├── Examples
├── 00110.png
├── 03deb7ad95.mp4
└── readme.txt
├── Output
└── readme.txt
├── README.md
├── caps_layers_cond.py
├── caps_main.py
├── caps_network_test.py
├── caps_network_train.py
├── config.py
├── inference.py
├── load_youtube_data_multi.py
├── load_youtubevalid_data.py
├── network_parts
├── lstm_capsnet_cond2_test.py
└── lstm_capsnet_cond2_train.py
├── network_saves
└── readme.txt
└── network_saves_best
└── readme.txt
/Examples/00110.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KevinDuarte/CapsuleVOS/0b390da90291b3f5527e444cb7fe943d319e6952/Examples/00110.png
--------------------------------------------------------------------------------
/Examples/03deb7ad95.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KevinDuarte/CapsuleVOS/0b390da90291b3f5527e444cb7fe943d319e6952/Examples/03deb7ad95.mp4
--------------------------------------------------------------------------------
/Examples/readme.txt:
--------------------------------------------------------------------------------
1 | ## Include the .mp4 files which will be used in inference here
2 |
--------------------------------------------------------------------------------
/Output/readme.txt:
--------------------------------------------------------------------------------
1 | This folder will hold the inference frames.
2 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## CapsuleVOS
2 |
3 | This is the code for the ICCV 2019 paper CapsuleVOS: Semi-Supervised Video Object Segmentation Using Capsule Routing.
4 |
5 | Arxiv Link: https://arxiv.org/abs/1910.00132
6 |
7 | The network is implemented using TensorFlow 1.4.1.
8 |
9 | Python packages used: numpy, scipy, scikit-video
10 |
11 | ## Files and their use
12 |
13 | 1. caps_layers_cod.py: Contains the functions required to construct capsule layers - (primary, convolutional, and fully-connected, and conditional capsule routing).
14 | 2. caps_network_train.py: Contains the CapsuleVOS model for training.
15 | 3. caps_network_test.py: Contains the CapsuleVOS model for testing.
16 | 4. caps_main.py: Contains the main function, which is called to train the network.
17 | 5. config.py: Contains several different hyperparameters used for the network, training, or inference.
18 | 6. inference.py: Contains the inference code.
19 | 7. load_youtube_data_multi.py: Contains the training data-generator for YoutubeVOS 2018 dataset.
20 | 8. load_youtubevalid_data.py: Contains the validation data-generator for YoutubeVOS 2018 dataset.
21 |
22 | ## Data Used
23 |
24 | We have supplied the code for training and inference of the model on the YoutubeVOS-2018 dataset. The file load_youtube_data_multi.py
and load_youtubevalid_data.py
creates two DataLoaders - one for training and one for validation. The data_loc
variable at the top of each file should be set to the base directory which contains the frames and annotations.
25 |
26 | To run this code, you need to do the following:
27 | 1. Download the YoutubeVOS dataset
28 | 2. Perform interpolation for the training frames following the papers' instructions
29 |
30 | ## Training the Model
31 |
32 | Once the data is set up you can train (and test) the network by calling python3 caps_main.py
.
33 |
34 | The config.py
file contains several hyper-parameters which are useful for training the network.
35 |
36 | ## Output File
37 |
38 | During training and testing, metrics are printed to stdout as well as an output*.txt file. During training/validation, the losses and accuracies are printed out to the terminal and to an output file.
39 |
40 | ## Saved Weights
41 |
42 | Pretrained weights for the network are available [here](https://drive.google.com/open?id=12zvvqd5i3EVNzPF-hEfq_hi2CEzRRSjS). To use them for inference, place them in the network_saves_best
folder.
43 |
44 | ## Inference
45 |
46 | If you just want to test the trained model with the weights above, run the inference code by calling python3 inference.py
. This code will read in an .mp4 file and a reference segmentation mask, and output the segmented frames of the video to the Output folder.
47 |
48 | An example video is available in the Example folder.
49 |
50 |
--------------------------------------------------------------------------------
/caps_layers_cond.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | from functools import reduce
4 | import config
5 |
6 | epsilon = 1e-7
7 |
8 |
9 | # Standard EM Routing
10 | def em_routing(v, a_i, beta_v, beta_a, n_iterations=3):
11 | batch_size = tf.shape(v)[0]
12 | _, _, n_caps_j, mat_len = v.get_shape().as_list()
13 | n_caps_j, mat_len = map(int, [n_caps_j, mat_len])
14 | n_caps_i = tf.shape(v)[1]
15 |
16 | a_i = tf.expand_dims(a_i, axis=-1)
17 |
18 | # Prior probabilities for routing
19 | r = tf.ones(shape=(batch_size, n_caps_i, n_caps_j, 1), dtype=tf.float32)/float(n_caps_j)
20 | r = tf.multiply(r, a_i)
21 |
22 | den = tf.reduce_sum(r, axis=1, keep_dims=True) + epsilon
23 |
24 | # Mean: shape=(N, 1, Ch_j, mat_len)
25 | m_num = tf.reduce_sum(v*r, axis=1, keep_dims=True)
26 | m = m_num/(den + epsilon)
27 |
28 | # Stddev: shape=(N, 1, Ch_j, mat_len)
29 | s_num = tf.reduce_sum(r * tf.square(v - m), axis=1, keep_dims=True)
30 | s = s_num/(den + epsilon)
31 |
32 | # cost_h: shape=(N, 1, Ch_j, mat_len)
33 | cost = (beta_v + tf.log(tf.sqrt(s + epsilon) + epsilon)) * den
34 | # cost_h: shape=(N, 1, Ch_j, 1)
35 | cost = tf.reduce_sum(cost, axis=-1, keep_dims=True)
36 |
37 | # calculates the mean and std_deviation of the cost, used for numerical stability
38 | cost_mean = tf.reduce_mean(cost, axis=-2, keep_dims=True)
39 | cost_stdv = tf.sqrt(
40 | tf.reduce_sum(
41 | tf.square(cost - cost_mean), axis=-2, keep_dims=True
42 | ) / n_caps_j + epsilon
43 | )
44 |
45 | # calculates the activations for the capsules in layer j
46 | a_j = tf.sigmoid(float(config.inv_temp) * (beta_a + (cost_mean - cost) / (cost_stdv + epsilon)))
47 |
48 | def condition(mean, stdsqr, act_j, r_temp, counter):
49 | return tf.less(counter, n_iterations)
50 |
51 | def route(mean, stdsqr, act_j, r_temp, counter):
52 | exp = tf.reduce_sum(tf.square(v - mean) / (2 * stdsqr + epsilon), axis=-1)
53 | coef = 0 - .5 * tf.reduce_sum(tf.log(2 * np.pi * stdsqr + epsilon), axis=-1)
54 | log_p_j = coef - exp
55 |
56 | log_ap = tf.reshape(tf.log(act_j + epsilon), (batch_size, 1, n_caps_j)) + log_p_j
57 | r_ij = tf.nn.softmax(log_ap + epsilon) # ap / (tf.reduce_sum(ap, axis=-1, keep_dims=True) + epsilon)
58 |
59 | r_ij = tf.multiply(tf.expand_dims(r_ij, axis=-1), a_i)
60 |
61 | denom = tf.reduce_sum(r_ij, axis=1, keep_dims=True) + epsilon
62 | m_numer = tf.reduce_sum(v * r_ij, axis=1, keep_dims=True)
63 | mean = m_numer / (denom + epsilon)
64 |
65 | s_numer = tf.reduce_sum(r_ij * tf.square(v - mean), axis=1, keep_dims=True)
66 | stdsqr = s_numer / (denom + epsilon)
67 |
68 | cost_h = (beta_v + tf.log(tf.sqrt(stdsqr) + epsilon)) * denom
69 |
70 | cost_h = tf.reduce_sum(cost_h, axis=-1, keep_dims=True)
71 | cost_h_mean = tf.reduce_mean(cost_h, axis=-2, keep_dims=True)
72 | cost_h_stdv = tf.sqrt(
73 | tf.reduce_sum(
74 | tf.square(cost_h - cost_h_mean), axis=-2, keep_dims=True
75 | ) / n_caps_j
76 | )
77 |
78 | inv_temp = config.inv_temp + counter * config.inv_temp_delta
79 | act_j = tf.sigmoid(inv_temp * (beta_a + (cost_h_mean - cost_h) / (cost_h_stdv + epsilon)))
80 |
81 | return mean, stdsqr, act_j, r_ij, tf.add(counter, 1)
82 |
83 | [mean, _, act_j, r_new, _] = tf.while_loop(condition, route, [m, s, a_j, r, 1.0])
84 |
85 | return tf.reshape(mean, (batch_size, n_caps_j, mat_len)), tf.reshape(act_j, (batch_size, n_caps_j, 1))
86 |
87 |
88 | # Attention Routing layer
89 | def em_routing_cond(v_v1, v_v2, a_i_v, v_f, a_i_f, beta_v_v, beta_a_v, beta_v_f, beta_a_f, n_iterations=3):
90 | batch_size = tf.shape(v_f)[0]
91 | _, _, n_caps_j, mat_len = v_f.get_shape().as_list()
92 | n_caps_j, mat_len = map(int, [n_caps_j, mat_len])
93 | n_caps_i_f = tf.shape(v_f)[1]
94 |
95 | a_i_f = tf.expand_dims(a_i_f, axis=-1)
96 |
97 | # Prior probabilities for routing
98 | r_f = tf.ones(shape=(batch_size, n_caps_i_f, n_caps_j, 1), dtype=tf.float32)/float(n_caps_j)
99 | r_f = tf.multiply(r_f, a_i_f)
100 |
101 | den_f = tf.reduce_sum(r_f, axis=1, keep_dims=True) + epsilon
102 |
103 | m_num_f = tf.reduce_sum(v_f*r_f, axis=1, keep_dims=True) # Mean: shape=(N, 1, Ch_j, mat_len)
104 | m_f = m_num_f/(den_f + epsilon)
105 |
106 | s_num_f = tf.reduce_sum(r_f * tf.square(v_f - m_f), axis=1, keep_dims=True) # Stddev: shape=(N, 1, Ch_j, mat_len)
107 | s_f = s_num_f/(den_f + epsilon)
108 |
109 | cost_f = (beta_v_f + tf.log(tf.sqrt(s_f + epsilon) + epsilon)) * den_f # cost_h: shape=(N, 1, Ch_j, mat_len)
110 | cost_f = tf.reduce_sum(cost_f, axis=-1, keep_dims=True) # cost_h: shape=(N, 1, Ch_j, 1)
111 |
112 | # calculates the mean and std_deviation of the cost
113 | cost_mean_f = tf.reduce_mean(cost_f, axis=-2, keep_dims=True)
114 | cost_stdv_f = tf.sqrt(
115 | tf.reduce_sum(
116 | tf.square(cost_f - cost_mean_f), axis=-2, keep_dims=True
117 | ) / n_caps_j + epsilon
118 | )
119 |
120 | # calculates the activations for the capsules in layer j for the frame capsules
121 | a_j_f = tf.sigmoid(float(config.inv_temp) * (beta_a_f + (cost_mean_f - cost_f) / (cost_stdv_f + epsilon)))
122 |
123 | def condition(mean_f, stdsqr_f, act_j_f, counter):
124 | return tf.less(counter, n_iterations)
125 |
126 | def route(mean_f, stdsqr_f, act_j_f, counter):
127 | # Performs E-step for frames
128 | exp_f = tf.reduce_sum(tf.square(v_f - mean_f) / (2 * stdsqr_f + epsilon), axis=-1)
129 | coef_f = 0 - .5 * tf.reduce_sum(tf.log(2 * np.pi * stdsqr_f + epsilon), axis=-1)
130 | log_p_j_f = coef_f - exp_f
131 |
132 | log_ap_f = tf.reshape(tf.log(act_j_f + epsilon), (batch_size, 1, n_caps_j)) + log_p_j_f
133 | r_ij_f = tf.nn.softmax(log_ap_f + epsilon)
134 |
135 | # Performs M-step for frames
136 | r_ij_f = tf.multiply(tf.expand_dims(r_ij_f, axis=-1), a_i_f)
137 |
138 | denom_f = tf.reduce_sum(r_ij_f, axis=1, keep_dims=True) + epsilon
139 | m_numer_f = tf.reduce_sum(v_f * r_ij_f, axis=1, keep_dims=True)
140 | mean_f = m_numer_f / (denom_f + epsilon)
141 |
142 | s_numer_f = tf.reduce_sum(r_ij_f * tf.square(v_f - mean_f), axis=1, keep_dims=True)
143 | stdsqr_f = s_numer_f / (denom_f + epsilon)
144 |
145 | cost_h_f = (beta_v_f + tf.log(tf.sqrt(stdsqr_f + epsilon) + epsilon)) * denom_f
146 |
147 | cost_h_f = tf.reduce_sum(cost_h_f, axis=-1, keep_dims=True)
148 | cost_h_mean_f = tf.reduce_mean(cost_h_f, axis=-2, keep_dims=True)
149 | cost_h_stdv_f = tf.sqrt(
150 | tf.reduce_sum(
151 | tf.square(cost_h_f - cost_h_mean_f), axis=-2, keep_dims=True
152 | ) / n_caps_j + epsilon
153 | )
154 |
155 | inv_temp = config.inv_temp + counter * config.inv_temp_delta
156 | act_j_f = tf.sigmoid(inv_temp * (beta_a_f + (cost_h_mean_f - cost_h_f) / (cost_h_stdv_f + epsilon)))
157 |
158 | return mean_f, stdsqr_f, act_j_f, tf.add(counter, 1)
159 |
160 | [mean_f_fin, _, act_j_f_fin, _] = tf.while_loop(condition, route, [m_f, s_f, a_j_f, 1.0])
161 |
162 | # performs m step for the video capsules
163 | a_i_v = tf.expand_dims(a_i_v, axis=-1)
164 |
165 | dist_v = tf.reduce_sum(tf.square(v_v1 - mean_f_fin), axis=-1)
166 | r_v = tf.expand_dims(tf.nn.softmax(0 - dist_v), axis=-1) * a_i_v
167 |
168 | den_v = tf.reduce_sum(r_v, axis=1, keep_dims=True) + epsilon
169 |
170 | m_num_v = tf.reduce_sum(v_v2 * r_v, axis=1, keep_dims=True) # Mean: shape=(N, 1, Ch_j, mat_len)
171 | m_v = m_num_v / (den_v + epsilon)
172 |
173 | s_num_v = tf.reduce_sum(r_v * tf.square(v_v2 - m_v), axis=1, keep_dims=True) # Stddev: shape=(N, 1, Ch_j, mat_len)
174 | s_v = s_num_v / (den_v + epsilon)
175 |
176 | cost_v = (beta_v_v + tf.log(tf.sqrt(s_v + epsilon) + epsilon)) * den_v # cost_h: shape=(N, 1, Ch_j, mat_len)
177 | cost_v = tf.reduce_sum(cost_v, axis=-1, keep_dims=True) # cost_h: shape=(N, 1, Ch_j, 1)
178 |
179 | # calculates the mean and std_deviation of the cost
180 | cost_mean_v = tf.reduce_mean(cost_v, axis=-2, keep_dims=True)
181 | cost_stdv_v = tf.sqrt(
182 | tf.reduce_sum(
183 | tf.square(cost_v - cost_mean_v), axis=-2, keep_dims=True
184 | ) / n_caps_j + epsilon
185 | )
186 |
187 | # calculates the activations for the capsules in layer j for the frame capsules
188 | a_j_v = tf.sigmoid(float(config.inv_temp) * (beta_a_v + (cost_mean_v - cost_v) / (cost_stdv_v + epsilon)))
189 |
190 | return (tf.reshape(m_v, (batch_size, n_caps_j, mat_len)), tf.reshape(a_j_v, (batch_size, n_caps_j, 1))), (tf.reshape(mean_f_fin, (batch_size, n_caps_j, mat_len)), tf.reshape(act_j_f_fin, (batch_size, n_caps_j, 1)))
191 |
192 |
193 | def create_prim_conv3d_caps(inputs, channels, kernel_size, strides, name, padding='VALID', activation=None, mdim=4):
194 | mdim2 = mdim*mdim
195 | batch_size = tf.shape(inputs)[0]
196 | poses = tf.layers.conv3d(inputs=inputs, filters=channels * mdim2, kernel_size=kernel_size,
197 | strides=strides, padding=padding, activation=activation, name=name+'_pose')
198 |
199 | _, d, h, w, _ = poses.get_shape().as_list()
200 | d, h, w = map(int, [d, h, w])
201 |
202 | pose = tf.reshape(poses, (batch_size, d, h, w, channels, mdim2), name=name+'_pose_res')
203 | #pose = tf.nn.l2_normalize(pose, dim=-1)
204 |
205 | acts = tf.layers.conv3d(inputs=inputs, filters=channels, kernel_size=kernel_size,
206 | strides=strides, padding=padding, activation=tf.nn.sigmoid, name=name+'_act')
207 | activation = tf.reshape(acts, (batch_size, d, h, w, channels, 1), name=name+'_act_res')
208 |
209 | return pose, activation
210 |
211 |
212 | def create_coords_mat(pose, rel_center, mdim=4):
213 | """
214 |
215 | :param pose: the incoming map of pose matrices, shape (N, ..., Ch_i, 16) where ... is the dimensions of the map can
216 | be 1, 2 or 3 dimensional.
217 | :param rel_center: whether or not the coordinates are relative to the center of the map
218 | :return: returns the coordinates (padded to 16) fir the incoming capsules
219 | """
220 | batch_size = tf.shape(pose)[0]
221 | shape_list = [int(x) for x in pose.get_shape().as_list()[1:-2]]
222 | ch = int(pose.get_shape().as_list()[-2])
223 | n_dims = len(shape_list)
224 |
225 | if n_dims == 3:
226 | d, h, w = shape_list
227 | elif n_dims == 2:
228 | d = 1
229 | h, w = shape_list
230 | else:
231 | d, h = 1, 1
232 | w = shape_list[0]
233 |
234 | subs = [0, 0, 0]
235 | if rel_center:
236 | subs = [int(d / 2), int(h / 2), int(w / 2)]
237 |
238 | c_mats = []
239 | if n_dims >= 3:
240 | c_mats.append(tf.tile(tf.reshape(tf.range(d, dtype=tf.float32), (1, d, 1, 1, 1, 1)), [batch_size, 1, h, w, ch, 1])-subs[0])
241 | if n_dims >= 2:
242 | c_mats.append(tf.tile(tf.reshape(tf.range(h, dtype=tf.float32), (1, 1, h, 1, 1, 1)), [batch_size, d, 1, w, ch, 1])-subs[1])
243 | if n_dims >= 1:
244 | c_mats.append(tf.tile(tf.reshape(tf.range(w, dtype=tf.float32), (1, 1, 1, w, 1, 1)), [batch_size, d, h, 1, ch, 1])-subs[2])
245 | add_coords = tf.concat(c_mats, axis=-1)
246 | add_coords = tf.cast(tf.reshape(add_coords, (batch_size, d*h*w, ch, n_dims)), dtype=tf.float32)
247 |
248 | mdim2 = mdim*mdim
249 | zeros = tf.zeros((batch_size, d*h*w, ch, mdim2-n_dims))
250 |
251 | return tf.concat([zeros, add_coords], axis=-1)
252 |
253 |
254 | def create_dense_caps(inputs, n_caps_j, name, route_min=0.0, coord_add=False, rel_center=False,
255 | ch_same_w=True, mdim=4):
256 | """
257 |
258 | :param inputs: The input capsule layer. Shape ((N, K, Ch_i, 16), (N, K, Ch_i, 1)) or
259 | ((N, ..., Ch_i, 16), (N, ..., Ch_i, 1)) where K is the number of capsules per channel and '...' is if you are
260 | inputting an map of capsules like W or HxW or DxHxW.
261 | :param n_caps_j: The number of capsules in the following layer
262 | :param name: name of the capsule layer
263 | :param route_min: A threshold activation to route
264 | :param coord_add: A boolean, whether or not to to do coordinate addition
265 | :param rel_center: A boolean, whether or not the coordinate addition is relative to the center
266 | :param routing_type: The type of routing
267 | :return: Returns a layer of capsules. Shape ((N, n_caps_j, 16), (N, n_caps_j, 1))
268 | """
269 | mdim2 = mdim*mdim
270 | pose, activation = inputs
271 | batch_size = tf.shape(pose)[0]
272 | shape_list = [int(x) for x in pose.get_shape().as_list()[1:]]
273 | ch = int(shape_list[-2])
274 | n_capsch_i = 1 if len(shape_list) == 2 else reduce((lambda x, y: x * y), shape_list[:-2])
275 |
276 | u_i = tf.reshape(pose, (batch_size, n_capsch_i, ch, mdim2))
277 | activation = tf.reshape(activation, (batch_size, n_capsch_i, ch, 1))
278 | coords = create_coords_mat(pose, rel_center) if coord_add else tf.zeros_like(u_i)
279 |
280 | # reshapes the input capsules
281 | u_i = tf.reshape(u_i, (batch_size, n_capsch_i, ch, mdim, mdim))
282 | u_i = tf.expand_dims(u_i, axis=-3)
283 | u_i = tf.tile(u_i, [1, 1, 1, n_caps_j, 1, 1])
284 |
285 | if ch_same_w:
286 | weights = tf.get_variable(name=name + '_weights', shape=(ch, n_caps_j, mdim, mdim),
287 | initializer=tf.initializers.random_normal(stddev=0.1),
288 | regularizer=tf.contrib.layers.l2_regularizer(0.1))
289 |
290 | votes = tf.einsum('ijab,ntijbc->ntijac', weights, u_i)
291 | votes = tf.reshape(votes, (batch_size, n_capsch_i * ch, n_caps_j, mdim2), name=name+'_votes')
292 | else:
293 | weights = tf.get_variable(name=name + '_weights', shape=(n_capsch_i, ch, n_caps_j, mdim, mdim),
294 | initializer=tf.initializers.random_normal(stddev=0.1),
295 | regularizer=tf.contrib.layers.l2_regularizer(0.1))
296 | votes = tf.einsum('tijab,ntijbc->ntijac', weights, u_i)
297 | votes = tf.reshape(votes, (batch_size, n_capsch_i * ch, n_caps_j, mdim2), name=name+'_votes')
298 |
299 | if coord_add:
300 | coords = tf.reshape(coords, (batch_size, n_capsch_i * ch, 1, mdim2))
301 | votes = votes + tf.tile(coords, [1, 1, n_caps_j, 1])
302 |
303 | acts = tf.reshape(activation, (batch_size, n_capsch_i * ch, 1))
304 | activations = tf.where(tf.greater_equal(acts, tf.constant(route_min)), acts, tf.zeros_like(acts))
305 |
306 | beta_v = tf.get_variable(name=name + '_beta_v', shape=(n_caps_j, mdim2),
307 | initializer=tf.initializers.random_normal(stddev=0.1),
308 | regularizer=tf.contrib.layers.l2_regularizer(0.1))
309 |
310 | beta_a = tf.get_variable(name=name + '_beta_a', shape=(n_caps_j, 1),
311 | initializer=tf.initializers.random_normal(stddev=0.1),
312 | regularizer=tf.contrib.layers.l2_regularizer(0.1))
313 |
314 | capsules = em_routing(votes, activations, beta_v, beta_a)
315 |
316 | return capsules
317 |
318 |
319 | def create_conv3d_caps(inputs, channels, kernel_size, strides, name, padding='VALID',
320 | coord_add=False, rel_center=True, route_mean=True, ch_same_w=True, mdim=4):
321 | mdim2 = mdim*mdim
322 | inputs = tf.concat(inputs, axis=-1)
323 |
324 | if padding == 'SAME':
325 | d_padding, h_padding, w_padding = int(float(kernel_size[0]) / 2), int(float(kernel_size[1]) / 2), int(float(kernel_size[2]) / 2)
326 | u_padded = tf.pad(inputs, [[0, 0], [d_padding, d_padding], [h_padding, h_padding], [w_padding, w_padding], [0, 0], [0, 0]])
327 | else:
328 | u_padded = inputs
329 |
330 | batch_size = tf.shape(u_padded)[0]
331 | _, d, h, w, ch, _ = u_padded.get_shape().as_list()
332 | d, h, w, ch = map(int, [d, h, w, ch])
333 |
334 | # gets indices for kernels
335 | d_offsets = [[(d_ + k) for k in range(kernel_size[0])] for d_ in range(0, d + 1 - kernel_size[0], strides[0])]
336 | h_offsets = [[(h_ + k) for k in range(kernel_size[1])] for h_ in range(0, h + 1 - kernel_size[1], strides[1])]
337 | w_offsets = [[(w_ + k) for k in range(kernel_size[2])] for w_ in range(0, w + 1 - kernel_size[2], strides[2])]
338 |
339 | # output dimensions
340 | d_out, h_out, w_out = len(d_offsets), len(h_offsets), len(w_offsets)
341 |
342 | # gathers the capsules into shape (N, D2, H2, W2, KD, KH, KW, Ch_in, 17)
343 | d_gathered = tf.gather(u_padded, d_offsets, axis=1)
344 | h_gathered = tf.gather(d_gathered, h_offsets, axis=3)
345 | w_gathered = tf.gather(h_gathered, w_offsets, axis=5)
346 | w_gathered = tf.transpose(w_gathered, [0, 1, 3, 5, 2, 4, 6, 7, 8])
347 |
348 | if route_mean:
349 | kernels_reshaped = tf.reshape(w_gathered, [batch_size * d_out * h_out * w_out, kernel_size[0]* kernel_size[1]* kernel_size[2], ch, mdim2 + 1])
350 | kernels_reshaped = tf.reduce_mean(kernels_reshaped, axis=1)
351 | capsules = create_dense_caps((kernels_reshaped[:, :, :-1], kernels_reshaped[:, :, -1:]), channels, name,
352 | ch_same_w=ch_same_w, mdim=mdim)
353 | else:
354 | kernels_reshaped = tf.reshape(w_gathered, [batch_size * d_out * h_out * w_out, kernel_size[0], kernel_size[1], kernel_size[2], ch, mdim2 + 1])
355 | capsules = create_dense_caps((kernels_reshaped[:, :, :, :, :, :-1], kernels_reshaped[:, :, :, :, :, -1:]),
356 | channels, name, coord_add=coord_add, rel_center=rel_center, ch_same_w=ch_same_w, mdim=mdim)
357 |
358 | poses = tf.reshape(capsules[0][:, :, :mdim2], (batch_size, d_out, h_out, w_out, channels, mdim2), name=name+'_pose')
359 | activations = tf.reshape(capsules[1], (batch_size, d_out, h_out, w_out, channels, 1), name=name+'_act')
360 |
361 | return poses, activations
362 |
363 |
364 | def create_dense_caps_cond(inputs, n_caps_j, name, coord_add=False, rel_center=False,
365 | ch_same_w=True, mdim=4, n_cond_caps=0):
366 | """
367 |
368 | :param inputs: The input capsule layer. Shape ((N, K, Ch_i, 16), (N, K, Ch_i, 1)) or
369 | ((N, ..., Ch_i, 16), (N, ..., Ch_i, 1)) where K is the number of capsules per channel and '...' is if you are
370 | inputting an map of capsules like W or HxW or DxHxW.
371 | :param n_caps_j: The number of capsules in the following layer
372 | :param name: name of the capsule layer
373 | :param coord_add: A boolean, whether or not to to do coordinate addition
374 | :param rel_center: A boolean, whether or not the coordinate addition is relative to the center
375 | :param routing_type: The type of routing
376 | :return: Returns a layer of capsules. Shape ((N, n_caps_j, 16), (N, n_caps_j, 1))
377 | """
378 | mdim2 = mdim*mdim
379 | pose, activation = inputs
380 | batch_size = tf.shape(pose)[0]
381 | shape_list = [int(x) for x in pose.get_shape().as_list()[1:]]
382 | ch = int(shape_list[-2])
383 | n_capsch_i = 1 if len(shape_list) == 2 else reduce((lambda x, y: x * y), shape_list[:-2])
384 |
385 | u_i = tf.reshape(pose, (batch_size, n_capsch_i, ch, mdim2))
386 | activation = tf.reshape(activation, (batch_size, n_capsch_i, ch, 1))
387 | coords = create_coords_mat(pose, rel_center) if coord_add else tf.zeros_like(u_i)
388 |
389 |
390 | # reshapes the input capsules
391 | u_i = tf.reshape(u_i, (batch_size, n_capsch_i, ch, mdim, mdim))
392 | u_i = tf.expand_dims(u_i, axis=-3)
393 | u_i = tf.tile(u_i, [1, 1, 1, n_caps_j, 1, 1])
394 |
395 | if ch_same_w:
396 | weights = tf.get_variable(name=name + '_weights', shape=(ch, n_caps_j, mdim, mdim),
397 | initializer=tf.initializers.random_normal(stddev=0.1),
398 | regularizer=tf.contrib.layers.l2_regularizer(0.1))
399 |
400 | votes = tf.einsum('ijab,ntijbc->ntijac', weights, u_i)
401 | votes = tf.reshape(votes, (batch_size, n_capsch_i * ch, n_caps_j, mdim2), name=name+'_votes')
402 | else:
403 | weights = tf.get_variable(name=name + '_weights', shape=(n_capsch_i, ch, n_caps_j, mdim, mdim),
404 | initializer=tf.initializers.random_normal(stddev=0.1),
405 | regularizer=tf.contrib.layers.l2_regularizer(0.1))
406 | votes = tf.einsum('tijab,ntijbc->ntijac', weights, u_i)
407 | votes = tf.reshape(votes, (batch_size, n_capsch_i * ch, n_caps_j, mdim2), name=name+'_votes')
408 |
409 | if coord_add:
410 | coords = tf.reshape(coords, (batch_size, n_capsch_i * ch, 1, mdim2))
411 | votes = votes + tf.tile(coords, [1, 1, n_caps_j, 1])
412 |
413 |
414 | if n_cond_caps == 0:
415 | beta_v = tf.get_variable(name=name + '_beta_v', shape=(n_caps_j, mdim2),
416 | initializer=tf.initializers.random_normal(stddev=0.1),
417 | regularizer=tf.contrib.layers.l2_regularizer(0.1))
418 | beta_a = tf.get_variable(name=name + '_beta_a', shape=(n_caps_j, 1),
419 | initializer=tf.initializers.random_normal(stddev=0.1),
420 | regularizer=tf.contrib.layers.l2_regularizer(0.1))
421 |
422 | acts = tf.reshape(activation, (batch_size, n_capsch_i * ch, 1))
423 |
424 | capsules1 = em_routing(votes, acts, beta_v, beta_a)
425 | capsules2 = capsules1
426 | else:
427 | beta_v1 = tf.get_variable(name=name + '_beta_v1', shape=(n_caps_j, mdim2),
428 | initializer=tf.initializers.random_normal(stddev=0.1),
429 | regularizer=tf.contrib.layers.l2_regularizer(0.1))
430 | beta_a1 = tf.get_variable(name=name + '_beta_a1', shape=(n_caps_j, 1),
431 | initializer=tf.initializers.random_normal(stddev=0.1),
432 | regularizer=tf.contrib.layers.l2_regularizer(0.1))
433 |
434 | beta_v2 = tf.get_variable(name=name + '_beta_v2', shape=(n_caps_j, mdim2),
435 | initializer=tf.initializers.random_normal(stddev=0.1),
436 | regularizer=tf.contrib.layers.l2_regularizer(0.1))
437 | beta_a2 = tf.get_variable(name=name + '_beta_a2', shape=(n_caps_j, 1),
438 | initializer=tf.initializers.random_normal(stddev=0.1),
439 | regularizer=tf.contrib.layers.l2_regularizer(0.1))
440 |
441 | votes = tf.reshape(votes, (batch_size, n_capsch_i, ch, n_caps_j, mdim2))
442 |
443 | votes1 = tf.reshape(votes[:, :, :ch - n_cond_caps],
444 | (batch_size, n_capsch_i * (ch - n_cond_caps), n_caps_j, mdim2))
445 | votes2 = tf.reshape(votes[:, :, ch - n_cond_caps:], (batch_size, n_capsch_i * n_cond_caps, n_caps_j, mdim2))
446 |
447 | acts = tf.reshape(activation, (batch_size, n_capsch_i, ch, 1))
448 |
449 | acts1 = tf.reshape(acts[:, :, :ch - n_cond_caps], (batch_size, n_capsch_i * (ch - n_cond_caps), 1))
450 | acts2 = tf.reshape(acts[:, :, ch - n_cond_caps:], (batch_size, n_capsch_i * n_cond_caps, 1))
451 |
452 | weights_2 = tf.get_variable(name=name + '_weights_2', shape=(ch - n_cond_caps, n_caps_j, mdim, mdim),
453 | initializer=tf.initializers.random_normal(stddev=0.1),
454 | regularizer=tf.contrib.layers.l2_regularizer(0.1))
455 |
456 | votes_2 = tf.einsum('ijab,ntijbc->ntijac', weights_2, u_i[:, :, :ch - n_cond_caps])
457 | votes_2 = tf.reshape(votes_2, (batch_size, n_capsch_i * (ch - n_cond_caps), n_caps_j, mdim2), name=name + '_votes2')
458 |
459 | capsules1, capsules2 = em_routing_cond(votes1, votes_2, acts1, votes2, acts2, beta_v1, beta_a1, beta_v2,
460 | beta_a2)
461 |
462 | return capsules1, capsules2
463 |
464 |
465 | def create_conv3d_caps_cond(inputs, channels, kernel_size, strides, name, padding='VALID',
466 | coord_add=False, rel_center=True, route_mean=True, ch_same_w=True, mdim=4, n_cond_caps=0):
467 | mdim2 = mdim*mdim
468 | inputs = tf.concat(inputs, axis=-1)
469 |
470 | if padding == 'SAME':
471 | d_padding, h_padding, w_padding = int(float(kernel_size[0]) / 2), int(float(kernel_size[1]) / 2), int(float(kernel_size[2]) / 2)
472 | u_padded = tf.pad(inputs, [[0, 0], [d_padding, d_padding], [h_padding, h_padding], [w_padding, w_padding], [0, 0], [0, 0]])
473 | else:
474 | u_padded = inputs
475 |
476 | batch_size = tf.shape(u_padded)[0]
477 | _, d, h, w, ch, _ = u_padded.get_shape().as_list()
478 | d, h, w, ch = map(int, [d, h, w, ch])
479 |
480 | # gets indices for kernels
481 | d_offsets = [[(d_ + k) for k in range(kernel_size[0])] for d_ in range(0, d + 1 - kernel_size[0], strides[0])]
482 | h_offsets = [[(h_ + k) for k in range(kernel_size[1])] for h_ in range(0, h + 1 - kernel_size[1], strides[1])]
483 | w_offsets = [[(w_ + k) for k in range(kernel_size[2])] for w_ in range(0, w + 1 - kernel_size[2], strides[2])]
484 |
485 | # output dimensions
486 | d_out, h_out, w_out = len(d_offsets), len(h_offsets), len(w_offsets)
487 |
488 | # gathers the capsules into shape (N, D2, H2, W2, KD, KH, KW, Ch_in, 17)
489 | d_gathered = tf.gather(u_padded, d_offsets, axis=1)
490 | h_gathered = tf.gather(d_gathered, h_offsets, axis=3)
491 | w_gathered = tf.gather(h_gathered, w_offsets, axis=5)
492 | w_gathered = tf.transpose(w_gathered, [0, 1, 3, 5, 2, 4, 6, 7, 8])
493 |
494 | if route_mean:
495 | kernels_reshaped = tf.reshape(w_gathered, [batch_size * d_out * h_out * w_out, kernel_size[0]* kernel_size[1]* kernel_size[2], ch, mdim2 + 1])
496 | kernels_reshaped = tf.reduce_mean(kernels_reshaped, axis=1)
497 | capsules1, capsules2 = create_dense_caps_cond((kernels_reshaped[:, :, :-1], kernels_reshaped[:, :, -1:]), channels, name,
498 | ch_same_w=ch_same_w, mdim=mdim, n_cond_caps=n_cond_caps)
499 | else:
500 | kernels_reshaped = tf.reshape(w_gathered, [batch_size * d_out * h_out * w_out, kernel_size[0], kernel_size[1], kernel_size[2], ch, mdim2 + 1])
501 | capsules1, capsules2 = create_dense_caps_cond((kernels_reshaped[:, :, :, :, :, :-1], kernels_reshaped[:, :, :, :, :, -1:]),
502 | channels, name, coord_add=coord_add, rel_center=rel_center,
503 | ch_same_w=ch_same_w, mdim=mdim, n_cond_caps=n_cond_caps)
504 |
505 | poses1 = tf.reshape(capsules1[0][:, :, :mdim2], (batch_size, d_out, h_out, w_out, channels, mdim2), name=name+'_pose1')
506 | activations1 = tf.reshape(capsules1[1], (batch_size, d_out, h_out, w_out, channels, 1), name=name+'_act1')
507 |
508 | poses2 = tf.reshape(capsules2[0][:, :, :mdim2], (batch_size, d_out, h_out, w_out, channels, mdim2), name=name + '_pose2')
509 | activations2 = tf.reshape(capsules2[1], (batch_size, d_out, h_out, w_out, channels, 1), name=name + '_act2')
510 |
511 | return (poses1, activations1), (poses2, activations2)
512 |
513 |
514 | def layer_shape(layer):
515 | return str(layer[0].get_shape()) + ' ' + str(layer[1].get_shape())
516 |
517 |
--------------------------------------------------------------------------------
/caps_main.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import config
3 | from caps_network_train import CapsNet
4 | import sys
5 | from load_youtube_data_multi import YoutubeTrainDataGen as TrainDataGen
6 | from load_youtubevalid_data import YoutubeValidDataGen as ValidDataGen
7 | import numpy as np
8 | import time
9 |
10 |
11 | def get_num_params():
12 | total_parameters = 0
13 | for variable in tf.trainable_variables():
14 | # shape is an array of tf.Dimension
15 | shape = variable.get_shape()
16 | variable_parameters = 1
17 | for dim in shape:
18 | variable_parameters *= dim.value
19 | total_parameters += variable_parameters
20 | print('Num of parameters:', total_parameters)
21 | sys.stdout.flush()
22 |
23 |
24 | def train_one_epoch(sess, capsnet, data_gen, epoch):
25 | start_time = time.time()
26 | # continues until no more training data is generated
27 | batch, s_losses, seg_acc, reg_losses = 0.0, 0, 0, 0
28 |
29 | while data_gen.has_data():
30 | x_batch, seg_batch, crop1_batch, crop2_batch = data_gen.get_batch(config.batch_size)
31 |
32 | if config.multi_gpu and len(x_batch) == 1:
33 | print('Batch size of one, not running')
34 | continue
35 |
36 | n_samples = len(x_batch)
37 |
38 | use_gt_seg = epoch <= config.n_epochs_for_gt_seg
39 | use_gt_crop = epoch <= config.n_epochs_for_gt_crop
40 |
41 | hr_lstm_input = np.zeros((n_samples, config.hr_lstm_size[0], config.hr_lstm_size[1], config.hr_lstm_feats))
42 | lr_lstm_input = np.zeros((n_samples, config.lr_lstm_size[0], config.lr_lstm_size[1], config.lr_lstm_feats))
43 |
44 | outputs = sess.run([capsnet.train_op, capsnet.segmentation_loss, capsnet.pred_caps, capsnet.seg_acc,
45 | capsnet.regression_loss],
46 | feed_dict={capsnet.x_input_video: x_batch, capsnet.y_segmentation: seg_batch,
47 | capsnet.hr_cond_input: hr_lstm_input, capsnet.lr_cond_input: lr_lstm_input,
48 | capsnet.use_gt_seg: use_gt_seg, capsnet.use_gt_crop: use_gt_crop,
49 | capsnet.gt_crops1: crop1_batch, capsnet.gt_crops2: crop2_batch})
50 |
51 | _, s_loss, cap_vals, s_acc, reg_loss = outputs
52 | s_losses += s_loss
53 | seg_acc += s_acc
54 | reg_losses += reg_loss
55 |
56 | batch += 1
57 |
58 | if np.isnan(cap_vals[0][0]):
59 | print(cap_vals[0][:10])
60 | print('NAN encountered.')
61 | config.write_output('NAN encountered.\n')
62 | return -1, -1, -1
63 |
64 | if batch % config.batches_until_print == 0:
65 | print('Finished %d batches. %d(s) since start. Avg Segmentation Loss is %.4f. Avg Regression Loss is %.4f. '
66 | 'Seg Acc is %.4f.'
67 | % (batch, time.time() - start_time, s_losses / batch, reg_losses / batch, seg_acc / batch))
68 | sys.stdout.flush()
69 |
70 | print('Finish Epoch in %d(s). Avg Segmentation Loss is %.4f. Avg Regression Loss is %.4f. Seg Acc is %.4f.' %
71 | (time.time() - start_time, s_losses / batch, reg_losses / batch, seg_acc / batch))
72 | sys.stdout.flush()
73 |
74 | return s_losses / batch, reg_losses / batch, seg_acc / batch
75 |
76 |
77 | def validate(sess, capsnet, data_gen):
78 | batch, s_losses, seg_acc = 0.0, 0, 0
79 | start_time = time.time()
80 |
81 | while data_gen.has_data():
82 | batch_data = data_gen.get_batch(config.batch_size)
83 | x_batch, seg_batch, crop1_batch = batch_data
84 |
85 | hr_lstm_input = np.zeros((len(x_batch), config.hr_lstm_size[0], config.hr_lstm_size[1], config.hr_lstm_feats))
86 | lr_lstm_input = np.zeros((len(x_batch), config.lr_lstm_size[0], config.lr_lstm_size[1], config.lr_lstm_feats))
87 |
88 | val_ouputs = sess.run([capsnet.val_seg_loss, capsnet.val_seg_acc],
89 | feed_dict={capsnet.x_input_video: x_batch, capsnet.y_segmentation: seg_batch,
90 | capsnet.hr_cond_input: hr_lstm_input, capsnet.lr_cond_input: lr_lstm_input,
91 | capsnet.use_gt_seg: True, capsnet.use_gt_crop: True,
92 | capsnet.gt_crops1: crop1_batch, capsnet.gt_crops2: crop1_batch})
93 |
94 | s_loss, s_acc = val_ouputs
95 |
96 | s_losses += s_loss
97 | seg_acc += s_acc
98 |
99 | batch += 1
100 |
101 | if batch % config.batches_until_print == 0:
102 | print('Tested %d batches. %d(s) since start.' % (batch, time.time() - start_time))
103 | sys.stdout.flush()
104 |
105 | print('Evaluation done in %d(s).' % (time.time() - start_time))
106 | print('Test Segmentation Loss: %.4f. Test Segmentation Acc: %.4f' % (s_losses / batch, seg_acc / batch))
107 | sys.stdout.flush()
108 |
109 | return s_losses / batch, seg_acc / batch
110 |
111 |
112 | def train_network(gpu_config):
113 | capsnet = CapsNet()
114 |
115 | with tf.Session(graph=capsnet.graph, config=gpu_config) as sess:
116 | tf.global_variables_initializer().run()
117 |
118 | get_num_params()
119 | if config.start_at_epoch <= 1:
120 | config.clear_output()
121 | else:
122 | capsnet.load(sess, config.save_file_best_name % (config.start_at_epoch - 1))
123 | print('Loading from epoch %d.' % (config.start_at_epoch-1))
124 |
125 | best_loss = 1000000
126 | best_epoch = 1
127 | print('Training on YoutubeVOS')
128 | for ep in range(config.start_at_epoch, config.n_epochs + 1):
129 | print(20 * '*', 'epoch', ep, 20 * '*')
130 | sys.stdout.flush()
131 |
132 | # Trains network for 1 epoch
133 | nan_tries = 0
134 | while nan_tries < 3:
135 | data_gen = TrainDataGen(config.wait_for_data, crop_size=config.hr_frame_size, n_frames=config.n_frames,
136 | rand_frame_skip=config.rand_frame_skip, multi_objects=config.multiple_objects)
137 | seg_loss, reg_loss, seg_acc = train_one_epoch(sess, capsnet, data_gen, ep)
138 |
139 | if seg_loss < 0 or seg_acc < 0:
140 | nan_tries += 1
141 | capsnet.load(sess, config.save_file_best_name % best_epoch) # loads in the previous epoch
142 | while data_gen.has_data():
143 | data_gen.get_batch(config.batch_size)
144 | else:
145 | config.write_output('Epoch %d: SL: %.4f. RL: %.4f. SegAcc: %.4f.\n' % (ep, seg_loss, reg_loss, seg_acc))
146 | break
147 |
148 | if nan_tries == 3:
149 | print('Network cannot be trained. Too many NaN issues.')
150 | exit()
151 |
152 | # Validates network
153 | data_gen = ValidDataGen(config.wait_for_data, crop_size=config.hr_frame_size, n_frames=config.n_frames)
154 | seg_loss, seg_acc = validate(sess, capsnet, data_gen)
155 |
156 | config.write_output('Validation\tSL: %.4f. SA: %.4f.\n' % (seg_loss, seg_acc))
157 |
158 | # saves every 10 epochs
159 | if ep % config.save_every_n_epochs == 0:
160 | try:
161 | capsnet.save(sess, config.save_file_name % ep)
162 | config.write_output('Saved Network\n')
163 | except:
164 | print('Failed to save network!!!')
165 | sys.stdout.flush()
166 |
167 | # saves when validation loss becomes smaller (after 50 epochs to save space)
168 | t_loss = seg_loss
169 |
170 | if t_loss < best_loss:
171 | best_loss = t_loss
172 | try:
173 | capsnet.save(sess, config.save_file_best_name % ep)
174 | best_epoch = ep
175 | config.write_output('Saved Network - Minimum val\n')
176 | except:
177 | print('Failed to save network!!!')
178 | sys.stdout.flush()
179 |
180 | tf.reset_default_graph()
181 |
182 |
183 | def main():
184 | gpu_config = tf.ConfigProto(allow_soft_placement=True)
185 | gpu_config.gpu_options.allow_growth = True
186 |
187 | train_network(gpu_config)
188 |
189 |
190 | sys.settrace(main())
191 |
192 |
193 |
--------------------------------------------------------------------------------
/caps_network_test.py:
--------------------------------------------------------------------------------
1 | import config
2 | import tensorflow as tf
3 | import sys
4 | from network_parts.lstm_capsnet_cond2_test import create_network
5 |
6 |
7 | class CapsNet(object):
8 | def __init__(self, graph=None):
9 | if graph is None:
10 | self.graph = tf.Graph()
11 | else:
12 | self.graph = graph
13 |
14 | with self.graph.as_default():
15 | hr_h, hr_w = config.hr_frame_size
16 |
17 | n_frames = config.n_frames
18 |
19 | self.x_input_video = tf.placeholder(dtype=tf.float32, shape=(None, n_frames, hr_h, hr_w, 3),
20 | name='x_input_video')
21 | self.y_segmentation = tf.placeholder(dtype=tf.float32, shape=(None, n_frames, hr_h, hr_w, 1),
22 | name='y_segmentation')
23 |
24 | self.x_first_seg = tf.placeholder(dtype=tf.float32, shape=(None, hr_h, hr_w, 1), name='x_first_seg')
25 |
26 | self.use_gt_crop = tf.placeholder(dtype=tf.bool)
27 |
28 | self.gt_crops1 = tf.placeholder(dtype=tf.float32, shape=(None, 4)) # [y1, x1, alpha, 0] between 0 and 1
29 |
30 | cond_h, cond_w = config.hr_lstm_size
31 | self.hr_cond_input = tf.placeholder(dtype=tf.float32, shape=(None, cond_h, cond_w, config.hr_lstm_feats),
32 | name='hr_lstm_input')
33 | cond_h, cond_w = config.lr_lstm_size
34 | self.lr_cond_input = tf.placeholder(dtype=tf.float32, shape=(None, cond_h, cond_w, config.lr_lstm_feats),
35 | name='lr_lstm_input')
36 |
37 | self.init_network()
38 |
39 | #self.init_seg_loss()
40 | #self.init_regression_loss()
41 | #self.total_loss = self.segmentation_loss + self.regression_loss
42 |
43 | #self.init_optimizer()
44 |
45 | self.saver = tf.train.Saver()
46 |
47 | def init_network(self):
48 | print('Building Caps3d Model')
49 |
50 | with tf.variable_scope('network') as scope:
51 | #scope.reuse_variables()
52 | network_outputs = create_network(self.x_input_video, self.x_first_seg, self.hr_cond_input, self.lr_cond_input, self.use_gt_crop, self.gt_crops1)
53 | self.segment_layer, self.segment_layer_sig, self.pred_caps, self.state_t, self.state_t_lr, self.pred_crops1 = network_outputs
54 |
55 | sys.stdout.flush()
56 |
57 | def init_seg_loss(self):
58 | # Segmentation loss
59 | segment = self.segment_layer
60 | y_seg = self.segment_layer#self.y_segmentation
61 |
62 | segmentation_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=y_seg, logits=segment)
63 | segmentation_loss = tf.reduce_mean(tf.reduce_sum(segmentation_loss, axis=[1, 2, 3, 4]))
64 |
65 | pred_seg = tf.cast(tf.greater(segment, 0.0), tf.float32)
66 | seg_acc = tf.reduce_mean(tf.cast(tf.equal(pred_seg, y_seg), tf.float32))
67 |
68 | frame_segment = self.segment_layer[:, 0, :, :, :]
69 | y_frame_segment = self.segment_layer[:, 0, :, :, :]#self.y_segmentation[:, 0, :, :, :]
70 |
71 | val_seg_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=y_frame_segment, logits=frame_segment)
72 | val_seg_loss = tf.reduce_mean(tf.reduce_sum(val_seg_loss, axis=[1, 2, 3]))
73 |
74 | val_pred_seg = tf.cast(tf.greater(frame_segment, 0.0), tf.float32)
75 | val_seg_acc = tf.reduce_mean(tf.cast(tf.equal(val_pred_seg, y_frame_segment), tf.float32))
76 |
77 | self.segmentation_loss = segmentation_loss
78 | self.val_seg_loss = val_seg_loss
79 |
80 | self.seg_acc = seg_acc
81 | self.val_seg_acc = val_seg_acc
82 |
83 | print('Segmentation Loss Initialized')
84 |
85 | def init_regression_loss(self):
86 | regression_loss = tf.square(self.gt_crops1 - self.pred_crops1)
87 | self.regression_loss = tf.reduce_mean(tf.reduce_sum(regression_loss, axis=1))
88 |
89 | print('Regression Loss Initialized')
90 |
91 | def init_optimizer(self):
92 | optimizer = tf.train.AdamOptimizer(learning_rate=config.learning_rate, beta1=config.beta1, name='Adam',
93 | epsilon=config.epsilon)
94 |
95 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
96 | with tf.control_dependencies(update_ops):
97 | self.train_op = optimizer.minimize(loss=self.total_loss, colocate_gradients_with_ops=True)
98 |
99 | def save(self, sess, file_name):
100 | save_path = self.saver.save(sess, file_name)
101 | print("Model saved in file: %s" % save_path)
102 | sys.stdout.flush()
103 |
104 | def load(self, sess, file_name):
105 | self.saver.restore(sess, file_name)
106 | print('Model restored.')
107 | sys.stdout.flush()
108 |
109 |
--------------------------------------------------------------------------------
/caps_network_train.py:
--------------------------------------------------------------------------------
1 | import config
2 | import tensorflow as tf
3 | import sys
4 | from network_parts.lstm_capsnet_cond2_train import create_network
5 |
6 |
7 | class CapsNet(object):
8 | def __init__(self, reg_mult=1.0):
9 | self.graph = tf.Graph()
10 |
11 | with self.graph.as_default():
12 | hr_h, hr_w = config.hr_frame_size
13 |
14 | n_frames = config.n_frames*2-1
15 |
16 | self.x_input_video = tf.placeholder(dtype=tf.float32, shape=(None, n_frames, hr_h, hr_w, 3),
17 | name='x_input_video')
18 | self.y_segmentation = tf.placeholder(dtype=tf.float32, shape=(None, n_frames, hr_h, hr_w, 1),
19 | name='y_segmentation')
20 |
21 | self.use_gt_seg = tf.placeholder(dtype=tf.bool)
22 | self.use_gt_crop = tf.placeholder(dtype=tf.bool)
23 |
24 | self.gt_crops1 = tf.placeholder(dtype=tf.float32, shape=(None, 4)) # [y1, x1, y2, x2] between 0 and 1
25 | self.gt_crops2 = tf.placeholder(dtype=tf.float32, shape=(None, 4)) # [y1, x1, y2, x2] between 0 and 1
26 |
27 | cond_h, cond_w = config.hr_lstm_size
28 | self.hr_cond_input = tf.placeholder(dtype=tf.float32, shape=(None, cond_h, cond_w, config.hr_lstm_feats),
29 | name='hr_lstm_input')
30 | cond_h, cond_w = config.lr_lstm_size
31 | self.lr_cond_input = tf.placeholder(dtype=tf.float32, shape=(None, cond_h, cond_w, config.lr_lstm_feats),
32 | name='lr_lstm_input')
33 | self.init_network()
34 |
35 | self.init_seg_loss()
36 | self.init_regression_loss()
37 | self.total_loss = self.segmentation_loss + self.regression_loss*reg_mult
38 |
39 | self.init_optimizer()
40 |
41 | self.saver = tf.train.Saver()
42 |
43 | def init_network(self):
44 | print('Building Caps3d Model')
45 |
46 | with tf.variable_scope('network') as scope:
47 | if config.multi_gpu:
48 | b = tf.cast(tf.shape(self.x_input_video)[0] / 2, tf.int32)
49 | with tf.device(config.devices[0]):
50 | segment_layer1, segment_layer_sig1, prim_caps1, state_t1, pred_crops11, pred_crops21 = create_network(self.x_input_video[:b],
51 | self.y_segmentation[:b],
52 | self.hr_cond_input[:b],
53 | self.lr_cond_input[:b],
54 | self.use_gt_seg,
55 | self.use_gt_crop,
56 | self.gt_crops1[:b],
57 | self.gt_crops2[:b])
58 |
59 | scope.reuse_variables()
60 | with tf.device(config.devices[1]):
61 | segment_layer2, segment_layer_sig2, prim_caps2, state_t2, pred_crops12, pred_crops22 = create_network(self.x_input_video[b:],
62 | self.y_segmentation[b:],
63 | self.hr_cond_input[b:],
64 | self.lr_cond_input[b:],
65 | self.use_gt_seg,
66 | self.use_gt_crop,
67 | self.gt_crops1[b:],
68 | self.gt_crops2[b:])
69 |
70 | self.segment_layer = tf.concat([segment_layer1, segment_layer2], axis=0)
71 | self.segment_layer_sig = tf.concat([segment_layer_sig1, segment_layer_sig2], axis=0)
72 | self.pred_caps = tf.concat([prim_caps1, prim_caps2], axis=0)
73 | self.state_t = tf.concat([state_t1, state_t2], axis=0)
74 | self.pred_crops1 = tf.concat([pred_crops11, pred_crops12], axis=0)
75 | self.pred_crops2 = tf.concat([pred_crops21, pred_crops22], axis=0)
76 | else:
77 | network_outputs = create_network(self.x_input_video, self.y_segmentation, self.hr_cond_input, self.lr_cond_input, self.use_gt_seg, self.use_gt_crop, self.gt_crops1, self.gt_crops2)
78 | self.segment_layer, self.segment_layer_sig, self.pred_caps, self.state_t, self.pred_crops1, self.pred_crops2 = network_outputs
79 |
80 | sys.stdout.flush()
81 |
82 | def init_seg_loss(self):
83 | # Segmentation loss
84 | segment = self.segment_layer
85 | y_seg = self.y_segmentation
86 |
87 | segmentation_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=y_seg, logits=segment)
88 | #segmentation_loss = tf.reduce_mean(tf.reduce_sum(segmentation_loss, axis=[1, 2, 3, 4]))
89 | segmentation_loss = tf.reduce_mean(tf.reduce_mean(segmentation_loss, axis=[1, 2, 3, 4]))
90 |
91 | pred_seg = tf.cast(tf.greater(segment, 0.0), tf.float32)
92 | seg_acc = tf.reduce_mean(tf.cast(tf.equal(pred_seg, y_seg), tf.float32))
93 |
94 | frame_segment = self.segment_layer[:, 0, :, :, :]
95 | y_frame_segment = self.y_segmentation[:, 0, :, :, :]
96 |
97 | val_seg_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=y_frame_segment, logits=frame_segment)
98 | val_seg_loss = tf.reduce_mean(tf.reduce_sum(val_seg_loss, axis=[1, 2, 3]))
99 |
100 | val_pred_seg = tf.cast(tf.greater(frame_segment, 0.0), tf.float32)
101 | val_seg_acc = tf.reduce_mean(tf.cast(tf.equal(val_pred_seg, y_frame_segment), tf.float32))
102 |
103 | segment_sig = self.segment_layer_sig
104 |
105 | p_times_r = segment_sig*y_seg
106 | p_plus_r = segment_sig + y_seg
107 | inv_p_times_r = (1-segment_sig) * (1-y_seg)
108 | inv_p_plus_r = 2-segment_sig-y_seg
109 | eps = 1e-8
110 |
111 | term1 = (tf.reduce_sum(p_times_r, axis=[1, 2, 3])+eps)/(tf.reduce_sum(p_plus_r, axis=[1, 2, 3])+eps)
112 | term2 = (tf.reduce_sum(inv_p_times_r, axis=[1, 2, 3])+eps)/(tf.reduce_sum(inv_p_plus_r, axis=[1, 2, 3])+eps)
113 |
114 | dice_loss = tf.reduce_mean(1 - term1 - term2)
115 |
116 | self.segmentation_loss = segmentation_loss + dice_loss
117 | # self.segmentation_loss = segmentation_loss
118 | self.val_seg_loss = val_seg_loss
119 |
120 | self.seg_acc = seg_acc
121 | self.val_seg_acc = val_seg_acc
122 |
123 |
124 | print('Segmentation Loss Initialized')
125 |
126 | def init_regression_loss(self):
127 | regression_loss = tf.square(self.gt_crops1 - self.pred_crops1) + tf.square(self.gt_crops2 - self.pred_crops2)
128 | self.regression_loss = tf.reduce_mean(tf.reduce_sum(regression_loss, axis=1))
129 |
130 | print('Regression Loss Initialized')
131 |
132 | def init_optimizer(self):
133 | optimizer = tf.train.AdamOptimizer(learning_rate=config.learning_rate, beta1=config.beta1, name='Adam',
134 | epsilon=config.epsilon)
135 |
136 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
137 | with tf.control_dependencies(update_ops):
138 | self.train_op = optimizer.minimize(loss=self.total_loss, colocate_gradients_with_ops=True)
139 |
140 | def save(self, sess, file_name):
141 | save_path = self.saver.save(sess, file_name)
142 | print("Model saved in file: %s" % save_path)
143 | sys.stdout.flush()
144 |
145 | def load(self, sess, file_name):
146 | self.saver.restore(sess, file_name)
147 | print('Model restored.')
148 | sys.stdout.flush()
149 |
150 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
3 | devices = ['/gpu:0', '/gpu:0']
4 | multi_gpu = len(set(devices)) == 2
5 |
6 | batch_size = 4
7 | n_epochs = 200
8 |
9 | learning_rate, beta1, epsilon = 0.0001, 0.5, 1e-6
10 |
11 | n_frames = 8
12 | n_epochs_for_gt_seg = 0
13 | n_epochs_for_gt_crop = 500
14 |
15 | hr_frame_size = (128*4, 224*4)
16 | hr_lstm_size = (1, 1)
17 | hr_lstm_feats = 256
18 |
19 | lr_frame_size = (128, 224)
20 | lr_lstm_size = (lr_frame_size[0]//4, lr_frame_size[1]//4)
21 | lr_lstm_feats = 256
22 |
23 | model_num = 4
24 |
25 | save_every_n_epochs = 50
26 | output_file_name = './output%d.txt' % model_num
27 | save_file_name = ('./network_saves/model%d' % model_num) + '_%d.ckpt'
28 | save_file_best_name = ('./network_saves_best/model%d' % model_num) + '_%d.ckpt'
29 | start_at_epoch = 1
30 |
31 | output_inference_file = './Anns2/Annotations/'
32 | epoch_save = 394
33 | save_file_inference = ('./network_saves_best/model%d' % model_num) + '_%d.ckpt'
34 |
35 | multiple_objects = False
36 | rand_frame_skip = 1
37 | wait_for_data = 5 # in seconds
38 | batches_until_print = 1
39 |
40 | inv_temp = 0.5
41 | inv_temp_delta = 0.1
42 | pose_dimension = 4
43 |
44 | print_layers = True
45 |
46 |
47 | def clear_output():
48 | with open(output_file_name, 'w') as f:
49 | print('Writing to ' + output_file_name)
50 | f.write('Model #: %d. Batch Size: %d.\n' % (model_num, batch_size))
51 |
52 |
53 | def write_output(string):
54 | try:
55 | output_log = open(output_file_name, 'a')
56 | output_log.write(string)
57 | output_log.close()
58 | except:
59 | print('Unable to save to output log')
60 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from caps_network_test import CapsNet
3 | from skvideo.io import vread
4 | import numpy as np
5 | from PIL import Image
6 | from scipy.misc import imresize
7 | import os
8 | import config
9 | import time
10 |
11 |
12 | def load_video(video_name):
13 | video = vread(video_name)
14 |
15 | t, h, w, _ = video.shape
16 |
17 | resized_video = []
18 | for frame in video:
19 | resized_video.append(imresize(frame, config.hr_frame_size))
20 |
21 | resized_video = np.stack(resized_video, axis=0)
22 |
23 | return resized_video / 255., (h, w)
24 |
25 |
26 | def load_first_frame(frame_name):
27 | image = Image.open(frame_name)
28 | palette = image.getpalette()
29 | image_np = np.array(image)
30 |
31 | return imresize(image_np, config.hr_frame_size, interp='nearest'), palette
32 |
33 |
34 | def process_first_frame(first_frame):
35 | unique_seg_colors = np.unique(first_frame)
36 |
37 | fin_segmentations = {}
38 | for color in unique_seg_colors:
39 | if color == 0:
40 | continue
41 |
42 | gt_seg = np.where(first_frame == color, 1, 0)
43 | if np.sum(gt_seg) == 0:
44 | continue
45 | gt_seg = np.expand_dims(gt_seg, axis=-1)
46 | fin_segmentations[color] = (0, gt_seg)
47 |
48 | return fin_segmentations
49 |
50 |
51 | def get_bounds(img):
52 | h_sum = np.sum(img, axis=1)
53 | w_sum = np.sum(img, axis=0)
54 |
55 | hs = np.where(h_sum > 0)
56 | ws = np.where(w_sum > 0)
57 |
58 | try:
59 | h0 = hs[0][0]
60 | h1 = hs[0][-1]
61 | w0 = ws[0][0]
62 | w1 = ws[0][-1]
63 | except:
64 | return -1, -1, -1, -1
65 |
66 | return h0, h1, w0, w1
67 |
68 |
69 | def get_crop_to_use(h0, h1, w0, w1, h, w, prev_crop):
70 | # uses the h, w of predicted, and the center of gt
71 | if h0 == -1:
72 | use_gt_crop = False
73 | crop_to_use = prev_crop
74 | else:
75 | use_gt_crop = False
76 | crop_to_use = np.clip(np.array([((h0+h1)/2)/h, ((w0+w1)/2)/w, 1.0, 0]), 0, 1)
77 |
78 | return use_gt_crop, crop_to_use
79 |
80 |
81 | def get_seg_for_clip_gt(sess, capsnet, clip, frame_start, lstm_cond, lstm_cond_lr, prev_crop):
82 | f, h, w, _ = clip.shape
83 |
84 | # print(frame_start.min())
85 | # print(frame_start.max())
86 | first_frame_seg_full = frame_start # np.round(frame_start) # frame_start #
87 |
88 | new_video_in = clip
89 | len_clip = new_video_in.shape[0]
90 | if len_clip < config.n_frames:
91 | new_video_in = np.concatenate((new_video_in, np.tile(new_video_in[-1:],
92 | [config.n_frames - len_clip, 1, 1, 1])), axis=0)
93 |
94 | # gets the bounds of the segmentations
95 | h0, h1, w0, w1 = get_bounds(np.round(first_frame_seg_full[:, :, 0]))
96 |
97 | use_gt_crop, crop_to_use = get_crop_to_use(h0, h1, w0, w1, h, w, prev_crop)
98 |
99 | # runs through the network
100 | seg_pred, lstm_cond, lstm_cond_lr, pred_crops = sess.run([capsnet.segment_layer_sig, capsnet.state_t, capsnet.state_t_lr, capsnet.pred_crops1],
101 | feed_dict={capsnet.x_input_video: [new_video_in],
102 | capsnet.x_first_seg: [first_frame_seg_full],
103 | capsnet.hr_cond_input: lstm_cond,
104 | capsnet.lr_cond_input: lstm_cond_lr,
105 | capsnet.use_gt_crop: use_gt_crop,
106 | capsnet.gt_crops1: [crop_to_use]})
107 |
108 | # resizes crop and places it back into original frame size
109 | seg_pred = seg_pred[0]
110 |
111 | overlap_frames = 3
112 |
113 | if use_gt_crop:
114 | crop_to_use = crop_to_use
115 | else:
116 | crop_to_use = np.concatenate((crop_to_use[:2], pred_crops[0][2:]), axis=-1)
117 |
118 | return seg_pred, lstm_cond, lstm_cond_lr, overlap_frames, crop_to_use
119 |
120 |
121 | def generate_inference(sess, capsnet, video, segmentations, orig_dim, vid_name, img_palette):
122 | orig_h, orig_w = orig_dim
123 | n_objects = int(max(segmentations.keys()))
124 |
125 | lstm_conds = np.zeros((n_objects + 1, config.hr_lstm_size[0], config.hr_lstm_size[1], config.hr_lstm_feats))
126 | lstm_conds_lr = np.zeros((n_objects + 1, config.lr_lstm_size[0], config.lr_lstm_size[1], config.lr_lstm_feats))
127 |
128 | prev_coords = np.zeros((n_objects + 1, 4))
129 |
130 | f, h, w, _ = video.shape
131 |
132 | segmentation_maps = np.zeros((config.n_frames, h, w, n_objects + 1))
133 | segmentation_maps[:, :, :, 0] = 0.5
134 | final_segmentation = np.zeros((h, w, 1))
135 | cur_i = np.ones((n_objects + 1,), np.uint8)
136 | overlaps = np.ones((n_objects + 1,), np.uint8)
137 |
138 | vid_dir = 'Output/' + vid_name + '/'
139 | mkdir(vid_dir)
140 |
141 | for i in range(f):
142 | for color in range(1, n_objects + 1):
143 | if color not in segmentations.keys():
144 | continue
145 |
146 | cur_i[color] += 1
147 | start_frame, start_seg = segmentations[color]
148 |
149 | if i < start_frame: # the current frame occurs before the object appears
150 | cur_i[color] = 7
151 | continue
152 | elif i == start_frame: # the current frame is the first frame of the object (use given segmentation)
153 | segmentation_maps[-1, :, :, color:color + 1] = start_seg
154 | cur_i[color] = 7
155 | continue
156 |
157 | if cur_i[color] != config.n_frames - overlaps[color] + 1: # the current frame's segmentation has been predicted
158 | # cur_overlap[color] -= 1
159 | # segmentation_maps[:-1, :, :, color] = segmentation_maps[1:, :, :, color]
160 | continue
161 |
162 | cur_0 = cur_i[color] - 1
163 |
164 | # cond_frame_seg = segmentation_maps[i-1, :, :, color:color+1] # This is the naive approach
165 | # cond_frame_seg = (final_segmentation[i-1] == color).astype(np.float32) # winner take all approach
166 | cond_frame_seg = ((final_segmentation == color).astype(np.float32) + (final_segmentation == 0).astype(
167 | np.float32)) * segmentation_maps[cur_0, :, :, color:color + 1] # winner take all approach 2
168 |
169 | vid_to_use = video[i - 1:i + config.n_frames - 1]
170 | orig_len = vid_to_use.shape[0]
171 | if orig_len < config.n_frames:
172 | vid_to_use = np.concatenate(
173 | [vid_to_use] + [vid_to_use[-1:] for reps in range(config.n_frames - orig_len)], axis=0)
174 |
175 | # use previous frame to generate segmentation for future N frames
176 | pred_seg, lstm_cond, lstm_cond_lr, overlap_frames, coords_used = get_seg_for_clip_gt(sess, capsnet,
177 | vid_to_use,
178 | cond_frame_seg,
179 | lstm_conds[
180 | color:color + 1],
181 | lstm_conds_lr[
182 | color:color + 1],
183 | prev_coords[color])
184 |
185 | segmentation_maps[:, :, :, color:color + 1] = pred_seg
186 | lstm_conds[color:color + 1] = lstm_cond
187 | lstm_conds_lr[color:color + 1] = lstm_cond_lr
188 | overlaps[color] = overlap_frames
189 |
190 | prev_coords[color] = coords_used
191 | # print(overlap_frames)
192 |
193 | cur_i[color] = 1
194 |
195 | final_segmentation[:, :, 0] = np.argmax(segmentation_maps[cur_i, :, :, range(n_objects + 1)], axis=0)
196 |
197 | frame_name = vid_dir + ('%d.png' % i).zfill(5)
198 |
199 | fb_segs_argmax = imresize(final_segmentation[:, :, 0].astype(dtype=np.uint8), (orig_h, orig_w), interp='nearest')
200 | c = Image.fromarray(fb_segs_argmax, mode='P')
201 | c.putpalette(img_palette)
202 | c.save(frame_name, "PNG", mode='P')
203 |
204 |
205 | def mkdir(dl_path):
206 | if not os.path.exists(dl_path):
207 | print("path doesn't exist. trying to make %s" % dl_path)
208 | os.mkdir(dl_path)
209 | else:
210 | print('%s exists, cannot make directory' % dl_path)
211 |
212 |
213 | def inf():
214 | gpu_config = tf.ConfigProto()
215 | gpu_config.gpu_options.allow_growth = True
216 |
217 | capsnet = CapsNet()
218 | with tf.Session(graph=capsnet.graph, config=gpu_config) as sess:
219 | capsnet.load(sess, config.save_file_inference % config.epoch_save)
220 |
221 | # loads in video
222 | video_name = '03deb7ad95'
223 | video, orig_dims = load_video('./Examples/' + video_name + '.mp4')
224 | first_frame, img_palette = load_first_frame('./Examples/00110.png')
225 |
226 | processed_first_frame = process_first_frame(first_frame)
227 |
228 | start_time = time.time()
229 | print('Starting Inference')
230 |
231 | generate_inference(sess, capsnet, video, processed_first_frame, orig_dims, video_name, img_palette)
232 |
233 | print('Finished Inference in %d(s)' % (time.time()-start_time))
234 |
235 |
236 | if __name__ == "__main__":
237 | inf()
238 |
--------------------------------------------------------------------------------
/load_youtube_data_multi.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from scipy.misc import imread, imresize
4 | import numpy as np
5 | import random
6 | from threading import Thread
7 | import time
8 | from PIL import Image
9 | import sys
10 | #from scipy.signal import medfilt, medfilt2d
11 | import matplotlib.pyplot as plt
12 | from scipy.ndimage.filters import median_filter
13 |
14 | data_loc = '/home/kevin/HD2TB/Datasets/YoutubeVOS2018/'
15 | max_vids = 25
16 |
17 |
18 | def get_split_names(tr_or_val):
19 | split_file = data_loc + '%s/meta.json' % tr_or_val
20 |
21 | all_files = []
22 | with open(split_file) as f:
23 | data = json.load(f)
24 | files = sorted(list(data['videos'].keys()))
25 | for file_name in files:
26 | fdict = data['videos'][file_name]['objects']
27 | frames = sorted(list(set([fr for x in fdict.keys() for fr in fdict[x]['frames']])))
28 | all_files.append((file_name, data['videos'][file_name]['objects'], frames))
29 |
30 | return all_files
31 |
32 |
33 | def load_video(file_name, allowable_frames, tr_or_val='train', shuffle=True, n_frames=8, frame_skip=1):
34 | video_dir = data_loc + ('%s_all_frames/JPEGImages/%s/' % (tr_or_val, file_name))
35 | segment_dir = data_loc + ('%s_all_frames/Annotations/%s/' % (tr_or_val, file_name))
36 |
37 | frame_names = sorted(os.listdir(video_dir))
38 | seg_frame_names = sorted(os.listdir(segment_dir))
39 | frame_names = sorted([x for x in frame_names if x[:-4] + '.png' in seg_frame_names])
40 |
41 | start_ind = frame_names.index(allowable_frames[0] + '.jpg')
42 |
43 | if shuffle:
44 | start_frame = np.random.randint(start_ind, max(len(frame_names) - n_frames * frame_skip, start_ind + 1))
45 | else:
46 | start_frame = start_ind
47 |
48 | while start_frame > start_ind and frame_names[start_frame][:-4] not in allowable_frames: # ensures the first frame is not interpolated
49 | start_frame -= 1
50 |
51 | # loads video
52 | frames = []
53 | for f in range(start_frame, start_frame + n_frames*frame_skip, frame_skip):
54 | try:
55 | frames.append(imread(video_dir + frame_names[f]))
56 | #frames.append(np.array(Image.open(video_dir + frame_names[f])))
57 | except:
58 | frames.append(frames[-1])
59 |
60 | video = np.stack(frames, axis=0)
61 |
62 | # loads segmentations
63 | frames = []
64 | for f in range(start_frame, start_frame + n_frames*frame_skip, frame_skip):
65 | try:
66 | frames.append(np.array(Image.open(segment_dir + seg_frame_names[f])))
67 | except:
68 | frames.append(frames[-1])
69 |
70 | segmentation = np.stack(frames, axis=0)
71 |
72 | return video, segmentation, frame_names[start_frame][:-4]
73 |
74 |
75 | def resize_video(video, segmentation, target_size=(120, 120)):
76 | frames, h, w, _ = video.shape
77 |
78 | t_h, t_w = target_size
79 |
80 | video_res = np.zeros((frames, t_h, t_w, 3), np.uint8)
81 | segment_res = np.zeros((frames, t_h, t_w), np.uint8)
82 | for frame in range(frames):
83 | video_res[frame] = imresize(video[frame], (t_h, t_w))
84 | segment_res[frame] = imresize(segmentation[frame], (t_h, t_w), interp='nearest')
85 |
86 | return video_res/255., segment_res
87 |
88 |
89 | def flip_clip(clip, segment_clip):
90 | flip_y = np.random.random_sample()
91 | #flip_x = np.random.random_sample()
92 | if flip_y >= 0.5:
93 | clip = np.flip(clip, axis=2)
94 | segment_clip = np.flip(segment_clip, axis=2)
95 | # if flip_x >= 0.5:
96 | # clip = np.flip(clip, axis=1)
97 | # segment_clip = np.flip(segment_clip, axis=1)
98 |
99 | return clip, segment_clip
100 |
101 |
102 | def get_bounds(img):
103 | h_sum = np.sum(img, axis=1)
104 | w_sum = np.sum(img, axis=0)
105 |
106 | hs = np.where(h_sum > 0)
107 | ws = np.where(w_sum > 0)
108 |
109 | try:
110 | h0 = hs[0][0]
111 | h1 = hs[0][-1]
112 | w0 = ws[0][0]
113 | w1 = ws[0][-1]
114 | except:
115 | return 0, img.shape[0], 0, img.shape[1]
116 |
117 | return h0, h1, w0, w1
118 |
119 |
120 | def get_bounds2(img):
121 | h_sum = np.sum(img, axis=1)
122 | w_sum = np.sum(img, axis=0)
123 |
124 | hs = np.where(h_sum > 0)
125 | ws = np.where(w_sum > 0)
126 |
127 | try:
128 | h0 = hs[0][0]
129 | h1 = hs[0][-1]
130 | w0 = ws[0][0]
131 | w1 = ws[0][-1]
132 | except:
133 | return -1, -1, -1, -1
134 |
135 | return h0, h1, w0, w1
136 |
137 |
138 | def get_bounds_frames(frames):
139 | y0, y1, x0, x1 = [], [], [], []
140 |
141 | y1_0, y2_0, x1_0, x2_0 = 0, 0, 0, 0
142 |
143 | for i in range(frames.shape[0]):
144 | h0, h1, w0, w1 = get_bounds2(frames[i])
145 | if i == 0 and h0 == -1:
146 | return -1, -1, -1, -1, -1, -1
147 | elif h0 == -1:
148 | continue
149 | else:
150 | y0.append(h0)
151 | y1.append(h1)
152 | x0.append(w0)
153 | x1.append(w1)
154 |
155 | if i == 0:
156 | y1_0, y2_0, x1_0, x2_0 = h0, h1, w0, w1
157 |
158 | center_y1, center_x1 = ((y2_0 + y1_0) / 2), ((x2_0 + x1_0) / 2)
159 |
160 | min_y, max_y, min_x, max_x = min(y0), max(y1), min(x0), max(x1)
161 |
162 | return center_y1, center_x1, min_y, max_y, min_x, max_x
163 |
164 |
165 | def get_crops2(segmentations, n_frames, leeway=.15):
166 | _, h, w = segmentations.shape
167 | leeway_h, leeway_w = leeway * h, leeway * w
168 |
169 | y1_0, y2_0, x1_0, x2_0 = get_bounds(segmentations[0])
170 | y1_1, y2_1, x1_1, x2_1 = get_bounds(segmentations[n_frames - 1])
171 | y1_2, y2_2, x1_2, x2_2 = get_bounds(segmentations[n_frames*2 - 2])
172 |
173 | min_y = min(y1_0, y1_1)
174 | min_x = min(x1_0, x1_1)
175 | max_y = max(y2_0, y2_1)
176 | max_x = max(x2_0, x2_1)
177 |
178 | center_y1, center_x1 = ((y2_0 + y1_0)/2), ((x2_0 + x1_0)/2)
179 | h1, w1 = 2*max(center_y1 - min_y, max_y - center_y1) + leeway_h, 2*max(center_x1 - min_x, max_x - center_x1) + leeway_w
180 | h1, w1 = max(h1, h//8), max(w1, w//8)
181 | alpha1 = max(h1/h, w1/w)
182 | h1, w1 = alpha1*h, alpha1*w
183 |
184 | y1 = center_y1 - 0.5*h1
185 | x1 = center_x1 - 0.5*w1
186 | if y1 + h1 > h:
187 | y1 -= (y1+h1-h)
188 | if x1 + w1 > w:
189 | x1 -= (x1+w1-w)
190 |
191 | min_y = min(y1_2, y1_1)
192 | min_x = min(x1_2, x1_1)
193 | max_y = max(y2_2, y2_1)
194 | max_x = max(x2_2, x2_1)
195 |
196 | center_y2, center_x2 = ((y2_1 + y1_1) / 2), ((x2_1 + x1_1) / 2)
197 | h2, w2 = 2 * max(center_y2 - min_y, max_y - center_y2) + leeway_h, 2 * max(center_x2 - min_x, max_x - center_x2) + leeway_w
198 | h2, w2 = max(h2, h // 8), max(w2, w // 8)
199 | alpha2 = max(h2 / h, w2 / w)
200 | h2, w2 = alpha2 * h, alpha2 * w
201 |
202 | y2 = center_y2 - 0.5 * h2
203 | x2 = center_x2 - 0.5 * w2
204 | if y2 + h2 > h:
205 | y2 -= (y2+h2-h)
206 | if x2 + w2 > w:
207 | x2 -= (x2+w2-w)
208 |
209 | return np.clip(np.array([y1/h, x1/w, alpha1+0.01, 0]), 0, 1), np.clip(np.array([y2/h, x2/w, alpha2+0.01, 0]), 0, 1)
210 |
211 |
212 | def get_crops_fin(segmentations, n_frames, leeway=0.15):
213 | n_in_frames, h, w = segmentations.shape
214 | leeway_h, leeway_w = leeway * h, leeway * w
215 |
216 | assert (n_in_frames - n_frames) % (n_frames-1) == 0
217 |
218 | crops = []
219 | i = 0
220 |
221 | while i < n_in_frames:
222 | frames = segmentations[i:i+n_frames]
223 |
224 | center_y1, center_x1, min_y, max_y, min_x, max_x = get_bounds_frames(frames)
225 |
226 | if center_y1 == -1:
227 | crops.append(crops[-1])
228 | i += 7
229 | continue
230 |
231 | h1, w1 = 2 * max(center_y1 - min_y, max_y - center_y1) + leeway_h, 2 * max(center_x1 - min_x, max_x - center_x1) + leeway_w
232 | h1, w1 = max(h1, h // 8), max(w1, w // 8)
233 | alpha1 = max(h1 / h, w1 / w)
234 | h1, w1 = alpha1 * h, alpha1 * w
235 |
236 | y1 = center_y1 - 0.5 * h1
237 | x1 = center_x1 - 0.5 * w1
238 | if y1 + h1 > h:
239 | y1 -= (y1 + h1 - h)
240 | if x1 + w1 > w:
241 | x1 -= (x1 + w1 - w)
242 |
243 | crops.append(np.clip(np.array([y1/h, x1/w, alpha1+0.01, 0]), 0, 1))
244 | i += 7
245 |
246 | return crops
247 |
248 |
249 | class YoutubeTrainDataGen(object):
250 | def __init__(self, sec_to_wait=5, n_threads=10, crop_size=(256, 448), augment_data=True, n_frames=8,
251 | rand_frame_skip=4, multi_objects=False):
252 | self.train_files = get_split_names('train')
253 |
254 | self.sec_to_wait = sec_to_wait
255 |
256 | self.augment = augment_data
257 | self.rand_frame_skip = rand_frame_skip
258 |
259 | self.crop_size = crop_size
260 | self.multi_objects = multi_objects
261 |
262 | self.n_frames = n_frames
263 |
264 | np.random.seed(None)
265 | random.shuffle(self.train_files)
266 |
267 | self.data_queue = []
268 |
269 | self.thread_list = []
270 | for i in range(n_threads):
271 | load_thread = Thread(target=self.__load_and_process_data)
272 | load_thread.start()
273 | self.thread_list.append(load_thread)
274 |
275 | print('Waiting %d (s) to load data' % sec_to_wait)
276 | sys.stdout.flush()
277 | time.sleep(self.sec_to_wait)
278 |
279 | def __load_and_process_data(self):
280 | while self.train_files:
281 | while len(self.data_queue) >= max_vids:
282 | time.sleep(1)
283 |
284 | try:
285 | vid_name, fdict, allowable_frames = self.train_files.pop()
286 | except:
287 | continue # Thread issue
288 |
289 | frame_skip = np.random.randint(self.rand_frame_skip) + 1
290 |
291 | video, segmentation, frame_name = load_video(vid_name, allowable_frames, tr_or_val='train', n_frames=self.n_frames*2-1, frame_skip=frame_skip)
292 |
293 | video_res, seg_res = resize_video(video, segmentation, self.crop_size)
294 |
295 | # find objects in the first frame
296 | allowable_colors = []
297 | for obj_id in fdict.keys():
298 | if frame_name in fdict[obj_id]['frames']:
299 | allowable_colors.append(int(obj_id))
300 |
301 | # no objects in the first frame
302 | if len(allowable_colors) == 0:
303 | print(vid_name, 'has no colors - SHOULD BE IMPOSSIBLE - POSSIBLE BUG FOUND!')
304 | continue
305 |
306 | # selects the objects which will be chosen from clip
307 | colors_to_select = [x+1 for x in range(len(allowable_colors))]
308 |
309 | if self.multi_objects:
310 | n_colors_foreground = random.choice(colors_to_select)
311 | else:
312 | n_colors_foreground = 1
313 |
314 | random.shuffle(allowable_colors)
315 | selected_colors = allowable_colors[:n_colors_foreground]
316 |
317 | gt_seg = np.zeros_like(seg_res, dtype=np.float32)
318 | for color in selected_colors:
319 | gt_seg += np.where(seg_res == color, 1, 0)
320 | gt_seg = np.clip(gt_seg, 0, 1)
321 |
322 | #gt_seg = np.round(medfilt(gt_seg, (1, 3, 3)))
323 | for frame in range(gt_seg.shape[0]):
324 | if frame % 5 == 0:
325 | continue
326 | gt_seg[frame] = np.round(median_filter(gt_seg[frame], 3))
327 |
328 | if np.sum(gt_seg[0]) == 0:
329 | print('No segmentation found. ERROR.')
330 | continue
331 |
332 | if self.augment:
333 | video_res, gt_seg = flip_clip(video_res, gt_seg)
334 |
335 | leeway = np.random.random_sample() * 0.1 + 0.15
336 | gt_crop1, gr_crop2 = get_crops2(gt_seg, self.n_frames, leeway=leeway)
337 |
338 | gt_seg = np.expand_dims(gt_seg, axis=-1)
339 |
340 | self.data_queue.append((video_res, gt_seg, gt_crop1, gr_crop2))
341 | print('Loading data thread finished')
342 | sys.stdout.flush()
343 |
344 | def get_batch(self, batch_size=5):
345 | while len(self.data_queue) < batch_size and self.train_files:
346 | print('Waiting on data. # Already Loaded = %s' % str(len(self.data_queue)))
347 | sys.stdout.flush()
348 | time.sleep(self.sec_to_wait)
349 |
350 | batch_size = min(batch_size, len(self.data_queue))
351 | batch_x, batch_seg, batch_crop1, batch_crop2 = [], [], [], []
352 | for i in range(batch_size):
353 | vid, seg, cr1, cr2 = self.data_queue.pop(0)
354 | batch_x.append(vid)
355 | batch_seg.append(seg)
356 | batch_crop1.append(cr1)
357 | batch_crop2.append(cr2)
358 | # print(cr1)
359 | # print(cr2)
360 |
361 | return batch_x, batch_seg, batch_crop1, batch_crop2
362 |
363 | def has_data(self):
364 | return self.data_queue != [] or self.train_files != []
365 |
366 |
367 | #
368 | # def main():
369 | # a = YoutubeTrainDataGen(n_threads=10, crop_size=(256*2, 448*2))
370 | # i = 0
371 | # while a.has_data():
372 | # v, s, cr1, cr2 = a.get_batch(1)
373 | # print(i)
374 | # i+=1
375 | # print(cr1)
376 | # cy, cx, h, w = cr1[0]
377 | # cy, cx, h, w = cy*256*2, cx*448*2, h*256*2, w*448*2
378 | #
379 | # mask = np.ones_like(s[0][0, :, :, 0])*2
380 | # mask[np.clip(int(cy-0.5*h), 0, 256*2):np.clip(int(cy+0.5*h), 0, 256*2), np.clip(int(cx-0.5*w), 0, 448*2):np.clip(int(cx+0.5*w), 0, 448*2)] = 0
381 | #
382 | # plt.imshow(s[0][0, :, :, 0] + mask)
383 | # plt.show(plt)
384 | # plt.imshow(s[0][7, :, :, 0] + mask)
385 | # plt.show(plt)
386 | #
387 | # cy, cx, h, w = cr2[0]
388 | # cy, cx, h, w = cy * 256 * 2, cx * 448 * 2, h * 256 * 2, w * 448 * 2
389 | #
390 | # mask = np.ones_like(s[0][0, :, :, 0]) * 2
391 | # mask[np.clip(int(cy - 0.5 * h), 0, 256 * 2):np.clip(int(cy + 0.5 * h), 0, 256 * 2),
392 | # np.clip(int(cx - 0.5 * w), 0, 448 * 2):np.clip(int(cx + 0.5 * w), 0, 448 * 2)] = 0
393 | #
394 | # plt.imshow(s[0][7, :, :, 0] + mask)
395 | # plt.show(plt)
396 | # plt.imshow(s[0][7, :, :, -1] + mask)
397 | # plt.show(plt)
398 | #
399 | #
400 | # main()
401 |
--------------------------------------------------------------------------------
/load_youtubevalid_data.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from scipy.misc import imread, imresize
4 | import numpy as np
5 | from threading import Thread
6 | import time
7 | from PIL import Image
8 | import sys
9 | import matplotlib.pyplot as plt
10 |
11 | data_loc = '/home/kevin/HD2TB/Datasets/YoutubeVOS2018/'
12 | max_vids = 32
13 |
14 |
15 | def get_split_names(tr_or_val):
16 | split_file = data_loc + '%s/meta.json' % tr_or_val
17 |
18 | all_files = []
19 | with open(split_file) as f:
20 | data = json.load(f)
21 | files = sorted(list(data['videos'].keys()))
22 | for file_name in files:
23 | all_files.append((file_name, data['videos'][file_name]['objects']))
24 |
25 | return all_files
26 |
27 |
28 | # Loads in 8 frames
29 | def load_video(file_name, tr_or_val='train', n_frames=8):
30 | video_dir = data_loc + ('%s_all_frames/JPEGImages/%s/' % (tr_or_val, file_name))
31 |
32 | # loads segmentations
33 | segment_dir = data_loc + ('%s/Annotations/%s/' % (tr_or_val, file_name))
34 |
35 | seg_frame_name = sorted(os.listdir(segment_dir))[0]
36 | seg_frame = np.array(Image.open(segment_dir + seg_frame_name))
37 |
38 | frame_names = sorted(os.listdir(video_dir))
39 |
40 | start_frame = frame_names.index(seg_frame_name[:-4] + '.jpg')
41 |
42 | # loads video
43 | frames = []
44 | for f in range(start_frame, start_frame+n_frames):
45 | try:
46 | frames.append(imread(video_dir + frame_names[f]))
47 | except:
48 | frames.append(frames[-1])
49 |
50 | video = np.stack(frames, axis=0)
51 |
52 | return video, seg_frame
53 |
54 |
55 | def resize_video(video, segmentation, target_size=(120, 120)):
56 | frames, h, w, _ = video.shape
57 |
58 | t_h, t_w = target_size
59 |
60 | video_res = np.zeros((frames, t_h, t_w, 3), np.uint8)
61 | for frame in range(frames):
62 | video_res[frame] = imresize(video[frame], (t_h, t_w))
63 |
64 | segment_res = imresize(segmentation, (t_h, t_w), interp='nearest')
65 |
66 | return video_res/255., segment_res
67 |
68 |
69 | def get_bounds(img):
70 | h_sum = np.sum(img, axis=1)
71 | w_sum = np.sum(img, axis=0)
72 |
73 | hs = np.where(h_sum > 0)
74 | ws = np.where(w_sum > 0)
75 |
76 | try:
77 | h0 = hs[0][0]
78 | h1 = hs[0][-1]
79 | w0 = ws[0][0]
80 | w1 = ws[0][-1]
81 | except:
82 | return 0, img.shape[0], 0, img.shape[1]
83 |
84 | return h0, h1, w0, w1
85 |
86 |
87 | def perform_window_crop(gt_seg):
88 | h, w = gt_seg.shape
89 |
90 | # gets the bounds of the segmentation and the center of the object
91 | h0, h1, w0, w1 = get_bounds(gt_seg)
92 | obj_h, obj_w = h1 - h0, w1 - w0
93 | center_h, center_w = h0 + int(obj_h / 2), w0 + int(obj_w / 2)
94 |
95 | # defines the window size around the object
96 | if obj_h <= 0.2*h and obj_w <= 0.2*w:
97 | crop_dims = (0.25*h, 0.25*w)
98 | elif obj_h <= 0.4*h and obj_w <= 0.4*w:
99 | crop_dims = (0.5*h, 0.5*w)
100 | elif obj_h <= 0.65*h and obj_w <= 0.65*w:
101 | crop_dims = (0.75*h, 0.75*w)
102 | else:
103 | crop_dims = (h, w)
104 |
105 | y1 = max(0, center_h - crop_dims[0] / 2)
106 | x1 = max(0, center_w - crop_dims[1] / 2)
107 | if y1 + crop_dims[0] > h:
108 | y1 -= (y1+crop_dims[0]-h)
109 | if x1 + crop_dims[1] > w:
110 | x1 -= (x1+crop_dims[1]-w)
111 |
112 | return np.clip(np.array([y1/h, x1/w, crop_dims[0]/h, crop_dims[1]/w]), 0, 1)
113 |
114 |
115 | def perform_window_crop2(gt_seg):
116 | h, w = gt_seg.shape
117 |
118 | # gets the bounds of the segmentation and the center of the object
119 | h0, h1, w0, w1 = get_bounds(gt_seg)
120 | obj_h, obj_w = h1 - h0, w1 - w0
121 | center_h, center_w = h0 + int(obj_h / 2), w0 + int(obj_w / 2)
122 | min_y, max_y, min_x, max_x = h0, h1, w0, w1
123 |
124 | leeway = 0.15
125 | leeway_h, leeway_w = leeway * h, leeway * w
126 | center_y1, center_x1 = h0 + int(obj_h / 2), w0 + int(obj_w / 2)
127 | h1, w1 = 2 * max(center_y1 - min_y, max_y - center_y1) + leeway_h, 2 * max(center_x1 - min_x,
128 | max_x - center_x1) + leeway_w
129 | h1, w1 = max(h1, h // 8), max(w1, w // 8)
130 | alpha1 = max(h1 / h, w1 / w)
131 | h1, w1 = alpha1 * h, alpha1 * w
132 |
133 | y1 = center_y1 - 0.5 * h1
134 | x1 = center_x1 - 0.5 * w1
135 | if y1 + h1 > h:
136 | y1 -= (y1 + h1 - h)
137 | if x1 + w1 > w:
138 | x1 -= (x1 + w1 - w)
139 |
140 | return np.clip(np.array([y1/h, x1/w, alpha1+0.01, 0]), 0, 1)
141 |
142 |
143 | class YoutubeValidDataGen(object):
144 | def __init__(self, sec_to_wait=5, n_threads=10, crop_size=(256, 448), n_frames=8):
145 | self.train_files = get_split_names('valid')
146 |
147 | self.sec_to_wait = sec_to_wait
148 |
149 | self.crop_size = crop_size
150 |
151 | self.n_frames = n_frames
152 |
153 | self.data_queue = []
154 |
155 | self.thread_list = []
156 | for i in range(n_threads):
157 | load_thread = Thread(target=self.__load_and_process_data)
158 | load_thread.start()
159 | self.thread_list.append(load_thread)
160 |
161 | print('Waiting %d (s) to load data' % sec_to_wait)
162 | sys.stdout.flush()
163 | time.sleep(self.sec_to_wait)
164 |
165 | def __load_and_process_data(self):
166 | while self.train_files:
167 | while len(self.data_queue) >= max_vids:
168 | time.sleep(1)
169 |
170 | try:
171 | vid_name, fdict = self.train_files.pop()
172 | except:
173 | continue # Thread issue
174 |
175 | video, segmentation = load_video(vid_name, tr_or_val='valid', n_frames=self.n_frames*2-1)
176 |
177 | video, segmentation = resize_video(video, segmentation, target_size=self.crop_size)
178 |
179 | # find objects in the first frame
180 | color, gt_seg = 0, 0
181 | for obj_id in sorted(fdict.keys()):
182 | color = int(obj_id)
183 | gt_seg = np.where(segmentation == color, 1, 0)
184 | if np.sum(gt_seg) > 0:
185 | break
186 | else:
187 | color = 0
188 |
189 | # no objects in the first frame
190 | if color == 0:
191 | print('%s has no foreground segmentation.' % vid_name)
192 | continue
193 |
194 | gt_crop = perform_window_crop2(gt_seg)
195 |
196 | gt_seg = np.expand_dims(gt_seg, axis=-1)
197 |
198 | seg_vid = [gt_seg]
199 |
200 | for i in range(self.n_frames*2-1 - 1):
201 | seg_vid.append(np.zeros_like(gt_seg))
202 | seg = np.stack(seg_vid, axis=0)
203 |
204 | self.data_queue.append((video, seg, gt_crop))
205 |
206 | print('Loading data thread finished')
207 | sys.stdout.flush()
208 |
209 | def get_batch(self, batch_size=5):
210 | while len(self.data_queue) < batch_size and self.train_files:
211 | print('Waiting on data')
212 | sys.stdout.flush()
213 | time.sleep(self.sec_to_wait)
214 |
215 | batch_size = min(batch_size, len(self.data_queue))
216 | batch_x, batch_seg, batch_crop = [], [], []
217 | for i in range(batch_size):
218 | vid, seg, crp = self.data_queue.pop(0)
219 | batch_x.append(vid)
220 | batch_seg.append(seg)
221 | batch_crop.append(crp)
222 |
223 | return batch_x, batch_seg, batch_crop
224 |
225 | def has_data(self):
226 | return self.data_queue != [] or self.train_files != []
227 |
228 |
229 | # def main():
230 | # a = YoutubeValidDataGen(n_threads=1, crop_size=(256*2, 448*2))
231 | #
232 | # while a.has_data():
233 | # v, s, cr1 = a.get_batch(1)
234 | # print(cr1)
235 | # cy, cx, alpha, _ = cr1[0]
236 | # cy, cx, h, w = int(cy*256*2), int(cx*448*2), int(alpha*256*2), int(alpha*448*2)
237 | #
238 | # mask = np.ones_like(s[0][0, :, :, 0])*2
239 | # mask[cy:cy+h, cx:cx+w] = 0
240 | #
241 | # plt.imshow(s[0][0, :, :, 0] + mask)
242 | # plt.show(plt)
243 | #
244 | #
245 | #
246 | #
247 | # main()
248 |
249 |
250 |
251 |
--------------------------------------------------------------------------------
/network_parts/lstm_capsnet_cond2_test.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from caps_layers_cond import create_prim_conv3d_caps, create_conv3d_caps, layer_shape, create_conv3d_caps_cond
3 | import config
4 |
5 |
6 | # Basic network:
7 | # HR frame branch has LSTM right before prediction and bounding box of shape y1, x1, alpha (keeps 128x224 aspect ratio)
8 | # LR frame branch has ConvLSTM at 1/4 LR input resolution
9 | # Skip connections only from conditioned capsules and video capsules (no frame capsules given to decoder)
10 |
11 | # Batch size = 3 on TitanXp
12 |
13 | def conv_lstm_layer(x_t, c_tm1, h_tm1, n_feats, kernel_size=(3, 3), strides=(1, 1), padding='SAME', name='lstm_convs'):
14 | inp_cat = tf.concat([x_t, h_tm1], axis=-1)
15 | conv_outputs = tf.layers.conv2d(inp_cat, n_feats*4, kernel_size, strides, padding=padding, name=name)
16 |
17 | input_gate, new_input, forget_gate, output_gate = tf.split(conv_outputs, 4, axis=-1)
18 |
19 | forget_bias = 0.0 # changed this from 1.0
20 |
21 | c_t = tf.nn.sigmoid(forget_gate + forget_bias) * c_tm1
22 | c_t += tf.nn.sigmoid(input_gate) * tf.nn.tanh(new_input) # tf.nn.tanh(new_input)
23 | h_t = tf.nn.tanh(c_t) * tf.sigmoid(output_gate)
24 |
25 | return c_t, h_t
26 |
27 |
28 | def transposed_conv3d(inputs, n_units, kernel_size, strides, padding='VALID', activation=tf.nn.relu, name='deconv', use_bias=True):
29 | conv = tf.layers.conv3d_transpose(inputs, n_units, kernel_size=kernel_size, strides=strides, padding=padding,
30 | use_bias=False, activation=activation, name=name)
31 | if use_bias:
32 | bias = tf.get_variable(name=name + '_bias', shape=(1, 1, 1, 1, n_units))
33 | return activation(conv + bias)
34 | else:
35 | return activation(conv)
36 |
37 |
38 | def create_skip_connection(in_caps_layer, n_units, kernel_size, strides=(1, 1, 1), padding='VALID', name='skip', activation=tf.nn.relu):
39 | in_caps_layer = in_caps_layer[0]
40 | batch_size = tf.shape(in_caps_layer)[0]
41 | _, d, h, w, ch, dim = in_caps_layer.get_shape()
42 | d, h, w, ch, dim = map(int, [d, h, w, ch, dim])
43 |
44 | in_caps_res = tf.reshape(in_caps_layer, [batch_size, d, h, w, ch * dim])
45 |
46 | return tf.layers.conv3d_transpose(in_caps_res, n_units, kernel_size=kernel_size, strides=strides, padding=padding,
47 | use_bias=False, activation=activation, name=name)
48 |
49 |
50 | def video_encoder(x_input):
51 | x = tf.layers.conv3d(x_input, 32, kernel_size=[1, 3, 3], padding='SAME', strides=[1, 1, 1],
52 | activation=tf.nn.relu, name='conv1_2d')
53 | x = tf.layers.conv3d(x, 64, kernel_size=[3, 1, 1], padding='SAME', strides=[1, 1, 1],
54 | activation=tf.nn.relu, name='conv1_1d')
55 |
56 | x = tf.layers.conv3d(x, 64, kernel_size=[1, 3, 3], padding='SAME', strides=[1, 2, 2],
57 | activation=tf.nn.relu, name='conv2_2d')
58 | x = tf.layers.conv3d(x, 128, kernel_size=[3, 1, 1], padding='SAME', strides=[1, 1, 1],
59 | activation=tf.nn.relu, name='conv2_1d')
60 |
61 | x = tf.layers.conv3d(x, 256, kernel_size=[1, 3, 3], padding='SAME', strides=[1, 1, 1],
62 | activation=tf.nn.relu, name='conv3a_2d')
63 | x = tf.layers.conv3d(x, 256, kernel_size=[3, 1, 1], padding='SAME', strides=[1, 1, 1],
64 | activation=tf.nn.relu, name='conv3a_1d')
65 | x = tf.layers.conv3d(x, 256, kernel_size=[1, 3, 3], padding='SAME', strides=[1, 2, 2],
66 | activation=tf.nn.relu, name='conv3b_2d')
67 | x = tf.layers.conv3d(x, 256, kernel_size=[3, 1, 1], padding='SAME', strides=[1, 1, 1],
68 | activation=tf.nn.relu, name='conv3b_1d')
69 |
70 | x = tf.layers.conv3d(x, 512, kernel_size=[1, 3, 3], padding='SAME', strides=[1, 1, 1],
71 | activation=tf.nn.relu, name='conv4a_2d')
72 | x = tf.layers.conv3d(x, 512, kernel_size=[3, 1, 1], padding='SAME', strides=[1, 1, 1],
73 | activation=tf.nn.relu, name='conv4a_1d')
74 | x = tf.layers.conv3d(x, 512, kernel_size=[1, 3, 3], padding='SAME', strides=[1, 1, 1],
75 | activation=tf.nn.relu, name='conv4b_2d')
76 | x = tf.layers.conv3d(x, 512, kernel_size=[3, 1, 1], padding='SAME', strides=[1, 1, 1],
77 | activation=tf.nn.relu, name='conv4b_1d')
78 |
79 | return x
80 |
81 |
82 | def lr_frame_encoder(frame_plus_seg, c_tm1_lr, h_tm1_lr):
83 | fr_conv1 = tf.layers.conv2d(frame_plus_seg, 32, kernel_size=[3, 3], padding='SAME', strides=[1, 1],
84 | activation=tf.nn.relu, name='lr_fr_conv1')
85 | fr_conv2 = tf.layers.conv2d(fr_conv1, 64, kernel_size=[3, 3], padding='SAME', strides=[2, 2],
86 | activation=tf.nn.relu, name='lr_fr_conv2')
87 |
88 | fr_conv3 = tf.layers.conv2d(fr_conv2, 64, kernel_size=[3, 3], padding='SAME', strides=[1, 1],
89 | activation=tf.nn.relu, name='lr_fr_conv3')
90 | fr_conv4 = tf.layers.conv2d(fr_conv3, 128, kernel_size=[3, 3], padding='SAME', strides=[2, 2],
91 | activation=tf.nn.relu, name='lr_fr_conv4')
92 |
93 | fr_conv5 = tf.layers.conv2d(fr_conv4, config.lr_lstm_feats // 2, kernel_size=[3, 3], padding='SAME', strides=[1, 1],
94 | activation=tf.nn.relu, name='lr_fr_conv5')
95 | c_t, h_t = conv_lstm_layer(fr_conv5, c_tm1_lr, h_tm1_lr, config.lr_lstm_feats // 2, name='lr_fr_lstm')
96 |
97 | return c_t, h_t
98 |
99 |
100 | def hr_frame_encoder(frame_plus_seg, c_tm1, h_tm1):
101 | fr_conv1 = tf.layers.conv2d(frame_plus_seg, 32, kernel_size=[3, 3], padding='SAME', strides=[1, 1],
102 | activation=tf.nn.relu, name='hr_fr_conv1')
103 | fr_conv2 = tf.layers.conv2d(fr_conv1, 64, kernel_size=[3, 3], padding='SAME', strides=[2, 2],
104 | activation=tf.nn.relu, name='hr_fr_conv2') # 256, 448
105 |
106 | fr_conv3 = tf.layers.conv2d(fr_conv2, 128, kernel_size=[3, 3], padding='SAME', strides=[2, 2],
107 | activation=tf.nn.relu, name='hr_fr_conv3') # 128, 224
108 | fr_conv4 = tf.layers.conv2d(fr_conv3, 128, kernel_size=[3, 3], padding='SAME', strides=[2, 2],
109 | activation=tf.nn.relu, name='hr_fr_conv4') # 64, 112
110 |
111 | fr_conv5 = tf.layers.conv2d(fr_conv4, 256, kernel_size=[3, 3], padding='SAME', strides=[2, 2],
112 | activation=tf.nn.relu, name='hr_fr_conv5') # 32, 56
113 | fr_conv6 = tf.layers.conv2d(fr_conv5, 256, kernel_size=[3, 3], padding='SAME', strides=[2, 2],
114 | activation=tf.nn.relu, name='hr_fr_conv6') # 16, 28
115 |
116 | fr_conv7 = tf.layers.conv2d(fr_conv6, 512, kernel_size=[3, 3], padding='SAME', strides=[2, 2],
117 | activation=tf.nn.relu, name='hr_fr_conv7') # 8, 14
118 |
119 | fr_conv8 = tf.layers.conv2d(fr_conv7, config.hr_lstm_feats // 2, kernel_size=[8, 14], padding='VALID',
120 | strides=[1, 1], activation=tf.nn.relu, name='hr_fr_conv8') # 1, 1
121 |
122 | c_t, h_t = conv_lstm_layer(fr_conv8, c_tm1, h_tm1, config.hr_lstm_feats // 2, kernel_size=(1, 1))
123 |
124 | crop_pred = tf.layers.conv2d(h_t, 3, kernel_size=[1, 1], padding='VALID', strides=[1, 1],
125 | activation=tf.nn.sigmoid, name='crop_reg')
126 |
127 | y1, x1, alpha = tf.split(crop_pred[:, 0, 0, :], 3, axis=-1)
128 |
129 | alpha = 0.875 * alpha + 0.125
130 | zero = tf.zeros_like(alpha)
131 |
132 | return c_t, h_t, tf.concat((y1, x1, alpha, zero), axis=-1)
133 |
134 |
135 | def convert_crop(crop, use_gt):
136 | y1, x1, alpha, _ = tf.split(crop, 4, axis=-1)
137 | y1 = tf.cond(use_gt, lambda: y1, lambda: y1)
138 | x1 = tf.cond(use_gt, lambda: x1, lambda: x1)
139 | alpha = tf.cond(use_gt, lambda: alpha, lambda: alpha)
140 | y1, x1, alpha = y1 - 0.05, x1 - 0.05, alpha + 0.1
141 | #y1, x1, alpha = y1 - 0.1, x1 - 0.1, alpha + 0.2
142 |
143 | y2 = y1+alpha
144 | x2 = x1+alpha
145 |
146 | crop_new = tf.concat((y1, x1, y2, x2), axis=-1)
147 | crop_new = tf.clip_by_value(crop_new, 0, 1)
148 |
149 | return crop_new
150 |
151 |
152 | def uncrop(seg_out, crop_used):
153 | hr_h, hr_w = config.hr_frame_size
154 |
155 | y1, x1, y2, x2 = tf.split(crop_used, 4, axis=-1)
156 |
157 | up_padding = tf.cast(tf.floor(y1 * hr_h), tf.int32)
158 | left_padding = tf.cast(tf.floor(x1 * hr_w), tf.int32)
159 |
160 | img_res_h = tf.cast(tf.ceil((y2 - y1) * hr_h), tf.int32)
161 | img_res_w = tf.cast(tf.ceil((x2 - x1) * hr_w), tf.int32)
162 |
163 | down_padding = hr_h - img_res_h - up_padding
164 | right_padding = hr_w - img_res_w - left_padding
165 |
166 | fin_seg_init = tf.TensorArray(tf.float32, size=tf.shape(seg_out)[0])
167 |
168 | def cond(fin_seg, counter):
169 | return tf.less(counter, tf.shape(seg_out)[0])
170 |
171 | def res_and_pad(fin_seg, counter):
172 | segmentation = seg_out[counter]
173 |
174 | res_img = tf.image.resize_images(segmentation, (img_res_h[counter, 0], img_res_w[counter, 0]))
175 |
176 | padded_img = tf.pad(res_img, [[up_padding[counter, 0], down_padding[counter, 0]],
177 | [left_padding[counter, 0], right_padding[counter, 0]],
178 | [0, 0]], constant_values=-1000)
179 |
180 | fin_seg = fin_seg.write(counter, padded_img)
181 |
182 | return fin_seg, counter+1
183 |
184 | fin_seg, _ = tf.while_loop(cond, res_and_pad, [fin_seg_init, 0])
185 |
186 | return fin_seg.stack()
187 |
188 |
189 | def create_decoder_network(pred_caps, sec_caps, prim_caps, print_layers=True):
190 | deconv1 = create_skip_connection(pred_caps, 128, kernel_size=[3, 3, 3], strides=[2, 2, 2], padding='SAME',
191 | name='deconv1')
192 |
193 | skip_connection1 = create_skip_connection(sec_caps, 128, kernel_size=[3, 3, 3], strides=[1, 1, 1], padding='SAME',
194 | name='skip_1')
195 | deconv1 = tf.concat([deconv1, skip_connection1], axis=-1)
196 |
197 | deconv2 = transposed_conv3d(deconv1, 128, kernel_size=[3, 3, 3], strides=[2, 2, 2], padding='SAME',
198 | activation=tf.nn.relu, name='deconv2')
199 |
200 | skip_connection2 = create_skip_connection(prim_caps, 128, kernel_size=[1, 3, 3],
201 | strides=[1, 1, 1], padding='SAME', name='skip_2')
202 | deconv2 = tf.concat([deconv2, skip_connection2], axis=-1)
203 |
204 | deconv3 = transposed_conv3d(deconv2, 256, kernel_size=[1, 3, 3], strides=[1, 2, 2], padding='SAME',
205 | activation=tf.nn.relu, name='deconv3')
206 | deconv4 = transposed_conv3d(deconv3, 256, kernel_size=[1, 3, 3], strides=[1, 2, 2], padding='SAME',
207 | activation=tf.nn.relu, name='deconv4')
208 | deconv5 = transposed_conv3d(deconv4, 128, kernel_size=[1, 3, 3], strides=[1, 2, 2], padding='SAME',
209 | activation=tf.nn.relu, name='deconv5')
210 |
211 | segment_layer = tf.layers.conv3d(deconv5, 1, kernel_size=[1, 1, 1], strides=[1, 1, 1], padding='SAME',
212 | activation=None, name='segment_layer')
213 | #segment_layer_sig = tf.nn.sigmoid(segment_layer)
214 |
215 | if print_layers:
216 | print('Deconv Layer 1:', deconv1.get_shape())
217 | print('Deconv Layer 2:', deconv2.get_shape())
218 | print('Deconv Layer 3:', deconv3.get_shape())
219 | print('Deconv Layer 4:', deconv4.get_shape())
220 | print('Deconv Layer 5:', deconv5.get_shape())
221 | print('Segment Layer:', segment_layer.get_shape())
222 |
223 | return segment_layer
224 |
225 |
226 | def create_network_one_pass(hr_frames, hr_first_frame_seg, c_tm1, h_tm1, c_tm1_lr, h_tm1_lr, use_gt, gt_crop,
227 | coord_addition, print_layers=True):
228 | # encodes hr frame, and gets the predicted crop
229 | hr_frame_plus_seg = tf.concat([hr_frames[:, 0], hr_first_frame_seg], axis=-1)
230 | c_t, h_t, pred_crop = hr_frame_encoder(hr_frame_plus_seg, c_tm1, h_tm1)
231 |
232 | center_y, center_x, alpha_gt, _ = tf.split(gt_crop, 4, axis=-1)
233 | _, _, alpha, _ = tf.split(pred_crop, 4, -1)
234 |
235 | y0 = tf.clip_by_value(center_y - 0.5*alpha, 0, 1)
236 | x0 = tf.clip_by_value(center_x - 0.5 * alpha, 0, 1)
237 | y1 = tf.clip_by_value(center_y - 0.5 * alpha_gt, 0, 1)
238 | x1 = tf.clip_by_value(center_x - 0.5 * alpha_gt, 0, 1)
239 |
240 | crop_to_use = tf.cond(use_gt, lambda:tf.concat([y1, x1, alpha_gt, alpha_gt], axis=-1), lambda: tf.concat([y0, x0, alpha, alpha], axis=-1))
241 | # crop_to_use = tf.cond(use_gt, lambda: gt_crop, lambda: pred_crop)
242 |
243 | # crop_to_use = tf.cond(use_gt, lambda: gt_crop, lambda: pred_crop)
244 | crop_to_use = convert_crop(crop_to_use, use_gt)
245 |
246 | frame_h, frame_w = config.lr_frame_size
247 |
248 | # crops the low resolution frame+seg
249 | range_crops = tf.range(tf.shape(crop_to_use)[0])
250 | lr_frame_plus_seg = tf.image.crop_and_resize(hr_frame_plus_seg, crop_to_use, range_crops, (frame_h, frame_w))
251 |
252 | c_t_lr, h_t_lr = lr_frame_encoder(lr_frame_plus_seg, c_tm1_lr, h_tm1_lr)
253 |
254 | # crops the video
255 | tiled_crop_to_use = tf.reshape(tf.tile(tf.expand_dims(crop_to_use, 1), [1, config.n_frames, 1]), (-1, 4))
256 | video_res = tf.reshape(hr_frames, (-1, config.hr_frame_size[0], config.hr_frame_size[1], 3))
257 | cropped_video = tf.image.crop_and_resize(video_res, tiled_crop_to_use, tf.range(tf.shape(tiled_crop_to_use)[0]),
258 | (frame_h, frame_w))
259 | cropped_video = tf.reshape(cropped_video, (-1, config.n_frames, frame_h, frame_w, 3))
260 |
261 | # creates video capsules
262 | lr_video_encoding = video_encoder(cropped_video)
263 |
264 | vid_caps = create_prim_conv3d_caps(lr_video_encoding, 12, kernel_size=[1, 3, 3], strides=[1, 2, 2], padding='SAME',
265 | name='vid_caps')
266 |
267 | vid_caps2 = vid_caps
268 | # vid_caps2 = (vid_caps[0] + coord_addition, vid_caps[1])
269 |
270 | # creates frame capsules, tiles them, and performs coordinate addition
271 | frame_caps = create_prim_conv3d_caps(tf.expand_dims(h_t_lr, axis=1), 8, kernel_size=[1, 3, 3],
272 | strides=[1, 2, 2], padding='SAME', name='frame_caps')
273 | frame_caps = (tf.tile(frame_caps[0], [1, config.n_frames, 1, 1, 1, 1]),
274 | tf.tile(frame_caps[1], [1, config.n_frames, 1, 1, 1, 1]))
275 |
276 | frame_caps = (frame_caps[0] + coord_addition, frame_caps[1])
277 |
278 | # merges video and frame capsules
279 | prim_caps = (tf.concat([vid_caps2[0], frame_caps[0]], axis=-2), tf.concat([vid_caps2[1], frame_caps[1]], axis=-2))
280 |
281 | # performs capsule routing
282 | sec_caps, _ = create_conv3d_caps_cond(prim_caps, 16, kernel_size=[3, 3, 3], strides=[2, 2, 2], padding='SAME',
283 | name='sec_caps', route_mean=True, n_cond_caps=8)
284 | pred_caps = create_conv3d_caps(sec_caps, 16, kernel_size=[3, 3, 3], strides=[2, 2, 2], padding='SAME',
285 | name='third_caps', route_mean=True)
286 | fin_caps = tf.reduce_mean(pred_caps[1], [2, 3, 4, 5])
287 |
288 | if print_layers:
289 | print('Primary Caps:', layer_shape(prim_caps))
290 | print('Secondary Caps:', layer_shape(sec_caps))
291 | print('Prediction Caps:', layer_shape(pred_caps))
292 |
293 | # obtains the segmentations
294 | seg = create_decoder_network(pred_caps, sec_caps, vid_caps, print_layers=print_layers)
295 |
296 | hr_seg_out = uncrop(tf.reshape(seg, (-1, frame_h, frame_w, 1)), tiled_crop_to_use)
297 | hr_seg = tf.reshape(hr_seg_out, (-1, config.n_frames, config.hr_frame_size[0], config.hr_frame_size[1], 1))
298 | hr_seg_sig = tf.nn.sigmoid(hr_seg)
299 |
300 | return hr_seg, hr_seg_sig, c_t, h_t, c_t_lr, h_t_lr, pred_crop, fin_caps
301 |
302 |
303 | def create_network(x_input, y_segmentation, f_tm1, f_tm1_lr, use_gt_crop, gt_crop1):
304 | coords_to_add = tf.reshape(tf.range(config.n_frames, dtype=tf.float32) / (config.n_frames - 1),
305 | (1, config.n_frames, 1, 1, 1, 1))
306 | zeros_to_add = tf.zeros((1, config.n_frames, 1, 1, 1, 15), dtype=tf.float32)
307 | coords_to_add = tf.concat((zeros_to_add, coords_to_add), axis=-1)
308 |
309 | c_tm1, h_tm1 = f_tm1[:, :, :, :config.hr_lstm_feats // 2], f_tm1[:, :, :, config.hr_lstm_feats // 2:]
310 | c_tm1_lr, h_tm1_lr = f_tm1_lr[:, :, :, :config.lr_lstm_feats // 2], f_tm1_lr[:, :, :, config.lr_lstm_feats // 2:]
311 |
312 | network_outs1 = create_network_one_pass(x_input[:, :config.n_frames], y_segmentation, c_tm1, h_tm1, c_tm1_lr,
313 | h_tm1_lr, use_gt_crop, gt_crop1, coords_to_add, print_layers=True)
314 |
315 | seg, seg_sig, c_t, h_t, c_t_lr, h_t_lr, pred_crop1, fin_caps = network_outs1
316 |
317 | return seg, seg_sig, fin_caps, tf.concat([c_t, h_t], axis=-1), tf.concat([c_t_lr, h_t_lr], axis=-1), pred_crop1
318 |
--------------------------------------------------------------------------------
/network_parts/lstm_capsnet_cond2_train.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from caps_layers_cond import create_prim_conv3d_caps, create_conv3d_caps, layer_shape, create_conv3d_caps_cond
3 | import config
4 |
5 |
6 | # Basic network:
7 | # HR frame branch has LSTM right before prediction and bounding box of shape y1, x1, alpha (keeps 128x224 aspect ratio)
8 | # LR frame branch has ConvLSTM at 1/4 LR input resolution
9 | # Skip connections only from conditioned capsules and video capsules (no frame capsules given to decoder)
10 |
11 | # Batch size = 4 on TitanXp
12 |
13 | def conv_lstm_layer(x_t, c_tm1, h_tm1, n_feats, kernel_size=(3, 3), strides=(1, 1), padding='SAME', name='lstm_convs'):
14 | inp_cat = tf.concat([x_t, h_tm1], axis=-1)
15 | conv_outputs = tf.layers.conv2d(inp_cat, n_feats*4, kernel_size, strides, padding=padding, name=name)
16 |
17 | input_gate, new_input, forget_gate, output_gate = tf.split(conv_outputs, 4, axis=-1)
18 |
19 | forget_bias = 0.0 # changed this from 1.0
20 |
21 | c_t = tf.nn.sigmoid(forget_gate + forget_bias) * c_tm1
22 | c_t += tf.nn.sigmoid(input_gate) * tf.nn.tanh(new_input) # tf.nn.tanh(new_input)
23 | h_t = tf.nn.tanh(c_t) * tf.sigmoid(output_gate)
24 |
25 | return c_t, h_t
26 |
27 |
28 | def transposed_conv3d(inputs, n_units, kernel_size, strides, padding='VALID', activation=tf.nn.relu, name='deconv', use_bias=True):
29 | conv = tf.layers.conv3d_transpose(inputs, n_units, kernel_size=kernel_size, strides=strides, padding=padding,
30 | use_bias=False, activation=activation, name=name)
31 | if use_bias:
32 | bias = tf.get_variable(name=name + '_bias', shape=(1, 1, 1, 1, n_units))
33 | return activation(conv + bias)
34 | else:
35 | return activation(conv)
36 |
37 |
38 | def create_skip_connection(in_caps_layer, n_units, kernel_size, strides=(1, 1, 1), padding='VALID', name='skip', activation=tf.nn.relu):
39 | in_caps_layer = in_caps_layer[0]
40 | batch_size = tf.shape(in_caps_layer)[0]
41 | _, d, h, w, ch, dim = in_caps_layer.get_shape()
42 | d, h, w, ch, dim = map(int, [d, h, w, ch, dim])
43 |
44 | in_caps_res = tf.reshape(in_caps_layer, [batch_size, d, h, w, ch * dim])
45 |
46 | return tf.layers.conv3d_transpose(in_caps_res, n_units, kernel_size=kernel_size, strides=strides, padding=padding,
47 | use_bias=False, activation=activation, name=name)
48 |
49 |
50 | def video_encoder(x_input):
51 | x = tf.layers.conv3d(x_input, 32, kernel_size=[1, 3, 3], padding='SAME', strides=[1, 1, 1],
52 | activation=tf.nn.relu, name='conv1_2d')
53 | x = tf.layers.conv3d(x, 64, kernel_size=[3, 1, 1], padding='SAME', strides=[1, 1, 1],
54 | activation=tf.nn.relu, name='conv1_1d')
55 |
56 | x = tf.layers.conv3d(x, 64, kernel_size=[1, 3, 3], padding='SAME', strides=[1, 2, 2],
57 | activation=tf.nn.relu, name='conv2_2d')
58 | x = tf.layers.conv3d(x, 128, kernel_size=[3, 1, 1], padding='SAME', strides=[1, 1, 1],
59 | activation=tf.nn.relu, name='conv2_1d')
60 |
61 | x = tf.layers.conv3d(x, 256, kernel_size=[1, 3, 3], padding='SAME', strides=[1, 1, 1],
62 | activation=tf.nn.relu, name='conv3a_2d')
63 | x = tf.layers.conv3d(x, 256, kernel_size=[3, 1, 1], padding='SAME', strides=[1, 1, 1],
64 | activation=tf.nn.relu, name='conv3a_1d')
65 | x = tf.layers.conv3d(x, 256, kernel_size=[1, 3, 3], padding='SAME', strides=[1, 2, 2],
66 | activation=tf.nn.relu, name='conv3b_2d')
67 | x = tf.layers.conv3d(x, 256, kernel_size=[3, 1, 1], padding='SAME', strides=[1, 1, 1],
68 | activation=tf.nn.relu, name='conv3b_1d')
69 |
70 | x = tf.layers.conv3d(x, 512, kernel_size=[1, 3, 3], padding='SAME', strides=[1, 1, 1],
71 | activation=tf.nn.relu, name='conv4a_2d')
72 | x = tf.layers.conv3d(x, 512, kernel_size=[3, 1, 1], padding='SAME', strides=[1, 1, 1],
73 | activation=tf.nn.relu, name='conv4a_1d')
74 | x = tf.layers.conv3d(x, 512, kernel_size=[1, 3, 3], padding='SAME', strides=[1, 1, 1],
75 | activation=tf.nn.relu, name='conv4b_2d')
76 | x = tf.layers.conv3d(x, 512, kernel_size=[3, 1, 1], padding='SAME', strides=[1, 1, 1],
77 | activation=tf.nn.relu, name='conv4b_1d')
78 |
79 | return x
80 |
81 |
82 | def lr_frame_encoder(frame_plus_seg, c_tm1_lr, h_tm1_lr):
83 | fr_conv1 = tf.layers.conv2d(frame_plus_seg, 32, kernel_size=[3, 3], padding='SAME', strides=[1, 1],
84 | activation=tf.nn.relu, name='lr_fr_conv1')
85 | fr_conv2 = tf.layers.conv2d(fr_conv1, 64, kernel_size=[3, 3], padding='SAME', strides=[2, 2],
86 | activation=tf.nn.relu, name='lr_fr_conv2')
87 |
88 | fr_conv3 = tf.layers.conv2d(fr_conv2, 64, kernel_size=[3, 3], padding='SAME', strides=[1, 1],
89 | activation=tf.nn.relu, name='lr_fr_conv3')
90 | fr_conv4 = tf.layers.conv2d(fr_conv3, 128, kernel_size=[3, 3], padding='SAME', strides=[2, 2],
91 | activation=tf.nn.relu, name='lr_fr_conv4')
92 |
93 | fr_conv5 = tf.layers.conv2d(fr_conv4, config.lr_lstm_feats // 2, kernel_size=[3, 3], padding='SAME', strides=[1, 1],
94 | activation=tf.nn.relu, name='lr_fr_conv5')
95 | c_t, h_t = conv_lstm_layer(fr_conv5, c_tm1_lr, h_tm1_lr, config.lr_lstm_feats // 2, name='lr_fr_lstm')
96 |
97 | return c_t, h_t
98 |
99 |
100 | def hr_frame_encoder(frame_plus_seg, c_tm1, h_tm1):
101 | fr_conv1 = tf.layers.conv2d(frame_plus_seg, 32, kernel_size=[3, 3], padding='SAME', strides=[1, 1],
102 | activation=tf.nn.relu, name='hr_fr_conv1')
103 | fr_conv2 = tf.layers.conv2d(fr_conv1, 64, kernel_size=[3, 3], padding='SAME', strides=[2, 2],
104 | activation=tf.nn.relu, name='hr_fr_conv2') # 256, 448
105 |
106 | fr_conv3 = tf.layers.conv2d(fr_conv2, 128, kernel_size=[3, 3], padding='SAME', strides=[2, 2],
107 | activation=tf.nn.relu, name='hr_fr_conv3') # 128, 224
108 | fr_conv4 = tf.layers.conv2d(fr_conv3, 128, kernel_size=[3, 3], padding='SAME', strides=[2, 2],
109 | activation=tf.nn.relu, name='hr_fr_conv4') # 64, 112
110 |
111 | fr_conv5 = tf.layers.conv2d(fr_conv4, 256, kernel_size=[3, 3], padding='SAME', strides=[2, 2],
112 | activation=tf.nn.relu, name='hr_fr_conv5') # 32, 56
113 | fr_conv6 = tf.layers.conv2d(fr_conv5, 256, kernel_size=[3, 3], padding='SAME', strides=[2, 2],
114 | activation=tf.nn.relu, name='hr_fr_conv6') # 16, 28
115 |
116 | fr_conv7 = tf.layers.conv2d(fr_conv6, 512, kernel_size=[3, 3], padding='SAME', strides=[2, 2],
117 | activation=tf.nn.relu, name='hr_fr_conv7') # 8, 14
118 |
119 | fr_conv8 = tf.layers.conv2d(fr_conv7, config.hr_lstm_feats // 2, kernel_size=[8, 14], padding='VALID',
120 | strides=[1, 1], activation=tf.nn.relu, name='hr_fr_conv8') # 1, 1
121 |
122 | c_t, h_t = conv_lstm_layer(fr_conv8, c_tm1, h_tm1, config.hr_lstm_feats // 2, kernel_size=(1, 1))
123 |
124 | crop_pred = tf.layers.conv2d(h_t, 3, kernel_size=[1, 1], padding='VALID', strides=[1, 1],
125 | activation=tf.nn.sigmoid, name='crop_reg')
126 |
127 | y1, x1, alpha = tf.split(crop_pred[:, 0, 0, :], 3, axis=-1)
128 |
129 | alpha = 0.875 * alpha + 0.125
130 | zero = tf.zeros_like(alpha)
131 |
132 | return c_t, h_t, tf.concat((y1, x1, alpha, zero), axis=-1)
133 |
134 |
135 | def convert_crop(crop):
136 | y1, x1, alpha, _ = tf.split(crop, 4, axis=-1)
137 | #y1, x1, alpha = y1 - 0.05, x1 - 0.05, alpha + 0.1
138 |
139 | y2 = y1+alpha
140 | x2 = x1+alpha
141 |
142 | crop_new = tf.concat((y1, x1, y2, x2), axis=-1)
143 | crop_new = tf.clip_by_value(crop_new, 0, 1)
144 |
145 | return crop_new
146 |
147 |
148 | def uncrop(seg_out, crop_used):
149 | hr_h, hr_w = config.hr_frame_size
150 |
151 | y1, x1, y2, x2 = tf.split(crop_used, 4, axis=-1)
152 |
153 | up_padding = tf.cast(tf.floor(y1 * hr_h), tf.int32)
154 | left_padding = tf.cast(tf.floor(x1 * hr_w), tf.int32)
155 |
156 | img_res_h = tf.cast(tf.ceil((y2 - y1) * hr_h), tf.int32)
157 | img_res_w = tf.cast(tf.ceil((x2 - x1) * hr_w), tf.int32)
158 |
159 | down_padding = hr_h - img_res_h - up_padding
160 | right_padding = hr_w - img_res_w - left_padding
161 |
162 | fin_seg_init = tf.TensorArray(tf.float32, size=tf.shape(seg_out)[0])
163 |
164 | def cond(fin_seg, counter):
165 | return tf.less(counter, tf.shape(seg_out)[0])
166 |
167 | def res_and_pad(fin_seg, counter):
168 | segmentation = seg_out[counter]
169 |
170 | res_img = tf.image.resize_images(segmentation, (img_res_h[counter, 0], img_res_w[counter, 0]))
171 |
172 | padded_img = tf.pad(res_img, [[up_padding[counter, 0], down_padding[counter, 0]],
173 | [left_padding[counter, 0], right_padding[counter, 0]],
174 | [0, 0]], constant_values=-1000)
175 |
176 | fin_seg = fin_seg.write(counter, padded_img)
177 |
178 | return fin_seg, counter+1
179 |
180 | fin_seg, _ = tf.while_loop(cond, res_and_pad, [fin_seg_init, 0])
181 |
182 | return fin_seg.stack()
183 |
184 |
185 | def create_decoder_network(pred_caps, sec_caps, prim_caps, print_layers=True):
186 | deconv1 = create_skip_connection(pred_caps, 128, kernel_size=[3, 3, 3], strides=[2, 2, 2], padding='SAME',
187 | name='deconv1')
188 |
189 | skip_connection1 = create_skip_connection(sec_caps, 128, kernel_size=[3, 3, 3], strides=[1, 1, 1], padding='SAME',
190 | name='skip_1')
191 | deconv1 = tf.concat([deconv1, skip_connection1], axis=-1)
192 |
193 | deconv2 = transposed_conv3d(deconv1, 128, kernel_size=[3, 3, 3], strides=[2, 2, 2], padding='SAME',
194 | activation=tf.nn.relu, name='deconv2')
195 |
196 | skip_connection2 = create_skip_connection(prim_caps, 128, kernel_size=[1, 3, 3],
197 | strides=[1, 1, 1], padding='SAME', name='skip_2')
198 | deconv2 = tf.concat([deconv2, skip_connection2], axis=-1)
199 |
200 | deconv3 = transposed_conv3d(deconv2, 256, kernel_size=[1, 3, 3], strides=[1, 2, 2], padding='SAME',
201 | activation=tf.nn.relu, name='deconv3')
202 | deconv4 = transposed_conv3d(deconv3, 256, kernel_size=[1, 3, 3], strides=[1, 2, 2], padding='SAME',
203 | activation=tf.nn.relu, name='deconv4')
204 | deconv5 = transposed_conv3d(deconv4, 128, kernel_size=[1, 3, 3], strides=[1, 2, 2], padding='SAME',
205 | activation=tf.nn.relu, name='deconv5')
206 |
207 | segment_layer = tf.layers.conv3d(deconv5, 1, kernel_size=[1, 1, 1], strides=[1, 1, 1], padding='SAME',
208 | activation=None, name='segment_layer')
209 | #segment_layer_sig = tf.nn.sigmoid(segment_layer)
210 |
211 | if print_layers:
212 | print('Deconv Layer 1:', deconv1.get_shape())
213 | print('Deconv Layer 2:', deconv2.get_shape())
214 | print('Deconv Layer 3:', deconv3.get_shape())
215 | print('Deconv Layer 4:', deconv4.get_shape())
216 | print('Deconv Layer 5:', deconv5.get_shape())
217 | print('Segment Layer:', segment_layer.get_shape())
218 |
219 | return segment_layer
220 |
221 |
222 | def create_network_one_pass(hr_frames, hr_first_frame_seg, c_tm1, h_tm1, c_tm1_lr, h_tm1_lr, use_gt, gt_crop,
223 | coord_addition, print_layers=True):
224 | # encodes hr frame, and gets the predicted crop
225 | hr_frame_plus_seg = tf.concat([hr_frames[:, 0], hr_first_frame_seg], axis=-1)
226 | c_t, h_t, pred_crop = hr_frame_encoder(hr_frame_plus_seg, c_tm1, h_tm1)
227 |
228 | crop_to_use = tf.cond(use_gt, lambda: gt_crop, lambda: pred_crop)
229 | crop_to_use = convert_crop(crop_to_use)
230 |
231 | frame_h, frame_w = config.lr_frame_size
232 |
233 | # crops the low resolution frame+seg
234 | range_crops = tf.range(tf.shape(crop_to_use)[0])
235 | lr_frame_plus_seg = tf.image.crop_and_resize(hr_frame_plus_seg, crop_to_use, range_crops, (frame_h, frame_w))
236 |
237 | c_t_lr, h_t_lr = lr_frame_encoder(lr_frame_plus_seg, c_tm1_lr, h_tm1_lr)
238 |
239 | # crops the video
240 | tiled_crop_to_use = tf.reshape(tf.tile(tf.expand_dims(crop_to_use, 1), [1, config.n_frames, 1]), (-1, 4))
241 | video_res = tf.reshape(hr_frames, (-1, config.hr_frame_size[0], config.hr_frame_size[1], 3))
242 | cropped_video = tf.image.crop_and_resize(video_res, tiled_crop_to_use, tf.range(tf.shape(tiled_crop_to_use)[0]),
243 | (frame_h, frame_w))
244 | cropped_video = tf.reshape(cropped_video, (-1, config.n_frames, frame_h, frame_w, 3))
245 |
246 | # creates video capsules
247 | lr_video_encoding = video_encoder(cropped_video)
248 |
249 | vid_caps = create_prim_conv3d_caps(lr_video_encoding, 12, kernel_size=[1, 3, 3], strides=[1, 2, 2], padding='SAME',
250 | name='vid_caps')
251 |
252 | vid_caps2 = vid_caps
253 | # vid_caps2 = (vid_caps[0] + coord_addition, vid_caps[1])
254 |
255 | # creates frame capsules, tiles them, and performs coordinate addition
256 | frame_caps = create_prim_conv3d_caps(tf.expand_dims(h_t_lr, axis=1), 8, kernel_size=[1, 3, 3],
257 | strides=[1, 2, 2], padding='SAME', name='frame_caps')
258 | frame_caps = (tf.tile(frame_caps[0], [1, config.n_frames, 1, 1, 1, 1]),
259 | tf.tile(frame_caps[1], [1, config.n_frames, 1, 1, 1, 1]))
260 |
261 | frame_caps = (frame_caps[0] + coord_addition, frame_caps[1])
262 |
263 | # merges video and frame capsules
264 | prim_caps = (tf.concat([vid_caps2[0], frame_caps[0]], axis=-2), tf.concat([vid_caps2[1], frame_caps[1]], axis=-2))
265 |
266 | # performs capsule routing
267 | sec_caps, _ = create_conv3d_caps_cond(prim_caps, 16, kernel_size=[3, 3, 3], strides=[2, 2, 2], padding='SAME',
268 | name='sec_caps', route_mean=True, n_cond_caps=8)
269 | pred_caps = create_conv3d_caps(sec_caps, 16, kernel_size=[3, 3, 3], strides=[2, 2, 2], padding='SAME',
270 | name='third_caps', route_mean=True)
271 | fin_caps = tf.reduce_mean(pred_caps[1], [2, 3, 4, 5])
272 |
273 | if print_layers:
274 | print('Primary Caps:', layer_shape(prim_caps))
275 | print('Secondary Caps:', layer_shape(sec_caps))
276 | print('Prediction Caps:', layer_shape(pred_caps))
277 |
278 | # obtains the segmentations
279 | seg = create_decoder_network(pred_caps, sec_caps, vid_caps, print_layers=print_layers)
280 |
281 | hr_seg_out = uncrop(tf.reshape(seg, (-1, frame_h, frame_w, 1)), tiled_crop_to_use)
282 | hr_seg = tf.reshape(hr_seg_out, (-1, config.n_frames, config.hr_frame_size[0], config.hr_frame_size[1], 1))
283 | hr_seg_sig = tf.nn.sigmoid(hr_seg)
284 |
285 | return hr_seg, hr_seg_sig, c_t, h_t, c_t_lr, h_t_lr, pred_crop, fin_caps
286 |
287 |
288 | def create_network(x_input, y_segmentation, f_tm1, f_tm1_lr, use_gt_seg, use_gt_crop, gt_crop1, gt_crop2):
289 | coords_to_add = tf.reshape(tf.range(config.n_frames, dtype=tf.float32) / (config.n_frames - 1),
290 | (1, config.n_frames, 1, 1, 1, 1))
291 | zeros_to_add = tf.zeros((1, config.n_frames, 1, 1, 1, 15), dtype=tf.float32)
292 | coords_to_add = tf.concat((zeros_to_add, coords_to_add), axis=-1)
293 |
294 | c_tm1, h_tm1 = f_tm1[:, :, :, :config.hr_lstm_feats // 2], f_tm1[:, :, :, config.hr_lstm_feats // 2:]
295 | c_tm1_lr, h_tm1_lr = f_tm1_lr[:, :, :, :config.lr_lstm_feats // 2], f_tm1_lr[:, :, :, config.lr_lstm_feats // 2:]
296 |
297 | network_outs1 = create_network_one_pass(x_input[:, :config.n_frames], y_segmentation[:, 0], c_tm1, h_tm1, c_tm1_lr,
298 | h_tm1_lr, use_gt_crop, gt_crop1, coords_to_add, print_layers=True)
299 |
300 | hr_seg1, hr_seg_sig1, c_t, h_t, c_t_lr, h_t_lr, pred_crop1, fin_caps = network_outs1
301 |
302 | # reuse variables
303 | tf.get_variable_scope().reuse_variables()
304 |
305 | first_seg2 = tf.cond(use_gt_seg, lambda: y_segmentation[:, config.n_frames - 1], lambda: hr_seg_sig1[:, -1])
306 |
307 | network_outs2 = create_network_one_pass(x_input[:, config.n_frames-1:], first_seg2, c_t, h_t, c_t_lr, h_t_lr,
308 | use_gt_crop, gt_crop2, coords_to_add, print_layers=False)
309 |
310 | hr_seg2, hr_seg_sig2, c_t2, h_t2, c_t2_lr, h_t2_lr, pred_crop2, _ = network_outs2
311 |
312 | fin_seg = tf.concat([hr_seg1, hr_seg2[:, 1:]], axis=1)
313 | fin_seg_sig = tf.concat([hr_seg_sig1, hr_seg_sig2[:, 1:]], axis=1)
314 |
315 | return fin_seg, fin_seg_sig, fin_caps, tf.concat([c_t2, h_t2], axis=-1), pred_crop1, pred_crop2
316 |
--------------------------------------------------------------------------------
/network_saves/readme.txt:
--------------------------------------------------------------------------------
1 | This will hold network weights, which are periodically saved.
2 |
--------------------------------------------------------------------------------
/network_saves_best/readme.txt:
--------------------------------------------------------------------------------
1 | This will hold network weights which are saved when the validation loss in minimized.
2 |
--------------------------------------------------------------------------------