The response has been limited to 50k tokens of the smallest files in the repo. You can remove this limitation by removing the max tokens filter.
├── .gitignore
├── .gitmodules
├── .pre-commit-config.yaml
├── LICENSE.md
├── README.md
├── blender
    ├── guide.py
    ├── histogram_blend.py
    ├── poisson_fusion.py
    └── video_sequence.py
├── config
    ├── real2sculpture.json
    ├── real2sculpture_freeu.json
    ├── real2sculpture_loose_cfattn.json
    ├── van_gogh_man.json
    ├── van_gogh_man_dynamic_resolution.json
    └── woman.json
├── environment.yml
├── flow
    └── flow_utils.py
├── inference_playground.ipynb
├── install.py
├── requirements.txt
├── rerender.py
├── sd_model_cfg.py
├── src
    ├── config.py
    ├── controller.py
    ├── ddim_v_hacked.py
    ├── freeu.py
    ├── img_util.py
    ├── import_util.py
    └── video_util.py
├── video_blend.py
├── videos
    ├── pexels-antoni-shkraba-8048492-540x960-25fps.mp4
    ├── pexels-cottonbro-studio-6649832-960x506-25fps.mp4
    └── pexels-koolshooters-7322716.mp4
└── webUI.py


/.gitignore:
--------------------------------------------------------------------------------
  1 | videos
  2 | models
  3 | result
  4 | # Byte-compiled / optimized / DLL files
  5 | __pycache__/
  6 | *.py[cod]
  7 | *$py.class
  8 | **/*.pyc
  9 | 
 10 | # C extensions
 11 | *.so
 12 | 
 13 | # Distribution / packaging
 14 | .Python
 15 | build/
 16 | develop-eggs/
 17 | dist/
 18 | downloads/
 19 | eggs/
 20 | .eggs/
 21 | lib/
 22 | lib64/
 23 | parts/
 24 | sdist/
 25 | var/
 26 | wheels/
 27 | *.egg-info/
 28 | .installed.cfg
 29 | *.egg
 30 | MANIFEST
 31 | 
 32 | # PyInstaller
 33 | #  Usually these files are written by a python script from a template
 34 | #  before PyInstaller builds the exe, so as to inject date/other infos into it.
 35 | *.manifest
 36 | *.spec
 37 | 
 38 | # Installer logs
 39 | pip-log.txt
 40 | pip-delete-this-directory.txt
 41 | 
 42 | # Unit test / coverage reports
 43 | htmlcov/
 44 | .tox/
 45 | .coverage
 46 | .coverage.*
 47 | .cache
 48 | nosetests.xml
 49 | coverage.xml
 50 | *.cover
 51 | .hypothesis/
 52 | .pytest_cache/
 53 | 
 54 | # Translations
 55 | *.mo
 56 | *.pot
 57 | 
 58 | # Django stuff:
 59 | *.log
 60 | local_settings.py
 61 | db.sqlite3
 62 | 
 63 | # Flask stuff:
 64 | instance/
 65 | .webassets-cache
 66 | 
 67 | # Scrapy stuff:
 68 | .scrapy
 69 | 
 70 | # Sphinx documentation
 71 | docs/en/_build/
 72 | docs/zh_cn/_build/
 73 | 
 74 | # PyBuilder
 75 | target/
 76 | 
 77 | # Jupyter Notebook
 78 | .ipynb_checkpoints
 79 | 
 80 | # pyenv
 81 | .python-version
 82 | 
 83 | # celery beat schedule file
 84 | celerybeat-schedule
 85 | 
 86 | # SageMath parsed files
 87 | *.sage.py
 88 | 
 89 | # Environments
 90 | .env
 91 | .venv
 92 | env/
 93 | venv/
 94 | ENV/
 95 | env.bak/
 96 | venv.bak/
 97 | 
 98 | # Spyder project settings
 99 | .spyderproject
100 | .spyproject
101 | 
102 | # Rope project settings
103 | .ropeproject
104 | 
105 | # mkdocs documentation
106 | /site
107 | 
108 | # mypy
109 | .mypy_cache/
110 | 
111 | # custom
112 | .vscode
113 | .idea
114 | *.pkl
115 | *.pkl.json
116 | *.log.json
117 | work_dirs/
118 | 
119 | # Pytorch
120 | *.pth
121 | 
122 | # onnx and tensorrt
123 | *.onnx
124 | *.trt
125 | 
126 | # local history
127 | .history/**
128 | 
129 | # Pytorch Server
130 | *.mar
131 | .DS_Store
132 | 
133 | /data/
134 | /data
135 | data
136 | .vector_cache
137 | 
138 | __pycache
139 | 


--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
 1 | [submodule "deps/gmflow"]
 2 | 	path = deps/gmflow
 3 | 	url = https://github.com/haofeixu/gmflow.git
 4 | [submodule "deps/ControlNet"]
 5 | 	path = deps/ControlNet
 6 | 	url = https://github.com/lllyasviel/ControlNet.git
 7 | [submodule "deps/ebsynth"]
 8 | 	path = deps/ebsynth
 9 | 	url = https://github.com/SingleZombie/ebsynth.git
10 | 


--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
 1 | repos:
 2 |   - repo: https://github.com/PyCQA/flake8
 3 |     rev: 4.0.1
 4 |     hooks:
 5 |       - id: flake8
 6 |   - repo: https://github.com/PyCQA/isort
 7 |     rev: 5.11.5
 8 |     hooks:
 9 |       - id: isort
10 |   - repo: https://github.com/pre-commit/mirrors-yapf
11 |     rev: v0.32.0
12 |     hooks:
13 |       - id: yapf
14 |   - repo: https://github.com/pre-commit/pre-commit-hooks
15 |     rev: v4.2.0
16 |     hooks:
17 |       - id: trailing-whitespace
18 |       - id: check-yaml
19 |       - id: end-of-file-fixer
20 |       - id: requirements-txt-fixer
21 |       - id: double-quote-string-fixer
22 |       - id: check-merge-conflict
23 |       - id: fix-encoding-pragma
24 |         args: ["--remove"]
25 |       - id: mixed-line-ending
26 |         args: ["--fix=lf"]
27 | 


--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
 1 | # S-Lab License 1.0
 2 | 
 3 | Copyright 2023 S-Lab
 4 | 
 5 | Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
 6 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
 7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
 8 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.\
 9 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
10 | 4. In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work.
11 | 
12 | 
13 | ---
14 | For the commercial use of the code, please consult Prof. Chen Change Loy (ccloy@ntu.edu.sg)
15 | 


--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
  1 | # Rerender A Video - Official PyTorch Implementation
  2 | 
  3 | ![teaser](https://github.com/williamyang1991/Rerender_A_Video/assets/18130694/aa7dc164-dab7-43f4-a46b-758b34911f16)
  4 | 
  5 | <!--https://github.com/williamyang1991/Rerender_A_Video/assets/18130694/82c35efb-e86b-4376-bfbe-6b69159b8879-->
  6 | 
  7 | 
  8 | **Rerender A Video: Zero-Shot Text-Guided Video-to-Video Translation**<br>
  9 | [Shuai Yang](https://williamyang1991.github.io/), [Yifan Zhou](https://zhouyifan.net/), [Ziwei Liu](https://liuziwei7.github.io/) and [Chen Change Loy](https://www.mmlab-ntu.com/person/ccloy/)<br>
 10 | in SIGGRAPH Asia 2023 Conference Proceedings <br>
 11 | [**Project Page**](https://www.mmlab-ntu.com/project/rerender/) | [**Paper**](https://arxiv.org/abs/2306.07954) | [**Supplementary Video**](https://youtu.be/cxfxdepKVaM) | [**Input Data and Video Results**](https://drive.google.com/file/d/1HkxG5eiLM_TQbbMZYOwjDbd5gWisOy4m/view?usp=sharing) <br>
 12 | 
 13 | <a href="https://huggingface.co/spaces/Anonymous-sub/Rerender"><img src="https://huggingface.co/datasets/huggingface/badges/raw/main/open-in-hf-spaces-sm-dark.svg" alt="Web Demo"></a> ![visitors](https://visitor-badge.laobi.icu/badge?page_id=williamyang1991/Rerender_A_Video)
 14 | 
 15 | > **Abstract:** *Large text-to-image diffusion models have exhibited impressive proficiency in generating high-quality images. However, when applying these models to video domain, ensuring temporal consistency across video frames remains a formidable challenge. This paper proposes a novel zero-shot text-guided video-to-video translation framework to adapt image models to videos. The framework includes two parts: key frame translation and full video translation. The first part uses an adapted diffusion model to generate key frames, with hierarchical cross-frame constraints applied to enforce coherence in shapes, textures and colors. The second part propagates the key frames to other frames with temporal-aware patch matching and frame blending. Our framework achieves global style and local texture temporal consistency at a low cost (without re-training or optimization). The adaptation is compatible with existing image diffusion techniques, allowing our framework to take advantage of them, such as customizing a specific subject with LoRA, and introducing extra spatial guidance with ControlNet. Extensive experimental results demonstrate the effectiveness of our proposed framework over existing methods in rendering high-quality and temporally-coherent videos.*
 16 | 
 17 | **Features**:<br>
 18 | - **Temporal consistency**: cross-frame constraints for low-level temporal consistency.
 19 | - **Zero-shot**: no training or fine-tuning required.
 20 | - **Flexibility**: compatible with off-the-shelf models (e.g., [ControlNet](https://github.com/lllyasviel/ControlNet), [LoRA](https://civitai.com/)) for customized translation.
 21 | 
 22 | https://github.com/williamyang1991/Rerender_A_Video/assets/18130694/811fdea3-f0da-49c9-92b8-2d2ad360f0d6
 23 | 
 24 | ## Updates
 25 | - [12/2023] The Diffusers pipeline is available: [Rerender_A_Video Community Pipeline](https://github.com/huggingface/diffusers/tree/main/examples/community#Rerender_A_Video)
 26 | - [10/2023] New features: [Loose cross-frame attention](#loose-cross-frame-attention) and [FreeU](#freeu).
 27 | - [09/2023] Code is released.
 28 | - [09/2023] Accepted to SIGGRAPH Asia 2023 Conference Proceedings!
 29 | - [06/2023] Integrated to 🤗 [Hugging Face](https://huggingface.co/spaces/Anonymous-sub/Rerender). Enjoy the web demo!
 30 | - [05/2023] This website is created.
 31 | 
 32 | ### TODO
 33 | - [x] ~~Integrate into Diffusers.~~
 34 | - [x] ~~Integrate [FreeU](https://github.com/ChenyangSi/FreeU) into Rerender~~
 35 | - [x] ~~Add Inference instructions in README.md.~~
 36 | - [x] ~~Add Examples to webUI.~~
 37 | - [x] ~~Add optional poisson fusion to the pipeline.~~
 38 | - [x] ~~Add Installation instructions for Windows~~
 39 | 
 40 | ## Installation
 41 | 
 42 | *Please make sure your installation path only contain English letters or _*
 43 | 
 44 | 1. Clone the repository. (Don't forget --recursive. Otherwise, please run `git submodule update --init --recursive`)
 45 | 
 46 | ```shell
 47 | git clone git@github.com:williamyang1991/Rerender_A_Video.git --recursive
 48 | cd Rerender_A_Video
 49 | ```
 50 | 
 51 | 2. If you have installed PyTorch CUDA, you can simply set up the environment with pip.
 52 | 
 53 | ```shell
 54 | pip install -r requirements.txt
 55 | ```
 56 | 
 57 | You can also create a new conda environment from scratch.
 58 | 
 59 | ```shell
 60 | conda env create -f environment.yml
 61 | conda activate rerender
 62 | ```
 63 | 24GB VRAM is required. Please refer to https://github.com/williamyang1991/Rerender_A_Video/pull/23#issue-1900789461 to reduce memory consumption.
 64 | 
 65 | 3. Run the installation script. The required models will be downloaded in `./models`.
 66 | 
 67 | ```shell
 68 | python install.py
 69 | ```
 70 | 
 71 | 4. You can run the demo with `rerender.py`
 72 | 
 73 | ```shell
 74 | python rerender.py --cfg config/real2sculpture.json
 75 | ```
 76 | 
 77 | <details>
 78 | <summary>Installation on Windows</summary>
 79 | 
 80 |   Before running the above 1-4 steps, you need prepare:
 81 | 1. Install [CUDA](https://developer.nvidia.com/cuda-toolkit-archive)
 82 | 2. Install [git](https://git-scm.com/download/win)
 83 | 3. Install [VS](https://visualstudio.microsoft.com/) with Windows 10/11 SDK (for building deps/ebsynth/bin/ebsynth.exe)
 84 | 4. [Here](https://github.com/williamyang1991/Rerender_A_Video/issues/18#issuecomment-1752712233) are more information. If building ebsynth fails, we provides our complied [ebsynth](https://drive.google.com/drive/folders/1oSB3imKwZGz69q2unBUfcgmQpzwccoyD?usp=sharing). 
 85 | </details>
 86 | 
 87 | <details id="issues">
 88 | <summary>🔥🔥🔥 <b>Installation or Running Fails?</b> 🔥🔥🔥</summary>
 89 | 
 90 | 1. In case building ebsynth fails, we provides our complied [ebsynth](https://drive.google.com/drive/folders/1oSB3imKwZGz69q2unBUfcgmQpzwccoyD?usp=sharing)
 91 | 2. `FileNotFoundError: [Errno 2] No such file or directory: 'xxxx.bin' or 'xxxx.jpg'`:
 92 |     - make sure your path only contains English letters or _ (https://github.com/williamyang1991/Rerender_A_Video/issues/18#issuecomment-1723361433)
 93 |     - find the code `python video_blend.py ...` in the error log and use it to manually run the ebsynth part, which is more stable than WebUI.
 94 |     - if some non-keyframes are generated but somes are not, rather than missing all non-keyframes in '/out_xx/', you may refer to https://github.com/williamyang1991/Rerender_A_Video/issues/38#issuecomment-1730668991
 95 |     - Enable the Execute permission of deps/ebsynth/bin/ebsynth
 96 |     - Enable the debug log to find more information https://github.com/williamyang1991/Rerender_A_Video/blob/d32b1d6b6c1305ddd06e66868c5dcf4fb7aa048c/video_blend.py#L22 
 97 | 5. `KeyError: 'dataset'`: upgrade Gradio to the latest version (https://github.com/williamyang1991/Rerender_A_Video/issues/14#issuecomment-1722778672, https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/11855)
 98 | 6. Error when processing videos: manually install ffmpeg (https://github.com/williamyang1991/Rerender_A_Video/issues/19#issuecomment-1723685825, https://github.com/williamyang1991/Rerender_A_Video/issues/29#issuecomment-1726091112)
 99 | 7. `ERR_ADDRESS_INVALID` Cannot open the webUI in browser: replace 0.0.0.0 with 127.0.0.1 in webUI.py (https://github.com/williamyang1991/Rerender_A_Video/issues/19#issuecomment-1723685825)
100 | 8. `CUDA out of memory`:
101 |      - Using xformers (https://github.com/williamyang1991/Rerender_A_Video/pull/23#issue-1900789461)
102 |      - Set `"use_limit_device_resolution"` to `true` in the config to resize the video according to your VRAM (https://github.com/williamyang1991/Rerender_A_Video/issues/79). An example config `config/van_gogh_man_dynamic_resolution.json` is provided.
103 | 10. `AttributeError: module 'keras.backend' has no attribute 'is_tensor'`: update einops (https://github.com/williamyang1991/Rerender_A_Video/issues/26#issuecomment-1726682446)
104 | 11. `IndexError: list index out of range`: use the original DDIM steps of 20 (https://github.com/williamyang1991/Rerender_A_Video/issues/30#issuecomment-1729039779)
105 | 12. One-click installation https://github.com/williamyang1991/Rerender_A_Video/issues/99
106 | 
107 | </details>
108 | 
109 | 
110 | ## (1) Inference
111 | 
112 | ### WebUI (recommended)
113 | 
114 | ```
115 | python webUI.py
116 | ```
117 | The Gradio app also allows you to flexibly change the inference options. Just try it for more details. (For WebUI, you need to download [revAnimated_v11](https://civitai.com/models/7371/rev-animated?modelVersionId=19575) and [realisticVisionV20_v20](https://civitai.com/models/4201?modelVersionId=29460) to `./models/` after Installation)
118 | 
119 | Upload your video, input the prompt, select the seed, and hit:
120 | - **Run 1st Key Frame**: only translate the first frame, so you can adjust the prompts/models/parameters to find your ideal output appearance before running the whole video.
121 | - **Run Key Frames**: translate all the key frames based on the settings of the first frame, so you can adjust the temporal-related parameters for better temporal consistency before running the whole video.
122 | - **Run Propagation**: propagate the key frames to other frames for full video translation
123 | - **Run All**: **Run 1st Key Frame**, **Run Key Frames** and **Run Propagation**
124 | 
125 | ![UI](https://github.com/williamyang1991/Rerender_A_Video/assets/18130694/eb4e1ddc-11a3-42dd-baa4-622eecef04c7)
126 | 
127 | 
128 | We provide abundant advanced options to play with
129 | 
130 | <details id="option0">
131 | <summary> <b>Using customized models</b></summary>
132 | 
133 | - Using LoRA/Dreambooth/Finetuned/Mixed SD models
134 |   - Modify `sd_model_cfg.py` to add paths to the saved SD models
135 |   - How to use LoRA: https://github.com/williamyang1991/Rerender_A_Video/issues/39#issuecomment-1730678296
136 | - Using other controls from ControlNet (e.g., Depth, Pose)
137 |   - Add more options like `control_type = gr.Dropdown(['HED', 'canny', 'depth']` here https://github.com/williamyang1991/Rerender_A_Video/blob/b6cafb5d80a79a3ef831c689ffad92ec095f2794/webUI.py#L690
138 |   - Add model loading options like `elif control_type == 'depth':` following https://github.com/williamyang1991/Rerender_A_Video/blob/b6cafb5d80a79a3ef831c689ffad92ec095f2794/webUI.py#L88
139 |   - Add model detectors like `elif control_type == 'depth':` following https://github.com/williamyang1991/Rerender_A_Video/blob/b6cafb5d80a79a3ef831c689ffad92ec095f2794/webUI.py#L122
140 |   - One example is given [here](https://huggingface.co/spaces/Anonymous-sub/Rerender/discussions/10/files)
141 | 
142 | </details>
143 | 
144 | <details id="option1">
145 | <summary> <b>Advanced options for the 1st frame translation</b></summary>
146 | 
147 | 1. Resolution related (**Frame resolution**, **left/top/right/bottom crop length**): crop the frame and resize its short side to 512.
148 | 2. ControlNet related:
149 |    - **ControlNet strength**: how well the output matches the input control edges
150 |    - **Control type**: HED edge or Canny edge
151 |    - **Canny low/high threshold**: low values for more edge details
152 | 3. SDEdit related:
153 |    - **Denoising strength**: repaint degree (low value to make the output look more like the original video)
154 |    - **Preserve color**: preserve the color of the original video
155 | 4. SD related:
156 |    - **Steps**: denoising step
157 |    - **CFG scale**: how well the output matches the prompt
158 |    - **Base model**: base Stable Diffusion model (SD 1.5)
159 |      - Stable Diffusion 1.5: official model
160 |      - [revAnimated_v11](https://civitai.com/models/7371/rev-animated?modelVersionId=19575): a semi-realistic (2.5D) model
161 |      - [realisticVisionV20_v20](https://civitai.com/models/4201?modelVersionId=29460): a photo-realistic model
162 |    - **Added prompt/Negative prompt**: supplementary prompts
163 | 5. FreeU related:
164 |    - **FreeU first/second-stage backbone factor**: =1 do nothing; >1 enhance output color and details
165 |    - **FreeU first/second-stage skip factor**: =1 do nothing; <1 enhance output color and details
166 | 
167 | </details>
168 | 
169 | <details id="option2">
170 | <summary> <b>Advanced options for the key frame translation</b></summary>
171 | 
172 | 1. Key frame related
173 |    - **Key frame frequency (K)**: Uniformly sample the key frame every K frames. Small value for large or fast motions.
174 |    - **Number of key frames (M)**: The final output video will have K*M+1 frames with M+1 key frames.
175 | 2. Temporal consistency related
176 |    - Cross-frame attention:
177 |      - **Cross-frame attention start/end**: When applying cross-frame attention for global style consistency
178 |      - **Cross-frame attention update frequency (N)**: Update the reference style frame every N key frames. Should be large for long videos to avoid error accumulation.
179 |      - **Loose Cross-frame attention**: Using cross-frame attention in fewer layers to better match the input video (for video with large motions)
180 |    - **Shape-aware fusion** Check to use this feature
181 |      - **Shape-aware fusion start/end**: When applying shape-aware fusion for local shape consistency
182 |    - **Pixel-aware fusion** Check to use this feature
183 |      - **Pixel-aware fusion start/end**: When applying pixel-aware fusion for pixel-level temporal consistency
184 |      - **Pixel-aware fusion strength**: The strength to preserve the non-inpainting region. Small to avoid error accumulation. Large to avoid burry textures.
185 |      - **Pixel-aware fusion detail level**: The strength to sharpen the inpainting region. Small to avoid error accumulation. Large to avoid burry textures.
186 |      - **Smooth fusion boundary**: Check to smooth the inpainting boundary (avoid error accumulation).
187 |    - **Color-aware AdaIN** Check to use this feature
188 |      - **Color-aware AdaIN start/end**: When applying AdaIN to make the video color consistent with the first frame
189 | 
190 | </details>
191 | 
192 | <details id="option3">
193 | <summary> <b>Advanced options for the full video translation</b></summary>
194 | 
195 | 1. **Gradient blending**: apply Poisson Blending to reduce ghosting artifacts. May slow the process and increase flickers.
196 | 2. **Number of parallel processes**: multiprocessing to speed up the process. Large value (8) is recommended.
197 | </details>
198 | 
199 | ![options](https://github.com/williamyang1991/Rerender_A_Video/assets/18130694/ffebac15-e7e0-4cd4-a8fe-60f243450172)
200 | 
201 | 
202 | ### Command Line
203 | 
204 | We also provide a flexible script `rerender.py` to run our method.
205 | 
206 | #### Simple mode
207 | 
208 | Set the options via command line. For example,
209 | 
210 | ```shell
211 | python rerender.py --input videos/pexels-antoni-shkraba-8048492-540x960-25fps.mp4 --output result/man/man.mp4 --prompt "a handsome man in van gogh painting"
212 | ```
213 | 
214 | The script will run the full pipeline. A work directory will be created at `result/man` and the result video will be saved as `result/man/man.mp4`
215 | 
216 | #### Advanced mode
217 | 
218 | Set the options via a config file. For example,
219 | 
220 | ```shell
221 | python rerender.py --cfg config/van_gogh_man.json
222 | ```
223 | 
224 | The script will run the full pipeline.
225 | We provide some examples of the config in `config` directory.
226 | Most options in the config is the same as those in WebUI.
227 | Please check the explanations in the WebUI section.
228 | 
229 | Specifying customized models by setting `sd_model` in config. For example:
230 | ```json
231 | {
232 |   "sd_model": "models/realisticVisionV20_v20.safetensors",
233 | }
234 | ```
235 | 
236 | #### Customize the pipeline
237 | 
238 | Similar to WebUI, we provide three-step workflow: Rerender the first key frame, then rerender the full key frames, finally rerender the full video with propagation. To run only a single step, specify options `-one`, `-nb` and `-nr`:
239 | 
240 | 1. Rerender the first key frame
241 | ```shell
242 | python rerender.py --cfg config/van_gogh_man.json -one -nb
243 | ```
244 | 2. Rerender the full key frames
245 | ```shell
246 | python rerender.py --cfg config/van_gogh_man.json -nb
247 | ```
248 | 3. Rerender the full video with propagation
249 | ```shell
250 | python rerender.py --cfg config/van_gogh_man.json -nr
251 | ```
252 | 
253 | #### Our Ebsynth implementation
254 | 
255 | We provide a separate Ebsynth python script `video_blend.py` with the temporal blending algorithm introduced in
256 | [Stylizing Video by Example](https://dcgi.fel.cvut.cz/home/sykorad/ebsynth.html) for interpolating style between key frames.
257 | It can work on your own stylized key frames independently of our Rerender algorithm.
258 | 
259 | Usage:
260 | ```shell
261 | video_blend.py [-h] [--output OUTPUT] [--fps FPS] [--beg BEG] [--end END] [--itv ITV] [--key KEY]
262 |                       [--n_proc N_PROC] [-ps] [-ne] [-tmp]
263 |                       name
264 | 
265 | positional arguments:
266 |   name             Path to input video
267 | 
268 | optional arguments:
269 |   -h, --help       show this help message and exit
270 |   --output OUTPUT  Path to output video
271 |   --fps FPS        The FPS of output video
272 |   --beg BEG        The index of the first frame to be stylized
273 |   --end END        The index of the last frame to be stylized
274 |   --itv ITV        The interval of key frame
275 |   --key KEY        The subfolder name of stylized key frames
276 |   --n_proc N_PROC  The max process count
277 |   -ps              Use poisson gradient blending
278 |   -ne              Do not run ebsynth (use previous ebsynth output)
279 |   -tmp             Keep temporary output
280 | ```
281 | For example, to run Ebsynth on video `man.mp4`,
282 | 1. Put the stylized key frames to `videos/man/keys` for every 10 frames (named as `0001.png`, `0011.png`, ...)
283 | 2. Put the original video frames in `videos/man/video` (named as `0001.png`, `0002.png`, ...).
284 | 3. Run Ebsynth on the first 101 frames of the video with poisson gradient blending and save the result to `videos/man/blend.mp4` under FPS 25 with the following command:
285 | ```shell
286 | python video_blend.py videos/man \
287 |   --beg 1 \
288 |   --end 101 \
289 |   --itv 10 \
290 |   --key keys \
291 |   --output videos/man/blend.mp4 \
292 |   --fps 25.0 \
293 |   -ps
294 | ```
295 | 
296 | ## (2) Results
297 | 
298 | ### Key frame translation
299 | 
300 | 
301 | <table class="center">
302 | <tr>
303 |   <td><img src="https://github.com/williamyang1991/Rerender_A_Video/assets/18130694/18666871-f273-44b2-ae67-7be85d43e2f6" raw=true></td>
304 |   <td><img src="https://github.com/williamyang1991/Rerender_A_Video/assets/18130694/61f59540-f06e-4e5a-86b6-1d7cb8ed6300" raw=true></td>
305 |   <td><img src="https://github.com/williamyang1991/Rerender_A_Video/assets/18130694/8e8ad51a-6a71-4b34-8633-382192d0f17c" raw=true></td>
306 |   <td><img src="https://github.com/williamyang1991/Rerender_A_Video/assets/18130694/b03cd35f-5d90-471a-9aa9-5c7773d7ac39" raw=true></td>
307 | </tr>
308 | <tr>
309 |   <td width=27.5% align="center">white ancient Greek sculpture, Venus de Milo, light pink and blue background</td>
310 |   <td width=27.5% align="center">a handsome Greek man</td>
311 |   <td width=21.5% align="center">a traditional mountain in chinese ink wash painting</td>
312 |   <td width=23.5% align="center">a cartoon tiger</td>
313 | </tr>
314 | </table>
315 | 
316 | <table class="center">
317 | <tr>
318 |   <td><img src="https://github.com/williamyang1991/Rerender_A_Video/assets/18130694/649a789e-0c41-41cf-94a4-0d524dcfb282" raw=true></td>
319 |   <td><img src="https://github.com/williamyang1991/Rerender_A_Video/assets/18130694/73590c16-916f-4ee6-881a-44a201dd85dd" raw=true></td>
320 |   <td><img src="https://github.com/williamyang1991/Rerender_A_Video/assets/18130694/fbdc0b8e-6046-414f-a37e-3cd9dd0adf5d" raw=true></td>
321 |   <td><img src="https://github.com/williamyang1991/Rerender_A_Video/assets/18130694/eb11d807-2afa-4609-a074-34300b67e6aa" raw=true></td>
322 | </tr>
323 | <tr>
324 |   <td width=26.0% align="center">a swan in chinese ink wash painting, monochrome</td>
325 |   <td width=29.0% align="center">a beautiful woman in CG style</td>
326 |   <td width=21.5% align="center">a clean simple white jade sculpture</td>
327 |   <td width=24.0% align="center">a fluorescent jellyfish in the deep dark blue sea</td>
328 | </tr>
329 | </table>
330 | 
331 | ### Full video translation
332 | 
333 | Text-guided virtual character generation.
334 | 
335 | 
336 | https://github.com/williamyang1991/Rerender_A_Video/assets/18130694/1405b257-e59a-427f-890d-7652e6bed0a4
337 | 
338 | 
339 | https://github.com/williamyang1991/Rerender_A_Video/assets/18130694/efee8cc6-9708-4124-bf6a-49baf91349fc
340 | 
341 | 
342 | Video stylization and video editing.
343 | 
344 | 
345 | https://github.com/williamyang1991/Rerender_A_Video/assets/18130694/1b72585c-99c0-401d-b240-5b8016df7a3f
346 | 
347 | ## New Features
348 | 
349 | Compared to the conference version, we are keeping adding new features.
350 | 
351 | ![new_feature](https://github.com/williamyang1991/Rerender_A_Video/assets/18130694/98f39f3d-3dfe-4de4-a1b6-99a3c78b5336)
352 | 
353 | #### Loose cross-frame attention
354 | By using cross-frame attention in less layers, our results will better match the input video, thus reducing ghosting artifacts caused by inconsistencies. This feature can be activated by checking `Loose Cross-frame attention` in the <a href="#option2">Advanced options for the key frame translation</a> for WebUI or setting `loose_cfattn` for script (see `config/real2sculpture_loose_cfattn.json`).
355 | 
356 | #### FreeU
357 | [FreeU](https://github.com/ChenyangSi/FreeU) is a method that improves diffusion model sample quality at no costs. We find featured with FreeU, our results will have higher contrast and saturation, richer details, and more vivid colors. This feature can be used by setting FreeU backbone factors and skip factors in the <a href="#option1">Advanced options for the 1st frame translation</a> for WebUI or setting `freeu_args` for script (see `config/real2sculpture_freeu.json`).
358 | 
359 | ## Citation
360 | 
361 | If you find this work useful for your research, please consider citing our paper:
362 | 
363 | ```bibtex
364 | @inproceedings{yang2023rerender,
365 |  title = {Rerender A Video: Zero-Shot Text-Guided Video-to-Video Translation},
366 |  author = {Yang, Shuai and Zhou, Yifan and Liu, Ziwei and and Loy, Chen Change},
367 |  booktitle = {ACM SIGGRAPH Asia Conference Proceedings},
368 |  year = {2023},
369 | }
370 | ```
371 | 
372 | ## Acknowledgments
373 | 
374 | The code is mainly developed based on [ControlNet](https://github.com/lllyasviel/ControlNet), [Stable Diffusion](https://github.com/Stability-AI/stablediffusion), [GMFlow](https://github.com/haofeixu/gmflow) and [Ebsynth](https://github.com/jamriska/ebsynth).
375 | 


--------------------------------------------------------------------------------
/blender/guide.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | 
  3 | import cv2
  4 | import numpy as np
  5 | 
  6 | from flow.flow_utils import flow_calc, read_flow, read_mask
  7 | 
  8 | 
  9 | class BaseGuide:
 10 | 
 11 |     def __init__(self):
 12 |         ...
 13 | 
 14 |     def get_cmd(self, i, weight) -> str:
 15 |         return (f'-guide {os.path.abspath(self.imgs[0])} '
 16 |                 f'{os.path.abspath(self.imgs[i])} -weight {weight}')
 17 | 
 18 | 
 19 | class ColorGuide(BaseGuide):
 20 | 
 21 |     def __init__(self, imgs):
 22 |         super().__init__()
 23 |         self.imgs = imgs
 24 | 
 25 | 
 26 | class PositionalGuide(BaseGuide):
 27 | 
 28 |     def __init__(self, flow_paths, save_paths):
 29 |         super().__init__()
 30 |         flows = [read_flow(f) for f in flow_paths]
 31 |         masks = [read_mask(f) for f in flow_paths]
 32 |         # TODO: modify the format of flow to numpy
 33 |         H, W = flows[0].shape[2:]
 34 |         first_img = PositionalGuide.__generate_first_img(H, W)
 35 |         prev_img = first_img
 36 |         imgs = [first_img]
 37 |         cid = 0
 38 |         for flow, mask in zip(flows, masks):
 39 |             cur_img = flow_calc.warp(prev_img, flow,
 40 |                                      'nearest').astype(np.uint8)
 41 |             cur_img = cv2.inpaint(cur_img, mask, 30, cv2.INPAINT_TELEA)
 42 |             prev_img = cur_img
 43 |             imgs.append(cur_img)
 44 |             cid += 1
 45 |             cv2.imwrite(f'guide/{cid}.jpg', mask)
 46 | 
 47 |         for path, img in zip(save_paths, imgs):
 48 |             cv2.imwrite(path, img)
 49 |         self.imgs = save_paths
 50 | 
 51 |     @staticmethod
 52 |     def __generate_first_img(H, W):
 53 |         Hs = np.linspace(0, 1, H)
 54 |         Ws = np.linspace(0, 1, W)
 55 |         i, j = np.meshgrid(Hs, Ws, indexing='ij')
 56 |         r = (i * 255).astype(np.uint8)
 57 |         g = (j * 255).astype(np.uint8)
 58 |         b = np.zeros(r.shape)
 59 |         res = np.stack((b, g, r), 2)
 60 |         return res
 61 | 
 62 | 
 63 | class EdgeGuide(BaseGuide):
 64 | 
 65 |     def __init__(self, imgs, save_paths):
 66 |         super().__init__()
 67 |         edges = [EdgeGuide.__generate_edge(cv2.imread(img)) for img in imgs]
 68 |         for path, img in zip(save_paths, edges):
 69 |             cv2.imwrite(path, img)
 70 |         self.imgs = save_paths
 71 | 
 72 |     @staticmethod
 73 |     def __generate_edge(img):
 74 |         filter = np.array([[0, -1, 0], [-1, 4, -1], [0, -1, 0]])
 75 |         res = cv2.filter2D(img, -1, filter)
 76 |         return res
 77 | 
 78 | 
 79 | class TemporalGuide(BaseGuide):
 80 | 
 81 |     def __init__(self, key_img, stylized_imgs, flow_paths, save_paths):
 82 |         super().__init__()
 83 |         self.flows = [read_flow(f) for f in flow_paths]
 84 |         self.masks = [read_mask(f) for f in flow_paths]
 85 |         self.stylized_imgs = stylized_imgs
 86 |         self.imgs = save_paths
 87 | 
 88 |         first_img = cv2.imread(key_img)
 89 |         cv2.imwrite(self.imgs[0], first_img)
 90 | 
 91 |     def get_cmd(self, i, weight) -> str:
 92 |         if i == 0:
 93 |             warped_img = self.stylized_imgs[0]
 94 |         else:
 95 |             prev_img = cv2.imread(self.stylized_imgs[i - 1])
 96 |             warped_img = flow_calc.warp(prev_img, self.flows[i - 1],
 97 |                                         'nearest').astype(np.uint8)
 98 | 
 99 |             warped_img = cv2.inpaint(warped_img, self.masks[i - 1], 30,
100 |                                      cv2.INPAINT_TELEA)
101 | 
102 |             cv2.imwrite(self.imgs[i], warped_img)
103 | 
104 |         return super().get_cmd(i, weight)
105 | 


--------------------------------------------------------------------------------
/blender/histogram_blend.py:
--------------------------------------------------------------------------------
 1 | import cv2
 2 | import numpy as np
 3 | 
 4 | 
 5 | def histogram_transform(img: np.ndarray, means: np.ndarray, stds: np.ndarray,
 6 |                         target_means: np.ndarray, target_stds: np.ndarray):
 7 |     means = means.reshape((1, 1, 3))
 8 |     stds = stds.reshape((1, 1, 3))
 9 |     target_means = target_means.reshape((1, 1, 3))
10 |     target_stds = target_stds.reshape((1, 1, 3))
11 |     x = img.astype(np.float32)
12 |     x = (x - means) * target_stds / stds + target_means
13 |     # x = np.round(x)
14 |     # x = np.clip(x, 0, 255)
15 |     # x = x.astype(np.uint8)
16 |     return x
17 | 
18 | 
19 | def blend(a: np.ndarray,
20 |           b: np.ndarray,
21 |           min_error: np.ndarray,
22 |           weight1=0.5,
23 |           weight2=0.5):
24 |     a = cv2.cvtColor(a, cv2.COLOR_BGR2Lab)
25 |     b = cv2.cvtColor(b, cv2.COLOR_BGR2Lab)
26 |     min_error = cv2.cvtColor(min_error, cv2.COLOR_BGR2Lab)
27 |     a_mean = np.mean(a, axis=(0, 1))
28 |     a_std = np.std(a, axis=(0, 1))
29 |     b_mean = np.mean(b, axis=(0, 1))
30 |     b_std = np.std(b, axis=(0, 1))
31 |     min_error_mean = np.mean(min_error, axis=(0, 1))
32 |     min_error_std = np.std(min_error, axis=(0, 1))
33 | 
34 |     t_mean_val = 0.5 * 256
35 |     t_std_val = (1 / 36) * 256
36 |     t_mean = np.ones([3], dtype=np.float32) * t_mean_val
37 |     t_std = np.ones([3], dtype=np.float32) * t_std_val
38 |     a = histogram_transform(a, a_mean, a_std, t_mean, t_std)
39 | 
40 |     b = histogram_transform(b, b_mean, b_std, t_mean, t_std)
41 |     ab = (a * weight1 + b * weight2 - t_mean_val) / 0.5 + t_mean_val
42 |     ab_mean = np.mean(ab, axis=(0, 1))
43 |     ab_std = np.std(ab, axis=(0, 1))
44 |     ab = histogram_transform(ab, ab_mean, ab_std, min_error_mean,
45 |                              min_error_std)
46 |     ab = np.round(ab)
47 |     ab = np.clip(ab, 0, 255)
48 |     ab = ab.astype(np.uint8)
49 |     ab = cv2.cvtColor(ab, cv2.COLOR_Lab2BGR)
50 |     return ab
51 | 


--------------------------------------------------------------------------------
/blender/poisson_fusion.py:
--------------------------------------------------------------------------------
 1 | import cv2
 2 | import numpy as np
 3 | import scipy
 4 | 
 5 | As = None
 6 | prev_states = None
 7 | 
 8 | 
 9 | def construct_A(h, w, grad_weight):
10 |     indgx_x = []
11 |     indgx_y = []
12 |     indgy_x = []
13 |     indgy_y = []
14 |     vdx = []
15 |     vdy = []
16 |     for i in range(h):
17 |         for j in range(w):
18 |             if i < h - 1:
19 |                 indgx_x += [i * w + j]
20 |                 indgx_y += [i * w + j]
21 |                 vdx += [1]
22 |                 indgx_x += [i * w + j]
23 |                 indgx_y += [(i + 1) * w + j]
24 |                 vdx += [-1]
25 |             if j < w - 1:
26 |                 indgy_x += [i * w + j]
27 |                 indgy_y += [i * w + j]
28 |                 vdy += [1]
29 |                 indgy_x += [i * w + j]
30 |                 indgy_y += [i * w + j + 1]
31 |                 vdy += [-1]
32 |     Ix = scipy.sparse.coo_array(
33 |         (np.ones(h * w), (np.arange(h * w), np.arange(h * w))),
34 |         shape=(h * w, h * w)).tocsc()
35 |     Gx = scipy.sparse.coo_array(
36 |         (np.array(vdx), (np.array(indgx_x), np.array(indgx_y))),
37 |         shape=(h * w, h * w)).tocsc()
38 |     Gy = scipy.sparse.coo_array(
39 |         (np.array(vdy), (np.array(indgy_x), np.array(indgy_y))),
40 |         shape=(h * w, h * w)).tocsc()
41 |     As = []
42 |     for i in range(3):
43 |         As += [
44 |             scipy.sparse.vstack([Gx * grad_weight[i], Gy * grad_weight[i], Ix])
45 |         ]
46 |     return As
47 | 
48 | 
49 | # blendI, I1, I2, mask should be RGB unit8 type
50 | # return poissson fusion result (RGB unit8 type)
51 | # I1 and I2: propagated results from previous and subsequent key frames
52 | # mask: pixel selection mask
53 | # blendI: contrastive-preserving blending results of I1 and I2
54 | def poisson_fusion(blendI, I1, I2, mask, grad_weight=[2.5, 0.5, 0.5]):
55 |     global As
56 |     global prev_states
57 | 
58 |     Iab = cv2.cvtColor(blendI, cv2.COLOR_BGR2LAB).astype(float)
59 |     Ia = cv2.cvtColor(I1, cv2.COLOR_BGR2LAB).astype(float)
60 |     Ib = cv2.cvtColor(I2, cv2.COLOR_BGR2LAB).astype(float)
61 |     m = (mask > 0).astype(float)[:, :, np.newaxis]
62 |     h, w, c = Iab.shape
63 | 
64 |     # fuse the gradient of I1 and I2 with mask
65 |     gx = np.zeros_like(Ia)
66 |     gy = np.zeros_like(Ia)
67 |     gx[:-1, :, :] = (Ia[:-1, :, :] - Ia[1:, :, :]) * (1 - m[:-1, :, :]) + (
68 |         Ib[:-1, :, :] - Ib[1:, :, :]) * m[:-1, :, :]
69 |     gy[:, :-1, :] = (Ia[:, :-1, :] - Ia[:, 1:, :]) * (1 - m[:, :-1, :]) + (
70 |         Ib[:, :-1, :] - Ib[:, 1:, :]) * m[:, :-1, :]
71 | 
72 |     # construct A for solving Ax=b
73 |     crt_states = (h, w, grad_weight)
74 |     if As is None or crt_states != prev_states:
75 |         As = construct_A(*crt_states)
76 |         prev_states = crt_states
77 | 
78 |     final = []
79 |     for i in range(3):
80 |         weight = grad_weight[i]
81 |         im_dx = np.clip(gx[:, :, i].reshape(h * w, 1), -100, 100)
82 |         im_dy = np.clip(gy[:, :, i].reshape(h * w, 1), -100, 100)
83 |         im = Iab[:, :, i].reshape(h * w, 1)
84 |         im_mean = im.mean()
85 |         im = im - im_mean
86 |         A = As[i]
87 |         b = np.vstack([im_dx * weight, im_dy * weight, im])
88 |         out = scipy.sparse.linalg.lsqr(A, b)
89 |         out_im = (out[0] + im_mean).reshape(h, w, 1)
90 |         final += [out_im]
91 | 
92 |     final = np.clip(np.concatenate(final, axis=2), 0, 255)
93 |     return cv2.cvtColor(final.astype(np.uint8), cv2.COLOR_LAB2BGR)
94 | 


--------------------------------------------------------------------------------
/blender/video_sequence.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | import shutil
  3 | 
  4 | 
  5 | class VideoSequence:
  6 | 
  7 |     def __init__(self,
  8 |                  base_dir,
  9 |                  beg_frame,
 10 |                  end_frame,
 11 |                  interval,
 12 |                  input_subdir='videos',
 13 |                  key_subdir='keys0',
 14 |                  tmp_subdir='tmp',
 15 |                  input_format='frame%04d.jpg',
 16 |                  key_format='%04d.jpg',
 17 |                  out_subdir_format='out_%d',
 18 |                  blending_out_subdir='blend',
 19 |                  output_format='%04d.jpg'):
 20 |         if (end_frame - beg_frame) % interval != 0:
 21 |             end_frame -= (end_frame - beg_frame) % interval
 22 | 
 23 |         self.__base_dir = base_dir
 24 |         self.__input_dir = os.path.join(base_dir, input_subdir)
 25 |         self.__key_dir = os.path.join(base_dir, key_subdir)
 26 |         self.__tmp_dir = os.path.join(base_dir, tmp_subdir)
 27 |         self.__input_format = input_format
 28 |         self.__blending_out_dir = os.path.join(base_dir, blending_out_subdir)
 29 |         self.__key_format = key_format
 30 |         self.__out_subdir_format = out_subdir_format
 31 |         self.__output_format = output_format
 32 |         self.__beg_frame = beg_frame
 33 |         self.__end_frame = end_frame
 34 |         self.__interval = interval
 35 |         self.__n_seq = (end_frame - beg_frame) // interval
 36 |         self.__make_out_dirs()
 37 |         os.makedirs(self.__tmp_dir, exist_ok=True)
 38 | 
 39 |     @property
 40 |     def beg_frame(self):
 41 |         return self.__beg_frame
 42 | 
 43 |     @property
 44 |     def end_frame(self):
 45 |         return self.__end_frame
 46 | 
 47 |     @property
 48 |     def n_seq(self):
 49 |         return self.__n_seq
 50 | 
 51 |     @property
 52 |     def interval(self):
 53 |         return self.__interval
 54 | 
 55 |     @property
 56 |     def blending_dir(self):
 57 |         return os.path.abspath(self.__blending_out_dir)
 58 | 
 59 |     def remove_out_and_tmp(self):
 60 |         for i in range(self.n_seq + 1):
 61 |             out_dir = self.__get_out_subdir(i)
 62 |             shutil.rmtree(out_dir)
 63 |         shutil.rmtree(self.__tmp_dir)
 64 | 
 65 |     def get_input_sequence(self, i, is_forward=True):
 66 |         beg_id = self.get_sequence_beg_id(i)
 67 |         end_id = self.get_sequence_beg_id(i + 1)
 68 |         if is_forward:
 69 |             id_list = list(range(beg_id, end_id))
 70 |         else:
 71 |             id_list = list(range(end_id, beg_id, -1))
 72 |         path_dir = [
 73 |             os.path.join(self.__input_dir, self.__input_format % id)
 74 |             for id in id_list
 75 |         ]
 76 |         return path_dir
 77 | 
 78 |     def get_output_sequence(self, i, is_forward=True):
 79 |         beg_id = self.get_sequence_beg_id(i)
 80 |         end_id = self.get_sequence_beg_id(i + 1)
 81 |         if is_forward:
 82 |             id_list = list(range(beg_id, end_id))
 83 |         else:
 84 |             i += 1
 85 |             id_list = list(range(end_id, beg_id, -1))
 86 |         out_subdir = self.__get_out_subdir(i)
 87 |         path_dir = [
 88 |             os.path.join(out_subdir, self.__output_format % id)
 89 |             for id in id_list
 90 |         ]
 91 |         return path_dir
 92 | 
 93 |     def get_temporal_sequence(self, i, is_forward=True):
 94 |         beg_id = self.get_sequence_beg_id(i)
 95 |         end_id = self.get_sequence_beg_id(i + 1)
 96 |         if is_forward:
 97 |             id_list = list(range(beg_id, end_id))
 98 |         else:
 99 |             i += 1
100 |             id_list = list(range(end_id, beg_id, -1))
101 |         tmp_dir = self.__get_tmp_out_subdir(i)
102 |         path_dir = [
103 |             os.path.join(tmp_dir, 'temporal_' + self.__output_format % id)
104 |             for id in id_list
105 |         ]
106 |         return path_dir
107 | 
108 |     def get_edge_sequence(self, i, is_forward=True):
109 |         beg_id = self.get_sequence_beg_id(i)
110 |         end_id = self.get_sequence_beg_id(i + 1)
111 |         if is_forward:
112 |             id_list = list(range(beg_id, end_id))
113 |         else:
114 |             i += 1
115 |             id_list = list(range(end_id, beg_id, -1))
116 |         tmp_dir = self.__get_tmp_out_subdir(i)
117 |         path_dir = [
118 |             os.path.join(tmp_dir, 'edge_' + self.__output_format % id)
119 |             for id in id_list
120 |         ]
121 |         return path_dir
122 | 
123 |     def get_pos_sequence(self, i, is_forward=True):
124 |         beg_id = self.get_sequence_beg_id(i)
125 |         end_id = self.get_sequence_beg_id(i + 1)
126 |         if is_forward:
127 |             id_list = list(range(beg_id, end_id))
128 |         else:
129 |             i += 1
130 |             id_list = list(range(end_id, beg_id, -1))
131 |         tmp_dir = self.__get_tmp_out_subdir(i)
132 |         path_dir = [
133 |             os.path.join(tmp_dir, 'pos_' + self.__output_format % id)
134 |             for id in id_list
135 |         ]
136 |         return path_dir
137 | 
138 |     def get_flow_sequence(self, i, is_forward=True):
139 |         beg_id = self.get_sequence_beg_id(i)
140 |         end_id = self.get_sequence_beg_id(i + 1)
141 |         if is_forward:
142 |             id_list = list(range(beg_id, end_id - 1))
143 |             path_dir = [
144 |                 os.path.join(self.__tmp_dir, 'flow_f_%04d.npy' % id)
145 |                 for id in id_list
146 |             ]
147 |         else:
148 |             id_list = list(range(end_id, beg_id + 1, -1))
149 |             path_dir = [
150 |                 os.path.join(self.__tmp_dir, 'flow_b_%04d.npy' % id)
151 |                 for id in id_list
152 |             ]
153 | 
154 |         return path_dir
155 | 
156 |     def get_input_img(self, i):
157 |         return os.path.join(self.__input_dir, self.__input_format % i)
158 | 
159 |     def get_key_img(self, i):
160 |         sequence_beg_id = self.get_sequence_beg_id(i)
161 |         return os.path.join(self.__key_dir,
162 |                             self.__key_format % sequence_beg_id)
163 | 
164 |     def get_blending_img(self, i):
165 |         return os.path.join(self.__blending_out_dir, self.__output_format % i)
166 | 
167 |     def get_sequence_beg_id(self, i):
168 |         return i * self.__interval + self.__beg_frame
169 | 
170 |     def __get_out_subdir(self, i):
171 |         dir_id = self.get_sequence_beg_id(i)
172 |         out_subdir = os.path.join(self.__base_dir,
173 |                                   self.__out_subdir_format % dir_id)
174 |         return out_subdir
175 | 
176 |     def __get_tmp_out_subdir(self, i):
177 |         dir_id = self.get_sequence_beg_id(i)
178 |         tmp_out_subdir = os.path.join(self.__tmp_dir,
179 |                                       self.__out_subdir_format % dir_id)
180 |         return tmp_out_subdir
181 | 
182 |     def __make_out_dirs(self):
183 |         os.makedirs(self.__base_dir, exist_ok=True)
184 |         os.makedirs(self.__blending_out_dir, exist_ok=True)
185 |         for i in range(self.__n_seq + 1):
186 |             out_subdir = self.__get_out_subdir(i)
187 |             tmp_subdir = self.__get_tmp_out_subdir(i)
188 |             os.makedirs(out_subdir, exist_ok=True)
189 |             os.makedirs(tmp_subdir, exist_ok=True)
190 | 


--------------------------------------------------------------------------------
/config/real2sculpture.json:
--------------------------------------------------------------------------------
 1 | {
 2 |     "input": "videos/pexels-cottonbro-studio-6649832-960x506-25fps.mp4",
 3 |     "output": "videos/pexels-cottonbro-studio-6649832-960x506-25fps/blend.mp4",
 4 |     "work_dir": "videos/pexels-cottonbro-studio-6649832-960x506-25fps",
 5 |     "key_subdir": "keys",
 6 |     "sd_model": "models/realisticVisionV20_v20.safetensors",
 7 |     "frame_count": 102,
 8 |     "interval": 10,
 9 |     "crop": [
10 |         0,
11 |         180,
12 |         0,
13 |         0
14 |     ],
15 |     "prompt": "white ancient Greek sculpture, Venus de Milo, light pink and blue background",
16 |     "a_prompt": "RAW photo, subject, (high detailed skin:1.2), 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3",
17 |     "n_prompt": "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers:1.4), (deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation",
18 |     "x0_strength": 0.95,
19 |     "control_type": "canny",
20 |     "canny_low": 50,
21 |     "canny_high": 100,
22 |     "control_strength": 0.7,
23 |     "seed": 0,
24 |     "warp_period": [
25 |         0,
26 |         0.1
27 |     ],
28 |     "ada_period": [
29 |         0.8,
30 |         1
31 |     ]
32 | }


--------------------------------------------------------------------------------
/config/real2sculpture_freeu.json:
--------------------------------------------------------------------------------
 1 | {
 2 |     "input": "videos/pexels-cottonbro-studio-6649832-960x506-25fps.mp4",
 3 |     "output": "videos/pexels-cottonbro-studio-6649832-960x506-25fps/blend.mp4",
 4 |     "work_dir": "videos/pexels-cottonbro-studio-6649832-960x506-25fps",
 5 |     "key_subdir": "keys",
 6 |     "sd_model": "models/realisticVisionV20_v20.safetensors",
 7 |     "frame_count": 102,
 8 |     "interval": 10,
 9 |     "crop": [
10 |         0,
11 |         180,
12 |         0,
13 |         0
14 |     ],
15 |     "prompt": "white ancient Greek sculpture, Venus de Milo, light pink and blue background",
16 |     "a_prompt": "RAW photo, subject, (high detailed skin:1.2), 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3",
17 |     "n_prompt": "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers:1.4), (deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation",
18 |     "x0_strength": 0.95,
19 |     "control_type": "canny",
20 |     "canny_low": 50,
21 |     "canny_high": 100,
22 |     "control_strength": 0.7,
23 |     "seed": 0,
24 |     "warp_period": [
25 |         0,
26 |         0.1
27 |     ],
28 |     "ada_period": [
29 |         0.8,
30 |         1
31 |     ],
32 |     "freeu_args": [
33 |         1.1,
34 |         1.2,
35 |         1.0,
36 |         0.2
37 |     ]
38 | }


--------------------------------------------------------------------------------
/config/real2sculpture_loose_cfattn.json:
--------------------------------------------------------------------------------
 1 | {
 2 |     "input": "videos/pexels-cottonbro-studio-6649832-960x506-25fps.mp4",
 3 |     "output": "videos/pexels-cottonbro-studio-6649832-960x506-25fps/blend.mp4",
 4 |     "work_dir": "videos/pexels-cottonbro-studio-6649832-960x506-25fps",
 5 |     "key_subdir": "keys",
 6 |     "sd_model": "models/realisticVisionV20_v20.safetensors",
 7 |     "frame_count": 102,
 8 |     "interval": 10,
 9 |     "crop": [
10 |         0,
11 |         180,
12 |         0,
13 |         0
14 |     ],
15 |     "prompt": "white ancient Greek sculpture, Venus de Milo, light pink and blue background",
16 |     "a_prompt": "RAW photo, subject, (high detailed skin:1.2), 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3",
17 |     "n_prompt": "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers:1.4), (deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation",
18 |     "x0_strength": 0.95,
19 |     "control_type": "canny",
20 |     "canny_low": 50,
21 |     "canny_high": 100,
22 |     "control_strength": 0.7,
23 |     "seed": 0,
24 |     "warp_period": [
25 |         0,
26 |         0.1
27 |     ],
28 |     "ada_period": [
29 |         0.8,
30 |         1
31 |     ],
32 |     "loose_cfattn": true
33 | }


--------------------------------------------------------------------------------
/config/van_gogh_man.json:
--------------------------------------------------------------------------------
 1 | {
 2 |     "input": "videos/pexels-antoni-shkraba-8048492-540x960-25fps.mp4",
 3 |     "output": "videos/pexels-antoni-shkraba-8048492-540x960-25fps/blend.mp4",
 4 |     "work_dir": "videos/pexels-antoni-shkraba-8048492-540x960-25fps",
 5 |     "key_subdir": "keys",
 6 |     "frame_count": 102,
 7 |     "interval": 10,
 8 |     "crop": [
 9 |         0,
10 |         0,
11 |         0,
12 |         280
13 |     ],
14 |     "prompt": "a handsome man in van gogh painting",
15 |     "a_prompt": "best quality, extremely detailed",
16 |     "n_prompt": "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
17 |     "x0_strength": 1.05,
18 |     "control_type": "canny",
19 |     "canny_low": 50,
20 |     "canny_high": 100,
21 |     "control_strength": 0.7,
22 |     "seed": 0,
23 |     "warp_period": [
24 |         0,
25 |         0.1
26 |     ],
27 |     "ada_period": [
28 |         0.8,
29 |         1
30 |     ],
31 |     "image_resolution": 512
32 | }
33 | 


--------------------------------------------------------------------------------
/config/van_gogh_man_dynamic_resolution.json:
--------------------------------------------------------------------------------
 1 | {
 2 |     "input": "videos/pexels-antoni-shkraba-8048492-540x960-25fps.mp4",
 3 |     "output": "videos/pexels-antoni-shkraba-8048492-540x960-25fps/blend.mp4",
 4 |     "work_dir": "videos/pexels-antoni-shkraba-8048492-540x960-25fps",
 5 |     "key_subdir": "keys",
 6 |     "frame_count": 102,
 7 |     "interval": 10,
 8 |     "crop": [
 9 |         0,
10 |         0,
11 |         0,
12 |         280
13 |     ],
14 |     "prompt": "a handsome man in van gogh painting",
15 |     "a_prompt": "best quality, extremely detailed",
16 |     "n_prompt": "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
17 |     "x0_strength": 1.05,
18 |     "control_type": "canny",
19 |     "canny_low": 50,
20 |     "canny_high": 100,
21 |     "control_strength": 0.7,
22 |     "seed": 0,
23 |     "warp_period": [
24 |         0,
25 |         0.1
26 |     ],
27 |     "ada_period": [
28 |         0.8,
29 |         1
30 |     ],
31 |     "image_resolution": 512,
32 |     "use_limit_device_resolution": true
33 | }


--------------------------------------------------------------------------------
/config/woman.json:
--------------------------------------------------------------------------------
 1 | {
 2 |     "input": "videos/pexels-koolshooters-7322716.mp4",
 3 |     "output": "videos/pexels-koolshooters-7322716/blend.mp4",
 4 |     "work_dir": "videos/pexels-koolshooters-7322716",
 5 |     "key_subdir": "keys",
 6 |     "frame_count": 102,
 7 |     "interval": 10,
 8 |     "sd_model": "models/revAnimated_v11.safetensors",
 9 |     "prompt": "a beautiful woman in CG style",
10 |     "a_prompt": "best quality, extremely detailed",
11 |     "n_prompt": "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
12 |     "x0_strength": 0.75,
13 |     "control_type": "canny",
14 |     "canny_low": 50,
15 |     "canny_high": 100,
16 |     "control_strength": 0.7,
17 |     "seed": 0,
18 |     "warp_period": [
19 |         0,
20 |         0.1
21 |     ],
22 |     "ada_period": [
23 |         1,
24 |         1
25 |     ]
26 | }


--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
 1 | name: rerender
 2 | channels:
 3 |   - pytorch
 4 |   - defaults
 5 | dependencies:
 6 |   - python=3.8.5
 7 |   - pip=20.3
 8 |   - cudatoolkit=11.3
 9 |   - pytorch=1.12.1
10 |   - torchvision=0.13.1
11 |   - numpy=1.23.1
12 |   - pip:
13 |       - gradio==3.44.4
14 |       - albumentations==1.3.0
15 |       - opencv-contrib-python==4.3.0.36
16 |       - imageio==2.9.0
17 |       - imageio-ffmpeg==0.4.2
18 |       - pytorch-lightning==1.5.0
19 |       - omegaconf==2.1.1
20 |       - test-tube>=0.7.5
21 |       - streamlit==1.12.1
22 |       - einops==0.3.0
23 |       - transformers==4.19.2
24 |       - webdataset==0.2.5
25 |       - kornia==0.6
26 |       - open_clip_torch==2.0.2
27 |       - invisible-watermark>=0.1.5
28 |       - streamlit-drawable-canvas==0.8.0
29 |       - torchmetrics==0.6.0
30 |       - timm==0.6.12
31 |       - addict==2.4.0
32 |       - yapf==0.32.0
33 |       - prettytable==3.6.0
34 |       - safetensors==0.2.7
35 |       - basicsr==1.4.2
36 |       - blendmodes
37 |       - numba==0.57.0
38 | 


--------------------------------------------------------------------------------
/flow/flow_utils.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | import sys
  3 | 
  4 | import cv2
  5 | import numpy as np
  6 | import torch
  7 | import torch.nn.functional as F
  8 | 
  9 | parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
 10 | gmflow_dir = os.path.join(parent_dir, 'deps/gmflow')
 11 | sys.path.insert(0, gmflow_dir)
 12 | 
 13 | from gmflow.gmflow import GMFlow  # noqa: E702 E402 F401
 14 | from utils.utils import InputPadder  # noqa: E702 E402
 15 | 
 16 | 
 17 | def coords_grid(b, h, w, homogeneous=False, device=None):
 18 |     y, x = torch.meshgrid(torch.arange(h), torch.arange(w))  # [H, W]
 19 | 
 20 |     stacks = [x, y]
 21 | 
 22 |     if homogeneous:
 23 |         ones = torch.ones_like(x)  # [H, W]
 24 |         stacks.append(ones)
 25 | 
 26 |     grid = torch.stack(stacks, dim=0).float()  # [2, H, W] or [3, H, W]
 27 | 
 28 |     grid = grid[None].repeat(b, 1, 1, 1)  # [B, 2, H, W] or [B, 3, H, W]
 29 | 
 30 |     if device is not None:
 31 |         grid = grid.to(device)
 32 | 
 33 |     return grid
 34 | 
 35 | 
 36 | def bilinear_sample(img,
 37 |                     sample_coords,
 38 |                     mode='bilinear',
 39 |                     padding_mode='zeros',
 40 |                     return_mask=False):
 41 |     # img: [B, C, H, W]
 42 |     # sample_coords: [B, 2, H, W] in image scale
 43 |     if sample_coords.size(1) != 2:  # [B, H, W, 2]
 44 |         sample_coords = sample_coords.permute(0, 3, 1, 2)
 45 | 
 46 |     b, _, h, w = sample_coords.shape
 47 | 
 48 |     # Normalize to [-1, 1]
 49 |     x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1
 50 |     y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1
 51 | 
 52 |     grid = torch.stack([x_grid, y_grid], dim=-1)  # [B, H, W, 2]
 53 | 
 54 |     img = F.grid_sample(img,
 55 |                         grid,
 56 |                         mode=mode,
 57 |                         padding_mode=padding_mode,
 58 |                         align_corners=True)
 59 | 
 60 |     if return_mask:
 61 |         mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (
 62 |             y_grid <= 1)  # [B, H, W]
 63 | 
 64 |         return img, mask
 65 | 
 66 |     return img
 67 | 
 68 | 
 69 | def flow_warp(feature,
 70 |               flow,
 71 |               mask=False,
 72 |               mode='bilinear',
 73 |               padding_mode='zeros'):
 74 |     b, c, h, w = feature.size()
 75 |     assert flow.size(1) == 2
 76 | 
 77 |     grid = coords_grid(b, h, w).to(flow.device) + flow  # [B, 2, H, W]
 78 | 
 79 |     return bilinear_sample(feature,
 80 |                            grid,
 81 |                            mode=mode,
 82 |                            padding_mode=padding_mode,
 83 |                            return_mask=mask)
 84 | 
 85 | 
 86 | def forward_backward_consistency_check(fwd_flow,
 87 |                                        bwd_flow,
 88 |                                        alpha=0.01,
 89 |                                        beta=0.5):
 90 |     # fwd_flow, bwd_flow: [B, 2, H, W]
 91 |     # alpha and beta values are following UnFlow
 92 |     # (https://arxiv.org/abs/1711.07837)
 93 |     assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4
 94 |     assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2
 95 |     flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow,
 96 |                                                         dim=1)  # [B, H, W]
 97 | 
 98 |     warped_bwd_flow = flow_warp(bwd_flow, fwd_flow)  # [B, 2, H, W]
 99 |     warped_fwd_flow = flow_warp(fwd_flow, bwd_flow)  # [B, 2, H, W]
100 | 
101 |     diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1)  # [B, H, W]
102 |     diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1)
103 | 
104 |     threshold = alpha * flow_mag + beta
105 | 
106 |     fwd_occ = (diff_fwd > threshold).float()  # [B, H, W]
107 |     bwd_occ = (diff_bwd > threshold).float()
108 | 
109 |     return fwd_occ, bwd_occ
110 | 
111 | 
112 | @torch.no_grad()
113 | def get_warped_and_mask(flow_model,
114 |                         image1,
115 |                         image2,
116 |                         image3=None,
117 |                         pixel_consistency=False):
118 |     if image3 is None:
119 |         image3 = image1
120 |     padder = InputPadder(image1.shape, padding_factor=8)
121 |     image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
122 |     results_dict = flow_model(image1,
123 |                               image2,
124 |                               attn_splits_list=[2],
125 |                               corr_radius_list=[-1],
126 |                               prop_radius_list=[-1],
127 |                               pred_bidir_flow=True)
128 |     flow_pr = results_dict['flow_preds'][-1]  # [B, 2, H, W]
129 |     fwd_flow = padder.unpad(flow_pr[0]).unsqueeze(0)  # [1, 2, H, W]
130 |     bwd_flow = padder.unpad(flow_pr[1]).unsqueeze(0)  # [1, 2, H, W]
131 |     fwd_occ, bwd_occ = forward_backward_consistency_check(
132 |         fwd_flow, bwd_flow)  # [1, H, W] float
133 |     if pixel_consistency:
134 |         warped_image1 = flow_warp(image1, bwd_flow)
135 |         bwd_occ = torch.clamp(
136 |             bwd_occ +
137 |             (abs(image2 - warped_image1).mean(dim=1) > 255 * 0.25).float(), 0,
138 |             1).unsqueeze(0)
139 |     warped_results = flow_warp(image3, bwd_flow)
140 |     return warped_results, bwd_occ, bwd_flow
141 | 
142 | 
143 | class FlowCalc():
144 | 
145 |     def __init__(self, model_path='./models/gmflow_sintel-0c07dcb3.pth'):
146 |         flow_model = GMFlow(
147 |             feature_channels=128,
148 |             num_scales=1,
149 |             upsample_factor=8,
150 |             num_head=1,
151 |             attention_type='swin',
152 |             ffn_dim_expansion=4,
153 |             num_transformer_layers=6,
154 |         ).to('cuda')
155 | 
156 |         checkpoint = torch.load(model_path,
157 |                                 map_location=lambda storage, loc: storage)
158 |         weights = checkpoint['model'] if 'model' in checkpoint else checkpoint
159 |         flow_model.load_state_dict(weights, strict=False)
160 |         flow_model.eval()
161 |         self.model = flow_model
162 | 
163 |     @torch.no_grad()
164 |     def get_flow(self, image1, image2, save_path=None):
165 | 
166 |         if save_path is not None and os.path.exists(save_path):
167 |             bwd_flow = read_flow(save_path)
168 |             return bwd_flow
169 | 
170 |         image1 = torch.from_numpy(image1).permute(2, 0, 1).float()
171 |         image2 = torch.from_numpy(image2).permute(2, 0, 1).float()
172 |         padder = InputPadder(image1.shape, padding_factor=8)
173 |         image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
174 |         results_dict = self.model(image1,
175 |                                   image2,
176 |                                   attn_splits_list=[2],
177 |                                   corr_radius_list=[-1],
178 |                                   prop_radius_list=[-1],
179 |                                   pred_bidir_flow=True)
180 |         flow_pr = results_dict['flow_preds'][-1]  # [B, 2, H, W]
181 |         fwd_flow = padder.unpad(flow_pr[0]).unsqueeze(0)  # [1, 2, H, W]
182 |         bwd_flow = padder.unpad(flow_pr[1]).unsqueeze(0)  # [1, 2, H, W]
183 |         fwd_occ, bwd_occ = forward_backward_consistency_check(
184 |             fwd_flow, bwd_flow)  # [1, H, W] float
185 |         if save_path is not None:
186 |             flow_np = bwd_flow.cpu().numpy()
187 |             np.save(save_path, flow_np)
188 |             mask_path = os.path.splitext(save_path)[0] + '.png'
189 |             bwd_occ = bwd_occ.cpu().permute(1, 2, 0).to(
190 |                 torch.long).numpy() * 255
191 |             cv2.imwrite(mask_path, bwd_occ)
192 | 
193 |         return bwd_flow
194 | 
195 |     @torch.no_grad()
196 |     def get_mask(self, image1, image2, save_path=None):
197 | 
198 |         if save_path is not None:
199 |             mask_path = os.path.splitext(save_path)[0] + '.png'
200 |             if os.path.exists(mask_path):
201 |                 return read_mask(mask_path)
202 | 
203 |         image1 = torch.from_numpy(image1).permute(2, 0, 1).float()
204 |         image2 = torch.from_numpy(image2).permute(2, 0, 1).float()
205 |         padder = InputPadder(image1.shape, padding_factor=8)
206 |         image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
207 |         results_dict = self.model(image1,
208 |                                   image2,
209 |                                   attn_splits_list=[2],
210 |                                   corr_radius_list=[-1],
211 |                                   prop_radius_list=[-1],
212 |                                   pred_bidir_flow=True)
213 |         flow_pr = results_dict['flow_preds'][-1]  # [B, 2, H, W]
214 |         fwd_flow = padder.unpad(flow_pr[0]).unsqueeze(0)  # [1, 2, H, W]
215 |         bwd_flow = padder.unpad(flow_pr[1]).unsqueeze(0)  # [1, 2, H, W]
216 |         fwd_occ, bwd_occ = forward_backward_consistency_check(
217 |             fwd_flow, bwd_flow)  # [1, H, W] float
218 |         if save_path is not None:
219 |             flow_np = bwd_flow.cpu().numpy()
220 |             np.save(save_path, flow_np)
221 |             mask_path = os.path.splitext(save_path)[0] + '.png'
222 |             bwd_occ = bwd_occ.cpu().permute(1, 2, 0).to(
223 |                 torch.long).numpy() * 255
224 |             cv2.imwrite(mask_path, bwd_occ)
225 | 
226 |         return bwd_occ
227 | 
228 |     def warp(self, img, flow, mode='bilinear'):
229 |         expand = False
230 |         if len(img.shape) == 2:
231 |             expand = True
232 |             img = np.expand_dims(img, 2)
233 | 
234 |         img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
235 |         dtype = img.dtype
236 |         img = img.to(torch.float)
237 |         res = flow_warp(img, flow, mode=mode)
238 |         res = res.to(dtype)
239 |         res = res[0].cpu().permute(1, 2, 0).numpy()
240 |         if expand:
241 |             res = res[:, :, 0]
242 |         return res
243 | 
244 | 
245 | def read_flow(save_path):
246 |     flow_np = np.load(save_path)
247 |     bwd_flow = torch.from_numpy(flow_np)
248 |     return bwd_flow
249 | 
250 | 
251 | def read_mask(save_path):
252 |     mask_path = os.path.splitext(save_path)[0] + '.png'
253 |     mask = cv2.imread(mask_path)
254 |     mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
255 |     return mask
256 | 
257 | 
258 | flow_calc = FlowCalc()
259 | 


--------------------------------------------------------------------------------
/install.py:
--------------------------------------------------------------------------------
 1 | import os
 2 | import platform
 3 | 
 4 | import requests
 5 | 
 6 | 
 7 | def build_ebsynth():
 8 |     if os.path.exists('deps/ebsynth/bin/ebsynth'):
 9 |         print('Ebsynth has been built.')
10 |         return
11 | 
12 |     os_str = platform.system()
13 | 
14 |     if os_str == 'Windows':
15 |         print('Build Ebsynth Windows 64 bit.',
16 |               'If you want to build for 32 bit, please modify install.py.')
17 |         cmd = '.\\build-win64-cpu+cuda.bat'
18 |         exe_file = 'deps/ebsynth/bin/ebsynth.exe'
19 |     elif os_str == 'Linux':
20 |         cmd = 'bash build-linux-cpu+cuda.sh'
21 |         exe_file = 'deps/ebsynth/bin/ebsynth'
22 |     elif os_str == 'Darwin':
23 |         cmd = 'sh build-macos-cpu_only.sh'
24 |         exe_file = 'deps/ebsynth/bin/ebsynth.app'
25 |     else:
26 |         print('Cannot recognize OS. Ebsynth installation stopped.')
27 |         return
28 | 
29 |     os.chdir('deps/ebsynth')
30 |     print(cmd)
31 |     os.system(cmd)
32 |     os.chdir('../..')
33 |     if os.path.exists(exe_file):
34 |         print('Ebsynth installed successfully.')
35 |     else:
36 |         print('Failed to install Ebsynth.')
37 | 
38 | 
39 | def download(url, dir, name=None):
40 |     os.makedirs(dir, exist_ok=True)
41 |     if name is None:
42 |         name = url.split('/')[-1]
43 |     path = os.path.join(dir, name)
44 |     if not os.path.exists(path):
45 |         print(f'Install {name} ...')
46 |         open(path, 'wb').write(requests.get(url).content)
47 |         print('Install successfully.')
48 | 
49 | 
50 | def download_gmflow_ckpt():
51 |     url = ('https://huggingface.co/PKUWilliamYang/Rerender/'
52 |            'resolve/main/models/gmflow_sintel-0c07dcb3.pth')
53 |     download(url, 'models')
54 | 
55 | 
56 | def download_controlnet_canny():
57 |     url = ('https://huggingface.co/lllyasviel/ControlNet/'
58 |            'resolve/main/models/control_sd15_canny.pth')
59 |     download(url, 'models')
60 | 
61 | 
62 | def download_controlnet_hed():
63 |     url = ('https://huggingface.co/lllyasviel/ControlNet/'
64 |            'resolve/main/models/control_sd15_hed.pth')
65 |     download(url, 'models')
66 | 
67 | 
68 | def download_vae():
69 |     url = ('https://huggingface.co/stabilityai/sd-vae-ft-mse-original'
70 |            '/resolve/main/vae-ft-mse-840000-ema-pruned.ckpt')
71 |     download(url, 'models')
72 | 
73 | 
74 | build_ebsynth()
75 | download_gmflow_ckpt()
76 | download_controlnet_canny()
77 | download_controlnet_hed()
78 | download_vae()
79 | 


--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
 1 | addict==2.4.0
 2 | albumentations==1.3.0
 3 | basicsr==1.4.2
 4 | blendmodes
 5 | einops==0.3.0
 6 | gradio==3.44.4
 7 | imageio==2.9.0
 8 | imageio-ffmpeg==0.4.2
 9 | invisible-watermark==0.1.5
10 | kornia==0.6
11 | numba==0.57.0
12 | omegaconf==2.1.1
13 | open_clip_torch==2.0.2
14 | prettytable==3.6.0
15 | pytorch-lightning==1.5.0
16 | safetensors==0.2.7
17 | streamlit==1.12.1
18 | streamlit-drawable-canvas==0.8.0
19 | test-tube==0.7.5
20 | timm==0.6.12
21 | torchmetrics==0.6.0
22 | transformers==4.19.2
23 | webdataset==0.2.5
24 | yapf==0.32.0
25 | 


--------------------------------------------------------------------------------
/rerender.py:
--------------------------------------------------------------------------------
  1 | import argparse
  2 | import os
  3 | import random
  4 | 
  5 | import cv2
  6 | import einops
  7 | import numpy as np
  8 | import torch
  9 | import torch.nn.functional as F
 10 | import torchvision.transforms as T
 11 | from blendmodes.blend import BlendType, blendLayers
 12 | from PIL import Image
 13 | from pytorch_lightning import seed_everything
 14 | from safetensors.torch import load_file
 15 | from skimage import exposure
 16 | 
 17 | import src.import_util  # noqa: F401
 18 | from deps.ControlNet.annotator.canny import CannyDetector
 19 | from deps.ControlNet.annotator.hed import HEDdetector
 20 | from deps.ControlNet.annotator.util import HWC3
 21 | from deps.ControlNet.cldm.cldm import ControlLDM
 22 | from deps.ControlNet.cldm.model import create_model, load_state_dict
 23 | from deps.gmflow.gmflow.gmflow import GMFlow
 24 | from flow.flow_utils import get_warped_and_mask
 25 | from src.config import RerenderConfig
 26 | from src.controller import AttentionControl
 27 | from src.ddim_v_hacked import DDIMVSampler
 28 | from src.freeu import freeu_forward
 29 | from src.img_util import find_flat_region, numpy2tensor
 30 | from src.video_util import frame_to_video, get_fps, prepare_frames
 31 | 
 32 | blur = T.GaussianBlur(kernel_size=(9, 9), sigma=(18, 18))
 33 | totensor = T.PILToTensor()
 34 | 
 35 | 
 36 | def setup_color_correction(image):
 37 |     correction_target = cv2.cvtColor(np.asarray(image.copy()),
 38 |                                      cv2.COLOR_RGB2LAB)
 39 |     return correction_target
 40 | 
 41 | 
 42 | def apply_color_correction(correction, original_image):
 43 |     image = Image.fromarray(
 44 |         cv2.cvtColor(
 45 |             exposure.match_histograms(cv2.cvtColor(np.asarray(original_image),
 46 |                                                    cv2.COLOR_RGB2LAB),
 47 |                                       correction,
 48 |                                       channel_axis=2),
 49 |             cv2.COLOR_LAB2RGB).astype('uint8'))
 50 | 
 51 |     image = blendLayers(image, original_image, BlendType.LUMINOSITY)
 52 | 
 53 |     return image
 54 | 
 55 | 
 56 | def rerender(cfg: RerenderConfig, first_img_only: bool, key_video_path: str):
 57 | 
 58 |     # Preprocess input
 59 |     prepare_frames(cfg.input_path, cfg.input_dir, cfg.image_resolution, cfg.crop, cfg.use_limit_device_resolution)
 60 | 
 61 |     # Load models
 62 |     if cfg.control_type == 'HED':
 63 |         detector = HEDdetector()
 64 |     elif cfg.control_type == 'canny':
 65 |         canny_detector = CannyDetector()
 66 |         low_threshold = cfg.canny_low
 67 |         high_threshold = cfg.canny_high
 68 | 
 69 |         def apply_canny(x):
 70 |             return canny_detector(x, low_threshold, high_threshold)
 71 | 
 72 |         detector = apply_canny
 73 | 
 74 |     model: ControlLDM = create_model(
 75 |         './deps/ControlNet/models/cldm_v15.yaml').cpu()
 76 |     if cfg.control_type == 'HED':
 77 |         model.load_state_dict(
 78 |             load_state_dict('./models/control_sd15_hed.pth', location='cuda'))
 79 |     elif cfg.control_type == 'canny':
 80 |         model.load_state_dict(
 81 |             load_state_dict('./models/control_sd15_canny.pth',
 82 |                             location='cuda'))
 83 |     model = model.cuda()
 84 |     model.control_scales = [cfg.control_strength] * 13
 85 | 
 86 |     if cfg.sd_model is not None:
 87 |         model_ext = os.path.splitext(cfg.sd_model)[1]
 88 |         if model_ext == '.safetensors':
 89 |             model.load_state_dict(load_file(cfg.sd_model), strict=False)
 90 |         elif model_ext == '.ckpt' or model_ext == '.pth':
 91 |             model.load_state_dict(torch.load(cfg.sd_model)['state_dict'],
 92 |                                   strict=False)
 93 | 
 94 |     try:
 95 |         model.first_stage_model.load_state_dict(torch.load(
 96 |             './models/vae-ft-mse-840000-ema-pruned.ckpt')['state_dict'],
 97 |                                                 strict=False)
 98 |     except Exception:
 99 |         print('Warning: We suggest you download the fine-tuned VAE',
100 |               'otherwise the generation quality will be degraded')
101 | 
102 |     model.model.diffusion_model.forward = \
103 |         freeu_forward(model.model.diffusion_model, *cfg.freeu_args)
104 |     ddim_v_sampler = DDIMVSampler(model)
105 | 
106 |     flow_model = GMFlow(
107 |         feature_channels=128,
108 |         num_scales=1,
109 |         upsample_factor=8,
110 |         num_head=1,
111 |         attention_type='swin',
112 |         ffn_dim_expansion=4,
113 |         num_transformer_layers=6,
114 |     ).to('cuda')
115 | 
116 |     checkpoint = torch.load('models/gmflow_sintel-0c07dcb3.pth',
117 |                             map_location=lambda storage, loc: storage)
118 |     weights = checkpoint['model'] if 'model' in checkpoint else checkpoint
119 |     flow_model.load_state_dict(weights, strict=False)
120 |     flow_model.eval()
121 | 
122 |     num_samples = 1
123 |     ddim_steps = 20
124 |     scale = 7.5
125 | 
126 |     seed = cfg.seed
127 |     if seed == -1:
128 |         seed = random.randint(0, 65535)
129 |     eta = 0.0
130 | 
131 |     prompt = cfg.prompt
132 |     a_prompt = cfg.a_prompt
133 |     n_prompt = cfg.n_prompt
134 |     prompt = prompt + ', ' + a_prompt
135 | 
136 |     style_update_freq = cfg.style_update_freq
137 |     pixelfusion = True
138 |     color_preserve = cfg.color_preserve
139 | 
140 |     x0_strength = 1 - cfg.x0_strength
141 |     mask_period = cfg.mask_period
142 |     firstx0 = True
143 |     controller = AttentionControl(cfg.inner_strength, cfg.mask_period,
144 |                                   cfg.cross_period, cfg.ada_period,
145 |                                   cfg.warp_period, cfg.loose_cfattn)
146 | 
147 |     imgs = sorted(os.listdir(cfg.input_dir))
148 |     imgs = [os.path.join(cfg.input_dir, img) for img in imgs]
149 |     if cfg.frame_count >= 0:
150 |         imgs = imgs[:cfg.frame_count]
151 | 
152 |     with torch.no_grad():
153 |         frame = cv2.imread(imgs[0])
154 |         frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
155 |         img = HWC3(frame)
156 |         H, W, C = img.shape
157 | 
158 |         img_ = numpy2tensor(img)
159 |         # if color_preserve:
160 |         #     img_ = numpy2tensor(img)
161 |         # else:
162 |         #     img_ = apply_color_correction(color_corrections,
163 |         #                                   Image.fromarray(img))
164 |         #     img_ = totensor(img_).unsqueeze(0)[:, :3] / 127.5 - 1
165 |         encoder_posterior = model.encode_first_stage(img_.cuda())
166 |         x0 = model.get_first_stage_encoding(encoder_posterior).detach()
167 | 
168 |         detected_map = detector(img)
169 |         detected_map = HWC3(detected_map)
170 |         # For visualization
171 |         detected_img = 255 - detected_map
172 | 
173 |         control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
174 |         control = torch.stack([control for _ in range(num_samples)], dim=0)
175 |         control = einops.rearrange(control, 'b h w c -> b c h w').clone()
176 |         cond = {
177 |             'c_concat': [control],
178 |             'c_crossattn':
179 |             [model.get_learned_conditioning([prompt] * num_samples)]
180 |         }
181 |         un_cond = {
182 |             'c_concat': [control],
183 |             'c_crossattn':
184 |             [model.get_learned_conditioning([n_prompt] * num_samples)]
185 |         }
186 |         shape = (4, H // 8, W // 8)
187 | 
188 |         controller.set_task('initfirst')
189 |         seed_everything(seed)
190 |         samples, _ = ddim_v_sampler.sample(ddim_steps,
191 |                                            num_samples,
192 |                                            shape,
193 |                                            cond,
194 |                                            verbose=False,
195 |                                            eta=eta,
196 |                                            unconditional_guidance_scale=scale,
197 |                                            unconditional_conditioning=un_cond,
198 |                                            controller=controller,
199 |                                            x0=x0,
200 |                                            strength=x0_strength)
201 |         x_samples = model.decode_first_stage(samples)
202 |         pre_result = x_samples
203 |         pre_img = img
204 |         first_result = pre_result
205 |         first_img = pre_img
206 | 
207 |         x_samples = (
208 |             einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
209 |             127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
210 |     color_corrections = setup_color_correction(Image.fromarray(x_samples[0]))
211 |     Image.fromarray(x_samples[0]).save(os.path.join(cfg.first_dir,
212 |                                                     'first.jpg'))
213 |     cv2.imwrite(os.path.join(cfg.first_dir, 'first_edge.jpg'), detected_img)
214 | 
215 |     if first_img_only:
216 |         exit(0)
217 | 
218 |     for i in range(0, min(len(imgs), cfg.frame_count) - 1, cfg.interval):
219 |         cid = i + 1
220 |         print(cid)
221 |         if cid <= (len(imgs) - 1):
222 |             frame = cv2.imread(imgs[cid])
223 |         else:
224 |             frame = cv2.imread(imgs[len(imgs) - 1])
225 |         frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
226 |         img = HWC3(frame)
227 | 
228 |         if color_preserve:
229 |             img_ = numpy2tensor(img)
230 |         else:
231 |             img_ = apply_color_correction(color_corrections,
232 |                                           Image.fromarray(img))
233 |             img_ = totensor(img_).unsqueeze(0)[:, :3] / 127.5 - 1
234 |         encoder_posterior = model.encode_first_stage(img_.cuda())
235 |         x0 = model.get_first_stage_encoding(encoder_posterior).detach()
236 | 
237 |         detected_map = detector(img)
238 |         detected_map = HWC3(detected_map)
239 | 
240 |         control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
241 |         control = torch.stack([control for _ in range(num_samples)], dim=0)
242 |         control = einops.rearrange(control, 'b h w c -> b c h w').clone()
243 |         cond['c_concat'] = [control]
244 |         un_cond['c_concat'] = [control]
245 | 
246 |         image1 = torch.from_numpy(pre_img).permute(2, 0, 1).float()
247 |         image2 = torch.from_numpy(img).permute(2, 0, 1).float()
248 |         warped_pre, bwd_occ_pre, bwd_flow_pre = get_warped_and_mask(
249 |             flow_model, image1, image2, pre_result, False)
250 |         blend_mask_pre = blur(
251 |             F.max_pool2d(bwd_occ_pre, kernel_size=9, stride=1, padding=4))
252 |         blend_mask_pre = torch.clamp(blend_mask_pre + bwd_occ_pre, 0, 1)
253 | 
254 |         image1 = torch.from_numpy(first_img).permute(2, 0, 1).float()
255 |         warped_0, bwd_occ_0, bwd_flow_0 = get_warped_and_mask(
256 |             flow_model, image1, image2, first_result, False)
257 |         blend_mask_0 = blur(
258 |             F.max_pool2d(bwd_occ_0, kernel_size=9, stride=1, padding=4))
259 |         blend_mask_0 = torch.clamp(blend_mask_0 + bwd_occ_0, 0, 1)
260 | 
261 |         if firstx0:
262 |             mask = 1 - F.max_pool2d(blend_mask_0, kernel_size=8)
263 |             controller.set_warp(
264 |                 F.interpolate(bwd_flow_0 / 8.0,
265 |                               scale_factor=1. / 8,
266 |                               mode='bilinear'), mask)
267 |         else:
268 |             mask = 1 - F.max_pool2d(blend_mask_pre, kernel_size=8)
269 |             controller.set_warp(
270 |                 F.interpolate(bwd_flow_pre / 8.0,
271 |                               scale_factor=1. / 8,
272 |                               mode='bilinear'), mask)
273 | 
274 |         controller.set_task('keepx0, keepstyle')
275 |         seed_everything(seed)
276 |         samples, intermediates = ddim_v_sampler.sample(
277 |             ddim_steps,
278 |             num_samples,
279 |             shape,
280 |             cond,
281 |             verbose=False,
282 |             eta=eta,
283 |             unconditional_guidance_scale=scale,
284 |             unconditional_conditioning=un_cond,
285 |             controller=controller,
286 |             x0=x0,
287 |             strength=x0_strength)
288 |         direct_result = model.decode_first_stage(samples)
289 | 
290 |         if not pixelfusion:
291 |             pre_result = direct_result
292 |             pre_img = img
293 |             viz = (
294 |                 einops.rearrange(direct_result, 'b c h w -> b h w c') * 127.5 +
295 |                 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
296 | 
297 |         else:
298 | 
299 |             blend_results = (1 - blend_mask_pre
300 |                              ) * warped_pre + blend_mask_pre * direct_result
301 |             blend_results = (
302 |                 1 - blend_mask_0) * warped_0 + blend_mask_0 * blend_results
303 | 
304 |             bwd_occ = 1 - torch.clamp(1 - bwd_occ_pre + 1 - bwd_occ_0, 0, 1)
305 |             blend_mask = blur(
306 |                 F.max_pool2d(bwd_occ, kernel_size=9, stride=1, padding=4))
307 |             blend_mask = 1 - torch.clamp(blend_mask + bwd_occ, 0, 1)
308 | 
309 |             encoder_posterior = model.encode_first_stage(blend_results)
310 |             xtrg = model.get_first_stage_encoding(
311 |                 encoder_posterior).detach()  # * mask
312 |             blend_results_rec = model.decode_first_stage(xtrg)
313 |             encoder_posterior = model.encode_first_stage(blend_results_rec)
314 |             xtrg_rec = model.get_first_stage_encoding(
315 |                 encoder_posterior).detach()
316 |             xtrg_ = (xtrg + 1 * (xtrg - xtrg_rec))  # * mask
317 |             blend_results_rec_new = model.decode_first_stage(xtrg_)
318 |             tmp = (abs(blend_results_rec_new - blend_results).mean(
319 |                 dim=1, keepdims=True) > 0.25).float()
320 |             mask_x = F.max_pool2d((F.interpolate(
321 |                 tmp, scale_factor=1 / 8., mode='bilinear') > 0).float(),
322 |                                   kernel_size=3,
323 |                                   stride=1,
324 |                                   padding=1)
325 | 
326 |             mask = (1 - F.max_pool2d(1 - blend_mask, kernel_size=8)
327 |                     )  # * (1-mask_x)
328 | 
329 |             if cfg.smooth_boundary:
330 |                 noise_rescale = find_flat_region(mask)
331 |             else:
332 |                 noise_rescale = torch.ones_like(mask)
333 |             masks = []
334 |             for j in range(ddim_steps):
335 |                 if j <= ddim_steps * mask_period[
336 |                         0] or j >= ddim_steps * mask_period[1]:
337 |                     masks += [None]
338 |                 else:
339 |                     masks += [mask * cfg.mask_strength]
340 | 
341 |             # mask 3
342 |             # xtrg = ((1-mask_x) *
343 |             #         (xtrg + xtrg - xtrg_rec) + mask_x * samples) * mask
344 |             # mask 2
345 |             # xtrg = (xtrg + 1 * (xtrg - xtrg_rec)) * mask
346 |             xtrg = (xtrg + (1 - mask_x) * (xtrg - xtrg_rec)) * mask  # mask 1
347 | 
348 |             tasks = 'keepstyle, keepx0'
349 |             if not firstx0:
350 |                 tasks += ', updatex0'
351 |             if i % style_update_freq == 0:
352 |                 tasks += ', updatestyle'
353 |             controller.set_task(tasks, 1.0)
354 | 
355 |             seed_everything(seed)
356 |             samples, _ = ddim_v_sampler.sample(
357 |                 ddim_steps,
358 |                 num_samples,
359 |                 shape,
360 |                 cond,
361 |                 verbose=False,
362 |                 eta=eta,
363 |                 unconditional_guidance_scale=scale,
364 |                 unconditional_conditioning=un_cond,
365 |                 controller=controller,
366 |                 x0=x0,
367 |                 strength=x0_strength,
368 |                 xtrg=xtrg,
369 |                 mask=masks,
370 |                 noise_rescale=noise_rescale)
371 |             x_samples = model.decode_first_stage(samples)
372 |             pre_result = x_samples
373 |             pre_img = img
374 | 
375 |             viz = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
376 |                    127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
377 | 
378 |         Image.fromarray(viz[0]).save(
379 |             os.path.join(cfg.key_dir, f'{cid:04d}.png'))
380 |     if key_video_path is not None:
381 |         fps = get_fps(cfg.input_path)
382 |         fps //= cfg.interval
383 |         frame_to_video(key_video_path, cfg.key_dir, fps, False)
384 | 
385 | 
386 | def postprocess(cfg: RerenderConfig, ne: bool, max_process: int, tmp: bool,
387 |                 ps: bool):
388 |     video_base_dir = cfg.work_dir
389 |     o_video = cfg.output_path
390 |     fps = get_fps(cfg.input_path)
391 | 
392 |     end_frame = cfg.frame_count - 1
393 |     interval = cfg.interval
394 |     key_dir = os.path.split(cfg.key_dir)[-1]
395 |     use_e = '-ne' if ne else ''
396 |     use_tmp = '-tmp' if tmp else ''
397 |     use_ps = '-ps' if ps else ''
398 |     o_video_cmd = f'--output {o_video}'
399 | 
400 |     cmd = (
401 |         f'python video_blend.py {video_base_dir} --beg 1 --end {end_frame} '
402 |         f'--itv {interval} --key {key_dir} {use_e} {o_video_cmd} --fps {fps} '
403 |         f'--n_proc {max_process} {use_tmp} {use_ps}')
404 |     print(cmd)
405 |     os.system(cmd)
406 | 
407 | 
408 | if __name__ == '__main__':
409 |     parser = argparse.ArgumentParser()
410 |     parser.add_argument('--cfg', type=str, default=None)
411 |     parser.add_argument('--input',
412 |                         type=str,
413 |                         default=None,
414 |                         help='The input path to video.')
415 |     parser.add_argument('--output', type=str, default=None)
416 |     parser.add_argument('--prompt', type=str, default=None)
417 |     parser.add_argument('--key_video_path', type=str, default=None)
418 |     parser.add_argument('-one',
419 |                         action='store_true',
420 |                         help='Run the first frame with ControlNet only')
421 |     parser.add_argument('-nr',
422 |                         action='store_true',
423 |                         help='Do not run rerender and do postprocessing only')
424 |     parser.add_argument('-nb',
425 |                         action='store_true',
426 |                         help='Do not run postprocessing and run rerender only')
427 |     parser.add_argument(
428 |         '-ne',
429 |         action='store_true',
430 |         help='Do not run ebsynth (use previous ebsynth temporary output)')
431 |     parser.add_argument('-nps',
432 |                         action='store_true',
433 |                         help='Do not run poisson gradient blending')
434 |     parser.add_argument('--n_proc',
435 |                         type=int,
436 |                         default=4,
437 |                         help='The max process count')
438 |     parser.add_argument('--tmp',
439 |                         action='store_true',
440 |                         help='Keep ebsynth temporary output')
441 | 
442 |     args = parser.parse_args()
443 | 
444 |     cfg = RerenderConfig()
445 |     if args.cfg is not None:
446 |         cfg.create_from_path(args.cfg)
447 |         if args.input is not None:
448 |             print('Config has been loaded. --input is ignored.')
449 |         if args.output is not None:
450 |             print('Config has been loaded. --output is ignored.')
451 |         if args.prompt is not None:
452 |             print('Config has been loaded. --prompt is ignored.')
453 |     else:
454 |         if args.input is None:
455 |             print('Config not found. --input is required.')
456 |             exit(0)
457 |         if args.output is None:
458 |             print('Config not found. --output is required.')
459 |             exit(0)
460 |         if args.prompt is None:
461 |             print('Config not found. --prompt is required.')
462 |             exit(0)
463 |         cfg.create_from_parameters(args.input, args.output, args.prompt)
464 | 
465 |     if not args.nr:
466 |         rerender(cfg, args.one, args.key_video_path)
467 |         torch.cuda.empty_cache()
468 |     if not args.nb:
469 |         postprocess(cfg, args.ne, args.n_proc, args.tmp, not args.nps)
470 | 


--------------------------------------------------------------------------------
/sd_model_cfg.py:
--------------------------------------------------------------------------------
1 | # The model dict is used for webUI only
2 | 
3 | model_dict = {
4 |     'Stable Diffusion 1.5': '',
5 |     'revAnimated_v11': 'models/revAnimated_v11.safetensors',
6 |     'realisticVisionV20_v20': 'models/realisticVisionV20_v20.safetensors'
7 | }
8 | 


--------------------------------------------------------------------------------
/src/config.py:
--------------------------------------------------------------------------------
  1 | import json
  2 | import os
  3 | from typing import Optional, Sequence, Tuple
  4 | 
  5 | from src.video_util import get_frame_count
  6 | 
  7 | 
  8 | class RerenderConfig:
  9 | 
 10 |     def __init__(self):
 11 |         ...
 12 | 
 13 |     def create_from_parameters(self,
 14 |                                input_path: str,
 15 |                                output_path: str,
 16 |                                prompt: str,
 17 |                                work_dir: Optional[str] = None,
 18 |                                key_subdir: str = 'keys',
 19 |                                frame_count: Optional[int] = None,
 20 |                                interval: int = 10,
 21 |                                crop: Sequence[int] = (0, 0, 0, 0),
 22 |                                sd_model: Optional[str] = None,
 23 |                                a_prompt: str = '',
 24 |                                n_prompt: str = '',
 25 |                                ddim_steps=20,
 26 |                                scale=7.5,
 27 |                                control_type: str = 'HED',
 28 |                                control_strength=1,
 29 |                                seed: int = -1,
 30 |                                image_resolution: int = 512,
 31 |                                use_limit_device_resolution: bool = False,
 32 |                                x0_strength: float = -1,
 33 |                                style_update_freq: int = 10,
 34 |                                cross_period: Tuple[float, float] = (0, 1),
 35 |                                warp_period: Tuple[float, float] = (0, 0.1),
 36 |                                mask_period: Tuple[float, float] = (0.5, 0.8),
 37 |                                ada_period: Tuple[float, float] = (1.0, 1.0),
 38 |                                mask_strength: float = 0.5,
 39 |                                inner_strength: float = 0.9,
 40 |                                smooth_boundary: bool = True,
 41 |                                color_preserve: bool = True,
 42 |                                loose_cfattn: bool = False,
 43 |                                freeu_args: Tuple[int] = (1, 1, 1, 1),
 44 |                                **kwargs):
 45 |         self.input_path = input_path
 46 |         self.output_path = output_path
 47 |         self.prompt = prompt
 48 |         self.work_dir = work_dir
 49 |         if work_dir is None:
 50 |             self.work_dir = os.path.dirname(output_path)
 51 |         self.key_dir = os.path.join(self.work_dir, key_subdir)
 52 |         self.first_dir = os.path.join(self.work_dir, 'first')
 53 | 
 54 |         # Split video into frames
 55 |         if not os.path.isfile(input_path):
 56 |             raise FileNotFoundError(f'Cannot find video file {input_path}')
 57 |         self.input_dir = os.path.join(self.work_dir, 'video')
 58 | 
 59 |         self.frame_count = frame_count
 60 |         if frame_count is None:
 61 |             self.frame_count = get_frame_count(self.input_path)
 62 |         self.interval = interval
 63 |         self.crop = crop
 64 |         self.sd_model = sd_model
 65 |         self.a_prompt = a_prompt
 66 |         self.n_prompt = n_prompt
 67 |         self.ddim_steps = ddim_steps
 68 |         self.scale = scale
 69 |         self.control_type = control_type
 70 |         if self.control_type == 'canny':
 71 |             self.canny_low = kwargs.get('canny_low', 100)
 72 |             self.canny_high = kwargs.get('canny_high', 200)
 73 |         else:
 74 |             self.canny_low = None
 75 |             self.canny_high = None
 76 |         self.control_strength = control_strength
 77 |         self.seed = seed
 78 |         self.image_resolution = image_resolution
 79 |         self.use_limit_device_resolution = use_limit_device_resolution
 80 |         self.x0_strength = x0_strength
 81 |         self.style_update_freq = style_update_freq
 82 |         self.cross_period = cross_period
 83 |         self.mask_period = mask_period
 84 |         self.warp_period = warp_period
 85 |         self.ada_period = ada_period
 86 |         self.mask_strength = mask_strength
 87 |         self.inner_strength = inner_strength
 88 |         self.smooth_boundary = smooth_boundary
 89 |         self.color_preserve = color_preserve
 90 |         self.loose_cfattn = loose_cfattn
 91 |         self.freeu_args = freeu_args
 92 | 
 93 |         os.makedirs(self.input_dir, exist_ok=True)
 94 |         os.makedirs(self.work_dir, exist_ok=True)
 95 |         os.makedirs(self.key_dir, exist_ok=True)
 96 |         os.makedirs(self.first_dir, exist_ok=True)
 97 | 
 98 |     def create_from_path(self, cfg_path: str):
 99 |         with open(cfg_path, 'r') as fp:
100 |             cfg = json.load(fp)
101 |         kwargs = dict()
102 | 
103 |         def append_if_not_none(key):
104 |             value = cfg.get(key, None)
105 |             if value is not None:
106 |                 kwargs[key] = value
107 | 
108 |         kwargs['input_path'] = cfg['input']
109 |         kwargs['output_path'] = cfg['output']
110 |         kwargs['prompt'] = cfg['prompt']
111 |         append_if_not_none('work_dir')
112 |         append_if_not_none('key_subdir')
113 |         append_if_not_none('frame_count')
114 |         append_if_not_none('interval')
115 |         append_if_not_none('crop')
116 |         append_if_not_none('sd_model')
117 |         append_if_not_none('a_prompt')
118 |         append_if_not_none('n_prompt')
119 |         append_if_not_none('ddim_steps')
120 |         append_if_not_none('scale')
121 |         append_if_not_none('control_type')
122 |         if kwargs.get('control_type', '') == 'canny':
123 |             append_if_not_none('canny_low')
124 |             append_if_not_none('canny_high')
125 |         append_if_not_none('control_strength')
126 |         append_if_not_none('seed')
127 |         append_if_not_none('image_resolution')
128 |         append_if_not_none('use_limit_device_resolution')
129 |         append_if_not_none('x0_strength')
130 |         append_if_not_none('style_update_freq')
131 |         append_if_not_none('cross_period')
132 |         append_if_not_none('warp_period')
133 |         append_if_not_none('mask_period')
134 |         append_if_not_none('ada_period')
135 |         append_if_not_none('mask_strength')
136 |         append_if_not_none('inner_strength')
137 |         append_if_not_none('smooth_boundary')
138 |         append_if_not_none('color_perserve')
139 |         append_if_not_none('loose_cfattn')
140 |         append_if_not_none('freeu_args')
141 |         self.create_from_parameters(**kwargs)
142 | 
143 |     @property
144 |     def use_warp(self):
145 |         return self.warp_period[0] <= self.warp_period[1]
146 | 
147 |     @property
148 |     def use_mask(self):
149 |         return self.mask_period[0] <= self.mask_period[1]
150 | 
151 |     @property
152 |     def use_ada(self):
153 |         return self.ada_period[0] <= self.ada_period[1]
154 | 


--------------------------------------------------------------------------------
/src/controller.py:
--------------------------------------------------------------------------------
  1 | import gc
  2 | 
  3 | import torch
  4 | import torch.nn.functional as F
  5 | 
  6 | from flow.flow_utils import flow_warp
  7 | 
  8 | # AdaIn
  9 | 
 10 | 
 11 | def calc_mean_std(feat, eps=1e-5):
 12 |     # eps is a small value added to the variance to avoid divide-by-zero.
 13 |     size = feat.size()
 14 |     assert (len(size) == 4)
 15 |     N, C = size[:2]
 16 |     feat_var = feat.view(N, C, -1).var(dim=2) + eps
 17 |     feat_std = feat_var.sqrt().view(N, C, 1, 1)
 18 |     feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
 19 |     return feat_mean, feat_std
 20 | 
 21 | 
 22 | class AttentionControl():
 23 | 
 24 |     def __init__(self,
 25 |                  inner_strength,
 26 |                  mask_period,
 27 |                  cross_period,
 28 |                  ada_period,
 29 |                  warp_period,
 30 |                  loose_cfatnn=False):
 31 |         self.step_store = self.get_empty_store()
 32 |         self.cur_step = 0
 33 |         self.total_step = 0
 34 |         self.cur_index = 0
 35 |         self.init_store = False
 36 |         self.restore = False
 37 |         self.update = False
 38 |         self.flow = None
 39 |         self.mask = None
 40 |         self.restorex0 = False
 41 |         self.updatex0 = False
 42 |         self.inner_strength = inner_strength
 43 |         self.cross_period = cross_period
 44 |         self.mask_period = mask_period
 45 |         self.ada_period = ada_period
 46 |         self.warp_period = warp_period
 47 |         self.up_resolution = 1280 if loose_cfatnn else 1281
 48 | 
 49 |     @staticmethod
 50 |     def get_empty_store():
 51 |         return {
 52 |             'first': [],
 53 |             'previous': [],
 54 |             'x0_previous': [],
 55 |             'first_ada': []
 56 |         }
 57 | 
 58 |     def forward(self, context, is_cross: bool, place_in_unet: str):
 59 |         cross_period = (self.total_step * self.cross_period[0],
 60 |                         self.total_step * self.cross_period[1])
 61 |         if not is_cross and place_in_unet == 'up' and context.shape[
 62 |                 2] < self.up_resolution:
 63 |             if self.init_store:
 64 |                 self.step_store['first'].append(context.detach())
 65 |                 self.step_store['previous'].append(context.detach())
 66 |             if self.update:
 67 |                 tmp = context.clone().detach()
 68 |             if self.restore and self.cur_step >= cross_period[0] and \
 69 |                     self.cur_step <= cross_period[1]:
 70 |                 context = torch.cat(
 71 |                     (self.step_store['first'][self.cur_index],
 72 |                      self.step_store['previous'][self.cur_index]),
 73 |                     dim=1).clone()
 74 |             if self.update:
 75 |                 self.step_store['previous'][self.cur_index] = tmp
 76 |             self.cur_index += 1
 77 |         return context
 78 | 
 79 |     def update_x0(self, x0):
 80 |         if self.init_store:
 81 |             self.step_store['x0_previous'].append(x0.detach())
 82 |             style_mean, style_std = calc_mean_std(x0.detach())
 83 |             self.step_store['first_ada'].append(style_mean.detach())
 84 |             self.step_store['first_ada'].append(style_std.detach())
 85 |         if self.updatex0:
 86 |             tmp = x0.clone().detach()
 87 |         if self.restorex0:
 88 |             if self.cur_step >= self.total_step * self.ada_period[
 89 |                     0] and self.cur_step <= self.total_step * self.ada_period[
 90 |                         1]:
 91 |                 x0 = F.instance_norm(x0) * self.step_store['first_ada'][
 92 |                     2 * self.cur_step +
 93 |                     1] + self.step_store['first_ada'][2 * self.cur_step]
 94 |             if self.cur_step >= self.total_step * self.warp_period[
 95 |                     0] and self.cur_step <= self.total_step * self.warp_period[
 96 |                         1]:
 97 |                 pre = self.step_store['x0_previous'][self.cur_step]
 98 |                 x0 = flow_warp(pre, self.flow, mode='nearest') * self.mask + (
 99 |                     1 - self.mask) * x0
100 |         if self.updatex0:
101 |             self.step_store['x0_previous'][self.cur_step] = tmp
102 |         return x0
103 | 
104 |     def set_warp(self, flow, mask):
105 |         self.flow = flow.clone()
106 |         self.mask = mask.clone()
107 | 
108 |     def __call__(self, context, is_cross: bool, place_in_unet: str):
109 |         context = self.forward(context, is_cross, place_in_unet)
110 |         return context
111 | 
112 |     def set_step(self, step):
113 |         self.cur_step = step
114 | 
115 |     def set_total_step(self, total_step):
116 |         self.total_step = total_step
117 |         self.cur_index = 0
118 | 
119 |     def clear_store(self):
120 |         del self.step_store
121 |         torch.cuda.empty_cache()
122 |         gc.collect()
123 |         self.step_store = self.get_empty_store()
124 | 
125 |     def set_task(self, task, restore_step=1.0):
126 |         self.init_store = False
127 |         self.restore = False
128 |         self.update = False
129 |         self.cur_index = 0
130 |         self.restore_step = restore_step
131 |         self.updatex0 = False
132 |         self.restorex0 = False
133 |         if 'initfirst' in task:
134 |             self.init_store = True
135 |             self.clear_store()
136 |         if 'updatestyle' in task:
137 |             self.update = True
138 |         if 'keepstyle' in task:
139 |             self.restore = True
140 |         if 'updatex0' in task:
141 |             self.updatex0 = True
142 |         if 'keepx0' in task:
143 |             self.restorex0 = True
144 | 


--------------------------------------------------------------------------------
/src/ddim_v_hacked.py:
--------------------------------------------------------------------------------
  1 | """SAMPLING ONLY."""
  2 | 
  3 | # CrossAttn precision handling
  4 | import os
  5 | 
  6 | import einops
  7 | import numpy as np
  8 | import torch
  9 | from tqdm import tqdm
 10 | 
 11 | from deps.ControlNet.ldm.modules.diffusionmodules.util import (
 12 |     extract_into_tensor, make_ddim_sampling_parameters, make_ddim_timesteps,
 13 |     noise_like)
 14 | 
 15 | _ATTN_PRECISION = os.environ.get('ATTN_PRECISION', 'fp32')
 16 | 
 17 | 
 18 | def register_attention_control(model, controller=None):
 19 | 
 20 |     def ca_forward(self, place_in_unet):
 21 | 
 22 |         def forward(x, context=None, mask=None):
 23 |             h = self.heads
 24 | 
 25 |             q = self.to_q(x)
 26 |             is_cross = context is not None
 27 |             context = context if is_cross else x
 28 |             context = controller(context, is_cross, place_in_unet)
 29 | 
 30 |             k = self.to_k(context)
 31 |             v = self.to_v(context)
 32 | 
 33 |             q, k, v = map(
 34 |                 lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h),
 35 |                 (q, k, v))
 36 | 
 37 |             # force cast to fp32 to avoid overflowing
 38 |             if _ATTN_PRECISION == 'fp32':
 39 |                 with torch.autocast(enabled=False, device_type='cuda'):
 40 |                     q, k = q.float(), k.float()
 41 |                     sim = torch.einsum('b i d, b j d -> b i j', q,
 42 |                                        k) * self.scale
 43 |             else:
 44 |                 sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
 45 | 
 46 |             del q, k
 47 | 
 48 |             if mask is not None:
 49 |                 mask = einops.rearrange(mask, 'b ... -> b (...)')
 50 |                 max_neg_value = -torch.finfo(sim.dtype).max
 51 |                 mask = einops.repeat(mask, 'b j -> (b h) () j', h=h)
 52 |                 sim.masked_fill_(~mask, max_neg_value)
 53 | 
 54 |             # attention, what we cannot get enough of
 55 |             sim = sim.softmax(dim=-1)
 56 | 
 57 |             out = torch.einsum('b i j, b j d -> b i d', sim, v)
 58 |             out = einops.rearrange(out, '(b h) n d -> b n (h d)', h=h)
 59 |             return self.to_out(out)
 60 | 
 61 |         return forward
 62 | 
 63 |     class DummyController:
 64 | 
 65 |         def __call__(self, *args):
 66 |             return args[0]
 67 | 
 68 |         def __init__(self):
 69 |             self.cur_step = 0
 70 | 
 71 |     if controller is None:
 72 |         controller = DummyController()
 73 | 
 74 |     def register_recr(net_, place_in_unet):
 75 |         if net_.__class__.__name__ == 'CrossAttention':
 76 |             net_.forward = ca_forward(net_, place_in_unet)
 77 |         elif hasattr(net_, 'children'):
 78 |             for net__ in net_.children():
 79 |                 register_recr(net__, place_in_unet)
 80 | 
 81 |     sub_nets = model.named_children()
 82 |     for net in sub_nets:
 83 |         if 'input_blocks' in net[0]:
 84 |             register_recr(net[1], 'down')
 85 |         elif 'output_blocks' in net[0]:
 86 |             register_recr(net[1], 'up')
 87 |         elif 'middle_block' in net[0]:
 88 |             register_recr(net[1], 'mid')
 89 | 
 90 | 
 91 | class DDIMVSampler(object):
 92 | 
 93 |     def __init__(self, model, schedule='linear', **kwargs):
 94 |         super().__init__()
 95 |         self.model = model
 96 |         self.ddpm_num_timesteps = model.num_timesteps
 97 |         self.schedule = schedule
 98 | 
 99 |     def register_buffer(self, name, attr):
100 |         if type(attr) == torch.Tensor:
101 |             if attr.device != torch.device('cuda'):
102 |                 attr = attr.to(torch.device('cuda'))
103 |         setattr(self, name, attr)
104 | 
105 |     def make_schedule(self,
106 |                       ddim_num_steps,
107 |                       ddim_discretize='uniform',
108 |                       ddim_eta=0.,
109 |                       verbose=True):
110 |         self.ddim_timesteps = make_ddim_timesteps(
111 |             ddim_discr_method=ddim_discretize,
112 |             num_ddim_timesteps=ddim_num_steps,
113 |             num_ddpm_timesteps=self.ddpm_num_timesteps,
114 |             verbose=verbose)
115 |         alphas_cumprod = self.model.alphas_cumprod
116 |         assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, \
117 |             'alphas have to be defined for each timestep'
118 | 
119 |         def to_torch(x):
120 |             return x.clone().detach().to(torch.float32).to(self.model.device)
121 | 
122 |         self.register_buffer('betas', to_torch(self.model.betas))
123 |         self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
124 |         self.register_buffer('alphas_cumprod_prev',
125 |                              to_torch(self.model.alphas_cumprod_prev))
126 | 
127 |         # calculations for diffusion q(x_t | x_{t-1}) and others
128 |         self.register_buffer('sqrt_alphas_cumprod',
129 |                              to_torch(np.sqrt(alphas_cumprod.cpu())))
130 |         self.register_buffer('sqrt_one_minus_alphas_cumprod',
131 |                              to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
132 |         self.register_buffer('log_one_minus_alphas_cumprod',
133 |                              to_torch(np.log(1. - alphas_cumprod.cpu())))
134 |         self.register_buffer('sqrt_recip_alphas_cumprod',
135 |                              to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
136 |         self.register_buffer('sqrt_recipm1_alphas_cumprod',
137 |                              to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
138 | 
139 |         # ddim sampling parameters
140 |         ddim_sigmas, ddim_alphas, ddim_alphas_prev = \
141 |             make_ddim_sampling_parameters(
142 |                 alphacums=alphas_cumprod.cpu(),
143 |                 ddim_timesteps=self.ddim_timesteps,
144 |                 eta=ddim_eta,
145 |                 verbose=verbose)
146 |         self.register_buffer('ddim_sigmas', ddim_sigmas)
147 |         self.register_buffer('ddim_alphas', ddim_alphas)
148 |         self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
149 |         self.register_buffer('ddim_sqrt_one_minus_alphas',
150 |                              np.sqrt(1. - ddim_alphas))
151 |         sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
152 |             (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) *
153 |             (1 - self.alphas_cumprod / self.alphas_cumprod_prev))
154 |         self.register_buffer('ddim_sigmas_for_original_num_steps',
155 |                              sigmas_for_original_sampling_steps)
156 | 
157 |     @torch.no_grad()
158 |     def sample(self,
159 |                S,
160 |                batch_size,
161 |                shape,
162 |                conditioning=None,
163 |                callback=None,
164 |                img_callback=None,
165 |                quantize_x0=False,
166 |                eta=0.,
167 |                mask=None,
168 |                x0=None,
169 |                xtrg=None,
170 |                noise_rescale=None,
171 |                temperature=1.,
172 |                noise_dropout=0.,
173 |                score_corrector=None,
174 |                corrector_kwargs=None,
175 |                verbose=True,
176 |                x_T=None,
177 |                log_every_t=100,
178 |                unconditional_guidance_scale=1.,
179 |                unconditional_conditioning=None,
180 |                dynamic_threshold=None,
181 |                ucg_schedule=None,
182 |                controller=None,
183 |                strength=0.0,
184 |                **kwargs):
185 |         if conditioning is not None:
186 |             if isinstance(conditioning, dict):
187 |                 ctmp = conditioning[list(conditioning.keys())[0]]
188 |                 while isinstance(ctmp, list):
189 |                     ctmp = ctmp[0]
190 |                 cbs = ctmp.shape[0]
191 |                 if cbs != batch_size:
192 |                     print(f'Warning: Got {cbs} conditionings'
193 |                           f'but batch-size is {batch_size}')
194 | 
195 |             elif isinstance(conditioning, list):
196 |                 for ctmp in conditioning:
197 |                     if ctmp.shape[0] != batch_size:
198 |                         print(f'Warning: Got {cbs} conditionings'
199 |                               f'but batch-size is {batch_size}')
200 | 
201 |             else:
202 |                 if conditioning.shape[0] != batch_size:
203 |                     print(f'Warning: Got {conditioning.shape[0]}'
204 |                           f'conditionings but batch-size is {batch_size}')
205 | 
206 |         self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
207 |         # sampling
208 |         C, H, W = shape
209 |         size = (batch_size, C, H, W)
210 |         print(f'Data shape for DDIM sampling is {size}, eta {eta}')
211 | 
212 |         samples, intermediates = self.ddim_sampling(
213 |             conditioning,
214 |             size,
215 |             callback=callback,
216 |             img_callback=img_callback,
217 |             quantize_denoised=quantize_x0,
218 |             mask=mask,
219 |             x0=x0,
220 |             xtrg=xtrg,
221 |             noise_rescale=noise_rescale,
222 |             ddim_use_original_steps=False,
223 |             noise_dropout=noise_dropout,
224 |             temperature=temperature,
225 |             score_corrector=score_corrector,
226 |             corrector_kwargs=corrector_kwargs,
227 |             x_T=x_T,
228 |             log_every_t=log_every_t,
229 |             unconditional_guidance_scale=unconditional_guidance_scale,
230 |             unconditional_conditioning=unconditional_conditioning,
231 |             dynamic_threshold=dynamic_threshold,
232 |             ucg_schedule=ucg_schedule,
233 |             controller=controller,
234 |             strength=strength,
235 |         )
236 |         return samples, intermediates
237 | 
238 |     @torch.no_grad()
239 |     def ddim_sampling(self,
240 |                       cond,
241 |                       shape,
242 |                       x_T=None,
243 |                       ddim_use_original_steps=False,
244 |                       callback=None,
245 |                       timesteps=None,
246 |                       quantize_denoised=False,
247 |                       mask=None,
248 |                       x0=None,
249 |                       xtrg=None,
250 |                       noise_rescale=None,
251 |                       img_callback=None,
252 |                       log_every_t=100,
253 |                       temperature=1.,
254 |                       noise_dropout=0.,
255 |                       score_corrector=None,
256 |                       corrector_kwargs=None,
257 |                       unconditional_guidance_scale=1.,
258 |                       unconditional_conditioning=None,
259 |                       dynamic_threshold=None,
260 |                       ucg_schedule=None,
261 |                       controller=None,
262 |                       strength=0.0):
263 | 
264 |         if strength == 1 and x0 is not None:
265 |             return x0, None
266 | 
267 |         register_attention_control(self.model.model.diffusion_model,
268 |                                    controller)
269 | 
270 |         device = self.model.betas.device
271 |         b = shape[0]
272 |         if x_T is None:
273 |             img = torch.randn(shape, device=device)
274 |         else:
275 |             img = x_T
276 | 
277 |         if timesteps is None:
278 |             timesteps = self.ddpm_num_timesteps if ddim_use_original_steps \
279 |                 else self.ddim_timesteps
280 |         elif timesteps is not None and not ddim_use_original_steps:
281 |             subset_end = int(
282 |                 min(timesteps / self.ddim_timesteps.shape[0], 1) *
283 |                 self.ddim_timesteps.shape[0]) - 1
284 |             timesteps = self.ddim_timesteps[:subset_end]
285 | 
286 |         intermediates = {'x_inter': [img], 'pred_x0': [img]}
287 |         time_range = reversed(range(
288 |             0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
289 |         total_steps = timesteps if ddim_use_original_steps \
290 |             else timesteps.shape[0]
291 |         print(f'Running DDIM Sampling with {total_steps} timesteps')
292 | 
293 |         iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
294 |         if controller is not None:
295 |             controller.set_total_step(total_steps)
296 |         if mask is None:
297 |             mask = [None] * total_steps
298 | 
299 |         dir_xt = 0
300 |         for i, step in enumerate(iterator):
301 |             if controller is not None:
302 |                 controller.set_step(i)
303 |             index = total_steps - i - 1
304 |             ts = torch.full((b, ), step, device=device, dtype=torch.long)
305 | 
306 |             if strength >= 0 and i == int(
307 |                     total_steps * strength) and x0 is not None:
308 |                 img = self.model.q_sample(x0, ts)
309 |             if mask is not None and xtrg is not None:
310 |                 # TODO: deterministic forward pass?
311 |                 if type(mask) == list:
312 |                     weight = mask[i]
313 |                 else:
314 |                     weight = mask
315 |                 if weight is not None:
316 |                     rescale = torch.maximum(1. - weight, (1 - weight**2)**0.5 *
317 |                                             controller.inner_strength)
318 |                     if noise_rescale is not None:
319 |                         rescale = (1. - weight) * (
320 |                             1 - noise_rescale) + rescale * noise_rescale
321 |                     img_ref = self.model.q_sample(xtrg, ts)
322 |                     img = img_ref * weight + (1. - weight) * (
323 |                         img - dir_xt) + rescale * dir_xt
324 | 
325 |             if ucg_schedule is not None:
326 |                 assert len(ucg_schedule) == len(time_range)
327 |                 unconditional_guidance_scale = ucg_schedule[i]
328 | 
329 |             outs = self.p_sample_ddim(
330 |                 img,
331 |                 cond,
332 |                 ts,
333 |                 index=index,
334 |                 use_original_steps=ddim_use_original_steps,
335 |                 quantize_denoised=quantize_denoised,
336 |                 temperature=temperature,
337 |                 noise_dropout=noise_dropout,
338 |                 score_corrector=score_corrector,
339 |                 corrector_kwargs=corrector_kwargs,
340 |                 unconditional_guidance_scale=unconditional_guidance_scale,
341 |                 unconditional_conditioning=unconditional_conditioning,
342 |                 dynamic_threshold=dynamic_threshold,
343 |                 controller=controller,
344 |                 return_dir=True)
345 |             img, pred_x0, dir_xt = outs
346 |             if callback:
347 |                 callback(i)
348 |             if img_callback:
349 |                 img_callback(pred_x0, i)
350 | 
351 |             if index % log_every_t == 0 or index == total_steps - 1:
352 |                 intermediates['x_inter'].append(img)
353 |                 intermediates['pred_x0'].append(pred_x0)
354 | 
355 |         return img, intermediates
356 | 
357 |     @torch.no_grad()
358 |     def p_sample_ddim(self,
359 |                       x,
360 |                       c,
361 |                       t,
362 |                       index,
363 |                       repeat_noise=False,
364 |                       use_original_steps=False,
365 |                       quantize_denoised=False,
366 |                       temperature=1.,
367 |                       noise_dropout=0.,
368 |                       score_corrector=None,
369 |                       corrector_kwargs=None,
370 |                       unconditional_guidance_scale=1.,
371 |                       unconditional_conditioning=None,
372 |                       dynamic_threshold=None,
373 |                       controller=None,
374 |                       return_dir=False):
375 |         b, *_, device = *x.shape, x.device
376 | 
377 |         if unconditional_conditioning is None or \
378 |                 unconditional_guidance_scale == 1.:
379 |             model_output = self.model.apply_model(x, t, c)
380 |         else:
381 |             model_t = self.model.apply_model(x, t, c)
382 |             model_uncond = self.model.apply_model(x, t,
383 |                                                   unconditional_conditioning)
384 |             model_output = model_uncond + unconditional_guidance_scale * (
385 |                 model_t - model_uncond)
386 | 
387 |         if self.model.parameterization == 'v':
388 |             e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
389 |         else:
390 |             e_t = model_output
391 | 
392 |         if score_corrector is not None:
393 |             assert self.model.parameterization == 'eps', 'not implemented'
394 |             e_t = score_corrector.modify_score(self.model, e_t, x, t, c,
395 |                                                **corrector_kwargs)
396 | 
397 |         if use_original_steps:
398 |             alphas = self.model.alphas_cumprod
399 |             alphas_prev = self.model.alphas_cumprod_prev
400 |             sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod
401 |             sigmas = self.model.ddim_sigmas_for_original_num_steps
402 |         else:
403 |             alphas = self.ddim_alphas
404 |             alphas_prev = self.ddim_alphas_prev
405 |             sqrt_one_minus_alphas = self.ddim_sqrt_one_minus_alphas
406 |             sigmas = self.ddim_sigmas
407 | 
408 |         # select parameters corresponding to the currently considered timestep
409 |         a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
410 |         a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
411 |         sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
412 |         sqrt_one_minus_at = torch.full((b, 1, 1, 1),
413 |                                        sqrt_one_minus_alphas[index],
414 |                                        device=device)
415 | 
416 |         # current prediction for x_0
417 |         if self.model.parameterization != 'v':
418 |             pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
419 |         else:
420 |             pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
421 | 
422 |         if quantize_denoised:
423 |             pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
424 | 
425 |         if dynamic_threshold is not None:
426 |             raise NotImplementedError()
427 |         '''
428 |         if mask is not None and xtrg is not None:
429 |             pred_x0 = xtrg * mask + (1. - mask) * pred_x0
430 |         '''
431 | 
432 |         if controller is not None:
433 |             pred_x0 = controller.update_x0(pred_x0)
434 | 
435 |         # direction pointing to x_t
436 |         dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
437 |         noise = sigma_t * noise_like(x.shape, device,
438 |                                      repeat_noise) * temperature
439 |         if noise_dropout > 0.:
440 |             noise = torch.nn.functional.dropout(noise, p=noise_dropout)
441 |         x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
442 | 
443 |         if return_dir:
444 |             return x_prev, pred_x0, dir_xt
445 |         return x_prev, pred_x0
446 | 
447 |     @torch.no_grad()
448 |     def encode(self,
449 |                x0,
450 |                c,
451 |                t_enc,
452 |                use_original_steps=False,
453 |                return_intermediates=None,
454 |                unconditional_guidance_scale=1.0,
455 |                unconditional_conditioning=None,
456 |                callback=None):
457 |         timesteps = np.arange(self.ddpm_num_timesteps
458 |                               ) if use_original_steps else self.ddim_timesteps
459 |         num_reference_steps = timesteps.shape[0]
460 | 
461 |         assert t_enc <= num_reference_steps
462 |         num_steps = t_enc
463 | 
464 |         if use_original_steps:
465 |             alphas_next = self.alphas_cumprod[:num_steps]
466 |             alphas = self.alphas_cumprod_prev[:num_steps]
467 |         else:
468 |             alphas_next = self.ddim_alphas[:num_steps]
469 |             alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
470 | 
471 |         x_next = x0
472 |         intermediates = []
473 |         inter_steps = []
474 |         for i in tqdm(range(num_steps), desc='Encoding Image'):
475 |             t = torch.full((x0.shape[0], ),
476 |                            timesteps[i],
477 |                            device=self.model.device,
478 |                            dtype=torch.long)
479 |             if unconditional_guidance_scale == 1.:
480 |                 noise_pred = self.model.apply_model(x_next, t, c)
481 |             else:
482 |                 assert unconditional_conditioning is not None
483 |                 e_t_uncond, noise_pred = torch.chunk(
484 |                     self.model.apply_model(
485 |                         torch.cat((x_next, x_next)), torch.cat((t, t)),
486 |                         torch.cat((unconditional_conditioning, c))), 2)
487 |                 noise_pred = e_t_uncond + unconditional_guidance_scale * (
488 |                     noise_pred - e_t_uncond)
489 |             xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
490 |             weighted_noise_pred = alphas_next[i].sqrt() * (
491 |                 (1 / alphas_next[i] - 1).sqrt() -
492 |                 (1 / alphas[i] - 1).sqrt()) * noise_pred
493 |             x_next = xt_weighted + weighted_noise_pred
494 |             if return_intermediates and i % (num_steps // return_intermediates
495 |                                              ) == 0 and i < num_steps - 1:
496 |                 intermediates.append(x_next)
497 |                 inter_steps.append(i)
498 |             elif return_intermediates and i >= num_steps - 2:
499 |                 intermediates.append(x_next)
500 |                 inter_steps.append(i)
501 |             if callback:
502 |                 callback(i)
503 | 
504 |         out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
505 |         if return_intermediates:
506 |             out.update({'intermediates': intermediates})
507 |         return x_next, out
508 | 
509 |     @torch.no_grad()
510 |     def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
511 |         # fast, but does not allow for exact reconstruction
512 |         # t serves as an index to gather the correct alphas
513 |         if use_original_steps:
514 |             sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
515 |             sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
516 |         else:
517 |             sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
518 |             sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
519 | 
520 |         if noise is None:
521 |             noise = torch.randn_like(x0)
522 |         if t >= len(sqrt_alphas_cumprod):
523 |             return noise
524 |         return (
525 |             extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
526 |             extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) *
527 |             noise)
528 | 
529 |     @torch.no_grad()
530 |     def decode(self,
531 |                x_latent,
532 |                cond,
533 |                t_start,
534 |                unconditional_guidance_scale=1.0,
535 |                unconditional_conditioning=None,
536 |                use_original_steps=False,
537 |                callback=None):
538 | 
539 |         timesteps = np.arange(self.ddpm_num_timesteps
540 |                               ) if use_original_steps else self.ddim_timesteps
541 |         timesteps = timesteps[:t_start]
542 | 
543 |         time_range = np.flip(timesteps)
544 |         total_steps = timesteps.shape[0]
545 |         print(f'Running DDIM Sampling with {total_steps} timesteps')
546 | 
547 |         iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
548 |         x_dec = x_latent
549 |         for i, step in enumerate(iterator):
550 |             index = total_steps - i - 1
551 |             ts = torch.full((x_latent.shape[0], ),
552 |                             step,
553 |                             device=x_latent.device,
554 |                             dtype=torch.long)
555 |             x_dec, _ = self.p_sample_ddim(
556 |                 x_dec,
557 |                 cond,
558 |                 ts,
559 |                 index=index,
560 |                 use_original_steps=use_original_steps,
561 |                 unconditional_guidance_scale=unconditional_guidance_scale,
562 |                 unconditional_conditioning=unconditional_conditioning)
563 |             if callback:
564 |                 callback(i)
565 |         return x_dec
566 | 
567 | 
568 | def calc_mean_std(feat, eps=1e-5):
569 |     # eps is a small value added to the variance to avoid divide-by-zero.
570 |     size = feat.size()
571 |     assert (len(size) == 4)
572 |     N, C = size[:2]
573 |     feat_var = feat.view(N, C, -1).var(dim=2) + eps
574 |     feat_std = feat_var.sqrt().view(N, C, 1, 1)
575 |     feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
576 |     return feat_mean, feat_std
577 | 
578 | 
579 | def adaptive_instance_normalization(content_feat, style_feat):
580 |     assert (content_feat.size()[:2] == style_feat.size()[:2])
581 |     size = content_feat.size()
582 |     style_mean, style_std = calc_mean_std(style_feat)
583 |     content_mean, content_std = calc_mean_std(content_feat)
584 | 
585 |     normalized_feat = (content_feat -
586 |                        content_mean.expand(size)) / content_std.expand(size)
587 |     return normalized_feat * style_std.expand(size) + style_mean.expand(size)
588 | 


--------------------------------------------------------------------------------
/src/freeu.py:
--------------------------------------------------------------------------------
 1 | import torch
 2 | import torch.fft as fft
 3 | 
 4 | 
 5 | def Fourier_filter(x, threshold, scale):
 6 | 
 7 |     x_freq = fft.fftn(x, dim=(-2, -1))
 8 |     x_freq = fft.fftshift(x_freq, dim=(-2, -1))
 9 | 
10 |     B, C, H, W = x_freq.shape
11 |     mask = torch.ones((B, C, H, W)).cuda()
12 | 
13 |     crow, ccol = H // 2, W // 2
14 |     mask[..., crow - threshold:crow + threshold,
15 |          ccol - threshold:ccol + threshold] = scale
16 |     x_freq = x_freq * mask
17 | 
18 |     x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
19 | 
20 |     x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real
21 | 
22 |     return x_filtered
23 | 
24 | from deps.ControlNet.ldm.modules.diffusionmodules.util import \
25 |     timestep_embedding  # noqa:E501
26 | 
27 | 
28 | # backbone_scale1=1.1, backbone_scale2=1.2, skip_scale1=1.0, skip_scale2=0.2
29 | def freeu_forward(self,
30 |                   backbone_scale1=1.,
31 |                   backbone_scale2=1.,
32 |                   skip_scale1=1.,
33 |                   skip_scale2=1.):
34 | 
35 |     def forward(x,
36 |                 timesteps=None,
37 |                 context=None,
38 |                 control=None,
39 |                 only_mid_control=False,
40 |                 **kwargs):
41 |         hs = []
42 |         with torch.no_grad():
43 |             t_emb = timestep_embedding(timesteps,
44 |                                        self.model_channels,
45 |                                        repeat_only=False)
46 |             emb = self.time_embed(t_emb)
47 |             h = x.type(self.dtype)
48 |             for module in self.input_blocks:
49 |                 h = module(h, emb, context)
50 |                 hs.append(h)
51 |             h = self.middle_block(h, emb, context)
52 | 
53 |         if control is not None:
54 |             h += control.pop()
55 |         '''
56 |         for i, module in enumerate(self.output_blocks):
57 |             if only_mid_control or control is None:
58 |                 h = torch.cat([h, hs.pop()], dim=1)
59 |             else:
60 |                 h = torch.cat([h, hs.pop() + control.pop()], dim=1)
61 |             h = module(h, emb, context)
62 |         '''
63 |         for i, module in enumerate(self.output_blocks):
64 |             hs_ = hs.pop()
65 | 
66 |             if h.shape[1] == 1280:
67 |                 hidden_mean = h.mean(1).unsqueeze(1)
68 |                 B = hidden_mean.shape[0]
69 |                 hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True) 
70 |                 hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
71 |                 hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
72 |                 h[:, :640] = h[:, :640] * ((backbone_scale1 - 1) * hidden_mean + 1)        
73 |                 # h[:, :640] = h[:, :640] * backbone_scale1
74 |                 hs_ = Fourier_filter(hs_, threshold=1, scale=skip_scale1)
75 |             if h.shape[1] == 640:
76 |                 hidden_mean = h.mean(1).unsqueeze(1)
77 |                 B = hidden_mean.shape[0]
78 |                 hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True) 
79 |                 hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
80 |                 hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3) 
81 |                 h[:, :320] = h[:, :320] * ((backbone_scale2 - 1) * hidden_mean + 1)
82 |                 # h[:, :320] = h[:, :320] * backbone_scale2
83 |                 hs_ = Fourier_filter(hs_, threshold=1, scale=skip_scale2)
84 | 
85 |             if only_mid_control or control is None:
86 |                 h = torch.cat([h, hs_], dim=1)
87 |             else:
88 |                 h = torch.cat([h, hs_ + control.pop()], dim=1)
89 |             h = module(h, emb, context)
90 | 
91 |         h = h.type(x.dtype)
92 |         return self.out(h)
93 | 
94 |     return forward
95 | 


--------------------------------------------------------------------------------
/src/img_util.py:
--------------------------------------------------------------------------------
 1 | import einops
 2 | import torch
 3 | import torch.nn.functional as F
 4 | 
 5 | 
 6 | @torch.no_grad()
 7 | def find_flat_region(mask):
 8 |     device = mask.device
 9 |     kernel_x = torch.Tensor([[-1, 0, 1], [-1, 0, 1],
10 |                              [-1, 0, 1]]).unsqueeze(0).unsqueeze(0).to(device)
11 |     kernel_y = torch.Tensor([[-1, -1, -1], [0, 0, 0],
12 |                              [1, 1, 1]]).unsqueeze(0).unsqueeze(0).to(device)
13 |     mask_ = F.pad(mask.unsqueeze(0), (1, 1, 1, 1), mode='replicate')
14 | 
15 |     grad_x = torch.nn.functional.conv2d(mask_, kernel_x)
16 |     grad_y = torch.nn.functional.conv2d(mask_, kernel_y)
17 |     return ((abs(grad_x) + abs(grad_y)) == 0).float()[0]
18 | 
19 | 
20 | def numpy2tensor(img):
21 |     x0 = torch.from_numpy(img.copy()).float().cuda() / 255.0 * 2.0 - 1.
22 |     x0 = torch.stack([x0], dim=0)
23 |     return einops.rearrange(x0, 'b h w c -> b c h w').clone()
24 | 


--------------------------------------------------------------------------------
/src/import_util.py:
--------------------------------------------------------------------------------
 1 | import os
 2 | import sys
 3 | 
 4 | cur_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
 5 | gmflow_dir = os.path.join(cur_dir, 'deps/gmflow')
 6 | controlnet_dir = os.path.join(cur_dir, 'deps/ControlNet')
 7 | sys.path.insert(0, gmflow_dir)
 8 | sys.path.insert(0, controlnet_dir)
 9 | 
10 | import deps.ControlNet.share  # noqa: F401 E402
11 | 


--------------------------------------------------------------------------------
/src/video_util.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | 
  3 | import cv2
  4 | import torch
  5 | import imageio
  6 | import numpy as np
  7 | 
  8 | 
  9 | def video_to_frame(video_path: str,
 10 |                    frame_dir: str,
 11 |                    filename_pattern: str = 'frame%03d.jpg',
 12 |                    log: bool = True,
 13 |                    frame_edit_func=None):
 14 |     os.makedirs(frame_dir, exist_ok=True)
 15 | 
 16 |     vidcap = cv2.VideoCapture(video_path)
 17 |     success, image = vidcap.read()
 18 | 
 19 |     if log:
 20 |         print('img shape: ', image.shape[0:2])
 21 | 
 22 |     count = 0
 23 |     while success:
 24 |         if frame_edit_func is not None:
 25 |             image = frame_edit_func(image)
 26 | 
 27 |         cv2.imwrite(os.path.join(frame_dir, filename_pattern % count), image)
 28 |         success, image = vidcap.read()
 29 |         if log:
 30 |             print('Read a new frame: ', success, count)
 31 |         count += 1
 32 | 
 33 |     vidcap.release()
 34 | 
 35 | 
 36 | def frame_to_video(video_path: str, frame_dir: str, fps=30, log=True):
 37 | 
 38 |     first_img = True
 39 |     writer = imageio.get_writer(video_path, fps=fps)
 40 | 
 41 |     file_list = sorted(os.listdir(frame_dir))
 42 |     for file_name in file_list:
 43 |         if not (file_name.endswith('jpg') or file_name.endswith('png')):
 44 |             continue
 45 | 
 46 |         fn = os.path.join(frame_dir, file_name)
 47 |         curImg = imageio.imread(fn)
 48 | 
 49 |         if first_img:
 50 |             H, W = curImg.shape[0:2]
 51 |             if log:
 52 |                 print('img shape', (H, W))
 53 |             first_img = False
 54 | 
 55 |         writer.append_data(curImg)
 56 | 
 57 |     writer.close()
 58 | 
 59 | 
 60 | def get_fps(video_path: str):
 61 |     video = cv2.VideoCapture(video_path)
 62 |     fps = video.get(cv2.CAP_PROP_FPS)
 63 |     video.release()
 64 |     return fps
 65 | 
 66 | 
 67 | def get_frame_count(video_path: str):
 68 |     video = cv2.VideoCapture(video_path)
 69 |     frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
 70 |     video.release()
 71 |     return frame_count
 72 | 
 73 | 
 74 | def resize_image(input_image, resolution):
 75 |     H, W, C = input_image.shape
 76 |     H = float(H)
 77 |     W = float(W)
 78 |     aspect_ratio = W / H
 79 |     k = float(resolution) / min(H, W)
 80 |     H *= k
 81 |     W *= k
 82 |     if H < W:
 83 |         W = resolution
 84 |         H = int(resolution / aspect_ratio)
 85 |     else:
 86 |         H = resolution
 87 |         W = int(aspect_ratio * resolution)
 88 |     H = int(np.round(H / 64.0)) * 64
 89 |     W = int(np.round(W / 64.0)) * 64
 90 |     img = cv2.resize(
 91 |         input_image, (W, H),
 92 |         interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
 93 |     return img
 94 | 
 95 | 
 96 | def prepare_frames(input_path: str, output_dir: str, resolution: int, crop, use_limit_device_resolution=False):
 97 |     l, r, t, b = crop
 98 | 
 99 |     if use_limit_device_resolution:
100 |         resolution = vram_limit_device_resolution(resolution)
101 | 
102 |     def crop_func(frame):
103 |         H, W, C = frame.shape
104 |         left = np.clip(l, 0, W)
105 |         right = np.clip(W - r, left, W)
106 |         top = np.clip(t, 0, H)
107 |         bottom = np.clip(H - b, top, H)
108 |         frame = frame[top:bottom, left:right]
109 |         return resize_image(frame, resolution)
110 | 
111 |     video_to_frame(input_path, output_dir, '%04d.png', False, crop_func)
112 | 
113 | 
114 | def vram_limit_device_resolution(resolution, device="cuda"):
115 |     # get max limit target size
116 |     gpu_vram = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)
117 |     # table of gpu memory limit
118 |     gpu_table = {24: 1280, 18: 1024, 14: 768, 10: 640, 8: 576, 7: 512, 6: 448, 5: 320, 4: 192, 0: 0}
119 |     # get user resize for gpu
120 |     device_resolution = max(val for key, val in gpu_table.items() if key <= gpu_vram)
121 |     print(f"Limit VRAM is {gpu_vram} Gb and size {device_resolution}.")
122 |     if gpu_vram < 4:
123 |         print(f"Small VRAM to use GPU. Configuration resolution will be used.")
124 |     if resolution < device_resolution:
125 |         print(f"Video will not resize")
126 |         return resolution
127 |     return device_resolution
128 | 


--------------------------------------------------------------------------------
/video_blend.py:
--------------------------------------------------------------------------------
  1 | import argparse
  2 | import os
  3 | import platform
  4 | import struct
  5 | import subprocess
  6 | import time
  7 | from typing import List
  8 | 
  9 | import cv2
 10 | import numpy as np
 11 | import torch.multiprocessing as mp
 12 | from numba import njit
 13 | 
 14 | import blender.histogram_blend as histogram_blend
 15 | from blender.guide import (BaseGuide, ColorGuide, EdgeGuide, PositionalGuide,
 16 |                            TemporalGuide)
 17 | from blender.poisson_fusion import poisson_fusion
 18 | from blender.video_sequence import VideoSequence
 19 | from flow.flow_utils import flow_calc
 20 | from src.video_util import frame_to_video
 21 | 
 22 | OPEN_EBSYNTH_LOG = False
 23 | MAX_PROCESS = 8
 24 | 
 25 | os_str = platform.system()
 26 | 
 27 | if os_str == 'Windows':
 28 |     ebsynth_bin = '.\\deps\\ebsynth\\bin\\ebsynth.exe'
 29 | elif os_str == 'Linux':
 30 |     ebsynth_bin = './deps/ebsynth/bin/ebsynth'
 31 | elif os_str == 'Darwin':
 32 |     ebsynth_bin = './deps/ebsynth/bin/ebsynth.app'
 33 | else:
 34 |     print('Cannot recognize OS. Run Ebsynth failed.')
 35 |     exit(0)
 36 | 
 37 | 
 38 | @njit
 39 | def g_error_mask_loop(H, W, dist1, dist2, output, weight1, weight2):
 40 |     for i in range(H):
 41 |         for j in range(W):
 42 |             if weight1 * dist1[i, j] < weight2 * dist2[i, j]:
 43 |                 output[i, j] = 0
 44 |             else:
 45 |                 output[i, j] = 1
 46 |             if weight1 == 0:
 47 |                 output[i, j] = 0
 48 |             elif weight2 == 0:
 49 |                 output[i, j] = 1
 50 | 
 51 | 
 52 | def g_error_mask(dist1, dist2, weight1=1, weight2=1):
 53 |     H, W = dist1.shape
 54 |     output = np.empty_like(dist1, dtype=np.byte)
 55 |     g_error_mask_loop(H, W, dist1, dist2, output, weight1, weight2)
 56 |     return output
 57 | 
 58 | 
 59 | def create_sequence(base_dir, beg, end, interval, key_dir):
 60 |     sequence = VideoSequence(base_dir, beg, end, interval, 'video', key_dir,
 61 |                              'tmp', '%04d.png', '%04d.png')
 62 |     return sequence
 63 | 
 64 | 
 65 | def process_one_sequence(i, video_sequence: VideoSequence):
 66 |     interval = video_sequence.interval
 67 |     for is_forward in [True, False]:
 68 |         input_seq = video_sequence.get_input_sequence(i, is_forward)
 69 |         output_seq = video_sequence.get_output_sequence(i, is_forward)
 70 |         flow_seq = video_sequence.get_flow_sequence(i, is_forward)
 71 |         key_img_id = i if is_forward else i + 1
 72 |         key_img = video_sequence.get_key_img(key_img_id)
 73 |         for j in range(interval - 1):
 74 |             i1 = cv2.imread(input_seq[j])
 75 |             i2 = cv2.imread(input_seq[j + 1])
 76 |             flow_calc.get_flow(i1, i2, flow_seq[j])
 77 | 
 78 |         guides: List[BaseGuide] = [
 79 |             ColorGuide(input_seq),
 80 |             EdgeGuide(input_seq,
 81 |                       video_sequence.get_edge_sequence(i, is_forward)),
 82 |             TemporalGuide(key_img, output_seq, flow_seq,
 83 |                           video_sequence.get_temporal_sequence(i, is_forward)),
 84 |             PositionalGuide(flow_seq,
 85 |                             video_sequence.get_pos_sequence(i, is_forward))
 86 |         ]
 87 |         weights = [6, 0.5, 0.5, 2]
 88 |         for j in range(interval):
 89 |             # key frame
 90 |             if j == 0:
 91 |                 img = cv2.imread(key_img)
 92 |                 cv2.imwrite(output_seq[0], img)
 93 |             else:
 94 |                 cmd = f'{ebsynth_bin} -style {os.path.abspath(key_img)}'
 95 |                 for g, w in zip(guides, weights):
 96 |                     cmd += ' ' + g.get_cmd(j, w)
 97 | 
 98 |                 cmd += (f' -output {os.path.abspath(output_seq[j])}'
 99 |                         ' -searchvoteiters 12 -patchmatchiters 6')
100 |                 if OPEN_EBSYNTH_LOG:
101 |                     print(cmd)
102 |                 subprocess.run(cmd,
103 |                                shell=True,
104 |                                capture_output=not OPEN_EBSYNTH_LOG)
105 | 
106 | 
107 | def process_sequences(i_arr, video_sequence: VideoSequence):
108 |     for i in i_arr:
109 |         process_one_sequence(i, video_sequence)
110 | 
111 | 
112 | def run_ebsynth(video_sequence: VideoSequence):
113 | 
114 |     beg = time.time()
115 | 
116 |     processes = []
117 |     mp.set_start_method('spawn')
118 | 
119 |     n_process = min(MAX_PROCESS, video_sequence.n_seq)
120 |     cnt = video_sequence.n_seq // n_process
121 |     remainder = video_sequence.n_seq % n_process
122 | 
123 |     prev_idx = 0
124 | 
125 |     for i in range(n_process):
126 |         task_cnt = cnt + 1 if i < remainder else cnt
127 |         i_arr = list(range(prev_idx, prev_idx + task_cnt))
128 |         prev_idx += task_cnt
129 |         p = mp.Process(target=process_sequences, args=(i_arr, video_sequence))
130 |         p.start()
131 |         processes.append(p)
132 |     for p in processes:
133 |         p.join()
134 | 
135 |     end = time.time()
136 | 
137 |     print(f'ebsynth: {end-beg}')
138 | 
139 | 
140 | @njit
141 | def assemble_min_error_img_loop(H, W, a, b, error_mask, out):
142 |     for i in range(H):
143 |         for j in range(W):
144 |             if error_mask[i, j] == 0:
145 |                 out[i, j] = a[i, j]
146 |             else:
147 |                 out[i, j] = b[i, j]
148 | 
149 | 
150 | def assemble_min_error_img(a, b, error_mask):
151 |     H, W = a.shape[0:2]
152 |     out = np.empty_like(a)
153 |     assemble_min_error_img_loop(H, W, a, b, error_mask, out)
154 |     return out
155 | 
156 | 
157 | def load_error(bin_path, img_shape):
158 |     img_size = img_shape[0] * img_shape[1]
159 |     with open(bin_path, 'rb') as fp:
160 |         bytes = fp.read()
161 | 
162 |     read_size = struct.unpack('q', bytes[:8])
163 |     assert read_size[0] == img_size
164 |     float_res = struct.unpack('f' * img_size, bytes[8:])
165 |     res = np.array(float_res,
166 |                    dtype=np.float32).reshape(img_shape[0], img_shape[1])
167 |     return res
168 | 
169 | 
170 | def process_seq(video_sequence: VideoSequence,
171 |                 i,
172 |                 blend_histogram=True,
173 |                 blend_gradient=True):
174 | 
175 |     key1_img = cv2.imread(video_sequence.get_key_img(i))
176 |     img_shape = key1_img.shape
177 |     interval = video_sequence.interval
178 |     beg_id = video_sequence.get_sequence_beg_id(i)
179 | 
180 |     oas = video_sequence.get_output_sequence(i)
181 |     obs = video_sequence.get_output_sequence(i, False)
182 | 
183 |     binas = [x.replace('jpg', 'bin') for x in oas]
184 |     binbs = [x.replace('jpg', 'bin') for x in obs]
185 | 
186 |     obs = [obs[0]] + list(reversed(obs[1:]))
187 |     inputs = video_sequence.get_input_sequence(i)
188 |     oas = [cv2.imread(x) for x in oas]
189 |     obs = [cv2.imread(x) for x in obs]
190 |     inputs = [cv2.imread(x) for x in inputs]
191 |     flow_seq = video_sequence.get_flow_sequence(i)
192 | 
193 |     dist1s = []
194 |     dist2s = []
195 |     for i in range(interval - 1):
196 |         bin_a = binas[i + 1]
197 |         bin_b = binbs[i + 1]
198 |         dist1s.append(load_error(bin_a, img_shape))
199 |         dist2s.append(load_error(bin_b, img_shape))
200 | 
201 |     lb = 0
202 |     ub = 1
203 |     beg = time.time()
204 |     p_mask = None
205 | 
206 |     # write key img
207 |     blend_out_path = video_sequence.get_blending_img(beg_id)
208 |     cv2.imwrite(blend_out_path, key1_img)
209 | 
210 |     for i in range(interval - 1):
211 |         c_id = beg_id + i + 1
212 |         blend_out_path = video_sequence.get_blending_img(c_id)
213 | 
214 |         dist1 = dist1s[i]
215 |         dist2 = dist2s[i]
216 |         oa = oas[i + 1]
217 |         ob = obs[i + 1]
218 |         weight1 = i / (interval - 1) * (ub - lb) + lb
219 |         weight2 = 1 - weight1
220 |         mask = g_error_mask(dist1, dist2, weight1, weight2)
221 |         if p_mask is not None:
222 |             flow_path = flow_seq[i]
223 |             flow = flow_calc.get_flow(inputs[i], inputs[i + 1], flow_path)
224 |             p_mask = flow_calc.warp(p_mask, flow, 'nearest')
225 |             mask = p_mask | mask
226 |         p_mask = mask
227 | 
228 |         # Save tmp mask
229 |         # out_mask = np.expand_dims(mask, 2)
230 |         # cv2.imwrite(f'mask/mask_{c_id:04d}.jpg', out_mask * 255)
231 | 
232 |         min_error_img = assemble_min_error_img(oa, ob, mask)
233 |         if blend_histogram:
234 |             hb_res = histogram_blend.blend(oa, ob, min_error_img,
235 |                                            (1 - weight1), (1 - weight2))
236 | 
237 |         else:
238 |             # hb_res = min_error_img
239 |             tmpa = oa.astype(np.float32)
240 |             tmpb = ob.astype(np.float32)
241 |             hb_res = (1 - weight1) * tmpa + (1 - weight2) * tmpb
242 | 
243 |         # cv2.imwrite(blend_out_path, hb_res)
244 | 
245 |         # gradient blend
246 |         if blend_gradient:
247 |             res = poisson_fusion(hb_res, oa, ob, mask)
248 |         else:
249 |             res = hb_res
250 | 
251 |         cv2.imwrite(blend_out_path, res)
252 |     end = time.time()
253 |     print('others:', end - beg)
254 | 
255 | 
256 | def main(args):
257 |     global MAX_PROCESS
258 |     MAX_PROCESS = args.n_proc
259 | 
260 |     video_sequence = create_sequence(f'{args.name}', args.beg, args.end,
261 |                                      args.itv, args.key)
262 |     if not args.ne:
263 |         run_ebsynth(video_sequence)
264 |     blend_histogram = True
265 |     blend_gradient = args.ps
266 |     for i in range(video_sequence.n_seq):
267 |         process_seq(video_sequence, i, blend_histogram, blend_gradient)
268 |     if args.output:
269 |         frame_to_video(args.output, video_sequence.blending_dir, args.fps,
270 |                        False)
271 |     if not args.tmp:
272 |         video_sequence.remove_out_and_tmp()
273 | 
274 | 
275 | if __name__ == '__main__':
276 |     parser = argparse.ArgumentParser()
277 |     parser.add_argument('name', type=str, help='Path to input video')
278 |     parser.add_argument('--output',
279 |                         type=str,
280 |                         default=None,
281 |                         help='Path to output video')
282 |     parser.add_argument('--fps',
283 |                         type=float,
284 |                         default=30,
285 |                         help='The FPS of output video')
286 |     parser.add_argument('--beg',
287 |                         type=int,
288 |                         default=1,
289 |                         help='The index of the first frame to be stylized')
290 |     parser.add_argument('--end',
291 |                         type=int,
292 |                         default=101,
293 |                         help='The index of the last frame to be stylized')
294 |     parser.add_argument('--itv',
295 |                         type=int,
296 |                         default=10,
297 |                         help='The interval of key frame')
298 |     parser.add_argument('--key',
299 |                         type=str,
300 |                         default='keys0',
301 |                         help='The subfolder name of stylized key frames')
302 |     parser.add_argument('--n_proc',
303 |                         type=int,
304 |                         default=8,
305 |                         help='The max process count')
306 |     parser.add_argument('-ps',
307 |                         action='store_true',
308 |                         help='Use poisson gradient blending')
309 |     parser.add_argument(
310 |         '-ne',
311 |         action='store_true',
312 |         help='Do not run ebsynth (use previous ebsynth output)')
313 |     parser.add_argument('-tmp',
314 |                         action='store_true',
315 |                         help='Keep temporary output')
316 | 
317 |     args = parser.parse_args()
318 |     main(args)
319 | 


--------------------------------------------------------------------------------
/videos/pexels-antoni-shkraba-8048492-540x960-25fps.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/williamyang1991/Rerender_A_Video/dfaf9d8825f226a2f0a0b731ab2adc84a3f2ebd2/videos/pexels-antoni-shkraba-8048492-540x960-25fps.mp4


--------------------------------------------------------------------------------
/videos/pexels-cottonbro-studio-6649832-960x506-25fps.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/williamyang1991/Rerender_A_Video/dfaf9d8825f226a2f0a0b731ab2adc84a3f2ebd2/videos/pexels-cottonbro-studio-6649832-960x506-25fps.mp4


--------------------------------------------------------------------------------
/videos/pexels-koolshooters-7322716.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/williamyang1991/Rerender_A_Video/dfaf9d8825f226a2f0a0b731ab2adc84a3f2ebd2/videos/pexels-koolshooters-7322716.mp4


--------------------------------------------------------------------------------
/webUI.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | import shutil
  3 | from enum import Enum
  4 | 
  5 | import cv2
  6 | import einops
  7 | import gradio as gr
  8 | import numpy as np
  9 | import torch
 10 | import torch.nn.functional as F
 11 | import torchvision.transforms as T
 12 | from blendmodes.blend import BlendType, blendLayers
 13 | from PIL import Image
 14 | from pytorch_lightning import seed_everything
 15 | from safetensors.torch import load_file
 16 | from skimage import exposure
 17 | 
 18 | import src.import_util  # noqa: F401
 19 | from deps.ControlNet.annotator.canny import CannyDetector
 20 | from deps.ControlNet.annotator.hed import HEDdetector
 21 | from deps.ControlNet.annotator.util import HWC3
 22 | from deps.ControlNet.cldm.model import create_model, load_state_dict
 23 | from deps.gmflow.gmflow.gmflow import GMFlow
 24 | from flow.flow_utils import get_warped_and_mask
 25 | from sd_model_cfg import model_dict
 26 | from src.config import RerenderConfig
 27 | from src.controller import AttentionControl
 28 | from src.ddim_v_hacked import DDIMVSampler
 29 | from src.freeu import freeu_forward
 30 | from src.img_util import find_flat_region, numpy2tensor
 31 | from src.video_util import (frame_to_video, get_fps, get_frame_count,
 32 |                             prepare_frames)
 33 | 
 34 | inversed_model_dict = dict()
 35 | for k, v in model_dict.items():
 36 |     inversed_model_dict[v] = k
 37 | 
 38 | to_tensor = T.PILToTensor()
 39 | blur = T.GaussianBlur(kernel_size=(9, 9), sigma=(18, 18))
 40 | 
 41 | 
 42 | class ProcessingState(Enum):
 43 |     NULL = 0
 44 |     FIRST_IMG = 1
 45 |     KEY_IMGS = 2
 46 | 
 47 | 
 48 | class GlobalState:
 49 | 
 50 |     def __init__(self):
 51 |         self.sd_model = None
 52 |         self.ddim_v_sampler = None
 53 |         self.detector_type = None
 54 |         self.detector = None
 55 |         self.controller = None
 56 |         self.processing_state = ProcessingState.NULL
 57 |         flow_model = GMFlow(
 58 |             feature_channels=128,
 59 |             num_scales=1,
 60 |             upsample_factor=8,
 61 |             num_head=1,
 62 |             attention_type='swin',
 63 |             ffn_dim_expansion=4,
 64 |             num_transformer_layers=6,
 65 |         ).to('cuda')
 66 | 
 67 |         checkpoint = torch.load('models/gmflow_sintel-0c07dcb3.pth',
 68 |                                 map_location=lambda storage, loc: storage)
 69 |         weights = checkpoint['model'] if 'model' in checkpoint else checkpoint
 70 |         flow_model.load_state_dict(weights, strict=False)
 71 |         flow_model.eval()
 72 |         self.flow_model = flow_model
 73 | 
 74 |     def update_controller(self, inner_strength, mask_period, cross_period,
 75 |                           ada_period, warp_period, loose_cfattn):
 76 |         self.controller = AttentionControl(inner_strength,
 77 |                                            mask_period,
 78 |                                            cross_period,
 79 |                                            ada_period,
 80 |                                            warp_period,
 81 |                                            loose_cfatnn=loose_cfattn)
 82 | 
 83 |     def update_sd_model(self, sd_model, control_type, freeu_args):
 84 |         if sd_model == self.sd_model:
 85 |             return
 86 |         self.sd_model = sd_model
 87 |         model = create_model('./deps/ControlNet/models/cldm_v15.yaml').cpu()
 88 |         if control_type == 'HED':
 89 |             model.load_state_dict(
 90 |                 load_state_dict('./models/control_sd15_hed.pth',
 91 |                                 location='cuda'))
 92 |         elif control_type == 'canny':
 93 |             model.load_state_dict(
 94 |                 load_state_dict('./models/control_sd15_canny.pth',
 95 |                                 location='cuda'))
 96 |         model = model.cuda()
 97 |         sd_model_path = model_dict[sd_model]
 98 |         if len(sd_model_path) > 0:
 99 |             model_ext = os.path.splitext(sd_model_path)[1]
100 |             if model_ext == '.safetensors':
101 |                 model.load_state_dict(load_file(sd_model_path), strict=False)
102 |             elif model_ext == '.ckpt' or model_ext == '.pth':
103 |                 model.load_state_dict(torch.load(sd_model_path)['state_dict'],
104 |                                       strict=False)
105 | 
106 |         try:
107 |             model.first_stage_model.load_state_dict(torch.load(
108 |                 './models/vae-ft-mse-840000-ema-pruned.ckpt')['state_dict'],
109 |                                                     strict=False)
110 |         except Exception:
111 |             print('Warning: We suggest you download the fine-tuned VAE',
112 |                   'otherwise the generation quality will be degraded')
113 | 
114 |         model.model.diffusion_model.forward = freeu_forward(
115 |             model.model.diffusion_model, *freeu_args)
116 |         self.ddim_v_sampler = DDIMVSampler(model)
117 | 
118 |     def clear_sd_model(self):
119 |         self.sd_model = None
120 |         self.ddim_v_sampler = None
121 |         torch.cuda.empty_cache()
122 | 
123 |     def update_detector(self, control_type, canny_low=100, canny_high=200):
124 |         if self.detector_type == control_type:
125 |             return
126 |         if control_type == 'HED':
127 |             self.detector = HEDdetector()
128 |         elif control_type == 'canny':
129 |             canny_detector = CannyDetector()
130 |             low_threshold = canny_low
131 |             high_threshold = canny_high
132 | 
133 |             def apply_canny(x):
134 |                 return canny_detector(x, low_threshold, high_threshold)
135 | 
136 |             self.detector = apply_canny
137 | 
138 | 
139 | global_state = GlobalState()
140 | global_video_path = None
141 | video_frame_count = None
142 | 
143 | 
144 | def create_cfg(input_path, prompt, image_resolution, control_strength,
145 |                color_preserve, left_crop, right_crop, top_crop, bottom_crop,
146 |                control_type, low_threshold, high_threshold, ddim_steps, scale,
147 |                seed, sd_model, a_prompt, n_prompt, interval, keyframe_count,
148 |                x0_strength, use_constraints, cross_start, cross_end,
149 |                style_update_freq, warp_start, warp_end, mask_start, mask_end,
150 |                ada_start, ada_end, mask_strength, inner_strength,
151 |                smooth_boundary, loose_cfattn, b1, b2, s1, s2):
152 |     use_warp = 'shape-aware fusion' in use_constraints
153 |     use_mask = 'pixel-aware fusion' in use_constraints
154 |     use_ada = 'color-aware AdaIN' in use_constraints
155 | 
156 |     if not use_warp:
157 |         warp_start = 1
158 |         warp_end = 0
159 | 
160 |     if not use_mask:
161 |         mask_start = 1
162 |         mask_end = 0
163 | 
164 |     if not use_ada:
165 |         ada_start = 1
166 |         ada_end = 0
167 | 
168 |     input_name = os.path.split(input_path)[-1].split('.')[0]
169 |     frame_count = 2 + keyframe_count * interval
170 |     cfg = RerenderConfig()
171 |     cfg.create_from_parameters(
172 |         input_path,
173 |         os.path.join('result', input_name, 'blend.mp4'),
174 |         prompt,
175 |         a_prompt=a_prompt,
176 |         n_prompt=n_prompt,
177 |         frame_count=frame_count,
178 |         interval=interval,
179 |         crop=[left_crop, right_crop, top_crop, bottom_crop],
180 |         sd_model=sd_model,
181 |         ddim_steps=ddim_steps,
182 |         scale=scale,
183 |         control_type=control_type,
184 |         control_strength=control_strength,
185 |         canny_low=low_threshold,
186 |         canny_high=high_threshold,
187 |         seed=seed,
188 |         image_resolution=image_resolution,
189 |         x0_strength=x0_strength,
190 |         style_update_freq=style_update_freq,
191 |         cross_period=(cross_start, cross_end),
192 |         warp_period=(warp_start, warp_end),
193 |         mask_period=(mask_start, mask_end),
194 |         ada_period=(ada_start, ada_end),
195 |         mask_strength=mask_strength,
196 |         inner_strength=inner_strength,
197 |         smooth_boundary=smooth_boundary,
198 |         color_preserve=color_preserve,
199 |         loose_cfattn=loose_cfattn,
200 |         freeu_args=[b1, b2, s1, s2])
201 |     return cfg
202 | 
203 | 
204 | def cfg_to_input(filename):
205 | 
206 |     cfg = RerenderConfig()
207 |     cfg.create_from_path(filename)
208 |     keyframe_count = (cfg.frame_count - 2) // cfg.interval
209 |     use_constraints = [
210 |         'shape-aware fusion', 'pixel-aware fusion', 'color-aware AdaIN'
211 |     ]
212 | 
213 |     sd_model = inversed_model_dict.get(cfg.sd_model, 'Stable Diffusion 1.5')
214 | 
215 |     args = [
216 |         cfg.input_path, cfg.prompt, cfg.image_resolution, cfg.control_strength,
217 |         cfg.color_preserve, *cfg.crop, cfg.control_type, cfg.canny_low,
218 |         cfg.canny_high, cfg.ddim_steps, cfg.scale, cfg.seed, sd_model,
219 |         cfg.a_prompt, cfg.n_prompt, cfg.interval, keyframe_count,
220 |         cfg.x0_strength, use_constraints, *cfg.cross_period,
221 |         cfg.style_update_freq, *cfg.warp_period, *cfg.mask_period,
222 |         *cfg.ada_period, cfg.mask_strength, cfg.inner_strength,
223 |         cfg.smooth_boundary, cfg.loose_cfattn, *cfg.freeu_args
224 |     ]
225 |     return args
226 | 
227 | 
228 | def setup_color_correction(image):
229 |     correction_target = cv2.cvtColor(np.asarray(image.copy()),
230 |                                      cv2.COLOR_RGB2LAB)
231 |     return correction_target
232 | 
233 | 
234 | def apply_color_correction(correction, original_image):
235 |     image = Image.fromarray(
236 |         cv2.cvtColor(
237 |             exposure.match_histograms(cv2.cvtColor(np.asarray(original_image),
238 |                                                    cv2.COLOR_RGB2LAB),
239 |                                       correction,
240 |                                       channel_axis=2),
241 |             cv2.COLOR_LAB2RGB).astype('uint8'))
242 | 
243 |     image = blendLayers(image, original_image, BlendType.LUMINOSITY)
244 | 
245 |     return image
246 | 
247 | 
248 | @torch.no_grad()
249 | def process(*args):
250 |     args_wo_process3 = args[:-2]
251 |     first_frame = process1(*args_wo_process3)
252 | 
253 |     keypath = process2(*args_wo_process3)
254 | 
255 |     fullpath = process3(*args)
256 | 
257 |     return first_frame, keypath, fullpath
258 | 
259 | 
260 | @torch.no_grad()
261 | def process1(*args):
262 | 
263 |     global global_video_path
264 |     cfg = create_cfg(global_video_path, *args)
265 |     global global_state
266 |     global_state.update_sd_model(cfg.sd_model, cfg.control_type,
267 |                                  cfg.freeu_args)
268 |     global_state.update_controller(cfg.inner_strength, cfg.mask_period,
269 |                                    cfg.cross_period, cfg.ada_period,
270 |                                    cfg.warp_period, cfg.loose_cfattn)
271 |     global_state.update_detector(cfg.control_type, cfg.canny_low,
272 |                                  cfg.canny_high)
273 |     global_state.processing_state = ProcessingState.FIRST_IMG
274 | 
275 |     prepare_frames(cfg.input_path, cfg.input_dir, cfg.image_resolution, cfg.crop, cfg.use_limit_device_resolution)
276 | 
277 |     ddim_v_sampler = global_state.ddim_v_sampler
278 |     model = ddim_v_sampler.model
279 |     detector = global_state.detector
280 |     controller = global_state.controller
281 |     model.control_scales = [cfg.control_strength] * 13
282 | 
283 |     num_samples = 1
284 |     eta = 0.0
285 |     imgs = sorted(os.listdir(cfg.input_dir))
286 |     imgs = [os.path.join(cfg.input_dir, img) for img in imgs]
287 | 
288 |     with torch.no_grad():
289 |         frame = cv2.imread(imgs[0])
290 |         frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
291 |         img = HWC3(frame)
292 |         H, W, C = img.shape
293 | 
294 |         img_ = numpy2tensor(img)
295 | 
296 |         def generate_first_img(img_, strength):
297 |             encoder_posterior = model.encode_first_stage(img_.cuda())
298 |             x0 = model.get_first_stage_encoding(encoder_posterior).detach()
299 | 
300 |             detected_map = detector(img)
301 |             detected_map = HWC3(detected_map)
302 | 
303 |             control = torch.from_numpy(
304 |                 detected_map.copy()).float().cuda() / 255.0
305 |             control = torch.stack([control for _ in range(num_samples)], dim=0)
306 |             control = einops.rearrange(control, 'b h w c -> b c h w').clone()
307 |             cond = {
308 |                 'c_concat': [control],
309 |                 'c_crossattn': [
310 |                     model.get_learned_conditioning(
311 |                         [cfg.prompt + ', ' + cfg.a_prompt] * num_samples)
312 |                 ]
313 |             }
314 |             un_cond = {
315 |                 'c_concat': [control],
316 |                 'c_crossattn':
317 |                 [model.get_learned_conditioning([cfg.n_prompt] * num_samples)]
318 |             }
319 |             shape = (4, H // 8, W // 8)
320 | 
321 |             controller.set_task('initfirst')
322 |             seed_everything(cfg.seed)
323 | 
324 |             samples, _ = ddim_v_sampler.sample(
325 |                 cfg.ddim_steps,
326 |                 num_samples,
327 |                 shape,
328 |                 cond,
329 |                 verbose=False,
330 |                 eta=eta,
331 |                 unconditional_guidance_scale=cfg.scale,
332 |                 unconditional_conditioning=un_cond,
333 |                 controller=controller,
334 |                 x0=x0,
335 |                 strength=strength)
336 |             x_samples = model.decode_first_stage(samples)
337 |             x_samples_np = (
338 |                 einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
339 |                 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
340 |             return x_samples, x_samples_np
341 | 
342 |         # When not preserve color, draw a different frame at first and use its
343 |         # color to redraw the first frame.
344 |         if not cfg.color_preserve:
345 |             first_strength = -1
346 |         else:
347 |             first_strength = 1 - cfg.x0_strength
348 | 
349 |         x_samples, x_samples_np = generate_first_img(img_, first_strength)
350 | 
351 |         if not cfg.color_preserve:
352 |             color_corrections = setup_color_correction(
353 |                 Image.fromarray(x_samples_np[0]))
354 |             global_state.color_corrections = color_corrections
355 |             img_ = apply_color_correction(color_corrections,
356 |                                           Image.fromarray(img))
357 |             img_ = to_tensor(img_).unsqueeze(0)[:, :3] / 127.5 - 1
358 |             x_samples, x_samples_np = generate_first_img(
359 |                 img_, 1 - cfg.x0_strength)
360 | 
361 |         global_state.first_result = x_samples
362 |         global_state.first_img = img
363 | 
364 |     Image.fromarray(x_samples_np[0]).save(
365 |         os.path.join(cfg.first_dir, 'first.jpg'))
366 | 
367 |     return x_samples_np[0]
368 | 
369 | 
370 | @torch.no_grad()
371 | def process2(*args):
372 |     global global_state
373 |     global global_video_path
374 | 
375 |     if global_state.processing_state != ProcessingState.FIRST_IMG:
376 |         raise gr.Error('Please generate the first key image before generating'
377 |                        ' all key images')
378 | 
379 |     cfg = create_cfg(global_video_path, *args)
380 |     global_state.update_sd_model(cfg.sd_model, cfg.control_type,
381 |                                  cfg.freeu_args)
382 |     global_state.update_detector(cfg.control_type, cfg.canny_low,
383 |                                  cfg.canny_high)
384 |     global_state.processing_state = ProcessingState.KEY_IMGS
385 | 
386 |     # reset key dir
387 |     shutil.rmtree(cfg.key_dir)
388 |     os.makedirs(cfg.key_dir, exist_ok=True)
389 | 
390 |     ddim_v_sampler = global_state.ddim_v_sampler
391 |     model = ddim_v_sampler.model
392 |     detector = global_state.detector
393 |     controller = global_state.controller
394 |     flow_model = global_state.flow_model
395 |     model.control_scales = [cfg.control_strength] * 13
396 | 
397 |     num_samples = 1
398 |     eta = 0.0
399 |     firstx0 = True
400 |     pixelfusion = cfg.use_mask
401 |     imgs = sorted(os.listdir(cfg.input_dir))
402 |     imgs = [os.path.join(cfg.input_dir, img) for img in imgs]
403 | 
404 |     first_result = global_state.first_result
405 |     first_img = global_state.first_img
406 |     pre_result = first_result
407 |     pre_img = first_img
408 | 
409 |     for i in range(0, min(len(imgs), cfg.frame_count) - 1, cfg.interval):
410 |         cid = i + 1
411 |         print(cid)
412 |         if cid <= (len(imgs) - 1):
413 |             frame = cv2.imread(imgs[cid])
414 |         else:
415 |             frame = cv2.imread(imgs[len(imgs) - 1])
416 |         frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
417 |         img = HWC3(frame)
418 |         H, W, C = img.shape
419 | 
420 |         if cfg.color_preserve or global_state.color_corrections is None:
421 |             img_ = numpy2tensor(img)
422 |         else:
423 |             img_ = apply_color_correction(global_state.color_corrections,
424 |                                           Image.fromarray(img))
425 |             img_ = to_tensor(img_).unsqueeze(0)[:, :3] / 127.5 - 1
426 |         encoder_posterior = model.encode_first_stage(img_.cuda())
427 |         x0 = model.get_first_stage_encoding(encoder_posterior).detach()
428 | 
429 |         detected_map = detector(img)
430 |         detected_map = HWC3(detected_map)
431 | 
432 |         control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
433 |         control = torch.stack([control for _ in range(num_samples)], dim=0)
434 |         control = einops.rearrange(control, 'b h w c -> b c h w').clone()
435 |         cond = {
436 |             'c_concat': [control],
437 |             'c_crossattn': [
438 |                 model.get_learned_conditioning(
439 |                     [cfg.prompt + ', ' + cfg.a_prompt] * num_samples)
440 |             ]
441 |         }
442 |         un_cond = {
443 |             'c_concat': [control],
444 |             'c_crossattn':
445 |             [model.get_learned_conditioning([cfg.n_prompt] * num_samples)]
446 |         }
447 |         shape = (4, H // 8, W // 8)
448 | 
449 |         cond['c_concat'] = [control]
450 |         un_cond['c_concat'] = [control]
451 | 
452 |         image1 = torch.from_numpy(pre_img).permute(2, 0, 1).float()
453 |         image2 = torch.from_numpy(img).permute(2, 0, 1).float()
454 |         warped_pre, bwd_occ_pre, bwd_flow_pre = get_warped_and_mask(
455 |             flow_model, image1, image2, pre_result, False)
456 |         blend_mask_pre = blur(
457 |             F.max_pool2d(bwd_occ_pre, kernel_size=9, stride=1, padding=4))
458 |         blend_mask_pre = torch.clamp(blend_mask_pre + bwd_occ_pre, 0, 1)
459 | 
460 |         image1 = torch.from_numpy(first_img).permute(2, 0, 1).float()
461 |         warped_0, bwd_occ_0, bwd_flow_0 = get_warped_and_mask(
462 |             flow_model, image1, image2, first_result, False)
463 |         blend_mask_0 = blur(
464 |             F.max_pool2d(bwd_occ_0, kernel_size=9, stride=1, padding=4))
465 |         blend_mask_0 = torch.clamp(blend_mask_0 + bwd_occ_0, 0, 1)
466 | 
467 |         if firstx0:
468 |             mask = 1 - F.max_pool2d(blend_mask_0, kernel_size=8)
469 |             controller.set_warp(
470 |                 F.interpolate(bwd_flow_0 / 8.0,
471 |                               scale_factor=1. / 8,
472 |                               mode='bilinear'), mask)
473 |         else:
474 |             mask = 1 - F.max_pool2d(blend_mask_pre, kernel_size=8)
475 |             controller.set_warp(
476 |                 F.interpolate(bwd_flow_pre / 8.0,
477 |                               scale_factor=1. / 8,
478 |                               mode='bilinear'), mask)
479 | 
480 |         controller.set_task('keepx0, keepstyle')
481 |         seed_everything(cfg.seed)
482 |         samples, intermediates = ddim_v_sampler.sample(
483 |             cfg.ddim_steps,
484 |             num_samples,
485 |             shape,
486 |             cond,
487 |             verbose=False,
488 |             eta=eta,
489 |             unconditional_guidance_scale=cfg.scale,
490 |             unconditional_conditioning=un_cond,
491 |             controller=controller,
492 |             x0=x0,
493 |             strength=1 - cfg.x0_strength)
494 |         direct_result = model.decode_first_stage(samples)
495 | 
496 |         if not pixelfusion:
497 |             pre_result = direct_result
498 |             pre_img = img
499 |             viz = (
500 |                 einops.rearrange(direct_result, 'b c h w -> b h w c') * 127.5 +
501 |                 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
502 | 
503 |         else:
504 | 
505 |             blend_results = (1 - blend_mask_pre
506 |                              ) * warped_pre + blend_mask_pre * direct_result
507 |             blend_results = (
508 |                 1 - blend_mask_0) * warped_0 + blend_mask_0 * blend_results
509 | 
510 |             bwd_occ = 1 - torch.clamp(1 - bwd_occ_pre + 1 - bwd_occ_0, 0, 1)
511 |             blend_mask = blur(
512 |                 F.max_pool2d(bwd_occ, kernel_size=9, stride=1, padding=4))
513 |             blend_mask = 1 - torch.clamp(blend_mask + bwd_occ, 0, 1)
514 | 
515 |             encoder_posterior = model.encode_first_stage(blend_results)
516 |             xtrg = model.get_first_stage_encoding(
517 |                 encoder_posterior).detach()  # * mask
518 |             blend_results_rec = model.decode_first_stage(xtrg)
519 |             encoder_posterior = model.encode_first_stage(blend_results_rec)
520 |             xtrg_rec = model.get_first_stage_encoding(
521 |                 encoder_posterior).detach()
522 |             xtrg_ = (xtrg + 1 * (xtrg - xtrg_rec))  # * mask
523 |             blend_results_rec_new = model.decode_first_stage(xtrg_)
524 |             tmp = (abs(blend_results_rec_new - blend_results).mean(
525 |                 dim=1, keepdims=True) > 0.25).float()
526 |             mask_x = F.max_pool2d((F.interpolate(
527 |                 tmp, scale_factor=1 / 8., mode='bilinear') > 0).float(),
528 |                                   kernel_size=3,
529 |                                   stride=1,
530 |                                   padding=1)
531 | 
532 |             mask = (1 - F.max_pool2d(1 - blend_mask, kernel_size=8)
533 |                     )  # * (1-mask_x)
534 | 
535 |             if cfg.smooth_boundary:
536 |                 noise_rescale = find_flat_region(mask)
537 |             else:
538 |                 noise_rescale = torch.ones_like(mask)
539 |             masks = []
540 |             for j in range(cfg.ddim_steps):
541 |                 if j <= cfg.ddim_steps * cfg.mask_period[
542 |                         0] or j >= cfg.ddim_steps * cfg.mask_period[1]:
543 |                     masks += [None]
544 |                 else:
545 |                     masks += [mask * cfg.mask_strength]
546 | 
547 |             # mask 3
548 |             # xtrg = ((1-mask_x) *
549 |             #         (xtrg + xtrg - xtrg_rec) + mask_x * samples) * mask
550 |             # mask 2
551 |             # xtrg = (xtrg + 1 * (xtrg - xtrg_rec)) * mask
552 |             xtrg = (xtrg + (1 - mask_x) * (xtrg - xtrg_rec)) * mask  # mask 1
553 | 
554 |             tasks = 'keepstyle, keepx0'
555 |             if not firstx0:
556 |                 tasks += ', updatex0'
557 |             if i % cfg.style_update_freq == 0:
558 |                 tasks += ', updatestyle'
559 |             controller.set_task(tasks, 1.0)
560 | 
561 |             seed_everything(cfg.seed)
562 |             samples, _ = ddim_v_sampler.sample(
563 |                 cfg.ddim_steps,
564 |                 num_samples,
565 |                 shape,
566 |                 cond,
567 |                 verbose=False,
568 |                 eta=eta,
569 |                 unconditional_guidance_scale=cfg.scale,
570 |                 unconditional_conditioning=un_cond,
571 |                 controller=controller,
572 |                 x0=x0,
573 |                 strength=1 - cfg.x0_strength,
574 |                 xtrg=xtrg,
575 |                 mask=masks,
576 |                 noise_rescale=noise_rescale)
577 |             x_samples = model.decode_first_stage(samples)
578 |             pre_result = x_samples
579 |             pre_img = img
580 | 
581 |             viz = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
582 |                    127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
583 | 
584 |         Image.fromarray(viz[0]).save(
585 |             os.path.join(cfg.key_dir, f'{cid:04d}.png'))
586 | 
587 |     key_video_path = os.path.join(cfg.work_dir, 'key.mp4')
588 |     fps = get_fps(cfg.input_path)
589 |     fps //= cfg.interval
590 |     frame_to_video(key_video_path, cfg.key_dir, fps, False)
591 | 
592 |     return key_video_path
593 | 
594 | 
595 | @torch.no_grad()
596 | def process3(*args):
597 |     max_process = args[-2]
598 |     use_poisson = args[-1]
599 |     args = args[:-2]
600 |     global global_video_path
601 |     global global_state
602 |     if global_state.processing_state != ProcessingState.KEY_IMGS:
603 |         raise gr.Error('Please generate key images before propagation')
604 | 
605 |     global_state.clear_sd_model()
606 | 
607 |     cfg = create_cfg(global_video_path, *args)
608 | 
609 |     # reset blend dir
610 |     blend_dir = os.path.join(cfg.work_dir, 'blend')
611 |     if os.path.exists(blend_dir):
612 |         shutil.rmtree(blend_dir)
613 |     os.makedirs(blend_dir, exist_ok=True)
614 | 
615 |     video_base_dir = cfg.work_dir
616 |     o_video = cfg.output_path
617 |     fps = get_fps(cfg.input_path)
618 | 
619 |     end_frame = cfg.frame_count - 1
620 |     interval = cfg.interval
621 |     key_dir = os.path.split(cfg.key_dir)[-1]
622 |     o_video_cmd = f'--output {o_video}'
623 |     ps = '-ps' if use_poisson else ''
624 |     cmd = (f'python video_blend.py {video_base_dir} --beg 1 --end {end_frame} '
625 |            f'--itv {interval} --key {key_dir}  {o_video_cmd} --fps {fps} '
626 |            f'--n_proc {max_process} {ps}')
627 |     print(cmd)
628 |     os.system(cmd)
629 | 
630 |     return o_video
631 | 
632 | 
633 | block = gr.Blocks().queue()
634 | with block:
635 |     with gr.Row():
636 |         gr.Markdown('## Rerender A Video')
637 |     with gr.Row():
638 |         with gr.Column():
639 |             input_path = gr.Video(label='Input Video',
640 |                                   source='upload',
641 |                                   format='mp4',
642 |                                   visible=True)
643 |             prompt = gr.Textbox(label='Prompt')
644 |             seed = gr.Slider(label='Seed',
645 |                              minimum=0,
646 |                              maximum=2147483647,
647 |                              step=1,
648 |                              value=0,
649 |                              randomize=True)
650 |             run_button = gr.Button(value='Run All')
651 |             with gr.Row():
652 |                 run_button1 = gr.Button(value='Run 1st Key Frame')
653 |                 run_button2 = gr.Button(value='Run Key Frames')
654 |                 run_button3 = gr.Button(value='Run Propagation')
655 |             with gr.Accordion('Advanced options for the 1st frame translation',
656 |                               open=False):
657 |                 image_resolution = gr.Slider(label='Frame resolution',
658 |                                              minimum=256,
659 |                                              maximum=768,
660 |                                              value=512,
661 |                                              step=64)
662 |                 control_strength = gr.Slider(label='ControlNet strength',
663 |                                              minimum=0.0,
664 |                                              maximum=2.0,
665 |                                              value=1.0,
666 |                                              step=0.01)
667 |                 x0_strength = gr.Slider(
668 |                     label='Denoising strength',
669 |                     minimum=0.00,
670 |                     maximum=1.05,
671 |                     value=0.75,
672 |                     step=0.05,
673 |                     info=('0: fully recover the input.'
674 |                           '1.05: fully rerender the input.'))
675 |                 color_preserve = gr.Checkbox(
676 |                     label='Preserve color',
677 |                     value=True,
678 |                     info='Keep the color of the input video')
679 |                 with gr.Row():
680 |                     left_crop = gr.Slider(label='Left crop length',
681 |                                           minimum=0,
682 |                                           maximum=512,
683 |                                           value=0,
684 |                                           step=1)
685 |                     right_crop = gr.Slider(label='Right crop length',
686 |                                            minimum=0,
687 |                                            maximum=512,
688 |                                            value=0,
689 |                                            step=1)
690 |                 with gr.Row():
691 |                     top_crop = gr.Slider(label='Top crop length',
692 |                                          minimum=0,
693 |                                          maximum=512,
694 |                                          value=0,
695 |                                          step=1)
696 |                     bottom_crop = gr.Slider(label='Bottom crop length',
697 |                                             minimum=0,
698 |                                             maximum=512,
699 |                                             value=0,
700 |                                             step=1)
701 |                 with gr.Row():
702 |                     control_type = gr.Dropdown(['HED', 'canny'],
703 |                                                label='Control type',
704 |                                                value='HED')
705 |                     low_threshold = gr.Slider(label='Canny low threshold',
706 |                                               minimum=1,
707 |                                               maximum=255,
708 |                                               value=100,
709 |                                               step=1)
710 |                     high_threshold = gr.Slider(label='Canny high threshold',
711 |                                                minimum=1,
712 |                                                maximum=255,
713 |                                                value=200,
714 |                                                step=1)
715 |                 ddim_steps = gr.Slider(label='Steps',
716 |                                        minimum=20,
717 |                                        maximum=100,
718 |                                        value=20,
719 |                                        step=20)
720 |                 scale = gr.Slider(label='CFG scale',
721 |                                   minimum=0.1,
722 |                                   maximum=30.0,
723 |                                   value=7.5,
724 |                                   step=0.1)
725 |                 sd_model_list = list(model_dict.keys())
726 |                 sd_model = gr.Dropdown(sd_model_list,
727 |                                        label='Base model',
728 |                                        value='Stable Diffusion 1.5')
729 |                 a_prompt = gr.Textbox(label='Added prompt',
730 |                                       value='best quality, extremely detailed')
731 |                 n_prompt = gr.Textbox(
732 |                     label='Negative prompt',
733 |                     value=('longbody, lowres, bad anatomy, bad hands, '
734 |                            'missing fingers, extra digit, fewer digits, '
735 |                            'cropped, worst quality, low quality'))
736 |                 with gr.Row():
737 |                     b1 = gr.Slider(label='FreeU first-stage backbone factor',
738 |                                    minimum=1,
739 |                                    maximum=1.6,
740 |                                    value=1,
741 |                                    step=0.01,
742 |                                    info='FreeU to enhance texture and color')
743 |                     b2 = gr.Slider(label='FreeU second-stage backbone factor',
744 |                                    minimum=1,
745 |                                    maximum=1.6,
746 |                                    value=1,
747 |                                    step=0.01)
748 |                 with gr.Row():
749 |                     s1 = gr.Slider(label='FreeU first-stage skip factor',
750 |                                    minimum=0,
751 |                                    maximum=1,
752 |                                    value=1,
753 |                                    step=0.01)
754 |                     s2 = gr.Slider(label='FreeU second-stage skip factor',
755 |                                    minimum=0,
756 |                                    maximum=1,
757 |                                    value=1,
758 |                                    step=0.01)
759 |             with gr.Accordion('Advanced options for the key fame translation',
760 |                               open=False):
761 |                 interval = gr.Slider(
762 |                     label='Key frame frequency (K)',
763 |                     minimum=1,
764 |                     maximum=1,
765 |                     value=1,
766 |                     step=1,
767 |                     info='Uniformly sample the key frames every K frames')
768 |                 keyframe_count = gr.Slider(label='Number of key frames',
769 |                                            minimum=1,
770 |                                            maximum=1,
771 |                                            value=1,
772 |                                            step=1)
773 | 
774 |                 use_constraints = gr.CheckboxGroup(
775 |                     [
776 |                         'shape-aware fusion', 'pixel-aware fusion',
777 |                         'color-aware AdaIN'
778 |                     ],
779 |                     label='Select the cross-frame contraints to be used',
780 |                     value=[
781 |                         'shape-aware fusion', 'pixel-aware fusion',
782 |                         'color-aware AdaIN'
783 |                     ]),
784 |                 with gr.Row():
785 |                     cross_start = gr.Slider(
786 |                         label='Cross-frame attention start',
787 |                         minimum=0,
788 |                         maximum=1,
789 |                         value=0,
790 |                         step=0.05)
791 |                     cross_end = gr.Slider(label='Cross-frame attention end',
792 |                                           minimum=0,
793 |                                           maximum=1,
794 |                                           value=1,
795 |                                           step=0.05)
796 |                 style_update_freq = gr.Slider(
797 |                     label='Cross-frame attention update frequency',
798 |                     minimum=1,
799 |                     maximum=100,
800 |                     value=1,
801 |                     step=1,
802 |                     info=('Update the key and value for '
803 |                           'cross-frame attention every N key frames'))
804 |                 loose_cfattn = gr.Checkbox(
805 |                     label='Loose Cross-frame attention',
806 |                     value=True,
807 |                     info='Select to make output better match the input video')
808 |                 with gr.Row():
809 |                     warp_start = gr.Slider(label='Shape-aware fusion start',
810 |                                            minimum=0,
811 |                                            maximum=1,
812 |                                            value=0,
813 |                                            step=0.05)
814 |                     warp_end = gr.Slider(label='Shape-aware fusion end',
815 |                                          minimum=0,
816 |                                          maximum=1,
817 |                                          value=0.1,
818 |                                          step=0.05)
819 |                 with gr.Row():
820 |                     mask_start = gr.Slider(label='Pixel-aware fusion start',
821 |                                            minimum=0,
822 |                                            maximum=1,
823 |                                            value=0.5,
824 |                                            step=0.05)
825 |                     mask_end = gr.Slider(label='Pixel-aware fusion end',
826 |                                          minimum=0,
827 |                                          maximum=1,
828 |                                          value=0.8,
829 |                                          step=0.05)
830 |                 with gr.Row():
831 |                     ada_start = gr.Slider(label='Color-aware AdaIN start',
832 |                                           minimum=0,
833 |                                           maximum=1,
834 |                                           value=0.8,
835 |                                           step=0.05)
836 |                     ada_end = gr.Slider(label='Color-aware AdaIN end',
837 |                                         minimum=0,
838 |                                         maximum=1,
839 |                                         value=1,
840 |                                         step=0.05)
841 |                 mask_strength = gr.Slider(label='Pixel-aware fusion strength',
842 |                                           minimum=0,
843 |                                           maximum=1,
844 |                                           value=0.5,
845 |                                           step=0.01)
846 |                 inner_strength = gr.Slider(
847 |                     label='Pixel-aware fusion detail level',
848 |                     minimum=0.5,
849 |                     maximum=1,
850 |                     value=0.9,
851 |                     step=0.01,
852 |                     info='Use a low value to prevent artifacts')
853 |                 smooth_boundary = gr.Checkbox(
854 |                     label='Smooth fusion boundary',
855 |                     value=True,
856 |                     info='Select to prevent artifacts at boundary')
857 |             with gr.Accordion(
858 |                     'Advanced options for the full video translation',
859 |                     open=False):
860 |                 use_poisson = gr.Checkbox(
861 |                     label='Gradient blending',
862 |                     value=True,
863 |                     info=('Blend the output video in gradient, to reduce'
864 |                           ' ghosting artifacts (but may increase flickers)'))
865 |                 max_process = gr.Slider(label='Number of parallel processes',
866 |                                         minimum=1,
867 |                                         maximum=16,
868 |                                         value=4,
869 |                                         step=1)
870 | 
871 |             with gr.Accordion('Example configs', open=True):
872 |                 config_dir = 'config'
873 |                 config_list = [
874 |                     'real2sculpture.json', 'van_gogh_man.json', 'woman.json'
875 |                 ]
876 |                 args_list = []
877 |                 for config in config_list:
878 |                     try:
879 |                         config_path = os.path.join(config_dir, config)
880 |                         args = cfg_to_input(config_path)
881 |                         args_list.append(args)
882 |                     except FileNotFoundError:
883 |                         # The video file does not exist, skipped
884 |                         pass
885 | 
886 |                 ips = [
887 |                     prompt, image_resolution, control_strength, color_preserve,
888 |                     left_crop, right_crop, top_crop, bottom_crop, control_type,
889 |                     low_threshold, high_threshold, ddim_steps, scale, seed,
890 |                     sd_model, a_prompt, n_prompt, interval, keyframe_count,
891 |                     x0_strength, use_constraints[0], cross_start, cross_end,
892 |                     style_update_freq, warp_start, warp_end, mask_start,
893 |                     mask_end, ada_start, ada_end, mask_strength,
894 |                     inner_strength, smooth_boundary, loose_cfattn, b1, b2, s1,
895 |                     s2
896 |                 ]
897 | 
898 |                 gr.Examples(
899 |                     examples=args_list,
900 |                     inputs=[input_path, *ips],
901 |                 )
902 | 
903 |         with gr.Column():
904 |             result_image = gr.Image(label='Output first frame',
905 |                                     type='numpy',
906 |                                     interactive=False)
907 |             result_keyframe = gr.Video(label='Output key frame video',
908 |                                        format='mp4',
909 |                                        interactive=False)
910 |             result_video = gr.Video(label='Output full video',
911 |                                     format='mp4',
912 |                                     interactive=False)
913 | 
914 |     def input_uploaded(path):
915 |         frame_count = get_frame_count(path)
916 |         if frame_count <= 2:
917 |             raise gr.Error('The input video is too short!'
918 |                            'Please input another video.')
919 | 
920 |         default_interval = min(10, frame_count - 2)
921 |         max_keyframe = (frame_count - 2) // default_interval
922 | 
923 |         global video_frame_count
924 |         video_frame_count = frame_count
925 |         global global_video_path
926 |         global_video_path = path
927 | 
928 |         return gr.Slider.update(value=default_interval,
929 |                                 maximum=max_keyframe), gr.Slider.update(
930 |                                     value=max_keyframe, maximum=max_keyframe)
931 | 
932 |     def input_changed(path):
933 |         frame_count = get_frame_count(path)
934 |         if frame_count <= 2:
935 |             return gr.Slider.update(maximum=1), gr.Slider.update(maximum=1)
936 | 
937 |         default_interval = min(10, frame_count - 2)
938 |         max_keyframe = (frame_count - 2) // default_interval
939 | 
940 |         global video_frame_count
941 |         video_frame_count = frame_count
942 |         global global_video_path
943 |         global_video_path = path
944 | 
945 |         return gr.Slider.update(maximum=max_keyframe), \
946 |             gr.Slider.update(maximum=max_keyframe)
947 | 
948 |     def interval_changed(interval):
949 |         global video_frame_count
950 |         if video_frame_count is None:
951 |             return gr.Slider.update()
952 | 
953 |         max_keyframe = (video_frame_count - 2) // interval
954 | 
955 |         return gr.Slider.update(value=max_keyframe, maximum=max_keyframe)
956 | 
957 |     input_path.change(input_changed, input_path, [interval, keyframe_count])
958 |     input_path.upload(input_uploaded, input_path, [interval, keyframe_count])
959 |     interval.change(interval_changed, interval, keyframe_count)
960 | 
961 |     ips_process3 = [*ips, max_process, use_poisson]
962 |     run_button.click(fn=process,
963 |                      inputs=ips_process3,
964 |                      outputs=[result_image, result_keyframe, result_video])
965 |     run_button1.click(fn=process1, inputs=ips, outputs=[result_image])
966 |     run_button2.click(fn=process2, inputs=ips, outputs=[result_keyframe])
967 |     run_button3.click(fn=process3, inputs=ips_process3, outputs=[result_video])
968 | 
969 | block.launch(server_name='localhost')
970 | 


--------------------------------------------------------------------------------