├── .github
└── workflows
│ └── publish.yml
├── .gitignore
├── LICENSE.txt
├── README.md
├── __init__.py
├── examples
└── propainter-inpainting-workflow.json
├── model
├── __init__.py
├── canny
│ ├── canny_filter.py
│ ├── filter.py
│ ├── gaussian.py
│ ├── kernels.py
│ └── sobel.py
├── misc.py
├── modules
│ ├── RAFT
│ │ ├── __init__.py
│ │ ├── corr.py
│ │ ├── datasets.py
│ │ ├── demo.py
│ │ ├── extractor.py
│ │ ├── raft.py
│ │ ├── update.py
│ │ └── utils
│ │ │ ├── __init__.py
│ │ │ ├── augmentor.py
│ │ │ ├── flow_viz.py
│ │ │ ├── flow_viz_pt.py
│ │ │ ├── frame_utils.py
│ │ │ └── utils.py
│ ├── base_module.py
│ ├── deformconv.py
│ ├── flow_comp_raft.py
│ ├── flow_loss_utils.py
│ ├── sparse_transformer.py
│ └── spectral_norm.py
├── propainter.py
├── recurrent_flow_completion.py
└── vgg_arch.py
├── propainter_inference.py
├── propainter_nodes.py
├── pyproject.toml
├── requirements.txt
└── utils
├── __init__.py
├── download_utils.py
├── image_utils.py
└── model_utils.py
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish to Comfy registry
2 | on:
3 | workflow_dispatch:
4 | push:
5 | branches:
6 | - main
7 | - master
8 | paths:
9 | - "pyproject.toml"
10 |
11 | jobs:
12 | publish-node:
13 | name: Publish Custom Node to registry
14 | runs-on: ubuntu-latest
15 | steps:
16 | - name: Check out code
17 | uses: actions/checkout@v4
18 | - name: Publish Custom Node
19 | uses: Comfy-Org/publish-node-action@main
20 | with:
21 | ## Add your own personal access token to your Github Repository secrets and reference it here.
22 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }}
23 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | results/
2 | weights/
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 | cover/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | .pybuilder/
79 | target/
80 |
81 | # Jupyter Notebook
82 | .ipynb_checkpoints
83 |
84 | # IPython
85 | profile_default/
86 | ipython_config.py
87 |
88 | # pyenv
89 | # For a library or package, you might want to ignore these files since the code is
90 | # intended to run in multiple environments; otherwise, check them in:
91 | # .python-version
92 |
93 | # pipenv
94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
97 | # install all needed dependencies.
98 | #Pipfile.lock
99 |
100 | # poetry
101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102 | # This is especially recommended for binary packages to ensure reproducibility, and is more
103 | # commonly ignored for libraries.
104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105 | #poetry.lock
106 |
107 | # pdm
108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109 | #pdm.lock
110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111 | # in version control.
112 | # https://pdm.fming.dev/#use-with-ide
113 | .pdm.toml
114 |
115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
116 | __pypackages__/
117 |
118 | # Celery stuff
119 | celerybeat-schedule
120 | celerybeat.pid
121 |
122 | # SageMath parsed files
123 | *.sage.py
124 |
125 | # Environments
126 | .env
127 | .venv
128 | env/
129 | venv/
130 | ENV/
131 | env.bak/
132 | venv.bak/
133 |
134 | # Spyder project settings
135 | .spyderproject
136 | .spyproject
137 |
138 | # Rope project settings
139 | .ropeproject
140 |
141 | # mkdocs documentation
142 | /site
143 |
144 | # mypy
145 | .mypy_cache/
146 | .dmypy.json
147 | dmypy.json
148 | mypy.ini
149 |
150 | # ruff
151 | .ruff_cache/
152 | ruff.toml
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ProPainter Nodes for ComfyUI
2 |
3 | [ComfyUI](https://github.com/comfyanonymous/ComfyUI) implementation of [ProPainter](https://github.com/sczhou/ProPainter) for video inpainting. ProPainter is a framework that utilizes flow-based propagation and spatiotemporal transformer to enable advanced video frame editing for seamless inpainting tasks.
4 |
5 | ## Features
6 |
7 | #### 👨🏻🎨 Object Removal
8 |
9 |
10 |
11 |
12 | |
13 |
14 |
15 | |
16 |
17 |
18 |
19 | #### 🎨 Video Completion
20 |
21 |
22 |
23 |
24 | |
25 |
26 |
27 | |
28 |
29 |
30 |
31 | ## Installation
32 | ### ComfyUI Manager:
33 | You can use [ComfyUI Manager](https://github.com/ltdrdata/ComfyUI-Manager) to install the nodes:
34 | 1. Search for `ComfyUI ProPainter Nodes` and author `daniabib`.
35 |
36 | ### Manual Installation:
37 | 1. Clone this repository to `ComfyUI/custom_nodes`:
38 | ```bash
39 | git clone https://github.com/daniabib/ComfyUI_ProPainter_Nodes
40 | ```
41 |
42 | 2. Install the required dependencies:
43 | ```bash
44 | pip install -r requirements.txt
45 | ```
46 |
47 | Models will be automatically downloaded to the `weights` folder.
48 |
49 | ## Examples
50 | **Basic Inpainting Workflow**
51 |
52 | https://github.com/daniabib/ComfyUI_ProPainter_Nodes/assets/33937060/56244d09-fe89-4af2-916b-e8d903752f0d
53 |
54 | https://github.com/daniabib/ComfyUI_ProPainter_Nodes/blob/main/examples/propainter-inpainting-workflow.json
55 |
56 | ## Others suggested nodes
57 | * [VideoHelperSuite](https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite) for loading and saving the video frames.
58 | * [ComfyUI-YoloWorld-EfficientSAM](https://github.com/ZHO-ZHO-ZHO/ComfyUI-YoloWorld-EfficientSAM) for masking images using prompt.
59 |
60 | ## Nodes Reference
61 | **🚧 Section under construction**
62 | ### ProPainter Inpainting
63 |
64 | #### Input Parameters:
65 | - `image`: The video frames to be inpainted.
66 | - `mask`: The mask indicating the regions to be inpainted. Mask must have same size of video frames.
67 | - `width`: Width of the output images. (default: 640).
68 | - `height`: Height of the output images. (default: 360).
69 | - `mask_dilates`: Dilation size for the mask (default: 5).
70 | - `flow_mask_dilates`: Dilation size for the flow mask (default: 8).
71 | - `ref_stride`: Stride for reference frames (default: 10).
72 | - `neighbor_length`: Length of the neighborhood for inpainting (default: 10).
73 | - `subvideo_length`: Length of subvideos for processing (default: 80).
74 | - `raft_iter`): Number of iterations for RAFT model (default: 20).
75 | - `fp16`: Enable or disable FP16 precision (default: "enable").
76 |
77 | #### Output:
78 | - `IMAGE`: The inpainted video frames.
79 | - `FLOW_MASK`: The flow mask used during inpainting.
80 | - `MASK_DILATE`: The dilated mask used during inpainting.
81 |
82 | ### ProPainter Outpainting
83 | **Note**: The authors of the paper didn't mention the outpainting task for their framework, but there is an option for it in the original code. The results aren't very good but I decided to implement a node for it anyway.
84 |
85 | #### Input Parameters:
86 | - `image`: The video frames to be outpainted.
87 | - `width`: Width of the video frames (default: 640).
88 | - `height`: Height of the video frames (default: 360).
89 | - `width_scale`: Scale factor for width expansion (default: 1.2).
90 | - `height_scale`: Scale factor for height expansion (default: 1.0).
91 | - `mask_dilates`: Dilation size for the mask (default: 5).
92 | - `flow_mask_dilates`: Dilation size for the flow mask (default: 8).
93 | - `ref_stride`: Stride for reference frames (default: 10).
94 | - `neighbor_length`: Length of the neighborhood for outpainting (default: 10).
95 | - `subvideo_length`: Length of subvideos for processing (default: 80).
96 | - `raft_iter`: Number of iterations for RAFT model (default: 20).
97 | - `fp16`: Enable or disable FP16 precision (default: "disable").
98 |
99 | #### Output:
100 | - `IMAGE`: The outpainted video frames.
101 | - `OUTPAINT_MASK`: The mask used during outpainting.
102 | - `output_width`: The width of the outpainted frames.
103 | - `output_height`: The height of the outpainted frames.
104 |
105 | ## License
106 | The ProPainter models and code are licensed under [NTU S-Lab License 1.0](https://github.com/sczhou/ProPainter/blob/main/LICENSE).
107 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | from .propainter_nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
2 |
3 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
4 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/model/canny/canny_filter.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Tuple
3 |
4 | import torch
5 | from torch import nn
6 | import torch.nn.functional as F
7 |
8 | from .gaussian import gaussian_blur2d
9 | from .kernels import get_canny_nms_kernel, get_hysteresis_kernel
10 | from .sobel import spatial_gradient
11 |
12 |
13 | def rgb_to_grayscale(image, rgb_weights=None):
14 | if len(image.shape) < 3 or image.shape[-3] != 3:
15 | raise ValueError(
16 | f"Input size must have a shape of (*, 3, H, W). Got {image.shape}"
17 | )
18 |
19 | if rgb_weights is None:
20 | # 8 bit images
21 | if image.dtype == torch.uint8:
22 | rgb_weights = torch.tensor(
23 | [76, 150, 29], device=image.device, dtype=torch.uint8
24 | )
25 | # floating point images
26 | elif image.dtype in (torch.float16, torch.float32, torch.float64):
27 | rgb_weights = torch.tensor(
28 | [0.299, 0.587, 0.114], device=image.device, dtype=image.dtype
29 | )
30 | else:
31 | raise TypeError(f"Unknown data type: {image.dtype}")
32 | else:
33 | # is tensor that we make sure is in the same device/dtype
34 | rgb_weights = rgb_weights.to(image)
35 |
36 | # unpack the color image channels with RGB order
37 | r = image[..., 0:1, :, :]
38 | g = image[..., 1:2, :, :]
39 | b = image[..., 2:3, :, :]
40 |
41 | w_r, w_g, w_b = rgb_weights.unbind()
42 | return w_r * r + w_g * g + w_b * b
43 |
44 |
45 | def canny(
46 | input: torch.Tensor,
47 | low_threshold: float = 0.1,
48 | high_threshold: float = 0.2,
49 | kernel_size: Tuple[int, int] = (5, 5),
50 | sigma: Tuple[float, float] = (1, 1),
51 | hysteresis: bool = True,
52 | eps: float = 1e-6,
53 | ) -> Tuple[torch.Tensor, torch.Tensor]:
54 | r"""Find edges of the input image and filters them using the Canny algorithm.
55 |
56 | .. image:: _static/img/canny.png
57 |
58 | Args:
59 | input: input image tensor with shape :math:`(B,C,H,W)`.
60 | low_threshold: lower threshold for the hysteresis procedure.
61 | high_threshold: upper threshold for the hysteresis procedure.
62 | kernel_size: the size of the kernel for the gaussian blur.
63 | sigma: the standard deviation of the kernel for the gaussian blur.
64 | hysteresis: if True, applies the hysteresis edge tracking.
65 | Otherwise, the edges are divided between weak (0.5) and strong (1) edges.
66 | eps: regularization number to avoid NaN during backprop.
67 |
68 | Returns:
69 | - the canny edge magnitudes map, shape of :math:`(B,1,H,W)`.
70 | - the canny edge detection filtered by thresholds and hysteresis, shape of :math:`(B,1,H,W)`.
71 |
72 | .. note::
73 | See a working example `here `__.
75 |
76 | Example:
77 | >>> input = torch.rand(5, 3, 4, 4)
78 | >>> magnitude, edges = canny(input) # 5x3x4x4
79 | >>> magnitude.shape
80 | torch.Size([5, 1, 4, 4])
81 | >>> edges.shape
82 | torch.Size([5, 1, 4, 4])
83 | """
84 | if not isinstance(input, torch.Tensor):
85 | raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")
86 |
87 | if not len(input.shape) == 4:
88 | raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}")
89 |
90 | if low_threshold > high_threshold:
91 | raise ValueError(
92 | f"Invalid input thresholds. low_threshold should be smaller than the high_threshold. Got: {low_threshold}>{high_threshold}"
93 | )
94 |
95 | if low_threshold < 0 and low_threshold > 1:
96 | raise ValueError(
97 | f"Invalid input threshold. low_threshold should be in range (0,1). Got: {low_threshold}"
98 | )
99 |
100 | if high_threshold < 0 and high_threshold > 1:
101 | raise ValueError(
102 | f"Invalid input threshold. high_threshold should be in range (0,1). Got: {high_threshold}"
103 | )
104 |
105 | device: torch.device = input.device
106 | dtype: torch.dtype = input.dtype
107 |
108 | # To Grayscale
109 | if input.shape[1] == 3:
110 | input = rgb_to_grayscale(input)
111 |
112 | # Gaussian filter
113 | blurred: torch.Tensor = gaussian_blur2d(input, kernel_size, sigma)
114 |
115 | # Compute the gradients
116 | gradients: torch.Tensor = spatial_gradient(blurred, normalized=False)
117 |
118 | # Unpack the edges
119 | gx: torch.Tensor = gradients[:, :, 0]
120 | gy: torch.Tensor = gradients[:, :, 1]
121 |
122 | # Compute gradient magnitude and angle
123 | magnitude: torch.Tensor = torch.sqrt(gx * gx + gy * gy + eps)
124 | angle: torch.Tensor = torch.atan2(gy, gx)
125 |
126 | # Radians to Degrees
127 | angle = 180.0 * angle / math.pi
128 |
129 | # Round angle to the nearest 45 degree
130 | angle = torch.round(angle / 45) * 45
131 |
132 | # Non-maximal suppression
133 | nms_kernels: torch.Tensor = get_canny_nms_kernel(device, dtype)
134 | nms_magnitude: torch.Tensor = F.conv2d(
135 | magnitude, nms_kernels, padding=nms_kernels.shape[-1] // 2
136 | )
137 |
138 | # Get the indices for both directions
139 | positive_idx: torch.Tensor = (angle / 45) % 8
140 | positive_idx = positive_idx.long()
141 |
142 | negative_idx: torch.Tensor = ((angle / 45) + 4) % 8
143 | negative_idx = negative_idx.long()
144 |
145 | # Apply the non-maximum suppression to the different directions
146 | channel_select_filtered_positive: torch.Tensor = torch.gather(
147 | nms_magnitude, 1, positive_idx
148 | )
149 | channel_select_filtered_negative: torch.Tensor = torch.gather(
150 | nms_magnitude, 1, negative_idx
151 | )
152 |
153 | channel_select_filtered: torch.Tensor = torch.stack(
154 | [channel_select_filtered_positive, channel_select_filtered_negative], 1
155 | )
156 |
157 | is_max: torch.Tensor = channel_select_filtered.min(dim=1)[0] > 0.0
158 |
159 | magnitude = magnitude * is_max
160 |
161 | # Threshold
162 | edges: torch.Tensor = F.threshold(magnitude, low_threshold, 0.0)
163 |
164 | low: torch.Tensor = magnitude > low_threshold
165 | high: torch.Tensor = magnitude > high_threshold
166 |
167 | edges = low * 0.5 + high * 0.5
168 | edges = edges.to(dtype)
169 |
170 | # Hysteresis
171 | if hysteresis:
172 | edges_old: torch.Tensor = -torch.ones(
173 | edges.shape, device=edges.device, dtype=dtype
174 | )
175 | hysteresis_kernels: torch.Tensor = get_hysteresis_kernel(device, dtype)
176 |
177 | while ((edges_old - edges).abs() != 0).any():
178 | weak: torch.Tensor = (edges == 0.5).float()
179 | strong: torch.Tensor = (edges == 1).float()
180 |
181 | hysteresis_magnitude: torch.Tensor = F.conv2d(
182 | edges, hysteresis_kernels, padding=hysteresis_kernels.shape[-1] // 2
183 | )
184 | hysteresis_magnitude = (
185 | (hysteresis_magnitude == 1).any(1, keepdim=True).to(dtype)
186 | )
187 | hysteresis_magnitude = hysteresis_magnitude * weak + strong
188 |
189 | edges_old = edges.clone()
190 | edges = hysteresis_magnitude + (hysteresis_magnitude == 0) * weak * 0.5
191 |
192 | edges = hysteresis_magnitude
193 |
194 | return magnitude, edges
195 |
196 |
197 | class Canny(nn.Module):
198 | r"""Module that finds edges of the input image and filters them using the Canny algorithm.
199 |
200 | Args:
201 | input: input image tensor with shape :math:`(B,C,H,W)`.
202 | low_threshold: lower threshold for the hysteresis procedure.
203 | high_threshold: upper threshold for the hysteresis procedure.
204 | kernel_size: the size of the kernel for the gaussian blur.
205 | sigma: the standard deviation of the kernel for the gaussian blur.
206 | hysteresis: if True, applies the hysteresis edge tracking.
207 | Otherwise, the edges are divided between weak (0.5) and strong (1) edges.
208 | eps: regularization number to avoid NaN during backprop.
209 |
210 | Returns:
211 | - the canny edge magnitudes map, shape of :math:`(B,1,H,W)`.
212 | - the canny edge detection filtered by thresholds and hysteresis, shape of :math:`(B,1,H,W)`.
213 |
214 | Example:
215 | >>> input = torch.rand(5, 3, 4, 4)
216 | >>> magnitude, edges = Canny()(input) # 5x3x4x4
217 | >>> magnitude.shape
218 | torch.Size([5, 1, 4, 4])
219 | >>> edges.shape
220 | torch.Size([5, 1, 4, 4])
221 | """
222 |
223 | def __init__(
224 | self,
225 | low_threshold: float = 0.1,
226 | high_threshold: float = 0.2,
227 | kernel_size: Tuple[int, int] = (5, 5),
228 | sigma: Tuple[float, float] = (1, 1),
229 | hysteresis: bool = True,
230 | eps: float = 1e-6,
231 | ) -> None:
232 | super().__init__()
233 |
234 | if low_threshold > high_threshold:
235 | raise ValueError(
236 | f"Invalid input thresholds. low_threshold should be\
237 | smaller than the high_threshold. Got: {low_threshold}>{high_threshold}"
238 | )
239 |
240 | if low_threshold < 0 or low_threshold > 1:
241 | raise ValueError(
242 | f"Invalid input threshold. low_threshold should be in range (0,1). Got: {low_threshold}"
243 | )
244 |
245 | if high_threshold < 0 or high_threshold > 1:
246 | raise ValueError(
247 | f"Invalid input threshold. high_threshold should be in range (0,1). Got: {high_threshold}"
248 | )
249 |
250 | # Gaussian blur parameters
251 | self.kernel_size = kernel_size
252 | self.sigma = sigma
253 |
254 | # Double threshold
255 | self.low_threshold = low_threshold
256 | self.high_threshold = high_threshold
257 |
258 | # Hysteresis
259 | self.hysteresis = hysteresis
260 |
261 | self.eps: float = eps
262 |
263 | def __repr__(self) -> str:
264 | return "".join(
265 | (
266 | f"{type(self).__name__}(",
267 | ", ".join(
268 | f"{name}={getattr(self, name)}"
269 | for name in sorted(self.__dict__)
270 | if not name.startswith("_")
271 | ),
272 | ")",
273 | )
274 | )
275 |
276 | def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
277 | return canny(
278 | input,
279 | self.low_threshold,
280 | self.high_threshold,
281 | self.kernel_size,
282 | self.sigma,
283 | self.hysteresis,
284 | self.eps,
285 | )
286 |
--------------------------------------------------------------------------------
/model/canny/filter.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | from .kernels import normalize_kernel2d
7 |
8 |
9 | def _compute_padding(kernel_size: List[int]) -> List[int]:
10 | """Compute padding tuple."""
11 | # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
12 | # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
13 | if len(kernel_size) < 2:
14 | raise AssertionError(kernel_size)
15 | computed = [k - 1 for k in kernel_size]
16 |
17 | # for even kernels we need to do asymmetric padding :(
18 | out_padding = 2 * len(kernel_size) * [0]
19 |
20 | for i in range(len(kernel_size)):
21 | computed_tmp = computed[-(i + 1)]
22 |
23 | pad_front = computed_tmp // 2
24 | pad_rear = computed_tmp - pad_front
25 |
26 | out_padding[2 * i + 0] = pad_front
27 | out_padding[2 * i + 1] = pad_rear
28 |
29 | return out_padding
30 |
31 |
32 | def filter2d(
33 | input: torch.Tensor,
34 | kernel: torch.Tensor,
35 | border_type: str = "reflect",
36 | normalized: bool = False,
37 | padding: str = "same",
38 | ) -> torch.Tensor:
39 | r"""Convolve a tensor with a 2d kernel.
40 |
41 | The function applies a given kernel to a tensor. The kernel is applied
42 | independently at each depth channel of the tensor. Before applying the
43 | kernel, the function applies padding according to the specified mode so
44 | that the output remains in the same shape.
45 |
46 | Args:
47 | input: the input tensor with shape of
48 | :math:`(B, C, H, W)`.
49 | kernel: the kernel to be convolved with the input
50 | tensor. The kernel shape must be :math:`(1, kH, kW)` or :math:`(B, kH, kW)`.
51 | border_type: the padding mode to be applied before convolving.
52 | The expected modes are: ``'constant'``, ``'reflect'``,
53 | ``'replicate'`` or ``'circular'``.
54 | normalized: If True, kernel will be L1 normalized.
55 | padding: This defines the type of padding.
56 | 2 modes available ``'same'`` or ``'valid'``.
57 |
58 | Return:
59 | torch.Tensor: the convolved tensor of same size and numbers of channels
60 | as the input with shape :math:`(B, C, H, W)`.
61 |
62 | Example:
63 | >>> input = torch.tensor([[[
64 | ... [0., 0., 0., 0., 0.],
65 | ... [0., 0., 0., 0., 0.],
66 | ... [0., 0., 5., 0., 0.],
67 | ... [0., 0., 0., 0., 0.],
68 | ... [0., 0., 0., 0., 0.],]]])
69 | >>> kernel = torch.ones(1, 3, 3)
70 | >>> filter2d(input, kernel, padding='same')
71 | tensor([[[[0., 0., 0., 0., 0.],
72 | [0., 5., 5., 5., 0.],
73 | [0., 5., 5., 5., 0.],
74 | [0., 5., 5., 5., 0.],
75 | [0., 0., 0., 0., 0.]]]])
76 | """
77 | if not isinstance(input, torch.Tensor):
78 | raise TypeError(f"Input input is not torch.Tensor. Got {type(input)}")
79 |
80 | if not isinstance(kernel, torch.Tensor):
81 | raise TypeError(f"Input kernel is not torch.Tensor. Got {type(kernel)}")
82 |
83 | if not isinstance(border_type, str):
84 | raise TypeError(f"Input border_type is not string. Got {type(border_type)}")
85 |
86 | if border_type not in ["constant", "reflect", "replicate", "circular"]:
87 | raise ValueError(
88 | f"Invalid border type, we expect 'constant', \
89 | 'reflect', 'replicate', 'circular'. Got:{border_type}"
90 | )
91 |
92 | if not isinstance(padding, str):
93 | raise TypeError(f"Input padding is not string. Got {type(padding)}")
94 |
95 | if padding not in ["valid", "same"]:
96 | raise ValueError(
97 | f"Invalid padding mode, we expect 'valid' or 'same'. Got: {padding}"
98 | )
99 |
100 | if not len(input.shape) == 4:
101 | raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}")
102 |
103 | if (not len(kernel.shape) == 3) and not (
104 | (kernel.shape[0] == 0) or (kernel.shape[0] == input.shape[0])
105 | ):
106 | raise ValueError(
107 | f"Invalid kernel shape, we expect 1xHxW or BxHxW. Got: {kernel.shape}"
108 | )
109 |
110 | # prepare kernel
111 | b, c, h, w = input.shape
112 | tmp_kernel: torch.Tensor = kernel.unsqueeze(1).to(input)
113 |
114 | if normalized:
115 | tmp_kernel = normalize_kernel2d(tmp_kernel)
116 |
117 | tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
118 |
119 | height, width = tmp_kernel.shape[-2:]
120 |
121 | # pad the input tensor
122 | if padding == "same":
123 | padding_shape: List[int] = _compute_padding([height, width])
124 | input = F.pad(input, padding_shape, mode=border_type)
125 |
126 | # kernel and input tensor reshape to align element-wise or batch-wise params
127 | tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
128 | input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
129 |
130 | # convolve the tensor with the kernel.
131 | output = F.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
132 |
133 | if padding == "same":
134 | out = output.view(b, c, h, w)
135 | else:
136 | out = output.view(b, c, h - height + 1, w - width + 1)
137 |
138 | return out
139 |
140 |
141 | def filter2d_separable(
142 | input: torch.Tensor,
143 | kernel_x: torch.Tensor,
144 | kernel_y: torch.Tensor,
145 | border_type: str = "reflect",
146 | normalized: bool = False,
147 | padding: str = "same",
148 | ) -> torch.Tensor:
149 | r"""Convolve a tensor with two 1d kernels, in x and y directions.
150 |
151 | The function applies a given kernel to a tensor. The kernel is applied
152 | independently at each depth channel of the tensor. Before applying the
153 | kernel, the function applies padding according to the specified mode so
154 | that the output remains in the same shape.
155 |
156 | Args:
157 | input: the input tensor with shape of
158 | :math:`(B, C, H, W)`.
159 | kernel_x: the kernel to be convolved with the input
160 | tensor. The kernel shape must be :math:`(1, kW)` or :math:`(B, kW)`.
161 | kernel_y: the kernel to be convolved with the input
162 | tensor. The kernel shape must be :math:`(1, kH)` or :math:`(B, kH)`.
163 | border_type: the padding mode to be applied before convolving.
164 | The expected modes are: ``'constant'``, ``'reflect'``,
165 | ``'replicate'`` or ``'circular'``.
166 | normalized: If True, kernel will be L1 normalized.
167 | padding: This defines the type of padding.
168 | 2 modes available ``'same'`` or ``'valid'``.
169 |
170 | Return:
171 | torch.Tensor: the convolved tensor of same size and numbers of channels
172 | as the input with shape :math:`(B, C, H, W)`.
173 |
174 | Example:
175 | >>> input = torch.tensor([[[
176 | ... [0., 0., 0., 0., 0.],
177 | ... [0., 0., 0., 0., 0.],
178 | ... [0., 0., 5., 0., 0.],
179 | ... [0., 0., 0., 0., 0.],
180 | ... [0., 0., 0., 0., 0.],]]])
181 | >>> kernel = torch.ones(1, 3)
182 |
183 | >>> filter2d_separable(input, kernel, kernel, padding='same')
184 | tensor([[[[0., 0., 0., 0., 0.],
185 | [0., 5., 5., 5., 0.],
186 | [0., 5., 5., 5., 0.],
187 | [0., 5., 5., 5., 0.],
188 | [0., 0., 0., 0., 0.]]]])
189 | """
190 | out_x = filter2d(input, kernel_x.unsqueeze(0), border_type, normalized, padding)
191 | out = filter2d(out_x, kernel_y.unsqueeze(-1), border_type, normalized, padding)
192 | return out
193 |
194 |
195 | def filter3d(
196 | input: torch.Tensor,
197 | kernel: torch.Tensor,
198 | border_type: str = "replicate",
199 | normalized: bool = False,
200 | ) -> torch.Tensor:
201 | r"""Convolve a tensor with a 3d kernel.
202 |
203 | The function applies a given kernel to a tensor. The kernel is applied
204 | independently at each depth channel of the tensor. Before applying the
205 | kernel, the function applies padding according to the specified mode so
206 | that the output remains in the same shape.
207 |
208 | Args:
209 | input: the input tensor with shape of
210 | :math:`(B, C, D, H, W)`.
211 | kernel: the kernel to be convolved with the input
212 | tensor. The kernel shape must be :math:`(1, kD, kH, kW)` or :math:`(B, kD, kH, kW)`.
213 | border_type: the padding mode to be applied before convolving.
214 | The expected modes are: ``'constant'``,
215 | ``'replicate'`` or ``'circular'``.
216 | normalized: If True, kernel will be L1 normalized.
217 |
218 | Return:
219 | the convolved tensor of same size and numbers of channels
220 | as the input with shape :math:`(B, C, D, H, W)`.
221 |
222 | Example:
223 | >>> input = torch.tensor([[[
224 | ... [[0., 0., 0., 0., 0.],
225 | ... [0., 0., 0., 0., 0.],
226 | ... [0., 0., 0., 0., 0.],
227 | ... [0., 0., 0., 0., 0.],
228 | ... [0., 0., 0., 0., 0.]],
229 | ... [[0., 0., 0., 0., 0.],
230 | ... [0., 0., 0., 0., 0.],
231 | ... [0., 0., 5., 0., 0.],
232 | ... [0., 0., 0., 0., 0.],
233 | ... [0., 0., 0., 0., 0.]],
234 | ... [[0., 0., 0., 0., 0.],
235 | ... [0., 0., 0., 0., 0.],
236 | ... [0., 0., 0., 0., 0.],
237 | ... [0., 0., 0., 0., 0.],
238 | ... [0., 0., 0., 0., 0.]]
239 | ... ]]])
240 | >>> kernel = torch.ones(1, 3, 3, 3)
241 | >>> filter3d(input, kernel)
242 | tensor([[[[[0., 0., 0., 0., 0.],
243 | [0., 5., 5., 5., 0.],
244 | [0., 5., 5., 5., 0.],
245 | [0., 5., 5., 5., 0.],
246 | [0., 0., 0., 0., 0.]],
247 |
248 | [[0., 0., 0., 0., 0.],
249 | [0., 5., 5., 5., 0.],
250 | [0., 5., 5., 5., 0.],
251 | [0., 5., 5., 5., 0.],
252 | [0., 0., 0., 0., 0.]],
253 |
254 | [[0., 0., 0., 0., 0.],
255 | [0., 5., 5., 5., 0.],
256 | [0., 5., 5., 5., 0.],
257 | [0., 5., 5., 5., 0.],
258 | [0., 0., 0., 0., 0.]]]]])
259 | """
260 | if not isinstance(input, torch.Tensor):
261 | raise TypeError(f"Input border_type is not torch.Tensor. Got {type(input)}")
262 |
263 | if not isinstance(kernel, torch.Tensor):
264 | raise TypeError(f"Input border_type is not torch.Tensor. Got {type(kernel)}")
265 |
266 | if not isinstance(border_type, str):
267 | raise TypeError(f"Input border_type is not string. Got {type(kernel)}")
268 |
269 | if not len(input.shape) == 5:
270 | raise ValueError(
271 | f"Invalid input shape, we expect BxCxDxHxW. Got: {input.shape}"
272 | )
273 |
274 | if not len(kernel.shape) == 4 and kernel.shape[0] != 1:
275 | raise ValueError(
276 | f"Invalid kernel shape, we expect 1xDxHxW. Got: {kernel.shape}"
277 | )
278 |
279 | # prepare kernel
280 | b, c, d, h, w = input.shape
281 | tmp_kernel: torch.Tensor = kernel.unsqueeze(1).to(input)
282 |
283 | if normalized:
284 | bk, dk, hk, wk = kernel.shape
285 | tmp_kernel = normalize_kernel2d(tmp_kernel.view(bk, dk, hk * wk)).view_as(
286 | tmp_kernel
287 | )
288 |
289 | tmp_kernel = tmp_kernel.expand(-1, c, -1, -1, -1)
290 |
291 | # pad the input tensor
292 | depth, height, width = tmp_kernel.shape[-3:]
293 | padding_shape: List[int] = _compute_padding([depth, height, width])
294 | input_pad: torch.Tensor = F.pad(input, padding_shape, mode=border_type)
295 |
296 | # kernel and input tensor reshape to align element-wise or batch-wise params
297 | tmp_kernel = tmp_kernel.reshape(-1, 1, depth, height, width)
298 | input_pad = input_pad.view(
299 | -1,
300 | tmp_kernel.size(0),
301 | input_pad.size(-3),
302 | input_pad.size(-2),
303 | input_pad.size(-1),
304 | )
305 |
306 | # convolve the tensor with the kernel.
307 | output = F.conv3d(
308 | input_pad, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1
309 | )
310 |
311 | return output.view(b, c, d, h, w)
312 |
--------------------------------------------------------------------------------
/model/canny/gaussian.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import torch
4 | from torch import nn
5 |
6 | from .filter import filter2d, filter2d_separable
7 | from .kernels import get_gaussian_kernel1d, get_gaussian_kernel2d
8 |
9 |
10 | def gaussian_blur2d(
11 | input: torch.Tensor,
12 | kernel_size: Tuple[int, int],
13 | sigma: Tuple[float, float],
14 | border_type: str = "reflect",
15 | separable: bool = True,
16 | ) -> torch.Tensor:
17 | r"""Create an operator that blurs a tensor using a Gaussian filter.
18 |
19 | .. image:: _static/img/gaussian_blur2d.png
20 |
21 | The operator smooths the given tensor with a gaussian kernel by convolving
22 | it to each channel. It supports batched operation.
23 |
24 | Arguments:
25 | input: the input tensor with shape :math:`(B,C,H,W)`.
26 | kernel_size: the size of the kernel.
27 | sigma: the standard deviation of the kernel.
28 | border_type: the padding mode to be applied before convolving.
29 | The expected modes are: ``'constant'``, ``'reflect'``,
30 | ``'replicate'`` or ``'circular'``. Default: ``'reflect'``.
31 | separable: run as composition of two 1d-convolutions.
32 |
33 | Returns:
34 | the blurred tensor with shape :math:`(B, C, H, W)`.
35 |
36 | .. note::
37 | See a working example `here `__.
39 |
40 | Examples:
41 | >>> input = torch.rand(2, 4, 5, 5)
42 | >>> output = gaussian_blur2d(input, (3, 3), (1.5, 1.5))
43 | >>> output.shape
44 | torch.Size([2, 4, 5, 5])
45 | """
46 | if separable:
47 | kernel_x: torch.Tensor = get_gaussian_kernel1d(kernel_size[1], sigma[1])
48 | kernel_y: torch.Tensor = get_gaussian_kernel1d(kernel_size[0], sigma[0])
49 | out = filter2d_separable(input, kernel_x[None], kernel_y[None], border_type)
50 | else:
51 | kernel: torch.Tensor = get_gaussian_kernel2d(kernel_size, sigma)
52 | out = filter2d(input, kernel[None], border_type)
53 | return out
54 |
55 |
56 | class GaussianBlur2d(nn.Module):
57 | r"""Create an operator that blurs a tensor using a Gaussian filter.
58 |
59 | The operator smooths the given tensor with a gaussian kernel by convolving
60 | it to each channel. It supports batched operation.
61 |
62 | Arguments:
63 | kernel_size: the size of the kernel.
64 | sigma: the standard deviation of the kernel.
65 | border_type: the padding mode to be applied before convolving.
66 | The expected modes are: ``'constant'``, ``'reflect'``,
67 | ``'replicate'`` or ``'circular'``. Default: ``'reflect'``.
68 | separable: run as composition of two 1d-convolutions.
69 |
70 | Returns:
71 | the blurred tensor.
72 |
73 | Shape:
74 | - Input: :math:`(B, C, H, W)`
75 | - Output: :math:`(B, C, H, W)`
76 |
77 | Examples::
78 |
79 | >>> input = torch.rand(2, 4, 5, 5)
80 | >>> gauss = GaussianBlur2d((3, 3), (1.5, 1.5))
81 | >>> output = gauss(input) # 2x4x5x5
82 | >>> output.shape
83 | torch.Size([2, 4, 5, 5])
84 | """
85 |
86 | def __init__(
87 | self,
88 | kernel_size: Tuple[int, int],
89 | sigma: Tuple[float, float],
90 | border_type: str = "reflect",
91 | separable: bool = True,
92 | ) -> None:
93 | super().__init__()
94 | self.kernel_size: Tuple[int, int] = kernel_size
95 | self.sigma: Tuple[float, float] = sigma
96 | self.border_type = border_type
97 | self.separable = separable
98 |
99 | def __repr__(self) -> str:
100 | return (
101 | self.__class__.__name__
102 | + "(kernel_size="
103 | + str(self.kernel_size)
104 | + ", "
105 | + "sigma="
106 | + str(self.sigma)
107 | + ", "
108 | + "border_type="
109 | + self.border_type
110 | + "separable="
111 | + str(self.separable)
112 | + ")"
113 | )
114 |
115 | def forward(self, input: torch.Tensor) -> torch.Tensor:
116 | return gaussian_blur2d(
117 | input, self.kernel_size, self.sigma, self.border_type, self.separable
118 | )
119 |
--------------------------------------------------------------------------------
/model/canny/sobel.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 | from .kernels import (
6 | get_spatial_gradient_kernel2d,
7 | get_spatial_gradient_kernel3d,
8 | normalize_kernel2d,
9 | )
10 |
11 |
12 | def spatial_gradient(
13 | input: torch.Tensor, mode: str = "sobel", order: int = 1, normalized: bool = True
14 | ) -> torch.Tensor:
15 | r"""Compute the first order image derivative in both x and y using a Sobel operator.
16 |
17 | .. image:: _static/img/spatial_gradient.png
18 |
19 | Args:
20 | input: input image tensor with shape :math:`(B, C, H, W)`.
21 | mode: derivatives modality, can be: `sobel` or `diff`.
22 | order: the order of the derivatives.
23 | normalized: whether the output is normalized.
24 |
25 | Return:
26 | the derivatives of the input feature map. with shape :math:`(B, C, 2, H, W)`.
27 |
28 | .. note::
29 | See a working example `here `__.
31 |
32 | Examples:
33 | >>> input = torch.rand(1, 3, 4, 4)
34 | >>> output = spatial_gradient(input) # 1x3x2x4x4
35 | >>> output.shape
36 | torch.Size([1, 3, 2, 4, 4])
37 | """
38 | if not isinstance(input, torch.Tensor):
39 | raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")
40 |
41 | if not len(input.shape) == 4:
42 | raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}")
43 | # allocate kernel
44 | kernel: torch.Tensor = get_spatial_gradient_kernel2d(mode, order)
45 | if normalized:
46 | kernel = normalize_kernel2d(kernel)
47 |
48 | # prepare kernel
49 | b, c, h, w = input.shape
50 | tmp_kernel: torch.Tensor = kernel.to(input).detach()
51 | tmp_kernel = tmp_kernel.unsqueeze(1).unsqueeze(1)
52 |
53 | # convolve input tensor with sobel kernel
54 | kernel_flip: torch.Tensor = tmp_kernel.flip(-3)
55 |
56 | # Pad with "replicate for spatial dims, but with zeros for channel
57 | spatial_pad = [
58 | kernel.size(1) // 2,
59 | kernel.size(1) // 2,
60 | kernel.size(2) // 2,
61 | kernel.size(2) // 2,
62 | ]
63 | out_channels: int = 3 if order == 2 else 2
64 | padded_inp: torch.Tensor = F.pad(
65 | input.reshape(b * c, 1, h, w), spatial_pad, "replicate"
66 | )[:, :, None]
67 |
68 | return F.conv3d(padded_inp, kernel_flip, padding=0).view(b, c, out_channels, h, w)
69 |
70 |
71 | def spatial_gradient3d(
72 | input: torch.Tensor, mode: str = "diff", order: int = 1
73 | ) -> torch.Tensor:
74 | r"""Compute the first and second order volume derivative in x, y and d using a diff operator.
75 |
76 | Args:
77 | input: input features tensor with shape :math:`(B, C, D, H, W)`.
78 | mode: derivatives modality, can be: `sobel` or `diff`.
79 | order: the order of the derivatives.
80 |
81 | Return:
82 | the spatial gradients of the input feature map with shape math:`(B, C, 3, D, H, W)`
83 | or :math:`(B, C, 6, D, H, W)`.
84 |
85 | Examples:
86 | >>> input = torch.rand(1, 4, 2, 4, 4)
87 | >>> output = spatial_gradient3d(input)
88 | >>> output.shape
89 | torch.Size([1, 4, 3, 2, 4, 4])
90 | """
91 | if not isinstance(input, torch.Tensor):
92 | raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")
93 |
94 | if not len(input.shape) == 5:
95 | raise ValueError(
96 | f"Invalid input shape, we expect BxCxDxHxW. Got: {input.shape}"
97 | )
98 | b, c, d, h, w = input.shape
99 | dev = input.device
100 | dtype = input.dtype
101 | if (mode == "diff") and (order == 1):
102 | # we go for the special case implementation due to conv3d bad speed
103 | x: torch.Tensor = F.pad(input, 6 * [1], "replicate")
104 | center = slice(1, -1)
105 | left = slice(0, -2)
106 | right = slice(2, None)
107 | out = torch.empty(b, c, 3, d, h, w, device=dev, dtype=dtype)
108 | out[..., 0, :, :, :] = (
109 | x[..., center, center, right] - x[..., center, center, left]
110 | )
111 | out[..., 1, :, :, :] = (
112 | x[..., center, right, center] - x[..., center, left, center]
113 | )
114 | out[..., 2, :, :, :] = (
115 | x[..., right, center, center] - x[..., left, center, center]
116 | )
117 | out = 0.5 * out
118 | else:
119 | # prepare kernel
120 | # allocate kernel
121 | kernel: torch.Tensor = get_spatial_gradient_kernel3d(mode, order)
122 |
123 | tmp_kernel: torch.Tensor = kernel.to(input).detach()
124 | tmp_kernel = tmp_kernel.repeat(c, 1, 1, 1, 1)
125 |
126 | # convolve input tensor with grad kernel
127 | kernel_flip: torch.Tensor = tmp_kernel.flip(-3)
128 |
129 | # Pad with "replicate for spatial dims, but with zeros for channel
130 | spatial_pad = [
131 | kernel.size(2) // 2,
132 | kernel.size(2) // 2,
133 | kernel.size(3) // 2,
134 | kernel.size(3) // 2,
135 | kernel.size(4) // 2,
136 | kernel.size(4) // 2,
137 | ]
138 | out_ch: int = 6 if order == 2 else 3
139 | out = F.conv3d(
140 | F.pad(input, spatial_pad, "replicate"), kernel_flip, padding=0, groups=c
141 | ).view(b, c, out_ch, d, h, w)
142 | return out
143 |
144 |
145 | def sobel(
146 | input: torch.Tensor, normalized: bool = True, eps: float = 1e-6
147 | ) -> torch.Tensor:
148 | r"""Compute the Sobel operator and returns the magnitude per channel.
149 |
150 | .. image:: _static/img/sobel.png
151 |
152 | Args:
153 | input: the input image with shape :math:`(B,C,H,W)`.
154 | normalized: if True, L1 norm of the kernel is set to 1.
155 | eps: regularization number to avoid NaN during backprop.
156 |
157 | Return:
158 | the sobel edge gradient magnitudes map with shape :math:`(B,C,H,W)`.
159 |
160 | .. note::
161 | See a working example `here `__.
163 |
164 | Example:
165 | >>> input = torch.rand(1, 3, 4, 4)
166 | >>> output = sobel(input) # 1x3x4x4
167 | >>> output.shape
168 | torch.Size([1, 3, 4, 4])
169 | """
170 | if not isinstance(input, torch.Tensor):
171 | raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")
172 |
173 | if not len(input.shape) == 4:
174 | raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}")
175 |
176 | # comput the x/y gradients
177 | edges: torch.Tensor = spatial_gradient(input, normalized=normalized)
178 |
179 | # unpack the edges
180 | gx: torch.Tensor = edges[:, :, 0]
181 | gy: torch.Tensor = edges[:, :, 1]
182 |
183 | # compute gradient maginitude
184 | magnitude: torch.Tensor = torch.sqrt(gx * gx + gy * gy + eps)
185 |
186 | return magnitude
187 |
188 |
189 | class SpatialGradient(nn.Module):
190 | r"""Compute the first order image derivative in both x and y using a Sobel operator.
191 |
192 | Args:
193 | mode: derivatives modality, can be: `sobel` or `diff`.
194 | order: the order of the derivatives.
195 | normalized: whether the output is normalized.
196 |
197 | Return:
198 | the sobel edges of the input feature map.
199 |
200 | Shape:
201 | - Input: :math:`(B, C, H, W)`
202 | - Output: :math:`(B, C, 2, H, W)`
203 |
204 | Examples:
205 | >>> input = torch.rand(1, 3, 4, 4)
206 | >>> output = SpatialGradient()(input) # 1x3x2x4x4
207 | """
208 |
209 | def __init__(
210 | self, mode: str = "sobel", order: int = 1, normalized: bool = True
211 | ) -> None:
212 | super().__init__()
213 | self.normalized: bool = normalized
214 | self.order: int = order
215 | self.mode: str = mode
216 |
217 | def __repr__(self) -> str:
218 | return (
219 | self.__class__.__name__ + "("
220 | "order="
221 | + str(self.order)
222 | + ", "
223 | + "normalized="
224 | + str(self.normalized)
225 | + ", "
226 | + "mode="
227 | + self.mode
228 | + ")"
229 | )
230 |
231 | def forward(self, input: torch.Tensor) -> torch.Tensor:
232 | return spatial_gradient(input, self.mode, self.order, self.normalized)
233 |
234 |
235 | class SpatialGradient3d(nn.Module):
236 | r"""Compute the first and second order volume derivative in x, y and d using a diff operator.
237 |
238 | Args:
239 | mode: derivatives modality, can be: `sobel` or `diff`.
240 | order: the order of the derivatives.
241 |
242 | Return:
243 | the spatial gradients of the input feature map.
244 |
245 | Shape:
246 | - Input: :math:`(B, C, D, H, W)`. D, H, W are spatial dimensions, gradient is calculated w.r.t to them.
247 | - Output: :math:`(B, C, 3, D, H, W)` or :math:`(B, C, 6, D, H, W)`
248 |
249 | Examples:
250 | >>> input = torch.rand(1, 4, 2, 4, 4)
251 | >>> output = SpatialGradient3d()(input)
252 | >>> output.shape
253 | torch.Size([1, 4, 3, 2, 4, 4])
254 | """
255 |
256 | def __init__(self, mode: str = "diff", order: int = 1) -> None:
257 | super().__init__()
258 | self.order: int = order
259 | self.mode: str = mode
260 | self.kernel = get_spatial_gradient_kernel3d(mode, order)
261 |
262 | def __repr__(self) -> str:
263 | return (
264 | self.__class__.__name__ + "("
265 | "order=" + str(self.order) + ", " + "mode=" + self.mode + ")"
266 | )
267 |
268 | def forward(self, input: torch.Tensor) -> torch.Tensor: # type: ignore
269 | return spatial_gradient3d(input, self.mode, self.order)
270 |
271 |
272 | class Sobel(nn.Module):
273 | r"""Compute the Sobel operator and returns the magnitude per channel.
274 |
275 | Args:
276 | normalized: if True, L1 norm of the kernel is set to 1.
277 | eps: regularization number to avoid NaN during backprop.
278 |
279 | Return:
280 | the sobel edge gradient magnitudes map.
281 |
282 | Shape:
283 | - Input: :math:`(B, C, H, W)`
284 | - Output: :math:`(B, C, H, W)`
285 |
286 | Examples:
287 | >>> input = torch.rand(1, 3, 4, 4)
288 | >>> output = Sobel()(input) # 1x3x4x4
289 | """
290 |
291 | def __init__(self, normalized: bool = True, eps: float = 1e-6) -> None:
292 | super().__init__()
293 | self.normalized: bool = normalized
294 | self.eps: float = eps
295 |
296 | def __repr__(self) -> str:
297 | return self.__class__.__name__ + "(" "normalized=" + str(self.normalized) + ")"
298 |
299 | def forward(self, input: torch.Tensor) -> torch.Tensor:
300 | return sobel(input, self.normalized, self.eps)
301 |
--------------------------------------------------------------------------------
/model/misc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import random
4 | import time
5 | import torch
6 | from torch import nn
7 | import logging
8 | import numpy as np
9 | from os import path as osp
10 |
11 |
12 | def constant_init(module, val, bias=0):
13 | if hasattr(module, "weight") and module.weight is not None:
14 | nn.init.constant_(module.weight, val)
15 | if hasattr(module, "bias") and module.bias is not None:
16 | nn.init.constant_(module.bias, bias)
17 |
18 |
19 | initialized_logger = {}
20 |
21 |
22 | def get_root_logger(logger_name="basicsr", log_level=logging.INFO, log_file=None):
23 | """Get the root logger.
24 | The logger will be initialized if it has not been initialized. By default a
25 | StreamHandler will be added. If `log_file` is specified, a FileHandler will
26 | also be added.
27 |
28 | Args:
29 | logger_name (str): root logger name. Default: 'basicsr'.
30 | log_file (str | None): The log filename. If specified, a FileHandler
31 | will be added to the root logger.
32 | log_level (int): The root logger level. Note that only the process of
33 | rank 0 is affected, while other processes will set the level to
34 | "Error" and be silent most of the time.
35 |
36 | Returns:
37 | logging.Logger: The root logger.
38 | """
39 | logger = logging.getLogger(logger_name)
40 | # if the logger has been initialized, just return it
41 | if logger_name in initialized_logger:
42 | return logger
43 |
44 | format_str = "%(asctime)s %(levelname)s: %(message)s"
45 | stream_handler = logging.StreamHandler()
46 | stream_handler.setFormatter(logging.Formatter(format_str))
47 | logger.addHandler(stream_handler)
48 | logger.propagate = False
49 |
50 | if log_file is not None:
51 | logger.setLevel(log_level)
52 | # add file handler
53 | # file_handler = logging.FileHandler(log_file, 'w')
54 | file_handler = logging.FileHandler(
55 | log_file, "a"
56 | ) # Shangchen: keep the previous log
57 | file_handler.setFormatter(logging.Formatter(format_str))
58 | file_handler.setLevel(log_level)
59 | logger.addHandler(file_handler)
60 | initialized_logger[logger_name] = True
61 | return logger
62 |
63 |
64 | IS_HIGH_VERSION = [
65 | int(m)
66 | for m in list(
67 | re.findall(
68 | r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",
69 | torch.__version__,
70 | )[0][:3]
71 | )
72 | ] >= [1, 12, 0]
73 |
74 |
75 | def gpu_is_available():
76 | if IS_HIGH_VERSION:
77 | if torch.backends.mps.is_available():
78 | return True
79 | return (
80 | True
81 | if torch.cuda.is_available() and torch.backends.cudnn.is_available()
82 | else False
83 | )
84 |
85 |
86 | def get_device(gpu_id=None):
87 | if gpu_id is None:
88 | gpu_str = ""
89 | elif isinstance(gpu_id, int):
90 | gpu_str = f":{gpu_id}"
91 | else:
92 | raise TypeError("Input should be int value.")
93 |
94 | if IS_HIGH_VERSION:
95 | if torch.backends.mps.is_available():
96 | return torch.device("mps" + gpu_str)
97 | return torch.device(
98 | "cuda" + gpu_str
99 | if torch.cuda.is_available() and torch.backends.cudnn.is_available()
100 | else "cpu"
101 | )
102 |
103 |
104 | def set_random_seed(seed):
105 | """Set random seeds."""
106 | random.seed(seed)
107 | np.random.seed(seed)
108 | torch.manual_seed(seed)
109 | torch.cuda.manual_seed(seed)
110 | torch.cuda.manual_seed_all(seed)
111 |
112 |
113 | def get_time_str():
114 | return time.strftime("%Y%m%d_%H%M%S", time.localtime())
115 |
116 |
117 | def scandir(dir_path, suffix=None, recursive=False, full_path=False):
118 | """Scan a directory to find the interested files.
119 |
120 | Args:
121 | dir_path (str): Path of the directory.
122 | suffix (str | tuple(str), optional): File suffix that we are
123 | interested in. Default: None.
124 | recursive (bool, optional): If set to True, recursively scan the
125 | directory. Default: False.
126 | full_path (bool, optional): If set to True, include the dir_path.
127 | Default: False.
128 |
129 | Returns:
130 | A generator for all the interested files with relative pathes.
131 | """
132 | if (suffix is not None) and not isinstance(suffix, (str, tuple)):
133 | raise TypeError('"suffix" must be a string or tuple of strings')
134 |
135 | root = dir_path
136 |
137 | def _scandir(dir_path, suffix, recursive):
138 | for entry in os.scandir(dir_path):
139 | if not entry.name.startswith(".") and entry.is_file():
140 | if full_path:
141 | return_path = entry.path
142 | else:
143 | return_path = osp.relpath(entry.path, root)
144 |
145 | if suffix is None or return_path.endswith(suffix):
146 | yield return_path
147 | elif recursive:
148 | yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
149 | else:
150 | continue
151 |
152 | return _scandir(dir_path, suffix=suffix, recursive=recursive)
153 |
--------------------------------------------------------------------------------
/model/modules/RAFT/__init__.py:
--------------------------------------------------------------------------------
1 | # from .demo import RAFT_infer
2 | from .raft import RAFT
3 |
--------------------------------------------------------------------------------
/model/modules/RAFT/corr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from .utils.utils import bilinear_sampler
4 |
5 | try:
6 | import alt_cuda_corr
7 | except:
8 | # alt_cuda_corr is not compiled
9 | pass
10 |
11 |
12 | class CorrBlock:
13 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
14 | self.num_levels = num_levels
15 | self.radius = radius
16 | self.corr_pyramid = []
17 |
18 | # all pairs correlation
19 | corr = CorrBlock.corr(fmap1, fmap2)
20 |
21 | batch, h1, w1, dim, h2, w2 = corr.shape
22 | corr = corr.reshape(batch * h1 * w1, dim, h2, w2)
23 |
24 | self.corr_pyramid.append(corr)
25 | for i in range(self.num_levels - 1):
26 | corr = F.avg_pool2d(corr, 2, stride=2)
27 | self.corr_pyramid.append(corr)
28 |
29 | def __call__(self, coords):
30 | r = self.radius
31 | coords = coords.permute(0, 2, 3, 1)
32 | batch, h1, w1, _ = coords.shape
33 |
34 | out_pyramid = []
35 | for i in range(self.num_levels):
36 | corr = self.corr_pyramid[i]
37 | dx = torch.linspace(-r, r, 2 * r + 1)
38 | dy = torch.linspace(-r, r, 2 * r + 1)
39 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)
40 |
41 | centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i
42 | delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
43 | coords_lvl = centroid_lvl + delta_lvl
44 |
45 | corr = bilinear_sampler(corr, coords_lvl)
46 | corr = corr.view(batch, h1, w1, -1)
47 | out_pyramid.append(corr)
48 |
49 | out = torch.cat(out_pyramid, dim=-1)
50 | return out.permute(0, 3, 1, 2).contiguous().float()
51 |
52 | @staticmethod
53 | def corr(fmap1, fmap2):
54 | batch, dim, ht, wd = fmap1.shape
55 | fmap1 = fmap1.view(batch, dim, ht * wd)
56 | fmap2 = fmap2.view(batch, dim, ht * wd)
57 |
58 | corr = torch.matmul(fmap1.transpose(1, 2), fmap2)
59 | corr = corr.view(batch, ht, wd, 1, ht, wd)
60 | return corr / torch.sqrt(torch.tensor(dim).float())
61 |
62 |
63 | class CorrLayer(torch.autograd.Function):
64 | @staticmethod
65 | def forward(ctx, fmap1, fmap2, coords, r):
66 | fmap1 = fmap1.contiguous()
67 | fmap2 = fmap2.contiguous()
68 | coords = coords.contiguous()
69 | ctx.save_for_backward(fmap1, fmap2, coords)
70 | ctx.r = r
71 | (corr,) = correlation_cudaz.forward(fmap1, fmap2, coords, ctx.r)
72 | return corr
73 |
74 | @staticmethod
75 | def backward(ctx, grad_corr):
76 | fmap1, fmap2, coords = ctx.saved_tensors
77 | grad_corr = grad_corr.contiguous()
78 | fmap1_grad, fmap2_grad, coords_grad = correlation_cudaz.backward(
79 | fmap1, fmap2, coords, grad_corr, ctx.r
80 | )
81 | return fmap1_grad, fmap2_grad, coords_grad, None
82 |
83 |
84 | class AlternateCorrBlock:
85 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
86 | self.num_levels = num_levels
87 | self.radius = radius
88 |
89 | self.pyramid = [(fmap1, fmap2)]
90 | for i in range(self.num_levels):
91 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
92 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
93 | self.pyramid.append((fmap1, fmap2))
94 |
95 | def __call__(self, coords):
96 | coords = coords.permute(0, 2, 3, 1)
97 | B, H, W, _ = coords.shape
98 |
99 | corr_list = []
100 | for i in range(self.num_levels):
101 | r = self.radius
102 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1)
103 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1)
104 |
105 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
106 | corr = alt_cuda_corr(fmap1_i, fmap2_i, coords_i, r)
107 | corr_list.append(corr.squeeze(1))
108 |
109 | corr = torch.stack(corr_list, dim=1)
110 | corr = corr.reshape(B, -1, H, W)
111 | return corr / 16.0
112 |
--------------------------------------------------------------------------------
/model/modules/RAFT/datasets.py:
--------------------------------------------------------------------------------
1 | # Data loading based on https://github.com/NVIDIA/flownet2-pytorch
2 |
3 | import numpy as np
4 | import torch
5 | from torch.utils import data
6 |
7 | import os
8 | import random
9 | from glob import glob
10 | import os.path as osp
11 |
12 | from utils import frame_utils
13 | from utils.augmentor import FlowAugmentor, SparseFlowAugmentor
14 |
15 |
16 | class FlowDataset(data.Dataset):
17 | def __init__(self, aug_params=None, sparse=False):
18 | self.augmentor = None
19 | self.sparse = sparse
20 | if aug_params is not None:
21 | if sparse:
22 | self.augmentor = SparseFlowAugmentor(**aug_params)
23 | else:
24 | self.augmentor = FlowAugmentor(**aug_params)
25 |
26 | self.is_test = False
27 | self.init_seed = False
28 | self.flow_list = []
29 | self.image_list = []
30 | self.extra_info = []
31 |
32 | def __getitem__(self, index):
33 | if self.is_test:
34 | img1 = frame_utils.read_gen(self.image_list[index][0])
35 | img2 = frame_utils.read_gen(self.image_list[index][1])
36 | img1 = np.array(img1).astype(np.uint8)[..., :3]
37 | img2 = np.array(img2).astype(np.uint8)[..., :3]
38 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
39 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
40 | return img1, img2, self.extra_info[index]
41 |
42 | if not self.init_seed:
43 | worker_info = torch.utils.data.get_worker_info()
44 | if worker_info is not None:
45 | torch.manual_seed(worker_info.id)
46 | np.random.seed(worker_info.id)
47 | random.seed(worker_info.id)
48 | self.init_seed = True
49 |
50 | index = index % len(self.image_list)
51 | valid = None
52 | if self.sparse:
53 | flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
54 | else:
55 | flow = frame_utils.read_gen(self.flow_list[index])
56 |
57 | img1 = frame_utils.read_gen(self.image_list[index][0])
58 | img2 = frame_utils.read_gen(self.image_list[index][1])
59 |
60 | flow = np.array(flow).astype(np.float32)
61 | img1 = np.array(img1).astype(np.uint8)
62 | img2 = np.array(img2).astype(np.uint8)
63 |
64 | # grayscale images
65 | if len(img1.shape) == 2:
66 | img1 = np.tile(img1[..., None], (1, 1, 3))
67 | img2 = np.tile(img2[..., None], (1, 1, 3))
68 | else:
69 | img1 = img1[..., :3]
70 | img2 = img2[..., :3]
71 |
72 | if self.augmentor is not None:
73 | if self.sparse:
74 | img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
75 | else:
76 | img1, img2, flow = self.augmentor(img1, img2, flow)
77 |
78 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
79 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
80 | flow = torch.from_numpy(flow).permute(2, 0, 1).float()
81 |
82 | if valid is not None:
83 | valid = torch.from_numpy(valid)
84 | else:
85 | valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)
86 |
87 | return img1, img2, flow, valid.float()
88 |
89 | def __rmul__(self, v):
90 | self.flow_list = v * self.flow_list
91 | self.image_list = v * self.image_list
92 | return self
93 |
94 | def __len__(self):
95 | return len(self.image_list)
96 |
97 |
98 | class MpiSintel(FlowDataset):
99 | def __init__(
100 | self, aug_params=None, split="training", root="datasets/Sintel", dstype="clean"
101 | ):
102 | super(MpiSintel, self).__init__(aug_params)
103 | flow_root = osp.join(root, split, "flow")
104 | image_root = osp.join(root, split, dstype)
105 |
106 | if split == "test":
107 | self.is_test = True
108 |
109 | for scene in os.listdir(image_root):
110 | image_list = sorted(glob(osp.join(image_root, scene, "*.png")))
111 | for i in range(len(image_list) - 1):
112 | self.image_list += [[image_list[i], image_list[i + 1]]]
113 | self.extra_info += [(scene, i)] # scene and frame_id
114 |
115 | if split != "test":
116 | self.flow_list += sorted(glob(osp.join(flow_root, scene, "*.flo")))
117 |
118 |
119 | class FlyingChairs(FlowDataset):
120 | def __init__(
121 | self, aug_params=None, split="train", root="datasets/FlyingChairs_release/data"
122 | ):
123 | super(FlyingChairs, self).__init__(aug_params)
124 |
125 | images = sorted(glob(osp.join(root, "*.ppm")))
126 | flows = sorted(glob(osp.join(root, "*.flo")))
127 | assert len(images) // 2 == len(flows)
128 |
129 | split_list = np.loadtxt("chairs_split.txt", dtype=np.int32)
130 | for i in range(len(flows)):
131 | xid = split_list[i]
132 | if (split == "training" and xid == 1) or (
133 | split == "validation" and xid == 2
134 | ):
135 | self.flow_list += [flows[i]]
136 | self.image_list += [[images[2 * i], images[2 * i + 1]]]
137 |
138 |
139 | class FlyingThings3D(FlowDataset):
140 | def __init__(
141 | self, aug_params=None, root="datasets/FlyingThings3D", dstype="frames_cleanpass"
142 | ):
143 | super(FlyingThings3D, self).__init__(aug_params)
144 |
145 | for cam in ["left"]:
146 | for direction in ["into_future", "into_past"]:
147 | image_dirs = sorted(glob(osp.join(root, dstype, "TRAIN/*/*")))
148 | image_dirs = sorted([osp.join(f, cam) for f in image_dirs])
149 |
150 | flow_dirs = sorted(glob(osp.join(root, "optical_flow/TRAIN/*/*")))
151 | flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs])
152 |
153 | for idir, fdir in zip(image_dirs, flow_dirs):
154 | images = sorted(glob(osp.join(idir, "*.png")))
155 | flows = sorted(glob(osp.join(fdir, "*.pfm")))
156 | for i in range(len(flows) - 1):
157 | if direction == "into_future":
158 | self.image_list += [[images[i], images[i + 1]]]
159 | self.flow_list += [flows[i]]
160 | elif direction == "into_past":
161 | self.image_list += [[images[i + 1], images[i]]]
162 | self.flow_list += [flows[i + 1]]
163 |
164 |
165 | class KITTI(FlowDataset):
166 | def __init__(self, aug_params=None, split="training", root="datasets/KITTI"):
167 | super(KITTI, self).__init__(aug_params, sparse=True)
168 | if split == "testing":
169 | self.is_test = True
170 |
171 | root = osp.join(root, split)
172 | images1 = sorted(glob(osp.join(root, "image_2/*_10.png")))
173 | images2 = sorted(glob(osp.join(root, "image_2/*_11.png")))
174 |
175 | for img1, img2 in zip(images1, images2):
176 | frame_id = img1.split("/")[-1]
177 | self.extra_info += [[frame_id]]
178 | self.image_list += [[img1, img2]]
179 |
180 | if split == "training":
181 | self.flow_list = sorted(glob(osp.join(root, "flow_occ/*_10.png")))
182 |
183 |
184 | class HD1K(FlowDataset):
185 | def __init__(self, aug_params=None, root="datasets/HD1k"):
186 | super(HD1K, self).__init__(aug_params, sparse=True)
187 |
188 | seq_ix = 0
189 | while 1:
190 | flows = sorted(
191 | glob(os.path.join(root, "hd1k_flow_gt", "flow_occ/%06d_*.png" % seq_ix))
192 | )
193 | images = sorted(
194 | glob(os.path.join(root, "hd1k_input", "image_2/%06d_*.png" % seq_ix))
195 | )
196 |
197 | if len(flows) == 0:
198 | break
199 |
200 | for i in range(len(flows) - 1):
201 | self.flow_list += [flows[i]]
202 | self.image_list += [[images[i], images[i + 1]]]
203 |
204 | seq_ix += 1
205 |
206 |
207 | def fetch_dataloader(args, TRAIN_DS="C+T+K+S+H"):
208 | """Create the data loader for the corresponding trainign set"""
209 | if args.stage == "chairs":
210 | aug_params = {
211 | "crop_size": args.image_size,
212 | "min_scale": -0.1,
213 | "max_scale": 1.0,
214 | "do_flip": True,
215 | }
216 | train_dataset = FlyingChairs(aug_params, split="training")
217 |
218 | elif args.stage == "things":
219 | aug_params = {
220 | "crop_size": args.image_size,
221 | "min_scale": -0.4,
222 | "max_scale": 0.8,
223 | "do_flip": True,
224 | }
225 | clean_dataset = FlyingThings3D(aug_params, dstype="frames_cleanpass")
226 | final_dataset = FlyingThings3D(aug_params, dstype="frames_finalpass")
227 | train_dataset = clean_dataset + final_dataset
228 |
229 | elif args.stage == "sintel":
230 | aug_params = {
231 | "crop_size": args.image_size,
232 | "min_scale": -0.2,
233 | "max_scale": 0.6,
234 | "do_flip": True,
235 | }
236 | things = FlyingThings3D(aug_params, dstype="frames_cleanpass")
237 | sintel_clean = MpiSintel(aug_params, split="training", dstype="clean")
238 | sintel_final = MpiSintel(aug_params, split="training", dstype="final")
239 |
240 | if TRAIN_DS == "C+T+K+S+H":
241 | kitti = KITTI(
242 | {
243 | "crop_size": args.image_size,
244 | "min_scale": -0.3,
245 | "max_scale": 0.5,
246 | "do_flip": True,
247 | }
248 | )
249 | hd1k = HD1K(
250 | {
251 | "crop_size": args.image_size,
252 | "min_scale": -0.5,
253 | "max_scale": 0.2,
254 | "do_flip": True,
255 | }
256 | )
257 | train_dataset = (
258 | 100 * sintel_clean
259 | + 100 * sintel_final
260 | + 200 * kitti
261 | + 5 * hd1k
262 | + things
263 | )
264 |
265 | elif TRAIN_DS == "C+T+K/S":
266 | train_dataset = 100 * sintel_clean + 100 * sintel_final + things
267 |
268 | elif args.stage == "kitti":
269 | aug_params = {
270 | "crop_size": args.image_size,
271 | "min_scale": -0.2,
272 | "max_scale": 0.4,
273 | "do_flip": False,
274 | }
275 | train_dataset = KITTI(aug_params, split="training")
276 |
277 | train_loader = data.DataLoader(
278 | train_dataset,
279 | batch_size=args.batch_size,
280 | pin_memory=False,
281 | shuffle=True,
282 | num_workers=4,
283 | drop_last=True,
284 | )
285 |
286 | print("Training with %d image pairs" % len(train_dataset))
287 | return train_loader
288 |
--------------------------------------------------------------------------------
/model/modules/RAFT/demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import glob
4 | import numpy as np
5 | import torch
6 | from PIL import Image
7 |
8 | from .raft import RAFT
9 | from .utils import flow_viz
10 | from .utils.utils import InputPadder
11 |
12 |
13 | DEVICE = "cuda"
14 |
15 |
16 | def load_image(imfile):
17 | img = np.array(Image.open(imfile)).astype(np.uint8)
18 | img = torch.from_numpy(img).permute(2, 0, 1).float()
19 | return img
20 |
21 |
22 | def load_image_list(image_files):
23 | images = []
24 | for imfile in sorted(image_files):
25 | images.append(load_image(imfile))
26 |
27 | images = torch.stack(images, dim=0)
28 | images = images.to(DEVICE)
29 |
30 | padder = InputPadder(images.shape)
31 | return padder.pad(images)[0]
32 |
33 |
34 | def viz(img, flo):
35 | img = img[0].permute(1, 2, 0).cpu().numpy()
36 | flo = flo[0].permute(1, 2, 0).cpu().numpy()
37 |
38 | # map flow to rgb image
39 | flo = flow_viz.flow_to_image(flo)
40 | # img_flo = np.concatenate([img, flo], axis=0)
41 | img_flo = flo
42 |
43 | cv2.imwrite("/home/chengao/test/flow.png", img_flo[:, :, [2, 1, 0]])
44 | # cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)
45 | # cv2.waitKey()
46 |
47 |
48 | def demo(args):
49 | model = torch.nn.DataParallel(RAFT(args))
50 | model.load_state_dict(torch.load(args.model))
51 |
52 | model = model.module
53 | model.to(DEVICE)
54 | model.eval()
55 |
56 | with torch.no_grad():
57 | images = glob.glob(os.path.join(args.path, "*.png")) + glob.glob(
58 | os.path.join(args.path, "*.jpg")
59 | )
60 |
61 | images = load_image_list(images)
62 | for i in range(images.shape[0] - 1):
63 | image1 = images[i, None]
64 | image2 = images[i + 1, None]
65 |
66 | flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
67 | viz(image1, flow_up)
68 |
69 |
70 | def RAFT_infer(args):
71 | model = torch.nn.DataParallel(RAFT(args))
72 | model.load_state_dict(torch.load(args.model))
73 |
74 | model = model.module
75 | model.to(DEVICE)
76 | model.eval()
77 |
78 | return model
79 |
--------------------------------------------------------------------------------
/model/modules/RAFT/extractor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class ResidualBlock(nn.Module):
6 | def __init__(self, in_planes, planes, norm_fn="group", stride=1):
7 | super(ResidualBlock, self).__init__()
8 |
9 | self.conv1 = nn.Conv2d(
10 | in_planes, planes, kernel_size=3, padding=1, stride=stride
11 | )
12 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
13 | self.relu = nn.ReLU(inplace=True)
14 |
15 | num_groups = planes // 8
16 |
17 | if norm_fn == "group":
18 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
19 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
20 | if stride != 1:
21 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
22 |
23 | elif norm_fn == "batch":
24 | self.norm1 = nn.BatchNorm2d(planes)
25 | self.norm2 = nn.BatchNorm2d(planes)
26 | if stride != 1:
27 | self.norm3 = nn.BatchNorm2d(planes)
28 |
29 | elif norm_fn == "instance":
30 | self.norm1 = nn.InstanceNorm2d(planes)
31 | self.norm2 = nn.InstanceNorm2d(planes)
32 | if stride != 1:
33 | self.norm3 = nn.InstanceNorm2d(planes)
34 |
35 | elif norm_fn == "none":
36 | self.norm1 = nn.Sequential()
37 | self.norm2 = nn.Sequential()
38 | if stride != 1:
39 | self.norm3 = nn.Sequential()
40 |
41 | if stride == 1:
42 | self.downsample = None
43 |
44 | else:
45 | self.downsample = nn.Sequential(
46 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
47 | )
48 |
49 | def forward(self, x):
50 | y = x
51 | y = self.relu(self.norm1(self.conv1(y)))
52 | y = self.relu(self.norm2(self.conv2(y)))
53 |
54 | if self.downsample is not None:
55 | x = self.downsample(x)
56 |
57 | return self.relu(x + y)
58 |
59 |
60 | class BottleneckBlock(nn.Module):
61 | def __init__(self, in_planes, planes, norm_fn="group", stride=1):
62 | super(BottleneckBlock, self).__init__()
63 |
64 | self.conv1 = nn.Conv2d(in_planes, planes // 4, kernel_size=1, padding=0)
65 | self.conv2 = nn.Conv2d(
66 | planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride
67 | )
68 | self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0)
69 | self.relu = nn.ReLU(inplace=True)
70 |
71 | num_groups = planes // 8
72 |
73 | if norm_fn == "group":
74 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4)
75 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4)
76 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
77 | if stride != 1:
78 | self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
79 |
80 | elif norm_fn == "batch":
81 | self.norm1 = nn.BatchNorm2d(planes // 4)
82 | self.norm2 = nn.BatchNorm2d(planes // 4)
83 | self.norm3 = nn.BatchNorm2d(planes)
84 | if stride != 1:
85 | self.norm4 = nn.BatchNorm2d(planes)
86 |
87 | elif norm_fn == "instance":
88 | self.norm1 = nn.InstanceNorm2d(planes // 4)
89 | self.norm2 = nn.InstanceNorm2d(planes // 4)
90 | self.norm3 = nn.InstanceNorm2d(planes)
91 | if stride != 1:
92 | self.norm4 = nn.InstanceNorm2d(planes)
93 |
94 | elif norm_fn == "none":
95 | self.norm1 = nn.Sequential()
96 | self.norm2 = nn.Sequential()
97 | self.norm3 = nn.Sequential()
98 | if stride != 1:
99 | self.norm4 = nn.Sequential()
100 |
101 | if stride == 1:
102 | self.downsample = None
103 |
104 | else:
105 | self.downsample = nn.Sequential(
106 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4
107 | )
108 |
109 | def forward(self, x):
110 | y = x
111 | y = self.relu(self.norm1(self.conv1(y)))
112 | y = self.relu(self.norm2(self.conv2(y)))
113 | y = self.relu(self.norm3(self.conv3(y)))
114 |
115 | if self.downsample is not None:
116 | x = self.downsample(x)
117 |
118 | return self.relu(x + y)
119 |
120 |
121 | class BasicEncoder(nn.Module):
122 | def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0):
123 | super(BasicEncoder, self).__init__()
124 | self.norm_fn = norm_fn
125 |
126 | if self.norm_fn == "group":
127 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
128 |
129 | elif self.norm_fn == "batch":
130 | self.norm1 = nn.BatchNorm2d(64)
131 |
132 | elif self.norm_fn == "instance":
133 | self.norm1 = nn.InstanceNorm2d(64)
134 |
135 | elif self.norm_fn == "none":
136 | self.norm1 = nn.Sequential()
137 |
138 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
139 | self.relu1 = nn.ReLU(inplace=True)
140 |
141 | self.in_planes = 64
142 | self.layer1 = self._make_layer(64, stride=1)
143 | self.layer2 = self._make_layer(96, stride=2)
144 | self.layer3 = self._make_layer(128, stride=2)
145 |
146 | # output convolution
147 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
148 |
149 | self.dropout = None
150 | if dropout > 0:
151 | self.dropout = nn.Dropout2d(p=dropout)
152 |
153 | for m in self.modules():
154 | if isinstance(m, nn.Conv2d):
155 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
156 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
157 | if m.weight is not None:
158 | nn.init.constant_(m.weight, 1)
159 | if m.bias is not None:
160 | nn.init.constant_(m.bias, 0)
161 |
162 | def _make_layer(self, dim, stride=1):
163 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
164 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
165 | layers = (layer1, layer2)
166 |
167 | self.in_planes = dim
168 | return nn.Sequential(*layers)
169 |
170 | def forward(self, x):
171 | # if input is list, combine batch dimension
172 | is_list = isinstance(x, (list, tuple))
173 | if is_list:
174 | batch_dim = x[0].shape[0]
175 | x = torch.cat(x, dim=0)
176 |
177 | x = self.conv1(x)
178 | x = self.norm1(x)
179 | x = self.relu1(x)
180 |
181 | x = self.layer1(x)
182 | x = self.layer2(x)
183 | x = self.layer3(x)
184 |
185 | x = self.conv2(x)
186 |
187 | if self.training and self.dropout is not None:
188 | x = self.dropout(x)
189 |
190 | if is_list:
191 | x = torch.split(x, [batch_dim, batch_dim], dim=0)
192 |
193 | return x
194 |
195 |
196 | class SmallEncoder(nn.Module):
197 | def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0):
198 | super(SmallEncoder, self).__init__()
199 | self.norm_fn = norm_fn
200 |
201 | if self.norm_fn == "group":
202 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
203 |
204 | elif self.norm_fn == "batch":
205 | self.norm1 = nn.BatchNorm2d(32)
206 |
207 | elif self.norm_fn == "instance":
208 | self.norm1 = nn.InstanceNorm2d(32)
209 |
210 | elif self.norm_fn == "none":
211 | self.norm1 = nn.Sequential()
212 |
213 | self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
214 | self.relu1 = nn.ReLU(inplace=True)
215 |
216 | self.in_planes = 32
217 | self.layer1 = self._make_layer(32, stride=1)
218 | self.layer2 = self._make_layer(64, stride=2)
219 | self.layer3 = self._make_layer(96, stride=2)
220 |
221 | self.dropout = None
222 | if dropout > 0:
223 | self.dropout = nn.Dropout2d(p=dropout)
224 |
225 | self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
226 |
227 | for m in self.modules():
228 | if isinstance(m, nn.Conv2d):
229 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
230 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
231 | if m.weight is not None:
232 | nn.init.constant_(m.weight, 1)
233 | if m.bias is not None:
234 | nn.init.constant_(m.bias, 0)
235 |
236 | def _make_layer(self, dim, stride=1):
237 | layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
238 | layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
239 | layers = (layer1, layer2)
240 |
241 | self.in_planes = dim
242 | return nn.Sequential(*layers)
243 |
244 | def forward(self, x):
245 | # if input is list, combine batch dimension
246 | is_list = isinstance(x, (list, tuple))
247 | if is_list:
248 | batch_dim = x[0].shape[0]
249 | x = torch.cat(x, dim=0)
250 |
251 | x = self.conv1(x)
252 | x = self.norm1(x)
253 | x = self.relu1(x)
254 |
255 | x = self.layer1(x)
256 | x = self.layer2(x)
257 | x = self.layer3(x)
258 | x = self.conv2(x)
259 |
260 | if self.training and self.dropout is not None:
261 | x = self.dropout(x)
262 |
263 | if is_list:
264 | x = torch.split(x, [batch_dim, batch_dim], dim=0)
265 |
266 | return x
267 |
--------------------------------------------------------------------------------
/model/modules/RAFT/raft.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 | from .update import BasicUpdateBlock, SmallUpdateBlock
6 | from .extractor import BasicEncoder, SmallEncoder
7 | from .corr import CorrBlock, AlternateCorrBlock
8 | from .utils.utils import coords_grid, upflow8
9 |
10 | try:
11 | autocast = torch.cuda.amp.autocast
12 | except:
13 | # dummy autocast for PyTorch < 1.6
14 | class autocast:
15 | def __init__(self, enabled):
16 | pass
17 |
18 | def __enter__(self):
19 | pass
20 |
21 | def __exit__(self, *args):
22 | pass
23 |
24 |
25 | class RAFT(nn.Module):
26 | def __init__(self, args):
27 | super(RAFT, self).__init__()
28 | self.args = args
29 |
30 | if args.small:
31 | self.hidden_dim = hdim = 96
32 | self.context_dim = cdim = 64
33 | args.corr_levels = 4
34 | args.corr_radius = 3
35 |
36 | else:
37 | self.hidden_dim = hdim = 128
38 | self.context_dim = cdim = 128
39 | args.corr_levels = 4
40 | args.corr_radius = 4
41 |
42 | if "dropout" not in args._get_kwargs():
43 | args.dropout = 0
44 |
45 | if "alternate_corr" not in args._get_kwargs():
46 | args.alternate_corr = False
47 |
48 | # feature network, context network, and update block
49 | if args.small:
50 | self.fnet = SmallEncoder(
51 | output_dim=128, norm_fn="instance", dropout=args.dropout
52 | )
53 | self.cnet = SmallEncoder(
54 | output_dim=hdim + cdim, norm_fn="none", dropout=args.dropout
55 | )
56 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
57 |
58 | else:
59 | self.fnet = BasicEncoder(
60 | output_dim=256, norm_fn="instance", dropout=args.dropout
61 | )
62 | self.cnet = BasicEncoder(
63 | output_dim=hdim + cdim, norm_fn="batch", dropout=args.dropout
64 | )
65 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
66 |
67 | def freeze_bn(self):
68 | for m in self.modules():
69 | if isinstance(m, nn.BatchNorm2d):
70 | m.eval()
71 |
72 | def initialize_flow(self, img):
73 | """Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
74 | N, C, H, W = img.shape
75 | coords0 = coords_grid(N, H // 8, W // 8).to(img.device)
76 | coords1 = coords_grid(N, H // 8, W // 8).to(img.device)
77 |
78 | # optical flow computed as difference: flow = coords1 - coords0
79 | return coords0, coords1
80 |
81 | def upsample_flow(self, flow, mask):
82 | """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination"""
83 | N, _, H, W = flow.shape
84 | mask = mask.view(N, 1, 9, 8, 8, H, W)
85 | mask = torch.softmax(mask, dim=2)
86 |
87 | up_flow = F.unfold(8 * flow, [3, 3], padding=1)
88 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
89 |
90 | up_flow = torch.sum(mask * up_flow, dim=2)
91 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
92 | return up_flow.reshape(N, 2, 8 * H, 8 * W)
93 |
94 | def forward(self, image1, image2, iters=12, flow_init=None, test_mode=True):
95 | """Estimate optical flow between pair of frames"""
96 | # image1 = 2 * (image1 / 255.0) - 1.0
97 | # image2 = 2 * (image2 / 255.0) - 1.0
98 |
99 | image1 = image1.contiguous()
100 | image2 = image2.contiguous()
101 |
102 | hdim = self.hidden_dim
103 | cdim = self.context_dim
104 |
105 | # run the feature network
106 | with autocast(enabled=self.args.mixed_precision):
107 | fmap1, fmap2 = self.fnet([image1, image2])
108 |
109 | fmap1 = fmap1.float()
110 | fmap2 = fmap2.float()
111 |
112 | if self.args.alternate_corr:
113 | corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
114 | else:
115 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
116 |
117 | # run the context network
118 | with autocast(enabled=self.args.mixed_precision):
119 | cnet = self.cnet(image1)
120 | net, inp = torch.split(cnet, [hdim, cdim], dim=1)
121 | net = torch.tanh(net)
122 | inp = torch.relu(inp)
123 |
124 | coords0, coords1 = self.initialize_flow(image1)
125 |
126 | if flow_init is not None:
127 | coords1 = coords1 + flow_init
128 |
129 | flow_predictions = []
130 | for itr in range(iters):
131 | coords1 = coords1.detach()
132 | corr = corr_fn(coords1) # index correlation volume
133 |
134 | flow = coords1 - coords0
135 | with autocast(enabled=self.args.mixed_precision):
136 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
137 |
138 | # F(t+1) = F(t) + \Delta(t)
139 | coords1 = coords1 + delta_flow
140 |
141 | # upsample predictions
142 | if up_mask is None:
143 | flow_up = upflow8(coords1 - coords0)
144 | else:
145 | flow_up = self.upsample_flow(coords1 - coords0, up_mask)
146 |
147 | flow_predictions.append(flow_up)
148 |
149 | if test_mode:
150 | return coords1 - coords0, flow_up
151 |
152 | return flow_predictions
153 |
--------------------------------------------------------------------------------
/model/modules/RAFT/update.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class FlowHead(nn.Module):
7 | def __init__(self, input_dim=128, hidden_dim=256):
8 | super(FlowHead, self).__init__()
9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
11 | self.relu = nn.ReLU(inplace=True)
12 |
13 | def forward(self, x):
14 | return self.conv2(self.relu(self.conv1(x)))
15 |
16 |
17 | class ConvGRU(nn.Module):
18 | def __init__(self, hidden_dim=128, input_dim=192 + 128):
19 | super(ConvGRU, self).__init__()
20 | self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
21 | self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
22 | self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
23 |
24 | def forward(self, h, x):
25 | hx = torch.cat([h, x], dim=1)
26 |
27 | z = torch.sigmoid(self.convz(hx))
28 | r = torch.sigmoid(self.convr(hx))
29 | q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)))
30 |
31 | h = (1 - z) * h + z * q
32 | return h
33 |
34 |
35 | class SepConvGRU(nn.Module):
36 | def __init__(self, hidden_dim=128, input_dim=192 + 128):
37 | super(SepConvGRU, self).__init__()
38 | self.convz1 = nn.Conv2d(
39 | hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
40 | )
41 | self.convr1 = nn.Conv2d(
42 | hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
43 | )
44 | self.convq1 = nn.Conv2d(
45 | hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
46 | )
47 |
48 | self.convz2 = nn.Conv2d(
49 | hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
50 | )
51 | self.convr2 = nn.Conv2d(
52 | hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
53 | )
54 | self.convq2 = nn.Conv2d(
55 | hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
56 | )
57 |
58 | def forward(self, h, x):
59 | # horizontal
60 | hx = torch.cat([h, x], dim=1)
61 | z = torch.sigmoid(self.convz1(hx))
62 | r = torch.sigmoid(self.convr1(hx))
63 | q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))
64 | h = (1 - z) * h + z * q
65 |
66 | # vertical
67 | hx = torch.cat([h, x], dim=1)
68 | z = torch.sigmoid(self.convz2(hx))
69 | r = torch.sigmoid(self.convr2(hx))
70 | q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))
71 | h = (1 - z) * h + z * q
72 |
73 | return h
74 |
75 |
76 | class SmallMotionEncoder(nn.Module):
77 | def __init__(self, args):
78 | super(SmallMotionEncoder, self).__init__()
79 | cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2
80 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
81 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
82 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
83 | self.conv = nn.Conv2d(128, 80, 3, padding=1)
84 |
85 | def forward(self, flow, corr):
86 | cor = F.relu(self.convc1(corr))
87 | flo = F.relu(self.convf1(flow))
88 | flo = F.relu(self.convf2(flo))
89 | cor_flo = torch.cat([cor, flo], dim=1)
90 | out = F.relu(self.conv(cor_flo))
91 | return torch.cat([out, flow], dim=1)
92 |
93 |
94 | class BasicMotionEncoder(nn.Module):
95 | def __init__(self, args):
96 | super(BasicMotionEncoder, self).__init__()
97 | cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2
98 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
99 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
100 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
101 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
102 | self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1)
103 |
104 | def forward(self, flow, corr):
105 | cor = F.relu(self.convc1(corr))
106 | cor = F.relu(self.convc2(cor))
107 | flo = F.relu(self.convf1(flow))
108 | flo = F.relu(self.convf2(flo))
109 |
110 | cor_flo = torch.cat([cor, flo], dim=1)
111 | out = F.relu(self.conv(cor_flo))
112 | return torch.cat([out, flow], dim=1)
113 |
114 |
115 | class SmallUpdateBlock(nn.Module):
116 | def __init__(self, args, hidden_dim=96):
117 | super(SmallUpdateBlock, self).__init__()
118 | self.encoder = SmallMotionEncoder(args)
119 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82 + 64)
120 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
121 |
122 | def forward(self, net, inp, corr, flow):
123 | motion_features = self.encoder(flow, corr)
124 | inp = torch.cat([inp, motion_features], dim=1)
125 | net = self.gru(net, inp)
126 | delta_flow = self.flow_head(net)
127 |
128 | return net, None, delta_flow
129 |
130 |
131 | class BasicUpdateBlock(nn.Module):
132 | def __init__(self, args, hidden_dim=128, input_dim=128):
133 | super(BasicUpdateBlock, self).__init__()
134 | self.args = args
135 | self.encoder = BasicMotionEncoder(args)
136 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim)
137 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
138 |
139 | self.mask = nn.Sequential(
140 | nn.Conv2d(128, 256, 3, padding=1),
141 | nn.ReLU(inplace=True),
142 | nn.Conv2d(256, 64 * 9, 1, padding=0),
143 | )
144 |
145 | def forward(self, net, inp, corr, flow, upsample=True):
146 | motion_features = self.encoder(flow, corr)
147 | inp = torch.cat([inp, motion_features], dim=1)
148 |
149 | net = self.gru(net, inp)
150 | delta_flow = self.flow_head(net)
151 |
152 | # scale mask to balence gradients
153 | mask = 0.25 * self.mask(net)
154 | return net, mask, delta_flow
155 |
--------------------------------------------------------------------------------
/model/modules/RAFT/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .flow_viz import flow_to_image
2 | from .frame_utils import writeFlow
3 |
--------------------------------------------------------------------------------
/model/modules/RAFT/utils/augmentor.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 |
4 | import cv2
5 |
6 | cv2.setNumThreads(0)
7 | cv2.ocl.setUseOpenCL(False)
8 |
9 | from torchvision.transforms import ColorJitter
10 |
11 |
12 | class FlowAugmentor:
13 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):
14 | # spatial augmentation params
15 | self.crop_size = crop_size
16 | self.min_scale = min_scale
17 | self.max_scale = max_scale
18 | self.spatial_aug_prob = 0.8
19 | self.stretch_prob = 0.8
20 | self.max_stretch = 0.2
21 |
22 | # flip augmentation params
23 | self.do_flip = do_flip
24 | self.h_flip_prob = 0.5
25 | self.v_flip_prob = 0.1
26 |
27 | # photometric augmentation params
28 | self.photo_aug = ColorJitter(
29 | brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5 / 3.14
30 | )
31 | self.asymmetric_color_aug_prob = 0.2
32 | self.eraser_aug_prob = 0.5
33 |
34 | def color_transform(self, img1, img2):
35 | """Photometric augmentation"""
36 | # asymmetric
37 | if np.random.rand() < self.asymmetric_color_aug_prob:
38 | img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
39 | img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
40 |
41 | # symmetric
42 | else:
43 | image_stack = np.concatenate([img1, img2], axis=0)
44 | image_stack = np.array(
45 | self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8
46 | )
47 | img1, img2 = np.split(image_stack, 2, axis=0)
48 |
49 | return img1, img2
50 |
51 | def eraser_transform(self, img1, img2, bounds=[50, 100]):
52 | """Occlusion augmentation"""
53 | ht, wd = img1.shape[:2]
54 | if np.random.rand() < self.eraser_aug_prob:
55 | mean_color = np.mean(img2.reshape(-1, 3), axis=0)
56 | for _ in range(np.random.randint(1, 3)):
57 | x0 = np.random.randint(0, wd)
58 | y0 = np.random.randint(0, ht)
59 | dx = np.random.randint(bounds[0], bounds[1])
60 | dy = np.random.randint(bounds[0], bounds[1])
61 | img2[y0 : y0 + dy, x0 : x0 + dx, :] = mean_color
62 |
63 | return img1, img2
64 |
65 | def spatial_transform(self, img1, img2, flow):
66 | # randomly sample scale
67 | ht, wd = img1.shape[:2]
68 | min_scale = np.maximum(
69 | (self.crop_size[0] + 8) / float(ht), (self.crop_size[1] + 8) / float(wd)
70 | )
71 |
72 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
73 | scale_x = scale
74 | scale_y = scale
75 | if np.random.rand() < self.stretch_prob:
76 | scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
77 | scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
78 |
79 | scale_x = np.clip(scale_x, min_scale, None)
80 | scale_y = np.clip(scale_y, min_scale, None)
81 |
82 | if np.random.rand() < self.spatial_aug_prob:
83 | # rescale the images
84 | img1 = cv2.resize(
85 | img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR
86 | )
87 | img2 = cv2.resize(
88 | img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR
89 | )
90 | flow = cv2.resize(
91 | flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR
92 | )
93 | flow = flow * [scale_x, scale_y]
94 |
95 | if self.do_flip:
96 | if np.random.rand() < self.h_flip_prob: # h-flip
97 | img1 = img1[:, ::-1]
98 | img2 = img2[:, ::-1]
99 | flow = flow[:, ::-1] * [-1.0, 1.0]
100 |
101 | if np.random.rand() < self.v_flip_prob: # v-flip
102 | img1 = img1[::-1, :]
103 | img2 = img2[::-1, :]
104 | flow = flow[::-1, :] * [1.0, -1.0]
105 |
106 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
107 | x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
108 |
109 | img1 = img1[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
110 | img2 = img2[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
111 | flow = flow[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
112 |
113 | return img1, img2, flow
114 |
115 | def __call__(self, img1, img2, flow):
116 | img1, img2 = self.color_transform(img1, img2)
117 | img1, img2 = self.eraser_transform(img1, img2)
118 | img1, img2, flow = self.spatial_transform(img1, img2, flow)
119 |
120 | img1 = np.ascontiguousarray(img1)
121 | img2 = np.ascontiguousarray(img2)
122 | flow = np.ascontiguousarray(flow)
123 |
124 | return img1, img2, flow
125 |
126 |
127 | class SparseFlowAugmentor:
128 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
129 | # spatial augmentation params
130 | self.crop_size = crop_size
131 | self.min_scale = min_scale
132 | self.max_scale = max_scale
133 | self.spatial_aug_prob = 0.8
134 | self.stretch_prob = 0.8
135 | self.max_stretch = 0.2
136 |
137 | # flip augmentation params
138 | self.do_flip = do_flip
139 | self.h_flip_prob = 0.5
140 | self.v_flip_prob = 0.1
141 |
142 | # photometric augmentation params
143 | self.photo_aug = ColorJitter(
144 | brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3 / 3.14
145 | )
146 | self.asymmetric_color_aug_prob = 0.2
147 | self.eraser_aug_prob = 0.5
148 |
149 | def color_transform(self, img1, img2):
150 | image_stack = np.concatenate([img1, img2], axis=0)
151 | image_stack = np.array(
152 | self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8
153 | )
154 | img1, img2 = np.split(image_stack, 2, axis=0)
155 | return img1, img2
156 |
157 | def eraser_transform(self, img1, img2):
158 | ht, wd = img1.shape[:2]
159 | if np.random.rand() < self.eraser_aug_prob:
160 | mean_color = np.mean(img2.reshape(-1, 3), axis=0)
161 | for _ in range(np.random.randint(1, 3)):
162 | x0 = np.random.randint(0, wd)
163 | y0 = np.random.randint(0, ht)
164 | dx = np.random.randint(50, 100)
165 | dy = np.random.randint(50, 100)
166 | img2[y0 : y0 + dy, x0 : x0 + dx, :] = mean_color
167 |
168 | return img1, img2
169 |
170 | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
171 | ht, wd = flow.shape[:2]
172 | coords = np.meshgrid(np.arange(wd), np.arange(ht))
173 | coords = np.stack(coords, axis=-1)
174 |
175 | coords = coords.reshape(-1, 2).astype(np.float32)
176 | flow = flow.reshape(-1, 2).astype(np.float32)
177 | valid = valid.reshape(-1).astype(np.float32)
178 |
179 | coords0 = coords[valid >= 1]
180 | flow0 = flow[valid >= 1]
181 |
182 | ht1 = int(round(ht * fy))
183 | wd1 = int(round(wd * fx))
184 |
185 | coords1 = coords0 * [fx, fy]
186 | flow1 = flow0 * [fx, fy]
187 |
188 | xx = np.round(coords1[:, 0]).astype(np.int32)
189 | yy = np.round(coords1[:, 1]).astype(np.int32)
190 |
191 | v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
192 | xx = xx[v]
193 | yy = yy[v]
194 | flow1 = flow1[v]
195 |
196 | flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
197 | valid_img = np.zeros([ht1, wd1], dtype=np.int32)
198 |
199 | flow_img[yy, xx] = flow1
200 | valid_img[yy, xx] = 1
201 |
202 | return flow_img, valid_img
203 |
204 | def spatial_transform(self, img1, img2, flow, valid):
205 | # randomly sample scale
206 |
207 | ht, wd = img1.shape[:2]
208 | min_scale = np.maximum(
209 | (self.crop_size[0] + 1) / float(ht), (self.crop_size[1] + 1) / float(wd)
210 | )
211 |
212 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
213 | scale_x = np.clip(scale, min_scale, None)
214 | scale_y = np.clip(scale, min_scale, None)
215 |
216 | if np.random.rand() < self.spatial_aug_prob:
217 | # rescale the images
218 | img1 = cv2.resize(
219 | img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR
220 | )
221 | img2 = cv2.resize(
222 | img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR
223 | )
224 | flow, valid = self.resize_sparse_flow_map(
225 | flow, valid, fx=scale_x, fy=scale_y
226 | )
227 |
228 | if self.do_flip:
229 | if np.random.rand() < 0.5: # h-flip
230 | img1 = img1[:, ::-1]
231 | img2 = img2[:, ::-1]
232 | flow = flow[:, ::-1] * [-1.0, 1.0]
233 | valid = valid[:, ::-1]
234 |
235 | margin_y = 20
236 | margin_x = 50
237 |
238 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)
239 | x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)
240 |
241 | y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
242 | x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
243 |
244 | img1 = img1[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
245 | img2 = img2[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
246 | flow = flow[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
247 | valid = valid[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
248 | return img1, img2, flow, valid
249 |
250 | def __call__(self, img1, img2, flow, valid):
251 | img1, img2 = self.color_transform(img1, img2)
252 | img1, img2 = self.eraser_transform(img1, img2)
253 | img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)
254 |
255 | img1 = np.ascontiguousarray(img1)
256 | img2 = np.ascontiguousarray(img2)
257 | flow = np.ascontiguousarray(flow)
258 | valid = np.ascontiguousarray(valid)
259 |
260 | return img1, img2, flow, valid
261 |
--------------------------------------------------------------------------------
/model/modules/RAFT/utils/flow_viz.py:
--------------------------------------------------------------------------------
1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
2 |
3 |
4 | # MIT License
5 | #
6 | # Copyright (c) 2018 Tom Runia
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to conditions.
14 | #
15 | # Author: Tom Runia
16 | # Date Created: 2018-08-03
17 |
18 | import numpy as np
19 |
20 |
21 | def make_colorwheel():
22 | """Generates a color wheel for optical flow visualization as presented in:
23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
25 |
26 | Code follows the original C++ source code of Daniel Scharstein.
27 | Code follows the the Matlab source code of Deqing Sun.
28 |
29 | Returns:
30 | np.ndarray: Color wheel
31 | """
32 | RY = 15
33 | YG = 6
34 | GC = 4
35 | CB = 11
36 | BM = 13
37 | MR = 6
38 |
39 | ncols = RY + YG + GC + CB + BM + MR
40 | colorwheel = np.zeros((ncols, 3))
41 | col = 0
42 |
43 | # RY
44 | colorwheel[0:RY, 0] = 255
45 | colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY)
46 | col = col + RY
47 | # YG
48 | colorwheel[col : col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG)
49 | colorwheel[col : col + YG, 1] = 255
50 | col = col + YG
51 | # GC
52 | colorwheel[col : col + GC, 1] = 255
53 | colorwheel[col : col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC)
54 | col = col + GC
55 | # CB
56 | colorwheel[col : col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB)
57 | colorwheel[col : col + CB, 2] = 255
58 | col = col + CB
59 | # BM
60 | colorwheel[col : col + BM, 2] = 255
61 | colorwheel[col : col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM)
62 | col = col + BM
63 | # MR
64 | colorwheel[col : col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR)
65 | colorwheel[col : col + MR, 0] = 255
66 | return colorwheel
67 |
68 |
69 | def flow_uv_to_colors(u, v, convert_to_bgr=False):
70 | """Applies the flow color wheel to (possibly clipped) flow components u and v.
71 |
72 | According to the C++ source code of Daniel Scharstein
73 | According to the Matlab source code of Deqing Sun
74 |
75 | Args:
76 | u (np.ndarray): Input horizontal flow of shape [H,W]
77 | v (np.ndarray): Input vertical flow of shape [H,W]
78 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
79 |
80 | Returns:
81 | np.ndarray: Flow visualization image of shape [H,W,3]
82 | """
83 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
84 | colorwheel = make_colorwheel() # shape [55x3]
85 | ncols = colorwheel.shape[0]
86 | rad = np.sqrt(np.square(u) + np.square(v))
87 | a = np.arctan2(-v, -u) / np.pi
88 | fk = (a + 1) / 2 * (ncols - 1)
89 | k0 = np.floor(fk).astype(np.int32)
90 | k1 = k0 + 1
91 | k1[k1 == ncols] = 0
92 | f = fk - k0
93 | for i in range(colorwheel.shape[1]):
94 | tmp = colorwheel[:, i]
95 | col0 = tmp[k0] / 255.0
96 | col1 = tmp[k1] / 255.0
97 | col = (1 - f) * col0 + f * col1
98 | idx = rad <= 1
99 | col[idx] = 1 - rad[idx] * (1 - col[idx])
100 | col[~idx] = col[~idx] * 0.75 # out of range
101 | # Note the 2-i => BGR instead of RGB
102 | ch_idx = 2 - i if convert_to_bgr else i
103 | flow_image[:, :, ch_idx] = np.floor(255 * col)
104 | return flow_image
105 |
106 |
107 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
108 | """Expects a two dimensional flow image of shape.
109 |
110 | Args:
111 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
112 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
113 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
114 |
115 | Returns:
116 | np.ndarray: Flow visualization image of shape [H,W,3]
117 | """
118 | assert flow_uv.ndim == 3, "input flow must have three dimensions"
119 | assert flow_uv.shape[2] == 2, "input flow must have shape [H,W,2]"
120 | if clip_flow is not None:
121 | flow_uv = np.clip(flow_uv, 0, clip_flow)
122 | u = flow_uv[:, :, 0]
123 | v = flow_uv[:, :, 1]
124 | rad = np.sqrt(np.square(u) + np.square(v))
125 | rad_max = np.max(rad)
126 | epsilon = 1e-5
127 | u = u / (rad_max + epsilon)
128 | v = v / (rad_max + epsilon)
129 | return flow_uv_to_colors(u, v, convert_to_bgr)
130 |
--------------------------------------------------------------------------------
/model/modules/RAFT/utils/flow_viz_pt.py:
--------------------------------------------------------------------------------
1 | # Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization
2 | import torch
3 |
4 | torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
5 |
6 |
7 | @torch.no_grad()
8 | def flow_to_image(flow: torch.Tensor) -> torch.Tensor:
9 | """Converts a flow to an RGB image.
10 |
11 | Args:
12 | flow (Tensor): Flow of shape (N, 2, H, W) or (2, H, W) and dtype torch.float.
13 |
14 | Returns:
15 | img (Tensor): Image Tensor of dtype uint8 where each color corresponds
16 | to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input.
17 | """
18 | if flow.dtype != torch.float:
19 | raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.")
20 |
21 | orig_shape = flow.shape
22 | if flow.ndim == 3:
23 | flow = flow[None] # Add batch dim
24 |
25 | if flow.ndim != 4 or flow.shape[1] != 2:
26 | raise ValueError(
27 | f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}."
28 | )
29 |
30 | max_norm = torch.sum(flow**2, dim=1).sqrt().max()
31 | epsilon = torch.finfo((flow).dtype).eps
32 | normalized_flow = flow / (max_norm + epsilon)
33 | img = _normalized_flow_to_image(normalized_flow)
34 |
35 | if len(orig_shape) == 3:
36 | img = img[0] # Remove batch dim
37 | return img
38 |
39 |
40 | @torch.no_grad()
41 | def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor:
42 | """Converts a batch of normalized flow to an RGB image.
43 |
44 | Args:
45 | normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W)
46 |
47 | Returns:
48 | img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8.
49 | """
50 | N, _, H, W = normalized_flow.shape
51 | device = normalized_flow.device
52 | flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8, device=device)
53 | colorwheel = _make_colorwheel().to(device) # shape [55x3]
54 | num_cols = colorwheel.shape[0]
55 | norm = torch.sum(normalized_flow**2, dim=1).sqrt()
56 | a = (
57 | torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :])
58 | / torch.pi
59 | )
60 | fk = (a + 1) / 2 * (num_cols - 1)
61 | k0 = torch.floor(fk).to(torch.long)
62 | k1 = k0 + 1
63 | k1[k1 == num_cols] = 0
64 | f = fk - k0
65 |
66 | for c in range(colorwheel.shape[1]):
67 | tmp = colorwheel[:, c]
68 | col0 = tmp[k0] / 255.0
69 | col1 = tmp[k1] / 255.0
70 | col = (1 - f) * col0 + f * col1
71 | col = 1 - norm * (1 - col)
72 | flow_image[:, c, :, :] = torch.floor(255.0 * col)
73 | return flow_image
74 |
75 |
76 | @torch.no_grad()
77 | def _make_colorwheel() -> torch.Tensor:
78 | """Generates a color wheel for optical flow visualization as presented in:
79 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
80 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf.
81 |
82 | Returns:
83 | colorwheel (Tensor[55, 3]): Colorwheel Tensor.
84 | """
85 | RY = 15
86 | YG = 6
87 | GC = 4
88 | CB = 11
89 | BM = 13
90 | MR = 6
91 |
92 | ncols = RY + YG + GC + CB + BM + MR
93 | colorwheel = torch.zeros((ncols, 3))
94 | col = 0
95 |
96 | # RY
97 | colorwheel[0:RY, 0] = 255
98 | colorwheel[0:RY, 1] = torch.floor(255.0 * torch.arange(0.0, RY) / RY)
99 | col = col + RY
100 | # YG
101 | colorwheel[col : col + YG, 0] = 255 - torch.floor(
102 | 255.0 * torch.arange(0.0, YG) / YG
103 | )
104 | colorwheel[col : col + YG, 1] = 255
105 | col = col + YG
106 | # GC
107 | colorwheel[col : col + GC, 1] = 255
108 | colorwheel[col : col + GC, 2] = torch.floor(255.0 * torch.arange(0.0, GC) / GC)
109 | col = col + GC
110 | # CB
111 | colorwheel[col : col + CB, 1] = 255 - torch.floor(255.0 * torch.arange(CB) / CB)
112 | colorwheel[col : col + CB, 2] = 255
113 | col = col + CB
114 | # BM
115 | colorwheel[col : col + BM, 2] = 255
116 | colorwheel[col : col + BM, 0] = torch.floor(255.0 * torch.arange(0.0, BM) / BM)
117 | col = col + BM
118 | # MR
119 | colorwheel[col : col + MR, 2] = 255 - torch.floor(255.0 * torch.arange(MR) / MR)
120 | colorwheel[col : col + MR, 0] = 255
121 | return colorwheel
122 |
--------------------------------------------------------------------------------
/model/modules/RAFT/utils/frame_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | from os.path import *
4 | import re
5 |
6 | import cv2
7 |
8 | cv2.setNumThreads(0)
9 | cv2.ocl.setUseOpenCL(False)
10 |
11 | TAG_CHAR = np.array([202021.25], np.float32)
12 |
13 |
14 | def readFlow(fn):
15 | """Read .flo file in Middlebury format"""
16 | # Code adapted from:
17 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
18 |
19 | # WARNING: this will work on little-endian architectures (eg Intel x86) only!
20 | # print 'fn = %s'%(fn)
21 | with open(fn, "rb") as f:
22 | magic = np.fromfile(f, np.float32, count=1)
23 | if magic != 202021.25:
24 | print("Magic number incorrect. Invalid .flo file")
25 | return None
26 | else:
27 | w = np.fromfile(f, np.int32, count=1)
28 | h = np.fromfile(f, np.int32, count=1)
29 | # print 'Reading %d x %d flo file\n' % (w, h)
30 | data = np.fromfile(f, np.float32, count=2 * int(w) * int(h))
31 | # Reshape data into 3D array (columns, rows, bands)
32 | # The reshape here is for visualization, the original code is (w,h,2)
33 | return np.resize(data, (int(h), int(w), 2))
34 |
35 |
36 | def readPFM(file):
37 | file = open(file, "rb")
38 |
39 | color = None
40 | width = None
41 | height = None
42 | scale = None
43 | endian = None
44 |
45 | header = file.readline().rstrip()
46 | if header == b"PF":
47 | color = True
48 | elif header == b"Pf":
49 | color = False
50 | else:
51 | raise Exception("Not a PFM file.")
52 |
53 | dim_match = re.match(rb"^(\d+)\s(\d+)\s$", file.readline())
54 | if dim_match:
55 | width, height = map(int, dim_match.groups())
56 | else:
57 | raise Exception("Malformed PFM header.")
58 |
59 | scale = float(file.readline().rstrip())
60 | if scale < 0: # little-endian
61 | endian = "<"
62 | scale = -scale
63 | else:
64 | endian = ">" # big-endian
65 |
66 | data = np.fromfile(file, endian + "f")
67 | shape = (height, width, 3) if color else (height, width)
68 |
69 | data = np.reshape(data, shape)
70 | data = np.flipud(data)
71 | return data
72 |
73 |
74 | def writeFlow(filename, uv, v=None):
75 | """Write optical flow to file.
76 |
77 | If v is None, uv is assumed to contain both u and v channels,
78 | stacked in depth.
79 | Original code by Deqing Sun, adapted from Daniel Scharstein.
80 | """
81 | nBands = 2
82 |
83 | if v is None:
84 | assert uv.ndim == 3
85 | assert uv.shape[2] == 2
86 | u = uv[:, :, 0]
87 | v = uv[:, :, 1]
88 | else:
89 | u = uv
90 |
91 | assert u.shape == v.shape
92 | height, width = u.shape
93 | f = open(filename, "wb")
94 | # write the header
95 | f.write(TAG_CHAR)
96 | np.array(width).astype(np.int32).tofile(f)
97 | np.array(height).astype(np.int32).tofile(f)
98 | # arrange into matrix form
99 | tmp = np.zeros((height, width * nBands))
100 | tmp[:, np.arange(width) * 2] = u
101 | tmp[:, np.arange(width) * 2 + 1] = v
102 | tmp.astype(np.float32).tofile(f)
103 | f.close()
104 |
105 |
106 | def readFlowKITTI(filename):
107 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR)
108 | flow = flow[:, :, ::-1].astype(np.float32)
109 | flow, valid = flow[:, :, :2], flow[:, :, 2]
110 | flow = (flow - 2**15) / 64.0
111 | return flow, valid
112 |
113 |
114 | def readDispKITTI(filename):
115 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
116 | valid = disp > 0.0
117 | flow = np.stack([-disp, np.zeros_like(disp)], -1)
118 | return flow, valid
119 |
120 |
121 | def writeFlowKITTI(filename, uv):
122 | uv = 64.0 * uv + 2**15
123 | valid = np.ones([uv.shape[0], uv.shape[1], 1])
124 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
125 | cv2.imwrite(filename, uv[..., ::-1])
126 |
127 |
128 | def read_gen(file_name, pil=False):
129 | ext = splitext(file_name)[-1]
130 | if ext == ".png" or ext == ".jpeg" or ext == ".ppm" or ext == ".jpg":
131 | return Image.open(file_name)
132 | elif ext == ".bin" or ext == ".raw":
133 | return np.load(file_name)
134 | elif ext == ".flo":
135 | return readFlow(file_name).astype(np.float32)
136 | elif ext == ".pfm":
137 | flow = readPFM(file_name).astype(np.float32)
138 | if len(flow.shape) == 2:
139 | return flow
140 | else:
141 | return flow[:, :, :-1]
142 | return []
143 |
--------------------------------------------------------------------------------
/model/modules/RAFT/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | from scipy import interpolate
5 |
6 |
7 | class InputPadder:
8 | """Pads images such that dimensions are divisible by 8"""
9 |
10 | def __init__(self, dims, mode="sintel"):
11 | self.ht, self.wd = dims[-2:]
12 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
13 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
14 | if mode == "sintel":
15 | self._pad = [
16 | pad_wd // 2,
17 | pad_wd - pad_wd // 2,
18 | pad_ht // 2,
19 | pad_ht - pad_ht // 2,
20 | ]
21 | else:
22 | self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
23 |
24 | def pad(self, *inputs):
25 | return [F.pad(x, self._pad, mode="replicate") for x in inputs]
26 |
27 | def unpad(self, x):
28 | ht, wd = x.shape[-2:]
29 | c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
30 | return x[..., c[0] : c[1], c[2] : c[3]]
31 |
32 |
33 | def forward_interpolate(flow):
34 | flow = flow.detach().cpu().numpy()
35 | dx, dy = flow[0], flow[1]
36 |
37 | ht, wd = dx.shape
38 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
39 |
40 | x1 = x0 + dx
41 | y1 = y0 + dy
42 |
43 | x1 = x1.reshape(-1)
44 | y1 = y1.reshape(-1)
45 | dx = dx.reshape(-1)
46 | dy = dy.reshape(-1)
47 |
48 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
49 | x1 = x1[valid]
50 | y1 = y1[valid]
51 | dx = dx[valid]
52 | dy = dy[valid]
53 |
54 | flow_x = interpolate.griddata(
55 | (x1, y1), dx, (x0, y0), method="nearest", fill_value=0
56 | )
57 |
58 | flow_y = interpolate.griddata(
59 | (x1, y1), dy, (x0, y0), method="nearest", fill_value=0
60 | )
61 |
62 | flow = np.stack([flow_x, flow_y], axis=0)
63 | return torch.from_numpy(flow).float()
64 |
65 |
66 | def bilinear_sampler(img, coords, mode="bilinear", mask=False):
67 | """Wrapper for grid_sample, uses pixel coordinates"""
68 | H, W = img.shape[-2:]
69 | xgrid, ygrid = coords.split([1, 1], dim=-1)
70 | xgrid = 2 * xgrid / (W - 1) - 1
71 | ygrid = 2 * ygrid / (H - 1) - 1
72 |
73 | grid = torch.cat([xgrid, ygrid], dim=-1)
74 | img = F.grid_sample(img, grid, align_corners=True)
75 |
76 | if mask:
77 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
78 | return img, mask.float()
79 |
80 | return img
81 |
82 |
83 | def coords_grid(batch, ht, wd):
84 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
85 | coords = torch.stack(coords[::-1], dim=0).float()
86 | return coords[None].repeat(batch, 1, 1, 1)
87 |
88 |
89 | def upflow8(flow, mode="bilinear"):
90 | new_size = (8 * flow.shape[2], 8 * flow.shape[3])
91 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
92 |
--------------------------------------------------------------------------------
/model/modules/base_module.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch.nn.functional as F
3 |
4 | from functools import reduce
5 |
6 |
7 | class BaseNetwork(nn.Module):
8 | def __init__(self):
9 | super(BaseNetwork, self).__init__()
10 |
11 | def print_network(self):
12 | if isinstance(self, list):
13 | self = self[0]
14 | num_params = 0
15 | for param in self.parameters():
16 | num_params += param.numel()
17 | print(
18 | "Network [%s] was created. Total number of parameters: %.1f million. "
19 | "" % (type(self).__name__, num_params / 1000000)
20 | )
21 |
22 | def init_weights(self, init_type="normal", gain=0.02):
23 | """Initialize network's weights
24 | init_type: normal | xavier | kaiming | orthogonal
25 | https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
26 | """
27 |
28 | def init_func(m):
29 | classname = m.__class__.__name__
30 | if classname.find("InstanceNorm2d") != -1:
31 | if hasattr(m, "weight") and m.weight is not None:
32 | nn.init.constant_(m.weight.data, 1.0)
33 | if hasattr(m, "bias") and m.bias is not None:
34 | nn.init.constant_(m.bias.data, 0.0)
35 | elif hasattr(m, "weight") and (
36 | classname.find("Conv") != -1 or classname.find("Linear") != -1
37 | ):
38 | if init_type == "normal":
39 | nn.init.normal_(m.weight.data, 0.0, gain)
40 | elif init_type == "xavier":
41 | nn.init.xavier_normal_(m.weight.data, gain=gain)
42 | elif init_type == "xavier_uniform":
43 | nn.init.xavier_uniform_(m.weight.data, gain=1.0)
44 | elif init_type == "kaiming":
45 | nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
46 | elif init_type == "orthogonal":
47 | nn.init.orthogonal_(m.weight.data, gain=gain)
48 | elif init_type == "none": # uses pytorch's default init method
49 | m.reset_parameters()
50 | else:
51 | raise NotImplementedError(
52 | "initialization method [%s] is not implemented" % init_type
53 | )
54 | if hasattr(m, "bias") and m.bias is not None:
55 | nn.init.constant_(m.bias.data, 0.0)
56 |
57 | self.apply(init_func)
58 |
59 | # propagate to children
60 | for m in self.children():
61 | if hasattr(m, "init_weights"):
62 | m.init_weights(init_type, gain)
63 |
64 |
65 | class Vec2Feat(nn.Module):
66 | def __init__(self, channel, hidden, kernel_size, stride, padding):
67 | super(Vec2Feat, self).__init__()
68 | self.relu = nn.LeakyReLU(0.2, inplace=True)
69 | c_out = reduce((lambda x, y: x * y), kernel_size) * channel
70 | self.embedding = nn.Linear(hidden, c_out)
71 | self.kernel_size = kernel_size
72 | self.stride = stride
73 | self.padding = padding
74 | self.bias_conv = nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1)
75 |
76 | def forward(self, x, t, output_size):
77 | b_, _, _, _, c_ = x.shape
78 | x = x.view(b_, -1, c_)
79 | feat = self.embedding(x)
80 | b, _, c = feat.size()
81 | feat = feat.view(b * t, -1, c).permute(0, 2, 1)
82 | feat = F.fold(
83 | feat,
84 | output_size=output_size,
85 | kernel_size=self.kernel_size,
86 | stride=self.stride,
87 | padding=self.padding,
88 | )
89 | feat = self.bias_conv(feat)
90 | return feat
91 |
92 |
93 | class FusionFeedForward(nn.Module):
94 | def __init__(self, dim, hidden_dim=1960, t2t_params=None):
95 | super(FusionFeedForward, self).__init__()
96 | # We set hidden_dim as a default to 1960
97 | self.fc1 = nn.Sequential(nn.Linear(dim, hidden_dim))
98 | self.fc2 = nn.Sequential(nn.GELU(), nn.Linear(hidden_dim, dim))
99 | assert t2t_params is not None
100 | self.t2t_params = t2t_params
101 | self.kernel_shape = reduce(
102 | (lambda x, y: x * y), t2t_params["kernel_size"]
103 | ) # 49
104 |
105 | def forward(self, x, output_size):
106 | n_vecs = 1
107 | for i, d in enumerate(self.t2t_params["kernel_size"]):
108 | n_vecs *= int(
109 | (output_size[i] + 2 * self.t2t_params["padding"][i] - (d - 1) - 1)
110 | / self.t2t_params["stride"][i]
111 | + 1
112 | )
113 |
114 | x = self.fc1(x)
115 | b, n, c = x.size()
116 | normalizer = (
117 | x.new_ones(b, n, self.kernel_shape)
118 | .view(-1, n_vecs, self.kernel_shape)
119 | .permute(0, 2, 1)
120 | )
121 | normalizer = F.fold(
122 | normalizer,
123 | output_size=output_size,
124 | kernel_size=self.t2t_params["kernel_size"],
125 | padding=self.t2t_params["padding"],
126 | stride=self.t2t_params["stride"],
127 | )
128 |
129 | x = F.fold(
130 | x.view(-1, n_vecs, c).permute(0, 2, 1),
131 | output_size=output_size,
132 | kernel_size=self.t2t_params["kernel_size"],
133 | padding=self.t2t_params["padding"],
134 | stride=self.t2t_params["stride"],
135 | )
136 |
137 | x = (
138 | F.unfold(
139 | x / normalizer,
140 | kernel_size=self.t2t_params["kernel_size"],
141 | padding=self.t2t_params["padding"],
142 | stride=self.t2t_params["stride"],
143 | )
144 | .permute(0, 2, 1)
145 | .contiguous()
146 | .view(b, n, c)
147 | )
148 | x = self.fc2(x)
149 | return x
150 |
--------------------------------------------------------------------------------
/model/modules/deformconv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import init as init
4 | from torch.nn.modules.utils import _pair, _single
5 | import math
6 |
7 |
8 | class ModulatedDeformConv2d(nn.Module):
9 | def __init__(
10 | self,
11 | in_channels,
12 | out_channels,
13 | kernel_size,
14 | stride=1,
15 | padding=0,
16 | dilation=1,
17 | groups=1,
18 | deform_groups=1,
19 | bias=True,
20 | ):
21 | super(ModulatedDeformConv2d, self).__init__()
22 |
23 | self.in_channels = in_channels
24 | self.out_channels = out_channels
25 | self.kernel_size = _pair(kernel_size)
26 | self.stride = stride
27 | self.padding = padding
28 | self.dilation = dilation
29 | self.groups = groups
30 | self.deform_groups = deform_groups
31 | self.with_bias = bias
32 | # enable compatibility with nn.Conv2d
33 | self.transposed = False
34 | self.output_padding = _single(0)
35 |
36 | self.weight = nn.Parameter(
37 | torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)
38 | )
39 | if bias:
40 | self.bias = nn.Parameter(torch.Tensor(out_channels))
41 | else:
42 | self.register_parameter("bias", None)
43 | self.init_weights()
44 |
45 | def init_weights(self):
46 | n = self.in_channels
47 | for k in self.kernel_size:
48 | n *= k
49 | stdv = 1.0 / math.sqrt(n)
50 | self.weight.data.uniform_(-stdv, stdv)
51 | if self.bias is not None:
52 | self.bias.data.zero_()
53 |
54 | if hasattr(self, "conv_offset"):
55 | self.conv_offset.weight.data.zero_()
56 | self.conv_offset.bias.data.zero_()
57 |
58 | def forward(self, x, offset, mask):
59 | pass
60 |
--------------------------------------------------------------------------------
/model/modules/flow_comp_raft.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | from torch import nn
4 | import torch.nn.functional as F
5 |
6 | from .RAFT import RAFT
7 | from .flow_loss_utils import flow_warp, ternary_loss2
8 |
9 |
10 | def initialize_RAFT(model_path="weights/raft-things.pth", device="cuda"):
11 | """Initializes the RAFT model."""
12 | args = argparse.ArgumentParser()
13 | args.raft_model = model_path
14 | args.small = False
15 | args.mixed_precision = False
16 | args.alternate_corr = False
17 | model = torch.nn.DataParallel(RAFT(args))
18 | model.load_state_dict(torch.load(args.raft_model, map_location="cpu"))
19 | model = model.module
20 |
21 | model.to(device)
22 |
23 | return model
24 |
25 |
26 | class RAFT_bi(nn.Module):
27 | """Flow completion loss"""
28 |
29 | def __init__(self, model_path="weights/raft-things.pth", device="cuda"):
30 | super().__init__()
31 | self.fix_raft = initialize_RAFT(model_path, device=device)
32 |
33 | for p in self.fix_raft.parameters():
34 | p.requires_grad = False
35 |
36 | self.l1_criterion = nn.L1Loss()
37 | self.eval()
38 |
39 | def forward(self, gt_local_frames, iters=20):
40 | b, l_t, c, h, w = gt_local_frames.size()
41 | # print(gt_local_frames.shape)
42 |
43 | with torch.no_grad():
44 | gtlf_1 = gt_local_frames[:, :-1, :, :, :].reshape(-1, c, h, w)
45 | gtlf_2 = gt_local_frames[:, 1:, :, :, :].reshape(-1, c, h, w)
46 | # print(gtlf_1.shape)
47 |
48 | _, gt_flows_forward = self.fix_raft(
49 | gtlf_1, gtlf_2, iters=iters, test_mode=True
50 | )
51 | _, gt_flows_backward = self.fix_raft(
52 | gtlf_2, gtlf_1, iters=iters, test_mode=True
53 | )
54 |
55 | gt_flows_forward = gt_flows_forward.view(b, l_t - 1, 2, h, w)
56 | gt_flows_backward = gt_flows_backward.view(b, l_t - 1, 2, h, w)
57 |
58 | return gt_flows_forward, gt_flows_backward
59 |
60 |
61 | ##################################################################################
62 | def smoothness_loss(flow, cmask):
63 | delta_u, delta_v, mask = smoothness_deltas(flow)
64 | loss_u = charbonnier_loss(delta_u, cmask)
65 | loss_v = charbonnier_loss(delta_v, cmask)
66 | return loss_u + loss_v
67 |
68 |
69 | def smoothness_deltas(flow):
70 | """flow: [b, c, h, w]"""
71 | mask_x = create_mask(flow, [[0, 0], [0, 1]])
72 | mask_y = create_mask(flow, [[0, 1], [0, 0]])
73 | mask = torch.cat((mask_x, mask_y), dim=1)
74 | mask = mask.to(flow.device)
75 | filter_x = torch.tensor([[0, 0, 0.0], [0, 1, -1], [0, 0, 0]])
76 | filter_y = torch.tensor([[0, 0, 0.0], [0, 1, 0], [0, -1, 0]])
77 | weights = torch.ones([2, 1, 3, 3])
78 | weights[0, 0] = filter_x
79 | weights[1, 0] = filter_y
80 | weights = weights.to(flow.device)
81 |
82 | flow_u, flow_v = torch.split(flow, split_size_or_sections=1, dim=1)
83 | delta_u = F.conv2d(flow_u, weights, stride=1, padding=1)
84 | delta_v = F.conv2d(flow_v, weights, stride=1, padding=1)
85 | return delta_u, delta_v, mask
86 |
87 |
88 | def second_order_loss(flow, cmask):
89 | delta_u, delta_v, mask = second_order_deltas(flow)
90 | loss_u = charbonnier_loss(delta_u, cmask)
91 | loss_v = charbonnier_loss(delta_v, cmask)
92 | return loss_u + loss_v
93 |
94 |
95 | def charbonnier_loss(x, mask=None, truncate=None, alpha=0.45, beta=1.0, epsilon=0.001):
96 | """Compute the generalized charbonnier loss of the difference tensor x
97 | All positions where mask == 0 are not taken into account
98 | x: a tensor of shape [b, c, h, w]
99 | mask: a mask of shape [b, mc, h, w], where mask channels must be either 1 or the same as
100 | the number of channels of x. Entries should be 0 or 1
101 | return: loss
102 | """
103 | b, c, h, w = x.shape
104 | norm = b * c * h * w
105 | error = torch.pow(
106 | torch.square(x * beta) + torch.square(torch.tensor(epsilon)), alpha
107 | )
108 | if mask is not None:
109 | error = mask * error
110 | if truncate is not None:
111 | error = torch.min(error, truncate)
112 | return torch.sum(error) / norm
113 |
114 |
115 | def second_order_deltas(flow):
116 | """Consider the single flow first
117 | flow shape: [b, c, h, w]
118 | """
119 | # create mask
120 | mask_x = create_mask(flow, [[0, 0], [1, 1]])
121 | mask_y = create_mask(flow, [[1, 1], [0, 0]])
122 | mask_diag = create_mask(flow, [[1, 1], [1, 1]])
123 | mask = torch.cat((mask_x, mask_y, mask_diag, mask_diag), dim=1)
124 | mask = mask.to(flow.device)
125 |
126 | filter_x = torch.tensor([[0, 0, 0.0], [1, -2, 1], [0, 0, 0]])
127 | filter_y = torch.tensor([[0, 1, 0.0], [0, -2, 0], [0, 1, 0]])
128 | filter_diag1 = torch.tensor([[1, 0, 0.0], [0, -2, 0], [0, 0, 1]])
129 | filter_diag2 = torch.tensor([[0, 0, 1.0], [0, -2, 0], [1, 0, 0]])
130 | weights = torch.ones([4, 1, 3, 3])
131 | weights[0] = filter_x
132 | weights[1] = filter_y
133 | weights[2] = filter_diag1
134 | weights[3] = filter_diag2
135 | weights = weights.to(flow.device)
136 |
137 | # split the flow into flow_u and flow_v, conv them with the weights
138 | flow_u, flow_v = torch.split(flow, split_size_or_sections=1, dim=1)
139 | delta_u = F.conv2d(flow_u, weights, stride=1, padding=1)
140 | delta_v = F.conv2d(flow_v, weights, stride=1, padding=1)
141 | return delta_u, delta_v, mask
142 |
143 |
144 | def create_mask(tensor, paddings):
145 | """Tensor shape: [b, c, h, w]
146 | paddings: [2 x 2] shape list, the first row indicates up and down paddings
147 | the second row indicates left and right paddings
148 | | |
149 | | x |
150 | | x * x |
151 | | x |
152 | | |
153 | """
154 | shape = tensor.shape
155 | inner_height = shape[2] - (paddings[0][0] + paddings[0][1])
156 | inner_width = shape[3] - (paddings[1][0] + paddings[1][1])
157 | inner = torch.ones([inner_height, inner_width])
158 | torch_paddings = [
159 | paddings[1][0],
160 | paddings[1][1],
161 | paddings[0][0],
162 | paddings[0][1],
163 | ] # left, right, up and down
164 | mask2d = F.pad(inner, pad=torch_paddings)
165 | mask3d = mask2d.unsqueeze(0).repeat(shape[0], 1, 1)
166 | mask4d = mask3d.unsqueeze(1)
167 | return mask4d.detach()
168 |
169 |
170 | def ternary_loss(flow_comp, flow_gt, mask, current_frame, shift_frame, scale_factor=1):
171 | if scale_factor != 1:
172 | current_frame = F.interpolate(
173 | current_frame, scale_factor=1 / scale_factor, mode="bilinear"
174 | )
175 | shift_frame = F.interpolate(
176 | shift_frame, scale_factor=1 / scale_factor, mode="bilinear"
177 | )
178 | warped_sc = flow_warp(shift_frame, flow_gt.permute(0, 2, 3, 1))
179 | noc_mask = torch.exp(
180 | -50.0 * torch.sum(torch.abs(current_frame - warped_sc), dim=1).pow(2)
181 | ).unsqueeze(1)
182 | warped_comp_sc = flow_warp(shift_frame, flow_comp.permute(0, 2, 3, 1))
183 | loss = ternary_loss2(current_frame, warped_comp_sc, noc_mask, mask)
184 | return loss
185 |
186 |
187 | class FlowLoss(nn.Module):
188 | def __init__(self):
189 | super().__init__()
190 | self.l1_criterion = nn.L1Loss()
191 |
192 | def forward(self, pred_flows, gt_flows, masks, frames):
193 | # pred_flows: b t-1 2 h w
194 | loss = 0
195 | warp_loss = 0
196 | h, w = pred_flows[0].shape[-2:]
197 | masks = [masks[:, :-1, ...].contiguous(), masks[:, 1:, ...].contiguous()]
198 | frames0 = frames[:, :-1, ...]
199 | frames1 = frames[:, 1:, ...]
200 | current_frames = [frames0, frames1]
201 | next_frames = [frames1, frames0]
202 | for i in range(len(pred_flows)):
203 | # print(pred_flows[i].shape)
204 | combined_flow = pred_flows[i] * masks[i] + gt_flows[i] * (1 - masks[i])
205 | l1_loss = self.l1_criterion(
206 | pred_flows[i] * masks[i], gt_flows[i] * masks[i]
207 | ) / torch.mean(masks[i])
208 | l1_loss += self.l1_criterion(
209 | pred_flows[i] * (1 - masks[i]), gt_flows[i] * (1 - masks[i])
210 | ) / torch.mean(1 - masks[i])
211 |
212 | smooth_loss = smoothness_loss(
213 | combined_flow.reshape(-1, 2, h, w), masks[i].reshape(-1, 1, h, w)
214 | )
215 | smooth_loss2 = second_order_loss(
216 | combined_flow.reshape(-1, 2, h, w), masks[i].reshape(-1, 1, h, w)
217 | )
218 |
219 | warp_loss_i = ternary_loss(
220 | combined_flow.reshape(-1, 2, h, w),
221 | gt_flows[i].reshape(-1, 2, h, w),
222 | masks[i].reshape(-1, 1, h, w),
223 | current_frames[i].reshape(-1, 3, h, w),
224 | next_frames[i].reshape(-1, 3, h, w),
225 | )
226 |
227 | loss += l1_loss + smooth_loss + smooth_loss2
228 |
229 | warp_loss += warp_loss_i
230 |
231 | return loss, warp_loss
232 |
233 |
234 | def edgeLoss(preds_edges, edges):
235 | """Args:
236 | preds_edges: with shape [b, c, h , w]
237 | edges: with shape [b, c, h, w]
238 |
239 | Returns: Edge losses
240 |
241 | """
242 | mask = (edges > 0.5).float()
243 | b, c, h, w = mask.shape
244 | num_pos = torch.sum(mask, dim=[1, 2, 3]).float() # Shape: [b,].
245 | num_neg = c * h * w - num_pos # Shape: [b,].
246 | neg_weights = (num_neg / (num_pos + num_neg)).unsqueeze(1).unsqueeze(2).unsqueeze(3)
247 | pos_weights = (num_pos / (num_pos + num_neg)).unsqueeze(1).unsqueeze(2).unsqueeze(3)
248 | weight = neg_weights * mask + pos_weights * (1 - mask) # weight for debug
249 | losses = F.binary_cross_entropy_with_logits(
250 | preds_edges.float(), edges.float(), weight=weight, reduction="none"
251 | )
252 | loss = torch.mean(losses)
253 | return loss
254 |
255 |
256 | class EdgeLoss(nn.Module):
257 | def __init__(self):
258 | super().__init__()
259 |
260 | def forward(self, pred_edges, gt_edges, masks):
261 | # pred_flows: b t-1 1 h w
262 | loss = 0
263 | h, w = pred_edges[0].shape[-2:]
264 | masks = [masks[:, :-1, ...].contiguous(), masks[:, 1:, ...].contiguous()]
265 | for i in range(len(pred_edges)):
266 | # print(f'edges_{i}', torch.sum(gt_edges[i])) # debug
267 | combined_edge = pred_edges[i] * masks[i] + gt_edges[i] * (1 - masks[i])
268 | edge_loss = edgeLoss(
269 | pred_edges[i].reshape(-1, 1, h, w), gt_edges[i].reshape(-1, 1, h, w)
270 | ) + 5 * edgeLoss(
271 | combined_edge.reshape(-1, 1, h, w), gt_edges[i].reshape(-1, 1, h, w)
272 | )
273 | loss += edge_loss
274 |
275 | return loss
276 |
277 |
278 | class FlowSimpleLoss(nn.Module):
279 | def __init__(self):
280 | super().__init__()
281 | self.l1_criterion = nn.L1Loss()
282 |
283 | def forward(self, pred_flows, gt_flows):
284 | # pred_flows: b t-1 2 h w
285 | loss = 0
286 | h, w = pred_flows[0].shape[-2:]
287 | h_orig, w_orig = gt_flows[0].shape[-2:]
288 | pred_flows = [f.view(-1, 2, h, w) for f in pred_flows]
289 | gt_flows = [f.view(-1, 2, h_orig, w_orig) for f in gt_flows]
290 |
291 | ds_factor = 1.0 * h / h_orig
292 | gt_flows = [
293 | F.interpolate(f, scale_factor=ds_factor, mode="area") * ds_factor
294 | for f in gt_flows
295 | ]
296 | for i in range(len(pred_flows)):
297 | loss += self.l1_criterion(pred_flows[i], gt_flows[i])
298 |
299 | return loss
300 |
--------------------------------------------------------------------------------
/model/modules/flow_loss_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.nn.functional as F
4 |
5 |
6 | def flow_warp(
7 | x, flow, interpolation="bilinear", padding_mode="zeros", align_corners=True
8 | ):
9 | """Warp an image or a feature map with optical flow.
10 |
11 | Args:
12 | x (Tensor): Tensor with size (n, c, h, w).
13 | flow (Tensor): Tensor with size (n, h, w, 2). The last dimension is
14 | a two-channel, denoting the width and height relative offsets.
15 | Note that the values are not normalized to [-1, 1].
16 | interpolation (str): Interpolation mode: 'nearest' or 'bilinear'.
17 | Default: 'bilinear'.
18 | padding_mode (str): Padding mode: 'zeros' or 'border' or 'reflection'.
19 | Default: 'zeros'.
20 | align_corners (bool): Whether align corners. Default: True.
21 |
22 | Returns:
23 | Tensor: Warped image or feature map.
24 | """
25 | if x.size()[-2:] != flow.size()[1:3]:
26 | raise ValueError(
27 | f"The spatial sizes of input ({x.size()[-2:]}) and "
28 | f"flow ({flow.size()[1:3]}) are not the same."
29 | )
30 | _, _, h, w = x.size()
31 | # create mesh grid
32 | device = flow.device
33 | grid_y, grid_x = torch.meshgrid(
34 | torch.arange(0, h, device=device), torch.arange(0, w, device=device)
35 | )
36 | grid = torch.stack((grid_x, grid_y), 2).type_as(x) # (w, h, 2)
37 | grid.requires_grad = False
38 |
39 | grid_flow = grid + flow
40 | # scale grid_flow to [-1,1]
41 | grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0
42 | grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0
43 | grid_flow = torch.stack((grid_flow_x, grid_flow_y), dim=3)
44 | output = F.grid_sample(
45 | x,
46 | grid_flow,
47 | mode=interpolation,
48 | padding_mode=padding_mode,
49 | align_corners=align_corners,
50 | )
51 | return output
52 |
53 |
54 | # def image_warp(image, flow):
55 | # b, c, h, w = image.size()
56 | # device = image.device
57 | # flow = torch.cat([flow[:, 0:1, :, :] / ((w - 1.0) / 2.0), flow[:, 1:2, :, :] / ((h - 1.0) / 2.0)], dim=1) # normalize to [-1~1](from upper left to lower right
58 | # flow = flow.permute(0, 2, 3, 1) # if you wanna use grid_sample function, the channel(band) shape of show must be in the last dimension
59 | # x = np.linspace(-1, 1, w)
60 | # y = np.linspace(-1, 1, h)
61 | # X, Y = np.meshgrid(x, y)
62 | # grid = torch.cat((torch.from_numpy(X.astype('float32')).unsqueeze(0).unsqueeze(3),
63 | # torch.from_numpy(Y.astype('float32')).unsqueeze(0).unsqueeze(3)), 3).to(device)
64 | # output = torch.nn.functional.grid_sample(image, grid + flow, mode='bilinear', padding_mode='zeros')
65 | # return output
66 |
67 |
68 | def length_sq(x):
69 | return torch.sum(torch.square(x), dim=1, keepdim=True)
70 |
71 |
72 | def fbConsistencyCheck(flow_fw, flow_bw, alpha1=0.01, alpha2=0.5):
73 | flow_bw_warped = flow_warp(flow_bw, flow_fw.permute(0, 2, 3, 1)) # wb(wf(x))
74 | flow_fw_warped = flow_warp(flow_fw, flow_bw.permute(0, 2, 3, 1)) # wf(wb(x))
75 | flow_diff_fw = flow_fw + flow_bw_warped # wf + wb(wf(x))
76 | flow_diff_bw = flow_bw + flow_fw_warped # wb + wf(wb(x))
77 |
78 | mag_sq_fw = length_sq(flow_fw) + length_sq(flow_bw_warped) # |wf| + |wb(wf(x))|
79 | mag_sq_bw = length_sq(flow_bw) + length_sq(flow_fw_warped) # |wb| + |wf(wb(x))|
80 | occ_thresh_fw = alpha1 * mag_sq_fw + alpha2
81 | occ_thresh_bw = alpha1 * mag_sq_bw + alpha2
82 |
83 | fb_occ_fw = (length_sq(flow_diff_fw) > occ_thresh_fw).float()
84 | fb_occ_bw = (length_sq(flow_diff_bw) > occ_thresh_bw).float()
85 |
86 | return (
87 | fb_occ_fw,
88 | fb_occ_bw,
89 | ) # fb_occ_fw -> frame2 area occluded by frame1, fb_occ_bw -> frame1 area occluded by frame2
90 |
91 |
92 | def rgb2gray(image):
93 | gray_image = image[:, 0] * 0.299 + image[:, 1] * 0.587 + 0.110 * image[:, 2]
94 | gray_image = gray_image.unsqueeze(1)
95 | return gray_image
96 |
97 |
98 | def ternary_transform(image, max_distance=1):
99 | device = image.device
100 | patch_size = 2 * max_distance + 1
101 | intensities = rgb2gray(image) * 255
102 | out_channels = patch_size * patch_size
103 | w = np.eye(out_channels).reshape(out_channels, 1, patch_size, patch_size)
104 | weights = torch.from_numpy(w).float().to(device)
105 | patches = F.conv2d(intensities, weights, stride=1, padding=1)
106 | transf = patches - intensities
107 | transf_norm = transf / torch.sqrt(0.81 + torch.square(transf))
108 | return transf_norm
109 |
110 |
111 | def hamming_distance(t1, t2):
112 | dist = torch.square(t1 - t2)
113 | dist_norm = dist / (0.1 + dist)
114 | dist_sum = torch.sum(dist_norm, dim=1, keepdim=True)
115 | return dist_sum
116 |
117 |
118 | def create_mask(mask, paddings):
119 | """padding: [[top, bottom], [left, right]]"""
120 | shape = mask.shape
121 | inner_height = shape[2] - (paddings[0][0] + paddings[0][1])
122 | inner_width = shape[3] - (paddings[1][0] + paddings[1][1])
123 | inner = torch.ones([inner_height, inner_width])
124 |
125 | mask2d = F.pad(
126 | inner, pad=[paddings[1][0], paddings[1][1], paddings[0][0], paddings[0][1]]
127 | )
128 | mask3d = mask2d.unsqueeze(0)
129 | mask4d = mask3d.unsqueeze(0).repeat(shape[0], 1, 1, 1)
130 | return mask4d.detach()
131 |
132 |
133 | def ternary_loss2(frame1, warp_frame21, confMask, masks, max_distance=1):
134 | """Args:
135 | frame1: torch tensor, with shape [b * t, c, h, w]
136 | warp_frame21: torch tensor, with shape [b * t, c, h, w]
137 | confMask: confidence mask, with shape [b * t, c, h, w]
138 | masks: torch tensor, with shape [b * t, c, h, w]
139 | max_distance: maximum distance.
140 |
141 | Returns: ternary loss
142 |
143 | """
144 | t1 = ternary_transform(frame1)
145 | t21 = ternary_transform(warp_frame21)
146 | dist = hamming_distance(t1, t21)
147 | loss = torch.mean(dist * confMask * masks) / torch.mean(masks)
148 | return loss
149 |
--------------------------------------------------------------------------------
/model/modules/spectral_norm.py:
--------------------------------------------------------------------------------
1 | """Spectral Normalization from https://arxiv.org/abs/1802.05957"""
2 |
3 | import torch
4 | from torch.nn.functional import normalize
5 |
6 |
7 | class SpectralNorm:
8 | # Invariant before and after each forward call:
9 | # u = normalize(W @ v)
10 | # NB: At initialization, this invariant is not enforced
11 |
12 | _version = 1
13 |
14 | # At version 1:
15 | # made `W` not a buffer,
16 | # added `v` as a buffer, and
17 | # made eval mode use `W = u @ W_orig @ v` rather than the stored `W`.
18 |
19 | def __init__(self, name="weight", n_power_iterations=1, dim=0, eps=1e-12):
20 | self.name = name
21 | self.dim = dim
22 | if n_power_iterations <= 0:
23 | raise ValueError(
24 | "Expected n_power_iterations to be positive, but "
25 | f"got n_power_iterations={n_power_iterations}"
26 | )
27 | self.n_power_iterations = n_power_iterations
28 | self.eps = eps
29 |
30 | def reshape_weight_to_matrix(self, weight):
31 | weight_mat = weight
32 | if self.dim != 0:
33 | # permute dim to front
34 | weight_mat = weight_mat.permute(
35 | self.dim, *[d for d in range(weight_mat.dim()) if d != self.dim]
36 | )
37 | height = weight_mat.size(0)
38 | return weight_mat.reshape(height, -1)
39 |
40 | def compute_weight(self, module, do_power_iteration):
41 | # NB: If `do_power_iteration` is set, the `u` and `v` vectors are
42 | # updated in power iteration **in-place**. This is very important
43 | # because in `DataParallel` forward, the vectors (being buffers) are
44 | # broadcast from the parallelized module to each module replica,
45 | # which is a new module object created on the fly. And each replica
46 | # runs its own spectral norm power iteration. So simply assigning
47 | # the updated vectors to the module this function runs on will cause
48 | # the update to be lost forever. And the next time the parallelized
49 | # module is replicated, the same randomly initialized vectors are
50 | # broadcast and used!
51 | #
52 | # Therefore, to make the change propagate back, we rely on two
53 | # important behaviors (also enforced via tests):
54 | # 1. `DataParallel` doesn't clone storage if the broadcast tensor
55 | # is already on correct device; and it makes sure that the
56 | # parallelized module is already on `device[0]`.
57 | # 2. If the out tensor in `out=` kwarg has correct shape, it will
58 | # just fill in the values.
59 | # Therefore, since the same power iteration is performed on all
60 | # devices, simply updating the tensors in-place will make sure that
61 | # the module replica on `device[0]` will update the _u vector on the
62 | # parallized module (by shared storage).
63 | #
64 | # However, after we update `u` and `v` in-place, we need to **clone**
65 | # them before using them to normalize the weight. This is to support
66 | # backproping through two forward passes, e.g., the common pattern in
67 | # GAN training: loss = D(real) - D(fake). Otherwise, engine will
68 | # complain that variables needed to do backward for the first forward
69 | # (i.e., the `u` and `v` vectors) are changed in the second forward.
70 | weight = getattr(module, self.name + "_orig")
71 | u = getattr(module, self.name + "_u")
72 | v = getattr(module, self.name + "_v")
73 | weight_mat = self.reshape_weight_to_matrix(weight)
74 |
75 | if do_power_iteration:
76 | with torch.no_grad():
77 | for _ in range(self.n_power_iterations):
78 | # Spectral norm of weight equals to `u^T W v`, where `u` and `v`
79 | # are the first left and right singular vectors.
80 | # This power iteration produces approximations of `u` and `v`.
81 | v = normalize(
82 | torch.mv(weight_mat.t(), u), dim=0, eps=self.eps, out=v
83 | )
84 | u = normalize(torch.mv(weight_mat, v), dim=0, eps=self.eps, out=u)
85 | if self.n_power_iterations > 0:
86 | # See above on why we need to clone
87 | u = u.clone()
88 | v = v.clone()
89 |
90 | sigma = torch.dot(u, torch.mv(weight_mat, v))
91 | weight = weight / sigma
92 | return weight
93 |
94 | def remove(self, module):
95 | with torch.no_grad():
96 | weight = self.compute_weight(module, do_power_iteration=False)
97 | delattr(module, self.name)
98 | delattr(module, self.name + "_u")
99 | delattr(module, self.name + "_v")
100 | delattr(module, self.name + "_orig")
101 | module.register_parameter(self.name, torch.nn.Parameter(weight.detach()))
102 |
103 | def __call__(self, module, inputs):
104 | setattr(
105 | module,
106 | self.name,
107 | self.compute_weight(module, do_power_iteration=module.training),
108 | )
109 |
110 | def _solve_v_and_rescale(self, weight_mat, u, target_sigma):
111 | # Tries to returns a vector `v` s.t. `u = normalize(W @ v)`
112 | # (the invariant at top of this class) and `u @ W @ v = sigma`.
113 | # This uses pinverse in case W^T W is not invertible.
114 | v = torch.chain_matmul(
115 | weight_mat.t().mm(weight_mat).pinverse(), weight_mat.t(), u.unsqueeze(1)
116 | ).squeeze(1)
117 | return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v)))
118 |
119 | @staticmethod
120 | def apply(module, name, n_power_iterations, dim, eps):
121 | for k, hook in module._forward_pre_hooks.items():
122 | if isinstance(hook, SpectralNorm) and hook.name == name:
123 | raise RuntimeError(
124 | "Cannot register two spectral_norm hooks on "
125 | f"the same parameter {name}"
126 | )
127 |
128 | fn = SpectralNorm(name, n_power_iterations, dim, eps)
129 | weight = module._parameters[name]
130 |
131 | with torch.no_grad():
132 | weight_mat = fn.reshape_weight_to_matrix(weight)
133 |
134 | h, w = weight_mat.size()
135 | # randomly initialize `u` and `v`
136 | u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps)
137 | v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps)
138 |
139 | delattr(module, fn.name)
140 | module.register_parameter(fn.name + "_orig", weight)
141 | # We still need to assign weight back as fn.name because all sorts of
142 | # things may assume that it exists, e.g., when initializing weights.
143 | # However, we can't directly assign as it could be an nn.Parameter and
144 | # gets added as a parameter. Instead, we register weight.data as a plain
145 | # attribute.
146 | setattr(module, fn.name, weight.data)
147 | module.register_buffer(fn.name + "_u", u)
148 | module.register_buffer(fn.name + "_v", v)
149 |
150 | module.register_forward_pre_hook(fn)
151 |
152 | module._register_state_dict_hook(SpectralNormStateDictHook(fn))
153 | module._register_load_state_dict_pre_hook(SpectralNormLoadStateDictPreHook(fn))
154 | return fn
155 |
156 |
157 | # This is a top level class because Py2 pickle doesn't like inner class nor an
158 | # instancemethod.
159 | class SpectralNormLoadStateDictPreHook:
160 | # See docstring of SpectralNorm._version on the changes to spectral_norm.
161 | def __init__(self, fn):
162 | self.fn = fn
163 |
164 | # For state_dict with version None, (assuming that it has gone through at
165 | # least one training forward), we have
166 | #
167 | # u = normalize(W_orig @ v)
168 | # W = W_orig / sigma, where sigma = u @ W_orig @ v
169 | #
170 | # To compute `v`, we solve `W_orig @ x = u`, and let
171 | # v = x / (u @ W_orig @ x) * (W / W_orig).
172 | def __call__(
173 | self,
174 | state_dict,
175 | prefix,
176 | local_metadata,
177 | strict,
178 | missing_keys,
179 | unexpected_keys,
180 | error_msgs,
181 | ):
182 | fn = self.fn
183 | version = local_metadata.get("spectral_norm", {}).get(
184 | fn.name + ".version", None
185 | )
186 | if version is None or version < 1:
187 | with torch.no_grad():
188 | weight_orig = state_dict[prefix + fn.name + "_orig"]
189 | # weight = state_dict.pop(prefix + fn.name)
190 | # sigma = (weight_orig / weight).mean()
191 | weight_mat = fn.reshape_weight_to_matrix(weight_orig)
192 | u = state_dict[prefix + fn.name + "_u"]
193 | # v = fn._solve_v_and_rescale(weight_mat, u, sigma)
194 | # state_dict[prefix + fn.name + '_v'] = v
195 |
196 |
197 | # This is a top level class because Py2 pickle doesn't like inner class nor an
198 | # instancemethod.
199 | class SpectralNormStateDictHook:
200 | # See docstring of SpectralNorm._version on the changes to spectral_norm.
201 | def __init__(self, fn):
202 | self.fn = fn
203 |
204 | def __call__(self, module, state_dict, prefix, local_metadata):
205 | if "spectral_norm" not in local_metadata:
206 | local_metadata["spectral_norm"] = {}
207 | key = self.fn.name + ".version"
208 | if key in local_metadata["spectral_norm"]:
209 | raise RuntimeError(f"Unexpected key in metadata['spectral_norm']: {key}")
210 | local_metadata["spectral_norm"][key] = self.fn._version
211 |
212 |
213 | def spectral_norm(module, name="weight", n_power_iterations=1, eps=1e-12, dim=None):
214 | r"""Applies spectral normalization to a parameter in the given module.
215 |
216 | .. math::
217 | \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})},
218 | \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
219 |
220 | Spectral normalization stabilizes the training of discriminators (critics)
221 | in Generative Adversarial Networks (GANs) by rescaling the weight tensor
222 | with spectral norm :math:`\sigma` of the weight matrix calculated using
223 | power iteration method. If the dimension of the weight tensor is greater
224 | than 2, it is reshaped to 2D in power iteration method to get spectral
225 | norm. This is implemented via a hook that calculates spectral norm and
226 | rescales weight before every :meth:`~Module.forward` call.
227 |
228 | See `Spectral Normalization for Generative Adversarial Networks`_ .
229 |
230 | .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
231 |
232 | Args:
233 | module (nn.Module): containing module
234 | name (str, optional): name of weight parameter
235 | n_power_iterations (int, optional): number of power iterations to
236 | calculate spectral norm
237 | eps (float, optional): epsilon for numerical stability in
238 | calculating norms
239 | dim (int, optional): dimension corresponding to number of outputs,
240 | the default is ``0``, except for modules that are instances of
241 | ConvTranspose{1,2,3}d, when it is ``1``
242 |
243 | Returns:
244 | The original module with the spectral norm hook
245 |
246 | Example::
247 |
248 | >>> m = spectral_norm(nn.Linear(20, 40))
249 | >>> m
250 | Linear(in_features=20, out_features=40, bias=True)
251 | >>> m.weight_u.size()
252 | torch.Size([40])
253 |
254 | """
255 | if dim is None:
256 | if isinstance(
257 | module,
258 | (
259 | torch.nn.ConvTranspose1d,
260 | torch.nn.ConvTranspose2d,
261 | torch.nn.ConvTranspose3d,
262 | ),
263 | ):
264 | dim = 1
265 | else:
266 | dim = 0
267 | SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
268 | return module
269 |
270 |
271 | def remove_spectral_norm(module, name="weight"):
272 | r"""Removes the spectral normalization reparameterization from a module.
273 |
274 | Args:
275 | module (Module): containing module
276 | name (str, optional): name of weight parameter
277 |
278 | Example:
279 | >>> m = spectral_norm(nn.Linear(40, 10))
280 | >>> remove_spectral_norm(m)
281 | """
282 | for k, hook in module._forward_pre_hooks.items():
283 | if isinstance(hook, SpectralNorm) and hook.name == name:
284 | hook.remove(module)
285 | del module._forward_pre_hooks[k]
286 | return module
287 |
288 | raise ValueError(f"spectral_norm of '{name}' not found in {module}")
289 |
290 |
291 | def use_spectral_norm(module, use_sn=False):
292 | if use_sn:
293 | return spectral_norm(module)
294 | return module
295 |
--------------------------------------------------------------------------------
/model/recurrent_flow_completion.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | import torchvision
5 |
6 | from .modules.deformconv import ModulatedDeformConv2d
7 | from .misc import constant_init
8 |
9 |
10 | class SecondOrderDeformableAlignment(ModulatedDeformConv2d):
11 | """Second-order deformable alignment module."""
12 |
13 | def __init__(self, *args, **kwargs):
14 | self.max_residue_magnitude = kwargs.pop("max_residue_magnitude", 5)
15 |
16 | super(SecondOrderDeformableAlignment, self).__init__(*args, **kwargs)
17 |
18 | self.conv_offset = nn.Sequential(
19 | nn.Conv2d(3 * self.out_channels, self.out_channels, 3, 1, 1),
20 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
21 | nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
22 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
23 | nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
24 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
25 | nn.Conv2d(self.out_channels, 27 * self.deform_groups, 3, 1, 1),
26 | )
27 | self.init_offset()
28 |
29 | def init_offset(self):
30 | constant_init(self.conv_offset[-1], val=0, bias=0)
31 |
32 | def forward(self, x, extra_feat):
33 | out = self.conv_offset(extra_feat)
34 | o1, o2, mask = torch.chunk(out, 3, dim=1)
35 |
36 | # offset
37 | offset = self.max_residue_magnitude * torch.tanh(torch.cat((o1, o2), dim=1))
38 | offset_1, offset_2 = torch.chunk(offset, 2, dim=1)
39 | offset = torch.cat([offset_1, offset_2], dim=1)
40 |
41 | # mask
42 | mask = torch.sigmoid(mask)
43 |
44 | return torchvision.ops.deform_conv2d(
45 | x,
46 | offset,
47 | self.weight,
48 | self.bias,
49 | self.stride,
50 | self.padding,
51 | self.dilation,
52 | mask,
53 | )
54 |
55 |
56 | class BidirectionalPropagation(nn.Module):
57 | def __init__(self, channel):
58 | super(BidirectionalPropagation, self).__init__()
59 | modules = ["backward_", "forward_"]
60 | self.deform_align = nn.ModuleDict()
61 | self.backbone = nn.ModuleDict()
62 | self.channel = channel
63 |
64 | for i, module in enumerate(modules):
65 | self.deform_align[module] = SecondOrderDeformableAlignment(
66 | 2 * channel, channel, 3, padding=1, deform_groups=16
67 | )
68 |
69 | self.backbone[module] = nn.Sequential(
70 | nn.Conv2d((2 + i) * channel, channel, 3, 1, 1),
71 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
72 | nn.Conv2d(channel, channel, 3, 1, 1),
73 | )
74 |
75 | self.fusion = nn.Conv2d(2 * channel, channel, 1, 1, 0)
76 |
77 | def forward(self, x):
78 | """X shape : [b, t, c, h, w]
79 | return [b, t, c, h, w]
80 | """
81 | b, t, c, h, w = x.shape
82 | feats = {}
83 | feats["spatial"] = [x[:, i, :, :, :] for i in range(0, t)]
84 |
85 | for module_name in ["backward_", "forward_"]:
86 | feats[module_name] = []
87 |
88 | frame_idx = range(0, t)
89 | mapping_idx = list(range(0, len(feats["spatial"])))
90 | mapping_idx += mapping_idx[::-1]
91 |
92 | if "backward" in module_name:
93 | frame_idx = frame_idx[::-1]
94 |
95 | feat_prop = x.new_zeros(b, self.channel, h, w)
96 | for i, idx in enumerate(frame_idx):
97 | feat_current = feats["spatial"][mapping_idx[idx]]
98 | if i > 0:
99 | cond_n1 = feat_prop
100 |
101 | # initialize second-order features
102 | feat_n2 = torch.zeros_like(feat_prop)
103 | cond_n2 = torch.zeros_like(cond_n1)
104 | if i > 1: # second-order features
105 | feat_n2 = feats[module_name][-2]
106 | cond_n2 = feat_n2
107 |
108 | cond = torch.cat(
109 | [cond_n1, feat_current, cond_n2], dim=1
110 | ) # condition information, cond(flow warped 1st/2nd feature)
111 | feat_prop = torch.cat(
112 | [feat_prop, feat_n2], dim=1
113 | ) # two order feat_prop -1 & -2
114 | feat_prop = self.deform_align[module_name](feat_prop, cond)
115 |
116 | # fuse current features
117 | feat = (
118 | [feat_current]
119 | + [
120 | feats[k][idx]
121 | for k in feats
122 | if k not in ["spatial", module_name]
123 | ]
124 | + [feat_prop]
125 | )
126 |
127 | feat = torch.cat(feat, dim=1)
128 | # embed current features
129 | feat_prop = feat_prop + self.backbone[module_name](feat)
130 |
131 | feats[module_name].append(feat_prop)
132 |
133 | # end for
134 | if "backward" in module_name:
135 | feats[module_name] = feats[module_name][::-1]
136 |
137 | outputs = []
138 | for i in range(0, t):
139 | align_feats = [feats[k].pop(0) for k in feats if k != "spatial"]
140 | align_feats = torch.cat(align_feats, dim=1)
141 | outputs.append(self.fusion(align_feats))
142 |
143 | return torch.stack(outputs, dim=1) + x
144 |
145 |
146 | class deconv(nn.Module):
147 | def __init__(self, input_channel, output_channel, kernel_size=3, padding=0):
148 | super().__init__()
149 | self.conv = nn.Conv2d(
150 | input_channel,
151 | output_channel,
152 | kernel_size=kernel_size,
153 | stride=1,
154 | padding=padding,
155 | )
156 |
157 | def forward(self, x):
158 | x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True)
159 | return self.conv(x)
160 |
161 |
162 | class P3DBlock(nn.Module):
163 | def __init__(
164 | self,
165 | in_channels,
166 | out_channels,
167 | kernel_size,
168 | stride,
169 | padding,
170 | use_residual=0,
171 | bias=True,
172 | ):
173 | super().__init__()
174 | self.conv1 = nn.Sequential(
175 | nn.Conv3d(
176 | in_channels,
177 | out_channels,
178 | kernel_size=(1, kernel_size, kernel_size),
179 | stride=(1, stride, stride),
180 | padding=(0, padding, padding),
181 | bias=bias,
182 | ),
183 | nn.LeakyReLU(0.2, inplace=True),
184 | )
185 | self.conv2 = nn.Sequential(
186 | nn.Conv3d(
187 | out_channels,
188 | out_channels,
189 | kernel_size=(3, 1, 1),
190 | stride=(1, 1, 1),
191 | padding=(2, 0, 0),
192 | dilation=(2, 1, 1),
193 | bias=bias,
194 | )
195 | )
196 | self.use_residual = use_residual
197 |
198 | def forward(self, feats):
199 | feat1 = self.conv1(feats)
200 | feat2 = self.conv2(feat1)
201 | if self.use_residual:
202 | output = feats + feat2
203 | else:
204 | output = feat2
205 | return output
206 |
207 |
208 | class EdgeDetection(nn.Module):
209 | def __init__(self, in_ch=2, out_ch=1, mid_ch=16):
210 | super().__init__()
211 | self.projection = nn.Sequential(
212 | nn.Conv2d(in_ch, mid_ch, 3, 1, 1), nn.LeakyReLU(0.2, inplace=True)
213 | )
214 |
215 | self.mid_layer_1 = nn.Sequential(
216 | nn.Conv2d(mid_ch, mid_ch, 3, 1, 1), nn.LeakyReLU(0.2, inplace=True)
217 | )
218 |
219 | self.mid_layer_2 = nn.Sequential(nn.Conv2d(mid_ch, mid_ch, 3, 1, 1))
220 |
221 | self.l_relu = nn.LeakyReLU(0.01, inplace=True)
222 |
223 | self.out_layer = nn.Conv2d(mid_ch, out_ch, 1, 1, 0)
224 |
225 | def forward(self, flow):
226 | flow = self.projection(flow)
227 | edge = self.mid_layer_1(flow)
228 | edge = self.mid_layer_2(edge)
229 | edge = self.l_relu(flow + edge)
230 | edge = self.out_layer(edge)
231 | edge = torch.sigmoid(edge)
232 | return edge
233 |
234 |
235 | class RecurrentFlowCompleteNet(nn.Module):
236 | def __init__(self, model_path=None):
237 | super().__init__()
238 | self.downsample = nn.Sequential(
239 | nn.Conv3d(
240 | 3,
241 | 32,
242 | kernel_size=(1, 5, 5),
243 | stride=(1, 2, 2),
244 | padding=(0, 2, 2),
245 | padding_mode="replicate",
246 | ),
247 | nn.LeakyReLU(0.2, inplace=True),
248 | )
249 |
250 | self.encoder1 = nn.Sequential(
251 | P3DBlock(32, 32, 3, 1, 1),
252 | nn.LeakyReLU(0.2, inplace=True),
253 | P3DBlock(32, 64, 3, 2, 1),
254 | nn.LeakyReLU(0.2, inplace=True),
255 | ) # 4x
256 |
257 | self.encoder2 = nn.Sequential(
258 | P3DBlock(64, 64, 3, 1, 1),
259 | nn.LeakyReLU(0.2, inplace=True),
260 | P3DBlock(64, 128, 3, 2, 1),
261 | nn.LeakyReLU(0.2, inplace=True),
262 | ) # 8x
263 |
264 | self.mid_dilation = nn.Sequential(
265 | nn.Conv3d(
266 | 128, 128, (1, 3, 3), (1, 1, 1), padding=(0, 3, 3), dilation=(1, 3, 3)
267 | ), # p = d*(k-1)/2
268 | nn.LeakyReLU(0.2, inplace=True),
269 | nn.Conv3d(
270 | 128, 128, (1, 3, 3), (1, 1, 1), padding=(0, 2, 2), dilation=(1, 2, 2)
271 | ),
272 | nn.LeakyReLU(0.2, inplace=True),
273 | nn.Conv3d(
274 | 128, 128, (1, 3, 3), (1, 1, 1), padding=(0, 1, 1), dilation=(1, 1, 1)
275 | ),
276 | nn.LeakyReLU(0.2, inplace=True),
277 | )
278 |
279 | # feature propagation module
280 | self.feat_prop_module = BidirectionalPropagation(128)
281 |
282 | self.decoder2 = nn.Sequential(
283 | nn.Conv2d(128, 128, 3, 1, 1),
284 | nn.LeakyReLU(0.2, inplace=True),
285 | deconv(128, 64, 3, 1),
286 | nn.LeakyReLU(0.2, inplace=True),
287 | ) # 4x
288 |
289 | self.decoder1 = nn.Sequential(
290 | nn.Conv2d(64, 64, 3, 1, 1),
291 | nn.LeakyReLU(0.2, inplace=True),
292 | deconv(64, 32, 3, 1),
293 | nn.LeakyReLU(0.2, inplace=True),
294 | ) # 2x
295 |
296 | self.upsample = nn.Sequential(
297 | nn.Conv2d(32, 32, 3, padding=1),
298 | nn.LeakyReLU(0.2, inplace=True),
299 | deconv(32, 2, 3, 1),
300 | )
301 |
302 | # edge loss
303 | self.edgeDetector = EdgeDetection(in_ch=2, out_ch=1, mid_ch=16)
304 |
305 | # Need to initial the weights of MSDeformAttn specifically
306 | for m in self.modules():
307 | if isinstance(m, SecondOrderDeformableAlignment):
308 | m.init_offset()
309 |
310 | if model_path is not None:
311 | print("Pretrained flow completion model has loaded...")
312 | ckpt = torch.load(model_path, map_location="cpu")
313 | self.load_state_dict(ckpt, strict=True)
314 |
315 | def forward(self, masked_flows, masks):
316 | # masked_flows: b t-1 2 h w
317 | # masks: b t-1 2 h w
318 | b, t, _, h, w = masked_flows.size()
319 | masked_flows = masked_flows.permute(0, 2, 1, 3, 4)
320 | masks = masks.permute(0, 2, 1, 3, 4)
321 |
322 | inputs = torch.cat((masked_flows, masks), dim=1)
323 |
324 | x = self.downsample(inputs)
325 |
326 | feat_e1 = self.encoder1(x)
327 | feat_e2 = self.encoder2(feat_e1) # b c t h w
328 | feat_mid = self.mid_dilation(feat_e2) # b c t h w
329 | feat_mid = feat_mid.permute(0, 2, 1, 3, 4) # b t c h w
330 |
331 | feat_prop = self.feat_prop_module(feat_mid)
332 | feat_prop = feat_prop.view(-1, 128, h // 8, w // 8) # b*t c h w
333 |
334 | _, c, _, h_f, w_f = feat_e1.shape
335 | feat_e1 = (
336 | feat_e1.permute(0, 2, 1, 3, 4).contiguous().view(-1, c, h_f, w_f)
337 | ) # b*t c h w
338 | feat_d2 = self.decoder2(feat_prop) + feat_e1
339 |
340 | _, c, _, h_f, w_f = x.shape
341 | x = x.permute(0, 2, 1, 3, 4).contiguous().view(-1, c, h_f, w_f) # b*t c h w
342 |
343 | feat_d1 = self.decoder1(feat_d2)
344 |
345 | flow = self.upsample(feat_d1)
346 | if self.training:
347 | edge = self.edgeDetector(flow)
348 | edge = edge.view(b, t, 1, h, w)
349 | else:
350 | edge = None
351 |
352 | flow = flow.view(b, t, 2, h, w)
353 |
354 | return flow, edge
355 |
356 | def forward_bidirect_flow(self, masked_flows_bi, masks):
357 | """Args:
358 | masked_flows_bi: [masked_flows_f, masked_flows_b] | (b t-1 2 h w), (b t-1 2 h w)
359 | masks: b t 1 h w
360 | """
361 | masks_forward = masks[:, :-1, ...].contiguous()
362 | masks_backward = masks[:, 1:, ...].contiguous()
363 |
364 | # mask flow
365 | masked_flows_forward = masked_flows_bi[0] * (1 - masks_forward)
366 | masked_flows_backward = masked_flows_bi[1] * (1 - masks_backward)
367 |
368 | # -- completion --
369 | # forward
370 | pred_flows_forward, pred_edges_forward = self.forward(
371 | masked_flows_forward, masks_forward
372 | )
373 |
374 | # backward
375 | masked_flows_backward = torch.flip(masked_flows_backward, dims=[1])
376 | masks_backward = torch.flip(masks_backward, dims=[1])
377 | pred_flows_backward, pred_edges_backward = self.forward(
378 | masked_flows_backward, masks_backward
379 | )
380 | pred_flows_backward = torch.flip(pred_flows_backward, dims=[1])
381 | if self.training:
382 | pred_edges_backward = torch.flip(pred_edges_backward, dims=[1])
383 |
384 | return [pred_flows_forward, pred_flows_backward], [
385 | pred_edges_forward,
386 | pred_edges_backward,
387 | ]
388 |
389 | def combine_flow(self, masked_flows_bi, pred_flows_bi, masks):
390 | masks_forward = masks[:, :-1, ...].contiguous()
391 | masks_backward = masks[:, 1:, ...].contiguous()
392 |
393 | pred_flows_forward = pred_flows_bi[0] * masks_forward + masked_flows_bi[0] * (
394 | 1 - masks_forward
395 | )
396 | pred_flows_backward = pred_flows_bi[1] * masks_backward + masked_flows_bi[1] * (
397 | 1 - masks_backward
398 | )
399 |
400 | return pred_flows_forward, pred_flows_backward
401 |
--------------------------------------------------------------------------------
/model/vgg_arch.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from collections import OrderedDict
4 | from torch import nn
5 | from torchvision.models import vgg
6 |
7 | VGG_PRETRAIN_PATH = "experiments/pretrained_models/vgg19-dcbb9e9d.pth"
8 | NAMES = {
9 | "vgg11": [
10 | "conv1_1",
11 | "relu1_1",
12 | "pool1",
13 | "conv2_1",
14 | "relu2_1",
15 | "pool2",
16 | "conv3_1",
17 | "relu3_1",
18 | "conv3_2",
19 | "relu3_2",
20 | "pool3",
21 | "conv4_1",
22 | "relu4_1",
23 | "conv4_2",
24 | "relu4_2",
25 | "pool4",
26 | "conv5_1",
27 | "relu5_1",
28 | "conv5_2",
29 | "relu5_2",
30 | "pool5",
31 | ],
32 | "vgg13": [
33 | "conv1_1",
34 | "relu1_1",
35 | "conv1_2",
36 | "relu1_2",
37 | "pool1",
38 | "conv2_1",
39 | "relu2_1",
40 | "conv2_2",
41 | "relu2_2",
42 | "pool2",
43 | "conv3_1",
44 | "relu3_1",
45 | "conv3_2",
46 | "relu3_2",
47 | "pool3",
48 | "conv4_1",
49 | "relu4_1",
50 | "conv4_2",
51 | "relu4_2",
52 | "pool4",
53 | "conv5_1",
54 | "relu5_1",
55 | "conv5_2",
56 | "relu5_2",
57 | "pool5",
58 | ],
59 | "vgg16": [
60 | "conv1_1",
61 | "relu1_1",
62 | "conv1_2",
63 | "relu1_2",
64 | "pool1",
65 | "conv2_1",
66 | "relu2_1",
67 | "conv2_2",
68 | "relu2_2",
69 | "pool2",
70 | "conv3_1",
71 | "relu3_1",
72 | "conv3_2",
73 | "relu3_2",
74 | "conv3_3",
75 | "relu3_3",
76 | "pool3",
77 | "conv4_1",
78 | "relu4_1",
79 | "conv4_2",
80 | "relu4_2",
81 | "conv4_3",
82 | "relu4_3",
83 | "pool4",
84 | "conv5_1",
85 | "relu5_1",
86 | "conv5_2",
87 | "relu5_2",
88 | "conv5_3",
89 | "relu5_3",
90 | "pool5",
91 | ],
92 | "vgg19": [
93 | "conv1_1",
94 | "relu1_1",
95 | "conv1_2",
96 | "relu1_2",
97 | "pool1",
98 | "conv2_1",
99 | "relu2_1",
100 | "conv2_2",
101 | "relu2_2",
102 | "pool2",
103 | "conv3_1",
104 | "relu3_1",
105 | "conv3_2",
106 | "relu3_2",
107 | "conv3_3",
108 | "relu3_3",
109 | "conv3_4",
110 | "relu3_4",
111 | "pool3",
112 | "conv4_1",
113 | "relu4_1",
114 | "conv4_2",
115 | "relu4_2",
116 | "conv4_3",
117 | "relu4_3",
118 | "conv4_4",
119 | "relu4_4",
120 | "pool4",
121 | "conv5_1",
122 | "relu5_1",
123 | "conv5_2",
124 | "relu5_2",
125 | "conv5_3",
126 | "relu5_3",
127 | "conv5_4",
128 | "relu5_4",
129 | "pool5",
130 | ],
131 | }
132 |
133 |
134 | def insert_bn(names):
135 | """Insert bn layer after each conv.
136 |
137 | Args:
138 | names (list): The list of layer names.
139 |
140 | Returns:
141 | list: The list of layer names with bn layers.
142 | """
143 | names_bn = []
144 | for name in names:
145 | names_bn.append(name)
146 | if "conv" in name:
147 | position = name.replace("conv", "")
148 | names_bn.append("bn" + position)
149 | return names_bn
150 |
151 |
152 | class VGGFeatureExtractor(nn.Module):
153 | """VGG network for feature extraction.
154 |
155 | In this implementation, we allow users to choose whether use normalization
156 | in the input feature and the type of vgg network. Note that the pretrained
157 | path must fit the vgg type.
158 |
159 | Args:
160 | layer_name_list (list[str]): Forward function returns the corresponding
161 | features according to the layer_name_list.
162 | Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
163 | vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
164 | use_input_norm (bool): If True, normalize the input image. Importantly,
165 | the input feature must in the range [0, 1]. Default: True.
166 | range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
167 | Default: False.
168 | requires_grad (bool): If true, the parameters of VGG network will be
169 | optimized. Default: False.
170 | remove_pooling (bool): If true, the max pooling operations in VGG net
171 | will be removed. Default: False.
172 | pooling_stride (int): The stride of max pooling operation. Default: 2.
173 | """
174 |
175 | def __init__(
176 | self,
177 | layer_name_list,
178 | vgg_type="vgg19",
179 | use_input_norm=True,
180 | range_norm=False,
181 | requires_grad=False,
182 | remove_pooling=False,
183 | pooling_stride=2,
184 | ):
185 | super(VGGFeatureExtractor, self).__init__()
186 |
187 | self.layer_name_list = layer_name_list
188 | self.use_input_norm = use_input_norm
189 | self.range_norm = range_norm
190 |
191 | self.names = NAMES[vgg_type.replace("_bn", "")]
192 | if "bn" in vgg_type:
193 | self.names = insert_bn(self.names)
194 |
195 | # only borrow layers that will be used to avoid unused params
196 | max_idx = 0
197 | for v in layer_name_list:
198 | idx = self.names.index(v)
199 | if idx > max_idx:
200 | max_idx = idx
201 |
202 | if os.path.exists(VGG_PRETRAIN_PATH):
203 | vgg_net = getattr(vgg, vgg_type)(pretrained=False)
204 | state_dict = torch.load(
205 | VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage
206 | )
207 | vgg_net.load_state_dict(state_dict)
208 | else:
209 | vgg_net = getattr(vgg, vgg_type)(pretrained=True)
210 |
211 | features = vgg_net.features[: max_idx + 1]
212 |
213 | modified_net = OrderedDict()
214 | for k, v in zip(self.names, features):
215 | if "pool" in k:
216 | # if remove_pooling is true, pooling operation will be removed
217 | if remove_pooling:
218 | continue
219 | else:
220 | # in some cases, we may want to change the default stride
221 | modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
222 | else:
223 | modified_net[k] = v
224 |
225 | self.vgg_net = nn.Sequential(modified_net)
226 |
227 | if not requires_grad:
228 | self.vgg_net.eval()
229 | for param in self.parameters():
230 | param.requires_grad = False
231 | else:
232 | self.vgg_net.train()
233 | for param in self.parameters():
234 | param.requires_grad = True
235 |
236 | if self.use_input_norm:
237 | # the mean is for image with range [0, 1]
238 | self.register_buffer(
239 | "mean", torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
240 | )
241 | # the std is for image with range [0, 1]
242 | self.register_buffer(
243 | "std", torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
244 | )
245 |
246 | def forward(self, x):
247 | """Forward function.
248 |
249 | Args:
250 | x (Tensor): Input tensor with shape (n, c, h, w).
251 |
252 | Returns:
253 | Tensor: Forward results.
254 | """
255 | if self.range_norm:
256 | x = (x + 1) / 2
257 | if self.use_input_norm:
258 | x = (x - self.mean) / self.std
259 | output = {}
260 |
261 | for key, layer in self.vgg_net._modules.items():
262 | x = layer(x)
263 | if key in self.layer_name_list:
264 | output[key] = x.clone()
265 |
266 | return output
267 |
--------------------------------------------------------------------------------
/propainter_inference.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 |
3 | import numpy as np
4 | from numpy.typing import NDArray
5 |
6 | import torch
7 | from tqdm import tqdm
8 |
9 | from .model.modules.flow_comp_raft import RAFT_bi
10 | from .model.recurrent_flow_completion import (
11 | RecurrentFlowCompleteNet,
12 | )
13 | from .utils.model_utils import Models
14 | from .model.propainter import InpaintGenerator
15 |
16 |
17 | @dataclass
18 | class ProPainterConfig:
19 | ref_stride: int
20 | neighbor_length: int
21 | subvideo_length: int
22 | raft_iter: int
23 | fp16: str
24 | video_length: int
25 | device: torch.device
26 | process_size: tuple[int, int]
27 | use_half: bool = field(init=False)
28 |
29 | def __post_init__(self) -> None:
30 | """Initialize use-half."""
31 | self.use_half = self.fp16 == "enable"
32 | if self.device == torch.device("cpu"):
33 | self.use_half = False
34 |
35 |
36 | def get_ref_index(
37 | mid_neighbor_id: int,
38 | neighbor_ids: list[int],
39 | config: ProPainterConfig,
40 | ref_num: int = -1,
41 | ) -> list[int]:
42 | """Calculate reference indices for frames based on the provided parameters."""
43 | ref_index = []
44 | if ref_num == -1:
45 | for i in range(0, config.video_length, config.ref_stride):
46 | if i not in neighbor_ids:
47 | ref_index.append(i)
48 | else:
49 | start_idx = max(0, mid_neighbor_id - config.ref_stride * (ref_num // 2))
50 | end_idx = min(
51 | config.video_length, mid_neighbor_id + config.ref_stride * (ref_num // 2)
52 | )
53 | for i in range(start_idx, end_idx, config.ref_stride):
54 | if i not in neighbor_ids:
55 | if len(ref_index) > ref_num:
56 | break
57 | ref_index.append(i)
58 | return ref_index
59 |
60 |
61 | def compute_flow(
62 | raft_model: RAFT_bi, frames: torch.Tensor, config: ProPainterConfig
63 | ) -> tuple[torch.Tensor, torch.Tensor]:
64 | """Compute forward and backward optical flows using the RAFT model."""
65 | if frames.size(dim=-1) <= 640:
66 | short_clip_len = 12
67 | elif frames.size(dim=-1) <= 720:
68 | short_clip_len = 8
69 | elif frames.size(dim=-1) <= 1280:
70 | short_clip_len = 4
71 | else:
72 | short_clip_len = 2
73 |
74 | # use fp32 for RAFT
75 | if frames.size(dim=1) > short_clip_len:
76 | gt_flows_f_list, gt_flows_b_list = [], []
77 | for chunck in range(0, config.video_length, short_clip_len):
78 | end_f = min(config.video_length, chunck + short_clip_len)
79 | if chunck == 0:
80 | flows_f, flows_b = raft_model(
81 | frames[:, chunck:end_f], iters=config.raft_iter
82 | )
83 | else:
84 | flows_f, flows_b = raft_model(
85 | frames[:, chunck - 1 : end_f], iters=config.raft_iter
86 | )
87 |
88 | gt_flows_f_list.append(flows_f)
89 | gt_flows_b_list.append(flows_b)
90 | torch.cuda.empty_cache()
91 |
92 | gt_flows_f = torch.cat(gt_flows_f_list, dim=1)
93 | gt_flows_b = torch.cat(gt_flows_b_list, dim=1)
94 | gt_flows_bi = (gt_flows_f, gt_flows_b)
95 | else:
96 | gt_flows_bi = raft_model(frames, iters=config.raft_iter)
97 | torch.cuda.empty_cache()
98 |
99 | return gt_flows_bi
100 |
101 |
102 | def complete_flow(
103 | recurrent_flow_model: RecurrentFlowCompleteNet,
104 | flows_tuple: tuple[torch.Tensor, torch.Tensor],
105 | flow_masks: torch.Tensor,
106 | subvideo_length: int,
107 | ) -> tuple[torch.Tensor, torch.Tensor]:
108 | """Complete and refine optical flows using a recurrent flow completion model.
109 |
110 | This function processes optical flows in chunks if the total length exceeds the specified
111 | subvideo length. It uses a recurrent model to complete and refine the flows, combining
112 | forward and backward flows into bidirectional flows.
113 | """
114 | flow_length = flows_tuple[0].size(dim=1)
115 | if flow_length > subvideo_length:
116 | pred_flows_f_list, pred_flows_b_list = [], []
117 | pad_len = 5
118 | for f in range(0, flow_length, subvideo_length):
119 | s_f = max(0, f - pad_len)
120 | e_f = min(flow_length, f + subvideo_length + pad_len)
121 | pad_len_s = max(0, f) - s_f
122 | pad_len_e = e_f - min(flow_length, f + subvideo_length)
123 | pred_flows_bi_sub, _ = recurrent_flow_model.forward_bidirect_flow(
124 | (flows_tuple[0][:, s_f:e_f], flows_tuple[1][:, s_f:e_f]),
125 | flow_masks[:, s_f : e_f + 1],
126 | )
127 | pred_flows_bi_sub = recurrent_flow_model.combine_flow(
128 | (flows_tuple[0][:, s_f:e_f], flows_tuple[1][:, s_f:e_f]),
129 | pred_flows_bi_sub,
130 | flow_masks[:, s_f : e_f + 1],
131 | )
132 |
133 | pred_flows_f_list.append(
134 | pred_flows_bi_sub[0][:, pad_len_s : e_f - s_f - pad_len_e]
135 | )
136 | pred_flows_b_list.append(
137 | pred_flows_bi_sub[1][:, pad_len_s : e_f - s_f - pad_len_e]
138 | )
139 | torch.cuda.empty_cache()
140 |
141 | pred_flows_f = torch.cat(pred_flows_f_list, dim=1)
142 | pred_flows_b = torch.cat(pred_flows_b_list, dim=1)
143 |
144 | pred_flows_bi = (pred_flows_f, pred_flows_b)
145 |
146 | else:
147 | pred_flows_bi, _ = recurrent_flow_model.forward_bidirect_flow(
148 | flows_tuple, flow_masks
149 | )
150 | pred_flows_bi = recurrent_flow_model.combine_flow(
151 | flows_tuple, pred_flows_bi, flow_masks
152 | )
153 |
154 | torch.cuda.empty_cache()
155 |
156 | return pred_flows_bi
157 |
158 |
159 | def image_propagation(
160 | inpaint_model: InpaintGenerator,
161 | frames: torch.Tensor,
162 | masks_dilated: torch.Tensor,
163 | prediction_flows: tuple[torch.Tensor, torch.Tensor],
164 | config: ProPainterConfig,
165 | ) -> tuple[torch.Tensor, torch.Tensor]:
166 | """Propagate inpainted images across video frames.
167 |
168 | If the video length exceeds a defined threshold, the process is segmented and handled in chunks.
169 | """
170 | process_width, process_height = config.process_size
171 | masked_frames = frames * (1 - masks_dilated)
172 | subvideo_length_img_prop = min(
173 | 100, config.subvideo_length
174 | ) # ensure a minimum of 100 frames for image propagation
175 | if config.video_length > subvideo_length_img_prop:
176 | updated_frames_list, updated_masks_list = [], []
177 | pad_len = 10
178 | for f in range(0, config.video_length, subvideo_length_img_prop):
179 | s_f = max(0, f - pad_len)
180 | e_f = min(config.video_length, f + subvideo_length_img_prop + pad_len)
181 | pad_len_s = max(0, f) - s_f
182 | pad_len_e = e_f - min(config.video_length, f + subvideo_length_img_prop)
183 | b, t, _, _, _ = masks_dilated[:, s_f:e_f].size()
184 | pred_flows_bi_sub = (
185 | prediction_flows[0][:, s_f : e_f - 1],
186 | prediction_flows[1][:, s_f : e_f - 1],
187 | )
188 | prop_imgs_sub, updated_local_masks_sub = inpaint_model.img_propagation(
189 | masked_frames[:, s_f:e_f],
190 | pred_flows_bi_sub,
191 | masks_dilated[:, s_f:e_f],
192 | "nearest",
193 | )
194 | updated_frames_sub = (
195 | frames[:, s_f:e_f] * (1 - masks_dilated[:, s_f:e_f])
196 | + prop_imgs_sub.view(b, t, 3, process_height, process_width)
197 | * masks_dilated[:, s_f:e_f]
198 | )
199 | updated_masks_sub = updated_local_masks_sub.view(
200 | b, t, 1, process_height, process_width
201 | )
202 |
203 | updated_frames_list.append(
204 | updated_frames_sub[:, pad_len_s : e_f - s_f - pad_len_e]
205 | )
206 | updated_masks_list.append(
207 | updated_masks_sub[:, pad_len_s : e_f - s_f - pad_len_e]
208 | )
209 | torch.cuda.empty_cache()
210 |
211 | updated_frames = torch.cat(updated_frames_list, dim=1)
212 | updated_masks = torch.cat(updated_masks_list, dim=1)
213 | else:
214 | b, t, _, _, _ = masks_dilated.size()
215 | prop_imgs, updated_local_masks = inpaint_model.img_propagation(
216 | masked_frames, prediction_flows, masks_dilated, "nearest"
217 | )
218 | updated_frames = (
219 | frames * (1 - masks_dilated)
220 | + prop_imgs.view(b, t, 3, process_height, process_width) * masks_dilated
221 | )
222 | updated_masks = updated_local_masks.view(b, t, 1, process_height, process_width)
223 | torch.cuda.empty_cache()
224 |
225 | return updated_frames, updated_masks
226 |
227 |
228 | def feature_propagation(
229 | inpaint_model: InpaintGenerator,
230 | updated_frames: torch.Tensor,
231 | updated_masks: torch.Tensor,
232 | masks_dilated: torch.Tensor,
233 | prediction_flows: tuple[torch.Tensor, torch.Tensor],
234 | original_frames: list[NDArray],
235 | config: ProPainterConfig,
236 | ) -> list[NDArray]:
237 | """Propagate inpainted features across video frames.
238 |
239 | The process is segmented and handled in chunks if the video length exceeds a defined threshold.
240 | """
241 | # TODO: Refactor function may be too
242 | process_width, process_height = config.process_size
243 |
244 | # TODO: Refacator how composed frames is initialized
245 | composed_frames: list[NDArray | None] = [None] * config.video_length
246 |
247 | neighbor_stride = config.neighbor_length // 2
248 | ref_num = (
249 | config.subvideo_length // config.ref_stride
250 | if config.video_length > config.subvideo_length
251 | else -1
252 | )
253 |
254 | for f in tqdm(range(0, config.video_length, neighbor_stride)):
255 | neighbor_ids = list(
256 | range(
257 | max(0, f - neighbor_stride),
258 | min(config.video_length, f + neighbor_stride + 1),
259 | )
260 | )
261 | ref_ids = get_ref_index(f, neighbor_ids, config, ref_num)
262 | selected_imgs = updated_frames[:, neighbor_ids + ref_ids, :, :, :]
263 | selected_masks = masks_dilated[:, neighbor_ids + ref_ids, :, :, :]
264 | if config.use_half:
265 | selected_masks = selected_masks.half()
266 | selected_update_masks = updated_masks[:, neighbor_ids + ref_ids, :, :, :]
267 | selected_pred_flows_bi = (
268 | prediction_flows[0][:, neighbor_ids[:-1], :, :, :],
269 | prediction_flows[1][:, neighbor_ids[:-1], :, :, :],
270 | )
271 | with torch.no_grad():
272 | # 1.0 indicates mask
273 | l_t = len(neighbor_ids)
274 |
275 | pred_img = inpaint_model(
276 | selected_imgs,
277 | selected_pred_flows_bi,
278 | selected_masks,
279 | selected_update_masks,
280 | l_t,
281 | )
282 |
283 | pred_img = pred_img.view(-1, 3, process_height, process_width)
284 |
285 | pred_img = (pred_img + 1) / 2
286 | pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255
287 | binary_masks = (
288 | masks_dilated[0, neighbor_ids, :, :, :]
289 | .cpu()
290 | .permute(0, 2, 3, 1)
291 | .numpy()
292 | .astype(np.uint8)
293 | )
294 | for i, idx in enumerate(neighbor_ids):
295 | # idx = neighbor_ids[i]
296 | img = np.array(pred_img[i]).astype(np.uint8) * binary_masks[
297 | i
298 | ] + original_frames[idx] * (1 - binary_masks[i])
299 | if composed_frames[idx] is None:
300 | composed_frames[idx] = img
301 | else:
302 | composed_frames[idx] = (
303 | composed_frames[idx].astype(np.float32) * 0.5
304 | + img.astype(np.float32) * 0.5
305 | )
306 |
307 | composed_frames[idx] = composed_frames[idx].astype(np.uint8)
308 |
309 | torch.cuda.empty_cache()
310 |
311 | return composed_frames
312 |
313 |
314 | def process_inpainting(
315 | models: Models,
316 | frames: torch.Tensor,
317 | flow_masks: torch.Tensor,
318 | masks_dilated: torch.Tensor,
319 | config: ProPainterConfig,
320 | ) -> tuple[torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
321 | """Apply inpainting on video using recurrent flow and ProPainter model."""
322 | with torch.no_grad():
323 | gt_flows_bi = compute_flow(models.raft_model, frames, config)
324 |
325 | if config.use_half:
326 | frames, flow_masks, masks_dilated = (
327 | frames.half(),
328 | flow_masks.half(),
329 | masks_dilated.half(),
330 | )
331 | gt_flows_bi = (gt_flows_bi[0].half(), gt_flows_bi[1].half())
332 |
333 | pred_flows_bi = complete_flow(
334 | models.flow_model, gt_flows_bi, flow_masks, config.subvideo_length
335 | )
336 |
337 | updated_frames, updated_masks = image_propagation(
338 | models.inpaint_model, frames, masks_dilated, pred_flows_bi, config
339 | )
340 |
341 | return updated_frames, updated_masks, pred_flows_bi
342 |
--------------------------------------------------------------------------------
/propainter_nodes.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from comfy import model_management
3 |
4 | from .propainter_inference import (
5 | ProPainterConfig,
6 | feature_propagation,
7 | process_inpainting,
8 | )
9 | from .utils.image_utils import (
10 | ImageConfig,
11 | ImageOutpaintConfig,
12 | convert_image_to_frames,
13 | handle_output,
14 | prepare_frames_and_masks,
15 | extrapolation,
16 | prepare_frames_and_masks_for_outpaint,
17 | )
18 | from .utils.model_utils import initialize_models
19 |
20 |
21 | def check_inputs(frames: torch.Tensor, masks: torch.Tensor) -> Exception | None:
22 | if frames.size(dim=0) <= 1:
23 | raise Exception(f"""Image length must be greater than 1, but got:
24 | Image length: ({frames.size(dim=0)})""")
25 | if frames.size(dim=0) != masks.size(dim=0) and masks.size(dim=0) != 1:
26 | raise Exception(f"""Image and Mask must have the same length or Mask have length 1, but got:
27 | Image length: {frames.size(dim=0)}
28 | Mask length: {masks.size(dim=0)}""")
29 |
30 | if frames.size(dim=1) != masks.size(dim=1) or frames.size(dim=2) != masks.size(
31 | dim=2
32 | ):
33 | raise Exception(f"""Image and Mask must have the same dimensions, but got:
34 | Image: ({frames.size(dim=1)}, {frames.size(dim=2)})
35 | Mask: ({masks.size(dim=1)}, {masks.size(dim=2)})""")
36 |
37 |
38 | class ProPainterInpaint:
39 | """ComfyUI Node for performing inpainting on video frames using ProPainter."""
40 |
41 | def __init__(self):
42 | pass
43 |
44 | @classmethod
45 | def INPUT_TYPES(s):
46 | return {
47 | "required": {
48 | "image": ("IMAGE",), # --video
49 | "mask": ("MASK",), # --mask
50 | "width": ("INT", {"default": 640, "min": 0, "max": 2560}), # --width
51 | "height": ("INT", {"default": 360, "min": 0, "max": 2560}), # --height
52 | "mask_dilates": (
53 | "INT",
54 | {"default": 5, "min": 0, "max": 100},
55 | ), # --mask_dilates
56 | "flow_mask_dilates": (
57 | "INT",
58 | {"default": 8, "min": 0, "max": 100},
59 | ), # arg dont exist on original code
60 | "ref_stride": (
61 | "INT",
62 | {"default": 10, "min": 1, "max": 100},
63 | ), # --ref_stride
64 | "neighbor_length": (
65 | "INT",
66 | {"default": 10, "min": 2, "max": 300},
67 | ), # --neighbor_length
68 | "subvideo_length": (
69 | "INT",
70 | {"default": 80, "min": 1, "max": 300},
71 | ), # --subvideo_length
72 | "raft_iter": (
73 | "INT",
74 | {"default": 20, "min": 1, "max": 100},
75 | ), # --raft_iter
76 | "fp16": (["enable", "disable"],), # --fp16
77 | },
78 | }
79 |
80 | RETURN_TYPES = (
81 | "IMAGE",
82 | "MASK",
83 | "MASK",
84 | )
85 | RETURN_NAMES = (
86 | "IMAGE",
87 | "FLOW_MASK",
88 | "MASK_DILATE",
89 | )
90 | FUNCTION = "propainter_inpainting"
91 | CATEGORY = "ProPainter"
92 |
93 | def propainter_inpainting(
94 | self,
95 | image: torch.Tensor,
96 | mask: torch.Tensor,
97 | width: int,
98 | height: int,
99 | mask_dilates: int,
100 | flow_mask_dilates: int,
101 | ref_stride: int,
102 | neighbor_length: int,
103 | subvideo_length: int,
104 | raft_iter: int,
105 | fp16: str,
106 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
107 | """Perform inpainting on images input using the ProPainter model inference."""
108 | check_inputs(image, mask)
109 | device = model_management.get_torch_device()
110 | # TODO: Check if this convertion from Torch to PIL is really necessary.
111 | frames = convert_image_to_frames(image)
112 | video_length = image.size(dim=0)
113 | input_size = frames[0].size
114 |
115 | image_config = ImageConfig(
116 | width, height, mask_dilates, flow_mask_dilates, input_size, video_length
117 | )
118 | inpaint_config = ProPainterConfig(
119 | ref_stride,
120 | neighbor_length,
121 | subvideo_length,
122 | raft_iter,
123 | fp16,
124 | video_length,
125 | device,
126 | image_config.process_size,
127 | )
128 |
129 | frames_tensor, flow_masks_tensor, masks_dilated_tensor, original_frames = (
130 | prepare_frames_and_masks(frames, mask, image_config, device)
131 | )
132 |
133 | models = initialize_models(device, inpaint_config.fp16)
134 | print(f"\nProcessing {inpaint_config.video_length} frames...")
135 |
136 | updated_frames, updated_masks, pred_flows_bi = process_inpainting(
137 | models,
138 | frames_tensor,
139 | flow_masks_tensor,
140 | masks_dilated_tensor,
141 | inpaint_config,
142 | )
143 |
144 | composed_frames = feature_propagation(
145 | models.inpaint_model,
146 | updated_frames,
147 | updated_masks,
148 | masks_dilated_tensor,
149 | pred_flows_bi,
150 | original_frames,
151 | inpaint_config,
152 | )
153 |
154 | return handle_output(composed_frames, flow_masks_tensor, masks_dilated_tensor)
155 |
156 |
157 | class ProPainterOutpaint:
158 | """ComfyUI Node for performing outpainting on video frames using ProPainter."""
159 |
160 | def __init__(self):
161 | pass
162 |
163 | @classmethod
164 | def INPUT_TYPES(s):
165 | return {
166 | "required": {
167 | "image": ("IMAGE",), # --video
168 | "width": ("INT", {"default": 640, "min": 0, "max": 2560}), # --width
169 | "height": ("INT", {"default": 360, "min": 0, "max": 2560}), # --height
170 | "width_scale": (
171 | "FLOAT",
172 | {
173 | "default": 1.2,
174 | "min": 0.0,
175 | "max": 10.0,
176 | "step": 0.01,
177 | },
178 | ),
179 | "height_scale": (
180 | "FLOAT",
181 | {
182 | "default": 1.0,
183 | "min": 0.0,
184 | "max": 10.0,
185 | "step": 0.01,
186 | },
187 | ),
188 | "mask_dilates": (
189 | "INT",
190 | {"default": 5, "min": 0, "max": 100},
191 | ), # --mask_dilates
192 | "flow_mask_dilates": (
193 | "INT",
194 | {"default": 8, "min": 0, "max": 100},
195 | ), # arg dont exist on original code
196 | "ref_stride": (
197 | "INT",
198 | {"default": 10, "min": 1, "max": 100},
199 | ), # --ref_stride
200 | "neighbor_length": (
201 | "INT",
202 | {"default": 10, "min": 2, "max": 300},
203 | ), # --neighbor_length
204 | "subvideo_length": (
205 | "INT",
206 | {"default": 80, "min": 1, "max": 300},
207 | ), # --subvideo_length
208 | "raft_iter": (
209 | "INT",
210 | {"default": 20, "min": 1, "max": 100},
211 | ), # --raft_iter
212 | "fp16": (["enable", "disable"],), # --fp16
213 | },
214 | }
215 |
216 | RETURN_TYPES = (
217 | "IMAGE",
218 | "MASK",
219 | "INT",
220 | "INT",
221 | )
222 | RETURN_NAMES = (
223 | "IMAGE",
224 | "OUTPAINT_MASK",
225 | "output_width",
226 | "output_height",
227 | )
228 | FUNCTION = "propainter_outpainting"
229 | CATEGORY = "ProPainter"
230 |
231 | def propainter_outpainting(
232 | self,
233 | image: torch.Tensor,
234 | width: int,
235 | height: int,
236 | width_scale: float,
237 | height_scale: float,
238 | mask_dilates: int,
239 | flow_mask_dilates: int,
240 | ref_stride: int,
241 | neighbor_length: int,
242 | subvideo_length: int,
243 | raft_iter: int,
244 | fp16: str,
245 | ) -> tuple[torch.Tensor, torch.Tensor, int, int]:
246 | """Perform inpainting on images input using the ProPainter model inference."""
247 | device = model_management.get_torch_device()
248 | # TODO: Check if this convertion from Torch to PIL is really necessary.
249 | frames = convert_image_to_frames(image)
250 | video_length = image.size(dim=0)
251 | input_size = frames[0].size
252 |
253 | image_config = ImageOutpaintConfig(
254 | width,
255 | height,
256 | mask_dilates,
257 | flow_mask_dilates,
258 | input_size,
259 | video_length,
260 | width_scale,
261 | height_scale,
262 | )
263 |
264 | outpaint_config = ProPainterConfig(
265 | ref_stride,
266 | neighbor_length,
267 | subvideo_length,
268 | raft_iter,
269 | fp16,
270 | video_length,
271 | device,
272 | image_config.outpaint_size,
273 | )
274 |
275 | paded_frames, paded_flow_masks, paded_masks_dilated = extrapolation(
276 | frames, image_config
277 | )
278 |
279 | frames_tensor, flow_masks_tensor, masks_dilated_tensor, original_frames = (
280 | prepare_frames_and_masks_for_outpaint(
281 | paded_frames, paded_flow_masks, paded_masks_dilated, device
282 | )
283 | )
284 |
285 | models = initialize_models(device, outpaint_config.fp16)
286 | print(f"\nProcessing {outpaint_config.video_length} frames...")
287 |
288 | updated_frames, updated_masks, pred_flows_bi = process_inpainting(
289 | models,
290 | frames_tensor,
291 | flow_masks_tensor,
292 | masks_dilated_tensor,
293 | outpaint_config,
294 | )
295 |
296 | composed_frames = feature_propagation(
297 | models.inpaint_model,
298 | updated_frames,
299 | updated_masks,
300 | masks_dilated_tensor,
301 | pred_flows_bi,
302 | original_frames,
303 | outpaint_config,
304 | )
305 |
306 | output_frames, output_masks, _ = handle_output(
307 | composed_frames, flow_masks_tensor, masks_dilated_tensor
308 | )
309 | output_width, output_height = outpaint_config.process_size
310 | return output_frames, output_masks, output_width, output_height
311 |
312 |
313 | NODE_CLASS_MAPPINGS = {
314 | "ProPainterInpaint": ProPainterInpaint,
315 | "ProPainterOutpaint": ProPainterOutpaint,
316 | }
317 |
318 | NODE_DISPLAY_NAME_MAPPINGS = {
319 | "ProPainterInpaint": "ProPainter Inpainting",
320 | "ProPainterOutpaint": "ProPainter Outpainting",
321 | }
322 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "comfyui_propainter_nodes"
3 | description = "ComfyUI custom node implementation of [a/ProPainter](https://github.com/sczhou/ProPainter) framework for video inpainting."
4 | version = "1.0.0"
5 | license = "LICENSE"
6 | dependencies = ["opencv-python"]
7 |
8 | [project.urls]
9 | Repository = "https://github.com/daniabib/ComfyUI_ProPainter_Nodes"
10 | # Used by Comfy Registry https://comfyregistry.org
11 |
12 | [tool.comfy]
13 | PublisherId = "daniabib"
14 | DisplayName = "ComfyUI_ProPainter_Nodes"
15 | Icon = ""
16 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | opencv-python
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/daniabib/ComfyUI_ProPainter_Nodes/9c27d5a0a508bae3296a1886ad026d8d4139d66c/utils/__init__.py
--------------------------------------------------------------------------------
/utils/download_utils.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from urllib.parse import urljoin
3 |
4 | from torch.hub import download_url_to_file
5 |
6 |
7 | def load_file_from_url(
8 | url: str,
9 | model_dir: Path | None = None,
10 | progress: bool = True,
11 | file_name: str | None = None,
12 | ) -> str:
13 | """Load file form http url, will download models if necessary."""
14 | file_name = Path(file_name)
15 | model_dir.mkdir(exist_ok=True)
16 | cached_file = model_dir / file_name
17 | if not cached_file.exists():
18 | print(f'Downloading: "{url}" to {cached_file}\n')
19 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
20 | return cached_file
21 |
22 |
23 | def download_model(model_url: str, model_name: str) -> str:
24 | """Downloads a model from a URL and returns the local path to the downloaded model."""
25 | base_dir = Path(__file__).parents[1].resolve()
26 | target_dir = base_dir / "weights"
27 | return load_file_from_url(
28 | url=urljoin(model_url, model_name),
29 | model_dir=target_dir,
30 | progress=True,
31 | file_name=model_name,
32 | )
33 |
--------------------------------------------------------------------------------
/utils/image_utils.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 |
3 | import numpy as np
4 | import scipy
5 | import torch
6 | from numpy.typing import NDArray
7 | from PIL import Image
8 | from torchvision import transforms
9 | from torchvision.transforms.functional import to_pil_image
10 |
11 |
12 | @dataclass
13 | class ImageConfig:
14 | width: int
15 | height: int
16 | mask_dilates: int
17 | flow_mask_dilates: int
18 | input_size: tuple[int, int]
19 | video_length: int
20 | process_size: tuple[int, int] = field(init=False)
21 |
22 | def __post_init__(self) -> None:
23 | """Initialize process size."""
24 | self.process_size = (
25 | self.width - self.width % 8,
26 | self.height - self.height % 8,
27 | )
28 |
29 |
30 | @dataclass
31 | class ImageOutpaintConfig(ImageConfig):
32 | width_scale: float
33 | height_scale: float
34 | process_size: tuple[int, int] = field(init=False)
35 | outpaint_size: tuple[int, int] = field(init=False)
36 |
37 | # TODO: Refactor
38 | def __post_init__(self) -> None:
39 | """Initialize output size for outpainting."""
40 | self.process_size = (
41 | self.width - self.width % 8,
42 | self.height - self.height % 8,
43 | )
44 | pad_image_width = int(self.width_scale * self.width)
45 | pad_image_height = int(self.height_scale * self.height)
46 | self.outpaint_size = (
47 | pad_image_width - pad_image_width % 8,
48 | pad_image_height - pad_image_height % 8,
49 | )
50 |
51 |
52 | class Stack:
53 | """Stack images based on number of channels."""
54 |
55 | def __init__(self, roll=False):
56 | self.roll = roll
57 |
58 | def __call__(self, img_group) -> NDArray:
59 | mode = img_group[0].mode
60 | if mode == "1":
61 | img_group = [img.convert("L") for img in img_group]
62 | mode = "L"
63 | if mode == "L":
64 | return np.stack([np.expand_dims(x, 2) for x in img_group], axis=2)
65 | if mode == "RGB":
66 | if self.roll:
67 | return np.stack([np.array(x)[:, :, ::-1] for x in img_group], axis=2)
68 | return np.stack(img_group, axis=2)
69 | raise NotImplementedError(f"Image mode {mode}")
70 |
71 |
72 | class ToTorchFormatTensor:
73 | """Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] to a torch FloatTensor of shape (C x H x W) in the range [0.0, 1.0]."""
74 |
75 | # TODO: Check if this function is necessary with comfyUI workflow.
76 | def __init__(self, div=True):
77 | self.div = div
78 |
79 | def __call__(self, pic) -> torch.Tensor:
80 | if isinstance(pic, np.ndarray):
81 | # numpy img: [L, C, H, W]
82 | img = torch.from_numpy(pic).permute(2, 3, 0, 1).contiguous()
83 | else:
84 | # handle PIL Image
85 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
86 | img = img.view(pic.size[1], pic.size[0], len(pic.mode))
87 | # put it from HWC to CHW format
88 | # yikes, this transpose takes 80% of the loading time/CPU
89 | img = img.transpose(0, 1).transpose(0, 2).contiguous()
90 | img = img.float().div(255) if self.div else img.float()
91 | return img
92 |
93 |
94 | def to_tensors():
95 | return transforms.Compose([Stack(), ToTorchFormatTensor()])
96 |
97 |
98 | def resize_images(images: list[Image.Image], config: ImageConfig) -> list[Image.Image]:
99 | """Resizes each image in the list to a new size divisible by 8."""
100 | if config.process_size != config.input_size:
101 | images = [f.resize(config.process_size) for f in images]
102 |
103 | return images
104 |
105 |
106 | def convert_image_to_frames(images: torch.Tensor) -> list[Image.Image]:
107 | """Convert a batch of PyTorch tensors into a list of PIL Image frames."""
108 | frames = []
109 | for image in images:
110 | torch_frame = image.detach().cpu()
111 | np_frame = torch_frame.numpy()
112 | np_frame = (np_frame * 255).clip(0, 255).astype(np.uint8)
113 | frame = Image.fromarray(np_frame)
114 | frames.append(frame)
115 |
116 | return frames
117 |
118 |
119 | def binary_mask(mask: np.ndarray, th: float = 0.1) -> np.ndarray:
120 | mask[mask > th] = 1
121 | mask[mask <= th] = 0
122 |
123 | return mask
124 |
125 |
126 | def convert_mask_to_frames(images: torch.Tensor) -> list[Image.Image]:
127 | """Convert a batch of PyTorch tensors into a list of PIL Image frames."""
128 | frames = []
129 | for image in images:
130 | image = image.detach().cpu()
131 |
132 | # Adjust the scaling based on the data type
133 | if image.dtype == torch.float32:
134 | image = (image * 255).clamp(0, 255).byte()
135 |
136 | frame: Image.Image = to_pil_image(image)
137 | frames.append(frame)
138 |
139 | return frames
140 |
141 |
142 | def read_masks(
143 | masks: torch.Tensor, config: ImageConfig
144 | ) -> tuple[list[Image.Image], list[Image.Image]]:
145 | """TODO: Docstring."""
146 | mask_images = convert_mask_to_frames(masks)
147 | mask_images = resize_images(mask_images, config)
148 | masks_dilated: list[Image.Image] = []
149 | flow_masks: list[Image.Image] = []
150 |
151 | for mask_image in mask_images:
152 | mask_array = np.array(mask_image.convert("L"))
153 |
154 | # Dilate 8 pixel so that all known pixel is trustworthy
155 | if config.flow_mask_dilates > 0:
156 | flow_mask_img = scipy.ndimage.binary_dilation(
157 | mask_array, iterations=config.flow_mask_dilates
158 | ).astype(np.uint8)
159 | else:
160 | flow_mask_img = binary_mask(mask_array).astype(np.uint8)
161 | flow_masks.append(Image.fromarray(flow_mask_img * 255))
162 |
163 | if config.mask_dilates > 0:
164 | mask_array = scipy.ndimage.binary_dilation(
165 | mask_array, iterations=config.mask_dilates
166 | ).astype(np.uint8)
167 | else:
168 | mask_array = binary_mask(mask_array).astype(np.uint8)
169 | masks_dilated.append(Image.fromarray(mask_array * 255))
170 |
171 | if len(mask_images) == 1:
172 | flow_masks = flow_masks * config.video_length
173 | masks_dilated = masks_dilated * config.video_length
174 |
175 | return flow_masks, masks_dilated
176 |
177 |
178 | def prepare_frames_and_masks(
179 | frames: list[Image.Image],
180 | mask: torch.Tensor,
181 | config: ImageConfig,
182 | device: torch.device,
183 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[NDArray]]:
184 | frames = resize_images(frames, config)
185 |
186 | flow_masks, masks_dilated = read_masks(mask, config)
187 |
188 | original_frames = [np.array(f) for f in frames]
189 | frames_tensor: torch.Tensor = to_tensors()(frames).unsqueeze(0) * 2 - 1
190 | flow_masks_tensor: torch.Tensor = to_tensors()(flow_masks).unsqueeze(0)
191 | masks_dilated_tensor: torch.Tensor = to_tensors()(masks_dilated).unsqueeze(0)
192 | frames_tensor, flow_masks_tensor, masks_dilated_tensor = (
193 | frames_tensor.to(device),
194 | flow_masks_tensor.to(device),
195 | masks_dilated_tensor.to(device),
196 | )
197 | return frames_tensor, flow_masks_tensor, masks_dilated_tensor, original_frames
198 |
199 |
200 | def extrapolation(
201 | resized_frames: list[Image.Image], image_config: ImageOutpaintConfig
202 | ) -> tuple[list[Image.Image], list[Image.Image], list[Image.Image]]:
203 | """Prepares the data for video outpainting."""
204 | resized_frames = resize_images(resized_frames, image_config)
205 |
206 | # input_width, input_height = image_config.input_size
207 | resized_width, resized_height = resized_frames[0].size
208 | pad_image_width, pad_image_height = image_config.outpaint_size
209 |
210 | # Defines new FOV.
211 | width_start = int((pad_image_width - resized_width) / 2)
212 | height_start = int((pad_image_height - resized_height) / 2)
213 |
214 | # Extrapolates the FOV for video.
215 | extrapolated_frames = []
216 | for v in resized_frames:
217 | frame = np.zeros(((pad_image_height, pad_image_width, 3)), dtype=np.uint8)
218 | frame[
219 | height_start : height_start + resized_height,
220 | width_start : width_start + resized_width,
221 | :,
222 | ] = v
223 | extrapolated_frames.append(Image.fromarray(frame))
224 |
225 | # Generates the mask for missing region.
226 | masks_dilated = []
227 | flow_masks = []
228 |
229 | dilate_h = 4 if height_start > 10 else 0
230 | dilate_w = 4 if width_start > 10 else 0
231 | mask = np.ones(((pad_image_height, pad_image_width)), dtype=np.uint8)
232 |
233 | mask[
234 | height_start + dilate_h : height_start + resized_height - dilate_h,
235 | width_start + dilate_w : width_start + resized_width - dilate_w,
236 | ] = 0
237 | flow_masks.append(Image.fromarray(mask * 255))
238 |
239 | mask[
240 | height_start : height_start + resized_height,
241 | width_start : width_start + resized_width,
242 | ] = 0
243 | masks_dilated.append(Image.fromarray(mask * 255))
244 |
245 | flow_masks = flow_masks * image_config.video_length
246 | masks_dilated = masks_dilated * image_config.video_length
247 |
248 | return (
249 | extrapolated_frames,
250 | flow_masks,
251 | masks_dilated,
252 | )
253 |
254 |
255 | def prepare_frames_and_masks_for_outpaint(
256 | frames: list[Image.Image],
257 | flow_masks: list[Image.Image],
258 | masks_dilated: list[Image.Image],
259 | device: torch.device,
260 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[NDArray]]:
261 | # flow_masks, masks_dilated = read_masks(mask, config)
262 |
263 | # original_frames = [np.array(f).astype(np.uint8) for f in frames]
264 | original_frames = [np.array(f) for f in frames]
265 | frames_tensor: torch.Tensor = to_tensors()(frames).unsqueeze(0) * 2 - 1
266 | flow_masks_tensor: torch.Tensor = to_tensors()(flow_masks).unsqueeze(0)
267 | masks_dilated_tensor: torch.Tensor = to_tensors()(masks_dilated).unsqueeze(0)
268 | frames_tensor, flow_masks_tensor, masks_dilated_tensor = (
269 | frames_tensor.to(device),
270 | flow_masks_tensor.to(device),
271 | masks_dilated_tensor.to(device),
272 | )
273 | return frames_tensor, flow_masks_tensor, masks_dilated_tensor, original_frames
274 |
275 |
276 | def handle_output(
277 | composed_frames: list[NDArray],
278 | flow_masks: torch.Tensor,
279 | masks_dilated: torch.Tensor,
280 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
281 | output_frames = [
282 | torch.from_numpy(frame.astype(np.float32) / 255.0) for frame in composed_frames
283 | ]
284 |
285 | output_images = torch.stack(output_frames)
286 |
287 | output_flow_masks = flow_masks.squeeze()
288 | output_masks_dilated = masks_dilated.squeeze()
289 |
290 | return output_images, output_flow_masks, output_masks_dilated
291 |
--------------------------------------------------------------------------------
/utils/model_utils.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | import torch
4 |
5 | from ..model.modules.flow_comp_raft import RAFT_bi
6 | from ..model.propainter import InpaintGenerator
7 | from ..model.recurrent_flow_completion import (
8 | RecurrentFlowCompleteNet,
9 | )
10 | from ..utils.download_utils import download_model
11 |
12 |
13 | @dataclass
14 | class Models:
15 | raft_model: RAFT_bi
16 | flow_model: RecurrentFlowCompleteNet
17 | inpaint_model: InpaintGenerator
18 |
19 |
20 | PRETRAIN_MODEL_URL = "https://github.com/sczhou/ProPainter/releases/download/v0.1.0/"
21 |
22 |
23 | def load_raft_model(device: torch.device) -> RAFT_bi:
24 | """Loads the RAFT bi-directional model."""
25 | model_path = download_model(PRETRAIN_MODEL_URL, "raft-things.pth")
26 | raft_model = RAFT_bi(model_path, device)
27 | return raft_model
28 |
29 |
30 | def load_recurrent_flow_model(device: torch.device) -> RecurrentFlowCompleteNet:
31 | """Loads the Recurrent Flow Completion Network model."""
32 | model_path = download_model(PRETRAIN_MODEL_URL, "recurrent_flow_completion.pth")
33 | flow_model = RecurrentFlowCompleteNet(model_path)
34 | for p in flow_model.parameters():
35 | p.requires_grad = False
36 | flow_model.to(device)
37 | flow_model.eval()
38 | return flow_model
39 |
40 |
41 | def load_inpaint_model(device: torch.device) -> InpaintGenerator:
42 | """Loads the Inpaint Generator model."""
43 | model_path = download_model(PRETRAIN_MODEL_URL, "ProPainter.pth")
44 | inpaint_model = InpaintGenerator(model_path=model_path).to(device)
45 | inpaint_model.eval()
46 | return inpaint_model
47 |
48 |
49 | def initialize_models(device: torch.device, use_half: str) -> Models:
50 | """Return initialized inference models."""
51 | raft_model = load_raft_model(device)
52 | flow_model = load_recurrent_flow_model(device)
53 | inpaint_model = load_inpaint_model(device)
54 |
55 | if use_half == "enable":
56 | # raft_model = raft_model.half()
57 | flow_model = flow_model.half()
58 | inpaint_model = inpaint_model.half()
59 | return Models(raft_model, flow_model, inpaint_model)
60 |
--------------------------------------------------------------------------------