├── .github └── workflows │ └── publish.yml ├── .gitignore ├── LICENSE.txt ├── README.md ├── __init__.py ├── examples └── propainter-inpainting-workflow.json ├── model ├── __init__.py ├── canny │ ├── canny_filter.py │ ├── filter.py │ ├── gaussian.py │ ├── kernels.py │ └── sobel.py ├── misc.py ├── modules │ ├── RAFT │ │ ├── __init__.py │ │ ├── corr.py │ │ ├── datasets.py │ │ ├── demo.py │ │ ├── extractor.py │ │ ├── raft.py │ │ ├── update.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── augmentor.py │ │ │ ├── flow_viz.py │ │ │ ├── flow_viz_pt.py │ │ │ ├── frame_utils.py │ │ │ └── utils.py │ ├── base_module.py │ ├── deformconv.py │ ├── flow_comp_raft.py │ ├── flow_loss_utils.py │ ├── sparse_transformer.py │ └── spectral_norm.py ├── propainter.py ├── recurrent_flow_completion.py └── vgg_arch.py ├── propainter_inference.py ├── propainter_nodes.py ├── pyproject.toml ├── requirements.txt └── utils ├── __init__.py ├── download_utils.py ├── image_utils.py └── model_utils.py /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | - master 8 | paths: 9 | - "pyproject.toml" 10 | 11 | jobs: 12 | publish-node: 13 | name: Publish Custom Node to registry 14 | runs-on: ubuntu-latest 15 | steps: 16 | - name: Check out code 17 | uses: actions/checkout@v4 18 | - name: Publish Custom Node 19 | uses: Comfy-Org/publish-node-action@main 20 | with: 21 | ## Add your own personal access token to your Github Repository secrets and reference it here. 22 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} 23 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | results/ 2 | weights/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/#use-with-ide 113 | .pdm.toml 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | mypy.ini 149 | 150 | # ruff 151 | .ruff_cache/ 152 | ruff.toml 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ProPainter Nodes for ComfyUI 2 | 3 | [ComfyUI](https://github.com/comfyanonymous/ComfyUI) implementation of [ProPainter](https://github.com/sczhou/ProPainter) for video inpainting. ProPainter is a framework that utilizes flow-based propagation and spatiotemporal transformer to enable advanced video frame editing for seamless inpainting tasks. 4 | 5 | ## Features 6 | 7 | #### 👨🏻‍🎨 Object Removal 8 | 9 | 10 | 13 | 16 | 17 |
11 | 12 | 14 | 15 |
18 | 19 | #### 🎨 Video Completion 20 | 21 | 22 | 25 | 28 | 29 |
23 | 24 | 26 | 27 |
30 | 31 | ## Installation 32 | ### ComfyUI Manager: 33 | You can use [ComfyUI Manager](https://github.com/ltdrdata/ComfyUI-Manager) to install the nodes: 34 | 1. Search for `ComfyUI ProPainter Nodes` and author `daniabib`. 35 | 36 | ### Manual Installation: 37 | 1. Clone this repository to `ComfyUI/custom_nodes`: 38 | ```bash 39 | git clone https://github.com/daniabib/ComfyUI_ProPainter_Nodes 40 | ``` 41 | 42 | 2. Install the required dependencies: 43 | ```bash 44 | pip install -r requirements.txt 45 | ``` 46 | 47 | Models will be automatically downloaded to the `weights` folder. 48 | 49 | ## Examples 50 | **Basic Inpainting Workflow** 51 | 52 | https://github.com/daniabib/ComfyUI_ProPainter_Nodes/assets/33937060/56244d09-fe89-4af2-916b-e8d903752f0d 53 | 54 | https://github.com/daniabib/ComfyUI_ProPainter_Nodes/blob/main/examples/propainter-inpainting-workflow.json 55 | 56 | ## Others suggested nodes 57 | * [VideoHelperSuite](https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite) for loading and saving the video frames. 58 | * [ComfyUI-YoloWorld-EfficientSAM](https://github.com/ZHO-ZHO-ZHO/ComfyUI-YoloWorld-EfficientSAM) for masking images using prompt. 59 | 60 | ## Nodes Reference 61 | **🚧 Section under construction** 62 | ### ProPainter Inpainting 63 | 64 | #### Input Parameters: 65 | - `image`: The video frames to be inpainted. 66 | - `mask`: The mask indicating the regions to be inpainted. Mask must have same size of video frames. 67 | - `width`: Width of the output images. (default: 640). 68 | - `height`: Height of the output images. (default: 360). 69 | - `mask_dilates`: Dilation size for the mask (default: 5). 70 | - `flow_mask_dilates`: Dilation size for the flow mask (default: 8). 71 | - `ref_stride`: Stride for reference frames (default: 10). 72 | - `neighbor_length`: Length of the neighborhood for inpainting (default: 10). 73 | - `subvideo_length`: Length of subvideos for processing (default: 80). 74 | - `raft_iter`): Number of iterations for RAFT model (default: 20). 75 | - `fp16`: Enable or disable FP16 precision (default: "enable"). 76 | 77 | #### Output: 78 | - `IMAGE`: The inpainted video frames. 79 | - `FLOW_MASK`: The flow mask used during inpainting. 80 | - `MASK_DILATE`: The dilated mask used during inpainting. 81 | 82 | ### ProPainter Outpainting 83 | **Note**: The authors of the paper didn't mention the outpainting task for their framework, but there is an option for it in the original code. The results aren't very good but I decided to implement a node for it anyway. 84 | 85 | #### Input Parameters: 86 | - `image`: The video frames to be outpainted. 87 | - `width`: Width of the video frames (default: 640). 88 | - `height`: Height of the video frames (default: 360). 89 | - `width_scale`: Scale factor for width expansion (default: 1.2). 90 | - `height_scale`: Scale factor for height expansion (default: 1.0). 91 | - `mask_dilates`: Dilation size for the mask (default: 5). 92 | - `flow_mask_dilates`: Dilation size for the flow mask (default: 8). 93 | - `ref_stride`: Stride for reference frames (default: 10). 94 | - `neighbor_length`: Length of the neighborhood for outpainting (default: 10). 95 | - `subvideo_length`: Length of subvideos for processing (default: 80). 96 | - `raft_iter`: Number of iterations for RAFT model (default: 20). 97 | - `fp16`: Enable or disable FP16 precision (default: "disable"). 98 | 99 | #### Output: 100 | - `IMAGE`: The outpainted video frames. 101 | - `OUTPAINT_MASK`: The mask used during outpainting. 102 | - `output_width`: The width of the outpainted frames. 103 | - `output_height`: The height of the outpainted frames. 104 | 105 | ## License 106 | The ProPainter models and code are licensed under [NTU S-Lab License 1.0](https://github.com/sczhou/ProPainter/blob/main/LICENSE). 107 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .propainter_nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 2 | 3 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] 4 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /model/canny/canny_filter.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Tuple 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | 8 | from .gaussian import gaussian_blur2d 9 | from .kernels import get_canny_nms_kernel, get_hysteresis_kernel 10 | from .sobel import spatial_gradient 11 | 12 | 13 | def rgb_to_grayscale(image, rgb_weights=None): 14 | if len(image.shape) < 3 or image.shape[-3] != 3: 15 | raise ValueError( 16 | f"Input size must have a shape of (*, 3, H, W). Got {image.shape}" 17 | ) 18 | 19 | if rgb_weights is None: 20 | # 8 bit images 21 | if image.dtype == torch.uint8: 22 | rgb_weights = torch.tensor( 23 | [76, 150, 29], device=image.device, dtype=torch.uint8 24 | ) 25 | # floating point images 26 | elif image.dtype in (torch.float16, torch.float32, torch.float64): 27 | rgb_weights = torch.tensor( 28 | [0.299, 0.587, 0.114], device=image.device, dtype=image.dtype 29 | ) 30 | else: 31 | raise TypeError(f"Unknown data type: {image.dtype}") 32 | else: 33 | # is tensor that we make sure is in the same device/dtype 34 | rgb_weights = rgb_weights.to(image) 35 | 36 | # unpack the color image channels with RGB order 37 | r = image[..., 0:1, :, :] 38 | g = image[..., 1:2, :, :] 39 | b = image[..., 2:3, :, :] 40 | 41 | w_r, w_g, w_b = rgb_weights.unbind() 42 | return w_r * r + w_g * g + w_b * b 43 | 44 | 45 | def canny( 46 | input: torch.Tensor, 47 | low_threshold: float = 0.1, 48 | high_threshold: float = 0.2, 49 | kernel_size: Tuple[int, int] = (5, 5), 50 | sigma: Tuple[float, float] = (1, 1), 51 | hysteresis: bool = True, 52 | eps: float = 1e-6, 53 | ) -> Tuple[torch.Tensor, torch.Tensor]: 54 | r"""Find edges of the input image and filters them using the Canny algorithm. 55 | 56 | .. image:: _static/img/canny.png 57 | 58 | Args: 59 | input: input image tensor with shape :math:`(B,C,H,W)`. 60 | low_threshold: lower threshold for the hysteresis procedure. 61 | high_threshold: upper threshold for the hysteresis procedure. 62 | kernel_size: the size of the kernel for the gaussian blur. 63 | sigma: the standard deviation of the kernel for the gaussian blur. 64 | hysteresis: if True, applies the hysteresis edge tracking. 65 | Otherwise, the edges are divided between weak (0.5) and strong (1) edges. 66 | eps: regularization number to avoid NaN during backprop. 67 | 68 | Returns: 69 | - the canny edge magnitudes map, shape of :math:`(B,1,H,W)`. 70 | - the canny edge detection filtered by thresholds and hysteresis, shape of :math:`(B,1,H,W)`. 71 | 72 | .. note:: 73 | See a working example `here `__. 75 | 76 | Example: 77 | >>> input = torch.rand(5, 3, 4, 4) 78 | >>> magnitude, edges = canny(input) # 5x3x4x4 79 | >>> magnitude.shape 80 | torch.Size([5, 1, 4, 4]) 81 | >>> edges.shape 82 | torch.Size([5, 1, 4, 4]) 83 | """ 84 | if not isinstance(input, torch.Tensor): 85 | raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") 86 | 87 | if not len(input.shape) == 4: 88 | raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}") 89 | 90 | if low_threshold > high_threshold: 91 | raise ValueError( 92 | f"Invalid input thresholds. low_threshold should be smaller than the high_threshold. Got: {low_threshold}>{high_threshold}" 93 | ) 94 | 95 | if low_threshold < 0 and low_threshold > 1: 96 | raise ValueError( 97 | f"Invalid input threshold. low_threshold should be in range (0,1). Got: {low_threshold}" 98 | ) 99 | 100 | if high_threshold < 0 and high_threshold > 1: 101 | raise ValueError( 102 | f"Invalid input threshold. high_threshold should be in range (0,1). Got: {high_threshold}" 103 | ) 104 | 105 | device: torch.device = input.device 106 | dtype: torch.dtype = input.dtype 107 | 108 | # To Grayscale 109 | if input.shape[1] == 3: 110 | input = rgb_to_grayscale(input) 111 | 112 | # Gaussian filter 113 | blurred: torch.Tensor = gaussian_blur2d(input, kernel_size, sigma) 114 | 115 | # Compute the gradients 116 | gradients: torch.Tensor = spatial_gradient(blurred, normalized=False) 117 | 118 | # Unpack the edges 119 | gx: torch.Tensor = gradients[:, :, 0] 120 | gy: torch.Tensor = gradients[:, :, 1] 121 | 122 | # Compute gradient magnitude and angle 123 | magnitude: torch.Tensor = torch.sqrt(gx * gx + gy * gy + eps) 124 | angle: torch.Tensor = torch.atan2(gy, gx) 125 | 126 | # Radians to Degrees 127 | angle = 180.0 * angle / math.pi 128 | 129 | # Round angle to the nearest 45 degree 130 | angle = torch.round(angle / 45) * 45 131 | 132 | # Non-maximal suppression 133 | nms_kernels: torch.Tensor = get_canny_nms_kernel(device, dtype) 134 | nms_magnitude: torch.Tensor = F.conv2d( 135 | magnitude, nms_kernels, padding=nms_kernels.shape[-1] // 2 136 | ) 137 | 138 | # Get the indices for both directions 139 | positive_idx: torch.Tensor = (angle / 45) % 8 140 | positive_idx = positive_idx.long() 141 | 142 | negative_idx: torch.Tensor = ((angle / 45) + 4) % 8 143 | negative_idx = negative_idx.long() 144 | 145 | # Apply the non-maximum suppression to the different directions 146 | channel_select_filtered_positive: torch.Tensor = torch.gather( 147 | nms_magnitude, 1, positive_idx 148 | ) 149 | channel_select_filtered_negative: torch.Tensor = torch.gather( 150 | nms_magnitude, 1, negative_idx 151 | ) 152 | 153 | channel_select_filtered: torch.Tensor = torch.stack( 154 | [channel_select_filtered_positive, channel_select_filtered_negative], 1 155 | ) 156 | 157 | is_max: torch.Tensor = channel_select_filtered.min(dim=1)[0] > 0.0 158 | 159 | magnitude = magnitude * is_max 160 | 161 | # Threshold 162 | edges: torch.Tensor = F.threshold(magnitude, low_threshold, 0.0) 163 | 164 | low: torch.Tensor = magnitude > low_threshold 165 | high: torch.Tensor = magnitude > high_threshold 166 | 167 | edges = low * 0.5 + high * 0.5 168 | edges = edges.to(dtype) 169 | 170 | # Hysteresis 171 | if hysteresis: 172 | edges_old: torch.Tensor = -torch.ones( 173 | edges.shape, device=edges.device, dtype=dtype 174 | ) 175 | hysteresis_kernels: torch.Tensor = get_hysteresis_kernel(device, dtype) 176 | 177 | while ((edges_old - edges).abs() != 0).any(): 178 | weak: torch.Tensor = (edges == 0.5).float() 179 | strong: torch.Tensor = (edges == 1).float() 180 | 181 | hysteresis_magnitude: torch.Tensor = F.conv2d( 182 | edges, hysteresis_kernels, padding=hysteresis_kernels.shape[-1] // 2 183 | ) 184 | hysteresis_magnitude = ( 185 | (hysteresis_magnitude == 1).any(1, keepdim=True).to(dtype) 186 | ) 187 | hysteresis_magnitude = hysteresis_magnitude * weak + strong 188 | 189 | edges_old = edges.clone() 190 | edges = hysteresis_magnitude + (hysteresis_magnitude == 0) * weak * 0.5 191 | 192 | edges = hysteresis_magnitude 193 | 194 | return magnitude, edges 195 | 196 | 197 | class Canny(nn.Module): 198 | r"""Module that finds edges of the input image and filters them using the Canny algorithm. 199 | 200 | Args: 201 | input: input image tensor with shape :math:`(B,C,H,W)`. 202 | low_threshold: lower threshold for the hysteresis procedure. 203 | high_threshold: upper threshold for the hysteresis procedure. 204 | kernel_size: the size of the kernel for the gaussian blur. 205 | sigma: the standard deviation of the kernel for the gaussian blur. 206 | hysteresis: if True, applies the hysteresis edge tracking. 207 | Otherwise, the edges are divided between weak (0.5) and strong (1) edges. 208 | eps: regularization number to avoid NaN during backprop. 209 | 210 | Returns: 211 | - the canny edge magnitudes map, shape of :math:`(B,1,H,W)`. 212 | - the canny edge detection filtered by thresholds and hysteresis, shape of :math:`(B,1,H,W)`. 213 | 214 | Example: 215 | >>> input = torch.rand(5, 3, 4, 4) 216 | >>> magnitude, edges = Canny()(input) # 5x3x4x4 217 | >>> magnitude.shape 218 | torch.Size([5, 1, 4, 4]) 219 | >>> edges.shape 220 | torch.Size([5, 1, 4, 4]) 221 | """ 222 | 223 | def __init__( 224 | self, 225 | low_threshold: float = 0.1, 226 | high_threshold: float = 0.2, 227 | kernel_size: Tuple[int, int] = (5, 5), 228 | sigma: Tuple[float, float] = (1, 1), 229 | hysteresis: bool = True, 230 | eps: float = 1e-6, 231 | ) -> None: 232 | super().__init__() 233 | 234 | if low_threshold > high_threshold: 235 | raise ValueError( 236 | f"Invalid input thresholds. low_threshold should be\ 237 | smaller than the high_threshold. Got: {low_threshold}>{high_threshold}" 238 | ) 239 | 240 | if low_threshold < 0 or low_threshold > 1: 241 | raise ValueError( 242 | f"Invalid input threshold. low_threshold should be in range (0,1). Got: {low_threshold}" 243 | ) 244 | 245 | if high_threshold < 0 or high_threshold > 1: 246 | raise ValueError( 247 | f"Invalid input threshold. high_threshold should be in range (0,1). Got: {high_threshold}" 248 | ) 249 | 250 | # Gaussian blur parameters 251 | self.kernel_size = kernel_size 252 | self.sigma = sigma 253 | 254 | # Double threshold 255 | self.low_threshold = low_threshold 256 | self.high_threshold = high_threshold 257 | 258 | # Hysteresis 259 | self.hysteresis = hysteresis 260 | 261 | self.eps: float = eps 262 | 263 | def __repr__(self) -> str: 264 | return "".join( 265 | ( 266 | f"{type(self).__name__}(", 267 | ", ".join( 268 | f"{name}={getattr(self, name)}" 269 | for name in sorted(self.__dict__) 270 | if not name.startswith("_") 271 | ), 272 | ")", 273 | ) 274 | ) 275 | 276 | def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 277 | return canny( 278 | input, 279 | self.low_threshold, 280 | self.high_threshold, 281 | self.kernel_size, 282 | self.sigma, 283 | self.hysteresis, 284 | self.eps, 285 | ) 286 | -------------------------------------------------------------------------------- /model/canny/filter.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from .kernels import normalize_kernel2d 7 | 8 | 9 | def _compute_padding(kernel_size: List[int]) -> List[int]: 10 | """Compute padding tuple.""" 11 | # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom) 12 | # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad 13 | if len(kernel_size) < 2: 14 | raise AssertionError(kernel_size) 15 | computed = [k - 1 for k in kernel_size] 16 | 17 | # for even kernels we need to do asymmetric padding :( 18 | out_padding = 2 * len(kernel_size) * [0] 19 | 20 | for i in range(len(kernel_size)): 21 | computed_tmp = computed[-(i + 1)] 22 | 23 | pad_front = computed_tmp // 2 24 | pad_rear = computed_tmp - pad_front 25 | 26 | out_padding[2 * i + 0] = pad_front 27 | out_padding[2 * i + 1] = pad_rear 28 | 29 | return out_padding 30 | 31 | 32 | def filter2d( 33 | input: torch.Tensor, 34 | kernel: torch.Tensor, 35 | border_type: str = "reflect", 36 | normalized: bool = False, 37 | padding: str = "same", 38 | ) -> torch.Tensor: 39 | r"""Convolve a tensor with a 2d kernel. 40 | 41 | The function applies a given kernel to a tensor. The kernel is applied 42 | independently at each depth channel of the tensor. Before applying the 43 | kernel, the function applies padding according to the specified mode so 44 | that the output remains in the same shape. 45 | 46 | Args: 47 | input: the input tensor with shape of 48 | :math:`(B, C, H, W)`. 49 | kernel: the kernel to be convolved with the input 50 | tensor. The kernel shape must be :math:`(1, kH, kW)` or :math:`(B, kH, kW)`. 51 | border_type: the padding mode to be applied before convolving. 52 | The expected modes are: ``'constant'``, ``'reflect'``, 53 | ``'replicate'`` or ``'circular'``. 54 | normalized: If True, kernel will be L1 normalized. 55 | padding: This defines the type of padding. 56 | 2 modes available ``'same'`` or ``'valid'``. 57 | 58 | Return: 59 | torch.Tensor: the convolved tensor of same size and numbers of channels 60 | as the input with shape :math:`(B, C, H, W)`. 61 | 62 | Example: 63 | >>> input = torch.tensor([[[ 64 | ... [0., 0., 0., 0., 0.], 65 | ... [0., 0., 0., 0., 0.], 66 | ... [0., 0., 5., 0., 0.], 67 | ... [0., 0., 0., 0., 0.], 68 | ... [0., 0., 0., 0., 0.],]]]) 69 | >>> kernel = torch.ones(1, 3, 3) 70 | >>> filter2d(input, kernel, padding='same') 71 | tensor([[[[0., 0., 0., 0., 0.], 72 | [0., 5., 5., 5., 0.], 73 | [0., 5., 5., 5., 0.], 74 | [0., 5., 5., 5., 0.], 75 | [0., 0., 0., 0., 0.]]]]) 76 | """ 77 | if not isinstance(input, torch.Tensor): 78 | raise TypeError(f"Input input is not torch.Tensor. Got {type(input)}") 79 | 80 | if not isinstance(kernel, torch.Tensor): 81 | raise TypeError(f"Input kernel is not torch.Tensor. Got {type(kernel)}") 82 | 83 | if not isinstance(border_type, str): 84 | raise TypeError(f"Input border_type is not string. Got {type(border_type)}") 85 | 86 | if border_type not in ["constant", "reflect", "replicate", "circular"]: 87 | raise ValueError( 88 | f"Invalid border type, we expect 'constant', \ 89 | 'reflect', 'replicate', 'circular'. Got:{border_type}" 90 | ) 91 | 92 | if not isinstance(padding, str): 93 | raise TypeError(f"Input padding is not string. Got {type(padding)}") 94 | 95 | if padding not in ["valid", "same"]: 96 | raise ValueError( 97 | f"Invalid padding mode, we expect 'valid' or 'same'. Got: {padding}" 98 | ) 99 | 100 | if not len(input.shape) == 4: 101 | raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}") 102 | 103 | if (not len(kernel.shape) == 3) and not ( 104 | (kernel.shape[0] == 0) or (kernel.shape[0] == input.shape[0]) 105 | ): 106 | raise ValueError( 107 | f"Invalid kernel shape, we expect 1xHxW or BxHxW. Got: {kernel.shape}" 108 | ) 109 | 110 | # prepare kernel 111 | b, c, h, w = input.shape 112 | tmp_kernel: torch.Tensor = kernel.unsqueeze(1).to(input) 113 | 114 | if normalized: 115 | tmp_kernel = normalize_kernel2d(tmp_kernel) 116 | 117 | tmp_kernel = tmp_kernel.expand(-1, c, -1, -1) 118 | 119 | height, width = tmp_kernel.shape[-2:] 120 | 121 | # pad the input tensor 122 | if padding == "same": 123 | padding_shape: List[int] = _compute_padding([height, width]) 124 | input = F.pad(input, padding_shape, mode=border_type) 125 | 126 | # kernel and input tensor reshape to align element-wise or batch-wise params 127 | tmp_kernel = tmp_kernel.reshape(-1, 1, height, width) 128 | input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1)) 129 | 130 | # convolve the tensor with the kernel. 131 | output = F.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) 132 | 133 | if padding == "same": 134 | out = output.view(b, c, h, w) 135 | else: 136 | out = output.view(b, c, h - height + 1, w - width + 1) 137 | 138 | return out 139 | 140 | 141 | def filter2d_separable( 142 | input: torch.Tensor, 143 | kernel_x: torch.Tensor, 144 | kernel_y: torch.Tensor, 145 | border_type: str = "reflect", 146 | normalized: bool = False, 147 | padding: str = "same", 148 | ) -> torch.Tensor: 149 | r"""Convolve a tensor with two 1d kernels, in x and y directions. 150 | 151 | The function applies a given kernel to a tensor. The kernel is applied 152 | independently at each depth channel of the tensor. Before applying the 153 | kernel, the function applies padding according to the specified mode so 154 | that the output remains in the same shape. 155 | 156 | Args: 157 | input: the input tensor with shape of 158 | :math:`(B, C, H, W)`. 159 | kernel_x: the kernel to be convolved with the input 160 | tensor. The kernel shape must be :math:`(1, kW)` or :math:`(B, kW)`. 161 | kernel_y: the kernel to be convolved with the input 162 | tensor. The kernel shape must be :math:`(1, kH)` or :math:`(B, kH)`. 163 | border_type: the padding mode to be applied before convolving. 164 | The expected modes are: ``'constant'``, ``'reflect'``, 165 | ``'replicate'`` or ``'circular'``. 166 | normalized: If True, kernel will be L1 normalized. 167 | padding: This defines the type of padding. 168 | 2 modes available ``'same'`` or ``'valid'``. 169 | 170 | Return: 171 | torch.Tensor: the convolved tensor of same size and numbers of channels 172 | as the input with shape :math:`(B, C, H, W)`. 173 | 174 | Example: 175 | >>> input = torch.tensor([[[ 176 | ... [0., 0., 0., 0., 0.], 177 | ... [0., 0., 0., 0., 0.], 178 | ... [0., 0., 5., 0., 0.], 179 | ... [0., 0., 0., 0., 0.], 180 | ... [0., 0., 0., 0., 0.],]]]) 181 | >>> kernel = torch.ones(1, 3) 182 | 183 | >>> filter2d_separable(input, kernel, kernel, padding='same') 184 | tensor([[[[0., 0., 0., 0., 0.], 185 | [0., 5., 5., 5., 0.], 186 | [0., 5., 5., 5., 0.], 187 | [0., 5., 5., 5., 0.], 188 | [0., 0., 0., 0., 0.]]]]) 189 | """ 190 | out_x = filter2d(input, kernel_x.unsqueeze(0), border_type, normalized, padding) 191 | out = filter2d(out_x, kernel_y.unsqueeze(-1), border_type, normalized, padding) 192 | return out 193 | 194 | 195 | def filter3d( 196 | input: torch.Tensor, 197 | kernel: torch.Tensor, 198 | border_type: str = "replicate", 199 | normalized: bool = False, 200 | ) -> torch.Tensor: 201 | r"""Convolve a tensor with a 3d kernel. 202 | 203 | The function applies a given kernel to a tensor. The kernel is applied 204 | independently at each depth channel of the tensor. Before applying the 205 | kernel, the function applies padding according to the specified mode so 206 | that the output remains in the same shape. 207 | 208 | Args: 209 | input: the input tensor with shape of 210 | :math:`(B, C, D, H, W)`. 211 | kernel: the kernel to be convolved with the input 212 | tensor. The kernel shape must be :math:`(1, kD, kH, kW)` or :math:`(B, kD, kH, kW)`. 213 | border_type: the padding mode to be applied before convolving. 214 | The expected modes are: ``'constant'``, 215 | ``'replicate'`` or ``'circular'``. 216 | normalized: If True, kernel will be L1 normalized. 217 | 218 | Return: 219 | the convolved tensor of same size and numbers of channels 220 | as the input with shape :math:`(B, C, D, H, W)`. 221 | 222 | Example: 223 | >>> input = torch.tensor([[[ 224 | ... [[0., 0., 0., 0., 0.], 225 | ... [0., 0., 0., 0., 0.], 226 | ... [0., 0., 0., 0., 0.], 227 | ... [0., 0., 0., 0., 0.], 228 | ... [0., 0., 0., 0., 0.]], 229 | ... [[0., 0., 0., 0., 0.], 230 | ... [0., 0., 0., 0., 0.], 231 | ... [0., 0., 5., 0., 0.], 232 | ... [0., 0., 0., 0., 0.], 233 | ... [0., 0., 0., 0., 0.]], 234 | ... [[0., 0., 0., 0., 0.], 235 | ... [0., 0., 0., 0., 0.], 236 | ... [0., 0., 0., 0., 0.], 237 | ... [0., 0., 0., 0., 0.], 238 | ... [0., 0., 0., 0., 0.]] 239 | ... ]]]) 240 | >>> kernel = torch.ones(1, 3, 3, 3) 241 | >>> filter3d(input, kernel) 242 | tensor([[[[[0., 0., 0., 0., 0.], 243 | [0., 5., 5., 5., 0.], 244 | [0., 5., 5., 5., 0.], 245 | [0., 5., 5., 5., 0.], 246 | [0., 0., 0., 0., 0.]], 247 | 248 | [[0., 0., 0., 0., 0.], 249 | [0., 5., 5., 5., 0.], 250 | [0., 5., 5., 5., 0.], 251 | [0., 5., 5., 5., 0.], 252 | [0., 0., 0., 0., 0.]], 253 | 254 | [[0., 0., 0., 0., 0.], 255 | [0., 5., 5., 5., 0.], 256 | [0., 5., 5., 5., 0.], 257 | [0., 5., 5., 5., 0.], 258 | [0., 0., 0., 0., 0.]]]]]) 259 | """ 260 | if not isinstance(input, torch.Tensor): 261 | raise TypeError(f"Input border_type is not torch.Tensor. Got {type(input)}") 262 | 263 | if not isinstance(kernel, torch.Tensor): 264 | raise TypeError(f"Input border_type is not torch.Tensor. Got {type(kernel)}") 265 | 266 | if not isinstance(border_type, str): 267 | raise TypeError(f"Input border_type is not string. Got {type(kernel)}") 268 | 269 | if not len(input.shape) == 5: 270 | raise ValueError( 271 | f"Invalid input shape, we expect BxCxDxHxW. Got: {input.shape}" 272 | ) 273 | 274 | if not len(kernel.shape) == 4 and kernel.shape[0] != 1: 275 | raise ValueError( 276 | f"Invalid kernel shape, we expect 1xDxHxW. Got: {kernel.shape}" 277 | ) 278 | 279 | # prepare kernel 280 | b, c, d, h, w = input.shape 281 | tmp_kernel: torch.Tensor = kernel.unsqueeze(1).to(input) 282 | 283 | if normalized: 284 | bk, dk, hk, wk = kernel.shape 285 | tmp_kernel = normalize_kernel2d(tmp_kernel.view(bk, dk, hk * wk)).view_as( 286 | tmp_kernel 287 | ) 288 | 289 | tmp_kernel = tmp_kernel.expand(-1, c, -1, -1, -1) 290 | 291 | # pad the input tensor 292 | depth, height, width = tmp_kernel.shape[-3:] 293 | padding_shape: List[int] = _compute_padding([depth, height, width]) 294 | input_pad: torch.Tensor = F.pad(input, padding_shape, mode=border_type) 295 | 296 | # kernel and input tensor reshape to align element-wise or batch-wise params 297 | tmp_kernel = tmp_kernel.reshape(-1, 1, depth, height, width) 298 | input_pad = input_pad.view( 299 | -1, 300 | tmp_kernel.size(0), 301 | input_pad.size(-3), 302 | input_pad.size(-2), 303 | input_pad.size(-1), 304 | ) 305 | 306 | # convolve the tensor with the kernel. 307 | output = F.conv3d( 308 | input_pad, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1 309 | ) 310 | 311 | return output.view(b, c, d, h, w) 312 | -------------------------------------------------------------------------------- /model/canny/gaussian.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from .filter import filter2d, filter2d_separable 7 | from .kernels import get_gaussian_kernel1d, get_gaussian_kernel2d 8 | 9 | 10 | def gaussian_blur2d( 11 | input: torch.Tensor, 12 | kernel_size: Tuple[int, int], 13 | sigma: Tuple[float, float], 14 | border_type: str = "reflect", 15 | separable: bool = True, 16 | ) -> torch.Tensor: 17 | r"""Create an operator that blurs a tensor using a Gaussian filter. 18 | 19 | .. image:: _static/img/gaussian_blur2d.png 20 | 21 | The operator smooths the given tensor with a gaussian kernel by convolving 22 | it to each channel. It supports batched operation. 23 | 24 | Arguments: 25 | input: the input tensor with shape :math:`(B,C,H,W)`. 26 | kernel_size: the size of the kernel. 27 | sigma: the standard deviation of the kernel. 28 | border_type: the padding mode to be applied before convolving. 29 | The expected modes are: ``'constant'``, ``'reflect'``, 30 | ``'replicate'`` or ``'circular'``. Default: ``'reflect'``. 31 | separable: run as composition of two 1d-convolutions. 32 | 33 | Returns: 34 | the blurred tensor with shape :math:`(B, C, H, W)`. 35 | 36 | .. note:: 37 | See a working example `here `__. 39 | 40 | Examples: 41 | >>> input = torch.rand(2, 4, 5, 5) 42 | >>> output = gaussian_blur2d(input, (3, 3), (1.5, 1.5)) 43 | >>> output.shape 44 | torch.Size([2, 4, 5, 5]) 45 | """ 46 | if separable: 47 | kernel_x: torch.Tensor = get_gaussian_kernel1d(kernel_size[1], sigma[1]) 48 | kernel_y: torch.Tensor = get_gaussian_kernel1d(kernel_size[0], sigma[0]) 49 | out = filter2d_separable(input, kernel_x[None], kernel_y[None], border_type) 50 | else: 51 | kernel: torch.Tensor = get_gaussian_kernel2d(kernel_size, sigma) 52 | out = filter2d(input, kernel[None], border_type) 53 | return out 54 | 55 | 56 | class GaussianBlur2d(nn.Module): 57 | r"""Create an operator that blurs a tensor using a Gaussian filter. 58 | 59 | The operator smooths the given tensor with a gaussian kernel by convolving 60 | it to each channel. It supports batched operation. 61 | 62 | Arguments: 63 | kernel_size: the size of the kernel. 64 | sigma: the standard deviation of the kernel. 65 | border_type: the padding mode to be applied before convolving. 66 | The expected modes are: ``'constant'``, ``'reflect'``, 67 | ``'replicate'`` or ``'circular'``. Default: ``'reflect'``. 68 | separable: run as composition of two 1d-convolutions. 69 | 70 | Returns: 71 | the blurred tensor. 72 | 73 | Shape: 74 | - Input: :math:`(B, C, H, W)` 75 | - Output: :math:`(B, C, H, W)` 76 | 77 | Examples:: 78 | 79 | >>> input = torch.rand(2, 4, 5, 5) 80 | >>> gauss = GaussianBlur2d((3, 3), (1.5, 1.5)) 81 | >>> output = gauss(input) # 2x4x5x5 82 | >>> output.shape 83 | torch.Size([2, 4, 5, 5]) 84 | """ 85 | 86 | def __init__( 87 | self, 88 | kernel_size: Tuple[int, int], 89 | sigma: Tuple[float, float], 90 | border_type: str = "reflect", 91 | separable: bool = True, 92 | ) -> None: 93 | super().__init__() 94 | self.kernel_size: Tuple[int, int] = kernel_size 95 | self.sigma: Tuple[float, float] = sigma 96 | self.border_type = border_type 97 | self.separable = separable 98 | 99 | def __repr__(self) -> str: 100 | return ( 101 | self.__class__.__name__ 102 | + "(kernel_size=" 103 | + str(self.kernel_size) 104 | + ", " 105 | + "sigma=" 106 | + str(self.sigma) 107 | + ", " 108 | + "border_type=" 109 | + self.border_type 110 | + "separable=" 111 | + str(self.separable) 112 | + ")" 113 | ) 114 | 115 | def forward(self, input: torch.Tensor) -> torch.Tensor: 116 | return gaussian_blur2d( 117 | input, self.kernel_size, self.sigma, self.border_type, self.separable 118 | ) 119 | -------------------------------------------------------------------------------- /model/canny/sobel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | from .kernels import ( 6 | get_spatial_gradient_kernel2d, 7 | get_spatial_gradient_kernel3d, 8 | normalize_kernel2d, 9 | ) 10 | 11 | 12 | def spatial_gradient( 13 | input: torch.Tensor, mode: str = "sobel", order: int = 1, normalized: bool = True 14 | ) -> torch.Tensor: 15 | r"""Compute the first order image derivative in both x and y using a Sobel operator. 16 | 17 | .. image:: _static/img/spatial_gradient.png 18 | 19 | Args: 20 | input: input image tensor with shape :math:`(B, C, H, W)`. 21 | mode: derivatives modality, can be: `sobel` or `diff`. 22 | order: the order of the derivatives. 23 | normalized: whether the output is normalized. 24 | 25 | Return: 26 | the derivatives of the input feature map. with shape :math:`(B, C, 2, H, W)`. 27 | 28 | .. note:: 29 | See a working example `here `__. 31 | 32 | Examples: 33 | >>> input = torch.rand(1, 3, 4, 4) 34 | >>> output = spatial_gradient(input) # 1x3x2x4x4 35 | >>> output.shape 36 | torch.Size([1, 3, 2, 4, 4]) 37 | """ 38 | if not isinstance(input, torch.Tensor): 39 | raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") 40 | 41 | if not len(input.shape) == 4: 42 | raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}") 43 | # allocate kernel 44 | kernel: torch.Tensor = get_spatial_gradient_kernel2d(mode, order) 45 | if normalized: 46 | kernel = normalize_kernel2d(kernel) 47 | 48 | # prepare kernel 49 | b, c, h, w = input.shape 50 | tmp_kernel: torch.Tensor = kernel.to(input).detach() 51 | tmp_kernel = tmp_kernel.unsqueeze(1).unsqueeze(1) 52 | 53 | # convolve input tensor with sobel kernel 54 | kernel_flip: torch.Tensor = tmp_kernel.flip(-3) 55 | 56 | # Pad with "replicate for spatial dims, but with zeros for channel 57 | spatial_pad = [ 58 | kernel.size(1) // 2, 59 | kernel.size(1) // 2, 60 | kernel.size(2) // 2, 61 | kernel.size(2) // 2, 62 | ] 63 | out_channels: int = 3 if order == 2 else 2 64 | padded_inp: torch.Tensor = F.pad( 65 | input.reshape(b * c, 1, h, w), spatial_pad, "replicate" 66 | )[:, :, None] 67 | 68 | return F.conv3d(padded_inp, kernel_flip, padding=0).view(b, c, out_channels, h, w) 69 | 70 | 71 | def spatial_gradient3d( 72 | input: torch.Tensor, mode: str = "diff", order: int = 1 73 | ) -> torch.Tensor: 74 | r"""Compute the first and second order volume derivative in x, y and d using a diff operator. 75 | 76 | Args: 77 | input: input features tensor with shape :math:`(B, C, D, H, W)`. 78 | mode: derivatives modality, can be: `sobel` or `diff`. 79 | order: the order of the derivatives. 80 | 81 | Return: 82 | the spatial gradients of the input feature map with shape math:`(B, C, 3, D, H, W)` 83 | or :math:`(B, C, 6, D, H, W)`. 84 | 85 | Examples: 86 | >>> input = torch.rand(1, 4, 2, 4, 4) 87 | >>> output = spatial_gradient3d(input) 88 | >>> output.shape 89 | torch.Size([1, 4, 3, 2, 4, 4]) 90 | """ 91 | if not isinstance(input, torch.Tensor): 92 | raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") 93 | 94 | if not len(input.shape) == 5: 95 | raise ValueError( 96 | f"Invalid input shape, we expect BxCxDxHxW. Got: {input.shape}" 97 | ) 98 | b, c, d, h, w = input.shape 99 | dev = input.device 100 | dtype = input.dtype 101 | if (mode == "diff") and (order == 1): 102 | # we go for the special case implementation due to conv3d bad speed 103 | x: torch.Tensor = F.pad(input, 6 * [1], "replicate") 104 | center = slice(1, -1) 105 | left = slice(0, -2) 106 | right = slice(2, None) 107 | out = torch.empty(b, c, 3, d, h, w, device=dev, dtype=dtype) 108 | out[..., 0, :, :, :] = ( 109 | x[..., center, center, right] - x[..., center, center, left] 110 | ) 111 | out[..., 1, :, :, :] = ( 112 | x[..., center, right, center] - x[..., center, left, center] 113 | ) 114 | out[..., 2, :, :, :] = ( 115 | x[..., right, center, center] - x[..., left, center, center] 116 | ) 117 | out = 0.5 * out 118 | else: 119 | # prepare kernel 120 | # allocate kernel 121 | kernel: torch.Tensor = get_spatial_gradient_kernel3d(mode, order) 122 | 123 | tmp_kernel: torch.Tensor = kernel.to(input).detach() 124 | tmp_kernel = tmp_kernel.repeat(c, 1, 1, 1, 1) 125 | 126 | # convolve input tensor with grad kernel 127 | kernel_flip: torch.Tensor = tmp_kernel.flip(-3) 128 | 129 | # Pad with "replicate for spatial dims, but with zeros for channel 130 | spatial_pad = [ 131 | kernel.size(2) // 2, 132 | kernel.size(2) // 2, 133 | kernel.size(3) // 2, 134 | kernel.size(3) // 2, 135 | kernel.size(4) // 2, 136 | kernel.size(4) // 2, 137 | ] 138 | out_ch: int = 6 if order == 2 else 3 139 | out = F.conv3d( 140 | F.pad(input, spatial_pad, "replicate"), kernel_flip, padding=0, groups=c 141 | ).view(b, c, out_ch, d, h, w) 142 | return out 143 | 144 | 145 | def sobel( 146 | input: torch.Tensor, normalized: bool = True, eps: float = 1e-6 147 | ) -> torch.Tensor: 148 | r"""Compute the Sobel operator and returns the magnitude per channel. 149 | 150 | .. image:: _static/img/sobel.png 151 | 152 | Args: 153 | input: the input image with shape :math:`(B,C,H,W)`. 154 | normalized: if True, L1 norm of the kernel is set to 1. 155 | eps: regularization number to avoid NaN during backprop. 156 | 157 | Return: 158 | the sobel edge gradient magnitudes map with shape :math:`(B,C,H,W)`. 159 | 160 | .. note:: 161 | See a working example `here `__. 163 | 164 | Example: 165 | >>> input = torch.rand(1, 3, 4, 4) 166 | >>> output = sobel(input) # 1x3x4x4 167 | >>> output.shape 168 | torch.Size([1, 3, 4, 4]) 169 | """ 170 | if not isinstance(input, torch.Tensor): 171 | raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") 172 | 173 | if not len(input.shape) == 4: 174 | raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}") 175 | 176 | # comput the x/y gradients 177 | edges: torch.Tensor = spatial_gradient(input, normalized=normalized) 178 | 179 | # unpack the edges 180 | gx: torch.Tensor = edges[:, :, 0] 181 | gy: torch.Tensor = edges[:, :, 1] 182 | 183 | # compute gradient maginitude 184 | magnitude: torch.Tensor = torch.sqrt(gx * gx + gy * gy + eps) 185 | 186 | return magnitude 187 | 188 | 189 | class SpatialGradient(nn.Module): 190 | r"""Compute the first order image derivative in both x and y using a Sobel operator. 191 | 192 | Args: 193 | mode: derivatives modality, can be: `sobel` or `diff`. 194 | order: the order of the derivatives. 195 | normalized: whether the output is normalized. 196 | 197 | Return: 198 | the sobel edges of the input feature map. 199 | 200 | Shape: 201 | - Input: :math:`(B, C, H, W)` 202 | - Output: :math:`(B, C, 2, H, W)` 203 | 204 | Examples: 205 | >>> input = torch.rand(1, 3, 4, 4) 206 | >>> output = SpatialGradient()(input) # 1x3x2x4x4 207 | """ 208 | 209 | def __init__( 210 | self, mode: str = "sobel", order: int = 1, normalized: bool = True 211 | ) -> None: 212 | super().__init__() 213 | self.normalized: bool = normalized 214 | self.order: int = order 215 | self.mode: str = mode 216 | 217 | def __repr__(self) -> str: 218 | return ( 219 | self.__class__.__name__ + "(" 220 | "order=" 221 | + str(self.order) 222 | + ", " 223 | + "normalized=" 224 | + str(self.normalized) 225 | + ", " 226 | + "mode=" 227 | + self.mode 228 | + ")" 229 | ) 230 | 231 | def forward(self, input: torch.Tensor) -> torch.Tensor: 232 | return spatial_gradient(input, self.mode, self.order, self.normalized) 233 | 234 | 235 | class SpatialGradient3d(nn.Module): 236 | r"""Compute the first and second order volume derivative in x, y and d using a diff operator. 237 | 238 | Args: 239 | mode: derivatives modality, can be: `sobel` or `diff`. 240 | order: the order of the derivatives. 241 | 242 | Return: 243 | the spatial gradients of the input feature map. 244 | 245 | Shape: 246 | - Input: :math:`(B, C, D, H, W)`. D, H, W are spatial dimensions, gradient is calculated w.r.t to them. 247 | - Output: :math:`(B, C, 3, D, H, W)` or :math:`(B, C, 6, D, H, W)` 248 | 249 | Examples: 250 | >>> input = torch.rand(1, 4, 2, 4, 4) 251 | >>> output = SpatialGradient3d()(input) 252 | >>> output.shape 253 | torch.Size([1, 4, 3, 2, 4, 4]) 254 | """ 255 | 256 | def __init__(self, mode: str = "diff", order: int = 1) -> None: 257 | super().__init__() 258 | self.order: int = order 259 | self.mode: str = mode 260 | self.kernel = get_spatial_gradient_kernel3d(mode, order) 261 | 262 | def __repr__(self) -> str: 263 | return ( 264 | self.__class__.__name__ + "(" 265 | "order=" + str(self.order) + ", " + "mode=" + self.mode + ")" 266 | ) 267 | 268 | def forward(self, input: torch.Tensor) -> torch.Tensor: # type: ignore 269 | return spatial_gradient3d(input, self.mode, self.order) 270 | 271 | 272 | class Sobel(nn.Module): 273 | r"""Compute the Sobel operator and returns the magnitude per channel. 274 | 275 | Args: 276 | normalized: if True, L1 norm of the kernel is set to 1. 277 | eps: regularization number to avoid NaN during backprop. 278 | 279 | Return: 280 | the sobel edge gradient magnitudes map. 281 | 282 | Shape: 283 | - Input: :math:`(B, C, H, W)` 284 | - Output: :math:`(B, C, H, W)` 285 | 286 | Examples: 287 | >>> input = torch.rand(1, 3, 4, 4) 288 | >>> output = Sobel()(input) # 1x3x4x4 289 | """ 290 | 291 | def __init__(self, normalized: bool = True, eps: float = 1e-6) -> None: 292 | super().__init__() 293 | self.normalized: bool = normalized 294 | self.eps: float = eps 295 | 296 | def __repr__(self) -> str: 297 | return self.__class__.__name__ + "(" "normalized=" + str(self.normalized) + ")" 298 | 299 | def forward(self, input: torch.Tensor) -> torch.Tensor: 300 | return sobel(input, self.normalized, self.eps) 301 | -------------------------------------------------------------------------------- /model/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import random 4 | import time 5 | import torch 6 | from torch import nn 7 | import logging 8 | import numpy as np 9 | from os import path as osp 10 | 11 | 12 | def constant_init(module, val, bias=0): 13 | if hasattr(module, "weight") and module.weight is not None: 14 | nn.init.constant_(module.weight, val) 15 | if hasattr(module, "bias") and module.bias is not None: 16 | nn.init.constant_(module.bias, bias) 17 | 18 | 19 | initialized_logger = {} 20 | 21 | 22 | def get_root_logger(logger_name="basicsr", log_level=logging.INFO, log_file=None): 23 | """Get the root logger. 24 | The logger will be initialized if it has not been initialized. By default a 25 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 26 | also be added. 27 | 28 | Args: 29 | logger_name (str): root logger name. Default: 'basicsr'. 30 | log_file (str | None): The log filename. If specified, a FileHandler 31 | will be added to the root logger. 32 | log_level (int): The root logger level. Note that only the process of 33 | rank 0 is affected, while other processes will set the level to 34 | "Error" and be silent most of the time. 35 | 36 | Returns: 37 | logging.Logger: The root logger. 38 | """ 39 | logger = logging.getLogger(logger_name) 40 | # if the logger has been initialized, just return it 41 | if logger_name in initialized_logger: 42 | return logger 43 | 44 | format_str = "%(asctime)s %(levelname)s: %(message)s" 45 | stream_handler = logging.StreamHandler() 46 | stream_handler.setFormatter(logging.Formatter(format_str)) 47 | logger.addHandler(stream_handler) 48 | logger.propagate = False 49 | 50 | if log_file is not None: 51 | logger.setLevel(log_level) 52 | # add file handler 53 | # file_handler = logging.FileHandler(log_file, 'w') 54 | file_handler = logging.FileHandler( 55 | log_file, "a" 56 | ) # Shangchen: keep the previous log 57 | file_handler.setFormatter(logging.Formatter(format_str)) 58 | file_handler.setLevel(log_level) 59 | logger.addHandler(file_handler) 60 | initialized_logger[logger_name] = True 61 | return logger 62 | 63 | 64 | IS_HIGH_VERSION = [ 65 | int(m) 66 | for m in list( 67 | re.findall( 68 | r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$", 69 | torch.__version__, 70 | )[0][:3] 71 | ) 72 | ] >= [1, 12, 0] 73 | 74 | 75 | def gpu_is_available(): 76 | if IS_HIGH_VERSION: 77 | if torch.backends.mps.is_available(): 78 | return True 79 | return ( 80 | True 81 | if torch.cuda.is_available() and torch.backends.cudnn.is_available() 82 | else False 83 | ) 84 | 85 | 86 | def get_device(gpu_id=None): 87 | if gpu_id is None: 88 | gpu_str = "" 89 | elif isinstance(gpu_id, int): 90 | gpu_str = f":{gpu_id}" 91 | else: 92 | raise TypeError("Input should be int value.") 93 | 94 | if IS_HIGH_VERSION: 95 | if torch.backends.mps.is_available(): 96 | return torch.device("mps" + gpu_str) 97 | return torch.device( 98 | "cuda" + gpu_str 99 | if torch.cuda.is_available() and torch.backends.cudnn.is_available() 100 | else "cpu" 101 | ) 102 | 103 | 104 | def set_random_seed(seed): 105 | """Set random seeds.""" 106 | random.seed(seed) 107 | np.random.seed(seed) 108 | torch.manual_seed(seed) 109 | torch.cuda.manual_seed(seed) 110 | torch.cuda.manual_seed_all(seed) 111 | 112 | 113 | def get_time_str(): 114 | return time.strftime("%Y%m%d_%H%M%S", time.localtime()) 115 | 116 | 117 | def scandir(dir_path, suffix=None, recursive=False, full_path=False): 118 | """Scan a directory to find the interested files. 119 | 120 | Args: 121 | dir_path (str): Path of the directory. 122 | suffix (str | tuple(str), optional): File suffix that we are 123 | interested in. Default: None. 124 | recursive (bool, optional): If set to True, recursively scan the 125 | directory. Default: False. 126 | full_path (bool, optional): If set to True, include the dir_path. 127 | Default: False. 128 | 129 | Returns: 130 | A generator for all the interested files with relative pathes. 131 | """ 132 | if (suffix is not None) and not isinstance(suffix, (str, tuple)): 133 | raise TypeError('"suffix" must be a string or tuple of strings') 134 | 135 | root = dir_path 136 | 137 | def _scandir(dir_path, suffix, recursive): 138 | for entry in os.scandir(dir_path): 139 | if not entry.name.startswith(".") and entry.is_file(): 140 | if full_path: 141 | return_path = entry.path 142 | else: 143 | return_path = osp.relpath(entry.path, root) 144 | 145 | if suffix is None or return_path.endswith(suffix): 146 | yield return_path 147 | elif recursive: 148 | yield from _scandir(entry.path, suffix=suffix, recursive=recursive) 149 | else: 150 | continue 151 | 152 | return _scandir(dir_path, suffix=suffix, recursive=recursive) 153 | -------------------------------------------------------------------------------- /model/modules/RAFT/__init__.py: -------------------------------------------------------------------------------- 1 | # from .demo import RAFT_infer 2 | from .raft import RAFT 3 | -------------------------------------------------------------------------------- /model/modules/RAFT/corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from .utils.utils import bilinear_sampler 4 | 5 | try: 6 | import alt_cuda_corr 7 | except: 8 | # alt_cuda_corr is not compiled 9 | pass 10 | 11 | 12 | class CorrBlock: 13 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 14 | self.num_levels = num_levels 15 | self.radius = radius 16 | self.corr_pyramid = [] 17 | 18 | # all pairs correlation 19 | corr = CorrBlock.corr(fmap1, fmap2) 20 | 21 | batch, h1, w1, dim, h2, w2 = corr.shape 22 | corr = corr.reshape(batch * h1 * w1, dim, h2, w2) 23 | 24 | self.corr_pyramid.append(corr) 25 | for i in range(self.num_levels - 1): 26 | corr = F.avg_pool2d(corr, 2, stride=2) 27 | self.corr_pyramid.append(corr) 28 | 29 | def __call__(self, coords): 30 | r = self.radius 31 | coords = coords.permute(0, 2, 3, 1) 32 | batch, h1, w1, _ = coords.shape 33 | 34 | out_pyramid = [] 35 | for i in range(self.num_levels): 36 | corr = self.corr_pyramid[i] 37 | dx = torch.linspace(-r, r, 2 * r + 1) 38 | dy = torch.linspace(-r, r, 2 * r + 1) 39 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) 40 | 41 | centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i 42 | delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) 43 | coords_lvl = centroid_lvl + delta_lvl 44 | 45 | corr = bilinear_sampler(corr, coords_lvl) 46 | corr = corr.view(batch, h1, w1, -1) 47 | out_pyramid.append(corr) 48 | 49 | out = torch.cat(out_pyramid, dim=-1) 50 | return out.permute(0, 3, 1, 2).contiguous().float() 51 | 52 | @staticmethod 53 | def corr(fmap1, fmap2): 54 | batch, dim, ht, wd = fmap1.shape 55 | fmap1 = fmap1.view(batch, dim, ht * wd) 56 | fmap2 = fmap2.view(batch, dim, ht * wd) 57 | 58 | corr = torch.matmul(fmap1.transpose(1, 2), fmap2) 59 | corr = corr.view(batch, ht, wd, 1, ht, wd) 60 | return corr / torch.sqrt(torch.tensor(dim).float()) 61 | 62 | 63 | class CorrLayer(torch.autograd.Function): 64 | @staticmethod 65 | def forward(ctx, fmap1, fmap2, coords, r): 66 | fmap1 = fmap1.contiguous() 67 | fmap2 = fmap2.contiguous() 68 | coords = coords.contiguous() 69 | ctx.save_for_backward(fmap1, fmap2, coords) 70 | ctx.r = r 71 | (corr,) = correlation_cudaz.forward(fmap1, fmap2, coords, ctx.r) 72 | return corr 73 | 74 | @staticmethod 75 | def backward(ctx, grad_corr): 76 | fmap1, fmap2, coords = ctx.saved_tensors 77 | grad_corr = grad_corr.contiguous() 78 | fmap1_grad, fmap2_grad, coords_grad = correlation_cudaz.backward( 79 | fmap1, fmap2, coords, grad_corr, ctx.r 80 | ) 81 | return fmap1_grad, fmap2_grad, coords_grad, None 82 | 83 | 84 | class AlternateCorrBlock: 85 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 86 | self.num_levels = num_levels 87 | self.radius = radius 88 | 89 | self.pyramid = [(fmap1, fmap2)] 90 | for i in range(self.num_levels): 91 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2) 92 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2) 93 | self.pyramid.append((fmap1, fmap2)) 94 | 95 | def __call__(self, coords): 96 | coords = coords.permute(0, 2, 3, 1) 97 | B, H, W, _ = coords.shape 98 | 99 | corr_list = [] 100 | for i in range(self.num_levels): 101 | r = self.radius 102 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1) 103 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1) 104 | 105 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() 106 | corr = alt_cuda_corr(fmap1_i, fmap2_i, coords_i, r) 107 | corr_list.append(corr.squeeze(1)) 108 | 109 | corr = torch.stack(corr_list, dim=1) 110 | corr = corr.reshape(B, -1, H, W) 111 | return corr / 16.0 112 | -------------------------------------------------------------------------------- /model/modules/RAFT/datasets.py: -------------------------------------------------------------------------------- 1 | # Data loading based on https://github.com/NVIDIA/flownet2-pytorch 2 | 3 | import numpy as np 4 | import torch 5 | from torch.utils import data 6 | 7 | import os 8 | import random 9 | from glob import glob 10 | import os.path as osp 11 | 12 | from utils import frame_utils 13 | from utils.augmentor import FlowAugmentor, SparseFlowAugmentor 14 | 15 | 16 | class FlowDataset(data.Dataset): 17 | def __init__(self, aug_params=None, sparse=False): 18 | self.augmentor = None 19 | self.sparse = sparse 20 | if aug_params is not None: 21 | if sparse: 22 | self.augmentor = SparseFlowAugmentor(**aug_params) 23 | else: 24 | self.augmentor = FlowAugmentor(**aug_params) 25 | 26 | self.is_test = False 27 | self.init_seed = False 28 | self.flow_list = [] 29 | self.image_list = [] 30 | self.extra_info = [] 31 | 32 | def __getitem__(self, index): 33 | if self.is_test: 34 | img1 = frame_utils.read_gen(self.image_list[index][0]) 35 | img2 = frame_utils.read_gen(self.image_list[index][1]) 36 | img1 = np.array(img1).astype(np.uint8)[..., :3] 37 | img2 = np.array(img2).astype(np.uint8)[..., :3] 38 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float() 39 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float() 40 | return img1, img2, self.extra_info[index] 41 | 42 | if not self.init_seed: 43 | worker_info = torch.utils.data.get_worker_info() 44 | if worker_info is not None: 45 | torch.manual_seed(worker_info.id) 46 | np.random.seed(worker_info.id) 47 | random.seed(worker_info.id) 48 | self.init_seed = True 49 | 50 | index = index % len(self.image_list) 51 | valid = None 52 | if self.sparse: 53 | flow, valid = frame_utils.readFlowKITTI(self.flow_list[index]) 54 | else: 55 | flow = frame_utils.read_gen(self.flow_list[index]) 56 | 57 | img1 = frame_utils.read_gen(self.image_list[index][0]) 58 | img2 = frame_utils.read_gen(self.image_list[index][1]) 59 | 60 | flow = np.array(flow).astype(np.float32) 61 | img1 = np.array(img1).astype(np.uint8) 62 | img2 = np.array(img2).astype(np.uint8) 63 | 64 | # grayscale images 65 | if len(img1.shape) == 2: 66 | img1 = np.tile(img1[..., None], (1, 1, 3)) 67 | img2 = np.tile(img2[..., None], (1, 1, 3)) 68 | else: 69 | img1 = img1[..., :3] 70 | img2 = img2[..., :3] 71 | 72 | if self.augmentor is not None: 73 | if self.sparse: 74 | img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid) 75 | else: 76 | img1, img2, flow = self.augmentor(img1, img2, flow) 77 | 78 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float() 79 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float() 80 | flow = torch.from_numpy(flow).permute(2, 0, 1).float() 81 | 82 | if valid is not None: 83 | valid = torch.from_numpy(valid) 84 | else: 85 | valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000) 86 | 87 | return img1, img2, flow, valid.float() 88 | 89 | def __rmul__(self, v): 90 | self.flow_list = v * self.flow_list 91 | self.image_list = v * self.image_list 92 | return self 93 | 94 | def __len__(self): 95 | return len(self.image_list) 96 | 97 | 98 | class MpiSintel(FlowDataset): 99 | def __init__( 100 | self, aug_params=None, split="training", root="datasets/Sintel", dstype="clean" 101 | ): 102 | super(MpiSintel, self).__init__(aug_params) 103 | flow_root = osp.join(root, split, "flow") 104 | image_root = osp.join(root, split, dstype) 105 | 106 | if split == "test": 107 | self.is_test = True 108 | 109 | for scene in os.listdir(image_root): 110 | image_list = sorted(glob(osp.join(image_root, scene, "*.png"))) 111 | for i in range(len(image_list) - 1): 112 | self.image_list += [[image_list[i], image_list[i + 1]]] 113 | self.extra_info += [(scene, i)] # scene and frame_id 114 | 115 | if split != "test": 116 | self.flow_list += sorted(glob(osp.join(flow_root, scene, "*.flo"))) 117 | 118 | 119 | class FlyingChairs(FlowDataset): 120 | def __init__( 121 | self, aug_params=None, split="train", root="datasets/FlyingChairs_release/data" 122 | ): 123 | super(FlyingChairs, self).__init__(aug_params) 124 | 125 | images = sorted(glob(osp.join(root, "*.ppm"))) 126 | flows = sorted(glob(osp.join(root, "*.flo"))) 127 | assert len(images) // 2 == len(flows) 128 | 129 | split_list = np.loadtxt("chairs_split.txt", dtype=np.int32) 130 | for i in range(len(flows)): 131 | xid = split_list[i] 132 | if (split == "training" and xid == 1) or ( 133 | split == "validation" and xid == 2 134 | ): 135 | self.flow_list += [flows[i]] 136 | self.image_list += [[images[2 * i], images[2 * i + 1]]] 137 | 138 | 139 | class FlyingThings3D(FlowDataset): 140 | def __init__( 141 | self, aug_params=None, root="datasets/FlyingThings3D", dstype="frames_cleanpass" 142 | ): 143 | super(FlyingThings3D, self).__init__(aug_params) 144 | 145 | for cam in ["left"]: 146 | for direction in ["into_future", "into_past"]: 147 | image_dirs = sorted(glob(osp.join(root, dstype, "TRAIN/*/*"))) 148 | image_dirs = sorted([osp.join(f, cam) for f in image_dirs]) 149 | 150 | flow_dirs = sorted(glob(osp.join(root, "optical_flow/TRAIN/*/*"))) 151 | flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs]) 152 | 153 | for idir, fdir in zip(image_dirs, flow_dirs): 154 | images = sorted(glob(osp.join(idir, "*.png"))) 155 | flows = sorted(glob(osp.join(fdir, "*.pfm"))) 156 | for i in range(len(flows) - 1): 157 | if direction == "into_future": 158 | self.image_list += [[images[i], images[i + 1]]] 159 | self.flow_list += [flows[i]] 160 | elif direction == "into_past": 161 | self.image_list += [[images[i + 1], images[i]]] 162 | self.flow_list += [flows[i + 1]] 163 | 164 | 165 | class KITTI(FlowDataset): 166 | def __init__(self, aug_params=None, split="training", root="datasets/KITTI"): 167 | super(KITTI, self).__init__(aug_params, sparse=True) 168 | if split == "testing": 169 | self.is_test = True 170 | 171 | root = osp.join(root, split) 172 | images1 = sorted(glob(osp.join(root, "image_2/*_10.png"))) 173 | images2 = sorted(glob(osp.join(root, "image_2/*_11.png"))) 174 | 175 | for img1, img2 in zip(images1, images2): 176 | frame_id = img1.split("/")[-1] 177 | self.extra_info += [[frame_id]] 178 | self.image_list += [[img1, img2]] 179 | 180 | if split == "training": 181 | self.flow_list = sorted(glob(osp.join(root, "flow_occ/*_10.png"))) 182 | 183 | 184 | class HD1K(FlowDataset): 185 | def __init__(self, aug_params=None, root="datasets/HD1k"): 186 | super(HD1K, self).__init__(aug_params, sparse=True) 187 | 188 | seq_ix = 0 189 | while 1: 190 | flows = sorted( 191 | glob(os.path.join(root, "hd1k_flow_gt", "flow_occ/%06d_*.png" % seq_ix)) 192 | ) 193 | images = sorted( 194 | glob(os.path.join(root, "hd1k_input", "image_2/%06d_*.png" % seq_ix)) 195 | ) 196 | 197 | if len(flows) == 0: 198 | break 199 | 200 | for i in range(len(flows) - 1): 201 | self.flow_list += [flows[i]] 202 | self.image_list += [[images[i], images[i + 1]]] 203 | 204 | seq_ix += 1 205 | 206 | 207 | def fetch_dataloader(args, TRAIN_DS="C+T+K+S+H"): 208 | """Create the data loader for the corresponding trainign set""" 209 | if args.stage == "chairs": 210 | aug_params = { 211 | "crop_size": args.image_size, 212 | "min_scale": -0.1, 213 | "max_scale": 1.0, 214 | "do_flip": True, 215 | } 216 | train_dataset = FlyingChairs(aug_params, split="training") 217 | 218 | elif args.stage == "things": 219 | aug_params = { 220 | "crop_size": args.image_size, 221 | "min_scale": -0.4, 222 | "max_scale": 0.8, 223 | "do_flip": True, 224 | } 225 | clean_dataset = FlyingThings3D(aug_params, dstype="frames_cleanpass") 226 | final_dataset = FlyingThings3D(aug_params, dstype="frames_finalpass") 227 | train_dataset = clean_dataset + final_dataset 228 | 229 | elif args.stage == "sintel": 230 | aug_params = { 231 | "crop_size": args.image_size, 232 | "min_scale": -0.2, 233 | "max_scale": 0.6, 234 | "do_flip": True, 235 | } 236 | things = FlyingThings3D(aug_params, dstype="frames_cleanpass") 237 | sintel_clean = MpiSintel(aug_params, split="training", dstype="clean") 238 | sintel_final = MpiSintel(aug_params, split="training", dstype="final") 239 | 240 | if TRAIN_DS == "C+T+K+S+H": 241 | kitti = KITTI( 242 | { 243 | "crop_size": args.image_size, 244 | "min_scale": -0.3, 245 | "max_scale": 0.5, 246 | "do_flip": True, 247 | } 248 | ) 249 | hd1k = HD1K( 250 | { 251 | "crop_size": args.image_size, 252 | "min_scale": -0.5, 253 | "max_scale": 0.2, 254 | "do_flip": True, 255 | } 256 | ) 257 | train_dataset = ( 258 | 100 * sintel_clean 259 | + 100 * sintel_final 260 | + 200 * kitti 261 | + 5 * hd1k 262 | + things 263 | ) 264 | 265 | elif TRAIN_DS == "C+T+K/S": 266 | train_dataset = 100 * sintel_clean + 100 * sintel_final + things 267 | 268 | elif args.stage == "kitti": 269 | aug_params = { 270 | "crop_size": args.image_size, 271 | "min_scale": -0.2, 272 | "max_scale": 0.4, 273 | "do_flip": False, 274 | } 275 | train_dataset = KITTI(aug_params, split="training") 276 | 277 | train_loader = data.DataLoader( 278 | train_dataset, 279 | batch_size=args.batch_size, 280 | pin_memory=False, 281 | shuffle=True, 282 | num_workers=4, 283 | drop_last=True, 284 | ) 285 | 286 | print("Training with %d image pairs" % len(train_dataset)) 287 | return train_loader 288 | -------------------------------------------------------------------------------- /model/modules/RAFT/demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import glob 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | 8 | from .raft import RAFT 9 | from .utils import flow_viz 10 | from .utils.utils import InputPadder 11 | 12 | 13 | DEVICE = "cuda" 14 | 15 | 16 | def load_image(imfile): 17 | img = np.array(Image.open(imfile)).astype(np.uint8) 18 | img = torch.from_numpy(img).permute(2, 0, 1).float() 19 | return img 20 | 21 | 22 | def load_image_list(image_files): 23 | images = [] 24 | for imfile in sorted(image_files): 25 | images.append(load_image(imfile)) 26 | 27 | images = torch.stack(images, dim=0) 28 | images = images.to(DEVICE) 29 | 30 | padder = InputPadder(images.shape) 31 | return padder.pad(images)[0] 32 | 33 | 34 | def viz(img, flo): 35 | img = img[0].permute(1, 2, 0).cpu().numpy() 36 | flo = flo[0].permute(1, 2, 0).cpu().numpy() 37 | 38 | # map flow to rgb image 39 | flo = flow_viz.flow_to_image(flo) 40 | # img_flo = np.concatenate([img, flo], axis=0) 41 | img_flo = flo 42 | 43 | cv2.imwrite("/home/chengao/test/flow.png", img_flo[:, :, [2, 1, 0]]) 44 | # cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0) 45 | # cv2.waitKey() 46 | 47 | 48 | def demo(args): 49 | model = torch.nn.DataParallel(RAFT(args)) 50 | model.load_state_dict(torch.load(args.model)) 51 | 52 | model = model.module 53 | model.to(DEVICE) 54 | model.eval() 55 | 56 | with torch.no_grad(): 57 | images = glob.glob(os.path.join(args.path, "*.png")) + glob.glob( 58 | os.path.join(args.path, "*.jpg") 59 | ) 60 | 61 | images = load_image_list(images) 62 | for i in range(images.shape[0] - 1): 63 | image1 = images[i, None] 64 | image2 = images[i + 1, None] 65 | 66 | flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) 67 | viz(image1, flow_up) 68 | 69 | 70 | def RAFT_infer(args): 71 | model = torch.nn.DataParallel(RAFT(args)) 72 | model.load_state_dict(torch.load(args.model)) 73 | 74 | model = model.module 75 | model.to(DEVICE) 76 | model.eval() 77 | 78 | return model 79 | -------------------------------------------------------------------------------- /model/modules/RAFT/extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class ResidualBlock(nn.Module): 6 | def __init__(self, in_planes, planes, norm_fn="group", stride=1): 7 | super(ResidualBlock, self).__init__() 8 | 9 | self.conv1 = nn.Conv2d( 10 | in_planes, planes, kernel_size=3, padding=1, stride=stride 11 | ) 12 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 13 | self.relu = nn.ReLU(inplace=True) 14 | 15 | num_groups = planes // 8 16 | 17 | if norm_fn == "group": 18 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 19 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 20 | if stride != 1: 21 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 22 | 23 | elif norm_fn == "batch": 24 | self.norm1 = nn.BatchNorm2d(planes) 25 | self.norm2 = nn.BatchNorm2d(planes) 26 | if stride != 1: 27 | self.norm3 = nn.BatchNorm2d(planes) 28 | 29 | elif norm_fn == "instance": 30 | self.norm1 = nn.InstanceNorm2d(planes) 31 | self.norm2 = nn.InstanceNorm2d(planes) 32 | if stride != 1: 33 | self.norm3 = nn.InstanceNorm2d(planes) 34 | 35 | elif norm_fn == "none": 36 | self.norm1 = nn.Sequential() 37 | self.norm2 = nn.Sequential() 38 | if stride != 1: 39 | self.norm3 = nn.Sequential() 40 | 41 | if stride == 1: 42 | self.downsample = None 43 | 44 | else: 45 | self.downsample = nn.Sequential( 46 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3 47 | ) 48 | 49 | def forward(self, x): 50 | y = x 51 | y = self.relu(self.norm1(self.conv1(y))) 52 | y = self.relu(self.norm2(self.conv2(y))) 53 | 54 | if self.downsample is not None: 55 | x = self.downsample(x) 56 | 57 | return self.relu(x + y) 58 | 59 | 60 | class BottleneckBlock(nn.Module): 61 | def __init__(self, in_planes, planes, norm_fn="group", stride=1): 62 | super(BottleneckBlock, self).__init__() 63 | 64 | self.conv1 = nn.Conv2d(in_planes, planes // 4, kernel_size=1, padding=0) 65 | self.conv2 = nn.Conv2d( 66 | planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride 67 | ) 68 | self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0) 69 | self.relu = nn.ReLU(inplace=True) 70 | 71 | num_groups = planes // 8 72 | 73 | if norm_fn == "group": 74 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4) 75 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4) 76 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 77 | if stride != 1: 78 | self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 79 | 80 | elif norm_fn == "batch": 81 | self.norm1 = nn.BatchNorm2d(planes // 4) 82 | self.norm2 = nn.BatchNorm2d(planes // 4) 83 | self.norm3 = nn.BatchNorm2d(planes) 84 | if stride != 1: 85 | self.norm4 = nn.BatchNorm2d(planes) 86 | 87 | elif norm_fn == "instance": 88 | self.norm1 = nn.InstanceNorm2d(planes // 4) 89 | self.norm2 = nn.InstanceNorm2d(planes // 4) 90 | self.norm3 = nn.InstanceNorm2d(planes) 91 | if stride != 1: 92 | self.norm4 = nn.InstanceNorm2d(planes) 93 | 94 | elif norm_fn == "none": 95 | self.norm1 = nn.Sequential() 96 | self.norm2 = nn.Sequential() 97 | self.norm3 = nn.Sequential() 98 | if stride != 1: 99 | self.norm4 = nn.Sequential() 100 | 101 | if stride == 1: 102 | self.downsample = None 103 | 104 | else: 105 | self.downsample = nn.Sequential( 106 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4 107 | ) 108 | 109 | def forward(self, x): 110 | y = x 111 | y = self.relu(self.norm1(self.conv1(y))) 112 | y = self.relu(self.norm2(self.conv2(y))) 113 | y = self.relu(self.norm3(self.conv3(y))) 114 | 115 | if self.downsample is not None: 116 | x = self.downsample(x) 117 | 118 | return self.relu(x + y) 119 | 120 | 121 | class BasicEncoder(nn.Module): 122 | def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0): 123 | super(BasicEncoder, self).__init__() 124 | self.norm_fn = norm_fn 125 | 126 | if self.norm_fn == "group": 127 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) 128 | 129 | elif self.norm_fn == "batch": 130 | self.norm1 = nn.BatchNorm2d(64) 131 | 132 | elif self.norm_fn == "instance": 133 | self.norm1 = nn.InstanceNorm2d(64) 134 | 135 | elif self.norm_fn == "none": 136 | self.norm1 = nn.Sequential() 137 | 138 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) 139 | self.relu1 = nn.ReLU(inplace=True) 140 | 141 | self.in_planes = 64 142 | self.layer1 = self._make_layer(64, stride=1) 143 | self.layer2 = self._make_layer(96, stride=2) 144 | self.layer3 = self._make_layer(128, stride=2) 145 | 146 | # output convolution 147 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) 148 | 149 | self.dropout = None 150 | if dropout > 0: 151 | self.dropout = nn.Dropout2d(p=dropout) 152 | 153 | for m in self.modules(): 154 | if isinstance(m, nn.Conv2d): 155 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 156 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 157 | if m.weight is not None: 158 | nn.init.constant_(m.weight, 1) 159 | if m.bias is not None: 160 | nn.init.constant_(m.bias, 0) 161 | 162 | def _make_layer(self, dim, stride=1): 163 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) 164 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) 165 | layers = (layer1, layer2) 166 | 167 | self.in_planes = dim 168 | return nn.Sequential(*layers) 169 | 170 | def forward(self, x): 171 | # if input is list, combine batch dimension 172 | is_list = isinstance(x, (list, tuple)) 173 | if is_list: 174 | batch_dim = x[0].shape[0] 175 | x = torch.cat(x, dim=0) 176 | 177 | x = self.conv1(x) 178 | x = self.norm1(x) 179 | x = self.relu1(x) 180 | 181 | x = self.layer1(x) 182 | x = self.layer2(x) 183 | x = self.layer3(x) 184 | 185 | x = self.conv2(x) 186 | 187 | if self.training and self.dropout is not None: 188 | x = self.dropout(x) 189 | 190 | if is_list: 191 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 192 | 193 | return x 194 | 195 | 196 | class SmallEncoder(nn.Module): 197 | def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0): 198 | super(SmallEncoder, self).__init__() 199 | self.norm_fn = norm_fn 200 | 201 | if self.norm_fn == "group": 202 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) 203 | 204 | elif self.norm_fn == "batch": 205 | self.norm1 = nn.BatchNorm2d(32) 206 | 207 | elif self.norm_fn == "instance": 208 | self.norm1 = nn.InstanceNorm2d(32) 209 | 210 | elif self.norm_fn == "none": 211 | self.norm1 = nn.Sequential() 212 | 213 | self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) 214 | self.relu1 = nn.ReLU(inplace=True) 215 | 216 | self.in_planes = 32 217 | self.layer1 = self._make_layer(32, stride=1) 218 | self.layer2 = self._make_layer(64, stride=2) 219 | self.layer3 = self._make_layer(96, stride=2) 220 | 221 | self.dropout = None 222 | if dropout > 0: 223 | self.dropout = nn.Dropout2d(p=dropout) 224 | 225 | self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) 226 | 227 | for m in self.modules(): 228 | if isinstance(m, nn.Conv2d): 229 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 230 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 231 | if m.weight is not None: 232 | nn.init.constant_(m.weight, 1) 233 | if m.bias is not None: 234 | nn.init.constant_(m.bias, 0) 235 | 236 | def _make_layer(self, dim, stride=1): 237 | layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) 238 | layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) 239 | layers = (layer1, layer2) 240 | 241 | self.in_planes = dim 242 | return nn.Sequential(*layers) 243 | 244 | def forward(self, x): 245 | # if input is list, combine batch dimension 246 | is_list = isinstance(x, (list, tuple)) 247 | if is_list: 248 | batch_dim = x[0].shape[0] 249 | x = torch.cat(x, dim=0) 250 | 251 | x = self.conv1(x) 252 | x = self.norm1(x) 253 | x = self.relu1(x) 254 | 255 | x = self.layer1(x) 256 | x = self.layer2(x) 257 | x = self.layer3(x) 258 | x = self.conv2(x) 259 | 260 | if self.training and self.dropout is not None: 261 | x = self.dropout(x) 262 | 263 | if is_list: 264 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 265 | 266 | return x 267 | -------------------------------------------------------------------------------- /model/modules/RAFT/raft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | from .update import BasicUpdateBlock, SmallUpdateBlock 6 | from .extractor import BasicEncoder, SmallEncoder 7 | from .corr import CorrBlock, AlternateCorrBlock 8 | from .utils.utils import coords_grid, upflow8 9 | 10 | try: 11 | autocast = torch.cuda.amp.autocast 12 | except: 13 | # dummy autocast for PyTorch < 1.6 14 | class autocast: 15 | def __init__(self, enabled): 16 | pass 17 | 18 | def __enter__(self): 19 | pass 20 | 21 | def __exit__(self, *args): 22 | pass 23 | 24 | 25 | class RAFT(nn.Module): 26 | def __init__(self, args): 27 | super(RAFT, self).__init__() 28 | self.args = args 29 | 30 | if args.small: 31 | self.hidden_dim = hdim = 96 32 | self.context_dim = cdim = 64 33 | args.corr_levels = 4 34 | args.corr_radius = 3 35 | 36 | else: 37 | self.hidden_dim = hdim = 128 38 | self.context_dim = cdim = 128 39 | args.corr_levels = 4 40 | args.corr_radius = 4 41 | 42 | if "dropout" not in args._get_kwargs(): 43 | args.dropout = 0 44 | 45 | if "alternate_corr" not in args._get_kwargs(): 46 | args.alternate_corr = False 47 | 48 | # feature network, context network, and update block 49 | if args.small: 50 | self.fnet = SmallEncoder( 51 | output_dim=128, norm_fn="instance", dropout=args.dropout 52 | ) 53 | self.cnet = SmallEncoder( 54 | output_dim=hdim + cdim, norm_fn="none", dropout=args.dropout 55 | ) 56 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) 57 | 58 | else: 59 | self.fnet = BasicEncoder( 60 | output_dim=256, norm_fn="instance", dropout=args.dropout 61 | ) 62 | self.cnet = BasicEncoder( 63 | output_dim=hdim + cdim, norm_fn="batch", dropout=args.dropout 64 | ) 65 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) 66 | 67 | def freeze_bn(self): 68 | for m in self.modules(): 69 | if isinstance(m, nn.BatchNorm2d): 70 | m.eval() 71 | 72 | def initialize_flow(self, img): 73 | """Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 74 | N, C, H, W = img.shape 75 | coords0 = coords_grid(N, H // 8, W // 8).to(img.device) 76 | coords1 = coords_grid(N, H // 8, W // 8).to(img.device) 77 | 78 | # optical flow computed as difference: flow = coords1 - coords0 79 | return coords0, coords1 80 | 81 | def upsample_flow(self, flow, mask): 82 | """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination""" 83 | N, _, H, W = flow.shape 84 | mask = mask.view(N, 1, 9, 8, 8, H, W) 85 | mask = torch.softmax(mask, dim=2) 86 | 87 | up_flow = F.unfold(8 * flow, [3, 3], padding=1) 88 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 89 | 90 | up_flow = torch.sum(mask * up_flow, dim=2) 91 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 92 | return up_flow.reshape(N, 2, 8 * H, 8 * W) 93 | 94 | def forward(self, image1, image2, iters=12, flow_init=None, test_mode=True): 95 | """Estimate optical flow between pair of frames""" 96 | # image1 = 2 * (image1 / 255.0) - 1.0 97 | # image2 = 2 * (image2 / 255.0) - 1.0 98 | 99 | image1 = image1.contiguous() 100 | image2 = image2.contiguous() 101 | 102 | hdim = self.hidden_dim 103 | cdim = self.context_dim 104 | 105 | # run the feature network 106 | with autocast(enabled=self.args.mixed_precision): 107 | fmap1, fmap2 = self.fnet([image1, image2]) 108 | 109 | fmap1 = fmap1.float() 110 | fmap2 = fmap2.float() 111 | 112 | if self.args.alternate_corr: 113 | corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 114 | else: 115 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 116 | 117 | # run the context network 118 | with autocast(enabled=self.args.mixed_precision): 119 | cnet = self.cnet(image1) 120 | net, inp = torch.split(cnet, [hdim, cdim], dim=1) 121 | net = torch.tanh(net) 122 | inp = torch.relu(inp) 123 | 124 | coords0, coords1 = self.initialize_flow(image1) 125 | 126 | if flow_init is not None: 127 | coords1 = coords1 + flow_init 128 | 129 | flow_predictions = [] 130 | for itr in range(iters): 131 | coords1 = coords1.detach() 132 | corr = corr_fn(coords1) # index correlation volume 133 | 134 | flow = coords1 - coords0 135 | with autocast(enabled=self.args.mixed_precision): 136 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) 137 | 138 | # F(t+1) = F(t) + \Delta(t) 139 | coords1 = coords1 + delta_flow 140 | 141 | # upsample predictions 142 | if up_mask is None: 143 | flow_up = upflow8(coords1 - coords0) 144 | else: 145 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 146 | 147 | flow_predictions.append(flow_up) 148 | 149 | if test_mode: 150 | return coords1 - coords0, flow_up 151 | 152 | return flow_predictions 153 | -------------------------------------------------------------------------------- /model/modules/RAFT/update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FlowHead(nn.Module): 7 | def __init__(self, input_dim=128, hidden_dim=256): 8 | super(FlowHead, self).__init__() 9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 11 | self.relu = nn.ReLU(inplace=True) 12 | 13 | def forward(self, x): 14 | return self.conv2(self.relu(self.conv1(x))) 15 | 16 | 17 | class ConvGRU(nn.Module): 18 | def __init__(self, hidden_dim=128, input_dim=192 + 128): 19 | super(ConvGRU, self).__init__() 20 | self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) 21 | self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) 22 | self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) 23 | 24 | def forward(self, h, x): 25 | hx = torch.cat([h, x], dim=1) 26 | 27 | z = torch.sigmoid(self.convz(hx)) 28 | r = torch.sigmoid(self.convr(hx)) 29 | q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1))) 30 | 31 | h = (1 - z) * h + z * q 32 | return h 33 | 34 | 35 | class SepConvGRU(nn.Module): 36 | def __init__(self, hidden_dim=128, input_dim=192 + 128): 37 | super(SepConvGRU, self).__init__() 38 | self.convz1 = nn.Conv2d( 39 | hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) 40 | ) 41 | self.convr1 = nn.Conv2d( 42 | hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) 43 | ) 44 | self.convq1 = nn.Conv2d( 45 | hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) 46 | ) 47 | 48 | self.convz2 = nn.Conv2d( 49 | hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) 50 | ) 51 | self.convr2 = nn.Conv2d( 52 | hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) 53 | ) 54 | self.convq2 = nn.Conv2d( 55 | hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) 56 | ) 57 | 58 | def forward(self, h, x): 59 | # horizontal 60 | hx = torch.cat([h, x], dim=1) 61 | z = torch.sigmoid(self.convz1(hx)) 62 | r = torch.sigmoid(self.convr1(hx)) 63 | q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1))) 64 | h = (1 - z) * h + z * q 65 | 66 | # vertical 67 | hx = torch.cat([h, x], dim=1) 68 | z = torch.sigmoid(self.convz2(hx)) 69 | r = torch.sigmoid(self.convr2(hx)) 70 | q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1))) 71 | h = (1 - z) * h + z * q 72 | 73 | return h 74 | 75 | 76 | class SmallMotionEncoder(nn.Module): 77 | def __init__(self, args): 78 | super(SmallMotionEncoder, self).__init__() 79 | cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2 80 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) 81 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3) 82 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1) 83 | self.conv = nn.Conv2d(128, 80, 3, padding=1) 84 | 85 | def forward(self, flow, corr): 86 | cor = F.relu(self.convc1(corr)) 87 | flo = F.relu(self.convf1(flow)) 88 | flo = F.relu(self.convf2(flo)) 89 | cor_flo = torch.cat([cor, flo], dim=1) 90 | out = F.relu(self.conv(cor_flo)) 91 | return torch.cat([out, flow], dim=1) 92 | 93 | 94 | class BasicMotionEncoder(nn.Module): 95 | def __init__(self, args): 96 | super(BasicMotionEncoder, self).__init__() 97 | cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2 98 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) 99 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 100 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3) 101 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 102 | self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1) 103 | 104 | def forward(self, flow, corr): 105 | cor = F.relu(self.convc1(corr)) 106 | cor = F.relu(self.convc2(cor)) 107 | flo = F.relu(self.convf1(flow)) 108 | flo = F.relu(self.convf2(flo)) 109 | 110 | cor_flo = torch.cat([cor, flo], dim=1) 111 | out = F.relu(self.conv(cor_flo)) 112 | return torch.cat([out, flow], dim=1) 113 | 114 | 115 | class SmallUpdateBlock(nn.Module): 116 | def __init__(self, args, hidden_dim=96): 117 | super(SmallUpdateBlock, self).__init__() 118 | self.encoder = SmallMotionEncoder(args) 119 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82 + 64) 120 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128) 121 | 122 | def forward(self, net, inp, corr, flow): 123 | motion_features = self.encoder(flow, corr) 124 | inp = torch.cat([inp, motion_features], dim=1) 125 | net = self.gru(net, inp) 126 | delta_flow = self.flow_head(net) 127 | 128 | return net, None, delta_flow 129 | 130 | 131 | class BasicUpdateBlock(nn.Module): 132 | def __init__(self, args, hidden_dim=128, input_dim=128): 133 | super(BasicUpdateBlock, self).__init__() 134 | self.args = args 135 | self.encoder = BasicMotionEncoder(args) 136 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim) 137 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 138 | 139 | self.mask = nn.Sequential( 140 | nn.Conv2d(128, 256, 3, padding=1), 141 | nn.ReLU(inplace=True), 142 | nn.Conv2d(256, 64 * 9, 1, padding=0), 143 | ) 144 | 145 | def forward(self, net, inp, corr, flow, upsample=True): 146 | motion_features = self.encoder(flow, corr) 147 | inp = torch.cat([inp, motion_features], dim=1) 148 | 149 | net = self.gru(net, inp) 150 | delta_flow = self.flow_head(net) 151 | 152 | # scale mask to balence gradients 153 | mask = 0.25 * self.mask(net) 154 | return net, mask, delta_flow 155 | -------------------------------------------------------------------------------- /model/modules/RAFT/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .flow_viz import flow_to_image 2 | from .frame_utils import writeFlow 3 | -------------------------------------------------------------------------------- /model/modules/RAFT/utils/augmentor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | import cv2 5 | 6 | cv2.setNumThreads(0) 7 | cv2.ocl.setUseOpenCL(False) 8 | 9 | from torchvision.transforms import ColorJitter 10 | 11 | 12 | class FlowAugmentor: 13 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True): 14 | # spatial augmentation params 15 | self.crop_size = crop_size 16 | self.min_scale = min_scale 17 | self.max_scale = max_scale 18 | self.spatial_aug_prob = 0.8 19 | self.stretch_prob = 0.8 20 | self.max_stretch = 0.2 21 | 22 | # flip augmentation params 23 | self.do_flip = do_flip 24 | self.h_flip_prob = 0.5 25 | self.v_flip_prob = 0.1 26 | 27 | # photometric augmentation params 28 | self.photo_aug = ColorJitter( 29 | brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5 / 3.14 30 | ) 31 | self.asymmetric_color_aug_prob = 0.2 32 | self.eraser_aug_prob = 0.5 33 | 34 | def color_transform(self, img1, img2): 35 | """Photometric augmentation""" 36 | # asymmetric 37 | if np.random.rand() < self.asymmetric_color_aug_prob: 38 | img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8) 39 | img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8) 40 | 41 | # symmetric 42 | else: 43 | image_stack = np.concatenate([img1, img2], axis=0) 44 | image_stack = np.array( 45 | self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8 46 | ) 47 | img1, img2 = np.split(image_stack, 2, axis=0) 48 | 49 | return img1, img2 50 | 51 | def eraser_transform(self, img1, img2, bounds=[50, 100]): 52 | """Occlusion augmentation""" 53 | ht, wd = img1.shape[:2] 54 | if np.random.rand() < self.eraser_aug_prob: 55 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 56 | for _ in range(np.random.randint(1, 3)): 57 | x0 = np.random.randint(0, wd) 58 | y0 = np.random.randint(0, ht) 59 | dx = np.random.randint(bounds[0], bounds[1]) 60 | dy = np.random.randint(bounds[0], bounds[1]) 61 | img2[y0 : y0 + dy, x0 : x0 + dx, :] = mean_color 62 | 63 | return img1, img2 64 | 65 | def spatial_transform(self, img1, img2, flow): 66 | # randomly sample scale 67 | ht, wd = img1.shape[:2] 68 | min_scale = np.maximum( 69 | (self.crop_size[0] + 8) / float(ht), (self.crop_size[1] + 8) / float(wd) 70 | ) 71 | 72 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 73 | scale_x = scale 74 | scale_y = scale 75 | if np.random.rand() < self.stretch_prob: 76 | scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 77 | scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 78 | 79 | scale_x = np.clip(scale_x, min_scale, None) 80 | scale_y = np.clip(scale_y, min_scale, None) 81 | 82 | if np.random.rand() < self.spatial_aug_prob: 83 | # rescale the images 84 | img1 = cv2.resize( 85 | img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR 86 | ) 87 | img2 = cv2.resize( 88 | img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR 89 | ) 90 | flow = cv2.resize( 91 | flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR 92 | ) 93 | flow = flow * [scale_x, scale_y] 94 | 95 | if self.do_flip: 96 | if np.random.rand() < self.h_flip_prob: # h-flip 97 | img1 = img1[:, ::-1] 98 | img2 = img2[:, ::-1] 99 | flow = flow[:, ::-1] * [-1.0, 1.0] 100 | 101 | if np.random.rand() < self.v_flip_prob: # v-flip 102 | img1 = img1[::-1, :] 103 | img2 = img2[::-1, :] 104 | flow = flow[::-1, :] * [1.0, -1.0] 105 | 106 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) 107 | x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) 108 | 109 | img1 = img1[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] 110 | img2 = img2[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] 111 | flow = flow[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] 112 | 113 | return img1, img2, flow 114 | 115 | def __call__(self, img1, img2, flow): 116 | img1, img2 = self.color_transform(img1, img2) 117 | img1, img2 = self.eraser_transform(img1, img2) 118 | img1, img2, flow = self.spatial_transform(img1, img2, flow) 119 | 120 | img1 = np.ascontiguousarray(img1) 121 | img2 = np.ascontiguousarray(img2) 122 | flow = np.ascontiguousarray(flow) 123 | 124 | return img1, img2, flow 125 | 126 | 127 | class SparseFlowAugmentor: 128 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False): 129 | # spatial augmentation params 130 | self.crop_size = crop_size 131 | self.min_scale = min_scale 132 | self.max_scale = max_scale 133 | self.spatial_aug_prob = 0.8 134 | self.stretch_prob = 0.8 135 | self.max_stretch = 0.2 136 | 137 | # flip augmentation params 138 | self.do_flip = do_flip 139 | self.h_flip_prob = 0.5 140 | self.v_flip_prob = 0.1 141 | 142 | # photometric augmentation params 143 | self.photo_aug = ColorJitter( 144 | brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3 / 3.14 145 | ) 146 | self.asymmetric_color_aug_prob = 0.2 147 | self.eraser_aug_prob = 0.5 148 | 149 | def color_transform(self, img1, img2): 150 | image_stack = np.concatenate([img1, img2], axis=0) 151 | image_stack = np.array( 152 | self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8 153 | ) 154 | img1, img2 = np.split(image_stack, 2, axis=0) 155 | return img1, img2 156 | 157 | def eraser_transform(self, img1, img2): 158 | ht, wd = img1.shape[:2] 159 | if np.random.rand() < self.eraser_aug_prob: 160 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 161 | for _ in range(np.random.randint(1, 3)): 162 | x0 = np.random.randint(0, wd) 163 | y0 = np.random.randint(0, ht) 164 | dx = np.random.randint(50, 100) 165 | dy = np.random.randint(50, 100) 166 | img2[y0 : y0 + dy, x0 : x0 + dx, :] = mean_color 167 | 168 | return img1, img2 169 | 170 | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): 171 | ht, wd = flow.shape[:2] 172 | coords = np.meshgrid(np.arange(wd), np.arange(ht)) 173 | coords = np.stack(coords, axis=-1) 174 | 175 | coords = coords.reshape(-1, 2).astype(np.float32) 176 | flow = flow.reshape(-1, 2).astype(np.float32) 177 | valid = valid.reshape(-1).astype(np.float32) 178 | 179 | coords0 = coords[valid >= 1] 180 | flow0 = flow[valid >= 1] 181 | 182 | ht1 = int(round(ht * fy)) 183 | wd1 = int(round(wd * fx)) 184 | 185 | coords1 = coords0 * [fx, fy] 186 | flow1 = flow0 * [fx, fy] 187 | 188 | xx = np.round(coords1[:, 0]).astype(np.int32) 189 | yy = np.round(coords1[:, 1]).astype(np.int32) 190 | 191 | v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) 192 | xx = xx[v] 193 | yy = yy[v] 194 | flow1 = flow1[v] 195 | 196 | flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) 197 | valid_img = np.zeros([ht1, wd1], dtype=np.int32) 198 | 199 | flow_img[yy, xx] = flow1 200 | valid_img[yy, xx] = 1 201 | 202 | return flow_img, valid_img 203 | 204 | def spatial_transform(self, img1, img2, flow, valid): 205 | # randomly sample scale 206 | 207 | ht, wd = img1.shape[:2] 208 | min_scale = np.maximum( 209 | (self.crop_size[0] + 1) / float(ht), (self.crop_size[1] + 1) / float(wd) 210 | ) 211 | 212 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 213 | scale_x = np.clip(scale, min_scale, None) 214 | scale_y = np.clip(scale, min_scale, None) 215 | 216 | if np.random.rand() < self.spatial_aug_prob: 217 | # rescale the images 218 | img1 = cv2.resize( 219 | img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR 220 | ) 221 | img2 = cv2.resize( 222 | img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR 223 | ) 224 | flow, valid = self.resize_sparse_flow_map( 225 | flow, valid, fx=scale_x, fy=scale_y 226 | ) 227 | 228 | if self.do_flip: 229 | if np.random.rand() < 0.5: # h-flip 230 | img1 = img1[:, ::-1] 231 | img2 = img2[:, ::-1] 232 | flow = flow[:, ::-1] * [-1.0, 1.0] 233 | valid = valid[:, ::-1] 234 | 235 | margin_y = 20 236 | margin_x = 50 237 | 238 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y) 239 | x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x) 240 | 241 | y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) 242 | x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) 243 | 244 | img1 = img1[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] 245 | img2 = img2[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] 246 | flow = flow[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] 247 | valid = valid[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] 248 | return img1, img2, flow, valid 249 | 250 | def __call__(self, img1, img2, flow, valid): 251 | img1, img2 = self.color_transform(img1, img2) 252 | img1, img2 = self.eraser_transform(img1, img2) 253 | img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid) 254 | 255 | img1 = np.ascontiguousarray(img1) 256 | img2 = np.ascontiguousarray(img2) 257 | flow = np.ascontiguousarray(flow) 258 | valid = np.ascontiguousarray(valid) 259 | 260 | return img1, img2, flow, valid 261 | -------------------------------------------------------------------------------- /model/modules/RAFT/utils/flow_viz.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization 2 | 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Tom Runia 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to conditions. 14 | # 15 | # Author: Tom Runia 16 | # Date Created: 2018-08-03 17 | 18 | import numpy as np 19 | 20 | 21 | def make_colorwheel(): 22 | """Generates a color wheel for optical flow visualization as presented in: 23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 25 | 26 | Code follows the original C++ source code of Daniel Scharstein. 27 | Code follows the the Matlab source code of Deqing Sun. 28 | 29 | Returns: 30 | np.ndarray: Color wheel 31 | """ 32 | RY = 15 33 | YG = 6 34 | GC = 4 35 | CB = 11 36 | BM = 13 37 | MR = 6 38 | 39 | ncols = RY + YG + GC + CB + BM + MR 40 | colorwheel = np.zeros((ncols, 3)) 41 | col = 0 42 | 43 | # RY 44 | colorwheel[0:RY, 0] = 255 45 | colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY) 46 | col = col + RY 47 | # YG 48 | colorwheel[col : col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG) 49 | colorwheel[col : col + YG, 1] = 255 50 | col = col + YG 51 | # GC 52 | colorwheel[col : col + GC, 1] = 255 53 | colorwheel[col : col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC) 54 | col = col + GC 55 | # CB 56 | colorwheel[col : col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB) 57 | colorwheel[col : col + CB, 2] = 255 58 | col = col + CB 59 | # BM 60 | colorwheel[col : col + BM, 2] = 255 61 | colorwheel[col : col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM) 62 | col = col + BM 63 | # MR 64 | colorwheel[col : col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR) 65 | colorwheel[col : col + MR, 0] = 255 66 | return colorwheel 67 | 68 | 69 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 70 | """Applies the flow color wheel to (possibly clipped) flow components u and v. 71 | 72 | According to the C++ source code of Daniel Scharstein 73 | According to the Matlab source code of Deqing Sun 74 | 75 | Args: 76 | u (np.ndarray): Input horizontal flow of shape [H,W] 77 | v (np.ndarray): Input vertical flow of shape [H,W] 78 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 79 | 80 | Returns: 81 | np.ndarray: Flow visualization image of shape [H,W,3] 82 | """ 83 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 84 | colorwheel = make_colorwheel() # shape [55x3] 85 | ncols = colorwheel.shape[0] 86 | rad = np.sqrt(np.square(u) + np.square(v)) 87 | a = np.arctan2(-v, -u) / np.pi 88 | fk = (a + 1) / 2 * (ncols - 1) 89 | k0 = np.floor(fk).astype(np.int32) 90 | k1 = k0 + 1 91 | k1[k1 == ncols] = 0 92 | f = fk - k0 93 | for i in range(colorwheel.shape[1]): 94 | tmp = colorwheel[:, i] 95 | col0 = tmp[k0] / 255.0 96 | col1 = tmp[k1] / 255.0 97 | col = (1 - f) * col0 + f * col1 98 | idx = rad <= 1 99 | col[idx] = 1 - rad[idx] * (1 - col[idx]) 100 | col[~idx] = col[~idx] * 0.75 # out of range 101 | # Note the 2-i => BGR instead of RGB 102 | ch_idx = 2 - i if convert_to_bgr else i 103 | flow_image[:, :, ch_idx] = np.floor(255 * col) 104 | return flow_image 105 | 106 | 107 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): 108 | """Expects a two dimensional flow image of shape. 109 | 110 | Args: 111 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 112 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 113 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 114 | 115 | Returns: 116 | np.ndarray: Flow visualization image of shape [H,W,3] 117 | """ 118 | assert flow_uv.ndim == 3, "input flow must have three dimensions" 119 | assert flow_uv.shape[2] == 2, "input flow must have shape [H,W,2]" 120 | if clip_flow is not None: 121 | flow_uv = np.clip(flow_uv, 0, clip_flow) 122 | u = flow_uv[:, :, 0] 123 | v = flow_uv[:, :, 1] 124 | rad = np.sqrt(np.square(u) + np.square(v)) 125 | rad_max = np.max(rad) 126 | epsilon = 1e-5 127 | u = u / (rad_max + epsilon) 128 | v = v / (rad_max + epsilon) 129 | return flow_uv_to_colors(u, v, convert_to_bgr) 130 | -------------------------------------------------------------------------------- /model/modules/RAFT/utils/flow_viz_pt.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization 2 | import torch 3 | 4 | torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732 5 | 6 | 7 | @torch.no_grad() 8 | def flow_to_image(flow: torch.Tensor) -> torch.Tensor: 9 | """Converts a flow to an RGB image. 10 | 11 | Args: 12 | flow (Tensor): Flow of shape (N, 2, H, W) or (2, H, W) and dtype torch.float. 13 | 14 | Returns: 15 | img (Tensor): Image Tensor of dtype uint8 where each color corresponds 16 | to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input. 17 | """ 18 | if flow.dtype != torch.float: 19 | raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.") 20 | 21 | orig_shape = flow.shape 22 | if flow.ndim == 3: 23 | flow = flow[None] # Add batch dim 24 | 25 | if flow.ndim != 4 or flow.shape[1] != 2: 26 | raise ValueError( 27 | f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}." 28 | ) 29 | 30 | max_norm = torch.sum(flow**2, dim=1).sqrt().max() 31 | epsilon = torch.finfo((flow).dtype).eps 32 | normalized_flow = flow / (max_norm + epsilon) 33 | img = _normalized_flow_to_image(normalized_flow) 34 | 35 | if len(orig_shape) == 3: 36 | img = img[0] # Remove batch dim 37 | return img 38 | 39 | 40 | @torch.no_grad() 41 | def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor: 42 | """Converts a batch of normalized flow to an RGB image. 43 | 44 | Args: 45 | normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W) 46 | 47 | Returns: 48 | img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8. 49 | """ 50 | N, _, H, W = normalized_flow.shape 51 | device = normalized_flow.device 52 | flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8, device=device) 53 | colorwheel = _make_colorwheel().to(device) # shape [55x3] 54 | num_cols = colorwheel.shape[0] 55 | norm = torch.sum(normalized_flow**2, dim=1).sqrt() 56 | a = ( 57 | torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) 58 | / torch.pi 59 | ) 60 | fk = (a + 1) / 2 * (num_cols - 1) 61 | k0 = torch.floor(fk).to(torch.long) 62 | k1 = k0 + 1 63 | k1[k1 == num_cols] = 0 64 | f = fk - k0 65 | 66 | for c in range(colorwheel.shape[1]): 67 | tmp = colorwheel[:, c] 68 | col0 = tmp[k0] / 255.0 69 | col1 = tmp[k1] / 255.0 70 | col = (1 - f) * col0 + f * col1 71 | col = 1 - norm * (1 - col) 72 | flow_image[:, c, :, :] = torch.floor(255.0 * col) 73 | return flow_image 74 | 75 | 76 | @torch.no_grad() 77 | def _make_colorwheel() -> torch.Tensor: 78 | """Generates a color wheel for optical flow visualization as presented in: 79 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 80 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf. 81 | 82 | Returns: 83 | colorwheel (Tensor[55, 3]): Colorwheel Tensor. 84 | """ 85 | RY = 15 86 | YG = 6 87 | GC = 4 88 | CB = 11 89 | BM = 13 90 | MR = 6 91 | 92 | ncols = RY + YG + GC + CB + BM + MR 93 | colorwheel = torch.zeros((ncols, 3)) 94 | col = 0 95 | 96 | # RY 97 | colorwheel[0:RY, 0] = 255 98 | colorwheel[0:RY, 1] = torch.floor(255.0 * torch.arange(0.0, RY) / RY) 99 | col = col + RY 100 | # YG 101 | colorwheel[col : col + YG, 0] = 255 - torch.floor( 102 | 255.0 * torch.arange(0.0, YG) / YG 103 | ) 104 | colorwheel[col : col + YG, 1] = 255 105 | col = col + YG 106 | # GC 107 | colorwheel[col : col + GC, 1] = 255 108 | colorwheel[col : col + GC, 2] = torch.floor(255.0 * torch.arange(0.0, GC) / GC) 109 | col = col + GC 110 | # CB 111 | colorwheel[col : col + CB, 1] = 255 - torch.floor(255.0 * torch.arange(CB) / CB) 112 | colorwheel[col : col + CB, 2] = 255 113 | col = col + CB 114 | # BM 115 | colorwheel[col : col + BM, 2] = 255 116 | colorwheel[col : col + BM, 0] = torch.floor(255.0 * torch.arange(0.0, BM) / BM) 117 | col = col + BM 118 | # MR 119 | colorwheel[col : col + MR, 2] = 255 - torch.floor(255.0 * torch.arange(MR) / MR) 120 | colorwheel[col : col + MR, 0] = 255 121 | return colorwheel 122 | -------------------------------------------------------------------------------- /model/modules/RAFT/utils/frame_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from os.path import * 4 | import re 5 | 6 | import cv2 7 | 8 | cv2.setNumThreads(0) 9 | cv2.ocl.setUseOpenCL(False) 10 | 11 | TAG_CHAR = np.array([202021.25], np.float32) 12 | 13 | 14 | def readFlow(fn): 15 | """Read .flo file in Middlebury format""" 16 | # Code adapted from: 17 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 18 | 19 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 20 | # print 'fn = %s'%(fn) 21 | with open(fn, "rb") as f: 22 | magic = np.fromfile(f, np.float32, count=1) 23 | if magic != 202021.25: 24 | print("Magic number incorrect. Invalid .flo file") 25 | return None 26 | else: 27 | w = np.fromfile(f, np.int32, count=1) 28 | h = np.fromfile(f, np.int32, count=1) 29 | # print 'Reading %d x %d flo file\n' % (w, h) 30 | data = np.fromfile(f, np.float32, count=2 * int(w) * int(h)) 31 | # Reshape data into 3D array (columns, rows, bands) 32 | # The reshape here is for visualization, the original code is (w,h,2) 33 | return np.resize(data, (int(h), int(w), 2)) 34 | 35 | 36 | def readPFM(file): 37 | file = open(file, "rb") 38 | 39 | color = None 40 | width = None 41 | height = None 42 | scale = None 43 | endian = None 44 | 45 | header = file.readline().rstrip() 46 | if header == b"PF": 47 | color = True 48 | elif header == b"Pf": 49 | color = False 50 | else: 51 | raise Exception("Not a PFM file.") 52 | 53 | dim_match = re.match(rb"^(\d+)\s(\d+)\s$", file.readline()) 54 | if dim_match: 55 | width, height = map(int, dim_match.groups()) 56 | else: 57 | raise Exception("Malformed PFM header.") 58 | 59 | scale = float(file.readline().rstrip()) 60 | if scale < 0: # little-endian 61 | endian = "<" 62 | scale = -scale 63 | else: 64 | endian = ">" # big-endian 65 | 66 | data = np.fromfile(file, endian + "f") 67 | shape = (height, width, 3) if color else (height, width) 68 | 69 | data = np.reshape(data, shape) 70 | data = np.flipud(data) 71 | return data 72 | 73 | 74 | def writeFlow(filename, uv, v=None): 75 | """Write optical flow to file. 76 | 77 | If v is None, uv is assumed to contain both u and v channels, 78 | stacked in depth. 79 | Original code by Deqing Sun, adapted from Daniel Scharstein. 80 | """ 81 | nBands = 2 82 | 83 | if v is None: 84 | assert uv.ndim == 3 85 | assert uv.shape[2] == 2 86 | u = uv[:, :, 0] 87 | v = uv[:, :, 1] 88 | else: 89 | u = uv 90 | 91 | assert u.shape == v.shape 92 | height, width = u.shape 93 | f = open(filename, "wb") 94 | # write the header 95 | f.write(TAG_CHAR) 96 | np.array(width).astype(np.int32).tofile(f) 97 | np.array(height).astype(np.int32).tofile(f) 98 | # arrange into matrix form 99 | tmp = np.zeros((height, width * nBands)) 100 | tmp[:, np.arange(width) * 2] = u 101 | tmp[:, np.arange(width) * 2 + 1] = v 102 | tmp.astype(np.float32).tofile(f) 103 | f.close() 104 | 105 | 106 | def readFlowKITTI(filename): 107 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR) 108 | flow = flow[:, :, ::-1].astype(np.float32) 109 | flow, valid = flow[:, :, :2], flow[:, :, 2] 110 | flow = (flow - 2**15) / 64.0 111 | return flow, valid 112 | 113 | 114 | def readDispKITTI(filename): 115 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 116 | valid = disp > 0.0 117 | flow = np.stack([-disp, np.zeros_like(disp)], -1) 118 | return flow, valid 119 | 120 | 121 | def writeFlowKITTI(filename, uv): 122 | uv = 64.0 * uv + 2**15 123 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 124 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 125 | cv2.imwrite(filename, uv[..., ::-1]) 126 | 127 | 128 | def read_gen(file_name, pil=False): 129 | ext = splitext(file_name)[-1] 130 | if ext == ".png" or ext == ".jpeg" or ext == ".ppm" or ext == ".jpg": 131 | return Image.open(file_name) 132 | elif ext == ".bin" or ext == ".raw": 133 | return np.load(file_name) 134 | elif ext == ".flo": 135 | return readFlow(file_name).astype(np.float32) 136 | elif ext == ".pfm": 137 | flow = readPFM(file_name).astype(np.float32) 138 | if len(flow.shape) == 2: 139 | return flow 140 | else: 141 | return flow[:, :, :-1] 142 | return [] 143 | -------------------------------------------------------------------------------- /model/modules/RAFT/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy import interpolate 5 | 6 | 7 | class InputPadder: 8 | """Pads images such that dimensions are divisible by 8""" 9 | 10 | def __init__(self, dims, mode="sintel"): 11 | self.ht, self.wd = dims[-2:] 12 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 13 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 14 | if mode == "sintel": 15 | self._pad = [ 16 | pad_wd // 2, 17 | pad_wd - pad_wd // 2, 18 | pad_ht // 2, 19 | pad_ht - pad_ht // 2, 20 | ] 21 | else: 22 | self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht] 23 | 24 | def pad(self, *inputs): 25 | return [F.pad(x, self._pad, mode="replicate") for x in inputs] 26 | 27 | def unpad(self, x): 28 | ht, wd = x.shape[-2:] 29 | c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] 30 | return x[..., c[0] : c[1], c[2] : c[3]] 31 | 32 | 33 | def forward_interpolate(flow): 34 | flow = flow.detach().cpu().numpy() 35 | dx, dy = flow[0], flow[1] 36 | 37 | ht, wd = dx.shape 38 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 39 | 40 | x1 = x0 + dx 41 | y1 = y0 + dy 42 | 43 | x1 = x1.reshape(-1) 44 | y1 = y1.reshape(-1) 45 | dx = dx.reshape(-1) 46 | dy = dy.reshape(-1) 47 | 48 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 49 | x1 = x1[valid] 50 | y1 = y1[valid] 51 | dx = dx[valid] 52 | dy = dy[valid] 53 | 54 | flow_x = interpolate.griddata( 55 | (x1, y1), dx, (x0, y0), method="nearest", fill_value=0 56 | ) 57 | 58 | flow_y = interpolate.griddata( 59 | (x1, y1), dy, (x0, y0), method="nearest", fill_value=0 60 | ) 61 | 62 | flow = np.stack([flow_x, flow_y], axis=0) 63 | return torch.from_numpy(flow).float() 64 | 65 | 66 | def bilinear_sampler(img, coords, mode="bilinear", mask=False): 67 | """Wrapper for grid_sample, uses pixel coordinates""" 68 | H, W = img.shape[-2:] 69 | xgrid, ygrid = coords.split([1, 1], dim=-1) 70 | xgrid = 2 * xgrid / (W - 1) - 1 71 | ygrid = 2 * ygrid / (H - 1) - 1 72 | 73 | grid = torch.cat([xgrid, ygrid], dim=-1) 74 | img = F.grid_sample(img, grid, align_corners=True) 75 | 76 | if mask: 77 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 78 | return img, mask.float() 79 | 80 | return img 81 | 82 | 83 | def coords_grid(batch, ht, wd): 84 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) 85 | coords = torch.stack(coords[::-1], dim=0).float() 86 | return coords[None].repeat(batch, 1, 1, 1) 87 | 88 | 89 | def upflow8(flow, mode="bilinear"): 90 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 91 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 92 | -------------------------------------------------------------------------------- /model/modules/base_module.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | 4 | from functools import reduce 5 | 6 | 7 | class BaseNetwork(nn.Module): 8 | def __init__(self): 9 | super(BaseNetwork, self).__init__() 10 | 11 | def print_network(self): 12 | if isinstance(self, list): 13 | self = self[0] 14 | num_params = 0 15 | for param in self.parameters(): 16 | num_params += param.numel() 17 | print( 18 | "Network [%s] was created. Total number of parameters: %.1f million. " 19 | "" % (type(self).__name__, num_params / 1000000) 20 | ) 21 | 22 | def init_weights(self, init_type="normal", gain=0.02): 23 | """Initialize network's weights 24 | init_type: normal | xavier | kaiming | orthogonal 25 | https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39 26 | """ 27 | 28 | def init_func(m): 29 | classname = m.__class__.__name__ 30 | if classname.find("InstanceNorm2d") != -1: 31 | if hasattr(m, "weight") and m.weight is not None: 32 | nn.init.constant_(m.weight.data, 1.0) 33 | if hasattr(m, "bias") and m.bias is not None: 34 | nn.init.constant_(m.bias.data, 0.0) 35 | elif hasattr(m, "weight") and ( 36 | classname.find("Conv") != -1 or classname.find("Linear") != -1 37 | ): 38 | if init_type == "normal": 39 | nn.init.normal_(m.weight.data, 0.0, gain) 40 | elif init_type == "xavier": 41 | nn.init.xavier_normal_(m.weight.data, gain=gain) 42 | elif init_type == "xavier_uniform": 43 | nn.init.xavier_uniform_(m.weight.data, gain=1.0) 44 | elif init_type == "kaiming": 45 | nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_in") 46 | elif init_type == "orthogonal": 47 | nn.init.orthogonal_(m.weight.data, gain=gain) 48 | elif init_type == "none": # uses pytorch's default init method 49 | m.reset_parameters() 50 | else: 51 | raise NotImplementedError( 52 | "initialization method [%s] is not implemented" % init_type 53 | ) 54 | if hasattr(m, "bias") and m.bias is not None: 55 | nn.init.constant_(m.bias.data, 0.0) 56 | 57 | self.apply(init_func) 58 | 59 | # propagate to children 60 | for m in self.children(): 61 | if hasattr(m, "init_weights"): 62 | m.init_weights(init_type, gain) 63 | 64 | 65 | class Vec2Feat(nn.Module): 66 | def __init__(self, channel, hidden, kernel_size, stride, padding): 67 | super(Vec2Feat, self).__init__() 68 | self.relu = nn.LeakyReLU(0.2, inplace=True) 69 | c_out = reduce((lambda x, y: x * y), kernel_size) * channel 70 | self.embedding = nn.Linear(hidden, c_out) 71 | self.kernel_size = kernel_size 72 | self.stride = stride 73 | self.padding = padding 74 | self.bias_conv = nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1) 75 | 76 | def forward(self, x, t, output_size): 77 | b_, _, _, _, c_ = x.shape 78 | x = x.view(b_, -1, c_) 79 | feat = self.embedding(x) 80 | b, _, c = feat.size() 81 | feat = feat.view(b * t, -1, c).permute(0, 2, 1) 82 | feat = F.fold( 83 | feat, 84 | output_size=output_size, 85 | kernel_size=self.kernel_size, 86 | stride=self.stride, 87 | padding=self.padding, 88 | ) 89 | feat = self.bias_conv(feat) 90 | return feat 91 | 92 | 93 | class FusionFeedForward(nn.Module): 94 | def __init__(self, dim, hidden_dim=1960, t2t_params=None): 95 | super(FusionFeedForward, self).__init__() 96 | # We set hidden_dim as a default to 1960 97 | self.fc1 = nn.Sequential(nn.Linear(dim, hidden_dim)) 98 | self.fc2 = nn.Sequential(nn.GELU(), nn.Linear(hidden_dim, dim)) 99 | assert t2t_params is not None 100 | self.t2t_params = t2t_params 101 | self.kernel_shape = reduce( 102 | (lambda x, y: x * y), t2t_params["kernel_size"] 103 | ) # 49 104 | 105 | def forward(self, x, output_size): 106 | n_vecs = 1 107 | for i, d in enumerate(self.t2t_params["kernel_size"]): 108 | n_vecs *= int( 109 | (output_size[i] + 2 * self.t2t_params["padding"][i] - (d - 1) - 1) 110 | / self.t2t_params["stride"][i] 111 | + 1 112 | ) 113 | 114 | x = self.fc1(x) 115 | b, n, c = x.size() 116 | normalizer = ( 117 | x.new_ones(b, n, self.kernel_shape) 118 | .view(-1, n_vecs, self.kernel_shape) 119 | .permute(0, 2, 1) 120 | ) 121 | normalizer = F.fold( 122 | normalizer, 123 | output_size=output_size, 124 | kernel_size=self.t2t_params["kernel_size"], 125 | padding=self.t2t_params["padding"], 126 | stride=self.t2t_params["stride"], 127 | ) 128 | 129 | x = F.fold( 130 | x.view(-1, n_vecs, c).permute(0, 2, 1), 131 | output_size=output_size, 132 | kernel_size=self.t2t_params["kernel_size"], 133 | padding=self.t2t_params["padding"], 134 | stride=self.t2t_params["stride"], 135 | ) 136 | 137 | x = ( 138 | F.unfold( 139 | x / normalizer, 140 | kernel_size=self.t2t_params["kernel_size"], 141 | padding=self.t2t_params["padding"], 142 | stride=self.t2t_params["stride"], 143 | ) 144 | .permute(0, 2, 1) 145 | .contiguous() 146 | .view(b, n, c) 147 | ) 148 | x = self.fc2(x) 149 | return x 150 | -------------------------------------------------------------------------------- /model/modules/deformconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import init as init 4 | from torch.nn.modules.utils import _pair, _single 5 | import math 6 | 7 | 8 | class ModulatedDeformConv2d(nn.Module): 9 | def __init__( 10 | self, 11 | in_channels, 12 | out_channels, 13 | kernel_size, 14 | stride=1, 15 | padding=0, 16 | dilation=1, 17 | groups=1, 18 | deform_groups=1, 19 | bias=True, 20 | ): 21 | super(ModulatedDeformConv2d, self).__init__() 22 | 23 | self.in_channels = in_channels 24 | self.out_channels = out_channels 25 | self.kernel_size = _pair(kernel_size) 26 | self.stride = stride 27 | self.padding = padding 28 | self.dilation = dilation 29 | self.groups = groups 30 | self.deform_groups = deform_groups 31 | self.with_bias = bias 32 | # enable compatibility with nn.Conv2d 33 | self.transposed = False 34 | self.output_padding = _single(0) 35 | 36 | self.weight = nn.Parameter( 37 | torch.Tensor(out_channels, in_channels // groups, *self.kernel_size) 38 | ) 39 | if bias: 40 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 41 | else: 42 | self.register_parameter("bias", None) 43 | self.init_weights() 44 | 45 | def init_weights(self): 46 | n = self.in_channels 47 | for k in self.kernel_size: 48 | n *= k 49 | stdv = 1.0 / math.sqrt(n) 50 | self.weight.data.uniform_(-stdv, stdv) 51 | if self.bias is not None: 52 | self.bias.data.zero_() 53 | 54 | if hasattr(self, "conv_offset"): 55 | self.conv_offset.weight.data.zero_() 56 | self.conv_offset.bias.data.zero_() 57 | 58 | def forward(self, x, offset, mask): 59 | pass 60 | -------------------------------------------------------------------------------- /model/modules/flow_comp_raft.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | from .RAFT import RAFT 7 | from .flow_loss_utils import flow_warp, ternary_loss2 8 | 9 | 10 | def initialize_RAFT(model_path="weights/raft-things.pth", device="cuda"): 11 | """Initializes the RAFT model.""" 12 | args = argparse.ArgumentParser() 13 | args.raft_model = model_path 14 | args.small = False 15 | args.mixed_precision = False 16 | args.alternate_corr = False 17 | model = torch.nn.DataParallel(RAFT(args)) 18 | model.load_state_dict(torch.load(args.raft_model, map_location="cpu")) 19 | model = model.module 20 | 21 | model.to(device) 22 | 23 | return model 24 | 25 | 26 | class RAFT_bi(nn.Module): 27 | """Flow completion loss""" 28 | 29 | def __init__(self, model_path="weights/raft-things.pth", device="cuda"): 30 | super().__init__() 31 | self.fix_raft = initialize_RAFT(model_path, device=device) 32 | 33 | for p in self.fix_raft.parameters(): 34 | p.requires_grad = False 35 | 36 | self.l1_criterion = nn.L1Loss() 37 | self.eval() 38 | 39 | def forward(self, gt_local_frames, iters=20): 40 | b, l_t, c, h, w = gt_local_frames.size() 41 | # print(gt_local_frames.shape) 42 | 43 | with torch.no_grad(): 44 | gtlf_1 = gt_local_frames[:, :-1, :, :, :].reshape(-1, c, h, w) 45 | gtlf_2 = gt_local_frames[:, 1:, :, :, :].reshape(-1, c, h, w) 46 | # print(gtlf_1.shape) 47 | 48 | _, gt_flows_forward = self.fix_raft( 49 | gtlf_1, gtlf_2, iters=iters, test_mode=True 50 | ) 51 | _, gt_flows_backward = self.fix_raft( 52 | gtlf_2, gtlf_1, iters=iters, test_mode=True 53 | ) 54 | 55 | gt_flows_forward = gt_flows_forward.view(b, l_t - 1, 2, h, w) 56 | gt_flows_backward = gt_flows_backward.view(b, l_t - 1, 2, h, w) 57 | 58 | return gt_flows_forward, gt_flows_backward 59 | 60 | 61 | ################################################################################## 62 | def smoothness_loss(flow, cmask): 63 | delta_u, delta_v, mask = smoothness_deltas(flow) 64 | loss_u = charbonnier_loss(delta_u, cmask) 65 | loss_v = charbonnier_loss(delta_v, cmask) 66 | return loss_u + loss_v 67 | 68 | 69 | def smoothness_deltas(flow): 70 | """flow: [b, c, h, w]""" 71 | mask_x = create_mask(flow, [[0, 0], [0, 1]]) 72 | mask_y = create_mask(flow, [[0, 1], [0, 0]]) 73 | mask = torch.cat((mask_x, mask_y), dim=1) 74 | mask = mask.to(flow.device) 75 | filter_x = torch.tensor([[0, 0, 0.0], [0, 1, -1], [0, 0, 0]]) 76 | filter_y = torch.tensor([[0, 0, 0.0], [0, 1, 0], [0, -1, 0]]) 77 | weights = torch.ones([2, 1, 3, 3]) 78 | weights[0, 0] = filter_x 79 | weights[1, 0] = filter_y 80 | weights = weights.to(flow.device) 81 | 82 | flow_u, flow_v = torch.split(flow, split_size_or_sections=1, dim=1) 83 | delta_u = F.conv2d(flow_u, weights, stride=1, padding=1) 84 | delta_v = F.conv2d(flow_v, weights, stride=1, padding=1) 85 | return delta_u, delta_v, mask 86 | 87 | 88 | def second_order_loss(flow, cmask): 89 | delta_u, delta_v, mask = second_order_deltas(flow) 90 | loss_u = charbonnier_loss(delta_u, cmask) 91 | loss_v = charbonnier_loss(delta_v, cmask) 92 | return loss_u + loss_v 93 | 94 | 95 | def charbonnier_loss(x, mask=None, truncate=None, alpha=0.45, beta=1.0, epsilon=0.001): 96 | """Compute the generalized charbonnier loss of the difference tensor x 97 | All positions where mask == 0 are not taken into account 98 | x: a tensor of shape [b, c, h, w] 99 | mask: a mask of shape [b, mc, h, w], where mask channels must be either 1 or the same as 100 | the number of channels of x. Entries should be 0 or 1 101 | return: loss 102 | """ 103 | b, c, h, w = x.shape 104 | norm = b * c * h * w 105 | error = torch.pow( 106 | torch.square(x * beta) + torch.square(torch.tensor(epsilon)), alpha 107 | ) 108 | if mask is not None: 109 | error = mask * error 110 | if truncate is not None: 111 | error = torch.min(error, truncate) 112 | return torch.sum(error) / norm 113 | 114 | 115 | def second_order_deltas(flow): 116 | """Consider the single flow first 117 | flow shape: [b, c, h, w] 118 | """ 119 | # create mask 120 | mask_x = create_mask(flow, [[0, 0], [1, 1]]) 121 | mask_y = create_mask(flow, [[1, 1], [0, 0]]) 122 | mask_diag = create_mask(flow, [[1, 1], [1, 1]]) 123 | mask = torch.cat((mask_x, mask_y, mask_diag, mask_diag), dim=1) 124 | mask = mask.to(flow.device) 125 | 126 | filter_x = torch.tensor([[0, 0, 0.0], [1, -2, 1], [0, 0, 0]]) 127 | filter_y = torch.tensor([[0, 1, 0.0], [0, -2, 0], [0, 1, 0]]) 128 | filter_diag1 = torch.tensor([[1, 0, 0.0], [0, -2, 0], [0, 0, 1]]) 129 | filter_diag2 = torch.tensor([[0, 0, 1.0], [0, -2, 0], [1, 0, 0]]) 130 | weights = torch.ones([4, 1, 3, 3]) 131 | weights[0] = filter_x 132 | weights[1] = filter_y 133 | weights[2] = filter_diag1 134 | weights[3] = filter_diag2 135 | weights = weights.to(flow.device) 136 | 137 | # split the flow into flow_u and flow_v, conv them with the weights 138 | flow_u, flow_v = torch.split(flow, split_size_or_sections=1, dim=1) 139 | delta_u = F.conv2d(flow_u, weights, stride=1, padding=1) 140 | delta_v = F.conv2d(flow_v, weights, stride=1, padding=1) 141 | return delta_u, delta_v, mask 142 | 143 | 144 | def create_mask(tensor, paddings): 145 | """Tensor shape: [b, c, h, w] 146 | paddings: [2 x 2] shape list, the first row indicates up and down paddings 147 | the second row indicates left and right paddings 148 | | | 149 | | x | 150 | | x * x | 151 | | x | 152 | | | 153 | """ 154 | shape = tensor.shape 155 | inner_height = shape[2] - (paddings[0][0] + paddings[0][1]) 156 | inner_width = shape[3] - (paddings[1][0] + paddings[1][1]) 157 | inner = torch.ones([inner_height, inner_width]) 158 | torch_paddings = [ 159 | paddings[1][0], 160 | paddings[1][1], 161 | paddings[0][0], 162 | paddings[0][1], 163 | ] # left, right, up and down 164 | mask2d = F.pad(inner, pad=torch_paddings) 165 | mask3d = mask2d.unsqueeze(0).repeat(shape[0], 1, 1) 166 | mask4d = mask3d.unsqueeze(1) 167 | return mask4d.detach() 168 | 169 | 170 | def ternary_loss(flow_comp, flow_gt, mask, current_frame, shift_frame, scale_factor=1): 171 | if scale_factor != 1: 172 | current_frame = F.interpolate( 173 | current_frame, scale_factor=1 / scale_factor, mode="bilinear" 174 | ) 175 | shift_frame = F.interpolate( 176 | shift_frame, scale_factor=1 / scale_factor, mode="bilinear" 177 | ) 178 | warped_sc = flow_warp(shift_frame, flow_gt.permute(0, 2, 3, 1)) 179 | noc_mask = torch.exp( 180 | -50.0 * torch.sum(torch.abs(current_frame - warped_sc), dim=1).pow(2) 181 | ).unsqueeze(1) 182 | warped_comp_sc = flow_warp(shift_frame, flow_comp.permute(0, 2, 3, 1)) 183 | loss = ternary_loss2(current_frame, warped_comp_sc, noc_mask, mask) 184 | return loss 185 | 186 | 187 | class FlowLoss(nn.Module): 188 | def __init__(self): 189 | super().__init__() 190 | self.l1_criterion = nn.L1Loss() 191 | 192 | def forward(self, pred_flows, gt_flows, masks, frames): 193 | # pred_flows: b t-1 2 h w 194 | loss = 0 195 | warp_loss = 0 196 | h, w = pred_flows[0].shape[-2:] 197 | masks = [masks[:, :-1, ...].contiguous(), masks[:, 1:, ...].contiguous()] 198 | frames0 = frames[:, :-1, ...] 199 | frames1 = frames[:, 1:, ...] 200 | current_frames = [frames0, frames1] 201 | next_frames = [frames1, frames0] 202 | for i in range(len(pred_flows)): 203 | # print(pred_flows[i].shape) 204 | combined_flow = pred_flows[i] * masks[i] + gt_flows[i] * (1 - masks[i]) 205 | l1_loss = self.l1_criterion( 206 | pred_flows[i] * masks[i], gt_flows[i] * masks[i] 207 | ) / torch.mean(masks[i]) 208 | l1_loss += self.l1_criterion( 209 | pred_flows[i] * (1 - masks[i]), gt_flows[i] * (1 - masks[i]) 210 | ) / torch.mean(1 - masks[i]) 211 | 212 | smooth_loss = smoothness_loss( 213 | combined_flow.reshape(-1, 2, h, w), masks[i].reshape(-1, 1, h, w) 214 | ) 215 | smooth_loss2 = second_order_loss( 216 | combined_flow.reshape(-1, 2, h, w), masks[i].reshape(-1, 1, h, w) 217 | ) 218 | 219 | warp_loss_i = ternary_loss( 220 | combined_flow.reshape(-1, 2, h, w), 221 | gt_flows[i].reshape(-1, 2, h, w), 222 | masks[i].reshape(-1, 1, h, w), 223 | current_frames[i].reshape(-1, 3, h, w), 224 | next_frames[i].reshape(-1, 3, h, w), 225 | ) 226 | 227 | loss += l1_loss + smooth_loss + smooth_loss2 228 | 229 | warp_loss += warp_loss_i 230 | 231 | return loss, warp_loss 232 | 233 | 234 | def edgeLoss(preds_edges, edges): 235 | """Args: 236 | preds_edges: with shape [b, c, h , w] 237 | edges: with shape [b, c, h, w] 238 | 239 | Returns: Edge losses 240 | 241 | """ 242 | mask = (edges > 0.5).float() 243 | b, c, h, w = mask.shape 244 | num_pos = torch.sum(mask, dim=[1, 2, 3]).float() # Shape: [b,]. 245 | num_neg = c * h * w - num_pos # Shape: [b,]. 246 | neg_weights = (num_neg / (num_pos + num_neg)).unsqueeze(1).unsqueeze(2).unsqueeze(3) 247 | pos_weights = (num_pos / (num_pos + num_neg)).unsqueeze(1).unsqueeze(2).unsqueeze(3) 248 | weight = neg_weights * mask + pos_weights * (1 - mask) # weight for debug 249 | losses = F.binary_cross_entropy_with_logits( 250 | preds_edges.float(), edges.float(), weight=weight, reduction="none" 251 | ) 252 | loss = torch.mean(losses) 253 | return loss 254 | 255 | 256 | class EdgeLoss(nn.Module): 257 | def __init__(self): 258 | super().__init__() 259 | 260 | def forward(self, pred_edges, gt_edges, masks): 261 | # pred_flows: b t-1 1 h w 262 | loss = 0 263 | h, w = pred_edges[0].shape[-2:] 264 | masks = [masks[:, :-1, ...].contiguous(), masks[:, 1:, ...].contiguous()] 265 | for i in range(len(pred_edges)): 266 | # print(f'edges_{i}', torch.sum(gt_edges[i])) # debug 267 | combined_edge = pred_edges[i] * masks[i] + gt_edges[i] * (1 - masks[i]) 268 | edge_loss = edgeLoss( 269 | pred_edges[i].reshape(-1, 1, h, w), gt_edges[i].reshape(-1, 1, h, w) 270 | ) + 5 * edgeLoss( 271 | combined_edge.reshape(-1, 1, h, w), gt_edges[i].reshape(-1, 1, h, w) 272 | ) 273 | loss += edge_loss 274 | 275 | return loss 276 | 277 | 278 | class FlowSimpleLoss(nn.Module): 279 | def __init__(self): 280 | super().__init__() 281 | self.l1_criterion = nn.L1Loss() 282 | 283 | def forward(self, pred_flows, gt_flows): 284 | # pred_flows: b t-1 2 h w 285 | loss = 0 286 | h, w = pred_flows[0].shape[-2:] 287 | h_orig, w_orig = gt_flows[0].shape[-2:] 288 | pred_flows = [f.view(-1, 2, h, w) for f in pred_flows] 289 | gt_flows = [f.view(-1, 2, h_orig, w_orig) for f in gt_flows] 290 | 291 | ds_factor = 1.0 * h / h_orig 292 | gt_flows = [ 293 | F.interpolate(f, scale_factor=ds_factor, mode="area") * ds_factor 294 | for f in gt_flows 295 | ] 296 | for i in range(len(pred_flows)): 297 | loss += self.l1_criterion(pred_flows[i], gt_flows[i]) 298 | 299 | return loss 300 | -------------------------------------------------------------------------------- /model/modules/flow_loss_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | 5 | 6 | def flow_warp( 7 | x, flow, interpolation="bilinear", padding_mode="zeros", align_corners=True 8 | ): 9 | """Warp an image or a feature map with optical flow. 10 | 11 | Args: 12 | x (Tensor): Tensor with size (n, c, h, w). 13 | flow (Tensor): Tensor with size (n, h, w, 2). The last dimension is 14 | a two-channel, denoting the width and height relative offsets. 15 | Note that the values are not normalized to [-1, 1]. 16 | interpolation (str): Interpolation mode: 'nearest' or 'bilinear'. 17 | Default: 'bilinear'. 18 | padding_mode (str): Padding mode: 'zeros' or 'border' or 'reflection'. 19 | Default: 'zeros'. 20 | align_corners (bool): Whether align corners. Default: True. 21 | 22 | Returns: 23 | Tensor: Warped image or feature map. 24 | """ 25 | if x.size()[-2:] != flow.size()[1:3]: 26 | raise ValueError( 27 | f"The spatial sizes of input ({x.size()[-2:]}) and " 28 | f"flow ({flow.size()[1:3]}) are not the same." 29 | ) 30 | _, _, h, w = x.size() 31 | # create mesh grid 32 | device = flow.device 33 | grid_y, grid_x = torch.meshgrid( 34 | torch.arange(0, h, device=device), torch.arange(0, w, device=device) 35 | ) 36 | grid = torch.stack((grid_x, grid_y), 2).type_as(x) # (w, h, 2) 37 | grid.requires_grad = False 38 | 39 | grid_flow = grid + flow 40 | # scale grid_flow to [-1,1] 41 | grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0 42 | grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0 43 | grid_flow = torch.stack((grid_flow_x, grid_flow_y), dim=3) 44 | output = F.grid_sample( 45 | x, 46 | grid_flow, 47 | mode=interpolation, 48 | padding_mode=padding_mode, 49 | align_corners=align_corners, 50 | ) 51 | return output 52 | 53 | 54 | # def image_warp(image, flow): 55 | # b, c, h, w = image.size() 56 | # device = image.device 57 | # flow = torch.cat([flow[:, 0:1, :, :] / ((w - 1.0) / 2.0), flow[:, 1:2, :, :] / ((h - 1.0) / 2.0)], dim=1) # normalize to [-1~1](from upper left to lower right 58 | # flow = flow.permute(0, 2, 3, 1) # if you wanna use grid_sample function, the channel(band) shape of show must be in the last dimension 59 | # x = np.linspace(-1, 1, w) 60 | # y = np.linspace(-1, 1, h) 61 | # X, Y = np.meshgrid(x, y) 62 | # grid = torch.cat((torch.from_numpy(X.astype('float32')).unsqueeze(0).unsqueeze(3), 63 | # torch.from_numpy(Y.astype('float32')).unsqueeze(0).unsqueeze(3)), 3).to(device) 64 | # output = torch.nn.functional.grid_sample(image, grid + flow, mode='bilinear', padding_mode='zeros') 65 | # return output 66 | 67 | 68 | def length_sq(x): 69 | return torch.sum(torch.square(x), dim=1, keepdim=True) 70 | 71 | 72 | def fbConsistencyCheck(flow_fw, flow_bw, alpha1=0.01, alpha2=0.5): 73 | flow_bw_warped = flow_warp(flow_bw, flow_fw.permute(0, 2, 3, 1)) # wb(wf(x)) 74 | flow_fw_warped = flow_warp(flow_fw, flow_bw.permute(0, 2, 3, 1)) # wf(wb(x)) 75 | flow_diff_fw = flow_fw + flow_bw_warped # wf + wb(wf(x)) 76 | flow_diff_bw = flow_bw + flow_fw_warped # wb + wf(wb(x)) 77 | 78 | mag_sq_fw = length_sq(flow_fw) + length_sq(flow_bw_warped) # |wf| + |wb(wf(x))| 79 | mag_sq_bw = length_sq(flow_bw) + length_sq(flow_fw_warped) # |wb| + |wf(wb(x))| 80 | occ_thresh_fw = alpha1 * mag_sq_fw + alpha2 81 | occ_thresh_bw = alpha1 * mag_sq_bw + alpha2 82 | 83 | fb_occ_fw = (length_sq(flow_diff_fw) > occ_thresh_fw).float() 84 | fb_occ_bw = (length_sq(flow_diff_bw) > occ_thresh_bw).float() 85 | 86 | return ( 87 | fb_occ_fw, 88 | fb_occ_bw, 89 | ) # fb_occ_fw -> frame2 area occluded by frame1, fb_occ_bw -> frame1 area occluded by frame2 90 | 91 | 92 | def rgb2gray(image): 93 | gray_image = image[:, 0] * 0.299 + image[:, 1] * 0.587 + 0.110 * image[:, 2] 94 | gray_image = gray_image.unsqueeze(1) 95 | return gray_image 96 | 97 | 98 | def ternary_transform(image, max_distance=1): 99 | device = image.device 100 | patch_size = 2 * max_distance + 1 101 | intensities = rgb2gray(image) * 255 102 | out_channels = patch_size * patch_size 103 | w = np.eye(out_channels).reshape(out_channels, 1, patch_size, patch_size) 104 | weights = torch.from_numpy(w).float().to(device) 105 | patches = F.conv2d(intensities, weights, stride=1, padding=1) 106 | transf = patches - intensities 107 | transf_norm = transf / torch.sqrt(0.81 + torch.square(transf)) 108 | return transf_norm 109 | 110 | 111 | def hamming_distance(t1, t2): 112 | dist = torch.square(t1 - t2) 113 | dist_norm = dist / (0.1 + dist) 114 | dist_sum = torch.sum(dist_norm, dim=1, keepdim=True) 115 | return dist_sum 116 | 117 | 118 | def create_mask(mask, paddings): 119 | """padding: [[top, bottom], [left, right]]""" 120 | shape = mask.shape 121 | inner_height = shape[2] - (paddings[0][0] + paddings[0][1]) 122 | inner_width = shape[3] - (paddings[1][0] + paddings[1][1]) 123 | inner = torch.ones([inner_height, inner_width]) 124 | 125 | mask2d = F.pad( 126 | inner, pad=[paddings[1][0], paddings[1][1], paddings[0][0], paddings[0][1]] 127 | ) 128 | mask3d = mask2d.unsqueeze(0) 129 | mask4d = mask3d.unsqueeze(0).repeat(shape[0], 1, 1, 1) 130 | return mask4d.detach() 131 | 132 | 133 | def ternary_loss2(frame1, warp_frame21, confMask, masks, max_distance=1): 134 | """Args: 135 | frame1: torch tensor, with shape [b * t, c, h, w] 136 | warp_frame21: torch tensor, with shape [b * t, c, h, w] 137 | confMask: confidence mask, with shape [b * t, c, h, w] 138 | masks: torch tensor, with shape [b * t, c, h, w] 139 | max_distance: maximum distance. 140 | 141 | Returns: ternary loss 142 | 143 | """ 144 | t1 = ternary_transform(frame1) 145 | t21 = ternary_transform(warp_frame21) 146 | dist = hamming_distance(t1, t21) 147 | loss = torch.mean(dist * confMask * masks) / torch.mean(masks) 148 | return loss 149 | -------------------------------------------------------------------------------- /model/modules/spectral_norm.py: -------------------------------------------------------------------------------- 1 | """Spectral Normalization from https://arxiv.org/abs/1802.05957""" 2 | 3 | import torch 4 | from torch.nn.functional import normalize 5 | 6 | 7 | class SpectralNorm: 8 | # Invariant before and after each forward call: 9 | # u = normalize(W @ v) 10 | # NB: At initialization, this invariant is not enforced 11 | 12 | _version = 1 13 | 14 | # At version 1: 15 | # made `W` not a buffer, 16 | # added `v` as a buffer, and 17 | # made eval mode use `W = u @ W_orig @ v` rather than the stored `W`. 18 | 19 | def __init__(self, name="weight", n_power_iterations=1, dim=0, eps=1e-12): 20 | self.name = name 21 | self.dim = dim 22 | if n_power_iterations <= 0: 23 | raise ValueError( 24 | "Expected n_power_iterations to be positive, but " 25 | f"got n_power_iterations={n_power_iterations}" 26 | ) 27 | self.n_power_iterations = n_power_iterations 28 | self.eps = eps 29 | 30 | def reshape_weight_to_matrix(self, weight): 31 | weight_mat = weight 32 | if self.dim != 0: 33 | # permute dim to front 34 | weight_mat = weight_mat.permute( 35 | self.dim, *[d for d in range(weight_mat.dim()) if d != self.dim] 36 | ) 37 | height = weight_mat.size(0) 38 | return weight_mat.reshape(height, -1) 39 | 40 | def compute_weight(self, module, do_power_iteration): 41 | # NB: If `do_power_iteration` is set, the `u` and `v` vectors are 42 | # updated in power iteration **in-place**. This is very important 43 | # because in `DataParallel` forward, the vectors (being buffers) are 44 | # broadcast from the parallelized module to each module replica, 45 | # which is a new module object created on the fly. And each replica 46 | # runs its own spectral norm power iteration. So simply assigning 47 | # the updated vectors to the module this function runs on will cause 48 | # the update to be lost forever. And the next time the parallelized 49 | # module is replicated, the same randomly initialized vectors are 50 | # broadcast and used! 51 | # 52 | # Therefore, to make the change propagate back, we rely on two 53 | # important behaviors (also enforced via tests): 54 | # 1. `DataParallel` doesn't clone storage if the broadcast tensor 55 | # is already on correct device; and it makes sure that the 56 | # parallelized module is already on `device[0]`. 57 | # 2. If the out tensor in `out=` kwarg has correct shape, it will 58 | # just fill in the values. 59 | # Therefore, since the same power iteration is performed on all 60 | # devices, simply updating the tensors in-place will make sure that 61 | # the module replica on `device[0]` will update the _u vector on the 62 | # parallized module (by shared storage). 63 | # 64 | # However, after we update `u` and `v` in-place, we need to **clone** 65 | # them before using them to normalize the weight. This is to support 66 | # backproping through two forward passes, e.g., the common pattern in 67 | # GAN training: loss = D(real) - D(fake). Otherwise, engine will 68 | # complain that variables needed to do backward for the first forward 69 | # (i.e., the `u` and `v` vectors) are changed in the second forward. 70 | weight = getattr(module, self.name + "_orig") 71 | u = getattr(module, self.name + "_u") 72 | v = getattr(module, self.name + "_v") 73 | weight_mat = self.reshape_weight_to_matrix(weight) 74 | 75 | if do_power_iteration: 76 | with torch.no_grad(): 77 | for _ in range(self.n_power_iterations): 78 | # Spectral norm of weight equals to `u^T W v`, where `u` and `v` 79 | # are the first left and right singular vectors. 80 | # This power iteration produces approximations of `u` and `v`. 81 | v = normalize( 82 | torch.mv(weight_mat.t(), u), dim=0, eps=self.eps, out=v 83 | ) 84 | u = normalize(torch.mv(weight_mat, v), dim=0, eps=self.eps, out=u) 85 | if self.n_power_iterations > 0: 86 | # See above on why we need to clone 87 | u = u.clone() 88 | v = v.clone() 89 | 90 | sigma = torch.dot(u, torch.mv(weight_mat, v)) 91 | weight = weight / sigma 92 | return weight 93 | 94 | def remove(self, module): 95 | with torch.no_grad(): 96 | weight = self.compute_weight(module, do_power_iteration=False) 97 | delattr(module, self.name) 98 | delattr(module, self.name + "_u") 99 | delattr(module, self.name + "_v") 100 | delattr(module, self.name + "_orig") 101 | module.register_parameter(self.name, torch.nn.Parameter(weight.detach())) 102 | 103 | def __call__(self, module, inputs): 104 | setattr( 105 | module, 106 | self.name, 107 | self.compute_weight(module, do_power_iteration=module.training), 108 | ) 109 | 110 | def _solve_v_and_rescale(self, weight_mat, u, target_sigma): 111 | # Tries to returns a vector `v` s.t. `u = normalize(W @ v)` 112 | # (the invariant at top of this class) and `u @ W @ v = sigma`. 113 | # This uses pinverse in case W^T W is not invertible. 114 | v = torch.chain_matmul( 115 | weight_mat.t().mm(weight_mat).pinverse(), weight_mat.t(), u.unsqueeze(1) 116 | ).squeeze(1) 117 | return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v))) 118 | 119 | @staticmethod 120 | def apply(module, name, n_power_iterations, dim, eps): 121 | for k, hook in module._forward_pre_hooks.items(): 122 | if isinstance(hook, SpectralNorm) and hook.name == name: 123 | raise RuntimeError( 124 | "Cannot register two spectral_norm hooks on " 125 | f"the same parameter {name}" 126 | ) 127 | 128 | fn = SpectralNorm(name, n_power_iterations, dim, eps) 129 | weight = module._parameters[name] 130 | 131 | with torch.no_grad(): 132 | weight_mat = fn.reshape_weight_to_matrix(weight) 133 | 134 | h, w = weight_mat.size() 135 | # randomly initialize `u` and `v` 136 | u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps) 137 | v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps) 138 | 139 | delattr(module, fn.name) 140 | module.register_parameter(fn.name + "_orig", weight) 141 | # We still need to assign weight back as fn.name because all sorts of 142 | # things may assume that it exists, e.g., when initializing weights. 143 | # However, we can't directly assign as it could be an nn.Parameter and 144 | # gets added as a parameter. Instead, we register weight.data as a plain 145 | # attribute. 146 | setattr(module, fn.name, weight.data) 147 | module.register_buffer(fn.name + "_u", u) 148 | module.register_buffer(fn.name + "_v", v) 149 | 150 | module.register_forward_pre_hook(fn) 151 | 152 | module._register_state_dict_hook(SpectralNormStateDictHook(fn)) 153 | module._register_load_state_dict_pre_hook(SpectralNormLoadStateDictPreHook(fn)) 154 | return fn 155 | 156 | 157 | # This is a top level class because Py2 pickle doesn't like inner class nor an 158 | # instancemethod. 159 | class SpectralNormLoadStateDictPreHook: 160 | # See docstring of SpectralNorm._version on the changes to spectral_norm. 161 | def __init__(self, fn): 162 | self.fn = fn 163 | 164 | # For state_dict with version None, (assuming that it has gone through at 165 | # least one training forward), we have 166 | # 167 | # u = normalize(W_orig @ v) 168 | # W = W_orig / sigma, where sigma = u @ W_orig @ v 169 | # 170 | # To compute `v`, we solve `W_orig @ x = u`, and let 171 | # v = x / (u @ W_orig @ x) * (W / W_orig). 172 | def __call__( 173 | self, 174 | state_dict, 175 | prefix, 176 | local_metadata, 177 | strict, 178 | missing_keys, 179 | unexpected_keys, 180 | error_msgs, 181 | ): 182 | fn = self.fn 183 | version = local_metadata.get("spectral_norm", {}).get( 184 | fn.name + ".version", None 185 | ) 186 | if version is None or version < 1: 187 | with torch.no_grad(): 188 | weight_orig = state_dict[prefix + fn.name + "_orig"] 189 | # weight = state_dict.pop(prefix + fn.name) 190 | # sigma = (weight_orig / weight).mean() 191 | weight_mat = fn.reshape_weight_to_matrix(weight_orig) 192 | u = state_dict[prefix + fn.name + "_u"] 193 | # v = fn._solve_v_and_rescale(weight_mat, u, sigma) 194 | # state_dict[prefix + fn.name + '_v'] = v 195 | 196 | 197 | # This is a top level class because Py2 pickle doesn't like inner class nor an 198 | # instancemethod. 199 | class SpectralNormStateDictHook: 200 | # See docstring of SpectralNorm._version on the changes to spectral_norm. 201 | def __init__(self, fn): 202 | self.fn = fn 203 | 204 | def __call__(self, module, state_dict, prefix, local_metadata): 205 | if "spectral_norm" not in local_metadata: 206 | local_metadata["spectral_norm"] = {} 207 | key = self.fn.name + ".version" 208 | if key in local_metadata["spectral_norm"]: 209 | raise RuntimeError(f"Unexpected key in metadata['spectral_norm']: {key}") 210 | local_metadata["spectral_norm"][key] = self.fn._version 211 | 212 | 213 | def spectral_norm(module, name="weight", n_power_iterations=1, eps=1e-12, dim=None): 214 | r"""Applies spectral normalization to a parameter in the given module. 215 | 216 | .. math:: 217 | \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})}, 218 | \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} 219 | 220 | Spectral normalization stabilizes the training of discriminators (critics) 221 | in Generative Adversarial Networks (GANs) by rescaling the weight tensor 222 | with spectral norm :math:`\sigma` of the weight matrix calculated using 223 | power iteration method. If the dimension of the weight tensor is greater 224 | than 2, it is reshaped to 2D in power iteration method to get spectral 225 | norm. This is implemented via a hook that calculates spectral norm and 226 | rescales weight before every :meth:`~Module.forward` call. 227 | 228 | See `Spectral Normalization for Generative Adversarial Networks`_ . 229 | 230 | .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 231 | 232 | Args: 233 | module (nn.Module): containing module 234 | name (str, optional): name of weight parameter 235 | n_power_iterations (int, optional): number of power iterations to 236 | calculate spectral norm 237 | eps (float, optional): epsilon for numerical stability in 238 | calculating norms 239 | dim (int, optional): dimension corresponding to number of outputs, 240 | the default is ``0``, except for modules that are instances of 241 | ConvTranspose{1,2,3}d, when it is ``1`` 242 | 243 | Returns: 244 | The original module with the spectral norm hook 245 | 246 | Example:: 247 | 248 | >>> m = spectral_norm(nn.Linear(20, 40)) 249 | >>> m 250 | Linear(in_features=20, out_features=40, bias=True) 251 | >>> m.weight_u.size() 252 | torch.Size([40]) 253 | 254 | """ 255 | if dim is None: 256 | if isinstance( 257 | module, 258 | ( 259 | torch.nn.ConvTranspose1d, 260 | torch.nn.ConvTranspose2d, 261 | torch.nn.ConvTranspose3d, 262 | ), 263 | ): 264 | dim = 1 265 | else: 266 | dim = 0 267 | SpectralNorm.apply(module, name, n_power_iterations, dim, eps) 268 | return module 269 | 270 | 271 | def remove_spectral_norm(module, name="weight"): 272 | r"""Removes the spectral normalization reparameterization from a module. 273 | 274 | Args: 275 | module (Module): containing module 276 | name (str, optional): name of weight parameter 277 | 278 | Example: 279 | >>> m = spectral_norm(nn.Linear(40, 10)) 280 | >>> remove_spectral_norm(m) 281 | """ 282 | for k, hook in module._forward_pre_hooks.items(): 283 | if isinstance(hook, SpectralNorm) and hook.name == name: 284 | hook.remove(module) 285 | del module._forward_pre_hooks[k] 286 | return module 287 | 288 | raise ValueError(f"spectral_norm of '{name}' not found in {module}") 289 | 290 | 291 | def use_spectral_norm(module, use_sn=False): 292 | if use_sn: 293 | return spectral_norm(module) 294 | return module 295 | -------------------------------------------------------------------------------- /model/recurrent_flow_completion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | 6 | from .modules.deformconv import ModulatedDeformConv2d 7 | from .misc import constant_init 8 | 9 | 10 | class SecondOrderDeformableAlignment(ModulatedDeformConv2d): 11 | """Second-order deformable alignment module.""" 12 | 13 | def __init__(self, *args, **kwargs): 14 | self.max_residue_magnitude = kwargs.pop("max_residue_magnitude", 5) 15 | 16 | super(SecondOrderDeformableAlignment, self).__init__(*args, **kwargs) 17 | 18 | self.conv_offset = nn.Sequential( 19 | nn.Conv2d(3 * self.out_channels, self.out_channels, 3, 1, 1), 20 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 21 | nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1), 22 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 23 | nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1), 24 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 25 | nn.Conv2d(self.out_channels, 27 * self.deform_groups, 3, 1, 1), 26 | ) 27 | self.init_offset() 28 | 29 | def init_offset(self): 30 | constant_init(self.conv_offset[-1], val=0, bias=0) 31 | 32 | def forward(self, x, extra_feat): 33 | out = self.conv_offset(extra_feat) 34 | o1, o2, mask = torch.chunk(out, 3, dim=1) 35 | 36 | # offset 37 | offset = self.max_residue_magnitude * torch.tanh(torch.cat((o1, o2), dim=1)) 38 | offset_1, offset_2 = torch.chunk(offset, 2, dim=1) 39 | offset = torch.cat([offset_1, offset_2], dim=1) 40 | 41 | # mask 42 | mask = torch.sigmoid(mask) 43 | 44 | return torchvision.ops.deform_conv2d( 45 | x, 46 | offset, 47 | self.weight, 48 | self.bias, 49 | self.stride, 50 | self.padding, 51 | self.dilation, 52 | mask, 53 | ) 54 | 55 | 56 | class BidirectionalPropagation(nn.Module): 57 | def __init__(self, channel): 58 | super(BidirectionalPropagation, self).__init__() 59 | modules = ["backward_", "forward_"] 60 | self.deform_align = nn.ModuleDict() 61 | self.backbone = nn.ModuleDict() 62 | self.channel = channel 63 | 64 | for i, module in enumerate(modules): 65 | self.deform_align[module] = SecondOrderDeformableAlignment( 66 | 2 * channel, channel, 3, padding=1, deform_groups=16 67 | ) 68 | 69 | self.backbone[module] = nn.Sequential( 70 | nn.Conv2d((2 + i) * channel, channel, 3, 1, 1), 71 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 72 | nn.Conv2d(channel, channel, 3, 1, 1), 73 | ) 74 | 75 | self.fusion = nn.Conv2d(2 * channel, channel, 1, 1, 0) 76 | 77 | def forward(self, x): 78 | """X shape : [b, t, c, h, w] 79 | return [b, t, c, h, w] 80 | """ 81 | b, t, c, h, w = x.shape 82 | feats = {} 83 | feats["spatial"] = [x[:, i, :, :, :] for i in range(0, t)] 84 | 85 | for module_name in ["backward_", "forward_"]: 86 | feats[module_name] = [] 87 | 88 | frame_idx = range(0, t) 89 | mapping_idx = list(range(0, len(feats["spatial"]))) 90 | mapping_idx += mapping_idx[::-1] 91 | 92 | if "backward" in module_name: 93 | frame_idx = frame_idx[::-1] 94 | 95 | feat_prop = x.new_zeros(b, self.channel, h, w) 96 | for i, idx in enumerate(frame_idx): 97 | feat_current = feats["spatial"][mapping_idx[idx]] 98 | if i > 0: 99 | cond_n1 = feat_prop 100 | 101 | # initialize second-order features 102 | feat_n2 = torch.zeros_like(feat_prop) 103 | cond_n2 = torch.zeros_like(cond_n1) 104 | if i > 1: # second-order features 105 | feat_n2 = feats[module_name][-2] 106 | cond_n2 = feat_n2 107 | 108 | cond = torch.cat( 109 | [cond_n1, feat_current, cond_n2], dim=1 110 | ) # condition information, cond(flow warped 1st/2nd feature) 111 | feat_prop = torch.cat( 112 | [feat_prop, feat_n2], dim=1 113 | ) # two order feat_prop -1 & -2 114 | feat_prop = self.deform_align[module_name](feat_prop, cond) 115 | 116 | # fuse current features 117 | feat = ( 118 | [feat_current] 119 | + [ 120 | feats[k][idx] 121 | for k in feats 122 | if k not in ["spatial", module_name] 123 | ] 124 | + [feat_prop] 125 | ) 126 | 127 | feat = torch.cat(feat, dim=1) 128 | # embed current features 129 | feat_prop = feat_prop + self.backbone[module_name](feat) 130 | 131 | feats[module_name].append(feat_prop) 132 | 133 | # end for 134 | if "backward" in module_name: 135 | feats[module_name] = feats[module_name][::-1] 136 | 137 | outputs = [] 138 | for i in range(0, t): 139 | align_feats = [feats[k].pop(0) for k in feats if k != "spatial"] 140 | align_feats = torch.cat(align_feats, dim=1) 141 | outputs.append(self.fusion(align_feats)) 142 | 143 | return torch.stack(outputs, dim=1) + x 144 | 145 | 146 | class deconv(nn.Module): 147 | def __init__(self, input_channel, output_channel, kernel_size=3, padding=0): 148 | super().__init__() 149 | self.conv = nn.Conv2d( 150 | input_channel, 151 | output_channel, 152 | kernel_size=kernel_size, 153 | stride=1, 154 | padding=padding, 155 | ) 156 | 157 | def forward(self, x): 158 | x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) 159 | return self.conv(x) 160 | 161 | 162 | class P3DBlock(nn.Module): 163 | def __init__( 164 | self, 165 | in_channels, 166 | out_channels, 167 | kernel_size, 168 | stride, 169 | padding, 170 | use_residual=0, 171 | bias=True, 172 | ): 173 | super().__init__() 174 | self.conv1 = nn.Sequential( 175 | nn.Conv3d( 176 | in_channels, 177 | out_channels, 178 | kernel_size=(1, kernel_size, kernel_size), 179 | stride=(1, stride, stride), 180 | padding=(0, padding, padding), 181 | bias=bias, 182 | ), 183 | nn.LeakyReLU(0.2, inplace=True), 184 | ) 185 | self.conv2 = nn.Sequential( 186 | nn.Conv3d( 187 | out_channels, 188 | out_channels, 189 | kernel_size=(3, 1, 1), 190 | stride=(1, 1, 1), 191 | padding=(2, 0, 0), 192 | dilation=(2, 1, 1), 193 | bias=bias, 194 | ) 195 | ) 196 | self.use_residual = use_residual 197 | 198 | def forward(self, feats): 199 | feat1 = self.conv1(feats) 200 | feat2 = self.conv2(feat1) 201 | if self.use_residual: 202 | output = feats + feat2 203 | else: 204 | output = feat2 205 | return output 206 | 207 | 208 | class EdgeDetection(nn.Module): 209 | def __init__(self, in_ch=2, out_ch=1, mid_ch=16): 210 | super().__init__() 211 | self.projection = nn.Sequential( 212 | nn.Conv2d(in_ch, mid_ch, 3, 1, 1), nn.LeakyReLU(0.2, inplace=True) 213 | ) 214 | 215 | self.mid_layer_1 = nn.Sequential( 216 | nn.Conv2d(mid_ch, mid_ch, 3, 1, 1), nn.LeakyReLU(0.2, inplace=True) 217 | ) 218 | 219 | self.mid_layer_2 = nn.Sequential(nn.Conv2d(mid_ch, mid_ch, 3, 1, 1)) 220 | 221 | self.l_relu = nn.LeakyReLU(0.01, inplace=True) 222 | 223 | self.out_layer = nn.Conv2d(mid_ch, out_ch, 1, 1, 0) 224 | 225 | def forward(self, flow): 226 | flow = self.projection(flow) 227 | edge = self.mid_layer_1(flow) 228 | edge = self.mid_layer_2(edge) 229 | edge = self.l_relu(flow + edge) 230 | edge = self.out_layer(edge) 231 | edge = torch.sigmoid(edge) 232 | return edge 233 | 234 | 235 | class RecurrentFlowCompleteNet(nn.Module): 236 | def __init__(self, model_path=None): 237 | super().__init__() 238 | self.downsample = nn.Sequential( 239 | nn.Conv3d( 240 | 3, 241 | 32, 242 | kernel_size=(1, 5, 5), 243 | stride=(1, 2, 2), 244 | padding=(0, 2, 2), 245 | padding_mode="replicate", 246 | ), 247 | nn.LeakyReLU(0.2, inplace=True), 248 | ) 249 | 250 | self.encoder1 = nn.Sequential( 251 | P3DBlock(32, 32, 3, 1, 1), 252 | nn.LeakyReLU(0.2, inplace=True), 253 | P3DBlock(32, 64, 3, 2, 1), 254 | nn.LeakyReLU(0.2, inplace=True), 255 | ) # 4x 256 | 257 | self.encoder2 = nn.Sequential( 258 | P3DBlock(64, 64, 3, 1, 1), 259 | nn.LeakyReLU(0.2, inplace=True), 260 | P3DBlock(64, 128, 3, 2, 1), 261 | nn.LeakyReLU(0.2, inplace=True), 262 | ) # 8x 263 | 264 | self.mid_dilation = nn.Sequential( 265 | nn.Conv3d( 266 | 128, 128, (1, 3, 3), (1, 1, 1), padding=(0, 3, 3), dilation=(1, 3, 3) 267 | ), # p = d*(k-1)/2 268 | nn.LeakyReLU(0.2, inplace=True), 269 | nn.Conv3d( 270 | 128, 128, (1, 3, 3), (1, 1, 1), padding=(0, 2, 2), dilation=(1, 2, 2) 271 | ), 272 | nn.LeakyReLU(0.2, inplace=True), 273 | nn.Conv3d( 274 | 128, 128, (1, 3, 3), (1, 1, 1), padding=(0, 1, 1), dilation=(1, 1, 1) 275 | ), 276 | nn.LeakyReLU(0.2, inplace=True), 277 | ) 278 | 279 | # feature propagation module 280 | self.feat_prop_module = BidirectionalPropagation(128) 281 | 282 | self.decoder2 = nn.Sequential( 283 | nn.Conv2d(128, 128, 3, 1, 1), 284 | nn.LeakyReLU(0.2, inplace=True), 285 | deconv(128, 64, 3, 1), 286 | nn.LeakyReLU(0.2, inplace=True), 287 | ) # 4x 288 | 289 | self.decoder1 = nn.Sequential( 290 | nn.Conv2d(64, 64, 3, 1, 1), 291 | nn.LeakyReLU(0.2, inplace=True), 292 | deconv(64, 32, 3, 1), 293 | nn.LeakyReLU(0.2, inplace=True), 294 | ) # 2x 295 | 296 | self.upsample = nn.Sequential( 297 | nn.Conv2d(32, 32, 3, padding=1), 298 | nn.LeakyReLU(0.2, inplace=True), 299 | deconv(32, 2, 3, 1), 300 | ) 301 | 302 | # edge loss 303 | self.edgeDetector = EdgeDetection(in_ch=2, out_ch=1, mid_ch=16) 304 | 305 | # Need to initial the weights of MSDeformAttn specifically 306 | for m in self.modules(): 307 | if isinstance(m, SecondOrderDeformableAlignment): 308 | m.init_offset() 309 | 310 | if model_path is not None: 311 | print("Pretrained flow completion model has loaded...") 312 | ckpt = torch.load(model_path, map_location="cpu") 313 | self.load_state_dict(ckpt, strict=True) 314 | 315 | def forward(self, masked_flows, masks): 316 | # masked_flows: b t-1 2 h w 317 | # masks: b t-1 2 h w 318 | b, t, _, h, w = masked_flows.size() 319 | masked_flows = masked_flows.permute(0, 2, 1, 3, 4) 320 | masks = masks.permute(0, 2, 1, 3, 4) 321 | 322 | inputs = torch.cat((masked_flows, masks), dim=1) 323 | 324 | x = self.downsample(inputs) 325 | 326 | feat_e1 = self.encoder1(x) 327 | feat_e2 = self.encoder2(feat_e1) # b c t h w 328 | feat_mid = self.mid_dilation(feat_e2) # b c t h w 329 | feat_mid = feat_mid.permute(0, 2, 1, 3, 4) # b t c h w 330 | 331 | feat_prop = self.feat_prop_module(feat_mid) 332 | feat_prop = feat_prop.view(-1, 128, h // 8, w // 8) # b*t c h w 333 | 334 | _, c, _, h_f, w_f = feat_e1.shape 335 | feat_e1 = ( 336 | feat_e1.permute(0, 2, 1, 3, 4).contiguous().view(-1, c, h_f, w_f) 337 | ) # b*t c h w 338 | feat_d2 = self.decoder2(feat_prop) + feat_e1 339 | 340 | _, c, _, h_f, w_f = x.shape 341 | x = x.permute(0, 2, 1, 3, 4).contiguous().view(-1, c, h_f, w_f) # b*t c h w 342 | 343 | feat_d1 = self.decoder1(feat_d2) 344 | 345 | flow = self.upsample(feat_d1) 346 | if self.training: 347 | edge = self.edgeDetector(flow) 348 | edge = edge.view(b, t, 1, h, w) 349 | else: 350 | edge = None 351 | 352 | flow = flow.view(b, t, 2, h, w) 353 | 354 | return flow, edge 355 | 356 | def forward_bidirect_flow(self, masked_flows_bi, masks): 357 | """Args: 358 | masked_flows_bi: [masked_flows_f, masked_flows_b] | (b t-1 2 h w), (b t-1 2 h w) 359 | masks: b t 1 h w 360 | """ 361 | masks_forward = masks[:, :-1, ...].contiguous() 362 | masks_backward = masks[:, 1:, ...].contiguous() 363 | 364 | # mask flow 365 | masked_flows_forward = masked_flows_bi[0] * (1 - masks_forward) 366 | masked_flows_backward = masked_flows_bi[1] * (1 - masks_backward) 367 | 368 | # -- completion -- 369 | # forward 370 | pred_flows_forward, pred_edges_forward = self.forward( 371 | masked_flows_forward, masks_forward 372 | ) 373 | 374 | # backward 375 | masked_flows_backward = torch.flip(masked_flows_backward, dims=[1]) 376 | masks_backward = torch.flip(masks_backward, dims=[1]) 377 | pred_flows_backward, pred_edges_backward = self.forward( 378 | masked_flows_backward, masks_backward 379 | ) 380 | pred_flows_backward = torch.flip(pred_flows_backward, dims=[1]) 381 | if self.training: 382 | pred_edges_backward = torch.flip(pred_edges_backward, dims=[1]) 383 | 384 | return [pred_flows_forward, pred_flows_backward], [ 385 | pred_edges_forward, 386 | pred_edges_backward, 387 | ] 388 | 389 | def combine_flow(self, masked_flows_bi, pred_flows_bi, masks): 390 | masks_forward = masks[:, :-1, ...].contiguous() 391 | masks_backward = masks[:, 1:, ...].contiguous() 392 | 393 | pred_flows_forward = pred_flows_bi[0] * masks_forward + masked_flows_bi[0] * ( 394 | 1 - masks_forward 395 | ) 396 | pred_flows_backward = pred_flows_bi[1] * masks_backward + masked_flows_bi[1] * ( 397 | 1 - masks_backward 398 | ) 399 | 400 | return pred_flows_forward, pred_flows_backward 401 | -------------------------------------------------------------------------------- /model/vgg_arch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from torch import nn 5 | from torchvision.models import vgg 6 | 7 | VGG_PRETRAIN_PATH = "experiments/pretrained_models/vgg19-dcbb9e9d.pth" 8 | NAMES = { 9 | "vgg11": [ 10 | "conv1_1", 11 | "relu1_1", 12 | "pool1", 13 | "conv2_1", 14 | "relu2_1", 15 | "pool2", 16 | "conv3_1", 17 | "relu3_1", 18 | "conv3_2", 19 | "relu3_2", 20 | "pool3", 21 | "conv4_1", 22 | "relu4_1", 23 | "conv4_2", 24 | "relu4_2", 25 | "pool4", 26 | "conv5_1", 27 | "relu5_1", 28 | "conv5_2", 29 | "relu5_2", 30 | "pool5", 31 | ], 32 | "vgg13": [ 33 | "conv1_1", 34 | "relu1_1", 35 | "conv1_2", 36 | "relu1_2", 37 | "pool1", 38 | "conv2_1", 39 | "relu2_1", 40 | "conv2_2", 41 | "relu2_2", 42 | "pool2", 43 | "conv3_1", 44 | "relu3_1", 45 | "conv3_2", 46 | "relu3_2", 47 | "pool3", 48 | "conv4_1", 49 | "relu4_1", 50 | "conv4_2", 51 | "relu4_2", 52 | "pool4", 53 | "conv5_1", 54 | "relu5_1", 55 | "conv5_2", 56 | "relu5_2", 57 | "pool5", 58 | ], 59 | "vgg16": [ 60 | "conv1_1", 61 | "relu1_1", 62 | "conv1_2", 63 | "relu1_2", 64 | "pool1", 65 | "conv2_1", 66 | "relu2_1", 67 | "conv2_2", 68 | "relu2_2", 69 | "pool2", 70 | "conv3_1", 71 | "relu3_1", 72 | "conv3_2", 73 | "relu3_2", 74 | "conv3_3", 75 | "relu3_3", 76 | "pool3", 77 | "conv4_1", 78 | "relu4_1", 79 | "conv4_2", 80 | "relu4_2", 81 | "conv4_3", 82 | "relu4_3", 83 | "pool4", 84 | "conv5_1", 85 | "relu5_1", 86 | "conv5_2", 87 | "relu5_2", 88 | "conv5_3", 89 | "relu5_3", 90 | "pool5", 91 | ], 92 | "vgg19": [ 93 | "conv1_1", 94 | "relu1_1", 95 | "conv1_2", 96 | "relu1_2", 97 | "pool1", 98 | "conv2_1", 99 | "relu2_1", 100 | "conv2_2", 101 | "relu2_2", 102 | "pool2", 103 | "conv3_1", 104 | "relu3_1", 105 | "conv3_2", 106 | "relu3_2", 107 | "conv3_3", 108 | "relu3_3", 109 | "conv3_4", 110 | "relu3_4", 111 | "pool3", 112 | "conv4_1", 113 | "relu4_1", 114 | "conv4_2", 115 | "relu4_2", 116 | "conv4_3", 117 | "relu4_3", 118 | "conv4_4", 119 | "relu4_4", 120 | "pool4", 121 | "conv5_1", 122 | "relu5_1", 123 | "conv5_2", 124 | "relu5_2", 125 | "conv5_3", 126 | "relu5_3", 127 | "conv5_4", 128 | "relu5_4", 129 | "pool5", 130 | ], 131 | } 132 | 133 | 134 | def insert_bn(names): 135 | """Insert bn layer after each conv. 136 | 137 | Args: 138 | names (list): The list of layer names. 139 | 140 | Returns: 141 | list: The list of layer names with bn layers. 142 | """ 143 | names_bn = [] 144 | for name in names: 145 | names_bn.append(name) 146 | if "conv" in name: 147 | position = name.replace("conv", "") 148 | names_bn.append("bn" + position) 149 | return names_bn 150 | 151 | 152 | class VGGFeatureExtractor(nn.Module): 153 | """VGG network for feature extraction. 154 | 155 | In this implementation, we allow users to choose whether use normalization 156 | in the input feature and the type of vgg network. Note that the pretrained 157 | path must fit the vgg type. 158 | 159 | Args: 160 | layer_name_list (list[str]): Forward function returns the corresponding 161 | features according to the layer_name_list. 162 | Example: {'relu1_1', 'relu2_1', 'relu3_1'}. 163 | vgg_type (str): Set the type of vgg network. Default: 'vgg19'. 164 | use_input_norm (bool): If True, normalize the input image. Importantly, 165 | the input feature must in the range [0, 1]. Default: True. 166 | range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. 167 | Default: False. 168 | requires_grad (bool): If true, the parameters of VGG network will be 169 | optimized. Default: False. 170 | remove_pooling (bool): If true, the max pooling operations in VGG net 171 | will be removed. Default: False. 172 | pooling_stride (int): The stride of max pooling operation. Default: 2. 173 | """ 174 | 175 | def __init__( 176 | self, 177 | layer_name_list, 178 | vgg_type="vgg19", 179 | use_input_norm=True, 180 | range_norm=False, 181 | requires_grad=False, 182 | remove_pooling=False, 183 | pooling_stride=2, 184 | ): 185 | super(VGGFeatureExtractor, self).__init__() 186 | 187 | self.layer_name_list = layer_name_list 188 | self.use_input_norm = use_input_norm 189 | self.range_norm = range_norm 190 | 191 | self.names = NAMES[vgg_type.replace("_bn", "")] 192 | if "bn" in vgg_type: 193 | self.names = insert_bn(self.names) 194 | 195 | # only borrow layers that will be used to avoid unused params 196 | max_idx = 0 197 | for v in layer_name_list: 198 | idx = self.names.index(v) 199 | if idx > max_idx: 200 | max_idx = idx 201 | 202 | if os.path.exists(VGG_PRETRAIN_PATH): 203 | vgg_net = getattr(vgg, vgg_type)(pretrained=False) 204 | state_dict = torch.load( 205 | VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage 206 | ) 207 | vgg_net.load_state_dict(state_dict) 208 | else: 209 | vgg_net = getattr(vgg, vgg_type)(pretrained=True) 210 | 211 | features = vgg_net.features[: max_idx + 1] 212 | 213 | modified_net = OrderedDict() 214 | for k, v in zip(self.names, features): 215 | if "pool" in k: 216 | # if remove_pooling is true, pooling operation will be removed 217 | if remove_pooling: 218 | continue 219 | else: 220 | # in some cases, we may want to change the default stride 221 | modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride) 222 | else: 223 | modified_net[k] = v 224 | 225 | self.vgg_net = nn.Sequential(modified_net) 226 | 227 | if not requires_grad: 228 | self.vgg_net.eval() 229 | for param in self.parameters(): 230 | param.requires_grad = False 231 | else: 232 | self.vgg_net.train() 233 | for param in self.parameters(): 234 | param.requires_grad = True 235 | 236 | if self.use_input_norm: 237 | # the mean is for image with range [0, 1] 238 | self.register_buffer( 239 | "mean", torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) 240 | ) 241 | # the std is for image with range [0, 1] 242 | self.register_buffer( 243 | "std", torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) 244 | ) 245 | 246 | def forward(self, x): 247 | """Forward function. 248 | 249 | Args: 250 | x (Tensor): Input tensor with shape (n, c, h, w). 251 | 252 | Returns: 253 | Tensor: Forward results. 254 | """ 255 | if self.range_norm: 256 | x = (x + 1) / 2 257 | if self.use_input_norm: 258 | x = (x - self.mean) / self.std 259 | output = {} 260 | 261 | for key, layer in self.vgg_net._modules.items(): 262 | x = layer(x) 263 | if key in self.layer_name_list: 264 | output[key] = x.clone() 265 | 266 | return output 267 | -------------------------------------------------------------------------------- /propainter_inference.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | import numpy as np 4 | from numpy.typing import NDArray 5 | 6 | import torch 7 | from tqdm import tqdm 8 | 9 | from .model.modules.flow_comp_raft import RAFT_bi 10 | from .model.recurrent_flow_completion import ( 11 | RecurrentFlowCompleteNet, 12 | ) 13 | from .utils.model_utils import Models 14 | from .model.propainter import InpaintGenerator 15 | 16 | 17 | @dataclass 18 | class ProPainterConfig: 19 | ref_stride: int 20 | neighbor_length: int 21 | subvideo_length: int 22 | raft_iter: int 23 | fp16: str 24 | video_length: int 25 | device: torch.device 26 | process_size: tuple[int, int] 27 | use_half: bool = field(init=False) 28 | 29 | def __post_init__(self) -> None: 30 | """Initialize use-half.""" 31 | self.use_half = self.fp16 == "enable" 32 | if self.device == torch.device("cpu"): 33 | self.use_half = False 34 | 35 | 36 | def get_ref_index( 37 | mid_neighbor_id: int, 38 | neighbor_ids: list[int], 39 | config: ProPainterConfig, 40 | ref_num: int = -1, 41 | ) -> list[int]: 42 | """Calculate reference indices for frames based on the provided parameters.""" 43 | ref_index = [] 44 | if ref_num == -1: 45 | for i in range(0, config.video_length, config.ref_stride): 46 | if i not in neighbor_ids: 47 | ref_index.append(i) 48 | else: 49 | start_idx = max(0, mid_neighbor_id - config.ref_stride * (ref_num // 2)) 50 | end_idx = min( 51 | config.video_length, mid_neighbor_id + config.ref_stride * (ref_num // 2) 52 | ) 53 | for i in range(start_idx, end_idx, config.ref_stride): 54 | if i not in neighbor_ids: 55 | if len(ref_index) > ref_num: 56 | break 57 | ref_index.append(i) 58 | return ref_index 59 | 60 | 61 | def compute_flow( 62 | raft_model: RAFT_bi, frames: torch.Tensor, config: ProPainterConfig 63 | ) -> tuple[torch.Tensor, torch.Tensor]: 64 | """Compute forward and backward optical flows using the RAFT model.""" 65 | if frames.size(dim=-1) <= 640: 66 | short_clip_len = 12 67 | elif frames.size(dim=-1) <= 720: 68 | short_clip_len = 8 69 | elif frames.size(dim=-1) <= 1280: 70 | short_clip_len = 4 71 | else: 72 | short_clip_len = 2 73 | 74 | # use fp32 for RAFT 75 | if frames.size(dim=1) > short_clip_len: 76 | gt_flows_f_list, gt_flows_b_list = [], [] 77 | for chunck in range(0, config.video_length, short_clip_len): 78 | end_f = min(config.video_length, chunck + short_clip_len) 79 | if chunck == 0: 80 | flows_f, flows_b = raft_model( 81 | frames[:, chunck:end_f], iters=config.raft_iter 82 | ) 83 | else: 84 | flows_f, flows_b = raft_model( 85 | frames[:, chunck - 1 : end_f], iters=config.raft_iter 86 | ) 87 | 88 | gt_flows_f_list.append(flows_f) 89 | gt_flows_b_list.append(flows_b) 90 | torch.cuda.empty_cache() 91 | 92 | gt_flows_f = torch.cat(gt_flows_f_list, dim=1) 93 | gt_flows_b = torch.cat(gt_flows_b_list, dim=1) 94 | gt_flows_bi = (gt_flows_f, gt_flows_b) 95 | else: 96 | gt_flows_bi = raft_model(frames, iters=config.raft_iter) 97 | torch.cuda.empty_cache() 98 | 99 | return gt_flows_bi 100 | 101 | 102 | def complete_flow( 103 | recurrent_flow_model: RecurrentFlowCompleteNet, 104 | flows_tuple: tuple[torch.Tensor, torch.Tensor], 105 | flow_masks: torch.Tensor, 106 | subvideo_length: int, 107 | ) -> tuple[torch.Tensor, torch.Tensor]: 108 | """Complete and refine optical flows using a recurrent flow completion model. 109 | 110 | This function processes optical flows in chunks if the total length exceeds the specified 111 | subvideo length. It uses a recurrent model to complete and refine the flows, combining 112 | forward and backward flows into bidirectional flows. 113 | """ 114 | flow_length = flows_tuple[0].size(dim=1) 115 | if flow_length > subvideo_length: 116 | pred_flows_f_list, pred_flows_b_list = [], [] 117 | pad_len = 5 118 | for f in range(0, flow_length, subvideo_length): 119 | s_f = max(0, f - pad_len) 120 | e_f = min(flow_length, f + subvideo_length + pad_len) 121 | pad_len_s = max(0, f) - s_f 122 | pad_len_e = e_f - min(flow_length, f + subvideo_length) 123 | pred_flows_bi_sub, _ = recurrent_flow_model.forward_bidirect_flow( 124 | (flows_tuple[0][:, s_f:e_f], flows_tuple[1][:, s_f:e_f]), 125 | flow_masks[:, s_f : e_f + 1], 126 | ) 127 | pred_flows_bi_sub = recurrent_flow_model.combine_flow( 128 | (flows_tuple[0][:, s_f:e_f], flows_tuple[1][:, s_f:e_f]), 129 | pred_flows_bi_sub, 130 | flow_masks[:, s_f : e_f + 1], 131 | ) 132 | 133 | pred_flows_f_list.append( 134 | pred_flows_bi_sub[0][:, pad_len_s : e_f - s_f - pad_len_e] 135 | ) 136 | pred_flows_b_list.append( 137 | pred_flows_bi_sub[1][:, pad_len_s : e_f - s_f - pad_len_e] 138 | ) 139 | torch.cuda.empty_cache() 140 | 141 | pred_flows_f = torch.cat(pred_flows_f_list, dim=1) 142 | pred_flows_b = torch.cat(pred_flows_b_list, dim=1) 143 | 144 | pred_flows_bi = (pred_flows_f, pred_flows_b) 145 | 146 | else: 147 | pred_flows_bi, _ = recurrent_flow_model.forward_bidirect_flow( 148 | flows_tuple, flow_masks 149 | ) 150 | pred_flows_bi = recurrent_flow_model.combine_flow( 151 | flows_tuple, pred_flows_bi, flow_masks 152 | ) 153 | 154 | torch.cuda.empty_cache() 155 | 156 | return pred_flows_bi 157 | 158 | 159 | def image_propagation( 160 | inpaint_model: InpaintGenerator, 161 | frames: torch.Tensor, 162 | masks_dilated: torch.Tensor, 163 | prediction_flows: tuple[torch.Tensor, torch.Tensor], 164 | config: ProPainterConfig, 165 | ) -> tuple[torch.Tensor, torch.Tensor]: 166 | """Propagate inpainted images across video frames. 167 | 168 | If the video length exceeds a defined threshold, the process is segmented and handled in chunks. 169 | """ 170 | process_width, process_height = config.process_size 171 | masked_frames = frames * (1 - masks_dilated) 172 | subvideo_length_img_prop = min( 173 | 100, config.subvideo_length 174 | ) # ensure a minimum of 100 frames for image propagation 175 | if config.video_length > subvideo_length_img_prop: 176 | updated_frames_list, updated_masks_list = [], [] 177 | pad_len = 10 178 | for f in range(0, config.video_length, subvideo_length_img_prop): 179 | s_f = max(0, f - pad_len) 180 | e_f = min(config.video_length, f + subvideo_length_img_prop + pad_len) 181 | pad_len_s = max(0, f) - s_f 182 | pad_len_e = e_f - min(config.video_length, f + subvideo_length_img_prop) 183 | b, t, _, _, _ = masks_dilated[:, s_f:e_f].size() 184 | pred_flows_bi_sub = ( 185 | prediction_flows[0][:, s_f : e_f - 1], 186 | prediction_flows[1][:, s_f : e_f - 1], 187 | ) 188 | prop_imgs_sub, updated_local_masks_sub = inpaint_model.img_propagation( 189 | masked_frames[:, s_f:e_f], 190 | pred_flows_bi_sub, 191 | masks_dilated[:, s_f:e_f], 192 | "nearest", 193 | ) 194 | updated_frames_sub = ( 195 | frames[:, s_f:e_f] * (1 - masks_dilated[:, s_f:e_f]) 196 | + prop_imgs_sub.view(b, t, 3, process_height, process_width) 197 | * masks_dilated[:, s_f:e_f] 198 | ) 199 | updated_masks_sub = updated_local_masks_sub.view( 200 | b, t, 1, process_height, process_width 201 | ) 202 | 203 | updated_frames_list.append( 204 | updated_frames_sub[:, pad_len_s : e_f - s_f - pad_len_e] 205 | ) 206 | updated_masks_list.append( 207 | updated_masks_sub[:, pad_len_s : e_f - s_f - pad_len_e] 208 | ) 209 | torch.cuda.empty_cache() 210 | 211 | updated_frames = torch.cat(updated_frames_list, dim=1) 212 | updated_masks = torch.cat(updated_masks_list, dim=1) 213 | else: 214 | b, t, _, _, _ = masks_dilated.size() 215 | prop_imgs, updated_local_masks = inpaint_model.img_propagation( 216 | masked_frames, prediction_flows, masks_dilated, "nearest" 217 | ) 218 | updated_frames = ( 219 | frames * (1 - masks_dilated) 220 | + prop_imgs.view(b, t, 3, process_height, process_width) * masks_dilated 221 | ) 222 | updated_masks = updated_local_masks.view(b, t, 1, process_height, process_width) 223 | torch.cuda.empty_cache() 224 | 225 | return updated_frames, updated_masks 226 | 227 | 228 | def feature_propagation( 229 | inpaint_model: InpaintGenerator, 230 | updated_frames: torch.Tensor, 231 | updated_masks: torch.Tensor, 232 | masks_dilated: torch.Tensor, 233 | prediction_flows: tuple[torch.Tensor, torch.Tensor], 234 | original_frames: list[NDArray], 235 | config: ProPainterConfig, 236 | ) -> list[NDArray]: 237 | """Propagate inpainted features across video frames. 238 | 239 | The process is segmented and handled in chunks if the video length exceeds a defined threshold. 240 | """ 241 | # TODO: Refactor function may be too 242 | process_width, process_height = config.process_size 243 | 244 | # TODO: Refacator how composed frames is initialized 245 | composed_frames: list[NDArray | None] = [None] * config.video_length 246 | 247 | neighbor_stride = config.neighbor_length // 2 248 | ref_num = ( 249 | config.subvideo_length // config.ref_stride 250 | if config.video_length > config.subvideo_length 251 | else -1 252 | ) 253 | 254 | for f in tqdm(range(0, config.video_length, neighbor_stride)): 255 | neighbor_ids = list( 256 | range( 257 | max(0, f - neighbor_stride), 258 | min(config.video_length, f + neighbor_stride + 1), 259 | ) 260 | ) 261 | ref_ids = get_ref_index(f, neighbor_ids, config, ref_num) 262 | selected_imgs = updated_frames[:, neighbor_ids + ref_ids, :, :, :] 263 | selected_masks = masks_dilated[:, neighbor_ids + ref_ids, :, :, :] 264 | if config.use_half: 265 | selected_masks = selected_masks.half() 266 | selected_update_masks = updated_masks[:, neighbor_ids + ref_ids, :, :, :] 267 | selected_pred_flows_bi = ( 268 | prediction_flows[0][:, neighbor_ids[:-1], :, :, :], 269 | prediction_flows[1][:, neighbor_ids[:-1], :, :, :], 270 | ) 271 | with torch.no_grad(): 272 | # 1.0 indicates mask 273 | l_t = len(neighbor_ids) 274 | 275 | pred_img = inpaint_model( 276 | selected_imgs, 277 | selected_pred_flows_bi, 278 | selected_masks, 279 | selected_update_masks, 280 | l_t, 281 | ) 282 | 283 | pred_img = pred_img.view(-1, 3, process_height, process_width) 284 | 285 | pred_img = (pred_img + 1) / 2 286 | pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255 287 | binary_masks = ( 288 | masks_dilated[0, neighbor_ids, :, :, :] 289 | .cpu() 290 | .permute(0, 2, 3, 1) 291 | .numpy() 292 | .astype(np.uint8) 293 | ) 294 | for i, idx in enumerate(neighbor_ids): 295 | # idx = neighbor_ids[i] 296 | img = np.array(pred_img[i]).astype(np.uint8) * binary_masks[ 297 | i 298 | ] + original_frames[idx] * (1 - binary_masks[i]) 299 | if composed_frames[idx] is None: 300 | composed_frames[idx] = img 301 | else: 302 | composed_frames[idx] = ( 303 | composed_frames[idx].astype(np.float32) * 0.5 304 | + img.astype(np.float32) * 0.5 305 | ) 306 | 307 | composed_frames[idx] = composed_frames[idx].astype(np.uint8) 308 | 309 | torch.cuda.empty_cache() 310 | 311 | return composed_frames 312 | 313 | 314 | def process_inpainting( 315 | models: Models, 316 | frames: torch.Tensor, 317 | flow_masks: torch.Tensor, 318 | masks_dilated: torch.Tensor, 319 | config: ProPainterConfig, 320 | ) -> tuple[torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: 321 | """Apply inpainting on video using recurrent flow and ProPainter model.""" 322 | with torch.no_grad(): 323 | gt_flows_bi = compute_flow(models.raft_model, frames, config) 324 | 325 | if config.use_half: 326 | frames, flow_masks, masks_dilated = ( 327 | frames.half(), 328 | flow_masks.half(), 329 | masks_dilated.half(), 330 | ) 331 | gt_flows_bi = (gt_flows_bi[0].half(), gt_flows_bi[1].half()) 332 | 333 | pred_flows_bi = complete_flow( 334 | models.flow_model, gt_flows_bi, flow_masks, config.subvideo_length 335 | ) 336 | 337 | updated_frames, updated_masks = image_propagation( 338 | models.inpaint_model, frames, masks_dilated, pred_flows_bi, config 339 | ) 340 | 341 | return updated_frames, updated_masks, pred_flows_bi 342 | -------------------------------------------------------------------------------- /propainter_nodes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from comfy import model_management 3 | 4 | from .propainter_inference import ( 5 | ProPainterConfig, 6 | feature_propagation, 7 | process_inpainting, 8 | ) 9 | from .utils.image_utils import ( 10 | ImageConfig, 11 | ImageOutpaintConfig, 12 | convert_image_to_frames, 13 | handle_output, 14 | prepare_frames_and_masks, 15 | extrapolation, 16 | prepare_frames_and_masks_for_outpaint, 17 | ) 18 | from .utils.model_utils import initialize_models 19 | 20 | 21 | def check_inputs(frames: torch.Tensor, masks: torch.Tensor) -> Exception | None: 22 | if frames.size(dim=0) <= 1: 23 | raise Exception(f"""Image length must be greater than 1, but got: 24 | Image length: ({frames.size(dim=0)})""") 25 | if frames.size(dim=0) != masks.size(dim=0) and masks.size(dim=0) != 1: 26 | raise Exception(f"""Image and Mask must have the same length or Mask have length 1, but got: 27 | Image length: {frames.size(dim=0)} 28 | Mask length: {masks.size(dim=0)}""") 29 | 30 | if frames.size(dim=1) != masks.size(dim=1) or frames.size(dim=2) != masks.size( 31 | dim=2 32 | ): 33 | raise Exception(f"""Image and Mask must have the same dimensions, but got: 34 | Image: ({frames.size(dim=1)}, {frames.size(dim=2)}) 35 | Mask: ({masks.size(dim=1)}, {masks.size(dim=2)})""") 36 | 37 | 38 | class ProPainterInpaint: 39 | """ComfyUI Node for performing inpainting on video frames using ProPainter.""" 40 | 41 | def __init__(self): 42 | pass 43 | 44 | @classmethod 45 | def INPUT_TYPES(s): 46 | return { 47 | "required": { 48 | "image": ("IMAGE",), # --video 49 | "mask": ("MASK",), # --mask 50 | "width": ("INT", {"default": 640, "min": 0, "max": 2560}), # --width 51 | "height": ("INT", {"default": 360, "min": 0, "max": 2560}), # --height 52 | "mask_dilates": ( 53 | "INT", 54 | {"default": 5, "min": 0, "max": 100}, 55 | ), # --mask_dilates 56 | "flow_mask_dilates": ( 57 | "INT", 58 | {"default": 8, "min": 0, "max": 100}, 59 | ), # arg dont exist on original code 60 | "ref_stride": ( 61 | "INT", 62 | {"default": 10, "min": 1, "max": 100}, 63 | ), # --ref_stride 64 | "neighbor_length": ( 65 | "INT", 66 | {"default": 10, "min": 2, "max": 300}, 67 | ), # --neighbor_length 68 | "subvideo_length": ( 69 | "INT", 70 | {"default": 80, "min": 1, "max": 300}, 71 | ), # --subvideo_length 72 | "raft_iter": ( 73 | "INT", 74 | {"default": 20, "min": 1, "max": 100}, 75 | ), # --raft_iter 76 | "fp16": (["enable", "disable"],), # --fp16 77 | }, 78 | } 79 | 80 | RETURN_TYPES = ( 81 | "IMAGE", 82 | "MASK", 83 | "MASK", 84 | ) 85 | RETURN_NAMES = ( 86 | "IMAGE", 87 | "FLOW_MASK", 88 | "MASK_DILATE", 89 | ) 90 | FUNCTION = "propainter_inpainting" 91 | CATEGORY = "ProPainter" 92 | 93 | def propainter_inpainting( 94 | self, 95 | image: torch.Tensor, 96 | mask: torch.Tensor, 97 | width: int, 98 | height: int, 99 | mask_dilates: int, 100 | flow_mask_dilates: int, 101 | ref_stride: int, 102 | neighbor_length: int, 103 | subvideo_length: int, 104 | raft_iter: int, 105 | fp16: str, 106 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 107 | """Perform inpainting on images input using the ProPainter model inference.""" 108 | check_inputs(image, mask) 109 | device = model_management.get_torch_device() 110 | # TODO: Check if this convertion from Torch to PIL is really necessary. 111 | frames = convert_image_to_frames(image) 112 | video_length = image.size(dim=0) 113 | input_size = frames[0].size 114 | 115 | image_config = ImageConfig( 116 | width, height, mask_dilates, flow_mask_dilates, input_size, video_length 117 | ) 118 | inpaint_config = ProPainterConfig( 119 | ref_stride, 120 | neighbor_length, 121 | subvideo_length, 122 | raft_iter, 123 | fp16, 124 | video_length, 125 | device, 126 | image_config.process_size, 127 | ) 128 | 129 | frames_tensor, flow_masks_tensor, masks_dilated_tensor, original_frames = ( 130 | prepare_frames_and_masks(frames, mask, image_config, device) 131 | ) 132 | 133 | models = initialize_models(device, inpaint_config.fp16) 134 | print(f"\nProcessing {inpaint_config.video_length} frames...") 135 | 136 | updated_frames, updated_masks, pred_flows_bi = process_inpainting( 137 | models, 138 | frames_tensor, 139 | flow_masks_tensor, 140 | masks_dilated_tensor, 141 | inpaint_config, 142 | ) 143 | 144 | composed_frames = feature_propagation( 145 | models.inpaint_model, 146 | updated_frames, 147 | updated_masks, 148 | masks_dilated_tensor, 149 | pred_flows_bi, 150 | original_frames, 151 | inpaint_config, 152 | ) 153 | 154 | return handle_output(composed_frames, flow_masks_tensor, masks_dilated_tensor) 155 | 156 | 157 | class ProPainterOutpaint: 158 | """ComfyUI Node for performing outpainting on video frames using ProPainter.""" 159 | 160 | def __init__(self): 161 | pass 162 | 163 | @classmethod 164 | def INPUT_TYPES(s): 165 | return { 166 | "required": { 167 | "image": ("IMAGE",), # --video 168 | "width": ("INT", {"default": 640, "min": 0, "max": 2560}), # --width 169 | "height": ("INT", {"default": 360, "min": 0, "max": 2560}), # --height 170 | "width_scale": ( 171 | "FLOAT", 172 | { 173 | "default": 1.2, 174 | "min": 0.0, 175 | "max": 10.0, 176 | "step": 0.01, 177 | }, 178 | ), 179 | "height_scale": ( 180 | "FLOAT", 181 | { 182 | "default": 1.0, 183 | "min": 0.0, 184 | "max": 10.0, 185 | "step": 0.01, 186 | }, 187 | ), 188 | "mask_dilates": ( 189 | "INT", 190 | {"default": 5, "min": 0, "max": 100}, 191 | ), # --mask_dilates 192 | "flow_mask_dilates": ( 193 | "INT", 194 | {"default": 8, "min": 0, "max": 100}, 195 | ), # arg dont exist on original code 196 | "ref_stride": ( 197 | "INT", 198 | {"default": 10, "min": 1, "max": 100}, 199 | ), # --ref_stride 200 | "neighbor_length": ( 201 | "INT", 202 | {"default": 10, "min": 2, "max": 300}, 203 | ), # --neighbor_length 204 | "subvideo_length": ( 205 | "INT", 206 | {"default": 80, "min": 1, "max": 300}, 207 | ), # --subvideo_length 208 | "raft_iter": ( 209 | "INT", 210 | {"default": 20, "min": 1, "max": 100}, 211 | ), # --raft_iter 212 | "fp16": (["enable", "disable"],), # --fp16 213 | }, 214 | } 215 | 216 | RETURN_TYPES = ( 217 | "IMAGE", 218 | "MASK", 219 | "INT", 220 | "INT", 221 | ) 222 | RETURN_NAMES = ( 223 | "IMAGE", 224 | "OUTPAINT_MASK", 225 | "output_width", 226 | "output_height", 227 | ) 228 | FUNCTION = "propainter_outpainting" 229 | CATEGORY = "ProPainter" 230 | 231 | def propainter_outpainting( 232 | self, 233 | image: torch.Tensor, 234 | width: int, 235 | height: int, 236 | width_scale: float, 237 | height_scale: float, 238 | mask_dilates: int, 239 | flow_mask_dilates: int, 240 | ref_stride: int, 241 | neighbor_length: int, 242 | subvideo_length: int, 243 | raft_iter: int, 244 | fp16: str, 245 | ) -> tuple[torch.Tensor, torch.Tensor, int, int]: 246 | """Perform inpainting on images input using the ProPainter model inference.""" 247 | device = model_management.get_torch_device() 248 | # TODO: Check if this convertion from Torch to PIL is really necessary. 249 | frames = convert_image_to_frames(image) 250 | video_length = image.size(dim=0) 251 | input_size = frames[0].size 252 | 253 | image_config = ImageOutpaintConfig( 254 | width, 255 | height, 256 | mask_dilates, 257 | flow_mask_dilates, 258 | input_size, 259 | video_length, 260 | width_scale, 261 | height_scale, 262 | ) 263 | 264 | outpaint_config = ProPainterConfig( 265 | ref_stride, 266 | neighbor_length, 267 | subvideo_length, 268 | raft_iter, 269 | fp16, 270 | video_length, 271 | device, 272 | image_config.outpaint_size, 273 | ) 274 | 275 | paded_frames, paded_flow_masks, paded_masks_dilated = extrapolation( 276 | frames, image_config 277 | ) 278 | 279 | frames_tensor, flow_masks_tensor, masks_dilated_tensor, original_frames = ( 280 | prepare_frames_and_masks_for_outpaint( 281 | paded_frames, paded_flow_masks, paded_masks_dilated, device 282 | ) 283 | ) 284 | 285 | models = initialize_models(device, outpaint_config.fp16) 286 | print(f"\nProcessing {outpaint_config.video_length} frames...") 287 | 288 | updated_frames, updated_masks, pred_flows_bi = process_inpainting( 289 | models, 290 | frames_tensor, 291 | flow_masks_tensor, 292 | masks_dilated_tensor, 293 | outpaint_config, 294 | ) 295 | 296 | composed_frames = feature_propagation( 297 | models.inpaint_model, 298 | updated_frames, 299 | updated_masks, 300 | masks_dilated_tensor, 301 | pred_flows_bi, 302 | original_frames, 303 | outpaint_config, 304 | ) 305 | 306 | output_frames, output_masks, _ = handle_output( 307 | composed_frames, flow_masks_tensor, masks_dilated_tensor 308 | ) 309 | output_width, output_height = outpaint_config.process_size 310 | return output_frames, output_masks, output_width, output_height 311 | 312 | 313 | NODE_CLASS_MAPPINGS = { 314 | "ProPainterInpaint": ProPainterInpaint, 315 | "ProPainterOutpaint": ProPainterOutpaint, 316 | } 317 | 318 | NODE_DISPLAY_NAME_MAPPINGS = { 319 | "ProPainterInpaint": "ProPainter Inpainting", 320 | "ProPainterOutpaint": "ProPainter Outpainting", 321 | } 322 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui_propainter_nodes" 3 | description = "ComfyUI custom node implementation of [a/ProPainter](https://github.com/sczhou/ProPainter) framework for video inpainting." 4 | version = "1.0.0" 5 | license = "LICENSE" 6 | dependencies = ["opencv-python"] 7 | 8 | [project.urls] 9 | Repository = "https://github.com/daniabib/ComfyUI_ProPainter_Nodes" 10 | # Used by Comfy Registry https://comfyregistry.org 11 | 12 | [tool.comfy] 13 | PublisherId = "daniabib" 14 | DisplayName = "ComfyUI_ProPainter_Nodes" 15 | Icon = "" 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daniabib/ComfyUI_ProPainter_Nodes/9c27d5a0a508bae3296a1886ad026d8d4139d66c/utils/__init__.py -------------------------------------------------------------------------------- /utils/download_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from urllib.parse import urljoin 3 | 4 | from torch.hub import download_url_to_file 5 | 6 | 7 | def load_file_from_url( 8 | url: str, 9 | model_dir: Path | None = None, 10 | progress: bool = True, 11 | file_name: str | None = None, 12 | ) -> str: 13 | """Load file form http url, will download models if necessary.""" 14 | file_name = Path(file_name) 15 | model_dir.mkdir(exist_ok=True) 16 | cached_file = model_dir / file_name 17 | if not cached_file.exists(): 18 | print(f'Downloading: "{url}" to {cached_file}\n') 19 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) 20 | return cached_file 21 | 22 | 23 | def download_model(model_url: str, model_name: str) -> str: 24 | """Downloads a model from a URL and returns the local path to the downloaded model.""" 25 | base_dir = Path(__file__).parents[1].resolve() 26 | target_dir = base_dir / "weights" 27 | return load_file_from_url( 28 | url=urljoin(model_url, model_name), 29 | model_dir=target_dir, 30 | progress=True, 31 | file_name=model_name, 32 | ) 33 | -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | import numpy as np 4 | import scipy 5 | import torch 6 | from numpy.typing import NDArray 7 | from PIL import Image 8 | from torchvision import transforms 9 | from torchvision.transforms.functional import to_pil_image 10 | 11 | 12 | @dataclass 13 | class ImageConfig: 14 | width: int 15 | height: int 16 | mask_dilates: int 17 | flow_mask_dilates: int 18 | input_size: tuple[int, int] 19 | video_length: int 20 | process_size: tuple[int, int] = field(init=False) 21 | 22 | def __post_init__(self) -> None: 23 | """Initialize process size.""" 24 | self.process_size = ( 25 | self.width - self.width % 8, 26 | self.height - self.height % 8, 27 | ) 28 | 29 | 30 | @dataclass 31 | class ImageOutpaintConfig(ImageConfig): 32 | width_scale: float 33 | height_scale: float 34 | process_size: tuple[int, int] = field(init=False) 35 | outpaint_size: tuple[int, int] = field(init=False) 36 | 37 | # TODO: Refactor 38 | def __post_init__(self) -> None: 39 | """Initialize output size for outpainting.""" 40 | self.process_size = ( 41 | self.width - self.width % 8, 42 | self.height - self.height % 8, 43 | ) 44 | pad_image_width = int(self.width_scale * self.width) 45 | pad_image_height = int(self.height_scale * self.height) 46 | self.outpaint_size = ( 47 | pad_image_width - pad_image_width % 8, 48 | pad_image_height - pad_image_height % 8, 49 | ) 50 | 51 | 52 | class Stack: 53 | """Stack images based on number of channels.""" 54 | 55 | def __init__(self, roll=False): 56 | self.roll = roll 57 | 58 | def __call__(self, img_group) -> NDArray: 59 | mode = img_group[0].mode 60 | if mode == "1": 61 | img_group = [img.convert("L") for img in img_group] 62 | mode = "L" 63 | if mode == "L": 64 | return np.stack([np.expand_dims(x, 2) for x in img_group], axis=2) 65 | if mode == "RGB": 66 | if self.roll: 67 | return np.stack([np.array(x)[:, :, ::-1] for x in img_group], axis=2) 68 | return np.stack(img_group, axis=2) 69 | raise NotImplementedError(f"Image mode {mode}") 70 | 71 | 72 | class ToTorchFormatTensor: 73 | """Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] to a torch FloatTensor of shape (C x H x W) in the range [0.0, 1.0].""" 74 | 75 | # TODO: Check if this function is necessary with comfyUI workflow. 76 | def __init__(self, div=True): 77 | self.div = div 78 | 79 | def __call__(self, pic) -> torch.Tensor: 80 | if isinstance(pic, np.ndarray): 81 | # numpy img: [L, C, H, W] 82 | img = torch.from_numpy(pic).permute(2, 3, 0, 1).contiguous() 83 | else: 84 | # handle PIL Image 85 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 86 | img = img.view(pic.size[1], pic.size[0], len(pic.mode)) 87 | # put it from HWC to CHW format 88 | # yikes, this transpose takes 80% of the loading time/CPU 89 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 90 | img = img.float().div(255) if self.div else img.float() 91 | return img 92 | 93 | 94 | def to_tensors(): 95 | return transforms.Compose([Stack(), ToTorchFormatTensor()]) 96 | 97 | 98 | def resize_images(images: list[Image.Image], config: ImageConfig) -> list[Image.Image]: 99 | """Resizes each image in the list to a new size divisible by 8.""" 100 | if config.process_size != config.input_size: 101 | images = [f.resize(config.process_size) for f in images] 102 | 103 | return images 104 | 105 | 106 | def convert_image_to_frames(images: torch.Tensor) -> list[Image.Image]: 107 | """Convert a batch of PyTorch tensors into a list of PIL Image frames.""" 108 | frames = [] 109 | for image in images: 110 | torch_frame = image.detach().cpu() 111 | np_frame = torch_frame.numpy() 112 | np_frame = (np_frame * 255).clip(0, 255).astype(np.uint8) 113 | frame = Image.fromarray(np_frame) 114 | frames.append(frame) 115 | 116 | return frames 117 | 118 | 119 | def binary_mask(mask: np.ndarray, th: float = 0.1) -> np.ndarray: 120 | mask[mask > th] = 1 121 | mask[mask <= th] = 0 122 | 123 | return mask 124 | 125 | 126 | def convert_mask_to_frames(images: torch.Tensor) -> list[Image.Image]: 127 | """Convert a batch of PyTorch tensors into a list of PIL Image frames.""" 128 | frames = [] 129 | for image in images: 130 | image = image.detach().cpu() 131 | 132 | # Adjust the scaling based on the data type 133 | if image.dtype == torch.float32: 134 | image = (image * 255).clamp(0, 255).byte() 135 | 136 | frame: Image.Image = to_pil_image(image) 137 | frames.append(frame) 138 | 139 | return frames 140 | 141 | 142 | def read_masks( 143 | masks: torch.Tensor, config: ImageConfig 144 | ) -> tuple[list[Image.Image], list[Image.Image]]: 145 | """TODO: Docstring.""" 146 | mask_images = convert_mask_to_frames(masks) 147 | mask_images = resize_images(mask_images, config) 148 | masks_dilated: list[Image.Image] = [] 149 | flow_masks: list[Image.Image] = [] 150 | 151 | for mask_image in mask_images: 152 | mask_array = np.array(mask_image.convert("L")) 153 | 154 | # Dilate 8 pixel so that all known pixel is trustworthy 155 | if config.flow_mask_dilates > 0: 156 | flow_mask_img = scipy.ndimage.binary_dilation( 157 | mask_array, iterations=config.flow_mask_dilates 158 | ).astype(np.uint8) 159 | else: 160 | flow_mask_img = binary_mask(mask_array).astype(np.uint8) 161 | flow_masks.append(Image.fromarray(flow_mask_img * 255)) 162 | 163 | if config.mask_dilates > 0: 164 | mask_array = scipy.ndimage.binary_dilation( 165 | mask_array, iterations=config.mask_dilates 166 | ).astype(np.uint8) 167 | else: 168 | mask_array = binary_mask(mask_array).astype(np.uint8) 169 | masks_dilated.append(Image.fromarray(mask_array * 255)) 170 | 171 | if len(mask_images) == 1: 172 | flow_masks = flow_masks * config.video_length 173 | masks_dilated = masks_dilated * config.video_length 174 | 175 | return flow_masks, masks_dilated 176 | 177 | 178 | def prepare_frames_and_masks( 179 | frames: list[Image.Image], 180 | mask: torch.Tensor, 181 | config: ImageConfig, 182 | device: torch.device, 183 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[NDArray]]: 184 | frames = resize_images(frames, config) 185 | 186 | flow_masks, masks_dilated = read_masks(mask, config) 187 | 188 | original_frames = [np.array(f) for f in frames] 189 | frames_tensor: torch.Tensor = to_tensors()(frames).unsqueeze(0) * 2 - 1 190 | flow_masks_tensor: torch.Tensor = to_tensors()(flow_masks).unsqueeze(0) 191 | masks_dilated_tensor: torch.Tensor = to_tensors()(masks_dilated).unsqueeze(0) 192 | frames_tensor, flow_masks_tensor, masks_dilated_tensor = ( 193 | frames_tensor.to(device), 194 | flow_masks_tensor.to(device), 195 | masks_dilated_tensor.to(device), 196 | ) 197 | return frames_tensor, flow_masks_tensor, masks_dilated_tensor, original_frames 198 | 199 | 200 | def extrapolation( 201 | resized_frames: list[Image.Image], image_config: ImageOutpaintConfig 202 | ) -> tuple[list[Image.Image], list[Image.Image], list[Image.Image]]: 203 | """Prepares the data for video outpainting.""" 204 | resized_frames = resize_images(resized_frames, image_config) 205 | 206 | # input_width, input_height = image_config.input_size 207 | resized_width, resized_height = resized_frames[0].size 208 | pad_image_width, pad_image_height = image_config.outpaint_size 209 | 210 | # Defines new FOV. 211 | width_start = int((pad_image_width - resized_width) / 2) 212 | height_start = int((pad_image_height - resized_height) / 2) 213 | 214 | # Extrapolates the FOV for video. 215 | extrapolated_frames = [] 216 | for v in resized_frames: 217 | frame = np.zeros(((pad_image_height, pad_image_width, 3)), dtype=np.uint8) 218 | frame[ 219 | height_start : height_start + resized_height, 220 | width_start : width_start + resized_width, 221 | :, 222 | ] = v 223 | extrapolated_frames.append(Image.fromarray(frame)) 224 | 225 | # Generates the mask for missing region. 226 | masks_dilated = [] 227 | flow_masks = [] 228 | 229 | dilate_h = 4 if height_start > 10 else 0 230 | dilate_w = 4 if width_start > 10 else 0 231 | mask = np.ones(((pad_image_height, pad_image_width)), dtype=np.uint8) 232 | 233 | mask[ 234 | height_start + dilate_h : height_start + resized_height - dilate_h, 235 | width_start + dilate_w : width_start + resized_width - dilate_w, 236 | ] = 0 237 | flow_masks.append(Image.fromarray(mask * 255)) 238 | 239 | mask[ 240 | height_start : height_start + resized_height, 241 | width_start : width_start + resized_width, 242 | ] = 0 243 | masks_dilated.append(Image.fromarray(mask * 255)) 244 | 245 | flow_masks = flow_masks * image_config.video_length 246 | masks_dilated = masks_dilated * image_config.video_length 247 | 248 | return ( 249 | extrapolated_frames, 250 | flow_masks, 251 | masks_dilated, 252 | ) 253 | 254 | 255 | def prepare_frames_and_masks_for_outpaint( 256 | frames: list[Image.Image], 257 | flow_masks: list[Image.Image], 258 | masks_dilated: list[Image.Image], 259 | device: torch.device, 260 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[NDArray]]: 261 | # flow_masks, masks_dilated = read_masks(mask, config) 262 | 263 | # original_frames = [np.array(f).astype(np.uint8) for f in frames] 264 | original_frames = [np.array(f) for f in frames] 265 | frames_tensor: torch.Tensor = to_tensors()(frames).unsqueeze(0) * 2 - 1 266 | flow_masks_tensor: torch.Tensor = to_tensors()(flow_masks).unsqueeze(0) 267 | masks_dilated_tensor: torch.Tensor = to_tensors()(masks_dilated).unsqueeze(0) 268 | frames_tensor, flow_masks_tensor, masks_dilated_tensor = ( 269 | frames_tensor.to(device), 270 | flow_masks_tensor.to(device), 271 | masks_dilated_tensor.to(device), 272 | ) 273 | return frames_tensor, flow_masks_tensor, masks_dilated_tensor, original_frames 274 | 275 | 276 | def handle_output( 277 | composed_frames: list[NDArray], 278 | flow_masks: torch.Tensor, 279 | masks_dilated: torch.Tensor, 280 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 281 | output_frames = [ 282 | torch.from_numpy(frame.astype(np.float32) / 255.0) for frame in composed_frames 283 | ] 284 | 285 | output_images = torch.stack(output_frames) 286 | 287 | output_flow_masks = flow_masks.squeeze() 288 | output_masks_dilated = masks_dilated.squeeze() 289 | 290 | return output_images, output_flow_masks, output_masks_dilated 291 | -------------------------------------------------------------------------------- /utils/model_utils.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | 5 | from ..model.modules.flow_comp_raft import RAFT_bi 6 | from ..model.propainter import InpaintGenerator 7 | from ..model.recurrent_flow_completion import ( 8 | RecurrentFlowCompleteNet, 9 | ) 10 | from ..utils.download_utils import download_model 11 | 12 | 13 | @dataclass 14 | class Models: 15 | raft_model: RAFT_bi 16 | flow_model: RecurrentFlowCompleteNet 17 | inpaint_model: InpaintGenerator 18 | 19 | 20 | PRETRAIN_MODEL_URL = "https://github.com/sczhou/ProPainter/releases/download/v0.1.0/" 21 | 22 | 23 | def load_raft_model(device: torch.device) -> RAFT_bi: 24 | """Loads the RAFT bi-directional model.""" 25 | model_path = download_model(PRETRAIN_MODEL_URL, "raft-things.pth") 26 | raft_model = RAFT_bi(model_path, device) 27 | return raft_model 28 | 29 | 30 | def load_recurrent_flow_model(device: torch.device) -> RecurrentFlowCompleteNet: 31 | """Loads the Recurrent Flow Completion Network model.""" 32 | model_path = download_model(PRETRAIN_MODEL_URL, "recurrent_flow_completion.pth") 33 | flow_model = RecurrentFlowCompleteNet(model_path) 34 | for p in flow_model.parameters(): 35 | p.requires_grad = False 36 | flow_model.to(device) 37 | flow_model.eval() 38 | return flow_model 39 | 40 | 41 | def load_inpaint_model(device: torch.device) -> InpaintGenerator: 42 | """Loads the Inpaint Generator model.""" 43 | model_path = download_model(PRETRAIN_MODEL_URL, "ProPainter.pth") 44 | inpaint_model = InpaintGenerator(model_path=model_path).to(device) 45 | inpaint_model.eval() 46 | return inpaint_model 47 | 48 | 49 | def initialize_models(device: torch.device, use_half: str) -> Models: 50 | """Return initialized inference models.""" 51 | raft_model = load_raft_model(device) 52 | flow_model = load_recurrent_flow_model(device) 53 | inpaint_model = load_inpaint_model(device) 54 | 55 | if use_half == "enable": 56 | # raft_model = raft_model.half() 57 | flow_model = flow_model.half() 58 | inpaint_model = inpaint_model.half() 59 | return Models(raft_model, flow_model, inpaint_model) 60 | --------------------------------------------------------------------------------