├── .gitignore ├── LICENSE ├── README.md ├── colab_notebooks ├── Convert_Video_to_SMPLpix_Dataset.ipynb └── SMPLpix_training.ipynb ├── setup.py └── smplpix ├── __init__.py ├── args.py ├── dataset.py ├── eval.py ├── train.py ├── training.py ├── unet.py ├── utils.py └── vgg.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021 - now, Sergey Prokudin (sergey.prokudin@gmail.com) 2 | 3 | Redistribution and use of this software and associated documentation files (the "Software"), with or 4 | without modification, are permitted provided that the following conditions are met: 5 | 6 | * The above copyright notice and this permission notice shall be included in all copies or substantial 7 | portions of the Software. 8 | 9 | * Any use for commercial, pornographic, military, or surveillance, purposes is prohibited. The Software 10 | may not be used to create fake, libelous, misleading, or defamatory content of any kind excluding 11 | analyses in peer-reviewed scientific research. 12 | 13 | For commercial uses of the Software, please send email to ps-license@tue.mpg.de. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT 16 | LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 17 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 18 | WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE 19 | SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | ![116595226-621dc600-a923-11eb-85d4-f52a5b9f265c](https://user-images.githubusercontent.com/8117267/133928409-d4576fae-e23d-4d10-9f3b-64855690bc6a.gif) 3 | 4 | _**Left**: [SMPL-X](https://smpl-x.is.tue.mpg.de/) human mesh registered with [SMPLify-X](https://smpl-x.is.tue.mpg.de/), **middle**: SMPLpix render, **right**: ground truth. [Video](https://user-images.githubusercontent.com/8117267/116540639-b9537480-a8ea-11eb-81ca-57d473147fbd.mp4)._ 5 | 6 | 7 | # SMPLpix: Neural Avatars from 3D Human Models 8 | 9 | *SMPLpix* neural rendering framework combines deformable 3D models such as [SMPL-X](https://smpl-x.is.tue.mpg.de/) 10 | with the power of image-to-image translation frameworks (aka [pix2pix](https://phillipi.github.io/pix2pix/) models). 11 | 12 | Please check our [WACV 2021 paper](https://arxiv.org/abs/2008.06872) or a [5-minute explanatory video](https://www.youtube.com/watch?v=JY9t4xUAouk) for more details on the framework. 13 | 14 | _**Important note**_: this repository is a re-implementation of the original framework, made by the same author after the end of internship. 15 | It **does not contain** the original Amazon multi-subject, multi-view training data and code, and uses full mesh rasterizations as inputs rather than point projections (as described [here](https://youtu.be/JY9t4xUAouk?t=241)). 16 | 17 | 18 | ## Demo 19 | 20 | | Description | Link | 21 | | ----------- | ----------- | 22 | | Process a video into a SMPLpix dataset| [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github//sergeyprokudin/smplpix/blob/main/colab_notebooks/Convert_Video_to_SMPLpix_Dataset.ipynb)| 23 | | Train SMPLpix| [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/sergeyprokudin/smplpix/blob/main/colab_notebooks/SMPLpix_training.ipynb)| 24 | 25 | ### Prepare the data 26 | 27 | ![demo_openpose_simplifyx](https://user-images.githubusercontent.com/8117267/116876711-8defc300-ac25-11eb-8b7b-5eab8860602c.png) 28 | 29 | We provide the Colab notebook for preparing SMPLpix training dataset. This will allow you 30 | to create your own neural avatar given [monocular video of a human moving in front of the camera](https://www.dropbox.com/s/rjqwf894ovso218/smplpix_test_video_na.mp4?dl=0). 31 | 32 | ### Run demo training 33 | 34 | We provide some preprocessed data which allows you to run and test the training pipeline right away: 35 | 36 | ``` 37 | git clone https://github.com/sergeyprokudin/smplpix 38 | cd smplpix 39 | python setup.py install 40 | python smplpix/train.py --workdir='/content/smplpix_logs/' \ 41 | --data_url='https://www.dropbox.com/s/coapl05ahqalh09/smplpix_data_test_final.zip?dl=0' 42 | ``` 43 | 44 | ### Train on your own data 45 | 46 | You can train SMPLpix on your own data by specifying the path to the root directory with data: 47 | 48 | ``` 49 | python smplpix/train.py --workdir='/content/smplpix_logs/' \ 50 | --data_dir='/path/to/data' 51 | ``` 52 | 53 | The directory should contain train, validation and test folders, each of which should contain input and output folders. Check the structure of [the demo dataset](https://www.dropbox.com/s/coapl05ahqalh09/smplpix_data_test_final.zip?dl=0) for reference. 54 | 55 | You can also specify various parameters of training via command line. E.g., to reproduce the results of the demo video: 56 | 57 | ``` 58 | python smplpix/train.py --workdir='/content/smplpix_logs/' \ 59 | --data_url='https://www.dropbox.com/s/coapl05ahqalh09/smplpix_data_test_final.zip?dl=0' \ 60 | --downsample_factor=2 \ 61 | --n_epochs=500 \ 62 | --sched_patience=2 \ 63 | --batch_size=4 \ 64 | --n_unet_blocks=5 \ 65 | --n_input_channels=3 \ 66 | --n_output_channels=3 \ 67 | --eval_every_nth_epoch=10 68 | ``` 69 | 70 | Check the [args.py](https://github.com/sergeyprokudin/smplpix/blob/main/smplpix/args.py) for the full list of parameters. 71 | 72 | ## More examples 73 | 74 | ### Animating with novel poses 75 | 76 | ![116546566-0edf4f80-a8f2-11eb-9fb2-a173c0018a4e](https://user-images.githubusercontent.com/8117267/134176955-cc2d75ed-07dc-43f1-adce-c4dfdcc0925f.gif) 77 | 78 | **Left**: poses from the test video sequence, **right**: SMPLpix renders. [Video](https://user-images.githubusercontent.com/8117267/116546566-0edf4f80-a8f2-11eb-9fb2-a173c0018a4e.mp4). 79 | 80 | 81 | ### Rendering faces 82 | 83 | ![116543423-23214d80-a8ee-11eb-9ded-86af17c56549](https://user-images.githubusercontent.com/8117267/134175773-32885d04-32f4-4ff6-a3bb-fddb96575ba4.gif) 84 | 85 | 86 | _**Left**: [FLAME](https://flame.is.tue.mpg.de/) face model inferred with [DECA](https://github.com/YadiraF/DECA), **middle**: ground truth test video, **right**: SMPLpix render. [Video](https://user-images.githubusercontent.com/8117267/116543423-23214d80-a8ee-11eb-9ded-86af17c56549.mp4)._ 87 | 88 | Thanks to [Maria Paola Forte](https://www.is.mpg.de/~Forte) for providing the sequence. 89 | 90 | ### Few-shot artistic neural style transfer 91 | 92 | ![116544826-e9514680-a8ef-11eb-8682-0ea8d19d0d5e](https://user-images.githubusercontent.com/8117267/134177512-22d52204-e3ae-48bd-a4d1-8b6fc1914fe5.gif) 93 | 94 | 95 | _**Left**: rendered [AMASS](https://amass.is.tue.mpg.de/) motion sequence, **right**: generated SMPLpix animations. [Full video](https://user-images.githubusercontent.com/8117267/116544826-e9514680-a8ef-11eb-8682-0ea8d19d0d5e.mp4). See [the explanatory video](https://youtu.be/JY9t4xUAouk?t=255) for details._ 96 | 97 | Credits to [Alexander Kabarov](mailto:blackocher@gmail.com) for providing the training sketches. 98 | 99 | ## Citation 100 | 101 | If you find our work useful in your research, please consider citing: 102 | ``` 103 | @inproceedings{prokudin2021smplpix, 104 | title={SMPLpix: Neural Avatars from 3D Human Models}, 105 | author={Prokudin, Sergey and Black, Michael J and Romero, Javier}, 106 | booktitle={Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision}, 107 | pages={1810--1819}, 108 | year={2021} 109 | } 110 | ``` 111 | 112 | ## License 113 | 114 | See the [LICENSE](https://github.com/sergeyprokudin/smplpix/blob/main/LICENSE) file. 115 | 116 | ![visitors](https://visitor-badge.laobi.icu/badge?page_id=sergeyprokudin/smplpix) 117 | 118 | -------------------------------------------------------------------------------- /colab_notebooks/Convert_Video_to_SMPLpix_Dataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Convert Video to SMPLpix Dataset.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [ 9 | "cbXoNhFF-D8Q" 10 | ], 11 | "toc_visible": true, 12 | "machine_shape": "hm" 13 | }, 14 | "kernelspec": { 15 | "name": "python3", 16 | "display_name": "Python 3" 17 | }, 18 | "accelerator": "GPU" 19 | }, 20 | "cells": [ 21 | { 22 | "cell_type": "markdown", 23 | "metadata": { 24 | "id": "saNBv0dY-Eef" 25 | }, 26 | "source": [ 27 | "![](https://user-images.githubusercontent.com/8117267/116876711-8defc300-ac25-11eb-8b7b-5eab8860602c.png)\n", 28 | "\n", 29 | "# SMPLpix Dataset Preparation.\n", 30 | "\n", 31 | "**Author**: [Sergey Prokudin](https://ps.is.mpg.de/people/sprokudin). \n", 32 | "[[Project Page](https://sergeyprokudin.github.io/smplpix/)]\n", 33 | "[[Paper](https://arxiv.org/pdf/2008.06872.pdf)]\n", 34 | "[[Video](https://www.youtube.com/watch?v=JY9t4xUAouk)]\n", 35 | "[[GitHub](https://github.com/sergeyprokudin/smplpix)]\n", 36 | "\n", 37 | "This notebook contains an example workflow for converting a video file to a SMPLpix dataset. \n", 38 | "### Processing steps:\n", 39 | "\n", 40 | "1. Download the video of choice, extract frames;\n", 41 | "2. Extract **2D keypoints**: run [OpenPose](https://github.com/CMU-Perceptual-Computing-Lab/openpose) on the extracted frames;\n", 42 | "3. Infer **3D human meshes**: run [SMPLify-x](https://github.com/vchoutas/smplify-x) on the extracted frames + keypoints;\n", 43 | "4. Form dataset **image pairs**, where *input* is SMPL-X mesh render, and *output* is the corresponding target ground truth video frame;\n", 44 | "5. **Split the data** into train, test and validation, zip and copy to Google Drive.\n", 45 | "\n", 46 | "### Instructions\n", 47 | "\n", 48 | "1. Convert a video into our dataset format using this notebook.\n", 49 | "2. Train a SMPLpix using the training notebook.\n", 50 | "\n", 51 | "\n", 52 | "### Notes\n", 53 | "* While this will work for small datasets in a Colab runtime, larger datasets will require more compute power;\n", 54 | "* If you would like to train a model on a serious dataset, you should consider copying this to your own workstation and running it there. Some minor modifications will be required, and you will have to install the dependencies separately;\n", 55 | "* Please report issues on the [GitHub issue tracker](https://github.com/sergeyprokudin/smplpix/issues).\n", 56 | "\n", 57 | "If you find this work useful, please consider citing:\n", 58 | "```bibtex\n", 59 | "@inproceedings{prokudin2021smplpix,\n", 60 | " title={SMPLpix: Neural Avatars from 3D Human Models},\n", 61 | " author={Prokudin, Sergey and Black, Michael J and Romero, Javier},\n", 62 | " booktitle={Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision},\n", 63 | " pages={1810--1819},\n", 64 | " year={2021}\n", 65 | "}\n", 66 | "```\n", 67 | "\n", 68 | "Many thanks [Keunhong Park](https://keunhong.com) for providing the [Nerfie dataset preparation template](https://colab.research.google.com/github/google/nerfies/blob/main/notebooks/Nerfies_Capture_Processing.ipynb)!\n" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "metadata": { 74 | "id": "cbXoNhFF-D8Q" 75 | }, 76 | "source": [ 77 | "## Upload the video and extract frames\n" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "metadata": { 83 | "id": "K0EmRUXf5nTr", 84 | "cellView": "form" 85 | }, 86 | "source": [ 87 | "# @title Upload a video file (.mp4, .mov, etc.) from your disk, Dropbox, Google Drive or YouTube\n", 88 | "\n", 89 | "# @markdown This will upload it to the local Colab working directory. You can use a demo video to test the pipeline. The background in the demo was removed with the [Unscreen](https://www.unscreen.com/) service. Alternatively, you can try [PointRend](https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend) segmentation for this purpose. \n", 90 | "\n", 91 | "import os\n", 92 | "from google.colab import files\n", 93 | "\n", 94 | "def download_youtube_video(img_url, save_path, resolution_id=-3):\n", 95 | "\n", 96 | " print(\"downloading the video: %s\" % img_url)\n", 97 | " res_path = YouTube(img_url).streams.order_by('resolution')[resolution_id].download(save_path)\n", 98 | "\n", 99 | " return res_path\n", 100 | "\n", 101 | "def download_dropbox_url(url, filepath, chunk_size=1024):\n", 102 | "\n", 103 | " import requests\n", 104 | " headers = {'user-agent': 'Wget/1.16 (linux-gnu)'}\n", 105 | " r = requests.get(url, stream=True, headers=headers)\n", 106 | " with open(filepath, 'wb') as f:\n", 107 | " for chunk in r.iter_content(chunk_size=chunk_size):\n", 108 | " if chunk:\n", 109 | " f.write(chunk)\n", 110 | " return filepath\n", 111 | "\n", 112 | "!rm -rf /content/data\n", 113 | "\n", 114 | "from google.colab import drive\n", 115 | "drive.mount('/content/gdrive')\n", 116 | "\n", 117 | "VIDEO_SOURCE = 'dropbox' #@param [\"youtube\", \"upload\", \"google drive\", \"dropbox\"]\n", 118 | "\n", 119 | "if VIDEO_SOURCE == 'dropbox':\n", 120 | " DROPBOX_URL = 'https://www.dropbox.com/s/rjqwf894ovso218/smplpix_test_video_na.mp4?dl=0' #@param \n", 121 | " VIDEO_PATH = '/content/video.mp4' \n", 122 | " download_dropbox_url(DROPBOX_URL, VIDEO_PATH)\n", 123 | "elif VIDEO_SOURCE == 'upload':\n", 124 | " print(\"Please upload the video: \")\n", 125 | " uploaded = files.upload()\n", 126 | " VIDEO_PATH = os.path.join('/content', list(uploaded.keys())[0])\n", 127 | "elif VIDEO_SOURCE == 'youtube':\n", 128 | " !pip install pytube\n", 129 | " from pytube import YouTube\n", 130 | " YOTUBE_VIDEO_URL = '' #@param \n", 131 | " VIDEO_PATH = download_youtube_video(YOTUBE_VIDEO_URL, '/content/')\n", 132 | "elif VIDEO_SOURCE == 'google drive':\n", 133 | " from google.colab import drive\n", 134 | " drive.mount('/content/gdrive')\n", 135 | " GOOGLE_DRIVE_PATH = '' #@param \n", 136 | " VIDEO_PATH = GOOGLE_DRIVE_PATH\n", 137 | "\n", 138 | "\n", 139 | "print(\"video is uploaded to %s\" % VIDEO_PATH)" 140 | ], 141 | "execution_count": null, 142 | "outputs": [] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "metadata": { 147 | "id": "VImnPzFg9UNA", 148 | "cellView": "form" 149 | }, 150 | "source": [ 151 | "# @title Flatten the video into frames\n", 152 | "\n", 153 | "\n", 154 | "FPS = 1# @param {type:'number'}\n", 155 | "\n", 156 | "# @markdown _Note_: for longer videos, it might make sense to decrease the FPS as it will take 30-60 seconds for SMPLify-X framework to process every frame.\n", 157 | "\n", 158 | "\n", 159 | "\n", 160 | "RES_DIR = '/content/data'\n", 161 | "FRAMES_DIR = os.path.join(RES_DIR, 'images')\n", 162 | "!rm -rf $RES_DIR\n", 163 | "!mkdir $RES_DIR\n", 164 | "!mkdir $FRAMES_DIR\n", 165 | "!ffmpeg -i \"$VIDEO_PATH\" -vf fps=$FPS -qscale:v 2 '$FRAMES_DIR/%05d.png'\n", 166 | "\n", 167 | "from PIL import Image\n", 168 | "import numpy as np\n", 169 | "import matplotlib.pyplot as plt\n", 170 | "\n", 171 | "def load_img(img_path):\n", 172 | "\n", 173 | " return np.asarray(Image.open(img_path))/255\n", 174 | "\n", 175 | "test_img_path = os.path.join(FRAMES_DIR, os.listdir(FRAMES_DIR)[0])\n", 176 | "\n", 177 | "test_img = load_img(test_img_path)\n", 178 | "\n", 179 | "plt.figure(figsize=(5, 10))\n", 180 | "plt.title(\"extracted image example\")\n", 181 | "plt.imshow(test_img)" 182 | ], 183 | "execution_count": null, 184 | "outputs": [] 185 | }, 186 | { 187 | "cell_type": "markdown", 188 | "metadata": { 189 | "id": "7Z-ASlgBUPXJ" 190 | }, 191 | "source": [ 192 | "## Extract 2D body keypoints with OpenPose\n", 193 | "\n" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "metadata": { 199 | "id": "1AL4QpsBUO9p", 200 | "cellView": "form" 201 | }, 202 | "source": [ 203 | "# @title Install OpenPose\n", 204 | "# @markdown This will take some time (~10 mins). The code is taken from this [OpenPose Colab notebook](https://colab.research.google.com/github/tugstugi/dl-colab-notebooks/blob/master/notebooks/OpenPose.ipynb).\n", 205 | "\n", 206 | "%cd /content\n", 207 | "import os\n", 208 | "from os.path import exists, join, basename, splitext\n", 209 | "\n", 210 | "git_repo_url = 'https://github.com/CMU-Perceptual-Computing-Lab/openpose.git'\n", 211 | "project_name = splitext(basename(git_repo_url))[0]\n", 212 | "if not exists(project_name):\n", 213 | " # see: https://github.com/CMU-Perceptual-Computing-Lab/openpose/issues/949\n", 214 | " # install new CMake becaue of CUDA10\n", 215 | " !wget -q https://cmake.org/files/v3.13/cmake-3.13.0-Linux-x86_64.tar.gz\n", 216 | " !tar xfz cmake-3.13.0-Linux-x86_64.tar.gz --strip-components=1 -C /usr/local\n", 217 | "\n", 218 | " # clone openpose\n", 219 | " !git clone -q --depth 1 $git_repo_url\n", 220 | " !sed -i 's/execute_process(COMMAND git checkout master WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}\\/3rdparty\\/caffe)/execute_process(COMMAND git checkout f019d0dfe86f49d1140961f8c7dec22130c83154 WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}\\/3rdparty\\/caffe)/g' openpose/CMakeLists.txt\n", 221 | " # install system dependencies\n", 222 | " !apt-get -qq install -y libatlas-base-dev libprotobuf-dev libleveldb-dev libsnappy-dev libhdf5-serial-dev protobuf-compiler libgflags-dev libgoogle-glog-dev liblmdb-dev opencl-headers ocl-icd-opencl-dev libviennacl-dev\n", 223 | " # install python dependencies\n", 224 | " !pip install -q youtube-dl\n", 225 | " # build openpose\n", 226 | " !cd openpose && rm -rf build || true && mkdir build && cd build && cmake .. && make -j`nproc`\n" 227 | ], 228 | "execution_count": null, 229 | "outputs": [] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "metadata": { 234 | "id": "5NR5OGyeUOKU", 235 | "cellView": "form" 236 | }, 237 | "source": [ 238 | "# @title Run OpenPose on the extracted frames\n", 239 | "%cd /content\n", 240 | "KEYPOINTS_DIR = os.path.join(RES_DIR, 'keypoints')\n", 241 | "OPENPOSE_IMAGES_DIR = os.path.join(RES_DIR, 'openpose_images')\n", 242 | "!mkdir $KEYPOINTS_DIR\n", 243 | "!mkdir $OPENPOSE_IMAGES_DIR\n", 244 | "\n", 245 | "!cd openpose && ./build/examples/openpose/openpose.bin --image_dir $FRAMES_DIR --write_json $KEYPOINTS_DIR --face --hand --display 0 --write_images $OPENPOSE_IMAGES_DIR\n", 246 | "\n", 247 | "input_img_path = os.path.join(FRAMES_DIR, sorted(os.listdir(FRAMES_DIR))[0])\n", 248 | "openpose_img_path = os.path.join(OPENPOSE_IMAGES_DIR, sorted(os.listdir(OPENPOSE_IMAGES_DIR))[0])\n", 249 | "\n", 250 | "test_img = load_img(input_img_path)\n", 251 | "open_pose_img = load_img(openpose_img_path)\n", 252 | "\n", 253 | "plt.figure(figsize=(10, 10))\n", 254 | "plt.title(\"Input Frame + Openpose Prediction\")\n", 255 | "plt.imshow(np.concatenate([test_img, open_pose_img], 1))" 256 | ], 257 | "execution_count": null, 258 | "outputs": [] 259 | }, 260 | { 261 | "cell_type": "markdown", 262 | "metadata": { 263 | "id": "to4QpKLFHf2s" 264 | }, 265 | "source": [ 266 | "\n", 267 | "## Infer 3D Human Model with [SMPLify-X](https://smpl-x.is.tue.mpg.de/)" 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "metadata": { 273 | "id": "SFzPpUoM99nd", 274 | "cellView": "form" 275 | }, 276 | "source": [ 277 | "# @title Install SMPLify-X and other dependencies\n", 278 | "\n", 279 | "%cd /content\n", 280 | "!pip install chumpy\n", 281 | "!pip install smplx\n", 282 | "!git clone https://github.com/vchoutas/smplx\n", 283 | "%cd smplx\n", 284 | "!python setup.py install\n", 285 | "\n", 286 | "#vposer\n", 287 | "!pip install git+https://github.com/nghorbani/configer\n", 288 | "!pip install git+https://github.com/sergeyprokudin/human_body_prior\n", 289 | "\n", 290 | "!pip install torch==1.1.0\n", 291 | "%cd /content\n", 292 | "!git clone https://github.com/sergeyprokudin/smplify-x\n", 293 | "%cd /content/smplify-x\n", 294 | "!pip install -r requirements.txt" 295 | ], 296 | "execution_count": null, 297 | "outputs": [] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "metadata": { 302 | "id": "aSPgjzkFURq-", 303 | "cellView": "form" 304 | }, 305 | "source": [ 306 | "# @title Upload the SMPL-X model files\n", 307 | "\n", 308 | "# @markdown Proceed to the [official website](https://smpl-x.is.tue.mpg.de/), register and download the zip files with SMPL-X (**models_smplx_v1_1.zip**, ~830MB) and VPoser (**vposer_v1_0.zip**, ~2.5MB) models from the **Downloads** section. \n", 309 | "# @markdown\n", 310 | "\n", 311 | "# @markdown Since uploading large zip files to Colab is relatively slow, we expect you to upload these files to Google Drive instead, link gdrive to the Colab file systems and modify **SMPLX_ZIP_PATH** and **VPOSER_ZIP_PATH** variables accordingly.\n", 312 | "\n", 313 | "%cd /content/\n", 314 | "from google.colab import drive\n", 315 | "drive.mount('/content/gdrive')\n", 316 | "\n", 317 | "SMPLX_ZIP_PATH = '/content/gdrive/MyDrive/datasets/models_smplx_v1_1.zip' # @param {type:\"string\"}\n", 318 | "VPOSER_ZIP_PATH = '/content/gdrive/MyDrive/datasets/vposer_v1_0.zip' # @param {type:\"string\"}\n", 319 | "\n", 320 | "SMPLX_MODEL_PATH = '/content/smplx'\n", 321 | "!mkdir $SMPLX_MODEL_PATH\n", 322 | "!unzip -n '$SMPLX_ZIP_PATH' -d $SMPLX_MODEL_PATH\n", 323 | "VPOSER_MODEL_PATH = '/content/vposer'\n", 324 | "!mkdir $VPOSER_MODEL_PATH\n", 325 | "!unzip -n '$VPOSER_ZIP_PATH' -d $VPOSER_MODEL_PATH" 326 | ], 327 | "execution_count": null, 328 | "outputs": [] 329 | }, 330 | { 331 | "cell_type": "code", 332 | "metadata": { 333 | "id": "ODeGAyGrrIov", 334 | "cellView": "form" 335 | }, 336 | "source": [ 337 | "# @title Run SMPLify-X\n", 338 | "\n", 339 | "# @markdown Please select gender of the SMPL-X model:\n", 340 | "\n", 341 | "gender = 'male' #@param [\"neutral\", \"female\", \"male\"]\n", 342 | "\n", 343 | "# @markdown Please keep in mind that estimating 3D body with SMPLify-X framework will take ~30-60 secs, so processing long videos at high FPS might take a long time.\n", 344 | "\n", 345 | "!rm -rf /content/data/smplifyx_results\n", 346 | "%cd /content/smplify-x\n", 347 | "!git pull origin\n", 348 | "!python smplifyx/main.py --config cfg_files/fit_smplx.yaml \\\n", 349 | " --data_folder /content/data \\\n", 350 | " --output_folder /content/data/smplifyx_results \\\n", 351 | " --visualize=True \\\n", 352 | " --gender=$gender \\\n", 353 | " --model_folder /content/smplx/models \\\n", 354 | " --vposer_ckpt /content/vposer/vposer_v1_0 \\\n", 355 | " --part_segm_fn smplx_parts_segm.pkl " 356 | ], 357 | "execution_count": null, 358 | "outputs": [] 359 | }, 360 | { 361 | "cell_type": "code", 362 | "metadata": { 363 | "id": "qGcmsijSQ5jd", 364 | "cellView": "form" 365 | }, 366 | "source": [ 367 | "# @title Make train-test-validation splits, copy data to final folders\n", 368 | "\n", 369 | "import shutil\n", 370 | "\n", 371 | "train_ratio = 0.9 #@param\n", 372 | " \n", 373 | "final_zip_path = '/content/gdrive/MyDrive/datasets/smplpix_data_test.zip' # @param {type:\"string\"}\n", 374 | "\n", 375 | "\n", 376 | "target_images_path = '/content/data/smplifyx_results/input_images'\n", 377 | "smplifyx_renders = '/content/data/smplifyx_results/rendered_smplifyx_meshes'\n", 378 | "\n", 379 | "smplpix_data_path = '/content/smplpix_data'\n", 380 | "\n", 381 | "train_input_dir = os.path.join(smplpix_data_path, 'train', 'input')\n", 382 | "train_output_dir = os.path.join(smplpix_data_path, 'train', 'output')\n", 383 | "val_input_dir = os.path.join(smplpix_data_path, 'validation', 'input')\n", 384 | "val_output_dir = os.path.join(smplpix_data_path, 'validation', 'output')\n", 385 | "test_input_dir = os.path.join(smplpix_data_path, 'test', 'input')\n", 386 | "test_output_dir = os.path.join(smplpix_data_path, 'test', 'output')\n", 387 | "\n", 388 | "!mkdir -p $train_input_dir\n", 389 | "!mkdir -p $train_output_dir\n", 390 | "!mkdir -p $val_input_dir\n", 391 | "!mkdir -p $val_output_dir\n", 392 | "!mkdir -p $test_input_dir\n", 393 | "!mkdir -p $test_output_dir\n", 394 | "\n", 395 | "img_names = sorted(os.listdir(target_images_path))\n", 396 | "n_images = len(img_names)\n", 397 | "n_train_images = int(n_images * train_ratio)\n", 398 | "n_val_images = int(n_images * (1-train_ratio) / 2)\n", 399 | "train_images = img_names[0:n_train_images]\n", 400 | "val_images = img_names[n_train_images:n_train_images+n_val_images]\n", 401 | "test_images = img_names[n_train_images:]\n", 402 | "\n", 403 | "for img in train_images:\n", 404 | " shutil.copy(os.path.join(smplifyx_renders, img), train_input_dir)\n", 405 | " shutil.copy(os.path.join(target_images_path, img), train_output_dir)\n", 406 | "\n", 407 | "for img in val_images:\n", 408 | " shutil.copy(os.path.join(smplifyx_renders, img), val_input_dir)\n", 409 | " shutil.copy(os.path.join(target_images_path, img), val_output_dir)\n", 410 | "\n", 411 | "for img in test_images:\n", 412 | " shutil.copy(os.path.join(smplifyx_renders, img), test_input_dir)\n", 413 | " shutil.copy(os.path.join(target_images_path, img), test_output_dir)\n", 414 | "\n", 415 | "\n", 416 | "%cd /content\n", 417 | "!zip -r $final_zip_path smplpix_data/" 418 | ], 419 | "execution_count": null, 420 | "outputs": [] 421 | } 422 | ] 423 | } 424 | -------------------------------------------------------------------------------- /colab_notebooks/SMPLpix_training.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "accelerator": "GPU", 6 | "colab": { 7 | "name": "SMPLpix training.ipynb", 8 | "provenance": [], 9 | "collapsed_sections": [], 10 | "machine_shape": "hm" 11 | }, 12 | "kernelspec": { 13 | "display_name": "Python 3", 14 | "name": "python3" 15 | }, 16 | "language_info": { 17 | "name": "python" 18 | } 19 | }, 20 | "cells": [ 21 | { 22 | "cell_type": "markdown", 23 | "metadata": { 24 | "id": "w3rYnZGzpAkn" 25 | }, 26 | "source": [ 27 | "![](https://user-images.githubusercontent.com/8117267/116595226-621dc600-a923-11eb-85d4-f52a5b9f265c.gif)\n", 28 | "###SMPLpix Training Demo.\n", 29 | "\n", 30 | "**Author**: [Sergey Prokudin](https://ps.is.mpg.de/people/sprokudin). \n", 31 | "[[Project Page](https://sergeyprokudin.github.io/smplpix/)]\n", 32 | "[[Paper](https://arxiv.org/pdf/2008.06872.pdf)]\n", 33 | "[[Video](https://www.youtube.com/watch?v=JY9t4xUAouk)]\n", 34 | "[[GitHub](https://github.com/sergeyprokudin/smplpix)]\n", 35 | "\n", 36 | "This notebook contains an example of training script for SMPLpix rendering module. \n", 37 | "\n", 38 | "To prepare the data with your own video, use the [SMPLpix dataset preparation notebook](https://colab.research.google.com/github//sergeyprokudin/smplpix/blob/main/colab_notebooks/Convert_Video_to_SMPLpix_Dataset.ipynb)." 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": { 44 | "id": "2uwGj2Zfz-2S" 45 | }, 46 | "source": [ 47 | "### Install the SMPLpix framework:\n", 48 | "\n" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "metadata": { 54 | "id": "ZEiojUz_qaqp" 55 | }, 56 | "source": [ 57 | "!git clone https://github.com/sergeyprokudin/smplpix\n", 58 | "%cd /content/smplpix\n", 59 | "!python setup.py install" 60 | ], 61 | "execution_count": null, 62 | "outputs": [] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "metadata": { 67 | "id": "xsWo4DAdyIxZ" 68 | }, 69 | "source": [ 70 | "### To train the model on the provided [demo dataset](https://www.dropbox.com/s/coapl05ahqalh09/smplpix_data_test_final.zip?dl=0), simply run:" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "metadata": { 76 | "id": "NdhVaGJzwnM2" 77 | }, 78 | "source": [ 79 | "!cd /content/smplpix\n", 80 | "!python smplpix/train.py \\\n", 81 | " --workdir='/content/smplpix_logs/' \\\n", 82 | " --resume_training=0 \\\n", 83 | " --data_url='https://www.dropbox.com/s/coapl05ahqalh09/smplpix_data_test_final.zip?dl=0' \\\n", 84 | " --downsample_factor=4 \\\n", 85 | " --n_epochs=200 \\\n", 86 | " --sched_patience=2 \\\n", 87 | " --eval_every_nth_epoch=10 \\\n", 88 | " --batch_size=4 \\\n", 89 | " --learning_rate=1.0e-3 \\\n", 90 | " --n_unet_blocks=5 \\\n", 91 | " --aug_prob=0.0" 92 | ], 93 | "execution_count": null, 94 | "outputs": [] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "metadata": { 99 | "id": "bO4OIRJJpDAC" 100 | }, 101 | "source": [ 102 | "### Optionally, we can continue training and finetune the network with lower learning rate:" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "metadata": { 108 | "id": "HoWDXi3SQByO" 109 | }, 110 | "source": [ 111 | "!python smplpix/train.py \\\n", 112 | " --workdir='/content/smplpix_logs/' \\\n", 113 | " --resume_training=1 \\\n", 114 | " --data_dir='/content/smplpix_logs/smplpix_data' \\\n", 115 | " --downsample_factor=4 \\\n", 116 | " --n_epochs=50 \\\n", 117 | " --sched_patience=2 \\\n", 118 | " --eval_every_nth_epoch=10 \\\n", 119 | " --batch_size=4 \\\n", 120 | " --learning_rate=1.0e-4 \\\n", 121 | " --n_unet_blocks=5" 122 | ], 123 | "execution_count": null, 124 | "outputs": [] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "metadata": { 129 | "id": "fhKhReRYM0Ri" 130 | }, 131 | "source": [ 132 | "### To evaluate the model on the test images, run:" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "metadata": { 138 | "id": "C1iKU3dXP7ea" 139 | }, 140 | "source": [ 141 | "!python smplpix/eval.py \\\n", 142 | " --workdir='/content/eval_logs/' \\\n", 143 | " --checkpoint_path='/content/smplpix_logs/network.h5' \\\n", 144 | " --data_dir='/content/smplpix_logs/smplpix_data/test' \\\n", 145 | " --downsample_factor=4 \\\n", 146 | " --save_target=1 " 147 | ], 148 | "execution_count": null, 149 | "outputs": [] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "metadata": { 154 | "cellView": "form", 155 | "id": "AYpNcVKiMzB_" 156 | }, 157 | "source": [ 158 | "# @markdown ###Play the generated test video\n", 159 | "from IPython.display import HTML\n", 160 | "from base64 import b64encode\n", 161 | "mp4 = open('/content/eval_logs/test_animation.mp4','rb').read()\n", 162 | "data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n", 163 | "HTML(\"\"\"\n", 164 | "\n", 167 | "\"\"\" % data_url)" 168 | ], 169 | "execution_count": null, 170 | "outputs": [] 171 | } 172 | ] 173 | } -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Setup for the project.""" 2 | from setuptools import setup 3 | 4 | setup( 5 | name="smplpix", 6 | version=1.0, 7 | description="SMPLpix: Neural Avatars from Deformable 3D models", 8 | install_requires=["torch", "torchvision", "trimesh", "pyrender"], 9 | author="Sergey Prokudin", 10 | license="MIT", 11 | author_email="sergey.prokudin@gmail.com", 12 | packages=["smplpix"] 13 | ) -------------------------------------------------------------------------------- /smplpix/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sergeyprokudin/smplpix/4dd135303f1641c5f13fbf8b92b6fdaa38008b03/smplpix/__init__.py -------------------------------------------------------------------------------- /smplpix/args.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | def get_smplpix_arguments(): 5 | 6 | parser = argparse.ArgumentParser(description='SMPLpix argument parser') 7 | parser.add_argument('--workdir', 8 | dest='workdir', 9 | help='workdir to save data, checkpoints, renders, etc.', 10 | default=os.getcwd()) 11 | parser.add_argument('--data_dir', 12 | dest='data_dir', 13 | help='directory with training input and target images to the network, should contain' 14 | 'input and output subfolders', 15 | default=None) 16 | parser.add_argument('--resume_training', 17 | dest='resume_training', 18 | type=int, 19 | help='whether to continue training process given the checkpoint in workdir', 20 | default=False) 21 | parser.add_argument('--data_url', 22 | dest='data_url', 23 | help='Dropbox URL containing zipped dataset', 24 | default=None) 25 | parser.add_argument('--n_input_channels', 26 | dest='n_input_channels', 27 | type=int, 28 | help='number of channels in the input images', 29 | default=3) 30 | parser.add_argument('--n_output_channels', 31 | dest='n_output_channels', 32 | type=int, 33 | help='number of channels in the input images', 34 | default=3) 35 | parser.add_argument('--sigmoid_output', 36 | dest='sigmoid_output', 37 | type=int, 38 | help='whether to add sigmoid activation as a final layer', 39 | default=True) 40 | parser.add_argument('--n_unet_blocks', 41 | dest='n_unet_blocks', 42 | type=int, 43 | help='number of blocks in UNet rendering module', 44 | default=5) 45 | parser.add_argument('--batch_size', 46 | dest='batch_size', 47 | type=int, 48 | help='batch size to use during training', 49 | default=4) 50 | parser.add_argument('--device', 51 | dest='device', 52 | help='GPU device to use during training', 53 | default='cuda') 54 | parser.add_argument('--downsample_factor', 55 | dest='downsample_factor', 56 | type=int, 57 | help='image downsampling factor (for faster training)', 58 | default=4) 59 | parser.add_argument('--n_epochs', 60 | dest='n_epochs', 61 | type=int, 62 | help='number of epochs to train the network for', 63 | default=500) 64 | parser.add_argument('--learning_rate', 65 | dest='learning_rate', 66 | type=float, 67 | help='initial learning rate', 68 | default=1.0e-3) 69 | parser.add_argument('--eval_every_nth_epoch', 70 | dest='eval_every_nth_epoch', 71 | type=int, 72 | help='evaluate on validation data every nth epoch', 73 | default=10) 74 | parser.add_argument('--sched_patience', 75 | dest='sched_patience', 76 | type=int, 77 | help='amount of validation set evaluations with no improvement after which LR will be reduced', 78 | default=3) 79 | parser.add_argument('--aug_prob', 80 | dest='aug_prob', 81 | type=float, 82 | help='probability that the input sample will be rotated and rescaled - higher value is recommended for data scarse scenarios', 83 | default=0.8) 84 | parser.add_argument('--save_target', 85 | dest='save_target', 86 | type=int, 87 | help='whether to save target images during evaluation', 88 | default=1) 89 | parser.add_argument('--checkpoint_path', 90 | dest='checkpoint_path', 91 | help='path to checkpoint (for evaluation)', 92 | default=None) 93 | args = parser.parse_args() 94 | 95 | return args 96 | -------------------------------------------------------------------------------- /smplpix/dataset.py: -------------------------------------------------------------------------------- 1 | # SMPLpix basic dataset class used in all experiments 2 | # 3 | # (c) Sergey Prokudin (sergey.prokudin@gmail.com), 2021 4 | # 5 | 6 | import os 7 | import numpy as np 8 | from PIL import Image 9 | import torch 10 | from torch.utils.data import Dataset 11 | from torchvision.transforms import functional as tvf 12 | 13 | class SMPLPixDataset(Dataset): 14 | 15 | def __init__(self, data_dir, 16 | n_input_channels=3, 17 | n_output_channels=3, 18 | downsample_factor=1, 19 | perform_augmentation=False, 20 | augmentation_probability=0.75, 21 | aug_scale_interval=None, 22 | aug_angle_interval=None, 23 | aug_translate_interval=None, 24 | input_fill_color=1, 25 | output_fill_color=1): 26 | 27 | if aug_translate_interval is None: 28 | aug_translate_interval = [-100, 100] 29 | if aug_scale_interval is None: 30 | aug_scale_interval = [0.5, 1.5] 31 | if aug_angle_interval is None: 32 | aug_angle_interval = [-60, 60] 33 | 34 | self.input_dir = os.path.join(data_dir, 'input') 35 | self.output_dir = os.path.join(data_dir, 'output') 36 | if not os.path.exists(self.output_dir): 37 | self.output_dir = self.input_dir 38 | 39 | self.n_input_channels = n_input_channels 40 | self.n_output_channels = n_output_channels 41 | self.samples = sorted(os.listdir(self.output_dir)) 42 | self.downsample_factor = downsample_factor 43 | self.perform_augmentation = perform_augmentation 44 | self.augmentation_probability = augmentation_probability 45 | self.aug_scale_interval = aug_scale_interval 46 | self.aug_angle_interval = aug_angle_interval 47 | self.aug_translate_interval = aug_translate_interval 48 | self.input_fill_color = input_fill_color 49 | self.output_fill_color = output_fill_color 50 | 51 | def __len__(self): 52 | return len(self.samples) 53 | 54 | def _get_augmentation_params(self): 55 | 56 | scale = np.random.uniform(low=self.aug_scale_interval[0], 57 | high=self.aug_scale_interval[1]) 58 | angle = np.random.uniform(self.aug_angle_interval[0], self.aug_angle_interval[1]) 59 | translate = [np.random.uniform(self.aug_translate_interval[0], 60 | self.aug_translate_interval[1]), 61 | np.random.uniform(self.aug_translate_interval[0], 62 | self.aug_translate_interval[1])] 63 | 64 | return scale, angle, translate 65 | 66 | def _augment_images(self, x, y): 67 | 68 | augment_instance = np.random.uniform() < self.augmentation_probability 69 | 70 | if augment_instance: 71 | scale, angle, translate = self._get_augmentation_params() 72 | 73 | x = tvf.affine(x, 74 | angle=angle, 75 | translate=translate, 76 | scale=scale, 77 | shear=0, fill=self.input_fill_color) 78 | 79 | y = tvf.affine(y, 80 | angle=angle, 81 | translate=translate, 82 | scale=scale, 83 | shear=0, fill=self.output_fill_color) 84 | 85 | return x, y 86 | 87 | def __getitem__(self, idx): 88 | 89 | img_name = self.samples[idx] 90 | x_path = os.path.join(self.input_dir, img_name) 91 | x = Image.open(x_path) 92 | y_path = os.path.join(self.output_dir, img_name) 93 | y = Image.open(y_path) 94 | 95 | if self.perform_augmentation: 96 | x, y = self._augment_images(x, y) 97 | 98 | x = torch.Tensor(np.asarray(x) / 255).transpose(0, 2) 99 | y = torch.Tensor(np.asarray(y) / 255).transpose(0, 2) 100 | x = x[0:self.n_input_channels, ::self.downsample_factor, ::self.downsample_factor] 101 | y = y[0:self.n_output_channels, ::self.downsample_factor, ::self.downsample_factor] 102 | 103 | return x, y, img_name 104 | -------------------------------------------------------------------------------- /smplpix/eval.py: -------------------------------------------------------------------------------- 1 | # Main SMPLpix Evaluation Script 2 | # 3 | # (c) Sergey Prokudin (sergey.prokudin@gmail.com), 2021 4 | # 5 | 6 | import os 7 | import shutil 8 | import pprint 9 | import torch 10 | from torch.utils.data import DataLoader 11 | 12 | from smplpix.args import get_smplpix_arguments 13 | from smplpix.utils import generate_mp4 14 | from smplpix.dataset import SMPLPixDataset 15 | from smplpix.unet import UNet 16 | from smplpix.training import train, evaluate 17 | from smplpix.utils import download_and_unzip 18 | 19 | def generate_eval_video(args, data_dir, unet, frame_rate=25, save_target=False, save_input=True): 20 | 21 | print("rendering SMPLpix predictions for %s..." % data_dir) 22 | data_part_name = os.path.split(data_dir)[-1] 23 | 24 | test_dataset = SMPLPixDataset(data_dir=data_dir, 25 | downsample_factor=args.downsample_factor, 26 | perform_augmentation=False, 27 | n_input_channels=args.n_input_channels, 28 | n_output_channels=args.n_output_channels, 29 | augmentation_probability=args.aug_prob) 30 | test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size) 31 | final_renders_path = os.path.join(args.workdir, 'renders_%s' % data_part_name) 32 | _ = evaluate(unet, test_dataloader, final_renders_path, args.device, save_target=save_target, save_input=save_input) 33 | 34 | print("generating video animation for data %s..." % data_dir) 35 | 36 | video_animation_path = os.path.join(args.workdir, '%s_animation.mp4' % data_part_name) 37 | _ = generate_mp4(final_renders_path, video_animation_path, frame_rate=frame_rate) 38 | print("saved animation video to %s" % video_animation_path) 39 | 40 | return 41 | 42 | def main(): 43 | 44 | print("******************************************************************************************\n"+ 45 | "****************************** SMPLpix Evaluation Loop **********************************\n"+ 46 | "******************************************************************************************\n"+ 47 | "******** Copyright (c) 2021 - now, Sergey Prokudin (sergey.prokudin@gmail.com) ***********") 48 | 49 | args = get_smplpix_arguments() 50 | print("ARGUMENTS:") 51 | pprint.pprint(args) 52 | 53 | if args.checkpoint_path is None: 54 | print("no model checkpoint was specified, looking in the log directory...") 55 | ckpt_path = os.path.join(args.workdir, 'network.h5') 56 | else: 57 | ckpt_path = args.checkpoint_path 58 | if not os.path.exists(ckpt_path): 59 | print("checkpoint %s not found!" % ckpt_path) 60 | return 61 | 62 | print("defining the neural renderer model (U-Net)...") 63 | unet = UNet(in_channels=args.n_input_channels, out_channels=args.n_output_channels, 64 | n_blocks=args.n_unet_blocks, dim=2, up_mode='resizeconv_linear').to(args.device) 65 | 66 | print("loading the model from checkpoint: %s" % ckpt_path) 67 | unet.load_state_dict(torch.load(ckpt_path)) 68 | unet.eval() 69 | generate_eval_video(args, args.data_dir, unet, save_target=args.save_target) 70 | 71 | return 72 | 73 | if __name__== '__main__': 74 | main() 75 | -------------------------------------------------------------------------------- /smplpix/train.py: -------------------------------------------------------------------------------- 1 | # Main SMPLpix Training Script 2 | # 3 | # (c) Sergey Prokudin (sergey.prokudin@gmail.com), 2021 4 | # 5 | 6 | import os 7 | import shutil 8 | import pprint 9 | import torch 10 | from torch.utils.data import DataLoader 11 | 12 | from smplpix.args import get_smplpix_arguments 13 | from smplpix.utils import generate_mp4 14 | from smplpix.dataset import SMPLPixDataset 15 | from smplpix.unet import UNet 16 | from smplpix.training import train, evaluate 17 | from smplpix.utils import download_and_unzip 18 | 19 | def generate_eval_video(args, data_dir, unet, frame_rate=25, save_target=False, save_input=True): 20 | 21 | print("rendering SMPLpix predictions for %s..." % data_dir) 22 | data_part_name = os.path.split(data_dir)[-1] 23 | 24 | test_dataset = SMPLPixDataset(data_dir=data_dir, 25 | downsample_factor=args.downsample_factor, 26 | perform_augmentation=False, 27 | n_input_channels=args.n_input_channels, 28 | n_output_channels=args.n_output_channels, 29 | augmentation_probability=args.aug_prob) 30 | test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size) 31 | final_renders_path = os.path.join(args.workdir, 'renders_%s' % data_part_name) 32 | _ = evaluate(unet, test_dataloader, final_renders_path, args.device, save_target=save_target, save_input=save_input) 33 | 34 | print("generating video animation for data %s..." % data_dir) 35 | 36 | video_animation_path = os.path.join(args.workdir, '%s_animation.mp4' % data_part_name) 37 | _ = generate_mp4(final_renders_path, video_animation_path, frame_rate=frame_rate) 38 | print("saved animation video to %s" % video_animation_path) 39 | 40 | return 41 | 42 | def main(): 43 | 44 | print("******************************************************************************************\n"+ 45 | "****************************** SMPLpix Training Loop ************************************\n"+ 46 | "******************************************************************************************\n"+ 47 | "******** Copyright (c) 2021 - now, Sergey Prokudin (sergey.prokudin@gmail.com) ***********\n"+ 48 | "****************************************************************************************+*\n\n") 49 | 50 | args = get_smplpix_arguments() 51 | print("ARGUMENTS:") 52 | pprint.pprint(args) 53 | 54 | if not os.path.exists(args.workdir): 55 | os.makedirs(args.workdir) 56 | 57 | log_dir = os.path.join(args.workdir, 'logs') 58 | os.makedirs(log_dir, exist_ok=True) 59 | 60 | if args.data_url is not None: 61 | download_and_unzip(args.data_url, args.workdir) 62 | args.data_dir = os.path.join(args.workdir, 'smplpix_data') 63 | 64 | train_dir = os.path.join(args.data_dir, 'train') 65 | val_dir = os.path.join(args.data_dir, 'validation') 66 | test_dir = os.path.join(args.data_dir, 'test') 67 | 68 | train_dataset = SMPLPixDataset(data_dir=train_dir, 69 | perform_augmentation=True, 70 | augmentation_probability=args.aug_prob, 71 | downsample_factor=args.downsample_factor, 72 | n_input_channels=args.n_input_channels, 73 | n_output_channels=args.n_output_channels) 74 | train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size) 75 | 76 | if os.path.exists(val_dir): 77 | val_dataset = SMPLPixDataset(data_dir=val_dir, 78 | perform_augmentation=False, 79 | augmentation_probability=args.aug_prob, 80 | downsample_factor=args.downsample_factor, 81 | n_input_channels=args.n_input_channels, 82 | n_output_channels=args.n_output_channels) 83 | val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size) 84 | else: 85 | print("no validation data was provided, will use train data for validation...") 86 | val_dataloader = train_dataloader 87 | 88 | print("defining the neural renderer model (U-Net)...") 89 | unet = UNet(in_channels=args.n_input_channels, 90 | out_channels=args.n_output_channels, 91 | sigmoid_output=args.sigmoid_output, 92 | n_blocks=args.n_unet_blocks, dim=2, up_mode='resizeconv_linear').to(args.device) 93 | 94 | if args.checkpoint_path is None: 95 | ckpt_path = os.path.join(args.workdir, 'network.h5') 96 | else: 97 | ckpt_path = args.checkpoint_path 98 | 99 | if args.resume_training and os.path.exists(ckpt_path): 100 | print("found checkpoint, resuming from: %s" % ckpt_path) 101 | unet.load_state_dict(torch.load(ckpt_path)) 102 | if not args.resume_training: 103 | print("starting training from scratch, cleaning the log dirs...") 104 | shutil.rmtree(log_dir) 105 | 106 | print("starting training...") 107 | finished = False 108 | try: 109 | train(model=unet, train_dataloader=train_dataloader, val_dataloader=val_dataloader, 110 | log_dir=log_dir, ckpt_path=ckpt_path, device=args.device, n_epochs=args.n_epochs, 111 | eval_every_nth_epoch=args.eval_every_nth_epoch, sched_patience=args.sched_patience, 112 | init_lr=args.learning_rate) 113 | finished = True 114 | 115 | except KeyboardInterrupt: 116 | print("training interrupted, generating final animations...") 117 | generate_eval_video(args, train_dir, unet, save_target=True) 118 | generate_eval_video(args, val_dir, unet, save_target=True) 119 | generate_eval_video(args, test_dir, unet, save_target=True) 120 | 121 | if finished: 122 | generate_eval_video(args, train_dir, unet, save_target=True) 123 | generate_eval_video(args, val_dir, unet, save_target=True) 124 | generate_eval_video(args, test_dir, unet, save_target=True) 125 | 126 | return 127 | 128 | if __name__== '__main__': 129 | main() 130 | -------------------------------------------------------------------------------- /smplpix/training.py: -------------------------------------------------------------------------------- 1 | # SMPLpix training and evaluation loop functions 2 | # 3 | # (c) Sergey Prokudin (sergey.prokudin@gmail.com), 2021 4 | # 5 | 6 | import os 7 | import numpy as np 8 | from tqdm import tqdm 9 | import torch 10 | import torch.nn as nn 11 | from torchvision.utils import save_image 12 | from .vgg import Vgg16Features 13 | 14 | 15 | def train(model, train_dataloader, val_dataloader, log_dir, ckpt_path, device, 16 | n_epochs=1000, eval_every_nth_epoch=50, sched_patience=5, init_lr=1.0e-4): 17 | 18 | vgg = Vgg16Features(layers_weights = [1, 1/16, 1/8, 1/4, 1]).to(device) 19 | criterion_l1 = nn.L1Loss().to(device) 20 | optimizer = torch.optim.Adam(model.parameters(), lr=init_lr) 21 | sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 22 | patience=sched_patience, 23 | verbose=True) 24 | 25 | for epoch_id in tqdm(range(0, n_epochs)): 26 | 27 | model.train() 28 | torch.save(model.state_dict(), ckpt_path) 29 | 30 | for batch_idx, (x, ytrue, img_names) in enumerate(train_dataloader): 31 | x, ytrue = x.to(device), ytrue.to(device) 32 | ypred = model(x) 33 | vgg_loss = criterion_l1(vgg(ypred), vgg(ytrue)) 34 | optimizer.zero_grad() 35 | vgg_loss.backward() 36 | optimizer.step() 37 | 38 | if epoch_id % eval_every_nth_epoch == 0: 39 | print("\ncurrent epoch: %d" % epoch_id) 40 | eval_dir = os.path.join(log_dir, 'val_preds_%04d' % epoch_id) 41 | val_loss = evaluate(model, val_dataloader, eval_dir, device, vgg, show_progress=False) 42 | sched.step(val_loss) 43 | 44 | return 45 | 46 | 47 | def evaluate(model, data_loader, res_dir, device, 48 | vgg=None, report_loss=True, show_progress=True, 49 | save_input=True, save_target=True): 50 | 51 | model.eval() 52 | 53 | if not os.path.exists(res_dir): 54 | os.makedirs(res_dir) 55 | 56 | if vgg is None: 57 | vgg = Vgg16Features(layers_weights = [1, 1 / 16, 1 / 8, 1 / 4, 1]).to(device) 58 | criterion_l1 = nn.L1Loss().to(device) 59 | losses = [] 60 | 61 | if show_progress: 62 | data_seq = tqdm(enumerate(data_loader)) 63 | else: 64 | data_seq = enumerate(data_loader) 65 | 66 | for batch_idx, (x, ytrue, img_names) in data_seq: 67 | 68 | x, ytrue = x.to(device), ytrue.to(device) 69 | 70 | ypred = model(x).detach().to(device) 71 | losses.append(float(criterion_l1(vgg(ypred), vgg(ytrue)))) 72 | 73 | for fid in range(0, len(img_names)): 74 | if save_input: 75 | res_image = torch.cat([x[fid].transpose(1, 2), ypred[fid].transpose(1, 2)], dim=2) 76 | else: 77 | res_image = ypred[fid].transpose(1, 2) 78 | if save_target: 79 | res_image = torch.cat([res_image, ytrue[fid].transpose(1, 2)], dim=2) 80 | 81 | save_image(res_image, os.path.join(res_dir, '%s' % img_names[fid])) 82 | 83 | avg_loss = np.mean(losses) 84 | 85 | if report_loss: 86 | print("mean VGG loss: %f" % np.mean(avg_loss)) 87 | print("images saved at %s" % res_dir) 88 | 89 | return avg_loss 90 | -------------------------------------------------------------------------------- /smplpix/unet.py: -------------------------------------------------------------------------------- 1 | # pytorch U-Net model (https://arxiv.org/abs/1505.04597) 2 | # original repository: 3 | # https://github.com/ELEKTRONN 4 | # https://github.com/ELEKTRONN/elektronn3/blob/master/elektronn3/models/unet.py 5 | 6 | # ELEKTRONN3 - Neural Network Toolkit 7 | # 8 | # Copyright (c) 2017 - now 9 | # Max Planck Institute of Neurobiology, Munich, Germany 10 | # Author: Martin Drawitsch 11 | 12 | # MIT License 13 | # 14 | # Copyright (c) 2017 - now, ELEKTRONN team 15 | # 16 | # Permission is hereby granted, free of charge, to any person obtaining a copy 17 | # of this software and associated documentation files (the "Software"), to deal 18 | # in the Software without restriction, including without limitation the rights 19 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 20 | # copies of the Software, and to permit persons to whom the Software is 21 | # furnished to do so, subject to the following conditions: 22 | # 23 | # The above copyright notice and this permission notice shall be included in all 24 | # copies or substantial portions of the Software. 25 | # 26 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 27 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 28 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 29 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 30 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 31 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 32 | # SOFTWARE. 33 | 34 | """ 35 | This is a modified version of the U-Net CNN architecture for biomedical 36 | image segmentation. U-Net was originally published in 37 | https://arxiv.org/abs/1505.04597 by Ronneberger et al. 38 | 39 | A pure-3D variant of U-Net has been proposed by Çiçek et al. 40 | in https://arxiv.org/abs/1606.06650, but the below implementation 41 | is based on the original U-Net paper, with several improvements. 42 | 43 | This code is based on https://github.com/jaxony/unet-pytorch 44 | (c) 2017 Jackson Huang, released under MIT License, 45 | which implements (2D) U-Net with user-defined network depth 46 | and a few other improvements of the original architecture. 47 | 48 | Major differences of this version from Huang's code: 49 | 50 | - Operates on 3D image data (5D tensors) instead of 2D data 51 | - Uses 3D convolution, 3D pooling etc. by default 52 | - planar_blocks architecture parameter for mixed 2D/3D convnets 53 | (see UNet class docstring for details) 54 | - Improved tests (see the bottom of the file) 55 | - Cleaned up parameter/variable names and formatting, changed default params 56 | - Updated for PyTorch 1.3 and Python 3.6 (earlier versions unsupported) 57 | - (Optional DEBUG mode for optional printing of debug information) 58 | - Extended documentation 59 | """ 60 | 61 | __all__ = ['UNet'] 62 | 63 | import copy 64 | import itertools 65 | 66 | from typing import Sequence, Union, Tuple, Optional 67 | 68 | import torch 69 | from torch import nn 70 | from torch.utils.checkpoint import checkpoint 71 | from torch.nn import functional as F 72 | 73 | 74 | def get_conv(dim=3): 75 | """Chooses an implementation for a convolution layer.""" 76 | if dim == 3: 77 | return nn.Conv3d 78 | elif dim == 2: 79 | return nn.Conv2d 80 | else: 81 | raise ValueError('dim has to be 2 or 3') 82 | 83 | 84 | def get_convtranspose(dim=3): 85 | """Chooses an implementation for a transposed convolution layer.""" 86 | if dim == 3: 87 | return nn.ConvTranspose3d 88 | elif dim == 2: 89 | return nn.ConvTranspose2d 90 | else: 91 | raise ValueError('dim has to be 2 or 3') 92 | 93 | 94 | def get_maxpool(dim=3): 95 | """Chooses an implementation for a max-pooling layer.""" 96 | if dim == 3: 97 | return nn.MaxPool3d 98 | elif dim == 2: 99 | return nn.MaxPool2d 100 | else: 101 | raise ValueError('dim has to be 2 or 3') 102 | 103 | 104 | def get_normalization(normtype: str, num_channels: int, dim: int = 3): 105 | """Chooses an implementation for a batch normalization layer.""" 106 | if normtype is None or normtype == 'none': 107 | return nn.Identity() 108 | elif normtype.startswith('group'): 109 | if normtype == 'group': 110 | num_groups = 8 111 | elif len(normtype) > len('group') and normtype[len('group'):].isdigit(): 112 | num_groups = int(normtype[len('group'):]) 113 | else: 114 | raise ValueError( 115 | f'normtype "{normtype}" not understood. It should be "group",' 116 | f' where is the number of groups.' 117 | ) 118 | return nn.GroupNorm(num_groups=num_groups, num_channels=num_channels) 119 | elif normtype == 'instance': 120 | if dim == 3: 121 | return nn.InstanceNorm3d(num_channels) 122 | elif dim == 2: 123 | return nn.InstanceNorm2d(num_channels) 124 | else: 125 | raise ValueError('dim has to be 2 or 3') 126 | elif normtype == 'batch': 127 | if dim == 3: 128 | return nn.BatchNorm3d(num_channels) 129 | elif dim == 2: 130 | return nn.BatchNorm2d(num_channels) 131 | else: 132 | raise ValueError('dim has to be 2 or 3') 133 | else: 134 | raise ValueError( 135 | f'Unknown normalization type "{normtype}".\n' 136 | 'Valid choices are "batch", "instance", "group" or "group",' 137 | 'where is the number of groups.' 138 | ) 139 | 140 | 141 | def planar_kernel(x): 142 | """Returns a "planar" kernel shape (e.g. for 2D convolution in 3D space) 143 | that doesn't consider the first spatial dim (D).""" 144 | if isinstance(x, int): 145 | return (1, x, x) 146 | else: 147 | return x 148 | 149 | 150 | def planar_pad(x): 151 | """Returns a "planar" padding shape that doesn't pad along the first spatial dim (D).""" 152 | if isinstance(x, int): 153 | return (0, x, x) 154 | else: 155 | return x 156 | 157 | 158 | def conv3(in_channels, out_channels, kernel_size=3, stride=1, 159 | padding=1, bias=True, planar=False, dim=3): 160 | """Returns an appropriate spatial convolution layer, depending on args. 161 | - dim=2: Conv2d with 3x3 kernel 162 | - dim=3 and planar=False: Conv3d with 3x3x3 kernel 163 | - dim=3 and planar=True: Conv3d with 1x3x3 kernel 164 | """ 165 | if planar: 166 | stride = planar_kernel(stride) 167 | padding = planar_pad(padding) 168 | kernel_size = planar_kernel(kernel_size) 169 | return get_conv(dim)( 170 | in_channels, 171 | out_channels, 172 | kernel_size=kernel_size, 173 | stride=stride, 174 | padding=padding, 175 | bias=bias 176 | ) 177 | 178 | 179 | def upconv2(in_channels, out_channels, mode='transpose', planar=False, dim=3): 180 | """Returns a learned upsampling operator depending on args.""" 181 | kernel_size = 2 182 | stride = 2 183 | if planar: 184 | kernel_size = planar_kernel(kernel_size) 185 | stride = planar_kernel(stride) 186 | if mode == 'transpose': 187 | return get_convtranspose(dim)( 188 | in_channels, 189 | out_channels, 190 | kernel_size=kernel_size, 191 | stride=stride 192 | ) 193 | elif 'resizeconv' in mode: 194 | if 'linear' in mode: 195 | upsampling_mode = 'trilinear' if dim == 3 else 'bilinear' 196 | else: 197 | upsampling_mode = 'nearest' 198 | rc_kernel_size = 1 if mode.endswith('1') else 3 199 | return ResizeConv( 200 | in_channels, out_channels, planar=planar, dim=dim, 201 | upsampling_mode=upsampling_mode, kernel_size=rc_kernel_size 202 | ) 203 | 204 | 205 | def conv1(in_channels, out_channels, dim=3): 206 | """Returns a 1x1 or 1x1x1 convolution, depending on dim""" 207 | return get_conv(dim)(in_channels, out_channels, kernel_size=1) 208 | 209 | 210 | def get_activation(activation): 211 | if isinstance(activation, str): 212 | if activation == 'relu': 213 | return nn.ReLU() 214 | elif activation == 'leaky': 215 | return nn.LeakyReLU(negative_slope=0.1) 216 | elif activation == 'prelu': 217 | return nn.PReLU(num_parameters=1) 218 | elif activation == 'rrelu': 219 | return nn.RReLU() 220 | elif activation == 'lin': 221 | return nn.Identity() 222 | else: 223 | # Deep copy is necessary in case of paremtrized activations 224 | return copy.deepcopy(activation) 225 | 226 | 227 | class DownConv(nn.Module): 228 | """ 229 | A helper Module that performs 2 convolutions and 1 MaxPool. 230 | A ReLU activation follows each convolution. 231 | """ 232 | def __init__(self, in_channels, out_channels, pooling=True, planar=False, activation='relu', 233 | normalization=None, full_norm=True, dim=3, conv_mode='same'): 234 | super().__init__() 235 | 236 | self.in_channels = in_channels 237 | self.out_channels = out_channels 238 | self.pooling = pooling 239 | self.normalization = normalization 240 | self.dim = dim 241 | padding = 1 if 'same' in conv_mode else 0 242 | 243 | self.conv1 = conv3( 244 | self.in_channels, self.out_channels, planar=planar, dim=dim, padding=padding 245 | ) 246 | self.conv2 = conv3( 247 | self.out_channels, self.out_channels, planar=planar, dim=dim, padding=padding 248 | ) 249 | 250 | if self.pooling: 251 | kernel_size = 2 252 | if planar: 253 | kernel_size = planar_kernel(kernel_size) 254 | self.pool = get_maxpool(dim)(kernel_size=kernel_size, ceil_mode=True) 255 | self.pool_ks = kernel_size 256 | else: 257 | self.pool = nn.Identity() 258 | self.pool_ks = -123 # Bogus value, will never be read. Only to satisfy TorchScript's static type system 259 | 260 | self.act1 = get_activation(activation) 261 | self.act2 = get_activation(activation) 262 | 263 | if full_norm: 264 | self.norm0 = get_normalization(normalization, self.out_channels, dim=dim) 265 | else: 266 | self.norm0 = nn.Identity() 267 | self.norm1 = get_normalization(normalization, self.out_channels, dim=dim) 268 | 269 | def forward(self, x): 270 | y = self.conv1(x) 271 | y = self.norm0(y) 272 | y = self.act1(y) 273 | y = self.conv2(y) 274 | y = self.norm1(y) 275 | y = self.act2(y) 276 | before_pool = y 277 | y = self.pool(y) 278 | return y, before_pool 279 | 280 | 281 | @torch.jit.script 282 | def autocrop(from_down: torch.Tensor, from_up: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 283 | """ 284 | Crops feature tensors from the encoder and decoder pathways so that they 285 | can be combined. 286 | 287 | - If inputs from the encoder pathway have shapes that are not divisible 288 | by 2, the use of ``nn.MaxPool(ceil_mode=True)`` leads to the 2x 289 | upconvolution results being too large by one element in each odd 290 | dimension, so they need to be cropped in these dimensions. 291 | 292 | - If VALID convolutions are used, feature tensors get smaller with each 293 | convolution, so we need to center-crop the larger feature tensors from 294 | the encoder pathway to make features combinable with the smaller 295 | decoder feautures. 296 | 297 | Args: 298 | from_down: Feature from encoder pathway (``DownConv``) 299 | from_up: Feature from decoder pathway (2x upsampled) 300 | 301 | Returns: 302 | 303 | """ 304 | ndim = from_down.dim() # .ndim is not supported by torch.jit 305 | 306 | if from_down.shape[2:] == from_up.shape[2:]: # No need to crop anything 307 | return from_down, from_up 308 | 309 | # Step 1: Handle odd shapes 310 | 311 | # Handle potentially odd input shapes from encoder 312 | # by cropping from_up by 1 in each dim that is odd in from_down and not 313 | # odd in from_up (that is, where the difference between them is odd). 314 | # The reason for looking at the shape difference and not just the shape 315 | # of from_down is that although decoder outputs mostly have even shape 316 | # because of the 2x upsampling, but if anisotropic pooling is used, the 317 | # decoder outputs can also be be oddly shaped in the z (D) dimension. 318 | # In these cases no cropping should be performed. 319 | ds = from_down.shape[2:] 320 | us = from_up.shape[2:] 321 | upcrop = [u - ((u - d) % 2) for d, u in zip(ds, us)] 322 | 323 | if ndim == 4: 324 | from_up = from_up[:, :, :upcrop[0], :upcrop[1]] 325 | if ndim == 5: 326 | from_up = from_up[:, :, :upcrop[0], :upcrop[1], :upcrop[2]] 327 | 328 | # Step 2: Handle center-crop resulting from valid convolutions 329 | ds = from_down.shape[2:] 330 | us = from_up.shape[2:] 331 | 332 | assert ds[0] >= us[0], f'{ds, us}' 333 | assert ds[1] >= us[1] 334 | if ndim == 4: 335 | from_down = from_down[ 336 | :, 337 | :, 338 | (ds[0] - us[0]) // 2:(ds[0] + us[0]) // 2, 339 | (ds[1] - us[1]) // 2:(ds[1] + us[1]) // 2 340 | ] 341 | elif ndim == 5: 342 | assert ds[2] >= us[2] 343 | from_down = from_down[ 344 | :, 345 | :, 346 | ((ds[0] - us[0]) // 2):((ds[0] + us[0]) // 2), 347 | ((ds[1] - us[1]) // 2):((ds[1] + us[1]) // 2), 348 | ((ds[2] - us[2]) // 2):((ds[2] + us[2]) // 2), 349 | ] 350 | return from_down, from_up 351 | 352 | 353 | class UpConv(nn.Module): 354 | """ 355 | A helper Module that performs 2 convolutions and 1 UpConvolution. 356 | A ReLU activation follows each convolution. 357 | """ 358 | 359 | att: Optional[torch.Tensor] 360 | 361 | def __init__(self, in_channels, out_channels, 362 | merge_mode='concat', up_mode='transpose', planar=False, 363 | activation='relu', normalization=None, full_norm=True, dim=3, conv_mode='same', 364 | attention=False): 365 | super().__init__() 366 | 367 | self.in_channels = in_channels 368 | self.out_channels = out_channels 369 | self.merge_mode = merge_mode 370 | self.up_mode = up_mode 371 | self.normalization = normalization 372 | padding = 1 if 'same' in conv_mode else 0 373 | 374 | self.upconv = upconv2(self.in_channels, self.out_channels, 375 | mode=self.up_mode, planar=planar, dim=dim) 376 | 377 | if self.merge_mode == 'concat': 378 | self.conv1 = conv3( 379 | 2*self.out_channels, self.out_channels, planar=planar, dim=dim, padding=padding 380 | ) 381 | else: 382 | # num of input channels to conv2 is same 383 | self.conv1 = conv3( 384 | self.out_channels, self.out_channels, planar=planar, dim=dim, padding=padding 385 | ) 386 | self.conv2 = conv3( 387 | self.out_channels, self.out_channels, planar=planar, dim=dim, padding=padding 388 | ) 389 | 390 | self.act0 = get_activation(activation) 391 | self.act1 = get_activation(activation) 392 | self.act2 = get_activation(activation) 393 | 394 | if full_norm: 395 | self.norm0 = get_normalization(normalization, self.out_channels, dim=dim) 396 | self.norm1 = get_normalization(normalization, self.out_channels, dim=dim) 397 | else: 398 | self.norm0 = nn.Identity() 399 | self.norm1 = nn.Identity() 400 | self.norm2 = get_normalization(normalization, self.out_channels, dim=dim) 401 | if attention: 402 | self.attention = GridAttention( 403 | in_channels=in_channels // 2, gating_channels=in_channels, dim=dim 404 | ) 405 | else: 406 | self.attention = DummyAttention() 407 | self.att = None # Field to store attention mask for later analysis 408 | 409 | def forward(self, enc, dec): 410 | """ Forward pass 411 | Arguments: 412 | enc: Tensor from the encoder pathway 413 | dec: Tensor from the decoder pathway (to be upconv'd) 414 | """ 415 | 416 | updec = self.upconv(dec) 417 | enc, updec = autocrop(enc, updec) 418 | genc, att = self.attention(enc, dec) 419 | if not torch.jit.is_scripting(): 420 | self.att = att 421 | updec = self.norm0(updec) 422 | updec = self.act0(updec) 423 | if self.merge_mode == 'concat': 424 | mrg = torch.cat((updec, genc), 1) 425 | else: 426 | mrg = updec + genc 427 | y = self.conv1(mrg) 428 | y = self.norm1(y) 429 | y = self.act1(y) 430 | y = self.conv2(y) 431 | y = self.norm2(y) 432 | y = self.act2(y) 433 | return y 434 | 435 | 436 | class ResizeConv(nn.Module): 437 | """Upsamples by 2x and applies a convolution. 438 | 439 | This is meant as a replacement for transposed convolution to avoid 440 | checkerboard artifacts. See 441 | 442 | - https://distill.pub/2016/deconv-checkerboard/ 443 | - https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/190 444 | """ 445 | def __init__(self, in_channels, out_channels, kernel_size=3, planar=False, dim=3, 446 | upsampling_mode='nearest'): 447 | super().__init__() 448 | self.upsampling_mode = upsampling_mode 449 | self.scale_factor = 2 450 | if dim == 3 and planar: # Only interpolate (H, W) dims, leave D as is 451 | self.scale_factor = planar_kernel(self.scale_factor) 452 | self.dim = dim 453 | self.upsample = nn.Upsample(scale_factor=self.scale_factor, mode=self.upsampling_mode) 454 | # TODO: Investigate if 3x3 or 1x1 conv makes more sense here and choose default accordingly 455 | # Preliminary notes: 456 | # - conv3 increases global parameter count by ~10%, compared to conv1 and is slower overall 457 | # - conv1 is the simplest way of aligning feature dimensions 458 | # - conv1 may be enough because in all common models later layers will apply conv3 459 | # eventually, which could learn to perform the same task... 460 | # But not exactly the same thing, because this layer operates on 461 | # higher-dimensional features, which subsequent layers can't access 462 | # (at least in U-Net out_channels == in_channels // 2). 463 | # --> Needs empirical evaluation 464 | if kernel_size == 3: 465 | self.conv = conv3( 466 | in_channels, out_channels, padding=1, planar=planar, dim=dim 467 | ) 468 | elif kernel_size == 1: 469 | self.conv = conv1(in_channels, out_channels, dim=dim) 470 | else: 471 | raise ValueError(f'kernel_size={kernel_size} is not supported. Choose 1 or 3.') 472 | 473 | def forward(self, x): 474 | return self.conv(self.upsample(x)) 475 | 476 | 477 | class GridAttention(nn.Module): 478 | """Based on https://github.com/ozan-oktay/Attention-Gated-Networks 479 | 480 | Published in https://arxiv.org/abs/1804.03999""" 481 | def __init__(self, in_channels, gating_channels, inter_channels=None, dim=3, sub_sample_factor=2): 482 | super().__init__() 483 | 484 | assert dim in [2, 3] 485 | 486 | # Downsampling rate for the input featuremap 487 | if isinstance(sub_sample_factor, tuple): self.sub_sample_factor = sub_sample_factor 488 | elif isinstance(sub_sample_factor, list): self.sub_sample_factor = tuple(sub_sample_factor) 489 | else: self.sub_sample_factor = tuple([sub_sample_factor]) * dim 490 | 491 | # Default parameter set 492 | self.dim = dim 493 | self.sub_sample_kernel_size = self.sub_sample_factor 494 | 495 | # Number of channels (pixel dimensions) 496 | self.in_channels = in_channels 497 | self.gating_channels = gating_channels 498 | self.inter_channels = inter_channels 499 | 500 | if self.inter_channels is None: 501 | self.inter_channels = in_channels // 2 502 | if self.inter_channels == 0: 503 | self.inter_channels = 1 504 | 505 | if dim == 3: 506 | conv_nd = nn.Conv3d 507 | bn = nn.BatchNorm3d 508 | self.upsample_mode = 'trilinear' 509 | elif dim == 2: 510 | conv_nd = nn.Conv2d 511 | bn = nn.BatchNorm2d 512 | self.upsample_mode = 'bilinear' 513 | else: 514 | raise NotImplementedError 515 | 516 | # Output transform 517 | self.w = nn.Sequential( 518 | conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1), 519 | bn(self.in_channels), 520 | ) 521 | # Theta^T * x_ij + Phi^T * gating_signal + bias 522 | self.theta = conv_nd( 523 | in_channels=self.in_channels, out_channels=self.inter_channels, 524 | kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, bias=False 525 | ) 526 | self.phi = conv_nd( 527 | in_channels=self.gating_channels, out_channels=self.inter_channels, 528 | kernel_size=1, stride=1, padding=0, bias=True 529 | ) 530 | self.psi = conv_nd( 531 | in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, bias=True 532 | ) 533 | 534 | self.init_weights() 535 | 536 | def forward(self, x, g): 537 | # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw) 538 | # phi => (b, g_d) -> (b, i_c) 539 | theta_x = self.theta(x) 540 | 541 | # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w') 542 | # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3) 543 | phi_g = F.interpolate(self.phi(g), size=theta_x.shape[2:], mode=self.upsample_mode, align_corners=False) 544 | f = F.relu(theta_x + phi_g, inplace=True) 545 | 546 | # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3) 547 | sigm_psi_f = torch.sigmoid(self.psi(f)) 548 | 549 | # upsample the attentions and multiply 550 | sigm_psi_f = F.interpolate(sigm_psi_f, size=x.shape[2:], mode=self.upsample_mode, align_corners=False) 551 | y = sigm_psi_f.expand_as(x) * x 552 | wy = self.w(y) 553 | 554 | return wy, sigm_psi_f 555 | 556 | def init_weights(self): 557 | def weight_init(m): 558 | classname = m.__class__.__name__ 559 | if classname.find('Conv') != -1: 560 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 561 | elif classname.find('Linear') != -1: 562 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 563 | elif classname.find('BatchNorm') != -1: 564 | nn.init.normal_(m.weight.data, 1.0, 0.02) 565 | nn.init.constant_(m.bias.data, 0.0) 566 | self.apply(weight_init) 567 | 568 | 569 | class DummyAttention(nn.Module): 570 | def forward(self, x, g): 571 | return x, None 572 | 573 | 574 | # TODO: Pre-calculate output sizes when using valid convolutions 575 | class UNet(nn.Module): 576 | """Modified version of U-Net, adapted for 3D biomedical image segmentation 577 | 578 | The U-Net is a convolutional encoder-decoder neural network. 579 | Contextual spatial information (from the decoding, expansive pathway) 580 | about an input tensor is merged with information representing the 581 | localization of details (from the encoding, compressive pathway). 582 | 583 | - Original paper: https://arxiv.org/abs/1505.04597 584 | - Base implementation: https://github.com/jaxony/unet-pytorch 585 | 586 | 587 | Modifications to the original paper (@jaxony): 588 | 589 | - Padding is used in size-3-convolutions to prevent loss 590 | of border pixels. 591 | - Merging outputs does not require cropping due to (1). 592 | - Residual connections can be used by specifying 593 | UNet(merge_mode='add'). 594 | - If non-parametric upsampling is used in the decoder 595 | pathway (specified by upmode='upsample'), then an 596 | additional 1x1 convolution occurs after upsampling 597 | to reduce channel dimensionality by a factor of 2. 598 | This channel halving happens with the convolution in 599 | the tranpose convolution (specified by upmode='transpose'). 600 | 601 | Additional modifications (@mdraw): 602 | 603 | - Operates on 3D image data (5D tensors) instead of 2D data 604 | - Uses 3D convolution, 3D pooling etc. by default 605 | - Each network block pair (the two corresponding submodules in the 606 | encoder and decoder pathways) can be configured to either work 607 | in 3D or 2D mode (3D/2D convolution, pooling etc.) 608 | with the `planar_blocks` parameter. 609 | This is helpful for dealing with data anisotropy (commonly the 610 | depth axis has lower resolution in SBEM data sets, so it is not 611 | as important for convolution/pooling) and can reduce the complexity of 612 | models (parameter counts, speed, memory usage etc.). 613 | Note: If planar blocks are used, the input patch size should be 614 | adapted by reducing depth and increasing height and width of inputs. 615 | - Configurable activation function. 616 | - Optional normalization 617 | 618 | Gradient checkpointing can be used to reduce memory consumption while 619 | training. To make use of gradient checkpointing, just run the 620 | ``forward_gradcp()`` instead of the regular ``forward`` method. 621 | This makes the backward pass a bit slower, but the memory savings can be 622 | huge (usually around 20% - 50%, depending on hyperparameters). Checkpoints 623 | are made after each network *block*. 624 | See https://pytorch.org/docs/master/checkpoint.html and 625 | https://arxiv.org/abs/1604.06174 for more details. 626 | Gradient checkpointing is not supported in TorchScript mode. 627 | 628 | Args: 629 | in_channels: Number of input channels 630 | (e.g. 1 for single-grayscale inputs, 3 for RGB images) 631 | Default: 1 632 | out_channels: Number of output channels (in classification/semantic 633 | segmentation, this is the number of different classes). 634 | Default: 2 635 | n_blocks: Number of downsampling/convolution blocks (max-pooling) 636 | in the encoder pathway. The decoder (upsampling/upconvolution) 637 | pathway will consist of `n_blocks - 1` blocks. 638 | Increasing `n_blocks` has two major effects: 639 | 640 | - The network will be deeper 641 | (n + 1 -> 4 additional convolution layers) 642 | - Since each block causes one additional downsampling, more 643 | contextual information will be available for the network, 644 | enhancing the effective visual receptive field. 645 | (n + 1 -> receptive field is approximately doubled in each 646 | dimension, except in planar blocks, in which it is only 647 | doubled in the H and W image dimensions) 648 | 649 | **Important note**: Always make sure that the spatial shape of 650 | your input is divisible by the number of blocks, because 651 | else, concatenating downsampled features will fail. 652 | start_filts: Number of filters for the first convolution layer. 653 | Note: The filter counts of the later layers depend on the 654 | choice of `merge_mode`. 655 | up_mode: Upsampling method in the decoder pathway. 656 | Choices: 657 | 658 | - 'transpose' (default): Use transposed convolution 659 | ("Upconvolution") 660 | - 'resizeconv_nearest': Use resize-convolution with nearest- 661 | neighbor interpolation, as proposed in 662 | https://distill.pub/2016/deconv-checkerboard/ 663 | - 'resizeconv_linear: Same as above, but with (bi-/tri-)linear 664 | interpolation 665 | - 'resizeconv_nearest1': Like 'resizeconv_nearest', but using a 666 | light-weight 1x1 convolution layer instead of a spatial convolution 667 | - 'resizeconv_linear1': Like 'resizeconv_nearest', but using a 668 | light-weight 1x1-convolution layer instead of a spatial convolution 669 | merge_mode: How the features from the encoder pathway should 670 | be combined with the decoder features. 671 | Choices: 672 | 673 | - 'concat' (default): Concatenate feature maps along the 674 | `C` axis, doubling the number of filters each block. 675 | - 'add': Directly add feature maps (like in ResNets). 676 | The number of filters thus stays constant in each block. 677 | 678 | Note: According to https://arxiv.org/abs/1701.03056, feature 679 | concatenation ('concat') generally leads to better model 680 | accuracy than 'add' in typical medical image segmentation 681 | tasks. 682 | planar_blocks: Each number i in this sequence leads to the i-th 683 | block being a "planar" block. This means that all image 684 | operations performed in the i-th block in the encoder pathway 685 | and its corresponding decoder counterpart disregard the depth 686 | (`D`) axis and only operate in 2D (`H`, `W`). 687 | This is helpful for dealing with data anisotropy (commonly the 688 | depth axis has lower resolution in SBEM data sets, so it is 689 | not as important for convolution/pooling) and can reduce the 690 | complexity of models (parameter counts, speed, memory usage 691 | etc.). 692 | Note: If planar blocks are used, the input patch size should 693 | be adapted by reducing depth and increasing height and 694 | width of inputs. 695 | activation: Name of the non-linear activation function that should be 696 | applied after each network layer. 697 | Choices (see https://arxiv.org/abs/1505.00853 for details): 698 | 699 | - 'relu' (default) 700 | - 'leaky': Leaky ReLU (slope 0.1) 701 | - 'prelu': Parametrized ReLU. Best for training accuracy, but 702 | tends to increase overfitting. 703 | - 'rrelu': Can improve generalization at the cost of training 704 | accuracy. 705 | - Or you can pass an nn.Module instance directly, e.g. 706 | ``activation=torch.nn.ReLU()`` 707 | normalization: Type of normalization that should be applied at the end 708 | of each block. Note that it is applied after the activated conv 709 | layers, not before the activation. This scheme differs from the 710 | original batch normalization paper and the BN scheme of 3D U-Net, 711 | but it delivers better results this way 712 | (see https://redd.it/67gonq). 713 | Choices: 714 | 715 | - 'group' for group normalization (G=8) 716 | - 'group' for group normalization with groups 717 | (e.g. 'group16') for G=16 718 | - 'instance' for instance normalization 719 | - 'batch' for batch normalization (default) 720 | - 'none' or ``None`` for no normalization 721 | attention: If ``True``, use grid attention in the decoding pathway, 722 | as proposed in https://arxiv.org/abs/1804.03999. 723 | Default: ``False``. 724 | sigmoid_output: If ``True``, add sigmoid activation as final layer 725 | Default: ``True``. 726 | full_norm: If ``True`` (default), perform normalization after each 727 | (transposed) convolution in the network (which is what almost 728 | all published neural network architectures do). 729 | If ``False``, only normalize after the last convolution 730 | layer of each block, in order to save resources. This was also 731 | the default behavior before this option was introduced. 732 | dim: Spatial dimensionality of the network. Choices: 733 | 734 | - 3 (default): 3D mode. Every block fully works in 3D unless 735 | it is excluded by the ``planar_blocks`` setting. 736 | The network expects and operates on 5D input tensors 737 | (N, C, D, H, W). 738 | - 2: Every block and every operation works in 2D, expecting 739 | 4D input tensors (N, C, H, W). 740 | conv_mode: Padding mode of convolutions. Choices: 741 | 742 | - 'same' (default): Use SAME-convolutions in every layer: 743 | zero-padding inputs so that all convolutions preserve spatial 744 | shapes and don't produce an offset at the boundaries. 745 | - 'valid': Use VALID-convolutions in every layer: no padding is 746 | used, so every convolution layer reduces spatial shape by 2 in 747 | each dimension. Intermediate feature maps of the encoder pathway 748 | are automatically cropped to compatible shapes so they can be 749 | merged with decoder features. 750 | Advantages: 751 | 752 | - Less resource consumption than SAME because feature maps 753 | have reduced sizes especially in deeper layers. 754 | - No "fake" data (that is, the zeros from the SAME-padding) 755 | is fed into the network. The output regions that are influenced 756 | by zero-padding naturally have worse quality, so they should 757 | be removed in post-processing if possible (see 758 | ``overlap_shape`` in :py:mod:`elektronn3.inference`). 759 | Using VALID convolutions prevents the unnecessary computation 760 | of these regions that need to be cut away anyways for 761 | high-quality tiled inference. 762 | - Avoids the issues described in https://arxiv.org/abs/1811.11718. 763 | - Since the network will not receive zero-padded inputs, it is 764 | not required to learn a robustness against artificial zeros 765 | being in the border regions of inputs. This should reduce the 766 | complexity of the learning task and allow the network to 767 | specialize better on understanding the actual, unaltered 768 | inputs (effectively requiring less parameters to fit). 769 | 770 | Disadvantages: 771 | 772 | - Using this mode poses some additional constraints on input 773 | sizes and requires you to center-crop your targets, 774 | so it's harder to use in practice than the 'same' mode. 775 | - In some cases it might be preferable to get low-quality 776 | outputs at image borders as opposed to getting no outputs at 777 | the borders. Most notably this is the case if you do training 778 | and inference not on small patches, but on complete images in 779 | a single step. 780 | """ 781 | def __init__( 782 | self, 783 | in_channels: int = 1, 784 | out_channels: int = 2, 785 | n_blocks: int = 3, 786 | start_filts: int = 32, 787 | up_mode: str = 'resizeconv_linear', 788 | merge_mode: str = 'concat', 789 | planar_blocks: Sequence = (), 790 | batch_norm: str = 'unset', 791 | attention: bool = False, 792 | sigmoid_output: bool = True, 793 | activation: Union[str, nn.Module] = 'relu', 794 | normalization: str = 'batch', 795 | full_norm: bool = True, 796 | dim: int = 2, 797 | conv_mode: str = 'same', 798 | ): 799 | super().__init__() 800 | 801 | if n_blocks < 1: 802 | raise ValueError('n_blocks must be > 1.') 803 | 804 | if dim not in {2, 3}: 805 | raise ValueError('dim has to be 2 or 3') 806 | if dim == 2 and planar_blocks != (): 807 | raise ValueError( 808 | 'If dim=2, you can\'t use planar_blocks since everything will ' 809 | 'be planar (2-dimensional) anyways.\n' 810 | 'Either set dim=3 or set planar_blocks=().' 811 | ) 812 | if up_mode in ('transpose', 'upsample', 'resizeconv_nearest', 'resizeconv_linear', 813 | 'resizeconv_nearest1', 'resizeconv_linear1'): 814 | self.up_mode = up_mode 815 | else: 816 | raise ValueError("\"{}\" is not a valid mode for upsampling".format(up_mode)) 817 | 818 | if merge_mode in ('concat', 'add'): 819 | self.merge_mode = merge_mode 820 | else: 821 | raise ValueError("\"{}\" is not a valid mode for" 822 | "merging up and down paths. " 823 | "Only \"concat\" and " 824 | "\"add\" are allowed.".format(up_mode)) 825 | 826 | # NOTE: up_mode 'upsample' is incompatible with merge_mode 'add' 827 | # TODO: Remove merge_mode=add. It's just worse than concat 828 | if 'resizeconv' in self.up_mode and self.merge_mode == 'add': 829 | raise ValueError("up_mode \"resizeconv\" is incompatible " 830 | "with merge_mode \"add\" at the moment " 831 | "because it doesn't make sense to use " 832 | "nearest neighbour to reduce " 833 | "n_blocks channels (by half).") 834 | 835 | if len(planar_blocks) > n_blocks: 836 | raise ValueError('planar_blocks can\'t be longer than n_blocks.') 837 | if planar_blocks and (max(planar_blocks) >= n_blocks or min(planar_blocks) < 0): 838 | raise ValueError( 839 | 'planar_blocks has invalid value range. All values have to be' 840 | 'block indices, meaning integers between 0 and (n_blocks - 1).' 841 | ) 842 | 843 | self.out_channels = out_channels 844 | self.in_channels = in_channels 845 | self.sigmoid_output = sigmoid_output 846 | self.start_filts = start_filts 847 | self.n_blocks = n_blocks 848 | self.normalization = normalization 849 | self.attention = attention 850 | self.conv_mode = conv_mode 851 | self.activation = activation 852 | self.dim = dim 853 | 854 | self.down_convs = nn.ModuleList() 855 | self.up_convs = nn.ModuleList() 856 | 857 | if batch_norm != 'unset': 858 | raise RuntimeError( 859 | 'The `batch_norm` option has been replaced with the more general `normalization` option.\n' 860 | 'If you still want to use batch normalization, set `normalization=batch` instead.' 861 | ) 862 | 863 | # Indices of blocks that should operate in 2D instead of 3D mode, 864 | # to save resources 865 | self.planar_blocks = planar_blocks 866 | 867 | # create the encoder pathway and add to a list 868 | for i in range(n_blocks): 869 | ins = self.in_channels if i == 0 else outs 870 | outs = self.start_filts * (2**i) 871 | pooling = True if i < n_blocks - 1 else False 872 | planar = i in self.planar_blocks 873 | 874 | down_conv = DownConv( 875 | ins, 876 | outs, 877 | pooling=pooling, 878 | planar=planar, 879 | activation=activation, 880 | normalization=normalization, 881 | full_norm=full_norm, 882 | dim=dim, 883 | conv_mode=conv_mode, 884 | ) 885 | self.down_convs.append(down_conv) 886 | 887 | # create the decoder pathway and add to a list 888 | # - careful! decoding only requires n_blocks-1 blocks 889 | for i in range(n_blocks - 1): 890 | ins = outs 891 | outs = ins // 2 892 | planar = n_blocks - 2 - i in self.planar_blocks 893 | 894 | up_conv = UpConv( 895 | ins, 896 | outs, 897 | up_mode=up_mode, 898 | merge_mode=merge_mode, 899 | planar=planar, 900 | activation=activation, 901 | normalization=normalization, 902 | attention=attention, 903 | full_norm=full_norm, 904 | dim=dim, 905 | conv_mode=conv_mode, 906 | ) 907 | self.up_convs.append(up_conv) 908 | 909 | self.conv_final = conv1(outs, self.out_channels, dim=dim) 910 | 911 | self.apply(self.weight_init) 912 | 913 | @staticmethod 914 | def weight_init(m): 915 | if isinstance(m, GridAttention): 916 | return 917 | if isinstance(m, (nn.Conv3d, nn.Conv2d, nn.ConvTranspose3d, nn.ConvTranspose2d)): 918 | nn.init.xavier_normal_(m.weight) 919 | if getattr(m, 'bias') is not None: 920 | nn.init.constant_(m.bias, 0) 921 | 922 | def forward(self, x): 923 | encoder_outs = [] 924 | 925 | # Encoder pathway, save outputs for merging 926 | i = 0 # Can't enumerate because of https://github.com/pytorch/pytorch/issues/16123 927 | for module in self.down_convs: 928 | x, before_pool = module(x) 929 | encoder_outs.append(before_pool) 930 | i += 1 931 | 932 | # Decoding by UpConv and merging with saved outputs of encoder 933 | i = 0 934 | for module in self.up_convs: 935 | before_pool = encoder_outs[-(i+2)] 936 | x = module(before_pool, x) 937 | i += 1 938 | 939 | # No softmax is used, so you need to apply it in the loss. 940 | x = self.conv_final(x) 941 | # Uncomment the following line to temporarily store output for 942 | # receptive field estimation using fornoxai/receptivefield: 943 | # self.feature_maps = [x] # Currently disabled to save memory 944 | if self.sigmoid_output: 945 | x = torch.sigmoid(x) 946 | return x 947 | 948 | @torch.jit.unused 949 | def forward_gradcp(self, x): 950 | """``forward()`` implementation with gradient checkpointing enabled. 951 | Apart from checkpointing, this behaves the same as ``forward()``.""" 952 | encoder_outs = [] 953 | i = 0 954 | for module in self.down_convs: 955 | x, before_pool = checkpoint(module, x) 956 | encoder_outs.append(before_pool) 957 | i += 1 958 | i = 0 959 | for module in self.up_convs: 960 | before_pool = encoder_outs[-(i+2)] 961 | x = checkpoint(module, before_pool, x) 962 | i += 1 963 | x = self.conv_final(x) 964 | # self.feature_maps = [x] # Currently disabled to save memory 965 | return x 966 | 967 | -------------------------------------------------------------------------------- /smplpix/utils.py: -------------------------------------------------------------------------------- 1 | # helper functions for downloading and preprocessing SMPLpix training data 2 | # 3 | # (c) Sergey Prokudin (sergey.prokudin@gmail.com), 2021 4 | # 5 | 6 | import os 7 | 8 | def download_dropbox_url(url, filepath, chunk_size=1024): 9 | 10 | import requests 11 | headers = {'user-agent': 'Wget/1.16 (linux-gnu)'} 12 | r = requests.get(url, stream=True, headers=headers) 13 | with open(filepath, 'wb') as f: 14 | for chunk in r.iter_content(chunk_size=chunk_size): 15 | if chunk: 16 | f.write(chunk) 17 | return 18 | 19 | def unzip(zip_path, target_dir, remove_zip=True): 20 | import zipfile 21 | with zipfile.ZipFile(zip_path, 'r') as zip_ref: 22 | zip_ref.extractall(target_dir) 23 | if remove_zip: 24 | os.remove(zip_path) 25 | return 26 | 27 | 28 | def download_and_unzip(dropbox_url, workdir): 29 | 30 | if not os.path.exists(workdir): 31 | print("creating workdir %s" % workdir) 32 | os.makedirs(workdir) 33 | 34 | data_zip_path = os.path.join(workdir, 'data.zip') 35 | print("downloading zip from dropbox link: %s" % dropbox_url) 36 | download_dropbox_url(dropbox_url, data_zip_path) 37 | print("unzipping %s" % data_zip_path) 38 | unzip(data_zip_path, workdir) 39 | 40 | return 41 | 42 | 43 | def generate_mp4(image_dir, video_path, frame_rate=25, img_ext=None): 44 | 45 | if img_ext is None: 46 | test_img = os.listdir(image_dir)[0] 47 | img_ext = os.path.splitext(test_img)[1] 48 | 49 | ffmpeg_cmd = "ffmpeg -framerate %d -pattern_type glob " \ 50 | "-i \'%s/*%s\' -vcodec h264 -an -b:v 1M -pix_fmt yuv420p -an \'%s\'" % \ 51 | (frame_rate, image_dir, img_ext, video_path) 52 | 53 | print("executing %s" % ffmpeg_cmd) 54 | exit_code = os.system(ffmpeg_cmd) 55 | 56 | if exit_code != 0: 57 | print("something went wrong during video generation. Make sure you have ffmpeg tool installed.") 58 | 59 | return exit_code -------------------------------------------------------------------------------- /smplpix/vgg.py: -------------------------------------------------------------------------------- 1 | # extract perceptual features from the pre-trained Vgg16 network 2 | # these features are used for the perceptual loss function (https://arxiv.org/abs/1603.08155) 3 | # 4 | # based on the code snippet of W. Falcon: 5 | # https://gist.github.com/williamFalcon/1ee773c159ff5d76d47518653369d890 6 | 7 | import torch 8 | from torchvision import models 9 | 10 | class Vgg16Features(torch.nn.Module): 11 | 12 | def __init__(self, 13 | requires_grad=False, 14 | layers_weights=None): 15 | super(Vgg16Features, self).__init__() 16 | if layers_weights is None: 17 | self.layers_weights = [1 / 32, 1 / 16, 1 / 8, 1 / 4, 1] 18 | else: 19 | self.layers_weights = layers_weights 20 | 21 | vgg_pretrained_features = models.vgg16(pretrained=True).features 22 | self.slice1 = torch.nn.Sequential() 23 | self.slice2 = torch.nn.Sequential() 24 | self.slice3 = torch.nn.Sequential() 25 | self.slice4 = torch.nn.Sequential() 26 | for x in range(4): 27 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 28 | for x in range(4, 9): 29 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 30 | for x in range(9, 16): 31 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 32 | for x in range(16, 23): 33 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 34 | if not requires_grad: 35 | for param in self.parameters(): 36 | param.requires_grad = False 37 | 38 | def forward(self, x): 39 | 40 | h_0 = x.flatten(start_dim=1) 41 | h = self.slice1(x) 42 | h_relu1_2 = h.flatten(start_dim=1) 43 | h = self.slice2(h) 44 | h_relu2_2 = h.flatten(start_dim=1) 45 | h = self.slice3(h) 46 | h_relu3_3 = h.flatten(start_dim=1) 47 | h = self.slice4(h) 48 | h_relu4_3 = h.flatten(start_dim=1) 49 | 50 | h = torch.cat([self.layers_weights[0] * h_0, 51 | self.layers_weights[1] * h_relu1_2, 52 | self.layers_weights[2] * h_relu2_2, 53 | self.layers_weights[3] * h_relu3_3, 54 | self.layers_weights[4] * h_relu4_3], 1) 55 | 56 | return h 57 | --------------------------------------------------------------------------------