├── .editorconfig ├── .gitignore ├── Installation.md ├── LICENSE ├── README.md ├── README_jp.md ├── README_zh.md ├── app ├── __init__.py ├── all_models.py ├── custom_models │ ├── image2mvimage.yaml │ ├── image2normal.yaml │ ├── mvimg_prediction.py │ ├── normal_prediction.py │ └── utils.py ├── examples │ ├── Groot.png │ ├── aaa.png │ ├── abma.png │ ├── akun.png │ ├── anya.png │ ├── bag.png │ ├── ex1.png │ ├── ex2.png │ ├── ex3.jpg │ ├── ex4.png │ ├── generated_1715761545_frame0.png │ ├── generated_1715762357_frame0.png │ ├── generated_1715763329_frame0.png │ ├── hatsune_miku.png │ └── princess-large.png ├── gradio_3dgen.py ├── gradio_3dgen_steps.py ├── gradio_local.py └── utils.py ├── assets ├── teaser.jpg └── teaser_safe.jpg ├── custum_3d_diffusion ├── custum_modules │ ├── attention_processors.py │ └── unifield_processor.py ├── custum_pipeline │ ├── unifield_pipeline_img2img.py │ └── unifield_pipeline_img2mvimg.py ├── modules.py └── trainings │ ├── __init__.py │ ├── base.py │ ├── config_classes.py │ ├── image2image_trainer.py │ ├── image2mvimage_trainer.py │ └── utils.py ├── docker ├── Dockerfile └── README.md ├── gradio_app.py ├── install_windows_win_py311_cu121.bat ├── mesh_reconstruction ├── func.py ├── opt.py ├── recon.py ├── refine.py ├── remesh.py └── render.py ├── requirements-detail.txt ├── requirements-win-py311-cu121.txt ├── requirements.txt └── scripts ├── all_typing.py ├── load_onnx.py ├── mesh_init.py ├── multiview_inference.py ├── normal_to_height_map.py ├── project_mesh.py ├── refine_lr_to_sr.py ├── sd_model_zoo.py ├── upsampler.py └── utils.py /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*.py] 4 | charset = utf-8 5 | trim_trailing_whitespace = true 6 | end_of_line = lf 7 | insert_final_newline = true 8 | indent_style = space 9 | indent_size = 4 10 | 11 | [*.md] 12 | trim_trailing_whitespace = false 13 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/python 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=python 3 | 4 | ### Python ### 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # poetry 102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 106 | #poetry.lock 107 | 108 | # pdm 109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 110 | #pdm.lock 111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 112 | # in version control. 113 | # https://pdm.fming.dev/#use-with-ide 114 | .pdm.toml 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # pytype static type analyzer 154 | .pytype/ 155 | 156 | # Cython debug symbols 157 | cython_debug/ 158 | 159 | # PyCharm 160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 162 | # and can be added to the global gitignore or merged into this file. For a more nuclear 163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 164 | .idea/ 165 | 166 | ### Python Patch ### 167 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 168 | poetry.toml 169 | 170 | # ruff 171 | .ruff_cache/ 172 | 173 | # LSP config files 174 | pyrightconfig.json 175 | 176 | # End of https://www.toptal.com/developers/gitignore/api/python 177 | 178 | .vscode/ 179 | .threestudio_cache/ 180 | outputs 181 | outputs/ 182 | outputs-gradio 183 | outputs-gradio/ 184 | lightning_logs/ 185 | 186 | # pretrained model weights 187 | *.ckpt 188 | *.pt 189 | *.pth 190 | *.bin 191 | *.param 192 | 193 | # wandb 194 | wandb/ 195 | 196 | # obj results 197 | *.obj 198 | *.glb 199 | *.ply 200 | 201 | # ckpts 202 | ckpt/* 203 | *.pth 204 | *.pt 205 | 206 | # tensorrt 207 | *.engine 208 | *.profile 209 | 210 | # zipfiles 211 | *.zip 212 | *.tar 213 | *.tar.gz 214 | 215 | # others 216 | run_30.sh 217 | ckpt -------------------------------------------------------------------------------- /Installation.md: -------------------------------------------------------------------------------- 1 | # 官方安装指南 2 | 3 | * 在 requirements-detail.txt 里,我们提供了详细的各个库的版本,这个对应的环境是 `python3.10 + cuda12.2`。 4 | * 本项目依赖于几个重要的pypi包,这几个包安装起来会有一些困难。 5 | 6 | ### nvdiffrast 安装 7 | 8 | * nvdiffrast 会在第一次运行时,编译对应的torch插件,这一步需要 ninja 及 cudatoolkit的支持。 9 | * 因此需要先确保正确安装了 ninja 以及 cudatoolkit 并正确配置了 CUDA_HOME 环境变量。 10 | * cudatoolkit 安装可以参考 [linux-cuda-installation-guide](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html), [windows-cuda-installation-guide](https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html) 11 | * ninja 则可用直接 `pip install ninja` 12 | * 然后设置 CUDA_HOME 变量为 cudatoolkit 的安装目录,如 `/usr/local/cuda`。 13 | * 最后 `pip install nvdiffrast` 即可。 14 | * 如果无法在目标服务器上安装 cudatoolkit (如权限不够),可用使用我修改的[预编译版本 nvdiffrast](https://github.com/wukailu/nvdiffrast-torch) 在另一台拥有 cudatoolkit 且环境相似(python, torch, cuda版本相同)的服务器上预编译后安装。 15 | 16 | ### onnxruntime-gpu 安装 17 | 18 | * 注意,同时安装 `onnxruntime` 与 `onnxruntime-gpu` 可能导致最终程序无法运行在GPU,而运行在CPU,导致极慢的推理速度。 19 | * [onnxruntime 官方安装指南](https://onnxruntime.ai/docs/install/#python-installs) 20 | * TLDR: For cuda11.x, `pip install onnxruntime-gpu`. For cuda12.x, `pip install onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/ 21 | `. 22 | * 进一步的,可用安装基于 tensorrt 的 onnxruntime,进一步加快推理速度。 23 | * 注意:如果没有安装基于 tensorrt 的 onnxruntime,建议将 `https://github.com/AiuniAI/Unique3D/blob/4e1174c3896fee992ffc780d0ea813500401fae9/scripts/load_onnx.py#L4` 中 `TensorrtExecutionProvider` 删除。 24 | * 对于 cuda12.x 可用使用如下命令快速安装带有tensorrt的onnxruntime (注意将 `/root/miniconda3/lib/python3.10/site-packages` 修改为你的python 对应路径,将 `/root/.bashrc` 改为你的用户下路径 `.bashrc` 路劲) 25 | ``` 26 | pip install ort-nightly-gpu --index-url=https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ort-cuda-12-nightly/pypi/simple/ 27 | pip install onnxruntime-gpu==1.17.0 --index-url=https://pkgs.dev.azure.com/onnxruntime/onnxruntime/_packaging/onnxruntime-cuda-12/pypi/simple/ 28 | pip install tensorrt==8.6.0 29 | echo -e "export LD_LIBRARY_PATH=/usr/local/cuda/targets/x86_64-linux/lib/:/root/miniconda3/lib/python3.10/site-packages/tensorrt:${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" >> /root/.bashrc 30 | ``` 31 | 32 | ### pytorch3d 安装 33 | 34 | * 根据 [pytorch3d 官方的安装建议](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md#2-install-wheels-for-linux),建议使用预编译版本 35 | ``` 36 | import sys 37 | import torch 38 | pyt_version_str=torch.__version__.split("+")[0].replace(".", "") 39 | version_str="".join([ 40 | f"py3{sys.version_info.minor}_cu", 41 | torch.version.cuda.replace(".",""), 42 | f"_pyt{pyt_version_str}" 43 | ]) 44 | !pip install fvcore iopath 45 | !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html 46 | ``` 47 | 48 | ### torch_scatter 安装 49 | 50 | * 在[torch_scatter 官方安装指南](https://github.com/rusty1s/pytorch_scatter?tab=readme-ov-file#installation) 使用预编译的安装包快速安装。 51 | * 或者直接编译安装 `pip install git+https://github.com/rusty1s/pytorch_scatter.git` 52 | 53 | ### 其他安装 54 | 55 | * 其他文件 `pip install -r requirements.txt` 即可。 56 | 57 | ----- 58 | 59 | # Detailed Installation Guide 60 | 61 | * In `requirements-detail.txt`, we provide detailed versions of all packages, which correspond to the environment of `python3.10 + cuda12.2`. 62 | * This project relies on several important PyPI packages, which may be difficult to install. 63 | 64 | ### Installation of nvdiffrast 65 | 66 | * nvdiffrast will compile the corresponding torch plugin the first time it runs, which requires support from ninja and cudatoolkit. 67 | * Therefore, it is necessary to ensure that ninja and cudatoolkit are correctly installed and that the CUDA_HOME environment variable is properly configured. 68 | * For the installation of cudatoolkit, you can refer to the [Linux CUDA Installation Guide](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html) and [Windows CUDA Installation Guide](https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html). 69 | * Ninja can be directly installed with `pip install ninja`. 70 | * Then set the CUDA_HOME variable to the installation directory of cudatoolkit, such as `/usr/local/cuda`. 71 | * Finally, `pip install nvdiffrast`. 72 | * If you cannot install cudatoolkit on the computer (e.g., insufficient permissions), you can use my modified [pre-compiled version of nvdiffrast](https://github.com/wukailu/nvdiffrast-torch) to pre-compile on another computer that has cudatoolkit and a similar environment (same versions of python, torch, cuda) and then install the `.whl`. 73 | 74 | ### Installation of onnxruntime-gpu 75 | 76 | * Note that installing both `onnxruntime` and `onnxruntime-gpu` may result in not running on the GPU but on the CPU, leading to extremely slow inference speed. 77 | * [Official ONNX Runtime Installation Guide](https://onnxruntime.ai/docs/install/#python-installs) 78 | * TLDR: For cuda11.x, `pip install onnxruntime-gpu`. For cuda12.x, `pip install onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/`. 79 | * Furthermore, you can install onnxruntime based on tensorrt to further increase the inference speed. 80 | * Note: If you do not correctly installed onnxruntime based on tensorrt, it is recommended to remove `TensorrtExecutionProvider` from `https://github.com/AiuniAI/Unique3D/blob/4e1174c3896fee992ffc780d0ea813500401fae9/scripts/load_onnx.py#L4`. 81 | * For cuda12.x, you can quickly install onnxruntime with tensorrt using the following commands (note to change the path `/root/miniconda3/lib/python3.10/site-packages` to the corresponding path of your python, and change `/root/.bashrc` to the path of `.bashrc` under your user directory): 82 | ``` 83 | pip install ort-nightly-gpu --index-url=https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ort-cuda-12-nightly/pypi/simple/ 84 | pip install onnxruntime-gpu==1.17.0 --index-url=https://pkgs.dev.azure.com/onnxruntime/onnxruntime/_packaging/onnxruntime-cuda-12/pypi/simple/ 85 | pip install tensorrt==8.6.0 86 | echo -e "export LD_LIBRARY_PATH=/usr/local/cuda/targets/x86_64-linux/lib/:/root/miniconda3/lib/python3.10/site-packages/tensorrt:${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" >> /root/.bashrc 87 | ``` 88 | 89 | ### Installation of pytorch3d 90 | 91 | * According to the [official installation recommendations of pytorch3d](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md#2-install-wheels-for-linux), it is recommended to use the pre-compiled version: 92 | ``` 93 | import sys 94 | import torch 95 | pyt_version_str=torch.__version__.split("+")[0].replace(".", "") 96 | version_str="".join([ 97 | f"py3{sys.version_info.minor}_cu", 98 | torch.version.cuda.replace(".",""), 99 | f"_pyt{pyt_version_str}" 100 | ]) 101 | !pip install fvcore iopath 102 | !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html 103 | ``` 104 | 105 | ### Installation of torch_scatter 106 | 107 | * Use the pre-compiled installation package according to the [official installation guide of torch_scatter](https://github.com/rusty1s/pytorch_scatter?tab=readme-ov-file#installation) for a quick installation. 108 | * Alternatively, you can directly compile and install with `pip install git+https://github.com/rusty1s/pytorch_scatter.git`. 109 | 110 | ### Other Installations 111 | 112 | * For other packages, simply `pip install -r requirements.txt`. 113 | 114 | ----- 115 | 116 | # 官方インストールガイド 117 | 118 | * `requirements-detail.txt` には、各ライブラリのバージョンが詳細に提供されており、これは Python 3.10 + CUDA 12.2 に対応する環境です。 119 | * このプロジェクトは、いくつかの重要な PyPI パッケージに依存しており、これらのパッケージのインストールにはいくつかの困難が伴います。 120 | 121 | ### nvdiffrast のインストール 122 | 123 | * nvdiffrast は、最初に実行するときに、torch プラグインの対応バージョンをコンパイルします。このステップには、ninja および cudatoolkit のサポートが必要です。 124 | * したがって、ninja および cudatoolkit の正確なインストールと、CUDA_HOME 環境変数の正確な設定を確保する必要があります。 125 | * cudatoolkit のインストールについては、[Linux CUDA インストールガイド](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html)、[Windows CUDA インストールガイド](https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html) を参照してください。 126 | * ninja は、直接 `pip install ninja` でインストールできます。 127 | * 次に、CUDA_HOME 変数を cudatoolkit のインストールディレクトリに設定します。例えば、`/usr/local/cuda` のように。 128 | * 最後に、`pip install nvdiffrast` を実行します。 129 | * 目標サーバーで cudatoolkit をインストールできない場合(例えば、権限が不足している場合)、私の修正した[事前コンパイル済みバージョンの nvdiffrast](https://github.com/wukailu/nvdiffrast-torch)を使用できます。これは、cudatoolkit があり、環境が似ている(Python、torch、cudaのバージョンが同じ)別のサーバーで事前コンパイルしてからインストールすることができます。 130 | 131 | ### onnxruntime-gpu のインストール 132 | 133 | * 注意:`onnxruntime` と `onnxruntime-gpu` を同時にインストールすると、最終的なプログラムが GPU 上で実行されず、CPU 上で実行される可能性があり、推論速度が非常に遅くなることがあります。 134 | * [onnxruntime 公式インストールガイド](https://onnxruntime.ai/docs/install/#python-installs) 135 | * TLDR: cuda11.x 用には、`pip install onnxruntime-gpu` を使用します。cuda12.x 用には、`pip install onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/` を使用します。 136 | * さらに、TensorRT ベースの onnxruntime をインストールして、推論速度をさらに向上させることができます。 137 | * 注意:TensorRT ベースの onnxruntime がインストールされていない場合は、`https://github.com/AiuniAI/Unique3D/blob/4e1174c3896fee992ffc780d0ea813500401fae9/scripts/load_onnx.py#L4` の `TensorrtExecutionProvider` を削除することをお勧めします。 138 | * cuda12.x の場合、次のコマンドを使用して迅速に TensorRT を備えた onnxruntime をインストールできます(`/root/miniconda3/lib/python3.10/site-packages` をあなたの Python に対応するパスに、`/root/.bashrc` をあなたのユーザーのパスの下の `.bashrc` に変更してください)。 139 | ```bash 140 | pip install ort-nightly-gpu --index-url=https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ort-cuda-12-nightly/pypi/simple/ 141 | pip install onnxruntime-gpu==1.17.0 --index-url=https://pkgs.dev.azure.com/onnxruntime/onnxruntime/_packaging/onnxruntime-cuda-12/pypi/simple/ 142 | pip install tensorrt==8.6.0 143 | echo -e "export LD_LIBRARY_PATH=/usr/local/cuda/targets/x86_64-linux/lib/:/root/miniconda3/lib/python3.10/site-packages/tensorrt:${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" >> /root/.bashrc 144 | ``` 145 | 146 | ### pytorch3d のインストール 147 | 148 | * [pytorch3d 公式のインストール提案](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md#2-install-wheels-for-linux)に従い、事前コンパイル済みバージョンを使用することをお勧めします。 149 | ```python 150 | import sys 151 | import torch 152 | pyt_version_str=torch.__version__.split("+")[0].replace(".", "") 153 | version_str="".join([ 154 | f"py3{sys.version_info.minor}_cu", 155 | torch.version.cuda.replace(".",""), 156 | f"_pyt{pyt_version_str}" 157 | ]) 158 | !pip install fvcore iopath 159 | !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html 160 | ``` 161 | 162 | ### torch_scatter のインストール 163 | 164 | * [torch_scatter 公式インストールガイド](https://github.com/rusty1s/pytorch_scatter?tab=readme-ov-file#installation)に従い、事前コンパイル済みのインストールパッケージを使用して迅速インストールします。 165 | * または、直接コンパイルしてインストールする `pip install git+https://github.com/rusty1s/pytorch_scatter.git` も可能です。 166 | 167 | ### その他のインストール 168 | 169 | * その他のファイルについては、`pip install -r requirements.txt` を実行するだけです。 170 | 171 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 AiuniAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **[中文版本](README_zh.md)** 2 | 3 | **[日本語版](README_jp.md)** 4 | 5 | # Unique3D 6 | Official implementation of Unique3D: High-Quality and Efficient 3D Mesh Generation from a Single Image. 7 | 8 | [Kailu Wu](https://scholar.google.com/citations?user=VTU0gysAAAAJ&hl=zh-CN&oi=ao), [Fangfu Liu](https://liuff19.github.io/), Zhihan Cai, Runjie Yan, Hanyang Wang, Yating Hu, [Yueqi Duan](https://duanyueqi.github.io/), [Kaisheng Ma](https://group.iiis.tsinghua.edu.cn/~maks/) 9 | 10 | ## [Paper](https://arxiv.org/abs/2405.20343) | [Project page](https://wukailu.github.io/Unique3D/) | [Huggingface Demo](https://huggingface.co/spaces/Wuvin/Unique3D) | [Online Demo](https://www.aiuni.ai/) 11 | 12 | * Demo inference speed: Gradio Demo > Huggingface Demo > Huggingface Demo2 > Online Demo 13 | 14 | **If the Gradio Demo is overcrowded or fails to produce stable results, you can use the Online Demo [aiuni.ai](https://www.aiuni.ai/), which is free to try (get the registration invitation code Join Discord: https://discord.gg/aiuni). However, the Online Demo is slightly different from the Gradio Demo, in that the inference speed is slower, but the generation is much more stable.** 15 | 16 |

17 | 18 |

19 | 20 | High-fidelity and diverse textured meshes generated by Unique3D from single-view wild images in 30 seconds. 21 | 22 | ## More features 23 | 24 | The repo is still being under construction, thanks for your patience. 25 | - [x] Upload weights. 26 | - [x] Local gradio demo. 27 | - [x] Detailed tutorial. 28 | - [x] Huggingface demo. 29 | - [ ] Detailed local demo. 30 | - [x] Comfyui support. 31 | - [x] Windows support. 32 | - [x] Docker support. 33 | - [ ] More stable reconstruction with normal. 34 | - [ ] Training code release. 35 | 36 | ## Preparation for inference 37 | 38 | * [Detailed linux installation guide](Installation.md). 39 | 40 | ### Linux System Setup. 41 | 42 | Adapted for Ubuntu 22.04.4 LTS and CUDA 12.1. 43 | ```angular2html 44 | conda create -n unique3d python=3.11 45 | conda activate unique3d 46 | 47 | pip install ninja 48 | pip install diffusers==0.27.2 49 | 50 | pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu121/torch2.3.1/index.html 51 | 52 | pip install -r requirements.txt 53 | ``` 54 | 55 | [oak-barry](https://github.com/oak-barry) provide another setup script for torch210+cu121 at [here](https://github.com/oak-barry/Unique3D). 56 | 57 | ### Windows Setup. 58 | 59 | * Thank you very much `jtydhr88` for the windows installation method! See [issues/15](https://github.com/AiuniAI/Unique3D/issues/15). 60 | 61 | According to [issues/15](https://github.com/AiuniAI/Unique3D/issues/15), implemented a bat script to run the commands, so you can: 62 | 1. Might still require Visual Studio Build Tools, you can find it from [Visual Studio Build Tools](https://visualstudio.microsoft.com/downloads/?q=build+tools). 63 | 2. Create conda env and activate it 64 | 1. `conda create -n unique3d-py311 python=3.11` 65 | 2. `conda activate unique3d-py311` 66 | 3. download [triton whl](https://huggingface.co/madbuda/triton-windows-builds/resolve/main/triton-2.1.0-cp311-cp311-win_amd64.whl) for py311, and put it into this project. 67 | 4. run **install_windows_win_py311_cu121.bat** 68 | 5. answer y while asking you uninstall onnxruntime and onnxruntime-gpu 69 | 6. create the output folder **tmp\gradio** under the driver root, such as F:\tmp\gradio for me. 70 | 7. python app/gradio_local.py --port 7860 71 | 72 | More details prefer to [issues/15](https://github.com/AiuniAI/Unique3D/issues/15). 73 | 74 | ### Interactive inference: run your local gradio demo. 75 | 76 | 1. Download the weights from [huggingface spaces](https://huggingface.co/spaces/Wuvin/Unique3D/tree/main/ckpt) or [Tsinghua Cloud Drive](https://cloud.tsinghua.edu.cn/d/319762ec478d46c8bdf7/), and extract it to `ckpt/*`. 77 | ``` 78 | Unique3D 79 | ├──ckpt 80 | ├── controlnet-tile/ 81 | ├── image2normal/ 82 | ├── img2mvimg/ 83 | ├── realesrgan-x4.onnx 84 | └── v1-inference.yaml 85 | ``` 86 | 87 | 2. Run the interactive inference locally. 88 | ```bash 89 | python app/gradio_local.py --port 7860 90 | ``` 91 | 92 | ## ComfyUI Support 93 | 94 | Thanks for the [ComfyUI-Unique3D](https://github.com/jtydhr88/ComfyUI-Unique3D) implementation from [jtydhr88](https://github.com/jtydhr88)! 95 | 96 | ## Tips to get better results 97 | 98 | **Important: Because the mesh is normalized by the longest edge of xyz during training, it is desirable that the input image needs to contain the longest edge of the object during inference, or else you may get erroneously squashed results.** 99 | 1. Unique3D is sensitive to the facing direction of input images. Due to the distribution of the training data, orthographic front-facing images with a rest pose always lead to good reconstructions. 100 | 2. Images with occlusions will cause worse reconstructions, since four views cannot cover the complete object. Images with fewer occlusions lead to better results. 101 | 3. Pass an image with as high a resolution as possible to the input when resolution is a factor. 102 | 103 | ## Acknowledgement 104 | 105 | We have intensively borrowed code from the following repositories. Many thanks to the authors for sharing their code. 106 | - [Stable Diffusion](https://github.com/CompVis/stable-diffusion) 107 | - [Wonder3d](https://github.com/xxlong0/Wonder3D) 108 | - [Zero123Plus](https://github.com/SUDO-AI-3D/zero123plus) 109 | - [Continues Remeshing](https://github.com/Profactor/continuous-remeshing) 110 | - [Depth from Normals](https://github.com/YertleTurtleGit/depth-from-normals) 111 | 112 | ## Collaborations 113 | Our mission is to create a 4D generative model with 3D concepts. This is just our first step, and the road ahead is still long, but we are confident. We warmly invite you to join the discussion and explore potential collaborations in any capacity. **If you're interested in connecting or partnering with us, please don't hesitate to reach out via email (wkl22@mails.tsinghua.edu.cn)**. 114 | 115 | - Follow us on twitter for the latest updates: https://x.com/aiuni_ai 116 | - Join AIGC 3D/4D generation community on discord: https://discord.gg/aiuni 117 | - Research collaboration, please contact: ai@aiuni.ai 118 | 119 | ## Citation 120 | 121 | If you found Unique3D helpful, please cite our report: 122 | ```bibtex 123 | @misc{wu2024unique3d, 124 | title={Unique3D: High-Quality and Efficient 3D Mesh Generation from a Single Image}, 125 | author={Kailu Wu and Fangfu Liu and Zhihan Cai and Runjie Yan and Hanyang Wang and Yating Hu and Yueqi Duan and Kaisheng Ma}, 126 | year={2024}, 127 | eprint={2405.20343}, 128 | archivePrefix={arXiv}, 129 | primaryClass={cs.CV} 130 | } 131 | ``` 132 | -------------------------------------------------------------------------------- /README_jp.md: -------------------------------------------------------------------------------- 1 | **他の言語のバージョン [英語](README.md) [中国語](README_zh.md)** 2 | 3 | # Unique3D 4 | Unique3D: 単一画像からの高品質かつ効率的な3Dメッシュ生成の公式実装。 5 | 6 | [Kailu Wu](https://scholar.google.com/citations?user=VTU0gysAAAAJ&hl=zh-CN&oi=ao), [Fangfu Liu](https://liuff19.github.io/), Zhihan Cai, Runjie Yan, Hanyang Wang, Yating Hu, [Yueqi Duan](https://duanyueqi.github.io/), [Kaisheng Ma](https://group.iiis.tsinghua.edu.cn/~maks/) 7 | 8 | ## [論文](https://arxiv.org/abs/2405.20343) | [プロジェクトページ](https://wukailu.github.io/Unique3D/) | [Huggingfaceデモ](https://huggingface.co/spaces/Wuvin/Unique3D) | [オンラインデモ](https://www.aiuni.ai/) 9 | 10 | * デモ推論速度: Gradioデモ > Huggingfaceデモ > Huggingfaceデモ2 > オンラインデモ 11 | 12 | **Gradioデモが残念ながらハングアップしたり、非常に混雑している場合は、[aiuni.ai](https://www.aiuni.ai/)のオンラインデモを使用できます。これは無料で試すことができます(登録招待コードを取得するには、Discordに参加してください: https://discord.gg/aiuni)。ただし、オンラインデモはGradioデモとは少し異なり、推論速度が遅く、生成結果が安定していない可能性がありますが、素材の品質は良いです。** 13 | 14 |

15 | 16 |

