├── .gitignore
├── README.md
├── custom_ops
├── compiled.py
├── native.py
└── source_code
│ ├── correlation_op.cc
│ ├── correlation_op.cu.cc
│ ├── correlation_op.h
│ ├── decode_flo_op.cc
│ └── decode_ppm_op.cc
├── data_loader.py
├── directories.py
├── main.py
├── network.py
├── readme_images
├── example_loss.png
├── example_training_flow1.gif
├── example_training_flow2.gif
├── example_training_flow3.gif
├── example_training_flow4.gif
├── example_validation_flow1.gif
├── example_validation_flow2.gif
├── example_validation_flow3.gif
└── example_validation_flow4.gif
├── task.py
└── trainer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | FlyingChairs/
2 | dynamic_libs/
3 | logged_data/
4 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PWC Net TensorFlow
2 |
3 | Tensorflow implementation of Pyramid, Warping and Cost Volume (PWC) Networks based on the [paper](https://arxiv.org/abs/1709.02371) presented at CVPR 2018.
4 | Currently, [main.py](https://github.com/djl11/PWC_Net_TensorFlow/blob/master/main.py) simply downloads the FlyingChairs Dataset and starts training, following the outlined [schedule](https://arxiv.org/abs/1709.02371).
5 | This code could easily be adapted to train on other datasets though.
6 |
7 | ## Tested Environment
8 |
9 | Ubuntu 16.04
10 | Python3
11 | Tensorflow 1.8
12 | Cuda 9.0
13 |
14 | ## Acknowledgements
15 |
16 | This repo uses 3 custom written tensorfow ops in c++
17 |
18 | The correlation op was taken from [this](https://github.com/simonmeister/UnFlow) tensorflow implementation of UnFlow by Simon Meister
19 | The ppm and flo decoding ops were taken from [this](https://github.com/lmb-freiburg/lmbspecialops) collection of tf ops, from the Computer Vision Group, Albert-Ludwigs-Universität Freiburg
20 |
21 | ## Usage
22 |
23 | ```python
24 | python3 main.py
25 | ```
26 |
27 | A tensorboard session will automatically be started in a new tmux window (so that the visualisations are still available after the python session has ended).
28 | This tensorboard session will log the training/validation losses, as well as giffs of the flow as it trains.
29 |
30 | Some general hyperparameters regarding the logging of data can be changed through [task.py](https://github.com/djl11/PWC_Net_TensorFlow/blob/master/task.py)
31 | Other hyperparameters relating to the training schedule can be changed in the constructor of [network.py](https://github.com/djl11/PWC_Net_TensorFlow/blob/master/network.py)
32 | The default training/validation split is to have 90% training, with 10% left for validation.
33 |
34 |
35 | ## Example visualisations following training
36 |
37 | From left to right, the images below indicate rgb image, ground truth flow, predicted flow, flow error
38 |
39 | Examples from the training set:
40 |
41 | 
42 | 
43 | 
44 | 
45 |
46 | Examples from the validation set:
47 |
48 | 
49 | 
50 | 
51 | 
52 |
53 | ## Example Training Loss
54 |
55 | This is an example of the loss when training on the full flying chairs dataset (no validation was used on this occassion).
56 |
57 | 
58 |
--------------------------------------------------------------------------------
/custom_ops/compiled.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import tensorflow as tf
4 | import subprocess
5 | from tensorflow.python.framework import ops
6 |
7 | # Register ops for compilation here
8 | OP_NAMES = ['correlation', 'decode_flo', 'decode_ppm']
9 |
10 | cwd = os.getcwd()
11 | current_file_dir = os.path.dirname(os.path.realpath(__file__))
12 |
13 | dynamic_lib_dir = 'dynamic_libs/'
14 | dynamic_lib_dir = current_file_dir + '/' + dynamic_lib_dir
15 |
16 | source_code_dir = 'source_code/'
17 | source_code_dir = current_file_dir + '/' + source_code_dir
18 |
19 | os.chdir(source_code_dir)
20 |
21 | if not os.path.isdir(dynamic_lib_dir):
22 | os.mkdir(dynamic_lib_dir)
23 |
24 | def compile(op=None):
25 | if op is not None:
26 | to_compile = [op]
27 | else:
28 | to_compile = OP_NAMES
29 |
30 | tf_cflags = " ".join(tf.sysconfig.get_compile_flags())
31 | tf_lflags = " ".join(tf.sysconfig.get_link_flags())
32 | for n in to_compile:
33 |
34 | print('\n\ncompiling custom operation: ' + n + '\n\n')
35 |
36 | base = n + "_op"
37 | fn_cu_cc = base + ".cu.cc"
38 | fn_cc = base + ".cc"
39 | fn_cu_o = dynamic_lib_dir + base + ".cu.o"
40 | fn_so = dynamic_lib_dir + base + ".so"
41 |
42 | out, err = subprocess.Popen(['which', 'nvcc'], stdout=subprocess.PIPE).communicate()
43 | cuda_dir = out.decode().split('/cuda')[0]
44 |
45 |
46 | if os.path.isfile(os.getcwd() + '/' + fn_cu_cc):
47 | nvcc_cmd = "nvcc -std=c++11 -c -o {} {} {} -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC -I " + cuda_dir + " --expt-relaxed-constexpr"
48 | nvcc_cmd = nvcc_cmd.format(" ".join([fn_cu_o, fn_cu_cc]), tf_cflags, tf_lflags)
49 | subprocess.check_output(nvcc_cmd, shell=True)
50 |
51 | gcc_cmd = "{} -std=c++11 -shared -o {} {} -fPIC -L " + cuda_dir + "/cuda/lib64 -lcudart {} -O2 -D GOOGLE_CUDA=1"
52 | gcc_cmd = gcc_cmd.format('g++'," ".join([fn_so, fn_cu_o, fn_cc]), tf_cflags, tf_lflags)
53 | else:
54 | gcc_cmd = "{} -std=c++11 -shared {} -o {} -fPIC {} {} -O2"
55 | gcc_cmd = gcc_cmd.format('g++', fn_cc, fn_so, tf_cflags, tf_lflags)
56 | print('gcc_cmd: ' + gcc_cmd)
57 | subprocess.check_output(gcc_cmd, shell=True)
58 |
59 |
60 | module = sys.modules[__name__]
61 | for n in OP_NAMES:
62 | lib_path = './{}_op.so'.format(n)
63 | try:
64 | os.chdir(dynamic_lib_dir)
65 | op_lib = tf.load_op_library(lib_path)
66 | except:
67 | os.chdir(source_code_dir)
68 | compile(n)
69 | os.chdir(dynamic_lib_dir)
70 | op_lib = tf.load_op_library(lib_path)
71 | setattr(module, '_' + n + '_module', op_lib)
72 |
73 | os.chdir(cwd)
74 |
75 | # functions #
76 | #-----------#
77 |
78 | def correlation(first, second, **kwargs):
79 | return _correlation_module.correlation(first, second, **kwargs)[0]
80 |
81 | decode_flo = _decode_flo_module.decode_flo
82 | decode_ppm = _decode_ppm_module.decode_ppm
83 |
84 | # Register op gradients
85 |
86 | @ops.RegisterGradient("Correlation")
87 | def _CorrelationGrad(op, in_grad, in_grad1, in_grad2):
88 | grad0, grad1 = _correlation_module.correlation_grad(
89 | in_grad, op.inputs[0], op.inputs[1],
90 | op.outputs[1], op.outputs[2],
91 | kernel_size=op.get_attr('kernel_size'),
92 | max_displacement=op.get_attr('max_displacement'),
93 | pad=op.get_attr('pad'),
94 | stride_1=op.get_attr('stride_1'),
95 | stride_2=op.get_attr('stride_2'))
96 | return [grad0, grad1]
97 |
98 | ops.NotDifferentiable("DecodeFlo")
99 | ops.NotDifferentiable("DecodePpm")
--------------------------------------------------------------------------------
/custom_ops/native.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import moviepy.editor as mpy
3 | import tempfile
4 | import cv2
5 | import numpy as np
6 |
7 | min_denominator = 1e-12
8 |
9 | # Image Warp #
10 | #------------#
11 |
12 | def image_warp(im, flow):
13 |
14 | num_batch, height, width, channels = tf.unstack(tf.shape(im))
15 | max_x = tf.cast(width - 1, 'int32')
16 | max_y = tf.cast(height - 1, 'int32')
17 | zero = tf.zeros([], dtype='int32')
18 |
19 | # We have to flatten our tensors to vectorize the interpolation
20 | im_flat = tf.reshape(im, [-1, channels])
21 | flow_flat = tf.reshape(flow, [-1, 2])
22 |
23 | # Floor the flow, as the final indices are integers
24 | # The fractional part is used to control the bilinear interpolation.
25 | flow_floor = tf.to_int32(tf.floor(flow_flat))
26 | bilinear_weights = flow_flat - tf.floor(flow_flat)
27 |
28 | # Construct base indices which are displaced with the flow
29 | pos_x = tf.tile(tf.range(width), [height * num_batch])
30 | grid_y = tf.tile(tf.expand_dims(tf.range(height), 1), [1, width])
31 | pos_y = tf.tile(tf.reshape(grid_y, [-1]), [num_batch])
32 |
33 | x = flow_floor[:, 0]
34 | y = flow_floor[:, 1]
35 | xw = bilinear_weights[:, 0]
36 | yw = bilinear_weights[:, 1]
37 |
38 | # Compute interpolation weights for 4 adjacent pixels
39 | # expand to num_batch * height * width x 1 for broadcasting in add_n below
40 | wa = tf.expand_dims((1 - xw) * (1 - yw), 1) # top left pixel
41 | wb = tf.expand_dims((1 - xw) * yw, 1) # bottom left pixel
42 | wc = tf.expand_dims(xw * (1 - yw), 1) # top right pixel
43 | wd = tf.expand_dims(xw * yw, 1) # bottom right pixel
44 |
45 | x0 = pos_x + x
46 | x1 = x0 + 1
47 | y0 = pos_y + y
48 | y1 = y0 + 1
49 |
50 | x0 = tf.clip_by_value(x0, zero, max_x)
51 | x1 = tf.clip_by_value(x1, zero, max_x)
52 | y0 = tf.clip_by_value(y0, zero, max_y)
53 | y1 = tf.clip_by_value(y1, zero, max_y)
54 |
55 | dim1 = width * height
56 | batch_offsets = tf.range(num_batch) * dim1
57 | base_grid = tf.tile(tf.expand_dims(batch_offsets, 1), [1, dim1])
58 | base = tf.reshape(base_grid, [-1])
59 |
60 | base_y0 = base + y0 * width
61 | base_y1 = base + y1 * width
62 | idx_a = base_y0 + x0
63 | idx_b = base_y1 + x0
64 | idx_c = base_y0 + x1
65 | idx_d = base_y1 + x1
66 |
67 | Ia = tf.gather(im_flat, idx_a)
68 | Ib = tf.gather(im_flat, idx_b)
69 | Ic = tf.gather(im_flat, idx_c)
70 | Id = tf.gather(im_flat, idx_d)
71 |
72 | warped_flat = tf.add_n([wa * Ia, wb * Ib, wc * Ic, wd * Id])
73 | warped = tf.reshape(warped_flat, [num_batch, height, width, channels])
74 |
75 | return warped
76 |
77 | # Visualisation #
78 | #---------------#
79 |
80 | def make_color_wheel():
81 | """
82 | Generate color wheel according Middlebury color code
83 | :return: Color wheel
84 | """
85 | RY = 15
86 | YG = 6
87 | GC = 4
88 | CB = 11
89 | BM = 13
90 | MR = 6
91 |
92 | ncols = RY + YG + GC + CB + BM + MR
93 |
94 | colorwheel = np.zeros([ncols, 3])
95 |
96 | col = 0
97 |
98 | # RY
99 | colorwheel[0:RY, 0] = 255
100 | colorwheel[0:RY, 1] = np.transpose(np.floor(255 * np.arange(0, RY) / RY))
101 | col += RY
102 |
103 | # YG
104 | colorwheel[col:col + YG, 0] = 255 - np.transpose(np.floor(255 * np.arange(0, YG) / YG))
105 | colorwheel[col:col + YG, 1] = 255
106 | col += YG
107 |
108 | # GC
109 | colorwheel[col:col + GC, 1] = 255
110 | colorwheel[col:col + GC, 2] = np.transpose(np.floor(255 * np.arange(0, GC) / GC))
111 | col += GC
112 |
113 | # CB
114 | colorwheel[col:col + CB, 1] = 255 - np.transpose(np.floor(255 * np.arange(0, CB) / CB))
115 | colorwheel[col:col + CB, 2] = 255
116 | col += CB
117 |
118 | # BM
119 | colorwheel[col:col + BM, 2] = 255
120 | colorwheel[col:col + BM, 0] = np.transpose(np.floor(255 * np.arange(0, BM) / BM))
121 | col += + BM
122 |
123 | # MR
124 | colorwheel[col:col + MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
125 | colorwheel[col:col + MR, 0] = 255
126 |
127 | return colorwheel
128 |
129 | def compute_color(u, v):
130 | """
131 | compute optical flow color map
132 | :param u: optical flow horizontal map
133 | :param v: optical flow vertical map
134 | :return: optical flow in color code
135 | """
136 | [h, w] = u.shape
137 | img = np.zeros([h, w, 3])
138 | nanIdx = np.isnan(u) | np.isnan(v)
139 | u[nanIdx] = 0
140 | v[nanIdx] = 0
141 |
142 | colorwheel = make_color_wheel()
143 | ncols = np.size(colorwheel, 0)
144 |
145 | rad = np.sqrt(u**2+v**2)
146 |
147 | a = np.arctan2(-v, -u) / np.pi
148 |
149 | fk = (a+1) / 2 * (ncols - 1) + 1
150 |
151 | k0 = np.floor(fk).astype(int)
152 |
153 | k1 = k0 + 1
154 | k1[k1 == ncols+1] = 1
155 | f = fk - k0
156 |
157 | for i in range(0, np.size(colorwheel,1)):
158 | tmp = colorwheel[:, i]
159 | col0 = tmp[k0-1] / 255
160 | col1 = tmp[k1-1] / 255
161 | col = (1-f) * col0 + f * col1
162 |
163 | idx = rad <= 1
164 | col[idx] = 1-rad[idx]*(1-col[idx])
165 | notidx = np.logical_not(idx)
166 |
167 | col[notidx] *= 0.75
168 | img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx)))
169 |
170 | return img
171 |
172 | def flow_to_image(flow):
173 | """
174 | Convert flow into middlebury color code image
175 | :param flow: optical flow map
176 | :return: optical flow image in middlebury color
177 | """
178 | u = flow[:, :, 0]
179 | v = flow[:, :, 1]
180 |
181 | idxUnknow = (abs(u) > 1e7) | (abs(v) > 1e7)
182 | u[idxUnknow] = 0
183 | v[idxUnknow] = 0
184 |
185 | rad = np.sqrt(u ** 2 + v ** 2)
186 | maxrad = max(-1, np.max(rad))
187 |
188 | u = u / (maxrad + np.finfo(float).eps)
189 | v = v / (maxrad + np.finfo(float).eps)
190 |
191 | img = compute_color(u, v)
192 |
193 | idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2)
194 | img[idx] = 0
195 |
196 | return np.uint8(img)
197 |
198 | def modify_images_for_vis(x_images, gt_flow, predicted_flow):
199 |
200 | images = list()
201 | for i in range(2):
202 | x_image = x_images[i]
203 | unscaled_predicted_flow = cv2.resize(predicted_flow, (448,384))
204 | diff_flow = flow_to_image(gt_flow-unscaled_predicted_flow)
205 |
206 | gt_flow_im = flow_to_image(gt_flow)
207 | predicted_flow_im = flow_to_image(unscaled_predicted_flow)
208 | diff_flow_im = flow_to_image(diff_flow)
209 |
210 | combined_image = np.concatenate((x_image,gt_flow_im,predicted_flow_im,diff_flow_im),1).astype(np.uint8) # change this to combine all of above
211 |
212 | images.append(combined_image)
213 |
214 | return np.asarray(images)
215 |
216 | def convert_array_to_gif_summary(images_arr, tag, fps):
217 |
218 | summary = tf.Summary()
219 |
220 | if len(images_arr.shape) == 5:
221 | # concatenate batch dimension horizontally
222 | images_arr = np.concatenate(list(images_arr), axis=-2)
223 | if len(images_arr.shape) != 4:
224 | raise ValueError('Tensors must be 4-D or 5-D for gif summary.')
225 | if images_arr.shape[-1] != 3:
226 | raise ValueError('Tensors must have 3 channels.')
227 |
228 | # encode sequence of images into gif string
229 | clip = mpy.ImageSequenceClip(list(images_arr), fps=fps)
230 | with tempfile.NamedTemporaryFile() as f:
231 | filename = f.name + '.gif'
232 | clip.write_gif(filename, verbose=False, program='ffmpeg')
233 | with open(filename, 'rb') as f:
234 | encoded_image_string = f.read()
235 |
236 | image = tf.Summary.Image()
237 | image.height = images_arr.shape[-3]
238 | image.width = images_arr.shape[-2]
239 | image.colorspace = 3 # code for 'RGB'
240 | image.encoded_image_string = encoded_image_string
241 | summary.value.add(tag=tag, image=image)
242 | return summary
--------------------------------------------------------------------------------
/custom_ops/source_code/correlation_op.cc:
--------------------------------------------------------------------------------
1 | #define EIGEN_USE_THREADS
2 |
3 | #include
4 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
5 | #include "tensorflow/core/framework/op_kernel.h"
6 | #include "tensorflow/core/framework/register_types.h"
7 | #include "tensorflow/core/framework/tensor.h"
8 | #include "tensorflow/core/framework/tensor_shape.h"
9 | #include "tensorflow/core/framework/types.h"
10 | #include "tensorflow/core/lib/core/status.h"
11 | #include "tensorflow/core/platform/logging.h"
12 | #include "tensorflow/core/framework/op.h"
13 | #include "tensorflow/core/framework/shape_inference.h"
14 | #include "tensorflow/core/framework/common_shape_fns.h"
15 |
16 | #include "correlation_op.h"
17 |
18 | typedef Eigen::GpuDevice GPUDevice;
19 |
20 | using namespace tensorflow;
21 |
22 | void Correlation(const GPUDevice& d,
23 | typename TTypes::ConstTensor input_0,
24 | typename TTypes::ConstTensor input_1,
25 | typename TTypes::Tensor output,
26 | typename TTypes::Tensor padded_0,
27 | typename TTypes::Tensor padded_1,
28 | CorrelationState params);
29 |
30 | void CorrelationGrad(const GPUDevice& d,
31 | typename TTypes::ConstTensor input_grad,
32 | typename TTypes::ConstTensor padded_0,
33 | typename TTypes::ConstTensor padded_1,
34 | typename TTypes::Tensor output_grad_0,
35 | typename TTypes::Tensor output_grad_1,
36 | CorrelationState params);
37 |
38 | class CorrelationOp : public OpKernel {
39 | public:
40 | explicit CorrelationOp(OpKernelConstruction* context)
41 | : OpKernel(context), attrs(context) {}
42 |
43 | void Compute(OpKernelContext* context) override {
44 | const Tensor& input_0 = context->input(0);
45 | const Tensor& input_1 = context->input(1);
46 |
47 | OP_REQUIRES(context, input_0.shape() == input_1.shape(),
48 | errors::InvalidArgument("Input shapes have to be the same"));
49 |
50 | typename TTypes::ConstTensor input_0_data = input_0.tensor();
51 | typename TTypes::ConstTensor input_1_data = input_1.tensor();
52 |
53 | const int batch = input_0_data.dimension(0);
54 | const int in_channels = input_0_data.dimension(1);
55 | const int in_height = input_0_data.dimension(2);
56 | const int in_width = input_0_data.dimension(3);
57 |
58 | CorrelationState st(attrs, in_height, in_width, in_channels);
59 |
60 | OP_REQUIRES(context, st.out_width * st.out_height > 0,
61 | errors::InvalidArgument("Invalid correlation settings"));
62 |
63 | Tensor* output = NULL;
64 | TensorShape output_shape({batch, st.out_channels, st.out_height, st.out_width});
65 | OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
66 |
67 | Tensor* padded_0 = NULL;
68 | Tensor* padded_1 = NULL;
69 | TensorShape padded_shape({batch, st.padded_height, st.padded_width, in_channels});
70 | OP_REQUIRES_OK(context, context->allocate_output(1, padded_shape, &padded_0));
71 | OP_REQUIRES_OK(context, context->allocate_output(2, padded_shape, &padded_1));
72 |
73 | typename TTypes::Tensor output_data = output->tensor();
74 | typename TTypes::Tensor padded_0_data = padded_0->tensor();
75 | typename TTypes::Tensor padded_1_data = padded_1->tensor();
76 |
77 | Correlation(context->eigen_device(),
78 | input_0_data, input_1_data, output_data,
79 | padded_0_data, padded_1_data,
80 | st);
81 | }
82 |
83 | private:
84 | CorrelationAttrs attrs;
85 | };
86 |
87 | class CorrelationOpGrad : public OpKernel {
88 | public:
89 | explicit CorrelationOpGrad(OpKernelConstruction* context)
90 | : OpKernel(context), attrs(context) {}
91 |
92 | void Compute(OpKernelContext* context) override {
93 | const Tensor& input_grad = context->input(0);
94 | const Tensor& input_0 = context->input(1);
95 | const Tensor& input_1 = context->input(2);
96 | const Tensor& padded_0 = context->input(3);
97 | const Tensor& padded_1 = context->input(4);
98 |
99 | typename TTypes::ConstTensor input_grad_data = input_grad.tensor();
100 | typename TTypes::ConstTensor input_0_data = input_0.tensor();
101 | //typename TTypes::ConstTensor input_1_data = input_1.tensor();
102 | typename TTypes::ConstTensor padded_0_data = padded_0.tensor();
103 | typename TTypes::ConstTensor padded_1_data = padded_1.tensor();
104 |
105 | const int in_channels = input_0_data.dimension(1);
106 | const int in_height = input_0_data.dimension(2);
107 | const int in_width = input_0_data.dimension(3);
108 |
109 | CorrelationState st(attrs, in_height, in_width, in_channels);
110 |
111 | Tensor* output_grad_0 = NULL;
112 | OP_REQUIRES_OK(context, context->allocate_output(0, input_0.shape(),
113 | &output_grad_0));
114 | Tensor* output_grad_1 = NULL;
115 | OP_REQUIRES_OK(context, context->allocate_output(1, input_0.shape(),
116 | &output_grad_1));
117 |
118 | typename TTypes::Tensor output_grad_0_data = output_grad_0->tensor();
119 | typename TTypes::Tensor output_grad_1_data = output_grad_1->tensor();
120 |
121 | CorrelationGrad(context->eigen_device(),
122 | input_grad_data,
123 | padded_0_data, padded_1_data,
124 | output_grad_0_data, output_grad_1_data,
125 | st);
126 | }
127 | private:
128 | CorrelationAttrs attrs;
129 | };
130 |
131 | using shape_inference::DimensionHandle;;
132 |
133 | REGISTER_OP("Correlation")
134 | .Input("input_0: float")
135 | .Input("input_1: float")
136 | .Attr("kernel_size: int = 1")
137 | .Attr("max_displacement: int = 20")
138 | .Attr("pad: int = 20")
139 | .Attr("stride_1: int = 1")
140 | .Attr("stride_2: int = 2")
141 | .Output("correlation: float")
142 | .Output("padded_0: float")
143 | .Output("padded_1: float")
144 | .SetShapeFn([](shape_inference::InferenceContext* c) {
145 | CorrelationAttrs attrs;
146 | c->GetAttr("kernel_size", &attrs.kernel_size);
147 | c->GetAttr("max_displacement", &attrs.max_displacement);
148 | c->GetAttr("pad", &attrs.pad_size);
149 | c->GetAttr("stride_1", &attrs.stride_1);
150 | c->GetAttr("stride_2", &attrs.stride_2);
151 |
152 | DimensionHandle batch = c->Dim(c->input(0), 0);
153 |
154 | //padded_height = in_height + 2 * pad_size;
155 | //padded_width = in_width + 2 * pad_size;
156 | //kernel_radius = (kernel_size - 1) / 2;
157 | //border_size = max_displacement + kernel_radius;
158 | int neighborhood_grid_radius = attrs.max_displacement / attrs.stride_2;
159 | int neighborhood_grid_width = neighborhood_grid_radius * 2 + 1;
160 | //out_width = ceil((float)(padded_width - border_size *2) / (float)stride_1);
161 | //out_height = ceil((float)(padded_height - border_size *2) / (float)stride_1);
162 | int out_channels = neighborhood_grid_width * neighborhood_grid_width;
163 |
164 | // TODO: support passing on output width and height
165 |
166 | c->set_output(0, c->MakeShape({batch, out_channels, c->UnknownDim(), c->UnknownDim()}));
167 | return Status::OK();
168 | });
169 |
170 | REGISTER_OP("CorrelationGrad")
171 | .Input("input_grad: float")
172 | .Input("original_input_0: float")
173 | .Input("original_input_1: float")
174 | .Input("padded_0: float")
175 | .Input("padded_1: float")
176 | .Attr("kernel_size: int = 1")
177 | .Attr("max_displacement: int = 20")
178 | .Attr("pad: int = 20")
179 | .Attr("stride_1: int = 1")
180 | .Attr("stride_2: int = 2")
181 | .Output("output_grad_0: float")
182 | .Output("output_grad_1: float")
183 | .SetShapeFn([](shape_inference::InferenceContext* c) {
184 | c->set_output(0, c->input(1));
185 | c->set_output(1, c->input(2));
186 | return Status::OK();
187 | });
188 |
189 | #if GOOGLE_CUDA
190 |
191 | REGISTER_KERNEL_BUILDER(Name("Correlation").Device(DEVICE_GPU), CorrelationOp);
192 | REGISTER_KERNEL_BUILDER(Name("CorrelationGrad").Device(DEVICE_GPU), CorrelationOpGrad);
193 |
194 | #endif // GOOGLE_CUDA
195 |
--------------------------------------------------------------------------------
/custom_ops/source_code/correlation_op.cu.cc:
--------------------------------------------------------------------------------
1 | #if GOOGLE_CUDA
2 |
3 | #define EIGEN_USE_GPU
4 |
5 | #include "tensorflow/core/framework/register_types.h"
6 | #include "tensorflow/core/framework/tensor_types.h"
7 | #include "tensorflow/core/platform/types.h"
8 | #include "tensorflow/core/util/cuda_kernel_helper.h"
9 |
10 | #include "correlation_op.h"
11 |
12 | using namespace tensorflow;
13 | using CPUDevice = Eigen::ThreadPoolDevice;
14 | using GPUDevice = Eigen::GpuDevice;
15 | // ---------------------------------------------------------
16 | // DIRECT PORT OF CAFFE CODE WITH MINIMAL CHANGES
17 | // ---------------------------------------------------------
18 |
19 | #define ROUND_OFF 50000
20 |
21 | #define WARPS_PER_BLOCK 1
22 | #define THREADS_PER_WARP 32
23 |
24 | const int CAFFE_CUDA_NUM_THREADS = 512;
25 |
26 | inline int CAFFE_GET_BLOCKS(const int N) {
27 | return (N + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS;
28 | }
29 |
30 | template
31 | __global__ void blob_rearrange_kernel2(const Dtype* in, Dtype* out, int num, int channels, int width, int height, int widthheight, int padding, int pwidthheight)
32 | {
33 | int xy = blockIdx.x*blockDim.x + threadIdx.x;
34 | if(xy>=widthheight)
35 | return;
36 |
37 | int ch = blockIdx.y;
38 | int n = blockIdx.z;
39 |
40 | Dtype value=in[(n*channels+ch)*widthheight+xy];
41 |
42 | __syncthreads();
43 |
44 | int xpad = (xy % width + padding);
45 | int ypad = (xy / width + padding);
46 | int xypad = ypad * (width+2*padding) + xpad;
47 |
48 | out[(n*pwidthheight+xypad)*channels + ch] = value;
49 | }
50 |
51 | template
52 | __global__ void CorrelateData(const int nthreads, int num, int topwidth, int topheight, int topchannels, int topcount,
53 | int max_displacement, int neighborhood_grid_radius, int neighborhood_grid_width, int kernel_radius, int kernel_size, int stride1, int stride2,
54 | int bottomwidth, int bottomheight, int bottomchannels,
55 | const Dtype *bottom0, const Dtype *bottom1, Dtype *top)
56 | {
57 | extern __shared__ char patch_data_char[];
58 |
59 | Dtype *patch_data = (Dtype *)patch_data_char;
60 |
61 | // First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1
62 | int x1 = blockIdx.x*stride1 + max_displacement;
63 | int y1 = blockIdx.y*stride1 + max_displacement;
64 | int item = blockIdx.z;
65 | int ch_off = threadIdx.x;
66 |
67 | // Load 3D patch into shared shared memory
68 | for(int j = 0; j < kernel_size; j++) { // HEIGHT
69 | for(int i = 0; i < kernel_size; i++) { // WIDTH
70 | int ji_off = ((j * kernel_size) + i) * bottomchannels;
71 | for(int ch = ch_off; ch < bottomchannels; ch += (WARPS_PER_BLOCK*THREADS_PER_WARP)) { // CHANNELS
72 | int idx1 = ((item * bottomheight + y1+j) * bottomwidth + x1+i) * bottomchannels + ch;
73 | int idxPatchData = ji_off + ch;
74 | patch_data[idxPatchData] = bottom0[idx1];
75 | }
76 | }
77 | }
78 |
79 | __syncthreads();
80 |
81 | __shared__ Dtype sum[WARPS_PER_BLOCK*THREADS_PER_WARP];
82 |
83 | // Compute correlation
84 | for(int top_channel = 0; top_channel < topchannels; top_channel++) {
85 | sum[ch_off] = 0;
86 |
87 | int s2o = (top_channel % neighborhood_grid_width - neighborhood_grid_radius) * stride2;
88 | int s2p = (top_channel / neighborhood_grid_width - neighborhood_grid_radius) * stride2;
89 |
90 | for(int j = 0; j < kernel_size; j++) { // HEIGHT
91 | for(int i = 0; i < kernel_size; i++) { // WIDTH
92 | int ji_off = ((j * kernel_size) + i) * bottomchannels;
93 | for(int ch = ch_off; ch < bottomchannels; ch += (WARPS_PER_BLOCK*THREADS_PER_WARP)) { // CHANNELS
94 | int x2 = x1 + s2o;
95 | int y2 = y1 + s2p;
96 |
97 | int idxPatchData = ji_off + ch;
98 | int idx2 = ((item * bottomheight + y2+j) * bottomwidth + x2+i) * bottomchannels + ch;
99 |
100 | sum[ch_off] += patch_data[idxPatchData] * bottom1[idx2];
101 | }
102 | }
103 | }
104 |
105 | __syncthreads();
106 |
107 | if(ch_off == 0) {
108 | Dtype total_sum = 0;
109 | for(int idx = 0; idx < WARPS_PER_BLOCK*THREADS_PER_WARP; idx++) {
110 | total_sum += sum[idx];
111 | }
112 | const int sumelems = kernel_size*kernel_size*bottomchannels;
113 | const int index = ((top_channel*topheight + blockIdx.y)*topwidth)+blockIdx.x;
114 | top[index + item*topcount] = total_sum / (float)sumelems;
115 | }
116 | }
117 | }
118 |
119 | template
120 | __global__ void CorrelateDataBackward0(const int nthreads, int num, int item, int topwidth, int topheight, int topchannels,
121 | int max_displacement, int neighborhood_grid_radius, int neighborhood_grid_width, int kernel_radius, int stride1, int stride2,
122 | int bottomwidth, int bottomheight, int pbottomwidth, int pbottomheight, int bottomchannels, int bottomcount, int pad_size,
123 | Dtype *bottom0diff, const Dtype *bottom1, const Dtype *topdiff)
124 | {
125 | CUDA_1D_KERNEL_LOOP(index, nthreads) {
126 | int n = index % bottomchannels; //channels
127 | int l = (index / bottomchannels) % bottomwidth + pad_size; //w-pos
128 | int m = (index / bottomchannels / bottomwidth) % bottomheight + pad_size; //h-pos
129 |
130 | //Get X,Y ranges and clamp
131 | // round_off is a trick to enable integer division with ceil, even for negative numbers
132 | // We use a large offset, for the inner part not to become negative.
133 | const int round_off = ROUND_OFF;
134 | const int round_off_s1 = stride1 * round_off;
135 |
136 | // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:
137 | int xmin = (l - 2*kernel_radius - max_displacement + round_off_s1 - 1) / stride1 + 1 - round_off; // ceil (l - 2*kernel_radius - max_displacement) / stride1
138 | int ymin = (m - 2*kernel_radius - max_displacement + round_off_s1 - 1) / stride1 + 1 - round_off; // ceil (l - 2*kernel_radius - max_displacement) / stride1
139 |
140 | // Same here:
141 | int xmax = (l - max_displacement + round_off_s1) / stride1 - round_off; // floor (l - max_displacement) / stride1
142 | int ymax = (m - max_displacement + round_off_s1) / stride1 - round_off; // floor (m - max_displacement) / stride1
143 |
144 |
145 | Dtype sum = 0;
146 | if(xmax>=0 && ymax>=0 && (xmin<=topwidth-1) && (ymin<=topheight-1))
147 | {
148 | xmin = max(0,xmin);
149 | xmax = min(topwidth-1,xmax);
150 |
151 | ymin = max(0,ymin);
152 | ymax = min(topheight-1,ymax);
153 |
154 | for(int p = -neighborhood_grid_radius; p <= neighborhood_grid_radius; p++) {
155 | for(int o = -neighborhood_grid_radius; o <= neighborhood_grid_radius; o++) {
156 |
157 | // Get bottom1 data:
158 | int s2o = stride2 * o;
159 | int s2p = stride2 * p;
160 | int idxbot1 = ((item * pbottomheight + (m+s2p)) * pbottomwidth + (l+s2o)) * bottomchannels + n;
161 | Dtype bot1tmp = bottom1[idxbot1]; // bottom1[l+s2o,m+s2p,n]
162 |
163 | // Index offset for topdiff in following loops:
164 | int op = (p+neighborhood_grid_radius) * neighborhood_grid_width + (o+neighborhood_grid_radius); // index [o,p]
165 | int idxopoffset = (item * topchannels + op);
166 |
167 | for(int y = ymin; y <= ymax; y++) {
168 | for(int x = xmin; x <= xmax; x++) {
169 | int idxtopdiff = (idxopoffset * topheight + y) * topwidth + x; // topdiff[x,y,o,p]
170 | sum += topdiff[idxtopdiff] * bot1tmp;
171 | }
172 | }
173 | }
174 | }
175 | }
176 | const int sumelems = (kernel_radius*2+1)*(kernel_radius*2+1)*bottomchannels;
177 | const int bot0index = ((n * bottomheight) + (m-pad_size)) * bottomwidth + (l-pad_size);
178 | bottom0diff[bot0index + item*bottomcount] = sum / (float)sumelems;
179 | }
180 |
181 | }
182 |
183 | template
184 | __global__ void CorrelateDataBackward1(const int nthreads, int num, int item, int topwidth, int topheight, int topchannels,
185 | int max_displacement, int neighborhood_grid_radius, int neighborhood_grid_width, int kernel_radius, int stride1, int stride2,
186 | int bottomwidth, int bottomheight, int pbottomwidth, int pbottomheight, int bottomchannels, int bottomcount, int pad_size,
187 | const Dtype *bottom0, Dtype *bottom1diff, const Dtype *topdiff)
188 | {
189 | CUDA_1D_KERNEL_LOOP(index, nthreads) {
190 | //int l = index % bottomwidth + pad_size; //w-pos
191 | //int m = (index / bottomwidth) % bottomheight + pad_size; //h-pos
192 | //int n = (index / bottomwidth / bottomheight) % bottomchannels; //channels
193 | int n = index % bottomchannels; //channels
194 | int l = (index / bottomchannels) % bottomwidth + pad_size; //w-pos
195 | int m = (index / bottomchannels / bottomwidth) % bottomheight + pad_size; //h-pos
196 |
197 | // round_off is a trick to enable integer division with ceil, even for negative numbers
198 | // We use a large offset, for the inner part not to become negative.
199 | const int round_off = ROUND_OFF;
200 | const int round_off_s1 = stride1 * round_off;
201 |
202 | Dtype sum = 0;
203 | for(int p = -neighborhood_grid_radius; p <= neighborhood_grid_radius; p++) {
204 | for(int o = -neighborhood_grid_radius; o <= neighborhood_grid_radius; o++) {
205 |
206 | int s2o = stride2 * o;
207 | int s2p = stride2 * p;
208 |
209 | //Get X,Y ranges and clamp
210 | // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:
211 | int xmin = (l - 2*kernel_radius - max_displacement - s2o + round_off_s1 - 1) / stride1 + 1 - round_off; // ceil (l - 2*kernel_radius - max_displacement - s2o) / stride1
212 | int ymin = (m - 2*kernel_radius - max_displacement - s2p + round_off_s1 - 1) / stride1 + 1 - round_off; // ceil (l - 2*kernel_radius - max_displacement - s2o) / stride1
213 |
214 | // Same here:
215 | int xmax = (l - max_displacement - s2o + round_off_s1) / stride1 - round_off; // floor (l - max_displacement - s2o) / stride1
216 | int ymax = (m - max_displacement - s2p + round_off_s1) / stride1 - round_off; // floor (m - max_displacement - s2p) / stride1
217 |
218 | if(xmax>=0 && ymax>=0 && (xmin<=topwidth-1) && (ymin<=topheight-1))
219 | {
220 | xmin = max(0,xmin);
221 | xmax = min(topwidth-1,xmax);
222 |
223 | ymin = max(0,ymin);
224 | ymax = min(topheight-1,ymax);
225 |
226 | // Get bottom0 data:
227 | int idxbot0 = ((item * pbottomheight + (m-s2p)) * pbottomwidth + (l-s2o)) * bottomchannels + n;
228 | Dtype bot0tmp = bottom0[idxbot0]; // bottom1[l+s2o,m+s2p,n]
229 |
230 | // Index offset for topdiff in following loops:
231 | int op = (p+neighborhood_grid_radius) * neighborhood_grid_width + (o+neighborhood_grid_radius); // index [o,p]
232 | int idxOpOffset = (item * topchannels + op);
233 |
234 | for(int y = ymin; y <= ymax; y++) {
235 | for(int x = xmin; x <= xmax; x++) {
236 | int idxtopdiff = (idxOpOffset * topheight + y) * topwidth + x; // topdiff[x,y,o,p]
237 | sum += topdiff[idxtopdiff] * bot0tmp;
238 | }
239 | }
240 | }
241 | }
242 | }
243 | const int sumelems = (kernel_radius*2+1)*(kernel_radius*2+1)*bottomchannels;
244 | const int bot1index = ((n * bottomheight) + (m-pad_size)) * bottomwidth + (l-pad_size);
245 | bottom1diff[bot1index + item*bottomcount] = sum / (float)sumelems;
246 | }
247 |
248 | }
249 |
250 | void Correlation(const GPUDevice& d,
251 | typename TTypes::ConstTensor input_0,
252 | typename TTypes::ConstTensor input_1,
253 | typename TTypes::Tensor output,
254 | typename TTypes::Tensor padded_0,
255 | typename TTypes::Tensor padded_1,
256 | CorrelationState st) {
257 |
258 | const int top_channels_ = output.dimension(1);
259 | const int top_height_ = output.dimension(2);
260 | const int top_width_ = output.dimension(3);
261 | const int pad_size_ = st.pad_size;
262 | const int stride1_ = st.stride_1;
263 | const int stride2_ = st.stride_2;
264 | const int kernel_size_ = st.kernel_size;
265 | const int kernel_radius_ = st.kernel_radius;
266 | const int max_displacement_ = st.max_displacement;
267 | const int neighborhood_grid_radius_ = st.neighborhood_grid_radius;
268 | const int neighborhood_grid_width_ = st.neighborhood_grid_width;
269 |
270 | // PORTED CAFFE CODE
271 |
272 | const int bnum = input_0.dimension(0);
273 | const int bchannels = input_0.dimension(1);
274 | const int bheight = input_0.dimension(2);
275 | const int bwidth = input_0.dimension(3);
276 | const int bwidthheight = bwidth * bheight;
277 |
278 | const int topcount = top_width_ * top_height_ * top_channels_;
279 |
280 | dim3 threadsPerBlock(THREADS_PER_WARP * WARPS_PER_BLOCK);
281 |
282 | cudaMemset(padded_0.data(), 0, padded_0.size()*sizeof(float));
283 | cudaMemset(padded_1.data(), 0, padded_1.size()*sizeof(float));
284 |
285 | int threads_per_block=16;
286 | dim3 totalBlocksRearr((bwidthheight-1)/threads_per_block+1, bchannels, bnum);
287 | const int pwidthheight = (bwidth + 2 * pad_size_) * (bheight + 2 * pad_size_);
288 |
289 | blob_rearrange_kernel2<<>>
290 | (input_0.data(),padded_0.data(),bnum,bchannels,bwidth,bheight,bwidthheight,pad_size_,pwidthheight);
291 |
292 | blob_rearrange_kernel2<<>>
293 | (input_1.data(),padded_1.data(),bnum,bchannels,bwidth,bheight,bwidthheight,pad_size_,pwidthheight);
294 |
295 | const int num = bnum;
296 | const int channels = bchannels;
297 | const int height = bheight + 2*pad_size_;
298 | const int width = bwidth + 2*pad_size_;
299 |
300 | const int shared_memory_per_block = (kernel_size_*kernel_size_)*bchannels;
301 |
302 | // CorrelationLayer
303 | int topThreadCount = topcount;
304 |
305 | dim3 totalBlocksCorr(top_width_, top_height_, num);
306 |
307 | CorrelateData<<>>(
308 | topThreadCount,
309 | num, top_width_, top_height_, top_channels_, topcount,
310 | max_displacement_, neighborhood_grid_radius_, neighborhood_grid_width_, kernel_radius_, kernel_size_,
311 | stride1_, stride2_,
312 | width, height, channels,
313 | padded_0.data(), padded_1.data(), output.data()
314 | );
315 | }
316 |
317 | void CorrelationGrad(const GPUDevice& d,
318 | typename TTypes::ConstTensor input_grad,
319 | typename TTypes::ConstTensor padded_0,
320 | typename TTypes::ConstTensor padded_1,
321 | typename TTypes::Tensor output_grad_0,
322 | typename TTypes::Tensor output_grad_1,
323 | CorrelationState st) {
324 |
325 | const int top_channels_ = input_grad.dimension(1);
326 | const int top_height_ = input_grad.dimension(2);
327 | const int top_width_ = input_grad.dimension(3);
328 |
329 | const int pad_size_ = st.pad_size;
330 | const int stride1_ = st.stride_1;
331 | const int stride2_ = st.stride_2;
332 | const int kernel_size_ = st.kernel_size;
333 | const int kernel_radius_ = st.kernel_radius;
334 | const int max_displacement_ = st.max_displacement;
335 | const int neighborhood_grid_radius_ = st.neighborhood_grid_radius;
336 | const int neighborhood_grid_width_ = st.neighborhood_grid_width;
337 |
338 | // PORTED CAFFE CODE
339 |
340 | // Get top diff, compute bottom diff
341 | const float* top_diff = input_grad.data();
342 |
343 | float* bottom0_diff = output_grad_0.data();
344 | float* bottom1_diff = output_grad_1.data();
345 |
346 | const int num = output_grad_0.dimension(0);
347 | const int channels = output_grad_0.dimension(1);
348 | const int height = output_grad_0.dimension(2);
349 | const int width = output_grad_0.dimension(3);
350 |
351 | const int paddedheight = height + 2*pad_size_;
352 | const int paddedwidth = width + 2*pad_size_;
353 |
354 | const int bottomcount = channels * height * width;
355 |
356 | int botThreadCount = bottomcount;
357 |
358 | // CorrelationLayerBackward
359 |
360 | // == Run kernel Backward 0
361 | dim3 totalBlocksBackward0(width, height, channels * num); //First dim is fastest
362 | dim3 threadsPerBlockBackward0(THREADS_PER_WARP * WARPS_PER_BLOCK);
363 | const int buffer_size_backw0 = ((int)ceil((float)(2 * kernel_radius_) / (float)stride1_) + 1) * top_channels_;
364 |
365 | // == Run kernel Backward 0
366 | for(int n = 0; n < num; n++) {
367 | //Bottom0:
368 | CorrelateDataBackward0<<>>(
369 | botThreadCount,
370 | num, n, top_width_, top_height_, top_channels_,
371 | max_displacement_, neighborhood_grid_radius_, neighborhood_grid_width_, kernel_radius_,
372 | stride1_, stride2_,
373 | width, height, paddedwidth, paddedheight, channels, bottomcount, pad_size_,
374 | bottom0_diff, padded_1.data(), top_diff
375 | );
376 | }
377 |
378 | // == Run kernel Backward 1
379 | for(int n = 0; n < num; n++) {
380 | CorrelateDataBackward1<<>>(
381 | botThreadCount,
382 | num, n, top_width_, top_height_, top_channels_,
383 | max_displacement_, neighborhood_grid_radius_, neighborhood_grid_width_, kernel_radius_,
384 | stride1_, stride2_,
385 | width, height, paddedwidth, paddedheight, channels, bottomcount, pad_size_,
386 | padded_0.data(), bottom1_diff, top_diff
387 | );
388 | }
389 |
390 | }
391 |
392 | #endif // GOOGLE_CUDA
393 |
--------------------------------------------------------------------------------
/custom_ops/source_code/correlation_op.h:
--------------------------------------------------------------------------------
1 | #define EIGEN_USE_THREADS
2 |
3 | #include "tensorflow/core/framework/op_kernel.h"
4 | #include "tensorflow/core/framework/op.h"
5 |
6 | using namespace tensorflow;
7 |
8 | struct CorrelationAttrs {
9 | CorrelationAttrs(OpKernelConstruction* c) {
10 | OP_REQUIRES_OK(c, c->GetAttr("kernel_size", &kernel_size));
11 | OP_REQUIRES_OK(c, c->GetAttr("max_displacement", &max_displacement));
12 | OP_REQUIRES_OK(c, c->GetAttr("pad", &pad_size));
13 | OP_REQUIRES_OK(c, c->GetAttr("stride_1", &stride_1));
14 | OP_REQUIRES_OK(c, c->GetAttr("stride_2", &stride_2));
15 |
16 | OP_REQUIRES(c, kernel_size % 2 != 0,
17 | errors::InvalidArgument("kernel_size must be odd"));
18 | }
19 | CorrelationAttrs() {}
20 |
21 | int pad_size;
22 | int stride_1;
23 | int stride_2;
24 | int max_displacement;
25 | int kernel_size;
26 | };
27 |
28 | struct CorrelationState {
29 | CorrelationState(CorrelationAttrs attrs, int in_height, int in_width, int in_channels) {
30 | pad_size = attrs.pad_size;
31 | stride_1 = attrs.stride_1;
32 | stride_2 = attrs.stride_2;
33 | max_displacement = attrs.max_displacement;
34 | kernel_size = attrs.kernel_size;
35 |
36 | padded_height = in_height + 2 * pad_size;
37 | padded_width = in_width + 2 * pad_size;
38 |
39 | // Compute size of unreachable border region (on each side)
40 | kernel_radius = (kernel_size - 1) / 2;
41 | border_size = max_displacement + kernel_radius;
42 |
43 | // Given a center position in image 1, how many displaced positions in -x / +x
44 | // direction do we consider in image 2 (neighborhoodGridWidth):
45 | neighborhood_grid_radius = max_displacement / stride_2;
46 | neighborhood_grid_width = neighborhood_grid_radius * 2 + 1;
47 |
48 | out_width = ceil((float)(padded_width - border_size *2) / (float)stride_1);
49 | out_height = ceil((float)(padded_height - border_size *2) / (float)stride_1);
50 | // Top Channels amount to displacement combinations in X and Y direction:
51 | out_channels = neighborhood_grid_width * neighborhood_grid_width;
52 | }
53 |
54 | int pad_size;
55 | int stride_1;
56 | int stride_2;
57 | int kernel_radius;
58 | int max_displacement;
59 | int kernel_size;
60 | int neighborhood_grid_radius;
61 | int neighborhood_grid_width;
62 | int padded_height;
63 | int padded_width;
64 | int border_size;
65 | int out_height;
66 | int out_width;
67 | int out_channels;
68 | };
69 |
--------------------------------------------------------------------------------
/custom_ops/source_code/decode_flo_op.cc:
--------------------------------------------------------------------------------
1 | //
2 | // lmbspecialops - a collection of tensorflow ops
3 | // Copyright (C) 2017 Albert Ludwigs University of Freiburg, Pattern Recognition and Image Processing, Computer Vision Group
4 | // Author(s): Lukas Voegtle
5 | //
6 | // This program is free software: you can redistribute it and/or modify
7 | // it under the terms of the GNU General Public License as published by
8 | // the Free Software Foundation, either version 3 of the License, or
9 | // (at your option) any later version.
10 | //
11 | // This program is distributed in the hope that it will be useful,
12 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | // GNU General Public License for more details.
15 | //
16 | // You should have received a copy of the GNU General Public License
17 | // along with this program. If not, see .
18 | //
19 | //#include "config.h"
20 | #include "tensorflow/core/framework/op.h"
21 | #include "tensorflow/core/framework/shape_inference.h"
22 | #include "tensorflow/core/framework/op_kernel.h"
23 | #include "tensorflow/core/framework/op.h"
24 |
25 | using namespace tensorflow;
26 |
27 | REGISTER_OP("DecodeFlo")
28 | .Input("contents: string")
29 | .Output("image: float")
30 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
31 | using namespace ::tensorflow::shape_inference;
32 | c->set_output(0,
33 | c->MakeShape({InferenceContext::kUnknownDim,
34 | InferenceContext::kUnknownDim, 2}));
35 | return Status::OK();
36 | })
37 | .Doc(R"doc(
38 | Decode a FLO-encoded image to a float tensor.
39 | contents: 0-D. The FLO-encoded image.
40 | image: 3-D with shape `[height, width, 2]`.
41 | )doc");
42 |
43 |
44 |
45 |
46 | // Decode the contents of a FLO file
47 | class DecodeFloOp : public OpKernel {
48 | public:
49 | explicit DecodeFloOp(OpKernelConstruction* context) : OpKernel(context) {
50 | }
51 |
52 | void Compute(OpKernelContext* context) override {
53 | const Tensor& contents = context->input(0);
54 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(contents.shape()),
55 | errors::InvalidArgument("contents must be scalar, got shape ",
56 | contents.shape().DebugString()));
57 |
58 | // Start decoding image to get shape details
59 | const StringPiece data = contents.scalar()();
60 | if (data.size() < 12) {
61 | OP_REQUIRES(context, false,
62 | errors::InvalidArgument("Invalid FLO data size, expected at least 12"));
63 | }
64 |
65 | if (!data.starts_with("PIEH")) {
66 | OP_REQUIRES(context, false,
67 | errors::InvalidArgument("Invalid FLO header, expected 'PIEH'"));
68 | }
69 | if (*((float*)(data.data())) != 202021.25f) {
70 | OP_REQUIRES(context, false,
71 | errors::InvalidArgument("Invalid FLO header, expected 202021.25 (sanity check failed)"));
72 | }
73 | uint32 width = *((uint32*)(data.data() + 4));
74 | uint32 height = *((uint32*)(data.data() + 8));
75 |
76 | // Verify that width and height are not too large:
77 | // - verify width and height don't overflow int.
78 | // - width can later be multiplied by channels_ and sizeof(uint16), so
79 | // verify single dimension is not too large.
80 | // - verify when width and height are multiplied together, there are a few
81 | // bits to spare as well.
82 | const int64 total_size =
83 | static_cast(width) * static_cast(height);
84 | if (width != static_cast(width) || width <= 0 ||
85 | width >= (1LL << 27) || height != static_cast(height) ||
86 | height <= 0 || height >= (1LL << 27) || total_size >= (1LL << 29)) {
87 | OP_REQUIRES(context, false,
88 | errors::InvalidArgument("FLO size too large for int: ",
89 | width, " by ", height));
90 | }
91 |
92 | if (data.size() != 12 + width * height * 2 * 4) {
93 | OP_REQUIRES(context, false,
94 | errors::InvalidArgument("Invalid FLO data size, expected ", 12 + width * height * 2 * 4));
95 | }
96 |
97 | // Allocate tensor
98 | Tensor* output = nullptr;
99 | OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({height, width, 2}), &output));
100 |
101 | // Finish decoding image
102 | const uint8* innerData = (const uint8*)(data.data() + 12);
103 | memcpy(output->flat().data(), innerData, height * width * 2 * sizeof(float));
104 | }
105 | };
106 | REGISTER_KERNEL_BUILDER(Name("DecodeFlo").Device(DEVICE_CPU), DecodeFloOp);
107 |
108 |
--------------------------------------------------------------------------------
/custom_ops/source_code/decode_ppm_op.cc:
--------------------------------------------------------------------------------
1 | //
2 | // lmbspecialops - a collection of tensorflow ops
3 | // Copyright (C) 2017 Albert Ludwigs University of Freiburg, Pattern Recognition and Image Processing, Computer Vision Group
4 | // Author(s): Lukas Voegtle
5 | //
6 | // This program is free software: you can redistribute it and/or modify
7 | // it under the terms of the GNU General Public License as published by
8 | // the Free Software Foundation, either version 3 of the License, or
9 | // (at your option) any later version.
10 | //
11 | // This program is distributed in the hope that it will be useful,
12 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | // GNU General Public License for more details.
15 | //
16 | // You should have received a copy of the GNU General Public License
17 | // along with this program. If not, see .
18 | //
19 | #include "tensorflow/core/framework/op.h"
20 | #include "tensorflow/core/framework/shape_inference.h"
21 | #include "tensorflow/core/framework/op_kernel.h"
22 | #include "tensorflow/core/framework/op.h"
23 |
24 | using namespace tensorflow;
25 |
26 | REGISTER_OP("DecodePpm")
27 | .Input("contents: string")
28 | .Attr("dtype: {uint8, uint16} = DT_UINT8")
29 | .Output("image: dtype")
30 | .Output("maxval: dtype")
31 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
32 | using namespace ::tensorflow::shape_inference;
33 | ShapeHandle unused;
34 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
35 | c->set_output(0,
36 | c->MakeShape({InferenceContext::kUnknownDim,
37 | InferenceContext::kUnknownDim, 3}));
38 | c->set_output(1, c->Scalar());
39 | return Status::OK();
40 | })
41 | .Doc(R"doc(
42 | Decode a PPM-encoded image to a uint8 or uint16 tensor.
43 | contents: 0-D. The PPM-encoded image.
44 | image: 3-D with shape `[height, width, 3]`.
45 | maxval: maxval from the ppm file.
46 | )doc");
47 |
48 |
49 |
50 |
51 | // Decode the contents of a PPM file
52 | class DecodePpmOp : public OpKernel {
53 | public:
54 | explicit DecodePpmOp(OpKernelConstruction* context) : OpKernel(context) {
55 | DataType dt;
56 | OP_REQUIRES_OK(context, context->GetAttr("dtype", &dt));
57 | OP_REQUIRES(
58 | context, dt == DataType::DT_UINT8 || dt == DataType::DT_UINT16,
59 | errors::InvalidArgument("Type must be UINT8 or UINT16, got ", dt));
60 | if (dt == DataType::DT_UINT8) {
61 | desired_channel_bits_ = 8;
62 | } else {
63 | desired_channel_bits_ = 16;
64 | }
65 | }
66 |
67 | static bool skipWhitespace(OpKernelContext* context, const StringPiece& data, size_t* index, bool onlyOne = false) {
68 | if ((*index) >= data.size()) {
69 | context->CtxFailure(errors::InvalidArgument("Invalid PPM header, unexpected end of file"));
70 | return false;
71 | }
72 | char c = data[(*index)];
73 | if (c != ' ' && c != '\t' && c != '\r' && c != '\n' && c != '#') {
74 | context->CtxFailure(errors::InvalidArgument("Invalid PPM header, expected whitespace at ", (*index)));
75 | return false;
76 | }
77 | (*index)++;
78 | if (onlyOne) {
79 | return true;
80 | }
81 | for (; (*index) < data.size(); (*index)++) {
82 | c = data[(*index)];
83 | // Skip comments
84 | if (c == '#') {
85 | for (; (*index) < data.size(); (*index)++) {
86 | c = data[(*index)];
87 | if (c == '\r' || c == '\n') {
88 | break;
89 | }
90 | }
91 | }
92 | if (c != ' ' && c != '\t' && c != '\r' && c != '\n') {
93 | break;
94 | }
95 | }
96 | return true;
97 | }
98 |
99 | static bool readInt(OpKernelContext* context, const StringPiece& data, size_t* index, uint32* number) {
100 | if ((*index) >= data.size()) {
101 | context->CtxFailure(errors::InvalidArgument("Invalid PPM header, unexpected end of file"));
102 | return false;
103 | }
104 | char c = data[(*index)];
105 | if (c < '0' || c > '9') {
106 | context->CtxFailure(errors::InvalidArgument("Invalid PPM header, expected float at ", (*index)));
107 | return false;
108 | }
109 | *number = 0;
110 | for (; (*index) < data.size(); (*index)++) {
111 | c = data[(*index)];
112 | if (c >= '0' && c <= '9') {
113 | *number *= 10;
114 | *number += c - '0';
115 | } else {
116 | break;
117 | }
118 | }
119 | return true;
120 | }
121 |
122 | void Compute(OpKernelContext* context) override {
123 | const Tensor& contents = context->input(0);
124 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(contents.shape()),
125 | errors::InvalidArgument("contents must be scalar, got shape ",
126 | contents.shape().DebugString()));
127 |
128 | // Start decoding image to get shape details
129 | const StringPiece data = contents.scalar()();
130 | if (data.size() < 9) {
131 | //P6000
132 | OP_REQUIRES(context, false,
133 | errors::InvalidArgument("Invalid PPM data size, data too small for PPM file"));
134 | }
135 | if (!data.starts_with("P6")) {
136 | OP_REQUIRES(context, false,
137 | errors::InvalidArgument("Invalid PPM header, expected 'P6'"));
138 | }
139 |
140 | size_t index = 2;
141 | if (!skipWhitespace(context, data, &index)) {
142 | return;
143 | }
144 | uint32 width = 0, height = 0, maxval = 0;
145 | if (!readInt(context, data, &index, &width)) {
146 | return;
147 | }
148 | if (!skipWhitespace(context, data, &index)) {
149 | return;
150 | }
151 | if (!readInt(context, data, &index, &height)) {
152 | return;
153 | }
154 | if (!skipWhitespace(context, data, &index)) {
155 | return;
156 | }
157 | if (!readInt(context, data, &index, &maxval)) {
158 | return;
159 | }
160 | if (!skipWhitespace(context, data, &index, true)) {
161 | return;
162 | }
163 | // Verify that width and height are not too large:
164 | // - verify width and height don't overflow int.
165 | // - width can later be multiplied by channels_ and sizeof(uint16), so
166 | // verify single dimension is not too large.
167 | // - verify when width and height are multiplied together, there are a few
168 | // bits to spare as well.
169 | const int64 total_size =
170 | static_cast(width) * static_cast(height);
171 | if (width != static_cast(width) || width <= 0 ||
172 | width >= (1LL << 27) || height != static_cast(height) ||
173 | height <= 0 || height >= (1LL << 27) || total_size >= (1LL << 29)) {
174 | OP_REQUIRES(context, false,
175 | errors::InvalidArgument("PPM size too large for int: ",
176 | width, " by ", height));
177 | }
178 |
179 | // Allocate tensor
180 | Tensor* output = nullptr;
181 | OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({height, width, 3}), &output));
182 | Tensor* output_maxval = nullptr;
183 | OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({}), &output_maxval));
184 |
185 | if (desired_channel_bits_ == 8 && maxval <= 255) {
186 | if (data.size() != index + width * height * 3 * sizeof(uint8)) {
187 | OP_REQUIRES(context, false,
188 | errors::InvalidArgument("Invalid PPM data size, expected ", index + width * height * 3 * sizeof(uint8)));
189 | }
190 | // Finish decoding image
191 | const uint8* innerData = (const uint8*)(data.data() + index);
192 | uint8* dstData = output->flat().data();
193 | std::memcpy(dstData, innerData, height * width * 3 * sizeof(uint8));
194 | output_maxval->scalar()() = (uint8)maxval;
195 | } else if (desired_channel_bits_ == 16 && maxval > 255) {
196 | if (data.size() != index + width * height * 3 * sizeof(uint16)) {
197 | OP_REQUIRES(context, false,
198 | errors::InvalidArgument("Invalid PPM data size, expected ", index + width * height * 3 * sizeof(uint16)));
199 | }
200 | // Finish decoding image
201 | const uint8* innerData = (const uint8*)(data.data() + index);
202 | uint16* dstData = output->flat().data();
203 | std::memcpy(dstData, innerData, height * width * 3 * sizeof(uint16));
204 | // PPM data is always in big endian
205 | if (port::kLittleEndian) {
206 | // Change endianness from big endian to system endianness
207 | size_t size = height * width * 3 * sizeof(uint16);
208 | uint8* bytes = (uint8*)dstData;
209 | for (size_t i = 0; i < size; i+=sizeof(uint16)) {
210 | for (size_t j = 0; j < sizeof(uint16) / 2; j++) {
211 | std::swap(bytes[i + j], bytes[i + (sizeof(uint16) - j - 1)]);
212 | }
213 | }
214 | }
215 |
216 | output_maxval->scalar()() = (uint16)maxval;
217 | } else {
218 | OP_REQUIRES(context, false,
219 | errors::InvalidArgument("PPM maxval ", maxval, " does not match requested bit depth format ", desired_channel_bits_));
220 | }
221 | }
222 |
223 | private:
224 | int desired_channel_bits_;
225 | };
226 | REGISTER_KERNEL_BUILDER(Name("DecodePpm").Device(DEVICE_CPU), DecodePpmOp);
227 |
228 |
--------------------------------------------------------------------------------
/data_loader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import tensorflow as tf
4 | import custom_ops.compiled as compiled_ops
5 | import urllib
6 | import zipfile
7 |
8 | class DataLoader:
9 |
10 | def __init__(self, dirs, total_num_examples):
11 |
12 | self.dirs = dirs
13 | self.total_num_examples = total_num_examples
14 |
15 | self.download_and_extract_data_if_necessary()
16 |
17 | def download_and_extract_data_if_necessary(self):
18 | cwd = os.getcwd()
19 | url = 'https://lmb.informatik.uni-freiburg.de/data/FlyingChairs/FlyingChairs.zip'
20 | filepath = cwd + '/' + url.split('/')[-1]
21 |
22 | if not os.path.isdir(filepath[:-4]):
23 |
24 | result = input('About to download and extract the flying chairs dataset.'
25 | '\nThis requires ~85GB of free space. Continue? [y/n]\n')
26 |
27 | if result == 'y':
28 | print('downloading flying chairs dataset as zip, this may take ~30-60 minutes...')
29 | urllib.request.urlretrieve(url, filepath)
30 | else:
31 | exit(0)
32 |
33 | print('extracting flying chairs dataset, this may take a while...')
34 |
35 | zip_ref = zipfile.ZipFile(filepath, 'r')
36 | zip_ref.extractall(cwd)
37 | zip_ref.close()
38 | os.remove(filepath)
39 | os.rename(filepath[:-4] + '_release', filepath[:-4])
40 |
41 | def input_parser(self, img_paths, flow_img_paths):
42 |
43 | imgs = tf.map_fn(lambda img_path: compiled_ops.decode_ppm(tf.read_file(img_path[0]))[0], img_paths, dtype=tf.uint8)
44 | flow_imgs = tf.map_fn(lambda flow_img_path: compiled_ops.decode_flo(tf.read_file(flow_img_path[0])), flow_img_paths, dtype=tf.float32)
45 | return imgs, flow_imgs
46 |
47 | def prime_image_data(self, image_dir, starting_example, ending_example, file_extension):
48 |
49 | image_filenames = [file for file in os.listdir(image_dir) if file.endswith(file_extension)]
50 | image_filenames.sort()
51 | image_paths = [image_dir + image_filename for image_filename in image_filenames]
52 |
53 | grouped_images = []
54 | group = []
55 | image_num = starting_example
56 | for image_path in image_paths:
57 | trimmed_string = image_path.split('/')[-1].split('_')[0]
58 | current_image_num = int(trimmed_string) - 1 if str.isdigit(trimmed_string) else int(trimmed_string[1:]) - 1
59 | if current_image_num >= starting_example:
60 | if current_image_num != image_num:
61 | image_num = current_image_num
62 | if current_image_num > ending_example:
63 | break
64 | grouped_images.append(group)
65 | group = []
66 | group.append(image_path)
67 |
68 | grouped_images.append(group)
69 |
70 | return grouped_images
71 |
72 |
73 | def prime_data_for_loading(self, starting_example, ending_example, batch_size, training):
74 |
75 | self.batch_size = batch_size
76 |
77 | image_paths_list = list()
78 | image_paths_tensor_list = list()
79 |
80 | rgb_image_dir = self.dirs.rgb_image_dir
81 | grouped_rgb_images = self.prime_image_data(rgb_image_dir, starting_example, ending_example, self.dirs.rgb_format)
82 | rgb_img_paths_tensor = tf.expand_dims(tf.constant(grouped_rgb_images), -1)
83 |
84 | image_paths_list.append(grouped_rgb_images)
85 | image_paths_tensor_list.append(rgb_img_paths_tensor)
86 |
87 | flow_image_dir = self.dirs.flow_image_dir
88 | grouped_flow_images = self.prime_image_data(flow_image_dir, starting_example, ending_example, self.dirs.flow_format)
89 | flow_img_paths_tensor = tf.expand_dims(tf.constant(grouped_flow_images),-1)
90 |
91 | image_paths_list.append(grouped_flow_images)
92 | image_paths_tensor_list.append(flow_img_paths_tensor)
93 |
94 | data = tf.data.Dataset.from_tensor_slices(tuple(image_paths_tensor_list))
95 | data = data.map(map_func=self.input_parser, num_parallel_calls=8)
96 | data = data.shuffle(tf.cast(batch_size*100,tf.int64))
97 | data = data.apply(tf.contrib.data.batch_and_drop_remainder(tf.cast(batch_size,tf.int64)))
98 | data = data.prefetch(1)
99 | data = data.repeat()
100 |
101 | return data
102 |
103 | def prime_data(self, start_it_train, end_it_train, start_it_valid, end_it_valid, batch_size, data_type):
104 |
105 | training_dataset = self.prime_data_for_loading(start_it_train, end_it_train, batch_size, True)
106 | self.training_data_len = end_it_train - start_it_train
107 |
108 | validation_dataset = self.prime_data_for_loading(start_it_valid, end_it_valid, batch_size, False)
109 | self.validation_data_len = end_it_valid - start_it_valid
110 |
111 | iterator = tf.data.Iterator.from_string_handle(data_type, training_dataset.output_types, training_dataset.output_shapes)
112 | self.next_element = iterator.get_next()
113 |
114 | training_iterator = training_dataset.make_initializable_iterator()
115 | validation_iterator = validation_dataset.make_initializable_iterator()
116 |
117 | self.training_init_op = training_iterator.make_initializer(training_dataset)
118 | self.validation_init_op = validation_iterator.make_initializer(validation_dataset)
119 |
120 | self.training_handle = self.sess.run(training_iterator.string_handle())
121 | self.validation_handle = self.sess.run(validation_iterator.string_handle())
122 |
123 | def set_session(self, sess):
124 | self.sess = sess
--------------------------------------------------------------------------------
/directories.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | class Directories:
4 |
5 | def __init__(self):
6 | self.init_directories()
7 |
8 | def init_directories(self):
9 | self.__project_dir = os.getcwd()
10 | self.__chkpt_dir = '/chkpts/'
11 | self.__log_dir = '/log/'
12 | self.__rgb_image_dir = '/FlyingChairs/data/'
13 | self.__flow_image_dir = '/FlyingChairs/data/'
14 | self.__rgb_format = '.ppm'
15 | self.__flow_format = '.flo'
16 | self.__training_dir = 'training/'
17 | self.__validation_dir = 'validation/'
18 | self.__networks_dir = '/networks/'
19 |
20 | def update_for_task(self, current_dir, network_dir):
21 |
22 | self.init_directories()
23 |
24 | self.__networks_dir = current_dir + self.__networks_dir
25 | self.__chkpt_dir = network_dir + self.__chkpt_dir
26 | self.__log_dir = network_dir + self.__log_dir
27 | self.__rgb_image_dir = current_dir + self.__rgb_image_dir
28 | self.__flow_image_dir = current_dir + self.__flow_image_dir
29 |
30 | # Getters #
31 | #---------#
32 |
33 | @property
34 | def project_dir(self):
35 | return self.__project_dir
36 | @property
37 | def networks_dir(self):
38 | return self.__networks_dir
39 | @property
40 | def chkpt_dir(self):
41 | return self.__chkpt_dir
42 | @property
43 | def log_dir(self):
44 | return self.__log_dir
45 | @property
46 | def rgb_image_dir(self):
47 | return self.__rgb_image_dir
48 | @property
49 | def flow_image_dir(self):
50 | return self.__flow_image_dir
51 | @property
52 | def image_data_filename(self):
53 | return self.__image_data_filename
54 | @property
55 | def rgb_format(self):
56 | return self.__rgb_format
57 | @property
58 | def flow_format(self):
59 | return self.__flow_format
60 | @property
61 | def training_dir(self):
62 | return self.__training_dir
63 | @property
64 | def validation_dir(self):
65 | return self.__validation_dir
66 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import faulthandler
3 | import numpy as np
4 | from task import Task
5 |
6 | def main():
7 |
8 | seed = 1
9 | np.random.seed(seed)
10 |
11 | faulthandler.enable()
12 |
13 | mode = input('\n\ntrain (t), visualise (v) ?\n')
14 | while mode != 't' and mode != 'v':
15 | mode = input('\n\ntrain (t), visualise (v) ?\n')
16 |
17 | task = Task()
18 |
19 | if mode == 't':
20 | task.run(train=True)
21 | else:
22 | task.run(train=False)
23 |
24 |
25 | if __name__ == '__main__':
26 | main()
27 |
--------------------------------------------------------------------------------
/network.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import os
3 | import math
4 | import numpy as np
5 | import custom_ops.native as native_ops
6 | import custom_ops.compiled as compiled_ops
7 |
8 | class Dims():
9 | def __init__(self):
10 | self.input_vector = 0
11 | self.input_image = [384,448,3]
12 | self.output = [384,448,2]
13 |
14 | class Network():
15 |
16 | def __init__(self, data_loader, dirs):
17 |
18 | self.data_loader = data_loader
19 | self.chkpt_dir = dirs.chkpt_dir
20 | self.log_dir = dirs.log_dir
21 | self.log_dir_training = dirs.log_dir + 'training/'
22 | self.log_dir_validation = dirs.log_dir + 'validation/'
23 |
24 | self.dims = Dims()
25 |
26 | self.global_step = tf.placeholder(tf.int32)
27 |
28 | self.window_size = 11
29 | self.max_dist = int(math.floor(self.window_size/2))
30 |
31 | self.batch_size = 8
32 |
33 | self.initial_learning_rate = 2e-4
34 | self.learning_decrement_rate = 0.2e6
35 | self.learning_decrement = 0.5
36 | self.min_learning_rate = 6.25e-6
37 | self.max_learning_rate = 1e-4
38 |
39 | self.total_iterations = 1e6
40 |
41 | self.validation_ratio = 0.1
42 | self.num_training_examples = int(round(self.data_loader.total_num_examples * (1 - self.validation_ratio)))
43 | self.num_validation_examples = self.data_loader.total_num_examples - self.num_training_examples
44 | self.total_num_examples = self.num_training_examples + self.num_validation_examples
45 |
46 | self.define_placeholders()
47 |
48 | @staticmethod
49 | def directory():
50 | return os.path.dirname(os.path.realpath(__file__))
51 |
52 |
53 | # Operations #
54 | #------------#
55 |
56 | def warp(self, ct, w):
57 | return native_ops.image_warp(ct, w)
58 |
59 | def cost_volume(self, cwt, ctm1):
60 | return compiled_ops.correlation(cwt, ctm1, pad=4, kernel_size=1, max_displacement=4, stride_1=1, stride_2=1)
61 |
62 |
63 | # Inputs #
64 | #--------#
65 |
66 | def define_placeholders(self):
67 |
68 | self.data_type = tf.placeholder(tf.string)
69 | self.vis_mode = tf.placeholder(tf.bool)
70 | self.step_placeholder = tf.placeholder(tf.int32)
71 |
72 | def define_data_loader(self):
73 |
74 | self.loaded_data = list(self.data_loader.next_element)
75 | self.loaded_images = tf.cast(self.loaded_data[0],tf.float32)
76 | self.loaded_flow_images = tf.cast(self.loaded_data[1][:,0:1],tf.float32)/20
77 |
78 | def define_network_inputs(self):
79 |
80 | # rgb images
81 | self.stacked_images = tf.reshape(self.loaded_images, (-1, 384, 512, 3))
82 |
83 | # flow images
84 | self.appended_flow_images = tf.concat((self.loaded_flow_images, tf.zeros((self.batch_size,1, 384, 512, 2))), 1)
85 | self.stacked_flow_images = tf.reshape(self.appended_flow_images, (-1, 384, 512, 2))
86 |
87 | # combined images
88 | self.stacked_combined_images = tf.concat((self.stacked_images, self.stacked_flow_images),-1)
89 | self.cropped_combined_images = tf.map_fn(lambda img: tf.image.crop_to_bounding_box(img,
90 | 0, tf.random_uniform((), 0, 512 - 448, tf.int32),
91 | 384, 448), self.stacked_combined_images)
92 |
93 | self.unstacked_combined_images = tf.reshape(self.cropped_combined_images, (self.batch_size, 2, self.dims.input_image[0],
94 | self.dims.input_image[1], 5))
95 |
96 | # rgb images
97 | self.unstacked_images = self.unstacked_combined_images[:,:,:,:,0:3]
98 | self.x_images = tf.transpose(self.unstacked_images, [0,1,4,2,3])
99 |
100 | # flow images
101 | self.x_flow_images = self.unstacked_combined_images[:,0,:,:,3:5]
102 |
103 | def define_network_targets(self):
104 |
105 | # ground truth outputs
106 | self.gt_ys = list()
107 | for i in range(5):
108 | image = tf.image.resize_images(self.x_flow_images, (int(self.dims.input_image[0]/(math.pow(2,6-i))),
109 | int(self.dims.input_image[1]/(math.pow(2,6-i)))))
110 | self.gt_ys.append(tf.transpose(image, [0,3,1,2]))
111 |
112 |
113 | # Architecture #
114 | #--------------#
115 |
116 | def feature_pyramid_forward_pass(self, I):
117 |
118 | ct0_5 = tf.layers.conv2d(I,16,[3,3],(2,2),'same','channels_first',(1,1),tf.nn.leaky_relu)
119 | ct1 = tf.layers.conv2d(ct0_5, 16, [3, 3], (1, 1), 'same', 'channels_first', (1, 1), tf.nn.leaky_relu)
120 |
121 | ct1_5 = tf.layers.conv2d(ct1,32,[3,3],(2,2),'same','channels_first',(1,1),tf.nn.leaky_relu)
122 | ct2 = tf.layers.conv2d(ct1_5, 32, [3, 3], (1, 1), 'same', 'channels_first', (1, 1), tf.nn.leaky_relu)
123 |
124 | ct2_5 = tf.layers.conv2d(ct2,64,[3,3],(2,2),'same','channels_first',(1,1),tf.nn.leaky_relu)
125 | ct3 = tf.layers.conv2d(ct2_5, 64, [3, 3], (1, 1), 'same', 'channels_first', (1, 1), tf.nn.leaky_relu)
126 |
127 | ct3_5 = tf.layers.conv2d(ct3,64,[3,3],(2,2),'same','channels_first',(1,1),tf.nn.leaky_relu)
128 | ct4 = tf.layers.conv2d(ct3_5, 64, [3, 3], (1, 1), 'same', 'channels_first', (1, 1), tf.nn.leaky_relu)
129 |
130 | ct4_5 = tf.layers.conv2d(ct4,64,[3,3],(2,2),'same','channels_first',(1,1),tf.nn.leaky_relu)
131 | ct5 = tf.layers.conv2d(ct4_5, 64, [3, 3], (1, 1), 'same', 'channels_first', (1, 1), tf.nn.leaky_relu)
132 |
133 | ct5_5 = tf.layers.conv2d(ct5,64,[3,3],(2,2),'same','channels_first',(1,1),tf.nn.leaky_relu)
134 | ct6 = tf.layers.conv2d(ct5_5, 64, [3, 3], (1, 1), 'same', 'channels_first', (1, 1), tf.nn.leaky_relu)
135 |
136 | return [ct6, ct5, ct4, ct3, ct2, ct1]
137 |
138 | def flow_estimator_forward_pass(self, ctm1, ct, w):
139 |
140 | ct_trans = tf.transpose(ct, [0,2,3,1])
141 | w_trans = tf.transpose(w, [0,2,3,1])
142 |
143 | cwt_trans = self.warp(ct_trans,w_trans)
144 |
145 | cwt = tf.transpose(cwt_trans, [0,3,1,2])
146 |
147 | cvt = self.cost_volume(cwt,ctm1)
148 |
149 | concatted = tf.concat((cvt,w,ctm1),-3)
150 |
151 | conv1 = tf.layers.conv2d(concatted,128,[3,3],(1,1),'same','channels_first',(1,1),tf.nn.leaky_relu)
152 | conv2 = tf.layers.conv2d(conv1,128,[3,3],(1,1),'same','channels_first',(1,1),tf.nn.leaky_relu)
153 | conv3 = tf.layers.conv2d(conv2,96,[3,3],(1,1),'same','channels_first',(1,1),tf.nn.leaky_relu)
154 | conv4 = tf.layers.conv2d(conv3,64,[3,3],(1,1),'same','channels_first',(1,1),tf.nn.leaky_relu)
155 | ft = tf.layers.conv2d(conv4,32,[3,3],(1,1),'same','channels_first',(1,1),tf.nn.leaky_relu)
156 | wt = tf.layers.conv2d(ft,2,[3,3],(1,1),'same','channels_first',(1,1),tf.nn.leaky_relu)
157 |
158 | return ft, wt
159 |
160 | def context_forward_pass(self, ft, wt):
161 |
162 | concatted = tf.concat((ft, wt), -3)
163 |
164 | conv1 = tf.layers.conv2d(concatted,128,[3,3],(1,1),'same','channels_first',(1,1),tf.nn.leaky_relu)
165 | conv2 = tf.layers.conv2d(conv1,128,[3,3],(1,1),'same','channels_first',(2,2),tf.nn.leaky_relu)
166 | conv3 = tf.layers.conv2d(conv2,128,[3,3],(1,1),'same','channels_first',(4,4),tf.nn.leaky_relu)
167 | conv4 = tf.layers.conv2d(conv3,96,[3,3],(1,1),'same','channels_first',(8,8),tf.nn.leaky_relu)
168 | conv5 = tf.layers.conv2d(conv4,64,[3,3],(1,1),'same','channels_first',(16,16),tf.nn.leaky_relu)
169 | conv6 = tf.layers.conv2d(conv5,32,[3,3],(1,1),'same','channels_first',(1,1),tf.nn.leaky_relu)
170 | conv7 = tf.layers.conv2d(conv6,2,[3,3],(1,1),'same','channels_first',(1,1),tf.nn.leaky_relu)
171 |
172 | return conv7 + wt
173 |
174 | def define_network_structure(self):
175 |
176 | I1 = self.x_images[:,0]
177 | I2 = self.x_images[:,1]
178 |
179 | self.features1 = self.feature_pyramid_forward_pass(I1)
180 | self.features2 = self.feature_pyramid_forward_pass(I2)
181 |
182 | self.flow_terms = list()
183 | self.upsampled_flow_terms = list()
184 | self.flow_features = list()
185 |
186 | for i, features in enumerate(zip(self.features1,self.features2)):
187 |
188 | feature1 = features[0]
189 | feature2 = features[1]
190 |
191 | if i == 0:
192 | wt_upsampled = tf.zeros((self.batch_size,2,6,7))
193 | wt_upsampled_scaled = wt_upsampled
194 | self.flow_terms.append(wt_upsampled)
195 | else:
196 | wt = self.flow_terms[i]
197 | wt_trans = tf.transpose(wt, [0, 2, 3, 1])
198 | wt_trans_upsampled = tf.image.resize_images(wt_trans, ((int(wt.shape[-2] * 2), int(wt.shape[-1] * 2))))
199 | wt_upsampled = tf.transpose(wt_trans_upsampled, [0, 3, 1, 2])
200 | wt_upsampled_scaled = wt_upsampled*20/math.pow(2,6-i)
201 |
202 | self.upsampled_flow_terms.append(wt_upsampled)
203 | flow_features, estimated_flow = self.flow_estimator_forward_pass(feature1,feature2,wt_upsampled_scaled)
204 | self.flow_features.append(flow_features)
205 |
206 | refined_flow = self.context_forward_pass(flow_features, estimated_flow)
207 |
208 | self.flow_terms.append(wt_upsampled + estimated_flow + refined_flow)
209 |
210 | self.network_outputs = self.flow_terms[1:]
211 |
212 |
213 | # Build #
214 | #-------#
215 |
216 | def build_model(self):
217 |
218 | self.define_data_loader()
219 |
220 | self.define_network_inputs()
221 | self.define_network_targets()
222 |
223 | self.define_network_structure()
224 | self.define_cost()
225 | self.define_optimiser()
226 | self.define_optimisation()
227 |
228 | tf.global_variables_initializer().run()
229 |
230 | def load_params(self, sess, saver, chkpt_num=None, test=False):
231 | dir(tf.contrib)
232 | if chkpt_num is None:
233 | try:
234 | saver.restore(sess, tf.train.latest_checkpoint(self.chkpt_dir))
235 | except:
236 | return False, 0
237 | with open(self.chkpt_dir + 'checkpoint', 'r') as checkpoint_file:
238 | starting_it = int(checkpoint_file.read().split('\n', 1)[0].split('-', 1)[-1][:-1]) + 1
239 | else:
240 | try:
241 | saver.restore(sess, self.chkpt_dir + 'model-' + str(chkpt_num))
242 | except:
243 | return False, 0
244 | starting_it = chkpt_num + 1
245 | return True, starting_it
246 |
247 |
248 | # Optimisation #
249 | #--------------#
250 |
251 | def define_cost(self):
252 |
253 | self.n_targets = self.network_outputs
254 |
255 | # cost
256 | alphas = [0.32,0.08,0.02,0.01,0.005]
257 | self.cost_target = tf.constant([0.])
258 | for i in range(5):
259 | cost_target_terms = tf.pow(self.n_targets[i]-self.gt_ys[i],2)
260 | self.cost_target += alphas[i]*tf.reduce_sum(cost_target_terms)
261 |
262 | self.regularisation_loss = tf.constant([0.])
263 | weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
264 | for weight in weights:
265 | self.regularisation_loss += tf.nn.l2_loss(weight)
266 | factor_l2 = 0.0004
267 |
268 | self.cost_target += factor_l2*self.regularisation_loss
269 | self.cost_aux = tf.identity(float(0))
270 | self.cost = tf.squeeze(tf.add(self.cost_target, self.cost_aux))
271 |
272 | self.cost_last_step = self.cost
273 |
274 | def define_optimiser(self):
275 |
276 | self.learning_rate = tf.minimum(tf.maximum(
277 | tf.train.exponential_decay(self.initial_learning_rate, self.global_step,
278 | self.learning_decrement_rate, self.learning_decrement, staircase=True),
279 | self.min_learning_rate),self.max_learning_rate)
280 |
281 | self.optimizer = tf.train.AdamOptimizer(self.learning_rate)
282 |
283 | def define_optimisation(self):
284 |
285 | self.updating_train_step = self.optimizer.minimize(self.cost)
286 | self.train_step = tf.cond(self.vis_mode, lambda: True, lambda: self.updating_train_step)
287 |
288 | self.finished = tf.constant(False,tf.bool)
289 |
290 |
291 | # Log Summaries #
292 | #---------------#
293 |
294 | def init_summary(self):
295 |
296 | self.cost_summary = tf.summary.scalar('cost', self.cost)
297 | self.learning_rate_summary = tf.summary.scalar('learning_rate', self.learning_rate)
298 | self.log_summary_op = tf.summary.merge([self.cost_summary, self.learning_rate_summary])
299 |
300 | self.random_batch_num = tf.random_uniform([], 0, self.batch_size, tf.int32)
301 | self.vis_summary_op = tf.Summary()
302 |
303 | self.gpu_usage = tf.placeholder(tf.int32)
304 | self.cpu_usage = tf.placeholder(tf.int32)
305 | self.ram_usage = tf.placeholder(tf.float32)
306 | self.throughput = tf.placeholder(tf.float32)
307 |
308 | gpu_memory_summary = tf.summary.scalar('percent gpu memory',
309 | tf.divide(tf.cast(tf.contrib.memory_stats.BytesInUse(),tf.float32),tf.constant(8e7)))
310 | gpu_usage_summary = tf.summary.scalar('percent gpu usage', self.gpu_usage)
311 | cpu_usage_summary = tf.summary.scalar('percent cpu usage', self.cpu_usage)
312 | ram_usage_summary = tf.summary.scalar('percent ram usage', self.ram_usage)
313 | throughput_summary = tf.summary.scalar('throughput MB', self.throughput)
314 |
315 | self.perflog_summary_op = tf.summary.merge([gpu_memory_summary,
316 | gpu_usage_summary,
317 | cpu_usage_summary,
318 | ram_usage_summary,
319 | throughput_summary])
320 |
321 | def get_image_summary(self, sess, dict_feed, fps):
322 |
323 | dict_feed[self.step_placeholder] = 0
324 |
325 | x_images_out, gt_flow_out, predicted_flow_out = \
326 | sess.run((self.x_images[self.random_batch_num], self.x_flow_images[self.random_batch_num],
327 | self.network_outputs[-1][self.random_batch_num]), dict_feed)
328 |
329 | x_images_trans = np.transpose(x_images_out, (0,2,3,1))
330 | predicted_flow_trans = np.transpose(predicted_flow_out, (1,2,0))
331 |
332 | images_arr = native_ops.modify_images_for_vis(x_images_trans, gt_flow_out, predicted_flow_trans)
333 |
334 | return native_ops.convert_array_to_gif_summary(images_arr, tag='images', fps=fps)
335 |
336 | def get_log_summary(self, sess, dict_feed):
337 | return sess.run(self.log_summary_op, dict_feed)
338 |
339 | def get_summary(self, sess, summary_op, dict_feed):
340 | if summary_op is self.vis_summary_op:
341 | fps = 2
342 | return self.get_image_summary(sess, dict_feed, fps)
343 | elif summary_op is self.log_summary_op:
344 | return self.get_log_summary(sess, dict_feed)
345 |
346 | def write_summary(self, step, summary, writer):
347 | writer.add_summary(summary, step)
348 | writer.flush()
349 |
350 | def write_summaries(self, sess, i, dict_feed, summary_op, data_handles, summary_writers):
351 |
352 | training_handle = data_handles[0]
353 | validation_handle = data_handles[1]
354 |
355 | dict_feed[self.data_type] = training_handle
356 |
357 | training_summ = self.get_summary(sess, summary_op, dict_feed)
358 | self.write_summary(i,training_summ,summary_writers[0])
359 |
360 | dict_feed[self.data_type] = validation_handle
361 |
362 | validation_summ = self.get_summary(sess, summary_op, dict_feed)
363 | self.write_summary(i,validation_summ,summary_writers[1])
364 |
--------------------------------------------------------------------------------
/readme_images/example_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/djl11/PWCNetTensorFlow/42cfc60b36c8540cf9e71ebb28e7c1f08bb5589b/readme_images/example_loss.png
--------------------------------------------------------------------------------
/readme_images/example_training_flow1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/djl11/PWCNetTensorFlow/42cfc60b36c8540cf9e71ebb28e7c1f08bb5589b/readme_images/example_training_flow1.gif
--------------------------------------------------------------------------------
/readme_images/example_training_flow2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/djl11/PWCNetTensorFlow/42cfc60b36c8540cf9e71ebb28e7c1f08bb5589b/readme_images/example_training_flow2.gif
--------------------------------------------------------------------------------
/readme_images/example_training_flow3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/djl11/PWCNetTensorFlow/42cfc60b36c8540cf9e71ebb28e7c1f08bb5589b/readme_images/example_training_flow3.gif
--------------------------------------------------------------------------------
/readme_images/example_training_flow4.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/djl11/PWCNetTensorFlow/42cfc60b36c8540cf9e71ebb28e7c1f08bb5589b/readme_images/example_training_flow4.gif
--------------------------------------------------------------------------------
/readme_images/example_validation_flow1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/djl11/PWCNetTensorFlow/42cfc60b36c8540cf9e71ebb28e7c1f08bb5589b/readme_images/example_validation_flow1.gif
--------------------------------------------------------------------------------
/readme_images/example_validation_flow2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/djl11/PWCNetTensorFlow/42cfc60b36c8540cf9e71ebb28e7c1f08bb5589b/readme_images/example_validation_flow2.gif
--------------------------------------------------------------------------------
/readme_images/example_validation_flow3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/djl11/PWCNetTensorFlow/42cfc60b36c8540cf9e71ebb28e7c1f08bb5589b/readme_images/example_validation_flow3.gif
--------------------------------------------------------------------------------
/readme_images/example_validation_flow4.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/djl11/PWCNetTensorFlow/42cfc60b36c8540cf9e71ebb28e7c1f08bb5589b/readme_images/example_validation_flow4.gif
--------------------------------------------------------------------------------
/task.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from trainer import Trainer
4 | from data_loader import DataLoader
5 | from directories import Directories
6 | from network import Network
7 |
8 | class Task():
9 |
10 | def __init__(self):
11 |
12 | self._dirs = Directories()
13 | self._init_data_loader()
14 | self._init_network()
15 |
16 | def _update_dirs(self, log_folder):
17 |
18 | current_dir = os.path.dirname(os.path.realpath(__file__))
19 | network_dir = Network.directory() + log_folder
20 | self._dirs.update_for_task(current_dir, network_dir)
21 |
22 | def _init_network(self):
23 |
24 | self.__network = Network(data_loader=self.__data_loader,
25 | dirs=self._dirs)
26 |
27 | def _init_data_loader(self):
28 | self.__data_loader = DataLoader(dirs=self._dirs,
29 | total_num_examples=22872)
30 |
31 | def _init_trainer(self):
32 |
33 | self.__trainer = Trainer(data_loader=self.__data_loader,
34 | network=self.__network,
35 | dirs=self._dirs,
36 | ld_chkpt=True,
37 | save_freq=10000,
38 | log_freq=200,
39 | vis_freq=5000)
40 | return self.__trainer
41 |
42 |
43 | def run(self, train):
44 |
45 | self._init_trainer()
46 | log_folder = '/logged_data'
47 |
48 | self._update_dirs(log_folder)
49 | self.__network.chkpt_dir = self._dirs.chkpt_dir
50 |
51 | if train:
52 | self.__trainer.run_trainer()
53 | else:
54 | self.__trainer.run_visualiser()
55 | self.__trainer.sess.close()
56 |
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf # add tensorflow framework
2 | import os
3 | import libtmux
4 | import pathlib
5 |
6 | class Trainer():
7 |
8 | def __init__(self, data_loader, network, dirs, ld_chkpt, save_freq, log_freq,
9 | vis_freq):
10 |
11 | self.sess = tf.InteractiveSession(config=tf.ConfigProto(log_device_placement=False))
12 |
13 | self.data_loader = data_loader
14 | self.data_loader.set_session(self.sess)
15 |
16 | self.network = network
17 | self.num_training_examples = self.network.num_training_examples
18 | self.num_validation_examples = self.network.num_validation_examples
19 | self.total_num_examples = self.num_training_examples + self.num_validation_examples
20 |
21 | self.networks_dir = dirs.networks_dir
22 | self.ld_chkpt = ld_chkpt
23 |
24 | self.save_freq = save_freq
25 | self.log_freq = log_freq
26 | self.vis_freq = vis_freq
27 | self.dirs = dirs
28 |
29 | def init_saver(self):
30 | self.chkpt_dir = self.dirs.chkpt_dir
31 | all_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
32 | variables_to_save = [variable for variable in all_variables if variable.name[0:11] != 'placeholder']
33 | self.saver = tf.train.Saver(var_list=variables_to_save, max_to_keep=None)
34 | pathlib.Path(self.chkpt_dir).mkdir(parents=True, exist_ok=True)
35 |
36 | def initial_save(self):
37 | if not os.path.exists(self.chkpt_dir + 'model.meta'):
38 | self.saver.save(self.sess, self.chkpt_dir + 'model')
39 |
40 | def make_log_dirs(self):
41 | for log_dir in self.log_dirs:
42 | pathlib.Path(log_dir).mkdir(parents=True, exist_ok=True)
43 |
44 | def init_log_dirs(self):
45 |
46 | self.log_dir = self.dirs.log_dir
47 |
48 | self.log_dir_train = self.log_dir + 'plots/' + self.dirs.training_dir
49 | self.log_dir_valid = self.log_dir + 'plots/' + self.dirs.validation_dir
50 |
51 | self.vis_dir_train = self.log_dir + 'vis/' + self.dirs.training_dir
52 | self.vis_dir_valid = self.log_dir + 'vis/' + self.dirs.validation_dir
53 |
54 | self.log_dirs = [
55 | self.log_dir_train,
56 | self.log_dir_valid,
57 | self.vis_dir_train,
58 | self.vis_dir_valid]
59 |
60 | self.make_log_dirs()
61 |
62 | def init_log_summary_writers(self):
63 |
64 | self.sw_log_train = tf.summary.FileWriter(self.log_dir_train, graph=tf.get_default_graph())
65 | self.sw_log_valid = tf.summary.FileWriter(self.log_dir_valid, graph=tf.get_default_graph())
66 |
67 | self.summary_writers[0] = self.sw_log_train
68 | self.summary_writers[1] = self.sw_log_valid
69 |
70 | def init_vis_summary_writers(self):
71 |
72 | self.sw_vis_train = tf.summary.FileWriter(self.vis_dir_train, graph=tf.get_default_graph())
73 | self.sw_vis_valid = tf.summary.FileWriter(self.vis_dir_valid, graph=tf.get_default_graph())
74 |
75 | self.summary_writers[2] = self.sw_vis_train
76 | self.summary_writers[3] = self.sw_vis_valid
77 |
78 | def init_summary_writers(self):
79 |
80 | self.network.init_summary()
81 | self.summary_writers = [0]*4
82 |
83 | self.init_log_summary_writers()
84 | self.init_vis_summary_writers()
85 |
86 | def init_logger(self):
87 | self.init_log_dirs()
88 | self.init_summary_writers()
89 |
90 | def init_dataset_loader(self, i):
91 |
92 | dict_feed = {self.network.data_type: self.data_loader.training_handle, self.network.global_step: i}
93 | self.sess.run(self.data_loader.training_init_op,dict_feed)
94 |
95 | dict_feed[self.network.data_type] = self.data_loader.validation_handle
96 | self.sess.run(self.data_loader.validation_init_op,dict_feed)
97 |
98 | def write_summaries(self, i, dict_feed, summary_op, summary_writers):
99 | data_handles = (self.data_loader.training_handle, self.data_loader.validation_handle)
100 | self.network.write_summaries(self.sess, i, dict_feed, summary_op, data_handles, summary_writers)
101 |
102 | def log(self, i):
103 | self.write_summaries(i, self.dict_feed, self.network.log_summary_op, self.summary_writers[0:2])
104 | print('logged, step ' + str(i))
105 |
106 | def vis(self, i):
107 | self.write_summaries(i, self.dict_feed, self.network.vis_summary_op, self.summary_writers[2:4])
108 |
109 | def save(self, i):
110 | self.saver.save(self.sess, self.chkpt_dir + 'model', global_step=i, write_meta_graph=False)
111 | print('saved, step ' + str(i))
112 |
113 | def train(self, starting_it, vis_mode=False):
114 |
115 | global_step = starting_it
116 |
117 | self.dict_feed = {}
118 | self.dict_feed[self.network.vis_mode] = vis_mode
119 | self.dict_feed[self.network.global_step] = global_step
120 | self.dict_feed[self.network.data_type] = self.data_loader.training_handle
121 |
122 | if starting_it == self.network.total_iterations: return True
123 |
124 | if vis_mode:
125 | self.vis_freq = 1
126 | self.dict_feed[self.network.vis_mode] = True
127 |
128 | while global_step < self.network.total_iterations or self.network.total_iterations == -1:
129 |
130 | self.dict_feed[self.network.global_step] = global_step
131 | self.dict_feed[self.network.data_type] = self.data_loader.training_handle
132 | self.dict_feed[self.network.step_placeholder] = 0
133 |
134 | _, finished, cost = self.sess.run((self.network.train_step, self.network.finished, self.network.cost_last_step), self.dict_feed)
135 |
136 | if global_step % self.log_freq == 0 and not vis_mode: self.log(global_step)
137 | if global_step % self.vis_freq == 0: self.vis(global_step)
138 | if global_step % self.save_freq == 0 and not vis_mode: self.save(global_step)
139 |
140 | global_step += 1
141 |
142 | if vis_mode: input('press enter to visualise another example')
143 |
144 | return True
145 |
146 | def start_tensorboard(self):
147 | self.stop_tensorboard()
148 | server = libtmux.Server()
149 | session_name = 'tensorboard'
150 | os.system('tmux new-session -s ' + session_name + ' -d')
151 | session = server.find_where({'session_name': session_name})
152 | self.tmux_window = session.attached_window
153 | pane = self.tmux_window.split_window(attach=False)
154 | port = 6006
155 | pane.send_keys('tensorboard --logdir=' + os.getcwd() + '/logged_data' + ' --port ' + str(port))
156 |
157 | def stop_tensorboard(self):
158 | try:
159 | server = libtmux.Server()
160 | session_name = 'tensorboard'
161 | session = server.find_where({'session_name': session_name})
162 | self.tmux_window = session.attached_window
163 | self.tmux_window.kill_window()
164 | except NameError:
165 | print('No session was running')
166 | except:
167 | return
168 |
169 | def prime_data(self):
170 |
171 | self.data_loader.prime_data(start_it_train=0,
172 | end_it_train=self.num_training_examples-1,
173 | start_it_valid=self.num_training_examples,
174 | end_it_valid=self.total_num_examples-1,
175 | batch_size=self.network.batch_size,
176 | data_type=self.network.data_type)
177 |
178 | def __init_data_and_model(self):
179 |
180 | self.prime_data()
181 |
182 | starting_iteration = 0
183 | self.network.build_model()
184 | self.init_saver()
185 | if self.ld_chkpt is True:
186 | load_success, starting_iteration = self.network.load_params(self.sess, self.saver)
187 | if load_success is False:
188 | print('model built')
189 | else:
190 | print('model loaded')
191 | else:
192 | print('model built')
193 |
194 | self.init_dataset_loader(starting_iteration)
195 | self.initial_save()
196 | self.start_tensorboard()
197 |
198 | return starting_iteration
199 |
200 | def run_trainer(self):
201 | print('started trainer')
202 | starting_iteration = self.__init_data_and_model()
203 | self.init_logger()
204 | return self.train(starting_iteration)
205 |
206 | def run_visualiser(self):
207 | print('started visualiser')
208 | self.__init_data_and_model()
209 | self.init_logger()
210 | self.train(0,True)
211 |
--------------------------------------------------------------------------------