├── .gitignore ├── ACKNOWLEDGMENTS ├── DPT └── dpt │ ├── __init__.py │ ├── base_model.py │ ├── blocks.py │ ├── midas_net.py │ ├── models.py │ ├── transforms.py │ └── vit.py ├── LICENSE ├── RAFT ├── alt_cuda_corr │ ├── __init__.py │ ├── correlation.cpp │ ├── correlation_kernel.cu │ └── setup.py └── core │ ├── __init__.py │ ├── corr.py │ ├── datasets.py │ ├── extractor.py │ ├── raft.py │ ├── update.py │ └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── augmentor.cpython-38.pyc │ ├── frame_utils.cpython-38.pyc │ └── utils.cpython-38.pyc │ ├── augmentor.py │ ├── flow_viz.py │ ├── frame_utils.py │ └── utils.py ├── README.md ├── RFdata ├── flow │ └── RFflow.flo ├── img1 │ └── test_1-1.png └── img2 │ └── RFtest_1-2.png ├── RealFLow.py ├── dataset_download.sh ├── demo.py ├── sample ├── test_1-1.png └── test_1-2.png ├── softmax_splatting └── softsplat.py └── utils ├── tools.py └── video2img.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.idea 3 | __pycache__ 4 | .idea -------------------------------------------------------------------------------- /ACKNOWLEDGMENTS: -------------------------------------------------------------------------------- 1 | RealFlow datasets are licensed under the CC BY-NC-SA3.0, except for the third-party components listed below. 2 | 3 | ********************************************************************************************************************************* 4 | Dataset Licensed under the BSD 3-Clause License: 5 | -------------------------------------------------------------------- 6 | 7 | BDD100K 8 | Copyright (c) 2018, Fisher Yu 9 | All rights reserved. 10 | 11 | Terms of the BSD 3-Clause License: 12 | -------------------------------------------------------------------- 13 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 14 | 15 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 16 | 17 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 18 | 19 | * Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 22 | 23 | ********************************************************************************************************************************* 24 | 25 | 26 | ********************************************************************************************************************************* 27 | Dataset Licensed under the CC BY-NC-SA 3.0 License: 28 | -------------------------------------------------------------------- 29 | 30 | KITTI 31 | 32 | Terms of the CC BY-NC-SA 3.0 License: 33 | -------------------------------------------------------------------- 34 | 35 | THE WORK (AS DEFINED BELOW) IS PROVIDED UNDER THE TERMS OF THIS CREATIVE COMMONS PUBLIC LICENSE ("CCPL" OR "LICENSE"). THE WORK IS PROTECTED BY COPYRIGHT AND/OR OTHER APPLICABLE LAW. ANY USE OF THE WORK OTHER THAN AS AUTHORIZED UNDER THIS LICENSE OR COPYRIGHT LAW IS PROHIBITED. 36 | 37 | BY EXERCISING ANY RIGHTS TO THE WORK PROVIDED HERE, YOU ACCEPT AND AGREE TO BE BOUND BY THE TERMS OF THIS LICENSE. TO THE EXTENT THIS LICENSE MAY BE CONSIDERED TO BE A CONTRACT, THE LICENSOR GRANTS YOU THE RIGHTS CONTAINED HERE IN CONSIDERATION OF YOUR ACCEPTANCE OF SUCH TERMS AND CONDITIONS. 38 | 39 | 1. Definitions 40 | "Adaptation" means a work based upon the Work, or upon the Work and other pre-existing works, such as a translation, adaptation, derivative work, arrangement of music or other alterations of a literary or artistic work, or phonogram or performance and includes cinematographic adaptations or any other form in which the Work may be recast, transformed, or adapted including in any form recognizably derived from the original, except that a work that constitutes a Collection will not be considered an Adaptation for the purpose of this License. For the avoidance of doubt, where the Work is a musical work, performance or phonogram, the synchronization of the Work in timed-relation with a moving image ("synching") will be considered an Adaptation for the purpose of this License. 41 | "Collection" means a collection of literary or artistic works, such as encyclopedias and anthologies, or performances, phonograms or broadcasts, or other works or subject matter other than works listed in Section 1(f) below, which, by reason of the selection and arrangement of their contents, constitute intellectual creations, in which the Work is included in its entirety in unmodified form along with one or more other contributions, each constituting separate and independent works in themselves, which together are assembled into a collective whole. A work that constitutes a Collection will not be considered an Adaptation (as defined below) for the purposes of this License. 42 | "Creative Commons Compatible License" means a license that is listed at http://creativecommons.org/compatiblelicenses that has been approved by Creative Commons as being essentially equivalent to this License, including, at a minimum, because that license: (i) contains terms that have the same purpose, meaning and effect as the License Elements of this License; and, (ii) explicitly permits the relicensing of adaptations of works made available under that license under this License or a Creative Commons jurisdiction license with the same License Elements as this License. 43 | "Distribute" means to make available to the public the original and copies of the Work or Adaptation, as appropriate, through sale or other transfer of ownership. 44 | "License Elements" means the following high-level license attributes as selected by Licensor and indicated in the title of this License: Attribution, ShareAlike. 45 | "Licensor" means the individual, individuals, entity or entities that offer(s) the Work under the terms of this License. 46 | "Original Author" means, in the case of a literary or artistic work, the individual, individuals, entity or entities who created the Work or if no individual or entity can be identified, the publisher; and in addition (i) in the case of a performance the actors, singers, musicians, dancers, and other persons who act, sing, deliver, declaim, play in, interpret or otherwise perform literary or artistic works or expressions of folklore; (ii) in the case of a phonogram the producer being the person or legal entity who first fixes the sounds of a performance or other sounds; and, (iii) in the case of broadcasts, the organization that transmits the broadcast. 47 | "Work" means the literary and/or artistic work offered under the terms of this License including without limitation any production in the literary, scientific and artistic domain, whatever may be the mode or form of its expression including digital form, such as a book, pamphlet and other writing; a lecture, address, sermon or other work of the same nature; a dramatic or dramatico-musical work; a choreographic work or entertainment in dumb show; a musical composition with or without words; a cinematographic work to which are assimilated works expressed by a process analogous to cinematography; a work of drawing, painting, architecture, sculpture, engraving or lithography; a photographic work to which are assimilated works expressed by a process analogous to photography; a work of applied art; an illustration, map, plan, sketch or three-dimensional work relative to geography, topography, architecture or science; a performance; a broadcast; a phonogram; a compilation of data to the extent it is protected as a copyrightable work; or a work performed by a variety or circus performer to the extent it is not otherwise considered a literary or artistic work. 48 | "You" means an individual or entity exercising rights under this License who has not previously violated the terms of this License with respect to the Work, or who has received express permission from the Licensor to exercise rights under this License despite a previous violation. 49 | "Publicly Perform" means to perform public recitations of the Work and to communicate to the public those public recitations, by any means or process, including by wire or wireless means or public digital performances; to make available to the public Works in such a way that members of the public may access these Works from a place and at a place individually chosen by them; to perform the Work to the public by any means or process and the communication to the public of the performances of the Work, including by public digital performance; to broadcast and rebroadcast the Work by any means including signs, sounds or images. 50 | "Reproduce" means to make copies of the Work by any means including without limitation by sound or visual recordings and the right of fixation and reproducing fixations of the Work, including storage of a protected performance or phonogram in digital form or other electronic medium. 51 | 52 | 2. Fair Dealing Rights 53 | Nothing in this License is intended to reduce, limit, or restrict any uses free from copyright or rights arising from limitations or exceptions that are provided for in connection with the copyright protection under copyright law or other applicable laws. 54 | 55 | 3. License Grant 56 | Subject to the terms and conditions of this License, Licensor hereby grants You a worldwide, royalty-free, non-exclusive, perpetual (for the duration of the applicable copyright) license to exercise the rights in the Work as stated below: 57 | to Reproduce the Work, to incorporate the Work into one or more Collections, and to Reproduce the Work as incorporated in the Collections; 58 | to create and Reproduce Adaptations provided that any such Adaptation, including any translation in any medium, takes reasonable steps to clearly label, demarcate or otherwise identify that changes were made to the original Work. For example, a translation could be marked "The original work was translated from English to Spanish," or a modification could indicate "The original work has been modified."; to Distribute and Publicly Perform the Work including as incorporated in Collections; and, to Distribute and Publicly Perform Adaptations. 59 | For the avoidance of doubt: 60 | Non-waivable Compulsory License Schemes. In those jurisdictions in which the right to collect royalties through any statutory or compulsory licensing scheme cannot be waived, the Licensor reserves the exclusive right to collect such royalties for any exercise by You of the rights granted under this License; 61 | Waivable Compulsory License Schemes. In those jurisdictions in which the right to collect royalties through any statutory or compulsory licensing scheme can be waived, the Licensor waives the exclusive right to collect such royalties for any exercise by You of the rights granted under this License; and, 62 | Voluntary License Schemes. The Licensor waives the right to collect royalties, whether individually or, in the event that the Licensor is a member of a collecting society that administers voluntary licensing schemes, via that society, from any exercise by You of the rights granted under this License. 63 | The above rights may be exercised in all media and formats whether now known or hereafter devised. The above rights include the right to make such modifications as are technically necessary to exercise the rights in other media and formats. Subject to Section 8(f), all rights not expressly granted by Licensor are hereby reserved. 64 | 65 | 4. Restrictions 66 | The license granted in Section 3 above is expressly made subject to and limited by the following restrictions: 67 | You may Distribute or Publicly Perform the Work only under the terms of this License. You must include a copy of, or the Uniform Resource Identifier (URI) for, this License with every copy of the Work You Distribute or Publicly Perform. You may not offer or impose any terms on the Work that restrict the terms of this License or the ability of the recipient of the Work to exercise the rights granted to that recipient under the terms of the License. You may not sublicense the Work. You must keep intact all notices that refer to this License and to the disclaimer of warranties with every copy of the Work You Distribute or Publicly Perform. When You Distribute or Publicly Perform the Work, You may not impose any effective technological measures on the Work that restrict the ability of a recipient of the Work from You to exercise the rights granted to that recipient under the terms of the License. This Section 4(a) applies to the Work as incorporated in a Collection, but this does not require the Collection apart from the Work itself to be made subject to the terms of this License. If You create a Collection, upon notice from any Licensor You must, to the extent practicable, remove from the Collection any credit as required by Section 4(c), as requested. If You create an Adaptation, upon notice from any Licensor You must, to the extent practicable, remove from the Adaptation any credit as required by Section 4(c), as requested. 68 | You may Distribute or Publicly Perform an Adaptation only under the terms of: (i) this License; (ii) a later version of this License with the same License Elements as this License; (iii) a Creative Commons jurisdiction license (either this or a later license version) that contains the same License Elements as this License (e.g., Attribution-ShareAlike 3.0 US)); (iv) a Creative Commons Compatible License. If you license the Adaptation under one of the licenses mentioned in (iv), you must comply with the terms of that license. If you license the Adaptation under the terms of any of the licenses mentioned in (i), (ii) or (iii) (the "Applicable License"), you must comply with the terms of the Applicable License generally and the following provisions: (I) You must include a copy of, or the URI for, the Applicable License with every copy of each Adaptation You Distribute or Publicly Perform; (II) You may not offer or impose any terms on the Adaptation that restrict the terms of the Applicable License or the ability of the recipient of the Adaptation to exercise the rights granted to that recipient under the terms of the Applicable License; (III) You must keep intact all notices that refer to the Applicable License and to the disclaimer of warranties with every copy of the Work as included in the Adaptation You Distribute or Publicly Perform; (IV) when You Distribute or Publicly Perform the Adaptation, You may not impose any effective technological measures on the Adaptation that restrict the ability of a recipient of the Adaptation from You to exercise the rights granted to that recipient under the terms of the Applicable License. This Section 4(b) applies to the Adaptation as incorporated in a Collection, but this does not require the Collection apart from the Adaptation itself to be made subject to the terms of the Applicable License. 69 | If You Distribute, or Publicly Perform the Work or any Adaptations or Collections, You must, unless a request has been made pursuant to Section 4(a), keep intact all copyright notices for the Work and provide, reasonable to the medium or means You are utilizing: (i) the name of the Original Author (or pseudonym, if applicable) if supplied, and/or if the Original Author and/or Licensor designate another party or parties (e.g., a sponsor institute, publishing entity, journal) for attribution ("Attribution Parties") in Licensor's copyright notice, terms of service or by other reasonable means, the name of such party or parties; (ii) the title of the Work if supplied; (iii) to the extent reasonably practicable, the URI, if any, that Licensor specifies to be associated with the Work, unless such URI does not refer to the copyright notice or licensing information for the Work; and (iv) , consistent with Section 3(b), in the case of an Adaptation, a credit identifying the use of the Work in the Adaptation (e.g., "French translation of the Work by Original Author," or "Screenplay based on original Work by Original Author"). The credit required by this Section 4(c) may be implemented in any reasonable manner; provided, however, that in the case of a Adaptation or Collection, at a minimum such credit will appear, if a credit for all contributing authors of the Adaptation or Collection appears, then as part of these credits and in a manner at least as prominent as the credits for the other contributing authors. For the avoidance of doubt, You may only use the credit required by this Section for the purpose of attribution in the manner set out above and, by exercising Your rights under this License, You may not implicitly or explicitly assert or imply any connection with, sponsorship or endorsement by the Original Author, Licensor and/or Attribution Parties, as appropriate, of You or Your use of the Work, without the separate, express prior written permission of the Original Author, Licensor and/or Attribution Parties. 70 | Except as otherwise agreed in writing by the Licensor or as may be otherwise permitted by applicable law, if You Reproduce, Distribute or Publicly Perform the Work either by itself or as part of any Adaptations or Collections, You must not distort, mutilate, modify or take other derogatory action in relation to the Work which would be prejudicial to the Original Author's honor or reputation. Licensor agrees that in those jurisdictions (e.g. Japan), in which any exercise of the right granted in Section 3(b) of this License (the right to make Adaptations) would be deemed to be a distortion, mutilation, modification or other derogatory action prejudicial to the Original Author's honor and reputation, the Licensor will waive or not assert, as appropriate, this Section, to the fullest extent permitted by the applicable national law, to enable You to reasonably exercise Your right under Section 3(b) of this License (right to make Adaptations) but not otherwise. 71 | 72 | 5. Representations, Warranties and Disclaimer 73 | UNLESS OTHERWISE MUTUALLY AGREED TO BY THE PARTIES IN WRITING, LICENSOR OFFERS THE WORK AS-IS AND MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE WORK, EXPRESS, IMPLIED, STATUTORY OR OTHERWISE, INCLUDING, WITHOUT LIMITATION, WARRANTIES OF TITLE, MERCHANTIBILITY, FITNESS FOR A PARTICULAR PURPOSE, NONINFRINGEMENT, OR THE ABSENCE OF LATENT OR OTHER DEFECTS, ACCURACY, OR THE PRESENCE OF ABSENCE OF ERRORS, WHETHER OR NOT DISCOVERABLE. SOME JURISDICTIONS DO NOT ALLOW THE EXCLUSION OF IMPLIED WARRANTIES, SO SUCH EXCLUSION MAY NOT APPLY TO YOU. 74 | 75 | 6. Limitation on Liability 76 | EXCEPT TO THE EXTENT REQUIRED BY APPLICABLE LAW, IN NO EVENT WILL LICENSOR BE LIABLE TO YOU ON ANY LEGAL THEORY FOR ANY SPECIAL, INCIDENTAL, CONSEQUENTIAL, PUNITIVE OR EXEMPLARY DAMAGES ARISING OUT OF THIS LICENSE OR THE USE OF THE WORK, EVEN IF LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 77 | 78 | 7. Termination 79 | This License and the rights granted hereunder will terminate automatically upon any breach by You of the terms of this License. Individuals or entities who have received Adaptations or Collections from You under this License, however, will not have their licenses terminated provided such individuals or entities remain in full compliance with those licenses. Sections 1, 2, 5, 6, 7, and 8 will survive any termination of this License. 80 | Subject to the above terms and conditions, the license granted here is perpetual (for the duration of the applicable copyright in the Work). Notwithstanding the above, Licensor reserves the right to release the Work under different license terms or to stop distributing the Work at any time; provided, however that any such election will not serve to withdraw this License (or any other license that has been, or is required to be, granted under the terms of this License), and this License will continue in full force and effect unless terminated as stated above. 81 | 82 | 8. Miscellaneous 83 | Each time You Distribute or Publicly Perform the Work or a Collection, the Licensor offers to the recipient a license to the Work on the same terms and conditions as the license granted to You under this License. 84 | Each time You Distribute or Publicly Perform an Adaptation, Licensor offers to the recipient a license to the original Work on the same terms and conditions as the license granted to You under this License. 85 | If any provision of this License is invalid or unenforceable under applicable law, it shall not affect the validity or enforceability of the remainder of the terms of this License, and without further action by the parties to this agreement, such provision shall be reformed to the minimum extent necessary to make such provision valid and enforceable. 86 | No term or provision of this License shall be deemed waived and no breach consented to unless such waiver or consent shall be in writing and signed by the party to be charged with such waiver or consent. 87 | This License constitutes the entire agreement between the parties with respect to the Work licensed here. There are no understandings, agreements or representations with respect to the Work not specified here. Licensor shall not be bound by any additional provisions that may appear in any communication from You. This License may not be modified without the mutual written agreement of the Licensor and You. 88 | The rights granted under, and the subject matter referenced, in this License were drafted utilizing the terminology of the Berne Convention for the Protection of Literary and Artistic Works (as amended on September 28, 1979), the Rome Convention of 1961, the WIPO Copyright Treaty of 1996, the WIPO Performances and Phonograms Treaty of 1996 and the Universal Copyright Convention (as revised on July 24, 1971). These rights and subject matter take effect in the relevant jurisdiction in which the License terms are sought to be enforced according to the corresponding provisions of the implementation of those treaty provisions in the applicable national law. If the standard suite of rights granted under applicable copyright law includes additional rights not granted under this License, such additional rights are deemed to be included in the License; this License is not intended to restrict the license of any rights under applicable law. 89 | 90 | ********************************************************************************************************************************* 91 | -------------------------------------------------------------------------------- /DPT/dpt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/RealFlow/52ea80d416d55f454f63650eeba85687055285c2/DPT/dpt/__init__.py -------------------------------------------------------------------------------- /DPT/dpt/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseModel(torch.nn.Module): 5 | def load(self, path): 6 | """Load model from file. 7 | 8 | Args: 9 | path (str): file path 10 | """ 11 | parameters = torch.load(path, map_location=torch.device("cpu")) 12 | 13 | if "optimizer" in parameters: 14 | parameters = parameters["model"] 15 | 16 | self.load_state_dict(parameters) 17 | -------------------------------------------------------------------------------- /DPT/dpt/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .vit import ( 5 | _make_pretrained_vitb_rn50_384, 6 | _make_pretrained_vitl16_384, 7 | _make_pretrained_vitb16_384, 8 | forward_vit, 9 | ) 10 | 11 | 12 | def _make_encoder( 13 | backbone, 14 | features, 15 | use_pretrained, 16 | groups=1, 17 | expand=False, 18 | exportable=True, 19 | hooks=None, 20 | use_vit_only=False, 21 | use_readout="ignore", 22 | enable_attention_hooks=False, 23 | ): 24 | if backbone == "vitl16_384": 25 | pretrained = _make_pretrained_vitl16_384( 26 | use_pretrained, 27 | hooks=hooks, 28 | use_readout=use_readout, 29 | enable_attention_hooks=enable_attention_hooks, 30 | ) 31 | scratch = _make_scratch( 32 | [256, 512, 1024, 1024], features, groups=groups, expand=expand 33 | ) # ViT-L/16 - 85.0% Top1 (backbone) 34 | elif backbone == "vitb_rn50_384": 35 | pretrained = _make_pretrained_vitb_rn50_384( 36 | use_pretrained, 37 | hooks=hooks, 38 | use_vit_only=use_vit_only, 39 | use_readout=use_readout, 40 | enable_attention_hooks=enable_attention_hooks, 41 | ) 42 | scratch = _make_scratch( 43 | [256, 512, 768, 768], features, groups=groups, expand=expand 44 | ) # ViT-H/16 - 85.0% Top1 (backbone) 45 | elif backbone == "vitb16_384": 46 | pretrained = _make_pretrained_vitb16_384( 47 | use_pretrained, 48 | hooks=hooks, 49 | use_readout=use_readout, 50 | enable_attention_hooks=enable_attention_hooks, 51 | ) 52 | scratch = _make_scratch( 53 | [96, 192, 384, 768], features, groups=groups, expand=expand 54 | ) # ViT-B/16 - 84.6% Top1 (backbone) 55 | elif backbone == "resnext101_wsl": 56 | pretrained = _make_pretrained_resnext101_wsl(use_pretrained) 57 | scratch = _make_scratch( 58 | [256, 512, 1024, 2048], features, groups=groups, expand=expand 59 | ) # efficientnet_lite3 60 | else: 61 | print(f"Backbone '{backbone}' not implemented") 62 | assert False 63 | 64 | return pretrained, scratch 65 | 66 | 67 | def _make_scratch(in_shape, out_shape, groups=1, expand=False): 68 | scratch = nn.Module() 69 | 70 | out_shape1 = out_shape 71 | out_shape2 = out_shape 72 | out_shape3 = out_shape 73 | out_shape4 = out_shape 74 | if expand == True: 75 | out_shape1 = out_shape 76 | out_shape2 = out_shape * 2 77 | out_shape3 = out_shape * 4 78 | out_shape4 = out_shape * 8 79 | 80 | scratch.layer1_rn = nn.Conv2d( 81 | in_shape[0], 82 | out_shape1, 83 | kernel_size=3, 84 | stride=1, 85 | padding=1, 86 | bias=False, 87 | groups=groups, 88 | ) 89 | scratch.layer2_rn = nn.Conv2d( 90 | in_shape[1], 91 | out_shape2, 92 | kernel_size=3, 93 | stride=1, 94 | padding=1, 95 | bias=False, 96 | groups=groups, 97 | ) 98 | scratch.layer3_rn = nn.Conv2d( 99 | in_shape[2], 100 | out_shape3, 101 | kernel_size=3, 102 | stride=1, 103 | padding=1, 104 | bias=False, 105 | groups=groups, 106 | ) 107 | scratch.layer4_rn = nn.Conv2d( 108 | in_shape[3], 109 | out_shape4, 110 | kernel_size=3, 111 | stride=1, 112 | padding=1, 113 | bias=False, 114 | groups=groups, 115 | ) 116 | 117 | return scratch 118 | 119 | 120 | def _make_resnet_backbone(resnet): 121 | pretrained = nn.Module() 122 | pretrained.layer1 = nn.Sequential( 123 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 124 | ) 125 | 126 | pretrained.layer2 = resnet.layer2 127 | pretrained.layer3 = resnet.layer3 128 | pretrained.layer4 = resnet.layer4 129 | 130 | return pretrained 131 | 132 | 133 | def _make_pretrained_resnext101_wsl(use_pretrained): 134 | resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") 135 | return _make_resnet_backbone(resnet) 136 | 137 | 138 | class Interpolate(nn.Module): 139 | """Interpolation module.""" 140 | 141 | def __init__(self, scale_factor, mode, align_corners=False): 142 | """Init. 143 | 144 | Args: 145 | scale_factor (float): scaling 146 | mode (str): interpolation mode 147 | """ 148 | super(Interpolate, self).__init__() 149 | 150 | self.interp = nn.functional.interpolate 151 | self.scale_factor = scale_factor 152 | self.mode = mode 153 | self.align_corners = align_corners 154 | 155 | def forward(self, x): 156 | """Forward pass. 157 | 158 | Args: 159 | x (tensor): input 160 | 161 | Returns: 162 | tensor: interpolated data 163 | """ 164 | 165 | x = self.interp( 166 | x, 167 | scale_factor=self.scale_factor, 168 | mode=self.mode, 169 | align_corners=self.align_corners, 170 | ) 171 | 172 | return x 173 | 174 | 175 | class ResidualConvUnit(nn.Module): 176 | """Residual convolution module.""" 177 | 178 | def __init__(self, features): 179 | """Init. 180 | 181 | Args: 182 | features (int): number of features 183 | """ 184 | super().__init__() 185 | 186 | self.conv1 = nn.Conv2d( 187 | features, features, kernel_size=3, stride=1, padding=1, bias=True 188 | ) 189 | 190 | self.conv2 = nn.Conv2d( 191 | features, features, kernel_size=3, stride=1, padding=1, bias=True 192 | ) 193 | 194 | self.relu = nn.ReLU(inplace=True) 195 | 196 | def forward(self, x): 197 | """Forward pass. 198 | 199 | Args: 200 | x (tensor): input 201 | 202 | Returns: 203 | tensor: output 204 | """ 205 | out = self.relu(x) 206 | out = self.conv1(out) 207 | out = self.relu(out) 208 | out = self.conv2(out) 209 | 210 | return out + x 211 | 212 | 213 | class FeatureFusionBlock(nn.Module): 214 | """Feature fusion block.""" 215 | 216 | def __init__(self, features): 217 | """Init. 218 | 219 | Args: 220 | features (int): number of features 221 | """ 222 | super(FeatureFusionBlock, self).__init__() 223 | 224 | self.resConfUnit1 = ResidualConvUnit(features) 225 | self.resConfUnit2 = ResidualConvUnit(features) 226 | 227 | def forward(self, *xs): 228 | """Forward pass. 229 | 230 | Returns: 231 | tensor: output 232 | """ 233 | output = xs[0] 234 | 235 | if len(xs) == 2: 236 | output += self.resConfUnit1(xs[1]) 237 | 238 | output = self.resConfUnit2(output) 239 | 240 | output = nn.functional.interpolate( 241 | output, scale_factor=2, mode="bilinear", align_corners=True 242 | ) 243 | 244 | return output 245 | 246 | 247 | class ResidualConvUnit_custom(nn.Module): 248 | """Residual convolution module.""" 249 | 250 | def __init__(self, features, activation, bn): 251 | """Init. 252 | 253 | Args: 254 | features (int): number of features 255 | """ 256 | super().__init__() 257 | 258 | self.bn = bn 259 | 260 | self.groups = 1 261 | 262 | self.conv1 = nn.Conv2d( 263 | features, 264 | features, 265 | kernel_size=3, 266 | stride=1, 267 | padding=1, 268 | bias=not self.bn, 269 | groups=self.groups, 270 | ) 271 | 272 | self.conv2 = nn.Conv2d( 273 | features, 274 | features, 275 | kernel_size=3, 276 | stride=1, 277 | padding=1, 278 | bias=not self.bn, 279 | groups=self.groups, 280 | ) 281 | 282 | if self.bn == True: 283 | self.bn1 = nn.BatchNorm2d(features) 284 | self.bn2 = nn.BatchNorm2d(features) 285 | 286 | self.activation = activation 287 | 288 | self.skip_add = nn.quantized.FloatFunctional() 289 | 290 | def forward(self, x): 291 | """Forward pass. 292 | 293 | Args: 294 | x (tensor): input 295 | 296 | Returns: 297 | tensor: output 298 | """ 299 | 300 | out = self.activation(x) 301 | out = self.conv1(out) 302 | if self.bn == True: 303 | out = self.bn1(out) 304 | 305 | out = self.activation(out) 306 | out = self.conv2(out) 307 | if self.bn == True: 308 | out = self.bn2(out) 309 | 310 | if self.groups > 1: 311 | out = self.conv_merge(out) 312 | 313 | return self.skip_add.add(out, x) 314 | 315 | # return out + x 316 | 317 | 318 | class FeatureFusionBlock_custom(nn.Module): 319 | """Feature fusion block.""" 320 | 321 | def __init__( 322 | self, 323 | features, 324 | activation, 325 | deconv=False, 326 | bn=False, 327 | expand=False, 328 | align_corners=True, 329 | ): 330 | """Init. 331 | 332 | Args: 333 | features (int): number of features 334 | """ 335 | super(FeatureFusionBlock_custom, self).__init__() 336 | 337 | self.deconv = deconv 338 | self.align_corners = align_corners 339 | 340 | self.groups = 1 341 | 342 | self.expand = expand 343 | out_features = features 344 | if self.expand == True: 345 | out_features = features // 2 346 | 347 | self.out_conv = nn.Conv2d( 348 | features, 349 | out_features, 350 | kernel_size=1, 351 | stride=1, 352 | padding=0, 353 | bias=True, 354 | groups=1, 355 | ) 356 | 357 | self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) 358 | self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) 359 | 360 | self.skip_add = nn.quantized.FloatFunctional() 361 | 362 | def forward(self, *xs): 363 | """Forward pass. 364 | 365 | Returns: 366 | tensor: output 367 | """ 368 | output = xs[0] 369 | 370 | if len(xs) == 2: 371 | res = self.resConfUnit1(xs[1]) 372 | output = self.skip_add.add(output, res) 373 | # output += res 374 | 375 | output = self.resConfUnit2(output) 376 | 377 | output = nn.functional.interpolate( 378 | output, scale_factor=2, mode="bilinear", align_corners=self.align_corners 379 | ) 380 | 381 | output = self.out_conv(output) 382 | 383 | return output 384 | -------------------------------------------------------------------------------- /DPT/dpt/midas_net.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet_large(BaseModel): 13 | """Network for monocular depth estimation.""" 14 | 15 | def __init__(self, path=None, features=256, non_negative=True): 16 | """Init. 17 | 18 | Args: 19 | path (str, optional): Path to saved model. Defaults to None. 20 | features (int, optional): Number of features. Defaults to 256. 21 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 22 | """ 23 | print("Loading weights: ", path) 24 | 25 | super(MidasNet_large, self).__init__() 26 | 27 | use_pretrained = False if path is None else True 28 | 29 | self.pretrained, self.scratch = _make_encoder( 30 | backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained 31 | ) 32 | 33 | self.scratch.refinenet4 = FeatureFusionBlock(features) 34 | self.scratch.refinenet3 = FeatureFusionBlock(features) 35 | self.scratch.refinenet2 = FeatureFusionBlock(features) 36 | self.scratch.refinenet1 = FeatureFusionBlock(features) 37 | 38 | self.scratch.output_conv = nn.Sequential( 39 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 40 | Interpolate(scale_factor=2, mode="bilinear"), 41 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 42 | nn.ReLU(True), 43 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 44 | nn.ReLU(True) if non_negative else nn.Identity(), 45 | ) 46 | 47 | if path: 48 | self.load(path) 49 | 50 | def forward(self, x): 51 | """Forward pass. 52 | 53 | Args: 54 | x (tensor): input data (image) 55 | 56 | Returns: 57 | tensor: depth 58 | """ 59 | 60 | layer_1 = self.pretrained.layer1(x) 61 | layer_2 = self.pretrained.layer2(layer_1) 62 | layer_3 = self.pretrained.layer3(layer_2) 63 | layer_4 = self.pretrained.layer4(layer_3) 64 | 65 | layer_1_rn = self.scratch.layer1_rn(layer_1) 66 | layer_2_rn = self.scratch.layer2_rn(layer_2) 67 | layer_3_rn = self.scratch.layer3_rn(layer_3) 68 | layer_4_rn = self.scratch.layer4_rn(layer_4) 69 | 70 | path_4 = self.scratch.refinenet4(layer_4_rn) 71 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 72 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 73 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 74 | 75 | out = self.scratch.output_conv(path_1) 76 | 77 | return torch.squeeze(out, dim=1) 78 | -------------------------------------------------------------------------------- /DPT/dpt/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base_model import BaseModel 6 | from .blocks import ( 7 | FeatureFusionBlock, 8 | FeatureFusionBlock_custom, 9 | Interpolate, 10 | _make_encoder, 11 | forward_vit, 12 | ) 13 | 14 | 15 | def _make_fusion_block(features, use_bn): 16 | return FeatureFusionBlock_custom( 17 | features, 18 | nn.ReLU(False), 19 | deconv=False, 20 | bn=use_bn, 21 | expand=False, 22 | align_corners=True, 23 | ) 24 | 25 | 26 | class DPT(BaseModel): 27 | def __init__( 28 | self, 29 | head, 30 | features=256, 31 | backbone="vitb_rn50_384", 32 | readout="project", 33 | channels_last=False, 34 | use_bn=False, 35 | enable_attention_hooks=False, 36 | ): 37 | 38 | super(DPT, self).__init__() 39 | 40 | self.channels_last = channels_last 41 | 42 | hooks = { 43 | "vitb_rn50_384": [0, 1, 8, 11], 44 | "vitb16_384": [2, 5, 8, 11], 45 | "vitl16_384": [5, 11, 17, 23], 46 | } 47 | 48 | # Instantiate backbone and reassemble blocks 49 | self.pretrained, self.scratch = _make_encoder( 50 | backbone, 51 | features, 52 | False, # Set to true of you want to train from scratch, uses ImageNet weights 53 | groups=1, 54 | expand=False, 55 | exportable=False, 56 | hooks=hooks[backbone], 57 | use_readout=readout, 58 | enable_attention_hooks=enable_attention_hooks, 59 | ) 60 | 61 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 62 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 63 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn) 64 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 65 | 66 | self.scratch.output_conv = head 67 | 68 | def forward(self, x): 69 | if self.channels_last == True: 70 | x.contiguous(memory_format=torch.channels_last) 71 | 72 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) 73 | 74 | layer_1_rn = self.scratch.layer1_rn(layer_1) 75 | layer_2_rn = self.scratch.layer2_rn(layer_2) 76 | layer_3_rn = self.scratch.layer3_rn(layer_3) 77 | layer_4_rn = self.scratch.layer4_rn(layer_4) 78 | 79 | path_4 = self.scratch.refinenet4(layer_4_rn) 80 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 81 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 82 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 83 | 84 | out = self.scratch.output_conv(path_1) 85 | 86 | return out 87 | 88 | 89 | class DPTDepthModel(DPT): 90 | def __init__( 91 | self, path=None, non_negative=True, scale=1.0, shift=0.0, invert=False, **kwargs 92 | ): 93 | features = kwargs["features"] if "features" in kwargs else 256 94 | 95 | self.scale = scale 96 | self.shift = shift 97 | self.invert = invert 98 | 99 | head = nn.Sequential( 100 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), 101 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 102 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 103 | nn.ReLU(True), 104 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 105 | nn.ReLU(True) if non_negative else nn.Identity(), 106 | nn.Identity(), 107 | ) 108 | 109 | super().__init__(head, **kwargs) 110 | 111 | if path is not None: 112 | self.load(path) 113 | 114 | def forward(self, x): 115 | inv_depth = super().forward(x).squeeze(dim=1) 116 | 117 | if self.invert: 118 | depth = self.scale * inv_depth + self.shift 119 | depth[depth < 1e-8] = 1e-8 120 | depth = 1.0 / depth 121 | return depth 122 | else: 123 | return inv_depth 124 | 125 | 126 | class DPTSegmentationModel(DPT): 127 | def __init__(self, num_classes, path=None, **kwargs): 128 | 129 | features = kwargs["features"] if "features" in kwargs else 256 130 | 131 | kwargs["use_bn"] = True 132 | 133 | head = nn.Sequential( 134 | nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False), 135 | nn.BatchNorm2d(features), 136 | nn.ReLU(True), 137 | nn.Dropout(0.1, False), 138 | nn.Conv2d(features, num_classes, kernel_size=1), 139 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 140 | ) 141 | 142 | super().__init__(head, **kwargs) 143 | 144 | self.auxlayer = nn.Sequential( 145 | nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False), 146 | nn.BatchNorm2d(features), 147 | nn.ReLU(True), 148 | nn.Dropout(0.1, False), 149 | nn.Conv2d(features, num_classes, kernel_size=1), 150 | ) 151 | 152 | if path is not None: 153 | self.load(path) 154 | -------------------------------------------------------------------------------- /DPT/dpt/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import math 4 | 5 | 6 | def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): 7 | """Rezise the sample to ensure the given size. Keeps aspect ratio. 8 | 9 | Args: 10 | sample (dict): sample 11 | size (tuple): image size 12 | 13 | Returns: 14 | tuple: new size 15 | """ 16 | shape = list(sample["disparity"].shape) 17 | 18 | if shape[0] >= size[0] and shape[1] >= size[1]: 19 | return sample 20 | 21 | scale = [0, 0] 22 | scale[0] = size[0] / shape[0] 23 | scale[1] = size[1] / shape[1] 24 | 25 | scale = max(scale) 26 | 27 | shape[0] = math.ceil(scale * shape[0]) 28 | shape[1] = math.ceil(scale * shape[1]) 29 | 30 | # resize 31 | sample["image"] = cv2.resize( 32 | sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method 33 | ) 34 | 35 | sample["disparity"] = cv2.resize( 36 | sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST 37 | ) 38 | sample["mask"] = cv2.resize( 39 | sample["mask"].astype(np.float32), 40 | tuple(shape[::-1]), 41 | interpolation=cv2.INTER_NEAREST, 42 | ) 43 | sample["mask"] = sample["mask"].astype(bool) 44 | 45 | return tuple(shape) 46 | 47 | 48 | class Resize(object): 49 | """Resize sample to given size (width, height).""" 50 | 51 | def __init__( 52 | self, 53 | width, 54 | height, 55 | resize_target=True, 56 | keep_aspect_ratio=False, 57 | ensure_multiple_of=1, 58 | resize_method="lower_bound", 59 | image_interpolation_method=cv2.INTER_AREA, 60 | ): 61 | """Init. 62 | 63 | Args: 64 | width (int): desired output width 65 | height (int): desired output height 66 | resize_target (bool, optional): 67 | True: Resize the full sample (image, mask, target). 68 | False: Resize image only. 69 | Defaults to True. 70 | keep_aspect_ratio (bool, optional): 71 | True: Keep the aspect ratio of the input sample. 72 | Output sample might not have the given width and height, and 73 | resize behaviour depends on the parameter 'resize_method'. 74 | Defaults to False. 75 | ensure_multiple_of (int, optional): 76 | Output width and height is constrained to be multiple of this parameter. 77 | Defaults to 1. 78 | resize_method (str, optional): 79 | "lower_bound": Output will be at least as large as the given size. 80 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) 81 | "minimal": Scale as least as possible. (Output size might be smaller than given size.) 82 | Defaults to "lower_bound". 83 | """ 84 | self.__width = width 85 | self.__height = height 86 | 87 | self.__resize_target = resize_target 88 | self.__keep_aspect_ratio = keep_aspect_ratio 89 | self.__multiple_of = ensure_multiple_of 90 | self.__resize_method = resize_method 91 | self.__image_interpolation_method = image_interpolation_method 92 | 93 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None): 94 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) 95 | 96 | if max_val is not None and y > max_val: 97 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) 98 | 99 | if y < min_val: 100 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) 101 | 102 | return y 103 | 104 | def get_size(self, width, height): 105 | # determine new height and width 106 | scale_height = self.__height / height 107 | scale_width = self.__width / width 108 | 109 | if self.__keep_aspect_ratio: 110 | if self.__resize_method == "lower_bound": 111 | # scale such that output size is lower bound 112 | if scale_width > scale_height: 113 | # fit width 114 | scale_height = scale_width 115 | else: 116 | # fit height 117 | scale_width = scale_height 118 | elif self.__resize_method == "upper_bound": 119 | # scale such that output size is upper bound 120 | if scale_width < scale_height: 121 | # fit width 122 | scale_height = scale_width 123 | else: 124 | # fit height 125 | scale_width = scale_height 126 | elif self.__resize_method == "minimal": 127 | # scale as least as possbile 128 | if abs(1 - scale_width) < abs(1 - scale_height): 129 | # fit width 130 | scale_height = scale_width 131 | else: 132 | # fit height 133 | scale_width = scale_height 134 | else: 135 | raise ValueError( 136 | f"resize_method {self.__resize_method} not implemented" 137 | ) 138 | 139 | if self.__resize_method == "lower_bound": 140 | new_height = self.constrain_to_multiple_of( 141 | scale_height * height, min_val=self.__height 142 | ) 143 | new_width = self.constrain_to_multiple_of( 144 | scale_width * width, min_val=self.__width 145 | ) 146 | elif self.__resize_method == "upper_bound": 147 | new_height = self.constrain_to_multiple_of( 148 | scale_height * height, max_val=self.__height 149 | ) 150 | new_width = self.constrain_to_multiple_of( 151 | scale_width * width, max_val=self.__width 152 | ) 153 | elif self.__resize_method == "minimal": 154 | new_height = self.constrain_to_multiple_of(scale_height * height) 155 | new_width = self.constrain_to_multiple_of(scale_width * width) 156 | else: 157 | raise ValueError(f"resize_method {self.__resize_method} not implemented") 158 | 159 | return (new_width, new_height) 160 | 161 | def __call__(self, sample): 162 | width, height = self.get_size( 163 | sample["image"].shape[1], sample["image"].shape[0] 164 | ) 165 | 166 | # resize sample 167 | sample["image"] = cv2.resize( 168 | sample["image"], 169 | (width, height), 170 | interpolation=self.__image_interpolation_method, 171 | ) 172 | 173 | if self.__resize_target: 174 | if "disparity" in sample: 175 | sample["disparity"] = cv2.resize( 176 | sample["disparity"], 177 | (width, height), 178 | interpolation=cv2.INTER_NEAREST, 179 | ) 180 | 181 | if "depth" in sample: 182 | sample["depth"] = cv2.resize( 183 | sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST 184 | ) 185 | 186 | sample["mask"] = cv2.resize( 187 | sample["mask"].astype(np.float32), 188 | (width, height), 189 | interpolation=cv2.INTER_NEAREST, 190 | ) 191 | sample["mask"] = sample["mask"].astype(bool) 192 | 193 | return sample 194 | 195 | 196 | class NormalizeImage(object): 197 | """Normlize image by given mean and std.""" 198 | 199 | def __init__(self, mean, std): 200 | self.__mean = mean 201 | self.__std = std 202 | 203 | def __call__(self, sample): 204 | sample["image"] = (sample["image"] - self.__mean) / self.__std 205 | 206 | return sample 207 | 208 | 209 | class PrepareForNet(object): 210 | """Prepare sample for usage as network input.""" 211 | 212 | def __init__(self): 213 | pass 214 | 215 | def __call__(self, sample): 216 | image = np.transpose(sample["image"], (2, 0, 1)) 217 | sample["image"] = np.ascontiguousarray(image).astype(np.float32) 218 | 219 | if "mask" in sample: 220 | sample["mask"] = sample["mask"].astype(np.float32) 221 | sample["mask"] = np.ascontiguousarray(sample["mask"]) 222 | 223 | if "disparity" in sample: 224 | disparity = sample["disparity"].astype(np.float32) 225 | sample["disparity"] = np.ascontiguousarray(disparity) 226 | 227 | if "depth" in sample: 228 | depth = sample["depth"].astype(np.float32) 229 | sample["depth"] = np.ascontiguousarray(depth) 230 | 231 | return sample 232 | -------------------------------------------------------------------------------- /DPT/dpt/vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import timm 4 | import types 5 | import math 6 | import torch.nn.functional as F 7 | 8 | 9 | activations = {} 10 | 11 | 12 | def get_activation(name): 13 | def hook(model, input, output): 14 | activations[name] = output 15 | 16 | return hook 17 | 18 | 19 | attention = {} 20 | 21 | 22 | def get_attention(name): 23 | def hook(module, input, output): 24 | x = input[0] 25 | B, N, C = x.shape 26 | qkv = ( 27 | module.qkv(x) 28 | .reshape(B, N, 3, module.num_heads, C // module.num_heads) 29 | .permute(2, 0, 3, 1, 4) 30 | ) 31 | q, k, v = ( 32 | qkv[0], 33 | qkv[1], 34 | qkv[2], 35 | ) # make torchscript happy (cannot use tensor as tuple) 36 | 37 | attn = (q @ k.transpose(-2, -1)) * module.scale 38 | 39 | attn = attn.softmax(dim=-1) # [:,:,1,1:] 40 | attention[name] = attn 41 | 42 | return hook 43 | 44 | 45 | def get_mean_attention_map(attn, token, shape): 46 | attn = attn[:, :, token, 1:] 47 | attn = attn.unflatten(2, torch.Size([shape[2] // 16, shape[3] // 16])).float() 48 | attn = torch.nn.functional.interpolate( 49 | attn, size=shape[2:], mode="bicubic", align_corners=False 50 | ).squeeze(0) 51 | 52 | all_attn = torch.mean(attn, 0) 53 | 54 | return all_attn 55 | 56 | 57 | class Slice(nn.Module): 58 | def __init__(self, start_index=1): 59 | super(Slice, self).__init__() 60 | self.start_index = start_index 61 | 62 | def forward(self, x): 63 | return x[:, self.start_index :] 64 | 65 | 66 | class AddReadout(nn.Module): 67 | def __init__(self, start_index=1): 68 | super(AddReadout, self).__init__() 69 | self.start_index = start_index 70 | 71 | def forward(self, x): 72 | if self.start_index == 2: 73 | readout = (x[:, 0] + x[:, 1]) / 2 74 | else: 75 | readout = x[:, 0] 76 | return x[:, self.start_index :] + readout.unsqueeze(1) 77 | 78 | 79 | class ProjectReadout(nn.Module): 80 | def __init__(self, in_features, start_index=1): 81 | super(ProjectReadout, self).__init__() 82 | self.start_index = start_index 83 | 84 | self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) 85 | 86 | def forward(self, x): 87 | readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) 88 | features = torch.cat((x[:, self.start_index :], readout), -1) 89 | 90 | return self.project(features) 91 | 92 | 93 | class Transpose(nn.Module): 94 | def __init__(self, dim0, dim1): 95 | super(Transpose, self).__init__() 96 | self.dim0 = dim0 97 | self.dim1 = dim1 98 | 99 | def forward(self, x): 100 | x = x.transpose(self.dim0, self.dim1) 101 | return x 102 | 103 | 104 | def forward_vit(pretrained, x): 105 | b, c, h, w = x.shape 106 | 107 | glob = pretrained.model.forward_flex(x) 108 | 109 | layer_1 = pretrained.activations["1"] 110 | layer_2 = pretrained.activations["2"] 111 | layer_3 = pretrained.activations["3"] 112 | layer_4 = pretrained.activations["4"] 113 | 114 | layer_1 = pretrained.act_postprocess1[0:2](layer_1) 115 | layer_2 = pretrained.act_postprocess2[0:2](layer_2) 116 | layer_3 = pretrained.act_postprocess3[0:2](layer_3) 117 | layer_4 = pretrained.act_postprocess4[0:2](layer_4) 118 | 119 | unflatten = nn.Sequential( 120 | nn.Unflatten( 121 | 2, 122 | torch.Size( 123 | [ 124 | h // pretrained.model.patch_size[1], 125 | w // pretrained.model.patch_size[0], 126 | ] 127 | ), 128 | ) 129 | ) 130 | 131 | if layer_1.ndim == 3: 132 | layer_1 = unflatten(layer_1) 133 | if layer_2.ndim == 3: 134 | layer_2 = unflatten(layer_2) 135 | if layer_3.ndim == 3: 136 | layer_3 = unflatten(layer_3) 137 | if layer_4.ndim == 3: 138 | layer_4 = unflatten(layer_4) 139 | 140 | layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) 141 | layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) 142 | layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) 143 | layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) 144 | 145 | return layer_1, layer_2, layer_3, layer_4 146 | 147 | 148 | def _resize_pos_embed(self, posemb, gs_h, gs_w): 149 | posemb_tok, posemb_grid = ( 150 | posemb[:, : self.start_index], 151 | posemb[0, self.start_index :], 152 | ) 153 | 154 | gs_old = int(math.sqrt(len(posemb_grid))) 155 | 156 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 157 | posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") 158 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) 159 | 160 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 161 | 162 | return posemb 163 | 164 | 165 | def forward_flex(self, x): 166 | b, c, h, w = x.shape 167 | 168 | pos_embed = self._resize_pos_embed( 169 | self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] 170 | ) 171 | 172 | B = x.shape[0] 173 | 174 | if hasattr(self.patch_embed, "backbone"): 175 | x = self.patch_embed.backbone(x) 176 | if isinstance(x, (list, tuple)): 177 | x = x[-1] # last feature if backbone outputs list/tuple of features 178 | 179 | x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) 180 | 181 | if getattr(self, "dist_token", None) is not None: 182 | cls_tokens = self.cls_token.expand( 183 | B, -1, -1 184 | ) # stole cls_tokens impl from Phil Wang, thanks 185 | dist_token = self.dist_token.expand(B, -1, -1) 186 | x = torch.cat((cls_tokens, dist_token, x), dim=1) 187 | else: 188 | cls_tokens = self.cls_token.expand( 189 | B, -1, -1 190 | ) # stole cls_tokens impl from Phil Wang, thanks 191 | x = torch.cat((cls_tokens, x), dim=1) 192 | 193 | x = x + pos_embed 194 | x = self.pos_drop(x) 195 | 196 | for blk in self.blocks: 197 | x = blk(x) 198 | 199 | x = self.norm(x) 200 | 201 | return x 202 | 203 | 204 | def get_readout_oper(vit_features, features, use_readout, start_index=1): 205 | if use_readout == "ignore": 206 | readout_oper = [Slice(start_index)] * len(features) 207 | elif use_readout == "add": 208 | readout_oper = [AddReadout(start_index)] * len(features) 209 | elif use_readout == "project": 210 | readout_oper = [ 211 | ProjectReadout(vit_features, start_index) for out_feat in features 212 | ] 213 | else: 214 | assert ( 215 | False 216 | ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" 217 | 218 | return readout_oper 219 | 220 | 221 | def _make_vit_b16_backbone( 222 | model, 223 | features=[96, 192, 384, 768], 224 | size=[384, 384], 225 | hooks=[2, 5, 8, 11], 226 | vit_features=768, 227 | use_readout="ignore", 228 | start_index=1, 229 | enable_attention_hooks=False, 230 | ): 231 | pretrained = nn.Module() 232 | 233 | pretrained.model = model 234 | pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) 235 | pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) 236 | pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) 237 | pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) 238 | 239 | pretrained.activations = activations 240 | 241 | if enable_attention_hooks: 242 | pretrained.model.blocks[hooks[0]].attn.register_forward_hook( 243 | get_attention("attn_1") 244 | ) 245 | pretrained.model.blocks[hooks[1]].attn.register_forward_hook( 246 | get_attention("attn_2") 247 | ) 248 | pretrained.model.blocks[hooks[2]].attn.register_forward_hook( 249 | get_attention("attn_3") 250 | ) 251 | pretrained.model.blocks[hooks[3]].attn.register_forward_hook( 252 | get_attention("attn_4") 253 | ) 254 | pretrained.attention = attention 255 | 256 | readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) 257 | 258 | # 32, 48, 136, 384 259 | pretrained.act_postprocess1 = nn.Sequential( 260 | readout_oper[0], 261 | Transpose(1, 2), 262 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 263 | nn.Conv2d( 264 | in_channels=vit_features, 265 | out_channels=features[0], 266 | kernel_size=1, 267 | stride=1, 268 | padding=0, 269 | ), 270 | nn.ConvTranspose2d( 271 | in_channels=features[0], 272 | out_channels=features[0], 273 | kernel_size=4, 274 | stride=4, 275 | padding=0, 276 | bias=True, 277 | dilation=1, 278 | groups=1, 279 | ), 280 | ) 281 | 282 | pretrained.act_postprocess2 = nn.Sequential( 283 | readout_oper[1], 284 | Transpose(1, 2), 285 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 286 | nn.Conv2d( 287 | in_channels=vit_features, 288 | out_channels=features[1], 289 | kernel_size=1, 290 | stride=1, 291 | padding=0, 292 | ), 293 | nn.ConvTranspose2d( 294 | in_channels=features[1], 295 | out_channels=features[1], 296 | kernel_size=2, 297 | stride=2, 298 | padding=0, 299 | bias=True, 300 | dilation=1, 301 | groups=1, 302 | ), 303 | ) 304 | 305 | pretrained.act_postprocess3 = nn.Sequential( 306 | readout_oper[2], 307 | Transpose(1, 2), 308 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 309 | nn.Conv2d( 310 | in_channels=vit_features, 311 | out_channels=features[2], 312 | kernel_size=1, 313 | stride=1, 314 | padding=0, 315 | ), 316 | ) 317 | 318 | pretrained.act_postprocess4 = nn.Sequential( 319 | readout_oper[3], 320 | Transpose(1, 2), 321 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 322 | nn.Conv2d( 323 | in_channels=vit_features, 324 | out_channels=features[3], 325 | kernel_size=1, 326 | stride=1, 327 | padding=0, 328 | ), 329 | nn.Conv2d( 330 | in_channels=features[3], 331 | out_channels=features[3], 332 | kernel_size=3, 333 | stride=2, 334 | padding=1, 335 | ), 336 | ) 337 | 338 | pretrained.model.start_index = start_index 339 | pretrained.model.patch_size = [16, 16] 340 | 341 | # We inject this function into the VisionTransformer instances so that 342 | # we can use it with interpolated position embeddings without modifying the library source. 343 | pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) 344 | pretrained.model._resize_pos_embed = types.MethodType( 345 | _resize_pos_embed, pretrained.model 346 | ) 347 | 348 | return pretrained 349 | 350 | 351 | def _make_vit_b_rn50_backbone( 352 | model, 353 | features=[256, 512, 768, 768], 354 | size=[384, 384], 355 | hooks=[0, 1, 8, 11], 356 | vit_features=768, 357 | use_vit_only=False, 358 | use_readout="ignore", 359 | start_index=1, 360 | enable_attention_hooks=False, 361 | ): 362 | pretrained = nn.Module() 363 | 364 | pretrained.model = model 365 | 366 | if use_vit_only == True: 367 | pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) 368 | pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) 369 | else: 370 | pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( 371 | get_activation("1") 372 | ) 373 | pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( 374 | get_activation("2") 375 | ) 376 | 377 | pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) 378 | pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) 379 | 380 | if enable_attention_hooks: 381 | pretrained.model.blocks[2].attn.register_forward_hook(get_attention("attn_1")) 382 | pretrained.model.blocks[5].attn.register_forward_hook(get_attention("attn_2")) 383 | pretrained.model.blocks[8].attn.register_forward_hook(get_attention("attn_3")) 384 | pretrained.model.blocks[11].attn.register_forward_hook(get_attention("attn_4")) 385 | pretrained.attention = attention 386 | 387 | pretrained.activations = activations 388 | 389 | readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) 390 | 391 | if use_vit_only == True: 392 | pretrained.act_postprocess1 = nn.Sequential( 393 | readout_oper[0], 394 | Transpose(1, 2), 395 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 396 | nn.Conv2d( 397 | in_channels=vit_features, 398 | out_channels=features[0], 399 | kernel_size=1, 400 | stride=1, 401 | padding=0, 402 | ), 403 | nn.ConvTranspose2d( 404 | in_channels=features[0], 405 | out_channels=features[0], 406 | kernel_size=4, 407 | stride=4, 408 | padding=0, 409 | bias=True, 410 | dilation=1, 411 | groups=1, 412 | ), 413 | ) 414 | 415 | pretrained.act_postprocess2 = nn.Sequential( 416 | readout_oper[1], 417 | Transpose(1, 2), 418 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 419 | nn.Conv2d( 420 | in_channels=vit_features, 421 | out_channels=features[1], 422 | kernel_size=1, 423 | stride=1, 424 | padding=0, 425 | ), 426 | nn.ConvTranspose2d( 427 | in_channels=features[1], 428 | out_channels=features[1], 429 | kernel_size=2, 430 | stride=2, 431 | padding=0, 432 | bias=True, 433 | dilation=1, 434 | groups=1, 435 | ), 436 | ) 437 | else: 438 | pretrained.act_postprocess1 = nn.Sequential( 439 | nn.Identity(), nn.Identity(), nn.Identity() 440 | ) 441 | pretrained.act_postprocess2 = nn.Sequential( 442 | nn.Identity(), nn.Identity(), nn.Identity() 443 | ) 444 | 445 | pretrained.act_postprocess3 = nn.Sequential( 446 | readout_oper[2], 447 | Transpose(1, 2), 448 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 449 | nn.Conv2d( 450 | in_channels=vit_features, 451 | out_channels=features[2], 452 | kernel_size=1, 453 | stride=1, 454 | padding=0, 455 | ), 456 | ) 457 | 458 | pretrained.act_postprocess4 = nn.Sequential( 459 | readout_oper[3], 460 | Transpose(1, 2), 461 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 462 | nn.Conv2d( 463 | in_channels=vit_features, 464 | out_channels=features[3], 465 | kernel_size=1, 466 | stride=1, 467 | padding=0, 468 | ), 469 | nn.Conv2d( 470 | in_channels=features[3], 471 | out_channels=features[3], 472 | kernel_size=3, 473 | stride=2, 474 | padding=1, 475 | ), 476 | ) 477 | 478 | pretrained.model.start_index = start_index 479 | pretrained.model.patch_size = [16, 16] 480 | 481 | # We inject this function into the VisionTransformer instances so that 482 | # we can use it with interpolated position embeddings without modifying the library source. 483 | pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) 484 | 485 | # We inject this function into the VisionTransformer instances so that 486 | # we can use it with interpolated position embeddings without modifying the library source. 487 | pretrained.model._resize_pos_embed = types.MethodType( 488 | _resize_pos_embed, pretrained.model 489 | ) 490 | 491 | return pretrained 492 | 493 | 494 | def _make_pretrained_vitb_rn50_384( 495 | pretrained, 496 | use_readout="ignore", 497 | hooks=None, 498 | use_vit_only=False, 499 | enable_attention_hooks=False, 500 | ): 501 | model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) 502 | 503 | hooks = [0, 1, 8, 11] if hooks == None else hooks 504 | return _make_vit_b_rn50_backbone( 505 | model, 506 | features=[256, 512, 768, 768], 507 | size=[384, 384], 508 | hooks=hooks, 509 | use_vit_only=use_vit_only, 510 | use_readout=use_readout, 511 | enable_attention_hooks=enable_attention_hooks, 512 | ) 513 | 514 | 515 | def _make_pretrained_vitl16_384( 516 | pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False 517 | ): 518 | model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) 519 | 520 | hooks = [5, 11, 17, 23] if hooks == None else hooks 521 | return _make_vit_b16_backbone( 522 | model, 523 | features=[256, 512, 1024, 1024], 524 | hooks=hooks, 525 | vit_features=1024, 526 | use_readout=use_readout, 527 | enable_attention_hooks=enable_attention_hooks, 528 | ) 529 | 530 | 531 | def _make_pretrained_vitb16_384( 532 | pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False 533 | ): 534 | model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) 535 | 536 | hooks = [2, 5, 8, 11] if hooks == None else hooks 537 | return _make_vit_b16_backbone( 538 | model, 539 | features=[96, 192, 384, 768], 540 | hooks=hooks, 541 | use_readout=use_readout, 542 | enable_attention_hooks=enable_attention_hooks, 543 | ) 544 | 545 | 546 | def _make_pretrained_deitb16_384( 547 | pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False 548 | ): 549 | model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) 550 | 551 | hooks = [2, 5, 8, 11] if hooks == None else hooks 552 | return _make_vit_b16_backbone( 553 | model, 554 | features=[96, 192, 384, 768], 555 | hooks=hooks, 556 | use_readout=use_readout, 557 | enable_attention_hooks=enable_attention_hooks, 558 | ) 559 | 560 | 561 | def _make_pretrained_deitb16_distil_384( 562 | pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False 563 | ): 564 | model = timm.create_model( 565 | "vit_deit_base_distilled_patch16_384", pretrained=pretrained 566 | ) 567 | 568 | hooks = [2, 5, 8, 11] if hooks == None else hooks 569 | return _make_vit_b16_backbone( 570 | model, 571 | features=[96, 192, 384, 768], 572 | hooks=hooks, 573 | use_readout=use_readout, 574 | start_index=2, 575 | enable_attention_hooks=enable_attention_hooks, 576 | ) 577 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial-ShareAlike 3.0 Unported (CC BY-NC-SA 3.0) 2 | 3 | THE WORK (AS DEFINED BELOW) IS PROVIDED UNDER THE TERMS OF THIS CREATIVE COMMONS PUBLIC LICENSE ("CCPL" OR "LICENSE"). THE WORK IS PROTECTED BY COPYRIGHT AND/OR OTHER APPLICABLE LAW. ANY USE OF THE WORK OTHER THAN AS AUTHORIZED UNDER THIS LICENSE OR COPYRIGHT LAW IS PROHIBITED. 4 | 5 | BY EXERCISING ANY RIGHTS TO THE WORK PROVIDED HERE, YOU ACCEPT AND AGREE TO BE BOUND BY THE TERMS OF THIS LICENSE. TO THE EXTENT THIS LICENSE MAY BE CONSIDERED TO BE A CONTRACT, THE LICENSOR GRANTS YOU THE RIGHTS CONTAINED HERE IN CONSIDERATION OF YOUR ACCEPTANCE OF SUCH TERMS AND CONDITIONS. 6 | 7 | 1. Definitions 8 | "Adaptation" means a work based upon the Work, or upon the Work and other pre-existing works, such as a translation, adaptation, derivative work, arrangement of music or other alterations of a literary or artistic work, or phonogram or performance and includes cinematographic adaptations or any other form in which the Work may be recast, transformed, or adapted including in any form recognizably derived from the original, except that a work that constitutes a Collection will not be considered an Adaptation for the purpose of this License. For the avoidance of doubt, where the Work is a musical work, performance or phonogram, the synchronization of the Work in timed-relation with a moving image ("synching") will be considered an Adaptation for the purpose of this License. 9 | "Collection" means a collection of literary or artistic works, such as encyclopedias and anthologies, or performances, phonograms or broadcasts, or other works or subject matter other than works listed in Section 1(f) below, which, by reason of the selection and arrangement of their contents, constitute intellectual creations, in which the Work is included in its entirety in unmodified form along with one or more other contributions, each constituting separate and independent works in themselves, which together are assembled into a collective whole. A work that constitutes a Collection will not be considered an Adaptation (as defined below) for the purposes of this License. 10 | "Creative Commons Compatible License" means a license that is listed at http://creativecommons.org/compatiblelicenses that has been approved by Creative Commons as being essentially equivalent to this License, including, at a minimum, because that license: (i) contains terms that have the same purpose, meaning and effect as the License Elements of this License; and, (ii) explicitly permits the relicensing of adaptations of works made available under that license under this License or a Creative Commons jurisdiction license with the same License Elements as this License. 11 | "Distribute" means to make available to the public the original and copies of the Work or Adaptation, as appropriate, through sale or other transfer of ownership. 12 | "License Elements" means the following high-level license attributes as selected by Licensor and indicated in the title of this License: Attribution, ShareAlike. 13 | "Licensor" means the individual, individuals, entity or entities that offer(s) the Work under the terms of this License. 14 | "Original Author" means, in the case of a literary or artistic work, the individual, individuals, entity or entities who created the Work or if no individual or entity can be identified, the publisher; and in addition (i) in the case of a performance the actors, singers, musicians, dancers, and other persons who act, sing, deliver, declaim, play in, interpret or otherwise perform literary or artistic works or expressions of folklore; (ii) in the case of a phonogram the producer being the person or legal entity who first fixes the sounds of a performance or other sounds; and, (iii) in the case of broadcasts, the organization that transmits the broadcast. 15 | "Work" means the literary and/or artistic work offered under the terms of this License including without limitation any production in the literary, scientific and artistic domain, whatever may be the mode or form of its expression including digital form, such as a book, pamphlet and other writing; a lecture, address, sermon or other work of the same nature; a dramatic or dramatico-musical work; a choreographic work or entertainment in dumb show; a musical composition with or without words; a cinematographic work to which are assimilated works expressed by a process analogous to cinematography; a work of drawing, painting, architecture, sculpture, engraving or lithography; a photographic work to which are assimilated works expressed by a process analogous to photography; a work of applied art; an illustration, map, plan, sketch or three-dimensional work relative to geography, topography, architecture or science; a performance; a broadcast; a phonogram; a compilation of data to the extent it is protected as a copyrightable work; or a work performed by a variety or circus performer to the extent it is not otherwise considered a literary or artistic work. 16 | "You" means an individual or entity exercising rights under this License who has not previously violated the terms of this License with respect to the Work, or who has received express permission from the Licensor to exercise rights under this License despite a previous violation. 17 | "Publicly Perform" means to perform public recitations of the Work and to communicate to the public those public recitations, by any means or process, including by wire or wireless means or public digital performances; to make available to the public Works in such a way that members of the public may access these Works from a place and at a place individually chosen by them; to perform the Work to the public by any means or process and the communication to the public of the performances of the Work, including by public digital performance; to broadcast and rebroadcast the Work by any means including signs, sounds or images. 18 | "Reproduce" means to make copies of the Work by any means including without limitation by sound or visual recordings and the right of fixation and reproducing fixations of the Work, including storage of a protected performance or phonogram in digital form or other electronic medium. 19 | 20 | 2. Fair Dealing Rights 21 | Nothing in this License is intended to reduce, limit, or restrict any uses free from copyright or rights arising from limitations or exceptions that are provided for in connection with the copyright protection under copyright law or other applicable laws. 22 | 23 | 3. License Grant 24 | Subject to the terms and conditions of this License, Licensor hereby grants You a worldwide, royalty-free, non-exclusive, perpetual (for the duration of the applicable copyright) license to exercise the rights in the Work as stated below: 25 | to Reproduce the Work, to incorporate the Work into one or more Collections, and to Reproduce the Work as incorporated in the Collections; 26 | to create and Reproduce Adaptations provided that any such Adaptation, including any translation in any medium, takes reasonable steps to clearly label, demarcate or otherwise identify that changes were made to the original Work. For example, a translation could be marked "The original work was translated from English to Spanish," or a modification could indicate "The original work has been modified."; to Distribute and Publicly Perform the Work including as incorporated in Collections; and, to Distribute and Publicly Perform Adaptations. 27 | For the avoidance of doubt: 28 | Non-waivable Compulsory License Schemes. In those jurisdictions in which the right to collect royalties through any statutory or compulsory licensing scheme cannot be waived, the Licensor reserves the exclusive right to collect such royalties for any exercise by You of the rights granted under this License; 29 | Waivable Compulsory License Schemes. In those jurisdictions in which the right to collect royalties through any statutory or compulsory licensing scheme can be waived, the Licensor waives the exclusive right to collect such royalties for any exercise by You of the rights granted under this License; and, 30 | Voluntary License Schemes. The Licensor waives the right to collect royalties, whether individually or, in the event that the Licensor is a member of a collecting society that administers voluntary licensing schemes, via that society, from any exercise by You of the rights granted under this License. 31 | The above rights may be exercised in all media and formats whether now known or hereafter devised. The above rights include the right to make such modifications as are technically necessary to exercise the rights in other media and formats. Subject to Section 8(f), all rights not expressly granted by Licensor are hereby reserved. 32 | 33 | 4. Restrictions 34 | The license granted in Section 3 above is expressly made subject to and limited by the following restrictions: 35 | You may Distribute or Publicly Perform the Work only under the terms of this License. You must include a copy of, or the Uniform Resource Identifier (URI) for, this License with every copy of the Work You Distribute or Publicly Perform. You may not offer or impose any terms on the Work that restrict the terms of this License or the ability of the recipient of the Work to exercise the rights granted to that recipient under the terms of the License. You may not sublicense the Work. You must keep intact all notices that refer to this License and to the disclaimer of warranties with every copy of the Work You Distribute or Publicly Perform. When You Distribute or Publicly Perform the Work, You may not impose any effective technological measures on the Work that restrict the ability of a recipient of the Work from You to exercise the rights granted to that recipient under the terms of the License. This Section 4(a) applies to the Work as incorporated in a Collection, but this does not require the Collection apart from the Work itself to be made subject to the terms of this License. If You create a Collection, upon notice from any Licensor You must, to the extent practicable, remove from the Collection any credit as required by Section 4(c), as requested. If You create an Adaptation, upon notice from any Licensor You must, to the extent practicable, remove from the Adaptation any credit as required by Section 4(c), as requested. 36 | You may Distribute or Publicly Perform an Adaptation only under the terms of: (i) this License; (ii) a later version of this License with the same License Elements as this License; (iii) a Creative Commons jurisdiction license (either this or a later license version) that contains the same License Elements as this License (e.g., Attribution-ShareAlike 3.0 US)); (iv) a Creative Commons Compatible License. If you license the Adaptation under one of the licenses mentioned in (iv), you must comply with the terms of that license. If you license the Adaptation under the terms of any of the licenses mentioned in (i), (ii) or (iii) (the "Applicable License"), you must comply with the terms of the Applicable License generally and the following provisions: (I) You must include a copy of, or the URI for, the Applicable License with every copy of each Adaptation You Distribute or Publicly Perform; (II) You may not offer or impose any terms on the Adaptation that restrict the terms of the Applicable License or the ability of the recipient of the Adaptation to exercise the rights granted to that recipient under the terms of the Applicable License; (III) You must keep intact all notices that refer to the Applicable License and to the disclaimer of warranties with every copy of the Work as included in the Adaptation You Distribute or Publicly Perform; (IV) when You Distribute or Publicly Perform the Adaptation, You may not impose any effective technological measures on the Adaptation that restrict the ability of a recipient of the Adaptation from You to exercise the rights granted to that recipient under the terms of the Applicable License. This Section 4(b) applies to the Adaptation as incorporated in a Collection, but this does not require the Collection apart from the Adaptation itself to be made subject to the terms of the Applicable License. 37 | If You Distribute, or Publicly Perform the Work or any Adaptations or Collections, You must, unless a request has been made pursuant to Section 4(a), keep intact all copyright notices for the Work and provide, reasonable to the medium or means You are utilizing: (i) the name of the Original Author (or pseudonym, if applicable) if supplied, and/or if the Original Author and/or Licensor designate another party or parties (e.g., a sponsor institute, publishing entity, journal) for attribution ("Attribution Parties") in Licensor's copyright notice, terms of service or by other reasonable means, the name of such party or parties; (ii) the title of the Work if supplied; (iii) to the extent reasonably practicable, the URI, if any, that Licensor specifies to be associated with the Work, unless such URI does not refer to the copyright notice or licensing information for the Work; and (iv) , consistent with Section 3(b), in the case of an Adaptation, a credit identifying the use of the Work in the Adaptation (e.g., "French translation of the Work by Original Author," or "Screenplay based on original Work by Original Author"). The credit required by this Section 4(c) may be implemented in any reasonable manner; provided, however, that in the case of a Adaptation or Collection, at a minimum such credit will appear, if a credit for all contributing authors of the Adaptation or Collection appears, then as part of these credits and in a manner at least as prominent as the credits for the other contributing authors. For the avoidance of doubt, You may only use the credit required by this Section for the purpose of attribution in the manner set out above and, by exercising Your rights under this License, You may not implicitly or explicitly assert or imply any connection with, sponsorship or endorsement by the Original Author, Licensor and/or Attribution Parties, as appropriate, of You or Your use of the Work, without the separate, express prior written permission of the Original Author, Licensor and/or Attribution Parties. 38 | Except as otherwise agreed in writing by the Licensor or as may be otherwise permitted by applicable law, if You Reproduce, Distribute or Publicly Perform the Work either by itself or as part of any Adaptations or Collections, You must not distort, mutilate, modify or take other derogatory action in relation to the Work which would be prejudicial to the Original Author's honor or reputation. Licensor agrees that in those jurisdictions (e.g. Japan), in which any exercise of the right granted in Section 3(b) of this License (the right to make Adaptations) would be deemed to be a distortion, mutilation, modification or other derogatory action prejudicial to the Original Author's honor and reputation, the Licensor will waive or not assert, as appropriate, this Section, to the fullest extent permitted by the applicable national law, to enable You to reasonably exercise Your right under Section 3(b) of this License (right to make Adaptations) but not otherwise. 39 | 40 | 5. Representations, Warranties and Disclaimer 41 | UNLESS OTHERWISE MUTUALLY AGREED TO BY THE PARTIES IN WRITING, LICENSOR OFFERS THE WORK AS-IS AND MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE WORK, EXPRESS, IMPLIED, STATUTORY OR OTHERWISE, INCLUDING, WITHOUT LIMITATION, WARRANTIES OF TITLE, MERCHANTIBILITY, FITNESS FOR A PARTICULAR PURPOSE, NONINFRINGEMENT, OR THE ABSENCE OF LATENT OR OTHER DEFECTS, ACCURACY, OR THE PRESENCE OF ABSENCE OF ERRORS, WHETHER OR NOT DISCOVERABLE. SOME JURISDICTIONS DO NOT ALLOW THE EXCLUSION OF IMPLIED WARRANTIES, SO SUCH EXCLUSION MAY NOT APPLY TO YOU. 42 | 43 | 6. Limitation on Liability 44 | EXCEPT TO THE EXTENT REQUIRED BY APPLICABLE LAW, IN NO EVENT WILL LICENSOR BE LIABLE TO YOU ON ANY LEGAL THEORY FOR ANY SPECIAL, INCIDENTAL, CONSEQUENTIAL, PUNITIVE OR EXEMPLARY DAMAGES ARISING OUT OF THIS LICENSE OR THE USE OF THE WORK, EVEN IF LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 45 | 46 | 7. Termination 47 | This License and the rights granted hereunder will terminate automatically upon any breach by You of the terms of this License. Individuals or entities who have received Adaptations or Collections from You under this License, however, will not have their licenses terminated provided such individuals or entities remain in full compliance with those licenses. Sections 1, 2, 5, 6, 7, and 8 will survive any termination of this License. 48 | Subject to the above terms and conditions, the license granted here is perpetual (for the duration of the applicable copyright in the Work). Notwithstanding the above, Licensor reserves the right to release the Work under different license terms or to stop distributing the Work at any time; provided, however that any such election will not serve to withdraw this License (or any other license that has been, or is required to be, granted under the terms of this License), and this License will continue in full force and effect unless terminated as stated above. 49 | 50 | 8. Miscellaneous 51 | Each time You Distribute or Publicly Perform the Work or a Collection, the Licensor offers to the recipient a license to the Work on the same terms and conditions as the license granted to You under this License. 52 | Each time You Distribute or Publicly Perform an Adaptation, Licensor offers to the recipient a license to the original Work on the same terms and conditions as the license granted to You under this License. 53 | If any provision of this License is invalid or unenforceable under applicable law, it shall not affect the validity or enforceability of the remainder of the terms of this License, and without further action by the parties to this agreement, such provision shall be reformed to the minimum extent necessary to make such provision valid and enforceable. 54 | No term or provision of this License shall be deemed waived and no breach consented to unless such waiver or consent shall be in writing and signed by the party to be charged with such waiver or consent. 55 | This License constitutes the entire agreement between the parties with respect to the Work licensed here. There are no understandings, agreements or representations with respect to the Work not specified here. Licensor shall not be bound by any additional provisions that may appear in any communication from You. This License may not be modified without the mutual written agreement of the Licensor and You. 56 | The rights granted under, and the subject matter referenced, in this License were drafted utilizing the terminology of the Berne Convention for the Protection of Literary and Artistic Works (as amended on September 28, 1979), the Rome Convention of 1961, the WIPO Copyright Treaty of 1996, the WIPO Performances and Phonograms Treaty of 1996 and the Universal Copyright Convention (as revised on July 24, 1971). These rights and subject matter take effect in the relevant jurisdiction in which the License terms are sought to be enforced according to the corresponding provisions of the implementation of those treaty provisions in the applicable national law. If the standard suite of rights granted under applicable copyright law includes additional rights not granted under this License, such additional rights are deemed to be included in the License; this License is not intended to restrict the license of any rights under applicable law. 57 | -------------------------------------------------------------------------------- /RAFT/alt_cuda_corr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/RealFlow/52ea80d416d55f454f63650eeba85687055285c2/RAFT/alt_cuda_corr/__init__.py -------------------------------------------------------------------------------- /RAFT/alt_cuda_corr/correlation.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | // CUDA forward declarations 5 | std::vector corr_cuda_forward( 6 | torch::Tensor fmap1, 7 | torch::Tensor fmap2, 8 | torch::Tensor coords, 9 | int radius); 10 | 11 | std::vector corr_cuda_backward( 12 | torch::Tensor fmap1, 13 | torch::Tensor fmap2, 14 | torch::Tensor coords, 15 | torch::Tensor corr_grad, 16 | int radius); 17 | 18 | // C++ interface 19 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 20 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 21 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 22 | 23 | std::vector corr_forward( 24 | torch::Tensor fmap1, 25 | torch::Tensor fmap2, 26 | torch::Tensor coords, 27 | int radius) { 28 | CHECK_INPUT(fmap1); 29 | CHECK_INPUT(fmap2); 30 | CHECK_INPUT(coords); 31 | 32 | return corr_cuda_forward(fmap1, fmap2, coords, radius); 33 | } 34 | 35 | 36 | std::vector corr_backward( 37 | torch::Tensor fmap1, 38 | torch::Tensor fmap2, 39 | torch::Tensor coords, 40 | torch::Tensor corr_grad, 41 | int radius) { 42 | CHECK_INPUT(fmap1); 43 | CHECK_INPUT(fmap2); 44 | CHECK_INPUT(coords); 45 | CHECK_INPUT(corr_grad); 46 | 47 | return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius); 48 | } 49 | 50 | 51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 52 | m.def("forward", &corr_forward, "CORR forward"); 53 | m.def("backward", &corr_backward, "CORR backward"); 54 | } -------------------------------------------------------------------------------- /RAFT/alt_cuda_corr/correlation_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | 7 | #define BLOCK_H 4 8 | #define BLOCK_W 8 9 | #define BLOCK_HW BLOCK_H * BLOCK_W 10 | #define CHANNEL_STRIDE 32 11 | 12 | 13 | __forceinline__ __device__ 14 | bool within_bounds(int h, int w, int H, int W) { 15 | return h >= 0 && h < H && w >= 0 && w < W; 16 | } 17 | 18 | template 19 | __global__ void corr_forward_kernel( 20 | const torch::PackedTensorAccessor32 fmap1, 21 | const torch::PackedTensorAccessor32 fmap2, 22 | const torch::PackedTensorAccessor32 coords, 23 | torch::PackedTensorAccessor32 corr, 24 | int r) 25 | { 26 | const int b = blockIdx.x; 27 | const int h0 = blockIdx.y * blockDim.x; 28 | const int w0 = blockIdx.z * blockDim.y; 29 | const int tid = threadIdx.x * blockDim.y + threadIdx.y; 30 | 31 | const int H1 = fmap1.size(1); 32 | const int W1 = fmap1.size(2); 33 | const int H2 = fmap2.size(1); 34 | const int W2 = fmap2.size(2); 35 | const int N = coords.size(1); 36 | const int C = fmap1.size(3); 37 | 38 | __shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1]; 39 | __shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1]; 40 | __shared__ scalar_t x2s[BLOCK_HW]; 41 | __shared__ scalar_t y2s[BLOCK_HW]; 42 | 43 | for (int c=0; c(floor(y2s[k1]))-r+iy; 76 | int w2 = static_cast(floor(x2s[k1]))-r+ix; 77 | int c2 = tid % CHANNEL_STRIDE; 78 | 79 | auto fptr = fmap2[b][h2][w2]; 80 | if (within_bounds(h2, w2, H2, W2)) 81 | f2[c2][k1] = fptr[c+c2]; 82 | else 83 | f2[c2][k1] = 0.0; 84 | } 85 | 86 | __syncthreads(); 87 | 88 | scalar_t s = 0.0; 89 | for (int k=0; k 0 && ix > 0 && within_bounds(h1, w1, H1, W1)) 105 | *(corr_ptr + ix_nw) += nw; 106 | 107 | if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1)) 108 | *(corr_ptr + ix_ne) += ne; 109 | 110 | if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1)) 111 | *(corr_ptr + ix_sw) += sw; 112 | 113 | if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1)) 114 | *(corr_ptr + ix_se) += se; 115 | } 116 | } 117 | } 118 | } 119 | } 120 | 121 | 122 | template 123 | __global__ void corr_backward_kernel( 124 | const torch::PackedTensorAccessor32 fmap1, 125 | const torch::PackedTensorAccessor32 fmap2, 126 | const torch::PackedTensorAccessor32 coords, 127 | const torch::PackedTensorAccessor32 corr_grad, 128 | torch::PackedTensorAccessor32 fmap1_grad, 129 | torch::PackedTensorAccessor32 fmap2_grad, 130 | torch::PackedTensorAccessor32 coords_grad, 131 | int r) 132 | { 133 | 134 | const int b = blockIdx.x; 135 | const int h0 = blockIdx.y * blockDim.x; 136 | const int w0 = blockIdx.z * blockDim.y; 137 | const int tid = threadIdx.x * blockDim.y + threadIdx.y; 138 | 139 | const int H1 = fmap1.size(1); 140 | const int W1 = fmap1.size(2); 141 | const int H2 = fmap2.size(1); 142 | const int W2 = fmap2.size(2); 143 | const int N = coords.size(1); 144 | const int C = fmap1.size(3); 145 | 146 | __shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1]; 147 | __shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1]; 148 | 149 | __shared__ scalar_t f1_grad[CHANNEL_STRIDE][BLOCK_HW+1]; 150 | __shared__ scalar_t f2_grad[CHANNEL_STRIDE][BLOCK_HW+1]; 151 | 152 | __shared__ scalar_t x2s[BLOCK_HW]; 153 | __shared__ scalar_t y2s[BLOCK_HW]; 154 | 155 | for (int c=0; c(floor(y2s[k1]))-r+iy; 190 | int w2 = static_cast(floor(x2s[k1]))-r+ix; 191 | int c2 = tid % CHANNEL_STRIDE; 192 | 193 | auto fptr = fmap2[b][h2][w2]; 194 | if (within_bounds(h2, w2, H2, W2)) 195 | f2[c2][k1] = fptr[c+c2]; 196 | else 197 | f2[c2][k1] = 0.0; 198 | 199 | f2_grad[c2][k1] = 0.0; 200 | } 201 | 202 | __syncthreads(); 203 | 204 | const scalar_t* grad_ptr = &corr_grad[b][n][0][h1][w1]; 205 | scalar_t g = 0.0; 206 | 207 | int ix_nw = H1*W1*((iy-1) + rd*(ix-1)); 208 | int ix_ne = H1*W1*((iy-1) + rd*ix); 209 | int ix_sw = H1*W1*(iy + rd*(ix-1)); 210 | int ix_se = H1*W1*(iy + rd*ix); 211 | 212 | if (iy > 0 && ix > 0 && within_bounds(h1, w1, H1, W1)) 213 | g += *(grad_ptr + ix_nw) * dy * dx; 214 | 215 | if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1)) 216 | g += *(grad_ptr + ix_ne) * dy * (1-dx); 217 | 218 | if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1)) 219 | g += *(grad_ptr + ix_sw) * (1-dy) * dx; 220 | 221 | if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1)) 222 | g += *(grad_ptr + ix_se) * (1-dy) * (1-dx); 223 | 224 | for (int k=0; k(floor(y2s[k1]))-r+iy; 232 | int w2 = static_cast(floor(x2s[k1]))-r+ix; 233 | int c2 = tid % CHANNEL_STRIDE; 234 | 235 | scalar_t* fptr = &fmap2_grad[b][h2][w2][0]; 236 | if (within_bounds(h2, w2, H2, W2)) 237 | atomicAdd(fptr+c+c2, f2_grad[c2][k1]); 238 | } 239 | } 240 | } 241 | } 242 | __syncthreads(); 243 | 244 | 245 | for (int k=0; k corr_cuda_forward( 261 | torch::Tensor fmap1, 262 | torch::Tensor fmap2, 263 | torch::Tensor coords, 264 | int radius) 265 | { 266 | const auto B = coords.size(0); 267 | const auto N = coords.size(1); 268 | const auto H = coords.size(2); 269 | const auto W = coords.size(3); 270 | 271 | const auto rd = 2 * radius + 1; 272 | auto opts = fmap1.options(); 273 | auto corr = torch::zeros({B, N, rd*rd, H, W}, opts); 274 | 275 | const dim3 blocks(B, (H+BLOCK_H-1)/BLOCK_H, (W+BLOCK_W-1)/BLOCK_W); 276 | const dim3 threads(BLOCK_H, BLOCK_W); 277 | 278 | corr_forward_kernel<<>>( 279 | fmap1.packed_accessor32(), 280 | fmap2.packed_accessor32(), 281 | coords.packed_accessor32(), 282 | corr.packed_accessor32(), 283 | radius); 284 | 285 | return {corr}; 286 | } 287 | 288 | std::vector corr_cuda_backward( 289 | torch::Tensor fmap1, 290 | torch::Tensor fmap2, 291 | torch::Tensor coords, 292 | torch::Tensor corr_grad, 293 | int radius) 294 | { 295 | const auto B = coords.size(0); 296 | const auto N = coords.size(1); 297 | 298 | const auto H1 = fmap1.size(1); 299 | const auto W1 = fmap1.size(2); 300 | const auto H2 = fmap2.size(1); 301 | const auto W2 = fmap2.size(2); 302 | const auto C = fmap1.size(3); 303 | 304 | auto opts = fmap1.options(); 305 | auto fmap1_grad = torch::zeros({B, H1, W1, C}, opts); 306 | auto fmap2_grad = torch::zeros({B, H2, W2, C}, opts); 307 | auto coords_grad = torch::zeros({B, N, H1, W1, 2}, opts); 308 | 309 | const dim3 blocks(B, (H1+BLOCK_H-1)/BLOCK_H, (W1+BLOCK_W-1)/BLOCK_W); 310 | const dim3 threads(BLOCK_H, BLOCK_W); 311 | 312 | 313 | corr_backward_kernel<<>>( 314 | fmap1.packed_accessor32(), 315 | fmap2.packed_accessor32(), 316 | coords.packed_accessor32(), 317 | corr_grad.packed_accessor32(), 318 | fmap1_grad.packed_accessor32(), 319 | fmap2_grad.packed_accessor32(), 320 | coords_grad.packed_accessor32(), 321 | radius); 322 | 323 | return {fmap1_grad, fmap2_grad, coords_grad}; 324 | } -------------------------------------------------------------------------------- /RAFT/alt_cuda_corr/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | 5 | setup( 6 | name='correlation', 7 | ext_modules=[ 8 | CUDAExtension('alt_cuda_corr', 9 | sources=['correlation.cpp', 'correlation_kernel.cu'], 10 | extra_compile_args={'cxx': [], 'nvcc': ['-O3']}), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) 15 | 16 | -------------------------------------------------------------------------------- /RAFT/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/RealFlow/52ea80d416d55f454f63650eeba85687055285c2/RAFT/core/__init__.py -------------------------------------------------------------------------------- /RAFT/core/corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from RAFT.core.utils.utils import bilinear_sampler, coords_grid 4 | from torch import nn 5 | 6 | try: 7 | # import alt_cuda_corr 8 | from RAFT.alt_cuda_corr import alt_cuda_corr 9 | except: 10 | # alt_cuda_corr is not compiled 11 | # raise ValueError('alt_cuda_corr is not compiled') 12 | print('alt_cuda_corr is not compiled') 13 | pass 14 | 15 | 16 | class CorrBlock: 17 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 18 | self.num_levels = num_levels 19 | self.radius = radius 20 | self.corr_pyramid = [] 21 | 22 | # all pairs correlation 23 | corr = CorrBlock.corr(fmap1, fmap2) 24 | 25 | batch, h1, w1, dim, h2, w2 = corr.shape 26 | corr = corr.reshape(batch * h1 * w1, dim, h2, w2) 27 | 28 | self.corr_pyramid.append(corr) 29 | for i in range(self.num_levels - 1): 30 | corr = F.avg_pool2d(corr, 2, stride=2) 31 | self.corr_pyramid.append(corr) 32 | 33 | def __call__(self, coords): 34 | r = self.radius 35 | coords = coords.permute(0, 2, 3, 1) 36 | batch, h1, w1, _ = coords.shape 37 | 38 | out_pyramid = [] 39 | for i in range(self.num_levels): 40 | corr = self.corr_pyramid[i] 41 | dx = torch.linspace(-r, r, 2 * r + 1) 42 | dy = torch.linspace(-r, r, 2 * r + 1) 43 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) 44 | 45 | centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2 ** i 46 | delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) 47 | coords_lvl = centroid_lvl + delta_lvl 48 | 49 | corr = bilinear_sampler(corr, coords_lvl) 50 | corr = corr.view(batch, h1, w1, -1) 51 | out_pyramid.append(corr) 52 | 53 | out = torch.cat(out_pyramid, dim=-1) 54 | return out.permute(0, 3, 1, 2).contiguous().float() 55 | 56 | @staticmethod 57 | def corr(fmap1, fmap2): 58 | batch, dim, ht, wd = fmap1.shape 59 | fmap1 = fmap1.view(batch, dim, ht * wd) 60 | fmap2 = fmap2.view(batch, dim, ht * wd) 61 | 62 | corr = torch.matmul(fmap1.transpose(1, 2), fmap2) 63 | corr = corr.view(batch, ht, wd, 1, ht, wd) 64 | return corr / torch.sqrt(torch.tensor(dim).float()) 65 | 66 | 67 | class AlternateCorrBlock: 68 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 69 | self.num_levels = num_levels 70 | self.radius = radius 71 | 72 | self.pyramid = [(fmap1, fmap2)] 73 | for i in range(self.num_levels): 74 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2) 75 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2) 76 | self.pyramid.append((fmap1, fmap2)) 77 | 78 | def __call__(self, coords): 79 | coords = coords.permute(0, 2, 3, 1) 80 | B, H, W, _ = coords.shape 81 | dim = self.pyramid[0][0].shape[1] 82 | 83 | corr_list = [] 84 | for i in range(self.num_levels): 85 | r = self.radius 86 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() 87 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() 88 | 89 | coords_i = (coords / 2 ** i).reshape(B, 1, H, W, 2).contiguous() 90 | corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) 91 | corr_list.append(corr.squeeze(1)) 92 | 93 | corr = torch.stack(corr_list, dim=1) 94 | corr = corr.reshape(B, -1, H, W) 95 | return corr / torch.sqrt(torch.tensor(dim).float()) 96 | 97 | 98 | class CorrBlock_pyramid: 99 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 100 | self.num_levels = num_levels 101 | self.radius = radius 102 | self.corr_pyramid = [] 103 | 104 | fmap1 = fmap1[::-1] 105 | fmap2 = fmap2[::-1] 106 | 107 | corr = CorrBlock.corr(fmap1[3], fmap2[3]) 108 | batch, h1, w1, dim, h2, w2 = corr.shape 109 | corr = corr.reshape(batch * h1 * w1, dim, h2, w2) 110 | self.corr_pyramid.append(corr) 111 | 112 | # all pairs correlation 113 | for i in range(self.num_levels-1): 114 | corr = CorrBlock.corr(fmap1[i + 4], fmap2[i + 4]) 115 | batch, h1, w1, dim, h2, w2 = corr.shape 116 | corr = corr.permute(0, 3, 4, 5, 1, 2).contiguous() 117 | corr = corr.reshape(batch, dim*h2*w2, h1, w1) 118 | corr = nn.Upsample(scale_factor=2**(i+1), mode='bilinear')(corr).reshape(batch, dim, h2, w2, h1*2**(i+1), w1*2**(i+1)) 119 | corr = corr.permute(0, 4, 5, 1, 2, 3).contiguous() 120 | corr = corr.reshape(batch * h1 * w1 * 4**(i+1), dim, h2, w2) 121 | self.corr_pyramid.append(corr) 122 | 123 | def __call__(self, coords): 124 | r = self.radius 125 | coords = coords.permute(0, 2, 3, 1) 126 | batch, h1, w1, _ = coords.shape 127 | 128 | out_pyramid = [] 129 | for i in range(self.num_levels): 130 | corr = self.corr_pyramid[i] 131 | dx = torch.linspace(-r, r, 2 * r + 1) 132 | dy = torch.linspace(-r, r, 2 * r + 1) 133 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) 134 | 135 | centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2 ** i 136 | delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) 137 | coords_lvl = centroid_lvl + delta_lvl 138 | 139 | corr = bilinear_sampler(corr, coords_lvl) 140 | corr = corr.view(batch, h1, w1, -1) 141 | out_pyramid.append(corr) 142 | 143 | out = torch.cat(out_pyramid, dim=-1) 144 | return out.permute(0, 3, 1, 2).contiguous().float() 145 | 146 | @staticmethod 147 | def corr(fmap1, fmap2): 148 | batch, dim, ht, wd = fmap1.shape 149 | fmap1 = fmap1.view(batch, dim, ht * wd) 150 | fmap2 = fmap2.view(batch, dim, ht * wd) 151 | 152 | corr = torch.matmul(fmap1.transpose(1, 2), fmap2) 153 | corr = corr.view(batch, ht, wd, 1, ht, wd) 154 | return corr / torch.sqrt(torch.tensor(dim).float()) -------------------------------------------------------------------------------- /RAFT/core/datasets.py: -------------------------------------------------------------------------------- 1 | # Data loading based on https://github.com/NVIDIA/flownet2-pytorch 2 | 3 | import numpy as np 4 | import torch 5 | import torch.utils.data as data 6 | import os 7 | import random 8 | from glob import glob 9 | import os.path as osp 10 | import cv2 11 | from RAFT.core.utils import frame_utils 12 | from RAFT.core.utils.augmentor import FlowAugmentor, SparseFlowAugmentor 13 | import math 14 | 15 | 16 | class FlowDataset(data.Dataset): 17 | def __init__(self, aug_params=None, sparse=False): 18 | self.augmentor = None 19 | self.sparse = sparse 20 | if aug_params is not None: 21 | if sparse: 22 | self.augmentor = SparseFlowAugmentor(**aug_params) 23 | else: 24 | self.augmentor = FlowAugmentor(**aug_params) 25 | 26 | self.is_test = False 27 | self.init_seed = False 28 | self.flow_list = [] 29 | self.image_list = [] 30 | self.extra_info = [] 31 | 32 | self.dataclean = None 33 | self.type = 'chairs' 34 | self.depth = False 35 | 36 | def __getitem__(self, index): 37 | 38 | if self.is_test: 39 | img1 = frame_utils.read_gen(self.image_list[index][0]) 40 | img2 = frame_utils.read_gen(self.image_list[index][1]) 41 | img1 = np.array(img1).astype(np.uint8)[..., :3] 42 | img2 = np.array(img2).astype(np.uint8)[..., :3] 43 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float() 44 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float() 45 | return img1, img2, self.extra_info[index] 46 | 47 | if not self.init_seed: 48 | worker_info = torch.utils.data.get_worker_info() 49 | if worker_info is not None: 50 | torch.manual_seed(worker_info.id) 51 | np.random.seed(worker_info.id) 52 | random.seed(worker_info.id) 53 | self.init_seed = True 54 | 55 | # things, 0706 56 | if isinstance(self.image_list[index], dict): 57 | sample = self.dataclean.getiterm(self.image_list[index]) 58 | img1, img2, flow, valid = sample['im1'], sample['im2'], sample['flow'], sample['valid'] 59 | # img1, img2 = sample['im1'], sample['im2'] 60 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float() 61 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float() 62 | flow = torch.from_numpy(flow).permute(2, 0, 1).float() 63 | valid = torch.from_numpy(valid).permute(2, 0, 1).float() 64 | return img1, img2, flow, valid 65 | # RAFT original 66 | else: 67 | index = index % len(self.image_list) 68 | valid = None 69 | if self.sparse: 70 | flow, valid = frame_utils.readFlowKITTI(self.flow_list[index]) 71 | else: 72 | flow = frame_utils.read_gen(self.flow_list[index]) 73 | if self.depth == True: 74 | depth = frame_utils.depth_read(self.depth_list[index]) 75 | depth = np.array(depth).astype(np.float32) 76 | depth = torch.from_numpy(depth).unsqueeze(0).float() 77 | 78 | img1 = frame_utils.read_gen(self.image_list[index][0]) 79 | img2 = frame_utils.read_gen(self.image_list[index][1]) 80 | 81 | flow = np.array(flow).astype(np.float32) 82 | img1 = np.array(img1).astype(np.uint8) 83 | img2 = np.array(img2).astype(np.uint8) 84 | 85 | # grayscale images 86 | if len(img1.shape) == 2: 87 | img1 = np.tile(img1[..., None], (1, 1, 3)) 88 | img2 = np.tile(img2[..., None], (1, 1, 3)) 89 | else: 90 | img1 = img1[..., :3] 91 | img2 = img2[..., :3] 92 | 93 | if self.augmentor is not None: 94 | if self.sparse: 95 | img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid) 96 | else: 97 | img1, img2, flow = self.augmentor(img1, img2, flow) 98 | 99 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float() 100 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float() 101 | flow = torch.from_numpy(flow).permute(2, 0, 1).float() 102 | 103 | if valid is not None: 104 | valid = torch.from_numpy(valid) 105 | else: 106 | valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000) 107 | 108 | return img1, img2, flow, valid.float() 109 | 110 | def __rmul__(self, v): 111 | self.flow_list = v * self.flow_list 112 | self.image_list = v * self.image_list 113 | return self 114 | 115 | def __len__(self): 116 | return len(self.image_list) 117 | 118 | 119 | class MpiSintel(FlowDataset): 120 | def __init__(self, aug_params=None, split='training', root='/data/Sintel', dstype='clean'): 121 | super(MpiSintel, self).__init__(aug_params) 122 | flow_root = osp.join(root, split, 'flow') 123 | image_root = osp.join(root, split, dstype) 124 | 125 | if split == 'test': 126 | self.is_test = True 127 | for scene in sorted(os.listdir(image_root)): 128 | image_list = sorted(glob(osp.join(image_root, scene, '*.png'))) 129 | for i in range(len(image_list) - 1): 130 | self.image_list += [[image_list[i], image_list[i + 1]]] 131 | self.extra_info += [(scene, i)] # scene and frame_id 132 | 133 | if split != 'test': 134 | self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo'))) 135 | depth_root = osp.join('/data/Sintel_depth', split, 'depth') 136 | self.depth_list = [] 137 | for scene in os.listdir(image_root): 138 | self.depth_list += sorted(glob(osp.join(depth_root, scene, '*.dpt')))[:-1] 139 | self.depth = True 140 | 141 | 142 | class FlyingChairs(FlowDataset): 143 | def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'): 144 | super(FlyingChairs, self).__init__(aug_params) 145 | 146 | images = sorted(glob(osp.join(root, '*.ppm'))) 147 | flows = sorted(glob(osp.join(root, '*.flo'))) 148 | assert (len(images) // 2 == len(flows)) 149 | 150 | split_list = np.loadtxt('chairs_split.txt', dtype=np.int32) 151 | for i in range(len(flows)): 152 | xid = split_list[i] 153 | if (split == 'training' and xid == 1) or (split == 'validation' and xid == 2): 154 | self.flow_list += [flows[i]] 155 | self.image_list += [[images[2 * i], images[2 * i + 1]]] 156 | 157 | 158 | class FlyingThings3D(FlowDataset): 159 | def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'): 160 | super(FlyingThings3D, self).__init__(aug_params) 161 | 162 | for cam in ['left']: 163 | for direction in ['into_future', 'into_past']: 164 | image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*'))) 165 | image_dirs = sorted([osp.join(f, cam) for f in image_dirs]) 166 | 167 | flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*'))) 168 | flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs]) 169 | 170 | for idir, fdir in zip(image_dirs, flow_dirs): 171 | images = sorted(glob(osp.join(idir, '*.png'))) 172 | flows = sorted(glob(osp.join(fdir, '*.pfm'))) 173 | for i in range(len(flows) - 1): 174 | if direction == 'into_future': 175 | self.image_list += [[images[i], images[i + 1]]] 176 | self.flow_list += [flows[i]] 177 | elif direction == 'into_past': 178 | self.image_list += [[images[i + 1], images[i]]] 179 | self.flow_list += [flows[i + 1]] 180 | 181 | 182 | class KITTI(FlowDataset): 183 | def __init__(self, aug_params=None, split='training', root='/data/Optical_Flow_all/datasets/KITTI_data/data_scene_flow/'): 184 | super(KITTI, self).__init__(aug_params, sparse=True) 185 | if split == 'testing': 186 | self.is_test = True 187 | 188 | root = osp.join(root, split) 189 | images1 = sorted(glob(osp.join(root, 'image_2/*_10.png'))) 190 | images2 = sorted(glob(osp.join(root, 'image_2/*_11.png'))) 191 | 192 | for img1, img2 in zip(images1, images2): 193 | frame_id = img1.split('/')[-1] 194 | self.extra_info += [[frame_id]] 195 | self.image_list += [[img1, img2]] 196 | 197 | if split == 'training': 198 | self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png'))) 199 | 200 | 201 | class HD1K(FlowDataset): 202 | def __init__(self, aug_params=None, root='datasets/HD1k'): 203 | super(HD1K, self).__init__(aug_params, sparse=True) 204 | 205 | seq_ix = 0 206 | while 1: 207 | flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix))) 208 | images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix))) 209 | 210 | if len(flows) == 0: 211 | break 212 | 213 | for i in range(len(flows) - 1): 214 | self.flow_list += [flows[i]] 215 | self.image_list += [[images[i], images[i + 1]]] 216 | 217 | seq_ix += 1 218 | 219 | 220 | def fetch_dataloader(args, TRAIN_DS='C+T+K/S'): 221 | """ Create the data loader for the corresponding trainign set """ 222 | 223 | if args.stage == 'chairs': 224 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True} 225 | # train_dataset = FlyingChairs(aug_params, split='training') 226 | train_dataset = FlyingChairs_Nori(aug_params) 227 | elif args.stage == 'things': 228 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True} 229 | # clean_dataset = FlyingThings3D_Nori(aug_params, dstype='frames_cleanpass') 230 | final_dataset = FlyingThings3D_Nori(aug_params, dstype='frames_finalpass') 231 | train_dataset = final_dataset 232 | elif args.stage == 'sintel': 233 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True} 234 | things = FlyingThings3D(aug_params, dstype='frames_cleanpass') 235 | sintel_clean = MpiSintel(aug_params, split='training', dstype='clean') 236 | sintel_final = MpiSintel(aug_params, split='training', dstype='final') 237 | 238 | if TRAIN_DS == 'C+T+K+S+H': 239 | kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True}) 240 | train_dataset = 100 * sintel_clean + 100 * sintel_final + 200 * kitti + things 241 | 242 | elif TRAIN_DS == 'C+T+K/S': 243 | train_dataset = 100 * sintel_clean + 100 * sintel_final + things 244 | elif args.stage == 'kitti': 245 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False} 246 | train_dataset = KITTI(aug_params, split='training') 247 | else: 248 | raise ValueError('') 249 | 250 | train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, 251 | pin_memory=True, shuffle=True, num_workers=8, drop_last=True) 252 | 253 | print('Training with %d image pairs' % len(train_dataset)) 254 | return train_loader 255 | 256 | 257 | 258 | -------------------------------------------------------------------------------- /RAFT/core/extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ResidualBlock(nn.Module): 7 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 8 | super(ResidualBlock, self).__init__() 9 | 10 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) 11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 12 | self.relu = nn.ReLU(inplace=True) 13 | 14 | num_groups = planes // 8 15 | 16 | if norm_fn == 'group': 17 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 18 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 19 | if not stride == 1: 20 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 21 | 22 | elif norm_fn == 'batch': 23 | self.norm1 = nn.BatchNorm2d(planes) 24 | self.norm2 = nn.BatchNorm2d(planes) 25 | if not stride == 1: 26 | self.norm3 = nn.BatchNorm2d(planes) 27 | 28 | elif norm_fn == 'instance': 29 | self.norm1 = nn.InstanceNorm2d(planes) 30 | self.norm2 = nn.InstanceNorm2d(planes) 31 | if not stride == 1: 32 | self.norm3 = nn.InstanceNorm2d(planes) 33 | 34 | elif norm_fn == 'none': 35 | self.norm1 = nn.Sequential() 36 | self.norm2 = nn.Sequential() 37 | if not stride == 1: 38 | self.norm3 = nn.Sequential() 39 | 40 | elif norm_fn == 'domain': 41 | self.norm1 = DomainNorm(planes) 42 | self.norm2 = DomainNorm(planes) 43 | if not stride == 1: 44 | self.norm3 = DomainNorm(planes) 45 | 46 | if stride == 1: 47 | self.downsample = None 48 | 49 | else: 50 | self.downsample = nn.Sequential( 51 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 52 | 53 | def forward(self, x): 54 | y = x 55 | y = self.relu(self.norm1(self.conv1(y))) 56 | y = self.relu(self.norm2(self.conv2(y))) 57 | 58 | if self.downsample is not None: 59 | x = self.downsample(x) 60 | 61 | return self.relu(x + y) 62 | 63 | 64 | class BottleneckBlock(nn.Module): 65 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 66 | super(BottleneckBlock, self).__init__() 67 | 68 | self.conv1 = nn.Conv2d(in_planes, planes // 4, kernel_size=1, padding=0) 69 | self.conv2 = nn.Conv2d(planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride) 70 | self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0) 71 | self.relu = nn.ReLU(inplace=True) 72 | 73 | num_groups = planes // 8 74 | 75 | if norm_fn == 'group': 76 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4) 77 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4) 78 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 79 | if not stride == 1: 80 | self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 81 | 82 | elif norm_fn == 'batch': 83 | self.norm1 = nn.BatchNorm2d(planes // 4) 84 | self.norm2 = nn.BatchNorm2d(planes // 4) 85 | self.norm3 = nn.BatchNorm2d(planes) 86 | if not stride == 1: 87 | self.norm4 = nn.BatchNorm2d(planes) 88 | 89 | elif norm_fn == 'instance': 90 | self.norm1 = nn.InstanceNorm2d(planes // 4) 91 | self.norm2 = nn.InstanceNorm2d(planes // 4) 92 | self.norm3 = nn.InstanceNorm2d(planes) 93 | if not stride == 1: 94 | self.norm4 = nn.InstanceNorm2d(planes) 95 | 96 | elif norm_fn == 'none': 97 | self.norm1 = nn.Sequential() 98 | self.norm2 = nn.Sequential() 99 | self.norm3 = nn.Sequential() 100 | if not stride == 1: 101 | self.norm4 = nn.Sequential() 102 | 103 | if stride == 1: 104 | self.downsample = None 105 | 106 | else: 107 | self.downsample = nn.Sequential( 108 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) 109 | 110 | def forward(self, x): 111 | y = x 112 | y = self.relu(self.norm1(self.conv1(y))) 113 | y = self.relu(self.norm2(self.conv2(y))) 114 | y = self.relu(self.norm3(self.conv3(y))) 115 | 116 | if self.downsample is not None: 117 | x = self.downsample(x) 118 | 119 | return self.relu(x + y) 120 | 121 | 122 | class BasicEncoder(nn.Module): 123 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0, input_dim=3): 124 | super(BasicEncoder, self).__init__() 125 | self.norm_fn = norm_fn 126 | 127 | if self.norm_fn == 'group': 128 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) 129 | 130 | elif self.norm_fn == 'batch': 131 | self.norm1 = nn.BatchNorm2d(64) 132 | 133 | elif self.norm_fn == 'instance': 134 | self.norm1 = nn.InstanceNorm2d(64) 135 | 136 | elif self.norm_fn == 'none': 137 | self.norm1 = nn.Sequential() 138 | 139 | elif self.norm_fn == 'domain': 140 | self.norm1 = DomainNorm(64) 141 | 142 | self.conv1 = nn.Conv2d(input_dim, 64, kernel_size=7, stride=2, padding=3) 143 | self.relu1 = nn.ReLU(inplace=True) 144 | 145 | self.in_planes = 64 146 | self.layer1 = self._make_layer(64, stride=1) 147 | self.layer2 = self._make_layer(96, stride=2) 148 | self.layer3 = self._make_layer(128, stride=2) 149 | 150 | # output convolution 151 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) 152 | 153 | self.dropout = None 154 | if dropout > 0: 155 | self.dropout = nn.Dropout2d(p=dropout) 156 | 157 | for m in self.modules(): 158 | if isinstance(m, nn.Conv2d): 159 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 160 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 161 | if m.weight is not None: 162 | nn.init.constant_(m.weight, 1) 163 | if m.bias is not None: 164 | nn.init.constant_(m.bias, 0) 165 | 166 | def _make_layer(self, dim, stride=1): 167 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) 168 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) 169 | layers = (layer1, layer2) 170 | 171 | self.in_planes = dim 172 | return nn.Sequential(*layers) 173 | 174 | def forward(self, x): 175 | 176 | # if input is list, combine batch dimension 177 | is_list = isinstance(x, tuple) or isinstance(x, list) 178 | if is_list: 179 | batch_dim = x[0].shape[0] 180 | x = torch.cat(x, dim=0) 181 | 182 | x = self.conv1(x) 183 | x = self.norm1(x) 184 | x = self.relu1(x) 185 | 186 | x = self.layer1(x) 187 | x = self.layer2(x) 188 | x = self.layer3(x) 189 | 190 | x = self.conv2(x) 191 | 192 | if self.training and self.dropout is not None: 193 | x = self.dropout(x) 194 | 195 | if is_list: 196 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 197 | 198 | return x 199 | 200 | 201 | class SmallEncoder(nn.Module): 202 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): 203 | super(SmallEncoder, self).__init__() 204 | self.norm_fn = norm_fn 205 | 206 | if self.norm_fn == 'group': 207 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) 208 | 209 | elif self.norm_fn == 'batch': 210 | self.norm1 = nn.BatchNorm2d(32) 211 | 212 | elif self.norm_fn == 'instance': 213 | self.norm1 = nn.InstanceNorm2d(32) 214 | 215 | elif self.norm_fn == 'none': 216 | self.norm1 = nn.Sequential() 217 | 218 | self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) 219 | self.relu1 = nn.ReLU(inplace=True) 220 | 221 | self.in_planes = 32 222 | self.layer1 = self._make_layer(32, stride=1) 223 | self.layer2 = self._make_layer(64, stride=2) 224 | self.layer3 = self._make_layer(96, stride=2) 225 | 226 | self.dropout = None 227 | if dropout > 0: 228 | self.dropout = nn.Dropout2d(p=dropout) 229 | 230 | self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) 231 | 232 | for m in self.modules(): 233 | if isinstance(m, nn.Conv2d): 234 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 235 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 236 | if m.weight is not None: 237 | nn.init.constant_(m.weight, 1) 238 | if m.bias is not None: 239 | nn.init.constant_(m.bias, 0) 240 | 241 | def _make_layer(self, dim, stride=1): 242 | layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) 243 | layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) 244 | layers = (layer1, layer2) 245 | 246 | self.in_planes = dim 247 | return nn.Sequential(*layers) 248 | 249 | def forward(self, x): 250 | 251 | # if input is list, combine batch dimension 252 | is_list = isinstance(x, tuple) or isinstance(x, list) 253 | if is_list: 254 | batch_dim = x[0].shape[0] 255 | x = torch.cat(x, dim=0) 256 | 257 | x = self.conv1(x) 258 | x = self.norm1(x) 259 | x = self.relu1(x) 260 | 261 | x = self.layer1(x) 262 | x = self.layer2(x) 263 | x = self.layer3(x) 264 | x = self.conv2(x) 265 | 266 | if self.training and self.dropout is not None: 267 | x = self.dropout(x) 268 | 269 | if is_list: 270 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 271 | 272 | return x 273 | 274 | -------------------------------------------------------------------------------- /RAFT/core/raft.py: -------------------------------------------------------------------------------- 1 | import imageio 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from RAFT.core.update import BasicUpdateBlock, SmallUpdateBlock 8 | from RAFT.core.extractor import BasicEncoder, SmallEncoder 9 | from RAFT.core.corr import CorrBlock, AlternateCorrBlock, CorrBlock_pyramid 10 | from RAFT.core.utils.utils import bilinear_sampler, coords_grid, upflow8 11 | 12 | try: 13 | autocast = torch.cuda.amp.autocast 14 | except: 15 | # dummy autocast for PyTorch < 1.6 16 | class autocast: 17 | def __init__(self, enabled): 18 | pass 19 | 20 | def __enter__(self): 21 | pass 22 | 23 | def __exit__(self, *args): 24 | pass 25 | 26 | 27 | class RAFT(nn.Module): 28 | def __init__(self, args): 29 | super(RAFT, self).__init__() 30 | self.args = args 31 | 32 | if args.small: 33 | self.hidden_dim = hdim = 96 34 | self.context_dim = cdim = 64 35 | args.corr_levels = 4 36 | args.corr_radius = 3 37 | 38 | else: 39 | self.hidden_dim = hdim = 128 40 | self.context_dim = cdim = 128 41 | args.corr_levels = 4 42 | args.corr_radius = 4 43 | 44 | if 'dropout' not in self.args: 45 | self.args.dropout = 0 46 | 47 | if 'alternate_corr' not in self.args: 48 | self.args.alternate_corr = False 49 | 50 | # feature network, context network, and update block 51 | if args.small: 52 | self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout) 53 | self.cnet = SmallEncoder(output_dim=hdim + cdim, norm_fn='none', dropout=args.dropout) 54 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) 55 | 56 | else: 57 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout) 58 | self.cnet = BasicEncoder(output_dim=hdim + cdim, norm_fn='batch', dropout=args.dropout) 59 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) 60 | 61 | def freeze_bn(self): 62 | for m in self.modules(): 63 | if isinstance(m, nn.BatchNorm2d): 64 | m.eval() 65 | 66 | def initialize_flow(self, img): 67 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 68 | N, C, H, W = img.shape 69 | coords0 = coords_grid(N, H // 8, W // 8).to(img.device) 70 | coords1 = coords_grid(N, H // 8, W // 8).to(img.device) 71 | 72 | # optical flow computed as difference: flow = coords1 - coords0 73 | return coords0, coords1 74 | 75 | def upsample_flow(self, flow, mask): 76 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 77 | N, _, H, W = flow.shape 78 | mask = mask.view(N, 1, 9, 8, 8, H, W) 79 | mask = torch.softmax(mask, dim=2) 80 | 81 | up_flow = F.unfold(8 * flow, [3, 3], padding=1) 82 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 83 | 84 | up_flow = torch.sum(mask * up_flow, dim=2) 85 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 86 | return up_flow.reshape(N, 2, 8 * H, 8 * W) 87 | 88 | def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False): 89 | """ Estimate optical flow between pair of frames """ 90 | 91 | image1 = 2 * (image1 / 255.0) - 1.0 92 | image2 = 2 * (image2 / 255.0) - 1.0 93 | 94 | image1 = image1.contiguous() 95 | image2 = image2.contiguous() 96 | 97 | hdim = self.hidden_dim 98 | cdim = self.context_dim 99 | 100 | # run the feature network 101 | with autocast(enabled=self.args.mixed_precision): 102 | fmap1, fmap2 = self.fnet([image1, image2]) 103 | 104 | fmap1 = fmap1.float() 105 | fmap2 = fmap2.float() 106 | if self.args.alternate_corr: 107 | corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 108 | else: 109 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 110 | 111 | # run the context network 112 | with autocast(enabled=self.args.mixed_precision): 113 | cnet = self.cnet(image1) 114 | net, inp = torch.split(cnet, [hdim, cdim], dim=1) 115 | net = torch.tanh(net) 116 | inp = torch.relu(inp) 117 | 118 | coords0, coords1 = self.initialize_flow(image1) 119 | 120 | if flow_init is not None: 121 | coords1 = coords1 + flow_init 122 | 123 | flow_predictions = [] 124 | for itr in range(iters): 125 | coords1 = coords1.detach() 126 | corr = corr_fn(coords1) # index correlation volume 127 | 128 | flow = coords1 - coords0 129 | with autocast(enabled=self.args.mixed_precision): 130 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) 131 | 132 | # F(t+1) = F(t) + \Delta(t) 133 | coords1 = coords1 + delta_flow 134 | 135 | # upsample predictions 136 | if up_mask is None: 137 | flow_up = upflow8(coords1 - coords0) 138 | else: 139 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 140 | 141 | flow_predictions.append(flow_up) 142 | 143 | if test_mode: 144 | return coords1 - coords0, flow_up 145 | 146 | return flow_predictions 147 | 148 | -------------------------------------------------------------------------------- /RAFT/core/update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FlowHead(nn.Module): 7 | def __init__(self, input_dim=128, hidden_dim=256): 8 | super(FlowHead, self).__init__() 9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 11 | self.relu = nn.ReLU(inplace=True) 12 | 13 | def forward(self, x): 14 | return self.conv2(self.relu(self.conv1(x))) 15 | 16 | class ConvGRU(nn.Module): 17 | def __init__(self, hidden_dim=128, input_dim=192+128): 18 | super(ConvGRU, self).__init__() 19 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 20 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 21 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 22 | 23 | def forward(self, h, x): 24 | hx = torch.cat([h, x], dim=1) 25 | 26 | z = torch.sigmoid(self.convz(hx)) 27 | r = torch.sigmoid(self.convr(hx)) 28 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) 29 | 30 | h = (1-z) * h + z * q 31 | return h 32 | 33 | class SepConvGRU(nn.Module): 34 | def __init__(self, hidden_dim=128, input_dim=192+128): 35 | super(SepConvGRU, self).__init__() 36 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 37 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 38 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 39 | 40 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 41 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 42 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 43 | 44 | 45 | def forward(self, h, x): 46 | # horizontal 47 | hx = torch.cat([h, x], dim=1) 48 | z = torch.sigmoid(self.convz1(hx)) 49 | r = torch.sigmoid(self.convr1(hx)) 50 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) 51 | h = (1-z) * h + z * q 52 | 53 | # vertical 54 | hx = torch.cat([h, x], dim=1) 55 | z = torch.sigmoid(self.convz2(hx)) 56 | r = torch.sigmoid(self.convr2(hx)) 57 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) 58 | h = (1-z) * h + z * q 59 | 60 | return h 61 | 62 | class SmallMotionEncoder(nn.Module): 63 | def __init__(self, args): 64 | super(SmallMotionEncoder, self).__init__() 65 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 66 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) 67 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3) 68 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1) 69 | self.conv = nn.Conv2d(128, 80, 3, padding=1) 70 | 71 | def forward(self, flow, corr): 72 | cor = F.relu(self.convc1(corr)) 73 | flo = F.relu(self.convf1(flow)) 74 | flo = F.relu(self.convf2(flo)) 75 | cor_flo = torch.cat([cor, flo], dim=1) 76 | out = F.relu(self.conv(cor_flo)) 77 | return torch.cat([out, flow], dim=1) 78 | 79 | class BasicMotionEncoder(nn.Module): 80 | def __init__(self, args): 81 | super(BasicMotionEncoder, self).__init__() 82 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 83 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) 84 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 85 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3) 86 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 87 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) 88 | 89 | def forward(self, flow, corr): 90 | cor = F.relu(self.convc1(corr)) 91 | cor = F.relu(self.convc2(cor)) 92 | flo = F.relu(self.convf1(flow)) 93 | flo = F.relu(self.convf2(flo)) 94 | 95 | cor_flo = torch.cat([cor, flo], dim=1) 96 | out = F.relu(self.conv(cor_flo)) 97 | return torch.cat([out, flow], dim=1) 98 | 99 | class SmallUpdateBlock(nn.Module): 100 | def __init__(self, args, hidden_dim=96): 101 | super(SmallUpdateBlock, self).__init__() 102 | self.encoder = SmallMotionEncoder(args) 103 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) 104 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128) 105 | 106 | def forward(self, net, inp, corr, flow): 107 | motion_features = self.encoder(flow, corr) 108 | inp = torch.cat([inp, motion_features], dim=1) 109 | net = self.gru(net, inp) 110 | delta_flow = self.flow_head(net) 111 | 112 | return net, None, delta_flow 113 | 114 | class BasicUpdateBlock(nn.Module): 115 | def __init__(self, args, hidden_dim=128, input_dim=128): 116 | super(BasicUpdateBlock, self).__init__() 117 | self.args = args 118 | self.encoder = BasicMotionEncoder(args) 119 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) 120 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 121 | 122 | self.mask = nn.Sequential( 123 | nn.Conv2d(128, 256, 3, padding=1), 124 | nn.ReLU(inplace=True), 125 | nn.Conv2d(256, 64*9, 1, padding=0)) 126 | 127 | def forward(self, net, inp, corr, flow, upsample=True): 128 | motion_features = self.encoder(flow, corr) 129 | inp = torch.cat([inp, motion_features], dim=1) 130 | 131 | net = self.gru(net, inp) 132 | delta_flow = self.flow_head(net) 133 | 134 | # scale mask to balence gradients 135 | mask = .25 * self.mask(net) 136 | return net, mask, delta_flow 137 | 138 | 139 | 140 | -------------------------------------------------------------------------------- /RAFT/core/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/RealFlow/52ea80d416d55f454f63650eeba85687055285c2/RAFT/core/utils/__init__.py -------------------------------------------------------------------------------- /RAFT/core/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/RealFlow/52ea80d416d55f454f63650eeba85687055285c2/RAFT/core/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /RAFT/core/utils/__pycache__/augmentor.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/RealFlow/52ea80d416d55f454f63650eeba85687055285c2/RAFT/core/utils/__pycache__/augmentor.cpython-38.pyc -------------------------------------------------------------------------------- /RAFT/core/utils/__pycache__/frame_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/RealFlow/52ea80d416d55f454f63650eeba85687055285c2/RAFT/core/utils/__pycache__/frame_utils.cpython-38.pyc -------------------------------------------------------------------------------- /RAFT/core/utils/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/RealFlow/52ea80d416d55f454f63650eeba85687055285c2/RAFT/core/utils/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /RAFT/core/utils/augmentor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import math 4 | from PIL import Image 5 | 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | import torch 11 | from torchvision.transforms import ColorJitter 12 | import torch.nn.functional as F 13 | 14 | 15 | class FlowAugmentor: 16 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True): 17 | 18 | # spatial augmentation params 19 | self.crop_size = crop_size 20 | self.min_scale = min_scale 21 | self.max_scale = max_scale 22 | self.spatial_aug_prob = 0.8 23 | self.stretch_prob = 0.8 24 | self.max_stretch = 0.2 25 | 26 | # flip augmentation params 27 | self.do_flip = do_flip 28 | self.h_flip_prob = 0.5 29 | self.v_flip_prob = 0.1 30 | 31 | # photometric augmentation params 32 | self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14) 33 | self.asymmetric_color_aug_prob = 0.2 34 | self.eraser_aug_prob = 0.5 35 | 36 | def color_transform(self, img1, img2): 37 | """ Photometric augmentation """ 38 | 39 | # asymmetric 40 | if np.random.rand() < self.asymmetric_color_aug_prob: 41 | img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8) 42 | img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8) 43 | 44 | # symmetric 45 | else: 46 | image_stack = np.concatenate([img1, img2], axis=0) 47 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 48 | img1, img2 = np.split(image_stack, 2, axis=0) 49 | 50 | return img1, img2 51 | 52 | def eraser_transform(self, img1, img2, bounds=[50, 100]): 53 | """ Occlusion augmentation """ 54 | 55 | ht, wd = img1.shape[:2] 56 | if np.random.rand() < self.eraser_aug_prob: 57 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 58 | for _ in range(np.random.randint(1, 3)): 59 | x0 = np.random.randint(0, wd) 60 | y0 = np.random.randint(0, ht) 61 | dx = np.random.randint(bounds[0], bounds[1]) 62 | dy = np.random.randint(bounds[0], bounds[1]) 63 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color 64 | 65 | return img1, img2 66 | 67 | def spatial_transform(self, img1, img2, flow): 68 | # randomly sample scale 69 | ht, wd = img1.shape[:2] 70 | min_scale = np.maximum( 71 | (self.crop_size[0] + 8) / float(ht), 72 | (self.crop_size[1] + 8) / float(wd)) 73 | 74 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 75 | scale_x = scale 76 | scale_y = scale 77 | if np.random.rand() < self.stretch_prob: 78 | scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 79 | scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 80 | 81 | scale_x = np.clip(scale_x, min_scale, None) 82 | scale_y = np.clip(scale_y, min_scale, None) 83 | 84 | if np.random.rand() < self.spatial_aug_prob: 85 | # rescale the images 86 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 87 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 88 | flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 89 | flow = flow * [scale_x, scale_y] 90 | 91 | if self.do_flip: 92 | if np.random.rand() < self.h_flip_prob: # h-flip 93 | img1 = img1[:, ::-1] 94 | img2 = img2[:, ::-1] 95 | flow = flow[:, ::-1] * [-1.0, 1.0] 96 | 97 | if np.random.rand() < self.v_flip_prob: # v-flip 98 | img1 = img1[::-1, :] 99 | img2 = img2[::-1, :] 100 | flow = flow[::-1, :] * [1.0, -1.0] 101 | 102 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) 103 | x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) 104 | 105 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 106 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 107 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 108 | 109 | return img1, img2, flow 110 | 111 | def __call__(self, img1, img2, flow): 112 | img1, img2 = self.color_transform(img1, img2) 113 | img1, img2 = self.eraser_transform(img1, img2) 114 | img1, img2, flow = self.spatial_transform(img1, img2, flow) 115 | 116 | img1 = np.ascontiguousarray(img1) 117 | img2 = np.ascontiguousarray(img2) 118 | flow = np.ascontiguousarray(flow) 119 | 120 | return img1, img2, flow 121 | 122 | class SparseFlowAugmentor: 123 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False): 124 | # spatial augmentation params 125 | self.crop_size = crop_size 126 | self.min_scale = min_scale 127 | self.max_scale = max_scale 128 | self.spatial_aug_prob = 0.8 129 | self.stretch_prob = 0.8 130 | self.max_stretch = 0.2 131 | 132 | # flip augmentation params 133 | self.do_flip = do_flip 134 | self.h_flip_prob = 0.5 135 | self.v_flip_prob = 0.1 136 | 137 | # photometric augmentation params 138 | self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14) 139 | self.asymmetric_color_aug_prob = 0.2 140 | self.eraser_aug_prob = 0.5 141 | 142 | def color_transform(self, img1, img2): 143 | image_stack = np.concatenate([img1, img2], axis=0) 144 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 145 | img1, img2 = np.split(image_stack, 2, axis=0) 146 | return img1, img2 147 | 148 | def eraser_transform(self, img1, img2): 149 | ht, wd = img1.shape[:2] 150 | if np.random.rand() < self.eraser_aug_prob: 151 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 152 | for _ in range(np.random.randint(1, 3)): 153 | x0 = np.random.randint(0, wd) 154 | y0 = np.random.randint(0, ht) 155 | dx = np.random.randint(50, 100) 156 | dy = np.random.randint(50, 100) 157 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color 158 | 159 | return img1, img2 160 | 161 | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): 162 | ht, wd = flow.shape[:2] 163 | coords = np.meshgrid(np.arange(wd), np.arange(ht)) 164 | coords = np.stack(coords, axis=-1) 165 | 166 | coords = coords.reshape(-1, 2).astype(np.float32) 167 | flow = flow.reshape(-1, 2).astype(np.float32) 168 | valid = valid.reshape(-1).astype(np.float32) 169 | 170 | coords0 = coords[valid>=1] 171 | flow0 = flow[valid>=1] 172 | 173 | ht1 = int(round(ht * fy)) 174 | wd1 = int(round(wd * fx)) 175 | 176 | coords1 = coords0 * [fx, fy] 177 | flow1 = flow0 * [fx, fy] 178 | 179 | xx = np.round(coords1[:,0]).astype(np.int32) 180 | yy = np.round(coords1[:,1]).astype(np.int32) 181 | 182 | v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) 183 | xx = xx[v] 184 | yy = yy[v] 185 | flow1 = flow1[v] 186 | 187 | flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) 188 | valid_img = np.zeros([ht1, wd1], dtype=np.int32) 189 | 190 | flow_img[yy, xx] = flow1 191 | valid_img[yy, xx] = 1 192 | 193 | return flow_img, valid_img 194 | 195 | def spatial_transform(self, img1, img2, flow, valid): 196 | # randomly sample scale 197 | 198 | ht, wd = img1.shape[:2] 199 | min_scale = np.maximum( 200 | (self.crop_size[0] + 1) / float(ht), 201 | (self.crop_size[1] + 1) / float(wd)) 202 | 203 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 204 | scale_x = np.clip(scale, min_scale, None) 205 | scale_y = np.clip(scale, min_scale, None) 206 | 207 | if np.random.rand() < self.spatial_aug_prob: 208 | # rescale the images 209 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 210 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 211 | flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y) 212 | 213 | if self.do_flip: 214 | if np.random.rand() < 0.5: # h-flip 215 | img1 = img1[:, ::-1] 216 | img2 = img2[:, ::-1] 217 | flow = flow[:, ::-1] * [-1.0, 1.0] 218 | valid = valid[:, ::-1] 219 | 220 | margin_y = 20 221 | margin_x = 50 222 | 223 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y) 224 | x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x) 225 | 226 | y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) 227 | x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) 228 | 229 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 230 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 231 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 232 | valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 233 | return img1, img2, flow, valid 234 | 235 | def __call__(self, img1, img2, flow, valid): 236 | img1, img2 = self.color_transform(img1, img2) 237 | img1, img2 = self.eraser_transform(img1, img2) 238 | img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid) 239 | 240 | img1 = np.ascontiguousarray(img1) 241 | img2 = np.ascontiguousarray(img2) 242 | flow = np.ascontiguousarray(flow) 243 | valid = np.ascontiguousarray(valid) 244 | 245 | return img1, img2, flow, valid 246 | -------------------------------------------------------------------------------- /RAFT/core/utils/flow_viz.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization 2 | 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Tom Runia 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to conditions. 14 | # 15 | # Author: Tom Runia 16 | # Date Created: 2018-08-03 17 | 18 | import numpy as np 19 | 20 | def make_colorwheel(): 21 | """ 22 | Generates a color wheel for optical flow visualization as presented in: 23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 25 | 26 | Code follows the original C++ source code of Daniel Scharstein. 27 | Code follows the the Matlab source code of Deqing Sun. 28 | 29 | Returns: 30 | np.ndarray: Color wheel 31 | """ 32 | 33 | RY = 15 34 | YG = 6 35 | GC = 4 36 | CB = 11 37 | BM = 13 38 | MR = 6 39 | 40 | ncols = RY + YG + GC + CB + BM + MR 41 | colorwheel = np.zeros((ncols, 3)) 42 | col = 0 43 | 44 | # RY 45 | colorwheel[0:RY, 0] = 255 46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) 47 | col = col+RY 48 | # YG 49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) 50 | colorwheel[col:col+YG, 1] = 255 51 | col = col+YG 52 | # GC 53 | colorwheel[col:col+GC, 1] = 255 54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) 55 | col = col+GC 56 | # CB 57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) 58 | colorwheel[col:col+CB, 2] = 255 59 | col = col+CB 60 | # BM 61 | colorwheel[col:col+BM, 2] = 255 62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) 63 | col = col+BM 64 | # MR 65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) 66 | colorwheel[col:col+MR, 0] = 255 67 | return colorwheel 68 | 69 | 70 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 71 | """ 72 | Applies the flow color wheel to (possibly clipped) flow components u and v. 73 | 74 | According to the C++ source code of Daniel Scharstein 75 | According to the Matlab source code of Deqing Sun 76 | 77 | Args: 78 | u (np.ndarray): Input horizontal flow of shape [H,W] 79 | v (np.ndarray): Input vertical flow of shape [H,W] 80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 81 | 82 | Returns: 83 | np.ndarray: Flow visualization image of shape [H,W,3] 84 | """ 85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 86 | colorwheel = make_colorwheel() # shape [55x3] 87 | ncols = colorwheel.shape[0] 88 | rad = np.sqrt(np.square(u) + np.square(v)) 89 | a = np.arctan2(-v, -u)/np.pi 90 | fk = (a+1) / 2*(ncols-1) 91 | k0 = np.floor(fk).astype(np.int32) 92 | k1 = k0 + 1 93 | k1[k1 == ncols] = 0 94 | f = fk - k0 95 | for i in range(colorwheel.shape[1]): 96 | tmp = colorwheel[:,i] 97 | col0 = tmp[k0] / 255.0 98 | col1 = tmp[k1] / 255.0 99 | col = (1-f)*col0 + f*col1 100 | idx = (rad <= 1) 101 | col[idx] = 1 - rad[idx] * (1-col[idx]) 102 | col[~idx] = col[~idx] * 0.75 # out of range 103 | # Note the 2-i => BGR instead of RGB 104 | ch_idx = 2-i if convert_to_bgr else i 105 | flow_image[:,:,ch_idx] = np.floor(255 * col) 106 | return flow_image 107 | 108 | 109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): 110 | """ 111 | Expects a two dimensional flow image of shape. 112 | 113 | Args: 114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 117 | 118 | Returns: 119 | np.ndarray: Flow visualization image of shape [H,W,3] 120 | """ 121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 123 | if clip_flow is not None: 124 | flow_uv = np.clip(flow_uv, 0, clip_flow) 125 | u = flow_uv[:,:,0] 126 | v = flow_uv[:,:,1] 127 | rad = np.sqrt(np.square(u) + np.square(v)) 128 | rad_max = np.max(rad) 129 | epsilon = 1e-5 130 | u = u / (rad_max + epsilon) 131 | v = v / (rad_max + epsilon) 132 | return flow_uv_to_colors(u, v, convert_to_bgr) -------------------------------------------------------------------------------- /RAFT/core/utils/frame_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from os.path import * 4 | import re 5 | 6 | import cv2 7 | 8 | cv2.setNumThreads(0) 9 | cv2.ocl.setUseOpenCL(False) 10 | 11 | TAG_CHAR = np.array([202021.25], np.float32) 12 | 13 | 14 | def readFlow(fn): 15 | """ Read .flo file in Middlebury format""" 16 | # Code adapted from: 17 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 18 | 19 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 20 | # print 'fn = %s'%(fn) 21 | with open(fn, 'rb') as f: 22 | magic = np.fromfile(f, np.float32, count=1) 23 | if 202021.25 != magic: 24 | print('Magic number incorrect. Invalid .flo file') 25 | return None 26 | else: 27 | w = np.fromfile(f, np.int32, count=1) 28 | h = np.fromfile(f, np.int32, count=1) 29 | # print 'Reading %d x %d flo file\n' % (w, h) 30 | data = np.fromfile(f, np.float32, count=2 * int(w) * int(h)) 31 | # Reshape data into 3D array (columns, rows, bands) 32 | # The reshape here is for visualization, the original code is (w,h,2) 33 | return np.resize(data, (int(h), int(w), 2)) 34 | 35 | 36 | def readPFM(file): 37 | file = open(file, 'rb') 38 | 39 | color = None 40 | width = None 41 | height = None 42 | scale = None 43 | endian = None 44 | 45 | header = file.readline().rstrip() 46 | if header == b'PF': 47 | color = True 48 | elif header == b'Pf': 49 | color = False 50 | else: 51 | raise Exception('Not a PFM file.') 52 | 53 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) 54 | if dim_match: 55 | width, height = map(int, dim_match.groups()) 56 | else: 57 | raise Exception('Malformed PFM header.') 58 | 59 | scale = float(file.readline().rstrip()) 60 | if scale < 0: # little-endian 61 | endian = '<' 62 | scale = -scale 63 | else: 64 | endian = '>' # big-endian 65 | 66 | data = np.fromfile(file, endian + 'f') 67 | shape = (height, width, 3) if color else (height, width) 68 | 69 | data = np.reshape(data, shape) 70 | data = np.flipud(data) 71 | return data 72 | 73 | 74 | def writeFlow(filename, uv, v=None): 75 | """ Write optical flow to file. 76 | 77 | If v is None, uv is assumed to contain both u and v channels, 78 | stacked in depth. 79 | Original code by Deqing Sun, adapted from Daniel Scharstein. 80 | """ 81 | nBands = 2 82 | 83 | if v is None: 84 | assert (uv.ndim == 3) 85 | assert (uv.shape[2] == 2) 86 | u = uv[:, :, 0] 87 | v = uv[:, :, 1] 88 | else: 89 | u = uv 90 | 91 | assert (u.shape == v.shape) 92 | height, width = u.shape 93 | f = open(filename, 'wb') 94 | # write the header 95 | f.write(TAG_CHAR) 96 | np.array(width).astype(np.int32).tofile(f) 97 | np.array(height).astype(np.int32).tofile(f) 98 | # arrange into matrix form 99 | tmp = np.zeros((height, width * nBands)) 100 | tmp[:, np.arange(width) * 2] = u 101 | tmp[:, np.arange(width) * 2 + 1] = v 102 | tmp.astype(np.float32).tofile(f) 103 | f.close() 104 | 105 | 106 | def readFlowKITTI(filename): 107 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR) 108 | flow = flow[:, :, ::-1].astype(np.float32) 109 | flow, valid = flow[:, :, :2], flow[:, :, 2] 110 | flow = (flow - 2 ** 15) / 64.0 111 | return flow, valid 112 | 113 | 114 | def readDispKITTI(filename): 115 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 116 | valid = disp > 0.0 117 | flow = np.stack([-disp, np.zeros_like(disp)], -1) 118 | return flow, valid 119 | 120 | 121 | def writeFlowKITTI(filename, uv): 122 | uv = 64.0 * uv + 2 ** 15 123 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 124 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 125 | cv2.imwrite(filename, uv[..., ::-1]) 126 | 127 | 128 | def read_gen(file_name, pil=False): 129 | ext = splitext(file_name)[-1] 130 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': 131 | return Image.open(file_name) 132 | elif ext == '.bin' or ext == '.raw': 133 | return np.load(file_name) 134 | elif ext == '.flo': 135 | return readFlow(file_name).astype(np.float32) 136 | elif ext == '.pfm': 137 | flow = readPFM(file_name).astype(np.float32) 138 | if len(flow.shape) == 2: 139 | return flow 140 | else: 141 | return flow[:, :, :-1] 142 | else: 143 | raise ValueError('wrong file type: %s' % ext) 144 | 145 | TAG_FLOAT = 202021.25 146 | def depth_read(filename): 147 | """ Read depth data from file, return as numpy array. """ 148 | f = open(filename,'rb') 149 | check = np.fromfile(f,dtype=np.float32,count=1)[0] 150 | assert check == TAG_FLOAT, ' depth_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(TAG_FLOAT,check) 151 | width = np.fromfile(f,dtype=np.int32,count=1)[0] 152 | height = np.fromfile(f,dtype=np.int32,count=1)[0] 153 | size = width*height 154 | assert width > 0 and height > 0 and size > 1 and size < 100000000, ' depth_read:: Wrong input size (width = {0}, height = {1}).'.format(width,height) 155 | depth = np.fromfile(f,dtype=np.float32,count=-1).reshape((height,width)) 156 | return depth -------------------------------------------------------------------------------- /RAFT/core/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy import interpolate 5 | 6 | 7 | class InputPadder: 8 | """ Pads images such that dimensions are divisible by 8 """ 9 | def __init__(self, dims, mode='sintel', divisible=8): 10 | self.ht, self.wd = dims[-2:] 11 | pad_ht = (((self.ht // divisible) + 1) * divisible - self.ht) % divisible 12 | pad_wd = (((self.wd // divisible) + 1) * divisible - self.wd) % divisible 13 | if mode == 'sintel': 14 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 15 | else: 16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] 17 | 18 | def pad(self, *inputs, mode='replicate'): 19 | return [F.pad(x, self._pad, mode=mode) for x in inputs] 20 | 21 | def unpad(self,x): 22 | ht, wd = x.shape[-2:] 23 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 24 | return x[..., c[0]:c[1], c[2]:c[3]] 25 | 26 | def forward_interpolate(flow): 27 | flow = flow.detach().cpu().numpy() 28 | dx, dy = flow[0], flow[1] 29 | 30 | ht, wd = dx.shape 31 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 32 | 33 | x1 = x0 + dx 34 | y1 = y0 + dy 35 | 36 | x1 = x1.reshape(-1) 37 | y1 = y1.reshape(-1) 38 | dx = dx.reshape(-1) 39 | dy = dy.reshape(-1) 40 | 41 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 42 | x1 = x1[valid] 43 | y1 = y1[valid] 44 | dx = dx[valid] 45 | dy = dy[valid] 46 | 47 | flow_x = interpolate.griddata( 48 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) 49 | 50 | flow_y = interpolate.griddata( 51 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) 52 | 53 | flow = np.stack([flow_x, flow_y], axis=0) 54 | return torch.from_numpy(flow).float() 55 | 56 | 57 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 58 | """ Wrapper for grid_sample, uses pixel coordinates """ 59 | H, W = img.shape[-2:] 60 | xgrid, ygrid = coords.split([1,1], dim=-1) 61 | xgrid = 2*xgrid/(W-1) - 1 62 | ygrid = 2*ygrid/(H-1) - 1 63 | 64 | grid = torch.cat([xgrid, ygrid], dim=-1) 65 | img = F.grid_sample(img, grid, align_corners=True) 66 | 67 | if mask: 68 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 69 | return img, mask.float() 70 | 71 | return img 72 | 73 | 74 | def coords_grid(batch, ht, wd): 75 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) 76 | coords = torch.stack(coords[::-1], dim=0).float() 77 | return coords[None].repeat(batch, 1, 1, 1) 78 | 79 | 80 | def upflow8(flow, mode='bilinear'): 81 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 82 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 83 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [ECCV 2022 Oral] RealFlow: EM-based Realistic Optical Flow Dataset Generation from Videos ([Paper](https://arxiv.org/pdf/2207.11075.pdf)) 2 | 3 |