17 | 18 | Unique3Dは、野生の単一画像から高忠実度および多様なテクスチャメッシュを30秒で生成します。 19 | 20 | ## より多くの機能 21 | 22 | リポジトリはまだ構築中です。ご理解いただきありがとうございます。 23 | - [x] 重みのアップロード。 24 | - [x] ローカルGradioデモ。 25 | - [ ] 詳細なチュートリアル。 26 | - [x] Huggingfaceデモ。 27 | - [ ] 詳細なローカルデモ。 28 | - [x] Comfyuiサポート。 29 | - [x] Windowsサポート。 30 | - [ ] Dockerサポート。 31 | - [ ] ノーマルでより安定した再構築。 32 | - [ ] トレーニングコードのリリース。 33 | 34 | ## 推論の準備 35 | 36 | ### Linuxシステムセットアップ 37 | 38 | Ubuntu 22.04.4 LTSおよびCUDA 12.1に適応。 39 | ```angular2html 40 | conda create -n unique3d python=3.11 41 | conda activate unique3d 42 | 43 | pip install ninja 44 | pip install diffusers==0.27.2 45 | 46 | pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu121/torch2.3.1/index.html 47 | 48 | pip install -r requirements.txt 49 | ``` 50 | 51 | [oak-barry](https://github.com/oak-barry)は、[こちら](https://github.com/oak-barry/Unique3D)でtorch210+cu121の別のセットアップスクリプトを提供しています。 52 | 53 | ### Windowsセットアップ 54 | 55 | * `jtydhr88`によるWindowsインストール方法に非常に感謝します![issues/15](https://github.com/AiuniAI/Unique3D/issues/15)を参照してください。 56 | 57 | [issues/15](https://github.com/AiuniAI/Unique3D/issues/15)によると、コマンドを実行するバッチスクリプトを実装したので、以下の手順に従ってください。 58 | 1. [Visual Studio Build Tools](https://visualstudio.microsoft.com/downloads/?q=build+tools)からVisual Studio Build Toolsが必要になる場合があります。 59 | 2. conda envを作成し、アクティブにします。 60 | 1. `conda create -n unique3d-py311 python=3.11` 61 | 2. `conda activate unique3d-py311` 62 | 3. [triton whl](https://huggingface.co/madbuda/triton-windows-builds/resolve/main/triton-2.1.0-cp311-cp311-win_amd64.whl)をダウンロードし、このプロジェクトに配置します。 63 | 4. **install_windows_win_py311_cu121.bat**を実行します。 64 | 5. onnxruntimeおよびonnxruntime-gpuのアンインストールを求められた場合は、yと回答します。 65 | 6. ドライバールートの下に**tmp\gradio**フォルダを作成します(例:F:\tmp\gradio)。 66 | 7. python app/gradio_local.py --port 7860 67 | 68 | 詳細は[issues/15](https://github.com/AiuniAI/Unique3D/issues/15)を参照してください。 69 | 70 | ### インタラクティブ推論:ローカルGradioデモを実行する 71 | 72 | 1. [huggingface spaces](https://huggingface.co/spaces/Wuvin/Unique3D/tree/main/ckpt)または[Tsinghua Cloud Drive](https://cloud.tsinghua.edu.cn/d/319762ec478d46c8bdf7/)から重みをダウンロードし、`ckpt/*`に抽出します。 73 | ``` 74 | Unique3D 75 | ├──ckpt 76 | ├── controlnet-tile/ 77 | ├── image2normal/ 78 | ├── img2mvimg/ 79 | ├── realesrgan-x4.onnx 80 | └── v1-inference.yaml 81 | ``` 82 | 83 | 2. インタラクティブ推論をローカルで実行します。 84 | ```bash 85 | python app/gradio_local.py --port 7860 86 | ``` 87 | 88 | ## ComfyUIサポート 89 | 90 | [jtydhr88](https://github.com/jtydhr88)からの[ComfyUI-Unique3D](https://github.com/jtydhr88/ComfyUI-Unique3D)の実装に感謝します! 91 | 92 | ## より良い結果を得るためのヒント 93 | 94 | 1. Unique3Dは入力画像の向きに敏感です。トレーニングデータの分布により、正面を向いた直交画像は常に良い再構築につながります。 95 | 2. 遮蔽のある画像は、4つのビューがオブジェクトを完全にカバーできないため、再構築が悪化します。遮蔽の少ない画像は、より良い結果につながります。 96 | 3. 可能な限り高解像度の画像を入力として使用してください。 97 | 98 | ## 謝辞 99 | 100 | 以下のリポジトリからコードを大量に借用しました。コードを共有してくれた著者に感謝します。 101 | - [Stable Diffusion](https://github.com/CompVis/stable-diffusion) 102 | - [Wonder3d](https://github.com/xxlong0/Wonder3D) 103 | - [Zero123Plus](https://github.com/SUDO-AI-3D/zero123plus) 104 | - [Continues Remeshing](https://github.com/Profactor/continuous-remeshing) 105 | - [Depth from Normals](https://github.com/YertleTurtleGit/depth-from-normals) 106 | 107 | ## コラボレーション 108 | 私たちの使命は、3Dの概念を持つ4D生成モデルを作成することです。これは私たちの最初のステップであり、前途はまだ長いですが、私たちは自信を持っています。あらゆる形態の潜在的なコラボレーションを探求し、議論に参加することを心から歓迎します。**私たちと連絡を取りたい、またはパートナーシップを結びたい方は、メールでお気軽にお問い合わせください (wkl22@mails.tsinghua.edu.cn)**。 109 | 110 | - 最新情報を入手するには、Twitterをフォローしてください: https://x.com/aiuni_ai 111 | - DiscordでAIGC 3D/4D生成コミュニティに参加してください: https://discord.gg/aiuni 112 | - 研究協力については、ai@aiuni.aiまでご連絡ください。 113 | 114 | ## 引用 115 | 116 | Unique3Dが役立つと思われる場合は、私たちのレポートを引用してください: 117 | ```bibtex 118 | @misc{wu2024unique3d, 119 | title={Unique3D: High-Quality and Efficient 3D Mesh Generation from a Single Image}, 120 | author={Kailu Wu and Fangfu Liu and Zhihan Cai and Runjie Yan and Hanyang Wang and Yating Hu and Yueqi Duan and Kaisheng Ma}, 121 | year={2024}, 122 | eprint={2405.20343}, 123 | archivePrefix={arXiv}, 124 | primaryClass={cs.CV} 125 | } 126 | ``` 127 | -------------------------------------------------------------------------------- /README_zh.md: -------------------------------------------------------------------------------- 1 | **其他语言版本 [English](README.md)** 2 | 3 | # Unique3D 4 | High-Quality and Efficient 3D Mesh Generation from a Single Image 5 | 6 | [Kailu Wu](https://scholar.google.com/citations?user=VTU0gysAAAAJ&hl=zh-CN&oi=ao), [Fangfu Liu](https://liuff19.github.io/), Zhihan Cai, Runjie Yan, Hanyang Wang, Yating Hu, [Yueqi Duan](https://duanyueqi.github.io/), [Kaisheng Ma](https://group.iiis.tsinghua.edu.cn/~maks/) 7 | 8 | ## [论文](https://arxiv.org/abs/2405.20343) | [项目页面](https://wukailu.github.io/Unique3D/) | [Huggingface Demo](https://huggingface.co/spaces/Wuvin/Unique3D) | [Gradio Demo](http://unique3d.demo.avar.cn/) | [在线演示](https://www.aiuni.ai/) 9 | 10 | 11 | 12 |

13 | 14 |

