├── .editorconfig
├── .gitignore
├── Installation.md
├── LICENSE
├── README.md
├── README_jp.md
├── README_zh.md
├── app
├── __init__.py
├── all_models.py
├── custom_models
│ ├── image2mvimage.yaml
│ ├── image2normal.yaml
│ ├── mvimg_prediction.py
│ ├── normal_prediction.py
│ └── utils.py
├── examples
│ ├── Groot.png
│ ├── aaa.png
│ ├── abma.png
│ ├── akun.png
│ ├── anya.png
│ ├── bag.png
│ ├── ex1.png
│ ├── ex2.png
│ ├── ex3.jpg
│ ├── ex4.png
│ ├── generated_1715761545_frame0.png
│ ├── generated_1715762357_frame0.png
│ ├── generated_1715763329_frame0.png
│ ├── hatsune_miku.png
│ └── princess-large.png
├── gradio_3dgen.py
├── gradio_3dgen_steps.py
├── gradio_local.py
└── utils.py
├── assets
├── teaser.jpg
└── teaser_safe.jpg
├── custum_3d_diffusion
├── custum_modules
│ ├── attention_processors.py
│ └── unifield_processor.py
├── custum_pipeline
│ ├── unifield_pipeline_img2img.py
│ └── unifield_pipeline_img2mvimg.py
├── modules.py
└── trainings
│ ├── __init__.py
│ ├── base.py
│ ├── config_classes.py
│ ├── image2image_trainer.py
│ ├── image2mvimage_trainer.py
│ └── utils.py
├── docker
├── Dockerfile
└── README.md
├── gradio_app.py
├── install_windows_win_py311_cu121.bat
├── mesh_reconstruction
├── func.py
├── opt.py
├── recon.py
├── refine.py
├── remesh.py
└── render.py
├── requirements-detail.txt
├── requirements-win-py311-cu121.txt
├── requirements.txt
└── scripts
├── all_typing.py
├── load_onnx.py
├── mesh_init.py
├── multiview_inference.py
├── normal_to_height_map.py
├── project_mesh.py
├── refine_lr_to_sr.py
├── sd_model_zoo.py
├── upsampler.py
└── utils.py
/.editorconfig:
--------------------------------------------------------------------------------
1 | root = true
2 |
3 | [*.py]
4 | charset = utf-8
5 | trim_trailing_whitespace = true
6 | end_of_line = lf
7 | insert_final_newline = true
8 | indent_style = space
9 | indent_size = 4
10 |
11 | [*.md]
12 | trim_trailing_whitespace = false
13 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by https://www.toptal.com/developers/gitignore/api/python
2 | # Edit at https://www.toptal.com/developers/gitignore?templates=python
3 |
4 | ### Python ###
5 | # Byte-compiled / optimized / DLL files
6 | __pycache__/
7 | *.py[cod]
8 | *$py.class
9 |
10 | # C extensions
11 | *.so
12 |
13 | # Distribution / packaging
14 | .Python
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 | cover/
57 |
58 | # Translations
59 | *.mo
60 | *.pot
61 |
62 | # Django stuff:
63 | *.log
64 | local_settings.py
65 | db.sqlite3
66 | db.sqlite3-journal
67 |
68 | # Flask stuff:
69 | instance/
70 | .webassets-cache
71 |
72 | # Scrapy stuff:
73 | .scrapy
74 |
75 | # Sphinx documentation
76 | docs/_build/
77 |
78 | # PyBuilder
79 | .pybuilder/
80 | target/
81 |
82 | # Jupyter Notebook
83 | .ipynb_checkpoints
84 |
85 | # IPython
86 | profile_default/
87 | ipython_config.py
88 |
89 | # pyenv
90 | # For a library or package, you might want to ignore these files since the code is
91 | # intended to run in multiple environments; otherwise, check them in:
92 | # .python-version
93 |
94 | # pipenv
95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
98 | # install all needed dependencies.
99 | #Pipfile.lock
100 |
101 | # poetry
102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
103 | # This is especially recommended for binary packages to ensure reproducibility, and is more
104 | # commonly ignored for libraries.
105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
106 | #poetry.lock
107 |
108 | # pdm
109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
110 | #pdm.lock
111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
112 | # in version control.
113 | # https://pdm.fming.dev/#use-with-ide
114 | .pdm.toml
115 |
116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117 | __pypackages__/
118 |
119 | # Celery stuff
120 | celerybeat-schedule
121 | celerybeat.pid
122 |
123 | # SageMath parsed files
124 | *.sage.py
125 |
126 | # Environments
127 | .env
128 | .venv
129 | env/
130 | venv/
131 | ENV/
132 | env.bak/
133 | venv.bak/
134 |
135 | # Spyder project settings
136 | .spyderproject
137 | .spyproject
138 |
139 | # Rope project settings
140 | .ropeproject
141 |
142 | # mkdocs documentation
143 | /site
144 |
145 | # mypy
146 | .mypy_cache/
147 | .dmypy.json
148 | dmypy.json
149 |
150 | # Pyre type checker
151 | .pyre/
152 |
153 | # pytype static type analyzer
154 | .pytype/
155 |
156 | # Cython debug symbols
157 | cython_debug/
158 |
159 | # PyCharm
160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162 | # and can be added to the global gitignore or merged into this file. For a more nuclear
163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164 | .idea/
165 |
166 | ### Python Patch ###
167 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
168 | poetry.toml
169 |
170 | # ruff
171 | .ruff_cache/
172 |
173 | # LSP config files
174 | pyrightconfig.json
175 |
176 | # End of https://www.toptal.com/developers/gitignore/api/python
177 |
178 | .vscode/
179 | .threestudio_cache/
180 | outputs
181 | outputs/
182 | outputs-gradio
183 | outputs-gradio/
184 | lightning_logs/
185 |
186 | # pretrained model weights
187 | *.ckpt
188 | *.pt
189 | *.pth
190 | *.bin
191 | *.param
192 |
193 | # wandb
194 | wandb/
195 |
196 | # obj results
197 | *.obj
198 | *.glb
199 | *.ply
200 |
201 | # ckpts
202 | ckpt/*
203 | *.pth
204 | *.pt
205 |
206 | # tensorrt
207 | *.engine
208 | *.profile
209 |
210 | # zipfiles
211 | *.zip
212 | *.tar
213 | *.tar.gz
214 |
215 | # others
216 | run_30.sh
217 | ckpt
--------------------------------------------------------------------------------
/Installation.md:
--------------------------------------------------------------------------------
1 | # 官方安装指南
2 |
3 | * 在 requirements-detail.txt 里,我们提供了详细的各个库的版本,这个对应的环境是 `python3.10 + cuda12.2`。
4 | * 本项目依赖于几个重要的pypi包,这几个包安装起来会有一些困难。
5 |
6 | ### nvdiffrast 安装
7 |
8 | * nvdiffrast 会在第一次运行时,编译对应的torch插件,这一步需要 ninja 及 cudatoolkit的支持。
9 | * 因此需要先确保正确安装了 ninja 以及 cudatoolkit 并正确配置了 CUDA_HOME 环境变量。
10 | * cudatoolkit 安装可以参考 [linux-cuda-installation-guide](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html), [windows-cuda-installation-guide](https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html)
11 | * ninja 则可用直接 `pip install ninja`
12 | * 然后设置 CUDA_HOME 变量为 cudatoolkit 的安装目录,如 `/usr/local/cuda`。
13 | * 最后 `pip install nvdiffrast` 即可。
14 | * 如果无法在目标服务器上安装 cudatoolkit (如权限不够),可用使用我修改的[预编译版本 nvdiffrast](https://github.com/wukailu/nvdiffrast-torch) 在另一台拥有 cudatoolkit 且环境相似(python, torch, cuda版本相同)的服务器上预编译后安装。
15 |
16 | ### onnxruntime-gpu 安装
17 |
18 | * 注意,同时安装 `onnxruntime` 与 `onnxruntime-gpu` 可能导致最终程序无法运行在GPU,而运行在CPU,导致极慢的推理速度。
19 | * [onnxruntime 官方安装指南](https://onnxruntime.ai/docs/install/#python-installs)
20 | * TLDR: For cuda11.x, `pip install onnxruntime-gpu`. For cuda12.x, `pip install onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/
21 | `.
22 | * 进一步的,可用安装基于 tensorrt 的 onnxruntime,进一步加快推理速度。
23 | * 注意:如果没有安装基于 tensorrt 的 onnxruntime,建议将 `https://github.com/AiuniAI/Unique3D/blob/4e1174c3896fee992ffc780d0ea813500401fae9/scripts/load_onnx.py#L4` 中 `TensorrtExecutionProvider` 删除。
24 | * 对于 cuda12.x 可用使用如下命令快速安装带有tensorrt的onnxruntime (注意将 `/root/miniconda3/lib/python3.10/site-packages` 修改为你的python 对应路径,将 `/root/.bashrc` 改为你的用户下路径 `.bashrc` 路劲)
25 | ```
26 | pip install ort-nightly-gpu --index-url=https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ort-cuda-12-nightly/pypi/simple/
27 | pip install onnxruntime-gpu==1.17.0 --index-url=https://pkgs.dev.azure.com/onnxruntime/onnxruntime/_packaging/onnxruntime-cuda-12/pypi/simple/
28 | pip install tensorrt==8.6.0
29 | echo -e "export LD_LIBRARY_PATH=/usr/local/cuda/targets/x86_64-linux/lib/:/root/miniconda3/lib/python3.10/site-packages/tensorrt:${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" >> /root/.bashrc
30 | ```
31 |
32 | ### pytorch3d 安装
33 |
34 | * 根据 [pytorch3d 官方的安装建议](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md#2-install-wheels-for-linux),建议使用预编译版本
35 | ```
36 | import sys
37 | import torch
38 | pyt_version_str=torch.__version__.split("+")[0].replace(".", "")
39 | version_str="".join([
40 | f"py3{sys.version_info.minor}_cu",
41 | torch.version.cuda.replace(".",""),
42 | f"_pyt{pyt_version_str}"
43 | ])
44 | !pip install fvcore iopath
45 | !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html
46 | ```
47 |
48 | ### torch_scatter 安装
49 |
50 | * 在[torch_scatter 官方安装指南](https://github.com/rusty1s/pytorch_scatter?tab=readme-ov-file#installation) 使用预编译的安装包快速安装。
51 | * 或者直接编译安装 `pip install git+https://github.com/rusty1s/pytorch_scatter.git`
52 |
53 | ### 其他安装
54 |
55 | * 其他文件 `pip install -r requirements.txt` 即可。
56 |
57 | -----
58 |
59 | # Detailed Installation Guide
60 |
61 | * In `requirements-detail.txt`, we provide detailed versions of all packages, which correspond to the environment of `python3.10 + cuda12.2`.
62 | * This project relies on several important PyPI packages, which may be difficult to install.
63 |
64 | ### Installation of nvdiffrast
65 |
66 | * nvdiffrast will compile the corresponding torch plugin the first time it runs, which requires support from ninja and cudatoolkit.
67 | * Therefore, it is necessary to ensure that ninja and cudatoolkit are correctly installed and that the CUDA_HOME environment variable is properly configured.
68 | * For the installation of cudatoolkit, you can refer to the [Linux CUDA Installation Guide](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html) and [Windows CUDA Installation Guide](https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html).
69 | * Ninja can be directly installed with `pip install ninja`.
70 | * Then set the CUDA_HOME variable to the installation directory of cudatoolkit, such as `/usr/local/cuda`.
71 | * Finally, `pip install nvdiffrast`.
72 | * If you cannot install cudatoolkit on the computer (e.g., insufficient permissions), you can use my modified [pre-compiled version of nvdiffrast](https://github.com/wukailu/nvdiffrast-torch) to pre-compile on another computer that has cudatoolkit and a similar environment (same versions of python, torch, cuda) and then install the `.whl`.
73 |
74 | ### Installation of onnxruntime-gpu
75 |
76 | * Note that installing both `onnxruntime` and `onnxruntime-gpu` may result in not running on the GPU but on the CPU, leading to extremely slow inference speed.
77 | * [Official ONNX Runtime Installation Guide](https://onnxruntime.ai/docs/install/#python-installs)
78 | * TLDR: For cuda11.x, `pip install onnxruntime-gpu`. For cuda12.x, `pip install onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/`.
79 | * Furthermore, you can install onnxruntime based on tensorrt to further increase the inference speed.
80 | * Note: If you do not correctly installed onnxruntime based on tensorrt, it is recommended to remove `TensorrtExecutionProvider` from `https://github.com/AiuniAI/Unique3D/blob/4e1174c3896fee992ffc780d0ea813500401fae9/scripts/load_onnx.py#L4`.
81 | * For cuda12.x, you can quickly install onnxruntime with tensorrt using the following commands (note to change the path `/root/miniconda3/lib/python3.10/site-packages` to the corresponding path of your python, and change `/root/.bashrc` to the path of `.bashrc` under your user directory):
82 | ```
83 | pip install ort-nightly-gpu --index-url=https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ort-cuda-12-nightly/pypi/simple/
84 | pip install onnxruntime-gpu==1.17.0 --index-url=https://pkgs.dev.azure.com/onnxruntime/onnxruntime/_packaging/onnxruntime-cuda-12/pypi/simple/
85 | pip install tensorrt==8.6.0
86 | echo -e "export LD_LIBRARY_PATH=/usr/local/cuda/targets/x86_64-linux/lib/:/root/miniconda3/lib/python3.10/site-packages/tensorrt:${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" >> /root/.bashrc
87 | ```
88 |
89 | ### Installation of pytorch3d
90 |
91 | * According to the [official installation recommendations of pytorch3d](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md#2-install-wheels-for-linux), it is recommended to use the pre-compiled version:
92 | ```
93 | import sys
94 | import torch
95 | pyt_version_str=torch.__version__.split("+")[0].replace(".", "")
96 | version_str="".join([
97 | f"py3{sys.version_info.minor}_cu",
98 | torch.version.cuda.replace(".",""),
99 | f"_pyt{pyt_version_str}"
100 | ])
101 | !pip install fvcore iopath
102 | !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html
103 | ```
104 |
105 | ### Installation of torch_scatter
106 |
107 | * Use the pre-compiled installation package according to the [official installation guide of torch_scatter](https://github.com/rusty1s/pytorch_scatter?tab=readme-ov-file#installation) for a quick installation.
108 | * Alternatively, you can directly compile and install with `pip install git+https://github.com/rusty1s/pytorch_scatter.git`.
109 |
110 | ### Other Installations
111 |
112 | * For other packages, simply `pip install -r requirements.txt`.
113 |
114 | -----
115 |
116 | # 官方インストールガイド
117 |
118 | * `requirements-detail.txt` には、各ライブラリのバージョンが詳細に提供されており、これは Python 3.10 + CUDA 12.2 に対応する環境です。
119 | * このプロジェクトは、いくつかの重要な PyPI パッケージに依存しており、これらのパッケージのインストールにはいくつかの困難が伴います。
120 |
121 | ### nvdiffrast のインストール
122 |
123 | * nvdiffrast は、最初に実行するときに、torch プラグインの対応バージョンをコンパイルします。このステップには、ninja および cudatoolkit のサポートが必要です。
124 | * したがって、ninja および cudatoolkit の正確なインストールと、CUDA_HOME 環境変数の正確な設定を確保する必要があります。
125 | * cudatoolkit のインストールについては、[Linux CUDA インストールガイド](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html)、[Windows CUDA インストールガイド](https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html) を参照してください。
126 | * ninja は、直接 `pip install ninja` でインストールできます。
127 | * 次に、CUDA_HOME 変数を cudatoolkit のインストールディレクトリに設定します。例えば、`/usr/local/cuda` のように。
128 | * 最後に、`pip install nvdiffrast` を実行します。
129 | * 目標サーバーで cudatoolkit をインストールできない場合(例えば、権限が不足している場合)、私の修正した[事前コンパイル済みバージョンの nvdiffrast](https://github.com/wukailu/nvdiffrast-torch)を使用できます。これは、cudatoolkit があり、環境が似ている(Python、torch、cudaのバージョンが同じ)別のサーバーで事前コンパイルしてからインストールすることができます。
130 |
131 | ### onnxruntime-gpu のインストール
132 |
133 | * 注意:`onnxruntime` と `onnxruntime-gpu` を同時にインストールすると、最終的なプログラムが GPU 上で実行されず、CPU 上で実行される可能性があり、推論速度が非常に遅くなることがあります。
134 | * [onnxruntime 公式インストールガイド](https://onnxruntime.ai/docs/install/#python-installs)
135 | * TLDR: cuda11.x 用には、`pip install onnxruntime-gpu` を使用します。cuda12.x 用には、`pip install onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/` を使用します。
136 | * さらに、TensorRT ベースの onnxruntime をインストールして、推論速度をさらに向上させることができます。
137 | * 注意:TensorRT ベースの onnxruntime がインストールされていない場合は、`https://github.com/AiuniAI/Unique3D/blob/4e1174c3896fee992ffc780d0ea813500401fae9/scripts/load_onnx.py#L4` の `TensorrtExecutionProvider` を削除することをお勧めします。
138 | * cuda12.x の場合、次のコマンドを使用して迅速に TensorRT を備えた onnxruntime をインストールできます(`/root/miniconda3/lib/python3.10/site-packages` をあなたの Python に対応するパスに、`/root/.bashrc` をあなたのユーザーのパスの下の `.bashrc` に変更してください)。
139 | ```bash
140 | pip install ort-nightly-gpu --index-url=https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ort-cuda-12-nightly/pypi/simple/
141 | pip install onnxruntime-gpu==1.17.0 --index-url=https://pkgs.dev.azure.com/onnxruntime/onnxruntime/_packaging/onnxruntime-cuda-12/pypi/simple/
142 | pip install tensorrt==8.6.0
143 | echo -e "export LD_LIBRARY_PATH=/usr/local/cuda/targets/x86_64-linux/lib/:/root/miniconda3/lib/python3.10/site-packages/tensorrt:${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" >> /root/.bashrc
144 | ```
145 |
146 | ### pytorch3d のインストール
147 |
148 | * [pytorch3d 公式のインストール提案](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md#2-install-wheels-for-linux)に従い、事前コンパイル済みバージョンを使用することをお勧めします。
149 | ```python
150 | import sys
151 | import torch
152 | pyt_version_str=torch.__version__.split("+")[0].replace(".", "")
153 | version_str="".join([
154 | f"py3{sys.version_info.minor}_cu",
155 | torch.version.cuda.replace(".",""),
156 | f"_pyt{pyt_version_str}"
157 | ])
158 | !pip install fvcore iopath
159 | !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html
160 | ```
161 |
162 | ### torch_scatter のインストール
163 |
164 | * [torch_scatter 公式インストールガイド](https://github.com/rusty1s/pytorch_scatter?tab=readme-ov-file#installation)に従い、事前コンパイル済みのインストールパッケージを使用して迅速インストールします。
165 | * または、直接コンパイルしてインストールする `pip install git+https://github.com/rusty1s/pytorch_scatter.git` も可能です。
166 |
167 | ### その他のインストール
168 |
169 | * その他のファイルについては、`pip install -r requirements.txt` を実行するだけです。
170 |
171 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 AiuniAI
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | **[中文版本](README_zh.md)**
2 |
3 | **[日本語版](README_jp.md)**
4 |
5 | # Unique3D
6 | Official implementation of Unique3D: High-Quality and Efficient 3D Mesh Generation from a Single Image.
7 |
8 | [Kailu Wu](https://scholar.google.com/citations?user=VTU0gysAAAAJ&hl=zh-CN&oi=ao), [Fangfu Liu](https://liuff19.github.io/), Zhihan Cai, Runjie Yan, Hanyang Wang, Yating Hu, [Yueqi Duan](https://duanyueqi.github.io/), [Kaisheng Ma](https://group.iiis.tsinghua.edu.cn/~maks/)
9 |
10 | ## [Paper](https://arxiv.org/abs/2405.20343) | [Project page](https://wukailu.github.io/Unique3D/) | [Huggingface Demo](https://huggingface.co/spaces/Wuvin/Unique3D) | [Online Demo](https://www.aiuni.ai/)
11 |
12 | * Demo inference speed: Gradio Demo > Huggingface Demo > Huggingface Demo2 > Online Demo
13 |
14 | **If the Gradio Demo is overcrowded or fails to produce stable results, you can use the Online Demo [aiuni.ai](https://www.aiuni.ai/), which is free to try (get the registration invitation code Join Discord: https://discord.gg/aiuni). However, the Online Demo is slightly different from the Gradio Demo, in that the inference speed is slower, but the generation is much more stable.**
15 |
16 |
17 |
18 |
19 |
20 | High-fidelity and diverse textured meshes generated by Unique3D from single-view wild images in 30 seconds.
21 |
22 | ## More features
23 |
24 | The repo is still being under construction, thanks for your patience.
25 | - [x] Upload weights.
26 | - [x] Local gradio demo.
27 | - [x] Detailed tutorial.
28 | - [x] Huggingface demo.
29 | - [ ] Detailed local demo.
30 | - [x] Comfyui support.
31 | - [x] Windows support.
32 | - [x] Docker support.
33 | - [ ] More stable reconstruction with normal.
34 | - [ ] Training code release.
35 |
36 | ## Preparation for inference
37 |
38 | * [Detailed linux installation guide](Installation.md).
39 |
40 | ### Linux System Setup.
41 |
42 | Adapted for Ubuntu 22.04.4 LTS and CUDA 12.1.
43 | ```angular2html
44 | conda create -n unique3d python=3.11
45 | conda activate unique3d
46 |
47 | pip install ninja
48 | pip install diffusers==0.27.2
49 |
50 | pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu121/torch2.3.1/index.html
51 |
52 | pip install -r requirements.txt
53 | ```
54 |
55 | [oak-barry](https://github.com/oak-barry) provide another setup script for torch210+cu121 at [here](https://github.com/oak-barry/Unique3D).
56 |
57 | ### Windows Setup.
58 |
59 | * Thank you very much `jtydhr88` for the windows installation method! See [issues/15](https://github.com/AiuniAI/Unique3D/issues/15).
60 |
61 | According to [issues/15](https://github.com/AiuniAI/Unique3D/issues/15), implemented a bat script to run the commands, so you can:
62 | 1. Might still require Visual Studio Build Tools, you can find it from [Visual Studio Build Tools](https://visualstudio.microsoft.com/downloads/?q=build+tools).
63 | 2. Create conda env and activate it
64 | 1. `conda create -n unique3d-py311 python=3.11`
65 | 2. `conda activate unique3d-py311`
66 | 3. download [triton whl](https://huggingface.co/madbuda/triton-windows-builds/resolve/main/triton-2.1.0-cp311-cp311-win_amd64.whl) for py311, and put it into this project.
67 | 4. run **install_windows_win_py311_cu121.bat**
68 | 5. answer y while asking you uninstall onnxruntime and onnxruntime-gpu
69 | 6. create the output folder **tmp\gradio** under the driver root, such as F:\tmp\gradio for me.
70 | 7. python app/gradio_local.py --port 7860
71 |
72 | More details prefer to [issues/15](https://github.com/AiuniAI/Unique3D/issues/15).
73 |
74 | ### Interactive inference: run your local gradio demo.
75 |
76 | 1. Download the weights from [huggingface spaces](https://huggingface.co/spaces/Wuvin/Unique3D/tree/main/ckpt) or [Tsinghua Cloud Drive](https://cloud.tsinghua.edu.cn/d/319762ec478d46c8bdf7/), and extract it to `ckpt/*`.
77 | ```
78 | Unique3D
79 | ├──ckpt
80 | ├── controlnet-tile/
81 | ├── image2normal/
82 | ├── img2mvimg/
83 | ├── realesrgan-x4.onnx
84 | └── v1-inference.yaml
85 | ```
86 |
87 | 2. Run the interactive inference locally.
88 | ```bash
89 | python app/gradio_local.py --port 7860
90 | ```
91 |
92 | ## ComfyUI Support
93 |
94 | Thanks for the [ComfyUI-Unique3D](https://github.com/jtydhr88/ComfyUI-Unique3D) implementation from [jtydhr88](https://github.com/jtydhr88)!
95 |
96 | ## Tips to get better results
97 |
98 | **Important: Because the mesh is normalized by the longest edge of xyz during training, it is desirable that the input image needs to contain the longest edge of the object during inference, or else you may get erroneously squashed results.**
99 | 1. Unique3D is sensitive to the facing direction of input images. Due to the distribution of the training data, orthographic front-facing images with a rest pose always lead to good reconstructions.
100 | 2. Images with occlusions will cause worse reconstructions, since four views cannot cover the complete object. Images with fewer occlusions lead to better results.
101 | 3. Pass an image with as high a resolution as possible to the input when resolution is a factor.
102 |
103 | ## Acknowledgement
104 |
105 | We have intensively borrowed code from the following repositories. Many thanks to the authors for sharing their code.
106 | - [Stable Diffusion](https://github.com/CompVis/stable-diffusion)
107 | - [Wonder3d](https://github.com/xxlong0/Wonder3D)
108 | - [Zero123Plus](https://github.com/SUDO-AI-3D/zero123plus)
109 | - [Continues Remeshing](https://github.com/Profactor/continuous-remeshing)
110 | - [Depth from Normals](https://github.com/YertleTurtleGit/depth-from-normals)
111 |
112 | ## Collaborations
113 | Our mission is to create a 4D generative model with 3D concepts. This is just our first step, and the road ahead is still long, but we are confident. We warmly invite you to join the discussion and explore potential collaborations in any capacity. **If you're interested in connecting or partnering with us, please don't hesitate to reach out via email (wkl22@mails.tsinghua.edu.cn)**.
114 |
115 | - Follow us on twitter for the latest updates: https://x.com/aiuni_ai
116 | - Join AIGC 3D/4D generation community on discord: https://discord.gg/aiuni
117 | - Research collaboration, please contact: ai@aiuni.ai
118 |
119 | ## Citation
120 |
121 | If you found Unique3D helpful, please cite our report:
122 | ```bibtex
123 | @misc{wu2024unique3d,
124 | title={Unique3D: High-Quality and Efficient 3D Mesh Generation from a Single Image},
125 | author={Kailu Wu and Fangfu Liu and Zhihan Cai and Runjie Yan and Hanyang Wang and Yating Hu and Yueqi Duan and Kaisheng Ma},
126 | year={2024},
127 | eprint={2405.20343},
128 | archivePrefix={arXiv},
129 | primaryClass={cs.CV}
130 | }
131 | ```
132 |
--------------------------------------------------------------------------------
/README_jp.md:
--------------------------------------------------------------------------------
1 | **他の言語のバージョン [英語](README.md) [中国語](README_zh.md)**
2 |
3 | # Unique3D
4 | Unique3D: 単一画像からの高品質かつ効率的な3Dメッシュ生成の公式実装。
5 |
6 | [Kailu Wu](https://scholar.google.com/citations?user=VTU0gysAAAAJ&hl=zh-CN&oi=ao), [Fangfu Liu](https://liuff19.github.io/), Zhihan Cai, Runjie Yan, Hanyang Wang, Yating Hu, [Yueqi Duan](https://duanyueqi.github.io/), [Kaisheng Ma](https://group.iiis.tsinghua.edu.cn/~maks/)
7 |
8 | ## [論文](https://arxiv.org/abs/2405.20343) | [プロジェクトページ](https://wukailu.github.io/Unique3D/) | [Huggingfaceデモ](https://huggingface.co/spaces/Wuvin/Unique3D) | [オンラインデモ](https://www.aiuni.ai/)
9 |
10 | * デモ推論速度: Gradioデモ > Huggingfaceデモ > Huggingfaceデモ2 > オンラインデモ
11 |
12 | **Gradioデモが残念ながらハングアップしたり、非常に混雑している場合は、[aiuni.ai](https://www.aiuni.ai/)のオンラインデモを使用できます。これは無料で試すことができます(登録招待コードを取得するには、Discordに参加してください: https://discord.gg/aiuni)。ただし、オンラインデモはGradioデモとは少し異なり、推論速度が遅く、生成結果が安定していない可能性がありますが、素材の品質は良いです。**
13 |
14 |
15 |
16 |
17 |
18 | Unique3Dは、野生の単一画像から高忠実度および多様なテクスチャメッシュを30秒で生成します。
19 |
20 | ## より多くの機能
21 |
22 | リポジトリはまだ構築中です。ご理解いただきありがとうございます。
23 | - [x] 重みのアップロード。
24 | - [x] ローカルGradioデモ。
25 | - [ ] 詳細なチュートリアル。
26 | - [x] Huggingfaceデモ。
27 | - [ ] 詳細なローカルデモ。
28 | - [x] Comfyuiサポート。
29 | - [x] Windowsサポート。
30 | - [ ] Dockerサポート。
31 | - [ ] ノーマルでより安定した再構築。
32 | - [ ] トレーニングコードのリリース。
33 |
34 | ## 推論の準備
35 |
36 | ### Linuxシステムセットアップ
37 |
38 | Ubuntu 22.04.4 LTSおよびCUDA 12.1に適応。
39 | ```angular2html
40 | conda create -n unique3d python=3.11
41 | conda activate unique3d
42 |
43 | pip install ninja
44 | pip install diffusers==0.27.2
45 |
46 | pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu121/torch2.3.1/index.html
47 |
48 | pip install -r requirements.txt
49 | ```
50 |
51 | [oak-barry](https://github.com/oak-barry)は、[こちら](https://github.com/oak-barry/Unique3D)でtorch210+cu121の別のセットアップスクリプトを提供しています。
52 |
53 | ### Windowsセットアップ
54 |
55 | * `jtydhr88`によるWindowsインストール方法に非常に感謝しますを参照してください。
56 |
57 | [issues/15](https://github.com/AiuniAI/Unique3D/issues/15)によると、コマンドを実行するバッチスクリプトを実装したので、以下の手順に従ってください。
58 | 1. [Visual Studio Build Tools](https://visualstudio.microsoft.com/downloads/?q=build+tools)からVisual Studio Build Toolsが必要になる場合があります。
59 | 2. conda envを作成し、アクティブにします。
60 | 1. `conda create -n unique3d-py311 python=3.11`
61 | 2. `conda activate unique3d-py311`
62 | 3. [triton whl](https://huggingface.co/madbuda/triton-windows-builds/resolve/main/triton-2.1.0-cp311-cp311-win_amd64.whl)をダウンロードし、このプロジェクトに配置します。
63 | 4. **install_windows_win_py311_cu121.bat**を実行します。
64 | 5. onnxruntimeおよびonnxruntime-gpuのアンインストールを求められた場合は、yと回答します。
65 | 6. ドライバールートの下に**tmp\gradio**フォルダを作成します(例:F:\tmp\gradio)。
66 | 7. python app/gradio_local.py --port 7860
67 |
68 | 詳細は[issues/15](https://github.com/AiuniAI/Unique3D/issues/15)を参照してください。
69 |
70 | ### インタラクティブ推論:ローカルGradioデモを実行する
71 |
72 | 1. [huggingface spaces](https://huggingface.co/spaces/Wuvin/Unique3D/tree/main/ckpt)または[Tsinghua Cloud Drive](https://cloud.tsinghua.edu.cn/d/319762ec478d46c8bdf7/)から重みをダウンロードし、`ckpt/*`に抽出します。
73 | ```
74 | Unique3D
75 | ├──ckpt
76 | ├── controlnet-tile/
77 | ├── image2normal/
78 | ├── img2mvimg/
79 | ├── realesrgan-x4.onnx
80 | └── v1-inference.yaml
81 | ```
82 |
83 | 2. インタラクティブ推論をローカルで実行します。
84 | ```bash
85 | python app/gradio_local.py --port 7860
86 | ```
87 |
88 | ## ComfyUIサポート
89 |
90 | [jtydhr88](https://github.com/jtydhr88)からの[ComfyUI-Unique3D](https://github.com/jtydhr88/ComfyUI-Unique3D)の実装に感謝します!
91 |
92 | ## より良い結果を得るためのヒント
93 |
94 | 1. Unique3Dは入力画像の向きに敏感です。トレーニングデータの分布により、正面を向いた直交画像は常に良い再構築につながります。
95 | 2. 遮蔽のある画像は、4つのビューがオブジェクトを完全にカバーできないため、再構築が悪化します。遮蔽の少ない画像は、より良い結果につながります。
96 | 3. 可能な限り高解像度の画像を入力として使用してください。
97 |
98 | ## 謝辞
99 |
100 | 以下のリポジトリからコードを大量に借用しました。コードを共有してくれた著者に感謝します。
101 | - [Stable Diffusion](https://github.com/CompVis/stable-diffusion)
102 | - [Wonder3d](https://github.com/xxlong0/Wonder3D)
103 | - [Zero123Plus](https://github.com/SUDO-AI-3D/zero123plus)
104 | - [Continues Remeshing](https://github.com/Profactor/continuous-remeshing)
105 | - [Depth from Normals](https://github.com/YertleTurtleGit/depth-from-normals)
106 |
107 | ## コラボレーション
108 | 私たちの使命は、3Dの概念を持つ4D生成モデルを作成することです。これは私たちの最初のステップであり、前途はまだ長いですが、私たちは自信を持っています。あらゆる形態の潜在的なコラボレーションを探求し、議論に参加することを心から歓迎します。**私たちと連絡を取りたい、またはパートナーシップを結びたい方は、メールでお気軽にお問い合わせください (wkl22@mails.tsinghua.edu.cn)**。
109 |
110 | - 最新情報を入手するには、Twitterをフォローしてください: https://x.com/aiuni_ai
111 | - DiscordでAIGC 3D/4D生成コミュニティに参加してください: https://discord.gg/aiuni
112 | - 研究協力については、ai@aiuni.aiまでご連絡ください。
113 |
114 | ## 引用
115 |
116 | Unique3Dが役立つと思われる場合は、私たちのレポートを引用してください:
117 | ```bibtex
118 | @misc{wu2024unique3d,
119 | title={Unique3D: High-Quality and Efficient 3D Mesh Generation from a Single Image},
120 | author={Kailu Wu and Fangfu Liu and Zhihan Cai and Runjie Yan and Hanyang Wang and Yating Hu and Yueqi Duan and Kaisheng Ma},
121 | year={2024},
122 | eprint={2405.20343},
123 | archivePrefix={arXiv},
124 | primaryClass={cs.CV}
125 | }
126 | ```
127 |
--------------------------------------------------------------------------------
/README_zh.md:
--------------------------------------------------------------------------------
1 | **其他语言版本 [English](README.md)**
2 |
3 | # Unique3D
4 | High-Quality and Efficient 3D Mesh Generation from a Single Image
5 |
6 | [Kailu Wu](https://scholar.google.com/citations?user=VTU0gysAAAAJ&hl=zh-CN&oi=ao), [Fangfu Liu](https://liuff19.github.io/), Zhihan Cai, Runjie Yan, Hanyang Wang, Yating Hu, [Yueqi Duan](https://duanyueqi.github.io/), [Kaisheng Ma](https://group.iiis.tsinghua.edu.cn/~maks/)
7 |
8 | ## [论文](https://arxiv.org/abs/2405.20343) | [项目页面](https://wukailu.github.io/Unique3D/) | [Huggingface Demo](https://huggingface.co/spaces/Wuvin/Unique3D) | [Gradio Demo](http://unique3d.demo.avar.cn/) | [在线演示](https://www.aiuni.ai/)
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 | Unique3D从单视图图像生成高保真度和多样化纹理的网格,在4090上大约需要30秒。
17 |
18 | ### 推理准备
19 |
20 | #### Linux系统设置
21 | ```angular2html
22 | conda create -n unique3d
23 | conda activate unique3d
24 | pip install -r requirements.txt
25 | ```
26 |
27 | #### 交互式推理:运行您的本地gradio演示
28 |
29 | 1. 从 [huggingface spaces](https://huggingface.co/spaces/Wuvin/Unique3D/tree/main/ckpt) 下载或者从[清华云盘](https://cloud.tsinghua.edu.cn/d/319762ec478d46c8bdf7/)下载权重,并将其解压到`ckpt/*`。
30 | ```
31 | Unique3D
32 | ├──ckpt
33 | ├── controlnet-tile/
34 | ├── image2normal/
35 | ├── img2mvimg/
36 | ├── realesrgan-x4.onnx
37 | └── v1-inference.yaml
38 | ```
39 |
40 | 2. 在本地运行交互式推理。
41 | ```bash
42 | python app/gradio_local.py --port 7860
43 | ```
44 |
45 | ## 获取更好结果的提示
46 |
47 | 1. Unique3D对输入图像的朝向非常敏感。由于训练数据的分布,**正交正视图像**通常总是能带来良好的重建。对于人物而言,最好是 A-pose 或者 T-pose,因为目前训练数据很少含有其他类型姿态。
48 | 2. 有遮挡的图像会导致更差的重建,因为4个视图无法覆盖完整的对象。遮挡较少的图像会带来更好的结果。
49 | 3. 尽可能将高分辨率的图像用作输入。
50 |
51 | ## 致谢
52 |
53 | 我们借用了以下代码库的代码。非常感谢作者们分享他们的代码。
54 | - [Stable Diffusion](https://github.com/CompVis/stable-diffusion)
55 | - [Wonder3d](https://github.com/xxlong0/Wonder3D)
56 | - [Zero123Plus](https://github.com/SUDO-AI-3D/zero123plus)
57 | - [Continues Remeshing](https://github.com/Profactor/continuous-remeshing)
58 | - [Depth from Normals](https://github.com/YertleTurtleGit/depth-from-normals)
59 |
60 | ## 合作
61 |
62 | 我们使命是创建一个具有3D概念的4D生成模型。这只是我们的第一步,前方的道路仍然很长,但我们有信心。我们热情邀请您加入讨论,并探索任何形式的潜在合作。**如果您有兴趣联系或与我们合作,欢迎通过电子邮件(wkl22@mails.tsinghua.edu.cn)与我们联系**。
63 |
--------------------------------------------------------------------------------
/app/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AiuniAI/Unique3D/a4f4ad2a6d9e71f2b04a753b375022898581b2c8/app/__init__.py
--------------------------------------------------------------------------------
/app/all_models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from scripts.sd_model_zoo import load_common_sd15_pipe
3 | from diffusers import StableDiffusionControlNetImg2ImgPipeline, StableDiffusionPipeline
4 |
5 |
6 | class MyModelZoo:
7 | _pipe_disney_controlnet_lineart_ipadapter_i2i: StableDiffusionControlNetImg2ImgPipeline = None
8 |
9 | base_model = "runwayml/stable-diffusion-v1-5"
10 |
11 | def __init__(self, base_model=None) -> None:
12 | if base_model is not None:
13 | self.base_model = base_model
14 |
15 | @property
16 | def pipe_disney_controlnet_tile_ipadapter_i2i(self):
17 | return self._pipe_disney_controlnet_lineart_ipadapter_i2i
18 |
19 | def init_models(self):
20 | self._pipe_disney_controlnet_lineart_ipadapter_i2i = load_common_sd15_pipe(base_model=self.base_model, ip_adapter=True, plus_model=False, controlnet="./ckpt/controlnet-tile", pipeline_class=StableDiffusionControlNetImg2ImgPipeline)
21 |
22 | model_zoo = MyModelZoo()
23 |
--------------------------------------------------------------------------------
/app/custom_models/image2mvimage.yaml:
--------------------------------------------------------------------------------
1 | pretrained_model_name_or_path: "./ckpt/img2mvimg"
2 | mixed_precision: "bf16"
3 |
4 | init_config:
5 | # enable controls
6 | enable_cross_attn_lora: False
7 | enable_cross_attn_ip: False
8 | enable_self_attn_lora: False
9 | enable_self_attn_ref: False
10 | enable_multiview_attn: True
11 |
12 | # for cross attention
13 | init_cross_attn_lora: False
14 | init_cross_attn_ip: False
15 | cross_attn_lora_rank: 256 # 0 for not enabled
16 | cross_attn_lora_only_kv: False
17 | ipadapter_pretrained_name: "h94/IP-Adapter"
18 | ipadapter_subfolder_name: "models"
19 | ipadapter_weight_name: "ip-adapter_sd15.safetensors"
20 | ipadapter_effect_on: "all" # all, first
21 |
22 | # for self attention
23 | init_self_attn_lora: False
24 | self_attn_lora_rank: 256
25 | self_attn_lora_only_kv: False
26 |
27 | # for self attention ref
28 | init_self_attn_ref: False
29 | self_attn_ref_position: "attn1"
30 | self_attn_ref_other_model_name: "lambdalabs/sd-image-variations-diffusers"
31 | self_attn_ref_pixel_wise_crosspond: False
32 | self_attn_ref_effect_on: "all"
33 |
34 | # for multiview attention
35 | init_multiview_attn: True
36 | multiview_attn_position: "attn1"
37 | use_mv_joint_attn: True
38 | num_modalities: 1
39 |
40 | # for unet
41 | init_unet_path: "${pretrained_model_name_or_path}"
42 | cat_condition: True # cat condition to input
43 |
44 | # for cls embedding
45 | init_num_cls_label: 8 # for initialize
46 | cls_labels: [0, 1, 2, 3] # for current task
47 |
48 | trainers:
49 | - trainer_type: "image2mvimage_trainer"
50 | trainer:
51 | pretrained_model_name_or_path: "${pretrained_model_name_or_path}"
52 | attn_config:
53 | cls_labels: [0, 1, 2, 3] # for current task
54 | enable_cross_attn_lora: False
55 | enable_cross_attn_ip: False
56 | enable_self_attn_lora: False
57 | enable_self_attn_ref: False
58 | enable_multiview_attn: True
59 | resolution: "256"
60 | condition_image_resolution: "256"
61 | normal_cls_offset: 4
62 | condition_image_column_name: "conditioning_image"
63 | image_column_name: "image"
--------------------------------------------------------------------------------
/app/custom_models/image2normal.yaml:
--------------------------------------------------------------------------------
1 | pretrained_model_name_or_path: "lambdalabs/sd-image-variations-diffusers"
2 | mixed_precision: "bf16"
3 |
4 | init_config:
5 | # enable controls
6 | enable_cross_attn_lora: False
7 | enable_cross_attn_ip: False
8 | enable_self_attn_lora: False
9 | enable_self_attn_ref: True
10 | enable_multiview_attn: False
11 |
12 | # for cross attention
13 | init_cross_attn_lora: False
14 | init_cross_attn_ip: False
15 | cross_attn_lora_rank: 512 # 0 for not enabled
16 | cross_attn_lora_only_kv: False
17 | ipadapter_pretrained_name: "h94/IP-Adapter"
18 | ipadapter_subfolder_name: "models"
19 | ipadapter_weight_name: "ip-adapter_sd15.safetensors"
20 | ipadapter_effect_on: "all" # all, first
21 |
22 | # for self attention
23 | init_self_attn_lora: False
24 | self_attn_lora_rank: 512
25 | self_attn_lora_only_kv: False
26 |
27 | # for self attention ref
28 | init_self_attn_ref: True
29 | self_attn_ref_position: "attn1"
30 | self_attn_ref_other_model_name: "lambdalabs/sd-image-variations-diffusers"
31 | self_attn_ref_pixel_wise_crosspond: True
32 | self_attn_ref_effect_on: "all"
33 |
34 | # for multiview attention
35 | init_multiview_attn: False
36 | multiview_attn_position: "attn1"
37 | num_modalities: 1
38 |
39 | # for unet
40 | init_unet_path: "${pretrained_model_name_or_path}"
41 | init_num_cls_label: 0 # for initialize
42 | cls_labels: [] # for current task
43 |
44 | trainers:
45 | - trainer_type: "image2image_trainer"
46 | trainer:
47 | pretrained_model_name_or_path: "${pretrained_model_name_or_path}"
48 | attn_config:
49 | cls_labels: [] # for current task
50 | enable_cross_attn_lora: False
51 | enable_cross_attn_ip: False
52 | enable_self_attn_lora: False
53 | enable_self_attn_ref: True
54 | enable_multiview_attn: False
55 | resolution: "512"
56 | condition_image_resolution: "512"
57 | condition_image_column_name: "conditioning_image"
58 | image_column_name: "image"
59 |
60 |
61 |
62 |
--------------------------------------------------------------------------------
/app/custom_models/mvimg_prediction.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import torch
3 | import gradio as gr
4 | from PIL import Image
5 | import numpy as np
6 | from rembg import remove
7 | from app.utils import change_rgba_bg, rgba_to_rgb
8 | from app.custom_models.utils import load_pipeline
9 | from scripts.all_typing import *
10 | from scripts.utils import session, simple_preprocess
11 |
12 | training_config = "app/custom_models/image2mvimage.yaml"
13 | checkpoint_path = "ckpt/img2mvimg/unet_state_dict.pth"
14 | trainer, pipeline = load_pipeline(training_config, checkpoint_path)
15 | # pipeline.enable_model_cpu_offload()
16 |
17 | def predict(img_list: List[Image.Image], guidance_scale=2., **kwargs):
18 | if isinstance(img_list, Image.Image):
19 | img_list = [img_list]
20 | img_list = [rgba_to_rgb(i) if i.mode == 'RGBA' else i for i in img_list]
21 | ret = []
22 | for img in img_list:
23 | images = trainer.pipeline_forward(
24 | pipeline=pipeline,
25 | image=img,
26 | guidance_scale=guidance_scale,
27 | **kwargs
28 | ).images
29 | ret.extend(images)
30 | return ret
31 |
32 |
33 | def run_mvprediction(input_image: Image.Image, remove_bg=True, guidance_scale=1.5, seed=1145):
34 | if input_image.mode == 'RGB' or np.array(input_image)[..., -1].mean() == 255.:
35 | # still do remove using rembg, since simple_preprocess requires RGBA image
36 | print("RGB image not RGBA! still remove bg!")
37 | remove_bg = True
38 |
39 | if remove_bg:
40 | input_image = remove(input_image, session=session)
41 |
42 | # make front_pil RGBA with white bg
43 | input_image = change_rgba_bg(input_image, "white")
44 | single_image = simple_preprocess(input_image)
45 |
46 | generator = torch.Generator(device="cuda").manual_seed(int(seed)) if seed >= 0 else None
47 |
48 | rgb_pils = predict(
49 | single_image,
50 | generator=generator,
51 | guidance_scale=guidance_scale,
52 | width=256,
53 | height=256,
54 | num_inference_steps=30,
55 | )
56 |
57 | return rgb_pils, single_image
58 |
--------------------------------------------------------------------------------
/app/custom_models/normal_prediction.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from PIL import Image
3 | from app.utils import rgba_to_rgb, simple_remove
4 | from app.custom_models.utils import load_pipeline
5 | from scripts.utils import rotate_normals_torch
6 | from scripts.all_typing import *
7 |
8 | training_config = "app/custom_models/image2normal.yaml"
9 | checkpoint_path = "ckpt/image2normal/unet_state_dict.pth"
10 | trainer, pipeline = load_pipeline(training_config, checkpoint_path)
11 | # pipeline.enable_model_cpu_offload()
12 |
13 | def predict_normals(image: List[Image.Image], guidance_scale=2., do_rotate=True, num_inference_steps=30, **kwargs):
14 | img_list = image if isinstance(image, list) else [image]
15 | img_list = [rgba_to_rgb(i) if i.mode == 'RGBA' else i for i in img_list]
16 | images = trainer.pipeline_forward(
17 | pipeline=pipeline,
18 | image=img_list,
19 | num_inference_steps=num_inference_steps,
20 | guidance_scale=guidance_scale,
21 | **kwargs
22 | ).images
23 | images = simple_remove(images)
24 | if do_rotate and len(images) > 1:
25 | images = rotate_normals_torch(images, return_types='pil')
26 | return images
--------------------------------------------------------------------------------
/app/custom_models/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import List
3 | from dataclasses import dataclass
4 | from app.utils import rgba_to_rgb
5 | from custum_3d_diffusion.trainings.config_classes import ExprimentConfig, TrainerSubConfig
6 | from custum_3d_diffusion import modules
7 | from custum_3d_diffusion.custum_modules.unifield_processor import AttnConfig, ConfigurableUNet2DConditionModel
8 | from custum_3d_diffusion.trainings.base import BasicTrainer
9 | from custum_3d_diffusion.trainings.utils import load_config
10 |
11 |
12 | @dataclass
13 | class FakeAccelerator:
14 | device: torch.device = torch.device("cuda")
15 |
16 |
17 | def init_trainers(cfg_path: str, weight_dtype: torch.dtype, extras: dict):
18 | accelerator = FakeAccelerator()
19 | cfg: ExprimentConfig = load_config(ExprimentConfig, cfg_path, extras)
20 | init_config: AttnConfig = load_config(AttnConfig, cfg.init_config)
21 | configurable_unet = ConfigurableUNet2DConditionModel(init_config, weight_dtype)
22 | configurable_unet.enable_xformers_memory_efficient_attention()
23 | trainer_cfgs: List[TrainerSubConfig] = [load_config(TrainerSubConfig, trainer) for trainer in cfg.trainers]
24 | trainers: List[BasicTrainer] = [modules.find(trainer.trainer_type)(accelerator, None, configurable_unet, trainer.trainer, weight_dtype, i) for i, trainer in enumerate(trainer_cfgs)]
25 | return trainers, configurable_unet
26 |
27 | from app.utils import make_image_grid, split_image
28 | def process_image(function, img, guidance_scale=2., merged_image=False, remove_bg=True):
29 | from rembg import remove
30 | if remove_bg:
31 | img = remove(img)
32 | img = rgba_to_rgb(img)
33 | if merged_image:
34 | img = split_image(img, rows=2)
35 | images = function(
36 | image=img,
37 | guidance_scale=guidance_scale,
38 | )
39 | if len(images) > 1:
40 | return make_image_grid(images, rows=2)
41 | else:
42 | return images[0]
43 |
44 |
45 | def process_text(trainer, pipeline, img, guidance_scale=2.):
46 | pipeline.cfg.validation_prompts = [img]
47 | titles, images = trainer.batched_validation_forward(pipeline, guidance_scale=[guidance_scale])
48 | return images[0]
49 |
50 |
51 | def load_pipeline(config_path, ckpt_path, pipeline_filter=lambda x: True, weight_dtype = torch.bfloat16):
52 | training_config = config_path
53 | load_from_checkpoint = ckpt_path
54 | extras = []
55 | device = "cuda"
56 | trainers, configurable_unet = init_trainers(training_config, weight_dtype, extras)
57 | shared_modules = dict()
58 | for trainer in trainers:
59 | shared_modules = trainer.init_shared_modules(shared_modules)
60 |
61 | if load_from_checkpoint is not None:
62 | state_dict = torch.load(load_from_checkpoint)
63 | configurable_unet.unet.load_state_dict(state_dict, strict=False)
64 | # Move unet, vae and text_encoder to device and cast to weight_dtype
65 | configurable_unet.unet.to(device, dtype=weight_dtype)
66 |
67 | pipeline = None
68 | trainer_out = None
69 | for trainer in trainers:
70 | if pipeline_filter(trainer.cfg.trainer_name):
71 | pipeline = trainer.construct_pipeline(shared_modules, configurable_unet.unet)
72 | pipeline.set_progress_bar_config(disable=False)
73 | trainer_out = trainer
74 | pipeline = pipeline.to(device)
75 | return trainer_out, pipeline
--------------------------------------------------------------------------------
/app/examples/Groot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AiuniAI/Unique3D/a4f4ad2a6d9e71f2b04a753b375022898581b2c8/app/examples/Groot.png
--------------------------------------------------------------------------------
/app/examples/aaa.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AiuniAI/Unique3D/a4f4ad2a6d9e71f2b04a753b375022898581b2c8/app/examples/aaa.png
--------------------------------------------------------------------------------
/app/examples/abma.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AiuniAI/Unique3D/a4f4ad2a6d9e71f2b04a753b375022898581b2c8/app/examples/abma.png
--------------------------------------------------------------------------------
/app/examples/akun.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AiuniAI/Unique3D/a4f4ad2a6d9e71f2b04a753b375022898581b2c8/app/examples/akun.png
--------------------------------------------------------------------------------
/app/examples/anya.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AiuniAI/Unique3D/a4f4ad2a6d9e71f2b04a753b375022898581b2c8/app/examples/anya.png
--------------------------------------------------------------------------------
/app/examples/bag.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AiuniAI/Unique3D/a4f4ad2a6d9e71f2b04a753b375022898581b2c8/app/examples/bag.png
--------------------------------------------------------------------------------
/app/examples/ex1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AiuniAI/Unique3D/a4f4ad2a6d9e71f2b04a753b375022898581b2c8/app/examples/ex1.png
--------------------------------------------------------------------------------
/app/examples/ex2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AiuniAI/Unique3D/a4f4ad2a6d9e71f2b04a753b375022898581b2c8/app/examples/ex2.png
--------------------------------------------------------------------------------
/app/examples/ex3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AiuniAI/Unique3D/a4f4ad2a6d9e71f2b04a753b375022898581b2c8/app/examples/ex3.jpg
--------------------------------------------------------------------------------
/app/examples/ex4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AiuniAI/Unique3D/a4f4ad2a6d9e71f2b04a753b375022898581b2c8/app/examples/ex4.png
--------------------------------------------------------------------------------
/app/examples/generated_1715761545_frame0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AiuniAI/Unique3D/a4f4ad2a6d9e71f2b04a753b375022898581b2c8/app/examples/generated_1715761545_frame0.png
--------------------------------------------------------------------------------
/app/examples/generated_1715762357_frame0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AiuniAI/Unique3D/a4f4ad2a6d9e71f2b04a753b375022898581b2c8/app/examples/generated_1715762357_frame0.png
--------------------------------------------------------------------------------
/app/examples/generated_1715763329_frame0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AiuniAI/Unique3D/a4f4ad2a6d9e71f2b04a753b375022898581b2c8/app/examples/generated_1715763329_frame0.png
--------------------------------------------------------------------------------
/app/examples/hatsune_miku.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AiuniAI/Unique3D/a4f4ad2a6d9e71f2b04a753b375022898581b2c8/app/examples/hatsune_miku.png
--------------------------------------------------------------------------------
/app/examples/princess-large.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AiuniAI/Unique3D/a4f4ad2a6d9e71f2b04a753b375022898581b2c8/app/examples/princess-large.png
--------------------------------------------------------------------------------
/app/gradio_3dgen.py:
--------------------------------------------------------------------------------
1 | import os
2 | import gradio as gr
3 | from PIL import Image
4 | from pytorch3d.structures import Meshes
5 | from app.utils import clean_up
6 | from app.custom_models.mvimg_prediction import run_mvprediction
7 | from app.custom_models.normal_prediction import predict_normals
8 | from scripts.refine_lr_to_sr import run_sr_fast
9 | from scripts.utils import save_glb_and_video
10 | from scripts.multiview_inference import geo_reconstruct
11 |
12 | def generate3dv2(preview_img, input_processing, seed, render_video=True, do_refine=True, expansion_weight=0.1, init_type="std"):
13 | if preview_img is None:
14 | raise gr.Error("preview_img is none")
15 | if isinstance(preview_img, str):
16 | preview_img = Image.open(preview_img)
17 |
18 | if preview_img.size[0] <= 512:
19 | preview_img = run_sr_fast([preview_img])[0]
20 | rgb_pils, front_pil = run_mvprediction(preview_img, remove_bg=input_processing, seed=int(seed)) # 6s
21 | new_meshes = geo_reconstruct(rgb_pils, None, front_pil, do_refine=do_refine, predict_normal=True, expansion_weight=expansion_weight, init_type=init_type)
22 | vertices = new_meshes.verts_packed()
23 | vertices = vertices / 2 * 1.35
24 | vertices[..., [0, 2]] = - vertices[..., [0, 2]]
25 | new_meshes = Meshes(verts=[vertices], faces=new_meshes.faces_list(), textures=new_meshes.textures)
26 |
27 | ret_mesh, video = save_glb_and_video("/tmp/gradio/generated", new_meshes, with_timestamp=True, dist=3.5, fov_in_degrees=2 / 1.35, cam_type="ortho", export_video=render_video)
28 | return ret_mesh, video
29 |
30 | #######################################
31 | def create_ui(concurrency_id="wkl"):
32 | with gr.Row():
33 | with gr.Column(scale=2):
34 | input_image = gr.Image(type='pil', image_mode='RGBA', label='Frontview')
35 |
36 | example_folder = os.path.join(os.path.dirname(__file__), "./examples")
37 | example_fns = sorted([os.path.join(example_folder, example) for example in os.listdir(example_folder)])
38 | gr.Examples(
39 | examples=example_fns,
40 | inputs=[input_image],
41 | cache_examples=False,
42 | label='Examples (click one of the images below to start)',
43 | examples_per_page=12
44 | )
45 |
46 |
47 | with gr.Column(scale=3):
48 | # export mesh display
49 | output_mesh = gr.Model3D(value=None, label="Mesh Model", show_label=True, height=320)
50 | output_video = gr.Video(label="Preview", show_label=True, show_share_button=True, height=320, visible=False)
51 |
52 | input_processing = gr.Checkbox(
53 | value=True,
54 | label='Remove Background',
55 | visible=True,
56 | )
57 | do_refine = gr.Checkbox(value=True, label="Refine Multiview Details", visible=False)
58 | expansion_weight = gr.Slider(minimum=-1., maximum=1.0, value=0.1, step=0.1, label="Expansion Weight", visible=False)
59 | init_type = gr.Dropdown(choices=["std", "thin"], label="Mesh Initialization", value="std", visible=False)
60 | setable_seed = gr.Slider(-1, 1000000000, -1, step=1, visible=True, label="Seed")
61 | render_video = gr.Checkbox(value=False, visible=False, label="generate video")
62 | fullrunv2_btn = gr.Button('Generate 3D', interactive=True)
63 |
64 | fullrunv2_btn.click(
65 | fn = generate3dv2,
66 | inputs=[input_image, input_processing, setable_seed, render_video, do_refine, expansion_weight, init_type],
67 | outputs=[output_mesh, output_video],
68 | concurrency_id=concurrency_id,
69 | api_name="generate3dv2",
70 | ).success(clean_up, api_name=False)
71 | return input_image
72 |
--------------------------------------------------------------------------------
/app/gradio_3dgen_steps.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | from PIL import Image
3 |
4 | from app.custom_models.mvimg_prediction import run_mvprediction
5 | from app.utils import make_image_grid, split_image
6 | from scripts.utils import save_glb_and_video
7 |
8 | def concept_to_multiview(preview_img, input_processing, seed, guidance=1.):
9 | seed = int(seed)
10 | if preview_img is None:
11 | raise gr.Error("preview_img is none.")
12 | if isinstance(preview_img, str):
13 | preview_img = Image.open(preview_img)
14 |
15 | rgb_pils, front_pil = run_mvprediction(preview_img, remove_bg=input_processing, seed=seed, guidance_scale=guidance)
16 | rgb_pil = make_image_grid(rgb_pils, rows=2)
17 | return rgb_pil, front_pil
18 |
19 | def concept_to_multiview_ui(concurrency_id="wkl"):
20 | with gr.Row():
21 | with gr.Column(scale=2):
22 | preview_img = gr.Image(type='pil', image_mode='RGBA', label='Frontview')
23 | input_processing = gr.Checkbox(
24 | value=True,
25 | label='Remove Background',
26 | )
27 | seed = gr.Slider(minimum=-1, maximum=1000000000, value=-1, step=1.0, label="seed")
28 | guidance = gr.Slider(minimum=1.0, maximum=5.0, value=1.0, label="Guidance Scale", step=0.5)
29 | run_btn = gr.Button('Generate Multiview', interactive=True)
30 | with gr.Column(scale=3):
31 | # export mesh display
32 | output_rgb = gr.Image(type='pil', label="RGB", show_label=True)
33 | output_front = gr.Image(type='pil', image_mode='RGBA', label="Frontview", show_label=True)
34 | run_btn.click(
35 | fn = concept_to_multiview,
36 | inputs=[preview_img, input_processing, seed, guidance],
37 | outputs=[output_rgb, output_front],
38 | concurrency_id=concurrency_id,
39 | api_name=False,
40 | )
41 | return output_rgb, output_front
42 |
43 | from app.custom_models.normal_prediction import predict_normals
44 | from scripts.multiview_inference import geo_reconstruct
45 | def multiview_to_mesh_v2(rgb_pil, normal_pil, front_pil, do_refine=False, expansion_weight=0.1, init_type="std"):
46 | rgb_pils = split_image(rgb_pil, rows=2)
47 | if normal_pil is not None:
48 | normal_pil = split_image(normal_pil, rows=2)
49 | if front_pil is None:
50 | front_pil = rgb_pils[0]
51 | new_meshes = geo_reconstruct(rgb_pils, normal_pil, front_pil, do_refine=do_refine, predict_normal=normal_pil is None, expansion_weight=expansion_weight, init_type=init_type)
52 | ret_mesh, video = save_glb_and_video("/tmp/gradio/generated", new_meshes, with_timestamp=True, dist=3.5, fov_in_degrees=2 / 1.35, cam_type="ortho", export_video=False)
53 | return ret_mesh
54 |
55 | def new_multiview_to_mesh_ui(concurrency_id="wkl"):
56 | with gr.Row():
57 | with gr.Column(scale=2):
58 | rgb_pil = gr.Image(type='pil', image_mode='RGB', label='RGB')
59 | front_pil = gr.Image(type='pil', image_mode='RGBA', label='Frontview(Optinal)')
60 | normal_pil = gr.Image(type='pil', image_mode='RGBA', label='Normal(Optinal)')
61 | do_refine = gr.Checkbox(
62 | value=False,
63 | label='Refine rgb',
64 | visible=False,
65 | )
66 | expansion_weight = gr.Slider(minimum=-1.0, maximum=1.0, value=0.1, step=0.1, label="Expansion Weight", visible=False)
67 | init_type = gr.Dropdown(choices=["std", "thin"], label="Mesh initialization", value="std", visible=False)
68 | run_btn = gr.Button('Generate 3D', interactive=True)
69 | with gr.Column(scale=3):
70 | # export mesh display
71 | output_mesh = gr.Model3D(value=None, label="mesh model", show_label=True)
72 | run_btn.click(
73 | fn = multiview_to_mesh_v2,
74 | inputs=[rgb_pil, normal_pil, front_pil, do_refine, expansion_weight, init_type],
75 | outputs=[output_mesh],
76 | concurrency_id=concurrency_id,
77 | api_name="multiview_to_mesh",
78 | )
79 | return rgb_pil, front_pil, output_mesh
80 |
81 |
82 | #######################################
83 | def create_step_ui(concurrency_id="wkl"):
84 | with gr.Tab(label="3D:concept_to_multiview"):
85 | concept_to_multiview_ui(concurrency_id)
86 | with gr.Tab(label="3D:new_multiview_to_mesh"):
87 | new_multiview_to_mesh_ui(concurrency_id)
88 |
--------------------------------------------------------------------------------
/app/gradio_local.py:
--------------------------------------------------------------------------------
1 | if __name__ == "__main__":
2 | import os
3 | import sys
4 | sys.path.append(os.curdir)
5 | if 'CUDA_VISIBLE_DEVICES' not in os.environ:
6 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
7 | os.environ['TRANSFORMERS_OFFLINE']='0'
8 | os.environ['DIFFUSERS_OFFLINE']='0'
9 | os.environ['HF_HUB_OFFLINE']='0'
10 | os.environ['GRADIO_ANALYTICS_ENABLED']='False'
11 | os.environ['HF_ENDPOINT']='https://hf-mirror.com'
12 | import torch
13 | torch.set_float32_matmul_precision('medium')
14 | torch.backends.cuda.matmul.allow_tf32 = True
15 | torch.set_grad_enabled(False)
16 |
17 | import gradio as gr
18 | import argparse
19 |
20 | from app.gradio_3dgen import create_ui as create_3d_ui
21 | # from app.gradio_3dgen_steps import create_step_ui
22 | from app.all_models import model_zoo
23 |
24 |
25 | _TITLE = '''Unique3D: High-Quality and Efficient 3D Mesh Generation from a Single Image'''
26 | _DESCRIPTION = '''
27 | [Project page](https://wukailu.github.io/Unique3D/)
28 |
29 | * High-fidelity and diverse textured meshes generated by Unique3D from single-view images.
30 |
31 | **If the Gradio Demo is overcrowded or fails to produce stable results, you can use the Online Demo [aiuni.ai](https://www.aiuni.ai/), which is free to try (get the registration invitation code Join Discord: https://discord.gg/aiuni). However, the Online Demo is slightly different from the Gradio Demo, in that the inference speed is slower, but the generation is much more stable.**
32 | '''
33 |
34 | def launch(
35 | port,
36 | listen=False,
37 | share=False,
38 | gradio_root="",
39 | ):
40 | model_zoo.init_models()
41 |
42 | with gr.Blocks(
43 | title=_TITLE,
44 | theme=gr.themes.Monochrome(),
45 | ) as demo:
46 | with gr.Row():
47 | with gr.Column(scale=1):
48 | gr.Markdown('# ' + _TITLE)
49 | gr.Markdown(_DESCRIPTION)
50 | create_3d_ui("wkl")
51 |
52 | launch_args = {}
53 | if listen:
54 | launch_args["server_name"] = "0.0.0.0"
55 |
56 | demo.queue(default_concurrency_limit=1).launch(
57 | server_port=None if port == 0 else port,
58 | share=share,
59 | root_path=gradio_root if gradio_root != "" else None, # "/myapp"
60 | **launch_args,
61 | )
62 |
63 | if __name__ == "__main__":
64 | parser = argparse.ArgumentParser()
65 | args, extra = parser.parse_known_args()
66 | parser.add_argument("--listen", action="store_true")
67 | parser.add_argument("--port", type=int, default=0)
68 | parser.add_argument("--share", action="store_true")
69 | parser.add_argument("--gradio_root", default="")
70 | args = parser.parse_args()
71 | launch(
72 | args.port,
73 | listen=args.listen,
74 | share=args.share,
75 | gradio_root=args.gradio_root,
76 | )
--------------------------------------------------------------------------------
/app/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from PIL import Image
4 | import gc
5 | import numpy as np
6 | import numpy as np
7 | from PIL import Image
8 | from scripts.refine_lr_to_sr import run_sr_fast
9 |
10 | GRADIO_CACHE = "/tmp/gradio/"
11 |
12 | def clean_up():
13 | torch.cuda.empty_cache()
14 | gc.collect()
15 |
16 | def remove_color(arr):
17 | if arr.shape[-1] == 4:
18 | arr = arr[..., :3]
19 | # calc diffs
20 | base = arr[0, 0]
21 | diffs = np.abs(arr.astype(np.int32) - base.astype(np.int32)).sum(axis=-1)
22 | alpha = (diffs <= 80)
23 |
24 | arr[alpha] = 255
25 | alpha = ~alpha
26 | arr = np.concatenate([arr, alpha[..., None].astype(np.int32) * 255], axis=-1)
27 | return arr
28 |
29 | def simple_remove(imgs, run_sr=True):
30 | """Only works for normal"""
31 | if not isinstance(imgs, list):
32 | imgs = [imgs]
33 | single_input = True
34 | else:
35 | single_input = False
36 | if run_sr:
37 | imgs = run_sr_fast(imgs)
38 | rets = []
39 | for img in imgs:
40 | arr = np.array(img)
41 | arr = remove_color(arr)
42 | rets.append(Image.fromarray(arr.astype(np.uint8)))
43 | if single_input:
44 | return rets[0]
45 | return rets
46 |
47 | def rgba_to_rgb(rgba: Image.Image, bkgd="WHITE"):
48 | new_image = Image.new("RGBA", rgba.size, bkgd)
49 | new_image.paste(rgba, (0, 0), rgba)
50 | new_image = new_image.convert('RGB')
51 | return new_image
52 |
53 | def change_rgba_bg(rgba: Image.Image, bkgd="WHITE"):
54 | rgb_white = rgba_to_rgb(rgba, bkgd)
55 | new_rgba = Image.fromarray(np.concatenate([np.array(rgb_white), np.array(rgba)[:, :, 3:4]], axis=-1))
56 | return new_rgba
57 |
58 | def split_image(image, rows=None, cols=None):
59 | """
60 | inverse function of make_image_grid
61 | """
62 | # image is in square
63 | if rows is None and cols is None:
64 | # image.size [W, H]
65 | rows = 1
66 | cols = image.size[0] // image.size[1]
67 | assert cols * image.size[1] == image.size[0]
68 | subimg_size = image.size[1]
69 | elif rows is None:
70 | subimg_size = image.size[0] // cols
71 | rows = image.size[1] // subimg_size
72 | assert rows * subimg_size == image.size[1]
73 | elif cols is None:
74 | subimg_size = image.size[1] // rows
75 | cols = image.size[0] // subimg_size
76 | assert cols * subimg_size == image.size[0]
77 | else:
78 | subimg_size = image.size[1] // rows
79 | assert cols * subimg_size == image.size[0]
80 | subimgs = []
81 | for i in range(rows):
82 | for j in range(cols):
83 | subimg = image.crop((j*subimg_size, i*subimg_size, (j+1)*subimg_size, (i+1)*subimg_size))
84 | subimgs.append(subimg)
85 | return subimgs
86 |
87 | def make_image_grid(images, rows=None, cols=None, resize=None):
88 | if rows is None and cols is None:
89 | rows = 1
90 | cols = len(images)
91 | if rows is None:
92 | rows = len(images) // cols
93 | if len(images) % cols != 0:
94 | rows += 1
95 | if cols is None:
96 | cols = len(images) // rows
97 | if len(images) % rows != 0:
98 | cols += 1
99 | total_imgs = rows * cols
100 | if total_imgs > len(images):
101 | images += [Image.new(images[0].mode, images[0].size) for _ in range(total_imgs - len(images))]
102 |
103 | if resize is not None:
104 | images = [img.resize((resize, resize)) for img in images]
105 |
106 | w, h = images[0].size
107 | grid = Image.new(images[0].mode, size=(cols * w, rows * h))
108 |
109 | for i, img in enumerate(images):
110 | grid.paste(img, box=(i % cols * w, i // cols * h))
111 | return grid
112 |
113 |
--------------------------------------------------------------------------------
/assets/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AiuniAI/Unique3D/a4f4ad2a6d9e71f2b04a753b375022898581b2c8/assets/teaser.jpg
--------------------------------------------------------------------------------
/assets/teaser_safe.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AiuniAI/Unique3D/a4f4ad2a6d9e71f2b04a753b375022898581b2c8/assets/teaser_safe.jpg
--------------------------------------------------------------------------------
/custum_3d_diffusion/custum_pipeline/unifield_pipeline_img2img.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # modified by Wuvin
15 |
16 |
17 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union
18 |
19 | import numpy as np
20 | import torch
21 |
22 | from diffusers import AutoencoderKL, UNet2DConditionModel, StableDiffusionImageVariationPipeline
23 | from diffusers.schedulers import KarrasDiffusionSchedulers
24 | from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipelineOutput
25 | from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
26 | from PIL import Image
27 | from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
28 |
29 |
30 |
31 | class StableDiffusionImageCustomPipeline(
32 | StableDiffusionImageVariationPipeline
33 | ):
34 | def __init__(
35 | self,
36 | vae: AutoencoderKL,
37 | image_encoder: CLIPVisionModelWithProjection,
38 | unet: UNet2DConditionModel,
39 | scheduler: KarrasDiffusionSchedulers,
40 | safety_checker: StableDiffusionSafetyChecker,
41 | feature_extractor: CLIPImageProcessor,
42 | requires_safety_checker: bool = True,
43 | latents_offset=None,
44 | noisy_cond_latents=False,
45 | ):
46 | super().__init__(
47 | vae=vae,
48 | image_encoder=image_encoder,
49 | unet=unet,
50 | scheduler=scheduler,
51 | safety_checker=safety_checker,
52 | feature_extractor=feature_extractor,
53 | requires_safety_checker=requires_safety_checker
54 | )
55 | latents_offset = tuple(latents_offset) if latents_offset is not None else None
56 | self.latents_offset = latents_offset
57 | if latents_offset is not None:
58 | self.register_to_config(latents_offset=latents_offset)
59 | self.noisy_cond_latents = noisy_cond_latents
60 | self.register_to_config(noisy_cond_latents=noisy_cond_latents)
61 |
62 | def encode_latents(self, image, device, dtype, height, width):
63 | # support batchsize > 1
64 | if isinstance(image, Image.Image):
65 | image = [image]
66 | image = [img.convert("RGB") for img in image]
67 | images = self.image_processor.preprocess(image, height=height, width=width).to(device, dtype=dtype)
68 | latents = self.vae.encode(images).latent_dist.mode() * self.vae.config.scaling_factor
69 | if self.latents_offset is not None:
70 | return latents - torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None]
71 | else:
72 | return latents
73 |
74 | def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance):
75 | dtype = next(self.image_encoder.parameters()).dtype
76 |
77 | if not isinstance(image, torch.Tensor):
78 | image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
79 |
80 | image = image.to(device=device, dtype=dtype)
81 | image_embeddings = self.image_encoder(image).image_embeds
82 | image_embeddings = image_embeddings.unsqueeze(1)
83 |
84 | # duplicate image embeddings for each generation per prompt, using mps friendly method
85 | bs_embed, seq_len, _ = image_embeddings.shape
86 | image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
87 | image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
88 |
89 | if do_classifier_free_guidance:
90 | # NOTE: the same as original code
91 | negative_prompt_embeds = torch.zeros_like(image_embeddings)
92 | # For classifier free guidance, we need to do two forward passes.
93 | # Here we concatenate the unconditional and text embeddings into a single batch
94 | # to avoid doing two forward passes
95 | image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
96 |
97 | return image_embeddings
98 |
99 | @torch.no_grad()
100 | def __call__(
101 | self,
102 | image: Union[Image.Image, List[Image.Image], torch.FloatTensor],
103 | height: Optional[int] = 1024,
104 | width: Optional[int] = 1024,
105 | height_cond: Optional[int] = 512,
106 | width_cond: Optional[int] = 512,
107 | num_inference_steps: int = 50,
108 | guidance_scale: float = 7.5,
109 | num_images_per_prompt: Optional[int] = 1,
110 | eta: float = 0.0,
111 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
112 | latents: Optional[torch.FloatTensor] = None,
113 | output_type: Optional[str] = "pil",
114 | return_dict: bool = True,
115 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
116 | callback_steps: int = 1,
117 | upper_left_feature: bool = False,
118 | ):
119 | r"""
120 | The call function to the pipeline for generation.
121 |
122 | Args:
123 | image (`Image.Image` or `List[Image.Image]` or `torch.FloatTensor`):
124 | Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
125 | [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
126 | height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
127 | The height in pixels of the generated image.
128 | width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
129 | The width in pixels of the generated image.
130 | num_inference_steps (`int`, *optional*, defaults to 50):
131 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
132 | expense of slower inference. This parameter is modulated by `strength`.
133 | guidance_scale (`float`, *optional*, defaults to 7.5):
134 | A higher guidance scale value encourages the model to generate images closely linked to the text
135 | `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
136 | num_images_per_prompt (`int`, *optional*, defaults to 1):
137 | The number of images to generate per prompt.
138 | eta (`float`, *optional*, defaults to 0.0):
139 | Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
140 | to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
141 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
142 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
143 | generation deterministic.
144 | latents (`torch.FloatTensor`, *optional*):
145 | Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
146 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
147 | tensor is generated by sampling using the supplied random `generator`.
148 | output_type (`str`, *optional*, defaults to `"pil"`):
149 | The output format of the generated image. Choose between `PIL.Image` or `np.array`.
150 | return_dict (`bool`, *optional*, defaults to `True`):
151 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
152 | plain tuple.
153 | callback (`Callable`, *optional*):
154 | A function that calls every `callback_steps` steps during inference. The function is called with the
155 | following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
156 | callback_steps (`int`, *optional*, defaults to 1):
157 | The frequency at which the `callback` function is called. If not specified, the callback is called at
158 | every step.
159 |
160 | Returns:
161 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
162 | If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
163 | otherwise a `tuple` is returned where the first element is a list with the generated images and the
164 | second element is a list of `bool`s indicating whether the corresponding generated image contains
165 | "not-safe-for-work" (nsfw) content.
166 |
167 | Examples:
168 |
169 | ```py
170 | from diffusers import StableDiffusionImageVariationPipeline
171 | from PIL import Image
172 | from io import BytesIO
173 | import requests
174 |
175 | pipe = StableDiffusionImageVariationPipeline.from_pretrained(
176 | "lambdalabs/sd-image-variations-diffusers", revision="v2.0"
177 | )
178 | pipe = pipe.to("cuda")
179 |
180 | url = "https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200"
181 |
182 | response = requests.get(url)
183 | image = Image.open(BytesIO(response.content)).convert("RGB")
184 |
185 | out = pipe(image, num_images_per_prompt=3, guidance_scale=15)
186 | out["images"][0].save("result.jpg")
187 | ```
188 | """
189 | # 0. Default height and width to unet
190 | height = height or self.unet.config.sample_size * self.vae_scale_factor
191 | width = width or self.unet.config.sample_size * self.vae_scale_factor
192 |
193 | # 1. Check inputs. Raise error if not correct
194 | self.check_inputs(image, height, width, callback_steps)
195 |
196 | # 2. Define call parameters
197 | if isinstance(image, Image.Image):
198 | batch_size = 1
199 | elif isinstance(image, list):
200 | batch_size = len(image)
201 | else:
202 | batch_size = image.shape[0]
203 | device = self._execution_device
204 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
205 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
206 | # corresponds to doing no classifier free guidance.
207 | do_classifier_free_guidance = guidance_scale > 1.0
208 |
209 | # 3. Encode input image
210 | if isinstance(image, Image.Image) and upper_left_feature:
211 | # only use the first one of four images
212 | emb_image = image.crop((0, 0, image.size[0] // 2, image.size[1] // 2))
213 | else:
214 | emb_image = image
215 |
216 | image_embeddings = self._encode_image(emb_image, device, num_images_per_prompt, do_classifier_free_guidance)
217 | cond_latents = self.encode_latents(image, image_embeddings.device, image_embeddings.dtype, height_cond, width_cond)
218 |
219 | # 4. Prepare timesteps
220 | self.scheduler.set_timesteps(num_inference_steps, device=device)
221 | timesteps = self.scheduler.timesteps
222 |
223 | # 5. Prepare latent variables
224 | num_channels_latents = self.unet.config.out_channels
225 | latents = self.prepare_latents(
226 | batch_size * num_images_per_prompt,
227 | num_channels_latents,
228 | height,
229 | width,
230 | image_embeddings.dtype,
231 | device,
232 | generator,
233 | latents,
234 | )
235 |
236 | # 6. Prepare extra step kwargs.
237 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
238 |
239 | # 7. Denoising loop
240 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
241 | with self.progress_bar(total=num_inference_steps) as progress_bar:
242 | for i, t in enumerate(timesteps):
243 | if self.noisy_cond_latents:
244 | raise ValueError("Noisy condition latents is not recommended.")
245 | else:
246 | noisy_cond_latents = cond_latents
247 |
248 | noisy_cond_latents = torch.cat([torch.zeros_like(noisy_cond_latents), noisy_cond_latents]) if do_classifier_free_guidance else noisy_cond_latents
249 | # expand the latents if we are doing classifier free guidance
250 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
251 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
252 |
253 | # predict the noise residual
254 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings, condition_latents=noisy_cond_latents).sample
255 |
256 | # perform guidance
257 | if do_classifier_free_guidance:
258 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
259 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
260 |
261 | # compute the previous noisy sample x_t -> x_t-1
262 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
263 |
264 | # call the callback, if provided
265 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
266 | progress_bar.update()
267 | if callback is not None and i % callback_steps == 0:
268 | step_idx = i // getattr(self.scheduler, "order", 1)
269 | callback(step_idx, t, latents)
270 |
271 | self.maybe_free_model_hooks()
272 |
273 | if self.latents_offset is not None:
274 | latents = latents + torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None]
275 |
276 | if not output_type == "latent":
277 | image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
278 | image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)
279 | else:
280 | image = latents
281 | has_nsfw_concept = None
282 |
283 | if has_nsfw_concept is None:
284 | do_denormalize = [True] * image.shape[0]
285 | else:
286 | do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
287 |
288 | image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
289 |
290 | self.maybe_free_model_hooks()
291 |
292 | if not return_dict:
293 | return (image, has_nsfw_concept)
294 |
295 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
296 |
297 | if __name__ == "__main__":
298 | pass
299 |
--------------------------------------------------------------------------------
/custum_3d_diffusion/custum_pipeline/unifield_pipeline_img2mvimg.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # modified by Wuvin
15 |
16 |
17 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union
18 |
19 | import numpy as np
20 | import torch
21 |
22 | from diffusers import AutoencoderKL, UNet2DConditionModel, StableDiffusionImageVariationPipeline
23 | from diffusers.schedulers import KarrasDiffusionSchedulers, DDPMScheduler
24 | from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipelineOutput
25 | from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
26 | from PIL import Image
27 | from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
28 |
29 |
30 |
31 | class StableDiffusionImage2MVCustomPipeline(
32 | StableDiffusionImageVariationPipeline
33 | ):
34 | def __init__(
35 | self,
36 | vae: AutoencoderKL,
37 | image_encoder: CLIPVisionModelWithProjection,
38 | unet: UNet2DConditionModel,
39 | scheduler: KarrasDiffusionSchedulers,
40 | safety_checker: StableDiffusionSafetyChecker,
41 | feature_extractor: CLIPImageProcessor,
42 | requires_safety_checker: bool = True,
43 | latents_offset=None,
44 | noisy_cond_latents=False,
45 | condition_offset=True,
46 | ):
47 | super().__init__(
48 | vae=vae,
49 | image_encoder=image_encoder,
50 | unet=unet,
51 | scheduler=scheduler,
52 | safety_checker=safety_checker,
53 | feature_extractor=feature_extractor,
54 | requires_safety_checker=requires_safety_checker
55 | )
56 | latents_offset = tuple(latents_offset) if latents_offset is not None else None
57 | self.latents_offset = latents_offset
58 | if latents_offset is not None:
59 | self.register_to_config(latents_offset=latents_offset)
60 | if noisy_cond_latents:
61 | raise NotImplementedError("Noisy condition latents not supported Now.")
62 | self.condition_offset = condition_offset
63 | self.register_to_config(condition_offset=condition_offset)
64 |
65 | def encode_latents(self, image: Image.Image, device, dtype, height, width):
66 | images = self.image_processor.preprocess(image.convert("RGB"), height=height, width=width).to(device, dtype=dtype)
67 | # NOTE: .mode() for condition
68 | latents = self.vae.encode(images).latent_dist.mode() * self.vae.config.scaling_factor
69 | if self.latents_offset is not None and self.condition_offset:
70 | return latents - torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None]
71 | else:
72 | return latents
73 |
74 | def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance):
75 | dtype = next(self.image_encoder.parameters()).dtype
76 |
77 | if not isinstance(image, torch.Tensor):
78 | image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
79 |
80 | image = image.to(device=device, dtype=dtype)
81 | image_embeddings = self.image_encoder(image).image_embeds
82 | image_embeddings = image_embeddings.unsqueeze(1)
83 |
84 | # duplicate image embeddings for each generation per prompt, using mps friendly method
85 | bs_embed, seq_len, _ = image_embeddings.shape
86 | image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
87 | image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
88 |
89 | if do_classifier_free_guidance:
90 | # NOTE: the same as original code
91 | negative_prompt_embeds = torch.zeros_like(image_embeddings)
92 | # For classifier free guidance, we need to do two forward passes.
93 | # Here we concatenate the unconditional and text embeddings into a single batch
94 | # to avoid doing two forward passes
95 | image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
96 |
97 | return image_embeddings
98 |
99 | @torch.no_grad()
100 | def __call__(
101 | self,
102 | image: Union[Image.Image, List[Image.Image], torch.FloatTensor],
103 | height: Optional[int] = 1024,
104 | width: Optional[int] = 1024,
105 | height_cond: Optional[int] = 512,
106 | width_cond: Optional[int] = 512,
107 | num_inference_steps: int = 50,
108 | guidance_scale: float = 7.5,
109 | num_images_per_prompt: Optional[int] = 1,
110 | eta: float = 0.0,
111 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
112 | latents: Optional[torch.FloatTensor] = None,
113 | output_type: Optional[str] = "pil",
114 | return_dict: bool = True,
115 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
116 | callback_steps: int = 1,
117 | ):
118 | r"""
119 | The call function to the pipeline for generation.
120 |
121 | Args:
122 | image (`Image.Image` or `List[Image.Image]` or `torch.FloatTensor`):
123 | Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
124 | [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
125 | height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
126 | The height in pixels of the generated image.
127 | width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
128 | The width in pixels of the generated image.
129 | num_inference_steps (`int`, *optional*, defaults to 50):
130 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
131 | expense of slower inference. This parameter is modulated by `strength`.
132 | guidance_scale (`float`, *optional*, defaults to 7.5):
133 | A higher guidance scale value encourages the model to generate images closely linked to the text
134 | `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
135 | num_images_per_prompt (`int`, *optional*, defaults to 1):
136 | The number of images to generate per prompt.
137 | eta (`float`, *optional*, defaults to 0.0):
138 | Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
139 | to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
140 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
141 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
142 | generation deterministic.
143 | latents (`torch.FloatTensor`, *optional*):
144 | Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
145 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
146 | tensor is generated by sampling using the supplied random `generator`.
147 | output_type (`str`, *optional*, defaults to `"pil"`):
148 | The output format of the generated image. Choose between `PIL.Image` or `np.array`.
149 | return_dict (`bool`, *optional*, defaults to `True`):
150 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
151 | plain tuple.
152 | callback (`Callable`, *optional*):
153 | A function that calls every `callback_steps` steps during inference. The function is called with the
154 | following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
155 | callback_steps (`int`, *optional*, defaults to 1):
156 | The frequency at which the `callback` function is called. If not specified, the callback is called at
157 | every step.
158 |
159 | Returns:
160 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
161 | If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
162 | otherwise a `tuple` is returned where the first element is a list with the generated images and the
163 | second element is a list of `bool`s indicating whether the corresponding generated image contains
164 | "not-safe-for-work" (nsfw) content.
165 |
166 | Examples:
167 |
168 | ```py
169 | from diffusers import StableDiffusionImageVariationPipeline
170 | from PIL import Image
171 | from io import BytesIO
172 | import requests
173 |
174 | pipe = StableDiffusionImageVariationPipeline.from_pretrained(
175 | "lambdalabs/sd-image-variations-diffusers", revision="v2.0"
176 | )
177 | pipe = pipe.to("cuda")
178 |
179 | url = "https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200"
180 |
181 | response = requests.get(url)
182 | image = Image.open(BytesIO(response.content)).convert("RGB")
183 |
184 | out = pipe(image, num_images_per_prompt=3, guidance_scale=15)
185 | out["images"][0].save("result.jpg")
186 | ```
187 | """
188 | # 0. Default height and width to unet
189 | height = height or self.unet.config.sample_size * self.vae_scale_factor
190 | width = width or self.unet.config.sample_size * self.vae_scale_factor
191 |
192 | # 1. Check inputs. Raise error if not correct
193 | self.check_inputs(image, height, width, callback_steps)
194 |
195 | # 2. Define call parameters
196 | if isinstance(image, Image.Image):
197 | batch_size = 1
198 | elif len(image) == 1:
199 | image = image[0]
200 | batch_size = 1
201 | else:
202 | raise NotImplementedError()
203 | # elif isinstance(image, list):
204 | # batch_size = len(image)
205 | # else:
206 | # batch_size = image.shape[0]
207 | device = self._execution_device
208 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
209 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
210 | # corresponds to doing no classifier free guidance.
211 | do_classifier_free_guidance = guidance_scale > 1.0
212 |
213 | # 3. Encode input image
214 | emb_image = image
215 |
216 | image_embeddings = self._encode_image(emb_image, device, num_images_per_prompt, do_classifier_free_guidance)
217 | cond_latents = self.encode_latents(image, image_embeddings.device, image_embeddings.dtype, height_cond, width_cond)
218 | cond_latents = torch.cat([torch.zeros_like(cond_latents), cond_latents]) if do_classifier_free_guidance else cond_latents
219 | image_pixels = self.feature_extractor(images=emb_image, return_tensors="pt").pixel_values
220 | if do_classifier_free_guidance:
221 | image_pixels = torch.cat([torch.zeros_like(image_pixels), image_pixels], dim=0)
222 |
223 | # 4. Prepare timesteps
224 | self.scheduler.set_timesteps(num_inference_steps, device=device)
225 | timesteps = self.scheduler.timesteps
226 |
227 | # 5. Prepare latent variables
228 | num_channels_latents = self.unet.config.out_channels
229 | latents = self.prepare_latents(
230 | batch_size * num_images_per_prompt,
231 | num_channels_latents,
232 | height,
233 | width,
234 | image_embeddings.dtype,
235 | device,
236 | generator,
237 | latents,
238 | )
239 |
240 |
241 | # 6. Prepare extra step kwargs.
242 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
243 | # 7. Denoising loop
244 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
245 | with self.progress_bar(total=num_inference_steps) as progress_bar:
246 | for i, t in enumerate(timesteps):
247 | # expand the latents if we are doing classifier free guidance
248 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
249 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
250 |
251 | # predict the noise residual
252 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings, condition_latents=cond_latents, noisy_condition_input=False, cond_pixels_clip=image_pixels).sample
253 |
254 | # perform guidance
255 | if do_classifier_free_guidance:
256 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
257 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
258 |
259 | # compute the previous noisy sample x_t -> x_t-1
260 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
261 |
262 | # call the callback, if provided
263 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
264 | progress_bar.update()
265 | if callback is not None and i % callback_steps == 0:
266 | step_idx = i // getattr(self.scheduler, "order", 1)
267 | callback(step_idx, t, latents)
268 |
269 | self.maybe_free_model_hooks()
270 |
271 | if self.latents_offset is not None:
272 | latents = latents + torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None]
273 |
274 | if not output_type == "latent":
275 | image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
276 | image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)
277 | else:
278 | image = latents
279 | has_nsfw_concept = None
280 |
281 | if has_nsfw_concept is None:
282 | do_denormalize = [True] * image.shape[0]
283 | else:
284 | do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
285 |
286 | image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
287 |
288 | self.maybe_free_model_hooks()
289 |
290 | if not return_dict:
291 | return (image, has_nsfw_concept)
292 |
293 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
294 |
295 | if __name__ == "__main__":
296 | pass
297 |
--------------------------------------------------------------------------------
/custum_3d_diffusion/modules.py:
--------------------------------------------------------------------------------
1 | __modules__ = {}
2 |
3 | def register(name):
4 | def decorator(cls):
5 | __modules__[name] = cls
6 | return cls
7 |
8 | return decorator
9 |
10 |
11 | def find(name):
12 | return __modules__[name]
13 |
14 | from custum_3d_diffusion.trainings import base, image2mvimage_trainer, image2image_trainer
15 |
--------------------------------------------------------------------------------
/custum_3d_diffusion/trainings/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AiuniAI/Unique3D/a4f4ad2a6d9e71f2b04a753b375022898581b2c8/custum_3d_diffusion/trainings/__init__.py
--------------------------------------------------------------------------------
/custum_3d_diffusion/trainings/base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from accelerate import Accelerator
3 | from accelerate.logging import MultiProcessAdapter
4 | from dataclasses import dataclass, field
5 | from typing import Optional, Union
6 | from datasets import load_dataset
7 | import json
8 | import abc
9 | from diffusers.utils import make_image_grid
10 | import numpy as np
11 | import wandb
12 |
13 | from custum_3d_diffusion.trainings.utils import load_config
14 | from custum_3d_diffusion.custum_modules.unifield_processor import ConfigurableUNet2DConditionModel, AttnConfig
15 |
16 | class BasicTrainer(torch.nn.Module, abc.ABC):
17 | accelerator: Accelerator
18 | logger: MultiProcessAdapter
19 | unet: ConfigurableUNet2DConditionModel
20 | train_dataloader: torch.utils.data.DataLoader
21 | test_dataset: torch.utils.data.Dataset
22 | attn_config: AttnConfig
23 |
24 | @dataclass
25 | class TrainerConfig:
26 | trainer_name: str = "basic"
27 | pretrained_model_name_or_path: str = ""
28 |
29 | attn_config: dict = field(default_factory=dict)
30 | dataset_name: str = ""
31 | dataset_config_name: Optional[str] = None
32 | resolution: str = "1024"
33 | dataloader_num_workers: int = 4
34 | pair_sampler_group_size: int = 1
35 | num_views: int = 4
36 |
37 | max_train_steps: int = -1 # -1 means infinity, otherwise [0, max_train_steps)
38 | training_step_interval: int = 1 # train on step i*interval, stop at max_train_steps
39 | max_train_samples: Optional[int] = None
40 | seed: Optional[int] = None # For dataset related operations and validation stuff
41 | train_batch_size: int = 1
42 |
43 | validation_interval: int = 5000
44 | debug: bool = False
45 |
46 | cfg: TrainerConfig # only enable_xxx is used
47 |
48 | def __init__(
49 | self,
50 | accelerator: Accelerator,
51 | logger: MultiProcessAdapter,
52 | unet: ConfigurableUNet2DConditionModel,
53 | config: Union[dict, str],
54 | weight_dtype: torch.dtype,
55 | index: int,
56 | ):
57 | super().__init__()
58 | self.index = index # index in all trainers
59 | self.accelerator = accelerator
60 | self.logger = logger
61 | self.unet = unet
62 | self.weight_dtype = weight_dtype
63 | self.ext_logs = {}
64 | self.cfg = load_config(self.TrainerConfig, config)
65 | self.attn_config = load_config(AttnConfig, self.cfg.attn_config)
66 | self.test_dataset = None
67 | self.validate_trainer_config()
68 | self.configure()
69 |
70 | def get_HW(self):
71 | resolution = json.loads(self.cfg.resolution)
72 | if isinstance(resolution, int):
73 | H = W = resolution
74 | elif isinstance(resolution, list):
75 | H, W = resolution
76 | return H, W
77 |
78 | def unet_update(self):
79 | self.unet.update_config(self.attn_config)
80 |
81 | def validate_trainer_config(self):
82 | pass
83 |
84 | def is_train_finished(self, current_step):
85 | assert isinstance(self.cfg.max_train_steps, int)
86 | return self.cfg.max_train_steps != -1 and current_step >= self.cfg.max_train_steps
87 |
88 | def next_train_step(self, current_step):
89 | if self.is_train_finished(current_step):
90 | return None
91 | return current_step + self.cfg.training_step_interval
92 |
93 | @classmethod
94 | def make_image_into_grid(cls, all_imgs, rows=2, columns=2):
95 | catted = [make_image_grid(all_imgs[i:i+rows * columns], rows=rows, cols=columns) for i in range(0, len(all_imgs), rows * columns)]
96 | return make_image_grid(catted, rows=1, cols=len(catted))
97 |
98 | def configure(self) -> None:
99 | pass
100 |
101 | @abc.abstractmethod
102 | def init_shared_modules(self, shared_modules: dict) -> dict:
103 | pass
104 |
105 | def load_dataset(self):
106 | dataset = load_dataset(
107 | self.cfg.dataset_name,
108 | self.cfg.dataset_config_name,
109 | trust_remote_code=True
110 | )
111 | return dataset
112 |
113 | @abc.abstractmethod
114 | def init_train_dataloader(self, shared_modules: dict) -> torch.utils.data.DataLoader:
115 | """Both init train_dataloader and test_dataset, but returns train_dataloader only"""
116 | pass
117 |
118 | @abc.abstractmethod
119 | def forward_step(
120 | self,
121 | *args,
122 | **kwargs
123 | ) -> torch.Tensor:
124 | """
125 | input a batch
126 | return a loss
127 | """
128 | self.unet_update()
129 | pass
130 |
131 | @abc.abstractmethod
132 | def construct_pipeline(self, shared_modules, unet):
133 | pass
134 |
135 | @abc.abstractmethod
136 | def pipeline_forward(self, pipeline, **pipeline_call_kwargs) -> tuple:
137 | """
138 | For inference time forward.
139 | """
140 | pass
141 |
142 | @abc.abstractmethod
143 | def batched_validation_forward(self, pipeline, **pipeline_call_kwargs) -> tuple:
144 | pass
145 |
146 | def do_validation(
147 | self,
148 | shared_modules,
149 | unet,
150 | global_step,
151 | ):
152 | self.unet_update()
153 | self.logger.info("Running validation... ")
154 | pipeline = self.construct_pipeline(shared_modules, unet)
155 | pipeline.set_progress_bar_config(disable=True)
156 | titles, images = self.batched_validation_forward(pipeline, guidance_scale=[1., 3.])
157 | for tracker in self.accelerator.trackers:
158 | if tracker.name == "tensorboard":
159 | np_images = np.stack([np.asarray(img) for img in images])
160 | tracker.writer.add_images("validation", np_images, global_step, dataformats="NHWC")
161 | elif tracker.name == "wandb":
162 | [image.thumbnail((512, 512)) for image, title in zip(images, titles) if 'noresize' not in title] # inplace operation
163 | tracker.log({"validation": [
164 | wandb.Image(image, caption=f"{i}: {titles[i]}", file_type="jpg")
165 | for i, image in enumerate(images)]})
166 | else:
167 | self.logger.warn(f"image logging not implemented for {tracker.name}")
168 | del pipeline
169 | torch.cuda.empty_cache()
170 | return images
171 |
172 |
173 | @torch.no_grad()
174 | def log_validation(
175 | self,
176 | shared_modules,
177 | unet,
178 | global_step,
179 | force=False
180 | ):
181 | if self.accelerator.is_main_process:
182 | for tracker in self.accelerator.trackers:
183 | if tracker.name == "wandb":
184 | tracker.log(self.ext_logs)
185 | self.ext_logs = {}
186 | if (global_step % self.cfg.validation_interval == 0 and not self.is_train_finished(global_step)) or force:
187 | self.unet_update()
188 | if self.accelerator.is_main_process:
189 | self.do_validation(shared_modules, self.accelerator.unwrap_model(unet), global_step)
190 |
191 | def save_model(self, unwrap_unet, shared_modules, save_dir):
192 | if self.accelerator.is_main_process:
193 | pipeline = self.construct_pipeline(shared_modules, unwrap_unet)
194 | pipeline.save_pretrained(save_dir)
195 | self.logger.info(f"{self.cfg.trainer_name} Model saved at {save_dir}")
196 |
197 | def save_debug_info(self, save_name="debug", **kwargs):
198 | if self.cfg.debug:
199 | to_saves = {key: value.detach().cpu() if isinstance(value, torch.Tensor) else value for key, value in kwargs.items()}
200 | import pickle
201 | import os
202 | if os.path.exists(f"{save_name}.pkl"):
203 | for i in range(100):
204 | if not os.path.exists(f"{save_name}_v{i}.pkl"):
205 | save_name = f"{save_name}_v{i}"
206 | break
207 | with open(f"{save_name}.pkl", "wb") as f:
208 | pickle.dump(to_saves, f)
--------------------------------------------------------------------------------
/custum_3d_diffusion/trainings/config_classes.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import List, Optional
3 |
4 |
5 | @dataclass
6 | class TrainerSubConfig:
7 | trainer_type: str = ""
8 | trainer: dict = field(default_factory=dict)
9 |
10 |
11 | @dataclass
12 | class ExprimentConfig:
13 | trainers: List[dict] = field(default_factory=lambda: [])
14 | init_config: dict = field(default_factory=dict)
15 | pretrained_model_name_or_path: str = ""
16 | pretrained_unet_state_dict_path: str = ""
17 | # expriments related parameters
18 | linear_beta_schedule: bool = False
19 | zero_snr: bool = False
20 | prediction_type: Optional[str] = None
21 | seed: Optional[int] = None
22 | max_train_steps: int = 1000000
23 | gradient_accumulation_steps: int = 1
24 | learning_rate: float = 1e-4
25 | lr_scheduler: str = "constant"
26 | lr_warmup_steps: int = 500
27 | use_8bit_adam: bool = False
28 | adam_beta1: float = 0.9
29 | adam_beta2: float = 0.999
30 | adam_weight_decay: float = 1e-2
31 | adam_epsilon: float = 1e-08
32 | max_grad_norm: float = 1.0
33 | mixed_precision: Optional[str] = None # ["no", "fp16", "bf16", "fp8"]
34 | skip_training: bool = False
35 | debug: bool = False
--------------------------------------------------------------------------------
/custum_3d_diffusion/trainings/image2image_trainer.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 | from diffusers import EulerAncestralDiscreteScheduler, DDPMScheduler
4 | from dataclasses import dataclass
5 |
6 | from custum_3d_diffusion.modules import register
7 | from custum_3d_diffusion.trainings.image2mvimage_trainer import Image2MVImageTrainer
8 | from custum_3d_diffusion.custum_pipeline.unifield_pipeline_img2img import StableDiffusionImageCustomPipeline
9 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
10 |
11 | def get_HW(resolution):
12 | if isinstance(resolution, str):
13 | resolution = json.loads(resolution)
14 | if isinstance(resolution, int):
15 | H = W = resolution
16 | elif isinstance(resolution, list):
17 | H, W = resolution
18 | return H, W
19 |
20 |
21 | @register("image2image_trainer")
22 | class Image2ImageTrainer(Image2MVImageTrainer):
23 | """
24 | Trainer for simple image to multiview images.
25 | """
26 | @dataclass
27 | class TrainerConfig(Image2MVImageTrainer.TrainerConfig):
28 | trainer_name: str = "image2image"
29 |
30 | cfg: TrainerConfig
31 |
32 | def forward_step(self, batch, unet, shared_modules, noise_scheduler: DDPMScheduler, global_step) -> torch.Tensor:
33 | raise NotImplementedError()
34 |
35 | def construct_pipeline(self, shared_modules, unet, old_version=False):
36 | MyPipeline = StableDiffusionImageCustomPipeline
37 | pipeline = MyPipeline.from_pretrained(
38 | self.cfg.pretrained_model_name_or_path,
39 | vae=shared_modules['vae'],
40 | image_encoder=shared_modules['image_encoder'],
41 | feature_extractor=shared_modules['feature_extractor'],
42 | unet=unet,
43 | safety_checker=None,
44 | torch_dtype=self.weight_dtype,
45 | latents_offset=self.cfg.latents_offset,
46 | noisy_cond_latents=self.cfg.noisy_condition_input,
47 | )
48 | pipeline.set_progress_bar_config(disable=True)
49 | scheduler_dict = {}
50 | if self.cfg.zero_snr:
51 | scheduler_dict.update(rescale_betas_zero_snr=True)
52 | if self.cfg.linear_beta_schedule:
53 | scheduler_dict.update(beta_schedule='linear')
54 |
55 | pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config, **scheduler_dict)
56 | return pipeline
57 |
58 | def get_forward_args(self):
59 | if self.cfg.seed is None:
60 | generator = None
61 | else:
62 | generator = torch.Generator(device=self.accelerator.device).manual_seed(self.cfg.seed)
63 |
64 | H, W = get_HW(self.cfg.resolution)
65 | H_cond, W_cond = get_HW(self.cfg.condition_image_resolution)
66 |
67 | forward_args = dict(
68 | num_images_per_prompt=1,
69 | num_inference_steps=20,
70 | height=H,
71 | width=W,
72 | height_cond=H_cond,
73 | width_cond=W_cond,
74 | generator=generator,
75 | )
76 | if self.cfg.zero_snr:
77 | forward_args.update(guidance_rescale=0.7)
78 | return forward_args
79 |
80 | def pipeline_forward(self, pipeline, **pipeline_call_kwargs) -> StableDiffusionPipelineOutput:
81 | forward_args = self.get_forward_args()
82 | forward_args.update(pipeline_call_kwargs)
83 | return pipeline(**forward_args)
84 |
85 | def batched_validation_forward(self, pipeline, **pipeline_call_kwargs) -> tuple:
86 | raise NotImplementedError()
--------------------------------------------------------------------------------
/custum_3d_diffusion/trainings/image2mvimage_trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from diffusers import AutoencoderKL, DDPMScheduler, EulerAncestralDiscreteScheduler, DDIMScheduler
3 | from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, BatchFeature
4 |
5 | import json
6 | from dataclasses import dataclass
7 | from typing import List, Optional
8 |
9 | from custum_3d_diffusion.modules import register
10 | from custum_3d_diffusion.trainings.base import BasicTrainer
11 | from custum_3d_diffusion.custum_pipeline.unifield_pipeline_img2mvimg import StableDiffusionImage2MVCustomPipeline
12 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
13 |
14 | def get_HW(resolution):
15 | if isinstance(resolution, str):
16 | resolution = json.loads(resolution)
17 | if isinstance(resolution, int):
18 | H = W = resolution
19 | elif isinstance(resolution, list):
20 | H, W = resolution
21 | return H, W
22 |
23 | @register("image2mvimage_trainer")
24 | class Image2MVImageTrainer(BasicTrainer):
25 | """
26 | Trainer for simple image to multiview images.
27 | """
28 | @dataclass
29 | class TrainerConfig(BasicTrainer.TrainerConfig):
30 | trainer_name: str = "image2mvimage"
31 | condition_image_column_name: str = "conditioning_image"
32 | image_column_name: str = "image"
33 | condition_dropout: float = 0.
34 | condition_image_resolution: str = "512"
35 | validation_images: Optional[List[str]] = None
36 | noise_offset: float = 0.1
37 | max_loss_drop: float = 0.
38 | snr_gamma: float = 5.0
39 | log_distribution: bool = False
40 | latents_offset: Optional[List[float]] = None
41 | input_perturbation: float = 0.
42 | noisy_condition_input: bool = False # whether to add noise for ref unet input
43 | normal_cls_offset: int = 0
44 | condition_offset: bool = True
45 | zero_snr: bool = False
46 | linear_beta_schedule: bool = False
47 |
48 | cfg: TrainerConfig
49 |
50 | def configure(self) -> None:
51 | return super().configure()
52 |
53 | def init_shared_modules(self, shared_modules: dict) -> dict:
54 | if 'vae' not in shared_modules:
55 | vae = AutoencoderKL.from_pretrained(
56 | self.cfg.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.weight_dtype
57 | )
58 | vae.requires_grad_(False)
59 | vae.to(self.accelerator.device, dtype=self.weight_dtype)
60 | shared_modules['vae'] = vae
61 | if 'image_encoder' not in shared_modules:
62 | image_encoder = CLIPVisionModelWithProjection.from_pretrained(
63 | self.cfg.pretrained_model_name_or_path, subfolder="image_encoder"
64 | )
65 | image_encoder.requires_grad_(False)
66 | image_encoder.to(self.accelerator.device, dtype=self.weight_dtype)
67 | shared_modules['image_encoder'] = image_encoder
68 | if 'feature_extractor' not in shared_modules:
69 | feature_extractor = CLIPImageProcessor.from_pretrained(
70 | self.cfg.pretrained_model_name_or_path, subfolder="feature_extractor"
71 | )
72 | shared_modules['feature_extractor'] = feature_extractor
73 | return shared_modules
74 |
75 | def init_train_dataloader(self, shared_modules: dict) -> torch.utils.data.DataLoader:
76 | raise NotImplementedError()
77 |
78 | def loss_rescale(self, loss, timesteps=None):
79 | raise NotImplementedError()
80 |
81 | def forward_step(self, batch, unet, shared_modules, noise_scheduler: DDPMScheduler, global_step) -> torch.Tensor:
82 | raise NotImplementedError()
83 |
84 | def construct_pipeline(self, shared_modules, unet, old_version=False):
85 | MyPipeline = StableDiffusionImage2MVCustomPipeline
86 | pipeline = MyPipeline.from_pretrained(
87 | self.cfg.pretrained_model_name_or_path,
88 | vae=shared_modules['vae'],
89 | image_encoder=shared_modules['image_encoder'],
90 | feature_extractor=shared_modules['feature_extractor'],
91 | unet=unet,
92 | safety_checker=None,
93 | torch_dtype=self.weight_dtype,
94 | latents_offset=self.cfg.latents_offset,
95 | noisy_cond_latents=self.cfg.noisy_condition_input,
96 | condition_offset=self.cfg.condition_offset,
97 | )
98 | pipeline.set_progress_bar_config(disable=True)
99 | scheduler_dict = {}
100 | if self.cfg.zero_snr:
101 | scheduler_dict.update(rescale_betas_zero_snr=True)
102 | if self.cfg.linear_beta_schedule:
103 | scheduler_dict.update(beta_schedule='linear')
104 |
105 | pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config, **scheduler_dict)
106 | return pipeline
107 |
108 | def get_forward_args(self):
109 | if self.cfg.seed is None:
110 | generator = None
111 | else:
112 | generator = torch.Generator(device=self.accelerator.device).manual_seed(self.cfg.seed)
113 |
114 | H, W = get_HW(self.cfg.resolution)
115 | H_cond, W_cond = get_HW(self.cfg.condition_image_resolution)
116 |
117 | sub_img_H = H // 2
118 | num_imgs = H // sub_img_H * W // sub_img_H
119 |
120 | forward_args = dict(
121 | num_images_per_prompt=num_imgs,
122 | num_inference_steps=50,
123 | height=sub_img_H,
124 | width=sub_img_H,
125 | height_cond=H_cond,
126 | width_cond=W_cond,
127 | generator=generator,
128 | )
129 | if self.cfg.zero_snr:
130 | forward_args.update(guidance_rescale=0.7)
131 | return forward_args
132 |
133 | def pipeline_forward(self, pipeline, **pipeline_call_kwargs) -> StableDiffusionPipelineOutput:
134 | forward_args = self.get_forward_args()
135 | forward_args.update(pipeline_call_kwargs)
136 | return pipeline(**forward_args)
137 |
138 | def batched_validation_forward(self, pipeline, **pipeline_call_kwargs) -> tuple:
139 | raise NotImplementedError()
--------------------------------------------------------------------------------
/custum_3d_diffusion/trainings/utils.py:
--------------------------------------------------------------------------------
1 | from omegaconf import DictConfig, OmegaConf
2 |
3 |
4 | def parse_structured(fields, cfg) -> DictConfig:
5 | scfg = OmegaConf.structured(fields(**cfg))
6 | return scfg
7 |
8 |
9 | def load_config(fields, config, extras=None):
10 | if extras is not None:
11 | print("Warning! extra parameter in cli is not verified, may cause erros.")
12 | if isinstance(config, str):
13 | cfg = OmegaConf.load(config)
14 | elif isinstance(config, dict):
15 | cfg = OmegaConf.create(config)
16 | elif isinstance(config, DictConfig):
17 | cfg = config
18 | else:
19 | raise NotImplementedError(f"Unsupported config type {type(config)}")
20 | if extras is not None:
21 | cli_conf = OmegaConf.from_cli(extras)
22 | cfg = OmegaConf.merge(cfg, cli_conf)
23 | OmegaConf.resolve(cfg)
24 | assert isinstance(cfg, DictConfig)
25 | return parse_structured(fields, cfg)
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | # get the development image from nvidia cuda 12.1
2 | FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04
3 |
4 | LABEL name="unique3d" maintainer="unique3d"
5 |
6 | # create workspace folder and set it as working directory
7 | RUN mkdir -p /workspace
8 | WORKDIR /workspace
9 |
10 | # update package lists and install git, wget, vim, libegl1-mesa-dev, and libglib2.0-0
11 | RUN apt-get update && apt-get install -y build-essential git wget vim libegl1-mesa-dev libglib2.0-0 unzip git-lfs
12 |
13 | RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends pkg-config libglvnd0 libgl1 libglx0 libegl1 libgles2 libglvnd-dev libgl1-mesa-dev libegl1-mesa-dev libgles2-mesa-dev cmake curl mesa-utils-extra
14 | ENV PYTHONDONTWRITEBYTECODE=1
15 | ENV PYTHONUNBUFFERED=1
16 | ENV LD_LIBRARY_PATH=/usr/lib64:$LD_LIBRARY_PATH
17 | ENV PYOPENGL_PLATFORM=egl
18 |
19 | # install conda
20 | RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
21 | chmod +x Miniconda3-latest-Linux-x86_64.sh && \
22 | ./Miniconda3-latest-Linux-x86_64.sh -b -p /workspace/miniconda3 && \
23 | rm Miniconda3-latest-Linux-x86_64.sh
24 |
25 | # update PATH environment variable
26 | ENV PATH="/workspace/miniconda3/bin:${PATH}"
27 |
28 | # initialize conda
29 | RUN conda init bash
30 |
31 | # create and activate conda environment
32 | RUN conda create -n unique3d python=3.10 && echo "source activate unique3d" > ~/.bashrc
33 | ENV PATH /workspace/miniconda3/envs/unique3d/bin:$PATH
34 |
35 | RUN conda install Ninja
36 | RUN conda install cuda -c nvidia/label/cuda-12.1.0 -y
37 |
38 | RUN pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 xformers triton --index-url https://download.pytorch.org/whl/cu121
39 | RUN pip install diffusers==0.27.2
40 |
41 | RUN git clone --depth 1 https://huggingface.co/spaces/Wuvin/Unique3D
42 |
43 | # change the working directory to the repository
44 |
45 | WORKDIR /workspace/Unique3D
46 | # other dependencies
47 | RUN pip install -r requirements.txt
48 |
49 | RUN pip install nvidia-pyindex
50 |
51 | RUN pip install --upgrade nvidia-tensorrt
52 |
53 | RUN pip install spaces
54 |
55 |
--------------------------------------------------------------------------------
/docker/README.md:
--------------------------------------------------------------------------------
1 | # Docker setup
2 |
3 | This docker setup is tested on Windows 10.
4 |
5 | make sure you are under this directory yourworkspace/Unique3D/docker
6 |
7 | Build docker image:
8 |
9 | ```
10 | docker build -t unique3d -f Dockerfile .
11 | ```
12 |
13 | Run docker image at the first time:
14 |
15 | ```
16 | docker run -it --name unique3d -p 7860:7860 --gpus all unique3d python app.py
17 | ```
18 |
19 | After first time:
20 | ```
21 | docker start unique3d
22 | docker exec unique3d python app.py
23 | ```
24 |
25 | Stop the container:
26 | ```
27 | docker stop unique3d
28 | ```
29 |
30 | You can find the demo link showing in terminal, such as `https://94fc1ba77a08526e17.gradio.live/` or something similar else (it will be changed after each time to restart the container) to use the demo.
31 |
32 | Some notes:
33 | 1. this docker build is using https://huggingface.co/spaces/Wuvin/Unique3D rather than this repo to clone the source.
34 | 2. the total built time might take more than one hour.
35 | 3. the total size of the built image will be more than 70GB.
--------------------------------------------------------------------------------
/gradio_app.py:
--------------------------------------------------------------------------------
1 | if __name__ == "__main__":
2 | import os
3 | import sys
4 | sys.path.append(os.curdir)
5 | import torch
6 | torch.set_float32_matmul_precision('medium')
7 | torch.backends.cuda.matmul.allow_tf32 = True
8 | torch.set_grad_enabled(False)
9 |
10 | import fire
11 | import gradio as gr
12 | from app.gradio_3dgen import create_ui as create_3d_ui
13 | from app.all_models import model_zoo
14 |
15 |
16 | _TITLE = '''Unique3D: High-Quality and Efficient 3D Mesh Generation from a Single Image'''
17 | _DESCRIPTION = '''
18 | [Project page](https://wukailu.github.io/Unique3D/)
19 |
20 | * High-fidelity and diverse textured meshes generated by Unique3D from single-view images.
21 |
22 | * The demo is still under construction, and more features are expected to be implemented soon.
23 | '''
24 |
25 | def launch():
26 | model_zoo.init_models()
27 |
28 | with gr.Blocks(
29 | title=_TITLE,
30 | theme=gr.themes.Monochrome(),
31 | ) as demo:
32 | with gr.Row():
33 | with gr.Column(scale=1):
34 | gr.Markdown('# ' + _TITLE)
35 | gr.Markdown(_DESCRIPTION)
36 | create_3d_ui("wkl")
37 |
38 | demo.queue().launch(share=True)
39 |
40 | if __name__ == '__main__':
41 | fire.Fire(launch)
42 |
--------------------------------------------------------------------------------
/install_windows_win_py311_cu121.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 |
3 | set "triton_whl=%~dp0\triton-2.1.0-cp311-cp311-win_amd64.whl"
4 |
5 | echo Starting to install Unique3D...
6 |
7 | echo Installing torch, xformers, etc
8 |
9 | pip install torch torchvision torchaudio xformers --index-url https://download.pytorch.org/whl/cu121
10 |
11 | echo Installing triton
12 |
13 | pip install "%triton_whl%"
14 |
15 | pip install Ninja
16 |
17 | pip install diffusers==0.27.2
18 |
19 | pip install grpcio werkzeug tensorboard-data-server
20 |
21 | pip install -r requirements-win-py311-cu121.txt
22 |
23 | echo Removing default onnxruntime and onnxruntime-gpu
24 |
25 | pip uninstall onnxruntime
26 | pip uninstall onnxruntime-gpu
27 |
28 | echo Installing correct version onnxruntime-gpu for cuda 12.1
29 |
30 | pip install onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/
31 |
32 | echo Install Finished. Press any key to continue...
33 |
34 | pause
--------------------------------------------------------------------------------
/mesh_reconstruction/func.py:
--------------------------------------------------------------------------------
1 | # modified from https://github.com/Profactor/continuous-remeshing
2 | import torch
3 | import numpy as np
4 | import trimesh
5 | from typing import Tuple
6 |
7 | def to_numpy(*args):
8 | def convert(a):
9 | if isinstance(a,torch.Tensor):
10 | return a.detach().cpu().numpy()
11 | assert a is None or isinstance(a,np.ndarray)
12 | return a
13 |
14 | return convert(args[0]) if len(args)==1 else tuple(convert(a) for a in args)
15 |
16 | def laplacian(
17 | num_verts:int,
18 | edges: torch.Tensor #E,2
19 | ) -> torch.Tensor: #sparse V,V
20 | """create sparse Laplacian matrix"""
21 | V = num_verts
22 | E = edges.shape[0]
23 |
24 | #adjacency matrix,
25 | idx = torch.cat([edges, edges.fliplr()], dim=0).type(torch.long).T # (2, 2*E)
26 | ones = torch.ones(2*E, dtype=torch.float32, device=edges.device)
27 | A = torch.sparse.FloatTensor(idx, ones, (V, V))
28 |
29 | #degree matrix
30 | deg = torch.sparse.sum(A, dim=1).to_dense()
31 | idx = torch.arange(V, device=edges.device)
32 | idx = torch.stack([idx, idx], dim=0)
33 | D = torch.sparse.FloatTensor(idx, deg, (V, V))
34 |
35 | return D - A
36 |
37 | def _translation(x, y, z, device):
38 | return torch.tensor([[1., 0, 0, x],
39 | [0, 1, 0, y],
40 | [0, 0, 1, z],
41 | [0, 0, 0, 1]],device=device) #4,4
42 |
43 | def _projection(r, device, l=None, t=None, b=None, n=1.0, f=50.0, flip_y=True):
44 | """
45 | see https://blog.csdn.net/wodownload2/article/details/85069240/
46 | """
47 | if l is None:
48 | l = -r
49 | if t is None:
50 | t = r
51 | if b is None:
52 | b = -t
53 | p = torch.zeros([4,4],device=device)
54 | p[0,0] = 2*n/(r-l)
55 | p[0,2] = (r+l)/(r-l)
56 | p[1,1] = 2*n/(t-b) * (-1 if flip_y else 1)
57 | p[1,2] = (t+b)/(t-b)
58 | p[2,2] = -(f+n)/(f-n)
59 | p[2,3] = -(2*f*n)/(f-n)
60 | p[3,2] = -1
61 | return p #4,4
62 |
63 | def _orthographic(r, device, l=None, t=None, b=None, n=1.0, f=50.0, flip_y=True):
64 | if l is None:
65 | l = -r
66 | if t is None:
67 | t = r
68 | if b is None:
69 | b = -t
70 | o = torch.zeros([4,4],device=device)
71 | o[0,0] = 2/(r-l)
72 | o[0,3] = -(r+l)/(r-l)
73 | o[1,1] = 2/(t-b) * (-1 if flip_y else 1)
74 | o[1,3] = -(t+b)/(t-b)
75 | o[2,2] = -2/(f-n)
76 | o[2,3] = -(f+n)/(f-n)
77 | o[3,3] = 1
78 | return o #4,4
79 |
80 | def make_star_cameras(az_count,pol_count,distance:float=10.,r=None,image_size=[512,512],device='cuda'):
81 | if r is None:
82 | r = 1/distance
83 | A = az_count
84 | P = pol_count
85 | C = A * P
86 |
87 | phi = torch.arange(0,A) * (2*torch.pi/A)
88 | phi_rot = torch.eye(3,device=device)[None,None].expand(A,1,3,3).clone()
89 | phi_rot[:,0,2,2] = phi.cos()
90 | phi_rot[:,0,2,0] = -phi.sin()
91 | phi_rot[:,0,0,2] = phi.sin()
92 | phi_rot[:,0,0,0] = phi.cos()
93 |
94 | theta = torch.arange(1,P+1) * (torch.pi/(P+1)) - torch.pi/2
95 | theta_rot = torch.eye(3,device=device)[None,None].expand(1,P,3,3).clone()
96 | theta_rot[0,:,1,1] = theta.cos()
97 | theta_rot[0,:,1,2] = -theta.sin()
98 | theta_rot[0,:,2,1] = theta.sin()
99 | theta_rot[0,:,2,2] = theta.cos()
100 |
101 | mv = torch.empty((C,4,4), device=device)
102 | mv[:] = torch.eye(4, device=device)
103 | mv[:,:3,:3] = (theta_rot @ phi_rot).reshape(C,3,3)
104 | mv = _translation(0, 0, -distance, device) @ mv
105 |
106 | return mv, _projection(r,device)
107 |
108 | def make_star_cameras_orthographic(az_count,pol_count,distance:float=10.,r=None,image_size=[512,512],device='cuda'):
109 | mv, _ = make_star_cameras(az_count,pol_count,distance,r,image_size,device)
110 | if r is None:
111 | r = 1
112 | return mv, _orthographic(r,device)
113 |
114 | def make_sphere(level:int=2,radius=1.,device='cuda') -> Tuple[torch.Tensor,torch.Tensor]:
115 | sphere = trimesh.creation.icosphere(subdivisions=level, radius=1.0, color=None)
116 | vertices = torch.tensor(sphere.vertices, device=device, dtype=torch.float32) * radius
117 | faces = torch.tensor(sphere.faces, device=device, dtype=torch.long)
118 | return vertices,faces
119 |
120 | from pytorch3d.renderer import (
121 | FoVOrthographicCameras,
122 | look_at_view_transform,
123 | )
124 |
125 | def get_camera(R, T, focal_length=1 / (2**0.5)):
126 | focal_length = 1 / focal_length
127 | camera = FoVOrthographicCameras(device=R.device, R=R, T=T, min_x=-focal_length, max_x=focal_length, min_y=-focal_length, max_y=focal_length)
128 | return camera
129 |
130 | def make_star_cameras_orthographic_py3d(azim_list, device, focal=2/1.35, dist=1.1):
131 | R, T = look_at_view_transform(dist, 0, azim_list)
132 | focal_length = 1 / focal
133 | return FoVOrthographicCameras(device=R.device, R=R, T=T, min_x=-focal_length, max_x=focal_length, min_y=-focal_length, max_y=focal_length).to(device)
134 |
--------------------------------------------------------------------------------
/mesh_reconstruction/opt.py:
--------------------------------------------------------------------------------
1 | # modified from https://github.com/Profactor/continuous-remeshing
2 | import time
3 | import torch
4 | import torch_scatter
5 | from typing import Tuple
6 | from mesh_reconstruction.remesh import calc_edge_length, calc_edges, calc_face_collapses, calc_face_normals, calc_vertex_normals, collapse_edges, flip_edges, pack, prepend_dummies, remove_dummies, split_edges
7 |
8 | @torch.no_grad()
9 | def remesh(
10 | vertices_etc:torch.Tensor, #V,D
11 | faces:torch.Tensor, #F,3 long
12 | min_edgelen:torch.Tensor, #V
13 | max_edgelen:torch.Tensor, #V
14 | flip:bool,
15 | max_vertices=1e6
16 | ):
17 |
18 | # dummies
19 | vertices_etc,faces = prepend_dummies(vertices_etc,faces)
20 | vertices = vertices_etc[:,:3] #V,3
21 | nan_tensor = torch.tensor([torch.nan],device=min_edgelen.device)
22 | min_edgelen = torch.concat((nan_tensor,min_edgelen))
23 | max_edgelen = torch.concat((nan_tensor,max_edgelen))
24 |
25 | # collapse
26 | edges,face_to_edge = calc_edges(faces) #E,2 F,3
27 | edge_length = calc_edge_length(vertices,edges) #E
28 | face_normals = calc_face_normals(vertices,faces,normalize=False) #F,3
29 | vertex_normals = calc_vertex_normals(vertices,faces,face_normals) #V,3
30 | face_collapse = calc_face_collapses(vertices,faces,edges,face_to_edge,edge_length,face_normals,vertex_normals,min_edgelen,area_ratio=0.5)
31 | shortness = (1 - edge_length / min_edgelen[edges].mean(dim=-1)).clamp_min_(0) #e[0,1] 0...ok, 1...edgelen=0
32 | priority = face_collapse.float() + shortness
33 | vertices_etc,faces = collapse_edges(vertices_etc,faces,edges,priority)
34 |
35 | # split
36 | if vertices.shape[0] max_edgelen[edges].mean(dim=-1)
41 | vertices_etc,faces = split_edges(vertices_etc,faces,edges,face_to_edge,splits,pack_faces=False)
42 |
43 | vertices_etc,faces = pack(vertices_etc,faces)
44 | vertices = vertices_etc[:,:3]
45 |
46 | if flip:
47 | edges,_,edge_to_face = calc_edges(faces,with_edge_to_face=True) #E,2 F,3
48 | flip_edges(vertices,faces,edges,edge_to_face,with_border=False)
49 |
50 | return remove_dummies(vertices_etc,faces)
51 |
52 | def lerp_unbiased(a:torch.Tensor,b:torch.Tensor,weight:float,step:int):
53 | """lerp with adam's bias correction"""
54 | c_prev = 1-weight**(step-1)
55 | c = 1-weight**step
56 | a_weight = weight*c_prev/c
57 | b_weight = (1-weight)/c
58 | a.mul_(a_weight).add_(b, alpha=b_weight)
59 |
60 |
61 | class MeshOptimizer:
62 | """Use this like a pytorch Optimizer, but after calling opt.step(), do vertices,faces = opt.remesh()."""
63 |
64 | def __init__(self,
65 | vertices:torch.Tensor, #V,3
66 | faces:torch.Tensor, #F,3
67 | lr=0.3, #learning rate
68 | betas=(0.8,0.8,0), #betas[0:2] are the same as in Adam, betas[2] may be used to time-smooth the relative velocity nu
69 | gammas=(0,0,0), #optional spatial smoothing for m1,m2,nu, values between 0 (no smoothing) and 1 (max. smoothing)
70 | nu_ref=0.3, #reference velocity for edge length controller
71 | edge_len_lims=(.01,.15), #smallest and largest allowed reference edge length
72 | edge_len_tol=.5, #edge length tolerance for split and collapse
73 | gain=.2, #gain value for edge length controller
74 | laplacian_weight=.02, #for laplacian smoothing/regularization
75 | ramp=1, #learning rate ramp, actual ramp width is ramp/(1-betas[0])
76 | grad_lim=10., #gradients are clipped to m1.abs()*grad_lim
77 | remesh_interval=1, #larger intervals are faster but with worse mesh quality
78 | local_edgelen=True, #set to False to use a global scalar reference edge length instead
79 | ):
80 | self._vertices = vertices
81 | self._faces = faces
82 | self._lr = lr
83 | self._betas = betas
84 | self._gammas = gammas
85 | self._nu_ref = nu_ref
86 | self._edge_len_lims = edge_len_lims
87 | self._edge_len_tol = edge_len_tol
88 | self._gain = gain
89 | self._laplacian_weight = laplacian_weight
90 | self._ramp = ramp
91 | self._grad_lim = grad_lim
92 | self._remesh_interval = remesh_interval
93 | self._local_edgelen = local_edgelen
94 | self._step = 0
95 |
96 | V = self._vertices.shape[0]
97 | # prepare continuous tensor for all vertex-based data
98 | self._vertices_etc = torch.zeros([V,9],device=vertices.device)
99 | self._split_vertices_etc()
100 | self.vertices.copy_(vertices) #initialize vertices
101 | self._vertices.requires_grad_()
102 | self._ref_len.fill_(edge_len_lims[1])
103 |
104 | @property
105 | def vertices(self):
106 | return self._vertices
107 |
108 | @property
109 | def faces(self):
110 | return self._faces
111 |
112 | def _split_vertices_etc(self):
113 | self._vertices = self._vertices_etc[:,:3]
114 | self._m2 = self._vertices_etc[:,3]
115 | self._nu = self._vertices_etc[:,4]
116 | self._m1 = self._vertices_etc[:,5:8]
117 | self._ref_len = self._vertices_etc[:,8]
118 |
119 | with_gammas = any(g!=0 for g in self._gammas)
120 | self._smooth = self._vertices_etc[:,:8] if with_gammas else self._vertices_etc[:,:3]
121 |
122 | def zero_grad(self):
123 | self._vertices.grad = None
124 |
125 | @torch.no_grad()
126 | def step(self):
127 |
128 | eps = 1e-8
129 |
130 | self._step += 1
131 |
132 | # spatial smoothing
133 | edges,_ = calc_edges(self._faces) #E,2
134 | E = edges.shape[0]
135 | edge_smooth = self._smooth[edges] #E,2,S
136 | neighbor_smooth = torch.zeros_like(self._smooth) #V,S
137 | torch_scatter.scatter_mean(src=edge_smooth.flip(dims=[1]).reshape(E*2,-1),index=edges.reshape(E*2,1),dim=0,out=neighbor_smooth)
138 |
139 | #apply optional smoothing of m1,m2,nu
140 | if self._gammas[0]:
141 | self._m1.lerp_(neighbor_smooth[:,5:8],self._gammas[0])
142 | if self._gammas[1]:
143 | self._m2.lerp_(neighbor_smooth[:,3],self._gammas[1])
144 | if self._gammas[2]:
145 | self._nu.lerp_(neighbor_smooth[:,4],self._gammas[2])
146 |
147 | #add laplace smoothing to gradients
148 | laplace = self._vertices - neighbor_smooth[:,:3]
149 | grad = torch.addcmul(self._vertices.grad, laplace, self._nu[:,None], value=self._laplacian_weight)
150 |
151 | #gradient clipping
152 | if self._step>1:
153 | grad_lim = self._m1.abs().mul_(self._grad_lim)
154 | grad.clamp_(min=-grad_lim,max=grad_lim)
155 |
156 | # moment updates
157 | lerp_unbiased(self._m1, grad, self._betas[0], self._step)
158 | lerp_unbiased(self._m2, (grad**2).sum(dim=-1), self._betas[1], self._step)
159 |
160 | velocity = self._m1 / self._m2[:,None].sqrt().add_(eps) #V,3
161 | speed = velocity.norm(dim=-1) #V
162 |
163 | if self._betas[2]:
164 | lerp_unbiased(self._nu,speed,self._betas[2],self._step) #V
165 | else:
166 | self._nu.copy_(speed) #V
167 |
168 | # update vertices
169 | ramped_lr = self._lr * min(1,self._step * (1-self._betas[0]) / self._ramp)
170 | self._vertices.add_(velocity * self._ref_len[:,None], alpha=-ramped_lr)
171 |
172 | # update target edge length
173 | if self._step % self._remesh_interval == 0:
174 | if self._local_edgelen:
175 | len_change = (1 + (self._nu - self._nu_ref) * self._gain)
176 | else:
177 | len_change = (1 + (self._nu.mean() - self._nu_ref) * self._gain)
178 | self._ref_len *= len_change
179 | self._ref_len.clamp_(*self._edge_len_lims)
180 |
181 | def remesh(self, flip:bool=True, poisson=False)->Tuple[torch.Tensor,torch.Tensor]:
182 | min_edge_len = self._ref_len * (1 - self._edge_len_tol)
183 | max_edge_len = self._ref_len * (1 + self._edge_len_tol)
184 |
185 | self._vertices_etc,self._faces = remesh(self._vertices_etc,self._faces,min_edge_len,max_edge_len,flip, max_vertices=1e6)
186 |
187 | self._split_vertices_etc()
188 | self._vertices.requires_grad_()
189 |
190 | return self._vertices, self._faces
191 |
--------------------------------------------------------------------------------
/mesh_reconstruction/recon.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | from PIL import Image
3 | import numpy as np
4 | import torch
5 | from typing import List
6 | from mesh_reconstruction.remesh import calc_vertex_normals
7 | from mesh_reconstruction.opt import MeshOptimizer
8 | from mesh_reconstruction.func import make_star_cameras_orthographic, make_star_cameras_orthographic_py3d
9 | from mesh_reconstruction.render import NormalsRenderer, Pytorch3DNormalsRenderer
10 | from scripts.utils import to_py3d_mesh, init_target
11 |
12 | def reconstruct_stage1(pils: List[Image.Image], steps=100, vertices=None, faces=None, start_edge_len=0.15, end_edge_len=0.005, decay=0.995, return_mesh=True, loss_expansion_weight=0.1, gain=0.1):
13 | vertices, faces = vertices.to("cuda"), faces.to("cuda")
14 | assert len(pils) == 4
15 | mv,proj = make_star_cameras_orthographic(4, 1)
16 | renderer = NormalsRenderer(mv,proj,list(pils[0].size))
17 | # cameras = make_star_cameras_orthographic_py3d([0, 270, 180, 90], device="cuda", focal=1., dist=4.0)
18 | # renderer = Pytorch3DNormalsRenderer(cameras, list(pils[0].size), device="cuda")
19 |
20 | target_images = init_target(pils, new_bkgd=(0., 0., 0.)) # 4s
21 | # 1. no rotate
22 | target_images = target_images[[0, 3, 2, 1]]
23 |
24 | # 2. init from coarse mesh
25 | opt = MeshOptimizer(vertices,faces, local_edgelen=False, gain=gain, edge_len_lims=(end_edge_len, start_edge_len))
26 |
27 | vertices = opt.vertices
28 |
29 | mask = target_images[..., -1] < 0.5
30 |
31 | for i in tqdm(range(steps)):
32 | opt.zero_grad()
33 | opt._lr *= decay
34 | normals = calc_vertex_normals(vertices,faces)
35 | images = renderer.render(vertices,normals,faces)
36 |
37 | loss_expand = 0.5 * ((vertices+normals).detach() - vertices).pow(2).mean()
38 |
39 | t_mask = images[..., -1] > 0.5
40 | loss_target_l2 = (images[t_mask] - target_images[t_mask]).abs().pow(2).mean()
41 | loss_alpha_target_mask_l2 = (images[..., -1][mask] - target_images[..., -1][mask]).pow(2).mean()
42 |
43 | loss = loss_target_l2 + loss_alpha_target_mask_l2 + loss_expand * loss_expansion_weight
44 |
45 | # out of box
46 | loss_oob = (vertices.abs() > 0.99).float().mean() * 10
47 | loss = loss + loss_oob
48 |
49 | loss.backward()
50 | opt.step()
51 |
52 | vertices,faces = opt.remesh(poisson=False)
53 |
54 | vertices, faces = vertices.detach(), faces.detach()
55 |
56 | if return_mesh:
57 | return to_py3d_mesh(vertices, faces)
58 | else:
59 | return vertices, faces
60 |
--------------------------------------------------------------------------------
/mesh_reconstruction/refine.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | from PIL import Image
3 | import torch
4 | from typing import List
5 | from mesh_reconstruction.remesh import calc_vertex_normals
6 | from mesh_reconstruction.opt import MeshOptimizer
7 | from mesh_reconstruction.func import make_star_cameras_orthographic, make_star_cameras_orthographic_py3d
8 | from mesh_reconstruction.render import NormalsRenderer, Pytorch3DNormalsRenderer
9 | from scripts.project_mesh import multiview_color_projection, get_cameras_list
10 | from scripts.utils import to_py3d_mesh, from_py3d_mesh, init_target
11 |
12 | def run_mesh_refine(vertices, faces, pils: List[Image.Image], steps=100, start_edge_len=0.02, end_edge_len=0.005, decay=0.99, update_normal_interval=10, update_warmup=10, return_mesh=True, process_inputs=True, process_outputs=True):
13 | if process_inputs:
14 | vertices = vertices * 2 / 1.35
15 | vertices[..., [0, 2]] = - vertices[..., [0, 2]]
16 |
17 | poission_steps = []
18 |
19 | assert len(pils) == 4
20 | mv,proj = make_star_cameras_orthographic(4, 1)
21 | renderer = NormalsRenderer(mv,proj,list(pils[0].size))
22 | # cameras = make_star_cameras_orthographic_py3d([0, 270, 180, 90], device="cuda", focal=1., dist=4.0)
23 | # renderer = Pytorch3DNormalsRenderer(cameras, list(pils[0].size), device="cuda")
24 |
25 | target_images = init_target(pils, new_bkgd=(0., 0., 0.)) # 4s
26 | # 1. no rotate
27 | target_images = target_images[[0, 3, 2, 1]]
28 |
29 | # 2. init from coarse mesh
30 | opt = MeshOptimizer(vertices,faces, ramp=5, edge_len_lims=(end_edge_len, start_edge_len), local_edgelen=False, laplacian_weight=0.02)
31 |
32 | vertices = opt.vertices
33 | alpha_init = None
34 |
35 | mask = target_images[..., -1] < 0.5
36 |
37 | for i in tqdm(range(steps)):
38 | opt.zero_grad()
39 | opt._lr *= decay
40 | normals = calc_vertex_normals(vertices,faces)
41 | images = renderer.render(vertices,normals,faces)
42 | if alpha_init is None:
43 | alpha_init = images.detach()
44 |
45 | if i < update_warmup or i % update_normal_interval == 0:
46 | with torch.no_grad():
47 | py3d_mesh = to_py3d_mesh(vertices, faces, normals)
48 | cameras = get_cameras_list(azim_list = [0, 90, 180, 270], device=vertices.device, focal=1.)
49 | _, _, target_normal = from_py3d_mesh(multiview_color_projection(py3d_mesh, pils, cameras_list=cameras, weights=[2.0, 0.8, 1.0, 0.8], confidence_threshold=0.1, complete_unseen=False, below_confidence_strategy='original', reweight_with_cosangle='linear'))
50 | target_normal = target_normal * 2 - 1
51 | target_normal = torch.nn.functional.normalize(target_normal, dim=-1)
52 | debug_images = renderer.render(vertices,target_normal,faces)
53 |
54 | d_mask = images[..., -1] > 0.5
55 | loss_debug_l2 = (images[..., :3][d_mask] - debug_images[..., :3][d_mask]).pow(2).mean()
56 |
57 | loss_alpha_target_mask_l2 = (images[..., -1][mask] - target_images[..., -1][mask]).pow(2).mean()
58 |
59 | loss = loss_debug_l2 + loss_alpha_target_mask_l2
60 |
61 | # out of box
62 | loss_oob = (vertices.abs() > 0.99).float().mean() * 10
63 | loss = loss + loss_oob
64 |
65 | loss.backward()
66 | opt.step()
67 |
68 | vertices,faces = opt.remesh(poisson=(i in poission_steps))
69 |
70 | vertices, faces = vertices.detach(), faces.detach()
71 |
72 | if process_outputs:
73 | vertices = vertices / 2 * 1.35
74 | vertices[..., [0, 2]] = - vertices[..., [0, 2]]
75 |
76 | if return_mesh:
77 | return to_py3d_mesh(vertices, faces)
78 | else:
79 | return vertices, faces
80 |
--------------------------------------------------------------------------------
/mesh_reconstruction/remesh.py:
--------------------------------------------------------------------------------
1 | # modified from https://github.com/Profactor/continuous-remeshing
2 | import torch
3 | import torch.nn.functional as tfunc
4 | import torch_scatter
5 | from typing import Tuple
6 |
7 | def prepend_dummies(
8 | vertices:torch.Tensor, #V,D
9 | faces:torch.Tensor, #F,3 long
10 | )->Tuple[torch.Tensor,torch.Tensor]:
11 | """prepend dummy elements to vertices and faces to enable "masked" scatter operations"""
12 | V,D = vertices.shape
13 | vertices = torch.concat((torch.full((1,D),fill_value=torch.nan,device=vertices.device),vertices),dim=0)
14 | faces = torch.concat((torch.zeros((1,3),dtype=torch.long,device=faces.device),faces+1),dim=0)
15 | return vertices,faces
16 |
17 | def remove_dummies(
18 | vertices:torch.Tensor, #V,D - first vertex all nan and unreferenced
19 | faces:torch.Tensor, #F,3 long - first face all zeros
20 | )->Tuple[torch.Tensor,torch.Tensor]:
21 | """remove dummy elements added with prepend_dummies()"""
22 | return vertices[1:],faces[1:]-1
23 |
24 |
25 | def calc_edges(
26 | faces: torch.Tensor, # F,3 long - first face may be dummy with all zeros
27 | with_edge_to_face: bool = False
28 | ) -> Tuple[torch.Tensor, ...]:
29 | """
30 | returns Tuple of
31 | - edges E,2 long, 0 for unused, lower vertex index first
32 | - face_to_edge F,3 long
33 | - (optional) edge_to_face shape=E,[left,right],[face,side]
34 |
35 | o-<-----e1 e0,e1...edge, e0-o
41 | """
42 |
43 | F = faces.shape[0]
44 |
45 | # make full edges, lower vertex index first
46 | face_edges = torch.stack((faces,faces.roll(-1,1)),dim=-1) #F*3,3,2
47 | full_edges = face_edges.reshape(F*3,2)
48 | sorted_edges,_ = full_edges.sort(dim=-1) #F*3,2
49 |
50 | # make unique edges
51 | edges,full_to_unique = torch.unique(input=sorted_edges,sorted=True,return_inverse=True,dim=0) #(E,2),(F*3)
52 | E = edges.shape[0]
53 | face_to_edge = full_to_unique.reshape(F,3) #F,3
54 |
55 | if not with_edge_to_face:
56 | return edges, face_to_edge
57 |
58 | is_right = full_edges[:,0]!=sorted_edges[:,0] #F*3
59 | edge_to_face = torch.zeros((E,2,2),dtype=torch.long,device=faces.device) #E,LR=2,S=2
60 | scatter_src = torch.cartesian_prod(torch.arange(0,F,device=faces.device),torch.arange(0,3,device=faces.device)) #F*3,2
61 | edge_to_face.reshape(2*E,2).scatter_(dim=0,index=(2*full_to_unique+is_right)[:,None].expand(F*3,2),src=scatter_src) #E,LR=2,S=2
62 | edge_to_face[0] = 0
63 | return edges, face_to_edge, edge_to_face
64 |
65 | def calc_edge_length(
66 | vertices:torch.Tensor, #V,3 first may be dummy
67 | edges:torch.Tensor, #E,2 long, lower vertex index first, (0,0) for unused
68 | )->torch.Tensor: #E
69 |
70 | full_vertices = vertices[edges] #E,2,3
71 | a,b = full_vertices.unbind(dim=1) #E,3
72 | return torch.norm(a-b,p=2,dim=-1)
73 |
74 | def calc_face_normals(
75 | vertices:torch.Tensor, #V,3 first vertex may be unreferenced
76 | faces:torch.Tensor, #F,3 long, first face may be all zero
77 | normalize:bool=False,
78 | )->torch.Tensor: #F,3
79 | """
80 | n
81 | |
82 | c0 corners ordered counterclockwise when
83 | / \ looking onto surface (in neg normal direction)
84 | c1---c2
85 | """
86 | full_vertices = vertices[faces] #F,C=3,3
87 | v0,v1,v2 = full_vertices.unbind(dim=1) #F,3
88 | face_normals = torch.cross(v1-v0,v2-v0, dim=1) #F,3
89 | if normalize:
90 | face_normals = tfunc.normalize(face_normals, eps=1e-6, dim=1)
91 | return face_normals #F,3
92 |
93 | def calc_vertex_normals(
94 | vertices:torch.Tensor, #V,3 first vertex may be unreferenced
95 | faces:torch.Tensor, #F,3 long, first face may be all zero
96 | face_normals:torch.Tensor=None, #F,3, not normalized
97 | )->torch.Tensor: #F,3
98 |
99 | F = faces.shape[0]
100 |
101 | if face_normals is None:
102 | face_normals = calc_face_normals(vertices,faces)
103 |
104 | vertex_normals = torch.zeros((vertices.shape[0],3,3),dtype=vertices.dtype,device=vertices.device) #V,C=3,3
105 | vertex_normals.scatter_add_(dim=0,index=faces[:,:,None].expand(F,3,3),src=face_normals[:,None,:].expand(F,3,3))
106 | vertex_normals = vertex_normals.sum(dim=1) #V,3
107 | return tfunc.normalize(vertex_normals, eps=1e-6, dim=1)
108 |
109 | def calc_face_ref_normals(
110 | faces:torch.Tensor, #F,3 long, 0 for unused
111 | vertex_normals:torch.Tensor, #V,3 first unused
112 | normalize:bool=False,
113 | )->torch.Tensor: #F,3
114 | """calculate reference normals for face flip detection"""
115 | full_normals = vertex_normals[faces] #F,C=3,3
116 | ref_normals = full_normals.sum(dim=1) #F,3
117 | if normalize:
118 | ref_normals = tfunc.normalize(ref_normals, eps=1e-6, dim=1)
119 | return ref_normals
120 |
121 | def pack(
122 | vertices:torch.Tensor, #V,3 first unused and nan
123 | faces:torch.Tensor, #F,3 long, 0 for unused
124 | )->Tuple[torch.Tensor,torch.Tensor]: #(vertices,faces), keeps first vertex unused
125 | """removes unused elements in vertices and faces"""
126 | V = vertices.shape[0]
127 |
128 | # remove unused faces
129 | used_faces = faces[:,0]!=0
130 | used_faces[0] = True
131 | faces = faces[used_faces] #sync
132 |
133 | # remove unused vertices
134 | used_vertices = torch.zeros(V,3,dtype=torch.bool,device=vertices.device)
135 | used_vertices.scatter_(dim=0,index=faces,value=True,reduce='add')
136 | used_vertices = used_vertices.any(dim=1)
137 | used_vertices[0] = True
138 | vertices = vertices[used_vertices] #sync
139 |
140 | # update used faces
141 | ind = torch.zeros(V,dtype=torch.long,device=vertices.device)
142 | V1 = used_vertices.sum()
143 | ind[used_vertices] = torch.arange(0,V1,device=vertices.device) #sync
144 | faces = ind[faces]
145 |
146 | return vertices,faces
147 |
148 | def split_edges(
149 | vertices:torch.Tensor, #V,3 first unused
150 | faces:torch.Tensor, #F,3 long, 0 for unused
151 | edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first
152 | face_to_edge:torch.Tensor, #F,3 long 0 for unused
153 | splits, #E bool
154 | pack_faces:bool=True,
155 | )->Tuple[torch.Tensor,torch.Tensor]: #(vertices,faces)
156 |
157 | # c2 c2 c...corners = faces
158 | # . . . . s...side_vert, 0 means no split
159 | # . . .N2 . S...shrunk_face
160 | # . . . . Ni...new_faces
161 | # s2 s1 s2|c2...s1|c1
162 | # . . . . .
163 | # . . . S . .
164 | # . . . . N1 .
165 | # c0...(s0=0)....c1 s0|c0...........c1
166 | #
167 | # pseudo-code:
168 | # S = [s0|c0,s1|c1,s2|c2] example:[c0,s1,s2]
169 | # split = side_vert!=0 example:[False,True,True]
170 | # N0 = split[0]*[c0,s0,s2|c2] example:[0,0,0]
171 | # N1 = split[1]*[c1,s1,s0|c0] example:[c1,s1,c0]
172 | # N2 = split[2]*[c2,s2,s1|c1] example:[c2,s2,s1]
173 |
174 | V = vertices.shape[0]
175 | F = faces.shape[0]
176 | S = splits.sum().item() #sync
177 |
178 | if S==0:
179 | return vertices,faces
180 |
181 | edge_vert = torch.zeros_like(splits, dtype=torch.long) #E
182 | edge_vert[splits] = torch.arange(V,V+S,dtype=torch.long,device=vertices.device) #E 0 for no split, sync
183 | side_vert = edge_vert[face_to_edge] #F,3 long, 0 for no split
184 | split_edges = edges[splits] #S sync
185 |
186 | #vertices
187 | split_vertices = vertices[split_edges].mean(dim=1) #S,3
188 | vertices = torch.concat((vertices,split_vertices),dim=0)
189 |
190 | #faces
191 | side_split = side_vert!=0 #F,3
192 | shrunk_faces = torch.where(side_split,side_vert,faces) #F,3 long, 0 for no split
193 | new_faces = side_split[:,:,None] * torch.stack((faces,side_vert,shrunk_faces.roll(1,dims=-1)),dim=-1) #F,N=3,C=3
194 | faces = torch.concat((shrunk_faces,new_faces.reshape(F*3,3))) #4F,3
195 | if pack_faces:
196 | mask = faces[:,0]!=0
197 | mask[0] = True
198 | faces = faces[mask] #F',3 sync
199 |
200 | return vertices,faces
201 |
202 | def collapse_edges(
203 | vertices:torch.Tensor, #V,3 first unused
204 | faces:torch.Tensor, #F,3 long 0 for unused
205 | edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first
206 | priorities:torch.Tensor, #E float
207 | stable:bool=False, #only for unit testing
208 | )->Tuple[torch.Tensor,torch.Tensor]: #(vertices,faces)
209 |
210 | V = vertices.shape[0]
211 |
212 | # check spacing
213 | _,order = priorities.sort(stable=stable) #E
214 | rank = torch.zeros_like(order)
215 | rank[order] = torch.arange(0,len(rank),device=rank.device)
216 | vert_rank = torch.zeros(V,dtype=torch.long,device=vertices.device) #V
217 | edge_rank = rank #E
218 | for i in range(3):
219 | torch_scatter.scatter_max(src=edge_rank[:,None].expand(-1,2).reshape(-1),index=edges.reshape(-1),dim=0,out=vert_rank)
220 | edge_rank,_ = vert_rank[edges].max(dim=-1) #E
221 | candidates = edges[(edge_rank==rank).logical_and_(priorities>0)] #E',2
222 |
223 | # check connectivity
224 | vert_connections = torch.zeros(V,dtype=torch.long,device=vertices.device) #V
225 | vert_connections[candidates[:,0]] = 1 #start
226 | edge_connections = vert_connections[edges].sum(dim=-1) #E, edge connected to start
227 | vert_connections.scatter_add_(dim=0,index=edges.reshape(-1),src=edge_connections[:,None].expand(-1,2).reshape(-1))# one edge from start
228 | vert_connections[candidates] = 0 #clear start and end
229 | edge_connections = vert_connections[edges].sum(dim=-1) #E, one or two edges from start
230 | vert_connections.scatter_add_(dim=0,index=edges.reshape(-1),src=edge_connections[:,None].expand(-1,2).reshape(-1)) #one or two edges from start
231 | collapses = candidates[vert_connections[candidates[:,1]] <= 2] # E" not more than two connections between start and end
232 |
233 | # mean vertices
234 | vertices[collapses[:,0]] = vertices[collapses].mean(dim=1)
235 |
236 | # update faces
237 | dest = torch.arange(0,V,dtype=torch.long,device=vertices.device) #V
238 | dest[collapses[:,1]] = dest[collapses[:,0]]
239 | faces = dest[faces] #F,3
240 | c0,c1,c2 = faces.unbind(dim=-1)
241 | collapsed = (c0==c1).logical_or_(c1==c2).logical_or_(c0==c2)
242 | faces[collapsed] = 0
243 |
244 | return vertices,faces
245 |
246 | def calc_face_collapses(
247 | vertices:torch.Tensor, #V,3 first unused
248 | faces:torch.Tensor, #F,3 long, 0 for unused
249 | edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first
250 | face_to_edge:torch.Tensor, #F,3 long 0 for unused
251 | edge_length:torch.Tensor, #E
252 | face_normals:torch.Tensor, #F,3
253 | vertex_normals:torch.Tensor, #V,3 first unused
254 | min_edge_length:torch.Tensor=None, #V
255 | area_ratio = 0.5, #collapse if area < min_edge_length**2 * area_ratio
256 | shortest_probability = 0.8
257 | )->torch.Tensor: #E edges to collapse
258 |
259 | E = edges.shape[0]
260 | F = faces.shape[0]
261 |
262 | # face flips
263 | ref_normals = calc_face_ref_normals(faces,vertex_normals,normalize=False) #F,3
264 | face_collapses = (face_normals*ref_normals).sum(dim=-1)<0 #F
265 |
266 | # small faces
267 | if min_edge_length is not None:
268 | min_face_length = min_edge_length[faces].mean(dim=-1) #F
269 | min_area = min_face_length**2 * area_ratio #F
270 | face_collapses.logical_or_(face_normals.norm(dim=-1) < min_area*2) #F
271 | face_collapses[0] = False
272 |
273 | # faces to edges
274 | face_length = edge_length[face_to_edge] #F,3
275 |
276 | if shortest_probability<1:
277 | #select shortest edge with shortest_probability chance
278 | randlim = round(2/(1-shortest_probability))
279 | rand_ind = torch.randint(0,randlim,size=(F,),device=faces.device).clamp_max_(2) #selected edge local index in face
280 | sort_ind = torch.argsort(face_length,dim=-1,descending=True) #F,3
281 | local_ind = sort_ind.gather(dim=-1,index=rand_ind[:,None])
282 | else:
283 | local_ind = torch.argmin(face_length,dim=-1)[:,None] #F,1 0...2 shortest edge local index in face
284 |
285 | edge_ind = face_to_edge.gather(dim=1,index=local_ind)[:,0] #F 0...E selected edge global index
286 | edge_collapses = torch.zeros(E,dtype=torch.long,device=vertices.device)
287 | edge_collapses.scatter_add_(dim=0,index=edge_ind,src=face_collapses.long())
288 |
289 | return edge_collapses.bool()
290 |
291 | def flip_edges(
292 | vertices:torch.Tensor, #V,3 first unused
293 | faces:torch.Tensor, #F,3 long, first must be 0, 0 for unused
294 | edges:torch.Tensor, #E,2 long, first must be 0, 0 for unused, lower vertex index first
295 | edge_to_face:torch.Tensor, #E,[left,right],[face,side]
296 | with_border:bool=True, #handle border edges (D=4 instead of D=6)
297 | with_normal_check:bool=True, #check face normal flips
298 | stable:bool=False, #only for unit testing
299 | ):
300 | V = vertices.shape[0]
301 | E = edges.shape[0]
302 | device=vertices.device
303 | vertex_degree = torch.zeros(V,dtype=torch.long,device=device) #V long
304 | vertex_degree.scatter_(dim=0,index=edges.reshape(E*2),value=1,reduce='add')
305 | neighbor_corner = (edge_to_face[:,:,1] + 2) % 3 #go from side to corner
306 | neighbors = faces[edge_to_face[:,:,0],neighbor_corner] #E,LR=2
307 | edge_is_inside = neighbors.all(dim=-1) #E
308 |
309 | if with_border:
310 | # inside vertices should have D=6, border edges D=4, so we subtract 2 for all inside vertices
311 | # need to use float for masks in order to use scatter(reduce='multiply')
312 | vertex_is_inside = torch.ones(V,2,dtype=torch.float32,device=vertices.device) #V,2 float
313 | src = edge_is_inside.type(torch.float32)[:,None].expand(E,2) #E,2 float
314 | vertex_is_inside.scatter_(dim=0,index=edges,src=src,reduce='multiply')
315 | vertex_is_inside = vertex_is_inside.prod(dim=-1,dtype=torch.long) #V long
316 | vertex_degree -= 2 * vertex_is_inside #V long
317 |
318 | neighbor_degrees = vertex_degree[neighbors] #E,LR=2
319 | edge_degrees = vertex_degree[edges] #E,2
320 | #
321 | # loss = Sum_over_affected_vertices((new_degree-6)**2)
322 | # loss_change = Sum_over_neighbor_vertices((degree+1-6)**2-(degree-6)**2)
323 | # + Sum_over_edge_vertices((degree-1-6)**2-(degree-6)**2)
324 | # = 2 * (2 + Sum_over_neighbor_vertices(degree) - Sum_over_edge_vertices(degree))
325 | #
326 | loss_change = 2 + neighbor_degrees.sum(dim=-1) - edge_degrees.sum(dim=-1) #E
327 | candidates = torch.logical_and(loss_change<0, edge_is_inside) #E
328 | loss_change = loss_change[candidates] #E'
329 | if loss_change.shape[0]==0:
330 | return
331 |
332 | edges_neighbors = torch.concat((edges[candidates],neighbors[candidates]),dim=-1) #E',4
333 | _,order = loss_change.sort(descending=True, stable=stable) #E'
334 | rank = torch.zeros_like(order)
335 | rank[order] = torch.arange(0,len(rank),device=rank.device)
336 | vertex_rank = torch.zeros((V,4),dtype=torch.long,device=device) #V,4
337 | torch_scatter.scatter_max(src=rank[:,None].expand(-1,4),index=edges_neighbors,dim=0,out=vertex_rank)
338 | vertex_rank,_ = vertex_rank.max(dim=-1) #V
339 | neighborhood_rank,_ = vertex_rank[edges_neighbors].max(dim=-1) #E'
340 | flip = rank==neighborhood_rank #E'
341 |
342 | if with_normal_check:
343 | # cl-<-----e1 e0,e1...edge, e0-cr
349 | v = vertices[edges_neighbors] #E",4,3
350 | v = v - v[:,0:1] #make relative to e0
351 | e1 = v[:,1]
352 | cl = v[:,2]
353 | cr = v[:,3]
354 | n = torch.cross(e1,cl) + torch.cross(cr,e1) #sum of old normal vectors
355 | flip.logical_and_(torch.sum(n*torch.cross(cr,cl),dim=-1)>0) #first new face
356 | flip.logical_and_(torch.sum(n*torch.cross(cl-e1,cr-e1),dim=-1)>0) #second new face
357 |
358 | flip_edges_neighbors = edges_neighbors[flip] #E",4
359 | flip_edge_to_face = edge_to_face[candidates,:,0][flip] #E",2
360 | flip_faces = flip_edges_neighbors[:,[[0,3,2],[1,2,3]]] #E",2,3
361 | faces.scatter_(dim=0,index=flip_edge_to_face.reshape(-1,1).expand(-1,3),src=flip_faces.reshape(-1,3))
362 |
--------------------------------------------------------------------------------
/mesh_reconstruction/render.py:
--------------------------------------------------------------------------------
1 | # modified from https://github.com/Profactor/continuous-remeshing
2 | import nvdiffrast.torch as dr
3 | import torch
4 | from typing import Tuple
5 |
6 | def _warmup(glctx, device=None):
7 | device = 'cuda' if device is None else device
8 | #windows workaround for https://github.com/NVlabs/nvdiffrast/issues/59
9 | def tensor(*args, **kwargs):
10 | return torch.tensor(*args, device=device, **kwargs)
11 | pos = tensor([[[-0.8, -0.8, 0, 1], [0.8, -0.8, 0, 1], [-0.8, 0.8, 0, 1]]], dtype=torch.float32)
12 | tri = tensor([[0, 1, 2]], dtype=torch.int32)
13 | dr.rasterize(glctx, pos, tri, resolution=[256, 256])
14 |
15 | glctx = dr.RasterizeGLContext(output_db=False, device="cuda")
16 |
17 | class NormalsRenderer:
18 |
19 | _glctx:dr.RasterizeGLContext = None
20 |
21 | def __init__(
22 | self,
23 | mv: torch.Tensor, #C,4,4
24 | proj: torch.Tensor, #C,4,4
25 | image_size: Tuple[int,int],
26 | mvp = None,
27 | device=None,
28 | ):
29 | if mvp is None:
30 | self._mvp = proj @ mv #C,4,4
31 | else:
32 | self._mvp = mvp
33 | self._image_size = image_size
34 | self._glctx = glctx
35 | _warmup(self._glctx, device)
36 |
37 | def render(self,
38 | vertices: torch.Tensor, #V,3 float
39 | normals: torch.Tensor, #V,3 float in [-1, 1]
40 | faces: torch.Tensor, #F,3 long
41 | ) ->torch.Tensor: #C,H,W,4
42 |
43 | V = vertices.shape[0]
44 | faces = faces.type(torch.int32)
45 | vert_hom = torch.cat((vertices, torch.ones(V,1,device=vertices.device)),axis=-1) #V,3 -> V,4
46 | vertices_clip = vert_hom @ self._mvp.transpose(-2,-1) #C,V,4
47 | rast_out,_ = dr.rasterize(self._glctx, vertices_clip, faces, resolution=self._image_size, grad_db=False) #C,H,W,4
48 | vert_col = (normals+1)/2 #V,3
49 | col,_ = dr.interpolate(vert_col, rast_out, faces) #C,H,W,3
50 | alpha = torch.clamp(rast_out[..., -1:], max=1) #C,H,W,1
51 | col = torch.concat((col,alpha),dim=-1) #C,H,W,4
52 | col = dr.antialias(col, rast_out, vertices_clip, faces) #C,H,W,4
53 | return col #C,H,W,4
54 |
55 |
56 |
57 | from pytorch3d.structures import Meshes
58 | from pytorch3d.renderer.mesh.shader import ShaderBase
59 | from pytorch3d.renderer import (
60 | RasterizationSettings,
61 | MeshRendererWithFragments,
62 | TexturesVertex,
63 | MeshRasterizer,
64 | BlendParams,
65 | FoVOrthographicCameras,
66 | look_at_view_transform,
67 | hard_rgb_blend,
68 | )
69 |
70 | class VertexColorShader(ShaderBase):
71 | def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
72 | blend_params = kwargs.get("blend_params", self.blend_params)
73 | texels = meshes.sample_textures(fragments)
74 | return hard_rgb_blend(texels, fragments, blend_params)
75 |
76 | def render_mesh_vertex_color(mesh, cameras, H, W, blur_radius=0.0, faces_per_pixel=1, bkgd=(0., 0., 0.), dtype=torch.float32, device="cuda"):
77 | if len(mesh) != len(cameras):
78 | if len(cameras) % len(mesh) == 0:
79 | mesh = mesh.extend(len(cameras))
80 | else:
81 | raise NotImplementedError()
82 |
83 | # render requires everything in float16 or float32
84 | input_dtype = dtype
85 | blend_params = BlendParams(1e-4, 1e-4, bkgd)
86 |
87 | # Define the settings for rasterization and shading
88 | raster_settings = RasterizationSettings(
89 | image_size=(H, W),
90 | blur_radius=blur_radius,
91 | faces_per_pixel=faces_per_pixel,
92 | clip_barycentric_coords=True,
93 | bin_size=None,
94 | max_faces_per_bin=None,
95 | )
96 |
97 | # Create a renderer by composing a rasterizer and a shader
98 | # We simply render vertex colors through the custom VertexColorShader (no lighting, materials are used)
99 | renderer = MeshRendererWithFragments(
100 | rasterizer=MeshRasterizer(
101 | cameras=cameras,
102 | raster_settings=raster_settings
103 | ),
104 | shader=VertexColorShader(
105 | device=device,
106 | cameras=cameras,
107 | blend_params=blend_params
108 | )
109 | )
110 |
111 | # render RGB and depth, get mask
112 | with torch.autocast(dtype=input_dtype, device_type=torch.device(device).type):
113 | images, _ = renderer(mesh)
114 | return images # BHW4
115 |
116 | class Pytorch3DNormalsRenderer: # 100 times slower!!!
117 | def __init__(self, cameras, image_size, device):
118 | self.cameras = cameras.to(device)
119 | self._image_size = image_size
120 | self.device = device
121 |
122 | def render(self,
123 | vertices: torch.Tensor, #V,3 float
124 | normals: torch.Tensor, #V,3 float in [-1, 1]
125 | faces: torch.Tensor, #F,3 long
126 | ) ->torch.Tensor: #C,H,W,4
127 | mesh = Meshes(verts=[vertices], faces=[faces], textures=TexturesVertex(verts_features=[(normals + 1) / 2])).to(self.device)
128 | return render_mesh_vertex_color(mesh, self.cameras, self._image_size[0], self._image_size[1], device=self.device)
129 |
130 | def save_tensor_to_img(tensor, save_dir):
131 | from PIL import Image
132 | import numpy as np
133 | for idx, img in enumerate(tensor):
134 | img = img[..., :3].cpu().numpy()
135 | img = (img * 255).astype(np.uint8)
136 | img = Image.fromarray(img)
137 | img.save(save_dir + f"{idx}.png")
138 |
139 | if __name__ == "__main__":
140 | import sys
141 | import os
142 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
143 | from mesh_reconstruction.func import make_star_cameras_orthographic, make_star_cameras_orthographic_py3d
144 | cameras = make_star_cameras_orthographic_py3d([0, 270, 180, 90], device="cuda", focal=1., dist=4.0)
145 | mv,proj = make_star_cameras_orthographic(4, 1)
146 | resolution = 1024
147 | renderer1 = NormalsRenderer(mv,proj, [resolution,resolution], device="cuda")
148 | renderer2 = Pytorch3DNormalsRenderer(cameras, [resolution,resolution], device="cuda")
149 | vertices = torch.tensor([[0,0,0],[0,0,1],[0,1,0],[1,0,0]], device="cuda", dtype=torch.float32)
150 | normals = torch.tensor([[-1,-1,-1],[1,-1,-1],[-1,-1,1],[-1,1,-1]], device="cuda", dtype=torch.float32)
151 | faces = torch.tensor([[0,1,2],[0,1,3],[0,2,3],[1,2,3]], device="cuda", dtype=torch.long)
152 |
153 | import time
154 | t0 = time.time()
155 | r1 = renderer1.render(vertices, normals, faces)
156 | print("time r1:", time.time() - t0)
157 |
158 | t0 = time.time()
159 | r2 = renderer2.render(vertices, normals, faces)
160 | print("time r2:", time.time() - t0)
161 |
162 | for i in range(4):
163 | print((r1[i]-r2[i]).abs().mean(), (r1[i]+r2[i]).abs().mean())
--------------------------------------------------------------------------------
/requirements-detail.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.29.2
2 | datasets==2.18.0
3 | diffusers==0.27.2
4 | fire==0.6.0
5 | gradio==4.32.0
6 | jaxtyping==0.2.29
7 | numba==0.59.1
8 | numpy==1.26.4
9 | nvdiffrast==0.3.1
10 | omegaconf==2.3.0
11 | onnxruntime_gpu==1.17.0
12 | opencv_python==4.9.0.80
13 | opencv_python_headless==4.9.0.80
14 | ort_nightly_gpu==1.17.0.dev20240118002
15 | peft==0.10.0
16 | Pillow==10.3.0
17 | pygltflib==1.16.2
18 | pymeshlab==2023.12.post1
19 | pytorch3d==0.7.5
20 | rembg==2.0.56
21 | torch==2.1.0+cu121
22 | torch_scatter==2.1.2
23 | tqdm==4.64.1
24 | transformers==4.39.3
25 | trimesh==4.3.0
26 | typeguard==2.13.3
27 | wandb==0.16.6
28 |
--------------------------------------------------------------------------------
/requirements-win-py311-cu121.txt:
--------------------------------------------------------------------------------
1 | accelerate
2 | datasets
3 | fire
4 | gradio
5 | jaxtyping
6 | numba
7 | numpy
8 | git+https://github.com/NVlabs/nvdiffrast.git
9 | omegaconf>=2.3.0
10 | opencv_python
11 | opencv_python_headless
12 | ort_nightly_gpu
13 | peft
14 | Pillow
15 | pygltflib
16 | pymeshlab>=2023.12
17 | git+https://github.com/facebookresearch/pytorch3d.git@stable
18 | rembg[gpu]
19 | torch_scatter
20 | tqdm
21 | transformers
22 | trimesh
23 | typeguard
24 | wandb
25 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate
2 | datasets
3 | diffusers>=0.26.3
4 | fire
5 | gradio
6 | jaxtyping
7 | numba
8 | numpy
9 | git+https://github.com/NVlabs/nvdiffrast.git
10 | omegaconf>=2.3.0
11 | onnxruntime_gpu
12 | opencv_python
13 | opencv_python_headless
14 | ort_nightly_gpu
15 | peft
16 | Pillow
17 | pygltflib
18 | pymeshlab>=2023.12
19 | git+https://github.com/facebookresearch/pytorch3d.git@stable
20 | rembg[gpu]
21 | #torch>=2.0.1
22 | torch_scatter
23 | tqdm
24 | transformers
25 | trimesh
26 | typeguard
27 | wandb
28 | xformers
29 |
--------------------------------------------------------------------------------
/scripts/all_typing.py:
--------------------------------------------------------------------------------
1 | # code from https://github.com/threestudio-project
2 |
3 | """
4 | This module contains type annotations for the project, using
5 | 1. Python type hints (https://docs.python.org/3/library/typing.html) for Python objects
6 | 2. jaxtyping (https://github.com/google/jaxtyping/blob/main/API.md) for PyTorch tensors
7 |
8 | Two types of typing checking can be used:
9 | 1. Static type checking with mypy (install with pip and enabled as the default linter in VSCode)
10 | 2. Runtime type checking with typeguard (install with pip and triggered at runtime, mainly for tensor dtype and shape checking)
11 | """
12 |
13 | # Basic types
14 | from typing import (
15 | Any,
16 | Callable,
17 | Dict,
18 | Iterable,
19 | List,
20 | Literal,
21 | NamedTuple,
22 | NewType,
23 | Optional,
24 | Sized,
25 | Tuple,
26 | Type,
27 | TypeVar,
28 | Union,
29 | )
30 |
31 | # Tensor dtype
32 | # for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md
33 | from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt
34 |
35 | # Config type
36 | from omegaconf import DictConfig
37 |
38 | # PyTorch Tensor type
39 | from torch import Tensor
40 |
41 | # Runtime type checking decorator
42 | from typeguard import typechecked as typechecker
43 |
--------------------------------------------------------------------------------
/scripts/load_onnx.py:
--------------------------------------------------------------------------------
1 | import onnxruntime
2 | import torch
3 |
4 | providers = [
5 | ('TensorrtExecutionProvider', {
6 | 'device_id': 0,
7 | 'trt_max_workspace_size': 8 * 1024 * 1024 * 1024,
8 | 'trt_fp16_enable': True,
9 | 'trt_engine_cache_enable': True,
10 | }),
11 | ('CUDAExecutionProvider', {
12 | 'device_id': 0,
13 | 'arena_extend_strategy': 'kSameAsRequested',
14 | 'gpu_mem_limit': 8 * 1024 * 1024 * 1024,
15 | 'cudnn_conv_algo_search': 'HEURISTIC',
16 | })
17 | ]
18 |
19 | def load_onnx(file_path: str):
20 | assert file_path.endswith(".onnx")
21 | sess_opt = onnxruntime.SessionOptions()
22 | ort_session = onnxruntime.InferenceSession(file_path, sess_opt=sess_opt, providers=providers)
23 | return ort_session
24 |
25 |
26 | def load_onnx_caller(file_path: str, single_output=False):
27 | ort_session = load_onnx(file_path)
28 | def caller(*args):
29 | torch_input = isinstance(args[0], torch.Tensor)
30 | if torch_input:
31 | torch_input_dtype = args[0].dtype
32 | torch_input_device = args[0].device
33 | # check all are torch.Tensor and have same dtype and device
34 | assert all([isinstance(arg, torch.Tensor) for arg in args]), "All inputs should be torch.Tensor, if first input is torch.Tensor"
35 | assert all([arg.dtype == torch_input_dtype for arg in args]), "All inputs should have same dtype, if first input is torch.Tensor"
36 | assert all([arg.device == torch_input_device for arg in args]), "All inputs should have same device, if first input is torch.Tensor"
37 | args = [arg.cpu().float().numpy() for arg in args]
38 |
39 | ort_inputs = {ort_session.get_inputs()[idx].name: args[idx] for idx in range(len(args))}
40 | ort_outs = ort_session.run(None, ort_inputs)
41 |
42 | if torch_input:
43 | ort_outs = [torch.tensor(ort_out, dtype=torch_input_dtype, device=torch_input_device) for ort_out in ort_outs]
44 |
45 | if single_output:
46 | return ort_outs[0]
47 | return ort_outs
48 | return caller
49 |
--------------------------------------------------------------------------------
/scripts/mesh_init.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import torch
3 | import numpy as np
4 | from pytorch3d.structures import Meshes
5 | from pytorch3d.renderer import TexturesVertex
6 | from scripts.utils import meshlab_mesh_to_py3dmesh, py3dmesh_to_meshlab_mesh
7 | import pymeshlab
8 |
9 | _MAX_THREAD = 8
10 |
11 | # rgb and depth to mesh
12 | def get_ortho_ray_directions_origins(W, H, use_pixel_centers=True, device="cuda"):
13 | pixel_center = 0.5 if use_pixel_centers else 0
14 | i, j = np.meshgrid(
15 | np.arange(W, dtype=np.float32) + pixel_center,
16 | np.arange(H, dtype=np.float32) + pixel_center,
17 | indexing='xy'
18 | )
19 | i, j = torch.from_numpy(i).to(device), torch.from_numpy(j).to(device)
20 |
21 | origins = torch.stack([(i/W-0.5)*2, (j/H-0.5)*2 * H / W, torch.zeros_like(i)], dim=-1) # W, H, 3
22 | directions = torch.stack([torch.zeros_like(i), torch.zeros_like(j), torch.ones_like(i)], dim=-1) # W, H, 3
23 |
24 | return origins, directions
25 |
26 | def depth_and_color_to_mesh(rgb_BCHW, pred_HWC, valid_HWC=None, is_back=False):
27 | if valid_HWC is None:
28 | valid_HWC = torch.ones_like(pred_HWC).bool()
29 | H, W = rgb_BCHW.shape[-2:]
30 | rgb_BCHW = rgb_BCHW.flip(-2)
31 | pred_HWC = pred_HWC.flip(0)
32 | valid_HWC = valid_HWC.flip(0)
33 | rays_o, rays_d = get_ortho_ray_directions_origins(W, H, device=rgb_BCHW.device)
34 | verts = rays_o + rays_d * pred_HWC # [H, W, 3]
35 | verts = verts.reshape(-1, 3) # [V, 3]
36 | indexes = torch.arange(H * W).reshape(H, W).to(rgb_BCHW.device)
37 | faces1 = torch.stack([indexes[:-1, :-1], indexes[:-1, 1:], indexes[1:, :-1]], dim=-1)
38 | # faces1_valid = valid_HWC[:-1, :-1] | valid_HWC[:-1, 1:] | valid_HWC[1:, :-1]
39 | faces1_valid = valid_HWC[:-1, :-1] & valid_HWC[:-1, 1:] & valid_HWC[1:, :-1]
40 | faces2 = torch.stack([indexes[1:, 1:], indexes[1:, :-1], indexes[:-1, 1:]], dim=-1)
41 | # faces2_valid = valid_HWC[1:, 1:] | valid_HWC[1:, :-1] | valid_HWC[:-1, 1:]
42 | faces2_valid = valid_HWC[1:, 1:] & valid_HWC[1:, :-1] & valid_HWC[:-1, 1:]
43 | faces = torch.cat([faces1[faces1_valid.expand_as(faces1)].reshape(-1, 3), faces2[faces2_valid.expand_as(faces2)].reshape(-1, 3)], dim=0) # (F, 3)
44 | colors = (rgb_BCHW[0].permute((1,2,0)) / 2 + 0.5).reshape(-1, 3) # (V, 3)
45 | if is_back:
46 | verts = verts * torch.tensor([-1, 1, -1], dtype=verts.dtype, device=verts.device)
47 |
48 | used_verts = faces.unique()
49 | old_to_new_mapping = torch.zeros_like(verts[..., 0]).long()
50 | old_to_new_mapping[used_verts] = torch.arange(used_verts.shape[0], device=verts.device)
51 | new_faces = old_to_new_mapping[faces]
52 | mesh = Meshes(verts=[verts[used_verts]], faces=[new_faces], textures=TexturesVertex(verts_features=[colors[used_verts]]))
53 | return mesh
54 |
55 | def normalmap_to_depthmap(normal_np):
56 | from scripts.normal_to_height_map import estimate_height_map
57 | height = estimate_height_map(normal_np, raw_values=True, thread_count=_MAX_THREAD, target_iteration_count=96)
58 | return height
59 |
60 | def transform_back_normal_to_front(normal_pil):
61 | arr = np.array(normal_pil) # in [0, 255]
62 | arr[..., 0] = 255-arr[..., 0]
63 | arr[..., 2] = 255-arr[..., 2]
64 | return Image.fromarray(arr.astype(np.uint8))
65 |
66 | def calc_w_over_h(normal_pil):
67 | if isinstance(normal_pil, Image.Image):
68 | arr = np.array(normal_pil)
69 | else:
70 | assert isinstance(normal_pil, np.ndarray)
71 | arr = normal_pil
72 | if arr.shape[-1] == 4:
73 | alpha = arr[..., -1] / 255.
74 | alpha[alpha >= 0.5] = 1
75 | alpha[alpha < 0.5] = 0
76 | else:
77 | alpha = ~(arr.min(axis=-1) >= 250)
78 | h_min, w_min = np.min(np.where(alpha), axis=1)
79 | h_max, w_max = np.max(np.where(alpha), axis=1)
80 | return (w_max - w_min) / (h_max - h_min)
81 |
82 | def build_mesh(normal_pil, rgb_pil, is_back=False, clamp_min=-1, scale=0.3, init_type="std", offset=0):
83 | if is_back:
84 | normal_pil = transform_back_normal_to_front(normal_pil)
85 | normal_img = np.array(normal_pil)
86 | rgb_img = np.array(rgb_pil)
87 | if normal_img.shape[-1] == 4:
88 | valid_HWC = normal_img[..., [3]] / 255
89 | elif rgb_img.shape[-1] == 4:
90 | valid_HWC = rgb_img[..., [3]] / 255
91 | else:
92 | raise ValueError("invalid input, either normal or rgb should have alpha channel")
93 |
94 | real_height_pix = np.max(np.where(valid_HWC>0.5)[0]) - np.min(np.where(valid_HWC>0.5)[0])
95 |
96 | heights = normalmap_to_depthmap(normal_img)
97 | rgb_BCHW = torch.from_numpy(rgb_img[..., :3] / 255.).permute((2,0,1))[None]
98 | valid_HWC[valid_HWC < 0.5] = 0
99 | valid_HWC[valid_HWC >= 0.5] = 1
100 | valid_HWC = torch.from_numpy(valid_HWC).bool()
101 | if init_type == "std":
102 | # accurate but not stable
103 | pred_HWC = torch.from_numpy(heights / heights.max() * (real_height_pix / heights.shape[0]) * scale * 2).float()[..., None]
104 | elif init_type == "thin":
105 | heights = heights - heights.min()
106 | heights = (heights / heights.max() * 0.2)
107 | pred_HWC = torch.from_numpy(heights * scale).float()[..., None]
108 | else:
109 | # stable but not accurate
110 | heights = heights - heights.min()
111 | heights = (heights / heights.max() * (1-offset)) + offset # to [0.2, 1]
112 | pred_HWC = torch.from_numpy(heights * scale).float()[..., None]
113 |
114 | # set the boarder pixels to 0 height
115 | import cv2
116 | # edge filter
117 | edge = cv2.Canny((valid_HWC[..., 0] * 255).numpy().astype(np.uint8), 0, 255)
118 | edge = torch.from_numpy(edge).bool()[..., None]
119 | pred_HWC[edge] = 0
120 |
121 | valid_HWC[pred_HWC < clamp_min] = False
122 | return depth_and_color_to_mesh(rgb_BCHW.cuda(), pred_HWC.cuda(), valid_HWC.cuda(), is_back)
123 |
124 | def fix_border_with_pymeshlab_fast(meshes: Meshes, poissson_depth=6, simplification=0):
125 | ms = pymeshlab.MeshSet()
126 | ms.add_mesh(py3dmesh_to_meshlab_mesh(meshes), "cube_vcolor_mesh")
127 | if simplification > 0:
128 | ms.apply_filter('meshing_decimation_quadric_edge_collapse', targetfacenum=simplification, preservetopology=True)
129 | ms.apply_filter('generate_surface_reconstruction_screened_poisson', threads = 6, depth = poissson_depth, preclean = True)
130 | if simplification > 0:
131 | ms.apply_filter('meshing_decimation_quadric_edge_collapse', targetfacenum=simplification, preservetopology=True)
132 | return meshlab_mesh_to_py3dmesh(ms.current_mesh())
133 |
--------------------------------------------------------------------------------
/scripts/multiview_inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | from scripts.mesh_init import build_mesh, calc_w_over_h, fix_border_with_pymeshlab_fast
4 | from scripts.project_mesh import multiview_color_projection
5 | from scripts.refine_lr_to_sr import run_sr_fast
6 | from scripts.utils import simple_clean_mesh
7 | from app.utils import simple_remove, split_image
8 | from app.custom_models.normal_prediction import predict_normals
9 | from mesh_reconstruction.recon import reconstruct_stage1
10 | from mesh_reconstruction.refine import run_mesh_refine
11 | from scripts.project_mesh import get_cameras_list
12 | from scripts.utils import from_py3d_mesh, to_pyml_mesh
13 | from pytorch3d.structures import Meshes, join_meshes_as_scene
14 | import numpy as np
15 |
16 | def fast_geo(front_normal: Image.Image, back_normal: Image.Image, side_normal: Image.Image, clamp=0., init_type="std"):
17 | import time
18 | if front_normal.mode == "RGB":
19 | front_normal = simple_remove(front_normal, run_sr=False)
20 | front_normal = front_normal.resize((192, 192))
21 | if back_normal.mode == "RGB":
22 | back_normal = simple_remove(back_normal, run_sr=False)
23 | back_normal = back_normal.resize((192, 192))
24 | if side_normal.mode == "RGB":
25 | side_normal = simple_remove(side_normal, run_sr=False)
26 | side_normal = side_normal.resize((192, 192))
27 |
28 | # build mesh with front back projection # ~3s
29 | side_w_over_h = calc_w_over_h(side_normal)
30 | mesh_front = build_mesh(front_normal, front_normal, clamp_min=clamp, scale=side_w_over_h, init_type=init_type)
31 | mesh_back = build_mesh(back_normal, back_normal, is_back=True, clamp_min=clamp, scale=side_w_over_h, init_type=init_type)
32 | meshes = join_meshes_as_scene([mesh_front, mesh_back])
33 | meshes = fix_border_with_pymeshlab_fast(meshes, poissson_depth=6, simplification=2000)
34 | return meshes
35 |
36 | def refine_rgb(rgb_pils, front_pil):
37 | from scripts.refine_lr_to_sr import refine_lr_with_sd
38 | from scripts.utils import NEG_PROMPT
39 | from app.utils import make_image_grid
40 | from app.all_models import model_zoo
41 | from app.utils import rgba_to_rgb
42 | rgb_pil = make_image_grid(rgb_pils, rows=2)
43 | prompt = "4views, multiview"
44 | neg_prompt = NEG_PROMPT
45 | control_image = rgb_pil.resize((1024, 1024))
46 | refined_rgb = refine_lr_with_sd([rgb_pil], [rgba_to_rgb(front_pil)], [control_image], prompt_list=[prompt], neg_prompt_list=[neg_prompt], pipe=model_zoo.pipe_disney_controlnet_tile_ipadapter_i2i, strength=0.2, output_size=(1024, 1024))[0]
47 | refined_rgbs = split_image(refined_rgb, rows=2)
48 | return refined_rgbs
49 |
50 | def erode_alpha(img_list):
51 | out_img_list = []
52 | for idx, img in enumerate(img_list):
53 | arr = np.array(img)
54 | alpha = (arr[:, :, 3] > 127).astype(np.uint8)
55 | # erode 1px
56 | import cv2
57 | alpha = cv2.erode(alpha, np.ones((3, 3), np.uint8), iterations=1)
58 | alpha = (alpha * 255).astype(np.uint8)
59 | img = Image.fromarray(np.concatenate([arr[:, :, :3], alpha[:, :, None]], axis=-1))
60 | out_img_list.append(img)
61 | return out_img_list
62 | import time
63 | def geo_reconstruct(rgb_pils, normal_pils, front_pil, do_refine=False, predict_normal=True, expansion_weight=0.1, init_type="std"):
64 | if front_pil.size[0] <= 512:
65 | front_pil = run_sr_fast([front_pil])[0]
66 | if do_refine:
67 | refined_rgbs = refine_rgb(rgb_pils, front_pil) # 6s
68 | else:
69 | refined_rgbs = [rgb.resize((512, 512), resample=Image.LANCZOS) for rgb in rgb_pils]
70 | img_list = [front_pil] + run_sr_fast(refined_rgbs[1:])
71 |
72 | if predict_normal:
73 | rm_normals = predict_normals([img.resize((512, 512), resample=Image.LANCZOS) for img in img_list], guidance_scale=1.5)
74 | else:
75 | rm_normals = simple_remove([img.resize((512, 512), resample=Image.LANCZOS) for img in normal_pils])
76 | # transfer the alpha channel of rm_normals to img_list
77 | for idx, img in enumerate(rm_normals):
78 | if idx == 0 and img_list[0].mode == "RGBA":
79 | temp = img_list[0].resize((2048, 2048))
80 | rm_normals[0] = Image.fromarray(np.concatenate([np.array(rm_normals[0])[:, :, :3], np.array(temp)[:, :, 3:4]], axis=-1))
81 | continue
82 | img_list[idx] = Image.fromarray(np.concatenate([np.array(img_list[idx]), np.array(img)[:, :, 3:4]], axis=-1))
83 | assert img_list[0].mode == "RGBA"
84 | assert np.mean(np.array(img_list[0])[..., 3]) < 250
85 |
86 | img_list = [img_list[0]] + erode_alpha(img_list[1:])
87 | normal_stg1 = [img.resize((512, 512)) for img in rm_normals]
88 | if init_type in ["std", "thin"]:
89 | meshes = fast_geo(normal_stg1[0], normal_stg1[2], normal_stg1[1], init_type=init_type)
90 | _ = multiview_color_projection(meshes, rgb_pils, resolution=512, device="cuda", complete_unseen=False, confidence_threshold=0.1) # just check for validation, may throw error
91 | vertices, faces, _ = from_py3d_mesh(meshes)
92 | vertices, faces = reconstruct_stage1(normal_stg1, steps=200, vertices=vertices, faces=faces, start_edge_len=0.1, end_edge_len=0.02, gain=0.05, return_mesh=False, loss_expansion_weight=expansion_weight)
93 | elif init_type in ["ball"]:
94 | vertices, faces = reconstruct_stage1(normal_stg1, steps=200, end_edge_len=0.01, return_mesh=False, loss_expansion_weight=expansion_weight)
95 | vertices, faces = run_mesh_refine(vertices, faces, rm_normals, steps=100, start_edge_len=0.02, end_edge_len=0.005, decay=0.99, update_normal_interval=20, update_warmup=5, return_mesh=False, process_inputs=False, process_outputs=False)
96 | meshes = simple_clean_mesh(to_pyml_mesh(vertices, faces), apply_smooth=True, stepsmoothnum=1, apply_sub_divide=True, sub_divide_threshold=0.25).to("cuda")
97 | new_meshes = multiview_color_projection(meshes, img_list, resolution=1024, device="cuda", complete_unseen=True, confidence_threshold=0.2, cameras_list = get_cameras_list([0, 90, 180, 270], "cuda", focal=1))
98 | return new_meshes
99 |
--------------------------------------------------------------------------------
/scripts/normal_to_height_map.py:
--------------------------------------------------------------------------------
1 | # code modified from https://github.com/YertleTurtleGit/depth-from-normals
2 | import numpy as np
3 | import cv2 as cv
4 | from multiprocessing.pool import ThreadPool as Pool
5 | from multiprocessing import cpu_count
6 | from typing import Tuple, List, Union
7 | import numba
8 |
9 |
10 | def calculate_gradients(
11 | normals: np.ndarray, mask: np.ndarray
12 | ) -> Tuple[np.ndarray, np.ndarray]:
13 | horizontal_angle_map = np.arccos(np.clip(normals[:, :, 0], -1, 1))
14 | left_gradients = np.zeros(normals.shape[:2])
15 | left_gradients[mask != 0] = (1 - np.sin(horizontal_angle_map[mask != 0])) * np.sign(
16 | horizontal_angle_map[mask != 0] - np.pi / 2
17 | )
18 |
19 | vertical_angle_map = np.arccos(np.clip(normals[:, :, 1], -1, 1))
20 | top_gradients = np.zeros(normals.shape[:2])
21 | top_gradients[mask != 0] = -(1 - np.sin(vertical_angle_map[mask != 0])) * np.sign(
22 | vertical_angle_map[mask != 0] - np.pi / 2
23 | )
24 |
25 | return left_gradients, top_gradients
26 |
27 |
28 | @numba.jit(nopython=True)
29 | def integrate_gradient_field(
30 | gradient_field: np.ndarray, axis: int, mask: np.ndarray
31 | ) -> np.ndarray:
32 | heights = np.zeros(gradient_field.shape)
33 |
34 | for d1 in numba.prange(heights.shape[1 - axis]):
35 | sum_value = 0
36 | for d2 in range(heights.shape[axis]):
37 | coordinates = (d1, d2) if axis == 1 else (d2, d1)
38 |
39 | if mask[coordinates] != 0:
40 | sum_value = sum_value + gradient_field[coordinates]
41 | heights[coordinates] = sum_value
42 | else:
43 | sum_value = 0
44 |
45 | return heights
46 |
47 |
48 | def calculate_heights(
49 | left_gradients: np.ndarray, top_gradients, mask: np.ndarray
50 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
51 | left_heights = integrate_gradient_field(left_gradients, 1, mask)
52 | right_heights = np.fliplr(
53 | integrate_gradient_field(np.fliplr(-left_gradients), 1, np.fliplr(mask))
54 | )
55 | top_heights = integrate_gradient_field(top_gradients, 0, mask)
56 | bottom_heights = np.flipud(
57 | integrate_gradient_field(np.flipud(-top_gradients), 0, np.flipud(mask))
58 | )
59 | return left_heights, right_heights, top_heights, bottom_heights
60 |
61 |
62 | def combine_heights(*heights: np.ndarray) -> np.ndarray:
63 | return np.mean(np.stack(heights, axis=0), axis=0)
64 |
65 |
66 | def rotate(matrix: np.ndarray, angle: float) -> np.ndarray:
67 | h, w = matrix.shape[:2]
68 | center = (w / 2, h / 2)
69 |
70 | rotation_matrix = cv.getRotationMatrix2D(center, angle, 1.0)
71 | corners = cv.transform(
72 | np.array([[[0, 0], [w, 0], [w, h], [0, h]]]), rotation_matrix
73 | )[0]
74 |
75 | _, _, w, h = cv.boundingRect(corners)
76 |
77 | rotation_matrix[0, 2] += w / 2 - center[0]
78 | rotation_matrix[1, 2] += h / 2 - center[1]
79 | result = cv.warpAffine(matrix, rotation_matrix, (w, h), flags=cv.INTER_LINEAR)
80 |
81 | return result
82 |
83 |
84 | def rotate_vector_field_normals(normals: np.ndarray, angle: float) -> np.ndarray:
85 | angle = np.radians(angle)
86 | cos_angle = np.cos(angle)
87 | sin_angle = np.sin(angle)
88 |
89 | rotated_normals = np.empty_like(normals)
90 | rotated_normals[:, :, 0] = (
91 | normals[:, :, 0] * cos_angle - normals[:, :, 1] * sin_angle
92 | )
93 | rotated_normals[:, :, 1] = (
94 | normals[:, :, 0] * sin_angle + normals[:, :, 1] * cos_angle
95 | )
96 |
97 | return rotated_normals
98 |
99 |
100 | def centered_crop(image: np.ndarray, target_resolution: Tuple[int, int]) -> np.ndarray:
101 | return image[
102 | (image.shape[0] - target_resolution[0])
103 | // 2 : (image.shape[0] - target_resolution[0])
104 | // 2
105 | + target_resolution[0],
106 | (image.shape[1] - target_resolution[1])
107 | // 2 : (image.shape[1] - target_resolution[1])
108 | // 2
109 | + target_resolution[1],
110 | ]
111 |
112 |
113 | def integrate_vector_field(
114 | vector_field: np.ndarray,
115 | mask: np.ndarray,
116 | target_iteration_count: int,
117 | thread_count: int,
118 | ) -> np.ndarray:
119 | shape = vector_field.shape[:2]
120 | angles = np.linspace(0, 90, target_iteration_count, endpoint=False)
121 |
122 | def integrate_vector_field_angles(angles: List[float]) -> np.ndarray:
123 | all_combined_heights = np.zeros(shape)
124 |
125 | for angle in angles:
126 | rotated_vector_field = rotate_vector_field_normals(
127 | rotate(vector_field, angle), angle
128 | )
129 | rotated_mask = rotate(mask, angle)
130 |
131 | left_gradients, top_gradients = calculate_gradients(
132 | rotated_vector_field, rotated_mask
133 | )
134 | (
135 | left_heights,
136 | right_heights,
137 | top_heights,
138 | bottom_heights,
139 | ) = calculate_heights(left_gradients, top_gradients, rotated_mask)
140 |
141 | combined_heights = combine_heights(
142 | left_heights, right_heights, top_heights, bottom_heights
143 | )
144 | combined_heights = centered_crop(rotate(combined_heights, -angle), shape)
145 | all_combined_heights += combined_heights / len(angles)
146 |
147 | return all_combined_heights
148 |
149 | with Pool(processes=thread_count) as pool:
150 | heights = pool.map(
151 | integrate_vector_field_angles,
152 | np.array(
153 | np.array_split(angles, thread_count),
154 | dtype=object,
155 | ),
156 | )
157 | pool.close()
158 | pool.join()
159 |
160 | isotropic_height = np.zeros(shape)
161 | for height in heights:
162 | isotropic_height += height / thread_count
163 |
164 | return isotropic_height
165 |
166 |
167 | def estimate_height_map(
168 | normal_map: np.ndarray,
169 | mask: Union[np.ndarray, None] = None,
170 | height_divisor: float = 1,
171 | target_iteration_count: int = 250,
172 | thread_count: int = cpu_count(),
173 | raw_values: bool = False,
174 | ) -> np.ndarray:
175 | if mask is None:
176 | if normal_map.shape[-1] == 4:
177 | mask = normal_map[:, :, 3] / 255
178 | mask[mask < 0.5] = 0
179 | mask[mask >= 0.5] = 1
180 | else:
181 | mask = np.ones(normal_map.shape[:2], dtype=np.uint8)
182 |
183 | normals = ((normal_map[:, :, :3].astype(np.float64) / 255) - 0.5) * 2
184 | heights = integrate_vector_field(
185 | normals, mask, target_iteration_count, thread_count
186 | )
187 |
188 | if raw_values:
189 | return heights
190 |
191 | heights /= height_divisor
192 | heights[mask > 0] += 1 / 2
193 | heights[mask == 0] = 1 / 2
194 |
195 | heights *= 2**16 - 1
196 |
197 | if np.min(heights) < 0 or np.max(heights) > 2**16 - 1:
198 | raise OverflowError("Height values are clipping.")
199 |
200 | heights = np.clip(heights, 0, 2**16 - 1)
201 | heights = heights.astype(np.uint16)
202 |
203 | return heights
204 |
--------------------------------------------------------------------------------
/scripts/refine_lr_to_sr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 |
4 | import numpy as np
5 | from hashlib import md5
6 | def hash_img(img):
7 | return md5(np.array(img).tobytes()).hexdigest()
8 | def hash_any(obj):
9 | return md5(str(obj).encode()).hexdigest()
10 |
11 | def refine_lr_with_sd(pil_image_list, concept_img_list, control_image_list, prompt_list, pipe=None, strength=0.35, neg_prompt_list="", output_size=(512, 512), controlnet_conditioning_scale=1.):
12 | with torch.no_grad():
13 | images = pipe(
14 | image=pil_image_list,
15 | ip_adapter_image=concept_img_list,
16 | prompt=prompt_list,
17 | neg_prompt=neg_prompt_list,
18 | num_inference_steps=50,
19 | strength=strength,
20 | height=output_size[0],
21 | width=output_size[1],
22 | control_image=control_image_list,
23 | guidance_scale=5.0,
24 | controlnet_conditioning_scale=controlnet_conditioning_scale,
25 | generator=torch.manual_seed(233),
26 | ).images
27 | return images
28 |
29 | SR_cache = None
30 |
31 | def run_sr_fast(source_pils, scale=4):
32 | from PIL import Image
33 | from scripts.upsampler import RealESRGANer
34 | import numpy as np
35 | global SR_cache
36 | if SR_cache is not None:
37 | upsampler = SR_cache
38 | else:
39 | upsampler = RealESRGANer(
40 | scale=4,
41 | onnx_path="ckpt/realesrgan-x4.onnx",
42 | tile=0,
43 | tile_pad=10,
44 | pre_pad=0,
45 | half=True,
46 | gpu_id=0,
47 | )
48 | ret_pils = []
49 | for idx, img_pils in enumerate(source_pils):
50 | np_in = isinstance(img_pils, np.ndarray)
51 | assert isinstance(img_pils, (Image.Image, np.ndarray))
52 | img = np.array(img_pils)
53 | output, _ = upsampler.enhance(img, outscale=scale)
54 | if np_in:
55 | ret_pils.append(output)
56 | else:
57 | ret_pils.append(Image.fromarray(output))
58 | if SR_cache is None:
59 | SR_cache = upsampler
60 | return ret_pils
61 |
--------------------------------------------------------------------------------
/scripts/sd_model_zoo.py:
--------------------------------------------------------------------------------
1 | from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, EulerAncestralDiscreteScheduler, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionPipeline
2 | from transformers import CLIPVisionModelWithProjection
3 | import torch
4 | from copy import deepcopy
5 |
6 | ENABLE_CPU_CACHE = False
7 | DEFAULT_BASE_MODEL = "runwayml/stable-diffusion-v1-5"
8 |
9 | cached_models = {} # cache for models to avoid repeated loading, key is model name
10 | def cache_model(func):
11 | def wrapper(*args, **kwargs):
12 | if ENABLE_CPU_CACHE:
13 | model_name = func.__name__ + str(args) + str(kwargs)
14 | if model_name not in cached_models:
15 | cached_models[model_name] = func(*args, **kwargs)
16 | return cached_models[model_name]
17 | else:
18 | return func(*args, **kwargs)
19 | return wrapper
20 |
21 | def copied_cache_model(func):
22 | def wrapper(*args, **kwargs):
23 | if ENABLE_CPU_CACHE:
24 | model_name = func.__name__ + str(args) + str(kwargs)
25 | if model_name not in cached_models:
26 | cached_models[model_name] = func(*args, **kwargs)
27 | return deepcopy(cached_models[model_name])
28 | else:
29 | return func(*args, **kwargs)
30 | return wrapper
31 |
32 | def model_from_ckpt_or_pretrained(ckpt_or_pretrained, model_cls, original_config_file='ckpt/v1-inference.yaml', torch_dtype=torch.float16, **kwargs):
33 | if ckpt_or_pretrained.endswith(".safetensors"):
34 | pipe = model_cls.from_single_file(ckpt_or_pretrained, original_config_file=original_config_file, torch_dtype=torch_dtype, **kwargs)
35 | else:
36 | pipe = model_cls.from_pretrained(ckpt_or_pretrained, torch_dtype=torch_dtype, **kwargs)
37 | return pipe
38 |
39 | @copied_cache_model
40 | def load_base_model_components(base_model=DEFAULT_BASE_MODEL, torch_dtype=torch.float16):
41 | model_kwargs = dict(
42 | torch_dtype=torch_dtype,
43 | requires_safety_checker=False,
44 | safety_checker=None,
45 | )
46 | pipe: StableDiffusionPipeline = model_from_ckpt_or_pretrained(
47 | base_model,
48 | StableDiffusionPipeline,
49 | **model_kwargs
50 | )
51 | pipe.to("cpu")
52 | return pipe.components
53 |
54 | @cache_model
55 | def load_controlnet(controlnet_path, torch_dtype=torch.float16):
56 | controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch_dtype)
57 | return controlnet
58 |
59 | @cache_model
60 | def load_image_encoder():
61 | image_encoder = CLIPVisionModelWithProjection.from_pretrained(
62 | "h94/IP-Adapter",
63 | subfolder="models/image_encoder",
64 | torch_dtype=torch.float16,
65 | )
66 | return image_encoder
67 |
68 | def load_common_sd15_pipe(base_model=DEFAULT_BASE_MODEL, device="auto", controlnet=None, ip_adapter=False, plus_model=True, torch_dtype=torch.float16, model_cpu_offload_seq=None, enable_sequential_cpu_offload=False, vae_slicing=False, pipeline_class=None, **kwargs):
69 | model_kwargs = dict(
70 | torch_dtype=torch_dtype,
71 | device_map=device,
72 | requires_safety_checker=False,
73 | safety_checker=None,
74 | )
75 | components = load_base_model_components(base_model=base_model, torch_dtype=torch_dtype)
76 | model_kwargs.update(components)
77 | model_kwargs.update(kwargs)
78 |
79 | if controlnet is not None:
80 | if isinstance(controlnet, list):
81 | controlnet = [load_controlnet(controlnet_path, torch_dtype=torch_dtype) for controlnet_path in controlnet]
82 | else:
83 | controlnet = load_controlnet(controlnet, torch_dtype=torch_dtype)
84 | model_kwargs.update(controlnet=controlnet)
85 |
86 | if pipeline_class is None:
87 | if controlnet is not None:
88 | pipeline_class = StableDiffusionControlNetPipeline
89 | else:
90 | pipeline_class = StableDiffusionPipeline
91 |
92 | pipe: StableDiffusionPipeline = model_from_ckpt_or_pretrained(
93 | base_model,
94 | pipeline_class,
95 | **model_kwargs
96 | )
97 |
98 | if ip_adapter:
99 | image_encoder = load_image_encoder()
100 | pipe.image_encoder = image_encoder
101 | if plus_model:
102 | pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.safetensors")
103 | else:
104 | pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.safetensors")
105 | pipe.set_ip_adapter_scale(1.0)
106 | else:
107 | pipe.unload_ip_adapter()
108 |
109 | pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
110 |
111 | if model_cpu_offload_seq is None:
112 | if isinstance(pipe, StableDiffusionControlNetPipeline):
113 | pipe.model_cpu_offload_seq = "text_encoder->controlnet->unet->vae"
114 | elif isinstance(pipe, StableDiffusionControlNetImg2ImgPipeline):
115 | pipe.model_cpu_offload_seq = "text_encoder->controlnet->vae->unet->vae"
116 | else:
117 | pipe.model_cpu_offload_seq = model_cpu_offload_seq
118 |
119 | if enable_sequential_cpu_offload:
120 | pipe.enable_sequential_cpu_offload()
121 | else:
122 | pipe = pipe.to("cuda")
123 | pass
124 | # pipe.enable_model_cpu_offload()
125 | if vae_slicing:
126 | pipe.enable_vae_slicing()
127 |
128 | import gc
129 | gc.collect()
130 | return pipe
131 |
132 |
--------------------------------------------------------------------------------
/scripts/upsampler.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import math
3 | import numpy as np
4 | import os
5 | import torch
6 | from torch.nn import functional as F
7 | from scripts.load_onnx import load_onnx_caller
8 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
9 |
10 |
11 | class RealESRGANer():
12 | """A helper class for upsampling images with RealESRGAN.
13 |
14 | Args:
15 | scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
16 | model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
17 | model (nn.Module): The defined network. Default: None.
18 | tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
19 | input images into tiles, and then process each of them. Finally, they will be merged into one image.
20 | 0 denotes for do not use tile. Default: 0.
21 | tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
22 | pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
23 | half (float): Whether to use half precision during inference. Default: False.
24 | """
25 |
26 | def __init__(self,
27 | scale,
28 | onnx_path,
29 | tile=0,
30 | tile_pad=10,
31 | pre_pad=10,
32 | half=False,
33 | device=None,
34 | gpu_id=None):
35 | self.scale = scale
36 | self.tile_size = tile
37 | self.tile_pad = tile_pad
38 | self.pre_pad = pre_pad
39 | self.mod_scale = None
40 | self.half = half
41 |
42 | # initialize model
43 | if gpu_id:
44 | self.device = torch.device(
45 | f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
46 | else:
47 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
48 | self.model = load_onnx_caller(onnx_path, single_output=True)
49 | # warm up
50 | sample_input = torch.randn(1,3,512,512).cuda().float()
51 | self.model(sample_input)
52 |
53 | def pre_process(self, img):
54 | """Pre-process, such as pre-pad and mod pad, so that the images can be divisible
55 | """
56 | img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
57 | self.img = img.unsqueeze(0).to(self.device)
58 | if self.half:
59 | self.img = self.img.half()
60 |
61 | # pre_pad
62 | if self.pre_pad != 0:
63 | self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
64 | # mod pad for divisible borders
65 | if self.scale == 2:
66 | self.mod_scale = 2
67 | elif self.scale == 1:
68 | self.mod_scale = 4
69 | if self.mod_scale is not None:
70 | self.mod_pad_h, self.mod_pad_w = 0, 0
71 | _, _, h, w = self.img.size()
72 | if (h % self.mod_scale != 0):
73 | self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
74 | if (w % self.mod_scale != 0):
75 | self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
76 | self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
77 |
78 | def process(self):
79 | # model inference
80 | self.output = self.model(self.img)
81 |
82 | def tile_process(self):
83 | """It will first crop input images to tiles, and then process each tile.
84 | Finally, all the processed tiles are merged into one images.
85 |
86 | Modified from: https://github.com/ata4/esrgan-launcher
87 | """
88 | batch, channel, height, width = self.img.shape
89 | output_height = height * self.scale
90 | output_width = width * self.scale
91 | output_shape = (batch, channel, output_height, output_width)
92 |
93 | # start with black image
94 | self.output = self.img.new_zeros(output_shape)
95 | tiles_x = math.ceil(width / self.tile_size)
96 | tiles_y = math.ceil(height / self.tile_size)
97 |
98 | # loop over all tiles
99 | for y in range(tiles_y):
100 | for x in range(tiles_x):
101 | # extract tile from input image
102 | ofs_x = x * self.tile_size
103 | ofs_y = y * self.tile_size
104 | # input tile area on total image
105 | input_start_x = ofs_x
106 | input_end_x = min(ofs_x + self.tile_size, width)
107 | input_start_y = ofs_y
108 | input_end_y = min(ofs_y + self.tile_size, height)
109 |
110 | # input tile area on total image with padding
111 | input_start_x_pad = max(input_start_x - self.tile_pad, 0)
112 | input_end_x_pad = min(input_end_x + self.tile_pad, width)
113 | input_start_y_pad = max(input_start_y - self.tile_pad, 0)
114 | input_end_y_pad = min(input_end_y + self.tile_pad, height)
115 |
116 | # input tile dimensions
117 | input_tile_width = input_end_x - input_start_x
118 | input_tile_height = input_end_y - input_start_y
119 | tile_idx = y * tiles_x + x + 1
120 | input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
121 |
122 | # upscale tile
123 | try:
124 | with torch.no_grad():
125 | output_tile = self.model(input_tile)
126 | except RuntimeError as error:
127 | print('Error', error)
128 | print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
129 |
130 | # output tile area on total image
131 | output_start_x = input_start_x * self.scale
132 | output_end_x = input_end_x * self.scale
133 | output_start_y = input_start_y * self.scale
134 | output_end_y = input_end_y * self.scale
135 |
136 | # output tile area without padding
137 | output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
138 | output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
139 | output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
140 | output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
141 |
142 | # put tile into output image
143 | self.output[:, :, output_start_y:output_end_y,
144 | output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
145 | output_start_x_tile:output_end_x_tile]
146 |
147 | def post_process(self):
148 | # remove extra pad
149 | if self.mod_scale is not None:
150 | _, _, h, w = self.output.size()
151 | self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
152 | # remove prepad
153 | if self.pre_pad != 0:
154 | _, _, h, w = self.output.size()
155 | self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
156 | return self.output
157 |
158 | @torch.no_grad()
159 | def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
160 | h_input, w_input = img.shape[0:2]
161 | # img: numpy
162 | img = img.astype(np.float32)
163 | if np.max(img) > 256: # 16-bit image
164 | max_range = 65535
165 | print('\tInput is a 16-bit image')
166 | else:
167 | max_range = 255
168 | img = img / max_range
169 | if len(img.shape) == 2: # gray image
170 | img_mode = 'L'
171 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
172 | elif img.shape[2] == 4: # RGBA image with alpha channel
173 | img_mode = 'RGBA'
174 | alpha = img[:, :, 3]
175 | img = img[:, :, 0:3]
176 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
177 | if alpha_upsampler == 'realesrgan':
178 | alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
179 | else:
180 | img_mode = 'RGB'
181 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
182 |
183 | # ------------------- process image (without the alpha channel) ------------------- #
184 | self.pre_process(img)
185 | if self.tile_size > 0:
186 | self.tile_process()
187 | else:
188 | self.process()
189 | output_img = self.post_process()
190 | output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
191 | output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
192 | if img_mode == 'L':
193 | output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
194 |
195 | # ------------------- process the alpha channel if necessary ------------------- #
196 | if img_mode == 'RGBA':
197 | if alpha_upsampler == 'realesrgan':
198 | self.pre_process(alpha)
199 | if self.tile_size > 0:
200 | self.tile_process()
201 | else:
202 | self.process()
203 | output_alpha = self.post_process()
204 | output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
205 | output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
206 | output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
207 | else: # use the cv2 resize for alpha channel
208 | h, w = alpha.shape[0:2]
209 | output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
210 |
211 | # merge the alpha channel
212 | output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
213 | output_img[:, :, 3] = output_alpha
214 |
215 | # ------------------------------ return ------------------------------ #
216 | if max_range == 65535: # 16-bit image
217 | output = (output_img * 65535.0).round().astype(np.uint16)
218 | else:
219 | output = (output_img * 255.0).round().astype(np.uint8)
220 |
221 | if outscale is not None and outscale != float(self.scale):
222 | output = cv2.resize(
223 | output, (
224 | int(w_input * outscale),
225 | int(h_input * outscale),
226 | ), interpolation=cv2.INTER_LANCZOS4)
227 |
228 | return output, img_mode
229 |
230 |
--------------------------------------------------------------------------------
/scripts/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from PIL import Image
4 | import pymeshlab
5 | import pymeshlab as ml
6 | from pymeshlab import PercentageValue
7 | from pytorch3d.renderer import TexturesVertex
8 | from pytorch3d.structures import Meshes
9 | from rembg import new_session, remove
10 | import torch
11 | import torch.nn.functional as F
12 | from typing import List, Tuple
13 | from PIL import Image
14 | import trimesh
15 |
16 | providers = [
17 | ('CUDAExecutionProvider', {
18 | 'device_id': 0,
19 | 'arena_extend_strategy': 'kSameAsRequested',
20 | 'gpu_mem_limit': 8 * 1024 * 1024 * 1024,
21 | 'cudnn_conv_algo_search': 'HEURISTIC',
22 | })
23 | ]
24 |
25 | session = new_session(providers=providers)
26 |
27 | NEG_PROMPT="sketch, sculpture, hand drawing, outline, single color, NSFW, lowres, bad anatomy,bad hands, text, error, missing fingers, yellow sleeves, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry,(worst quality:1.4),(low quality:1.4)"
28 |
29 | def load_mesh_with_trimesh(file_name, file_type=None):
30 | import trimesh
31 | mesh: trimesh.Trimesh = trimesh.load(file_name, file_type=file_type)
32 | if isinstance(mesh, trimesh.Scene):
33 | assert len(mesh.geometry) > 0
34 | # save to obj first and load again to avoid offset issue
35 | from io import BytesIO
36 | with BytesIO() as f:
37 | mesh.export(f, file_type="obj")
38 | f.seek(0)
39 | mesh = trimesh.load(f, file_type="obj")
40 | if isinstance(mesh, trimesh.Scene):
41 | # we lose texture information here
42 | mesh = trimesh.util.concatenate(
43 | tuple(trimesh.Trimesh(vertices=g.vertices, faces=g.faces)
44 | for g in mesh.geometry.values()))
45 | assert isinstance(mesh, trimesh.Trimesh)
46 |
47 | vertices = torch.from_numpy(mesh.vertices).T
48 | faces = torch.from_numpy(mesh.faces).T
49 | colors = None
50 | if mesh.visual is not None:
51 | if hasattr(mesh.visual, 'vertex_colors'):
52 | colors = torch.from_numpy(mesh.visual.vertex_colors)[..., :3].T / 255.
53 | if colors is None:
54 | # print("Warning: no vertex color found in mesh! Filling it with gray.")
55 | colors = torch.ones_like(vertices) * 0.5
56 | return vertices, faces, colors
57 |
58 | def meshlab_mesh_to_py3dmesh(mesh: pymeshlab.Mesh) -> Meshes:
59 | verts = torch.from_numpy(mesh.vertex_matrix()).float()
60 | faces = torch.from_numpy(mesh.face_matrix()).long()
61 | colors = torch.from_numpy(mesh.vertex_color_matrix()[..., :3]).float()
62 | textures = TexturesVertex(verts_features=[colors])
63 | return Meshes(verts=[verts], faces=[faces], textures=textures)
64 |
65 |
66 | def py3dmesh_to_meshlab_mesh(meshes: Meshes) -> pymeshlab.Mesh:
67 | colors_in = F.pad(meshes.textures.verts_features_packed().cpu().float(), [0,1], value=1).numpy().astype(np.float64)
68 | m1 = pymeshlab.Mesh(
69 | vertex_matrix=meshes.verts_packed().cpu().float().numpy().astype(np.float64),
70 | face_matrix=meshes.faces_packed().cpu().long().numpy().astype(np.int32),
71 | v_normals_matrix=meshes.verts_normals_packed().cpu().float().numpy().astype(np.float64),
72 | v_color_matrix=colors_in)
73 | return m1
74 |
75 |
76 | def to_pyml_mesh(vertices,faces):
77 | m1 = pymeshlab.Mesh(
78 | vertex_matrix=vertices.cpu().float().numpy().astype(np.float64),
79 | face_matrix=faces.cpu().long().numpy().astype(np.int32),
80 | )
81 | return m1
82 |
83 |
84 | def to_py3d_mesh(vertices, faces, normals=None):
85 | from pytorch3d.structures import Meshes
86 | from pytorch3d.renderer.mesh.textures import TexturesVertex
87 | mesh = Meshes(verts=[vertices], faces=[faces], textures=None)
88 | if normals is None:
89 | normals = mesh.verts_normals_packed()
90 | # set normals as vertext colors
91 | mesh.textures = TexturesVertex(verts_features=[normals / 2 + 0.5])
92 | return mesh
93 |
94 |
95 | def from_py3d_mesh(mesh):
96 | return mesh.verts_list()[0], mesh.faces_list()[0], mesh.textures.verts_features_packed()
97 |
98 | def rotate_normalmap_by_angle(normal_map: np.ndarray, angle: float):
99 | """
100 | rotate along y-axis
101 | normal_map: np.array, shape=(H, W, 3) in [-1, 1]
102 | angle: float, in degree
103 | """
104 | angle = angle / 180 * np.pi
105 | R = np.array([[np.cos(angle), 0, np.sin(angle)], [0, 1, 0], [-np.sin(angle), 0, np.cos(angle)]])
106 | return np.dot(normal_map.reshape(-1, 3), R.T).reshape(normal_map.shape)
107 |
108 | # from view coord to front view world coord
109 | def rotate_normals(normal_pils, return_types='np', rotate_direction=1) -> np.ndarray: # [0, 255]
110 | n_views = len(normal_pils)
111 | ret = []
112 | for idx, rgba_normal in enumerate(normal_pils):
113 | # rotate normal
114 | normal_np = np.array(rgba_normal)[:, :, :3] / 255 # in [-1, 1]
115 | alpha_np = np.array(rgba_normal)[:, :, 3] / 255 # in [0, 1]
116 | normal_np = normal_np * 2 - 1
117 | normal_np = rotate_normalmap_by_angle(normal_np, rotate_direction * idx * (360 / n_views))
118 | normal_np = (normal_np + 1) / 2
119 | normal_np = normal_np * alpha_np[..., None] # make bg black
120 | rgba_normal_np = np.concatenate([normal_np * 255, alpha_np[:, :, None] * 255] , axis=-1)
121 | if return_types == 'np':
122 | ret.append(rgba_normal_np)
123 | elif return_types == 'pil':
124 | ret.append(Image.fromarray(rgba_normal_np.astype(np.uint8)))
125 | else:
126 | raise ValueError(f"return_types should be 'np' or 'pil', but got {return_types}")
127 | return ret
128 |
129 |
130 | def rotate_normalmap_by_angle_torch(normal_map, angle):
131 | """
132 | rotate along y-axis
133 | normal_map: torch.Tensor, shape=(H, W, 3) in [-1, 1], device='cuda'
134 | angle: float, in degree
135 | """
136 | angle = torch.tensor(angle / 180 * np.pi).to(normal_map)
137 | R = torch.tensor([[torch.cos(angle), 0, torch.sin(angle)],
138 | [0, 1, 0],
139 | [-torch.sin(angle), 0, torch.cos(angle)]]).to(normal_map)
140 | return torch.matmul(normal_map.view(-1, 3), R.T).view(normal_map.shape)
141 |
142 | def do_rotate(rgba_normal, angle):
143 | rgba_normal = torch.from_numpy(rgba_normal).float().cuda() / 255
144 | rotated_normal_tensor = rotate_normalmap_by_angle_torch(rgba_normal[..., :3] * 2 - 1, angle)
145 | rotated_normal_tensor = (rotated_normal_tensor + 1) / 2
146 | rotated_normal_tensor = rotated_normal_tensor * rgba_normal[:, :, [3]] # make bg black
147 | rgba_normal_np = torch.cat([rotated_normal_tensor * 255, rgba_normal[:, :, [3]] * 255], dim=-1).cpu().numpy()
148 | return rgba_normal_np
149 |
150 | def rotate_normals_torch(normal_pils, return_types='np', rotate_direction=1):
151 | n_views = len(normal_pils)
152 | ret = []
153 | for idx, rgba_normal in enumerate(normal_pils):
154 | # rotate normal
155 | angle = rotate_direction * idx * (360 / n_views)
156 | rgba_normal_np = do_rotate(np.array(rgba_normal), angle)
157 | if return_types == 'np':
158 | ret.append(rgba_normal_np)
159 | elif return_types == 'pil':
160 | ret.append(Image.fromarray(rgba_normal_np.astype(np.uint8)))
161 | else:
162 | raise ValueError(f"return_types should be 'np' or 'pil', but got {return_types}")
163 | return ret
164 |
165 | def change_bkgd(img_pils, new_bkgd=(0., 0., 0.)):
166 | ret = []
167 | new_bkgd = np.array(new_bkgd).reshape(1, 1, 3)
168 | for rgba_img in img_pils:
169 | img_np = np.array(rgba_img)[:, :, :3] / 255
170 | alpha_np = np.array(rgba_img)[:, :, 3] / 255
171 | ori_bkgd = img_np[:1, :1]
172 | # color = ori_color * alpha + bkgd * (1-alpha)
173 | # ori_color = (color - bkgd * (1-alpha)) / alpha
174 | alpha_np_clamp = np.clip(alpha_np, 1e-6, 1) # avoid divide by zero
175 | ori_img_np = (img_np - ori_bkgd * (1 - alpha_np[..., None])) / alpha_np_clamp[..., None]
176 | img_np = np.where(alpha_np[..., None] > 0.05, ori_img_np * alpha_np[..., None] + new_bkgd * (1 - alpha_np[..., None]), new_bkgd)
177 | rgba_img_np = np.concatenate([img_np * 255, alpha_np[..., None] * 255], axis=-1)
178 | ret.append(Image.fromarray(rgba_img_np.astype(np.uint8)))
179 | return ret
180 |
181 | def change_bkgd_to_normal(normal_pils) -> List[Image.Image]:
182 | n_views = len(normal_pils)
183 | ret = []
184 | for idx, rgba_normal in enumerate(normal_pils):
185 | # calcuate background normal
186 | target_bkgd = rotate_normalmap_by_angle(np.array([[[0., 0., 1.]]]), idx * (360 / n_views))
187 | normal_np = np.array(rgba_normal)[:, :, :3] / 255 # in [-1, 1]
188 | alpha_np = np.array(rgba_normal)[:, :, 3] / 255 # in [0, 1]
189 | normal_np = normal_np * 2 - 1
190 | old_bkgd = normal_np[:1,:1]
191 | normal_np[alpha_np > 0.05] = (normal_np[alpha_np > 0.05] - old_bkgd * (1 - alpha_np[alpha_np > 0.05][..., None])) / alpha_np[alpha_np > 0.05][..., None]
192 | normal_np = normal_np * alpha_np[..., None] + target_bkgd * (1 - alpha_np[..., None])
193 | normal_np = (normal_np + 1) / 2
194 | rgba_normal_np = np.concatenate([normal_np * 255, alpha_np[..., None] * 255] , axis=-1)
195 | ret.append(Image.fromarray(rgba_normal_np.astype(np.uint8)))
196 | return ret
197 |
198 |
199 | def fix_vert_color_glb(mesh_path):
200 | from pygltflib import GLTF2, Material, PbrMetallicRoughness
201 | obj1 = GLTF2().load(mesh_path)
202 | obj1.meshes[0].primitives[0].material = 0
203 | obj1.materials.append(Material(
204 | pbrMetallicRoughness = PbrMetallicRoughness(
205 | baseColorFactor = [1.0, 1.0, 1.0, 1.0],
206 | metallicFactor = 0.,
207 | roughnessFactor = 1.0,
208 | ),
209 | emissiveFactor = [0.0, 0.0, 0.0],
210 | doubleSided = True,
211 | ))
212 | obj1.save(mesh_path)
213 |
214 |
215 | def srgb_to_linear(c_srgb):
216 | c_linear = np.where(c_srgb <= 0.04045, c_srgb / 12.92, ((c_srgb + 0.055) / 1.055) ** 2.4)
217 | return c_linear.clip(0, 1.)
218 |
219 |
220 | def save_py3dmesh_with_trimesh_fast(meshes: Meshes, save_glb_path, apply_sRGB_to_LinearRGB=True):
221 | # convert from pytorch3d meshes to trimesh mesh
222 | vertices = meshes.verts_packed().cpu().float().numpy()
223 | triangles = meshes.faces_packed().cpu().long().numpy()
224 | np_color = meshes.textures.verts_features_packed().cpu().float().numpy()
225 | if save_glb_path.endswith(".glb"):
226 | # rotate 180 along +Y
227 | vertices[:, [0, 2]] = -vertices[:, [0, 2]]
228 |
229 | if apply_sRGB_to_LinearRGB:
230 | np_color = srgb_to_linear(np_color)
231 | assert vertices.shape[0] == np_color.shape[0]
232 | assert np_color.shape[1] == 3
233 | assert 0 <= np_color.min() and np_color.max() <= 1, f"min={np_color.min()}, max={np_color.max()}"
234 | mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, vertex_colors=np_color)
235 | mesh.remove_unreferenced_vertices()
236 | # save mesh
237 | mesh.export(save_glb_path)
238 | if save_glb_path.endswith(".glb"):
239 | fix_vert_color_glb(save_glb_path)
240 | print(f"saving to {save_glb_path}")
241 |
242 |
243 | def save_glb_and_video(save_mesh_prefix: str, meshes: Meshes, with_timestamp=True, dist=3.5, azim_offset=180, resolution=512, fov_in_degrees=1 / 1.15, cam_type="ortho", view_padding=60, export_video=True) -> Tuple[str, str]:
244 | import time
245 | if '.' in save_mesh_prefix:
246 | save_mesh_prefix = ".".join(save_mesh_prefix.split('.')[:-1])
247 | if with_timestamp:
248 | save_mesh_prefix = save_mesh_prefix + f"_{int(time.time())}"
249 | ret_mesh = save_mesh_prefix + ".glb"
250 | # optimizied version
251 | save_py3dmesh_with_trimesh_fast(meshes, ret_mesh)
252 | return ret_mesh, None
253 |
254 |
255 | def simple_clean_mesh(pyml_mesh: ml.Mesh, apply_smooth=True, stepsmoothnum=1, apply_sub_divide=False, sub_divide_threshold=0.25):
256 | ms = ml.MeshSet()
257 | ms.add_mesh(pyml_mesh, "cube_mesh")
258 |
259 | if apply_smooth:
260 | ms.apply_filter("apply_coord_laplacian_smoothing", stepsmoothnum=stepsmoothnum, cotangentweight=False)
261 | if apply_sub_divide: # 5s, slow
262 | ms.apply_filter("meshing_repair_non_manifold_vertices")
263 | ms.apply_filter("meshing_repair_non_manifold_edges", method='Remove Faces')
264 | ms.apply_filter("meshing_surface_subdivision_loop", iterations=2, threshold=PercentageValue(sub_divide_threshold))
265 | return meshlab_mesh_to_py3dmesh(ms.current_mesh())
266 |
267 |
268 | def expand2square(pil_img, background_color):
269 | width, height = pil_img.size
270 | if width == height:
271 | return pil_img
272 | elif width > height:
273 | result = Image.new(pil_img.mode, (width, width), background_color)
274 | result.paste(pil_img, (0, (width - height) // 2))
275 | return result
276 | else:
277 | result = Image.new(pil_img.mode, (height, height), background_color)
278 | result.paste(pil_img, ((height - width) // 2, 0))
279 | return result
280 |
281 |
282 | def simple_preprocess(input_image, rembg_session=session, background_color=255):
283 | RES = 2048
284 | input_image.thumbnail([RES, RES], Image.Resampling.LANCZOS)
285 | if input_image.mode != 'RGBA':
286 | image_rem = input_image.convert('RGBA')
287 | input_image = remove(image_rem, alpha_matting=False, session=rembg_session)
288 |
289 | arr = np.asarray(input_image)
290 | alpha = np.asarray(input_image)[:, :, -1]
291 | x_nonzero = np.nonzero((alpha > 60).sum(axis=1))
292 | y_nonzero = np.nonzero((alpha > 60).sum(axis=0))
293 | x_min = int(x_nonzero[0].min())
294 | y_min = int(y_nonzero[0].min())
295 | x_max = int(x_nonzero[0].max())
296 | y_max = int(y_nonzero[0].max())
297 | arr = arr[x_min: x_max, y_min: y_max]
298 | input_image = Image.fromarray(arr)
299 | input_image = expand2square(input_image, (background_color, background_color, background_color, 0))
300 | return input_image
301 |
302 | def init_target(img_pils, new_bkgd=(0., 0., 0.), device="cuda"):
303 | # Convert the background color to a PyTorch tensor
304 | new_bkgd = torch.tensor(new_bkgd, dtype=torch.float32).view(1, 1, 3).to(device)
305 |
306 | # Convert all images to PyTorch tensors and process them
307 | imgs = torch.stack([torch.from_numpy(np.array(img, dtype=np.float32)) for img in img_pils]).to(device) / 255
308 | img_nps = imgs[..., :3]
309 | alpha_nps = imgs[..., 3]
310 | ori_bkgds = img_nps[:, :1, :1]
311 |
312 | # Avoid divide by zero and calculate the original image
313 | alpha_nps_clamp = torch.clamp(alpha_nps, 1e-6, 1)
314 | ori_img_nps = (img_nps - ori_bkgds * (1 - alpha_nps.unsqueeze(-1))) / alpha_nps_clamp.unsqueeze(-1)
315 | ori_img_nps = torch.clamp(ori_img_nps, 0, 1)
316 | img_nps = torch.where(alpha_nps.unsqueeze(-1) > 0.05, ori_img_nps * alpha_nps.unsqueeze(-1) + new_bkgd * (1 - alpha_nps.unsqueeze(-1)), new_bkgd)
317 |
318 | rgba_img_np = torch.cat([img_nps, alpha_nps.unsqueeze(-1)], dim=-1)
319 | return rgba_img_np
--------------------------------------------------------------------------------