├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── LICENSE.md ├── README.md ├── blender ├── guide.py ├── histogram_blend.py ├── poisson_fusion.py └── video_sequence.py ├── config ├── real2sculpture.json ├── real2sculpture_freeu.json ├── real2sculpture_loose_cfattn.json ├── van_gogh_man.json ├── van_gogh_man_dynamic_resolution.json └── woman.json ├── environment.yml ├── flow └── flow_utils.py ├── inference_playground.ipynb ├── install.py ├── requirements.txt ├── rerender.py ├── sd_model_cfg.py ├── src ├── config.py ├── controller.py ├── ddim_v_hacked.py ├── freeu.py ├── img_util.py ├── import_util.py └── video_util.py ├── video_blend.py ├── videos ├── pexels-antoni-shkraba-8048492-540x960-25fps.mp4 ├── pexels-cottonbro-studio-6649832-960x506-25fps.mp4 └── pexels-koolshooters-7322716.mp4 └── webUI.py /.gitignore: -------------------------------------------------------------------------------- 1 | videos 2 | models 3 | result 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | **/*.pyc 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/en/_build/ 72 | docs/zh_cn/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # Environments 90 | .env 91 | .venv 92 | env/ 93 | venv/ 94 | ENV/ 95 | env.bak/ 96 | venv.bak/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | 111 | # custom 112 | .vscode 113 | .idea 114 | *.pkl 115 | *.pkl.json 116 | *.log.json 117 | work_dirs/ 118 | 119 | # Pytorch 120 | *.pth 121 | 122 | # onnx and tensorrt 123 | *.onnx 124 | *.trt 125 | 126 | # local history 127 | .history/** 128 | 129 | # Pytorch Server 130 | *.mar 131 | .DS_Store 132 | 133 | /data/ 134 | /data 135 | data 136 | .vector_cache 137 | 138 | __pycache 139 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "deps/gmflow"] 2 | path = deps/gmflow 3 | url = https://github.com/haofeixu/gmflow.git 4 | [submodule "deps/ControlNet"] 5 | path = deps/ControlNet 6 | url = https://github.com/lllyasviel/ControlNet.git 7 | [submodule "deps/ebsynth"] 8 | path = deps/ebsynth 9 | url = https://github.com/SingleZombie/ebsynth.git 10 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/PyCQA/flake8 3 | rev: 4.0.1 4 | hooks: 5 | - id: flake8 6 | - repo: https://github.com/PyCQA/isort 7 | rev: 5.11.5 8 | hooks: 9 | - id: isort 10 | - repo: https://github.com/pre-commit/mirrors-yapf 11 | rev: v0.32.0 12 | hooks: 13 | - id: yapf 14 | - repo: https://github.com/pre-commit/pre-commit-hooks 15 | rev: v4.2.0 16 | hooks: 17 | - id: trailing-whitespace 18 | - id: check-yaml 19 | - id: end-of-file-fixer 20 | - id: requirements-txt-fixer 21 | - id: double-quote-string-fixer 22 | - id: check-merge-conflict 23 | - id: fix-encoding-pragma 24 | args: ["--remove"] 25 | - id: mixed-line-ending 26 | args: ["--fix=lf"] 27 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | # S-Lab License 1.0 2 | 3 | Copyright 2023 S-Lab 4 | 5 | Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 6 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.\ 9 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 10 | 4. In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work. 11 | 12 | 13 | --- 14 | For the commercial use of the code, please consult Prof. Chen Change Loy (ccloy@ntu.edu.sg) 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Rerender A Video - Official PyTorch Implementation 2 | 3 |  4 | 5 | <!--https://github.com/williamyang1991/Rerender_A_Video/assets/18130694/82c35efb-e86b-4376-bfbe-6b69159b8879--> 6 | 7 | 8 | **Rerender A Video: Zero-Shot Text-Guided Video-to-Video Translation**<br> 9 | [Shuai Yang](https://williamyang1991.github.io/), [Yifan Zhou](https://zhouyifan.net/), [Ziwei Liu](https://liuziwei7.github.io/) and [Chen Change Loy](https://www.mmlab-ntu.com/person/ccloy/)<br> 10 | in SIGGRAPH Asia 2023 Conference Proceedings <br> 11 | [**Project Page**](https://www.mmlab-ntu.com/project/rerender/) | [**Paper**](https://arxiv.org/abs/2306.07954) | [**Supplementary Video**](https://youtu.be/cxfxdepKVaM) | [**Input Data and Video Results**](https://drive.google.com/file/d/1HkxG5eiLM_TQbbMZYOwjDbd5gWisOy4m/view?usp=sharing) <br> 12 | 13 | <a href="https://huggingface.co/spaces/Anonymous-sub/Rerender"><img src="https://huggingface.co/datasets/huggingface/badges/raw/main/open-in-hf-spaces-sm-dark.svg" alt="Web Demo"></a>  14 | 15 | > **Abstract:** *Large text-to-image diffusion models have exhibited impressive proficiency in generating high-quality images. However, when applying these models to video domain, ensuring temporal consistency across video frames remains a formidable challenge. This paper proposes a novel zero-shot text-guided video-to-video translation framework to adapt image models to videos. The framework includes two parts: key frame translation and full video translation. The first part uses an adapted diffusion model to generate key frames, with hierarchical cross-frame constraints applied to enforce coherence in shapes, textures and colors. The second part propagates the key frames to other frames with temporal-aware patch matching and frame blending. Our framework achieves global style and local texture temporal consistency at a low cost (without re-training or optimization). The adaptation is compatible with existing image diffusion techniques, allowing our framework to take advantage of them, such as customizing a specific subject with LoRA, and introducing extra spatial guidance with ControlNet. Extensive experimental results demonstrate the effectiveness of our proposed framework over existing methods in rendering high-quality and temporally-coherent videos.* 16 | 17 | **Features**:<br> 18 | - **Temporal consistency**: cross-frame constraints for low-level temporal consistency. 19 | - **Zero-shot**: no training or fine-tuning required. 20 | - **Flexibility**: compatible with off-the-shelf models (e.g., [ControlNet](https://github.com/lllyasviel/ControlNet), [LoRA](https://civitai.com/)) for customized translation. 21 | 22 | https://github.com/williamyang1991/Rerender_A_Video/assets/18130694/811fdea3-f0da-49c9-92b8-2d2ad360f0d6 23 | 24 | ## Updates 25 | - [12/2023] The Diffusers pipeline is available: [Rerender_A_Video Community Pipeline](https://github.com/huggingface/diffusers/tree/main/examples/community#Rerender_A_Video) 26 | - [10/2023] New features: [Loose cross-frame attention](#loose-cross-frame-attention) and [FreeU](#freeu). 27 | - [09/2023] Code is released. 28 | - [09/2023] Accepted to SIGGRAPH Asia 2023 Conference Proceedings! 29 | - [06/2023] Integrated to 🤗 [Hugging Face](https://huggingface.co/spaces/Anonymous-sub/Rerender). Enjoy the web demo! 30 | - [05/2023] This website is created. 31 | 32 | ### TODO 33 | - [x] ~~Integrate into Diffusers.~~ 34 | - [x] ~~Integrate [FreeU](https://github.com/ChenyangSi/FreeU) into Rerender~~ 35 | - [x] ~~Add Inference instructions in README.md.~~ 36 | - [x] ~~Add Examples to webUI.~~ 37 | - [x] ~~Add optional poisson fusion to the pipeline.~~ 38 | - [x] ~~Add Installation instructions for Windows~~ 39 | 40 | ## Installation 41 | 42 | *Please make sure your installation path only contain English letters or _* 43 | 44 | 1. Clone the repository. (Don't forget --recursive. Otherwise, please run `git submodule update --init --recursive`) 45 | 46 | ```shell 47 | git clone git@github.com:williamyang1991/Rerender_A_Video.git --recursive 48 | cd Rerender_A_Video 49 | ``` 50 | 51 | 2. If you have installed PyTorch CUDA, you can simply set up the environment with pip. 52 | 53 | ```shell 54 | pip install -r requirements.txt 55 | ``` 56 | 57 | You can also create a new conda environment from scratch. 58 | 59 | ```shell 60 | conda env create -f environment.yml 61 | conda activate rerender 62 | ``` 63 | 24GB VRAM is required. Please refer to https://github.com/williamyang1991/Rerender_A_Video/pull/23#issue-1900789461 to reduce memory consumption. 64 | 65 | 3. Run the installation script. The required models will be downloaded in `./models`. 66 | 67 | ```shell 68 | python install.py 69 | ``` 70 | 71 | 4. You can run the demo with `rerender.py` 72 | 73 | ```shell 74 | python rerender.py --cfg config/real2sculpture.json 75 | ``` 76 | 77 | <details> 78 | <summary>Installation on Windows</summary> 79 | 80 | Before running the above 1-4 steps, you need prepare: 81 | 1. Install [CUDA](https://developer.nvidia.com/cuda-toolkit-archive) 82 | 2. Install [git](https://git-scm.com/download/win) 83 | 3. Install [VS](https://visualstudio.microsoft.com/) with Windows 10/11 SDK (for building deps/ebsynth/bin/ebsynth.exe) 84 | 4. [Here](https://github.com/williamyang1991/Rerender_A_Video/issues/18#issuecomment-1752712233) are more information. If building ebsynth fails, we provides our complied [ebsynth](https://drive.google.com/drive/folders/1oSB3imKwZGz69q2unBUfcgmQpzwccoyD?usp=sharing). 85 | </details> 86 | 87 | <details id="issues"> 88 | <summary>🔥🔥🔥 <b>Installation or Running Fails?</b> 🔥🔥🔥</summary> 89 | 90 | 1. In case building ebsynth fails, we provides our complied [ebsynth](https://drive.google.com/drive/folders/1oSB3imKwZGz69q2unBUfcgmQpzwccoyD?usp=sharing) 91 | 2. `FileNotFoundError: [Errno 2] No such file or directory: 'xxxx.bin' or 'xxxx.jpg'`: 92 | - make sure your path only contains English letters or _ (https://github.com/williamyang1991/Rerender_A_Video/issues/18#issuecomment-1723361433) 93 | - find the code `python video_blend.py ...` in the error log and use it to manually run the ebsynth part, which is more stable than WebUI. 94 | - if some non-keyframes are generated but somes are not, rather than missing all non-keyframes in '/out_xx/', you may refer to https://github.com/williamyang1991/Rerender_A_Video/issues/38#issuecomment-1730668991 95 | - Enable the Execute permission of deps/ebsynth/bin/ebsynth 96 | - Enable the debug log to find more information https://github.com/williamyang1991/Rerender_A_Video/blob/d32b1d6b6c1305ddd06e66868c5dcf4fb7aa048c/video_blend.py#L22 97 | 5. `KeyError: 'dataset'`: upgrade Gradio to the latest version (https://github.com/williamyang1991/Rerender_A_Video/issues/14#issuecomment-1722778672, https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/11855) 98 | 6. Error when processing videos: manually install ffmpeg (https://github.com/williamyang1991/Rerender_A_Video/issues/19#issuecomment-1723685825, https://github.com/williamyang1991/Rerender_A_Video/issues/29#issuecomment-1726091112) 99 | 7. `ERR_ADDRESS_INVALID` Cannot open the webUI in browser: replace 0.0.0.0 with 127.0.0.1 in webUI.py (https://github.com/williamyang1991/Rerender_A_Video/issues/19#issuecomment-1723685825) 100 | 8. `CUDA out of memory`: 101 | - Using xformers (https://github.com/williamyang1991/Rerender_A_Video/pull/23#issue-1900789461) 102 | - Set `"use_limit_device_resolution"` to `true` in the config to resize the video according to your VRAM (https://github.com/williamyang1991/Rerender_A_Video/issues/79). An example config `config/van_gogh_man_dynamic_resolution.json` is provided. 103 | 10. `AttributeError: module 'keras.backend' has no attribute 'is_tensor'`: update einops (https://github.com/williamyang1991/Rerender_A_Video/issues/26#issuecomment-1726682446) 104 | 11. `IndexError: list index out of range`: use the original DDIM steps of 20 (https://github.com/williamyang1991/Rerender_A_Video/issues/30#issuecomment-1729039779) 105 | 12. One-click installation https://github.com/williamyang1991/Rerender_A_Video/issues/99 106 | 107 | </details> 108 | 109 | 110 | ## (1) Inference 111 | 112 | ### WebUI (recommended) 113 | 114 | ``` 115 | python webUI.py 116 | ``` 117 | The Gradio app also allows you to flexibly change the inference options. Just try it for more details. (For WebUI, you need to download [revAnimated_v11](https://civitai.com/models/7371/rev-animated?modelVersionId=19575) and [realisticVisionV20_v20](https://civitai.com/models/4201?modelVersionId=29460) to `./models/` after Installation) 118 | 119 | Upload your video, input the prompt, select the seed, and hit: 120 | - **Run 1st Key Frame**: only translate the first frame, so you can adjust the prompts/models/parameters to find your ideal output appearance before running the whole video. 121 | - **Run Key Frames**: translate all the key frames based on the settings of the first frame, so you can adjust the temporal-related parameters for better temporal consistency before running the whole video. 122 | - **Run Propagation**: propagate the key frames to other frames for full video translation 123 | - **Run All**: **Run 1st Key Frame**, **Run Key Frames** and **Run Propagation** 124 | 125 |  126 | 127 | 128 | We provide abundant advanced options to play with 129 | 130 | <details id="option0"> 131 | <summary> <b>Using customized models</b></summary> 132 | 133 | - Using LoRA/Dreambooth/Finetuned/Mixed SD models 134 | - Modify `sd_model_cfg.py` to add paths to the saved SD models 135 | - How to use LoRA: https://github.com/williamyang1991/Rerender_A_Video/issues/39#issuecomment-1730678296 136 | - Using other controls from ControlNet (e.g., Depth, Pose) 137 | - Add more options like `control_type = gr.Dropdown(['HED', 'canny', 'depth']` here https://github.com/williamyang1991/Rerender_A_Video/blob/b6cafb5d80a79a3ef831c689ffad92ec095f2794/webUI.py#L690 138 | - Add model loading options like `elif control_type == 'depth':` following https://github.com/williamyang1991/Rerender_A_Video/blob/b6cafb5d80a79a3ef831c689ffad92ec095f2794/webUI.py#L88 139 | - Add model detectors like `elif control_type == 'depth':` following https://github.com/williamyang1991/Rerender_A_Video/blob/b6cafb5d80a79a3ef831c689ffad92ec095f2794/webUI.py#L122 140 | - One example is given [here](https://huggingface.co/spaces/Anonymous-sub/Rerender/discussions/10/files) 141 | 142 | </details> 143 | 144 | <details id="option1"> 145 | <summary> <b>Advanced options for the 1st frame translation</b></summary> 146 | 147 | 1. Resolution related (**Frame resolution**, **left/top/right/bottom crop length**): crop the frame and resize its short side to 512. 148 | 2. ControlNet related: 149 | - **ControlNet strength**: how well the output matches the input control edges 150 | - **Control type**: HED edge or Canny edge 151 | - **Canny low/high threshold**: low values for more edge details 152 | 3. SDEdit related: 153 | - **Denoising strength**: repaint degree (low value to make the output look more like the original video) 154 | - **Preserve color**: preserve the color of the original video 155 | 4. SD related: 156 | - **Steps**: denoising step 157 | - **CFG scale**: how well the output matches the prompt 158 | - **Base model**: base Stable Diffusion model (SD 1.5) 159 | - Stable Diffusion 1.5: official model 160 | - [revAnimated_v11](https://civitai.com/models/7371/rev-animated?modelVersionId=19575): a semi-realistic (2.5D) model 161 | - [realisticVisionV20_v20](https://civitai.com/models/4201?modelVersionId=29460): a photo-realistic model 162 | - **Added prompt/Negative prompt**: supplementary prompts 163 | 5. FreeU related: 164 | - **FreeU first/second-stage backbone factor**: =1 do nothing; >1 enhance output color and details 165 | - **FreeU first/second-stage skip factor**: =1 do nothing; <1 enhance output color and details 166 | 167 | </details> 168 | 169 | <details id="option2"> 170 | <summary> <b>Advanced options for the key frame translation</b></summary> 171 | 172 | 1. Key frame related 173 | - **Key frame frequency (K)**: Uniformly sample the key frame every K frames. Small value for large or fast motions. 174 | - **Number of key frames (M)**: The final output video will have K*M+1 frames with M+1 key frames. 175 | 2. Temporal consistency related 176 | - Cross-frame attention: 177 | - **Cross-frame attention start/end**: When applying cross-frame attention for global style consistency 178 | - **Cross-frame attention update frequency (N)**: Update the reference style frame every N key frames. Should be large for long videos to avoid error accumulation. 179 | - **Loose Cross-frame attention**: Using cross-frame attention in fewer layers to better match the input video (for video with large motions) 180 | - **Shape-aware fusion** Check to use this feature 181 | - **Shape-aware fusion start/end**: When applying shape-aware fusion for local shape consistency 182 | - **Pixel-aware fusion** Check to use this feature 183 | - **Pixel-aware fusion start/end**: When applying pixel-aware fusion for pixel-level temporal consistency 184 | - **Pixel-aware fusion strength**: The strength to preserve the non-inpainting region. Small to avoid error accumulation. Large to avoid burry textures. 185 | - **Pixel-aware fusion detail level**: The strength to sharpen the inpainting region. Small to avoid error accumulation. Large to avoid burry textures. 186 | - **Smooth fusion boundary**: Check to smooth the inpainting boundary (avoid error accumulation). 187 | - **Color-aware AdaIN** Check to use this feature 188 | - **Color-aware AdaIN start/end**: When applying AdaIN to make the video color consistent with the first frame 189 | 190 | </details> 191 | 192 | <details id="option3"> 193 | <summary> <b>Advanced options for the full video translation</b></summary> 194 | 195 | 1. **Gradient blending**: apply Poisson Blending to reduce ghosting artifacts. May slow the process and increase flickers. 196 | 2. **Number of parallel processes**: multiprocessing to speed up the process. Large value (8) is recommended. 197 | </details> 198 | 199 |  200 | 201 | 202 | ### Command Line 203 | 204 | We also provide a flexible script `rerender.py` to run our method. 205 | 206 | #### Simple mode 207 | 208 | Set the options via command line. For example, 209 | 210 | ```shell 211 | python rerender.py --input videos/pexels-antoni-shkraba-8048492-540x960-25fps.mp4 --output result/man/man.mp4 --prompt "a handsome man in van gogh painting" 212 | ``` 213 | 214 | The script will run the full pipeline. A work directory will be created at `result/man` and the result video will be saved as `result/man/man.mp4` 215 | 216 | #### Advanced mode 217 | 218 | Set the options via a config file. For example, 219 | 220 | ```shell 221 | python rerender.py --cfg config/van_gogh_man.json 222 | ``` 223 | 224 | The script will run the full pipeline. 225 | We provide some examples of the config in `config` directory. 226 | Most options in the config is the same as those in WebUI. 227 | Please check the explanations in the WebUI section. 228 | 229 | Specifying customized models by setting `sd_model` in config. For example: 230 | ```json 231 | { 232 | "sd_model": "models/realisticVisionV20_v20.safetensors", 233 | } 234 | ``` 235 | 236 | #### Customize the pipeline 237 | 238 | Similar to WebUI, we provide three-step workflow: Rerender the first key frame, then rerender the full key frames, finally rerender the full video with propagation. To run only a single step, specify options `-one`, `-nb` and `-nr`: 239 | 240 | 1. Rerender the first key frame 241 | ```shell 242 | python rerender.py --cfg config/van_gogh_man.json -one -nb 243 | ``` 244 | 2. Rerender the full key frames 245 | ```shell 246 | python rerender.py --cfg config/van_gogh_man.json -nb 247 | ``` 248 | 3. Rerender the full video with propagation 249 | ```shell 250 | python rerender.py --cfg config/van_gogh_man.json -nr 251 | ``` 252 | 253 | #### Our Ebsynth implementation 254 | 255 | We provide a separate Ebsynth python script `video_blend.py` with the temporal blending algorithm introduced in 256 | [Stylizing Video by Example](https://dcgi.fel.cvut.cz/home/sykorad/ebsynth.html) for interpolating style between key frames. 257 | It can work on your own stylized key frames independently of our Rerender algorithm. 258 | 259 | Usage: 260 | ```shell 261 | video_blend.py [-h] [--output OUTPUT] [--fps FPS] [--beg BEG] [--end END] [--itv ITV] [--key KEY] 262 | [--n_proc N_PROC] [-ps] [-ne] [-tmp] 263 | name 264 | 265 | positional arguments: 266 | name Path to input video 267 | 268 | optional arguments: 269 | -h, --help show this help message and exit 270 | --output OUTPUT Path to output video 271 | --fps FPS The FPS of output video 272 | --beg BEG The index of the first frame to be stylized 273 | --end END The index of the last frame to be stylized 274 | --itv ITV The interval of key frame 275 | --key KEY The subfolder name of stylized key frames 276 | --n_proc N_PROC The max process count 277 | -ps Use poisson gradient blending 278 | -ne Do not run ebsynth (use previous ebsynth output) 279 | -tmp Keep temporary output 280 | ``` 281 | For example, to run Ebsynth on video `man.mp4`, 282 | 1. Put the stylized key frames to `videos/man/keys` for every 10 frames (named as `0001.png`, `0011.png`, ...) 283 | 2. Put the original video frames in `videos/man/video` (named as `0001.png`, `0002.png`, ...). 284 | 3. Run Ebsynth on the first 101 frames of the video with poisson gradient blending and save the result to `videos/man/blend.mp4` under FPS 25 with the following command: 285 | ```shell 286 | python video_blend.py videos/man \ 287 | --beg 1 \ 288 | --end 101 \ 289 | --itv 10 \ 290 | --key keys \ 291 | --output videos/man/blend.mp4 \ 292 | --fps 25.0 \ 293 | -ps 294 | ``` 295 | 296 | ## (2) Results 297 | 298 | ### Key frame translation 299 | 300 | 301 | <table class="center"> 302 | <tr> 303 | <td><img src="https://github.com/williamyang1991/Rerender_A_Video/assets/18130694/18666871-f273-44b2-ae67-7be85d43e2f6" raw=true></td> 304 | <td><img src="https://github.com/williamyang1991/Rerender_A_Video/assets/18130694/61f59540-f06e-4e5a-86b6-1d7cb8ed6300" raw=true></td> 305 | <td><img src="https://github.com/williamyang1991/Rerender_A_Video/assets/18130694/8e8ad51a-6a71-4b34-8633-382192d0f17c" raw=true></td> 306 | <td><img src="https://github.com/williamyang1991/Rerender_A_Video/assets/18130694/b03cd35f-5d90-471a-9aa9-5c7773d7ac39" raw=true></td> 307 | </tr> 308 | <tr> 309 | <td width=27.5% align="center">white ancient Greek sculpture, Venus de Milo, light pink and blue background</td> 310 | <td width=27.5% align="center">a handsome Greek man</td> 311 | <td width=21.5% align="center">a traditional mountain in chinese ink wash painting</td> 312 | <td width=23.5% align="center">a cartoon tiger</td> 313 | </tr> 314 | </table> 315 | 316 | <table class="center"> 317 | <tr> 318 | <td><img src="https://github.com/williamyang1991/Rerender_A_Video/assets/18130694/649a789e-0c41-41cf-94a4-0d524dcfb282" raw=true></td> 319 | <td><img src="https://github.com/williamyang1991/Rerender_A_Video/assets/18130694/73590c16-916f-4ee6-881a-44a201dd85dd" raw=true></td> 320 | <td><img src="https://github.com/williamyang1991/Rerender_A_Video/assets/18130694/fbdc0b8e-6046-414f-a37e-3cd9dd0adf5d" raw=true></td> 321 | <td><img src="https://github.com/williamyang1991/Rerender_A_Video/assets/18130694/eb11d807-2afa-4609-a074-34300b67e6aa" raw=true></td> 322 | </tr> 323 | <tr> 324 | <td width=26.0% align="center">a swan in chinese ink wash painting, monochrome</td> 325 | <td width=29.0% align="center">a beautiful woman in CG style</td> 326 | <td width=21.5% align="center">a clean simple white jade sculpture</td> 327 | <td width=24.0% align="center">a fluorescent jellyfish in the deep dark blue sea</td> 328 | </tr> 329 | </table> 330 | 331 | ### Full video translation 332 | 333 | Text-guided virtual character generation. 334 | 335 | 336 | https://github.com/williamyang1991/Rerender_A_Video/assets/18130694/1405b257-e59a-427f-890d-7652e6bed0a4 337 | 338 | 339 | https://github.com/williamyang1991/Rerender_A_Video/assets/18130694/efee8cc6-9708-4124-bf6a-49baf91349fc 340 | 341 | 342 | Video stylization and video editing. 343 | 344 | 345 | https://github.com/williamyang1991/Rerender_A_Video/assets/18130694/1b72585c-99c0-401d-b240-5b8016df7a3f 346 | 347 | ## New Features 348 | 349 | Compared to the conference version, we are keeping adding new features. 350 | 351 |  352 | 353 | #### Loose cross-frame attention 354 | By using cross-frame attention in less layers, our results will better match the input video, thus reducing ghosting artifacts caused by inconsistencies. This feature can be activated by checking `Loose Cross-frame attention` in the <a href="#option2">Advanced options for the key frame translation</a> for WebUI or setting `loose_cfattn` for script (see `config/real2sculpture_loose_cfattn.json`). 355 | 356 | #### FreeU 357 | [FreeU](https://github.com/ChenyangSi/FreeU) is a method that improves diffusion model sample quality at no costs. We find featured with FreeU, our results will have higher contrast and saturation, richer details, and more vivid colors. This feature can be used by setting FreeU backbone factors and skip factors in the <a href="#option1">Advanced options for the 1st frame translation</a> for WebUI or setting `freeu_args` for script (see `config/real2sculpture_freeu.json`). 358 | 359 | ## Citation 360 | 361 | If you find this work useful for your research, please consider citing our paper: 362 | 363 | ```bibtex 364 | @inproceedings{yang2023rerender, 365 | title = {Rerender A Video: Zero-Shot Text-Guided Video-to-Video Translation}, 366 | author = {Yang, Shuai and Zhou, Yifan and Liu, Ziwei and and Loy, Chen Change}, 367 | booktitle = {ACM SIGGRAPH Asia Conference Proceedings}, 368 | year = {2023}, 369 | } 370 | ``` 371 | 372 | ## Acknowledgments 373 | 374 | The code is mainly developed based on [ControlNet](https://github.com/lllyasviel/ControlNet), [Stable Diffusion](https://github.com/Stability-AI/stablediffusion), [GMFlow](https://github.com/haofeixu/gmflow) and [Ebsynth](https://github.com/jamriska/ebsynth). 375 | -------------------------------------------------------------------------------- /blender/guide.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import numpy as np 5 | 6 | from flow.flow_utils import flow_calc, read_flow, read_mask 7 | 8 | 9 | class BaseGuide: 10 | 11 | def __init__(self): 12 | ... 13 | 14 | def get_cmd(self, i, weight) -> str: 15 | return (f'-guide {os.path.abspath(self.imgs[0])} ' 16 | f'{os.path.abspath(self.imgs[i])} -weight {weight}') 17 | 18 | 19 | class ColorGuide(BaseGuide): 20 | 21 | def __init__(self, imgs): 22 | super().__init__() 23 | self.imgs = imgs 24 | 25 | 26 | class PositionalGuide(BaseGuide): 27 | 28 | def __init__(self, flow_paths, save_paths): 29 | super().__init__() 30 | flows = [read_flow(f) for f in flow_paths] 31 | masks = [read_mask(f) for f in flow_paths] 32 | # TODO: modify the format of flow to numpy 33 | H, W = flows[0].shape[2:] 34 | first_img = PositionalGuide.__generate_first_img(H, W) 35 | prev_img = first_img 36 | imgs = [first_img] 37 | cid = 0 38 | for flow, mask in zip(flows, masks): 39 | cur_img = flow_calc.warp(prev_img, flow, 40 | 'nearest').astype(np.uint8) 41 | cur_img = cv2.inpaint(cur_img, mask, 30, cv2.INPAINT_TELEA) 42 | prev_img = cur_img 43 | imgs.append(cur_img) 44 | cid += 1 45 | cv2.imwrite(f'guide/{cid}.jpg', mask) 46 | 47 | for path, img in zip(save_paths, imgs): 48 | cv2.imwrite(path, img) 49 | self.imgs = save_paths 50 | 51 | @staticmethod 52 | def __generate_first_img(H, W): 53 | Hs = np.linspace(0, 1, H) 54 | Ws = np.linspace(0, 1, W) 55 | i, j = np.meshgrid(Hs, Ws, indexing='ij') 56 | r = (i * 255).astype(np.uint8) 57 | g = (j * 255).astype(np.uint8) 58 | b = np.zeros(r.shape) 59 | res = np.stack((b, g, r), 2) 60 | return res 61 | 62 | 63 | class EdgeGuide(BaseGuide): 64 | 65 | def __init__(self, imgs, save_paths): 66 | super().__init__() 67 | edges = [EdgeGuide.__generate_edge(cv2.imread(img)) for img in imgs] 68 | for path, img in zip(save_paths, edges): 69 | cv2.imwrite(path, img) 70 | self.imgs = save_paths 71 | 72 | @staticmethod 73 | def __generate_edge(img): 74 | filter = np.array([[0, -1, 0], [-1, 4, -1], [0, -1, 0]]) 75 | res = cv2.filter2D(img, -1, filter) 76 | return res 77 | 78 | 79 | class TemporalGuide(BaseGuide): 80 | 81 | def __init__(self, key_img, stylized_imgs, flow_paths, save_paths): 82 | super().__init__() 83 | self.flows = [read_flow(f) for f in flow_paths] 84 | self.masks = [read_mask(f) for f in flow_paths] 85 | self.stylized_imgs = stylized_imgs 86 | self.imgs = save_paths 87 | 88 | first_img = cv2.imread(key_img) 89 | cv2.imwrite(self.imgs[0], first_img) 90 | 91 | def get_cmd(self, i, weight) -> str: 92 | if i == 0: 93 | warped_img = self.stylized_imgs[0] 94 | else: 95 | prev_img = cv2.imread(self.stylized_imgs[i - 1]) 96 | warped_img = flow_calc.warp(prev_img, self.flows[i - 1], 97 | 'nearest').astype(np.uint8) 98 | 99 | warped_img = cv2.inpaint(warped_img, self.masks[i - 1], 30, 100 | cv2.INPAINT_TELEA) 101 | 102 | cv2.imwrite(self.imgs[i], warped_img) 103 | 104 | return super().get_cmd(i, weight) 105 | -------------------------------------------------------------------------------- /blender/histogram_blend.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def histogram_transform(img: np.ndarray, means: np.ndarray, stds: np.ndarray, 6 | target_means: np.ndarray, target_stds: np.ndarray): 7 | means = means.reshape((1, 1, 3)) 8 | stds = stds.reshape((1, 1, 3)) 9 | target_means = target_means.reshape((1, 1, 3)) 10 | target_stds = target_stds.reshape((1, 1, 3)) 11 | x = img.astype(np.float32) 12 | x = (x - means) * target_stds / stds + target_means 13 | # x = np.round(x) 14 | # x = np.clip(x, 0, 255) 15 | # x = x.astype(np.uint8) 16 | return x 17 | 18 | 19 | def blend(a: np.ndarray, 20 | b: np.ndarray, 21 | min_error: np.ndarray, 22 | weight1=0.5, 23 | weight2=0.5): 24 | a = cv2.cvtColor(a, cv2.COLOR_BGR2Lab) 25 | b = cv2.cvtColor(b, cv2.COLOR_BGR2Lab) 26 | min_error = cv2.cvtColor(min_error, cv2.COLOR_BGR2Lab) 27 | a_mean = np.mean(a, axis=(0, 1)) 28 | a_std = np.std(a, axis=(0, 1)) 29 | b_mean = np.mean(b, axis=(0, 1)) 30 | b_std = np.std(b, axis=(0, 1)) 31 | min_error_mean = np.mean(min_error, axis=(0, 1)) 32 | min_error_std = np.std(min_error, axis=(0, 1)) 33 | 34 | t_mean_val = 0.5 * 256 35 | t_std_val = (1 / 36) * 256 36 | t_mean = np.ones([3], dtype=np.float32) * t_mean_val 37 | t_std = np.ones([3], dtype=np.float32) * t_std_val 38 | a = histogram_transform(a, a_mean, a_std, t_mean, t_std) 39 | 40 | b = histogram_transform(b, b_mean, b_std, t_mean, t_std) 41 | ab = (a * weight1 + b * weight2 - t_mean_val) / 0.5 + t_mean_val 42 | ab_mean = np.mean(ab, axis=(0, 1)) 43 | ab_std = np.std(ab, axis=(0, 1)) 44 | ab = histogram_transform(ab, ab_mean, ab_std, min_error_mean, 45 | min_error_std) 46 | ab = np.round(ab) 47 | ab = np.clip(ab, 0, 255) 48 | ab = ab.astype(np.uint8) 49 | ab = cv2.cvtColor(ab, cv2.COLOR_Lab2BGR) 50 | return ab 51 | -------------------------------------------------------------------------------- /blender/poisson_fusion.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import scipy 4 | 5 | As = None 6 | prev_states = None 7 | 8 | 9 | def construct_A(h, w, grad_weight): 10 | indgx_x = [] 11 | indgx_y = [] 12 | indgy_x = [] 13 | indgy_y = [] 14 | vdx = [] 15 | vdy = [] 16 | for i in range(h): 17 | for j in range(w): 18 | if i < h - 1: 19 | indgx_x += [i * w + j] 20 | indgx_y += [i * w + j] 21 | vdx += [1] 22 | indgx_x += [i * w + j] 23 | indgx_y += [(i + 1) * w + j] 24 | vdx += [-1] 25 | if j < w - 1: 26 | indgy_x += [i * w + j] 27 | indgy_y += [i * w + j] 28 | vdy += [1] 29 | indgy_x += [i * w + j] 30 | indgy_y += [i * w + j + 1] 31 | vdy += [-1] 32 | Ix = scipy.sparse.coo_array( 33 | (np.ones(h * w), (np.arange(h * w), np.arange(h * w))), 34 | shape=(h * w, h * w)).tocsc() 35 | Gx = scipy.sparse.coo_array( 36 | (np.array(vdx), (np.array(indgx_x), np.array(indgx_y))), 37 | shape=(h * w, h * w)).tocsc() 38 | Gy = scipy.sparse.coo_array( 39 | (np.array(vdy), (np.array(indgy_x), np.array(indgy_y))), 40 | shape=(h * w, h * w)).tocsc() 41 | As = [] 42 | for i in range(3): 43 | As += [ 44 | scipy.sparse.vstack([Gx * grad_weight[i], Gy * grad_weight[i], Ix]) 45 | ] 46 | return As 47 | 48 | 49 | # blendI, I1, I2, mask should be RGB unit8 type 50 | # return poissson fusion result (RGB unit8 type) 51 | # I1 and I2: propagated results from previous and subsequent key frames 52 | # mask: pixel selection mask 53 | # blendI: contrastive-preserving blending results of I1 and I2 54 | def poisson_fusion(blendI, I1, I2, mask, grad_weight=[2.5, 0.5, 0.5]): 55 | global As 56 | global prev_states 57 | 58 | Iab = cv2.cvtColor(blendI, cv2.COLOR_BGR2LAB).astype(float) 59 | Ia = cv2.cvtColor(I1, cv2.COLOR_BGR2LAB).astype(float) 60 | Ib = cv2.cvtColor(I2, cv2.COLOR_BGR2LAB).astype(float) 61 | m = (mask > 0).astype(float)[:, :, np.newaxis] 62 | h, w, c = Iab.shape 63 | 64 | # fuse the gradient of I1 and I2 with mask 65 | gx = np.zeros_like(Ia) 66 | gy = np.zeros_like(Ia) 67 | gx[:-1, :, :] = (Ia[:-1, :, :] - Ia[1:, :, :]) * (1 - m[:-1, :, :]) + ( 68 | Ib[:-1, :, :] - Ib[1:, :, :]) * m[:-1, :, :] 69 | gy[:, :-1, :] = (Ia[:, :-1, :] - Ia[:, 1:, :]) * (1 - m[:, :-1, :]) + ( 70 | Ib[:, :-1, :] - Ib[:, 1:, :]) * m[:, :-1, :] 71 | 72 | # construct A for solving Ax=b 73 | crt_states = (h, w, grad_weight) 74 | if As is None or crt_states != prev_states: 75 | As = construct_A(*crt_states) 76 | prev_states = crt_states 77 | 78 | final = [] 79 | for i in range(3): 80 | weight = grad_weight[i] 81 | im_dx = np.clip(gx[:, :, i].reshape(h * w, 1), -100, 100) 82 | im_dy = np.clip(gy[:, :, i].reshape(h * w, 1), -100, 100) 83 | im = Iab[:, :, i].reshape(h * w, 1) 84 | im_mean = im.mean() 85 | im = im - im_mean 86 | A = As[i] 87 | b = np.vstack([im_dx * weight, im_dy * weight, im]) 88 | out = scipy.sparse.linalg.lsqr(A, b) 89 | out_im = (out[0] + im_mean).reshape(h, w, 1) 90 | final += [out_im] 91 | 92 | final = np.clip(np.concatenate(final, axis=2), 0, 255) 93 | return cv2.cvtColor(final.astype(np.uint8), cv2.COLOR_LAB2BGR) 94 | -------------------------------------------------------------------------------- /blender/video_sequence.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | 5 | class VideoSequence: 6 | 7 | def __init__(self, 8 | base_dir, 9 | beg_frame, 10 | end_frame, 11 | interval, 12 | input_subdir='videos', 13 | key_subdir='keys0', 14 | tmp_subdir='tmp', 15 | input_format='frame%04d.jpg', 16 | key_format='%04d.jpg', 17 | out_subdir_format='out_%d', 18 | blending_out_subdir='blend', 19 | output_format='%04d.jpg'): 20 | if (end_frame - beg_frame) % interval != 0: 21 | end_frame -= (end_frame - beg_frame) % interval 22 | 23 | self.__base_dir = base_dir 24 | self.__input_dir = os.path.join(base_dir, input_subdir) 25 | self.__key_dir = os.path.join(base_dir, key_subdir) 26 | self.__tmp_dir = os.path.join(base_dir, tmp_subdir) 27 | self.__input_format = input_format 28 | self.__blending_out_dir = os.path.join(base_dir, blending_out_subdir) 29 | self.__key_format = key_format 30 | self.__out_subdir_format = out_subdir_format 31 | self.__output_format = output_format 32 | self.__beg_frame = beg_frame 33 | self.__end_frame = end_frame 34 | self.__interval = interval 35 | self.__n_seq = (end_frame - beg_frame) // interval 36 | self.__make_out_dirs() 37 | os.makedirs(self.__tmp_dir, exist_ok=True) 38 | 39 | @property 40 | def beg_frame(self): 41 | return self.__beg_frame 42 | 43 | @property 44 | def end_frame(self): 45 | return self.__end_frame 46 | 47 | @property 48 | def n_seq(self): 49 | return self.__n_seq 50 | 51 | @property 52 | def interval(self): 53 | return self.__interval 54 | 55 | @property 56 | def blending_dir(self): 57 | return os.path.abspath(self.__blending_out_dir) 58 | 59 | def remove_out_and_tmp(self): 60 | for i in range(self.n_seq + 1): 61 | out_dir = self.__get_out_subdir(i) 62 | shutil.rmtree(out_dir) 63 | shutil.rmtree(self.__tmp_dir) 64 | 65 | def get_input_sequence(self, i, is_forward=True): 66 | beg_id = self.get_sequence_beg_id(i) 67 | end_id = self.get_sequence_beg_id(i + 1) 68 | if is_forward: 69 | id_list = list(range(beg_id, end_id)) 70 | else: 71 | id_list = list(range(end_id, beg_id, -1)) 72 | path_dir = [ 73 | os.path.join(self.__input_dir, self.__input_format % id) 74 | for id in id_list 75 | ] 76 | return path_dir 77 | 78 | def get_output_sequence(self, i, is_forward=True): 79 | beg_id = self.get_sequence_beg_id(i) 80 | end_id = self.get_sequence_beg_id(i + 1) 81 | if is_forward: 82 | id_list = list(range(beg_id, end_id)) 83 | else: 84 | i += 1 85 | id_list = list(range(end_id, beg_id, -1)) 86 | out_subdir = self.__get_out_subdir(i) 87 | path_dir = [ 88 | os.path.join(out_subdir, self.__output_format % id) 89 | for id in id_list 90 | ] 91 | return path_dir 92 | 93 | def get_temporal_sequence(self, i, is_forward=True): 94 | beg_id = self.get_sequence_beg_id(i) 95 | end_id = self.get_sequence_beg_id(i + 1) 96 | if is_forward: 97 | id_list = list(range(beg_id, end_id)) 98 | else: 99 | i += 1 100 | id_list = list(range(end_id, beg_id, -1)) 101 | tmp_dir = self.__get_tmp_out_subdir(i) 102 | path_dir = [ 103 | os.path.join(tmp_dir, 'temporal_' + self.__output_format % id) 104 | for id in id_list 105 | ] 106 | return path_dir 107 | 108 | def get_edge_sequence(self, i, is_forward=True): 109 | beg_id = self.get_sequence_beg_id(i) 110 | end_id = self.get_sequence_beg_id(i + 1) 111 | if is_forward: 112 | id_list = list(range(beg_id, end_id)) 113 | else: 114 | i += 1 115 | id_list = list(range(end_id, beg_id, -1)) 116 | tmp_dir = self.__get_tmp_out_subdir(i) 117 | path_dir = [ 118 | os.path.join(tmp_dir, 'edge_' + self.__output_format % id) 119 | for id in id_list 120 | ] 121 | return path_dir 122 | 123 | def get_pos_sequence(self, i, is_forward=True): 124 | beg_id = self.get_sequence_beg_id(i) 125 | end_id = self.get_sequence_beg_id(i + 1) 126 | if is_forward: 127 | id_list = list(range(beg_id, end_id)) 128 | else: 129 | i += 1 130 | id_list = list(range(end_id, beg_id, -1)) 131 | tmp_dir = self.__get_tmp_out_subdir(i) 132 | path_dir = [ 133 | os.path.join(tmp_dir, 'pos_' + self.__output_format % id) 134 | for id in id_list 135 | ] 136 | return path_dir 137 | 138 | def get_flow_sequence(self, i, is_forward=True): 139 | beg_id = self.get_sequence_beg_id(i) 140 | end_id = self.get_sequence_beg_id(i + 1) 141 | if is_forward: 142 | id_list = list(range(beg_id, end_id - 1)) 143 | path_dir = [ 144 | os.path.join(self.__tmp_dir, 'flow_f_%04d.npy' % id) 145 | for id in id_list 146 | ] 147 | else: 148 | id_list = list(range(end_id, beg_id + 1, -1)) 149 | path_dir = [ 150 | os.path.join(self.__tmp_dir, 'flow_b_%04d.npy' % id) 151 | for id in id_list 152 | ] 153 | 154 | return path_dir 155 | 156 | def get_input_img(self, i): 157 | return os.path.join(self.__input_dir, self.__input_format % i) 158 | 159 | def get_key_img(self, i): 160 | sequence_beg_id = self.get_sequence_beg_id(i) 161 | return os.path.join(self.__key_dir, 162 | self.__key_format % sequence_beg_id) 163 | 164 | def get_blending_img(self, i): 165 | return os.path.join(self.__blending_out_dir, self.__output_format % i) 166 | 167 | def get_sequence_beg_id(self, i): 168 | return i * self.__interval + self.__beg_frame 169 | 170 | def __get_out_subdir(self, i): 171 | dir_id = self.get_sequence_beg_id(i) 172 | out_subdir = os.path.join(self.__base_dir, 173 | self.__out_subdir_format % dir_id) 174 | return out_subdir 175 | 176 | def __get_tmp_out_subdir(self, i): 177 | dir_id = self.get_sequence_beg_id(i) 178 | tmp_out_subdir = os.path.join(self.__tmp_dir, 179 | self.__out_subdir_format % dir_id) 180 | return tmp_out_subdir 181 | 182 | def __make_out_dirs(self): 183 | os.makedirs(self.__base_dir, exist_ok=True) 184 | os.makedirs(self.__blending_out_dir, exist_ok=True) 185 | for i in range(self.__n_seq + 1): 186 | out_subdir = self.__get_out_subdir(i) 187 | tmp_subdir = self.__get_tmp_out_subdir(i) 188 | os.makedirs(out_subdir, exist_ok=True) 189 | os.makedirs(tmp_subdir, exist_ok=True) 190 | -------------------------------------------------------------------------------- /config/real2sculpture.json: -------------------------------------------------------------------------------- 1 | { 2 | "input": "videos/pexels-cottonbro-studio-6649832-960x506-25fps.mp4", 3 | "output": "videos/pexels-cottonbro-studio-6649832-960x506-25fps/blend.mp4", 4 | "work_dir": "videos/pexels-cottonbro-studio-6649832-960x506-25fps", 5 | "key_subdir": "keys", 6 | "sd_model": "models/realisticVisionV20_v20.safetensors", 7 | "frame_count": 102, 8 | "interval": 10, 9 | "crop": [ 10 | 0, 11 | 180, 12 | 0, 13 | 0 14 | ], 15 | "prompt": "white ancient Greek sculpture, Venus de Milo, light pink and blue background", 16 | "a_prompt": "RAW photo, subject, (high detailed skin:1.2), 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3", 17 | "n_prompt": "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers:1.4), (deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation", 18 | "x0_strength": 0.95, 19 | "control_type": "canny", 20 | "canny_low": 50, 21 | "canny_high": 100, 22 | "control_strength": 0.7, 23 | "seed": 0, 24 | "warp_period": [ 25 | 0, 26 | 0.1 27 | ], 28 | "ada_period": [ 29 | 0.8, 30 | 1 31 | ] 32 | } -------------------------------------------------------------------------------- /config/real2sculpture_freeu.json: -------------------------------------------------------------------------------- 1 | { 2 | "input": "videos/pexels-cottonbro-studio-6649832-960x506-25fps.mp4", 3 | "output": "videos/pexels-cottonbro-studio-6649832-960x506-25fps/blend.mp4", 4 | "work_dir": "videos/pexels-cottonbro-studio-6649832-960x506-25fps", 5 | "key_subdir": "keys", 6 | "sd_model": "models/realisticVisionV20_v20.safetensors", 7 | "frame_count": 102, 8 | "interval": 10, 9 | "crop": [ 10 | 0, 11 | 180, 12 | 0, 13 | 0 14 | ], 15 | "prompt": "white ancient Greek sculpture, Venus de Milo, light pink and blue background", 16 | "a_prompt": "RAW photo, subject, (high detailed skin:1.2), 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3", 17 | "n_prompt": "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers:1.4), (deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation", 18 | "x0_strength": 0.95, 19 | "control_type": "canny", 20 | "canny_low": 50, 21 | "canny_high": 100, 22 | "control_strength": 0.7, 23 | "seed": 0, 24 | "warp_period": [ 25 | 0, 26 | 0.1 27 | ], 28 | "ada_period": [ 29 | 0.8, 30 | 1 31 | ], 32 | "freeu_args": [ 33 | 1.1, 34 | 1.2, 35 | 1.0, 36 | 0.2 37 | ] 38 | } -------------------------------------------------------------------------------- /config/real2sculpture_loose_cfattn.json: -------------------------------------------------------------------------------- 1 | { 2 | "input": "videos/pexels-cottonbro-studio-6649832-960x506-25fps.mp4", 3 | "output": "videos/pexels-cottonbro-studio-6649832-960x506-25fps/blend.mp4", 4 | "work_dir": "videos/pexels-cottonbro-studio-6649832-960x506-25fps", 5 | "key_subdir": "keys", 6 | "sd_model": "models/realisticVisionV20_v20.safetensors", 7 | "frame_count": 102, 8 | "interval": 10, 9 | "crop": [ 10 | 0, 11 | 180, 12 | 0, 13 | 0 14 | ], 15 | "prompt": "white ancient Greek sculpture, Venus de Milo, light pink and blue background", 16 | "a_prompt": "RAW photo, subject, (high detailed skin:1.2), 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3", 17 | "n_prompt": "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers:1.4), (deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation", 18 | "x0_strength": 0.95, 19 | "control_type": "canny", 20 | "canny_low": 50, 21 | "canny_high": 100, 22 | "control_strength": 0.7, 23 | "seed": 0, 24 | "warp_period": [ 25 | 0, 26 | 0.1 27 | ], 28 | "ada_period": [ 29 | 0.8, 30 | 1 31 | ], 32 | "loose_cfattn": true 33 | } -------------------------------------------------------------------------------- /config/van_gogh_man.json: -------------------------------------------------------------------------------- 1 | { 2 | "input": "videos/pexels-antoni-shkraba-8048492-540x960-25fps.mp4", 3 | "output": "videos/pexels-antoni-shkraba-8048492-540x960-25fps/blend.mp4", 4 | "work_dir": "videos/pexels-antoni-shkraba-8048492-540x960-25fps", 5 | "key_subdir": "keys", 6 | "frame_count": 102, 7 | "interval": 10, 8 | "crop": [ 9 | 0, 10 | 0, 11 | 0, 12 | 280 13 | ], 14 | "prompt": "a handsome man in van gogh painting", 15 | "a_prompt": "best quality, extremely detailed", 16 | "n_prompt": "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", 17 | "x0_strength": 1.05, 18 | "control_type": "canny", 19 | "canny_low": 50, 20 | "canny_high": 100, 21 | "control_strength": 0.7, 22 | "seed": 0, 23 | "warp_period": [ 24 | 0, 25 | 0.1 26 | ], 27 | "ada_period": [ 28 | 0.8, 29 | 1 30 | ], 31 | "image_resolution": 512 32 | } 33 | -------------------------------------------------------------------------------- /config/van_gogh_man_dynamic_resolution.json: -------------------------------------------------------------------------------- 1 | { 2 | "input": "videos/pexels-antoni-shkraba-8048492-540x960-25fps.mp4", 3 | "output": "videos/pexels-antoni-shkraba-8048492-540x960-25fps/blend.mp4", 4 | "work_dir": "videos/pexels-antoni-shkraba-8048492-540x960-25fps", 5 | "key_subdir": "keys", 6 | "frame_count": 102, 7 | "interval": 10, 8 | "crop": [ 9 | 0, 10 | 0, 11 | 0, 12 | 280 13 | ], 14 | "prompt": "a handsome man in van gogh painting", 15 | "a_prompt": "best quality, extremely detailed", 16 | "n_prompt": "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", 17 | "x0_strength": 1.05, 18 | "control_type": "canny", 19 | "canny_low": 50, 20 | "canny_high": 100, 21 | "control_strength": 0.7, 22 | "seed": 0, 23 | "warp_period": [ 24 | 0, 25 | 0.1 26 | ], 27 | "ada_period": [ 28 | 0.8, 29 | 1 30 | ], 31 | "image_resolution": 512, 32 | "use_limit_device_resolution": true 33 | } -------------------------------------------------------------------------------- /config/woman.json: -------------------------------------------------------------------------------- 1 | { 2 | "input": "videos/pexels-koolshooters-7322716.mp4", 3 | "output": "videos/pexels-koolshooters-7322716/blend.mp4", 4 | "work_dir": "videos/pexels-koolshooters-7322716", 5 | "key_subdir": "keys", 6 | "frame_count": 102, 7 | "interval": 10, 8 | "sd_model": "models/revAnimated_v11.safetensors", 9 | "prompt": "a beautiful woman in CG style", 10 | "a_prompt": "best quality, extremely detailed", 11 | "n_prompt": "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", 12 | "x0_strength": 0.75, 13 | "control_type": "canny", 14 | "canny_low": 50, 15 | "canny_high": 100, 16 | "control_strength": 0.7, 17 | "seed": 0, 18 | "warp_period": [ 19 | 0, 20 | 0.1 21 | ], 22 | "ada_period": [ 23 | 1, 24 | 1 25 | ] 26 | } -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: rerender 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8.5 7 | - pip=20.3 8 | - cudatoolkit=11.3 9 | - pytorch=1.12.1 10 | - torchvision=0.13.1 11 | - numpy=1.23.1 12 | - pip: 13 | - gradio==3.44.4 14 | - albumentations==1.3.0 15 | - opencv-contrib-python==4.3.0.36 16 | - imageio==2.9.0 17 | - imageio-ffmpeg==0.4.2 18 | - pytorch-lightning==1.5.0 19 | - omegaconf==2.1.1 20 | - test-tube>=0.7.5 21 | - streamlit==1.12.1 22 | - einops==0.3.0 23 | - transformers==4.19.2 24 | - webdataset==0.2.5 25 | - kornia==0.6 26 | - open_clip_torch==2.0.2 27 | - invisible-watermark>=0.1.5 28 | - streamlit-drawable-canvas==0.8.0 29 | - torchmetrics==0.6.0 30 | - timm==0.6.12 31 | - addict==2.4.0 32 | - yapf==0.32.0 33 | - prettytable==3.6.0 34 | - safetensors==0.2.7 35 | - basicsr==1.4.2 36 | - blendmodes 37 | - numba==0.57.0 38 | -------------------------------------------------------------------------------- /flow/flow_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 10 | gmflow_dir = os.path.join(parent_dir, 'deps/gmflow') 11 | sys.path.insert(0, gmflow_dir) 12 | 13 | from gmflow.gmflow import GMFlow # noqa: E702 E402 F401 14 | from utils.utils import InputPadder # noqa: E702 E402 15 | 16 | 17 | def coords_grid(b, h, w, homogeneous=False, device=None): 18 | y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W] 19 | 20 | stacks = [x, y] 21 | 22 | if homogeneous: 23 | ones = torch.ones_like(x) # [H, W] 24 | stacks.append(ones) 25 | 26 | grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W] 27 | 28 | grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W] 29 | 30 | if device is not None: 31 | grid = grid.to(device) 32 | 33 | return grid 34 | 35 | 36 | def bilinear_sample(img, 37 | sample_coords, 38 | mode='bilinear', 39 | padding_mode='zeros', 40 | return_mask=False): 41 | # img: [B, C, H, W] 42 | # sample_coords: [B, 2, H, W] in image scale 43 | if sample_coords.size(1) != 2: # [B, H, W, 2] 44 | sample_coords = sample_coords.permute(0, 3, 1, 2) 45 | 46 | b, _, h, w = sample_coords.shape 47 | 48 | # Normalize to [-1, 1] 49 | x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1 50 | y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1 51 | 52 | grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2] 53 | 54 | img = F.grid_sample(img, 55 | grid, 56 | mode=mode, 57 | padding_mode=padding_mode, 58 | align_corners=True) 59 | 60 | if return_mask: 61 | mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & ( 62 | y_grid <= 1) # [B, H, W] 63 | 64 | return img, mask 65 | 66 | return img 67 | 68 | 69 | def flow_warp(feature, 70 | flow, 71 | mask=False, 72 | mode='bilinear', 73 | padding_mode='zeros'): 74 | b, c, h, w = feature.size() 75 | assert flow.size(1) == 2 76 | 77 | grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W] 78 | 79 | return bilinear_sample(feature, 80 | grid, 81 | mode=mode, 82 | padding_mode=padding_mode, 83 | return_mask=mask) 84 | 85 | 86 | def forward_backward_consistency_check(fwd_flow, 87 | bwd_flow, 88 | alpha=0.01, 89 | beta=0.5): 90 | # fwd_flow, bwd_flow: [B, 2, H, W] 91 | # alpha and beta values are following UnFlow 92 | # (https://arxiv.org/abs/1711.07837) 93 | assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4 94 | assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2 95 | flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, 96 | dim=1) # [B, H, W] 97 | 98 | warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W] 99 | warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W] 100 | 101 | diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W] 102 | diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1) 103 | 104 | threshold = alpha * flow_mag + beta 105 | 106 | fwd_occ = (diff_fwd > threshold).float() # [B, H, W] 107 | bwd_occ = (diff_bwd > threshold).float() 108 | 109 | return fwd_occ, bwd_occ 110 | 111 | 112 | @torch.no_grad() 113 | def get_warped_and_mask(flow_model, 114 | image1, 115 | image2, 116 | image3=None, 117 | pixel_consistency=False): 118 | if image3 is None: 119 | image3 = image1 120 | padder = InputPadder(image1.shape, padding_factor=8) 121 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) 122 | results_dict = flow_model(image1, 123 | image2, 124 | attn_splits_list=[2], 125 | corr_radius_list=[-1], 126 | prop_radius_list=[-1], 127 | pred_bidir_flow=True) 128 | flow_pr = results_dict['flow_preds'][-1] # [B, 2, H, W] 129 | fwd_flow = padder.unpad(flow_pr[0]).unsqueeze(0) # [1, 2, H, W] 130 | bwd_flow = padder.unpad(flow_pr[1]).unsqueeze(0) # [1, 2, H, W] 131 | fwd_occ, bwd_occ = forward_backward_consistency_check( 132 | fwd_flow, bwd_flow) # [1, H, W] float 133 | if pixel_consistency: 134 | warped_image1 = flow_warp(image1, bwd_flow) 135 | bwd_occ = torch.clamp( 136 | bwd_occ + 137 | (abs(image2 - warped_image1).mean(dim=1) > 255 * 0.25).float(), 0, 138 | 1).unsqueeze(0) 139 | warped_results = flow_warp(image3, bwd_flow) 140 | return warped_results, bwd_occ, bwd_flow 141 | 142 | 143 | class FlowCalc(): 144 | 145 | def __init__(self, model_path='./models/gmflow_sintel-0c07dcb3.pth'): 146 | flow_model = GMFlow( 147 | feature_channels=128, 148 | num_scales=1, 149 | upsample_factor=8, 150 | num_head=1, 151 | attention_type='swin', 152 | ffn_dim_expansion=4, 153 | num_transformer_layers=6, 154 | ).to('cuda') 155 | 156 | checkpoint = torch.load(model_path, 157 | map_location=lambda storage, loc: storage) 158 | weights = checkpoint['model'] if 'model' in checkpoint else checkpoint 159 | flow_model.load_state_dict(weights, strict=False) 160 | flow_model.eval() 161 | self.model = flow_model 162 | 163 | @torch.no_grad() 164 | def get_flow(self, image1, image2, save_path=None): 165 | 166 | if save_path is not None and os.path.exists(save_path): 167 | bwd_flow = read_flow(save_path) 168 | return bwd_flow 169 | 170 | image1 = torch.from_numpy(image1).permute(2, 0, 1).float() 171 | image2 = torch.from_numpy(image2).permute(2, 0, 1).float() 172 | padder = InputPadder(image1.shape, padding_factor=8) 173 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) 174 | results_dict = self.model(image1, 175 | image2, 176 | attn_splits_list=[2], 177 | corr_radius_list=[-1], 178 | prop_radius_list=[-1], 179 | pred_bidir_flow=True) 180 | flow_pr = results_dict['flow_preds'][-1] # [B, 2, H, W] 181 | fwd_flow = padder.unpad(flow_pr[0]).unsqueeze(0) # [1, 2, H, W] 182 | bwd_flow = padder.unpad(flow_pr[1]).unsqueeze(0) # [1, 2, H, W] 183 | fwd_occ, bwd_occ = forward_backward_consistency_check( 184 | fwd_flow, bwd_flow) # [1, H, W] float 185 | if save_path is not None: 186 | flow_np = bwd_flow.cpu().numpy() 187 | np.save(save_path, flow_np) 188 | mask_path = os.path.splitext(save_path)[0] + '.png' 189 | bwd_occ = bwd_occ.cpu().permute(1, 2, 0).to( 190 | torch.long).numpy() * 255 191 | cv2.imwrite(mask_path, bwd_occ) 192 | 193 | return bwd_flow 194 | 195 | @torch.no_grad() 196 | def get_mask(self, image1, image2, save_path=None): 197 | 198 | if save_path is not None: 199 | mask_path = os.path.splitext(save_path)[0] + '.png' 200 | if os.path.exists(mask_path): 201 | return read_mask(mask_path) 202 | 203 | image1 = torch.from_numpy(image1).permute(2, 0, 1).float() 204 | image2 = torch.from_numpy(image2).permute(2, 0, 1).float() 205 | padder = InputPadder(image1.shape, padding_factor=8) 206 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) 207 | results_dict = self.model(image1, 208 | image2, 209 | attn_splits_list=[2], 210 | corr_radius_list=[-1], 211 | prop_radius_list=[-1], 212 | pred_bidir_flow=True) 213 | flow_pr = results_dict['flow_preds'][-1] # [B, 2, H, W] 214 | fwd_flow = padder.unpad(flow_pr[0]).unsqueeze(0) # [1, 2, H, W] 215 | bwd_flow = padder.unpad(flow_pr[1]).unsqueeze(0) # [1, 2, H, W] 216 | fwd_occ, bwd_occ = forward_backward_consistency_check( 217 | fwd_flow, bwd_flow) # [1, H, W] float 218 | if save_path is not None: 219 | flow_np = bwd_flow.cpu().numpy() 220 | np.save(save_path, flow_np) 221 | mask_path = os.path.splitext(save_path)[0] + '.png' 222 | bwd_occ = bwd_occ.cpu().permute(1, 2, 0).to( 223 | torch.long).numpy() * 255 224 | cv2.imwrite(mask_path, bwd_occ) 225 | 226 | return bwd_occ 227 | 228 | def warp(self, img, flow, mode='bilinear'): 229 | expand = False 230 | if len(img.shape) == 2: 231 | expand = True 232 | img = np.expand_dims(img, 2) 233 | 234 | img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0) 235 | dtype = img.dtype 236 | img = img.to(torch.float) 237 | res = flow_warp(img, flow, mode=mode) 238 | res = res.to(dtype) 239 | res = res[0].cpu().permute(1, 2, 0).numpy() 240 | if expand: 241 | res = res[:, :, 0] 242 | return res 243 | 244 | 245 | def read_flow(save_path): 246 | flow_np = np.load(save_path) 247 | bwd_flow = torch.from_numpy(flow_np) 248 | return bwd_flow 249 | 250 | 251 | def read_mask(save_path): 252 | mask_path = os.path.splitext(save_path)[0] + '.png' 253 | mask = cv2.imread(mask_path) 254 | mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) 255 | return mask 256 | 257 | 258 | flow_calc = FlowCalc() 259 | -------------------------------------------------------------------------------- /install.py: -------------------------------------------------------------------------------- 1 | import os 2 | import platform 3 | 4 | import requests 5 | 6 | 7 | def build_ebsynth(): 8 | if os.path.exists('deps/ebsynth/bin/ebsynth'): 9 | print('Ebsynth has been built.') 10 | return 11 | 12 | os_str = platform.system() 13 | 14 | if os_str == 'Windows': 15 | print('Build Ebsynth Windows 64 bit.', 16 | 'If you want to build for 32 bit, please modify install.py.') 17 | cmd = '.\\build-win64-cpu+cuda.bat' 18 | exe_file = 'deps/ebsynth/bin/ebsynth.exe' 19 | elif os_str == 'Linux': 20 | cmd = 'bash build-linux-cpu+cuda.sh' 21 | exe_file = 'deps/ebsynth/bin/ebsynth' 22 | elif os_str == 'Darwin': 23 | cmd = 'sh build-macos-cpu_only.sh' 24 | exe_file = 'deps/ebsynth/bin/ebsynth.app' 25 | else: 26 | print('Cannot recognize OS. Ebsynth installation stopped.') 27 | return 28 | 29 | os.chdir('deps/ebsynth') 30 | print(cmd) 31 | os.system(cmd) 32 | os.chdir('../..') 33 | if os.path.exists(exe_file): 34 | print('Ebsynth installed successfully.') 35 | else: 36 | print('Failed to install Ebsynth.') 37 | 38 | 39 | def download(url, dir, name=None): 40 | os.makedirs(dir, exist_ok=True) 41 | if name is None: 42 | name = url.split('/')[-1] 43 | path = os.path.join(dir, name) 44 | if not os.path.exists(path): 45 | print(f'Install {name} ...') 46 | open(path, 'wb').write(requests.get(url).content) 47 | print('Install successfully.') 48 | 49 | 50 | def download_gmflow_ckpt(): 51 | url = ('https://huggingface.co/PKUWilliamYang/Rerender/' 52 | 'resolve/main/models/gmflow_sintel-0c07dcb3.pth') 53 | download(url, 'models') 54 | 55 | 56 | def download_controlnet_canny(): 57 | url = ('https://huggingface.co/lllyasviel/ControlNet/' 58 | 'resolve/main/models/control_sd15_canny.pth') 59 | download(url, 'models') 60 | 61 | 62 | def download_controlnet_hed(): 63 | url = ('https://huggingface.co/lllyasviel/ControlNet/' 64 | 'resolve/main/models/control_sd15_hed.pth') 65 | download(url, 'models') 66 | 67 | 68 | def download_vae(): 69 | url = ('https://huggingface.co/stabilityai/sd-vae-ft-mse-original' 70 | '/resolve/main/vae-ft-mse-840000-ema-pruned.ckpt') 71 | download(url, 'models') 72 | 73 | 74 | build_ebsynth() 75 | download_gmflow_ckpt() 76 | download_controlnet_canny() 77 | download_controlnet_hed() 78 | download_vae() 79 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | addict==2.4.0 2 | albumentations==1.3.0 3 | basicsr==1.4.2 4 | blendmodes 5 | einops==0.3.0 6 | gradio==3.44.4 7 | imageio==2.9.0 8 | imageio-ffmpeg==0.4.2 9 | invisible-watermark==0.1.5 10 | kornia==0.6 11 | numba==0.57.0 12 | omegaconf==2.1.1 13 | open_clip_torch==2.0.2 14 | prettytable==3.6.0 15 | pytorch-lightning==1.5.0 16 | safetensors==0.2.7 17 | streamlit==1.12.1 18 | streamlit-drawable-canvas==0.8.0 19 | test-tube==0.7.5 20 | timm==0.6.12 21 | torchmetrics==0.6.0 22 | transformers==4.19.2 23 | webdataset==0.2.5 24 | yapf==0.32.0 25 | -------------------------------------------------------------------------------- /rerender.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | 5 | import cv2 6 | import einops 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | import torchvision.transforms as T 11 | from blendmodes.blend import BlendType, blendLayers 12 | from PIL import Image 13 | from pytorch_lightning import seed_everything 14 | from safetensors.torch import load_file 15 | from skimage import exposure 16 | 17 | import src.import_util # noqa: F401 18 | from deps.ControlNet.annotator.canny import CannyDetector 19 | from deps.ControlNet.annotator.hed import HEDdetector 20 | from deps.ControlNet.annotator.util import HWC3 21 | from deps.ControlNet.cldm.cldm import ControlLDM 22 | from deps.ControlNet.cldm.model import create_model, load_state_dict 23 | from deps.gmflow.gmflow.gmflow import GMFlow 24 | from flow.flow_utils import get_warped_and_mask 25 | from src.config import RerenderConfig 26 | from src.controller import AttentionControl 27 | from src.ddim_v_hacked import DDIMVSampler 28 | from src.freeu import freeu_forward 29 | from src.img_util import find_flat_region, numpy2tensor 30 | from src.video_util import frame_to_video, get_fps, prepare_frames 31 | 32 | blur = T.GaussianBlur(kernel_size=(9, 9), sigma=(18, 18)) 33 | totensor = T.PILToTensor() 34 | 35 | 36 | def setup_color_correction(image): 37 | correction_target = cv2.cvtColor(np.asarray(image.copy()), 38 | cv2.COLOR_RGB2LAB) 39 | return correction_target 40 | 41 | 42 | def apply_color_correction(correction, original_image): 43 | image = Image.fromarray( 44 | cv2.cvtColor( 45 | exposure.match_histograms(cv2.cvtColor(np.asarray(original_image), 46 | cv2.COLOR_RGB2LAB), 47 | correction, 48 | channel_axis=2), 49 | cv2.COLOR_LAB2RGB).astype('uint8')) 50 | 51 | image = blendLayers(image, original_image, BlendType.LUMINOSITY) 52 | 53 | return image 54 | 55 | 56 | def rerender(cfg: RerenderConfig, first_img_only: bool, key_video_path: str): 57 | 58 | # Preprocess input 59 | prepare_frames(cfg.input_path, cfg.input_dir, cfg.image_resolution, cfg.crop, cfg.use_limit_device_resolution) 60 | 61 | # Load models 62 | if cfg.control_type == 'HED': 63 | detector = HEDdetector() 64 | elif cfg.control_type == 'canny': 65 | canny_detector = CannyDetector() 66 | low_threshold = cfg.canny_low 67 | high_threshold = cfg.canny_high 68 | 69 | def apply_canny(x): 70 | return canny_detector(x, low_threshold, high_threshold) 71 | 72 | detector = apply_canny 73 | 74 | model: ControlLDM = create_model( 75 | './deps/ControlNet/models/cldm_v15.yaml').cpu() 76 | if cfg.control_type == 'HED': 77 | model.load_state_dict( 78 | load_state_dict('./models/control_sd15_hed.pth', location='cuda')) 79 | elif cfg.control_type == 'canny': 80 | model.load_state_dict( 81 | load_state_dict('./models/control_sd15_canny.pth', 82 | location='cuda')) 83 | model = model.cuda() 84 | model.control_scales = [cfg.control_strength] * 13 85 | 86 | if cfg.sd_model is not None: 87 | model_ext = os.path.splitext(cfg.sd_model)[1] 88 | if model_ext == '.safetensors': 89 | model.load_state_dict(load_file(cfg.sd_model), strict=False) 90 | elif model_ext == '.ckpt' or model_ext == '.pth': 91 | model.load_state_dict(torch.load(cfg.sd_model)['state_dict'], 92 | strict=False) 93 | 94 | try: 95 | model.first_stage_model.load_state_dict(torch.load( 96 | './models/vae-ft-mse-840000-ema-pruned.ckpt')['state_dict'], 97 | strict=False) 98 | except Exception: 99 | print('Warning: We suggest you download the fine-tuned VAE', 100 | 'otherwise the generation quality will be degraded') 101 | 102 | model.model.diffusion_model.forward = \ 103 | freeu_forward(model.model.diffusion_model, *cfg.freeu_args) 104 | ddim_v_sampler = DDIMVSampler(model) 105 | 106 | flow_model = GMFlow( 107 | feature_channels=128, 108 | num_scales=1, 109 | upsample_factor=8, 110 | num_head=1, 111 | attention_type='swin', 112 | ffn_dim_expansion=4, 113 | num_transformer_layers=6, 114 | ).to('cuda') 115 | 116 | checkpoint = torch.load('models/gmflow_sintel-0c07dcb3.pth', 117 | map_location=lambda storage, loc: storage) 118 | weights = checkpoint['model'] if 'model' in checkpoint else checkpoint 119 | flow_model.load_state_dict(weights, strict=False) 120 | flow_model.eval() 121 | 122 | num_samples = 1 123 | ddim_steps = 20 124 | scale = 7.5 125 | 126 | seed = cfg.seed 127 | if seed == -1: 128 | seed = random.randint(0, 65535) 129 | eta = 0.0 130 | 131 | prompt = cfg.prompt 132 | a_prompt = cfg.a_prompt 133 | n_prompt = cfg.n_prompt 134 | prompt = prompt + ', ' + a_prompt 135 | 136 | style_update_freq = cfg.style_update_freq 137 | pixelfusion = True 138 | color_preserve = cfg.color_preserve 139 | 140 | x0_strength = 1 - cfg.x0_strength 141 | mask_period = cfg.mask_period 142 | firstx0 = True 143 | controller = AttentionControl(cfg.inner_strength, cfg.mask_period, 144 | cfg.cross_period, cfg.ada_period, 145 | cfg.warp_period, cfg.loose_cfattn) 146 | 147 | imgs = sorted(os.listdir(cfg.input_dir)) 148 | imgs = [os.path.join(cfg.input_dir, img) for img in imgs] 149 | if cfg.frame_count >= 0: 150 | imgs = imgs[:cfg.frame_count] 151 | 152 | with torch.no_grad(): 153 | frame = cv2.imread(imgs[0]) 154 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 155 | img = HWC3(frame) 156 | H, W, C = img.shape 157 | 158 | img_ = numpy2tensor(img) 159 | # if color_preserve: 160 | # img_ = numpy2tensor(img) 161 | # else: 162 | # img_ = apply_color_correction(color_corrections, 163 | # Image.fromarray(img)) 164 | # img_ = totensor(img_).unsqueeze(0)[:, :3] / 127.5 - 1 165 | encoder_posterior = model.encode_first_stage(img_.cuda()) 166 | x0 = model.get_first_stage_encoding(encoder_posterior).detach() 167 | 168 | detected_map = detector(img) 169 | detected_map = HWC3(detected_map) 170 | # For visualization 171 | detected_img = 255 - detected_map 172 | 173 | control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 174 | control = torch.stack([control for _ in range(num_samples)], dim=0) 175 | control = einops.rearrange(control, 'b h w c -> b c h w').clone() 176 | cond = { 177 | 'c_concat': [control], 178 | 'c_crossattn': 179 | [model.get_learned_conditioning([prompt] * num_samples)] 180 | } 181 | un_cond = { 182 | 'c_concat': [control], 183 | 'c_crossattn': 184 | [model.get_learned_conditioning([n_prompt] * num_samples)] 185 | } 186 | shape = (4, H // 8, W // 8) 187 | 188 | controller.set_task('initfirst') 189 | seed_everything(seed) 190 | samples, _ = ddim_v_sampler.sample(ddim_steps, 191 | num_samples, 192 | shape, 193 | cond, 194 | verbose=False, 195 | eta=eta, 196 | unconditional_guidance_scale=scale, 197 | unconditional_conditioning=un_cond, 198 | controller=controller, 199 | x0=x0, 200 | strength=x0_strength) 201 | x_samples = model.decode_first_stage(samples) 202 | pre_result = x_samples 203 | pre_img = img 204 | first_result = pre_result 205 | first_img = pre_img 206 | 207 | x_samples = ( 208 | einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 209 | 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) 210 | color_corrections = setup_color_correction(Image.fromarray(x_samples[0])) 211 | Image.fromarray(x_samples[0]).save(os.path.join(cfg.first_dir, 212 | 'first.jpg')) 213 | cv2.imwrite(os.path.join(cfg.first_dir, 'first_edge.jpg'), detected_img) 214 | 215 | if first_img_only: 216 | exit(0) 217 | 218 | for i in range(0, min(len(imgs), cfg.frame_count) - 1, cfg.interval): 219 | cid = i + 1 220 | print(cid) 221 | if cid <= (len(imgs) - 1): 222 | frame = cv2.imread(imgs[cid]) 223 | else: 224 | frame = cv2.imread(imgs[len(imgs) - 1]) 225 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 226 | img = HWC3(frame) 227 | 228 | if color_preserve: 229 | img_ = numpy2tensor(img) 230 | else: 231 | img_ = apply_color_correction(color_corrections, 232 | Image.fromarray(img)) 233 | img_ = totensor(img_).unsqueeze(0)[:, :3] / 127.5 - 1 234 | encoder_posterior = model.encode_first_stage(img_.cuda()) 235 | x0 = model.get_first_stage_encoding(encoder_posterior).detach() 236 | 237 | detected_map = detector(img) 238 | detected_map = HWC3(detected_map) 239 | 240 | control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 241 | control = torch.stack([control for _ in range(num_samples)], dim=0) 242 | control = einops.rearrange(control, 'b h w c -> b c h w').clone() 243 | cond['c_concat'] = [control] 244 | un_cond['c_concat'] = [control] 245 | 246 | image1 = torch.from_numpy(pre_img).permute(2, 0, 1).float() 247 | image2 = torch.from_numpy(img).permute(2, 0, 1).float() 248 | warped_pre, bwd_occ_pre, bwd_flow_pre = get_warped_and_mask( 249 | flow_model, image1, image2, pre_result, False) 250 | blend_mask_pre = blur( 251 | F.max_pool2d(bwd_occ_pre, kernel_size=9, stride=1, padding=4)) 252 | blend_mask_pre = torch.clamp(blend_mask_pre + bwd_occ_pre, 0, 1) 253 | 254 | image1 = torch.from_numpy(first_img).permute(2, 0, 1).float() 255 | warped_0, bwd_occ_0, bwd_flow_0 = get_warped_and_mask( 256 | flow_model, image1, image2, first_result, False) 257 | blend_mask_0 = blur( 258 | F.max_pool2d(bwd_occ_0, kernel_size=9, stride=1, padding=4)) 259 | blend_mask_0 = torch.clamp(blend_mask_0 + bwd_occ_0, 0, 1) 260 | 261 | if firstx0: 262 | mask = 1 - F.max_pool2d(blend_mask_0, kernel_size=8) 263 | controller.set_warp( 264 | F.interpolate(bwd_flow_0 / 8.0, 265 | scale_factor=1. / 8, 266 | mode='bilinear'), mask) 267 | else: 268 | mask = 1 - F.max_pool2d(blend_mask_pre, kernel_size=8) 269 | controller.set_warp( 270 | F.interpolate(bwd_flow_pre / 8.0, 271 | scale_factor=1. / 8, 272 | mode='bilinear'), mask) 273 | 274 | controller.set_task('keepx0, keepstyle') 275 | seed_everything(seed) 276 | samples, intermediates = ddim_v_sampler.sample( 277 | ddim_steps, 278 | num_samples, 279 | shape, 280 | cond, 281 | verbose=False, 282 | eta=eta, 283 | unconditional_guidance_scale=scale, 284 | unconditional_conditioning=un_cond, 285 | controller=controller, 286 | x0=x0, 287 | strength=x0_strength) 288 | direct_result = model.decode_first_stage(samples) 289 | 290 | if not pixelfusion: 291 | pre_result = direct_result 292 | pre_img = img 293 | viz = ( 294 | einops.rearrange(direct_result, 'b c h w -> b h w c') * 127.5 + 295 | 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) 296 | 297 | else: 298 | 299 | blend_results = (1 - blend_mask_pre 300 | ) * warped_pre + blend_mask_pre * direct_result 301 | blend_results = ( 302 | 1 - blend_mask_0) * warped_0 + blend_mask_0 * blend_results 303 | 304 | bwd_occ = 1 - torch.clamp(1 - bwd_occ_pre + 1 - bwd_occ_0, 0, 1) 305 | blend_mask = blur( 306 | F.max_pool2d(bwd_occ, kernel_size=9, stride=1, padding=4)) 307 | blend_mask = 1 - torch.clamp(blend_mask + bwd_occ, 0, 1) 308 | 309 | encoder_posterior = model.encode_first_stage(blend_results) 310 | xtrg = model.get_first_stage_encoding( 311 | encoder_posterior).detach() # * mask 312 | blend_results_rec = model.decode_first_stage(xtrg) 313 | encoder_posterior = model.encode_first_stage(blend_results_rec) 314 | xtrg_rec = model.get_first_stage_encoding( 315 | encoder_posterior).detach() 316 | xtrg_ = (xtrg + 1 * (xtrg - xtrg_rec)) # * mask 317 | blend_results_rec_new = model.decode_first_stage(xtrg_) 318 | tmp = (abs(blend_results_rec_new - blend_results).mean( 319 | dim=1, keepdims=True) > 0.25).float() 320 | mask_x = F.max_pool2d((F.interpolate( 321 | tmp, scale_factor=1 / 8., mode='bilinear') > 0).float(), 322 | kernel_size=3, 323 | stride=1, 324 | padding=1) 325 | 326 | mask = (1 - F.max_pool2d(1 - blend_mask, kernel_size=8) 327 | ) # * (1-mask_x) 328 | 329 | if cfg.smooth_boundary: 330 | noise_rescale = find_flat_region(mask) 331 | else: 332 | noise_rescale = torch.ones_like(mask) 333 | masks = [] 334 | for j in range(ddim_steps): 335 | if j <= ddim_steps * mask_period[ 336 | 0] or j >= ddim_steps * mask_period[1]: 337 | masks += [None] 338 | else: 339 | masks += [mask * cfg.mask_strength] 340 | 341 | # mask 3 342 | # xtrg = ((1-mask_x) * 343 | # (xtrg + xtrg - xtrg_rec) + mask_x * samples) * mask 344 | # mask 2 345 | # xtrg = (xtrg + 1 * (xtrg - xtrg_rec)) * mask 346 | xtrg = (xtrg + (1 - mask_x) * (xtrg - xtrg_rec)) * mask # mask 1 347 | 348 | tasks = 'keepstyle, keepx0' 349 | if not firstx0: 350 | tasks += ', updatex0' 351 | if i % style_update_freq == 0: 352 | tasks += ', updatestyle' 353 | controller.set_task(tasks, 1.0) 354 | 355 | seed_everything(seed) 356 | samples, _ = ddim_v_sampler.sample( 357 | ddim_steps, 358 | num_samples, 359 | shape, 360 | cond, 361 | verbose=False, 362 | eta=eta, 363 | unconditional_guidance_scale=scale, 364 | unconditional_conditioning=un_cond, 365 | controller=controller, 366 | x0=x0, 367 | strength=x0_strength, 368 | xtrg=xtrg, 369 | mask=masks, 370 | noise_rescale=noise_rescale) 371 | x_samples = model.decode_first_stage(samples) 372 | pre_result = x_samples 373 | pre_img = img 374 | 375 | viz = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 376 | 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) 377 | 378 | Image.fromarray(viz[0]).save( 379 | os.path.join(cfg.key_dir, f'{cid:04d}.png')) 380 | if key_video_path is not None: 381 | fps = get_fps(cfg.input_path) 382 | fps //= cfg.interval 383 | frame_to_video(key_video_path, cfg.key_dir, fps, False) 384 | 385 | 386 | def postprocess(cfg: RerenderConfig, ne: bool, max_process: int, tmp: bool, 387 | ps: bool): 388 | video_base_dir = cfg.work_dir 389 | o_video = cfg.output_path 390 | fps = get_fps(cfg.input_path) 391 | 392 | end_frame = cfg.frame_count - 1 393 | interval = cfg.interval 394 | key_dir = os.path.split(cfg.key_dir)[-1] 395 | use_e = '-ne' if ne else '' 396 | use_tmp = '-tmp' if tmp else '' 397 | use_ps = '-ps' if ps else '' 398 | o_video_cmd = f'--output {o_video}' 399 | 400 | cmd = ( 401 | f'python video_blend.py {video_base_dir} --beg 1 --end {end_frame} ' 402 | f'--itv {interval} --key {key_dir} {use_e} {o_video_cmd} --fps {fps} ' 403 | f'--n_proc {max_process} {use_tmp} {use_ps}') 404 | print(cmd) 405 | os.system(cmd) 406 | 407 | 408 | if __name__ == '__main__': 409 | parser = argparse.ArgumentParser() 410 | parser.add_argument('--cfg', type=str, default=None) 411 | parser.add_argument('--input', 412 | type=str, 413 | default=None, 414 | help='The input path to video.') 415 | parser.add_argument('--output', type=str, default=None) 416 | parser.add_argument('--prompt', type=str, default=None) 417 | parser.add_argument('--key_video_path', type=str, default=None) 418 | parser.add_argument('-one', 419 | action='store_true', 420 | help='Run the first frame with ControlNet only') 421 | parser.add_argument('-nr', 422 | action='store_true', 423 | help='Do not run rerender and do postprocessing only') 424 | parser.add_argument('-nb', 425 | action='store_true', 426 | help='Do not run postprocessing and run rerender only') 427 | parser.add_argument( 428 | '-ne', 429 | action='store_true', 430 | help='Do not run ebsynth (use previous ebsynth temporary output)') 431 | parser.add_argument('-nps', 432 | action='store_true', 433 | help='Do not run poisson gradient blending') 434 | parser.add_argument('--n_proc', 435 | type=int, 436 | default=4, 437 | help='The max process count') 438 | parser.add_argument('--tmp', 439 | action='store_true', 440 | help='Keep ebsynth temporary output') 441 | 442 | args = parser.parse_args() 443 | 444 | cfg = RerenderConfig() 445 | if args.cfg is not None: 446 | cfg.create_from_path(args.cfg) 447 | if args.input is not None: 448 | print('Config has been loaded. --input is ignored.') 449 | if args.output is not None: 450 | print('Config has been loaded. --output is ignored.') 451 | if args.prompt is not None: 452 | print('Config has been loaded. --prompt is ignored.') 453 | else: 454 | if args.input is None: 455 | print('Config not found. --input is required.') 456 | exit(0) 457 | if args.output is None: 458 | print('Config not found. --output is required.') 459 | exit(0) 460 | if args.prompt is None: 461 | print('Config not found. --prompt is required.') 462 | exit(0) 463 | cfg.create_from_parameters(args.input, args.output, args.prompt) 464 | 465 | if not args.nr: 466 | rerender(cfg, args.one, args.key_video_path) 467 | torch.cuda.empty_cache() 468 | if not args.nb: 469 | postprocess(cfg, args.ne, args.n_proc, args.tmp, not args.nps) 470 | -------------------------------------------------------------------------------- /sd_model_cfg.py: -------------------------------------------------------------------------------- 1 | # The model dict is used for webUI only 2 | 3 | model_dict = { 4 | 'Stable Diffusion 1.5': '', 5 | 'revAnimated_v11': 'models/revAnimated_v11.safetensors', 6 | 'realisticVisionV20_v20': 'models/realisticVisionV20_v20.safetensors' 7 | } 8 | -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import Optional, Sequence, Tuple 4 | 5 | from src.video_util import get_frame_count 6 | 7 | 8 | class RerenderConfig: 9 | 10 | def __init__(self): 11 | ... 12 | 13 | def create_from_parameters(self, 14 | input_path: str, 15 | output_path: str, 16 | prompt: str, 17 | work_dir: Optional[str] = None, 18 | key_subdir: str = 'keys', 19 | frame_count: Optional[int] = None, 20 | interval: int = 10, 21 | crop: Sequence[int] = (0, 0, 0, 0), 22 | sd_model: Optional[str] = None, 23 | a_prompt: str = '', 24 | n_prompt: str = '', 25 | ddim_steps=20, 26 | scale=7.5, 27 | control_type: str = 'HED', 28 | control_strength=1, 29 | seed: int = -1, 30 | image_resolution: int = 512, 31 | use_limit_device_resolution: bool = False, 32 | x0_strength: float = -1, 33 | style_update_freq: int = 10, 34 | cross_period: Tuple[float, float] = (0, 1), 35 | warp_period: Tuple[float, float] = (0, 0.1), 36 | mask_period: Tuple[float, float] = (0.5, 0.8), 37 | ada_period: Tuple[float, float] = (1.0, 1.0), 38 | mask_strength: float = 0.5, 39 | inner_strength: float = 0.9, 40 | smooth_boundary: bool = True, 41 | color_preserve: bool = True, 42 | loose_cfattn: bool = False, 43 | freeu_args: Tuple[int] = (1, 1, 1, 1), 44 | **kwargs): 45 | self.input_path = input_path 46 | self.output_path = output_path 47 | self.prompt = prompt 48 | self.work_dir = work_dir 49 | if work_dir is None: 50 | self.work_dir = os.path.dirname(output_path) 51 | self.key_dir = os.path.join(self.work_dir, key_subdir) 52 | self.first_dir = os.path.join(self.work_dir, 'first') 53 | 54 | # Split video into frames 55 | if not os.path.isfile(input_path): 56 | raise FileNotFoundError(f'Cannot find video file {input_path}') 57 | self.input_dir = os.path.join(self.work_dir, 'video') 58 | 59 | self.frame_count = frame_count 60 | if frame_count is None: 61 | self.frame_count = get_frame_count(self.input_path) 62 | self.interval = interval 63 | self.crop = crop 64 | self.sd_model = sd_model 65 | self.a_prompt = a_prompt 66 | self.n_prompt = n_prompt 67 | self.ddim_steps = ddim_steps 68 | self.scale = scale 69 | self.control_type = control_type 70 | if self.control_type == 'canny': 71 | self.canny_low = kwargs.get('canny_low', 100) 72 | self.canny_high = kwargs.get('canny_high', 200) 73 | else: 74 | self.canny_low = None 75 | self.canny_high = None 76 | self.control_strength = control_strength 77 | self.seed = seed 78 | self.image_resolution = image_resolution 79 | self.use_limit_device_resolution = use_limit_device_resolution 80 | self.x0_strength = x0_strength 81 | self.style_update_freq = style_update_freq 82 | self.cross_period = cross_period 83 | self.mask_period = mask_period 84 | self.warp_period = warp_period 85 | self.ada_period = ada_period 86 | self.mask_strength = mask_strength 87 | self.inner_strength = inner_strength 88 | self.smooth_boundary = smooth_boundary 89 | self.color_preserve = color_preserve 90 | self.loose_cfattn = loose_cfattn 91 | self.freeu_args = freeu_args 92 | 93 | os.makedirs(self.input_dir, exist_ok=True) 94 | os.makedirs(self.work_dir, exist_ok=True) 95 | os.makedirs(self.key_dir, exist_ok=True) 96 | os.makedirs(self.first_dir, exist_ok=True) 97 | 98 | def create_from_path(self, cfg_path: str): 99 | with open(cfg_path, 'r') as fp: 100 | cfg = json.load(fp) 101 | kwargs = dict() 102 | 103 | def append_if_not_none(key): 104 | value = cfg.get(key, None) 105 | if value is not None: 106 | kwargs[key] = value 107 | 108 | kwargs['input_path'] = cfg['input'] 109 | kwargs['output_path'] = cfg['output'] 110 | kwargs['prompt'] = cfg['prompt'] 111 | append_if_not_none('work_dir') 112 | append_if_not_none('key_subdir') 113 | append_if_not_none('frame_count') 114 | append_if_not_none('interval') 115 | append_if_not_none('crop') 116 | append_if_not_none('sd_model') 117 | append_if_not_none('a_prompt') 118 | append_if_not_none('n_prompt') 119 | append_if_not_none('ddim_steps') 120 | append_if_not_none('scale') 121 | append_if_not_none('control_type') 122 | if kwargs.get('control_type', '') == 'canny': 123 | append_if_not_none('canny_low') 124 | append_if_not_none('canny_high') 125 | append_if_not_none('control_strength') 126 | append_if_not_none('seed') 127 | append_if_not_none('image_resolution') 128 | append_if_not_none('use_limit_device_resolution') 129 | append_if_not_none('x0_strength') 130 | append_if_not_none('style_update_freq') 131 | append_if_not_none('cross_period') 132 | append_if_not_none('warp_period') 133 | append_if_not_none('mask_period') 134 | append_if_not_none('ada_period') 135 | append_if_not_none('mask_strength') 136 | append_if_not_none('inner_strength') 137 | append_if_not_none('smooth_boundary') 138 | append_if_not_none('color_perserve') 139 | append_if_not_none('loose_cfattn') 140 | append_if_not_none('freeu_args') 141 | self.create_from_parameters(**kwargs) 142 | 143 | @property 144 | def use_warp(self): 145 | return self.warp_period[0] <= self.warp_period[1] 146 | 147 | @property 148 | def use_mask(self): 149 | return self.mask_period[0] <= self.mask_period[1] 150 | 151 | @property 152 | def use_ada(self): 153 | return self.ada_period[0] <= self.ada_period[1] 154 | -------------------------------------------------------------------------------- /src/controller.py: -------------------------------------------------------------------------------- 1 | import gc 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from flow.flow_utils import flow_warp 7 | 8 | # AdaIn 9 | 10 | 11 | def calc_mean_std(feat, eps=1e-5): 12 | # eps is a small value added to the variance to avoid divide-by-zero. 13 | size = feat.size() 14 | assert (len(size) == 4) 15 | N, C = size[:2] 16 | feat_var = feat.view(N, C, -1).var(dim=2) + eps 17 | feat_std = feat_var.sqrt().view(N, C, 1, 1) 18 | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) 19 | return feat_mean, feat_std 20 | 21 | 22 | class AttentionControl(): 23 | 24 | def __init__(self, 25 | inner_strength, 26 | mask_period, 27 | cross_period, 28 | ada_period, 29 | warp_period, 30 | loose_cfatnn=False): 31 | self.step_store = self.get_empty_store() 32 | self.cur_step = 0 33 | self.total_step = 0 34 | self.cur_index = 0 35 | self.init_store = False 36 | self.restore = False 37 | self.update = False 38 | self.flow = None 39 | self.mask = None 40 | self.restorex0 = False 41 | self.updatex0 = False 42 | self.inner_strength = inner_strength 43 | self.cross_period = cross_period 44 | self.mask_period = mask_period 45 | self.ada_period = ada_period 46 | self.warp_period = warp_period 47 | self.up_resolution = 1280 if loose_cfatnn else 1281 48 | 49 | @staticmethod 50 | def get_empty_store(): 51 | return { 52 | 'first': [], 53 | 'previous': [], 54 | 'x0_previous': [], 55 | 'first_ada': [] 56 | } 57 | 58 | def forward(self, context, is_cross: bool, place_in_unet: str): 59 | cross_period = (self.total_step * self.cross_period[0], 60 | self.total_step * self.cross_period[1]) 61 | if not is_cross and place_in_unet == 'up' and context.shape[ 62 | 2] < self.up_resolution: 63 | if self.init_store: 64 | self.step_store['first'].append(context.detach()) 65 | self.step_store['previous'].append(context.detach()) 66 | if self.update: 67 | tmp = context.clone().detach() 68 | if self.restore and self.cur_step >= cross_period[0] and \ 69 | self.cur_step <= cross_period[1]: 70 | context = torch.cat( 71 | (self.step_store['first'][self.cur_index], 72 | self.step_store['previous'][self.cur_index]), 73 | dim=1).clone() 74 | if self.update: 75 | self.step_store['previous'][self.cur_index] = tmp 76 | self.cur_index += 1 77 | return context 78 | 79 | def update_x0(self, x0): 80 | if self.init_store: 81 | self.step_store['x0_previous'].append(x0.detach()) 82 | style_mean, style_std = calc_mean_std(x0.detach()) 83 | self.step_store['first_ada'].append(style_mean.detach()) 84 | self.step_store['first_ada'].append(style_std.detach()) 85 | if self.updatex0: 86 | tmp = x0.clone().detach() 87 | if self.restorex0: 88 | if self.cur_step >= self.total_step * self.ada_period[ 89 | 0] and self.cur_step <= self.total_step * self.ada_period[ 90 | 1]: 91 | x0 = F.instance_norm(x0) * self.step_store['first_ada'][ 92 | 2 * self.cur_step + 93 | 1] + self.step_store['first_ada'][2 * self.cur_step] 94 | if self.cur_step >= self.total_step * self.warp_period[ 95 | 0] and self.cur_step <= self.total_step * self.warp_period[ 96 | 1]: 97 | pre = self.step_store['x0_previous'][self.cur_step] 98 | x0 = flow_warp(pre, self.flow, mode='nearest') * self.mask + ( 99 | 1 - self.mask) * x0 100 | if self.updatex0: 101 | self.step_store['x0_previous'][self.cur_step] = tmp 102 | return x0 103 | 104 | def set_warp(self, flow, mask): 105 | self.flow = flow.clone() 106 | self.mask = mask.clone() 107 | 108 | def __call__(self, context, is_cross: bool, place_in_unet: str): 109 | context = self.forward(context, is_cross, place_in_unet) 110 | return context 111 | 112 | def set_step(self, step): 113 | self.cur_step = step 114 | 115 | def set_total_step(self, total_step): 116 | self.total_step = total_step 117 | self.cur_index = 0 118 | 119 | def clear_store(self): 120 | del self.step_store 121 | torch.cuda.empty_cache() 122 | gc.collect() 123 | self.step_store = self.get_empty_store() 124 | 125 | def set_task(self, task, restore_step=1.0): 126 | self.init_store = False 127 | self.restore = False 128 | self.update = False 129 | self.cur_index = 0 130 | self.restore_step = restore_step 131 | self.updatex0 = False 132 | self.restorex0 = False 133 | if 'initfirst' in task: 134 | self.init_store = True 135 | self.clear_store() 136 | if 'updatestyle' in task: 137 | self.update = True 138 | if 'keepstyle' in task: 139 | self.restore = True 140 | if 'updatex0' in task: 141 | self.updatex0 = True 142 | if 'keepx0' in task: 143 | self.restorex0 = True 144 | -------------------------------------------------------------------------------- /src/ddim_v_hacked.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | # CrossAttn precision handling 4 | import os 5 | 6 | import einops 7 | import numpy as np 8 | import torch 9 | from tqdm import tqdm 10 | 11 | from deps.ControlNet.ldm.modules.diffusionmodules.util import ( 12 | extract_into_tensor, make_ddim_sampling_parameters, make_ddim_timesteps, 13 | noise_like) 14 | 15 | _ATTN_PRECISION = os.environ.get('ATTN_PRECISION', 'fp32') 16 | 17 | 18 | def register_attention_control(model, controller=None): 19 | 20 | def ca_forward(self, place_in_unet): 21 | 22 | def forward(x, context=None, mask=None): 23 | h = self.heads 24 | 25 | q = self.to_q(x) 26 | is_cross = context is not None 27 | context = context if is_cross else x 28 | context = controller(context, is_cross, place_in_unet) 29 | 30 | k = self.to_k(context) 31 | v = self.to_v(context) 32 | 33 | q, k, v = map( 34 | lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), 35 | (q, k, v)) 36 | 37 | # force cast to fp32 to avoid overflowing 38 | if _ATTN_PRECISION == 'fp32': 39 | with torch.autocast(enabled=False, device_type='cuda'): 40 | q, k = q.float(), k.float() 41 | sim = torch.einsum('b i d, b j d -> b i j', q, 42 | k) * self.scale 43 | else: 44 | sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale 45 | 46 | del q, k 47 | 48 | if mask is not None: 49 | mask = einops.rearrange(mask, 'b ... -> b (...)') 50 | max_neg_value = -torch.finfo(sim.dtype).max 51 | mask = einops.repeat(mask, 'b j -> (b h) () j', h=h) 52 | sim.masked_fill_(~mask, max_neg_value) 53 | 54 | # attention, what we cannot get enough of 55 | sim = sim.softmax(dim=-1) 56 | 57 | out = torch.einsum('b i j, b j d -> b i d', sim, v) 58 | out = einops.rearrange(out, '(b h) n d -> b n (h d)', h=h) 59 | return self.to_out(out) 60 | 61 | return forward 62 | 63 | class DummyController: 64 | 65 | def __call__(self, *args): 66 | return args[0] 67 | 68 | def __init__(self): 69 | self.cur_step = 0 70 | 71 | if controller is None: 72 | controller = DummyController() 73 | 74 | def register_recr(net_, place_in_unet): 75 | if net_.__class__.__name__ == 'CrossAttention': 76 | net_.forward = ca_forward(net_, place_in_unet) 77 | elif hasattr(net_, 'children'): 78 | for net__ in net_.children(): 79 | register_recr(net__, place_in_unet) 80 | 81 | sub_nets = model.named_children() 82 | for net in sub_nets: 83 | if 'input_blocks' in net[0]: 84 | register_recr(net[1], 'down') 85 | elif 'output_blocks' in net[0]: 86 | register_recr(net[1], 'up') 87 | elif 'middle_block' in net[0]: 88 | register_recr(net[1], 'mid') 89 | 90 | 91 | class DDIMVSampler(object): 92 | 93 | def __init__(self, model, schedule='linear', **kwargs): 94 | super().__init__() 95 | self.model = model 96 | self.ddpm_num_timesteps = model.num_timesteps 97 | self.schedule = schedule 98 | 99 | def register_buffer(self, name, attr): 100 | if type(attr) == torch.Tensor: 101 | if attr.device != torch.device('cuda'): 102 | attr = attr.to(torch.device('cuda')) 103 | setattr(self, name, attr) 104 | 105 | def make_schedule(self, 106 | ddim_num_steps, 107 | ddim_discretize='uniform', 108 | ddim_eta=0., 109 | verbose=True): 110 | self.ddim_timesteps = make_ddim_timesteps( 111 | ddim_discr_method=ddim_discretize, 112 | num_ddim_timesteps=ddim_num_steps, 113 | num_ddpm_timesteps=self.ddpm_num_timesteps, 114 | verbose=verbose) 115 | alphas_cumprod = self.model.alphas_cumprod 116 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, \ 117 | 'alphas have to be defined for each timestep' 118 | 119 | def to_torch(x): 120 | return x.clone().detach().to(torch.float32).to(self.model.device) 121 | 122 | self.register_buffer('betas', to_torch(self.model.betas)) 123 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 124 | self.register_buffer('alphas_cumprod_prev', 125 | to_torch(self.model.alphas_cumprod_prev)) 126 | 127 | # calculations for diffusion q(x_t | x_{t-1}) and others 128 | self.register_buffer('sqrt_alphas_cumprod', 129 | to_torch(np.sqrt(alphas_cumprod.cpu()))) 130 | self.register_buffer('sqrt_one_minus_alphas_cumprod', 131 | to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 132 | self.register_buffer('log_one_minus_alphas_cumprod', 133 | to_torch(np.log(1. - alphas_cumprod.cpu()))) 134 | self.register_buffer('sqrt_recip_alphas_cumprod', 135 | to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 136 | self.register_buffer('sqrt_recipm1_alphas_cumprod', 137 | to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 138 | 139 | # ddim sampling parameters 140 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = \ 141 | make_ddim_sampling_parameters( 142 | alphacums=alphas_cumprod.cpu(), 143 | ddim_timesteps=self.ddim_timesteps, 144 | eta=ddim_eta, 145 | verbose=verbose) 146 | self.register_buffer('ddim_sigmas', ddim_sigmas) 147 | self.register_buffer('ddim_alphas', ddim_alphas) 148 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 149 | self.register_buffer('ddim_sqrt_one_minus_alphas', 150 | np.sqrt(1. - ddim_alphas)) 151 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 152 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * 153 | (1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 154 | self.register_buffer('ddim_sigmas_for_original_num_steps', 155 | sigmas_for_original_sampling_steps) 156 | 157 | @torch.no_grad() 158 | def sample(self, 159 | S, 160 | batch_size, 161 | shape, 162 | conditioning=None, 163 | callback=None, 164 | img_callback=None, 165 | quantize_x0=False, 166 | eta=0., 167 | mask=None, 168 | x0=None, 169 | xtrg=None, 170 | noise_rescale=None, 171 | temperature=1., 172 | noise_dropout=0., 173 | score_corrector=None, 174 | corrector_kwargs=None, 175 | verbose=True, 176 | x_T=None, 177 | log_every_t=100, 178 | unconditional_guidance_scale=1., 179 | unconditional_conditioning=None, 180 | dynamic_threshold=None, 181 | ucg_schedule=None, 182 | controller=None, 183 | strength=0.0, 184 | **kwargs): 185 | if conditioning is not None: 186 | if isinstance(conditioning, dict): 187 | ctmp = conditioning[list(conditioning.keys())[0]] 188 | while isinstance(ctmp, list): 189 | ctmp = ctmp[0] 190 | cbs = ctmp.shape[0] 191 | if cbs != batch_size: 192 | print(f'Warning: Got {cbs} conditionings' 193 | f'but batch-size is {batch_size}') 194 | 195 | elif isinstance(conditioning, list): 196 | for ctmp in conditioning: 197 | if ctmp.shape[0] != batch_size: 198 | print(f'Warning: Got {cbs} conditionings' 199 | f'but batch-size is {batch_size}') 200 | 201 | else: 202 | if conditioning.shape[0] != batch_size: 203 | print(f'Warning: Got {conditioning.shape[0]}' 204 | f'conditionings but batch-size is {batch_size}') 205 | 206 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) 207 | # sampling 208 | C, H, W = shape 209 | size = (batch_size, C, H, W) 210 | print(f'Data shape for DDIM sampling is {size}, eta {eta}') 211 | 212 | samples, intermediates = self.ddim_sampling( 213 | conditioning, 214 | size, 215 | callback=callback, 216 | img_callback=img_callback, 217 | quantize_denoised=quantize_x0, 218 | mask=mask, 219 | x0=x0, 220 | xtrg=xtrg, 221 | noise_rescale=noise_rescale, 222 | ddim_use_original_steps=False, 223 | noise_dropout=noise_dropout, 224 | temperature=temperature, 225 | score_corrector=score_corrector, 226 | corrector_kwargs=corrector_kwargs, 227 | x_T=x_T, 228 | log_every_t=log_every_t, 229 | unconditional_guidance_scale=unconditional_guidance_scale, 230 | unconditional_conditioning=unconditional_conditioning, 231 | dynamic_threshold=dynamic_threshold, 232 | ucg_schedule=ucg_schedule, 233 | controller=controller, 234 | strength=strength, 235 | ) 236 | return samples, intermediates 237 | 238 | @torch.no_grad() 239 | def ddim_sampling(self, 240 | cond, 241 | shape, 242 | x_T=None, 243 | ddim_use_original_steps=False, 244 | callback=None, 245 | timesteps=None, 246 | quantize_denoised=False, 247 | mask=None, 248 | x0=None, 249 | xtrg=None, 250 | noise_rescale=None, 251 | img_callback=None, 252 | log_every_t=100, 253 | temperature=1., 254 | noise_dropout=0., 255 | score_corrector=None, 256 | corrector_kwargs=None, 257 | unconditional_guidance_scale=1., 258 | unconditional_conditioning=None, 259 | dynamic_threshold=None, 260 | ucg_schedule=None, 261 | controller=None, 262 | strength=0.0): 263 | 264 | if strength == 1 and x0 is not None: 265 | return x0, None 266 | 267 | register_attention_control(self.model.model.diffusion_model, 268 | controller) 269 | 270 | device = self.model.betas.device 271 | b = shape[0] 272 | if x_T is None: 273 | img = torch.randn(shape, device=device) 274 | else: 275 | img = x_T 276 | 277 | if timesteps is None: 278 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps \ 279 | else self.ddim_timesteps 280 | elif timesteps is not None and not ddim_use_original_steps: 281 | subset_end = int( 282 | min(timesteps / self.ddim_timesteps.shape[0], 1) * 283 | self.ddim_timesteps.shape[0]) - 1 284 | timesteps = self.ddim_timesteps[:subset_end] 285 | 286 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 287 | time_range = reversed(range( 288 | 0, timesteps)) if ddim_use_original_steps else np.flip(timesteps) 289 | total_steps = timesteps if ddim_use_original_steps \ 290 | else timesteps.shape[0] 291 | print(f'Running DDIM Sampling with {total_steps} timesteps') 292 | 293 | iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) 294 | if controller is not None: 295 | controller.set_total_step(total_steps) 296 | if mask is None: 297 | mask = [None] * total_steps 298 | 299 | dir_xt = 0 300 | for i, step in enumerate(iterator): 301 | if controller is not None: 302 | controller.set_step(i) 303 | index = total_steps - i - 1 304 | ts = torch.full((b, ), step, device=device, dtype=torch.long) 305 | 306 | if strength >= 0 and i == int( 307 | total_steps * strength) and x0 is not None: 308 | img = self.model.q_sample(x0, ts) 309 | if mask is not None and xtrg is not None: 310 | # TODO: deterministic forward pass? 311 | if type(mask) == list: 312 | weight = mask[i] 313 | else: 314 | weight = mask 315 | if weight is not None: 316 | rescale = torch.maximum(1. - weight, (1 - weight**2)**0.5 * 317 | controller.inner_strength) 318 | if noise_rescale is not None: 319 | rescale = (1. - weight) * ( 320 | 1 - noise_rescale) + rescale * noise_rescale 321 | img_ref = self.model.q_sample(xtrg, ts) 322 | img = img_ref * weight + (1. - weight) * ( 323 | img - dir_xt) + rescale * dir_xt 324 | 325 | if ucg_schedule is not None: 326 | assert len(ucg_schedule) == len(time_range) 327 | unconditional_guidance_scale = ucg_schedule[i] 328 | 329 | outs = self.p_sample_ddim( 330 | img, 331 | cond, 332 | ts, 333 | index=index, 334 | use_original_steps=ddim_use_original_steps, 335 | quantize_denoised=quantize_denoised, 336 | temperature=temperature, 337 | noise_dropout=noise_dropout, 338 | score_corrector=score_corrector, 339 | corrector_kwargs=corrector_kwargs, 340 | unconditional_guidance_scale=unconditional_guidance_scale, 341 | unconditional_conditioning=unconditional_conditioning, 342 | dynamic_threshold=dynamic_threshold, 343 | controller=controller, 344 | return_dir=True) 345 | img, pred_x0, dir_xt = outs 346 | if callback: 347 | callback(i) 348 | if img_callback: 349 | img_callback(pred_x0, i) 350 | 351 | if index % log_every_t == 0 or index == total_steps - 1: 352 | intermediates['x_inter'].append(img) 353 | intermediates['pred_x0'].append(pred_x0) 354 | 355 | return img, intermediates 356 | 357 | @torch.no_grad() 358 | def p_sample_ddim(self, 359 | x, 360 | c, 361 | t, 362 | index, 363 | repeat_noise=False, 364 | use_original_steps=False, 365 | quantize_denoised=False, 366 | temperature=1., 367 | noise_dropout=0., 368 | score_corrector=None, 369 | corrector_kwargs=None, 370 | unconditional_guidance_scale=1., 371 | unconditional_conditioning=None, 372 | dynamic_threshold=None, 373 | controller=None, 374 | return_dir=False): 375 | b, *_, device = *x.shape, x.device 376 | 377 | if unconditional_conditioning is None or \ 378 | unconditional_guidance_scale == 1.: 379 | model_output = self.model.apply_model(x, t, c) 380 | else: 381 | model_t = self.model.apply_model(x, t, c) 382 | model_uncond = self.model.apply_model(x, t, 383 | unconditional_conditioning) 384 | model_output = model_uncond + unconditional_guidance_scale * ( 385 | model_t - model_uncond) 386 | 387 | if self.model.parameterization == 'v': 388 | e_t = self.model.predict_eps_from_z_and_v(x, t, model_output) 389 | else: 390 | e_t = model_output 391 | 392 | if score_corrector is not None: 393 | assert self.model.parameterization == 'eps', 'not implemented' 394 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, 395 | **corrector_kwargs) 396 | 397 | if use_original_steps: 398 | alphas = self.model.alphas_cumprod 399 | alphas_prev = self.model.alphas_cumprod_prev 400 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod 401 | sigmas = self.model.ddim_sigmas_for_original_num_steps 402 | else: 403 | alphas = self.ddim_alphas 404 | alphas_prev = self.ddim_alphas_prev 405 | sqrt_one_minus_alphas = self.ddim_sqrt_one_minus_alphas 406 | sigmas = self.ddim_sigmas 407 | 408 | # select parameters corresponding to the currently considered timestep 409 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) 410 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) 411 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) 412 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), 413 | sqrt_one_minus_alphas[index], 414 | device=device) 415 | 416 | # current prediction for x_0 417 | if self.model.parameterization != 'v': 418 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 419 | else: 420 | pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) 421 | 422 | if quantize_denoised: 423 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 424 | 425 | if dynamic_threshold is not None: 426 | raise NotImplementedError() 427 | ''' 428 | if mask is not None and xtrg is not None: 429 | pred_x0 = xtrg * mask + (1. - mask) * pred_x0 430 | ''' 431 | 432 | if controller is not None: 433 | pred_x0 = controller.update_x0(pred_x0) 434 | 435 | # direction pointing to x_t 436 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 437 | noise = sigma_t * noise_like(x.shape, device, 438 | repeat_noise) * temperature 439 | if noise_dropout > 0.: 440 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 441 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 442 | 443 | if return_dir: 444 | return x_prev, pred_x0, dir_xt 445 | return x_prev, pred_x0 446 | 447 | @torch.no_grad() 448 | def encode(self, 449 | x0, 450 | c, 451 | t_enc, 452 | use_original_steps=False, 453 | return_intermediates=None, 454 | unconditional_guidance_scale=1.0, 455 | unconditional_conditioning=None, 456 | callback=None): 457 | timesteps = np.arange(self.ddpm_num_timesteps 458 | ) if use_original_steps else self.ddim_timesteps 459 | num_reference_steps = timesteps.shape[0] 460 | 461 | assert t_enc <= num_reference_steps 462 | num_steps = t_enc 463 | 464 | if use_original_steps: 465 | alphas_next = self.alphas_cumprod[:num_steps] 466 | alphas = self.alphas_cumprod_prev[:num_steps] 467 | else: 468 | alphas_next = self.ddim_alphas[:num_steps] 469 | alphas = torch.tensor(self.ddim_alphas_prev[:num_steps]) 470 | 471 | x_next = x0 472 | intermediates = [] 473 | inter_steps = [] 474 | for i in tqdm(range(num_steps), desc='Encoding Image'): 475 | t = torch.full((x0.shape[0], ), 476 | timesteps[i], 477 | device=self.model.device, 478 | dtype=torch.long) 479 | if unconditional_guidance_scale == 1.: 480 | noise_pred = self.model.apply_model(x_next, t, c) 481 | else: 482 | assert unconditional_conditioning is not None 483 | e_t_uncond, noise_pred = torch.chunk( 484 | self.model.apply_model( 485 | torch.cat((x_next, x_next)), torch.cat((t, t)), 486 | torch.cat((unconditional_conditioning, c))), 2) 487 | noise_pred = e_t_uncond + unconditional_guidance_scale * ( 488 | noise_pred - e_t_uncond) 489 | xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next 490 | weighted_noise_pred = alphas_next[i].sqrt() * ( 491 | (1 / alphas_next[i] - 1).sqrt() - 492 | (1 / alphas[i] - 1).sqrt()) * noise_pred 493 | x_next = xt_weighted + weighted_noise_pred 494 | if return_intermediates and i % (num_steps // return_intermediates 495 | ) == 0 and i < num_steps - 1: 496 | intermediates.append(x_next) 497 | inter_steps.append(i) 498 | elif return_intermediates and i >= num_steps - 2: 499 | intermediates.append(x_next) 500 | inter_steps.append(i) 501 | if callback: 502 | callback(i) 503 | 504 | out = {'x_encoded': x_next, 'intermediate_steps': inter_steps} 505 | if return_intermediates: 506 | out.update({'intermediates': intermediates}) 507 | return x_next, out 508 | 509 | @torch.no_grad() 510 | def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): 511 | # fast, but does not allow for exact reconstruction 512 | # t serves as an index to gather the correct alphas 513 | if use_original_steps: 514 | sqrt_alphas_cumprod = self.sqrt_alphas_cumprod 515 | sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod 516 | else: 517 | sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) 518 | sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas 519 | 520 | if noise is None: 521 | noise = torch.randn_like(x0) 522 | if t >= len(sqrt_alphas_cumprod): 523 | return noise 524 | return ( 525 | extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + 526 | extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * 527 | noise) 528 | 529 | @torch.no_grad() 530 | def decode(self, 531 | x_latent, 532 | cond, 533 | t_start, 534 | unconditional_guidance_scale=1.0, 535 | unconditional_conditioning=None, 536 | use_original_steps=False, 537 | callback=None): 538 | 539 | timesteps = np.arange(self.ddpm_num_timesteps 540 | ) if use_original_steps else self.ddim_timesteps 541 | timesteps = timesteps[:t_start] 542 | 543 | time_range = np.flip(timesteps) 544 | total_steps = timesteps.shape[0] 545 | print(f'Running DDIM Sampling with {total_steps} timesteps') 546 | 547 | iterator = tqdm(time_range, desc='Decoding image', total=total_steps) 548 | x_dec = x_latent 549 | for i, step in enumerate(iterator): 550 | index = total_steps - i - 1 551 | ts = torch.full((x_latent.shape[0], ), 552 | step, 553 | device=x_latent.device, 554 | dtype=torch.long) 555 | x_dec, _ = self.p_sample_ddim( 556 | x_dec, 557 | cond, 558 | ts, 559 | index=index, 560 | use_original_steps=use_original_steps, 561 | unconditional_guidance_scale=unconditional_guidance_scale, 562 | unconditional_conditioning=unconditional_conditioning) 563 | if callback: 564 | callback(i) 565 | return x_dec 566 | 567 | 568 | def calc_mean_std(feat, eps=1e-5): 569 | # eps is a small value added to the variance to avoid divide-by-zero. 570 | size = feat.size() 571 | assert (len(size) == 4) 572 | N, C = size[:2] 573 | feat_var = feat.view(N, C, -1).var(dim=2) + eps 574 | feat_std = feat_var.sqrt().view(N, C, 1, 1) 575 | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) 576 | return feat_mean, feat_std 577 | 578 | 579 | def adaptive_instance_normalization(content_feat, style_feat): 580 | assert (content_feat.size()[:2] == style_feat.size()[:2]) 581 | size = content_feat.size() 582 | style_mean, style_std = calc_mean_std(style_feat) 583 | content_mean, content_std = calc_mean_std(content_feat) 584 | 585 | normalized_feat = (content_feat - 586 | content_mean.expand(size)) / content_std.expand(size) 587 | return normalized_feat * style_std.expand(size) + style_mean.expand(size) 588 | -------------------------------------------------------------------------------- /src/freeu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.fft as fft 3 | 4 | 5 | def Fourier_filter(x, threshold, scale): 6 | 7 | x_freq = fft.fftn(x, dim=(-2, -1)) 8 | x_freq = fft.fftshift(x_freq, dim=(-2, -1)) 9 | 10 | B, C, H, W = x_freq.shape 11 | mask = torch.ones((B, C, H, W)).cuda() 12 | 13 | crow, ccol = H // 2, W // 2 14 | mask[..., crow - threshold:crow + threshold, 15 | ccol - threshold:ccol + threshold] = scale 16 | x_freq = x_freq * mask 17 | 18 | x_freq = fft.ifftshift(x_freq, dim=(-2, -1)) 19 | 20 | x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real 21 | 22 | return x_filtered 23 | 24 | from deps.ControlNet.ldm.modules.diffusionmodules.util import \ 25 | timestep_embedding # noqa:E501 26 | 27 | 28 | # backbone_scale1=1.1, backbone_scale2=1.2, skip_scale1=1.0, skip_scale2=0.2 29 | def freeu_forward(self, 30 | backbone_scale1=1., 31 | backbone_scale2=1., 32 | skip_scale1=1., 33 | skip_scale2=1.): 34 | 35 | def forward(x, 36 | timesteps=None, 37 | context=None, 38 | control=None, 39 | only_mid_control=False, 40 | **kwargs): 41 | hs = [] 42 | with torch.no_grad(): 43 | t_emb = timestep_embedding(timesteps, 44 | self.model_channels, 45 | repeat_only=False) 46 | emb = self.time_embed(t_emb) 47 | h = x.type(self.dtype) 48 | for module in self.input_blocks: 49 | h = module(h, emb, context) 50 | hs.append(h) 51 | h = self.middle_block(h, emb, context) 52 | 53 | if control is not None: 54 | h += control.pop() 55 | ''' 56 | for i, module in enumerate(self.output_blocks): 57 | if only_mid_control or control is None: 58 | h = torch.cat([h, hs.pop()], dim=1) 59 | else: 60 | h = torch.cat([h, hs.pop() + control.pop()], dim=1) 61 | h = module(h, emb, context) 62 | ''' 63 | for i, module in enumerate(self.output_blocks): 64 | hs_ = hs.pop() 65 | 66 | if h.shape[1] == 1280: 67 | hidden_mean = h.mean(1).unsqueeze(1) 68 | B = hidden_mean.shape[0] 69 | hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True) 70 | hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True) 71 | hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3) 72 | h[:, :640] = h[:, :640] * ((backbone_scale1 - 1) * hidden_mean + 1) 73 | # h[:, :640] = h[:, :640] * backbone_scale1 74 | hs_ = Fourier_filter(hs_, threshold=1, scale=skip_scale1) 75 | if h.shape[1] == 640: 76 | hidden_mean = h.mean(1).unsqueeze(1) 77 | B = hidden_mean.shape[0] 78 | hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True) 79 | hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True) 80 | hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3) 81 | h[:, :320] = h[:, :320] * ((backbone_scale2 - 1) * hidden_mean + 1) 82 | # h[:, :320] = h[:, :320] * backbone_scale2 83 | hs_ = Fourier_filter(hs_, threshold=1, scale=skip_scale2) 84 | 85 | if only_mid_control or control is None: 86 | h = torch.cat([h, hs_], dim=1) 87 | else: 88 | h = torch.cat([h, hs_ + control.pop()], dim=1) 89 | h = module(h, emb, context) 90 | 91 | h = h.type(x.dtype) 92 | return self.out(h) 93 | 94 | return forward 95 | -------------------------------------------------------------------------------- /src/img_util.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | @torch.no_grad() 7 | def find_flat_region(mask): 8 | device = mask.device 9 | kernel_x = torch.Tensor([[-1, 0, 1], [-1, 0, 1], 10 | [-1, 0, 1]]).unsqueeze(0).unsqueeze(0).to(device) 11 | kernel_y = torch.Tensor([[-1, -1, -1], [0, 0, 0], 12 | [1, 1, 1]]).unsqueeze(0).unsqueeze(0).to(device) 13 | mask_ = F.pad(mask.unsqueeze(0), (1, 1, 1, 1), mode='replicate') 14 | 15 | grad_x = torch.nn.functional.conv2d(mask_, kernel_x) 16 | grad_y = torch.nn.functional.conv2d(mask_, kernel_y) 17 | return ((abs(grad_x) + abs(grad_y)) == 0).float()[0] 18 | 19 | 20 | def numpy2tensor(img): 21 | x0 = torch.from_numpy(img.copy()).float().cuda() / 255.0 * 2.0 - 1. 22 | x0 = torch.stack([x0], dim=0) 23 | return einops.rearrange(x0, 'b h w c -> b c h w').clone() 24 | -------------------------------------------------------------------------------- /src/import_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | cur_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 5 | gmflow_dir = os.path.join(cur_dir, 'deps/gmflow') 6 | controlnet_dir = os.path.join(cur_dir, 'deps/ControlNet') 7 | sys.path.insert(0, gmflow_dir) 8 | sys.path.insert(0, controlnet_dir) 9 | 10 | import deps.ControlNet.share # noqa: F401 E402 11 | -------------------------------------------------------------------------------- /src/video_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import torch 5 | import imageio 6 | import numpy as np 7 | 8 | 9 | def video_to_frame(video_path: str, 10 | frame_dir: str, 11 | filename_pattern: str = 'frame%03d.jpg', 12 | log: bool = True, 13 | frame_edit_func=None): 14 | os.makedirs(frame_dir, exist_ok=True) 15 | 16 | vidcap = cv2.VideoCapture(video_path) 17 | success, image = vidcap.read() 18 | 19 | if log: 20 | print('img shape: ', image.shape[0:2]) 21 | 22 | count = 0 23 | while success: 24 | if frame_edit_func is not None: 25 | image = frame_edit_func(image) 26 | 27 | cv2.imwrite(os.path.join(frame_dir, filename_pattern % count), image) 28 | success, image = vidcap.read() 29 | if log: 30 | print('Read a new frame: ', success, count) 31 | count += 1 32 | 33 | vidcap.release() 34 | 35 | 36 | def frame_to_video(video_path: str, frame_dir: str, fps=30, log=True): 37 | 38 | first_img = True 39 | writer = imageio.get_writer(video_path, fps=fps) 40 | 41 | file_list = sorted(os.listdir(frame_dir)) 42 | for file_name in file_list: 43 | if not (file_name.endswith('jpg') or file_name.endswith('png')): 44 | continue 45 | 46 | fn = os.path.join(frame_dir, file_name) 47 | curImg = imageio.imread(fn) 48 | 49 | if first_img: 50 | H, W = curImg.shape[0:2] 51 | if log: 52 | print('img shape', (H, W)) 53 | first_img = False 54 | 55 | writer.append_data(curImg) 56 | 57 | writer.close() 58 | 59 | 60 | def get_fps(video_path: str): 61 | video = cv2.VideoCapture(video_path) 62 | fps = video.get(cv2.CAP_PROP_FPS) 63 | video.release() 64 | return fps 65 | 66 | 67 | def get_frame_count(video_path: str): 68 | video = cv2.VideoCapture(video_path) 69 | frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) 70 | video.release() 71 | return frame_count 72 | 73 | 74 | def resize_image(input_image, resolution): 75 | H, W, C = input_image.shape 76 | H = float(H) 77 | W = float(W) 78 | aspect_ratio = W / H 79 | k = float(resolution) / min(H, W) 80 | H *= k 81 | W *= k 82 | if H < W: 83 | W = resolution 84 | H = int(resolution / aspect_ratio) 85 | else: 86 | H = resolution 87 | W = int(aspect_ratio * resolution) 88 | H = int(np.round(H / 64.0)) * 64 89 | W = int(np.round(W / 64.0)) * 64 90 | img = cv2.resize( 91 | input_image, (W, H), 92 | interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) 93 | return img 94 | 95 | 96 | def prepare_frames(input_path: str, output_dir: str, resolution: int, crop, use_limit_device_resolution=False): 97 | l, r, t, b = crop 98 | 99 | if use_limit_device_resolution: 100 | resolution = vram_limit_device_resolution(resolution) 101 | 102 | def crop_func(frame): 103 | H, W, C = frame.shape 104 | left = np.clip(l, 0, W) 105 | right = np.clip(W - r, left, W) 106 | top = np.clip(t, 0, H) 107 | bottom = np.clip(H - b, top, H) 108 | frame = frame[top:bottom, left:right] 109 | return resize_image(frame, resolution) 110 | 111 | video_to_frame(input_path, output_dir, '%04d.png', False, crop_func) 112 | 113 | 114 | def vram_limit_device_resolution(resolution, device="cuda"): 115 | # get max limit target size 116 | gpu_vram = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) 117 | # table of gpu memory limit 118 | gpu_table = {24: 1280, 18: 1024, 14: 768, 10: 640, 8: 576, 7: 512, 6: 448, 5: 320, 4: 192, 0: 0} 119 | # get user resize for gpu 120 | device_resolution = max(val for key, val in gpu_table.items() if key <= gpu_vram) 121 | print(f"Limit VRAM is {gpu_vram} Gb and size {device_resolution}.") 122 | if gpu_vram < 4: 123 | print(f"Small VRAM to use GPU. Configuration resolution will be used.") 124 | if resolution < device_resolution: 125 | print(f"Video will not resize") 126 | return resolution 127 | return device_resolution 128 | -------------------------------------------------------------------------------- /video_blend.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import platform 4 | import struct 5 | import subprocess 6 | import time 7 | from typing import List 8 | 9 | import cv2 10 | import numpy as np 11 | import torch.multiprocessing as mp 12 | from numba import njit 13 | 14 | import blender.histogram_blend as histogram_blend 15 | from blender.guide import (BaseGuide, ColorGuide, EdgeGuide, PositionalGuide, 16 | TemporalGuide) 17 | from blender.poisson_fusion import poisson_fusion 18 | from blender.video_sequence import VideoSequence 19 | from flow.flow_utils import flow_calc 20 | from src.video_util import frame_to_video 21 | 22 | OPEN_EBSYNTH_LOG = False 23 | MAX_PROCESS = 8 24 | 25 | os_str = platform.system() 26 | 27 | if os_str == 'Windows': 28 | ebsynth_bin = '.\\deps\\ebsynth\\bin\\ebsynth.exe' 29 | elif os_str == 'Linux': 30 | ebsynth_bin = './deps/ebsynth/bin/ebsynth' 31 | elif os_str == 'Darwin': 32 | ebsynth_bin = './deps/ebsynth/bin/ebsynth.app' 33 | else: 34 | print('Cannot recognize OS. Run Ebsynth failed.') 35 | exit(0) 36 | 37 | 38 | @njit 39 | def g_error_mask_loop(H, W, dist1, dist2, output, weight1, weight2): 40 | for i in range(H): 41 | for j in range(W): 42 | if weight1 * dist1[i, j] < weight2 * dist2[i, j]: 43 | output[i, j] = 0 44 | else: 45 | output[i, j] = 1 46 | if weight1 == 0: 47 | output[i, j] = 0 48 | elif weight2 == 0: 49 | output[i, j] = 1 50 | 51 | 52 | def g_error_mask(dist1, dist2, weight1=1, weight2=1): 53 | H, W = dist1.shape 54 | output = np.empty_like(dist1, dtype=np.byte) 55 | g_error_mask_loop(H, W, dist1, dist2, output, weight1, weight2) 56 | return output 57 | 58 | 59 | def create_sequence(base_dir, beg, end, interval, key_dir): 60 | sequence = VideoSequence(base_dir, beg, end, interval, 'video', key_dir, 61 | 'tmp', '%04d.png', '%04d.png') 62 | return sequence 63 | 64 | 65 | def process_one_sequence(i, video_sequence: VideoSequence): 66 | interval = video_sequence.interval 67 | for is_forward in [True, False]: 68 | input_seq = video_sequence.get_input_sequence(i, is_forward) 69 | output_seq = video_sequence.get_output_sequence(i, is_forward) 70 | flow_seq = video_sequence.get_flow_sequence(i, is_forward) 71 | key_img_id = i if is_forward else i + 1 72 | key_img = video_sequence.get_key_img(key_img_id) 73 | for j in range(interval - 1): 74 | i1 = cv2.imread(input_seq[j]) 75 | i2 = cv2.imread(input_seq[j + 1]) 76 | flow_calc.get_flow(i1, i2, flow_seq[j]) 77 | 78 | guides: List[BaseGuide] = [ 79 | ColorGuide(input_seq), 80 | EdgeGuide(input_seq, 81 | video_sequence.get_edge_sequence(i, is_forward)), 82 | TemporalGuide(key_img, output_seq, flow_seq, 83 | video_sequence.get_temporal_sequence(i, is_forward)), 84 | PositionalGuide(flow_seq, 85 | video_sequence.get_pos_sequence(i, is_forward)) 86 | ] 87 | weights = [6, 0.5, 0.5, 2] 88 | for j in range(interval): 89 | # key frame 90 | if j == 0: 91 | img = cv2.imread(key_img) 92 | cv2.imwrite(output_seq[0], img) 93 | else: 94 | cmd = f'{ebsynth_bin} -style {os.path.abspath(key_img)}' 95 | for g, w in zip(guides, weights): 96 | cmd += ' ' + g.get_cmd(j, w) 97 | 98 | cmd += (f' -output {os.path.abspath(output_seq[j])}' 99 | ' -searchvoteiters 12 -patchmatchiters 6') 100 | if OPEN_EBSYNTH_LOG: 101 | print(cmd) 102 | subprocess.run(cmd, 103 | shell=True, 104 | capture_output=not OPEN_EBSYNTH_LOG) 105 | 106 | 107 | def process_sequences(i_arr, video_sequence: VideoSequence): 108 | for i in i_arr: 109 | process_one_sequence(i, video_sequence) 110 | 111 | 112 | def run_ebsynth(video_sequence: VideoSequence): 113 | 114 | beg = time.time() 115 | 116 | processes = [] 117 | mp.set_start_method('spawn') 118 | 119 | n_process = min(MAX_PROCESS, video_sequence.n_seq) 120 | cnt = video_sequence.n_seq // n_process 121 | remainder = video_sequence.n_seq % n_process 122 | 123 | prev_idx = 0 124 | 125 | for i in range(n_process): 126 | task_cnt = cnt + 1 if i < remainder else cnt 127 | i_arr = list(range(prev_idx, prev_idx + task_cnt)) 128 | prev_idx += task_cnt 129 | p = mp.Process(target=process_sequences, args=(i_arr, video_sequence)) 130 | p.start() 131 | processes.append(p) 132 | for p in processes: 133 | p.join() 134 | 135 | end = time.time() 136 | 137 | print(f'ebsynth: {end-beg}') 138 | 139 | 140 | @njit 141 | def assemble_min_error_img_loop(H, W, a, b, error_mask, out): 142 | for i in range(H): 143 | for j in range(W): 144 | if error_mask[i, j] == 0: 145 | out[i, j] = a[i, j] 146 | else: 147 | out[i, j] = b[i, j] 148 | 149 | 150 | def assemble_min_error_img(a, b, error_mask): 151 | H, W = a.shape[0:2] 152 | out = np.empty_like(a) 153 | assemble_min_error_img_loop(H, W, a, b, error_mask, out) 154 | return out 155 | 156 | 157 | def load_error(bin_path, img_shape): 158 | img_size = img_shape[0] * img_shape[1] 159 | with open(bin_path, 'rb') as fp: 160 | bytes = fp.read() 161 | 162 | read_size = struct.unpack('q', bytes[:8]) 163 | assert read_size[0] == img_size 164 | float_res = struct.unpack('f' * img_size, bytes[8:]) 165 | res = np.array(float_res, 166 | dtype=np.float32).reshape(img_shape[0], img_shape[1]) 167 | return res 168 | 169 | 170 | def process_seq(video_sequence: VideoSequence, 171 | i, 172 | blend_histogram=True, 173 | blend_gradient=True): 174 | 175 | key1_img = cv2.imread(video_sequence.get_key_img(i)) 176 | img_shape = key1_img.shape 177 | interval = video_sequence.interval 178 | beg_id = video_sequence.get_sequence_beg_id(i) 179 | 180 | oas = video_sequence.get_output_sequence(i) 181 | obs = video_sequence.get_output_sequence(i, False) 182 | 183 | binas = [x.replace('jpg', 'bin') for x in oas] 184 | binbs = [x.replace('jpg', 'bin') for x in obs] 185 | 186 | obs = [obs[0]] + list(reversed(obs[1:])) 187 | inputs = video_sequence.get_input_sequence(i) 188 | oas = [cv2.imread(x) for x in oas] 189 | obs = [cv2.imread(x) for x in obs] 190 | inputs = [cv2.imread(x) for x in inputs] 191 | flow_seq = video_sequence.get_flow_sequence(i) 192 | 193 | dist1s = [] 194 | dist2s = [] 195 | for i in range(interval - 1): 196 | bin_a = binas[i + 1] 197 | bin_b = binbs[i + 1] 198 | dist1s.append(load_error(bin_a, img_shape)) 199 | dist2s.append(load_error(bin_b, img_shape)) 200 | 201 | lb = 0 202 | ub = 1 203 | beg = time.time() 204 | p_mask = None 205 | 206 | # write key img 207 | blend_out_path = video_sequence.get_blending_img(beg_id) 208 | cv2.imwrite(blend_out_path, key1_img) 209 | 210 | for i in range(interval - 1): 211 | c_id = beg_id + i + 1 212 | blend_out_path = video_sequence.get_blending_img(c_id) 213 | 214 | dist1 = dist1s[i] 215 | dist2 = dist2s[i] 216 | oa = oas[i + 1] 217 | ob = obs[i + 1] 218 | weight1 = i / (interval - 1) * (ub - lb) + lb 219 | weight2 = 1 - weight1 220 | mask = g_error_mask(dist1, dist2, weight1, weight2) 221 | if p_mask is not None: 222 | flow_path = flow_seq[i] 223 | flow = flow_calc.get_flow(inputs[i], inputs[i + 1], flow_path) 224 | p_mask = flow_calc.warp(p_mask, flow, 'nearest') 225 | mask = p_mask | mask 226 | p_mask = mask 227 | 228 | # Save tmp mask 229 | # out_mask = np.expand_dims(mask, 2) 230 | # cv2.imwrite(f'mask/mask_{c_id:04d}.jpg', out_mask * 255) 231 | 232 | min_error_img = assemble_min_error_img(oa, ob, mask) 233 | if blend_histogram: 234 | hb_res = histogram_blend.blend(oa, ob, min_error_img, 235 | (1 - weight1), (1 - weight2)) 236 | 237 | else: 238 | # hb_res = min_error_img 239 | tmpa = oa.astype(np.float32) 240 | tmpb = ob.astype(np.float32) 241 | hb_res = (1 - weight1) * tmpa + (1 - weight2) * tmpb 242 | 243 | # cv2.imwrite(blend_out_path, hb_res) 244 | 245 | # gradient blend 246 | if blend_gradient: 247 | res = poisson_fusion(hb_res, oa, ob, mask) 248 | else: 249 | res = hb_res 250 | 251 | cv2.imwrite(blend_out_path, res) 252 | end = time.time() 253 | print('others:', end - beg) 254 | 255 | 256 | def main(args): 257 | global MAX_PROCESS 258 | MAX_PROCESS = args.n_proc 259 | 260 | video_sequence = create_sequence(f'{args.name}', args.beg, args.end, 261 | args.itv, args.key) 262 | if not args.ne: 263 | run_ebsynth(video_sequence) 264 | blend_histogram = True 265 | blend_gradient = args.ps 266 | for i in range(video_sequence.n_seq): 267 | process_seq(video_sequence, i, blend_histogram, blend_gradient) 268 | if args.output: 269 | frame_to_video(args.output, video_sequence.blending_dir, args.fps, 270 | False) 271 | if not args.tmp: 272 | video_sequence.remove_out_and_tmp() 273 | 274 | 275 | if __name__ == '__main__': 276 | parser = argparse.ArgumentParser() 277 | parser.add_argument('name', type=str, help='Path to input video') 278 | parser.add_argument('--output', 279 | type=str, 280 | default=None, 281 | help='Path to output video') 282 | parser.add_argument('--fps', 283 | type=float, 284 | default=30, 285 | help='The FPS of output video') 286 | parser.add_argument('--beg', 287 | type=int, 288 | default=1, 289 | help='The index of the first frame to be stylized') 290 | parser.add_argument('--end', 291 | type=int, 292 | default=101, 293 | help='The index of the last frame to be stylized') 294 | parser.add_argument('--itv', 295 | type=int, 296 | default=10, 297 | help='The interval of key frame') 298 | parser.add_argument('--key', 299 | type=str, 300 | default='keys0', 301 | help='The subfolder name of stylized key frames') 302 | parser.add_argument('--n_proc', 303 | type=int, 304 | default=8, 305 | help='The max process count') 306 | parser.add_argument('-ps', 307 | action='store_true', 308 | help='Use poisson gradient blending') 309 | parser.add_argument( 310 | '-ne', 311 | action='store_true', 312 | help='Do not run ebsynth (use previous ebsynth output)') 313 | parser.add_argument('-tmp', 314 | action='store_true', 315 | help='Keep temporary output') 316 | 317 | args = parser.parse_args() 318 | main(args) 319 | -------------------------------------------------------------------------------- /videos/pexels-antoni-shkraba-8048492-540x960-25fps.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/williamyang1991/Rerender_A_Video/dfaf9d8825f226a2f0a0b731ab2adc84a3f2ebd2/videos/pexels-antoni-shkraba-8048492-540x960-25fps.mp4 -------------------------------------------------------------------------------- /videos/pexels-cottonbro-studio-6649832-960x506-25fps.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/williamyang1991/Rerender_A_Video/dfaf9d8825f226a2f0a0b731ab2adc84a3f2ebd2/videos/pexels-cottonbro-studio-6649832-960x506-25fps.mp4 -------------------------------------------------------------------------------- /videos/pexels-koolshooters-7322716.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/williamyang1991/Rerender_A_Video/dfaf9d8825f226a2f0a0b731ab2adc84a3f2ebd2/videos/pexels-koolshooters-7322716.mp4 -------------------------------------------------------------------------------- /webUI.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from enum import Enum 4 | 5 | import cv2 6 | import einops 7 | import gradio as gr 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | import torchvision.transforms as T 12 | from blendmodes.blend import BlendType, blendLayers 13 | from PIL import Image 14 | from pytorch_lightning import seed_everything 15 | from safetensors.torch import load_file 16 | from skimage import exposure 17 | 18 | import src.import_util # noqa: F401 19 | from deps.ControlNet.annotator.canny import CannyDetector 20 | from deps.ControlNet.annotator.hed import HEDdetector 21 | from deps.ControlNet.annotator.util import HWC3 22 | from deps.ControlNet.cldm.model import create_model, load_state_dict 23 | from deps.gmflow.gmflow.gmflow import GMFlow 24 | from flow.flow_utils import get_warped_and_mask 25 | from sd_model_cfg import model_dict 26 | from src.config import RerenderConfig 27 | from src.controller import AttentionControl 28 | from src.ddim_v_hacked import DDIMVSampler 29 | from src.freeu import freeu_forward 30 | from src.img_util import find_flat_region, numpy2tensor 31 | from src.video_util import (frame_to_video, get_fps, get_frame_count, 32 | prepare_frames) 33 | 34 | inversed_model_dict = dict() 35 | for k, v in model_dict.items(): 36 | inversed_model_dict[v] = k 37 | 38 | to_tensor = T.PILToTensor() 39 | blur = T.GaussianBlur(kernel_size=(9, 9), sigma=(18, 18)) 40 | 41 | 42 | class ProcessingState(Enum): 43 | NULL = 0 44 | FIRST_IMG = 1 45 | KEY_IMGS = 2 46 | 47 | 48 | class GlobalState: 49 | 50 | def __init__(self): 51 | self.sd_model = None 52 | self.ddim_v_sampler = None 53 | self.detector_type = None 54 | self.detector = None 55 | self.controller = None 56 | self.processing_state = ProcessingState.NULL 57 | flow_model = GMFlow( 58 | feature_channels=128, 59 | num_scales=1, 60 | upsample_factor=8, 61 | num_head=1, 62 | attention_type='swin', 63 | ffn_dim_expansion=4, 64 | num_transformer_layers=6, 65 | ).to('cuda') 66 | 67 | checkpoint = torch.load('models/gmflow_sintel-0c07dcb3.pth', 68 | map_location=lambda storage, loc: storage) 69 | weights = checkpoint['model'] if 'model' in checkpoint else checkpoint 70 | flow_model.load_state_dict(weights, strict=False) 71 | flow_model.eval() 72 | self.flow_model = flow_model 73 | 74 | def update_controller(self, inner_strength, mask_period, cross_period, 75 | ada_period, warp_period, loose_cfattn): 76 | self.controller = AttentionControl(inner_strength, 77 | mask_period, 78 | cross_period, 79 | ada_period, 80 | warp_period, 81 | loose_cfatnn=loose_cfattn) 82 | 83 | def update_sd_model(self, sd_model, control_type, freeu_args): 84 | if sd_model == self.sd_model: 85 | return 86 | self.sd_model = sd_model 87 | model = create_model('./deps/ControlNet/models/cldm_v15.yaml').cpu() 88 | if control_type == 'HED': 89 | model.load_state_dict( 90 | load_state_dict('./models/control_sd15_hed.pth', 91 | location='cuda')) 92 | elif control_type == 'canny': 93 | model.load_state_dict( 94 | load_state_dict('./models/control_sd15_canny.pth', 95 | location='cuda')) 96 | model = model.cuda() 97 | sd_model_path = model_dict[sd_model] 98 | if len(sd_model_path) > 0: 99 | model_ext = os.path.splitext(sd_model_path)[1] 100 | if model_ext == '.safetensors': 101 | model.load_state_dict(load_file(sd_model_path), strict=False) 102 | elif model_ext == '.ckpt' or model_ext == '.pth': 103 | model.load_state_dict(torch.load(sd_model_path)['state_dict'], 104 | strict=False) 105 | 106 | try: 107 | model.first_stage_model.load_state_dict(torch.load( 108 | './models/vae-ft-mse-840000-ema-pruned.ckpt')['state_dict'], 109 | strict=False) 110 | except Exception: 111 | print('Warning: We suggest you download the fine-tuned VAE', 112 | 'otherwise the generation quality will be degraded') 113 | 114 | model.model.diffusion_model.forward = freeu_forward( 115 | model.model.diffusion_model, *freeu_args) 116 | self.ddim_v_sampler = DDIMVSampler(model) 117 | 118 | def clear_sd_model(self): 119 | self.sd_model = None 120 | self.ddim_v_sampler = None 121 | torch.cuda.empty_cache() 122 | 123 | def update_detector(self, control_type, canny_low=100, canny_high=200): 124 | if self.detector_type == control_type: 125 | return 126 | if control_type == 'HED': 127 | self.detector = HEDdetector() 128 | elif control_type == 'canny': 129 | canny_detector = CannyDetector() 130 | low_threshold = canny_low 131 | high_threshold = canny_high 132 | 133 | def apply_canny(x): 134 | return canny_detector(x, low_threshold, high_threshold) 135 | 136 | self.detector = apply_canny 137 | 138 | 139 | global_state = GlobalState() 140 | global_video_path = None 141 | video_frame_count = None 142 | 143 | 144 | def create_cfg(input_path, prompt, image_resolution, control_strength, 145 | color_preserve, left_crop, right_crop, top_crop, bottom_crop, 146 | control_type, low_threshold, high_threshold, ddim_steps, scale, 147 | seed, sd_model, a_prompt, n_prompt, interval, keyframe_count, 148 | x0_strength, use_constraints, cross_start, cross_end, 149 | style_update_freq, warp_start, warp_end, mask_start, mask_end, 150 | ada_start, ada_end, mask_strength, inner_strength, 151 | smooth_boundary, loose_cfattn, b1, b2, s1, s2): 152 | use_warp = 'shape-aware fusion' in use_constraints 153 | use_mask = 'pixel-aware fusion' in use_constraints 154 | use_ada = 'color-aware AdaIN' in use_constraints 155 | 156 | if not use_warp: 157 | warp_start = 1 158 | warp_end = 0 159 | 160 | if not use_mask: 161 | mask_start = 1 162 | mask_end = 0 163 | 164 | if not use_ada: 165 | ada_start = 1 166 | ada_end = 0 167 | 168 | input_name = os.path.split(input_path)[-1].split('.')[0] 169 | frame_count = 2 + keyframe_count * interval 170 | cfg = RerenderConfig() 171 | cfg.create_from_parameters( 172 | input_path, 173 | os.path.join('result', input_name, 'blend.mp4'), 174 | prompt, 175 | a_prompt=a_prompt, 176 | n_prompt=n_prompt, 177 | frame_count=frame_count, 178 | interval=interval, 179 | crop=[left_crop, right_crop, top_crop, bottom_crop], 180 | sd_model=sd_model, 181 | ddim_steps=ddim_steps, 182 | scale=scale, 183 | control_type=control_type, 184 | control_strength=control_strength, 185 | canny_low=low_threshold, 186 | canny_high=high_threshold, 187 | seed=seed, 188 | image_resolution=image_resolution, 189 | x0_strength=x0_strength, 190 | style_update_freq=style_update_freq, 191 | cross_period=(cross_start, cross_end), 192 | warp_period=(warp_start, warp_end), 193 | mask_period=(mask_start, mask_end), 194 | ada_period=(ada_start, ada_end), 195 | mask_strength=mask_strength, 196 | inner_strength=inner_strength, 197 | smooth_boundary=smooth_boundary, 198 | color_preserve=color_preserve, 199 | loose_cfattn=loose_cfattn, 200 | freeu_args=[b1, b2, s1, s2]) 201 | return cfg 202 | 203 | 204 | def cfg_to_input(filename): 205 | 206 | cfg = RerenderConfig() 207 | cfg.create_from_path(filename) 208 | keyframe_count = (cfg.frame_count - 2) // cfg.interval 209 | use_constraints = [ 210 | 'shape-aware fusion', 'pixel-aware fusion', 'color-aware AdaIN' 211 | ] 212 | 213 | sd_model = inversed_model_dict.get(cfg.sd_model, 'Stable Diffusion 1.5') 214 | 215 | args = [ 216 | cfg.input_path, cfg.prompt, cfg.image_resolution, cfg.control_strength, 217 | cfg.color_preserve, *cfg.crop, cfg.control_type, cfg.canny_low, 218 | cfg.canny_high, cfg.ddim_steps, cfg.scale, cfg.seed, sd_model, 219 | cfg.a_prompt, cfg.n_prompt, cfg.interval, keyframe_count, 220 | cfg.x0_strength, use_constraints, *cfg.cross_period, 221 | cfg.style_update_freq, *cfg.warp_period, *cfg.mask_period, 222 | *cfg.ada_period, cfg.mask_strength, cfg.inner_strength, 223 | cfg.smooth_boundary, cfg.loose_cfattn, *cfg.freeu_args 224 | ] 225 | return args 226 | 227 | 228 | def setup_color_correction(image): 229 | correction_target = cv2.cvtColor(np.asarray(image.copy()), 230 | cv2.COLOR_RGB2LAB) 231 | return correction_target 232 | 233 | 234 | def apply_color_correction(correction, original_image): 235 | image = Image.fromarray( 236 | cv2.cvtColor( 237 | exposure.match_histograms(cv2.cvtColor(np.asarray(original_image), 238 | cv2.COLOR_RGB2LAB), 239 | correction, 240 | channel_axis=2), 241 | cv2.COLOR_LAB2RGB).astype('uint8')) 242 | 243 | image = blendLayers(image, original_image, BlendType.LUMINOSITY) 244 | 245 | return image 246 | 247 | 248 | @torch.no_grad() 249 | def process(*args): 250 | args_wo_process3 = args[:-2] 251 | first_frame = process1(*args_wo_process3) 252 | 253 | keypath = process2(*args_wo_process3) 254 | 255 | fullpath = process3(*args) 256 | 257 | return first_frame, keypath, fullpath 258 | 259 | 260 | @torch.no_grad() 261 | def process1(*args): 262 | 263 | global global_video_path 264 | cfg = create_cfg(global_video_path, *args) 265 | global global_state 266 | global_state.update_sd_model(cfg.sd_model, cfg.control_type, 267 | cfg.freeu_args) 268 | global_state.update_controller(cfg.inner_strength, cfg.mask_period, 269 | cfg.cross_period, cfg.ada_period, 270 | cfg.warp_period, cfg.loose_cfattn) 271 | global_state.update_detector(cfg.control_type, cfg.canny_low, 272 | cfg.canny_high) 273 | global_state.processing_state = ProcessingState.FIRST_IMG 274 | 275 | prepare_frames(cfg.input_path, cfg.input_dir, cfg.image_resolution, cfg.crop, cfg.use_limit_device_resolution) 276 | 277 | ddim_v_sampler = global_state.ddim_v_sampler 278 | model = ddim_v_sampler.model 279 | detector = global_state.detector 280 | controller = global_state.controller 281 | model.control_scales = [cfg.control_strength] * 13 282 | 283 | num_samples = 1 284 | eta = 0.0 285 | imgs = sorted(os.listdir(cfg.input_dir)) 286 | imgs = [os.path.join(cfg.input_dir, img) for img in imgs] 287 | 288 | with torch.no_grad(): 289 | frame = cv2.imread(imgs[0]) 290 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 291 | img = HWC3(frame) 292 | H, W, C = img.shape 293 | 294 | img_ = numpy2tensor(img) 295 | 296 | def generate_first_img(img_, strength): 297 | encoder_posterior = model.encode_first_stage(img_.cuda()) 298 | x0 = model.get_first_stage_encoding(encoder_posterior).detach() 299 | 300 | detected_map = detector(img) 301 | detected_map = HWC3(detected_map) 302 | 303 | control = torch.from_numpy( 304 | detected_map.copy()).float().cuda() / 255.0 305 | control = torch.stack([control for _ in range(num_samples)], dim=0) 306 | control = einops.rearrange(control, 'b h w c -> b c h w').clone() 307 | cond = { 308 | 'c_concat': [control], 309 | 'c_crossattn': [ 310 | model.get_learned_conditioning( 311 | [cfg.prompt + ', ' + cfg.a_prompt] * num_samples) 312 | ] 313 | } 314 | un_cond = { 315 | 'c_concat': [control], 316 | 'c_crossattn': 317 | [model.get_learned_conditioning([cfg.n_prompt] * num_samples)] 318 | } 319 | shape = (4, H // 8, W // 8) 320 | 321 | controller.set_task('initfirst') 322 | seed_everything(cfg.seed) 323 | 324 | samples, _ = ddim_v_sampler.sample( 325 | cfg.ddim_steps, 326 | num_samples, 327 | shape, 328 | cond, 329 | verbose=False, 330 | eta=eta, 331 | unconditional_guidance_scale=cfg.scale, 332 | unconditional_conditioning=un_cond, 333 | controller=controller, 334 | x0=x0, 335 | strength=strength) 336 | x_samples = model.decode_first_stage(samples) 337 | x_samples_np = ( 338 | einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 339 | 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) 340 | return x_samples, x_samples_np 341 | 342 | # When not preserve color, draw a different frame at first and use its 343 | # color to redraw the first frame. 344 | if not cfg.color_preserve: 345 | first_strength = -1 346 | else: 347 | first_strength = 1 - cfg.x0_strength 348 | 349 | x_samples, x_samples_np = generate_first_img(img_, first_strength) 350 | 351 | if not cfg.color_preserve: 352 | color_corrections = setup_color_correction( 353 | Image.fromarray(x_samples_np[0])) 354 | global_state.color_corrections = color_corrections 355 | img_ = apply_color_correction(color_corrections, 356 | Image.fromarray(img)) 357 | img_ = to_tensor(img_).unsqueeze(0)[:, :3] / 127.5 - 1 358 | x_samples, x_samples_np = generate_first_img( 359 | img_, 1 - cfg.x0_strength) 360 | 361 | global_state.first_result = x_samples 362 | global_state.first_img = img 363 | 364 | Image.fromarray(x_samples_np[0]).save( 365 | os.path.join(cfg.first_dir, 'first.jpg')) 366 | 367 | return x_samples_np[0] 368 | 369 | 370 | @torch.no_grad() 371 | def process2(*args): 372 | global global_state 373 | global global_video_path 374 | 375 | if global_state.processing_state != ProcessingState.FIRST_IMG: 376 | raise gr.Error('Please generate the first key image before generating' 377 | ' all key images') 378 | 379 | cfg = create_cfg(global_video_path, *args) 380 | global_state.update_sd_model(cfg.sd_model, cfg.control_type, 381 | cfg.freeu_args) 382 | global_state.update_detector(cfg.control_type, cfg.canny_low, 383 | cfg.canny_high) 384 | global_state.processing_state = ProcessingState.KEY_IMGS 385 | 386 | # reset key dir 387 | shutil.rmtree(cfg.key_dir) 388 | os.makedirs(cfg.key_dir, exist_ok=True) 389 | 390 | ddim_v_sampler = global_state.ddim_v_sampler 391 | model = ddim_v_sampler.model 392 | detector = global_state.detector 393 | controller = global_state.controller 394 | flow_model = global_state.flow_model 395 | model.control_scales = [cfg.control_strength] * 13 396 | 397 | num_samples = 1 398 | eta = 0.0 399 | firstx0 = True 400 | pixelfusion = cfg.use_mask 401 | imgs = sorted(os.listdir(cfg.input_dir)) 402 | imgs = [os.path.join(cfg.input_dir, img) for img in imgs] 403 | 404 | first_result = global_state.first_result 405 | first_img = global_state.first_img 406 | pre_result = first_result 407 | pre_img = first_img 408 | 409 | for i in range(0, min(len(imgs), cfg.frame_count) - 1, cfg.interval): 410 | cid = i + 1 411 | print(cid) 412 | if cid <= (len(imgs) - 1): 413 | frame = cv2.imread(imgs[cid]) 414 | else: 415 | frame = cv2.imread(imgs[len(imgs) - 1]) 416 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 417 | img = HWC3(frame) 418 | H, W, C = img.shape 419 | 420 | if cfg.color_preserve or global_state.color_corrections is None: 421 | img_ = numpy2tensor(img) 422 | else: 423 | img_ = apply_color_correction(global_state.color_corrections, 424 | Image.fromarray(img)) 425 | img_ = to_tensor(img_).unsqueeze(0)[:, :3] / 127.5 - 1 426 | encoder_posterior = model.encode_first_stage(img_.cuda()) 427 | x0 = model.get_first_stage_encoding(encoder_posterior).detach() 428 | 429 | detected_map = detector(img) 430 | detected_map = HWC3(detected_map) 431 | 432 | control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 433 | control = torch.stack([control for _ in range(num_samples)], dim=0) 434 | control = einops.rearrange(control, 'b h w c -> b c h w').clone() 435 | cond = { 436 | 'c_concat': [control], 437 | 'c_crossattn': [ 438 | model.get_learned_conditioning( 439 | [cfg.prompt + ', ' + cfg.a_prompt] * num_samples) 440 | ] 441 | } 442 | un_cond = { 443 | 'c_concat': [control], 444 | 'c_crossattn': 445 | [model.get_learned_conditioning([cfg.n_prompt] * num_samples)] 446 | } 447 | shape = (4, H // 8, W // 8) 448 | 449 | cond['c_concat'] = [control] 450 | un_cond['c_concat'] = [control] 451 | 452 | image1 = torch.from_numpy(pre_img).permute(2, 0, 1).float() 453 | image2 = torch.from_numpy(img).permute(2, 0, 1).float() 454 | warped_pre, bwd_occ_pre, bwd_flow_pre = get_warped_and_mask( 455 | flow_model, image1, image2, pre_result, False) 456 | blend_mask_pre = blur( 457 | F.max_pool2d(bwd_occ_pre, kernel_size=9, stride=1, padding=4)) 458 | blend_mask_pre = torch.clamp(blend_mask_pre + bwd_occ_pre, 0, 1) 459 | 460 | image1 = torch.from_numpy(first_img).permute(2, 0, 1).float() 461 | warped_0, bwd_occ_0, bwd_flow_0 = get_warped_and_mask( 462 | flow_model, image1, image2, first_result, False) 463 | blend_mask_0 = blur( 464 | F.max_pool2d(bwd_occ_0, kernel_size=9, stride=1, padding=4)) 465 | blend_mask_0 = torch.clamp(blend_mask_0 + bwd_occ_0, 0, 1) 466 | 467 | if firstx0: 468 | mask = 1 - F.max_pool2d(blend_mask_0, kernel_size=8) 469 | controller.set_warp( 470 | F.interpolate(bwd_flow_0 / 8.0, 471 | scale_factor=1. / 8, 472 | mode='bilinear'), mask) 473 | else: 474 | mask = 1 - F.max_pool2d(blend_mask_pre, kernel_size=8) 475 | controller.set_warp( 476 | F.interpolate(bwd_flow_pre / 8.0, 477 | scale_factor=1. / 8, 478 | mode='bilinear'), mask) 479 | 480 | controller.set_task('keepx0, keepstyle') 481 | seed_everything(cfg.seed) 482 | samples, intermediates = ddim_v_sampler.sample( 483 | cfg.ddim_steps, 484 | num_samples, 485 | shape, 486 | cond, 487 | verbose=False, 488 | eta=eta, 489 | unconditional_guidance_scale=cfg.scale, 490 | unconditional_conditioning=un_cond, 491 | controller=controller, 492 | x0=x0, 493 | strength=1 - cfg.x0_strength) 494 | direct_result = model.decode_first_stage(samples) 495 | 496 | if not pixelfusion: 497 | pre_result = direct_result 498 | pre_img = img 499 | viz = ( 500 | einops.rearrange(direct_result, 'b c h w -> b h w c') * 127.5 + 501 | 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) 502 | 503 | else: 504 | 505 | blend_results = (1 - blend_mask_pre 506 | ) * warped_pre + blend_mask_pre * direct_result 507 | blend_results = ( 508 | 1 - blend_mask_0) * warped_0 + blend_mask_0 * blend_results 509 | 510 | bwd_occ = 1 - torch.clamp(1 - bwd_occ_pre + 1 - bwd_occ_0, 0, 1) 511 | blend_mask = blur( 512 | F.max_pool2d(bwd_occ, kernel_size=9, stride=1, padding=4)) 513 | blend_mask = 1 - torch.clamp(blend_mask + bwd_occ, 0, 1) 514 | 515 | encoder_posterior = model.encode_first_stage(blend_results) 516 | xtrg = model.get_first_stage_encoding( 517 | encoder_posterior).detach() # * mask 518 | blend_results_rec = model.decode_first_stage(xtrg) 519 | encoder_posterior = model.encode_first_stage(blend_results_rec) 520 | xtrg_rec = model.get_first_stage_encoding( 521 | encoder_posterior).detach() 522 | xtrg_ = (xtrg + 1 * (xtrg - xtrg_rec)) # * mask 523 | blend_results_rec_new = model.decode_first_stage(xtrg_) 524 | tmp = (abs(blend_results_rec_new - blend_results).mean( 525 | dim=1, keepdims=True) > 0.25).float() 526 | mask_x = F.max_pool2d((F.interpolate( 527 | tmp, scale_factor=1 / 8., mode='bilinear') > 0).float(), 528 | kernel_size=3, 529 | stride=1, 530 | padding=1) 531 | 532 | mask = (1 - F.max_pool2d(1 - blend_mask, kernel_size=8) 533 | ) # * (1-mask_x) 534 | 535 | if cfg.smooth_boundary: 536 | noise_rescale = find_flat_region(mask) 537 | else: 538 | noise_rescale = torch.ones_like(mask) 539 | masks = [] 540 | for j in range(cfg.ddim_steps): 541 | if j <= cfg.ddim_steps * cfg.mask_period[ 542 | 0] or j >= cfg.ddim_steps * cfg.mask_period[1]: 543 | masks += [None] 544 | else: 545 | masks += [mask * cfg.mask_strength] 546 | 547 | # mask 3 548 | # xtrg = ((1-mask_x) * 549 | # (xtrg + xtrg - xtrg_rec) + mask_x * samples) * mask 550 | # mask 2 551 | # xtrg = (xtrg + 1 * (xtrg - xtrg_rec)) * mask 552 | xtrg = (xtrg + (1 - mask_x) * (xtrg - xtrg_rec)) * mask # mask 1 553 | 554 | tasks = 'keepstyle, keepx0' 555 | if not firstx0: 556 | tasks += ', updatex0' 557 | if i % cfg.style_update_freq == 0: 558 | tasks += ', updatestyle' 559 | controller.set_task(tasks, 1.0) 560 | 561 | seed_everything(cfg.seed) 562 | samples, _ = ddim_v_sampler.sample( 563 | cfg.ddim_steps, 564 | num_samples, 565 | shape, 566 | cond, 567 | verbose=False, 568 | eta=eta, 569 | unconditional_guidance_scale=cfg.scale, 570 | unconditional_conditioning=un_cond, 571 | controller=controller, 572 | x0=x0, 573 | strength=1 - cfg.x0_strength, 574 | xtrg=xtrg, 575 | mask=masks, 576 | noise_rescale=noise_rescale) 577 | x_samples = model.decode_first_stage(samples) 578 | pre_result = x_samples 579 | pre_img = img 580 | 581 | viz = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 582 | 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) 583 | 584 | Image.fromarray(viz[0]).save( 585 | os.path.join(cfg.key_dir, f'{cid:04d}.png')) 586 | 587 | key_video_path = os.path.join(cfg.work_dir, 'key.mp4') 588 | fps = get_fps(cfg.input_path) 589 | fps //= cfg.interval 590 | frame_to_video(key_video_path, cfg.key_dir, fps, False) 591 | 592 | return key_video_path 593 | 594 | 595 | @torch.no_grad() 596 | def process3(*args): 597 | max_process = args[-2] 598 | use_poisson = args[-1] 599 | args = args[:-2] 600 | global global_video_path 601 | global global_state 602 | if global_state.processing_state != ProcessingState.KEY_IMGS: 603 | raise gr.Error('Please generate key images before propagation') 604 | 605 | global_state.clear_sd_model() 606 | 607 | cfg = create_cfg(global_video_path, *args) 608 | 609 | # reset blend dir 610 | blend_dir = os.path.join(cfg.work_dir, 'blend') 611 | if os.path.exists(blend_dir): 612 | shutil.rmtree(blend_dir) 613 | os.makedirs(blend_dir, exist_ok=True) 614 | 615 | video_base_dir = cfg.work_dir 616 | o_video = cfg.output_path 617 | fps = get_fps(cfg.input_path) 618 | 619 | end_frame = cfg.frame_count - 1 620 | interval = cfg.interval 621 | key_dir = os.path.split(cfg.key_dir)[-1] 622 | o_video_cmd = f'--output {o_video}' 623 | ps = '-ps' if use_poisson else '' 624 | cmd = (f'python video_blend.py {video_base_dir} --beg 1 --end {end_frame} ' 625 | f'--itv {interval} --key {key_dir} {o_video_cmd} --fps {fps} ' 626 | f'--n_proc {max_process} {ps}') 627 | print(cmd) 628 | os.system(cmd) 629 | 630 | return o_video 631 | 632 | 633 | block = gr.Blocks().queue() 634 | with block: 635 | with gr.Row(): 636 | gr.Markdown('## Rerender A Video') 637 | with gr.Row(): 638 | with gr.Column(): 639 | input_path = gr.Video(label='Input Video', 640 | source='upload', 641 | format='mp4', 642 | visible=True) 643 | prompt = gr.Textbox(label='Prompt') 644 | seed = gr.Slider(label='Seed', 645 | minimum=0, 646 | maximum=2147483647, 647 | step=1, 648 | value=0, 649 | randomize=True) 650 | run_button = gr.Button(value='Run All') 651 | with gr.Row(): 652 | run_button1 = gr.Button(value='Run 1st Key Frame') 653 | run_button2 = gr.Button(value='Run Key Frames') 654 | run_button3 = gr.Button(value='Run Propagation') 655 | with gr.Accordion('Advanced options for the 1st frame translation', 656 | open=False): 657 | image_resolution = gr.Slider(label='Frame resolution', 658 | minimum=256, 659 | maximum=768, 660 | value=512, 661 | step=64) 662 | control_strength = gr.Slider(label='ControlNet strength', 663 | minimum=0.0, 664 | maximum=2.0, 665 | value=1.0, 666 | step=0.01) 667 | x0_strength = gr.Slider( 668 | label='Denoising strength', 669 | minimum=0.00, 670 | maximum=1.05, 671 | value=0.75, 672 | step=0.05, 673 | info=('0: fully recover the input.' 674 | '1.05: fully rerender the input.')) 675 | color_preserve = gr.Checkbox( 676 | label='Preserve color', 677 | value=True, 678 | info='Keep the color of the input video') 679 | with gr.Row(): 680 | left_crop = gr.Slider(label='Left crop length', 681 | minimum=0, 682 | maximum=512, 683 | value=0, 684 | step=1) 685 | right_crop = gr.Slider(label='Right crop length', 686 | minimum=0, 687 | maximum=512, 688 | value=0, 689 | step=1) 690 | with gr.Row(): 691 | top_crop = gr.Slider(label='Top crop length', 692 | minimum=0, 693 | maximum=512, 694 | value=0, 695 | step=1) 696 | bottom_crop = gr.Slider(label='Bottom crop length', 697 | minimum=0, 698 | maximum=512, 699 | value=0, 700 | step=1) 701 | with gr.Row(): 702 | control_type = gr.Dropdown(['HED', 'canny'], 703 | label='Control type', 704 | value='HED') 705 | low_threshold = gr.Slider(label='Canny low threshold', 706 | minimum=1, 707 | maximum=255, 708 | value=100, 709 | step=1) 710 | high_threshold = gr.Slider(label='Canny high threshold', 711 | minimum=1, 712 | maximum=255, 713 | value=200, 714 | step=1) 715 | ddim_steps = gr.Slider(label='Steps', 716 | minimum=20, 717 | maximum=100, 718 | value=20, 719 | step=20) 720 | scale = gr.Slider(label='CFG scale', 721 | minimum=0.1, 722 | maximum=30.0, 723 | value=7.5, 724 | step=0.1) 725 | sd_model_list = list(model_dict.keys()) 726 | sd_model = gr.Dropdown(sd_model_list, 727 | label='Base model', 728 | value='Stable Diffusion 1.5') 729 | a_prompt = gr.Textbox(label='Added prompt', 730 | value='best quality, extremely detailed') 731 | n_prompt = gr.Textbox( 732 | label='Negative prompt', 733 | value=('longbody, lowres, bad anatomy, bad hands, ' 734 | 'missing fingers, extra digit, fewer digits, ' 735 | 'cropped, worst quality, low quality')) 736 | with gr.Row(): 737 | b1 = gr.Slider(label='FreeU first-stage backbone factor', 738 | minimum=1, 739 | maximum=1.6, 740 | value=1, 741 | step=0.01, 742 | info='FreeU to enhance texture and color') 743 | b2 = gr.Slider(label='FreeU second-stage backbone factor', 744 | minimum=1, 745 | maximum=1.6, 746 | value=1, 747 | step=0.01) 748 | with gr.Row(): 749 | s1 = gr.Slider(label='FreeU first-stage skip factor', 750 | minimum=0, 751 | maximum=1, 752 | value=1, 753 | step=0.01) 754 | s2 = gr.Slider(label='FreeU second-stage skip factor', 755 | minimum=0, 756 | maximum=1, 757 | value=1, 758 | step=0.01) 759 | with gr.Accordion('Advanced options for the key fame translation', 760 | open=False): 761 | interval = gr.Slider( 762 | label='Key frame frequency (K)', 763 | minimum=1, 764 | maximum=1, 765 | value=1, 766 | step=1, 767 | info='Uniformly sample the key frames every K frames') 768 | keyframe_count = gr.Slider(label='Number of key frames', 769 | minimum=1, 770 | maximum=1, 771 | value=1, 772 | step=1) 773 | 774 | use_constraints = gr.CheckboxGroup( 775 | [ 776 | 'shape-aware fusion', 'pixel-aware fusion', 777 | 'color-aware AdaIN' 778 | ], 779 | label='Select the cross-frame contraints to be used', 780 | value=[ 781 | 'shape-aware fusion', 'pixel-aware fusion', 782 | 'color-aware AdaIN' 783 | ]), 784 | with gr.Row(): 785 | cross_start = gr.Slider( 786 | label='Cross-frame attention start', 787 | minimum=0, 788 | maximum=1, 789 | value=0, 790 | step=0.05) 791 | cross_end = gr.Slider(label='Cross-frame attention end', 792 | minimum=0, 793 | maximum=1, 794 | value=1, 795 | step=0.05) 796 | style_update_freq = gr.Slider( 797 | label='Cross-frame attention update frequency', 798 | minimum=1, 799 | maximum=100, 800 | value=1, 801 | step=1, 802 | info=('Update the key and value for ' 803 | 'cross-frame attention every N key frames')) 804 | loose_cfattn = gr.Checkbox( 805 | label='Loose Cross-frame attention', 806 | value=True, 807 | info='Select to make output better match the input video') 808 | with gr.Row(): 809 | warp_start = gr.Slider(label='Shape-aware fusion start', 810 | minimum=0, 811 | maximum=1, 812 | value=0, 813 | step=0.05) 814 | warp_end = gr.Slider(label='Shape-aware fusion end', 815 | minimum=0, 816 | maximum=1, 817 | value=0.1, 818 | step=0.05) 819 | with gr.Row(): 820 | mask_start = gr.Slider(label='Pixel-aware fusion start', 821 | minimum=0, 822 | maximum=1, 823 | value=0.5, 824 | step=0.05) 825 | mask_end = gr.Slider(label='Pixel-aware fusion end', 826 | minimum=0, 827 | maximum=1, 828 | value=0.8, 829 | step=0.05) 830 | with gr.Row(): 831 | ada_start = gr.Slider(label='Color-aware AdaIN start', 832 | minimum=0, 833 | maximum=1, 834 | value=0.8, 835 | step=0.05) 836 | ada_end = gr.Slider(label='Color-aware AdaIN end', 837 | minimum=0, 838 | maximum=1, 839 | value=1, 840 | step=0.05) 841 | mask_strength = gr.Slider(label='Pixel-aware fusion strength', 842 | minimum=0, 843 | maximum=1, 844 | value=0.5, 845 | step=0.01) 846 | inner_strength = gr.Slider( 847 | label='Pixel-aware fusion detail level', 848 | minimum=0.5, 849 | maximum=1, 850 | value=0.9, 851 | step=0.01, 852 | info='Use a low value to prevent artifacts') 853 | smooth_boundary = gr.Checkbox( 854 | label='Smooth fusion boundary', 855 | value=True, 856 | info='Select to prevent artifacts at boundary') 857 | with gr.Accordion( 858 | 'Advanced options for the full video translation', 859 | open=False): 860 | use_poisson = gr.Checkbox( 861 | label='Gradient blending', 862 | value=True, 863 | info=('Blend the output video in gradient, to reduce' 864 | ' ghosting artifacts (but may increase flickers)')) 865 | max_process = gr.Slider(label='Number of parallel processes', 866 | minimum=1, 867 | maximum=16, 868 | value=4, 869 | step=1) 870 | 871 | with gr.Accordion('Example configs', open=True): 872 | config_dir = 'config' 873 | config_list = [ 874 | 'real2sculpture.json', 'van_gogh_man.json', 'woman.json' 875 | ] 876 | args_list = [] 877 | for config in config_list: 878 | try: 879 | config_path = os.path.join(config_dir, config) 880 | args = cfg_to_input(config_path) 881 | args_list.append(args) 882 | except FileNotFoundError: 883 | # The video file does not exist, skipped 884 | pass 885 | 886 | ips = [ 887 | prompt, image_resolution, control_strength, color_preserve, 888 | left_crop, right_crop, top_crop, bottom_crop, control_type, 889 | low_threshold, high_threshold, ddim_steps, scale, seed, 890 | sd_model, a_prompt, n_prompt, interval, keyframe_count, 891 | x0_strength, use_constraints[0], cross_start, cross_end, 892 | style_update_freq, warp_start, warp_end, mask_start, 893 | mask_end, ada_start, ada_end, mask_strength, 894 | inner_strength, smooth_boundary, loose_cfattn, b1, b2, s1, 895 | s2 896 | ] 897 | 898 | gr.Examples( 899 | examples=args_list, 900 | inputs=[input_path, *ips], 901 | ) 902 | 903 | with gr.Column(): 904 | result_image = gr.Image(label='Output first frame', 905 | type='numpy', 906 | interactive=False) 907 | result_keyframe = gr.Video(label='Output key frame video', 908 | format='mp4', 909 | interactive=False) 910 | result_video = gr.Video(label='Output full video', 911 | format='mp4', 912 | interactive=False) 913 | 914 | def input_uploaded(path): 915 | frame_count = get_frame_count(path) 916 | if frame_count <= 2: 917 | raise gr.Error('The input video is too short!' 918 | 'Please input another video.') 919 | 920 | default_interval = min(10, frame_count - 2) 921 | max_keyframe = (frame_count - 2) // default_interval 922 | 923 | global video_frame_count 924 | video_frame_count = frame_count 925 | global global_video_path 926 | global_video_path = path 927 | 928 | return gr.Slider.update(value=default_interval, 929 | maximum=max_keyframe), gr.Slider.update( 930 | value=max_keyframe, maximum=max_keyframe) 931 | 932 | def input_changed(path): 933 | frame_count = get_frame_count(path) 934 | if frame_count <= 2: 935 | return gr.Slider.update(maximum=1), gr.Slider.update(maximum=1) 936 | 937 | default_interval = min(10, frame_count - 2) 938 | max_keyframe = (frame_count - 2) // default_interval 939 | 940 | global video_frame_count 941 | video_frame_count = frame_count 942 | global global_video_path 943 | global_video_path = path 944 | 945 | return gr.Slider.update(maximum=max_keyframe), \ 946 | gr.Slider.update(maximum=max_keyframe) 947 | 948 | def interval_changed(interval): 949 | global video_frame_count 950 | if video_frame_count is None: 951 | return gr.Slider.update() 952 | 953 | max_keyframe = (video_frame_count - 2) // interval 954 | 955 | return gr.Slider.update(value=max_keyframe, maximum=max_keyframe) 956 | 957 | input_path.change(input_changed, input_path, [interval, keyframe_count]) 958 | input_path.upload(input_uploaded, input_path, [interval, keyframe_count]) 959 | interval.change(interval_changed, interval, keyframe_count) 960 | 961 | ips_process3 = [*ips, max_process, use_poisson] 962 | run_button.click(fn=process, 963 | inputs=ips_process3, 964 | outputs=[result_image, result_keyframe, result_video]) 965 | run_button1.click(fn=process1, inputs=ips, outputs=[result_image]) 966 | run_button2.click(fn=process2, inputs=ips, outputs=[result_keyframe]) 967 | run_button3.click(fn=process3, inputs=ips_process3, outputs=[result_video]) 968 | 969 | block.launch(server_name='localhost') 970 | --------------------------------------------------------------------------------