├── .gitattributes
├── .gitignore
├── LICENSE
├── README.md
├── README_v044.md
├── baselight
└── uk.co.andriy.mltimewarp.v1
│ ├── IFNet_HDv3.py
│ ├── LICENCE.md
│ ├── VERSION
│ ├── bin
│ ├── Activate.ps1
│ ├── activate
│ ├── activate.csh
│ ├── activate.fish
│ ├── easy_install
│ ├── easy_install-3.9
│ ├── pip
│ ├── pip3
│ ├── pip3.9
│ ├── python
│ ├── python3
│ └── python3.9
│ ├── flexi.py
│ ├── lib64
│ ├── pyvenv.cfg
│ ├── timewarp.py
│ └── uk.co.dirtylooks.mltimewarp.v1.flexi
├── find_scale.sh
├── fix-xattr.command
├── flameTimewarpML.py
├── flameTimewarpML_framework.py
├── fonts
└── DejaVuSansMono.ttf
├── models
└── flownet2_v004.pth
├── packages
├── .miniconda
│ └── README.md
├── README.md
└── clean_pycache.sh
├── presets
├── openexr16bit.xml
├── openexr32bit.xml
├── source_export.xml
└── source_export32bit.xml
├── pyflame_lib_flameTimewarpML.py
├── pytorch
├── check_consistent.py
├── check_readable.py
├── flameStabML_train.py
├── flameTimewarpML_findscale.py
├── flameTimewarpML_findscale_ml.py
├── flameTimewarpML_findscale_scipy.py
├── flameTimewarpML_finetune.py
├── flameTimewarpML_inference.py
├── flameTimewarpML_train.py
├── flameTimewarpML_validate.py
├── hub
│ └── checkpoints
│ │ └── README.md
├── models
│ ├── archived
│ │ ├── flownet2_v004 copy.py
│ │ ├── flownet2_v004.py
│ │ ├── flownet4_v001a.py
│ │ ├── flownet4_v001aa.py
│ │ ├── flownet4_v001ab.py
│ │ ├── flownet4_v001b.py
│ │ ├── flownet4_v001c.py
│ │ ├── flownet4_v001d.py
│ │ ├── flownet4_v001e.py
│ │ ├── flownet4_v001ea.py
│ │ ├── flownet4_v001eb3x.py
│ │ ├── flownet4_v001eb3xx.py
│ │ ├── flownet4_v001eb5.py
│ │ ├── flownet4_v001eb8x_fibonacci5.py
│ │ ├── flownet4_v001eb_fibonacci5.py
│ │ ├── flownet4_v001ec.py
│ │ ├── flownet4_v001ed.py
│ │ ├── flownet4_v001ee.py
│ │ ├── flownet4_v001ef.py
│ │ ├── flownet4_v001eg.py
│ │ ├── flownet4_v002.py
│ │ ├── flownet4_v002a.py
│ │ ├── flownet4_v002b.py
│ │ ├── flownet4_v002c.py
│ │ ├── flownet4_v002d.py
│ │ ├── flownet4_v002e.py
│ │ ├── flownet4_v002f.py
│ │ ├── flownet4_v002g.py
│ │ ├── flownet4_v002h.py
│ │ ├── flownet4_v002i.py
│ │ ├── flownet4_v002j.py
│ │ ├── flownet4_v002k.py
│ │ ├── flownet4_v002l.py
│ │ ├── flownet4_v002la.py
│ │ ├── flownet4_v002lb.py
│ │ ├── flownet4_v002lc.py
│ │ ├── flownet4_v002ld.py
│ │ ├── flownet4_v002le.py
│ │ ├── flownet4_v002lf.py
│ │ ├── flownet4_v002lg.py
│ │ ├── flownet4_v002lh.py
│ │ ├── flownet4_v002li.py
│ │ ├── flownet4_v002lj.py
│ │ ├── flownet4_v002lk.py
│ │ ├── flownet4_v002ll.py
│ │ ├── flownet4_v002lm.py
│ │ ├── flownet4_v002ln.py
│ │ ├── flownet4_v002m.py
│ │ ├── flownet4_v003.py
│ │ ├── flownet4_v003a.py
│ │ ├── flownet4_v003b.py
│ │ ├── flownet4_v003c.py
│ │ ├── flownet4_v003d.py
│ │ ├── flownet4_v004.py
│ │ └── flownet4_v004_old01.py
│ ├── flownet4_v001.py
│ ├── flownet4_v001_baseline.py
│ ├── flownet4_v001_baseline_sst.py
│ ├── flownet4_v001a.py
│ ├── flownet4_v001b.py
│ ├── flownet4_v001b_v01.py
│ ├── flownet4_v001b_v02.py
│ ├── flownet4_v001b_v03.py
│ ├── flownet4_v001batt.py
│ ├── flownet4_v001beatles.py
│ ├── flownet4_v001c.py
│ ├── flownet4_v001d.py
│ ├── flownet4_v001e.py
│ ├── flownet4_v001ea.py
│ ├── flownet4_v001eb.py
│ ├── flownet4_v001eb_dv01.py
│ ├── flownet4_v001eb_dv02.py
│ ├── flownet4_v001eb_dv03.py
│ ├── flownet4_v001eb_dv04.py
│ ├── flownet4_v001eb_dv04a.py
│ ├── flownet4_v001eb_dv04b.py
│ ├── flownet4_v001eb_dv04c.py
│ ├── flownet4_v001eb_dv04d.py
│ ├── flownet4_v001eb_dv04e.py
│ ├── flownet4_v001eb_dv04ea.py
│ ├── flownet4_v001eb_dv04f.py
│ ├── flownet4_v001ec.py
│ ├── flownet4_v001f.py
│ ├── flownet4_v001f_v02.py
│ ├── flownet4_v001f_v03.py
│ ├── flownet4_v001f_v04.py
│ ├── flownet4_v001f_v04_001.py
│ ├── flownet4_v001f_v04_002.py
│ ├── flownet4_v001f_v04_003.py
│ ├── flownet4_v001fa_v04_001.py
│ ├── flownet4_v001fb_v04_001.py
│ ├── flownet4_v001fc_v04_001.py
│ ├── flownet4_v001fd_v04_001.py
│ ├── flownet4_v001fe_v04_001.py
│ ├── flownet4_v001fp4.py
│ ├── flownet4_v001hp5.py
│ ├── flownet4_v001hpass.py
│ ├── flownet4_v001orig.py
│ ├── flownet4_v002.py
│ ├── flownet4_v003.py
│ ├── flownet4_v004.py
│ ├── flownet5_v001.py
│ ├── flownet5_v002.py
│ ├── flownet5_v003.py
│ ├── flownet5_v004.py
│ ├── flownet5_v005.py
│ ├── flownet5_v006.py
│ ├── flownet5_v007.py
│ ├── flownet5_v008.py
│ ├── flownet5_v009.py
│ ├── flownet5_v010.py
│ ├── stabnet4_v000.py
│ ├── stabnet4_v001.py
│ ├── stabnet4_v001f_002.py
│ ├── stabnet4_v001fa_v04_001.py
│ ├── stabnet4_v001fb_v04_001.py
│ ├── stabnet4_v001fc_v04_001.py
│ ├── stabnet4_v001fd_v04_001.py
│ ├── stabnet4_v001fe_v04_001.py
│ ├── stabnet4_v002.py
│ ├── stabnet4_v003.py
│ ├── stabnet4_v004.py
│ ├── stabnet4_v005.py
│ ├── stabnet4_v006.py
│ ├── stabnet4_v007.py
│ ├── stabnet4_v008.py
│ ├── stabnet4_v009.py
│ └── stabnet4_v010.py
├── move_to_folders.py
├── recompress.py
├── rename_keys.py
├── resize_exrs_half.py
├── resize_exrs_min.py
├── resize_exrs_width.py
├── retime.py
└── speedtest.py
├── requirements.txt
├── retime.sh
├── speedtest.sh
├── test.timewarp_node
├── train.sh
├── train_stab.sh
└── validate.sh
/.gitattributes:
--------------------------------------------------------------------------------
1 | # *.pth filter=lfs diff=lfs merge=lfs -text
2 | # *.exr filter=lfs diff=lfs merge=lfs -text
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | bundle/miniconda.package/Miniconda*
2 | bundle/miniconda.package/packages/*
3 | bundle/bin/*
4 | !bundle/miniconda.package/packages/README
5 | settings.json
6 | *.pyc
7 | *.exr
8 | *.png
9 | flameTimewarpML.package
10 | flameTimewarpMLenv.tar.bz2
11 | flameTimewarpML.bundle.tar
12 | test_*
13 | env
14 | appenv
15 | .DS_Store
16 | *.pth
17 | *.pkl
18 |
19 | !packages/README.md
20 | !packages/.miniconda/README.md
21 | !packages/.miniconda/appenv/README.md
22 |
23 | # Byte-compiled / optimized / DLL files
24 | __pycache__/
25 | *.py[cod]
26 | *$py.class
27 |
28 | # C extensions
29 | *.so
30 |
31 | # Distribution / packaging
32 | .Python
33 | build/
34 | develop-eggs/
35 | dist/
36 | downloads/
37 | eggs/
38 | .eggs/
39 | lib/
40 | lib64/
41 | parts/
42 | sdist/
43 | var/
44 | wheels/
45 | share/python-wheels/
46 | *.egg-info/
47 | .installed.cfg
48 | *.egg
49 | MANIFEST
50 |
51 | # PyInstaller
52 | # Usually these files are written by a python script from a template
53 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
54 | *.manifest
55 | *.spec
56 |
57 | # Installer logs
58 | pip-log.txt
59 | pip-delete-this-directory.txt
60 |
61 | # Unit test / coverage reports
62 | htmlcov/
63 | .tox/
64 | .nox/
65 | .coverage
66 | .coverage.*
67 | .cache
68 | nosetests.xml
69 | coverage.xml
70 | *.cover
71 | *.py,cover
72 | .hypothesis/
73 | .pytest_cache/
74 | cover/
75 |
76 | # Translations
77 | *.mo
78 | *.pot
79 |
80 | # Django stuff:
81 | *.log
82 | local_settings.py
83 | db.sqlite3
84 | db.sqlite3-journal
85 |
86 | # Flask stuff:
87 | instance/
88 | .webassets-cache
89 |
90 | # Scrapy stuff:
91 | .scrapy
92 |
93 | # Sphinx documentation
94 | docs/_build/
95 |
96 | # PyBuilder
97 | .pybuilder/
98 | target/
99 |
100 | # Jupyter Notebook
101 | .ipynb_checkpoints
102 |
103 | # IPython
104 | profile_default/
105 | ipython_config.py
106 |
107 | # pyenv
108 | # For a library or package, you might want to ignore these files since the code is
109 | # intended to run in multiple environments; otherwise, check them in:
110 | # .python-version
111 |
112 | # pipenv
113 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
114 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
115 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
116 | # install all needed dependencies.
117 | #Pipfile.lock
118 |
119 | # poetry
120 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
121 | # This is especially recommended for binary packages to ensure reproducibility, and is more
122 | # commonly ignored for libraries.
123 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
124 | #poetry.lock
125 |
126 | # pdm
127 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
128 | #pdm.lock
129 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
130 | # in version control.
131 | # https://pdm.fming.dev/#use-with-ide
132 | .pdm.toml
133 |
134 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
135 | __pypackages__/
136 |
137 | # Celery stuff
138 | celerybeat-schedule
139 | celerybeat.pid
140 |
141 | # SageMath parsed files
142 | *.sage.py
143 |
144 | # Environments
145 | .env
146 | .venv
147 | env/
148 | venv/
149 | ENV/
150 | env.bak/
151 | venv.bak/
152 |
153 | # Spyder project settings
154 | .spyderproject
155 | .spyproject
156 |
157 | # Rope project settings
158 | .ropeproject
159 |
160 | # mkdocs documentation
161 | /site
162 |
163 | # mypy
164 | .mypy_cache/
165 | .dmypy.json
166 | dmypy.json
167 |
168 | # Pyre type checker
169 | .pyre/
170 |
171 | # pytype static type analyzer
172 | .pytype/
173 |
174 | # Cython debug symbols
175 | cython_debug/
176 |
177 | # PyCharm
178 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
179 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
180 | # and can be added to the global gitignore or merged into this file. For a more nuclear
181 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
182 | #.idea/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Andriy Toloshny
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## flameTimewarpML
2 |
3 | Previous versions: [Readme v0.4.4](https://github.com/talosh/flameTimewarpML/blob/main/README_v044.md)
4 |
5 | ## Table of Contents
6 | - [Status](#status)
7 | - [Installation](#installation)
8 | - [Training](#training)
9 |
10 | ### Installation
11 |
12 | #### Installing from package - Linux
13 | * Download parts and run
14 | ```bash
15 | cat flameTimewarpML_v0.4.5-dev003.Linux.tar.gz.part-* > flameTimewarpML_v0.4.5-dev003.Linux.tar.gz
16 | tar -xvf flameTimewarpML_v0.4.5-dev003.Linux.tar.gz
17 | ```
18 | * Place "flameTimewarpML" folder into Flame python hooks path.
19 |
20 | #### Installing from package - MacOS
21 | * Unpack and run fix-xattr.command
22 | * Place "flameTimewarpML" folder into Flame python hooks path.
23 |
24 | ### Installing and configuring python environment manually
25 |
26 | * pre-configured miniconda environment should be placed into hidden "packages/.miniconda" folder
27 | * the folder is hidden (starts with ".") in order to keep Flame from scanning it looking for python hooks
28 | * pre-configured python environment usually packed with release tar file
29 |
30 | * download Miniconda for Mac or Linux (I'm using python 3.11 for tests) from
31 |
32 |
33 | * install downloaded Miniconda python distribution, use "-p" to select install location. For example:
34 |
35 | ```bash
36 | sh ~/Downloads/Miniconda3-py311_24.7.1-0-Linux-x86_64.sh -bfsm -p ~/miniconda3
37 | ```
38 |
39 | * Activate anc clone default environment into another named "appenv"
40 |
41 | ```bash
42 | eval "$(~/miniconda3/bin/conda shell.bash hook)"
43 | conda create --name appenv
44 | conda activate appenv
45 | ```
46 |
47 | * Install dependency libraries
48 |
49 | ```bash
50 | conda install python=3.11 conda-forge::openimageio conda-forge::py-openimageio
51 | conda install pyqt
52 | conda install conda-pack
53 | ```
54 |
55 | * Install pytorch. Please look up exact commands depending on OS and Cuda versions
56 |
57 | * Linux example
58 | ```bash
59 | conda install pytorch torchvision pytorch-cuda=11.8 -c pytorch -c nvidia
60 | ```
61 |
62 | * MacOS example:
63 |
64 | ```bash
65 | conda install pytorch::pytorch torchvision -c pytorch
66 | ```
67 |
68 | * Install rest of the dependencies
69 | ```bash
70 | pip install -r {flameTimewarpML folder}/requirements.txt
71 | ```
72 |
73 | * Pack append environment into a portable tar file
74 |
75 | ```bash
76 | conda pack --ignore-missing-files -n appenv
77 | ```
78 |
79 | * Unpack environment to flameTimewarpML folder
80 |
81 | ```bash
82 | mkdir {flameTimewarpML folder}/packages/.miniconda/appenv/
83 | tar xvf appenv.tar.gz -C {flameTimewarpML folder}/packages/.miniconda/appenv/
84 | ```
85 |
86 | * Remove environment tarball
87 |
88 | ```bash
89 | rm appenv.tar.gz
90 | ```
91 |
92 | ### Training
93 |
94 |
95 |
96 | #### Finetune for specific shot or set of shots
97 | Finetune option is avaliable as a menu item starting from 0.4.5 dev 003
98 |
99 | #### Finetune using command line script
100 |
101 | Export as Linear ACEScg (AP1) Uncompressed OpenEXR sequence.
102 | Export your shots so each shot are in separate folder.
103 | Copy pre-trained model file (large or lite) to the file to train with.
104 | If motion is fast place the whole shot or its parts with fast motion to "fast" folder.
105 |
106 | ```bash
107 | cd {flameTimewarpML folder}
108 | ./train.sh --state_file {Path to copied state file}/MyShot001.pth --generalize 1 --lr 4e-6 --acescc 0 --onecycle 1000 {Path to shot}/{fast}/
109 | ```
110 |
111 | * Change number after "--onecycle" to set number of runs.
112 | * Use "--acescc 0" to train in Linear or to retain input colourspace, "--acescc 100" to convert all samples to Log.
113 | * Use "--frame_size" to modify training samples size
114 | * Preview last 9 training patches in "{Path to shot}/preview" folder
115 | * Use "--preview" to modify how frequently preview files are saved
116 |
117 | #### Train your own model
118 | ```bash
119 | cd {flameTimewarpML folder}
120 | ./train.sh --state_file {Path to MyModel}/MyModel001.pth --model flownet4_v004 --batch_size 4 {Path to Dataset}/
121 | ```
122 |
123 | #### Batch size and learning rate
124 | The batch size and learning rate are two crucial hyperparameters that significantly affect the training process and empirical tuning is necessary here.
125 | When the batch size is increased, the learning rate can often be increased as well. A common heuristic is the linear scaling rule: when you multiply the batch size by a factor
126 | k, you can also multiply the learning rate by k. Another approach is the square root scaling rule: when you multiply the batch size by k multiply the learning rate by sqrt(k)
127 |
128 |
129 | #### Dataset preparation
130 | Training script will scan all the folders under a given path and will compose training samples out of folders where .exr files are found.
131 | Only Uncompressed OpenEXR files are supported at the moment.
132 |
133 | Export your shots so each shot are in separate folder.
134 | There are two magic words in shot path: "fast" and "slow"
135 | When "fast" is in path 3-frame window will be used for those shots.
136 | When "slow" is in path 9-frame window will be used.
137 | Default window size is 5 frames.
138 | Put shots with fast motion in "fast" folder and shots where motion are slow and continious to "slow" folder to let the model learn more intermediae ratios.
139 |
140 | - Scene001/
141 | - slow/
142 | - Shot001/
143 | - Shot001.001.exr
144 | - Shot001.002.exr
145 | - ...
146 | - fast/
147 | - Shot002
148 | - Shot002.001.exr
149 | - Shot002.002.exr
150 | - ...
151 | - normal/
152 | - Shot003
153 | - Shot003.001.exr
154 | - Shot003.002.exr
155 | - ...
156 |
157 | #### Window size
158 | Samples for training are 3 frames and ratio. The model is given the first and the last frame and tries to re-create middle frame at given ratio.
159 |
160 | [TODO] - add documentation
161 |
--------------------------------------------------------------------------------
/README_v044.md:
--------------------------------------------------------------------------------
1 | ## flameTimewarpML
2 |
3 | Machine Learning frame interpolation tool for Autodesk Flame.
4 |
5 | Based on arXiv2020-RIFE, original implementation:
6 |
7 | ```data
8 | @article{huang2020rife,
9 | title={RIFE: Real-Time Intermediate Flow Estimation for Video Frame Interpolation},
10 | author={Huang, Zhewei and Zhang, Tianyuan and Heng, Wen and Shi, Boxin and Zhou, Shuchang},
11 | journal={arXiv preprint arXiv:2011.06294},
12 | year={2020}
13 | }
14 | ```
15 |
16 | Flame's animation curve interpolation code from Julik Tarkhanov:
17 |
18 |
19 | ## Installation
20 |
21 | * Creating Miniconda environment with dependencies:
22 |
23 | ```sh Miniconda3-py311_24.1.2-0-MacOSX-arm64.sh
24 | eval "$(~/miniconda3/bin/conda shell.zsh hook)"
25 | conda create --name twml --clone base
26 | conda install numpy
27 | pip install PySide6
28 | conda install pytorch::pytorch -c pytorch
29 | conda install conda-pack
30 | conda pack --ignore-missing-files -n twml
31 | tar xvf twml.tar.gz -C {flameTimewarpML folder}/packages/.miniconda/twml/
32 | ```
33 |
34 |
35 | ### Single workstation / Easy install
36 |
37 | * Download latest release from [Releases](https://github.com/talosh/flameTimewarpML/releases) page
38 | * All you need to do is put the .py file in /opt/Autodesk/shared/python and launch/relaunch flame. The first time it will unpack/install what it needs. It will give you a prompt to let you know when it’s finished.
39 | * It is possible to choose the installation folder. Default installation folders are:
40 |
41 | ```bash
42 | ~/Documents/flameTimewarpML # Mac
43 | ~/flameTimewarpML # Linux
44 | ```
45 |
46 | ### Preferences location
47 |
48 | * In case you'd need to reset preferences
49 |
50 | ```bash
51 | ~/Library/Preferences/flameTimewarpML # Mac
52 | ~/.config/flameTimewarpML # Linux
53 | ```
54 |
55 | ### Configuration via env variables
56 |
57 | * There are an option to set working folder over env variable
58 | Setting FLAMETWML_DEFAULT_WORK_FOLDER will set default folder and user will still be allowed to change it
59 |
60 | ```bash
61 | export FLAMETWML_DEFAULT_WORK_FOLDER=/Volumes/projects/my_timewarps/
62 | ```
63 |
64 | * Setting FLAMETWML_WORK_FOLDER will block user from changing it. This read every time just before the tool is launched so one can use it with other pipeline tools to change it dynamically depending on context
65 |
66 | ```bash
67 | export FLAMETWML_WORK_FOLDER=/Volumes/projects/my_timewarps/
68 | ```
69 |
70 | * Setting FLAMETWML_HARDCOMMIT will lead to imported clip hard commited after import and all temporary files deleted
71 |
72 | ```bash
73 | export FLAMETWML_HARDCOMMIT=True
74 | ```
75 | ### Updating / Downgrading PyTorch to match CUDA version
76 |
77 | This example will change Pytorch to cuda11
78 |
79 | * Get pytorch build for python 3.8 and cuda11:
80 | ```
81 | https://download.pytorch.org/whl/cu110/torch-1.7.1%2Bcu110-cp38-cp38-linux_x86_64.whl
82 | ```
83 | * Init miniconda3 environment:
84 | ```
85 | ~/flameTimewarpML/bundle/init_env
86 | ```
87 | * In new terminal window that opens check current pytorch cuda version:
88 | ```
89 | python -c "import torch; print(torch.version.cuda)"
90 | ```
91 | this should report 10.2
92 | * Update pytorch
93 | ```
94 | pip3 install --upgrade ~/Downloads/torch-1.7.1+cu110-cp38-cp38-linux_x86_64.whl
95 | ```
96 | * Check pytorch cuda version again:
97 | ```
98 | python -c "import torch; print(torch.version.cuda)"
99 | ```
100 | this now should report 11.0
101 |
102 |
103 | ### Centralised / Manual Install
104 |
105 | * It is possible to do easy install on one of the workstations to common location and then configure other via env variables (see Centralised configuration section)
106 | * flameTimewarpML is made of three components.
107 | 1. Miniconda3 isolated Python 3.8 environment with some additional dependence libraries (see requirements.txt)
108 | 2. A set of python scripts that are called from command line and should be running within Python 3.8 environment. Those two parts are not dependant on Flame and can be run as a standalone tools.
109 | 3. Flame script to be run inside Flame. It has to know the location of two previous parts to be able to initialize Python 3.8 environment and then run command line tools to process image sequences.
110 | * Get the source from [Releases](https://github.com/talosh/flameTimewarpML/releases) page or directly from repo:
111 |
112 | ```bash
113 | git clone https://github.com/talosh/flameTimewarpML.git
114 | ```
115 |
116 | * It is possible to share Python3 code location between platforms, but Miniconda3 python environment should be placed separately for Mac and Linux.
117 | This example will use:
118 |
119 | ```bash
120 | Python3 code location:
121 | /Volumes/software/flameTimewarpML # Mac
122 | /mnt/software/flameTimewarpML # Linux
123 |
124 | miniconda3 location:
125 | /Volumes/software/miniconda3 # Mac
126 | /mnt/software/miniconda3 # Linux
127 | ```
128 |
129 | * Do not place those folders inside Flame's python hooks folder.
130 | The only file that should be placed in Flame hooks folder is flameTimewarpML.py
131 |
132 | * Create folders
133 |
134 | ```bash
135 | mkdir -p /Volumes/software/flameTimewarpML # Mac
136 | mkdir -p /Volumes/software/miniconda3 # Mac
137 |
138 | mkdir -p /mnt/software/flameTimewarpML # Linux
139 | mkdir -p /mnt/software/miniconda3 # Linux
140 | ```
141 |
142 | * Copy the contents of 'bundle' folder to the code location.
143 |
144 | ```bash
145 | cp -a bundle/* /Volumes/software/flameTimewarpML/ # Mac
146 | cp -a bundle/* /mnt/software/flameTimewarpML/ # Linux
147 | ```
148 |
149 | * Download Miniconda3 for Python 3.8 from
150 | Shell installer is used for this example.
151 |
152 | * Install Miniconda3 to centralized location:
153 |
154 | ```bash
155 | sh Miniconda3-latest-MacOSX-x86_64.sh -bu -p /Volumes/software/miniconda3/ # Mac
156 | sh Miniconda3-latest-Linux-x86_64.sh -bu -p /mnt/software/miniconda3/ # Linux
157 | ```
158 |
159 | * In case you'd like to move Miniconda3 installation later you'll have to re-install it again.
160 | Please refer to Anaconda docs:
161 |
162 | * Init installed Miniconda3 environment:
163 |
164 | ```bash
165 | /Volumes/software/flameTimewarpML/init_env /Volumes/software/miniconda3/ # Mac
166 | /mnt/software/flameTimewarpML/init_env /mnt/software/miniconda3/ # Linux
167 | ```
168 |
169 | The script will open new konsole window.
170 |
171 | * In case you do it over ssh remotely on Linux edit the line at the very end of init_env
172 |
173 | ```bash
174 | cmd_prefix = 'konsole -e /bin/bash --rcfile ' + tmp_bash_rc_file
175 | ```
176 |
177 | ```bash
178 | cmd_prefix = '/bin/bash --rcfile ' + tmp_bash_rc_file
179 | ```
180 |
181 | * To check if environment is properly initialized: there should be (base) before shell prompt. python --version should give Python 3.8.5 or greater
182 |
183 | * From this shell prompt install dependency libraries:
184 |
185 | ```bash
186 | pip3 install -r /Volumes/software/flameTimewarpML/requirements.txt # Mac
187 | pip3 install -r /mnt/software/flameTimewarpML/requirements.txt # Linux
188 | ```
189 |
190 | * In case pip fails to find a match for torch, install torch with the '-f' flag:
191 |
192 | ```bash
193 | pip3 install torch=="1.10.2+cu111" -f https://download.pytorch.org/whl/torch_stable.html
194 | ```
195 |
196 | * It is possible to download the packages and install it without internet connection. Install Miniconda3 on the machine that is connected to internet and download dependency packages:
197 |
198 | ```bash
199 | pip3 download -r bundle/requirements.txt -d packages_folder
200 | ```
201 |
202 | then it is possible to install packages from folder:
203 |
204 | ```bash
205 | pip3 install -r /Volumes/software/flameTimewarpML/requirements.txt --no-index --find-links /Volumes/software/flameTimewarpML/packages_folder
206 | ```
207 |
208 | * Copy flameTimewarpML.py to your Flame python scripts folder
209 |
210 | * set FLAMETWML_BUNDLE, FLAMETWML_MINICONDA environment variables to point code and miniconda locations.
211 |
212 | ```bash
213 | export FLAMETWML_BUNDLE=/Volumes/software/flameTimewarpML/ # Mac
214 | export FLAMETWML_MINICONDA=//Volumes/software/miniconda3/ # Mac
215 |
216 | export FLAMETWML_BUNDLE=/mnt/software/flameTimewarpML/ # Linux
217 | export FLAMETWML_MINICONDA=/mnt/software/miniconda3/ # Linux
218 | ```
219 |
220 | More granular settings per platform possible with
221 |
222 | ```bash
223 | export FLAMETWML_BUNDLE_MAC=/Volumes/software/flameTimewarpML/
224 | export FLAMETWML_BUNDLE_LINUX=/mnt/software/flameTimewarpML/
225 | export FLAMETWML_MINICONDA_MAC=/Volumes/software/miniconda3/
226 | export FLAMETWML_MINICONDA_LINUX=mnt/software/miniconda3/
227 | ```
228 |
--------------------------------------------------------------------------------
/baselight/uk.co.andriy.mltimewarp.v1/IFNet_HDv3.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from warplayer import warp
5 |
6 | def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
7 | return nn.Sequential(
8 | torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1),
9 | nn.PReLU(out_planes)
10 | )
11 |
12 | def conv_wo_act(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
13 | return nn.Sequential(
14 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
15 | padding=padding, dilation=dilation, bias=True),
16 | )
17 |
18 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
19 | return nn.Sequential(
20 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
21 | padding=padding, dilation=dilation, bias=True),
22 | nn.PReLU(out_planes)
23 | )
24 |
25 | class IFBlock(nn.Module):
26 | def __init__(self, in_planes, c=64):
27 | super(IFBlock, self).__init__()
28 | self.conv0 = nn.Sequential(
29 | conv(in_planes, c, 3, 2, 1),
30 | conv(c, 2*c, 3, 2, 1),
31 | )
32 | self.convblock0 = nn.Sequential(
33 | conv(2*c, 2*c),
34 | conv(2*c, 2*c),
35 | )
36 | self.convblock1 = nn.Sequential(
37 | conv(2*c, 2*c),
38 | conv(2*c, 2*c),
39 | )
40 | self.convblock2 = nn.Sequential(
41 | conv(2*c, 2*c),
42 | conv(2*c, 2*c),
43 | )
44 | self.conv1 = nn.ConvTranspose2d(2*c, 4, 4, 2, 1)
45 |
46 | def forward(self, x, flow=None, scale=1):
47 | x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False)
48 | if flow != None:
49 | flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False) * (1. / scale)
50 | x = torch.cat((x, flow), 1)
51 | x = self.conv0(x)
52 | x = self.convblock0(x) + x
53 | x = self.convblock1(x) + x
54 | x = self.convblock2(x) + x
55 | x = self.conv1(x)
56 | flow = x
57 | if scale != 1:
58 | flow = F.interpolate(flow, scale_factor= scale, mode="bilinear", align_corners=False) * scale
59 | return flow
60 |
61 | class IFNet(nn.Module):
62 | def __init__(self):
63 | super(IFNet, self).__init__()
64 | self.block0 = IFBlock(6, c=80)
65 | self.block1 = IFBlock(10, c=80)
66 | self.block2 = IFBlock(10, c=80)
67 |
68 | def set_device(self,idevice):
69 | self.device = idevice
70 |
71 | def forward(self, x, scale_list=[4,2,1], scale=1.0):
72 | x = F.interpolate(x, scale_factor=scale, mode="bilinear", align_corners=False)
73 | flow0 = self.block0(x, scale=scale_list[0])
74 | F1 = flow0
75 | F1_large = F.interpolate(F1, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
76 | warped_img0 = warp(x[:, :3], F1_large[:, :2], device=self.device)
77 | warped_img1 = warp(x[:, 3:], F1_large[:, 2:4], device=self.device)
78 | flow1 = self.block1(torch.cat((warped_img0, warped_img1), 1), F1_large, scale=scale_list[1])
79 | F2 = (flow0 + flow1)
80 | F2_large = F.interpolate(F2, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
81 | warped_img0 = warp(x[:, :3], F2_large[:, :2], device=self.device)
82 | warped_img1 = warp(x[:, 3:], F2_large[:, 2:4], device=self.device)
83 | flow2 = self.block2(torch.cat((warped_img0, warped_img1), 1), F2_large, scale=scale_list[2])
84 | F3 = (flow0 + flow1 + flow2)
85 | if scale != 1.0:
86 | F3 = F.interpolate(F3, scale_factor=1 / scale, mode="bilinear", align_corners=False) / scale
87 | return F3, [F1, F2, F3]
88 |
--------------------------------------------------------------------------------
/baselight/uk.co.andriy.mltimewarp.v1/LICENCE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Andriy Toloshny
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/baselight/uk.co.andriy.mltimewarp.v1/VERSION:
--------------------------------------------------------------------------------
1 | 0.0.1 dev 001
2 |
--------------------------------------------------------------------------------
/baselight/uk.co.andriy.mltimewarp.v1/bin/Activate.ps1:
--------------------------------------------------------------------------------
1 | <#
2 | .Synopsis
3 | Activate a Python virtual environment for the current PowerShell session.
4 |
5 | .Description
6 | Pushes the python executable for a virtual environment to the front of the
7 | $Env:PATH environment variable and sets the prompt to signify that you are
8 | in a Python virtual environment. Makes use of the command line switches as
9 | well as the `pyvenv.cfg` file values present in the virtual environment.
10 |
11 | .Parameter VenvDir
12 | Path to the directory that contains the virtual environment to activate. The
13 | default value for this is the parent of the directory that the Activate.ps1
14 | script is located within.
15 |
16 | .Parameter Prompt
17 | The prompt prefix to display when this virtual environment is activated. By
18 | default, this prompt is the name of the virtual environment folder (VenvDir)
19 | surrounded by parentheses and followed by a single space (ie. '(.venv) ').
20 |
21 | .Example
22 | Activate.ps1
23 | Activates the Python virtual environment that contains the Activate.ps1 script.
24 |
25 | .Example
26 | Activate.ps1 -Verbose
27 | Activates the Python virtual environment that contains the Activate.ps1 script,
28 | and shows extra information about the activation as it executes.
29 |
30 | .Example
31 | Activate.ps1 -VenvDir C:\Users\MyUser\Common\.venv
32 | Activates the Python virtual environment located in the specified location.
33 |
34 | .Example
35 | Activate.ps1 -Prompt "MyPython"
36 | Activates the Python virtual environment that contains the Activate.ps1 script,
37 | and prefixes the current prompt with the specified string (surrounded in
38 | parentheses) while the virtual environment is active.
39 |
40 | .Notes
41 | On Windows, it may be required to enable this Activate.ps1 script by setting the
42 | execution policy for the user. You can do this by issuing the following PowerShell
43 | command:
44 |
45 | PS C:\> Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser
46 |
47 | For more information on Execution Policies:
48 | https://go.microsoft.com/fwlink/?LinkID=135170
49 |
50 | #>
51 | Param(
52 | [Parameter(Mandatory = $false)]
53 | [String]
54 | $VenvDir,
55 | [Parameter(Mandatory = $false)]
56 | [String]
57 | $Prompt
58 | )
59 |
60 | <# Function declarations --------------------------------------------------- #>
61 |
62 | <#
63 | .Synopsis
64 | Remove all shell session elements added by the Activate script, including the
65 | addition of the virtual environment's Python executable from the beginning of
66 | the PATH variable.
67 |
68 | .Parameter NonDestructive
69 | If present, do not remove this function from the global namespace for the
70 | session.
71 |
72 | #>
73 | function global:deactivate ([switch]$NonDestructive) {
74 | # Revert to original values
75 |
76 | # The prior prompt:
77 | if (Test-Path -Path Function:_OLD_VIRTUAL_PROMPT) {
78 | Copy-Item -Path Function:_OLD_VIRTUAL_PROMPT -Destination Function:prompt
79 | Remove-Item -Path Function:_OLD_VIRTUAL_PROMPT
80 | }
81 |
82 | # The prior PYTHONHOME:
83 | if (Test-Path -Path Env:_OLD_VIRTUAL_PYTHONHOME) {
84 | Copy-Item -Path Env:_OLD_VIRTUAL_PYTHONHOME -Destination Env:PYTHONHOME
85 | Remove-Item -Path Env:_OLD_VIRTUAL_PYTHONHOME
86 | }
87 |
88 | # The prior PATH:
89 | if (Test-Path -Path Env:_OLD_VIRTUAL_PATH) {
90 | Copy-Item -Path Env:_OLD_VIRTUAL_PATH -Destination Env:PATH
91 | Remove-Item -Path Env:_OLD_VIRTUAL_PATH
92 | }
93 |
94 | # Just remove the VIRTUAL_ENV altogether:
95 | if (Test-Path -Path Env:VIRTUAL_ENV) {
96 | Remove-Item -Path env:VIRTUAL_ENV
97 | }
98 |
99 | # Just remove the _PYTHON_VENV_PROMPT_PREFIX altogether:
100 | if (Get-Variable -Name "_PYTHON_VENV_PROMPT_PREFIX" -ErrorAction SilentlyContinue) {
101 | Remove-Variable -Name _PYTHON_VENV_PROMPT_PREFIX -Scope Global -Force
102 | }
103 |
104 | # Leave deactivate function in the global namespace if requested:
105 | if (-not $NonDestructive) {
106 | Remove-Item -Path function:deactivate
107 | }
108 | }
109 |
110 | <#
111 | .Description
112 | Get-PyVenvConfig parses the values from the pyvenv.cfg file located in the
113 | given folder, and returns them in a map.
114 |
115 | For each line in the pyvenv.cfg file, if that line can be parsed into exactly
116 | two strings separated by `=` (with any amount of whitespace surrounding the =)
117 | then it is considered a `key = value` line. The left hand string is the key,
118 | the right hand is the value.
119 |
120 | If the value starts with a `'` or a `"` then the first and last character is
121 | stripped from the value before being captured.
122 |
123 | .Parameter ConfigDir
124 | Path to the directory that contains the `pyvenv.cfg` file.
125 | #>
126 | function Get-PyVenvConfig(
127 | [String]
128 | $ConfigDir
129 | ) {
130 | Write-Verbose "Given ConfigDir=$ConfigDir, obtain values in pyvenv.cfg"
131 |
132 | # Ensure the file exists, and issue a warning if it doesn't (but still allow the function to continue).
133 | $pyvenvConfigPath = Join-Path -Resolve -Path $ConfigDir -ChildPath 'pyvenv.cfg' -ErrorAction Continue
134 |
135 | # An empty map will be returned if no config file is found.
136 | $pyvenvConfig = @{ }
137 |
138 | if ($pyvenvConfigPath) {
139 |
140 | Write-Verbose "File exists, parse `key = value` lines"
141 | $pyvenvConfigContent = Get-Content -Path $pyvenvConfigPath
142 |
143 | $pyvenvConfigContent | ForEach-Object {
144 | $keyval = $PSItem -split "\s*=\s*", 2
145 | if ($keyval[0] -and $keyval[1]) {
146 | $val = $keyval[1]
147 |
148 | # Remove extraneous quotations around a string value.
149 | if ("'""".Contains($val.Substring(0, 1))) {
150 | $val = $val.Substring(1, $val.Length - 2)
151 | }
152 |
153 | $pyvenvConfig[$keyval[0]] = $val
154 | Write-Verbose "Adding Key: '$($keyval[0])'='$val'"
155 | }
156 | }
157 | }
158 | return $pyvenvConfig
159 | }
160 |
161 |
162 | <# Begin Activate script --------------------------------------------------- #>
163 |
164 | # Determine the containing directory of this script
165 | $VenvExecPath = Split-Path -Parent $MyInvocation.MyCommand.Definition
166 | $VenvExecDir = Get-Item -Path $VenvExecPath
167 |
168 | Write-Verbose "Activation script is located in path: '$VenvExecPath'"
169 | Write-Verbose "VenvExecDir Fullname: '$($VenvExecDir.FullName)"
170 | Write-Verbose "VenvExecDir Name: '$($VenvExecDir.Name)"
171 |
172 | # Set values required in priority: CmdLine, ConfigFile, Default
173 | # First, get the location of the virtual environment, it might not be
174 | # VenvExecDir if specified on the command line.
175 | if ($VenvDir) {
176 | Write-Verbose "VenvDir given as parameter, using '$VenvDir' to determine values"
177 | }
178 | else {
179 | Write-Verbose "VenvDir not given as a parameter, using parent directory name as VenvDir."
180 | $VenvDir = $VenvExecDir.Parent.FullName.TrimEnd("\\/")
181 | Write-Verbose "VenvDir=$VenvDir"
182 | }
183 |
184 | # Next, read the `pyvenv.cfg` file to determine any required value such
185 | # as `prompt`.
186 | $pyvenvCfg = Get-PyVenvConfig -ConfigDir $VenvDir
187 |
188 | # Next, set the prompt from the command line, or the config file, or
189 | # just use the name of the virtual environment folder.
190 | if ($Prompt) {
191 | Write-Verbose "Prompt specified as argument, using '$Prompt'"
192 | }
193 | else {
194 | Write-Verbose "Prompt not specified as argument to script, checking pyvenv.cfg value"
195 | if ($pyvenvCfg -and $pyvenvCfg['prompt']) {
196 | Write-Verbose " Setting based on value in pyvenv.cfg='$($pyvenvCfg['prompt'])'"
197 | $Prompt = $pyvenvCfg['prompt'];
198 | }
199 | else {
200 | Write-Verbose " Setting prompt based on parent's directory's name. (Is the directory name passed to venv module when creating the virutal environment)"
201 | Write-Verbose " Got leaf-name of $VenvDir='$(Split-Path -Path $venvDir -Leaf)'"
202 | $Prompt = Split-Path -Path $venvDir -Leaf
203 | }
204 | }
205 |
206 | Write-Verbose "Prompt = '$Prompt'"
207 | Write-Verbose "VenvDir='$VenvDir'"
208 |
209 | # Deactivate any currently active virtual environment, but leave the
210 | # deactivate function in place.
211 | deactivate -nondestructive
212 |
213 | # Now set the environment variable VIRTUAL_ENV, used by many tools to determine
214 | # that there is an activated venv.
215 | $env:VIRTUAL_ENV = $VenvDir
216 |
217 | if (-not $Env:VIRTUAL_ENV_DISABLE_PROMPT) {
218 |
219 | Write-Verbose "Setting prompt to '$Prompt'"
220 |
221 | # Set the prompt to include the env name
222 | # Make sure _OLD_VIRTUAL_PROMPT is global
223 | function global:_OLD_VIRTUAL_PROMPT { "" }
224 | Copy-Item -Path function:prompt -Destination function:_OLD_VIRTUAL_PROMPT
225 | New-Variable -Name _PYTHON_VENV_PROMPT_PREFIX -Description "Python virtual environment prompt prefix" -Scope Global -Option ReadOnly -Visibility Public -Value $Prompt
226 |
227 | function global:prompt {
228 | Write-Host -NoNewline -ForegroundColor Green "($_PYTHON_VENV_PROMPT_PREFIX) "
229 | _OLD_VIRTUAL_PROMPT
230 | }
231 | }
232 |
233 | # Clear PYTHONHOME
234 | if (Test-Path -Path Env:PYTHONHOME) {
235 | Copy-Item -Path Env:PYTHONHOME -Destination Env:_OLD_VIRTUAL_PYTHONHOME
236 | Remove-Item -Path Env:PYTHONHOME
237 | }
238 |
239 | # Add the venv to the PATH
240 | Copy-Item -Path Env:PATH -Destination Env:_OLD_VIRTUAL_PATH
241 | $Env:PATH = "$VenvExecDir$([System.IO.Path]::PathSeparator)$Env:PATH"
242 |
--------------------------------------------------------------------------------
/baselight/uk.co.andriy.mltimewarp.v1/bin/activate:
--------------------------------------------------------------------------------
1 | # This file must be used with "source bin/activate" *from bash*
2 | # you cannot run it directly
3 |
4 | deactivate () {
5 | # reset old environment variables
6 | if [ -n "${_OLD_VIRTUAL_PATH:-}" ] ; then
7 | PATH="${_OLD_VIRTUAL_PATH:-}"
8 | export PATH
9 | unset _OLD_VIRTUAL_PATH
10 | fi
11 | if [ -n "${_OLD_VIRTUAL_PYTHONHOME:-}" ] ; then
12 | PYTHONHOME="${_OLD_VIRTUAL_PYTHONHOME:-}"
13 | export PYTHONHOME
14 | unset _OLD_VIRTUAL_PYTHONHOME
15 | fi
16 |
17 | # This should detect bash and zsh, which have a hash command that must
18 | # be called to get it to forget past commands. Without forgetting
19 | # past commands the $PATH changes we made may not be respected
20 | if [ -n "${BASH:-}" -o -n "${ZSH_VERSION:-}" ] ; then
21 | hash -r 2> /dev/null
22 | fi
23 |
24 | if [ -n "${_OLD_VIRTUAL_PS1:-}" ] ; then
25 | PS1="${_OLD_VIRTUAL_PS1:-}"
26 | export PS1
27 | unset _OLD_VIRTUAL_PS1
28 | fi
29 |
30 | unset VIRTUAL_ENV
31 | if [ ! "${1:-}" = "nondestructive" ] ; then
32 | # Self destruct!
33 | unset -f deactivate
34 | fi
35 | }
36 |
37 | # unset irrelevant variables
38 | deactivate nondestructive
39 |
40 | VIRTUAL_ENV="/usr/fl/flexi/effects/uk.ltd.filmlight.rife_retime.v1"
41 | export VIRTUAL_ENV
42 |
43 | _OLD_VIRTUAL_PATH="$PATH"
44 | PATH="$VIRTUAL_ENV/bin:$PATH"
45 | export PATH
46 |
47 | # unset PYTHONHOME if set
48 | # this will fail if PYTHONHOME is set to the empty string (which is bad anyway)
49 | # could use `if (set -u; : $PYTHONHOME) ;` in bash
50 | if [ -n "${PYTHONHOME:-}" ] ; then
51 | _OLD_VIRTUAL_PYTHONHOME="${PYTHONHOME:-}"
52 | unset PYTHONHOME
53 | fi
54 |
55 | if [ -z "${VIRTUAL_ENV_DISABLE_PROMPT:-}" ] ; then
56 | _OLD_VIRTUAL_PS1="${PS1:-}"
57 | PS1="(uk.ltd.filmlight.rife_retime.v1) ${PS1:-}"
58 | export PS1
59 | fi
60 |
61 | # This should detect bash and zsh, which have a hash command that must
62 | # be called to get it to forget past commands. Without forgetting
63 | # past commands the $PATH changes we made may not be respected
64 | if [ -n "${BASH:-}" -o -n "${ZSH_VERSION:-}" ] ; then
65 | hash -r 2> /dev/null
66 | fi
67 |
--------------------------------------------------------------------------------
/baselight/uk.co.andriy.mltimewarp.v1/bin/activate.csh:
--------------------------------------------------------------------------------
1 | # This file must be used with "source bin/activate.csh" *from csh*.
2 | # You cannot run it directly.
3 | # Created by Davide Di Blasi .
4 | # Ported to Python 3.3 venv by Andrew Svetlov
5 |
6 | alias deactivate 'test $?_OLD_VIRTUAL_PATH != 0 && setenv PATH "$_OLD_VIRTUAL_PATH" && unset _OLD_VIRTUAL_PATH; rehash; test $?_OLD_VIRTUAL_PROMPT != 0 && set prompt="$_OLD_VIRTUAL_PROMPT" && unset _OLD_VIRTUAL_PROMPT; unsetenv VIRTUAL_ENV; test "\!:*" != "nondestructive" && unalias deactivate'
7 |
8 | # Unset irrelevant variables.
9 | deactivate nondestructive
10 |
11 | setenv VIRTUAL_ENV "/usr/fl/flexi/effects/uk.ltd.filmlight.rife_retime.v1"
12 |
13 | set _OLD_VIRTUAL_PATH="$PATH"
14 | setenv PATH "$VIRTUAL_ENV/bin:$PATH"
15 |
16 |
17 | set _OLD_VIRTUAL_PROMPT="$prompt"
18 |
19 | if (! "$?VIRTUAL_ENV_DISABLE_PROMPT") then
20 | set prompt = "(uk.ltd.filmlight.rife_retime.v1) $prompt"
21 | endif
22 |
23 | alias pydoc python -m pydoc
24 |
25 | rehash
26 |
--------------------------------------------------------------------------------
/baselight/uk.co.andriy.mltimewarp.v1/bin/activate.fish:
--------------------------------------------------------------------------------
1 | # This file must be used with "source /bin/activate.fish" *from fish*
2 | # (https://fishshell.com/); you cannot run it directly.
3 |
4 | function deactivate -d "Exit virtual environment and return to normal shell environment"
5 | # reset old environment variables
6 | if test -n "$_OLD_VIRTUAL_PATH"
7 | set -gx PATH $_OLD_VIRTUAL_PATH
8 | set -e _OLD_VIRTUAL_PATH
9 | end
10 | if test -n "$_OLD_VIRTUAL_PYTHONHOME"
11 | set -gx PYTHONHOME $_OLD_VIRTUAL_PYTHONHOME
12 | set -e _OLD_VIRTUAL_PYTHONHOME
13 | end
14 |
15 | if test -n "$_OLD_FISH_PROMPT_OVERRIDE"
16 | functions -e fish_prompt
17 | set -e _OLD_FISH_PROMPT_OVERRIDE
18 | functions -c _old_fish_prompt fish_prompt
19 | functions -e _old_fish_prompt
20 | end
21 |
22 | set -e VIRTUAL_ENV
23 | if test "$argv[1]" != "nondestructive"
24 | # Self-destruct!
25 | functions -e deactivate
26 | end
27 | end
28 |
29 | # Unset irrelevant variables.
30 | deactivate nondestructive
31 |
32 | set -gx VIRTUAL_ENV "/usr/fl/flexi/effects/uk.ltd.filmlight.rife_retime.v1"
33 |
34 | set -gx _OLD_VIRTUAL_PATH $PATH
35 | set -gx PATH "$VIRTUAL_ENV/bin" $PATH
36 |
37 | # Unset PYTHONHOME if set.
38 | if set -q PYTHONHOME
39 | set -gx _OLD_VIRTUAL_PYTHONHOME $PYTHONHOME
40 | set -e PYTHONHOME
41 | end
42 |
43 | if test -z "$VIRTUAL_ENV_DISABLE_PROMPT"
44 | # fish uses a function instead of an env var to generate the prompt.
45 |
46 | # Save the current fish_prompt function as the function _old_fish_prompt.
47 | functions -c fish_prompt _old_fish_prompt
48 |
49 | # With the original prompt function renamed, we can override with our own.
50 | function fish_prompt
51 | # Save the return status of the last command.
52 | set -l old_status $status
53 |
54 | # Output the venv prompt; color taken from the blue of the Python logo.
55 | printf "%s%s%s" (set_color 4B8BBE) "(uk.ltd.filmlight.rife_retime.v1) " (set_color normal)
56 |
57 | # Restore the return status of the previous command.
58 | echo "exit $old_status" | .
59 | # Output the original/"old" prompt.
60 | _old_fish_prompt
61 | end
62 |
63 | set -gx _OLD_FISH_PROMPT_OVERRIDE "$VIRTUAL_ENV"
64 | end
65 |
--------------------------------------------------------------------------------
/baselight/uk.co.andriy.mltimewarp.v1/bin/easy_install:
--------------------------------------------------------------------------------
1 | #!/usr/fl/flexi/effects/uk.ltd.filmlight.rife_retime.v1/bin/python3.9
2 | # -*- coding: utf-8 -*-
3 | import re
4 | import sys
5 | from setuptools.command.easy_install import main
6 | if __name__ == '__main__':
7 | sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
8 | sys.exit(main())
9 |
--------------------------------------------------------------------------------
/baselight/uk.co.andriy.mltimewarp.v1/bin/easy_install-3.9:
--------------------------------------------------------------------------------
1 | #!/usr/fl/flexi/effects/uk.ltd.filmlight.rife_retime.v1/bin/python3.9
2 | # -*- coding: utf-8 -*-
3 | import re
4 | import sys
5 | from setuptools.command.easy_install import main
6 | if __name__ == '__main__':
7 | sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
8 | sys.exit(main())
9 |
--------------------------------------------------------------------------------
/baselight/uk.co.andriy.mltimewarp.v1/bin/pip:
--------------------------------------------------------------------------------
1 | #!/usr/fl/flexi/effects/uk.ltd.filmlight.rife_retime.v1/bin/python3.9
2 | # -*- coding: utf-8 -*-
3 | import re
4 | import sys
5 | from pip._internal.cli.main import main
6 | if __name__ == '__main__':
7 | sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
8 | sys.exit(main())
9 |
--------------------------------------------------------------------------------
/baselight/uk.co.andriy.mltimewarp.v1/bin/pip3:
--------------------------------------------------------------------------------
1 | #!/usr/fl/flexi/effects/uk.ltd.filmlight.rife_retime.v1/bin/python3.9
2 | # -*- coding: utf-8 -*-
3 | import re
4 | import sys
5 | from pip._internal.cli.main import main
6 | if __name__ == '__main__':
7 | sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
8 | sys.exit(main())
9 |
--------------------------------------------------------------------------------
/baselight/uk.co.andriy.mltimewarp.v1/bin/pip3.9:
--------------------------------------------------------------------------------
1 | #!/usr/fl/flexi/effects/uk.ltd.filmlight.rife_retime.v1/bin/python3.9
2 | # -*- coding: utf-8 -*-
3 | import re
4 | import sys
5 | from pip._internal.cli.main import main
6 | if __name__ == '__main__':
7 | sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
8 | sys.exit(main())
9 |
--------------------------------------------------------------------------------
/baselight/uk.co.andriy.mltimewarp.v1/bin/python:
--------------------------------------------------------------------------------
1 | python3.9
--------------------------------------------------------------------------------
/baselight/uk.co.andriy.mltimewarp.v1/bin/python3:
--------------------------------------------------------------------------------
1 | python3.9
--------------------------------------------------------------------------------
/baselight/uk.co.andriy.mltimewarp.v1/bin/python3.9:
--------------------------------------------------------------------------------
1 | /usr/bin/python3.9
--------------------------------------------------------------------------------
/baselight/uk.co.andriy.mltimewarp.v1/flexi.py:
--------------------------------------------------------------------------------
1 | # Flexi Python module
2 | #
3 | # Copyright (c) 2021 FilmLight. All rights reserved.
4 | # npublished - rights reserved under the copyright laws of the
5 | # United States. Use of a copyright notice is precautionary only
6 | # and does not imply publication or disclosure.
7 | # This software contains confidential information and trade secrets
8 | # of FilmLight Limited. Use, disclosure, or reproduction
9 | # is prohibited without the prior express written permission of
10 | # FilmLight Limited.
11 |
12 | import sys
13 | import json
14 | import websocket
15 | import traceback
16 | import os
17 |
18 | # Globals
19 | Flexi_registered_effects = {}
20 | Flexi_effect_for_id = {}
21 | Flexi_ws = None
22 |
23 | def register_id(id, effect):
24 | global Flexi_effect_for_id, Flexi_registered_effects
25 | Flexi_registered_effects[effect] = 1
26 | Flexi_effect_for_id[id] = effect
27 |
28 | def run():
29 | # Open connection to the application using the URL passed in the environment
30 | global Flexi_ws
31 |
32 | if(not 'FLEXI_SERVER_URL' in os.environ):
33 | print('flexi.run() called without websocket URL. Running in test mode.')
34 | from jsonschema import validate
35 | with open('flexi_reply.schema') as f:
36 | describe_effect_schema = json.load(f)
37 | class DummyWS:
38 | def send(self, s):
39 | self.reply = json.loads(s)
40 | # print(json.dumps(self.reply, indent=2))
41 | # if 'log' in self.reply:
42 | # print(json.dumps(reply['log'], indent=2))
43 | # if 'data' in self.reply:
44 | # print(json.dumps(reply['data'], indent=2))
45 | Flexi_ws = DummyWS()
46 | for id in Flexi_effect_for_id:
47 | print(f'Validating describe_effect for {id[0]} v{id[1]}')
48 | _process_message(
49 | {'effect' : [id[0], id[1], 'instance'],
50 | 'method' : 'describe_effect_v1',
51 | 'id' : 2,
52 | 'data' : {}})
53 | validate(Flexi_ws.reply, describe_effect_schema)
54 | return
55 |
56 | url = os.environ['FLEXI_SERVER_URL']
57 |
58 | Flexi_ws = websocket.create_connection(url)
59 |
60 | # print('Connected to {}'.format(url))
61 |
62 | # Processing loop
63 | while True:
64 | try:
65 | message = json.loads(Flexi_ws.recv())
66 | except websocket._exceptions.WebSocketConnectionClosedException:
67 | print('Connection closed')
68 | _shutdown()
69 | if isinstance(message, list):
70 | # TODO could consider batching replies
71 | for entry in message:
72 | _process_message(entry)
73 | else:
74 | _process_message(message)
75 |
76 | # Process one message from the application
77 | def _process_message(message):
78 | global Flexi_effect_for_id, Flexi_registered_effects
79 |
80 | if 'method' in message and message['method'] == 'shutdown_v1':
81 | _reply(message, {})
82 | _shutdown()
83 |
84 | try:
85 | id = (message['effect'][0], message['effect'][1])
86 | if not id in Flexi_effect_for_id:
87 | return _reply(message, error('Unregistered effect', id[0]))
88 | effect = Flexi_effect_for_id[id]
89 | effect._handle_message(id, message)
90 |
91 | except SystemExit as e:
92 | sys.exit(e)
93 |
94 | except:
95 | print('Internal error', format(sys.exc_info()))
96 | traceback.print_exc()
97 | _reply(message, error('Internal error', format(sys.exc_info())))
98 |
99 | # Send a reply to a message
100 | def _reply(message, data):
101 | global Flexi_ws
102 | if data != None and '__LOG__' in data:
103 | reply = {'effect' : message['effect'],
104 | 'method' : message['method'].upper(),
105 | 'id' : message['id'],
106 | 'log' : data['__LOG__']
107 | }
108 | else:
109 | reply = {'effect' : message['effect'],
110 | 'method' : message['method'].upper(),
111 | 'id' : message['id'],
112 | 'data' : data
113 | }
114 | Flexi_ws.send(json.dumps(reply))
115 |
116 | # Make a reply with an error to a message
117 | def error(error, detail):
118 | return { '__LOG__' : [ { 'level' : 3, 'message' : error, 'detail' : detail } ] }
119 |
120 | def _shutdown():
121 | global Flexi_ws
122 | for id in Flexi_effect_for_id:
123 | effect = Flexi_effect_for_id[id]
124 | effect.shutdown()
125 | if Flexi_ws != None:
126 | Flexi_ws.close()
127 | print('Exiting')
128 | sys.exit(0)
129 |
130 | # Effect class - subclass this to define your effect
131 |
132 | class Effect:
133 |
134 | def init(self, ids):
135 | """Initialise the effect object.
136 | ids is a list of tuples ('identifier', version)
137 | """
138 | for id in ids:
139 | register_id(id, self)
140 |
141 | def _handle_message(self, id, message):
142 | """Internal method"""
143 | gpu = None
144 | if 'data' in message and 'cuda_gpu' in message['data']:
145 | cuda_gpu = message['data']['cuda_gpu']
146 | if cuda_gpu != '':
147 | gpu = cuda_gpu
148 | if 'data' in message and 'metal_device_index' in message['data']:
149 | metal_device_index = message['data']['metal_device_index']
150 | if metal_device_index != -1:
151 | gpu = metal_device_index
152 | if message['method'] == 'describe_effect_v1':
153 | _reply(message, self.describe_effect(id))
154 | self.setup_if_necessary(id, gpu)
155 | elif message['method'] == 'query_vram_requirement_v1':
156 | _reply(message, self.query_vram_requirement(id, message['data']))
157 | elif message['method'] == 'set_vram_limit_v1':
158 | res = self.set_vram_limit(id, message['data'])
159 | if(res == 'quit'):
160 | _reply(message, {})
161 | _shutdown()
162 | else:
163 | _reply(message, res)
164 | elif message['method'] == 'run_generate_v1':
165 | self.setup_if_necessary(id, gpu)
166 | _reply(message, self.run_generate(id, message['effect'][2], message['data']))
167 | else:
168 | _reply(message, error('Unknown method', message['method']))
169 |
170 | def describe_effect(self, id, message):
171 | """Subclass must override this to QUICKLY return a description of a particular effect.
172 | id is a tuple ('identifier', version, 'instance')
173 | message will have members 'host_identifer', 'host_version' which can be used
174 | to change how the effect behaves, or indeed to fail to describe it
175 | """
176 | raise RuntimeError('Effect subclass must implement describe_effect() method')
177 |
178 | def setup_if_necessary(self, id, gpu):
179 | """Subclass may override this to do expensive setup.
180 | gpu is (on Linux) a string which is the CUDA UUID of the GPU to use,
181 | or (on macOS) an integer index into the result of MTLCopyAllDevices()"""
182 | return
183 |
184 | def query_vram_requirement(self, id, data):
185 | """Subclass must override this to QUICKLY return the VRAM requirement of a particular effect.
186 | id is a tuple ('identifier', version, 'instance')
187 | data contains params, width and height
188 | Reply must contain 'min_mb' and 'max_mb'
189 | """
190 | raise RuntimeError('Effect subclass must implement query_vram_requirement() method')
191 |
192 | def set_vram_limit(self, id, data):
193 | """Subclass must override this to IMMEDIATELY restrict its current and future
194 | VRAM usage to the supplied value data['limit_mb'] in MB.
195 |
196 | If limit_mb is zero, the effect must release ALL resources on the indicated device.
197 |
198 | Method must not return until its GPU usage is at or below the supplied limit,
199 | except in the special case where a GPU SDK is being used which does not support
200 | releasing resources, then this method can return 'exit', which will cause the
201 | process to exit.
202 |
203 | id is a tuple ('identifier', version, 'instance')
204 | """
205 | raise RuntimeError('Effect subclass must implement set_vram_limit() method')
206 |
207 | def run_generate(self, id, instance, data):
208 | """Subclass must override this to run an effect.
209 | id is a tuple ('identifier', version, 'instance')
210 | instance is a string that might remain constant for each instance of the effect
211 | in the app
212 | data contains the inputs, outputs, parameters etc
213 | """
214 | raise RuntimeError('Effect subclass must implement run_generate() method')
215 |
216 | def shutdown(self):
217 | """Subclass may override this to do a cleanup prior to shutdown"""
218 | return
219 |
220 | # Exports
221 | __all__ = ["Effect", "run"]
222 |
--------------------------------------------------------------------------------
/baselight/uk.co.andriy.mltimewarp.v1/lib64:
--------------------------------------------------------------------------------
1 | lib
--------------------------------------------------------------------------------
/baselight/uk.co.andriy.mltimewarp.v1/pyvenv.cfg:
--------------------------------------------------------------------------------
1 | home = /usr/bin
2 | include-system-site-packages = false
3 | version = 3.9.6
4 |
--------------------------------------------------------------------------------
/baselight/uk.co.andriy.mltimewarp.v1/timewarp.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import flexi
3 | import json
4 | import numpy
5 | import platform
6 | if (platform.system() == 'Darwin'):
7 | from multiprocessing import shared_memory
8 | elif (platform.system() == 'Linux'):
9 | import cupy
10 | from numba import cuda
11 | import torch
12 | from torch.nn import functional as F
13 | import math, time
14 |
15 | def get_rife_model(req_device="cuda"):
16 | from RIFE_HDv3 import Model
17 | if (req_device == "mps" and not torch.backends.mps.is_available()):
18 | req_device = "cpu"
19 | if (req_device == "cuda" and not torch.cuda.is_available()):
20 | req_device = "cpu"
21 |
22 | device = torch.device(req_device)
23 |
24 | model = Model()
25 | model.load_model("RIFE/model/", device, -1)
26 | torch.set_grad_enabled(False)
27 |
28 | if device == "cuda":
29 | torch.backends.cudnn.enabled = True
30 | torch.backends.cudnn.benchmark = True
31 |
32 | model.eval()
33 | model.device(device)
34 |
35 | return model, req_device
36 |
37 | def rife_inference(model, device, src1, src2, scale, ratio, nbPass, width, height):
38 |
39 | img1 = torch.as_tensor(src1, device=device)
40 | img2 = torch.as_tensor(src2, device=device)
41 |
42 | img1 = torch.reshape(img1, (4, height, width)).unsqueeze(0)
43 | img2 = torch.reshape(img2, (4, height, width)).unsqueeze(0)
44 |
45 | n, c, h, w = img1.shape
46 |
47 | #determine the padding
48 | pad = max(int(32 / scale), 128)
49 | ph = ((h - 1) // pad + 1) * pad
50 | pw = ((w - 1) // pad + 1) * pad
51 | padding = (0, pw - w, 0, ph - h)
52 | img1 = F.pad(img1, padding)
53 | img2 = F.pad(img2, padding)
54 |
55 | lf = 0.
56 | rf = 1.
57 |
58 | scale_list = [4, 2, 1]
59 | for i in range(nbPass):
60 | if ratio == lf:
61 | res = img1
62 | break
63 | elif ratio == rf:
64 | res = img2
65 | break
66 |
67 | f = (ratio - lf) / (rf - lf) if(i == nbPass - 1) else 0.5
68 |
69 | res = model.inference(img1, img2, scale, f)
70 |
71 | if(i is not nbPass - 1):
72 | if ratio <= (lf + rf) / 2.:
73 | img2 = res
74 | rf = (lf + rf) / 2.
75 | else:
76 | img1 = res
77 | lf = (lf + rf) / 2.
78 | _, _, h2, w2 = res.shape
79 |
80 | out = res[0]
81 | if (platform.system() == 'Linux'):
82 | return out[:, :h, :w].contiguous()
83 | else:
84 | return out[:, :h, :w].contiguous().cpu()
85 |
86 | ######
87 | #
88 | # RIFE Retime
89 | #
90 | class RIFERetime(flexi.Effect):
91 |
92 | def __init__(self):
93 | self.init([('uk.ltd.filmlight.rife_retime', 1)])
94 | self.model = None
95 | self.vram_amount = 0
96 |
97 | def describe_effect(self, id):
98 | # We don't expose the effect
99 | return {}
100 |
101 | def query_vram_requirement(self, id, data):
102 | width = data['width']
103 | height = data['height']
104 | params = data['params']
105 | amount = 3. * math.sqrt(width * height)
106 | #As the first instanciation will require more VRAM, make sure we request more
107 | # if the model hasn't been loaded
108 | return {'min_mb' : amount, 'max_mb' : 1.1 * amount}
109 |
110 | def set_vram_limit(self, id, data):
111 | limit = data['limit_mb']
112 | #If the limit is not enough, we need to desinstanciate the effect
113 | #(don't bother if we haven't loaded the model)
114 | if(self.model is not None and limit < self.vram_amount):
115 | return 'quit'
116 | return {'RIFE' : 'Success'}
117 |
118 | def run_generate(self, id, instance, data):
119 | # No need to verify identifier or version, we only support one
120 | inputs = data['inputs']
121 | output = data['output']
122 | params = data['params']
123 |
124 | width = data['width']
125 | height = data['height']
126 | timestamp = params['timestamp']
127 | scale = params['scale']
128 | nbPass = int(params['pass'])
129 |
130 | #Account for VRAM usage
131 | amount = self.query_vram_requirement(id, data)['min_mb']
132 | if(self.vram_amount < amount):
133 | self.vram_amount = amount
134 |
135 | #instancidate the model if necessary with the correct backend
136 | if(self.model is None):
137 | device = "cpu"
138 | if (platform.system() == 'Darwin'):
139 | device = "mps"
140 | elif (platform.system() == 'Linux'):
141 | device = "cuda"
142 |
143 | self.model, self.device = get_rife_model(device)
144 |
145 |
146 | #start_t = time.time()
147 |
148 | if (platform.system() == 'Darwin'): #SHM
149 | out_mem = shared_memory.SharedMemory(name=output['shm'].lstrip('/'))
150 | src_mem = shared_memory.SharedMemory(name=inputs['src:0']['shm'].lstrip('/'))
151 | src2_mem = shared_memory.SharedMemory(name=inputs.get('src:1', '-')['shm'].lstrip('/'))
152 |
153 | #map the shared memory into ndarray objects
154 | src1_array = numpy.ndarray((4, height, width), numpy.float32, src_mem.buf)
155 | src2_array = numpy.ndarray((4, height, width), numpy.float32, src2_mem.buf)
156 |
157 | out_array = numpy.ndarray((4, height, width), numpy.float32, out_mem.buf)
158 | out_array[:,:,:] = rife_inference(self.model, self.device, src1_array, src2_array, scale, timestamp, nbPass, width, height)
159 | #close the shared memory
160 | src_mem.close()
161 | src2_mem.close()
162 | out_mem.close()
163 |
164 | else: #CUDA_IPC x Linux
165 | shape = height * width * 16
166 | dtype = numpy.dtype(numpy.float32)
167 |
168 | #standard_b64decode
169 | ihandle1 = int(inputs['src:0']['cuda_ipc'], base=16).to_bytes(64, byteorder='big')
170 | ihandle2 = int(inputs.get('src:1','-')['cuda_ipc'], base=16).to_bytes(64, byteorder='big')
171 | ohandle = int(output['cuda_ipc'], base=16).to_bytes(64, byteorder='big')
172 |
173 | with cuda.open_ipc_array(ihandle1, shape=shape // dtype.itemsize, dtype=dtype) as src1_array:
174 | with cuda.open_ipc_array(ihandle2, shape=shape // dtype.itemsize, dtype=dtype) as src2_array:
175 | with cuda.open_ipc_array(ohandle, shape=shape // dtype.itemsize, dtype=dtype) as out_array:
176 |
177 | res = rife_inference(self.model, self.device, src1_array, src2_array, scale, timestamp, nbPass, width, height)
178 |
179 | cupy.cuda.runtime.memcpy(out_array.gpu_data._mem.handle.value, res.__cuda_array_interface__['data'][0], height * width * 4 * 4, 4 )
180 | cuda.synchronize()
181 | torch.cuda.empty_cache()
182 |
183 | return {'metadata' : {'RIFE' : 'Success'}}
184 |
185 | if __name__ == '__main__':
186 |
187 | # Register each effect
188 | RIFERetime()
189 |
190 | # and run
191 | flexi.run()
192 |
--------------------------------------------------------------------------------
/baselight/uk.co.andriy.mltimewarp.v1/uk.co.dirtylooks.mltimewarp.v1.flexi:
--------------------------------------------------------------------------------
1 | {
2 | "flexi_effects_definition_v1" : {
3 | "effects" : [
4 | {
5 | "identifier": "uk.co.dirtylooks.mltimewarp",
6 | "version": 1,
7 | "label": "DL ML Timewarp",
8 | "description": "DL ML Timewarp",
9 | "licence" : "LICENCE.md",
10 | "type" : "internal"
11 | }
12 | ],
13 | "launch_command": "./bin/python3 -W ignore ./timewarp.py",
14 | "install": {
15 | "all": {
16 | "flexi_platform": "pytorch-2.1",
17 | "sources": {
18 | "flexi.py": "../flexi.py",
19 | "timewarp.py": "timewarp.py",
20 | "LICENCE.md" : "LICENCE.md"
21 | },
22 | "weights" : {
23 | "Flownet001" : "flownet/flownet001.pth"
24 | },
25 | "run": []
26 | }
27 | }
28 | }
29 | }
30 |
31 |
--------------------------------------------------------------------------------
/find_scale.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Resolves the directory where the script is located
4 | SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
5 |
6 | # Change to the script directory
7 | cd "$SCRIPT_DIR"
8 |
9 | # Define the path to the Python executable and the Python script
10 | PYTHON_CMD="./packages/.miniconda/appenv/bin/python"
11 | PYTHON_SCRIPT="./pytorch/flameTimewarpML_findscale.py"
12 |
13 | # Run the Python script with all arguments passed to this shell script
14 | $PYTHON_CMD $PYTHON_SCRIPT "$@"
15 |
--------------------------------------------------------------------------------
/fix-xattr.command:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Define the target directory relative to the script location
4 | TARGET_DIR="$(dirname "$0")/packages/.miniconda"
5 |
6 | # Check if the target directory exists
7 | if [ ! -d "$TARGET_DIR" ]; then
8 | echo "Directory $TARGET_DIR does not exist."
9 | exit 1
10 | fi
11 |
12 | # Find all files in the target directory and clear extended attributes
13 | find "$TARGET_DIR" -type f -exec xattr -c {} \;
14 |
15 | echo "Extended attributes cleared from all files in $TARGET_DIR"
--------------------------------------------------------------------------------
/fonts/DejaVuSansMono.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/talosh/flameTimewarpML/cec31a5436d9c40113f635152c1c0da110aadebd/fonts/DejaVuSansMono.ttf
--------------------------------------------------------------------------------
/models/flownet2_v004.pth:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:40245beb4a757320f67f714ee5bf3f33c613c478bd01f30ecbb57b1ce445e9ec
3 | size 339780714
4 |
--------------------------------------------------------------------------------
/packages/.miniconda/README.md:
--------------------------------------------------------------------------------
1 | # flameTimewarpML
2 | * pre-configured miniconda environment to be placed here into appenv folder
--------------------------------------------------------------------------------
/packages/README.md:
--------------------------------------------------------------------------------
1 | ## flameTimewarpML
2 |
3 | * pre-configured miniconda environment should be placed into hidden "packages/.miniconda" folder
4 | * the folder is hidden (starts with ".") in order to keep Flame from scanning it looking for python hooks
5 | * pre-configured python environment usually packed with release tar file
6 |
7 | ### Installing and configuring python environment manually
8 |
9 | * download Miniconda for Mac or Linux (I'm using python 3.11 for tests) from
10 |
11 |
12 | * install downloaded Miniconda python distribution, use "-p" to select install location. For example:
13 |
14 | ```bash
15 | sh ~/Downloads/Miniconda3-py311_24.1.2-0-Linux-x86_64.sh -bfsm -p ~/miniconda3
16 | ```
17 |
18 | * Activate anc clone default environment into another named "appenv"
19 |
20 | ```bash
21 | eval "$(~/miniconda3/bin/conda shell.bash hook)"
22 | conda create --name appenv --clone base
23 | conda activate appenv
24 | ```
25 |
26 | * Install dependency libraries
27 |
28 | ```bash
29 | conda install pyqt
30 | conda install numpy
31 | conda install conda-pack
32 | ```
33 |
34 | * Install pytorch. Please look up exact commands depending on OS and Cuda versions at <>
35 |
36 | * Linux example
37 | ```bash
38 | conda install pytorch pytorch-cuda=11.8 -c pytorch -c nvidia
39 | ```
40 |
41 | * MacOS example:
42 |
43 | ```bash
44 | conda install pytorch::pytorch -c pytorch
45 | ```
46 |
47 | * Install rest of the dependencies
48 | ```bash
49 | pip install -r requirements.txt
50 | ```
51 |
52 | * Pack append environment into a portable tar file
53 |
54 | ```bash
55 | conda pack --ignore-missing-files -n appenv
56 | ```
57 |
58 | * Unpack environment to flameTimewarpML folder
59 |
60 | ```bash
61 | mkdir {flameTimewarpML folder}/packages/.miniconda/appenv/
62 | tar xvf appenv.tar.gz -C {flameTimewarpML folder}/packages/.miniconda/appenv/
63 | ```
64 |
65 | * Remove environment tarball
66 |
67 | ```bash
68 | rm appenv.tar.gz
69 | ```
70 |
--------------------------------------------------------------------------------
/packages/clean_pycache.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Define the target directory. Default is the current directory.
4 | TARGET_DIR="${1:-.}"
5 |
6 | # Find all __pycache__ directories and remove them
7 | find "$TARGET_DIR" -type d -name "__pycache__" -exec rm -r {} +
8 |
9 | echo "All __pycache__ directories have been removed from $TARGET_DIR"
--------------------------------------------------------------------------------
/presets/openexr16bit.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 | image
4 | Creates 16-bit fp OpenEXR file sequence (Uncompressed) numbered based on the timecode of the selected clip.
5 |
29 |
30 | 8
31 | 1
32 | 1
33 |
34 |
--------------------------------------------------------------------------------
/presets/openexr32bit.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 | image
4 | Creates 32-bit fp OpenEXR file sequence (Uncompressed) numbered based on the timecode of the selected clip.
5 |
29 |
30 | 8
31 | 1
32 | 1
33 |
34 |
--------------------------------------------------------------------------------
/presets/source_export.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 | sequence
4 | Source export 16-bit fp OpenEXR (Uncompressed)
5 |
6 | NONE
7 |
8 | <name>
9 | True
10 | True
11 |
12 | image
13 | Original
14 | NoChange
15 | True
16 | 99999
17 |
18 | True
19 | False
20 |
21 | audio
22 | FX
23 | FlattenTracks
24 | True
25 | 10
26 |
27 |
28 |
53 |
54 | 8
55 | 1
56 | 0
57 |
58 |
--------------------------------------------------------------------------------
/presets/source_export32bit.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 | sequence
4 | Source export 32-bit fp OpenEXR (Uncompressed)
5 |
6 | NONE
7 |
8 | <name>
9 | True
10 | True
11 |
12 | image
13 | Original
14 | NoChange
15 | True
16 | 99999
17 |
18 | True
19 | False
20 |
21 | audio
22 | FX
23 | FlattenTracks
24 | True
25 | 10
26 |
27 |
28 |
53 |
54 | 8
55 | 1
56 | 0
57 |
58 |
--------------------------------------------------------------------------------
/pytorch/check_consistent.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 | import importlib
5 | import platform
6 |
7 | os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
8 |
9 | try:
10 | import numpy as np
11 | except:
12 | python_executable_path = sys.executable
13 | if '.miniconda' in python_executable_path:
14 | print (f'Using "{python_executable_path}" as python interpreter')
15 | print ('Unable to import Numpy')
16 | sys.exit()
17 |
18 | try:
19 | import OpenImageIO as oiio
20 | except:
21 | python_executable_path = sys.executable
22 | if '.miniconda' in python_executable_path:
23 | print ('Unable to import OpenImageIO')
24 | print (f'Using "{python_executable_path}" as python interpreter')
25 | sys.exit()
26 |
27 | def read_image_file(file_path, header_only = False):
28 | result = {'spec': None, 'image_data': None}
29 | inp = oiio.ImageInput.open(file_path)
30 | if inp :
31 | spec = inp.spec()
32 | result['spec'] = spec
33 | if not header_only:
34 | channels = spec.nchannels
35 | result['image_data'] = inp.read_image(0, 0, 0, channels).transpose(1, 0, 2)
36 | inp.close()
37 | return result
38 |
39 | def find_folders_with_exr(path):
40 | """
41 | Find all folders under the given path that contain .exr files.
42 |
43 | Parameters:
44 | path (str): The root directory to start the search from.
45 |
46 | Returns:
47 | list: A list of directories containing .exr files.
48 | """
49 | directories_with_exr = set()
50 |
51 | # Walk through all directories and files in the given path
52 | for root, dirs, files in os.walk(path):
53 | if 'preview' in root:
54 | continue
55 | if 'eval' in root:
56 | continue
57 | for file in files:
58 | if file.endswith('.exr'):
59 | directories_with_exr.add(root)
60 | break # No need to check other files in the same directory
61 |
62 | return directories_with_exr
63 |
64 | def clear_lines(n=2):
65 | """Clears a specified number of lines in the terminal."""
66 | CURSOR_UP_ONE = '\x1b[1A'
67 | ERASE_LINE = '\x1b[2K'
68 | for _ in range(n):
69 | sys.stdout.write(CURSOR_UP_ONE)
70 | sys.stdout.write(ERASE_LINE)
71 |
72 | def main():
73 | parser = argparse.ArgumentParser(description='In-place resize exrs to half-res (WARNING - DESTRUCTIVE!)')
74 |
75 | # Required argument
76 | parser.add_argument('dataset_path', type=str, help='Path to folders with exrs')
77 | args = parser.parse_args()
78 |
79 | folders_with_exr = find_folders_with_exr(args.dataset_path)
80 |
81 | for folder_idx, folder_path in enumerate(sorted(folders_with_exr)):
82 | print (f'\rFolder [{folder_idx+1} / {len(folders_with_exr)}], {folder_path}', end='')
83 | folder_exr_files = [os.path.join(folder_path, file) for file in os.listdir(folder_path) if file.endswith('.exr')]
84 | folder_exr_files.sort()
85 |
86 | frame_numbers = []
87 | for filename in folder_exr_files:
88 | parts = filename.split('.')
89 | try:
90 | frame_number = int(parts[-2])
91 | frame_numbers.append(frame_number)
92 | except ValueError:
93 | print (f'\rFormat error in {folder_path}: {filename}')
94 |
95 | frame_numbers.sort()
96 | for i in range(1, len(frame_numbers)):
97 | if frame_numbers[i] != frame_numbers[i - 1] + 1:
98 | print(f'\r{folder_path}: Missing or non-consecutive frame between {frame_numbers[i-1]} and {frame_numbers[i]}')
99 |
100 | print ('')
101 |
102 | if __name__ == "__main__":
103 | main()
104 |
--------------------------------------------------------------------------------
/pytorch/check_readable.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 | import importlib
5 | import platform
6 |
7 | os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
8 |
9 | try:
10 | import numpy as np
11 | except:
12 | python_executable_path = sys.executable
13 | if '.miniconda' in python_executable_path:
14 | print (f'Using "{python_executable_path}" as python interpreter')
15 | print ('Unable to import Numpy')
16 | sys.exit()
17 |
18 | try:
19 | import OpenImageIO as oiio
20 | except:
21 | python_executable_path = sys.executable
22 | if '.miniconda' in python_executable_path:
23 | print ('Unable to import OpenImageIO')
24 | print (f'Using "{python_executable_path}" as python interpreter')
25 | sys.exit()
26 |
27 | def read_image_file(file_path, header_only = False):
28 | result = {'spec': None, 'image_data': None}
29 | inp = oiio.ImageInput.open(file_path)
30 | if inp :
31 | spec = inp.spec()
32 | result['spec'] = spec
33 | if not header_only:
34 | channels = spec.nchannels
35 | result['image_data'] = inp.read_image(0, 0, 0, channels).transpose(1, 0, 2)
36 | inp.close()
37 | return result
38 |
39 | def find_folders_with_exr(path):
40 | """
41 | Find all folders under the given path that contain .exr files.
42 |
43 | Parameters:
44 | path (str): The root directory to start the search from.
45 |
46 | Returns:
47 | list: A list of directories containing .exr files.
48 | """
49 | directories_with_exr = set()
50 |
51 | # Walk through all directories and files in the given path
52 | for root, dirs, files in os.walk(path):
53 | if root.endswith('preview'):
54 | continue
55 | if root.endswith('eval'):
56 | continue
57 | for file in files:
58 | if file.endswith('.exr'):
59 | directories_with_exr.add(root)
60 | break # No need to check other files in the same directory
61 |
62 | return directories_with_exr
63 |
64 | def clear_lines(n=2):
65 | """Clears a specified number of lines in the terminal."""
66 | CURSOR_UP_ONE = '\x1b[1A'
67 | ERASE_LINE = '\x1b[2K'
68 | for _ in range(n):
69 | sys.stdout.write(CURSOR_UP_ONE)
70 | sys.stdout.write(ERASE_LINE)
71 |
72 | def main():
73 | parser = argparse.ArgumentParser(description='In-place resize exrs to half-res (WARNING - DESTRUCTIVE!)')
74 |
75 | # Required argument
76 | parser.add_argument('dataset_path', type=str, help='Path to folders with exrs')
77 | args = parser.parse_args()
78 |
79 | folders_with_exr = find_folders_with_exr(args.dataset_path)
80 |
81 | exr_files = []
82 |
83 | for folder_path in sorted(folders_with_exr):
84 | folder_exr_files = [os.path.join(folder_path, file) for file in os.listdir(folder_path) if file.endswith('.exr')]
85 | folder_exr_files.sort()
86 | exr_files.extend(folder_exr_files)
87 |
88 | idx = 0
89 | for exr_file_path in exr_files:
90 | # clear_lines(1)
91 | print (f'\rFile [{idx+1} / {len(exr_files)}], {os.path.basename(exr_file_path)}', end='')
92 | try:
93 | result = read_image_file(exr_file_path)
94 | w, h, c = result['image_data'].shape
95 | except Exception as e:
96 | print (f'\nError reading {exr_file_path}: {e}')
97 | idx += 1
98 | print ('')
99 |
100 | if __name__ == "__main__":
101 | main()
102 |
--------------------------------------------------------------------------------
/pytorch/hub/checkpoints/README.md:
--------------------------------------------------------------------------------
1 | This folder is neede to put alexnet-owt-7be5be79.pth for LPIPS weights
--------------------------------------------------------------------------------
/pytorch/models/archived/flownet4_v001d.py:
--------------------------------------------------------------------------------
1 | # Orig v001 changed to v002 main flow and signatures
2 | # Compile test
3 |
4 | class Model:
5 | def __init__(self, status = dict(), torch = None):
6 | if torch is None:
7 | import torch
8 | Module = torch.nn.Module
9 | backwarp_tenGrid = {}
10 |
11 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
12 | return torch.nn.Sequential(
13 | torch.nn.Conv2d(
14 | in_planes,
15 | out_planes,
16 | kernel_size=kernel_size,
17 | stride=stride,
18 | padding=padding,
19 | dilation=dilation,
20 | padding_mode = 'reflect',
21 | bias=True
22 | ),
23 | torch.nn.LeakyReLU(0.2, True)
24 | # torch.nn.SELU(inplace = True)
25 | )
26 |
27 | def warp(tenInput, tenFlow):
28 | k = (str(tenFlow.device), str(tenFlow.size()))
29 | if k not in backwarp_tenGrid:
30 | tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3]).view(1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
31 | tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2]).view(1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
32 | backwarp_tenGrid[k] = torch.cat([ tenHorizontal, tenVertical ], 1).to(device=tenInput.device, dtype=tenInput.dtype)
33 | tenFlow = torch.cat([ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) ], 1)
34 |
35 | g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
36 | return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='reflection', align_corners=True)
37 |
38 | class Head(Module):
39 | def __init__(self):
40 | super(Head, self).__init__()
41 | self.cnn0 = torch.nn.Conv2d(3, 32, 3, 2, 1)
42 | self.cnn1 = torch.nn.Conv2d(32, 32, 3, 1, 1)
43 | self.cnn2 = torch.nn.Conv2d(32, 32, 3, 1, 1)
44 | self.cnn3 = torch.nn.ConvTranspose2d(32, 8, 4, 2, 1)
45 | self.relu = torch.nn.LeakyReLU(0.2, True)
46 |
47 | torch.nn.init.kaiming_normal_(self.cnn0.weight, mode='fan_in', nonlinearity='relu')
48 | self.cnn0.weight.data *= 1e-2
49 | if self.cnn0.bias is not None:
50 | torch.nn.init.constant_(self.cnn0.bias, 0)
51 | torch.nn.init.kaiming_normal_(self.cnn1.weight, mode='fan_in', nonlinearity='relu')
52 | self.cnn1.weight.data *= 1e-2
53 | if self.cnn1.bias is not None:
54 | torch.nn.init.constant_(self.cnn1.bias, 0)
55 | torch.nn.init.kaiming_normal_(self.cnn2.weight, mode='fan_in', nonlinearity='relu')
56 | self.cnn2.weight.data *= 1e-2
57 | if self.cnn2.bias is not None:
58 | torch.nn.init.constant_(self.cnn2.bias, 0)
59 | torch.nn.init.kaiming_normal_(self.cnn3.weight, mode='fan_in', nonlinearity='relu')
60 | self.cnn3.weight.data *= 1e-2
61 | if self.cnn3.bias is not None:
62 | torch.nn.init.constant_(self.cnn3.bias, 0)
63 |
64 | def forward(self, x, feat=False):
65 | x0 = self.cnn0(x)
66 | x = self.relu(x0)
67 | x1 = self.cnn1(x)
68 | x = self.relu(x1)
69 | x2 = self.cnn2(x)
70 | x = self.relu(x2)
71 | x3 = self.cnn3(x)
72 | if feat:
73 | return [x0, x1, x2, x3]
74 | return x3
75 |
76 | class ResConv(Module):
77 | def __init__(self, c, dilation=1):
78 | super().__init__()
79 | self.conv = torch.nn.Conv2d(c, c, 3, 1, dilation, dilation = dilation, groups = 1, padding_mode = 'reflect', bias=True)
80 | self.beta = torch.nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True)
81 | self.relu = torch.nn.LeakyReLU(0.2, True) # torch.nn.SELU(inplace = True)
82 |
83 | torch.nn.init.kaiming_normal_(self.conv.weight, mode='fan_in', nonlinearity='relu')
84 | self.conv.weight.data *= 1e-2
85 | if self.conv.bias is not None:
86 | torch.nn.init.constant_(self.conv.bias, 0)
87 |
88 | def forward(self, x):
89 | return self.relu(self.conv(x) * self.beta + x)
90 |
91 | class Flownet(Module):
92 | def __init__(self, in_planes, c=64):
93 | super().__init__()
94 | self.conv0 = torch.nn.Sequential(
95 | conv(in_planes, c//2, 3, 2, 1),
96 | conv(c//2, c, 3, 2, 1),
97 | )
98 | self.lastconv = torch.nn.Sequential(
99 | torch.nn.ConvTranspose2d(c, 4*6, 4, 2, 1),
100 | torch.nn.PixelShuffle(2)
101 | )
102 |
103 | def forward(self, img0, img1, f0, f1, timestep, mask, flow, scale=1):
104 | timestep = (img0[:, :1].clone() * 0 + 1) * timestep
105 | x = torch.cat((img0, img1, f0, f1, timestep), 1)
106 | if flow is not None:
107 | x = torch.cat((x, mask, flow), 1)
108 | feat = self.conv0(x)
109 | # feat = self.convblock(feat)
110 | tmp = self.lastconv(feat)
111 | flow = tmp[:, :4] * scale
112 | mask = tmp[:, 4:5]
113 | conf = tmp[:, 5:6]
114 | return flow, mask, conf
115 |
116 | class FlownetCas(Module):
117 | def __init__(self):
118 | super().__init__()
119 | self.block0 = Flownet(7+16, c=192)
120 | self.block1 = Flownet(8+4+16, c=128)
121 | self.block2 = Flownet(8+4+16, c=96)
122 | self.block3 = Flownet(8+4+16, c=64)
123 | self.encode = Head()
124 |
125 | def forward(self, img0, img1, timestep=0.5, scale=[8, 4, 2, 1], iterations=1):
126 | img0 = img0
127 | img1 = img1
128 | f0 = self.encode(img0)
129 | f1 = self.encode(img1)
130 |
131 | flow_list = [None] * 4
132 | mask_list = [None] * 4
133 | merged = [None] * 4
134 |
135 | flow, mask, conf = self.block0(img0, img1, f0, f1, timestep, None, None, scale=scale[0])
136 |
137 | flow_list[3] = flow.clone()
138 | mask_list[3] = torch.sigmoid(mask.clone())
139 | merged[3] = warp(img0, flow[:, :2]) * mask_list[3] + warp(img1, flow[:, 2:4]) * (1 - mask_list[3])
140 |
141 | return flow_list, mask_list, merged
142 |
143 | for iteration in range(iterations):
144 | flow_d, mask, conf = self.block1(
145 | warp(img0, flow[:, :2]),
146 | warp(img1, flow[:, 2:4]),
147 | warp(f0, flow[:, :2]),
148 | warp(f1, flow[:, 2:4]),
149 | timestep,
150 | mask,
151 | flow,
152 | scale=scale[1]
153 | )
154 | flow = flow + flow_d
155 |
156 | flow_list[1] = flow.clone()
157 | mask_list[1] = torch.sigmoid(mask.clone())
158 | merged[1] = warp(img0, flow[:, :2]) * mask_list[1] + warp(img1, flow[:, 2:4]) * (1 - mask_list[1])
159 |
160 | for iteration in range(iterations):
161 | flow_d, mask, conf = self.block2(
162 | warp(img0, flow[:, :2]),
163 | warp(img1, flow[:, 2:4]),
164 | warp(f0, flow[:, :2]),
165 | warp(f1, flow[:, 2:4]),
166 | timestep,
167 | mask,
168 | flow,
169 | scale=scale[2]
170 | )
171 | flow = flow + flow_d
172 |
173 | flow_list[2] = flow.clone()
174 | mask_list[2] = torch.sigmoid(mask.clone())
175 | merged[2] = warp(img0, flow[:, :2]) * mask_list[2] + warp(img1, flow[:, 2:4]) * (1 - mask_list[2])
176 |
177 | for iteration in range(iterations):
178 | flow_d, mask, conf = self.block3(
179 | warp(img0, flow[:, :2]),
180 | warp(img1, flow[:, 2:4]),
181 | warp(f0, flow[:, :2]),
182 | warp(f1, flow[:, 2:4]),
183 | timestep,
184 | mask,
185 | flow,
186 | scale=scale[3]
187 | )
188 | flow = flow + flow_d
189 |
190 | flow_list[3] = flow
191 | mask_list[3] = torch.sigmoid(mask)
192 | merged[3] = warp(img0, flow[:, :2]) * mask_list[3] + warp(img1, flow[:, 2:4]) * (1 - mask_list[3])
193 |
194 | return flow_list, mask_list, merged
195 |
196 | self.model = FlownetCas
197 | self.training_model = FlownetCas
198 |
199 | @staticmethod
200 | def get_info():
201 | info = {
202 | 'name': 'Flownet4_v001d',
203 | 'file': 'flownet4_v001d.py',
204 | 'ratio_support': True
205 | }
206 | return info
207 |
208 | @staticmethod
209 | def get_name():
210 | return 'TWML_Flownet_v001'
211 |
212 | @staticmethod
213 | def input_channels(model_state_dict):
214 | channels = 3
215 | try:
216 | channels = model_state_dict.get('multiresblock1.conv_3x3.conv1.weight').shape[1]
217 | except Exception as e:
218 | print (f'Unable to get model dict input channels - setting to 3, {e}')
219 | return channels
220 |
221 | @staticmethod
222 | def output_channels(model_state_dict):
223 | channels = 5
224 | try:
225 | channels = model_state_dict.get('conv_final.conv1.weight').shape[0]
226 | except Exception as e:
227 | print (f'Unable to get model dict output channels - setting to 3, {e}')
228 | return channels
229 |
230 | def get_model(self):
231 | import platform
232 | if platform.system() == 'Darwin':
233 | return self.training_model
234 | return self.model
235 |
236 | def get_training_model(self):
237 | return self.training_model
238 |
239 | def load_model(self, path, flownet, rank=0):
240 | import torch
241 | def convert(param):
242 | if rank == -1:
243 | return {
244 | k.replace("module.", ""): v
245 | for k, v in param.items()
246 | if "module." in k
247 | }
248 | else:
249 | return param
250 | if rank <= 0:
251 | if torch.cuda.is_available():
252 | flownet.load_state_dict(convert(torch.load(path)), False)
253 | else:
254 | flownet.load_state_dict(convert(torch.load(path, map_location ='cpu')), False)
--------------------------------------------------------------------------------
/pytorch/models/archived/flownet4_v001ea.py:
--------------------------------------------------------------------------------
1 | # Orig v001 changed to v002 main flow and signatures
2 | # Back from SiLU to LeakyReLU to test data flow
3 | # Warps moved to flownet forward
4 |
5 | class Model:
6 | def __init__(self, status = dict(), torch = None):
7 | if torch is None:
8 | import torch
9 | Module = torch.nn.Module
10 | backwarp_tenGrid = {}
11 |
12 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
13 | return torch.nn.Sequential(
14 | torch.nn.Conv2d(
15 | in_planes,
16 | out_planes,
17 | kernel_size=kernel_size,
18 | stride=stride,
19 | padding=padding,
20 | dilation=dilation,
21 | padding_mode = 'reflect',
22 | bias=True
23 | ),
24 | torch.nn.LeakyReLU(0.2, True)
25 | # torch.nn.SELU(inplace = True)
26 | )
27 |
28 | def warp(tenInput, tenFlow):
29 | k = (str(tenFlow.device), str(tenFlow.size()))
30 | if k not in backwarp_tenGrid:
31 | tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3]).view(1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
32 | tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2]).view(1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
33 | backwarp_tenGrid[k] = torch.cat([ tenHorizontal, tenVertical ], 1).to(device=tenInput.device, dtype=tenInput.dtype)
34 | tenFlow = torch.cat([ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) ], 1)
35 |
36 | g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
37 | return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='reflection', align_corners=True)
38 |
39 | class Head(Module):
40 | def __init__(self):
41 | super(Head, self).__init__()
42 | self.cnn0 = torch.nn.Conv2d(3, 32, 3, 2, 1)
43 | self.cnn1 = torch.nn.Conv2d(32, 32, 3, 1, 1)
44 | self.cnn2 = torch.nn.Conv2d(32, 32, 3, 1, 1)
45 | self.cnn3 = torch.nn.ConvTranspose2d(32, 8, 4, 2, 1)
46 | self.relu = torch.nn.LeakyReLU(0.2, True)
47 |
48 | def forward(self, x, feat=False):
49 | x0 = self.cnn0(x)
50 | x = self.relu(x0)
51 | x1 = self.cnn1(x)
52 | x = self.relu(x1)
53 | x2 = self.cnn2(x)
54 | x = self.relu(x2)
55 | x3 = self.cnn3(x)
56 | if feat:
57 | return [x0, x1, x2, x3]
58 | return x3
59 |
60 | class ResConv(Module):
61 | def __init__(self, c, dilation=1):
62 | super().__init__()
63 | self.conv = torch.nn.Conv2d(c, c, 3, 1, dilation, dilation = dilation, groups = 1, padding_mode = 'reflect', bias=True)
64 | self.beta = torch.nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True)
65 | self.relu = torch.nn.LeakyReLU(0.2, True) # torch.nn.SELU(inplace = True)
66 |
67 | torch.nn.init.kaiming_normal_(self.conv.weight, mode='fan_in', nonlinearity='relu')
68 | self.conv.weight.data *= 1e-2
69 | if self.conv.bias is not None:
70 | torch.nn.init.constant_(self.conv.bias, 0)
71 |
72 | def forward(self, x):
73 | return self.relu(self.conv(x) * self.beta + x)
74 |
75 | class Flownet(Module):
76 | def __init__(self, in_planes, c=64):
77 | super().__init__()
78 | self.conv0 = torch.nn.Sequential(
79 | conv(in_planes, c//2, 3, 2, 1),
80 | conv(c//2, c, 3, 2, 1),
81 | )
82 | self.convblock = torch.nn.Sequential(
83 | ResConv(c),
84 | ResConv(c),
85 | ResConv(c),
86 | ResConv(c),
87 | ResConv(c),
88 | ResConv(c),
89 | ResConv(c),
90 | ResConv(c),
91 | )
92 | self.lastconv = torch.nn.Sequential(
93 | torch.nn.ConvTranspose2d(c, 4*6, 4, 2, 1),
94 | torch.nn.PixelShuffle(2)
95 | )
96 |
97 | def forward(self, img0, img1, f0, f1, timestep, mask, flow, scale=1):
98 | timestep = (img0[:, :1].clone() * 0 + 1) * timestep
99 |
100 | if flow is None:
101 | x = torch.cat((img0, img1, f0, f1, timestep), 1)
102 | x = torch.nn.functional.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False)
103 | else:
104 | warped_img0 = warp(img0, flow[:, :2])
105 | warped_img1 = warp(img0, flow[:, :2])
106 | warped_f0 = warp(f0, flow[:, :2])
107 | warped_f1 = warp(f1, flow[:, 2:4])
108 | x = x = torch.cat((warped_img0, warped_img1, warped_f0, warped_f1, timestep, mask), 1)
109 | x = torch.nn.functional.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False)
110 | flow = torch.nn.functional.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False) * 1. / scale
111 | x = torch.cat((x, flow), 1)
112 |
113 | feat = self.conv0(x)
114 | feat = self.convblock(feat)
115 | tmp = self.lastconv(feat)
116 | tmp = torch.nn.functional.interpolate(tmp, scale_factor=scale, mode="bilinear", align_corners=False)
117 | flow = tmp[:, :4] * scale
118 | mask = tmp[:, 4:5]
119 | conf = tmp[:, 5:6]
120 | return flow, mask, conf
121 |
122 | class FlownetCas(Module):
123 | def __init__(self):
124 | super().__init__()
125 | self.block0 = Flownet(7+16, c=192)
126 | self.block1 = Flownet(8+4+16, c=128)
127 | self.block2 = Flownet(8+4+16, c=96)
128 | self.block3 = Flownet(8+4+16, c=64)
129 | self.encode = Head()
130 |
131 | def forward(self, img0, img1, timestep=0.5, scale=[8, 4, 2, 1], iterations=1):
132 | img0 = img0
133 | img1 = img1
134 | f0 = self.encode(img0)
135 | f1 = self.encode(img1)
136 |
137 | flow_list = [None] * 4
138 | mask_list = [None] * 4
139 | merged = [None] * 4
140 |
141 | flow, mask, conf = self.block0(img0, img1, f0, f1, timestep, None, None, scale=scale[0])
142 |
143 | flow_list[0] = flow.clone()
144 | mask_list[0] = torch.sigmoid(mask.clone())
145 | merged[0] = warp(img0, flow[:, :2]) * mask_list[0] + warp(img1, flow[:, 2:4]) * (1 - mask_list[0])
146 |
147 | for iteration in range(iterations):
148 | flow_d, mask, conf = self.block1(
149 | img0,
150 | img1,
151 | f0,
152 | f1,
153 | timestep,
154 | mask,
155 | flow,
156 | scale=scale[1]
157 | )
158 | flow = flow + flow_d
159 |
160 | flow_list[1] = flow.clone()
161 | mask_list[1] = torch.sigmoid(mask.clone())
162 | merged[1] = warp(img0, flow[:, :2]) * mask_list[1] + warp(img1, flow[:, 2:4]) * (1 - mask_list[1])
163 |
164 | for iteration in range(iterations):
165 | flow_d, mask, conf = self.block2(
166 | img0,
167 | img1,
168 | f0,
169 | f1,
170 | timestep,
171 | mask,
172 | flow,
173 | scale=scale[2]
174 | )
175 | flow = flow + flow_d
176 |
177 | flow_list[2] = flow.clone()
178 | mask_list[2] = torch.sigmoid(mask.clone())
179 | merged[2] = warp(img0, flow[:, :2]) * mask_list[2] + warp(img1, flow[:, 2:4]) * (1 - mask_list[2])
180 |
181 | for iteration in range(iterations):
182 | flow_d, mask, conf = self.block3(
183 | img0,
184 | img1,
185 | f0,
186 | f1,
187 | timestep,
188 | mask,
189 | flow,
190 | scale=scale[3]
191 | )
192 | flow = flow + flow_d
193 |
194 | flow_list[3] = flow
195 | mask_list[3] = torch.sigmoid(mask)
196 | merged[3] = warp(img0, flow[:, :2]) * mask_list[3] + warp(img1, flow[:, 2:4]) * (1 - mask_list[3])
197 |
198 | return flow_list, mask_list, merged
199 |
200 | self.model = FlownetCas
201 | self.training_model = FlownetCas
202 |
203 | @staticmethod
204 | def get_info():
205 | info = {
206 | 'name': 'Flownet4_v001ea',
207 | 'file': 'flownet4_v001ea.py',
208 | 'ratio_support': True
209 | }
210 | return info
211 |
212 | @staticmethod
213 | def get_name():
214 | return 'TWML_Flownet_v001ea'
215 |
216 | @staticmethod
217 | def input_channels(model_state_dict):
218 | channels = 3
219 | try:
220 | channels = model_state_dict.get('multiresblock1.conv_3x3.conv1.weight').shape[1]
221 | except Exception as e:
222 | print (f'Unable to get model dict input channels - setting to 3, {e}')
223 | return channels
224 |
225 | @staticmethod
226 | def output_channels(model_state_dict):
227 | channels = 5
228 | try:
229 | channels = model_state_dict.get('conv_final.conv1.weight').shape[0]
230 | except Exception as e:
231 | print (f'Unable to get model dict output channels - setting to 3, {e}')
232 | return channels
233 |
234 | def get_model(self):
235 | import platform
236 | if platform.system() == 'Darwin':
237 | return self.training_model
238 | return self.model
239 |
240 | def get_training_model(self):
241 | return self.training_model
242 |
243 | def load_model(self, path, flownet, rank=0):
244 | import torch
245 | def convert(param):
246 | if rank == -1:
247 | return {
248 | k.replace("module.", ""): v
249 | for k, v in param.items()
250 | if "module." in k
251 | }
252 | else:
253 | return param
254 | if rank <= 0:
255 | if torch.cuda.is_available():
256 | flownet.load_state_dict(convert(torch.load(path)), False)
257 | else:
258 | flownet.load_state_dict(convert(torch.load(path, map_location ='cpu')), False)
--------------------------------------------------------------------------------
/pytorch/models/archived/flownet4_v001eb3x.py:
--------------------------------------------------------------------------------
1 | # Orig v001 changed to v002 main flow and signatures
2 | # Back from SiLU to LeakyReLU to test data flow
3 | # Warps moved to flownet forward
4 | # Different Tail from flownet 2lh (ConvTr 6x6, conv 1x1, ConvTr 4x4, conv 1x1)
5 | # Scale list are multiplies of 3 instead of 2 [12, 6, 3, 1]
6 |
7 | class Model:
8 | def __init__(self, status = dict(), torch = None):
9 | if torch is None:
10 | import torch
11 | Module = torch.nn.Module
12 | backwarp_tenGrid = {}
13 |
14 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
15 | return torch.nn.Sequential(
16 | torch.nn.Conv2d(
17 | in_planes,
18 | out_planes,
19 | kernel_size=kernel_size,
20 | stride=stride,
21 | padding=padding,
22 | dilation=dilation,
23 | padding_mode = 'reflect',
24 | bias=True
25 | ),
26 | torch.nn.LeakyReLU(0.2, True)
27 | # torch.nn.SELU(inplace = True)
28 | )
29 |
30 | def warp(tenInput, tenFlow):
31 | k = (str(tenFlow.device), str(tenFlow.size()))
32 | if k not in backwarp_tenGrid:
33 | tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3]).view(1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
34 | tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2]).view(1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
35 | backwarp_tenGrid[k] = torch.cat([ tenHorizontal, tenVertical ], 1).to(device=tenInput.device, dtype=tenInput.dtype)
36 | tenFlow = torch.cat([ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) ], 1)
37 |
38 | g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
39 | return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='reflection', align_corners=True)
40 |
41 | class Head(Module):
42 | def __init__(self):
43 | super(Head, self).__init__()
44 | self.cnn0 = torch.nn.Conv2d(3, 32, 3, 2, 1)
45 | self.cnn1 = torch.nn.Conv2d(32, 32, 3, 1, 1)
46 | self.cnn2 = torch.nn.Conv2d(32, 32, 3, 1, 1)
47 | self.cnn3 = torch.nn.ConvTranspose2d(32, 8, 4, 2, 1)
48 | self.relu = torch.nn.LeakyReLU(0.2, True)
49 |
50 | def forward(self, x, feat=False):
51 | x0 = self.cnn0(x)
52 | x = self.relu(x0)
53 | x1 = self.cnn1(x)
54 | x = self.relu(x1)
55 | x2 = self.cnn2(x)
56 | x = self.relu(x2)
57 | x3 = self.cnn3(x)
58 | if feat:
59 | return [x0, x1, x2, x3]
60 | return x3
61 |
62 | class ResConv(Module):
63 | def __init__(self, c, dilation=1):
64 | super().__init__()
65 | self.conv = torch.nn.Conv2d(c, c, 3, 1, dilation, dilation = dilation, groups = 1, padding_mode = 'reflect', bias=True)
66 | self.beta = torch.nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True)
67 | self.relu = torch.nn.LeakyReLU(0.2, True) # torch.nn.SELU(inplace = True)
68 |
69 | torch.nn.init.kaiming_normal_(self.conv.weight, mode='fan_in', nonlinearity='relu')
70 | self.conv.weight.data *= 1e-2
71 | if self.conv.bias is not None:
72 | torch.nn.init.constant_(self.conv.bias, 0)
73 |
74 | def forward(self, x):
75 | return self.relu(self.conv(x) * self.beta + x)
76 |
77 | class Flownet(Module):
78 | def __init__(self, in_planes, c=64):
79 | super().__init__()
80 | self.conv0 = torch.nn.Sequential(
81 | conv(in_planes, c//2, 3, 2, 1),
82 | conv(c//2, c, 3, 2, 1),
83 | )
84 | self.convblock = torch.nn.Sequential(
85 | ResConv(c),
86 | ResConv(c),
87 | ResConv(c),
88 | ResConv(c),
89 | ResConv(c),
90 | ResConv(c),
91 | ResConv(c),
92 | ResConv(c),
93 | )
94 | self.lastconv = torch.nn.Sequential(
95 | torch.nn.ConvTranspose2d(c, c, 6, 2, 2),
96 | torch.nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0, bias=True),
97 | torch.nn.ConvTranspose2d(c, c, 4, 2, 1),
98 | torch.nn.Conv2d(c, 6, kernel_size=1, stride=1, padding=0, bias=True),
99 | )
100 |
101 | def forward(self, img0, img1, f0, f1, timestep, mask, flow, scale=1):
102 | timestep = (img0[:, :1].clone() * 0 + 1) * timestep
103 |
104 | if flow is None:
105 | x = torch.cat((img0, img1, f0, f1, timestep), 1)
106 | x = torch.nn.functional.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False)
107 | else:
108 | warped_img0 = warp(img0, flow[:, :2])
109 | warped_img1 = warp(img0, flow[:, :2])
110 | warped_f0 = warp(f0, flow[:, :2])
111 | warped_f1 = warp(f1, flow[:, 2:4])
112 | x = torch.cat((warped_img0, warped_img1, warped_f0, warped_f1, timestep, mask), 1)
113 | x = torch.nn.functional.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False)
114 | flow = torch.nn.functional.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False) * 1. / scale
115 | x = torch.cat((x, flow), 1)
116 |
117 | feat = self.conv0(x)
118 | feat = self.convblock(feat)
119 | tmp = self.lastconv(feat)
120 | tmp = torch.nn.functional.interpolate(tmp, scale_factor=scale, mode="bilinear", align_corners=False)
121 | flow = tmp[:, :4] * scale
122 | mask = tmp[:, 4:5]
123 | conf = tmp[:, 5:6]
124 | return flow, mask, conf
125 |
126 | class FlownetCas(Module):
127 | def __init__(self):
128 | super().__init__()
129 | self.block0 = Flownet(7+16, c=192)
130 | self.block1 = Flownet(8+4+16, c=128)
131 | self.block2 = Flownet(8+4+16, c=96)
132 | self.block3 = Flownet(8+4+16, c=64)
133 | self.encode = Head()
134 |
135 | def forward(self, img0, img1, timestep=0.5, scale=[8, 4, 2, 1], iterations=1):
136 | scale = [12, 6, 3, 1] if scale == [8, 4, 2, 1] else scale
137 |
138 | img0 = img0
139 | img1 = img1
140 | f0 = self.encode(img0)
141 | f1 = self.encode(img1)
142 |
143 | flow_list = [None] * 4
144 | mask_list = [None] * 4
145 | merged = [None] * 4
146 |
147 | flow, mask, conf = self.block0(img0, img1, f0, f1, timestep, None, None, scale=scale[0])
148 |
149 | flow_list[0] = flow.clone()
150 | mask_list[0] = torch.sigmoid(mask.clone())
151 | merged[0] = warp(img0, flow[:, :2]) * mask_list[0] + warp(img1, flow[:, 2:4]) * (1 - mask_list[0])
152 |
153 | for iteration in range(iterations):
154 | flow_d, mask, conf = self.block1(
155 | img0,
156 | img1,
157 | f0,
158 | f1,
159 | timestep,
160 | mask,
161 | flow,
162 | scale=scale[1]
163 | )
164 | flow = flow + flow_d
165 |
166 | flow_list[1] = flow.clone()
167 | mask_list[1] = torch.sigmoid(mask.clone())
168 | merged[1] = warp(img0, flow[:, :2]) * mask_list[1] + warp(img1, flow[:, 2:4]) * (1 - mask_list[1])
169 |
170 | for iteration in range(iterations):
171 | flow_d, mask, conf = self.block2(
172 | img0,
173 | img1,
174 | f0,
175 | f1,
176 | timestep,
177 | mask,
178 | flow,
179 | scale=scale[2]
180 | )
181 | flow = flow + flow_d
182 |
183 | flow_list[2] = flow.clone()
184 | mask_list[2] = torch.sigmoid(mask.clone())
185 | merged[2] = warp(img0, flow[:, :2]) * mask_list[2] + warp(img1, flow[:, 2:4]) * (1 - mask_list[2])
186 |
187 | for iteration in range(iterations):
188 | flow_d, mask, conf = self.block3(
189 | img0,
190 | img1,
191 | f0,
192 | f1,
193 | timestep,
194 | mask,
195 | flow,
196 | scale=scale[3]
197 | )
198 | flow = flow + flow_d
199 |
200 | flow_list[3] = flow
201 | mask_list[3] = torch.sigmoid(mask)
202 | merged[3] = warp(img0, flow[:, :2]) * mask_list[3] + warp(img1, flow[:, 2:4]) * (1 - mask_list[3])
203 |
204 | return flow_list, mask_list, merged
205 |
206 | self.model = FlownetCas
207 | self.training_model = FlownetCas
208 |
209 | @staticmethod
210 | def get_info():
211 | info = {
212 | 'name': 'Flownet4_v001eb3x',
213 | 'file': 'flownet4_v001eb3x.py',
214 | 'ratio_support': True,
215 | 'padding': 48
216 | }
217 | return info
218 |
219 | @staticmethod
220 | def get_name():
221 | return 'TWML_Flownet_v001eb3x'
222 |
223 | @staticmethod
224 | def input_channels(model_state_dict):
225 | channels = 3
226 | try:
227 | channels = model_state_dict.get('multiresblock1.conv_3x3.conv1.weight').shape[1]
228 | except Exception as e:
229 | print (f'Unable to get model dict input channels - setting to 3, {e}')
230 | return channels
231 |
232 | @staticmethod
233 | def output_channels(model_state_dict):
234 | channels = 5
235 | try:
236 | channels = model_state_dict.get('conv_final.conv1.weight').shape[0]
237 | except Exception as e:
238 | print (f'Unable to get model dict output channels - setting to 3, {e}')
239 | return channels
240 |
241 | def get_model(self):
242 | import platform
243 | if platform.system() == 'Darwin':
244 | return self.training_model
245 | return self.model
246 |
247 | def get_training_model(self):
248 | return self.training_model
249 |
250 | def load_model(self, path, flownet, rank=0):
251 | import torch
252 | def convert(param):
253 | if rank == -1:
254 | return {
255 | k.replace("module.", ""): v
256 | for k, v in param.items()
257 | if "module." in k
258 | }
259 | else:
260 | return param
261 | if rank <= 0:
262 | if torch.cuda.is_available():
263 | flownet.load_state_dict(convert(torch.load(path)), False)
264 | else:
265 | flownet.load_state_dict(convert(torch.load(path, map_location ='cpu')), False)
--------------------------------------------------------------------------------
/pytorch/models/flownet4_v001_baseline.py:
--------------------------------------------------------------------------------
1 | class Model:
2 |
3 | info = {
4 | 'name': 'Flownet4_v001_baseline',
5 | 'file': 'flownet4_v001_baseline.py',
6 | 'ratio_support': True
7 | }
8 |
9 | def __init__(self, status = dict(), torch = None):
10 | if torch is None:
11 | import torch
12 | Module = torch.nn.Module
13 | backwarp_tenGrid = {}
14 |
15 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
16 | return torch.nn.Sequential(
17 | torch.nn.Conv2d(
18 | in_planes,
19 | out_planes,
20 | kernel_size=kernel_size,
21 | stride=stride,
22 | padding=padding,
23 | dilation=dilation,
24 | padding_mode = 'reflect',
25 | bias=True
26 | ),
27 | torch.nn.LeakyReLU(0.2, True)
28 | # torch.nn.SELU(inplace = True)
29 | )
30 |
31 | def warp(tenInput, tenFlow):
32 | k = (str(tenFlow.device), str(tenFlow.size()))
33 | if k not in backwarp_tenGrid:
34 | tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3]).view(1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
35 | tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2]).view(1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
36 | backwarp_tenGrid[k] = torch.cat([ tenHorizontal, tenVertical ], 1).to(device=tenInput.device, dtype=tenInput.dtype)
37 | tenFlow = torch.cat([ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) ], 1)
38 |
39 | g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
40 | return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
41 |
42 | class Head(Module):
43 | def __init__(self):
44 | super(Head, self).__init__()
45 | self.cnn0 = torch.nn.Conv2d(3, 32, 3, 2, 1)
46 | self.cnn1 = torch.nn.Conv2d(32, 32, 3, 1, 1)
47 | self.cnn2 = torch.nn.Conv2d(32, 32, 3, 1, 1)
48 | self.cnn3 = torch.nn.ConvTranspose2d(32, 8, 4, 2, 1)
49 | self.relu = torch.nn.LeakyReLU(0.2, True)
50 |
51 | def forward(self, x, feat=False):
52 | x0 = self.cnn0(x)
53 | x = self.relu(x0)
54 | x1 = self.cnn1(x)
55 | x = self.relu(x1)
56 | x2 = self.cnn2(x)
57 | x = self.relu(x2)
58 | x3 = self.cnn3(x)
59 | if feat:
60 | return [x0, x1, x2, x3]
61 | return x3
62 |
63 | class ResConv(Module):
64 | def __init__(self, c, dilation=1):
65 | super().__init__()
66 | self.conv = torch.nn.Conv2d(c, c, 3, 1, dilation, dilation = dilation, groups = 1, padding_mode = 'reflect', bias=True)
67 | self.beta = torch.nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True)
68 | self.relu = torch.nn.LeakyReLU(0.2, True)
69 |
70 | def forward(self, x):
71 | return self.relu(self.conv(x) * self.beta + x)
72 |
73 | class Flownet(Module):
74 | def __init__(self, in_planes, c=64):
75 | super().__init__()
76 | self.conv0 = torch.nn.Sequential(
77 | conv(in_planes, c//2, 3, 2, 1),
78 | conv(c//2, c, 3, 2, 1),
79 | )
80 | self.convblock = torch.nn.Sequential(
81 | ResConv(c),
82 | ResConv(c),
83 | ResConv(c),
84 | ResConv(c),
85 | ResConv(c),
86 | ResConv(c),
87 | ResConv(c),
88 | ResConv(c),
89 | )
90 | self.lastconv = torch.nn.Sequential(
91 | torch.nn.ConvTranspose2d(c, 4*6, 4, 2, 1),
92 | torch.nn.PixelShuffle(2)
93 | )
94 |
95 | def forward(self, x, flow, scale=1):
96 | x = torch.nn.functional.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False)
97 | if flow is not None:
98 | flow = torch.nn.functional.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False) * 1. / scale
99 | x = torch.cat((x, flow), 1)
100 | feat = self.conv0(x)
101 | feat = self.convblock(feat)
102 | tmp = self.lastconv(feat)
103 | tmp = torch.nn.functional.interpolate(tmp, scale_factor=scale, mode="bilinear", align_corners=False)
104 | flow = tmp[:, :4] * scale
105 | mask = tmp[:, 4:5]
106 | conf = tmp[:, 5:6]
107 | return flow, mask, conf
108 |
109 | class FlownetCas(Module):
110 | def __init__(self):
111 | super().__init__()
112 | self.block0 = Flownet(7+16, c=192)
113 | self.block1 = Flownet(8+4+16, c=128)
114 | self.block2 = Flownet(8+4+16, c=96)
115 | self.block3 = Flownet(8+4+16, c=64)
116 | self.encode = Head()
117 |
118 | def forward(self, img0, img1, timestep=0.5, scale=[8, 4, 2, 1], iterations=1):
119 | gt = None
120 | # return self.encode(img0)
121 | img0 = img0
122 | img1 = img1
123 | f0 = self.encode(img0)
124 | f1 = self.encode(img1)
125 |
126 | if not torch.is_tensor(timestep):
127 | timestep = (img0[:, :1].clone() * 0 + 1) * timestep
128 | else:
129 | timestep = timestep.repeat(1, 1, img0.shape[2], img0.shape[3])
130 | flow_list = []
131 | merged = []
132 | mask_list = []
133 | conf_list = []
134 | teacher_list = []
135 | flow_list_teacher = []
136 | warped_img0 = img0
137 | warped_img1 = img1
138 | flow = None
139 | loss_cons = 0
140 | stu = [self.block0, self.block1, self.block2, self.block3]
141 | flow = None
142 |
143 | # single step
144 | flow, mask, conf = stu[0](torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=1)
145 |
146 | flow_list = [flow] * 5
147 | mask_list = [torch.sigmoid(mask)] * 5
148 | conf_list = [torch.sigmoid(conf)] * 5
149 | merged_student = warp(img0, flow[:, :2]) * mask_list[0] + warp(img1, flow[:, 2:4]) * (1 - mask_list[0])
150 | merged = [merged_student] * 5
151 | return flow_list, mask_list, conf_list, merged
152 |
153 | for i in range(4):
154 | if flow is not None:
155 | flow_d, mask, conf = stu[i](torch.cat((warped_img0, warped_img1, warped_f0, warped_f1, timestep, mask), 1), flow, scale=scale[i])
156 | flow = flow + flow_d
157 | else:
158 | flow, mask, conf = stu[i](torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=scale[i])
159 |
160 | mask_list.append(mask)
161 | flow_list.append(flow)
162 | conf_list.append(conf)
163 | warped_img0 = warp(img0, flow[:, :2])
164 | warped_img1 = warp(img1, flow[:, 2:4])
165 | warped_f0 = warp(f0, flow[:, :2])
166 | warped_f1 = warp(f1, flow[:, 2:4])
167 | merged_student = (warped_img0, warped_img1)
168 | merged.append(merged_student)
169 | conf = torch.sigmoid(torch.cat(conf_list, 1))
170 | conf = conf / (conf.sum(1, True) + 1e-3)
171 | if gt is not None:
172 | flow_teacher = 0
173 | mask_teacher = 0
174 | for i in range(4):
175 | flow_teacher += conf[:, i:i+1] * flow_list[i]
176 | mask_teacher += conf[:, i:i+1] * mask_list[i]
177 | warped_img0_teacher = warp(img0, flow_teacher[:, :2])
178 | warped_img1_teacher = warp(img1, flow_teacher[:, 2:4])
179 | mask_teacher = torch.sigmoid(mask_teacher)
180 | merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (1 - mask_teacher)
181 | teacher_list.append(merged_teacher)
182 | flow_list_teacher.append(flow_teacher)
183 |
184 | for i in range(4):
185 | mask_list[i] = torch.sigmoid(mask_list[i])
186 | merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
187 | if gt is not None:
188 | loss_mask = ((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 1e-2).float().detach()
189 | loss_cons += (((flow_teacher.detach() - flow_list[i]) ** 2).sum(1, True) ** 0.5 * loss_mask).mean() * 0.001
190 |
191 | return flow_list, mask_list, merged
192 |
193 | self.model = FlownetCas
194 | self.training_model = FlownetCas
195 |
196 | @staticmethod
197 | def get_info():
198 | return Model.info
199 |
200 | @staticmethod
201 | def get_name():
202 | return Model.info.get('name')
203 |
204 | @staticmethod
205 | def input_channels(model_state_dict):
206 | channels = 3
207 | try:
208 | channels = model_state_dict.get('multiresblock1.conv_3x3.conv1.weight').shape[1]
209 | except Exception as e:
210 | print (f'Unable to get model dict input channels - setting to 3, {e}')
211 | return channels
212 |
213 | @staticmethod
214 | def output_channels(model_state_dict):
215 | channels = 5
216 | try:
217 | channels = model_state_dict.get('conv_final.conv1.weight').shape[0]
218 | except Exception as e:
219 | print (f'Unable to get model dict output channels - setting to 3, {e}')
220 | return channels
221 |
222 | def get_model(self):
223 | import platform
224 | if platform.system() == 'Darwin':
225 | return self.training_model
226 | return self.model
227 |
228 | def get_training_model(self):
229 | return self.training_model
230 |
231 | def load_model(self, path, flownet, rank=0):
232 | import torch
233 | def convert(param):
234 | if rank == -1:
235 | return {
236 | k.replace("module.", ""): v
237 | for k, v in param.items()
238 | if "module." in k
239 | }
240 | else:
241 | return param
242 | if rank <= 0:
243 | if torch.cuda.is_available():
244 | flownet.load_state_dict(convert(torch.load(path)), False)
245 | else:
246 | flownet.load_state_dict(convert(torch.load(path, map_location ='cpu')), False)
--------------------------------------------------------------------------------
/pytorch/models/flownet4_v001_baseline_sst.py:
--------------------------------------------------------------------------------
1 | class Model:
2 |
3 | info = {
4 | 'name': 'Flownet4_v001_baseline',
5 | 'file': 'flownet4_v001_baseline.py',
6 | 'ratio_support': True
7 | }
8 |
9 | def __init__(self, status = dict(), torch = None):
10 | if torch is None:
11 | import torch
12 | Module = torch.nn.Module
13 | backwarp_tenGrid = {}
14 |
15 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
16 | return torch.nn.Sequential(
17 | torch.nn.Conv2d(
18 | in_planes,
19 | out_planes,
20 | kernel_size=kernel_size,
21 | stride=stride,
22 | padding=padding,
23 | dilation=dilation,
24 | padding_mode = 'reflect',
25 | bias=True
26 | ),
27 | torch.nn.LeakyReLU(0.2, True)
28 | # torch.nn.SELU(inplace = True)
29 | )
30 |
31 | def warp(tenInput, tenFlow):
32 | k = (str(tenFlow.device), str(tenFlow.size()))
33 | if k not in backwarp_tenGrid:
34 | tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3]).view(1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
35 | tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2]).view(1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
36 | backwarp_tenGrid[k] = torch.cat([ tenHorizontal, tenVertical ], 1).to(device=tenInput.device, dtype=tenInput.dtype)
37 | tenFlow = torch.cat([ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) ], 1)
38 |
39 | g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
40 | return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
41 |
42 | class Head(Module):
43 | def __init__(self):
44 | super(Head, self).__init__()
45 | self.cnn0 = torch.nn.Conv2d(3, 32, 3, 2, 1)
46 | self.cnn1 = torch.nn.Conv2d(32, 32, 3, 1, 1)
47 | self.cnn2 = torch.nn.Conv2d(32, 32, 3, 1, 1)
48 | self.cnn3 = torch.nn.ConvTranspose2d(32, 8, 4, 2, 1)
49 | self.relu = torch.nn.LeakyReLU(0.2, True)
50 |
51 | def forward(self, x, feat=False):
52 | x0 = self.cnn0(x)
53 | x = self.relu(x0)
54 | x1 = self.cnn1(x)
55 | x = self.relu(x1)
56 | x2 = self.cnn2(x)
57 | x = self.relu(x2)
58 | x3 = self.cnn3(x)
59 | if feat:
60 | return [x0, x1, x2, x3]
61 | return x3
62 |
63 | class ResConv(Module):
64 | def __init__(self, c, dilation=1):
65 | super().__init__()
66 | self.conv = torch.nn.Conv2d(c, c, 3, 1, dilation, dilation = dilation, groups = 1, padding_mode = 'reflect', bias=True)
67 | self.beta = torch.nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True)
68 | self.relu = torch.nn.LeakyReLU(0.2, True)
69 |
70 | def forward(self, x):
71 | return self.relu(self.conv(x) * self.beta + x)
72 |
73 | class Flownet(Module):
74 | def __init__(self, in_planes, c=64):
75 | super().__init__()
76 | self.conv0 = torch.nn.Sequential(
77 | conv(in_planes, c//2, 3, 2, 1),
78 | conv(c//2, c, 3, 2, 1),
79 | )
80 | self.convblock = torch.nn.Sequential(
81 | ResConv(c),
82 | ResConv(c),
83 | ResConv(c),
84 | ResConv(c),
85 | ResConv(c),
86 | ResConv(c),
87 | ResConv(c),
88 | ResConv(c),
89 | )
90 | self.lastconv = torch.nn.Sequential(
91 | torch.nn.ConvTranspose2d(c, 4*6, 4, 2, 1),
92 | torch.nn.PixelShuffle(2)
93 | )
94 |
95 | def forward(self, x, flow, scale=1):
96 | x = torch.nn.functional.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False)
97 | if flow is not None:
98 | flow = torch.nn.functional.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False) * 1. / scale
99 | x = torch.cat((x, flow), 1)
100 | feat = self.conv0(x)
101 | feat = self.convblock(feat)
102 | tmp = self.lastconv(feat)
103 | tmp = torch.nn.functional.interpolate(tmp, scale_factor=scale, mode="bilinear", align_corners=False)
104 | flow = tmp[:, :4] * scale
105 | mask = tmp[:, 4:5]
106 | conf = tmp[:, 5:6]
107 | return flow, mask, conf
108 |
109 | class FlownetCas(Module):
110 | def __init__(self):
111 | super().__init__()
112 | self.block0 = Flownet(7+16, c=192)
113 | self.block1 = Flownet(8+4+16, c=128)
114 | self.block2 = Flownet(8+4+16, c=96)
115 | self.block3 = Flownet(8+4+16, c=64)
116 | self.encode = Head()
117 |
118 | def forward(self, img0, img1, timestep=0.5, scale=[8, 4, 2, 1], iterations=1):
119 | gt = None
120 | # return self.encode(img0)
121 | img0 = img0
122 | img1 = img1
123 | f0 = self.encode(img0)
124 | f1 = self.encode(img1)
125 |
126 | if not torch.is_tensor(timestep):
127 | timestep = (img0[:, :1].clone() * 0 + 1) * timestep
128 | else:
129 | timestep = timestep.repeat(1, 1, img0.shape[2], img0.shape[3])
130 | flow_list = []
131 | merged = []
132 | mask_list = []
133 | conf_list = []
134 | teacher_list = []
135 | flow_list_teacher = []
136 | warped_img0 = img0
137 | warped_img1 = img1
138 | flow = None
139 | loss_cons = 0
140 | stu = [self.block0, self.block1, self.block2, self.block3]
141 | flow = None
142 |
143 | # single step
144 | flow, mask, conf = stu[0](torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=1)
145 |
146 | flow_list = [flow] * 5
147 | mask_list = [torch.sigmoid(mask)] * 5
148 | conf_list = [torch.sigmoid(conf)] * 5
149 | merged_student = warp(img0, flow[:, :2]) * mask_list[0] + warp(img1, flow[:, 2:4]) * (1 - mask_list[0])
150 | merged = [merged_student] * 5
151 | return flow_list, mask_list, conf_list, merged
152 |
153 | for i in range(4):
154 | if flow is not None:
155 | flow_d, mask, conf = stu[i](torch.cat((warped_img0, warped_img1, warped_f0, warped_f1, timestep, mask), 1), flow, scale=scale[i])
156 | flow = flow + flow_d
157 | else:
158 | flow, mask, conf = stu[i](torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=scale[i])
159 |
160 | mask_list.append(mask)
161 | flow_list.append(flow)
162 | conf_list.append(conf)
163 | warped_img0 = warp(img0, flow[:, :2])
164 | warped_img1 = warp(img1, flow[:, 2:4])
165 | warped_f0 = warp(f0, flow[:, :2])
166 | warped_f1 = warp(f1, flow[:, 2:4])
167 | merged_student = (warped_img0, warped_img1)
168 | merged.append(merged_student)
169 | conf = torch.sigmoid(torch.cat(conf_list, 1))
170 | conf = conf / (conf.sum(1, True) + 1e-3)
171 | if gt is not None:
172 | flow_teacher = 0
173 | mask_teacher = 0
174 | for i in range(4):
175 | flow_teacher += conf[:, i:i+1] * flow_list[i]
176 | mask_teacher += conf[:, i:i+1] * mask_list[i]
177 | warped_img0_teacher = warp(img0, flow_teacher[:, :2])
178 | warped_img1_teacher = warp(img1, flow_teacher[:, 2:4])
179 | mask_teacher = torch.sigmoid(mask_teacher)
180 | merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (1 - mask_teacher)
181 | teacher_list.append(merged_teacher)
182 | flow_list_teacher.append(flow_teacher)
183 |
184 | for i in range(4):
185 | mask_list[i] = torch.sigmoid(mask_list[i])
186 | merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
187 | if gt is not None:
188 | loss_mask = ((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 1e-2).float().detach()
189 | loss_cons += (((flow_teacher.detach() - flow_list[i]) ** 2).sum(1, True) ** 0.5 * loss_mask).mean() * 0.001
190 |
191 | return flow_list, mask_list, merged
192 |
193 | self.model = FlownetCas
194 | self.training_model = FlownetCas
195 |
196 | @staticmethod
197 | def get_info():
198 | return Model.info
199 |
200 | @staticmethod
201 | def get_name():
202 | return Model.info.get('name')
203 |
204 | @staticmethod
205 | def input_channels(model_state_dict):
206 | channels = 3
207 | try:
208 | channels = model_state_dict.get('multiresblock1.conv_3x3.conv1.weight').shape[1]
209 | except Exception as e:
210 | print (f'Unable to get model dict input channels - setting to 3, {e}')
211 | return channels
212 |
213 | @staticmethod
214 | def output_channels(model_state_dict):
215 | channels = 5
216 | try:
217 | channels = model_state_dict.get('conv_final.conv1.weight').shape[0]
218 | except Exception as e:
219 | print (f'Unable to get model dict output channels - setting to 3, {e}')
220 | return channels
221 |
222 | def get_model(self):
223 | import platform
224 | if platform.system() == 'Darwin':
225 | return self.training_model
226 | return self.model
227 |
228 | def get_training_model(self):
229 | return self.training_model
230 |
231 | def load_model(self, path, flownet, rank=0):
232 | import torch
233 | def convert(param):
234 | if rank == -1:
235 | return {
236 | k.replace("module.", ""): v
237 | for k, v in param.items()
238 | if "module." in k
239 | }
240 | else:
241 | return param
242 | if rank <= 0:
243 | if torch.cuda.is_available():
244 | flownet.load_state_dict(convert(torch.load(path)), False)
245 | else:
246 | flownet.load_state_dict(convert(torch.load(path, map_location ='cpu')), False)
--------------------------------------------------------------------------------
/pytorch/models/flownet4_v001a.py:
--------------------------------------------------------------------------------
1 | # Orig v001 changed to v002 main flow and signatures
2 | # Back from SiLU to LeakyReLU to test data flow
3 | # Warps moved to flownet forward
4 | # Different Tail from flownet 2lh (ConvTr 6x6, conv 1x1, ConvTr 4x4, conv 1x1)
5 |
6 | class Model:
7 |
8 | info = {
9 | 'name': 'Flownet4_v001',
10 | 'file': 'flownet4_v001.py',
11 | 'ratio_support': True
12 | }
13 |
14 | def __init__(self, status = dict(), torch = None):
15 | if torch is None:
16 | import torch
17 | Module = torch.nn.Module
18 |
19 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
20 | return torch.nn.Sequential(
21 | torch.nn.Conv2d(
22 | in_planes,
23 | out_planes,
24 | kernel_size=kernel_size,
25 | stride=stride,
26 | padding=padding,
27 | dilation=dilation,
28 | padding_mode = 'zeros',
29 | bias=True
30 | ),
31 | torch.nn.LeakyReLU(0.2, True)
32 | # torch.nn.SELU(inplace = True)
33 | )
34 |
35 | def warp(tenInput, tenFlow):
36 | k = (str(tenFlow.device), str(tenFlow.size()))
37 | if k not in backwarp_tenGrid:
38 | tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3]).view(1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
39 | tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2]).view(1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
40 | backwarp_tenGrid[k] = torch.cat([ tenHorizontal, tenVertical ], 1).to(device=tenInput.device, dtype=tenInput.dtype)
41 | tenFlow = torch.cat([ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) ], 1)
42 |
43 | g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
44 | return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
45 |
46 | class Head(Module):
47 | def __init__(self):
48 | super(Head, self).__init__()
49 | self.cnn0 = torch.nn.Conv2d(3, 32, 3, 2, 1)
50 | self.cnn1 = torch.nn.Conv2d(32, 32, 3, 1, 1)
51 | self.cnn2 = torch.nn.Conv2d(32, 32, 3, 1, 1)
52 | self.cnn3 = torch.nn.ConvTranspose2d(32, 8, 4, 2, 1)
53 | self.relu = torch.nn.LeakyReLU(0.2, True)
54 |
55 | def forward(self, x, feat=False):
56 | x0 = self.cnn0(x)
57 | x = self.relu(x0)
58 | x1 = self.cnn1(x)
59 | x = self.relu(x1)
60 | x2 = self.cnn2(x)
61 | x = self.relu(x2)
62 | x3 = self.cnn3(x)
63 | if feat:
64 | return [x0, x1, x2, x3]
65 | return x3
66 |
67 | class ResConv(Module):
68 | def __init__(self, c, dilation=1):
69 | super().__init__()
70 | self.conv = torch.nn.Conv2d(c, c, 3, 1, dilation, dilation = dilation, groups = 1, padding_mode = 'zeros', bias=True)
71 | self.beta = torch.nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True)
72 | self.relu = torch.nn.LeakyReLU(0.2, True) # torch.nn.SELU(inplace = True)
73 | def forward(self, x):
74 | return self.relu(self.conv(x) * self.beta + x)
75 |
76 | class Flownet(Module):
77 | def __init__(self, in_planes, c=64):
78 | super().__init__()
79 | self.conv0 = torch.nn.Sequential(
80 | conv(in_planes, c//2, 3, 2, 1),
81 | conv(c//2, c, 3, 2, 1),
82 | )
83 | self.convblock = torch.nn.Sequential(
84 | ResConv(c),
85 | ResConv(c),
86 | ResConv(c),
87 | ResConv(c),
88 | ResConv(c),
89 | ResConv(c),
90 | ResConv(c),
91 | ResConv(c),
92 | )
93 | self.lastconv = torch.nn.Sequential(
94 | torch.nn.ConvTranspose2d(c, 4*6, 4, 2, 1),
95 | torch.nn.PixelShuffle(2)
96 | )
97 |
98 | def forward(self, img0, img1, f0, f1, timestep, mask, flow, scale=1):
99 | timestep = (img0[:, :1].clone() * 0 + 1) * timestep
100 |
101 | if flow is None:
102 | x = torch.cat((img0, img1, f0, f1, timestep), 1)
103 | x = torch.nn.functional.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False)
104 | else:
105 | warped_img0 = warp(img0, flow[:, :2])
106 | warped_img1 = warp(img1, flow[:, 2:4])
107 | warped_f0 = warp(f0, flow[:, :2])
108 | warped_f1 = warp(f1, flow[:, 2:4])
109 | x = torch.cat((warped_img0, warped_img1, warped_f0, warped_f1, timestep, mask), 1)
110 | x = torch.nn.functional.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False)
111 | flow = torch.nn.functional.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False) * 1. / scale
112 | x = torch.cat((x, flow), 1)
113 |
114 | feat = self.conv0(x)
115 | feat = self.convblock(feat)
116 | tmp = self.lastconv(feat)
117 | tmp = torch.nn.functional.interpolate(tmp, scale_factor=scale, mode="bilinear", align_corners=False)
118 | flow = tmp[:, :4] * scale
119 | mask = tmp[:, 4:5]
120 | conf = tmp[:, 5:6]
121 | return flow, mask, conf
122 |
123 | class FlownetCas(Module):
124 | def __init__(self):
125 | super().__init__()
126 | self.block0 = Flownet(7+16, c=192)
127 | self.block1 = Flownet(8+4+16, c=128)
128 | self.block2 = Flownet(8+4+16, c=96)
129 | self.block3 = Flownet(8+4+16, c=64)
130 | self.encode = Head()
131 |
132 | def forward(self, img0, img1, timestep=0.5, scale=[16, 8, 4, 1], iterations=1):
133 | img0 = img0
134 | img1 = img1
135 | f0 = self.encode(img0)
136 | f1 = self.encode(img1)
137 |
138 | flow_list = [None] * 4
139 | mask_list = [None] * 4
140 | conf_list = [None] * 4
141 | merged = [None] * 4
142 |
143 | flow, mask, conf = self.block0(img0, img1, f0, f1, timestep, None, None, scale=scale[0])
144 |
145 | flow_list[0] = flow.clone()
146 | conf_list[0] = torch.sigmoid(conf.clone())
147 | mask_list[0] = torch.sigmoid(mask.clone())
148 | merged[0] = warp(img0, flow[:, :2]) * mask_list[0] + warp(img1, flow[:, 2:4]) * (1 - mask_list[0])
149 |
150 | for iteration in range(iterations):
151 | flow_d, mask, conf = self.block1(
152 | img0,
153 | img1,
154 | f0,
155 | f1,
156 | timestep,
157 | mask,
158 | flow,
159 | scale=scale[1]
160 | )
161 | flow = flow + flow_d
162 |
163 | flow_list[1] = flow.clone()
164 | conf_list[1] = torch.sigmoid(conf.clone())
165 | mask_list[1] = torch.sigmoid(mask.clone())
166 | merged[1] = warp(img0, flow[:, :2]) * mask_list[1] + warp(img1, flow[:, 2:4]) * (1 - mask_list[1])
167 |
168 | for iteration in range(iterations):
169 | flow_d, mask, conf = self.block2(
170 | img0,
171 | img1,
172 | f0,
173 | f1,
174 | timestep,
175 | mask,
176 | flow,
177 | scale=scale[2]
178 | )
179 | flow = flow + flow_d
180 |
181 | flow_list[2] = flow.clone()
182 | conf_list[2] = torch.sigmoid(conf.clone())
183 | mask_list[2] = torch.sigmoid(mask.clone())
184 | merged[2] = warp(img0, flow[:, :2]) * mask_list[2] + warp(img1, flow[:, 2:4]) * (1 - mask_list[2])
185 |
186 | for iteration in range(iterations):
187 | flow_d, mask, conf = self.block3(
188 | img0,
189 | img1,
190 | f0,
191 | f1,
192 | timestep,
193 | mask,
194 | flow,
195 | scale=scale[3]
196 | )
197 | flow = flow + flow_d
198 |
199 | flow_list[3] = flow
200 | conf_list[3] = torch.sigmoid(conf)
201 | mask_list[3] = torch.sigmoid(mask)
202 | merged[3] = warp(img0, flow[:, :2]) * mask_list[3] + warp(img1, flow[:, 2:4]) * (1 - mask_list[3])
203 |
204 | return flow_list, mask_list, conf_list, merged
205 |
206 | self.model = FlownetCas
207 | self.training_model = FlownetCas
208 |
209 | @staticmethod
210 | def get_info():
211 | return Model.info
212 |
213 | @staticmethod
214 | def get_name():
215 | return Model.info.get('name')
216 |
217 | @staticmethod
218 | def input_channels(model_state_dict):
219 | channels = 3
220 | try:
221 | channels = model_state_dict.get('multiresblock1.conv_3x3.conv1.weight').shape[1]
222 | except Exception as e:
223 | print (f'Unable to get model dict input channels - setting to 3, {e}')
224 | return channels
225 |
226 | @staticmethod
227 | def output_channels(model_state_dict):
228 | channels = 5
229 | try:
230 | channels = model_state_dict.get('conv_final.conv1.weight').shape[0]
231 | except Exception as e:
232 | print (f'Unable to get model dict output channels - setting to 3, {e}')
233 | return channels
234 |
235 | def get_model(self):
236 | import platform
237 | if platform.system() == 'Darwin':
238 | return self.training_model
239 | return self.model
240 |
241 | def get_training_model(self):
242 | return self.training_model
243 |
244 | def load_model(self, path, flownet, rank=0):
245 | import torch
246 | def convert(param):
247 | if rank == -1:
248 | return {
249 | k.replace("module.", ""): v
250 | for k, v in param.items()
251 | if "module." in k
252 | }
253 | else:
254 | return param
255 | if rank <= 0:
256 | if torch.cuda.is_available():
257 | flownet.load_state_dict(convert(torch.load(path)), False)
258 | else:
259 | flownet.load_state_dict(convert(torch.load(path, map_location ='cpu')), False)
--------------------------------------------------------------------------------
/pytorch/models/flownet4_v001eb.py:
--------------------------------------------------------------------------------
1 | # Orig v001 changed to v002 main flow and signatures
2 | # Back from SiLU to LeakyReLU to test data flow
3 | # Warps moved to flownet forward
4 | # Different Tail from flownet 2lh (ConvTr 6x6, conv 1x1, ConvTr 4x4, conv 1x1)
5 |
6 | class Model:
7 |
8 | info = {
9 | 'name': 'Flownet4_v001eb',
10 | 'file': 'flownet4_v001eb.py',
11 | 'ratio_support': True
12 | }
13 |
14 | def __init__(self, status = dict(), torch = None):
15 | if torch is None:
16 | import torch
17 | Module = torch.nn.Module
18 | backwarp_tenGrid = {}
19 |
20 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
21 | return torch.nn.Sequential(
22 | torch.nn.Conv2d(
23 | in_planes,
24 | out_planes,
25 | kernel_size=kernel_size,
26 | stride=stride,
27 | padding=padding,
28 | dilation=dilation,
29 | padding_mode = 'reflect',
30 | bias=True
31 | ),
32 | torch.nn.LeakyReLU(0.2, True)
33 | # torch.nn.SELU(inplace = True)
34 | )
35 |
36 | def warp(tenInput, tenFlow):
37 | k = (str(tenFlow.device), str(tenFlow.size()))
38 | if k not in backwarp_tenGrid:
39 | tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3]).view(1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
40 | tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2]).view(1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
41 | backwarp_tenGrid[k] = torch.cat([ tenHorizontal, tenVertical ], 1).to(device=tenInput.device, dtype=tenInput.dtype)
42 | tenFlow = torch.cat([ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) ], 1)
43 |
44 | g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
45 | return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='reflection', align_corners=True)
46 |
47 | class Head(Module):
48 | def __init__(self):
49 | super(Head, self).__init__()
50 | self.cnn0 = torch.nn.Conv2d(3, 32, 3, 2, 1)
51 | self.cnn1 = torch.nn.Conv2d(32, 32, 3, 1, 1)
52 | self.cnn2 = torch.nn.Conv2d(32, 32, 3, 1, 1)
53 | self.cnn3 = torch.nn.ConvTranspose2d(32, 8, 4, 2, 1)
54 | self.relu = torch.nn.LeakyReLU(0.2, True)
55 |
56 | def forward(self, x, feat=False):
57 | x0 = self.cnn0(x)
58 | x = self.relu(x0)
59 | x1 = self.cnn1(x)
60 | x = self.relu(x1)
61 | x2 = self.cnn2(x)
62 | x = self.relu(x2)
63 | x3 = self.cnn3(x)
64 | if feat:
65 | return [x0, x1, x2, x3]
66 | return x3
67 |
68 | class ResConv(Module):
69 | def __init__(self, c, dilation=1):
70 | super().__init__()
71 | self.conv = torch.nn.Conv2d(c, c, 3, 1, dilation, dilation = dilation, groups = 1, padding_mode = 'reflect', bias=True)
72 | self.beta = torch.nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True)
73 | self.relu = torch.nn.LeakyReLU(0.2, True) # torch.nn.SELU(inplace = True)
74 |
75 | torch.nn.init.kaiming_normal_(self.conv.weight, mode='fan_in', nonlinearity='relu')
76 | self.conv.weight.data *= 1e-2
77 | if self.conv.bias is not None:
78 | torch.nn.init.constant_(self.conv.bias, 0)
79 |
80 | def forward(self, x):
81 | return self.relu(self.conv(x) * self.beta + x)
82 |
83 | class Flownet(Module):
84 | def __init__(self, in_planes, c=64):
85 | super().__init__()
86 | self.conv0 = torch.nn.Sequential(
87 | conv(in_planes, c//2, 3, 2, 1),
88 | conv(c//2, c, 3, 2, 1),
89 | )
90 | self.convblock = torch.nn.Sequential(
91 | ResConv(c),
92 | ResConv(c),
93 | ResConv(c),
94 | ResConv(c),
95 | ResConv(c),
96 | ResConv(c),
97 | ResConv(c),
98 | ResConv(c),
99 | )
100 | self.lastconv = torch.nn.Sequential(
101 | torch.nn.ConvTranspose2d(c, c, 6, 2, 2),
102 | torch.nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0, bias=True),
103 | torch.nn.ConvTranspose2d(c, c, 4, 2, 1),
104 | torch.nn.Conv2d(c, 6, kernel_size=1, stride=1, padding=0, bias=True),
105 | )
106 |
107 | def forward(self, img0, img1, f0, f1, timestep, mask, flow, scale=1):
108 | timestep = (img0[:, :1].clone() * 0 + 1) * timestep
109 |
110 | if flow is None:
111 | x = torch.cat((img0, img1, f0, f1, timestep), 1)
112 | x = torch.nn.functional.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False)
113 | else:
114 | warped_img0 = warp(img0, flow[:, :2])
115 | warped_img1 = warp(img1, flow[:, 2:4])
116 | warped_f0 = warp(f0, flow[:, :2])
117 | warped_f1 = warp(f1, flow[:, 2:4])
118 | x = torch.cat((warped_img0, warped_img1, warped_f0, warped_f1, timestep, mask), 1)
119 | x = torch.nn.functional.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False)
120 | flow = torch.nn.functional.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False) * 1. / scale
121 | x = torch.cat((x, flow), 1)
122 |
123 | feat = self.conv0(x)
124 | feat = self.convblock(feat)
125 | tmp = self.lastconv(feat)
126 | tmp = torch.nn.functional.interpolate(tmp, scale_factor=scale, mode="bilinear", align_corners=False)
127 | flow = tmp[:, :4] * scale
128 | mask = tmp[:, 4:5]
129 | conf = tmp[:, 5:6]
130 | return flow, mask, conf
131 |
132 | class FlownetCas(Module):
133 | def __init__(self):
134 | super().__init__()
135 | self.block0 = Flownet(7+16, c=192)
136 | self.block1 = Flownet(8+4+16, c=128)
137 | self.block2 = Flownet(8+4+16, c=96)
138 | self.block3 = Flownet(8+4+16, c=64)
139 | self.encode = Head()
140 |
141 | def forward(self, img0, img1, timestep=0.5, scale=[16, 8, 4, 1], iterations=1):
142 | img0 = img0
143 | img1 = img1
144 | f0 = self.encode(img0)
145 | f1 = self.encode(img1)
146 |
147 | flow_list = [None] * 4
148 | mask_list = [None] * 4
149 | merged = [None] * 4
150 |
151 | flow, mask, conf = self.block0(img0, img1, f0, f1, timestep, None, None, scale=scale[0])
152 |
153 | flow_list[0] = flow.clone()
154 | mask_list[0] = torch.sigmoid(mask.clone())
155 | merged[0] = warp(img0, flow[:, :2]) * mask_list[0] + warp(img1, flow[:, 2:4]) * (1 - mask_list[0])
156 |
157 | for iteration in range(iterations):
158 | flow_d, mask, conf = self.block1(
159 | img0,
160 | img1,
161 | f0,
162 | f1,
163 | timestep,
164 | mask,
165 | flow,
166 | scale=scale[1]
167 | )
168 | flow = flow + flow_d
169 |
170 | flow_list[1] = flow.clone()
171 | mask_list[1] = torch.sigmoid(mask.clone())
172 | merged[1] = warp(img0, flow[:, :2]) * mask_list[1] + warp(img1, flow[:, 2:4]) * (1 - mask_list[1])
173 |
174 | for iteration in range(iterations):
175 | flow_d, mask, conf = self.block2(
176 | img0,
177 | img1,
178 | f0,
179 | f1,
180 | timestep,
181 | mask,
182 | flow,
183 | scale=scale[2]
184 | )
185 | flow = flow + flow_d
186 |
187 | flow_list[2] = flow.clone()
188 | mask_list[2] = torch.sigmoid(mask.clone())
189 | merged[2] = warp(img0, flow[:, :2]) * mask_list[2] + warp(img1, flow[:, 2:4]) * (1 - mask_list[2])
190 |
191 | for iteration in range(iterations):
192 | flow_d, mask, conf = self.block3(
193 | img0,
194 | img1,
195 | f0,
196 | f1,
197 | timestep,
198 | mask,
199 | flow,
200 | scale=scale[3]
201 | )
202 | flow = flow + flow_d
203 |
204 | flow_list[3] = flow
205 | mask_list[3] = torch.sigmoid(mask)
206 | merged[3] = warp(img0, flow[:, :2]) * mask_list[3] + warp(img1, flow[:, 2:4]) * (1 - mask_list[3])
207 |
208 | return flow_list, mask_list, merged
209 |
210 | self.model = FlownetCas
211 | self.training_model = FlownetCas
212 |
213 | @staticmethod
214 | def get_info():
215 | return Model.info
216 |
217 | @staticmethod
218 | def get_name():
219 | return Model.info.get('name')
220 |
221 | @staticmethod
222 | def input_channels(model_state_dict):
223 | channels = 3
224 | try:
225 | channels = model_state_dict.get('multiresblock1.conv_3x3.conv1.weight').shape[1]
226 | except Exception as e:
227 | print (f'Unable to get model dict input channels - setting to 3, {e}')
228 | return channels
229 |
230 | @staticmethod
231 | def output_channels(model_state_dict):
232 | channels = 5
233 | try:
234 | channels = model_state_dict.get('conv_final.conv1.weight').shape[0]
235 | except Exception as e:
236 | print (f'Unable to get model dict output channels - setting to 3, {e}')
237 | return channels
238 |
239 | def get_model(self):
240 | import platform
241 | if platform.system() == 'Darwin':
242 | return self.training_model
243 | return self.model
244 |
245 | def get_training_model(self):
246 | return self.training_model
247 |
248 | def load_model(self, path, flownet, rank=0):
249 | import torch
250 | def convert(param):
251 | if rank == -1:
252 | return {
253 | k.replace("module.", ""): v
254 | for k, v in param.items()
255 | if "module." in k
256 | }
257 | else:
258 | return param
259 | if rank <= 0:
260 | if torch.cuda.is_available():
261 | flownet.load_state_dict(convert(torch.load(path)), False)
262 | else:
263 | flownet.load_state_dict(convert(torch.load(path, map_location ='cpu')), False)
--------------------------------------------------------------------------------
/pytorch/models/stabnet4_v000.py:
--------------------------------------------------------------------------------
1 | # vanilla RIFE block
2 |
3 | class Model:
4 |
5 | info = {
6 | 'name': 'Stabnet4_v001',
7 | 'file': 'stabnet4_v001.py',
8 | 'ratio_support': True
9 | }
10 |
11 | def __init__(self, status = dict(), torch = None):
12 | if torch is None:
13 | import torch
14 | Module = torch.nn.Module
15 | backwarp_tenGrid = {}
16 |
17 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
18 | return torch.nn.Sequential(
19 | torch.nn.Conv2d(
20 | in_planes,
21 | out_planes,
22 | kernel_size=kernel_size,
23 | stride=stride,
24 | padding=padding,
25 | dilation=dilation,
26 | padding_mode = 'reflect',
27 | bias=True
28 | ),
29 | torch.nn.LeakyReLU(0.2, True)
30 | # torch.nn.SELU(inplace = True)
31 | )
32 |
33 | def warp(tenInput, tenFlow):
34 | k = (str(tenFlow.device), str(tenFlow.size()))
35 | if k not in backwarp_tenGrid:
36 | tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3]).view(1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
37 | tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2]).view(1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
38 | backwarp_tenGrid[k] = torch.cat([ tenHorizontal, tenVertical ], 1).to(device=tenInput.device, dtype=tenInput.dtype)
39 | tenFlow = torch.cat([ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) ], 1)
40 |
41 | g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
42 | return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bicubic', padding_mode='border', align_corners=True)
43 |
44 | def normalize(tensor, min_val, max_val):
45 | src_dtype = tensor.dtype
46 | tensor = tensor.float()
47 | t_min = tensor.min()
48 | t_max = tensor.max()
49 |
50 | if t_min == t_max:
51 | return torch.full_like(tensor, (min_val + max_val) / 2.0)
52 |
53 | tensor = ((tensor - t_min) / (t_max - t_min)) * (max_val - min_val) + min_val
54 | tensor = tensor.to(dtype = src_dtype)
55 | return tensor
56 |
57 | def hpass(img):
58 | src_dtype = img.dtype
59 | img = img.float()
60 | def gauss_kernel(size=5, channels=3):
61 | kernel = torch.tensor([[1., 4., 6., 4., 1],
62 | [4., 16., 24., 16., 4.],
63 | [6., 24., 36., 24., 6.],
64 | [4., 16., 24., 16., 4.],
65 | [1., 4., 6., 4., 1.]])
66 | kernel /= 256.
67 | kernel = kernel.repeat(channels, 1, 1, 1)
68 | return kernel
69 |
70 | def conv_gauss(img, kernel):
71 | img = torch.nn.functional.pad(img, (2, 2, 2, 2), mode='reflect')
72 | out = torch.nn.functional.conv2d(img, kernel, groups=img.shape[1])
73 | return out
74 |
75 | gkernel = gauss_kernel()
76 | gkernel = gkernel.to(device=img.device, dtype=img.dtype)
77 | hp = img - conv_gauss(img, gkernel) + 0.5
78 | hp = torch.clamp(hp, 0.48, 0.52)
79 | hp = normalize(hp, 0, 1)
80 | hp = torch.max(hp, dim=1, keepdim=True).values
81 | hp = hp.to(dtype = src_dtype)
82 | return hp
83 |
84 | def compress(x):
85 | src_dtype = x.dtype
86 | x = x.float()
87 | scale = torch.tanh(torch.tensor(1.0))
88 | x = torch.where(
89 | (x >= -1) & (x <= 1), scale * x,
90 | torch.tanh(x)
91 | )
92 | x = (x + 1) / 2
93 | x = x.to(dtype = src_dtype)
94 | return x
95 |
96 | class Head(Module):
97 | def __init__(self):
98 | super(Head, self).__init__()
99 | self.encode = torch.nn.Sequential(
100 | torch.nn.Conv2d(4, 32, 3, 2, 1),
101 | torch.nn.LeakyReLU(0.2, True),
102 | torch.nn.Conv2d(32, 32, 3, 1, 1),
103 | torch.nn.LeakyReLU(0.2, True),
104 | torch.nn.Conv2d(32, 32, 3, 1, 1),
105 | torch.nn.LeakyReLU(0.2, True),
106 | torch.nn.ConvTranspose2d(32, 8, 4, 2, 1)
107 | )
108 | self.maxdepth = 2
109 |
110 | def forward(self, x):
111 | hp = hpass(x)
112 | x = torch.cat((x, hp), 1)
113 |
114 | n, c, h, w = x.shape
115 | ph = self.maxdepth - (h % self.maxdepth)
116 | pw = self.maxdepth - (w % self.maxdepth)
117 | padding = (0, pw, 0, ph)
118 | x = torch.nn.functional.pad(x, padding)
119 |
120 | return self.encode(x)[:, :, :h, :w]
121 |
122 | class ResConv(Module):
123 | def __init__(self, c, dilation=1):
124 | super().__init__()
125 | self.conv = torch.nn.Conv2d(c, c, 3, 1, dilation, dilation = dilation, groups = 1, padding_mode = 'zeros', bias=True)
126 | self.beta = torch.nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True)
127 | self.relu = torch.nn.PReLU(c, 0.2)
128 | def forward(self, x):
129 | return self.relu(self.conv(x) * self.beta + x)
130 |
131 | class Flownet(Module):
132 | def __init__(self, in_planes, c=64):
133 | super().__init__()
134 | self.conv0 = torch.nn.Sequential(
135 | conv(in_planes, c//2, 5, 2, 2),
136 | conv(c//2, c, 3, 2, 1),
137 | )
138 | self.convblock = torch.nn.Sequential(
139 | ResConv(c),
140 | ResConv(c),
141 | ResConv(c),
142 | ResConv(c),
143 | ResConv(c),
144 | ResConv(c),
145 | ResConv(c),
146 | ResConv(c),
147 | )
148 | self.lastconv = torch.nn.Sequential(
149 | torch.nn.ConvTranspose2d(c, 4*2, 4, 2, 1),
150 | torch.nn.PixelShuffle(2)
151 | )
152 | self.maxdepth = 4
153 |
154 | def forward(self, img0, img1, f0, f1, scale=1):
155 | n, c, h, w = img0.shape
156 | sh, sw = round(h * (1 / scale)), round(w * (1 / scale))
157 |
158 | ph = self.maxdepth - (sh % self.maxdepth)
159 | pw = self.maxdepth - (sw % self.maxdepth)
160 | padding = (0, pw, 0, ph)
161 |
162 | imgs = torch.cat((img0, img1), 1)
163 | imgs = normalize(imgs, 0, 1) * 2 - 1
164 | x = torch.cat((imgs, f0, f1), 1)
165 | x = torch.nn.functional.interpolate(x, size=(sh, sw), mode="bicubic", align_corners=False)
166 | x = torch.nn.functional.pad(x, padding)
167 |
168 | feat = self.conv0(x)
169 | feat = self.convblock(feat)
170 | feat = self.lastconv(feat)
171 | feat = torch.nn.functional.interpolate(feat[:, :, :sh, :sw], size=(h, w), mode="bilinear", align_corners=False)
172 | flow = feat * scale
173 | return flow
174 |
175 | class FlownetCas(Module):
176 | def __init__(self):
177 | super().__init__()
178 | self.block0 = Flownet(6+16, c=64) # images + feat + timestep + lingrid
179 | # self.block1 = FlownetDeepDualHead(6+3+20+1+1+2+1+4, 12+20+8+1, c=128) # FlownetDeepDualHead(9+30+1+1+4+1+2, 22+30+1, c=128) # images + feat + timestep + lingrid + mask + conf + flow
180 | # self.block2 = FlownetDeepDualHead(6+3+20+1+1+2+1+4, 12+20+8+1, c=96) # FlownetLT(6+2+1+1+1, c=48) # None # FlownetDeepDualHead(9+30+1+1+4+1+2, 22+30+1, c=112) # images + feat + timestep + lingrid + mask + conf + flow
181 | # self.block3 = FlownetLT(11, c=48)
182 | self.encode = Head()
183 |
184 | def forward(self, img0, img1, timestep=0.5, scale=[16, 8, 4, 1], iterations=4, gt=None):
185 | img0 = compress(img0 * 2 - 1)
186 | img1 = compress(img1 * 2 - 1)
187 |
188 | f0 = self.encode(img0)
189 | f1 = self.encode(img1)
190 |
191 | flow = self.block0(img0, img1, f0, f1, scale=1)
192 |
193 | result = {
194 | 'flow_list': [flow]
195 | }
196 |
197 | return result
198 |
199 | self.model = FlownetCas
200 | self.training_model = FlownetCas
201 |
202 | @staticmethod
203 | def get_info():
204 | return Model.info
205 |
206 | @staticmethod
207 | def get_name():
208 | return Model.info.get('name')
209 |
210 | @staticmethod
211 | def input_channels(model_state_dict):
212 | channels = 3
213 | try:
214 | channels = model_state_dict.get('multiresblock1.conv_3x3.conv1.weight').shape[1]
215 | except Exception as e:
216 | print (f'Unable to get model dict input channels - setting to 3, {e}')
217 | return channels
218 |
219 | @staticmethod
220 | def output_channels(model_state_dict):
221 | channels = 5
222 | try:
223 | channels = model_state_dict.get('conv_final.conv1.weight').shape[0]
224 | except Exception as e:
225 | print (f'Unable to get model dict output channels - setting to 3, {e}')
226 | return channels
227 |
228 | def get_model(self):
229 | import platform
230 | if platform.system() == 'Darwin':
231 | return self.training_model
232 | return self.model
233 |
234 | def get_training_model(self):
235 | return self.training_model
236 |
237 | def load_model(self, path, flownet, rank=0):
238 | import torch
239 | def convert(param):
240 | if rank == -1:
241 | return {
242 | k.replace("module.", ""): v
243 | for k, v in param.items()
244 | if "module." in k
245 | }
246 | else:
247 | return param
248 | if rank <= 0:
249 | if torch.cuda.is_available():
250 | flownet.load_state_dict(convert(torch.load(path)), False)
251 | else:
252 | flownet.load_state_dict(convert(torch.load(path, map_location ='cpu')), False)
--------------------------------------------------------------------------------
/pytorch/models/stabnet4_v001.py:
--------------------------------------------------------------------------------
1 | # vanilla RIFE block + PReLU
2 |
3 | class Model:
4 |
5 | info = {
6 | 'name': 'Stabnet4_v001',
7 | 'file': 'stabnet4_v001.py',
8 | 'ratio_support': True
9 | }
10 |
11 | def __init__(self, status = dict(), torch = None):
12 | if torch is None:
13 | import torch
14 | Module = torch.nn.Module
15 | backwarp_tenGrid = {}
16 |
17 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
18 | return torch.nn.Sequential(
19 | torch.nn.Conv2d(
20 | in_planes,
21 | out_planes,
22 | kernel_size=kernel_size,
23 | stride=stride,
24 | padding=padding,
25 | dilation=dilation,
26 | padding_mode = 'reflect',
27 | bias=True
28 | ),
29 | torch.nn.PReLU(out_planes, 0.2)
30 | )
31 |
32 | def warp(tenInput, tenFlow):
33 | k = (str(tenFlow.device), str(tenFlow.size()))
34 | if k not in backwarp_tenGrid:
35 | tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3]).view(1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
36 | tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2]).view(1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
37 | backwarp_tenGrid[k] = torch.cat([ tenHorizontal, tenVertical ], 1).to(device=tenInput.device, dtype=tenInput.dtype)
38 | tenFlow = torch.cat([ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) ], 1)
39 |
40 | g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
41 | return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bicubic', padding_mode='border', align_corners=True)
42 |
43 | def normalize(tensor, min_val, max_val):
44 | src_dtype = tensor.dtype
45 | tensor = tensor.float()
46 | t_min = tensor.min()
47 | t_max = tensor.max()
48 |
49 | if t_min == t_max:
50 | return torch.full_like(tensor, (min_val + max_val) / 2.0)
51 |
52 | tensor = ((tensor - t_min) / (t_max - t_min)) * (max_val - min_val) + min_val
53 | tensor = tensor.to(dtype = src_dtype)
54 | return tensor
55 |
56 | def hpass(img):
57 | src_dtype = img.dtype
58 | img = img.float()
59 | def gauss_kernel(size=5, channels=3):
60 | kernel = torch.tensor([[1., 4., 6., 4., 1],
61 | [4., 16., 24., 16., 4.],
62 | [6., 24., 36., 24., 6.],
63 | [4., 16., 24., 16., 4.],
64 | [1., 4., 6., 4., 1.]])
65 | kernel /= 256.
66 | kernel = kernel.repeat(channels, 1, 1, 1)
67 | return kernel
68 |
69 | def conv_gauss(img, kernel):
70 | img = torch.nn.functional.pad(img, (2, 2, 2, 2), mode='reflect')
71 | out = torch.nn.functional.conv2d(img, kernel, groups=img.shape[1])
72 | return out
73 |
74 | gkernel = gauss_kernel()
75 | gkernel = gkernel.to(device=img.device, dtype=img.dtype)
76 | hp = img - conv_gauss(img, gkernel) + 0.5
77 | hp = torch.clamp(hp, 0.48, 0.52)
78 | hp = normalize(hp, 0, 1)
79 | hp = torch.max(hp, dim=1, keepdim=True).values
80 | hp = hp.to(dtype = src_dtype)
81 | return hp
82 |
83 | def compress(x):
84 | src_dtype = x.dtype
85 | x = x.float()
86 | scale = torch.tanh(torch.tensor(1.0))
87 | x = torch.where(
88 | (x >= -1) & (x <= 1), scale * x,
89 | torch.tanh(x)
90 | )
91 | x = (x + 1) / 2
92 | x = x.to(dtype = src_dtype)
93 | return x
94 |
95 | class Head(Module):
96 | def __init__(self):
97 | super(Head, self).__init__()
98 | self.encode = torch.nn.Sequential(
99 | torch.nn.Conv2d(4, 32, 3, 2, 1),
100 | torch.nn.PReLU(32, 0.2),
101 | torch.nn.Conv2d(32, 32, 3, 1, 1),
102 | torch.nn.PReLU(32, 0.2),
103 | torch.nn.Conv2d(32, 32, 3, 1, 1),
104 | torch.nn.PReLU(32, 0.2),
105 | torch.nn.ConvTranspose2d(32, 8, 4, 2, 1)
106 | )
107 | self.maxdepth = 2
108 |
109 | def forward(self, x):
110 | hp = hpass(x)
111 | x = torch.cat((x, hp), 1)
112 |
113 | n, c, h, w = x.shape
114 | ph = self.maxdepth - (h % self.maxdepth)
115 | pw = self.maxdepth - (w % self.maxdepth)
116 | padding = (0, pw, 0, ph)
117 | x = torch.nn.functional.pad(x, padding)
118 |
119 | return self.encode(x)[:, :, :h, :w]
120 |
121 | class ResConv(Module):
122 | def __init__(self, c, dilation=1):
123 | super().__init__()
124 | self.conv = torch.nn.Conv2d(c, c, 3, 1, dilation, dilation = dilation, groups = 1, padding_mode = 'zeros', bias=True)
125 | self.beta = torch.nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True)
126 | self.relu = torch.nn.LeakyReLU(0.2, True)
127 | def forward(self, x):
128 | return self.relu(self.conv(x) * self.beta + x)
129 |
130 | class Flownet(Module):
131 | def __init__(self, in_planes, c=64):
132 | super().__init__()
133 | self.conv0 = torch.nn.Sequential(
134 | conv(in_planes, c//2, 5, 2, 2),
135 | conv(c//2, c, 3, 2, 1),
136 | )
137 | self.convblock = torch.nn.Sequential(
138 | ResConv(c),
139 | ResConv(c),
140 | ResConv(c),
141 | ResConv(c),
142 | ResConv(c),
143 | ResConv(c),
144 | ResConv(c),
145 | ResConv(c),
146 | )
147 | self.lastconv = torch.nn.Sequential(
148 | torch.nn.ConvTranspose2d(c, 4*2, 4, 2, 1),
149 | torch.nn.PixelShuffle(2)
150 | )
151 | self.maxdepth = 4
152 |
153 | def forward(self, img0, img1, f0, f1, scale=1):
154 | n, c, h, w = img0.shape
155 | sh, sw = round(h * (1 / scale)), round(w * (1 / scale))
156 |
157 | ph = self.maxdepth - (sh % self.maxdepth)
158 | pw = self.maxdepth - (sw % self.maxdepth)
159 | padding = (0, pw, 0, ph)
160 |
161 | imgs = torch.cat((img0, img1), 1)
162 | imgs = normalize(imgs, 0, 1) * 2 - 1
163 | x = torch.cat((imgs, f0, f1), 1)
164 | x = torch.nn.functional.interpolate(x, size=(sh, sw), mode="bicubic", align_corners=False)
165 | x = torch.nn.functional.pad(x, padding)
166 |
167 | feat = self.conv0(x)
168 | feat = self.convblock(feat)
169 | feat = self.lastconv(feat)
170 | feat = torch.nn.functional.interpolate(feat[:, :, :sh, :sw], size=(h, w), mode="bilinear", align_corners=False)
171 | flow = feat * scale
172 | return flow
173 |
174 | class FlownetCas(Module):
175 | def __init__(self):
176 | super().__init__()
177 | self.block0 = Flownet(6+16, c=64) # images + feat + timestep + lingrid
178 | # self.block1 = FlownetDeepDualHead(6+3+20+1+1+2+1+4, 12+20+8+1, c=128) # FlownetDeepDualHead(9+30+1+1+4+1+2, 22+30+1, c=128) # images + feat + timestep + lingrid + mask + conf + flow
179 | # self.block2 = FlownetDeepDualHead(6+3+20+1+1+2+1+4, 12+20+8+1, c=96) # FlownetLT(6+2+1+1+1, c=48) # None # FlownetDeepDualHead(9+30+1+1+4+1+2, 22+30+1, c=112) # images + feat + timestep + lingrid + mask + conf + flow
180 | # self.block3 = FlownetLT(11, c=48)
181 | self.encode = Head()
182 |
183 | def forward(self, img0, img1, timestep=0.5, scale=[16, 8, 4, 1], iterations=4, gt=None):
184 | img0 = compress(img0 * 2 - 1)
185 | img1 = compress(img1 * 2 - 1)
186 |
187 | f0 = self.encode(img0)
188 | f1 = self.encode(img1)
189 |
190 | flow = self.block0(img0, img1, f0, f1, scale=1)
191 |
192 | result = {
193 | 'flow_list': [flow]
194 | }
195 |
196 | return result
197 |
198 | self.model = FlownetCas
199 | self.training_model = FlownetCas
200 |
201 | @staticmethod
202 | def get_info():
203 | return Model.info
204 |
205 | @staticmethod
206 | def get_name():
207 | return Model.info.get('name')
208 |
209 | @staticmethod
210 | def input_channels(model_state_dict):
211 | channels = 3
212 | try:
213 | channels = model_state_dict.get('multiresblock1.conv_3x3.conv1.weight').shape[1]
214 | except Exception as e:
215 | print (f'Unable to get model dict input channels - setting to 3, {e}')
216 | return channels
217 |
218 | @staticmethod
219 | def output_channels(model_state_dict):
220 | channels = 5
221 | try:
222 | channels = model_state_dict.get('conv_final.conv1.weight').shape[0]
223 | except Exception as e:
224 | print (f'Unable to get model dict output channels - setting to 3, {e}')
225 | return channels
226 |
227 | def get_model(self):
228 | import platform
229 | if platform.system() == 'Darwin':
230 | return self.training_model
231 | return self.model
232 |
233 | def get_training_model(self):
234 | return self.training_model
235 |
236 | def load_model(self, path, flownet, rank=0):
237 | import torch
238 | def convert(param):
239 | if rank == -1:
240 | return {
241 | k.replace("module.", ""): v
242 | for k, v in param.items()
243 | if "module." in k
244 | }
245 | else:
246 | return param
247 | if rank <= 0:
248 | if torch.cuda.is_available():
249 | flownet.load_state_dict(convert(torch.load(path)), False)
250 | else:
251 | flownet.load_state_dict(convert(torch.load(path, map_location ='cpu')), False)
--------------------------------------------------------------------------------
/pytorch/move_to_folders.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 |
5 | def main():
6 | parser = argparse.ArgumentParser(description='Move exrs to corresponding folders')
7 |
8 | # Required argument
9 | parser.add_argument('path', type=str, help='Path to folder with exrs')
10 | args = parser.parse_args()
11 |
12 | files_and_dirs = os.listdir(args.path)
13 | exr_files = [file for file in files_and_dirs if file.endswith('.exr')]
14 |
15 | exr_file_names = set()
16 | for exr_file_name in exr_files:
17 | exr_file_names.add(exr_file_name.split('.')[0])
18 |
19 | for name in exr_file_names:
20 | os.system(f'mkdir -p {args.path}/{name}')
21 | os.system(f'mv {args.path}/{name}.* {args.path}/{name}/')
22 |
23 | if __name__ == "__main__":
24 | main()
--------------------------------------------------------------------------------
/pytorch/recompress.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 | import importlib
5 | import platform
6 |
7 | os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
8 |
9 | try:
10 | import numpy as np
11 | except:
12 | python_executable_path = sys.executable
13 | if '.miniconda' in python_executable_path:
14 | print (f'Using "{python_executable_path}" as python interpreter')
15 | print ('Unable to import Numpy')
16 | sys.exit()
17 |
18 | try:
19 | import OpenImageIO as oiio
20 | except:
21 | python_executable_path = sys.executable
22 | if '.miniconda' in python_executable_path:
23 | print ('Unable to import OpenImageIO')
24 | print (f'Using "{python_executable_path}" as python interpreter')
25 | sys.exit()
26 |
27 | def read_image_file(file_path, header_only = False):
28 | result = {'spec': None, 'image_data': None}
29 | inp = oiio.ImageInput.open(file_path)
30 | if inp :
31 | spec = inp.spec()
32 | result['spec'] = spec
33 | if not header_only:
34 | channels = spec.nchannels
35 | result['image_data'] = inp.read_image(0, 0, 0, channels).transpose(1, 0, 2)
36 | inp.close()
37 | return result
38 |
39 | def find_folders_with_exr(path):
40 | """
41 | Find all folders under the given path that contain .exr files.
42 |
43 | Parameters:
44 | path (str): The root directory to start the search from.
45 |
46 | Returns:
47 | list: A list of directories containing .exr files.
48 | """
49 | directories_with_exr = set()
50 |
51 | # Walk through all directories and files in the given path
52 | for root, dirs, files in os.walk(path):
53 | if root.endswith('preview'):
54 | continue
55 | if root.endswith('eval'):
56 | continue
57 | for file in files:
58 | if file.endswith('.exr'):
59 | directories_with_exr.add(root)
60 | break # No need to check other files in the same directory
61 |
62 | return directories_with_exr
63 |
64 | def clear_lines(n=2):
65 | """Clears a specified number of lines in the terminal."""
66 | CURSOR_UP_ONE = '\x1b[1A'
67 | ERASE_LINE = '\x1b[2K'
68 | for _ in range(n):
69 | sys.stdout.write(CURSOR_UP_ONE)
70 | sys.stdout.write(ERASE_LINE)
71 |
72 | def main():
73 | parser = argparse.ArgumentParser(description='In-place recompress exrs to half PIZ (WARNING - DESTRUCTIVE!)')
74 |
75 | # Required argument
76 | parser.add_argument('dataset_path', type=str, help='Path to folders with exrs')
77 | args = parser.parse_args()
78 |
79 | folders_with_exr = find_folders_with_exr(args.dataset_path)
80 |
81 | exr_files = []
82 |
83 | for folder_path in sorted(folders_with_exr):
84 | folder_exr_files = [os.path.join(folder_path, file) for file in os.listdir(folder_path) if file.endswith('.exr')]
85 | folder_exr_files.sort()
86 | exr_files.extend(folder_exr_files)
87 |
88 | idx = 0
89 | for exr_file_path in exr_files:
90 | # clear_lines(1)
91 | print (f'\r{" "*120}', end='')
92 | print (f'\rFile [{idx+1} / {len(exr_files)}], {os.path.basename(exr_file_path)}', end='')
93 | try:
94 | result = read_image_file(exr_file_path)
95 | image_data = result['image_data']
96 | spec = result['spec']
97 |
98 | if image_data is None or spec is None:
99 | raise RuntimeError(f"Could not read image or header from {exr_file_path}")
100 |
101 | # Transpose back to (W, H, C) for OIIO
102 | image_data_for_write = image_data.transpose(1, 0, 2)
103 |
104 | # Clone the spec and modify compression only
105 | new_spec = oiio.ImageSpec(spec)
106 | new_spec.set_format(oiio.TypeDesc.TypeHalf) # ensure half-precision
107 | new_spec.attribute("compression", "piz")
108 |
109 | out = oiio.ImageOutput.create(exr_file_path)
110 | if not out:
111 | raise RuntimeError(f"Could not create output for {exr_file_path}")
112 |
113 | if not out.open(exr_file_path, new_spec):
114 | raise RuntimeError(f"Failed to open output file {exr_file_path}")
115 |
116 | if not out.write_image(image_data_for_write):
117 | raise RuntimeError(f"Failed to write image to {exr_file_path}")
118 |
119 | out.close()
120 |
121 | except Exception as e:
122 | print (f'\nError reading {exr_file_path}: {e}')
123 | idx += 1
124 | print ('')
125 |
126 | if __name__ == "__main__":
127 | main()
128 |
--------------------------------------------------------------------------------
/pytorch/rename_keys.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 |
5 | try:
6 | import numpy as np
7 | import torch
8 | except:
9 | python_executable_path = sys.executable
10 | if '.miniconda' in python_executable_path:
11 | print ('Unable to import Numpy and PyTorch libraries')
12 | print (f'Using {python_executable_path} python interpreter')
13 | sys.exit()
14 |
15 | key_mapping_01 = {
16 | 'block0.encode01.0.weight': 'block0.encode01.downconv01.conv1.weight',
17 | 'block0.encode01.0.bias': 'block0.encode01.downconv01.conv1.bias',
18 | 'block0.encode01.2.weight': 'block0.encode01.conv01.conv1.weight',
19 | 'block0.encode01.2.bias': 'block0.encode01.conv01.conv1.bias',
20 | 'block0.encode01.4.weight': 'block0.encode01.conv02.conv1.weight',
21 | 'block0.encode01.4.bias': 'block0.encode01.conv02.conv1.bias',
22 | 'block0.encode01.6.weight': 'block0.encode01.upsample01.conv1.weight',
23 | 'block0.encode01.6.bias': 'block0.encode01.upsample01.conv1.bias',
24 |
25 | 'block1.encode01.0.weight': 'block1.encode01.downconv01.conv1.weight',
26 | 'block1.encode01.0.bias': 'block1.encode01.downconv01.conv1.bias',
27 | 'block1.encode01.2.weight': 'block1.encode01.conv01.conv1.weight',
28 | 'block1.encode01.2.bias': 'block1.encode01.conv01.conv1.bias',
29 | 'block1.encode01.4.weight': 'block1.encode01.conv02.conv1.weight',
30 | 'block1.encode01.4.bias': 'block1.encode01.conv02.conv1.bias',
31 | 'block1.encode01.6.weight': 'block1.encode01.upsample01.conv1.weight',
32 | 'block1.encode01.6.bias': 'block1.encode01.upsample01.conv1.bias',
33 |
34 | 'block2.encode01.0.weight': 'block2.encode01.downconv01.conv1.weight',
35 | 'block2.encode01.0.bias': 'block2.encode01.downconv01.conv1.bias',
36 | 'block2.encode01.2.weight': 'block2.encode01.conv01.conv1.weight',
37 | 'block2.encode01.2.bias': 'block2.encode01.conv01.conv1.bias',
38 | 'block2.encode01.4.weight': 'block2.encode01.conv02.conv1.weight',
39 | 'block2.encode01.4.bias': 'block2.encode01.conv02.conv1.bias',
40 | 'block2.encode01.6.weight': 'block2.encode01.upsample01.conv1.weight',
41 | 'block2.encode01.6.bias': 'block2.encode01.upsample01.conv1.bias',
42 |
43 | 'block3.encode01.0.weight': 'block3.encode01.downconv01.conv1.weight',
44 | 'block3.encode01.0.bias': 'block3.encode01.downconv01.conv1.bias',
45 | 'block3.encode01.2.weight': 'block3.encode01.conv01.conv1.weight',
46 | 'block3.encode01.2.bias': 'block3.encode01.conv01.conv1.bias',
47 | 'block3.encode01.4.weight': 'block3.encode01.conv02.conv1.weight',
48 | 'block3.encode01.4.bias': 'block3.encode01.conv02.conv1.bias',
49 | 'block3.encode01.6.weight': 'block3.encode01.upsample01.conv1.weight',
50 | 'block3.encode01.6.bias': 'block3.encode01.upsample01.conv1.bias',
51 | }
52 |
53 | def main():
54 | parser = argparse.ArgumentParser(description='Move exrs to corresponding folders')
55 |
56 | # Required argument
57 | parser.add_argument('source', type=str, help='Path to source state dict')
58 | parser.add_argument('dest', type=str, help='Path to dest state dict')
59 | args = parser.parse_args()
60 |
61 | checkpoint = torch.load(args.source)
62 | src_state_dict = checkpoint.get('flownet_state_dict')
63 |
64 | dest_state_dict = {}
65 | for key in src_state_dict:
66 | new_key = key_mapping_01.get(key, key)
67 | dest_state_dict[new_key] = src_state_dict[key]
68 |
69 | checkpoint['flownet_state_dict'] = dest_state_dict
70 | torch.save(checkpoint, args.dest)
71 |
72 | '''
73 | for key in src_state_dict.keys():
74 | print (key)
75 | '''
76 |
77 | if __name__ == "__main__":
78 | main()
--------------------------------------------------------------------------------
/pytorch/resize_exrs_min.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 | import importlib
5 | import platform
6 |
7 | os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
8 |
9 | try:
10 | import numpy as np
11 | import torch
12 | except:
13 | python_executable_path = sys.executable
14 | if '.miniconda' in python_executable_path:
15 | print ('Unable to import Numpy and PyTorch libraries')
16 | print (f'Using {python_executable_path} python interpreter')
17 | sys.exit()
18 |
19 | try:
20 | import OpenImageIO as oiio
21 | except:
22 | python_executable_path = sys.executable
23 | if '.miniconda' in python_executable_path:
24 | print ('Unable to import OpenImageIO')
25 | print (f'Using "{python_executable_path}" as python interpreter')
26 | sys.exit()
27 |
28 | def read_image_file(file_path, header_only = False):
29 | result = {'spec': None, 'image_data': None}
30 | inp = oiio.ImageInput.open(file_path)
31 | if inp:
32 | spec = inp.spec()
33 | result['spec'] = spec
34 | if not header_only:
35 | channels = spec.nchannels
36 | result['image_data'] = inp.read_image(0, 0, 0, channels)
37 | # img_data = inp.read_image(0, 0, 0, channels) #.transpose(1, 0, 2)
38 | # result['image_data'] = np.ascontiguousarray(img_data)
39 | inp.close()
40 | # del inp
41 | return result
42 |
43 | def write_image_file(file_path, image_data, image_spec):
44 | out = oiio.ImageOptput.create(file_path)
45 | if out:
46 | out.open(file_path, image_spec)
47 | out.write_image(image_data)
48 | out.close()
49 |
50 | def find_folders_with_exr(path):
51 | """
52 | Find all folders under the given path that contain .exr files.
53 |
54 | Parameters:
55 | path (str): The root directory to start the search from.
56 |
57 | Returns:
58 | list: A list of directories containing .exr files.
59 | """
60 | directories_with_exr = set()
61 |
62 | # Walk through all directories and files in the given path
63 | for root, dirs, files in os.walk(path):
64 | if root.endswith('preview'):
65 | continue
66 | if root.endswith('eval'):
67 | continue
68 | for file in files:
69 | if file.endswith('.exr'):
70 | directories_with_exr.add(root)
71 | break # No need to check other files in the same directory
72 |
73 | return directories_with_exr
74 |
75 | def clear_lines(n=2):
76 | """Clears a specified number of lines in the terminal."""
77 | CURSOR_UP_ONE = '\x1b[1A'
78 | ERASE_LINE = '\x1b[2K'
79 | for _ in range(n):
80 | sys.stdout.write(CURSOR_UP_ONE)
81 | sys.stdout.write(ERASE_LINE)
82 |
83 | def resize_image(tensor, new_h, new_w):
84 | """
85 | Resize the tensor of shape [h, w, c] so that the smallest dimension becomes x,
86 | while retaining aspect ratio.
87 |
88 | Parameters:
89 | tensor (torch.Tensor): The input tensor with shape [h, w, c].
90 | x (int): The target size for the smallest dimension.
91 |
92 | Returns:
93 | torch.Tensor: The resized tensor.
94 | """
95 | # Adjust tensor shape to [n, c, h, w]
96 | tensor = tensor.permute(2, 0, 1).unsqueeze(0)
97 |
98 | # Resize
99 | resized_tensor = torch.nn.functional.interpolate(tensor, size=(new_h, new_w), mode='bicubic', align_corners=False)
100 |
101 | # Adjust tensor shape back to [h, w, c]
102 | resized_tensor = resized_tensor.squeeze(0).permute(1, 2, 0)
103 |
104 | return resized_tensor
105 |
106 | def halve(exr_file_path, new_h):
107 | device = torch.device('cuda')
108 | exr_data = read_openexr_file(exr_file_path)
109 | h = exr_data['shape'][0]
110 | w = exr_data['shape'][1]
111 |
112 | aspect_ratio = w / h
113 | new_w = int(new_h * aspect_ratio)
114 |
115 | img0 = exr_data['image_data']
116 | img0 = torch.from_numpy(img0)
117 | img0 = img0.to(device = device, dtype = torch.float32)
118 | img0 = resize_image(img0, new_h, new_w)
119 | img0 = img0.to(dtype=torch.half)
120 | write_exr(img0.cpu().detach().numpy(), exr_file_path, half_float = True)
121 | del img0
122 |
123 | def main():
124 | parser = argparse.ArgumentParser(description='In-place resize exrs to half-res (WARNING - DESTRUCTIVE!)')
125 |
126 | # Required argument
127 | parser.add_argument('h', type=int, help='New height')
128 | parser.add_argument('dataset_path', type=str, help='Path to folders with exrs')
129 | args = parser.parse_args()
130 |
131 | folders_with_exr = find_folders_with_exr(args.dataset_path)
132 |
133 | exr_files = []
134 |
135 | for folder_path in sorted(folders_with_exr):
136 | folder_exr_files = [os.path.join(folder_path, file) for file in os.listdir(folder_path) if file.endswith('.exr')]
137 | folder_exr_files.sort()
138 | exr_files.extend(folder_exr_files)
139 |
140 | idx = 0
141 | for exr_file_path in exr_files:
142 | clear_lines(1)
143 | print (f'\rFile [{idx+1} / {len(exr_files)}], {os.path.basename(exr_file_path)}')
144 | try:
145 | halve(exr_file_path, args.h)
146 | except Exception as e:
147 | print (f'\n\nError halving {exr_file_path}: {e}')
148 | idx += 1
149 | print ('')
150 |
151 | if __name__ == "__main__":
152 | main()
153 |
--------------------------------------------------------------------------------
/pytorch/speedtest.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 | import platform
5 | import torch
6 | import numpy as np
7 | import random
8 | import queue
9 | import threading
10 | import time
11 | import OpenImageIO as oiio
12 |
13 | from pprint import pprint
14 |
15 | def find_and_import_model(models_dir='models', base_name=None, model_name=None, model_file=None):
16 | """
17 | Dynamically imports the latest version of a model based on the base name,
18 | or a specific model if the model name/version is given, and returns the Model
19 | object named after the base model name.
20 |
21 | :param models_dir: Relative path to the models directory.
22 | :param base_name: Base name of the model to search for.
23 | :param model_name: Specific name/version of the model (optional).
24 | :return: Imported Model object or None if not found.
25 | """
26 |
27 | import os
28 | import re
29 | import importlib
30 |
31 | if model_file:
32 | module_name = model_file[:-3] # Remove '.py' from filename to get module name
33 |
34 | print (module_name)
35 |
36 | module_path = f"models.{module_name}"
37 | module = importlib.import_module(module_path)
38 | model_object = getattr(module, 'Model')
39 | return model_object
40 |
41 | # Resolve the absolute path of the models directory
42 | models_abs_path = os.path.abspath(
43 | os.path.join(
44 | os.path.dirname(__file__),
45 | models_dir
46 | )
47 | )
48 |
49 | # List all files in the models directory
50 | try:
51 | files = os.listdir(models_abs_path)
52 | except FileNotFoundError:
53 | print(f"Directory not found: {models_abs_path}")
54 | return None
55 |
56 | # Filter files based on base_name or model_name
57 | if model_name:
58 | # Look for a specific model version
59 | filtered_files = [f for f in files if f == f"{model_name.lower()}.py"]
60 | else:
61 | # Find all versions of the model and select the latest one
62 | # regex_pattern = fr"{base_name}_v(\d+)\.py"
63 | # versions = [(f, int(m.group(1))) for f in files if (m := re.match(regex_pattern, f))]
64 | versions = [f for f in files if f.endswith('.py')]
65 | if versions:
66 | # Sort by version number (second item in tuple) and select the latest one
67 | # latest_version_file = sorted(versions, key=lambda x: x[1], reverse=True)[0][0]
68 | latest_version_file = sorted(versions, reverse=True)[0]
69 | filtered_files = [latest_version_file]
70 |
71 | # Import the module and return the Model object
72 | if filtered_files:
73 | module_name = filtered_files[0][:-3] # Remove '.py' from filename to get module name
74 | module_path = f"models.{module_name}"
75 | module = importlib.import_module(module_path)
76 | model_object = getattr(module, 'Model')
77 | return model_object
78 | else:
79 | print(f"Model not found: {base_name or model_name}")
80 | return None
81 |
82 |
83 |
84 | def main():
85 | parser = argparse.ArgumentParser(description='Speed test.')
86 | parser.add_argument('--model', type=str, default=None, help='Model name (default: Flownet4_v001_baseline.py)')
87 | parser.add_argument('--frame_size', type=str, default=None, help='Frame size (default: 4096x1716)')
88 | parser.add_argument('--device', type=int, default=0, help='Graphics card index (default: 0)')
89 |
90 | args = parser.parse_args()
91 |
92 | device = torch.device("mps") if platform.system() == 'Darwin' else torch.device('cuda')
93 |
94 | if args.frame_size:
95 | w, h = args.frame_size.split('x')
96 | h, w = int(h), int(w)
97 | else:
98 | h, w = 1716, 4096
99 |
100 | shape = (1, 3, h, w)
101 | img = torch.randn(shape).to(device)
102 |
103 | model = args.model if args.model else 'Flownet4_v001_baseline'
104 | Net = find_and_import_model(model_name=model)
105 |
106 | print ('Net info:')
107 | pprint (Net.get_info())
108 |
109 | net = Net().get_training_model()().to(device)
110 |
111 | import signal
112 | def create_graceful_exit():
113 | def graceful_exit(signum, frame):
114 | '''
115 | print(f'\nSaving current state to {current_state_dict["trained_model_path"]}...')
116 | print (f'Epoch: {current_state_dict["epoch"] + 1}, Step: {current_state_dict["step"]:11}')
117 | torch.save(current_state_dict, current_state_dict['trained_model_path'])
118 | exit_event.set() # Signal threads to stop
119 | process_exit_event.set() # Signal processes to stop
120 | '''
121 | print('\n')
122 | exit(0)
123 | # signal.signal(signum, signal.SIG_DFL)
124 | # os.kill(os.getpid(), signal.SIGINT)
125 | return graceful_exit
126 | signal.signal(signal.SIGINT, create_graceful_exit())
127 |
128 | with torch.no_grad():
129 | while True:
130 | start_time = time.time()
131 | reult = net(
132 | img,
133 | img,
134 | # 0.5,
135 | scale=[8, 4, 2, 1],
136 | iterations = 1
137 | )
138 | # result = result[0].permute(1, 2, 0).numpy(force=True)
139 | torch.cuda.synchronize()
140 | end_time = time.time()
141 | elapsed_time = end_time - start_time
142 | elapsed_time = elapsed_time if elapsed_time > 0 else float('inf')
143 | fps = 1 / elapsed_time
144 | print (f'\rRunning at {fps:.2f} fps for {shape[3]}x{shape[2]}', end='')
145 |
146 | if __name__ == "__main__":
147 | main()
148 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | lpips
2 |
--------------------------------------------------------------------------------
/retime.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Resolves the directory where the script is located
4 | SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
5 |
6 | # Change to the script directory
7 | cd "$SCRIPT_DIR"
8 |
9 | # Define the path to the Python executable and the Python script
10 | PYTHON_CMD="./packages/.miniconda/appenv/bin/python"
11 | PYTHON_SCRIPT="./pytorch/retime.py"
12 |
13 | # Run the Python script with all arguments passed to this shell script
14 | $PYTHON_CMD $PYTHON_SCRIPT "$@"
15 |
--------------------------------------------------------------------------------
/speedtest.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Resolves the directory where the script is located
4 | SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
5 |
6 | # Change to the script directory
7 | cd "$SCRIPT_DIR"
8 |
9 | # Define the path to the Python executable and the Python script
10 | PYTHON_CMD="./packages/.miniconda/appenv/bin/python"
11 | PYTHON_SCRIPT="./pytorch/speedtest.py"
12 |
13 | # Run the Python script with all arguments passed to this shell script
14 | $PYTHON_CMD $PYTHON_SCRIPT "$@"
15 |
--------------------------------------------------------------------------------
/test.timewarp_node:
--------------------------------------------------------------------------------
1 |
2 |
3 | 21.020000
4 | Timewarp
5 |
6 | False
7 | 0
8 | 0
9 | 1
10 | 4
11 | 0
12 | False
13 | False
14 |
15 | 1
16 | 0
17 | False
18 |
19 |
20 |
21 | 2
22 | False
23 | 0
24 | False
25 |
26 | 1
27 | 0
28 | 0
29 | 1
30 |
31 | 2
32 | 0
33 | True
34 | 1
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 | constant
48 | 100
49 | 1
50 | 2
51 |
52 |
53 | 1
54 | 100
55 | 0.250000
56 | 0
57 | -0.250000
58 | 0
59 | hermite
60 | linear
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 | linear
69 | 1
70 | 1
71 | 2
72 |
73 |
74 | 1
75 | 31
76 | 0.250000
77 | 0.250000
78 | -0.250000
79 | -0.250000
80 | hermite
81 | quadratic
82 | False
83 | False
84 | manual
85 |
86 |
87 |
88 |
89 |
90 |
91 | linear
92 | 11
93 | 6
94 | 2
95 |
96 |
97 | 1
98 | 11
99 | 0.250000
100 | 0.250000
101 | -0.250000
102 | -0.250000
103 | hermite
104 | cubic
105 | False
106 | False
107 | manual
108 |
109 |
110 | 21
111 | 34.7803268
112 | 3
113 | 3.603181
114 | -6.666667
115 | -8.007070
116 | hermite
117 | cubic
118 | smooth
119 |
120 |
121 | 30
122 | 45.8307533
123 | 2.333333
124 | 3.677869
125 | -3
126 | -4.728689
127 | hermite
128 | cubic
129 | smooth
130 |
131 |
132 | 37
133 | 60
134 | 10.666667
135 | 14.571893
136 | -2.333333
137 | -3.187602
138 | hermite
139 | cubic
140 | smooth
141 |
142 |
143 | 69
144 | 99.1092377
145 | 5.333333
146 | 6.500000
147 | -10.666667
148 | -13
149 | hermite
150 | cubic
151 | smooth
152 |
153 |
154 | 85
155 | 118.5
156 | 0.250000
157 | 0.320622
158 | -0.250000
159 | -0.320622
160 | hermite
161 | cubic
162 | smooth
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 | 1
214 | 30
215 | 1
216 | 0
217 | 1
218 | False
219 | 0
220 |
221 |
222 | linear
223 | 86
224 | 2
225 | 2
226 |
227 |
228 | 1
229 | 31
230 | 0.250000
231 | 0.250000
232 | -0.250000
233 | -0.250000
234 | bezier
235 | cubic
236 | False
237 | False
238 | manual
239 |
240 |
241 | 86
242 | 116
243 | 0.250000
244 | 0.250000
245 | -0.250000
246 | -0.250000
247 | bezier
248 | cubic
249 | False
250 | False
251 | manual
252 |
253 |
254 |
255 |
256 | 0
257 | 1
258 |
259 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Resolves the directory where the script is located
4 | SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
5 |
6 | # Change to the script directory
7 | cd "$SCRIPT_DIR"
8 |
9 | # Define the path to the Python executable and the Python script
10 | PYTHON_CMD="./packages/.miniconda/appenv/bin/python"
11 | PYTHON_SCRIPT="./pytorch/flameTimewarpML_train.py"
12 |
13 | # Run the Python script with all arguments passed to this shell script
14 | $PYTHON_CMD $PYTHON_SCRIPT "$@"
15 |
--------------------------------------------------------------------------------
/train_stab.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Resolves the directory where the script is located
4 | SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
5 |
6 | # Change to the script directory
7 | cd "$SCRIPT_DIR"
8 |
9 | # Define the path to the Python executable and the Python script
10 | PYTHON_CMD="./packages/.miniconda/appenv/bin/python"
11 | PYTHON_SCRIPT="./pytorch/flameStabML_train.py"
12 |
13 | # Run the Python script with all arguments passed to this shell script
14 | $PYTHON_CMD $PYTHON_SCRIPT "$@"
15 |
--------------------------------------------------------------------------------
/validate.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Resolves the directory where the script is located
4 | SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
5 |
6 | # Change to the script directory
7 | cd "$SCRIPT_DIR"
8 |
9 | # Define the path to the Python executable and the Python script
10 | PYTHON_CMD="./packages/.miniconda/appenv/bin/python"
11 | PYTHON_SCRIPT="./pytorch/flameTimewarpML_validate.py"
12 |
13 | # Run the Python script with all arguments passed to this shell script
14 | $PYTHON_CMD $PYTHON_SCRIPT "$@"
15 |
--------------------------------------------------------------------------------