15 | 16 | Unique3D从单视图图像生成高保真度和多样化纹理的网格,在4090上大约需要30秒。 17 | 18 | ### 推理准备 19 | 20 | #### Linux系统设置 21 | ```angular2html 22 | conda create -n unique3d 23 | conda activate unique3d 24 | pip install -r requirements.txt 25 | ``` 26 | 27 | #### 交互式推理:运行您的本地gradio演示 28 | 29 | 1. 从 [huggingface spaces](https://huggingface.co/spaces/Wuvin/Unique3D/tree/main/ckpt) 下载或者从[清华云盘](https://cloud.tsinghua.edu.cn/d/319762ec478d46c8bdf7/)下载权重,并将其解压到`ckpt/*`。 30 | ``` 31 | Unique3D 32 | ├──ckpt 33 | ├── controlnet-tile/ 34 | ├── image2normal/ 35 | ├── img2mvimg/ 36 | ├── realesrgan-x4.onnx 37 | └── v1-inference.yaml 38 | ``` 39 | 40 | 2. 在本地运行交互式推理。 41 | ```bash 42 | python app/gradio_local.py --port 7860 43 | ``` 44 | 45 | ## 获取更好结果的提示 46 | 47 | 1. Unique3D对输入图像的朝向非常敏感。由于训练数据的分布,**正交正视图像**通常总是能带来良好的重建。对于人物而言,最好是 A-pose 或者 T-pose,因为目前训练数据很少含有其他类型姿态。 48 | 2. 有遮挡的图像会导致更差的重建,因为4个视图无法覆盖完整的对象。遮挡较少的图像会带来更好的结果。 49 | 3. 尽可能将高分辨率的图像用作输入。 50 | 51 | ## 致谢 52 | 53 | 我们借用了以下代码库的代码。非常感谢作者们分享他们的代码。 54 | - [Stable Diffusion](https://github.com/CompVis/stable-diffusion) 55 | - [Wonder3d](https://github.com/xxlong0/Wonder3D) 56 | - [Zero123Plus](https://github.com/SUDO-AI-3D/zero123plus) 57 | - [Continues Remeshing](https://github.com/Profactor/continuous-remeshing) 58 | - [Depth from Normals](https://github.com/YertleTurtleGit/depth-from-normals) 59 | 60 | ## 合作 61 | 62 | 我们使命是创建一个具有3D概念的4D生成模型。这只是我们的第一步,前方的道路仍然很长,但我们有信心。我们热情邀请您加入讨论,并探索任何形式的潜在合作。**如果您有兴趣联系或与我们合作,欢迎通过电子邮件(wkl22@mails.tsinghua.edu.cn)与我们联系**。 63 | -------------------------------------------------------------------------------- /app/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AiuniAI/Unique3D/a4f4ad2a6d9e71f2b04a753b375022898581b2c8/app/__init__.py -------------------------------------------------------------------------------- /app/all_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from scripts.sd_model_zoo import load_common_sd15_pipe 3 | from diffusers import StableDiffusionControlNetImg2ImgPipeline, StableDiffusionPipeline 4 | 5 | 6 | class MyModelZoo: 7 | _pipe_disney_controlnet_lineart_ipadapter_i2i: StableDiffusionControlNetImg2ImgPipeline = None 8 | 9 | base_model = "runwayml/stable-diffusion-v1-5" 10 | 11 | def __init__(self, base_model=None) -> None: 12 | if base_model is not None: 13 | self.base_model = base_model 14 | 15 | @property 16 | def pipe_disney_controlnet_tile_ipadapter_i2i(self): 17 | return self._pipe_disney_controlnet_lineart_ipadapter_i2i 18 | 19 | def init_models(self): 20 | self._pipe_disney_controlnet_lineart_ipadapter_i2i = load_common_sd15_pipe(base_model=self.base_model, ip_adapter=True, plus_model=False, controlnet="./ckpt/controlnet-tile", pipeline_class=StableDiffusionControlNetImg2ImgPipeline) 21 | 22 | model_zoo = MyModelZoo() 23 | -------------------------------------------------------------------------------- /app/custom_models/image2mvimage.yaml: -------------------------------------------------------------------------------- 1 | pretrained_model_name_or_path: "./ckpt/img2mvimg" 2 | mixed_precision: "bf16" 3 | 4 | init_config: 5 | # enable controls 6 | enable_cross_attn_lora: False 7 | enable_cross_attn_ip: False 8 | enable_self_attn_lora: False 9 | enable_self_attn_ref: False 10 | enable_multiview_attn: True 11 | 12 | # for cross attention 13 | init_cross_attn_lora: False 14 | init_cross_attn_ip: False 15 | cross_attn_lora_rank: 256 # 0 for not enabled 16 | cross_attn_lora_only_kv: False 17 | ipadapter_pretrained_name: "h94/IP-Adapter" 18 | ipadapter_subfolder_name: "models" 19 | ipadapter_weight_name: "ip-adapter_sd15.safetensors" 20 | ipadapter_effect_on: "all" # all, first 21 | 22 | # for self attention 23 | init_self_attn_lora: False 24 | self_attn_lora_rank: 256 25 | self_attn_lora_only_kv: False 26 | 27 | # for self attention ref 28 | init_self_attn_ref: False 29 | self_attn_ref_position: "attn1" 30 | self_attn_ref_other_model_name: "lambdalabs/sd-image-variations-diffusers" 31 | self_attn_ref_pixel_wise_crosspond: False 32 | self_attn_ref_effect_on: "all" 33 | 34 | # for multiview attention 35 | init_multiview_attn: True 36 | multiview_attn_position: "attn1" 37 | use_mv_joint_attn: True 38 | num_modalities: 1 39 | 40 | # for unet 41 | init_unet_path: "${pretrained_model_name_or_path}" 42 | cat_condition: True # cat condition to input 43 | 44 | # for cls embedding 45 | init_num_cls_label: 8 # for initialize 46 | cls_labels: [0, 1, 2, 3] # for current task 47 | 48 | trainers: 49 | - trainer_type: "image2mvimage_trainer" 50 | trainer: 51 | pretrained_model_name_or_path: "${pretrained_model_name_or_path}" 52 | attn_config: 53 | cls_labels: [0, 1, 2, 3] # for current task 54 | enable_cross_attn_lora: False 55 | enable_cross_attn_ip: False 56 | enable_self_attn_lora: False 57 | enable_self_attn_ref: False 58 | enable_multiview_attn: True 59 | resolution: "256" 60 | condition_image_resolution: "256" 61 | normal_cls_offset: 4 62 | condition_image_column_name: "conditioning_image" 63 | image_column_name: "image" -------------------------------------------------------------------------------- /app/custom_models/image2normal.yaml: -------------------------------------------------------------------------------- 1 | pretrained_model_name_or_path: "lambdalabs/sd-image-variations-diffusers" 2 | mixed_precision: "bf16" 3 | 4 | init_config: 5 | # enable controls 6 | enable_cross_attn_lora: False 7 | enable_cross_attn_ip: False 8 | enable_self_attn_lora: False 9 | enable_self_attn_ref: True 10 | enable_multiview_attn: False 11 | 12 | # for cross attention 13 | init_cross_attn_lora: False 14 | init_cross_attn_ip: False 15 | cross_attn_lora_rank: 512 # 0 for not enabled 16 | cross_attn_lora_only_kv: False 17 | ipadapter_pretrained_name: "h94/IP-Adapter" 18 | ipadapter_subfolder_name: "models" 19 | ipadapter_weight_name: "ip-adapter_sd15.safetensors" 20 | ipadapter_effect_on: "all" # all, first 21 | 22 | # for self attention 23 | init_self_attn_lora: False 24 | self_attn_lora_rank: 512 25 | self_attn_lora_only_kv: False 26 | 27 | # for self attention ref 28 | init_self_attn_ref: True 29 | self_attn_ref_position: "attn1" 30 | self_attn_ref_other_model_name: "lambdalabs/sd-image-variations-diffusers" 31 | self_attn_ref_pixel_wise_crosspond: True 32 | self_attn_ref_effect_on: "all" 33 | 34 | # for multiview attention 35 | init_multiview_attn: False 36 | multiview_attn_position: "attn1" 37 | num_modalities: 1 38 | 39 | # for unet 40 | init_unet_path: "${pretrained_model_name_or_path}" 41 | init_num_cls_label: 0 # for initialize 42 | cls_labels: [] # for current task 43 | 44 | trainers: 45 | - trainer_type: "image2image_trainer" 46 | trainer: 47 | pretrained_model_name_or_path: "${pretrained_model_name_or_path}" 48 | attn_config: 49 | cls_labels: [] # for current task 50 | enable_cross_attn_lora: False 51 | enable_cross_attn_ip: False 52 | enable_self_attn_lora: False 53 | enable_self_attn_ref: True 54 | enable_multiview_attn: False 55 | resolution: "512" 56 | condition_image_resolution: "512" 57 | condition_image_column_name: "conditioning_image" 58 | image_column_name: "image" 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /app/custom_models/mvimg_prediction.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import gradio as gr 4 | from PIL import Image 5 | import numpy as np 6 | from rembg import remove 7 | from app.utils import change_rgba_bg, rgba_to_rgb 8 | from app.custom_models.utils import load_pipeline 9 | from scripts.all_typing import * 10 | from scripts.utils import session, simple_preprocess 11 | 12 | training_config = "app/custom_models/image2mvimage.yaml" 13 | checkpoint_path = "ckpt/img2mvimg/unet_state_dict.pth" 14 | trainer, pipeline = load_pipeline(training_config, checkpoint_path) 15 | # pipeline.enable_model_cpu_offload() 16 | 17 | def predict(img_list: List[Image.Image], guidance_scale=2., **kwargs): 18 | if isinstance(img_list, Image.Image): 19 | img_list = [img_list] 20 | img_list = [rgba_to_rgb(i) if i.mode == 'RGBA' else i for i in img_list] 21 | ret = [] 22 | for img in img_list: 23 | images = trainer.pipeline_forward( 24 | pipeline=pipeline, 25 | image=img, 26 | guidance_scale=guidance_scale, 27 | **kwargs 28 | ).images 29 | ret.extend(images) 30 | return ret 31 | 32 | 33 | def run_mvprediction(input_image: Image.Image, remove_bg=True, guidance_scale=1.5, seed=1145): 34 | if input_image.mode == 'RGB' or np.array(input_image)[..., -1].mean() == 255.: 35 | # still do remove using rembg, since simple_preprocess requires RGBA image 36 | print("RGB image not RGBA! still remove bg!") 37 | remove_bg = True 38 | 39 | if remove_bg: 40 | input_image = remove(input_image, session=session) 41 | 42 | # make front_pil RGBA with white bg 43 | input_image = change_rgba_bg(input_image, "white") 44 | single_image = simple_preprocess(input_image) 45 | 46 | generator = torch.Generator(device="cuda").manual_seed(int(seed)) if seed >= 0 else None 47 | 48 | rgb_pils = predict( 49 | single_image, 50 | generator=generator, 51 | guidance_scale=guidance_scale, 52 | width=256, 53 | height=256, 54 | num_inference_steps=30, 55 | ) 56 | 57 | return rgb_pils, single_image 58 | -------------------------------------------------------------------------------- /app/custom_models/normal_prediction.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from PIL import Image 3 | from app.utils import rgba_to_rgb, simple_remove 4 | from app.custom_models.utils import load_pipeline 5 | from scripts.utils import rotate_normals_torch 6 | from scripts.all_typing import * 7 | 8 | training_config = "app/custom_models/image2normal.yaml" 9 | checkpoint_path = "ckpt/image2normal/unet_state_dict.pth" 10 | trainer, pipeline = load_pipeline(training_config, checkpoint_path) 11 | # pipeline.enable_model_cpu_offload() 12 | 13 | def predict_normals(image: List[Image.Image], guidance_scale=2., do_rotate=True, num_inference_steps=30, **kwargs): 14 | img_list = image if isinstance(image, list) else [image] 15 | img_list = [rgba_to_rgb(i) if i.mode == 'RGBA' else i for i in img_list] 16 | images = trainer.pipeline_forward( 17 | pipeline=pipeline, 18 | image=img_list, 19 | num_inference_steps=num_inference_steps, 20 | guidance_scale=guidance_scale, 21 | **kwargs 22 | ).images 23 | images = simple_remove(images) 24 | if do_rotate and len(images) > 1: 25 | images = rotate_normals_torch(images, return_types='pil') 26 | return images -------------------------------------------------------------------------------- /app/custom_models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List 3 | from dataclasses import dataclass 4 | from app.utils import rgba_to_rgb 5 | from custum_3d_diffusion.trainings.config_classes import ExprimentConfig, TrainerSubConfig 6 | from custum_3d_diffusion import modules 7 | from custum_3d_diffusion.custum_modules.unifield_processor import AttnConfig, ConfigurableUNet2DConditionModel 8 | from custum_3d_diffusion.trainings.base import BasicTrainer 9 | from custum_3d_diffusion.trainings.utils import load_config 10 | 11 | 12 | @dataclass 13 | class FakeAccelerator: 14 | device: torch.device = torch.device("cuda") 15 | 16 | 17 | def init_trainers(cfg_path: str, weight_dtype: torch.dtype, extras: dict): 18 | accelerator = FakeAccelerator() 19 | cfg: ExprimentConfig = load_config(ExprimentConfig, cfg_path, extras) 20 | init_config: AttnConfig = load_config(AttnConfig, cfg.init_config) 21 | configurable_unet = ConfigurableUNet2DConditionModel(init_config, weight_dtype) 22 | configurable_unet.enable_xformers_memory_efficient_attention() 23 | trainer_cfgs: List[TrainerSubConfig] = [load_config(TrainerSubConfig, trainer) for trainer in cfg.trainers] 24 | trainers: List[BasicTrainer] = [modules.find(trainer.trainer_type)(accelerator, None, configurable_unet, trainer.trainer, weight_dtype, i) for i, trainer in enumerate(trainer_cfgs)] 25 | return trainers, configurable_unet 26 | 27 | from app.utils import make_image_grid, split_image 28 | def process_image(function, img, guidance_scale=2., merged_image=False, remove_bg=True): 29 | from rembg import remove 30 | if remove_bg: 31 | img = remove(img) 32 | img = rgba_to_rgb(img) 33 | if merged_image: 34 | img = split_image(img, rows=2) 35 | images = function( 36 | image=img, 37 | guidance_scale=guidance_scale, 38 | ) 39 | if len(images) > 1: 40 | return make_image_grid(images, rows=2) 41 | else: 42 | return images[0] 43 | 44 | 45 | def process_text(trainer, pipeline, img, guidance_scale=2.): 46 | pipeline.cfg.validation_prompts = [img] 47 | titles, images = trainer.batched_validation_forward(pipeline, guidance_scale=[guidance_scale]) 48 | return images[0] 49 | 50 | 51 | def load_pipeline(config_path, ckpt_path, pipeline_filter=lambda x: True, weight_dtype = torch.bfloat16): 52 | training_config = config_path 53 | load_from_checkpoint = ckpt_path 54 | extras = [] 55 | device = "cuda" 56 | trainers, configurable_unet = init_trainers(training_config, weight_dtype, extras) 57 | shared_modules = dict() 58 | for trainer in trainers: 59 | shared_modules = trainer.init_shared_modules(shared_modules) 60 | 61 | if load_from_checkpoint is not None: 62 | state_dict = torch.load(load_from_checkpoint) 63 | configurable_unet.unet.load_state_dict(state_dict, strict=False) 64 | # Move unet, vae and text_encoder to device and cast to weight_dtype 65 | configurable_unet.unet.to(device, dtype=weight_dtype) 66 | 67 | pipeline = None 68 | trainer_out = None 69 | for trainer in trainers: 70 | if pipeline_filter(trainer.cfg.trainer_name): 71 | pipeline = trainer.construct_pipeline(shared_modules, configurable_unet.unet) 72 | pipeline.set_progress_bar_config(disable=False) 73 | trainer_out = trainer 74 | pipeline = pipeline.to(device) 75 | return trainer_out, pipeline -------------------------------------------------------------------------------- /app/examples/Groot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AiuniAI/Unique3D/a4f4ad2a6d9e71f2b04a753b375022898581b2c8/app/examples/Groot.png -------------------------------------------------------------------------------- /app/examples/aaa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AiuniAI/Unique3D/a4f4ad2a6d9e71f2b04a753b375022898581b2c8/app/examples/aaa.png -------------------------------------------------------------------------------- /app/examples/abma.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AiuniAI/Unique3D/a4f4ad2a6d9e71f2b04a753b375022898581b2c8/app/examples/abma.png -------------------------------------------------------------------------------- /app/examples/akun.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AiuniAI/Unique3D/a4f4ad2a6d9e71f2b04a753b375022898581b2c8/app/examples/akun.png -------------------------------------------------------------------------------- /app/examples/anya.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AiuniAI/Unique3D/a4f4ad2a6d9e71f2b04a753b375022898581b2c8/app/examples/anya.png -------------------------------------------------------------------------------- /app/examples/bag.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AiuniAI/Unique3D/a4f4ad2a6d9e71f2b04a753b375022898581b2c8/app/examples/bag.png -------------------------------------------------------------------------------- /app/examples/ex1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AiuniAI/Unique3D/a4f4ad2a6d9e71f2b04a753b375022898581b2c8/app/examples/ex1.png -------------------------------------------------------------------------------- /app/examples/ex2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AiuniAI/Unique3D/a4f4ad2a6d9e71f2b04a753b375022898581b2c8/app/examples/ex2.png -------------------------------------------------------------------------------- /app/examples/ex3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AiuniAI/Unique3D/a4f4ad2a6d9e71f2b04a753b375022898581b2c8/app/examples/ex3.jpg -------------------------------------------------------------------------------- /app/examples/ex4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AiuniAI/Unique3D/a4f4ad2a6d9e71f2b04a753b375022898581b2c8/app/examples/ex4.png -------------------------------------------------------------------------------- /app/examples/generated_1715761545_frame0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AiuniAI/Unique3D/a4f4ad2a6d9e71f2b04a753b375022898581b2c8/app/examples/generated_1715761545_frame0.png -------------------------------------------------------------------------------- /app/examples/generated_1715762357_frame0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AiuniAI/Unique3D/a4f4ad2a6d9e71f2b04a753b375022898581b2c8/app/examples/generated_1715762357_frame0.png -------------------------------------------------------------------------------- /app/examples/generated_1715763329_frame0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AiuniAI/Unique3D/a4f4ad2a6d9e71f2b04a753b375022898581b2c8/app/examples/generated_1715763329_frame0.png -------------------------------------------------------------------------------- /app/examples/hatsune_miku.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AiuniAI/Unique3D/a4f4ad2a6d9e71f2b04a753b375022898581b2c8/app/examples/hatsune_miku.png -------------------------------------------------------------------------------- /app/examples/princess-large.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AiuniAI/Unique3D/a4f4ad2a6d9e71f2b04a753b375022898581b2c8/app/examples/princess-large.png -------------------------------------------------------------------------------- /app/gradio_3dgen.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gradio as gr 3 | from PIL import Image 4 | from pytorch3d.structures import Meshes 5 | from app.utils import clean_up 6 | from app.custom_models.mvimg_prediction import run_mvprediction 7 | from app.custom_models.normal_prediction import predict_normals 8 | from scripts.refine_lr_to_sr import run_sr_fast 9 | from scripts.utils import save_glb_and_video 10 | from scripts.multiview_inference import geo_reconstruct 11 | 12 | def generate3dv2(preview_img, input_processing, seed, render_video=True, do_refine=True, expansion_weight=0.1, init_type="std"): 13 | if preview_img is None: 14 | raise gr.Error("preview_img is none") 15 | if isinstance(preview_img, str): 16 | preview_img = Image.open(preview_img) 17 | 18 | if preview_img.size[0] <= 512: 19 | preview_img = run_sr_fast([preview_img])[0] 20 | rgb_pils, front_pil = run_mvprediction(preview_img, remove_bg=input_processing, seed=int(seed)) # 6s 21 | new_meshes = geo_reconstruct(rgb_pils, None, front_pil, do_refine=do_refine, predict_normal=True, expansion_weight=expansion_weight, init_type=init_type) 22 | vertices = new_meshes.verts_packed() 23 | vertices = vertices / 2 * 1.35 24 | vertices[..., [0, 2]] = - vertices[..., [0, 2]] 25 | new_meshes = Meshes(verts=[vertices], faces=new_meshes.faces_list(), textures=new_meshes.textures) 26 | 27 | ret_mesh, video = save_glb_and_video("/tmp/gradio/generated", new_meshes, with_timestamp=True, dist=3.5, fov_in_degrees=2 / 1.35, cam_type="ortho", export_video=render_video) 28 | return ret_mesh, video 29 | 30 | ####################################### 31 | def create_ui(concurrency_id="wkl"): 32 | with gr.Row(): 33 | with gr.Column(scale=2): 34 | input_image = gr.Image(type='pil', image_mode='RGBA', label='Frontview') 35 | 36 | example_folder = os.path.join(os.path.dirname(__file__), "./examples") 37 | example_fns = sorted([os.path.join(example_folder, example) for example in os.listdir(example_folder)]) 38 | gr.Examples( 39 | examples=example_fns, 40 | inputs=[input_image], 41 | cache_examples=False, 42 | label='Examples (click one of the images below to start)', 43 | examples_per_page=12 44 | ) 45 | 46 | 47 | with gr.Column(scale=3): 48 | # export mesh display 49 | output_mesh = gr.Model3D(value=None, label="Mesh Model", show_label=True, height=320) 50 | output_video = gr.Video(label="Preview", show_label=True, show_share_button=True, height=320, visible=False) 51 | 52 | input_processing = gr.Checkbox( 53 | value=True, 54 | label='Remove Background', 55 | visible=True, 56 | ) 57 | do_refine = gr.Checkbox(value=True, label="Refine Multiview Details", visible=False) 58 | expansion_weight = gr.Slider(minimum=-1., maximum=1.0, value=0.1, step=0.1, label="Expansion Weight", visible=False) 59 | init_type = gr.Dropdown(choices=["std", "thin"], label="Mesh Initialization", value="std", visible=False) 60 | setable_seed = gr.Slider(-1, 1000000000, -1, step=1, visible=True, label="Seed") 61 | render_video = gr.Checkbox(value=False, visible=False, label="generate video") 62 | fullrunv2_btn = gr.Button('Generate 3D', interactive=True) 63 | 64 | fullrunv2_btn.click( 65 | fn = generate3dv2, 66 | inputs=[input_image, input_processing, setable_seed, render_video, do_refine, expansion_weight, init_type], 67 | outputs=[output_mesh, output_video], 68 | concurrency_id=concurrency_id, 69 | api_name="generate3dv2", 70 | ).success(clean_up, api_name=False) 71 | return input_image 72 | -------------------------------------------------------------------------------- /app/gradio_3dgen_steps.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | from PIL import Image 3 | 4 | from app.custom_models.mvimg_prediction import run_mvprediction 5 | from app.utils import make_image_grid, split_image 6 | from scripts.utils import save_glb_and_video 7 | 8 | def concept_to_multiview(preview_img, input_processing, seed, guidance=1.): 9 | seed = int(seed) 10 | if preview_img is None: 11 | raise gr.Error("preview_img is none.") 12 | if isinstance(preview_img, str): 13 | preview_img = Image.open(preview_img) 14 | 15 | rgb_pils, front_pil = run_mvprediction(preview_img, remove_bg=input_processing, seed=seed, guidance_scale=guidance) 16 | rgb_pil = make_image_grid(rgb_pils, rows=2) 17 | return rgb_pil, front_pil 18 | 19 | def concept_to_multiview_ui(concurrency_id="wkl"): 20 | with gr.Row(): 21 | with gr.Column(scale=2): 22 | preview_img = gr.Image(type='pil', image_mode='RGBA', label='Frontview') 23 | input_processing = gr.Checkbox( 24 | value=True, 25 | label='Remove Background', 26 | ) 27 | seed = gr.Slider(minimum=-1, maximum=1000000000, value=-1, step=1.0, label="seed") 28 | guidance = gr.Slider(minimum=1.0, maximum=5.0, value=1.0, label="Guidance Scale", step=0.5) 29 | run_btn = gr.Button('Generate Multiview', interactive=True) 30 | with gr.Column(scale=3): 31 | # export mesh display 32 | output_rgb = gr.Image(type='pil', label="RGB", show_label=True) 33 | output_front = gr.Image(type='pil', image_mode='RGBA', label="Frontview", show_label=True) 34 | run_btn.click( 35 | fn = concept_to_multiview, 36 | inputs=[preview_img, input_processing, seed, guidance], 37 | outputs=[output_rgb, output_front], 38 | concurrency_id=concurrency_id, 39 | api_name=False, 40 | ) 41 | return output_rgb, output_front 42 | 43 | from app.custom_models.normal_prediction import predict_normals 44 | from scripts.multiview_inference import geo_reconstruct 45 | def multiview_to_mesh_v2(rgb_pil, normal_pil, front_pil, do_refine=False, expansion_weight=0.1, init_type="std"): 46 | rgb_pils = split_image(rgb_pil, rows=2) 47 | if normal_pil is not None: 48 | normal_pil = split_image(normal_pil, rows=2) 49 | if front_pil is None: 50 | front_pil = rgb_pils[0] 51 | new_meshes = geo_reconstruct(rgb_pils, normal_pil, front_pil, do_refine=do_refine, predict_normal=normal_pil is None, expansion_weight=expansion_weight, init_type=init_type) 52 | ret_mesh, video = save_glb_and_video("/tmp/gradio/generated", new_meshes, with_timestamp=True, dist=3.5, fov_in_degrees=2 / 1.35, cam_type="ortho", export_video=False) 53 | return ret_mesh 54 | 55 | def new_multiview_to_mesh_ui(concurrency_id="wkl"): 56 | with gr.Row(): 57 | with gr.Column(scale=2): 58 | rgb_pil = gr.Image(type='pil', image_mode='RGB', label='RGB') 59 | front_pil = gr.Image(type='pil', image_mode='RGBA', label='Frontview(Optinal)') 60 | normal_pil = gr.Image(type='pil', image_mode='RGBA', label='Normal(Optinal)') 61 | do_refine = gr.Checkbox( 62 | value=False, 63 | label='Refine rgb', 64 | visible=False, 65 | ) 66 | expansion_weight = gr.Slider(minimum=-1.0, maximum=1.0, value=0.1, step=0.1, label="Expansion Weight", visible=False) 67 | init_type = gr.Dropdown(choices=["std", "thin"], label="Mesh initialization", value="std", visible=False) 68 | run_btn = gr.Button('Generate 3D', interactive=True) 69 | with gr.Column(scale=3): 70 | # export mesh display 71 | output_mesh = gr.Model3D(value=None, label="mesh model", show_label=True) 72 | run_btn.click( 73 | fn = multiview_to_mesh_v2, 74 | inputs=[rgb_pil, normal_pil, front_pil, do_refine, expansion_weight, init_type], 75 | outputs=[output_mesh], 76 | concurrency_id=concurrency_id, 77 | api_name="multiview_to_mesh", 78 | ) 79 | return rgb_pil, front_pil, output_mesh 80 | 81 | 82 | ####################################### 83 | def create_step_ui(concurrency_id="wkl"): 84 | with gr.Tab(label="3D:concept_to_multiview"): 85 | concept_to_multiview_ui(concurrency_id) 86 | with gr.Tab(label="3D:new_multiview_to_mesh"): 87 | new_multiview_to_mesh_ui(concurrency_id) 88 | -------------------------------------------------------------------------------- /app/gradio_local.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | import os 3 | import sys 4 | sys.path.append(os.curdir) 5 | if 'CUDA_VISIBLE_DEVICES' not in os.environ: 6 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 7 | os.environ['TRANSFORMERS_OFFLINE']='0' 8 | os.environ['DIFFUSERS_OFFLINE']='0' 9 | os.environ['HF_HUB_OFFLINE']='0' 10 | os.environ['GRADIO_ANALYTICS_ENABLED']='False' 11 | os.environ['HF_ENDPOINT']='https://hf-mirror.com' 12 | import torch 13 | torch.set_float32_matmul_precision('medium') 14 | torch.backends.cuda.matmul.allow_tf32 = True 15 | torch.set_grad_enabled(False) 16 | 17 | import gradio as gr 18 | import argparse 19 | 20 | from app.gradio_3dgen import create_ui as create_3d_ui 21 | # from app.gradio_3dgen_steps import create_step_ui 22 | from app.all_models import model_zoo 23 | 24 | 25 | _TITLE = '''Unique3D: High-Quality and Efficient 3D Mesh Generation from a Single Image''' 26 | _DESCRIPTION = ''' 27 | [Project page](https://wukailu.github.io/Unique3D/) 28 | 29 | * High-fidelity and diverse textured meshes generated by Unique3D from single-view images. 30 | 31 | **If the Gradio Demo is overcrowded or fails to produce stable results, you can use the Online Demo [aiuni.ai](https://www.aiuni.ai/), which is free to try (get the registration invitation code Join Discord: https://discord.gg/aiuni). However, the Online Demo is slightly different from the Gradio Demo, in that the inference speed is slower, but the generation is much more stable.** 32 | ''' 33 | 34 | def launch( 35 | port, 36 | listen=False, 37 | share=False, 38 | gradio_root="", 39 | ): 40 | model_zoo.init_models() 41 | 42 | with gr.Blocks( 43 | title=_TITLE, 44 | theme=gr.themes.Monochrome(), 45 | ) as demo: 46 | with gr.Row(): 47 | with gr.Column(scale=1): 48 | gr.Markdown('# ' + _TITLE) 49 | gr.Markdown(_DESCRIPTION) 50 | create_3d_ui("wkl") 51 | 52 | launch_args = {} 53 | if listen: 54 | launch_args["server_name"] = "0.0.0.0" 55 | 56 | demo.queue(default_concurrency_limit=1).launch( 57 | server_port=None if port == 0 else port, 58 | share=share, 59 | root_path=gradio_root if gradio_root != "" else None, # "/myapp" 60 | **launch_args, 61 | ) 62 | 63 | if __name__ == "__main__": 64 | parser = argparse.ArgumentParser() 65 | args, extra = parser.parse_known_args() 66 | parser.add_argument("--listen", action="store_true") 67 | parser.add_argument("--port", type=int, default=0) 68 | parser.add_argument("--share", action="store_true") 69 | parser.add_argument("--gradio_root", default="") 70 | args = parser.parse_args() 71 | launch( 72 | args.port, 73 | listen=args.listen, 74 | share=args.share, 75 | gradio_root=args.gradio_root, 76 | ) -------------------------------------------------------------------------------- /app/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from PIL import Image 4 | import gc 5 | import numpy as np 6 | import numpy as np 7 | from PIL import Image 8 | from scripts.refine_lr_to_sr import run_sr_fast 9 | 10 | GRADIO_CACHE = "/tmp/gradio/" 11 | 12 | def clean_up(): 13 | torch.cuda.empty_cache() 14 | gc.collect() 15 | 16 | def remove_color(arr): 17 | if arr.shape[-1] == 4: 18 | arr = arr[..., :3] 19 | # calc diffs 20 | base = arr[0, 0] 21 | diffs = np.abs(arr.astype(np.int32) - base.astype(np.int32)).sum(axis=-1) 22 | alpha = (diffs <= 80) 23 | 24 | arr[alpha] = 255 25 | alpha = ~alpha 26 | arr = np.concatenate([arr, alpha[..., None].astype(np.int32) * 255], axis=-1) 27 | return arr 28 | 29 | def simple_remove(imgs, run_sr=True): 30 | """Only works for normal""" 31 | if not isinstance(imgs, list): 32 | imgs = [imgs] 33 | single_input = True 34 | else: 35 | single_input = False 36 | if run_sr: 37 | imgs = run_sr_fast(imgs) 38 | rets = [] 39 | for img in imgs: 40 | arr = np.array(img) 41 | arr = remove_color(arr) 42 | rets.append(Image.fromarray(arr.astype(np.uint8))) 43 | if single_input: 44 | return rets[0] 45 | return rets 46 | 47 | def rgba_to_rgb(rgba: Image.Image, bkgd="WHITE"): 48 | new_image = Image.new("RGBA", rgba.size, bkgd) 49 | new_image.paste(rgba, (0, 0), rgba) 50 | new_image = new_image.convert('RGB') 51 | return new_image 52 | 53 | def change_rgba_bg(rgba: Image.Image, bkgd="WHITE"): 54 | rgb_white = rgba_to_rgb(rgba, bkgd) 55 | new_rgba = Image.fromarray(np.concatenate([np.array(rgb_white), np.array(rgba)[:, :, 3:4]], axis=-1)) 56 | return new_rgba 57 | 58 | def split_image(image, rows=None, cols=None): 59 | """ 60 | inverse function of make_image_grid 61 | """ 62 | # image is in square 63 | if rows is None and cols is None: 64 | # image.size [W, H] 65 | rows = 1 66 | cols = image.size[0] // image.size[1] 67 | assert cols * image.size[1] == image.size[0] 68 | subimg_size = image.size[1] 69 | elif rows is None: 70 | subimg_size = image.size[0] // cols 71 | rows = image.size[1] // subimg_size 72 | assert rows * subimg_size == image.size[1] 73 | elif cols is None: 74 | subimg_size = image.size[1] // rows 75 | cols = image.size[0] // subimg_size 76 | assert cols * subimg_size == image.size[0] 77 | else: 78 | subimg_size = image.size[1] // rows 79 | assert cols * subimg_size == image.size[0] 80 | subimgs = [] 81 | for i in range(rows): 82 | for j in range(cols): 83 | subimg = image.crop((j*subimg_size, i*subimg_size, (j+1)*subimg_size, (i+1)*subimg_size)) 84 | subimgs.append(subimg) 85 | return subimgs 86 | 87 | def make_image_grid(images, rows=None, cols=None, resize=None): 88 | if rows is None and cols is None: 89 | rows = 1 90 | cols = len(images) 91 | if rows is None: 92 | rows = len(images) // cols 93 | if len(images) % cols != 0: 94 | rows += 1 95 | if cols is None: 96 | cols = len(images) // rows 97 | if len(images) % rows != 0: 98 | cols += 1 99 | total_imgs = rows * cols 100 | if total_imgs > len(images): 101 | images += [Image.new(images[0].mode, images[0].size) for _ in range(total_imgs - len(images))] 102 | 103 | if resize is not None: 104 | images = [img.resize((resize, resize)) for img in images] 105 | 106 | w, h = images[0].size 107 | grid = Image.new(images[0].mode, size=(cols * w, rows * h)) 108 | 109 | for i, img in enumerate(images): 110 | grid.paste(img, box=(i % cols * w, i // cols * h)) 111 | return grid 112 | 113 | -------------------------------------------------------------------------------- /assets/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AiuniAI/Unique3D/a4f4ad2a6d9e71f2b04a753b375022898581b2c8/assets/teaser.jpg -------------------------------------------------------------------------------- /assets/teaser_safe.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AiuniAI/Unique3D/a4f4ad2a6d9e71f2b04a753b375022898581b2c8/assets/teaser_safe.jpg -------------------------------------------------------------------------------- /custum_3d_diffusion/custum_pipeline/unifield_pipeline_img2img.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # modified by Wuvin 15 | 16 | 17 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 18 | 19 | import numpy as np 20 | import torch 21 | 22 | from diffusers import AutoencoderKL, UNet2DConditionModel, StableDiffusionImageVariationPipeline 23 | from diffusers.schedulers import KarrasDiffusionSchedulers 24 | from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipelineOutput 25 | from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel 26 | from PIL import Image 27 | from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection 28 | 29 | 30 | 31 | class StableDiffusionImageCustomPipeline( 32 | StableDiffusionImageVariationPipeline 33 | ): 34 | def __init__( 35 | self, 36 | vae: AutoencoderKL, 37 | image_encoder: CLIPVisionModelWithProjection, 38 | unet: UNet2DConditionModel, 39 | scheduler: KarrasDiffusionSchedulers, 40 | safety_checker: StableDiffusionSafetyChecker, 41 | feature_extractor: CLIPImageProcessor, 42 | requires_safety_checker: bool = True, 43 | latents_offset=None, 44 | noisy_cond_latents=False, 45 | ): 46 | super().__init__( 47 | vae=vae, 48 | image_encoder=image_encoder, 49 | unet=unet, 50 | scheduler=scheduler, 51 | safety_checker=safety_checker, 52 | feature_extractor=feature_extractor, 53 | requires_safety_checker=requires_safety_checker 54 | ) 55 | latents_offset = tuple(latents_offset) if latents_offset is not None else None 56 | self.latents_offset = latents_offset 57 | if latents_offset is not None: 58 | self.register_to_config(latents_offset=latents_offset) 59 | self.noisy_cond_latents = noisy_cond_latents 60 | self.register_to_config(noisy_cond_latents=noisy_cond_latents) 61 | 62 | def encode_latents(self, image, device, dtype, height, width): 63 | # support batchsize > 1 64 | if isinstance(image, Image.Image): 65 | image = [image] 66 | image = [img.convert("RGB") for img in image] 67 | images = self.image_processor.preprocess(image, height=height, width=width).to(device, dtype=dtype) 68 | latents = self.vae.encode(images).latent_dist.mode() * self.vae.config.scaling_factor 69 | if self.latents_offset is not None: 70 | return latents - torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None] 71 | else: 72 | return latents 73 | 74 | def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance): 75 | dtype = next(self.image_encoder.parameters()).dtype 76 | 77 | if not isinstance(image, torch.Tensor): 78 | image = self.feature_extractor(images=image, return_tensors="pt").pixel_values 79 | 80 | image = image.to(device=device, dtype=dtype) 81 | image_embeddings = self.image_encoder(image).image_embeds 82 | image_embeddings = image_embeddings.unsqueeze(1) 83 | 84 | # duplicate image embeddings for each generation per prompt, using mps friendly method 85 | bs_embed, seq_len, _ = image_embeddings.shape 86 | image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) 87 | image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) 88 | 89 | if do_classifier_free_guidance: 90 | # NOTE: the same as original code 91 | negative_prompt_embeds = torch.zeros_like(image_embeddings) 92 | # For classifier free guidance, we need to do two forward passes. 93 | # Here we concatenate the unconditional and text embeddings into a single batch 94 | # to avoid doing two forward passes 95 | image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) 96 | 97 | return image_embeddings 98 | 99 | @torch.no_grad() 100 | def __call__( 101 | self, 102 | image: Union[Image.Image, List[Image.Image], torch.FloatTensor], 103 | height: Optional[int] = 1024, 104 | width: Optional[int] = 1024, 105 | height_cond: Optional[int] = 512, 106 | width_cond: Optional[int] = 512, 107 | num_inference_steps: int = 50, 108 | guidance_scale: float = 7.5, 109 | num_images_per_prompt: Optional[int] = 1, 110 | eta: float = 0.0, 111 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 112 | latents: Optional[torch.FloatTensor] = None, 113 | output_type: Optional[str] = "pil", 114 | return_dict: bool = True, 115 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 116 | callback_steps: int = 1, 117 | upper_left_feature: bool = False, 118 | ): 119 | r""" 120 | The call function to the pipeline for generation. 121 | 122 | Args: 123 | image (`Image.Image` or `List[Image.Image]` or `torch.FloatTensor`): 124 | Image or images to guide image generation. If you provide a tensor, it needs to be compatible with 125 | [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json). 126 | height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): 127 | The height in pixels of the generated image. 128 | width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): 129 | The width in pixels of the generated image. 130 | num_inference_steps (`int`, *optional*, defaults to 50): 131 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 132 | expense of slower inference. This parameter is modulated by `strength`. 133 | guidance_scale (`float`, *optional*, defaults to 7.5): 134 | A higher guidance scale value encourages the model to generate images closely linked to the text 135 | `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. 136 | num_images_per_prompt (`int`, *optional*, defaults to 1): 137 | The number of images to generate per prompt. 138 | eta (`float`, *optional*, defaults to 0.0): 139 | Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies 140 | to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. 141 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 142 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make 143 | generation deterministic. 144 | latents (`torch.FloatTensor`, *optional*): 145 | Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image 146 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 147 | tensor is generated by sampling using the supplied random `generator`. 148 | output_type (`str`, *optional*, defaults to `"pil"`): 149 | The output format of the generated image. Choose between `PIL.Image` or `np.array`. 150 | return_dict (`bool`, *optional*, defaults to `True`): 151 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 152 | plain tuple. 153 | callback (`Callable`, *optional*): 154 | A function that calls every `callback_steps` steps during inference. The function is called with the 155 | following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. 156 | callback_steps (`int`, *optional*, defaults to 1): 157 | The frequency at which the `callback` function is called. If not specified, the callback is called at 158 | every step. 159 | 160 | Returns: 161 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: 162 | If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, 163 | otherwise a `tuple` is returned where the first element is a list with the generated images and the 164 | second element is a list of `bool`s indicating whether the corresponding generated image contains 165 | "not-safe-for-work" (nsfw) content. 166 | 167 | Examples: 168 | 169 | ```py 170 | from diffusers import StableDiffusionImageVariationPipeline 171 | from PIL import Image 172 | from io import BytesIO 173 | import requests 174 | 175 | pipe = StableDiffusionImageVariationPipeline.from_pretrained( 176 | "lambdalabs/sd-image-variations-diffusers", revision="v2.0" 177 | ) 178 | pipe = pipe.to("cuda") 179 | 180 | url = "https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200" 181 | 182 | response = requests.get(url) 183 | image = Image.open(BytesIO(response.content)).convert("RGB") 184 | 185 | out = pipe(image, num_images_per_prompt=3, guidance_scale=15) 186 | out["images"][0].save("result.jpg") 187 | ``` 188 | """ 189 | # 0. Default height and width to unet 190 | height = height or self.unet.config.sample_size * self.vae_scale_factor 191 | width = width or self.unet.config.sample_size * self.vae_scale_factor 192 | 193 | # 1. Check inputs. Raise error if not correct 194 | self.check_inputs(image, height, width, callback_steps) 195 | 196 | # 2. Define call parameters 197 | if isinstance(image, Image.Image): 198 | batch_size = 1 199 | elif isinstance(image, list): 200 | batch_size = len(image) 201 | else: 202 | batch_size = image.shape[0] 203 | device = self._execution_device 204 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 205 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 206 | # corresponds to doing no classifier free guidance. 207 | do_classifier_free_guidance = guidance_scale > 1.0 208 | 209 | # 3. Encode input image 210 | if isinstance(image, Image.Image) and upper_left_feature: 211 | # only use the first one of four images 212 | emb_image = image.crop((0, 0, image.size[0] // 2, image.size[1] // 2)) 213 | else: 214 | emb_image = image 215 | 216 | image_embeddings = self._encode_image(emb_image, device, num_images_per_prompt, do_classifier_free_guidance) 217 | cond_latents = self.encode_latents(image, image_embeddings.device, image_embeddings.dtype, height_cond, width_cond) 218 | 219 | # 4. Prepare timesteps 220 | self.scheduler.set_timesteps(num_inference_steps, device=device) 221 | timesteps = self.scheduler.timesteps 222 | 223 | # 5. Prepare latent variables 224 | num_channels_latents = self.unet.config.out_channels 225 | latents = self.prepare_latents( 226 | batch_size * num_images_per_prompt, 227 | num_channels_latents, 228 | height, 229 | width, 230 | image_embeddings.dtype, 231 | device, 232 | generator, 233 | latents, 234 | ) 235 | 236 | # 6. Prepare extra step kwargs. 237 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 238 | 239 | # 7. Denoising loop 240 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 241 | with self.progress_bar(total=num_inference_steps) as progress_bar: 242 | for i, t in enumerate(timesteps): 243 | if self.noisy_cond_latents: 244 | raise ValueError("Noisy condition latents is not recommended.") 245 | else: 246 | noisy_cond_latents = cond_latents 247 | 248 | noisy_cond_latents = torch.cat([torch.zeros_like(noisy_cond_latents), noisy_cond_latents]) if do_classifier_free_guidance else noisy_cond_latents 249 | # expand the latents if we are doing classifier free guidance 250 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 251 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 252 | 253 | # predict the noise residual 254 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings, condition_latents=noisy_cond_latents).sample 255 | 256 | # perform guidance 257 | if do_classifier_free_guidance: 258 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 259 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 260 | 261 | # compute the previous noisy sample x_t -> x_t-1 262 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 263 | 264 | # call the callback, if provided 265 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 266 | progress_bar.update() 267 | if callback is not None and i % callback_steps == 0: 268 | step_idx = i // getattr(self.scheduler, "order", 1) 269 | callback(step_idx, t, latents) 270 | 271 | self.maybe_free_model_hooks() 272 | 273 | if self.latents_offset is not None: 274 | latents = latents + torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None] 275 | 276 | if not output_type == "latent": 277 | image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] 278 | image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype) 279 | else: 280 | image = latents 281 | has_nsfw_concept = None 282 | 283 | if has_nsfw_concept is None: 284 | do_denormalize = [True] * image.shape[0] 285 | else: 286 | do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] 287 | 288 | image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) 289 | 290 | self.maybe_free_model_hooks() 291 | 292 | if not return_dict: 293 | return (image, has_nsfw_concept) 294 | 295 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) 296 | 297 | if __name__ == "__main__": 298 | pass 299 | -------------------------------------------------------------------------------- /custum_3d_diffusion/custum_pipeline/unifield_pipeline_img2mvimg.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # modified by Wuvin 15 | 16 | 17 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 18 | 19 | import numpy as np 20 | import torch 21 | 22 | from diffusers import AutoencoderKL, UNet2DConditionModel, StableDiffusionImageVariationPipeline 23 | from diffusers.schedulers import KarrasDiffusionSchedulers, DDPMScheduler 24 | from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipelineOutput 25 | from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel 26 | from PIL import Image 27 | from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection 28 | 29 | 30 | 31 | class StableDiffusionImage2MVCustomPipeline( 32 | StableDiffusionImageVariationPipeline 33 | ): 34 | def __init__( 35 | self, 36 | vae: AutoencoderKL, 37 | image_encoder: CLIPVisionModelWithProjection, 38 | unet: UNet2DConditionModel, 39 | scheduler: KarrasDiffusionSchedulers, 40 | safety_checker: StableDiffusionSafetyChecker, 41 | feature_extractor: CLIPImageProcessor, 42 | requires_safety_checker: bool = True, 43 | latents_offset=None, 44 | noisy_cond_latents=False, 45 | condition_offset=True, 46 | ): 47 | super().__init__( 48 | vae=vae, 49 | image_encoder=image_encoder, 50 | unet=unet, 51 | scheduler=scheduler, 52 | safety_checker=safety_checker, 53 | feature_extractor=feature_extractor, 54 | requires_safety_checker=requires_safety_checker 55 | ) 56 | latents_offset = tuple(latents_offset) if latents_offset is not None else None 57 | self.latents_offset = latents_offset 58 | if latents_offset is not None: 59 | self.register_to_config(latents_offset=latents_offset) 60 | if noisy_cond_latents: 61 | raise NotImplementedError("Noisy condition latents not supported Now.") 62 | self.condition_offset = condition_offset 63 | self.register_to_config(condition_offset=condition_offset) 64 | 65 | def encode_latents(self, image: Image.Image, device, dtype, height, width): 66 | images = self.image_processor.preprocess(image.convert("RGB"), height=height, width=width).to(device, dtype=dtype) 67 | # NOTE: .mode() for condition 68 | latents = self.vae.encode(images).latent_dist.mode() * self.vae.config.scaling_factor 69 | if self.latents_offset is not None and self.condition_offset: 70 | return latents - torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None] 71 | else: 72 | return latents 73 | 74 | def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance): 75 | dtype = next(self.image_encoder.parameters()).dtype 76 | 77 | if not isinstance(image, torch.Tensor): 78 | image = self.feature_extractor(images=image, return_tensors="pt").pixel_values 79 | 80 | image = image.to(device=device, dtype=dtype) 81 | image_embeddings = self.image_encoder(image).image_embeds 82 | image_embeddings = image_embeddings.unsqueeze(1) 83 | 84 | # duplicate image embeddings for each generation per prompt, using mps friendly method 85 | bs_embed, seq_len, _ = image_embeddings.shape 86 | image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) 87 | image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) 88 | 89 | if do_classifier_free_guidance: 90 | # NOTE: the same as original code 91 | negative_prompt_embeds = torch.zeros_like(image_embeddings) 92 | # For classifier free guidance, we need to do two forward passes. 93 | # Here we concatenate the unconditional and text embeddings into a single batch 94 | # to avoid doing two forward passes 95 | image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) 96 | 97 | return image_embeddings 98 | 99 | @torch.no_grad() 100 | def __call__( 101 | self, 102 | image: Union[Image.Image, List[Image.Image], torch.FloatTensor], 103 | height: Optional[int] = 1024, 104 | width: Optional[int] = 1024, 105 | height_cond: Optional[int] = 512, 106 | width_cond: Optional[int] = 512, 107 | num_inference_steps: int = 50, 108 | guidance_scale: float = 7.5, 109 | num_images_per_prompt: Optional[int] = 1, 110 | eta: float = 0.0, 111 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 112 | latents: Optional[torch.FloatTensor] = None, 113 | output_type: Optional[str] = "pil", 114 | return_dict: bool = True, 115 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 116 | callback_steps: int = 1, 117 | ): 118 | r""" 119 | The call function to the pipeline for generation. 120 | 121 | Args: 122 | image (`Image.Image` or `List[Image.Image]` or `torch.FloatTensor`): 123 | Image or images to guide image generation. If you provide a tensor, it needs to be compatible with 124 | [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json). 125 | height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): 126 | The height in pixels of the generated image. 127 | width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): 128 | The width in pixels of the generated image. 129 | num_inference_steps (`int`, *optional*, defaults to 50): 130 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 131 | expense of slower inference. This parameter is modulated by `strength`. 132 | guidance_scale (`float`, *optional*, defaults to 7.5): 133 | A higher guidance scale value encourages the model to generate images closely linked to the text 134 | `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. 135 | num_images_per_prompt (`int`, *optional*, defaults to 1): 136 | The number of images to generate per prompt. 137 | eta (`float`, *optional*, defaults to 0.0): 138 | Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies 139 | to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. 140 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 141 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make 142 | generation deterministic. 143 | latents (`torch.FloatTensor`, *optional*): 144 | Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image 145 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 146 | tensor is generated by sampling using the supplied random `generator`. 147 | output_type (`str`, *optional*, defaults to `"pil"`): 148 | The output format of the generated image. Choose between `PIL.Image` or `np.array`. 149 | return_dict (`bool`, *optional*, defaults to `True`): 150 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 151 | plain tuple. 152 | callback (`Callable`, *optional*): 153 | A function that calls every `callback_steps` steps during inference. The function is called with the 154 | following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. 155 | callback_steps (`int`, *optional*, defaults to 1): 156 | The frequency at which the `callback` function is called. If not specified, the callback is called at 157 | every step. 158 | 159 | Returns: 160 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: 161 | If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, 162 | otherwise a `tuple` is returned where the first element is a list with the generated images and the 163 | second element is a list of `bool`s indicating whether the corresponding generated image contains 164 | "not-safe-for-work" (nsfw) content. 165 | 166 | Examples: 167 | 168 | ```py 169 | from diffusers import StableDiffusionImageVariationPipeline 170 | from PIL import Image 171 | from io import BytesIO 172 | import requests 173 | 174 | pipe = StableDiffusionImageVariationPipeline.from_pretrained( 175 | "lambdalabs/sd-image-variations-diffusers", revision="v2.0" 176 | ) 177 | pipe = pipe.to("cuda") 178 | 179 | url = "https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200" 180 | 181 | response = requests.get(url) 182 | image = Image.open(BytesIO(response.content)).convert("RGB") 183 | 184 | out = pipe(image, num_images_per_prompt=3, guidance_scale=15) 185 | out["images"][0].save("result.jpg") 186 | ``` 187 | """ 188 | # 0. Default height and width to unet 189 | height = height or self.unet.config.sample_size * self.vae_scale_factor 190 | width = width or self.unet.config.sample_size * self.vae_scale_factor 191 | 192 | # 1. Check inputs. Raise error if not correct 193 | self.check_inputs(image, height, width, callback_steps) 194 | 195 | # 2. Define call parameters 196 | if isinstance(image, Image.Image): 197 | batch_size = 1 198 | elif len(image) == 1: 199 | image = image[0] 200 | batch_size = 1 201 | else: 202 | raise NotImplementedError() 203 | # elif isinstance(image, list): 204 | # batch_size = len(image) 205 | # else: 206 | # batch_size = image.shape[0] 207 | device = self._execution_device 208 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 209 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 210 | # corresponds to doing no classifier free guidance. 211 | do_classifier_free_guidance = guidance_scale > 1.0 212 | 213 | # 3. Encode input image 214 | emb_image = image 215 | 216 | image_embeddings = self._encode_image(emb_image, device, num_images_per_prompt, do_classifier_free_guidance) 217 | cond_latents = self.encode_latents(image, image_embeddings.device, image_embeddings.dtype, height_cond, width_cond) 218 | cond_latents = torch.cat([torch.zeros_like(cond_latents), cond_latents]) if do_classifier_free_guidance else cond_latents 219 | image_pixels = self.feature_extractor(images=emb_image, return_tensors="pt").pixel_values 220 | if do_classifier_free_guidance: 221 | image_pixels = torch.cat([torch.zeros_like(image_pixels), image_pixels], dim=0) 222 | 223 | # 4. Prepare timesteps 224 | self.scheduler.set_timesteps(num_inference_steps, device=device) 225 | timesteps = self.scheduler.timesteps 226 | 227 | # 5. Prepare latent variables 228 | num_channels_latents = self.unet.config.out_channels 229 | latents = self.prepare_latents( 230 | batch_size * num_images_per_prompt, 231 | num_channels_latents, 232 | height, 233 | width, 234 | image_embeddings.dtype, 235 | device, 236 | generator, 237 | latents, 238 | ) 239 | 240 | 241 | # 6. Prepare extra step kwargs. 242 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 243 | # 7. Denoising loop 244 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 245 | with self.progress_bar(total=num_inference_steps) as progress_bar: 246 | for i, t in enumerate(timesteps): 247 | # expand the latents if we are doing classifier free guidance 248 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 249 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 250 | 251 | # predict the noise residual 252 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings, condition_latents=cond_latents, noisy_condition_input=False, cond_pixels_clip=image_pixels).sample 253 | 254 | # perform guidance 255 | if do_classifier_free_guidance: 256 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 257 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 258 | 259 | # compute the previous noisy sample x_t -> x_t-1 260 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 261 | 262 | # call the callback, if provided 263 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 264 | progress_bar.update() 265 | if callback is not None and i % callback_steps == 0: 266 | step_idx = i // getattr(self.scheduler, "order", 1) 267 | callback(step_idx, t, latents) 268 | 269 | self.maybe_free_model_hooks() 270 | 271 | if self.latents_offset is not None: 272 | latents = latents + torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None] 273 | 274 | if not output_type == "latent": 275 | image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] 276 | image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype) 277 | else: 278 | image = latents 279 | has_nsfw_concept = None 280 | 281 | if has_nsfw_concept is None: 282 | do_denormalize = [True] * image.shape[0] 283 | else: 284 | do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] 285 | 286 | image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) 287 | 288 | self.maybe_free_model_hooks() 289 | 290 | if not return_dict: 291 | return (image, has_nsfw_concept) 292 | 293 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) 294 | 295 | if __name__ == "__main__": 296 | pass 297 | -------------------------------------------------------------------------------- /custum_3d_diffusion/modules.py: -------------------------------------------------------------------------------- 1 | __modules__ = {} 2 | 3 | def register(name): 4 | def decorator(cls): 5 | __modules__[name] = cls 6 | return cls 7 | 8 | return decorator 9 | 10 | 11 | def find(name): 12 | return __modules__[name] 13 | 14 | from custum_3d_diffusion.trainings import base, image2mvimage_trainer, image2image_trainer 15 | -------------------------------------------------------------------------------- /custum_3d_diffusion/trainings/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AiuniAI/Unique3D/a4f4ad2a6d9e71f2b04a753b375022898581b2c8/custum_3d_diffusion/trainings/__init__.py -------------------------------------------------------------------------------- /custum_3d_diffusion/trainings/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from accelerate import Accelerator 3 | from accelerate.logging import MultiProcessAdapter 4 | from dataclasses import dataclass, field 5 | from typing import Optional, Union 6 | from datasets import load_dataset 7 | import json 8 | import abc 9 | from diffusers.utils import make_image_grid 10 | import numpy as np 11 | import wandb 12 | 13 | from custum_3d_diffusion.trainings.utils import load_config 14 | from custum_3d_diffusion.custum_modules.unifield_processor import ConfigurableUNet2DConditionModel, AttnConfig 15 | 16 | class BasicTrainer(torch.nn.Module, abc.ABC): 17 | accelerator: Accelerator 18 | logger: MultiProcessAdapter 19 | unet: ConfigurableUNet2DConditionModel 20 | train_dataloader: torch.utils.data.DataLoader 21 | test_dataset: torch.utils.data.Dataset 22 | attn_config: AttnConfig 23 | 24 | @dataclass 25 | class TrainerConfig: 26 | trainer_name: str = "basic" 27 | pretrained_model_name_or_path: str = "" 28 | 29 | attn_config: dict = field(default_factory=dict) 30 | dataset_name: str = "" 31 | dataset_config_name: Optional[str] = None 32 | resolution: str = "1024" 33 | dataloader_num_workers: int = 4 34 | pair_sampler_group_size: int = 1 35 | num_views: int = 4 36 | 37 | max_train_steps: int = -1 # -1 means infinity, otherwise [0, max_train_steps) 38 | training_step_interval: int = 1 # train on step i*interval, stop at max_train_steps 39 | max_train_samples: Optional[int] = None 40 | seed: Optional[int] = None # For dataset related operations and validation stuff 41 | train_batch_size: int = 1 42 | 43 | validation_interval: int = 5000 44 | debug: bool = False 45 | 46 | cfg: TrainerConfig # only enable_xxx is used 47 | 48 | def __init__( 49 | self, 50 | accelerator: Accelerator, 51 | logger: MultiProcessAdapter, 52 | unet: ConfigurableUNet2DConditionModel, 53 | config: Union[dict, str], 54 | weight_dtype: torch.dtype, 55 | index: int, 56 | ): 57 | super().__init__() 58 | self.index = index # index in all trainers 59 | self.accelerator = accelerator 60 | self.logger = logger 61 | self.unet = unet 62 | self.weight_dtype = weight_dtype 63 | self.ext_logs = {} 64 | self.cfg = load_config(self.TrainerConfig, config) 65 | self.attn_config = load_config(AttnConfig, self.cfg.attn_config) 66 | self.test_dataset = None 67 | self.validate_trainer_config() 68 | self.configure() 69 | 70 | def get_HW(self): 71 | resolution = json.loads(self.cfg.resolution) 72 | if isinstance(resolution, int): 73 | H = W = resolution 74 | elif isinstance(resolution, list): 75 | H, W = resolution 76 | return H, W 77 | 78 | def unet_update(self): 79 | self.unet.update_config(self.attn_config) 80 | 81 | def validate_trainer_config(self): 82 | pass 83 | 84 | def is_train_finished(self, current_step): 85 | assert isinstance(self.cfg.max_train_steps, int) 86 | return self.cfg.max_train_steps != -1 and current_step >= self.cfg.max_train_steps 87 | 88 | def next_train_step(self, current_step): 89 | if self.is_train_finished(current_step): 90 | return None 91 | return current_step + self.cfg.training_step_interval 92 | 93 | @classmethod 94 | def make_image_into_grid(cls, all_imgs, rows=2, columns=2): 95 | catted = [make_image_grid(all_imgs[i:i+rows * columns], rows=rows, cols=columns) for i in range(0, len(all_imgs), rows * columns)] 96 | return make_image_grid(catted, rows=1, cols=len(catted)) 97 | 98 | def configure(self) -> None: 99 | pass 100 | 101 | @abc.abstractmethod 102 | def init_shared_modules(self, shared_modules: dict) -> dict: 103 | pass 104 | 105 | def load_dataset(self): 106 | dataset = load_dataset( 107 | self.cfg.dataset_name, 108 | self.cfg.dataset_config_name, 109 | trust_remote_code=True 110 | ) 111 | return dataset 112 | 113 | @abc.abstractmethod 114 | def init_train_dataloader(self, shared_modules: dict) -> torch.utils.data.DataLoader: 115 | """Both init train_dataloader and test_dataset, but returns train_dataloader only""" 116 | pass 117 | 118 | @abc.abstractmethod 119 | def forward_step( 120 | self, 121 | *args, 122 | **kwargs 123 | ) -> torch.Tensor: 124 | """ 125 | input a batch 126 | return a loss 127 | """ 128 | self.unet_update() 129 | pass 130 | 131 | @abc.abstractmethod 132 | def construct_pipeline(self, shared_modules, unet): 133 | pass 134 | 135 | @abc.abstractmethod 136 | def pipeline_forward(self, pipeline, **pipeline_call_kwargs) -> tuple: 137 | """ 138 | For inference time forward. 139 | """ 140 | pass 141 | 142 | @abc.abstractmethod 143 | def batched_validation_forward(self, pipeline, **pipeline_call_kwargs) -> tuple: 144 | pass 145 | 146 | def do_validation( 147 | self, 148 | shared_modules, 149 | unet, 150 | global_step, 151 | ): 152 | self.unet_update() 153 | self.logger.info("Running validation... ") 154 | pipeline = self.construct_pipeline(shared_modules, unet) 155 | pipeline.set_progress_bar_config(disable=True) 156 | titles, images = self.batched_validation_forward(pipeline, guidance_scale=[1., 3.]) 157 | for tracker in self.accelerator.trackers: 158 | if tracker.name == "tensorboard": 159 | np_images = np.stack([np.asarray(img) for img in images]) 160 | tracker.writer.add_images("validation", np_images, global_step, dataformats="NHWC") 161 | elif tracker.name == "wandb": 162 | [image.thumbnail((512, 512)) for image, title in zip(images, titles) if 'noresize' not in title] # inplace operation 163 | tracker.log({"validation": [ 164 | wandb.Image(image, caption=f"{i}: {titles[i]}", file_type="jpg") 165 | for i, image in enumerate(images)]}) 166 | else: 167 | self.logger.warn(f"image logging not implemented for {tracker.name}") 168 | del pipeline 169 | torch.cuda.empty_cache() 170 | return images 171 | 172 | 173 | @torch.no_grad() 174 | def log_validation( 175 | self, 176 | shared_modules, 177 | unet, 178 | global_step, 179 | force=False 180 | ): 181 | if self.accelerator.is_main_process: 182 | for tracker in self.accelerator.trackers: 183 | if tracker.name == "wandb": 184 | tracker.log(self.ext_logs) 185 | self.ext_logs = {} 186 | if (global_step % self.cfg.validation_interval == 0 and not self.is_train_finished(global_step)) or force: 187 | self.unet_update() 188 | if self.accelerator.is_main_process: 189 | self.do_validation(shared_modules, self.accelerator.unwrap_model(unet), global_step) 190 | 191 | def save_model(self, unwrap_unet, shared_modules, save_dir): 192 | if self.accelerator.is_main_process: 193 | pipeline = self.construct_pipeline(shared_modules, unwrap_unet) 194 | pipeline.save_pretrained(save_dir) 195 | self.logger.info(f"{self.cfg.trainer_name} Model saved at {save_dir}") 196 | 197 | def save_debug_info(self, save_name="debug", **kwargs): 198 | if self.cfg.debug: 199 | to_saves = {key: value.detach().cpu() if isinstance(value, torch.Tensor) else value for key, value in kwargs.items()} 200 | import pickle 201 | import os 202 | if os.path.exists(f"{save_name}.pkl"): 203 | for i in range(100): 204 | if not os.path.exists(f"{save_name}_v{i}.pkl"): 205 | save_name = f"{save_name}_v{i}" 206 | break 207 | with open(f"{save_name}.pkl", "wb") as f: 208 | pickle.dump(to_saves, f) -------------------------------------------------------------------------------- /custum_3d_diffusion/trainings/config_classes.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import List, Optional 3 | 4 | 5 | @dataclass 6 | class TrainerSubConfig: 7 | trainer_type: str = "" 8 | trainer: dict = field(default_factory=dict) 9 | 10 | 11 | @dataclass 12 | class ExprimentConfig: 13 | trainers: List[dict] = field(default_factory=lambda: []) 14 | init_config: dict = field(default_factory=dict) 15 | pretrained_model_name_or_path: str = "" 16 | pretrained_unet_state_dict_path: str = "" 17 | # expriments related parameters 18 | linear_beta_schedule: bool = False 19 | zero_snr: bool = False 20 | prediction_type: Optional[str] = None 21 | seed: Optional[int] = None 22 | max_train_steps: int = 1000000 23 | gradient_accumulation_steps: int = 1 24 | learning_rate: float = 1e-4 25 | lr_scheduler: str = "constant" 26 | lr_warmup_steps: int = 500 27 | use_8bit_adam: bool = False 28 | adam_beta1: float = 0.9 29 | adam_beta2: float = 0.999 30 | adam_weight_decay: float = 1e-2 31 | adam_epsilon: float = 1e-08 32 | max_grad_norm: float = 1.0 33 | mixed_precision: Optional[str] = None # ["no", "fp16", "bf16", "fp8"] 34 | skip_training: bool = False 35 | debug: bool = False -------------------------------------------------------------------------------- /custum_3d_diffusion/trainings/image2image_trainer.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | from diffusers import EulerAncestralDiscreteScheduler, DDPMScheduler 4 | from dataclasses import dataclass 5 | 6 | from custum_3d_diffusion.modules import register 7 | from custum_3d_diffusion.trainings.image2mvimage_trainer import Image2MVImageTrainer 8 | from custum_3d_diffusion.custum_pipeline.unifield_pipeline_img2img import StableDiffusionImageCustomPipeline 9 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput 10 | 11 | def get_HW(resolution): 12 | if isinstance(resolution, str): 13 | resolution = json.loads(resolution) 14 | if isinstance(resolution, int): 15 | H = W = resolution 16 | elif isinstance(resolution, list): 17 | H, W = resolution 18 | return H, W 19 | 20 | 21 | @register("image2image_trainer") 22 | class Image2ImageTrainer(Image2MVImageTrainer): 23 | """ 24 | Trainer for simple image to multiview images. 25 | """ 26 | @dataclass 27 | class TrainerConfig(Image2MVImageTrainer.TrainerConfig): 28 | trainer_name: str = "image2image" 29 | 30 | cfg: TrainerConfig 31 | 32 | def forward_step(self, batch, unet, shared_modules, noise_scheduler: DDPMScheduler, global_step) -> torch.Tensor: 33 | raise NotImplementedError() 34 | 35 | def construct_pipeline(self, shared_modules, unet, old_version=False): 36 | MyPipeline = StableDiffusionImageCustomPipeline 37 | pipeline = MyPipeline.from_pretrained( 38 | self.cfg.pretrained_model_name_or_path, 39 | vae=shared_modules['vae'], 40 | image_encoder=shared_modules['image_encoder'], 41 | feature_extractor=shared_modules['feature_extractor'], 42 | unet=unet, 43 | safety_checker=None, 44 | torch_dtype=self.weight_dtype, 45 | latents_offset=self.cfg.latents_offset, 46 | noisy_cond_latents=self.cfg.noisy_condition_input, 47 | ) 48 | pipeline.set_progress_bar_config(disable=True) 49 | scheduler_dict = {} 50 | if self.cfg.zero_snr: 51 | scheduler_dict.update(rescale_betas_zero_snr=True) 52 | if self.cfg.linear_beta_schedule: 53 | scheduler_dict.update(beta_schedule='linear') 54 | 55 | pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config, **scheduler_dict) 56 | return pipeline 57 | 58 | def get_forward_args(self): 59 | if self.cfg.seed is None: 60 | generator = None 61 | else: 62 | generator = torch.Generator(device=self.accelerator.device).manual_seed(self.cfg.seed) 63 | 64 | H, W = get_HW(self.cfg.resolution) 65 | H_cond, W_cond = get_HW(self.cfg.condition_image_resolution) 66 | 67 | forward_args = dict( 68 | num_images_per_prompt=1, 69 | num_inference_steps=20, 70 | height=H, 71 | width=W, 72 | height_cond=H_cond, 73 | width_cond=W_cond, 74 | generator=generator, 75 | ) 76 | if self.cfg.zero_snr: 77 | forward_args.update(guidance_rescale=0.7) 78 | return forward_args 79 | 80 | def pipeline_forward(self, pipeline, **pipeline_call_kwargs) -> StableDiffusionPipelineOutput: 81 | forward_args = self.get_forward_args() 82 | forward_args.update(pipeline_call_kwargs) 83 | return pipeline(**forward_args) 84 | 85 | def batched_validation_forward(self, pipeline, **pipeline_call_kwargs) -> tuple: 86 | raise NotImplementedError() -------------------------------------------------------------------------------- /custum_3d_diffusion/trainings/image2mvimage_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers import AutoencoderKL, DDPMScheduler, EulerAncestralDiscreteScheduler, DDIMScheduler 3 | from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, BatchFeature 4 | 5 | import json 6 | from dataclasses import dataclass 7 | from typing import List, Optional 8 | 9 | from custum_3d_diffusion.modules import register 10 | from custum_3d_diffusion.trainings.base import BasicTrainer 11 | from custum_3d_diffusion.custum_pipeline.unifield_pipeline_img2mvimg import StableDiffusionImage2MVCustomPipeline 12 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput 13 | 14 | def get_HW(resolution): 15 | if isinstance(resolution, str): 16 | resolution = json.loads(resolution) 17 | if isinstance(resolution, int): 18 | H = W = resolution 19 | elif isinstance(resolution, list): 20 | H, W = resolution 21 | return H, W 22 | 23 | @register("image2mvimage_trainer") 24 | class Image2MVImageTrainer(BasicTrainer): 25 | """ 26 | Trainer for simple image to multiview images. 27 | """ 28 | @dataclass 29 | class TrainerConfig(BasicTrainer.TrainerConfig): 30 | trainer_name: str = "image2mvimage" 31 | condition_image_column_name: str = "conditioning_image" 32 | image_column_name: str = "image" 33 | condition_dropout: float = 0. 34 | condition_image_resolution: str = "512" 35 | validation_images: Optional[List[str]] = None 36 | noise_offset: float = 0.1 37 | max_loss_drop: float = 0. 38 | snr_gamma: float = 5.0 39 | log_distribution: bool = False 40 | latents_offset: Optional[List[float]] = None 41 | input_perturbation: float = 0. 42 | noisy_condition_input: bool = False # whether to add noise for ref unet input 43 | normal_cls_offset: int = 0 44 | condition_offset: bool = True 45 | zero_snr: bool = False 46 | linear_beta_schedule: bool = False 47 | 48 | cfg: TrainerConfig 49 | 50 | def configure(self) -> None: 51 | return super().configure() 52 | 53 | def init_shared_modules(self, shared_modules: dict) -> dict: 54 | if 'vae' not in shared_modules: 55 | vae = AutoencoderKL.from_pretrained( 56 | self.cfg.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.weight_dtype 57 | ) 58 | vae.requires_grad_(False) 59 | vae.to(self.accelerator.device, dtype=self.weight_dtype) 60 | shared_modules['vae'] = vae 61 | if 'image_encoder' not in shared_modules: 62 | image_encoder = CLIPVisionModelWithProjection.from_pretrained( 63 | self.cfg.pretrained_model_name_or_path, subfolder="image_encoder" 64 | ) 65 | image_encoder.requires_grad_(False) 66 | image_encoder.to(self.accelerator.device, dtype=self.weight_dtype) 67 | shared_modules['image_encoder'] = image_encoder 68 | if 'feature_extractor' not in shared_modules: 69 | feature_extractor = CLIPImageProcessor.from_pretrained( 70 | self.cfg.pretrained_model_name_or_path, subfolder="feature_extractor" 71 | ) 72 | shared_modules['feature_extractor'] = feature_extractor 73 | return shared_modules 74 | 75 | def init_train_dataloader(self, shared_modules: dict) -> torch.utils.data.DataLoader: 76 | raise NotImplementedError() 77 | 78 | def loss_rescale(self, loss, timesteps=None): 79 | raise NotImplementedError() 80 | 81 | def forward_step(self, batch, unet, shared_modules, noise_scheduler: DDPMScheduler, global_step) -> torch.Tensor: 82 | raise NotImplementedError() 83 | 84 | def construct_pipeline(self, shared_modules, unet, old_version=False): 85 | MyPipeline = StableDiffusionImage2MVCustomPipeline 86 | pipeline = MyPipeline.from_pretrained( 87 | self.cfg.pretrained_model_name_or_path, 88 | vae=shared_modules['vae'], 89 | image_encoder=shared_modules['image_encoder'], 90 | feature_extractor=shared_modules['feature_extractor'], 91 | unet=unet, 92 | safety_checker=None, 93 | torch_dtype=self.weight_dtype, 94 | latents_offset=self.cfg.latents_offset, 95 | noisy_cond_latents=self.cfg.noisy_condition_input, 96 | condition_offset=self.cfg.condition_offset, 97 | ) 98 | pipeline.set_progress_bar_config(disable=True) 99 | scheduler_dict = {} 100 | if self.cfg.zero_snr: 101 | scheduler_dict.update(rescale_betas_zero_snr=True) 102 | if self.cfg.linear_beta_schedule: 103 | scheduler_dict.update(beta_schedule='linear') 104 | 105 | pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config, **scheduler_dict) 106 | return pipeline 107 | 108 | def get_forward_args(self): 109 | if self.cfg.seed is None: 110 | generator = None 111 | else: 112 | generator = torch.Generator(device=self.accelerator.device).manual_seed(self.cfg.seed) 113 | 114 | H, W = get_HW(self.cfg.resolution) 115 | H_cond, W_cond = get_HW(self.cfg.condition_image_resolution) 116 | 117 | sub_img_H = H // 2 118 | num_imgs = H // sub_img_H * W // sub_img_H 119 | 120 | forward_args = dict( 121 | num_images_per_prompt=num_imgs, 122 | num_inference_steps=50, 123 | height=sub_img_H, 124 | width=sub_img_H, 125 | height_cond=H_cond, 126 | width_cond=W_cond, 127 | generator=generator, 128 | ) 129 | if self.cfg.zero_snr: 130 | forward_args.update(guidance_rescale=0.7) 131 | return forward_args 132 | 133 | def pipeline_forward(self, pipeline, **pipeline_call_kwargs) -> StableDiffusionPipelineOutput: 134 | forward_args = self.get_forward_args() 135 | forward_args.update(pipeline_call_kwargs) 136 | return pipeline(**forward_args) 137 | 138 | def batched_validation_forward(self, pipeline, **pipeline_call_kwargs) -> tuple: 139 | raise NotImplementedError() -------------------------------------------------------------------------------- /custum_3d_diffusion/trainings/utils.py: -------------------------------------------------------------------------------- 1 | from omegaconf import DictConfig, OmegaConf 2 | 3 | 4 | def parse_structured(fields, cfg) -> DictConfig: 5 | scfg = OmegaConf.structured(fields(**cfg)) 6 | return scfg 7 | 8 | 9 | def load_config(fields, config, extras=None): 10 | if extras is not None: 11 | print("Warning! extra parameter in cli is not verified, may cause erros.") 12 | if isinstance(config, str): 13 | cfg = OmegaConf.load(config) 14 | elif isinstance(config, dict): 15 | cfg = OmegaConf.create(config) 16 | elif isinstance(config, DictConfig): 17 | cfg = config 18 | else: 19 | raise NotImplementedError(f"Unsupported config type {type(config)}") 20 | if extras is not None: 21 | cli_conf = OmegaConf.from_cli(extras) 22 | cfg = OmegaConf.merge(cfg, cli_conf) 23 | OmegaConf.resolve(cfg) 24 | assert isinstance(cfg, DictConfig) 25 | return parse_structured(fields, cfg) -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # get the development image from nvidia cuda 12.1 2 | FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04 3 | 4 | LABEL name="unique3d" maintainer="unique3d" 5 | 6 | # create workspace folder and set it as working directory 7 | RUN mkdir -p /workspace 8 | WORKDIR /workspace 9 | 10 | # update package lists and install git, wget, vim, libegl1-mesa-dev, and libglib2.0-0 11 | RUN apt-get update && apt-get install -y build-essential git wget vim libegl1-mesa-dev libglib2.0-0 unzip git-lfs 12 | 13 | RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends pkg-config libglvnd0 libgl1 libglx0 libegl1 libgles2 libglvnd-dev libgl1-mesa-dev libegl1-mesa-dev libgles2-mesa-dev cmake curl mesa-utils-extra 14 | ENV PYTHONDONTWRITEBYTECODE=1 15 | ENV PYTHONUNBUFFERED=1 16 | ENV LD_LIBRARY_PATH=/usr/lib64:$LD_LIBRARY_PATH 17 | ENV PYOPENGL_PLATFORM=egl 18 | 19 | # install conda 20 | RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ 21 | chmod +x Miniconda3-latest-Linux-x86_64.sh && \ 22 | ./Miniconda3-latest-Linux-x86_64.sh -b -p /workspace/miniconda3 && \ 23 | rm Miniconda3-latest-Linux-x86_64.sh 24 | 25 | # update PATH environment variable 26 | ENV PATH="/workspace/miniconda3/bin:${PATH}" 27 | 28 | # initialize conda 29 | RUN conda init bash 30 | 31 | # create and activate conda environment 32 | RUN conda create -n unique3d python=3.10 && echo "source activate unique3d" > ~/.bashrc 33 | ENV PATH /workspace/miniconda3/envs/unique3d/bin:$PATH 34 | 35 | RUN conda install Ninja 36 | RUN conda install cuda -c nvidia/label/cuda-12.1.0 -y 37 | 38 | RUN pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 xformers triton --index-url https://download.pytorch.org/whl/cu121 39 | RUN pip install diffusers==0.27.2 40 | 41 | RUN git clone --depth 1 https://huggingface.co/spaces/Wuvin/Unique3D 42 | 43 | # change the working directory to the repository 44 | 45 | WORKDIR /workspace/Unique3D 46 | # other dependencies 47 | RUN pip install -r requirements.txt 48 | 49 | RUN pip install nvidia-pyindex 50 | 51 | RUN pip install --upgrade nvidia-tensorrt 52 | 53 | RUN pip install spaces 54 | 55 | -------------------------------------------------------------------------------- /docker/README.md: -------------------------------------------------------------------------------- 1 | # Docker setup 2 | 3 | This docker setup is tested on Windows 10. 4 | 5 | make sure you are under this directory yourworkspace/Unique3D/docker 6 | 7 | Build docker image: 8 | 9 | ``` 10 | docker build -t unique3d -f Dockerfile . 11 | ``` 12 | 13 | Run docker image at the first time: 14 | 15 | ``` 16 | docker run -it --name unique3d -p 7860:7860 --gpus all unique3d python app.py 17 | ``` 18 | 19 | After first time: 20 | ``` 21 | docker start unique3d 22 | docker exec unique3d python app.py 23 | ``` 24 | 25 | Stop the container: 26 | ``` 27 | docker stop unique3d 28 | ``` 29 | 30 | You can find the demo link showing in terminal, such as `https://94fc1ba77a08526e17.gradio.live/` or something similar else (it will be changed after each time to restart the container) to use the demo. 31 | 32 | Some notes: 33 | 1. this docker build is using https://huggingface.co/spaces/Wuvin/Unique3D rather than this repo to clone the source. 34 | 2. the total built time might take more than one hour. 35 | 3. the total size of the built image will be more than 70GB. -------------------------------------------------------------------------------- /gradio_app.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | import os 3 | import sys 4 | sys.path.append(os.curdir) 5 | import torch 6 | torch.set_float32_matmul_precision('medium') 7 | torch.backends.cuda.matmul.allow_tf32 = True 8 | torch.set_grad_enabled(False) 9 | 10 | import fire 11 | import gradio as gr 12 | from app.gradio_3dgen import create_ui as create_3d_ui 13 | from app.all_models import model_zoo 14 | 15 | 16 | _TITLE = '''Unique3D: High-Quality and Efficient 3D Mesh Generation from a Single Image''' 17 | _DESCRIPTION = ''' 18 | [Project page](https://wukailu.github.io/Unique3D/) 19 | 20 | * High-fidelity and diverse textured meshes generated by Unique3D from single-view images. 21 | 22 | * The demo is still under construction, and more features are expected to be implemented soon. 23 | ''' 24 | 25 | def launch(): 26 | model_zoo.init_models() 27 | 28 | with gr.Blocks( 29 | title=_TITLE, 30 | theme=gr.themes.Monochrome(), 31 | ) as demo: 32 | with gr.Row(): 33 | with gr.Column(scale=1): 34 | gr.Markdown('# ' + _TITLE) 35 | gr.Markdown(_DESCRIPTION) 36 | create_3d_ui("wkl") 37 | 38 | demo.queue().launch(share=True) 39 | 40 | if __name__ == '__main__': 41 | fire.Fire(launch) 42 | -------------------------------------------------------------------------------- /install_windows_win_py311_cu121.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | set "triton_whl=%~dp0\triton-2.1.0-cp311-cp311-win_amd64.whl" 4 | 5 | echo Starting to install Unique3D... 6 | 7 | echo Installing torch, xformers, etc 8 | 9 | pip install torch torchvision torchaudio xformers --index-url https://download.pytorch.org/whl/cu121 10 | 11 | echo Installing triton 12 | 13 | pip install "%triton_whl%" 14 | 15 | pip install Ninja 16 | 17 | pip install diffusers==0.27.2 18 | 19 | pip install grpcio werkzeug tensorboard-data-server 20 | 21 | pip install -r requirements-win-py311-cu121.txt 22 | 23 | echo Removing default onnxruntime and onnxruntime-gpu 24 | 25 | pip uninstall onnxruntime 26 | pip uninstall onnxruntime-gpu 27 | 28 | echo Installing correct version onnxruntime-gpu for cuda 12.1 29 | 30 | pip install onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/ 31 | 32 | echo Install Finished. Press any key to continue... 33 | 34 | pause -------------------------------------------------------------------------------- /mesh_reconstruction/func.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/Profactor/continuous-remeshing 2 | import torch 3 | import numpy as np 4 | import trimesh 5 | from typing import Tuple 6 | 7 | def to_numpy(*args): 8 | def convert(a): 9 | if isinstance(a,torch.Tensor): 10 | return a.detach().cpu().numpy() 11 | assert a is None or isinstance(a,np.ndarray) 12 | return a 13 | 14 | return convert(args[0]) if len(args)==1 else tuple(convert(a) for a in args) 15 | 16 | def laplacian( 17 | num_verts:int, 18 | edges: torch.Tensor #E,2 19 | ) -> torch.Tensor: #sparse V,V 20 | """create sparse Laplacian matrix""" 21 | V = num_verts 22 | E = edges.shape[0] 23 | 24 | #adjacency matrix, 25 | idx = torch.cat([edges, edges.fliplr()], dim=0).type(torch.long).T # (2, 2*E) 26 | ones = torch.ones(2*E, dtype=torch.float32, device=edges.device) 27 | A = torch.sparse.FloatTensor(idx, ones, (V, V)) 28 | 29 | #degree matrix 30 | deg = torch.sparse.sum(A, dim=1).to_dense() 31 | idx = torch.arange(V, device=edges.device) 32 | idx = torch.stack([idx, idx], dim=0) 33 | D = torch.sparse.FloatTensor(idx, deg, (V, V)) 34 | 35 | return D - A 36 | 37 | def _translation(x, y, z, device): 38 | return torch.tensor([[1., 0, 0, x], 39 | [0, 1, 0, y], 40 | [0, 0, 1, z], 41 | [0, 0, 0, 1]],device=device) #4,4 42 | 43 | def _projection(r, device, l=None, t=None, b=None, n=1.0, f=50.0, flip_y=True): 44 | """ 45 | see https://blog.csdn.net/wodownload2/article/details/85069240/ 46 | """ 47 | if l is None: 48 | l = -r 49 | if t is None: 50 | t = r 51 | if b is None: 52 | b = -t 53 | p = torch.zeros([4,4],device=device) 54 | p[0,0] = 2*n/(r-l) 55 | p[0,2] = (r+l)/(r-l) 56 | p[1,1] = 2*n/(t-b) * (-1 if flip_y else 1) 57 | p[1,2] = (t+b)/(t-b) 58 | p[2,2] = -(f+n)/(f-n) 59 | p[2,3] = -(2*f*n)/(f-n) 60 | p[3,2] = -1 61 | return p #4,4 62 | 63 | def _orthographic(r, device, l=None, t=None, b=None, n=1.0, f=50.0, flip_y=True): 64 | if l is None: 65 | l = -r 66 | if t is None: 67 | t = r 68 | if b is None: 69 | b = -t 70 | o = torch.zeros([4,4],device=device) 71 | o[0,0] = 2/(r-l) 72 | o[0,3] = -(r+l)/(r-l) 73 | o[1,1] = 2/(t-b) * (-1 if flip_y else 1) 74 | o[1,3] = -(t+b)/(t-b) 75 | o[2,2] = -2/(f-n) 76 | o[2,3] = -(f+n)/(f-n) 77 | o[3,3] = 1 78 | return o #4,4 79 | 80 | def make_star_cameras(az_count,pol_count,distance:float=10.,r=None,image_size=[512,512],device='cuda'): 81 | if r is None: 82 | r = 1/distance 83 | A = az_count 84 | P = pol_count 85 | C = A * P 86 | 87 | phi = torch.arange(0,A) * (2*torch.pi/A) 88 | phi_rot = torch.eye(3,device=device)[None,None].expand(A,1,3,3).clone() 89 | phi_rot[:,0,2,2] = phi.cos() 90 | phi_rot[:,0,2,0] = -phi.sin() 91 | phi_rot[:,0,0,2] = phi.sin() 92 | phi_rot[:,0,0,0] = phi.cos() 93 | 94 | theta = torch.arange(1,P+1) * (torch.pi/(P+1)) - torch.pi/2 95 | theta_rot = torch.eye(3,device=device)[None,None].expand(1,P,3,3).clone() 96 | theta_rot[0,:,1,1] = theta.cos() 97 | theta_rot[0,:,1,2] = -theta.sin() 98 | theta_rot[0,:,2,1] = theta.sin() 99 | theta_rot[0,:,2,2] = theta.cos() 100 | 101 | mv = torch.empty((C,4,4), device=device) 102 | mv[:] = torch.eye(4, device=device) 103 | mv[:,:3,:3] = (theta_rot @ phi_rot).reshape(C,3,3) 104 | mv = _translation(0, 0, -distance, device) @ mv 105 | 106 | return mv, _projection(r,device) 107 | 108 | def make_star_cameras_orthographic(az_count,pol_count,distance:float=10.,r=None,image_size=[512,512],device='cuda'): 109 | mv, _ = make_star_cameras(az_count,pol_count,distance,r,image_size,device) 110 | if r is None: 111 | r = 1 112 | return mv, _orthographic(r,device) 113 | 114 | def make_sphere(level:int=2,radius=1.,device='cuda') -> Tuple[torch.Tensor,torch.Tensor]: 115 | sphere = trimesh.creation.icosphere(subdivisions=level, radius=1.0, color=None) 116 | vertices = torch.tensor(sphere.vertices, device=device, dtype=torch.float32) * radius 117 | faces = torch.tensor(sphere.faces, device=device, dtype=torch.long) 118 | return vertices,faces 119 | 120 | from pytorch3d.renderer import ( 121 | FoVOrthographicCameras, 122 | look_at_view_transform, 123 | ) 124 | 125 | def get_camera(R, T, focal_length=1 / (2**0.5)): 126 | focal_length = 1 / focal_length 127 | camera = FoVOrthographicCameras(device=R.device, R=R, T=T, min_x=-focal_length, max_x=focal_length, min_y=-focal_length, max_y=focal_length) 128 | return camera 129 | 130 | def make_star_cameras_orthographic_py3d(azim_list, device, focal=2/1.35, dist=1.1): 131 | R, T = look_at_view_transform(dist, 0, azim_list) 132 | focal_length = 1 / focal 133 | return FoVOrthographicCameras(device=R.device, R=R, T=T, min_x=-focal_length, max_x=focal_length, min_y=-focal_length, max_y=focal_length).to(device) 134 | -------------------------------------------------------------------------------- /mesh_reconstruction/opt.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/Profactor/continuous-remeshing 2 | import time 3 | import torch 4 | import torch_scatter 5 | from typing import Tuple 6 | from mesh_reconstruction.remesh import calc_edge_length, calc_edges, calc_face_collapses, calc_face_normals, calc_vertex_normals, collapse_edges, flip_edges, pack, prepend_dummies, remove_dummies, split_edges 7 | 8 | @torch.no_grad() 9 | def remesh( 10 | vertices_etc:torch.Tensor, #V,D 11 | faces:torch.Tensor, #F,3 long 12 | min_edgelen:torch.Tensor, #V 13 | max_edgelen:torch.Tensor, #V 14 | flip:bool, 15 | max_vertices=1e6 16 | ): 17 | 18 | # dummies 19 | vertices_etc,faces = prepend_dummies(vertices_etc,faces) 20 | vertices = vertices_etc[:,:3] #V,3 21 | nan_tensor = torch.tensor([torch.nan],device=min_edgelen.device) 22 | min_edgelen = torch.concat((nan_tensor,min_edgelen)) 23 | max_edgelen = torch.concat((nan_tensor,max_edgelen)) 24 | 25 | # collapse 26 | edges,face_to_edge = calc_edges(faces) #E,2 F,3 27 | edge_length = calc_edge_length(vertices,edges) #E 28 | face_normals = calc_face_normals(vertices,faces,normalize=False) #F,3 29 | vertex_normals = calc_vertex_normals(vertices,faces,face_normals) #V,3 30 | face_collapse = calc_face_collapses(vertices,faces,edges,face_to_edge,edge_length,face_normals,vertex_normals,min_edgelen,area_ratio=0.5) 31 | shortness = (1 - edge_length / min_edgelen[edges].mean(dim=-1)).clamp_min_(0) #e[0,1] 0...ok, 1...edgelen=0 32 | priority = face_collapse.float() + shortness 33 | vertices_etc,faces = collapse_edges(vertices_etc,faces,edges,priority) 34 | 35 | # split 36 | if vertices.shape[0] max_edgelen[edges].mean(dim=-1) 41 | vertices_etc,faces = split_edges(vertices_etc,faces,edges,face_to_edge,splits,pack_faces=False) 42 | 43 | vertices_etc,faces = pack(vertices_etc,faces) 44 | vertices = vertices_etc[:,:3] 45 | 46 | if flip: 47 | edges,_,edge_to_face = calc_edges(faces,with_edge_to_face=True) #E,2 F,3 48 | flip_edges(vertices,faces,edges,edge_to_face,with_border=False) 49 | 50 | return remove_dummies(vertices_etc,faces) 51 | 52 | def lerp_unbiased(a:torch.Tensor,b:torch.Tensor,weight:float,step:int): 53 | """lerp with adam's bias correction""" 54 | c_prev = 1-weight**(step-1) 55 | c = 1-weight**step 56 | a_weight = weight*c_prev/c 57 | b_weight = (1-weight)/c 58 | a.mul_(a_weight).add_(b, alpha=b_weight) 59 | 60 | 61 | class MeshOptimizer: 62 | """Use this like a pytorch Optimizer, but after calling opt.step(), do vertices,faces = opt.remesh().""" 63 | 64 | def __init__(self, 65 | vertices:torch.Tensor, #V,3 66 | faces:torch.Tensor, #F,3 67 | lr=0.3, #learning rate 68 | betas=(0.8,0.8,0), #betas[0:2] are the same as in Adam, betas[2] may be used to time-smooth the relative velocity nu 69 | gammas=(0,0,0), #optional spatial smoothing for m1,m2,nu, values between 0 (no smoothing) and 1 (max. smoothing) 70 | nu_ref=0.3, #reference velocity for edge length controller 71 | edge_len_lims=(.01,.15), #smallest and largest allowed reference edge length 72 | edge_len_tol=.5, #edge length tolerance for split and collapse 73 | gain=.2, #gain value for edge length controller 74 | laplacian_weight=.02, #for laplacian smoothing/regularization 75 | ramp=1, #learning rate ramp, actual ramp width is ramp/(1-betas[0]) 76 | grad_lim=10., #gradients are clipped to m1.abs()*grad_lim 77 | remesh_interval=1, #larger intervals are faster but with worse mesh quality 78 | local_edgelen=True, #set to False to use a global scalar reference edge length instead 79 | ): 80 | self._vertices = vertices 81 | self._faces = faces 82 | self._lr = lr 83 | self._betas = betas 84 | self._gammas = gammas 85 | self._nu_ref = nu_ref 86 | self._edge_len_lims = edge_len_lims 87 | self._edge_len_tol = edge_len_tol 88 | self._gain = gain 89 | self._laplacian_weight = laplacian_weight 90 | self._ramp = ramp 91 | self._grad_lim = grad_lim 92 | self._remesh_interval = remesh_interval 93 | self._local_edgelen = local_edgelen 94 | self._step = 0 95 | 96 | V = self._vertices.shape[0] 97 | # prepare continuous tensor for all vertex-based data 98 | self._vertices_etc = torch.zeros([V,9],device=vertices.device) 99 | self._split_vertices_etc() 100 | self.vertices.copy_(vertices) #initialize vertices 101 | self._vertices.requires_grad_() 102 | self._ref_len.fill_(edge_len_lims[1]) 103 | 104 | @property 105 | def vertices(self): 106 | return self._vertices 107 | 108 | @property 109 | def faces(self): 110 | return self._faces 111 | 112 | def _split_vertices_etc(self): 113 | self._vertices = self._vertices_etc[:,:3] 114 | self._m2 = self._vertices_etc[:,3] 115 | self._nu = self._vertices_etc[:,4] 116 | self._m1 = self._vertices_etc[:,5:8] 117 | self._ref_len = self._vertices_etc[:,8] 118 | 119 | with_gammas = any(g!=0 for g in self._gammas) 120 | self._smooth = self._vertices_etc[:,:8] if with_gammas else self._vertices_etc[:,:3] 121 | 122 | def zero_grad(self): 123 | self._vertices.grad = None 124 | 125 | @torch.no_grad() 126 | def step(self): 127 | 128 | eps = 1e-8 129 | 130 | self._step += 1 131 | 132 | # spatial smoothing 133 | edges,_ = calc_edges(self._faces) #E,2 134 | E = edges.shape[0] 135 | edge_smooth = self._smooth[edges] #E,2,S 136 | neighbor_smooth = torch.zeros_like(self._smooth) #V,S 137 | torch_scatter.scatter_mean(src=edge_smooth.flip(dims=[1]).reshape(E*2,-1),index=edges.reshape(E*2,1),dim=0,out=neighbor_smooth) 138 | 139 | #apply optional smoothing of m1,m2,nu 140 | if self._gammas[0]: 141 | self._m1.lerp_(neighbor_smooth[:,5:8],self._gammas[0]) 142 | if self._gammas[1]: 143 | self._m2.lerp_(neighbor_smooth[:,3],self._gammas[1]) 144 | if self._gammas[2]: 145 | self._nu.lerp_(neighbor_smooth[:,4],self._gammas[2]) 146 | 147 | #add laplace smoothing to gradients 148 | laplace = self._vertices - neighbor_smooth[:,:3] 149 | grad = torch.addcmul(self._vertices.grad, laplace, self._nu[:,None], value=self._laplacian_weight) 150 | 151 | #gradient clipping 152 | if self._step>1: 153 | grad_lim = self._m1.abs().mul_(self._grad_lim) 154 | grad.clamp_(min=-grad_lim,max=grad_lim) 155 | 156 | # moment updates 157 | lerp_unbiased(self._m1, grad, self._betas[0], self._step) 158 | lerp_unbiased(self._m2, (grad**2).sum(dim=-1), self._betas[1], self._step) 159 | 160 | velocity = self._m1 / self._m2[:,None].sqrt().add_(eps) #V,3 161 | speed = velocity.norm(dim=-1) #V 162 | 163 | if self._betas[2]: 164 | lerp_unbiased(self._nu,speed,self._betas[2],self._step) #V 165 | else: 166 | self._nu.copy_(speed) #V 167 | 168 | # update vertices 169 | ramped_lr = self._lr * min(1,self._step * (1-self._betas[0]) / self._ramp) 170 | self._vertices.add_(velocity * self._ref_len[:,None], alpha=-ramped_lr) 171 | 172 | # update target edge length 173 | if self._step % self._remesh_interval == 0: 174 | if self._local_edgelen: 175 | len_change = (1 + (self._nu - self._nu_ref) * self._gain) 176 | else: 177 | len_change = (1 + (self._nu.mean() - self._nu_ref) * self._gain) 178 | self._ref_len *= len_change 179 | self._ref_len.clamp_(*self._edge_len_lims) 180 | 181 | def remesh(self, flip:bool=True, poisson=False)->Tuple[torch.Tensor,torch.Tensor]: 182 | min_edge_len = self._ref_len * (1 - self._edge_len_tol) 183 | max_edge_len = self._ref_len * (1 + self._edge_len_tol) 184 | 185 | self._vertices_etc,self._faces = remesh(self._vertices_etc,self._faces,min_edge_len,max_edge_len,flip, max_vertices=1e6) 186 | 187 | self._split_vertices_etc() 188 | self._vertices.requires_grad_() 189 | 190 | return self._vertices, self._faces 191 | -------------------------------------------------------------------------------- /mesh_reconstruction/recon.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from PIL import Image 3 | import numpy as np 4 | import torch 5 | from typing import List 6 | from mesh_reconstruction.remesh import calc_vertex_normals 7 | from mesh_reconstruction.opt import MeshOptimizer 8 | from mesh_reconstruction.func import make_star_cameras_orthographic, make_star_cameras_orthographic_py3d 9 | from mesh_reconstruction.render import NormalsRenderer, Pytorch3DNormalsRenderer 10 | from scripts.utils import to_py3d_mesh, init_target 11 | 12 | def reconstruct_stage1(pils: List[Image.Image], steps=100, vertices=None, faces=None, start_edge_len=0.15, end_edge_len=0.005, decay=0.995, return_mesh=True, loss_expansion_weight=0.1, gain=0.1): 13 | vertices, faces = vertices.to("cuda"), faces.to("cuda") 14 | assert len(pils) == 4 15 | mv,proj = make_star_cameras_orthographic(4, 1) 16 | renderer = NormalsRenderer(mv,proj,list(pils[0].size)) 17 | # cameras = make_star_cameras_orthographic_py3d([0, 270, 180, 90], device="cuda", focal=1., dist=4.0) 18 | # renderer = Pytorch3DNormalsRenderer(cameras, list(pils[0].size), device="cuda") 19 | 20 | target_images = init_target(pils, new_bkgd=(0., 0., 0.)) # 4s 21 | # 1. no rotate 22 | target_images = target_images[[0, 3, 2, 1]] 23 | 24 | # 2. init from coarse mesh 25 | opt = MeshOptimizer(vertices,faces, local_edgelen=False, gain=gain, edge_len_lims=(end_edge_len, start_edge_len)) 26 | 27 | vertices = opt.vertices 28 | 29 | mask = target_images[..., -1] < 0.5 30 | 31 | for i in tqdm(range(steps)): 32 | opt.zero_grad() 33 | opt._lr *= decay 34 | normals = calc_vertex_normals(vertices,faces) 35 | images = renderer.render(vertices,normals,faces) 36 | 37 | loss_expand = 0.5 * ((vertices+normals).detach() - vertices).pow(2).mean() 38 | 39 | t_mask = images[..., -1] > 0.5 40 | loss_target_l2 = (images[t_mask] - target_images[t_mask]).abs().pow(2).mean() 41 | loss_alpha_target_mask_l2 = (images[..., -1][mask] - target_images[..., -1][mask]).pow(2).mean() 42 | 43 | loss = loss_target_l2 + loss_alpha_target_mask_l2 + loss_expand * loss_expansion_weight 44 | 45 | # out of box 46 | loss_oob = (vertices.abs() > 0.99).float().mean() * 10 47 | loss = loss + loss_oob 48 | 49 | loss.backward() 50 | opt.step() 51 | 52 | vertices,faces = opt.remesh(poisson=False) 53 | 54 | vertices, faces = vertices.detach(), faces.detach() 55 | 56 | if return_mesh: 57 | return to_py3d_mesh(vertices, faces) 58 | else: 59 | return vertices, faces 60 | -------------------------------------------------------------------------------- /mesh_reconstruction/refine.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from PIL import Image 3 | import torch 4 | from typing import List 5 | from mesh_reconstruction.remesh import calc_vertex_normals 6 | from mesh_reconstruction.opt import MeshOptimizer 7 | from mesh_reconstruction.func import make_star_cameras_orthographic, make_star_cameras_orthographic_py3d 8 | from mesh_reconstruction.render import NormalsRenderer, Pytorch3DNormalsRenderer 9 | from scripts.project_mesh import multiview_color_projection, get_cameras_list 10 | from scripts.utils import to_py3d_mesh, from_py3d_mesh, init_target 11 | 12 | def run_mesh_refine(vertices, faces, pils: List[Image.Image], steps=100, start_edge_len=0.02, end_edge_len=0.005, decay=0.99, update_normal_interval=10, update_warmup=10, return_mesh=True, process_inputs=True, process_outputs=True): 13 | if process_inputs: 14 | vertices = vertices * 2 / 1.35 15 | vertices[..., [0, 2]] = - vertices[..., [0, 2]] 16 | 17 | poission_steps = [] 18 | 19 | assert len(pils) == 4 20 | mv,proj = make_star_cameras_orthographic(4, 1) 21 | renderer = NormalsRenderer(mv,proj,list(pils[0].size)) 22 | # cameras = make_star_cameras_orthographic_py3d([0, 270, 180, 90], device="cuda", focal=1., dist=4.0) 23 | # renderer = Pytorch3DNormalsRenderer(cameras, list(pils[0].size), device="cuda") 24 | 25 | target_images = init_target(pils, new_bkgd=(0., 0., 0.)) # 4s 26 | # 1. no rotate 27 | target_images = target_images[[0, 3, 2, 1]] 28 | 29 | # 2. init from coarse mesh 30 | opt = MeshOptimizer(vertices,faces, ramp=5, edge_len_lims=(end_edge_len, start_edge_len), local_edgelen=False, laplacian_weight=0.02) 31 | 32 | vertices = opt.vertices 33 | alpha_init = None 34 | 35 | mask = target_images[..., -1] < 0.5 36 | 37 | for i in tqdm(range(steps)): 38 | opt.zero_grad() 39 | opt._lr *= decay 40 | normals = calc_vertex_normals(vertices,faces) 41 | images = renderer.render(vertices,normals,faces) 42 | if alpha_init is None: 43 | alpha_init = images.detach() 44 | 45 | if i < update_warmup or i % update_normal_interval == 0: 46 | with torch.no_grad(): 47 | py3d_mesh = to_py3d_mesh(vertices, faces, normals) 48 | cameras = get_cameras_list(azim_list = [0, 90, 180, 270], device=vertices.device, focal=1.) 49 | _, _, target_normal = from_py3d_mesh(multiview_color_projection(py3d_mesh, pils, cameras_list=cameras, weights=[2.0, 0.8, 1.0, 0.8], confidence_threshold=0.1, complete_unseen=False, below_confidence_strategy='original', reweight_with_cosangle='linear')) 50 | target_normal = target_normal * 2 - 1 51 | target_normal = torch.nn.functional.normalize(target_normal, dim=-1) 52 | debug_images = renderer.render(vertices,target_normal,faces) 53 | 54 | d_mask = images[..., -1] > 0.5 55 | loss_debug_l2 = (images[..., :3][d_mask] - debug_images[..., :3][d_mask]).pow(2).mean() 56 | 57 | loss_alpha_target_mask_l2 = (images[..., -1][mask] - target_images[..., -1][mask]).pow(2).mean() 58 | 59 | loss = loss_debug_l2 + loss_alpha_target_mask_l2 60 | 61 | # out of box 62 | loss_oob = (vertices.abs() > 0.99).float().mean() * 10 63 | loss = loss + loss_oob 64 | 65 | loss.backward() 66 | opt.step() 67 | 68 | vertices,faces = opt.remesh(poisson=(i in poission_steps)) 69 | 70 | vertices, faces = vertices.detach(), faces.detach() 71 | 72 | if process_outputs: 73 | vertices = vertices / 2 * 1.35 74 | vertices[..., [0, 2]] = - vertices[..., [0, 2]] 75 | 76 | if return_mesh: 77 | return to_py3d_mesh(vertices, faces) 78 | else: 79 | return vertices, faces 80 | -------------------------------------------------------------------------------- /mesh_reconstruction/remesh.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/Profactor/continuous-remeshing 2 | import torch 3 | import torch.nn.functional as tfunc 4 | import torch_scatter 5 | from typing import Tuple 6 | 7 | def prepend_dummies( 8 | vertices:torch.Tensor, #V,D 9 | faces:torch.Tensor, #F,3 long 10 | )->Tuple[torch.Tensor,torch.Tensor]: 11 | """prepend dummy elements to vertices and faces to enable "masked" scatter operations""" 12 | V,D = vertices.shape 13 | vertices = torch.concat((torch.full((1,D),fill_value=torch.nan,device=vertices.device),vertices),dim=0) 14 | faces = torch.concat((torch.zeros((1,3),dtype=torch.long,device=faces.device),faces+1),dim=0) 15 | return vertices,faces 16 | 17 | def remove_dummies( 18 | vertices:torch.Tensor, #V,D - first vertex all nan and unreferenced 19 | faces:torch.Tensor, #F,3 long - first face all zeros 20 | )->Tuple[torch.Tensor,torch.Tensor]: 21 | """remove dummy elements added with prepend_dummies()""" 22 | return vertices[1:],faces[1:]-1 23 | 24 | 25 | def calc_edges( 26 | faces: torch.Tensor, # F,3 long - first face may be dummy with all zeros 27 | with_edge_to_face: bool = False 28 | ) -> Tuple[torch.Tensor, ...]: 29 | """ 30 | returns Tuple of 31 | - edges E,2 long, 0 for unused, lower vertex index first 32 | - face_to_edge F,3 long 33 | - (optional) edge_to_face shape=E,[left,right],[face,side] 34 | 35 | o-<-----e1 e0,e1...edge, e0-o 41 | """ 42 | 43 | F = faces.shape[0] 44 | 45 | # make full edges, lower vertex index first 46 | face_edges = torch.stack((faces,faces.roll(-1,1)),dim=-1) #F*3,3,2 47 | full_edges = face_edges.reshape(F*3,2) 48 | sorted_edges,_ = full_edges.sort(dim=-1) #F*3,2 49 | 50 | # make unique edges 51 | edges,full_to_unique = torch.unique(input=sorted_edges,sorted=True,return_inverse=True,dim=0) #(E,2),(F*3) 52 | E = edges.shape[0] 53 | face_to_edge = full_to_unique.reshape(F,3) #F,3 54 | 55 | if not with_edge_to_face: 56 | return edges, face_to_edge 57 | 58 | is_right = full_edges[:,0]!=sorted_edges[:,0] #F*3 59 | edge_to_face = torch.zeros((E,2,2),dtype=torch.long,device=faces.device) #E,LR=2,S=2 60 | scatter_src = torch.cartesian_prod(torch.arange(0,F,device=faces.device),torch.arange(0,3,device=faces.device)) #F*3,2 61 | edge_to_face.reshape(2*E,2).scatter_(dim=0,index=(2*full_to_unique+is_right)[:,None].expand(F*3,2),src=scatter_src) #E,LR=2,S=2 62 | edge_to_face[0] = 0 63 | return edges, face_to_edge, edge_to_face 64 | 65 | def calc_edge_length( 66 | vertices:torch.Tensor, #V,3 first may be dummy 67 | edges:torch.Tensor, #E,2 long, lower vertex index first, (0,0) for unused 68 | )->torch.Tensor: #E 69 | 70 | full_vertices = vertices[edges] #E,2,3 71 | a,b = full_vertices.unbind(dim=1) #E,3 72 | return torch.norm(a-b,p=2,dim=-1) 73 | 74 | def calc_face_normals( 75 | vertices:torch.Tensor, #V,3 first vertex may be unreferenced 76 | faces:torch.Tensor, #F,3 long, first face may be all zero 77 | normalize:bool=False, 78 | )->torch.Tensor: #F,3 79 | """ 80 | n 81 | | 82 | c0 corners ordered counterclockwise when 83 | / \ looking onto surface (in neg normal direction) 84 | c1---c2 85 | """ 86 | full_vertices = vertices[faces] #F,C=3,3 87 | v0,v1,v2 = full_vertices.unbind(dim=1) #F,3 88 | face_normals = torch.cross(v1-v0,v2-v0, dim=1) #F,3 89 | if normalize: 90 | face_normals = tfunc.normalize(face_normals, eps=1e-6, dim=1) 91 | return face_normals #F,3 92 | 93 | def calc_vertex_normals( 94 | vertices:torch.Tensor, #V,3 first vertex may be unreferenced 95 | faces:torch.Tensor, #F,3 long, first face may be all zero 96 | face_normals:torch.Tensor=None, #F,3, not normalized 97 | )->torch.Tensor: #F,3 98 | 99 | F = faces.shape[0] 100 | 101 | if face_normals is None: 102 | face_normals = calc_face_normals(vertices,faces) 103 | 104 | vertex_normals = torch.zeros((vertices.shape[0],3,3),dtype=vertices.dtype,device=vertices.device) #V,C=3,3 105 | vertex_normals.scatter_add_(dim=0,index=faces[:,:,None].expand(F,3,3),src=face_normals[:,None,:].expand(F,3,3)) 106 | vertex_normals = vertex_normals.sum(dim=1) #V,3 107 | return tfunc.normalize(vertex_normals, eps=1e-6, dim=1) 108 | 109 | def calc_face_ref_normals( 110 | faces:torch.Tensor, #F,3 long, 0 for unused 111 | vertex_normals:torch.Tensor, #V,3 first unused 112 | normalize:bool=False, 113 | )->torch.Tensor: #F,3 114 | """calculate reference normals for face flip detection""" 115 | full_normals = vertex_normals[faces] #F,C=3,3 116 | ref_normals = full_normals.sum(dim=1) #F,3 117 | if normalize: 118 | ref_normals = tfunc.normalize(ref_normals, eps=1e-6, dim=1) 119 | return ref_normals 120 | 121 | def pack( 122 | vertices:torch.Tensor, #V,3 first unused and nan 123 | faces:torch.Tensor, #F,3 long, 0 for unused 124 | )->Tuple[torch.Tensor,torch.Tensor]: #(vertices,faces), keeps first vertex unused 125 | """removes unused elements in vertices and faces""" 126 | V = vertices.shape[0] 127 | 128 | # remove unused faces 129 | used_faces = faces[:,0]!=0 130 | used_faces[0] = True 131 | faces = faces[used_faces] #sync 132 | 133 | # remove unused vertices 134 | used_vertices = torch.zeros(V,3,dtype=torch.bool,device=vertices.device) 135 | used_vertices.scatter_(dim=0,index=faces,value=True,reduce='add') 136 | used_vertices = used_vertices.any(dim=1) 137 | used_vertices[0] = True 138 | vertices = vertices[used_vertices] #sync 139 | 140 | # update used faces 141 | ind = torch.zeros(V,dtype=torch.long,device=vertices.device) 142 | V1 = used_vertices.sum() 143 | ind[used_vertices] = torch.arange(0,V1,device=vertices.device) #sync 144 | faces = ind[faces] 145 | 146 | return vertices,faces 147 | 148 | def split_edges( 149 | vertices:torch.Tensor, #V,3 first unused 150 | faces:torch.Tensor, #F,3 long, 0 for unused 151 | edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first 152 | face_to_edge:torch.Tensor, #F,3 long 0 for unused 153 | splits, #E bool 154 | pack_faces:bool=True, 155 | )->Tuple[torch.Tensor,torch.Tensor]: #(vertices,faces) 156 | 157 | # c2 c2 c...corners = faces 158 | # . . . . s...side_vert, 0 means no split 159 | # . . .N2 . S...shrunk_face 160 | # . . . . Ni...new_faces 161 | # s2 s1 s2|c2...s1|c1 162 | # . . . . . 163 | # . . . S . . 164 | # . . . . N1 . 165 | # c0...(s0=0)....c1 s0|c0...........c1 166 | # 167 | # pseudo-code: 168 | # S = [s0|c0,s1|c1,s2|c2] example:[c0,s1,s2] 169 | # split = side_vert!=0 example:[False,True,True] 170 | # N0 = split[0]*[c0,s0,s2|c2] example:[0,0,0] 171 | # N1 = split[1]*[c1,s1,s0|c0] example:[c1,s1,c0] 172 | # N2 = split[2]*[c2,s2,s1|c1] example:[c2,s2,s1] 173 | 174 | V = vertices.shape[0] 175 | F = faces.shape[0] 176 | S = splits.sum().item() #sync 177 | 178 | if S==0: 179 | return vertices,faces 180 | 181 | edge_vert = torch.zeros_like(splits, dtype=torch.long) #E 182 | edge_vert[splits] = torch.arange(V,V+S,dtype=torch.long,device=vertices.device) #E 0 for no split, sync 183 | side_vert = edge_vert[face_to_edge] #F,3 long, 0 for no split 184 | split_edges = edges[splits] #S sync 185 | 186 | #vertices 187 | split_vertices = vertices[split_edges].mean(dim=1) #S,3 188 | vertices = torch.concat((vertices,split_vertices),dim=0) 189 | 190 | #faces 191 | side_split = side_vert!=0 #F,3 192 | shrunk_faces = torch.where(side_split,side_vert,faces) #F,3 long, 0 for no split 193 | new_faces = side_split[:,:,None] * torch.stack((faces,side_vert,shrunk_faces.roll(1,dims=-1)),dim=-1) #F,N=3,C=3 194 | faces = torch.concat((shrunk_faces,new_faces.reshape(F*3,3))) #4F,3 195 | if pack_faces: 196 | mask = faces[:,0]!=0 197 | mask[0] = True 198 | faces = faces[mask] #F',3 sync 199 | 200 | return vertices,faces 201 | 202 | def collapse_edges( 203 | vertices:torch.Tensor, #V,3 first unused 204 | faces:torch.Tensor, #F,3 long 0 for unused 205 | edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first 206 | priorities:torch.Tensor, #E float 207 | stable:bool=False, #only for unit testing 208 | )->Tuple[torch.Tensor,torch.Tensor]: #(vertices,faces) 209 | 210 | V = vertices.shape[0] 211 | 212 | # check spacing 213 | _,order = priorities.sort(stable=stable) #E 214 | rank = torch.zeros_like(order) 215 | rank[order] = torch.arange(0,len(rank),device=rank.device) 216 | vert_rank = torch.zeros(V,dtype=torch.long,device=vertices.device) #V 217 | edge_rank = rank #E 218 | for i in range(3): 219 | torch_scatter.scatter_max(src=edge_rank[:,None].expand(-1,2).reshape(-1),index=edges.reshape(-1),dim=0,out=vert_rank) 220 | edge_rank,_ = vert_rank[edges].max(dim=-1) #E 221 | candidates = edges[(edge_rank==rank).logical_and_(priorities>0)] #E',2 222 | 223 | # check connectivity 224 | vert_connections = torch.zeros(V,dtype=torch.long,device=vertices.device) #V 225 | vert_connections[candidates[:,0]] = 1 #start 226 | edge_connections = vert_connections[edges].sum(dim=-1) #E, edge connected to start 227 | vert_connections.scatter_add_(dim=0,index=edges.reshape(-1),src=edge_connections[:,None].expand(-1,2).reshape(-1))# one edge from start 228 | vert_connections[candidates] = 0 #clear start and end 229 | edge_connections = vert_connections[edges].sum(dim=-1) #E, one or two edges from start 230 | vert_connections.scatter_add_(dim=0,index=edges.reshape(-1),src=edge_connections[:,None].expand(-1,2).reshape(-1)) #one or two edges from start 231 | collapses = candidates[vert_connections[candidates[:,1]] <= 2] # E" not more than two connections between start and end 232 | 233 | # mean vertices 234 | vertices[collapses[:,0]] = vertices[collapses].mean(dim=1) 235 | 236 | # update faces 237 | dest = torch.arange(0,V,dtype=torch.long,device=vertices.device) #V 238 | dest[collapses[:,1]] = dest[collapses[:,0]] 239 | faces = dest[faces] #F,3 240 | c0,c1,c2 = faces.unbind(dim=-1) 241 | collapsed = (c0==c1).logical_or_(c1==c2).logical_or_(c0==c2) 242 | faces[collapsed] = 0 243 | 244 | return vertices,faces 245 | 246 | def calc_face_collapses( 247 | vertices:torch.Tensor, #V,3 first unused 248 | faces:torch.Tensor, #F,3 long, 0 for unused 249 | edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first 250 | face_to_edge:torch.Tensor, #F,3 long 0 for unused 251 | edge_length:torch.Tensor, #E 252 | face_normals:torch.Tensor, #F,3 253 | vertex_normals:torch.Tensor, #V,3 first unused 254 | min_edge_length:torch.Tensor=None, #V 255 | area_ratio = 0.5, #collapse if area < min_edge_length**2 * area_ratio 256 | shortest_probability = 0.8 257 | )->torch.Tensor: #E edges to collapse 258 | 259 | E = edges.shape[0] 260 | F = faces.shape[0] 261 | 262 | # face flips 263 | ref_normals = calc_face_ref_normals(faces,vertex_normals,normalize=False) #F,3 264 | face_collapses = (face_normals*ref_normals).sum(dim=-1)<0 #F 265 | 266 | # small faces 267 | if min_edge_length is not None: 268 | min_face_length = min_edge_length[faces].mean(dim=-1) #F 269 | min_area = min_face_length**2 * area_ratio #F 270 | face_collapses.logical_or_(face_normals.norm(dim=-1) < min_area*2) #F 271 | face_collapses[0] = False 272 | 273 | # faces to edges 274 | face_length = edge_length[face_to_edge] #F,3 275 | 276 | if shortest_probability<1: 277 | #select shortest edge with shortest_probability chance 278 | randlim = round(2/(1-shortest_probability)) 279 | rand_ind = torch.randint(0,randlim,size=(F,),device=faces.device).clamp_max_(2) #selected edge local index in face 280 | sort_ind = torch.argsort(face_length,dim=-1,descending=True) #F,3 281 | local_ind = sort_ind.gather(dim=-1,index=rand_ind[:,None]) 282 | else: 283 | local_ind = torch.argmin(face_length,dim=-1)[:,None] #F,1 0...2 shortest edge local index in face 284 | 285 | edge_ind = face_to_edge.gather(dim=1,index=local_ind)[:,0] #F 0...E selected edge global index 286 | edge_collapses = torch.zeros(E,dtype=torch.long,device=vertices.device) 287 | edge_collapses.scatter_add_(dim=0,index=edge_ind,src=face_collapses.long()) 288 | 289 | return edge_collapses.bool() 290 | 291 | def flip_edges( 292 | vertices:torch.Tensor, #V,3 first unused 293 | faces:torch.Tensor, #F,3 long, first must be 0, 0 for unused 294 | edges:torch.Tensor, #E,2 long, first must be 0, 0 for unused, lower vertex index first 295 | edge_to_face:torch.Tensor, #E,[left,right],[face,side] 296 | with_border:bool=True, #handle border edges (D=4 instead of D=6) 297 | with_normal_check:bool=True, #check face normal flips 298 | stable:bool=False, #only for unit testing 299 | ): 300 | V = vertices.shape[0] 301 | E = edges.shape[0] 302 | device=vertices.device 303 | vertex_degree = torch.zeros(V,dtype=torch.long,device=device) #V long 304 | vertex_degree.scatter_(dim=0,index=edges.reshape(E*2),value=1,reduce='add') 305 | neighbor_corner = (edge_to_face[:,:,1] + 2) % 3 #go from side to corner 306 | neighbors = faces[edge_to_face[:,:,0],neighbor_corner] #E,LR=2 307 | edge_is_inside = neighbors.all(dim=-1) #E 308 | 309 | if with_border: 310 | # inside vertices should have D=6, border edges D=4, so we subtract 2 for all inside vertices 311 | # need to use float for masks in order to use scatter(reduce='multiply') 312 | vertex_is_inside = torch.ones(V,2,dtype=torch.float32,device=vertices.device) #V,2 float 313 | src = edge_is_inside.type(torch.float32)[:,None].expand(E,2) #E,2 float 314 | vertex_is_inside.scatter_(dim=0,index=edges,src=src,reduce='multiply') 315 | vertex_is_inside = vertex_is_inside.prod(dim=-1,dtype=torch.long) #V long 316 | vertex_degree -= 2 * vertex_is_inside #V long 317 | 318 | neighbor_degrees = vertex_degree[neighbors] #E,LR=2 319 | edge_degrees = vertex_degree[edges] #E,2 320 | # 321 | # loss = Sum_over_affected_vertices((new_degree-6)**2) 322 | # loss_change = Sum_over_neighbor_vertices((degree+1-6)**2-(degree-6)**2) 323 | # + Sum_over_edge_vertices((degree-1-6)**2-(degree-6)**2) 324 | # = 2 * (2 + Sum_over_neighbor_vertices(degree) - Sum_over_edge_vertices(degree)) 325 | # 326 | loss_change = 2 + neighbor_degrees.sum(dim=-1) - edge_degrees.sum(dim=-1) #E 327 | candidates = torch.logical_and(loss_change<0, edge_is_inside) #E 328 | loss_change = loss_change[candidates] #E' 329 | if loss_change.shape[0]==0: 330 | return 331 | 332 | edges_neighbors = torch.concat((edges[candidates],neighbors[candidates]),dim=-1) #E',4 333 | _,order = loss_change.sort(descending=True, stable=stable) #E' 334 | rank = torch.zeros_like(order) 335 | rank[order] = torch.arange(0,len(rank),device=rank.device) 336 | vertex_rank = torch.zeros((V,4),dtype=torch.long,device=device) #V,4 337 | torch_scatter.scatter_max(src=rank[:,None].expand(-1,4),index=edges_neighbors,dim=0,out=vertex_rank) 338 | vertex_rank,_ = vertex_rank.max(dim=-1) #V 339 | neighborhood_rank,_ = vertex_rank[edges_neighbors].max(dim=-1) #E' 340 | flip = rank==neighborhood_rank #E' 341 | 342 | if with_normal_check: 343 | # cl-<-----e1 e0,e1...edge, e0-cr 349 | v = vertices[edges_neighbors] #E",4,3 350 | v = v - v[:,0:1] #make relative to e0 351 | e1 = v[:,1] 352 | cl = v[:,2] 353 | cr = v[:,3] 354 | n = torch.cross(e1,cl) + torch.cross(cr,e1) #sum of old normal vectors 355 | flip.logical_and_(torch.sum(n*torch.cross(cr,cl),dim=-1)>0) #first new face 356 | flip.logical_and_(torch.sum(n*torch.cross(cl-e1,cr-e1),dim=-1)>0) #second new face 357 | 358 | flip_edges_neighbors = edges_neighbors[flip] #E",4 359 | flip_edge_to_face = edge_to_face[candidates,:,0][flip] #E",2 360 | flip_faces = flip_edges_neighbors[:,[[0,3,2],[1,2,3]]] #E",2,3 361 | faces.scatter_(dim=0,index=flip_edge_to_face.reshape(-1,1).expand(-1,3),src=flip_faces.reshape(-1,3)) 362 | -------------------------------------------------------------------------------- /mesh_reconstruction/render.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/Profactor/continuous-remeshing 2 | import nvdiffrast.torch as dr 3 | import torch 4 | from typing import Tuple 5 | 6 | def _warmup(glctx, device=None): 7 | device = 'cuda' if device is None else device 8 | #windows workaround for https://github.com/NVlabs/nvdiffrast/issues/59 9 | def tensor(*args, **kwargs): 10 | return torch.tensor(*args, device=device, **kwargs) 11 | pos = tensor([[[-0.8, -0.8, 0, 1], [0.8, -0.8, 0, 1], [-0.8, 0.8, 0, 1]]], dtype=torch.float32) 12 | tri = tensor([[0, 1, 2]], dtype=torch.int32) 13 | dr.rasterize(glctx, pos, tri, resolution=[256, 256]) 14 | 15 | glctx = dr.RasterizeGLContext(output_db=False, device="cuda") 16 | 17 | class NormalsRenderer: 18 | 19 | _glctx:dr.RasterizeGLContext = None 20 | 21 | def __init__( 22 | self, 23 | mv: torch.Tensor, #C,4,4 24 | proj: torch.Tensor, #C,4,4 25 | image_size: Tuple[int,int], 26 | mvp = None, 27 | device=None, 28 | ): 29 | if mvp is None: 30 | self._mvp = proj @ mv #C,4,4 31 | else: 32 | self._mvp = mvp 33 | self._image_size = image_size 34 | self._glctx = glctx 35 | _warmup(self._glctx, device) 36 | 37 | def render(self, 38 | vertices: torch.Tensor, #V,3 float 39 | normals: torch.Tensor, #V,3 float in [-1, 1] 40 | faces: torch.Tensor, #F,3 long 41 | ) ->torch.Tensor: #C,H,W,4 42 | 43 | V = vertices.shape[0] 44 | faces = faces.type(torch.int32) 45 | vert_hom = torch.cat((vertices, torch.ones(V,1,device=vertices.device)),axis=-1) #V,3 -> V,4 46 | vertices_clip = vert_hom @ self._mvp.transpose(-2,-1) #C,V,4 47 | rast_out,_ = dr.rasterize(self._glctx, vertices_clip, faces, resolution=self._image_size, grad_db=False) #C,H,W,4 48 | vert_col = (normals+1)/2 #V,3 49 | col,_ = dr.interpolate(vert_col, rast_out, faces) #C,H,W,3 50 | alpha = torch.clamp(rast_out[..., -1:], max=1) #C,H,W,1 51 | col = torch.concat((col,alpha),dim=-1) #C,H,W,4 52 | col = dr.antialias(col, rast_out, vertices_clip, faces) #C,H,W,4 53 | return col #C,H,W,4 54 | 55 | 56 | 57 | from pytorch3d.structures import Meshes 58 | from pytorch3d.renderer.mesh.shader import ShaderBase 59 | from pytorch3d.renderer import ( 60 | RasterizationSettings, 61 | MeshRendererWithFragments, 62 | TexturesVertex, 63 | MeshRasterizer, 64 | BlendParams, 65 | FoVOrthographicCameras, 66 | look_at_view_transform, 67 | hard_rgb_blend, 68 | ) 69 | 70 | class VertexColorShader(ShaderBase): 71 | def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: 72 | blend_params = kwargs.get("blend_params", self.blend_params) 73 | texels = meshes.sample_textures(fragments) 74 | return hard_rgb_blend(texels, fragments, blend_params) 75 | 76 | def render_mesh_vertex_color(mesh, cameras, H, W, blur_radius=0.0, faces_per_pixel=1, bkgd=(0., 0., 0.), dtype=torch.float32, device="cuda"): 77 | if len(mesh) != len(cameras): 78 | if len(cameras) % len(mesh) == 0: 79 | mesh = mesh.extend(len(cameras)) 80 | else: 81 | raise NotImplementedError() 82 | 83 | # render requires everything in float16 or float32 84 | input_dtype = dtype 85 | blend_params = BlendParams(1e-4, 1e-4, bkgd) 86 | 87 | # Define the settings for rasterization and shading 88 | raster_settings = RasterizationSettings( 89 | image_size=(H, W), 90 | blur_radius=blur_radius, 91 | faces_per_pixel=faces_per_pixel, 92 | clip_barycentric_coords=True, 93 | bin_size=None, 94 | max_faces_per_bin=None, 95 | ) 96 | 97 | # Create a renderer by composing a rasterizer and a shader 98 | # We simply render vertex colors through the custom VertexColorShader (no lighting, materials are used) 99 | renderer = MeshRendererWithFragments( 100 | rasterizer=MeshRasterizer( 101 | cameras=cameras, 102 | raster_settings=raster_settings 103 | ), 104 | shader=VertexColorShader( 105 | device=device, 106 | cameras=cameras, 107 | blend_params=blend_params 108 | ) 109 | ) 110 | 111 | # render RGB and depth, get mask 112 | with torch.autocast(dtype=input_dtype, device_type=torch.device(device).type): 113 | images, _ = renderer(mesh) 114 | return images # BHW4 115 | 116 | class Pytorch3DNormalsRenderer: # 100 times slower!!! 117 | def __init__(self, cameras, image_size, device): 118 | self.cameras = cameras.to(device) 119 | self._image_size = image_size 120 | self.device = device 121 | 122 | def render(self, 123 | vertices: torch.Tensor, #V,3 float 124 | normals: torch.Tensor, #V,3 float in [-1, 1] 125 | faces: torch.Tensor, #F,3 long 126 | ) ->torch.Tensor: #C,H,W,4 127 | mesh = Meshes(verts=[vertices], faces=[faces], textures=TexturesVertex(verts_features=[(normals + 1) / 2])).to(self.device) 128 | return render_mesh_vertex_color(mesh, self.cameras, self._image_size[0], self._image_size[1], device=self.device) 129 | 130 | def save_tensor_to_img(tensor, save_dir): 131 | from PIL import Image 132 | import numpy as np 133 | for idx, img in enumerate(tensor): 134 | img = img[..., :3].cpu().numpy() 135 | img = (img * 255).astype(np.uint8) 136 | img = Image.fromarray(img) 137 | img.save(save_dir + f"{idx}.png") 138 | 139 | if __name__ == "__main__": 140 | import sys 141 | import os 142 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 143 | from mesh_reconstruction.func import make_star_cameras_orthographic, make_star_cameras_orthographic_py3d 144 | cameras = make_star_cameras_orthographic_py3d([0, 270, 180, 90], device="cuda", focal=1., dist=4.0) 145 | mv,proj = make_star_cameras_orthographic(4, 1) 146 | resolution = 1024 147 | renderer1 = NormalsRenderer(mv,proj, [resolution,resolution], device="cuda") 148 | renderer2 = Pytorch3DNormalsRenderer(cameras, [resolution,resolution], device="cuda") 149 | vertices = torch.tensor([[0,0,0],[0,0,1],[0,1,0],[1,0,0]], device="cuda", dtype=torch.float32) 150 | normals = torch.tensor([[-1,-1,-1],[1,-1,-1],[-1,-1,1],[-1,1,-1]], device="cuda", dtype=torch.float32) 151 | faces = torch.tensor([[0,1,2],[0,1,3],[0,2,3],[1,2,3]], device="cuda", dtype=torch.long) 152 | 153 | import time 154 | t0 = time.time() 155 | r1 = renderer1.render(vertices, normals, faces) 156 | print("time r1:", time.time() - t0) 157 | 158 | t0 = time.time() 159 | r2 = renderer2.render(vertices, normals, faces) 160 | print("time r2:", time.time() - t0) 161 | 162 | for i in range(4): 163 | print((r1[i]-r2[i]).abs().mean(), (r1[i]+r2[i]).abs().mean()) -------------------------------------------------------------------------------- /requirements-detail.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.29.2 2 | datasets==2.18.0 3 | diffusers==0.27.2 4 | fire==0.6.0 5 | gradio==4.32.0 6 | jaxtyping==0.2.29 7 | numba==0.59.1 8 | numpy==1.26.4 9 | nvdiffrast==0.3.1 10 | omegaconf==2.3.0 11 | onnxruntime_gpu==1.17.0 12 | opencv_python==4.9.0.80 13 | opencv_python_headless==4.9.0.80 14 | ort_nightly_gpu==1.17.0.dev20240118002 15 | peft==0.10.0 16 | Pillow==10.3.0 17 | pygltflib==1.16.2 18 | pymeshlab==2023.12.post1 19 | pytorch3d==0.7.5 20 | rembg==2.0.56 21 | torch==2.1.0+cu121 22 | torch_scatter==2.1.2 23 | tqdm==4.64.1 24 | transformers==4.39.3 25 | trimesh==4.3.0 26 | typeguard==2.13.3 27 | wandb==0.16.6 28 | -------------------------------------------------------------------------------- /requirements-win-py311-cu121.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | datasets 3 | fire 4 | gradio 5 | jaxtyping 6 | numba 7 | numpy 8 | git+https://github.com/NVlabs/nvdiffrast.git 9 | omegaconf>=2.3.0 10 | opencv_python 11 | opencv_python_headless 12 | ort_nightly_gpu 13 | peft 14 | Pillow 15 | pygltflib 16 | pymeshlab>=2023.12 17 | git+https://github.com/facebookresearch/pytorch3d.git@stable 18 | rembg[gpu] 19 | torch_scatter 20 | tqdm 21 | transformers 22 | trimesh 23 | typeguard 24 | wandb 25 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | datasets 3 | diffusers>=0.26.3 4 | fire 5 | gradio 6 | jaxtyping 7 | numba 8 | numpy 9 | git+https://github.com/NVlabs/nvdiffrast.git 10 | omegaconf>=2.3.0 11 | onnxruntime_gpu 12 | opencv_python 13 | opencv_python_headless 14 | ort_nightly_gpu 15 | peft 16 | Pillow 17 | pygltflib 18 | pymeshlab>=2023.12 19 | git+https://github.com/facebookresearch/pytorch3d.git@stable 20 | rembg[gpu] 21 | #torch>=2.0.1 22 | torch_scatter 23 | tqdm 24 | transformers 25 | trimesh 26 | typeguard 27 | wandb 28 | xformers 29 | -------------------------------------------------------------------------------- /scripts/all_typing.py: -------------------------------------------------------------------------------- 1 | # code from https://github.com/threestudio-project 2 | 3 | """ 4 | This module contains type annotations for the project, using 5 | 1. Python type hints (https://docs.python.org/3/library/typing.html) for Python objects 6 | 2. jaxtyping (https://github.com/google/jaxtyping/blob/main/API.md) for PyTorch tensors 7 | 8 | Two types of typing checking can be used: 9 | 1. Static type checking with mypy (install with pip and enabled as the default linter in VSCode) 10 | 2. Runtime type checking with typeguard (install with pip and triggered at runtime, mainly for tensor dtype and shape checking) 11 | """ 12 | 13 | # Basic types 14 | from typing import ( 15 | Any, 16 | Callable, 17 | Dict, 18 | Iterable, 19 | List, 20 | Literal, 21 | NamedTuple, 22 | NewType, 23 | Optional, 24 | Sized, 25 | Tuple, 26 | Type, 27 | TypeVar, 28 | Union, 29 | ) 30 | 31 | # Tensor dtype 32 | # for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md 33 | from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt 34 | 35 | # Config type 36 | from omegaconf import DictConfig 37 | 38 | # PyTorch Tensor type 39 | from torch import Tensor 40 | 41 | # Runtime type checking decorator 42 | from typeguard import typechecked as typechecker 43 | -------------------------------------------------------------------------------- /scripts/load_onnx.py: -------------------------------------------------------------------------------- 1 | import onnxruntime 2 | import torch 3 | 4 | providers = [ 5 | ('TensorrtExecutionProvider', { 6 | 'device_id': 0, 7 | 'trt_max_workspace_size': 8 * 1024 * 1024 * 1024, 8 | 'trt_fp16_enable': True, 9 | 'trt_engine_cache_enable': True, 10 | }), 11 | ('CUDAExecutionProvider', { 12 | 'device_id': 0, 13 | 'arena_extend_strategy': 'kSameAsRequested', 14 | 'gpu_mem_limit': 8 * 1024 * 1024 * 1024, 15 | 'cudnn_conv_algo_search': 'HEURISTIC', 16 | }) 17 | ] 18 | 19 | def load_onnx(file_path: str): 20 | assert file_path.endswith(".onnx") 21 | sess_opt = onnxruntime.SessionOptions() 22 | ort_session = onnxruntime.InferenceSession(file_path, sess_opt=sess_opt, providers=providers) 23 | return ort_session 24 | 25 | 26 | def load_onnx_caller(file_path: str, single_output=False): 27 | ort_session = load_onnx(file_path) 28 | def caller(*args): 29 | torch_input = isinstance(args[0], torch.Tensor) 30 | if torch_input: 31 | torch_input_dtype = args[0].dtype 32 | torch_input_device = args[0].device 33 | # check all are torch.Tensor and have same dtype and device 34 | assert all([isinstance(arg, torch.Tensor) for arg in args]), "All inputs should be torch.Tensor, if first input is torch.Tensor" 35 | assert all([arg.dtype == torch_input_dtype for arg in args]), "All inputs should have same dtype, if first input is torch.Tensor" 36 | assert all([arg.device == torch_input_device for arg in args]), "All inputs should have same device, if first input is torch.Tensor" 37 | args = [arg.cpu().float().numpy() for arg in args] 38 | 39 | ort_inputs = {ort_session.get_inputs()[idx].name: args[idx] for idx in range(len(args))} 40 | ort_outs = ort_session.run(None, ort_inputs) 41 | 42 | if torch_input: 43 | ort_outs = [torch.tensor(ort_out, dtype=torch_input_dtype, device=torch_input_device) for ort_out in ort_outs] 44 | 45 | if single_output: 46 | return ort_outs[0] 47 | return ort_outs 48 | return caller 49 | -------------------------------------------------------------------------------- /scripts/mesh_init.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import torch 3 | import numpy as np 4 | from pytorch3d.structures import Meshes 5 | from pytorch3d.renderer import TexturesVertex 6 | from scripts.utils import meshlab_mesh_to_py3dmesh, py3dmesh_to_meshlab_mesh 7 | import pymeshlab 8 | 9 | _MAX_THREAD = 8 10 | 11 | # rgb and depth to mesh 12 | def get_ortho_ray_directions_origins(W, H, use_pixel_centers=True, device="cuda"): 13 | pixel_center = 0.5 if use_pixel_centers else 0 14 | i, j = np.meshgrid( 15 | np.arange(W, dtype=np.float32) + pixel_center, 16 | np.arange(H, dtype=np.float32) + pixel_center, 17 | indexing='xy' 18 | ) 19 | i, j = torch.from_numpy(i).to(device), torch.from_numpy(j).to(device) 20 | 21 | origins = torch.stack([(i/W-0.5)*2, (j/H-0.5)*2 * H / W, torch.zeros_like(i)], dim=-1) # W, H, 3 22 | directions = torch.stack([torch.zeros_like(i), torch.zeros_like(j), torch.ones_like(i)], dim=-1) # W, H, 3 23 | 24 | return origins, directions 25 | 26 | def depth_and_color_to_mesh(rgb_BCHW, pred_HWC, valid_HWC=None, is_back=False): 27 | if valid_HWC is None: 28 | valid_HWC = torch.ones_like(pred_HWC).bool() 29 | H, W = rgb_BCHW.shape[-2:] 30 | rgb_BCHW = rgb_BCHW.flip(-2) 31 | pred_HWC = pred_HWC.flip(0) 32 | valid_HWC = valid_HWC.flip(0) 33 | rays_o, rays_d = get_ortho_ray_directions_origins(W, H, device=rgb_BCHW.device) 34 | verts = rays_o + rays_d * pred_HWC # [H, W, 3] 35 | verts = verts.reshape(-1, 3) # [V, 3] 36 | indexes = torch.arange(H * W).reshape(H, W).to(rgb_BCHW.device) 37 | faces1 = torch.stack([indexes[:-1, :-1], indexes[:-1, 1:], indexes[1:, :-1]], dim=-1) 38 | # faces1_valid = valid_HWC[:-1, :-1] | valid_HWC[:-1, 1:] | valid_HWC[1:, :-1] 39 | faces1_valid = valid_HWC[:-1, :-1] & valid_HWC[:-1, 1:] & valid_HWC[1:, :-1] 40 | faces2 = torch.stack([indexes[1:, 1:], indexes[1:, :-1], indexes[:-1, 1:]], dim=-1) 41 | # faces2_valid = valid_HWC[1:, 1:] | valid_HWC[1:, :-1] | valid_HWC[:-1, 1:] 42 | faces2_valid = valid_HWC[1:, 1:] & valid_HWC[1:, :-1] & valid_HWC[:-1, 1:] 43 | faces = torch.cat([faces1[faces1_valid.expand_as(faces1)].reshape(-1, 3), faces2[faces2_valid.expand_as(faces2)].reshape(-1, 3)], dim=0) # (F, 3) 44 | colors = (rgb_BCHW[0].permute((1,2,0)) / 2 + 0.5).reshape(-1, 3) # (V, 3) 45 | if is_back: 46 | verts = verts * torch.tensor([-1, 1, -1], dtype=verts.dtype, device=verts.device) 47 | 48 | used_verts = faces.unique() 49 | old_to_new_mapping = torch.zeros_like(verts[..., 0]).long() 50 | old_to_new_mapping[used_verts] = torch.arange(used_verts.shape[0], device=verts.device) 51 | new_faces = old_to_new_mapping[faces] 52 | mesh = Meshes(verts=[verts[used_verts]], faces=[new_faces], textures=TexturesVertex(verts_features=[colors[used_verts]])) 53 | return mesh 54 | 55 | def normalmap_to_depthmap(normal_np): 56 | from scripts.normal_to_height_map import estimate_height_map 57 | height = estimate_height_map(normal_np, raw_values=True, thread_count=_MAX_THREAD, target_iteration_count=96) 58 | return height 59 | 60 | def transform_back_normal_to_front(normal_pil): 61 | arr = np.array(normal_pil) # in [0, 255] 62 | arr[..., 0] = 255-arr[..., 0] 63 | arr[..., 2] = 255-arr[..., 2] 64 | return Image.fromarray(arr.astype(np.uint8)) 65 | 66 | def calc_w_over_h(normal_pil): 67 | if isinstance(normal_pil, Image.Image): 68 | arr = np.array(normal_pil) 69 | else: 70 | assert isinstance(normal_pil, np.ndarray) 71 | arr = normal_pil 72 | if arr.shape[-1] == 4: 73 | alpha = arr[..., -1] / 255. 74 | alpha[alpha >= 0.5] = 1 75 | alpha[alpha < 0.5] = 0 76 | else: 77 | alpha = ~(arr.min(axis=-1) >= 250) 78 | h_min, w_min = np.min(np.where(alpha), axis=1) 79 | h_max, w_max = np.max(np.where(alpha), axis=1) 80 | return (w_max - w_min) / (h_max - h_min) 81 | 82 | def build_mesh(normal_pil, rgb_pil, is_back=False, clamp_min=-1, scale=0.3, init_type="std", offset=0): 83 | if is_back: 84 | normal_pil = transform_back_normal_to_front(normal_pil) 85 | normal_img = np.array(normal_pil) 86 | rgb_img = np.array(rgb_pil) 87 | if normal_img.shape[-1] == 4: 88 | valid_HWC = normal_img[..., [3]] / 255 89 | elif rgb_img.shape[-1] == 4: 90 | valid_HWC = rgb_img[..., [3]] / 255 91 | else: 92 | raise ValueError("invalid input, either normal or rgb should have alpha channel") 93 | 94 | real_height_pix = np.max(np.where(valid_HWC>0.5)[0]) - np.min(np.where(valid_HWC>0.5)[0]) 95 | 96 | heights = normalmap_to_depthmap(normal_img) 97 | rgb_BCHW = torch.from_numpy(rgb_img[..., :3] / 255.).permute((2,0,1))[None] 98 | valid_HWC[valid_HWC < 0.5] = 0 99 | valid_HWC[valid_HWC >= 0.5] = 1 100 | valid_HWC = torch.from_numpy(valid_HWC).bool() 101 | if init_type == "std": 102 | # accurate but not stable 103 | pred_HWC = torch.from_numpy(heights / heights.max() * (real_height_pix / heights.shape[0]) * scale * 2).float()[..., None] 104 | elif init_type == "thin": 105 | heights = heights - heights.min() 106 | heights = (heights / heights.max() * 0.2) 107 | pred_HWC = torch.from_numpy(heights * scale).float()[..., None] 108 | else: 109 | # stable but not accurate 110 | heights = heights - heights.min() 111 | heights = (heights / heights.max() * (1-offset)) + offset # to [0.2, 1] 112 | pred_HWC = torch.from_numpy(heights * scale).float()[..., None] 113 | 114 | # set the boarder pixels to 0 height 115 | import cv2 116 | # edge filter 117 | edge = cv2.Canny((valid_HWC[..., 0] * 255).numpy().astype(np.uint8), 0, 255) 118 | edge = torch.from_numpy(edge).bool()[..., None] 119 | pred_HWC[edge] = 0 120 | 121 | valid_HWC[pred_HWC < clamp_min] = False 122 | return depth_and_color_to_mesh(rgb_BCHW.cuda(), pred_HWC.cuda(), valid_HWC.cuda(), is_back) 123 | 124 | def fix_border_with_pymeshlab_fast(meshes: Meshes, poissson_depth=6, simplification=0): 125 | ms = pymeshlab.MeshSet() 126 | ms.add_mesh(py3dmesh_to_meshlab_mesh(meshes), "cube_vcolor_mesh") 127 | if simplification > 0: 128 | ms.apply_filter('meshing_decimation_quadric_edge_collapse', targetfacenum=simplification, preservetopology=True) 129 | ms.apply_filter('generate_surface_reconstruction_screened_poisson', threads = 6, depth = poissson_depth, preclean = True) 130 | if simplification > 0: 131 | ms.apply_filter('meshing_decimation_quadric_edge_collapse', targetfacenum=simplification, preservetopology=True) 132 | return meshlab_mesh_to_py3dmesh(ms.current_mesh()) 133 | -------------------------------------------------------------------------------- /scripts/multiview_inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from scripts.mesh_init import build_mesh, calc_w_over_h, fix_border_with_pymeshlab_fast 4 | from scripts.project_mesh import multiview_color_projection 5 | from scripts.refine_lr_to_sr import run_sr_fast 6 | from scripts.utils import simple_clean_mesh 7 | from app.utils import simple_remove, split_image 8 | from app.custom_models.normal_prediction import predict_normals 9 | from mesh_reconstruction.recon import reconstruct_stage1 10 | from mesh_reconstruction.refine import run_mesh_refine 11 | from scripts.project_mesh import get_cameras_list 12 | from scripts.utils import from_py3d_mesh, to_pyml_mesh 13 | from pytorch3d.structures import Meshes, join_meshes_as_scene 14 | import numpy as np 15 | 16 | def fast_geo(front_normal: Image.Image, back_normal: Image.Image, side_normal: Image.Image, clamp=0., init_type="std"): 17 | import time 18 | if front_normal.mode == "RGB": 19 | front_normal = simple_remove(front_normal, run_sr=False) 20 | front_normal = front_normal.resize((192, 192)) 21 | if back_normal.mode == "RGB": 22 | back_normal = simple_remove(back_normal, run_sr=False) 23 | back_normal = back_normal.resize((192, 192)) 24 | if side_normal.mode == "RGB": 25 | side_normal = simple_remove(side_normal, run_sr=False) 26 | side_normal = side_normal.resize((192, 192)) 27 | 28 | # build mesh with front back projection # ~3s 29 | side_w_over_h = calc_w_over_h(side_normal) 30 | mesh_front = build_mesh(front_normal, front_normal, clamp_min=clamp, scale=side_w_over_h, init_type=init_type) 31 | mesh_back = build_mesh(back_normal, back_normal, is_back=True, clamp_min=clamp, scale=side_w_over_h, init_type=init_type) 32 | meshes = join_meshes_as_scene([mesh_front, mesh_back]) 33 | meshes = fix_border_with_pymeshlab_fast(meshes, poissson_depth=6, simplification=2000) 34 | return meshes 35 | 36 | def refine_rgb(rgb_pils, front_pil): 37 | from scripts.refine_lr_to_sr import refine_lr_with_sd 38 | from scripts.utils import NEG_PROMPT 39 | from app.utils import make_image_grid 40 | from app.all_models import model_zoo 41 | from app.utils import rgba_to_rgb 42 | rgb_pil = make_image_grid(rgb_pils, rows=2) 43 | prompt = "4views, multiview" 44 | neg_prompt = NEG_PROMPT 45 | control_image = rgb_pil.resize((1024, 1024)) 46 | refined_rgb = refine_lr_with_sd([rgb_pil], [rgba_to_rgb(front_pil)], [control_image], prompt_list=[prompt], neg_prompt_list=[neg_prompt], pipe=model_zoo.pipe_disney_controlnet_tile_ipadapter_i2i, strength=0.2, output_size=(1024, 1024))[0] 47 | refined_rgbs = split_image(refined_rgb, rows=2) 48 | return refined_rgbs 49 | 50 | def erode_alpha(img_list): 51 | out_img_list = [] 52 | for idx, img in enumerate(img_list): 53 | arr = np.array(img) 54 | alpha = (arr[:, :, 3] > 127).astype(np.uint8) 55 | # erode 1px 56 | import cv2 57 | alpha = cv2.erode(alpha, np.ones((3, 3), np.uint8), iterations=1) 58 | alpha = (alpha * 255).astype(np.uint8) 59 | img = Image.fromarray(np.concatenate([arr[:, :, :3], alpha[:, :, None]], axis=-1)) 60 | out_img_list.append(img) 61 | return out_img_list 62 | import time 63 | def geo_reconstruct(rgb_pils, normal_pils, front_pil, do_refine=False, predict_normal=True, expansion_weight=0.1, init_type="std"): 64 | if front_pil.size[0] <= 512: 65 | front_pil = run_sr_fast([front_pil])[0] 66 | if do_refine: 67 | refined_rgbs = refine_rgb(rgb_pils, front_pil) # 6s 68 | else: 69 | refined_rgbs = [rgb.resize((512, 512), resample=Image.LANCZOS) for rgb in rgb_pils] 70 | img_list = [front_pil] + run_sr_fast(refined_rgbs[1:]) 71 | 72 | if predict_normal: 73 | rm_normals = predict_normals([img.resize((512, 512), resample=Image.LANCZOS) for img in img_list], guidance_scale=1.5) 74 | else: 75 | rm_normals = simple_remove([img.resize((512, 512), resample=Image.LANCZOS) for img in normal_pils]) 76 | # transfer the alpha channel of rm_normals to img_list 77 | for idx, img in enumerate(rm_normals): 78 | if idx == 0 and img_list[0].mode == "RGBA": 79 | temp = img_list[0].resize((2048, 2048)) 80 | rm_normals[0] = Image.fromarray(np.concatenate([np.array(rm_normals[0])[:, :, :3], np.array(temp)[:, :, 3:4]], axis=-1)) 81 | continue 82 | img_list[idx] = Image.fromarray(np.concatenate([np.array(img_list[idx]), np.array(img)[:, :, 3:4]], axis=-1)) 83 | assert img_list[0].mode == "RGBA" 84 | assert np.mean(np.array(img_list[0])[..., 3]) < 250 85 | 86 | img_list = [img_list[0]] + erode_alpha(img_list[1:]) 87 | normal_stg1 = [img.resize((512, 512)) for img in rm_normals] 88 | if init_type in ["std", "thin"]: 89 | meshes = fast_geo(normal_stg1[0], normal_stg1[2], normal_stg1[1], init_type=init_type) 90 | _ = multiview_color_projection(meshes, rgb_pils, resolution=512, device="cuda", complete_unseen=False, confidence_threshold=0.1) # just check for validation, may throw error 91 | vertices, faces, _ = from_py3d_mesh(meshes) 92 | vertices, faces = reconstruct_stage1(normal_stg1, steps=200, vertices=vertices, faces=faces, start_edge_len=0.1, end_edge_len=0.02, gain=0.05, return_mesh=False, loss_expansion_weight=expansion_weight) 93 | elif init_type in ["ball"]: 94 | vertices, faces = reconstruct_stage1(normal_stg1, steps=200, end_edge_len=0.01, return_mesh=False, loss_expansion_weight=expansion_weight) 95 | vertices, faces = run_mesh_refine(vertices, faces, rm_normals, steps=100, start_edge_len=0.02, end_edge_len=0.005, decay=0.99, update_normal_interval=20, update_warmup=5, return_mesh=False, process_inputs=False, process_outputs=False) 96 | meshes = simple_clean_mesh(to_pyml_mesh(vertices, faces), apply_smooth=True, stepsmoothnum=1, apply_sub_divide=True, sub_divide_threshold=0.25).to("cuda") 97 | new_meshes = multiview_color_projection(meshes, img_list, resolution=1024, device="cuda", complete_unseen=True, confidence_threshold=0.2, cameras_list = get_cameras_list([0, 90, 180, 270], "cuda", focal=1)) 98 | return new_meshes 99 | -------------------------------------------------------------------------------- /scripts/normal_to_height_map.py: -------------------------------------------------------------------------------- 1 | # code modified from https://github.com/YertleTurtleGit/depth-from-normals 2 | import numpy as np 3 | import cv2 as cv 4 | from multiprocessing.pool import ThreadPool as Pool 5 | from multiprocessing import cpu_count 6 | from typing import Tuple, List, Union 7 | import numba 8 | 9 | 10 | def calculate_gradients( 11 | normals: np.ndarray, mask: np.ndarray 12 | ) -> Tuple[np.ndarray, np.ndarray]: 13 | horizontal_angle_map = np.arccos(np.clip(normals[:, :, 0], -1, 1)) 14 | left_gradients = np.zeros(normals.shape[:2]) 15 | left_gradients[mask != 0] = (1 - np.sin(horizontal_angle_map[mask != 0])) * np.sign( 16 | horizontal_angle_map[mask != 0] - np.pi / 2 17 | ) 18 | 19 | vertical_angle_map = np.arccos(np.clip(normals[:, :, 1], -1, 1)) 20 | top_gradients = np.zeros(normals.shape[:2]) 21 | top_gradients[mask != 0] = -(1 - np.sin(vertical_angle_map[mask != 0])) * np.sign( 22 | vertical_angle_map[mask != 0] - np.pi / 2 23 | ) 24 | 25 | return left_gradients, top_gradients 26 | 27 | 28 | @numba.jit(nopython=True) 29 | def integrate_gradient_field( 30 | gradient_field: np.ndarray, axis: int, mask: np.ndarray 31 | ) -> np.ndarray: 32 | heights = np.zeros(gradient_field.shape) 33 | 34 | for d1 in numba.prange(heights.shape[1 - axis]): 35 | sum_value = 0 36 | for d2 in range(heights.shape[axis]): 37 | coordinates = (d1, d2) if axis == 1 else (d2, d1) 38 | 39 | if mask[coordinates] != 0: 40 | sum_value = sum_value + gradient_field[coordinates] 41 | heights[coordinates] = sum_value 42 | else: 43 | sum_value = 0 44 | 45 | return heights 46 | 47 | 48 | def calculate_heights( 49 | left_gradients: np.ndarray, top_gradients, mask: np.ndarray 50 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: 51 | left_heights = integrate_gradient_field(left_gradients, 1, mask) 52 | right_heights = np.fliplr( 53 | integrate_gradient_field(np.fliplr(-left_gradients), 1, np.fliplr(mask)) 54 | ) 55 | top_heights = integrate_gradient_field(top_gradients, 0, mask) 56 | bottom_heights = np.flipud( 57 | integrate_gradient_field(np.flipud(-top_gradients), 0, np.flipud(mask)) 58 | ) 59 | return left_heights, right_heights, top_heights, bottom_heights 60 | 61 | 62 | def combine_heights(*heights: np.ndarray) -> np.ndarray: 63 | return np.mean(np.stack(heights, axis=0), axis=0) 64 | 65 | 66 | def rotate(matrix: np.ndarray, angle: float) -> np.ndarray: 67 | h, w = matrix.shape[:2] 68 | center = (w / 2, h / 2) 69 | 70 | rotation_matrix = cv.getRotationMatrix2D(center, angle, 1.0) 71 | corners = cv.transform( 72 | np.array([[[0, 0], [w, 0], [w, h], [0, h]]]), rotation_matrix 73 | )[0] 74 | 75 | _, _, w, h = cv.boundingRect(corners) 76 | 77 | rotation_matrix[0, 2] += w / 2 - center[0] 78 | rotation_matrix[1, 2] += h / 2 - center[1] 79 | result = cv.warpAffine(matrix, rotation_matrix, (w, h), flags=cv.INTER_LINEAR) 80 | 81 | return result 82 | 83 | 84 | def rotate_vector_field_normals(normals: np.ndarray, angle: float) -> np.ndarray: 85 | angle = np.radians(angle) 86 | cos_angle = np.cos(angle) 87 | sin_angle = np.sin(angle) 88 | 89 | rotated_normals = np.empty_like(normals) 90 | rotated_normals[:, :, 0] = ( 91 | normals[:, :, 0] * cos_angle - normals[:, :, 1] * sin_angle 92 | ) 93 | rotated_normals[:, :, 1] = ( 94 | normals[:, :, 0] * sin_angle + normals[:, :, 1] * cos_angle 95 | ) 96 | 97 | return rotated_normals 98 | 99 | 100 | def centered_crop(image: np.ndarray, target_resolution: Tuple[int, int]) -> np.ndarray: 101 | return image[ 102 | (image.shape[0] - target_resolution[0]) 103 | // 2 : (image.shape[0] - target_resolution[0]) 104 | // 2 105 | + target_resolution[0], 106 | (image.shape[1] - target_resolution[1]) 107 | // 2 : (image.shape[1] - target_resolution[1]) 108 | // 2 109 | + target_resolution[1], 110 | ] 111 | 112 | 113 | def integrate_vector_field( 114 | vector_field: np.ndarray, 115 | mask: np.ndarray, 116 | target_iteration_count: int, 117 | thread_count: int, 118 | ) -> np.ndarray: 119 | shape = vector_field.shape[:2] 120 | angles = np.linspace(0, 90, target_iteration_count, endpoint=False) 121 | 122 | def integrate_vector_field_angles(angles: List[float]) -> np.ndarray: 123 | all_combined_heights = np.zeros(shape) 124 | 125 | for angle in angles: 126 | rotated_vector_field = rotate_vector_field_normals( 127 | rotate(vector_field, angle), angle 128 | ) 129 | rotated_mask = rotate(mask, angle) 130 | 131 | left_gradients, top_gradients = calculate_gradients( 132 | rotated_vector_field, rotated_mask 133 | ) 134 | ( 135 | left_heights, 136 | right_heights, 137 | top_heights, 138 | bottom_heights, 139 | ) = calculate_heights(left_gradients, top_gradients, rotated_mask) 140 | 141 | combined_heights = combine_heights( 142 | left_heights, right_heights, top_heights, bottom_heights 143 | ) 144 | combined_heights = centered_crop(rotate(combined_heights, -angle), shape) 145 | all_combined_heights += combined_heights / len(angles) 146 | 147 | return all_combined_heights 148 | 149 | with Pool(processes=thread_count) as pool: 150 | heights = pool.map( 151 | integrate_vector_field_angles, 152 | np.array( 153 | np.array_split(angles, thread_count), 154 | dtype=object, 155 | ), 156 | ) 157 | pool.close() 158 | pool.join() 159 | 160 | isotropic_height = np.zeros(shape) 161 | for height in heights: 162 | isotropic_height += height / thread_count 163 | 164 | return isotropic_height 165 | 166 | 167 | def estimate_height_map( 168 | normal_map: np.ndarray, 169 | mask: Union[np.ndarray, None] = None, 170 | height_divisor: float = 1, 171 | target_iteration_count: int = 250, 172 | thread_count: int = cpu_count(), 173 | raw_values: bool = False, 174 | ) -> np.ndarray: 175 | if mask is None: 176 | if normal_map.shape[-1] == 4: 177 | mask = normal_map[:, :, 3] / 255 178 | mask[mask < 0.5] = 0 179 | mask[mask >= 0.5] = 1 180 | else: 181 | mask = np.ones(normal_map.shape[:2], dtype=np.uint8) 182 | 183 | normals = ((normal_map[:, :, :3].astype(np.float64) / 255) - 0.5) * 2 184 | heights = integrate_vector_field( 185 | normals, mask, target_iteration_count, thread_count 186 | ) 187 | 188 | if raw_values: 189 | return heights 190 | 191 | heights /= height_divisor 192 | heights[mask > 0] += 1 / 2 193 | heights[mask == 0] = 1 / 2 194 | 195 | heights *= 2**16 - 1 196 | 197 | if np.min(heights) < 0 or np.max(heights) > 2**16 - 1: 198 | raise OverflowError("Height values are clipping.") 199 | 200 | heights = np.clip(heights, 0, 2**16 - 1) 201 | heights = heights.astype(np.uint16) 202 | 203 | return heights 204 | -------------------------------------------------------------------------------- /scripts/refine_lr_to_sr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | import numpy as np 5 | from hashlib import md5 6 | def hash_img(img): 7 | return md5(np.array(img).tobytes()).hexdigest() 8 | def hash_any(obj): 9 | return md5(str(obj).encode()).hexdigest() 10 | 11 | def refine_lr_with_sd(pil_image_list, concept_img_list, control_image_list, prompt_list, pipe=None, strength=0.35, neg_prompt_list="", output_size=(512, 512), controlnet_conditioning_scale=1.): 12 | with torch.no_grad(): 13 | images = pipe( 14 | image=pil_image_list, 15 | ip_adapter_image=concept_img_list, 16 | prompt=prompt_list, 17 | neg_prompt=neg_prompt_list, 18 | num_inference_steps=50, 19 | strength=strength, 20 | height=output_size[0], 21 | width=output_size[1], 22 | control_image=control_image_list, 23 | guidance_scale=5.0, 24 | controlnet_conditioning_scale=controlnet_conditioning_scale, 25 | generator=torch.manual_seed(233), 26 | ).images 27 | return images 28 | 29 | SR_cache = None 30 | 31 | def run_sr_fast(source_pils, scale=4): 32 | from PIL import Image 33 | from scripts.upsampler import RealESRGANer 34 | import numpy as np 35 | global SR_cache 36 | if SR_cache is not None: 37 | upsampler = SR_cache 38 | else: 39 | upsampler = RealESRGANer( 40 | scale=4, 41 | onnx_path="ckpt/realesrgan-x4.onnx", 42 | tile=0, 43 | tile_pad=10, 44 | pre_pad=0, 45 | half=True, 46 | gpu_id=0, 47 | ) 48 | ret_pils = [] 49 | for idx, img_pils in enumerate(source_pils): 50 | np_in = isinstance(img_pils, np.ndarray) 51 | assert isinstance(img_pils, (Image.Image, np.ndarray)) 52 | img = np.array(img_pils) 53 | output, _ = upsampler.enhance(img, outscale=scale) 54 | if np_in: 55 | ret_pils.append(output) 56 | else: 57 | ret_pils.append(Image.fromarray(output)) 58 | if SR_cache is None: 59 | SR_cache = upsampler 60 | return ret_pils 61 | -------------------------------------------------------------------------------- /scripts/sd_model_zoo.py: -------------------------------------------------------------------------------- 1 | from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, EulerAncestralDiscreteScheduler, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionPipeline 2 | from transformers import CLIPVisionModelWithProjection 3 | import torch 4 | from copy import deepcopy 5 | 6 | ENABLE_CPU_CACHE = False 7 | DEFAULT_BASE_MODEL = "runwayml/stable-diffusion-v1-5" 8 | 9 | cached_models = {} # cache for models to avoid repeated loading, key is model name 10 | def cache_model(func): 11 | def wrapper(*args, **kwargs): 12 | if ENABLE_CPU_CACHE: 13 | model_name = func.__name__ + str(args) + str(kwargs) 14 | if model_name not in cached_models: 15 | cached_models[model_name] = func(*args, **kwargs) 16 | return cached_models[model_name] 17 | else: 18 | return func(*args, **kwargs) 19 | return wrapper 20 | 21 | def copied_cache_model(func): 22 | def wrapper(*args, **kwargs): 23 | if ENABLE_CPU_CACHE: 24 | model_name = func.__name__ + str(args) + str(kwargs) 25 | if model_name not in cached_models: 26 | cached_models[model_name] = func(*args, **kwargs) 27 | return deepcopy(cached_models[model_name]) 28 | else: 29 | return func(*args, **kwargs) 30 | return wrapper 31 | 32 | def model_from_ckpt_or_pretrained(ckpt_or_pretrained, model_cls, original_config_file='ckpt/v1-inference.yaml', torch_dtype=torch.float16, **kwargs): 33 | if ckpt_or_pretrained.endswith(".safetensors"): 34 | pipe = model_cls.from_single_file(ckpt_or_pretrained, original_config_file=original_config_file, torch_dtype=torch_dtype, **kwargs) 35 | else: 36 | pipe = model_cls.from_pretrained(ckpt_or_pretrained, torch_dtype=torch_dtype, **kwargs) 37 | return pipe 38 | 39 | @copied_cache_model 40 | def load_base_model_components(base_model=DEFAULT_BASE_MODEL, torch_dtype=torch.float16): 41 | model_kwargs = dict( 42 | torch_dtype=torch_dtype, 43 | requires_safety_checker=False, 44 | safety_checker=None, 45 | ) 46 | pipe: StableDiffusionPipeline = model_from_ckpt_or_pretrained( 47 | base_model, 48 | StableDiffusionPipeline, 49 | **model_kwargs 50 | ) 51 | pipe.to("cpu") 52 | return pipe.components 53 | 54 | @cache_model 55 | def load_controlnet(controlnet_path, torch_dtype=torch.float16): 56 | controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch_dtype) 57 | return controlnet 58 | 59 | @cache_model 60 | def load_image_encoder(): 61 | image_encoder = CLIPVisionModelWithProjection.from_pretrained( 62 | "h94/IP-Adapter", 63 | subfolder="models/image_encoder", 64 | torch_dtype=torch.float16, 65 | ) 66 | return image_encoder 67 | 68 | def load_common_sd15_pipe(base_model=DEFAULT_BASE_MODEL, device="auto", controlnet=None, ip_adapter=False, plus_model=True, torch_dtype=torch.float16, model_cpu_offload_seq=None, enable_sequential_cpu_offload=False, vae_slicing=False, pipeline_class=None, **kwargs): 69 | model_kwargs = dict( 70 | torch_dtype=torch_dtype, 71 | device_map=device, 72 | requires_safety_checker=False, 73 | safety_checker=None, 74 | ) 75 | components = load_base_model_components(base_model=base_model, torch_dtype=torch_dtype) 76 | model_kwargs.update(components) 77 | model_kwargs.update(kwargs) 78 | 79 | if controlnet is not None: 80 | if isinstance(controlnet, list): 81 | controlnet = [load_controlnet(controlnet_path, torch_dtype=torch_dtype) for controlnet_path in controlnet] 82 | else: 83 | controlnet = load_controlnet(controlnet, torch_dtype=torch_dtype) 84 | model_kwargs.update(controlnet=controlnet) 85 | 86 | if pipeline_class is None: 87 | if controlnet is not None: 88 | pipeline_class = StableDiffusionControlNetPipeline 89 | else: 90 | pipeline_class = StableDiffusionPipeline 91 | 92 | pipe: StableDiffusionPipeline = model_from_ckpt_or_pretrained( 93 | base_model, 94 | pipeline_class, 95 | **model_kwargs 96 | ) 97 | 98 | if ip_adapter: 99 | image_encoder = load_image_encoder() 100 | pipe.image_encoder = image_encoder 101 | if plus_model: 102 | pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.safetensors") 103 | else: 104 | pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.safetensors") 105 | pipe.set_ip_adapter_scale(1.0) 106 | else: 107 | pipe.unload_ip_adapter() 108 | 109 | pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) 110 | 111 | if model_cpu_offload_seq is None: 112 | if isinstance(pipe, StableDiffusionControlNetPipeline): 113 | pipe.model_cpu_offload_seq = "text_encoder->controlnet->unet->vae" 114 | elif isinstance(pipe, StableDiffusionControlNetImg2ImgPipeline): 115 | pipe.model_cpu_offload_seq = "text_encoder->controlnet->vae->unet->vae" 116 | else: 117 | pipe.model_cpu_offload_seq = model_cpu_offload_seq 118 | 119 | if enable_sequential_cpu_offload: 120 | pipe.enable_sequential_cpu_offload() 121 | else: 122 | pipe = pipe.to("cuda") 123 | pass 124 | # pipe.enable_model_cpu_offload() 125 | if vae_slicing: 126 | pipe.enable_vae_slicing() 127 | 128 | import gc 129 | gc.collect() 130 | return pipe 131 | 132 | -------------------------------------------------------------------------------- /scripts/upsampler.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import numpy as np 4 | import os 5 | import torch 6 | from torch.nn import functional as F 7 | from scripts.load_onnx import load_onnx_caller 8 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 9 | 10 | 11 | class RealESRGANer(): 12 | """A helper class for upsampling images with RealESRGAN. 13 | 14 | Args: 15 | scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4. 16 | model_path (str): The path to the pretrained model. It can be urls (will first download it automatically). 17 | model (nn.Module): The defined network. Default: None. 18 | tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop 19 | input images into tiles, and then process each of them. Finally, they will be merged into one image. 20 | 0 denotes for do not use tile. Default: 0. 21 | tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10. 22 | pre_pad (int): Pad the input images to avoid border artifacts. Default: 10. 23 | half (float): Whether to use half precision during inference. Default: False. 24 | """ 25 | 26 | def __init__(self, 27 | scale, 28 | onnx_path, 29 | tile=0, 30 | tile_pad=10, 31 | pre_pad=10, 32 | half=False, 33 | device=None, 34 | gpu_id=None): 35 | self.scale = scale 36 | self.tile_size = tile 37 | self.tile_pad = tile_pad 38 | self.pre_pad = pre_pad 39 | self.mod_scale = None 40 | self.half = half 41 | 42 | # initialize model 43 | if gpu_id: 44 | self.device = torch.device( 45 | f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device 46 | else: 47 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device 48 | self.model = load_onnx_caller(onnx_path, single_output=True) 49 | # warm up 50 | sample_input = torch.randn(1,3,512,512).cuda().float() 51 | self.model(sample_input) 52 | 53 | def pre_process(self, img): 54 | """Pre-process, such as pre-pad and mod pad, so that the images can be divisible 55 | """ 56 | img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float() 57 | self.img = img.unsqueeze(0).to(self.device) 58 | if self.half: 59 | self.img = self.img.half() 60 | 61 | # pre_pad 62 | if self.pre_pad != 0: 63 | self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect') 64 | # mod pad for divisible borders 65 | if self.scale == 2: 66 | self.mod_scale = 2 67 | elif self.scale == 1: 68 | self.mod_scale = 4 69 | if self.mod_scale is not None: 70 | self.mod_pad_h, self.mod_pad_w = 0, 0 71 | _, _, h, w = self.img.size() 72 | if (h % self.mod_scale != 0): 73 | self.mod_pad_h = (self.mod_scale - h % self.mod_scale) 74 | if (w % self.mod_scale != 0): 75 | self.mod_pad_w = (self.mod_scale - w % self.mod_scale) 76 | self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect') 77 | 78 | def process(self): 79 | # model inference 80 | self.output = self.model(self.img) 81 | 82 | def tile_process(self): 83 | """It will first crop input images to tiles, and then process each tile. 84 | Finally, all the processed tiles are merged into one images. 85 | 86 | Modified from: https://github.com/ata4/esrgan-launcher 87 | """ 88 | batch, channel, height, width = self.img.shape 89 | output_height = height * self.scale 90 | output_width = width * self.scale 91 | output_shape = (batch, channel, output_height, output_width) 92 | 93 | # start with black image 94 | self.output = self.img.new_zeros(output_shape) 95 | tiles_x = math.ceil(width / self.tile_size) 96 | tiles_y = math.ceil(height / self.tile_size) 97 | 98 | # loop over all tiles 99 | for y in range(tiles_y): 100 | for x in range(tiles_x): 101 | # extract tile from input image 102 | ofs_x = x * self.tile_size 103 | ofs_y = y * self.tile_size 104 | # input tile area on total image 105 | input_start_x = ofs_x 106 | input_end_x = min(ofs_x + self.tile_size, width) 107 | input_start_y = ofs_y 108 | input_end_y = min(ofs_y + self.tile_size, height) 109 | 110 | # input tile area on total image with padding 111 | input_start_x_pad = max(input_start_x - self.tile_pad, 0) 112 | input_end_x_pad = min(input_end_x + self.tile_pad, width) 113 | input_start_y_pad = max(input_start_y - self.tile_pad, 0) 114 | input_end_y_pad = min(input_end_y + self.tile_pad, height) 115 | 116 | # input tile dimensions 117 | input_tile_width = input_end_x - input_start_x 118 | input_tile_height = input_end_y - input_start_y 119 | tile_idx = y * tiles_x + x + 1 120 | input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad] 121 | 122 | # upscale tile 123 | try: 124 | with torch.no_grad(): 125 | output_tile = self.model(input_tile) 126 | except RuntimeError as error: 127 | print('Error', error) 128 | print(f'\tTile {tile_idx}/{tiles_x * tiles_y}') 129 | 130 | # output tile area on total image 131 | output_start_x = input_start_x * self.scale 132 | output_end_x = input_end_x * self.scale 133 | output_start_y = input_start_y * self.scale 134 | output_end_y = input_end_y * self.scale 135 | 136 | # output tile area without padding 137 | output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale 138 | output_end_x_tile = output_start_x_tile + input_tile_width * self.scale 139 | output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale 140 | output_end_y_tile = output_start_y_tile + input_tile_height * self.scale 141 | 142 | # put tile into output image 143 | self.output[:, :, output_start_y:output_end_y, 144 | output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile, 145 | output_start_x_tile:output_end_x_tile] 146 | 147 | def post_process(self): 148 | # remove extra pad 149 | if self.mod_scale is not None: 150 | _, _, h, w = self.output.size() 151 | self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale] 152 | # remove prepad 153 | if self.pre_pad != 0: 154 | _, _, h, w = self.output.size() 155 | self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale] 156 | return self.output 157 | 158 | @torch.no_grad() 159 | def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'): 160 | h_input, w_input = img.shape[0:2] 161 | # img: numpy 162 | img = img.astype(np.float32) 163 | if np.max(img) > 256: # 16-bit image 164 | max_range = 65535 165 | print('\tInput is a 16-bit image') 166 | else: 167 | max_range = 255 168 | img = img / max_range 169 | if len(img.shape) == 2: # gray image 170 | img_mode = 'L' 171 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) 172 | elif img.shape[2] == 4: # RGBA image with alpha channel 173 | img_mode = 'RGBA' 174 | alpha = img[:, :, 3] 175 | img = img[:, :, 0:3] 176 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 177 | if alpha_upsampler == 'realesrgan': 178 | alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB) 179 | else: 180 | img_mode = 'RGB' 181 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 182 | 183 | # ------------------- process image (without the alpha channel) ------------------- # 184 | self.pre_process(img) 185 | if self.tile_size > 0: 186 | self.tile_process() 187 | else: 188 | self.process() 189 | output_img = self.post_process() 190 | output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy() 191 | output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0)) 192 | if img_mode == 'L': 193 | output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY) 194 | 195 | # ------------------- process the alpha channel if necessary ------------------- # 196 | if img_mode == 'RGBA': 197 | if alpha_upsampler == 'realesrgan': 198 | self.pre_process(alpha) 199 | if self.tile_size > 0: 200 | self.tile_process() 201 | else: 202 | self.process() 203 | output_alpha = self.post_process() 204 | output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy() 205 | output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0)) 206 | output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY) 207 | else: # use the cv2 resize for alpha channel 208 | h, w = alpha.shape[0:2] 209 | output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR) 210 | 211 | # merge the alpha channel 212 | output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA) 213 | output_img[:, :, 3] = output_alpha 214 | 215 | # ------------------------------ return ------------------------------ # 216 | if max_range == 65535: # 16-bit image 217 | output = (output_img * 65535.0).round().astype(np.uint16) 218 | else: 219 | output = (output_img * 255.0).round().astype(np.uint8) 220 | 221 | if outscale is not None and outscale != float(self.scale): 222 | output = cv2.resize( 223 | output, ( 224 | int(w_input * outscale), 225 | int(h_input * outscale), 226 | ), interpolation=cv2.INTER_LANCZOS4) 227 | 228 | return output, img_mode 229 | 230 | -------------------------------------------------------------------------------- /scripts/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from PIL import Image 4 | import pymeshlab 5 | import pymeshlab as ml 6 | from pymeshlab import PercentageValue 7 | from pytorch3d.renderer import TexturesVertex 8 | from pytorch3d.structures import Meshes 9 | from rembg import new_session, remove 10 | import torch 11 | import torch.nn.functional as F 12 | from typing import List, Tuple 13 | from PIL import Image 14 | import trimesh 15 | 16 | providers = [ 17 | ('CUDAExecutionProvider', { 18 | 'device_id': 0, 19 | 'arena_extend_strategy': 'kSameAsRequested', 20 | 'gpu_mem_limit': 8 * 1024 * 1024 * 1024, 21 | 'cudnn_conv_algo_search': 'HEURISTIC', 22 | }) 23 | ] 24 | 25 | session = new_session(providers=providers) 26 | 27 | NEG_PROMPT="sketch, sculpture, hand drawing, outline, single color, NSFW, lowres, bad anatomy,bad hands, text, error, missing fingers, yellow sleeves, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry,(worst quality:1.4),(low quality:1.4)" 28 | 29 | def load_mesh_with_trimesh(file_name, file_type=None): 30 | import trimesh 31 | mesh: trimesh.Trimesh = trimesh.load(file_name, file_type=file_type) 32 | if isinstance(mesh, trimesh.Scene): 33 | assert len(mesh.geometry) > 0 34 | # save to obj first and load again to avoid offset issue 35 | from io import BytesIO 36 | with BytesIO() as f: 37 | mesh.export(f, file_type="obj") 38 | f.seek(0) 39 | mesh = trimesh.load(f, file_type="obj") 40 | if isinstance(mesh, trimesh.Scene): 41 | # we lose texture information here 42 | mesh = trimesh.util.concatenate( 43 | tuple(trimesh.Trimesh(vertices=g.vertices, faces=g.faces) 44 | for g in mesh.geometry.values())) 45 | assert isinstance(mesh, trimesh.Trimesh) 46 | 47 | vertices = torch.from_numpy(mesh.vertices).T 48 | faces = torch.from_numpy(mesh.faces).T 49 | colors = None 50 | if mesh.visual is not None: 51 | if hasattr(mesh.visual, 'vertex_colors'): 52 | colors = torch.from_numpy(mesh.visual.vertex_colors)[..., :3].T / 255. 53 | if colors is None: 54 | # print("Warning: no vertex color found in mesh! Filling it with gray.") 55 | colors = torch.ones_like(vertices) * 0.5 56 | return vertices, faces, colors 57 | 58 | def meshlab_mesh_to_py3dmesh(mesh: pymeshlab.Mesh) -> Meshes: 59 | verts = torch.from_numpy(mesh.vertex_matrix()).float() 60 | faces = torch.from_numpy(mesh.face_matrix()).long() 61 | colors = torch.from_numpy(mesh.vertex_color_matrix()[..., :3]).float() 62 | textures = TexturesVertex(verts_features=[colors]) 63 | return Meshes(verts=[verts], faces=[faces], textures=textures) 64 | 65 | 66 | def py3dmesh_to_meshlab_mesh(meshes: Meshes) -> pymeshlab.Mesh: 67 | colors_in = F.pad(meshes.textures.verts_features_packed().cpu().float(), [0,1], value=1).numpy().astype(np.float64) 68 | m1 = pymeshlab.Mesh( 69 | vertex_matrix=meshes.verts_packed().cpu().float().numpy().astype(np.float64), 70 | face_matrix=meshes.faces_packed().cpu().long().numpy().astype(np.int32), 71 | v_normals_matrix=meshes.verts_normals_packed().cpu().float().numpy().astype(np.float64), 72 | v_color_matrix=colors_in) 73 | return m1 74 | 75 | 76 | def to_pyml_mesh(vertices,faces): 77 | m1 = pymeshlab.Mesh( 78 | vertex_matrix=vertices.cpu().float().numpy().astype(np.float64), 79 | face_matrix=faces.cpu().long().numpy().astype(np.int32), 80 | ) 81 | return m1 82 | 83 | 84 | def to_py3d_mesh(vertices, faces, normals=None): 85 | from pytorch3d.structures import Meshes 86 | from pytorch3d.renderer.mesh.textures import TexturesVertex 87 | mesh = Meshes(verts=[vertices], faces=[faces], textures=None) 88 | if normals is None: 89 | normals = mesh.verts_normals_packed() 90 | # set normals as vertext colors 91 | mesh.textures = TexturesVertex(verts_features=[normals / 2 + 0.5]) 92 | return mesh 93 | 94 | 95 | def from_py3d_mesh(mesh): 96 | return mesh.verts_list()[0], mesh.faces_list()[0], mesh.textures.verts_features_packed() 97 | 98 | def rotate_normalmap_by_angle(normal_map: np.ndarray, angle: float): 99 | """ 100 | rotate along y-axis 101 | normal_map: np.array, shape=(H, W, 3) in [-1, 1] 102 | angle: float, in degree 103 | """ 104 | angle = angle / 180 * np.pi 105 | R = np.array([[np.cos(angle), 0, np.sin(angle)], [0, 1, 0], [-np.sin(angle), 0, np.cos(angle)]]) 106 | return np.dot(normal_map.reshape(-1, 3), R.T).reshape(normal_map.shape) 107 | 108 | # from view coord to front view world coord 109 | def rotate_normals(normal_pils, return_types='np', rotate_direction=1) -> np.ndarray: # [0, 255] 110 | n_views = len(normal_pils) 111 | ret = [] 112 | for idx, rgba_normal in enumerate(normal_pils): 113 | # rotate normal 114 | normal_np = np.array(rgba_normal)[:, :, :3] / 255 # in [-1, 1] 115 | alpha_np = np.array(rgba_normal)[:, :, 3] / 255 # in [0, 1] 116 | normal_np = normal_np * 2 - 1 117 | normal_np = rotate_normalmap_by_angle(normal_np, rotate_direction * idx * (360 / n_views)) 118 | normal_np = (normal_np + 1) / 2 119 | normal_np = normal_np * alpha_np[..., None] # make bg black 120 | rgba_normal_np = np.concatenate([normal_np * 255, alpha_np[:, :, None] * 255] , axis=-1) 121 | if return_types == 'np': 122 | ret.append(rgba_normal_np) 123 | elif return_types == 'pil': 124 | ret.append(Image.fromarray(rgba_normal_np.astype(np.uint8))) 125 | else: 126 | raise ValueError(f"return_types should be 'np' or 'pil', but got {return_types}") 127 | return ret 128 | 129 | 130 | def rotate_normalmap_by_angle_torch(normal_map, angle): 131 | """ 132 | rotate along y-axis 133 | normal_map: torch.Tensor, shape=(H, W, 3) in [-1, 1], device='cuda' 134 | angle: float, in degree 135 | """ 136 | angle = torch.tensor(angle / 180 * np.pi).to(normal_map) 137 | R = torch.tensor([[torch.cos(angle), 0, torch.sin(angle)], 138 | [0, 1, 0], 139 | [-torch.sin(angle), 0, torch.cos(angle)]]).to(normal_map) 140 | return torch.matmul(normal_map.view(-1, 3), R.T).view(normal_map.shape) 141 | 142 | def do_rotate(rgba_normal, angle): 143 | rgba_normal = torch.from_numpy(rgba_normal).float().cuda() / 255 144 | rotated_normal_tensor = rotate_normalmap_by_angle_torch(rgba_normal[..., :3] * 2 - 1, angle) 145 | rotated_normal_tensor = (rotated_normal_tensor + 1) / 2 146 | rotated_normal_tensor = rotated_normal_tensor * rgba_normal[:, :, [3]] # make bg black 147 | rgba_normal_np = torch.cat([rotated_normal_tensor * 255, rgba_normal[:, :, [3]] * 255], dim=-1).cpu().numpy() 148 | return rgba_normal_np 149 | 150 | def rotate_normals_torch(normal_pils, return_types='np', rotate_direction=1): 151 | n_views = len(normal_pils) 152 | ret = [] 153 | for idx, rgba_normal in enumerate(normal_pils): 154 | # rotate normal 155 | angle = rotate_direction * idx * (360 / n_views) 156 | rgba_normal_np = do_rotate(np.array(rgba_normal), angle) 157 | if return_types == 'np': 158 | ret.append(rgba_normal_np) 159 | elif return_types == 'pil': 160 | ret.append(Image.fromarray(rgba_normal_np.astype(np.uint8))) 161 | else: 162 | raise ValueError(f"return_types should be 'np' or 'pil', but got {return_types}") 163 | return ret 164 | 165 | def change_bkgd(img_pils, new_bkgd=(0., 0., 0.)): 166 | ret = [] 167 | new_bkgd = np.array(new_bkgd).reshape(1, 1, 3) 168 | for rgba_img in img_pils: 169 | img_np = np.array(rgba_img)[:, :, :3] / 255 170 | alpha_np = np.array(rgba_img)[:, :, 3] / 255 171 | ori_bkgd = img_np[:1, :1] 172 | # color = ori_color * alpha + bkgd * (1-alpha) 173 | # ori_color = (color - bkgd * (1-alpha)) / alpha 174 | alpha_np_clamp = np.clip(alpha_np, 1e-6, 1) # avoid divide by zero 175 | ori_img_np = (img_np - ori_bkgd * (1 - alpha_np[..., None])) / alpha_np_clamp[..., None] 176 | img_np = np.where(alpha_np[..., None] > 0.05, ori_img_np * alpha_np[..., None] + new_bkgd * (1 - alpha_np[..., None]), new_bkgd) 177 | rgba_img_np = np.concatenate([img_np * 255, alpha_np[..., None] * 255], axis=-1) 178 | ret.append(Image.fromarray(rgba_img_np.astype(np.uint8))) 179 | return ret 180 | 181 | def change_bkgd_to_normal(normal_pils) -> List[Image.Image]: 182 | n_views = len(normal_pils) 183 | ret = [] 184 | for idx, rgba_normal in enumerate(normal_pils): 185 | # calcuate background normal 186 | target_bkgd = rotate_normalmap_by_angle(np.array([[[0., 0., 1.]]]), idx * (360 / n_views)) 187 | normal_np = np.array(rgba_normal)[:, :, :3] / 255 # in [-1, 1] 188 | alpha_np = np.array(rgba_normal)[:, :, 3] / 255 # in [0, 1] 189 | normal_np = normal_np * 2 - 1 190 | old_bkgd = normal_np[:1,:1] 191 | normal_np[alpha_np > 0.05] = (normal_np[alpha_np > 0.05] - old_bkgd * (1 - alpha_np[alpha_np > 0.05][..., None])) / alpha_np[alpha_np > 0.05][..., None] 192 | normal_np = normal_np * alpha_np[..., None] + target_bkgd * (1 - alpha_np[..., None]) 193 | normal_np = (normal_np + 1) / 2 194 | rgba_normal_np = np.concatenate([normal_np * 255, alpha_np[..., None] * 255] , axis=-1) 195 | ret.append(Image.fromarray(rgba_normal_np.astype(np.uint8))) 196 | return ret 197 | 198 | 199 | def fix_vert_color_glb(mesh_path): 200 | from pygltflib import GLTF2, Material, PbrMetallicRoughness 201 | obj1 = GLTF2().load(mesh_path) 202 | obj1.meshes[0].primitives[0].material = 0 203 | obj1.materials.append(Material( 204 | pbrMetallicRoughness = PbrMetallicRoughness( 205 | baseColorFactor = [1.0, 1.0, 1.0, 1.0], 206 | metallicFactor = 0., 207 | roughnessFactor = 1.0, 208 | ), 209 | emissiveFactor = [0.0, 0.0, 0.0], 210 | doubleSided = True, 211 | )) 212 | obj1.save(mesh_path) 213 | 214 | 215 | def srgb_to_linear(c_srgb): 216 | c_linear = np.where(c_srgb <= 0.04045, c_srgb / 12.92, ((c_srgb + 0.055) / 1.055) ** 2.4) 217 | return c_linear.clip(0, 1.) 218 | 219 | 220 | def save_py3dmesh_with_trimesh_fast(meshes: Meshes, save_glb_path, apply_sRGB_to_LinearRGB=True): 221 | # convert from pytorch3d meshes to trimesh mesh 222 | vertices = meshes.verts_packed().cpu().float().numpy() 223 | triangles = meshes.faces_packed().cpu().long().numpy() 224 | np_color = meshes.textures.verts_features_packed().cpu().float().numpy() 225 | if save_glb_path.endswith(".glb"): 226 | # rotate 180 along +Y 227 | vertices[:, [0, 2]] = -vertices[:, [0, 2]] 228 | 229 | if apply_sRGB_to_LinearRGB: 230 | np_color = srgb_to_linear(np_color) 231 | assert vertices.shape[0] == np_color.shape[0] 232 | assert np_color.shape[1] == 3 233 | assert 0 <= np_color.min() and np_color.max() <= 1, f"min={np_color.min()}, max={np_color.max()}" 234 | mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, vertex_colors=np_color) 235 | mesh.remove_unreferenced_vertices() 236 | # save mesh 237 | mesh.export(save_glb_path) 238 | if save_glb_path.endswith(".glb"): 239 | fix_vert_color_glb(save_glb_path) 240 | print(f"saving to {save_glb_path}") 241 | 242 | 243 | def save_glb_and_video(save_mesh_prefix: str, meshes: Meshes, with_timestamp=True, dist=3.5, azim_offset=180, resolution=512, fov_in_degrees=1 / 1.15, cam_type="ortho", view_padding=60, export_video=True) -> Tuple[str, str]: 244 | import time 245 | if '.' in save_mesh_prefix: 246 | save_mesh_prefix = ".".join(save_mesh_prefix.split('.')[:-1]) 247 | if with_timestamp: 248 | save_mesh_prefix = save_mesh_prefix + f"_{int(time.time())}" 249 | ret_mesh = save_mesh_prefix + ".glb" 250 | # optimizied version 251 | save_py3dmesh_with_trimesh_fast(meshes, ret_mesh) 252 | return ret_mesh, None 253 | 254 | 255 | def simple_clean_mesh(pyml_mesh: ml.Mesh, apply_smooth=True, stepsmoothnum=1, apply_sub_divide=False, sub_divide_threshold=0.25): 256 | ms = ml.MeshSet() 257 | ms.add_mesh(pyml_mesh, "cube_mesh") 258 | 259 | if apply_smooth: 260 | ms.apply_filter("apply_coord_laplacian_smoothing", stepsmoothnum=stepsmoothnum, cotangentweight=False) 261 | if apply_sub_divide: # 5s, slow 262 | ms.apply_filter("meshing_repair_non_manifold_vertices") 263 | ms.apply_filter("meshing_repair_non_manifold_edges", method='Remove Faces') 264 | ms.apply_filter("meshing_surface_subdivision_loop", iterations=2, threshold=PercentageValue(sub_divide_threshold)) 265 | return meshlab_mesh_to_py3dmesh(ms.current_mesh()) 266 | 267 | 268 | def expand2square(pil_img, background_color): 269 | width, height = pil_img.size 270 | if width == height: 271 | return pil_img 272 | elif width > height: 273 | result = Image.new(pil_img.mode, (width, width), background_color) 274 | result.paste(pil_img, (0, (width - height) // 2)) 275 | return result 276 | else: 277 | result = Image.new(pil_img.mode, (height, height), background_color) 278 | result.paste(pil_img, ((height - width) // 2, 0)) 279 | return result 280 | 281 | 282 | def simple_preprocess(input_image, rembg_session=session, background_color=255): 283 | RES = 2048 284 | input_image.thumbnail([RES, RES], Image.Resampling.LANCZOS) 285 | if input_image.mode != 'RGBA': 286 | image_rem = input_image.convert('RGBA') 287 | input_image = remove(image_rem, alpha_matting=False, session=rembg_session) 288 | 289 | arr = np.asarray(input_image) 290 | alpha = np.asarray(input_image)[:, :, -1] 291 | x_nonzero = np.nonzero((alpha > 60).sum(axis=1)) 292 | y_nonzero = np.nonzero((alpha > 60).sum(axis=0)) 293 | x_min = int(x_nonzero[0].min()) 294 | y_min = int(y_nonzero[0].min()) 295 | x_max = int(x_nonzero[0].max()) 296 | y_max = int(y_nonzero[0].max()) 297 | arr = arr[x_min: x_max, y_min: y_max] 298 | input_image = Image.fromarray(arr) 299 | input_image = expand2square(input_image, (background_color, background_color, background_color, 0)) 300 | return input_image 301 | 302 | def init_target(img_pils, new_bkgd=(0., 0., 0.), device="cuda"): 303 | # Convert the background color to a PyTorch tensor 304 | new_bkgd = torch.tensor(new_bkgd, dtype=torch.float32).view(1, 1, 3).to(device) 305 | 306 | # Convert all images to PyTorch tensors and process them 307 | imgs = torch.stack([torch.from_numpy(np.array(img, dtype=np.float32)) for img in img_pils]).to(device) / 255 308 | img_nps = imgs[..., :3] 309 | alpha_nps = imgs[..., 3] 310 | ori_bkgds = img_nps[:, :1, :1] 311 | 312 | # Avoid divide by zero and calculate the original image 313 | alpha_nps_clamp = torch.clamp(alpha_nps, 1e-6, 1) 314 | ori_img_nps = (img_nps - ori_bkgds * (1 - alpha_nps.unsqueeze(-1))) / alpha_nps_clamp.unsqueeze(-1) 315 | ori_img_nps = torch.clamp(ori_img_nps, 0, 1) 316 | img_nps = torch.where(alpha_nps.unsqueeze(-1) > 0.05, ori_img_nps * alpha_nps.unsqueeze(-1) + new_bkgd * (1 - alpha_nps.unsqueeze(-1)), new_bkgd) 317 | 318 | rgba_img_np = torch.cat([img_nps, alpha_nps.unsqueeze(-1)], dim=-1) 319 | return rgba_img_np --------------------------------------------------------------------------------