├── .gitignore
├── LICENSE
├── README.md
├── build_ebs-linux-cpu+cuda.sh
├── build_ebs-linux-cpu_only.sh
├── build_ebs-win64-cpu+cuda.bat
├── ebsynth_sha256.txt
├── examples
├── facestyle
│ ├── source_Gapp.png
│ ├── source_Gpos.png
│ ├── source_Gseg.png
│ ├── source_painting.png
│ ├── target_Gapp.png
│ ├── target_Gpos.png
│ └── target_Gseg.png
├── input
│ ├── 000.jpg
│ ├── 001.jpg
│ ├── 002.jpg
│ ├── 003.jpg
│ ├── 004.jpg
│ ├── 005.jpg
│ ├── 006.jpg
│ ├── 007.jpg
│ ├── 008.jpg
│ ├── 009.jpg
│ └── 010.jpg
├── mask
│ ├── mask_feather
│ │ ├── mask_00000.png
│ │ ├── mask_00001.png
│ │ ├── mask_00002.png
│ │ ├── mask_00003.png
│ │ ├── mask_00004.png
│ │ ├── mask_00005.png
│ │ ├── mask_00006.png
│ │ ├── mask_00007.png
│ │ ├── mask_00008.png
│ │ ├── mask_00009.png
│ │ └── mask_00010.png
│ └── mask_nofeather
│ │ ├── mask_00000.png
│ │ ├── mask_00001.png
│ │ ├── mask_00002.png
│ │ ├── mask_00003.png
│ │ ├── mask_00004.png
│ │ ├── mask_00005.png
│ │ ├── mask_00006.png
│ │ ├── mask_00007.png
│ │ ├── mask_00008.png
│ │ ├── mask_00009.png
│ │ └── mask_00010.png
├── styles
│ ├── style000.jpg
│ ├── style002.png
│ ├── style003.png
│ ├── style006.png
│ ├── style010.png
│ ├── style014.png
│ ├── style019.png
│ └── style099.jpg
├── stylit
│ ├── source_dirdif.png
│ ├── source_dirspc.png
│ ├── source_fullgi.png
│ ├── source_indirb.png
│ ├── source_style.png
│ ├── target_dirdif.png
│ ├── target_dirspc.png
│ ├── target_fullgi.png
│ └── target_indirb.png
└── texbynum
│ ├── run.bat
│ ├── run.sh
│ ├── source_photo.png
│ ├── source_segment.png
│ └── target_segment.png
├── ezsynth
├── aux_classes.py
├── aux_computations.py
├── aux_flow_viz.py
├── aux_masker.py
├── aux_run.py
├── aux_utils.py
├── constants.py
├── edge_detection.py
├── main_ez.py
├── sequences.py
└── utils
│ ├── __init__.py
│ ├── _eb.py
│ ├── _ebsynth.py
│ ├── blend
│ ├── __init__.py
│ ├── blender.py
│ ├── cupy_accelerated.py
│ ├── histogram_blend.py
│ └── reconstruction.py
│ ├── ebsynth.dll
│ └── flow_utils
│ ├── OpticalFlow.py
│ ├── __init__.py
│ ├── alt_cuda_corr
│ ├── __init__.py
│ ├── correlation.cpp
│ ├── correlation_kernel.cu
│ └── setup.py
│ ├── core
│ ├── __init__.py
│ ├── corr.py
│ ├── datasets.py
│ ├── ef_raft.py
│ ├── extractor.py
│ ├── fd_corr.py
│ ├── fd_decoder.py
│ ├── fd_encoder.py
│ ├── flow_diffusion.py
│ ├── raft.py
│ ├── update.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ ├── __init__.cpython-311.pyc
│ │ └── utils.cpython-311.pyc
│ │ ├── augmentor.py
│ │ ├── flow_viz.py
│ │ └── utils.py
│ ├── ef_raft_models
│ └── .gitkeep
│ ├── flow_diff
│ ├── fd_corr.py
│ ├── fd_decoder.py
│ ├── fd_encoder.py
│ └── flow_diffusion.py
│ ├── flow_diffusion_models
│ └── .gitkeep
│ ├── models
│ ├── raft-kitti.pth
│ ├── raft-sintel.pth
│ └── raft-small.pth
│ └── warp.py
├── output_synth
├── facestyle_err.png
├── facestyle_out.png
├── retarget_err.png
├── retarget_out.png
├── stylit_err.png
└── stylit_out.png
├── requirements.txt
├── test_imgsynth.py
├── test_progress.txt
└── test_redux.py
/.gitignore:
--------------------------------------------------------------------------------
1 | #The Zipped Library
2 | *.zip
3 |
4 | #The Dist Folder from building
5 | /dist/
6 |
7 | #The Build Folder from building
8 | /build/
9 |
10 | /ezsynth.egg-info/
11 |
12 | /.vscode/
13 |
14 | ##
15 | __pycache__
16 | .mypy_cache
17 |
18 | /venv/
19 | *.so
20 |
21 | # Extra model files
22 | /ezsynth/utils/flow_utils/flow_diffusion_models/*.pth
23 | /ezsynth/utils/flow_utils/ef_raft_models/*.pth
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Ezsynth - Ebsynth Python Library
2 |
3 | Reworked version, courtesy of [FuouM](https://github.com/FuouM), with masking support and some visual bug fixes. Aims to be easy to use and maintain.
4 |
5 | Perform things like style transfer, color transfer, inpainting, superimposition, video stylization and more!
6 | This implementation makes use of advanced physics based edge detection and RAFT optical flow, which leads to more accurate results during synthesis.
7 |
8 | :warning: **This is not intended to be used as an installable module.**
9 |
10 | Currently tested on:
11 | ```
12 | Windows 10 - Python 3.11 - RTX3060
13 | Ubuntu 24 - Python 3.12 - RTX4070(Laptop)
14 | ```
15 |
16 | ## Get started
17 |
18 | ### Windows
19 | ```cmd
20 | rem Clone this repo
21 | git clone https://github.com/Trentonom0r3/Ezsynth.git
22 | cd Ezsynth
23 |
24 | rem (Optional) create and activate venv
25 | python -m venv venv
26 | venv\Scripts\activate.bat
27 |
28 | rem Install requirements
29 | pip install -r requirements.txt
30 |
31 | rem A precompiled ebsynth.dll is included.
32 | rem If don't want to rebuild, you are ready to go and can skip the following steps.
33 |
34 | rem Clone ebsynth
35 | git clone https://github.com/Trentonom0r3/ebsynth.git
36 |
37 | rem build ebsynth as lib
38 | copy .\build_ebs-win64-cpu+cuda.bat .\ebsynth
39 | cd ebsynth && .\build_ebs-win64-cpu+cuda.bat
40 |
41 | rem copy lib
42 | cp .\bin\ebsynth.so ..\ezsynth\utils\ebsynth.so
43 |
44 | rem cleanup
45 | cd .. && rmdir /s /q .\ebsynth
46 | ```
47 |
48 | ### Linux
49 | ```bash
50 | # clone this repo
51 | git clone https://github.com/Trentonom0r3/Ezsynth.git
52 | cd Ezsynth
53 |
54 | # (optional) create and activate venv
55 | python -m venv venv
56 | source ./venv/bin/activate
57 |
58 | # install requirements
59 | pip install -r requirements.txt
60 |
61 | # clone ebsynth
62 | git clone https://github.com/Trentonom0r3/ebsynth.git
63 |
64 | # build ebsynth as lib
65 | cp ./build_ebs-linux-cpu+cuda.sh ./ebsynth
66 | cd ebsynth && ./build_ebs-linux-cpu+cuda.sh
67 |
68 | # copy lib
69 | cp ./bin/ebsynth.so ../ezsynth/utils/ebsynth.so
70 |
71 | # cleanup
72 | cd .. && rm -rf ./ebsynth
73 | ```
74 |
75 | ### All
76 | You may also install Cupy and Cupyx to use GPU for some other operations.
77 |
78 | ## Examples
79 |
80 | * To get started, see `test_redux.py` for an example of generating a full video.
81 | * To generate image style transfer, see `test_imgsynth.py` for all examples from the original `Ebsynth`.
82 |
83 | ## Example outputs
84 |
85 | | Face style | Stylit | Retarget |
86 | |:-:|:-:|:-:|
87 | |
|
|
|
88 |
89 | https://github.com/user-attachments/assets/aa3cd191-4eb2-4dc0-8213-2c763f1b3316
90 |
91 | https://github.com/user-attachments/assets/63e50272-aa5c-42a1-a5ec-46178cdf2981
92 |
93 | Comparison of Edge methods
94 |
95 | ## Notable things
96 |
97 | **Updates:**
98 | 1. [Ef-RAFT](https://github.com/n3slami/Ef-RAFT) is added
99 |
100 | To use, download models from [the original repo](https://github.com/n3slami/Ef-RAFT/tree/master/models) and place them in `/ezsynth/utils/flow_utils/ef_raft_models`
101 | ```
102 | .gitkeep
103 | 25000_ours-sintel.pth
104 | ours-things.pth
105 | ours_sintel.pth
106 | ```
107 |
108 | 2. [FlowDiffuser](https://github.com/LA30/FlowDiffuser) is added.
109 |
110 | To use, download the model from [the original repo](https://github.com/LA30/FlowDiffuser?tab=readme-ov-file#usage) and place it in `/ezsynth/utils/flow_utils/flow_diffusion_models/FlowDiffuser-things.pth`.
111 |
112 | You will also need to install PyTorch Image Models to run it: `pip install timm`. On first run, it will download 2 models ~470MB `twins_svt_large (378 MB)` and `twins_svt_small (92 MB)`.
113 |
114 | This increases the VRAM usage significantly when run along with `EbSynth Run` (~15GB, but may not OOM. Tested on 12GB VRAM).
115 |
116 | In that case, It will throw `CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR` error, but shouldn't be fatal, and instead takes ~3x as long to run.
117 |
118 | https://github.com/user-attachments/assets/7f43630f-c7c9-40d0-8745-58d1f7c84d4f
119 |
120 | Comparison of Optical Flow models
121 |
122 | Optical Flow directly affects Flow position warping and Style image warping, controlled by `pos_wgt` and `wrp_wgt` respectively.
123 |
124 | **Changes:**
125 | 1. Flow is calculated on a frame by frame basis, with correct time orientation, instead of pre-computing only a forward-flow.
126 | 2. Padding is applied to Edge detection and Warping to remove border visual distortion.
127 |
128 |
129 | **Observations:**
130 | 1. Edge detection models return NaN if input tensor has too many zeros(?).
131 | 2. Pre-masked inputs take twice as long to run Ebsynth
132 |
133 | ## API Overview
134 |
135 | ### ImageSynth
136 | For image-to-image style transfer, via file paths: `test_imgsynth.py`
137 | ```python
138 | ezsynner = ImageSynth(
139 | style_path="source_style.png",
140 | src_path="source_fullgi.png",
141 | tgt_path="target_fullgi.png",
142 | cfg=RunConfig(img_wgt=0.66),
143 | )
144 |
145 | result = ezsynner.run(
146 | guides=[
147 | load_guide(
148 | "source_dirdif.png",
149 | "target_dirdif.png",
150 | 0.66,
151 | ),
152 | load_guide(
153 | "source_indirb.png",
154 | "target_indirb.png",
155 | 0.66,
156 | ),
157 | ]
158 | )
159 |
160 | save_to_folder(output_folder, "stylit_out.png", result[0]) # Styled image
161 | save_to_folder(output_folder, "stylit_err.png", result[1]) # Error image
162 | ```
163 |
164 | ### Ezsynth
165 |
166 | **edge_method**
167 |
168 | Edge detection method. Choose from `PST`, `Classic`, or `PAGE`.
169 | * `PST` (Phase Stretch Transform): Good overall structure, but not very detailed.
170 | * `Classic`: A good balance between structure and detail.
171 | * `PAGE` (Phase and Gradient Estimation): Great detail, great structure, but slow.
172 |
173 | **video stylization**
174 |
175 | Via file paths (see `test_redux.py`):
176 |
177 | ```python
178 | style_paths = [
179 | "style000.png",
180 | "style006.png"
181 | ]
182 |
183 | ezrunner = Ezsynth(
184 | style_paths=style_paths,
185 | image_folder=image_folder,
186 | cfg=RunConfig(pre_mask=False, feather=5, return_masked_only=False),
187 | edge_method="PAGE",
188 | raft_flow_model_name="sintel",
189 | mask_folder=mask_folder,
190 | do_mask=True
191 | )
192 |
193 | only_mode = None
194 | stylized_frames, err_frames = ezrunner.run_sequences(only_mode)
195 |
196 | save_seq(stylized_frames, "output")
197 | ```
198 |
199 | Via Numpy ndarrays:
200 |
201 | ```python
202 | class EzsynthBase:
203 | def __init__(
204 | self,
205 | style_frs: list[np.ndarray],
206 | style_idxes: list[int],
207 | img_frs_seq: list[np.ndarray],
208 | cfg: RunConfig = RunConfig(),
209 | edge_method="Classic",
210 | raft_flow_model_name="sintel",
211 | do_mask=False,
212 | msk_frs_seq: list[np.ndarray] | None = None,
213 | ):
214 | pass
215 | ```
216 |
217 | ### RunConfig
218 | #### Ebsynth gen params
219 | * `uniformity (float)`: Uniformity weight for the style transfer. Reasonable values are between `500-15000`. Defaults to `3500.0`.
220 |
221 | * `patchsize (int)`: Size of the patches [NxN]. Must be an odd number `>= 3`. Defaults to `7`.
222 |
223 | * `pyramidlevels (int)`: Number of pyramid levels. Larger values useful for things like color transfer. Defaults to `6`.
224 |
225 | * `searchvoteiters (int)`: Number of search/vote iterations. Defaults to `12`.
226 | * `patchmatchiters (int)`: Number of Patch-Match iterations. The larger, the longer it takes. Defaults to `6`.
227 |
228 | * `extrapass3x3 (bool)`: Perform additional polishing pass with 3x3 patches at the finest level. Defaults to `True`.
229 |
230 | #### Ebsynth guide weights params
231 | * `edg_wgt (float)`: Edge detect weights. Defaults to `1.0`.
232 | * `img_wgt (float)`: Original image weights. Defaults to `6.0`.
233 | * `pos_wgt (float)`: Flow position warping weights. Defaults to `2.0`.
234 | * `wrp_wgt (float)`: Warped style image weight. Defaults to `0.5`.
235 |
236 | #### Blending params
237 | * `use_gpu (bool)`: Use GPU for Histogram Blending (Only affect Blend mode). Faster than CPU. Defaults to `False`.
238 |
239 | * `use_lsqr (bool)`: Use LSQR (Least-squares solver) instead of LSMR (Iterative solver for least-squares) for Poisson blending step. LSQR often yield better results. May change to LSMR for speed (depends). Defaults to `True`.
240 |
241 | * `use_poisson_cupy (bool)`: Use Cupy GPU acceleration for Poisson blending step. Uses LSMR (overrides `use_lsqr`). May not yield better speed. Defaults to `False`.
242 |
243 | * `poisson_maxiter (int | None)`: Max iteration to calculate Poisson Least-squares (only affect LSMR mode). Expect positive integers. Defaults to `None`.
244 |
245 | * `only_mode (str)`: Skip blending, only run one pass per sequence. Valid values:
246 | * `MODE_FWD = "forward"` (Will only run forward mode if `sequence.mode` is blend)
247 |
248 | * `MODE_REV = "reverse"` (Will only run reverse mode if `sequence.mode` is blend)
249 |
250 | * Defaults to `MODE_NON = "none"`.
251 |
252 | #### Masking params
253 | * `do_mask (bool)`: Whether to apply mask. Defaults to `False`.
254 |
255 | * `pre_mask (bool)`: Whether to mask the inputs and styles before `RUN` or after. Pre-mask takes ~2x time to run per frame. Could be due to Ebsynth.dll implementation. Defaults to `False`.
256 |
257 | * `feather (int)`: Feather Gaussian radius to apply on the mask results. Only affect if `return_masked_only == False`. Expects integers. Defaults to `0`.
258 |
259 | ## Credits
260 |
261 | jamriska - https://github.com/jamriska/ebsynth
262 |
263 | ```
264 | @misc{Jamriska2018,
265 | author = {Jamriska, Ondrej},
266 | title = {Ebsynth: Fast Example-based Image Synthesis and Style Transfer},
267 | year = {2018},
268 | publisher = {GitHub},
269 | journal = {GitHub repository},
270 | howpublished = {\url{https://github.com/jamriska/ebsynth}},
271 | }
272 | ```
273 | ```
274 | Ondřej Jamriška, Šárka Sochorová, Ondřej Texler, Michal Lukáč, Jakub Fišer, Jingwan Lu, Eli Shechtman, and Daniel Sýkora. 2019. Stylizing Video by Example. ACM Trans. Graph. 38, 4, Article 107 (July 2019), 11 pages. https://doi.org/10.1145/3306346.3323006
275 | ```
276 |
277 | FuouM - https://github.com/FuouM
278 | pravdomil - https://github.com/pravdomil
279 | xy-gao - https://github.com/xy-gao
280 |
281 | https://github.com/princeton-vl/RAFT
282 |
283 | ```
284 | RAFT: Recurrent All Pairs Field Transforms for Optical Flow
285 | ECCV 2020
286 | Zachary Teed and Jia Deng
287 | ```
288 |
289 | https://github.com/n3slami/Ef-RAFT
290 |
291 | ```
292 | @inproceedings{eslami2024rethinking,
293 | title={Rethinking RAFT for efficient optical flow},
294 | author={Eslami, Navid and Arefi, Farnoosh and Mansourian, Amir M and Kasaei, Shohreh},
295 | booktitle={2024 13th Iranian/3rd International Machine Vision and Image Processing Conference (MVIP)},
296 | pages={1--7},
297 | year={2024},
298 | organization={IEEE}
299 | }
300 | ```
301 |
302 | https://github.com/LA30/FlowDiffuser
303 |
304 | ```
305 | @inproceedings{luo2024flowdiffuser,
306 | title={FlowDiffuser: Advancing Optical Flow Estimation with Diffusion Models},
307 | author={Luo, Ao and Li, Xin and Yang, Fan and Liu, Jiangyu and Fan, Haoqiang and Liu, Shuaicheng},
308 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
309 | pages={19167--19176},
310 | year={2024}
311 | }
312 | ```
313 |
--------------------------------------------------------------------------------
/build_ebs-linux-cpu+cuda.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | nvcc --shared src/ebsynth.cpp src/ebsynth_cpu.cpp src/ebsynth_cuda.cu -I"include" -DNDEBUG -D__CORRECT_ISO_CPP11_MATH_H_PROTO -O6 -std=c++14 -w -Xcompiler -fopenmp,-fPIC -o bin/ebsynth.so
3 |
--------------------------------------------------------------------------------
/build_ebs-linux-cpu_only.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | g++ src/ebsynth.cpp src/ebsynth_cpu.cpp src/ebsynth_nocuda.cpp -DNDEBUG -O6 -fopenmp -I"include" -std=c++11 -o bin/ebsynth
3 |
--------------------------------------------------------------------------------
/build_ebs-win64-cpu+cuda.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 | setlocal ENABLEDELAYEDEXPANSION
3 |
4 | call "vcvarsall.bat" amd64
5 |
6 | nvcc -v -arch sm_86 src\ebsynth.cpp src\ebsynth_cpu.cpp src\ebsynth_cuda.cu -DNDEBUG -O6 -I "include" -o "bin\ebsynth.dll" -Xcompiler "/openmp /fp:fast" -Xlinker "/IMPLIB:lib\ebsynth.lib" -shared -DEBSYNTH_API=__declspec(dllexport) -w || goto error
7 |
8 | del dummy.lib;dummy.exp 2> NUL
9 | goto :EOF
10 |
11 | :error
12 | echo FAILED
13 | @%COMSPEC% /C exit 1 >nul
--------------------------------------------------------------------------------
/ebsynth_sha256.txt:
--------------------------------------------------------------------------------
1 | SHA256 hash of ebsynth.dll:
2 | e3cfad210d445fcbfa6c7dcd2f9bdaaf36d550746c108c79a94d2d1ecce41369
3 | CertUtil: -hashfile command completed successfully.
4 |
--------------------------------------------------------------------------------
/examples/facestyle/source_Gapp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/facestyle/source_Gapp.png
--------------------------------------------------------------------------------
/examples/facestyle/source_Gpos.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/facestyle/source_Gpos.png
--------------------------------------------------------------------------------
/examples/facestyle/source_Gseg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/facestyle/source_Gseg.png
--------------------------------------------------------------------------------
/examples/facestyle/source_painting.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/facestyle/source_painting.png
--------------------------------------------------------------------------------
/examples/facestyle/target_Gapp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/facestyle/target_Gapp.png
--------------------------------------------------------------------------------
/examples/facestyle/target_Gpos.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/facestyle/target_Gpos.png
--------------------------------------------------------------------------------
/examples/facestyle/target_Gseg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/facestyle/target_Gseg.png
--------------------------------------------------------------------------------
/examples/input/000.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/input/000.jpg
--------------------------------------------------------------------------------
/examples/input/001.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/input/001.jpg
--------------------------------------------------------------------------------
/examples/input/002.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/input/002.jpg
--------------------------------------------------------------------------------
/examples/input/003.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/input/003.jpg
--------------------------------------------------------------------------------
/examples/input/004.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/input/004.jpg
--------------------------------------------------------------------------------
/examples/input/005.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/input/005.jpg
--------------------------------------------------------------------------------
/examples/input/006.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/input/006.jpg
--------------------------------------------------------------------------------
/examples/input/007.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/input/007.jpg
--------------------------------------------------------------------------------
/examples/input/008.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/input/008.jpg
--------------------------------------------------------------------------------
/examples/input/009.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/input/009.jpg
--------------------------------------------------------------------------------
/examples/input/010.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/input/010.jpg
--------------------------------------------------------------------------------
/examples/mask/mask_feather/mask_00000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_feather/mask_00000.png
--------------------------------------------------------------------------------
/examples/mask/mask_feather/mask_00001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_feather/mask_00001.png
--------------------------------------------------------------------------------
/examples/mask/mask_feather/mask_00002.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_feather/mask_00002.png
--------------------------------------------------------------------------------
/examples/mask/mask_feather/mask_00003.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_feather/mask_00003.png
--------------------------------------------------------------------------------
/examples/mask/mask_feather/mask_00004.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_feather/mask_00004.png
--------------------------------------------------------------------------------
/examples/mask/mask_feather/mask_00005.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_feather/mask_00005.png
--------------------------------------------------------------------------------
/examples/mask/mask_feather/mask_00006.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_feather/mask_00006.png
--------------------------------------------------------------------------------
/examples/mask/mask_feather/mask_00007.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_feather/mask_00007.png
--------------------------------------------------------------------------------
/examples/mask/mask_feather/mask_00008.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_feather/mask_00008.png
--------------------------------------------------------------------------------
/examples/mask/mask_feather/mask_00009.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_feather/mask_00009.png
--------------------------------------------------------------------------------
/examples/mask/mask_feather/mask_00010.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_feather/mask_00010.png
--------------------------------------------------------------------------------
/examples/mask/mask_nofeather/mask_00000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_nofeather/mask_00000.png
--------------------------------------------------------------------------------
/examples/mask/mask_nofeather/mask_00001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_nofeather/mask_00001.png
--------------------------------------------------------------------------------
/examples/mask/mask_nofeather/mask_00002.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_nofeather/mask_00002.png
--------------------------------------------------------------------------------
/examples/mask/mask_nofeather/mask_00003.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_nofeather/mask_00003.png
--------------------------------------------------------------------------------
/examples/mask/mask_nofeather/mask_00004.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_nofeather/mask_00004.png
--------------------------------------------------------------------------------
/examples/mask/mask_nofeather/mask_00005.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_nofeather/mask_00005.png
--------------------------------------------------------------------------------
/examples/mask/mask_nofeather/mask_00006.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_nofeather/mask_00006.png
--------------------------------------------------------------------------------
/examples/mask/mask_nofeather/mask_00007.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_nofeather/mask_00007.png
--------------------------------------------------------------------------------
/examples/mask/mask_nofeather/mask_00008.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_nofeather/mask_00008.png
--------------------------------------------------------------------------------
/examples/mask/mask_nofeather/mask_00009.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_nofeather/mask_00009.png
--------------------------------------------------------------------------------
/examples/mask/mask_nofeather/mask_00010.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_nofeather/mask_00010.png
--------------------------------------------------------------------------------
/examples/styles/style000.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/styles/style000.jpg
--------------------------------------------------------------------------------
/examples/styles/style002.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/styles/style002.png
--------------------------------------------------------------------------------
/examples/styles/style003.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/styles/style003.png
--------------------------------------------------------------------------------
/examples/styles/style006.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/styles/style006.png
--------------------------------------------------------------------------------
/examples/styles/style010.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/styles/style010.png
--------------------------------------------------------------------------------
/examples/styles/style014.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/styles/style014.png
--------------------------------------------------------------------------------
/examples/styles/style019.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/styles/style019.png
--------------------------------------------------------------------------------
/examples/styles/style099.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/styles/style099.jpg
--------------------------------------------------------------------------------
/examples/stylit/source_dirdif.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/stylit/source_dirdif.png
--------------------------------------------------------------------------------
/examples/stylit/source_dirspc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/stylit/source_dirspc.png
--------------------------------------------------------------------------------
/examples/stylit/source_fullgi.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/stylit/source_fullgi.png
--------------------------------------------------------------------------------
/examples/stylit/source_indirb.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/stylit/source_indirb.png
--------------------------------------------------------------------------------
/examples/stylit/source_style.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/stylit/source_style.png
--------------------------------------------------------------------------------
/examples/stylit/target_dirdif.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/stylit/target_dirdif.png
--------------------------------------------------------------------------------
/examples/stylit/target_dirspc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/stylit/target_dirspc.png
--------------------------------------------------------------------------------
/examples/stylit/target_fullgi.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/stylit/target_fullgi.png
--------------------------------------------------------------------------------
/examples/stylit/target_indirb.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/stylit/target_indirb.png
--------------------------------------------------------------------------------
/examples/texbynum/run.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 | setlocal
3 | set PATH=..\..\bin;%PATH%
4 |
5 | ebsynth.exe -patchsize 3 -uniformity 1000 -style source_photo.png -guide source_segment.png target_segment.png -output output.png
6 |
--------------------------------------------------------------------------------
/examples/texbynum/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | export PATH=../../bin:$PATH
3 |
4 | ebsynth -patchsize 3 -uniformity 1000 -style source_photo.png -guide source_segment.png target_segment.png -output output.png
5 |
6 |
--------------------------------------------------------------------------------
/examples/texbynum/source_photo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/texbynum/source_photo.png
--------------------------------------------------------------------------------
/examples/texbynum/source_segment.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/texbynum/source_segment.png
--------------------------------------------------------------------------------
/examples/texbynum/target_segment.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/texbynum/target_segment.png
--------------------------------------------------------------------------------
/ezsynth/aux_classes.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 | from .utils.flow_utils.warp import Warp
5 | from .sequences import EasySequence
6 |
7 |
8 | class RunConfig:
9 | """
10 | ### Ebsynth gen params
11 | `uniformity (float)`: Uniformity weight for the style transfer.
12 | Reasonable values are between `500-15000`.
13 | Defaults to `3500.0`.
14 |
15 | `patchsize (int)`: Size of the patches [NxN]. Must be an odd number `>= 3`.
16 | Defaults to `7`.
17 |
18 | `pyramidlevels (int)`: Number of pyramid levels. Larger values useful for things like color transfer.
19 | Defaults to `6`.
20 |
21 | `searchvoteiters (int)`: Number of search/vote iterations. Defaults to `12`.
22 | `patchmatchiters (int)`: Number of Patch-Match iterations. The larger, the longer it takes.
23 | Defaults to `6`.
24 |
25 | `extrapass3x3 (bool)`: Perform additional polishing pass with 3x3 patches at the finest level.
26 | Defaults to `True`.
27 |
28 | ### Ebsynth guide weights params
29 | `edg_wgt (float)`: Edge detect weights. Defaults to `1.0`.
30 | `img_wgt (float)`: Original image weights. Defaults to `6.0`.
31 | `pos_wgt (float)`: Flow position warping weights. Defaults to `2.0`.
32 | `wrp_wgt (float)`: Warped style image weight. Defaults to `0.5`.
33 |
34 | ### Blending params
35 | `use_gpu (bool)`: Use GPU for Histogram Blending (Only affect Blend mode). Faster than CPU.
36 | Defaults to `False`.
37 |
38 | `use_lsqr (bool)`: Use LSQR (Least-squares solver) instead of LSMR (Iterative solver for least-squares)
39 | for Poisson blending step. LSQR often yield better results. May change to LSMR for speed (depends).
40 | Defaults to `True`.
41 |
42 | `use_poisson_cupy (bool)`: Use Cupy GPU acceleration for Poisson blending step.
43 | Uses LSMR (overrides `use_lsqr`). May not yield better speed.
44 | Defaults to `False`.
45 |
46 | `poisson_maxiter (int | None)`: Max iteration to calculate Poisson Least-squares (only affect LSMR mode).
47 | Expect positive integers.
48 | Defaults to `None`.
49 |
50 | `only_mode (str)`: Skip blending, only run one pass per sequence.
51 | Valid values:
52 | `MODE_FWD = "forward"` (Will only run forward mode if `sequence.mode` is blend)
53 |
54 | `MODE_REV = "reverse"` (Will only run reverse mode if `sequence.mode` is blend)
55 |
56 | Defaults to `MODE_NON = "none"`.
57 |
58 | ### Masking params
59 | `do_mask (bool)`: Whether to apply mask. Defaults to `False`.
60 |
61 | `pre_mask (bool)`: Whether to mask the inputs and styles before `RUN` or after.
62 | Pre-mask takes ~2x time to run per frame. Could be due to Ebsynth.dll implementation.
63 | Defaults to `False`.
64 |
65 | `feather (int)`: Feather Gaussian radius to apply on the mask results. Only affect if `return_masked_only == False`.
66 | Expects integers. Defaults to `0`.
67 | """
68 |
69 | def __init__(
70 | self,
71 | uniformity=3500.0,
72 | patchsize=7,
73 | pyramidlevels=6,
74 | searchvoteiters=12,
75 | patchmatchiters=6,
76 | extrapass3x3=True,
77 | edg_wgt=1.0,
78 | img_wgt=6.0,
79 | pos_wgt=2.0,
80 | wrp_wgt=0.5,
81 | use_gpu=False,
82 | use_lsqr=True,
83 | use_poisson_cupy=False,
84 | poisson_maxiter: int | None = None,
85 | only_mode=EasySequence.MODE_NON,
86 | do_mask=False,
87 | pre_mask=False,
88 | feather=0,
89 | ) -> None:
90 | # Ebsynth gen params
91 | self.uniformity = uniformity
92 | """Uniformity weight for the style transfer.
93 | Reasonable values are between `500-15000`.
94 |
95 | Defaults to `3500.0`."""
96 | self.patchsize = patchsize
97 | """Size of the patches [`NxN`]. Must be an odd number `>= 3`.
98 | Defaults to `7`"""
99 |
100 | self.pyramidlevels = pyramidlevels
101 | """Number of pyramid levels.
102 | Larger values useful for things like color transfer.
103 |
104 | Defaults to 6."""
105 |
106 | self.searchvoteiters = searchvoteiters
107 | """Number of search/vote iterations.
108 | Defaults to `12`"""
109 |
110 | self.patchmatchiters = patchmatchiters
111 | """Number of Patch-Match iterations. The larger, the longer it takes.
112 | Defaults to `6`"""
113 |
114 | self.extrapass3x3 = extrapass3x3
115 | """Perform additional polishing pass with 3x3 patches at the finest level.
116 | Defaults to `True`"""
117 |
118 | # Weights
119 | self.edg_wgt = edg_wgt
120 | """Edge detect weights. Defaults to `1.0`"""
121 |
122 | self.img_wgt = img_wgt
123 | """Original image weights. Defaults to `6.0`"""
124 |
125 | self.pos_wgt = pos_wgt
126 | """Flow position warping weights. Defaults to `2.0`"""
127 |
128 | self.wrp_wgt = wrp_wgt
129 | """Warped style image weight. Defaults to `0.5`"""
130 |
131 | # Blend params
132 | self.use_gpu = use_gpu
133 | """Use GPU for Histogram Blending (Only affect Blend mode). Faster than CPU.
134 | Defaults to `False`"""
135 |
136 | self.use_lsqr = use_lsqr
137 | """Use LSQR (Least-squares solver) instead of LSMR (Iterative solver for least-squares)
138 | for Poisson blending step. LSQR often yield better results.
139 |
140 | May change to LSMR for speed (depends).
141 | Defaults to `True`"""
142 |
143 | self.use_poisson_cupy = use_poisson_cupy
144 | """Use Cupy GPU acceleration for Poisson blending step.
145 | Uses LSMR (overrides `use_lsqr`). May not yield better speed.
146 |
147 | Defaults to `False`"""
148 |
149 | self.poisson_maxiter = poisson_maxiter
150 | """Max iteration to calculate Poisson Least-squares (only affect LSMR mode). Expect positive integers.
151 |
152 | Defaults to `None`"""
153 |
154 | # No blending mode
155 | self.only_mode = only_mode
156 | """Skip blending, only run one pass per sequence.
157 |
158 | Valid values:
159 | `MODE_FWD = "forward"` (Will only run forward mode if `sequence.mode` is blend)
160 |
161 | `MODE_REV = "reverse"` (Will only run reverse mode if `sequence.mode` is blend)
162 |
163 | Defaults to `MODE_NON = "none"`
164 | """
165 |
166 | # Skip adding last style frame if blending
167 | self.skip_blend_style_last = False
168 | """Skip adding last style frame if blending. Internal variable"""
169 |
170 | # Masking mode
171 | self.do_mask = do_mask
172 | """Whether to apply mask. Defaults to `False`"""
173 |
174 | self.pre_mask = pre_mask
175 | """Whether to mask the inputs and styles before `RUN` or after.
176 |
177 | Pre-mask takes ~2x time to run per frame. Could be due to Ebsynth.dll implementation.
178 |
179 | Defaults to `False`"""
180 |
181 | self.feather = feather
182 | """Feather Gaussian radius to apply on the mask results. Only affect if `return_masked_only == False`.
183 |
184 | Expects integers. Defaults to `0`"""
185 |
186 | def get_ebsynth_cfg(self):
187 | return {
188 | "uniformity": self.uniformity,
189 | "patchsize": self.patchsize,
190 | "pyramidlevels": self.pyramidlevels,
191 | "searchvoteiters": self.searchvoteiters,
192 | "patchmatchiters": self.patchmatchiters,
193 | "extrapass3x3": self.extrapass3x3,
194 | }
195 |
196 | def get_blender_cfg(self):
197 | return {
198 | "use_gpu": self.use_gpu,
199 | "use_lsqr": self.use_lsqr,
200 | "use_poisson_cupy": self.use_poisson_cupy,
201 | "poisson_maxiter": self.poisson_maxiter,
202 | }
203 |
204 |
205 | class PositionalGuide:
206 | def __init__(self) -> None:
207 | self.coord_map = None
208 |
209 | def get_coord_maps(self, warp: Warp):
210 | h, w = warp.H, warp.W
211 |
212 | # Create x and y coordinates
213 | x = np.linspace(0, 1, w)
214 | y = np.linspace(0, 1, h)
215 |
216 | # Use numpy's meshgrid to create 2D coordinate arrays
217 | xx, yy = np.meshgrid(x, y)
218 |
219 | # Stack the coordinates into a single 3D array
220 | self.coord_map = np.stack((xx, yy, np.zeros_like(xx)), axis=-1).astype(
221 | np.float32
222 | )
223 |
224 | def get_or_create_coord_maps(self, warp: Warp):
225 | if self.coord_map is None is None:
226 | self.get_coord_maps(warp)
227 | return self.coord_map
228 |
229 | def create_from_flow(
230 | self, flow: np.ndarray, original_size: tuple[int, ...], warp: Warp
231 | ):
232 | coord_map = self.get_or_create_coord_maps(warp)
233 | coord_map_warped = warp.run_warping(coord_map, flow)
234 |
235 | coord_map_warped[..., :2] = coord_map_warped[..., :2] % 1
236 |
237 | if coord_map_warped.shape[:2] != original_size:
238 | coord_map_warped = cv2.resize(
239 | coord_map_warped, original_size, interpolation=cv2.INTER_LINEAR
240 | )
241 |
242 | g_pos = (coord_map_warped * 255).astype(np.uint8)
243 | self.coord_map = coord_map_warped.copy()
244 |
245 | return g_pos
246 |
247 |
248 | class EdgeConfig:
249 | # PST
250 | PST_S = 0.3
251 | PST_W = 15
252 | PST_SIG_LPF = 0.15
253 | PST_MIN = 0.05
254 | PST_MAX = 0.9
255 |
256 | # PAGE
257 | PAGE_M1 = 0
258 | PAGE_M2 = 0.35
259 | PAGE_SIG1 = 0.05
260 | PAGE_SIG2 = 0.8
261 | PAGE_S1 = 0.8
262 | PAGE_S2 = 0.8
263 | PAGE_SIG_LPF = 0.1
264 | PAGE_MIN = 0.0
265 | PAGE_MAX = 0.9
266 |
267 | MORPH_FLAG = 1
268 |
269 | def __init__(self, **kwargs):
270 | # PST attributes
271 | self.pst_s = kwargs.get("S", self.PST_S)
272 | self.pst_w = kwargs.get("W", self.PST_W)
273 | self.pst_sigma_lpf = kwargs.get("sigma_LPF", self.PST_SIG_LPF)
274 | self.pst_thresh_min = kwargs.get("thresh_min", self.PST_MIN)
275 | self.pst_thresh_max = kwargs.get("thresh_max", self.PST_MAX)
276 |
277 | # PAGE attributes
278 | self.page_mu_1 = kwargs.get("mu_1", self.PAGE_M1)
279 | self.page_mu_2 = kwargs.get("mu_2", self.PAGE_M2)
280 | self.page_sigma_1 = kwargs.get("sigma_1", self.PAGE_SIG1)
281 | self.page_sigma_2 = kwargs.get("sigma_2", self.PAGE_SIG2)
282 | self.page_s1 = kwargs.get("S1", self.PAGE_S1)
283 | self.page_s2 = kwargs.get("S2", self.PAGE_S2)
284 | self.page_sigma_lpf = kwargs.get("sigma_LPF", self.PAGE_SIG_LPF)
285 | self.page_thresh_min = kwargs.get("thresh_min", self.PAGE_MIN)
286 | self.page_thresh_max = kwargs.get("thresh_max", self.PAGE_MAX)
287 |
288 | self.morph_flag = kwargs.get("morph_flag", self.MORPH_FLAG)
289 |
290 | @classmethod
291 | def get_pst_default(cls) -> dict:
292 | return {
293 | "S": cls.PST_S,
294 | "W": cls.PST_W,
295 | "sigma_LPF": cls.PST_SIG_LPF,
296 | "thresh_min": cls.PST_MIN,
297 | "thresh_max": cls.PST_MAX,
298 | "morph_flag": cls.MORPH_FLAG,
299 | }
300 |
301 | @classmethod
302 | def get_page_default(cls) -> dict:
303 | return {
304 | "mu_1": cls.PAGE_M1,
305 | "mu_2": cls.PAGE_M2,
306 | "sigma_1": cls.PAGE_SIG1,
307 | "sigma_2": cls.PAGE_SIG2,
308 | "S1": cls.PAGE_S1,
309 | "S2": cls.PAGE_S2,
310 | "sigma_LPF": cls.PAGE_SIG_LPF,
311 | "thresh_min": cls.PAGE_MIN,
312 | "thresh_max": cls.PAGE_MAX,
313 | "morph_flag": cls.MORPH_FLAG,
314 | }
315 |
316 | def get_pst_current(self) -> dict:
317 | return {
318 | "S": self.pst_s,
319 | "W": self.pst_w,
320 | "sigma_LPF": self.pst_sigma_lpf,
321 | "thresh_min": self.pst_thresh_min,
322 | "thresh_max": self.pst_thresh_max,
323 | "morph_flag": self.morph_flag,
324 | }
325 |
326 | def get_page_current(self) -> dict:
327 | return {
328 | "mu_1": self.page_mu_1,
329 | "mu_2": self.page_mu_2,
330 | "sigma_1": self.page_sigma_1,
331 | "sigma_2": self.page_sigma_2,
332 | "S1": self.page_s1,
333 | "S2": self.page_s2,
334 | "sigma_LPF": self.page_sigma_lpf,
335 | "thresh_min": self.page_thresh_min,
336 | "thresh_max": self.page_thresh_max,
337 | "morph_flag": self.morph_flag,
338 | }
339 |
--------------------------------------------------------------------------------
/ezsynth/aux_computations.py:
--------------------------------------------------------------------------------
1 | import tqdm
2 | import numpy as np
3 | from .edge_detection import EdgeDetector
4 |
5 |
6 | def precompute_edge_guides(
7 | img_frs_seq: list[np.ndarray], edge_method: str
8 | ) -> list[np.ndarray]:
9 | edge_detector = EdgeDetector(edge_method)
10 | edge_maps = []
11 | for img_fr in tqdm.tqdm(img_frs_seq, desc="Calculating edge maps"):
12 | edge_maps.append(edge_detector.compute_edge(img_fr))
13 | return edge_maps
14 |
15 |
--------------------------------------------------------------------------------
/ezsynth/aux_flow_viz.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def make_colorwheel():
4 | """
5 | Generates a color wheel for optical flow visualization.
6 | """
7 | RY, YG, GC, CB, BM, MR = 15, 6, 4, 11, 13, 6
8 | ncols = RY + YG + GC + CB + BM + MR
9 | colorwheel = np.zeros((ncols, 3), dtype=np.uint8)
10 | col = 0
11 |
12 | colorwheel[col : col + RY, 0] = 255
13 | colorwheel[col : col + RY, 1] = np.floor(255 * np.arange(0, RY) / RY).astype(
14 | np.uint8
15 | )
16 | col += RY
17 |
18 | colorwheel[col : col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG).astype(
19 | np.uint8
20 | )
21 | colorwheel[col : col + YG, 1] = 255
22 | col += YG
23 |
24 | colorwheel[col : col + GC, 1] = 255
25 | colorwheel[col : col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC).astype(
26 | np.uint8
27 | )
28 | col += GC
29 |
30 | colorwheel[col : col + CB, 1] = 255 - np.floor(255 * np.arange(0, CB) / CB).astype(
31 | np.uint8
32 | )
33 | colorwheel[col : col + CB, 2] = 255
34 | col += CB
35 |
36 | colorwheel[col : col + BM, 2] = 255
37 | colorwheel[col : col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM).astype(
38 | np.uint8
39 | )
40 | col += BM
41 |
42 | colorwheel[col : col + MR, 2] = 255 - np.floor(255 * np.arange(0, MR) / MR).astype(
43 | np.uint8
44 | )
45 | colorwheel[col : col + MR, 0] = 255
46 |
47 | return colorwheel
48 |
49 |
50 | def flow_uv_to_colors(u, v, convert_to_bgr=False):
51 | """
52 | Applies the flow color wheel to (possibly clipped) flow components u and v.
53 | """
54 | flow_image = np.zeros((*u.shape, 3), dtype=np.uint8)
55 | colorwheel = make_colorwheel()
56 | ncols = colorwheel.shape[0]
57 |
58 | rad = np.sqrt(np.square(u) + np.square(v))
59 | a = np.arctan2(-v, -u) / np.pi
60 | fk = (a + 1) / 2 * (ncols - 1)
61 | k0 = np.floor(fk).astype(np.int32)
62 | k1 = (k0 + 1) % ncols
63 | f = fk - k0
64 |
65 | for i in range(3):
66 | tmp = colorwheel[:, i]
67 | col0 = tmp[k0] / 255.0
68 | col1 = tmp[k1] / 255.0
69 | col = (1 - f) * col0 + f * col1
70 |
71 | idx = rad <= 1
72 | col[idx] = 1 - rad[idx] * (1 - col[idx])
73 | col[~idx] *= 0.75
74 |
75 | ch_idx = 2 - i if convert_to_bgr else i
76 | flow_image[..., ch_idx] = (255 * col).astype(np.uint8)
77 |
78 | return flow_image
79 |
80 |
81 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
82 | """
83 | Converts a two-dimensional flow image to a color image for visualization.
84 | """
85 | assert (
86 | flow_uv.ndim == 3 and flow_uv.shape[2] == 2
87 | ), "Input flow must have shape [H,W,2]"
88 |
89 | if clip_flow is not None:
90 | flow_uv = np.clip(flow_uv, 0, clip_flow)
91 |
92 | u, v = flow_uv[..., 0], flow_uv[..., 1]
93 | rad = np.sqrt(np.square(u) + np.square(v))
94 | rad_max = np.max(rad)
95 |
96 | epsilon = 1e-5
97 | u = u / (rad_max + epsilon)
98 | v = v / (rad_max + epsilon)
99 |
100 | return flow_uv_to_colors(u, v, convert_to_bgr)
101 |
--------------------------------------------------------------------------------
/ezsynth/aux_masker.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import tqdm
4 |
5 |
6 | def apply_mask(image: np.ndarray, mask: np.ndarray):
7 | masked_image = cv2.bitwise_and(image, image, mask=mask)
8 | return masked_image.astype(np.uint8)
9 |
10 |
11 | def apply_masks(images: list[np.ndarray], masks: list[np.ndarray]):
12 | len_img = len(images)
13 | len_msk = len(masks)
14 | if len_img != len_msk:
15 | raise ValueError(f"[{len_img=}], [{len_msk=}]")
16 |
17 | masked_images = []
18 | for i in range(len_img):
19 | masked_images.append(apply_mask(images[i], masks[i]))
20 |
21 | return masked_images
22 |
23 |
24 | def apply_masks_idxes(
25 | images: list[np.ndarray], masks: list[np.ndarray], img_idxes: list[int]
26 | ):
27 | masked_images = []
28 | for i, idx in enumerate(img_idxes):
29 | masked_images.append(apply_mask(images[i], masks[idx]))
30 | return masked_images
31 |
32 |
33 | def apply_masked_back(
34 | original: np.ndarray, processed: np.ndarray, mask: np.ndarray, feather_radius=0
35 | ):
36 | if feather_radius > 0:
37 | mask_blurred = cv2.GaussianBlur(mask, (feather_radius, feather_radius), 0)
38 | mask_blurred = mask_blurred.astype(np.float32) / 255.0
39 |
40 | mask_inv_blurred = 1.0 - mask_blurred
41 |
42 | # Expand dimensions to match the number of channels in the original image
43 | mask_blurred_expanded = np.expand_dims(mask_blurred, axis=-1)
44 | mask_inv_blurred_expanded = np.expand_dims(mask_inv_blurred, axis=-1)
45 |
46 | background = original * mask_inv_blurred_expanded
47 | foreground = processed * mask_blurred_expanded
48 |
49 | # Combine the background and foreground
50 | result = background + foreground
51 | result = result.astype(np.uint8)
52 |
53 | else:
54 | mask = mask.astype(np.float32) / 255.0
55 | mask_inv = 1.0 - mask
56 | mask_expanded = np.expand_dims(mask, axis=-1)
57 | mask_inv_expanded = np.expand_dims(mask_inv, axis=-1)
58 | background = original * mask_inv_expanded
59 | foreground = processed * mask_expanded
60 | result = background + foreground
61 | result = result.astype(np.uint8)
62 |
63 | return result
64 |
65 |
66 | def apply_masked_back_seq(
67 | img_frs_seq: list[np.ndarray],
68 | styled_msk_frs: list[np.ndarray],
69 | mask_frs_seq: list[np.ndarray],
70 | feather=0,
71 | ):
72 | len_img = len(img_frs_seq)
73 | len_stl = len(styled_msk_frs)
74 | len_msk = len(mask_frs_seq)
75 |
76 | if len_img != len_stl != len_msk:
77 | raise ValueError(f"Lengths not match. [{len_img=}, {len_stl=}, {len_msk=}]")
78 |
79 | backed_seq = []
80 |
81 | for i in tqdm.tqdm(range(len_img), desc="Adding masked back"):
82 | backed_seq.append(
83 | apply_masked_back(
84 | img_frs_seq[i], styled_msk_frs[i], mask_frs_seq[i], feather
85 | )
86 | )
87 |
88 | return backed_seq
89 |
--------------------------------------------------------------------------------
/ezsynth/aux_run.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import tqdm
4 |
5 | from .aux_classes import PositionalGuide, RunConfig
6 | from .utils._ebsynth import ebsynth
7 | from .utils.blend.blender import Blend
8 | from .utils.flow_utils.OpticalFlow import RAFT_flow
9 | from .utils.flow_utils.warp import Warp
10 | from .sequences import EasySequence
11 |
12 |
13 | def run_a_pass(
14 | seq: EasySequence,
15 | seq_mode: str,
16 | img_frs_seq: list[np.ndarray],
17 | style: np.ndarray,
18 | edge: list[np.ndarray],
19 | cfg: RunConfig,
20 | rafter: RAFT_flow,
21 | eb: ebsynth,
22 | ):
23 | stylized_frames: list[np.ndarray] = [style]
24 | err_list: list[np.ndarray] = []
25 | ORIGINAL_SIZE = img_frs_seq[0].shape[1::-1]
26 |
27 | start, end, step, is_forward = (
28 | get_forward(seq) if seq_mode == EasySequence.MODE_FWD else get_backward(seq)
29 | )
30 | warp = Warp(img_frs_seq[start])
31 | print(f"{'Forward' if is_forward else 'Reverse'} mode. {start=}, {end=}, {step=}")
32 | flows = []
33 | poses = []
34 | pos_guider = PositionalGuide()
35 |
36 | for i in tqdm.tqdm(range(start, end, step), "Generating"):
37 | flow = get_flow(img_frs_seq, rafter, step, is_forward, i)
38 | flows.append(flow)
39 |
40 | poster = pos_guider.create_from_flow(flow, ORIGINAL_SIZE, warp)
41 | poses.append(poster)
42 | warped_img = get_warped_img(stylized_frames, ORIGINAL_SIZE, step, warp, flow)
43 |
44 | stylized_img, err = eb.run(
45 | style,
46 | guides=[
47 | (edge[start], edge[i + step], cfg.edg_wgt), # Slower with premask
48 | (img_frs_seq[start], img_frs_seq[i + step], cfg.img_wgt),
49 | (poses[0], poster, cfg.pos_wgt),
50 | (style, warped_img, cfg.wrp_wgt), # Slower with premask
51 | ],
52 | )
53 | stylized_frames.append(stylized_img)
54 | err_list.append(err)
55 |
56 | if not is_forward:
57 | stylized_frames = stylized_frames[::-1]
58 | err_list = err_list[::-1]
59 | flows = flows[::-1]
60 |
61 | return stylized_frames, err_list, flows
62 |
63 |
64 | def get_warped_img(
65 | stylized_frames: list[np.ndarray], ORIGINAL_SIZE, step: int, warp: Warp, flow
66 | ):
67 | stylized_img = stylized_frames[-1] / 255.0
68 | warped_img = warp.run_warping(stylized_img, flow * (-step))
69 | warped_img = cv2.resize(warped_img, ORIGINAL_SIZE)
70 | return warped_img
71 |
72 |
73 | def get_flow(
74 | img_frs_seq: list[np.ndarray],
75 | rafter: RAFT_flow,
76 | step: int,
77 | is_forward: bool,
78 | i: int,
79 | ):
80 | if is_forward:
81 | flow = rafter._compute_flow(img_frs_seq[i], img_frs_seq[i + step])
82 | else:
83 | flow = rafter._compute_flow(img_frs_seq[i + step], img_frs_seq[i])
84 | return flow
85 |
86 |
87 | def run_scratch(
88 | seq: EasySequence,
89 | img_frs_seq: list[np.ndarray],
90 | style_frs: list[np.ndarray],
91 | edge: list[np.ndarray],
92 | cfg: RunConfig,
93 | rafter: RAFT_flow,
94 | eb: ebsynth,
95 | ):
96 | if seq.mode == EasySequence.MODE_BLN and cfg.only_mode != EasySequence.MODE_NON:
97 | print(f"{cfg.only_mode} Only")
98 | stylized_frames, err_list, flow = run_a_pass(
99 | seq,
100 | cfg.only_mode,
101 | img_frs_seq,
102 | style_frs[seq.style_idxs[0]]
103 | if cfg.only_mode == EasySequence.MODE_FWD
104 | else style_frs[seq.style_idxs[1]],
105 | edge,
106 | cfg,
107 | rafter,
108 | eb,
109 | )
110 | return stylized_frames, err_list, flow
111 |
112 | if seq.mode != EasySequence.MODE_BLN:
113 | stylized_frames, err_list, flow = run_a_pass(
114 | seq,
115 | seq.mode,
116 | img_frs_seq,
117 | style_frs[seq.style_idxs[0]],
118 | edge,
119 | cfg,
120 | rafter,
121 | eb,
122 | )
123 | return stylized_frames, err_list, flow
124 |
125 | print("Blending mode")
126 |
127 | style_fwd, err_fwd, flow_fwd = run_a_pass(
128 | seq,
129 | EasySequence.MODE_FWD,
130 | img_frs_seq,
131 | style_frs[seq.style_idxs[0]],
132 | edge,
133 | cfg,
134 | rafter,
135 | eb,
136 | )
137 |
138 | style_bwd, err_bwd, _ = run_a_pass(
139 | seq,
140 | EasySequence.MODE_REV,
141 | img_frs_seq,
142 | style_frs[seq.style_idxs[1]],
143 | edge,
144 | cfg,
145 | rafter,
146 | eb,
147 | )
148 |
149 | return run_blend(img_frs_seq, style_fwd, style_bwd, err_fwd, err_bwd, flow_fwd, cfg)
150 |
151 |
152 | def run_blend(
153 | img_frs_seq: list[np.ndarray],
154 | style_fwd: list[np.ndarray],
155 | style_bwd: list[np.ndarray],
156 | err_fwd: list[np.ndarray],
157 | err_bwd: list[np.ndarray],
158 | flow_fwd: list[np.ndarray],
159 | cfg: RunConfig,
160 | ):
161 | blender = Blend(**cfg.get_blender_cfg())
162 |
163 | err_masks = blender._create_selection_mask(err_fwd, err_bwd)
164 |
165 | warped_masks = blender._warping_masks(img_frs_seq[0], flow_fwd, err_masks)
166 |
167 | hist_blends = blender._hist_blend(style_fwd, style_bwd, warped_masks)
168 |
169 | blends = blender._reconstruct(style_fwd, style_bwd, warped_masks, hist_blends)
170 |
171 | if not cfg.skip_blend_style_last:
172 | blends.append(style_bwd[-1])
173 |
174 | return blends, warped_masks, flow_fwd
175 |
176 |
177 | def get_forward(seq: EasySequence):
178 | start = seq.fr_start_idx
179 | end = seq.fr_end_idx
180 | step = 1
181 | is_forward = True
182 | return start, end, step, is_forward
183 |
184 |
185 | def get_backward(seq: EasySequence):
186 | start = seq.fr_end_idx
187 | end = seq.fr_start_idx
188 | step = -1
189 | is_forward = False
190 | return start, end, step, is_forward
191 |
--------------------------------------------------------------------------------
/ezsynth/aux_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 |
4 | import cv2
5 | import numpy as np
6 | import torch
7 | import tqdm
8 |
9 |
10 | def validate_option(option, values, default):
11 | return option if option in values else default
12 |
13 |
14 | def save_to_folder(
15 | output_folder: str, base_file_name: str, result_array: np.ndarray
16 | ) -> str:
17 | os.makedirs(output_folder, exist_ok=True)
18 | output_file_path = os.path.join(output_folder, base_file_name)
19 | cv2.imwrite(output_file_path, result_array)
20 | return output_file_path
21 |
22 |
23 | def validate_and_read_img(img: str | np.ndarray) -> np.ndarray:
24 | if isinstance(img, str):
25 | if os.path.isfile(img):
26 | img = cv2.imread(img)
27 | return img
28 | raise ValueError(f"Path does not exist: {img}")
29 |
30 | if isinstance(img, np.ndarray):
31 | if img.shape[-1] == 3:
32 | return img
33 | raise ValueError(f"Expected 3 channels image. Style shape is {img.shape}")
34 |
35 |
36 | def load_guide(src_path: str, tgt_path: str, weight=1.0):
37 | src_img = validate_and_read_img(src_path)
38 | tgt_img = validate_and_read_img(tgt_path)
39 | return (src_img, tgt_img, weight)
40 |
41 |
42 | def read_frames_from_paths(lst: list[str]) -> list[np.ndarray]:
43 | img_arr_seq: list[np.ndarray] = []
44 | err_frame = -1
45 | try:
46 | total = len(lst)
47 | for err_frame, img_path in tqdm.tqdm(
48 | enumerate(lst), desc="Reading images: ", total=total
49 | ):
50 | img_arr = validate_and_read_img(img_path)
51 | img_arr_seq.append(img_arr)
52 | else:
53 | print(f"Read {len(img_arr_seq)} frames successfully")
54 | return img_arr_seq
55 | except Exception as e:
56 | raise ValueError(f"Error reading frame {err_frame}: {e}")
57 |
58 |
59 | def read_masks_from_paths(lst: list[str]) -> list[np.ndarray]:
60 | msk_arr_seq: list[np.ndarray] = []
61 | err_frame = -1
62 | try:
63 | total = len(lst)
64 | for err_frame, img_path in tqdm.tqdm(
65 | enumerate(lst), desc="Reading masks: ", total=total
66 | ):
67 | msk = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
68 | msk_arr_seq.append(msk)
69 | else:
70 | print(f"Read {len(msk_arr_seq)} frames successfully")
71 | return msk_arr_seq
72 | except Exception as e:
73 | raise ValueError(f"Error reading mask frame {err_frame}: {e}")
74 |
75 |
76 | img_extensions = (".png", ".jpg", ".jpeg")
77 | img_path_pattern = re.compile(r"(\d+)(?=\.(jpg|jpeg|png)$)")
78 |
79 |
80 | def get_sequence_indices(seq_folder_path: str) -> list[str]:
81 | if not os.path.isdir(seq_folder_path):
82 | raise ValueError(f"Path does not exist: {seq_folder_path}")
83 | file_names = os.listdir(seq_folder_path)
84 | file_names = sorted(
85 | file_names,
86 | key=lambda x: [int(c) if c.isdigit() else c for c in re.split("([0-9]+)", x)],
87 | )
88 | img_file_paths = [
89 | os.path.join(seq_folder_path, file_name)
90 | for file_name in file_names
91 | if file_name.lower().endswith(img_extensions)
92 | ]
93 | if not img_file_paths:
94 | raise ValueError("No image files found in the directory.")
95 | return img_file_paths
96 |
97 |
98 | def extract_indices(lst: list[str]):
99 | return sorted(int(img_path_pattern.findall(img_name)[-1][0]) for img_name in lst)
100 |
101 |
102 | def is_valid_file_path(input_path: str | list[str]) -> bool:
103 | return isinstance(input_path, str) and os.path.isfile(input_path)
104 |
105 |
106 | def validate_file_or_folder_to_lst(
107 | input_paths: str | list[str], type_name=""
108 | ) -> list[str]:
109 | if is_valid_file_path(input_paths):
110 | return [input_paths] # type: ignore
111 | if isinstance(input_paths, list):
112 | valid_paths = [path for path in input_paths if is_valid_file_path(path)]
113 | if valid_paths:
114 | print(f"Received {len(valid_paths)} {type_name} files")
115 | return valid_paths
116 | raise FileNotFoundError(f"No valid {type_name} file(s) were found. {input_paths}")
117 |
118 |
119 | def setup_src_from_folder(
120 | seq_folder_path: str,
121 | ) -> tuple[list[str], list[int], list[np.ndarray]]:
122 | img_file_paths = get_sequence_indices(seq_folder_path)
123 | img_idxes = extract_indices(img_file_paths)
124 | img_frs_seq = read_frames_from_paths(img_file_paths)
125 | return img_file_paths, img_idxes, img_frs_seq
126 |
127 |
128 | def setup_masks_from_folder(mask_folder_path: str):
129 | msk_file_paths = get_sequence_indices(mask_folder_path)
130 | msk_idxes = extract_indices(msk_file_paths)
131 | msk_frs_seq = read_masks_from_paths(msk_file_paths)
132 | return msk_file_paths, msk_idxes, msk_frs_seq
133 |
134 |
135 | def setup_src_from_lst(paths: list[str], type_name=""):
136 | val_paths = validate_file_or_folder_to_lst(paths, type_name)
137 | val_idxes = extract_indices(val_paths)
138 | frs_seq = read_frames_from_paths(val_paths)
139 | return val_paths, val_idxes, frs_seq
140 |
141 |
142 | def save_seq(results: list, output_folder, base_name="output", extension=".png"):
143 | if not results:
144 | print("Error: No results to save.")
145 | return
146 | for i in range(len(results)):
147 | save_to_folder(
148 | output_folder,
149 | f"{base_name}{i:03}{extension}",
150 | results[i],
151 | )
152 | else:
153 | print("All results saved successfully")
154 | return
155 |
156 |
157 | def replace_zeros_tensor(image: torch.Tensor, replace_value: int = 1) -> torch.Tensor:
158 | zero_mask = image == 0
159 | replace_tensor = torch.full_like(image, replace_value)
160 | return torch.where(zero_mask, replace_tensor, image)
161 |
162 |
163 | def replace_zeros_np(image: np.ndarray, replace_value: int = 1) -> np.ndarray:
164 | zero_mask = image == 0
165 | replace_array = np.full_like(image, replace_value)
166 | return np.where(zero_mask, replace_array, image)
167 |
--------------------------------------------------------------------------------
/ezsynth/constants.py:
--------------------------------------------------------------------------------
1 | EDGE_METHODS = ["PAGE", "PST", "Classic"]
2 | DEFAULT_EDGE_METHOD = "Classic"
3 |
4 | FLOW_MODELS = ["sintel", "kitti"]
5 | DEFAULT_FLOW_MODEL = "sintel"
6 |
7 | FLOW_ARCHS = ["RAFT", "EF_RAFT", "FLOW_DIFF"]
8 | DEFAULT_FLOW_ARCH = "RAFT"
9 |
10 | EF_RAFT_MODELS = [
11 | "25000_ours-sintel",
12 | "ours_sintel",
13 | "ours-things",
14 | ]
15 | DEFAULT_EF_RAFT_MODEL = "25000_ours-sintel"
16 |
17 | FLOW_DIFF_MODEL = "FlowDiffuser-things"
18 |
--------------------------------------------------------------------------------
/ezsynth/edge_detection.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 |
5 | from phycv import PAGE_GPU, PST_GPU
6 |
7 | from .aux_classes import EdgeConfig
8 | from .aux_utils import replace_zeros_tensor
9 |
10 |
11 | class EdgeDetector:
12 | def __init__(self, method="PAGE"):
13 | """
14 | Initialize the edge detector.
15 |
16 | :param method: Edge detection method. Choose from 'PST', 'Classic', or 'PAGE'.
17 | :PST: Phase Stretch Transform (PST) edge detector. - Good overall structure,
18 | but not very detailed.
19 | :Classic: Classic edge detector. - A good balance between structure and detail.
20 | :PAGE: Phase and Gradient Estimation (PAGE) edge detector. -
21 | Great detail, great structure, but slow.
22 | """
23 | self.method = method
24 | self.device = "cuda"
25 | if method == "PST":
26 | self.pst_gpu = PST_GPU(device=self.device)
27 | elif method == "PAGE":
28 | self.page_gpu = PAGE_GPU(direction_bins=10, device=self.device)
29 | elif method == "Classic":
30 | size, sigma = 5, 6.0
31 | self.kernel = self.create_gaussian_kernel(size, sigma)
32 | self.pad_size = 16
33 |
34 | @staticmethod
35 | def create_gaussian_kernel(size, sigma):
36 | x, y = np.mgrid[-size // 2 + 1 : size // 2 + 1, -size // 2 + 1 : size // 2 + 1]
37 | g = np.exp(-((x**2 + y**2) / (2.0 * sigma**2)))
38 | return g / g.sum()
39 |
40 | def pad_image(self, img):
41 | return cv2.copyMakeBorder(
42 | img,
43 | self.pad_size,
44 | self.pad_size,
45 | self.pad_size,
46 | self.pad_size,
47 | cv2.BORDER_REFLECT,
48 | )
49 |
50 | def unpad_image(self, img):
51 | return img[self.pad_size : -self.pad_size, self.pad_size : -self.pad_size]
52 |
53 | def classic_preprocess(self, img):
54 | gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
55 | blurred = cv2.filter2D(gray, -1, self.kernel)
56 | edge_map = cv2.subtract(gray, blurred)
57 | edge_map = np.clip(edge_map + 128, 0, 255)
58 | return edge_map.astype(np.uint8)
59 |
60 | def pst_page_postprocess(self, edge_map: np.ndarray):
61 | edge_map = cv2.GaussianBlur(edge_map, (5, 5), 3)
62 | edge_map = edge_map * 255
63 | return edge_map.astype(np.uint8)
64 |
65 | def pst_run(
66 | self,
67 | input_data: np.ndarray,
68 | S,
69 | W,
70 | sigma_LPF,
71 | thresh_min,
72 | thresh_max,
73 | morph_flag,
74 | ):
75 | input_img = cv2.cvtColor(input_data, cv2.COLOR_BGR2GRAY)
76 |
77 | padded_img = self.pad_image(input_img)
78 |
79 | self.pst_gpu.h = padded_img.shape[0]
80 | self.pst_gpu.w = padded_img.shape[1]
81 |
82 | self.pst_gpu.img = torch.from_numpy(padded_img).to(self.pst_gpu.device)
83 | # If input has too many zeros the model returns NaNs for some reason
84 | self.pst_gpu.img = replace_zeros_tensor(self.pst_gpu.img, 1)
85 |
86 | self.pst_gpu.init_kernel(S, W)
87 | self.pst_gpu.apply_kernel(sigma_LPF, thresh_min, thresh_max, morph_flag)
88 |
89 | edge_map = self.pst_gpu.pst_output.cpu().numpy()
90 | edge_map = self.unpad_image(edge_map)
91 |
92 | return edge_map
93 |
94 | def page_run(
95 | self,
96 | input_data: np.ndarray,
97 | mu_1,
98 | mu_2,
99 | sigma_1,
100 | sigma_2,
101 | S1,
102 | S2,
103 | sigma_LPF,
104 | thresh_min,
105 | thresh_max,
106 | morph_flag,
107 | ):
108 | input_img = cv2.cvtColor(input_data, cv2.COLOR_BGR2GRAY)
109 | padded_img = self.pad_image(input_img)
110 |
111 | self.page_gpu.h = padded_img.shape[0]
112 | self.page_gpu.w = padded_img.shape[1]
113 |
114 | self.page_gpu.img = torch.from_numpy(padded_img).to(self.page_gpu.device)
115 | # If input has too many zeros the model returns NaNs for some reason
116 | self.page_gpu.img = replace_zeros_tensor(self.page_gpu.img, 1)
117 |
118 | self.page_gpu.init_kernel(mu_1, mu_2, sigma_1, sigma_2, S1, S2)
119 | self.page_gpu.apply_kernel(sigma_LPF, thresh_min, thresh_max, morph_flag)
120 | self.page_gpu.create_page_edge()
121 |
122 | edge_map = self.page_gpu.page_edge.cpu().numpy()
123 | edge_map = self.unpad_image(edge_map)
124 | return edge_map
125 |
126 | def compute_edge(self, input_data: np.ndarray):
127 | edge_map = None
128 | if self.method == "PST":
129 | edge_map = self.pst_run(input_data, **EdgeConfig.get_pst_default())
130 | edge_map = self.pst_page_postprocess(edge_map)
131 | return edge_map
132 |
133 | if self.method == "Classic":
134 | edge_map = self.classic_preprocess(input_data)
135 | return edge_map
136 |
137 | if self.method == "PAGE":
138 | edge_map = self.page_run(input_data, **EdgeConfig.get_page_default())
139 | edge_map = self.pst_page_postprocess(edge_map)
140 | return edge_map
141 | return edge_map
142 |
143 |
--------------------------------------------------------------------------------
/ezsynth/main_ez.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import numpy as np
4 | import tqdm
5 |
6 | from .aux_flow_viz import flow_to_image
7 |
8 | from .aux_classes import RunConfig
9 | from .aux_computations import precompute_edge_guides
10 | from .aux_masker import (
11 | apply_masked_back_seq,
12 | apply_masks,
13 | apply_masks_idxes,
14 | )
15 | from .aux_run import run_scratch
16 | from .aux_utils import (
17 | setup_masks_from_folder,
18 | setup_src_from_folder,
19 | setup_src_from_lst,
20 | validate_and_read_img,
21 | validate_option,
22 | )
23 | from .constants import (
24 | DEFAULT_EDGE_METHOD,
25 | DEFAULT_EF_RAFT_MODEL,
26 | DEFAULT_FLOW_ARCH,
27 | DEFAULT_FLOW_MODEL,
28 | EDGE_METHODS,
29 | EF_RAFT_MODELS,
30 | FLOW_ARCHS,
31 | FLOW_DIFF_MODEL,
32 | FLOW_MODELS,
33 | )
34 | from .utils._ebsynth import ebsynth
35 | from .utils.flow_utils.OpticalFlow import RAFT_flow
36 | from .sequences import EasySequence, SequenceManager
37 |
38 |
39 | class EzsynthBase:
40 | def __init__(
41 | self,
42 | style_frs: list[np.ndarray],
43 | style_idxes: list[int],
44 | img_frs_seq: list[np.ndarray],
45 | cfg: RunConfig = RunConfig(),
46 | edge_method="Classic",
47 | raft_flow_model_name="sintel",
48 | do_mask=False,
49 | msk_frs_seq: list[np.ndarray] | None = None,
50 | flow_arch="RAFT",
51 | do_compute_edge=True,
52 | ) -> None:
53 | st = time.time()
54 |
55 | self.style_frs = style_frs
56 | self.style_idxes = style_idxes
57 | self.img_frs_seq = img_frs_seq
58 | self.msk_frs_seq = msk_frs_seq or []
59 |
60 | self.len_img = len(self.img_frs_seq)
61 | self.len_msk = len(self.msk_frs_seq)
62 | self.len_stl = len(self.style_idxes)
63 |
64 | self.msk_frs_seq = self.msk_frs_seq[: self.len_img]
65 |
66 | self.cfg = cfg
67 | self.edge_method = validate_option(
68 | edge_method, EDGE_METHODS, DEFAULT_EDGE_METHOD
69 | )
70 | self.flow_model = validate_option(
71 | raft_flow_model_name, FLOW_MODELS, DEFAULT_FLOW_MODEL
72 | )
73 |
74 | self.flow_arch = validate_option(flow_arch, FLOW_ARCHS, DEFAULT_FLOW_ARCH)
75 |
76 | if self.flow_arch == "RAFT":
77 | self.flow_model = validate_option(
78 | raft_flow_model_name, FLOW_MODELS, DEFAULT_FLOW_MODEL
79 | )
80 | elif self.flow_arch == "EF_RAFT":
81 | self.flow_model = validate_option(
82 | raft_flow_model_name, EF_RAFT_MODELS, DEFAULT_EF_RAFT_MODEL
83 | )
84 | elif self.flow_arch == "FLOW_DIFF":
85 | self.flow_model = FLOW_DIFF_MODEL
86 |
87 | self.cfg.do_mask = do_mask and self.len_msk > 0
88 | print(f"Masking mode: {self.cfg.do_mask}")
89 |
90 | if self.cfg.do_mask and len(self.msk_frs_seq) != len(self.img_frs_seq):
91 | raise ValueError(
92 | f"Missing frames: Masks={self.len_msk}, Expected {self.len_img}"
93 | )
94 |
95 | self.style_masked_frs = None
96 | if self.cfg.do_mask and self.cfg.pre_mask:
97 | self.masked_frs_seq = apply_masks(self.img_frs_seq, self.msk_frs_seq)
98 | self.style_masked_frs = apply_masks_idxes(
99 | self.style_frs, self.msk_frs_seq, self.style_idxes
100 | )
101 |
102 | manager = SequenceManager(
103 | 0,
104 | self.len_img - 1,
105 | self.len_stl,
106 | self.style_idxes,
107 | list(range(0, self.len_img)),
108 | )
109 |
110 | self.sequences, self.atlas = manager.create_sequences()
111 | self.num_seqs = len(self.sequences)
112 |
113 | self.edge_guides = []
114 | if do_compute_edge:
115 | self.edge_guides = precompute_edge_guides(
116 | self.masked_frs_seq
117 | if (self.cfg.do_mask and self.cfg.pre_mask)
118 | else self.img_frs_seq,
119 | self.edge_method,
120 | )
121 | self.rafter = RAFT_flow(model_name=self.flow_model, arch=self.flow_arch)
122 |
123 | self.eb = ebsynth(**cfg.get_ebsynth_cfg())
124 | self.eb.runner.initialize_libebsynth()
125 |
126 | print(f"Init Ezsynth took: {time.time() - st:.4f} s")
127 |
128 | def run_sequences(self, cfg_only_mode: str | None = None):
129 | stylized_frames, err_frames, _ = self.run_sequences_full(
130 | cfg_only_mode=cfg_only_mode, return_flow=False
131 | )
132 | return stylized_frames, err_frames
133 |
134 | def run_sequences_full(self, cfg_only_mode: str | None = None, return_flow=False):
135 | st = time.time()
136 |
137 | if len(self.edge_guides) == 0:
138 | raise ValueError("Edge guides were not computed. ")
139 | if len(self.edge_guides) != self.len_img:
140 | raise ValueError(
141 | f"Missing edge guides: Got {len(self.edge_guides)}, expected {self.len_img}"
142 | )
143 |
144 | if (
145 | cfg_only_mode is not None
146 | and cfg_only_mode in EasySequence.get_valid_modes()
147 | ):
148 | self.cfg.only_mode = cfg_only_mode
149 |
150 | no_skip_rev = False
151 |
152 | stylized_frames = []
153 | err_frames = []
154 | flow_frames = []
155 |
156 | img_seq = (
157 | self.masked_frs_seq
158 | if (self.cfg.do_mask and self.cfg.pre_mask)
159 | else self.img_frs_seq
160 | )
161 | stl_seq = (
162 | self.style_masked_frs
163 | if (self.cfg.do_mask and self.cfg.pre_mask)
164 | else self.style_frs
165 | )
166 |
167 | for i, seq in enumerate(self.sequences):
168 | if self._should_skip_blend_style_last(i):
169 | self.cfg.skip_blend_style_last = True
170 | else:
171 | self.cfg.skip_blend_style_last = False
172 |
173 | if self._should_rev_move_fr(i):
174 | seq.fr_start_idx += 1
175 | no_skip_rev = True
176 |
177 | tmp_stylized_frames, tmp_err_frames, tmp_flow = run_scratch(
178 | seq,
179 | img_seq,
180 | stl_seq,
181 | self.edge_guides,
182 | self.cfg,
183 | self.rafter,
184 | self.eb,
185 | )
186 |
187 | if self._should_remove_first_fr(i, no_skip_rev):
188 | tmp_stylized_frames.pop(0)
189 | tmp_err_frames.pop(0)
190 | tmp_flow.pop(0)
191 |
192 | no_skip_rev = False
193 |
194 | stylized_frames.extend(tmp_stylized_frames)
195 | err_frames.extend(tmp_err_frames)
196 | flow_frames.extend(tmp_flow)
197 |
198 | print(f"Run took: {time.time() - st:.4f} s")
199 |
200 | if self.cfg.do_mask:
201 | stylized_frames = apply_masked_back_seq(
202 | self.img_frs_seq, stylized_frames, self.msk_frs_seq, self.cfg.feather
203 | )
204 |
205 | final_flows: list[np.ndarray] = []
206 | if return_flow:
207 | for flow in tqdm.tqdm(flow_frames, desc="Converting flows"):
208 | final_flows.append(flow_to_image(flow, convert_to_bgr=True))
209 |
210 | return stylized_frames, err_frames, final_flows
211 |
212 | def _should_skip_blend_style_last(self, i: int) -> bool:
213 | if (
214 | self.cfg.only_mode == EasySequence.MODE_NON
215 | and i < self.num_seqs - 1
216 | and (
217 | self.atlas[i] == EasySequence.MODE_BLN
218 | or self.atlas[i + 1] != EasySequence.MODE_FWD
219 | )
220 | ):
221 | return True
222 | return False
223 |
224 | def _should_rev_move_fr(self, i: int) -> bool:
225 | if (
226 | i > 0
227 | and self.cfg.only_mode == EasySequence.MODE_REV
228 | and self.atlas[i] == EasySequence.MODE_BLN
229 | ):
230 | return True
231 | return False
232 |
233 | def _should_remove_first_fr(self, i: int, no_skip_rev: bool) -> bool:
234 | if i > 0 and not no_skip_rev:
235 | if (self.atlas[i - 1] == EasySequence.MODE_REV) or (
236 | self.atlas[i - 1] == EasySequence.MODE_BLN
237 | and (
238 | self.cfg.only_mode == EasySequence.MODE_FWD
239 | or (
240 | self.cfg.only_mode == EasySequence.MODE_REV
241 | and self.atlas[i] == EasySequence.MODE_FWD
242 | )
243 | )
244 | ):
245 | return True
246 | return False
247 |
248 |
249 | class Ezsynth(EzsynthBase):
250 | def __init__(
251 | self,
252 | style_paths: list[str],
253 | image_folder: str,
254 | cfg: RunConfig = RunConfig(),
255 | edge_method="Classic",
256 | raft_flow_model_name="sintel",
257 | mask_folder: str | None = None,
258 | do_mask=False,
259 | flow_arch="RAFT",
260 | ) -> None:
261 | _, img_idxes, img_frs_seq = setup_src_from_folder(image_folder)
262 | _, style_idxes, style_frs = setup_src_from_lst(style_paths, "style")
263 | msk_frs_seq = setup_masks_from_folder(mask_folder)[2] if do_mask else None
264 |
265 | if img_idxes[0] != 0:
266 | style_idxes = [idx - img_idxes[0] for idx in style_idxes]
267 |
268 | super().__init__(
269 | style_frs=style_frs,
270 | style_idxes=style_idxes,
271 | img_frs_seq=img_frs_seq,
272 | cfg=cfg,
273 | edge_method=edge_method,
274 | raft_flow_model_name=raft_flow_model_name,
275 | do_mask=do_mask,
276 | msk_frs_seq=msk_frs_seq,
277 | flow_arch=flow_arch,
278 | )
279 |
280 |
281 | class ImageSynthBase:
282 | def __init__(
283 | self,
284 | style_img: np.ndarray,
285 | src_img: np.ndarray,
286 | tgt_img: np.ndarray,
287 | cfg: RunConfig = RunConfig(),
288 | ) -> None:
289 | self.style_img = style_img
290 | self.src_img = src_img
291 | self.tgt_img = tgt_img
292 | self.cfg = cfg
293 |
294 | st = time.time()
295 |
296 | self.eb = ebsynth(**cfg.get_ebsynth_cfg())
297 | self.eb.runner.initialize_libebsynth()
298 |
299 | print(f"Init ImageSynth took: {time.time() - st:.4f} s")
300 |
301 | def run(self, guides: list[tuple[np.ndarray, np.ndarray, float]] = []):
302 | guides.append((self.src_img, self.tgt_img, self.cfg.img_wgt))
303 | return self.eb.run(self.style_img, guides=guides)
304 |
305 |
306 | class ImageSynth(ImageSynthBase):
307 | def __init__(
308 | self,
309 | style_path: str,
310 | src_path: str,
311 | tgt_path: str,
312 | cfg: RunConfig = RunConfig(),
313 | ) -> None:
314 | style_img = validate_and_read_img(style_path)
315 | src_img = validate_and_read_img(src_path)
316 | tgt_img = validate_and_read_img(tgt_path)
317 |
318 | super().__init__(style_img, src_img, tgt_img, cfg)
319 |
--------------------------------------------------------------------------------
/ezsynth/sequences.py:
--------------------------------------------------------------------------------
1 | class EasySequence:
2 | MODE_FWD = "forward"
3 | MODE_REV = "reverse"
4 | MODE_BLN = "blend"
5 | MODE_NON = "none"
6 |
7 | def __init__(
8 | self, fr_start_idx: int, fr_end_idx: int, mode: str, style_idxs: list[int]
9 | ) -> None:
10 | self.fr_start_idx = fr_start_idx
11 | self.fr_end_idx = fr_end_idx
12 | self.mode = mode
13 | self.style_idxs = style_idxs
14 |
15 | def __repr__(self) -> str:
16 | return f"[{self.fr_start_idx}, {self.fr_end_idx}] {self.mode} {self.style_idxs}"
17 |
18 | @classmethod
19 | def get_valid_modes(cls) -> tuple[str, str, str, str]:
20 | return (
21 | cls.MODE_FWD,
22 | cls.MODE_REV,
23 | cls.MODE_BLN,
24 | cls.MODE_NON,
25 | )
26 |
27 |
28 | class SequenceManager:
29 | def __init__(self, begin_fr_idx, end_fr_idx, num_style_frs, style_idxs, img_idxs):
30 | self.begin_fr_idx = begin_fr_idx
31 | self.end_fr_idx = end_fr_idx
32 | self.style_idxs = style_idxs
33 | self.img_idxs = img_idxs
34 | self.num_style_frs = num_style_frs
35 |
36 | def create_sequences(self) -> tuple[list[EasySequence], list[str]]:
37 | sequences = []
38 | atlas: list[str] = []
39 |
40 | # Handle sequence before the first style frame
41 | if self.begin_fr_idx < self.style_idxs[0]:
42 | sequences.append(
43 | EasySequence(
44 | fr_start_idx=self.begin_fr_idx,
45 | fr_end_idx=self.style_idxs[0],
46 | mode=EasySequence.MODE_REV,
47 | style_idxs=[0],
48 | )
49 | )
50 | atlas.append(EasySequence.MODE_REV)
51 |
52 | # Handle sequences between style frames
53 | for i in range(len(self.style_idxs) - 1):
54 | sequences.append(
55 | EasySequence(
56 | fr_start_idx=self.style_idxs[i],
57 | fr_end_idx=self.style_idxs[i + 1],
58 | mode=EasySequence.MODE_BLN,
59 | style_idxs=[i, i + 1],
60 | )
61 | )
62 | atlas.append(EasySequence.MODE_BLN)
63 |
64 | # Handle sequence after the last style frame
65 | if self.end_fr_idx > self.style_idxs[-1]:
66 | sequences.append(
67 | EasySequence(
68 | fr_start_idx=self.style_idxs[-1],
69 | fr_end_idx=self.end_fr_idx,
70 | mode=EasySequence.MODE_FWD,
71 | style_idxs=[self.num_style_frs - 1],
72 | )
73 | )
74 | atlas.append(EasySequence.MODE_FWD)
75 |
76 | for seq in sequences:
77 | print(f"{seq}")
78 | return sequences, atlas
79 |
--------------------------------------------------------------------------------
/ezsynth/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/ezsynth/utils/__init__.py
--------------------------------------------------------------------------------
/ezsynth/utils/_eb.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from ctypes import (
3 | CDLL,
4 | POINTER,
5 | c_float,
6 | c_int,
7 | c_void_p,
8 | create_string_buffer,
9 | )
10 | from pathlib import Path
11 |
12 | import numpy as np
13 |
14 |
15 | class EbsynthRunner:
16 | EBSYNTH_BACKEND_CPU = 0x0001
17 | EBSYNTH_BACKEND_CUDA = 0x0002
18 | EBSYNTH_BACKEND_AUTO = 0x0000
19 | EBSYNTH_MAX_STYLE_CHANNELS = 8
20 | EBSYNTH_MAX_GUIDE_CHANNELS = 24
21 | EBSYNTH_VOTEMODE_PLAIN = 0x0001 # weight = 1
22 | EBSYNTH_VOTEMODE_WEIGHTED = 0x0002 # weight = 1/(1+error)
23 |
24 | def __init__(self):
25 | self.libebsynth = None
26 | self.cached_buffer = {}
27 | self.cached_err_buffer = {}
28 |
29 | def initialize_libebsynth(self):
30 | if self.libebsynth is None:
31 | if sys.platform[0:3] == "win":
32 | libebsynth_path = str(Path(__file__).parent / "ebsynth.dll")
33 | self.libebsynth = CDLL(libebsynth_path)
34 | # elif sys.platform == "darwin":
35 | # libebsynth_path = str(Path(__file__).parent / "ebsynth.so")
36 | # self.libebsynth = CDLL(libebsynth_path)
37 | elif sys.platform[0:5] == "linux":
38 | libebsynth_path = str(Path(__file__).parent / "ebsynth.so")
39 | self.libebsynth = CDLL(libebsynth_path)
40 | else:
41 | raise RuntimeError("Unsupported platform.")
42 |
43 | if self.libebsynth is not None:
44 | self.libebsynth.ebsynthRun.argtypes = (
45 | c_int,
46 | c_int,
47 | c_int,
48 | c_int,
49 | c_int,
50 | c_void_p,
51 | c_void_p,
52 | c_int,
53 | c_int,
54 | c_void_p,
55 | c_void_p,
56 | POINTER(c_float),
57 | POINTER(c_float),
58 | c_float,
59 | c_int,
60 | c_int,
61 | c_int,
62 | POINTER(c_int),
63 | POINTER(c_int),
64 | POINTER(c_int),
65 | c_int,
66 | c_void_p,
67 | c_void_p,
68 | c_void_p,
69 | )
70 | pass
71 |
72 | def get_or_create_buffer(self, key):
73 | buffer = self.cached_buffer.get(key, None)
74 | if buffer is None:
75 | buffer = create_string_buffer(key[0] * key[1] * key[2])
76 | self.cached_buffer[key] = buffer
77 | return buffer
78 |
79 | def get_or_create_err_buffer(self, key):
80 | errbuffer = self.cached_err_buffer.get(key, None)
81 | if errbuffer is None:
82 | errbuffer = (c_float * (key[0] * key[1]))()
83 | self.cached_err_buffer[key] = errbuffer
84 | return errbuffer
85 |
86 | # def _normalize_img_shape(self, img: np.ndarray) -> np.ndarray:
87 | # # with self.normalize_lock:
88 | # img_len = len(img.shape)
89 | # if img_len == 2:
90 | # sh, sw = img.shape
91 | # sc = 0
92 | # elif img_len == 3:
93 | # sh, sw, sc = img.shape
94 |
95 | # if sc == 0:
96 | # sc = 1
97 |
98 | # return img
99 |
100 | # def _normalize_img_shape(self, img: np.ndarray) -> np.ndarray:
101 | # if len(img.shape) == 2:
102 | # img = img[..., np.newaxis]
103 | # return img
104 |
105 | def _normalize_img_shape(self, img: np.ndarray) -> np.ndarray:
106 | return np.atleast_3d(img)
107 |
108 | def validate_inputs(self, patch_size: int, guides: list):
109 | # Validation checks
110 | if patch_size < 3:
111 | raise ValueError("patch_size is too small")
112 | if patch_size % 2 == 0:
113 | raise ValueError("patch_size must be an odd number")
114 | if len(guides) == 0:
115 | raise ValueError("at least one guide must be specified")
116 |
117 | def run(
118 | self,
119 | img_style,
120 | guides,
121 | patch_size=5,
122 | num_pyramid_levels=-1,
123 | num_search_vote_iters=6,
124 | num_patch_match_iters=4,
125 | stop_threshold=5,
126 | uniformity_weight=3500.0,
127 | extraPass3x3=False,
128 | ):
129 | self.validate_inputs(patch_size, guides)
130 |
131 | # Initialize libebsynth if not already done
132 | # self.initialize_libebsynth()
133 |
134 | img_style = self._normalize_img_shape(img_style)
135 | sh, sw, sc = img_style.shape
136 | t_h, t_w, t_c = 0, 0, 0
137 |
138 | self.validate_style_channels(sc)
139 |
140 | guides_source = []
141 | guides_target = []
142 | guides_weights = []
143 |
144 | t_h, t_w = self.validate_guides(
145 | guides, sh, sw, t_c, guides_source, guides_target, guides_weights
146 | )
147 |
148 | guides_source = np.concatenate(guides_source, axis=-1)
149 | guides_target = np.concatenate(guides_target, axis=-1)
150 | guides_weights = (c_float * len(guides_weights))(*guides_weights)
151 |
152 | style_weights = [1.0 / sc for _ in range(sc)]
153 | style_weights = (c_float * sc)(*style_weights)
154 |
155 | maxPyramidLevels = self.get_max_pyramid_level(patch_size, sh, sw, t_h, t_w)
156 |
157 | (
158 | num_pyramid_levels,
159 | num_search_vote_iters_per_level,
160 | num_patch_match_iters_per_level,
161 | stop_threshold_per_level,
162 | ) = self.validate_per_levels(
163 | num_pyramid_levels,
164 | num_search_vote_iters,
165 | num_patch_match_iters,
166 | stop_threshold,
167 | maxPyramidLevels,
168 | )
169 |
170 | # Get or create buffers
171 | buffer = self.get_or_create_buffer((t_h, t_w, sc))
172 | errbuffer = self.get_or_create_err_buffer((t_h, t_w))
173 |
174 | self.libebsynth.ebsynthRun(
175 | self.EBSYNTH_BACKEND_AUTO, # backend
176 | sc, # numStyleChannels
177 | guides_source.shape[-1], # numGuideChannels
178 | sw, # sourceWidth
179 | sh, # sourceHeight
180 | img_style.tobytes(), # sourceStyleData (width * height * numStyleChannels) bytes, scan-line order
181 | guides_source.tobytes(), # sourceGuideData (width * height * numGuideChannels) bytes, scan-line order
182 | t_w, # targetWidth
183 | t_h, # targetHeight
184 | guides_target.tobytes(), # targetGuideData (width * height * numGuideChannels) bytes, scan-line order
185 | None, # targetModulationData (width * height * numGuideChannels) bytes, scan-line order; pass NULL to switch off the modulation
186 | style_weights, # styleWeights (numStyleChannels) floats
187 | guides_weights, # guideWeights (numGuideChannels) floats
188 | uniformity_weight, # uniformityWeight reasonable values are between 500-15000, 3500 is a good default
189 | patch_size, # patchSize odd sizes only, use 5 for 5x5 patch, 7 for 7x7, etc.
190 | self.EBSYNTH_VOTEMODE_WEIGHTED, # voteMode use VOTEMODE_WEIGHTED for sharper result
191 | num_pyramid_levels, # numPyramidLevels
192 | num_search_vote_iters_per_level, # numSearchVoteItersPerLevel how many search/vote iters to perform at each level (array of ints, coarse first, fine last)
193 | num_patch_match_iters_per_level, # numPatchMatchItersPerLevel how many Patch-Match iters to perform at each level (array of ints, coarse first, fine last)
194 | stop_threshold_per_level, # stopThresholdPerLevel stop improving pixel when its change since last iteration falls under this threshold
195 | 1
196 | if extraPass3x3
197 | else 0, # extraPass3x3 perform additional polishing pass with 3x3 patches at the finest level, use 0 to disable
198 | None, # outputNnfData (width * height * 2) ints, scan-line order; pass NULL to ignore
199 | buffer, # outputImageData (width * height * numStyleChannels) bytes, scan-line order
200 | errbuffer, # outputErrorData (width * height) floats, scan-line order; pass NULL to ignore
201 | )
202 |
203 | img = np.frombuffer(buffer, dtype=np.uint8).reshape((t_h, t_w, sc)).copy()
204 | err = np.frombuffer(errbuffer, dtype=np.float32).reshape((t_h, t_w)).copy()
205 |
206 | return img, err
207 |
208 | def get_max_pyramid_level(self, patch_size, sh, sw, t_h, t_w):
209 | maxPyramidLevels = 0
210 | min_a = min(sh, t_h)
211 | min_b = min(sw, t_w)
212 | for level in range(32, -1, -1):
213 | pow_a = pow(2.0, -level)
214 | if min(min_a * pow_a, min_b * pow_a) >= (2 * patch_size + 1):
215 | maxPyramidLevels = level + 1
216 | break
217 | return maxPyramidLevels
218 |
219 | def validate_per_levels(
220 | self,
221 | num_pyramid_levels,
222 | num_search_vote_iters,
223 | num_patch_match_iters,
224 | stop_threshold,
225 | maxPyramidLevels,
226 | ):
227 | if num_pyramid_levels == -1:
228 | num_pyramid_levels = maxPyramidLevels
229 | num_pyramid_levels = min(num_pyramid_levels, maxPyramidLevels)
230 |
231 | num_search_vote_iters_per_level = (c_int * num_pyramid_levels)(
232 | *[num_search_vote_iters] * num_pyramid_levels
233 | )
234 | num_patch_match_iters_per_level = (c_int * num_pyramid_levels)(
235 | *[num_patch_match_iters] * num_pyramid_levels
236 | )
237 | stop_threshold_per_level = (c_int * num_pyramid_levels)(
238 | *[stop_threshold] * num_pyramid_levels
239 | )
240 |
241 | return (
242 | num_pyramid_levels,
243 | num_search_vote_iters_per_level,
244 | num_patch_match_iters_per_level,
245 | stop_threshold_per_level,
246 | )
247 |
248 | def validate_style_channels(self, sc):
249 | if sc > self.EBSYNTH_MAX_STYLE_CHANNELS:
250 | raise ValueError(
251 | f"error: too many style channels {sc}, maximum number is {self.EBSYNTH_MAX_STYLE_CHANNELS}"
252 | )
253 |
254 | def validate_guides(
255 | self, guides, sh, sw, t_c, guides_source, guides_target, guides_weights
256 | ):
257 | for i in range(len(guides)):
258 | source_guide, target_guide, guide_weight = guides[i]
259 | source_guide = self._normalize_img_shape(source_guide)
260 | target_guide = self._normalize_img_shape(target_guide)
261 | s_h, s_w, s_c = source_guide.shape
262 | nt_h, nt_w, nt_c = target_guide.shape
263 |
264 | if s_h != sh or s_w != sw:
265 | raise ValueError(
266 | "guide source and style resolution must match style resolution."
267 | )
268 |
269 | if t_c == 0:
270 | t_h, t_w, t_c = nt_h, nt_w, nt_c
271 | elif nt_h != t_h or nt_w != t_w:
272 | raise ValueError("guides target resolutions must be equal")
273 |
274 | if s_c != nt_c:
275 | raise ValueError("guide source and target channels must match exactly.")
276 |
277 | guides_source.append(source_guide)
278 | guides_target.append(target_guide)
279 |
280 | guides_weights.extend([guide_weight / s_c] * s_c)
281 | return t_h, t_w
282 |
--------------------------------------------------------------------------------
/ezsynth/utils/_ebsynth.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from ._eb import EbsynthRunner
4 |
5 |
6 | class ebsynth:
7 | """
8 | EBSynth class provides a wrapper around the ebsynth style transfer method.
9 |
10 | Usage:
11 | ebsynth = ebsynth.ebsynth(style='style.png', guides=[('source1.png', 'target1.png'), 1.0])
12 | result_img = ebsynth.run()
13 | """
14 |
15 | def __init__(
16 | self,
17 | uniformity=3500.0,
18 | patchsize=5,
19 | pyramidlevels=6,
20 | searchvoteiters=12,
21 | patchmatchiters=6,
22 | extrapass3x3=True,
23 | backend="auto",
24 | ):
25 | """
26 | Initialize the EBSynth wrapper.
27 | :param style: path to the style image, or a numpy array.
28 | :param guides: list of tuples containing source and target guide images, as file paths or as numpy arrays.
29 | :param weight: weights for each guide pair. Defaults to 1.0 for each pair.
30 | :param uniformity: uniformity weight for the style transfer. Defaults to 3500.0.
31 | :param patchsize: size of the patches. Must be an odd number. Defaults to 5. [5x5 patches]
32 | :param pyramidlevels: number of pyramid levels. Larger Values useful for things like color transfer. Defaults to 6.
33 | :param searchvoteiters: number of search/vote iterations. Defaults to 12.
34 | :param patchmatchiters: number of Patch-Match iterations. Defaults to 6.
35 | :param extrapass3x3: whether to perform an extra pass with 3x3 patches. Defaults to False.
36 | :param backend: backend to use ('cpu', 'cuda', or 'auto'). Defaults to 'auto'.
37 | """
38 |
39 | self.runner = EbsynthRunner()
40 | self.uniformity = uniformity
41 | self.patchsize = patchsize
42 | self.pyramidlevels = pyramidlevels
43 | self.searchvoteiters = searchvoteiters
44 | self.patchmatchiters = patchmatchiters
45 | self.extrapass3x3 = extrapass3x3
46 |
47 | # Define backend constants
48 | self.backends = {
49 | "cpu": EbsynthRunner.EBSYNTH_BACKEND_CPU,
50 | "cuda": EbsynthRunner.EBSYNTH_BACKEND_CUDA,
51 | "auto": EbsynthRunner.EBSYNTH_BACKEND_AUTO,
52 | }
53 | self.backend = self.backends[backend]
54 |
55 | def run(self, style: np.ndarray, guides: list[tuple[np.ndarray, np.ndarray, np.ndarray]]):
56 | # Call the run function with the provided arguments
57 | img, err = self.runner.run(
58 | style,
59 | guides,
60 | patch_size=self.patchsize,
61 | num_pyramid_levels=self.pyramidlevels,
62 | num_search_vote_iters=self.searchvoteiters,
63 | num_patch_match_iters=self.patchmatchiters,
64 | uniformity_weight=self.uniformity,
65 | extraPass3x3=self.extrapass3x3,
66 | )
67 |
68 | return img, err
69 |
--------------------------------------------------------------------------------
/ezsynth/utils/blend/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/ezsynth/utils/blend/__init__.py
--------------------------------------------------------------------------------
/ezsynth/utils/blend/blender.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import numpy as np
4 | import tqdm
5 |
6 | from ..flow_utils.warp import Warp
7 | from .histogram_blend import hist_blender
8 | from .reconstruction import reconstructor
9 |
10 | try:
11 | from .cupy_accelerated import hist_blend_cupy
12 |
13 | USE_GPU = True
14 | except ImportError as e:
15 | print(f"Cupy is not installed. Revert to CPU. {e}")
16 | USE_GPU = False
17 |
18 |
19 | class Blend:
20 | def __init__(
21 | self,
22 | use_gpu=False,
23 | use_lsqr=True,
24 | use_poisson_cupy=False,
25 | poisson_maxiter=None,
26 | ):
27 | self.prev_mask = None
28 |
29 | self.use_gpu = use_gpu and USE_GPU
30 | self.use_lsqr = use_lsqr
31 | self.use_poisson_cupy = use_poisson_cupy
32 | self.poisson_maxiter = poisson_maxiter
33 |
34 | def _warping_masks(
35 | self,
36 | sample_fr: np.ndarray,
37 | flow_fwd: list[np.ndarray],
38 | err_masks: list[np.ndarray],
39 | ):
40 | # use err_masks with flow to create final err_masks
41 | warped_masks = []
42 | warp = Warp(sample_fr)
43 |
44 | for i in tqdm.tqdm(range(len(err_masks)), desc="Warping masks"):
45 | if self.prev_mask is None:
46 | self.prev_mask = np.zeros_like(err_masks[0])
47 | warped_mask = warp.run_warping(
48 | err_masks[i], flow_fwd[i] if i == 0 else flow_fwd[i - 1]
49 | )
50 |
51 | z_hat = warped_mask.copy()
52 | # If the shapes are not compatible, we can adjust the shape of self.prev_mask
53 | if self.prev_mask.shape != z_hat.shape:
54 | self.prev_mask = np.repeat(
55 | self.prev_mask[:, :, np.newaxis], z_hat.shape[2], axis=2
56 | )
57 |
58 | z_hat = np.where((self.prev_mask > 1) & (z_hat == 0), 1, z_hat)
59 |
60 | self.prev_mask = z_hat.copy()
61 | warped_masks.append(z_hat.copy())
62 | return warped_masks
63 |
64 | def _create_selection_mask(
65 | self, err_forward_lst: list[np.ndarray], err_backward_lst: list[np.ndarray]
66 | ) -> list[np.ndarray]:
67 | err_forward = np.array(err_forward_lst)
68 | err_backward = np.array(err_backward_lst)
69 |
70 | if err_forward.shape != err_backward.shape:
71 | print(f"Shape mismatch: {err_forward.shape=} vs {err_backward.shape=}")
72 | return []
73 |
74 | # Create a binary mask where the forward error metric
75 | # is less than the backward error metric
76 | selection_masks = np.where(err_forward < err_backward, 0, 1).astype(np.uint8)
77 |
78 | # Convert numpy array back to list
79 | selection_masks_lst = [
80 | selection_masks[i] for i in range(selection_masks.shape[0])
81 | ]
82 |
83 | return selection_masks_lst
84 |
85 | def _hist_blend(
86 | self,
87 | style_fwd: list[np.ndarray],
88 | style_bwd: list[np.ndarray],
89 | err_masks: list[np.ndarray],
90 | ) -> list[np.ndarray]:
91 | st = time.time()
92 | hist_blends: list[np.ndarray] = []
93 | for i in tqdm.tqdm(range(len(err_masks)), desc="Hist blending: "):
94 | if self.use_gpu:
95 | hist_blend = hist_blend_cupy(
96 | style_fwd[i],
97 | style_bwd[i],
98 | err_masks[i],
99 | )
100 | else:
101 | hist_blend = hist_blender(
102 | style_fwd[i],
103 | style_bwd[i],
104 | err_masks[i],
105 | )
106 | hist_blends.append(hist_blend)
107 | print(f"Hist Blend took {time.time() - st:.4f} s")
108 | print(len(hist_blends))
109 | return hist_blends
110 |
111 | def _reconstruct(
112 | self,
113 | style_fwd: list[np.ndarray],
114 | style_bwd: list[np.ndarray],
115 | err_masks: list[np.ndarray],
116 | hist_blends: list[np.ndarray],
117 | ):
118 | blends = reconstructor(
119 | hist_blends,
120 | style_fwd,
121 | style_bwd,
122 | err_masks,
123 | use_gpu=self.use_gpu,
124 | use_lsqr=self.use_lsqr,
125 | use_poisson_cupy=self.use_poisson_cupy,
126 | poisson_maxiter=self.poisson_maxiter,
127 | )
128 | final_blends = blends._create()
129 | final_blends = [blend for blend in final_blends if blend is not None]
130 | return final_blends
131 |
--------------------------------------------------------------------------------
/ezsynth/utils/blend/cupy_accelerated.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import cupy as cp
4 | import cupyx
5 | import cupyx.scipy.sparse._csc
6 | import cupyx.scipy.sparse.linalg
7 | import cv2
8 | import numpy as np
9 | import scipy.sparse
10 |
11 |
12 | def assemble_min_error_img(a, b, error_mask):
13 | return cp.where(error_mask == 0, a, b)
14 |
15 |
16 | def mean_std(img):
17 | return cp.mean(img, axis=(0, 1)), cp.std(img, axis=(0, 1))
18 |
19 |
20 | def histogram_transform(img, means, stds, target_means, target_stds):
21 | return ((img - means) * target_stds / stds + target_means).astype(cp.float32)
22 |
23 |
24 | def hist_blend_cupy(
25 | a: np.ndarray, b: np.ndarray, error_mask: np.ndarray, weight1=0.5, weight2=0.5
26 | ):
27 | a = cp.asarray(a)
28 | b = cp.asarray(b)
29 | error_mask = cp.asarray(error_mask)
30 |
31 | # Ensure error_mask has 3 channels
32 | if len(error_mask.shape) == 2:
33 | error_mask = cp.repeat(error_mask[:, :, cp.newaxis], 3, axis=2)
34 |
35 | # Convert to Lab color space
36 | a_lab = cv2.cvtColor(cp.asnumpy(a), cv2.COLOR_BGR2Lab)
37 | b_lab = cv2.cvtColor(cp.asnumpy(b), cv2.COLOR_BGR2Lab)
38 | a_lab = cp.asarray(a_lab)
39 | b_lab = cp.asarray(b_lab)
40 |
41 | min_error_lab = assemble_min_error_img(a_lab, b_lab, error_mask)
42 |
43 | # Compute means and stds
44 | a_mean, a_std = mean_std(a_lab)
45 | b_mean, b_std = mean_std(b_lab)
46 | min_error_mean, min_error_std = mean_std(min_error_lab)
47 |
48 | # Histogram transformation constants
49 | t_mean = cp.full(3, 0.5 * 256, dtype=cp.float32)
50 | t_std = cp.full(3, (1 / 36) * 256, dtype=cp.float32)
51 |
52 | # Histogram transform
53 | a_lab = histogram_transform(a_lab, a_mean, a_std, t_mean, t_std)
54 | b_lab = histogram_transform(b_lab, b_mean, b_std, t_mean, t_std)
55 |
56 | # Blending
57 | ab_lab = (a_lab * weight1 + b_lab * weight2 - 128) / 0.5 + 128
58 | ab_mean, ab_std = mean_std(ab_lab)
59 |
60 | # Final histogram transform
61 | ab_lab = histogram_transform(ab_lab, ab_mean, ab_std, min_error_mean, min_error_std)
62 |
63 | ab_lab = cp.clip(cp.round(ab_lab), 0, 255).astype(cp.uint8)
64 | ab_lab = cp.asnumpy(ab_lab)
65 | ab = cv2.cvtColor(ab_lab, cv2.COLOR_Lab2BGR)
66 | return ab
67 |
68 |
69 | def construct_A_cupy(h: int, w: int, grad_weight: list[float], use_poisson_cupy=False):
70 | st = time.time()
71 | indgx_x = cp.zeros(2 * (h - 1) * w, dtype=int)
72 | indgx_y = cp.zeros(2 * (h - 1) * w, dtype=int)
73 | vdx = cp.ones(2 * (h - 1) * w)
74 |
75 | indgy_x = cp.zeros(2 * h * (w - 1), dtype=int)
76 | indgy_y = cp.zeros(2 * h * (w - 1), dtype=int)
77 | vdy = cp.ones(2 * h * (w - 1))
78 |
79 | indgx_x[::2] = cp.arange((h - 1) * w)
80 | indgx_y[::2] = indgx_x[::2]
81 | indgx_x[1::2] = indgx_x[::2]
82 | indgx_y[1::2] = indgx_x[::2] + w
83 |
84 | indgy_x[::2] = cp.arange(h * (w - 1))
85 | indgy_y[::2] = indgy_x[::2]
86 | indgy_x[1::2] = indgy_x[::2]
87 | indgy_y[1::2] = indgy_x[::2] + 1
88 |
89 | vdx[1::2] = -1
90 | vdy[1::2] = -1
91 |
92 | Ix = cupyx.scipy.sparse.eye(h * w, format="csc")
93 | Gx = cupyx.scipy.sparse.coo_matrix(
94 | (vdx, (indgx_x, indgx_y)), shape=(h * w, h * w)
95 | ).tocsc()
96 | Gy = cupyx.scipy.sparse.coo_matrix(
97 | (vdy, (indgy_x, indgy_y)), shape=(h * w, h * w)
98 | ).tocsc()
99 |
100 | As = [
101 | cupyx.scipy.sparse.vstack([Gx * weight, Gy * weight, Ix])
102 | for weight in grad_weight
103 | ]
104 | print(f"Constructing As took {time.time() - st:.4f} s")
105 | if not use_poisson_cupy:
106 | As_scipy = [
107 | scipy.sparse.vstack(
108 | [
109 | scipy.sparse.csr_matrix(Gx.get() * weight),
110 | scipy.sparse.csr_matrix(Gy.get() * weight),
111 | scipy.sparse.csr_matrix(Ix.get()),
112 | ]
113 | )
114 | for weight in grad_weight
115 | ]
116 | return As_scipy
117 | return As
118 |
119 |
120 | def poisson_fusion_cupy(
121 | blendI: np.ndarray,
122 | I1: np.ndarray,
123 | I2: np.ndarray,
124 | mask: np.ndarray,
125 | As: list[cupyx.scipy.sparse._csc.csc_matrix],
126 | poisson_maxiter=None,
127 | ):
128 | grad_weight = [2.5, 0.5, 0.5]
129 | Iab = cv2.cvtColor(blendI, cv2.COLOR_BGR2LAB).astype(float)
130 | Ia = cv2.cvtColor(I1, cv2.COLOR_BGR2LAB).astype(float)
131 | Ib = cv2.cvtColor(I2, cv2.COLOR_BGR2LAB).astype(float)
132 |
133 | Iab_cp = cp.asarray(Iab)
134 | Ia_cp = cp.asarray(Ia)
135 | Ib_cp = cp.asarray(Ib)
136 | mask_cp = cp.asarray(mask)
137 |
138 | m_cp = (mask_cp > 0).astype(float)[..., cp.newaxis]
139 | h, w, c = Iab.shape
140 |
141 | gx_cp = cp.zeros_like(Ia_cp)
142 | gy_cp = cp.zeros_like(Ia_cp)
143 |
144 | gx_cp[:-1] = (Ia_cp[:-1] - Ia_cp[1:]) * (1 - m_cp[:-1]) + (
145 | Ib_cp[:-1] - Ib_cp[1:]
146 | ) * m_cp[:-1]
147 | gy_cp[:, :-1] = (Ia_cp[:, :-1] - Ia_cp[:, 1:]) * (1 - m_cp[:, :-1]) + (
148 | Ib_cp[:, :-1] - Ib_cp[:, 1:]
149 | ) * m_cp[:, :-1]
150 |
151 | final_channels = [
152 | poisson_fusion_channel_cupy(
153 | Iab_cp, gx_cp, gy_cp, h, w, As, i, grad_weight, maxiter=poisson_maxiter
154 | )
155 | for i in range(3)
156 | ]
157 |
158 | final = np.clip(np.concatenate(final_channels, axis=2), 0, 255)
159 | return cv2.cvtColor(final.astype(np.uint8), cv2.COLOR_LAB2BGR)
160 |
161 |
162 | def poisson_fusion_channel_cupy(
163 | Iab: cp.ndarray,
164 | gx: cp.ndarray,
165 | gy: cp.ndarray,
166 | h: int,
167 | w: int,
168 | As: list[cupyx.scipy.sparse._csc.csc_matrix],
169 | channel: int,
170 | grad_weight: list[float],
171 | maxiter: int | None = None,
172 | ):
173 | cp.get_default_memory_pool().free_all_blocks()
174 | weight = grad_weight[channel]
175 | im_dx = cp.clip(gx[:, :, channel].reshape(h * w, 1), -100, 100)
176 | im_dy = cp.clip(gy[:, :, channel].reshape(h * w, 1), -100, 100)
177 | im = Iab[:, :, channel].reshape(h * w, 1)
178 | im_mean = im.mean()
179 | im = im - im_mean
180 | A = As[channel]
181 | b = cp.vstack([im_dx * weight, im_dy * weight, im])
182 | out = cupyx.scipy.sparse.linalg.lsmr(A, b, maxiter=maxiter)
183 | out_im = (out[0] + im_mean).reshape(h, w, 1)
184 |
185 | return cp.asnumpy(out_im)
186 |
--------------------------------------------------------------------------------
/ezsynth/utils/blend/histogram_blend.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 |
5 | def hist_blender(
6 | a: np.ndarray,
7 | b: np.ndarray,
8 | error_mask: np.ndarray,
9 | weight1=0.5,
10 | weight2=0.5,
11 | ) -> np.ndarray:
12 | # Ensure error_mask has 3 channels
13 | if len(error_mask.shape) == 2:
14 | error_mask = np.repeat(error_mask[:, :, np.newaxis], 3, axis=2)
15 |
16 | # Convert to Lab color space
17 | a_lab = cv2.cvtColor(a, cv2.COLOR_BGR2Lab)
18 | b_lab = cv2.cvtColor(b, cv2.COLOR_BGR2Lab)
19 |
20 | # Generate min_error_img
21 | min_error_lab = assemble_min_error_img(a_lab, b_lab, error_mask)
22 |
23 | # Compute means and stds
24 | a_mean, a_std = mean_std(a_lab)
25 | b_mean, b_std = mean_std(b_lab)
26 | min_error_mean, min_error_std = mean_std(min_error_lab)
27 |
28 | # Histogram transformation constants
29 | t_mean = np.full(3, 0.5 * 256, dtype=np.float32)
30 | t_std = np.full(3, (1 / 36) * 256, dtype=np.float32)
31 |
32 | # Histogram transform
33 | a_lab = histogram_transform(a_lab, a_mean, a_std, t_mean, t_std)
34 | b_lab = histogram_transform(b_lab, b_mean, b_std, t_mean, t_std)
35 |
36 | # Blending
37 | ab_lab = (a_lab * weight1 + b_lab * weight2 - 128) / 0.5 + 128
38 | ab_mean, ab_std = mean_std(ab_lab)
39 |
40 | # Final histogram transform
41 | ab_lab = histogram_transform(ab_lab, ab_mean, ab_std, min_error_mean, min_error_std)
42 |
43 | ab_lab = np.clip(np.round(ab_lab), 0, 255).astype(np.uint8)
44 |
45 | # Convert back to BGR
46 | ab = cv2.cvtColor(ab_lab, cv2.COLOR_Lab2BGR)
47 |
48 | return ab
49 |
50 |
51 | def histogram_transform(img, means, stds, target_means, target_stds):
52 | return ((img - means) * target_stds / stds + target_means).astype(np.float32)
53 |
54 |
55 | def assemble_min_error_img(a, b, error_mask):
56 | return np.where(error_mask == 0, a, b)
57 |
58 |
59 | def mean_std(img):
60 | return np.mean(img, axis=(0, 1)), np.std(img, axis=(0, 1))
61 |
--------------------------------------------------------------------------------
/ezsynth/utils/blend/reconstruction.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import cv2
4 | import numpy as np
5 | import scipy.sparse
6 | import scipy.sparse.linalg
7 | import tqdm
8 |
9 | try:
10 | from .cupy_accelerated import construct_A_cupy, poisson_fusion_cupy
11 |
12 | USE_GPU = True
13 | print("Cupy is installed. Can do Cupy GPU accelerations")
14 | except ImportError:
15 | USE_GPU = False
16 | print("Cupy is not installed. Revert to CPU")
17 |
18 |
19 | class reconstructor:
20 | def __init__(
21 | self,
22 | hist_blends: list[np.ndarray],
23 | style_fwd: list[np.ndarray],
24 | style_bwd: list[np.ndarray],
25 | err_masks: list[np.ndarray],
26 | use_gpu=False,
27 | use_lsqr=True,
28 | use_poisson_cupy=False,
29 | poisson_maxiter=None,
30 | ):
31 | self.hist_blends = hist_blends
32 | self.style_fwd = style_fwd
33 | self.style_bwd = style_bwd
34 | self.err_masks = err_masks
35 | self.blends = None
36 |
37 | self.use_gpu = use_gpu and USE_GPU
38 | self.use_lsqr = use_lsqr
39 | self.use_poisson_cupy = self.use_gpu and use_poisson_cupy
40 | self.poisson_maxiter = poisson_maxiter
41 |
42 | def _create(self):
43 | num_blends = len(self.hist_blends)
44 | h, w, c = self.hist_blends[0].shape
45 | self.blends = np.zeros((num_blends, h, w, c))
46 |
47 | a = construct_A(h, w, [2.5, 0.5, 0.5], self.use_gpu, self.use_poisson_cupy)
48 | for i in tqdm.tqdm(range(num_blends)):
49 | self.blends[i] = poisson_fusion(
50 | self.hist_blends[i],
51 | self.style_fwd[i],
52 | self.style_bwd[i],
53 | self.err_masks[i],
54 | a,
55 | self.use_gpu,
56 | self.use_lsqr,
57 | self.use_poisson_cupy,
58 | self.poisson_maxiter,
59 | )
60 |
61 | return self.blends
62 |
63 |
64 | def construct_A(
65 | h: int, w: int, grad_weight: list[float], use_gpu=False, use_poisson_cupy=False
66 | ):
67 | if use_gpu:
68 | return construct_A_cupy(h, w, grad_weight, use_poisson_cupy)
69 | return construct_A_cpu(h, w, grad_weight)
70 |
71 |
72 | def poisson_fusion(
73 | blendI: np.ndarray,
74 | I1: np.ndarray,
75 | I2: np.ndarray,
76 | mask: np.ndarray,
77 | As,
78 | use_gpu=False,
79 | use_lsqr=True,
80 | use_poisson_cupy=False,
81 | poisson_maxiter=None,
82 | ):
83 | if use_gpu and use_poisson_cupy:
84 | return poisson_fusion_cupy(blendI, I1, I2, mask, As, poisson_maxiter)
85 | return poisson_fusion_cpu_optimized(
86 | blendI, I1, I2, mask, As, use_lsqr, poisson_maxiter
87 | )
88 |
89 |
90 | def construct_A_cpu(h: int, w: int, grad_weight: list[float]):
91 | st = time.time()
92 | indgx_x = np.zeros(2 * (h - 1) * w, dtype=int)
93 | indgx_y = np.zeros(2 * (h - 1) * w, dtype=int)
94 | vdx = np.ones(2 * (h - 1) * w)
95 |
96 | indgy_x = np.zeros(2 * h * (w - 1), dtype=int)
97 | indgy_y = np.zeros(2 * h * (w - 1), dtype=int)
98 | vdy = np.ones(2 * h * (w - 1))
99 |
100 | indgx_x[::2] = np.arange((h - 1) * w)
101 | indgx_y[::2] = indgx_x[::2]
102 | indgx_x[1::2] = indgx_x[::2]
103 | indgx_y[1::2] = indgx_x[::2] + w
104 |
105 | indgy_x[::2] = np.arange(h * (w - 1))
106 | indgy_y[::2] = indgy_x[::2]
107 | indgy_x[1::2] = indgy_x[::2]
108 | indgy_y[1::2] = indgy_x[::2] + 1
109 |
110 | vdx[1::2] = -1
111 | vdy[1::2] = -1
112 |
113 | Ix = scipy.sparse.eye(h * w, format="csc")
114 | Gx = scipy.sparse.coo_matrix(
115 | (vdx, (indgx_x, indgx_y)), shape=(h * w, h * w)
116 | ).tocsc()
117 | Gy = scipy.sparse.coo_matrix(
118 | (vdy, (indgy_x, indgy_y)), shape=(h * w, h * w)
119 | ).tocsc()
120 |
121 | As = [scipy.sparse.vstack([Gx * weight, Gy * weight, Ix]) for weight in grad_weight]
122 | print(f"Constructing As took {time.time() - st:.4f} s")
123 | return As
124 |
125 |
126 | def poisson_fusion_cpu(
127 | blendI: np.ndarray,
128 | I1: np.ndarray,
129 | I2: np.ndarray,
130 | mask: np.ndarray,
131 | As: list[scipy.sparse._csc.csc_matrix],
132 | use_lsqr=True,
133 | ):
134 | grad_weight = [2.5, 0.5, 0.5]
135 | Iab = cv2.cvtColor(blendI, cv2.COLOR_BGR2LAB).astype(float)
136 | Ia = cv2.cvtColor(I1, cv2.COLOR_BGR2LAB).astype(float)
137 | Ib = cv2.cvtColor(I2, cv2.COLOR_BGR2LAB).astype(float)
138 |
139 | m = (mask > 0).astype(float)[..., np.newaxis]
140 | h, w, c = Iab.shape
141 |
142 | gx = np.zeros_like(Ia)
143 | gy = np.zeros_like(Ia)
144 |
145 | gx[:-1] = (Ia[:-1] - Ia[1:]) * (1 - m[:-1]) + (Ib[:-1] - Ib[1:]) * m[:-1]
146 | gy[:, :-1] = (Ia[:, :-1] - Ia[:, 1:]) * (1 - m[:, :-1]) + (
147 | Ib[:, :-1] - Ib[:, 1:]
148 | ) * m[:, :-1]
149 |
150 | final_channels = [
151 | poisson_fusion_channel_cpu(Iab, gx, gy, h, w, As, i, grad_weight, use_lsqr)
152 | for i in range(3)
153 | ]
154 |
155 | final = np.clip(np.concatenate(final_channels, axis=2), 0, 255)
156 | return cv2.cvtColor(final.astype(np.uint8), cv2.COLOR_LAB2BGR)
157 |
158 |
159 | def poisson_fusion_channel_cpu(
160 | Iab: np.ndarray,
161 | gx: np.ndarray,
162 | gy: np.ndarray,
163 | h: int,
164 | w: int,
165 | As: list[scipy.sparse._csc.csc_matrix],
166 | channel: int,
167 | grad_weight: list[float],
168 | use_lsqr=True,
169 | ):
170 | """Helper function to perform Poisson fusion on a single channel."""
171 | weight = grad_weight[channel]
172 | im_dx = np.clip(gx[:, :, channel].reshape(h * w, 1), -100, 100)
173 | im_dy = np.clip(gy[:, :, channel].reshape(h * w, 1), -100, 100)
174 | im = Iab[:, :, channel].reshape(h * w, 1)
175 | im_mean = im.mean()
176 | im = im - im_mean
177 |
178 | A = As[channel]
179 | b = np.vstack([im_dx * weight, im_dy * weight, im])
180 | if use_lsqr:
181 | out = scipy.sparse.linalg.lsqr(A, b)
182 | else:
183 | out = scipy.sparse.linalg.lsmr(A, b)
184 | out_im = (out[0] + im_mean).reshape(h, w, 1)
185 |
186 | return out_im
187 |
188 |
189 | def gradient_compute_python(Ia: np.ndarray, Ib: np.ndarray, m: np.ndarray):
190 | gx = np.zeros_like(Ia)
191 | gy = np.zeros_like(Ia)
192 |
193 | gx[:-1] = (Ia[:-1] - Ia[1:]) * (1 - m[:-1]) + (Ib[:-1] - Ib[1:]) * m[:-1]
194 | gy[:, :-1] = (Ia[:, :-1] - Ia[:, 1:]) * (1 - m[:, :-1]) + (
195 | Ib[:, :-1] - Ib[:, 1:]
196 | ) * m[:, :-1]
197 | return gx, gy
198 |
199 |
200 | def poisson_fusion_cpu_optimized(
201 | blendI: np.ndarray,
202 | I1: np.ndarray,
203 | I2: np.ndarray,
204 | mask: np.ndarray,
205 | As: list[scipy.sparse._csc.csc_matrix],
206 | use_lsqr=True,
207 | poisson_maxiter=None,
208 | ):
209 | # grad_weight = [2.5, 0.5, 0.5]
210 | grad_weight = np.array([2.5, 0.5, 0.5])
211 | Iab = cv2.cvtColor(blendI, cv2.COLOR_BGR2LAB).astype(float)
212 | Ia = cv2.cvtColor(I1, cv2.COLOR_BGR2LAB).astype(float)
213 | Ib = cv2.cvtColor(I2, cv2.COLOR_BGR2LAB).astype(float)
214 |
215 | m = (mask > 0).astype(float)[..., np.newaxis]
216 | h, w, c = Iab.shape
217 |
218 | gx, gy = gradient_compute_python(Ia, Ib, m)
219 |
220 | # Reshape and clip all channels at once
221 | gx_reshaped = np.clip(gx.reshape(h * w, c), -100, 100)
222 | gy_reshaped = np.clip(gy.reshape(h * w, c), -100, 100)
223 |
224 | Iab_reshaped = Iab.reshape(h * w, c)
225 | Iab_mean = np.mean(Iab_reshaped, axis=0)
226 | Iab_centered = Iab_reshaped - Iab_mean
227 |
228 | # Pre-allocate the output array
229 | out_all = np.zeros((h * w, c), dtype=np.float32)
230 |
231 | for channel in range(3):
232 | weight = grad_weight[channel]
233 | im_dx = gx_reshaped[:, channel : channel + 1]
234 | im_dy = gy_reshaped[:, channel : channel + 1]
235 | im = Iab_centered[:, channel : channel + 1]
236 |
237 | A = As[channel]
238 | b = np.vstack([im_dx * weight, im_dy * weight, im])
239 |
240 | if use_lsqr:
241 | out_all[:, channel] = scipy.sparse.linalg.lsqr(A, b)[0]
242 | else:
243 | out_all[:, channel] = scipy.sparse.linalg.lsmr(
244 | A, b, maxiter=poisson_maxiter
245 | )[0]
246 |
247 | # Add back the mean and reshape
248 | final = (out_all + Iab_mean).reshape(h, w, c)
249 | final = np.clip(final, 0, 255)
250 | return cv2.cvtColor(final.astype(np.uint8), cv2.COLOR_LAB2BGR)
251 |
--------------------------------------------------------------------------------
/ezsynth/utils/ebsynth.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/ezsynth/utils/ebsynth.dll
--------------------------------------------------------------------------------
/ezsynth/utils/flow_utils/OpticalFlow.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import cv2
4 | import numpy as np
5 | import torch
6 | import tqdm
7 |
8 | from .core.utils.utils import InputPadder
9 |
10 | class RAFT_flow:
11 | def __init__(self, model_name="sintel", arch="RAFT"):
12 | self.DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13 |
14 | self.arch = arch
15 |
16 | if self.arch == "RAFT":
17 | from .core.raft import RAFT
18 | model_name = f"raft-{model_name}.pth"
19 | model_path = os.path.join(os.path.dirname(__file__), "models", model_name)
20 |
21 | if not os.path.exists(model_path):
22 | raise ValueError(f"[ERROR] Model file '{model_path}' not found.")
23 |
24 | self.model = torch.nn.DataParallel(
25 | RAFT(args=self._instantiate_raft_model(model_name))
26 | )
27 |
28 | elif self.arch == "EF_RAFT":
29 | from .core.ef_raft import EF_RAFT
30 | model_name = f"{model_name}.pth"
31 | model_path = os.path.join(
32 | os.path.dirname(__file__), "ef_raft_models", model_name
33 | )
34 | if not os.path.exists(model_path):
35 | raise ValueError(f"[ERROR] Model file '{model_path}' not found.")
36 | self.model = torch.nn.DataParallel(
37 | EF_RAFT(args=self._instantiate_raft_model(model_name))
38 | )
39 |
40 | elif self.arch == "FLOW_DIFF":
41 | try:
42 | from .core.flow_diffusion import FlowDiffuser
43 | except ImportError as e:
44 | raise ImportError(f"Could not import FlowDiffuser. {e}")
45 | model_name = "FlowDiffuser-things.pth"
46 | model_path = os.path.join(os.path.dirname(__file__), "flow_diffusion_models", model_name)
47 | if not os.path.exists(model_path):
48 | raise ValueError(f"[ERROR] Model file '{model_path}' not found.")
49 | self.model = torch.nn.DataParallel(
50 | FlowDiffuser(args=self._instantiate_raft_model(model_name))
51 | )
52 |
53 |
54 | state_dict = torch.load(model_path, map_location=self.DEVICE)
55 | self.model.load_state_dict(state_dict)
56 |
57 | self.model.to(self.DEVICE)
58 | self.model.eval()
59 |
60 | def _instantiate_raft_model(self, model_name):
61 | from argparse import Namespace
62 |
63 | args = Namespace()
64 | args.model = model_name
65 | args.small = False
66 | args.mixed_precision = False
67 | return args
68 |
69 | def _load_tensor_from_numpy(self, np_array: np.ndarray):
70 | try:
71 | tensor = (
72 | torch.tensor(np_array, dtype=torch.float32)
73 | .permute(2, 0, 1)
74 | .unsqueeze(0)
75 | .to(self.DEVICE)
76 | )
77 | return tensor
78 | except Exception as e:
79 | print(f"[ERROR] Exception in load_tensor_from_numpy: {e}")
80 | raise e
81 |
82 | def _compute_flow(self, img1: np.ndarray, img2: np.ndarray):
83 | original_size = img1.shape[1::-1]
84 | with torch.no_grad():
85 | img1_tensor = self._load_tensor_from_numpy(img1)
86 | img2_tensor = self._load_tensor_from_numpy(img2)
87 | padder = InputPadder(img1_tensor.shape)
88 | images = padder.pad(img1_tensor, img2_tensor)
89 | _, flow_up = self.model(images[0], images[1], iters=20, test_mode=True)
90 | # flow_np = flow_up[0].permute(1, 2, 0).cpu().numpy()
91 | flow_np = padder.unpad(flow_up[0]).permute(1, 2, 0).cpu().numpy()
92 | cv2.resize(flow_np, original_size)
93 | return flow_np
94 |
95 | def compute_flow(self, img_frs_seq: list[np.ndarray]):
96 | optical_flow = []
97 | total_flows = len(img_frs_seq) - 1
98 | for i in tqdm.tqdm(range(total_flows), desc="Calculating Flow: "):
99 | optical_flow.append(self._compute_flow(img_frs_seq[i], img_frs_seq[i + 1]))
100 | self.optical_flow = optical_flow
101 | return self.optical_flow
102 |
--------------------------------------------------------------------------------
/ezsynth/utils/flow_utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/ezsynth/utils/flow_utils/__init__.py
--------------------------------------------------------------------------------
/ezsynth/utils/flow_utils/alt_cuda_corr/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/ezsynth/utils/flow_utils/alt_cuda_corr/__init__.py
--------------------------------------------------------------------------------
/ezsynth/utils/flow_utils/alt_cuda_corr/correlation.cpp:
--------------------------------------------------------------------------------
1 | <<<<<<< Updated upstream
2 | #include
3 | #include
4 |
5 | // CUDA forward declarations
6 | std::vector corr_cuda_forward(
7 | torch::Tensor fmap1,
8 | torch::Tensor fmap2,
9 | torch::Tensor coords,
10 | int radius);
11 |
12 | std::vector corr_cuda_backward(
13 | torch::Tensor fmap1,
14 | torch::Tensor fmap2,
15 | torch::Tensor coords,
16 | torch::Tensor corr_grad,
17 | int radius);
18 |
19 | // C++ interface
20 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
21 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
22 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
23 |
24 | std::vector corr_forward(
25 | torch::Tensor fmap1,
26 | torch::Tensor fmap2,
27 | torch::Tensor coords,
28 | int radius) {
29 | CHECK_INPUT(fmap1);
30 | CHECK_INPUT(fmap2);
31 | CHECK_INPUT(coords);
32 |
33 | return corr_cuda_forward(fmap1, fmap2, coords, radius);
34 | }
35 |
36 |
37 | std::vector corr_backward(
38 | torch::Tensor fmap1,
39 | torch::Tensor fmap2,
40 | torch::Tensor coords,
41 | torch::Tensor corr_grad,
42 | int radius) {
43 | CHECK_INPUT(fmap1);
44 | CHECK_INPUT(fmap2);
45 | CHECK_INPUT(coords);
46 | CHECK_INPUT(corr_grad);
47 |
48 | return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius);
49 | }
50 |
51 |
52 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
53 | m.def("forward", &corr_forward, "CORR forward");
54 | m.def("backward", &corr_backward, "CORR backward");
55 | =======
56 | #include
57 | #include
58 |
59 | // CUDA forward declarations
60 | std::vector corr_cuda_forward(
61 | torch::Tensor fmap1,
62 | torch::Tensor fmap2,
63 | torch::Tensor coords,
64 | int radius);
65 |
66 | std::vector corr_cuda_backward(
67 | torch::Tensor fmap1,
68 | torch::Tensor fmap2,
69 | torch::Tensor coords,
70 | torch::Tensor corr_grad,
71 | int radius);
72 |
73 | // C++ interface
74 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
75 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
76 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
77 |
78 | std::vector corr_forward(
79 | torch::Tensor fmap1,
80 | torch::Tensor fmap2,
81 | torch::Tensor coords,
82 | int radius) {
83 | CHECK_INPUT(fmap1);
84 | CHECK_INPUT(fmap2);
85 | CHECK_INPUT(coords);
86 |
87 | return corr_cuda_forward(fmap1, fmap2, coords, radius);
88 | }
89 |
90 |
91 | std::vector corr_backward(
92 | torch::Tensor fmap1,
93 | torch::Tensor fmap2,
94 | torch::Tensor coords,
95 | torch::Tensor corr_grad,
96 | int radius) {
97 | CHECK_INPUT(fmap1);
98 | CHECK_INPUT(fmap2);
99 | CHECK_INPUT(coords);
100 | CHECK_INPUT(corr_grad);
101 |
102 | return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius);
103 | }
104 |
105 |
106 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
107 | m.def("forward", &corr_forward, "CORR forward");
108 | m.def("backward", &corr_backward, "CORR backward");
109 | >>>>>>> Stashed changes
110 | }
--------------------------------------------------------------------------------
/ezsynth/utils/flow_utils/alt_cuda_corr/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3 |
4 |
5 | setup(
6 | name='correlation',
7 | ext_modules=[
8 | CUDAExtension('alt_cuda_corr',
9 | sources=['correlation.cpp', 'correlation_kernel.cu'],
10 | extra_compile_args={'cxx': [], 'nvcc': ['-O3']}),
11 | ],
12 | cmdclass={
13 | 'build_ext': BuildExtension
14 | })
15 |
16 |
--------------------------------------------------------------------------------
/ezsynth/utils/flow_utils/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/ezsynth/utils/flow_utils/core/__init__.py
--------------------------------------------------------------------------------
/ezsynth/utils/flow_utils/core/corr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from .utils.utils import bilinear_sampler
5 |
6 | try:
7 | import alt_cuda_corr
8 | except ImportError as e:
9 | print(f"alt_cuda_corr is not compiled. Not fatal. {e}")
10 | # alt_cuda_corr is not compiled
11 | pass
12 |
13 |
14 | class CorrBlock:
15 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
16 | self.num_levels = num_levels
17 | self.radius = radius
18 | self.corr_pyramid = []
19 |
20 | # all pairs correlation
21 | corr = CorrBlock.corr(fmap1, fmap2)
22 |
23 | batch, h1, w1, dim, h2, w2 = corr.shape
24 | corr = corr.reshape(batch * h1 * w1, dim, h2, w2)
25 |
26 | self.corr_pyramid.append(corr)
27 | for i in range(self.num_levels - 1):
28 | corr = F.avg_pool2d(corr, 2, stride=2)
29 | self.corr_pyramid.append(corr)
30 |
31 | def __call__(self, coords):
32 | r = self.radius
33 | coords = coords.permute(0, 2, 3, 1)
34 | batch, h1, w1, _ = coords.shape
35 |
36 | out_pyramid = []
37 | for i in range(self.num_levels):
38 | corr = self.corr_pyramid[i]
39 | dx = torch.linspace(-r, r, 2 * r + 1, device=coords.device)
40 | dy = torch.linspace(-r, r, 2 * r + 1, device=coords.device)
41 | delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1)
42 |
43 | centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i
44 | delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
45 | coords_lvl = centroid_lvl + delta_lvl
46 |
47 | corr = bilinear_sampler(corr, coords_lvl)
48 | corr = corr.view(batch, h1, w1, -1)
49 | out_pyramid.append(corr)
50 |
51 | out = torch.cat(out_pyramid, dim=-1)
52 | return out.permute(0, 3, 1, 2).contiguous().float()
53 |
54 | @staticmethod
55 | def corr(fmap1, fmap2):
56 | batch, dim, ht, wd = fmap1.shape
57 | fmap1 = fmap1.view(batch, dim, ht * wd)
58 | fmap2 = fmap2.view(batch, dim, ht * wd)
59 |
60 | corr = torch.matmul(fmap1.transpose(1, 2), fmap2)
61 | corr = corr.view(batch, ht, wd, 1, ht, wd)
62 | return corr / torch.sqrt(torch.tensor(dim).float())
63 |
64 |
65 | class AlternateCorrBlock:
66 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
67 | self.num_levels = num_levels
68 | self.radius = radius
69 |
70 | self.pyramid = [(fmap1, fmap2)]
71 | for i in range(self.num_levels):
72 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
73 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
74 | self.pyramid.append((fmap1, fmap2))
75 |
76 | def __call__(self, coords):
77 | coords = coords.permute(0, 2, 3, 1)
78 | B, H, W, _ = coords.shape
79 | dim = self.pyramid[0][0].shape[1]
80 |
81 | corr_list = []
82 | for i in range(self.num_levels):
83 | r = self.radius
84 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
85 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
86 |
87 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
88 | (corr,) = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
89 | corr_list.append(corr.squeeze(1))
90 |
91 | corr = torch.stack(corr_list, dim=1)
92 | corr = corr.reshape(B, -1, H, W)
93 | return corr / torch.sqrt(torch.tensor(dim).float())
94 |
95 | class EF_CorrBlock:
96 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
97 | self.num_levels = num_levels
98 | self.radius = radius
99 | self.corr_pyramid = []
100 |
101 | # all pairs correlation
102 | corr = EF_CorrBlock.corr(fmap1, fmap2)
103 |
104 | batch, h1, w1, dim, h2, w2 = corr.shape
105 | self.corr_map = corr.view(batch, h1 * w1, h2 * w2)
106 | corr = corr.reshape(batch*h1*w1, dim, h2, w2)
107 |
108 | self.corr_pyramid.append(corr)
109 | for i in range(self.num_levels-1):
110 | corr = F.avg_pool2d(corr, 2, stride=2)
111 | self.corr_pyramid.append(corr)
112 |
113 | def __call__(self, coords, scalers=None):
114 | r = self.radius
115 |
116 | if scalers is not None:
117 | assert(scalers.shape[-1] == 4 and scalers.shape[-2] == self.num_levels)
118 | scalers = scalers.view(-1, 1, scalers.shape[-2], 2, scalers.shape[-1] // 2)
119 |
120 | coords = coords.permute(0, 2, 3, 1)
121 | batch, h1, w1, _ = coords.shape
122 |
123 | out_pyramid = []
124 | for i in range(self.num_levels):
125 | corr = self.corr_pyramid[i]
126 | centroid_lvl = coords.reshape(batch, h1*w1, 1, 1, 2) / 2**i
127 | dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
128 | dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
129 | delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1)
130 | delta = delta.view(-1, 2)
131 | delta = delta.repeat((batch, 1, 1))
132 | if scalers is not None:
133 | delta[..., 0] *= scalers[..., i, 0, 0]
134 | delta[..., 1] *= scalers[..., i, 0, 1]
135 | delta[..., 0] += torch.sign(delta[..., 0]) * scalers[..., i, 1, 0] * r
136 | delta[..., 1] += torch.sign(delta[..., 1]) * scalers[..., i, 1, 1] * r
137 | delta_lvl = delta.view(batch, 1, 2*r+1, 2*r+1, 2)
138 | coords_lvl = centroid_lvl + delta_lvl
139 |
140 | coords_lvl = coords_lvl.reshape(-1, coords_lvl.shape[-3], coords_lvl.shape[-2], coords_lvl.shape[-1])
141 |
142 | corr = bilinear_sampler(corr, coords_lvl)
143 | corr = corr.view(batch, h1, w1, -1)
144 | out_pyramid.append(corr)
145 |
146 | out = torch.cat(out_pyramid, dim=-1)
147 | return out.permute(0, 3, 1, 2).contiguous().float()
148 |
149 | @staticmethod
150 | def corr(fmap1, fmap2):
151 | batch, dim, ht, wd = fmap1.shape
152 | fmap1 = fmap1.view(batch, dim, ht*wd)
153 | fmap2 = fmap2.view(batch, dim, ht*wd)
154 |
155 | corr = torch.matmul(fmap1.transpose(1,2), fmap2)
156 | corr = corr.view(batch, ht, wd, 1, ht, wd)
157 | return corr / torch.sqrt(torch.tensor(dim).float())
158 |
159 | class EF_AlternateCorrBlock:
160 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
161 | self.num_levels = num_levels
162 | self.radius = radius
163 |
164 | self.pyramid = [(fmap1, fmap2)]
165 | for i in range(self.num_levels):
166 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
167 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
168 | self.pyramid.append((fmap1, fmap2))
169 |
170 | def __call__(self, coords):
171 | coords = coords.permute(0, 2, 3, 1)
172 | B, H, W, _ = coords.shape
173 | dim = self.pyramid[0][0].shape[1]
174 |
175 | corr_list = []
176 | for i in range(self.num_levels):
177 | r = self.radius
178 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
179 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
180 |
181 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
182 | corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
183 | corr_list.append(corr.squeeze(1))
184 |
185 | corr = torch.stack(corr_list, dim=1)
186 | corr = corr.reshape(B, -1, H, W)
187 | return corr / torch.sqrt(torch.tensor(dim).float())
188 |
--------------------------------------------------------------------------------
/ezsynth/utils/flow_utils/core/datasets.py:
--------------------------------------------------------------------------------
1 | # Data loading based on https://github.com/NVIDIA/flownet2-pytorch
2 |
3 | # import math
4 | import os
5 | import os.path as osp
6 | import random
7 | from glob import glob
8 |
9 | import numpy as np
10 | import torch
11 | # import torch.nn.functional as F
12 | import torch.utils.data as data
13 |
14 | from .utils import frame_utils
15 | from .utils.augmentor import FlowAugmentor, SparseFlowAugmentor
16 |
17 |
18 | class FlowDataset(data.Dataset):
19 | def __init__(self, aug_params=None, sparse=False):
20 | self.augmentor = None
21 | self.sparse = sparse
22 | if aug_params is not None:
23 | if sparse:
24 | self.augmentor = SparseFlowAugmentor(**aug_params)
25 | else:
26 | self.augmentor = FlowAugmentor(**aug_params)
27 |
28 | self.is_test = False
29 | self.init_seed = False
30 | self.flow_list = []
31 | self.image_list = []
32 | self.extra_info = []
33 |
34 | def __getitem__(self, index):
35 | if self.is_test:
36 | img1 = frame_utils.read_gen(self.image_list[index][0])
37 | img2 = frame_utils.read_gen(self.image_list[index][1])
38 | img1 = np.array(img1).astype(np.uint8)[..., :3]
39 | img2 = np.array(img2).astype(np.uint8)[..., :3]
40 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
41 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
42 | return img1, img2, self.extra_info[index]
43 |
44 | if not self.init_seed:
45 | worker_info = torch.utils.data.get_worker_info()
46 | if worker_info is not None:
47 | torch.manual_seed(worker_info.id)
48 | np.random.seed(worker_info.id)
49 | random.seed(worker_info.id)
50 | self.init_seed = True
51 |
52 | index = index % len(self.image_list)
53 | valid = None
54 | if self.sparse:
55 | flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
56 | else:
57 | flow = frame_utils.read_gen(self.flow_list[index])
58 |
59 | img1 = frame_utils.read_gen(self.image_list[index][0])
60 | img2 = frame_utils.read_gen(self.image_list[index][1])
61 |
62 | flow = np.array(flow).astype(np.float32)
63 | img1 = np.array(img1).astype(np.uint8)
64 | img2 = np.array(img2).astype(np.uint8)
65 |
66 | # grayscale images
67 | if len(img1.shape) == 2:
68 | img1 = np.tile(img1[..., None], (1, 1, 3))
69 | img2 = np.tile(img2[..., None], (1, 1, 3))
70 | else:
71 | img1 = img1[..., :3]
72 | img2 = img2[..., :3]
73 |
74 | if self.augmentor is not None:
75 | if self.sparse:
76 | img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
77 | else:
78 | img1, img2, flow = self.augmentor(img1, img2, flow)
79 |
80 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
81 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
82 | flow = torch.from_numpy(flow).permute(2, 0, 1).float()
83 |
84 | if valid is not None:
85 | valid = torch.from_numpy(valid)
86 | else:
87 | valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)
88 |
89 | return img1, img2, flow, valid.float()
90 |
91 | def __rmul__(self, v):
92 | self.flow_list = v * self.flow_list
93 | self.image_list = v * self.image_list
94 | return self
95 |
96 | def __len__(self):
97 | return len(self.image_list)
98 |
99 |
100 | class MpiSintel(FlowDataset):
101 | def __init__(
102 | self, aug_params=None, split="training", root="datasets/Sintel", dstype="clean"
103 | ):
104 | super(MpiSintel, self).__init__(aug_params)
105 | flow_root = osp.join(root, split, "flow")
106 | image_root = osp.join(root, split, dstype)
107 |
108 | if split == "test":
109 | self.is_test = True
110 |
111 | for scene in os.listdir(image_root):
112 | image_list = sorted(glob(osp.join(image_root, scene, "*.png")))
113 | for i in range(len(image_list) - 1):
114 | self.image_list += [[image_list[i], image_list[i + 1]]]
115 | self.extra_info += [(scene, i)] # scene and frame_id
116 |
117 | if split != "test":
118 | self.flow_list += sorted(glob(osp.join(flow_root, scene, "*.flo")))
119 |
120 |
121 | class FlyingChairs(FlowDataset):
122 | def __init__(
123 | self, aug_params=None, split="train", root="datasets/FlyingChairs_release/data"
124 | ):
125 | super(FlyingChairs, self).__init__(aug_params)
126 |
127 | images = sorted(glob(osp.join(root, "*.ppm")))
128 | flows = sorted(glob(osp.join(root, "*.flo")))
129 | assert len(images) // 2 == len(flows)
130 |
131 | split_list = np.loadtxt("chairs_split.txt", dtype=np.int32)
132 | for i in range(len(flows)):
133 | xid = split_list[i]
134 | if (split == "training" and xid == 1) or (
135 | split == "validation" and xid == 2
136 | ):
137 | self.flow_list += [flows[i]]
138 | self.image_list += [[images[2 * i], images[2 * i + 1]]]
139 |
140 |
141 | class FlyingThings3D(FlowDataset):
142 | def __init__(
143 | self, aug_params=None, root="datasets/FlyingThings3D", dstype="frames_cleanpass"
144 | ):
145 | super(FlyingThings3D, self).__init__(aug_params)
146 |
147 | for cam in ["left"]:
148 | for direction in ["into_future", "into_past"]:
149 | image_dirs = sorted(glob(osp.join(root, dstype, "TRAIN/*/*")))
150 | image_dirs = sorted([osp.join(f, cam) for f in image_dirs])
151 |
152 | flow_dirs = sorted(glob(osp.join(root, "optical_flow/TRAIN/*/*")))
153 | flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs])
154 |
155 | for idir, fdir in zip(image_dirs, flow_dirs):
156 | images = sorted(glob(osp.join(idir, "*.png")))
157 | flows = sorted(glob(osp.join(fdir, "*.pfm")))
158 | for i in range(len(flows) - 1):
159 | if direction == "into_future":
160 | self.image_list += [[images[i], images[i + 1]]]
161 | self.flow_list += [flows[i]]
162 | elif direction == "into_past":
163 | self.image_list += [[images[i + 1], images[i]]]
164 | self.flow_list += [flows[i + 1]]
165 |
166 |
167 | class KITTI(FlowDataset):
168 | def __init__(self, aug_params=None, split="training", root="datasets/KITTI"):
169 | super(KITTI, self).__init__(aug_params, sparse=True)
170 | if split == "testing":
171 | self.is_test = True
172 |
173 | root = osp.join(root, split)
174 | images1 = sorted(glob(osp.join(root, "image_2/*_10.png")))
175 | images2 = sorted(glob(osp.join(root, "image_2/*_11.png")))
176 |
177 | for img1, img2 in zip(images1, images2):
178 | frame_id = img1.split("/")[-1]
179 | self.extra_info += [[frame_id]]
180 | self.image_list += [[img1, img2]]
181 |
182 | if split == "training":
183 | self.flow_list = sorted(glob(osp.join(root, "flow_occ/*_10.png")))
184 |
185 |
186 | class HD1K(FlowDataset):
187 | def __init__(self, aug_params=None, root="datasets/HD1k"):
188 | super(HD1K, self).__init__(aug_params, sparse=True)
189 |
190 | seq_ix = 0
191 | while 1:
192 | flows = sorted(
193 | glob(os.path.join(root, "hd1k_flow_gt", "flow_occ/%06d_*.png" % seq_ix))
194 | )
195 | images = sorted(
196 | glob(os.path.join(root, "hd1k_input", "image_2/%06d_*.png" % seq_ix))
197 | )
198 |
199 | if len(flows) == 0:
200 | break
201 |
202 | for i in range(len(flows) - 1):
203 | self.flow_list += [flows[i]]
204 | self.image_list += [[images[i], images[i + 1]]]
205 |
206 | seq_ix += 1
207 |
208 |
209 | def fetch_dataloader(args, TRAIN_DS="C+T+K+S+H"):
210 | """Create the data loader for the corresponding trainign set"""
211 |
212 | if args.stage == "chairs":
213 | aug_params = {
214 | "crop_size": args.image_size,
215 | "min_scale": -0.1,
216 | "max_scale": 1.0,
217 | "do_flip": True,
218 | }
219 | train_dataset = FlyingChairs(aug_params, split="training")
220 |
221 | elif args.stage == "things":
222 | aug_params = {
223 | "crop_size": args.image_size,
224 | "min_scale": -0.4,
225 | "max_scale": 0.8,
226 | "do_flip": True,
227 | }
228 | clean_dataset = FlyingThings3D(aug_params, dstype="frames_cleanpass")
229 | final_dataset = FlyingThings3D(aug_params, dstype="frames_finalpass")
230 | train_dataset = clean_dataset + final_dataset
231 |
232 | elif args.stage == "sintel":
233 | aug_params = {
234 | "crop_size": args.image_size,
235 | "min_scale": -0.2,
236 | "max_scale": 0.6,
237 | "do_flip": True,
238 | }
239 | things = FlyingThings3D(aug_params, dstype="frames_cleanpass")
240 | sintel_clean = MpiSintel(aug_params, split="training", dstype="clean")
241 | sintel_final = MpiSintel(aug_params, split="training", dstype="final")
242 |
243 | if TRAIN_DS == "C+T+K+S+H":
244 | kitti = KITTI(
245 | {
246 | "crop_size": args.image_size,
247 | "min_scale": -0.3,
248 | "max_scale": 0.5,
249 | "do_flip": True,
250 | }
251 | )
252 | hd1k = HD1K(
253 | {
254 | "crop_size": args.image_size,
255 | "min_scale": -0.5,
256 | "max_scale": 0.2,
257 | "do_flip": True,
258 | }
259 | )
260 | train_dataset = (
261 | 100 * sintel_clean
262 | + 100 * sintel_final
263 | + 200 * kitti
264 | + 5 * hd1k
265 | + things
266 | )
267 |
268 | elif TRAIN_DS == "C+T+K/S":
269 | train_dataset = 100 * sintel_clean + 100 * sintel_final + things
270 |
271 | elif args.stage == "kitti":
272 | aug_params = {
273 | "crop_size": args.image_size,
274 | "min_scale": -0.2,
275 | "max_scale": 0.4,
276 | "do_flip": False,
277 | }
278 | train_dataset = KITTI(aug_params, split="training")
279 |
280 | train_loader = data.DataLoader(
281 | train_dataset,
282 | batch_size=args.batch_size,
283 | pin_memory=False,
284 | shuffle=True,
285 | num_workers=4,
286 | drop_last=True,
287 | )
288 |
289 | print("Training with %d image pairs" % len(train_dataset))
290 | return train_loader
291 |
--------------------------------------------------------------------------------
/ezsynth/utils/flow_utils/core/ef_raft.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .corr import EF_AlternateCorrBlock, EF_CorrBlock
6 | from .extractor import BasicEncoder, SmallEncoder, CoordinateAttention
7 | from .update import BasicUpdateBlock, SmallUpdateBlock, LookupScaler
8 | from .utils.utils import coords_grid, upflow8
9 |
10 | try:
11 | autocast = torch.cuda.amp.autocast
12 | except Exception as e:
13 | # dummy autocast for PyTorch < 1.6
14 | print(e)
15 |
16 | class autocast:
17 | def __init__(self, enabled):
18 | pass
19 |
20 | def __enter__(self):
21 | pass
22 |
23 | def __exit__(self, *args):
24 | pass
25 |
26 |
27 | class EF_RAFT(nn.Module):
28 | def __init__(self, args):
29 | super(EF_RAFT, self).__init__()
30 | self.args = args
31 |
32 | if args.small:
33 | self.hidden_dim = hdim = 96
34 | self.context_dim = cdim = 64
35 | args.corr_levels = 4
36 | args.corr_radius = 3
37 |
38 | else:
39 | self.hidden_dim = hdim = 128
40 | self.context_dim = cdim = 128
41 | args.corr_levels = 4
42 | args.corr_radius = 4
43 |
44 | if "dropout" not in self.args:
45 | self.args.dropout = 0
46 |
47 | if "alternate_corr" not in self.args:
48 | self.args.alternate_corr = False
49 |
50 | # feature network, context network, and update block
51 | if args.small:
52 | self.fnet = SmallEncoder(
53 | output_dim=128, norm_fn="instance", dropout=args.dropout
54 | )
55 | self.coor_att = None
56 | self.cnet = SmallEncoder(
57 | output_dim=hdim + cdim, norm_fn="none", dropout=args.dropout
58 | )
59 | self.lookup_scaler = None
60 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
61 |
62 | else:
63 | self.fnet = BasicEncoder(
64 | output_dim=256, norm_fn="instance", dropout=args.dropout
65 | )
66 | self.coor_att = CoordinateAttention(feature_size=256, enc_size=128) # New
67 | self.cnet = BasicEncoder(
68 | output_dim=hdim + cdim, norm_fn="batch", dropout=args.dropout
69 | )
70 | self.lookup_scaler = LookupScaler(
71 | input_dim=hdim, output_size=args.corr_levels # New
72 | )
73 | self.update_block = BasicUpdateBlock(
74 | self.args, hidden_dim=hdim, input_dim=cdim + args.corr_levels * 4
75 | ) # Updated
76 |
77 | def freeze_bn(self):
78 | for m in self.modules():
79 | if isinstance(m, nn.BatchNorm2d):
80 | m.eval()
81 |
82 | def initialize_flow(self, img):
83 | """Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
84 | N, C, H, W = img.shape
85 | coords0 = coords_grid(N, H // 8, W // 8, device=img.device)
86 | coords1 = coords_grid(N, H // 8, W // 8, device=img.device)
87 |
88 | # optical flow computed as difference: flow = coords1 - coords0
89 | return coords0, coords1
90 |
91 | def upsample_flow(self, flow, mask):
92 | """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination"""
93 | N, _, H, W = flow.shape
94 | mask = mask.view(N, 1, 9, 8, 8, H, W)
95 | mask = torch.softmax(mask, dim=2)
96 |
97 | up_flow = F.unfold(8 * flow, [3, 3], padding=1)
98 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
99 |
100 | up_flow = torch.sum(mask * up_flow, dim=2)
101 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
102 | return up_flow.reshape(N, 2, 8 * H, 8 * W)
103 |
104 | def forward(
105 | self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False
106 | ):
107 | """Estimate optical flow between pair of frames"""
108 |
109 | image1 = 2 * (image1 / 255.0) - 1.0
110 | image2 = 2 * (image2 / 255.0) - 1.0
111 |
112 | image1 = image1.contiguous()
113 | image2 = image2.contiguous()
114 |
115 | hdim = self.hidden_dim
116 | cdim = self.context_dim
117 |
118 | # run the feature network
119 | with autocast(enabled=self.args.mixed_precision):
120 | fmap1, fmap2 = self.fnet([image1, image2])
121 | fmap1 = self.coor_att(fmap1)
122 | fmap2 = self.coor_att(fmap2)
123 |
124 | fmap1 = fmap1.float()
125 | fmap2 = fmap2.float()
126 | if self.args.alternate_corr:
127 | corr_fn = EF_AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
128 | else:
129 | corr_fn = EF_CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
130 |
131 | # run the context network
132 | with autocast(enabled=self.args.mixed_precision):
133 | cnet = self.cnet(image1)
134 | net, base_inp = torch.split(cnet, [hdim, cdim], dim=1)
135 | net = torch.tanh(net)
136 | base_inp = torch.relu(base_inp)
137 |
138 | coords0, coords1 = self.initialize_flow(image1)
139 |
140 | BATCH_N, fC, fH, fW = fmap1.shape
141 | corr_map = corr_fn.corr_map
142 | soft_corr_map = F.softmax(corr_map, dim=2) * F.softmax(corr_map, dim=1)
143 |
144 | if flow_init is not None:
145 | coords1 = coords1 + flow_init
146 | else: # Use the global matching idea as an initialization.
147 | match_f, match_f_ind = soft_corr_map.max(dim=2) # Forward matching.
148 | match_b, match_b_ind = soft_corr_map.max(dim=1) # Backward matching.
149 |
150 | # Permute the backward softmax for match the forward.
151 | for i in range(BATCH_N):
152 | match_b_tmp = match_b[i, ...]
153 | match_b[i, ...] = match_b_tmp[match_f_ind[i, ...]]
154 |
155 | # Replace the identity mapping with the found matches.
156 | matched = (match_f - match_b) == 0
157 | coords_index = (
158 | torch.arange(fH * fW)
159 | .unsqueeze(0)
160 | .repeat(BATCH_N, 1)
161 | .to(soft_corr_map.device)
162 | )
163 | coords_index[matched] = match_f_ind[matched]
164 |
165 | # Convert the 1D mapping to a 2D one.
166 | coords_index = coords_index.reshape(BATCH_N, fH, fW)
167 | coords_x = coords_index % fW
168 | coords_y = coords_index // fW
169 | coords1 = torch.stack([coords_x, coords_y], dim=1).float()
170 |
171 | # Iterative update
172 | flow_predictions = []
173 | for itr in range(iters):
174 | coords1 = coords1.detach()
175 |
176 | lookup_scalers = None
177 | if self.lookup_scaler is not None:
178 | lookup_scalers = self.lookup_scaler(base_inp, net)
179 | cat_lookup_scalers = lookup_scalers.view(
180 | -1, lookup_scalers.shape[-1] * lookup_scalers.shape[-2], 1, 1
181 | )
182 | cat_lookup_scalers = cat_lookup_scalers.expand(
183 | -1, -1, base_inp.shape[2], base_inp.shape[3]
184 | )
185 | inp = torch.cat([base_inp, cat_lookup_scalers], dim=1)
186 |
187 | corr = corr_fn(coords1, scalers=lookup_scalers) # index correlation volume
188 |
189 | flow = coords1 - coords0
190 | with autocast(enabled=self.args.mixed_precision):
191 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
192 |
193 | # F(t+1) = F(t) + \Delta(t)
194 | coords1 = coords1 + delta_flow
195 |
196 | # upsample predictions
197 | if up_mask is None:
198 | flow_up = upflow8(coords1 - coords0)
199 | else:
200 | flow_up = self.upsample_flow(coords1 - coords0, up_mask)
201 |
202 | flow_predictions.append(flow_up)
203 |
204 | if test_mode:
205 | return coords1 - coords0, flow_up
206 |
207 | return flow_predictions
208 |
--------------------------------------------------------------------------------
/ezsynth/utils/flow_utils/core/fd_corr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from .utils.utils import bilinear_sampler
5 |
6 |
7 | class CorrBlock_FD_Sp4:
8 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4, coords_init=None, rad=1):
9 | self.num_levels = num_levels
10 | self.radius = radius
11 | self.corr_pyramid = []
12 |
13 | corr = CorrBlock_FD_Sp4.corr(fmap1, fmap2, coords_init, r=rad)
14 |
15 | batch, h1, w1, dim, h2, w2 = corr.shape
16 | corr = corr.reshape(batch*h1*w1, dim, h2, w2)
17 |
18 | self.corr_pyramid.append(corr)
19 | for i in range(self.num_levels-1):
20 | corr = F.avg_pool2d(corr, 2, stride=2)
21 | self.corr_pyramid.append(corr)
22 |
23 | def __call__(self, coords):
24 | r = self.radius
25 | coords = coords.permute(0, 2, 3, 1)
26 | batch, h1, w1, _ = coords.shape
27 |
28 | out_pyramid = []
29 | for i in range(self.num_levels):
30 | corr = self.corr_pyramid[i]
31 | dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
32 | dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
33 | delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1)
34 |
35 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
36 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
37 | coords_lvl = centroid_lvl + delta_lvl
38 |
39 | corr = bilinear_sampler(corr, coords_lvl)
40 | corr = corr.view(batch, h1, w1, -1)
41 | out_pyramid.append(corr)
42 |
43 | out = torch.cat(out_pyramid, dim=-1)
44 | return out.permute(0, 3, 1, 2).contiguous().float()
45 |
46 | @staticmethod
47 | def corr(fmap1, fmap2, coords_init, r):
48 | batch, dim, ht, wd = fmap1.shape
49 | fmap1 = fmap1.view(batch, dim, ht*wd)
50 | fmap2 = fmap2.view(batch, dim, ht*wd)
51 |
52 | corr = torch.matmul(fmap1.transpose(1, 2), fmap2)
53 | corr = corr.view(batch, ht, wd, 1, ht, wd)
54 | # return corr / torch.sqrt(torch.tensor(dim).float())
55 |
56 | coords = coords_init.permute(0, 2, 3, 1).contiguous()
57 | batch, h1, w1, _ = coords.shape
58 |
59 | corr = corr.view(batch*h1*w1, 1, h1, w1)
60 |
61 | dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
62 | dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
63 | delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1)
64 |
65 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2)
66 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
67 | coords_lvl = centroid_lvl + delta_lvl
68 |
69 | corr = bilinear_sampler(corr, coords_lvl)
70 |
71 | corr = corr.view(batch, h1, w1, 1, 2*r+1, 2*r+1)
72 | return corr.permute(0, 1, 2, 3, 5, 4).contiguous() / torch.sqrt(torch.tensor(dim).float())
73 |
--------------------------------------------------------------------------------
/ezsynth/utils/flow_utils/core/fd_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | # import torch.nn.functional as F
4 |
5 | import timm
6 | import numpy as np
7 |
8 |
9 | class twins_svt_large(nn.Module):
10 | def __init__(self, pretrained=True):
11 | super().__init__()
12 | self.svt = timm.create_model('twins_svt_large', pretrained=pretrained)
13 |
14 | del self.svt.head
15 | del self.svt.patch_embeds[2]
16 | del self.svt.patch_embeds[2]
17 | del self.svt.blocks[2]
18 | del self.svt.blocks[2]
19 | del self.svt.pos_block[2]
20 | del self.svt.pos_block[2]
21 |
22 | def forward(self, x, data=None, layer=2):
23 | B = x.shape[0]
24 | x_4 = None
25 | for i, (embed, drop, blocks, pos_blk) in enumerate(
26 | zip(self.svt.patch_embeds, self.svt.pos_drops, self.svt.blocks, self.svt.pos_block)):
27 |
28 | patch_size = embed.patch_size
29 | if i == layer - 1:
30 | embed.patch_size = (1, 1)
31 | embed.proj.stride = embed.patch_size
32 | x_4 = torch.nn.functional.pad(x, [1, 0, 1, 0], mode='constant', value=0)
33 | x_4, size_4 = embed(x_4)
34 | size_4 = (size_4[0] - 1, size_4[1] - 1)
35 | x_4 = drop(x_4)
36 | for j, blk in enumerate(blocks):
37 | x_4 = blk(x_4, size_4)
38 | if j == 0:
39 | x_4 = pos_blk(x_4, size_4)
40 |
41 | if i < len(self.svt.depths) - 1:
42 | x_4 = x_4.reshape(B, *size_4, -1).permute(0, 3, 1, 2).contiguous()
43 |
44 | embed.patch_size = patch_size
45 | embed.proj.stride = patch_size
46 | x, size = embed(x)
47 | x = drop(x)
48 | for j, blk in enumerate(blocks):
49 | x = blk(x, size)
50 | if j==0:
51 | x = pos_blk(x, size)
52 | if i < len(self.svt.depths) - 1:
53 | x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
54 |
55 | if i == layer-1:
56 | break
57 |
58 | return x, x_4
59 |
60 | def compute_params(self, layer=2):
61 | num = 0
62 | for i, (embed, drop, blocks, pos_blk) in enumerate(
63 | zip(self.svt.patch_embeds, self.svt.pos_drops, self.svt.blocks, self.svt.pos_block)):
64 |
65 | for param in embed.parameters():
66 | num += np.prod(param.size())
67 |
68 | for param in drop.parameters():
69 | num += np.prod(param.size())
70 |
71 | for param in blocks.parameters():
72 | num += np.prod(param.size())
73 |
74 | for param in pos_blk.parameters():
75 | num += np.prod(param.size())
76 |
77 | if i == layer-1:
78 | break
79 |
80 | for param in self.svt.head.parameters():
81 | num += np.prod(param.size())
82 |
83 | return num
84 |
85 |
86 | class twins_svt_small_context(nn.Module):
87 | def __init__(self, pretrained=True):
88 | super().__init__()
89 | self.svt = timm.create_model('twins_svt_small', pretrained=pretrained)
90 |
91 | del self.svt.head
92 | del self.svt.patch_embeds[2]
93 | del self.svt.patch_embeds[2]
94 | del self.svt.blocks[2]
95 | del self.svt.blocks[2]
96 | del self.svt.pos_block[2]
97 | del self.svt.pos_block[2]
98 |
99 | def forward(self, x, data=None, layer=2):
100 | B = x.shape[0]
101 | x_4 = None
102 | for i, (embed, drop, blocks, pos_blk) in enumerate(
103 | zip(self.svt.patch_embeds, self.svt.pos_drops, self.svt.blocks, self.svt.pos_block)):
104 |
105 | patch_size = embed.patch_size
106 | if i == layer - 1:
107 | embed.patch_size = (1, 1)
108 | embed.proj.stride = embed.patch_size
109 | x_4 = torch.nn.functional.pad(x, [1, 0, 1, 0], mode='constant', value=0)
110 | x_4, size_4 = embed(x_4)
111 | size_4 = (size_4[0] - 1, size_4[1] - 1)
112 | x_4 = drop(x_4)
113 | for j, blk in enumerate(blocks):
114 | x_4 = blk(x_4, size_4)
115 | if j == 0:
116 | x_4 = pos_blk(x_4, size_4)
117 |
118 | if i < len(self.svt.depths) - 1:
119 | x_4 = x_4.reshape(B, *size_4, -1).permute(0, 3, 1, 2).contiguous()
120 |
121 | embed.patch_size = patch_size
122 | embed.proj.stride = patch_size
123 | x, size = embed(x)
124 | x = drop(x)
125 | for j, blk in enumerate(blocks):
126 | x = blk(x, size)
127 | if j == 0:
128 | x = pos_blk(x, size)
129 | if i < len(self.svt.depths) - 1:
130 | x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
131 |
132 | if i == layer - 1:
133 | break
134 |
135 | return x, x_4
136 |
137 | def compute_params(self, layer=2):
138 | num = 0
139 | for i, (embed, drop, blocks, pos_blk) in enumerate(
140 | zip(self.svt.patch_embeds, self.svt.pos_drops, self.svt.blocks, self.svt.pos_block)):
141 |
142 | for param in embed.parameters():
143 | num += np.prod(param.size())
144 |
145 | for param in drop.parameters():
146 | num += np.prod(param.size())
147 |
148 | for param in blocks.parameters():
149 | num += np.prod(param.size())
150 |
151 | for param in pos_blk.parameters():
152 | num += np.prod(param.size())
153 |
154 | if i == layer - 1:
155 | break
156 |
157 | for param in self.svt.head.parameters():
158 | num += np.prod(param.size())
159 |
160 | return num
161 |
162 |
--------------------------------------------------------------------------------
/ezsynth/utils/flow_utils/core/raft.py:
--------------------------------------------------------------------------------
1 | # import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from .corr import AlternateCorrBlock, CorrBlock
7 | from .extractor import BasicEncoder, SmallEncoder
8 | from .update import BasicUpdateBlock, SmallUpdateBlock
9 | from .utils.utils import coords_grid, upflow8
10 |
11 | try:
12 | autocast = torch.cuda.amp.autocast
13 | except Exception as e:
14 | # dummy autocast for PyTorch < 1.6
15 | print(e)
16 | class autocast:
17 | def __init__(self, enabled):
18 | pass
19 |
20 | def __enter__(self):
21 | pass
22 |
23 | def __exit__(self, *args):
24 | pass
25 |
26 |
27 | class RAFT(nn.Module):
28 | def __init__(self, args):
29 | super(RAFT, self).__init__()
30 | self.args = args
31 |
32 | if args.small:
33 | self.hidden_dim = hdim = 96
34 | self.context_dim = cdim = 64
35 | args.corr_levels = 4
36 | args.corr_radius = 3
37 |
38 | else:
39 | self.hidden_dim = hdim = 128
40 | self.context_dim = cdim = 128
41 | args.corr_levels = 4
42 | args.corr_radius = 4
43 |
44 | if "dropout" not in self.args:
45 | self.args.dropout = 0
46 |
47 | if "alternate_corr" not in self.args:
48 | self.args.alternate_corr = False
49 |
50 | # feature network, context network, and update block
51 | if args.small:
52 | self.fnet = SmallEncoder(
53 | output_dim=128, norm_fn="instance", dropout=args.dropout
54 | )
55 | self.cnet = SmallEncoder(
56 | output_dim=hdim + cdim, norm_fn="none", dropout=args.dropout
57 | )
58 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
59 |
60 | else:
61 | self.fnet = BasicEncoder(
62 | output_dim=256, norm_fn="instance", dropout=args.dropout
63 | )
64 | self.cnet = BasicEncoder(
65 | output_dim=hdim + cdim, norm_fn="batch", dropout=args.dropout
66 | )
67 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
68 |
69 | def freeze_bn(self):
70 | for m in self.modules():
71 | if isinstance(m, nn.BatchNorm2d):
72 | m.eval()
73 |
74 | def initialize_flow(self, img):
75 | """Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
76 | N, C, H, W = img.shape
77 | coords0 = coords_grid(N, H // 8, W // 8, device=img.device)
78 | coords1 = coords_grid(N, H // 8, W // 8, device=img.device)
79 |
80 | # optical flow computed as difference: flow = coords1 - coords0
81 | return coords0, coords1
82 |
83 | def upsample_flow(self, flow, mask):
84 | """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination"""
85 | N, _, H, W = flow.shape
86 | mask = mask.view(N, 1, 9, 8, 8, H, W)
87 | mask = torch.softmax(mask, dim=2)
88 |
89 | up_flow = F.unfold(8 * flow, [3, 3], padding=1)
90 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
91 |
92 | up_flow = torch.sum(mask * up_flow, dim=2)
93 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
94 | return up_flow.reshape(N, 2, 8 * H, 8 * W)
95 |
96 | def forward(
97 | self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False
98 | ):
99 | """Estimate optical flow between pair of frames"""
100 |
101 | image1 = 2 * (image1 / 255.0) - 1.0
102 | image2 = 2 * (image2 / 255.0) - 1.0
103 |
104 | image1 = image1.contiguous()
105 | image2 = image2.contiguous()
106 |
107 | hdim = self.hidden_dim
108 | cdim = self.context_dim
109 |
110 | # run the feature network
111 | with autocast(enabled=self.args.mixed_precision):
112 | fmap1, fmap2 = self.fnet([image1, image2])
113 |
114 | fmap1 = fmap1.float()
115 | fmap2 = fmap2.float()
116 | if self.args.alternate_corr:
117 | corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
118 | else:
119 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
120 |
121 | # run the context network
122 | with autocast(enabled=self.args.mixed_precision):
123 | cnet = self.cnet(image1)
124 | net, inp = torch.split(cnet, [hdim, cdim], dim=1)
125 | net = torch.tanh(net)
126 | inp = torch.relu(inp)
127 |
128 | coords0, coords1 = self.initialize_flow(image1)
129 |
130 | if flow_init is not None:
131 | coords1 = coords1 + flow_init
132 |
133 | flow_predictions = []
134 | for itr in range(iters):
135 | coords1 = coords1.detach()
136 | corr = corr_fn(coords1) # index correlation volume
137 |
138 | flow = coords1 - coords0
139 | with autocast(enabled=self.args.mixed_precision):
140 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
141 |
142 | # F(t+1) = F(t) + \Delta(t)
143 | coords1 = coords1 + delta_flow
144 |
145 | # upsample predictions
146 | if up_mask is None:
147 | flow_up = upflow8(coords1 - coords0)
148 | else:
149 | flow_up = self.upsample_flow(coords1 - coords0, up_mask)
150 |
151 | flow_predictions.append(flow_up)
152 |
153 | if test_mode:
154 | return coords1 - coords0, flow_up
155 |
156 | return flow_predictions
157 |
--------------------------------------------------------------------------------
/ezsynth/utils/flow_utils/core/update.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class FlowHead(nn.Module):
7 | def __init__(self, input_dim=128, hidden_dim=256):
8 | super(FlowHead, self).__init__()
9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
11 | self.relu = nn.ReLU(inplace=True)
12 |
13 | def forward(self, x):
14 | return self.conv2(self.relu(self.conv1(x)))
15 |
16 |
17 | class ConvGRU(nn.Module):
18 | def __init__(self, hidden_dim=128, input_dim=192 + 128):
19 | super(ConvGRU, self).__init__()
20 | self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
21 | self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
22 | self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
23 |
24 | def forward(self, h, x):
25 | hx = torch.cat([h, x], dim=1)
26 |
27 | z = torch.sigmoid(self.convz(hx))
28 | r = torch.sigmoid(self.convr(hx))
29 | q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)))
30 |
31 | h = (1 - z) * h + z * q
32 | return h
33 |
34 |
35 | class SepConvGRU(nn.Module):
36 | def __init__(self, hidden_dim=128, input_dim=192 + 128):
37 | super(SepConvGRU, self).__init__()
38 | self.convz1 = nn.Conv2d(
39 | hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
40 | )
41 | self.convr1 = nn.Conv2d(
42 | hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
43 | )
44 | self.convq1 = nn.Conv2d(
45 | hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
46 | )
47 |
48 | self.convz2 = nn.Conv2d(
49 | hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
50 | )
51 | self.convr2 = nn.Conv2d(
52 | hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
53 | )
54 | self.convq2 = nn.Conv2d(
55 | hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
56 | )
57 |
58 | def forward(self, h, x):
59 | # horizontal
60 | hx = torch.cat([h, x], dim=1)
61 | z = torch.sigmoid(self.convz1(hx))
62 | r = torch.sigmoid(self.convr1(hx))
63 | q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))
64 | h = (1 - z) * h + z * q
65 |
66 | # vertical
67 | hx = torch.cat([h, x], dim=1)
68 | z = torch.sigmoid(self.convz2(hx))
69 | r = torch.sigmoid(self.convr2(hx))
70 | q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))
71 | h = (1 - z) * h + z * q
72 |
73 | return h
74 |
75 |
76 | class SmallMotionEncoder(nn.Module):
77 | def __init__(self, args):
78 | super(SmallMotionEncoder, self).__init__()
79 | cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2
80 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
81 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
82 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
83 | self.conv = nn.Conv2d(128, 80, 3, padding=1)
84 |
85 | def forward(self, flow, corr):
86 | cor = F.relu(self.convc1(corr))
87 | flo = F.relu(self.convf1(flow))
88 | flo = F.relu(self.convf2(flo))
89 | cor_flo = torch.cat([cor, flo], dim=1)
90 | out = F.relu(self.conv(cor_flo))
91 | return torch.cat([out, flow], dim=1)
92 |
93 |
94 | class BasicMotionEncoder(nn.Module):
95 | def __init__(self, args):
96 | super(BasicMotionEncoder, self).__init__()
97 | cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2
98 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
99 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
100 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
101 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
102 | self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1)
103 |
104 | def forward(self, flow, corr):
105 | cor = F.relu(self.convc1(corr))
106 | cor = F.relu(self.convc2(cor))
107 | flo = F.relu(self.convf1(flow))
108 | flo = F.relu(self.convf2(flo))
109 |
110 | cor_flo = torch.cat([cor, flo], dim=1)
111 | out = F.relu(self.conv(cor_flo))
112 | return torch.cat([out, flow], dim=1)
113 |
114 |
115 | class SmallUpdateBlock(nn.Module):
116 | def __init__(self, args, hidden_dim=96):
117 | super(SmallUpdateBlock, self).__init__()
118 | self.encoder = SmallMotionEncoder(args)
119 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82 + 64)
120 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
121 |
122 | def forward(self, net, inp, corr, flow):
123 | motion_features = self.encoder(flow, corr)
124 | inp = torch.cat([inp, motion_features], dim=1)
125 | net = self.gru(net, inp)
126 | delta_flow = self.flow_head(net)
127 |
128 | return net, None, delta_flow
129 |
130 |
131 | class BasicUpdateBlock(nn.Module):
132 | def __init__(self, args, hidden_dim=128, input_dim=128):
133 | super(BasicUpdateBlock, self).__init__()
134 | self.args = args
135 | self.encoder = BasicMotionEncoder(args)
136 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=input_dim+hidden_dim)
137 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
138 |
139 | self.mask = nn.Sequential(
140 | nn.Conv2d(128, 256, 3, padding=1),
141 | nn.ReLU(inplace=True),
142 | nn.Conv2d(256, 64 * 9, 1, padding=0),
143 | )
144 |
145 | def forward(self, net, inp, corr, flow, upsample=True):
146 | motion_features = self.encoder(flow, corr)
147 | inp = torch.cat([inp, motion_features], dim=1)
148 |
149 | net = self.gru(net, inp)
150 | delta_flow = self.flow_head(net)
151 |
152 | # scale mask to balence gradients
153 | mask = 0.25 * self.mask(net)
154 | return net, mask, delta_flow
155 |
156 | class LookupScaler(nn.Module):
157 | def __init__(self, input_dim=128, output_size=4, output_dim=4, max_multiplier=2, max_translation=2):
158 | super(LookupScaler, self).__init__()
159 | self.input_dim = input_dim
160 | self.output_size = output_size
161 | self.output_dim = output_dim
162 | self.max_multiplier = max_multiplier
163 | self.max_translation = max_translation
164 | self.convert_conv = nn.Conv2d(2 * input_dim, 2 * input_dim, 1)
165 | self.model_scale = nn.Sequential(nn.Linear(4 * input_dim, (output_dim // 2) * output_size),
166 | nn.Sigmoid())
167 | self.model_add = nn.Sequential(nn.Linear(4 * input_dim, (output_dim // 2) * output_size),
168 | nn.Sigmoid())
169 |
170 | def forward(self, context_map, hidden_state):
171 | assert(context_map.shape[1] == self.input_dim)
172 | assert(hidden_state.shape[1] == self.input_dim)
173 | convert_map = self.convert_conv(torch.cat([context_map, hidden_state], dim=1).type(torch.float32))
174 | lookup_context = torch.cat([torch.amax(convert_map, dim=(2, 3)),
175 | torch.amin(convert_map, dim=(2, 3))], dim=-1).type(torch.float32)
176 | scale = self.model_scale(lookup_context).view(-1, self.output_size, self.output_dim // 2) * self.max_multiplier + 1
177 | add = self.model_add(lookup_context).view(-1, self.output_size, self.output_dim // 2) * self.max_translation
178 | return torch.cat([scale, add], dim=-1)
--------------------------------------------------------------------------------
/ezsynth/utils/flow_utils/core/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/ezsynth/utils/flow_utils/core/utils/__init__.py
--------------------------------------------------------------------------------
/ezsynth/utils/flow_utils/core/utils/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/ezsynth/utils/flow_utils/core/utils/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/ezsynth/utils/flow_utils/core/utils/__pycache__/utils.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/ezsynth/utils/flow_utils/core/utils/__pycache__/utils.cpython-311.pyc
--------------------------------------------------------------------------------
/ezsynth/utils/flow_utils/core/utils/augmentor.py:
--------------------------------------------------------------------------------
1 | # import math
2 | # import random
3 |
4 | import cv2
5 | import numpy as np
6 | from PIL import Image
7 |
8 | # import torch
9 | # import torch.nn.functional as F
10 | from torchvision.transforms import ColorJitter
11 |
12 | cv2.setNumThreads(0)
13 | cv2.ocl.setUseOpenCL(False)
14 |
15 |
16 | class FlowAugmentor:
17 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):
18 | # spatial augmentation params
19 | self.crop_size = crop_size
20 | self.min_scale = min_scale
21 | self.max_scale = max_scale
22 | self.spatial_aug_prob = 0.8
23 | self.stretch_prob = 0.8
24 | self.max_stretch = 0.2
25 |
26 | # flip augmentation params
27 | self.do_flip = do_flip
28 | self.h_flip_prob = 0.5
29 | self.v_flip_prob = 0.1
30 |
31 | # photometric augmentation params
32 | self.photo_aug = ColorJitter(
33 | brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5 / 3.14
34 | )
35 | self.asymmetric_color_aug_prob = 0.2
36 | self.eraser_aug_prob = 0.5
37 |
38 | def color_transform(self, img1, img2):
39 | """Photometric augmentation"""
40 |
41 | # asymmetric
42 | if np.random.rand() < self.asymmetric_color_aug_prob:
43 | img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
44 | img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
45 |
46 | # symmetric
47 | else:
48 | image_stack = np.concatenate([img1, img2], axis=0)
49 | image_stack = np.array(
50 | self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8
51 | )
52 | img1, img2 = np.split(image_stack, 2, axis=0)
53 |
54 | return img1, img2
55 |
56 | def eraser_transform(self, img1, img2, bounds=[50, 100]):
57 | """Occlusion augmentation"""
58 |
59 | ht, wd = img1.shape[:2]
60 | if np.random.rand() < self.eraser_aug_prob:
61 | mean_color = np.mean(img2.reshape(-1, 3), axis=0)
62 | for _ in range(np.random.randint(1, 3)):
63 | x0 = np.random.randint(0, wd)
64 | y0 = np.random.randint(0, ht)
65 | dx = np.random.randint(bounds[0], bounds[1])
66 | dy = np.random.randint(bounds[0], bounds[1])
67 | img2[y0 : y0 + dy, x0 : x0 + dx, :] = mean_color
68 |
69 | return img1, img2
70 |
71 | def spatial_transform(self, img1, img2, flow):
72 | # randomly sample scale
73 | ht, wd = img1.shape[:2]
74 | min_scale = np.maximum(
75 | (self.crop_size[0] + 8) / float(ht), (self.crop_size[1] + 8) / float(wd)
76 | )
77 |
78 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
79 | scale_x = scale
80 | scale_y = scale
81 | if np.random.rand() < self.stretch_prob:
82 | scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
83 | scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
84 |
85 | scale_x = np.clip(scale_x, min_scale, None)
86 | scale_y = np.clip(scale_y, min_scale, None)
87 |
88 | if np.random.rand() < self.spatial_aug_prob:
89 | # rescale the images
90 | img1 = cv2.resize(
91 | img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR
92 | )
93 | img2 = cv2.resize(
94 | img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR
95 | )
96 | flow = cv2.resize(
97 | flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR
98 | )
99 | flow = flow * [scale_x, scale_y]
100 |
101 | if self.do_flip:
102 | if np.random.rand() < self.h_flip_prob: # h-flip
103 | img1 = img1[:, ::-1]
104 | img2 = img2[:, ::-1]
105 | flow = flow[:, ::-1] * [-1.0, 1.0]
106 |
107 | if np.random.rand() < self.v_flip_prob: # v-flip
108 | img1 = img1[::-1, :]
109 | img2 = img2[::-1, :]
110 | flow = flow[::-1, :] * [1.0, -1.0]
111 |
112 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
113 | x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
114 |
115 | img1 = img1[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
116 | img2 = img2[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
117 | flow = flow[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
118 |
119 | return img1, img2, flow
120 |
121 | def __call__(self, img1, img2, flow):
122 | img1, img2 = self.color_transform(img1, img2)
123 | img1, img2 = self.eraser_transform(img1, img2)
124 | img1, img2, flow = self.spatial_transform(img1, img2, flow)
125 |
126 | img1 = np.ascontiguousarray(img1)
127 | img2 = np.ascontiguousarray(img2)
128 | flow = np.ascontiguousarray(flow)
129 |
130 | return img1, img2, flow
131 |
132 |
133 | class SparseFlowAugmentor:
134 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
135 | # spatial augmentation params
136 | self.crop_size = crop_size
137 | self.min_scale = min_scale
138 | self.max_scale = max_scale
139 | self.spatial_aug_prob = 0.8
140 | self.stretch_prob = 0.8
141 | self.max_stretch = 0.2
142 |
143 | # flip augmentation params
144 | self.do_flip = do_flip
145 | self.h_flip_prob = 0.5
146 | self.v_flip_prob = 0.1
147 |
148 | # photometric augmentation params
149 | self.photo_aug = ColorJitter(
150 | brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3 / 3.14
151 | )
152 | self.asymmetric_color_aug_prob = 0.2
153 | self.eraser_aug_prob = 0.5
154 |
155 | def color_transform(self, img1, img2):
156 | image_stack = np.concatenate([img1, img2], axis=0)
157 | image_stack = np.array(
158 | self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8
159 | )
160 | img1, img2 = np.split(image_stack, 2, axis=0)
161 | return img1, img2
162 |
163 | def eraser_transform(self, img1, img2):
164 | ht, wd = img1.shape[:2]
165 | if np.random.rand() < self.eraser_aug_prob:
166 | mean_color = np.mean(img2.reshape(-1, 3), axis=0)
167 | for _ in range(np.random.randint(1, 3)):
168 | x0 = np.random.randint(0, wd)
169 | y0 = np.random.randint(0, ht)
170 | dx = np.random.randint(50, 100)
171 | dy = np.random.randint(50, 100)
172 | img2[y0 : y0 + dy, x0 : x0 + dx, :] = mean_color
173 |
174 | return img1, img2
175 |
176 | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
177 | ht, wd = flow.shape[:2]
178 | coords = np.meshgrid(np.arange(wd), np.arange(ht), indexing="xy")
179 | coords = np.stack(coords, axis=-1)
180 |
181 | coords = coords.reshape(-1, 2).astype(np.float32)
182 | flow = flow.reshape(-1, 2).astype(np.float32)
183 | valid = valid.reshape(-1).astype(np.float32)
184 |
185 | coords0 = coords[valid >= 1]
186 | flow0 = flow[valid >= 1]
187 |
188 | ht1 = int(round(ht * fy))
189 | wd1 = int(round(wd * fx))
190 |
191 | coords1 = coords0 * [fx, fy]
192 | flow1 = flow0 * [fx, fy]
193 |
194 | xx = np.round(coords1[:, 0]).astype(np.int32)
195 | yy = np.round(coords1[:, 1]).astype(np.int32)
196 |
197 | v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
198 | xx = xx[v]
199 | yy = yy[v]
200 | flow1 = flow1[v]
201 |
202 | flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
203 | valid_img = np.zeros([ht1, wd1], dtype=np.int32)
204 |
205 | flow_img[yy, xx] = flow1
206 | valid_img[yy, xx] = 1
207 |
208 | return flow_img, valid_img
209 |
210 | def spatial_transform(self, img1, img2, flow, valid):
211 | # randomly sample scale
212 |
213 | ht, wd = img1.shape[:2]
214 | min_scale = np.maximum(
215 | (self.crop_size[0] + 1) / float(ht), (self.crop_size[1] + 1) / float(wd)
216 | )
217 |
218 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
219 | scale_x = np.clip(scale, min_scale, None)
220 | scale_y = np.clip(scale, min_scale, None)
221 |
222 | if np.random.rand() < self.spatial_aug_prob:
223 | # rescale the images
224 | img1 = cv2.resize(
225 | img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR
226 | )
227 | img2 = cv2.resize(
228 | img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR
229 | )
230 | flow, valid = self.resize_sparse_flow_map(
231 | flow, valid, fx=scale_x, fy=scale_y
232 | )
233 |
234 | if self.do_flip:
235 | if np.random.rand() < 0.5: # h-flip
236 | img1 = img1[:, ::-1]
237 | img2 = img2[:, ::-1]
238 | flow = flow[:, ::-1] * [-1.0, 1.0]
239 | valid = valid[:, ::-1]
240 |
241 | margin_y = 20
242 | margin_x = 50
243 |
244 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)
245 | x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)
246 |
247 | y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
248 | x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
249 |
250 | img1 = img1[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
251 | img2 = img2[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
252 | flow = flow[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
253 | valid = valid[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
254 | return img1, img2, flow, valid
255 |
256 | def __call__(self, img1, img2, flow, valid):
257 | img1, img2 = self.color_transform(img1, img2)
258 | img1, img2 = self.eraser_transform(img1, img2)
259 | img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)
260 |
261 | img1 = np.ascontiguousarray(img1)
262 | img2 = np.ascontiguousarray(img2)
263 | flow = np.ascontiguousarray(flow)
264 | valid = np.ascontiguousarray(valid)
265 |
266 | return img1, img2, flow, valid
267 |
--------------------------------------------------------------------------------
/ezsynth/utils/flow_utils/core/utils/flow_viz.py:
--------------------------------------------------------------------------------
1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
2 |
3 |
4 | # MIT License
5 | #
6 | # Copyright (c) 2018 Tom Runia
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to conditions.
14 | #
15 | # Author: Tom Runia
16 | # Date Created: 2018-08-03
17 |
18 | import numpy as np
19 |
20 |
21 | def make_colorwheel():
22 | """
23 | Generates a color wheel for optical flow visualization as presented in:
24 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
25 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
26 |
27 | Code follows the original C++ source code of Daniel Scharstein.
28 | Code follows the the Matlab source code of Deqing Sun.
29 |
30 | Returns:
31 | np.ndarray: Color wheel
32 | """
33 |
34 | RY = 15
35 | YG = 6
36 | GC = 4
37 | CB = 11
38 | BM = 13
39 | MR = 6
40 |
41 | ncols = RY + YG + GC + CB + BM + MR
42 | colorwheel = np.zeros((ncols, 3))
43 | col = 0
44 |
45 | # RY
46 | colorwheel[0:RY, 0] = 255
47 | colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY)
48 | col = col + RY
49 | # YG
50 | colorwheel[col : col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG)
51 | colorwheel[col : col + YG, 1] = 255
52 | col = col + YG
53 | # GC
54 | colorwheel[col : col + GC, 1] = 255
55 | colorwheel[col : col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC)
56 | col = col + GC
57 | # CB
58 | colorwheel[col : col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB)
59 | colorwheel[col : col + CB, 2] = 255
60 | col = col + CB
61 | # BM
62 | colorwheel[col : col + BM, 2] = 255
63 | colorwheel[col : col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM)
64 | col = col + BM
65 | # MR
66 | colorwheel[col : col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR)
67 | colorwheel[col : col + MR, 0] = 255
68 | return colorwheel
69 |
70 |
71 | def flow_uv_to_colors(u, v, convert_to_bgr=False):
72 | """
73 | Applies the flow color wheel to (possibly clipped) flow components u and v.
74 |
75 | According to the C++ source code of Daniel Scharstein
76 | According to the Matlab source code of Deqing Sun
77 |
78 | Args:
79 | u (np.ndarray): Input horizontal flow of shape [H,W]
80 | v (np.ndarray): Input vertical flow of shape [H,W]
81 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
82 |
83 | Returns:
84 | np.ndarray: Flow visualization image of shape [H,W,3]
85 | """
86 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
87 | colorwheel = make_colorwheel() # shape [55x3]
88 | ncols = colorwheel.shape[0]
89 | rad = np.sqrt(np.square(u) + np.square(v))
90 | a = np.arctan2(-v, -u) / np.pi
91 | fk = (a + 1) / 2 * (ncols - 1)
92 | k0 = np.floor(fk).astype(np.int32)
93 | k1 = k0 + 1
94 | k1[k1 == ncols] = 0
95 | f = fk - k0
96 | for i in range(colorwheel.shape[1]):
97 | tmp = colorwheel[:, i]
98 | col0 = tmp[k0] / 255.0
99 | col1 = tmp[k1] / 255.0
100 | col = (1 - f) * col0 + f * col1
101 | idx = rad <= 1
102 | col[idx] = 1 - rad[idx] * (1 - col[idx])
103 | col[~idx] = col[~idx] * 0.75 # out of range
104 | # Note the 2-i => BGR instead of RGB
105 | ch_idx = 2 - i if convert_to_bgr else i
106 | flow_image[:, :, ch_idx] = np.floor(255 * col)
107 | return flow_image
108 |
109 |
110 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
111 | """
112 | Expects a two dimensional flow image of shape.
113 |
114 | Args:
115 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
116 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
117 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
118 |
119 | Returns:
120 | np.ndarray: Flow visualization image of shape [H,W,3]
121 | """
122 | assert flow_uv.ndim == 3, "input flow must have three dimensions"
123 | assert flow_uv.shape[2] == 2, "input flow must have shape [H,W,2]"
124 | if clip_flow is not None:
125 | flow_uv = np.clip(flow_uv, 0, clip_flow)
126 | u = flow_uv[:, :, 0]
127 | v = flow_uv[:, :, 1]
128 | rad = np.sqrt(np.square(u) + np.square(v))
129 | rad_max = np.max(rad)
130 | epsilon = 1e-5
131 | u = u / (rad_max + epsilon)
132 | v = v / (rad_max + epsilon)
133 | return flow_uv_to_colors(u, v, convert_to_bgr)
134 |
--------------------------------------------------------------------------------
/ezsynth/utils/flow_utils/core/utils/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | from scipy import interpolate
5 |
6 |
7 | class InputPadder:
8 | """Pads images such that dimensions are divisible by 8"""
9 |
10 | def __init__(self, dims, mode="sintel"):
11 | self.ht, self.wd = dims[-2:]
12 | self._calculate_padding(mode)
13 |
14 | def _calculate_padding(self, mode):
15 | pad_ht = -self.ht % 8
16 | pad_wd = -self.wd % 8
17 |
18 | if mode == "sintel":
19 | self._pad = (
20 | pad_wd // 2,
21 | pad_wd - pad_wd // 2,
22 | pad_ht // 2,
23 | pad_ht - pad_ht // 2,
24 | )
25 | else:
26 | self._pad = (pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht)
27 |
28 | def pad(self, *inputs):
29 | return [F.pad(x, self._pad, mode="replicate") for x in inputs]
30 |
31 | def unpad(self, x):
32 | return x[
33 | ...,
34 | self._pad[2] : self.ht + self._pad[2],
35 | self._pad[0] : self.wd + self._pad[0],
36 | ]
37 |
38 |
39 | def forward_interpolate(flow_ts):
40 | with torch.no_grad():
41 | flow = flow_ts.cpu().numpy()
42 | dx, dy = flow[0], flow[1]
43 |
44 | ht, wd = dx.shape
45 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht), indexing="xy")
46 |
47 | x1 = x0 + dx
48 | y1 = y0 + dy
49 |
50 | x1 = x1.reshape(-1)
51 | y1 = y1.reshape(-1)
52 | dx = dx.reshape(-1)
53 | dy = dy.reshape(-1)
54 |
55 | valid = (x1 >= 0) & (x1 < wd) & (y1 >= 0) & (y1 < ht)
56 |
57 | x1_valid = x1[valid]
58 | y1_valid = y1[valid]
59 | dx_valid = dx[valid]
60 | dy_valid = dy[valid]
61 |
62 | flow_x = interpolate.griddata(
63 | (x1_valid, y1_valid), dx_valid, (x0, y0), method="nearest", fill_value=0
64 | )
65 |
66 | flow_y = interpolate.griddata(
67 | (x1_valid, y1_valid), dy_valid, (x0, y0), method="nearest", fill_value=0
68 | )
69 |
70 | flow = np.stack([flow_x, flow_y], axis=0)
71 | return torch.from_numpy(flow).float()
72 |
73 |
74 | def bilinear_sampler(img, coords, mode="bilinear", mask=False):
75 | """Wrapper for grid_sample, uses pixel coordinates"""
76 | H, W = img.shape[-2:]
77 | xgrid, ygrid = coords.split([1, 1], dim=-1)
78 | xgrid = 2 * xgrid / (W - 1) - 1
79 | ygrid = 2 * ygrid / (H - 1) - 1
80 |
81 | grid = torch.cat([xgrid, ygrid], dim=-1)
82 | img = F.grid_sample(img, grid, align_corners=True)
83 |
84 | if mask:
85 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
86 | return img, mask.float()
87 |
88 | return img
89 |
90 |
91 | def coords_grid(batch, ht, wd, device):
92 | coords = torch.meshgrid(
93 | torch.arange(ht, device=device), torch.arange(wd, device=device), indexing="ij"
94 | )
95 | coords = torch.stack(coords[::-1], dim=0).float()
96 | return coords[None].repeat(batch, 1, 1, 1)
97 |
98 |
99 | def upflow8(flow, mode="bilinear"):
100 | new_size = (8 * flow.shape[2], 8 * flow.shape[3])
101 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
102 |
--------------------------------------------------------------------------------
/ezsynth/utils/flow_utils/ef_raft_models/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/ezsynth/utils/flow_utils/ef_raft_models/.gitkeep
--------------------------------------------------------------------------------
/ezsynth/utils/flow_utils/flow_diff/fd_corr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from ezsynth.utils.flow_utils.core.utils.utils import bilinear_sampler
5 |
6 |
7 | class CorrBlock_FD_Sp4:
8 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4, coords_init=None, rad=1):
9 | self.num_levels = num_levels
10 | self.radius = radius
11 | self.corr_pyramid = []
12 |
13 | corr = CorrBlock_FD_Sp4.corr(fmap1, fmap2, coords_init, r=rad)
14 |
15 | batch, h1, w1, dim, h2, w2 = corr.shape
16 | corr = corr.reshape(batch*h1*w1, dim, h2, w2)
17 |
18 | self.corr_pyramid.append(corr)
19 | for i in range(self.num_levels-1):
20 | corr = F.avg_pool2d(corr, 2, stride=2)
21 | self.corr_pyramid.append(corr)
22 |
23 | def __call__(self, coords):
24 | r = self.radius
25 | coords = coords.permute(0, 2, 3, 1)
26 | batch, h1, w1, _ = coords.shape
27 |
28 | out_pyramid = []
29 | for i in range(self.num_levels):
30 | corr = self.corr_pyramid[i]
31 | dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
32 | dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
33 | delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1)
34 |
35 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
36 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
37 | coords_lvl = centroid_lvl + delta_lvl
38 |
39 | corr = bilinear_sampler(corr, coords_lvl)
40 | corr = corr.view(batch, h1, w1, -1)
41 | out_pyramid.append(corr)
42 |
43 | out = torch.cat(out_pyramid, dim=-1)
44 | return out.permute(0, 3, 1, 2).contiguous().float()
45 |
46 | @staticmethod
47 | def corr(fmap1, fmap2, coords_init, r):
48 | batch, dim, ht, wd = fmap1.shape
49 | fmap1 = fmap1.view(batch, dim, ht*wd)
50 | fmap2 = fmap2.view(batch, dim, ht*wd)
51 |
52 | corr = torch.matmul(fmap1.transpose(1, 2), fmap2)
53 | corr = corr.view(batch, ht, wd, 1, ht, wd)
54 | # return corr / torch.sqrt(torch.tensor(dim).float())
55 |
56 | coords = coords_init.permute(0, 2, 3, 1).contiguous()
57 | batch, h1, w1, _ = coords.shape
58 |
59 | corr = corr.view(batch*h1*w1, 1, h1, w1)
60 |
61 | dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
62 | dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
63 | delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1)
64 |
65 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2)
66 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
67 | coords_lvl = centroid_lvl + delta_lvl
68 |
69 | corr = bilinear_sampler(corr, coords_lvl)
70 |
71 | corr = corr.view(batch, h1, w1, 1, 2*r+1, 2*r+1)
72 | return corr.permute(0, 1, 2, 3, 5, 4).contiguous() / torch.sqrt(torch.tensor(dim).float())
73 |
--------------------------------------------------------------------------------
/ezsynth/utils/flow_utils/flow_diff/fd_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | # import torch.nn.functional as F
4 |
5 | import timm
6 | import numpy as np
7 |
8 |
9 | class twins_svt_large(nn.Module):
10 | def __init__(self, pretrained=True):
11 | super().__init__()
12 | self.svt = timm.create_model('twins_svt_large', pretrained=pretrained)
13 |
14 | del self.svt.head
15 | del self.svt.patch_embeds[2]
16 | del self.svt.patch_embeds[2]
17 | del self.svt.blocks[2]
18 | del self.svt.blocks[2]
19 | del self.svt.pos_block[2]
20 | del self.svt.pos_block[2]
21 |
22 | def forward(self, x, data=None, layer=2):
23 | B = x.shape[0]
24 | x_4 = None
25 | for i, (embed, drop, blocks, pos_blk) in enumerate(
26 | zip(self.svt.patch_embeds, self.svt.pos_drops, self.svt.blocks, self.svt.pos_block)):
27 |
28 | patch_size = embed.patch_size
29 | if i == layer - 1:
30 | embed.patch_size = (1, 1)
31 | embed.proj.stride = embed.patch_size
32 | x_4 = torch.nn.functional.pad(x, [1, 0, 1, 0], mode='constant', value=0)
33 | x_4, size_4 = embed(x_4)
34 | size_4 = (size_4[0] - 1, size_4[1] - 1)
35 | x_4 = drop(x_4)
36 | for j, blk in enumerate(blocks):
37 | x_4 = blk(x_4, size_4)
38 | if j == 0:
39 | x_4 = pos_blk(x_4, size_4)
40 |
41 | if i < len(self.svt.depths) - 1:
42 | x_4 = x_4.reshape(B, *size_4, -1).permute(0, 3, 1, 2).contiguous()
43 |
44 | embed.patch_size = patch_size
45 | embed.proj.stride = patch_size
46 | x, size = embed(x)
47 | x = drop(x)
48 | for j, blk in enumerate(blocks):
49 | x = blk(x, size)
50 | if j==0:
51 | x = pos_blk(x, size)
52 | if i < len(self.svt.depths) - 1:
53 | x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
54 |
55 | if i == layer-1:
56 | break
57 |
58 | return x, x_4
59 |
60 | def compute_params(self, layer=2):
61 | num = 0
62 | for i, (embed, drop, blocks, pos_blk) in enumerate(
63 | zip(self.svt.patch_embeds, self.svt.pos_drops, self.svt.blocks, self.svt.pos_block)):
64 |
65 | for param in embed.parameters():
66 | num += np.prod(param.size())
67 |
68 | for param in drop.parameters():
69 | num += np.prod(param.size())
70 |
71 | for param in blocks.parameters():
72 | num += np.prod(param.size())
73 |
74 | for param in pos_blk.parameters():
75 | num += np.prod(param.size())
76 |
77 | if i == layer-1:
78 | break
79 |
80 | for param in self.svt.head.parameters():
81 | num += np.prod(param.size())
82 |
83 | return num
84 |
85 |
86 | class twins_svt_small_context(nn.Module):
87 | def __init__(self, pretrained=True):
88 | super().__init__()
89 | self.svt = timm.create_model('twins_svt_small', pretrained=pretrained)
90 |
91 | del self.svt.head
92 | del self.svt.patch_embeds[2]
93 | del self.svt.patch_embeds[2]
94 | del self.svt.blocks[2]
95 | del self.svt.blocks[2]
96 | del self.svt.pos_block[2]
97 | del self.svt.pos_block[2]
98 |
99 | def forward(self, x, data=None, layer=2):
100 | B = x.shape[0]
101 | x_4 = None
102 | for i, (embed, drop, blocks, pos_blk) in enumerate(
103 | zip(self.svt.patch_embeds, self.svt.pos_drops, self.svt.blocks, self.svt.pos_block)):
104 |
105 | patch_size = embed.patch_size
106 | if i == layer - 1:
107 | embed.patch_size = (1, 1)
108 | embed.proj.stride = embed.patch_size
109 | x_4 = torch.nn.functional.pad(x, [1, 0, 1, 0], mode='constant', value=0)
110 | x_4, size_4 = embed(x_4)
111 | size_4 = (size_4[0] - 1, size_4[1] - 1)
112 | x_4 = drop(x_4)
113 | for j, blk in enumerate(blocks):
114 | x_4 = blk(x_4, size_4)
115 | if j == 0:
116 | x_4 = pos_blk(x_4, size_4)
117 |
118 | if i < len(self.svt.depths) - 1:
119 | x_4 = x_4.reshape(B, *size_4, -1).permute(0, 3, 1, 2).contiguous()
120 |
121 | embed.patch_size = patch_size
122 | embed.proj.stride = patch_size
123 | x, size = embed(x)
124 | x = drop(x)
125 | for j, blk in enumerate(blocks):
126 | x = blk(x, size)
127 | if j == 0:
128 | x = pos_blk(x, size)
129 | if i < len(self.svt.depths) - 1:
130 | x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
131 |
132 | if i == layer - 1:
133 | break
134 |
135 | return x, x_4
136 |
137 | def compute_params(self, layer=2):
138 | num = 0
139 | for i, (embed, drop, blocks, pos_blk) in enumerate(
140 | zip(self.svt.patch_embeds, self.svt.pos_drops, self.svt.blocks, self.svt.pos_block)):
141 |
142 | for param in embed.parameters():
143 | num += np.prod(param.size())
144 |
145 | for param in drop.parameters():
146 | num += np.prod(param.size())
147 |
148 | for param in blocks.parameters():
149 | num += np.prod(param.size())
150 |
151 | for param in pos_blk.parameters():
152 | num += np.prod(param.size())
153 |
154 | if i == layer - 1:
155 | break
156 |
157 | for param in self.svt.head.parameters():
158 | num += np.prod(param.size())
159 |
160 | return num
161 |
162 |
--------------------------------------------------------------------------------
/ezsynth/utils/flow_utils/flow_diffusion_models/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/ezsynth/utils/flow_utils/flow_diffusion_models/.gitkeep
--------------------------------------------------------------------------------
/ezsynth/utils/flow_utils/models/raft-kitti.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/ezsynth/utils/flow_utils/models/raft-kitti.pth
--------------------------------------------------------------------------------
/ezsynth/utils/flow_utils/models/raft-sintel.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/ezsynth/utils/flow_utils/models/raft-sintel.pth
--------------------------------------------------------------------------------
/ezsynth/utils/flow_utils/models/raft-small.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/ezsynth/utils/flow_utils/models/raft-small.pth
--------------------------------------------------------------------------------
/ezsynth/utils/flow_utils/warp.py:
--------------------------------------------------------------------------------
1 | # import warnings
2 |
3 | import cv2
4 | import numpy as np
5 |
6 | # warnings.filterwarnings("ignore", category=UserWarning)
7 |
8 |
9 | class Warp:
10 | def __init__(self, sample_fr: np.ndarray):
11 | H, W, _ = sample_fr.shape
12 | self.H = H
13 | self.W = W
14 | self.grid = self._create_grid(H, W)
15 |
16 | def _create_grid(self, H: int, W: int):
17 | x, y = np.meshgrid(np.arange(W), np.arange(H), indexing="xy")
18 | return np.stack((x, y), axis=-1).astype(np.float32)
19 |
20 | def _warp(self, img: np.ndarray, flo: np.ndarray):
21 | flo_resized = cv2.resize(flo, (self.W, self.H), interpolation=cv2.INTER_LINEAR)
22 | map_x = self.grid[..., 0] + flo_resized[..., 0]
23 | map_y = self.grid[..., 1] + flo_resized[..., 1]
24 | warped_img = cv2.remap(
25 | img,
26 | map_x,
27 | map_y,
28 | interpolation=cv2.INTER_LINEAR,
29 | borderMode=cv2.BORDER_REFLECT,
30 | )
31 | return warped_img
32 |
33 | def run_warping(self, img: np.ndarray, flow: np.ndarray):
34 | img = img.astype(np.float32)
35 | flow = flow.astype(np.float32)
36 |
37 | try:
38 | warped_img = self._warp(img, flow)
39 | warped_image = (warped_img * 255).astype(np.uint8)
40 | return warped_image
41 | except Exception as e:
42 | print(f"[ERROR] Exception in run_warping: {e}")
43 | return None
44 |
--------------------------------------------------------------------------------
/output_synth/facestyle_err.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/output_synth/facestyle_err.png
--------------------------------------------------------------------------------
/output_synth/facestyle_out.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/output_synth/facestyle_out.png
--------------------------------------------------------------------------------
/output_synth/retarget_err.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/output_synth/retarget_err.png
--------------------------------------------------------------------------------
/output_synth/retarget_out.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/output_synth/retarget_out.png
--------------------------------------------------------------------------------
/output_synth/stylit_err.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/output_synth/stylit_err.png
--------------------------------------------------------------------------------
/output_synth/stylit_out.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/output_synth/stylit_out.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | opencv-python
3 | torch
4 | torchvision
5 | phycv
6 | scipy
7 | Pillow
8 | tqdm
--------------------------------------------------------------------------------
/test_imgsynth.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import os
3 | import sys
4 | import time
5 |
6 | import torch
7 |
8 | sys.path.append(os.path.dirname(os.path.abspath(__file__)))
9 |
10 | from ezsynth.aux_classes import RunConfig
11 | from ezsynth.aux_utils import save_to_folder, load_guide
12 | from ezsynth.main_ez import ImageSynth
13 |
14 | st = time.time()
15 |
16 | output_folder = "J:/AI/Ezsynth/output_synth"
17 |
18 | # Examples from the Ebsynth repository
19 |
20 | ## Segment retargetting
21 |
22 | ezsynner = ImageSynth(
23 | style_path="J:/AI/Ezsynth/examples/texbynum/source_photo.png",
24 | src_path="J:/AI/Ezsynth/examples/texbynum/source_segment.png",
25 | tgt_path="J:/AI/Ezsynth/examples/texbynum/target_segment.png",
26 | cfg=RunConfig(img_wgt=1.0),
27 | )
28 |
29 | result = ezsynner.run()
30 |
31 | save_to_folder(output_folder, "retarget_out.png", result[0])
32 | save_to_folder(output_folder, "retarget_err.png", result[1])
33 |
34 | ## Stylit
35 |
36 | ezsynner = ImageSynth(
37 | style_path="J:/AI/Ezsynth/examples/stylit/source_style.png",
38 | src_path="J:/AI/Ezsynth/examples/stylit/source_fullgi.png",
39 | tgt_path="J:/AI/Ezsynth/examples/stylit/target_fullgi.png",
40 | cfg=RunConfig(img_wgt=0.66),
41 | )
42 |
43 | result = ezsynner.run(
44 | guides=[
45 | load_guide(
46 | "J:/AI/Ezsynth/examples/stylit/source_dirdif.png",
47 | "J:/AI/Ezsynth/examples/stylit/target_dirdif.png",
48 | 0.66,
49 | ),
50 | load_guide(
51 | "J:/AI/Ezsynth/examples/stylit/source_indirb.png",
52 | "J:/AI/Ezsynth/examples/stylit/target_indirb.png",
53 | 0.66,
54 | ),
55 | ]
56 | )
57 |
58 | save_to_folder(output_folder, "stylit_out.png", result[0])
59 | save_to_folder(output_folder, "stylit_err.png", result[1])
60 |
61 | ## Face style
62 |
63 | ezsynner = ImageSynth(
64 | style_path="J:/AI/Ezsynth/examples/facestyle/source_painting.png",
65 | src_path="J:/AI/Ezsynth/examples/facestyle/source_Gapp.png",
66 | tgt_path="J:/AI/Ezsynth/examples/facestyle/target_Gapp.png",
67 | cfg=RunConfig(img_wgt=2.0),
68 | )
69 |
70 | result = ezsynner.run(
71 | guides=[
72 | load_guide(
73 | "J:/AI/Ezsynth/examples/facestyle/source_Gseg.png",
74 | "J:/AI/Ezsynth/examples/facestyle/target_Gseg.png",
75 | 1.5,
76 | ),
77 | load_guide(
78 | "J:/AI/Ezsynth/examples/facestyle/source_Gpos.png",
79 | "J:/AI/Ezsynth/examples/facestyle/target_Gpos.png",
80 | 1.5,
81 | ),
82 | ]
83 | )
84 |
85 | save_to_folder(output_folder, "facestyle_out.png", result[0])
86 | save_to_folder(output_folder, "facestyle_err.png", result[1])
87 |
88 | gc.collect()
89 | torch.cuda.empty_cache()
90 |
91 | print(f"Time taken: {time.time() - st:.4f} s")
92 |
--------------------------------------------------------------------------------
/test_progress.txt:
--------------------------------------------------------------------------------
1 | Style 1
2 | First frame only: Passed (Passed) (Passed)
3 | Last frame only: Passed (Passed) (Passed)
4 | Middle frame only: Passed (Passed) (Passed)
5 |
6 | Style 2 with Blend (Regression - Fixed)
7 | First and Last: Passed (Passed) (Passed)
8 | First and Mid: Passed (Passed) (Passed)
9 | Mid and Last: Passed (Passed) (Passed)
10 | Mid and Mid: Passed (Passed) (Passed)
11 |
12 | Style 2 without Blend - Forward only
13 | First and Last: Passed (Passed) (Passed)
14 | First and Mid: Passed (Passed) (Passed)
15 | Mid and Last: Passed (Passed) (Passed)
16 | Mid and Mid: Passed (Passed) (Passed)
17 |
18 | Style 2 without Blend - Reverse only (Regression - Fixed)
19 | First and Last: Passed (Passed) (Passed)
20 | First and Mid: Passed (Passed) (Passed)
21 | Mid and Last: Passed (Passed) (Passed)
22 | Mid and Mid: Passed (Passed) (Passed)
23 |
24 | Style 3:
25 | Blend Mid Mid Mid: Passed (Passed)
26 | Blend First Mid Last: Passed (Passed)
27 | Forward Mid Mid Mid: Passed (Passed)
28 | Forward First Mid Last: Passed (Passed)
29 | Reverse Mid Mid Mid: Passed (Passed)
30 | Reverse First Mid Last: Passed (Passed)
31 |
--------------------------------------------------------------------------------
/test_redux.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import os
3 | import sys
4 | import time
5 |
6 | import torch
7 |
8 |
9 | sys.path.append(os.path.dirname(os.path.abspath(__file__)))
10 |
11 | from ezsynth.sequences import EasySequence
12 | from ezsynth.aux_classes import RunConfig
13 | from ezsynth.aux_utils import save_seq
14 | from ezsynth.main_ez import Ezsynth
15 |
16 | st = time.time()
17 |
18 | style_paths = [
19 | "J:/AI/Ezsynth/examples/styles/style000.jpg",
20 | # "J:/AI/Ezsynth/examples/styles/style002.png",
21 | # "J:/AI/Ezsynth/examples/styles/style003.png",
22 | # "J:/AI/Ezsynth/examples/styles/style006.png",
23 | "J:/AI/Ezsynth/examples/styles/style010.png",
24 | # "J:/AI/Ezsynth/examples/styles/style014.png",
25 | # "J:/AI/Ezsynth/examples/styles/style019.png",
26 | # "J:/AI/Ezsynth/examples/styles/style099.jpg",
27 | ]
28 |
29 | image_folder = "J:/AI/Ezsynth/examples/input"
30 | mask_folder = "J:/AI/Ezsynth/examples/mask/mask_feather"
31 | output_folder = "J:/AI/Ezsynth/output"
32 |
33 | # edge_method="Classic"
34 | edge_method = "PAGE"
35 | # edge_method="PST"
36 |
37 | # flow_arch = "RAFT"
38 | # flow_model = "sintel"
39 |
40 | flow_arch = "EF_RAFT"
41 | flow_model = "25000_ours-sintel"
42 |
43 | # flow_arch = "FLOW_DIFF"
44 | # flow_model = "FlowDiffuser-things"
45 |
46 |
47 | ezrunner = Ezsynth(
48 | style_paths=style_paths,
49 | image_folder=image_folder,
50 | cfg=RunConfig(pre_mask=False, feather=5),
51 | edge_method=edge_method,
52 | raft_flow_model_name=flow_model,
53 | mask_folder=mask_folder,
54 | # do_mask=True,
55 | do_mask=False,
56 | flow_arch=flow_arch
57 | )
58 |
59 |
60 | # only_mode = EasySequence.MODE_FWD
61 | # only_mode = EasySequence.MODE_REV
62 | only_mode = None
63 |
64 | stylized_frames, err_frames, flow_frames = ezrunner.run_sequences_full(
65 | only_mode, return_flow=True
66 | )
67 | # stylized_frames, err_frames = ezrunner.run_sequences(only_mode)
68 |
69 | save_seq(stylized_frames, "J:/AI/Ezsynth/output_57_efraft")
70 | save_seq(flow_frames, "J:/AI/Ezsynth/flow_output_57_efraft")
71 | # save_seq(err_frames, "J:/AI/Ezsynth/output_51err")
72 |
73 | gc.collect()
74 | torch.cuda.empty_cache()
75 |
76 | print(f"Time taken: {time.time() - st:.4f} s")
77 |
--------------------------------------------------------------------------------