├── .gitignore
├── LICENSE
├── README.md
├── colab_notebooks
├── Convert_Video_to_SMPLpix_Dataset.ipynb
└── SMPLpix_training.ipynb
├── setup.py
└── smplpix
├── __init__.py
├── args.py
├── dataset.py
├── eval.py
├── train.py
├── training.py
├── unet.py
├── utils.py
└── vgg.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 |
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | pip-wheel-metadata/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | .python-version
88 |
89 | # pipenv
90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
93 | # install all needed dependencies.
94 | #Pipfile.lock
95 |
96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
97 | __pypackages__/
98 |
99 | # Celery stuff
100 | celerybeat-schedule
101 | celerybeat.pid
102 |
103 | # SageMath parsed files
104 | *.sage.py
105 |
106 | # Environments
107 | .env
108 | .venv
109 | env/
110 | venv/
111 | ENV/
112 | env.bak/
113 | venv.bak/
114 |
115 | # Spyder project settings
116 | .spyderproject
117 | .spyproject
118 |
119 | # Rope project settings
120 | .ropeproject
121 |
122 | # mkdocs documentation
123 | /site
124 |
125 | # mypy
126 | .mypy_cache/
127 | .dmypy.json
128 | dmypy.json
129 |
130 | # Pyre type checker
131 | .pyre/
132 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2021 - now, Sergey Prokudin (sergey.prokudin@gmail.com)
2 |
3 | Redistribution and use of this software and associated documentation files (the "Software"), with or
4 | without modification, are permitted provided that the following conditions are met:
5 |
6 | * The above copyright notice and this permission notice shall be included in all copies or substantial
7 | portions of the Software.
8 |
9 | * Any use for commercial, pornographic, military, or surveillance, purposes is prohibited. The Software
10 | may not be used to create fake, libelous, misleading, or defamatory content of any kind excluding
11 | analyses in peer-reviewed scientific research.
12 |
13 | For commercial uses of the Software, please send email to ps-license@tue.mpg.de.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT
16 | LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
17 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
18 | WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
19 | SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
20 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | 
3 |
4 | _**Left**: [SMPL-X](https://smpl-x.is.tue.mpg.de/) human mesh registered with [SMPLify-X](https://smpl-x.is.tue.mpg.de/), **middle**: SMPLpix render, **right**: ground truth. [Video](https://user-images.githubusercontent.com/8117267/116540639-b9537480-a8ea-11eb-81ca-57d473147fbd.mp4)._
5 |
6 |
7 | # SMPLpix: Neural Avatars from 3D Human Models
8 |
9 | *SMPLpix* neural rendering framework combines deformable 3D models such as [SMPL-X](https://smpl-x.is.tue.mpg.de/)
10 | with the power of image-to-image translation frameworks (aka [pix2pix](https://phillipi.github.io/pix2pix/) models).
11 |
12 | Please check our [WACV 2021 paper](https://arxiv.org/abs/2008.06872) or a [5-minute explanatory video](https://www.youtube.com/watch?v=JY9t4xUAouk) for more details on the framework.
13 |
14 | _**Important note**_: this repository is a re-implementation of the original framework, made by the same author after the end of internship.
15 | It **does not contain** the original Amazon multi-subject, multi-view training data and code, and uses full mesh rasterizations as inputs rather than point projections (as described [here](https://youtu.be/JY9t4xUAouk?t=241)).
16 |
17 |
18 | ## Demo
19 |
20 | | Description | Link |
21 | | ----------- | ----------- |
22 | | Process a video into a SMPLpix dataset| [](https://colab.research.google.com/github//sergeyprokudin/smplpix/blob/main/colab_notebooks/Convert_Video_to_SMPLpix_Dataset.ipynb)|
23 | | Train SMPLpix| [](https://colab.research.google.com/github/sergeyprokudin/smplpix/blob/main/colab_notebooks/SMPLpix_training.ipynb)|
24 |
25 | ### Prepare the data
26 |
27 | 
28 |
29 | We provide the Colab notebook for preparing SMPLpix training dataset. This will allow you
30 | to create your own neural avatar given [monocular video of a human moving in front of the camera](https://www.dropbox.com/s/rjqwf894ovso218/smplpix_test_video_na.mp4?dl=0).
31 |
32 | ### Run demo training
33 |
34 | We provide some preprocessed data which allows you to run and test the training pipeline right away:
35 |
36 | ```
37 | git clone https://github.com/sergeyprokudin/smplpix
38 | cd smplpix
39 | python setup.py install
40 | python smplpix/train.py --workdir='/content/smplpix_logs/' \
41 | --data_url='https://www.dropbox.com/s/coapl05ahqalh09/smplpix_data_test_final.zip?dl=0'
42 | ```
43 |
44 | ### Train on your own data
45 |
46 | You can train SMPLpix on your own data by specifying the path to the root directory with data:
47 |
48 | ```
49 | python smplpix/train.py --workdir='/content/smplpix_logs/' \
50 | --data_dir='/path/to/data'
51 | ```
52 |
53 | The directory should contain train, validation and test folders, each of which should contain input and output folders. Check the structure of [the demo dataset](https://www.dropbox.com/s/coapl05ahqalh09/smplpix_data_test_final.zip?dl=0) for reference.
54 |
55 | You can also specify various parameters of training via command line. E.g., to reproduce the results of the demo video:
56 |
57 | ```
58 | python smplpix/train.py --workdir='/content/smplpix_logs/' \
59 | --data_url='https://www.dropbox.com/s/coapl05ahqalh09/smplpix_data_test_final.zip?dl=0' \
60 | --downsample_factor=2 \
61 | --n_epochs=500 \
62 | --sched_patience=2 \
63 | --batch_size=4 \
64 | --n_unet_blocks=5 \
65 | --n_input_channels=3 \
66 | --n_output_channels=3 \
67 | --eval_every_nth_epoch=10
68 | ```
69 |
70 | Check the [args.py](https://github.com/sergeyprokudin/smplpix/blob/main/smplpix/args.py) for the full list of parameters.
71 |
72 | ## More examples
73 |
74 | ### Animating with novel poses
75 |
76 | 
77 |
78 | **Left**: poses from the test video sequence, **right**: SMPLpix renders. [Video](https://user-images.githubusercontent.com/8117267/116546566-0edf4f80-a8f2-11eb-9fb2-a173c0018a4e.mp4).
79 |
80 |
81 | ### Rendering faces
82 |
83 | 
84 |
85 |
86 | _**Left**: [FLAME](https://flame.is.tue.mpg.de/) face model inferred with [DECA](https://github.com/YadiraF/DECA), **middle**: ground truth test video, **right**: SMPLpix render. [Video](https://user-images.githubusercontent.com/8117267/116543423-23214d80-a8ee-11eb-9ded-86af17c56549.mp4)._
87 |
88 | Thanks to [Maria Paola Forte](https://www.is.mpg.de/~Forte) for providing the sequence.
89 |
90 | ### Few-shot artistic neural style transfer
91 |
92 | 
93 |
94 |
95 | _**Left**: rendered [AMASS](https://amass.is.tue.mpg.de/) motion sequence, **right**: generated SMPLpix animations. [Full video](https://user-images.githubusercontent.com/8117267/116544826-e9514680-a8ef-11eb-8682-0ea8d19d0d5e.mp4). See [the explanatory video](https://youtu.be/JY9t4xUAouk?t=255) for details._
96 |
97 | Credits to [Alexander Kabarov](mailto:blackocher@gmail.com) for providing the training sketches.
98 |
99 | ## Citation
100 |
101 | If you find our work useful in your research, please consider citing:
102 | ```
103 | @inproceedings{prokudin2021smplpix,
104 | title={SMPLpix: Neural Avatars from 3D Human Models},
105 | author={Prokudin, Sergey and Black, Michael J and Romero, Javier},
106 | booktitle={Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision},
107 | pages={1810--1819},
108 | year={2021}
109 | }
110 | ```
111 |
112 | ## License
113 |
114 | See the [LICENSE](https://github.com/sergeyprokudin/smplpix/blob/main/LICENSE) file.
115 |
116 | 
117 |
118 |
--------------------------------------------------------------------------------
/colab_notebooks/Convert_Video_to_SMPLpix_Dataset.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "Convert Video to SMPLpix Dataset.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": [
9 | "cbXoNhFF-D8Q"
10 | ],
11 | "toc_visible": true,
12 | "machine_shape": "hm"
13 | },
14 | "kernelspec": {
15 | "name": "python3",
16 | "display_name": "Python 3"
17 | },
18 | "accelerator": "GPU"
19 | },
20 | "cells": [
21 | {
22 | "cell_type": "markdown",
23 | "metadata": {
24 | "id": "saNBv0dY-Eef"
25 | },
26 | "source": [
27 | "\n",
28 | "\n",
29 | "# SMPLpix Dataset Preparation.\n",
30 | "\n",
31 | "**Author**: [Sergey Prokudin](https://ps.is.mpg.de/people/sprokudin). \n",
32 | "[[Project Page](https://sergeyprokudin.github.io/smplpix/)]\n",
33 | "[[Paper](https://arxiv.org/pdf/2008.06872.pdf)]\n",
34 | "[[Video](https://www.youtube.com/watch?v=JY9t4xUAouk)]\n",
35 | "[[GitHub](https://github.com/sergeyprokudin/smplpix)]\n",
36 | "\n",
37 | "This notebook contains an example workflow for converting a video file to a SMPLpix dataset. \n",
38 | "### Processing steps:\n",
39 | "\n",
40 | "1. Download the video of choice, extract frames;\n",
41 | "2. Extract **2D keypoints**: run [OpenPose](https://github.com/CMU-Perceptual-Computing-Lab/openpose) on the extracted frames;\n",
42 | "3. Infer **3D human meshes**: run [SMPLify-x](https://github.com/vchoutas/smplify-x) on the extracted frames + keypoints;\n",
43 | "4. Form dataset **image pairs**, where *input* is SMPL-X mesh render, and *output* is the corresponding target ground truth video frame;\n",
44 | "5. **Split the data** into train, test and validation, zip and copy to Google Drive.\n",
45 | "\n",
46 | "### Instructions\n",
47 | "\n",
48 | "1. Convert a video into our dataset format using this notebook.\n",
49 | "2. Train a SMPLpix using the training notebook.\n",
50 | "\n",
51 | "\n",
52 | "### Notes\n",
53 | "* While this will work for small datasets in a Colab runtime, larger datasets will require more compute power;\n",
54 | "* If you would like to train a model on a serious dataset, you should consider copying this to your own workstation and running it there. Some minor modifications will be required, and you will have to install the dependencies separately;\n",
55 | "* Please report issues on the [GitHub issue tracker](https://github.com/sergeyprokudin/smplpix/issues).\n",
56 | "\n",
57 | "If you find this work useful, please consider citing:\n",
58 | "```bibtex\n",
59 | "@inproceedings{prokudin2021smplpix,\n",
60 | " title={SMPLpix: Neural Avatars from 3D Human Models},\n",
61 | " author={Prokudin, Sergey and Black, Michael J and Romero, Javier},\n",
62 | " booktitle={Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision},\n",
63 | " pages={1810--1819},\n",
64 | " year={2021}\n",
65 | "}\n",
66 | "```\n",
67 | "\n",
68 | "Many thanks [Keunhong Park](https://keunhong.com) for providing the [Nerfie dataset preparation template](https://colab.research.google.com/github/google/nerfies/blob/main/notebooks/Nerfies_Capture_Processing.ipynb)!\n"
69 | ]
70 | },
71 | {
72 | "cell_type": "markdown",
73 | "metadata": {
74 | "id": "cbXoNhFF-D8Q"
75 | },
76 | "source": [
77 | "## Upload the video and extract frames\n"
78 | ]
79 | },
80 | {
81 | "cell_type": "code",
82 | "metadata": {
83 | "id": "K0EmRUXf5nTr",
84 | "cellView": "form"
85 | },
86 | "source": [
87 | "# @title Upload a video file (.mp4, .mov, etc.) from your disk, Dropbox, Google Drive or YouTube\n",
88 | "\n",
89 | "# @markdown This will upload it to the local Colab working directory. You can use a demo video to test the pipeline. The background in the demo was removed with the [Unscreen](https://www.unscreen.com/) service. Alternatively, you can try [PointRend](https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend) segmentation for this purpose. \n",
90 | "\n",
91 | "import os\n",
92 | "from google.colab import files\n",
93 | "\n",
94 | "def download_youtube_video(img_url, save_path, resolution_id=-3):\n",
95 | "\n",
96 | " print(\"downloading the video: %s\" % img_url)\n",
97 | " res_path = YouTube(img_url).streams.order_by('resolution')[resolution_id].download(save_path)\n",
98 | "\n",
99 | " return res_path\n",
100 | "\n",
101 | "def download_dropbox_url(url, filepath, chunk_size=1024):\n",
102 | "\n",
103 | " import requests\n",
104 | " headers = {'user-agent': 'Wget/1.16 (linux-gnu)'}\n",
105 | " r = requests.get(url, stream=True, headers=headers)\n",
106 | " with open(filepath, 'wb') as f:\n",
107 | " for chunk in r.iter_content(chunk_size=chunk_size):\n",
108 | " if chunk:\n",
109 | " f.write(chunk)\n",
110 | " return filepath\n",
111 | "\n",
112 | "!rm -rf /content/data\n",
113 | "\n",
114 | "from google.colab import drive\n",
115 | "drive.mount('/content/gdrive')\n",
116 | "\n",
117 | "VIDEO_SOURCE = 'dropbox' #@param [\"youtube\", \"upload\", \"google drive\", \"dropbox\"]\n",
118 | "\n",
119 | "if VIDEO_SOURCE == 'dropbox':\n",
120 | " DROPBOX_URL = 'https://www.dropbox.com/s/rjqwf894ovso218/smplpix_test_video_na.mp4?dl=0' #@param \n",
121 | " VIDEO_PATH = '/content/video.mp4' \n",
122 | " download_dropbox_url(DROPBOX_URL, VIDEO_PATH)\n",
123 | "elif VIDEO_SOURCE == 'upload':\n",
124 | " print(\"Please upload the video: \")\n",
125 | " uploaded = files.upload()\n",
126 | " VIDEO_PATH = os.path.join('/content', list(uploaded.keys())[0])\n",
127 | "elif VIDEO_SOURCE == 'youtube':\n",
128 | " !pip install pytube\n",
129 | " from pytube import YouTube\n",
130 | " YOTUBE_VIDEO_URL = '' #@param \n",
131 | " VIDEO_PATH = download_youtube_video(YOTUBE_VIDEO_URL, '/content/')\n",
132 | "elif VIDEO_SOURCE == 'google drive':\n",
133 | " from google.colab import drive\n",
134 | " drive.mount('/content/gdrive')\n",
135 | " GOOGLE_DRIVE_PATH = '' #@param \n",
136 | " VIDEO_PATH = GOOGLE_DRIVE_PATH\n",
137 | "\n",
138 | "\n",
139 | "print(\"video is uploaded to %s\" % VIDEO_PATH)"
140 | ],
141 | "execution_count": null,
142 | "outputs": []
143 | },
144 | {
145 | "cell_type": "code",
146 | "metadata": {
147 | "id": "VImnPzFg9UNA",
148 | "cellView": "form"
149 | },
150 | "source": [
151 | "# @title Flatten the video into frames\n",
152 | "\n",
153 | "\n",
154 | "FPS = 1# @param {type:'number'}\n",
155 | "\n",
156 | "# @markdown _Note_: for longer videos, it might make sense to decrease the FPS as it will take 30-60 seconds for SMPLify-X framework to process every frame.\n",
157 | "\n",
158 | "\n",
159 | "\n",
160 | "RES_DIR = '/content/data'\n",
161 | "FRAMES_DIR = os.path.join(RES_DIR, 'images')\n",
162 | "!rm -rf $RES_DIR\n",
163 | "!mkdir $RES_DIR\n",
164 | "!mkdir $FRAMES_DIR\n",
165 | "!ffmpeg -i \"$VIDEO_PATH\" -vf fps=$FPS -qscale:v 2 '$FRAMES_DIR/%05d.png'\n",
166 | "\n",
167 | "from PIL import Image\n",
168 | "import numpy as np\n",
169 | "import matplotlib.pyplot as plt\n",
170 | "\n",
171 | "def load_img(img_path):\n",
172 | "\n",
173 | " return np.asarray(Image.open(img_path))/255\n",
174 | "\n",
175 | "test_img_path = os.path.join(FRAMES_DIR, os.listdir(FRAMES_DIR)[0])\n",
176 | "\n",
177 | "test_img = load_img(test_img_path)\n",
178 | "\n",
179 | "plt.figure(figsize=(5, 10))\n",
180 | "plt.title(\"extracted image example\")\n",
181 | "plt.imshow(test_img)"
182 | ],
183 | "execution_count": null,
184 | "outputs": []
185 | },
186 | {
187 | "cell_type": "markdown",
188 | "metadata": {
189 | "id": "7Z-ASlgBUPXJ"
190 | },
191 | "source": [
192 | "## Extract 2D body keypoints with OpenPose\n",
193 | "\n"
194 | ]
195 | },
196 | {
197 | "cell_type": "code",
198 | "metadata": {
199 | "id": "1AL4QpsBUO9p",
200 | "cellView": "form"
201 | },
202 | "source": [
203 | "# @title Install OpenPose\n",
204 | "# @markdown This will take some time (~10 mins). The code is taken from this [OpenPose Colab notebook](https://colab.research.google.com/github/tugstugi/dl-colab-notebooks/blob/master/notebooks/OpenPose.ipynb).\n",
205 | "\n",
206 | "%cd /content\n",
207 | "import os\n",
208 | "from os.path import exists, join, basename, splitext\n",
209 | "\n",
210 | "git_repo_url = 'https://github.com/CMU-Perceptual-Computing-Lab/openpose.git'\n",
211 | "project_name = splitext(basename(git_repo_url))[0]\n",
212 | "if not exists(project_name):\n",
213 | " # see: https://github.com/CMU-Perceptual-Computing-Lab/openpose/issues/949\n",
214 | " # install new CMake becaue of CUDA10\n",
215 | " !wget -q https://cmake.org/files/v3.13/cmake-3.13.0-Linux-x86_64.tar.gz\n",
216 | " !tar xfz cmake-3.13.0-Linux-x86_64.tar.gz --strip-components=1 -C /usr/local\n",
217 | "\n",
218 | " # clone openpose\n",
219 | " !git clone -q --depth 1 $git_repo_url\n",
220 | " !sed -i 's/execute_process(COMMAND git checkout master WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}\\/3rdparty\\/caffe)/execute_process(COMMAND git checkout f019d0dfe86f49d1140961f8c7dec22130c83154 WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}\\/3rdparty\\/caffe)/g' openpose/CMakeLists.txt\n",
221 | " # install system dependencies\n",
222 | " !apt-get -qq install -y libatlas-base-dev libprotobuf-dev libleveldb-dev libsnappy-dev libhdf5-serial-dev protobuf-compiler libgflags-dev libgoogle-glog-dev liblmdb-dev opencl-headers ocl-icd-opencl-dev libviennacl-dev\n",
223 | " # install python dependencies\n",
224 | " !pip install -q youtube-dl\n",
225 | " # build openpose\n",
226 | " !cd openpose && rm -rf build || true && mkdir build && cd build && cmake .. && make -j`nproc`\n"
227 | ],
228 | "execution_count": null,
229 | "outputs": []
230 | },
231 | {
232 | "cell_type": "code",
233 | "metadata": {
234 | "id": "5NR5OGyeUOKU",
235 | "cellView": "form"
236 | },
237 | "source": [
238 | "# @title Run OpenPose on the extracted frames\n",
239 | "%cd /content\n",
240 | "KEYPOINTS_DIR = os.path.join(RES_DIR, 'keypoints')\n",
241 | "OPENPOSE_IMAGES_DIR = os.path.join(RES_DIR, 'openpose_images')\n",
242 | "!mkdir $KEYPOINTS_DIR\n",
243 | "!mkdir $OPENPOSE_IMAGES_DIR\n",
244 | "\n",
245 | "!cd openpose && ./build/examples/openpose/openpose.bin --image_dir $FRAMES_DIR --write_json $KEYPOINTS_DIR --face --hand --display 0 --write_images $OPENPOSE_IMAGES_DIR\n",
246 | "\n",
247 | "input_img_path = os.path.join(FRAMES_DIR, sorted(os.listdir(FRAMES_DIR))[0])\n",
248 | "openpose_img_path = os.path.join(OPENPOSE_IMAGES_DIR, sorted(os.listdir(OPENPOSE_IMAGES_DIR))[0])\n",
249 | "\n",
250 | "test_img = load_img(input_img_path)\n",
251 | "open_pose_img = load_img(openpose_img_path)\n",
252 | "\n",
253 | "plt.figure(figsize=(10, 10))\n",
254 | "plt.title(\"Input Frame + Openpose Prediction\")\n",
255 | "plt.imshow(np.concatenate([test_img, open_pose_img], 1))"
256 | ],
257 | "execution_count": null,
258 | "outputs": []
259 | },
260 | {
261 | "cell_type": "markdown",
262 | "metadata": {
263 | "id": "to4QpKLFHf2s"
264 | },
265 | "source": [
266 | "\n",
267 | "## Infer 3D Human Model with [SMPLify-X](https://smpl-x.is.tue.mpg.de/)"
268 | ]
269 | },
270 | {
271 | "cell_type": "code",
272 | "metadata": {
273 | "id": "SFzPpUoM99nd",
274 | "cellView": "form"
275 | },
276 | "source": [
277 | "# @title Install SMPLify-X and other dependencies\n",
278 | "\n",
279 | "%cd /content\n",
280 | "!pip install chumpy\n",
281 | "!pip install smplx\n",
282 | "!git clone https://github.com/vchoutas/smplx\n",
283 | "%cd smplx\n",
284 | "!python setup.py install\n",
285 | "\n",
286 | "#vposer\n",
287 | "!pip install git+https://github.com/nghorbani/configer\n",
288 | "!pip install git+https://github.com/sergeyprokudin/human_body_prior\n",
289 | "\n",
290 | "!pip install torch==1.1.0\n",
291 | "%cd /content\n",
292 | "!git clone https://github.com/sergeyprokudin/smplify-x\n",
293 | "%cd /content/smplify-x\n",
294 | "!pip install -r requirements.txt"
295 | ],
296 | "execution_count": null,
297 | "outputs": []
298 | },
299 | {
300 | "cell_type": "code",
301 | "metadata": {
302 | "id": "aSPgjzkFURq-",
303 | "cellView": "form"
304 | },
305 | "source": [
306 | "# @title Upload the SMPL-X model files\n",
307 | "\n",
308 | "# @markdown Proceed to the [official website](https://smpl-x.is.tue.mpg.de/), register and download the zip files with SMPL-X (**models_smplx_v1_1.zip**, ~830MB) and VPoser (**vposer_v1_0.zip**, ~2.5MB) models from the **Downloads** section. \n",
309 | "# @markdown\n",
310 | "\n",
311 | "# @markdown Since uploading large zip files to Colab is relatively slow, we expect you to upload these files to Google Drive instead, link gdrive to the Colab file systems and modify **SMPLX_ZIP_PATH** and **VPOSER_ZIP_PATH** variables accordingly.\n",
312 | "\n",
313 | "%cd /content/\n",
314 | "from google.colab import drive\n",
315 | "drive.mount('/content/gdrive')\n",
316 | "\n",
317 | "SMPLX_ZIP_PATH = '/content/gdrive/MyDrive/datasets/models_smplx_v1_1.zip' # @param {type:\"string\"}\n",
318 | "VPOSER_ZIP_PATH = '/content/gdrive/MyDrive/datasets/vposer_v1_0.zip' # @param {type:\"string\"}\n",
319 | "\n",
320 | "SMPLX_MODEL_PATH = '/content/smplx'\n",
321 | "!mkdir $SMPLX_MODEL_PATH\n",
322 | "!unzip -n '$SMPLX_ZIP_PATH' -d $SMPLX_MODEL_PATH\n",
323 | "VPOSER_MODEL_PATH = '/content/vposer'\n",
324 | "!mkdir $VPOSER_MODEL_PATH\n",
325 | "!unzip -n '$VPOSER_ZIP_PATH' -d $VPOSER_MODEL_PATH"
326 | ],
327 | "execution_count": null,
328 | "outputs": []
329 | },
330 | {
331 | "cell_type": "code",
332 | "metadata": {
333 | "id": "ODeGAyGrrIov",
334 | "cellView": "form"
335 | },
336 | "source": [
337 | "# @title Run SMPLify-X\n",
338 | "\n",
339 | "# @markdown Please select gender of the SMPL-X model:\n",
340 | "\n",
341 | "gender = 'male' #@param [\"neutral\", \"female\", \"male\"]\n",
342 | "\n",
343 | "# @markdown Please keep in mind that estimating 3D body with SMPLify-X framework will take ~30-60 secs, so processing long videos at high FPS might take a long time.\n",
344 | "\n",
345 | "!rm -rf /content/data/smplifyx_results\n",
346 | "%cd /content/smplify-x\n",
347 | "!git pull origin\n",
348 | "!python smplifyx/main.py --config cfg_files/fit_smplx.yaml \\\n",
349 | " --data_folder /content/data \\\n",
350 | " --output_folder /content/data/smplifyx_results \\\n",
351 | " --visualize=True \\\n",
352 | " --gender=$gender \\\n",
353 | " --model_folder /content/smplx/models \\\n",
354 | " --vposer_ckpt /content/vposer/vposer_v1_0 \\\n",
355 | " --part_segm_fn smplx_parts_segm.pkl "
356 | ],
357 | "execution_count": null,
358 | "outputs": []
359 | },
360 | {
361 | "cell_type": "code",
362 | "metadata": {
363 | "id": "qGcmsijSQ5jd",
364 | "cellView": "form"
365 | },
366 | "source": [
367 | "# @title Make train-test-validation splits, copy data to final folders\n",
368 | "\n",
369 | "import shutil\n",
370 | "\n",
371 | "train_ratio = 0.9 #@param\n",
372 | " \n",
373 | "final_zip_path = '/content/gdrive/MyDrive/datasets/smplpix_data_test.zip' # @param {type:\"string\"}\n",
374 | "\n",
375 | "\n",
376 | "target_images_path = '/content/data/smplifyx_results/input_images'\n",
377 | "smplifyx_renders = '/content/data/smplifyx_results/rendered_smplifyx_meshes'\n",
378 | "\n",
379 | "smplpix_data_path = '/content/smplpix_data'\n",
380 | "\n",
381 | "train_input_dir = os.path.join(smplpix_data_path, 'train', 'input')\n",
382 | "train_output_dir = os.path.join(smplpix_data_path, 'train', 'output')\n",
383 | "val_input_dir = os.path.join(smplpix_data_path, 'validation', 'input')\n",
384 | "val_output_dir = os.path.join(smplpix_data_path, 'validation', 'output')\n",
385 | "test_input_dir = os.path.join(smplpix_data_path, 'test', 'input')\n",
386 | "test_output_dir = os.path.join(smplpix_data_path, 'test', 'output')\n",
387 | "\n",
388 | "!mkdir -p $train_input_dir\n",
389 | "!mkdir -p $train_output_dir\n",
390 | "!mkdir -p $val_input_dir\n",
391 | "!mkdir -p $val_output_dir\n",
392 | "!mkdir -p $test_input_dir\n",
393 | "!mkdir -p $test_output_dir\n",
394 | "\n",
395 | "img_names = sorted(os.listdir(target_images_path))\n",
396 | "n_images = len(img_names)\n",
397 | "n_train_images = int(n_images * train_ratio)\n",
398 | "n_val_images = int(n_images * (1-train_ratio) / 2)\n",
399 | "train_images = img_names[0:n_train_images]\n",
400 | "val_images = img_names[n_train_images:n_train_images+n_val_images]\n",
401 | "test_images = img_names[n_train_images:]\n",
402 | "\n",
403 | "for img in train_images:\n",
404 | " shutil.copy(os.path.join(smplifyx_renders, img), train_input_dir)\n",
405 | " shutil.copy(os.path.join(target_images_path, img), train_output_dir)\n",
406 | "\n",
407 | "for img in val_images:\n",
408 | " shutil.copy(os.path.join(smplifyx_renders, img), val_input_dir)\n",
409 | " shutil.copy(os.path.join(target_images_path, img), val_output_dir)\n",
410 | "\n",
411 | "for img in test_images:\n",
412 | " shutil.copy(os.path.join(smplifyx_renders, img), test_input_dir)\n",
413 | " shutil.copy(os.path.join(target_images_path, img), test_output_dir)\n",
414 | "\n",
415 | "\n",
416 | "%cd /content\n",
417 | "!zip -r $final_zip_path smplpix_data/"
418 | ],
419 | "execution_count": null,
420 | "outputs": []
421 | }
422 | ]
423 | }
424 |
--------------------------------------------------------------------------------
/colab_notebooks/SMPLpix_training.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "accelerator": "GPU",
6 | "colab": {
7 | "name": "SMPLpix training.ipynb",
8 | "provenance": [],
9 | "collapsed_sections": [],
10 | "machine_shape": "hm"
11 | },
12 | "kernelspec": {
13 | "display_name": "Python 3",
14 | "name": "python3"
15 | },
16 | "language_info": {
17 | "name": "python"
18 | }
19 | },
20 | "cells": [
21 | {
22 | "cell_type": "markdown",
23 | "metadata": {
24 | "id": "w3rYnZGzpAkn"
25 | },
26 | "source": [
27 | "\n",
28 | "###SMPLpix Training Demo.\n",
29 | "\n",
30 | "**Author**: [Sergey Prokudin](https://ps.is.mpg.de/people/sprokudin). \n",
31 | "[[Project Page](https://sergeyprokudin.github.io/smplpix/)]\n",
32 | "[[Paper](https://arxiv.org/pdf/2008.06872.pdf)]\n",
33 | "[[Video](https://www.youtube.com/watch?v=JY9t4xUAouk)]\n",
34 | "[[GitHub](https://github.com/sergeyprokudin/smplpix)]\n",
35 | "\n",
36 | "This notebook contains an example of training script for SMPLpix rendering module. \n",
37 | "\n",
38 | "To prepare the data with your own video, use the [SMPLpix dataset preparation notebook](https://colab.research.google.com/github//sergeyprokudin/smplpix/blob/main/colab_notebooks/Convert_Video_to_SMPLpix_Dataset.ipynb)."
39 | ]
40 | },
41 | {
42 | "cell_type": "markdown",
43 | "metadata": {
44 | "id": "2uwGj2Zfz-2S"
45 | },
46 | "source": [
47 | "### Install the SMPLpix framework:\n",
48 | "\n"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "metadata": {
54 | "id": "ZEiojUz_qaqp"
55 | },
56 | "source": [
57 | "!git clone https://github.com/sergeyprokudin/smplpix\n",
58 | "%cd /content/smplpix\n",
59 | "!python setup.py install"
60 | ],
61 | "execution_count": null,
62 | "outputs": []
63 | },
64 | {
65 | "cell_type": "markdown",
66 | "metadata": {
67 | "id": "xsWo4DAdyIxZ"
68 | },
69 | "source": [
70 | "### To train the model on the provided [demo dataset](https://www.dropbox.com/s/coapl05ahqalh09/smplpix_data_test_final.zip?dl=0), simply run:"
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "metadata": {
76 | "id": "NdhVaGJzwnM2"
77 | },
78 | "source": [
79 | "!cd /content/smplpix\n",
80 | "!python smplpix/train.py \\\n",
81 | " --workdir='/content/smplpix_logs/' \\\n",
82 | " --resume_training=0 \\\n",
83 | " --data_url='https://www.dropbox.com/s/coapl05ahqalh09/smplpix_data_test_final.zip?dl=0' \\\n",
84 | " --downsample_factor=4 \\\n",
85 | " --n_epochs=200 \\\n",
86 | " --sched_patience=2 \\\n",
87 | " --eval_every_nth_epoch=10 \\\n",
88 | " --batch_size=4 \\\n",
89 | " --learning_rate=1.0e-3 \\\n",
90 | " --n_unet_blocks=5 \\\n",
91 | " --aug_prob=0.0"
92 | ],
93 | "execution_count": null,
94 | "outputs": []
95 | },
96 | {
97 | "cell_type": "markdown",
98 | "metadata": {
99 | "id": "bO4OIRJJpDAC"
100 | },
101 | "source": [
102 | "### Optionally, we can continue training and finetune the network with lower learning rate:"
103 | ]
104 | },
105 | {
106 | "cell_type": "code",
107 | "metadata": {
108 | "id": "HoWDXi3SQByO"
109 | },
110 | "source": [
111 | "!python smplpix/train.py \\\n",
112 | " --workdir='/content/smplpix_logs/' \\\n",
113 | " --resume_training=1 \\\n",
114 | " --data_dir='/content/smplpix_logs/smplpix_data' \\\n",
115 | " --downsample_factor=4 \\\n",
116 | " --n_epochs=50 \\\n",
117 | " --sched_patience=2 \\\n",
118 | " --eval_every_nth_epoch=10 \\\n",
119 | " --batch_size=4 \\\n",
120 | " --learning_rate=1.0e-4 \\\n",
121 | " --n_unet_blocks=5"
122 | ],
123 | "execution_count": null,
124 | "outputs": []
125 | },
126 | {
127 | "cell_type": "markdown",
128 | "metadata": {
129 | "id": "fhKhReRYM0Ri"
130 | },
131 | "source": [
132 | "### To evaluate the model on the test images, run:"
133 | ]
134 | },
135 | {
136 | "cell_type": "code",
137 | "metadata": {
138 | "id": "C1iKU3dXP7ea"
139 | },
140 | "source": [
141 | "!python smplpix/eval.py \\\n",
142 | " --workdir='/content/eval_logs/' \\\n",
143 | " --checkpoint_path='/content/smplpix_logs/network.h5' \\\n",
144 | " --data_dir='/content/smplpix_logs/smplpix_data/test' \\\n",
145 | " --downsample_factor=4 \\\n",
146 | " --save_target=1 "
147 | ],
148 | "execution_count": null,
149 | "outputs": []
150 | },
151 | {
152 | "cell_type": "code",
153 | "metadata": {
154 | "cellView": "form",
155 | "id": "AYpNcVKiMzB_"
156 | },
157 | "source": [
158 | "# @markdown ###Play the generated test video\n",
159 | "from IPython.display import HTML\n",
160 | "from base64 import b64encode\n",
161 | "mp4 = open('/content/eval_logs/test_animation.mp4','rb').read()\n",
162 | "data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
163 | "HTML(\"\"\"\n",
164 | "\n",
167 | "\"\"\" % data_url)"
168 | ],
169 | "execution_count": null,
170 | "outputs": []
171 | }
172 | ]
173 | }
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | """Setup for the project."""
2 | from setuptools import setup
3 |
4 | setup(
5 | name="smplpix",
6 | version=1.0,
7 | description="SMPLpix: Neural Avatars from Deformable 3D models",
8 | install_requires=["torch", "torchvision", "trimesh", "pyrender"],
9 | author="Sergey Prokudin",
10 | license="MIT",
11 | author_email="sergey.prokudin@gmail.com",
12 | packages=["smplpix"]
13 | )
--------------------------------------------------------------------------------
/smplpix/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sergeyprokudin/smplpix/4dd135303f1641c5f13fbf8b92b6fdaa38008b03/smplpix/__init__.py
--------------------------------------------------------------------------------
/smplpix/args.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 | def get_smplpix_arguments():
5 |
6 | parser = argparse.ArgumentParser(description='SMPLpix argument parser')
7 | parser.add_argument('--workdir',
8 | dest='workdir',
9 | help='workdir to save data, checkpoints, renders, etc.',
10 | default=os.getcwd())
11 | parser.add_argument('--data_dir',
12 | dest='data_dir',
13 | help='directory with training input and target images to the network, should contain'
14 | 'input and output subfolders',
15 | default=None)
16 | parser.add_argument('--resume_training',
17 | dest='resume_training',
18 | type=int,
19 | help='whether to continue training process given the checkpoint in workdir',
20 | default=False)
21 | parser.add_argument('--data_url',
22 | dest='data_url',
23 | help='Dropbox URL containing zipped dataset',
24 | default=None)
25 | parser.add_argument('--n_input_channels',
26 | dest='n_input_channels',
27 | type=int,
28 | help='number of channels in the input images',
29 | default=3)
30 | parser.add_argument('--n_output_channels',
31 | dest='n_output_channels',
32 | type=int,
33 | help='number of channels in the input images',
34 | default=3)
35 | parser.add_argument('--sigmoid_output',
36 | dest='sigmoid_output',
37 | type=int,
38 | help='whether to add sigmoid activation as a final layer',
39 | default=True)
40 | parser.add_argument('--n_unet_blocks',
41 | dest='n_unet_blocks',
42 | type=int,
43 | help='number of blocks in UNet rendering module',
44 | default=5)
45 | parser.add_argument('--batch_size',
46 | dest='batch_size',
47 | type=int,
48 | help='batch size to use during training',
49 | default=4)
50 | parser.add_argument('--device',
51 | dest='device',
52 | help='GPU device to use during training',
53 | default='cuda')
54 | parser.add_argument('--downsample_factor',
55 | dest='downsample_factor',
56 | type=int,
57 | help='image downsampling factor (for faster training)',
58 | default=4)
59 | parser.add_argument('--n_epochs',
60 | dest='n_epochs',
61 | type=int,
62 | help='number of epochs to train the network for',
63 | default=500)
64 | parser.add_argument('--learning_rate',
65 | dest='learning_rate',
66 | type=float,
67 | help='initial learning rate',
68 | default=1.0e-3)
69 | parser.add_argument('--eval_every_nth_epoch',
70 | dest='eval_every_nth_epoch',
71 | type=int,
72 | help='evaluate on validation data every nth epoch',
73 | default=10)
74 | parser.add_argument('--sched_patience',
75 | dest='sched_patience',
76 | type=int,
77 | help='amount of validation set evaluations with no improvement after which LR will be reduced',
78 | default=3)
79 | parser.add_argument('--aug_prob',
80 | dest='aug_prob',
81 | type=float,
82 | help='probability that the input sample will be rotated and rescaled - higher value is recommended for data scarse scenarios',
83 | default=0.8)
84 | parser.add_argument('--save_target',
85 | dest='save_target',
86 | type=int,
87 | help='whether to save target images during evaluation',
88 | default=1)
89 | parser.add_argument('--checkpoint_path',
90 | dest='checkpoint_path',
91 | help='path to checkpoint (for evaluation)',
92 | default=None)
93 | args = parser.parse_args()
94 |
95 | return args
96 |
--------------------------------------------------------------------------------
/smplpix/dataset.py:
--------------------------------------------------------------------------------
1 | # SMPLpix basic dataset class used in all experiments
2 | #
3 | # (c) Sergey Prokudin (sergey.prokudin@gmail.com), 2021
4 | #
5 |
6 | import os
7 | import numpy as np
8 | from PIL import Image
9 | import torch
10 | from torch.utils.data import Dataset
11 | from torchvision.transforms import functional as tvf
12 |
13 | class SMPLPixDataset(Dataset):
14 |
15 | def __init__(self, data_dir,
16 | n_input_channels=3,
17 | n_output_channels=3,
18 | downsample_factor=1,
19 | perform_augmentation=False,
20 | augmentation_probability=0.75,
21 | aug_scale_interval=None,
22 | aug_angle_interval=None,
23 | aug_translate_interval=None,
24 | input_fill_color=1,
25 | output_fill_color=1):
26 |
27 | if aug_translate_interval is None:
28 | aug_translate_interval = [-100, 100]
29 | if aug_scale_interval is None:
30 | aug_scale_interval = [0.5, 1.5]
31 | if aug_angle_interval is None:
32 | aug_angle_interval = [-60, 60]
33 |
34 | self.input_dir = os.path.join(data_dir, 'input')
35 | self.output_dir = os.path.join(data_dir, 'output')
36 | if not os.path.exists(self.output_dir):
37 | self.output_dir = self.input_dir
38 |
39 | self.n_input_channels = n_input_channels
40 | self.n_output_channels = n_output_channels
41 | self.samples = sorted(os.listdir(self.output_dir))
42 | self.downsample_factor = downsample_factor
43 | self.perform_augmentation = perform_augmentation
44 | self.augmentation_probability = augmentation_probability
45 | self.aug_scale_interval = aug_scale_interval
46 | self.aug_angle_interval = aug_angle_interval
47 | self.aug_translate_interval = aug_translate_interval
48 | self.input_fill_color = input_fill_color
49 | self.output_fill_color = output_fill_color
50 |
51 | def __len__(self):
52 | return len(self.samples)
53 |
54 | def _get_augmentation_params(self):
55 |
56 | scale = np.random.uniform(low=self.aug_scale_interval[0],
57 | high=self.aug_scale_interval[1])
58 | angle = np.random.uniform(self.aug_angle_interval[0], self.aug_angle_interval[1])
59 | translate = [np.random.uniform(self.aug_translate_interval[0],
60 | self.aug_translate_interval[1]),
61 | np.random.uniform(self.aug_translate_interval[0],
62 | self.aug_translate_interval[1])]
63 |
64 | return scale, angle, translate
65 |
66 | def _augment_images(self, x, y):
67 |
68 | augment_instance = np.random.uniform() < self.augmentation_probability
69 |
70 | if augment_instance:
71 | scale, angle, translate = self._get_augmentation_params()
72 |
73 | x = tvf.affine(x,
74 | angle=angle,
75 | translate=translate,
76 | scale=scale,
77 | shear=0, fill=self.input_fill_color)
78 |
79 | y = tvf.affine(y,
80 | angle=angle,
81 | translate=translate,
82 | scale=scale,
83 | shear=0, fill=self.output_fill_color)
84 |
85 | return x, y
86 |
87 | def __getitem__(self, idx):
88 |
89 | img_name = self.samples[idx]
90 | x_path = os.path.join(self.input_dir, img_name)
91 | x = Image.open(x_path)
92 | y_path = os.path.join(self.output_dir, img_name)
93 | y = Image.open(y_path)
94 |
95 | if self.perform_augmentation:
96 | x, y = self._augment_images(x, y)
97 |
98 | x = torch.Tensor(np.asarray(x) / 255).transpose(0, 2)
99 | y = torch.Tensor(np.asarray(y) / 255).transpose(0, 2)
100 | x = x[0:self.n_input_channels, ::self.downsample_factor, ::self.downsample_factor]
101 | y = y[0:self.n_output_channels, ::self.downsample_factor, ::self.downsample_factor]
102 |
103 | return x, y, img_name
104 |
--------------------------------------------------------------------------------
/smplpix/eval.py:
--------------------------------------------------------------------------------
1 | # Main SMPLpix Evaluation Script
2 | #
3 | # (c) Sergey Prokudin (sergey.prokudin@gmail.com), 2021
4 | #
5 |
6 | import os
7 | import shutil
8 | import pprint
9 | import torch
10 | from torch.utils.data import DataLoader
11 |
12 | from smplpix.args import get_smplpix_arguments
13 | from smplpix.utils import generate_mp4
14 | from smplpix.dataset import SMPLPixDataset
15 | from smplpix.unet import UNet
16 | from smplpix.training import train, evaluate
17 | from smplpix.utils import download_and_unzip
18 |
19 | def generate_eval_video(args, data_dir, unet, frame_rate=25, save_target=False, save_input=True):
20 |
21 | print("rendering SMPLpix predictions for %s..." % data_dir)
22 | data_part_name = os.path.split(data_dir)[-1]
23 |
24 | test_dataset = SMPLPixDataset(data_dir=data_dir,
25 | downsample_factor=args.downsample_factor,
26 | perform_augmentation=False,
27 | n_input_channels=args.n_input_channels,
28 | n_output_channels=args.n_output_channels,
29 | augmentation_probability=args.aug_prob)
30 | test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size)
31 | final_renders_path = os.path.join(args.workdir, 'renders_%s' % data_part_name)
32 | _ = evaluate(unet, test_dataloader, final_renders_path, args.device, save_target=save_target, save_input=save_input)
33 |
34 | print("generating video animation for data %s..." % data_dir)
35 |
36 | video_animation_path = os.path.join(args.workdir, '%s_animation.mp4' % data_part_name)
37 | _ = generate_mp4(final_renders_path, video_animation_path, frame_rate=frame_rate)
38 | print("saved animation video to %s" % video_animation_path)
39 |
40 | return
41 |
42 | def main():
43 |
44 | print("******************************************************************************************\n"+
45 | "****************************** SMPLpix Evaluation Loop **********************************\n"+
46 | "******************************************************************************************\n"+
47 | "******** Copyright (c) 2021 - now, Sergey Prokudin (sergey.prokudin@gmail.com) ***********")
48 |
49 | args = get_smplpix_arguments()
50 | print("ARGUMENTS:")
51 | pprint.pprint(args)
52 |
53 | if args.checkpoint_path is None:
54 | print("no model checkpoint was specified, looking in the log directory...")
55 | ckpt_path = os.path.join(args.workdir, 'network.h5')
56 | else:
57 | ckpt_path = args.checkpoint_path
58 | if not os.path.exists(ckpt_path):
59 | print("checkpoint %s not found!" % ckpt_path)
60 | return
61 |
62 | print("defining the neural renderer model (U-Net)...")
63 | unet = UNet(in_channels=args.n_input_channels, out_channels=args.n_output_channels,
64 | n_blocks=args.n_unet_blocks, dim=2, up_mode='resizeconv_linear').to(args.device)
65 |
66 | print("loading the model from checkpoint: %s" % ckpt_path)
67 | unet.load_state_dict(torch.load(ckpt_path))
68 | unet.eval()
69 | generate_eval_video(args, args.data_dir, unet, save_target=args.save_target)
70 |
71 | return
72 |
73 | if __name__== '__main__':
74 | main()
75 |
--------------------------------------------------------------------------------
/smplpix/train.py:
--------------------------------------------------------------------------------
1 | # Main SMPLpix Training Script
2 | #
3 | # (c) Sergey Prokudin (sergey.prokudin@gmail.com), 2021
4 | #
5 |
6 | import os
7 | import shutil
8 | import pprint
9 | import torch
10 | from torch.utils.data import DataLoader
11 |
12 | from smplpix.args import get_smplpix_arguments
13 | from smplpix.utils import generate_mp4
14 | from smplpix.dataset import SMPLPixDataset
15 | from smplpix.unet import UNet
16 | from smplpix.training import train, evaluate
17 | from smplpix.utils import download_and_unzip
18 |
19 | def generate_eval_video(args, data_dir, unet, frame_rate=25, save_target=False, save_input=True):
20 |
21 | print("rendering SMPLpix predictions for %s..." % data_dir)
22 | data_part_name = os.path.split(data_dir)[-1]
23 |
24 | test_dataset = SMPLPixDataset(data_dir=data_dir,
25 | downsample_factor=args.downsample_factor,
26 | perform_augmentation=False,
27 | n_input_channels=args.n_input_channels,
28 | n_output_channels=args.n_output_channels,
29 | augmentation_probability=args.aug_prob)
30 | test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size)
31 | final_renders_path = os.path.join(args.workdir, 'renders_%s' % data_part_name)
32 | _ = evaluate(unet, test_dataloader, final_renders_path, args.device, save_target=save_target, save_input=save_input)
33 |
34 | print("generating video animation for data %s..." % data_dir)
35 |
36 | video_animation_path = os.path.join(args.workdir, '%s_animation.mp4' % data_part_name)
37 | _ = generate_mp4(final_renders_path, video_animation_path, frame_rate=frame_rate)
38 | print("saved animation video to %s" % video_animation_path)
39 |
40 | return
41 |
42 | def main():
43 |
44 | print("******************************************************************************************\n"+
45 | "****************************** SMPLpix Training Loop ************************************\n"+
46 | "******************************************************************************************\n"+
47 | "******** Copyright (c) 2021 - now, Sergey Prokudin (sergey.prokudin@gmail.com) ***********\n"+
48 | "****************************************************************************************+*\n\n")
49 |
50 | args = get_smplpix_arguments()
51 | print("ARGUMENTS:")
52 | pprint.pprint(args)
53 |
54 | if not os.path.exists(args.workdir):
55 | os.makedirs(args.workdir)
56 |
57 | log_dir = os.path.join(args.workdir, 'logs')
58 | os.makedirs(log_dir, exist_ok=True)
59 |
60 | if args.data_url is not None:
61 | download_and_unzip(args.data_url, args.workdir)
62 | args.data_dir = os.path.join(args.workdir, 'smplpix_data')
63 |
64 | train_dir = os.path.join(args.data_dir, 'train')
65 | val_dir = os.path.join(args.data_dir, 'validation')
66 | test_dir = os.path.join(args.data_dir, 'test')
67 |
68 | train_dataset = SMPLPixDataset(data_dir=train_dir,
69 | perform_augmentation=True,
70 | augmentation_probability=args.aug_prob,
71 | downsample_factor=args.downsample_factor,
72 | n_input_channels=args.n_input_channels,
73 | n_output_channels=args.n_output_channels)
74 | train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size)
75 |
76 | if os.path.exists(val_dir):
77 | val_dataset = SMPLPixDataset(data_dir=val_dir,
78 | perform_augmentation=False,
79 | augmentation_probability=args.aug_prob,
80 | downsample_factor=args.downsample_factor,
81 | n_input_channels=args.n_input_channels,
82 | n_output_channels=args.n_output_channels)
83 | val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size)
84 | else:
85 | print("no validation data was provided, will use train data for validation...")
86 | val_dataloader = train_dataloader
87 |
88 | print("defining the neural renderer model (U-Net)...")
89 | unet = UNet(in_channels=args.n_input_channels,
90 | out_channels=args.n_output_channels,
91 | sigmoid_output=args.sigmoid_output,
92 | n_blocks=args.n_unet_blocks, dim=2, up_mode='resizeconv_linear').to(args.device)
93 |
94 | if args.checkpoint_path is None:
95 | ckpt_path = os.path.join(args.workdir, 'network.h5')
96 | else:
97 | ckpt_path = args.checkpoint_path
98 |
99 | if args.resume_training and os.path.exists(ckpt_path):
100 | print("found checkpoint, resuming from: %s" % ckpt_path)
101 | unet.load_state_dict(torch.load(ckpt_path))
102 | if not args.resume_training:
103 | print("starting training from scratch, cleaning the log dirs...")
104 | shutil.rmtree(log_dir)
105 |
106 | print("starting training...")
107 | finished = False
108 | try:
109 | train(model=unet, train_dataloader=train_dataloader, val_dataloader=val_dataloader,
110 | log_dir=log_dir, ckpt_path=ckpt_path, device=args.device, n_epochs=args.n_epochs,
111 | eval_every_nth_epoch=args.eval_every_nth_epoch, sched_patience=args.sched_patience,
112 | init_lr=args.learning_rate)
113 | finished = True
114 |
115 | except KeyboardInterrupt:
116 | print("training interrupted, generating final animations...")
117 | generate_eval_video(args, train_dir, unet, save_target=True)
118 | generate_eval_video(args, val_dir, unet, save_target=True)
119 | generate_eval_video(args, test_dir, unet, save_target=True)
120 |
121 | if finished:
122 | generate_eval_video(args, train_dir, unet, save_target=True)
123 | generate_eval_video(args, val_dir, unet, save_target=True)
124 | generate_eval_video(args, test_dir, unet, save_target=True)
125 |
126 | return
127 |
128 | if __name__== '__main__':
129 | main()
130 |
--------------------------------------------------------------------------------
/smplpix/training.py:
--------------------------------------------------------------------------------
1 | # SMPLpix training and evaluation loop functions
2 | #
3 | # (c) Sergey Prokudin (sergey.prokudin@gmail.com), 2021
4 | #
5 |
6 | import os
7 | import numpy as np
8 | from tqdm import tqdm
9 | import torch
10 | import torch.nn as nn
11 | from torchvision.utils import save_image
12 | from .vgg import Vgg16Features
13 |
14 |
15 | def train(model, train_dataloader, val_dataloader, log_dir, ckpt_path, device,
16 | n_epochs=1000, eval_every_nth_epoch=50, sched_patience=5, init_lr=1.0e-4):
17 |
18 | vgg = Vgg16Features(layers_weights = [1, 1/16, 1/8, 1/4, 1]).to(device)
19 | criterion_l1 = nn.L1Loss().to(device)
20 | optimizer = torch.optim.Adam(model.parameters(), lr=init_lr)
21 | sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
22 | patience=sched_patience,
23 | verbose=True)
24 |
25 | for epoch_id in tqdm(range(0, n_epochs)):
26 |
27 | model.train()
28 | torch.save(model.state_dict(), ckpt_path)
29 |
30 | for batch_idx, (x, ytrue, img_names) in enumerate(train_dataloader):
31 | x, ytrue = x.to(device), ytrue.to(device)
32 | ypred = model(x)
33 | vgg_loss = criterion_l1(vgg(ypred), vgg(ytrue))
34 | optimizer.zero_grad()
35 | vgg_loss.backward()
36 | optimizer.step()
37 |
38 | if epoch_id % eval_every_nth_epoch == 0:
39 | print("\ncurrent epoch: %d" % epoch_id)
40 | eval_dir = os.path.join(log_dir, 'val_preds_%04d' % epoch_id)
41 | val_loss = evaluate(model, val_dataloader, eval_dir, device, vgg, show_progress=False)
42 | sched.step(val_loss)
43 |
44 | return
45 |
46 |
47 | def evaluate(model, data_loader, res_dir, device,
48 | vgg=None, report_loss=True, show_progress=True,
49 | save_input=True, save_target=True):
50 |
51 | model.eval()
52 |
53 | if not os.path.exists(res_dir):
54 | os.makedirs(res_dir)
55 |
56 | if vgg is None:
57 | vgg = Vgg16Features(layers_weights = [1, 1 / 16, 1 / 8, 1 / 4, 1]).to(device)
58 | criterion_l1 = nn.L1Loss().to(device)
59 | losses = []
60 |
61 | if show_progress:
62 | data_seq = tqdm(enumerate(data_loader))
63 | else:
64 | data_seq = enumerate(data_loader)
65 |
66 | for batch_idx, (x, ytrue, img_names) in data_seq:
67 |
68 | x, ytrue = x.to(device), ytrue.to(device)
69 |
70 | ypred = model(x).detach().to(device)
71 | losses.append(float(criterion_l1(vgg(ypred), vgg(ytrue))))
72 |
73 | for fid in range(0, len(img_names)):
74 | if save_input:
75 | res_image = torch.cat([x[fid].transpose(1, 2), ypred[fid].transpose(1, 2)], dim=2)
76 | else:
77 | res_image = ypred[fid].transpose(1, 2)
78 | if save_target:
79 | res_image = torch.cat([res_image, ytrue[fid].transpose(1, 2)], dim=2)
80 |
81 | save_image(res_image, os.path.join(res_dir, '%s' % img_names[fid]))
82 |
83 | avg_loss = np.mean(losses)
84 |
85 | if report_loss:
86 | print("mean VGG loss: %f" % np.mean(avg_loss))
87 | print("images saved at %s" % res_dir)
88 |
89 | return avg_loss
90 |
--------------------------------------------------------------------------------
/smplpix/unet.py:
--------------------------------------------------------------------------------
1 | # pytorch U-Net model (https://arxiv.org/abs/1505.04597)
2 | # original repository:
3 | # https://github.com/ELEKTRONN
4 | # https://github.com/ELEKTRONN/elektronn3/blob/master/elektronn3/models/unet.py
5 |
6 | # ELEKTRONN3 - Neural Network Toolkit
7 | #
8 | # Copyright (c) 2017 - now
9 | # Max Planck Institute of Neurobiology, Munich, Germany
10 | # Author: Martin Drawitsch
11 |
12 | # MIT License
13 | #
14 | # Copyright (c) 2017 - now, ELEKTRONN team
15 | #
16 | # Permission is hereby granted, free of charge, to any person obtaining a copy
17 | # of this software and associated documentation files (the "Software"), to deal
18 | # in the Software without restriction, including without limitation the rights
19 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
20 | # copies of the Software, and to permit persons to whom the Software is
21 | # furnished to do so, subject to the following conditions:
22 | #
23 | # The above copyright notice and this permission notice shall be included in all
24 | # copies or substantial portions of the Software.
25 | #
26 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
27 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
28 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
29 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
30 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
31 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
32 | # SOFTWARE.
33 |
34 | """
35 | This is a modified version of the U-Net CNN architecture for biomedical
36 | image segmentation. U-Net was originally published in
37 | https://arxiv.org/abs/1505.04597 by Ronneberger et al.
38 |
39 | A pure-3D variant of U-Net has been proposed by Çiçek et al.
40 | in https://arxiv.org/abs/1606.06650, but the below implementation
41 | is based on the original U-Net paper, with several improvements.
42 |
43 | This code is based on https://github.com/jaxony/unet-pytorch
44 | (c) 2017 Jackson Huang, released under MIT License,
45 | which implements (2D) U-Net with user-defined network depth
46 | and a few other improvements of the original architecture.
47 |
48 | Major differences of this version from Huang's code:
49 |
50 | - Operates on 3D image data (5D tensors) instead of 2D data
51 | - Uses 3D convolution, 3D pooling etc. by default
52 | - planar_blocks architecture parameter for mixed 2D/3D convnets
53 | (see UNet class docstring for details)
54 | - Improved tests (see the bottom of the file)
55 | - Cleaned up parameter/variable names and formatting, changed default params
56 | - Updated for PyTorch 1.3 and Python 3.6 (earlier versions unsupported)
57 | - (Optional DEBUG mode for optional printing of debug information)
58 | - Extended documentation
59 | """
60 |
61 | __all__ = ['UNet']
62 |
63 | import copy
64 | import itertools
65 |
66 | from typing import Sequence, Union, Tuple, Optional
67 |
68 | import torch
69 | from torch import nn
70 | from torch.utils.checkpoint import checkpoint
71 | from torch.nn import functional as F
72 |
73 |
74 | def get_conv(dim=3):
75 | """Chooses an implementation for a convolution layer."""
76 | if dim == 3:
77 | return nn.Conv3d
78 | elif dim == 2:
79 | return nn.Conv2d
80 | else:
81 | raise ValueError('dim has to be 2 or 3')
82 |
83 |
84 | def get_convtranspose(dim=3):
85 | """Chooses an implementation for a transposed convolution layer."""
86 | if dim == 3:
87 | return nn.ConvTranspose3d
88 | elif dim == 2:
89 | return nn.ConvTranspose2d
90 | else:
91 | raise ValueError('dim has to be 2 or 3')
92 |
93 |
94 | def get_maxpool(dim=3):
95 | """Chooses an implementation for a max-pooling layer."""
96 | if dim == 3:
97 | return nn.MaxPool3d
98 | elif dim == 2:
99 | return nn.MaxPool2d
100 | else:
101 | raise ValueError('dim has to be 2 or 3')
102 |
103 |
104 | def get_normalization(normtype: str, num_channels: int, dim: int = 3):
105 | """Chooses an implementation for a batch normalization layer."""
106 | if normtype is None or normtype == 'none':
107 | return nn.Identity()
108 | elif normtype.startswith('group'):
109 | if normtype == 'group':
110 | num_groups = 8
111 | elif len(normtype) > len('group') and normtype[len('group'):].isdigit():
112 | num_groups = int(normtype[len('group'):])
113 | else:
114 | raise ValueError(
115 | f'normtype "{normtype}" not understood. It should be "group",'
116 | f' where is the number of groups.'
117 | )
118 | return nn.GroupNorm(num_groups=num_groups, num_channels=num_channels)
119 | elif normtype == 'instance':
120 | if dim == 3:
121 | return nn.InstanceNorm3d(num_channels)
122 | elif dim == 2:
123 | return nn.InstanceNorm2d(num_channels)
124 | else:
125 | raise ValueError('dim has to be 2 or 3')
126 | elif normtype == 'batch':
127 | if dim == 3:
128 | return nn.BatchNorm3d(num_channels)
129 | elif dim == 2:
130 | return nn.BatchNorm2d(num_channels)
131 | else:
132 | raise ValueError('dim has to be 2 or 3')
133 | else:
134 | raise ValueError(
135 | f'Unknown normalization type "{normtype}".\n'
136 | 'Valid choices are "batch", "instance", "group" or "group",'
137 | 'where is the number of groups.'
138 | )
139 |
140 |
141 | def planar_kernel(x):
142 | """Returns a "planar" kernel shape (e.g. for 2D convolution in 3D space)
143 | that doesn't consider the first spatial dim (D)."""
144 | if isinstance(x, int):
145 | return (1, x, x)
146 | else:
147 | return x
148 |
149 |
150 | def planar_pad(x):
151 | """Returns a "planar" padding shape that doesn't pad along the first spatial dim (D)."""
152 | if isinstance(x, int):
153 | return (0, x, x)
154 | else:
155 | return x
156 |
157 |
158 | def conv3(in_channels, out_channels, kernel_size=3, stride=1,
159 | padding=1, bias=True, planar=False, dim=3):
160 | """Returns an appropriate spatial convolution layer, depending on args.
161 | - dim=2: Conv2d with 3x3 kernel
162 | - dim=3 and planar=False: Conv3d with 3x3x3 kernel
163 | - dim=3 and planar=True: Conv3d with 1x3x3 kernel
164 | """
165 | if planar:
166 | stride = planar_kernel(stride)
167 | padding = planar_pad(padding)
168 | kernel_size = planar_kernel(kernel_size)
169 | return get_conv(dim)(
170 | in_channels,
171 | out_channels,
172 | kernel_size=kernel_size,
173 | stride=stride,
174 | padding=padding,
175 | bias=bias
176 | )
177 |
178 |
179 | def upconv2(in_channels, out_channels, mode='transpose', planar=False, dim=3):
180 | """Returns a learned upsampling operator depending on args."""
181 | kernel_size = 2
182 | stride = 2
183 | if planar:
184 | kernel_size = planar_kernel(kernel_size)
185 | stride = planar_kernel(stride)
186 | if mode == 'transpose':
187 | return get_convtranspose(dim)(
188 | in_channels,
189 | out_channels,
190 | kernel_size=kernel_size,
191 | stride=stride
192 | )
193 | elif 'resizeconv' in mode:
194 | if 'linear' in mode:
195 | upsampling_mode = 'trilinear' if dim == 3 else 'bilinear'
196 | else:
197 | upsampling_mode = 'nearest'
198 | rc_kernel_size = 1 if mode.endswith('1') else 3
199 | return ResizeConv(
200 | in_channels, out_channels, planar=planar, dim=dim,
201 | upsampling_mode=upsampling_mode, kernel_size=rc_kernel_size
202 | )
203 |
204 |
205 | def conv1(in_channels, out_channels, dim=3):
206 | """Returns a 1x1 or 1x1x1 convolution, depending on dim"""
207 | return get_conv(dim)(in_channels, out_channels, kernel_size=1)
208 |
209 |
210 | def get_activation(activation):
211 | if isinstance(activation, str):
212 | if activation == 'relu':
213 | return nn.ReLU()
214 | elif activation == 'leaky':
215 | return nn.LeakyReLU(negative_slope=0.1)
216 | elif activation == 'prelu':
217 | return nn.PReLU(num_parameters=1)
218 | elif activation == 'rrelu':
219 | return nn.RReLU()
220 | elif activation == 'lin':
221 | return nn.Identity()
222 | else:
223 | # Deep copy is necessary in case of paremtrized activations
224 | return copy.deepcopy(activation)
225 |
226 |
227 | class DownConv(nn.Module):
228 | """
229 | A helper Module that performs 2 convolutions and 1 MaxPool.
230 | A ReLU activation follows each convolution.
231 | """
232 | def __init__(self, in_channels, out_channels, pooling=True, planar=False, activation='relu',
233 | normalization=None, full_norm=True, dim=3, conv_mode='same'):
234 | super().__init__()
235 |
236 | self.in_channels = in_channels
237 | self.out_channels = out_channels
238 | self.pooling = pooling
239 | self.normalization = normalization
240 | self.dim = dim
241 | padding = 1 if 'same' in conv_mode else 0
242 |
243 | self.conv1 = conv3(
244 | self.in_channels, self.out_channels, planar=planar, dim=dim, padding=padding
245 | )
246 | self.conv2 = conv3(
247 | self.out_channels, self.out_channels, planar=planar, dim=dim, padding=padding
248 | )
249 |
250 | if self.pooling:
251 | kernel_size = 2
252 | if planar:
253 | kernel_size = planar_kernel(kernel_size)
254 | self.pool = get_maxpool(dim)(kernel_size=kernel_size, ceil_mode=True)
255 | self.pool_ks = kernel_size
256 | else:
257 | self.pool = nn.Identity()
258 | self.pool_ks = -123 # Bogus value, will never be read. Only to satisfy TorchScript's static type system
259 |
260 | self.act1 = get_activation(activation)
261 | self.act2 = get_activation(activation)
262 |
263 | if full_norm:
264 | self.norm0 = get_normalization(normalization, self.out_channels, dim=dim)
265 | else:
266 | self.norm0 = nn.Identity()
267 | self.norm1 = get_normalization(normalization, self.out_channels, dim=dim)
268 |
269 | def forward(self, x):
270 | y = self.conv1(x)
271 | y = self.norm0(y)
272 | y = self.act1(y)
273 | y = self.conv2(y)
274 | y = self.norm1(y)
275 | y = self.act2(y)
276 | before_pool = y
277 | y = self.pool(y)
278 | return y, before_pool
279 |
280 |
281 | @torch.jit.script
282 | def autocrop(from_down: torch.Tensor, from_up: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
283 | """
284 | Crops feature tensors from the encoder and decoder pathways so that they
285 | can be combined.
286 |
287 | - If inputs from the encoder pathway have shapes that are not divisible
288 | by 2, the use of ``nn.MaxPool(ceil_mode=True)`` leads to the 2x
289 | upconvolution results being too large by one element in each odd
290 | dimension, so they need to be cropped in these dimensions.
291 |
292 | - If VALID convolutions are used, feature tensors get smaller with each
293 | convolution, so we need to center-crop the larger feature tensors from
294 | the encoder pathway to make features combinable with the smaller
295 | decoder feautures.
296 |
297 | Args:
298 | from_down: Feature from encoder pathway (``DownConv``)
299 | from_up: Feature from decoder pathway (2x upsampled)
300 |
301 | Returns:
302 |
303 | """
304 | ndim = from_down.dim() # .ndim is not supported by torch.jit
305 |
306 | if from_down.shape[2:] == from_up.shape[2:]: # No need to crop anything
307 | return from_down, from_up
308 |
309 | # Step 1: Handle odd shapes
310 |
311 | # Handle potentially odd input shapes from encoder
312 | # by cropping from_up by 1 in each dim that is odd in from_down and not
313 | # odd in from_up (that is, where the difference between them is odd).
314 | # The reason for looking at the shape difference and not just the shape
315 | # of from_down is that although decoder outputs mostly have even shape
316 | # because of the 2x upsampling, but if anisotropic pooling is used, the
317 | # decoder outputs can also be be oddly shaped in the z (D) dimension.
318 | # In these cases no cropping should be performed.
319 | ds = from_down.shape[2:]
320 | us = from_up.shape[2:]
321 | upcrop = [u - ((u - d) % 2) for d, u in zip(ds, us)]
322 |
323 | if ndim == 4:
324 | from_up = from_up[:, :, :upcrop[0], :upcrop[1]]
325 | if ndim == 5:
326 | from_up = from_up[:, :, :upcrop[0], :upcrop[1], :upcrop[2]]
327 |
328 | # Step 2: Handle center-crop resulting from valid convolutions
329 | ds = from_down.shape[2:]
330 | us = from_up.shape[2:]
331 |
332 | assert ds[0] >= us[0], f'{ds, us}'
333 | assert ds[1] >= us[1]
334 | if ndim == 4:
335 | from_down = from_down[
336 | :,
337 | :,
338 | (ds[0] - us[0]) // 2:(ds[0] + us[0]) // 2,
339 | (ds[1] - us[1]) // 2:(ds[1] + us[1]) // 2
340 | ]
341 | elif ndim == 5:
342 | assert ds[2] >= us[2]
343 | from_down = from_down[
344 | :,
345 | :,
346 | ((ds[0] - us[0]) // 2):((ds[0] + us[0]) // 2),
347 | ((ds[1] - us[1]) // 2):((ds[1] + us[1]) // 2),
348 | ((ds[2] - us[2]) // 2):((ds[2] + us[2]) // 2),
349 | ]
350 | return from_down, from_up
351 |
352 |
353 | class UpConv(nn.Module):
354 | """
355 | A helper Module that performs 2 convolutions and 1 UpConvolution.
356 | A ReLU activation follows each convolution.
357 | """
358 |
359 | att: Optional[torch.Tensor]
360 |
361 | def __init__(self, in_channels, out_channels,
362 | merge_mode='concat', up_mode='transpose', planar=False,
363 | activation='relu', normalization=None, full_norm=True, dim=3, conv_mode='same',
364 | attention=False):
365 | super().__init__()
366 |
367 | self.in_channels = in_channels
368 | self.out_channels = out_channels
369 | self.merge_mode = merge_mode
370 | self.up_mode = up_mode
371 | self.normalization = normalization
372 | padding = 1 if 'same' in conv_mode else 0
373 |
374 | self.upconv = upconv2(self.in_channels, self.out_channels,
375 | mode=self.up_mode, planar=planar, dim=dim)
376 |
377 | if self.merge_mode == 'concat':
378 | self.conv1 = conv3(
379 | 2*self.out_channels, self.out_channels, planar=planar, dim=dim, padding=padding
380 | )
381 | else:
382 | # num of input channels to conv2 is same
383 | self.conv1 = conv3(
384 | self.out_channels, self.out_channels, planar=planar, dim=dim, padding=padding
385 | )
386 | self.conv2 = conv3(
387 | self.out_channels, self.out_channels, planar=planar, dim=dim, padding=padding
388 | )
389 |
390 | self.act0 = get_activation(activation)
391 | self.act1 = get_activation(activation)
392 | self.act2 = get_activation(activation)
393 |
394 | if full_norm:
395 | self.norm0 = get_normalization(normalization, self.out_channels, dim=dim)
396 | self.norm1 = get_normalization(normalization, self.out_channels, dim=dim)
397 | else:
398 | self.norm0 = nn.Identity()
399 | self.norm1 = nn.Identity()
400 | self.norm2 = get_normalization(normalization, self.out_channels, dim=dim)
401 | if attention:
402 | self.attention = GridAttention(
403 | in_channels=in_channels // 2, gating_channels=in_channels, dim=dim
404 | )
405 | else:
406 | self.attention = DummyAttention()
407 | self.att = None # Field to store attention mask for later analysis
408 |
409 | def forward(self, enc, dec):
410 | """ Forward pass
411 | Arguments:
412 | enc: Tensor from the encoder pathway
413 | dec: Tensor from the decoder pathway (to be upconv'd)
414 | """
415 |
416 | updec = self.upconv(dec)
417 | enc, updec = autocrop(enc, updec)
418 | genc, att = self.attention(enc, dec)
419 | if not torch.jit.is_scripting():
420 | self.att = att
421 | updec = self.norm0(updec)
422 | updec = self.act0(updec)
423 | if self.merge_mode == 'concat':
424 | mrg = torch.cat((updec, genc), 1)
425 | else:
426 | mrg = updec + genc
427 | y = self.conv1(mrg)
428 | y = self.norm1(y)
429 | y = self.act1(y)
430 | y = self.conv2(y)
431 | y = self.norm2(y)
432 | y = self.act2(y)
433 | return y
434 |
435 |
436 | class ResizeConv(nn.Module):
437 | """Upsamples by 2x and applies a convolution.
438 |
439 | This is meant as a replacement for transposed convolution to avoid
440 | checkerboard artifacts. See
441 |
442 | - https://distill.pub/2016/deconv-checkerboard/
443 | - https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/190
444 | """
445 | def __init__(self, in_channels, out_channels, kernel_size=3, planar=False, dim=3,
446 | upsampling_mode='nearest'):
447 | super().__init__()
448 | self.upsampling_mode = upsampling_mode
449 | self.scale_factor = 2
450 | if dim == 3 and planar: # Only interpolate (H, W) dims, leave D as is
451 | self.scale_factor = planar_kernel(self.scale_factor)
452 | self.dim = dim
453 | self.upsample = nn.Upsample(scale_factor=self.scale_factor, mode=self.upsampling_mode)
454 | # TODO: Investigate if 3x3 or 1x1 conv makes more sense here and choose default accordingly
455 | # Preliminary notes:
456 | # - conv3 increases global parameter count by ~10%, compared to conv1 and is slower overall
457 | # - conv1 is the simplest way of aligning feature dimensions
458 | # - conv1 may be enough because in all common models later layers will apply conv3
459 | # eventually, which could learn to perform the same task...
460 | # But not exactly the same thing, because this layer operates on
461 | # higher-dimensional features, which subsequent layers can't access
462 | # (at least in U-Net out_channels == in_channels // 2).
463 | # --> Needs empirical evaluation
464 | if kernel_size == 3:
465 | self.conv = conv3(
466 | in_channels, out_channels, padding=1, planar=planar, dim=dim
467 | )
468 | elif kernel_size == 1:
469 | self.conv = conv1(in_channels, out_channels, dim=dim)
470 | else:
471 | raise ValueError(f'kernel_size={kernel_size} is not supported. Choose 1 or 3.')
472 |
473 | def forward(self, x):
474 | return self.conv(self.upsample(x))
475 |
476 |
477 | class GridAttention(nn.Module):
478 | """Based on https://github.com/ozan-oktay/Attention-Gated-Networks
479 |
480 | Published in https://arxiv.org/abs/1804.03999"""
481 | def __init__(self, in_channels, gating_channels, inter_channels=None, dim=3, sub_sample_factor=2):
482 | super().__init__()
483 |
484 | assert dim in [2, 3]
485 |
486 | # Downsampling rate for the input featuremap
487 | if isinstance(sub_sample_factor, tuple): self.sub_sample_factor = sub_sample_factor
488 | elif isinstance(sub_sample_factor, list): self.sub_sample_factor = tuple(sub_sample_factor)
489 | else: self.sub_sample_factor = tuple([sub_sample_factor]) * dim
490 |
491 | # Default parameter set
492 | self.dim = dim
493 | self.sub_sample_kernel_size = self.sub_sample_factor
494 |
495 | # Number of channels (pixel dimensions)
496 | self.in_channels = in_channels
497 | self.gating_channels = gating_channels
498 | self.inter_channels = inter_channels
499 |
500 | if self.inter_channels is None:
501 | self.inter_channels = in_channels // 2
502 | if self.inter_channels == 0:
503 | self.inter_channels = 1
504 |
505 | if dim == 3:
506 | conv_nd = nn.Conv3d
507 | bn = nn.BatchNorm3d
508 | self.upsample_mode = 'trilinear'
509 | elif dim == 2:
510 | conv_nd = nn.Conv2d
511 | bn = nn.BatchNorm2d
512 | self.upsample_mode = 'bilinear'
513 | else:
514 | raise NotImplementedError
515 |
516 | # Output transform
517 | self.w = nn.Sequential(
518 | conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1),
519 | bn(self.in_channels),
520 | )
521 | # Theta^T * x_ij + Phi^T * gating_signal + bias
522 | self.theta = conv_nd(
523 | in_channels=self.in_channels, out_channels=self.inter_channels,
524 | kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, bias=False
525 | )
526 | self.phi = conv_nd(
527 | in_channels=self.gating_channels, out_channels=self.inter_channels,
528 | kernel_size=1, stride=1, padding=0, bias=True
529 | )
530 | self.psi = conv_nd(
531 | in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, bias=True
532 | )
533 |
534 | self.init_weights()
535 |
536 | def forward(self, x, g):
537 | # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw)
538 | # phi => (b, g_d) -> (b, i_c)
539 | theta_x = self.theta(x)
540 |
541 | # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w')
542 | # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3)
543 | phi_g = F.interpolate(self.phi(g), size=theta_x.shape[2:], mode=self.upsample_mode, align_corners=False)
544 | f = F.relu(theta_x + phi_g, inplace=True)
545 |
546 | # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3)
547 | sigm_psi_f = torch.sigmoid(self.psi(f))
548 |
549 | # upsample the attentions and multiply
550 | sigm_psi_f = F.interpolate(sigm_psi_f, size=x.shape[2:], mode=self.upsample_mode, align_corners=False)
551 | y = sigm_psi_f.expand_as(x) * x
552 | wy = self.w(y)
553 |
554 | return wy, sigm_psi_f
555 |
556 | def init_weights(self):
557 | def weight_init(m):
558 | classname = m.__class__.__name__
559 | if classname.find('Conv') != -1:
560 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
561 | elif classname.find('Linear') != -1:
562 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
563 | elif classname.find('BatchNorm') != -1:
564 | nn.init.normal_(m.weight.data, 1.0, 0.02)
565 | nn.init.constant_(m.bias.data, 0.0)
566 | self.apply(weight_init)
567 |
568 |
569 | class DummyAttention(nn.Module):
570 | def forward(self, x, g):
571 | return x, None
572 |
573 |
574 | # TODO: Pre-calculate output sizes when using valid convolutions
575 | class UNet(nn.Module):
576 | """Modified version of U-Net, adapted for 3D biomedical image segmentation
577 |
578 | The U-Net is a convolutional encoder-decoder neural network.
579 | Contextual spatial information (from the decoding, expansive pathway)
580 | about an input tensor is merged with information representing the
581 | localization of details (from the encoding, compressive pathway).
582 |
583 | - Original paper: https://arxiv.org/abs/1505.04597
584 | - Base implementation: https://github.com/jaxony/unet-pytorch
585 |
586 |
587 | Modifications to the original paper (@jaxony):
588 |
589 | - Padding is used in size-3-convolutions to prevent loss
590 | of border pixels.
591 | - Merging outputs does not require cropping due to (1).
592 | - Residual connections can be used by specifying
593 | UNet(merge_mode='add').
594 | - If non-parametric upsampling is used in the decoder
595 | pathway (specified by upmode='upsample'), then an
596 | additional 1x1 convolution occurs after upsampling
597 | to reduce channel dimensionality by a factor of 2.
598 | This channel halving happens with the convolution in
599 | the tranpose convolution (specified by upmode='transpose').
600 |
601 | Additional modifications (@mdraw):
602 |
603 | - Operates on 3D image data (5D tensors) instead of 2D data
604 | - Uses 3D convolution, 3D pooling etc. by default
605 | - Each network block pair (the two corresponding submodules in the
606 | encoder and decoder pathways) can be configured to either work
607 | in 3D or 2D mode (3D/2D convolution, pooling etc.)
608 | with the `planar_blocks` parameter.
609 | This is helpful for dealing with data anisotropy (commonly the
610 | depth axis has lower resolution in SBEM data sets, so it is not
611 | as important for convolution/pooling) and can reduce the complexity of
612 | models (parameter counts, speed, memory usage etc.).
613 | Note: If planar blocks are used, the input patch size should be
614 | adapted by reducing depth and increasing height and width of inputs.
615 | - Configurable activation function.
616 | - Optional normalization
617 |
618 | Gradient checkpointing can be used to reduce memory consumption while
619 | training. To make use of gradient checkpointing, just run the
620 | ``forward_gradcp()`` instead of the regular ``forward`` method.
621 | This makes the backward pass a bit slower, but the memory savings can be
622 | huge (usually around 20% - 50%, depending on hyperparameters). Checkpoints
623 | are made after each network *block*.
624 | See https://pytorch.org/docs/master/checkpoint.html and
625 | https://arxiv.org/abs/1604.06174 for more details.
626 | Gradient checkpointing is not supported in TorchScript mode.
627 |
628 | Args:
629 | in_channels: Number of input channels
630 | (e.g. 1 for single-grayscale inputs, 3 for RGB images)
631 | Default: 1
632 | out_channels: Number of output channels (in classification/semantic
633 | segmentation, this is the number of different classes).
634 | Default: 2
635 | n_blocks: Number of downsampling/convolution blocks (max-pooling)
636 | in the encoder pathway. The decoder (upsampling/upconvolution)
637 | pathway will consist of `n_blocks - 1` blocks.
638 | Increasing `n_blocks` has two major effects:
639 |
640 | - The network will be deeper
641 | (n + 1 -> 4 additional convolution layers)
642 | - Since each block causes one additional downsampling, more
643 | contextual information will be available for the network,
644 | enhancing the effective visual receptive field.
645 | (n + 1 -> receptive field is approximately doubled in each
646 | dimension, except in planar blocks, in which it is only
647 | doubled in the H and W image dimensions)
648 |
649 | **Important note**: Always make sure that the spatial shape of
650 | your input is divisible by the number of blocks, because
651 | else, concatenating downsampled features will fail.
652 | start_filts: Number of filters for the first convolution layer.
653 | Note: The filter counts of the later layers depend on the
654 | choice of `merge_mode`.
655 | up_mode: Upsampling method in the decoder pathway.
656 | Choices:
657 |
658 | - 'transpose' (default): Use transposed convolution
659 | ("Upconvolution")
660 | - 'resizeconv_nearest': Use resize-convolution with nearest-
661 | neighbor interpolation, as proposed in
662 | https://distill.pub/2016/deconv-checkerboard/
663 | - 'resizeconv_linear: Same as above, but with (bi-/tri-)linear
664 | interpolation
665 | - 'resizeconv_nearest1': Like 'resizeconv_nearest', but using a
666 | light-weight 1x1 convolution layer instead of a spatial convolution
667 | - 'resizeconv_linear1': Like 'resizeconv_nearest', but using a
668 | light-weight 1x1-convolution layer instead of a spatial convolution
669 | merge_mode: How the features from the encoder pathway should
670 | be combined with the decoder features.
671 | Choices:
672 |
673 | - 'concat' (default): Concatenate feature maps along the
674 | `C` axis, doubling the number of filters each block.
675 | - 'add': Directly add feature maps (like in ResNets).
676 | The number of filters thus stays constant in each block.
677 |
678 | Note: According to https://arxiv.org/abs/1701.03056, feature
679 | concatenation ('concat') generally leads to better model
680 | accuracy than 'add' in typical medical image segmentation
681 | tasks.
682 | planar_blocks: Each number i in this sequence leads to the i-th
683 | block being a "planar" block. This means that all image
684 | operations performed in the i-th block in the encoder pathway
685 | and its corresponding decoder counterpart disregard the depth
686 | (`D`) axis and only operate in 2D (`H`, `W`).
687 | This is helpful for dealing with data anisotropy (commonly the
688 | depth axis has lower resolution in SBEM data sets, so it is
689 | not as important for convolution/pooling) and can reduce the
690 | complexity of models (parameter counts, speed, memory usage
691 | etc.).
692 | Note: If planar blocks are used, the input patch size should
693 | be adapted by reducing depth and increasing height and
694 | width of inputs.
695 | activation: Name of the non-linear activation function that should be
696 | applied after each network layer.
697 | Choices (see https://arxiv.org/abs/1505.00853 for details):
698 |
699 | - 'relu' (default)
700 | - 'leaky': Leaky ReLU (slope 0.1)
701 | - 'prelu': Parametrized ReLU. Best for training accuracy, but
702 | tends to increase overfitting.
703 | - 'rrelu': Can improve generalization at the cost of training
704 | accuracy.
705 | - Or you can pass an nn.Module instance directly, e.g.
706 | ``activation=torch.nn.ReLU()``
707 | normalization: Type of normalization that should be applied at the end
708 | of each block. Note that it is applied after the activated conv
709 | layers, not before the activation. This scheme differs from the
710 | original batch normalization paper and the BN scheme of 3D U-Net,
711 | but it delivers better results this way
712 | (see https://redd.it/67gonq).
713 | Choices:
714 |
715 | - 'group' for group normalization (G=8)
716 | - 'group' for group normalization with groups
717 | (e.g. 'group16') for G=16
718 | - 'instance' for instance normalization
719 | - 'batch' for batch normalization (default)
720 | - 'none' or ``None`` for no normalization
721 | attention: If ``True``, use grid attention in the decoding pathway,
722 | as proposed in https://arxiv.org/abs/1804.03999.
723 | Default: ``False``.
724 | sigmoid_output: If ``True``, add sigmoid activation as final layer
725 | Default: ``True``.
726 | full_norm: If ``True`` (default), perform normalization after each
727 | (transposed) convolution in the network (which is what almost
728 | all published neural network architectures do).
729 | If ``False``, only normalize after the last convolution
730 | layer of each block, in order to save resources. This was also
731 | the default behavior before this option was introduced.
732 | dim: Spatial dimensionality of the network. Choices:
733 |
734 | - 3 (default): 3D mode. Every block fully works in 3D unless
735 | it is excluded by the ``planar_blocks`` setting.
736 | The network expects and operates on 5D input tensors
737 | (N, C, D, H, W).
738 | - 2: Every block and every operation works in 2D, expecting
739 | 4D input tensors (N, C, H, W).
740 | conv_mode: Padding mode of convolutions. Choices:
741 |
742 | - 'same' (default): Use SAME-convolutions in every layer:
743 | zero-padding inputs so that all convolutions preserve spatial
744 | shapes and don't produce an offset at the boundaries.
745 | - 'valid': Use VALID-convolutions in every layer: no padding is
746 | used, so every convolution layer reduces spatial shape by 2 in
747 | each dimension. Intermediate feature maps of the encoder pathway
748 | are automatically cropped to compatible shapes so they can be
749 | merged with decoder features.
750 | Advantages:
751 |
752 | - Less resource consumption than SAME because feature maps
753 | have reduced sizes especially in deeper layers.
754 | - No "fake" data (that is, the zeros from the SAME-padding)
755 | is fed into the network. The output regions that are influenced
756 | by zero-padding naturally have worse quality, so they should
757 | be removed in post-processing if possible (see
758 | ``overlap_shape`` in :py:mod:`elektronn3.inference`).
759 | Using VALID convolutions prevents the unnecessary computation
760 | of these regions that need to be cut away anyways for
761 | high-quality tiled inference.
762 | - Avoids the issues described in https://arxiv.org/abs/1811.11718.
763 | - Since the network will not receive zero-padded inputs, it is
764 | not required to learn a robustness against artificial zeros
765 | being in the border regions of inputs. This should reduce the
766 | complexity of the learning task and allow the network to
767 | specialize better on understanding the actual, unaltered
768 | inputs (effectively requiring less parameters to fit).
769 |
770 | Disadvantages:
771 |
772 | - Using this mode poses some additional constraints on input
773 | sizes and requires you to center-crop your targets,
774 | so it's harder to use in practice than the 'same' mode.
775 | - In some cases it might be preferable to get low-quality
776 | outputs at image borders as opposed to getting no outputs at
777 | the borders. Most notably this is the case if you do training
778 | and inference not on small patches, but on complete images in
779 | a single step.
780 | """
781 | def __init__(
782 | self,
783 | in_channels: int = 1,
784 | out_channels: int = 2,
785 | n_blocks: int = 3,
786 | start_filts: int = 32,
787 | up_mode: str = 'resizeconv_linear',
788 | merge_mode: str = 'concat',
789 | planar_blocks: Sequence = (),
790 | batch_norm: str = 'unset',
791 | attention: bool = False,
792 | sigmoid_output: bool = True,
793 | activation: Union[str, nn.Module] = 'relu',
794 | normalization: str = 'batch',
795 | full_norm: bool = True,
796 | dim: int = 2,
797 | conv_mode: str = 'same',
798 | ):
799 | super().__init__()
800 |
801 | if n_blocks < 1:
802 | raise ValueError('n_blocks must be > 1.')
803 |
804 | if dim not in {2, 3}:
805 | raise ValueError('dim has to be 2 or 3')
806 | if dim == 2 and planar_blocks != ():
807 | raise ValueError(
808 | 'If dim=2, you can\'t use planar_blocks since everything will '
809 | 'be planar (2-dimensional) anyways.\n'
810 | 'Either set dim=3 or set planar_blocks=().'
811 | )
812 | if up_mode in ('transpose', 'upsample', 'resizeconv_nearest', 'resizeconv_linear',
813 | 'resizeconv_nearest1', 'resizeconv_linear1'):
814 | self.up_mode = up_mode
815 | else:
816 | raise ValueError("\"{}\" is not a valid mode for upsampling".format(up_mode))
817 |
818 | if merge_mode in ('concat', 'add'):
819 | self.merge_mode = merge_mode
820 | else:
821 | raise ValueError("\"{}\" is not a valid mode for"
822 | "merging up and down paths. "
823 | "Only \"concat\" and "
824 | "\"add\" are allowed.".format(up_mode))
825 |
826 | # NOTE: up_mode 'upsample' is incompatible with merge_mode 'add'
827 | # TODO: Remove merge_mode=add. It's just worse than concat
828 | if 'resizeconv' in self.up_mode and self.merge_mode == 'add':
829 | raise ValueError("up_mode \"resizeconv\" is incompatible "
830 | "with merge_mode \"add\" at the moment "
831 | "because it doesn't make sense to use "
832 | "nearest neighbour to reduce "
833 | "n_blocks channels (by half).")
834 |
835 | if len(planar_blocks) > n_blocks:
836 | raise ValueError('planar_blocks can\'t be longer than n_blocks.')
837 | if planar_blocks and (max(planar_blocks) >= n_blocks or min(planar_blocks) < 0):
838 | raise ValueError(
839 | 'planar_blocks has invalid value range. All values have to be'
840 | 'block indices, meaning integers between 0 and (n_blocks - 1).'
841 | )
842 |
843 | self.out_channels = out_channels
844 | self.in_channels = in_channels
845 | self.sigmoid_output = sigmoid_output
846 | self.start_filts = start_filts
847 | self.n_blocks = n_blocks
848 | self.normalization = normalization
849 | self.attention = attention
850 | self.conv_mode = conv_mode
851 | self.activation = activation
852 | self.dim = dim
853 |
854 | self.down_convs = nn.ModuleList()
855 | self.up_convs = nn.ModuleList()
856 |
857 | if batch_norm != 'unset':
858 | raise RuntimeError(
859 | 'The `batch_norm` option has been replaced with the more general `normalization` option.\n'
860 | 'If you still want to use batch normalization, set `normalization=batch` instead.'
861 | )
862 |
863 | # Indices of blocks that should operate in 2D instead of 3D mode,
864 | # to save resources
865 | self.planar_blocks = planar_blocks
866 |
867 | # create the encoder pathway and add to a list
868 | for i in range(n_blocks):
869 | ins = self.in_channels if i == 0 else outs
870 | outs = self.start_filts * (2**i)
871 | pooling = True if i < n_blocks - 1 else False
872 | planar = i in self.planar_blocks
873 |
874 | down_conv = DownConv(
875 | ins,
876 | outs,
877 | pooling=pooling,
878 | planar=planar,
879 | activation=activation,
880 | normalization=normalization,
881 | full_norm=full_norm,
882 | dim=dim,
883 | conv_mode=conv_mode,
884 | )
885 | self.down_convs.append(down_conv)
886 |
887 | # create the decoder pathway and add to a list
888 | # - careful! decoding only requires n_blocks-1 blocks
889 | for i in range(n_blocks - 1):
890 | ins = outs
891 | outs = ins // 2
892 | planar = n_blocks - 2 - i in self.planar_blocks
893 |
894 | up_conv = UpConv(
895 | ins,
896 | outs,
897 | up_mode=up_mode,
898 | merge_mode=merge_mode,
899 | planar=planar,
900 | activation=activation,
901 | normalization=normalization,
902 | attention=attention,
903 | full_norm=full_norm,
904 | dim=dim,
905 | conv_mode=conv_mode,
906 | )
907 | self.up_convs.append(up_conv)
908 |
909 | self.conv_final = conv1(outs, self.out_channels, dim=dim)
910 |
911 | self.apply(self.weight_init)
912 |
913 | @staticmethod
914 | def weight_init(m):
915 | if isinstance(m, GridAttention):
916 | return
917 | if isinstance(m, (nn.Conv3d, nn.Conv2d, nn.ConvTranspose3d, nn.ConvTranspose2d)):
918 | nn.init.xavier_normal_(m.weight)
919 | if getattr(m, 'bias') is not None:
920 | nn.init.constant_(m.bias, 0)
921 |
922 | def forward(self, x):
923 | encoder_outs = []
924 |
925 | # Encoder pathway, save outputs for merging
926 | i = 0 # Can't enumerate because of https://github.com/pytorch/pytorch/issues/16123
927 | for module in self.down_convs:
928 | x, before_pool = module(x)
929 | encoder_outs.append(before_pool)
930 | i += 1
931 |
932 | # Decoding by UpConv and merging with saved outputs of encoder
933 | i = 0
934 | for module in self.up_convs:
935 | before_pool = encoder_outs[-(i+2)]
936 | x = module(before_pool, x)
937 | i += 1
938 |
939 | # No softmax is used, so you need to apply it in the loss.
940 | x = self.conv_final(x)
941 | # Uncomment the following line to temporarily store output for
942 | # receptive field estimation using fornoxai/receptivefield:
943 | # self.feature_maps = [x] # Currently disabled to save memory
944 | if self.sigmoid_output:
945 | x = torch.sigmoid(x)
946 | return x
947 |
948 | @torch.jit.unused
949 | def forward_gradcp(self, x):
950 | """``forward()`` implementation with gradient checkpointing enabled.
951 | Apart from checkpointing, this behaves the same as ``forward()``."""
952 | encoder_outs = []
953 | i = 0
954 | for module in self.down_convs:
955 | x, before_pool = checkpoint(module, x)
956 | encoder_outs.append(before_pool)
957 | i += 1
958 | i = 0
959 | for module in self.up_convs:
960 | before_pool = encoder_outs[-(i+2)]
961 | x = checkpoint(module, before_pool, x)
962 | i += 1
963 | x = self.conv_final(x)
964 | # self.feature_maps = [x] # Currently disabled to save memory
965 | return x
966 |
967 |
--------------------------------------------------------------------------------
/smplpix/utils.py:
--------------------------------------------------------------------------------
1 | # helper functions for downloading and preprocessing SMPLpix training data
2 | #
3 | # (c) Sergey Prokudin (sergey.prokudin@gmail.com), 2021
4 | #
5 |
6 | import os
7 |
8 | def download_dropbox_url(url, filepath, chunk_size=1024):
9 |
10 | import requests
11 | headers = {'user-agent': 'Wget/1.16 (linux-gnu)'}
12 | r = requests.get(url, stream=True, headers=headers)
13 | with open(filepath, 'wb') as f:
14 | for chunk in r.iter_content(chunk_size=chunk_size):
15 | if chunk:
16 | f.write(chunk)
17 | return
18 |
19 | def unzip(zip_path, target_dir, remove_zip=True):
20 | import zipfile
21 | with zipfile.ZipFile(zip_path, 'r') as zip_ref:
22 | zip_ref.extractall(target_dir)
23 | if remove_zip:
24 | os.remove(zip_path)
25 | return
26 |
27 |
28 | def download_and_unzip(dropbox_url, workdir):
29 |
30 | if not os.path.exists(workdir):
31 | print("creating workdir %s" % workdir)
32 | os.makedirs(workdir)
33 |
34 | data_zip_path = os.path.join(workdir, 'data.zip')
35 | print("downloading zip from dropbox link: %s" % dropbox_url)
36 | download_dropbox_url(dropbox_url, data_zip_path)
37 | print("unzipping %s" % data_zip_path)
38 | unzip(data_zip_path, workdir)
39 |
40 | return
41 |
42 |
43 | def generate_mp4(image_dir, video_path, frame_rate=25, img_ext=None):
44 |
45 | if img_ext is None:
46 | test_img = os.listdir(image_dir)[0]
47 | img_ext = os.path.splitext(test_img)[1]
48 |
49 | ffmpeg_cmd = "ffmpeg -framerate %d -pattern_type glob " \
50 | "-i \'%s/*%s\' -vcodec h264 -an -b:v 1M -pix_fmt yuv420p -an \'%s\'" % \
51 | (frame_rate, image_dir, img_ext, video_path)
52 |
53 | print("executing %s" % ffmpeg_cmd)
54 | exit_code = os.system(ffmpeg_cmd)
55 |
56 | if exit_code != 0:
57 | print("something went wrong during video generation. Make sure you have ffmpeg tool installed.")
58 |
59 | return exit_code
--------------------------------------------------------------------------------
/smplpix/vgg.py:
--------------------------------------------------------------------------------
1 | # extract perceptual features from the pre-trained Vgg16 network
2 | # these features are used for the perceptual loss function (https://arxiv.org/abs/1603.08155)
3 | #
4 | # based on the code snippet of W. Falcon:
5 | # https://gist.github.com/williamFalcon/1ee773c159ff5d76d47518653369d890
6 |
7 | import torch
8 | from torchvision import models
9 |
10 | class Vgg16Features(torch.nn.Module):
11 |
12 | def __init__(self,
13 | requires_grad=False,
14 | layers_weights=None):
15 | super(Vgg16Features, self).__init__()
16 | if layers_weights is None:
17 | self.layers_weights = [1 / 32, 1 / 16, 1 / 8, 1 / 4, 1]
18 | else:
19 | self.layers_weights = layers_weights
20 |
21 | vgg_pretrained_features = models.vgg16(pretrained=True).features
22 | self.slice1 = torch.nn.Sequential()
23 | self.slice2 = torch.nn.Sequential()
24 | self.slice3 = torch.nn.Sequential()
25 | self.slice4 = torch.nn.Sequential()
26 | for x in range(4):
27 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
28 | for x in range(4, 9):
29 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
30 | for x in range(9, 16):
31 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
32 | for x in range(16, 23):
33 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
34 | if not requires_grad:
35 | for param in self.parameters():
36 | param.requires_grad = False
37 |
38 | def forward(self, x):
39 |
40 | h_0 = x.flatten(start_dim=1)
41 | h = self.slice1(x)
42 | h_relu1_2 = h.flatten(start_dim=1)
43 | h = self.slice2(h)
44 | h_relu2_2 = h.flatten(start_dim=1)
45 | h = self.slice3(h)
46 | h_relu3_3 = h.flatten(start_dim=1)
47 | h = self.slice4(h)
48 | h_relu4_3 = h.flatten(start_dim=1)
49 |
50 | h = torch.cat([self.layers_weights[0] * h_0,
51 | self.layers_weights[1] * h_relu1_2,
52 | self.layers_weights[2] * h_relu2_2,
53 | self.layers_weights[3] * h_relu3_3,
54 | self.layers_weights[4] * h_relu4_3], 1)
55 |
56 | return h
57 |
--------------------------------------------------------------------------------