├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── README_v044.md ├── baselight └── uk.co.andriy.mltimewarp.v1 │ ├── IFNet_HDv3.py │ ├── LICENCE.md │ ├── VERSION │ ├── bin │ ├── Activate.ps1 │ ├── activate │ ├── activate.csh │ ├── activate.fish │ ├── easy_install │ ├── easy_install-3.9 │ ├── pip │ ├── pip3 │ ├── pip3.9 │ ├── python │ ├── python3 │ └── python3.9 │ ├── flexi.py │ ├── lib64 │ ├── pyvenv.cfg │ ├── timewarp.py │ └── uk.co.dirtylooks.mltimewarp.v1.flexi ├── find_scale.sh ├── fix-xattr.command ├── flameTimewarpML.py ├── flameTimewarpML_framework.py ├── fonts └── DejaVuSansMono.ttf ├── models └── flownet2_v004.pth ├── packages ├── .miniconda │ └── README.md ├── README.md └── clean_pycache.sh ├── presets ├── openexr16bit.xml ├── openexr32bit.xml ├── source_export.xml └── source_export32bit.xml ├── pyflame_lib_flameTimewarpML.py ├── pytorch ├── check_consistent.py ├── check_readable.py ├── flameStabML_train.py ├── flameTimewarpML_findscale.py ├── flameTimewarpML_findscale_ml.py ├── flameTimewarpML_findscale_scipy.py ├── flameTimewarpML_finetune.py ├── flameTimewarpML_inference.py ├── flameTimewarpML_train.py ├── flameTimewarpML_validate.py ├── hub │ └── checkpoints │ │ └── README.md ├── models │ ├── archived │ │ ├── flownet2_v004 copy.py │ │ ├── flownet2_v004.py │ │ ├── flownet4_v001a.py │ │ ├── flownet4_v001aa.py │ │ ├── flownet4_v001ab.py │ │ ├── flownet4_v001b.py │ │ ├── flownet4_v001c.py │ │ ├── flownet4_v001d.py │ │ ├── flownet4_v001e.py │ │ ├── flownet4_v001ea.py │ │ ├── flownet4_v001eb3x.py │ │ ├── flownet4_v001eb3xx.py │ │ ├── flownet4_v001eb5.py │ │ ├── flownet4_v001eb8x_fibonacci5.py │ │ ├── flownet4_v001eb_fibonacci5.py │ │ ├── flownet4_v001ec.py │ │ ├── flownet4_v001ed.py │ │ ├── flownet4_v001ee.py │ │ ├── flownet4_v001ef.py │ │ ├── flownet4_v001eg.py │ │ ├── flownet4_v002.py │ │ ├── flownet4_v002a.py │ │ ├── flownet4_v002b.py │ │ ├── flownet4_v002c.py │ │ ├── flownet4_v002d.py │ │ ├── flownet4_v002e.py │ │ ├── flownet4_v002f.py │ │ ├── flownet4_v002g.py │ │ ├── flownet4_v002h.py │ │ ├── flownet4_v002i.py │ │ ├── flownet4_v002j.py │ │ ├── flownet4_v002k.py │ │ ├── flownet4_v002l.py │ │ ├── flownet4_v002la.py │ │ ├── flownet4_v002lb.py │ │ ├── flownet4_v002lc.py │ │ ├── flownet4_v002ld.py │ │ ├── flownet4_v002le.py │ │ ├── flownet4_v002lf.py │ │ ├── flownet4_v002lg.py │ │ ├── flownet4_v002lh.py │ │ ├── flownet4_v002li.py │ │ ├── flownet4_v002lj.py │ │ ├── flownet4_v002lk.py │ │ ├── flownet4_v002ll.py │ │ ├── flownet4_v002lm.py │ │ ├── flownet4_v002ln.py │ │ ├── flownet4_v002m.py │ │ ├── flownet4_v003.py │ │ ├── flownet4_v003a.py │ │ ├── flownet4_v003b.py │ │ ├── flownet4_v003c.py │ │ ├── flownet4_v003d.py │ │ ├── flownet4_v004.py │ │ └── flownet4_v004_old01.py │ ├── flownet4_v001.py │ ├── flownet4_v001_baseline.py │ ├── flownet4_v001_baseline_sst.py │ ├── flownet4_v001a.py │ ├── flownet4_v001b.py │ ├── flownet4_v001b_v01.py │ ├── flownet4_v001b_v02.py │ ├── flownet4_v001b_v03.py │ ├── flownet4_v001batt.py │ ├── flownet4_v001beatles.py │ ├── flownet4_v001c.py │ ├── flownet4_v001d.py │ ├── flownet4_v001e.py │ ├── flownet4_v001ea.py │ ├── flownet4_v001eb.py │ ├── flownet4_v001eb_dv01.py │ ├── flownet4_v001eb_dv02.py │ ├── flownet4_v001eb_dv03.py │ ├── flownet4_v001eb_dv04.py │ ├── flownet4_v001eb_dv04a.py │ ├── flownet4_v001eb_dv04b.py │ ├── flownet4_v001eb_dv04c.py │ ├── flownet4_v001eb_dv04d.py │ ├── flownet4_v001eb_dv04e.py │ ├── flownet4_v001eb_dv04ea.py │ ├── flownet4_v001eb_dv04f.py │ ├── flownet4_v001ec.py │ ├── flownet4_v001f.py │ ├── flownet4_v001f_v02.py │ ├── flownet4_v001f_v03.py │ ├── flownet4_v001f_v04.py │ ├── flownet4_v001f_v04_001.py │ ├── flownet4_v001f_v04_002.py │ ├── flownet4_v001f_v04_003.py │ ├── flownet4_v001fa_v04_001.py │ ├── flownet4_v001fb_v04_001.py │ ├── flownet4_v001fc_v04_001.py │ ├── flownet4_v001fd_v04_001.py │ ├── flownet4_v001fe_v04_001.py │ ├── flownet4_v001fp4.py │ ├── flownet4_v001hp5.py │ ├── flownet4_v001hpass.py │ ├── flownet4_v001orig.py │ ├── flownet4_v002.py │ ├── flownet4_v003.py │ ├── flownet4_v004.py │ ├── flownet5_v001.py │ ├── flownet5_v002.py │ ├── flownet5_v003.py │ ├── flownet5_v004.py │ ├── flownet5_v005.py │ ├── flownet5_v006.py │ ├── flownet5_v007.py │ ├── flownet5_v008.py │ ├── flownet5_v009.py │ ├── flownet5_v010.py │ ├── stabnet4_v000.py │ ├── stabnet4_v001.py │ ├── stabnet4_v001f_002.py │ ├── stabnet4_v001fa_v04_001.py │ ├── stabnet4_v001fb_v04_001.py │ ├── stabnet4_v001fc_v04_001.py │ ├── stabnet4_v001fd_v04_001.py │ ├── stabnet4_v001fe_v04_001.py │ ├── stabnet4_v002.py │ ├── stabnet4_v003.py │ ├── stabnet4_v004.py │ ├── stabnet4_v005.py │ ├── stabnet4_v006.py │ ├── stabnet4_v007.py │ ├── stabnet4_v008.py │ ├── stabnet4_v009.py │ └── stabnet4_v010.py ├── move_to_folders.py ├── recompress.py ├── rename_keys.py ├── resize_exrs_half.py ├── resize_exrs_min.py ├── resize_exrs_width.py ├── retime.py └── speedtest.py ├── requirements.txt ├── retime.sh ├── speedtest.sh ├── test.timewarp_node ├── train.sh ├── train_stab.sh └── validate.sh /.gitattributes: -------------------------------------------------------------------------------- 1 | # *.pth filter=lfs diff=lfs merge=lfs -text 2 | # *.exr filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | bundle/miniconda.package/Miniconda* 2 | bundle/miniconda.package/packages/* 3 | bundle/bin/* 4 | !bundle/miniconda.package/packages/README 5 | settings.json 6 | *.pyc 7 | *.exr 8 | *.png 9 | flameTimewarpML.package 10 | flameTimewarpMLenv.tar.bz2 11 | flameTimewarpML.bundle.tar 12 | test_* 13 | env 14 | appenv 15 | .DS_Store 16 | *.pth 17 | *.pkl 18 | 19 | !packages/README.md 20 | !packages/.miniconda/README.md 21 | !packages/.miniconda/appenv/README.md 22 | 23 | # Byte-compiled / optimized / DLL files 24 | __pycache__/ 25 | *.py[cod] 26 | *$py.class 27 | 28 | # C extensions 29 | *.so 30 | 31 | # Distribution / packaging 32 | .Python 33 | build/ 34 | develop-eggs/ 35 | dist/ 36 | downloads/ 37 | eggs/ 38 | .eggs/ 39 | lib/ 40 | lib64/ 41 | parts/ 42 | sdist/ 43 | var/ 44 | wheels/ 45 | share/python-wheels/ 46 | *.egg-info/ 47 | .installed.cfg 48 | *.egg 49 | MANIFEST 50 | 51 | # PyInstaller 52 | # Usually these files are written by a python script from a template 53 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 54 | *.manifest 55 | *.spec 56 | 57 | # Installer logs 58 | pip-log.txt 59 | pip-delete-this-directory.txt 60 | 61 | # Unit test / coverage reports 62 | htmlcov/ 63 | .tox/ 64 | .nox/ 65 | .coverage 66 | .coverage.* 67 | .cache 68 | nosetests.xml 69 | coverage.xml 70 | *.cover 71 | *.py,cover 72 | .hypothesis/ 73 | .pytest_cache/ 74 | cover/ 75 | 76 | # Translations 77 | *.mo 78 | *.pot 79 | 80 | # Django stuff: 81 | *.log 82 | local_settings.py 83 | db.sqlite3 84 | db.sqlite3-journal 85 | 86 | # Flask stuff: 87 | instance/ 88 | .webassets-cache 89 | 90 | # Scrapy stuff: 91 | .scrapy 92 | 93 | # Sphinx documentation 94 | docs/_build/ 95 | 96 | # PyBuilder 97 | .pybuilder/ 98 | target/ 99 | 100 | # Jupyter Notebook 101 | .ipynb_checkpoints 102 | 103 | # IPython 104 | profile_default/ 105 | ipython_config.py 106 | 107 | # pyenv 108 | # For a library or package, you might want to ignore these files since the code is 109 | # intended to run in multiple environments; otherwise, check them in: 110 | # .python-version 111 | 112 | # pipenv 113 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 114 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 115 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 116 | # install all needed dependencies. 117 | #Pipfile.lock 118 | 119 | # poetry 120 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 121 | # This is especially recommended for binary packages to ensure reproducibility, and is more 122 | # commonly ignored for libraries. 123 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 124 | #poetry.lock 125 | 126 | # pdm 127 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 128 | #pdm.lock 129 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 130 | # in version control. 131 | # https://pdm.fming.dev/#use-with-ide 132 | .pdm.toml 133 | 134 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 135 | __pypackages__/ 136 | 137 | # Celery stuff 138 | celerybeat-schedule 139 | celerybeat.pid 140 | 141 | # SageMath parsed files 142 | *.sage.py 143 | 144 | # Environments 145 | .env 146 | .venv 147 | env/ 148 | venv/ 149 | ENV/ 150 | env.bak/ 151 | venv.bak/ 152 | 153 | # Spyder project settings 154 | .spyderproject 155 | .spyproject 156 | 157 | # Rope project settings 158 | .ropeproject 159 | 160 | # mkdocs documentation 161 | /site 162 | 163 | # mypy 164 | .mypy_cache/ 165 | .dmypy.json 166 | dmypy.json 167 | 168 | # Pyre type checker 169 | .pyre/ 170 | 171 | # pytype static type analyzer 172 | .pytype/ 173 | 174 | # Cython debug symbols 175 | cython_debug/ 176 | 177 | # PyCharm 178 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 179 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 180 | # and can be added to the global gitignore or merged into this file. For a more nuclear 181 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 182 | #.idea/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Andriy Toloshny 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## flameTimewarpML 2 | 3 | Previous versions: [Readme v0.4.4](https://github.com/talosh/flameTimewarpML/blob/main/README_v044.md) 4 | 5 | ## Table of Contents 6 | - [Status](#status) 7 | - [Installation](#installation) 8 | - [Training](#training) 9 | 10 | ### Installation 11 | 12 | #### Installing from package - Linux 13 | * Download parts and run 14 | ```bash 15 | cat flameTimewarpML_v0.4.5-dev003.Linux.tar.gz.part-* > flameTimewarpML_v0.4.5-dev003.Linux.tar.gz 16 | tar -xvf flameTimewarpML_v0.4.5-dev003.Linux.tar.gz 17 | ``` 18 | * Place "flameTimewarpML" folder into Flame python hooks path. 19 | 20 | #### Installing from package - MacOS 21 | * Unpack and run fix-xattr.command 22 | * Place "flameTimewarpML" folder into Flame python hooks path. 23 | 24 | ### Installing and configuring python environment manually 25 | 26 | * pre-configured miniconda environment should be placed into hidden "packages/.miniconda" folder 27 | * the folder is hidden (starts with ".") in order to keep Flame from scanning it looking for python hooks 28 | * pre-configured python environment usually packed with release tar file 29 | 30 | * download Miniconda for Mac or Linux (I'm using python 3.11 for tests) from 31 | 32 | 33 | * install downloaded Miniconda python distribution, use "-p" to select install location. For example: 34 | 35 | ```bash 36 | sh ~/Downloads/Miniconda3-py311_24.7.1-0-Linux-x86_64.sh -bfsm -p ~/miniconda3 37 | ``` 38 | 39 | * Activate anc clone default environment into another named "appenv" 40 | 41 | ```bash 42 | eval "$(~/miniconda3/bin/conda shell.bash hook)" 43 | conda create --name appenv 44 | conda activate appenv 45 | ``` 46 | 47 | * Install dependency libraries 48 | 49 | ```bash 50 | conda install python=3.11 conda-forge::openimageio conda-forge::py-openimageio 51 | conda install pyqt 52 | conda install conda-pack 53 | ``` 54 | 55 | * Install pytorch. Please look up exact commands depending on OS and Cuda versions 56 | 57 | * Linux example 58 | ```bash 59 | conda install pytorch torchvision pytorch-cuda=11.8 -c pytorch -c nvidia 60 | ``` 61 | 62 | * MacOS example: 63 | 64 | ```bash 65 | conda install pytorch::pytorch torchvision -c pytorch 66 | ``` 67 | 68 | * Install rest of the dependencies 69 | ```bash 70 | pip install -r {flameTimewarpML folder}/requirements.txt 71 | ``` 72 | 73 | * Pack append environment into a portable tar file 74 | 75 | ```bash 76 | conda pack --ignore-missing-files -n appenv 77 | ``` 78 | 79 | * Unpack environment to flameTimewarpML folder 80 | 81 | ```bash 82 | mkdir {flameTimewarpML folder}/packages/.miniconda/appenv/ 83 | tar xvf appenv.tar.gz -C {flameTimewarpML folder}/packages/.miniconda/appenv/ 84 | ``` 85 | 86 | * Remove environment tarball 87 | 88 | ```bash 89 | rm appenv.tar.gz 90 | ``` 91 | 92 | ### Training 93 | 94 | 95 | 96 | #### Finetune for specific shot or set of shots 97 | Finetune option is avaliable as a menu item starting from 0.4.5 dev 003 98 | 99 | #### Finetune using command line script 100 | 101 | Export as Linear ACEScg (AP1) Uncompressed OpenEXR sequence. 102 | Export your shots so each shot are in separate folder. 103 | Copy pre-trained model file (large or lite) to the file to train with. 104 | If motion is fast place the whole shot or its parts with fast motion to "fast" folder. 105 | 106 | ```bash 107 | cd {flameTimewarpML folder} 108 | ./train.sh --state_file {Path to copied state file}/MyShot001.pth --generalize 1 --lr 4e-6 --acescc 0 --onecycle 1000 {Path to shot}/{fast}/ 109 | ``` 110 | 111 | * Change number after "--onecycle" to set number of runs. 112 | * Use "--acescc 0" to train in Linear or to retain input colourspace, "--acescc 100" to convert all samples to Log. 113 | * Use "--frame_size" to modify training samples size 114 | * Preview last 9 training patches in "{Path to shot}/preview" folder 115 | * Use "--preview" to modify how frequently preview files are saved 116 | 117 | #### Train your own model 118 | ```bash 119 | cd {flameTimewarpML folder} 120 | ./train.sh --state_file {Path to MyModel}/MyModel001.pth --model flownet4_v004 --batch_size 4 {Path to Dataset}/ 121 | ``` 122 | 123 | #### Batch size and learning rate 124 | The batch size and learning rate are two crucial hyperparameters that significantly affect the training process and empirical tuning is necessary here. 125 | When the batch size is increased, the learning rate can often be increased as well. A common heuristic is the linear scaling rule: when you multiply the batch size by a factor 126 | k, you can also multiply the learning rate by k. Another approach is the square root scaling rule: when you multiply the batch size by k multiply the learning rate by sqrt(k) 127 | 128 | 129 | #### Dataset preparation 130 | Training script will scan all the folders under a given path and will compose training samples out of folders where .exr files are found. 131 | Only Uncompressed OpenEXR files are supported at the moment. 132 | 133 | Export your shots so each shot are in separate folder. 134 | There are two magic words in shot path: "fast" and "slow" 135 | When "fast" is in path 3-frame window will be used for those shots. 136 | When "slow" is in path 9-frame window will be used. 137 | Default window size is 5 frames. 138 | Put shots with fast motion in "fast" folder and shots where motion are slow and continious to "slow" folder to let the model learn more intermediae ratios. 139 | 140 | - Scene001/ 141 | - slow/ 142 | - Shot001/ 143 | - Shot001.001.exr 144 | - Shot001.002.exr 145 | - ... 146 | - fast/ 147 | - Shot002 148 | - Shot002.001.exr 149 | - Shot002.002.exr 150 | - ... 151 | - normal/ 152 | - Shot003 153 | - Shot003.001.exr 154 | - Shot003.002.exr 155 | - ... 156 | 157 | #### Window size 158 | Samples for training are 3 frames and ratio. The model is given the first and the last frame and tries to re-create middle frame at given ratio. 159 | 160 | [TODO] - add documentation 161 | -------------------------------------------------------------------------------- /README_v044.md: -------------------------------------------------------------------------------- 1 | ## flameTimewarpML 2 | 3 | Machine Learning frame interpolation tool for Autodesk Flame. 4 | 5 | Based on arXiv2020-RIFE, original implementation: 6 | 7 | ```data 8 | @article{huang2020rife, 9 | title={RIFE: Real-Time Intermediate Flow Estimation for Video Frame Interpolation}, 10 | author={Huang, Zhewei and Zhang, Tianyuan and Heng, Wen and Shi, Boxin and Zhou, Shuchang}, 11 | journal={arXiv preprint arXiv:2011.06294}, 12 | year={2020} 13 | } 14 | ``` 15 | 16 | Flame's animation curve interpolation code from Julik Tarkhanov: 17 | 18 | 19 | ## Installation 20 | 21 | * Creating Miniconda environment with dependencies: 22 | 23 | ```sh Miniconda3-py311_24.1.2-0-MacOSX-arm64.sh 24 | eval "$(~/miniconda3/bin/conda shell.zsh hook)" 25 | conda create --name twml --clone base 26 | conda install numpy 27 | pip install PySide6 28 | conda install pytorch::pytorch -c pytorch 29 | conda install conda-pack 30 | conda pack --ignore-missing-files -n twml 31 | tar xvf twml.tar.gz -C {flameTimewarpML folder}/packages/.miniconda/twml/ 32 | ``` 33 | 34 | 35 | ### Single workstation / Easy install 36 | 37 | * Download latest release from [Releases](https://github.com/talosh/flameTimewarpML/releases) page 38 | * All you need to do is put the .py file in /opt/Autodesk/shared/python and launch/relaunch flame. The first time it will unpack/install what it needs. It will give you a prompt to let you know when it’s finished. 39 | * It is possible to choose the installation folder. Default installation folders are: 40 | 41 | ```bash 42 | ~/Documents/flameTimewarpML # Mac 43 | ~/flameTimewarpML # Linux 44 | ``` 45 | 46 | ### Preferences location 47 | 48 | * In case you'd need to reset preferences 49 | 50 | ```bash 51 | ~/Library/Preferences/flameTimewarpML # Mac 52 | ~/.config/flameTimewarpML # Linux 53 | ``` 54 | 55 | ### Configuration via env variables 56 | 57 | * There are an option to set working folder over env variable 58 | Setting FLAMETWML_DEFAULT_WORK_FOLDER will set default folder and user will still be allowed to change it 59 | 60 | ```bash 61 | export FLAMETWML_DEFAULT_WORK_FOLDER=/Volumes/projects/my_timewarps/ 62 | ``` 63 | 64 | * Setting FLAMETWML_WORK_FOLDER will block user from changing it. This read every time just before the tool is launched so one can use it with other pipeline tools to change it dynamically depending on context 65 | 66 | ```bash 67 | export FLAMETWML_WORK_FOLDER=/Volumes/projects/my_timewarps/ 68 | ``` 69 | 70 | * Setting FLAMETWML_HARDCOMMIT will lead to imported clip hard commited after import and all temporary files deleted 71 | 72 | ```bash 73 | export FLAMETWML_HARDCOMMIT=True 74 | ``` 75 | ### Updating / Downgrading PyTorch to match CUDA version 76 | 77 | This example will change Pytorch to cuda11 78 | 79 | * Get pytorch build for python 3.8 and cuda11: 80 | ``` 81 | https://download.pytorch.org/whl/cu110/torch-1.7.1%2Bcu110-cp38-cp38-linux_x86_64.whl 82 | ``` 83 | * Init miniconda3 environment: 84 | ``` 85 | ~/flameTimewarpML/bundle/init_env 86 | ``` 87 | * In new terminal window that opens check current pytorch cuda version: 88 | ``` 89 | python -c "import torch; print(torch.version.cuda)" 90 | ``` 91 | this should report 10.2 92 | * Update pytorch 93 | ``` 94 | pip3 install --upgrade ~/Downloads/torch-1.7.1+cu110-cp38-cp38-linux_x86_64.whl 95 | ``` 96 | * Check pytorch cuda version again: 97 | ``` 98 | python -c "import torch; print(torch.version.cuda)" 99 | ``` 100 | this now should report 11.0 101 | 102 | 103 | ### Centralised / Manual Install 104 | 105 | * It is possible to do easy install on one of the workstations to common location and then configure other via env variables (see Centralised configuration section) 106 | * flameTimewarpML is made of three components. 107 | 1. Miniconda3 isolated Python 3.8 environment with some additional dependence libraries (see requirements.txt) 108 | 2. A set of python scripts that are called from command line and should be running within Python 3.8 environment. Those two parts are not dependant on Flame and can be run as a standalone tools. 109 | 3. Flame script to be run inside Flame. It has to know the location of two previous parts to be able to initialize Python 3.8 environment and then run command line tools to process image sequences. 110 | * Get the source from [Releases](https://github.com/talosh/flameTimewarpML/releases) page or directly from repo: 111 | 112 | ```bash 113 | git clone https://github.com/talosh/flameTimewarpML.git 114 | ``` 115 | 116 | * It is possible to share Python3 code location between platforms, but Miniconda3 python environment should be placed separately for Mac and Linux. 117 | This example will use: 118 | 119 | ```bash 120 | Python3 code location: 121 | /Volumes/software/flameTimewarpML # Mac 122 | /mnt/software/flameTimewarpML # Linux 123 | 124 | miniconda3 location: 125 | /Volumes/software/miniconda3 # Mac 126 | /mnt/software/miniconda3 # Linux 127 | ``` 128 | 129 | * Do not place those folders inside Flame's python hooks folder. 130 | The only file that should be placed in Flame hooks folder is flameTimewarpML.py 131 | 132 | * Create folders 133 | 134 | ```bash 135 | mkdir -p /Volumes/software/flameTimewarpML # Mac 136 | mkdir -p /Volumes/software/miniconda3 # Mac 137 | 138 | mkdir -p /mnt/software/flameTimewarpML # Linux 139 | mkdir -p /mnt/software/miniconda3 # Linux 140 | ``` 141 | 142 | * Copy the contents of 'bundle' folder to the code location. 143 | 144 | ```bash 145 | cp -a bundle/* /Volumes/software/flameTimewarpML/ # Mac 146 | cp -a bundle/* /mnt/software/flameTimewarpML/ # Linux 147 | ``` 148 | 149 | * Download Miniconda3 for Python 3.8 from 150 | Shell installer is used for this example. 151 | 152 | * Install Miniconda3 to centralized location: 153 | 154 | ```bash 155 | sh Miniconda3-latest-MacOSX-x86_64.sh -bu -p /Volumes/software/miniconda3/ # Mac 156 | sh Miniconda3-latest-Linux-x86_64.sh -bu -p /mnt/software/miniconda3/ # Linux 157 | ``` 158 | 159 | * In case you'd like to move Miniconda3 installation later you'll have to re-install it again. 160 | Please refer to Anaconda docs: 161 | 162 | * Init installed Miniconda3 environment: 163 | 164 | ```bash 165 | /Volumes/software/flameTimewarpML/init_env /Volumes/software/miniconda3/ # Mac 166 | /mnt/software/flameTimewarpML/init_env /mnt/software/miniconda3/ # Linux 167 | ``` 168 | 169 | The script will open new konsole window. 170 | 171 | * In case you do it over ssh remotely on Linux edit the line at the very end of init_env 172 | 173 | ```bash 174 | cmd_prefix = 'konsole -e /bin/bash --rcfile ' + tmp_bash_rc_file 175 | ``` 176 | 177 | ```bash 178 | cmd_prefix = '/bin/bash --rcfile ' + tmp_bash_rc_file 179 | ``` 180 | 181 | * To check if environment is properly initialized: there should be (base) before shell prompt. python --version should give Python 3.8.5 or greater 182 | 183 | * From this shell prompt install dependency libraries: 184 | 185 | ```bash 186 | pip3 install -r /Volumes/software/flameTimewarpML/requirements.txt # Mac 187 | pip3 install -r /mnt/software/flameTimewarpML/requirements.txt # Linux 188 | ``` 189 | 190 | * In case pip fails to find a match for torch, install torch with the '-f' flag: 191 | 192 | ```bash 193 | pip3 install torch=="1.10.2+cu111" -f https://download.pytorch.org/whl/torch_stable.html 194 | ``` 195 | 196 | * It is possible to download the packages and install it without internet connection. Install Miniconda3 on the machine that is connected to internet and download dependency packages: 197 | 198 | ```bash 199 | pip3 download -r bundle/requirements.txt -d packages_folder 200 | ``` 201 | 202 | then it is possible to install packages from folder: 203 | 204 | ```bash 205 | pip3 install -r /Volumes/software/flameTimewarpML/requirements.txt --no-index --find-links /Volumes/software/flameTimewarpML/packages_folder 206 | ``` 207 | 208 | * Copy flameTimewarpML.py to your Flame python scripts folder 209 | 210 | * set FLAMETWML_BUNDLE, FLAMETWML_MINICONDA environment variables to point code and miniconda locations. 211 | 212 | ```bash 213 | export FLAMETWML_BUNDLE=/Volumes/software/flameTimewarpML/ # Mac 214 | export FLAMETWML_MINICONDA=//Volumes/software/miniconda3/ # Mac 215 | 216 | export FLAMETWML_BUNDLE=/mnt/software/flameTimewarpML/ # Linux 217 | export FLAMETWML_MINICONDA=/mnt/software/miniconda3/ # Linux 218 | ``` 219 | 220 | More granular settings per platform possible with 221 | 222 | ```bash 223 | export FLAMETWML_BUNDLE_MAC=/Volumes/software/flameTimewarpML/ 224 | export FLAMETWML_BUNDLE_LINUX=/mnt/software/flameTimewarpML/ 225 | export FLAMETWML_MINICONDA_MAC=/Volumes/software/miniconda3/ 226 | export FLAMETWML_MINICONDA_LINUX=mnt/software/miniconda3/ 227 | ``` 228 | -------------------------------------------------------------------------------- /baselight/uk.co.andriy.mltimewarp.v1/IFNet_HDv3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from warplayer import warp 5 | 6 | def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): 7 | return nn.Sequential( 8 | torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1), 9 | nn.PReLU(out_planes) 10 | ) 11 | 12 | def conv_wo_act(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 13 | return nn.Sequential( 14 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, 15 | padding=padding, dilation=dilation, bias=True), 16 | ) 17 | 18 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 19 | return nn.Sequential( 20 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, 21 | padding=padding, dilation=dilation, bias=True), 22 | nn.PReLU(out_planes) 23 | ) 24 | 25 | class IFBlock(nn.Module): 26 | def __init__(self, in_planes, c=64): 27 | super(IFBlock, self).__init__() 28 | self.conv0 = nn.Sequential( 29 | conv(in_planes, c, 3, 2, 1), 30 | conv(c, 2*c, 3, 2, 1), 31 | ) 32 | self.convblock0 = nn.Sequential( 33 | conv(2*c, 2*c), 34 | conv(2*c, 2*c), 35 | ) 36 | self.convblock1 = nn.Sequential( 37 | conv(2*c, 2*c), 38 | conv(2*c, 2*c), 39 | ) 40 | self.convblock2 = nn.Sequential( 41 | conv(2*c, 2*c), 42 | conv(2*c, 2*c), 43 | ) 44 | self.conv1 = nn.ConvTranspose2d(2*c, 4, 4, 2, 1) 45 | 46 | def forward(self, x, flow=None, scale=1): 47 | x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False) 48 | if flow != None: 49 | flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False) * (1. / scale) 50 | x = torch.cat((x, flow), 1) 51 | x = self.conv0(x) 52 | x = self.convblock0(x) + x 53 | x = self.convblock1(x) + x 54 | x = self.convblock2(x) + x 55 | x = self.conv1(x) 56 | flow = x 57 | if scale != 1: 58 | flow = F.interpolate(flow, scale_factor= scale, mode="bilinear", align_corners=False) * scale 59 | return flow 60 | 61 | class IFNet(nn.Module): 62 | def __init__(self): 63 | super(IFNet, self).__init__() 64 | self.block0 = IFBlock(6, c=80) 65 | self.block1 = IFBlock(10, c=80) 66 | self.block2 = IFBlock(10, c=80) 67 | 68 | def set_device(self,idevice): 69 | self.device = idevice 70 | 71 | def forward(self, x, scale_list=[4,2,1], scale=1.0): 72 | x = F.interpolate(x, scale_factor=scale, mode="bilinear", align_corners=False) 73 | flow0 = self.block0(x, scale=scale_list[0]) 74 | F1 = flow0 75 | F1_large = F.interpolate(F1, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0 76 | warped_img0 = warp(x[:, :3], F1_large[:, :2], device=self.device) 77 | warped_img1 = warp(x[:, 3:], F1_large[:, 2:4], device=self.device) 78 | flow1 = self.block1(torch.cat((warped_img0, warped_img1), 1), F1_large, scale=scale_list[1]) 79 | F2 = (flow0 + flow1) 80 | F2_large = F.interpolate(F2, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0 81 | warped_img0 = warp(x[:, :3], F2_large[:, :2], device=self.device) 82 | warped_img1 = warp(x[:, 3:], F2_large[:, 2:4], device=self.device) 83 | flow2 = self.block2(torch.cat((warped_img0, warped_img1), 1), F2_large, scale=scale_list[2]) 84 | F3 = (flow0 + flow1 + flow2) 85 | if scale != 1.0: 86 | F3 = F.interpolate(F3, scale_factor=1 / scale, mode="bilinear", align_corners=False) / scale 87 | return F3, [F1, F2, F3] 88 | -------------------------------------------------------------------------------- /baselight/uk.co.andriy.mltimewarp.v1/LICENCE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Andriy Toloshny 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /baselight/uk.co.andriy.mltimewarp.v1/VERSION: -------------------------------------------------------------------------------- 1 | 0.0.1 dev 001 2 | -------------------------------------------------------------------------------- /baselight/uk.co.andriy.mltimewarp.v1/bin/Activate.ps1: -------------------------------------------------------------------------------- 1 | <# 2 | .Synopsis 3 | Activate a Python virtual environment for the current PowerShell session. 4 | 5 | .Description 6 | Pushes the python executable for a virtual environment to the front of the 7 | $Env:PATH environment variable and sets the prompt to signify that you are 8 | in a Python virtual environment. Makes use of the command line switches as 9 | well as the `pyvenv.cfg` file values present in the virtual environment. 10 | 11 | .Parameter VenvDir 12 | Path to the directory that contains the virtual environment to activate. The 13 | default value for this is the parent of the directory that the Activate.ps1 14 | script is located within. 15 | 16 | .Parameter Prompt 17 | The prompt prefix to display when this virtual environment is activated. By 18 | default, this prompt is the name of the virtual environment folder (VenvDir) 19 | surrounded by parentheses and followed by a single space (ie. '(.venv) '). 20 | 21 | .Example 22 | Activate.ps1 23 | Activates the Python virtual environment that contains the Activate.ps1 script. 24 | 25 | .Example 26 | Activate.ps1 -Verbose 27 | Activates the Python virtual environment that contains the Activate.ps1 script, 28 | and shows extra information about the activation as it executes. 29 | 30 | .Example 31 | Activate.ps1 -VenvDir C:\Users\MyUser\Common\.venv 32 | Activates the Python virtual environment located in the specified location. 33 | 34 | .Example 35 | Activate.ps1 -Prompt "MyPython" 36 | Activates the Python virtual environment that contains the Activate.ps1 script, 37 | and prefixes the current prompt with the specified string (surrounded in 38 | parentheses) while the virtual environment is active. 39 | 40 | .Notes 41 | On Windows, it may be required to enable this Activate.ps1 script by setting the 42 | execution policy for the user. You can do this by issuing the following PowerShell 43 | command: 44 | 45 | PS C:\> Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser 46 | 47 | For more information on Execution Policies: 48 | https://go.microsoft.com/fwlink/?LinkID=135170 49 | 50 | #> 51 | Param( 52 | [Parameter(Mandatory = $false)] 53 | [String] 54 | $VenvDir, 55 | [Parameter(Mandatory = $false)] 56 | [String] 57 | $Prompt 58 | ) 59 | 60 | <# Function declarations --------------------------------------------------- #> 61 | 62 | <# 63 | .Synopsis 64 | Remove all shell session elements added by the Activate script, including the 65 | addition of the virtual environment's Python executable from the beginning of 66 | the PATH variable. 67 | 68 | .Parameter NonDestructive 69 | If present, do not remove this function from the global namespace for the 70 | session. 71 | 72 | #> 73 | function global:deactivate ([switch]$NonDestructive) { 74 | # Revert to original values 75 | 76 | # The prior prompt: 77 | if (Test-Path -Path Function:_OLD_VIRTUAL_PROMPT) { 78 | Copy-Item -Path Function:_OLD_VIRTUAL_PROMPT -Destination Function:prompt 79 | Remove-Item -Path Function:_OLD_VIRTUAL_PROMPT 80 | } 81 | 82 | # The prior PYTHONHOME: 83 | if (Test-Path -Path Env:_OLD_VIRTUAL_PYTHONHOME) { 84 | Copy-Item -Path Env:_OLD_VIRTUAL_PYTHONHOME -Destination Env:PYTHONHOME 85 | Remove-Item -Path Env:_OLD_VIRTUAL_PYTHONHOME 86 | } 87 | 88 | # The prior PATH: 89 | if (Test-Path -Path Env:_OLD_VIRTUAL_PATH) { 90 | Copy-Item -Path Env:_OLD_VIRTUAL_PATH -Destination Env:PATH 91 | Remove-Item -Path Env:_OLD_VIRTUAL_PATH 92 | } 93 | 94 | # Just remove the VIRTUAL_ENV altogether: 95 | if (Test-Path -Path Env:VIRTUAL_ENV) { 96 | Remove-Item -Path env:VIRTUAL_ENV 97 | } 98 | 99 | # Just remove the _PYTHON_VENV_PROMPT_PREFIX altogether: 100 | if (Get-Variable -Name "_PYTHON_VENV_PROMPT_PREFIX" -ErrorAction SilentlyContinue) { 101 | Remove-Variable -Name _PYTHON_VENV_PROMPT_PREFIX -Scope Global -Force 102 | } 103 | 104 | # Leave deactivate function in the global namespace if requested: 105 | if (-not $NonDestructive) { 106 | Remove-Item -Path function:deactivate 107 | } 108 | } 109 | 110 | <# 111 | .Description 112 | Get-PyVenvConfig parses the values from the pyvenv.cfg file located in the 113 | given folder, and returns them in a map. 114 | 115 | For each line in the pyvenv.cfg file, if that line can be parsed into exactly 116 | two strings separated by `=` (with any amount of whitespace surrounding the =) 117 | then it is considered a `key = value` line. The left hand string is the key, 118 | the right hand is the value. 119 | 120 | If the value starts with a `'` or a `"` then the first and last character is 121 | stripped from the value before being captured. 122 | 123 | .Parameter ConfigDir 124 | Path to the directory that contains the `pyvenv.cfg` file. 125 | #> 126 | function Get-PyVenvConfig( 127 | [String] 128 | $ConfigDir 129 | ) { 130 | Write-Verbose "Given ConfigDir=$ConfigDir, obtain values in pyvenv.cfg" 131 | 132 | # Ensure the file exists, and issue a warning if it doesn't (but still allow the function to continue). 133 | $pyvenvConfigPath = Join-Path -Resolve -Path $ConfigDir -ChildPath 'pyvenv.cfg' -ErrorAction Continue 134 | 135 | # An empty map will be returned if no config file is found. 136 | $pyvenvConfig = @{ } 137 | 138 | if ($pyvenvConfigPath) { 139 | 140 | Write-Verbose "File exists, parse `key = value` lines" 141 | $pyvenvConfigContent = Get-Content -Path $pyvenvConfigPath 142 | 143 | $pyvenvConfigContent | ForEach-Object { 144 | $keyval = $PSItem -split "\s*=\s*", 2 145 | if ($keyval[0] -and $keyval[1]) { 146 | $val = $keyval[1] 147 | 148 | # Remove extraneous quotations around a string value. 149 | if ("'""".Contains($val.Substring(0, 1))) { 150 | $val = $val.Substring(1, $val.Length - 2) 151 | } 152 | 153 | $pyvenvConfig[$keyval[0]] = $val 154 | Write-Verbose "Adding Key: '$($keyval[0])'='$val'" 155 | } 156 | } 157 | } 158 | return $pyvenvConfig 159 | } 160 | 161 | 162 | <# Begin Activate script --------------------------------------------------- #> 163 | 164 | # Determine the containing directory of this script 165 | $VenvExecPath = Split-Path -Parent $MyInvocation.MyCommand.Definition 166 | $VenvExecDir = Get-Item -Path $VenvExecPath 167 | 168 | Write-Verbose "Activation script is located in path: '$VenvExecPath'" 169 | Write-Verbose "VenvExecDir Fullname: '$($VenvExecDir.FullName)" 170 | Write-Verbose "VenvExecDir Name: '$($VenvExecDir.Name)" 171 | 172 | # Set values required in priority: CmdLine, ConfigFile, Default 173 | # First, get the location of the virtual environment, it might not be 174 | # VenvExecDir if specified on the command line. 175 | if ($VenvDir) { 176 | Write-Verbose "VenvDir given as parameter, using '$VenvDir' to determine values" 177 | } 178 | else { 179 | Write-Verbose "VenvDir not given as a parameter, using parent directory name as VenvDir." 180 | $VenvDir = $VenvExecDir.Parent.FullName.TrimEnd("\\/") 181 | Write-Verbose "VenvDir=$VenvDir" 182 | } 183 | 184 | # Next, read the `pyvenv.cfg` file to determine any required value such 185 | # as `prompt`. 186 | $pyvenvCfg = Get-PyVenvConfig -ConfigDir $VenvDir 187 | 188 | # Next, set the prompt from the command line, or the config file, or 189 | # just use the name of the virtual environment folder. 190 | if ($Prompt) { 191 | Write-Verbose "Prompt specified as argument, using '$Prompt'" 192 | } 193 | else { 194 | Write-Verbose "Prompt not specified as argument to script, checking pyvenv.cfg value" 195 | if ($pyvenvCfg -and $pyvenvCfg['prompt']) { 196 | Write-Verbose " Setting based on value in pyvenv.cfg='$($pyvenvCfg['prompt'])'" 197 | $Prompt = $pyvenvCfg['prompt']; 198 | } 199 | else { 200 | Write-Verbose " Setting prompt based on parent's directory's name. (Is the directory name passed to venv module when creating the virutal environment)" 201 | Write-Verbose " Got leaf-name of $VenvDir='$(Split-Path -Path $venvDir -Leaf)'" 202 | $Prompt = Split-Path -Path $venvDir -Leaf 203 | } 204 | } 205 | 206 | Write-Verbose "Prompt = '$Prompt'" 207 | Write-Verbose "VenvDir='$VenvDir'" 208 | 209 | # Deactivate any currently active virtual environment, but leave the 210 | # deactivate function in place. 211 | deactivate -nondestructive 212 | 213 | # Now set the environment variable VIRTUAL_ENV, used by many tools to determine 214 | # that there is an activated venv. 215 | $env:VIRTUAL_ENV = $VenvDir 216 | 217 | if (-not $Env:VIRTUAL_ENV_DISABLE_PROMPT) { 218 | 219 | Write-Verbose "Setting prompt to '$Prompt'" 220 | 221 | # Set the prompt to include the env name 222 | # Make sure _OLD_VIRTUAL_PROMPT is global 223 | function global:_OLD_VIRTUAL_PROMPT { "" } 224 | Copy-Item -Path function:prompt -Destination function:_OLD_VIRTUAL_PROMPT 225 | New-Variable -Name _PYTHON_VENV_PROMPT_PREFIX -Description "Python virtual environment prompt prefix" -Scope Global -Option ReadOnly -Visibility Public -Value $Prompt 226 | 227 | function global:prompt { 228 | Write-Host -NoNewline -ForegroundColor Green "($_PYTHON_VENV_PROMPT_PREFIX) " 229 | _OLD_VIRTUAL_PROMPT 230 | } 231 | } 232 | 233 | # Clear PYTHONHOME 234 | if (Test-Path -Path Env:PYTHONHOME) { 235 | Copy-Item -Path Env:PYTHONHOME -Destination Env:_OLD_VIRTUAL_PYTHONHOME 236 | Remove-Item -Path Env:PYTHONHOME 237 | } 238 | 239 | # Add the venv to the PATH 240 | Copy-Item -Path Env:PATH -Destination Env:_OLD_VIRTUAL_PATH 241 | $Env:PATH = "$VenvExecDir$([System.IO.Path]::PathSeparator)$Env:PATH" 242 | -------------------------------------------------------------------------------- /baselight/uk.co.andriy.mltimewarp.v1/bin/activate: -------------------------------------------------------------------------------- 1 | # This file must be used with "source bin/activate" *from bash* 2 | # you cannot run it directly 3 | 4 | deactivate () { 5 | # reset old environment variables 6 | if [ -n "${_OLD_VIRTUAL_PATH:-}" ] ; then 7 | PATH="${_OLD_VIRTUAL_PATH:-}" 8 | export PATH 9 | unset _OLD_VIRTUAL_PATH 10 | fi 11 | if [ -n "${_OLD_VIRTUAL_PYTHONHOME:-}" ] ; then 12 | PYTHONHOME="${_OLD_VIRTUAL_PYTHONHOME:-}" 13 | export PYTHONHOME 14 | unset _OLD_VIRTUAL_PYTHONHOME 15 | fi 16 | 17 | # This should detect bash and zsh, which have a hash command that must 18 | # be called to get it to forget past commands. Without forgetting 19 | # past commands the $PATH changes we made may not be respected 20 | if [ -n "${BASH:-}" -o -n "${ZSH_VERSION:-}" ] ; then 21 | hash -r 2> /dev/null 22 | fi 23 | 24 | if [ -n "${_OLD_VIRTUAL_PS1:-}" ] ; then 25 | PS1="${_OLD_VIRTUAL_PS1:-}" 26 | export PS1 27 | unset _OLD_VIRTUAL_PS1 28 | fi 29 | 30 | unset VIRTUAL_ENV 31 | if [ ! "${1:-}" = "nondestructive" ] ; then 32 | # Self destruct! 33 | unset -f deactivate 34 | fi 35 | } 36 | 37 | # unset irrelevant variables 38 | deactivate nondestructive 39 | 40 | VIRTUAL_ENV="/usr/fl/flexi/effects/uk.ltd.filmlight.rife_retime.v1" 41 | export VIRTUAL_ENV 42 | 43 | _OLD_VIRTUAL_PATH="$PATH" 44 | PATH="$VIRTUAL_ENV/bin:$PATH" 45 | export PATH 46 | 47 | # unset PYTHONHOME if set 48 | # this will fail if PYTHONHOME is set to the empty string (which is bad anyway) 49 | # could use `if (set -u; : $PYTHONHOME) ;` in bash 50 | if [ -n "${PYTHONHOME:-}" ] ; then 51 | _OLD_VIRTUAL_PYTHONHOME="${PYTHONHOME:-}" 52 | unset PYTHONHOME 53 | fi 54 | 55 | if [ -z "${VIRTUAL_ENV_DISABLE_PROMPT:-}" ] ; then 56 | _OLD_VIRTUAL_PS1="${PS1:-}" 57 | PS1="(uk.ltd.filmlight.rife_retime.v1) ${PS1:-}" 58 | export PS1 59 | fi 60 | 61 | # This should detect bash and zsh, which have a hash command that must 62 | # be called to get it to forget past commands. Without forgetting 63 | # past commands the $PATH changes we made may not be respected 64 | if [ -n "${BASH:-}" -o -n "${ZSH_VERSION:-}" ] ; then 65 | hash -r 2> /dev/null 66 | fi 67 | -------------------------------------------------------------------------------- /baselight/uk.co.andriy.mltimewarp.v1/bin/activate.csh: -------------------------------------------------------------------------------- 1 | # This file must be used with "source bin/activate.csh" *from csh*. 2 | # You cannot run it directly. 3 | # Created by Davide Di Blasi . 4 | # Ported to Python 3.3 venv by Andrew Svetlov 5 | 6 | alias deactivate 'test $?_OLD_VIRTUAL_PATH != 0 && setenv PATH "$_OLD_VIRTUAL_PATH" && unset _OLD_VIRTUAL_PATH; rehash; test $?_OLD_VIRTUAL_PROMPT != 0 && set prompt="$_OLD_VIRTUAL_PROMPT" && unset _OLD_VIRTUAL_PROMPT; unsetenv VIRTUAL_ENV; test "\!:*" != "nondestructive" && unalias deactivate' 7 | 8 | # Unset irrelevant variables. 9 | deactivate nondestructive 10 | 11 | setenv VIRTUAL_ENV "/usr/fl/flexi/effects/uk.ltd.filmlight.rife_retime.v1" 12 | 13 | set _OLD_VIRTUAL_PATH="$PATH" 14 | setenv PATH "$VIRTUAL_ENV/bin:$PATH" 15 | 16 | 17 | set _OLD_VIRTUAL_PROMPT="$prompt" 18 | 19 | if (! "$?VIRTUAL_ENV_DISABLE_PROMPT") then 20 | set prompt = "(uk.ltd.filmlight.rife_retime.v1) $prompt" 21 | endif 22 | 23 | alias pydoc python -m pydoc 24 | 25 | rehash 26 | -------------------------------------------------------------------------------- /baselight/uk.co.andriy.mltimewarp.v1/bin/activate.fish: -------------------------------------------------------------------------------- 1 | # This file must be used with "source /bin/activate.fish" *from fish* 2 | # (https://fishshell.com/); you cannot run it directly. 3 | 4 | function deactivate -d "Exit virtual environment and return to normal shell environment" 5 | # reset old environment variables 6 | if test -n "$_OLD_VIRTUAL_PATH" 7 | set -gx PATH $_OLD_VIRTUAL_PATH 8 | set -e _OLD_VIRTUAL_PATH 9 | end 10 | if test -n "$_OLD_VIRTUAL_PYTHONHOME" 11 | set -gx PYTHONHOME $_OLD_VIRTUAL_PYTHONHOME 12 | set -e _OLD_VIRTUAL_PYTHONHOME 13 | end 14 | 15 | if test -n "$_OLD_FISH_PROMPT_OVERRIDE" 16 | functions -e fish_prompt 17 | set -e _OLD_FISH_PROMPT_OVERRIDE 18 | functions -c _old_fish_prompt fish_prompt 19 | functions -e _old_fish_prompt 20 | end 21 | 22 | set -e VIRTUAL_ENV 23 | if test "$argv[1]" != "nondestructive" 24 | # Self-destruct! 25 | functions -e deactivate 26 | end 27 | end 28 | 29 | # Unset irrelevant variables. 30 | deactivate nondestructive 31 | 32 | set -gx VIRTUAL_ENV "/usr/fl/flexi/effects/uk.ltd.filmlight.rife_retime.v1" 33 | 34 | set -gx _OLD_VIRTUAL_PATH $PATH 35 | set -gx PATH "$VIRTUAL_ENV/bin" $PATH 36 | 37 | # Unset PYTHONHOME if set. 38 | if set -q PYTHONHOME 39 | set -gx _OLD_VIRTUAL_PYTHONHOME $PYTHONHOME 40 | set -e PYTHONHOME 41 | end 42 | 43 | if test -z "$VIRTUAL_ENV_DISABLE_PROMPT" 44 | # fish uses a function instead of an env var to generate the prompt. 45 | 46 | # Save the current fish_prompt function as the function _old_fish_prompt. 47 | functions -c fish_prompt _old_fish_prompt 48 | 49 | # With the original prompt function renamed, we can override with our own. 50 | function fish_prompt 51 | # Save the return status of the last command. 52 | set -l old_status $status 53 | 54 | # Output the venv prompt; color taken from the blue of the Python logo. 55 | printf "%s%s%s" (set_color 4B8BBE) "(uk.ltd.filmlight.rife_retime.v1) " (set_color normal) 56 | 57 | # Restore the return status of the previous command. 58 | echo "exit $old_status" | . 59 | # Output the original/"old" prompt. 60 | _old_fish_prompt 61 | end 62 | 63 | set -gx _OLD_FISH_PROMPT_OVERRIDE "$VIRTUAL_ENV" 64 | end 65 | -------------------------------------------------------------------------------- /baselight/uk.co.andriy.mltimewarp.v1/bin/easy_install: -------------------------------------------------------------------------------- 1 | #!/usr/fl/flexi/effects/uk.ltd.filmlight.rife_retime.v1/bin/python3.9 2 | # -*- coding: utf-8 -*- 3 | import re 4 | import sys 5 | from setuptools.command.easy_install import main 6 | if __name__ == '__main__': 7 | sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) 8 | sys.exit(main()) 9 | -------------------------------------------------------------------------------- /baselight/uk.co.andriy.mltimewarp.v1/bin/easy_install-3.9: -------------------------------------------------------------------------------- 1 | #!/usr/fl/flexi/effects/uk.ltd.filmlight.rife_retime.v1/bin/python3.9 2 | # -*- coding: utf-8 -*- 3 | import re 4 | import sys 5 | from setuptools.command.easy_install import main 6 | if __name__ == '__main__': 7 | sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) 8 | sys.exit(main()) 9 | -------------------------------------------------------------------------------- /baselight/uk.co.andriy.mltimewarp.v1/bin/pip: -------------------------------------------------------------------------------- 1 | #!/usr/fl/flexi/effects/uk.ltd.filmlight.rife_retime.v1/bin/python3.9 2 | # -*- coding: utf-8 -*- 3 | import re 4 | import sys 5 | from pip._internal.cli.main import main 6 | if __name__ == '__main__': 7 | sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) 8 | sys.exit(main()) 9 | -------------------------------------------------------------------------------- /baselight/uk.co.andriy.mltimewarp.v1/bin/pip3: -------------------------------------------------------------------------------- 1 | #!/usr/fl/flexi/effects/uk.ltd.filmlight.rife_retime.v1/bin/python3.9 2 | # -*- coding: utf-8 -*- 3 | import re 4 | import sys 5 | from pip._internal.cli.main import main 6 | if __name__ == '__main__': 7 | sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) 8 | sys.exit(main()) 9 | -------------------------------------------------------------------------------- /baselight/uk.co.andriy.mltimewarp.v1/bin/pip3.9: -------------------------------------------------------------------------------- 1 | #!/usr/fl/flexi/effects/uk.ltd.filmlight.rife_retime.v1/bin/python3.9 2 | # -*- coding: utf-8 -*- 3 | import re 4 | import sys 5 | from pip._internal.cli.main import main 6 | if __name__ == '__main__': 7 | sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) 8 | sys.exit(main()) 9 | -------------------------------------------------------------------------------- /baselight/uk.co.andriy.mltimewarp.v1/bin/python: -------------------------------------------------------------------------------- 1 | python3.9 -------------------------------------------------------------------------------- /baselight/uk.co.andriy.mltimewarp.v1/bin/python3: -------------------------------------------------------------------------------- 1 | python3.9 -------------------------------------------------------------------------------- /baselight/uk.co.andriy.mltimewarp.v1/bin/python3.9: -------------------------------------------------------------------------------- 1 | /usr/bin/python3.9 -------------------------------------------------------------------------------- /baselight/uk.co.andriy.mltimewarp.v1/flexi.py: -------------------------------------------------------------------------------- 1 | # Flexi Python module 2 | # 3 | # Copyright (c) 2021 FilmLight. All rights reserved. 4 | # npublished - rights reserved under the copyright laws of the 5 | # United States. Use of a copyright notice is precautionary only 6 | # and does not imply publication or disclosure. 7 | # This software contains confidential information and trade secrets 8 | # of FilmLight Limited. Use, disclosure, or reproduction 9 | # is prohibited without the prior express written permission of 10 | # FilmLight Limited. 11 | 12 | import sys 13 | import json 14 | import websocket 15 | import traceback 16 | import os 17 | 18 | # Globals 19 | Flexi_registered_effects = {} 20 | Flexi_effect_for_id = {} 21 | Flexi_ws = None 22 | 23 | def register_id(id, effect): 24 | global Flexi_effect_for_id, Flexi_registered_effects 25 | Flexi_registered_effects[effect] = 1 26 | Flexi_effect_for_id[id] = effect 27 | 28 | def run(): 29 | # Open connection to the application using the URL passed in the environment 30 | global Flexi_ws 31 | 32 | if(not 'FLEXI_SERVER_URL' in os.environ): 33 | print('flexi.run() called without websocket URL. Running in test mode.') 34 | from jsonschema import validate 35 | with open('flexi_reply.schema') as f: 36 | describe_effect_schema = json.load(f) 37 | class DummyWS: 38 | def send(self, s): 39 | self.reply = json.loads(s) 40 | # print(json.dumps(self.reply, indent=2)) 41 | # if 'log' in self.reply: 42 | # print(json.dumps(reply['log'], indent=2)) 43 | # if 'data' in self.reply: 44 | # print(json.dumps(reply['data'], indent=2)) 45 | Flexi_ws = DummyWS() 46 | for id in Flexi_effect_for_id: 47 | print(f'Validating describe_effect for {id[0]} v{id[1]}') 48 | _process_message( 49 | {'effect' : [id[0], id[1], 'instance'], 50 | 'method' : 'describe_effect_v1', 51 | 'id' : 2, 52 | 'data' : {}}) 53 | validate(Flexi_ws.reply, describe_effect_schema) 54 | return 55 | 56 | url = os.environ['FLEXI_SERVER_URL'] 57 | 58 | Flexi_ws = websocket.create_connection(url) 59 | 60 | # print('Connected to {}'.format(url)) 61 | 62 | # Processing loop 63 | while True: 64 | try: 65 | message = json.loads(Flexi_ws.recv()) 66 | except websocket._exceptions.WebSocketConnectionClosedException: 67 | print('Connection closed') 68 | _shutdown() 69 | if isinstance(message, list): 70 | # TODO could consider batching replies 71 | for entry in message: 72 | _process_message(entry) 73 | else: 74 | _process_message(message) 75 | 76 | # Process one message from the application 77 | def _process_message(message): 78 | global Flexi_effect_for_id, Flexi_registered_effects 79 | 80 | if 'method' in message and message['method'] == 'shutdown_v1': 81 | _reply(message, {}) 82 | _shutdown() 83 | 84 | try: 85 | id = (message['effect'][0], message['effect'][1]) 86 | if not id in Flexi_effect_for_id: 87 | return _reply(message, error('Unregistered effect', id[0])) 88 | effect = Flexi_effect_for_id[id] 89 | effect._handle_message(id, message) 90 | 91 | except SystemExit as e: 92 | sys.exit(e) 93 | 94 | except: 95 | print('Internal error', format(sys.exc_info())) 96 | traceback.print_exc() 97 | _reply(message, error('Internal error', format(sys.exc_info()))) 98 | 99 | # Send a reply to a message 100 | def _reply(message, data): 101 | global Flexi_ws 102 | if data != None and '__LOG__' in data: 103 | reply = {'effect' : message['effect'], 104 | 'method' : message['method'].upper(), 105 | 'id' : message['id'], 106 | 'log' : data['__LOG__'] 107 | } 108 | else: 109 | reply = {'effect' : message['effect'], 110 | 'method' : message['method'].upper(), 111 | 'id' : message['id'], 112 | 'data' : data 113 | } 114 | Flexi_ws.send(json.dumps(reply)) 115 | 116 | # Make a reply with an error to a message 117 | def error(error, detail): 118 | return { '__LOG__' : [ { 'level' : 3, 'message' : error, 'detail' : detail } ] } 119 | 120 | def _shutdown(): 121 | global Flexi_ws 122 | for id in Flexi_effect_for_id: 123 | effect = Flexi_effect_for_id[id] 124 | effect.shutdown() 125 | if Flexi_ws != None: 126 | Flexi_ws.close() 127 | print('Exiting') 128 | sys.exit(0) 129 | 130 | # Effect class - subclass this to define your effect 131 | 132 | class Effect: 133 | 134 | def init(self, ids): 135 | """Initialise the effect object. 136 | ids is a list of tuples ('identifier', version) 137 | """ 138 | for id in ids: 139 | register_id(id, self) 140 | 141 | def _handle_message(self, id, message): 142 | """Internal method""" 143 | gpu = None 144 | if 'data' in message and 'cuda_gpu' in message['data']: 145 | cuda_gpu = message['data']['cuda_gpu'] 146 | if cuda_gpu != '': 147 | gpu = cuda_gpu 148 | if 'data' in message and 'metal_device_index' in message['data']: 149 | metal_device_index = message['data']['metal_device_index'] 150 | if metal_device_index != -1: 151 | gpu = metal_device_index 152 | if message['method'] == 'describe_effect_v1': 153 | _reply(message, self.describe_effect(id)) 154 | self.setup_if_necessary(id, gpu) 155 | elif message['method'] == 'query_vram_requirement_v1': 156 | _reply(message, self.query_vram_requirement(id, message['data'])) 157 | elif message['method'] == 'set_vram_limit_v1': 158 | res = self.set_vram_limit(id, message['data']) 159 | if(res == 'quit'): 160 | _reply(message, {}) 161 | _shutdown() 162 | else: 163 | _reply(message, res) 164 | elif message['method'] == 'run_generate_v1': 165 | self.setup_if_necessary(id, gpu) 166 | _reply(message, self.run_generate(id, message['effect'][2], message['data'])) 167 | else: 168 | _reply(message, error('Unknown method', message['method'])) 169 | 170 | def describe_effect(self, id, message): 171 | """Subclass must override this to QUICKLY return a description of a particular effect. 172 | id is a tuple ('identifier', version, 'instance') 173 | message will have members 'host_identifer', 'host_version' which can be used 174 | to change how the effect behaves, or indeed to fail to describe it 175 | """ 176 | raise RuntimeError('Effect subclass must implement describe_effect() method') 177 | 178 | def setup_if_necessary(self, id, gpu): 179 | """Subclass may override this to do expensive setup. 180 | gpu is (on Linux) a string which is the CUDA UUID of the GPU to use, 181 | or (on macOS) an integer index into the result of MTLCopyAllDevices()""" 182 | return 183 | 184 | def query_vram_requirement(self, id, data): 185 | """Subclass must override this to QUICKLY return the VRAM requirement of a particular effect. 186 | id is a tuple ('identifier', version, 'instance') 187 | data contains params, width and height 188 | Reply must contain 'min_mb' and 'max_mb' 189 | """ 190 | raise RuntimeError('Effect subclass must implement query_vram_requirement() method') 191 | 192 | def set_vram_limit(self, id, data): 193 | """Subclass must override this to IMMEDIATELY restrict its current and future 194 | VRAM usage to the supplied value data['limit_mb'] in MB. 195 | 196 | If limit_mb is zero, the effect must release ALL resources on the indicated device. 197 | 198 | Method must not return until its GPU usage is at or below the supplied limit, 199 | except in the special case where a GPU SDK is being used which does not support 200 | releasing resources, then this method can return 'exit', which will cause the 201 | process to exit. 202 | 203 | id is a tuple ('identifier', version, 'instance') 204 | """ 205 | raise RuntimeError('Effect subclass must implement set_vram_limit() method') 206 | 207 | def run_generate(self, id, instance, data): 208 | """Subclass must override this to run an effect. 209 | id is a tuple ('identifier', version, 'instance') 210 | instance is a string that might remain constant for each instance of the effect 211 | in the app 212 | data contains the inputs, outputs, parameters etc 213 | """ 214 | raise RuntimeError('Effect subclass must implement run_generate() method') 215 | 216 | def shutdown(self): 217 | """Subclass may override this to do a cleanup prior to shutdown""" 218 | return 219 | 220 | # Exports 221 | __all__ = ["Effect", "run"] 222 | -------------------------------------------------------------------------------- /baselight/uk.co.andriy.mltimewarp.v1/lib64: -------------------------------------------------------------------------------- 1 | lib -------------------------------------------------------------------------------- /baselight/uk.co.andriy.mltimewarp.v1/pyvenv.cfg: -------------------------------------------------------------------------------- 1 | home = /usr/bin 2 | include-system-site-packages = false 3 | version = 3.9.6 4 | -------------------------------------------------------------------------------- /baselight/uk.co.andriy.mltimewarp.v1/timewarp.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import flexi 3 | import json 4 | import numpy 5 | import platform 6 | if (platform.system() == 'Darwin'): 7 | from multiprocessing import shared_memory 8 | elif (platform.system() == 'Linux'): 9 | import cupy 10 | from numba import cuda 11 | import torch 12 | from torch.nn import functional as F 13 | import math, time 14 | 15 | def get_rife_model(req_device="cuda"): 16 | from RIFE_HDv3 import Model 17 | if (req_device == "mps" and not torch.backends.mps.is_available()): 18 | req_device = "cpu" 19 | if (req_device == "cuda" and not torch.cuda.is_available()): 20 | req_device = "cpu" 21 | 22 | device = torch.device(req_device) 23 | 24 | model = Model() 25 | model.load_model("RIFE/model/", device, -1) 26 | torch.set_grad_enabled(False) 27 | 28 | if device == "cuda": 29 | torch.backends.cudnn.enabled = True 30 | torch.backends.cudnn.benchmark = True 31 | 32 | model.eval() 33 | model.device(device) 34 | 35 | return model, req_device 36 | 37 | def rife_inference(model, device, src1, src2, scale, ratio, nbPass, width, height): 38 | 39 | img1 = torch.as_tensor(src1, device=device) 40 | img2 = torch.as_tensor(src2, device=device) 41 | 42 | img1 = torch.reshape(img1, (4, height, width)).unsqueeze(0) 43 | img2 = torch.reshape(img2, (4, height, width)).unsqueeze(0) 44 | 45 | n, c, h, w = img1.shape 46 | 47 | #determine the padding 48 | pad = max(int(32 / scale), 128) 49 | ph = ((h - 1) // pad + 1) * pad 50 | pw = ((w - 1) // pad + 1) * pad 51 | padding = (0, pw - w, 0, ph - h) 52 | img1 = F.pad(img1, padding) 53 | img2 = F.pad(img2, padding) 54 | 55 | lf = 0. 56 | rf = 1. 57 | 58 | scale_list = [4, 2, 1] 59 | for i in range(nbPass): 60 | if ratio == lf: 61 | res = img1 62 | break 63 | elif ratio == rf: 64 | res = img2 65 | break 66 | 67 | f = (ratio - lf) / (rf - lf) if(i == nbPass - 1) else 0.5 68 | 69 | res = model.inference(img1, img2, scale, f) 70 | 71 | if(i is not nbPass - 1): 72 | if ratio <= (lf + rf) / 2.: 73 | img2 = res 74 | rf = (lf + rf) / 2. 75 | else: 76 | img1 = res 77 | lf = (lf + rf) / 2. 78 | _, _, h2, w2 = res.shape 79 | 80 | out = res[0] 81 | if (platform.system() == 'Linux'): 82 | return out[:, :h, :w].contiguous() 83 | else: 84 | return out[:, :h, :w].contiguous().cpu() 85 | 86 | ###### 87 | # 88 | # RIFE Retime 89 | # 90 | class RIFERetime(flexi.Effect): 91 | 92 | def __init__(self): 93 | self.init([('uk.ltd.filmlight.rife_retime', 1)]) 94 | self.model = None 95 | self.vram_amount = 0 96 | 97 | def describe_effect(self, id): 98 | # We don't expose the effect 99 | return {} 100 | 101 | def query_vram_requirement(self, id, data): 102 | width = data['width'] 103 | height = data['height'] 104 | params = data['params'] 105 | amount = 3. * math.sqrt(width * height) 106 | #As the first instanciation will require more VRAM, make sure we request more 107 | # if the model hasn't been loaded 108 | return {'min_mb' : amount, 'max_mb' : 1.1 * amount} 109 | 110 | def set_vram_limit(self, id, data): 111 | limit = data['limit_mb'] 112 | #If the limit is not enough, we need to desinstanciate the effect 113 | #(don't bother if we haven't loaded the model) 114 | if(self.model is not None and limit < self.vram_amount): 115 | return 'quit' 116 | return {'RIFE' : 'Success'} 117 | 118 | def run_generate(self, id, instance, data): 119 | # No need to verify identifier or version, we only support one 120 | inputs = data['inputs'] 121 | output = data['output'] 122 | params = data['params'] 123 | 124 | width = data['width'] 125 | height = data['height'] 126 | timestamp = params['timestamp'] 127 | scale = params['scale'] 128 | nbPass = int(params['pass']) 129 | 130 | #Account for VRAM usage 131 | amount = self.query_vram_requirement(id, data)['min_mb'] 132 | if(self.vram_amount < amount): 133 | self.vram_amount = amount 134 | 135 | #instancidate the model if necessary with the correct backend 136 | if(self.model is None): 137 | device = "cpu" 138 | if (platform.system() == 'Darwin'): 139 | device = "mps" 140 | elif (platform.system() == 'Linux'): 141 | device = "cuda" 142 | 143 | self.model, self.device = get_rife_model(device) 144 | 145 | 146 | #start_t = time.time() 147 | 148 | if (platform.system() == 'Darwin'): #SHM 149 | out_mem = shared_memory.SharedMemory(name=output['shm'].lstrip('/')) 150 | src_mem = shared_memory.SharedMemory(name=inputs['src:0']['shm'].lstrip('/')) 151 | src2_mem = shared_memory.SharedMemory(name=inputs.get('src:1', '-')['shm'].lstrip('/')) 152 | 153 | #map the shared memory into ndarray objects 154 | src1_array = numpy.ndarray((4, height, width), numpy.float32, src_mem.buf) 155 | src2_array = numpy.ndarray((4, height, width), numpy.float32, src2_mem.buf) 156 | 157 | out_array = numpy.ndarray((4, height, width), numpy.float32, out_mem.buf) 158 | out_array[:,:,:] = rife_inference(self.model, self.device, src1_array, src2_array, scale, timestamp, nbPass, width, height) 159 | #close the shared memory 160 | src_mem.close() 161 | src2_mem.close() 162 | out_mem.close() 163 | 164 | else: #CUDA_IPC x Linux 165 | shape = height * width * 16 166 | dtype = numpy.dtype(numpy.float32) 167 | 168 | #standard_b64decode 169 | ihandle1 = int(inputs['src:0']['cuda_ipc'], base=16).to_bytes(64, byteorder='big') 170 | ihandle2 = int(inputs.get('src:1','-')['cuda_ipc'], base=16).to_bytes(64, byteorder='big') 171 | ohandle = int(output['cuda_ipc'], base=16).to_bytes(64, byteorder='big') 172 | 173 | with cuda.open_ipc_array(ihandle1, shape=shape // dtype.itemsize, dtype=dtype) as src1_array: 174 | with cuda.open_ipc_array(ihandle2, shape=shape // dtype.itemsize, dtype=dtype) as src2_array: 175 | with cuda.open_ipc_array(ohandle, shape=shape // dtype.itemsize, dtype=dtype) as out_array: 176 | 177 | res = rife_inference(self.model, self.device, src1_array, src2_array, scale, timestamp, nbPass, width, height) 178 | 179 | cupy.cuda.runtime.memcpy(out_array.gpu_data._mem.handle.value, res.__cuda_array_interface__['data'][0], height * width * 4 * 4, 4 ) 180 | cuda.synchronize() 181 | torch.cuda.empty_cache() 182 | 183 | return {'metadata' : {'RIFE' : 'Success'}} 184 | 185 | if __name__ == '__main__': 186 | 187 | # Register each effect 188 | RIFERetime() 189 | 190 | # and run 191 | flexi.run() 192 | -------------------------------------------------------------------------------- /baselight/uk.co.andriy.mltimewarp.v1/uk.co.dirtylooks.mltimewarp.v1.flexi: -------------------------------------------------------------------------------- 1 | { 2 | "flexi_effects_definition_v1" : { 3 | "effects" : [ 4 | { 5 | "identifier": "uk.co.dirtylooks.mltimewarp", 6 | "version": 1, 7 | "label": "DL ML Timewarp", 8 | "description": "DL ML Timewarp", 9 | "licence" : "LICENCE.md", 10 | "type" : "internal" 11 | } 12 | ], 13 | "launch_command": "./bin/python3 -W ignore ./timewarp.py", 14 | "install": { 15 | "all": { 16 | "flexi_platform": "pytorch-2.1", 17 | "sources": { 18 | "flexi.py": "../flexi.py", 19 | "timewarp.py": "timewarp.py", 20 | "LICENCE.md" : "LICENCE.md" 21 | }, 22 | "weights" : { 23 | "Flownet001" : "flownet/flownet001.pth" 24 | }, 25 | "run": [] 26 | } 27 | } 28 | } 29 | } 30 | 31 | -------------------------------------------------------------------------------- /find_scale.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Resolves the directory where the script is located 4 | SCRIPT_DIR=$(dirname "$(readlink -f "$0")") 5 | 6 | # Change to the script directory 7 | cd "$SCRIPT_DIR" 8 | 9 | # Define the path to the Python executable and the Python script 10 | PYTHON_CMD="./packages/.miniconda/appenv/bin/python" 11 | PYTHON_SCRIPT="./pytorch/flameTimewarpML_findscale.py" 12 | 13 | # Run the Python script with all arguments passed to this shell script 14 | $PYTHON_CMD $PYTHON_SCRIPT "$@" 15 | -------------------------------------------------------------------------------- /fix-xattr.command: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Define the target directory relative to the script location 4 | TARGET_DIR="$(dirname "$0")/packages/.miniconda" 5 | 6 | # Check if the target directory exists 7 | if [ ! -d "$TARGET_DIR" ]; then 8 | echo "Directory $TARGET_DIR does not exist." 9 | exit 1 10 | fi 11 | 12 | # Find all files in the target directory and clear extended attributes 13 | find "$TARGET_DIR" -type f -exec xattr -c {} \; 14 | 15 | echo "Extended attributes cleared from all files in $TARGET_DIR" -------------------------------------------------------------------------------- /fonts/DejaVuSansMono.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/talosh/flameTimewarpML/cec31a5436d9c40113f635152c1c0da110aadebd/fonts/DejaVuSansMono.ttf -------------------------------------------------------------------------------- /models/flownet2_v004.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:40245beb4a757320f67f714ee5bf3f33c613c478bd01f30ecbb57b1ce445e9ec 3 | size 339780714 4 | -------------------------------------------------------------------------------- /packages/.miniconda/README.md: -------------------------------------------------------------------------------- 1 | # flameTimewarpML 2 | * pre-configured miniconda environment to be placed here into appenv folder -------------------------------------------------------------------------------- /packages/README.md: -------------------------------------------------------------------------------- 1 | ## flameTimewarpML 2 | 3 | * pre-configured miniconda environment should be placed into hidden "packages/.miniconda" folder 4 | * the folder is hidden (starts with ".") in order to keep Flame from scanning it looking for python hooks 5 | * pre-configured python environment usually packed with release tar file 6 | 7 | ### Installing and configuring python environment manually 8 | 9 | * download Miniconda for Mac or Linux (I'm using python 3.11 for tests) from 10 | 11 | 12 | * install downloaded Miniconda python distribution, use "-p" to select install location. For example: 13 | 14 | ```bash 15 | sh ~/Downloads/Miniconda3-py311_24.1.2-0-Linux-x86_64.sh -bfsm -p ~/miniconda3 16 | ``` 17 | 18 | * Activate anc clone default environment into another named "appenv" 19 | 20 | ```bash 21 | eval "$(~/miniconda3/bin/conda shell.bash hook)" 22 | conda create --name appenv --clone base 23 | conda activate appenv 24 | ``` 25 | 26 | * Install dependency libraries 27 | 28 | ```bash 29 | conda install pyqt 30 | conda install numpy 31 | conda install conda-pack 32 | ``` 33 | 34 | * Install pytorch. Please look up exact commands depending on OS and Cuda versions at <> 35 | 36 | * Linux example 37 | ```bash 38 | conda install pytorch pytorch-cuda=11.8 -c pytorch -c nvidia 39 | ``` 40 | 41 | * MacOS example: 42 | 43 | ```bash 44 | conda install pytorch::pytorch -c pytorch 45 | ``` 46 | 47 | * Install rest of the dependencies 48 | ```bash 49 | pip install -r requirements.txt 50 | ``` 51 | 52 | * Pack append environment into a portable tar file 53 | 54 | ```bash 55 | conda pack --ignore-missing-files -n appenv 56 | ``` 57 | 58 | * Unpack environment to flameTimewarpML folder 59 | 60 | ```bash 61 | mkdir {flameTimewarpML folder}/packages/.miniconda/appenv/ 62 | tar xvf appenv.tar.gz -C {flameTimewarpML folder}/packages/.miniconda/appenv/ 63 | ``` 64 | 65 | * Remove environment tarball 66 | 67 | ```bash 68 | rm appenv.tar.gz 69 | ``` 70 | -------------------------------------------------------------------------------- /packages/clean_pycache.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Define the target directory. Default is the current directory. 4 | TARGET_DIR="${1:-.}" 5 | 6 | # Find all __pycache__ directories and remove them 7 | find "$TARGET_DIR" -type d -name "__pycache__" -exec rm -r {} + 8 | 9 | echo "All __pycache__ directories have been removed from $TARGET_DIR" -------------------------------------------------------------------------------- /presets/openexr16bit.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | image 4 | Creates 16-bit fp OpenEXR file sequence (Uncompressed) numbered based on the timecode of the selected clip. 5 | 29 | 30 | 8 31 | 1 32 | 1 33 | 34 | -------------------------------------------------------------------------------- /presets/openexr32bit.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | image 4 | Creates 32-bit fp OpenEXR file sequence (Uncompressed) numbered based on the timecode of the selected clip. 5 | 29 | 30 | 8 31 | 1 32 | 1 33 | 34 | -------------------------------------------------------------------------------- /presets/source_export.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | sequence 4 | Source export 16-bit fp OpenEXR (Uncompressed) 5 | 6 | NONE 7 | 8 | <name> 9 | True 10 | True 11 | 12 | image 13 | Original 14 | NoChange 15 | True 16 | 99999 17 | 18 | True 19 | False 20 | 21 | audio 22 | FX 23 | FlattenTracks 24 | True 25 | 10 26 | 27 | 28 | 53 | 54 | 8 55 | 1 56 | 0 57 | 58 | -------------------------------------------------------------------------------- /presets/source_export32bit.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | sequence 4 | Source export 32-bit fp OpenEXR (Uncompressed) 5 | 6 | NONE 7 | 8 | <name> 9 | True 10 | True 11 | 12 | image 13 | Original 14 | NoChange 15 | True 16 | 99999 17 | 18 | True 19 | False 20 | 21 | audio 22 | FX 23 | FlattenTracks 24 | True 25 | 10 26 | 27 | 28 | 53 | 54 | 8 55 | 1 56 | 0 57 | 58 | -------------------------------------------------------------------------------- /pytorch/check_consistent.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import importlib 5 | import platform 6 | 7 | os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' 8 | 9 | try: 10 | import numpy as np 11 | except: 12 | python_executable_path = sys.executable 13 | if '.miniconda' in python_executable_path: 14 | print (f'Using "{python_executable_path}" as python interpreter') 15 | print ('Unable to import Numpy') 16 | sys.exit() 17 | 18 | try: 19 | import OpenImageIO as oiio 20 | except: 21 | python_executable_path = sys.executable 22 | if '.miniconda' in python_executable_path: 23 | print ('Unable to import OpenImageIO') 24 | print (f'Using "{python_executable_path}" as python interpreter') 25 | sys.exit() 26 | 27 | def read_image_file(file_path, header_only = False): 28 | result = {'spec': None, 'image_data': None} 29 | inp = oiio.ImageInput.open(file_path) 30 | if inp : 31 | spec = inp.spec() 32 | result['spec'] = spec 33 | if not header_only: 34 | channels = spec.nchannels 35 | result['image_data'] = inp.read_image(0, 0, 0, channels).transpose(1, 0, 2) 36 | inp.close() 37 | return result 38 | 39 | def find_folders_with_exr(path): 40 | """ 41 | Find all folders under the given path that contain .exr files. 42 | 43 | Parameters: 44 | path (str): The root directory to start the search from. 45 | 46 | Returns: 47 | list: A list of directories containing .exr files. 48 | """ 49 | directories_with_exr = set() 50 | 51 | # Walk through all directories and files in the given path 52 | for root, dirs, files in os.walk(path): 53 | if 'preview' in root: 54 | continue 55 | if 'eval' in root: 56 | continue 57 | for file in files: 58 | if file.endswith('.exr'): 59 | directories_with_exr.add(root) 60 | break # No need to check other files in the same directory 61 | 62 | return directories_with_exr 63 | 64 | def clear_lines(n=2): 65 | """Clears a specified number of lines in the terminal.""" 66 | CURSOR_UP_ONE = '\x1b[1A' 67 | ERASE_LINE = '\x1b[2K' 68 | for _ in range(n): 69 | sys.stdout.write(CURSOR_UP_ONE) 70 | sys.stdout.write(ERASE_LINE) 71 | 72 | def main(): 73 | parser = argparse.ArgumentParser(description='In-place resize exrs to half-res (WARNING - DESTRUCTIVE!)') 74 | 75 | # Required argument 76 | parser.add_argument('dataset_path', type=str, help='Path to folders with exrs') 77 | args = parser.parse_args() 78 | 79 | folders_with_exr = find_folders_with_exr(args.dataset_path) 80 | 81 | for folder_idx, folder_path in enumerate(sorted(folders_with_exr)): 82 | print (f'\rFolder [{folder_idx+1} / {len(folders_with_exr)}], {folder_path}', end='') 83 | folder_exr_files = [os.path.join(folder_path, file) for file in os.listdir(folder_path) if file.endswith('.exr')] 84 | folder_exr_files.sort() 85 | 86 | frame_numbers = [] 87 | for filename in folder_exr_files: 88 | parts = filename.split('.') 89 | try: 90 | frame_number = int(parts[-2]) 91 | frame_numbers.append(frame_number) 92 | except ValueError: 93 | print (f'\rFormat error in {folder_path}: {filename}') 94 | 95 | frame_numbers.sort() 96 | for i in range(1, len(frame_numbers)): 97 | if frame_numbers[i] != frame_numbers[i - 1] + 1: 98 | print(f'\r{folder_path}: Missing or non-consecutive frame between {frame_numbers[i-1]} and {frame_numbers[i]}') 99 | 100 | print ('') 101 | 102 | if __name__ == "__main__": 103 | main() 104 | -------------------------------------------------------------------------------- /pytorch/check_readable.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import importlib 5 | import platform 6 | 7 | os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' 8 | 9 | try: 10 | import numpy as np 11 | except: 12 | python_executable_path = sys.executable 13 | if '.miniconda' in python_executable_path: 14 | print (f'Using "{python_executable_path}" as python interpreter') 15 | print ('Unable to import Numpy') 16 | sys.exit() 17 | 18 | try: 19 | import OpenImageIO as oiio 20 | except: 21 | python_executable_path = sys.executable 22 | if '.miniconda' in python_executable_path: 23 | print ('Unable to import OpenImageIO') 24 | print (f'Using "{python_executable_path}" as python interpreter') 25 | sys.exit() 26 | 27 | def read_image_file(file_path, header_only = False): 28 | result = {'spec': None, 'image_data': None} 29 | inp = oiio.ImageInput.open(file_path) 30 | if inp : 31 | spec = inp.spec() 32 | result['spec'] = spec 33 | if not header_only: 34 | channels = spec.nchannels 35 | result['image_data'] = inp.read_image(0, 0, 0, channels).transpose(1, 0, 2) 36 | inp.close() 37 | return result 38 | 39 | def find_folders_with_exr(path): 40 | """ 41 | Find all folders under the given path that contain .exr files. 42 | 43 | Parameters: 44 | path (str): The root directory to start the search from. 45 | 46 | Returns: 47 | list: A list of directories containing .exr files. 48 | """ 49 | directories_with_exr = set() 50 | 51 | # Walk through all directories and files in the given path 52 | for root, dirs, files in os.walk(path): 53 | if root.endswith('preview'): 54 | continue 55 | if root.endswith('eval'): 56 | continue 57 | for file in files: 58 | if file.endswith('.exr'): 59 | directories_with_exr.add(root) 60 | break # No need to check other files in the same directory 61 | 62 | return directories_with_exr 63 | 64 | def clear_lines(n=2): 65 | """Clears a specified number of lines in the terminal.""" 66 | CURSOR_UP_ONE = '\x1b[1A' 67 | ERASE_LINE = '\x1b[2K' 68 | for _ in range(n): 69 | sys.stdout.write(CURSOR_UP_ONE) 70 | sys.stdout.write(ERASE_LINE) 71 | 72 | def main(): 73 | parser = argparse.ArgumentParser(description='In-place resize exrs to half-res (WARNING - DESTRUCTIVE!)') 74 | 75 | # Required argument 76 | parser.add_argument('dataset_path', type=str, help='Path to folders with exrs') 77 | args = parser.parse_args() 78 | 79 | folders_with_exr = find_folders_with_exr(args.dataset_path) 80 | 81 | exr_files = [] 82 | 83 | for folder_path in sorted(folders_with_exr): 84 | folder_exr_files = [os.path.join(folder_path, file) for file in os.listdir(folder_path) if file.endswith('.exr')] 85 | folder_exr_files.sort() 86 | exr_files.extend(folder_exr_files) 87 | 88 | idx = 0 89 | for exr_file_path in exr_files: 90 | # clear_lines(1) 91 | print (f'\rFile [{idx+1} / {len(exr_files)}], {os.path.basename(exr_file_path)}', end='') 92 | try: 93 | result = read_image_file(exr_file_path) 94 | w, h, c = result['image_data'].shape 95 | except Exception as e: 96 | print (f'\nError reading {exr_file_path}: {e}') 97 | idx += 1 98 | print ('') 99 | 100 | if __name__ == "__main__": 101 | main() 102 | -------------------------------------------------------------------------------- /pytorch/hub/checkpoints/README.md: -------------------------------------------------------------------------------- 1 | This folder is neede to put alexnet-owt-7be5be79.pth for LPIPS weights -------------------------------------------------------------------------------- /pytorch/models/archived/flownet4_v001d.py: -------------------------------------------------------------------------------- 1 | # Orig v001 changed to v002 main flow and signatures 2 | # Compile test 3 | 4 | class Model: 5 | def __init__(self, status = dict(), torch = None): 6 | if torch is None: 7 | import torch 8 | Module = torch.nn.Module 9 | backwarp_tenGrid = {} 10 | 11 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 12 | return torch.nn.Sequential( 13 | torch.nn.Conv2d( 14 | in_planes, 15 | out_planes, 16 | kernel_size=kernel_size, 17 | stride=stride, 18 | padding=padding, 19 | dilation=dilation, 20 | padding_mode = 'reflect', 21 | bias=True 22 | ), 23 | torch.nn.LeakyReLU(0.2, True) 24 | # torch.nn.SELU(inplace = True) 25 | ) 26 | 27 | def warp(tenInput, tenFlow): 28 | k = (str(tenFlow.device), str(tenFlow.size())) 29 | if k not in backwarp_tenGrid: 30 | tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3]).view(1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) 31 | tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2]).view(1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) 32 | backwarp_tenGrid[k] = torch.cat([ tenHorizontal, tenVertical ], 1).to(device=tenInput.device, dtype=tenInput.dtype) 33 | tenFlow = torch.cat([ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) ], 1) 34 | 35 | g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) 36 | return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='reflection', align_corners=True) 37 | 38 | class Head(Module): 39 | def __init__(self): 40 | super(Head, self).__init__() 41 | self.cnn0 = torch.nn.Conv2d(3, 32, 3, 2, 1) 42 | self.cnn1 = torch.nn.Conv2d(32, 32, 3, 1, 1) 43 | self.cnn2 = torch.nn.Conv2d(32, 32, 3, 1, 1) 44 | self.cnn3 = torch.nn.ConvTranspose2d(32, 8, 4, 2, 1) 45 | self.relu = torch.nn.LeakyReLU(0.2, True) 46 | 47 | torch.nn.init.kaiming_normal_(self.cnn0.weight, mode='fan_in', nonlinearity='relu') 48 | self.cnn0.weight.data *= 1e-2 49 | if self.cnn0.bias is not None: 50 | torch.nn.init.constant_(self.cnn0.bias, 0) 51 | torch.nn.init.kaiming_normal_(self.cnn1.weight, mode='fan_in', nonlinearity='relu') 52 | self.cnn1.weight.data *= 1e-2 53 | if self.cnn1.bias is not None: 54 | torch.nn.init.constant_(self.cnn1.bias, 0) 55 | torch.nn.init.kaiming_normal_(self.cnn2.weight, mode='fan_in', nonlinearity='relu') 56 | self.cnn2.weight.data *= 1e-2 57 | if self.cnn2.bias is not None: 58 | torch.nn.init.constant_(self.cnn2.bias, 0) 59 | torch.nn.init.kaiming_normal_(self.cnn3.weight, mode='fan_in', nonlinearity='relu') 60 | self.cnn3.weight.data *= 1e-2 61 | if self.cnn3.bias is not None: 62 | torch.nn.init.constant_(self.cnn3.bias, 0) 63 | 64 | def forward(self, x, feat=False): 65 | x0 = self.cnn0(x) 66 | x = self.relu(x0) 67 | x1 = self.cnn1(x) 68 | x = self.relu(x1) 69 | x2 = self.cnn2(x) 70 | x = self.relu(x2) 71 | x3 = self.cnn3(x) 72 | if feat: 73 | return [x0, x1, x2, x3] 74 | return x3 75 | 76 | class ResConv(Module): 77 | def __init__(self, c, dilation=1): 78 | super().__init__() 79 | self.conv = torch.nn.Conv2d(c, c, 3, 1, dilation, dilation = dilation, groups = 1, padding_mode = 'reflect', bias=True) 80 | self.beta = torch.nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 81 | self.relu = torch.nn.LeakyReLU(0.2, True) # torch.nn.SELU(inplace = True) 82 | 83 | torch.nn.init.kaiming_normal_(self.conv.weight, mode='fan_in', nonlinearity='relu') 84 | self.conv.weight.data *= 1e-2 85 | if self.conv.bias is not None: 86 | torch.nn.init.constant_(self.conv.bias, 0) 87 | 88 | def forward(self, x): 89 | return self.relu(self.conv(x) * self.beta + x) 90 | 91 | class Flownet(Module): 92 | def __init__(self, in_planes, c=64): 93 | super().__init__() 94 | self.conv0 = torch.nn.Sequential( 95 | conv(in_planes, c//2, 3, 2, 1), 96 | conv(c//2, c, 3, 2, 1), 97 | ) 98 | self.lastconv = torch.nn.Sequential( 99 | torch.nn.ConvTranspose2d(c, 4*6, 4, 2, 1), 100 | torch.nn.PixelShuffle(2) 101 | ) 102 | 103 | def forward(self, img0, img1, f0, f1, timestep, mask, flow, scale=1): 104 | timestep = (img0[:, :1].clone() * 0 + 1) * timestep 105 | x = torch.cat((img0, img1, f0, f1, timestep), 1) 106 | if flow is not None: 107 | x = torch.cat((x, mask, flow), 1) 108 | feat = self.conv0(x) 109 | # feat = self.convblock(feat) 110 | tmp = self.lastconv(feat) 111 | flow = tmp[:, :4] * scale 112 | mask = tmp[:, 4:5] 113 | conf = tmp[:, 5:6] 114 | return flow, mask, conf 115 | 116 | class FlownetCas(Module): 117 | def __init__(self): 118 | super().__init__() 119 | self.block0 = Flownet(7+16, c=192) 120 | self.block1 = Flownet(8+4+16, c=128) 121 | self.block2 = Flownet(8+4+16, c=96) 122 | self.block3 = Flownet(8+4+16, c=64) 123 | self.encode = Head() 124 | 125 | def forward(self, img0, img1, timestep=0.5, scale=[8, 4, 2, 1], iterations=1): 126 | img0 = img0 127 | img1 = img1 128 | f0 = self.encode(img0) 129 | f1 = self.encode(img1) 130 | 131 | flow_list = [None] * 4 132 | mask_list = [None] * 4 133 | merged = [None] * 4 134 | 135 | flow, mask, conf = self.block0(img0, img1, f0, f1, timestep, None, None, scale=scale[0]) 136 | 137 | flow_list[3] = flow.clone() 138 | mask_list[3] = torch.sigmoid(mask.clone()) 139 | merged[3] = warp(img0, flow[:, :2]) * mask_list[3] + warp(img1, flow[:, 2:4]) * (1 - mask_list[3]) 140 | 141 | return flow_list, mask_list, merged 142 | 143 | for iteration in range(iterations): 144 | flow_d, mask, conf = self.block1( 145 | warp(img0, flow[:, :2]), 146 | warp(img1, flow[:, 2:4]), 147 | warp(f0, flow[:, :2]), 148 | warp(f1, flow[:, 2:4]), 149 | timestep, 150 | mask, 151 | flow, 152 | scale=scale[1] 153 | ) 154 | flow = flow + flow_d 155 | 156 | flow_list[1] = flow.clone() 157 | mask_list[1] = torch.sigmoid(mask.clone()) 158 | merged[1] = warp(img0, flow[:, :2]) * mask_list[1] + warp(img1, flow[:, 2:4]) * (1 - mask_list[1]) 159 | 160 | for iteration in range(iterations): 161 | flow_d, mask, conf = self.block2( 162 | warp(img0, flow[:, :2]), 163 | warp(img1, flow[:, 2:4]), 164 | warp(f0, flow[:, :2]), 165 | warp(f1, flow[:, 2:4]), 166 | timestep, 167 | mask, 168 | flow, 169 | scale=scale[2] 170 | ) 171 | flow = flow + flow_d 172 | 173 | flow_list[2] = flow.clone() 174 | mask_list[2] = torch.sigmoid(mask.clone()) 175 | merged[2] = warp(img0, flow[:, :2]) * mask_list[2] + warp(img1, flow[:, 2:4]) * (1 - mask_list[2]) 176 | 177 | for iteration in range(iterations): 178 | flow_d, mask, conf = self.block3( 179 | warp(img0, flow[:, :2]), 180 | warp(img1, flow[:, 2:4]), 181 | warp(f0, flow[:, :2]), 182 | warp(f1, flow[:, 2:4]), 183 | timestep, 184 | mask, 185 | flow, 186 | scale=scale[3] 187 | ) 188 | flow = flow + flow_d 189 | 190 | flow_list[3] = flow 191 | mask_list[3] = torch.sigmoid(mask) 192 | merged[3] = warp(img0, flow[:, :2]) * mask_list[3] + warp(img1, flow[:, 2:4]) * (1 - mask_list[3]) 193 | 194 | return flow_list, mask_list, merged 195 | 196 | self.model = FlownetCas 197 | self.training_model = FlownetCas 198 | 199 | @staticmethod 200 | def get_info(): 201 | info = { 202 | 'name': 'Flownet4_v001d', 203 | 'file': 'flownet4_v001d.py', 204 | 'ratio_support': True 205 | } 206 | return info 207 | 208 | @staticmethod 209 | def get_name(): 210 | return 'TWML_Flownet_v001' 211 | 212 | @staticmethod 213 | def input_channels(model_state_dict): 214 | channels = 3 215 | try: 216 | channels = model_state_dict.get('multiresblock1.conv_3x3.conv1.weight').shape[1] 217 | except Exception as e: 218 | print (f'Unable to get model dict input channels - setting to 3, {e}') 219 | return channels 220 | 221 | @staticmethod 222 | def output_channels(model_state_dict): 223 | channels = 5 224 | try: 225 | channels = model_state_dict.get('conv_final.conv1.weight').shape[0] 226 | except Exception as e: 227 | print (f'Unable to get model dict output channels - setting to 3, {e}') 228 | return channels 229 | 230 | def get_model(self): 231 | import platform 232 | if platform.system() == 'Darwin': 233 | return self.training_model 234 | return self.model 235 | 236 | def get_training_model(self): 237 | return self.training_model 238 | 239 | def load_model(self, path, flownet, rank=0): 240 | import torch 241 | def convert(param): 242 | if rank == -1: 243 | return { 244 | k.replace("module.", ""): v 245 | for k, v in param.items() 246 | if "module." in k 247 | } 248 | else: 249 | return param 250 | if rank <= 0: 251 | if torch.cuda.is_available(): 252 | flownet.load_state_dict(convert(torch.load(path)), False) 253 | else: 254 | flownet.load_state_dict(convert(torch.load(path, map_location ='cpu')), False) -------------------------------------------------------------------------------- /pytorch/models/archived/flownet4_v001ea.py: -------------------------------------------------------------------------------- 1 | # Orig v001 changed to v002 main flow and signatures 2 | # Back from SiLU to LeakyReLU to test data flow 3 | # Warps moved to flownet forward 4 | 5 | class Model: 6 | def __init__(self, status = dict(), torch = None): 7 | if torch is None: 8 | import torch 9 | Module = torch.nn.Module 10 | backwarp_tenGrid = {} 11 | 12 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 13 | return torch.nn.Sequential( 14 | torch.nn.Conv2d( 15 | in_planes, 16 | out_planes, 17 | kernel_size=kernel_size, 18 | stride=stride, 19 | padding=padding, 20 | dilation=dilation, 21 | padding_mode = 'reflect', 22 | bias=True 23 | ), 24 | torch.nn.LeakyReLU(0.2, True) 25 | # torch.nn.SELU(inplace = True) 26 | ) 27 | 28 | def warp(tenInput, tenFlow): 29 | k = (str(tenFlow.device), str(tenFlow.size())) 30 | if k not in backwarp_tenGrid: 31 | tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3]).view(1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) 32 | tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2]).view(1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) 33 | backwarp_tenGrid[k] = torch.cat([ tenHorizontal, tenVertical ], 1).to(device=tenInput.device, dtype=tenInput.dtype) 34 | tenFlow = torch.cat([ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) ], 1) 35 | 36 | g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) 37 | return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='reflection', align_corners=True) 38 | 39 | class Head(Module): 40 | def __init__(self): 41 | super(Head, self).__init__() 42 | self.cnn0 = torch.nn.Conv2d(3, 32, 3, 2, 1) 43 | self.cnn1 = torch.nn.Conv2d(32, 32, 3, 1, 1) 44 | self.cnn2 = torch.nn.Conv2d(32, 32, 3, 1, 1) 45 | self.cnn3 = torch.nn.ConvTranspose2d(32, 8, 4, 2, 1) 46 | self.relu = torch.nn.LeakyReLU(0.2, True) 47 | 48 | def forward(self, x, feat=False): 49 | x0 = self.cnn0(x) 50 | x = self.relu(x0) 51 | x1 = self.cnn1(x) 52 | x = self.relu(x1) 53 | x2 = self.cnn2(x) 54 | x = self.relu(x2) 55 | x3 = self.cnn3(x) 56 | if feat: 57 | return [x0, x1, x2, x3] 58 | return x3 59 | 60 | class ResConv(Module): 61 | def __init__(self, c, dilation=1): 62 | super().__init__() 63 | self.conv = torch.nn.Conv2d(c, c, 3, 1, dilation, dilation = dilation, groups = 1, padding_mode = 'reflect', bias=True) 64 | self.beta = torch.nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 65 | self.relu = torch.nn.LeakyReLU(0.2, True) # torch.nn.SELU(inplace = True) 66 | 67 | torch.nn.init.kaiming_normal_(self.conv.weight, mode='fan_in', nonlinearity='relu') 68 | self.conv.weight.data *= 1e-2 69 | if self.conv.bias is not None: 70 | torch.nn.init.constant_(self.conv.bias, 0) 71 | 72 | def forward(self, x): 73 | return self.relu(self.conv(x) * self.beta + x) 74 | 75 | class Flownet(Module): 76 | def __init__(self, in_planes, c=64): 77 | super().__init__() 78 | self.conv0 = torch.nn.Sequential( 79 | conv(in_planes, c//2, 3, 2, 1), 80 | conv(c//2, c, 3, 2, 1), 81 | ) 82 | self.convblock = torch.nn.Sequential( 83 | ResConv(c), 84 | ResConv(c), 85 | ResConv(c), 86 | ResConv(c), 87 | ResConv(c), 88 | ResConv(c), 89 | ResConv(c), 90 | ResConv(c), 91 | ) 92 | self.lastconv = torch.nn.Sequential( 93 | torch.nn.ConvTranspose2d(c, 4*6, 4, 2, 1), 94 | torch.nn.PixelShuffle(2) 95 | ) 96 | 97 | def forward(self, img0, img1, f0, f1, timestep, mask, flow, scale=1): 98 | timestep = (img0[:, :1].clone() * 0 + 1) * timestep 99 | 100 | if flow is None: 101 | x = torch.cat((img0, img1, f0, f1, timestep), 1) 102 | x = torch.nn.functional.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False) 103 | else: 104 | warped_img0 = warp(img0, flow[:, :2]) 105 | warped_img1 = warp(img0, flow[:, :2]) 106 | warped_f0 = warp(f0, flow[:, :2]) 107 | warped_f1 = warp(f1, flow[:, 2:4]) 108 | x = x = torch.cat((warped_img0, warped_img1, warped_f0, warped_f1, timestep, mask), 1) 109 | x = torch.nn.functional.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False) 110 | flow = torch.nn.functional.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False) * 1. / scale 111 | x = torch.cat((x, flow), 1) 112 | 113 | feat = self.conv0(x) 114 | feat = self.convblock(feat) 115 | tmp = self.lastconv(feat) 116 | tmp = torch.nn.functional.interpolate(tmp, scale_factor=scale, mode="bilinear", align_corners=False) 117 | flow = tmp[:, :4] * scale 118 | mask = tmp[:, 4:5] 119 | conf = tmp[:, 5:6] 120 | return flow, mask, conf 121 | 122 | class FlownetCas(Module): 123 | def __init__(self): 124 | super().__init__() 125 | self.block0 = Flownet(7+16, c=192) 126 | self.block1 = Flownet(8+4+16, c=128) 127 | self.block2 = Flownet(8+4+16, c=96) 128 | self.block3 = Flownet(8+4+16, c=64) 129 | self.encode = Head() 130 | 131 | def forward(self, img0, img1, timestep=0.5, scale=[8, 4, 2, 1], iterations=1): 132 | img0 = img0 133 | img1 = img1 134 | f0 = self.encode(img0) 135 | f1 = self.encode(img1) 136 | 137 | flow_list = [None] * 4 138 | mask_list = [None] * 4 139 | merged = [None] * 4 140 | 141 | flow, mask, conf = self.block0(img0, img1, f0, f1, timestep, None, None, scale=scale[0]) 142 | 143 | flow_list[0] = flow.clone() 144 | mask_list[0] = torch.sigmoid(mask.clone()) 145 | merged[0] = warp(img0, flow[:, :2]) * mask_list[0] + warp(img1, flow[:, 2:4]) * (1 - mask_list[0]) 146 | 147 | for iteration in range(iterations): 148 | flow_d, mask, conf = self.block1( 149 | img0, 150 | img1, 151 | f0, 152 | f1, 153 | timestep, 154 | mask, 155 | flow, 156 | scale=scale[1] 157 | ) 158 | flow = flow + flow_d 159 | 160 | flow_list[1] = flow.clone() 161 | mask_list[1] = torch.sigmoid(mask.clone()) 162 | merged[1] = warp(img0, flow[:, :2]) * mask_list[1] + warp(img1, flow[:, 2:4]) * (1 - mask_list[1]) 163 | 164 | for iteration in range(iterations): 165 | flow_d, mask, conf = self.block2( 166 | img0, 167 | img1, 168 | f0, 169 | f1, 170 | timestep, 171 | mask, 172 | flow, 173 | scale=scale[2] 174 | ) 175 | flow = flow + flow_d 176 | 177 | flow_list[2] = flow.clone() 178 | mask_list[2] = torch.sigmoid(mask.clone()) 179 | merged[2] = warp(img0, flow[:, :2]) * mask_list[2] + warp(img1, flow[:, 2:4]) * (1 - mask_list[2]) 180 | 181 | for iteration in range(iterations): 182 | flow_d, mask, conf = self.block3( 183 | img0, 184 | img1, 185 | f0, 186 | f1, 187 | timestep, 188 | mask, 189 | flow, 190 | scale=scale[3] 191 | ) 192 | flow = flow + flow_d 193 | 194 | flow_list[3] = flow 195 | mask_list[3] = torch.sigmoid(mask) 196 | merged[3] = warp(img0, flow[:, :2]) * mask_list[3] + warp(img1, flow[:, 2:4]) * (1 - mask_list[3]) 197 | 198 | return flow_list, mask_list, merged 199 | 200 | self.model = FlownetCas 201 | self.training_model = FlownetCas 202 | 203 | @staticmethod 204 | def get_info(): 205 | info = { 206 | 'name': 'Flownet4_v001ea', 207 | 'file': 'flownet4_v001ea.py', 208 | 'ratio_support': True 209 | } 210 | return info 211 | 212 | @staticmethod 213 | def get_name(): 214 | return 'TWML_Flownet_v001ea' 215 | 216 | @staticmethod 217 | def input_channels(model_state_dict): 218 | channels = 3 219 | try: 220 | channels = model_state_dict.get('multiresblock1.conv_3x3.conv1.weight').shape[1] 221 | except Exception as e: 222 | print (f'Unable to get model dict input channels - setting to 3, {e}') 223 | return channels 224 | 225 | @staticmethod 226 | def output_channels(model_state_dict): 227 | channels = 5 228 | try: 229 | channels = model_state_dict.get('conv_final.conv1.weight').shape[0] 230 | except Exception as e: 231 | print (f'Unable to get model dict output channels - setting to 3, {e}') 232 | return channels 233 | 234 | def get_model(self): 235 | import platform 236 | if platform.system() == 'Darwin': 237 | return self.training_model 238 | return self.model 239 | 240 | def get_training_model(self): 241 | return self.training_model 242 | 243 | def load_model(self, path, flownet, rank=0): 244 | import torch 245 | def convert(param): 246 | if rank == -1: 247 | return { 248 | k.replace("module.", ""): v 249 | for k, v in param.items() 250 | if "module." in k 251 | } 252 | else: 253 | return param 254 | if rank <= 0: 255 | if torch.cuda.is_available(): 256 | flownet.load_state_dict(convert(torch.load(path)), False) 257 | else: 258 | flownet.load_state_dict(convert(torch.load(path, map_location ='cpu')), False) -------------------------------------------------------------------------------- /pytorch/models/archived/flownet4_v001eb3x.py: -------------------------------------------------------------------------------- 1 | # Orig v001 changed to v002 main flow and signatures 2 | # Back from SiLU to LeakyReLU to test data flow 3 | # Warps moved to flownet forward 4 | # Different Tail from flownet 2lh (ConvTr 6x6, conv 1x1, ConvTr 4x4, conv 1x1) 5 | # Scale list are multiplies of 3 instead of 2 [12, 6, 3, 1] 6 | 7 | class Model: 8 | def __init__(self, status = dict(), torch = None): 9 | if torch is None: 10 | import torch 11 | Module = torch.nn.Module 12 | backwarp_tenGrid = {} 13 | 14 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 15 | return torch.nn.Sequential( 16 | torch.nn.Conv2d( 17 | in_planes, 18 | out_planes, 19 | kernel_size=kernel_size, 20 | stride=stride, 21 | padding=padding, 22 | dilation=dilation, 23 | padding_mode = 'reflect', 24 | bias=True 25 | ), 26 | torch.nn.LeakyReLU(0.2, True) 27 | # torch.nn.SELU(inplace = True) 28 | ) 29 | 30 | def warp(tenInput, tenFlow): 31 | k = (str(tenFlow.device), str(tenFlow.size())) 32 | if k not in backwarp_tenGrid: 33 | tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3]).view(1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) 34 | tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2]).view(1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) 35 | backwarp_tenGrid[k] = torch.cat([ tenHorizontal, tenVertical ], 1).to(device=tenInput.device, dtype=tenInput.dtype) 36 | tenFlow = torch.cat([ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) ], 1) 37 | 38 | g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) 39 | return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='reflection', align_corners=True) 40 | 41 | class Head(Module): 42 | def __init__(self): 43 | super(Head, self).__init__() 44 | self.cnn0 = torch.nn.Conv2d(3, 32, 3, 2, 1) 45 | self.cnn1 = torch.nn.Conv2d(32, 32, 3, 1, 1) 46 | self.cnn2 = torch.nn.Conv2d(32, 32, 3, 1, 1) 47 | self.cnn3 = torch.nn.ConvTranspose2d(32, 8, 4, 2, 1) 48 | self.relu = torch.nn.LeakyReLU(0.2, True) 49 | 50 | def forward(self, x, feat=False): 51 | x0 = self.cnn0(x) 52 | x = self.relu(x0) 53 | x1 = self.cnn1(x) 54 | x = self.relu(x1) 55 | x2 = self.cnn2(x) 56 | x = self.relu(x2) 57 | x3 = self.cnn3(x) 58 | if feat: 59 | return [x0, x1, x2, x3] 60 | return x3 61 | 62 | class ResConv(Module): 63 | def __init__(self, c, dilation=1): 64 | super().__init__() 65 | self.conv = torch.nn.Conv2d(c, c, 3, 1, dilation, dilation = dilation, groups = 1, padding_mode = 'reflect', bias=True) 66 | self.beta = torch.nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 67 | self.relu = torch.nn.LeakyReLU(0.2, True) # torch.nn.SELU(inplace = True) 68 | 69 | torch.nn.init.kaiming_normal_(self.conv.weight, mode='fan_in', nonlinearity='relu') 70 | self.conv.weight.data *= 1e-2 71 | if self.conv.bias is not None: 72 | torch.nn.init.constant_(self.conv.bias, 0) 73 | 74 | def forward(self, x): 75 | return self.relu(self.conv(x) * self.beta + x) 76 | 77 | class Flownet(Module): 78 | def __init__(self, in_planes, c=64): 79 | super().__init__() 80 | self.conv0 = torch.nn.Sequential( 81 | conv(in_planes, c//2, 3, 2, 1), 82 | conv(c//2, c, 3, 2, 1), 83 | ) 84 | self.convblock = torch.nn.Sequential( 85 | ResConv(c), 86 | ResConv(c), 87 | ResConv(c), 88 | ResConv(c), 89 | ResConv(c), 90 | ResConv(c), 91 | ResConv(c), 92 | ResConv(c), 93 | ) 94 | self.lastconv = torch.nn.Sequential( 95 | torch.nn.ConvTranspose2d(c, c, 6, 2, 2), 96 | torch.nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0, bias=True), 97 | torch.nn.ConvTranspose2d(c, c, 4, 2, 1), 98 | torch.nn.Conv2d(c, 6, kernel_size=1, stride=1, padding=0, bias=True), 99 | ) 100 | 101 | def forward(self, img0, img1, f0, f1, timestep, mask, flow, scale=1): 102 | timestep = (img0[:, :1].clone() * 0 + 1) * timestep 103 | 104 | if flow is None: 105 | x = torch.cat((img0, img1, f0, f1, timestep), 1) 106 | x = torch.nn.functional.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False) 107 | else: 108 | warped_img0 = warp(img0, flow[:, :2]) 109 | warped_img1 = warp(img0, flow[:, :2]) 110 | warped_f0 = warp(f0, flow[:, :2]) 111 | warped_f1 = warp(f1, flow[:, 2:4]) 112 | x = torch.cat((warped_img0, warped_img1, warped_f0, warped_f1, timestep, mask), 1) 113 | x = torch.nn.functional.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False) 114 | flow = torch.nn.functional.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False) * 1. / scale 115 | x = torch.cat((x, flow), 1) 116 | 117 | feat = self.conv0(x) 118 | feat = self.convblock(feat) 119 | tmp = self.lastconv(feat) 120 | tmp = torch.nn.functional.interpolate(tmp, scale_factor=scale, mode="bilinear", align_corners=False) 121 | flow = tmp[:, :4] * scale 122 | mask = tmp[:, 4:5] 123 | conf = tmp[:, 5:6] 124 | return flow, mask, conf 125 | 126 | class FlownetCas(Module): 127 | def __init__(self): 128 | super().__init__() 129 | self.block0 = Flownet(7+16, c=192) 130 | self.block1 = Flownet(8+4+16, c=128) 131 | self.block2 = Flownet(8+4+16, c=96) 132 | self.block3 = Flownet(8+4+16, c=64) 133 | self.encode = Head() 134 | 135 | def forward(self, img0, img1, timestep=0.5, scale=[8, 4, 2, 1], iterations=1): 136 | scale = [12, 6, 3, 1] if scale == [8, 4, 2, 1] else scale 137 | 138 | img0 = img0 139 | img1 = img1 140 | f0 = self.encode(img0) 141 | f1 = self.encode(img1) 142 | 143 | flow_list = [None] * 4 144 | mask_list = [None] * 4 145 | merged = [None] * 4 146 | 147 | flow, mask, conf = self.block0(img0, img1, f0, f1, timestep, None, None, scale=scale[0]) 148 | 149 | flow_list[0] = flow.clone() 150 | mask_list[0] = torch.sigmoid(mask.clone()) 151 | merged[0] = warp(img0, flow[:, :2]) * mask_list[0] + warp(img1, flow[:, 2:4]) * (1 - mask_list[0]) 152 | 153 | for iteration in range(iterations): 154 | flow_d, mask, conf = self.block1( 155 | img0, 156 | img1, 157 | f0, 158 | f1, 159 | timestep, 160 | mask, 161 | flow, 162 | scale=scale[1] 163 | ) 164 | flow = flow + flow_d 165 | 166 | flow_list[1] = flow.clone() 167 | mask_list[1] = torch.sigmoid(mask.clone()) 168 | merged[1] = warp(img0, flow[:, :2]) * mask_list[1] + warp(img1, flow[:, 2:4]) * (1 - mask_list[1]) 169 | 170 | for iteration in range(iterations): 171 | flow_d, mask, conf = self.block2( 172 | img0, 173 | img1, 174 | f0, 175 | f1, 176 | timestep, 177 | mask, 178 | flow, 179 | scale=scale[2] 180 | ) 181 | flow = flow + flow_d 182 | 183 | flow_list[2] = flow.clone() 184 | mask_list[2] = torch.sigmoid(mask.clone()) 185 | merged[2] = warp(img0, flow[:, :2]) * mask_list[2] + warp(img1, flow[:, 2:4]) * (1 - mask_list[2]) 186 | 187 | for iteration in range(iterations): 188 | flow_d, mask, conf = self.block3( 189 | img0, 190 | img1, 191 | f0, 192 | f1, 193 | timestep, 194 | mask, 195 | flow, 196 | scale=scale[3] 197 | ) 198 | flow = flow + flow_d 199 | 200 | flow_list[3] = flow 201 | mask_list[3] = torch.sigmoid(mask) 202 | merged[3] = warp(img0, flow[:, :2]) * mask_list[3] + warp(img1, flow[:, 2:4]) * (1 - mask_list[3]) 203 | 204 | return flow_list, mask_list, merged 205 | 206 | self.model = FlownetCas 207 | self.training_model = FlownetCas 208 | 209 | @staticmethod 210 | def get_info(): 211 | info = { 212 | 'name': 'Flownet4_v001eb3x', 213 | 'file': 'flownet4_v001eb3x.py', 214 | 'ratio_support': True, 215 | 'padding': 48 216 | } 217 | return info 218 | 219 | @staticmethod 220 | def get_name(): 221 | return 'TWML_Flownet_v001eb3x' 222 | 223 | @staticmethod 224 | def input_channels(model_state_dict): 225 | channels = 3 226 | try: 227 | channels = model_state_dict.get('multiresblock1.conv_3x3.conv1.weight').shape[1] 228 | except Exception as e: 229 | print (f'Unable to get model dict input channels - setting to 3, {e}') 230 | return channels 231 | 232 | @staticmethod 233 | def output_channels(model_state_dict): 234 | channels = 5 235 | try: 236 | channels = model_state_dict.get('conv_final.conv1.weight').shape[0] 237 | except Exception as e: 238 | print (f'Unable to get model dict output channels - setting to 3, {e}') 239 | return channels 240 | 241 | def get_model(self): 242 | import platform 243 | if platform.system() == 'Darwin': 244 | return self.training_model 245 | return self.model 246 | 247 | def get_training_model(self): 248 | return self.training_model 249 | 250 | def load_model(self, path, flownet, rank=0): 251 | import torch 252 | def convert(param): 253 | if rank == -1: 254 | return { 255 | k.replace("module.", ""): v 256 | for k, v in param.items() 257 | if "module." in k 258 | } 259 | else: 260 | return param 261 | if rank <= 0: 262 | if torch.cuda.is_available(): 263 | flownet.load_state_dict(convert(torch.load(path)), False) 264 | else: 265 | flownet.load_state_dict(convert(torch.load(path, map_location ='cpu')), False) -------------------------------------------------------------------------------- /pytorch/models/flownet4_v001_baseline.py: -------------------------------------------------------------------------------- 1 | class Model: 2 | 3 | info = { 4 | 'name': 'Flownet4_v001_baseline', 5 | 'file': 'flownet4_v001_baseline.py', 6 | 'ratio_support': True 7 | } 8 | 9 | def __init__(self, status = dict(), torch = None): 10 | if torch is None: 11 | import torch 12 | Module = torch.nn.Module 13 | backwarp_tenGrid = {} 14 | 15 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 16 | return torch.nn.Sequential( 17 | torch.nn.Conv2d( 18 | in_planes, 19 | out_planes, 20 | kernel_size=kernel_size, 21 | stride=stride, 22 | padding=padding, 23 | dilation=dilation, 24 | padding_mode = 'reflect', 25 | bias=True 26 | ), 27 | torch.nn.LeakyReLU(0.2, True) 28 | # torch.nn.SELU(inplace = True) 29 | ) 30 | 31 | def warp(tenInput, tenFlow): 32 | k = (str(tenFlow.device), str(tenFlow.size())) 33 | if k not in backwarp_tenGrid: 34 | tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3]).view(1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) 35 | tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2]).view(1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) 36 | backwarp_tenGrid[k] = torch.cat([ tenHorizontal, tenVertical ], 1).to(device=tenInput.device, dtype=tenInput.dtype) 37 | tenFlow = torch.cat([ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) ], 1) 38 | 39 | g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) 40 | return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True) 41 | 42 | class Head(Module): 43 | def __init__(self): 44 | super(Head, self).__init__() 45 | self.cnn0 = torch.nn.Conv2d(3, 32, 3, 2, 1) 46 | self.cnn1 = torch.nn.Conv2d(32, 32, 3, 1, 1) 47 | self.cnn2 = torch.nn.Conv2d(32, 32, 3, 1, 1) 48 | self.cnn3 = torch.nn.ConvTranspose2d(32, 8, 4, 2, 1) 49 | self.relu = torch.nn.LeakyReLU(0.2, True) 50 | 51 | def forward(self, x, feat=False): 52 | x0 = self.cnn0(x) 53 | x = self.relu(x0) 54 | x1 = self.cnn1(x) 55 | x = self.relu(x1) 56 | x2 = self.cnn2(x) 57 | x = self.relu(x2) 58 | x3 = self.cnn3(x) 59 | if feat: 60 | return [x0, x1, x2, x3] 61 | return x3 62 | 63 | class ResConv(Module): 64 | def __init__(self, c, dilation=1): 65 | super().__init__() 66 | self.conv = torch.nn.Conv2d(c, c, 3, 1, dilation, dilation = dilation, groups = 1, padding_mode = 'reflect', bias=True) 67 | self.beta = torch.nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 68 | self.relu = torch.nn.LeakyReLU(0.2, True) 69 | 70 | def forward(self, x): 71 | return self.relu(self.conv(x) * self.beta + x) 72 | 73 | class Flownet(Module): 74 | def __init__(self, in_planes, c=64): 75 | super().__init__() 76 | self.conv0 = torch.nn.Sequential( 77 | conv(in_planes, c//2, 3, 2, 1), 78 | conv(c//2, c, 3, 2, 1), 79 | ) 80 | self.convblock = torch.nn.Sequential( 81 | ResConv(c), 82 | ResConv(c), 83 | ResConv(c), 84 | ResConv(c), 85 | ResConv(c), 86 | ResConv(c), 87 | ResConv(c), 88 | ResConv(c), 89 | ) 90 | self.lastconv = torch.nn.Sequential( 91 | torch.nn.ConvTranspose2d(c, 4*6, 4, 2, 1), 92 | torch.nn.PixelShuffle(2) 93 | ) 94 | 95 | def forward(self, x, flow, scale=1): 96 | x = torch.nn.functional.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False) 97 | if flow is not None: 98 | flow = torch.nn.functional.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False) * 1. / scale 99 | x = torch.cat((x, flow), 1) 100 | feat = self.conv0(x) 101 | feat = self.convblock(feat) 102 | tmp = self.lastconv(feat) 103 | tmp = torch.nn.functional.interpolate(tmp, scale_factor=scale, mode="bilinear", align_corners=False) 104 | flow = tmp[:, :4] * scale 105 | mask = tmp[:, 4:5] 106 | conf = tmp[:, 5:6] 107 | return flow, mask, conf 108 | 109 | class FlownetCas(Module): 110 | def __init__(self): 111 | super().__init__() 112 | self.block0 = Flownet(7+16, c=192) 113 | self.block1 = Flownet(8+4+16, c=128) 114 | self.block2 = Flownet(8+4+16, c=96) 115 | self.block3 = Flownet(8+4+16, c=64) 116 | self.encode = Head() 117 | 118 | def forward(self, img0, img1, timestep=0.5, scale=[8, 4, 2, 1], iterations=1): 119 | gt = None 120 | # return self.encode(img0) 121 | img0 = img0 122 | img1 = img1 123 | f0 = self.encode(img0) 124 | f1 = self.encode(img1) 125 | 126 | if not torch.is_tensor(timestep): 127 | timestep = (img0[:, :1].clone() * 0 + 1) * timestep 128 | else: 129 | timestep = timestep.repeat(1, 1, img0.shape[2], img0.shape[3]) 130 | flow_list = [] 131 | merged = [] 132 | mask_list = [] 133 | conf_list = [] 134 | teacher_list = [] 135 | flow_list_teacher = [] 136 | warped_img0 = img0 137 | warped_img1 = img1 138 | flow = None 139 | loss_cons = 0 140 | stu = [self.block0, self.block1, self.block2, self.block3] 141 | flow = None 142 | 143 | # single step 144 | flow, mask, conf = stu[0](torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=1) 145 | 146 | flow_list = [flow] * 5 147 | mask_list = [torch.sigmoid(mask)] * 5 148 | conf_list = [torch.sigmoid(conf)] * 5 149 | merged_student = warp(img0, flow[:, :2]) * mask_list[0] + warp(img1, flow[:, 2:4]) * (1 - mask_list[0]) 150 | merged = [merged_student] * 5 151 | return flow_list, mask_list, conf_list, merged 152 | 153 | for i in range(4): 154 | if flow is not None: 155 | flow_d, mask, conf = stu[i](torch.cat((warped_img0, warped_img1, warped_f0, warped_f1, timestep, mask), 1), flow, scale=scale[i]) 156 | flow = flow + flow_d 157 | else: 158 | flow, mask, conf = stu[i](torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=scale[i]) 159 | 160 | mask_list.append(mask) 161 | flow_list.append(flow) 162 | conf_list.append(conf) 163 | warped_img0 = warp(img0, flow[:, :2]) 164 | warped_img1 = warp(img1, flow[:, 2:4]) 165 | warped_f0 = warp(f0, flow[:, :2]) 166 | warped_f1 = warp(f1, flow[:, 2:4]) 167 | merged_student = (warped_img0, warped_img1) 168 | merged.append(merged_student) 169 | conf = torch.sigmoid(torch.cat(conf_list, 1)) 170 | conf = conf / (conf.sum(1, True) + 1e-3) 171 | if gt is not None: 172 | flow_teacher = 0 173 | mask_teacher = 0 174 | for i in range(4): 175 | flow_teacher += conf[:, i:i+1] * flow_list[i] 176 | mask_teacher += conf[:, i:i+1] * mask_list[i] 177 | warped_img0_teacher = warp(img0, flow_teacher[:, :2]) 178 | warped_img1_teacher = warp(img1, flow_teacher[:, 2:4]) 179 | mask_teacher = torch.sigmoid(mask_teacher) 180 | merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (1 - mask_teacher) 181 | teacher_list.append(merged_teacher) 182 | flow_list_teacher.append(flow_teacher) 183 | 184 | for i in range(4): 185 | mask_list[i] = torch.sigmoid(mask_list[i]) 186 | merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i]) 187 | if gt is not None: 188 | loss_mask = ((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 1e-2).float().detach() 189 | loss_cons += (((flow_teacher.detach() - flow_list[i]) ** 2).sum(1, True) ** 0.5 * loss_mask).mean() * 0.001 190 | 191 | return flow_list, mask_list, merged 192 | 193 | self.model = FlownetCas 194 | self.training_model = FlownetCas 195 | 196 | @staticmethod 197 | def get_info(): 198 | return Model.info 199 | 200 | @staticmethod 201 | def get_name(): 202 | return Model.info.get('name') 203 | 204 | @staticmethod 205 | def input_channels(model_state_dict): 206 | channels = 3 207 | try: 208 | channels = model_state_dict.get('multiresblock1.conv_3x3.conv1.weight').shape[1] 209 | except Exception as e: 210 | print (f'Unable to get model dict input channels - setting to 3, {e}') 211 | return channels 212 | 213 | @staticmethod 214 | def output_channels(model_state_dict): 215 | channels = 5 216 | try: 217 | channels = model_state_dict.get('conv_final.conv1.weight').shape[0] 218 | except Exception as e: 219 | print (f'Unable to get model dict output channels - setting to 3, {e}') 220 | return channels 221 | 222 | def get_model(self): 223 | import platform 224 | if platform.system() == 'Darwin': 225 | return self.training_model 226 | return self.model 227 | 228 | def get_training_model(self): 229 | return self.training_model 230 | 231 | def load_model(self, path, flownet, rank=0): 232 | import torch 233 | def convert(param): 234 | if rank == -1: 235 | return { 236 | k.replace("module.", ""): v 237 | for k, v in param.items() 238 | if "module." in k 239 | } 240 | else: 241 | return param 242 | if rank <= 0: 243 | if torch.cuda.is_available(): 244 | flownet.load_state_dict(convert(torch.load(path)), False) 245 | else: 246 | flownet.load_state_dict(convert(torch.load(path, map_location ='cpu')), False) -------------------------------------------------------------------------------- /pytorch/models/flownet4_v001_baseline_sst.py: -------------------------------------------------------------------------------- 1 | class Model: 2 | 3 | info = { 4 | 'name': 'Flownet4_v001_baseline', 5 | 'file': 'flownet4_v001_baseline.py', 6 | 'ratio_support': True 7 | } 8 | 9 | def __init__(self, status = dict(), torch = None): 10 | if torch is None: 11 | import torch 12 | Module = torch.nn.Module 13 | backwarp_tenGrid = {} 14 | 15 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 16 | return torch.nn.Sequential( 17 | torch.nn.Conv2d( 18 | in_planes, 19 | out_planes, 20 | kernel_size=kernel_size, 21 | stride=stride, 22 | padding=padding, 23 | dilation=dilation, 24 | padding_mode = 'reflect', 25 | bias=True 26 | ), 27 | torch.nn.LeakyReLU(0.2, True) 28 | # torch.nn.SELU(inplace = True) 29 | ) 30 | 31 | def warp(tenInput, tenFlow): 32 | k = (str(tenFlow.device), str(tenFlow.size())) 33 | if k not in backwarp_tenGrid: 34 | tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3]).view(1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) 35 | tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2]).view(1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) 36 | backwarp_tenGrid[k] = torch.cat([ tenHorizontal, tenVertical ], 1).to(device=tenInput.device, dtype=tenInput.dtype) 37 | tenFlow = torch.cat([ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) ], 1) 38 | 39 | g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) 40 | return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True) 41 | 42 | class Head(Module): 43 | def __init__(self): 44 | super(Head, self).__init__() 45 | self.cnn0 = torch.nn.Conv2d(3, 32, 3, 2, 1) 46 | self.cnn1 = torch.nn.Conv2d(32, 32, 3, 1, 1) 47 | self.cnn2 = torch.nn.Conv2d(32, 32, 3, 1, 1) 48 | self.cnn3 = torch.nn.ConvTranspose2d(32, 8, 4, 2, 1) 49 | self.relu = torch.nn.LeakyReLU(0.2, True) 50 | 51 | def forward(self, x, feat=False): 52 | x0 = self.cnn0(x) 53 | x = self.relu(x0) 54 | x1 = self.cnn1(x) 55 | x = self.relu(x1) 56 | x2 = self.cnn2(x) 57 | x = self.relu(x2) 58 | x3 = self.cnn3(x) 59 | if feat: 60 | return [x0, x1, x2, x3] 61 | return x3 62 | 63 | class ResConv(Module): 64 | def __init__(self, c, dilation=1): 65 | super().__init__() 66 | self.conv = torch.nn.Conv2d(c, c, 3, 1, dilation, dilation = dilation, groups = 1, padding_mode = 'reflect', bias=True) 67 | self.beta = torch.nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 68 | self.relu = torch.nn.LeakyReLU(0.2, True) 69 | 70 | def forward(self, x): 71 | return self.relu(self.conv(x) * self.beta + x) 72 | 73 | class Flownet(Module): 74 | def __init__(self, in_planes, c=64): 75 | super().__init__() 76 | self.conv0 = torch.nn.Sequential( 77 | conv(in_planes, c//2, 3, 2, 1), 78 | conv(c//2, c, 3, 2, 1), 79 | ) 80 | self.convblock = torch.nn.Sequential( 81 | ResConv(c), 82 | ResConv(c), 83 | ResConv(c), 84 | ResConv(c), 85 | ResConv(c), 86 | ResConv(c), 87 | ResConv(c), 88 | ResConv(c), 89 | ) 90 | self.lastconv = torch.nn.Sequential( 91 | torch.nn.ConvTranspose2d(c, 4*6, 4, 2, 1), 92 | torch.nn.PixelShuffle(2) 93 | ) 94 | 95 | def forward(self, x, flow, scale=1): 96 | x = torch.nn.functional.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False) 97 | if flow is not None: 98 | flow = torch.nn.functional.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False) * 1. / scale 99 | x = torch.cat((x, flow), 1) 100 | feat = self.conv0(x) 101 | feat = self.convblock(feat) 102 | tmp = self.lastconv(feat) 103 | tmp = torch.nn.functional.interpolate(tmp, scale_factor=scale, mode="bilinear", align_corners=False) 104 | flow = tmp[:, :4] * scale 105 | mask = tmp[:, 4:5] 106 | conf = tmp[:, 5:6] 107 | return flow, mask, conf 108 | 109 | class FlownetCas(Module): 110 | def __init__(self): 111 | super().__init__() 112 | self.block0 = Flownet(7+16, c=192) 113 | self.block1 = Flownet(8+4+16, c=128) 114 | self.block2 = Flownet(8+4+16, c=96) 115 | self.block3 = Flownet(8+4+16, c=64) 116 | self.encode = Head() 117 | 118 | def forward(self, img0, img1, timestep=0.5, scale=[8, 4, 2, 1], iterations=1): 119 | gt = None 120 | # return self.encode(img0) 121 | img0 = img0 122 | img1 = img1 123 | f0 = self.encode(img0) 124 | f1 = self.encode(img1) 125 | 126 | if not torch.is_tensor(timestep): 127 | timestep = (img0[:, :1].clone() * 0 + 1) * timestep 128 | else: 129 | timestep = timestep.repeat(1, 1, img0.shape[2], img0.shape[3]) 130 | flow_list = [] 131 | merged = [] 132 | mask_list = [] 133 | conf_list = [] 134 | teacher_list = [] 135 | flow_list_teacher = [] 136 | warped_img0 = img0 137 | warped_img1 = img1 138 | flow = None 139 | loss_cons = 0 140 | stu = [self.block0, self.block1, self.block2, self.block3] 141 | flow = None 142 | 143 | # single step 144 | flow, mask, conf = stu[0](torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=1) 145 | 146 | flow_list = [flow] * 5 147 | mask_list = [torch.sigmoid(mask)] * 5 148 | conf_list = [torch.sigmoid(conf)] * 5 149 | merged_student = warp(img0, flow[:, :2]) * mask_list[0] + warp(img1, flow[:, 2:4]) * (1 - mask_list[0]) 150 | merged = [merged_student] * 5 151 | return flow_list, mask_list, conf_list, merged 152 | 153 | for i in range(4): 154 | if flow is not None: 155 | flow_d, mask, conf = stu[i](torch.cat((warped_img0, warped_img1, warped_f0, warped_f1, timestep, mask), 1), flow, scale=scale[i]) 156 | flow = flow + flow_d 157 | else: 158 | flow, mask, conf = stu[i](torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=scale[i]) 159 | 160 | mask_list.append(mask) 161 | flow_list.append(flow) 162 | conf_list.append(conf) 163 | warped_img0 = warp(img0, flow[:, :2]) 164 | warped_img1 = warp(img1, flow[:, 2:4]) 165 | warped_f0 = warp(f0, flow[:, :2]) 166 | warped_f1 = warp(f1, flow[:, 2:4]) 167 | merged_student = (warped_img0, warped_img1) 168 | merged.append(merged_student) 169 | conf = torch.sigmoid(torch.cat(conf_list, 1)) 170 | conf = conf / (conf.sum(1, True) + 1e-3) 171 | if gt is not None: 172 | flow_teacher = 0 173 | mask_teacher = 0 174 | for i in range(4): 175 | flow_teacher += conf[:, i:i+1] * flow_list[i] 176 | mask_teacher += conf[:, i:i+1] * mask_list[i] 177 | warped_img0_teacher = warp(img0, flow_teacher[:, :2]) 178 | warped_img1_teacher = warp(img1, flow_teacher[:, 2:4]) 179 | mask_teacher = torch.sigmoid(mask_teacher) 180 | merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (1 - mask_teacher) 181 | teacher_list.append(merged_teacher) 182 | flow_list_teacher.append(flow_teacher) 183 | 184 | for i in range(4): 185 | mask_list[i] = torch.sigmoid(mask_list[i]) 186 | merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i]) 187 | if gt is not None: 188 | loss_mask = ((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 1e-2).float().detach() 189 | loss_cons += (((flow_teacher.detach() - flow_list[i]) ** 2).sum(1, True) ** 0.5 * loss_mask).mean() * 0.001 190 | 191 | return flow_list, mask_list, merged 192 | 193 | self.model = FlownetCas 194 | self.training_model = FlownetCas 195 | 196 | @staticmethod 197 | def get_info(): 198 | return Model.info 199 | 200 | @staticmethod 201 | def get_name(): 202 | return Model.info.get('name') 203 | 204 | @staticmethod 205 | def input_channels(model_state_dict): 206 | channels = 3 207 | try: 208 | channels = model_state_dict.get('multiresblock1.conv_3x3.conv1.weight').shape[1] 209 | except Exception as e: 210 | print (f'Unable to get model dict input channels - setting to 3, {e}') 211 | return channels 212 | 213 | @staticmethod 214 | def output_channels(model_state_dict): 215 | channels = 5 216 | try: 217 | channels = model_state_dict.get('conv_final.conv1.weight').shape[0] 218 | except Exception as e: 219 | print (f'Unable to get model dict output channels - setting to 3, {e}') 220 | return channels 221 | 222 | def get_model(self): 223 | import platform 224 | if platform.system() == 'Darwin': 225 | return self.training_model 226 | return self.model 227 | 228 | def get_training_model(self): 229 | return self.training_model 230 | 231 | def load_model(self, path, flownet, rank=0): 232 | import torch 233 | def convert(param): 234 | if rank == -1: 235 | return { 236 | k.replace("module.", ""): v 237 | for k, v in param.items() 238 | if "module." in k 239 | } 240 | else: 241 | return param 242 | if rank <= 0: 243 | if torch.cuda.is_available(): 244 | flownet.load_state_dict(convert(torch.load(path)), False) 245 | else: 246 | flownet.load_state_dict(convert(torch.load(path, map_location ='cpu')), False) -------------------------------------------------------------------------------- /pytorch/models/flownet4_v001a.py: -------------------------------------------------------------------------------- 1 | # Orig v001 changed to v002 main flow and signatures 2 | # Back from SiLU to LeakyReLU to test data flow 3 | # Warps moved to flownet forward 4 | # Different Tail from flownet 2lh (ConvTr 6x6, conv 1x1, ConvTr 4x4, conv 1x1) 5 | 6 | class Model: 7 | 8 | info = { 9 | 'name': 'Flownet4_v001', 10 | 'file': 'flownet4_v001.py', 11 | 'ratio_support': True 12 | } 13 | 14 | def __init__(self, status = dict(), torch = None): 15 | if torch is None: 16 | import torch 17 | Module = torch.nn.Module 18 | 19 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 20 | return torch.nn.Sequential( 21 | torch.nn.Conv2d( 22 | in_planes, 23 | out_planes, 24 | kernel_size=kernel_size, 25 | stride=stride, 26 | padding=padding, 27 | dilation=dilation, 28 | padding_mode = 'zeros', 29 | bias=True 30 | ), 31 | torch.nn.LeakyReLU(0.2, True) 32 | # torch.nn.SELU(inplace = True) 33 | ) 34 | 35 | def warp(tenInput, tenFlow): 36 | k = (str(tenFlow.device), str(tenFlow.size())) 37 | if k not in backwarp_tenGrid: 38 | tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3]).view(1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) 39 | tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2]).view(1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) 40 | backwarp_tenGrid[k] = torch.cat([ tenHorizontal, tenVertical ], 1).to(device=tenInput.device, dtype=tenInput.dtype) 41 | tenFlow = torch.cat([ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) ], 1) 42 | 43 | g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) 44 | return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True) 45 | 46 | class Head(Module): 47 | def __init__(self): 48 | super(Head, self).__init__() 49 | self.cnn0 = torch.nn.Conv2d(3, 32, 3, 2, 1) 50 | self.cnn1 = torch.nn.Conv2d(32, 32, 3, 1, 1) 51 | self.cnn2 = torch.nn.Conv2d(32, 32, 3, 1, 1) 52 | self.cnn3 = torch.nn.ConvTranspose2d(32, 8, 4, 2, 1) 53 | self.relu = torch.nn.LeakyReLU(0.2, True) 54 | 55 | def forward(self, x, feat=False): 56 | x0 = self.cnn0(x) 57 | x = self.relu(x0) 58 | x1 = self.cnn1(x) 59 | x = self.relu(x1) 60 | x2 = self.cnn2(x) 61 | x = self.relu(x2) 62 | x3 = self.cnn3(x) 63 | if feat: 64 | return [x0, x1, x2, x3] 65 | return x3 66 | 67 | class ResConv(Module): 68 | def __init__(self, c, dilation=1): 69 | super().__init__() 70 | self.conv = torch.nn.Conv2d(c, c, 3, 1, dilation, dilation = dilation, groups = 1, padding_mode = 'zeros', bias=True) 71 | self.beta = torch.nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 72 | self.relu = torch.nn.LeakyReLU(0.2, True) # torch.nn.SELU(inplace = True) 73 | def forward(self, x): 74 | return self.relu(self.conv(x) * self.beta + x) 75 | 76 | class Flownet(Module): 77 | def __init__(self, in_planes, c=64): 78 | super().__init__() 79 | self.conv0 = torch.nn.Sequential( 80 | conv(in_planes, c//2, 3, 2, 1), 81 | conv(c//2, c, 3, 2, 1), 82 | ) 83 | self.convblock = torch.nn.Sequential( 84 | ResConv(c), 85 | ResConv(c), 86 | ResConv(c), 87 | ResConv(c), 88 | ResConv(c), 89 | ResConv(c), 90 | ResConv(c), 91 | ResConv(c), 92 | ) 93 | self.lastconv = torch.nn.Sequential( 94 | torch.nn.ConvTranspose2d(c, 4*6, 4, 2, 1), 95 | torch.nn.PixelShuffle(2) 96 | ) 97 | 98 | def forward(self, img0, img1, f0, f1, timestep, mask, flow, scale=1): 99 | timestep = (img0[:, :1].clone() * 0 + 1) * timestep 100 | 101 | if flow is None: 102 | x = torch.cat((img0, img1, f0, f1, timestep), 1) 103 | x = torch.nn.functional.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False) 104 | else: 105 | warped_img0 = warp(img0, flow[:, :2]) 106 | warped_img1 = warp(img1, flow[:, 2:4]) 107 | warped_f0 = warp(f0, flow[:, :2]) 108 | warped_f1 = warp(f1, flow[:, 2:4]) 109 | x = torch.cat((warped_img0, warped_img1, warped_f0, warped_f1, timestep, mask), 1) 110 | x = torch.nn.functional.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False) 111 | flow = torch.nn.functional.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False) * 1. / scale 112 | x = torch.cat((x, flow), 1) 113 | 114 | feat = self.conv0(x) 115 | feat = self.convblock(feat) 116 | tmp = self.lastconv(feat) 117 | tmp = torch.nn.functional.interpolate(tmp, scale_factor=scale, mode="bilinear", align_corners=False) 118 | flow = tmp[:, :4] * scale 119 | mask = tmp[:, 4:5] 120 | conf = tmp[:, 5:6] 121 | return flow, mask, conf 122 | 123 | class FlownetCas(Module): 124 | def __init__(self): 125 | super().__init__() 126 | self.block0 = Flownet(7+16, c=192) 127 | self.block1 = Flownet(8+4+16, c=128) 128 | self.block2 = Flownet(8+4+16, c=96) 129 | self.block3 = Flownet(8+4+16, c=64) 130 | self.encode = Head() 131 | 132 | def forward(self, img0, img1, timestep=0.5, scale=[16, 8, 4, 1], iterations=1): 133 | img0 = img0 134 | img1 = img1 135 | f0 = self.encode(img0) 136 | f1 = self.encode(img1) 137 | 138 | flow_list = [None] * 4 139 | mask_list = [None] * 4 140 | conf_list = [None] * 4 141 | merged = [None] * 4 142 | 143 | flow, mask, conf = self.block0(img0, img1, f0, f1, timestep, None, None, scale=scale[0]) 144 | 145 | flow_list[0] = flow.clone() 146 | conf_list[0] = torch.sigmoid(conf.clone()) 147 | mask_list[0] = torch.sigmoid(mask.clone()) 148 | merged[0] = warp(img0, flow[:, :2]) * mask_list[0] + warp(img1, flow[:, 2:4]) * (1 - mask_list[0]) 149 | 150 | for iteration in range(iterations): 151 | flow_d, mask, conf = self.block1( 152 | img0, 153 | img1, 154 | f0, 155 | f1, 156 | timestep, 157 | mask, 158 | flow, 159 | scale=scale[1] 160 | ) 161 | flow = flow + flow_d 162 | 163 | flow_list[1] = flow.clone() 164 | conf_list[1] = torch.sigmoid(conf.clone()) 165 | mask_list[1] = torch.sigmoid(mask.clone()) 166 | merged[1] = warp(img0, flow[:, :2]) * mask_list[1] + warp(img1, flow[:, 2:4]) * (1 - mask_list[1]) 167 | 168 | for iteration in range(iterations): 169 | flow_d, mask, conf = self.block2( 170 | img0, 171 | img1, 172 | f0, 173 | f1, 174 | timestep, 175 | mask, 176 | flow, 177 | scale=scale[2] 178 | ) 179 | flow = flow + flow_d 180 | 181 | flow_list[2] = flow.clone() 182 | conf_list[2] = torch.sigmoid(conf.clone()) 183 | mask_list[2] = torch.sigmoid(mask.clone()) 184 | merged[2] = warp(img0, flow[:, :2]) * mask_list[2] + warp(img1, flow[:, 2:4]) * (1 - mask_list[2]) 185 | 186 | for iteration in range(iterations): 187 | flow_d, mask, conf = self.block3( 188 | img0, 189 | img1, 190 | f0, 191 | f1, 192 | timestep, 193 | mask, 194 | flow, 195 | scale=scale[3] 196 | ) 197 | flow = flow + flow_d 198 | 199 | flow_list[3] = flow 200 | conf_list[3] = torch.sigmoid(conf) 201 | mask_list[3] = torch.sigmoid(mask) 202 | merged[3] = warp(img0, flow[:, :2]) * mask_list[3] + warp(img1, flow[:, 2:4]) * (1 - mask_list[3]) 203 | 204 | return flow_list, mask_list, conf_list, merged 205 | 206 | self.model = FlownetCas 207 | self.training_model = FlownetCas 208 | 209 | @staticmethod 210 | def get_info(): 211 | return Model.info 212 | 213 | @staticmethod 214 | def get_name(): 215 | return Model.info.get('name') 216 | 217 | @staticmethod 218 | def input_channels(model_state_dict): 219 | channels = 3 220 | try: 221 | channels = model_state_dict.get('multiresblock1.conv_3x3.conv1.weight').shape[1] 222 | except Exception as e: 223 | print (f'Unable to get model dict input channels - setting to 3, {e}') 224 | return channels 225 | 226 | @staticmethod 227 | def output_channels(model_state_dict): 228 | channels = 5 229 | try: 230 | channels = model_state_dict.get('conv_final.conv1.weight').shape[0] 231 | except Exception as e: 232 | print (f'Unable to get model dict output channels - setting to 3, {e}') 233 | return channels 234 | 235 | def get_model(self): 236 | import platform 237 | if platform.system() == 'Darwin': 238 | return self.training_model 239 | return self.model 240 | 241 | def get_training_model(self): 242 | return self.training_model 243 | 244 | def load_model(self, path, flownet, rank=0): 245 | import torch 246 | def convert(param): 247 | if rank == -1: 248 | return { 249 | k.replace("module.", ""): v 250 | for k, v in param.items() 251 | if "module." in k 252 | } 253 | else: 254 | return param 255 | if rank <= 0: 256 | if torch.cuda.is_available(): 257 | flownet.load_state_dict(convert(torch.load(path)), False) 258 | else: 259 | flownet.load_state_dict(convert(torch.load(path, map_location ='cpu')), False) -------------------------------------------------------------------------------- /pytorch/models/flownet4_v001eb.py: -------------------------------------------------------------------------------- 1 | # Orig v001 changed to v002 main flow and signatures 2 | # Back from SiLU to LeakyReLU to test data flow 3 | # Warps moved to flownet forward 4 | # Different Tail from flownet 2lh (ConvTr 6x6, conv 1x1, ConvTr 4x4, conv 1x1) 5 | 6 | class Model: 7 | 8 | info = { 9 | 'name': 'Flownet4_v001eb', 10 | 'file': 'flownet4_v001eb.py', 11 | 'ratio_support': True 12 | } 13 | 14 | def __init__(self, status = dict(), torch = None): 15 | if torch is None: 16 | import torch 17 | Module = torch.nn.Module 18 | backwarp_tenGrid = {} 19 | 20 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 21 | return torch.nn.Sequential( 22 | torch.nn.Conv2d( 23 | in_planes, 24 | out_planes, 25 | kernel_size=kernel_size, 26 | stride=stride, 27 | padding=padding, 28 | dilation=dilation, 29 | padding_mode = 'reflect', 30 | bias=True 31 | ), 32 | torch.nn.LeakyReLU(0.2, True) 33 | # torch.nn.SELU(inplace = True) 34 | ) 35 | 36 | def warp(tenInput, tenFlow): 37 | k = (str(tenFlow.device), str(tenFlow.size())) 38 | if k not in backwarp_tenGrid: 39 | tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3]).view(1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) 40 | tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2]).view(1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) 41 | backwarp_tenGrid[k] = torch.cat([ tenHorizontal, tenVertical ], 1).to(device=tenInput.device, dtype=tenInput.dtype) 42 | tenFlow = torch.cat([ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) ], 1) 43 | 44 | g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) 45 | return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='reflection', align_corners=True) 46 | 47 | class Head(Module): 48 | def __init__(self): 49 | super(Head, self).__init__() 50 | self.cnn0 = torch.nn.Conv2d(3, 32, 3, 2, 1) 51 | self.cnn1 = torch.nn.Conv2d(32, 32, 3, 1, 1) 52 | self.cnn2 = torch.nn.Conv2d(32, 32, 3, 1, 1) 53 | self.cnn3 = torch.nn.ConvTranspose2d(32, 8, 4, 2, 1) 54 | self.relu = torch.nn.LeakyReLU(0.2, True) 55 | 56 | def forward(self, x, feat=False): 57 | x0 = self.cnn0(x) 58 | x = self.relu(x0) 59 | x1 = self.cnn1(x) 60 | x = self.relu(x1) 61 | x2 = self.cnn2(x) 62 | x = self.relu(x2) 63 | x3 = self.cnn3(x) 64 | if feat: 65 | return [x0, x1, x2, x3] 66 | return x3 67 | 68 | class ResConv(Module): 69 | def __init__(self, c, dilation=1): 70 | super().__init__() 71 | self.conv = torch.nn.Conv2d(c, c, 3, 1, dilation, dilation = dilation, groups = 1, padding_mode = 'reflect', bias=True) 72 | self.beta = torch.nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 73 | self.relu = torch.nn.LeakyReLU(0.2, True) # torch.nn.SELU(inplace = True) 74 | 75 | torch.nn.init.kaiming_normal_(self.conv.weight, mode='fan_in', nonlinearity='relu') 76 | self.conv.weight.data *= 1e-2 77 | if self.conv.bias is not None: 78 | torch.nn.init.constant_(self.conv.bias, 0) 79 | 80 | def forward(self, x): 81 | return self.relu(self.conv(x) * self.beta + x) 82 | 83 | class Flownet(Module): 84 | def __init__(self, in_planes, c=64): 85 | super().__init__() 86 | self.conv0 = torch.nn.Sequential( 87 | conv(in_planes, c//2, 3, 2, 1), 88 | conv(c//2, c, 3, 2, 1), 89 | ) 90 | self.convblock = torch.nn.Sequential( 91 | ResConv(c), 92 | ResConv(c), 93 | ResConv(c), 94 | ResConv(c), 95 | ResConv(c), 96 | ResConv(c), 97 | ResConv(c), 98 | ResConv(c), 99 | ) 100 | self.lastconv = torch.nn.Sequential( 101 | torch.nn.ConvTranspose2d(c, c, 6, 2, 2), 102 | torch.nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0, bias=True), 103 | torch.nn.ConvTranspose2d(c, c, 4, 2, 1), 104 | torch.nn.Conv2d(c, 6, kernel_size=1, stride=1, padding=0, bias=True), 105 | ) 106 | 107 | def forward(self, img0, img1, f0, f1, timestep, mask, flow, scale=1): 108 | timestep = (img0[:, :1].clone() * 0 + 1) * timestep 109 | 110 | if flow is None: 111 | x = torch.cat((img0, img1, f0, f1, timestep), 1) 112 | x = torch.nn.functional.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False) 113 | else: 114 | warped_img0 = warp(img0, flow[:, :2]) 115 | warped_img1 = warp(img1, flow[:, 2:4]) 116 | warped_f0 = warp(f0, flow[:, :2]) 117 | warped_f1 = warp(f1, flow[:, 2:4]) 118 | x = torch.cat((warped_img0, warped_img1, warped_f0, warped_f1, timestep, mask), 1) 119 | x = torch.nn.functional.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False) 120 | flow = torch.nn.functional.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False) * 1. / scale 121 | x = torch.cat((x, flow), 1) 122 | 123 | feat = self.conv0(x) 124 | feat = self.convblock(feat) 125 | tmp = self.lastconv(feat) 126 | tmp = torch.nn.functional.interpolate(tmp, scale_factor=scale, mode="bilinear", align_corners=False) 127 | flow = tmp[:, :4] * scale 128 | mask = tmp[:, 4:5] 129 | conf = tmp[:, 5:6] 130 | return flow, mask, conf 131 | 132 | class FlownetCas(Module): 133 | def __init__(self): 134 | super().__init__() 135 | self.block0 = Flownet(7+16, c=192) 136 | self.block1 = Flownet(8+4+16, c=128) 137 | self.block2 = Flownet(8+4+16, c=96) 138 | self.block3 = Flownet(8+4+16, c=64) 139 | self.encode = Head() 140 | 141 | def forward(self, img0, img1, timestep=0.5, scale=[16, 8, 4, 1], iterations=1): 142 | img0 = img0 143 | img1 = img1 144 | f0 = self.encode(img0) 145 | f1 = self.encode(img1) 146 | 147 | flow_list = [None] * 4 148 | mask_list = [None] * 4 149 | merged = [None] * 4 150 | 151 | flow, mask, conf = self.block0(img0, img1, f0, f1, timestep, None, None, scale=scale[0]) 152 | 153 | flow_list[0] = flow.clone() 154 | mask_list[0] = torch.sigmoid(mask.clone()) 155 | merged[0] = warp(img0, flow[:, :2]) * mask_list[0] + warp(img1, flow[:, 2:4]) * (1 - mask_list[0]) 156 | 157 | for iteration in range(iterations): 158 | flow_d, mask, conf = self.block1( 159 | img0, 160 | img1, 161 | f0, 162 | f1, 163 | timestep, 164 | mask, 165 | flow, 166 | scale=scale[1] 167 | ) 168 | flow = flow + flow_d 169 | 170 | flow_list[1] = flow.clone() 171 | mask_list[1] = torch.sigmoid(mask.clone()) 172 | merged[1] = warp(img0, flow[:, :2]) * mask_list[1] + warp(img1, flow[:, 2:4]) * (1 - mask_list[1]) 173 | 174 | for iteration in range(iterations): 175 | flow_d, mask, conf = self.block2( 176 | img0, 177 | img1, 178 | f0, 179 | f1, 180 | timestep, 181 | mask, 182 | flow, 183 | scale=scale[2] 184 | ) 185 | flow = flow + flow_d 186 | 187 | flow_list[2] = flow.clone() 188 | mask_list[2] = torch.sigmoid(mask.clone()) 189 | merged[2] = warp(img0, flow[:, :2]) * mask_list[2] + warp(img1, flow[:, 2:4]) * (1 - mask_list[2]) 190 | 191 | for iteration in range(iterations): 192 | flow_d, mask, conf = self.block3( 193 | img0, 194 | img1, 195 | f0, 196 | f1, 197 | timestep, 198 | mask, 199 | flow, 200 | scale=scale[3] 201 | ) 202 | flow = flow + flow_d 203 | 204 | flow_list[3] = flow 205 | mask_list[3] = torch.sigmoid(mask) 206 | merged[3] = warp(img0, flow[:, :2]) * mask_list[3] + warp(img1, flow[:, 2:4]) * (1 - mask_list[3]) 207 | 208 | return flow_list, mask_list, merged 209 | 210 | self.model = FlownetCas 211 | self.training_model = FlownetCas 212 | 213 | @staticmethod 214 | def get_info(): 215 | return Model.info 216 | 217 | @staticmethod 218 | def get_name(): 219 | return Model.info.get('name') 220 | 221 | @staticmethod 222 | def input_channels(model_state_dict): 223 | channels = 3 224 | try: 225 | channels = model_state_dict.get('multiresblock1.conv_3x3.conv1.weight').shape[1] 226 | except Exception as e: 227 | print (f'Unable to get model dict input channels - setting to 3, {e}') 228 | return channels 229 | 230 | @staticmethod 231 | def output_channels(model_state_dict): 232 | channels = 5 233 | try: 234 | channels = model_state_dict.get('conv_final.conv1.weight').shape[0] 235 | except Exception as e: 236 | print (f'Unable to get model dict output channels - setting to 3, {e}') 237 | return channels 238 | 239 | def get_model(self): 240 | import platform 241 | if platform.system() == 'Darwin': 242 | return self.training_model 243 | return self.model 244 | 245 | def get_training_model(self): 246 | return self.training_model 247 | 248 | def load_model(self, path, flownet, rank=0): 249 | import torch 250 | def convert(param): 251 | if rank == -1: 252 | return { 253 | k.replace("module.", ""): v 254 | for k, v in param.items() 255 | if "module." in k 256 | } 257 | else: 258 | return param 259 | if rank <= 0: 260 | if torch.cuda.is_available(): 261 | flownet.load_state_dict(convert(torch.load(path)), False) 262 | else: 263 | flownet.load_state_dict(convert(torch.load(path, map_location ='cpu')), False) -------------------------------------------------------------------------------- /pytorch/models/stabnet4_v000.py: -------------------------------------------------------------------------------- 1 | # vanilla RIFE block 2 | 3 | class Model: 4 | 5 | info = { 6 | 'name': 'Stabnet4_v001', 7 | 'file': 'stabnet4_v001.py', 8 | 'ratio_support': True 9 | } 10 | 11 | def __init__(self, status = dict(), torch = None): 12 | if torch is None: 13 | import torch 14 | Module = torch.nn.Module 15 | backwarp_tenGrid = {} 16 | 17 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 18 | return torch.nn.Sequential( 19 | torch.nn.Conv2d( 20 | in_planes, 21 | out_planes, 22 | kernel_size=kernel_size, 23 | stride=stride, 24 | padding=padding, 25 | dilation=dilation, 26 | padding_mode = 'reflect', 27 | bias=True 28 | ), 29 | torch.nn.LeakyReLU(0.2, True) 30 | # torch.nn.SELU(inplace = True) 31 | ) 32 | 33 | def warp(tenInput, tenFlow): 34 | k = (str(tenFlow.device), str(tenFlow.size())) 35 | if k not in backwarp_tenGrid: 36 | tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3]).view(1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) 37 | tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2]).view(1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) 38 | backwarp_tenGrid[k] = torch.cat([ tenHorizontal, tenVertical ], 1).to(device=tenInput.device, dtype=tenInput.dtype) 39 | tenFlow = torch.cat([ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) ], 1) 40 | 41 | g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) 42 | return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bicubic', padding_mode='border', align_corners=True) 43 | 44 | def normalize(tensor, min_val, max_val): 45 | src_dtype = tensor.dtype 46 | tensor = tensor.float() 47 | t_min = tensor.min() 48 | t_max = tensor.max() 49 | 50 | if t_min == t_max: 51 | return torch.full_like(tensor, (min_val + max_val) / 2.0) 52 | 53 | tensor = ((tensor - t_min) / (t_max - t_min)) * (max_val - min_val) + min_val 54 | tensor = tensor.to(dtype = src_dtype) 55 | return tensor 56 | 57 | def hpass(img): 58 | src_dtype = img.dtype 59 | img = img.float() 60 | def gauss_kernel(size=5, channels=3): 61 | kernel = torch.tensor([[1., 4., 6., 4., 1], 62 | [4., 16., 24., 16., 4.], 63 | [6., 24., 36., 24., 6.], 64 | [4., 16., 24., 16., 4.], 65 | [1., 4., 6., 4., 1.]]) 66 | kernel /= 256. 67 | kernel = kernel.repeat(channels, 1, 1, 1) 68 | return kernel 69 | 70 | def conv_gauss(img, kernel): 71 | img = torch.nn.functional.pad(img, (2, 2, 2, 2), mode='reflect') 72 | out = torch.nn.functional.conv2d(img, kernel, groups=img.shape[1]) 73 | return out 74 | 75 | gkernel = gauss_kernel() 76 | gkernel = gkernel.to(device=img.device, dtype=img.dtype) 77 | hp = img - conv_gauss(img, gkernel) + 0.5 78 | hp = torch.clamp(hp, 0.48, 0.52) 79 | hp = normalize(hp, 0, 1) 80 | hp = torch.max(hp, dim=1, keepdim=True).values 81 | hp = hp.to(dtype = src_dtype) 82 | return hp 83 | 84 | def compress(x): 85 | src_dtype = x.dtype 86 | x = x.float() 87 | scale = torch.tanh(torch.tensor(1.0)) 88 | x = torch.where( 89 | (x >= -1) & (x <= 1), scale * x, 90 | torch.tanh(x) 91 | ) 92 | x = (x + 1) / 2 93 | x = x.to(dtype = src_dtype) 94 | return x 95 | 96 | class Head(Module): 97 | def __init__(self): 98 | super(Head, self).__init__() 99 | self.encode = torch.nn.Sequential( 100 | torch.nn.Conv2d(4, 32, 3, 2, 1), 101 | torch.nn.LeakyReLU(0.2, True), 102 | torch.nn.Conv2d(32, 32, 3, 1, 1), 103 | torch.nn.LeakyReLU(0.2, True), 104 | torch.nn.Conv2d(32, 32, 3, 1, 1), 105 | torch.nn.LeakyReLU(0.2, True), 106 | torch.nn.ConvTranspose2d(32, 8, 4, 2, 1) 107 | ) 108 | self.maxdepth = 2 109 | 110 | def forward(self, x): 111 | hp = hpass(x) 112 | x = torch.cat((x, hp), 1) 113 | 114 | n, c, h, w = x.shape 115 | ph = self.maxdepth - (h % self.maxdepth) 116 | pw = self.maxdepth - (w % self.maxdepth) 117 | padding = (0, pw, 0, ph) 118 | x = torch.nn.functional.pad(x, padding) 119 | 120 | return self.encode(x)[:, :, :h, :w] 121 | 122 | class ResConv(Module): 123 | def __init__(self, c, dilation=1): 124 | super().__init__() 125 | self.conv = torch.nn.Conv2d(c, c, 3, 1, dilation, dilation = dilation, groups = 1, padding_mode = 'zeros', bias=True) 126 | self.beta = torch.nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 127 | self.relu = torch.nn.PReLU(c, 0.2) 128 | def forward(self, x): 129 | return self.relu(self.conv(x) * self.beta + x) 130 | 131 | class Flownet(Module): 132 | def __init__(self, in_planes, c=64): 133 | super().__init__() 134 | self.conv0 = torch.nn.Sequential( 135 | conv(in_planes, c//2, 5, 2, 2), 136 | conv(c//2, c, 3, 2, 1), 137 | ) 138 | self.convblock = torch.nn.Sequential( 139 | ResConv(c), 140 | ResConv(c), 141 | ResConv(c), 142 | ResConv(c), 143 | ResConv(c), 144 | ResConv(c), 145 | ResConv(c), 146 | ResConv(c), 147 | ) 148 | self.lastconv = torch.nn.Sequential( 149 | torch.nn.ConvTranspose2d(c, 4*2, 4, 2, 1), 150 | torch.nn.PixelShuffle(2) 151 | ) 152 | self.maxdepth = 4 153 | 154 | def forward(self, img0, img1, f0, f1, scale=1): 155 | n, c, h, w = img0.shape 156 | sh, sw = round(h * (1 / scale)), round(w * (1 / scale)) 157 | 158 | ph = self.maxdepth - (sh % self.maxdepth) 159 | pw = self.maxdepth - (sw % self.maxdepth) 160 | padding = (0, pw, 0, ph) 161 | 162 | imgs = torch.cat((img0, img1), 1) 163 | imgs = normalize(imgs, 0, 1) * 2 - 1 164 | x = torch.cat((imgs, f0, f1), 1) 165 | x = torch.nn.functional.interpolate(x, size=(sh, sw), mode="bicubic", align_corners=False) 166 | x = torch.nn.functional.pad(x, padding) 167 | 168 | feat = self.conv0(x) 169 | feat = self.convblock(feat) 170 | feat = self.lastconv(feat) 171 | feat = torch.nn.functional.interpolate(feat[:, :, :sh, :sw], size=(h, w), mode="bilinear", align_corners=False) 172 | flow = feat * scale 173 | return flow 174 | 175 | class FlownetCas(Module): 176 | def __init__(self): 177 | super().__init__() 178 | self.block0 = Flownet(6+16, c=64) # images + feat + timestep + lingrid 179 | # self.block1 = FlownetDeepDualHead(6+3+20+1+1+2+1+4, 12+20+8+1, c=128) # FlownetDeepDualHead(9+30+1+1+4+1+2, 22+30+1, c=128) # images + feat + timestep + lingrid + mask + conf + flow 180 | # self.block2 = FlownetDeepDualHead(6+3+20+1+1+2+1+4, 12+20+8+1, c=96) # FlownetLT(6+2+1+1+1, c=48) # None # FlownetDeepDualHead(9+30+1+1+4+1+2, 22+30+1, c=112) # images + feat + timestep + lingrid + mask + conf + flow 181 | # self.block3 = FlownetLT(11, c=48) 182 | self.encode = Head() 183 | 184 | def forward(self, img0, img1, timestep=0.5, scale=[16, 8, 4, 1], iterations=4, gt=None): 185 | img0 = compress(img0 * 2 - 1) 186 | img1 = compress(img1 * 2 - 1) 187 | 188 | f0 = self.encode(img0) 189 | f1 = self.encode(img1) 190 | 191 | flow = self.block0(img0, img1, f0, f1, scale=1) 192 | 193 | result = { 194 | 'flow_list': [flow] 195 | } 196 | 197 | return result 198 | 199 | self.model = FlownetCas 200 | self.training_model = FlownetCas 201 | 202 | @staticmethod 203 | def get_info(): 204 | return Model.info 205 | 206 | @staticmethod 207 | def get_name(): 208 | return Model.info.get('name') 209 | 210 | @staticmethod 211 | def input_channels(model_state_dict): 212 | channels = 3 213 | try: 214 | channels = model_state_dict.get('multiresblock1.conv_3x3.conv1.weight').shape[1] 215 | except Exception as e: 216 | print (f'Unable to get model dict input channels - setting to 3, {e}') 217 | return channels 218 | 219 | @staticmethod 220 | def output_channels(model_state_dict): 221 | channels = 5 222 | try: 223 | channels = model_state_dict.get('conv_final.conv1.weight').shape[0] 224 | except Exception as e: 225 | print (f'Unable to get model dict output channels - setting to 3, {e}') 226 | return channels 227 | 228 | def get_model(self): 229 | import platform 230 | if platform.system() == 'Darwin': 231 | return self.training_model 232 | return self.model 233 | 234 | def get_training_model(self): 235 | return self.training_model 236 | 237 | def load_model(self, path, flownet, rank=0): 238 | import torch 239 | def convert(param): 240 | if rank == -1: 241 | return { 242 | k.replace("module.", ""): v 243 | for k, v in param.items() 244 | if "module." in k 245 | } 246 | else: 247 | return param 248 | if rank <= 0: 249 | if torch.cuda.is_available(): 250 | flownet.load_state_dict(convert(torch.load(path)), False) 251 | else: 252 | flownet.load_state_dict(convert(torch.load(path, map_location ='cpu')), False) -------------------------------------------------------------------------------- /pytorch/models/stabnet4_v001.py: -------------------------------------------------------------------------------- 1 | # vanilla RIFE block + PReLU 2 | 3 | class Model: 4 | 5 | info = { 6 | 'name': 'Stabnet4_v001', 7 | 'file': 'stabnet4_v001.py', 8 | 'ratio_support': True 9 | } 10 | 11 | def __init__(self, status = dict(), torch = None): 12 | if torch is None: 13 | import torch 14 | Module = torch.nn.Module 15 | backwarp_tenGrid = {} 16 | 17 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 18 | return torch.nn.Sequential( 19 | torch.nn.Conv2d( 20 | in_planes, 21 | out_planes, 22 | kernel_size=kernel_size, 23 | stride=stride, 24 | padding=padding, 25 | dilation=dilation, 26 | padding_mode = 'reflect', 27 | bias=True 28 | ), 29 | torch.nn.PReLU(out_planes, 0.2) 30 | ) 31 | 32 | def warp(tenInput, tenFlow): 33 | k = (str(tenFlow.device), str(tenFlow.size())) 34 | if k not in backwarp_tenGrid: 35 | tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3]).view(1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) 36 | tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2]).view(1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) 37 | backwarp_tenGrid[k] = torch.cat([ tenHorizontal, tenVertical ], 1).to(device=tenInput.device, dtype=tenInput.dtype) 38 | tenFlow = torch.cat([ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) ], 1) 39 | 40 | g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) 41 | return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bicubic', padding_mode='border', align_corners=True) 42 | 43 | def normalize(tensor, min_val, max_val): 44 | src_dtype = tensor.dtype 45 | tensor = tensor.float() 46 | t_min = tensor.min() 47 | t_max = tensor.max() 48 | 49 | if t_min == t_max: 50 | return torch.full_like(tensor, (min_val + max_val) / 2.0) 51 | 52 | tensor = ((tensor - t_min) / (t_max - t_min)) * (max_val - min_val) + min_val 53 | tensor = tensor.to(dtype = src_dtype) 54 | return tensor 55 | 56 | def hpass(img): 57 | src_dtype = img.dtype 58 | img = img.float() 59 | def gauss_kernel(size=5, channels=3): 60 | kernel = torch.tensor([[1., 4., 6., 4., 1], 61 | [4., 16., 24., 16., 4.], 62 | [6., 24., 36., 24., 6.], 63 | [4., 16., 24., 16., 4.], 64 | [1., 4., 6., 4., 1.]]) 65 | kernel /= 256. 66 | kernel = kernel.repeat(channels, 1, 1, 1) 67 | return kernel 68 | 69 | def conv_gauss(img, kernel): 70 | img = torch.nn.functional.pad(img, (2, 2, 2, 2), mode='reflect') 71 | out = torch.nn.functional.conv2d(img, kernel, groups=img.shape[1]) 72 | return out 73 | 74 | gkernel = gauss_kernel() 75 | gkernel = gkernel.to(device=img.device, dtype=img.dtype) 76 | hp = img - conv_gauss(img, gkernel) + 0.5 77 | hp = torch.clamp(hp, 0.48, 0.52) 78 | hp = normalize(hp, 0, 1) 79 | hp = torch.max(hp, dim=1, keepdim=True).values 80 | hp = hp.to(dtype = src_dtype) 81 | return hp 82 | 83 | def compress(x): 84 | src_dtype = x.dtype 85 | x = x.float() 86 | scale = torch.tanh(torch.tensor(1.0)) 87 | x = torch.where( 88 | (x >= -1) & (x <= 1), scale * x, 89 | torch.tanh(x) 90 | ) 91 | x = (x + 1) / 2 92 | x = x.to(dtype = src_dtype) 93 | return x 94 | 95 | class Head(Module): 96 | def __init__(self): 97 | super(Head, self).__init__() 98 | self.encode = torch.nn.Sequential( 99 | torch.nn.Conv2d(4, 32, 3, 2, 1), 100 | torch.nn.PReLU(32, 0.2), 101 | torch.nn.Conv2d(32, 32, 3, 1, 1), 102 | torch.nn.PReLU(32, 0.2), 103 | torch.nn.Conv2d(32, 32, 3, 1, 1), 104 | torch.nn.PReLU(32, 0.2), 105 | torch.nn.ConvTranspose2d(32, 8, 4, 2, 1) 106 | ) 107 | self.maxdepth = 2 108 | 109 | def forward(self, x): 110 | hp = hpass(x) 111 | x = torch.cat((x, hp), 1) 112 | 113 | n, c, h, w = x.shape 114 | ph = self.maxdepth - (h % self.maxdepth) 115 | pw = self.maxdepth - (w % self.maxdepth) 116 | padding = (0, pw, 0, ph) 117 | x = torch.nn.functional.pad(x, padding) 118 | 119 | return self.encode(x)[:, :, :h, :w] 120 | 121 | class ResConv(Module): 122 | def __init__(self, c, dilation=1): 123 | super().__init__() 124 | self.conv = torch.nn.Conv2d(c, c, 3, 1, dilation, dilation = dilation, groups = 1, padding_mode = 'zeros', bias=True) 125 | self.beta = torch.nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 126 | self.relu = torch.nn.LeakyReLU(0.2, True) 127 | def forward(self, x): 128 | return self.relu(self.conv(x) * self.beta + x) 129 | 130 | class Flownet(Module): 131 | def __init__(self, in_planes, c=64): 132 | super().__init__() 133 | self.conv0 = torch.nn.Sequential( 134 | conv(in_planes, c//2, 5, 2, 2), 135 | conv(c//2, c, 3, 2, 1), 136 | ) 137 | self.convblock = torch.nn.Sequential( 138 | ResConv(c), 139 | ResConv(c), 140 | ResConv(c), 141 | ResConv(c), 142 | ResConv(c), 143 | ResConv(c), 144 | ResConv(c), 145 | ResConv(c), 146 | ) 147 | self.lastconv = torch.nn.Sequential( 148 | torch.nn.ConvTranspose2d(c, 4*2, 4, 2, 1), 149 | torch.nn.PixelShuffle(2) 150 | ) 151 | self.maxdepth = 4 152 | 153 | def forward(self, img0, img1, f0, f1, scale=1): 154 | n, c, h, w = img0.shape 155 | sh, sw = round(h * (1 / scale)), round(w * (1 / scale)) 156 | 157 | ph = self.maxdepth - (sh % self.maxdepth) 158 | pw = self.maxdepth - (sw % self.maxdepth) 159 | padding = (0, pw, 0, ph) 160 | 161 | imgs = torch.cat((img0, img1), 1) 162 | imgs = normalize(imgs, 0, 1) * 2 - 1 163 | x = torch.cat((imgs, f0, f1), 1) 164 | x = torch.nn.functional.interpolate(x, size=(sh, sw), mode="bicubic", align_corners=False) 165 | x = torch.nn.functional.pad(x, padding) 166 | 167 | feat = self.conv0(x) 168 | feat = self.convblock(feat) 169 | feat = self.lastconv(feat) 170 | feat = torch.nn.functional.interpolate(feat[:, :, :sh, :sw], size=(h, w), mode="bilinear", align_corners=False) 171 | flow = feat * scale 172 | return flow 173 | 174 | class FlownetCas(Module): 175 | def __init__(self): 176 | super().__init__() 177 | self.block0 = Flownet(6+16, c=64) # images + feat + timestep + lingrid 178 | # self.block1 = FlownetDeepDualHead(6+3+20+1+1+2+1+4, 12+20+8+1, c=128) # FlownetDeepDualHead(9+30+1+1+4+1+2, 22+30+1, c=128) # images + feat + timestep + lingrid + mask + conf + flow 179 | # self.block2 = FlownetDeepDualHead(6+3+20+1+1+2+1+4, 12+20+8+1, c=96) # FlownetLT(6+2+1+1+1, c=48) # None # FlownetDeepDualHead(9+30+1+1+4+1+2, 22+30+1, c=112) # images + feat + timestep + lingrid + mask + conf + flow 180 | # self.block3 = FlownetLT(11, c=48) 181 | self.encode = Head() 182 | 183 | def forward(self, img0, img1, timestep=0.5, scale=[16, 8, 4, 1], iterations=4, gt=None): 184 | img0 = compress(img0 * 2 - 1) 185 | img1 = compress(img1 * 2 - 1) 186 | 187 | f0 = self.encode(img0) 188 | f1 = self.encode(img1) 189 | 190 | flow = self.block0(img0, img1, f0, f1, scale=1) 191 | 192 | result = { 193 | 'flow_list': [flow] 194 | } 195 | 196 | return result 197 | 198 | self.model = FlownetCas 199 | self.training_model = FlownetCas 200 | 201 | @staticmethod 202 | def get_info(): 203 | return Model.info 204 | 205 | @staticmethod 206 | def get_name(): 207 | return Model.info.get('name') 208 | 209 | @staticmethod 210 | def input_channels(model_state_dict): 211 | channels = 3 212 | try: 213 | channels = model_state_dict.get('multiresblock1.conv_3x3.conv1.weight').shape[1] 214 | except Exception as e: 215 | print (f'Unable to get model dict input channels - setting to 3, {e}') 216 | return channels 217 | 218 | @staticmethod 219 | def output_channels(model_state_dict): 220 | channels = 5 221 | try: 222 | channels = model_state_dict.get('conv_final.conv1.weight').shape[0] 223 | except Exception as e: 224 | print (f'Unable to get model dict output channels - setting to 3, {e}') 225 | return channels 226 | 227 | def get_model(self): 228 | import platform 229 | if platform.system() == 'Darwin': 230 | return self.training_model 231 | return self.model 232 | 233 | def get_training_model(self): 234 | return self.training_model 235 | 236 | def load_model(self, path, flownet, rank=0): 237 | import torch 238 | def convert(param): 239 | if rank == -1: 240 | return { 241 | k.replace("module.", ""): v 242 | for k, v in param.items() 243 | if "module." in k 244 | } 245 | else: 246 | return param 247 | if rank <= 0: 248 | if torch.cuda.is_available(): 249 | flownet.load_state_dict(convert(torch.load(path)), False) 250 | else: 251 | flownet.load_state_dict(convert(torch.load(path, map_location ='cpu')), False) -------------------------------------------------------------------------------- /pytorch/move_to_folders.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | 5 | def main(): 6 | parser = argparse.ArgumentParser(description='Move exrs to corresponding folders') 7 | 8 | # Required argument 9 | parser.add_argument('path', type=str, help='Path to folder with exrs') 10 | args = parser.parse_args() 11 | 12 | files_and_dirs = os.listdir(args.path) 13 | exr_files = [file for file in files_and_dirs if file.endswith('.exr')] 14 | 15 | exr_file_names = set() 16 | for exr_file_name in exr_files: 17 | exr_file_names.add(exr_file_name.split('.')[0]) 18 | 19 | for name in exr_file_names: 20 | os.system(f'mkdir -p {args.path}/{name}') 21 | os.system(f'mv {args.path}/{name}.* {args.path}/{name}/') 22 | 23 | if __name__ == "__main__": 24 | main() -------------------------------------------------------------------------------- /pytorch/recompress.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import importlib 5 | import platform 6 | 7 | os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' 8 | 9 | try: 10 | import numpy as np 11 | except: 12 | python_executable_path = sys.executable 13 | if '.miniconda' in python_executable_path: 14 | print (f'Using "{python_executable_path}" as python interpreter') 15 | print ('Unable to import Numpy') 16 | sys.exit() 17 | 18 | try: 19 | import OpenImageIO as oiio 20 | except: 21 | python_executable_path = sys.executable 22 | if '.miniconda' in python_executable_path: 23 | print ('Unable to import OpenImageIO') 24 | print (f'Using "{python_executable_path}" as python interpreter') 25 | sys.exit() 26 | 27 | def read_image_file(file_path, header_only = False): 28 | result = {'spec': None, 'image_data': None} 29 | inp = oiio.ImageInput.open(file_path) 30 | if inp : 31 | spec = inp.spec() 32 | result['spec'] = spec 33 | if not header_only: 34 | channels = spec.nchannels 35 | result['image_data'] = inp.read_image(0, 0, 0, channels).transpose(1, 0, 2) 36 | inp.close() 37 | return result 38 | 39 | def find_folders_with_exr(path): 40 | """ 41 | Find all folders under the given path that contain .exr files. 42 | 43 | Parameters: 44 | path (str): The root directory to start the search from. 45 | 46 | Returns: 47 | list: A list of directories containing .exr files. 48 | """ 49 | directories_with_exr = set() 50 | 51 | # Walk through all directories and files in the given path 52 | for root, dirs, files in os.walk(path): 53 | if root.endswith('preview'): 54 | continue 55 | if root.endswith('eval'): 56 | continue 57 | for file in files: 58 | if file.endswith('.exr'): 59 | directories_with_exr.add(root) 60 | break # No need to check other files in the same directory 61 | 62 | return directories_with_exr 63 | 64 | def clear_lines(n=2): 65 | """Clears a specified number of lines in the terminal.""" 66 | CURSOR_UP_ONE = '\x1b[1A' 67 | ERASE_LINE = '\x1b[2K' 68 | for _ in range(n): 69 | sys.stdout.write(CURSOR_UP_ONE) 70 | sys.stdout.write(ERASE_LINE) 71 | 72 | def main(): 73 | parser = argparse.ArgumentParser(description='In-place recompress exrs to half PIZ (WARNING - DESTRUCTIVE!)') 74 | 75 | # Required argument 76 | parser.add_argument('dataset_path', type=str, help='Path to folders with exrs') 77 | args = parser.parse_args() 78 | 79 | folders_with_exr = find_folders_with_exr(args.dataset_path) 80 | 81 | exr_files = [] 82 | 83 | for folder_path in sorted(folders_with_exr): 84 | folder_exr_files = [os.path.join(folder_path, file) for file in os.listdir(folder_path) if file.endswith('.exr')] 85 | folder_exr_files.sort() 86 | exr_files.extend(folder_exr_files) 87 | 88 | idx = 0 89 | for exr_file_path in exr_files: 90 | # clear_lines(1) 91 | print (f'\r{" "*120}', end='') 92 | print (f'\rFile [{idx+1} / {len(exr_files)}], {os.path.basename(exr_file_path)}', end='') 93 | try: 94 | result = read_image_file(exr_file_path) 95 | image_data = result['image_data'] 96 | spec = result['spec'] 97 | 98 | if image_data is None or spec is None: 99 | raise RuntimeError(f"Could not read image or header from {exr_file_path}") 100 | 101 | # Transpose back to (W, H, C) for OIIO 102 | image_data_for_write = image_data.transpose(1, 0, 2) 103 | 104 | # Clone the spec and modify compression only 105 | new_spec = oiio.ImageSpec(spec) 106 | new_spec.set_format(oiio.TypeDesc.TypeHalf) # ensure half-precision 107 | new_spec.attribute("compression", "piz") 108 | 109 | out = oiio.ImageOutput.create(exr_file_path) 110 | if not out: 111 | raise RuntimeError(f"Could not create output for {exr_file_path}") 112 | 113 | if not out.open(exr_file_path, new_spec): 114 | raise RuntimeError(f"Failed to open output file {exr_file_path}") 115 | 116 | if not out.write_image(image_data_for_write): 117 | raise RuntimeError(f"Failed to write image to {exr_file_path}") 118 | 119 | out.close() 120 | 121 | except Exception as e: 122 | print (f'\nError reading {exr_file_path}: {e}') 123 | idx += 1 124 | print ('') 125 | 126 | if __name__ == "__main__": 127 | main() 128 | -------------------------------------------------------------------------------- /pytorch/rename_keys.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | 5 | try: 6 | import numpy as np 7 | import torch 8 | except: 9 | python_executable_path = sys.executable 10 | if '.miniconda' in python_executable_path: 11 | print ('Unable to import Numpy and PyTorch libraries') 12 | print (f'Using {python_executable_path} python interpreter') 13 | sys.exit() 14 | 15 | key_mapping_01 = { 16 | 'block0.encode01.0.weight': 'block0.encode01.downconv01.conv1.weight', 17 | 'block0.encode01.0.bias': 'block0.encode01.downconv01.conv1.bias', 18 | 'block0.encode01.2.weight': 'block0.encode01.conv01.conv1.weight', 19 | 'block0.encode01.2.bias': 'block0.encode01.conv01.conv1.bias', 20 | 'block0.encode01.4.weight': 'block0.encode01.conv02.conv1.weight', 21 | 'block0.encode01.4.bias': 'block0.encode01.conv02.conv1.bias', 22 | 'block0.encode01.6.weight': 'block0.encode01.upsample01.conv1.weight', 23 | 'block0.encode01.6.bias': 'block0.encode01.upsample01.conv1.bias', 24 | 25 | 'block1.encode01.0.weight': 'block1.encode01.downconv01.conv1.weight', 26 | 'block1.encode01.0.bias': 'block1.encode01.downconv01.conv1.bias', 27 | 'block1.encode01.2.weight': 'block1.encode01.conv01.conv1.weight', 28 | 'block1.encode01.2.bias': 'block1.encode01.conv01.conv1.bias', 29 | 'block1.encode01.4.weight': 'block1.encode01.conv02.conv1.weight', 30 | 'block1.encode01.4.bias': 'block1.encode01.conv02.conv1.bias', 31 | 'block1.encode01.6.weight': 'block1.encode01.upsample01.conv1.weight', 32 | 'block1.encode01.6.bias': 'block1.encode01.upsample01.conv1.bias', 33 | 34 | 'block2.encode01.0.weight': 'block2.encode01.downconv01.conv1.weight', 35 | 'block2.encode01.0.bias': 'block2.encode01.downconv01.conv1.bias', 36 | 'block2.encode01.2.weight': 'block2.encode01.conv01.conv1.weight', 37 | 'block2.encode01.2.bias': 'block2.encode01.conv01.conv1.bias', 38 | 'block2.encode01.4.weight': 'block2.encode01.conv02.conv1.weight', 39 | 'block2.encode01.4.bias': 'block2.encode01.conv02.conv1.bias', 40 | 'block2.encode01.6.weight': 'block2.encode01.upsample01.conv1.weight', 41 | 'block2.encode01.6.bias': 'block2.encode01.upsample01.conv1.bias', 42 | 43 | 'block3.encode01.0.weight': 'block3.encode01.downconv01.conv1.weight', 44 | 'block3.encode01.0.bias': 'block3.encode01.downconv01.conv1.bias', 45 | 'block3.encode01.2.weight': 'block3.encode01.conv01.conv1.weight', 46 | 'block3.encode01.2.bias': 'block3.encode01.conv01.conv1.bias', 47 | 'block3.encode01.4.weight': 'block3.encode01.conv02.conv1.weight', 48 | 'block3.encode01.4.bias': 'block3.encode01.conv02.conv1.bias', 49 | 'block3.encode01.6.weight': 'block3.encode01.upsample01.conv1.weight', 50 | 'block3.encode01.6.bias': 'block3.encode01.upsample01.conv1.bias', 51 | } 52 | 53 | def main(): 54 | parser = argparse.ArgumentParser(description='Move exrs to corresponding folders') 55 | 56 | # Required argument 57 | parser.add_argument('source', type=str, help='Path to source state dict') 58 | parser.add_argument('dest', type=str, help='Path to dest state dict') 59 | args = parser.parse_args() 60 | 61 | checkpoint = torch.load(args.source) 62 | src_state_dict = checkpoint.get('flownet_state_dict') 63 | 64 | dest_state_dict = {} 65 | for key in src_state_dict: 66 | new_key = key_mapping_01.get(key, key) 67 | dest_state_dict[new_key] = src_state_dict[key] 68 | 69 | checkpoint['flownet_state_dict'] = dest_state_dict 70 | torch.save(checkpoint, args.dest) 71 | 72 | ''' 73 | for key in src_state_dict.keys(): 74 | print (key) 75 | ''' 76 | 77 | if __name__ == "__main__": 78 | main() -------------------------------------------------------------------------------- /pytorch/resize_exrs_min.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import importlib 5 | import platform 6 | 7 | os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' 8 | 9 | try: 10 | import numpy as np 11 | import torch 12 | except: 13 | python_executable_path = sys.executable 14 | if '.miniconda' in python_executable_path: 15 | print ('Unable to import Numpy and PyTorch libraries') 16 | print (f'Using {python_executable_path} python interpreter') 17 | sys.exit() 18 | 19 | try: 20 | import OpenImageIO as oiio 21 | except: 22 | python_executable_path = sys.executable 23 | if '.miniconda' in python_executable_path: 24 | print ('Unable to import OpenImageIO') 25 | print (f'Using "{python_executable_path}" as python interpreter') 26 | sys.exit() 27 | 28 | def read_image_file(file_path, header_only = False): 29 | result = {'spec': None, 'image_data': None} 30 | inp = oiio.ImageInput.open(file_path) 31 | if inp: 32 | spec = inp.spec() 33 | result['spec'] = spec 34 | if not header_only: 35 | channels = spec.nchannels 36 | result['image_data'] = inp.read_image(0, 0, 0, channels) 37 | # img_data = inp.read_image(0, 0, 0, channels) #.transpose(1, 0, 2) 38 | # result['image_data'] = np.ascontiguousarray(img_data) 39 | inp.close() 40 | # del inp 41 | return result 42 | 43 | def write_image_file(file_path, image_data, image_spec): 44 | out = oiio.ImageOptput.create(file_path) 45 | if out: 46 | out.open(file_path, image_spec) 47 | out.write_image(image_data) 48 | out.close() 49 | 50 | def find_folders_with_exr(path): 51 | """ 52 | Find all folders under the given path that contain .exr files. 53 | 54 | Parameters: 55 | path (str): The root directory to start the search from. 56 | 57 | Returns: 58 | list: A list of directories containing .exr files. 59 | """ 60 | directories_with_exr = set() 61 | 62 | # Walk through all directories and files in the given path 63 | for root, dirs, files in os.walk(path): 64 | if root.endswith('preview'): 65 | continue 66 | if root.endswith('eval'): 67 | continue 68 | for file in files: 69 | if file.endswith('.exr'): 70 | directories_with_exr.add(root) 71 | break # No need to check other files in the same directory 72 | 73 | return directories_with_exr 74 | 75 | def clear_lines(n=2): 76 | """Clears a specified number of lines in the terminal.""" 77 | CURSOR_UP_ONE = '\x1b[1A' 78 | ERASE_LINE = '\x1b[2K' 79 | for _ in range(n): 80 | sys.stdout.write(CURSOR_UP_ONE) 81 | sys.stdout.write(ERASE_LINE) 82 | 83 | def resize_image(tensor, new_h, new_w): 84 | """ 85 | Resize the tensor of shape [h, w, c] so that the smallest dimension becomes x, 86 | while retaining aspect ratio. 87 | 88 | Parameters: 89 | tensor (torch.Tensor): The input tensor with shape [h, w, c]. 90 | x (int): The target size for the smallest dimension. 91 | 92 | Returns: 93 | torch.Tensor: The resized tensor. 94 | """ 95 | # Adjust tensor shape to [n, c, h, w] 96 | tensor = tensor.permute(2, 0, 1).unsqueeze(0) 97 | 98 | # Resize 99 | resized_tensor = torch.nn.functional.interpolate(tensor, size=(new_h, new_w), mode='bicubic', align_corners=False) 100 | 101 | # Adjust tensor shape back to [h, w, c] 102 | resized_tensor = resized_tensor.squeeze(0).permute(1, 2, 0) 103 | 104 | return resized_tensor 105 | 106 | def halve(exr_file_path, new_h): 107 | device = torch.device('cuda') 108 | exr_data = read_openexr_file(exr_file_path) 109 | h = exr_data['shape'][0] 110 | w = exr_data['shape'][1] 111 | 112 | aspect_ratio = w / h 113 | new_w = int(new_h * aspect_ratio) 114 | 115 | img0 = exr_data['image_data'] 116 | img0 = torch.from_numpy(img0) 117 | img0 = img0.to(device = device, dtype = torch.float32) 118 | img0 = resize_image(img0, new_h, new_w) 119 | img0 = img0.to(dtype=torch.half) 120 | write_exr(img0.cpu().detach().numpy(), exr_file_path, half_float = True) 121 | del img0 122 | 123 | def main(): 124 | parser = argparse.ArgumentParser(description='In-place resize exrs to half-res (WARNING - DESTRUCTIVE!)') 125 | 126 | # Required argument 127 | parser.add_argument('h', type=int, help='New height') 128 | parser.add_argument('dataset_path', type=str, help='Path to folders with exrs') 129 | args = parser.parse_args() 130 | 131 | folders_with_exr = find_folders_with_exr(args.dataset_path) 132 | 133 | exr_files = [] 134 | 135 | for folder_path in sorted(folders_with_exr): 136 | folder_exr_files = [os.path.join(folder_path, file) for file in os.listdir(folder_path) if file.endswith('.exr')] 137 | folder_exr_files.sort() 138 | exr_files.extend(folder_exr_files) 139 | 140 | idx = 0 141 | for exr_file_path in exr_files: 142 | clear_lines(1) 143 | print (f'\rFile [{idx+1} / {len(exr_files)}], {os.path.basename(exr_file_path)}') 144 | try: 145 | halve(exr_file_path, args.h) 146 | except Exception as e: 147 | print (f'\n\nError halving {exr_file_path}: {e}') 148 | idx += 1 149 | print ('') 150 | 151 | if __name__ == "__main__": 152 | main() 153 | -------------------------------------------------------------------------------- /pytorch/speedtest.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import platform 5 | import torch 6 | import numpy as np 7 | import random 8 | import queue 9 | import threading 10 | import time 11 | import OpenImageIO as oiio 12 | 13 | from pprint import pprint 14 | 15 | def find_and_import_model(models_dir='models', base_name=None, model_name=None, model_file=None): 16 | """ 17 | Dynamically imports the latest version of a model based on the base name, 18 | or a specific model if the model name/version is given, and returns the Model 19 | object named after the base model name. 20 | 21 | :param models_dir: Relative path to the models directory. 22 | :param base_name: Base name of the model to search for. 23 | :param model_name: Specific name/version of the model (optional). 24 | :return: Imported Model object or None if not found. 25 | """ 26 | 27 | import os 28 | import re 29 | import importlib 30 | 31 | if model_file: 32 | module_name = model_file[:-3] # Remove '.py' from filename to get module name 33 | 34 | print (module_name) 35 | 36 | module_path = f"models.{module_name}" 37 | module = importlib.import_module(module_path) 38 | model_object = getattr(module, 'Model') 39 | return model_object 40 | 41 | # Resolve the absolute path of the models directory 42 | models_abs_path = os.path.abspath( 43 | os.path.join( 44 | os.path.dirname(__file__), 45 | models_dir 46 | ) 47 | ) 48 | 49 | # List all files in the models directory 50 | try: 51 | files = os.listdir(models_abs_path) 52 | except FileNotFoundError: 53 | print(f"Directory not found: {models_abs_path}") 54 | return None 55 | 56 | # Filter files based on base_name or model_name 57 | if model_name: 58 | # Look for a specific model version 59 | filtered_files = [f for f in files if f == f"{model_name.lower()}.py"] 60 | else: 61 | # Find all versions of the model and select the latest one 62 | # regex_pattern = fr"{base_name}_v(\d+)\.py" 63 | # versions = [(f, int(m.group(1))) for f in files if (m := re.match(regex_pattern, f))] 64 | versions = [f for f in files if f.endswith('.py')] 65 | if versions: 66 | # Sort by version number (second item in tuple) and select the latest one 67 | # latest_version_file = sorted(versions, key=lambda x: x[1], reverse=True)[0][0] 68 | latest_version_file = sorted(versions, reverse=True)[0] 69 | filtered_files = [latest_version_file] 70 | 71 | # Import the module and return the Model object 72 | if filtered_files: 73 | module_name = filtered_files[0][:-3] # Remove '.py' from filename to get module name 74 | module_path = f"models.{module_name}" 75 | module = importlib.import_module(module_path) 76 | model_object = getattr(module, 'Model') 77 | return model_object 78 | else: 79 | print(f"Model not found: {base_name or model_name}") 80 | return None 81 | 82 | 83 | 84 | def main(): 85 | parser = argparse.ArgumentParser(description='Speed test.') 86 | parser.add_argument('--model', type=str, default=None, help='Model name (default: Flownet4_v001_baseline.py)') 87 | parser.add_argument('--frame_size', type=str, default=None, help='Frame size (default: 4096x1716)') 88 | parser.add_argument('--device', type=int, default=0, help='Graphics card index (default: 0)') 89 | 90 | args = parser.parse_args() 91 | 92 | device = torch.device("mps") if platform.system() == 'Darwin' else torch.device('cuda') 93 | 94 | if args.frame_size: 95 | w, h = args.frame_size.split('x') 96 | h, w = int(h), int(w) 97 | else: 98 | h, w = 1716, 4096 99 | 100 | shape = (1, 3, h, w) 101 | img = torch.randn(shape).to(device) 102 | 103 | model = args.model if args.model else 'Flownet4_v001_baseline' 104 | Net = find_and_import_model(model_name=model) 105 | 106 | print ('Net info:') 107 | pprint (Net.get_info()) 108 | 109 | net = Net().get_training_model()().to(device) 110 | 111 | import signal 112 | def create_graceful_exit(): 113 | def graceful_exit(signum, frame): 114 | ''' 115 | print(f'\nSaving current state to {current_state_dict["trained_model_path"]}...') 116 | print (f'Epoch: {current_state_dict["epoch"] + 1}, Step: {current_state_dict["step"]:11}') 117 | torch.save(current_state_dict, current_state_dict['trained_model_path']) 118 | exit_event.set() # Signal threads to stop 119 | process_exit_event.set() # Signal processes to stop 120 | ''' 121 | print('\n') 122 | exit(0) 123 | # signal.signal(signum, signal.SIG_DFL) 124 | # os.kill(os.getpid(), signal.SIGINT) 125 | return graceful_exit 126 | signal.signal(signal.SIGINT, create_graceful_exit()) 127 | 128 | with torch.no_grad(): 129 | while True: 130 | start_time = time.time() 131 | reult = net( 132 | img, 133 | img, 134 | # 0.5, 135 | scale=[8, 4, 2, 1], 136 | iterations = 1 137 | ) 138 | # result = result[0].permute(1, 2, 0).numpy(force=True) 139 | torch.cuda.synchronize() 140 | end_time = time.time() 141 | elapsed_time = end_time - start_time 142 | elapsed_time = elapsed_time if elapsed_time > 0 else float('inf') 143 | fps = 1 / elapsed_time 144 | print (f'\rRunning at {fps:.2f} fps for {shape[3]}x{shape[2]}', end='') 145 | 146 | if __name__ == "__main__": 147 | main() 148 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | lpips 2 | -------------------------------------------------------------------------------- /retime.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Resolves the directory where the script is located 4 | SCRIPT_DIR=$(dirname "$(readlink -f "$0")") 5 | 6 | # Change to the script directory 7 | cd "$SCRIPT_DIR" 8 | 9 | # Define the path to the Python executable and the Python script 10 | PYTHON_CMD="./packages/.miniconda/appenv/bin/python" 11 | PYTHON_SCRIPT="./pytorch/retime.py" 12 | 13 | # Run the Python script with all arguments passed to this shell script 14 | $PYTHON_CMD $PYTHON_SCRIPT "$@" 15 | -------------------------------------------------------------------------------- /speedtest.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Resolves the directory where the script is located 4 | SCRIPT_DIR=$(dirname "$(readlink -f "$0")") 5 | 6 | # Change to the script directory 7 | cd "$SCRIPT_DIR" 8 | 9 | # Define the path to the Python executable and the Python script 10 | PYTHON_CMD="./packages/.miniconda/appenv/bin/python" 11 | PYTHON_SCRIPT="./pytorch/speedtest.py" 12 | 13 | # Run the Python script with all arguments passed to this shell script 14 | $PYTHON_CMD $PYTHON_SCRIPT "$@" 15 | -------------------------------------------------------------------------------- /test.timewarp_node: -------------------------------------------------------------------------------- 1 | 2 | 3 | 21.020000 4 | Timewarp 5 | 6 | False 7 | 0 8 | 0 9 | 1 10 | 4 11 | 0 12 | False 13 | False 14 | 15 | 1 16 | 0 17 | False 18 | 19 | 20 | 21 | 2 22 | False 23 | 0 24 | False 25 | 26 | 1 27 | 0 28 | 0 29 | 1 30 | 31 | 2 32 | 0 33 | True 34 | 1 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | constant 48 | 100 49 | 1 50 | 2 51 | 52 | 53 | 1 54 | 100 55 | 0.250000 56 | 0 57 | -0.250000 58 | 0 59 | hermite 60 | linear 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | linear 69 | 1 70 | 1 71 | 2 72 | 73 | 74 | 1 75 | 31 76 | 0.250000 77 | 0.250000 78 | -0.250000 79 | -0.250000 80 | hermite 81 | quadratic 82 | False 83 | False 84 | manual 85 | 86 | 87 | 88 | 89 | 90 | 91 | linear 92 | 11 93 | 6 94 | 2 95 | 96 | 97 | 1 98 | 11 99 | 0.250000 100 | 0.250000 101 | -0.250000 102 | -0.250000 103 | hermite 104 | cubic 105 | False 106 | False 107 | manual 108 | 109 | 110 | 21 111 | 34.7803268 112 | 3 113 | 3.603181 114 | -6.666667 115 | -8.007070 116 | hermite 117 | cubic 118 | smooth 119 | 120 | 121 | 30 122 | 45.8307533 123 | 2.333333 124 | 3.677869 125 | -3 126 | -4.728689 127 | hermite 128 | cubic 129 | smooth 130 | 131 | 132 | 37 133 | 60 134 | 10.666667 135 | 14.571893 136 | -2.333333 137 | -3.187602 138 | hermite 139 | cubic 140 | smooth 141 | 142 | 143 | 69 144 | 99.1092377 145 | 5.333333 146 | 6.500000 147 | -10.666667 148 | -13 149 | hermite 150 | cubic 151 | smooth 152 | 153 | 154 | 85 155 | 118.5 156 | 0.250000 157 | 0.320622 158 | -0.250000 159 | -0.320622 160 | hermite 161 | cubic 162 | smooth 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 1 214 | 30 215 | 1 216 | 0 217 | 1 218 | False 219 | 0 220 | 221 | 222 | linear 223 | 86 224 | 2 225 | 2 226 | 227 | 228 | 1 229 | 31 230 | 0.250000 231 | 0.250000 232 | -0.250000 233 | -0.250000 234 | bezier 235 | cubic 236 | False 237 | False 238 | manual 239 | 240 | 241 | 86 242 | 116 243 | 0.250000 244 | 0.250000 245 | -0.250000 246 | -0.250000 247 | bezier 248 | cubic 249 | False 250 | False 251 | manual 252 | 253 | 254 | 255 | 256 | 0 257 | 1 258 | 259 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Resolves the directory where the script is located 4 | SCRIPT_DIR=$(dirname "$(readlink -f "$0")") 5 | 6 | # Change to the script directory 7 | cd "$SCRIPT_DIR" 8 | 9 | # Define the path to the Python executable and the Python script 10 | PYTHON_CMD="./packages/.miniconda/appenv/bin/python" 11 | PYTHON_SCRIPT="./pytorch/flameTimewarpML_train.py" 12 | 13 | # Run the Python script with all arguments passed to this shell script 14 | $PYTHON_CMD $PYTHON_SCRIPT "$@" 15 | -------------------------------------------------------------------------------- /train_stab.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Resolves the directory where the script is located 4 | SCRIPT_DIR=$(dirname "$(readlink -f "$0")") 5 | 6 | # Change to the script directory 7 | cd "$SCRIPT_DIR" 8 | 9 | # Define the path to the Python executable and the Python script 10 | PYTHON_CMD="./packages/.miniconda/appenv/bin/python" 11 | PYTHON_SCRIPT="./pytorch/flameStabML_train.py" 12 | 13 | # Run the Python script with all arguments passed to this shell script 14 | $PYTHON_CMD $PYTHON_SCRIPT "$@" 15 | -------------------------------------------------------------------------------- /validate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Resolves the directory where the script is located 4 | SCRIPT_DIR=$(dirname "$(readlink -f "$0")") 5 | 6 | # Change to the script directory 7 | cd "$SCRIPT_DIR" 8 | 9 | # Define the path to the Python executable and the Python script 10 | PYTHON_CMD="./packages/.miniconda/appenv/bin/python" 11 | PYTHON_SCRIPT="./pytorch/flameTimewarpML_validate.py" 12 | 13 | # Run the Python script with all arguments passed to this shell script 14 | $PYTHON_CMD $PYTHON_SCRIPT "$@" 15 | --------------------------------------------------------------------------------