├── README.md ├── data ├── README.md └── extract_rosbag_to_npy.py ├── motorcycle_flow.png └── src ├── EVFlowNet.py ├── basic_layers.py ├── config.py ├── data_loader.py ├── eval_utils.py ├── losses.py ├── test.py ├── train.py └── vis_utils.py /README.md: -------------------------------------------------------------------------------- 1 | # EVFlowNet-pytorch 2 | 3 | Author:CyrilSterling 4 | 5 | EVFlowNet in pytorch. 6 | 7 | The following is the README file in origin project. 8 | 9 | ![Predicted flow from the MVSEC motorcycle sequence.](motorcycle_flow.png) 10 | Left: subsampled flow overlaid on top of grayscale image. Right: predicted flow at each pixel with an event. 11 | 12 | This repo contains the code associated with [EV-FlowNet: Self-Supervised Optical Flow Estimation for Event-based Cameras](https://arxiv.org/abs/1802.06898). This code has most recently been tested with TensorFlow 1.10.0, but should work with anything down to 1.4.0. This code will only work with python 2. Python 3 is not supported. 13 | 14 | ## Introduction 15 | In this work, we present a method to train a deep neural network to learn to predict optical flow from only events from an event-based camera, without any ground truth optical flow labels. Instead, supervision is provided to the network through the grayscale images also generated by an event-based camera such as the [DAVIS camera](https://ieeexplore.ieee.org/abstract/document/6889103/). 16 | 17 | ## Citations 18 | If you use this work in an academic publication, please cite the following work: 19 | [Alex Zihao Zhu, Liangzhe Yuan, Kenneth Chaney, Kostas Daniilidis. "EV-FlowNet: Self-Supervised Optical Flow Estimation for Event-based Cameras", Proceedings of Robotics: Science and Systems, 2018. DOI: 10.15607/RSS.2018.XIV.062.](http://www.roboticsproceedings.org/rss14/p62.html) 20 | 21 | ``` 22 | @INPROCEEDINGS{Zhu-RSS-18, 23 | AUTHOR = {Alex Zhu AND Liangzhe Yuan AND Kenneth Chaney AND Kostas Daniilidis}, 24 | TITLE = {EV-FlowNet: Self-Supervised Optical Flow Estimation for Event-based Cameras}, 25 | BOOKTITLE = {Proceedings of Robotics: Science and Systems}, 26 | YEAR = {2018}, 27 | ADDRESS = {Pittsburgh, Pennsylvania}, 28 | MONTH = {June}, 29 | DOI = {10.15607/RSS.2018.XIV.062} 30 | } 31 | ``` 32 | 33 | An updated arXiv version is also available: 34 | [Zhu, Alex Zihao, et al. "EV-FlowNet: Self-Supervised Optical Flow Estimation for Event-based Cameras." arXiv preprint arXiv:1802.06898 (2018).](https://arxiv.org/abs/1802.06898) 35 | 36 | ## Data 37 | For convenience, the converted data for the ```outdoor_day``` and ```indoor_flying``` sequences can be found [__**here**__](https://drive.google.com/drive/folders/1sW5PPL8tyOPKafMKQkoRdsjL_cq6MJin?usp=sharing). Note that the ```outdoor_day1``` sequence has been shortened to the test sequence as described in the paper. The data can also be generated by running [extract_all_bags.sh](data/extract_all_bags.sh). Note that the converted data takes up **167GB** of space for all of the indoor_flying and outdoor_day sequences. 38 | 39 | Ground truth flow computed from the paper can also be downloaded [__**here**__](https://drive.google.com/drive/folders/1XS0AQTuCwUaWOmtjyJWRHkbXjj_igJLp?usp=sharing). The ground truth takes up **31.7GB** for all of the indoor_flying and outdoor_day sequences. 40 | 41 | The data used to train and test our models is from the [Multi Vehicle Stereo Event Camera dataset](https://daniilidis-group.github.io/mvsec/). For efficient batched reading, we convert the data from the [ROS bag](http://wiki.ros.org/rosbag) format to ```png``` for the grayscale images and ```TFRecords``` for the events, using [extract_rosbag_to_tf.py](data/extract_rosbag_to_tf.py). 42 | 43 | The default option is for the data to be downloaded into [mvsec_data](mvsec_data). 44 | 45 | ### Pre-trained Model 46 | A pre-trained model can be downloaded [__**here**__](https://drive.google.com/drive/folders/1tHu1_ajMi1xdZdyDvDe6z6gOyX5PXDeQ?usp=sharing). The default option is for the model folder (ev-flownet) to be placed in [data/log/saver](data/log/saver). 47 | 48 | ## Updates to the architecture 49 | This code has undergone a few improvements since publication in RSS. Most notably, they are: 50 | * Random 2D rotations are applied as data augmentation at training time. 51 | * Batch norm is used at training. 52 | * The original model had 2 output channels before the predict_flow layer, which caused it to be quite sensitive and difficult to train. We have now increased this to 32 output channels to resolve these issues. 53 | 54 | These improvements have greatly improved the robustness of the models, and they should now work reasonably in most environments. This decreases the AEE error for the dt=1 frames, but slightly increases them for the dt=4 frames. We will update the numbers in our arxiv submission with the new numbers shortly. 55 | 56 | ## Testing 57 | The model can be tested using [test.py](src/test.py). The basic syntax is: 58 | ```python test.py --training_instance ev-flownet --test_sequence outdoor_day1``` 59 | This will evaluate the model on the ```outdoor_day1``` sequence. 60 | 61 | The ```--test_plot``` argument plots the predicted flow, while ```--gt_path``` allows you to specify the path to the groundtruth flow npy file for evaluation. ```--test_skip_frames``` allows for testing with inputs 4 frames apart, as described in the paper. 62 | 63 | ```--save_test_output``` will cause the predicted flows and event histograms, as well as optionally the GT flows to be saved to a npz file {$SEQUENCE}\_output.npz or {$SEQUENCE}\_output_gt.npz if GT is available. The saved data can be accessed from the npz file under the keys 'output_flows', 'event_images' and 'gt_flows'. 64 | 65 | Useful evaluation functions can be found in [eval_utils.py](src/eval_utils.py). In particular, we provide a function that interpolates optical flow over multiple ground truth frames (for the 4 frames apart scenario). 66 | 67 | ## Training 68 | To train a new model from scratch, the basic syntax is: 69 | ```python train.py --training_instance model_name``` 70 | 71 | More parameters are described by running: 72 | ```python train.py -h``` 73 | To modify the training data (it is only ```outdoor_day2``` by default), you can modify [train_bags.txt](data/train_bags.txt), with a separate sequence per line. 74 | 75 | ## Authors 76 | [Alex Zihao Zhu](https://fling.seas.upenn.edu/~alexzhu/dynamic/), [Liangzhe Yuan](https://yuanliangzhe.github.io/), Kenneth Chaney and [Kostas Daniilidis](https://www.cis.upenn.edu/~kostas/). 77 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | Please change line 214 and line 224 to your path before you run extract_rosbag_to_npy.py 2 | 3 | use following code to run the script 4 | python extract_rosbag_to_npy.py --bag indoor_flying1.bag --prefix indoor_flying1 --start_time 4.0 --max_aug 6 --n_skip 1 --output_folder [your location of the output folder] 5 | 6 | and the start time are: 7 | indoor_flying1---4.0 8 | indoor_flying2---9.0 9 | indoor_flying3---7.0 10 | indoor_flying4---6.0 11 | 12 | outdoor_day1---3.0 13 | outdoor_day2---45.0 14 | 15 | outdoor_night1---0.0 16 | outdoor_night2---0.0 17 | outdoor_night3---0.0 -------------------------------------------------------------------------------- /data/extract_rosbag_to_npy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import math 4 | import os 5 | import argparse 6 | 7 | import rospy 8 | from rosbag import Bag 9 | from cv_bridge import CvBridge 10 | 11 | import cv2 12 | import numpy as np 13 | 14 | def _int64_feature(value): 15 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 16 | 17 | def _bytes_feature(value): 18 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 19 | 20 | def _save_events(args, 21 | events, 22 | image_times, 23 | event_count_images, 24 | event_time_images, 25 | event_image_times, 26 | rows, 27 | cols, 28 | max_aug, 29 | n_skip, 30 | event_image_iter, 31 | prefix, 32 | cam, 33 | t_start_ros): 34 | event_iter = 0 35 | cutoff_event_iter = 0 36 | image_iter = 0 37 | curr_image_time = (image_times[image_iter] - t_start_ros).to_sec() 38 | 39 | event_count_image = np.zeros((rows, cols, 2), dtype=np.uint16) 40 | event_time_image = np.zeros((rows, cols, 2), dtype=np.float32) 41 | 42 | while image_iter < len(image_times) and \ 43 | events[-1][2] > curr_image_time: 44 | x = events[event_iter][0] 45 | y = events[event_iter][1] 46 | t = events[event_iter][2] 47 | 48 | if t > curr_image_time: 49 | event_count_images.append(event_count_image) 50 | event_count_image = np.zeros((rows, cols, 2), dtype=np.uint16) 51 | event_time_images.append(event_time_image) 52 | event_time_image = np.zeros((rows, cols, 2), dtype=np.float32) 53 | cutoff_event_iter = event_iter 54 | event_image_times.append(image_times[image_iter].to_sec()) 55 | image_iter += n_skip 56 | if (image_iter < len(image_times)): 57 | curr_image_time = (image_times[image_iter] - t_start_ros).to_sec() 58 | 59 | if events[event_iter][3] > 0: 60 | event_count_image[y, x, 0] += 1 61 | event_time_image[y, x, 0] = t 62 | else: 63 | event_count_image[y, x, 1] += 1 64 | event_time_image[y, x, 1] = t 65 | 66 | event_iter += 1 67 | 68 | del image_times[:image_iter] 69 | del events[:cutoff_event_iter] 70 | 71 | if len(event_count_images) >= max_aug: 72 | n_to_save = len(event_count_images) - max_aug + 1 73 | for i in range(n_to_save): 74 | image_times_out = np.array(event_image_times[i:i+max_aug+1]) 75 | image_times_out = image_times_out.astype(np.float64) 76 | event_time_images_np = np.array(event_time_images[i:i+max_aug], dtype=np.float32) 77 | event_time_images_np -= image_times_out[0] - t_start_ros.to_sec() 78 | event_time_images_np = np.clip(event_time_images_np, a_min=0, a_max=None) 79 | image_shape = np.array(event_time_images_np.shape, dtype=np.uint16) 80 | 81 | now = np.array([np.array(event_count_images[i:i+max_aug]),event_time_images_np,image_times_out]) 82 | np.save(args.output_folder+cam+'_event'+str(event_image_iter).rjust(5,'0'),now) 83 | event_image_iter += n_skip 84 | 85 | del event_count_images[:n_to_save] 86 | del event_time_images[:n_to_save] 87 | del event_image_times[:n_to_save] 88 | return event_image_iter 89 | 90 | def filter_events(events, ts): 91 | r'''Removes all events with timestamp lower than the specified one 92 | 93 | Args: 94 | events (list): the list of events in form of (x, y, t, p) 95 | ts (float): the timestamp to split events 96 | 97 | Return: 98 | (list): a list of events with timestamp above the threshold 99 | ''' 100 | tss = np.array([e[2] for e in events]) 101 | idx_array = np.argsort(tss) # I hope it's not needed 102 | i = np.searchsorted(tss[idx_array], ts) 103 | return [events[k] for k in idx_array[i:]] 104 | 105 | def main(): 106 | parser = argparse.ArgumentParser( 107 | description=("Extracts grayscale and event images from a ROS bag and " 108 | "saves them as TFRecords for training in TensorFlow.")) 109 | parser.add_argument("--bag", dest="bag", 110 | help="Path to ROS bag.", 111 | default='/media/cyrilsterling/D/EV-FlowNet-pth/data/outdoor_day2.bag') 112 | parser.add_argument("--prefix", dest="prefix", 113 | help="Output file prefix.", 114 | default='outdoor_day2') 115 | parser.add_argument("--output_folder", dest="output_folder", 116 | help="Output folder.", 117 | default='/media/cyrilsterling/D/EV-FlowNet-pth/data/mvsec/outdoor_day2') 118 | parser.add_argument("--max_aug", dest="max_aug", 119 | help="Maximum number of images to combine for augmentation.", 120 | type=int, 121 | default = 6) 122 | parser.add_argument("--n_skip", dest="n_skip", 123 | help="Maximum number of images to combine for augmentation.", 124 | type=int, 125 | default = 1) 126 | parser.add_argument("--start_time", dest="start_time", 127 | help="Time to start in the bag.", 128 | type=float, 129 | default = 45.0) 130 | parser.add_argument("--end_time", dest="end_time", 131 | help="Time to end in the bag.", 132 | type=float, 133 | default = -1.0) 134 | 135 | args = parser.parse_args() 136 | 137 | bridge = CvBridge() 138 | 139 | n_msgs = 0 140 | left_event_image_iter = 0 141 | right_event_image_iter = 0 142 | left_image_iter = 0 143 | right_image_iter = 0 144 | first_left_image_time = -1 145 | first_right_image_time = -1 146 | 147 | left_events = [] 148 | right_events = [] 149 | left_images = [] 150 | right_images = [] 151 | left_image_times = [] 152 | right_image_times = [] 153 | left_event_count_images = [] 154 | left_event_time_images = [] 155 | left_event_image_times = [] 156 | 157 | right_event_count_images = [] 158 | right_event_time_images = [] 159 | right_event_image_times = [] 160 | 161 | cols = 346 162 | rows = 260 163 | print("Processing bag") 164 | bag = Bag(args.bag) 165 | # Get actual time for the start of the bag. 166 | t_start = bag.get_start_time() 167 | t_start_ros = rospy.Time(t_start) 168 | # Set the time at which the bag reading should end. 169 | if args.end_time == -1.0: 170 | t_end = bag.get_end_time() 171 | else: 172 | t_end = t_start + args.end_time 173 | 174 | eps = 0.1 175 | ifis = 0 176 | for topic, msg, t in bag.read_messages( 177 | topics=['/davis/left/image_raw', 178 | '/davis/right/image_raw', 179 | '/davis/left/events', 180 | '/davis/right/events'], 181 | start_time=rospy.Time(max(args.start_time, eps) - eps + t_start), 182 | end_time=rospy.Time(t_end)): 183 | # Check to make sure we're working with stereo messages. 184 | if not ('left' in topic or 'right' in topic): 185 | print('ERROR: topic {} does not contain left or right, is this stereo?' 186 | 'If not, you will need to modify the topic names in the code.'. 187 | format(topic)) 188 | return 189 | 190 | # Counter for status updates. 191 | n_msgs += 1 192 | if n_msgs % 500 == 0: 193 | print("Processed {} msgs, {} images, {} rgb images, time is {}.".format(n_msgs, 194 | left_event_image_iter, 195 | left_image_iter, 196 | t.to_sec() - t_start)) 197 | 198 | isLeft = 'left' in topic 199 | 200 | if 'image' in topic: 201 | width = msg.width 202 | height = msg.height 203 | if width != cols or height != rows: 204 | print("Image dimensions are not what we expected: set: ({} {}) vs got:({} {})" 205 | .format(cols, rows, width, height)) 206 | return 207 | time = msg.header.stamp 208 | if time.to_sec() - t_start < args.start_time: 209 | continue 210 | image = np.asarray(bridge.imgmsg_to_cv2(msg, msg.encoding)) 211 | image = np.reshape(image, (height, width)) 212 | 213 | if isLeft: 214 | cv2.imwrite("/media/cyrilsterling/D/EV-FlowNet-pth/data/mvsec2/indoor_flying1/left_image{:05d}.png".format(left_image_iter),image) 215 | if left_image_iter > 0: 216 | left_image_times.append(time) 217 | else: 218 | first_left_image_time = time 219 | left_event_image_times.append(time.to_sec()) 220 | # filter events we added previously 221 | left_events = filter_events(left_events, left_event_image_times[-1] - t_start) 222 | left_image_iter += 1 223 | else: 224 | cv2.imwrite("/media/cyrilsterling/D/EV-FlowNet-pth/data/mvsec/outdoor_day2/right_image{:05d}.png".format(right_image_iter),image) 225 | if right_image_iter > 0: 226 | right_image_times.append(time) 227 | else: 228 | first_right_image_time = time 229 | right_event_image_times.append(time.to_sec()) 230 | # filter events we added previously 231 | right_events = filter_events(right_events, left_event_image_times[-1] - t_start) 232 | 233 | right_image_iter += 1 234 | 235 | elif 'events' in topic and msg.events: 236 | # Add events to list. 237 | for event in msg.events: 238 | ts = event.ts 239 | event = [event.x, 240 | event.y, 241 | (ts - t_start_ros).to_sec(), 242 | (float(event.polarity) - 0.5) * 2] 243 | if isLeft: 244 | # add event if it was after the first image or we haven't seen the first image 245 | if first_left_image_time == -1 or ts > first_left_image_time: 246 | left_events.append(event) 247 | elif first_right_image_time == -1 or ts > first_right_image_time: 248 | right_events.append(event) 249 | if isLeft: 250 | if len(left_image_times) >= args.max_aug and\ 251 | left_events[-1][2] > (left_image_times[args.max_aug-1]-t_start_ros).to_sec(): 252 | left_event_image_iter = _save_events(args, 253 | left_events, 254 | left_image_times, 255 | left_event_count_images, 256 | left_event_time_images, 257 | left_event_image_times, 258 | rows, 259 | cols, 260 | args.max_aug, 261 | args.n_skip, 262 | left_event_image_iter, 263 | args.prefix, 264 | 'left', 265 | t_start_ros) 266 | else: 267 | if len(right_image_times) >= args.max_aug and\ 268 | right_events[-1][2] > (right_image_times[args.max_aug-1]-t_start_ros).to_sec(): 269 | right_event_image_iter = _save_events(args, 270 | right_events, 271 | right_image_times, 272 | right_event_count_images, 273 | right_event_time_images, 274 | right_event_image_times, 275 | rows, 276 | cols, 277 | args.max_aug, 278 | args.n_skip, 279 | right_event_image_iter, 280 | args.prefix, 281 | 'right', 282 | t_start_ros) 283 | 284 | 285 | image_counter_file = open(os.path.join(args.output_folder, args.prefix, "n_images.txt") , 'w') 286 | image_counter_file.write("{} {}".format(left_event_image_iter, right_event_image_iter)) 287 | image_counter_file.close() 288 | 289 | if __name__ == "__main__": 290 | main() 291 | -------------------------------------------------------------------------------- /motorcycle_flow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CyrilSterling/EVFlowNet-pytorch/fdd800d0fc546cf63c516c59b478be04ed3b5d4b/motorcycle_flow.png -------------------------------------------------------------------------------- /src/EVFlowNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from basic_layers import * 4 | 5 | _BASE_CHANNELS = 64 6 | 7 | class EVFlowNet(nn.Module): 8 | def __init__(self, args): 9 | super(EVFlowNet,self).__init__() 10 | self._args = args 11 | 12 | self.encoder1 = general_conv2d(in_channels = 4, out_channels=_BASE_CHANNELS, do_batch_norm=not self._args.no_batch_norm) 13 | self.encoder2 = general_conv2d(in_channels = _BASE_CHANNELS, out_channels=2*_BASE_CHANNELS, do_batch_norm=not self._args.no_batch_norm) 14 | self.encoder3 = general_conv2d(in_channels = 2*_BASE_CHANNELS, out_channels=4*_BASE_CHANNELS, do_batch_norm=not self._args.no_batch_norm) 15 | self.encoder4 = general_conv2d(in_channels = 4*_BASE_CHANNELS, out_channels=8*_BASE_CHANNELS, do_batch_norm=not self._args.no_batch_norm) 16 | 17 | self.resnet_block = nn.Sequential(*[build_resnet_block(8*_BASE_CHANNELS, do_batch_norm=not self._args.no_batch_norm) for i in range(2)]) 18 | 19 | self.decoder1 = upsample_conv2d_and_predict_flow(in_channels=16*_BASE_CHANNELS, 20 | out_channels=4*_BASE_CHANNELS, do_batch_norm=not self._args.no_batch_norm) 21 | 22 | self.decoder2 = upsample_conv2d_and_predict_flow(in_channels=8*_BASE_CHANNELS+2, 23 | out_channels=2*_BASE_CHANNELS, do_batch_norm=not self._args.no_batch_norm) 24 | 25 | self.decoder3 = upsample_conv2d_and_predict_flow(in_channels=4*_BASE_CHANNELS+2, 26 | out_channels=_BASE_CHANNELS, do_batch_norm=not self._args.no_batch_norm) 27 | 28 | self.decoder4 = upsample_conv2d_and_predict_flow(in_channels=2*_BASE_CHANNELS+2, 29 | out_channels=int(_BASE_CHANNELS/2), do_batch_norm=not self._args.no_batch_norm) 30 | 31 | def forward(self,inputs): 32 | # encoder 33 | skip_connections = {} 34 | inputs = self.encoder1(inputs) 35 | skip_connections['skip0'] = inputs.clone() 36 | inputs = self.encoder2(inputs) 37 | skip_connections['skip1'] = inputs.clone() 38 | inputs = self.encoder3(inputs) 39 | skip_connections['skip2'] = inputs.clone() 40 | inputs = self.encoder4(inputs) 41 | skip_connections['skip3'] = inputs.clone() 42 | 43 | # transition 44 | inputs = self.resnet_block(inputs) 45 | 46 | # decoder 47 | flow_dict = {} 48 | inputs = torch.cat([inputs, skip_connections['skip3']], dim=1) 49 | inputs, flow = self.decoder1(inputs) 50 | flow_dict['flow0'] = flow.clone() 51 | 52 | inputs = torch.cat([inputs, skip_connections['skip2']], dim=1) 53 | inputs, flow = self.decoder2(inputs) 54 | flow_dict['flow1'] = flow.clone() 55 | 56 | inputs = torch.cat([inputs, skip_connections['skip1']], dim=1) 57 | inputs, flow = self.decoder3(inputs) 58 | flow_dict['flow2'] = flow.clone() 59 | 60 | inputs = torch.cat([inputs, skip_connections['skip0']], dim=1) 61 | inputs, flow = self.decoder4(inputs) 62 | flow_dict['flow3'] = flow.clone() 63 | 64 | return flow_dict 65 | 66 | 67 | if __name__ == "__main__": 68 | from config import configs 69 | import time 70 | from data_loader import EventData 71 | ''' 72 | args = configs() 73 | model = EVFlowNet(args).cuda() 74 | input_ = torch.rand(8,4,256,256).cuda() 75 | a = time.time() 76 | output = model(input_) 77 | b = time.time() 78 | print(b-a) 79 | print(output['flow0'].shape, output['flow1'].shape, output['flow2'].shape, output['flow3'].shape) 80 | #print(model.state_dict().keys()) 81 | #print(model) 82 | ''' 83 | import numpy as np 84 | args = configs() 85 | model = EVFlowNet(args).cuda() 86 | EventDataset = EventData(args.data_path, 'train') 87 | EventDataLoader = torch.utils.data.DataLoader(dataset=EventDataset, batch_size=args.batch_size, shuffle=True) 88 | #model = nn.DataParallel(model) 89 | #model.load_state_dict(torch.load(args.load_path+'/model18')) 90 | for input_, _, _, _ in EventDataLoader: 91 | input_ = input_.cuda() 92 | a = time.time() 93 | (model(input_)) 94 | b = time.time() 95 | print(b-a) -------------------------------------------------------------------------------- /src/basic_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class build_resnet_block(nn.Module): 5 | """ 6 | a resnet block which includes two general_conv2d 7 | """ 8 | def __init__(self, channels, layers=2, do_batch_norm=False): 9 | super(build_resnet_block,self).__init__() 10 | self._channels = channels 11 | self._layers = layers 12 | 13 | self.res_block = nn.Sequential(*[general_conv2d(in_channels=self._channels, 14 | out_channels=self._channels, 15 | strides=1, 16 | do_batch_norm=do_batch_norm) for i in range(self._layers)]) 17 | 18 | def forward(self,input_res): 19 | inputs = input_res.clone() 20 | input_res = self.res_block(input_res) 21 | return input_res + inputs 22 | 23 | class upsample_conv2d_and_predict_flow(nn.Module): 24 | """ 25 | an upsample convolution layer which includes a nearest interpolate and a general_conv2d 26 | """ 27 | def __init__(self, in_channels, out_channels, ksize=3, do_batch_norm=False): 28 | super(upsample_conv2d_and_predict_flow, self).__init__() 29 | self._in_channels = in_channels 30 | self._out_channels = out_channels 31 | self._ksize = ksize 32 | self._do_batch_norm = do_batch_norm 33 | 34 | self.general_conv2d = general_conv2d(in_channels=self._in_channels, 35 | out_channels=self._out_channels, 36 | ksize=self._ksize, 37 | strides=1, 38 | do_batch_norm=self._do_batch_norm, 39 | padding=0) 40 | 41 | self.pad = nn.ReflectionPad2d(padding=(int((self._ksize-1)/2), int((self._ksize-1)/2), 42 | int((self._ksize-1)/2), int((self._ksize-1)/2)))#对称padding 43 | 44 | self.predict_flow = general_conv2d(in_channels=self._out_channels, 45 | out_channels=2, 46 | ksize=1, 47 | strides=1, 48 | padding=0, 49 | activation='tanh') 50 | 51 | def forward(self, conv): 52 | shape = conv.shape 53 | conv = nn.functional.interpolate(conv,size=[shape[2]*2,shape[3]*2],mode='nearest')#最近邻插值上采样 54 | conv = self.pad(conv) 55 | conv = self.general_conv2d(conv) 56 | 57 | flow = self.predict_flow(conv) * 256. 58 | 59 | return torch.cat([conv,flow.clone()], dim=1), flow 60 | 61 | def general_conv2d(in_channels,out_channels, ksize=3, strides=2, padding=1, do_batch_norm=False, activation='relu'): 62 | """ 63 | a general convolution layer which includes a conv2d, a relu and a batch_normalize 64 | """ 65 | if activation == 'relu': 66 | if do_batch_norm: 67 | conv2d = nn.Sequential( 68 | nn.Conv2d(in_channels = in_channels,out_channels = out_channels,kernel_size = ksize, 69 | stride=strides,padding=padding), 70 | nn.ReLU(inplace=True), 71 | nn.BatchNorm2d(out_channels,eps=1e-5,momentum=0.99) 72 | ) 73 | else: 74 | conv2d = nn.Sequential( 75 | nn.Conv2d(in_channels = in_channels,out_channels = out_channels,kernel_size = ksize, 76 | stride=strides,padding=padding), 77 | nn.ReLU(inplace=True) 78 | ) 79 | elif activation == 'tanh': 80 | if do_batch_norm: 81 | conv2d = nn.Sequential( 82 | nn.Conv2d(in_channels = in_channels,out_channels = out_channels,kernel_size = ksize, 83 | stride=strides,padding=padding), 84 | nn.Tanh(), 85 | nn.BatchNorm2d(out_channels,eps=1e-5,momentum=0.99) 86 | ) 87 | else: 88 | conv2d = nn.Sequential( 89 | nn.Conv2d(in_channels = in_channels,out_channels = out_channels,kernel_size = ksize, 90 | stride=strides,padding=padding), 91 | nn.Tanh() 92 | ) 93 | return conv2d 94 | 95 | if __name__ == "__main__": 96 | a = upsample_conv2d_and_predict_flow(1,4) 97 | b = build_resnet_block(2) 98 | c = torch.Tensor([[1,2,3,4,5],[4,5,6,2,3],[7,8,9,4,7],[2,3,5,4,6],[4,6,7,4,5]]).reshape(1,1,5,5) 99 | _, out = a(c) 100 | out = b(out) 101 | print(out.shape) 102 | print(a) 103 | print(b) 104 | -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def configs(): 4 | parser = argparse.ArgumentParser() 5 | 6 | parser.add_argument('--data_path', 7 | type=str, 8 | help="Path to data directory.", 9 | default='D://EV-FlowNet-pth/data/mvsec/') 10 | parser.add_argument('--load_path', 11 | type=str, 12 | help="Path to saved model.", 13 | default='D://EV-FlowNet-pth/data/log/saver/') 14 | parser.add_argument('--training_instance', 15 | type=str, 16 | help="Specific saved model to load. A new one will be generated if empty.", 17 | default='') 18 | parser.add_argument('--batch_size', 19 | type=int, 20 | help="Training batch size.", 21 | default=8) 22 | parser.add_argument('--initial_learning_rate', 23 | type=float, 24 | help="Initial learning rate.", 25 | default=3e-4) 26 | parser.add_argument('--learning_rate_decay', 27 | type=float, 28 | help='Rate at which the learning rate is decayed.', 29 | default=0.9) 30 | parser.add_argument('--smoothness_weight', 31 | type=float, 32 | help='Weight for the smoothness term in the loss function.', 33 | default=0.5) 34 | parser.add_argument('--image_height', 35 | type=int, 36 | help="Image height.", 37 | default=256) 38 | parser.add_argument('--image_width', 39 | type=int, 40 | help="Image width.", 41 | default=256) 42 | parser.add_argument('--no_batch_norm', 43 | action='store_true', 44 | help='If true, batch norm will not be performed at each layer', 45 | default=False) 46 | 47 | # Args for testing only. 48 | 49 | parser.add_argument('--test_skip_frames', 50 | action='store_true', 51 | help='If true, input images will be 4 frames apart.') 52 | 53 | parser.add_argument('--test_sequence', 54 | type=str, 55 | help="Name of the test sequence.", 56 | default='indoor_flying1') 57 | parser.add_argument('--gt_path', 58 | type=str, 59 | help='Path to optical flow ground truth npz file.', 60 | default='D:\mvsec\indoor_flying1_gt_flow_dist.npz') 61 | parser.add_argument('--test_plot', 62 | action='store_true', 63 | help='If true, the flow predictions will be visualized during testing.', 64 | default=True) 65 | parser.add_argument('--save_test_output', 66 | action='store_true', 67 | help='If true, output flow will be saved to a npz file.', 68 | default='') 69 | 70 | 71 | args = parser.parse_args() 72 | return args -------------------------------------------------------------------------------- /src/data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from torchvision import transforms 4 | import torchvision.transforms.functional as F 5 | from config import configs 6 | from PIL import Image 7 | import os 8 | import numpy as np 9 | import cv2 10 | 11 | _MAX_SKIP_FRAMES = 6 12 | _TEST_SKIP_FRAMES = 4 13 | _N_SKIP = 1 14 | 15 | class EventData(Dataset): 16 | """ 17 | args: 18 | data_folder_path:the path of data 19 | split:'train' or 'test' 20 | """ 21 | def __init__(self, data_folder_path, split, count_only=False, time_only=False, skip_frames=False): 22 | self._data_folder_path = data_folder_path 23 | self._split = split 24 | self._count_only = count_only 25 | self._time_only = time_only 26 | self._skip_frames = skip_frames 27 | self.args = configs() 28 | self.event_data_paths, self.n_ima = self.read_file_paths(self._data_folder_path, self._split) 29 | 30 | def __getitem__(self, index): 31 | # 获得image_times event_count_images event_time_images image_iter prefix cam 32 | image_iter = 0 33 | for i in self.n_ima: 34 | if index < i: 35 | break 36 | image_iter += 1 37 | image_iter -= 1 38 | if image_iter % 2 == 0: 39 | cam = 'left' 40 | else: 41 | cam = 'right' 42 | prefix = self.event_data_paths[image_iter] 43 | image_iter = index - self.n_ima[image_iter] 44 | 45 | event_count_images, event_time_images, image_times = np.load(prefix + "/" + cam + "_event" +\ 46 | str(image_iter).rjust(5,'0') + ".npy", encoding='bytes', allow_pickle=True) 47 | event_count_images = torch.from_numpy(event_count_images.astype(np.int16)) 48 | event_time_images = torch.from_numpy(event_time_images.astype(np.float32)) 49 | image_times = torch.from_numpy(image_times.astype(np.float64)) 50 | 51 | if self._split is 'test': 52 | if self._skip_frames: 53 | n_frames = _TEST_SKIP_FRAMES 54 | else: 55 | n_frames = 1 56 | else: 57 | n_frames = np.random.randint(low=1, high=_MAX_SKIP_FRAMES+1) * _N_SKIP 58 | timestamps = [image_times[0], image_times[n_frames]] 59 | event_count_image, event_time_image = self._read_events(event_count_images, event_time_images, n_frames) 60 | 61 | prev_img_path = prefix + "/" + cam + "_image" + str(image_iter).rjust(5,'0') + ".png" 62 | next_img_path = prefix + "/" + cam + "_image" + str(image_iter+n_frames).rjust(5,'0') + ".png" 63 | 64 | prev_image = Image.open(prev_img_path) 65 | next_image = Image.open(next_img_path) 66 | 67 | #transforms 68 | rand_flip = np.random.randint(low=0, high=2) 69 | rand_rotate = np.random.randint(low=-30, high=30) 70 | x = np.random.randint(low=1, high=(event_count_image.shape[1]-self.args.image_height)) 71 | y = np.random.randint(low=1, high=(event_count_image.shape[2]-self.args.image_width)) 72 | if self._split == 'train': 73 | if self._count_only: 74 | event_count_image = F.to_pil_image(event_count_image / 255.) 75 | # random_flip 76 | if rand_flip == 0: 77 | event_count_image = event_count_image.transpose(Image.FLIP_LEFT_RIGHT) 78 | # random_rotate 79 | event_image = event_count_image.rotate(rand_rotate) 80 | # random_crop 81 | event_image = F.to_tensor(event_image) * 255. 82 | event_image = event_image[:,x:x+self.args.image_height,y:y+self.args.image_width] 83 | elif self._time_only: 84 | event_time_image = F.to_pil_image(event_time_image) 85 | # random_flip 86 | if rand_flip == 0: 87 | event_time_image = event_time_image.transpose(Image.FLIP_LEFT_RIGHT) 88 | # random_rotate 89 | event_image = event_time_image.rotate(rand_rotate) 90 | # random_crop 91 | event_image = F.to_tensor(event_image) 92 | event_image = event_image[:,x:x+self.args.image_height,y:y+self.args.image_width] 93 | else: 94 | event_count_image = F.to_pil_image(event_count_image / 255.) 95 | event_time_image = F.to_pil_image(event_time_image) 96 | # random_flip 97 | if rand_flip == 0: 98 | event_count_image = event_count_image.transpose(Image.FLIP_LEFT_RIGHT) 99 | event_time_image = event_time_image.transpose(Image.FLIP_LEFT_RIGHT) 100 | # random_rotate 101 | event_count_image = event_count_image.rotate(rand_rotate) 102 | event_time_image = event_time_image.rotate(rand_rotate) 103 | # random_crop 104 | event_count_image = F.to_tensor(event_count_image) 105 | event_time_image = F.to_tensor(event_time_image) * 255. 106 | event_image = torch.cat((event_count_image,event_time_image), dim=0) 107 | event_image = event_image[...,x:x+self.args.image_height,y:y+self.args.image_width] 108 | 109 | if rand_flip == 0: 110 | prev_image = prev_image.transpose(Image.FLIP_LEFT_RIGHT) 111 | next_image = next_image.transpose(Image.FLIP_LEFT_RIGHT) 112 | prev_image = prev_image.rotate(rand_rotate) 113 | next_image = next_image.rotate(rand_rotate) 114 | prev_image = F.to_tensor(prev_image) 115 | next_image = F.to_tensor(next_image) 116 | prev_image = prev_image[...,x:x+self.args.image_height,y:y+self.args.image_width] 117 | next_image = next_image[...,x:x+self.args.image_height,y:y+self.args.image_width] 118 | 119 | else: 120 | if self._count_only: 121 | event_image = F.to_tensor(F.center_crop(F.to_pil_image(event_count_image / 255.), 122 | (self.args.image_height, self.args.image_width))) 123 | event_image = event_image * 255. 124 | elif self._time_only: 125 | event_image = F.to_tensor(F.center_crop(F.to_pil_image(event_time_image), 126 | (self.args.image_height, self.args.image_width))) 127 | else: 128 | event_image = torch.cat((event_count_image / 255.,event_time_image), dim=0) 129 | event_image = F.to_tensor(F.center_crop(F.to_pil_image(event_image), 130 | (self.args.image_height, self.args.image_width))) 131 | event_image[:2,...] = event_image[:2,...] * 255. 132 | prev_image = F.to_tensor(F.center_crop(prev_image, (self.args.image_height, self.args.image_width))) 133 | next_image = F.to_tensor(F.center_crop(next_image, (self.args.image_height, self.args.image_width))) 134 | 135 | return event_image, prev_image, next_image, timestamps 136 | 137 | def __len__(self): 138 | return self.n_ima[-1] 139 | 140 | def _read_events(self, 141 | event_count_images, 142 | event_time_images, 143 | n_frames): 144 | #event_count_images = event_count_images.reshape(shape).type(torch.float32) 145 | event_count_image = event_count_images[:n_frames, :, :, :] 146 | event_count_image = torch.sum(event_count_image, dim=0).type(torch.float32) 147 | p = torch.max(event_count_image) 148 | event_count_image = event_count_image.permute(2,0,1) 149 | 150 | #event_time_images = event_time_images.reshape(shape).type(torch.float32) 151 | event_time_image = event_time_images[:n_frames, :, :, :] 152 | event_time_image = torch.max(event_time_image, dim=0)[0] 153 | 154 | event_time_image /= torch.max(event_time_image) 155 | event_time_image = event_time_image.permute(2,0,1) 156 | 157 | ''' 158 | if self._count_only: 159 | event_image = event_count_image 160 | elif self._time_only: 161 | event_image = event_time_image 162 | else: 163 | event_image = torch.cat([event_count_image, event_time_image], dim=2) 164 | 165 | event_image = event_image.permute(2,0,1).type(torch.float32) 166 | ''' 167 | 168 | return event_count_image, event_time_image 169 | 170 | def read_file_paths(self, 171 | data_folder_path, 172 | split, 173 | sequence=None): 174 | """ 175 | return: event_data_paths,paths of event data (left and right in one folder is two) 176 | n_ima: the sum number of event pictures in every path and the paths before 177 | """ 178 | event_data_paths = [] 179 | n_ima = 0 180 | if sequence is None: 181 | bag_list_file = open(os.path.join(data_folder_path, "{}_bags.txt".format(split)), 'r') 182 | lines = bag_list_file.read().splitlines() 183 | bag_list_file.close() 184 | else: 185 | if isinstance(sequence, (list, )): 186 | lines = sequence 187 | else: 188 | lines = [sequence] 189 | 190 | n_ima = [0] 191 | for line in lines: 192 | bag_name = line 193 | 194 | event_data_paths.append(os.path.join(data_folder_path,bag_name)) 195 | num_ima_file = open(os.path.join(data_folder_path, bag_name, 'n_images.txt'), 'r') 196 | num_imas = num_ima_file.read() 197 | num_ima_file.close() 198 | num_imas_split = num_imas.split(' ') 199 | n_left_ima = int(num_imas_split[0]) - _MAX_SKIP_FRAMES 200 | n_ima.append(n_left_ima + n_ima[-1]) 201 | 202 | n_right_ima = int(num_imas_split[1]) - _MAX_SKIP_FRAMES 203 | if n_right_ima > 0 and not split is 'test': 204 | n_ima.append(n_right_ima + n_ima[-1]) 205 | else: 206 | n_ima.append(n_ima[-1]) 207 | event_data_paths.append(os.path.join(data_folder_path,bag_name)) 208 | 209 | return event_data_paths, n_ima 210 | 211 | if __name__ == "__main__": 212 | data = EventData('/media/cyrilsterling/D/EV-FlowNet-pth/data/mvsec/', 'train') 213 | EventDataLoader = torch.utils.data.DataLoader(dataset=data, batch_size=1,shuffle=True) 214 | it = 0 215 | for i in EventDataLoader: 216 | a = i[0][0].numpy() 217 | b = i[1][0].numpy() 218 | c = i[2][0].numpy() 219 | cv2.namedWindow('a') 220 | cv2.namedWindow('b') 221 | cv2.namedWindow('c') 222 | a = a[2,...]+a[3,...] 223 | print(np.max(a)) 224 | a = (a-np.min(a))/(np.max(a)-np.min(a)) 225 | b = np.transpose(b,(1,2,0)) 226 | c = np.transpose(c,(1,2,0)) 227 | cv2.imshow('a',a) 228 | cv2.imshow('b',b) 229 | cv2.imshow('c',c) 230 | cv2.waitKey(1) -------------------------------------------------------------------------------- /src/eval_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import numpy as np 3 | import math 4 | import cv2 5 | 6 | """ 7 | Calculates per pixel flow error between flow_pred and flow_gt. 8 | event_img is used to mask out any pixels without events (are 0). 9 | If is_car is True, only the top 190 rows of the images will be evaluated to remove the hood of 10 | the car which does not appear in the GT. 11 | """ 12 | def flow_error_dense(flow_gt, flow_pred, event_img, is_car=False): 13 | max_row = flow_gt.shape[1] 14 | 15 | if is_car: 16 | max_row = 190 17 | 18 | event_img_cropped = np.squeeze(event_img)[:max_row, :] 19 | flow_gt_cropped = flow_gt[:max_row, :, :] 20 | 21 | flow_pred_cropped = flow_pred[:max_row, :, :] 22 | 23 | event_mask = event_img_cropped > 0 24 | 25 | # Only compute error over points that are valid in the GT (not inf or 0). 26 | flow_mask = np.logical_and( 27 | np.logical_and(~np.isinf(flow_gt_cropped[:, :, 0]), ~np.isinf(flow_gt_cropped[:, :, 1])), 28 | np.linalg.norm(flow_gt_cropped, axis=2) > 0) 29 | total_mask = np.squeeze(np.logical_and(event_mask, flow_mask)) 30 | 31 | gt_masked = flow_gt_cropped[total_mask, :] 32 | pred_masked = flow_pred_cropped[total_mask, :] 33 | 34 | # Average endpoint error. 35 | EE = np.linalg.norm(gt_masked - pred_masked, axis=-1) 36 | n_points = EE.shape[0] 37 | AEE = np.mean(EE) 38 | 39 | # Percentage of points with EE < 3 pixels. 40 | thresh = 3. 41 | percent_AEE = float((EE < thresh).sum()) / float(EE.shape[0] + 1e-5) 42 | 43 | return AEE, percent_AEE, n_points 44 | 45 | """ 46 | Propagates x_indices and y_indices by their flow, as defined in x_flow, y_flow. 47 | x_mask and y_mask are zeroed out at each pixel where the indices leave the image. 48 | The optional scale_factor will scale the final displacement. 49 | """ 50 | def prop_flow(x_flow, y_flow, x_indices, y_indices, x_mask, y_mask, scale_factor=1.0): 51 | flow_x_interp = cv2.remap(x_flow, 52 | x_indices, 53 | y_indices, 54 | cv2.INTER_NEAREST) 55 | 56 | flow_y_interp = cv2.remap(y_flow, 57 | x_indices, 58 | y_indices, 59 | cv2.INTER_NEAREST) 60 | 61 | x_mask[flow_x_interp == 0] = False 62 | y_mask[flow_y_interp == 0] = False 63 | 64 | x_indices += flow_x_interp * scale_factor 65 | y_indices += flow_y_interp * scale_factor 66 | 67 | return 68 | 69 | """ 70 | The ground truth flow maps are not time synchronized with the grayscale images. Therefore, we 71 | need to propagate the ground truth flow over the time between two images. 72 | This function assumes that the ground truth flow is in terms of pixel displacement, not velocity. 73 | 74 | Pseudo code for this process is as follows: 75 | 76 | x_orig = range(cols) 77 | y_orig = range(rows) 78 | x_prop = x_orig 79 | y_prop = y_orig 80 | Find all GT flows that fit in [image_timestamp, image_timestamp+image_dt]. 81 | for all of these flows: 82 | x_prop = x_prop + gt_flow_x(x_prop, y_prop) 83 | y_prop = y_prop + gt_flow_y(x_prop, y_prop) 84 | 85 | The final flow, then, is x_prop - x-orig, y_prop - y_orig. 86 | Note that this is flow in terms of pixel displacement, with units of pixels, not pixel velocity. 87 | 88 | Inputs: 89 | x_flow_in, y_flow_in - list of numpy arrays, each array corresponds to per pixel flow at 90 | each timestamp. 91 | gt_timestamps - timestamp for each flow array. 92 | start_time, end_time - gt flow will be estimated between start_time and end time. 93 | """ 94 | def estimate_corresponding_gt_flow(x_flow_in, 95 | y_flow_in, 96 | gt_timestamps, 97 | start_time, 98 | end_time): 99 | # Each gt flow at timestamp gt_timestamps[gt_iter] represents the displacement between 100 | # gt_iter and gt_iter+1. 101 | gt_iter = np.searchsorted(gt_timestamps, start_time, side='right') - 1 102 | gt_dt = gt_timestamps[gt_iter+1] - gt_timestamps[gt_iter] 103 | x_flow = np.squeeze(x_flow_in[gt_iter, ...]) 104 | y_flow = np.squeeze(y_flow_in[gt_iter, ...]) 105 | 106 | dt = end_time - start_time 107 | 108 | # No need to propagate if the desired dt is shorter than the time between gt timestamps. 109 | if gt_dt > dt: 110 | return x_flow * dt / gt_dt, y_flow * dt / gt_dt 111 | 112 | x_indices, y_indices = np.meshgrid(np.arange(x_flow.shape[1]), 113 | np.arange(x_flow.shape[0])) 114 | x_indices = x_indices.astype(np.float32) 115 | y_indices = y_indices.astype(np.float32) 116 | 117 | orig_x_indices = np.copy(x_indices) 118 | orig_y_indices = np.copy(y_indices) 119 | 120 | # Mask keeps track of the points that leave the image, and zeros out the flow afterwards. 121 | x_mask = np.ones(x_indices.shape, dtype=bool) 122 | y_mask = np.ones(y_indices.shape, dtype=bool) 123 | 124 | scale_factor = (gt_timestamps[gt_iter+1] - start_time) / gt_dt 125 | total_dt = gt_timestamps[gt_iter+1] - start_time 126 | 127 | prop_flow(x_flow, y_flow, 128 | x_indices, y_indices, 129 | x_mask, y_mask, 130 | scale_factor=scale_factor) 131 | 132 | gt_iter += 1 133 | 134 | while gt_timestamps[gt_iter+1] < end_time: 135 | x_flow = np.squeeze(x_flow_in[gt_iter, ...]) 136 | y_flow = np.squeeze(y_flow_in[gt_iter, ...]) 137 | 138 | prop_flow(x_flow, y_flow, 139 | x_indices, y_indices, 140 | x_mask, y_mask) 141 | total_dt += gt_timestamps[gt_iter+1] - gt_timestamps[gt_iter] 142 | 143 | gt_iter += 1 144 | 145 | final_dt = end_time - gt_timestamps[gt_iter] 146 | total_dt += final_dt 147 | 148 | final_gt_dt = gt_timestamps[gt_iter+1] - gt_timestamps[gt_iter] 149 | 150 | x_flow = np.squeeze(x_flow_in[gt_iter, ...]) 151 | y_flow = np.squeeze(y_flow_in[gt_iter, ...]) 152 | 153 | scale_factor = final_dt / final_gt_dt 154 | 155 | prop_flow(x_flow, y_flow, 156 | x_indices, y_indices, 157 | x_mask, y_mask, 158 | scale_factor) 159 | 160 | x_shift = x_indices - orig_x_indices 161 | y_shift = y_indices - orig_y_indices 162 | x_shift[~x_mask] = 0 163 | y_shift[~y_mask] = 0 164 | 165 | return x_shift, y_shift 166 | -------------------------------------------------------------------------------- /src/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms.functional as F 3 | import cv2 4 | import numpy as np 5 | 6 | def warp_images_with_flow(images, flow): 7 | """ 8 | Generates a prediction of an image given the optical flow, as in Spatial Transformer Networks. 9 | """ 10 | dim3 = 0 11 | if images.dim() == 3: 12 | dim3 = 1 13 | images = images.unsqueeze(0) 14 | flow = flow.unsqueeze(0) 15 | height = images.shape[2] 16 | width = images.shape[3] 17 | flow_x,flow_y = flow[:,0,...],flow[:,1,...] 18 | coord_x, coord_y = torch.meshgrid(torch.arange(height), torch.arange(width)) 19 | 20 | pos_x = coord_x.reshape(height,width).type(torch.float32).cuda() + flow_x 21 | pos_y = coord_y.reshape(height,width).type(torch.float32).cuda() + flow_y 22 | pos_x = (pos_x-(height-1)/2)/((height-1)/2) 23 | pos_y = (pos_y-(width-1)/2)/((width-1)/2) 24 | 25 | pos = torch.stack((pos_y,pos_x),3).type(torch.float32) 26 | result = torch.nn.functional.grid_sample(images, pos, mode='bilinear', padding_mode='zeros') 27 | if dim3 == 1: 28 | result = result.squeeze() 29 | 30 | return result 31 | 32 | 33 | def charbonnier_loss(delta, alpha=0.45, epsilon=1e-3): 34 | """ 35 | Robust Charbonnier loss, as defined in equation (4) of the paper. 36 | """ 37 | loss = torch.mean(torch.pow((delta ** 2 + epsilon ** 2), alpha)) 38 | return loss 39 | 40 | 41 | def compute_smoothness_loss(flow): 42 | """ 43 | Local smoothness loss, as defined in equation (5) of the paper. 44 | The neighborhood here is defined as the 8-connected region around each pixel. 45 | """ 46 | flow_ucrop = flow[..., 1:] 47 | flow_dcrop = flow[..., :-1] 48 | flow_lcrop = flow[..., 1:, :] 49 | flow_rcrop = flow[..., :-1, :] 50 | 51 | flow_ulcrop = flow[..., 1:, 1:] 52 | flow_drcrop = flow[..., :-1, :-1] 53 | flow_dlcrop = flow[..., :-1, 1:] 54 | flow_urcrop = flow[..., 1:, :-1] 55 | 56 | smoothness_loss = charbonnier_loss(flow_lcrop - flow_rcrop) +\ 57 | charbonnier_loss(flow_ucrop - flow_dcrop) +\ 58 | charbonnier_loss(flow_ulcrop - flow_drcrop) +\ 59 | charbonnier_loss(flow_dlcrop - flow_urcrop) 60 | smoothness_loss /= 4. 61 | 62 | return smoothness_loss 63 | 64 | 65 | def compute_photometric_loss(prev_images, next_images, flow_dict): 66 | """ 67 | Multi-scale photometric loss, as defined in equation (3) of the paper. 68 | """ 69 | total_photometric_loss = 0. 70 | loss_weight_sum = 0. 71 | for i in range(len(flow_dict)): 72 | for image_num in range(prev_images.shape[0]): 73 | flow = flow_dict["flow{}".format(i)][image_num] 74 | height = flow.shape[1] 75 | width = flow.shape[2] 76 | 77 | prev_images_resize = F.to_tensor(F.resize(F.to_pil_image(prev_images[image_num].cpu()), 78 | [height, width])).cuda() 79 | next_images_resize = F.to_tensor(F.resize(F.to_pil_image(next_images[image_num].cpu()), 80 | [height, width])).cuda() 81 | 82 | next_images_warped = warp_images_with_flow(next_images_resize, flow) 83 | 84 | distance = next_images_warped - prev_images_resize 85 | photometric_loss = charbonnier_loss(distance) 86 | total_photometric_loss += photometric_loss 87 | loss_weight_sum += 1. 88 | total_photometric_loss /= loss_weight_sum 89 | 90 | return total_photometric_loss 91 | 92 | 93 | class TotalLoss(torch.nn.Module): 94 | def __init__(self, smoothness_weight, weight_decay_weight=1e-4): 95 | super(TotalLoss, self).__init__() 96 | self._smoothness_weight = smoothness_weight 97 | self._weight_decay_weight = weight_decay_weight 98 | 99 | def forward(self, flow_dict, prev_image, next_image, EVFlowNet_model): 100 | # weight decay loss 101 | weight_decay_loss = 0 102 | for i in EVFlowNet_model.parameters(): 103 | weight_decay_loss += torch.sum(i**2)/2*self._weight_decay_weight 104 | 105 | # smoothness loss 106 | smoothness_loss = 0 107 | for i in range(len(flow_dict)): 108 | smoothness_loss += compute_smoothness_loss(flow_dict["flow{}".format(i)]) 109 | smoothness_loss *= self._smoothness_weight / 4. 110 | 111 | # Photometric loss. 112 | photometric_loss = compute_photometric_loss(prev_image, 113 | next_image, 114 | flow_dict) 115 | 116 | # Warped next image for debugging. 117 | #next_image_warped = warp_images_with_flow(next_image, 118 | # flow_dict['flow3']) 119 | 120 | loss = weight_decay_loss + photometric_loss + smoothness_loss 121 | 122 | return loss 123 | 124 | if __name__ == "__main__": 125 | ''' 126 | a = torch.rand(7,7) 127 | b = torch.rand(7,7) 128 | flow = {} 129 | flow['flow0'] = torch.rand(1,2,3,3) 130 | loss = compute_photometric_loss(a,b,0,flow) 131 | print(loss) 132 | ''' 133 | a = torch.rand(1,5,5).cuda() 134 | #b = torch.rand(5,5)*5 135 | b = torch.rand((5,5)).type(torch.float32).cuda() 136 | b.requires_grad = True 137 | #c = torch.rand(5,5)*5 138 | c = torch.rand((5,5)).type(torch.float32).cuda() 139 | c.requires_grad = True 140 | d = torch.stack((b,c),0) 141 | print(a) 142 | print(b) 143 | print(c) 144 | r = warp_images_with_flow(a,d) 145 | print(r) 146 | r = torch.mean(r) 147 | r.backward() 148 | print(b.grad) 149 | print(c.grad) -------------------------------------------------------------------------------- /src/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import time 4 | import cv2 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from config import * 10 | from data_loader import EventData 11 | from eval_utils import * 12 | from vis_utils import * 13 | from EVFlowNet import EVFlowNet 14 | 15 | def drawImageTitle(img, title): 16 | cv2.putText(img, 17 | title, 18 | (60, 20), 19 | cv2.FONT_HERSHEY_SIMPLEX, 20 | 0.5, 21 | (255, 255, 255), 22 | thickness=2, 23 | bottomLeftOrigin=False) 24 | return img 25 | 26 | def test(args, EVFlowNet_model, EventDataLoder): 27 | if args.test_plot: 28 | cv2.namedWindow('EV-FlowNet Results', cv2.WINDOW_NORMAL) 29 | 30 | if args.gt_path: 31 | print("Loading ground truth {}".format(args.gt_path)) 32 | gt = np.load(args.gt_path) 33 | gt_timestamps = gt['timestamps'] 34 | U_gt_all = gt['x_flow_dist'] 35 | V_gt_all = gt['y_flow_dist'] 36 | print("Ground truth loaded") 37 | 38 | AEE_sum = 0. 39 | percent_AEE_sum = 0. 40 | AEE_list = [] 41 | 42 | if args.save_test_output: 43 | output_flow_list = [] 44 | gt_flow_list = [] 45 | event_image_list = [] 46 | 47 | max_flow_sum = 0 48 | min_flow_sum = 0 49 | iters = 0 50 | 51 | for event_image, prev_image, next_image, image_timestamps in EventDataLoder: 52 | image_timestamps[0] = image_timestamps[0].numpy() 53 | image_timestamps[1] = image_timestamps[1].numpy() 54 | prev_image = prev_image.numpy() 55 | next_image = next_image.numpy() 56 | prev_image = np.transpose(prev_image, (0,2,3,1)) 57 | next_image = np.transpose(next_image, (0,2,3,1)) 58 | 59 | start_time = time.time() 60 | flow_dict = EVFlowNet_model(event_image.cuda()) 61 | network_duration = time.time() - start_time 62 | 63 | pred_flow = np.squeeze(flow_dict['flow3'].detach().cpu().numpy()) 64 | pred_flow = np.transpose(pred_flow, (1,2,0)) 65 | pred_flow = np.flip(pred_flow, 2) 66 | 67 | max_flow_sum += np.max(pred_flow) 68 | min_flow_sum += np.min(pred_flow) 69 | 70 | event_count_image = torch.sum(event_image[:, :2, ...], dim=1).numpy() 71 | event_count_image = (event_count_image * 255 / event_count_image.max()).astype(np.uint8) 72 | event_count_image = np.squeeze(event_count_image) 73 | 74 | if args.save_test_output: 75 | output_flow_list.append(pred_flow) 76 | event_image_list.append(event_count_image) 77 | 78 | if args.gt_path: 79 | U_gt, V_gt = estimate_corresponding_gt_flow(U_gt_all, V_gt_all, 80 | gt_timestamps, 81 | image_timestamps[0], 82 | image_timestamps[1]) 83 | 84 | gt_flow = np.stack((U_gt, V_gt), axis=2) 85 | 86 | if args.save_test_output: 87 | gt_flow_list.append(gt_flow) 88 | 89 | image_size = pred_flow.shape 90 | full_size = gt_flow.shape 91 | xsize = full_size[1] 92 | ysize = full_size[0] 93 | xcrop = image_size[1] 94 | ycrop = image_size[0] 95 | xoff = (xsize - xcrop) // 2 96 | yoff = (ysize - ycrop) // 2 97 | 98 | gt_flow = gt_flow[yoff:-yoff, xoff:-xoff, :] 99 | 100 | # Calculate flow error. 101 | AEE, percent_AEE, n_points = flow_error_dense(gt_flow, 102 | pred_flow, 103 | event_count_image, 104 | 'outdoor' in args.test_sequence) 105 | AEE_list.append(AEE) 106 | AEE_sum += AEE 107 | percent_AEE_sum += percent_AEE 108 | 109 | iters += 1 110 | if iters % 100 == 0: 111 | print('-------------------------------------------------------') 112 | print('Iter: {}, time: {:f}, run time: {:.3f}s\n' 113 | 'Mean max flow: {:.2f}, mean min flow: {:.2f}' 114 | .format(iters, image_timestamps[0][0], network_duration, 115 | max_flow_sum / iters, min_flow_sum / iters)) 116 | if args.gt_path: 117 | print('Mean AEE: {:.2f}, mean %AEE: {:.2f}, # pts: {:.2f}' 118 | .format(AEE_sum / iters, 119 | percent_AEE_sum / iters, 120 | n_points)) 121 | 122 | # Prep outputs for nice visualization. 123 | if args.test_plot: 124 | pred_flow_rgb = flow_viz_np(pred_flow[..., 0], pred_flow[..., 1]) 125 | pred_flow_rgb = drawImageTitle(pred_flow_rgb, 'Predicted Flow') 126 | 127 | event_time_image = np.squeeze(np.amax(event_image[:, 2:, ...].numpy(), axis=1)) 128 | event_time_image = (event_time_image * 255 / event_time_image.max()).astype(np.uint8) 129 | event_time_image = np.tile(event_time_image[..., np.newaxis], [1, 1, 3]) 130 | 131 | event_count_image = np.tile(event_count_image[..., np.newaxis], [1, 1, 3]) 132 | 133 | event_time_image = drawImageTitle(event_time_image, 'Timestamp Image') 134 | event_count_image = drawImageTitle(event_count_image, 'Count Image') 135 | 136 | prev_image = np.squeeze(prev_image) 137 | prev_image = prev_image * 255. 138 | prev_image = np.tile(prev_image[..., np.newaxis], [1, 1, 3]) 139 | 140 | prev_image = drawImageTitle(prev_image, 'Grayscale Image') 141 | 142 | gt_flow_rgb = np.zeros(pred_flow_rgb.shape) 143 | errors = np.zeros(pred_flow_rgb.shape) 144 | 145 | gt_flow_rgb = drawImageTitle(gt_flow_rgb, 'GT Flow - No GT') 146 | errors = drawImageTitle(errors, 'Flow Error - No GT') 147 | 148 | if args.gt_path: 149 | errors = np.linalg.norm(gt_flow - pred_flow, axis=-1) 150 | errors[np.isinf(errors)] = 0 151 | errors[np.isnan(errors)] = 0 152 | errors = (errors * 255. / errors.max()).astype(np.uint8) 153 | errors = np.tile(errors[..., np.newaxis], [1, 1, 3]) 154 | errors[event_count_image == 0] = 0 155 | 156 | if 'outdoor' in args.test_sequence: 157 | errors[190:, :] = 0 158 | 159 | gt_flow_rgb = flow_viz_np(gt_flow[...,0], gt_flow[...,1]) 160 | 161 | gt_flow_rgb = drawImageTitle(gt_flow_rgb, 'GT Flow') 162 | errors= drawImageTitle(errors, 'Flow Error') 163 | 164 | top_cat = np.concatenate([event_count_image, prev_image, pred_flow_rgb], axis=1) 165 | bottom_cat = np.concatenate([event_time_image, errors, gt_flow_rgb], axis=1) 166 | cat = np.concatenate([top_cat, bottom_cat], axis=0) 167 | cat = cat.astype(np.uint8) 168 | cv2.imshow('EV-FlowNet Results', cat) 169 | cv2.waitKey(1) 170 | 171 | print('Testing done. ') 172 | if args.gt_path: 173 | print('mean AEE {:02f}, mean %AEE {:02f}' 174 | .format(AEE_sum / iters, 175 | percent_AEE_sum / iters)) 176 | if args.save_test_output: 177 | if args.gt_path: 178 | print('Saving data to {}_output_gt.npz'.format(args.test_sequence)) 179 | np.savez('{}_output_gt.npz'.format(args.test_sequence), 180 | output_flows=np.stack(output_flow_list, axis=0), 181 | gt_flows=np.stack(gt_flow_list, axis=0), 182 | event_images=np.stack(event_image_list, axis=0)) 183 | else: 184 | print('Saving data to {}_output.npz'.format(args.test_sequence)) 185 | np.savez('{}_output.npz'.format(args.test_sequence), 186 | output_flows=np.stack(output_flow_list, axis=0), 187 | event_images=np.stack(event_image_list, axis=0)) 188 | 189 | 190 | def main(): 191 | args = configs() 192 | args.load_path = os.path.join(args.load_path, args.training_instance) 193 | 194 | EVFlowNet_model = EVFlowNet(args).cuda() 195 | EVFlowNet_model.load_state_dict(torch.load(args.load_path+'/model91')) 196 | #para = np.load('D://p.npy').item() 197 | #EVFlowNet_model.load_state_dict(para) 198 | EventDataset = EventData(args.data_path, 'test', skip_frames=args.test_skip_frames) 199 | EventDataLoader = torch.utils.data.DataLoader(dataset=EventDataset, batch_size=1, shuffle=False) 200 | 201 | if not args.load_path: 202 | raise Exception("You need to set `load_path` and `training_instance`.") 203 | 204 | EVFlowNet_model.eval() 205 | ''' 206 | event,pre,next_,_ = next(iter(EventDataLoader)) 207 | flow = EVFlowNet_model(event.cuda()) 208 | a = flow['flow3'] 209 | x = a[0,0].detach().cpu().numpy() 210 | y = a[0,1].detach().cpu().numpy() 211 | x[np.isnan(x)] = 0 212 | x[np.isinf(x)] = np.max(x[~np.isinf(x)]) 213 | y[np.isnan(y)] = 0 214 | y[np.isinf(y)] = np.max(y[~np.isinf(y)]) 215 | a = np.sqrt(x**2+y**2) 216 | b = np.arctan(y/x) 217 | b[np.isnan(b)] = 0 218 | b[np.isinf(b)] = np.max(b[~np.isinf(b)]) 219 | a = 255*(a-np.min(a))/(np.max(a)-np.min(a)) 220 | a = a.astype(np.uint8) 221 | b = 180*(b-np.min(b))/(np.max(b)-np.min(b)) 222 | b = b.astype(np.uint8) 223 | c = 255*np.ones(a.shape).astype(np.uint8) 224 | a = np.stack((b,a,c),axis=2) 225 | a = cv2.cvtColor(a,cv2.COLOR_HSV2BGR) 226 | cv2.namedWindow('w') 227 | cv2.imshow('w',a) 228 | cv2.waitKey() 229 | ''' 230 | test(args, EVFlowNet_model, EventDataLoader) 231 | 232 | 233 | if __name__ == "__main__": 234 | main() 235 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import trange 3 | from tqdm import tqdm 4 | import numpy as np 5 | from datetime import datetime 6 | from losses import * 7 | 8 | import torch 9 | 10 | from config import configs 11 | from data_loader import EventData 12 | from EVFlowNet import EVFlowNet 13 | 14 | def main(): 15 | args = configs() 16 | 17 | if args.training_instance: 18 | args.load_path = os.path.join(args.load_path, args.training_instance) 19 | else: 20 | args.load_path = os.path.join(args.load_path, 21 | "evflownet_{}".format(datetime.now() 22 | .strftime("%m%d_%H%M%S"))) 23 | if not os.path.exists(args.load_path): 24 | os.makedirs(args.load_path) 25 | 26 | EventDataset = EventData(args.data_path, 'train') 27 | EventDataLoader = torch.utils.data.DataLoader(dataset=EventDataset, batch_size=args.batch_size, shuffle=True) 28 | 29 | # model 30 | EVFlowNet_model = EVFlowNet(args).cuda() 31 | 32 | #para = np.load('D://p.npy', allow_pickle=True).item() 33 | #EVFlowNet_model.load_state_dict(para) 34 | EVFlowNet_model.load_state_dict(torch.load(args.load_path+'/../model')) 35 | 36 | #EVFlowNet_parallelmodel = torch.nn.DataParallel(EVFlowNet_model) 37 | # optimizer 38 | optimizer = torch.optim.Adam(EVFlowNet_model.parameters(), lr=args.initial_learning_rate) 39 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=args.learning_rate_decay) 40 | loss_fun = TotalLoss(args.smoothness_weight) 41 | 42 | iteration = 0 43 | size = 0 44 | EVFlowNet_model.train() 45 | for epoch in range(100): 46 | loss_sum = 0.0 47 | print('*****************************************') 48 | print('epoch:'+str(epoch)) 49 | for event_image, prev_image, next_image, _ in tqdm(EventDataLoader): 50 | event_image = event_image.cuda() 51 | prev_image = prev_image.cuda() 52 | next_image = next_image.cuda() 53 | 54 | optimizer.zero_grad() 55 | flow_dict = EVFlowNet_model(event_image) 56 | 57 | loss = loss_fun(flow_dict, prev_image, next_image, EVFlowNet_model) 58 | 59 | loss.backward() 60 | optimizer.step() 61 | loss_sum += loss.item() 62 | iteration += 1 63 | size += 1 64 | print(loss) 65 | if iteration % 100 == 0: 66 | print('iteration:', iteration) 67 | print('loss:', loss_sum/100) 68 | loss_sum = 0.0 69 | torch.save(EVFlowNet_model.state_dict(), args.load_path+'/model%d'%epoch) 70 | if epoch % 4 == 3: 71 | scheduler.step() 72 | print('iteration:', iteration) 73 | print('loss:', loss_sum/size) 74 | size = 0 75 | 76 | 77 | 78 | 79 | if __name__ == "__main__": 80 | main() 81 | -------------------------------------------------------------------------------- /src/vis_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import numpy as np 3 | import math 4 | import cv2 5 | 6 | """ 7 | Generates an RGB image where each point corresponds to flow in that direction from the center, 8 | as visualized by flow_viz_tf. 9 | Output: color_wheel_rgb: [1, width, height, 3] 10 | def draw_color_wheel_tf(width, height): 11 | color_wheel_x = tf.lin_space(-width / 2., 12 | width / 2., 13 | width) 14 | color_wheel_y = tf.lin_space(-height / 2., 15 | height / 2., 16 | height) 17 | color_wheel_X, color_wheel_Y = tf.meshgrid(color_wheel_x, color_wheel_y) 18 | color_wheel_flow = tf.stack([color_wheel_X, color_wheel_Y], axis=2) 19 | color_wheel_flow = tf.expand_dims(color_wheel_flow, 0) 20 | color_wheel_rgb, flow_norm, flow_ang = flow_viz_tf(color_wheel_flow) 21 | return color_wheel_rgb 22 | """ 23 | 24 | def draw_color_wheel_np(width, height): 25 | color_wheel_x = np.linspace(-width / 2., 26 | width / 2., 27 | width) 28 | color_wheel_y = np.linspace(-height / 2., 29 | height / 2., 30 | height) 31 | color_wheel_X, color_wheel_Y = np.meshgrid(color_wheel_x, color_wheel_y) 32 | color_wheel_rgb = flow_viz_np(color_wheel_X, color_wheel_Y) 33 | return color_wheel_rgb 34 | 35 | """ 36 | Visualizes optical flow in HSV space using TensorFlow, with orientation as H, magnitude as V. 37 | Returned as RGB. 38 | Input: flow: [batch_size, width, height, 2] 39 | Output: flow_rgb: [batch_size, width, height, 3] 40 | def flow_viz_tf(flow): 41 | flow_norm = tf.norm(flow, axis=3) 42 | 43 | flow_ang_rad = tf.atan2(flow[:, :, :, 1], flow[:, :, :, 0]) 44 | flow_ang = (flow_ang_rad / math.pi) / 2. + 0.5 45 | 46 | const_mat = tf.ones(tf.shape(flow_norm)) 47 | hsv = tf.stack([flow_ang, const_mat, flow_norm], axis=3) 48 | flow_rgb = tf.image.hsv_to_rgb(hsv) 49 | return flow_rgb, flow_norm, flow_ang_rad 50 | """ 51 | 52 | def flow_viz_np(flow_x, flow_y): 53 | import cv2 54 | flows = np.stack((flow_x, flow_y), axis=2) 55 | flows[np.isinf(flows)] = 0 56 | flows[np.isnan(flows)] = 0 57 | mag = np.linalg.norm(flows, axis=2) 58 | ang = np.arctan2(flow_y, flow_x) 59 | p = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) 60 | ang += np.pi 61 | ang *= 180. / np.pi / 2. 62 | ang = ang.astype(np.uint8) 63 | hsv = np.zeros([flow_x.shape[0], flow_x.shape[1], 3], dtype=np.uint8) 64 | hsv[:, :, 0] = ang 65 | hsv[:, :, 1] = 255 66 | hsv[:, :, 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) 67 | flow_rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) 68 | return flow_rgb 69 | 70 | --------------------------------------------------------------------------------