Yunhui Han1, Kunming Luo2, Ao Luo2, Jiangyu Liu2, Haoqiang Fan2, Guiming Luo1, Shuaicheng Liu3,2* 4 |

1. Tsinghua University, 2. Megvii Research 5 |

3. University of Electronic Science and Technology of China 6 | 7 | 8 | ## Abstract 9 | Obtaining the ground truth labels from a video is challenging since the manual annotation of pixel-wise flow labels is prohibitively expensive and laborious. Besides, existing approaches try to adapt the trained model on synthetic datasets to authentic videos, which inevitably suffers from domain discrepancy and hinders the performance for realworld applications. To solve these problems, we propose RealFlow, an Expectation-Maximization based framework that can create large-scale optical flow datasets directly from any unlabeled realistic videos. Specifically, we first estimate optical flow between a pair of video frames, and then synthesize a new image from this pair based on the predicted flow. Thus the new image pairs and their corresponding flows can be regarded as a new training set. Besides, we design a Realistic Image Pair Rendering (RIPR) module that adopts softmax splatting and bi-directional hole filling techniques to alleviate the artifacts of the image synthesis. In the E-step, RIPR renders new images to create a large quantity of training data. In the M-step, we utilize the generated training data to train an optical flow network, which can be used to estimate optical flows in the next E-step. During the iterative learning steps, the capability of the flow network is gradually improved, so is the accuracy of the flow, as well as the quality of the synthesized dataset. Experimental results show that RealFlow outperforms previous dataset generation methods by a considerably large margin. Moreover, based on the generated dataset, our approach achieves state-of-the-art performance on two standard benchmarks compared with both supervised and unsupervised optical flow methods 10 | 11 | ## Motivation 12 | ![motivation](https://user-images.githubusercontent.com/1344482/180913272-d8e1af87-b305-4beb-b067-ff29ce53a56d.JPG) 13 | 14 | Top: previous methods use synthetic motion to produce training pairs. Bottom: we propose to construct training pairs with realistic motion labels from the real-world video sequence. We estimate optical flow between two frames as the training label and synthesize a ‘New Image 2’. Both the new view and flow labels are refined iteratively in the EM-based framework for mutual improvements. 15 | 16 | ## Requirements 17 | - torch>=1.8.1 18 | - torchvision>=0.9.1 19 | - opencv-python>=4.5.2 20 | - timm>=0.4.5 21 | - cupy>=5.0.0 22 | - numpy>=1.15.0 23 | 24 | ## Rendered Datasets 25 | ![results](https://user-images.githubusercontent.com/1344482/180913871-cbbce758-8b03-46b5-b3a4-b07f0b229f82.JPG) 26 | 27 | #### Download 28 | 29 | You can download all the generated datasets and pretrained models in our paper: 30 | 31 | - Download the generated datasets using shell scripts `dataset_download.sh` 32 | ```shell 33 | sh dataset_download.sh 34 | ``` 35 | the dataset will be downloaded in `./RF_dataset` 36 | 37 | - Download the pretrained models using this link: [pretrained_models](https://data.megengine.org.cn/research/realflow/models.zip). 38 | 39 | 40 | ## Render New Data 41 | Download the pretrained DPT model from [here](https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt) and pretrained RAFT C+T model (raft-things.pth) from [here](https://drive.google.com/drive/folders/1sWDsfuZ3Up38EUQt7-JDTT1HcGHuJgvT?usp=sharing) 42 | 43 | Download [KITTI multi-view](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php) Datasets. 44 | You can run the following command to render RF-Ktrain: 45 | ```shell 46 | python RealFlow.py 47 | ``` 48 | You can also download ALOV and BDD100k from their official website to render RF-AB. Using utils/video2img.py to capture pictures. 49 | 50 | 51 | You can simply render a new pair using: 52 | ```shell 53 | python demo.py 54 | ``` 55 | 56 | ## Citation 57 | If you find this work useful for your research, please cite: 58 | ``` 59 | @inproceedings{han2022realflow, 60 | title={RealFlow: EM-Based Realistic Optical Flow Dataset Generation from Videos}, 61 | author={Han, Yunhui and Luo, Kunming and Luo, Ao and Liu, Jiangyu and Fan, Haoqiang and Luo, Guiming and Liu, Shuaicheng}, 62 | booktitle={European Conference on Computer Vision}, 63 | pages={288--305}, 64 | year={2022} 65 | } 66 | 67 | ``` 68 | 69 | ## Acknowledgements 70 | Part of the code is adapted from previous works: 71 | - [RAFT](https://github.com/princeton-vl/RAFT) 72 | - [DPT](https://github.com/isl-org/DPT) 73 | - [Softmax Splatting](https://github.com/sniklaus/softmax-splatting) 74 | 75 | Our datasets are generated from [KITTI](http://www.cvlibs.net/datasets/kitti/index.php), [Sintel](http://sintel.is.tue.mpg.de/), [BDD100k](https://github.com/bdd100k/bdd100k), [DAVIS](https://davischallenge.org/), and [ALOV](http://crcv.ucf.edu/data/ALOV++/). 76 | 77 | We thank all the authors for their contributions. 78 | -------------------------------------------------------------------------------- /RFdata/flow/RFflow.flo: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/RealFlow/52ea80d416d55f454f63650eeba85687055285c2/RFdata/flow/RFflow.flo -------------------------------------------------------------------------------- /RFdata/img1/test_1-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/RealFlow/52ea80d416d55f454f63650eeba85687055285c2/RFdata/img1/test_1-1.png -------------------------------------------------------------------------------- /RFdata/img2/RFtest_1-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/RealFlow/52ea80d416d55f454f63650eeba85687055285c2/RFdata/img2/RFtest_1-2.png -------------------------------------------------------------------------------- /RealFLow.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import torch 5 | import RAFT.core.datasets as datasets 6 | from RAFT.core.utils.frame_utils import writeFlow 7 | from RAFT.core.raft import RAFT 8 | from RAFT.core.utils.utils import InputPadder 9 | from torchvision.utils import save_image 10 | from tqdm import tqdm 11 | from utils.tools import FlowReversal 12 | from softmax_splatting import softsplat 13 | from DPT.dpt.models import DPTDepthModel 14 | import imageio 15 | import time 16 | import cv2 17 | 18 | 19 | @torch.no_grad() 20 | def render_local(flow_net, dataset, save_path, iters=24): 21 | 22 | #load DPT depth model, using pretrain DPT model 23 | depth_model_path = "DPT/model/dpt_large-midas-2f21e586.pt" 24 | DPT = DPTDepthModel( 25 | path=depth_model_path, 26 | backbone="vitl16_384", 27 | non_negative=True, 28 | enable_attention_hooks=False, 29 | ) 30 | DPT.cuda() 31 | DPT.eval() 32 | 33 | if not os.path.exists(save_path): 34 | os.makedirs('{:s}/img1'.format(save_path)) 35 | os.makedirs('{:s}/img2'.format(save_path)) 36 | os.makedirs('{:s}/flow'.format(save_path)) 37 | 38 | 39 | for val_id in tqdm(range(0, len(dataset))): 40 | image1, image2, _,_ = dataset[val_id] 41 | image1 = image1[None].cuda() 42 | image2 = image2[None].cuda() 43 | 44 | padder = InputPadder(image1.shape, 8) 45 | image1, image2 = padder.pad(image1, image2) 46 | 47 | # estimate bi-directional flow 48 | with torch.no_grad(): 49 | _, flow_forward = flow_net(image1, image2, iters=iters, test_mode=True) 50 | _, flow_back = flow_net(image2, image1, iters=iters, test_mode=True) 51 | 52 | flow_fw = padder.unpad(flow_forward) 53 | image1 = padder.unpad(image1).contiguous() 54 | image2 = padder.unpad(image2) 55 | flow_bw = padder.unpad(flow_back) 56 | 57 | # setting alpha 58 | linspace = torch.rand(1).cuda()*2 59 | flow_fw = flow_fw * linspace 60 | flow_bw = flow_bw * (1 - linspace) 61 | 62 | # occ check 63 | with torch.no_grad(): 64 | fw = FlowReversal() 65 | _, occ = fw.forward(image1, flow_fw) 66 | occ = torch.clamp(occ, 0, 1) 67 | 68 | # dilated occ mask 69 | occ = occ.squeeze(0).permute(1,2,0).cpu().numpy() 70 | occ = (1-occ)*255 71 | kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5)) 72 | dilated = cv2.dilate(occ, kernel)/255 73 | occ = 1-torch.from_numpy(dilated).permute(2,0,1).unsqueeze(0).cuda() 74 | 75 | 76 | padder = InputPadder(image1.shape, mode='sintel', divisible=32) 77 | input, input2, flow_fw, flow_bw = padder.pad(image1 / 255, image2 / 255, flow_fw, flow_bw) 78 | 79 | # estimate depth and splatting 80 | with torch.no_grad(): 81 | 82 | # estimate depth and normalize 83 | tenMetric = DPT(input.cuda()) 84 | tenMetric = (tenMetric - tenMetric.min()) / (tenMetric.max() - tenMetric.min()) 85 | 86 | # splatting can choose: softmax, max, summation 87 | output1 = softsplat.FunctionSoftsplat(tenInput=input, tenFlow=flow_fw, 88 | tenMetric=tenMetric.unsqueeze(0), 89 | strType='softmax') 90 | 91 | tenMetric2 = DPT(input2.cuda()) 92 | tenMetric2 = (tenMetric2 - tenMetric2.min()) / (tenMetric2.max() - tenMetric2.min()) 93 | output2 = softsplat.FunctionSoftsplat(tenInput=input2, tenFlow=flow_bw, 94 | tenMetric=tenMetric2.unsqueeze(0), 95 | strType='softmax') 96 | # fuse the result 97 | output = padder.unpad(output1) * occ + (1 - occ) * padder.unpad(output2) 98 | input = padder.unpad(input) 99 | flow = padder.unpad(flow_fw).squeeze(0).permute(1, 2, 0).cpu().numpy() 100 | save_image(input, save_path+'/img1/img1_{}.png'.format(val_id)) 101 | save_image(output, save_path+'/img2/img2_{}.png'.format(val_id)) 102 | writeFlow(save_path+'/flow/flow_{}.flo'.format(val_id),flow) 103 | 104 | 105 | if __name__ == '__main__': 106 | parser = argparse.ArgumentParser() 107 | # RAFT parameteqqrs 108 | parser.add_argument('--model', help="restore checkpoint") 109 | parser.add_argument('--small', action='store_true', help='use small model') 110 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') 111 | parser.add_argument('--save_location', help="save the results in local or oss") 112 | parser.add_argument('--save_path', help=" local path to save the result") 113 | parser.add_argument('--iter', help=" kitti 24, sintel 32") 114 | args = parser.parse_args() 115 | 116 | 117 | # load RAFT model 118 | model = torch.nn.DataParallel(RAFT(args)) 119 | model.load_state_dict(torch.load(args.model)) 120 | 121 | model.cuda() 122 | model.eval() 123 | 124 | # choose your dataset here 125 | dataset = datasets.KITTI() 126 | 127 | with torch.no_grad(): 128 | render_local(model, dataset, args.save_path, iters= int(args.iter)) 129 | -------------------------------------------------------------------------------- /dataset_download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | mkdir RF_dataset 3 | cd RF_dataset 4 | 5 | # RF-Ktrain 6 | mkdir RF-Ktrain 7 | cd RF-Ktrain 8 | wget https://data.megengine.org.cn/research/realflow/RF-Ktrain-flow.zip 9 | wget https://data.megengine.org.cn/research/realflow/RF-Ktrain-img.zip 10 | wget https://data.megengine.org.cn/research/realflow/RF-Ktrain-flo.zip 11 | cd .. 12 | 13 | # RF-KTest 14 | mkdir RF-KTest 15 | cd RF-KTest 16 | wget https://data.megengine.org.cn/research/realflow/RF-KTest-flow.zip 17 | wget https://data.megengine.org.cn/research/realflow/RF-KTest-img.zip 18 | wget https://data.megengine.org.cn/research/realflow/RF-KTest-flo.zip 19 | cd .. 20 | 21 | # RF-Sintel 22 | mkdir RF-Sintel 23 | wget https://data.megengine.org.cn/research/realflow/RFAB-sintel-flow.zip 24 | wget https://data.megengine.org.cn/research/realflow/RFAB-sintel-img.zip 25 | cd RF-Sintel 26 | 27 | # RF-DAVIS 28 | mkdir RF-DAVIS 29 | cd RF-DAVIS 30 | wget https://data.megengine.org.cn/research/realflow/RF-Davis-flow.zip 31 | wget https://data.megengine.org.cn/research/realflow/RF-Davis-img.zip 32 | wget https://data.megengine.org.cn/research/realflow/RF-Davis-flo.zip 33 | cd .. 34 | 35 | # RF-AB 36 | mkdir RF-AB 37 | cd RF-AB 38 | wget https://data.megengine.org.cn/research/realflow/RFAB-flow-Apart0.zip 39 | wget https://data.megengine.org.cn/research/realflow/RFAB-flow-Apart1.zip 40 | wget https://data.megengine.org.cn/research/realflow/RFAB-flow-Apart2.zip 41 | wget https://data.megengine.org.cn/research/realflow/RFAB-flow-Apart3.zip 42 | wget https://data.megengine.org.cn/research/realflow/RFAB-flow-Bpart0.zip 43 | wget https://data.megengine.org.cn/research/realflow/RFAB-flow-Bpart1.zip 44 | wget https://data.megengine.org.cn/research/realflow/RFAB-flow-Bpart2.zip 45 | wget https://data.megengine.org.cn/research/realflow/RFAB-flow-Bpart3.zip 46 | wget https://data.megengine.org.cn/research/realflow/RFAB-img-Apart0.zip 47 | wget https://data.megengine.org.cn/research/realflow/RFAB-img-Apart1.zip 48 | wget https://data.megengine.org.cn/research/realflow/RFAB-img-Apart2.zip 49 | wget https://data.megengine.org.cn/research/realflow/RFAB-img-Apart3.zip 50 | wget https://data.megengine.org.cn/research/realflow/RFAB-img-Bpart0.zip 51 | wget https://data.megengine.org.cn/research/realflow/RFAB-img-Bpart1.zip 52 | wget https://data.megengine.org.cn/research/realflow/RFAB-img-Bpart2.zip 53 | wget https://data.megengine.org.cn/research/realflow/RFAB-img-Bpart3.zip 54 | wget https://data.megengine.org.cn/research/realflow/RFAB-flo-Apart0.zip 55 | wget https://data.megengine.org.cn/research/realflow/RFAB-flo-Apart1.zip 56 | wget https://data.megengine.org.cn/research/realflow/RFAB-flo-Apart2.zip 57 | wget https://data.megengine.org.cn/research/realflow/RFAB-flo-Apart3.zip 58 | wget https://data.megengine.org.cn/research/realflow/RFAB-flo-Bpart0.zip 59 | wget https://data.megengine.org.cn/research/realflow/RFAB-flo-Bpart1.zip 60 | wget https://data.megengine.org.cn/research/realflow/RFAB-flo-Bpart2.zip 61 | wget https://data.megengine.org.cn/research/realflow/RFAB-flo-Bpart3.zip 62 | cd .. 63 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import torch 5 | import RAFT.core.datasets as datasets 6 | from RAFT.core.utils.frame_utils import writeFlow 7 | from RAFT.core.raft import RAFT 8 | from RAFT.core.utils.utils import InputPadder 9 | from torchvision.utils import save_image 10 | from tqdm import tqdm 11 | from utils.tools import FlowReversal 12 | from softmax_splatting import softsplat 13 | from DPT.dpt.models import DPTDepthModel 14 | import imageio 15 | import time 16 | import cv2 17 | 18 | @torch.no_grad() 19 | def render_local(flow_net, sample, save_path, alpha, splatting, iters=24): 20 | 21 | #load DPT depth model, using pretrain DPT model 22 | depth_model_path = "DPT/model/dpt_large-midas-2f21e586.pt" 23 | DPT = DPTDepthModel( 24 | path=depth_model_path, 25 | backbone="vitl16_384", 26 | non_negative=True, 27 | enable_attention_hooks=False, 28 | ) 29 | DPT.cuda() 30 | DPT.eval() 31 | 32 | if not os.path.exists(save_path): 33 | os.makedirs('{:s}/img1'.format(save_path)) 34 | os.makedirs('{:s}/img2'.format(save_path)) 35 | os.makedirs('{:s}/flow'.format(save_path)) 36 | 37 | 38 | image1, image2 = sample[0], sample[1] 39 | image1 = torch.from_numpy(image1).cuda().unsqueeze(0).permute(0,3,1,2).float() 40 | image2 = torch.from_numpy(image2).cuda().unsqueeze(0).permute(0,3,1,2).float() 41 | 42 | padder = InputPadder(image1.shape, 8) 43 | image1, image2 = padder.pad(image1, image2) 44 | 45 | # estimate bi-directional flow 46 | with torch.no_grad(): 47 | _, flow_forward = flow_net(image1, image2, iters=iters, test_mode=True) 48 | _, flow_back = flow_net(image2, image1, iters=iters, test_mode=True) 49 | 50 | flow_fw = padder.unpad(flow_forward) 51 | image1 = padder.unpad(image1).contiguous() 52 | image2 = padder.unpad(image2) 53 | flow_bw = padder.unpad(flow_back) 54 | 55 | # setting alpha 56 | linspace = alpha 57 | flow_fw = flow_fw * linspace 58 | flow_bw = flow_bw * (1 - linspace) 59 | 60 | # occ check 61 | with torch.no_grad(): 62 | fw = FlowReversal() 63 | _, occ = fw.forward(image1, flow_fw) 64 | occ = torch.clamp(occ, 0, 1) 65 | 66 | # dilated occ mask 67 | occ = occ.squeeze(0).permute(1,2,0).cpu().numpy() 68 | occ = (1-occ)*255 69 | kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5)) 70 | dilated = cv2.dilate(occ, kernel)/255 71 | occ = 1-torch.from_numpy(dilated).permute(2,0,1).unsqueeze(0).cuda() 72 | 73 | 74 | padder = InputPadder(image1.shape, mode='sintel', divisible=32) 75 | input, input2, flow_fw, flow_bw = padder.pad(image1 / 255, image2 / 255, flow_fw, flow_bw) 76 | 77 | # estimate depth and splatting 78 | with torch.no_grad(): 79 | 80 | # estimate depth and normalize 81 | tenMetric = DPT(input.cuda()) 82 | tenMetric = (tenMetric - tenMetric.min()) / (tenMetric.max() - tenMetric.min()) 83 | 84 | # splatting can choose: softmax, max, summation 85 | output1 = softsplat.FunctionSoftsplat(tenInput=input, tenFlow=flow_fw, 86 | tenMetric=tenMetric.unsqueeze(0), 87 | strType=splatting) 88 | 89 | tenMetric2 = DPT(input2.cuda()) 90 | tenMetric2 = (tenMetric2 - tenMetric2.min()) / (tenMetric2.max() - tenMetric2.min()) 91 | output2 = softsplat.FunctionSoftsplat(tenInput=input2, tenFlow=flow_bw, 92 | tenMetric=tenMetric2.unsqueeze(0), 93 | strType=splatting) 94 | # fuse the result 95 | output = padder.unpad(output1) * occ + (1 - occ) * padder.unpad(output2) 96 | input = padder.unpad(input) 97 | flow = padder.unpad(flow_fw).squeeze(0).permute(1, 2, 0).cpu().numpy() 98 | save_image(input, save_path+'/img1/test_1-1.png') 99 | save_image(output, save_path+'/img2/RFtest_1-2.png') 100 | writeFlow(save_path+'/flow/RFflow.flo',flow) 101 | 102 | 103 | if __name__ == '__main__': 104 | parser = argparse.ArgumentParser() 105 | # RAFT parameteqqrs 106 | parser.add_argument('--model', help="restore checkpoint") 107 | parser.add_argument('--small', action='store_true', help='use small model') 108 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') 109 | parser.add_argument('--save_location', help="save the results in local or oss") 110 | parser.add_argument('--save_path', help=" local path to save the result") 111 | parser.add_argument('--iter', help=" kitti 24, sintel 32") 112 | parser.add_argument('--alpha', default=0.75) 113 | parser.add_argument('--splatting', help=" max or softmax") 114 | args = parser.parse_args() 115 | 116 | 117 | # load RAFT model 118 | model = torch.nn.DataParallel(RAFT(args)) 119 | model.load_state_dict(torch.load(args.model)) 120 | 121 | model.cuda() 122 | model.eval() 123 | 124 | img1 = imageio.imread("sample/test_1-1.png").astype('uint8') 125 | img2 = imageio.imread("sample/test_1-2.png").astype('uint8') 126 | sample = [img1, img2] 127 | 128 | with torch.no_grad(): 129 | render_local(model, sample, args.save_path, float(args.alpha), str(args.splatting), iters= int(args.iter)) 130 | -------------------------------------------------------------------------------- /sample/test_1-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/RealFlow/52ea80d416d55f454f63650eeba85687055285c2/sample/test_1-1.png -------------------------------------------------------------------------------- /sample/test_1-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/RealFlow/52ea80d416d55f454f63650eeba85687055285c2/sample/test_1-2.png -------------------------------------------------------------------------------- /softmax_splatting/softsplat.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import torch 4 | from torch.autograd import Variable 5 | import cupy 6 | import re 7 | 8 | kernel_Maxsplat_updateOutput = ''' 9 | extern "C" __global__ void kernel_Softsplat_updateOutput( 10 | const int n, 11 | const float* input, 12 | const float* flow, 13 | float* output 14 | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += gridDim.x) { 15 | const int intN = ( intIndex / SIZE_3(output) / SIZE_2(output) / SIZE_1(output) ) % SIZE_0(output); 16 | const int intC = ( intIndex / SIZE_3(output) / SIZE_2(output) ) % SIZE_1(output); 17 | const int intY = ( intIndex / SIZE_3(output) ) % SIZE_2(output); 18 | const int intX = ( intIndex ) % SIZE_3(output); 19 | 20 | float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); 21 | float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); 22 | 23 | int intNorthwestX = (int) (floor(fltOutputX)); 24 | int intNorthwestY = (int) (floor(fltOutputY)); 25 | int intNortheastX = intNorthwestX + 1; 26 | int intNortheastY = intNorthwestY; 27 | int intSouthwestX = intNorthwestX; 28 | int intSouthwestY = intNorthwestY + 1; 29 | int intSoutheastX = intNorthwestX + 1; 30 | int intSoutheastY = intNorthwestY + 1; 31 | 32 | float fltNorthwest = ((float) (intSoutheastX) - fltOutputX) * ((float) (intSoutheastY) - fltOutputY); 33 | float fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (intSouthwestY) - fltOutputY); 34 | float fltSouthwest = ((float) (intNortheastX) - fltOutputX) * (fltOutputY - (float) (intNortheastY)); 35 | float fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * (fltOutputY - (float) (intNorthwestY)); 36 | 37 | 38 | if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(output)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(output)) 39 | & (VALUE_4(input, intN, 3, intY, intX) >= output[OFFSET_4(output, intN, 3, intNorthwestY, intNorthwestX)]) 40 | ) { 41 | atomicExch(&output[OFFSET_4(output, intN, intC, intNorthwestY, intNorthwestX)], VALUE_4(input, intN, intC, intY, intX)); 42 | atomicExch(&output[OFFSET_4(output, intN, 3, intNorthwestY, intNorthwestX)], VALUE_4(input, intN, 3, intY, intX)); 43 | } 44 | 45 | if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(output)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(output)) 46 | & (VALUE_4(input, intN, 3, intY, intX) >= output[OFFSET_4(output, intN, 3, intNortheastY, intNortheastX)]) 47 | ) { 48 | atomicExch(&output[OFFSET_4(output, intN, intC, intNortheastY, intNortheastX)], VALUE_4(input, intN, intC, intY, intX)); 49 | atomicExch(&output[OFFSET_4(output, intN, 3, intNortheastY, intNortheastX)], VALUE_4(input, intN, 3, intY, intX)); 50 | } 51 | 52 | if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(output)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(output)) 53 | & (VALUE_4(input, intN, 3, intY, intX) >= output[OFFSET_4(output, intN, 3, intSouthwestY, intSouthwestX)]) 54 | ){ 55 | atomicExch(&output[OFFSET_4(output, intN, intC, intSouthwestY, intSouthwestX)], VALUE_4(input, intN, intC, intY, intX)); 56 | atomicExch(&output[OFFSET_4(output, intN, 3, intSouthwestY, intSouthwestX)], VALUE_4(input, intN, 3, intY, intX)); 57 | } 58 | 59 | if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(output)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(output)) 60 | & (VALUE_4(input, intN, 3, intY, intX) >= output[OFFSET_4(output, intN, 3, intSoutheastY, intSoutheastX)]) 61 | ){ 62 | atomicExch(&output[OFFSET_4(output, intN, intC, intSoutheastY, intSoutheastX)], VALUE_4(input, intN, intC, intY, intX)); 63 | atomicExch(&output[OFFSET_4(output, intN, 3, intSoutheastY, intSoutheastX)], VALUE_4(input, intN, 3, intY, intX)); 64 | } 65 | } } 66 | ''' 67 | 68 | kernel_Softsplat_updateOutput = ''' 69 | extern "C" __global__ void kernel_Softsplat_updateOutput( 70 | const int n, 71 | const float* input, 72 | const float* flow, 73 | float* output 74 | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { 75 | const int intN = ( intIndex / SIZE_3(output) / SIZE_2(output) / SIZE_1(output) ) % SIZE_0(output); 76 | const int intC = ( intIndex / SIZE_3(output) / SIZE_2(output) ) % SIZE_1(output); 77 | const int intY = ( intIndex / SIZE_3(output) ) % SIZE_2(output); 78 | const int intX = ( intIndex ) % SIZE_3(output); 79 | 80 | float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); 81 | float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); 82 | 83 | int intNorthwestX = (int) (floor(fltOutputX)); 84 | int intNorthwestY = (int) (floor(fltOutputY)); 85 | int intNortheastX = intNorthwestX + 1; 86 | int intNortheastY = intNorthwestY; 87 | int intSouthwestX = intNorthwestX; 88 | int intSouthwestY = intNorthwestY + 1; 89 | int intSoutheastX = intNorthwestX + 1; 90 | int intSoutheastY = intNorthwestY + 1; 91 | 92 | float fltNorthwest = ((float) (intSoutheastX) - fltOutputX) * ((float) (intSoutheastY) - fltOutputY); 93 | float fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (intSouthwestY) - fltOutputY); 94 | float fltSouthwest = ((float) (intNortheastX) - fltOutputX) * (fltOutputY - (float) (intNortheastY)); 95 | float fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * (fltOutputY - (float) (intNorthwestY)); 96 | 97 | if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(output)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(output))) { 98 | atomicAdd(&output[OFFSET_4(output, intN, intC, intNorthwestY, intNorthwestX)], VALUE_4(input, intN, intC, intY, intX) * fltNorthwest); 99 | } 100 | 101 | if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(output)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(output))) { 102 | atomicAdd(&output[OFFSET_4(output, intN, intC, intNortheastY, intNortheastX)], VALUE_4(input, intN, intC, intY, intX) * fltNortheast); 103 | } 104 | 105 | if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(output)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(output))) { 106 | atomicAdd(&output[OFFSET_4(output, intN, intC, intSouthwestY, intSouthwestX)], VALUE_4(input, intN, intC, intY, intX) * fltSouthwest); 107 | } 108 | 109 | if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(output)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(output))) { 110 | atomicAdd(&output[OFFSET_4(output, intN, intC, intSoutheastY, intSoutheastX)], VALUE_4(input, intN, intC, intY, intX) * fltSoutheast); 111 | } 112 | } } 113 | ''' 114 | 115 | kernel_Softsplat_updateGradInput = ''' 116 | extern "C" __global__ void kernel_Softsplat_updateGradInput( 117 | const int n, 118 | const float* input, 119 | const float* flow, 120 | const float* gradOutput, 121 | float* gradInput, 122 | float* gradFlow 123 | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { 124 | const int intN = ( intIndex / SIZE_3(gradInput) / SIZE_2(gradInput) / SIZE_1(gradInput) ) % SIZE_0(gradInput); 125 | const int intC = ( intIndex / SIZE_3(gradInput) / SIZE_2(gradInput) ) % SIZE_1(gradInput); 126 | const int intY = ( intIndex / SIZE_3(gradInput) ) % SIZE_2(gradInput); 127 | const int intX = ( intIndex ) % SIZE_3(gradInput); 128 | 129 | float fltGradInput = 0.0; 130 | 131 | float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); 132 | float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); 133 | 134 | int intNorthwestX = (int) (floor(fltOutputX)); 135 | int intNorthwestY = (int) (floor(fltOutputY)); 136 | int intNortheastX = intNorthwestX + 1; 137 | int intNortheastY = intNorthwestY; 138 | int intSouthwestX = intNorthwestX; 139 | int intSouthwestY = intNorthwestY + 1; 140 | int intSoutheastX = intNorthwestX + 1; 141 | int intSoutheastY = intNorthwestY + 1; 142 | 143 | float fltNorthwest = ((float) (intSoutheastX) - fltOutputX) * ((float) (intSoutheastY) - fltOutputY); 144 | float fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (intSouthwestY) - fltOutputY); 145 | float fltSouthwest = ((float) (intNortheastX) - fltOutputX) * (fltOutputY - (float) (intNortheastY)); 146 | float fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * (fltOutputY - (float) (intNorthwestY)); 147 | 148 | if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(gradOutput)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(gradOutput))) { 149 | fltGradInput += VALUE_4(gradOutput, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest; 150 | } 151 | 152 | if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(gradOutput)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(gradOutput))) { 153 | fltGradInput += VALUE_4(gradOutput, intN, intC, intNortheastY, intNortheastX) * fltNortheast; 154 | } 155 | 156 | if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(gradOutput)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(gradOutput))) { 157 | fltGradInput += VALUE_4(gradOutput, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest; 158 | } 159 | 160 | if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(gradOutput)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(gradOutput))) { 161 | fltGradInput += VALUE_4(gradOutput, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast; 162 | } 163 | 164 | gradInput[intIndex] = fltGradInput; 165 | } } 166 | ''' 167 | 168 | kernel_Softsplat_updateGradFlow = ''' 169 | extern "C" __global__ void kernel_Softsplat_updateGradFlow( 170 | const int n, 171 | const float* input, 172 | const float* flow, 173 | const float* gradOutput, 174 | float* gradInput, 175 | float* gradFlow 176 | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { 177 | float fltGradFlow = 0.0; 178 | 179 | const int intN = ( intIndex / SIZE_3(gradFlow) / SIZE_2(gradFlow) / SIZE_1(gradFlow) ) % SIZE_0(gradFlow); 180 | const int intC = ( intIndex / SIZE_3(gradFlow) / SIZE_2(gradFlow) ) % SIZE_1(gradFlow); 181 | const int intY = ( intIndex / SIZE_3(gradFlow) ) % SIZE_2(gradFlow); 182 | const int intX = ( intIndex ) % SIZE_3(gradFlow); 183 | 184 | float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); 185 | float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); 186 | 187 | int intNorthwestX = (int) (floor(fltOutputX)); 188 | int intNorthwestY = (int) (floor(fltOutputY)); 189 | int intNortheastX = intNorthwestX + 1; 190 | int intNortheastY = intNorthwestY; 191 | int intSouthwestX = intNorthwestX; 192 | int intSouthwestY = intNorthwestY + 1; 193 | int intSoutheastX = intNorthwestX + 1; 194 | int intSoutheastY = intNorthwestY + 1; 195 | 196 | float fltNorthwest = 0.0; 197 | float fltNortheast = 0.0; 198 | float fltSouthwest = 0.0; 199 | float fltSoutheast = 0.0; 200 | 201 | if (intC == 0) { 202 | fltNorthwest = ((float) (-1.0)) * ((float) (intSoutheastY) - fltOutputY); 203 | fltNortheast = ((float) (+1.0)) * ((float) (intSouthwestY) - fltOutputY); 204 | fltSouthwest = ((float) (-1.0)) * (fltOutputY - (float) (intNortheastY)); 205 | fltSoutheast = ((float) (+1.0)) * (fltOutputY - (float) (intNorthwestY)); 206 | 207 | } else if (intC == 1) { 208 | fltNorthwest = ((float) (intSoutheastX) - fltOutputX) * ((float) (-1.0)); 209 | fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (-1.0)); 210 | fltSouthwest = ((float) (intNortheastX) - fltOutputX) * ((float) (+1.0)); 211 | fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * ((float) (+1.0)); 212 | 213 | } 214 | 215 | for (int intChannel = 0; intChannel < SIZE_1(gradOutput); intChannel += 1) { 216 | float fltInput = VALUE_4(input, intN, intChannel, intY, intX); 217 | 218 | if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(gradOutput)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(gradOutput))) { 219 | fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intNorthwestY, intNorthwestX) * fltNorthwest; 220 | } 221 | 222 | if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(gradOutput)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(gradOutput))) { 223 | fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intNortheastY, intNortheastX) * fltNortheast; 224 | } 225 | 226 | if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(gradOutput)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(gradOutput))) { 227 | fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intSouthwestY, intSouthwestX) * fltSouthwest; 228 | } 229 | 230 | if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(gradOutput)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(gradOutput))) { 231 | fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intSoutheastY, intSoutheastX) * fltSoutheast; 232 | } 233 | } 234 | 235 | gradFlow[intIndex] = fltGradFlow; 236 | } } 237 | ''' 238 | 239 | def cupy_kernel(strFunction, objVariables): 240 | strKernel = globals()[strFunction] 241 | 242 | while True: 243 | objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) 244 | 245 | 246 | if objMatch is None: 247 | break 248 | # end 249 | 250 | intArg = int(objMatch.group(2)) 251 | 252 | strTensor = objMatch.group(4) 253 | 254 | intSizes = objVariables[strTensor].size() 255 | 256 | strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg])) 257 | 258 | # end 259 | 260 | while True: 261 | objMatch = re.search('(OFFSET_)([0-4])(\()([^\)]+)(\))', strKernel) 262 | 263 | if objMatch is None: 264 | break 265 | # end 266 | 267 | 268 | intArgs = int(objMatch.group(2)) 269 | 270 | strArgs = objMatch.group(4).split(',') 271 | 272 | strTensor = strArgs[0] 273 | 274 | intStrides = objVariables[strTensor].stride() 275 | 276 | 277 | strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ] 278 | 279 | strKernel = strKernel.replace(objMatch.group(0), '(' + str.join('+', strIndex) + ')') 280 | 281 | # end 282 | 283 | while True: 284 | objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel) 285 | 286 | if objMatch is None: 287 | break 288 | # end 289 | 290 | intArgs = int(objMatch.group(2)) 291 | strArgs = objMatch.group(4).split(',') 292 | 293 | strTensor = strArgs[0] 294 | intStrides = objVariables[strTensor].stride() 295 | 296 | 297 | strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ] 298 | 299 | strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']') 300 | # end 301 | 302 | return strKernel 303 | # end 304 | 305 | @cupy.memoize(for_each_device=True) 306 | def cupy_launch(strFunction, strKernel): 307 | return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction) 308 | # end 309 | 310 | class _FunctionSoftsplat(torch.autograd.Function): 311 | @staticmethod 312 | def forward(self, input, flow, strType): 313 | self.save_for_backward(input, flow) 314 | 315 | intSamples = input.shape[0] 316 | intInputDepth, intInputHeight, intInputWidth = input.shape[1], input.shape[2], input.shape[3] 317 | intFlowDepth, intFlowHeight, intFlowWidth = flow.shape[1], flow.shape[2], flow.shape[3] 318 | 319 | assert(intFlowDepth == 2) 320 | assert(intInputHeight == intFlowHeight) 321 | assert(intInputWidth == intFlowWidth) 322 | 323 | input = input.contiguous(); assert(input.is_cuda == True) 324 | flow = flow.contiguous(); assert(flow.is_cuda == True) 325 | 326 | output = input.new_zeros([ intSamples, intInputDepth, intInputHeight, intInputWidth ]) 327 | 328 | if input.is_cuda == True: 329 | n = output.nelement() 330 | if strType != 'max': 331 | cupy_launch('kernel_Softsplat_updateOutput', cupy_kernel('kernel_Softsplat_updateOutput', { 332 | 'input': input, 333 | 'flow': flow, 334 | 'output': output 335 | }))( 336 | grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), 337 | block=tuple([ 512, 1, 1 ]), 338 | args=[ cupy.int32(n), input.data_ptr(), flow.data_ptr(), output.data_ptr() ] 339 | ) 340 | elif strType == 'max': 341 | cupy_launch('kernel_Softsplat_updateOutput', cupy_kernel('kernel_Maxsplat_updateOutput', { 342 | 'input': input, 343 | 'flow': flow, 344 | 'output': output 345 | }))( 346 | grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), 347 | block=tuple([ 512, 1, 1 ]), 348 | args=[ cupy.int32(n), input.data_ptr(), flow.data_ptr(), output.data_ptr() ] 349 | ) 350 | elif input.is_cuda == False: 351 | raise NotImplementedError() 352 | 353 | # end 354 | return output 355 | # end 356 | 357 | @staticmethod 358 | def backward(self, gradOutput): 359 | input, flow = self.saved_tensors 360 | 361 | intSamples = input.shape[0] 362 | intInputDepth, intInputHeight, intInputWidth = input.shape[1], input.shape[2], input.shape[3] 363 | intFlowDepth, intFlowHeight, intFlowWidth = flow.shape[1], flow.shape[2], flow.shape[3] 364 | 365 | assert(intFlowDepth == 2) 366 | assert(intInputHeight == intFlowHeight) 367 | assert(intInputWidth == intFlowWidth) 368 | 369 | gradOutput = gradOutput.contiguous(); assert(gradOutput.is_cuda == True) 370 | 371 | gradInput = input.new_zeros([ intSamples, intInputDepth, intInputHeight, intInputWidth ]) if self.needs_input_grad[0] == True else None 372 | gradFlow = input.new_zeros([ intSamples, intFlowDepth, intFlowHeight, intFlowWidth ]) if self.needs_input_grad[1] == True else None 373 | 374 | if input.is_cuda == True: 375 | if gradInput is not None: 376 | n = gradInput.nelement() 377 | cupy_launch('kernel_Softsplat_updateGradInput', cupy_kernel('kernel_Softsplat_updateGradInput', { 378 | 'input': input, 379 | 'flow': flow, 380 | 'gradOutput': gradOutput, 381 | 'gradInput': gradInput, 382 | 'gradFlow': gradFlow 383 | }))( 384 | grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), 385 | block=tuple([ 512, 1, 1 ]), 386 | args=[ cupy.int32(n), input.data_ptr(), flow.data_ptr(), gradOutput.data_ptr(), gradInput.data_ptr(), None ] 387 | ) 388 | # end 389 | 390 | # if gradFlow is not None: 391 | # n = gradFlow.nelement() 392 | # cupy_launch('kernel_Softsplat_updateGradFlow', cupy_kernel('kernel_Softsplat_updateGradFlow', { 393 | # 'input': input, 394 | # 'flow': flow, 395 | # 'gradOutput': gradOutput, 396 | # 'gradInput': gradInput, 397 | # 'gradFlow': gradFlow 398 | # }))( 399 | # grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), 400 | # block=tuple([ 512, 1, 1 ]), 401 | # args=[ cupy.int32(n), input.data_ptr(), flow.data_ptr(), gradOutput.data_ptr(), None, gradFlow.data_ptr() ] 402 | # ) 403 | # end 404 | 405 | elif input.is_cuda == False: 406 | raise NotImplementedError() 407 | 408 | # end 409 | 410 | return gradInput, gradFlow 411 | # end 412 | # end 413 | 414 | def FunctionSoftsplat(tenInput, tenFlow, tenMetric, strType): 415 | assert(tenMetric is None or tenMetric.shape[1] == 1) 416 | assert(strType in ['summation', 'average', 'linear', 'softmax', 'gumbel_softmax', 'max']) 417 | 418 | if strType == 'average': 419 | tenInput = torch.cat([ tenInput, tenInput.new_ones(tenInput.shape[0], 1, tenInput.shape[2], tenInput.shape[3]) ], 1) 420 | 421 | elif strType == 'linear': 422 | tenInput = torch.cat([ tenInput * tenMetric, tenMetric ], 1) 423 | 424 | elif strType == 'softmax': 425 | tenInput = torch.cat([ tenInput * tenMetric.exp(), tenMetric.exp() ], 1) 426 | 427 | elif strType == 'gumbel_softmax': 428 | tenMetric = gumbel_softmax_sample(tenMetric, 0.1) 429 | tenInput = torch.cat([tenInput * tenMetric.exp(), tenMetric.exp()], 1) 430 | 431 | elif strType == 'max': 432 | tenInput = torch.cat([tenInput, tenMetric], 1) 433 | 434 | # end 435 | 436 | tenOutput = _FunctionSoftsplat.apply(tenInput, tenFlow, strType) 437 | 438 | if strType != 'summation' and strType != 'max': 439 | tenNormalize = tenOutput[:, -1:, :, :] 440 | 441 | tenNormalize[tenNormalize == 0.0] = 1.0 442 | 443 | tenOutput = tenOutput[:, :-1, :, :] / tenNormalize 444 | # end 445 | elif strType == 'max': 446 | tenOutput = tenOutput[:, :-1, :, :] 447 | 448 | elif strType == 'summation': 449 | tenOutput = tenOutput 450 | 451 | return tenOutput 452 | # end 453 | 454 | 455 | def gumbel_softmax_sample(logits, temperature): 456 | eps = 1e-20 457 | U = torch.rand(logits.size()).cuda() 458 | sample_gumbel = -Variable(torch.log(-torch.log(U + eps) + eps)) 459 | y = logits + sample_gumbel 460 | return y/temperature 461 | 462 | 463 | class ModuleSoftsplat(torch.nn.Module): 464 | def __init__(self, strType): 465 | super().__init__() 466 | 467 | self.strType = strType 468 | # end 469 | 470 | def forward(self, tenInput, tenFlow, tenMetric): 471 | return FunctionSoftsplat(tenInput, tenFlow, tenMetric, self.strType) 472 | # end 473 | # end -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class FlowReversal(nn.Module): 6 | """docstring for WarpLayer""" 7 | 8 | def __init__(self, ): 9 | super(FlowReversal, self).__init__() 10 | 11 | def forward(self, img, flo): 12 | """ 13 | -img: image (N, C, H, W) 14 | -flo: optical flow (N, 2, H, W) 15 | elements of flo is in [0, H] and [0, W] for dx, dy 16 | 17 | """ 18 | 19 | # (x1, y1) (x1, y2) 20 | # +---------------+ 21 | # | | 22 | # | o(x, y) | 23 | # | | 24 | # | | 25 | # | | 26 | # | | 27 | # +---------------+ 28 | # (x2, y1) (x2, y2) 29 | 30 | N, C, _, _ = img.size() 31 | 32 | # translate start-point optical flow to end-point optical flow 33 | y = flo[:, 0:1:, :] 34 | x = flo[:, 1:2, :, :] 35 | 36 | x = x.repeat(1, C, 1, 1) 37 | y = y.repeat(1, C, 1, 1) 38 | 39 | # ---------------------up left up right down left down right 40 | # Four point of square (x1, y1), (x1, y2), (x2, y1), (y2, y2) 41 | x1 = torch.floor(x) 42 | x2 = x1 + 1 43 | y1 = torch.floor(y) 44 | y2 = y1 + 1 45 | 46 | # firstly, get gaussian weights 47 | w11, w12, w21, w22 = self.get_gaussian_weights(x, y, x1, x2, y1, y2) 48 | # change to bilinear weights? 49 | # w11, w12, w21, w22 = self.get_bilinear_weights(x, y, x1, x2, y1, y2) 50 | 51 | # secondly, sample each weighted corner 52 | img11, o11 = self.sample_one(img, x1, y1, w11) 53 | img12, o12 = self.sample_one(img, x1, y2, w12) 54 | img21, o21 = self.sample_one(img, x2, y1, w21) 55 | img22, o22 = self.sample_one(img, x2, y2, w22) 56 | 57 | imgw = img11 + img12 + img21 + img22 58 | o = o11 + o12 + o21 + o22 59 | 60 | return imgw, o 61 | 62 | def get_gaussian_weights(self, x, y, x1, x2, y1, y2): 63 | w11 = torch.exp(-((x - x1) ** 2 + (y - y1) ** 2)) 64 | w12 = torch.exp(-((x - x1) ** 2 + (y - y2) ** 2)) 65 | w21 = torch.exp(-((x - x2) ** 2 + (y - y1) ** 2)) 66 | w22 = torch.exp(-((x - x2) ** 2 + (y - y2) ** 2)) 67 | 68 | return w11, w12, w21, w22 69 | 70 | def get_bilinear_weights(self, x, y, x1, x2, y1, y2): 71 | w11 = torch.abs((x - x1) * (y - y1)) 72 | w12 = torch.abs((x - x1) * (y - y2)) 73 | w21 = torch.abs(((x - x2) * (y - y1))) 74 | w22 = torch.abs((x - x2) * (y - y2)) 75 | return w22, w21, w12, w11 76 | # return w11, w12, w21, w22 77 | 78 | def sample_one(self, img, shiftx, shifty, weight): 79 | """ 80 | Input: 81 | -img (N, C, H, W) 82 | -shiftx, shifty (N, c, H, W) 83 | """ 84 | is_cuda = img.is_cuda 85 | N, C, H, W = img.size() 86 | # flatten all (all restored as Tensors) 87 | flat_shiftx = shiftx.view(-1) 88 | flat_shifty = shifty.view(-1) 89 | if is_cuda: 90 | flat_basex = torch.arange(0, H, requires_grad=False).view(-1, 1)[None, None].cuda().long().repeat(N, C, 1, W).view(-1) 91 | flat_basey = torch.arange(0, W, requires_grad=False).view(1, -1)[None, None].cuda().long().repeat(N, C, H, 1).view(-1) 92 | else: 93 | flat_basex = torch.arange(0, H, requires_grad=False).view(-1, 1)[None, None].long().repeat(N, C, 1, W).view(-1) 94 | flat_basey = torch.arange(0, W, requires_grad=False).view(1, -1)[None, None].long().repeat(N, C, H, 1).view(-1) 95 | flat_weight = weight.view(-1) 96 | flat_img = img.view(-1) 97 | 98 | # The corresponding positions in I1 99 | if is_cuda: 100 | idxn = torch.arange(0, N, requires_grad=False).view(N, 1, 1, 1).long().cuda().repeat(1, C, H, W).view(-1) 101 | idxc = torch.arange(0, C, requires_grad=False).view(1, C, 1, 1).long().cuda().repeat(N, 1, H, W).view(-1) 102 | else: 103 | idxn = torch.arange(0, N, requires_grad=False).view(N, 1, 1, 1).long().repeat(1, C, H, W).view(-1) 104 | idxc = torch.arange(0, C, requires_grad=False).view(1, C, 1, 1).long().repeat(N, 1, H, W).view(-1) 105 | # ttype = flat_basex.type() 106 | idxx = flat_shiftx.long() + flat_basex 107 | idxy = flat_shifty.long() + flat_basey 108 | 109 | # recording the inside part the shifted 110 | mask = idxx.ge(0) & idxx.lt(H) & idxy.ge(0) & idxy.lt(W) 111 | 112 | # Mask off points out of boundaries 113 | ids = (idxn * C * H * W + idxc * H * W + idxx * W + idxy) 114 | if is_cuda: 115 | ids_mask = torch.masked_select(ids, mask).clone().cuda() 116 | 117 | # Note here! accmulate fla must be true for proper bp 118 | img_warp = torch.zeros([N * C * H * W, ]).cuda() 119 | one_warp = torch.zeros([N * C * H * W, ]).cuda() 120 | else: 121 | ids_mask = torch.masked_select(ids, mask).clone() 122 | 123 | # Note here! accmulate fla must be true for proper bp 124 | img_warp = torch.zeros([N * C * H * W, ]) 125 | one_warp = torch.zeros([N * C * H * W, ]) 126 | img_warp.put_(ids_mask, torch.masked_select(flat_img * flat_weight, mask), accumulate=True) 127 | one_warp.put_(ids_mask, torch.masked_select(flat_weight, mask), accumulate=True) 128 | return img_warp.view(N, C, H, W), one_warp.view(N, C, H, W) 129 | 130 | def forward_occ(self, img, flo, occ): 131 | """ 132 | -img: image (N, C, H, W) 133 | -flo: optical flow (N, 2, H, W) 134 | elements of flo is in [0, H] and [0, W] for dx, dy 135 | 136 | """ 137 | 138 | # (x1, y1) (x1, y2) 139 | # +---------------+ 140 | # | | 141 | # | o(x, y) | 142 | # | | 143 | # | | 144 | # | | 145 | # | | 146 | # +---------------+ 147 | # (x2, y1) (x2, y2) 148 | 149 | N, C, _, _ = img.size() 150 | 151 | # translate start-point optical flow to end-point optical flow 152 | y = flo[:, 0:1:, :] 153 | x = flo[:, 1:2, :, :] 154 | 155 | x = x.repeat(1, C, 1, 1) 156 | y = y.repeat(1, C, 1, 1) 157 | 158 | # ---------------------up left up right down left down right 159 | # Four point of square (x1, y1), (x1, y2), (x2, y1), (y2, y2) 160 | x1 = torch.floor(x) 161 | x2 = x1 + 1 162 | y1 = torch.floor(y) 163 | y2 = y1 + 1 164 | 165 | # firstly, get gaussian weights 166 | # w11, w12, w21, w22 = self.get_gaussian_weights(x, y, x1, x2, y1, y2) 167 | # change to bilinear weights? 168 | w11, w12, w21, w22 = self.get_bilinear_weights(x, y, x1, x2, y1, y2) 169 | 170 | # secondly, sample each weighted corner 171 | img11, o11 = self.sample_one_occ(img, x1, y1, w11 * occ) 172 | img12, o12 = self.sample_one_occ(img, x1, y2, w12 * occ) 173 | img21, o21 = self.sample_one_occ(img, x2, y1, w21 * occ) 174 | img22, o22 = self.sample_one_occ(img, x2, y2, w22 * occ) 175 | 176 | imgw = img11 + img12 + img21 + img22 177 | o = o11 + o12 + o21 + o22 178 | 179 | return imgw, o 180 | 181 | def sample_one_occ(self, img, shiftx, shifty, weight): 182 | """ 183 | Input: 184 | -img (N, C, H, W) 185 | -shiftx, shifty (N, c, H, W) 186 | """ 187 | is_cuda = img.is_cuda 188 | N, C, H, W = img.size() 189 | # flatten all (all restored as Tensors) 190 | flat_shiftx = shiftx.view(-1) 191 | flat_shifty = shifty.view(-1) 192 | if is_cuda: 193 | flat_basex = torch.arange(0, H, requires_grad=False).view(-1, 1)[None, None].cuda().long().repeat(N, C, 1, W).view(-1) 194 | flat_basey = torch.arange(0, W, requires_grad=False).view(1, -1)[None, None].cuda().long().repeat(N, C, H, 1).view(-1) 195 | # The corresponding positions in I1 196 | idxn = torch.arange(0, N, requires_grad=False).view(N, 1, 1, 1).long().cuda().repeat(1, C, H, W).view(-1) 197 | idxc = torch.arange(0, C, requires_grad=False).view(1, C, 1, 1).long().cuda().repeat(N, 1, H, W).view(-1) 198 | else: 199 | flat_basex = torch.arange(0, H, requires_grad=False).view(-1, 1)[None, None].long().repeat(N, C, 1, W).view(-1) 200 | flat_basey = torch.arange(0, W, requires_grad=False).view(1, -1)[None, None].long().repeat(N, C, H, 1).view(-1) 201 | # The corresponding positions in I1 202 | idxn = torch.arange(0, N, requires_grad=False).view(N, 1, 1, 1).long().repeat(1, C, H, W).view(-1) 203 | idxc = torch.arange(0, C, requires_grad=False).view(1, C, 1, 1).long().repeat(N, 1, H, W).view(-1) 204 | flat_weight = weight.view(-1) 205 | flat_img = img.view(-1) 206 | # ttype = flat_basex.type() 207 | idxx = flat_shiftx.long() + flat_basex 208 | idxy = flat_shifty.long() + flat_basey 209 | 210 | # recording the inside part the shifted 211 | mask = idxx.ge(0) & idxx.lt(H) & idxy.ge(0) & idxy.lt(W) 212 | 213 | # Mask off points out of boundaries 214 | ids = (idxn * C * H * W + idxc * H * W + idxx * W + idxy) 215 | if is_cuda: 216 | ids_mask = torch.masked_select(ids, mask).clone().cuda() 217 | # Note here! accmulate fla must be true for proper bp 218 | img_warp = torch.zeros([N * C * H * W, ]).cuda() 219 | one_warp = torch.zeros([N * C * H * W, ]).cuda() 220 | else: 221 | ids_mask = torch.masked_select(ids, mask).clone() 222 | # Note here! accmulate fla must be true for proper bp 223 | img_warp = torch.zeros([N * C * H * W, ]) 224 | one_warp = torch.zeros([N * C * H * W, ]) 225 | 226 | img_warp.put_(ids_mask, torch.masked_select(flat_img * flat_weight, mask), accumulate=True) 227 | one_warp.put_(ids_mask, torch.masked_select(flat_weight, mask), accumulate=True) 228 | return img_warp.view(N, C, H, W), one_warp.view(N, C, H, W) 229 | -------------------------------------------------------------------------------- /utils/video2img.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("../") 3 | import cv2 4 | import os 5 | from utils_luo.tools import file_tools 6 | 7 | def save_img(): 8 | video_path = '/data/BDD100k/bdd100k/videos/train/' 9 | img_path = '/data/BDD100k/bdd100k/imgs/train/' 10 | if not os.path.exists(img_path): 11 | os.makedirs(img_path) 12 | videos = os.listdir(video_path) 13 | for video_name in videos: 14 | file_name = video_name.split('.')[0] 15 | folder_name = img_path + file_name 16 | if not os.path.exists(folder_name): 17 | os.makedirs(folder_name) 18 | vc = cv2.VideoCapture(video_path+video_name) #读入视频文件 19 | c = 0 20 | i = 0 21 | rval=vc.isOpened() 22 | 23 | while rval: #循环读取视频帧 24 | c = c + 1 25 | rval, frame = vc.read() 26 | pic_path = folder_name+'/' 27 | if rval and c % 5 == 0: 28 | i = i+1 29 | cv2.imwrite(pic_path + file_name + '_' + str(i) + '.jpg', frame[:512, :960, :]) 30 | cv2.waitKey(1) 31 | elif not rval: 32 | break 33 | vc.release() 34 | print('save_success') 35 | print(folder_name) 36 | 37 | save_img() --------------------------------------------------------------------------------