├── .gitignore ├── .vscode └── settings.json ├── LICENSE ├── README.md ├── VMM_learning_based_demo.ipynb ├── callbacks.py ├── config.py ├── data.py ├── losses.py ├── magnet.py ├── main.py ├── make_frameACB.py ├── materials ├── Fig2-a.png ├── baby_comp.gif ├── dogs.png ├── guitar_comp.gif └── myself_comp.gif ├── requirements.txt ├── run.sh └── test_video.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom 2 | .vscode 3 | weights* 4 | results* 5 | 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # pyenv 82 | .python-version 83 | 84 | # celery beat schedule file 85 | celerybeat-schedule 86 | 87 | # SageMath parsed files 88 | *.sage.py 89 | 90 | # Environments 91 | .env 92 | .venv 93 | env/ 94 | venv/ 95 | ENV/ 96 | env.bak/ 97 | venv.bak/ 98 | 99 | # Spyder project settings 100 | .spyderproject 101 | .spyproject 102 | 103 | # Rope project settings 104 | .ropeproject 105 | 106 | # mkdocs documentation 107 | /site 108 | 109 | # mypy 110 | .mypy_cache/ 111 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.pythonPath": "/home/pengzheng/miniconda3/envs/mm/bin/python" 3 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 ZhengPeng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Motion_magnification_learning-based 2 | This is an unofficial implementation of "[Learning-based Video Motion Magnification](https://arxiv.org/abs/1804.02684)" in Pytorch (1.8.1~2.0). 3 | [Here is the official implementation in Tensorflow==1.8.0](https://github.com/12dmodel/deep_motion_mag). 4 | 5 | #### High recommendations on my friends' latest works, come and try them! 6 | + Event-Based Motion Magnification: [[paper](https://arxiv.org/pdf/2402.11957.pdf)] [[codes](https://github.com/OpenImagingLab/emm)] [[project](https://openimaginglab.github.io/emm/)] 7 | + Frequency Decoupling for Motion Magnification via Multi-Level Isomorphic Architecture: [[paper](https://arxiv.org/pdf/2403.07347.pdf)] [[codes](https://github.com/Jiafei127/FD4MM)] 8 | 9 | # Update 10 | **(2023/11/05) Add notebook demo for offline inference. Feel free to email me or leave issues if you want any help I can do.** 11 | 12 | **(2023/04/07) I find there are still a few friends like you who have interests in this old repo, so I make a Colab demo for easy inference if you want. And I'm sorry for my stupid codes years ago, I felt painful when I used them for the Colab demo... And you know, some still exist 😂 But if you have any trouble with it, feel free to leave an issue or send an e-mail to me.** 13 | 14 | Besides, as tested, this repo can be compatible with **PyTorch 2.x** 15 | 16 | *Given the video, and amplify it with only one click for all steps:* 17 | 18 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1inOucehJXUAVBlRhZvo650SoOPLKQFNv#scrollTo=BjgKRohk7Q5M) 19 | 20 |

21 | VMM_colab_demo 22 |

