├── .gitignore ├── README.md ├── custom_ops ├── compiled.py ├── native.py └── source_code │ ├── correlation_op.cc │ ├── correlation_op.cu.cc │ ├── correlation_op.h │ ├── decode_flo_op.cc │ └── decode_ppm_op.cc ├── data_loader.py ├── directories.py ├── main.py ├── network.py ├── readme_images ├── example_loss.png ├── example_training_flow1.gif ├── example_training_flow2.gif ├── example_training_flow3.gif ├── example_training_flow4.gif ├── example_validation_flow1.gif ├── example_validation_flow2.gif ├── example_validation_flow3.gif └── example_validation_flow4.gif ├── task.py └── trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | FlyingChairs/ 2 | dynamic_libs/ 3 | logged_data/ 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PWC Net TensorFlow 2 | 3 | Tensorflow implementation of Pyramid, Warping and Cost Volume (PWC) Networks based on the [paper](https://arxiv.org/abs/1709.02371) presented at CVPR 2018.
4 | Currently, [main.py](https://github.com/djl11/PWC_Net_TensorFlow/blob/master/main.py) simply downloads the FlyingChairs Dataset and starts training, following the outlined [schedule](https://arxiv.org/abs/1709.02371).
5 | This code could easily be adapted to train on other datasets though.

6 | 7 | ## Tested Environment 8 | 9 | Ubuntu 16.04
10 | Python3
11 | Tensorflow 1.8
12 | Cuda 9.0
13 | 14 | ## Acknowledgements 15 | 16 | This repo uses 3 custom written tensorfow ops in c++ 17 | 18 | The correlation op was taken from [this](https://github.com/simonmeister/UnFlow) tensorflow implementation of UnFlow by Simon Meister
19 | The ppm and flo decoding ops were taken from [this](https://github.com/lmb-freiburg/lmbspecialops) collection of tf ops, from the Computer Vision Group, Albert-Ludwigs-Universität Freiburg
20 | 21 | ## Usage 22 | 23 | ```python 24 | python3 main.py 25 | ``` 26 | 27 | A tensorboard session will automatically be started in a new tmux window (so that the visualisations are still available after the python session has ended).
28 | This tensorboard session will log the training/validation losses, as well as giffs of the flow as it trains. 29 | 30 | Some general hyperparameters regarding the logging of data can be changed through [task.py](https://github.com/djl11/PWC_Net_TensorFlow/blob/master/task.py)
31 | Other hyperparameters relating to the training schedule can be changed in the constructor of [network.py](https://github.com/djl11/PWC_Net_TensorFlow/blob/master/network.py)
32 | The default training/validation split is to have 90% training, with 10% left for validation.
33 | 34 | 35 | ## Example visualisations following training 36 | 37 | From left to right, the images below indicate rgb image, ground truth flow, predicted flow, flow error
38 | 39 | Examples from the training set:
40 | 41 | ![Example Training Flow Result 1](readme_images/example_training_flow1.gif)
42 | ![Example Training Flow Result 2](readme_images/example_training_flow2.gif)
43 | ![Example Training Flow Result 3](readme_images/example_training_flow3.gif)
44 | ![Example Training Flow Result 4](readme_images/example_training_flow4.gif)
45 | 46 | Examples from the validation set:
47 | 48 | ![Example Validation Flow Result 1](readme_images/example_validation_flow1.gif)
49 | ![Example Validation Flow Result 2](readme_images/example_validation_flow2.gif)
50 | ![Example Validation Flow Result 3](readme_images/example_validation_flow3.gif)
51 | ![Example Validation Flow Result 4](readme_images/example_validation_flow4.gif)
52 | 53 | ## Example Training Loss 54 | 55 | This is an example of the loss when training on the full flying chairs dataset (no validation was used on this occassion).
56 | 57 | ![Example Loss](readme_images/example_loss.png)
58 | -------------------------------------------------------------------------------- /custom_ops/compiled.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import tensorflow as tf 4 | import subprocess 5 | from tensorflow.python.framework import ops 6 | 7 | # Register ops for compilation here 8 | OP_NAMES = ['correlation', 'decode_flo', 'decode_ppm'] 9 | 10 | cwd = os.getcwd() 11 | current_file_dir = os.path.dirname(os.path.realpath(__file__)) 12 | 13 | dynamic_lib_dir = 'dynamic_libs/' 14 | dynamic_lib_dir = current_file_dir + '/' + dynamic_lib_dir 15 | 16 | source_code_dir = 'source_code/' 17 | source_code_dir = current_file_dir + '/' + source_code_dir 18 | 19 | os.chdir(source_code_dir) 20 | 21 | if not os.path.isdir(dynamic_lib_dir): 22 | os.mkdir(dynamic_lib_dir) 23 | 24 | def compile(op=None): 25 | if op is not None: 26 | to_compile = [op] 27 | else: 28 | to_compile = OP_NAMES 29 | 30 | tf_cflags = " ".join(tf.sysconfig.get_compile_flags()) 31 | tf_lflags = " ".join(tf.sysconfig.get_link_flags()) 32 | for n in to_compile: 33 | 34 | print('\n\ncompiling custom operation: ' + n + '\n\n') 35 | 36 | base = n + "_op" 37 | fn_cu_cc = base + ".cu.cc" 38 | fn_cc = base + ".cc" 39 | fn_cu_o = dynamic_lib_dir + base + ".cu.o" 40 | fn_so = dynamic_lib_dir + base + ".so" 41 | 42 | out, err = subprocess.Popen(['which', 'nvcc'], stdout=subprocess.PIPE).communicate() 43 | cuda_dir = out.decode().split('/cuda')[0] 44 | 45 | 46 | if os.path.isfile(os.getcwd() + '/' + fn_cu_cc): 47 | nvcc_cmd = "nvcc -std=c++11 -c -o {} {} {} -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC -I " + cuda_dir + " --expt-relaxed-constexpr" 48 | nvcc_cmd = nvcc_cmd.format(" ".join([fn_cu_o, fn_cu_cc]), tf_cflags, tf_lflags) 49 | subprocess.check_output(nvcc_cmd, shell=True) 50 | 51 | gcc_cmd = "{} -std=c++11 -shared -o {} {} -fPIC -L " + cuda_dir + "/cuda/lib64 -lcudart {} -O2 -D GOOGLE_CUDA=1" 52 | gcc_cmd = gcc_cmd.format('g++'," ".join([fn_so, fn_cu_o, fn_cc]), tf_cflags, tf_lflags) 53 | else: 54 | gcc_cmd = "{} -std=c++11 -shared {} -o {} -fPIC {} {} -O2" 55 | gcc_cmd = gcc_cmd.format('g++', fn_cc, fn_so, tf_cflags, tf_lflags) 56 | print('gcc_cmd: ' + gcc_cmd) 57 | subprocess.check_output(gcc_cmd, shell=True) 58 | 59 | 60 | module = sys.modules[__name__] 61 | for n in OP_NAMES: 62 | lib_path = './{}_op.so'.format(n) 63 | try: 64 | os.chdir(dynamic_lib_dir) 65 | op_lib = tf.load_op_library(lib_path) 66 | except: 67 | os.chdir(source_code_dir) 68 | compile(n) 69 | os.chdir(dynamic_lib_dir) 70 | op_lib = tf.load_op_library(lib_path) 71 | setattr(module, '_' + n + '_module', op_lib) 72 | 73 | os.chdir(cwd) 74 | 75 | # functions # 76 | #-----------# 77 | 78 | def correlation(first, second, **kwargs): 79 | return _correlation_module.correlation(first, second, **kwargs)[0] 80 | 81 | decode_flo = _decode_flo_module.decode_flo 82 | decode_ppm = _decode_ppm_module.decode_ppm 83 | 84 | # Register op gradients 85 | 86 | @ops.RegisterGradient("Correlation") 87 | def _CorrelationGrad(op, in_grad, in_grad1, in_grad2): 88 | grad0, grad1 = _correlation_module.correlation_grad( 89 | in_grad, op.inputs[0], op.inputs[1], 90 | op.outputs[1], op.outputs[2], 91 | kernel_size=op.get_attr('kernel_size'), 92 | max_displacement=op.get_attr('max_displacement'), 93 | pad=op.get_attr('pad'), 94 | stride_1=op.get_attr('stride_1'), 95 | stride_2=op.get_attr('stride_2')) 96 | return [grad0, grad1] 97 | 98 | ops.NotDifferentiable("DecodeFlo") 99 | ops.NotDifferentiable("DecodePpm") -------------------------------------------------------------------------------- /custom_ops/native.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import moviepy.editor as mpy 3 | import tempfile 4 | import cv2 5 | import numpy as np 6 | 7 | min_denominator = 1e-12 8 | 9 | # Image Warp # 10 | #------------# 11 | 12 | def image_warp(im, flow): 13 | 14 | num_batch, height, width, channels = tf.unstack(tf.shape(im)) 15 | max_x = tf.cast(width - 1, 'int32') 16 | max_y = tf.cast(height - 1, 'int32') 17 | zero = tf.zeros([], dtype='int32') 18 | 19 | # We have to flatten our tensors to vectorize the interpolation 20 | im_flat = tf.reshape(im, [-1, channels]) 21 | flow_flat = tf.reshape(flow, [-1, 2]) 22 | 23 | # Floor the flow, as the final indices are integers 24 | # The fractional part is used to control the bilinear interpolation. 25 | flow_floor = tf.to_int32(tf.floor(flow_flat)) 26 | bilinear_weights = flow_flat - tf.floor(flow_flat) 27 | 28 | # Construct base indices which are displaced with the flow 29 | pos_x = tf.tile(tf.range(width), [height * num_batch]) 30 | grid_y = tf.tile(tf.expand_dims(tf.range(height), 1), [1, width]) 31 | pos_y = tf.tile(tf.reshape(grid_y, [-1]), [num_batch]) 32 | 33 | x = flow_floor[:, 0] 34 | y = flow_floor[:, 1] 35 | xw = bilinear_weights[:, 0] 36 | yw = bilinear_weights[:, 1] 37 | 38 | # Compute interpolation weights for 4 adjacent pixels 39 | # expand to num_batch * height * width x 1 for broadcasting in add_n below 40 | wa = tf.expand_dims((1 - xw) * (1 - yw), 1) # top left pixel 41 | wb = tf.expand_dims((1 - xw) * yw, 1) # bottom left pixel 42 | wc = tf.expand_dims(xw * (1 - yw), 1) # top right pixel 43 | wd = tf.expand_dims(xw * yw, 1) # bottom right pixel 44 | 45 | x0 = pos_x + x 46 | x1 = x0 + 1 47 | y0 = pos_y + y 48 | y1 = y0 + 1 49 | 50 | x0 = tf.clip_by_value(x0, zero, max_x) 51 | x1 = tf.clip_by_value(x1, zero, max_x) 52 | y0 = tf.clip_by_value(y0, zero, max_y) 53 | y1 = tf.clip_by_value(y1, zero, max_y) 54 | 55 | dim1 = width * height 56 | batch_offsets = tf.range(num_batch) * dim1 57 | base_grid = tf.tile(tf.expand_dims(batch_offsets, 1), [1, dim1]) 58 | base = tf.reshape(base_grid, [-1]) 59 | 60 | base_y0 = base + y0 * width 61 | base_y1 = base + y1 * width 62 | idx_a = base_y0 + x0 63 | idx_b = base_y1 + x0 64 | idx_c = base_y0 + x1 65 | idx_d = base_y1 + x1 66 | 67 | Ia = tf.gather(im_flat, idx_a) 68 | Ib = tf.gather(im_flat, idx_b) 69 | Ic = tf.gather(im_flat, idx_c) 70 | Id = tf.gather(im_flat, idx_d) 71 | 72 | warped_flat = tf.add_n([wa * Ia, wb * Ib, wc * Ic, wd * Id]) 73 | warped = tf.reshape(warped_flat, [num_batch, height, width, channels]) 74 | 75 | return warped 76 | 77 | # Visualisation # 78 | #---------------# 79 | 80 | def make_color_wheel(): 81 | """ 82 | Generate color wheel according Middlebury color code 83 | :return: Color wheel 84 | """ 85 | RY = 15 86 | YG = 6 87 | GC = 4 88 | CB = 11 89 | BM = 13 90 | MR = 6 91 | 92 | ncols = RY + YG + GC + CB + BM + MR 93 | 94 | colorwheel = np.zeros([ncols, 3]) 95 | 96 | col = 0 97 | 98 | # RY 99 | colorwheel[0:RY, 0] = 255 100 | colorwheel[0:RY, 1] = np.transpose(np.floor(255 * np.arange(0, RY) / RY)) 101 | col += RY 102 | 103 | # YG 104 | colorwheel[col:col + YG, 0] = 255 - np.transpose(np.floor(255 * np.arange(0, YG) / YG)) 105 | colorwheel[col:col + YG, 1] = 255 106 | col += YG 107 | 108 | # GC 109 | colorwheel[col:col + GC, 1] = 255 110 | colorwheel[col:col + GC, 2] = np.transpose(np.floor(255 * np.arange(0, GC) / GC)) 111 | col += GC 112 | 113 | # CB 114 | colorwheel[col:col + CB, 1] = 255 - np.transpose(np.floor(255 * np.arange(0, CB) / CB)) 115 | colorwheel[col:col + CB, 2] = 255 116 | col += CB 117 | 118 | # BM 119 | colorwheel[col:col + BM, 2] = 255 120 | colorwheel[col:col + BM, 0] = np.transpose(np.floor(255 * np.arange(0, BM) / BM)) 121 | col += + BM 122 | 123 | # MR 124 | colorwheel[col:col + MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR)) 125 | colorwheel[col:col + MR, 0] = 255 126 | 127 | return colorwheel 128 | 129 | def compute_color(u, v): 130 | """ 131 | compute optical flow color map 132 | :param u: optical flow horizontal map 133 | :param v: optical flow vertical map 134 | :return: optical flow in color code 135 | """ 136 | [h, w] = u.shape 137 | img = np.zeros([h, w, 3]) 138 | nanIdx = np.isnan(u) | np.isnan(v) 139 | u[nanIdx] = 0 140 | v[nanIdx] = 0 141 | 142 | colorwheel = make_color_wheel() 143 | ncols = np.size(colorwheel, 0) 144 | 145 | rad = np.sqrt(u**2+v**2) 146 | 147 | a = np.arctan2(-v, -u) / np.pi 148 | 149 | fk = (a+1) / 2 * (ncols - 1) + 1 150 | 151 | k0 = np.floor(fk).astype(int) 152 | 153 | k1 = k0 + 1 154 | k1[k1 == ncols+1] = 1 155 | f = fk - k0 156 | 157 | for i in range(0, np.size(colorwheel,1)): 158 | tmp = colorwheel[:, i] 159 | col0 = tmp[k0-1] / 255 160 | col1 = tmp[k1-1] / 255 161 | col = (1-f) * col0 + f * col1 162 | 163 | idx = rad <= 1 164 | col[idx] = 1-rad[idx]*(1-col[idx]) 165 | notidx = np.logical_not(idx) 166 | 167 | col[notidx] *= 0.75 168 | img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx))) 169 | 170 | return img 171 | 172 | def flow_to_image(flow): 173 | """ 174 | Convert flow into middlebury color code image 175 | :param flow: optical flow map 176 | :return: optical flow image in middlebury color 177 | """ 178 | u = flow[:, :, 0] 179 | v = flow[:, :, 1] 180 | 181 | idxUnknow = (abs(u) > 1e7) | (abs(v) > 1e7) 182 | u[idxUnknow] = 0 183 | v[idxUnknow] = 0 184 | 185 | rad = np.sqrt(u ** 2 + v ** 2) 186 | maxrad = max(-1, np.max(rad)) 187 | 188 | u = u / (maxrad + np.finfo(float).eps) 189 | v = v / (maxrad + np.finfo(float).eps) 190 | 191 | img = compute_color(u, v) 192 | 193 | idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2) 194 | img[idx] = 0 195 | 196 | return np.uint8(img) 197 | 198 | def modify_images_for_vis(x_images, gt_flow, predicted_flow): 199 | 200 | images = list() 201 | for i in range(2): 202 | x_image = x_images[i] 203 | unscaled_predicted_flow = cv2.resize(predicted_flow, (448,384)) 204 | diff_flow = flow_to_image(gt_flow-unscaled_predicted_flow) 205 | 206 | gt_flow_im = flow_to_image(gt_flow) 207 | predicted_flow_im = flow_to_image(unscaled_predicted_flow) 208 | diff_flow_im = flow_to_image(diff_flow) 209 | 210 | combined_image = np.concatenate((x_image,gt_flow_im,predicted_flow_im,diff_flow_im),1).astype(np.uint8) # change this to combine all of above 211 | 212 | images.append(combined_image) 213 | 214 | return np.asarray(images) 215 | 216 | def convert_array_to_gif_summary(images_arr, tag, fps): 217 | 218 | summary = tf.Summary() 219 | 220 | if len(images_arr.shape) == 5: 221 | # concatenate batch dimension horizontally 222 | images_arr = np.concatenate(list(images_arr), axis=-2) 223 | if len(images_arr.shape) != 4: 224 | raise ValueError('Tensors must be 4-D or 5-D for gif summary.') 225 | if images_arr.shape[-1] != 3: 226 | raise ValueError('Tensors must have 3 channels.') 227 | 228 | # encode sequence of images into gif string 229 | clip = mpy.ImageSequenceClip(list(images_arr), fps=fps) 230 | with tempfile.NamedTemporaryFile() as f: 231 | filename = f.name + '.gif' 232 | clip.write_gif(filename, verbose=False, program='ffmpeg') 233 | with open(filename, 'rb') as f: 234 | encoded_image_string = f.read() 235 | 236 | image = tf.Summary.Image() 237 | image.height = images_arr.shape[-3] 238 | image.width = images_arr.shape[-2] 239 | image.colorspace = 3 # code for 'RGB' 240 | image.encoded_image_string = encoded_image_string 241 | summary.value.add(tag=tag, image=image) 242 | return summary -------------------------------------------------------------------------------- /custom_ops/source_code/correlation_op.cc: -------------------------------------------------------------------------------- 1 | #define EIGEN_USE_THREADS 2 | 3 | #include 4 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 5 | #include "tensorflow/core/framework/op_kernel.h" 6 | #include "tensorflow/core/framework/register_types.h" 7 | #include "tensorflow/core/framework/tensor.h" 8 | #include "tensorflow/core/framework/tensor_shape.h" 9 | #include "tensorflow/core/framework/types.h" 10 | #include "tensorflow/core/lib/core/status.h" 11 | #include "tensorflow/core/platform/logging.h" 12 | #include "tensorflow/core/framework/op.h" 13 | #include "tensorflow/core/framework/shape_inference.h" 14 | #include "tensorflow/core/framework/common_shape_fns.h" 15 | 16 | #include "correlation_op.h" 17 | 18 | typedef Eigen::GpuDevice GPUDevice; 19 | 20 | using namespace tensorflow; 21 | 22 | void Correlation(const GPUDevice& d, 23 | typename TTypes::ConstTensor input_0, 24 | typename TTypes::ConstTensor input_1, 25 | typename TTypes::Tensor output, 26 | typename TTypes::Tensor padded_0, 27 | typename TTypes::Tensor padded_1, 28 | CorrelationState params); 29 | 30 | void CorrelationGrad(const GPUDevice& d, 31 | typename TTypes::ConstTensor input_grad, 32 | typename TTypes::ConstTensor padded_0, 33 | typename TTypes::ConstTensor padded_1, 34 | typename TTypes::Tensor output_grad_0, 35 | typename TTypes::Tensor output_grad_1, 36 | CorrelationState params); 37 | 38 | class CorrelationOp : public OpKernel { 39 | public: 40 | explicit CorrelationOp(OpKernelConstruction* context) 41 | : OpKernel(context), attrs(context) {} 42 | 43 | void Compute(OpKernelContext* context) override { 44 | const Tensor& input_0 = context->input(0); 45 | const Tensor& input_1 = context->input(1); 46 | 47 | OP_REQUIRES(context, input_0.shape() == input_1.shape(), 48 | errors::InvalidArgument("Input shapes have to be the same")); 49 | 50 | typename TTypes::ConstTensor input_0_data = input_0.tensor(); 51 | typename TTypes::ConstTensor input_1_data = input_1.tensor(); 52 | 53 | const int batch = input_0_data.dimension(0); 54 | const int in_channels = input_0_data.dimension(1); 55 | const int in_height = input_0_data.dimension(2); 56 | const int in_width = input_0_data.dimension(3); 57 | 58 | CorrelationState st(attrs, in_height, in_width, in_channels); 59 | 60 | OP_REQUIRES(context, st.out_width * st.out_height > 0, 61 | errors::InvalidArgument("Invalid correlation settings")); 62 | 63 | Tensor* output = NULL; 64 | TensorShape output_shape({batch, st.out_channels, st.out_height, st.out_width}); 65 | OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); 66 | 67 | Tensor* padded_0 = NULL; 68 | Tensor* padded_1 = NULL; 69 | TensorShape padded_shape({batch, st.padded_height, st.padded_width, in_channels}); 70 | OP_REQUIRES_OK(context, context->allocate_output(1, padded_shape, &padded_0)); 71 | OP_REQUIRES_OK(context, context->allocate_output(2, padded_shape, &padded_1)); 72 | 73 | typename TTypes::Tensor output_data = output->tensor(); 74 | typename TTypes::Tensor padded_0_data = padded_0->tensor(); 75 | typename TTypes::Tensor padded_1_data = padded_1->tensor(); 76 | 77 | Correlation(context->eigen_device(), 78 | input_0_data, input_1_data, output_data, 79 | padded_0_data, padded_1_data, 80 | st); 81 | } 82 | 83 | private: 84 | CorrelationAttrs attrs; 85 | }; 86 | 87 | class CorrelationOpGrad : public OpKernel { 88 | public: 89 | explicit CorrelationOpGrad(OpKernelConstruction* context) 90 | : OpKernel(context), attrs(context) {} 91 | 92 | void Compute(OpKernelContext* context) override { 93 | const Tensor& input_grad = context->input(0); 94 | const Tensor& input_0 = context->input(1); 95 | const Tensor& input_1 = context->input(2); 96 | const Tensor& padded_0 = context->input(3); 97 | const Tensor& padded_1 = context->input(4); 98 | 99 | typename TTypes::ConstTensor input_grad_data = input_grad.tensor(); 100 | typename TTypes::ConstTensor input_0_data = input_0.tensor(); 101 | //typename TTypes::ConstTensor input_1_data = input_1.tensor(); 102 | typename TTypes::ConstTensor padded_0_data = padded_0.tensor(); 103 | typename TTypes::ConstTensor padded_1_data = padded_1.tensor(); 104 | 105 | const int in_channels = input_0_data.dimension(1); 106 | const int in_height = input_0_data.dimension(2); 107 | const int in_width = input_0_data.dimension(3); 108 | 109 | CorrelationState st(attrs, in_height, in_width, in_channels); 110 | 111 | Tensor* output_grad_0 = NULL; 112 | OP_REQUIRES_OK(context, context->allocate_output(0, input_0.shape(), 113 | &output_grad_0)); 114 | Tensor* output_grad_1 = NULL; 115 | OP_REQUIRES_OK(context, context->allocate_output(1, input_0.shape(), 116 | &output_grad_1)); 117 | 118 | typename TTypes::Tensor output_grad_0_data = output_grad_0->tensor(); 119 | typename TTypes::Tensor output_grad_1_data = output_grad_1->tensor(); 120 | 121 | CorrelationGrad(context->eigen_device(), 122 | input_grad_data, 123 | padded_0_data, padded_1_data, 124 | output_grad_0_data, output_grad_1_data, 125 | st); 126 | } 127 | private: 128 | CorrelationAttrs attrs; 129 | }; 130 | 131 | using shape_inference::DimensionHandle;; 132 | 133 | REGISTER_OP("Correlation") 134 | .Input("input_0: float") 135 | .Input("input_1: float") 136 | .Attr("kernel_size: int = 1") 137 | .Attr("max_displacement: int = 20") 138 | .Attr("pad: int = 20") 139 | .Attr("stride_1: int = 1") 140 | .Attr("stride_2: int = 2") 141 | .Output("correlation: float") 142 | .Output("padded_0: float") 143 | .Output("padded_1: float") 144 | .SetShapeFn([](shape_inference::InferenceContext* c) { 145 | CorrelationAttrs attrs; 146 | c->GetAttr("kernel_size", &attrs.kernel_size); 147 | c->GetAttr("max_displacement", &attrs.max_displacement); 148 | c->GetAttr("pad", &attrs.pad_size); 149 | c->GetAttr("stride_1", &attrs.stride_1); 150 | c->GetAttr("stride_2", &attrs.stride_2); 151 | 152 | DimensionHandle batch = c->Dim(c->input(0), 0); 153 | 154 | //padded_height = in_height + 2 * pad_size; 155 | //padded_width = in_width + 2 * pad_size; 156 | //kernel_radius = (kernel_size - 1) / 2; 157 | //border_size = max_displacement + kernel_radius; 158 | int neighborhood_grid_radius = attrs.max_displacement / attrs.stride_2; 159 | int neighborhood_grid_width = neighborhood_grid_radius * 2 + 1; 160 | //out_width = ceil((float)(padded_width - border_size *2) / (float)stride_1); 161 | //out_height = ceil((float)(padded_height - border_size *2) / (float)stride_1); 162 | int out_channels = neighborhood_grid_width * neighborhood_grid_width; 163 | 164 | // TODO: support passing on output width and height 165 | 166 | c->set_output(0, c->MakeShape({batch, out_channels, c->UnknownDim(), c->UnknownDim()})); 167 | return Status::OK(); 168 | }); 169 | 170 | REGISTER_OP("CorrelationGrad") 171 | .Input("input_grad: float") 172 | .Input("original_input_0: float") 173 | .Input("original_input_1: float") 174 | .Input("padded_0: float") 175 | .Input("padded_1: float") 176 | .Attr("kernel_size: int = 1") 177 | .Attr("max_displacement: int = 20") 178 | .Attr("pad: int = 20") 179 | .Attr("stride_1: int = 1") 180 | .Attr("stride_2: int = 2") 181 | .Output("output_grad_0: float") 182 | .Output("output_grad_1: float") 183 | .SetShapeFn([](shape_inference::InferenceContext* c) { 184 | c->set_output(0, c->input(1)); 185 | c->set_output(1, c->input(2)); 186 | return Status::OK(); 187 | }); 188 | 189 | #if GOOGLE_CUDA 190 | 191 | REGISTER_KERNEL_BUILDER(Name("Correlation").Device(DEVICE_GPU), CorrelationOp); 192 | REGISTER_KERNEL_BUILDER(Name("CorrelationGrad").Device(DEVICE_GPU), CorrelationOpGrad); 193 | 194 | #endif // GOOGLE_CUDA 195 | -------------------------------------------------------------------------------- /custom_ops/source_code/correlation_op.cu.cc: -------------------------------------------------------------------------------- 1 | #if GOOGLE_CUDA 2 | 3 | #define EIGEN_USE_GPU 4 | 5 | #include "tensorflow/core/framework/register_types.h" 6 | #include "tensorflow/core/framework/tensor_types.h" 7 | #include "tensorflow/core/platform/types.h" 8 | #include "tensorflow/core/util/cuda_kernel_helper.h" 9 | 10 | #include "correlation_op.h" 11 | 12 | using namespace tensorflow; 13 | using CPUDevice = Eigen::ThreadPoolDevice; 14 | using GPUDevice = Eigen::GpuDevice; 15 | // --------------------------------------------------------- 16 | // DIRECT PORT OF CAFFE CODE WITH MINIMAL CHANGES 17 | // --------------------------------------------------------- 18 | 19 | #define ROUND_OFF 50000 20 | 21 | #define WARPS_PER_BLOCK 1 22 | #define THREADS_PER_WARP 32 23 | 24 | const int CAFFE_CUDA_NUM_THREADS = 512; 25 | 26 | inline int CAFFE_GET_BLOCKS(const int N) { 27 | return (N + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS; 28 | } 29 | 30 | template 31 | __global__ void blob_rearrange_kernel2(const Dtype* in, Dtype* out, int num, int channels, int width, int height, int widthheight, int padding, int pwidthheight) 32 | { 33 | int xy = blockIdx.x*blockDim.x + threadIdx.x; 34 | if(xy>=widthheight) 35 | return; 36 | 37 | int ch = blockIdx.y; 38 | int n = blockIdx.z; 39 | 40 | Dtype value=in[(n*channels+ch)*widthheight+xy]; 41 | 42 | __syncthreads(); 43 | 44 | int xpad = (xy % width + padding); 45 | int ypad = (xy / width + padding); 46 | int xypad = ypad * (width+2*padding) + xpad; 47 | 48 | out[(n*pwidthheight+xypad)*channels + ch] = value; 49 | } 50 | 51 | template 52 | __global__ void CorrelateData(const int nthreads, int num, int topwidth, int topheight, int topchannels, int topcount, 53 | int max_displacement, int neighborhood_grid_radius, int neighborhood_grid_width, int kernel_radius, int kernel_size, int stride1, int stride2, 54 | int bottomwidth, int bottomheight, int bottomchannels, 55 | const Dtype *bottom0, const Dtype *bottom1, Dtype *top) 56 | { 57 | extern __shared__ char patch_data_char[]; 58 | 59 | Dtype *patch_data = (Dtype *)patch_data_char; 60 | 61 | // First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1 62 | int x1 = blockIdx.x*stride1 + max_displacement; 63 | int y1 = blockIdx.y*stride1 + max_displacement; 64 | int item = blockIdx.z; 65 | int ch_off = threadIdx.x; 66 | 67 | // Load 3D patch into shared shared memory 68 | for(int j = 0; j < kernel_size; j++) { // HEIGHT 69 | for(int i = 0; i < kernel_size; i++) { // WIDTH 70 | int ji_off = ((j * kernel_size) + i) * bottomchannels; 71 | for(int ch = ch_off; ch < bottomchannels; ch += (WARPS_PER_BLOCK*THREADS_PER_WARP)) { // CHANNELS 72 | int idx1 = ((item * bottomheight + y1+j) * bottomwidth + x1+i) * bottomchannels + ch; 73 | int idxPatchData = ji_off + ch; 74 | patch_data[idxPatchData] = bottom0[idx1]; 75 | } 76 | } 77 | } 78 | 79 | __syncthreads(); 80 | 81 | __shared__ Dtype sum[WARPS_PER_BLOCK*THREADS_PER_WARP]; 82 | 83 | // Compute correlation 84 | for(int top_channel = 0; top_channel < topchannels; top_channel++) { 85 | sum[ch_off] = 0; 86 | 87 | int s2o = (top_channel % neighborhood_grid_width - neighborhood_grid_radius) * stride2; 88 | int s2p = (top_channel / neighborhood_grid_width - neighborhood_grid_radius) * stride2; 89 | 90 | for(int j = 0; j < kernel_size; j++) { // HEIGHT 91 | for(int i = 0; i < kernel_size; i++) { // WIDTH 92 | int ji_off = ((j * kernel_size) + i) * bottomchannels; 93 | for(int ch = ch_off; ch < bottomchannels; ch += (WARPS_PER_BLOCK*THREADS_PER_WARP)) { // CHANNELS 94 | int x2 = x1 + s2o; 95 | int y2 = y1 + s2p; 96 | 97 | int idxPatchData = ji_off + ch; 98 | int idx2 = ((item * bottomheight + y2+j) * bottomwidth + x2+i) * bottomchannels + ch; 99 | 100 | sum[ch_off] += patch_data[idxPatchData] * bottom1[idx2]; 101 | } 102 | } 103 | } 104 | 105 | __syncthreads(); 106 | 107 | if(ch_off == 0) { 108 | Dtype total_sum = 0; 109 | for(int idx = 0; idx < WARPS_PER_BLOCK*THREADS_PER_WARP; idx++) { 110 | total_sum += sum[idx]; 111 | } 112 | const int sumelems = kernel_size*kernel_size*bottomchannels; 113 | const int index = ((top_channel*topheight + blockIdx.y)*topwidth)+blockIdx.x; 114 | top[index + item*topcount] = total_sum / (float)sumelems; 115 | } 116 | } 117 | } 118 | 119 | template 120 | __global__ void CorrelateDataBackward0(const int nthreads, int num, int item, int topwidth, int topheight, int topchannels, 121 | int max_displacement, int neighborhood_grid_radius, int neighborhood_grid_width, int kernel_radius, int stride1, int stride2, 122 | int bottomwidth, int bottomheight, int pbottomwidth, int pbottomheight, int bottomchannels, int bottomcount, int pad_size, 123 | Dtype *bottom0diff, const Dtype *bottom1, const Dtype *topdiff) 124 | { 125 | CUDA_1D_KERNEL_LOOP(index, nthreads) { 126 | int n = index % bottomchannels; //channels 127 | int l = (index / bottomchannels) % bottomwidth + pad_size; //w-pos 128 | int m = (index / bottomchannels / bottomwidth) % bottomheight + pad_size; //h-pos 129 | 130 | //Get X,Y ranges and clamp 131 | // round_off is a trick to enable integer division with ceil, even for negative numbers 132 | // We use a large offset, for the inner part not to become negative. 133 | const int round_off = ROUND_OFF; 134 | const int round_off_s1 = stride1 * round_off; 135 | 136 | // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: 137 | int xmin = (l - 2*kernel_radius - max_displacement + round_off_s1 - 1) / stride1 + 1 - round_off; // ceil (l - 2*kernel_radius - max_displacement) / stride1 138 | int ymin = (m - 2*kernel_radius - max_displacement + round_off_s1 - 1) / stride1 + 1 - round_off; // ceil (l - 2*kernel_radius - max_displacement) / stride1 139 | 140 | // Same here: 141 | int xmax = (l - max_displacement + round_off_s1) / stride1 - round_off; // floor (l - max_displacement) / stride1 142 | int ymax = (m - max_displacement + round_off_s1) / stride1 - round_off; // floor (m - max_displacement) / stride1 143 | 144 | 145 | Dtype sum = 0; 146 | if(xmax>=0 && ymax>=0 && (xmin<=topwidth-1) && (ymin<=topheight-1)) 147 | { 148 | xmin = max(0,xmin); 149 | xmax = min(topwidth-1,xmax); 150 | 151 | ymin = max(0,ymin); 152 | ymax = min(topheight-1,ymax); 153 | 154 | for(int p = -neighborhood_grid_radius; p <= neighborhood_grid_radius; p++) { 155 | for(int o = -neighborhood_grid_radius; o <= neighborhood_grid_radius; o++) { 156 | 157 | // Get bottom1 data: 158 | int s2o = stride2 * o; 159 | int s2p = stride2 * p; 160 | int idxbot1 = ((item * pbottomheight + (m+s2p)) * pbottomwidth + (l+s2o)) * bottomchannels + n; 161 | Dtype bot1tmp = bottom1[idxbot1]; // bottom1[l+s2o,m+s2p,n] 162 | 163 | // Index offset for topdiff in following loops: 164 | int op = (p+neighborhood_grid_radius) * neighborhood_grid_width + (o+neighborhood_grid_radius); // index [o,p] 165 | int idxopoffset = (item * topchannels + op); 166 | 167 | for(int y = ymin; y <= ymax; y++) { 168 | for(int x = xmin; x <= xmax; x++) { 169 | int idxtopdiff = (idxopoffset * topheight + y) * topwidth + x; // topdiff[x,y,o,p] 170 | sum += topdiff[idxtopdiff] * bot1tmp; 171 | } 172 | } 173 | } 174 | } 175 | } 176 | const int sumelems = (kernel_radius*2+1)*(kernel_radius*2+1)*bottomchannels; 177 | const int bot0index = ((n * bottomheight) + (m-pad_size)) * bottomwidth + (l-pad_size); 178 | bottom0diff[bot0index + item*bottomcount] = sum / (float)sumelems; 179 | } 180 | 181 | } 182 | 183 | template 184 | __global__ void CorrelateDataBackward1(const int nthreads, int num, int item, int topwidth, int topheight, int topchannels, 185 | int max_displacement, int neighborhood_grid_radius, int neighborhood_grid_width, int kernel_radius, int stride1, int stride2, 186 | int bottomwidth, int bottomheight, int pbottomwidth, int pbottomheight, int bottomchannels, int bottomcount, int pad_size, 187 | const Dtype *bottom0, Dtype *bottom1diff, const Dtype *topdiff) 188 | { 189 | CUDA_1D_KERNEL_LOOP(index, nthreads) { 190 | //int l = index % bottomwidth + pad_size; //w-pos 191 | //int m = (index / bottomwidth) % bottomheight + pad_size; //h-pos 192 | //int n = (index / bottomwidth / bottomheight) % bottomchannels; //channels 193 | int n = index % bottomchannels; //channels 194 | int l = (index / bottomchannels) % bottomwidth + pad_size; //w-pos 195 | int m = (index / bottomchannels / bottomwidth) % bottomheight + pad_size; //h-pos 196 | 197 | // round_off is a trick to enable integer division with ceil, even for negative numbers 198 | // We use a large offset, for the inner part not to become negative. 199 | const int round_off = ROUND_OFF; 200 | const int round_off_s1 = stride1 * round_off; 201 | 202 | Dtype sum = 0; 203 | for(int p = -neighborhood_grid_radius; p <= neighborhood_grid_radius; p++) { 204 | for(int o = -neighborhood_grid_radius; o <= neighborhood_grid_radius; o++) { 205 | 206 | int s2o = stride2 * o; 207 | int s2p = stride2 * p; 208 | 209 | //Get X,Y ranges and clamp 210 | // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: 211 | int xmin = (l - 2*kernel_radius - max_displacement - s2o + round_off_s1 - 1) / stride1 + 1 - round_off; // ceil (l - 2*kernel_radius - max_displacement - s2o) / stride1 212 | int ymin = (m - 2*kernel_radius - max_displacement - s2p + round_off_s1 - 1) / stride1 + 1 - round_off; // ceil (l - 2*kernel_radius - max_displacement - s2o) / stride1 213 | 214 | // Same here: 215 | int xmax = (l - max_displacement - s2o + round_off_s1) / stride1 - round_off; // floor (l - max_displacement - s2o) / stride1 216 | int ymax = (m - max_displacement - s2p + round_off_s1) / stride1 - round_off; // floor (m - max_displacement - s2p) / stride1 217 | 218 | if(xmax>=0 && ymax>=0 && (xmin<=topwidth-1) && (ymin<=topheight-1)) 219 | { 220 | xmin = max(0,xmin); 221 | xmax = min(topwidth-1,xmax); 222 | 223 | ymin = max(0,ymin); 224 | ymax = min(topheight-1,ymax); 225 | 226 | // Get bottom0 data: 227 | int idxbot0 = ((item * pbottomheight + (m-s2p)) * pbottomwidth + (l-s2o)) * bottomchannels + n; 228 | Dtype bot0tmp = bottom0[idxbot0]; // bottom1[l+s2o,m+s2p,n] 229 | 230 | // Index offset for topdiff in following loops: 231 | int op = (p+neighborhood_grid_radius) * neighborhood_grid_width + (o+neighborhood_grid_radius); // index [o,p] 232 | int idxOpOffset = (item * topchannels + op); 233 | 234 | for(int y = ymin; y <= ymax; y++) { 235 | for(int x = xmin; x <= xmax; x++) { 236 | int idxtopdiff = (idxOpOffset * topheight + y) * topwidth + x; // topdiff[x,y,o,p] 237 | sum += topdiff[idxtopdiff] * bot0tmp; 238 | } 239 | } 240 | } 241 | } 242 | } 243 | const int sumelems = (kernel_radius*2+1)*(kernel_radius*2+1)*bottomchannels; 244 | const int bot1index = ((n * bottomheight) + (m-pad_size)) * bottomwidth + (l-pad_size); 245 | bottom1diff[bot1index + item*bottomcount] = sum / (float)sumelems; 246 | } 247 | 248 | } 249 | 250 | void Correlation(const GPUDevice& d, 251 | typename TTypes::ConstTensor input_0, 252 | typename TTypes::ConstTensor input_1, 253 | typename TTypes::Tensor output, 254 | typename TTypes::Tensor padded_0, 255 | typename TTypes::Tensor padded_1, 256 | CorrelationState st) { 257 | 258 | const int top_channels_ = output.dimension(1); 259 | const int top_height_ = output.dimension(2); 260 | const int top_width_ = output.dimension(3); 261 | const int pad_size_ = st.pad_size; 262 | const int stride1_ = st.stride_1; 263 | const int stride2_ = st.stride_2; 264 | const int kernel_size_ = st.kernel_size; 265 | const int kernel_radius_ = st.kernel_radius; 266 | const int max_displacement_ = st.max_displacement; 267 | const int neighborhood_grid_radius_ = st.neighborhood_grid_radius; 268 | const int neighborhood_grid_width_ = st.neighborhood_grid_width; 269 | 270 | // PORTED CAFFE CODE 271 | 272 | const int bnum = input_0.dimension(0); 273 | const int bchannels = input_0.dimension(1); 274 | const int bheight = input_0.dimension(2); 275 | const int bwidth = input_0.dimension(3); 276 | const int bwidthheight = bwidth * bheight; 277 | 278 | const int topcount = top_width_ * top_height_ * top_channels_; 279 | 280 | dim3 threadsPerBlock(THREADS_PER_WARP * WARPS_PER_BLOCK); 281 | 282 | cudaMemset(padded_0.data(), 0, padded_0.size()*sizeof(float)); 283 | cudaMemset(padded_1.data(), 0, padded_1.size()*sizeof(float)); 284 | 285 | int threads_per_block=16; 286 | dim3 totalBlocksRearr((bwidthheight-1)/threads_per_block+1, bchannels, bnum); 287 | const int pwidthheight = (bwidth + 2 * pad_size_) * (bheight + 2 * pad_size_); 288 | 289 | blob_rearrange_kernel2<<>> 290 | (input_0.data(),padded_0.data(),bnum,bchannels,bwidth,bheight,bwidthheight,pad_size_,pwidthheight); 291 | 292 | blob_rearrange_kernel2<<>> 293 | (input_1.data(),padded_1.data(),bnum,bchannels,bwidth,bheight,bwidthheight,pad_size_,pwidthheight); 294 | 295 | const int num = bnum; 296 | const int channels = bchannels; 297 | const int height = bheight + 2*pad_size_; 298 | const int width = bwidth + 2*pad_size_; 299 | 300 | const int shared_memory_per_block = (kernel_size_*kernel_size_)*bchannels; 301 | 302 | // CorrelationLayer 303 | int topThreadCount = topcount; 304 | 305 | dim3 totalBlocksCorr(top_width_, top_height_, num); 306 | 307 | CorrelateData<<>>( 308 | topThreadCount, 309 | num, top_width_, top_height_, top_channels_, topcount, 310 | max_displacement_, neighborhood_grid_radius_, neighborhood_grid_width_, kernel_radius_, kernel_size_, 311 | stride1_, stride2_, 312 | width, height, channels, 313 | padded_0.data(), padded_1.data(), output.data() 314 | ); 315 | } 316 | 317 | void CorrelationGrad(const GPUDevice& d, 318 | typename TTypes::ConstTensor input_grad, 319 | typename TTypes::ConstTensor padded_0, 320 | typename TTypes::ConstTensor padded_1, 321 | typename TTypes::Tensor output_grad_0, 322 | typename TTypes::Tensor output_grad_1, 323 | CorrelationState st) { 324 | 325 | const int top_channels_ = input_grad.dimension(1); 326 | const int top_height_ = input_grad.dimension(2); 327 | const int top_width_ = input_grad.dimension(3); 328 | 329 | const int pad_size_ = st.pad_size; 330 | const int stride1_ = st.stride_1; 331 | const int stride2_ = st.stride_2; 332 | const int kernel_size_ = st.kernel_size; 333 | const int kernel_radius_ = st.kernel_radius; 334 | const int max_displacement_ = st.max_displacement; 335 | const int neighborhood_grid_radius_ = st.neighborhood_grid_radius; 336 | const int neighborhood_grid_width_ = st.neighborhood_grid_width; 337 | 338 | // PORTED CAFFE CODE 339 | 340 | // Get top diff, compute bottom diff 341 | const float* top_diff = input_grad.data(); 342 | 343 | float* bottom0_diff = output_grad_0.data(); 344 | float* bottom1_diff = output_grad_1.data(); 345 | 346 | const int num = output_grad_0.dimension(0); 347 | const int channels = output_grad_0.dimension(1); 348 | const int height = output_grad_0.dimension(2); 349 | const int width = output_grad_0.dimension(3); 350 | 351 | const int paddedheight = height + 2*pad_size_; 352 | const int paddedwidth = width + 2*pad_size_; 353 | 354 | const int bottomcount = channels * height * width; 355 | 356 | int botThreadCount = bottomcount; 357 | 358 | // CorrelationLayerBackward 359 | 360 | // == Run kernel Backward 0 361 | dim3 totalBlocksBackward0(width, height, channels * num); //First dim is fastest 362 | dim3 threadsPerBlockBackward0(THREADS_PER_WARP * WARPS_PER_BLOCK); 363 | const int buffer_size_backw0 = ((int)ceil((float)(2 * kernel_radius_) / (float)stride1_) + 1) * top_channels_; 364 | 365 | // == Run kernel Backward 0 366 | for(int n = 0; n < num; n++) { 367 | //Bottom0: 368 | CorrelateDataBackward0<<>>( 369 | botThreadCount, 370 | num, n, top_width_, top_height_, top_channels_, 371 | max_displacement_, neighborhood_grid_radius_, neighborhood_grid_width_, kernel_radius_, 372 | stride1_, stride2_, 373 | width, height, paddedwidth, paddedheight, channels, bottomcount, pad_size_, 374 | bottom0_diff, padded_1.data(), top_diff 375 | ); 376 | } 377 | 378 | // == Run kernel Backward 1 379 | for(int n = 0; n < num; n++) { 380 | CorrelateDataBackward1<<>>( 381 | botThreadCount, 382 | num, n, top_width_, top_height_, top_channels_, 383 | max_displacement_, neighborhood_grid_radius_, neighborhood_grid_width_, kernel_radius_, 384 | stride1_, stride2_, 385 | width, height, paddedwidth, paddedheight, channels, bottomcount, pad_size_, 386 | padded_0.data(), bottom1_diff, top_diff 387 | ); 388 | } 389 | 390 | } 391 | 392 | #endif // GOOGLE_CUDA 393 | -------------------------------------------------------------------------------- /custom_ops/source_code/correlation_op.h: -------------------------------------------------------------------------------- 1 | #define EIGEN_USE_THREADS 2 | 3 | #include "tensorflow/core/framework/op_kernel.h" 4 | #include "tensorflow/core/framework/op.h" 5 | 6 | using namespace tensorflow; 7 | 8 | struct CorrelationAttrs { 9 | CorrelationAttrs(OpKernelConstruction* c) { 10 | OP_REQUIRES_OK(c, c->GetAttr("kernel_size", &kernel_size)); 11 | OP_REQUIRES_OK(c, c->GetAttr("max_displacement", &max_displacement)); 12 | OP_REQUIRES_OK(c, c->GetAttr("pad", &pad_size)); 13 | OP_REQUIRES_OK(c, c->GetAttr("stride_1", &stride_1)); 14 | OP_REQUIRES_OK(c, c->GetAttr("stride_2", &stride_2)); 15 | 16 | OP_REQUIRES(c, kernel_size % 2 != 0, 17 | errors::InvalidArgument("kernel_size must be odd")); 18 | } 19 | CorrelationAttrs() {} 20 | 21 | int pad_size; 22 | int stride_1; 23 | int stride_2; 24 | int max_displacement; 25 | int kernel_size; 26 | }; 27 | 28 | struct CorrelationState { 29 | CorrelationState(CorrelationAttrs attrs, int in_height, int in_width, int in_channels) { 30 | pad_size = attrs.pad_size; 31 | stride_1 = attrs.stride_1; 32 | stride_2 = attrs.stride_2; 33 | max_displacement = attrs.max_displacement; 34 | kernel_size = attrs.kernel_size; 35 | 36 | padded_height = in_height + 2 * pad_size; 37 | padded_width = in_width + 2 * pad_size; 38 | 39 | // Compute size of unreachable border region (on each side) 40 | kernel_radius = (kernel_size - 1) / 2; 41 | border_size = max_displacement + kernel_radius; 42 | 43 | // Given a center position in image 1, how many displaced positions in -x / +x 44 | // direction do we consider in image 2 (neighborhoodGridWidth): 45 | neighborhood_grid_radius = max_displacement / stride_2; 46 | neighborhood_grid_width = neighborhood_grid_radius * 2 + 1; 47 | 48 | out_width = ceil((float)(padded_width - border_size *2) / (float)stride_1); 49 | out_height = ceil((float)(padded_height - border_size *2) / (float)stride_1); 50 | // Top Channels amount to displacement combinations in X and Y direction: 51 | out_channels = neighborhood_grid_width * neighborhood_grid_width; 52 | } 53 | 54 | int pad_size; 55 | int stride_1; 56 | int stride_2; 57 | int kernel_radius; 58 | int max_displacement; 59 | int kernel_size; 60 | int neighborhood_grid_radius; 61 | int neighborhood_grid_width; 62 | int padded_height; 63 | int padded_width; 64 | int border_size; 65 | int out_height; 66 | int out_width; 67 | int out_channels; 68 | }; 69 | -------------------------------------------------------------------------------- /custom_ops/source_code/decode_flo_op.cc: -------------------------------------------------------------------------------- 1 | // 2 | // lmbspecialops - a collection of tensorflow ops 3 | // Copyright (C) 2017 Albert Ludwigs University of Freiburg, Pattern Recognition and Image Processing, Computer Vision Group 4 | // Author(s): Lukas Voegtle 5 | // 6 | // This program is free software: you can redistribute it and/or modify 7 | // it under the terms of the GNU General Public License as published by 8 | // the Free Software Foundation, either version 3 of the License, or 9 | // (at your option) any later version. 10 | // 11 | // This program is distributed in the hope that it will be useful, 12 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | // GNU General Public License for more details. 15 | // 16 | // You should have received a copy of the GNU General Public License 17 | // along with this program. If not, see . 18 | // 19 | //#include "config.h" 20 | #include "tensorflow/core/framework/op.h" 21 | #include "tensorflow/core/framework/shape_inference.h" 22 | #include "tensorflow/core/framework/op_kernel.h" 23 | #include "tensorflow/core/framework/op.h" 24 | 25 | using namespace tensorflow; 26 | 27 | REGISTER_OP("DecodeFlo") 28 | .Input("contents: string") 29 | .Output("image: float") 30 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 31 | using namespace ::tensorflow::shape_inference; 32 | c->set_output(0, 33 | c->MakeShape({InferenceContext::kUnknownDim, 34 | InferenceContext::kUnknownDim, 2})); 35 | return Status::OK(); 36 | }) 37 | .Doc(R"doc( 38 | Decode a FLO-encoded image to a float tensor. 39 | contents: 0-D. The FLO-encoded image. 40 | image: 3-D with shape `[height, width, 2]`. 41 | )doc"); 42 | 43 | 44 | 45 | 46 | // Decode the contents of a FLO file 47 | class DecodeFloOp : public OpKernel { 48 | public: 49 | explicit DecodeFloOp(OpKernelConstruction* context) : OpKernel(context) { 50 | } 51 | 52 | void Compute(OpKernelContext* context) override { 53 | const Tensor& contents = context->input(0); 54 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(contents.shape()), 55 | errors::InvalidArgument("contents must be scalar, got shape ", 56 | contents.shape().DebugString())); 57 | 58 | // Start decoding image to get shape details 59 | const StringPiece data = contents.scalar()(); 60 | if (data.size() < 12) { 61 | OP_REQUIRES(context, false, 62 | errors::InvalidArgument("Invalid FLO data size, expected at least 12")); 63 | } 64 | 65 | if (!data.starts_with("PIEH")) { 66 | OP_REQUIRES(context, false, 67 | errors::InvalidArgument("Invalid FLO header, expected 'PIEH'")); 68 | } 69 | if (*((float*)(data.data())) != 202021.25f) { 70 | OP_REQUIRES(context, false, 71 | errors::InvalidArgument("Invalid FLO header, expected 202021.25 (sanity check failed)")); 72 | } 73 | uint32 width = *((uint32*)(data.data() + 4)); 74 | uint32 height = *((uint32*)(data.data() + 8)); 75 | 76 | // Verify that width and height are not too large: 77 | // - verify width and height don't overflow int. 78 | // - width can later be multiplied by channels_ and sizeof(uint16), so 79 | // verify single dimension is not too large. 80 | // - verify when width and height are multiplied together, there are a few 81 | // bits to spare as well. 82 | const int64 total_size = 83 | static_cast(width) * static_cast(height); 84 | if (width != static_cast(width) || width <= 0 || 85 | width >= (1LL << 27) || height != static_cast(height) || 86 | height <= 0 || height >= (1LL << 27) || total_size >= (1LL << 29)) { 87 | OP_REQUIRES(context, false, 88 | errors::InvalidArgument("FLO size too large for int: ", 89 | width, " by ", height)); 90 | } 91 | 92 | if (data.size() != 12 + width * height * 2 * 4) { 93 | OP_REQUIRES(context, false, 94 | errors::InvalidArgument("Invalid FLO data size, expected ", 12 + width * height * 2 * 4)); 95 | } 96 | 97 | // Allocate tensor 98 | Tensor* output = nullptr; 99 | OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({height, width, 2}), &output)); 100 | 101 | // Finish decoding image 102 | const uint8* innerData = (const uint8*)(data.data() + 12); 103 | memcpy(output->flat().data(), innerData, height * width * 2 * sizeof(float)); 104 | } 105 | }; 106 | REGISTER_KERNEL_BUILDER(Name("DecodeFlo").Device(DEVICE_CPU), DecodeFloOp); 107 | 108 | -------------------------------------------------------------------------------- /custom_ops/source_code/decode_ppm_op.cc: -------------------------------------------------------------------------------- 1 | // 2 | // lmbspecialops - a collection of tensorflow ops 3 | // Copyright (C) 2017 Albert Ludwigs University of Freiburg, Pattern Recognition and Image Processing, Computer Vision Group 4 | // Author(s): Lukas Voegtle 5 | // 6 | // This program is free software: you can redistribute it and/or modify 7 | // it under the terms of the GNU General Public License as published by 8 | // the Free Software Foundation, either version 3 of the License, or 9 | // (at your option) any later version. 10 | // 11 | // This program is distributed in the hope that it will be useful, 12 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | // GNU General Public License for more details. 15 | // 16 | // You should have received a copy of the GNU General Public License 17 | // along with this program. If not, see . 18 | // 19 | #include "tensorflow/core/framework/op.h" 20 | #include "tensorflow/core/framework/shape_inference.h" 21 | #include "tensorflow/core/framework/op_kernel.h" 22 | #include "tensorflow/core/framework/op.h" 23 | 24 | using namespace tensorflow; 25 | 26 | REGISTER_OP("DecodePpm") 27 | .Input("contents: string") 28 | .Attr("dtype: {uint8, uint16} = DT_UINT8") 29 | .Output("image: dtype") 30 | .Output("maxval: dtype") 31 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 32 | using namespace ::tensorflow::shape_inference; 33 | ShapeHandle unused; 34 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); 35 | c->set_output(0, 36 | c->MakeShape({InferenceContext::kUnknownDim, 37 | InferenceContext::kUnknownDim, 3})); 38 | c->set_output(1, c->Scalar()); 39 | return Status::OK(); 40 | }) 41 | .Doc(R"doc( 42 | Decode a PPM-encoded image to a uint8 or uint16 tensor. 43 | contents: 0-D. The PPM-encoded image. 44 | image: 3-D with shape `[height, width, 3]`. 45 | maxval: maxval from the ppm file. 46 | )doc"); 47 | 48 | 49 | 50 | 51 | // Decode the contents of a PPM file 52 | class DecodePpmOp : public OpKernel { 53 | public: 54 | explicit DecodePpmOp(OpKernelConstruction* context) : OpKernel(context) { 55 | DataType dt; 56 | OP_REQUIRES_OK(context, context->GetAttr("dtype", &dt)); 57 | OP_REQUIRES( 58 | context, dt == DataType::DT_UINT8 || dt == DataType::DT_UINT16, 59 | errors::InvalidArgument("Type must be UINT8 or UINT16, got ", dt)); 60 | if (dt == DataType::DT_UINT8) { 61 | desired_channel_bits_ = 8; 62 | } else { 63 | desired_channel_bits_ = 16; 64 | } 65 | } 66 | 67 | static bool skipWhitespace(OpKernelContext* context, const StringPiece& data, size_t* index, bool onlyOne = false) { 68 | if ((*index) >= data.size()) { 69 | context->CtxFailure(errors::InvalidArgument("Invalid PPM header, unexpected end of file")); 70 | return false; 71 | } 72 | char c = data[(*index)]; 73 | if (c != ' ' && c != '\t' && c != '\r' && c != '\n' && c != '#') { 74 | context->CtxFailure(errors::InvalidArgument("Invalid PPM header, expected whitespace at ", (*index))); 75 | return false; 76 | } 77 | (*index)++; 78 | if (onlyOne) { 79 | return true; 80 | } 81 | for (; (*index) < data.size(); (*index)++) { 82 | c = data[(*index)]; 83 | // Skip comments 84 | if (c == '#') { 85 | for (; (*index) < data.size(); (*index)++) { 86 | c = data[(*index)]; 87 | if (c == '\r' || c == '\n') { 88 | break; 89 | } 90 | } 91 | } 92 | if (c != ' ' && c != '\t' && c != '\r' && c != '\n') { 93 | break; 94 | } 95 | } 96 | return true; 97 | } 98 | 99 | static bool readInt(OpKernelContext* context, const StringPiece& data, size_t* index, uint32* number) { 100 | if ((*index) >= data.size()) { 101 | context->CtxFailure(errors::InvalidArgument("Invalid PPM header, unexpected end of file")); 102 | return false; 103 | } 104 | char c = data[(*index)]; 105 | if (c < '0' || c > '9') { 106 | context->CtxFailure(errors::InvalidArgument("Invalid PPM header, expected float at ", (*index))); 107 | return false; 108 | } 109 | *number = 0; 110 | for (; (*index) < data.size(); (*index)++) { 111 | c = data[(*index)]; 112 | if (c >= '0' && c <= '9') { 113 | *number *= 10; 114 | *number += c - '0'; 115 | } else { 116 | break; 117 | } 118 | } 119 | return true; 120 | } 121 | 122 | void Compute(OpKernelContext* context) override { 123 | const Tensor& contents = context->input(0); 124 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(contents.shape()), 125 | errors::InvalidArgument("contents must be scalar, got shape ", 126 | contents.shape().DebugString())); 127 | 128 | // Start decoding image to get shape details 129 | const StringPiece data = contents.scalar()(); 130 | if (data.size() < 9) { 131 | //P6000 132 | OP_REQUIRES(context, false, 133 | errors::InvalidArgument("Invalid PPM data size, data too small for PPM file")); 134 | } 135 | if (!data.starts_with("P6")) { 136 | OP_REQUIRES(context, false, 137 | errors::InvalidArgument("Invalid PPM header, expected 'P6'")); 138 | } 139 | 140 | size_t index = 2; 141 | if (!skipWhitespace(context, data, &index)) { 142 | return; 143 | } 144 | uint32 width = 0, height = 0, maxval = 0; 145 | if (!readInt(context, data, &index, &width)) { 146 | return; 147 | } 148 | if (!skipWhitespace(context, data, &index)) { 149 | return; 150 | } 151 | if (!readInt(context, data, &index, &height)) { 152 | return; 153 | } 154 | if (!skipWhitespace(context, data, &index)) { 155 | return; 156 | } 157 | if (!readInt(context, data, &index, &maxval)) { 158 | return; 159 | } 160 | if (!skipWhitespace(context, data, &index, true)) { 161 | return; 162 | } 163 | // Verify that width and height are not too large: 164 | // - verify width and height don't overflow int. 165 | // - width can later be multiplied by channels_ and sizeof(uint16), so 166 | // verify single dimension is not too large. 167 | // - verify when width and height are multiplied together, there are a few 168 | // bits to spare as well. 169 | const int64 total_size = 170 | static_cast(width) * static_cast(height); 171 | if (width != static_cast(width) || width <= 0 || 172 | width >= (1LL << 27) || height != static_cast(height) || 173 | height <= 0 || height >= (1LL << 27) || total_size >= (1LL << 29)) { 174 | OP_REQUIRES(context, false, 175 | errors::InvalidArgument("PPM size too large for int: ", 176 | width, " by ", height)); 177 | } 178 | 179 | // Allocate tensor 180 | Tensor* output = nullptr; 181 | OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({height, width, 3}), &output)); 182 | Tensor* output_maxval = nullptr; 183 | OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({}), &output_maxval)); 184 | 185 | if (desired_channel_bits_ == 8 && maxval <= 255) { 186 | if (data.size() != index + width * height * 3 * sizeof(uint8)) { 187 | OP_REQUIRES(context, false, 188 | errors::InvalidArgument("Invalid PPM data size, expected ", index + width * height * 3 * sizeof(uint8))); 189 | } 190 | // Finish decoding image 191 | const uint8* innerData = (const uint8*)(data.data() + index); 192 | uint8* dstData = output->flat().data(); 193 | std::memcpy(dstData, innerData, height * width * 3 * sizeof(uint8)); 194 | output_maxval->scalar()() = (uint8)maxval; 195 | } else if (desired_channel_bits_ == 16 && maxval > 255) { 196 | if (data.size() != index + width * height * 3 * sizeof(uint16)) { 197 | OP_REQUIRES(context, false, 198 | errors::InvalidArgument("Invalid PPM data size, expected ", index + width * height * 3 * sizeof(uint16))); 199 | } 200 | // Finish decoding image 201 | const uint8* innerData = (const uint8*)(data.data() + index); 202 | uint16* dstData = output->flat().data(); 203 | std::memcpy(dstData, innerData, height * width * 3 * sizeof(uint16)); 204 | // PPM data is always in big endian 205 | if (port::kLittleEndian) { 206 | // Change endianness from big endian to system endianness 207 | size_t size = height * width * 3 * sizeof(uint16); 208 | uint8* bytes = (uint8*)dstData; 209 | for (size_t i = 0; i < size; i+=sizeof(uint16)) { 210 | for (size_t j = 0; j < sizeof(uint16) / 2; j++) { 211 | std::swap(bytes[i + j], bytes[i + (sizeof(uint16) - j - 1)]); 212 | } 213 | } 214 | } 215 | 216 | output_maxval->scalar()() = (uint16)maxval; 217 | } else { 218 | OP_REQUIRES(context, false, 219 | errors::InvalidArgument("PPM maxval ", maxval, " does not match requested bit depth format ", desired_channel_bits_)); 220 | } 221 | } 222 | 223 | private: 224 | int desired_channel_bits_; 225 | }; 226 | REGISTER_KERNEL_BUILDER(Name("DecodePpm").Device(DEVICE_CPU), DecodePpmOp); 227 | 228 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import tensorflow as tf 4 | import custom_ops.compiled as compiled_ops 5 | import urllib 6 | import zipfile 7 | 8 | class DataLoader: 9 | 10 | def __init__(self, dirs, total_num_examples): 11 | 12 | self.dirs = dirs 13 | self.total_num_examples = total_num_examples 14 | 15 | self.download_and_extract_data_if_necessary() 16 | 17 | def download_and_extract_data_if_necessary(self): 18 | cwd = os.getcwd() 19 | url = 'https://lmb.informatik.uni-freiburg.de/data/FlyingChairs/FlyingChairs.zip' 20 | filepath = cwd + '/' + url.split('/')[-1] 21 | 22 | if not os.path.isdir(filepath[:-4]): 23 | 24 | result = input('About to download and extract the flying chairs dataset.' 25 | '\nThis requires ~85GB of free space. Continue? [y/n]\n') 26 | 27 | if result == 'y': 28 | print('downloading flying chairs dataset as zip, this may take ~30-60 minutes...') 29 | urllib.request.urlretrieve(url, filepath) 30 | else: 31 | exit(0) 32 | 33 | print('extracting flying chairs dataset, this may take a while...') 34 | 35 | zip_ref = zipfile.ZipFile(filepath, 'r') 36 | zip_ref.extractall(cwd) 37 | zip_ref.close() 38 | os.remove(filepath) 39 | os.rename(filepath[:-4] + '_release', filepath[:-4]) 40 | 41 | def input_parser(self, img_paths, flow_img_paths): 42 | 43 | imgs = tf.map_fn(lambda img_path: compiled_ops.decode_ppm(tf.read_file(img_path[0]))[0], img_paths, dtype=tf.uint8) 44 | flow_imgs = tf.map_fn(lambda flow_img_path: compiled_ops.decode_flo(tf.read_file(flow_img_path[0])), flow_img_paths, dtype=tf.float32) 45 | return imgs, flow_imgs 46 | 47 | def prime_image_data(self, image_dir, starting_example, ending_example, file_extension): 48 | 49 | image_filenames = [file for file in os.listdir(image_dir) if file.endswith(file_extension)] 50 | image_filenames.sort() 51 | image_paths = [image_dir + image_filename for image_filename in image_filenames] 52 | 53 | grouped_images = [] 54 | group = [] 55 | image_num = starting_example 56 | for image_path in image_paths: 57 | trimmed_string = image_path.split('/')[-1].split('_')[0] 58 | current_image_num = int(trimmed_string) - 1 if str.isdigit(trimmed_string) else int(trimmed_string[1:]) - 1 59 | if current_image_num >= starting_example: 60 | if current_image_num != image_num: 61 | image_num = current_image_num 62 | if current_image_num > ending_example: 63 | break 64 | grouped_images.append(group) 65 | group = [] 66 | group.append(image_path) 67 | 68 | grouped_images.append(group) 69 | 70 | return grouped_images 71 | 72 | 73 | def prime_data_for_loading(self, starting_example, ending_example, batch_size, training): 74 | 75 | self.batch_size = batch_size 76 | 77 | image_paths_list = list() 78 | image_paths_tensor_list = list() 79 | 80 | rgb_image_dir = self.dirs.rgb_image_dir 81 | grouped_rgb_images = self.prime_image_data(rgb_image_dir, starting_example, ending_example, self.dirs.rgb_format) 82 | rgb_img_paths_tensor = tf.expand_dims(tf.constant(grouped_rgb_images), -1) 83 | 84 | image_paths_list.append(grouped_rgb_images) 85 | image_paths_tensor_list.append(rgb_img_paths_tensor) 86 | 87 | flow_image_dir = self.dirs.flow_image_dir 88 | grouped_flow_images = self.prime_image_data(flow_image_dir, starting_example, ending_example, self.dirs.flow_format) 89 | flow_img_paths_tensor = tf.expand_dims(tf.constant(grouped_flow_images),-1) 90 | 91 | image_paths_list.append(grouped_flow_images) 92 | image_paths_tensor_list.append(flow_img_paths_tensor) 93 | 94 | data = tf.data.Dataset.from_tensor_slices(tuple(image_paths_tensor_list)) 95 | data = data.map(map_func=self.input_parser, num_parallel_calls=8) 96 | data = data.shuffle(tf.cast(batch_size*100,tf.int64)) 97 | data = data.apply(tf.contrib.data.batch_and_drop_remainder(tf.cast(batch_size,tf.int64))) 98 | data = data.prefetch(1) 99 | data = data.repeat() 100 | 101 | return data 102 | 103 | def prime_data(self, start_it_train, end_it_train, start_it_valid, end_it_valid, batch_size, data_type): 104 | 105 | training_dataset = self.prime_data_for_loading(start_it_train, end_it_train, batch_size, True) 106 | self.training_data_len = end_it_train - start_it_train 107 | 108 | validation_dataset = self.prime_data_for_loading(start_it_valid, end_it_valid, batch_size, False) 109 | self.validation_data_len = end_it_valid - start_it_valid 110 | 111 | iterator = tf.data.Iterator.from_string_handle(data_type, training_dataset.output_types, training_dataset.output_shapes) 112 | self.next_element = iterator.get_next() 113 | 114 | training_iterator = training_dataset.make_initializable_iterator() 115 | validation_iterator = validation_dataset.make_initializable_iterator() 116 | 117 | self.training_init_op = training_iterator.make_initializer(training_dataset) 118 | self.validation_init_op = validation_iterator.make_initializer(validation_dataset) 119 | 120 | self.training_handle = self.sess.run(training_iterator.string_handle()) 121 | self.validation_handle = self.sess.run(validation_iterator.string_handle()) 122 | 123 | def set_session(self, sess): 124 | self.sess = sess -------------------------------------------------------------------------------- /directories.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | class Directories: 4 | 5 | def __init__(self): 6 | self.init_directories() 7 | 8 | def init_directories(self): 9 | self.__project_dir = os.getcwd() 10 | self.__chkpt_dir = '/chkpts/' 11 | self.__log_dir = '/log/' 12 | self.__rgb_image_dir = '/FlyingChairs/data/' 13 | self.__flow_image_dir = '/FlyingChairs/data/' 14 | self.__rgb_format = '.ppm' 15 | self.__flow_format = '.flo' 16 | self.__training_dir = 'training/' 17 | self.__validation_dir = 'validation/' 18 | self.__networks_dir = '/networks/' 19 | 20 | def update_for_task(self, current_dir, network_dir): 21 | 22 | self.init_directories() 23 | 24 | self.__networks_dir = current_dir + self.__networks_dir 25 | self.__chkpt_dir = network_dir + self.__chkpt_dir 26 | self.__log_dir = network_dir + self.__log_dir 27 | self.__rgb_image_dir = current_dir + self.__rgb_image_dir 28 | self.__flow_image_dir = current_dir + self.__flow_image_dir 29 | 30 | # Getters # 31 | #---------# 32 | 33 | @property 34 | def project_dir(self): 35 | return self.__project_dir 36 | @property 37 | def networks_dir(self): 38 | return self.__networks_dir 39 | @property 40 | def chkpt_dir(self): 41 | return self.__chkpt_dir 42 | @property 43 | def log_dir(self): 44 | return self.__log_dir 45 | @property 46 | def rgb_image_dir(self): 47 | return self.__rgb_image_dir 48 | @property 49 | def flow_image_dir(self): 50 | return self.__flow_image_dir 51 | @property 52 | def image_data_filename(self): 53 | return self.__image_data_filename 54 | @property 55 | def rgb_format(self): 56 | return self.__rgb_format 57 | @property 58 | def flow_format(self): 59 | return self.__flow_format 60 | @property 61 | def training_dir(self): 62 | return self.__training_dir 63 | @property 64 | def validation_dir(self): 65 | return self.__validation_dir 66 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import faulthandler 3 | import numpy as np 4 | from task import Task 5 | 6 | def main(): 7 | 8 | seed = 1 9 | np.random.seed(seed) 10 | 11 | faulthandler.enable() 12 | 13 | mode = input('\n\ntrain (t), visualise (v) ?\n') 14 | while mode != 't' and mode != 'v': 15 | mode = input('\n\ntrain (t), visualise (v) ?\n') 16 | 17 | task = Task() 18 | 19 | if mode == 't': 20 | task.run(train=True) 21 | else: 22 | task.run(train=False) 23 | 24 | 25 | if __name__ == '__main__': 26 | main() 27 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | import math 4 | import numpy as np 5 | import custom_ops.native as native_ops 6 | import custom_ops.compiled as compiled_ops 7 | 8 | class Dims(): 9 | def __init__(self): 10 | self.input_vector = 0 11 | self.input_image = [384,448,3] 12 | self.output = [384,448,2] 13 | 14 | class Network(): 15 | 16 | def __init__(self, data_loader, dirs): 17 | 18 | self.data_loader = data_loader 19 | self.chkpt_dir = dirs.chkpt_dir 20 | self.log_dir = dirs.log_dir 21 | self.log_dir_training = dirs.log_dir + 'training/' 22 | self.log_dir_validation = dirs.log_dir + 'validation/' 23 | 24 | self.dims = Dims() 25 | 26 | self.global_step = tf.placeholder(tf.int32) 27 | 28 | self.window_size = 11 29 | self.max_dist = int(math.floor(self.window_size/2)) 30 | 31 | self.batch_size = 8 32 | 33 | self.initial_learning_rate = 2e-4 34 | self.learning_decrement_rate = 0.2e6 35 | self.learning_decrement = 0.5 36 | self.min_learning_rate = 6.25e-6 37 | self.max_learning_rate = 1e-4 38 | 39 | self.total_iterations = 1e6 40 | 41 | self.validation_ratio = 0.1 42 | self.num_training_examples = int(round(self.data_loader.total_num_examples * (1 - self.validation_ratio))) 43 | self.num_validation_examples = self.data_loader.total_num_examples - self.num_training_examples 44 | self.total_num_examples = self.num_training_examples + self.num_validation_examples 45 | 46 | self.define_placeholders() 47 | 48 | @staticmethod 49 | def directory(): 50 | return os.path.dirname(os.path.realpath(__file__)) 51 | 52 | 53 | # Operations # 54 | #------------# 55 | 56 | def warp(self, ct, w): 57 | return native_ops.image_warp(ct, w) 58 | 59 | def cost_volume(self, cwt, ctm1): 60 | return compiled_ops.correlation(cwt, ctm1, pad=4, kernel_size=1, max_displacement=4, stride_1=1, stride_2=1) 61 | 62 | 63 | # Inputs # 64 | #--------# 65 | 66 | def define_placeholders(self): 67 | 68 | self.data_type = tf.placeholder(tf.string) 69 | self.vis_mode = tf.placeholder(tf.bool) 70 | self.step_placeholder = tf.placeholder(tf.int32) 71 | 72 | def define_data_loader(self): 73 | 74 | self.loaded_data = list(self.data_loader.next_element) 75 | self.loaded_images = tf.cast(self.loaded_data[0],tf.float32) 76 | self.loaded_flow_images = tf.cast(self.loaded_data[1][:,0:1],tf.float32)/20 77 | 78 | def define_network_inputs(self): 79 | 80 | # rgb images 81 | self.stacked_images = tf.reshape(self.loaded_images, (-1, 384, 512, 3)) 82 | 83 | # flow images 84 | self.appended_flow_images = tf.concat((self.loaded_flow_images, tf.zeros((self.batch_size,1, 384, 512, 2))), 1) 85 | self.stacked_flow_images = tf.reshape(self.appended_flow_images, (-1, 384, 512, 2)) 86 | 87 | # combined images 88 | self.stacked_combined_images = tf.concat((self.stacked_images, self.stacked_flow_images),-1) 89 | self.cropped_combined_images = tf.map_fn(lambda img: tf.image.crop_to_bounding_box(img, 90 | 0, tf.random_uniform((), 0, 512 - 448, tf.int32), 91 | 384, 448), self.stacked_combined_images) 92 | 93 | self.unstacked_combined_images = tf.reshape(self.cropped_combined_images, (self.batch_size, 2, self.dims.input_image[0], 94 | self.dims.input_image[1], 5)) 95 | 96 | # rgb images 97 | self.unstacked_images = self.unstacked_combined_images[:,:,:,:,0:3] 98 | self.x_images = tf.transpose(self.unstacked_images, [0,1,4,2,3]) 99 | 100 | # flow images 101 | self.x_flow_images = self.unstacked_combined_images[:,0,:,:,3:5] 102 | 103 | def define_network_targets(self): 104 | 105 | # ground truth outputs 106 | self.gt_ys = list() 107 | for i in range(5): 108 | image = tf.image.resize_images(self.x_flow_images, (int(self.dims.input_image[0]/(math.pow(2,6-i))), 109 | int(self.dims.input_image[1]/(math.pow(2,6-i))))) 110 | self.gt_ys.append(tf.transpose(image, [0,3,1,2])) 111 | 112 | 113 | # Architecture # 114 | #--------------# 115 | 116 | def feature_pyramid_forward_pass(self, I): 117 | 118 | ct0_5 = tf.layers.conv2d(I,16,[3,3],(2,2),'same','channels_first',(1,1),tf.nn.leaky_relu) 119 | ct1 = tf.layers.conv2d(ct0_5, 16, [3, 3], (1, 1), 'same', 'channels_first', (1, 1), tf.nn.leaky_relu) 120 | 121 | ct1_5 = tf.layers.conv2d(ct1,32,[3,3],(2,2),'same','channels_first',(1,1),tf.nn.leaky_relu) 122 | ct2 = tf.layers.conv2d(ct1_5, 32, [3, 3], (1, 1), 'same', 'channels_first', (1, 1), tf.nn.leaky_relu) 123 | 124 | ct2_5 = tf.layers.conv2d(ct2,64,[3,3],(2,2),'same','channels_first',(1,1),tf.nn.leaky_relu) 125 | ct3 = tf.layers.conv2d(ct2_5, 64, [3, 3], (1, 1), 'same', 'channels_first', (1, 1), tf.nn.leaky_relu) 126 | 127 | ct3_5 = tf.layers.conv2d(ct3,64,[3,3],(2,2),'same','channels_first',(1,1),tf.nn.leaky_relu) 128 | ct4 = tf.layers.conv2d(ct3_5, 64, [3, 3], (1, 1), 'same', 'channels_first', (1, 1), tf.nn.leaky_relu) 129 | 130 | ct4_5 = tf.layers.conv2d(ct4,64,[3,3],(2,2),'same','channels_first',(1,1),tf.nn.leaky_relu) 131 | ct5 = tf.layers.conv2d(ct4_5, 64, [3, 3], (1, 1), 'same', 'channels_first', (1, 1), tf.nn.leaky_relu) 132 | 133 | ct5_5 = tf.layers.conv2d(ct5,64,[3,3],(2,2),'same','channels_first',(1,1),tf.nn.leaky_relu) 134 | ct6 = tf.layers.conv2d(ct5_5, 64, [3, 3], (1, 1), 'same', 'channels_first', (1, 1), tf.nn.leaky_relu) 135 | 136 | return [ct6, ct5, ct4, ct3, ct2, ct1] 137 | 138 | def flow_estimator_forward_pass(self, ctm1, ct, w): 139 | 140 | ct_trans = tf.transpose(ct, [0,2,3,1]) 141 | w_trans = tf.transpose(w, [0,2,3,1]) 142 | 143 | cwt_trans = self.warp(ct_trans,w_trans) 144 | 145 | cwt = tf.transpose(cwt_trans, [0,3,1,2]) 146 | 147 | cvt = self.cost_volume(cwt,ctm1) 148 | 149 | concatted = tf.concat((cvt,w,ctm1),-3) 150 | 151 | conv1 = tf.layers.conv2d(concatted,128,[3,3],(1,1),'same','channels_first',(1,1),tf.nn.leaky_relu) 152 | conv2 = tf.layers.conv2d(conv1,128,[3,3],(1,1),'same','channels_first',(1,1),tf.nn.leaky_relu) 153 | conv3 = tf.layers.conv2d(conv2,96,[3,3],(1,1),'same','channels_first',(1,1),tf.nn.leaky_relu) 154 | conv4 = tf.layers.conv2d(conv3,64,[3,3],(1,1),'same','channels_first',(1,1),tf.nn.leaky_relu) 155 | ft = tf.layers.conv2d(conv4,32,[3,3],(1,1),'same','channels_first',(1,1),tf.nn.leaky_relu) 156 | wt = tf.layers.conv2d(ft,2,[3,3],(1,1),'same','channels_first',(1,1),tf.nn.leaky_relu) 157 | 158 | return ft, wt 159 | 160 | def context_forward_pass(self, ft, wt): 161 | 162 | concatted = tf.concat((ft, wt), -3) 163 | 164 | conv1 = tf.layers.conv2d(concatted,128,[3,3],(1,1),'same','channels_first',(1,1),tf.nn.leaky_relu) 165 | conv2 = tf.layers.conv2d(conv1,128,[3,3],(1,1),'same','channels_first',(2,2),tf.nn.leaky_relu) 166 | conv3 = tf.layers.conv2d(conv2,128,[3,3],(1,1),'same','channels_first',(4,4),tf.nn.leaky_relu) 167 | conv4 = tf.layers.conv2d(conv3,96,[3,3],(1,1),'same','channels_first',(8,8),tf.nn.leaky_relu) 168 | conv5 = tf.layers.conv2d(conv4,64,[3,3],(1,1),'same','channels_first',(16,16),tf.nn.leaky_relu) 169 | conv6 = tf.layers.conv2d(conv5,32,[3,3],(1,1),'same','channels_first',(1,1),tf.nn.leaky_relu) 170 | conv7 = tf.layers.conv2d(conv6,2,[3,3],(1,1),'same','channels_first',(1,1),tf.nn.leaky_relu) 171 | 172 | return conv7 + wt 173 | 174 | def define_network_structure(self): 175 | 176 | I1 = self.x_images[:,0] 177 | I2 = self.x_images[:,1] 178 | 179 | self.features1 = self.feature_pyramid_forward_pass(I1) 180 | self.features2 = self.feature_pyramid_forward_pass(I2) 181 | 182 | self.flow_terms = list() 183 | self.upsampled_flow_terms = list() 184 | self.flow_features = list() 185 | 186 | for i, features in enumerate(zip(self.features1,self.features2)): 187 | 188 | feature1 = features[0] 189 | feature2 = features[1] 190 | 191 | if i == 0: 192 | wt_upsampled = tf.zeros((self.batch_size,2,6,7)) 193 | wt_upsampled_scaled = wt_upsampled 194 | self.flow_terms.append(wt_upsampled) 195 | else: 196 | wt = self.flow_terms[i] 197 | wt_trans = tf.transpose(wt, [0, 2, 3, 1]) 198 | wt_trans_upsampled = tf.image.resize_images(wt_trans, ((int(wt.shape[-2] * 2), int(wt.shape[-1] * 2)))) 199 | wt_upsampled = tf.transpose(wt_trans_upsampled, [0, 3, 1, 2]) 200 | wt_upsampled_scaled = wt_upsampled*20/math.pow(2,6-i) 201 | 202 | self.upsampled_flow_terms.append(wt_upsampled) 203 | flow_features, estimated_flow = self.flow_estimator_forward_pass(feature1,feature2,wt_upsampled_scaled) 204 | self.flow_features.append(flow_features) 205 | 206 | refined_flow = self.context_forward_pass(flow_features, estimated_flow) 207 | 208 | self.flow_terms.append(wt_upsampled + estimated_flow + refined_flow) 209 | 210 | self.network_outputs = self.flow_terms[1:] 211 | 212 | 213 | # Build # 214 | #-------# 215 | 216 | def build_model(self): 217 | 218 | self.define_data_loader() 219 | 220 | self.define_network_inputs() 221 | self.define_network_targets() 222 | 223 | self.define_network_structure() 224 | self.define_cost() 225 | self.define_optimiser() 226 | self.define_optimisation() 227 | 228 | tf.global_variables_initializer().run() 229 | 230 | def load_params(self, sess, saver, chkpt_num=None, test=False): 231 | dir(tf.contrib) 232 | if chkpt_num is None: 233 | try: 234 | saver.restore(sess, tf.train.latest_checkpoint(self.chkpt_dir)) 235 | except: 236 | return False, 0 237 | with open(self.chkpt_dir + 'checkpoint', 'r') as checkpoint_file: 238 | starting_it = int(checkpoint_file.read().split('\n', 1)[0].split('-', 1)[-1][:-1]) + 1 239 | else: 240 | try: 241 | saver.restore(sess, self.chkpt_dir + 'model-' + str(chkpt_num)) 242 | except: 243 | return False, 0 244 | starting_it = chkpt_num + 1 245 | return True, starting_it 246 | 247 | 248 | # Optimisation # 249 | #--------------# 250 | 251 | def define_cost(self): 252 | 253 | self.n_targets = self.network_outputs 254 | 255 | # cost 256 | alphas = [0.32,0.08,0.02,0.01,0.005] 257 | self.cost_target = tf.constant([0.]) 258 | for i in range(5): 259 | cost_target_terms = tf.pow(self.n_targets[i]-self.gt_ys[i],2) 260 | self.cost_target += alphas[i]*tf.reduce_sum(cost_target_terms) 261 | 262 | self.regularisation_loss = tf.constant([0.]) 263 | weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) 264 | for weight in weights: 265 | self.regularisation_loss += tf.nn.l2_loss(weight) 266 | factor_l2 = 0.0004 267 | 268 | self.cost_target += factor_l2*self.regularisation_loss 269 | self.cost_aux = tf.identity(float(0)) 270 | self.cost = tf.squeeze(tf.add(self.cost_target, self.cost_aux)) 271 | 272 | self.cost_last_step = self.cost 273 | 274 | def define_optimiser(self): 275 | 276 | self.learning_rate = tf.minimum(tf.maximum( 277 | tf.train.exponential_decay(self.initial_learning_rate, self.global_step, 278 | self.learning_decrement_rate, self.learning_decrement, staircase=True), 279 | self.min_learning_rate),self.max_learning_rate) 280 | 281 | self.optimizer = tf.train.AdamOptimizer(self.learning_rate) 282 | 283 | def define_optimisation(self): 284 | 285 | self.updating_train_step = self.optimizer.minimize(self.cost) 286 | self.train_step = tf.cond(self.vis_mode, lambda: True, lambda: self.updating_train_step) 287 | 288 | self.finished = tf.constant(False,tf.bool) 289 | 290 | 291 | # Log Summaries # 292 | #---------------# 293 | 294 | def init_summary(self): 295 | 296 | self.cost_summary = tf.summary.scalar('cost', self.cost) 297 | self.learning_rate_summary = tf.summary.scalar('learning_rate', self.learning_rate) 298 | self.log_summary_op = tf.summary.merge([self.cost_summary, self.learning_rate_summary]) 299 | 300 | self.random_batch_num = tf.random_uniform([], 0, self.batch_size, tf.int32) 301 | self.vis_summary_op = tf.Summary() 302 | 303 | self.gpu_usage = tf.placeholder(tf.int32) 304 | self.cpu_usage = tf.placeholder(tf.int32) 305 | self.ram_usage = tf.placeholder(tf.float32) 306 | self.throughput = tf.placeholder(tf.float32) 307 | 308 | gpu_memory_summary = tf.summary.scalar('percent gpu memory', 309 | tf.divide(tf.cast(tf.contrib.memory_stats.BytesInUse(),tf.float32),tf.constant(8e7))) 310 | gpu_usage_summary = tf.summary.scalar('percent gpu usage', self.gpu_usage) 311 | cpu_usage_summary = tf.summary.scalar('percent cpu usage', self.cpu_usage) 312 | ram_usage_summary = tf.summary.scalar('percent ram usage', self.ram_usage) 313 | throughput_summary = tf.summary.scalar('throughput MB', self.throughput) 314 | 315 | self.perflog_summary_op = tf.summary.merge([gpu_memory_summary, 316 | gpu_usage_summary, 317 | cpu_usage_summary, 318 | ram_usage_summary, 319 | throughput_summary]) 320 | 321 | def get_image_summary(self, sess, dict_feed, fps): 322 | 323 | dict_feed[self.step_placeholder] = 0 324 | 325 | x_images_out, gt_flow_out, predicted_flow_out = \ 326 | sess.run((self.x_images[self.random_batch_num], self.x_flow_images[self.random_batch_num], 327 | self.network_outputs[-1][self.random_batch_num]), dict_feed) 328 | 329 | x_images_trans = np.transpose(x_images_out, (0,2,3,1)) 330 | predicted_flow_trans = np.transpose(predicted_flow_out, (1,2,0)) 331 | 332 | images_arr = native_ops.modify_images_for_vis(x_images_trans, gt_flow_out, predicted_flow_trans) 333 | 334 | return native_ops.convert_array_to_gif_summary(images_arr, tag='images', fps=fps) 335 | 336 | def get_log_summary(self, sess, dict_feed): 337 | return sess.run(self.log_summary_op, dict_feed) 338 | 339 | def get_summary(self, sess, summary_op, dict_feed): 340 | if summary_op is self.vis_summary_op: 341 | fps = 2 342 | return self.get_image_summary(sess, dict_feed, fps) 343 | elif summary_op is self.log_summary_op: 344 | return self.get_log_summary(sess, dict_feed) 345 | 346 | def write_summary(self, step, summary, writer): 347 | writer.add_summary(summary, step) 348 | writer.flush() 349 | 350 | def write_summaries(self, sess, i, dict_feed, summary_op, data_handles, summary_writers): 351 | 352 | training_handle = data_handles[0] 353 | validation_handle = data_handles[1] 354 | 355 | dict_feed[self.data_type] = training_handle 356 | 357 | training_summ = self.get_summary(sess, summary_op, dict_feed) 358 | self.write_summary(i,training_summ,summary_writers[0]) 359 | 360 | dict_feed[self.data_type] = validation_handle 361 | 362 | validation_summ = self.get_summary(sess, summary_op, dict_feed) 363 | self.write_summary(i,validation_summ,summary_writers[1]) 364 | -------------------------------------------------------------------------------- /readme_images/example_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/djl11/PWCNetTensorFlow/42cfc60b36c8540cf9e71ebb28e7c1f08bb5589b/readme_images/example_loss.png -------------------------------------------------------------------------------- /readme_images/example_training_flow1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/djl11/PWCNetTensorFlow/42cfc60b36c8540cf9e71ebb28e7c1f08bb5589b/readme_images/example_training_flow1.gif -------------------------------------------------------------------------------- /readme_images/example_training_flow2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/djl11/PWCNetTensorFlow/42cfc60b36c8540cf9e71ebb28e7c1f08bb5589b/readme_images/example_training_flow2.gif -------------------------------------------------------------------------------- /readme_images/example_training_flow3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/djl11/PWCNetTensorFlow/42cfc60b36c8540cf9e71ebb28e7c1f08bb5589b/readme_images/example_training_flow3.gif -------------------------------------------------------------------------------- /readme_images/example_training_flow4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/djl11/PWCNetTensorFlow/42cfc60b36c8540cf9e71ebb28e7c1f08bb5589b/readme_images/example_training_flow4.gif -------------------------------------------------------------------------------- /readme_images/example_validation_flow1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/djl11/PWCNetTensorFlow/42cfc60b36c8540cf9e71ebb28e7c1f08bb5589b/readme_images/example_validation_flow1.gif -------------------------------------------------------------------------------- /readme_images/example_validation_flow2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/djl11/PWCNetTensorFlow/42cfc60b36c8540cf9e71ebb28e7c1f08bb5589b/readme_images/example_validation_flow2.gif -------------------------------------------------------------------------------- /readme_images/example_validation_flow3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/djl11/PWCNetTensorFlow/42cfc60b36c8540cf9e71ebb28e7c1f08bb5589b/readme_images/example_validation_flow3.gif -------------------------------------------------------------------------------- /readme_images/example_validation_flow4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/djl11/PWCNetTensorFlow/42cfc60b36c8540cf9e71ebb28e7c1f08bb5589b/readme_images/example_validation_flow4.gif -------------------------------------------------------------------------------- /task.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from trainer import Trainer 4 | from data_loader import DataLoader 5 | from directories import Directories 6 | from network import Network 7 | 8 | class Task(): 9 | 10 | def __init__(self): 11 | 12 | self._dirs = Directories() 13 | self._init_data_loader() 14 | self._init_network() 15 | 16 | def _update_dirs(self, log_folder): 17 | 18 | current_dir = os.path.dirname(os.path.realpath(__file__)) 19 | network_dir = Network.directory() + log_folder 20 | self._dirs.update_for_task(current_dir, network_dir) 21 | 22 | def _init_network(self): 23 | 24 | self.__network = Network(data_loader=self.__data_loader, 25 | dirs=self._dirs) 26 | 27 | def _init_data_loader(self): 28 | self.__data_loader = DataLoader(dirs=self._dirs, 29 | total_num_examples=22872) 30 | 31 | def _init_trainer(self): 32 | 33 | self.__trainer = Trainer(data_loader=self.__data_loader, 34 | network=self.__network, 35 | dirs=self._dirs, 36 | ld_chkpt=True, 37 | save_freq=10000, 38 | log_freq=200, 39 | vis_freq=5000) 40 | return self.__trainer 41 | 42 | 43 | def run(self, train): 44 | 45 | self._init_trainer() 46 | log_folder = '/logged_data' 47 | 48 | self._update_dirs(log_folder) 49 | self.__network.chkpt_dir = self._dirs.chkpt_dir 50 | 51 | if train: 52 | self.__trainer.run_trainer() 53 | else: 54 | self.__trainer.run_visualiser() 55 | self.__trainer.sess.close() 56 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf # add tensorflow framework 2 | import os 3 | import libtmux 4 | import pathlib 5 | 6 | class Trainer(): 7 | 8 | def __init__(self, data_loader, network, dirs, ld_chkpt, save_freq, log_freq, 9 | vis_freq): 10 | 11 | self.sess = tf.InteractiveSession(config=tf.ConfigProto(log_device_placement=False)) 12 | 13 | self.data_loader = data_loader 14 | self.data_loader.set_session(self.sess) 15 | 16 | self.network = network 17 | self.num_training_examples = self.network.num_training_examples 18 | self.num_validation_examples = self.network.num_validation_examples 19 | self.total_num_examples = self.num_training_examples + self.num_validation_examples 20 | 21 | self.networks_dir = dirs.networks_dir 22 | self.ld_chkpt = ld_chkpt 23 | 24 | self.save_freq = save_freq 25 | self.log_freq = log_freq 26 | self.vis_freq = vis_freq 27 | self.dirs = dirs 28 | 29 | def init_saver(self): 30 | self.chkpt_dir = self.dirs.chkpt_dir 31 | all_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) 32 | variables_to_save = [variable for variable in all_variables if variable.name[0:11] != 'placeholder'] 33 | self.saver = tf.train.Saver(var_list=variables_to_save, max_to_keep=None) 34 | pathlib.Path(self.chkpt_dir).mkdir(parents=True, exist_ok=True) 35 | 36 | def initial_save(self): 37 | if not os.path.exists(self.chkpt_dir + 'model.meta'): 38 | self.saver.save(self.sess, self.chkpt_dir + 'model') 39 | 40 | def make_log_dirs(self): 41 | for log_dir in self.log_dirs: 42 | pathlib.Path(log_dir).mkdir(parents=True, exist_ok=True) 43 | 44 | def init_log_dirs(self): 45 | 46 | self.log_dir = self.dirs.log_dir 47 | 48 | self.log_dir_train = self.log_dir + 'plots/' + self.dirs.training_dir 49 | self.log_dir_valid = self.log_dir + 'plots/' + self.dirs.validation_dir 50 | 51 | self.vis_dir_train = self.log_dir + 'vis/' + self.dirs.training_dir 52 | self.vis_dir_valid = self.log_dir + 'vis/' + self.dirs.validation_dir 53 | 54 | self.log_dirs = [ 55 | self.log_dir_train, 56 | self.log_dir_valid, 57 | self.vis_dir_train, 58 | self.vis_dir_valid] 59 | 60 | self.make_log_dirs() 61 | 62 | def init_log_summary_writers(self): 63 | 64 | self.sw_log_train = tf.summary.FileWriter(self.log_dir_train, graph=tf.get_default_graph()) 65 | self.sw_log_valid = tf.summary.FileWriter(self.log_dir_valid, graph=tf.get_default_graph()) 66 | 67 | self.summary_writers[0] = self.sw_log_train 68 | self.summary_writers[1] = self.sw_log_valid 69 | 70 | def init_vis_summary_writers(self): 71 | 72 | self.sw_vis_train = tf.summary.FileWriter(self.vis_dir_train, graph=tf.get_default_graph()) 73 | self.sw_vis_valid = tf.summary.FileWriter(self.vis_dir_valid, graph=tf.get_default_graph()) 74 | 75 | self.summary_writers[2] = self.sw_vis_train 76 | self.summary_writers[3] = self.sw_vis_valid 77 | 78 | def init_summary_writers(self): 79 | 80 | self.network.init_summary() 81 | self.summary_writers = [0]*4 82 | 83 | self.init_log_summary_writers() 84 | self.init_vis_summary_writers() 85 | 86 | def init_logger(self): 87 | self.init_log_dirs() 88 | self.init_summary_writers() 89 | 90 | def init_dataset_loader(self, i): 91 | 92 | dict_feed = {self.network.data_type: self.data_loader.training_handle, self.network.global_step: i} 93 | self.sess.run(self.data_loader.training_init_op,dict_feed) 94 | 95 | dict_feed[self.network.data_type] = self.data_loader.validation_handle 96 | self.sess.run(self.data_loader.validation_init_op,dict_feed) 97 | 98 | def write_summaries(self, i, dict_feed, summary_op, summary_writers): 99 | data_handles = (self.data_loader.training_handle, self.data_loader.validation_handle) 100 | self.network.write_summaries(self.sess, i, dict_feed, summary_op, data_handles, summary_writers) 101 | 102 | def log(self, i): 103 | self.write_summaries(i, self.dict_feed, self.network.log_summary_op, self.summary_writers[0:2]) 104 | print('logged, step ' + str(i)) 105 | 106 | def vis(self, i): 107 | self.write_summaries(i, self.dict_feed, self.network.vis_summary_op, self.summary_writers[2:4]) 108 | 109 | def save(self, i): 110 | self.saver.save(self.sess, self.chkpt_dir + 'model', global_step=i, write_meta_graph=False) 111 | print('saved, step ' + str(i)) 112 | 113 | def train(self, starting_it, vis_mode=False): 114 | 115 | global_step = starting_it 116 | 117 | self.dict_feed = {} 118 | self.dict_feed[self.network.vis_mode] = vis_mode 119 | self.dict_feed[self.network.global_step] = global_step 120 | self.dict_feed[self.network.data_type] = self.data_loader.training_handle 121 | 122 | if starting_it == self.network.total_iterations: return True 123 | 124 | if vis_mode: 125 | self.vis_freq = 1 126 | self.dict_feed[self.network.vis_mode] = True 127 | 128 | while global_step < self.network.total_iterations or self.network.total_iterations == -1: 129 | 130 | self.dict_feed[self.network.global_step] = global_step 131 | self.dict_feed[self.network.data_type] = self.data_loader.training_handle 132 | self.dict_feed[self.network.step_placeholder] = 0 133 | 134 | _, finished, cost = self.sess.run((self.network.train_step, self.network.finished, self.network.cost_last_step), self.dict_feed) 135 | 136 | if global_step % self.log_freq == 0 and not vis_mode: self.log(global_step) 137 | if global_step % self.vis_freq == 0: self.vis(global_step) 138 | if global_step % self.save_freq == 0 and not vis_mode: self.save(global_step) 139 | 140 | global_step += 1 141 | 142 | if vis_mode: input('press enter to visualise another example') 143 | 144 | return True 145 | 146 | def start_tensorboard(self): 147 | self.stop_tensorboard() 148 | server = libtmux.Server() 149 | session_name = 'tensorboard' 150 | os.system('tmux new-session -s ' + session_name + ' -d') 151 | session = server.find_where({'session_name': session_name}) 152 | self.tmux_window = session.attached_window 153 | pane = self.tmux_window.split_window(attach=False) 154 | port = 6006 155 | pane.send_keys('tensorboard --logdir=' + os.getcwd() + '/logged_data' + ' --port ' + str(port)) 156 | 157 | def stop_tensorboard(self): 158 | try: 159 | server = libtmux.Server() 160 | session_name = 'tensorboard' 161 | session = server.find_where({'session_name': session_name}) 162 | self.tmux_window = session.attached_window 163 | self.tmux_window.kill_window() 164 | except NameError: 165 | print('No session was running') 166 | except: 167 | return 168 | 169 | def prime_data(self): 170 | 171 | self.data_loader.prime_data(start_it_train=0, 172 | end_it_train=self.num_training_examples-1, 173 | start_it_valid=self.num_training_examples, 174 | end_it_valid=self.total_num_examples-1, 175 | batch_size=self.network.batch_size, 176 | data_type=self.network.data_type) 177 | 178 | def __init_data_and_model(self): 179 | 180 | self.prime_data() 181 | 182 | starting_iteration = 0 183 | self.network.build_model() 184 | self.init_saver() 185 | if self.ld_chkpt is True: 186 | load_success, starting_iteration = self.network.load_params(self.sess, self.saver) 187 | if load_success is False: 188 | print('model built') 189 | else: 190 | print('model loaded') 191 | else: 192 | print('model built') 193 | 194 | self.init_dataset_loader(starting_iteration) 195 | self.initial_save() 196 | self.start_tensorboard() 197 | 198 | return starting_iteration 199 | 200 | def run_trainer(self): 201 | print('started trainer') 202 | starting_iteration = self.__init_data_and_model() 203 | self.init_logger() 204 | return self.train(starting_iteration) 205 | 206 | def run_visualiser(self): 207 | print('started visualiser') 208 | self.__init_data_and_model() 209 | self.init_logger() 210 | self.train(0,True) 211 | --------------------------------------------------------------------------------