├── .gitignore ├── LICENSE ├── README.md ├── build_ebs-linux-cpu+cuda.sh ├── build_ebs-linux-cpu_only.sh ├── build_ebs-win64-cpu+cuda.bat ├── ebsynth_sha256.txt ├── examples ├── facestyle │ ├── source_Gapp.png │ ├── source_Gpos.png │ ├── source_Gseg.png │ ├── source_painting.png │ ├── target_Gapp.png │ ├── target_Gpos.png │ └── target_Gseg.png ├── input │ ├── 000.jpg │ ├── 001.jpg │ ├── 002.jpg │ ├── 003.jpg │ ├── 004.jpg │ ├── 005.jpg │ ├── 006.jpg │ ├── 007.jpg │ ├── 008.jpg │ ├── 009.jpg │ └── 010.jpg ├── mask │ ├── mask_feather │ │ ├── mask_00000.png │ │ ├── mask_00001.png │ │ ├── mask_00002.png │ │ ├── mask_00003.png │ │ ├── mask_00004.png │ │ ├── mask_00005.png │ │ ├── mask_00006.png │ │ ├── mask_00007.png │ │ ├── mask_00008.png │ │ ├── mask_00009.png │ │ └── mask_00010.png │ └── mask_nofeather │ │ ├── mask_00000.png │ │ ├── mask_00001.png │ │ ├── mask_00002.png │ │ ├── mask_00003.png │ │ ├── mask_00004.png │ │ ├── mask_00005.png │ │ ├── mask_00006.png │ │ ├── mask_00007.png │ │ ├── mask_00008.png │ │ ├── mask_00009.png │ │ └── mask_00010.png ├── styles │ ├── style000.jpg │ ├── style002.png │ ├── style003.png │ ├── style006.png │ ├── style010.png │ ├── style014.png │ ├── style019.png │ └── style099.jpg ├── stylit │ ├── source_dirdif.png │ ├── source_dirspc.png │ ├── source_fullgi.png │ ├── source_indirb.png │ ├── source_style.png │ ├── target_dirdif.png │ ├── target_dirspc.png │ ├── target_fullgi.png │ └── target_indirb.png └── texbynum │ ├── run.bat │ ├── run.sh │ ├── source_photo.png │ ├── source_segment.png │ └── target_segment.png ├── ezsynth ├── aux_classes.py ├── aux_computations.py ├── aux_flow_viz.py ├── aux_masker.py ├── aux_run.py ├── aux_utils.py ├── constants.py ├── edge_detection.py ├── main_ez.py ├── sequences.py └── utils │ ├── __init__.py │ ├── _eb.py │ ├── _ebsynth.py │ ├── blend │ ├── __init__.py │ ├── blender.py │ ├── cupy_accelerated.py │ ├── histogram_blend.py │ └── reconstruction.py │ ├── ebsynth.dll │ └── flow_utils │ ├── OpticalFlow.py │ ├── __init__.py │ ├── alt_cuda_corr │ ├── __init__.py │ ├── correlation.cpp │ ├── correlation_kernel.cu │ └── setup.py │ ├── core │ ├── __init__.py │ ├── corr.py │ ├── datasets.py │ ├── ef_raft.py │ ├── extractor.py │ ├── fd_corr.py │ ├── fd_decoder.py │ ├── fd_encoder.py │ ├── flow_diffusion.py │ ├── raft.py │ ├── update.py │ └── utils │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-311.pyc │ │ └── utils.cpython-311.pyc │ │ ├── augmentor.py │ │ ├── flow_viz.py │ │ └── utils.py │ ├── ef_raft_models │ └── .gitkeep │ ├── flow_diff │ ├── fd_corr.py │ ├── fd_decoder.py │ ├── fd_encoder.py │ └── flow_diffusion.py │ ├── flow_diffusion_models │ └── .gitkeep │ ├── models │ ├── raft-kitti.pth │ ├── raft-sintel.pth │ └── raft-small.pth │ └── warp.py ├── output_synth ├── facestyle_err.png ├── facestyle_out.png ├── retarget_err.png ├── retarget_out.png ├── stylit_err.png └── stylit_out.png ├── requirements.txt ├── test_imgsynth.py ├── test_progress.txt └── test_redux.py /.gitignore: -------------------------------------------------------------------------------- 1 | #The Zipped Library 2 | *.zip 3 | 4 | #The Dist Folder from building 5 | /dist/ 6 | 7 | #The Build Folder from building 8 | /build/ 9 | 10 | /ezsynth.egg-info/ 11 | 12 | /.vscode/ 13 | 14 | ## 15 | __pycache__ 16 | .mypy_cache 17 | 18 | /venv/ 19 | *.so 20 | 21 | # Extra model files 22 | /ezsynth/utils/flow_utils/flow_diffusion_models/*.pth 23 | /ezsynth/utils/flow_utils/ef_raft_models/*.pth -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Ezsynth - Ebsynth Python Library 2 | 3 | Reworked version, courtesy of [FuouM](https://github.com/FuouM), with masking support and some visual bug fixes. Aims to be easy to use and maintain. 4 | 5 | Perform things like style transfer, color transfer, inpainting, superimposition, video stylization and more! 6 | This implementation makes use of advanced physics based edge detection and RAFT optical flow, which leads to more accurate results during synthesis. 7 | 8 | :warning: **This is not intended to be used as an installable module.** 9 | 10 | Currently tested on: 11 | ``` 12 | Windows 10 - Python 3.11 - RTX3060 13 | Ubuntu 24 - Python 3.12 - RTX4070(Laptop) 14 | ``` 15 | 16 | ## Get started 17 | 18 | ### Windows 19 | ```cmd 20 | rem Clone this repo 21 | git clone https://github.com/Trentonom0r3/Ezsynth.git 22 | cd Ezsynth 23 | 24 | rem (Optional) create and activate venv 25 | python -m venv venv 26 | venv\Scripts\activate.bat 27 | 28 | rem Install requirements 29 | pip install -r requirements.txt 30 | 31 | rem A precompiled ebsynth.dll is included. 32 | rem If don't want to rebuild, you are ready to go and can skip the following steps. 33 | 34 | rem Clone ebsynth 35 | git clone https://github.com/Trentonom0r3/ebsynth.git 36 | 37 | rem build ebsynth as lib 38 | copy .\build_ebs-win64-cpu+cuda.bat .\ebsynth 39 | cd ebsynth && .\build_ebs-win64-cpu+cuda.bat 40 | 41 | rem copy lib 42 | cp .\bin\ebsynth.so ..\ezsynth\utils\ebsynth.so 43 | 44 | rem cleanup 45 | cd .. && rmdir /s /q .\ebsynth 46 | ``` 47 | 48 | ### Linux 49 | ```bash 50 | # clone this repo 51 | git clone https://github.com/Trentonom0r3/Ezsynth.git 52 | cd Ezsynth 53 | 54 | # (optional) create and activate venv 55 | python -m venv venv 56 | source ./venv/bin/activate 57 | 58 | # install requirements 59 | pip install -r requirements.txt 60 | 61 | # clone ebsynth 62 | git clone https://github.com/Trentonom0r3/ebsynth.git 63 | 64 | # build ebsynth as lib 65 | cp ./build_ebs-linux-cpu+cuda.sh ./ebsynth 66 | cd ebsynth && ./build_ebs-linux-cpu+cuda.sh 67 | 68 | # copy lib 69 | cp ./bin/ebsynth.so ../ezsynth/utils/ebsynth.so 70 | 71 | # cleanup 72 | cd .. && rm -rf ./ebsynth 73 | ``` 74 | 75 | ### All 76 | You may also install Cupy and Cupyx to use GPU for some other operations. 77 | 78 | ## Examples 79 | 80 | * To get started, see `test_redux.py` for an example of generating a full video. 81 | * To generate image style transfer, see `test_imgsynth.py` for all examples from the original `Ebsynth`. 82 | 83 | ## Example outputs 84 | 85 | | Face style | Stylit | Retarget | 86 | |:-:|:-:|:-:| 87 | | | | | 88 | 89 | https://github.com/user-attachments/assets/aa3cd191-4eb2-4dc0-8213-2c763f1b3316 90 | 91 | https://github.com/user-attachments/assets/63e50272-aa5c-42a1-a5ec-46178cdf2981 92 | 93 | Comparison of Edge methods 94 | 95 | ## Notable things 96 | 97 | **Updates:** 98 | 1. [Ef-RAFT](https://github.com/n3slami/Ef-RAFT) is added 99 | 100 | To use, download models from [the original repo](https://github.com/n3slami/Ef-RAFT/tree/master/models) and place them in `/ezsynth/utils/flow_utils/ef_raft_models` 101 | ``` 102 | .gitkeep 103 | 25000_ours-sintel.pth 104 | ours-things.pth 105 | ours_sintel.pth 106 | ``` 107 | 108 | 2. [FlowDiffuser](https://github.com/LA30/FlowDiffuser) is added. 109 | 110 | To use, download the model from [the original repo](https://github.com/LA30/FlowDiffuser?tab=readme-ov-file#usage) and place it in `/ezsynth/utils/flow_utils/flow_diffusion_models/FlowDiffuser-things.pth`. 111 | 112 | You will also need to install PyTorch Image Models to run it: `pip install timm`. On first run, it will download 2 models ~470MB `twins_svt_large (378 MB)` and `twins_svt_small (92 MB)`. 113 | 114 | This increases the VRAM usage significantly when run along with `EbSynth Run` (~15GB, but may not OOM. Tested on 12GB VRAM). 115 | 116 | In that case, It will throw `CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR` error, but shouldn't be fatal, and instead takes ~3x as long to run. 117 | 118 | https://github.com/user-attachments/assets/7f43630f-c7c9-40d0-8745-58d1f7c84d4f 119 | 120 | Comparison of Optical Flow models 121 | 122 | Optical Flow directly affects Flow position warping and Style image warping, controlled by `pos_wgt` and `wrp_wgt` respectively. 123 | 124 | **Changes:** 125 | 1. Flow is calculated on a frame by frame basis, with correct time orientation, instead of pre-computing only a forward-flow. 126 | 2. Padding is applied to Edge detection and Warping to remove border visual distortion. 127 | 128 | 129 | **Observations:** 130 | 1. Edge detection models return NaN if input tensor has too many zeros(?). 131 | 2. Pre-masked inputs take twice as long to run Ebsynth 132 | 133 | ## API Overview 134 | 135 | ### ImageSynth 136 | For image-to-image style transfer, via file paths: `test_imgsynth.py` 137 | ```python 138 | ezsynner = ImageSynth( 139 | style_path="source_style.png", 140 | src_path="source_fullgi.png", 141 | tgt_path="target_fullgi.png", 142 | cfg=RunConfig(img_wgt=0.66), 143 | ) 144 | 145 | result = ezsynner.run( 146 | guides=[ 147 | load_guide( 148 | "source_dirdif.png", 149 | "target_dirdif.png", 150 | 0.66, 151 | ), 152 | load_guide( 153 | "source_indirb.png", 154 | "target_indirb.png", 155 | 0.66, 156 | ), 157 | ] 158 | ) 159 | 160 | save_to_folder(output_folder, "stylit_out.png", result[0]) # Styled image 161 | save_to_folder(output_folder, "stylit_err.png", result[1]) # Error image 162 | ``` 163 | 164 | ### Ezsynth 165 | 166 | **edge_method** 167 | 168 | Edge detection method. Choose from `PST`, `Classic`, or `PAGE`. 169 | * `PST` (Phase Stretch Transform): Good overall structure, but not very detailed. 170 | * `Classic`: A good balance between structure and detail. 171 | * `PAGE` (Phase and Gradient Estimation): Great detail, great structure, but slow. 172 | 173 | **video stylization** 174 | 175 | Via file paths (see `test_redux.py`): 176 | 177 | ```python 178 | style_paths = [ 179 | "style000.png", 180 | "style006.png" 181 | ] 182 | 183 | ezrunner = Ezsynth( 184 | style_paths=style_paths, 185 | image_folder=image_folder, 186 | cfg=RunConfig(pre_mask=False, feather=5, return_masked_only=False), 187 | edge_method="PAGE", 188 | raft_flow_model_name="sintel", 189 | mask_folder=mask_folder, 190 | do_mask=True 191 | ) 192 | 193 | only_mode = None 194 | stylized_frames, err_frames = ezrunner.run_sequences(only_mode) 195 | 196 | save_seq(stylized_frames, "output") 197 | ``` 198 | 199 | Via Numpy ndarrays: 200 | 201 | ```python 202 | class EzsynthBase: 203 | def __init__( 204 | self, 205 | style_frs: list[np.ndarray], 206 | style_idxes: list[int], 207 | img_frs_seq: list[np.ndarray], 208 | cfg: RunConfig = RunConfig(), 209 | edge_method="Classic", 210 | raft_flow_model_name="sintel", 211 | do_mask=False, 212 | msk_frs_seq: list[np.ndarray] | None = None, 213 | ): 214 | pass 215 | ``` 216 | 217 | ### RunConfig 218 | #### Ebsynth gen params 219 | * `uniformity (float)`: Uniformity weight for the style transfer. Reasonable values are between `500-15000`. Defaults to `3500.0`. 220 | 221 | * `patchsize (int)`: Size of the patches [NxN]. Must be an odd number `>= 3`. Defaults to `7`. 222 | 223 | * `pyramidlevels (int)`: Number of pyramid levels. Larger values useful for things like color transfer. Defaults to `6`. 224 | 225 | * `searchvoteiters (int)`: Number of search/vote iterations. Defaults to `12`. 226 | * `patchmatchiters (int)`: Number of Patch-Match iterations. The larger, the longer it takes. Defaults to `6`. 227 | 228 | * `extrapass3x3 (bool)`: Perform additional polishing pass with 3x3 patches at the finest level. Defaults to `True`. 229 | 230 | #### Ebsynth guide weights params 231 | * `edg_wgt (float)`: Edge detect weights. Defaults to `1.0`. 232 | * `img_wgt (float)`: Original image weights. Defaults to `6.0`. 233 | * `pos_wgt (float)`: Flow position warping weights. Defaults to `2.0`. 234 | * `wrp_wgt (float)`: Warped style image weight. Defaults to `0.5`. 235 | 236 | #### Blending params 237 | * `use_gpu (bool)`: Use GPU for Histogram Blending (Only affect Blend mode). Faster than CPU. Defaults to `False`. 238 | 239 | * `use_lsqr (bool)`: Use LSQR (Least-squares solver) instead of LSMR (Iterative solver for least-squares) for Poisson blending step. LSQR often yield better results. May change to LSMR for speed (depends). Defaults to `True`. 240 | 241 | * `use_poisson_cupy (bool)`: Use Cupy GPU acceleration for Poisson blending step. Uses LSMR (overrides `use_lsqr`). May not yield better speed. Defaults to `False`. 242 | 243 | * `poisson_maxiter (int | None)`: Max iteration to calculate Poisson Least-squares (only affect LSMR mode). Expect positive integers. Defaults to `None`. 244 | 245 | * `only_mode (str)`: Skip blending, only run one pass per sequence. Valid values: 246 | * `MODE_FWD = "forward"` (Will only run forward mode if `sequence.mode` is blend) 247 | 248 | * `MODE_REV = "reverse"` (Will only run reverse mode if `sequence.mode` is blend) 249 | 250 | * Defaults to `MODE_NON = "none"`. 251 | 252 | #### Masking params 253 | * `do_mask (bool)`: Whether to apply mask. Defaults to `False`. 254 | 255 | * `pre_mask (bool)`: Whether to mask the inputs and styles before `RUN` or after. Pre-mask takes ~2x time to run per frame. Could be due to Ebsynth.dll implementation. Defaults to `False`. 256 | 257 | * `feather (int)`: Feather Gaussian radius to apply on the mask results. Only affect if `return_masked_only == False`. Expects integers. Defaults to `0`. 258 | 259 | ## Credits 260 | 261 | jamriska - https://github.com/jamriska/ebsynth 262 | 263 | ``` 264 | @misc{Jamriska2018, 265 | author = {Jamriska, Ondrej}, 266 | title = {Ebsynth: Fast Example-based Image Synthesis and Style Transfer}, 267 | year = {2018}, 268 | publisher = {GitHub}, 269 | journal = {GitHub repository}, 270 | howpublished = {\url{https://github.com/jamriska/ebsynth}}, 271 | } 272 | ``` 273 | ``` 274 | Ondřej Jamriška, Šárka Sochorová, Ondřej Texler, Michal Lukáč, Jakub Fišer, Jingwan Lu, Eli Shechtman, and Daniel Sýkora. 2019. Stylizing Video by Example. ACM Trans. Graph. 38, 4, Article 107 (July 2019), 11 pages. https://doi.org/10.1145/3306346.3323006 275 | ``` 276 | 277 | FuouM - https://github.com/FuouM 278 | pravdomil - https://github.com/pravdomil 279 | xy-gao - https://github.com/xy-gao 280 | 281 | https://github.com/princeton-vl/RAFT 282 | 283 | ``` 284 | RAFT: Recurrent All Pairs Field Transforms for Optical Flow 285 | ECCV 2020 286 | Zachary Teed and Jia Deng 287 | ``` 288 | 289 | https://github.com/n3slami/Ef-RAFT 290 | 291 | ``` 292 | @inproceedings{eslami2024rethinking, 293 | title={Rethinking RAFT for efficient optical flow}, 294 | author={Eslami, Navid and Arefi, Farnoosh and Mansourian, Amir M and Kasaei, Shohreh}, 295 | booktitle={2024 13th Iranian/3rd International Machine Vision and Image Processing Conference (MVIP)}, 296 | pages={1--7}, 297 | year={2024}, 298 | organization={IEEE} 299 | } 300 | ``` 301 | 302 | https://github.com/LA30/FlowDiffuser 303 | 304 | ``` 305 | @inproceedings{luo2024flowdiffuser, 306 | title={FlowDiffuser: Advancing Optical Flow Estimation with Diffusion Models}, 307 | author={Luo, Ao and Li, Xin and Yang, Fan and Liu, Jiangyu and Fan, Haoqiang and Liu, Shuaicheng}, 308 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 309 | pages={19167--19176}, 310 | year={2024} 311 | } 312 | ``` 313 | -------------------------------------------------------------------------------- /build_ebs-linux-cpu+cuda.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | nvcc --shared src/ebsynth.cpp src/ebsynth_cpu.cpp src/ebsynth_cuda.cu -I"include" -DNDEBUG -D__CORRECT_ISO_CPP11_MATH_H_PROTO -O6 -std=c++14 -w -Xcompiler -fopenmp,-fPIC -o bin/ebsynth.so 3 | -------------------------------------------------------------------------------- /build_ebs-linux-cpu_only.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | g++ src/ebsynth.cpp src/ebsynth_cpu.cpp src/ebsynth_nocuda.cpp -DNDEBUG -O6 -fopenmp -I"include" -std=c++11 -o bin/ebsynth 3 | -------------------------------------------------------------------------------- /build_ebs-win64-cpu+cuda.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | setlocal ENABLEDELAYEDEXPANSION 3 | 4 | call "vcvarsall.bat" amd64 5 | 6 | nvcc -v -arch sm_86 src\ebsynth.cpp src\ebsynth_cpu.cpp src\ebsynth_cuda.cu -DNDEBUG -O6 -I "include" -o "bin\ebsynth.dll" -Xcompiler "/openmp /fp:fast" -Xlinker "/IMPLIB:lib\ebsynth.lib" -shared -DEBSYNTH_API=__declspec(dllexport) -w || goto error 7 | 8 | del dummy.lib;dummy.exp 2> NUL 9 | goto :EOF 10 | 11 | :error 12 | echo FAILED 13 | @%COMSPEC% /C exit 1 >nul -------------------------------------------------------------------------------- /ebsynth_sha256.txt: -------------------------------------------------------------------------------- 1 | SHA256 hash of ebsynth.dll: 2 | e3cfad210d445fcbfa6c7dcd2f9bdaaf36d550746c108c79a94d2d1ecce41369 3 | CertUtil: -hashfile command completed successfully. 4 | -------------------------------------------------------------------------------- /examples/facestyle/source_Gapp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/facestyle/source_Gapp.png -------------------------------------------------------------------------------- /examples/facestyle/source_Gpos.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/facestyle/source_Gpos.png -------------------------------------------------------------------------------- /examples/facestyle/source_Gseg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/facestyle/source_Gseg.png -------------------------------------------------------------------------------- /examples/facestyle/source_painting.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/facestyle/source_painting.png -------------------------------------------------------------------------------- /examples/facestyle/target_Gapp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/facestyle/target_Gapp.png -------------------------------------------------------------------------------- /examples/facestyle/target_Gpos.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/facestyle/target_Gpos.png -------------------------------------------------------------------------------- /examples/facestyle/target_Gseg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/facestyle/target_Gseg.png -------------------------------------------------------------------------------- /examples/input/000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/input/000.jpg -------------------------------------------------------------------------------- /examples/input/001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/input/001.jpg -------------------------------------------------------------------------------- /examples/input/002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/input/002.jpg -------------------------------------------------------------------------------- /examples/input/003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/input/003.jpg -------------------------------------------------------------------------------- /examples/input/004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/input/004.jpg -------------------------------------------------------------------------------- /examples/input/005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/input/005.jpg -------------------------------------------------------------------------------- /examples/input/006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/input/006.jpg -------------------------------------------------------------------------------- /examples/input/007.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/input/007.jpg -------------------------------------------------------------------------------- /examples/input/008.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/input/008.jpg -------------------------------------------------------------------------------- /examples/input/009.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/input/009.jpg -------------------------------------------------------------------------------- /examples/input/010.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/input/010.jpg -------------------------------------------------------------------------------- /examples/mask/mask_feather/mask_00000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_feather/mask_00000.png -------------------------------------------------------------------------------- /examples/mask/mask_feather/mask_00001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_feather/mask_00001.png -------------------------------------------------------------------------------- /examples/mask/mask_feather/mask_00002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_feather/mask_00002.png -------------------------------------------------------------------------------- /examples/mask/mask_feather/mask_00003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_feather/mask_00003.png -------------------------------------------------------------------------------- /examples/mask/mask_feather/mask_00004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_feather/mask_00004.png -------------------------------------------------------------------------------- /examples/mask/mask_feather/mask_00005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_feather/mask_00005.png -------------------------------------------------------------------------------- /examples/mask/mask_feather/mask_00006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_feather/mask_00006.png -------------------------------------------------------------------------------- /examples/mask/mask_feather/mask_00007.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_feather/mask_00007.png -------------------------------------------------------------------------------- /examples/mask/mask_feather/mask_00008.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_feather/mask_00008.png -------------------------------------------------------------------------------- /examples/mask/mask_feather/mask_00009.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_feather/mask_00009.png -------------------------------------------------------------------------------- /examples/mask/mask_feather/mask_00010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_feather/mask_00010.png -------------------------------------------------------------------------------- /examples/mask/mask_nofeather/mask_00000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_nofeather/mask_00000.png -------------------------------------------------------------------------------- /examples/mask/mask_nofeather/mask_00001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_nofeather/mask_00001.png -------------------------------------------------------------------------------- /examples/mask/mask_nofeather/mask_00002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_nofeather/mask_00002.png -------------------------------------------------------------------------------- /examples/mask/mask_nofeather/mask_00003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_nofeather/mask_00003.png -------------------------------------------------------------------------------- /examples/mask/mask_nofeather/mask_00004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_nofeather/mask_00004.png -------------------------------------------------------------------------------- /examples/mask/mask_nofeather/mask_00005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_nofeather/mask_00005.png -------------------------------------------------------------------------------- /examples/mask/mask_nofeather/mask_00006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_nofeather/mask_00006.png -------------------------------------------------------------------------------- /examples/mask/mask_nofeather/mask_00007.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_nofeather/mask_00007.png -------------------------------------------------------------------------------- /examples/mask/mask_nofeather/mask_00008.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_nofeather/mask_00008.png -------------------------------------------------------------------------------- /examples/mask/mask_nofeather/mask_00009.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_nofeather/mask_00009.png -------------------------------------------------------------------------------- /examples/mask/mask_nofeather/mask_00010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/mask/mask_nofeather/mask_00010.png -------------------------------------------------------------------------------- /examples/styles/style000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/styles/style000.jpg -------------------------------------------------------------------------------- /examples/styles/style002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/styles/style002.png -------------------------------------------------------------------------------- /examples/styles/style003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/styles/style003.png -------------------------------------------------------------------------------- /examples/styles/style006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/styles/style006.png -------------------------------------------------------------------------------- /examples/styles/style010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/styles/style010.png -------------------------------------------------------------------------------- /examples/styles/style014.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/styles/style014.png -------------------------------------------------------------------------------- /examples/styles/style019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/styles/style019.png -------------------------------------------------------------------------------- /examples/styles/style099.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/styles/style099.jpg -------------------------------------------------------------------------------- /examples/stylit/source_dirdif.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/stylit/source_dirdif.png -------------------------------------------------------------------------------- /examples/stylit/source_dirspc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/stylit/source_dirspc.png -------------------------------------------------------------------------------- /examples/stylit/source_fullgi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/stylit/source_fullgi.png -------------------------------------------------------------------------------- /examples/stylit/source_indirb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/stylit/source_indirb.png -------------------------------------------------------------------------------- /examples/stylit/source_style.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/stylit/source_style.png -------------------------------------------------------------------------------- /examples/stylit/target_dirdif.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/stylit/target_dirdif.png -------------------------------------------------------------------------------- /examples/stylit/target_dirspc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/stylit/target_dirspc.png -------------------------------------------------------------------------------- /examples/stylit/target_fullgi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/stylit/target_fullgi.png -------------------------------------------------------------------------------- /examples/stylit/target_indirb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/stylit/target_indirb.png -------------------------------------------------------------------------------- /examples/texbynum/run.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | setlocal 3 | set PATH=..\..\bin;%PATH% 4 | 5 | ebsynth.exe -patchsize 3 -uniformity 1000 -style source_photo.png -guide source_segment.png target_segment.png -output output.png 6 | -------------------------------------------------------------------------------- /examples/texbynum/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | export PATH=../../bin:$PATH 3 | 4 | ebsynth -patchsize 3 -uniformity 1000 -style source_photo.png -guide source_segment.png target_segment.png -output output.png 5 | 6 | -------------------------------------------------------------------------------- /examples/texbynum/source_photo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/texbynum/source_photo.png -------------------------------------------------------------------------------- /examples/texbynum/source_segment.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/texbynum/source_segment.png -------------------------------------------------------------------------------- /examples/texbynum/target_segment.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/examples/texbynum/target_segment.png -------------------------------------------------------------------------------- /ezsynth/aux_classes.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | from .utils.flow_utils.warp import Warp 5 | from .sequences import EasySequence 6 | 7 | 8 | class RunConfig: 9 | """ 10 | ### Ebsynth gen params 11 | `uniformity (float)`: Uniformity weight for the style transfer. 12 | Reasonable values are between `500-15000`. 13 | Defaults to `3500.0`. 14 | 15 | `patchsize (int)`: Size of the patches [NxN]. Must be an odd number `>= 3`. 16 | Defaults to `7`. 17 | 18 | `pyramidlevels (int)`: Number of pyramid levels. Larger values useful for things like color transfer. 19 | Defaults to `6`. 20 | 21 | `searchvoteiters (int)`: Number of search/vote iterations. Defaults to `12`. 22 | `patchmatchiters (int)`: Number of Patch-Match iterations. The larger, the longer it takes. 23 | Defaults to `6`. 24 | 25 | `extrapass3x3 (bool)`: Perform additional polishing pass with 3x3 patches at the finest level. 26 | Defaults to `True`. 27 | 28 | ### Ebsynth guide weights params 29 | `edg_wgt (float)`: Edge detect weights. Defaults to `1.0`. 30 | `img_wgt (float)`: Original image weights. Defaults to `6.0`. 31 | `pos_wgt (float)`: Flow position warping weights. Defaults to `2.0`. 32 | `wrp_wgt (float)`: Warped style image weight. Defaults to `0.5`. 33 | 34 | ### Blending params 35 | `use_gpu (bool)`: Use GPU for Histogram Blending (Only affect Blend mode). Faster than CPU. 36 | Defaults to `False`. 37 | 38 | `use_lsqr (bool)`: Use LSQR (Least-squares solver) instead of LSMR (Iterative solver for least-squares) 39 | for Poisson blending step. LSQR often yield better results. May change to LSMR for speed (depends). 40 | Defaults to `True`. 41 | 42 | `use_poisson_cupy (bool)`: Use Cupy GPU acceleration for Poisson blending step. 43 | Uses LSMR (overrides `use_lsqr`). May not yield better speed. 44 | Defaults to `False`. 45 | 46 | `poisson_maxiter (int | None)`: Max iteration to calculate Poisson Least-squares (only affect LSMR mode). 47 | Expect positive integers. 48 | Defaults to `None`. 49 | 50 | `only_mode (str)`: Skip blending, only run one pass per sequence. 51 | Valid values: 52 | `MODE_FWD = "forward"` (Will only run forward mode if `sequence.mode` is blend) 53 | 54 | `MODE_REV = "reverse"` (Will only run reverse mode if `sequence.mode` is blend) 55 | 56 | Defaults to `MODE_NON = "none"`. 57 | 58 | ### Masking params 59 | `do_mask (bool)`: Whether to apply mask. Defaults to `False`. 60 | 61 | `pre_mask (bool)`: Whether to mask the inputs and styles before `RUN` or after. 62 | Pre-mask takes ~2x time to run per frame. Could be due to Ebsynth.dll implementation. 63 | Defaults to `False`. 64 | 65 | `feather (int)`: Feather Gaussian radius to apply on the mask results. Only affect if `return_masked_only == False`. 66 | Expects integers. Defaults to `0`. 67 | """ 68 | 69 | def __init__( 70 | self, 71 | uniformity=3500.0, 72 | patchsize=7, 73 | pyramidlevels=6, 74 | searchvoteiters=12, 75 | patchmatchiters=6, 76 | extrapass3x3=True, 77 | edg_wgt=1.0, 78 | img_wgt=6.0, 79 | pos_wgt=2.0, 80 | wrp_wgt=0.5, 81 | use_gpu=False, 82 | use_lsqr=True, 83 | use_poisson_cupy=False, 84 | poisson_maxiter: int | None = None, 85 | only_mode=EasySequence.MODE_NON, 86 | do_mask=False, 87 | pre_mask=False, 88 | feather=0, 89 | ) -> None: 90 | # Ebsynth gen params 91 | self.uniformity = uniformity 92 | """Uniformity weight for the style transfer. 93 | Reasonable values are between `500-15000`. 94 | 95 | Defaults to `3500.0`.""" 96 | self.patchsize = patchsize 97 | """Size of the patches [`NxN`]. Must be an odd number `>= 3`. 98 | Defaults to `7`""" 99 | 100 | self.pyramidlevels = pyramidlevels 101 | """Number of pyramid levels. 102 | Larger values useful for things like color transfer. 103 | 104 | Defaults to 6.""" 105 | 106 | self.searchvoteiters = searchvoteiters 107 | """Number of search/vote iterations. 108 | Defaults to `12`""" 109 | 110 | self.patchmatchiters = patchmatchiters 111 | """Number of Patch-Match iterations. The larger, the longer it takes. 112 | Defaults to `6`""" 113 | 114 | self.extrapass3x3 = extrapass3x3 115 | """Perform additional polishing pass with 3x3 patches at the finest level. 116 | Defaults to `True`""" 117 | 118 | # Weights 119 | self.edg_wgt = edg_wgt 120 | """Edge detect weights. Defaults to `1.0`""" 121 | 122 | self.img_wgt = img_wgt 123 | """Original image weights. Defaults to `6.0`""" 124 | 125 | self.pos_wgt = pos_wgt 126 | """Flow position warping weights. Defaults to `2.0`""" 127 | 128 | self.wrp_wgt = wrp_wgt 129 | """Warped style image weight. Defaults to `0.5`""" 130 | 131 | # Blend params 132 | self.use_gpu = use_gpu 133 | """Use GPU for Histogram Blending (Only affect Blend mode). Faster than CPU. 134 | Defaults to `False`""" 135 | 136 | self.use_lsqr = use_lsqr 137 | """Use LSQR (Least-squares solver) instead of LSMR (Iterative solver for least-squares) 138 | for Poisson blending step. LSQR often yield better results. 139 | 140 | May change to LSMR for speed (depends). 141 | Defaults to `True`""" 142 | 143 | self.use_poisson_cupy = use_poisson_cupy 144 | """Use Cupy GPU acceleration for Poisson blending step. 145 | Uses LSMR (overrides `use_lsqr`). May not yield better speed. 146 | 147 | Defaults to `False`""" 148 | 149 | self.poisson_maxiter = poisson_maxiter 150 | """Max iteration to calculate Poisson Least-squares (only affect LSMR mode). Expect positive integers. 151 | 152 | Defaults to `None`""" 153 | 154 | # No blending mode 155 | self.only_mode = only_mode 156 | """Skip blending, only run one pass per sequence. 157 | 158 | Valid values: 159 | `MODE_FWD = "forward"` (Will only run forward mode if `sequence.mode` is blend) 160 | 161 | `MODE_REV = "reverse"` (Will only run reverse mode if `sequence.mode` is blend) 162 | 163 | Defaults to `MODE_NON = "none"` 164 | """ 165 | 166 | # Skip adding last style frame if blending 167 | self.skip_blend_style_last = False 168 | """Skip adding last style frame if blending. Internal variable""" 169 | 170 | # Masking mode 171 | self.do_mask = do_mask 172 | """Whether to apply mask. Defaults to `False`""" 173 | 174 | self.pre_mask = pre_mask 175 | """Whether to mask the inputs and styles before `RUN` or after. 176 | 177 | Pre-mask takes ~2x time to run per frame. Could be due to Ebsynth.dll implementation. 178 | 179 | Defaults to `False`""" 180 | 181 | self.feather = feather 182 | """Feather Gaussian radius to apply on the mask results. Only affect if `return_masked_only == False`. 183 | 184 | Expects integers. Defaults to `0`""" 185 | 186 | def get_ebsynth_cfg(self): 187 | return { 188 | "uniformity": self.uniformity, 189 | "patchsize": self.patchsize, 190 | "pyramidlevels": self.pyramidlevels, 191 | "searchvoteiters": self.searchvoteiters, 192 | "patchmatchiters": self.patchmatchiters, 193 | "extrapass3x3": self.extrapass3x3, 194 | } 195 | 196 | def get_blender_cfg(self): 197 | return { 198 | "use_gpu": self.use_gpu, 199 | "use_lsqr": self.use_lsqr, 200 | "use_poisson_cupy": self.use_poisson_cupy, 201 | "poisson_maxiter": self.poisson_maxiter, 202 | } 203 | 204 | 205 | class PositionalGuide: 206 | def __init__(self) -> None: 207 | self.coord_map = None 208 | 209 | def get_coord_maps(self, warp: Warp): 210 | h, w = warp.H, warp.W 211 | 212 | # Create x and y coordinates 213 | x = np.linspace(0, 1, w) 214 | y = np.linspace(0, 1, h) 215 | 216 | # Use numpy's meshgrid to create 2D coordinate arrays 217 | xx, yy = np.meshgrid(x, y) 218 | 219 | # Stack the coordinates into a single 3D array 220 | self.coord_map = np.stack((xx, yy, np.zeros_like(xx)), axis=-1).astype( 221 | np.float32 222 | ) 223 | 224 | def get_or_create_coord_maps(self, warp: Warp): 225 | if self.coord_map is None is None: 226 | self.get_coord_maps(warp) 227 | return self.coord_map 228 | 229 | def create_from_flow( 230 | self, flow: np.ndarray, original_size: tuple[int, ...], warp: Warp 231 | ): 232 | coord_map = self.get_or_create_coord_maps(warp) 233 | coord_map_warped = warp.run_warping(coord_map, flow) 234 | 235 | coord_map_warped[..., :2] = coord_map_warped[..., :2] % 1 236 | 237 | if coord_map_warped.shape[:2] != original_size: 238 | coord_map_warped = cv2.resize( 239 | coord_map_warped, original_size, interpolation=cv2.INTER_LINEAR 240 | ) 241 | 242 | g_pos = (coord_map_warped * 255).astype(np.uint8) 243 | self.coord_map = coord_map_warped.copy() 244 | 245 | return g_pos 246 | 247 | 248 | class EdgeConfig: 249 | # PST 250 | PST_S = 0.3 251 | PST_W = 15 252 | PST_SIG_LPF = 0.15 253 | PST_MIN = 0.05 254 | PST_MAX = 0.9 255 | 256 | # PAGE 257 | PAGE_M1 = 0 258 | PAGE_M2 = 0.35 259 | PAGE_SIG1 = 0.05 260 | PAGE_SIG2 = 0.8 261 | PAGE_S1 = 0.8 262 | PAGE_S2 = 0.8 263 | PAGE_SIG_LPF = 0.1 264 | PAGE_MIN = 0.0 265 | PAGE_MAX = 0.9 266 | 267 | MORPH_FLAG = 1 268 | 269 | def __init__(self, **kwargs): 270 | # PST attributes 271 | self.pst_s = kwargs.get("S", self.PST_S) 272 | self.pst_w = kwargs.get("W", self.PST_W) 273 | self.pst_sigma_lpf = kwargs.get("sigma_LPF", self.PST_SIG_LPF) 274 | self.pst_thresh_min = kwargs.get("thresh_min", self.PST_MIN) 275 | self.pst_thresh_max = kwargs.get("thresh_max", self.PST_MAX) 276 | 277 | # PAGE attributes 278 | self.page_mu_1 = kwargs.get("mu_1", self.PAGE_M1) 279 | self.page_mu_2 = kwargs.get("mu_2", self.PAGE_M2) 280 | self.page_sigma_1 = kwargs.get("sigma_1", self.PAGE_SIG1) 281 | self.page_sigma_2 = kwargs.get("sigma_2", self.PAGE_SIG2) 282 | self.page_s1 = kwargs.get("S1", self.PAGE_S1) 283 | self.page_s2 = kwargs.get("S2", self.PAGE_S2) 284 | self.page_sigma_lpf = kwargs.get("sigma_LPF", self.PAGE_SIG_LPF) 285 | self.page_thresh_min = kwargs.get("thresh_min", self.PAGE_MIN) 286 | self.page_thresh_max = kwargs.get("thresh_max", self.PAGE_MAX) 287 | 288 | self.morph_flag = kwargs.get("morph_flag", self.MORPH_FLAG) 289 | 290 | @classmethod 291 | def get_pst_default(cls) -> dict: 292 | return { 293 | "S": cls.PST_S, 294 | "W": cls.PST_W, 295 | "sigma_LPF": cls.PST_SIG_LPF, 296 | "thresh_min": cls.PST_MIN, 297 | "thresh_max": cls.PST_MAX, 298 | "morph_flag": cls.MORPH_FLAG, 299 | } 300 | 301 | @classmethod 302 | def get_page_default(cls) -> dict: 303 | return { 304 | "mu_1": cls.PAGE_M1, 305 | "mu_2": cls.PAGE_M2, 306 | "sigma_1": cls.PAGE_SIG1, 307 | "sigma_2": cls.PAGE_SIG2, 308 | "S1": cls.PAGE_S1, 309 | "S2": cls.PAGE_S2, 310 | "sigma_LPF": cls.PAGE_SIG_LPF, 311 | "thresh_min": cls.PAGE_MIN, 312 | "thresh_max": cls.PAGE_MAX, 313 | "morph_flag": cls.MORPH_FLAG, 314 | } 315 | 316 | def get_pst_current(self) -> dict: 317 | return { 318 | "S": self.pst_s, 319 | "W": self.pst_w, 320 | "sigma_LPF": self.pst_sigma_lpf, 321 | "thresh_min": self.pst_thresh_min, 322 | "thresh_max": self.pst_thresh_max, 323 | "morph_flag": self.morph_flag, 324 | } 325 | 326 | def get_page_current(self) -> dict: 327 | return { 328 | "mu_1": self.page_mu_1, 329 | "mu_2": self.page_mu_2, 330 | "sigma_1": self.page_sigma_1, 331 | "sigma_2": self.page_sigma_2, 332 | "S1": self.page_s1, 333 | "S2": self.page_s2, 334 | "sigma_LPF": self.page_sigma_lpf, 335 | "thresh_min": self.page_thresh_min, 336 | "thresh_max": self.page_thresh_max, 337 | "morph_flag": self.morph_flag, 338 | } 339 | -------------------------------------------------------------------------------- /ezsynth/aux_computations.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | import numpy as np 3 | from .edge_detection import EdgeDetector 4 | 5 | 6 | def precompute_edge_guides( 7 | img_frs_seq: list[np.ndarray], edge_method: str 8 | ) -> list[np.ndarray]: 9 | edge_detector = EdgeDetector(edge_method) 10 | edge_maps = [] 11 | for img_fr in tqdm.tqdm(img_frs_seq, desc="Calculating edge maps"): 12 | edge_maps.append(edge_detector.compute_edge(img_fr)) 13 | return edge_maps 14 | 15 | -------------------------------------------------------------------------------- /ezsynth/aux_flow_viz.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def make_colorwheel(): 4 | """ 5 | Generates a color wheel for optical flow visualization. 6 | """ 7 | RY, YG, GC, CB, BM, MR = 15, 6, 4, 11, 13, 6 8 | ncols = RY + YG + GC + CB + BM + MR 9 | colorwheel = np.zeros((ncols, 3), dtype=np.uint8) 10 | col = 0 11 | 12 | colorwheel[col : col + RY, 0] = 255 13 | colorwheel[col : col + RY, 1] = np.floor(255 * np.arange(0, RY) / RY).astype( 14 | np.uint8 15 | ) 16 | col += RY 17 | 18 | colorwheel[col : col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG).astype( 19 | np.uint8 20 | ) 21 | colorwheel[col : col + YG, 1] = 255 22 | col += YG 23 | 24 | colorwheel[col : col + GC, 1] = 255 25 | colorwheel[col : col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC).astype( 26 | np.uint8 27 | ) 28 | col += GC 29 | 30 | colorwheel[col : col + CB, 1] = 255 - np.floor(255 * np.arange(0, CB) / CB).astype( 31 | np.uint8 32 | ) 33 | colorwheel[col : col + CB, 2] = 255 34 | col += CB 35 | 36 | colorwheel[col : col + BM, 2] = 255 37 | colorwheel[col : col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM).astype( 38 | np.uint8 39 | ) 40 | col += BM 41 | 42 | colorwheel[col : col + MR, 2] = 255 - np.floor(255 * np.arange(0, MR) / MR).astype( 43 | np.uint8 44 | ) 45 | colorwheel[col : col + MR, 0] = 255 46 | 47 | return colorwheel 48 | 49 | 50 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 51 | """ 52 | Applies the flow color wheel to (possibly clipped) flow components u and v. 53 | """ 54 | flow_image = np.zeros((*u.shape, 3), dtype=np.uint8) 55 | colorwheel = make_colorwheel() 56 | ncols = colorwheel.shape[0] 57 | 58 | rad = np.sqrt(np.square(u) + np.square(v)) 59 | a = np.arctan2(-v, -u) / np.pi 60 | fk = (a + 1) / 2 * (ncols - 1) 61 | k0 = np.floor(fk).astype(np.int32) 62 | k1 = (k0 + 1) % ncols 63 | f = fk - k0 64 | 65 | for i in range(3): 66 | tmp = colorwheel[:, i] 67 | col0 = tmp[k0] / 255.0 68 | col1 = tmp[k1] / 255.0 69 | col = (1 - f) * col0 + f * col1 70 | 71 | idx = rad <= 1 72 | col[idx] = 1 - rad[idx] * (1 - col[idx]) 73 | col[~idx] *= 0.75 74 | 75 | ch_idx = 2 - i if convert_to_bgr else i 76 | flow_image[..., ch_idx] = (255 * col).astype(np.uint8) 77 | 78 | return flow_image 79 | 80 | 81 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): 82 | """ 83 | Converts a two-dimensional flow image to a color image for visualization. 84 | """ 85 | assert ( 86 | flow_uv.ndim == 3 and flow_uv.shape[2] == 2 87 | ), "Input flow must have shape [H,W,2]" 88 | 89 | if clip_flow is not None: 90 | flow_uv = np.clip(flow_uv, 0, clip_flow) 91 | 92 | u, v = flow_uv[..., 0], flow_uv[..., 1] 93 | rad = np.sqrt(np.square(u) + np.square(v)) 94 | rad_max = np.max(rad) 95 | 96 | epsilon = 1e-5 97 | u = u / (rad_max + epsilon) 98 | v = v / (rad_max + epsilon) 99 | 100 | return flow_uv_to_colors(u, v, convert_to_bgr) 101 | -------------------------------------------------------------------------------- /ezsynth/aux_masker.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import tqdm 4 | 5 | 6 | def apply_mask(image: np.ndarray, mask: np.ndarray): 7 | masked_image = cv2.bitwise_and(image, image, mask=mask) 8 | return masked_image.astype(np.uint8) 9 | 10 | 11 | def apply_masks(images: list[np.ndarray], masks: list[np.ndarray]): 12 | len_img = len(images) 13 | len_msk = len(masks) 14 | if len_img != len_msk: 15 | raise ValueError(f"[{len_img=}], [{len_msk=}]") 16 | 17 | masked_images = [] 18 | for i in range(len_img): 19 | masked_images.append(apply_mask(images[i], masks[i])) 20 | 21 | return masked_images 22 | 23 | 24 | def apply_masks_idxes( 25 | images: list[np.ndarray], masks: list[np.ndarray], img_idxes: list[int] 26 | ): 27 | masked_images = [] 28 | for i, idx in enumerate(img_idxes): 29 | masked_images.append(apply_mask(images[i], masks[idx])) 30 | return masked_images 31 | 32 | 33 | def apply_masked_back( 34 | original: np.ndarray, processed: np.ndarray, mask: np.ndarray, feather_radius=0 35 | ): 36 | if feather_radius > 0: 37 | mask_blurred = cv2.GaussianBlur(mask, (feather_radius, feather_radius), 0) 38 | mask_blurred = mask_blurred.astype(np.float32) / 255.0 39 | 40 | mask_inv_blurred = 1.0 - mask_blurred 41 | 42 | # Expand dimensions to match the number of channels in the original image 43 | mask_blurred_expanded = np.expand_dims(mask_blurred, axis=-1) 44 | mask_inv_blurred_expanded = np.expand_dims(mask_inv_blurred, axis=-1) 45 | 46 | background = original * mask_inv_blurred_expanded 47 | foreground = processed * mask_blurred_expanded 48 | 49 | # Combine the background and foreground 50 | result = background + foreground 51 | result = result.astype(np.uint8) 52 | 53 | else: 54 | mask = mask.astype(np.float32) / 255.0 55 | mask_inv = 1.0 - mask 56 | mask_expanded = np.expand_dims(mask, axis=-1) 57 | mask_inv_expanded = np.expand_dims(mask_inv, axis=-1) 58 | background = original * mask_inv_expanded 59 | foreground = processed * mask_expanded 60 | result = background + foreground 61 | result = result.astype(np.uint8) 62 | 63 | return result 64 | 65 | 66 | def apply_masked_back_seq( 67 | img_frs_seq: list[np.ndarray], 68 | styled_msk_frs: list[np.ndarray], 69 | mask_frs_seq: list[np.ndarray], 70 | feather=0, 71 | ): 72 | len_img = len(img_frs_seq) 73 | len_stl = len(styled_msk_frs) 74 | len_msk = len(mask_frs_seq) 75 | 76 | if len_img != len_stl != len_msk: 77 | raise ValueError(f"Lengths not match. [{len_img=}, {len_stl=}, {len_msk=}]") 78 | 79 | backed_seq = [] 80 | 81 | for i in tqdm.tqdm(range(len_img), desc="Adding masked back"): 82 | backed_seq.append( 83 | apply_masked_back( 84 | img_frs_seq[i], styled_msk_frs[i], mask_frs_seq[i], feather 85 | ) 86 | ) 87 | 88 | return backed_seq 89 | -------------------------------------------------------------------------------- /ezsynth/aux_run.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import tqdm 4 | 5 | from .aux_classes import PositionalGuide, RunConfig 6 | from .utils._ebsynth import ebsynth 7 | from .utils.blend.blender import Blend 8 | from .utils.flow_utils.OpticalFlow import RAFT_flow 9 | from .utils.flow_utils.warp import Warp 10 | from .sequences import EasySequence 11 | 12 | 13 | def run_a_pass( 14 | seq: EasySequence, 15 | seq_mode: str, 16 | img_frs_seq: list[np.ndarray], 17 | style: np.ndarray, 18 | edge: list[np.ndarray], 19 | cfg: RunConfig, 20 | rafter: RAFT_flow, 21 | eb: ebsynth, 22 | ): 23 | stylized_frames: list[np.ndarray] = [style] 24 | err_list: list[np.ndarray] = [] 25 | ORIGINAL_SIZE = img_frs_seq[0].shape[1::-1] 26 | 27 | start, end, step, is_forward = ( 28 | get_forward(seq) if seq_mode == EasySequence.MODE_FWD else get_backward(seq) 29 | ) 30 | warp = Warp(img_frs_seq[start]) 31 | print(f"{'Forward' if is_forward else 'Reverse'} mode. {start=}, {end=}, {step=}") 32 | flows = [] 33 | poses = [] 34 | pos_guider = PositionalGuide() 35 | 36 | for i in tqdm.tqdm(range(start, end, step), "Generating"): 37 | flow = get_flow(img_frs_seq, rafter, step, is_forward, i) 38 | flows.append(flow) 39 | 40 | poster = pos_guider.create_from_flow(flow, ORIGINAL_SIZE, warp) 41 | poses.append(poster) 42 | warped_img = get_warped_img(stylized_frames, ORIGINAL_SIZE, step, warp, flow) 43 | 44 | stylized_img, err = eb.run( 45 | style, 46 | guides=[ 47 | (edge[start], edge[i + step], cfg.edg_wgt), # Slower with premask 48 | (img_frs_seq[start], img_frs_seq[i + step], cfg.img_wgt), 49 | (poses[0], poster, cfg.pos_wgt), 50 | (style, warped_img, cfg.wrp_wgt), # Slower with premask 51 | ], 52 | ) 53 | stylized_frames.append(stylized_img) 54 | err_list.append(err) 55 | 56 | if not is_forward: 57 | stylized_frames = stylized_frames[::-1] 58 | err_list = err_list[::-1] 59 | flows = flows[::-1] 60 | 61 | return stylized_frames, err_list, flows 62 | 63 | 64 | def get_warped_img( 65 | stylized_frames: list[np.ndarray], ORIGINAL_SIZE, step: int, warp: Warp, flow 66 | ): 67 | stylized_img = stylized_frames[-1] / 255.0 68 | warped_img = warp.run_warping(stylized_img, flow * (-step)) 69 | warped_img = cv2.resize(warped_img, ORIGINAL_SIZE) 70 | return warped_img 71 | 72 | 73 | def get_flow( 74 | img_frs_seq: list[np.ndarray], 75 | rafter: RAFT_flow, 76 | step: int, 77 | is_forward: bool, 78 | i: int, 79 | ): 80 | if is_forward: 81 | flow = rafter._compute_flow(img_frs_seq[i], img_frs_seq[i + step]) 82 | else: 83 | flow = rafter._compute_flow(img_frs_seq[i + step], img_frs_seq[i]) 84 | return flow 85 | 86 | 87 | def run_scratch( 88 | seq: EasySequence, 89 | img_frs_seq: list[np.ndarray], 90 | style_frs: list[np.ndarray], 91 | edge: list[np.ndarray], 92 | cfg: RunConfig, 93 | rafter: RAFT_flow, 94 | eb: ebsynth, 95 | ): 96 | if seq.mode == EasySequence.MODE_BLN and cfg.only_mode != EasySequence.MODE_NON: 97 | print(f"{cfg.only_mode} Only") 98 | stylized_frames, err_list, flow = run_a_pass( 99 | seq, 100 | cfg.only_mode, 101 | img_frs_seq, 102 | style_frs[seq.style_idxs[0]] 103 | if cfg.only_mode == EasySequence.MODE_FWD 104 | else style_frs[seq.style_idxs[1]], 105 | edge, 106 | cfg, 107 | rafter, 108 | eb, 109 | ) 110 | return stylized_frames, err_list, flow 111 | 112 | if seq.mode != EasySequence.MODE_BLN: 113 | stylized_frames, err_list, flow = run_a_pass( 114 | seq, 115 | seq.mode, 116 | img_frs_seq, 117 | style_frs[seq.style_idxs[0]], 118 | edge, 119 | cfg, 120 | rafter, 121 | eb, 122 | ) 123 | return stylized_frames, err_list, flow 124 | 125 | print("Blending mode") 126 | 127 | style_fwd, err_fwd, flow_fwd = run_a_pass( 128 | seq, 129 | EasySequence.MODE_FWD, 130 | img_frs_seq, 131 | style_frs[seq.style_idxs[0]], 132 | edge, 133 | cfg, 134 | rafter, 135 | eb, 136 | ) 137 | 138 | style_bwd, err_bwd, _ = run_a_pass( 139 | seq, 140 | EasySequence.MODE_REV, 141 | img_frs_seq, 142 | style_frs[seq.style_idxs[1]], 143 | edge, 144 | cfg, 145 | rafter, 146 | eb, 147 | ) 148 | 149 | return run_blend(img_frs_seq, style_fwd, style_bwd, err_fwd, err_bwd, flow_fwd, cfg) 150 | 151 | 152 | def run_blend( 153 | img_frs_seq: list[np.ndarray], 154 | style_fwd: list[np.ndarray], 155 | style_bwd: list[np.ndarray], 156 | err_fwd: list[np.ndarray], 157 | err_bwd: list[np.ndarray], 158 | flow_fwd: list[np.ndarray], 159 | cfg: RunConfig, 160 | ): 161 | blender = Blend(**cfg.get_blender_cfg()) 162 | 163 | err_masks = blender._create_selection_mask(err_fwd, err_bwd) 164 | 165 | warped_masks = blender._warping_masks(img_frs_seq[0], flow_fwd, err_masks) 166 | 167 | hist_blends = blender._hist_blend(style_fwd, style_bwd, warped_masks) 168 | 169 | blends = blender._reconstruct(style_fwd, style_bwd, warped_masks, hist_blends) 170 | 171 | if not cfg.skip_blend_style_last: 172 | blends.append(style_bwd[-1]) 173 | 174 | return blends, warped_masks, flow_fwd 175 | 176 | 177 | def get_forward(seq: EasySequence): 178 | start = seq.fr_start_idx 179 | end = seq.fr_end_idx 180 | step = 1 181 | is_forward = True 182 | return start, end, step, is_forward 183 | 184 | 185 | def get_backward(seq: EasySequence): 186 | start = seq.fr_end_idx 187 | end = seq.fr_start_idx 188 | step = -1 189 | is_forward = False 190 | return start, end, step, is_forward 191 | -------------------------------------------------------------------------------- /ezsynth/aux_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | import tqdm 8 | 9 | 10 | def validate_option(option, values, default): 11 | return option if option in values else default 12 | 13 | 14 | def save_to_folder( 15 | output_folder: str, base_file_name: str, result_array: np.ndarray 16 | ) -> str: 17 | os.makedirs(output_folder, exist_ok=True) 18 | output_file_path = os.path.join(output_folder, base_file_name) 19 | cv2.imwrite(output_file_path, result_array) 20 | return output_file_path 21 | 22 | 23 | def validate_and_read_img(img: str | np.ndarray) -> np.ndarray: 24 | if isinstance(img, str): 25 | if os.path.isfile(img): 26 | img = cv2.imread(img) 27 | return img 28 | raise ValueError(f"Path does not exist: {img}") 29 | 30 | if isinstance(img, np.ndarray): 31 | if img.shape[-1] == 3: 32 | return img 33 | raise ValueError(f"Expected 3 channels image. Style shape is {img.shape}") 34 | 35 | 36 | def load_guide(src_path: str, tgt_path: str, weight=1.0): 37 | src_img = validate_and_read_img(src_path) 38 | tgt_img = validate_and_read_img(tgt_path) 39 | return (src_img, tgt_img, weight) 40 | 41 | 42 | def read_frames_from_paths(lst: list[str]) -> list[np.ndarray]: 43 | img_arr_seq: list[np.ndarray] = [] 44 | err_frame = -1 45 | try: 46 | total = len(lst) 47 | for err_frame, img_path in tqdm.tqdm( 48 | enumerate(lst), desc="Reading images: ", total=total 49 | ): 50 | img_arr = validate_and_read_img(img_path) 51 | img_arr_seq.append(img_arr) 52 | else: 53 | print(f"Read {len(img_arr_seq)} frames successfully") 54 | return img_arr_seq 55 | except Exception as e: 56 | raise ValueError(f"Error reading frame {err_frame}: {e}") 57 | 58 | 59 | def read_masks_from_paths(lst: list[str]) -> list[np.ndarray]: 60 | msk_arr_seq: list[np.ndarray] = [] 61 | err_frame = -1 62 | try: 63 | total = len(lst) 64 | for err_frame, img_path in tqdm.tqdm( 65 | enumerate(lst), desc="Reading masks: ", total=total 66 | ): 67 | msk = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) 68 | msk_arr_seq.append(msk) 69 | else: 70 | print(f"Read {len(msk_arr_seq)} frames successfully") 71 | return msk_arr_seq 72 | except Exception as e: 73 | raise ValueError(f"Error reading mask frame {err_frame}: {e}") 74 | 75 | 76 | img_extensions = (".png", ".jpg", ".jpeg") 77 | img_path_pattern = re.compile(r"(\d+)(?=\.(jpg|jpeg|png)$)") 78 | 79 | 80 | def get_sequence_indices(seq_folder_path: str) -> list[str]: 81 | if not os.path.isdir(seq_folder_path): 82 | raise ValueError(f"Path does not exist: {seq_folder_path}") 83 | file_names = os.listdir(seq_folder_path) 84 | file_names = sorted( 85 | file_names, 86 | key=lambda x: [int(c) if c.isdigit() else c for c in re.split("([0-9]+)", x)], 87 | ) 88 | img_file_paths = [ 89 | os.path.join(seq_folder_path, file_name) 90 | for file_name in file_names 91 | if file_name.lower().endswith(img_extensions) 92 | ] 93 | if not img_file_paths: 94 | raise ValueError("No image files found in the directory.") 95 | return img_file_paths 96 | 97 | 98 | def extract_indices(lst: list[str]): 99 | return sorted(int(img_path_pattern.findall(img_name)[-1][0]) for img_name in lst) 100 | 101 | 102 | def is_valid_file_path(input_path: str | list[str]) -> bool: 103 | return isinstance(input_path, str) and os.path.isfile(input_path) 104 | 105 | 106 | def validate_file_or_folder_to_lst( 107 | input_paths: str | list[str], type_name="" 108 | ) -> list[str]: 109 | if is_valid_file_path(input_paths): 110 | return [input_paths] # type: ignore 111 | if isinstance(input_paths, list): 112 | valid_paths = [path for path in input_paths if is_valid_file_path(path)] 113 | if valid_paths: 114 | print(f"Received {len(valid_paths)} {type_name} files") 115 | return valid_paths 116 | raise FileNotFoundError(f"No valid {type_name} file(s) were found. {input_paths}") 117 | 118 | 119 | def setup_src_from_folder( 120 | seq_folder_path: str, 121 | ) -> tuple[list[str], list[int], list[np.ndarray]]: 122 | img_file_paths = get_sequence_indices(seq_folder_path) 123 | img_idxes = extract_indices(img_file_paths) 124 | img_frs_seq = read_frames_from_paths(img_file_paths) 125 | return img_file_paths, img_idxes, img_frs_seq 126 | 127 | 128 | def setup_masks_from_folder(mask_folder_path: str): 129 | msk_file_paths = get_sequence_indices(mask_folder_path) 130 | msk_idxes = extract_indices(msk_file_paths) 131 | msk_frs_seq = read_masks_from_paths(msk_file_paths) 132 | return msk_file_paths, msk_idxes, msk_frs_seq 133 | 134 | 135 | def setup_src_from_lst(paths: list[str], type_name=""): 136 | val_paths = validate_file_or_folder_to_lst(paths, type_name) 137 | val_idxes = extract_indices(val_paths) 138 | frs_seq = read_frames_from_paths(val_paths) 139 | return val_paths, val_idxes, frs_seq 140 | 141 | 142 | def save_seq(results: list, output_folder, base_name="output", extension=".png"): 143 | if not results: 144 | print("Error: No results to save.") 145 | return 146 | for i in range(len(results)): 147 | save_to_folder( 148 | output_folder, 149 | f"{base_name}{i:03}{extension}", 150 | results[i], 151 | ) 152 | else: 153 | print("All results saved successfully") 154 | return 155 | 156 | 157 | def replace_zeros_tensor(image: torch.Tensor, replace_value: int = 1) -> torch.Tensor: 158 | zero_mask = image == 0 159 | replace_tensor = torch.full_like(image, replace_value) 160 | return torch.where(zero_mask, replace_tensor, image) 161 | 162 | 163 | def replace_zeros_np(image: np.ndarray, replace_value: int = 1) -> np.ndarray: 164 | zero_mask = image == 0 165 | replace_array = np.full_like(image, replace_value) 166 | return np.where(zero_mask, replace_array, image) 167 | -------------------------------------------------------------------------------- /ezsynth/constants.py: -------------------------------------------------------------------------------- 1 | EDGE_METHODS = ["PAGE", "PST", "Classic"] 2 | DEFAULT_EDGE_METHOD = "Classic" 3 | 4 | FLOW_MODELS = ["sintel", "kitti"] 5 | DEFAULT_FLOW_MODEL = "sintel" 6 | 7 | FLOW_ARCHS = ["RAFT", "EF_RAFT", "FLOW_DIFF"] 8 | DEFAULT_FLOW_ARCH = "RAFT" 9 | 10 | EF_RAFT_MODELS = [ 11 | "25000_ours-sintel", 12 | "ours_sintel", 13 | "ours-things", 14 | ] 15 | DEFAULT_EF_RAFT_MODEL = "25000_ours-sintel" 16 | 17 | FLOW_DIFF_MODEL = "FlowDiffuser-things" 18 | -------------------------------------------------------------------------------- /ezsynth/edge_detection.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | 5 | from phycv import PAGE_GPU, PST_GPU 6 | 7 | from .aux_classes import EdgeConfig 8 | from .aux_utils import replace_zeros_tensor 9 | 10 | 11 | class EdgeDetector: 12 | def __init__(self, method="PAGE"): 13 | """ 14 | Initialize the edge detector. 15 | 16 | :param method: Edge detection method. Choose from 'PST', 'Classic', or 'PAGE'. 17 | :PST: Phase Stretch Transform (PST) edge detector. - Good overall structure, 18 | but not very detailed. 19 | :Classic: Classic edge detector. - A good balance between structure and detail. 20 | :PAGE: Phase and Gradient Estimation (PAGE) edge detector. - 21 | Great detail, great structure, but slow. 22 | """ 23 | self.method = method 24 | self.device = "cuda" 25 | if method == "PST": 26 | self.pst_gpu = PST_GPU(device=self.device) 27 | elif method == "PAGE": 28 | self.page_gpu = PAGE_GPU(direction_bins=10, device=self.device) 29 | elif method == "Classic": 30 | size, sigma = 5, 6.0 31 | self.kernel = self.create_gaussian_kernel(size, sigma) 32 | self.pad_size = 16 33 | 34 | @staticmethod 35 | def create_gaussian_kernel(size, sigma): 36 | x, y = np.mgrid[-size // 2 + 1 : size // 2 + 1, -size // 2 + 1 : size // 2 + 1] 37 | g = np.exp(-((x**2 + y**2) / (2.0 * sigma**2))) 38 | return g / g.sum() 39 | 40 | def pad_image(self, img): 41 | return cv2.copyMakeBorder( 42 | img, 43 | self.pad_size, 44 | self.pad_size, 45 | self.pad_size, 46 | self.pad_size, 47 | cv2.BORDER_REFLECT, 48 | ) 49 | 50 | def unpad_image(self, img): 51 | return img[self.pad_size : -self.pad_size, self.pad_size : -self.pad_size] 52 | 53 | def classic_preprocess(self, img): 54 | gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 55 | blurred = cv2.filter2D(gray, -1, self.kernel) 56 | edge_map = cv2.subtract(gray, blurred) 57 | edge_map = np.clip(edge_map + 128, 0, 255) 58 | return edge_map.astype(np.uint8) 59 | 60 | def pst_page_postprocess(self, edge_map: np.ndarray): 61 | edge_map = cv2.GaussianBlur(edge_map, (5, 5), 3) 62 | edge_map = edge_map * 255 63 | return edge_map.astype(np.uint8) 64 | 65 | def pst_run( 66 | self, 67 | input_data: np.ndarray, 68 | S, 69 | W, 70 | sigma_LPF, 71 | thresh_min, 72 | thresh_max, 73 | morph_flag, 74 | ): 75 | input_img = cv2.cvtColor(input_data, cv2.COLOR_BGR2GRAY) 76 | 77 | padded_img = self.pad_image(input_img) 78 | 79 | self.pst_gpu.h = padded_img.shape[0] 80 | self.pst_gpu.w = padded_img.shape[1] 81 | 82 | self.pst_gpu.img = torch.from_numpy(padded_img).to(self.pst_gpu.device) 83 | # If input has too many zeros the model returns NaNs for some reason 84 | self.pst_gpu.img = replace_zeros_tensor(self.pst_gpu.img, 1) 85 | 86 | self.pst_gpu.init_kernel(S, W) 87 | self.pst_gpu.apply_kernel(sigma_LPF, thresh_min, thresh_max, morph_flag) 88 | 89 | edge_map = self.pst_gpu.pst_output.cpu().numpy() 90 | edge_map = self.unpad_image(edge_map) 91 | 92 | return edge_map 93 | 94 | def page_run( 95 | self, 96 | input_data: np.ndarray, 97 | mu_1, 98 | mu_2, 99 | sigma_1, 100 | sigma_2, 101 | S1, 102 | S2, 103 | sigma_LPF, 104 | thresh_min, 105 | thresh_max, 106 | morph_flag, 107 | ): 108 | input_img = cv2.cvtColor(input_data, cv2.COLOR_BGR2GRAY) 109 | padded_img = self.pad_image(input_img) 110 | 111 | self.page_gpu.h = padded_img.shape[0] 112 | self.page_gpu.w = padded_img.shape[1] 113 | 114 | self.page_gpu.img = torch.from_numpy(padded_img).to(self.page_gpu.device) 115 | # If input has too many zeros the model returns NaNs for some reason 116 | self.page_gpu.img = replace_zeros_tensor(self.page_gpu.img, 1) 117 | 118 | self.page_gpu.init_kernel(mu_1, mu_2, sigma_1, sigma_2, S1, S2) 119 | self.page_gpu.apply_kernel(sigma_LPF, thresh_min, thresh_max, morph_flag) 120 | self.page_gpu.create_page_edge() 121 | 122 | edge_map = self.page_gpu.page_edge.cpu().numpy() 123 | edge_map = self.unpad_image(edge_map) 124 | return edge_map 125 | 126 | def compute_edge(self, input_data: np.ndarray): 127 | edge_map = None 128 | if self.method == "PST": 129 | edge_map = self.pst_run(input_data, **EdgeConfig.get_pst_default()) 130 | edge_map = self.pst_page_postprocess(edge_map) 131 | return edge_map 132 | 133 | if self.method == "Classic": 134 | edge_map = self.classic_preprocess(input_data) 135 | return edge_map 136 | 137 | if self.method == "PAGE": 138 | edge_map = self.page_run(input_data, **EdgeConfig.get_page_default()) 139 | edge_map = self.pst_page_postprocess(edge_map) 140 | return edge_map 141 | return edge_map 142 | 143 | -------------------------------------------------------------------------------- /ezsynth/main_ez.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | import tqdm 5 | 6 | from .aux_flow_viz import flow_to_image 7 | 8 | from .aux_classes import RunConfig 9 | from .aux_computations import precompute_edge_guides 10 | from .aux_masker import ( 11 | apply_masked_back_seq, 12 | apply_masks, 13 | apply_masks_idxes, 14 | ) 15 | from .aux_run import run_scratch 16 | from .aux_utils import ( 17 | setup_masks_from_folder, 18 | setup_src_from_folder, 19 | setup_src_from_lst, 20 | validate_and_read_img, 21 | validate_option, 22 | ) 23 | from .constants import ( 24 | DEFAULT_EDGE_METHOD, 25 | DEFAULT_EF_RAFT_MODEL, 26 | DEFAULT_FLOW_ARCH, 27 | DEFAULT_FLOW_MODEL, 28 | EDGE_METHODS, 29 | EF_RAFT_MODELS, 30 | FLOW_ARCHS, 31 | FLOW_DIFF_MODEL, 32 | FLOW_MODELS, 33 | ) 34 | from .utils._ebsynth import ebsynth 35 | from .utils.flow_utils.OpticalFlow import RAFT_flow 36 | from .sequences import EasySequence, SequenceManager 37 | 38 | 39 | class EzsynthBase: 40 | def __init__( 41 | self, 42 | style_frs: list[np.ndarray], 43 | style_idxes: list[int], 44 | img_frs_seq: list[np.ndarray], 45 | cfg: RunConfig = RunConfig(), 46 | edge_method="Classic", 47 | raft_flow_model_name="sintel", 48 | do_mask=False, 49 | msk_frs_seq: list[np.ndarray] | None = None, 50 | flow_arch="RAFT", 51 | do_compute_edge=True, 52 | ) -> None: 53 | st = time.time() 54 | 55 | self.style_frs = style_frs 56 | self.style_idxes = style_idxes 57 | self.img_frs_seq = img_frs_seq 58 | self.msk_frs_seq = msk_frs_seq or [] 59 | 60 | self.len_img = len(self.img_frs_seq) 61 | self.len_msk = len(self.msk_frs_seq) 62 | self.len_stl = len(self.style_idxes) 63 | 64 | self.msk_frs_seq = self.msk_frs_seq[: self.len_img] 65 | 66 | self.cfg = cfg 67 | self.edge_method = validate_option( 68 | edge_method, EDGE_METHODS, DEFAULT_EDGE_METHOD 69 | ) 70 | self.flow_model = validate_option( 71 | raft_flow_model_name, FLOW_MODELS, DEFAULT_FLOW_MODEL 72 | ) 73 | 74 | self.flow_arch = validate_option(flow_arch, FLOW_ARCHS, DEFAULT_FLOW_ARCH) 75 | 76 | if self.flow_arch == "RAFT": 77 | self.flow_model = validate_option( 78 | raft_flow_model_name, FLOW_MODELS, DEFAULT_FLOW_MODEL 79 | ) 80 | elif self.flow_arch == "EF_RAFT": 81 | self.flow_model = validate_option( 82 | raft_flow_model_name, EF_RAFT_MODELS, DEFAULT_EF_RAFT_MODEL 83 | ) 84 | elif self.flow_arch == "FLOW_DIFF": 85 | self.flow_model = FLOW_DIFF_MODEL 86 | 87 | self.cfg.do_mask = do_mask and self.len_msk > 0 88 | print(f"Masking mode: {self.cfg.do_mask}") 89 | 90 | if self.cfg.do_mask and len(self.msk_frs_seq) != len(self.img_frs_seq): 91 | raise ValueError( 92 | f"Missing frames: Masks={self.len_msk}, Expected {self.len_img}" 93 | ) 94 | 95 | self.style_masked_frs = None 96 | if self.cfg.do_mask and self.cfg.pre_mask: 97 | self.masked_frs_seq = apply_masks(self.img_frs_seq, self.msk_frs_seq) 98 | self.style_masked_frs = apply_masks_idxes( 99 | self.style_frs, self.msk_frs_seq, self.style_idxes 100 | ) 101 | 102 | manager = SequenceManager( 103 | 0, 104 | self.len_img - 1, 105 | self.len_stl, 106 | self.style_idxes, 107 | list(range(0, self.len_img)), 108 | ) 109 | 110 | self.sequences, self.atlas = manager.create_sequences() 111 | self.num_seqs = len(self.sequences) 112 | 113 | self.edge_guides = [] 114 | if do_compute_edge: 115 | self.edge_guides = precompute_edge_guides( 116 | self.masked_frs_seq 117 | if (self.cfg.do_mask and self.cfg.pre_mask) 118 | else self.img_frs_seq, 119 | self.edge_method, 120 | ) 121 | self.rafter = RAFT_flow(model_name=self.flow_model, arch=self.flow_arch) 122 | 123 | self.eb = ebsynth(**cfg.get_ebsynth_cfg()) 124 | self.eb.runner.initialize_libebsynth() 125 | 126 | print(f"Init Ezsynth took: {time.time() - st:.4f} s") 127 | 128 | def run_sequences(self, cfg_only_mode: str | None = None): 129 | stylized_frames, err_frames, _ = self.run_sequences_full( 130 | cfg_only_mode=cfg_only_mode, return_flow=False 131 | ) 132 | return stylized_frames, err_frames 133 | 134 | def run_sequences_full(self, cfg_only_mode: str | None = None, return_flow=False): 135 | st = time.time() 136 | 137 | if len(self.edge_guides) == 0: 138 | raise ValueError("Edge guides were not computed. ") 139 | if len(self.edge_guides) != self.len_img: 140 | raise ValueError( 141 | f"Missing edge guides: Got {len(self.edge_guides)}, expected {self.len_img}" 142 | ) 143 | 144 | if ( 145 | cfg_only_mode is not None 146 | and cfg_only_mode in EasySequence.get_valid_modes() 147 | ): 148 | self.cfg.only_mode = cfg_only_mode 149 | 150 | no_skip_rev = False 151 | 152 | stylized_frames = [] 153 | err_frames = [] 154 | flow_frames = [] 155 | 156 | img_seq = ( 157 | self.masked_frs_seq 158 | if (self.cfg.do_mask and self.cfg.pre_mask) 159 | else self.img_frs_seq 160 | ) 161 | stl_seq = ( 162 | self.style_masked_frs 163 | if (self.cfg.do_mask and self.cfg.pre_mask) 164 | else self.style_frs 165 | ) 166 | 167 | for i, seq in enumerate(self.sequences): 168 | if self._should_skip_blend_style_last(i): 169 | self.cfg.skip_blend_style_last = True 170 | else: 171 | self.cfg.skip_blend_style_last = False 172 | 173 | if self._should_rev_move_fr(i): 174 | seq.fr_start_idx += 1 175 | no_skip_rev = True 176 | 177 | tmp_stylized_frames, tmp_err_frames, tmp_flow = run_scratch( 178 | seq, 179 | img_seq, 180 | stl_seq, 181 | self.edge_guides, 182 | self.cfg, 183 | self.rafter, 184 | self.eb, 185 | ) 186 | 187 | if self._should_remove_first_fr(i, no_skip_rev): 188 | tmp_stylized_frames.pop(0) 189 | tmp_err_frames.pop(0) 190 | tmp_flow.pop(0) 191 | 192 | no_skip_rev = False 193 | 194 | stylized_frames.extend(tmp_stylized_frames) 195 | err_frames.extend(tmp_err_frames) 196 | flow_frames.extend(tmp_flow) 197 | 198 | print(f"Run took: {time.time() - st:.4f} s") 199 | 200 | if self.cfg.do_mask: 201 | stylized_frames = apply_masked_back_seq( 202 | self.img_frs_seq, stylized_frames, self.msk_frs_seq, self.cfg.feather 203 | ) 204 | 205 | final_flows: list[np.ndarray] = [] 206 | if return_flow: 207 | for flow in tqdm.tqdm(flow_frames, desc="Converting flows"): 208 | final_flows.append(flow_to_image(flow, convert_to_bgr=True)) 209 | 210 | return stylized_frames, err_frames, final_flows 211 | 212 | def _should_skip_blend_style_last(self, i: int) -> bool: 213 | if ( 214 | self.cfg.only_mode == EasySequence.MODE_NON 215 | and i < self.num_seqs - 1 216 | and ( 217 | self.atlas[i] == EasySequence.MODE_BLN 218 | or self.atlas[i + 1] != EasySequence.MODE_FWD 219 | ) 220 | ): 221 | return True 222 | return False 223 | 224 | def _should_rev_move_fr(self, i: int) -> bool: 225 | if ( 226 | i > 0 227 | and self.cfg.only_mode == EasySequence.MODE_REV 228 | and self.atlas[i] == EasySequence.MODE_BLN 229 | ): 230 | return True 231 | return False 232 | 233 | def _should_remove_first_fr(self, i: int, no_skip_rev: bool) -> bool: 234 | if i > 0 and not no_skip_rev: 235 | if (self.atlas[i - 1] == EasySequence.MODE_REV) or ( 236 | self.atlas[i - 1] == EasySequence.MODE_BLN 237 | and ( 238 | self.cfg.only_mode == EasySequence.MODE_FWD 239 | or ( 240 | self.cfg.only_mode == EasySequence.MODE_REV 241 | and self.atlas[i] == EasySequence.MODE_FWD 242 | ) 243 | ) 244 | ): 245 | return True 246 | return False 247 | 248 | 249 | class Ezsynth(EzsynthBase): 250 | def __init__( 251 | self, 252 | style_paths: list[str], 253 | image_folder: str, 254 | cfg: RunConfig = RunConfig(), 255 | edge_method="Classic", 256 | raft_flow_model_name="sintel", 257 | mask_folder: str | None = None, 258 | do_mask=False, 259 | flow_arch="RAFT", 260 | ) -> None: 261 | _, img_idxes, img_frs_seq = setup_src_from_folder(image_folder) 262 | _, style_idxes, style_frs = setup_src_from_lst(style_paths, "style") 263 | msk_frs_seq = setup_masks_from_folder(mask_folder)[2] if do_mask else None 264 | 265 | if img_idxes[0] != 0: 266 | style_idxes = [idx - img_idxes[0] for idx in style_idxes] 267 | 268 | super().__init__( 269 | style_frs=style_frs, 270 | style_idxes=style_idxes, 271 | img_frs_seq=img_frs_seq, 272 | cfg=cfg, 273 | edge_method=edge_method, 274 | raft_flow_model_name=raft_flow_model_name, 275 | do_mask=do_mask, 276 | msk_frs_seq=msk_frs_seq, 277 | flow_arch=flow_arch, 278 | ) 279 | 280 | 281 | class ImageSynthBase: 282 | def __init__( 283 | self, 284 | style_img: np.ndarray, 285 | src_img: np.ndarray, 286 | tgt_img: np.ndarray, 287 | cfg: RunConfig = RunConfig(), 288 | ) -> None: 289 | self.style_img = style_img 290 | self.src_img = src_img 291 | self.tgt_img = tgt_img 292 | self.cfg = cfg 293 | 294 | st = time.time() 295 | 296 | self.eb = ebsynth(**cfg.get_ebsynth_cfg()) 297 | self.eb.runner.initialize_libebsynth() 298 | 299 | print(f"Init ImageSynth took: {time.time() - st:.4f} s") 300 | 301 | def run(self, guides: list[tuple[np.ndarray, np.ndarray, float]] = []): 302 | guides.append((self.src_img, self.tgt_img, self.cfg.img_wgt)) 303 | return self.eb.run(self.style_img, guides=guides) 304 | 305 | 306 | class ImageSynth(ImageSynthBase): 307 | def __init__( 308 | self, 309 | style_path: str, 310 | src_path: str, 311 | tgt_path: str, 312 | cfg: RunConfig = RunConfig(), 313 | ) -> None: 314 | style_img = validate_and_read_img(style_path) 315 | src_img = validate_and_read_img(src_path) 316 | tgt_img = validate_and_read_img(tgt_path) 317 | 318 | super().__init__(style_img, src_img, tgt_img, cfg) 319 | -------------------------------------------------------------------------------- /ezsynth/sequences.py: -------------------------------------------------------------------------------- 1 | class EasySequence: 2 | MODE_FWD = "forward" 3 | MODE_REV = "reverse" 4 | MODE_BLN = "blend" 5 | MODE_NON = "none" 6 | 7 | def __init__( 8 | self, fr_start_idx: int, fr_end_idx: int, mode: str, style_idxs: list[int] 9 | ) -> None: 10 | self.fr_start_idx = fr_start_idx 11 | self.fr_end_idx = fr_end_idx 12 | self.mode = mode 13 | self.style_idxs = style_idxs 14 | 15 | def __repr__(self) -> str: 16 | return f"[{self.fr_start_idx}, {self.fr_end_idx}] {self.mode} {self.style_idxs}" 17 | 18 | @classmethod 19 | def get_valid_modes(cls) -> tuple[str, str, str, str]: 20 | return ( 21 | cls.MODE_FWD, 22 | cls.MODE_REV, 23 | cls.MODE_BLN, 24 | cls.MODE_NON, 25 | ) 26 | 27 | 28 | class SequenceManager: 29 | def __init__(self, begin_fr_idx, end_fr_idx, num_style_frs, style_idxs, img_idxs): 30 | self.begin_fr_idx = begin_fr_idx 31 | self.end_fr_idx = end_fr_idx 32 | self.style_idxs = style_idxs 33 | self.img_idxs = img_idxs 34 | self.num_style_frs = num_style_frs 35 | 36 | def create_sequences(self) -> tuple[list[EasySequence], list[str]]: 37 | sequences = [] 38 | atlas: list[str] = [] 39 | 40 | # Handle sequence before the first style frame 41 | if self.begin_fr_idx < self.style_idxs[0]: 42 | sequences.append( 43 | EasySequence( 44 | fr_start_idx=self.begin_fr_idx, 45 | fr_end_idx=self.style_idxs[0], 46 | mode=EasySequence.MODE_REV, 47 | style_idxs=[0], 48 | ) 49 | ) 50 | atlas.append(EasySequence.MODE_REV) 51 | 52 | # Handle sequences between style frames 53 | for i in range(len(self.style_idxs) - 1): 54 | sequences.append( 55 | EasySequence( 56 | fr_start_idx=self.style_idxs[i], 57 | fr_end_idx=self.style_idxs[i + 1], 58 | mode=EasySequence.MODE_BLN, 59 | style_idxs=[i, i + 1], 60 | ) 61 | ) 62 | atlas.append(EasySequence.MODE_BLN) 63 | 64 | # Handle sequence after the last style frame 65 | if self.end_fr_idx > self.style_idxs[-1]: 66 | sequences.append( 67 | EasySequence( 68 | fr_start_idx=self.style_idxs[-1], 69 | fr_end_idx=self.end_fr_idx, 70 | mode=EasySequence.MODE_FWD, 71 | style_idxs=[self.num_style_frs - 1], 72 | ) 73 | ) 74 | atlas.append(EasySequence.MODE_FWD) 75 | 76 | for seq in sequences: 77 | print(f"{seq}") 78 | return sequences, atlas 79 | -------------------------------------------------------------------------------- /ezsynth/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/ezsynth/utils/__init__.py -------------------------------------------------------------------------------- /ezsynth/utils/_eb.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from ctypes import ( 3 | CDLL, 4 | POINTER, 5 | c_float, 6 | c_int, 7 | c_void_p, 8 | create_string_buffer, 9 | ) 10 | from pathlib import Path 11 | 12 | import numpy as np 13 | 14 | 15 | class EbsynthRunner: 16 | EBSYNTH_BACKEND_CPU = 0x0001 17 | EBSYNTH_BACKEND_CUDA = 0x0002 18 | EBSYNTH_BACKEND_AUTO = 0x0000 19 | EBSYNTH_MAX_STYLE_CHANNELS = 8 20 | EBSYNTH_MAX_GUIDE_CHANNELS = 24 21 | EBSYNTH_VOTEMODE_PLAIN = 0x0001 # weight = 1 22 | EBSYNTH_VOTEMODE_WEIGHTED = 0x0002 # weight = 1/(1+error) 23 | 24 | def __init__(self): 25 | self.libebsynth = None 26 | self.cached_buffer = {} 27 | self.cached_err_buffer = {} 28 | 29 | def initialize_libebsynth(self): 30 | if self.libebsynth is None: 31 | if sys.platform[0:3] == "win": 32 | libebsynth_path = str(Path(__file__).parent / "ebsynth.dll") 33 | self.libebsynth = CDLL(libebsynth_path) 34 | # elif sys.platform == "darwin": 35 | # libebsynth_path = str(Path(__file__).parent / "ebsynth.so") 36 | # self.libebsynth = CDLL(libebsynth_path) 37 | elif sys.platform[0:5] == "linux": 38 | libebsynth_path = str(Path(__file__).parent / "ebsynth.so") 39 | self.libebsynth = CDLL(libebsynth_path) 40 | else: 41 | raise RuntimeError("Unsupported platform.") 42 | 43 | if self.libebsynth is not None: 44 | self.libebsynth.ebsynthRun.argtypes = ( 45 | c_int, 46 | c_int, 47 | c_int, 48 | c_int, 49 | c_int, 50 | c_void_p, 51 | c_void_p, 52 | c_int, 53 | c_int, 54 | c_void_p, 55 | c_void_p, 56 | POINTER(c_float), 57 | POINTER(c_float), 58 | c_float, 59 | c_int, 60 | c_int, 61 | c_int, 62 | POINTER(c_int), 63 | POINTER(c_int), 64 | POINTER(c_int), 65 | c_int, 66 | c_void_p, 67 | c_void_p, 68 | c_void_p, 69 | ) 70 | pass 71 | 72 | def get_or_create_buffer(self, key): 73 | buffer = self.cached_buffer.get(key, None) 74 | if buffer is None: 75 | buffer = create_string_buffer(key[0] * key[1] * key[2]) 76 | self.cached_buffer[key] = buffer 77 | return buffer 78 | 79 | def get_or_create_err_buffer(self, key): 80 | errbuffer = self.cached_err_buffer.get(key, None) 81 | if errbuffer is None: 82 | errbuffer = (c_float * (key[0] * key[1]))() 83 | self.cached_err_buffer[key] = errbuffer 84 | return errbuffer 85 | 86 | # def _normalize_img_shape(self, img: np.ndarray) -> np.ndarray: 87 | # # with self.normalize_lock: 88 | # img_len = len(img.shape) 89 | # if img_len == 2: 90 | # sh, sw = img.shape 91 | # sc = 0 92 | # elif img_len == 3: 93 | # sh, sw, sc = img.shape 94 | 95 | # if sc == 0: 96 | # sc = 1 97 | 98 | # return img 99 | 100 | # def _normalize_img_shape(self, img: np.ndarray) -> np.ndarray: 101 | # if len(img.shape) == 2: 102 | # img = img[..., np.newaxis] 103 | # return img 104 | 105 | def _normalize_img_shape(self, img: np.ndarray) -> np.ndarray: 106 | return np.atleast_3d(img) 107 | 108 | def validate_inputs(self, patch_size: int, guides: list): 109 | # Validation checks 110 | if patch_size < 3: 111 | raise ValueError("patch_size is too small") 112 | if patch_size % 2 == 0: 113 | raise ValueError("patch_size must be an odd number") 114 | if len(guides) == 0: 115 | raise ValueError("at least one guide must be specified") 116 | 117 | def run( 118 | self, 119 | img_style, 120 | guides, 121 | patch_size=5, 122 | num_pyramid_levels=-1, 123 | num_search_vote_iters=6, 124 | num_patch_match_iters=4, 125 | stop_threshold=5, 126 | uniformity_weight=3500.0, 127 | extraPass3x3=False, 128 | ): 129 | self.validate_inputs(patch_size, guides) 130 | 131 | # Initialize libebsynth if not already done 132 | # self.initialize_libebsynth() 133 | 134 | img_style = self._normalize_img_shape(img_style) 135 | sh, sw, sc = img_style.shape 136 | t_h, t_w, t_c = 0, 0, 0 137 | 138 | self.validate_style_channels(sc) 139 | 140 | guides_source = [] 141 | guides_target = [] 142 | guides_weights = [] 143 | 144 | t_h, t_w = self.validate_guides( 145 | guides, sh, sw, t_c, guides_source, guides_target, guides_weights 146 | ) 147 | 148 | guides_source = np.concatenate(guides_source, axis=-1) 149 | guides_target = np.concatenate(guides_target, axis=-1) 150 | guides_weights = (c_float * len(guides_weights))(*guides_weights) 151 | 152 | style_weights = [1.0 / sc for _ in range(sc)] 153 | style_weights = (c_float * sc)(*style_weights) 154 | 155 | maxPyramidLevels = self.get_max_pyramid_level(patch_size, sh, sw, t_h, t_w) 156 | 157 | ( 158 | num_pyramid_levels, 159 | num_search_vote_iters_per_level, 160 | num_patch_match_iters_per_level, 161 | stop_threshold_per_level, 162 | ) = self.validate_per_levels( 163 | num_pyramid_levels, 164 | num_search_vote_iters, 165 | num_patch_match_iters, 166 | stop_threshold, 167 | maxPyramidLevels, 168 | ) 169 | 170 | # Get or create buffers 171 | buffer = self.get_or_create_buffer((t_h, t_w, sc)) 172 | errbuffer = self.get_or_create_err_buffer((t_h, t_w)) 173 | 174 | self.libebsynth.ebsynthRun( 175 | self.EBSYNTH_BACKEND_AUTO, # backend 176 | sc, # numStyleChannels 177 | guides_source.shape[-1], # numGuideChannels 178 | sw, # sourceWidth 179 | sh, # sourceHeight 180 | img_style.tobytes(), # sourceStyleData (width * height * numStyleChannels) bytes, scan-line order 181 | guides_source.tobytes(), # sourceGuideData (width * height * numGuideChannels) bytes, scan-line order 182 | t_w, # targetWidth 183 | t_h, # targetHeight 184 | guides_target.tobytes(), # targetGuideData (width * height * numGuideChannels) bytes, scan-line order 185 | None, # targetModulationData (width * height * numGuideChannels) bytes, scan-line order; pass NULL to switch off the modulation 186 | style_weights, # styleWeights (numStyleChannels) floats 187 | guides_weights, # guideWeights (numGuideChannels) floats 188 | uniformity_weight, # uniformityWeight reasonable values are between 500-15000, 3500 is a good default 189 | patch_size, # patchSize odd sizes only, use 5 for 5x5 patch, 7 for 7x7, etc. 190 | self.EBSYNTH_VOTEMODE_WEIGHTED, # voteMode use VOTEMODE_WEIGHTED for sharper result 191 | num_pyramid_levels, # numPyramidLevels 192 | num_search_vote_iters_per_level, # numSearchVoteItersPerLevel how many search/vote iters to perform at each level (array of ints, coarse first, fine last) 193 | num_patch_match_iters_per_level, # numPatchMatchItersPerLevel how many Patch-Match iters to perform at each level (array of ints, coarse first, fine last) 194 | stop_threshold_per_level, # stopThresholdPerLevel stop improving pixel when its change since last iteration falls under this threshold 195 | 1 196 | if extraPass3x3 197 | else 0, # extraPass3x3 perform additional polishing pass with 3x3 patches at the finest level, use 0 to disable 198 | None, # outputNnfData (width * height * 2) ints, scan-line order; pass NULL to ignore 199 | buffer, # outputImageData (width * height * numStyleChannels) bytes, scan-line order 200 | errbuffer, # outputErrorData (width * height) floats, scan-line order; pass NULL to ignore 201 | ) 202 | 203 | img = np.frombuffer(buffer, dtype=np.uint8).reshape((t_h, t_w, sc)).copy() 204 | err = np.frombuffer(errbuffer, dtype=np.float32).reshape((t_h, t_w)).copy() 205 | 206 | return img, err 207 | 208 | def get_max_pyramid_level(self, patch_size, sh, sw, t_h, t_w): 209 | maxPyramidLevels = 0 210 | min_a = min(sh, t_h) 211 | min_b = min(sw, t_w) 212 | for level in range(32, -1, -1): 213 | pow_a = pow(2.0, -level) 214 | if min(min_a * pow_a, min_b * pow_a) >= (2 * patch_size + 1): 215 | maxPyramidLevels = level + 1 216 | break 217 | return maxPyramidLevels 218 | 219 | def validate_per_levels( 220 | self, 221 | num_pyramid_levels, 222 | num_search_vote_iters, 223 | num_patch_match_iters, 224 | stop_threshold, 225 | maxPyramidLevels, 226 | ): 227 | if num_pyramid_levels == -1: 228 | num_pyramid_levels = maxPyramidLevels 229 | num_pyramid_levels = min(num_pyramid_levels, maxPyramidLevels) 230 | 231 | num_search_vote_iters_per_level = (c_int * num_pyramid_levels)( 232 | *[num_search_vote_iters] * num_pyramid_levels 233 | ) 234 | num_patch_match_iters_per_level = (c_int * num_pyramid_levels)( 235 | *[num_patch_match_iters] * num_pyramid_levels 236 | ) 237 | stop_threshold_per_level = (c_int * num_pyramid_levels)( 238 | *[stop_threshold] * num_pyramid_levels 239 | ) 240 | 241 | return ( 242 | num_pyramid_levels, 243 | num_search_vote_iters_per_level, 244 | num_patch_match_iters_per_level, 245 | stop_threshold_per_level, 246 | ) 247 | 248 | def validate_style_channels(self, sc): 249 | if sc > self.EBSYNTH_MAX_STYLE_CHANNELS: 250 | raise ValueError( 251 | f"error: too many style channels {sc}, maximum number is {self.EBSYNTH_MAX_STYLE_CHANNELS}" 252 | ) 253 | 254 | def validate_guides( 255 | self, guides, sh, sw, t_c, guides_source, guides_target, guides_weights 256 | ): 257 | for i in range(len(guides)): 258 | source_guide, target_guide, guide_weight = guides[i] 259 | source_guide = self._normalize_img_shape(source_guide) 260 | target_guide = self._normalize_img_shape(target_guide) 261 | s_h, s_w, s_c = source_guide.shape 262 | nt_h, nt_w, nt_c = target_guide.shape 263 | 264 | if s_h != sh or s_w != sw: 265 | raise ValueError( 266 | "guide source and style resolution must match style resolution." 267 | ) 268 | 269 | if t_c == 0: 270 | t_h, t_w, t_c = nt_h, nt_w, nt_c 271 | elif nt_h != t_h or nt_w != t_w: 272 | raise ValueError("guides target resolutions must be equal") 273 | 274 | if s_c != nt_c: 275 | raise ValueError("guide source and target channels must match exactly.") 276 | 277 | guides_source.append(source_guide) 278 | guides_target.append(target_guide) 279 | 280 | guides_weights.extend([guide_weight / s_c] * s_c) 281 | return t_h, t_w 282 | -------------------------------------------------------------------------------- /ezsynth/utils/_ebsynth.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from ._eb import EbsynthRunner 4 | 5 | 6 | class ebsynth: 7 | """ 8 | EBSynth class provides a wrapper around the ebsynth style transfer method. 9 | 10 | Usage: 11 | ebsynth = ebsynth.ebsynth(style='style.png', guides=[('source1.png', 'target1.png'), 1.0]) 12 | result_img = ebsynth.run() 13 | """ 14 | 15 | def __init__( 16 | self, 17 | uniformity=3500.0, 18 | patchsize=5, 19 | pyramidlevels=6, 20 | searchvoteiters=12, 21 | patchmatchiters=6, 22 | extrapass3x3=True, 23 | backend="auto", 24 | ): 25 | """ 26 | Initialize the EBSynth wrapper. 27 | :param style: path to the style image, or a numpy array. 28 | :param guides: list of tuples containing source and target guide images, as file paths or as numpy arrays. 29 | :param weight: weights for each guide pair. Defaults to 1.0 for each pair. 30 | :param uniformity: uniformity weight for the style transfer. Defaults to 3500.0. 31 | :param patchsize: size of the patches. Must be an odd number. Defaults to 5. [5x5 patches] 32 | :param pyramidlevels: number of pyramid levels. Larger Values useful for things like color transfer. Defaults to 6. 33 | :param searchvoteiters: number of search/vote iterations. Defaults to 12. 34 | :param patchmatchiters: number of Patch-Match iterations. Defaults to 6. 35 | :param extrapass3x3: whether to perform an extra pass with 3x3 patches. Defaults to False. 36 | :param backend: backend to use ('cpu', 'cuda', or 'auto'). Defaults to 'auto'. 37 | """ 38 | 39 | self.runner = EbsynthRunner() 40 | self.uniformity = uniformity 41 | self.patchsize = patchsize 42 | self.pyramidlevels = pyramidlevels 43 | self.searchvoteiters = searchvoteiters 44 | self.patchmatchiters = patchmatchiters 45 | self.extrapass3x3 = extrapass3x3 46 | 47 | # Define backend constants 48 | self.backends = { 49 | "cpu": EbsynthRunner.EBSYNTH_BACKEND_CPU, 50 | "cuda": EbsynthRunner.EBSYNTH_BACKEND_CUDA, 51 | "auto": EbsynthRunner.EBSYNTH_BACKEND_AUTO, 52 | } 53 | self.backend = self.backends[backend] 54 | 55 | def run(self, style: np.ndarray, guides: list[tuple[np.ndarray, np.ndarray, np.ndarray]]): 56 | # Call the run function with the provided arguments 57 | img, err = self.runner.run( 58 | style, 59 | guides, 60 | patch_size=self.patchsize, 61 | num_pyramid_levels=self.pyramidlevels, 62 | num_search_vote_iters=self.searchvoteiters, 63 | num_patch_match_iters=self.patchmatchiters, 64 | uniformity_weight=self.uniformity, 65 | extraPass3x3=self.extrapass3x3, 66 | ) 67 | 68 | return img, err 69 | -------------------------------------------------------------------------------- /ezsynth/utils/blend/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/ezsynth/utils/blend/__init__.py -------------------------------------------------------------------------------- /ezsynth/utils/blend/blender.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | import tqdm 5 | 6 | from ..flow_utils.warp import Warp 7 | from .histogram_blend import hist_blender 8 | from .reconstruction import reconstructor 9 | 10 | try: 11 | from .cupy_accelerated import hist_blend_cupy 12 | 13 | USE_GPU = True 14 | except ImportError as e: 15 | print(f"Cupy is not installed. Revert to CPU. {e}") 16 | USE_GPU = False 17 | 18 | 19 | class Blend: 20 | def __init__( 21 | self, 22 | use_gpu=False, 23 | use_lsqr=True, 24 | use_poisson_cupy=False, 25 | poisson_maxiter=None, 26 | ): 27 | self.prev_mask = None 28 | 29 | self.use_gpu = use_gpu and USE_GPU 30 | self.use_lsqr = use_lsqr 31 | self.use_poisson_cupy = use_poisson_cupy 32 | self.poisson_maxiter = poisson_maxiter 33 | 34 | def _warping_masks( 35 | self, 36 | sample_fr: np.ndarray, 37 | flow_fwd: list[np.ndarray], 38 | err_masks: list[np.ndarray], 39 | ): 40 | # use err_masks with flow to create final err_masks 41 | warped_masks = [] 42 | warp = Warp(sample_fr) 43 | 44 | for i in tqdm.tqdm(range(len(err_masks)), desc="Warping masks"): 45 | if self.prev_mask is None: 46 | self.prev_mask = np.zeros_like(err_masks[0]) 47 | warped_mask = warp.run_warping( 48 | err_masks[i], flow_fwd[i] if i == 0 else flow_fwd[i - 1] 49 | ) 50 | 51 | z_hat = warped_mask.copy() 52 | # If the shapes are not compatible, we can adjust the shape of self.prev_mask 53 | if self.prev_mask.shape != z_hat.shape: 54 | self.prev_mask = np.repeat( 55 | self.prev_mask[:, :, np.newaxis], z_hat.shape[2], axis=2 56 | ) 57 | 58 | z_hat = np.where((self.prev_mask > 1) & (z_hat == 0), 1, z_hat) 59 | 60 | self.prev_mask = z_hat.copy() 61 | warped_masks.append(z_hat.copy()) 62 | return warped_masks 63 | 64 | def _create_selection_mask( 65 | self, err_forward_lst: list[np.ndarray], err_backward_lst: list[np.ndarray] 66 | ) -> list[np.ndarray]: 67 | err_forward = np.array(err_forward_lst) 68 | err_backward = np.array(err_backward_lst) 69 | 70 | if err_forward.shape != err_backward.shape: 71 | print(f"Shape mismatch: {err_forward.shape=} vs {err_backward.shape=}") 72 | return [] 73 | 74 | # Create a binary mask where the forward error metric 75 | # is less than the backward error metric 76 | selection_masks = np.where(err_forward < err_backward, 0, 1).astype(np.uint8) 77 | 78 | # Convert numpy array back to list 79 | selection_masks_lst = [ 80 | selection_masks[i] for i in range(selection_masks.shape[0]) 81 | ] 82 | 83 | return selection_masks_lst 84 | 85 | def _hist_blend( 86 | self, 87 | style_fwd: list[np.ndarray], 88 | style_bwd: list[np.ndarray], 89 | err_masks: list[np.ndarray], 90 | ) -> list[np.ndarray]: 91 | st = time.time() 92 | hist_blends: list[np.ndarray] = [] 93 | for i in tqdm.tqdm(range(len(err_masks)), desc="Hist blending: "): 94 | if self.use_gpu: 95 | hist_blend = hist_blend_cupy( 96 | style_fwd[i], 97 | style_bwd[i], 98 | err_masks[i], 99 | ) 100 | else: 101 | hist_blend = hist_blender( 102 | style_fwd[i], 103 | style_bwd[i], 104 | err_masks[i], 105 | ) 106 | hist_blends.append(hist_blend) 107 | print(f"Hist Blend took {time.time() - st:.4f} s") 108 | print(len(hist_blends)) 109 | return hist_blends 110 | 111 | def _reconstruct( 112 | self, 113 | style_fwd: list[np.ndarray], 114 | style_bwd: list[np.ndarray], 115 | err_masks: list[np.ndarray], 116 | hist_blends: list[np.ndarray], 117 | ): 118 | blends = reconstructor( 119 | hist_blends, 120 | style_fwd, 121 | style_bwd, 122 | err_masks, 123 | use_gpu=self.use_gpu, 124 | use_lsqr=self.use_lsqr, 125 | use_poisson_cupy=self.use_poisson_cupy, 126 | poisson_maxiter=self.poisson_maxiter, 127 | ) 128 | final_blends = blends._create() 129 | final_blends = [blend for blend in final_blends if blend is not None] 130 | return final_blends 131 | -------------------------------------------------------------------------------- /ezsynth/utils/blend/cupy_accelerated.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import cupy as cp 4 | import cupyx 5 | import cupyx.scipy.sparse._csc 6 | import cupyx.scipy.sparse.linalg 7 | import cv2 8 | import numpy as np 9 | import scipy.sparse 10 | 11 | 12 | def assemble_min_error_img(a, b, error_mask): 13 | return cp.where(error_mask == 0, a, b) 14 | 15 | 16 | def mean_std(img): 17 | return cp.mean(img, axis=(0, 1)), cp.std(img, axis=(0, 1)) 18 | 19 | 20 | def histogram_transform(img, means, stds, target_means, target_stds): 21 | return ((img - means) * target_stds / stds + target_means).astype(cp.float32) 22 | 23 | 24 | def hist_blend_cupy( 25 | a: np.ndarray, b: np.ndarray, error_mask: np.ndarray, weight1=0.5, weight2=0.5 26 | ): 27 | a = cp.asarray(a) 28 | b = cp.asarray(b) 29 | error_mask = cp.asarray(error_mask) 30 | 31 | # Ensure error_mask has 3 channels 32 | if len(error_mask.shape) == 2: 33 | error_mask = cp.repeat(error_mask[:, :, cp.newaxis], 3, axis=2) 34 | 35 | # Convert to Lab color space 36 | a_lab = cv2.cvtColor(cp.asnumpy(a), cv2.COLOR_BGR2Lab) 37 | b_lab = cv2.cvtColor(cp.asnumpy(b), cv2.COLOR_BGR2Lab) 38 | a_lab = cp.asarray(a_lab) 39 | b_lab = cp.asarray(b_lab) 40 | 41 | min_error_lab = assemble_min_error_img(a_lab, b_lab, error_mask) 42 | 43 | # Compute means and stds 44 | a_mean, a_std = mean_std(a_lab) 45 | b_mean, b_std = mean_std(b_lab) 46 | min_error_mean, min_error_std = mean_std(min_error_lab) 47 | 48 | # Histogram transformation constants 49 | t_mean = cp.full(3, 0.5 * 256, dtype=cp.float32) 50 | t_std = cp.full(3, (1 / 36) * 256, dtype=cp.float32) 51 | 52 | # Histogram transform 53 | a_lab = histogram_transform(a_lab, a_mean, a_std, t_mean, t_std) 54 | b_lab = histogram_transform(b_lab, b_mean, b_std, t_mean, t_std) 55 | 56 | # Blending 57 | ab_lab = (a_lab * weight1 + b_lab * weight2 - 128) / 0.5 + 128 58 | ab_mean, ab_std = mean_std(ab_lab) 59 | 60 | # Final histogram transform 61 | ab_lab = histogram_transform(ab_lab, ab_mean, ab_std, min_error_mean, min_error_std) 62 | 63 | ab_lab = cp.clip(cp.round(ab_lab), 0, 255).astype(cp.uint8) 64 | ab_lab = cp.asnumpy(ab_lab) 65 | ab = cv2.cvtColor(ab_lab, cv2.COLOR_Lab2BGR) 66 | return ab 67 | 68 | 69 | def construct_A_cupy(h: int, w: int, grad_weight: list[float], use_poisson_cupy=False): 70 | st = time.time() 71 | indgx_x = cp.zeros(2 * (h - 1) * w, dtype=int) 72 | indgx_y = cp.zeros(2 * (h - 1) * w, dtype=int) 73 | vdx = cp.ones(2 * (h - 1) * w) 74 | 75 | indgy_x = cp.zeros(2 * h * (w - 1), dtype=int) 76 | indgy_y = cp.zeros(2 * h * (w - 1), dtype=int) 77 | vdy = cp.ones(2 * h * (w - 1)) 78 | 79 | indgx_x[::2] = cp.arange((h - 1) * w) 80 | indgx_y[::2] = indgx_x[::2] 81 | indgx_x[1::2] = indgx_x[::2] 82 | indgx_y[1::2] = indgx_x[::2] + w 83 | 84 | indgy_x[::2] = cp.arange(h * (w - 1)) 85 | indgy_y[::2] = indgy_x[::2] 86 | indgy_x[1::2] = indgy_x[::2] 87 | indgy_y[1::2] = indgy_x[::2] + 1 88 | 89 | vdx[1::2] = -1 90 | vdy[1::2] = -1 91 | 92 | Ix = cupyx.scipy.sparse.eye(h * w, format="csc") 93 | Gx = cupyx.scipy.sparse.coo_matrix( 94 | (vdx, (indgx_x, indgx_y)), shape=(h * w, h * w) 95 | ).tocsc() 96 | Gy = cupyx.scipy.sparse.coo_matrix( 97 | (vdy, (indgy_x, indgy_y)), shape=(h * w, h * w) 98 | ).tocsc() 99 | 100 | As = [ 101 | cupyx.scipy.sparse.vstack([Gx * weight, Gy * weight, Ix]) 102 | for weight in grad_weight 103 | ] 104 | print(f"Constructing As took {time.time() - st:.4f} s") 105 | if not use_poisson_cupy: 106 | As_scipy = [ 107 | scipy.sparse.vstack( 108 | [ 109 | scipy.sparse.csr_matrix(Gx.get() * weight), 110 | scipy.sparse.csr_matrix(Gy.get() * weight), 111 | scipy.sparse.csr_matrix(Ix.get()), 112 | ] 113 | ) 114 | for weight in grad_weight 115 | ] 116 | return As_scipy 117 | return As 118 | 119 | 120 | def poisson_fusion_cupy( 121 | blendI: np.ndarray, 122 | I1: np.ndarray, 123 | I2: np.ndarray, 124 | mask: np.ndarray, 125 | As: list[cupyx.scipy.sparse._csc.csc_matrix], 126 | poisson_maxiter=None, 127 | ): 128 | grad_weight = [2.5, 0.5, 0.5] 129 | Iab = cv2.cvtColor(blendI, cv2.COLOR_BGR2LAB).astype(float) 130 | Ia = cv2.cvtColor(I1, cv2.COLOR_BGR2LAB).astype(float) 131 | Ib = cv2.cvtColor(I2, cv2.COLOR_BGR2LAB).astype(float) 132 | 133 | Iab_cp = cp.asarray(Iab) 134 | Ia_cp = cp.asarray(Ia) 135 | Ib_cp = cp.asarray(Ib) 136 | mask_cp = cp.asarray(mask) 137 | 138 | m_cp = (mask_cp > 0).astype(float)[..., cp.newaxis] 139 | h, w, c = Iab.shape 140 | 141 | gx_cp = cp.zeros_like(Ia_cp) 142 | gy_cp = cp.zeros_like(Ia_cp) 143 | 144 | gx_cp[:-1] = (Ia_cp[:-1] - Ia_cp[1:]) * (1 - m_cp[:-1]) + ( 145 | Ib_cp[:-1] - Ib_cp[1:] 146 | ) * m_cp[:-1] 147 | gy_cp[:, :-1] = (Ia_cp[:, :-1] - Ia_cp[:, 1:]) * (1 - m_cp[:, :-1]) + ( 148 | Ib_cp[:, :-1] - Ib_cp[:, 1:] 149 | ) * m_cp[:, :-1] 150 | 151 | final_channels = [ 152 | poisson_fusion_channel_cupy( 153 | Iab_cp, gx_cp, gy_cp, h, w, As, i, grad_weight, maxiter=poisson_maxiter 154 | ) 155 | for i in range(3) 156 | ] 157 | 158 | final = np.clip(np.concatenate(final_channels, axis=2), 0, 255) 159 | return cv2.cvtColor(final.astype(np.uint8), cv2.COLOR_LAB2BGR) 160 | 161 | 162 | def poisson_fusion_channel_cupy( 163 | Iab: cp.ndarray, 164 | gx: cp.ndarray, 165 | gy: cp.ndarray, 166 | h: int, 167 | w: int, 168 | As: list[cupyx.scipy.sparse._csc.csc_matrix], 169 | channel: int, 170 | grad_weight: list[float], 171 | maxiter: int | None = None, 172 | ): 173 | cp.get_default_memory_pool().free_all_blocks() 174 | weight = grad_weight[channel] 175 | im_dx = cp.clip(gx[:, :, channel].reshape(h * w, 1), -100, 100) 176 | im_dy = cp.clip(gy[:, :, channel].reshape(h * w, 1), -100, 100) 177 | im = Iab[:, :, channel].reshape(h * w, 1) 178 | im_mean = im.mean() 179 | im = im - im_mean 180 | A = As[channel] 181 | b = cp.vstack([im_dx * weight, im_dy * weight, im]) 182 | out = cupyx.scipy.sparse.linalg.lsmr(A, b, maxiter=maxiter) 183 | out_im = (out[0] + im_mean).reshape(h, w, 1) 184 | 185 | return cp.asnumpy(out_im) 186 | -------------------------------------------------------------------------------- /ezsynth/utils/blend/histogram_blend.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def hist_blender( 6 | a: np.ndarray, 7 | b: np.ndarray, 8 | error_mask: np.ndarray, 9 | weight1=0.5, 10 | weight2=0.5, 11 | ) -> np.ndarray: 12 | # Ensure error_mask has 3 channels 13 | if len(error_mask.shape) == 2: 14 | error_mask = np.repeat(error_mask[:, :, np.newaxis], 3, axis=2) 15 | 16 | # Convert to Lab color space 17 | a_lab = cv2.cvtColor(a, cv2.COLOR_BGR2Lab) 18 | b_lab = cv2.cvtColor(b, cv2.COLOR_BGR2Lab) 19 | 20 | # Generate min_error_img 21 | min_error_lab = assemble_min_error_img(a_lab, b_lab, error_mask) 22 | 23 | # Compute means and stds 24 | a_mean, a_std = mean_std(a_lab) 25 | b_mean, b_std = mean_std(b_lab) 26 | min_error_mean, min_error_std = mean_std(min_error_lab) 27 | 28 | # Histogram transformation constants 29 | t_mean = np.full(3, 0.5 * 256, dtype=np.float32) 30 | t_std = np.full(3, (1 / 36) * 256, dtype=np.float32) 31 | 32 | # Histogram transform 33 | a_lab = histogram_transform(a_lab, a_mean, a_std, t_mean, t_std) 34 | b_lab = histogram_transform(b_lab, b_mean, b_std, t_mean, t_std) 35 | 36 | # Blending 37 | ab_lab = (a_lab * weight1 + b_lab * weight2 - 128) / 0.5 + 128 38 | ab_mean, ab_std = mean_std(ab_lab) 39 | 40 | # Final histogram transform 41 | ab_lab = histogram_transform(ab_lab, ab_mean, ab_std, min_error_mean, min_error_std) 42 | 43 | ab_lab = np.clip(np.round(ab_lab), 0, 255).astype(np.uint8) 44 | 45 | # Convert back to BGR 46 | ab = cv2.cvtColor(ab_lab, cv2.COLOR_Lab2BGR) 47 | 48 | return ab 49 | 50 | 51 | def histogram_transform(img, means, stds, target_means, target_stds): 52 | return ((img - means) * target_stds / stds + target_means).astype(np.float32) 53 | 54 | 55 | def assemble_min_error_img(a, b, error_mask): 56 | return np.where(error_mask == 0, a, b) 57 | 58 | 59 | def mean_std(img): 60 | return np.mean(img, axis=(0, 1)), np.std(img, axis=(0, 1)) 61 | -------------------------------------------------------------------------------- /ezsynth/utils/blend/reconstruction.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import cv2 4 | import numpy as np 5 | import scipy.sparse 6 | import scipy.sparse.linalg 7 | import tqdm 8 | 9 | try: 10 | from .cupy_accelerated import construct_A_cupy, poisson_fusion_cupy 11 | 12 | USE_GPU = True 13 | print("Cupy is installed. Can do Cupy GPU accelerations") 14 | except ImportError: 15 | USE_GPU = False 16 | print("Cupy is not installed. Revert to CPU") 17 | 18 | 19 | class reconstructor: 20 | def __init__( 21 | self, 22 | hist_blends: list[np.ndarray], 23 | style_fwd: list[np.ndarray], 24 | style_bwd: list[np.ndarray], 25 | err_masks: list[np.ndarray], 26 | use_gpu=False, 27 | use_lsqr=True, 28 | use_poisson_cupy=False, 29 | poisson_maxiter=None, 30 | ): 31 | self.hist_blends = hist_blends 32 | self.style_fwd = style_fwd 33 | self.style_bwd = style_bwd 34 | self.err_masks = err_masks 35 | self.blends = None 36 | 37 | self.use_gpu = use_gpu and USE_GPU 38 | self.use_lsqr = use_lsqr 39 | self.use_poisson_cupy = self.use_gpu and use_poisson_cupy 40 | self.poisson_maxiter = poisson_maxiter 41 | 42 | def _create(self): 43 | num_blends = len(self.hist_blends) 44 | h, w, c = self.hist_blends[0].shape 45 | self.blends = np.zeros((num_blends, h, w, c)) 46 | 47 | a = construct_A(h, w, [2.5, 0.5, 0.5], self.use_gpu, self.use_poisson_cupy) 48 | for i in tqdm.tqdm(range(num_blends)): 49 | self.blends[i] = poisson_fusion( 50 | self.hist_blends[i], 51 | self.style_fwd[i], 52 | self.style_bwd[i], 53 | self.err_masks[i], 54 | a, 55 | self.use_gpu, 56 | self.use_lsqr, 57 | self.use_poisson_cupy, 58 | self.poisson_maxiter, 59 | ) 60 | 61 | return self.blends 62 | 63 | 64 | def construct_A( 65 | h: int, w: int, grad_weight: list[float], use_gpu=False, use_poisson_cupy=False 66 | ): 67 | if use_gpu: 68 | return construct_A_cupy(h, w, grad_weight, use_poisson_cupy) 69 | return construct_A_cpu(h, w, grad_weight) 70 | 71 | 72 | def poisson_fusion( 73 | blendI: np.ndarray, 74 | I1: np.ndarray, 75 | I2: np.ndarray, 76 | mask: np.ndarray, 77 | As, 78 | use_gpu=False, 79 | use_lsqr=True, 80 | use_poisson_cupy=False, 81 | poisson_maxiter=None, 82 | ): 83 | if use_gpu and use_poisson_cupy: 84 | return poisson_fusion_cupy(blendI, I1, I2, mask, As, poisson_maxiter) 85 | return poisson_fusion_cpu_optimized( 86 | blendI, I1, I2, mask, As, use_lsqr, poisson_maxiter 87 | ) 88 | 89 | 90 | def construct_A_cpu(h: int, w: int, grad_weight: list[float]): 91 | st = time.time() 92 | indgx_x = np.zeros(2 * (h - 1) * w, dtype=int) 93 | indgx_y = np.zeros(2 * (h - 1) * w, dtype=int) 94 | vdx = np.ones(2 * (h - 1) * w) 95 | 96 | indgy_x = np.zeros(2 * h * (w - 1), dtype=int) 97 | indgy_y = np.zeros(2 * h * (w - 1), dtype=int) 98 | vdy = np.ones(2 * h * (w - 1)) 99 | 100 | indgx_x[::2] = np.arange((h - 1) * w) 101 | indgx_y[::2] = indgx_x[::2] 102 | indgx_x[1::2] = indgx_x[::2] 103 | indgx_y[1::2] = indgx_x[::2] + w 104 | 105 | indgy_x[::2] = np.arange(h * (w - 1)) 106 | indgy_y[::2] = indgy_x[::2] 107 | indgy_x[1::2] = indgy_x[::2] 108 | indgy_y[1::2] = indgy_x[::2] + 1 109 | 110 | vdx[1::2] = -1 111 | vdy[1::2] = -1 112 | 113 | Ix = scipy.sparse.eye(h * w, format="csc") 114 | Gx = scipy.sparse.coo_matrix( 115 | (vdx, (indgx_x, indgx_y)), shape=(h * w, h * w) 116 | ).tocsc() 117 | Gy = scipy.sparse.coo_matrix( 118 | (vdy, (indgy_x, indgy_y)), shape=(h * w, h * w) 119 | ).tocsc() 120 | 121 | As = [scipy.sparse.vstack([Gx * weight, Gy * weight, Ix]) for weight in grad_weight] 122 | print(f"Constructing As took {time.time() - st:.4f} s") 123 | return As 124 | 125 | 126 | def poisson_fusion_cpu( 127 | blendI: np.ndarray, 128 | I1: np.ndarray, 129 | I2: np.ndarray, 130 | mask: np.ndarray, 131 | As: list[scipy.sparse._csc.csc_matrix], 132 | use_lsqr=True, 133 | ): 134 | grad_weight = [2.5, 0.5, 0.5] 135 | Iab = cv2.cvtColor(blendI, cv2.COLOR_BGR2LAB).astype(float) 136 | Ia = cv2.cvtColor(I1, cv2.COLOR_BGR2LAB).astype(float) 137 | Ib = cv2.cvtColor(I2, cv2.COLOR_BGR2LAB).astype(float) 138 | 139 | m = (mask > 0).astype(float)[..., np.newaxis] 140 | h, w, c = Iab.shape 141 | 142 | gx = np.zeros_like(Ia) 143 | gy = np.zeros_like(Ia) 144 | 145 | gx[:-1] = (Ia[:-1] - Ia[1:]) * (1 - m[:-1]) + (Ib[:-1] - Ib[1:]) * m[:-1] 146 | gy[:, :-1] = (Ia[:, :-1] - Ia[:, 1:]) * (1 - m[:, :-1]) + ( 147 | Ib[:, :-1] - Ib[:, 1:] 148 | ) * m[:, :-1] 149 | 150 | final_channels = [ 151 | poisson_fusion_channel_cpu(Iab, gx, gy, h, w, As, i, grad_weight, use_lsqr) 152 | for i in range(3) 153 | ] 154 | 155 | final = np.clip(np.concatenate(final_channels, axis=2), 0, 255) 156 | return cv2.cvtColor(final.astype(np.uint8), cv2.COLOR_LAB2BGR) 157 | 158 | 159 | def poisson_fusion_channel_cpu( 160 | Iab: np.ndarray, 161 | gx: np.ndarray, 162 | gy: np.ndarray, 163 | h: int, 164 | w: int, 165 | As: list[scipy.sparse._csc.csc_matrix], 166 | channel: int, 167 | grad_weight: list[float], 168 | use_lsqr=True, 169 | ): 170 | """Helper function to perform Poisson fusion on a single channel.""" 171 | weight = grad_weight[channel] 172 | im_dx = np.clip(gx[:, :, channel].reshape(h * w, 1), -100, 100) 173 | im_dy = np.clip(gy[:, :, channel].reshape(h * w, 1), -100, 100) 174 | im = Iab[:, :, channel].reshape(h * w, 1) 175 | im_mean = im.mean() 176 | im = im - im_mean 177 | 178 | A = As[channel] 179 | b = np.vstack([im_dx * weight, im_dy * weight, im]) 180 | if use_lsqr: 181 | out = scipy.sparse.linalg.lsqr(A, b) 182 | else: 183 | out = scipy.sparse.linalg.lsmr(A, b) 184 | out_im = (out[0] + im_mean).reshape(h, w, 1) 185 | 186 | return out_im 187 | 188 | 189 | def gradient_compute_python(Ia: np.ndarray, Ib: np.ndarray, m: np.ndarray): 190 | gx = np.zeros_like(Ia) 191 | gy = np.zeros_like(Ia) 192 | 193 | gx[:-1] = (Ia[:-1] - Ia[1:]) * (1 - m[:-1]) + (Ib[:-1] - Ib[1:]) * m[:-1] 194 | gy[:, :-1] = (Ia[:, :-1] - Ia[:, 1:]) * (1 - m[:, :-1]) + ( 195 | Ib[:, :-1] - Ib[:, 1:] 196 | ) * m[:, :-1] 197 | return gx, gy 198 | 199 | 200 | def poisson_fusion_cpu_optimized( 201 | blendI: np.ndarray, 202 | I1: np.ndarray, 203 | I2: np.ndarray, 204 | mask: np.ndarray, 205 | As: list[scipy.sparse._csc.csc_matrix], 206 | use_lsqr=True, 207 | poisson_maxiter=None, 208 | ): 209 | # grad_weight = [2.5, 0.5, 0.5] 210 | grad_weight = np.array([2.5, 0.5, 0.5]) 211 | Iab = cv2.cvtColor(blendI, cv2.COLOR_BGR2LAB).astype(float) 212 | Ia = cv2.cvtColor(I1, cv2.COLOR_BGR2LAB).astype(float) 213 | Ib = cv2.cvtColor(I2, cv2.COLOR_BGR2LAB).astype(float) 214 | 215 | m = (mask > 0).astype(float)[..., np.newaxis] 216 | h, w, c = Iab.shape 217 | 218 | gx, gy = gradient_compute_python(Ia, Ib, m) 219 | 220 | # Reshape and clip all channels at once 221 | gx_reshaped = np.clip(gx.reshape(h * w, c), -100, 100) 222 | gy_reshaped = np.clip(gy.reshape(h * w, c), -100, 100) 223 | 224 | Iab_reshaped = Iab.reshape(h * w, c) 225 | Iab_mean = np.mean(Iab_reshaped, axis=0) 226 | Iab_centered = Iab_reshaped - Iab_mean 227 | 228 | # Pre-allocate the output array 229 | out_all = np.zeros((h * w, c), dtype=np.float32) 230 | 231 | for channel in range(3): 232 | weight = grad_weight[channel] 233 | im_dx = gx_reshaped[:, channel : channel + 1] 234 | im_dy = gy_reshaped[:, channel : channel + 1] 235 | im = Iab_centered[:, channel : channel + 1] 236 | 237 | A = As[channel] 238 | b = np.vstack([im_dx * weight, im_dy * weight, im]) 239 | 240 | if use_lsqr: 241 | out_all[:, channel] = scipy.sparse.linalg.lsqr(A, b)[0] 242 | else: 243 | out_all[:, channel] = scipy.sparse.linalg.lsmr( 244 | A, b, maxiter=poisson_maxiter 245 | )[0] 246 | 247 | # Add back the mean and reshape 248 | final = (out_all + Iab_mean).reshape(h, w, c) 249 | final = np.clip(final, 0, 255) 250 | return cv2.cvtColor(final.astype(np.uint8), cv2.COLOR_LAB2BGR) 251 | -------------------------------------------------------------------------------- /ezsynth/utils/ebsynth.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/ezsynth/utils/ebsynth.dll -------------------------------------------------------------------------------- /ezsynth/utils/flow_utils/OpticalFlow.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | import tqdm 7 | 8 | from .core.utils.utils import InputPadder 9 | 10 | class RAFT_flow: 11 | def __init__(self, model_name="sintel", arch="RAFT"): 12 | self.DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 13 | 14 | self.arch = arch 15 | 16 | if self.arch == "RAFT": 17 | from .core.raft import RAFT 18 | model_name = f"raft-{model_name}.pth" 19 | model_path = os.path.join(os.path.dirname(__file__), "models", model_name) 20 | 21 | if not os.path.exists(model_path): 22 | raise ValueError(f"[ERROR] Model file '{model_path}' not found.") 23 | 24 | self.model = torch.nn.DataParallel( 25 | RAFT(args=self._instantiate_raft_model(model_name)) 26 | ) 27 | 28 | elif self.arch == "EF_RAFT": 29 | from .core.ef_raft import EF_RAFT 30 | model_name = f"{model_name}.pth" 31 | model_path = os.path.join( 32 | os.path.dirname(__file__), "ef_raft_models", model_name 33 | ) 34 | if not os.path.exists(model_path): 35 | raise ValueError(f"[ERROR] Model file '{model_path}' not found.") 36 | self.model = torch.nn.DataParallel( 37 | EF_RAFT(args=self._instantiate_raft_model(model_name)) 38 | ) 39 | 40 | elif self.arch == "FLOW_DIFF": 41 | try: 42 | from .core.flow_diffusion import FlowDiffuser 43 | except ImportError as e: 44 | raise ImportError(f"Could not import FlowDiffuser. {e}") 45 | model_name = "FlowDiffuser-things.pth" 46 | model_path = os.path.join(os.path.dirname(__file__), "flow_diffusion_models", model_name) 47 | if not os.path.exists(model_path): 48 | raise ValueError(f"[ERROR] Model file '{model_path}' not found.") 49 | self.model = torch.nn.DataParallel( 50 | FlowDiffuser(args=self._instantiate_raft_model(model_name)) 51 | ) 52 | 53 | 54 | state_dict = torch.load(model_path, map_location=self.DEVICE) 55 | self.model.load_state_dict(state_dict) 56 | 57 | self.model.to(self.DEVICE) 58 | self.model.eval() 59 | 60 | def _instantiate_raft_model(self, model_name): 61 | from argparse import Namespace 62 | 63 | args = Namespace() 64 | args.model = model_name 65 | args.small = False 66 | args.mixed_precision = False 67 | return args 68 | 69 | def _load_tensor_from_numpy(self, np_array: np.ndarray): 70 | try: 71 | tensor = ( 72 | torch.tensor(np_array, dtype=torch.float32) 73 | .permute(2, 0, 1) 74 | .unsqueeze(0) 75 | .to(self.DEVICE) 76 | ) 77 | return tensor 78 | except Exception as e: 79 | print(f"[ERROR] Exception in load_tensor_from_numpy: {e}") 80 | raise e 81 | 82 | def _compute_flow(self, img1: np.ndarray, img2: np.ndarray): 83 | original_size = img1.shape[1::-1] 84 | with torch.no_grad(): 85 | img1_tensor = self._load_tensor_from_numpy(img1) 86 | img2_tensor = self._load_tensor_from_numpy(img2) 87 | padder = InputPadder(img1_tensor.shape) 88 | images = padder.pad(img1_tensor, img2_tensor) 89 | _, flow_up = self.model(images[0], images[1], iters=20, test_mode=True) 90 | # flow_np = flow_up[0].permute(1, 2, 0).cpu().numpy() 91 | flow_np = padder.unpad(flow_up[0]).permute(1, 2, 0).cpu().numpy() 92 | cv2.resize(flow_np, original_size) 93 | return flow_np 94 | 95 | def compute_flow(self, img_frs_seq: list[np.ndarray]): 96 | optical_flow = [] 97 | total_flows = len(img_frs_seq) - 1 98 | for i in tqdm.tqdm(range(total_flows), desc="Calculating Flow: "): 99 | optical_flow.append(self._compute_flow(img_frs_seq[i], img_frs_seq[i + 1])) 100 | self.optical_flow = optical_flow 101 | return self.optical_flow 102 | -------------------------------------------------------------------------------- /ezsynth/utils/flow_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/ezsynth/utils/flow_utils/__init__.py -------------------------------------------------------------------------------- /ezsynth/utils/flow_utils/alt_cuda_corr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/ezsynth/utils/flow_utils/alt_cuda_corr/__init__.py -------------------------------------------------------------------------------- /ezsynth/utils/flow_utils/alt_cuda_corr/correlation.cpp: -------------------------------------------------------------------------------- 1 | <<<<<<< Updated upstream 2 | #include 3 | #include 4 | 5 | // CUDA forward declarations 6 | std::vector corr_cuda_forward( 7 | torch::Tensor fmap1, 8 | torch::Tensor fmap2, 9 | torch::Tensor coords, 10 | int radius); 11 | 12 | std::vector corr_cuda_backward( 13 | torch::Tensor fmap1, 14 | torch::Tensor fmap2, 15 | torch::Tensor coords, 16 | torch::Tensor corr_grad, 17 | int radius); 18 | 19 | // C++ interface 20 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 21 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 22 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 23 | 24 | std::vector corr_forward( 25 | torch::Tensor fmap1, 26 | torch::Tensor fmap2, 27 | torch::Tensor coords, 28 | int radius) { 29 | CHECK_INPUT(fmap1); 30 | CHECK_INPUT(fmap2); 31 | CHECK_INPUT(coords); 32 | 33 | return corr_cuda_forward(fmap1, fmap2, coords, radius); 34 | } 35 | 36 | 37 | std::vector corr_backward( 38 | torch::Tensor fmap1, 39 | torch::Tensor fmap2, 40 | torch::Tensor coords, 41 | torch::Tensor corr_grad, 42 | int radius) { 43 | CHECK_INPUT(fmap1); 44 | CHECK_INPUT(fmap2); 45 | CHECK_INPUT(coords); 46 | CHECK_INPUT(corr_grad); 47 | 48 | return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius); 49 | } 50 | 51 | 52 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 53 | m.def("forward", &corr_forward, "CORR forward"); 54 | m.def("backward", &corr_backward, "CORR backward"); 55 | ======= 56 | #include 57 | #include 58 | 59 | // CUDA forward declarations 60 | std::vector corr_cuda_forward( 61 | torch::Tensor fmap1, 62 | torch::Tensor fmap2, 63 | torch::Tensor coords, 64 | int radius); 65 | 66 | std::vector corr_cuda_backward( 67 | torch::Tensor fmap1, 68 | torch::Tensor fmap2, 69 | torch::Tensor coords, 70 | torch::Tensor corr_grad, 71 | int radius); 72 | 73 | // C++ interface 74 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 75 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 76 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 77 | 78 | std::vector corr_forward( 79 | torch::Tensor fmap1, 80 | torch::Tensor fmap2, 81 | torch::Tensor coords, 82 | int radius) { 83 | CHECK_INPUT(fmap1); 84 | CHECK_INPUT(fmap2); 85 | CHECK_INPUT(coords); 86 | 87 | return corr_cuda_forward(fmap1, fmap2, coords, radius); 88 | } 89 | 90 | 91 | std::vector corr_backward( 92 | torch::Tensor fmap1, 93 | torch::Tensor fmap2, 94 | torch::Tensor coords, 95 | torch::Tensor corr_grad, 96 | int radius) { 97 | CHECK_INPUT(fmap1); 98 | CHECK_INPUT(fmap2); 99 | CHECK_INPUT(coords); 100 | CHECK_INPUT(corr_grad); 101 | 102 | return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius); 103 | } 104 | 105 | 106 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 107 | m.def("forward", &corr_forward, "CORR forward"); 108 | m.def("backward", &corr_backward, "CORR backward"); 109 | >>>>>>> Stashed changes 110 | } -------------------------------------------------------------------------------- /ezsynth/utils/flow_utils/alt_cuda_corr/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | 5 | setup( 6 | name='correlation', 7 | ext_modules=[ 8 | CUDAExtension('alt_cuda_corr', 9 | sources=['correlation.cpp', 'correlation_kernel.cu'], 10 | extra_compile_args={'cxx': [], 'nvcc': ['-O3']}), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) 15 | 16 | -------------------------------------------------------------------------------- /ezsynth/utils/flow_utils/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/ezsynth/utils/flow_utils/core/__init__.py -------------------------------------------------------------------------------- /ezsynth/utils/flow_utils/core/corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from .utils.utils import bilinear_sampler 5 | 6 | try: 7 | import alt_cuda_corr 8 | except ImportError as e: 9 | print(f"alt_cuda_corr is not compiled. Not fatal. {e}") 10 | # alt_cuda_corr is not compiled 11 | pass 12 | 13 | 14 | class CorrBlock: 15 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 16 | self.num_levels = num_levels 17 | self.radius = radius 18 | self.corr_pyramid = [] 19 | 20 | # all pairs correlation 21 | corr = CorrBlock.corr(fmap1, fmap2) 22 | 23 | batch, h1, w1, dim, h2, w2 = corr.shape 24 | corr = corr.reshape(batch * h1 * w1, dim, h2, w2) 25 | 26 | self.corr_pyramid.append(corr) 27 | for i in range(self.num_levels - 1): 28 | corr = F.avg_pool2d(corr, 2, stride=2) 29 | self.corr_pyramid.append(corr) 30 | 31 | def __call__(self, coords): 32 | r = self.radius 33 | coords = coords.permute(0, 2, 3, 1) 34 | batch, h1, w1, _ = coords.shape 35 | 36 | out_pyramid = [] 37 | for i in range(self.num_levels): 38 | corr = self.corr_pyramid[i] 39 | dx = torch.linspace(-r, r, 2 * r + 1, device=coords.device) 40 | dy = torch.linspace(-r, r, 2 * r + 1, device=coords.device) 41 | delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1) 42 | 43 | centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i 44 | delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) 45 | coords_lvl = centroid_lvl + delta_lvl 46 | 47 | corr = bilinear_sampler(corr, coords_lvl) 48 | corr = corr.view(batch, h1, w1, -1) 49 | out_pyramid.append(corr) 50 | 51 | out = torch.cat(out_pyramid, dim=-1) 52 | return out.permute(0, 3, 1, 2).contiguous().float() 53 | 54 | @staticmethod 55 | def corr(fmap1, fmap2): 56 | batch, dim, ht, wd = fmap1.shape 57 | fmap1 = fmap1.view(batch, dim, ht * wd) 58 | fmap2 = fmap2.view(batch, dim, ht * wd) 59 | 60 | corr = torch.matmul(fmap1.transpose(1, 2), fmap2) 61 | corr = corr.view(batch, ht, wd, 1, ht, wd) 62 | return corr / torch.sqrt(torch.tensor(dim).float()) 63 | 64 | 65 | class AlternateCorrBlock: 66 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 67 | self.num_levels = num_levels 68 | self.radius = radius 69 | 70 | self.pyramid = [(fmap1, fmap2)] 71 | for i in range(self.num_levels): 72 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2) 73 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2) 74 | self.pyramid.append((fmap1, fmap2)) 75 | 76 | def __call__(self, coords): 77 | coords = coords.permute(0, 2, 3, 1) 78 | B, H, W, _ = coords.shape 79 | dim = self.pyramid[0][0].shape[1] 80 | 81 | corr_list = [] 82 | for i in range(self.num_levels): 83 | r = self.radius 84 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() 85 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() 86 | 87 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() 88 | (corr,) = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) 89 | corr_list.append(corr.squeeze(1)) 90 | 91 | corr = torch.stack(corr_list, dim=1) 92 | corr = corr.reshape(B, -1, H, W) 93 | return corr / torch.sqrt(torch.tensor(dim).float()) 94 | 95 | class EF_CorrBlock: 96 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 97 | self.num_levels = num_levels 98 | self.radius = radius 99 | self.corr_pyramid = [] 100 | 101 | # all pairs correlation 102 | corr = EF_CorrBlock.corr(fmap1, fmap2) 103 | 104 | batch, h1, w1, dim, h2, w2 = corr.shape 105 | self.corr_map = corr.view(batch, h1 * w1, h2 * w2) 106 | corr = corr.reshape(batch*h1*w1, dim, h2, w2) 107 | 108 | self.corr_pyramid.append(corr) 109 | for i in range(self.num_levels-1): 110 | corr = F.avg_pool2d(corr, 2, stride=2) 111 | self.corr_pyramid.append(corr) 112 | 113 | def __call__(self, coords, scalers=None): 114 | r = self.radius 115 | 116 | if scalers is not None: 117 | assert(scalers.shape[-1] == 4 and scalers.shape[-2] == self.num_levels) 118 | scalers = scalers.view(-1, 1, scalers.shape[-2], 2, scalers.shape[-1] // 2) 119 | 120 | coords = coords.permute(0, 2, 3, 1) 121 | batch, h1, w1, _ = coords.shape 122 | 123 | out_pyramid = [] 124 | for i in range(self.num_levels): 125 | corr = self.corr_pyramid[i] 126 | centroid_lvl = coords.reshape(batch, h1*w1, 1, 1, 2) / 2**i 127 | dx = torch.linspace(-r, r, 2*r+1, device=coords.device) 128 | dy = torch.linspace(-r, r, 2*r+1, device=coords.device) 129 | delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1) 130 | delta = delta.view(-1, 2) 131 | delta = delta.repeat((batch, 1, 1)) 132 | if scalers is not None: 133 | delta[..., 0] *= scalers[..., i, 0, 0] 134 | delta[..., 1] *= scalers[..., i, 0, 1] 135 | delta[..., 0] += torch.sign(delta[..., 0]) * scalers[..., i, 1, 0] * r 136 | delta[..., 1] += torch.sign(delta[..., 1]) * scalers[..., i, 1, 1] * r 137 | delta_lvl = delta.view(batch, 1, 2*r+1, 2*r+1, 2) 138 | coords_lvl = centroid_lvl + delta_lvl 139 | 140 | coords_lvl = coords_lvl.reshape(-1, coords_lvl.shape[-3], coords_lvl.shape[-2], coords_lvl.shape[-1]) 141 | 142 | corr = bilinear_sampler(corr, coords_lvl) 143 | corr = corr.view(batch, h1, w1, -1) 144 | out_pyramid.append(corr) 145 | 146 | out = torch.cat(out_pyramid, dim=-1) 147 | return out.permute(0, 3, 1, 2).contiguous().float() 148 | 149 | @staticmethod 150 | def corr(fmap1, fmap2): 151 | batch, dim, ht, wd = fmap1.shape 152 | fmap1 = fmap1.view(batch, dim, ht*wd) 153 | fmap2 = fmap2.view(batch, dim, ht*wd) 154 | 155 | corr = torch.matmul(fmap1.transpose(1,2), fmap2) 156 | corr = corr.view(batch, ht, wd, 1, ht, wd) 157 | return corr / torch.sqrt(torch.tensor(dim).float()) 158 | 159 | class EF_AlternateCorrBlock: 160 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 161 | self.num_levels = num_levels 162 | self.radius = radius 163 | 164 | self.pyramid = [(fmap1, fmap2)] 165 | for i in range(self.num_levels): 166 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2) 167 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2) 168 | self.pyramid.append((fmap1, fmap2)) 169 | 170 | def __call__(self, coords): 171 | coords = coords.permute(0, 2, 3, 1) 172 | B, H, W, _ = coords.shape 173 | dim = self.pyramid[0][0].shape[1] 174 | 175 | corr_list = [] 176 | for i in range(self.num_levels): 177 | r = self.radius 178 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() 179 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() 180 | 181 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() 182 | corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) 183 | corr_list.append(corr.squeeze(1)) 184 | 185 | corr = torch.stack(corr_list, dim=1) 186 | corr = corr.reshape(B, -1, H, W) 187 | return corr / torch.sqrt(torch.tensor(dim).float()) 188 | -------------------------------------------------------------------------------- /ezsynth/utils/flow_utils/core/datasets.py: -------------------------------------------------------------------------------- 1 | # Data loading based on https://github.com/NVIDIA/flownet2-pytorch 2 | 3 | # import math 4 | import os 5 | import os.path as osp 6 | import random 7 | from glob import glob 8 | 9 | import numpy as np 10 | import torch 11 | # import torch.nn.functional as F 12 | import torch.utils.data as data 13 | 14 | from .utils import frame_utils 15 | from .utils.augmentor import FlowAugmentor, SparseFlowAugmentor 16 | 17 | 18 | class FlowDataset(data.Dataset): 19 | def __init__(self, aug_params=None, sparse=False): 20 | self.augmentor = None 21 | self.sparse = sparse 22 | if aug_params is not None: 23 | if sparse: 24 | self.augmentor = SparseFlowAugmentor(**aug_params) 25 | else: 26 | self.augmentor = FlowAugmentor(**aug_params) 27 | 28 | self.is_test = False 29 | self.init_seed = False 30 | self.flow_list = [] 31 | self.image_list = [] 32 | self.extra_info = [] 33 | 34 | def __getitem__(self, index): 35 | if self.is_test: 36 | img1 = frame_utils.read_gen(self.image_list[index][0]) 37 | img2 = frame_utils.read_gen(self.image_list[index][1]) 38 | img1 = np.array(img1).astype(np.uint8)[..., :3] 39 | img2 = np.array(img2).astype(np.uint8)[..., :3] 40 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float() 41 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float() 42 | return img1, img2, self.extra_info[index] 43 | 44 | if not self.init_seed: 45 | worker_info = torch.utils.data.get_worker_info() 46 | if worker_info is not None: 47 | torch.manual_seed(worker_info.id) 48 | np.random.seed(worker_info.id) 49 | random.seed(worker_info.id) 50 | self.init_seed = True 51 | 52 | index = index % len(self.image_list) 53 | valid = None 54 | if self.sparse: 55 | flow, valid = frame_utils.readFlowKITTI(self.flow_list[index]) 56 | else: 57 | flow = frame_utils.read_gen(self.flow_list[index]) 58 | 59 | img1 = frame_utils.read_gen(self.image_list[index][0]) 60 | img2 = frame_utils.read_gen(self.image_list[index][1]) 61 | 62 | flow = np.array(flow).astype(np.float32) 63 | img1 = np.array(img1).astype(np.uint8) 64 | img2 = np.array(img2).astype(np.uint8) 65 | 66 | # grayscale images 67 | if len(img1.shape) == 2: 68 | img1 = np.tile(img1[..., None], (1, 1, 3)) 69 | img2 = np.tile(img2[..., None], (1, 1, 3)) 70 | else: 71 | img1 = img1[..., :3] 72 | img2 = img2[..., :3] 73 | 74 | if self.augmentor is not None: 75 | if self.sparse: 76 | img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid) 77 | else: 78 | img1, img2, flow = self.augmentor(img1, img2, flow) 79 | 80 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float() 81 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float() 82 | flow = torch.from_numpy(flow).permute(2, 0, 1).float() 83 | 84 | if valid is not None: 85 | valid = torch.from_numpy(valid) 86 | else: 87 | valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000) 88 | 89 | return img1, img2, flow, valid.float() 90 | 91 | def __rmul__(self, v): 92 | self.flow_list = v * self.flow_list 93 | self.image_list = v * self.image_list 94 | return self 95 | 96 | def __len__(self): 97 | return len(self.image_list) 98 | 99 | 100 | class MpiSintel(FlowDataset): 101 | def __init__( 102 | self, aug_params=None, split="training", root="datasets/Sintel", dstype="clean" 103 | ): 104 | super(MpiSintel, self).__init__(aug_params) 105 | flow_root = osp.join(root, split, "flow") 106 | image_root = osp.join(root, split, dstype) 107 | 108 | if split == "test": 109 | self.is_test = True 110 | 111 | for scene in os.listdir(image_root): 112 | image_list = sorted(glob(osp.join(image_root, scene, "*.png"))) 113 | for i in range(len(image_list) - 1): 114 | self.image_list += [[image_list[i], image_list[i + 1]]] 115 | self.extra_info += [(scene, i)] # scene and frame_id 116 | 117 | if split != "test": 118 | self.flow_list += sorted(glob(osp.join(flow_root, scene, "*.flo"))) 119 | 120 | 121 | class FlyingChairs(FlowDataset): 122 | def __init__( 123 | self, aug_params=None, split="train", root="datasets/FlyingChairs_release/data" 124 | ): 125 | super(FlyingChairs, self).__init__(aug_params) 126 | 127 | images = sorted(glob(osp.join(root, "*.ppm"))) 128 | flows = sorted(glob(osp.join(root, "*.flo"))) 129 | assert len(images) // 2 == len(flows) 130 | 131 | split_list = np.loadtxt("chairs_split.txt", dtype=np.int32) 132 | for i in range(len(flows)): 133 | xid = split_list[i] 134 | if (split == "training" and xid == 1) or ( 135 | split == "validation" and xid == 2 136 | ): 137 | self.flow_list += [flows[i]] 138 | self.image_list += [[images[2 * i], images[2 * i + 1]]] 139 | 140 | 141 | class FlyingThings3D(FlowDataset): 142 | def __init__( 143 | self, aug_params=None, root="datasets/FlyingThings3D", dstype="frames_cleanpass" 144 | ): 145 | super(FlyingThings3D, self).__init__(aug_params) 146 | 147 | for cam in ["left"]: 148 | for direction in ["into_future", "into_past"]: 149 | image_dirs = sorted(glob(osp.join(root, dstype, "TRAIN/*/*"))) 150 | image_dirs = sorted([osp.join(f, cam) for f in image_dirs]) 151 | 152 | flow_dirs = sorted(glob(osp.join(root, "optical_flow/TRAIN/*/*"))) 153 | flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs]) 154 | 155 | for idir, fdir in zip(image_dirs, flow_dirs): 156 | images = sorted(glob(osp.join(idir, "*.png"))) 157 | flows = sorted(glob(osp.join(fdir, "*.pfm"))) 158 | for i in range(len(flows) - 1): 159 | if direction == "into_future": 160 | self.image_list += [[images[i], images[i + 1]]] 161 | self.flow_list += [flows[i]] 162 | elif direction == "into_past": 163 | self.image_list += [[images[i + 1], images[i]]] 164 | self.flow_list += [flows[i + 1]] 165 | 166 | 167 | class KITTI(FlowDataset): 168 | def __init__(self, aug_params=None, split="training", root="datasets/KITTI"): 169 | super(KITTI, self).__init__(aug_params, sparse=True) 170 | if split == "testing": 171 | self.is_test = True 172 | 173 | root = osp.join(root, split) 174 | images1 = sorted(glob(osp.join(root, "image_2/*_10.png"))) 175 | images2 = sorted(glob(osp.join(root, "image_2/*_11.png"))) 176 | 177 | for img1, img2 in zip(images1, images2): 178 | frame_id = img1.split("/")[-1] 179 | self.extra_info += [[frame_id]] 180 | self.image_list += [[img1, img2]] 181 | 182 | if split == "training": 183 | self.flow_list = sorted(glob(osp.join(root, "flow_occ/*_10.png"))) 184 | 185 | 186 | class HD1K(FlowDataset): 187 | def __init__(self, aug_params=None, root="datasets/HD1k"): 188 | super(HD1K, self).__init__(aug_params, sparse=True) 189 | 190 | seq_ix = 0 191 | while 1: 192 | flows = sorted( 193 | glob(os.path.join(root, "hd1k_flow_gt", "flow_occ/%06d_*.png" % seq_ix)) 194 | ) 195 | images = sorted( 196 | glob(os.path.join(root, "hd1k_input", "image_2/%06d_*.png" % seq_ix)) 197 | ) 198 | 199 | if len(flows) == 0: 200 | break 201 | 202 | for i in range(len(flows) - 1): 203 | self.flow_list += [flows[i]] 204 | self.image_list += [[images[i], images[i + 1]]] 205 | 206 | seq_ix += 1 207 | 208 | 209 | def fetch_dataloader(args, TRAIN_DS="C+T+K+S+H"): 210 | """Create the data loader for the corresponding trainign set""" 211 | 212 | if args.stage == "chairs": 213 | aug_params = { 214 | "crop_size": args.image_size, 215 | "min_scale": -0.1, 216 | "max_scale": 1.0, 217 | "do_flip": True, 218 | } 219 | train_dataset = FlyingChairs(aug_params, split="training") 220 | 221 | elif args.stage == "things": 222 | aug_params = { 223 | "crop_size": args.image_size, 224 | "min_scale": -0.4, 225 | "max_scale": 0.8, 226 | "do_flip": True, 227 | } 228 | clean_dataset = FlyingThings3D(aug_params, dstype="frames_cleanpass") 229 | final_dataset = FlyingThings3D(aug_params, dstype="frames_finalpass") 230 | train_dataset = clean_dataset + final_dataset 231 | 232 | elif args.stage == "sintel": 233 | aug_params = { 234 | "crop_size": args.image_size, 235 | "min_scale": -0.2, 236 | "max_scale": 0.6, 237 | "do_flip": True, 238 | } 239 | things = FlyingThings3D(aug_params, dstype="frames_cleanpass") 240 | sintel_clean = MpiSintel(aug_params, split="training", dstype="clean") 241 | sintel_final = MpiSintel(aug_params, split="training", dstype="final") 242 | 243 | if TRAIN_DS == "C+T+K+S+H": 244 | kitti = KITTI( 245 | { 246 | "crop_size": args.image_size, 247 | "min_scale": -0.3, 248 | "max_scale": 0.5, 249 | "do_flip": True, 250 | } 251 | ) 252 | hd1k = HD1K( 253 | { 254 | "crop_size": args.image_size, 255 | "min_scale": -0.5, 256 | "max_scale": 0.2, 257 | "do_flip": True, 258 | } 259 | ) 260 | train_dataset = ( 261 | 100 * sintel_clean 262 | + 100 * sintel_final 263 | + 200 * kitti 264 | + 5 * hd1k 265 | + things 266 | ) 267 | 268 | elif TRAIN_DS == "C+T+K/S": 269 | train_dataset = 100 * sintel_clean + 100 * sintel_final + things 270 | 271 | elif args.stage == "kitti": 272 | aug_params = { 273 | "crop_size": args.image_size, 274 | "min_scale": -0.2, 275 | "max_scale": 0.4, 276 | "do_flip": False, 277 | } 278 | train_dataset = KITTI(aug_params, split="training") 279 | 280 | train_loader = data.DataLoader( 281 | train_dataset, 282 | batch_size=args.batch_size, 283 | pin_memory=False, 284 | shuffle=True, 285 | num_workers=4, 286 | drop_last=True, 287 | ) 288 | 289 | print("Training with %d image pairs" % len(train_dataset)) 290 | return train_loader 291 | -------------------------------------------------------------------------------- /ezsynth/utils/flow_utils/core/ef_raft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .corr import EF_AlternateCorrBlock, EF_CorrBlock 6 | from .extractor import BasicEncoder, SmallEncoder, CoordinateAttention 7 | from .update import BasicUpdateBlock, SmallUpdateBlock, LookupScaler 8 | from .utils.utils import coords_grid, upflow8 9 | 10 | try: 11 | autocast = torch.cuda.amp.autocast 12 | except Exception as e: 13 | # dummy autocast for PyTorch < 1.6 14 | print(e) 15 | 16 | class autocast: 17 | def __init__(self, enabled): 18 | pass 19 | 20 | def __enter__(self): 21 | pass 22 | 23 | def __exit__(self, *args): 24 | pass 25 | 26 | 27 | class EF_RAFT(nn.Module): 28 | def __init__(self, args): 29 | super(EF_RAFT, self).__init__() 30 | self.args = args 31 | 32 | if args.small: 33 | self.hidden_dim = hdim = 96 34 | self.context_dim = cdim = 64 35 | args.corr_levels = 4 36 | args.corr_radius = 3 37 | 38 | else: 39 | self.hidden_dim = hdim = 128 40 | self.context_dim = cdim = 128 41 | args.corr_levels = 4 42 | args.corr_radius = 4 43 | 44 | if "dropout" not in self.args: 45 | self.args.dropout = 0 46 | 47 | if "alternate_corr" not in self.args: 48 | self.args.alternate_corr = False 49 | 50 | # feature network, context network, and update block 51 | if args.small: 52 | self.fnet = SmallEncoder( 53 | output_dim=128, norm_fn="instance", dropout=args.dropout 54 | ) 55 | self.coor_att = None 56 | self.cnet = SmallEncoder( 57 | output_dim=hdim + cdim, norm_fn="none", dropout=args.dropout 58 | ) 59 | self.lookup_scaler = None 60 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) 61 | 62 | else: 63 | self.fnet = BasicEncoder( 64 | output_dim=256, norm_fn="instance", dropout=args.dropout 65 | ) 66 | self.coor_att = CoordinateAttention(feature_size=256, enc_size=128) # New 67 | self.cnet = BasicEncoder( 68 | output_dim=hdim + cdim, norm_fn="batch", dropout=args.dropout 69 | ) 70 | self.lookup_scaler = LookupScaler( 71 | input_dim=hdim, output_size=args.corr_levels # New 72 | ) 73 | self.update_block = BasicUpdateBlock( 74 | self.args, hidden_dim=hdim, input_dim=cdim + args.corr_levels * 4 75 | ) # Updated 76 | 77 | def freeze_bn(self): 78 | for m in self.modules(): 79 | if isinstance(m, nn.BatchNorm2d): 80 | m.eval() 81 | 82 | def initialize_flow(self, img): 83 | """Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 84 | N, C, H, W = img.shape 85 | coords0 = coords_grid(N, H // 8, W // 8, device=img.device) 86 | coords1 = coords_grid(N, H // 8, W // 8, device=img.device) 87 | 88 | # optical flow computed as difference: flow = coords1 - coords0 89 | return coords0, coords1 90 | 91 | def upsample_flow(self, flow, mask): 92 | """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination""" 93 | N, _, H, W = flow.shape 94 | mask = mask.view(N, 1, 9, 8, 8, H, W) 95 | mask = torch.softmax(mask, dim=2) 96 | 97 | up_flow = F.unfold(8 * flow, [3, 3], padding=1) 98 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 99 | 100 | up_flow = torch.sum(mask * up_flow, dim=2) 101 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 102 | return up_flow.reshape(N, 2, 8 * H, 8 * W) 103 | 104 | def forward( 105 | self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False 106 | ): 107 | """Estimate optical flow between pair of frames""" 108 | 109 | image1 = 2 * (image1 / 255.0) - 1.0 110 | image2 = 2 * (image2 / 255.0) - 1.0 111 | 112 | image1 = image1.contiguous() 113 | image2 = image2.contiguous() 114 | 115 | hdim = self.hidden_dim 116 | cdim = self.context_dim 117 | 118 | # run the feature network 119 | with autocast(enabled=self.args.mixed_precision): 120 | fmap1, fmap2 = self.fnet([image1, image2]) 121 | fmap1 = self.coor_att(fmap1) 122 | fmap2 = self.coor_att(fmap2) 123 | 124 | fmap1 = fmap1.float() 125 | fmap2 = fmap2.float() 126 | if self.args.alternate_corr: 127 | corr_fn = EF_AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 128 | else: 129 | corr_fn = EF_CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 130 | 131 | # run the context network 132 | with autocast(enabled=self.args.mixed_precision): 133 | cnet = self.cnet(image1) 134 | net, base_inp = torch.split(cnet, [hdim, cdim], dim=1) 135 | net = torch.tanh(net) 136 | base_inp = torch.relu(base_inp) 137 | 138 | coords0, coords1 = self.initialize_flow(image1) 139 | 140 | BATCH_N, fC, fH, fW = fmap1.shape 141 | corr_map = corr_fn.corr_map 142 | soft_corr_map = F.softmax(corr_map, dim=2) * F.softmax(corr_map, dim=1) 143 | 144 | if flow_init is not None: 145 | coords1 = coords1 + flow_init 146 | else: # Use the global matching idea as an initialization. 147 | match_f, match_f_ind = soft_corr_map.max(dim=2) # Forward matching. 148 | match_b, match_b_ind = soft_corr_map.max(dim=1) # Backward matching. 149 | 150 | # Permute the backward softmax for match the forward. 151 | for i in range(BATCH_N): 152 | match_b_tmp = match_b[i, ...] 153 | match_b[i, ...] = match_b_tmp[match_f_ind[i, ...]] 154 | 155 | # Replace the identity mapping with the found matches. 156 | matched = (match_f - match_b) == 0 157 | coords_index = ( 158 | torch.arange(fH * fW) 159 | .unsqueeze(0) 160 | .repeat(BATCH_N, 1) 161 | .to(soft_corr_map.device) 162 | ) 163 | coords_index[matched] = match_f_ind[matched] 164 | 165 | # Convert the 1D mapping to a 2D one. 166 | coords_index = coords_index.reshape(BATCH_N, fH, fW) 167 | coords_x = coords_index % fW 168 | coords_y = coords_index // fW 169 | coords1 = torch.stack([coords_x, coords_y], dim=1).float() 170 | 171 | # Iterative update 172 | flow_predictions = [] 173 | for itr in range(iters): 174 | coords1 = coords1.detach() 175 | 176 | lookup_scalers = None 177 | if self.lookup_scaler is not None: 178 | lookup_scalers = self.lookup_scaler(base_inp, net) 179 | cat_lookup_scalers = lookup_scalers.view( 180 | -1, lookup_scalers.shape[-1] * lookup_scalers.shape[-2], 1, 1 181 | ) 182 | cat_lookup_scalers = cat_lookup_scalers.expand( 183 | -1, -1, base_inp.shape[2], base_inp.shape[3] 184 | ) 185 | inp = torch.cat([base_inp, cat_lookup_scalers], dim=1) 186 | 187 | corr = corr_fn(coords1, scalers=lookup_scalers) # index correlation volume 188 | 189 | flow = coords1 - coords0 190 | with autocast(enabled=self.args.mixed_precision): 191 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) 192 | 193 | # F(t+1) = F(t) + \Delta(t) 194 | coords1 = coords1 + delta_flow 195 | 196 | # upsample predictions 197 | if up_mask is None: 198 | flow_up = upflow8(coords1 - coords0) 199 | else: 200 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 201 | 202 | flow_predictions.append(flow_up) 203 | 204 | if test_mode: 205 | return coords1 - coords0, flow_up 206 | 207 | return flow_predictions 208 | -------------------------------------------------------------------------------- /ezsynth/utils/flow_utils/core/fd_corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from .utils.utils import bilinear_sampler 5 | 6 | 7 | class CorrBlock_FD_Sp4: 8 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4, coords_init=None, rad=1): 9 | self.num_levels = num_levels 10 | self.radius = radius 11 | self.corr_pyramid = [] 12 | 13 | corr = CorrBlock_FD_Sp4.corr(fmap1, fmap2, coords_init, r=rad) 14 | 15 | batch, h1, w1, dim, h2, w2 = corr.shape 16 | corr = corr.reshape(batch*h1*w1, dim, h2, w2) 17 | 18 | self.corr_pyramid.append(corr) 19 | for i in range(self.num_levels-1): 20 | corr = F.avg_pool2d(corr, 2, stride=2) 21 | self.corr_pyramid.append(corr) 22 | 23 | def __call__(self, coords): 24 | r = self.radius 25 | coords = coords.permute(0, 2, 3, 1) 26 | batch, h1, w1, _ = coords.shape 27 | 28 | out_pyramid = [] 29 | for i in range(self.num_levels): 30 | corr = self.corr_pyramid[i] 31 | dx = torch.linspace(-r, r, 2*r+1, device=coords.device) 32 | dy = torch.linspace(-r, r, 2*r+1, device=coords.device) 33 | delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1) 34 | 35 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i 36 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) 37 | coords_lvl = centroid_lvl + delta_lvl 38 | 39 | corr = bilinear_sampler(corr, coords_lvl) 40 | corr = corr.view(batch, h1, w1, -1) 41 | out_pyramid.append(corr) 42 | 43 | out = torch.cat(out_pyramid, dim=-1) 44 | return out.permute(0, 3, 1, 2).contiguous().float() 45 | 46 | @staticmethod 47 | def corr(fmap1, fmap2, coords_init, r): 48 | batch, dim, ht, wd = fmap1.shape 49 | fmap1 = fmap1.view(batch, dim, ht*wd) 50 | fmap2 = fmap2.view(batch, dim, ht*wd) 51 | 52 | corr = torch.matmul(fmap1.transpose(1, 2), fmap2) 53 | corr = corr.view(batch, ht, wd, 1, ht, wd) 54 | # return corr / torch.sqrt(torch.tensor(dim).float()) 55 | 56 | coords = coords_init.permute(0, 2, 3, 1).contiguous() 57 | batch, h1, w1, _ = coords.shape 58 | 59 | corr = corr.view(batch*h1*w1, 1, h1, w1) 60 | 61 | dx = torch.linspace(-r, r, 2*r+1, device=coords.device) 62 | dy = torch.linspace(-r, r, 2*r+1, device=coords.device) 63 | delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1) 64 | 65 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) 66 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) 67 | coords_lvl = centroid_lvl + delta_lvl 68 | 69 | corr = bilinear_sampler(corr, coords_lvl) 70 | 71 | corr = corr.view(batch, h1, w1, 1, 2*r+1, 2*r+1) 72 | return corr.permute(0, 1, 2, 3, 5, 4).contiguous() / torch.sqrt(torch.tensor(dim).float()) 73 | -------------------------------------------------------------------------------- /ezsynth/utils/flow_utils/core/fd_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | # import torch.nn.functional as F 4 | 5 | import timm 6 | import numpy as np 7 | 8 | 9 | class twins_svt_large(nn.Module): 10 | def __init__(self, pretrained=True): 11 | super().__init__() 12 | self.svt = timm.create_model('twins_svt_large', pretrained=pretrained) 13 | 14 | del self.svt.head 15 | del self.svt.patch_embeds[2] 16 | del self.svt.patch_embeds[2] 17 | del self.svt.blocks[2] 18 | del self.svt.blocks[2] 19 | del self.svt.pos_block[2] 20 | del self.svt.pos_block[2] 21 | 22 | def forward(self, x, data=None, layer=2): 23 | B = x.shape[0] 24 | x_4 = None 25 | for i, (embed, drop, blocks, pos_blk) in enumerate( 26 | zip(self.svt.patch_embeds, self.svt.pos_drops, self.svt.blocks, self.svt.pos_block)): 27 | 28 | patch_size = embed.patch_size 29 | if i == layer - 1: 30 | embed.patch_size = (1, 1) 31 | embed.proj.stride = embed.patch_size 32 | x_4 = torch.nn.functional.pad(x, [1, 0, 1, 0], mode='constant', value=0) 33 | x_4, size_4 = embed(x_4) 34 | size_4 = (size_4[0] - 1, size_4[1] - 1) 35 | x_4 = drop(x_4) 36 | for j, blk in enumerate(blocks): 37 | x_4 = blk(x_4, size_4) 38 | if j == 0: 39 | x_4 = pos_blk(x_4, size_4) 40 | 41 | if i < len(self.svt.depths) - 1: 42 | x_4 = x_4.reshape(B, *size_4, -1).permute(0, 3, 1, 2).contiguous() 43 | 44 | embed.patch_size = patch_size 45 | embed.proj.stride = patch_size 46 | x, size = embed(x) 47 | x = drop(x) 48 | for j, blk in enumerate(blocks): 49 | x = blk(x, size) 50 | if j==0: 51 | x = pos_blk(x, size) 52 | if i < len(self.svt.depths) - 1: 53 | x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() 54 | 55 | if i == layer-1: 56 | break 57 | 58 | return x, x_4 59 | 60 | def compute_params(self, layer=2): 61 | num = 0 62 | for i, (embed, drop, blocks, pos_blk) in enumerate( 63 | zip(self.svt.patch_embeds, self.svt.pos_drops, self.svt.blocks, self.svt.pos_block)): 64 | 65 | for param in embed.parameters(): 66 | num += np.prod(param.size()) 67 | 68 | for param in drop.parameters(): 69 | num += np.prod(param.size()) 70 | 71 | for param in blocks.parameters(): 72 | num += np.prod(param.size()) 73 | 74 | for param in pos_blk.parameters(): 75 | num += np.prod(param.size()) 76 | 77 | if i == layer-1: 78 | break 79 | 80 | for param in self.svt.head.parameters(): 81 | num += np.prod(param.size()) 82 | 83 | return num 84 | 85 | 86 | class twins_svt_small_context(nn.Module): 87 | def __init__(self, pretrained=True): 88 | super().__init__() 89 | self.svt = timm.create_model('twins_svt_small', pretrained=pretrained) 90 | 91 | del self.svt.head 92 | del self.svt.patch_embeds[2] 93 | del self.svt.patch_embeds[2] 94 | del self.svt.blocks[2] 95 | del self.svt.blocks[2] 96 | del self.svt.pos_block[2] 97 | del self.svt.pos_block[2] 98 | 99 | def forward(self, x, data=None, layer=2): 100 | B = x.shape[0] 101 | x_4 = None 102 | for i, (embed, drop, blocks, pos_blk) in enumerate( 103 | zip(self.svt.patch_embeds, self.svt.pos_drops, self.svt.blocks, self.svt.pos_block)): 104 | 105 | patch_size = embed.patch_size 106 | if i == layer - 1: 107 | embed.patch_size = (1, 1) 108 | embed.proj.stride = embed.patch_size 109 | x_4 = torch.nn.functional.pad(x, [1, 0, 1, 0], mode='constant', value=0) 110 | x_4, size_4 = embed(x_4) 111 | size_4 = (size_4[0] - 1, size_4[1] - 1) 112 | x_4 = drop(x_4) 113 | for j, blk in enumerate(blocks): 114 | x_4 = blk(x_4, size_4) 115 | if j == 0: 116 | x_4 = pos_blk(x_4, size_4) 117 | 118 | if i < len(self.svt.depths) - 1: 119 | x_4 = x_4.reshape(B, *size_4, -1).permute(0, 3, 1, 2).contiguous() 120 | 121 | embed.patch_size = patch_size 122 | embed.proj.stride = patch_size 123 | x, size = embed(x) 124 | x = drop(x) 125 | for j, blk in enumerate(blocks): 126 | x = blk(x, size) 127 | if j == 0: 128 | x = pos_blk(x, size) 129 | if i < len(self.svt.depths) - 1: 130 | x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() 131 | 132 | if i == layer - 1: 133 | break 134 | 135 | return x, x_4 136 | 137 | def compute_params(self, layer=2): 138 | num = 0 139 | for i, (embed, drop, blocks, pos_blk) in enumerate( 140 | zip(self.svt.patch_embeds, self.svt.pos_drops, self.svt.blocks, self.svt.pos_block)): 141 | 142 | for param in embed.parameters(): 143 | num += np.prod(param.size()) 144 | 145 | for param in drop.parameters(): 146 | num += np.prod(param.size()) 147 | 148 | for param in blocks.parameters(): 149 | num += np.prod(param.size()) 150 | 151 | for param in pos_blk.parameters(): 152 | num += np.prod(param.size()) 153 | 154 | if i == layer - 1: 155 | break 156 | 157 | for param in self.svt.head.parameters(): 158 | num += np.prod(param.size()) 159 | 160 | return num 161 | 162 | -------------------------------------------------------------------------------- /ezsynth/utils/flow_utils/core/raft.py: -------------------------------------------------------------------------------- 1 | # import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .corr import AlternateCorrBlock, CorrBlock 7 | from .extractor import BasicEncoder, SmallEncoder 8 | from .update import BasicUpdateBlock, SmallUpdateBlock 9 | from .utils.utils import coords_grid, upflow8 10 | 11 | try: 12 | autocast = torch.cuda.amp.autocast 13 | except Exception as e: 14 | # dummy autocast for PyTorch < 1.6 15 | print(e) 16 | class autocast: 17 | def __init__(self, enabled): 18 | pass 19 | 20 | def __enter__(self): 21 | pass 22 | 23 | def __exit__(self, *args): 24 | pass 25 | 26 | 27 | class RAFT(nn.Module): 28 | def __init__(self, args): 29 | super(RAFT, self).__init__() 30 | self.args = args 31 | 32 | if args.small: 33 | self.hidden_dim = hdim = 96 34 | self.context_dim = cdim = 64 35 | args.corr_levels = 4 36 | args.corr_radius = 3 37 | 38 | else: 39 | self.hidden_dim = hdim = 128 40 | self.context_dim = cdim = 128 41 | args.corr_levels = 4 42 | args.corr_radius = 4 43 | 44 | if "dropout" not in self.args: 45 | self.args.dropout = 0 46 | 47 | if "alternate_corr" not in self.args: 48 | self.args.alternate_corr = False 49 | 50 | # feature network, context network, and update block 51 | if args.small: 52 | self.fnet = SmallEncoder( 53 | output_dim=128, norm_fn="instance", dropout=args.dropout 54 | ) 55 | self.cnet = SmallEncoder( 56 | output_dim=hdim + cdim, norm_fn="none", dropout=args.dropout 57 | ) 58 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) 59 | 60 | else: 61 | self.fnet = BasicEncoder( 62 | output_dim=256, norm_fn="instance", dropout=args.dropout 63 | ) 64 | self.cnet = BasicEncoder( 65 | output_dim=hdim + cdim, norm_fn="batch", dropout=args.dropout 66 | ) 67 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) 68 | 69 | def freeze_bn(self): 70 | for m in self.modules(): 71 | if isinstance(m, nn.BatchNorm2d): 72 | m.eval() 73 | 74 | def initialize_flow(self, img): 75 | """Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 76 | N, C, H, W = img.shape 77 | coords0 = coords_grid(N, H // 8, W // 8, device=img.device) 78 | coords1 = coords_grid(N, H // 8, W // 8, device=img.device) 79 | 80 | # optical flow computed as difference: flow = coords1 - coords0 81 | return coords0, coords1 82 | 83 | def upsample_flow(self, flow, mask): 84 | """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination""" 85 | N, _, H, W = flow.shape 86 | mask = mask.view(N, 1, 9, 8, 8, H, W) 87 | mask = torch.softmax(mask, dim=2) 88 | 89 | up_flow = F.unfold(8 * flow, [3, 3], padding=1) 90 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 91 | 92 | up_flow = torch.sum(mask * up_flow, dim=2) 93 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 94 | return up_flow.reshape(N, 2, 8 * H, 8 * W) 95 | 96 | def forward( 97 | self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False 98 | ): 99 | """Estimate optical flow between pair of frames""" 100 | 101 | image1 = 2 * (image1 / 255.0) - 1.0 102 | image2 = 2 * (image2 / 255.0) - 1.0 103 | 104 | image1 = image1.contiguous() 105 | image2 = image2.contiguous() 106 | 107 | hdim = self.hidden_dim 108 | cdim = self.context_dim 109 | 110 | # run the feature network 111 | with autocast(enabled=self.args.mixed_precision): 112 | fmap1, fmap2 = self.fnet([image1, image2]) 113 | 114 | fmap1 = fmap1.float() 115 | fmap2 = fmap2.float() 116 | if self.args.alternate_corr: 117 | corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 118 | else: 119 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 120 | 121 | # run the context network 122 | with autocast(enabled=self.args.mixed_precision): 123 | cnet = self.cnet(image1) 124 | net, inp = torch.split(cnet, [hdim, cdim], dim=1) 125 | net = torch.tanh(net) 126 | inp = torch.relu(inp) 127 | 128 | coords0, coords1 = self.initialize_flow(image1) 129 | 130 | if flow_init is not None: 131 | coords1 = coords1 + flow_init 132 | 133 | flow_predictions = [] 134 | for itr in range(iters): 135 | coords1 = coords1.detach() 136 | corr = corr_fn(coords1) # index correlation volume 137 | 138 | flow = coords1 - coords0 139 | with autocast(enabled=self.args.mixed_precision): 140 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) 141 | 142 | # F(t+1) = F(t) + \Delta(t) 143 | coords1 = coords1 + delta_flow 144 | 145 | # upsample predictions 146 | if up_mask is None: 147 | flow_up = upflow8(coords1 - coords0) 148 | else: 149 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 150 | 151 | flow_predictions.append(flow_up) 152 | 153 | if test_mode: 154 | return coords1 - coords0, flow_up 155 | 156 | return flow_predictions 157 | -------------------------------------------------------------------------------- /ezsynth/utils/flow_utils/core/update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FlowHead(nn.Module): 7 | def __init__(self, input_dim=128, hidden_dim=256): 8 | super(FlowHead, self).__init__() 9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 11 | self.relu = nn.ReLU(inplace=True) 12 | 13 | def forward(self, x): 14 | return self.conv2(self.relu(self.conv1(x))) 15 | 16 | 17 | class ConvGRU(nn.Module): 18 | def __init__(self, hidden_dim=128, input_dim=192 + 128): 19 | super(ConvGRU, self).__init__() 20 | self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) 21 | self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) 22 | self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) 23 | 24 | def forward(self, h, x): 25 | hx = torch.cat([h, x], dim=1) 26 | 27 | z = torch.sigmoid(self.convz(hx)) 28 | r = torch.sigmoid(self.convr(hx)) 29 | q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1))) 30 | 31 | h = (1 - z) * h + z * q 32 | return h 33 | 34 | 35 | class SepConvGRU(nn.Module): 36 | def __init__(self, hidden_dim=128, input_dim=192 + 128): 37 | super(SepConvGRU, self).__init__() 38 | self.convz1 = nn.Conv2d( 39 | hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) 40 | ) 41 | self.convr1 = nn.Conv2d( 42 | hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) 43 | ) 44 | self.convq1 = nn.Conv2d( 45 | hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) 46 | ) 47 | 48 | self.convz2 = nn.Conv2d( 49 | hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) 50 | ) 51 | self.convr2 = nn.Conv2d( 52 | hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) 53 | ) 54 | self.convq2 = nn.Conv2d( 55 | hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) 56 | ) 57 | 58 | def forward(self, h, x): 59 | # horizontal 60 | hx = torch.cat([h, x], dim=1) 61 | z = torch.sigmoid(self.convz1(hx)) 62 | r = torch.sigmoid(self.convr1(hx)) 63 | q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1))) 64 | h = (1 - z) * h + z * q 65 | 66 | # vertical 67 | hx = torch.cat([h, x], dim=1) 68 | z = torch.sigmoid(self.convz2(hx)) 69 | r = torch.sigmoid(self.convr2(hx)) 70 | q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1))) 71 | h = (1 - z) * h + z * q 72 | 73 | return h 74 | 75 | 76 | class SmallMotionEncoder(nn.Module): 77 | def __init__(self, args): 78 | super(SmallMotionEncoder, self).__init__() 79 | cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2 80 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) 81 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3) 82 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1) 83 | self.conv = nn.Conv2d(128, 80, 3, padding=1) 84 | 85 | def forward(self, flow, corr): 86 | cor = F.relu(self.convc1(corr)) 87 | flo = F.relu(self.convf1(flow)) 88 | flo = F.relu(self.convf2(flo)) 89 | cor_flo = torch.cat([cor, flo], dim=1) 90 | out = F.relu(self.conv(cor_flo)) 91 | return torch.cat([out, flow], dim=1) 92 | 93 | 94 | class BasicMotionEncoder(nn.Module): 95 | def __init__(self, args): 96 | super(BasicMotionEncoder, self).__init__() 97 | cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2 98 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) 99 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 100 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3) 101 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 102 | self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1) 103 | 104 | def forward(self, flow, corr): 105 | cor = F.relu(self.convc1(corr)) 106 | cor = F.relu(self.convc2(cor)) 107 | flo = F.relu(self.convf1(flow)) 108 | flo = F.relu(self.convf2(flo)) 109 | 110 | cor_flo = torch.cat([cor, flo], dim=1) 111 | out = F.relu(self.conv(cor_flo)) 112 | return torch.cat([out, flow], dim=1) 113 | 114 | 115 | class SmallUpdateBlock(nn.Module): 116 | def __init__(self, args, hidden_dim=96): 117 | super(SmallUpdateBlock, self).__init__() 118 | self.encoder = SmallMotionEncoder(args) 119 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82 + 64) 120 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128) 121 | 122 | def forward(self, net, inp, corr, flow): 123 | motion_features = self.encoder(flow, corr) 124 | inp = torch.cat([inp, motion_features], dim=1) 125 | net = self.gru(net, inp) 126 | delta_flow = self.flow_head(net) 127 | 128 | return net, None, delta_flow 129 | 130 | 131 | class BasicUpdateBlock(nn.Module): 132 | def __init__(self, args, hidden_dim=128, input_dim=128): 133 | super(BasicUpdateBlock, self).__init__() 134 | self.args = args 135 | self.encoder = BasicMotionEncoder(args) 136 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=input_dim+hidden_dim) 137 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 138 | 139 | self.mask = nn.Sequential( 140 | nn.Conv2d(128, 256, 3, padding=1), 141 | nn.ReLU(inplace=True), 142 | nn.Conv2d(256, 64 * 9, 1, padding=0), 143 | ) 144 | 145 | def forward(self, net, inp, corr, flow, upsample=True): 146 | motion_features = self.encoder(flow, corr) 147 | inp = torch.cat([inp, motion_features], dim=1) 148 | 149 | net = self.gru(net, inp) 150 | delta_flow = self.flow_head(net) 151 | 152 | # scale mask to balence gradients 153 | mask = 0.25 * self.mask(net) 154 | return net, mask, delta_flow 155 | 156 | class LookupScaler(nn.Module): 157 | def __init__(self, input_dim=128, output_size=4, output_dim=4, max_multiplier=2, max_translation=2): 158 | super(LookupScaler, self).__init__() 159 | self.input_dim = input_dim 160 | self.output_size = output_size 161 | self.output_dim = output_dim 162 | self.max_multiplier = max_multiplier 163 | self.max_translation = max_translation 164 | self.convert_conv = nn.Conv2d(2 * input_dim, 2 * input_dim, 1) 165 | self.model_scale = nn.Sequential(nn.Linear(4 * input_dim, (output_dim // 2) * output_size), 166 | nn.Sigmoid()) 167 | self.model_add = nn.Sequential(nn.Linear(4 * input_dim, (output_dim // 2) * output_size), 168 | nn.Sigmoid()) 169 | 170 | def forward(self, context_map, hidden_state): 171 | assert(context_map.shape[1] == self.input_dim) 172 | assert(hidden_state.shape[1] == self.input_dim) 173 | convert_map = self.convert_conv(torch.cat([context_map, hidden_state], dim=1).type(torch.float32)) 174 | lookup_context = torch.cat([torch.amax(convert_map, dim=(2, 3)), 175 | torch.amin(convert_map, dim=(2, 3))], dim=-1).type(torch.float32) 176 | scale = self.model_scale(lookup_context).view(-1, self.output_size, self.output_dim // 2) * self.max_multiplier + 1 177 | add = self.model_add(lookup_context).view(-1, self.output_size, self.output_dim // 2) * self.max_translation 178 | return torch.cat([scale, add], dim=-1) -------------------------------------------------------------------------------- /ezsynth/utils/flow_utils/core/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/ezsynth/utils/flow_utils/core/utils/__init__.py -------------------------------------------------------------------------------- /ezsynth/utils/flow_utils/core/utils/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/ezsynth/utils/flow_utils/core/utils/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /ezsynth/utils/flow_utils/core/utils/__pycache__/utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/ezsynth/utils/flow_utils/core/utils/__pycache__/utils.cpython-311.pyc -------------------------------------------------------------------------------- /ezsynth/utils/flow_utils/core/utils/augmentor.py: -------------------------------------------------------------------------------- 1 | # import math 2 | # import random 3 | 4 | import cv2 5 | import numpy as np 6 | from PIL import Image 7 | 8 | # import torch 9 | # import torch.nn.functional as F 10 | from torchvision.transforms import ColorJitter 11 | 12 | cv2.setNumThreads(0) 13 | cv2.ocl.setUseOpenCL(False) 14 | 15 | 16 | class FlowAugmentor: 17 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True): 18 | # spatial augmentation params 19 | self.crop_size = crop_size 20 | self.min_scale = min_scale 21 | self.max_scale = max_scale 22 | self.spatial_aug_prob = 0.8 23 | self.stretch_prob = 0.8 24 | self.max_stretch = 0.2 25 | 26 | # flip augmentation params 27 | self.do_flip = do_flip 28 | self.h_flip_prob = 0.5 29 | self.v_flip_prob = 0.1 30 | 31 | # photometric augmentation params 32 | self.photo_aug = ColorJitter( 33 | brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5 / 3.14 34 | ) 35 | self.asymmetric_color_aug_prob = 0.2 36 | self.eraser_aug_prob = 0.5 37 | 38 | def color_transform(self, img1, img2): 39 | """Photometric augmentation""" 40 | 41 | # asymmetric 42 | if np.random.rand() < self.asymmetric_color_aug_prob: 43 | img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8) 44 | img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8) 45 | 46 | # symmetric 47 | else: 48 | image_stack = np.concatenate([img1, img2], axis=0) 49 | image_stack = np.array( 50 | self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8 51 | ) 52 | img1, img2 = np.split(image_stack, 2, axis=0) 53 | 54 | return img1, img2 55 | 56 | def eraser_transform(self, img1, img2, bounds=[50, 100]): 57 | """Occlusion augmentation""" 58 | 59 | ht, wd = img1.shape[:2] 60 | if np.random.rand() < self.eraser_aug_prob: 61 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 62 | for _ in range(np.random.randint(1, 3)): 63 | x0 = np.random.randint(0, wd) 64 | y0 = np.random.randint(0, ht) 65 | dx = np.random.randint(bounds[0], bounds[1]) 66 | dy = np.random.randint(bounds[0], bounds[1]) 67 | img2[y0 : y0 + dy, x0 : x0 + dx, :] = mean_color 68 | 69 | return img1, img2 70 | 71 | def spatial_transform(self, img1, img2, flow): 72 | # randomly sample scale 73 | ht, wd = img1.shape[:2] 74 | min_scale = np.maximum( 75 | (self.crop_size[0] + 8) / float(ht), (self.crop_size[1] + 8) / float(wd) 76 | ) 77 | 78 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 79 | scale_x = scale 80 | scale_y = scale 81 | if np.random.rand() < self.stretch_prob: 82 | scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 83 | scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 84 | 85 | scale_x = np.clip(scale_x, min_scale, None) 86 | scale_y = np.clip(scale_y, min_scale, None) 87 | 88 | if np.random.rand() < self.spatial_aug_prob: 89 | # rescale the images 90 | img1 = cv2.resize( 91 | img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR 92 | ) 93 | img2 = cv2.resize( 94 | img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR 95 | ) 96 | flow = cv2.resize( 97 | flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR 98 | ) 99 | flow = flow * [scale_x, scale_y] 100 | 101 | if self.do_flip: 102 | if np.random.rand() < self.h_flip_prob: # h-flip 103 | img1 = img1[:, ::-1] 104 | img2 = img2[:, ::-1] 105 | flow = flow[:, ::-1] * [-1.0, 1.0] 106 | 107 | if np.random.rand() < self.v_flip_prob: # v-flip 108 | img1 = img1[::-1, :] 109 | img2 = img2[::-1, :] 110 | flow = flow[::-1, :] * [1.0, -1.0] 111 | 112 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) 113 | x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) 114 | 115 | img1 = img1[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] 116 | img2 = img2[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] 117 | flow = flow[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] 118 | 119 | return img1, img2, flow 120 | 121 | def __call__(self, img1, img2, flow): 122 | img1, img2 = self.color_transform(img1, img2) 123 | img1, img2 = self.eraser_transform(img1, img2) 124 | img1, img2, flow = self.spatial_transform(img1, img2, flow) 125 | 126 | img1 = np.ascontiguousarray(img1) 127 | img2 = np.ascontiguousarray(img2) 128 | flow = np.ascontiguousarray(flow) 129 | 130 | return img1, img2, flow 131 | 132 | 133 | class SparseFlowAugmentor: 134 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False): 135 | # spatial augmentation params 136 | self.crop_size = crop_size 137 | self.min_scale = min_scale 138 | self.max_scale = max_scale 139 | self.spatial_aug_prob = 0.8 140 | self.stretch_prob = 0.8 141 | self.max_stretch = 0.2 142 | 143 | # flip augmentation params 144 | self.do_flip = do_flip 145 | self.h_flip_prob = 0.5 146 | self.v_flip_prob = 0.1 147 | 148 | # photometric augmentation params 149 | self.photo_aug = ColorJitter( 150 | brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3 / 3.14 151 | ) 152 | self.asymmetric_color_aug_prob = 0.2 153 | self.eraser_aug_prob = 0.5 154 | 155 | def color_transform(self, img1, img2): 156 | image_stack = np.concatenate([img1, img2], axis=0) 157 | image_stack = np.array( 158 | self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8 159 | ) 160 | img1, img2 = np.split(image_stack, 2, axis=0) 161 | return img1, img2 162 | 163 | def eraser_transform(self, img1, img2): 164 | ht, wd = img1.shape[:2] 165 | if np.random.rand() < self.eraser_aug_prob: 166 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 167 | for _ in range(np.random.randint(1, 3)): 168 | x0 = np.random.randint(0, wd) 169 | y0 = np.random.randint(0, ht) 170 | dx = np.random.randint(50, 100) 171 | dy = np.random.randint(50, 100) 172 | img2[y0 : y0 + dy, x0 : x0 + dx, :] = mean_color 173 | 174 | return img1, img2 175 | 176 | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): 177 | ht, wd = flow.shape[:2] 178 | coords = np.meshgrid(np.arange(wd), np.arange(ht), indexing="xy") 179 | coords = np.stack(coords, axis=-1) 180 | 181 | coords = coords.reshape(-1, 2).astype(np.float32) 182 | flow = flow.reshape(-1, 2).astype(np.float32) 183 | valid = valid.reshape(-1).astype(np.float32) 184 | 185 | coords0 = coords[valid >= 1] 186 | flow0 = flow[valid >= 1] 187 | 188 | ht1 = int(round(ht * fy)) 189 | wd1 = int(round(wd * fx)) 190 | 191 | coords1 = coords0 * [fx, fy] 192 | flow1 = flow0 * [fx, fy] 193 | 194 | xx = np.round(coords1[:, 0]).astype(np.int32) 195 | yy = np.round(coords1[:, 1]).astype(np.int32) 196 | 197 | v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) 198 | xx = xx[v] 199 | yy = yy[v] 200 | flow1 = flow1[v] 201 | 202 | flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) 203 | valid_img = np.zeros([ht1, wd1], dtype=np.int32) 204 | 205 | flow_img[yy, xx] = flow1 206 | valid_img[yy, xx] = 1 207 | 208 | return flow_img, valid_img 209 | 210 | def spatial_transform(self, img1, img2, flow, valid): 211 | # randomly sample scale 212 | 213 | ht, wd = img1.shape[:2] 214 | min_scale = np.maximum( 215 | (self.crop_size[0] + 1) / float(ht), (self.crop_size[1] + 1) / float(wd) 216 | ) 217 | 218 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 219 | scale_x = np.clip(scale, min_scale, None) 220 | scale_y = np.clip(scale, min_scale, None) 221 | 222 | if np.random.rand() < self.spatial_aug_prob: 223 | # rescale the images 224 | img1 = cv2.resize( 225 | img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR 226 | ) 227 | img2 = cv2.resize( 228 | img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR 229 | ) 230 | flow, valid = self.resize_sparse_flow_map( 231 | flow, valid, fx=scale_x, fy=scale_y 232 | ) 233 | 234 | if self.do_flip: 235 | if np.random.rand() < 0.5: # h-flip 236 | img1 = img1[:, ::-1] 237 | img2 = img2[:, ::-1] 238 | flow = flow[:, ::-1] * [-1.0, 1.0] 239 | valid = valid[:, ::-1] 240 | 241 | margin_y = 20 242 | margin_x = 50 243 | 244 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y) 245 | x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x) 246 | 247 | y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) 248 | x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) 249 | 250 | img1 = img1[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] 251 | img2 = img2[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] 252 | flow = flow[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] 253 | valid = valid[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] 254 | return img1, img2, flow, valid 255 | 256 | def __call__(self, img1, img2, flow, valid): 257 | img1, img2 = self.color_transform(img1, img2) 258 | img1, img2 = self.eraser_transform(img1, img2) 259 | img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid) 260 | 261 | img1 = np.ascontiguousarray(img1) 262 | img2 = np.ascontiguousarray(img2) 263 | flow = np.ascontiguousarray(flow) 264 | valid = np.ascontiguousarray(valid) 265 | 266 | return img1, img2, flow, valid 267 | -------------------------------------------------------------------------------- /ezsynth/utils/flow_utils/core/utils/flow_viz.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization 2 | 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Tom Runia 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to conditions. 14 | # 15 | # Author: Tom Runia 16 | # Date Created: 2018-08-03 17 | 18 | import numpy as np 19 | 20 | 21 | def make_colorwheel(): 22 | """ 23 | Generates a color wheel for optical flow visualization as presented in: 24 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 25 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 26 | 27 | Code follows the original C++ source code of Daniel Scharstein. 28 | Code follows the the Matlab source code of Deqing Sun. 29 | 30 | Returns: 31 | np.ndarray: Color wheel 32 | """ 33 | 34 | RY = 15 35 | YG = 6 36 | GC = 4 37 | CB = 11 38 | BM = 13 39 | MR = 6 40 | 41 | ncols = RY + YG + GC + CB + BM + MR 42 | colorwheel = np.zeros((ncols, 3)) 43 | col = 0 44 | 45 | # RY 46 | colorwheel[0:RY, 0] = 255 47 | colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY) 48 | col = col + RY 49 | # YG 50 | colorwheel[col : col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG) 51 | colorwheel[col : col + YG, 1] = 255 52 | col = col + YG 53 | # GC 54 | colorwheel[col : col + GC, 1] = 255 55 | colorwheel[col : col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC) 56 | col = col + GC 57 | # CB 58 | colorwheel[col : col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB) 59 | colorwheel[col : col + CB, 2] = 255 60 | col = col + CB 61 | # BM 62 | colorwheel[col : col + BM, 2] = 255 63 | colorwheel[col : col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM) 64 | col = col + BM 65 | # MR 66 | colorwheel[col : col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR) 67 | colorwheel[col : col + MR, 0] = 255 68 | return colorwheel 69 | 70 | 71 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 72 | """ 73 | Applies the flow color wheel to (possibly clipped) flow components u and v. 74 | 75 | According to the C++ source code of Daniel Scharstein 76 | According to the Matlab source code of Deqing Sun 77 | 78 | Args: 79 | u (np.ndarray): Input horizontal flow of shape [H,W] 80 | v (np.ndarray): Input vertical flow of shape [H,W] 81 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 82 | 83 | Returns: 84 | np.ndarray: Flow visualization image of shape [H,W,3] 85 | """ 86 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 87 | colorwheel = make_colorwheel() # shape [55x3] 88 | ncols = colorwheel.shape[0] 89 | rad = np.sqrt(np.square(u) + np.square(v)) 90 | a = np.arctan2(-v, -u) / np.pi 91 | fk = (a + 1) / 2 * (ncols - 1) 92 | k0 = np.floor(fk).astype(np.int32) 93 | k1 = k0 + 1 94 | k1[k1 == ncols] = 0 95 | f = fk - k0 96 | for i in range(colorwheel.shape[1]): 97 | tmp = colorwheel[:, i] 98 | col0 = tmp[k0] / 255.0 99 | col1 = tmp[k1] / 255.0 100 | col = (1 - f) * col0 + f * col1 101 | idx = rad <= 1 102 | col[idx] = 1 - rad[idx] * (1 - col[idx]) 103 | col[~idx] = col[~idx] * 0.75 # out of range 104 | # Note the 2-i => BGR instead of RGB 105 | ch_idx = 2 - i if convert_to_bgr else i 106 | flow_image[:, :, ch_idx] = np.floor(255 * col) 107 | return flow_image 108 | 109 | 110 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): 111 | """ 112 | Expects a two dimensional flow image of shape. 113 | 114 | Args: 115 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 116 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 117 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 118 | 119 | Returns: 120 | np.ndarray: Flow visualization image of shape [H,W,3] 121 | """ 122 | assert flow_uv.ndim == 3, "input flow must have three dimensions" 123 | assert flow_uv.shape[2] == 2, "input flow must have shape [H,W,2]" 124 | if clip_flow is not None: 125 | flow_uv = np.clip(flow_uv, 0, clip_flow) 126 | u = flow_uv[:, :, 0] 127 | v = flow_uv[:, :, 1] 128 | rad = np.sqrt(np.square(u) + np.square(v)) 129 | rad_max = np.max(rad) 130 | epsilon = 1e-5 131 | u = u / (rad_max + epsilon) 132 | v = v / (rad_max + epsilon) 133 | return flow_uv_to_colors(u, v, convert_to_bgr) 134 | -------------------------------------------------------------------------------- /ezsynth/utils/flow_utils/core/utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from scipy import interpolate 5 | 6 | 7 | class InputPadder: 8 | """Pads images such that dimensions are divisible by 8""" 9 | 10 | def __init__(self, dims, mode="sintel"): 11 | self.ht, self.wd = dims[-2:] 12 | self._calculate_padding(mode) 13 | 14 | def _calculate_padding(self, mode): 15 | pad_ht = -self.ht % 8 16 | pad_wd = -self.wd % 8 17 | 18 | if mode == "sintel": 19 | self._pad = ( 20 | pad_wd // 2, 21 | pad_wd - pad_wd // 2, 22 | pad_ht // 2, 23 | pad_ht - pad_ht // 2, 24 | ) 25 | else: 26 | self._pad = (pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht) 27 | 28 | def pad(self, *inputs): 29 | return [F.pad(x, self._pad, mode="replicate") for x in inputs] 30 | 31 | def unpad(self, x): 32 | return x[ 33 | ..., 34 | self._pad[2] : self.ht + self._pad[2], 35 | self._pad[0] : self.wd + self._pad[0], 36 | ] 37 | 38 | 39 | def forward_interpolate(flow_ts): 40 | with torch.no_grad(): 41 | flow = flow_ts.cpu().numpy() 42 | dx, dy = flow[0], flow[1] 43 | 44 | ht, wd = dx.shape 45 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht), indexing="xy") 46 | 47 | x1 = x0 + dx 48 | y1 = y0 + dy 49 | 50 | x1 = x1.reshape(-1) 51 | y1 = y1.reshape(-1) 52 | dx = dx.reshape(-1) 53 | dy = dy.reshape(-1) 54 | 55 | valid = (x1 >= 0) & (x1 < wd) & (y1 >= 0) & (y1 < ht) 56 | 57 | x1_valid = x1[valid] 58 | y1_valid = y1[valid] 59 | dx_valid = dx[valid] 60 | dy_valid = dy[valid] 61 | 62 | flow_x = interpolate.griddata( 63 | (x1_valid, y1_valid), dx_valid, (x0, y0), method="nearest", fill_value=0 64 | ) 65 | 66 | flow_y = interpolate.griddata( 67 | (x1_valid, y1_valid), dy_valid, (x0, y0), method="nearest", fill_value=0 68 | ) 69 | 70 | flow = np.stack([flow_x, flow_y], axis=0) 71 | return torch.from_numpy(flow).float() 72 | 73 | 74 | def bilinear_sampler(img, coords, mode="bilinear", mask=False): 75 | """Wrapper for grid_sample, uses pixel coordinates""" 76 | H, W = img.shape[-2:] 77 | xgrid, ygrid = coords.split([1, 1], dim=-1) 78 | xgrid = 2 * xgrid / (W - 1) - 1 79 | ygrid = 2 * ygrid / (H - 1) - 1 80 | 81 | grid = torch.cat([xgrid, ygrid], dim=-1) 82 | img = F.grid_sample(img, grid, align_corners=True) 83 | 84 | if mask: 85 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 86 | return img, mask.float() 87 | 88 | return img 89 | 90 | 91 | def coords_grid(batch, ht, wd, device): 92 | coords = torch.meshgrid( 93 | torch.arange(ht, device=device), torch.arange(wd, device=device), indexing="ij" 94 | ) 95 | coords = torch.stack(coords[::-1], dim=0).float() 96 | return coords[None].repeat(batch, 1, 1, 1) 97 | 98 | 99 | def upflow8(flow, mode="bilinear"): 100 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 101 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 102 | -------------------------------------------------------------------------------- /ezsynth/utils/flow_utils/ef_raft_models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/ezsynth/utils/flow_utils/ef_raft_models/.gitkeep -------------------------------------------------------------------------------- /ezsynth/utils/flow_utils/flow_diff/fd_corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from ezsynth.utils.flow_utils.core.utils.utils import bilinear_sampler 5 | 6 | 7 | class CorrBlock_FD_Sp4: 8 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4, coords_init=None, rad=1): 9 | self.num_levels = num_levels 10 | self.radius = radius 11 | self.corr_pyramid = [] 12 | 13 | corr = CorrBlock_FD_Sp4.corr(fmap1, fmap2, coords_init, r=rad) 14 | 15 | batch, h1, w1, dim, h2, w2 = corr.shape 16 | corr = corr.reshape(batch*h1*w1, dim, h2, w2) 17 | 18 | self.corr_pyramid.append(corr) 19 | for i in range(self.num_levels-1): 20 | corr = F.avg_pool2d(corr, 2, stride=2) 21 | self.corr_pyramid.append(corr) 22 | 23 | def __call__(self, coords): 24 | r = self.radius 25 | coords = coords.permute(0, 2, 3, 1) 26 | batch, h1, w1, _ = coords.shape 27 | 28 | out_pyramid = [] 29 | for i in range(self.num_levels): 30 | corr = self.corr_pyramid[i] 31 | dx = torch.linspace(-r, r, 2*r+1, device=coords.device) 32 | dy = torch.linspace(-r, r, 2*r+1, device=coords.device) 33 | delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1) 34 | 35 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i 36 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) 37 | coords_lvl = centroid_lvl + delta_lvl 38 | 39 | corr = bilinear_sampler(corr, coords_lvl) 40 | corr = corr.view(batch, h1, w1, -1) 41 | out_pyramid.append(corr) 42 | 43 | out = torch.cat(out_pyramid, dim=-1) 44 | return out.permute(0, 3, 1, 2).contiguous().float() 45 | 46 | @staticmethod 47 | def corr(fmap1, fmap2, coords_init, r): 48 | batch, dim, ht, wd = fmap1.shape 49 | fmap1 = fmap1.view(batch, dim, ht*wd) 50 | fmap2 = fmap2.view(batch, dim, ht*wd) 51 | 52 | corr = torch.matmul(fmap1.transpose(1, 2), fmap2) 53 | corr = corr.view(batch, ht, wd, 1, ht, wd) 54 | # return corr / torch.sqrt(torch.tensor(dim).float()) 55 | 56 | coords = coords_init.permute(0, 2, 3, 1).contiguous() 57 | batch, h1, w1, _ = coords.shape 58 | 59 | corr = corr.view(batch*h1*w1, 1, h1, w1) 60 | 61 | dx = torch.linspace(-r, r, 2*r+1, device=coords.device) 62 | dy = torch.linspace(-r, r, 2*r+1, device=coords.device) 63 | delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1) 64 | 65 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) 66 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) 67 | coords_lvl = centroid_lvl + delta_lvl 68 | 69 | corr = bilinear_sampler(corr, coords_lvl) 70 | 71 | corr = corr.view(batch, h1, w1, 1, 2*r+1, 2*r+1) 72 | return corr.permute(0, 1, 2, 3, 5, 4).contiguous() / torch.sqrt(torch.tensor(dim).float()) 73 | -------------------------------------------------------------------------------- /ezsynth/utils/flow_utils/flow_diff/fd_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | # import torch.nn.functional as F 4 | 5 | import timm 6 | import numpy as np 7 | 8 | 9 | class twins_svt_large(nn.Module): 10 | def __init__(self, pretrained=True): 11 | super().__init__() 12 | self.svt = timm.create_model('twins_svt_large', pretrained=pretrained) 13 | 14 | del self.svt.head 15 | del self.svt.patch_embeds[2] 16 | del self.svt.patch_embeds[2] 17 | del self.svt.blocks[2] 18 | del self.svt.blocks[2] 19 | del self.svt.pos_block[2] 20 | del self.svt.pos_block[2] 21 | 22 | def forward(self, x, data=None, layer=2): 23 | B = x.shape[0] 24 | x_4 = None 25 | for i, (embed, drop, blocks, pos_blk) in enumerate( 26 | zip(self.svt.patch_embeds, self.svt.pos_drops, self.svt.blocks, self.svt.pos_block)): 27 | 28 | patch_size = embed.patch_size 29 | if i == layer - 1: 30 | embed.patch_size = (1, 1) 31 | embed.proj.stride = embed.patch_size 32 | x_4 = torch.nn.functional.pad(x, [1, 0, 1, 0], mode='constant', value=0) 33 | x_4, size_4 = embed(x_4) 34 | size_4 = (size_4[0] - 1, size_4[1] - 1) 35 | x_4 = drop(x_4) 36 | for j, blk in enumerate(blocks): 37 | x_4 = blk(x_4, size_4) 38 | if j == 0: 39 | x_4 = pos_blk(x_4, size_4) 40 | 41 | if i < len(self.svt.depths) - 1: 42 | x_4 = x_4.reshape(B, *size_4, -1).permute(0, 3, 1, 2).contiguous() 43 | 44 | embed.patch_size = patch_size 45 | embed.proj.stride = patch_size 46 | x, size = embed(x) 47 | x = drop(x) 48 | for j, blk in enumerate(blocks): 49 | x = blk(x, size) 50 | if j==0: 51 | x = pos_blk(x, size) 52 | if i < len(self.svt.depths) - 1: 53 | x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() 54 | 55 | if i == layer-1: 56 | break 57 | 58 | return x, x_4 59 | 60 | def compute_params(self, layer=2): 61 | num = 0 62 | for i, (embed, drop, blocks, pos_blk) in enumerate( 63 | zip(self.svt.patch_embeds, self.svt.pos_drops, self.svt.blocks, self.svt.pos_block)): 64 | 65 | for param in embed.parameters(): 66 | num += np.prod(param.size()) 67 | 68 | for param in drop.parameters(): 69 | num += np.prod(param.size()) 70 | 71 | for param in blocks.parameters(): 72 | num += np.prod(param.size()) 73 | 74 | for param in pos_blk.parameters(): 75 | num += np.prod(param.size()) 76 | 77 | if i == layer-1: 78 | break 79 | 80 | for param in self.svt.head.parameters(): 81 | num += np.prod(param.size()) 82 | 83 | return num 84 | 85 | 86 | class twins_svt_small_context(nn.Module): 87 | def __init__(self, pretrained=True): 88 | super().__init__() 89 | self.svt = timm.create_model('twins_svt_small', pretrained=pretrained) 90 | 91 | del self.svt.head 92 | del self.svt.patch_embeds[2] 93 | del self.svt.patch_embeds[2] 94 | del self.svt.blocks[2] 95 | del self.svt.blocks[2] 96 | del self.svt.pos_block[2] 97 | del self.svt.pos_block[2] 98 | 99 | def forward(self, x, data=None, layer=2): 100 | B = x.shape[0] 101 | x_4 = None 102 | for i, (embed, drop, blocks, pos_blk) in enumerate( 103 | zip(self.svt.patch_embeds, self.svt.pos_drops, self.svt.blocks, self.svt.pos_block)): 104 | 105 | patch_size = embed.patch_size 106 | if i == layer - 1: 107 | embed.patch_size = (1, 1) 108 | embed.proj.stride = embed.patch_size 109 | x_4 = torch.nn.functional.pad(x, [1, 0, 1, 0], mode='constant', value=0) 110 | x_4, size_4 = embed(x_4) 111 | size_4 = (size_4[0] - 1, size_4[1] - 1) 112 | x_4 = drop(x_4) 113 | for j, blk in enumerate(blocks): 114 | x_4 = blk(x_4, size_4) 115 | if j == 0: 116 | x_4 = pos_blk(x_4, size_4) 117 | 118 | if i < len(self.svt.depths) - 1: 119 | x_4 = x_4.reshape(B, *size_4, -1).permute(0, 3, 1, 2).contiguous() 120 | 121 | embed.patch_size = patch_size 122 | embed.proj.stride = patch_size 123 | x, size = embed(x) 124 | x = drop(x) 125 | for j, blk in enumerate(blocks): 126 | x = blk(x, size) 127 | if j == 0: 128 | x = pos_blk(x, size) 129 | if i < len(self.svt.depths) - 1: 130 | x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() 131 | 132 | if i == layer - 1: 133 | break 134 | 135 | return x, x_4 136 | 137 | def compute_params(self, layer=2): 138 | num = 0 139 | for i, (embed, drop, blocks, pos_blk) in enumerate( 140 | zip(self.svt.patch_embeds, self.svt.pos_drops, self.svt.blocks, self.svt.pos_block)): 141 | 142 | for param in embed.parameters(): 143 | num += np.prod(param.size()) 144 | 145 | for param in drop.parameters(): 146 | num += np.prod(param.size()) 147 | 148 | for param in blocks.parameters(): 149 | num += np.prod(param.size()) 150 | 151 | for param in pos_blk.parameters(): 152 | num += np.prod(param.size()) 153 | 154 | if i == layer - 1: 155 | break 156 | 157 | for param in self.svt.head.parameters(): 158 | num += np.prod(param.size()) 159 | 160 | return num 161 | 162 | -------------------------------------------------------------------------------- /ezsynth/utils/flow_utils/flow_diffusion_models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/ezsynth/utils/flow_utils/flow_diffusion_models/.gitkeep -------------------------------------------------------------------------------- /ezsynth/utils/flow_utils/models/raft-kitti.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/ezsynth/utils/flow_utils/models/raft-kitti.pth -------------------------------------------------------------------------------- /ezsynth/utils/flow_utils/models/raft-sintel.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/ezsynth/utils/flow_utils/models/raft-sintel.pth -------------------------------------------------------------------------------- /ezsynth/utils/flow_utils/models/raft-small.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/ezsynth/utils/flow_utils/models/raft-small.pth -------------------------------------------------------------------------------- /ezsynth/utils/flow_utils/warp.py: -------------------------------------------------------------------------------- 1 | # import warnings 2 | 3 | import cv2 4 | import numpy as np 5 | 6 | # warnings.filterwarnings("ignore", category=UserWarning) 7 | 8 | 9 | class Warp: 10 | def __init__(self, sample_fr: np.ndarray): 11 | H, W, _ = sample_fr.shape 12 | self.H = H 13 | self.W = W 14 | self.grid = self._create_grid(H, W) 15 | 16 | def _create_grid(self, H: int, W: int): 17 | x, y = np.meshgrid(np.arange(W), np.arange(H), indexing="xy") 18 | return np.stack((x, y), axis=-1).astype(np.float32) 19 | 20 | def _warp(self, img: np.ndarray, flo: np.ndarray): 21 | flo_resized = cv2.resize(flo, (self.W, self.H), interpolation=cv2.INTER_LINEAR) 22 | map_x = self.grid[..., 0] + flo_resized[..., 0] 23 | map_y = self.grid[..., 1] + flo_resized[..., 1] 24 | warped_img = cv2.remap( 25 | img, 26 | map_x, 27 | map_y, 28 | interpolation=cv2.INTER_LINEAR, 29 | borderMode=cv2.BORDER_REFLECT, 30 | ) 31 | return warped_img 32 | 33 | def run_warping(self, img: np.ndarray, flow: np.ndarray): 34 | img = img.astype(np.float32) 35 | flow = flow.astype(np.float32) 36 | 37 | try: 38 | warped_img = self._warp(img, flow) 39 | warped_image = (warped_img * 255).astype(np.uint8) 40 | return warped_image 41 | except Exception as e: 42 | print(f"[ERROR] Exception in run_warping: {e}") 43 | return None 44 | -------------------------------------------------------------------------------- /output_synth/facestyle_err.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/output_synth/facestyle_err.png -------------------------------------------------------------------------------- /output_synth/facestyle_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/output_synth/facestyle_out.png -------------------------------------------------------------------------------- /output_synth/retarget_err.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/output_synth/retarget_err.png -------------------------------------------------------------------------------- /output_synth/retarget_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/output_synth/retarget_out.png -------------------------------------------------------------------------------- /output_synth/stylit_err.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/output_synth/stylit_err.png -------------------------------------------------------------------------------- /output_synth/stylit_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Trentonom0r3/Ezsynth/b198f2d7051eee542c4efc51c2d43dc442630bbf/output_synth/stylit_out.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | opencv-python 3 | torch 4 | torchvision 5 | phycv 6 | scipy 7 | Pillow 8 | tqdm -------------------------------------------------------------------------------- /test_imgsynth.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | import sys 4 | import time 5 | 6 | import torch 7 | 8 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 9 | 10 | from ezsynth.aux_classes import RunConfig 11 | from ezsynth.aux_utils import save_to_folder, load_guide 12 | from ezsynth.main_ez import ImageSynth 13 | 14 | st = time.time() 15 | 16 | output_folder = "J:/AI/Ezsynth/output_synth" 17 | 18 | # Examples from the Ebsynth repository 19 | 20 | ## Segment retargetting 21 | 22 | ezsynner = ImageSynth( 23 | style_path="J:/AI/Ezsynth/examples/texbynum/source_photo.png", 24 | src_path="J:/AI/Ezsynth/examples/texbynum/source_segment.png", 25 | tgt_path="J:/AI/Ezsynth/examples/texbynum/target_segment.png", 26 | cfg=RunConfig(img_wgt=1.0), 27 | ) 28 | 29 | result = ezsynner.run() 30 | 31 | save_to_folder(output_folder, "retarget_out.png", result[0]) 32 | save_to_folder(output_folder, "retarget_err.png", result[1]) 33 | 34 | ## Stylit 35 | 36 | ezsynner = ImageSynth( 37 | style_path="J:/AI/Ezsynth/examples/stylit/source_style.png", 38 | src_path="J:/AI/Ezsynth/examples/stylit/source_fullgi.png", 39 | tgt_path="J:/AI/Ezsynth/examples/stylit/target_fullgi.png", 40 | cfg=RunConfig(img_wgt=0.66), 41 | ) 42 | 43 | result = ezsynner.run( 44 | guides=[ 45 | load_guide( 46 | "J:/AI/Ezsynth/examples/stylit/source_dirdif.png", 47 | "J:/AI/Ezsynth/examples/stylit/target_dirdif.png", 48 | 0.66, 49 | ), 50 | load_guide( 51 | "J:/AI/Ezsynth/examples/stylit/source_indirb.png", 52 | "J:/AI/Ezsynth/examples/stylit/target_indirb.png", 53 | 0.66, 54 | ), 55 | ] 56 | ) 57 | 58 | save_to_folder(output_folder, "stylit_out.png", result[0]) 59 | save_to_folder(output_folder, "stylit_err.png", result[1]) 60 | 61 | ## Face style 62 | 63 | ezsynner = ImageSynth( 64 | style_path="J:/AI/Ezsynth/examples/facestyle/source_painting.png", 65 | src_path="J:/AI/Ezsynth/examples/facestyle/source_Gapp.png", 66 | tgt_path="J:/AI/Ezsynth/examples/facestyle/target_Gapp.png", 67 | cfg=RunConfig(img_wgt=2.0), 68 | ) 69 | 70 | result = ezsynner.run( 71 | guides=[ 72 | load_guide( 73 | "J:/AI/Ezsynth/examples/facestyle/source_Gseg.png", 74 | "J:/AI/Ezsynth/examples/facestyle/target_Gseg.png", 75 | 1.5, 76 | ), 77 | load_guide( 78 | "J:/AI/Ezsynth/examples/facestyle/source_Gpos.png", 79 | "J:/AI/Ezsynth/examples/facestyle/target_Gpos.png", 80 | 1.5, 81 | ), 82 | ] 83 | ) 84 | 85 | save_to_folder(output_folder, "facestyle_out.png", result[0]) 86 | save_to_folder(output_folder, "facestyle_err.png", result[1]) 87 | 88 | gc.collect() 89 | torch.cuda.empty_cache() 90 | 91 | print(f"Time taken: {time.time() - st:.4f} s") 92 | -------------------------------------------------------------------------------- /test_progress.txt: -------------------------------------------------------------------------------- 1 | Style 1 2 | First frame only: Passed (Passed) (Passed) 3 | Last frame only: Passed (Passed) (Passed) 4 | Middle frame only: Passed (Passed) (Passed) 5 | 6 | Style 2 with Blend (Regression - Fixed) 7 | First and Last: Passed (Passed) (Passed) 8 | First and Mid: Passed (Passed) (Passed) 9 | Mid and Last: Passed (Passed) (Passed) 10 | Mid and Mid: Passed (Passed) (Passed) 11 | 12 | Style 2 without Blend - Forward only 13 | First and Last: Passed (Passed) (Passed) 14 | First and Mid: Passed (Passed) (Passed) 15 | Mid and Last: Passed (Passed) (Passed) 16 | Mid and Mid: Passed (Passed) (Passed) 17 | 18 | Style 2 without Blend - Reverse only (Regression - Fixed) 19 | First and Last: Passed (Passed) (Passed) 20 | First and Mid: Passed (Passed) (Passed) 21 | Mid and Last: Passed (Passed) (Passed) 22 | Mid and Mid: Passed (Passed) (Passed) 23 | 24 | Style 3: 25 | Blend Mid Mid Mid: Passed (Passed) 26 | Blend First Mid Last: Passed (Passed) 27 | Forward Mid Mid Mid: Passed (Passed) 28 | Forward First Mid Last: Passed (Passed) 29 | Reverse Mid Mid Mid: Passed (Passed) 30 | Reverse First Mid Last: Passed (Passed) 31 | -------------------------------------------------------------------------------- /test_redux.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | import sys 4 | import time 5 | 6 | import torch 7 | 8 | 9 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 10 | 11 | from ezsynth.sequences import EasySequence 12 | from ezsynth.aux_classes import RunConfig 13 | from ezsynth.aux_utils import save_seq 14 | from ezsynth.main_ez import Ezsynth 15 | 16 | st = time.time() 17 | 18 | style_paths = [ 19 | "J:/AI/Ezsynth/examples/styles/style000.jpg", 20 | # "J:/AI/Ezsynth/examples/styles/style002.png", 21 | # "J:/AI/Ezsynth/examples/styles/style003.png", 22 | # "J:/AI/Ezsynth/examples/styles/style006.png", 23 | "J:/AI/Ezsynth/examples/styles/style010.png", 24 | # "J:/AI/Ezsynth/examples/styles/style014.png", 25 | # "J:/AI/Ezsynth/examples/styles/style019.png", 26 | # "J:/AI/Ezsynth/examples/styles/style099.jpg", 27 | ] 28 | 29 | image_folder = "J:/AI/Ezsynth/examples/input" 30 | mask_folder = "J:/AI/Ezsynth/examples/mask/mask_feather" 31 | output_folder = "J:/AI/Ezsynth/output" 32 | 33 | # edge_method="Classic" 34 | edge_method = "PAGE" 35 | # edge_method="PST" 36 | 37 | # flow_arch = "RAFT" 38 | # flow_model = "sintel" 39 | 40 | flow_arch = "EF_RAFT" 41 | flow_model = "25000_ours-sintel" 42 | 43 | # flow_arch = "FLOW_DIFF" 44 | # flow_model = "FlowDiffuser-things" 45 | 46 | 47 | ezrunner = Ezsynth( 48 | style_paths=style_paths, 49 | image_folder=image_folder, 50 | cfg=RunConfig(pre_mask=False, feather=5), 51 | edge_method=edge_method, 52 | raft_flow_model_name=flow_model, 53 | mask_folder=mask_folder, 54 | # do_mask=True, 55 | do_mask=False, 56 | flow_arch=flow_arch 57 | ) 58 | 59 | 60 | # only_mode = EasySequence.MODE_FWD 61 | # only_mode = EasySequence.MODE_REV 62 | only_mode = None 63 | 64 | stylized_frames, err_frames, flow_frames = ezrunner.run_sequences_full( 65 | only_mode, return_flow=True 66 | ) 67 | # stylized_frames, err_frames = ezrunner.run_sequences(only_mode) 68 | 69 | save_seq(stylized_frames, "J:/AI/Ezsynth/output_57_efraft") 70 | save_seq(flow_frames, "J:/AI/Ezsynth/flow_output_57_efraft") 71 | # save_seq(err_frames, "J:/AI/Ezsynth/output_51err") 72 | 73 | gc.collect() 74 | torch.cuda.empty_cache() 75 | 76 | print(f"Time taken: {time.time() - st:.4f} s") 77 | --------------------------------------------------------------------------------