23 | 24 | # Env 25 | ``` 26 | conda create -n vmm python=3.10 -y && conda activate vmm 27 | pip install -r requirements.txt 28 | ``` 29 | 30 | # Data preparation 31 | 32 | 0. About the synthetic dataset for **training**, please refer to the official repository mentioned above or download [here](https://drive.google.com/drive/folders/19K09QLouiV5N84wZiTPUMdoH9-UYqZrX?usp=sharing). 33 | 34 | 1. About the video datasets for **validation**, you can also download the preprocessed frames [here](https://drive.google.com/drive/folders/19K09QLouiV5N84wZiTPUMdoH9-UYqZrX?usp=sharing), which is named train_vid_frames.zip. 35 | 36 | 2. Check the settings of val_dir in **config.py** and modify it if necessary. 37 | 38 | 3. To convert the **validation** video into frames: 39 | 40 | `mkdir VIDEO_NAME && ffmpeg -i VIDEO_NAME.mp4 -f image2 VIDEO_NAME/%06d.png` 41 | 42 | > Tips: ffmpeg can also be installed by conda. 43 | 44 | 4. Modify the frames into **frameA/frameB/frameC**: 45 | 46 | `python make_frameACB.py `(remember adapt the 'if' at the beginning of the program to select videos.) 47 | 48 | # Little differences from the official codes 49 | 50 | 1. **Poisson noise** is not used here because I was a bit confused about that in official code. Although I coded it in data.py, and it works exactly the same as the official codes as I checked by examples. 51 | 2. About the **optimizer**, we kept it the same as that in the original paper -- Adam(lr=1e-4, betas=(0.9, 0.999)) with no weight decay, which is different from the official codes. 52 | 3. About the in loss, we also adhere to the original paper -- set to 0.1, which is different from the official codes. 53 | 4. The **temporal filter** is currently a bit confusing for me, so I haven't made the part of testing with temporal filter, sorry for that:(... 54 | 55 | # One thing **important** 56 | 57 | If you check the Fig.2-a in the original paper, you will find that the predicted magnified frame is actually , although the former one is theoretically same as with the same . 58 | 59 | Fig2-a 60 | 61 | However, what makes it matter is that the authors used perturbation for regularization, and the images in the dataset given has 4 parts: 62 | 63 | 1. frameA: , unperturbed; 64 | 2. frameB: perturbed frameC, is actually in the paper, 65 | 3. frameC: the real , unperturbed; 66 | 4. **amplified**: represent both and , perturbed. 67 | 68 | Here is the first training sample, where you can see clear that **no perturbation** between **A-C** nor between **B-amp**, and no motion between B-C: 69 | 70 | dog 71 | 72 | Given that, we don't have the unperturbed amplified frame, so **we can only use the former formula**(with ). Besides, if you check the **loss** in the original paper, you will find the , where is the ?... I also referred to some third-party reproductions on this problem which confused me a lot, but none of them solve it. And some just gave 0 to manually, so I think they noticed this problem too but didn't manage to understand it. 73 | 74 | Here are some links to the issues about this problem in the official repository, [issue-1](https://github.com/12dmodel/deep_motion_mag/issues/3), [issue-2](https://github.com/12dmodel/deep_motion_mag/issues/5), [issue-3](https://github.com/12dmodel/deep_motion_mag/issues/4), if you want to check them. 75 | 76 | # Run 77 | `bash run.sh` to train and test. 78 | 79 | It took me around 20 hours to train for 12 epochs on a single TITAN-Xp. 80 | 81 | If you don't want to use all the 100,000 groups to train, you can modify the `frames_train='coco100000'` in config.py to coco30000 or some other number. 82 | 83 | You can **download the weights**-ep12 from [the release](https://github.com/ZhengPeng7/motion_magnification_learning-based/releases/tag/v1.0), and `python test_videos.py baby-guitar-yourself-...` to do the test. 84 | 85 | # Results 86 | 87 | Here are some results generated from the model trained on the whole synthetic dataset for **12** epochs. 88 | 89 | Baby, amplification factor = 50 90 | 91 | ![baby](materials/baby_comp.gif) 92 | 93 | Guitar, amplification factor = 20 94 | 95 | ![guitar](materials/guitar_comp.gif) 96 | 97 | And I also took a video on the face of myself with amplification factor 20, which showed a Chinese idiom called '夺眶而出'😂. 98 | 99 | ![myself](materials/myself_comp.gif) 100 | 101 | > Any question, all welcome:) 102 | -------------------------------------------------------------------------------- /VMM_learning_based_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "bf5RrhNaDejC" 7 | }, 8 | "source": [ 9 | "# An offline demo for the \"Learning-based Video Motion Magnification\" (ECCV 2018)\n", 10 | "\n", 11 | "# Official repo: https://github.com/12dmodel/deep_motion_mag\n", 12 | "# My repo: https://github.com/ZhengPeng7/motion_magnification_learning-based" 13 | ] 14 | }, 15 | { 16 | "cell_type": "markdown", 17 | "metadata": {}, 18 | "source": [ 19 | "## Make sure to set your python Kernel for notebook. " 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": { 25 | "id": "Z7Ghq0laIGq4" 26 | }, 27 | "source": [ 28 | "## Preparations:\n" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": { 34 | "id": "pWvGFo9yECoa" 35 | }, 36 | "source": [ 37 | "### Install python packages" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 1, 43 | "metadata": { 44 | "colab": { 45 | "base_uri": "https://localhost:8080/" 46 | }, 47 | "id": "nI-Hq9WK5xNg", 48 | "outputId": "4c55ba46-ee51-4752-8dd1-deab4ffe704e" 49 | }, 50 | "outputs": [ 51 | { 52 | "name": "stdout", 53 | "output_type": "stream", 54 | "text": [ 55 | "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", 56 | "\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", 57 | "\u001b[0m" 58 | ] 59 | } 60 | ], 61 | "source": [ 62 | "!pip install --quiet -r requirements.txt\n", 63 | "!pip install --quiet gdown mediapy" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": { 69 | "id": "nhA7-odZEHQW" 70 | }, 71 | "source": [ 72 | "### Download and load the well-trained weights:" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 2, 78 | "metadata": { 79 | "colab": { 80 | "base_uri": "https://localhost:8080/" 81 | }, 82 | "id": "FUX3pb77Axr0", 83 | "outputId": "11164c11-73dc-4418-8248-dfb9ab66d535" 84 | }, 85 | "outputs": [ 86 | { 87 | "name": "stdout", 88 | "output_type": "stream", 89 | "text": [ 90 | "Will not apply HSTS. The HSTS database must be a regular and non-world-writable file.\n", 91 | "ERROR: could not open HSTS store at '/root/.wget-hsts'. HSTS will be disabled.\n", 92 | "--2023-11-05 14:03:48-- https://github.com/ZhengPeng7/motion_magnification_learning-based/releases/download/v1.0/magnet_epoch12_loss7.28e-02.pth\n", 93 | "Resolving github.com (github.com)... 47.93.241.142, 39.106.83.44\n", 94 | "Connecting to github.com (github.com)|47.93.241.142|:443... connected.\n", 95 | "HTTP request sent, awaiting response... 302 Found\n", 96 | "Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/223970988/4e613b00-aa97-11ea-87b4-bef43fcc6281?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20231105%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20231105T060349Z&X-Amz-Expires=300&X-Amz-Signature=054d40cd1599fce311358e5952e84073f9d390e242c658cc5a3de30142350113&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=223970988&response-content-disposition=attachment%3B%20filename%3Dmagnet_epoch12_loss7.28e-02.pth&response-content-type=application%2Foctet-stream [following]\n", 97 | "--2023-11-05 14:03:49-- https://objects.githubusercontent.com/github-production-release-asset-2e65be/223970988/4e613b00-aa97-11ea-87b4-bef43fcc6281?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20231105%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20231105T060349Z&X-Amz-Expires=300&X-Amz-Signature=054d40cd1599fce311358e5952e84073f9d390e242c658cc5a3de30142350113&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=223970988&response-content-disposition=attachment%3B%20filename%3Dmagnet_epoch12_loss7.28e-02.pth&response-content-type=application%2Foctet-stream\n", 98 | "Resolving objects.githubusercontent.com (objects.githubusercontent.com)... 185.199.109.133, 185.199.110.133, 185.199.111.133, ...\n", 99 | "Connecting to objects.githubusercontent.com (objects.githubusercontent.com)|185.199.109.133|:443... connected.\n", 100 | "HTTP request sent, awaiting response... 200 OK\n", 101 | "Length: 3528201 (3.4M) [application/octet-stream]\n", 102 | "Saving to: ‘magnet_epoch12_loss7.28e-02.pth’\n", 103 | "\n", 104 | "magnet_epoch12_loss 100%[===================>] 3.36M 15.3KB/s in 3m 17s \n", 105 | "\n", 106 | "2023-11-05 14:07:09 (17.5 KB/s) - ‘magnet_epoch12_loss7.28e-02.pth’ saved [3528201/3528201]\n", 107 | "\n" 108 | ] 109 | } 110 | ], 111 | "source": [ 112 | "!wget https://github.com/ZhengPeng7/motion_magnification_learning-based/releases/download/v1.0/magnet_epoch12_loss7.28e-02.pth" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 3, 118 | "metadata": { 119 | "colab": { 120 | "base_uri": "https://localhost:8080/" 121 | }, 122 | "id": "quAf0VARHAPM", 123 | "outputId": "7eb23fde-56e9-4797-dd0a-091b559aaec8" 124 | }, 125 | "outputs": [ 126 | { 127 | "name": "stdout", 128 | "output_type": "stream", 129 | "text": [ 130 | "Please load train_mf.txt if you want to do training.\n", 131 | "Loading weights: magnet_epoch12_loss7.28e-02.pth\n" 132 | ] 133 | } 134 | ], 135 | "source": [ 136 | "from magnet import MagNet\n", 137 | "from callbacks import gen_state_dict\n", 138 | "from config import Config\n", 139 | "\n", 140 | "\n", 141 | "# config\n", 142 | "config = Config()\n", 143 | "# Load weights\n", 144 | "weights_path = 'magnet_epoch12_loss7.28e-02.pth'\n", 145 | "ep = int(weights_path.split('epoch')[-1].split('_')[0])\n", 146 | "state_dict = gen_state_dict(weights_path)\n", 147 | "\n", 148 | "model_test = MagNet().cuda()\n", 149 | "model_test.load_state_dict(state_dict)\n", 150 | "model_test.eval()\n", 151 | "print(\"Loading weights:\", weights_path)" 152 | ] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "metadata": { 157 | "id": "_AM4n9P-iaNG" 158 | }, 159 | "source": [ 160 | "# Preprocess\n", 161 | "\n", 162 | "Make the video to frameAs/frameBs/frameCs.\n", 163 | "\n", 164 | "Let's take the guitar.mp4 as an example." 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "metadata": { 171 | "colab": { 172 | "base_uri": "https://localhost:8080/" 173 | }, 174 | "id": "tnJSrX43vlgy", 175 | "outputId": "bb3fec89-a7f1-4246-a4c7-7f08548ca905" 176 | }, 177 | "outputs": [], 178 | "source": [ 179 | "# Download some example video from my google-drive or upload your own video.\n", 180 | "# !gdown 1hNZ02vnSO04FYS9jkx2OjProYHIvwdkB # guitar.avi\n", 181 | "!gdown 1XGC2y4Lshd9aBiBxwkTuT_IA79n-WNST # baby.avi\n", 182 | "# !gdown 1QGOWuR0swF7_eHharTztlkEDz0hlfmU4 # zhiyin.mp4" 183 | ] 184 | }, 185 | { 186 | "cell_type": "markdown", 187 | "metadata": { 188 | "id": "pd9lV4mCHU9v" 189 | }, 190 | "source": [ 191 | "# Set VIDEO_NAME here, e.g., guitar, baby" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": 5, 197 | "metadata": { 198 | "colab": { 199 | "base_uri": "https://localhost:8080/" 200 | }, 201 | "id": "I0v0Uacfmtib", 202 | "outputId": "01383531-8a69-42f5-8b7c-a280acff41c8" 203 | }, 204 | "outputs": [ 205 | { 206 | "name": "stdout", 207 | "output_type": "stream", 208 | "text": [ 209 | "ffmpeg version 4.3 Copyright (c) 2000-2020 the FFmpeg developers\n", 210 | " built with gcc 7.3.0 (crosstool-NG 1.23.0.449-a04d0)\n", 211 | " configuration: --prefix=/opt/conda/conda-bld/ffmpeg_1597178665428/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placeh --cc=/opt/conda/conda-bld/ffmpeg_1597178665428/_build_env/bin/x86_64-conda_cos6-linux-gnu-cc --disable-doc --disable-openssl --enable-avresample --enable-gnutls --enable-hardcoded-tables --enable-libfreetype --enable-libopenh264 --enable-pic --enable-pthreads --enable-shared --disable-static --enable-version3 --enable-zlib --enable-libmp3lame\n", 212 | " libavutil 56. 51.100 / 56. 51.100\n", 213 | " libavcodec 58. 91.100 / 58. 91.100\n", 214 | " libavformat 58. 45.100 / 58. 45.100\n", 215 | " libavdevice 58. 10.100 / 58. 10.100\n", 216 | " libavfilter 7. 85.100 / 7. 85.100\n", 217 | " libavresample 4. 0. 0 / 4. 0. 0\n", 218 | " libswscale 5. 7.100 / 5. 7.100\n", 219 | " libswresample 3. 7.100 / 3. 7.100\n", 220 | "Input #0, avi, from 'baby.avi':\n", 221 | " Duration: 00:00:10.03, start: 0.000000, bitrate: 15411 kb/s\n", 222 | " Stream #0:0: Video: mjpeg (Baseline) (MJPG / 0x47504A4D), yuvj420p(pc, bt470bg/unknown/unknown), 960x544 [SAR 1:1 DAR 30:17], 15399 kb/s, 30 fps, 30 tbr, 30 tbn, 30 tbc\n", 223 | "Stream mapping:\n", 224 | " Stream #0:0 -> #0:0 (mjpeg (native) -> png (native))\n", 225 | "Press [q] to stop, [?] for help\n", 226 | "\u001b[1;34m[swscaler @ 0x561bdb17b380] \u001b[0m\u001b[0;33mdeprecated pixel format used, make sure you did set range correctly\n", 227 | "\u001b[0mOutput #0, image2, to 'baby/%06d.png':\n", 228 | " Metadata:\n", 229 | " encoder : Lavf58.45.100\n", 230 | " Stream #0:0: Video: png, rgb24, 960x544 [SAR 1:1 DAR 30:17], q=2-31, 200 kb/s, 30 fps, 30 tbn, 30 tbc\n", 231 | " Metadata:\n", 232 | " encoder : Lavc58.91.100 png\n", 233 | "frame= 301 fps= 48 q=-0.0 Lsize=N/A time=00:00:10.03 bitrate=N/A speed= 1.6x \n", 234 | "video:187776kB audio:0kB subtitle:0kB other streams:0kB global headers:0kB muxing overhead: unknown\n", 235 | "ACB-Processing on baby\n" 236 | ] 237 | } 238 | ], 239 | "source": [ 240 | "# Turn the video into frames and make them into frame_ACB format.\n", 241 | "file_to_be_maged = 'baby.avi' # 'guitar.avi'\n", 242 | "video_name = file_to_be_maged.split('.')[0]\n", 243 | "video_format = '.' + file_to_be_maged.split('.')[-1]\n", 244 | "\n", 245 | "\n", 246 | "sh_file = 'VIDEO_NAME={}\\nVIDEO_FORMAT={}'.format(video_name, video_format) + \"\"\"\n", 247 | "\n", 248 | "\n", 249 | "mkdir ${VIDEO_NAME}\n", 250 | "ffmpeg -i ${VIDEO_NAME}${VIDEO_FORMAT} -f image2 ${VIDEO_NAME}/%06d.png\n", 251 | "python make_frameACB.py ${VIDEO_NAME}\n", 252 | "mkdir test_dir\n", 253 | "mv ${VIDEO_NAME} test_dir\n", 254 | "\"\"\"\n", 255 | "with open('test_preproc.sh', 'w') as file:\n", 256 | " file.write(sh_file)\n", 257 | "\n", 258 | "!bash test_preproc.sh" 259 | ] 260 | }, 261 | { 262 | "cell_type": "markdown", 263 | "metadata": { 264 | "id": "OeOYfifiGUPq" 265 | }, 266 | "source": [ 267 | "## Test\n" 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": 6, 273 | "metadata": { 274 | "colab": { 275 | "base_uri": "https://localhost:8080/" 276 | }, 277 | "id": "RXz3mYanz07-", 278 | "outputId": "9d526b43-6f76-404c-fa91-235ea6b5ee72" 279 | }, 280 | "outputs": [ 281 | { 282 | "name": "stdout", 283 | "output_type": "stream", 284 | "text": [ 285 | "Number of test image couples: 300\n", 286 | "100, 200, 300, res_baby/baby/baby_amp10.avi has been done.\n", 287 | "100, 200, 300, res_baby/baby/baby_amp25.avi has been done.\n", 288 | "100, 200, 300, res_baby/baby/baby_amp50.avi has been done.\n" 289 | ] 290 | } 291 | ], 292 | "source": [ 293 | "import os\n", 294 | "import sys\n", 295 | "import cv2\n", 296 | "import torch\n", 297 | "import numpy as np\n", 298 | "from data import get_gen_ABC, unit_postprocessing, numpy2cuda, resize2d\n", 299 | "\n", 300 | "\n", 301 | "# testsets = '+'.join(['baby', 'guitar', 'gun', 'drone', 'cattoy', 'water'][:])\n", 302 | "for testset in [video_name]:\n", 303 | " dir_results = 'res_' + testset\n", 304 | " if not os.path.exists(dir_results):\n", 305 | " os.makedirs(dir_results)\n", 306 | "\n", 307 | " config.data_dir = 'test_dir'\n", 308 | " data_loader = get_gen_ABC(config, mode='test_on_'+testset)\n", 309 | " print('Number of test image couples:', data_loader.data_len)\n", 310 | " vid_size = cv2.imread(data_loader.paths[0]).shape[:2][::-1]\n", 311 | "\n", 312 | " # Test\n", 313 | " for amp in [10, 25, 50]:\n", 314 | " frames = []\n", 315 | " data_loader = get_gen_ABC(config, mode='test_on_'+testset)\n", 316 | " for idx_load in range(0, data_loader.data_len, data_loader.batch_size):\n", 317 | " if (idx_load+1) % 100 == 0:\n", 318 | " print('{}'.format(idx_load+1), end=', ')\n", 319 | " batch_A, batch_B = data_loader.gen_test()\n", 320 | " amp_factor = numpy2cuda(amp)\n", 321 | " for _ in range(len(batch_A.shape) - len(amp_factor.shape)):\n", 322 | " amp_factor = amp_factor.unsqueeze(-1)\n", 323 | " with torch.no_grad():\n", 324 | " y_hats = model_test(batch_A, batch_B, 0, 0, amp_factor, mode='evaluate')\n", 325 | " for y_hat in y_hats:\n", 326 | " y_hat = unit_postprocessing(y_hat, vid_size=vid_size)\n", 327 | " frames.append(y_hat)\n", 328 | " if len(frames) >= data_loader.data_len:\n", 329 | " break\n", 330 | " if len(frames) >= data_loader.data_len:\n", 331 | " break\n", 332 | " data_loader = get_gen_ABC(config, mode='test_on_'+testset)\n", 333 | " frames = [unit_postprocessing(data_loader.gen_test()[0], vid_size=vid_size)] + frames\n", 334 | "\n", 335 | " # Make videos of framesMag\n", 336 | " video_dir = os.path.join(dir_results, testset)\n", 337 | " if not os.path.exists(video_dir):\n", 338 | " os.makedirs(video_dir)\n", 339 | " FPS = 30\n", 340 | " video_save_path = os.path.join(video_dir, '{}_amp{}{}'.format(testset, amp, video_format))\n", 341 | " out = cv2.VideoWriter(\n", 342 | " video_save_path,\n", 343 | " cv2.VideoWriter_fourcc(*'DIVX'),\n", 344 | " FPS, frames[0].shape[-2::-1]\n", 345 | " )\n", 346 | " for frame in frames:\n", 347 | " frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)\n", 348 | " cv2.putText(frame, 'amp_factor={}'.format(amp), (7, 37),\n", 349 | " fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255), thickness=2)\n", 350 | " out.write(frame)\n", 351 | " out.release()\n", 352 | " print('{} has been done.'.format(video_save_path))" 353 | ] 354 | }, 355 | { 356 | "cell_type": "code", 357 | "execution_count": null, 358 | "metadata": {}, 359 | "outputs": [], 360 | "source": [ 361 | "# Play the amplified video here\n", 362 | "from glob import glob\n", 363 | "import mediapy\n", 364 | "\n", 365 | "\n", 366 | "video_save_paths = [file_to_be_maged] + sorted(glob(os.path.join(dir_results, testset, '*')), key=lambda x: int(x.split('amp')[-1].split('.')[0]))\n", 367 | "\n", 368 | "video_dict = {}\n", 369 | "for video_save_path in video_save_paths[:]:\n", 370 | " video_dict[video_save_path.split('/')[-1]] = mediapy.read_video(video_save_path)\n", 371 | "mediapy.show_videos(video_dict, fps=FPS, width=250, codec='gif')" 372 | ] 373 | }, 374 | { 375 | "cell_type": "code", 376 | "execution_count": null, 377 | "metadata": {}, 378 | "outputs": [], 379 | "source": [] 380 | } 381 | ], 382 | "metadata": { 383 | "accelerator": "GPU", 384 | "colab": { 385 | "provenance": [] 386 | }, 387 | "kernelspec": { 388 | "display_name": "Python 3 (ipykernel)", 389 | "language": "python", 390 | "name": "python3" 391 | }, 392 | "language_info": { 393 | "codemirror_mode": { 394 | "name": "ipython", 395 | "version": 3 396 | }, 397 | "file_extension": ".py", 398 | "mimetype": "text/x-python", 399 | "name": "python", 400 | "nbconvert_exporter": "python", 401 | "pygments_lexer": "ipython3", 402 | "version": "3.10.13" 403 | } 404 | }, 405 | "nbformat": 4, 406 | "nbformat_minor": 4 407 | } 408 | -------------------------------------------------------------------------------- /callbacks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def save_model(weights, losses, weights_dir, epoch): 7 | loss = np.mean(losses) 8 | path_ckpt = os.path.join( 9 | weights_dir, 'magnet_epoch{}_loss{:.2e}.pth'.format(epoch, loss) 10 | ) 11 | torch.save(weights, path_ckpt) 12 | 13 | 14 | def gen_state_dict(weights_path): 15 | st = torch.load(weights_path) 16 | st_ks = list(st.keys()) 17 | st_vs = list(st.values()) 18 | state_dict = {} 19 | for st_k, st_v in zip(st_ks, st_vs): 20 | state_dict[st_k.replace('module.', '')] = st_v 21 | return state_dict 22 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | import torch 5 | 6 | class Config(object): 7 | def __init__(self): 8 | 9 | # General 10 | self.epochs = 12 11 | # self.GPUs = '0' 12 | self.batch_size = 6 # * torch.cuda.device_count() # len(self.GPUs.split(',')) 13 | self.date = '0510' 14 | 15 | # Data 16 | self.data_dir = '../../../datasets/mm' 17 | self.dir_train = os.path.join(self.data_dir, 'train') 18 | self.dir_test = os.path.join(self.data_dir, 'test') 19 | self.dir_water = os.path.join(self.data_dir, 'train/train_vid_frames/val_water') 20 | self.dir_baby = os.path.join(self.data_dir, 'train/train_vid_frames/val_baby') 21 | self.dir_gun = os.path.join(self.data_dir, 'train/train_vid_frames/val_gun') 22 | self.dir_drone = os.path.join(self.data_dir, 'train/train_vid_frames/val_drone') 23 | self.dir_guitar = os.path.join(self.data_dir, 'train/train_vid_frames/val_guitar') 24 | self.dir_cattoy = os.path.join(self.data_dir, 'train/train_vid_frames/val_cattoy') 25 | self.dir_myself = os.path.join(self.data_dir, 'train/train_vid_frames/myself') 26 | self.frames_train = 'coco100000' # you can adapt 100000 to a smaller number to train 27 | self.cursor_end = int(self.frames_train.split('coco')[-1]) 28 | if os.path.exists(os.path.join(self.dir_train, 'train_mf.txt')): 29 | self.coco_amp_lst = np.loadtxt(os.path.join(self.dir_train, 'train_mf.txt'))[:self.cursor_end] 30 | else: 31 | print('Please load train_mf.txt if you want to do training.') 32 | self.coco_amp_lst = None 33 | self.videos_train = [] 34 | self.load_all = False # Don't turn it on, unless you have such a big mem. 35 | # On coco dataset, 100, 000 sets -> 850G 36 | 37 | # Training 38 | self.lr = 1e-4 39 | self.betas = (0.9, 0.999) 40 | self.batch_size_test = 1 41 | self.preproc = ['poisson'] # ['resize', ] 42 | self.pretrained_weights = '' 43 | 44 | # Callbacks 45 | self.num_val_per_epoch = 10 46 | self.save_dir = 'weights_date{}'.format(self.date) 47 | self.time_st = time.time() 48 | self.losses = [] 49 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from sklearn.utils import shuffle 3 | import cv2 4 | from skimage.io import imread 5 | from skimage.util import random_noise 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.autograd import Variable 10 | from PIL import Image, ImageFile 11 | 12 | 13 | ImageFile.LOAD_TRUNCATED_IMAGES = True 14 | Tensor = torch.cuda.FloatTensor 15 | 16 | 17 | def gen_poisson_noise(unit): 18 | n = np.random.randn(*unit.shape) 19 | 20 | # Strange here, unit has been in range of (-1, 1), 21 | # but made example to checked to be same as the official codes. 22 | n_str = np.sqrt(unit + 1.0) / np.sqrt(127.5) 23 | poisson_noise = np.multiply(n, n_str) 24 | return poisson_noise 25 | 26 | 27 | def load_unit(path, inf_size=(0, 0)): 28 | # Load 29 | file_suffix = os.path.splitext(path)[1].lower() 30 | if file_suffix in ['.jpg', '.png']: 31 | try: 32 | image = imread(path).astype(np.uint8) 33 | except Exception as e: 34 | print('{} load exception:\n'.format(path), e) 35 | image = np.array(Image.open(path).convert('RGB')) 36 | if inf_size != (0, 0): 37 | image = cv2.resize(image, inf_size, interpolation=cv2.INTER_LANCZOS4) 38 | unit = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) 39 | return unit 40 | else: 41 | print('Unsupported file type.') 42 | return None 43 | 44 | def unit_preprocessing(unit, preproc=[], is_test=False): 45 | # Preprocessing 46 | if 'BF' in preproc and is_test: 47 | unit = cv2.bilateralFilter(unit, 9, 75, 75) 48 | if 'resize' in preproc: 49 | unit = cv2.resize(unit, (384, 384), interpolation=cv2.INTER_LANCZOS4) 50 | elif 'downsample' in preproc: 51 | unit = cv2.resize(unit, unit.shape[1]//2, unit.shape[0]//2, interpolation=cv2.INTER_LANCZOS4) 52 | 53 | unit = cv2.cvtColor(unit, cv2.COLOR_BGR2RGB) 54 | try: 55 | if 'poisson' in preproc: 56 | # Use poisson noise from official repo or skimage? 57 | 58 | # unit = unit + gen_poisson_noise(unit) * np.random.uniform(0, 0.3) 59 | 60 | unit = random_noise(unit, mode='poisson') # unit: 0 ~ 1 61 | unit = unit * 255 62 | except Exception as e: 63 | print('EX:', e, unit.shape, unit.dtype) 64 | 65 | unit = unit / 127.5 - 1.0 66 | 67 | unit = np.transpose(unit, (2, 0, 1)) 68 | return unit 69 | 70 | 71 | def unit_postprocessing(unit, vid_size=None): 72 | unit = unit.squeeze() 73 | unit = unit.cpu().detach().numpy() 74 | unit = np.clip(unit, -1, 1) 75 | unit = np.round((np.transpose(unit, (1, 2, 0)) + 1.0) * 127.5).astype(np.uint8) 76 | if unit.shape[:2][::-1] != vid_size and vid_size is not None: 77 | unit = cv2.resize(unit, vid_size, interpolation=cv2.INTER_CUBIC) 78 | return unit 79 | 80 | 81 | def get_paths_ABC(config, mode): 82 | if mode in ('train', 'test_on_trainset'): 83 | dir_root = config.dir_train 84 | elif mode == 'test_on_testset': 85 | dir_root = config.dir_test 86 | else: 87 | val_vid = '_'.join(mode.split('_')[2:]) 88 | try: 89 | dir_root = eval('config.dir_{}'.format(val_vid)) 90 | if not os.path.exists(dir_root): 91 | dir_root = os.path.join(config.data_dir, val_vid) 92 | except: 93 | dir_root = os.path.join(config.data_dir, val_vid) 94 | if not os.path.exists(dir_root): 95 | print('Cannot find data at {}.\nExiting the program...'.format(dir_root)) 96 | exit() 97 | paths_A, paths_C, paths_skip_intermediate, paths_skip, paths_mag = [], [], [], [], [] 98 | if config.cursor_end > 0 or 'test' in mode: 99 | dir_A = os.path.join(dir_root, 'frameA') 100 | files_A = sorted(os.listdir(dir_A), key=lambda x: int(x.split('.')[0])) 101 | paths_A = [os.path.join(dir_A, file_A) for file_A in files_A] 102 | if mode == 'train' and isinstance(config.cursor_end, int): 103 | paths_A = paths_A[:config.cursor_end] 104 | paths_C = [p.replace('frameA', 'frameC') for p in paths_A] 105 | paths_mag = [p.replace('frameA', 'amplified') for p in paths_A] 106 | else: 107 | paths_A, paths_C, paths_skip_intermediate, paths_skip = [], [], [], [] 108 | if 'test' not in mode: 109 | path_vids = os.path.join(config.dir_train, 'train_vid_frames') 110 | dirs_vid = [os.path.join(path_vids, p, 'frameA') for p in config.videos_train] 111 | for dir_vid in dirs_vid[:len(config.videos_train)]: 112 | vid_frames = [ 113 | os.path.join(dir_vid, p) for p in sorted( 114 | os.listdir(dir_vid), key=lambda x: int(x.split('.')[0]) 115 | )] 116 | if config.skip < 0: 117 | lst = [p.replace('frameA', 'frameC') for p in vid_frames] 118 | for idx, _ in enumerate(lst): 119 | skip_rand = np.random.randint(min(-config.skip, 2), -config.skip+1) 120 | idx_skip = min(idx + skip_rand, len(lst) - 1) 121 | paths_skip.append(lst[idx_skip]) 122 | paths_skip_intermediate.append(lst[idx_skip//2]) 123 | paths_A += vid_frames 124 | paths_C = [p.replace('frameA', 'frameC') for p in paths_A] 125 | paths_B = [p.replace('frameC', 'frameB') for p in paths_C] 126 | return paths_A, paths_B, paths_C, paths_skip, paths_skip_intermediate, paths_mag 127 | 128 | 129 | class DataGen(): 130 | def __init__(self, paths, config, mode): 131 | self.is_train = 'test' not in mode 132 | self.anchor = 0 133 | self.paths = paths 134 | self.batch_size = config.batch_size if self.is_train else config.batch_size_test 135 | self.data_len = len(paths) 136 | self.load_all = config.load_all 137 | self.data = [] 138 | self.preproc = config.preproc 139 | self.coco_amp_lst = config.coco_amp_lst 140 | 141 | if self.is_train and self.load_all: 142 | self.units_A, self.units_C, self.units_M, self.units_B = [], [], [], [] 143 | for idx_data in range(self.data_len): 144 | if idx_data % 500 == 0: 145 | print('Processing {} / {}.'.format(idx_data, self.data_len)) 146 | unit_A = load_unit(self.paths[idx_data]) 147 | unit_C = load_unit(self.paths[idx_data].replace('frameA', 'frameC')) 148 | unit_M = load_unit(self.paths[idx_data].replace('frameA', 'amplified')) 149 | unit_B = load_unit(self.paths[idx_data].replace('frameA', 'frameB')) 150 | unit_A = unit_preprocessing(unit_A, preproc=self.preproc) 151 | unit_C = unit_preprocessing(unit_C, preproc=self.preproc) 152 | unit_M = unit_preprocessing(unit_M, preproc=[]) 153 | unit_B = unit_preprocessing(unit_B, preproc=self.preproc) 154 | self.units_A.append(unit_A) 155 | self.units_C.append(unit_C) 156 | self.units_M.append(unit_M) 157 | self.units_B.append(unit_B) 158 | 159 | def gen(self, anchor=None): 160 | batch_A = [] 161 | batch_C = [] 162 | batch_M = [] 163 | batch_B = [] 164 | batch_amp = [] 165 | if anchor is None: 166 | anchor = self.anchor 167 | 168 | for _ in range(self.batch_size): 169 | if not self.load_all: 170 | unit_A = load_unit(self.paths[anchor]) 171 | unit_C = load_unit(self.paths[anchor].replace('frameA', 'frameC')) 172 | unit_M = load_unit(self.paths[anchor].replace('frameA', 'amplified')) 173 | unit_B = load_unit(self.paths[anchor].replace('frameA', 'frameB')) 174 | unit_A = unit_preprocessing(unit_A, preproc=self.preproc) 175 | unit_C = unit_preprocessing(unit_C, preproc=self.preproc) 176 | unit_M = unit_preprocessing(unit_M, preproc=[]) 177 | unit_B = unit_preprocessing(unit_B, preproc=self.preproc) 178 | else: 179 | unit_A = self.units_A[anchor] 180 | unit_C = self.units_C[anchor] 181 | unit_M = self.units_M[anchor] 182 | unit_B = self.units_B[anchor] 183 | unit_amp = self.coco_amp_lst[anchor] 184 | 185 | batch_A.append(unit_A) 186 | batch_C.append(unit_C) 187 | batch_M.append(unit_M) 188 | batch_B.append(unit_B) 189 | batch_amp.append(unit_amp) 190 | 191 | self.anchor = (self.anchor + 1) % self.data_len 192 | 193 | batch_A = numpy2cuda(batch_A) 194 | batch_C = numpy2cuda(batch_C) 195 | batch_M = numpy2cuda(batch_M) 196 | batch_B = numpy2cuda(batch_B) 197 | batch_amp = numpy2cuda(batch_amp).reshape(self.batch_size, 1, 1, 1) 198 | return batch_A, batch_B, batch_C, batch_M, batch_amp 199 | 200 | def gen_test(self, anchor=None, inf_size=(0, 0)): 201 | batch_A = [] 202 | batch_C = [] 203 | if anchor is None: 204 | anchor = self.anchor 205 | 206 | for _ in range(self.batch_size): 207 | unit_A = load_unit(self.paths[anchor], inf_size=inf_size) 208 | unit_C = load_unit(self.paths[anchor].replace('frameA', 'frameC'), inf_size=inf_size) 209 | unit_A = unit_preprocessing(unit_A, preproc=[], is_test=True) 210 | unit_C = unit_preprocessing(unit_C, preproc=[], is_test=True) 211 | batch_A.append(unit_A) 212 | batch_C.append(unit_C) 213 | 214 | self.anchor = (self.anchor + 1) % self.data_len 215 | 216 | batch_A = numpy2cuda(batch_A) 217 | batch_C = numpy2cuda(batch_C) 218 | return batch_A, batch_C 219 | 220 | 221 | def get_gen_ABC(config, mode='train'): 222 | paths_A = get_paths_ABC(config, mode)[0] 223 | gen_train_A = DataGen(paths_A, config, mode) 224 | return gen_train_A 225 | 226 | 227 | def cuda2numpy(tensor): 228 | array = tensor.detach().cpu().squeeze().numpy() 229 | return array 230 | 231 | 232 | def numpy2cuda(array): 233 | tensor = torch.from_numpy(np.asarray(array)).float().cuda() 234 | return tensor 235 | 236 | 237 | def resize2d(img, size): 238 | with torch.no_grad(): 239 | img_resized = (F.adaptive_avg_pool2d(Variable(img, volatile=True), size)).data 240 | return img_resized 241 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.autograd as ag 5 | 6 | 7 | def criterion_mag(y, batch_M, texture_AC, texture_BM, motion_BC, criterion): 8 | # One thing deserves mentioning is that the amplified frames given in the dataset are actually perturbed Y(Y'), which I used M to represent. 9 | loss_y = criterion(y, batch_M) 10 | loss_texture_AC = criterion(*texture_AC) 11 | loss_texture_BM = criterion(*texture_BM) 12 | loss_motion_BC = criterion(*motion_BC) 13 | return loss_y, loss_texture_AC, loss_texture_BM, loss_motion_BC 14 | -------------------------------------------------------------------------------- /magnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from data import numpy2cuda 4 | 5 | 6 | def truncated_normal_(tensor, mean=0, std=1): 7 | size = tensor.shape 8 | tmp = tensor.new_empty(size + (4,)).normal_() 9 | valid = (tmp < 2) & (tmp > -2) 10 | ind = valid.max(-1, keepdim=True)[1] 11 | tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) 12 | tensor.data.mul_(std).add_(mean) 13 | return tensor 14 | 15 | 16 | class Conv2D_activa(nn.Module): 17 | def __init__( 18 | self, in_channels, out_channels, kernel_size, stride, 19 | padding=0, dilation=1, activation='relu' 20 | ): 21 | super(Conv2D_activa, self).__init__() 22 | self.padding = padding 23 | if self.padding: 24 | self.pad = nn.ReflectionPad2d(padding) 25 | self.conv2d = nn.Conv2d( 26 | in_channels, out_channels, kernel_size, stride, 27 | dilation=dilation, bias=None 28 | ) 29 | self.activation = activation 30 | if activation == 'relu': 31 | self.activation = nn.ReLU() 32 | 33 | def forward(self, x): 34 | if self.padding: 35 | x = self.pad(x) 36 | x = self.conv2d(x) 37 | if self.activation: 38 | x = self.activation(x) 39 | return x 40 | 41 | class ResBlk(nn.Module): 42 | def __init__(self, dim_in, dim_out, dim_intermediate=32, ks=3, s=1): 43 | super(ResBlk, self).__init__() 44 | p = (ks - 1) // 2 45 | self.cba_1 = Conv2D_activa(dim_in, dim_intermediate, ks, s, p, activation='relu') 46 | self.cba_2 = Conv2D_activa(dim_intermediate, dim_out, ks, s, p, activation=None) 47 | 48 | def forward(self, x): 49 | y = self.cba_1(x) 50 | y = self.cba_2(y) 51 | return y + x 52 | 53 | 54 | def _repeat_blocks(block, dim_in, dim_out, num_blocks, dim_intermediate=32, ks=3, s=1): 55 | blocks = [] 56 | for idx_block in range(num_blocks): 57 | if idx_block == 0: 58 | blocks.append(block(dim_in, dim_out, dim_intermediate=dim_intermediate, ks=ks, s=s)) 59 | else: 60 | blocks.append(block(dim_out, dim_out, dim_intermediate=dim_intermediate, ks=ks, s=s)) 61 | return nn.Sequential(*blocks) 62 | 63 | 64 | class Encoder(nn.Module): 65 | def __init__( 66 | self, dim_in=3, dim_out=32, num_resblk=3, 67 | use_texture_conv=True, use_motion_conv=True, texture_downsample=True, 68 | num_resblk_texture=2, num_resblk_motion=2 69 | ): 70 | super(Encoder, self).__init__() 71 | self.use_texture_conv, self.use_motion_conv = use_texture_conv, use_motion_conv 72 | 73 | self.cba_1 = Conv2D_activa(dim_in, 16, 7, 1, 3, activation='relu') 74 | self.cba_2 = Conv2D_activa(16, 32, 3, 2, 1, activation='relu') 75 | 76 | self.resblks = _repeat_blocks(ResBlk, 32, 32, num_resblk) 77 | 78 | # texture representation 79 | if self.use_texture_conv: 80 | self.texture_cba = Conv2D_activa( 81 | 32, 32, 3, (2 if texture_downsample else 1), 1, 82 | activation='relu' 83 | ) 84 | self.texture_resblks = _repeat_blocks(ResBlk, 32, dim_out, num_resblk_texture) 85 | 86 | # motion representation 87 | if self.use_motion_conv: 88 | self.motion_cba = Conv2D_activa(32, 32, 3, 1, 1, activation='relu') 89 | self.motion_resblks = _repeat_blocks(ResBlk, 32, dim_out, num_resblk_motion) 90 | 91 | def forward(self, x): 92 | x = self.cba_1(x) 93 | x = self.cba_2(x) 94 | x = self.resblks(x) 95 | 96 | if self.use_texture_conv: 97 | texture = self.texture_cba(x) 98 | texture = self.texture_resblks(texture) 99 | else: 100 | texture = self.texture_resblks(x) 101 | 102 | if self.use_motion_conv: 103 | motion = self.motion_cba(x) 104 | motion = self.motion_resblks(motion) 105 | else: 106 | motion = self.motion_resblks(x) 107 | 108 | return texture, motion 109 | 110 | 111 | class Decoder(nn.Module): 112 | def __init__(self, dim_in=32, dim_out=3, num_resblk=9, texture_downsample=True): 113 | super(Decoder, self).__init__() 114 | self.texture_downsample = texture_downsample 115 | 116 | if self.texture_downsample: 117 | self.texture_up = nn.UpsamplingNearest2d(scale_factor=2) 118 | # self.texture_cba = Conv2D_activa(dim_in, 32, 3, 1, 1, activation='relu') 119 | 120 | self.resblks = _repeat_blocks(ResBlk, 64, 64, num_resblk, dim_intermediate=64) 121 | self.up = nn.UpsamplingNearest2d(scale_factor=2) 122 | self.cba_1 = Conv2D_activa(64, 32, 3, 1, 1, activation='relu') 123 | self.cba_2 = Conv2D_activa(32, dim_out, 7, 1, 3, activation=None) 124 | 125 | def forward(self, texture, motion): 126 | if self.texture_downsample: 127 | texture = self.texture_up(texture) 128 | if motion.shape != texture.shape: 129 | texture = nn.functional.interpolate(texture, size=motion.shape[-2:]) 130 | x = torch.cat([texture, motion], 1) 131 | 132 | x = self.resblks(x) 133 | 134 | x = self.up(x) 135 | x = self.cba_1(x) 136 | x = self.cba_2(x) 137 | 138 | return x 139 | 140 | 141 | class Manipulator(nn.Module): 142 | def __init__(self): 143 | super(Manipulator, self).__init__() 144 | self.g = Conv2D_activa(32, 32, 3, 1, 1, activation='relu') 145 | self.h_conv = Conv2D_activa(32, 32, 3, 1, 1, activation=None) 146 | self.h_resblk = ResBlk(32, 32) 147 | 148 | def forward(self, motion_A, motion_B, amp_factor): 149 | motion = motion_B - motion_A 150 | motion_delta = self.g(motion) * amp_factor 151 | motion_delta = self.h_conv(motion_delta) 152 | motion_delta = self.h_resblk(motion_delta) 153 | motion_mag = motion_B + motion_delta 154 | return motion_mag 155 | 156 | 157 | class MagNet(nn.Module): 158 | def __init__(self): 159 | super(MagNet, self).__init__() 160 | self.encoder = Encoder(dim_in=3*1) 161 | self.manipulator = Manipulator() 162 | self.decoder = Decoder(dim_out=3*1) 163 | 164 | def forward(self, batch_A, batch_B, batch_C, batch_M, amp_factor, mode='train'): 165 | if mode == 'train': 166 | texture_A, motion_A = self.encoder(batch_A) 167 | texture_B, motion_B = self.encoder(batch_B) 168 | texture_C, motion_C = self.encoder(batch_C) 169 | texture_M, motion_M = self.encoder(batch_M) 170 | motion_mag = self.manipulator(motion_A, motion_B, amp_factor) 171 | y_hat = self.decoder(texture_B, motion_mag) 172 | texture_AC = [texture_A, texture_C] 173 | motion_BC = [motion_B, motion_C] 174 | texture_BM = [texture_B, texture_M] 175 | return y_hat, texture_AC, texture_BM, motion_BC 176 | elif mode == 'evaluate': 177 | texture_A, motion_A = self.encoder(batch_A) 178 | texture_B, motion_B = self.encoder(batch_B) 179 | motion_mag = self.manipulator(motion_A, motion_B, amp_factor) 180 | y_hat = self.decoder(texture_B, motion_mag) 181 | return y_hat 182 | 183 | 184 | def main(): 185 | model = MagNet() 186 | print('model:\n', model) 187 | 188 | 189 | if __name__ == '__main__': 190 | main() 191 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch.backends.cudnn as cudnn 9 | 10 | from config import Config 11 | from magnet import MagNet 12 | from data import get_gen_ABC, cuda2numpy 13 | from callbacks import save_model, gen_state_dict 14 | from losses import criterion_mag 15 | 16 | 17 | # Configurations 18 | config = Config() 19 | cudnn.benchmark = True 20 | 21 | magnet = MagNet().cuda() 22 | if config.pretrained_weights: 23 | magnet.load_state_dict(gen_state_dict(config.pretrained_weights)) 24 | if torch.cuda.device_count() > 1: 25 | magnet = nn.DataParallel(magnet) 26 | criterion = nn.L1Loss().cuda() 27 | 28 | optimizer = optim.Adam(magnet.parameters(), lr=config.lr, betas=config.betas) 29 | 30 | if not os.path.exists(config.save_dir): 31 | os.makedirs(config.save_dir) 32 | print('Save_dir:', config.save_dir) 33 | 34 | # Data generator 35 | data_loader = get_gen_ABC(config, mode='train') 36 | print('Number of training image couples:', data_loader.data_len) 37 | 38 | # Training 39 | for epoch in range(1, config.epochs+1): 40 | print('epoch:', epoch) 41 | losses, losses_y, losses_texture_AC, losses_texture_BM, losses_motion_BC = [], [], [], [], [] 42 | for idx_load in range(0, data_loader.data_len, data_loader.batch_size): 43 | 44 | # Data Loading 45 | batch_A, batch_B, batch_C, batch_M, batch_amp = data_loader.gen() 46 | 47 | # G Train 48 | optimizer.zero_grad() 49 | y_hat, texture_AC, texture_BM, motion_BC = magnet(batch_A, batch_B, batch_C, batch_M, batch_amp, mode='train') 50 | loss_y, loss_texture_AC, loss_texture_BM, loss_motion_BC = criterion_mag(y_hat, batch_M, texture_AC, texture_BM, motion_BC, criterion) 51 | loss = loss_y + (loss_texture_AC + loss_texture_BM + loss_motion_BC) * 0.1 52 | loss.backward() 53 | optimizer.step() 54 | 55 | # Callbacks 56 | losses.append(loss.item()) 57 | losses_y.append(loss_y.item()) 58 | losses_texture_AC.append(loss_texture_AC.item()) 59 | losses_texture_BM.append(loss_texture_BM.item()) 60 | losses_motion_BC.append(loss_motion_BC.item()) 61 | if ( 62 | idx_load > 0 and 63 | ((idx_load // data_loader.batch_size) % 64 | (data_loader.data_len // data_loader.batch_size // config.num_val_per_epoch)) == 0 65 | ): 66 | print(', {}%'.format(idx_load * 100 // data_loader.data_len), end='') 67 | 68 | # Collections 69 | save_model(magnet.state_dict(), losses, config.save_dir, epoch) 70 | print('\ntime: {}m, ep: {}, loss: {:.3e}, y: {:.3e}, tex_AC: {:.3e}, tex_BM: {:.3e}, mot_BC: {:.3e}'.format( 71 | int((time.time()-config.time_st)/60), epoch, np.mean(losses), np.mean(losses_y), np.mean(losses_texture_AC), np.mean(losses_texture_BM), np.mean(losses_motion_BC) 72 | )) 73 | -------------------------------------------------------------------------------- /make_frameACB.py: -------------------------------------------------------------------------------- 1 | """ 2 | Put it into the corresponding datasets directory, e.g. `/datasets/motion_mag_data/train/train_vid_frames` for me. 3 | Make the original frames into frameAs, frameBs, frameCs(same as frameBs here) 4 | """ 5 | import os 6 | import sys 7 | 8 | 9 | # Choose the dir you want 10 | dirs = sorted([i for i in os.listdir('.') if i in 11 | sys.argv[1].split('+') 12 | # and int(i.split('_')[-1].split('.')[0]) > 0 13 | ] 14 | # , key=lambda x: int(x.split('_')[-1]) 15 | )[:] 16 | 17 | image_format_name='png' 18 | 19 | for d in dirs: 20 | print('ACB-Processing on', d) 21 | os.chdir(d) 22 | os.mkdir('frameA') 23 | os.mkdir('frameC') 24 | files = sorted([f for f in os.listdir('.') if os.path.splitext(f)[1] == '.{}'.format(image_format_name)], key=lambda x: int(x.split('.')[0])) 25 | os.system('cp ./*{} frameA && cp ./*{} frameC'.format(image_format_name, image_format_name)) 26 | os.remove(os.path.join('frameA', files[-1])) 27 | os.remove(os.path.join('frameC', files[0])) 28 | for f in sorted(os.listdir('frameC'), key=lambda x: int(x.split('.')[0])): 29 | f_new = os.path.join('frameC', '%06d' % (int(f.split('.')[0])-1) + '.{}'.format(image_format_name)) 30 | f = os.path.join('frameC', f) 31 | os.rename(f, f_new) 32 | os.system('cp -r frameC frameB') 33 | os.system('rm ./*.{}'.format(image_format_name)) 34 | os.chdir('..') 35 | -------------------------------------------------------------------------------- /materials/Fig2-a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhengPeng7/motion_magnification_learning-based/6093402312f1c4c69cae1e6bac888407598ec414/materials/Fig2-a.png -------------------------------------------------------------------------------- /materials/baby_comp.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhengPeng7/motion_magnification_learning-based/6093402312f1c4c69cae1e6bac888407598ec414/materials/baby_comp.gif -------------------------------------------------------------------------------- /materials/dogs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhengPeng7/motion_magnification_learning-based/6093402312f1c4c69cae1e6bac888407598ec414/materials/dogs.png -------------------------------------------------------------------------------- /materials/guitar_comp.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhengPeng7/motion_magnification_learning-based/6093402312f1c4c69cae1e6bac888407598ec414/materials/guitar_comp.gif -------------------------------------------------------------------------------- /materials/myself_comp.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhengPeng7/motion_magnification_learning-based/6093402312f1c4c69cae1e6bac888407598ec414/materials/myself_comp.gif -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.7.0 2 | torchvision 3 | opencv-python 4 | scikit-image 5 | scikit-learn 6 | ipykernel 7 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | python main.py 2 | python test_video.py baby-guitar-water-gun-drone-cattoy 3 | -------------------------------------------------------------------------------- /test_video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import torch 5 | import numpy as np 6 | from config import Config 7 | from magnet import MagNet 8 | from data import get_gen_ABC, unit_postprocessing, numpy2cuda, resize2d 9 | from callbacks import gen_state_dict 10 | 11 | 12 | # config 13 | config = Config() 14 | # Load weights 15 | ep = '' 16 | weights_file = sorted( 17 | [p for p in os.listdir(config.save_dir) if '_loss' in p and '_epoch{}'.format(ep) in p and 'D' not in p], 18 | key=lambda x: float(x.rstrip('.pth').split('_loss')[-1]) 19 | )[0] 20 | weights_path = os.path.join(config.save_dir, weights_file) 21 | ep = int(weights_path.split('epoch')[-1].split('_')[0]) 22 | state_dict = gen_state_dict(weights_path) 23 | 24 | model_test = MagNet().cuda() 25 | model_test.load_state_dict(state_dict) 26 | model_test.eval() 27 | print("Loading weights:", weights_file) 28 | 29 | if len(sys.argv) == 1: 30 | testsets = 'baby-guitar-gun-drone-cattoy-water' 31 | else: 32 | testsets = sys.argv[-1] 33 | testsets = testsets.split('-') 34 | dir_results = config.save_dir.replace('weights', 'results') 35 | for testset in testsets: 36 | if not os.path.exists(dir_results): 37 | os.makedirs(dir_results) 38 | 39 | data_loader = get_gen_ABC(config, mode='test_on_'+testset) 40 | print('Number of test image couples:', data_loader.data_len) 41 | vid_size = cv2.imread(data_loader.paths[0]).shape[:2][::-1] 42 | 43 | # Test 44 | for amp in [5, 10, 30, 50]: 45 | frames = [] 46 | data_loader = get_gen_ABC(config, mode='test_on_'+testset) 47 | for idx_load in range(0, data_loader.data_len, data_loader.batch_size): 48 | if (idx_load+1) % 100 == 0: 49 | print('{}'.format(idx_load+1), end=', ') 50 | batch_A, batch_B = data_loader.gen_test() 51 | amp_factor = numpy2cuda(amp) 52 | for _ in range(len(batch_A.shape) - len(amp_factor.shape)): 53 | amp_factor = amp_factor.unsqueeze(-1) 54 | with torch.no_grad(): 55 | y_hats = model_test(batch_A, batch_B, 0, 0, amp_factor, mode='evaluate') 56 | for y_hat in y_hats: 57 | y_hat = unit_postprocessing(y_hat, vid_size=vid_size) 58 | frames.append(y_hat) 59 | if len(frames) >= data_loader.data_len: 60 | break 61 | if len(frames) >= data_loader.data_len: 62 | break 63 | data_loader = get_gen_ABC(config, mode='test_on_'+testset) 64 | frames = [unit_postprocessing(data_loader.gen_test()[0], vid_size=vid_size)] + frames 65 | 66 | # Make videos of framesMag 67 | video_dir = os.path.join(dir_results, testset) 68 | if not os.path.exists(video_dir): 69 | os.makedirs(video_dir) 70 | FPS = 30 71 | out = cv2.VideoWriter( 72 | os.path.join(video_dir, '{}_amp{}.avi'.format(testset, amp)), 73 | cv2.VideoWriter_fourcc(*'DIVX'), 74 | FPS, frames[0].shape[-2::-1] 75 | ) 76 | for frame in frames: 77 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) 78 | cv2.putText(frame, 'amp_factor={}'.format(amp), (7, 37), 79 | fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255), thickness=2) 80 | out.write(frame) 81 | out.release() 82 | print('{} has been done.'.format(os.path.join(video_dir, '{}_amp{}.avi'.format(testset, amp)))) 83 | 84 | --------------------------------------------------------------------------------