├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── __init__.py ├── config ├── __init__.py ├── config.ini └── extract_config.py ├── data_augmentation.py ├── datasets.py ├── flowlib.py ├── images ├── dance.gif ├── sintel_benchmark.png └── test_images │ ├── frame_0004.png │ ├── frame_0005.png │ ├── frame_0006.png │ ├── frame_0007.png │ └── frame_0008.png ├── img_list ├── KITTI │ ├── train_2015.txt │ ├── train_raw_2015_5f_with_id.txt │ └── val_2015.txt ├── Sintel │ └── train_all_5f_with_id.txt ├── sintel_raw_clip_split.txt └── test_img_list.txt ├── main.py ├── models ├── KITTI 2015 │ ├── unsupervise_with_self_supervision.data-00000-of-00001 │ └── unsupervise_with_self_supervision.index └── Sintel │ ├── supervise_finetune.data-00000-of-00001 │ └── supervise_finetune.index ├── network.py ├── requirements.txt ├── selflow_model.py ├── utils.py └── warp.py /.gitignore: -------------------------------------------------------------------------------- 1 | # python cache 2 | *.pyc 3 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM tensorflow/tensorflow:1.15.2-gpu-py3 2 | RUN apt-get update && apt-get install -y \ 3 | libsm6 \ 4 | libxext6 \ 5 | libxrender-dev 6 | COPY ./requirements.txt / 7 | RUN pip3 install -r /requirements.txt 8 | COPY . /SelFlow 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Pengpeng Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SelFlow: Self-Supervised Learning of Optical Flow 2 | The official Tensorflow implementation of [SelFlow](https://arxiv.org/abs/1904.09117) (CVPR 2019 Oral). 3 | 4 | Authors: [Pengpeng liu](https://ppliuboy.github.io/), [Michael R. Lyu](http://www.cse.cuhk.edu.hk/lyu/), [Irwin King](https://www.cse.cuhk.edu.hk/irwin.king/), [Jia Xu](http://pages.cs.wisc.edu/~jiaxu/index.html) 5 | 6 | - Testing code and part of pre-trained models are available. 7 | - Training code: please refer to [DDFlow](https://github.com/ppliuboy/DDFlow) to implement. With the current testing code and DDFlow code, the only thing you need to do is write a superpixel generation script. We use [skimage.segmentation.slic](https://scikit-image.org/docs/dev/api/skimage.segmentation.html#skimage.segmentation.slic) to generate superpixels. 8 | - Raw Sintel data used in the paper: download images from the [link](https://media.xiph.org/sintel/sintel-2k-png/), resize the image to resolution [1024*436] with nearest neighbor sampling and find the training clips in [clip split](./img_list/sintel_raw_clip_split.txt) 9 | 10 | ![](./images/dance.gif) 11 | 12 | Our SelFlow is the 1st place winner on [Sintel Optical Flow Benchmark](http://sintel.is.tue.mpg.de/results) from November 2018 to November 2019. 13 | ![](./images/sintel_benchmark.png) 14 | 15 | ## Requirements 16 | - **Software:** The code was developed with python (both python 2 and python 3 are supported), opencv, tensorflow 1.8 and anaconda (optional). It's okay to run without anaconda, but you may need to install the lacking packages by yourself when needed. For tensorflow of different versions, you may need to modify some functions accordingly. 17 | ### Dockerfile 18 | There is a dockerfile with the neccesary dependencies which you can build with the command below. 19 | 20 | ```docker build --network=host -t selflow .``` 21 | 22 | You can run the docker image with command below. 23 | 24 | ```docker run -it --rm --network=host -w /SelFlow selflow``` 25 | 26 | You can then follow the instructions below to test the model 27 | 28 | ## Usage 29 | **By default, you can get the testing results using the pre-trained Sintel model by running:** 30 | 31 | python main.py 32 | 33 | Both forward and backward optical flow and their visualization will be written to the output folder. 34 | 35 | **Please refer to the configuration file template [config](config/config.ini) for a detailed description of the different operating modes.** 36 | 37 | 38 | #### Testing 39 | - Edit [config](config/config.ini), set *mode = test*. 40 | - Create or edit a file, where the first three columns are the input image names, and the last column is the saving name. 41 | - Edit [config](config/config.ini) and set *data_list_file* to the file directory. 42 | - Edit [config](config/config.ini) and set *img_dir* to the directory of your image directory. 43 | - Run *python main.py*. 44 | - **Note** 45 | - Supervised pre-trained model: we normalize each channel to be standard normal distribution, please set *is_normalize_img=True*. 46 | - Unsupervised pre-trained model: please set *is_normalize_img=False*. 47 | 48 | ## Pre-trained Models 49 | Check [models](./models) for our pre-trained models on different datasets. 50 | 51 | ## Citation 52 | If you find SelFlow useful in your research, please consider citing: 53 | 54 | @inproceedings{Liu:2019:SelFlow, 55 | title = {SelFlow: Self-Supervised Learning of Optical Flow}, 56 | author = {Pengpeng Liu and Michael R. Lyu and Irwin King and Jia Xu}, 57 | booktitle = {CVPR}, 58 | year = {2019} 59 | } 60 | 61 | @inproceedings{Liu:2019:DDFlow, 62 | title = {DDFlow: Learning Optical Flow with Unlabeled Data Distillation}, 63 | author = {Pengpeng Liu and Irwin King and Michael R. Lyu and Jia Xu}, 64 | booktitle = {AAAI}, 65 | year = {2019}} 66 | 67 | 68 | ## Acknowledgement 69 | Part of our codes are adapted from [PWC-Net](https://github.com/NVlabs/PWC-Net) and [UnFlow](https://github.com/simonmeister/UnFlow), we thank the authors for their contributions. 70 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ppliuboy/SelFlow/0525a2a61b07785bb7c9d6c278769c218c57fc35/__init__.py -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ppliuboy/SelFlow/0525a2a61b07785bb7c9d6c278769c218c57fc35/config/__init__.py -------------------------------------------------------------------------------- /config/config.ini: -------------------------------------------------------------------------------- 1 | [run] 2 | # Total batch size, must be divisible by the number of GPUs. 3 | batch_size = 4 4 | 5 | # Total iteration step. 6 | iter_steps = 400000 7 | 8 | # The initial learning rate. 9 | initial_learning_rate = 1e-4 10 | 11 | # Interval for decaying the learning rate. 12 | decay_steps = 8e4 13 | 14 | # The decay rate. 15 | decay_rate = 0.5 16 | 17 | # Whether to scale optical flow during downsampling or upsampling. 18 | is_scale = True 19 | 20 | # Number of threads for loading input examples. 21 | num_input_threads = 4 22 | 23 | # 'beta1' for Adam optimizer: the exponential decay rate for the 1st moment estimates. 24 | beta1 = 0.9 25 | 26 | # Number of elements the new dataset will sample. 27 | buffer_size = 5000 28 | 29 | # Number of gpus to use. 30 | num_gpus = 1 31 | 32 | # CPU that guides mul-gpu trainging. 33 | cpu_device = /cpu:0 34 | 35 | # How many steps to save checkpoint. 36 | save_checkpoint_interval = 5000 37 | 38 | # How many steps to write summary. 39 | write_summary_interval = 200 40 | 41 | # How many steps to display log on the terminal. 42 | display_log_interval = 50 43 | 44 | # tf.ConfigProto parameters. 45 | allow_soft_placement = True 46 | log_device_placement = False 47 | 48 | # L2 weight decay. 49 | regularizer_scale = 1e-4 50 | 51 | # save direcory of model, summary, sample and so on, better save it as dataset name. 52 | save_dir = Sintel 53 | 54 | # Home directpty for checkpoints, summary and sample. 55 | model_name = supervise_finetune 56 | 57 | # Checkpoints directory, it shall be 'save_dir/model_name/checkpoint_dir'. 58 | checkpoint_dir = checkpoints 59 | 60 | # Summary directory, it shall be 'save_dir/model_name/summary_dir'. 61 | summary_dir = summary 62 | 63 | # Sample directory, it shall be 'save_dir/model_name/sample_dir'. 64 | sample_dir = sample 65 | 66 | # Mode, one of {train, test, generate_fake_flow_occlusion}. 67 | mode = test 68 | 69 | # Training mode, one of {no_self_supervision, self_supervision}. 70 | training_mode = no_self_supervision 71 | 72 | # Bool type, whether restore model from a checkpoint. 73 | is_restore_model = False 74 | 75 | # Restoration model name. If is_restore_model=True, restore this checkpoint 76 | restore_model = ./KITTI/models/no_census_no_occlusion 77 | 78 | [dataset] 79 | # Cropping height for training. 80 | crop_h = 320 81 | 82 | # Cropping width for training. 83 | crop_w = 896 84 | 85 | # Image name list. 86 | # For testing and supervised fine-tuning: 4 columns, first 3 columns are the name of three input images, the last column is the saving image name 87 | # For unsupervised training: 6 columns, first 5 columns are 5 input images, the last column is the saving image name, used for self-supervised training to match flow and occlusion map. 88 | 89 | # data_list_file = ./dataset/KITTI/train_raw_2015_with_id.txt 90 | data_list_file = ./img_list/test_img_list.txt 91 | 92 | # Image storage direcory. 93 | img_dir = ./images/test_images 94 | 95 | # Whether to normalize image as input 96 | is_normalize_img = True 97 | 98 | [self_supervision] 99 | # Image patch height for self-supervised training. 100 | target_h = 256 101 | 102 | # Image patch width for self-supervised training. 103 | target_w = 640 104 | 105 | # Generated flow and occlusion map directory. 106 | fake_flow_occ_dir = ./KITTI/sample/kitti_2015_raw 107 | 108 | 109 | [test] 110 | # Restoration model name. 111 | restore_model = ./models/Sintel/supervise_finetune 112 | save_dir = ./images/test_images 113 | 114 | [generate_fake_flow_occlusion] 115 | # Restoration model name. 116 | restore_model = ./models/KITTI/census_occlusion 117 | save_dir = ./KITTI/sample/kitti_2015_raw 118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /config/extract_config.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | 3 | def config_dict(config_path): 4 | """Returns the config as dictionary, 5 | where the elements have intuitively correct types. 6 | """ 7 | 8 | config = configparser.ConfigParser() 9 | config.read(config_path) 10 | 11 | d = dict() 12 | for section_key in config.sections(): 13 | sd = dict() 14 | section = config[section_key] 15 | for key in section: 16 | val = section[key] 17 | try: 18 | sd[key] = int(val) 19 | except ValueError: 20 | try: 21 | sd[key] = float(val) 22 | except ValueError: 23 | try: 24 | sd[key] = section.getboolean(key) 25 | except ValueError: 26 | sd[key] = val 27 | d[section_key] = sd 28 | return d 29 | 30 | -------------------------------------------------------------------------------- /data_augmentation.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | def random_crop(img_list, crop_h, crop_w): 5 | img_size = tf.shape(img_list[0]) 6 | # crop image and flow 7 | rand_offset_h = tf.random_uniform([], 0, img_size[0]-crop_h+1, dtype=tf.int32) 8 | rand_offset_w = tf.random_uniform([], 0, img_size[1]-crop_w+1, dtype=tf.int32) 9 | 10 | for i, img in enumerate(img_list): 11 | img_list[i] = tf.image.crop_to_bounding_box(img, rand_offset_h, rand_offset_w, crop_h, crop_w) 12 | 13 | return img_list 14 | 15 | def flow_vertical_flip(flow): 16 | flow = tf.image.flip_up_down(flow) 17 | flow_u, flow_v = tf.unstack(flow, axis=-1) 18 | flow_v = flow_v * -1 19 | flow = tf.stack([flow_u, flow_v], axis=-1) 20 | return flow 21 | 22 | def flow_horizontal_flip(flow): 23 | flow = tf.image.flip_left_right(flow) 24 | flow_u, flow_v = tf.unstack(flow, axis=-1) 25 | flow_u = flow_u * -1 26 | flow = tf.stack([flow_u, flow_v], axis=-1) 27 | return flow 28 | 29 | def random_flip(img_list): 30 | is_flip = tf.random_uniform([2], minval=0, maxval=2, dtype=tf.int32) 31 | 32 | for i in range(len(img_list)): 33 | img_list[i] = tf.where(is_flip[0] > 0, tf.image.flip_left_right(img_list[i]), img_list[i]) 34 | img_list[i] = tf.where(is_flip[1] > 0, tf.image.flip_up_down(img_list[i]), img_list[i]) 35 | return img_list 36 | 37 | def random_flip_with_flow(img_list, flow_list): 38 | is_flip = tf.random_uniform([2], minval=0, maxval=2, dtype=tf.int32) 39 | for i in range(len(img_list)): 40 | img_list[i] = tf.where(is_flip[0] > 0, tf.image.flip_left_right(img_list[i]), img_list[i]) 41 | img_list[i] = tf.where(is_flip[1] > 0, tf.image.flip_up_down(img_list[i]), img_list[i]) 42 | for i in range(len(flow_list)): 43 | flow_list[i] = tf.where(is_flip[0] > 0, flow_horizontal_flip(flow_list[i]), flow_list[i]) 44 | flow_list[i] = tf.where(is_flip[1] > 0, flow_vertical_flip(flow_list[i]), flow_list[i]) 45 | return img_list, flow_list 46 | 47 | 48 | def random_channel_swap(img_list): 49 | channel_permutation = tf.constant([[0, 1, 2], 50 | [0, 2, 1], 51 | [1, 0, 2], 52 | [1, 2, 0], 53 | [2, 0, 1], 54 | [2, 1, 0]]) 55 | rand_i = tf.random_uniform([], minval=0, maxval=6, dtype=tf.int32) 56 | perm = channel_permutation[rand_i] 57 | for i, img in enumerate(img_list): 58 | channel_1 = img[:, :, perm[0]] 59 | channel_2 = img[:, :, perm[1]] 60 | channel_3 = img[:, :, perm[2]] 61 | img_list[i] = tf.stack([channel_1, channel_2, channel_3], axis=-1) 62 | return img_list 63 | 64 | def flow_resize(flow, out_size, is_scale=True, method=0): 65 | ''' 66 | method: 0 mean bilinear, 1 means nearest 67 | ''' 68 | flow_size = tf.to_float(tf.shape(flow)[-3:-1]) 69 | flow = tf.image.resize_images(flow, out_size, method=method, align_corners=True) 70 | if is_scale: 71 | scale = tf.to_float(out_size) / flow_size 72 | scale = tf.stack([scale[1], scale[0]]) 73 | flow = tf.multiply(flow, scale) 74 | return flow 75 | 76 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import tensorflow as tf 3 | import numpy as np 4 | import scipy.misc as misc 5 | import cv2 6 | import matplotlib.pyplot as plt 7 | from flowlib import read_flo, read_pfm 8 | from data_augmentation import * 9 | from utils import mvn 10 | 11 | class BasicDataset(object): 12 | def __init__(self, crop_h=320, crop_w=896, batch_size=4, data_list_file='path_to_your_data_list_file', 13 | img_dir='path_to_your_image_directory', fake_flow_occ_dir='path_to_your_fake_flow_occlusion_directory', is_normalize_img=True): 14 | self.crop_h = crop_h 15 | self.crop_w = crop_w 16 | self.batch_size = batch_size 17 | self.img_dir = img_dir 18 | self.data_list = np.loadtxt(data_list_file, dtype=bytes).astype(np.str) 19 | self.data_num = self.data_list.shape[0] 20 | self.fake_flow_occ_dir = fake_flow_occ_dir 21 | self.is_normalize_img = is_normalize_img 22 | 23 | # KITTI's data format for storing flow and mask 24 | # The first two channels are flow, the third channel is mask 25 | def extract_flow_and_mask(self, flow): 26 | optical_flow = flow[:, :, :2] 27 | optical_flow = (optical_flow - 32768) / 64.0 28 | mask = tf.cast(tf.greater(flow[:, :, 2], 0), tf.float32) 29 | #mask = tf.cast(flow[:, :, 2], tf.float32) 30 | mask = tf.expand_dims(mask, -1) 31 | return optical_flow, mask 32 | 33 | # The default image type is PNG. 34 | def read_and_decode(self, filename_queue): 35 | img0_name = tf.string_join([self.img_dir, '/', filename_queue[0]]) 36 | img1_name = tf.string_join([self.img_dir, '/', filename_queue[1]]) 37 | img2_name = tf.string_join([self.img_dir, '/', filename_queue[2]]) 38 | 39 | img0 = tf.image.decode_png(tf.read_file(img0_name), channels=3) 40 | img0 = tf.cast(img0, tf.float32) 41 | img1 = tf.image.decode_png(tf.read_file(img1_name), channels=3) 42 | img1 = tf.cast(img1, tf.float32) 43 | img2 = tf.image.decode_png(tf.read_file(img2_name), channels=3) 44 | img2 = tf.cast(img2, tf.float32) 45 | 46 | return img0, img1, img2 47 | 48 | # For Validation or Testing 49 | def preprocess_one_shot(self, filename_queue): 50 | img0, img1, img2 = self.read_and_decode(filename_queue) 51 | img0 = img0 / 255. 52 | img1 = img1 / 255. 53 | img2 = img2 / 255. 54 | 55 | if self.is_normalize_img: 56 | img0 = mvn(img0) 57 | img1 = mvn(img1) 58 | img2 = mvn(img2) 59 | return img0, img1, img2 60 | 61 | 62 | def create_one_shot_iterator(self, data_list, num_parallel_calls=4): 63 | """ For Validation or Testing 64 | Generate image and flow one_by_one without cropping, image and flow size may change every iteration 65 | """ 66 | data_list = tf.convert_to_tensor(data_list, dtype=tf.string) 67 | dataset = tf.data.Dataset.from_tensor_slices(data_list) 68 | dataset = dataset.map(self.preprocess_one_shot, num_parallel_calls=num_parallel_calls) 69 | dataset = dataset.batch(1) 70 | dataset = dataset.repeat() 71 | iterator = dataset.make_initializable_iterator() 72 | return iterator -------------------------------------------------------------------------------- /flowlib.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import re 5 | import sys 6 | 7 | def read_flo(filename): 8 | with open(filename, 'rb') as f: 9 | magic = np.fromfile(f, np.float32, count=1) 10 | if 202021.25 != magic: 11 | print('Magic number incorrect. Invalid .flo file') 12 | else: 13 | w = np.fromfile(f, np.int32, count=1) 14 | h = np.fromfile(f, np.int32, count=1) 15 | data = np.fromfile(f, np.float32, count=int(2*w*h)) 16 | # Reshape data into 3D array (columns, rows, bands) 17 | data2D = np.resize(data, (h[0], w[0],2)) 18 | return data2D 19 | 20 | def write_flo(filename, flow): 21 | """ 22 | write optical flow in Middlebury .flo format 23 | :param flow: optical flow map 24 | :param filename: optical flow file path to be saved 25 | :return: None 26 | """ 27 | f = open(filename, 'wb') 28 | magic = np.array([202021.25], dtype=np.float32) 29 | (height, width) = flow.shape[0:2] 30 | w = np.array([width], dtype=np.int32) 31 | h = np.array([height], dtype=np.int32) 32 | magic.tofile(f) 33 | w.tofile(f) 34 | h.tofile(f) 35 | flow.tofile(f) 36 | f.close() 37 | 38 | 39 | def read_pfm(file): 40 | file = open(file, 'rb') 41 | 42 | color = None 43 | width = None 44 | height = None 45 | scale = None 46 | endian = None 47 | 48 | header = file.readline().rstrip() 49 | header = header.decode('utf-8') 50 | if header == 'PF': 51 | color = True 52 | elif header == 'Pf': 53 | color = False 54 | else: 55 | raise Exception('Not a PFM file.') 56 | 57 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8')) 58 | if dim_match: 59 | width, height = map(int, dim_match.groups()) 60 | else: 61 | raise Exception('Malformed PFM header.') 62 | 63 | scale = float(file.readline().rstrip().decode('utf-8')) 64 | if scale < 0: # little-endian 65 | endian = '<' 66 | scale = -scale 67 | else: 68 | endian = '>' # big-endian 69 | 70 | data = np.fromfile(file, endian + 'f') 71 | shape = (height, width, 3) if color else (height, width) 72 | 73 | data = np.reshape(data, shape) 74 | data = np.flipud(data) 75 | data = data[:, :, :2] 76 | return data 77 | 78 | 79 | def write_pfm(file, image, scale=1): 80 | file = open(file, 'wb') 81 | 82 | color = None 83 | 84 | if image.dtype.name != 'float32': 85 | raise Exception('Image dtype must be float32.') 86 | 87 | image = np.flipud(image) 88 | 89 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 90 | color = True 91 | elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale 92 | color = False 93 | else: 94 | raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.') 95 | 96 | file.write('PF\n' if color else 'Pf\n') 97 | file.write('%d %d\n' % (image.shape[1], image.shape[0])) 98 | 99 | endian = image.dtype.byteorder 100 | 101 | if endian == '<' or endian == '=' and sys.byteorder == 'little': 102 | scale = -scale 103 | 104 | file.write('%f\n' % scale) 105 | 106 | image.tofile(file) 107 | 108 | 109 | def flow_to_color(flow, mask=None, max_flow=None): 110 | """Converts flow to 3-channel color image. 111 | 112 | Args: 113 | flow: tensor of shape [num_batch, height, width, 2]. 114 | mask: flow validity mask of shape [num_batch, height, width, 1]. 115 | """ 116 | n = 8 117 | num_batch, height, width, _ = tf.unstack(tf.shape(flow)) 118 | mask = tf.ones([num_batch, height, width, 1]) if mask is None else mask 119 | flow_u, flow_v = tf.unstack(flow, axis=3) 120 | if max_flow is not None: 121 | max_flow = tf.maximum(tf.to_float(max_flow), 1.) 122 | else: 123 | max_flow = tf.reduce_max(tf.abs(flow * mask)) 124 | mag = tf.sqrt(tf.reduce_sum(tf.square(flow), 3)) 125 | angle = tf.atan2(flow_v, flow_u) 126 | 127 | im_h = tf.mod(angle / (2 * np.pi) + 1.0, 1.0) 128 | im_s = tf.clip_by_value(mag * n / max_flow, 0, 1) 129 | im_v = tf.clip_by_value(n - im_s, 0, 1) 130 | im_hsv = tf.stack([im_h, im_s, im_v], 3) 131 | im = tf.image.hsv_to_rgb(im_hsv) 132 | return im * mask 133 | 134 | 135 | def flow_error_image(flow_1, flow_2, mask_occ, mask_noc=None, log_colors=True): 136 | """Visualize the error between two flows as 3-channel color image. 137 | 138 | Adapted from the KITTI C++ devkit. 139 | 140 | Args: 141 | flow_1: first flow of shape [num_batch, height, width, 2]. 142 | flow_2: second flow (ground truth) 143 | mask_occ: flow validity mask of shape [num_batch, height, width, 1]. 144 | Equals 1 at (occluded and non-occluded) valid pixels. 145 | mask_noc: Is 1 only at valid pixels which are not occluded. 146 | """ 147 | mask_noc = tf.ones(tf.shape(mask_occ)) if mask_noc is None else mask_noc 148 | diff_sq = (flow_1 - flow_2) ** 2 149 | diff = tf.sqrt(tf.reduce_sum(diff_sq, [3], keepdims=True)) 150 | if log_colors: 151 | num_batch, height, width, _ = tf.unstack(tf.shape(flow_1)) 152 | colormap = [ 153 | [0,0.0625,49,54,149], 154 | [0.0625,0.125,69,117,180], 155 | [0.125,0.25,116,173,209], 156 | [0.25,0.5,171,217,233], 157 | [0.5,1,224,243,248], 158 | [1,2,254,224,144], 159 | [2,4,253,174,97], 160 | [4,8,244,109,67], 161 | [8,16,215,48,39], 162 | [16,1000000000.0,165,0,38]] 163 | colormap = np.asarray(colormap, dtype=np.float32) 164 | colormap[:, 2:5] = colormap[:, 2:5] / 255 165 | mag = tf.sqrt(tf.reduce_sum(tf.square(flow_2), 3, keepdims=True)) 166 | error = tf.minimum(diff / 3, 20 * diff / mag) 167 | im = tf.zeros([num_batch, height, width, 3]) 168 | for i in range(colormap.shape[0]): 169 | colors = colormap[i, :] 170 | cond = tf.logical_and(tf.greater_equal(error, colors[0]), 171 | tf.less(error, colors[1])) 172 | im = tf.where(tf.tile(cond, [1, 1, 1, 3]), 173 | tf.ones([num_batch, height, width, 1]) * colors[2:5], 174 | im) 175 | im = tf.where(tf.tile(tf.cast(mask_noc, tf.bool), [1, 1, 1, 3]), 176 | im, im * 0.5) 177 | im = im * mask_occ 178 | else: 179 | error = (tf.minimum(diff, 5) / 5) * mask_occ 180 | im_r = error # errors in occluded areas will be red 181 | im_g = error * mask_noc 182 | im_b = error * mask_noc 183 | im = tf.concat(axis=3, values=[im_r, im_g, im_b]) 184 | return im -------------------------------------------------------------------------------- /images/dance.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ppliuboy/SelFlow/0525a2a61b07785bb7c9d6c278769c218c57fc35/images/dance.gif -------------------------------------------------------------------------------- /images/sintel_benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ppliuboy/SelFlow/0525a2a61b07785bb7c9d6c278769c218c57fc35/images/sintel_benchmark.png -------------------------------------------------------------------------------- /images/test_images/frame_0004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ppliuboy/SelFlow/0525a2a61b07785bb7c9d6c278769c218c57fc35/images/test_images/frame_0004.png -------------------------------------------------------------------------------- /images/test_images/frame_0005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ppliuboy/SelFlow/0525a2a61b07785bb7c9d6c278769c218c57fc35/images/test_images/frame_0005.png -------------------------------------------------------------------------------- /images/test_images/frame_0006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ppliuboy/SelFlow/0525a2a61b07785bb7c9d6c278769c218c57fc35/images/test_images/frame_0006.png -------------------------------------------------------------------------------- /images/test_images/frame_0007.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ppliuboy/SelFlow/0525a2a61b07785bb7c9d6c278769c218c57fc35/images/test_images/frame_0007.png -------------------------------------------------------------------------------- /images/test_images/frame_0008.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ppliuboy/SelFlow/0525a2a61b07785bb7c9d6c278769c218c57fc35/images/test_images/frame_0008.png -------------------------------------------------------------------------------- /img_list/KITTI/train_2015.txt: -------------------------------------------------------------------------------- 1 | 000186 2 | 000163 3 | 000072 4 | 000056 5 | 000071 6 | 000050 7 | 000018 8 | 000041 9 | 000029 10 | 000198 11 | 000086 12 | 000135 13 | 000112 14 | 000175 15 | 000188 16 | 000096 17 | 000130 18 | 000128 19 | 000031 20 | 000027 21 | 000095 22 | 000074 23 | 000089 24 | 000026 25 | 000022 26 | 000165 27 | 000040 28 | 000122 29 | 000117 30 | 000194 31 | 000116 32 | 000069 33 | 000145 34 | 000150 35 | 000193 36 | 000035 37 | 000176 38 | 000062 39 | 000118 40 | 000097 41 | 000164 42 | 000161 43 | 000063 44 | 000002 45 | 000091 46 | 000113 47 | 000126 48 | 000103 49 | 000121 50 | 000115 51 | 000155 52 | 000141 53 | 000013 54 | 000007 55 | 000190 56 | 000119 57 | 000061 58 | 000192 59 | 000199 60 | 000181 61 | 000036 62 | 000010 63 | 000034 64 | 000060 65 | 000148 66 | 000153 67 | 000149 68 | 000133 69 | 000104 70 | 000174 71 | 000058 72 | 000076 73 | 000087 74 | 000178 75 | 000196 76 | 000043 77 | 000024 78 | 000080 79 | 000009 80 | 000108 81 | 000147 82 | 000173 83 | 000195 84 | 000092 85 | 000143 86 | 000011 87 | 000187 88 | 000102 89 | 000144 90 | 000110 91 | 000025 92 | 000016 93 | 000093 94 | 000094 95 | 000179 96 | 000167 97 | 000012 98 | 000125 99 | 000184 100 | 000059 101 | 000055 102 | 000079 103 | 000142 104 | 000185 105 | 000067 106 | 000124 107 | 000052 108 | 000019 109 | 000084 110 | 000068 111 | 000075 112 | 000106 113 | 000172 114 | 000182 115 | 000001 116 | 000051 117 | 000077 118 | 000078 119 | 000048 120 | 000033 121 | 000131 122 | 000006 123 | 000090 124 | 000152 125 | 000170 126 | 000028 127 | 000032 128 | 000134 129 | 000014 130 | 000046 131 | 000005 132 | 000168 133 | 000082 134 | 000042 135 | 000171 136 | 000020 137 | 000139 138 | 000158 139 | 000137 140 | 000159 141 | 000162 142 | 000049 143 | 000100 144 | 000021 145 | 000045 146 | 000047 147 | 000088 148 | 000197 149 | 000183 150 | 000138 151 | 000151 152 | 000101 153 | 000114 154 | 000127 155 | 000123 156 | 000015 157 | 000008 158 | 000044 159 | 000191 160 | 000140 161 | 000099 162 | 000073 163 | 000180 164 | 000160 165 | 000129 166 | 000177 167 | 000136 168 | 000037 169 | 000146 170 | 000054 171 | -------------------------------------------------------------------------------- /img_list/KITTI/val_2015.txt: -------------------------------------------------------------------------------- 1 | 000109 2 | 000107 3 | 000105 4 | 000111 5 | 000120 6 | 000070 7 | 000003 8 | 000065 9 | 000066 10 | 000085 11 | 000017 12 | 000053 13 | 000157 14 | 000189 15 | 000154 16 | 000039 17 | 000000 18 | 000030 19 | 000064 20 | 000156 21 | 000132 22 | 000169 23 | 000166 24 | 000004 25 | 000023 26 | 000083 27 | 000057 28 | 000098 29 | 000081 30 | 000038 31 | -------------------------------------------------------------------------------- /img_list/sintel_raw_clip_split.txt: -------------------------------------------------------------------------------- 1 | clip_0000 00000555.png 00000752.png 2 | clip_0001 00000753.png 00001018.png 3 | clip_0002 00002062.png 00002111.png 4 | clip_0003 00002126.png 00002155.png 5 | clip_0004 00002546.png 00002624.png 6 | clip_0005 00002625.png 00002780.png 7 | clip_0006 00002781.png 00002835.png 8 | clip_0007 00001170.png 00001295.png 9 | clip_0008 00001019.png 00001036.png 10 | clip_0009 00001101.png 00001136.png 11 | clip_0010 00001340.png 00001370.png 12 | clip_0011 00001456.png 00001494.png 13 | clip_0012 00001553.png 00001637.png 14 | clip_0013 00001638.png 00001663.png 15 | clip_0014 00001664.png 00001777.png 16 | clip_0015 00002882.png 00002955.png 17 | clip_0016 00002956.png 00003032.png 18 | clip_0017 00003033.png 00003144.png 19 | clip_0018 00003145.png 00003217.png 20 | clip_0019 00003316.png 00003384.png 21 | clip_0020 00003385.png 00003509.png 22 | clip_0021 00003636.png 00003740.png 23 | clip_0022 00003741.png 00003890.png 24 | clip_0023 00003891.png 00004234.png 25 | clip_0024 00004262.png 00004322.png 26 | clip_0025 00004369.png 00004425.png 27 | clip_0026 00004426.png 00004481.png 28 | clip_0027 00004482.png 00004555.png 29 | clip_0028 00004556.png 00004605.png 30 | clip_0029 00004606.png 00004671.png 31 | clip_0030 00004672.png 00004714.png 32 | clip_0031 00004715.png 00004743.png 33 | clip_0032 00004744.png 00004875.png 34 | clip_0033 00004876.png 00004920.png 35 | clip_0034 00004921.png 00004944.png 36 | clip_0035 00004945.png 00005082.png 37 | clip_0036 00005083.png 00005158.png 38 | clip_0037 00005159.png 00005267.png 39 | clip_0038 00005268.png 00005379.png 40 | clip_0039 00005380.png 00005956.png 41 | clip_0040 00006375.png 00006519.png 42 | clip_0041 00006520.png 00006558.png 43 | clip_0042 00006559.png 00006615.png 44 | clip_0043 00006616.png 00006691.png 45 | clip_0044 00006692.png 00006803.png 46 | clip_0045 00006804.png 00006900.png 47 | clip_0046 00006901.png 00006944.png 48 | clip_0047 00006945.png 00007134.png 49 | clip_0048 00005957.png 00006021.png 50 | clip_0049 00006022.png 00006071.png 51 | clip_0050 00006113.png 00006139.png 52 | clip_0051 00006140.png 00006221.png 53 | clip_0052 00006222.png 00006318.png 54 | clip_0053 00006319.png 00006374.png 55 | clip_0054 00007135.png 00007266.png 56 | clip_0055 00007267.png 00007351.png 57 | clip_0056 00007352.png 00007414.png 58 | clip_0057 00007415.png 00007467.png 59 | clip_0058 00007468.png 00007503.png 60 | clip_0059 00007504.png 00007559.png 61 | clip_0060 00007574.png 00007621.png 62 | clip_0061 00007622.png 00007656.png 63 | clip_0062 00007657.png 00007733.png 64 | clip_0063 00007734.png 00007769.png 65 | clip_0064 00007770.png 00007876.png 66 | clip_0065 00007877.png 00007929.png 67 | clip_0066 00007930.png 00007971.png 68 | clip_0067 00007988.png 00008047.png 69 | clip_0068 00008048.png 00008099.png 70 | clip_0069 00008100.png 00008144.png 71 | clip_0070 00008145.png 00008215.png 72 | clip_0071 00008216.png 00008394.png 73 | clip_0072 00008396.png 00008435.png 74 | clip_0073 00008436.png 00008480.png 75 | clip_0074 00009191.png 00009245.png 76 | clip_0075 00009262.png 00009421.png 77 | clip_0076 00009526.png 00009721.png 78 | clip_0077 00009791.png 00009962.png 79 | clip_0078 00009988.png 00010047.png 80 | clip_0079 00010176.png 00010220.png 81 | clip_0080 00010221.png 00010358.png 82 | clip_0081 00010359.png 00010424.png 83 | clip_0082 00010514.png 00010617.png 84 | clip_0083 00010618.png 00010664.png 85 | clip_0084 00010665.png 00010740.png 86 | clip_0085 00010741.png 00010811.png 87 | clip_0086 00010812.png 00010907.png 88 | clip_0087 00010908.png 00010950.png 89 | clip_0088 00010951.png 00010993.png 90 | clip_0089 00010994.png 00011050.png 91 | clip_0090 00011051.png 00011093.png 92 | clip_0091 00011094.png 00011199.png 93 | clip_0092 00011200.png 00011309.png 94 | clip_0093 00011401.png 00011486.png 95 | clip_0094 00011487.png 00011609.png 96 | clip_0095 00011612.png 00011700.png 97 | clip_0096 00012060.png 00012139.png 98 | clip_0097 00012140.png 00012210.png 99 | clip_0098 00012305.png 00012380.png 100 | clip_0099 00012418.png 00012562.png 101 | clip_0100 00012606.png 00012638.png 102 | clip_0101 00012639.png 00012807.png 103 | clip_0102 00012808.png 00012872.png 104 | clip_0103 00012873.png 00012944.png 105 | clip_0104 00012945.png 00012986.png 106 | clip_0105 00012987.png 00013011.png 107 | clip_0106 00013105.png 00013130.png 108 | clip_0107 00013233.png 00013319.png 109 | clip_0108 00013320.png 00013372.png 110 | clip_0109 00013373.png 00013406.png 111 | clip_0110 00013470.png 00013564.png 112 | clip_0111 00013753.png 00013830.png 113 | clip_0112 00013831.png 00013892.png 114 | clip_0113 00013961.png 00014057.png 115 | clip_0114 00014058.png 00014092.png 116 | clip_0115 00014142.png 00014177.png 117 | clip_0116 00014249.png 00014288.png 118 | clip_0117 00014289.png 00014354.png 119 | clip_0118 00014456.png 00014493.png 120 | clip_0119 00014554.png 00014620.png 121 | clip_0120 00014621.png 00014685.png 122 | clip_0121 00014686.png 00014731.png 123 | clip_0122 00014803.png 00014876.png 124 | clip_0123 00015783.png 00015825.png 125 | clip_0124 00016124.png 00016165.png 126 | clip_0125 00016862.png 00016911.png 127 | clip_0126 00001142.png 00001169.png 128 | clip_0127 00001296.png 00001319.png 129 | clip_0128 00001037.png 00001056.png 130 | clip_0129 00001057.png 00001078.png 131 | clip_0130 00001079.png 00001100.png 132 | clip_0131 00001371.png 00001425.png 133 | clip_0132 00001320.png 00001339.png 134 | clip_0133 00001495.png 00001552.png 135 | clip_0134 00004235.png 00004261.png 136 | clip_0135 00004323.png 00004368.png 137 | clip_0136 00006072.png 00006112.png 138 | clip_0137 00009109.png 00009133.png 139 | clip_0138 00009134.png 00009190.png 140 | clip_0139 00009449.png 00009525.png 141 | clip_0140 00012563.png 00012605.png 142 | clip_0141 00013180.png 00013216.png 143 | clip_0142 00013217.png 00013232.png 144 | clip_0143 00013906.png 00013960.png 145 | -------------------------------------------------------------------------------- /img_list/test_img_list.txt: -------------------------------------------------------------------------------- 1 | frame_0004.png frame_0005.png frame_0006.png 0005 2 | frame_0005.png frame_0006.png frame_0007.png 0006 3 | frame_0006.png frame_0007.png frame_0008.png 0007 4 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import os 4 | 5 | from selflow_model import SelFlowModel 6 | from config.extract_config import config_dict 7 | 8 | # manually select one or several free gpu 9 | # os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1, 2, 3' 10 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 11 | 12 | # autonatically select one free gpu 13 | #os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp') 14 | #os.environ['CUDA_VISIBLE_DEVICES']=str(np.argmax([int(x.split()[2]) for x in open('tmp','r').readlines()])) 15 | #os.system('rm tmp') 16 | 17 | 18 | def main(_): 19 | config = config_dict('./config/config.ini') 20 | run_config = config['run'] 21 | dataset_config = config['dataset'] 22 | self_supervision_config = config['self_supervision'] 23 | model = SelFlowModel(batch_size=run_config['batch_size'], 24 | iter_steps=run_config['iter_steps'], 25 | initial_learning_rate=run_config['initial_learning_rate'], 26 | decay_steps=run_config['decay_steps'], 27 | decay_rate=run_config['decay_rate'], 28 | is_scale=run_config['is_scale'], 29 | num_input_threads=run_config['num_input_threads'], 30 | buffer_size=run_config['buffer_size'], 31 | beta1=run_config['beta1'], 32 | num_gpus=run_config['num_gpus'], 33 | save_checkpoint_interval=run_config['save_checkpoint_interval'], 34 | write_summary_interval=run_config['write_summary_interval'], 35 | display_log_interval=run_config['display_log_interval'], 36 | allow_soft_placement=run_config['allow_soft_placement'], 37 | log_device_placement=run_config['log_device_placement'], 38 | regularizer_scale=run_config['regularizer_scale'], 39 | cpu_device=run_config['cpu_device'], 40 | save_dir=run_config['save_dir'], 41 | checkpoint_dir=run_config['checkpoint_dir'], 42 | model_name=run_config['model_name'], 43 | sample_dir=run_config['sample_dir'], 44 | summary_dir=run_config['summary_dir'], 45 | training_mode=run_config['training_mode'], 46 | is_restore_model=run_config['is_restore_model'], 47 | restore_model=run_config['restore_model'], 48 | dataset_config=dataset_config, 49 | self_supervision_config=self_supervision_config 50 | ) 51 | 52 | if run_config['mode'] == 'test': 53 | model.test(restore_model=config['test']['restore_model'], 54 | save_dir=config['test']['save_dir'], 55 | is_normalize_img=dataset_config['is_normalize_img']) 56 | else: 57 | raise ValueError('Invalid mode. Mode should be one of {test}') 58 | 59 | if __name__ == '__main__': 60 | tf.app.run() 61 | -------------------------------------------------------------------------------- /models/KITTI 2015/unsupervise_with_self_supervision.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ppliuboy/SelFlow/0525a2a61b07785bb7c9d6c278769c218c57fc35/models/KITTI 2015/unsupervise_with_self_supervision.data-00000-of-00001 -------------------------------------------------------------------------------- /models/KITTI 2015/unsupervise_with_self_supervision.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ppliuboy/SelFlow/0525a2a61b07785bb7c9d6c278769c218c57fc35/models/KITTI 2015/unsupervise_with_self_supervision.index -------------------------------------------------------------------------------- /models/Sintel/supervise_finetune.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ppliuboy/SelFlow/0525a2a61b07785bb7c9d6c278769c218c57fc35/models/Sintel/supervise_finetune.data-00000-of-00001 -------------------------------------------------------------------------------- /models/Sintel/supervise_finetune.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ppliuboy/SelFlow/0525a2a61b07785bb7c9d6c278769c218c57fc35/models/Sintel/supervise_finetune.index -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib import slim 3 | from data_augmentation import flow_resize 4 | from utils import lrelu 5 | from warp import tf_warp 6 | 7 | def feature_extractor(x, train=True, trainable=True, reuse=None, regularizer=None, name='feature_extractor'): 8 | with tf.variable_scope(name, reuse=reuse, regularizer=regularizer): 9 | with slim.arg_scope([slim.conv2d], activation_fn=lrelu, kernel_size=3, padding='SAME', trainable=trainable): 10 | net = {} 11 | net['conv1_1'] = slim.conv2d(x, 16, stride=2, scope='conv1_1') 12 | net['conv1_2'] = slim.conv2d(net['conv1_1'], 16, stride=1, scope='conv1_2') 13 | 14 | net['conv2_1'] = slim.conv2d(net['conv1_2'], 32, stride=2, scope='conv2_1') 15 | net['conv2_2'] = slim.conv2d(net['conv2_1'], 32, stride=1, scope='conv2_2') 16 | 17 | net['conv3_1'] = slim.conv2d(net['conv2_2'], 64, stride=2, scope='conv3_1') 18 | net['conv3_2'] = slim.conv2d(net['conv3_1'], 64, stride=1, scope='conv3_2') 19 | 20 | net['conv4_1'] = slim.conv2d(net['conv3_2'], 96, stride=2, scope='conv4_1') 21 | net['conv4_2'] = slim.conv2d(net['conv4_1'], 96, stride=1, scope='conv4_2') 22 | 23 | net['conv5_1'] = slim.conv2d(net['conv4_2'], 128, stride=2, scope='conv5_1') 24 | net['conv5_2'] = slim.conv2d(net['conv5_1'], 128, stride=1, scope='conv5_2') 25 | 26 | net['conv6_1'] = slim.conv2d(net['conv5_2'], 192, stride=2, scope='conv6_1') 27 | net['conv6_2'] = slim.conv2d(net['conv6_1'], 192, stride=1, scope='conv6_2') 28 | 29 | return net 30 | 31 | def context_network(x, flow, train=True, trainable=True, reuse=None, regularizer=None, name='context_network'): 32 | x_input = tf.concat([x, flow], axis=-1) 33 | with tf.variable_scope(name, reuse=reuse, regularizer=regularizer): 34 | with slim.arg_scope([slim.conv2d], activation_fn=lrelu, kernel_size=3, padding='SAME', trainable=trainable): 35 | net = {} 36 | net['dilated_conv1'] = slim.conv2d(x_input, 128, rate=1, scope='dilated_conv1') 37 | net['dilated_conv2'] = slim.conv2d(net['dilated_conv1'], 128, rate=2, scope='dilated_conv2') 38 | net['dilated_conv3'] = slim.conv2d(net['dilated_conv2'], 128, rate=4, scope='dilated_conv3') 39 | net['dilated_conv4'] = slim.conv2d(net['dilated_conv3'], 96, rate=8, scope='dilated_conv4') 40 | net['dilated_conv5'] = slim.conv2d(net['dilated_conv4'], 64, rate=16, scope='dilated_conv5') 41 | net['dilated_conv6'] = slim.conv2d(net['dilated_conv5'], 32, rate=1, scope='dilated_conv6') 42 | net['dilated_conv7'] = slim.conv2d(net['dilated_conv6'], 2, rate=1, activation_fn=None, scope='dilated_conv7') 43 | 44 | refined_flow = net['dilated_conv7'] 45 | return refined_flow 46 | 47 | def estimator_network(x1, cost_volume, flow, train=True, trainable=True, reuse=None, regularizer=None, name='estimator'): 48 | net_input = tf.concat([cost_volume, x1, flow], axis=-1) 49 | with tf.variable_scope(name, reuse=reuse, regularizer=regularizer): 50 | with slim.arg_scope([slim.conv2d], activation_fn=lrelu, kernel_size=3, padding='SAME', trainable=trainable): 51 | net = {} 52 | net['conv1'] = slim.conv2d(net_input, 128, scope='conv1') 53 | net['conv2'] = slim.conv2d(net['conv1'], 128, scope='conv2') 54 | net['conv3'] = slim.conv2d(net['conv2'], 96, scope='conv3') 55 | net['conv4'] = slim.conv2d(net['conv3'], 64, scope='conv4') 56 | net['conv5'] = slim.conv2d(net['conv4'], 32, scope='conv5') 57 | net['conv6'] = slim.conv2d(net['conv5'], 2, activation_fn=None, scope='conv6') 58 | 59 | #flow_estimated = net['conv6'] 60 | 61 | return net 62 | 63 | def compute_cost_volume(x1, x2, H, W, channel, d=9): 64 | x1 = tf.nn.l2_normalize(x1, axis=3) 65 | x2 = tf.nn.l2_normalize(x2, axis=3) 66 | 67 | # choice 1: use tf.extract_image_patches, may not work for some tensorflow versions 68 | x2_patches = tf.extract_image_patches(x2, [1, d, d, 1], strides=[1, 1, 1, 1], rates=[1, 1, 1, 1], padding='SAME') 69 | 70 | # choice 2: use convolution, but is slower than choice 1 71 | # out_channels = d * d 72 | # w = tf.eye(out_channels*channel, dtype=tf.float32) 73 | # w = tf.reshape(w, (d, d, channel, out_channels*channel)) 74 | # x2_patches = tf.nn.conv2d(x2, w, strides=[1, 1, 1, 1], padding='SAME') 75 | 76 | x2_patches = tf.reshape(x2_patches, [-1, H, W, d, d, channel]) 77 | x1_reshape = tf.reshape(x1, [-1, H, W, 1, 1, channel]) 78 | x1_dot_x2 = tf.multiply(x1_reshape, x2_patches) 79 | 80 | cost_volume = tf.reduce_sum(x1_dot_x2, axis=-1) 81 | #cost_volume = tf.reduce_mean(x1_dot_x2, axis=-1) 82 | cost_volume = tf.reshape(cost_volume, [-1, H, W, d*d]) 83 | return cost_volume 84 | 85 | def estimator(x0, x1, x2, flow_fw, flow_bw, train=True, trainable=True, reuse=None, regularizer=None, name='estimator'): 86 | # warp x2 according to flow 87 | if train: 88 | x_shape = x1.get_shape().as_list() 89 | else: 90 | x_shape = tf.shape(x1) 91 | H = x_shape[1] 92 | W = x_shape[2] 93 | channel = x_shape[3] 94 | x2_warp = tf_warp(x2, flow_fw, H, W) 95 | x0_warp = tf_warp(x0, flow_bw, H, W) 96 | 97 | # ---------------cost volume----------------- 98 | 99 | cost_volume_fw = compute_cost_volume(x1, x2_warp, H, W, channel, d=9) 100 | cost_volume_bw = compute_cost_volume(x1, x0_warp, H, W, channel, d=9) 101 | 102 | cv_concat_fw = tf.concat([cost_volume_fw, cost_volume_bw], -1) 103 | cv_concat_bw = tf.concat([cost_volume_bw, cost_volume_fw], -1) 104 | 105 | flow_concat_fw = tf.concat([flow_fw, -flow_bw], -1) 106 | flow_concat_bw = tf.concat([flow_bw, -flow_fw], -1) 107 | 108 | net_fw = estimator_network(x1, cv_concat_fw, flow_concat_fw, train=train, trainable=trainable, reuse=reuse, regularizer=regularizer, name=name) 109 | net_bw = estimator_network(x1, cv_concat_bw, flow_concat_bw, train=train, trainable=trainable, reuse=True, regularizer=regularizer, name=name) 110 | 111 | return net_fw, net_bw 112 | 113 | def pyramid_processing_three_frame(batch_img, x0_feature, x1_feature, x2_feature, train=True, trainable=True, reuse=None, regularizer=None, is_scale=True): 114 | x_shape = tf.shape(x1_feature['conv6_2']) 115 | initial_flow_fw = tf.zeros([x_shape[0], x_shape[1], x_shape[2], 2], dtype=tf.float32, name='initial_flow_fw') 116 | initial_flow_bw = tf.zeros([x_shape[0], x_shape[1], x_shape[2], 2], dtype=tf.float32, name='initial_flow_bw') 117 | flow_fw = {} 118 | flow_bw = {} 119 | net_fw, net_bw = estimator(x0_feature['conv6_2'], x1_feature['conv6_2'], x2_feature['conv6_2'], 120 | initial_flow_fw, initial_flow_bw, train=train, trainable=trainable, reuse=reuse, regularizer=regularizer, name='estimator_level_6') 121 | flow_fw['level_6'] = net_fw['conv6'] 122 | flow_bw['level_6'] = net_bw['conv6'] 123 | 124 | 125 | for i in range(4): 126 | feature_name = 'conv%d_2' % (5-i) 127 | level = 'level_%d' % (5-i) 128 | feature_size = tf.shape(x1_feature[feature_name])[1:3] 129 | initial_flow_fw = flow_resize(flow_fw['level_%d' % (6-i)], feature_size, is_scale=is_scale) 130 | initial_flow_bw = flow_resize(flow_bw['level_%d' % (6-i)], feature_size, is_scale=is_scale) 131 | net_fw, net_bw = estimator(x0_feature[feature_name], x1_feature[feature_name], x2_feature[feature_name], 132 | initial_flow_fw, initial_flow_bw, train=train, trainable=trainable, reuse=reuse, regularizer=regularizer, name='estimator_level_%d' % (5-i)) 133 | flow_fw[level] = net_fw['conv6'] 134 | flow_bw[level] = net_bw['conv6'] 135 | 136 | flow_concat_fw = tf.concat([flow_fw['level_2'], -flow_bw['level_2']], -1) 137 | flow_concat_bw = tf.concat([flow_bw['level_2'], -flow_fw['level_2']], -1) 138 | 139 | x_feature = tf.concat([net_fw['conv5'], net_bw['conv5']], axis=-1) 140 | flow_fw['refined'] = context_network(x_feature, flow_concat_fw, train=train, trainable=trainable, reuse=reuse, regularizer=regularizer, name='context_network') 141 | flow_size = tf.shape(batch_img)[1:3] 142 | flow_fw['full_res'] = flow_resize(flow_fw['refined'], flow_size, is_scale=is_scale) 143 | 144 | x_feature = tf.concat([net_bw['conv5'], net_fw['conv5']], axis=-1) 145 | flow_bw['refined'] = context_network(x_feature, flow_concat_bw, train=train, trainable=trainable, reuse=True, regularizer=regularizer, name='context_network') 146 | flow_bw['full_res'] = flow_resize(flow_bw['refined'], flow_size, is_scale=is_scale) 147 | 148 | return flow_fw, flow_bw 149 | 150 | 151 | def pyramid_processing(batch_img0, batch_img1, batch_img2, train=True, trainable=True, reuse=None, regularizer=None, is_scale=True): 152 | x0_feature = feature_extractor(batch_img0, train=train, trainable=trainable, regularizer=regularizer, name='feature_extractor') 153 | x1_feature = feature_extractor(batch_img1, train=train, trainable=trainable, reuse=True, regularizer=regularizer, name='feature_extractor') 154 | x2_feature = feature_extractor(batch_img2, train=train, trainable=trainable, reuse=True, regularizer=regularizer, name='feature_extractor') 155 | 156 | flow_fw, flow_bw = pyramid_processing_three_frame(batch_img0, x0_feature, x1_feature, x2_feature, train=train, trainable=trainable, 157 | reuse=reuse, regularizer=regularizer, is_scale=is_scale) 158 | return flow_fw, flow_bw 159 | 160 | 161 | def pyramid_processing_five_frame(batch_img0, batch_img1, batch_img2, batch_img3, batch_img4, train=True, trainable=True, regularizer=None, is_scale=True): 162 | x0_feature = feature_extractor(batch_img0, train=train, trainable=trainable, regularizer=regularizer, name='feature_extractor') 163 | x1_feature = feature_extractor(batch_img1, train=train, trainable=trainable, reuse=True, regularizer=regularizer, name='feature_extractor') 164 | x2_feature = feature_extractor(batch_img2, train=train, trainable=trainable, reuse=True, regularizer=regularizer, name='feature_extractor') 165 | x3_feature = feature_extractor(batch_img3, train=train, trainable=trainable, reuse=True, regularizer=regularizer, name='feature_extractor') 166 | x4_feature = feature_extractor(batch_img4, train=train, trainable=trainable, reuse=True, regularizer=regularizer, name='feature_extractor') 167 | 168 | flow_fw_12, flow_bw_10 = pyramid_processing_three_frame(batch_img0, x0_feature, x1_feature, x2_feature, train=train, trainable=trainable, reuse=None, regularizer=regularizer, is_scale=is_scale) 169 | flow_fw_23, flow_bw_21 = pyramid_processing_three_frame(batch_img0, x1_feature, x2_feature, x3_feature, train=train, trainable=trainable, reuse=True, regularizer=regularizer, is_scale=is_scale) 170 | flow_fw_34, flow_bw_32 = pyramid_processing_three_frame(batch_img0, x2_feature, x3_feature, x4_feature, train=train, trainable=trainable, reuse=True, regularizer=regularizer, is_scale=is_scale) 171 | 172 | return flow_fw_12, flow_bw_10, flow_fw_23, flow_bw_21, flow_fw_34, flow_bw_32 173 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pillow==7.1.2 2 | scipy==0.19.1 3 | opencv-python==4.2.0.34 4 | matplotlib==3.2.1 -------------------------------------------------------------------------------- /selflow_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import division, print_function, absolute_import 3 | import tensorflow as tf 4 | import numpy as np 5 | import os 6 | import sys 7 | import time 8 | import cv2 9 | 10 | from six.moves import xrange 11 | from scipy import misc, io 12 | from tensorflow.contrib import slim 13 | 14 | import matplotlib.pyplot as plt 15 | from network import pyramid_processing 16 | from datasets import BasicDataset 17 | from utils import average_gradients, lrelu, occlusion, rgb_bgr 18 | from data_augmentation import flow_resize 19 | from flowlib import flow_to_color, write_flo 20 | from warp import tf_warp 21 | 22 | class SelFlowModel(object): 23 | def __init__(self, batch_size=8, iter_steps=1000000, initial_learning_rate=1e-4, decay_steps=2e5, 24 | decay_rate=0.5, is_scale=True, num_input_threads=4, buffer_size=5000, 25 | beta1=0.9, num_gpus=1, save_checkpoint_interval=5000, write_summary_interval=200, 26 | display_log_interval=50, allow_soft_placement=True, log_device_placement=False, 27 | regularizer_scale=1e-4, cpu_device='/cpu:0', save_dir='KITTI', checkpoint_dir='checkpoints', 28 | model_name='model', sample_dir='sample', summary_dir='summary', training_mode="no_distillation", 29 | is_restore_model=False, restore_model='./models/KITTI/no_census_no_occlusion', 30 | dataset_config={}, self_supervision_config={}): 31 | self.batch_size = batch_size 32 | self.iter_steps = iter_steps 33 | self.initial_learning_rate = initial_learning_rate 34 | self.decay_steps = decay_steps 35 | self.decay_rate = decay_rate 36 | self.is_scale = is_scale 37 | self.num_input_threads = num_input_threads 38 | self.buffer_size = buffer_size 39 | self.beta1 = beta1 40 | self.num_gpus = num_gpus 41 | self.save_checkpoint_interval = save_checkpoint_interval 42 | self.write_summary_interval = write_summary_interval 43 | self.display_log_interval = display_log_interval 44 | self.allow_soft_placement = allow_soft_placement 45 | self.log_device_placement = log_device_placement 46 | self.regularizer_scale = regularizer_scale 47 | self.training_mode = training_mode 48 | self.is_restore_model = is_restore_model 49 | self.restore_model = restore_model 50 | self.dataset_config = dataset_config 51 | self.self_supervision_config = self_supervision_config 52 | self.shared_device = '/gpu:0' if self.num_gpus == 1 else cpu_device 53 | assert(np.mod(batch_size, num_gpus) == 0) 54 | self.batch_size_per_gpu = int(batch_size / np.maximum(num_gpus, 1)) 55 | 56 | self.save_dir = save_dir 57 | if not os.path.exists(self.save_dir): 58 | os.makedirs(self.save_dir) 59 | 60 | self.checkpoint_dir = '/'.join([self.save_dir, checkpoint_dir]) 61 | if not os.path.exists(self.checkpoint_dir): 62 | os.makedirs(self.checkpoint_dir) 63 | 64 | self.model_name = model_name 65 | if not os.path.exists('/'.join([self.checkpoint_dir, model_name])): 66 | os.makedirs(('/'.join([self.checkpoint_dir, self.model_name]))) 67 | 68 | self.sample_dir = '/'.join([self.save_dir, sample_dir]) 69 | if not os.path.exists(self.sample_dir): 70 | os.makedirs(self.sample_dir) 71 | if not os.path.exists('/'.join([self.sample_dir, self.model_name])): 72 | os.makedirs(('/'.join([self.sample_dir, self.model_name]))) 73 | 74 | self.summary_dir = '/'.join([self.save_dir, summary_dir]) 75 | if not os.path.exists(self.summary_dir): 76 | os.makedirs(self.summary_dir) 77 | if not os.path.exists('/'.join([self.summary_dir, 'train'])): 78 | os.makedirs(('/'.join([self.summary_dir, 'train']))) 79 | if not os.path.exists('/'.join([self.summary_dir, 'test'])): 80 | os.makedirs(('/'.join([self.summary_dir, 'test']))) 81 | 82 | 83 | def test(self, restore_model, save_dir, is_normalize_img=True): 84 | dataset = BasicDataset(data_list_file=self.dataset_config['data_list_file'], img_dir=self.dataset_config['img_dir'], is_normalize_img=is_normalize_img) 85 | save_name_list = dataset.data_list[:, -1] 86 | iterator = dataset.create_one_shot_iterator(dataset.data_list, num_parallel_calls=self.num_input_threads) 87 | batch_img0, batch_img1, batch_img2 = iterator.get_next() 88 | img_shape = tf.shape(batch_img0) 89 | h = img_shape[1] 90 | w = img_shape[2] 91 | 92 | new_h = tf.where(tf.equal(tf.mod(h, 64), 0), h, (tf.to_int32(tf.floor(h / 64) + 1)) * 64) 93 | new_w = tf.where(tf.equal(tf.mod(w, 64), 0), w, (tf.to_int32(tf.floor(w / 64) + 1)) * 64) 94 | 95 | batch_img0 = tf.image.resize_images(batch_img0, [new_h, new_w], method=1, align_corners=True) 96 | batch_img1 = tf.image.resize_images(batch_img1, [new_h, new_w], method=1, align_corners=True) 97 | batch_img2 = tf.image.resize_images(batch_img2, [new_h, new_w], method=1, align_corners=True) 98 | 99 | flow_fw, flow_bw = pyramid_processing(batch_img0, batch_img1, batch_img2, train=False, trainable=False, is_scale=True) 100 | flow_fw['full_res'] = flow_resize(flow_fw['full_res'], [h, w], method=1) 101 | flow_bw['full_res'] = flow_resize(flow_bw['full_res'], [h, w], method=1) 102 | 103 | flow_fw_color = flow_to_color(flow_fw['full_res'], mask=None, max_flow=256) 104 | flow_bw_color = flow_to_color(flow_bw['full_res'], mask=None, max_flow=256) 105 | 106 | restore_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) 107 | saver = tf.train.Saver(var_list=restore_vars) 108 | sess = tf.Session() 109 | sess.run(tf.global_variables_initializer()) 110 | sess.run(iterator.initializer) 111 | saver.restore(sess, restore_model) 112 | if not os.path.exists(save_dir): 113 | os.makedirs(save_dir) 114 | for i in range(dataset.data_num): 115 | np_flow_fw, np_flow_bw, np_flow_fw_color, np_flow_bw_color = sess.run([flow_fw['full_res'], flow_bw['full_res'], flow_fw_color, flow_bw_color]) 116 | misc.imsave('%s/flow_fw_color_%s.png' % (save_dir, save_name_list[i]), np_flow_fw_color[0]) 117 | misc.imsave('%s/flow_bw_color_%s.png' % (save_dir, save_name_list[i]), np_flow_bw_color[0]) 118 | write_flo('%s/flow_fw_%s.flo' % (save_dir, save_name_list[i]), np_flow_fw[0]) 119 | write_flo('%s/flow_bw_%s.flo' % (save_dir, save_name_list[i]), np_flow_bw[0]) 120 | print('Finish %d/%d' % (i+1, dataset.data_num)) 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from warp import tf_warp 5 | 6 | def mvn(img): 7 | # minus mean color and divided by standard variance 8 | mean, var = tf.nn.moments(img, axes=[0, 1], keep_dims=True) 9 | img = (img - mean) / tf.sqrt(var + 1e-12) 10 | return img 11 | 12 | def lrelu(x, leak=0.2, name='leaky_relu'): 13 | return tf.maximum(x, leak*x) 14 | 15 | def imshow(img, re_normalize=False): 16 | if re_normalize: 17 | min_value = np.min(img) 18 | max_value = np.max(img) 19 | img = (img - min_value) / (max_value - min_value) 20 | img = img * 255 21 | elif np.max(img) <= 1.: 22 | img = img * 255 23 | img = img.astype('uint8') 24 | shape = img.shape 25 | if len(shape) == 2: 26 | img = np.repeat(np.expand_dims(img, -1), 3, -1) 27 | elif shape[2] == 1: 28 | img = np.repeat(img, 3, -1) 29 | plt.imshow(img) 30 | plt.show() 31 | 32 | def rgb_bgr(img): 33 | tmp = np.copy(img[:, :, 0]) 34 | img[:, :, 0] = np.copy(img[:, :, 2]) 35 | img[:, :, 2] = np.copy(tmp) 36 | return img 37 | 38 | def compute_Fl(flow_gt, flow_est, mask): 39 | # F1 measure 40 | err = tf.multiply(flow_gt - flow_est, mask) 41 | err_norm = tf.norm(err, axis=-1) 42 | 43 | flow_gt_norm = tf.maximum(tf.norm(flow_gt, axis=-1), 1e-12) 44 | F1_logic = tf.logical_and(err_norm > 3, tf.divide(err_norm, flow_gt_norm) > 0.05) 45 | F1_logic = tf.cast(tf.logical_and(tf.expand_dims(F1_logic, -1), mask > 0), tf.float32) 46 | F1 = tf.reduce_sum(F1_logic) / (tf.reduce_sum(mask) + 1e-6) 47 | return F1 48 | 49 | def average_gradients(tower_grads): 50 | """Calculate the average gradient for each shared variable across all towers. 51 | Note that this function provides a synchronization point across all towers. 52 | Args: 53 | tower_grads: List of lists of (gradient, variable) tuples. The outer list 54 | is over individual gradients. The inner list is over the gradient 55 | calculation for each tower. 56 | Returns: 57 | List of pairs of (gradient, variable) where the gradient has been averaged 58 | across all towers. 59 | """ 60 | average_grads = [] 61 | for grad_and_vars in zip(*tower_grads): 62 | # Note that each grad_and_vars looks like the following: 63 | # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) 64 | grads = [] 65 | for g, _ in grad_and_vars: 66 | if g is not None: 67 | # Add 0 dimension to the gradients to represent the tower. 68 | expanded_g = tf.expand_dims(g, 0) 69 | 70 | # Append on a 'tower' dimension which we will average over below. 71 | grads.append(expanded_g) 72 | if grads != []: 73 | # Average over the 'tower' dimension. 74 | grad = tf.concat(grads, 0) 75 | grad = tf.reduce_mean(grad, 0) 76 | 77 | # Keep in mind that the Variables are redundant because they are shared 78 | # across towers. So .. we will just return the first tower's pointer to 79 | # the Variable. 80 | v = grad_and_vars[0][1] 81 | grad_and_var = (grad, v) 82 | average_grads.append(grad_and_var) 83 | return average_grads 84 | 85 | 86 | def length_sq(x): 87 | return tf.reduce_sum(tf.square(x), 3, keepdims=True) 88 | 89 | def occlusion(flow_fw, flow_bw): 90 | x_shape = tf.shape(flow_fw) 91 | H = x_shape[1] 92 | W = x_shape[2] 93 | flow_bw_warped = tf_warp(flow_bw, flow_fw, H, W) 94 | flow_fw_warped = tf_warp(flow_fw, flow_bw, H, W) 95 | flow_diff_fw = flow_fw + flow_bw_warped 96 | flow_diff_bw = flow_bw + flow_fw_warped 97 | mag_sq_fw = length_sq(flow_fw) + length_sq(flow_bw_warped) 98 | mag_sq_bw = length_sq(flow_bw) + length_sq(flow_fw_warped) 99 | occ_thresh_fw = 0.01 * mag_sq_fw + 0.5 100 | occ_thresh_bw = 0.01 * mag_sq_bw + 0.5 101 | occ_fw = tf.cast(length_sq(flow_diff_fw) > occ_thresh_fw, tf.float32) 102 | occ_bw = tf.cast(length_sq(flow_diff_bw) > occ_thresh_bw, tf.float32) 103 | 104 | return occ_fw, occ_bw -------------------------------------------------------------------------------- /warp.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def get_pixel_value(img, x, y): 4 | """ 5 | Utility function to get pixel value for coordinate 6 | vectors x and y from a 4D tensor image. 7 | Input 8 | ----- 9 | - img: tensor of shape (B, H, W, C) 10 | - x: flattened tensor of shape (B*H*W, ) 11 | - y: flattened tensor of shape (B*H*W, ) 12 | Returns 13 | ------- 14 | - output: tensor of shape (B, H, W, C) 15 | """ 16 | shape = tf.shape(x) 17 | batch_size = shape[0] 18 | height = shape[1] 19 | width = shape[2] 20 | 21 | batch_idx = tf.range(0, batch_size) 22 | batch_idx = tf.reshape(batch_idx, (batch_size, 1, 1)) 23 | b = tf.tile(batch_idx, (1, height, width)) 24 | 25 | indices = tf.stack([b, y, x], 3) 26 | 27 | return tf.gather_nd(img, indices) 28 | 29 | def tf_warp(img, flow, H, W): 30 | # H = 256 31 | # W = 256 32 | x,y = tf.meshgrid(tf.range(W), tf.range(H)) 33 | x = tf.expand_dims(x,0) 34 | x = tf.expand_dims(x,-1) 35 | 36 | y = tf.expand_dims(y,0) 37 | y = tf.expand_dims(y,-1) 38 | 39 | x = tf.cast(x, tf.float32) 40 | y = tf.cast(y, tf.float32) 41 | grid = tf.concat([x,y],axis = -1) 42 | # print grid.shape 43 | flows = grid+flow 44 | #print(flows.shape) 45 | max_y = tf.cast(H - 1, tf.int32) 46 | max_x = tf.cast(W - 1, tf.int32) 47 | zero = tf.zeros([], dtype=tf.int32) 48 | 49 | x = flows[:,:,:, 0] 50 | y = flows[:,:,:, 1] 51 | x0 = x 52 | y0 = y 53 | x0 = tf.cast(x0, tf.int32) 54 | x1 = x0 + 1 55 | y0 = tf.cast(y0, tf.int32) 56 | y1 = y0 + 1 57 | 58 | # clip to range [0, H/W] to not violate img boundaries 59 | x0 = tf.clip_by_value(x0, zero, max_x) 60 | x1 = tf.clip_by_value(x1, zero, max_x) 61 | y0 = tf.clip_by_value(y0, zero, max_y) 62 | y1 = tf.clip_by_value(y1, zero, max_y) 63 | 64 | # get pixel value at corner coords 65 | Ia = get_pixel_value(img, x0, y0) 66 | Ib = get_pixel_value(img, x0, y1) 67 | Ic = get_pixel_value(img, x1, y0) 68 | Id = get_pixel_value(img, x1, y1) 69 | 70 | # recast as float for delta calculation 71 | x0 = tf.cast(x0, tf.float32) 72 | x1 = tf.cast(x1, tf.float32) 73 | y0 = tf.cast(y0, tf.float32) 74 | y1 = tf.cast(y1, tf.float32) 75 | 76 | 77 | # calculate deltas 78 | wa = (x1-x) * (y1-y) 79 | wb = (x1-x) * (y-y0) 80 | wc = (x-x0) * (y1-y) 81 | wd = (x-x0) * (y-y0) 82 | 83 | # add dimension for addition 84 | wa = tf.expand_dims(wa, axis=3) 85 | wb = tf.expand_dims(wb, axis=3) 86 | wc = tf.expand_dims(wc, axis=3) 87 | wd = tf.expand_dims(wd, axis=3) 88 | 89 | # compute output 90 | out = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id]) 91 | return out --------------------------------------------------------------------------------