├── .gitignore ├── LICENSE ├── NeuFlow.pdf ├── NeuFlow ├── backbone.py ├── config.py ├── matching.py ├── neuflow.py ├── refine.py ├── transformer.py ├── upsample.py └── utils.py ├── NeuFlow_v2 └── neuflow_v2_plot.png ├── README.md ├── data_utils ├── datasets.py ├── evaluate.py ├── flow_viz.py ├── frame_utils.py └── transforms.py ├── dataset_preview ├── dataset_1.png └── dataset_2.png ├── dist_utils.py ├── eval.py ├── load_model.py ├── loss.py ├── neuflow_sintel.pth ├── neuflow_things.pth ├── train.py └── write_occ.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | datasets -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /NeuFlow.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neufieldrobotics/NeuFlow/1881c3cd85b87c6155d1b0726b8a345b6a1bdb7e/NeuFlow.pdf -------------------------------------------------------------------------------- /NeuFlow/backbone.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | class ConvBlock(torch.nn.Module): 6 | def __init__(self, in_planes, out_planes, kernel_size, stride, padding): 7 | super(ConvBlock, self).__init__() 8 | 9 | self.conv1 = torch.nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, padding_mode='zeros', bias=False) 10 | 11 | self.conv2 = torch.nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False) 12 | 13 | self.relu = torch.nn.LeakyReLU(negative_slope=0.1, inplace=False) 14 | 15 | self.norm = torch.nn.BatchNorm2d(out_planes, eps=1e-06, affine=False) 16 | 17 | # self.dropout = torch.nn.Dropout(p=0.1) 18 | 19 | def forward(self, x): 20 | 21 | # x = self.dropout(x) 22 | 23 | x1 = self.relu(self.conv1(x)) 24 | x2 = self.relu(self.conv2(x1)) 25 | 26 | return self.norm(x1 + x2) 27 | 28 | class DownDimBlock(torch.nn.Module): 29 | def __init__(self, in_planes, out_planes): 30 | super(DownDimBlock, self).__init__() 31 | 32 | self.relu = torch.nn.LeakyReLU(negative_slope=0.1, inplace=False) 33 | 34 | self.conv_block = ConvBlock(in_planes, out_planes, kernel_size=1, stride=1, padding=0) 35 | 36 | def forward(self, x): 37 | 38 | return self.conv_block(self.relu(x)) 39 | 40 | class CNNEncoder(torch.nn.Module): 41 | def __init__(self, feature_dim): 42 | super(CNNEncoder, self).__init__() 43 | 44 | # self.conv0 = torch.nn.Conv2d(3, 8, kernel_size=3, stride=1, padding=1, padding_mode='zeros', bias=False) # rgb2gray 45 | # self.norm0 = torch.nn.BatchNorm2d(16, eps=1e-06, affine=False) 46 | 47 | self.block1_1 = ConvBlock(3, feature_dim, kernel_size=8, stride=8, padding=0) # 1/1 48 | 49 | self.block1_2 = ConvBlock(3, feature_dim, kernel_size=8, stride=4, padding=2) # 1/2 50 | 51 | self.block1_3 = ConvBlock(3, feature_dim, kernel_size=8, stride=2, padding=3) # 1/4 52 | 53 | self.block1_4 = ConvBlock(3, feature_dim, kernel_size=7, stride=1, padding=3) # 1/8 54 | 55 | self.block1_dd = DownDimBlock(feature_dim * 4, feature_dim) # pick features 56 | self.block1_ds = ConvBlock(feature_dim, feature_dim, kernel_size=2, stride=2, padding=0) 57 | 58 | self.block2 = ConvBlock(3, feature_dim, kernel_size=7, stride=1, padding=3) # 1/16 59 | self.block2_dd = DownDimBlock(feature_dim * 2, feature_dim) # pick features 60 | 61 | def init_pos(self, batch_size, height, width): 62 | 63 | ys, xs = torch.meshgrid(torch.arange(height), torch.arange(width), indexing='ij') 64 | ys = ys.cuda() / (height-1) 65 | xs = xs.cuda() / (width-1) 66 | pos = torch.stack([ys, xs]) 67 | return pos[None].repeat(batch_size,1,1,1) 68 | 69 | def init_pos_12(self, batch_size, height, width): 70 | self.pos_1 = self.init_pos(batch_size, height, width) 71 | self.pos_2 = self.init_pos(batch_size, height//2, width//2) 72 | 73 | def forward(self, img): 74 | 75 | b = img.shape[0] 76 | 77 | # x = self.relu(self.norm0(self.conv0(x))) 78 | 79 | x1_1 = self.block1_1(img) 80 | 81 | img = F.avg_pool2d(img, kernel_size=2, stride=2) 82 | 83 | x1_2 = self.block1_2(img) 84 | 85 | img = F.avg_pool2d(img, kernel_size=2, stride=2) 86 | 87 | x1_3 = self.block1_3(img) 88 | 89 | img = F.avg_pool2d(img, kernel_size=2, stride=2) 90 | 91 | x1_4 = self.block1_4(img) 92 | 93 | x1 = torch.cat([x1_1, x1_2, x1_3, x1_4], dim=1) 94 | x1 = self.block1_dd(x1) 95 | 96 | img = F.avg_pool2d(img, kernel_size=2, stride=2) 97 | 98 | x2 = self.block2(img) 99 | 100 | x2 = torch.cat([self.block1_ds(x1), x2], dim=1) 101 | x2 = self.block2_dd(x2) 102 | 103 | x1 = torch.cat([x1, self.pos_1], dim=1) 104 | x2 = torch.cat([x2, self.pos_2], dim=1) 105 | 106 | # x2 = self.self_attn(x2, x2) 107 | 108 | return [x1, x2] 109 | -------------------------------------------------------------------------------- /NeuFlow/config.py: -------------------------------------------------------------------------------- 1 | feature_dim = 90 2 | -------------------------------------------------------------------------------- /NeuFlow/matching.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | 3 | from NeuFlow import utils 4 | 5 | 6 | class Matching: 7 | 8 | def init_grid(self, batch_size, height, width): 9 | self.grid = utils.coords_grid(batch_size, height, width).cuda() # [B, 2, H, W] 10 | self.flatten_grid = self.grid.view(batch_size, 2, -1).permute(0, 2, 1) # [B, H*W, 2] 11 | 12 | def global_correlation_softmax(self, feature0, feature1): 13 | 14 | b, c, h, w = feature0.shape 15 | 16 | feature0 = feature0.flatten(-2).permute(0, 2, 1) 17 | feature1 = feature1.flatten(-2).permute(0, 2, 1) 18 | 19 | correspondence = F.scaled_dot_product_attention(feature0, feature1, self.flatten_grid) 20 | 21 | correspondence = correspondence.view(b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W] 22 | 23 | flow = correspondence - self.grid 24 | 25 | return flow 26 | -------------------------------------------------------------------------------- /NeuFlow/neuflow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from NeuFlow import backbone 5 | from NeuFlow import transformer 6 | from NeuFlow import matching 7 | from NeuFlow import refine 8 | from NeuFlow import upsample 9 | from NeuFlow import utils 10 | 11 | from NeuFlow import config 12 | 13 | 14 | class NeuFlow(torch.nn.Module): 15 | def __init__(self): 16 | super(NeuFlow, self).__init__() 17 | 18 | self.backbone = backbone.CNNEncoder(config.feature_dim) 19 | self.cross_attn_s16 = transformer.FeatureAttention(config.feature_dim+2, num_layers=2, bidir=True, ffn=True, ffn_dim_expansion=1, post_norm=True) 20 | 21 | self.matching_s16 = matching.Matching() 22 | 23 | self.flow_attn_s16 = transformer.FlowAttention(config.feature_dim+2) 24 | 25 | self.merge_s8 = torch.nn.Sequential(torch.nn.Conv2d((config.feature_dim+2) * 2, config.feature_dim * 2, kernel_size=3, stride=1, padding=1, bias=False), 26 | torch.nn.GELU(), 27 | torch.nn.Conv2d(config.feature_dim * 2, config.feature_dim, kernel_size=3, stride=1, padding=1, bias=False)) 28 | 29 | self.refine_s8 = refine.Refine(config.feature_dim, patch_size=7, num_layers=6) 30 | 31 | self.conv_s8 = backbone.ConvBlock(3, config.feature_dim, kernel_size=8, stride=8, padding=0) 32 | 33 | self.upsample_s1 = upsample.UpSample(config.feature_dim, upsample_factor=8) 34 | 35 | for p in self.parameters(): 36 | if p.dim() > 1: 37 | torch.nn.init.xavier_uniform_(p) 38 | 39 | def init_bhw(self, batch_size, height, width): 40 | self.backbone.init_pos_12(batch_size, height//8, width//8) 41 | self.matching_s16.init_grid(batch_size, height//16, width//16) 42 | self.img_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).cuda() 43 | self.img_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).cuda() 44 | 45 | def forward(self, img0, img1): 46 | 47 | flow_list = [] 48 | 49 | img0 = utils.normalize_img(img0, self.img_mean, self.img_std) 50 | img1 = utils.normalize_img(img1, self.img_mean, self.img_std) 51 | 52 | feature0_s8, feature0_s16 = self.backbone(img0) 53 | feature1_s8, feature1_s16 = self.backbone(img1) 54 | 55 | feature0_s16, feature1_s16 = self.cross_attn_s16(feature0_s16, feature1_s16) 56 | flow0 = self.matching_s16.global_correlation_softmax(feature0_s16, feature1_s16) 57 | 58 | flow0 = self.flow_attn_s16(feature0_s16, flow0) 59 | 60 | feature0_s16 = F.interpolate(feature0_s16, scale_factor=2, mode='nearest') 61 | feature1_s16 = F.interpolate(feature1_s16, scale_factor=2, mode='nearest') 62 | 63 | feature0_s8 = self.merge_s8(torch.cat([feature0_s8, feature0_s16], dim=1)) 64 | feature1_s8 = self.merge_s8(torch.cat([feature1_s8, feature1_s16], dim=1)) 65 | 66 | flow0 = F.interpolate(flow0, scale_factor=2, mode='nearest') * 2 67 | 68 | delta_flow = self.refine_s8(feature0_s8, utils.flow_warp(feature1_s8, flow0), flow0) 69 | flow0 = flow0 + delta_flow 70 | 71 | if self.training: 72 | up_flow0 = F.interpolate(flow0, scale_factor=8, mode='bilinear', align_corners=True) * 8 73 | flow_list.append(up_flow0) 74 | 75 | feature0_s8 = self.conv_s8(img0) 76 | 77 | flow0 = self.upsample_s1(feature0_s8, flow0) 78 | flow_list.append(flow0) 79 | 80 | return flow_list 81 | -------------------------------------------------------------------------------- /NeuFlow/refine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from spatial_correlation_sampler import SpatialCorrelationSampler 4 | 5 | class ConvBlock(torch.nn.Module): 6 | def __init__(self, in_planes, out_planes, kernel_size, stride, padding): 7 | super(ConvBlock, self).__init__() 8 | 9 | self.conv = torch.nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, padding_mode='zeros', bias=True) 10 | self.relu = torch.nn.LeakyReLU(negative_slope=0.1, inplace=False) 11 | 12 | def forward(self, x): 13 | return self.relu(self.conv(x)) 14 | 15 | class Refine(torch.nn.Module): 16 | def __init__(self, feature_dim, patch_size, num_layers): 17 | super(Refine, self).__init__() 18 | 19 | self.patch_size = patch_size 20 | 21 | self.correlation_sampler = SpatialCorrelationSampler(kernel_size=1, patch_size=patch_size, stride=1, padding=0, dilation=1) 22 | 23 | self.conv1 = ConvBlock(patch_size**2+feature_dim+2, 96, kernel_size=3, stride=1, padding=1) 24 | 25 | self.conv_layers = torch.nn.ModuleList([ConvBlock(96, 96, kernel_size=3, stride=1, padding=1) 26 | for i in range(num_layers)]) 27 | 28 | self.conv2 = ConvBlock(96, 64, kernel_size=3, stride=1, padding=1) 29 | self.conv3 = ConvBlock(64, 32, kernel_size=3, stride=1, padding=1) 30 | self.conv4 = torch.nn.Conv2d(32, 2, kernel_size=3, stride=1, padding=1, padding_mode='zeros', bias=True) 31 | 32 | def forward(self, feature_0, feature_1, flow_0): 33 | 34 | b, c, h, w = feature_0.shape 35 | 36 | attn = self.correlation_sampler(feature_0, feature_1).view(b, -1, h, w) 37 | # attn = F.softmax(attn, dim=1) 38 | 39 | x = torch.cat([attn, feature_0, flow_0], dim=1) 40 | 41 | x = self.conv1(x) 42 | 43 | for layer in self.conv_layers: 44 | x = layer(x) 45 | 46 | x = self.conv2(x) 47 | x = self.conv3(x) 48 | 49 | return self.conv4(x) -------------------------------------------------------------------------------- /NeuFlow/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | class TransformerLayer(torch.nn.Module): 5 | def __init__(self, 6 | feature_dim, 7 | ffn=True, 8 | ffn_dim_expansion=1 9 | ): 10 | super(TransformerLayer, self).__init__() 11 | 12 | # multi-head attention 13 | self.q_proj = torch.nn.Linear(feature_dim, feature_dim) 14 | self.k_proj = torch.nn.Linear(feature_dim, feature_dim) 15 | self.v_proj = torch.nn.Linear(feature_dim, feature_dim) 16 | 17 | self.merge = torch.nn.Linear(feature_dim, feature_dim) 18 | 19 | # self.multi_head_attn = torch.nn.MultiheadAttention(feature_dim, 2, batch_first=True, device='cuda') 20 | 21 | self.norm1 = torch.nn.LayerNorm(feature_dim) 22 | 23 | self.ffn = ffn 24 | 25 | if self.ffn: 26 | in_channels = feature_dim * 2 27 | self.mlp = torch.nn.Sequential( 28 | torch.nn.Linear(in_channels, in_channels * ffn_dim_expansion, bias=False), 29 | torch.nn.GELU(), 30 | torch.nn.Linear(in_channels * ffn_dim_expansion, feature_dim, bias=False), 31 | ) 32 | 33 | self.norm2 = torch.nn.LayerNorm(feature_dim) 34 | 35 | def forward(self, source, target): 36 | # source, target: [B, L, C] 37 | query, key, value = source, target, target 38 | 39 | # single-head attention 40 | query = self.q_proj(query) # [B, L, C] 41 | key = self.k_proj(key) # [B, L, C] 42 | value = self.v_proj(value) # [B, L, C] 43 | 44 | message = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0) 45 | 46 | message = self.merge(message) 47 | 48 | # message, _ = self.multi_head_attn(query, key, value, need_weights=False) 49 | message = self.norm1(message) 50 | 51 | if self.ffn: 52 | message = self.mlp(torch.cat([source, message], dim=-1)) 53 | message = self.norm2(message) 54 | 55 | return source + message 56 | 57 | class FeatureAttention(torch.nn.Module): 58 | def __init__(self, feature_dim, num_layers, bidir=True, ffn=True, ffn_dim_expansion=1, post_norm=False): 59 | super(FeatureAttention, self).__init__() 60 | 61 | self.bidir = bidir 62 | 63 | self.layers = torch.nn.ModuleList([ 64 | TransformerLayer(feature_dim, ffn=ffn, ffn_dim_expansion=ffn_dim_expansion 65 | ) 66 | for i in range(num_layers)]) 67 | 68 | self.post_norm = post_norm 69 | 70 | if self.post_norm: 71 | self.norm = torch.nn.LayerNorm(feature_dim, eps=1e-06) 72 | 73 | def forward(self, feature0, feature1): 74 | 75 | b, c, h, w = feature0.shape 76 | 77 | feature0 = feature0.flatten(-2).permute(0, 2, 1) # [B, H*W, C] 78 | feature1 = feature1.flatten(-2).permute(0, 2, 1) # [B, H*W, C] 79 | 80 | if self.bidir: 81 | 82 | concat0 = torch.cat((feature0, feature1), dim=0) # [2B, H*W, C] 83 | concat1 = torch.cat((feature1, feature0), dim=0) # [2B, H*W, C] 84 | 85 | for layer in self.layers: 86 | concat0 = layer(concat0, concat1) 87 | concat1 = torch.cat(concat0.chunk(chunks=2, dim=0)[::-1], dim=0) 88 | 89 | if self.post_norm: 90 | concat0 = self.norm(concat0) 91 | 92 | feature0, feature1 = concat0.chunk(chunks=2, dim=0) # [B, H*W, C] 93 | 94 | # reshape back 95 | feature0 = feature0.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W] 96 | feature1 = feature1.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W] 97 | 98 | return feature0, feature1 99 | 100 | else: 101 | for layer in self.layers: 102 | feature0 = layer(feature0, feature1) 103 | 104 | if self.post_norm: 105 | feature0 = self.norm(feature0) 106 | 107 | feature0 = feature0.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W] 108 | 109 | return feature0 110 | 111 | class FlowAttention(torch.nn.Module): 112 | """ 113 | flow propagation with self-attention on feature 114 | query: feature0, key: feature0, value: flow 115 | """ 116 | 117 | def __init__(self, feature_dim): 118 | super(FlowAttention, self).__init__() 119 | 120 | self.q_proj = torch.nn.Linear(feature_dim, feature_dim) 121 | self.k_proj = torch.nn.Linear(feature_dim, feature_dim) 122 | 123 | def forward(self, feature, flow): 124 | # q, k: feature [B, C, H, W], v: flow [B, 2, H, W] 125 | b, c, h, w = feature.size() 126 | 127 | feature = feature.flatten(-2).permute(0, 2, 1) # [B, H*W, C] 128 | 129 | flow = flow.flatten(-2).permute(0, 2, 1) 130 | 131 | query = self.q_proj(feature) # [B, H*W, C] 132 | key = self.k_proj(feature) # [B, H*W, C] 133 | 134 | flow = F.scaled_dot_product_attention(query, key, flow) 135 | 136 | flow = flow.view(b, h, w, 2).permute(0, 3, 1, 2) 137 | 138 | return flow 139 | -------------------------------------------------------------------------------- /NeuFlow/upsample.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | # from spatial_correlation_sampler import SpatialCorrelationSampler 4 | 5 | class UpSample(torch.nn.Module): 6 | def __init__(self, feature_dim, upsample_factor): 7 | super(UpSample, self).__init__() 8 | 9 | self.upsample_factor = upsample_factor 10 | 11 | self.conv1 = torch.nn.Conv2d(2 + feature_dim, 256, 3, 1, 1) 12 | self.conv2 = torch.nn.Conv2d(256, 512, 3, 1, 1) 13 | self.conv3 = torch.nn.Conv2d(512, upsample_factor ** 2 * 9, 1, 1, 0) 14 | self.relu = torch.nn.ReLU(inplace=True) 15 | 16 | def forward(self, feature, flow): 17 | 18 | concat = torch.cat((flow, feature), dim=1) 19 | 20 | mask = self.conv3(self.relu(self.conv2(self.relu(self.conv1(concat))))) 21 | 22 | b, _, h, w = flow.shape 23 | 24 | mask = mask.view(b, 1, 9, self.upsample_factor, self.upsample_factor, h, w) # [B, 1, 9, K, K, H, W] 25 | mask = torch.softmax(mask, dim=2) 26 | 27 | up_flow = F.unfold(self.upsample_factor * flow, [3, 3], padding=1) 28 | up_flow = up_flow.view(b, 2, 9, 1, 1, h, w) # [B, 2, 9, 1, 1, H, W] 29 | 30 | up_flow = torch.sum(mask * up_flow, dim=2) # [B, 2, K, K, H, W] 31 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) # [B, 2, K, H, K, W] 32 | up_flow = up_flow.reshape(b, 2, self.upsample_factor * h, 33 | self.upsample_factor * w) # [B, 2, K*H, K*W] 34 | 35 | return up_flow -------------------------------------------------------------------------------- /NeuFlow/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def normalize_img(img, mean, std): 5 | return (img / 255. - mean) / std 6 | 7 | def coords_grid(b, h, w): 8 | ys, xs = torch.meshgrid(torch.arange(h), torch.arange(w), indexing='ij') # [H, W] 9 | 10 | stacks = [xs, ys] 11 | 12 | grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W] 13 | 14 | grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W] 15 | 16 | return grid 17 | 18 | def bilinear_sample(img, sample_coords): 19 | 20 | b, _, h, w = sample_coords.shape 21 | 22 | # Normalize to [-1, 1] 23 | x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1 24 | y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1 25 | 26 | grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2] 27 | 28 | img = F.grid_sample(img, grid, mode='bilinear', padding_mode='zeros', align_corners=True) 29 | 30 | return img 31 | 32 | def flow_warp(feature, flow): 33 | 34 | b, c, h, w = feature.size() 35 | 36 | grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W] 37 | 38 | return bilinear_sample(feature, grid) 39 | -------------------------------------------------------------------------------- /NeuFlow_v2/neuflow_v2_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neufieldrobotics/NeuFlow/1881c3cd85b87c6155d1b0726b8a345b6a1bdb7e/NeuFlow_v2/neuflow_v2_plot.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NeuFlow v1 (Deprecated) Please move to [NeuFlow v2](https://github.com/neufieldrobotics/NeuFlow_v2) 2 | 3 | Official PyTorch implementation of paper: 4 | 5 | [NeuFlow: Real-time, High-accuracy Optical Flow Estimation on Robots Using Edge Devices](NeuFlow.pdf) 6 | 7 | Authors: [Zhiyong Zhang](https://www.linkedin.com/in/zhiyong-zhang-0772a0159/), [Huaizu Jiang](https://jianghz.me/), [Hanumant Singh](https://scholar.google.com/citations?user=1UEU5PEAAAAJ) 8 | 9 | ## [NeuFlow v2](https://github.com/neufieldrobotics/NeuFlow_v2) has been published, significantly improving real-world accuracy. 10 | 11 | 12 | 13 | ## Installation (PyTorch >= 2.0 is required) 14 | 15 | ``` 16 | conda create --name neuflow python==3.8 17 | conda activate neuflow 18 | conda install pytorch==2.0.1 torchvision==0.15.2 pytorch-cuda=11.7 -c pytorch -c nvidia 19 | pip install numpy opencv-python 20 | ``` 21 | 22 | [Pytorch-Correlation-extension 0.4.0](https://github.com/ClementPinard/Pytorch-Correlation-extension/tree/0.4.0) (Recommend to build from source) 23 | 24 | ``` 25 | git clone -b 0.4.0 https://github.com/ClementPinard/Pytorch-Correlation-extension.git 26 | cd Pytorch-Correlation-extension/ 27 | python setup.py install 28 | ``` 29 | 30 | ## Datasets 31 | 32 | The datasets used to train and evaluate NeuFlow are as follows: 33 | 34 | * [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs) 35 | * [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) 36 | * [Sintel](http://sintel.is.tue.mpg.de/) 37 | * [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow) 38 | * [HD1K](http://hci-benchmark.iwr.uni-heidelberg.de/) 39 | 40 | By default the dataloader assumes the datasets are located in folder `datasets` and are organized as follows: 41 | 42 | ``` 43 | datasets 44 | ├── FlyingChairs_release 45 | │   └── data 46 | ├── FlyingThings3D 47 | │   ├── frames_cleanpass 48 | │   ├── frames_finalpass 49 | │   └── optical_flow 50 | ├── HD1K 51 | │   ├── hd1k_challenge 52 | │   ├── hd1k_flow_gt 53 | │   ├── hd1k_flow_uncertainty 54 | │   └── hd1k_input 55 | ├── KITTI 56 | │   ├── testing 57 | │   └── training 58 | ├── Sintel 59 | │   ├── test 60 | │   └── training 61 | ``` 62 | 63 | Symlink your dataset root to `datasets`: 64 | 65 | ```shell 66 | ln -s $YOUR_DATASET_ROOT datasets 67 | ``` 68 | 69 | ## Training 70 | 71 | We trained the model for approximately a week, using a single RTX 4090 GPU and an i9-13900K CPU to achieve the accuracy reported in the paper. (The CPU played a crucial role, particularly as loading images might have been the bottleneck.) 72 | 73 | ``` 74 | python train.py \ 75 | --checkpoint_dir $YOUR_CHECKPOINT_DIR \ 76 | --stage things \ 77 | --val_dataset things sintel \ 78 | --batch_size 64 \ 79 | --num_workers 8 \ 80 | --lr 2e-4 \ 81 | --weight_decay 1e-4 \ 82 | --val_freq 2000 \ 83 | --max_flow 400 84 | ``` 85 | 86 | ## Optional 87 | 88 | Write occlusion files for FlyingThings3D to prevent minimal overlap between image pairs during training. 89 | ``` 90 | python write_occ.py 91 | ``` 92 | 93 | ## Evaluation 94 | 95 | ``` 96 | python eval.py \ 97 | --resume neuflow_things.pth 98 | ``` 99 | -------------------------------------------------------------------------------- /data_utils/datasets.py: -------------------------------------------------------------------------------- 1 | # Data loading based on https://github.com/NVIDIA/flownet2-pytorch 2 | 3 | import numpy as np 4 | import torch 5 | import torch.utils.data as data 6 | 7 | import os 8 | import random 9 | from glob import glob 10 | import os.path as osp 11 | 12 | from data_utils import frame_utils 13 | from data_utils.transforms import FlowAugmentor, SparseFlowAugmentor 14 | 15 | 16 | class FlowDataset(data.Dataset): 17 | def __init__(self, aug_params=None, sparse=False, virtual=False, 18 | load_occlusion=False, 19 | ): 20 | self.augmentor = None 21 | self.sparse = sparse 22 | self.virtual = virtual 23 | 24 | if aug_params is not None: 25 | if sparse: 26 | self.augmentor = SparseFlowAugmentor(**aug_params) 27 | else: 28 | self.augmentor = FlowAugmentor(**aug_params) 29 | 30 | self.is_test = False 31 | self.init_seed = False 32 | self.flow_list = [] 33 | self.image_list = [] 34 | self.extra_info = [] 35 | 36 | self.load_occlusion = load_occlusion 37 | self.occ_list = [] 38 | 39 | def __getitem__(self, index): 40 | 41 | if self.is_test: 42 | img1 = frame_utils.read_gen(self.image_list[index][0]) 43 | img2 = frame_utils.read_gen(self.image_list[index][1]) 44 | 45 | img1 = np.array(img1).astype(np.uint8)[..., :3] 46 | img2 = np.array(img2).astype(np.uint8)[..., :3] 47 | 48 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float() 49 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float() 50 | 51 | return img1, img2, self.extra_info[index] 52 | 53 | if not self.init_seed: 54 | worker_info = torch.utils.data.get_worker_info() 55 | if worker_info is not None: 56 | torch.manual_seed(worker_info.id) 57 | np.random.seed(worker_info.id) 58 | random.seed(worker_info.id) 59 | self.init_seed = True 60 | 61 | index = index % len(self.image_list) 62 | valid = None 63 | 64 | if self.sparse: 65 | if self.virtual: 66 | flow, valid = frame_utils.read_vkitti_png_flow(self.flow_list[index]) # [H, W, 2], [H, W] 67 | else: 68 | flow, valid = frame_utils.readFlowKITTI(self.flow_list[index]) # [H, W, 2], [H, W] 69 | else: 70 | flow = frame_utils.read_gen(self.flow_list[index]) 71 | 72 | if self.load_occlusion: 73 | occlusion = frame_utils.read_gen(self.occ_list[index]) # [H, W], 0 or 255 (occluded) 74 | 75 | img1 = frame_utils.read_gen(self.image_list[index][0]) 76 | img2 = frame_utils.read_gen(self.image_list[index][1]) 77 | 78 | flow = np.array(flow).astype(np.float32) 79 | img1 = np.array(img1).astype(np.uint8) 80 | img2 = np.array(img2).astype(np.uint8) 81 | 82 | if self.load_occlusion: 83 | occlusion = np.array(occlusion).astype(np.float32) 84 | 85 | # grayscale images 86 | if len(img1.shape) == 2: 87 | img1 = np.tile(img1[..., None], (1, 1, 3)) 88 | img2 = np.tile(img2[..., None], (1, 1, 3)) 89 | else: 90 | img1 = img1[..., :3] 91 | img2 = img2[..., :3] 92 | 93 | if self.augmentor is not None: 94 | if self.sparse: 95 | img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid) 96 | else: 97 | if self.load_occlusion: 98 | img1, img2, flow, occlusion = self.augmentor(img1, img2, flow, occlusion=occlusion) 99 | else: 100 | img1, img2, flow = self.augmentor(img1, img2, flow) 101 | 102 | if self.load_occlusion: 103 | 104 | if np.count_nonzero(occlusion) / (occlusion.shape[0]*occlusion.shape[1]) < 0.3: 105 | valid = np.zeros(flow.shape[:-1]) 106 | 107 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float() 108 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float() 109 | flow = torch.from_numpy(flow).permute(2, 0, 1).float() 110 | 111 | # if self.load_occlusion: 112 | # occlusion = torch.from_numpy(occlusion) # [H, W] 113 | 114 | if valid is not None: 115 | valid = torch.from_numpy(valid) 116 | else: 117 | valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000) 118 | 119 | # # mask out occluded pixels 120 | # if self.load_occlusion: 121 | # # non-occlusion: 0, occlusion: 255 122 | # noc_valid = 1 - occlusion / 255. # 0 or 1 123 | 124 | # return img1, img2, flow, valid.float(), noc_valid.float() 125 | 126 | return img1, img2, flow, valid.float() 127 | 128 | def __rmul__(self, v): 129 | self.flow_list = v * self.flow_list 130 | self.image_list = v * self.image_list 131 | self.occ_list = v * self.occ_list 132 | 133 | return self 134 | 135 | def __len__(self): 136 | return len(self.image_list) 137 | 138 | 139 | class MpiSintel(FlowDataset): 140 | def __init__(self, aug_params=None, split='training', 141 | root='datasets/Sintel', 142 | dstype='clean', 143 | load_occlusion=False, 144 | ): 145 | super(MpiSintel, self).__init__(aug_params, 146 | load_occlusion=load_occlusion, 147 | ) 148 | 149 | flow_root = osp.join(root, split, 'flow') 150 | image_root = osp.join(root, split, dstype) 151 | 152 | if load_occlusion: 153 | occlusion_root = osp.join(root, split, 'occlusions') 154 | 155 | if split == 'test': 156 | self.is_test = True 157 | 158 | for scene in os.listdir(image_root): 159 | image_list = sorted(glob(osp.join(image_root, scene, '*.png'))) 160 | for i in range(len(image_list) - 1): 161 | self.image_list += [[image_list[i], image_list[i + 1]]] 162 | self.extra_info += [(scene, i)] # scene and frame_id 163 | 164 | if split != 'test': 165 | self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo'))) 166 | 167 | if load_occlusion: 168 | self.occ_list += sorted(glob(osp.join(occlusion_root, scene, '*.png'))) 169 | 170 | 171 | class FlyingChairs(FlowDataset): 172 | def __init__(self, aug_params=None, split='train', 173 | root='datasets/FlyingChairs_release/data', 174 | ): 175 | super(FlyingChairs, self).__init__(aug_params) 176 | 177 | images = sorted(glob(osp.join(root, '*.ppm'))) 178 | flows = sorted(glob(osp.join(root, '*.flo'))) 179 | assert (len(images) // 2 == len(flows)) 180 | 181 | split_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'chairs_split.txt') 182 | split_list = np.loadtxt(split_file, dtype=np.int32) 183 | for i in range(len(flows)): 184 | xid = split_list[i] 185 | if (split == 'training' and xid == 1) or (split == 'validation' and xid == 2): 186 | self.flow_list += [flows[i]] 187 | self.image_list += [[images[2 * i], images[2 * i + 1]]] 188 | 189 | 190 | class FlyingThings3D(FlowDataset): 191 | def __init__(self, aug_params=None, 192 | root='datasets/FlyingThings3D', 193 | dstype='frames_cleanpass', 194 | test_set=False, 195 | validate_subset=True, 196 | load_occlusion=False, 197 | only_left=True, 198 | ): 199 | super(FlyingThings3D, self).__init__(aug_params, load_occlusion=load_occlusion) 200 | 201 | img_dir = root 202 | flow_dir = root 203 | 204 | if only_left: 205 | cam_list = ['left'] 206 | else: 207 | cam_list = ['left', 'right'] 208 | 209 | for cam in cam_list: 210 | for direction in ['into_future', 'into_past']: 211 | if test_set: 212 | image_dirs = sorted(glob(osp.join(img_dir, dstype, 'TEST/*/*'))) 213 | else: 214 | image_dirs = sorted(glob(osp.join(img_dir, dstype, 'TRAIN/*/*'))) 215 | image_dirs = sorted([osp.join(f, cam) for f in image_dirs]) 216 | 217 | if test_set: 218 | flow_dirs = sorted(glob(osp.join(flow_dir, 'optical_flow/TEST/*/*'))) 219 | else: 220 | flow_dirs = sorted(glob(osp.join(flow_dir, 'optical_flow/TRAIN/*/*'))) 221 | flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs]) 222 | 223 | for idir, fdir in zip(image_dirs, flow_dirs): 224 | images = sorted(glob(osp.join(idir, '*.png'))) 225 | flows = sorted(glob(osp.join(fdir, '*.pfm'))) 226 | occs = sorted(glob(osp.join(fdir, '*.png'))) 227 | for i in range(len(flows) - 1): 228 | if direction == 'into_future': 229 | self.image_list += [[images[i], images[i + 1]]] 230 | self.flow_list += [flows[i]] 231 | if load_occlusion: 232 | self.occ_list += [occs[i]] 233 | elif direction == 'into_past': 234 | self.image_list += [[images[i + 1], images[i]]] 235 | self.flow_list += [flows[i + 1]] 236 | if load_occlusion: 237 | self.occ_list += [occs[i]] 238 | 239 | if test_set and validate_subset: 240 | num_val_samples = 1024 241 | all_test_samples = len(self.image_list) # 7866 242 | 243 | stride = all_test_samples // num_val_samples 244 | remove = all_test_samples % num_val_samples 245 | 246 | self.image_list = self.image_list[:-remove][::stride] 247 | self.flow_list = self.flow_list[:-remove][::stride] 248 | 249 | 250 | class KITTI(FlowDataset): 251 | def __init__(self, aug_params=None, split='training', 252 | root='datasets/KITTI', 253 | ): 254 | super(KITTI, self).__init__(aug_params, sparse=True, 255 | ) 256 | if split == 'testing': 257 | self.is_test = True 258 | 259 | root = osp.join(root, split) 260 | images1 = sorted(glob(osp.join(root, 'image_2/*_10.png'))) 261 | images2 = sorted(glob(osp.join(root, 'image_2/*_11.png'))) 262 | 263 | for img1, img2 in zip(images1, images2): 264 | frame_id = img1.split('/')[-1] 265 | self.extra_info += [[frame_id]] 266 | self.image_list += [[img1, img2]] 267 | 268 | if split == 'training': 269 | self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png'))) 270 | 271 | 272 | class HD1K(FlowDataset): 273 | def __init__(self, aug_params=None, root='datasets/HD1K'): 274 | super(HD1K, self).__init__(aug_params, sparse=True) 275 | 276 | seq_ix = 0 277 | while 1: 278 | flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix))) 279 | images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix))) 280 | 281 | if len(flows) == 0: 282 | break 283 | 284 | for i in range(len(flows) - 1): 285 | self.flow_list += [flows[i]] 286 | self.image_list += [[images[i], images[i + 1]]] 287 | 288 | seq_ix += 1 289 | 290 | 291 | class NeuSim(FlowDataset): 292 | def __init__(self, aug_params=None, 293 | root='datasets/NeuSim' 294 | ): 295 | super(NeuSim, self).__init__(aug_params) 296 | 297 | image_dirs = sorted(glob(osp.join(root, '*/image'))) 298 | 299 | fw_flow_dirs = sorted(glob(osp.join(root, '*/forward_flow'))) 300 | bw_flow_dirs = sorted(glob(osp.join(root, '*/backward_flow'))) 301 | 302 | for image_dir, fw_flow_dir, bw_flow_dir in zip(image_dirs, fw_flow_dirs, bw_flow_dirs): 303 | images = sorted(glob(osp.join(image_dir, '*.png'))) 304 | fw_flows = sorted(glob(osp.join(fw_flow_dir, '*.npy'))) 305 | bw_flows = sorted(glob(osp.join(bw_flow_dir, '*.npy'))) 306 | for i in range(len(fw_flows) - 1): 307 | self.image_list += [[images[i], images[i + 1]]] 308 | self.flow_list += [fw_flows[i]] 309 | self.image_list += [[images[i + 1], images[i]]] 310 | self.flow_list += [bw_flows[i]] 311 | 312 | 313 | def build_train_dataset(stage): 314 | if stage == 'chairs': 315 | aug_params = {'crop_size': (384, 512), 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True} 316 | 317 | train_dataset = FlyingChairs(aug_params, split='training') 318 | 319 | elif stage == 'things': 320 | aug_params = {'crop_size': (384, 768), 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True} 321 | 322 | clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass', load_occlusion=True) 323 | final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass', load_occlusion=True) 324 | train_dataset = clean_dataset + final_dataset 325 | 326 | elif stage == 'sintel': 327 | crop_size = (320, 896) 328 | aug_params = {'crop_size': crop_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True} 329 | 330 | things_clean = FlyingThings3D(aug_params, dstype='frames_cleanpass', load_occlusion=True) 331 | things_final = FlyingThings3D(aug_params, dstype='frames_finalpass', load_occlusion=True) 332 | 333 | sintel_clean = MpiSintel(aug_params, split='training', dstype='clean') 334 | sintel_final = MpiSintel(aug_params, split='training', dstype='final') 335 | 336 | aug_params = {'crop_size': crop_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True} 337 | 338 | kitti = KITTI(aug_params=aug_params, val=False) 339 | 340 | aug_params = {'crop_size': crop_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True} 341 | 342 | hd1k = HD1K(aug_params=aug_params) 343 | 344 | train_dataset = 40 * sintel_clean + 40 * sintel_final + 200 * kitti + 10 * hd1k + things_clean + things_final 345 | 346 | elif stage == 'kitti': 347 | aug_params = {'crop_size': (320, 1152), 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False} 348 | 349 | train_dataset = KITTI(aug_params, split='training', val=False) 350 | 351 | elif stage == 'neusim': 352 | crop_size = (320, 896) 353 | aug_params = {'crop_size': crop_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True} 354 | things_clean = FlyingThings3D(aug_params, dstype='frames_cleanpass', load_occlusion=True) 355 | things_final = FlyingThings3D(aug_params, dstype='frames_finalpass', load_occlusion=True) 356 | 357 | aug_params = {'crop_size': crop_size, 'min_scale': -1, 'max_scale': 0, 'do_flip': False} 358 | neu_dataset = NeuSim(aug_params) 359 | 360 | train_dataset = things_clean + things_final + 2 * neu_dataset 361 | 362 | return train_dataset 363 | -------------------------------------------------------------------------------- /data_utils/evaluate.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os 3 | import time 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | from data_utils import datasets 9 | 10 | from data_utils import frame_utils 11 | from data_utils import flow_viz 12 | 13 | 14 | @torch.no_grad() 15 | def validate_chairs(model): 16 | """ Perform evaluation on the FlyingChairs (test) split """ 17 | model.eval() 18 | epe_list = [] 19 | results = {} 20 | 21 | val_dataset = datasets.FlyingChairs(split='validation') 22 | 23 | print('Number of validation image pairs: %d' % len(val_dataset)) 24 | 25 | for val_id in range(len(val_dataset)): 26 | image1, image2, flow_gt, _ = val_dataset[val_id] 27 | 28 | image1 = image1[None].cuda() 29 | image2 = image2[None].cuda() 30 | 31 | model.init_bhw(image1.shape[0], image1.shape[-2], image1.shape[-1]) 32 | 33 | results_dict = model(image1, image2) 34 | 35 | flow_pr = results_dict[-1] # [B, 2, H, W] 36 | 37 | assert flow_pr.size()[-2:] == flow_gt.size()[-2:] 38 | 39 | epe = torch.sum((flow_pr[0].cpu() - flow_gt) ** 2, dim=0).sqrt() 40 | epe_list.append(epe.view(-1).numpy()) 41 | 42 | epe_all = np.concatenate(epe_list) 43 | print("Validation Chairs EPE: %.3f" % (epe)) 44 | results['chairs_epe'] = epe 45 | 46 | return results 47 | 48 | 49 | @torch.no_grad() 50 | def validate_things(model, 51 | dstype, 52 | validate_subset, 53 | padding_factor=16, 54 | max_val_flow=400 55 | ): 56 | """ Peform validation using the Things (test) split """ 57 | model.eval() 58 | results = {} 59 | 60 | val_dataset = datasets.FlyingThings3D(dstype=dstype, test_set=True, validate_subset=validate_subset, 61 | ) 62 | print('Number of validation image pairs: %d' % len(val_dataset)) 63 | epe_list = [] 64 | 65 | for val_id in range(len(val_dataset)): 66 | image1, image2, flow_gt, valid_gt = val_dataset[val_id] 67 | image1 = image1[None].cuda() 68 | image2 = image2[None].cuda() 69 | 70 | padder = frame_utils.InputPadder(image1.shape, padding_factor=padding_factor) 71 | image1, image2 = padder.pad(image1, image2) 72 | 73 | model.init_bhw(image1.shape[0], image1.shape[-2], image1.shape[-1]) 74 | 75 | results_dict = model(image1, image2) 76 | flow_pr = results_dict[-1] 77 | 78 | flow = padder.unpad(flow_pr[0]).cpu() 79 | 80 | # Evaluation on flow <= max_val_flow 81 | flow_gt_speed = torch.sum(flow_gt ** 2, dim=0).sqrt() 82 | valid_gt = valid_gt * (flow_gt_speed < max_val_flow) 83 | valid_gt = valid_gt.contiguous() 84 | 85 | epe = torch.sum((flow - flow_gt) ** 2, dim=0).sqrt() 86 | val = valid_gt >= 0.5 87 | epe_list.append(epe[val].cpu().numpy()) 88 | 89 | epe_list = np.mean(np.concatenate(epe_list)) 90 | 91 | epe = np.mean(epe_list) 92 | 93 | print("Validation Things test set (%s) EPE: %.3f" % (dstype, epe)) 94 | results[dstype + '_epe'] = epe 95 | 96 | return results 97 | 98 | 99 | @torch.no_grad() 100 | def validate_sintel(model, 101 | dstype, 102 | padding_factor=16 103 | ): 104 | """ Peform validation using the Sintel (train) split """ 105 | model.eval() 106 | results = {} 107 | 108 | val_dataset = datasets.MpiSintel(split='training', dstype=dstype) 109 | 110 | print('Number of validation image pairs: %d' % len(val_dataset)) 111 | epe_list = [] 112 | 113 | for val_id in range(len(val_dataset)): 114 | 115 | image1, image2, flow_gt, _ = val_dataset[val_id] 116 | 117 | image1 = image1[None].cuda() 118 | image2 = image2[None].cuda() 119 | 120 | padder = frame_utils.InputPadder(image1.shape, padding_factor=padding_factor) 121 | image1, image2 = padder.pad(image1, image2) 122 | 123 | model.init_bhw(image1.shape[0], image1.shape[-2], image1.shape[-1]) 124 | 125 | results_dict = model(image1, image2) 126 | 127 | # useful when using parallel branches 128 | flow_pr = results_dict[-1] 129 | 130 | flow = padder.unpad(flow_pr[0]).cpu() 131 | # flow = flow_pr[0].cpu() 132 | 133 | epe = torch.sum((flow - flow_gt) ** 2, dim=0).sqrt() 134 | epe_list.append(epe.view(-1).numpy()) 135 | 136 | epe_all = np.concatenate(epe_list) 137 | epe = np.mean(epe_all) 138 | 139 | print("Validation Sintel (%s) EPE: %.3f" % (dstype, epe)) 140 | 141 | dstype = 'sintel_' + dstype 142 | 143 | results[dstype + '_epe'] = np.mean(epe_list) 144 | 145 | return results 146 | 147 | 148 | @torch.no_grad() 149 | def validate_kitti(model, 150 | padding_factor=16 151 | ): 152 | """ Peform validation using the KITTI-2015 (train) split """ 153 | model.eval() 154 | 155 | val_dataset = datasets.KITTI(split='training') 156 | print('Number of validation image pairs: %d' % len(val_dataset)) 157 | 158 | out_list, epe_list = [], [] 159 | results = {} 160 | 161 | for val_id in range(len(val_dataset)): 162 | image1, image2, flow_gt, valid_gt = val_dataset[val_id] 163 | image1 = image1[None].cuda() 164 | image2 = image2[None].cuda() 165 | 166 | padder = frame_utils.InputPadder(image1.shape, mode='kitti', padding_factor=padding_factor) 167 | image1, image2 = padder.pad(image1, image2) 168 | 169 | model.init_bhw(image1.shape[0], image1.shape[-2], image1.shape[-1]) 170 | 171 | results_dict = model(image1, image2) 172 | 173 | # useful when using parallel branches 174 | flow_pr = results_dict[-1] 175 | 176 | flow = padder.unpad(flow_pr[0]).cpu() 177 | 178 | epe = torch.sum((flow - flow_gt) ** 2, dim=0).sqrt() 179 | mag = torch.sum(flow_gt ** 2, dim=0).sqrt() 180 | 181 | epe = epe.view(-1) 182 | mag = mag.view(-1) 183 | val = valid_gt.view(-1) >= 0.5 184 | 185 | out = ((epe > 3.0) & ((epe / mag) > 0.05)).float() 186 | 187 | epe_list.append(epe[val].cpu().numpy()) 188 | 189 | out_list.append(out[val].cpu().numpy()) 190 | 191 | epe_list = np.concatenate(epe_list) 192 | out_list = np.concatenate(out_list) 193 | 194 | epe = np.mean(epe_list) 195 | f1 = 100 * np.mean(out_list) 196 | 197 | print("Validation KITTI EPE: %.3f" % (epe)) 198 | print("Validation KITTI F1: %.3f" % (f1)) 199 | results['kitti_epe'] = epe 200 | results['kitti_f1'] = f1 201 | 202 | return results 203 | 204 | @torch.no_grad() 205 | def create_kitti_submission(model, output_path='datasets/kitti_submission/flow', padding_factor=16, save_vis_flow=False): 206 | """ Create submission for the Sintel leaderboard """ 207 | model.eval() 208 | test_dataset = datasets.KITTI(split='testing', aug_params=None) 209 | 210 | if not os.path.exists(output_path): 211 | os.makedirs(output_path) 212 | 213 | for test_id in range(len(test_dataset)): 214 | image1, image2, (frame_id,) = test_dataset[test_id] 215 | padder = frame_utils.InputPadder(image1.shape, mode='kitti', padding_factor=padding_factor) 216 | image1_pad, image2_pad = padder.pad(image1[None].cuda(), image2[None].cuda()) 217 | 218 | model.init_bhw(image1_pad.shape[0], image1_pad.shape[-2], image1_pad.shape[-1]) 219 | 220 | results_dict = model(image1_pad, image2_pad) 221 | 222 | flow_pr = results_dict[-1] 223 | 224 | flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy() 225 | 226 | output_filename = os.path.join(output_path, frame_id) 227 | 228 | if save_vis_flow: 229 | vis_flow_file = output_filename 230 | flow_viz.save_vis_flow_tofile(flow, vis_flow_file) 231 | else: 232 | frame_utils.writeFlowKITTI(output_filename, flow) 233 | -------------------------------------------------------------------------------- /data_utils/flow_viz.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | UNKNOWN_FLOW_THRESH = 1e7 5 | SMALLFLOW = 0.0 6 | LARGEFLOW = 1e8 7 | 8 | 9 | def make_color_wheel(): 10 | """ 11 | Generate color wheel according Middlebury color code 12 | :return: Color wheel 13 | """ 14 | RY = 15 15 | YG = 6 16 | GC = 4 17 | CB = 11 18 | BM = 13 19 | MR = 6 20 | 21 | ncols = RY + YG + GC + CB + BM + MR 22 | 23 | colorwheel = np.zeros([ncols, 3]) 24 | 25 | col = 0 26 | 27 | # RY 28 | colorwheel[0:RY, 0] = 255 29 | colorwheel[0:RY, 1] = np.transpose(np.floor(255 * np.arange(0, RY) / RY)) 30 | col += RY 31 | 32 | # YG 33 | colorwheel[col:col + YG, 0] = 255 - np.transpose(np.floor(255 * np.arange(0, YG) / YG)) 34 | colorwheel[col:col + YG, 1] = 255 35 | col += YG 36 | 37 | # GC 38 | colorwheel[col:col + GC, 1] = 255 39 | colorwheel[col:col + GC, 2] = np.transpose(np.floor(255 * np.arange(0, GC) / GC)) 40 | col += GC 41 | 42 | # CB 43 | colorwheel[col:col + CB, 1] = 255 - np.transpose(np.floor(255 * np.arange(0, CB) / CB)) 44 | colorwheel[col:col + CB, 2] = 255 45 | col += CB 46 | 47 | # BM 48 | colorwheel[col:col + BM, 2] = 255 49 | colorwheel[col:col + BM, 0] = np.transpose(np.floor(255 * np.arange(0, BM) / BM)) 50 | col += + BM 51 | 52 | # MR 53 | colorwheel[col:col + MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR)) 54 | colorwheel[col:col + MR, 0] = 255 55 | 56 | return colorwheel 57 | 58 | 59 | def compute_color(u, v): 60 | """ 61 | compute optical flow color map 62 | :param u: optical flow horizontal map 63 | :param v: optical flow vertical map 64 | :return: optical flow in color code 65 | """ 66 | [h, w] = u.shape 67 | img = np.zeros([h, w, 3]) 68 | nanIdx = np.isnan(u) | np.isnan(v) 69 | u[nanIdx] = 0 70 | v[nanIdx] = 0 71 | 72 | colorwheel = make_color_wheel() 73 | ncols = np.size(colorwheel, 0) 74 | 75 | rad = np.sqrt(u ** 2 + v ** 2) 76 | 77 | a = np.arctan2(-v, -u) / np.pi 78 | 79 | fk = (a + 1) / 2 * (ncols - 1) + 1 80 | 81 | k0 = np.floor(fk).astype(int) 82 | 83 | k1 = k0 + 1 84 | k1[k1 == ncols + 1] = 1 85 | f = fk - k0 86 | 87 | for i in range(0, np.size(colorwheel, 1)): 88 | tmp = colorwheel[:, i] 89 | col0 = tmp[k0 - 1] / 255 90 | col1 = tmp[k1 - 1] / 255 91 | col = (1 - f) * col0 + f * col1 92 | 93 | idx = rad <= 1 94 | col[idx] = 1 - rad[idx] * (1 - col[idx]) 95 | notidx = np.logical_not(idx) 96 | 97 | col[notidx] *= 0.75 98 | img[:, :, i] = np.uint8(np.floor(255 * col * (1 - nanIdx))) 99 | 100 | return img 101 | 102 | 103 | # from https://github.com/gengshan-y/VCN 104 | def flow_to_image(flow): 105 | """ 106 | Convert flow into middlebury color code image 107 | :param flow: optical flow map 108 | :return: optical flow image in middlebury color 109 | """ 110 | u = flow[:, :, 0] 111 | v = flow[:, :, 1] 112 | 113 | maxu = -999. 114 | maxv = -999. 115 | minu = 999. 116 | minv = 999. 117 | 118 | idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH) 119 | u[idxUnknow] = 0 120 | v[idxUnknow] = 0 121 | 122 | maxu = max(maxu, np.max(u)) 123 | minu = min(minu, np.min(u)) 124 | 125 | maxv = max(maxv, np.max(v)) 126 | minv = min(minv, np.min(v)) 127 | 128 | rad = np.sqrt(u ** 2 + v ** 2) 129 | maxrad = max(-1, np.max(rad)) 130 | 131 | u = u / (maxrad + np.finfo(float).eps) 132 | v = v / (maxrad + np.finfo(float).eps) 133 | 134 | img = compute_color(u, v) 135 | 136 | idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2) 137 | img[idx] = 0 138 | 139 | return np.uint8(img) 140 | 141 | 142 | def save_vis_flow_tofile(flow, output_path): 143 | vis_flow = flow_to_image(flow) 144 | from PIL import Image 145 | img = Image.fromarray(vis_flow) 146 | img.save(output_path) -------------------------------------------------------------------------------- /data_utils/frame_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from os.path import * 4 | import re 5 | import cv2 6 | 7 | import torch.nn.functional as F 8 | 9 | 10 | class InputPadder: 11 | """ Pads images such that dimensions are divisible by 8 """ 12 | 13 | def __init__(self, dims, mode='sintel', padding_factor=8): 14 | self.ht, self.wd = dims[-2:] 15 | pad_ht = (((self.ht // padding_factor) + 1) * padding_factor - self.ht) % padding_factor 16 | pad_wd = (((self.wd // padding_factor) + 1) * padding_factor - self.wd) % padding_factor 17 | if mode == 'sintel': 18 | self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2] 19 | else: 20 | self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht] 21 | 22 | def pad(self, *inputs): 23 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 24 | 25 | def unpad(self, x): 26 | ht, wd = x.shape[-2:] 27 | c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] 28 | return x[..., c[0]:c[1], c[2]:c[3]] 29 | 30 | 31 | def readFlow(fn): 32 | """ Read .flo file in Middlebury format""" 33 | # Code adapted from: 34 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 35 | 36 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 37 | # print 'fn = %s'%(fn) 38 | with open(fn, 'rb') as f: 39 | magic = np.fromfile(f, np.float32, count=1) 40 | if 202021.25 != magic: 41 | print('Magic number incorrect. Invalid .flo file') 42 | return None 43 | else: 44 | w = np.fromfile(f, np.int32, count=1) 45 | h = np.fromfile(f, np.int32, count=1) 46 | # print 'Reading %d x %d flo file\n' % (w, h) 47 | data = np.fromfile(f, np.float32, count=2 * int(w) * int(h)) 48 | # Reshape testdata into 3D array (columns, rows, bands) 49 | # The reshape here is for visualization, the original code is (w,h,2) 50 | return np.resize(data, (int(h), int(w), 2)) 51 | 52 | 53 | def readPFM(file): 54 | file = open(file, 'rb') 55 | 56 | color = None 57 | width = None 58 | height = None 59 | scale = None 60 | endian = None 61 | 62 | header = file.readline().rstrip() 63 | if header == b'PF': 64 | color = True 65 | elif header == b'Pf': 66 | color = False 67 | else: 68 | raise Exception('Not a PFM file.') 69 | 70 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) 71 | if dim_match: 72 | width, height = map(int, dim_match.groups()) 73 | else: 74 | raise Exception('Malformed PFM header.') 75 | 76 | scale = float(file.readline().rstrip()) 77 | if scale < 0: # little-endian 78 | endian = '<' 79 | scale = -scale 80 | else: 81 | endian = '>' # big-endian 82 | 83 | data = np.fromfile(file, endian + 'f') 84 | shape = (height, width, 3) if color else (height, width) 85 | 86 | data = np.reshape(data, shape) 87 | data = np.flipud(data) 88 | return data 89 | 90 | 91 | def readFlowKITTI(filename): 92 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR) 93 | flow = flow[:, :, ::-1].astype(np.float32) 94 | flow, valid = flow[:, :, :2], flow[:, :, 2] 95 | flow = (flow - 2 ** 15) / 64.0 96 | return flow, valid 97 | 98 | 99 | def read_vkitti_png_flow(flow_fn): 100 | # “Convert from .png to (h, w, 2) (flow_x, flow_y) float32 array” 101 | # read png to bgr in 16 bit unsigned short 102 | bgr = cv2.imread(flow_fn, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) 103 | h, w, _c = bgr.shape 104 | assert bgr.dtype == np.uint16 and _c == 3 105 | # b == invalid flow flag == 0 for sky or other invalid flow 106 | # invalid = bgr[..., 0] == 0 107 | # g,r == flow_y,x normalized by height,width and scaled to [0;2**16 – 1] 108 | out_flow = 2.0 / (2**16 - 1.0) * bgr[..., 2:0:-1].astype('float32') - 1 109 | out_flow[..., 0] *= w - 1 110 | out_flow[..., 1] *= h - 1 111 | # out_flow[invalid] = 0 # or another value (e.g., np.nan) 112 | return out_flow, bgr[..., 0] != 0 113 | 114 | 115 | def writeFlowKITTI(filename, uv): 116 | uv = 64.0 * uv + 2 ** 15 117 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 118 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 119 | cv2.imwrite(filename, uv[..., ::-1]) 120 | 121 | 122 | def read_gen(file_name, pil=False): 123 | ext = splitext(file_name)[-1] 124 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': 125 | return Image.open(file_name) 126 | elif ext == '.bin' or ext == '.raw': 127 | return np.load(file_name) 128 | elif ext == '.flo': 129 | return readFlow(file_name).astype(np.float32) 130 | elif ext == '.pfm': 131 | flow = readPFM(file_name).astype(np.float32) 132 | if len(flow.shape) == 2: 133 | return flow 134 | else: 135 | return flow[:, :, :-1] 136 | elif ext == '.npy': 137 | return np.load(file_name) 138 | return [] 139 | -------------------------------------------------------------------------------- /data_utils/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from PIL import Image 4 | from torchvision.transforms import ColorJitter 5 | 6 | 7 | class FlowAugmentor: 8 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True, 9 | no_eraser_aug=True, 10 | ): 11 | # spatial augmentation params 12 | self.crop_size = crop_size 13 | self.min_scale = min_scale 14 | self.max_scale = max_scale 15 | self.spatial_aug_prob = 0.8 16 | self.stretch_prob = 0.8 17 | self.max_stretch = 0.2 18 | 19 | # flip augmentation params 20 | self.do_flip = do_flip 21 | self.h_flip_prob = 0.5 22 | self.v_flip_prob = 0.1 23 | 24 | # photometric augmentation params 25 | self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5 / 3.14) 26 | 27 | self.asymmetric_color_aug_prob = 0.2 28 | 29 | if no_eraser_aug: 30 | # we disable eraser aug since no obvious improvement is observed in our experiments 31 | self.eraser_aug_prob = -1 32 | else: 33 | self.eraser_aug_prob = 0.5 34 | 35 | def color_transform(self, img1, img2): 36 | """ Photometric augmentation """ 37 | 38 | # asymmetric 39 | if np.random.rand() < self.asymmetric_color_aug_prob: 40 | img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8) 41 | img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8) 42 | 43 | # symmetric 44 | else: 45 | image_stack = np.concatenate([img1, img2], axis=0) 46 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 47 | img1, img2 = np.split(image_stack, 2, axis=0) 48 | 49 | return img1, img2 50 | 51 | def eraser_transform(self, img1, img2, bounds=[50, 100]): 52 | """ Occlusion augmentation """ 53 | 54 | ht, wd = img1.shape[:2] 55 | if np.random.rand() < self.eraser_aug_prob: 56 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 57 | for _ in range(np.random.randint(1, 3)): 58 | x0 = np.random.randint(0, wd) 59 | y0 = np.random.randint(0, ht) 60 | dx = np.random.randint(bounds[0], bounds[1]) 61 | dy = np.random.randint(bounds[0], bounds[1]) 62 | img2[y0:y0 + dy, x0:x0 + dx, :] = mean_color 63 | 64 | return img1, img2 65 | 66 | def spatial_transform(self, img1, img2, flow, occlusion=None): 67 | # randomly sample scale 68 | ht, wd = img1.shape[:2] 69 | 70 | min_scale = np.maximum( 71 | (self.crop_size[0] + 8) / float(ht), 72 | (self.crop_size[1] + 8) / float(wd)) 73 | 74 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 75 | scale_x = scale 76 | scale_y = scale 77 | if np.random.rand() < self.stretch_prob: 78 | scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 79 | scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 80 | 81 | scale_x = np.clip(scale_x, min_scale, None) 82 | scale_y = np.clip(scale_y, min_scale, None) 83 | 84 | if np.random.rand() < self.spatial_aug_prob: 85 | # rescale the images 86 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 87 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 88 | flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 89 | flow = flow * [scale_x, scale_y] 90 | 91 | if occlusion is not None: 92 | occlusion = cv2.resize(occlusion, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 93 | 94 | if self.do_flip: 95 | if np.random.rand() < self.h_flip_prob: # h-flip 96 | img1 = img1[:, ::-1] 97 | img2 = img2[:, ::-1] 98 | flow = flow[:, ::-1] * [-1.0, 1.0] 99 | 100 | if occlusion is not None: 101 | occlusion = occlusion[:, ::-1] 102 | 103 | if np.random.rand() < self.v_flip_prob: # v-flip 104 | img1 = img1[::-1, :] 105 | img2 = img2[::-1, :] 106 | flow = flow[::-1, :] * [1.0, -1.0] 107 | 108 | if occlusion is not None: 109 | occlusion = occlusion[::-1, :] 110 | 111 | # In case no cropping 112 | if img1.shape[0] - self.crop_size[0] > 0: 113 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) 114 | else: 115 | y0 = 0 116 | if img1.shape[1] - self.crop_size[1] > 0: 117 | x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) 118 | else: 119 | x0 = 0 120 | 121 | img1 = img1[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] 122 | img2 = img2[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] 123 | flow = flow[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] 124 | 125 | if occlusion is not None: 126 | occlusion = occlusion[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] 127 | return img1, img2, flow, occlusion 128 | 129 | return img1, img2, flow 130 | 131 | def __call__(self, img1, img2, flow, occlusion=None): 132 | img1, img2 = self.color_transform(img1, img2) 133 | img1, img2 = self.eraser_transform(img1, img2) 134 | 135 | if occlusion is not None: 136 | img1, img2, flow, occlusion = self.spatial_transform( 137 | img1, img2, flow, occlusion) 138 | else: 139 | img1, img2, flow = self.spatial_transform(img1, img2, flow) 140 | 141 | img1 = np.ascontiguousarray(img1) 142 | img2 = np.ascontiguousarray(img2) 143 | flow = np.ascontiguousarray(flow) 144 | 145 | if occlusion is not None: 146 | occlusion = np.ascontiguousarray(occlusion) 147 | return img1, img2, flow, occlusion 148 | 149 | return img1, img2, flow 150 | 151 | 152 | class SparseFlowAugmentor: 153 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False, 154 | no_eraser_aug=True, 155 | ): 156 | # spatial augmentation params 157 | self.crop_size = crop_size 158 | self.min_scale = min_scale 159 | self.max_scale = max_scale 160 | self.spatial_aug_prob = 0.8 161 | self.stretch_prob = 0.8 162 | self.max_stretch = 0.2 163 | 164 | # flip augmentation params 165 | self.do_flip = do_flip 166 | self.h_flip_prob = 0.5 167 | self.v_flip_prob = 0.1 168 | 169 | # photometric augmentation params 170 | self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3 / 3.14) 171 | self.asymmetric_color_aug_prob = 0.2 172 | 173 | if no_eraser_aug: 174 | # we disable eraser aug since no obvious improvement is observed in our experiments 175 | self.eraser_aug_prob = -1 176 | else: 177 | self.eraser_aug_prob = 0.5 178 | 179 | def color_transform(self, img1, img2): 180 | image_stack = np.concatenate([img1, img2], axis=0) 181 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 182 | img1, img2 = np.split(image_stack, 2, axis=0) 183 | return img1, img2 184 | 185 | def eraser_transform(self, img1, img2): 186 | ht, wd = img1.shape[:2] 187 | if np.random.rand() < self.eraser_aug_prob: 188 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 189 | for _ in range(np.random.randint(1, 3)): 190 | x0 = np.random.randint(0, wd) 191 | y0 = np.random.randint(0, ht) 192 | dx = np.random.randint(50, 100) 193 | dy = np.random.randint(50, 100) 194 | img2[y0:y0 + dy, x0:x0 + dx, :] = mean_color 195 | 196 | return img1, img2 197 | 198 | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): 199 | ht, wd = flow.shape[:2] 200 | coords = np.meshgrid(np.arange(wd), np.arange(ht)) 201 | coords = np.stack(coords, axis=-1) 202 | 203 | coords = coords.reshape(-1, 2).astype(np.float32) 204 | flow = flow.reshape(-1, 2).astype(np.float32) 205 | valid = valid.reshape(-1).astype(np.float32) 206 | 207 | coords0 = coords[valid >= 1] 208 | flow0 = flow[valid >= 1] 209 | 210 | ht1 = int(round(ht * fy)) 211 | wd1 = int(round(wd * fx)) 212 | 213 | coords1 = coords0 * [fx, fy] 214 | flow1 = flow0 * [fx, fy] 215 | 216 | xx = np.round(coords1[:, 0]).astype(np.int32) 217 | yy = np.round(coords1[:, 1]).astype(np.int32) 218 | 219 | v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) 220 | xx = xx[v] 221 | yy = yy[v] 222 | flow1 = flow1[v] 223 | 224 | flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) 225 | valid_img = np.zeros([ht1, wd1], dtype=np.int32) 226 | 227 | flow_img[yy, xx] = flow1 228 | valid_img[yy, xx] = 1 229 | 230 | return flow_img, valid_img 231 | 232 | def spatial_transform(self, img1, img2, flow, valid): 233 | # randomly sample scale 234 | 235 | ht, wd = img1.shape[:2] 236 | min_scale = np.maximum( 237 | (self.crop_size[0] + 1) / float(ht), 238 | (self.crop_size[1] + 1) / float(wd)) 239 | 240 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 241 | scale_x = np.clip(scale, min_scale, None) 242 | scale_y = np.clip(scale, min_scale, None) 243 | 244 | if np.random.rand() < self.spatial_aug_prob: 245 | # rescale the images 246 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 247 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 248 | 249 | flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y) 250 | 251 | if self.do_flip: 252 | if np.random.rand() < 0.5: # h-flip 253 | img1 = img1[:, ::-1] 254 | img2 = img2[:, ::-1] 255 | flow = flow[:, ::-1] * [-1.0, 1.0] 256 | valid = valid[:, ::-1] 257 | 258 | margin_y = 20 259 | margin_x = 50 260 | 261 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y) 262 | x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x) 263 | 264 | y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) 265 | x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) 266 | 267 | img1 = img1[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] 268 | img2 = img2[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] 269 | flow = flow[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] 270 | valid = valid[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] 271 | return img1, img2, flow, valid 272 | 273 | def __call__(self, img1, img2, flow, valid): 274 | img1, img2 = self.color_transform(img1, img2) 275 | img1, img2 = self.eraser_transform(img1, img2) 276 | 277 | img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid) 278 | 279 | img1 = np.ascontiguousarray(img1) 280 | img2 = np.ascontiguousarray(img2) 281 | flow = np.ascontiguousarray(flow) 282 | valid = np.ascontiguousarray(valid) 283 | 284 | return img1, img2, flow, valid -------------------------------------------------------------------------------- /dataset_preview/dataset_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neufieldrobotics/NeuFlow/1881c3cd85b87c6155d1b0726b8a345b6a1bdb7e/dataset_preview/dataset_1.png -------------------------------------------------------------------------------- /dataset_preview/dataset_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neufieldrobotics/NeuFlow/1881c3cd85b87c6155d1b0726b8a345b6a1bdb7e/dataset_preview/dataset_2.png -------------------------------------------------------------------------------- /dist_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | 4 | import torch 5 | import torch.multiprocessing as mp 6 | from torch import distributed as dist 7 | 8 | 9 | def init_dist(launcher, backend='nccl', **kwargs): 10 | if mp.get_start_method(allow_none=True) is None: 11 | mp.set_start_method('spawn') 12 | if launcher == 'pytorch': 13 | _init_dist_pytorch(backend, **kwargs) 14 | elif launcher == 'mpi': 15 | _init_dist_mpi(backend, **kwargs) 16 | elif launcher == 'slurm': 17 | _init_dist_slurm(backend, **kwargs) 18 | else: 19 | raise ValueError(f'Invalid launcher type: {launcher}') 20 | 21 | 22 | def _init_dist_pytorch(backend, **kwargs): 23 | # TODO: use local_rank instead of rank % num_gpus 24 | rank = int(os.environ['RANK']) 25 | num_gpus = torch.cuda.device_count() 26 | torch.cuda.set_device(rank % num_gpus) 27 | dist.init_process_group(backend=backend, **kwargs) 28 | 29 | 30 | def _init_dist_mpi(backend, **kwargs): 31 | rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 32 | num_gpus = torch.cuda.device_count() 33 | torch.cuda.set_device(rank % num_gpus) 34 | dist.init_process_group(backend=backend, **kwargs) 35 | 36 | 37 | def _init_dist_slurm(backend, port=None): 38 | """Initialize slurm distributed training environment. 39 | If argument ``port`` is not specified, then the master port will be system 40 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 41 | environment variable, then a default port ``29500`` will be used. 42 | Args: 43 | backend (str): Backend of torch.distributed. 44 | port (int, optional): Master port. Defaults to None. 45 | """ 46 | proc_id = int(os.environ['SLURM_PROCID']) 47 | ntasks = int(os.environ['SLURM_NTASKS']) 48 | node_list = os.environ['SLURM_NODELIST'] 49 | num_gpus = torch.cuda.device_count() 50 | torch.cuda.set_device(proc_id % num_gpus) 51 | addr = subprocess.getoutput( 52 | f'scontrol show hostname {node_list} | head -n1') 53 | # specify master port 54 | if port is not None: 55 | os.environ['MASTER_PORT'] = str(port) 56 | elif 'MASTER_PORT' in os.environ: 57 | pass # use MASTER_PORT in the environment variable 58 | else: 59 | # 29500 is torch.distributed default port 60 | os.environ['MASTER_PORT'] = '29500' 61 | # use MASTER_ADDR in the environment variable if it already exists 62 | if 'MASTER_ADDR' not in os.environ: 63 | os.environ['MASTER_ADDR'] = addr 64 | os.environ['WORLD_SIZE'] = str(ntasks) 65 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 66 | os.environ['RANK'] = str(proc_id) 67 | dist.init_process_group(backend=backend) 68 | 69 | 70 | def get_dist_info(): 71 | if dist.is_available(): 72 | initialized = dist.is_initialized() 73 | else: 74 | initialized = False 75 | if initialized: 76 | rank = dist.get_rank() 77 | world_size = dist.get_world_size() 78 | else: 79 | rank = 0 80 | world_size = 1 81 | return rank, world_size 82 | 83 | 84 | def setup_for_distributed(is_master): 85 | """ 86 | This function disables printing when not in master process 87 | """ 88 | import builtins as __builtin__ 89 | builtin_print = __builtin__.print 90 | 91 | def print(*args, **kwargs): 92 | force = kwargs.pop('force', False) 93 | if is_master or force: 94 | builtin_print(*args, **kwargs) 95 | 96 | __builtin__.print = print -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import argparse 4 | 5 | from NeuFlow.neuflow import NeuFlow 6 | from data_utils.evaluate import validate_things, validate_sintel 7 | 8 | def get_args_parser(): 9 | parser = argparse.ArgumentParser() 10 | 11 | parser.add_argument('--resume', default=None, type=str, 12 | help='resume from pretrain model for finetuing or resume from terminated training') 13 | 14 | return parser 15 | 16 | def main(args): 17 | torch.backends.cudnn.benchmark = True 18 | 19 | device = torch.device('cuda') 20 | 21 | model = NeuFlow().to(device) 22 | 23 | checkpoint = torch.load(args.resume, map_location='cuda') 24 | 25 | model.load_state_dict(checkpoint['model'], strict=True) 26 | 27 | num_params = sum(p.numel() for p in model.parameters()) 28 | print('Number of params:', num_params) 29 | 30 | validate_things(model, dstype='frames_cleanpass', validate_subset=False, max_val_flow=400) 31 | validate_things(model, dstype='frames_finalpass', validate_subset=False, max_val_flow=400) 32 | validate_sintel(model, dstype='clean') 33 | validate_sintel(model, dstype='final') 34 | 35 | if __name__ == '__main__': 36 | parser = get_args_parser() 37 | args = parser.parse_args() 38 | 39 | main(args) 40 | -------------------------------------------------------------------------------- /load_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import re 3 | 4 | 5 | def my_load_weights(weight_path): 6 | 7 | print('Load checkpoint: %s' % weight_path) 8 | 9 | checkpoint = torch.load(weight_path, map_location='cuda') 10 | 11 | state_dict = {} 12 | 13 | for k, v in checkpoint['model'].items(): 14 | 15 | # if k.startswith('conv_s8.'): 16 | # continue 17 | # if k.startswith('upsample_s1.'): 18 | # continue 19 | 20 | state_dict[k] = v 21 | 22 | return state_dict 23 | 24 | 25 | def my_freeze_model(model): 26 | for name, param in model.named_parameters(): 27 | pass 28 | # if name.startswith('upsample_s1.'): 29 | # param.requires_grad = True 30 | # elif name.startswith('conv_s8.'): 31 | # param.requires_grad = True 32 | # else: 33 | # param.requires_grad = False -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def flow_loss_func(flow_preds, flow_gt, valid, 5 | max_flow=400 6 | ): 7 | n_predictions = len(flow_preds) 8 | flow_loss = 0.0 9 | 10 | # exlude invalid pixels and extremely large diplacements 11 | mag = torch.sum(flow_gt ** 2, dim=1).sqrt() # [B, H, W] 12 | valid = (valid >= 0.5) & (mag < max_flow) 13 | 14 | weights = [0.2, 1] 15 | 16 | for i in range(n_predictions): 17 | 18 | i_loss = (flow_preds[i] - flow_gt).abs() 19 | flow_loss += weights[i] * (valid[:, None] * i_loss).mean() 20 | 21 | epe = torch.sum((flow_preds[-1] - flow_gt) ** 2, dim=1).sqrt() 22 | 23 | if valid.max() < 0.5: 24 | pass 25 | 26 | epe = epe.view(-1)[valid.view(-1)] 27 | 28 | metrics = { 29 | 'epe': epe.mean().item(), 30 | 'mag': mag.mean().item() 31 | } 32 | 33 | return flow_loss, metrics -------------------------------------------------------------------------------- /neuflow_sintel.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neufieldrobotics/NeuFlow/1881c3cd85b87c6155d1b0726b8a345b6a1bdb7e/neuflow_sintel.pth -------------------------------------------------------------------------------- /neuflow_things.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neufieldrobotics/NeuFlow/1881c3cd85b87c6155d1b0726b8a345b6a1bdb7e/neuflow_things.pth -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | 4 | import argparse 5 | import os 6 | 7 | from data_utils.datasets import build_train_dataset 8 | from NeuFlow.neuflow import NeuFlow 9 | from loss import flow_loss_func 10 | from data_utils.evaluate import validate_things, validate_sintel, validate_kitti 11 | from load_model import my_load_weights, my_freeze_model 12 | from dist_utils import get_dist_info, init_dist, setup_for_distributed 13 | 14 | 15 | def get_args_parser(): 16 | parser = argparse.ArgumentParser() 17 | 18 | # dataset 19 | parser.add_argument('--checkpoint_dir', default=None, type=str) 20 | parser.add_argument('--dataset_dir', default=None, type=str) 21 | parser.add_argument('--stage', default='things', type=str) 22 | parser.add_argument('--val_dataset', default=['things', 'sintel'], type=str, nargs='+') 23 | 24 | # training 25 | parser.add_argument('--lr', default=1e-4, type=float) 26 | parser.add_argument('--batch_size', default=64, type=int) 27 | parser.add_argument('--num_workers', default=8, type=int) 28 | parser.add_argument('--weight_decay', default=1e-4, type=float) 29 | parser.add_argument('--val_freq', default=1000, type=int) 30 | parser.add_argument('--num_steps', default=1000000, type=int) 31 | 32 | parser.add_argument('--max_flow', default=400, type=int) 33 | 34 | # resume pretrained model or resume training 35 | parser.add_argument('--resume', default=None, type=str) 36 | parser.add_argument('--strict_resume', action='store_true') 37 | 38 | # distributed training 39 | parser.add_argument('--local-rank', default=0, type=int) 40 | parser.add_argument('--distributed', action='store_true') 41 | 42 | return parser 43 | 44 | 45 | def main(args): 46 | # torch.autograd.set_detect_anomaly(True) 47 | print('Use %d GPUs' % torch.cuda.device_count()) 48 | # seed = args.seed 49 | # torch.manual_seed(seed) 50 | # np.random.seed(seed) 51 | # torch.cuda.manual_seed_all(seed) 52 | # torch.backends.cudnn.deterministic = True 53 | 54 | torch.backends.cudnn.benchmark = True 55 | 56 | if args.distributed: 57 | # adjust batch size for each gpu 58 | assert args.batch_size % torch.cuda.device_count() == 0 59 | args.batch_size = args.batch_size // torch.cuda.device_count() 60 | 61 | dist_params = dict(backend='nccl') 62 | init_dist('pytorch', **dist_params) 63 | # re-set gpu_ids with distributed training mode 64 | _, world_size = get_dist_info() 65 | args.gpu_ids = range(world_size) 66 | device = torch.device('cuda:{}'.format(args.local_rank)) 67 | 68 | setup_for_distributed(args.local_rank == 0) 69 | else: 70 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 71 | 72 | # model 73 | model = NeuFlow().to(device) 74 | 75 | if args.distributed: 76 | model = torch.nn.parallel.DistributedDataParallel( 77 | model, 78 | device_ids=[args.local_rank], 79 | output_device=args.local_rank) 80 | model = model.module 81 | 82 | num_params = sum(p.numel() for p in model.parameters()) 83 | print('Number of params:', num_params) 84 | 85 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, 86 | weight_decay=args.weight_decay) 87 | 88 | start_step = 0 89 | 90 | if args.resume: 91 | 92 | state_dict = my_load_weights(args.resume) 93 | 94 | model.load_state_dict(state_dict, strict=args.strict_resume) 95 | 96 | my_freeze_model(model) 97 | 98 | for name, param in model.named_parameters(): 99 | print(name, param.requires_grad) 100 | 101 | torch.save({ 102 | 'model': model.state_dict() 103 | }, os.path.join(args.checkpoint_dir, 'step_0.pth')) 104 | 105 | train_dataset = build_train_dataset(args.stage) 106 | print('Number of training images:', len(train_dataset)) 107 | 108 | if args.distributed: 109 | train_sampler = torch.utils.data.distributed.DistributedSampler( 110 | train_dataset, 111 | num_replicas=torch.cuda.device_count(), 112 | rank=args.local_rank) 113 | else: 114 | train_sampler = None 115 | 116 | shuffle = False if args.distributed else True 117 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, 118 | shuffle=shuffle, num_workers=args.num_workers, 119 | pin_memory=True, drop_last=True, 120 | sampler=train_sampler) 121 | 122 | # lr_scheduler = torch.optim.lr_scheduler.OneCycleLR( 123 | # optimizer, args.lr, 124 | # args.num_steps + 10, 125 | # pct_start=0.05, 126 | # cycle_momentum=False, 127 | # anneal_strategy='cos', 128 | # last_epoch=last_epoch, 129 | # ) 130 | 131 | total_steps = 0 132 | epoch = 0 133 | 134 | counter = 0 135 | 136 | while total_steps < args.num_steps: 137 | model.train() 138 | 139 | # mannual change random seed for shuffling every epoch 140 | if args.distributed: 141 | train_sampler.set_epoch(epoch) 142 | 143 | for i, sample in enumerate(train_loader): 144 | img1, img2, flow_gt, valid = [x.to(device) for x in sample] 145 | 146 | model.init_bhw(img1.shape[0], img1.shape[-2], img1.shape[-1]) 147 | 148 | flow_preds = model(img1, img2) 149 | 150 | loss, metrics = flow_loss_func(flow_preds, flow_gt, valid, args.max_flow) 151 | 152 | # more efficient zero_grad 153 | for param in model.parameters(): 154 | param.grad = None 155 | 156 | loss.backward() 157 | 158 | # Gradient clipping 159 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 160 | 161 | optimizer.step() 162 | 163 | # lr_scheduler.step() 164 | 165 | print(total_steps, metrics['epe'], metrics['mag'], optimizer.param_groups[-1]['lr']) 166 | 167 | total_steps += 1 168 | 169 | if total_steps % args.val_freq == 0: 170 | 171 | if args.local_rank == 0: 172 | checkpoint_path = os.path.join(args.checkpoint_dir, 'step_%06d.pth' % total_steps) 173 | torch.save({ 174 | 'model': model.state_dict() 175 | }, checkpoint_path) 176 | 177 | val_results = {} 178 | 179 | if 'things' in args.val_dataset: 180 | test_results_dict = validate_things(model, dstype='frames_cleanpass', validate_subset=True, max_val_flow=args.max_flow) 181 | if args.local_rank == 0: 182 | val_results.update(test_results_dict) 183 | 184 | if 'sintel' in args.val_dataset: 185 | test_results_dict = validate_sintel(model, dstype='final') 186 | if args.local_rank == 0: 187 | val_results.update(test_results_dict) 188 | 189 | if 'kitti' in args.val_dataset: 190 | test_results_dict = validate_kitti(model) 191 | if args.local_rank == 0: 192 | val_results.update(test_results_dict) 193 | 194 | if args.local_rank == 0: 195 | 196 | counter += 1 197 | 198 | if counter >= 20: 199 | 200 | for group in optimizer.param_groups: 201 | group['lr'] *= 0.7 202 | 203 | counter = 0 204 | 205 | # Save validation results 206 | val_file = os.path.join(args.checkpoint_dir, 'val_results.txt') 207 | with open(val_file, 'a') as f: 208 | f.write('step: %06d lr: %.6f\n' % (total_steps, optimizer.param_groups[-1]['lr'])) 209 | 210 | for k, v in val_results.items(): 211 | f.write("| %s: %.3f " % (k, v)) 212 | 213 | f.write('\n\n') 214 | 215 | model.train() 216 | 217 | epoch += 1 218 | 219 | 220 | if __name__ == '__main__': 221 | parser = get_args_parser() 222 | args = parser.parse_args() 223 | 224 | if 'LOCAL_RANK' not in os.environ: 225 | os.environ['LOCAL_RANK'] = str(args.local_rank) 226 | 227 | main(args) 228 | -------------------------------------------------------------------------------- /write_occ.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import glob 3 | import cv2 4 | import json 5 | import os 6 | 7 | from data_utils import frame_utils 8 | 9 | width = 960 10 | height = 540 11 | 12 | x_range = np.arange(width) 13 | y_range = np.arange(height) 14 | xs, ys = np.meshgrid(x_range, y_range) 15 | coords = np.float32(np.dstack([xs, ys])) 16 | 17 | root = 'datasets/FlyingThings3D/optical_flow/' 18 | 19 | fw_flow_dirs = sorted(glob.glob(root + '*/*/*/into_future/*/')) 20 | bw_flow_dirs = sorted(glob.glob(root + '*/*/*/into_past/*/')) 21 | 22 | print(len(fw_flow_dirs)) 23 | 24 | flow_mag_dict = {} 25 | index = 0 26 | 27 | for fw_flow_dir, bw_flow_dir in zip(fw_flow_dirs, bw_flow_dirs): 28 | 29 | fw_flows = sorted(glob.glob(fw_flow_dir + '*.pfm'))[:-1] 30 | bw_flows = sorted(glob.glob(bw_flow_dir + '*.pfm'))[1:] 31 | 32 | for fw_flow_path, bw_flow_path in zip(fw_flows+bw_flows, bw_flows+fw_flows): 33 | 34 | occlusion_file_path = os.path.splitext(fw_flow_path)[0]+'.png' 35 | 36 | fw_flow = frame_utils.read_gen(fw_flow_path) 37 | bw_flow = frame_utils.read_gen(bw_flow_path) 38 | 39 | warp_flow = cv2.remap(coords, coords + bw_flow, None, interpolation=cv2.INTER_LINEAR) 40 | warp_flow = cv2.remap(warp_flow, coords + fw_flow, None, interpolation=cv2.INTER_LINEAR) 41 | 42 | warp_flow -= coords 43 | 44 | occlusion = np.sum(warp_flow**2, axis=-1) < 0.01 45 | 46 | # cv2.imwrite(occlusion_file_path, occlusion*255) 47 | index += 1 48 | print(index) 49 | --------------------------------------------------------------------------------