├── .gitignore
├── .vscode
└── settings.json
├── LICENSE
├── README.md
├── VMM_learning_based_demo.ipynb
├── callbacks.py
├── config.py
├── data.py
├── losses.py
├── magnet.py
├── main.py
├── make_frameACB.py
├── materials
├── Fig2-a.png
├── baby_comp.gif
├── dogs.png
├── guitar_comp.gif
└── myself_comp.gif
├── requirements.txt
├── run.sh
└── test_video.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Custom
2 | .vscode
3 | weights*
4 | results*
5 |
6 |
7 | # Byte-compiled / optimized / DLL files
8 | __pycache__/
9 | *.py[cod]
10 | *$py.class
11 |
12 | # C extensions
13 | *.so
14 |
15 | # Distribution / packaging
16 | .Python
17 | build/
18 | develop-eggs/
19 | dist/
20 | downloads/
21 | eggs/
22 | .eggs/
23 | lib/
24 | lib64/
25 | parts/
26 | sdist/
27 | var/
28 | wheels/
29 | *.egg-info/
30 | .installed.cfg
31 | *.egg
32 | MANIFEST
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | .hypothesis/
54 | .pytest_cache/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 |
65 | # Flask stuff:
66 | instance/
67 | .webassets-cache
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/_build/
74 |
75 | # PyBuilder
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # pyenv
82 | .python-version
83 |
84 | # celery beat schedule file
85 | celerybeat-schedule
86 |
87 | # SageMath parsed files
88 | *.sage.py
89 |
90 | # Environments
91 | .env
92 | .venv
93 | env/
94 | venv/
95 | ENV/
96 | env.bak/
97 | venv.bak/
98 |
99 | # Spyder project settings
100 | .spyderproject
101 | .spyproject
102 |
103 | # Rope project settings
104 | .ropeproject
105 |
106 | # mkdocs documentation
107 | /site
108 |
109 | # mypy
110 | .mypy_cache/
111 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "python.pythonPath": "/home/pengzheng/miniconda3/envs/mm/bin/python"
3 | }
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 ZhengPeng
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Motion_magnification_learning-based
2 | This is an unofficial implementation of "[Learning-based Video Motion Magnification](https://arxiv.org/abs/1804.02684)" in Pytorch (1.8.1~2.0).
3 | [Here is the official implementation in Tensorflow==1.8.0](https://github.com/12dmodel/deep_motion_mag).
4 |
5 | #### High recommendations on my friends' latest works, come and try them!
6 | + Event-Based Motion Magnification: [[paper](https://arxiv.org/pdf/2402.11957.pdf)] [[codes](https://github.com/OpenImagingLab/emm)] [[project](https://openimaginglab.github.io/emm/)]
7 | + Frequency Decoupling for Motion Magnification via Multi-Level Isomorphic Architecture: [[paper](https://arxiv.org/pdf/2403.07347.pdf)] [[codes](https://github.com/Jiafei127/FD4MM)]
8 |
9 | # Update
10 | **(2023/11/05) Add notebook demo for offline inference. Feel free to email me or leave issues if you want any help I can do.**
11 |
12 | **(2023/04/07) I find there are still a few friends like you who have interests in this old repo, so I make a Colab demo for easy inference if you want. And I'm sorry for my stupid codes years ago, I felt painful when I used them for the Colab demo... And you know, some still exist 😂 But if you have any trouble with it, feel free to leave an issue or send an e-mail to me.**
13 |
14 | Besides, as tested, this repo can be compatible with **PyTorch 2.x**
15 |
16 | *Given the video, and amplify it with only one click for all steps:*
17 |
18 | [](https://colab.research.google.com/drive/1inOucehJXUAVBlRhZvo650SoOPLKQFNv#scrollTo=BjgKRohk7Q5M)
19 |
20 |
21 |
22 |
23 |
24 | # Env
25 | ```
26 | conda create -n vmm python=3.10 -y && conda activate vmm
27 | pip install -r requirements.txt
28 | ```
29 |
30 | # Data preparation
31 |
32 | 0. About the synthetic dataset for **training**, please refer to the official repository mentioned above or download [here](https://drive.google.com/drive/folders/19K09QLouiV5N84wZiTPUMdoH9-UYqZrX?usp=sharing).
33 |
34 | 1. About the video datasets for **validation**, you can also download the preprocessed frames [here](https://drive.google.com/drive/folders/19K09QLouiV5N84wZiTPUMdoH9-UYqZrX?usp=sharing), which is named train_vid_frames.zip.
35 |
36 | 2. Check the settings of val_dir in **config.py** and modify it if necessary.
37 |
38 | 3. To convert the **validation** video into frames:
39 |
40 | `mkdir VIDEO_NAME && ffmpeg -i VIDEO_NAME.mp4 -f image2 VIDEO_NAME/%06d.png`
41 |
42 | > Tips: ffmpeg can also be installed by conda.
43 |
44 | 4. Modify the frames into **frameA/frameB/frameC**:
45 |
46 | `python make_frameACB.py `(remember adapt the 'if' at the beginning of the program to select videos.)
47 |
48 | # Little differences from the official codes
49 |
50 | 1. **Poisson noise** is not used here because I was a bit confused about that in official code. Although I coded it in data.py, and it works exactly the same as the official codes as I checked by examples.
51 | 2. About the **optimizer**, we kept it the same as that in the original paper -- Adam(lr=1e-4, betas=(0.9, 0.999)) with no weight decay, which is different from the official codes.
52 | 3. About the
in loss, we also adhere to the original paper -- set to 0.1, which is different from the official codes.
53 | 4. The **temporal filter** is currently a bit confusing for me, so I haven't made the part of testing with temporal filter, sorry for that:(...
54 |
55 | # One thing **important**
56 |
57 | If you check the Fig.2-a in the original paper, you will find that the predicted magnified frame
is actually
, although the former one is theoretically same as
with the same
.
58 |
59 |
60 |
61 | However, what makes it matter is that the authors used perturbation for regularization, and the images in the dataset given has 4 parts:
62 |
63 | 1. frameA:
, unperturbed;
64 | 2. frameB: perturbed frameC, is actually
in the paper,
65 | 3. frameC: the real
, unperturbed;
66 | 4. **amplified**: represent both
and
, perturbed.
67 |
68 | Here is the first training sample, where you can see clear that **no perturbation** between **A-C** nor between **B-amp**, and no motion between B-C:
69 |
70 |
71 |
72 | Given that, we don't have the unperturbed amplified frame, so **we can only use the former formula**(with
). Besides, if you check the **loss** in the original paper, you will find the
, where is the
?... I also referred to some third-party reproductions on this problem which confused me a lot, but none of them solve it. And some just gave 0 to
manually, so I think they noticed this problem too but didn't manage to understand it.
73 |
74 | Here are some links to the issues about this problem in the official repository, [issue-1](https://github.com/12dmodel/deep_motion_mag/issues/3), [issue-2](https://github.com/12dmodel/deep_motion_mag/issues/5), [issue-3](https://github.com/12dmodel/deep_motion_mag/issues/4), if you want to check them.
75 |
76 | # Run
77 | `bash run.sh` to train and test.
78 |
79 | It took me around 20 hours to train for 12 epochs on a single TITAN-Xp.
80 |
81 | If you don't want to use all the 100,000 groups to train, you can modify the `frames_train='coco100000'` in config.py to coco30000 or some other number.
82 |
83 | You can **download the weights**-ep12 from [the release](https://github.com/ZhengPeng7/motion_magnification_learning-based/releases/tag/v1.0), and `python test_videos.py baby-guitar-yourself-...` to do the test.
84 |
85 | # Results
86 |
87 | Here are some results generated from the model trained on the whole synthetic dataset for **12** epochs.
88 |
89 | Baby, amplification factor = 50
90 |
91 | 
92 |
93 | Guitar, amplification factor = 20
94 |
95 | 
96 |
97 | And I also took a video on the face of myself with amplification factor 20, which showed a Chinese idiom called '夺眶而出'😂.
98 |
99 | 
100 |
101 | > Any question, all welcome:)
102 |
--------------------------------------------------------------------------------
/VMM_learning_based_demo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "bf5RrhNaDejC"
7 | },
8 | "source": [
9 | "# An offline demo for the \"Learning-based Video Motion Magnification\" (ECCV 2018)\n",
10 | "\n",
11 | "# Official repo: https://github.com/12dmodel/deep_motion_mag\n",
12 | "# My repo: https://github.com/ZhengPeng7/motion_magnification_learning-based"
13 | ]
14 | },
15 | {
16 | "cell_type": "markdown",
17 | "metadata": {},
18 | "source": [
19 | "## Make sure to set your python Kernel for notebook. "
20 | ]
21 | },
22 | {
23 | "cell_type": "markdown",
24 | "metadata": {
25 | "id": "Z7Ghq0laIGq4"
26 | },
27 | "source": [
28 | "## Preparations:\n"
29 | ]
30 | },
31 | {
32 | "cell_type": "markdown",
33 | "metadata": {
34 | "id": "pWvGFo9yECoa"
35 | },
36 | "source": [
37 | "### Install python packages"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 1,
43 | "metadata": {
44 | "colab": {
45 | "base_uri": "https://localhost:8080/"
46 | },
47 | "id": "nI-Hq9WK5xNg",
48 | "outputId": "4c55ba46-ee51-4752-8dd1-deab4ffe704e"
49 | },
50 | "outputs": [
51 | {
52 | "name": "stdout",
53 | "output_type": "stream",
54 | "text": [
55 | "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
56 | "\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
57 | "\u001b[0m"
58 | ]
59 | }
60 | ],
61 | "source": [
62 | "!pip install --quiet -r requirements.txt\n",
63 | "!pip install --quiet gdown mediapy"
64 | ]
65 | },
66 | {
67 | "cell_type": "markdown",
68 | "metadata": {
69 | "id": "nhA7-odZEHQW"
70 | },
71 | "source": [
72 | "### Download and load the well-trained weights:"
73 | ]
74 | },
75 | {
76 | "cell_type": "code",
77 | "execution_count": 2,
78 | "metadata": {
79 | "colab": {
80 | "base_uri": "https://localhost:8080/"
81 | },
82 | "id": "FUX3pb77Axr0",
83 | "outputId": "11164c11-73dc-4418-8248-dfb9ab66d535"
84 | },
85 | "outputs": [
86 | {
87 | "name": "stdout",
88 | "output_type": "stream",
89 | "text": [
90 | "Will not apply HSTS. The HSTS database must be a regular and non-world-writable file.\n",
91 | "ERROR: could not open HSTS store at '/root/.wget-hsts'. HSTS will be disabled.\n",
92 | "--2023-11-05 14:03:48-- https://github.com/ZhengPeng7/motion_magnification_learning-based/releases/download/v1.0/magnet_epoch12_loss7.28e-02.pth\n",
93 | "Resolving github.com (github.com)... 47.93.241.142, 39.106.83.44\n",
94 | "Connecting to github.com (github.com)|47.93.241.142|:443... connected.\n",
95 | "HTTP request sent, awaiting response... 302 Found\n",
96 | "Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/223970988/4e613b00-aa97-11ea-87b4-bef43fcc6281?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20231105%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20231105T060349Z&X-Amz-Expires=300&X-Amz-Signature=054d40cd1599fce311358e5952e84073f9d390e242c658cc5a3de30142350113&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=223970988&response-content-disposition=attachment%3B%20filename%3Dmagnet_epoch12_loss7.28e-02.pth&response-content-type=application%2Foctet-stream [following]\n",
97 | "--2023-11-05 14:03:49-- https://objects.githubusercontent.com/github-production-release-asset-2e65be/223970988/4e613b00-aa97-11ea-87b4-bef43fcc6281?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20231105%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20231105T060349Z&X-Amz-Expires=300&X-Amz-Signature=054d40cd1599fce311358e5952e84073f9d390e242c658cc5a3de30142350113&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=223970988&response-content-disposition=attachment%3B%20filename%3Dmagnet_epoch12_loss7.28e-02.pth&response-content-type=application%2Foctet-stream\n",
98 | "Resolving objects.githubusercontent.com (objects.githubusercontent.com)... 185.199.109.133, 185.199.110.133, 185.199.111.133, ...\n",
99 | "Connecting to objects.githubusercontent.com (objects.githubusercontent.com)|185.199.109.133|:443... connected.\n",
100 | "HTTP request sent, awaiting response... 200 OK\n",
101 | "Length: 3528201 (3.4M) [application/octet-stream]\n",
102 | "Saving to: ‘magnet_epoch12_loss7.28e-02.pth’\n",
103 | "\n",
104 | "magnet_epoch12_loss 100%[===================>] 3.36M 15.3KB/s in 3m 17s \n",
105 | "\n",
106 | "2023-11-05 14:07:09 (17.5 KB/s) - ‘magnet_epoch12_loss7.28e-02.pth’ saved [3528201/3528201]\n",
107 | "\n"
108 | ]
109 | }
110 | ],
111 | "source": [
112 | "!wget https://github.com/ZhengPeng7/motion_magnification_learning-based/releases/download/v1.0/magnet_epoch12_loss7.28e-02.pth"
113 | ]
114 | },
115 | {
116 | "cell_type": "code",
117 | "execution_count": 3,
118 | "metadata": {
119 | "colab": {
120 | "base_uri": "https://localhost:8080/"
121 | },
122 | "id": "quAf0VARHAPM",
123 | "outputId": "7eb23fde-56e9-4797-dd0a-091b559aaec8"
124 | },
125 | "outputs": [
126 | {
127 | "name": "stdout",
128 | "output_type": "stream",
129 | "text": [
130 | "Please load train_mf.txt if you want to do training.\n",
131 | "Loading weights: magnet_epoch12_loss7.28e-02.pth\n"
132 | ]
133 | }
134 | ],
135 | "source": [
136 | "from magnet import MagNet\n",
137 | "from callbacks import gen_state_dict\n",
138 | "from config import Config\n",
139 | "\n",
140 | "\n",
141 | "# config\n",
142 | "config = Config()\n",
143 | "# Load weights\n",
144 | "weights_path = 'magnet_epoch12_loss7.28e-02.pth'\n",
145 | "ep = int(weights_path.split('epoch')[-1].split('_')[0])\n",
146 | "state_dict = gen_state_dict(weights_path)\n",
147 | "\n",
148 | "model_test = MagNet().cuda()\n",
149 | "model_test.load_state_dict(state_dict)\n",
150 | "model_test.eval()\n",
151 | "print(\"Loading weights:\", weights_path)"
152 | ]
153 | },
154 | {
155 | "cell_type": "markdown",
156 | "metadata": {
157 | "id": "_AM4n9P-iaNG"
158 | },
159 | "source": [
160 | "# Preprocess\n",
161 | "\n",
162 | "Make the video to frameAs/frameBs/frameCs.\n",
163 | "\n",
164 | "Let's take the guitar.mp4 as an example."
165 | ]
166 | },
167 | {
168 | "cell_type": "code",
169 | "execution_count": null,
170 | "metadata": {
171 | "colab": {
172 | "base_uri": "https://localhost:8080/"
173 | },
174 | "id": "tnJSrX43vlgy",
175 | "outputId": "bb3fec89-a7f1-4246-a4c7-7f08548ca905"
176 | },
177 | "outputs": [],
178 | "source": [
179 | "# Download some example video from my google-drive or upload your own video.\n",
180 | "# !gdown 1hNZ02vnSO04FYS9jkx2OjProYHIvwdkB # guitar.avi\n",
181 | "!gdown 1XGC2y4Lshd9aBiBxwkTuT_IA79n-WNST # baby.avi\n",
182 | "# !gdown 1QGOWuR0swF7_eHharTztlkEDz0hlfmU4 # zhiyin.mp4"
183 | ]
184 | },
185 | {
186 | "cell_type": "markdown",
187 | "metadata": {
188 | "id": "pd9lV4mCHU9v"
189 | },
190 | "source": [
191 | "# Set VIDEO_NAME here, e.g., guitar, baby"
192 | ]
193 | },
194 | {
195 | "cell_type": "code",
196 | "execution_count": 5,
197 | "metadata": {
198 | "colab": {
199 | "base_uri": "https://localhost:8080/"
200 | },
201 | "id": "I0v0Uacfmtib",
202 | "outputId": "01383531-8a69-42f5-8b7c-a280acff41c8"
203 | },
204 | "outputs": [
205 | {
206 | "name": "stdout",
207 | "output_type": "stream",
208 | "text": [
209 | "ffmpeg version 4.3 Copyright (c) 2000-2020 the FFmpeg developers\n",
210 | " built with gcc 7.3.0 (crosstool-NG 1.23.0.449-a04d0)\n",
211 | " configuration: --prefix=/opt/conda/conda-bld/ffmpeg_1597178665428/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placeh --cc=/opt/conda/conda-bld/ffmpeg_1597178665428/_build_env/bin/x86_64-conda_cos6-linux-gnu-cc --disable-doc --disable-openssl --enable-avresample --enable-gnutls --enable-hardcoded-tables --enable-libfreetype --enable-libopenh264 --enable-pic --enable-pthreads --enable-shared --disable-static --enable-version3 --enable-zlib --enable-libmp3lame\n",
212 | " libavutil 56. 51.100 / 56. 51.100\n",
213 | " libavcodec 58. 91.100 / 58. 91.100\n",
214 | " libavformat 58. 45.100 / 58. 45.100\n",
215 | " libavdevice 58. 10.100 / 58. 10.100\n",
216 | " libavfilter 7. 85.100 / 7. 85.100\n",
217 | " libavresample 4. 0. 0 / 4. 0. 0\n",
218 | " libswscale 5. 7.100 / 5. 7.100\n",
219 | " libswresample 3. 7.100 / 3. 7.100\n",
220 | "Input #0, avi, from 'baby.avi':\n",
221 | " Duration: 00:00:10.03, start: 0.000000, bitrate: 15411 kb/s\n",
222 | " Stream #0:0: Video: mjpeg (Baseline) (MJPG / 0x47504A4D), yuvj420p(pc, bt470bg/unknown/unknown), 960x544 [SAR 1:1 DAR 30:17], 15399 kb/s, 30 fps, 30 tbr, 30 tbn, 30 tbc\n",
223 | "Stream mapping:\n",
224 | " Stream #0:0 -> #0:0 (mjpeg (native) -> png (native))\n",
225 | "Press [q] to stop, [?] for help\n",
226 | "\u001b[1;34m[swscaler @ 0x561bdb17b380] \u001b[0m\u001b[0;33mdeprecated pixel format used, make sure you did set range correctly\n",
227 | "\u001b[0mOutput #0, image2, to 'baby/%06d.png':\n",
228 | " Metadata:\n",
229 | " encoder : Lavf58.45.100\n",
230 | " Stream #0:0: Video: png, rgb24, 960x544 [SAR 1:1 DAR 30:17], q=2-31, 200 kb/s, 30 fps, 30 tbn, 30 tbc\n",
231 | " Metadata:\n",
232 | " encoder : Lavc58.91.100 png\n",
233 | "frame= 301 fps= 48 q=-0.0 Lsize=N/A time=00:00:10.03 bitrate=N/A speed= 1.6x \n",
234 | "video:187776kB audio:0kB subtitle:0kB other streams:0kB global headers:0kB muxing overhead: unknown\n",
235 | "ACB-Processing on baby\n"
236 | ]
237 | }
238 | ],
239 | "source": [
240 | "# Turn the video into frames and make them into frame_ACB format.\n",
241 | "file_to_be_maged = 'baby.avi' # 'guitar.avi'\n",
242 | "video_name = file_to_be_maged.split('.')[0]\n",
243 | "video_format = '.' + file_to_be_maged.split('.')[-1]\n",
244 | "\n",
245 | "\n",
246 | "sh_file = 'VIDEO_NAME={}\\nVIDEO_FORMAT={}'.format(video_name, video_format) + \"\"\"\n",
247 | "\n",
248 | "\n",
249 | "mkdir ${VIDEO_NAME}\n",
250 | "ffmpeg -i ${VIDEO_NAME}${VIDEO_FORMAT} -f image2 ${VIDEO_NAME}/%06d.png\n",
251 | "python make_frameACB.py ${VIDEO_NAME}\n",
252 | "mkdir test_dir\n",
253 | "mv ${VIDEO_NAME} test_dir\n",
254 | "\"\"\"\n",
255 | "with open('test_preproc.sh', 'w') as file:\n",
256 | " file.write(sh_file)\n",
257 | "\n",
258 | "!bash test_preproc.sh"
259 | ]
260 | },
261 | {
262 | "cell_type": "markdown",
263 | "metadata": {
264 | "id": "OeOYfifiGUPq"
265 | },
266 | "source": [
267 | "## Test\n"
268 | ]
269 | },
270 | {
271 | "cell_type": "code",
272 | "execution_count": 6,
273 | "metadata": {
274 | "colab": {
275 | "base_uri": "https://localhost:8080/"
276 | },
277 | "id": "RXz3mYanz07-",
278 | "outputId": "9d526b43-6f76-404c-fa91-235ea6b5ee72"
279 | },
280 | "outputs": [
281 | {
282 | "name": "stdout",
283 | "output_type": "stream",
284 | "text": [
285 | "Number of test image couples: 300\n",
286 | "100, 200, 300, res_baby/baby/baby_amp10.avi has been done.\n",
287 | "100, 200, 300, res_baby/baby/baby_amp25.avi has been done.\n",
288 | "100, 200, 300, res_baby/baby/baby_amp50.avi has been done.\n"
289 | ]
290 | }
291 | ],
292 | "source": [
293 | "import os\n",
294 | "import sys\n",
295 | "import cv2\n",
296 | "import torch\n",
297 | "import numpy as np\n",
298 | "from data import get_gen_ABC, unit_postprocessing, numpy2cuda, resize2d\n",
299 | "\n",
300 | "\n",
301 | "# testsets = '+'.join(['baby', 'guitar', 'gun', 'drone', 'cattoy', 'water'][:])\n",
302 | "for testset in [video_name]:\n",
303 | " dir_results = 'res_' + testset\n",
304 | " if not os.path.exists(dir_results):\n",
305 | " os.makedirs(dir_results)\n",
306 | "\n",
307 | " config.data_dir = 'test_dir'\n",
308 | " data_loader = get_gen_ABC(config, mode='test_on_'+testset)\n",
309 | " print('Number of test image couples:', data_loader.data_len)\n",
310 | " vid_size = cv2.imread(data_loader.paths[0]).shape[:2][::-1]\n",
311 | "\n",
312 | " # Test\n",
313 | " for amp in [10, 25, 50]:\n",
314 | " frames = []\n",
315 | " data_loader = get_gen_ABC(config, mode='test_on_'+testset)\n",
316 | " for idx_load in range(0, data_loader.data_len, data_loader.batch_size):\n",
317 | " if (idx_load+1) % 100 == 0:\n",
318 | " print('{}'.format(idx_load+1), end=', ')\n",
319 | " batch_A, batch_B = data_loader.gen_test()\n",
320 | " amp_factor = numpy2cuda(amp)\n",
321 | " for _ in range(len(batch_A.shape) - len(amp_factor.shape)):\n",
322 | " amp_factor = amp_factor.unsqueeze(-1)\n",
323 | " with torch.no_grad():\n",
324 | " y_hats = model_test(batch_A, batch_B, 0, 0, amp_factor, mode='evaluate')\n",
325 | " for y_hat in y_hats:\n",
326 | " y_hat = unit_postprocessing(y_hat, vid_size=vid_size)\n",
327 | " frames.append(y_hat)\n",
328 | " if len(frames) >= data_loader.data_len:\n",
329 | " break\n",
330 | " if len(frames) >= data_loader.data_len:\n",
331 | " break\n",
332 | " data_loader = get_gen_ABC(config, mode='test_on_'+testset)\n",
333 | " frames = [unit_postprocessing(data_loader.gen_test()[0], vid_size=vid_size)] + frames\n",
334 | "\n",
335 | " # Make videos of framesMag\n",
336 | " video_dir = os.path.join(dir_results, testset)\n",
337 | " if not os.path.exists(video_dir):\n",
338 | " os.makedirs(video_dir)\n",
339 | " FPS = 30\n",
340 | " video_save_path = os.path.join(video_dir, '{}_amp{}{}'.format(testset, amp, video_format))\n",
341 | " out = cv2.VideoWriter(\n",
342 | " video_save_path,\n",
343 | " cv2.VideoWriter_fourcc(*'DIVX'),\n",
344 | " FPS, frames[0].shape[-2::-1]\n",
345 | " )\n",
346 | " for frame in frames:\n",
347 | " frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)\n",
348 | " cv2.putText(frame, 'amp_factor={}'.format(amp), (7, 37),\n",
349 | " fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255), thickness=2)\n",
350 | " out.write(frame)\n",
351 | " out.release()\n",
352 | " print('{} has been done.'.format(video_save_path))"
353 | ]
354 | },
355 | {
356 | "cell_type": "code",
357 | "execution_count": null,
358 | "metadata": {},
359 | "outputs": [],
360 | "source": [
361 | "# Play the amplified video here\n",
362 | "from glob import glob\n",
363 | "import mediapy\n",
364 | "\n",
365 | "\n",
366 | "video_save_paths = [file_to_be_maged] + sorted(glob(os.path.join(dir_results, testset, '*')), key=lambda x: int(x.split('amp')[-1].split('.')[0]))\n",
367 | "\n",
368 | "video_dict = {}\n",
369 | "for video_save_path in video_save_paths[:]:\n",
370 | " video_dict[video_save_path.split('/')[-1]] = mediapy.read_video(video_save_path)\n",
371 | "mediapy.show_videos(video_dict, fps=FPS, width=250, codec='gif')"
372 | ]
373 | },
374 | {
375 | "cell_type": "code",
376 | "execution_count": null,
377 | "metadata": {},
378 | "outputs": [],
379 | "source": []
380 | }
381 | ],
382 | "metadata": {
383 | "accelerator": "GPU",
384 | "colab": {
385 | "provenance": []
386 | },
387 | "kernelspec": {
388 | "display_name": "Python 3 (ipykernel)",
389 | "language": "python",
390 | "name": "python3"
391 | },
392 | "language_info": {
393 | "codemirror_mode": {
394 | "name": "ipython",
395 | "version": 3
396 | },
397 | "file_extension": ".py",
398 | "mimetype": "text/x-python",
399 | "name": "python",
400 | "nbconvert_exporter": "python",
401 | "pygments_lexer": "ipython3",
402 | "version": "3.10.13"
403 | }
404 | },
405 | "nbformat": 4,
406 | "nbformat_minor": 4
407 | }
408 |
--------------------------------------------------------------------------------
/callbacks.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 |
5 |
6 | def save_model(weights, losses, weights_dir, epoch):
7 | loss = np.mean(losses)
8 | path_ckpt = os.path.join(
9 | weights_dir, 'magnet_epoch{}_loss{:.2e}.pth'.format(epoch, loss)
10 | )
11 | torch.save(weights, path_ckpt)
12 |
13 |
14 | def gen_state_dict(weights_path):
15 | st = torch.load(weights_path)
16 | st_ks = list(st.keys())
17 | st_vs = list(st.values())
18 | state_dict = {}
19 | for st_k, st_v in zip(st_ks, st_vs):
20 | state_dict[st_k.replace('module.', '')] = st_v
21 | return state_dict
22 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import numpy as np
4 | import torch
5 |
6 | class Config(object):
7 | def __init__(self):
8 |
9 | # General
10 | self.epochs = 12
11 | # self.GPUs = '0'
12 | self.batch_size = 6 # * torch.cuda.device_count() # len(self.GPUs.split(','))
13 | self.date = '0510'
14 |
15 | # Data
16 | self.data_dir = '../../../datasets/mm'
17 | self.dir_train = os.path.join(self.data_dir, 'train')
18 | self.dir_test = os.path.join(self.data_dir, 'test')
19 | self.dir_water = os.path.join(self.data_dir, 'train/train_vid_frames/val_water')
20 | self.dir_baby = os.path.join(self.data_dir, 'train/train_vid_frames/val_baby')
21 | self.dir_gun = os.path.join(self.data_dir, 'train/train_vid_frames/val_gun')
22 | self.dir_drone = os.path.join(self.data_dir, 'train/train_vid_frames/val_drone')
23 | self.dir_guitar = os.path.join(self.data_dir, 'train/train_vid_frames/val_guitar')
24 | self.dir_cattoy = os.path.join(self.data_dir, 'train/train_vid_frames/val_cattoy')
25 | self.dir_myself = os.path.join(self.data_dir, 'train/train_vid_frames/myself')
26 | self.frames_train = 'coco100000' # you can adapt 100000 to a smaller number to train
27 | self.cursor_end = int(self.frames_train.split('coco')[-1])
28 | if os.path.exists(os.path.join(self.dir_train, 'train_mf.txt')):
29 | self.coco_amp_lst = np.loadtxt(os.path.join(self.dir_train, 'train_mf.txt'))[:self.cursor_end]
30 | else:
31 | print('Please load train_mf.txt if you want to do training.')
32 | self.coco_amp_lst = None
33 | self.videos_train = []
34 | self.load_all = False # Don't turn it on, unless you have such a big mem.
35 | # On coco dataset, 100, 000 sets -> 850G
36 |
37 | # Training
38 | self.lr = 1e-4
39 | self.betas = (0.9, 0.999)
40 | self.batch_size_test = 1
41 | self.preproc = ['poisson'] # ['resize', ]
42 | self.pretrained_weights = ''
43 |
44 | # Callbacks
45 | self.num_val_per_epoch = 10
46 | self.save_dir = 'weights_date{}'.format(self.date)
47 | self.time_st = time.time()
48 | self.losses = []
49 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | from sklearn.utils import shuffle
3 | import cv2
4 | from skimage.io import imread
5 | from skimage.util import random_noise
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional as F
9 | from torch.autograd import Variable
10 | from PIL import Image, ImageFile
11 |
12 |
13 | ImageFile.LOAD_TRUNCATED_IMAGES = True
14 | Tensor = torch.cuda.FloatTensor
15 |
16 |
17 | def gen_poisson_noise(unit):
18 | n = np.random.randn(*unit.shape)
19 |
20 | # Strange here, unit has been in range of (-1, 1),
21 | # but made example to checked to be same as the official codes.
22 | n_str = np.sqrt(unit + 1.0) / np.sqrt(127.5)
23 | poisson_noise = np.multiply(n, n_str)
24 | return poisson_noise
25 |
26 |
27 | def load_unit(path, inf_size=(0, 0)):
28 | # Load
29 | file_suffix = os.path.splitext(path)[1].lower()
30 | if file_suffix in ['.jpg', '.png']:
31 | try:
32 | image = imread(path).astype(np.uint8)
33 | except Exception as e:
34 | print('{} load exception:\n'.format(path), e)
35 | image = np.array(Image.open(path).convert('RGB'))
36 | if inf_size != (0, 0):
37 | image = cv2.resize(image, inf_size, interpolation=cv2.INTER_LANCZOS4)
38 | unit = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
39 | return unit
40 | else:
41 | print('Unsupported file type.')
42 | return None
43 |
44 | def unit_preprocessing(unit, preproc=[], is_test=False):
45 | # Preprocessing
46 | if 'BF' in preproc and is_test:
47 | unit = cv2.bilateralFilter(unit, 9, 75, 75)
48 | if 'resize' in preproc:
49 | unit = cv2.resize(unit, (384, 384), interpolation=cv2.INTER_LANCZOS4)
50 | elif 'downsample' in preproc:
51 | unit = cv2.resize(unit, unit.shape[1]//2, unit.shape[0]//2, interpolation=cv2.INTER_LANCZOS4)
52 |
53 | unit = cv2.cvtColor(unit, cv2.COLOR_BGR2RGB)
54 | try:
55 | if 'poisson' in preproc:
56 | # Use poisson noise from official repo or skimage?
57 |
58 | # unit = unit + gen_poisson_noise(unit) * np.random.uniform(0, 0.3)
59 |
60 | unit = random_noise(unit, mode='poisson') # unit: 0 ~ 1
61 | unit = unit * 255
62 | except Exception as e:
63 | print('EX:', e, unit.shape, unit.dtype)
64 |
65 | unit = unit / 127.5 - 1.0
66 |
67 | unit = np.transpose(unit, (2, 0, 1))
68 | return unit
69 |
70 |
71 | def unit_postprocessing(unit, vid_size=None):
72 | unit = unit.squeeze()
73 | unit = unit.cpu().detach().numpy()
74 | unit = np.clip(unit, -1, 1)
75 | unit = np.round((np.transpose(unit, (1, 2, 0)) + 1.0) * 127.5).astype(np.uint8)
76 | if unit.shape[:2][::-1] != vid_size and vid_size is not None:
77 | unit = cv2.resize(unit, vid_size, interpolation=cv2.INTER_CUBIC)
78 | return unit
79 |
80 |
81 | def get_paths_ABC(config, mode):
82 | if mode in ('train', 'test_on_trainset'):
83 | dir_root = config.dir_train
84 | elif mode == 'test_on_testset':
85 | dir_root = config.dir_test
86 | else:
87 | val_vid = '_'.join(mode.split('_')[2:])
88 | try:
89 | dir_root = eval('config.dir_{}'.format(val_vid))
90 | if not os.path.exists(dir_root):
91 | dir_root = os.path.join(config.data_dir, val_vid)
92 | except:
93 | dir_root = os.path.join(config.data_dir, val_vid)
94 | if not os.path.exists(dir_root):
95 | print('Cannot find data at {}.\nExiting the program...'.format(dir_root))
96 | exit()
97 | paths_A, paths_C, paths_skip_intermediate, paths_skip, paths_mag = [], [], [], [], []
98 | if config.cursor_end > 0 or 'test' in mode:
99 | dir_A = os.path.join(dir_root, 'frameA')
100 | files_A = sorted(os.listdir(dir_A), key=lambda x: int(x.split('.')[0]))
101 | paths_A = [os.path.join(dir_A, file_A) for file_A in files_A]
102 | if mode == 'train' and isinstance(config.cursor_end, int):
103 | paths_A = paths_A[:config.cursor_end]
104 | paths_C = [p.replace('frameA', 'frameC') for p in paths_A]
105 | paths_mag = [p.replace('frameA', 'amplified') for p in paths_A]
106 | else:
107 | paths_A, paths_C, paths_skip_intermediate, paths_skip = [], [], [], []
108 | if 'test' not in mode:
109 | path_vids = os.path.join(config.dir_train, 'train_vid_frames')
110 | dirs_vid = [os.path.join(path_vids, p, 'frameA') for p in config.videos_train]
111 | for dir_vid in dirs_vid[:len(config.videos_train)]:
112 | vid_frames = [
113 | os.path.join(dir_vid, p) for p in sorted(
114 | os.listdir(dir_vid), key=lambda x: int(x.split('.')[0])
115 | )]
116 | if config.skip < 0:
117 | lst = [p.replace('frameA', 'frameC') for p in vid_frames]
118 | for idx, _ in enumerate(lst):
119 | skip_rand = np.random.randint(min(-config.skip, 2), -config.skip+1)
120 | idx_skip = min(idx + skip_rand, len(lst) - 1)
121 | paths_skip.append(lst[idx_skip])
122 | paths_skip_intermediate.append(lst[idx_skip//2])
123 | paths_A += vid_frames
124 | paths_C = [p.replace('frameA', 'frameC') for p in paths_A]
125 | paths_B = [p.replace('frameC', 'frameB') for p in paths_C]
126 | return paths_A, paths_B, paths_C, paths_skip, paths_skip_intermediate, paths_mag
127 |
128 |
129 | class DataGen():
130 | def __init__(self, paths, config, mode):
131 | self.is_train = 'test' not in mode
132 | self.anchor = 0
133 | self.paths = paths
134 | self.batch_size = config.batch_size if self.is_train else config.batch_size_test
135 | self.data_len = len(paths)
136 | self.load_all = config.load_all
137 | self.data = []
138 | self.preproc = config.preproc
139 | self.coco_amp_lst = config.coco_amp_lst
140 |
141 | if self.is_train and self.load_all:
142 | self.units_A, self.units_C, self.units_M, self.units_B = [], [], [], []
143 | for idx_data in range(self.data_len):
144 | if idx_data % 500 == 0:
145 | print('Processing {} / {}.'.format(idx_data, self.data_len))
146 | unit_A = load_unit(self.paths[idx_data])
147 | unit_C = load_unit(self.paths[idx_data].replace('frameA', 'frameC'))
148 | unit_M = load_unit(self.paths[idx_data].replace('frameA', 'amplified'))
149 | unit_B = load_unit(self.paths[idx_data].replace('frameA', 'frameB'))
150 | unit_A = unit_preprocessing(unit_A, preproc=self.preproc)
151 | unit_C = unit_preprocessing(unit_C, preproc=self.preproc)
152 | unit_M = unit_preprocessing(unit_M, preproc=[])
153 | unit_B = unit_preprocessing(unit_B, preproc=self.preproc)
154 | self.units_A.append(unit_A)
155 | self.units_C.append(unit_C)
156 | self.units_M.append(unit_M)
157 | self.units_B.append(unit_B)
158 |
159 | def gen(self, anchor=None):
160 | batch_A = []
161 | batch_C = []
162 | batch_M = []
163 | batch_B = []
164 | batch_amp = []
165 | if anchor is None:
166 | anchor = self.anchor
167 |
168 | for _ in range(self.batch_size):
169 | if not self.load_all:
170 | unit_A = load_unit(self.paths[anchor])
171 | unit_C = load_unit(self.paths[anchor].replace('frameA', 'frameC'))
172 | unit_M = load_unit(self.paths[anchor].replace('frameA', 'amplified'))
173 | unit_B = load_unit(self.paths[anchor].replace('frameA', 'frameB'))
174 | unit_A = unit_preprocessing(unit_A, preproc=self.preproc)
175 | unit_C = unit_preprocessing(unit_C, preproc=self.preproc)
176 | unit_M = unit_preprocessing(unit_M, preproc=[])
177 | unit_B = unit_preprocessing(unit_B, preproc=self.preproc)
178 | else:
179 | unit_A = self.units_A[anchor]
180 | unit_C = self.units_C[anchor]
181 | unit_M = self.units_M[anchor]
182 | unit_B = self.units_B[anchor]
183 | unit_amp = self.coco_amp_lst[anchor]
184 |
185 | batch_A.append(unit_A)
186 | batch_C.append(unit_C)
187 | batch_M.append(unit_M)
188 | batch_B.append(unit_B)
189 | batch_amp.append(unit_amp)
190 |
191 | self.anchor = (self.anchor + 1) % self.data_len
192 |
193 | batch_A = numpy2cuda(batch_A)
194 | batch_C = numpy2cuda(batch_C)
195 | batch_M = numpy2cuda(batch_M)
196 | batch_B = numpy2cuda(batch_B)
197 | batch_amp = numpy2cuda(batch_amp).reshape(self.batch_size, 1, 1, 1)
198 | return batch_A, batch_B, batch_C, batch_M, batch_amp
199 |
200 | def gen_test(self, anchor=None, inf_size=(0, 0)):
201 | batch_A = []
202 | batch_C = []
203 | if anchor is None:
204 | anchor = self.anchor
205 |
206 | for _ in range(self.batch_size):
207 | unit_A = load_unit(self.paths[anchor], inf_size=inf_size)
208 | unit_C = load_unit(self.paths[anchor].replace('frameA', 'frameC'), inf_size=inf_size)
209 | unit_A = unit_preprocessing(unit_A, preproc=[], is_test=True)
210 | unit_C = unit_preprocessing(unit_C, preproc=[], is_test=True)
211 | batch_A.append(unit_A)
212 | batch_C.append(unit_C)
213 |
214 | self.anchor = (self.anchor + 1) % self.data_len
215 |
216 | batch_A = numpy2cuda(batch_A)
217 | batch_C = numpy2cuda(batch_C)
218 | return batch_A, batch_C
219 |
220 |
221 | def get_gen_ABC(config, mode='train'):
222 | paths_A = get_paths_ABC(config, mode)[0]
223 | gen_train_A = DataGen(paths_A, config, mode)
224 | return gen_train_A
225 |
226 |
227 | def cuda2numpy(tensor):
228 | array = tensor.detach().cpu().squeeze().numpy()
229 | return array
230 |
231 |
232 | def numpy2cuda(array):
233 | tensor = torch.from_numpy(np.asarray(array)).float().cuda()
234 | return tensor
235 |
236 |
237 | def resize2d(img, size):
238 | with torch.no_grad():
239 | img_resized = (F.adaptive_avg_pool2d(Variable(img, volatile=True), size)).data
240 | return img_resized
241 |
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | import torch.autograd as ag
5 |
6 |
7 | def criterion_mag(y, batch_M, texture_AC, texture_BM, motion_BC, criterion):
8 | # One thing deserves mentioning is that the amplified frames given in the dataset are actually perturbed Y(Y'), which I used M to represent.
9 | loss_y = criterion(y, batch_M)
10 | loss_texture_AC = criterion(*texture_AC)
11 | loss_texture_BM = criterion(*texture_BM)
12 | loss_motion_BC = criterion(*motion_BC)
13 | return loss_y, loss_texture_AC, loss_texture_BM, loss_motion_BC
14 |
--------------------------------------------------------------------------------
/magnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from data import numpy2cuda
4 |
5 |
6 | def truncated_normal_(tensor, mean=0, std=1):
7 | size = tensor.shape
8 | tmp = tensor.new_empty(size + (4,)).normal_()
9 | valid = (tmp < 2) & (tmp > -2)
10 | ind = valid.max(-1, keepdim=True)[1]
11 | tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
12 | tensor.data.mul_(std).add_(mean)
13 | return tensor
14 |
15 |
16 | class Conv2D_activa(nn.Module):
17 | def __init__(
18 | self, in_channels, out_channels, kernel_size, stride,
19 | padding=0, dilation=1, activation='relu'
20 | ):
21 | super(Conv2D_activa, self).__init__()
22 | self.padding = padding
23 | if self.padding:
24 | self.pad = nn.ReflectionPad2d(padding)
25 | self.conv2d = nn.Conv2d(
26 | in_channels, out_channels, kernel_size, stride,
27 | dilation=dilation, bias=None
28 | )
29 | self.activation = activation
30 | if activation == 'relu':
31 | self.activation = nn.ReLU()
32 |
33 | def forward(self, x):
34 | if self.padding:
35 | x = self.pad(x)
36 | x = self.conv2d(x)
37 | if self.activation:
38 | x = self.activation(x)
39 | return x
40 |
41 | class ResBlk(nn.Module):
42 | def __init__(self, dim_in, dim_out, dim_intermediate=32, ks=3, s=1):
43 | super(ResBlk, self).__init__()
44 | p = (ks - 1) // 2
45 | self.cba_1 = Conv2D_activa(dim_in, dim_intermediate, ks, s, p, activation='relu')
46 | self.cba_2 = Conv2D_activa(dim_intermediate, dim_out, ks, s, p, activation=None)
47 |
48 | def forward(self, x):
49 | y = self.cba_1(x)
50 | y = self.cba_2(y)
51 | return y + x
52 |
53 |
54 | def _repeat_blocks(block, dim_in, dim_out, num_blocks, dim_intermediate=32, ks=3, s=1):
55 | blocks = []
56 | for idx_block in range(num_blocks):
57 | if idx_block == 0:
58 | blocks.append(block(dim_in, dim_out, dim_intermediate=dim_intermediate, ks=ks, s=s))
59 | else:
60 | blocks.append(block(dim_out, dim_out, dim_intermediate=dim_intermediate, ks=ks, s=s))
61 | return nn.Sequential(*blocks)
62 |
63 |
64 | class Encoder(nn.Module):
65 | def __init__(
66 | self, dim_in=3, dim_out=32, num_resblk=3,
67 | use_texture_conv=True, use_motion_conv=True, texture_downsample=True,
68 | num_resblk_texture=2, num_resblk_motion=2
69 | ):
70 | super(Encoder, self).__init__()
71 | self.use_texture_conv, self.use_motion_conv = use_texture_conv, use_motion_conv
72 |
73 | self.cba_1 = Conv2D_activa(dim_in, 16, 7, 1, 3, activation='relu')
74 | self.cba_2 = Conv2D_activa(16, 32, 3, 2, 1, activation='relu')
75 |
76 | self.resblks = _repeat_blocks(ResBlk, 32, 32, num_resblk)
77 |
78 | # texture representation
79 | if self.use_texture_conv:
80 | self.texture_cba = Conv2D_activa(
81 | 32, 32, 3, (2 if texture_downsample else 1), 1,
82 | activation='relu'
83 | )
84 | self.texture_resblks = _repeat_blocks(ResBlk, 32, dim_out, num_resblk_texture)
85 |
86 | # motion representation
87 | if self.use_motion_conv:
88 | self.motion_cba = Conv2D_activa(32, 32, 3, 1, 1, activation='relu')
89 | self.motion_resblks = _repeat_blocks(ResBlk, 32, dim_out, num_resblk_motion)
90 |
91 | def forward(self, x):
92 | x = self.cba_1(x)
93 | x = self.cba_2(x)
94 | x = self.resblks(x)
95 |
96 | if self.use_texture_conv:
97 | texture = self.texture_cba(x)
98 | texture = self.texture_resblks(texture)
99 | else:
100 | texture = self.texture_resblks(x)
101 |
102 | if self.use_motion_conv:
103 | motion = self.motion_cba(x)
104 | motion = self.motion_resblks(motion)
105 | else:
106 | motion = self.motion_resblks(x)
107 |
108 | return texture, motion
109 |
110 |
111 | class Decoder(nn.Module):
112 | def __init__(self, dim_in=32, dim_out=3, num_resblk=9, texture_downsample=True):
113 | super(Decoder, self).__init__()
114 | self.texture_downsample = texture_downsample
115 |
116 | if self.texture_downsample:
117 | self.texture_up = nn.UpsamplingNearest2d(scale_factor=2)
118 | # self.texture_cba = Conv2D_activa(dim_in, 32, 3, 1, 1, activation='relu')
119 |
120 | self.resblks = _repeat_blocks(ResBlk, 64, 64, num_resblk, dim_intermediate=64)
121 | self.up = nn.UpsamplingNearest2d(scale_factor=2)
122 | self.cba_1 = Conv2D_activa(64, 32, 3, 1, 1, activation='relu')
123 | self.cba_2 = Conv2D_activa(32, dim_out, 7, 1, 3, activation=None)
124 |
125 | def forward(self, texture, motion):
126 | if self.texture_downsample:
127 | texture = self.texture_up(texture)
128 | if motion.shape != texture.shape:
129 | texture = nn.functional.interpolate(texture, size=motion.shape[-2:])
130 | x = torch.cat([texture, motion], 1)
131 |
132 | x = self.resblks(x)
133 |
134 | x = self.up(x)
135 | x = self.cba_1(x)
136 | x = self.cba_2(x)
137 |
138 | return x
139 |
140 |
141 | class Manipulator(nn.Module):
142 | def __init__(self):
143 | super(Manipulator, self).__init__()
144 | self.g = Conv2D_activa(32, 32, 3, 1, 1, activation='relu')
145 | self.h_conv = Conv2D_activa(32, 32, 3, 1, 1, activation=None)
146 | self.h_resblk = ResBlk(32, 32)
147 |
148 | def forward(self, motion_A, motion_B, amp_factor):
149 | motion = motion_B - motion_A
150 | motion_delta = self.g(motion) * amp_factor
151 | motion_delta = self.h_conv(motion_delta)
152 | motion_delta = self.h_resblk(motion_delta)
153 | motion_mag = motion_B + motion_delta
154 | return motion_mag
155 |
156 |
157 | class MagNet(nn.Module):
158 | def __init__(self):
159 | super(MagNet, self).__init__()
160 | self.encoder = Encoder(dim_in=3*1)
161 | self.manipulator = Manipulator()
162 | self.decoder = Decoder(dim_out=3*1)
163 |
164 | def forward(self, batch_A, batch_B, batch_C, batch_M, amp_factor, mode='train'):
165 | if mode == 'train':
166 | texture_A, motion_A = self.encoder(batch_A)
167 | texture_B, motion_B = self.encoder(batch_B)
168 | texture_C, motion_C = self.encoder(batch_C)
169 | texture_M, motion_M = self.encoder(batch_M)
170 | motion_mag = self.manipulator(motion_A, motion_B, amp_factor)
171 | y_hat = self.decoder(texture_B, motion_mag)
172 | texture_AC = [texture_A, texture_C]
173 | motion_BC = [motion_B, motion_C]
174 | texture_BM = [texture_B, texture_M]
175 | return y_hat, texture_AC, texture_BM, motion_BC
176 | elif mode == 'evaluate':
177 | texture_A, motion_A = self.encoder(batch_A)
178 | texture_B, motion_B = self.encoder(batch_B)
179 | motion_mag = self.manipulator(motion_A, motion_B, amp_factor)
180 | y_hat = self.decoder(texture_B, motion_mag)
181 | return y_hat
182 |
183 |
184 | def main():
185 | model = MagNet()
186 | print('model:\n', model)
187 |
188 |
189 | if __name__ == '__main__':
190 | main()
191 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import torch.optim as optim
8 | import torch.backends.cudnn as cudnn
9 |
10 | from config import Config
11 | from magnet import MagNet
12 | from data import get_gen_ABC, cuda2numpy
13 | from callbacks import save_model, gen_state_dict
14 | from losses import criterion_mag
15 |
16 |
17 | # Configurations
18 | config = Config()
19 | cudnn.benchmark = True
20 |
21 | magnet = MagNet().cuda()
22 | if config.pretrained_weights:
23 | magnet.load_state_dict(gen_state_dict(config.pretrained_weights))
24 | if torch.cuda.device_count() > 1:
25 | magnet = nn.DataParallel(magnet)
26 | criterion = nn.L1Loss().cuda()
27 |
28 | optimizer = optim.Adam(magnet.parameters(), lr=config.lr, betas=config.betas)
29 |
30 | if not os.path.exists(config.save_dir):
31 | os.makedirs(config.save_dir)
32 | print('Save_dir:', config.save_dir)
33 |
34 | # Data generator
35 | data_loader = get_gen_ABC(config, mode='train')
36 | print('Number of training image couples:', data_loader.data_len)
37 |
38 | # Training
39 | for epoch in range(1, config.epochs+1):
40 | print('epoch:', epoch)
41 | losses, losses_y, losses_texture_AC, losses_texture_BM, losses_motion_BC = [], [], [], [], []
42 | for idx_load in range(0, data_loader.data_len, data_loader.batch_size):
43 |
44 | # Data Loading
45 | batch_A, batch_B, batch_C, batch_M, batch_amp = data_loader.gen()
46 |
47 | # G Train
48 | optimizer.zero_grad()
49 | y_hat, texture_AC, texture_BM, motion_BC = magnet(batch_A, batch_B, batch_C, batch_M, batch_amp, mode='train')
50 | loss_y, loss_texture_AC, loss_texture_BM, loss_motion_BC = criterion_mag(y_hat, batch_M, texture_AC, texture_BM, motion_BC, criterion)
51 | loss = loss_y + (loss_texture_AC + loss_texture_BM + loss_motion_BC) * 0.1
52 | loss.backward()
53 | optimizer.step()
54 |
55 | # Callbacks
56 | losses.append(loss.item())
57 | losses_y.append(loss_y.item())
58 | losses_texture_AC.append(loss_texture_AC.item())
59 | losses_texture_BM.append(loss_texture_BM.item())
60 | losses_motion_BC.append(loss_motion_BC.item())
61 | if (
62 | idx_load > 0 and
63 | ((idx_load // data_loader.batch_size) %
64 | (data_loader.data_len // data_loader.batch_size // config.num_val_per_epoch)) == 0
65 | ):
66 | print(', {}%'.format(idx_load * 100 // data_loader.data_len), end='')
67 |
68 | # Collections
69 | save_model(magnet.state_dict(), losses, config.save_dir, epoch)
70 | print('\ntime: {}m, ep: {}, loss: {:.3e}, y: {:.3e}, tex_AC: {:.3e}, tex_BM: {:.3e}, mot_BC: {:.3e}'.format(
71 | int((time.time()-config.time_st)/60), epoch, np.mean(losses), np.mean(losses_y), np.mean(losses_texture_AC), np.mean(losses_texture_BM), np.mean(losses_motion_BC)
72 | ))
73 |
--------------------------------------------------------------------------------
/make_frameACB.py:
--------------------------------------------------------------------------------
1 | """
2 | Put it into the corresponding datasets directory, e.g. `/datasets/motion_mag_data/train/train_vid_frames` for me.
3 | Make the original frames into frameAs, frameBs, frameCs(same as frameBs here)
4 | """
5 | import os
6 | import sys
7 |
8 |
9 | # Choose the dir you want
10 | dirs = sorted([i for i in os.listdir('.') if i in
11 | sys.argv[1].split('+')
12 | # and int(i.split('_')[-1].split('.')[0]) > 0
13 | ]
14 | # , key=lambda x: int(x.split('_')[-1])
15 | )[:]
16 |
17 | image_format_name='png'
18 |
19 | for d in dirs:
20 | print('ACB-Processing on', d)
21 | os.chdir(d)
22 | os.mkdir('frameA')
23 | os.mkdir('frameC')
24 | files = sorted([f for f in os.listdir('.') if os.path.splitext(f)[1] == '.{}'.format(image_format_name)], key=lambda x: int(x.split('.')[0]))
25 | os.system('cp ./*{} frameA && cp ./*{} frameC'.format(image_format_name, image_format_name))
26 | os.remove(os.path.join('frameA', files[-1]))
27 | os.remove(os.path.join('frameC', files[0]))
28 | for f in sorted(os.listdir('frameC'), key=lambda x: int(x.split('.')[0])):
29 | f_new = os.path.join('frameC', '%06d' % (int(f.split('.')[0])-1) + '.{}'.format(image_format_name))
30 | f = os.path.join('frameC', f)
31 | os.rename(f, f_new)
32 | os.system('cp -r frameC frameB')
33 | os.system('rm ./*.{}'.format(image_format_name))
34 | os.chdir('..')
35 |
--------------------------------------------------------------------------------
/materials/Fig2-a.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhengPeng7/motion_magnification_learning-based/6093402312f1c4c69cae1e6bac888407598ec414/materials/Fig2-a.png
--------------------------------------------------------------------------------
/materials/baby_comp.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhengPeng7/motion_magnification_learning-based/6093402312f1c4c69cae1e6bac888407598ec414/materials/baby_comp.gif
--------------------------------------------------------------------------------
/materials/dogs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhengPeng7/motion_magnification_learning-based/6093402312f1c4c69cae1e6bac888407598ec414/materials/dogs.png
--------------------------------------------------------------------------------
/materials/guitar_comp.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhengPeng7/motion_magnification_learning-based/6093402312f1c4c69cae1e6bac888407598ec414/materials/guitar_comp.gif
--------------------------------------------------------------------------------
/materials/myself_comp.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhengPeng7/motion_magnification_learning-based/6093402312f1c4c69cae1e6bac888407598ec414/materials/myself_comp.gif
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.7.0
2 | torchvision
3 | opencv-python
4 | scikit-image
5 | scikit-learn
6 | ipykernel
7 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | python main.py
2 | python test_video.py baby-guitar-water-gun-drone-cattoy
3 |
--------------------------------------------------------------------------------
/test_video.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import cv2
4 | import torch
5 | import numpy as np
6 | from config import Config
7 | from magnet import MagNet
8 | from data import get_gen_ABC, unit_postprocessing, numpy2cuda, resize2d
9 | from callbacks import gen_state_dict
10 |
11 |
12 | # config
13 | config = Config()
14 | # Load weights
15 | ep = ''
16 | weights_file = sorted(
17 | [p for p in os.listdir(config.save_dir) if '_loss' in p and '_epoch{}'.format(ep) in p and 'D' not in p],
18 | key=lambda x: float(x.rstrip('.pth').split('_loss')[-1])
19 | )[0]
20 | weights_path = os.path.join(config.save_dir, weights_file)
21 | ep = int(weights_path.split('epoch')[-1].split('_')[0])
22 | state_dict = gen_state_dict(weights_path)
23 |
24 | model_test = MagNet().cuda()
25 | model_test.load_state_dict(state_dict)
26 | model_test.eval()
27 | print("Loading weights:", weights_file)
28 |
29 | if len(sys.argv) == 1:
30 | testsets = 'baby-guitar-gun-drone-cattoy-water'
31 | else:
32 | testsets = sys.argv[-1]
33 | testsets = testsets.split('-')
34 | dir_results = config.save_dir.replace('weights', 'results')
35 | for testset in testsets:
36 | if not os.path.exists(dir_results):
37 | os.makedirs(dir_results)
38 |
39 | data_loader = get_gen_ABC(config, mode='test_on_'+testset)
40 | print('Number of test image couples:', data_loader.data_len)
41 | vid_size = cv2.imread(data_loader.paths[0]).shape[:2][::-1]
42 |
43 | # Test
44 | for amp in [5, 10, 30, 50]:
45 | frames = []
46 | data_loader = get_gen_ABC(config, mode='test_on_'+testset)
47 | for idx_load in range(0, data_loader.data_len, data_loader.batch_size):
48 | if (idx_load+1) % 100 == 0:
49 | print('{}'.format(idx_load+1), end=', ')
50 | batch_A, batch_B = data_loader.gen_test()
51 | amp_factor = numpy2cuda(amp)
52 | for _ in range(len(batch_A.shape) - len(amp_factor.shape)):
53 | amp_factor = amp_factor.unsqueeze(-1)
54 | with torch.no_grad():
55 | y_hats = model_test(batch_A, batch_B, 0, 0, amp_factor, mode='evaluate')
56 | for y_hat in y_hats:
57 | y_hat = unit_postprocessing(y_hat, vid_size=vid_size)
58 | frames.append(y_hat)
59 | if len(frames) >= data_loader.data_len:
60 | break
61 | if len(frames) >= data_loader.data_len:
62 | break
63 | data_loader = get_gen_ABC(config, mode='test_on_'+testset)
64 | frames = [unit_postprocessing(data_loader.gen_test()[0], vid_size=vid_size)] + frames
65 |
66 | # Make videos of framesMag
67 | video_dir = os.path.join(dir_results, testset)
68 | if not os.path.exists(video_dir):
69 | os.makedirs(video_dir)
70 | FPS = 30
71 | out = cv2.VideoWriter(
72 | os.path.join(video_dir, '{}_amp{}.avi'.format(testset, amp)),
73 | cv2.VideoWriter_fourcc(*'DIVX'),
74 | FPS, frames[0].shape[-2::-1]
75 | )
76 | for frame in frames:
77 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
78 | cv2.putText(frame, 'amp_factor={}'.format(amp), (7, 37),
79 | fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255), thickness=2)
80 | out.write(frame)
81 | out.release()
82 | print('{} has been done.'.format(os.path.join(video_dir, '{}_amp{}.avi'.format(testset, amp))))
83 |
84 |
--------------------------------------------------------------------------------