├── .editorconfig ├── .github └── workflows │ ├── anchore.yml │ ├── codacy.yml │ ├── devskim.yml │ ├── pylint.yml │ ├── pysa.yml │ └── semgrep.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE-CODE ├── LICENSE-MODEL ├── README.md ├── assets ├── diagram-1.png ├── logo.png ├── mesh_meshlab.jpg ├── repo_demo_0.mp4 ├── repo_demo_01.mp4 ├── repo_demo_02.mp4 ├── repo_static_v2.png └── result_mushroom.mp4 ├── configs ├── dreamcraft3d-coarse-nerf.yaml ├── dreamcraft3d-coarse-neus.yaml ├── dreamcraft3d-geometry.yaml └── dreamcraft3d-texture.yaml ├── docker ├── Dockerfile └── compose.yaml ├── docs └── installation.md ├── extern ├── __init__.py ├── ldm_zero123 │ ├── extras.py │ ├── guidance.py │ ├── lr_scheduler.py │ ├── models │ │ ├── autoencoder.py │ │ └── diffusion │ │ │ ├── __init__.py │ │ │ ├── classifier.py │ │ │ ├── ddim.py │ │ │ ├── ddpm.py │ │ │ ├── plms.py │ │ │ └── sampling_util.py │ ├── modules │ │ ├── attention.py │ │ ├── attention_ori.py │ │ ├── diffusionmodules │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ ├── openaimodel.py │ │ │ └── util.py │ │ ├── distributions │ │ │ ├── __init__.py │ │ │ └── distributions.py │ │ ├── ema.py │ │ ├── encoders │ │ │ ├── __init__.py │ │ │ └── modules.py │ │ ├── evaluate │ │ │ ├── adm_evaluator.py │ │ │ ├── evaluate_perceptualsim.py │ │ │ ├── frechet_video_distance.py │ │ │ ├── ssim.py │ │ │ └── torch_frechet_video_distance.py │ │ ├── image_degradation │ │ │ ├── __init__.py │ │ │ ├── bsrgan.py │ │ │ ├── bsrgan_light.py │ │ │ ├── utils │ │ │ │ └── test.png │ │ │ └── utils_image.py │ │ ├── losses │ │ │ ├── __init__.py │ │ │ ├── contperceptual.py │ │ │ └── vqperceptual.py │ │ └── x_transformer.py │ ├── thirdp │ │ └── psp │ │ │ ├── helpers.py │ │ │ ├── id_loss.py │ │ │ └── model_irse.py │ └── util.py └── zero123.py ├── gradio_app.py ├── launch.py ├── load ├── images │ ├── a DSLR photo of a puffin standing on a rock_depth.png │ ├── a DSLR photo of a puffin standing on a rock_normal.png │ ├── a DSLR photo of a puffin standing on a rock_rgba.png │ ├── a figurine of a frog holding a birthday cake_depth.png │ ├── a figurine of a frog holding a birthday cake_normal.png │ ├── a figurine of a frog holding a birthday cake_rgba.png │ ├── a kingfisher sitting on top of a piece of wood_depth.png │ ├── a kingfisher sitting on top of a piece of wood_normal.png │ ├── a kingfisher sitting on top of a piece of wood_rgba.png │ ├── a rubber duck dressed as a nurse_depth.png │ ├── a rubber duck dressed as a nurse_normal.png │ ├── a rubber duck dressed as a nurse_rgba.png │ ├── a white bowl of multiple fruits_depth.png │ ├── a white bowl of multiple fruits_normal.png │ ├── a white bowl of multiple fruits_rgba.png │ ├── groot_caption.txt │ ├── groot_depth.png │ ├── groot_normal.png │ ├── groot_rgba.png │ ├── jay-basket_caption.txt │ ├── jay-basket_depth.png │ ├── jay-basket_normal.png │ ├── jay-basket_rgba.png │ ├── mushroom_log_caption.txt │ ├── mushroom_log_depth.png │ ├── mushroom_log_normal.png │ ├── mushroom_log_rgba.png │ ├── tiger dressed as a nurse_depth.png │ ├── tiger dressed as a nurse_normal.png │ └── tiger dressed as a nurse_rgba.png ├── lights │ ├── LICENSE.txt │ ├── bsdf_256_256.bin │ └── mud_road_puresky_1k.hdr ├── make_prompt_library.py ├── prompt_library.json ├── tets │ ├── 128_tets.npz │ ├── 32_tets.npz │ ├── 64_tets.npz │ └── generate_tets.py └── zero123 │ ├── download.sh │ └── sd-objaverse-finetune-c_concat-256.yaml ├── metric_utils.py ├── preprocess_image.py ├── requirements.txt └── threestudio ├── __init__.py ├── data ├── __init__.py ├── image.py ├── images.py └── uncond.py ├── models ├── __init__.py ├── background │ ├── __init__.py │ ├── base.py │ ├── neural_environment_map_background.py │ ├── solid_color_background.py │ └── textured_background.py ├── estimators.py ├── exporters │ ├── __init__.py │ ├── base.py │ └── mesh_exporter.py ├── geometry │ ├── __init__.py │ ├── base.py │ ├── custom_mesh.py │ ├── implicit_sdf.py │ ├── implicit_volume.py │ ├── tetrahedra_sdf_grid.py │ └── volume_grid.py ├── guidance │ ├── __init__.py │ ├── clip_guidance.py │ ├── controlnet_guidance.py │ ├── controlnet_reg_guidance.py │ ├── deep_floyd_guidance.py │ ├── stable_diffusion_bsd_guidance.py │ ├── stable_diffusion_guidance.py │ ├── stable_diffusion_unified_guidance.py │ ├── stable_diffusion_vsd_guidance.py │ ├── stable_zero123_guidance.py │ ├── zero123_guidance.py │ └── zero123_unified_guidance.py ├── isosurface.py ├── materials │ ├── __init__.py │ ├── base.py │ ├── diffuse_with_point_light_material.py │ ├── hybrid_rgb_latent_material.py │ ├── neural_radiance_material.py │ ├── no_material.py │ ├── pbr_material.py │ └── sd_latent_adapter_material.py ├── mesh.py ├── networks.py ├── prompt_processors │ ├── __init__.py │ ├── base.py │ ├── clip_prompt_processor.py │ ├── deepfloyd_prompt_processor.py │ ├── dummy_prompt_processor.py │ └── stable_diffusion_prompt_processor.py └── renderers │ ├── __init__.py │ ├── base.py │ ├── deferred_volume_renderer.py │ ├── gan_volume_renderer.py │ ├── nerf_volume_renderer.py │ ├── neus_volume_renderer.py │ ├── nvdiff_rasterizer.py │ └── patch_renderer.py ├── scripts ├── convert_zero123_to_diffusers.py ├── dreamcraft3d_dreambooth.py ├── generate_images_if.py ├── generate_images_if_prompt_library.py ├── generate_mv_datasets.py ├── img_to_mv.py ├── make_training_vid.py ├── metric_utils.py ├── run_gaussian.sh ├── run_zero123.py ├── run_zero123_comparison.sh ├── run_zero123_demo.sh ├── run_zero123_phase.sh ├── run_zero123_phase2.sh ├── test_dreambooth.py ├── test_dreambooth_lora.py ├── train_dreambooth.py ├── train_dreambooth_lora.py └── train_text_to_image_lora.py ├── systems ├── __init__.py ├── base.py ├── dreamcraft3d.py ├── utils.py └── zero123.py └── utils ├── GAN ├── attention.py ├── discriminator.py ├── distribution.py ├── loss.py ├── mobilenet.py ├── network_util.py ├── normalunet.py ├── util.py └── vae.py ├── __init__.py ├── base.py ├── callbacks.py ├── config.py ├── dpt.py ├── lpips ├── __init__.py ├── lpips.py └── utils.py ├── misc.py ├── ops.py ├── perceptual ├── __init__.py ├── perceptual.py └── utils.py ├── rasterize.py ├── saving.py └── typing.py /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*.py] 4 | charset = utf-8 5 | trim_trailing_whitespace = true 6 | end_of_line = lf 7 | insert_final_newline = true 8 | indent_style = space 9 | indent_size = 4 10 | 11 | [*.md] 12 | trim_trailing_whitespace = false 13 | -------------------------------------------------------------------------------- /.github/workflows/anchore.yml: -------------------------------------------------------------------------------- 1 | # This workflow uses actions that are not certified by GitHub. 2 | # They are provided by a third-party and are governed by 3 | # separate terms of service, privacy policy, and support 4 | # documentation. 5 | 6 | # This workflow checks out code, builds an image, performs a container image 7 | # vulnerability scan with Anchore's Grype tool, and integrates the results with GitHub Advanced Security 8 | # code scanning feature. For more information on the Anchore scan action usage 9 | # and parameters, see https://github.com/anchore/scan-action. For more 10 | # information on Anchore's container image scanning tool Grype, see 11 | # https://github.com/anchore/grype 12 | name: Anchore Grype vulnerability scan 13 | 14 | on: 15 | push: 16 | branches: [ "main" ] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [ "main" ] 20 | schedule: 21 | - cron: '42 8 * * 2' 22 | 23 | permissions: 24 | contents: read 25 | 26 | jobs: 27 | Anchore-Build-Scan: 28 | permissions: 29 | contents: read # for actions/checkout to fetch code 30 | security-events: write # for github/codeql-action/upload-sarif to upload SARIF results 31 | actions: read # only required for a private repository by github/codeql-action/upload-sarif to get the Action run status 32 | runs-on: ubuntu-latest 33 | steps: 34 | - name: Check out the code 35 | uses: actions/checkout@v3 36 | - name: Build the Docker image 37 | run: docker build . --file Dockerfile --tag localbuild/testimage:latest 38 | - name: Run the Anchore Grype scan action 39 | uses: anchore/scan-action@d5aa5b6cb9414b0c7771438046ff5bcfa2854ed7 40 | id: scan 41 | with: 42 | image: "localbuild/testimage:latest" 43 | fail-build: true 44 | severity-cutoff: critical 45 | - name: Upload vulnerability report 46 | uses: github/codeql-action/upload-sarif@v2 47 | with: 48 | sarif_file: ${{ steps.scan.outputs.sarif }} 49 | -------------------------------------------------------------------------------- /.github/workflows/codacy.yml: -------------------------------------------------------------------------------- 1 | # This workflow uses actions that are not certified by GitHub. 2 | # They are provided by a third-party and are governed by 3 | # separate terms of service, privacy policy, and support 4 | # documentation. 5 | 6 | # This workflow checks out code, performs a Codacy security scan 7 | # and integrates the results with the 8 | # GitHub Advanced Security code scanning feature. For more information on 9 | # the Codacy security scan action usage and parameters, see 10 | # https://github.com/codacy/codacy-analysis-cli-action. 11 | # For more information on Codacy Analysis CLI in general, see 12 | # https://github.com/codacy/codacy-analysis-cli. 13 | 14 | name: Codacy Security Scan 15 | 16 | on: 17 | push: 18 | branches: [ "main" ] 19 | pull_request: 20 | # The branches below must be a subset of the branches above 21 | branches: [ "main" ] 22 | schedule: 23 | - cron: '25 4 * * 1' 24 | 25 | permissions: 26 | contents: read 27 | 28 | jobs: 29 | codacy-security-scan: 30 | permissions: 31 | contents: read # for actions/checkout to fetch code 32 | security-events: write # for github/codeql-action/upload-sarif to upload SARIF results 33 | actions: read # only required for a private repository by github/codeql-action/upload-sarif to get the Action run status 34 | name: Codacy Security Scan 35 | runs-on: ubuntu-latest 36 | steps: 37 | # Checkout the repository to the GitHub Actions runner 38 | - name: Checkout code 39 | uses: actions/checkout@v3 40 | 41 | # Execute Codacy Analysis CLI and generate a SARIF output with the security issues identified during the analysis 42 | - name: Run Codacy Analysis CLI 43 | uses: codacy/codacy-analysis-cli-action@d840f886c4bd4edc059706d09c6a1586111c540b 44 | with: 45 | # Check https://github.com/codacy/codacy-analysis-cli#project-token to get your project token from your Codacy repository 46 | # You can also omit the token and run the tools that support default configurations 47 | project-token: ${{ secrets.CODACY_PROJECT_TOKEN }} 48 | verbose: true 49 | output: results.sarif 50 | format: sarif 51 | # Adjust severity of non-security issues 52 | gh-code-scanning-compat: true 53 | # Force 0 exit code to allow SARIF file generation 54 | # This will handover control about PR rejection to the GitHub side 55 | max-allowed-issues: 2147483647 56 | 57 | # Upload the SARIF file generated in the previous step 58 | - name: Upload SARIF results file 59 | uses: github/codeql-action/upload-sarif@v2 60 | with: 61 | sarif_file: results.sarif 62 | -------------------------------------------------------------------------------- /.github/workflows/devskim.yml: -------------------------------------------------------------------------------- 1 | # This workflow uses actions that are not certified by GitHub. 2 | # They are provided by a third-party and are governed by 3 | # separate terms of service, privacy policy, and support 4 | # documentation. 5 | 6 | name: DevSkim 7 | 8 | on: 9 | push: 10 | branches: [ "main" ] 11 | pull_request: 12 | branches: [ "main" ] 13 | schedule: 14 | - cron: '24 11 * * 0' 15 | 16 | jobs: 17 | lint: 18 | name: DevSkim 19 | runs-on: ubuntu-20.04 20 | permissions: 21 | actions: read 22 | contents: read 23 | security-events: write 24 | steps: 25 | - name: Checkout code 26 | uses: actions/checkout@v3 27 | 28 | - name: Run DevSkim scanner 29 | uses: microsoft/DevSkim-Action@v1 30 | 31 | - name: Upload DevSkim scan results to GitHub Security tab 32 | uses: github/codeql-action/upload-sarif@v2 33 | with: 34 | sarif_file: devskim-results.sarif 35 | -------------------------------------------------------------------------------- /.github/workflows/pylint.yml: -------------------------------------------------------------------------------- 1 | name: Pylint 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | matrix: 10 | python-version: ["3.8", "3.9", "3.10"] 11 | steps: 12 | - uses: actions/checkout@v3 13 | - name: Set up Python ${{ matrix.python-version }} 14 | uses: actions/setup-python@v3 15 | with: 16 | python-version: ${{ matrix.python-version }} 17 | - name: Install dependencies 18 | run: | 19 | python -m pip install --upgrade pip 20 | pip install pylint 21 | - name: Analysing the code with pylint 22 | run: | 23 | pylint $(git ls-files '*.py') 24 | -------------------------------------------------------------------------------- /.github/workflows/pysa.yml: -------------------------------------------------------------------------------- 1 | # This workflow uses actions that are not certified by GitHub. 2 | # They are provided by a third-party and are governed by 3 | # separate terms of service, privacy policy, and support 4 | # documentation. 5 | 6 | # This workflow integrates Python Static Analyzer (Pysa) with 7 | # GitHub's Code Scanning feature. 8 | # 9 | # Python Static Analyzer (Pysa) is a security-focused static 10 | # analysis tool that tracks flows of data from where they 11 | # originate to where they terminate in a dangerous location. 12 | # 13 | # See https://pyre-check.org/docs/pysa-basics/ 14 | 15 | name: Pysa 16 | 17 | on: 18 | workflow_dispatch: 19 | push: 20 | branches: [ "main" ] 21 | pull_request: 22 | branches: [ "main" ] 23 | schedule: 24 | - cron: '39 15 * * 0' 25 | 26 | permissions: 27 | contents: read 28 | 29 | jobs: 30 | pysa: 31 | permissions: 32 | actions: read 33 | contents: read 34 | security-events: write 35 | 36 | runs-on: ubuntu-latest 37 | steps: 38 | - uses: actions/checkout@v3 39 | with: 40 | submodules: true 41 | 42 | - name: Run Pysa 43 | uses: facebook/pysa-action@f46a63777e59268613bd6e2ff4e29f144ca9e88b 44 | with: 45 | # To customize these inputs: 46 | # See https://github.com/facebook/pysa-action#inputs 47 | repo-directory: './' 48 | requirements-path: 'requirements.txt' 49 | infer-types: true 50 | include-default-sapp-filters: true 51 | -------------------------------------------------------------------------------- /.github/workflows/semgrep.yml: -------------------------------------------------------------------------------- 1 | # This workflow uses actions that are not certified by GitHub. 2 | # They are provided by a third-party and are governed by 3 | # separate terms of service, privacy policy, and support 4 | # documentation. 5 | 6 | # This workflow file requires a free account on Semgrep.dev to 7 | # manage rules, file ignores, notifications, and more. 8 | # 9 | # See https://semgrep.dev/docs 10 | 11 | name: Semgrep 12 | 13 | on: 14 | push: 15 | branches: [ "main" ] 16 | pull_request: 17 | # The branches below must be a subset of the branches above 18 | branches: [ "main" ] 19 | schedule: 20 | - cron: '22 9 * * 3' 21 | 22 | permissions: 23 | contents: read 24 | 25 | jobs: 26 | semgrep: 27 | permissions: 28 | contents: read # for actions/checkout to fetch code 29 | security-events: write # for github/codeql-action/upload-sarif to upload SARIF results 30 | actions: read # only required for a private repository by github/codeql-action/upload-sarif to get the Action run status 31 | name: Scan 32 | runs-on: ubuntu-latest 33 | steps: 34 | # Checkout project source 35 | - uses: actions/checkout@v3 36 | 37 | # Scan code using project's configuration on https://semgrep.dev/manage 38 | - uses: returntocorp/semgrep-action@fcd5ab7459e8d91cb1777481980d1b18b4fc6735 39 | with: 40 | publishToken: ${{ secrets.SEMGREP_APP_TOKEN }} 41 | publishDeployment: ${{ secrets.SEMGREP_DEPLOYMENT_ID }} 42 | generateSarif: "1" 43 | 44 | # Upload SARIF file generated in previous step 45 | - name: Upload SARIF file 46 | uses: github/codeql-action/upload-sarif@v2 47 | with: 48 | sarif_file: semgrep.sarif 49 | if: always() 50 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/python 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=python 3 | 4 | ### Python ### 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # poetry 102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 106 | #poetry.lock 107 | 108 | # pdm 109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 110 | #pdm.lock 111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 112 | # in version control. 113 | # https://pdm.fming.dev/#use-with-ide 114 | .pdm.toml 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # pytype static type analyzer 154 | .pytype/ 155 | 156 | # Cython debug symbols 157 | cython_debug/ 158 | 159 | # PyCharm 160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 162 | # and can be added to the global gitignore or merged into this file. For a more nuclear 163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 164 | #.idea/ 165 | 166 | ### Python Patch ### 167 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 168 | poetry.toml 169 | 170 | # ruff 171 | .ruff_cache/ 172 | 173 | # LSP config files 174 | pyrightconfig.json 175 | 176 | # End of https://www.toptal.com/developers/gitignore/api/python 177 | 178 | .vscode/ 179 | .threestudio_cache/ 180 | outputs/ 181 | outputs-gradio/ 182 | 183 | # pretrained model weights 184 | *.ckpt 185 | *.pt 186 | *.pth 187 | 188 | # wandb 189 | wandb/ 190 | 191 | load/tets/256_tets.npz 192 | 193 | # dataset 194 | dataset/ 195 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v4.4.0 7 | hooks: 8 | - id: trailing-whitespace 9 | - id: check-ast 10 | - id: check-merge-conflict 11 | - id: check-yaml 12 | - id: end-of-file-fixer 13 | - id: trailing-whitespace 14 | args: [--markdown-linebreak-ext=md] 15 | 16 | - repo: https://github.com/psf/black 17 | rev: 23.3.0 18 | hooks: 19 | - id: black 20 | language_version: python3 21 | 22 | - repo: https://github.com/pycqa/isort 23 | rev: 5.12.0 24 | hooks: 25 | - id: isort 26 | exclude: README.md 27 | args: ["--profile", "black"] 28 | 29 | # temporarily disable static type checking 30 | # - repo: https://github.com/pre-commit/mirrors-mypy 31 | # rev: v1.2.0 32 | # hooks: 33 | # - id: mypy 34 | # args: ["--ignore-missing-imports", "--scripts-are-modules", "--pretty"] -------------------------------------------------------------------------------- /LICENSE-CODE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 DeepSeek 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /assets/diagram-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/assets/diagram-1.png -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/assets/logo.png -------------------------------------------------------------------------------- /assets/mesh_meshlab.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/assets/mesh_meshlab.jpg -------------------------------------------------------------------------------- /assets/repo_demo_0.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/assets/repo_demo_0.mp4 -------------------------------------------------------------------------------- /assets/repo_demo_01.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/assets/repo_demo_01.mp4 -------------------------------------------------------------------------------- /assets/repo_demo_02.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/assets/repo_demo_02.mp4 -------------------------------------------------------------------------------- /assets/repo_static_v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/assets/repo_static_v2.png -------------------------------------------------------------------------------- /assets/result_mushroom.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/assets/result_mushroom.mp4 -------------------------------------------------------------------------------- /configs/dreamcraft3d-coarse-nerf.yaml: -------------------------------------------------------------------------------- 1 | name: "dreamcraft3d-coarse-nerf" 2 | tag: "${rmspace:${system.prompt_processor.prompt},_}" 3 | exp_root_dir: "outputs" 4 | seed: 0 5 | 6 | data_type: "single-image-datamodule" 7 | data: 8 | image_path: ./load/images/hamburger_rgba.png 9 | height: [128, 384] 10 | width: [128, 384] 11 | resolution_milestones: [3000] 12 | default_elevation_deg: 0.0 13 | default_azimuth_deg: 0.0 14 | default_camera_distance: 3.8 15 | default_fovy_deg: 20.0 16 | requires_depth: true 17 | requires_normal: ${cmaxgt0:${system.loss.lambda_normal}} 18 | random_camera: 19 | height: [128, 384] 20 | width: [128, 384] 21 | batch_size: [1, 1] 22 | resolution_milestones: [3000] 23 | eval_height: 512 24 | eval_width: 512 25 | eval_batch_size: 1 26 | elevation_range: [-10, 45] 27 | azimuth_range: [-180, 180] 28 | camera_distance_range: [3.8, 3.8] 29 | fovy_range: [20.0, 20.0] # Zero123 has fixed fovy 30 | progressive_until: 200 31 | camera_perturb: 0.0 32 | center_perturb: 0.0 33 | up_perturb: 0.0 34 | eval_elevation_deg: ${data.default_elevation_deg} 35 | eval_camera_distance: ${data.default_camera_distance} 36 | eval_fovy_deg: ${data.default_fovy_deg} 37 | batch_uniform_azimuth: false 38 | n_val_views: 40 39 | n_test_views: 120 40 | 41 | system_type: "dreamcraft3d-system" 42 | system: 43 | stage: coarse 44 | geometry_type: "implicit-volume" 45 | geometry: 46 | radius: 2.0 47 | normal_type: "finite_difference" 48 | 49 | # the density initialization proposed in the DreamFusion paper 50 | # does not work very well 51 | # density_bias: "blob_dreamfusion" 52 | # density_activation: exp 53 | # density_blob_scale: 5. 54 | # density_blob_std: 0.2 55 | 56 | # use Magic3D density initialization instead 57 | density_bias: "blob_magic3d" 58 | density_activation: softplus 59 | density_blob_scale: 10. 60 | density_blob_std: 0.5 61 | 62 | # coarse to fine hash grid encoding 63 | # to ensure smooth analytic normals 64 | pos_encoding_config: 65 | otype: ProgressiveBandHashGrid 66 | n_levels: 16 67 | n_features_per_level: 2 68 | log2_hashmap_size: 19 69 | base_resolution: 16 70 | per_level_scale: 1.447269237440378 # max resolution 4096 71 | start_level: 8 # resolution ~200 72 | start_step: 2000 73 | update_steps: 500 74 | 75 | material_type: "no-material" 76 | material: 77 | requires_normal: true 78 | 79 | background_type: "solid-color-background" 80 | 81 | renderer_type: "nerf-volume-renderer" 82 | renderer: 83 | radius: ${system.geometry.radius} 84 | num_samples_per_ray: 512 85 | return_normal_perturb: true 86 | return_comp_normal: ${cmaxgt0:${system.loss.lambda_normal_smooth}} 87 | 88 | prompt_processor_type: "deep-floyd-prompt-processor" 89 | prompt_processor: 90 | pretrained_model_name_or_path: "DeepFloyd/IF-I-XL-v1.0" 91 | prompt: ??? 92 | use_perp_neg: true 93 | 94 | guidance_type: "deep-floyd-guidance" 95 | guidance: 96 | pretrained_model_name_or_path: "DeepFloyd/IF-I-XL-v1.0" 97 | guidance_scale: 20 98 | min_step_percent: [0, 0.7, 0.2, 200] 99 | max_step_percent: [0, 0.85, 0.5, 200] 100 | 101 | guidance_3d_type: "stable-zero123-guidance" 102 | guidance_3d: 103 | pretrained_model_name_or_path: "./load/zero123/stable_zero123.ckpt" 104 | pretrained_config: "./load/zero123/sd-objaverse-finetune-c_concat-256.yaml" 105 | cond_image_path: ${data.image_path} 106 | cond_elevation_deg: ${data.default_elevation_deg} 107 | cond_azimuth_deg: ${data.default_azimuth_deg} 108 | cond_camera_distance: ${data.default_camera_distance} 109 | guidance_scale: 5.0 110 | min_step_percent: [0, 0.7, 0.2, 200] # (start_iter, start_val, end_val, end_iter) 111 | max_step_percent: [0, 0.85, 0.5, 200] 112 | 113 | freq: 114 | n_ref: 2 115 | ref_only_steps: 0 116 | ref_or_guidance: "alternate" 117 | no_diff_steps: 0 118 | guidance_eval: 0 119 | 120 | loggers: 121 | wandb: 122 | enable: false 123 | project: "threestudio" 124 | 125 | loss: 126 | lambda_sd: 0.1 127 | lambda_3d_sd: 0.1 128 | lambda_rgb: 1000.0 129 | lambda_mask: 100.0 130 | lambda_mask_binary: 0.0 131 | lambda_depth: 0.0 132 | lambda_depth_rel: 0.05 133 | lambda_normal: 0.0 134 | lambda_normal_smooth: 1.0 135 | lambda_3d_normal_smooth: [2000, 5., 1., 2001] 136 | lambda_orient: [2000, 1., 10., 2001] 137 | lambda_sparsity: [2000, 0.1, 10., 2001] 138 | lambda_opaque: [2000, 0.1, 10., 2001] 139 | lambda_clip: 0.0 140 | 141 | optimizer: 142 | name: Adam 143 | args: 144 | lr: 0.01 145 | betas: [0.9, 0.99] 146 | eps: 1.e-8 147 | 148 | trainer: 149 | max_steps: 5000 150 | log_every_n_steps: 1 151 | num_sanity_val_steps: 0 152 | val_check_interval: 200 153 | enable_progress_bar: true 154 | precision: 16-mixed 155 | 156 | checkpoint: 157 | save_last: true 158 | save_top_k: -1 159 | every_n_train_steps: ${trainer.max_steps} -------------------------------------------------------------------------------- /configs/dreamcraft3d-coarse-neus.yaml: -------------------------------------------------------------------------------- 1 | name: "dreamcraft3d-coarse-neus" 2 | tag: "${rmspace:${system.prompt_processor.prompt},_}" 3 | exp_root_dir: "outputs" 4 | seed: 0 5 | 6 | data_type: "single-image-datamodule" 7 | data: 8 | image_path: ./load/images/hamburger_rgba.png 9 | height: 256 10 | width: 256 11 | default_elevation_deg: 0.0 12 | default_azimuth_deg: 0.0 13 | default_camera_distance: 3.8 14 | default_fovy_deg: 20.0 15 | requires_depth: true 16 | requires_normal: ${cmaxgt0:${system.loss.lambda_normal}} 17 | random_camera: 18 | height: 256 19 | width: 256 20 | batch_size: 1 21 | eval_height: 512 22 | eval_width: 512 23 | eval_batch_size: 1 24 | elevation_range: [-10, 45] 25 | azimuth_range: [-180, 180] 26 | camera_distance_range: [3.8, 3.8] 27 | fovy_range: [20.0, 20.0] # Zero123 has fixed fovy 28 | progressive_until: 0 29 | camera_perturb: 0.0 30 | center_perturb: 0.0 31 | up_perturb: 0.0 32 | eval_elevation_deg: ${data.default_elevation_deg} 33 | eval_camera_distance: ${data.default_camera_distance} 34 | eval_fovy_deg: ${data.default_fovy_deg} 35 | batch_uniform_azimuth: false 36 | n_val_views: 40 37 | n_test_views: 120 38 | 39 | system_type: "dreamcraft3d-system" 40 | system: 41 | stage: coarse 42 | geometry_type: "implicit-sdf" 43 | geometry: 44 | radius: 2.0 45 | normal_type: "finite_difference" 46 | 47 | sdf_bias: sphere 48 | sdf_bias_params: 0.5 49 | 50 | # coarse to fine hash grid encoding 51 | pos_encoding_config: 52 | otype: HashGrid 53 | n_levels: 16 54 | n_features_per_level: 2 55 | log2_hashmap_size: 19 56 | base_resolution: 16 57 | per_level_scale: 1.447269237440378 # max resolution 4096 58 | start_level: 8 # resolution ~200 59 | start_step: 2000 60 | update_steps: 500 61 | 62 | material_type: "no-material" 63 | material: 64 | requires_normal: true 65 | 66 | background_type: "solid-color-background" 67 | 68 | renderer_type: "neus-volume-renderer" 69 | renderer: 70 | radius: ${system.geometry.radius} 71 | num_samples_per_ray: 512 72 | cos_anneal_end_steps: ${trainer.max_steps} 73 | eval_chunk_size: 8192 74 | 75 | prompt_processor_type: "deep-floyd-prompt-processor" 76 | prompt_processor: 77 | pretrained_model_name_or_path: "DeepFloyd/IF-I-XL-v1.0" 78 | prompt: ??? 79 | use_perp_neg: true 80 | 81 | guidance_type: "deep-floyd-guidance" 82 | guidance: 83 | pretrained_model_name_or_path: "DeepFloyd/IF-I-XL-v1.0" 84 | guidance_scale: 20 85 | min_step_percent: 0.2 86 | max_step_percent: 0.5 87 | 88 | guidance_3d_type: "stable-zero123-guidance" 89 | guidance_3d: 90 | pretrained_model_name_or_path: "./load/zero123/stable_zero123.ckpt" 91 | pretrained_config: "./load/zero123/sd-objaverse-finetune-c_concat-256.yaml" 92 | cond_image_path: ${data.image_path} 93 | cond_elevation_deg: ${data.default_elevation_deg} 94 | cond_azimuth_deg: ${data.default_azimuth_deg} 95 | cond_camera_distance: ${data.default_camera_distance} 96 | guidance_scale: 5.0 97 | min_step_percent: 0.2 98 | max_step_percent: 0.5 99 | 100 | freq: 101 | n_ref: 2 102 | ref_only_steps: 0 103 | ref_or_guidance: "alternate" 104 | no_diff_steps: 0 105 | guidance_eval: 0 106 | 107 | loggers: 108 | wandb: 109 | enable: false 110 | project: "threestudio" 111 | 112 | loss: 113 | lambda_sd: 0.1 114 | lambda_3d_sd: 0.1 115 | lambda_rgb: 1000.0 116 | lambda_mask: 100.0 117 | lambda_mask_binary: 0.0 118 | lambda_depth: 0.0 119 | lambda_depth_rel: 0.05 120 | lambda_normal: 0.0 121 | lambda_normal_smooth: 0.0 122 | lambda_3d_normal_smooth: 0.0 123 | lambda_orient: 10.0 124 | lambda_sparsity: 0.1 125 | lambda_opaque: 0.1 126 | lambda_clip: 0.0 127 | lambda_eikonal: 0.0 128 | 129 | optimizer: 130 | name: Adam 131 | args: 132 | betas: [0.9, 0.99] 133 | eps: 1.e-15 134 | params: 135 | geometry.encoding: 136 | lr: 0.01 137 | geometry.sdf_network: 138 | lr: 0.001 139 | geometry.feature_network: 140 | lr: 0.001 141 | renderer: 142 | lr: 0.001 143 | 144 | trainer: 145 | max_steps: 5000 146 | log_every_n_steps: 1 147 | num_sanity_val_steps: 0 148 | val_check_interval: 200 149 | enable_progress_bar: true 150 | precision: 16-mixed 151 | 152 | checkpoint: 153 | save_last: true 154 | save_top_k: -1 155 | every_n_train_steps: ${trainer.max_steps} -------------------------------------------------------------------------------- /configs/dreamcraft3d-geometry.yaml: -------------------------------------------------------------------------------- 1 | name: "dreamcraft3d-geometry" 2 | tag: "${rmspace:${system.prompt_processor.prompt},_}" 3 | exp_root_dir: "outputs" 4 | seed: 0 5 | 6 | data_type: "single-image-datamodule" 7 | data: 8 | image_path: ./load/images/hamburger_rgba.png 9 | height: 1024 10 | width: 1024 11 | default_elevation_deg: 0.0 12 | default_azimuth_deg: 0.0 13 | default_camera_distance: 3.8 14 | default_fovy_deg: 20.0 15 | requires_depth: ${cmaxgt0orcmaxgt0:${system.loss.lambda_depth},${system.loss.lambda_depth_rel}} 16 | requires_normal: ${cmaxgt0:${system.loss.lambda_normal}} 17 | use_mixed_camera_config: false 18 | random_camera: 19 | height: 1024 20 | width: 1024 21 | batch_size: 1 22 | eval_height: 1024 23 | eval_width: 1024 24 | eval_batch_size: 1 25 | elevation_range: [-10, 45] 26 | azimuth_range: [-180, 180] 27 | camera_distance_range: [3.8, 3.8] 28 | fovy_range: [20.0, 20.0] # Zero123 has fixed fovy 29 | progressive_until: 0 30 | camera_perturb: 0.0 31 | center_perturb: 0.0 32 | up_perturb: 0.0 33 | eval_elevation_deg: ${data.default_elevation_deg} 34 | eval_camera_distance: ${data.default_camera_distance} 35 | eval_fovy_deg: ${data.default_fovy_deg} 36 | batch_uniform_azimuth: false 37 | n_val_views: 40 38 | n_test_views: 120 39 | 40 | system_type: "dreamcraft3d-system" 41 | system: 42 | stage: geometry 43 | use_mixed_camera_config: ${data.use_mixed_camera_config} 44 | geometry_convert_from: ??? 45 | geometry_convert_inherit_texture: true 46 | geometry_type: "tetrahedra-sdf-grid" 47 | geometry: 48 | radius: 2.0 # consistent with coarse 49 | isosurface_resolution: 128 50 | isosurface_deformable_grid: true 51 | 52 | material_type: "no-material" 53 | material: 54 | n_output_dims: 3 55 | 56 | background_type: "solid-color-background" 57 | 58 | renderer_type: "nvdiff-rasterizer" 59 | renderer: 60 | context_type: cuda 61 | 62 | prompt_processor_type: "deep-floyd-prompt-processor" 63 | prompt_processor: 64 | pretrained_model_name_or_path: "DeepFloyd/IF-I-XL-v1.0" 65 | prompt: ??? 66 | use_perp_neg: true 67 | 68 | guidance_type: "deep-floyd-guidance" 69 | guidance: 70 | pretrained_model_name_or_path: "DeepFloyd/IF-I-XL-v1.0" 71 | guidance_scale: 20 72 | min_step_percent: 0.02 73 | max_step_percent: 0.5 74 | 75 | guidance_3d_type: "stable-zero123-guidance" 76 | guidance_3d: 77 | pretrained_model_name_or_path: "./load/zero123/stable_zero123.ckpt" 78 | pretrained_config: "./load/zero123/sd-objaverse-finetune-c_concat-256.yaml" 79 | cond_image_path: ${data.image_path} 80 | cond_elevation_deg: ${data.default_elevation_deg} 81 | cond_azimuth_deg: ${data.default_azimuth_deg} 82 | cond_camera_distance: ${data.default_camera_distance} 83 | guidance_scale: 5.0 84 | min_step_percent: 0.2 # (start_iter, start_val, end_val, end_iter) 85 | max_step_percent: 0.5 86 | 87 | freq: 88 | n_ref: 2 89 | ref_only_steps: 0 90 | ref_or_guidance: "accumulate" 91 | no_diff_steps: 0 92 | guidance_eval: 0 93 | n_rgb: 4 94 | 95 | loggers: 96 | wandb: 97 | enable: false 98 | project: "threestudio" 99 | 100 | loss: 101 | lambda_sd: 0.1 102 | lambda_3d_sd: 0.1 103 | lambda_rgb: 1000.0 104 | lambda_mask: 100.0 105 | lambda_mask_binary: 0.0 106 | lambda_depth: 0.0 107 | lambda_depth_rel: 0.0 108 | lambda_normal: 0.0 109 | lambda_normal_smooth: 0. 110 | lambda_3d_normal_smooth: 0. 111 | lambda_normal_consistency: [1000,10.0,1,2000] 112 | lambda_laplacian_smoothness: 0.0 113 | 114 | optimizer: 115 | name: Adam 116 | args: 117 | lr: 0.005 118 | betas: [0.9, 0.99] 119 | eps: 1.e-15 120 | 121 | trainer: 122 | max_steps: 5000 123 | log_every_n_steps: 1 124 | num_sanity_val_steps: 0 125 | val_check_interval: 200 126 | enable_progress_bar: true 127 | precision: 32 128 | strategy: "ddp_find_unused_parameters_true" 129 | 130 | checkpoint: 131 | save_last: true 132 | save_top_k: -1 133 | every_n_train_steps: ${trainer.max_steps} -------------------------------------------------------------------------------- /configs/dreamcraft3d-texture.yaml: -------------------------------------------------------------------------------- 1 | name: "dreamcraft3d-texture" 2 | tag: "${rmspace:${system.prompt_processor.prompt},_}" 3 | exp_root_dir: "outputs" 4 | seed: 0 5 | 6 | data_type: "single-image-datamodule" 7 | data: 8 | image_path: ./load/images/hamburger_rgba.png 9 | height: 1024 10 | width: 1024 11 | default_elevation_deg: 0.0 12 | default_azimuth_deg: 0.0 13 | default_camera_distance: 3.8 14 | default_fovy_deg: 20.0 15 | requires_depth: false 16 | requires_normal: false 17 | use_mixed_camera_config: false 18 | random_camera: 19 | height: 1024 20 | width: 1024 21 | batch_size: 1 22 | eval_height: 1024 23 | eval_width: 1024 24 | eval_batch_size: 1 25 | elevation_range: [-10, 45] 26 | azimuth_range: [-180, 180] 27 | camera_distance_range: [3.8, 3.8] 28 | fovy_range: [20.0, 20.0] # Zero123 has fixed fovy 29 | progressive_until: 0 30 | camera_perturb: 0.0 31 | center_perturb: 0.0 32 | up_perturb: 0.0 33 | eval_elevation_deg: ${data.default_elevation_deg} 34 | eval_camera_distance: ${data.default_camera_distance} 35 | eval_fovy_deg: ${data.default_fovy_deg} 36 | batch_uniform_azimuth: false 37 | n_val_views: 40 38 | n_test_views: 120 39 | 40 | system_type: "dreamcraft3d-system" 41 | system: 42 | stage: texture 43 | use_mixed_camera_config: ${data.use_mixed_camera_config} 44 | geometry_convert_from: ??? 45 | geometry_convert_inherit_texture: true 46 | geometry_type: "tetrahedra-sdf-grid" 47 | geometry: 48 | radius: 2.0 # consistent with coarse 49 | isosurface_resolution: 128 50 | isosurface_deformable_grid: true 51 | isosurface_remove_outliers: true 52 | pos_encoding_config: 53 | otype: HashGrid 54 | n_levels: 16 55 | n_features_per_level: 2 56 | log2_hashmap_size: 19 57 | base_resolution: 16 58 | per_level_scale: 1.447269237440378 # max resolution 4096 59 | fix_geometry: true 60 | 61 | material_type: "no-material" 62 | material: 63 | n_output_dims: 3 64 | 65 | background_type: "solid-color-background" 66 | 67 | renderer_type: "nvdiff-rasterizer" 68 | renderer: 69 | context_type: cuda 70 | 71 | prompt_processor_type: "stable-diffusion-prompt-processor" 72 | prompt_processor: 73 | pretrained_model_name_or_path: "stabilityai/stable-diffusion-2-1-base" 74 | prompt: ??? 75 | front_threshold: 30. 76 | back_threshold: 30. 77 | 78 | guidance_type: "stable-diffusion-bsd-guidance" 79 | guidance: 80 | pretrained_model_name_or_path: "stabilityai/stable-diffusion-2-1-base" 81 | pretrained_model_name_or_path_lora: "stabilityai/stable-diffusion-2-1-base" 82 | # pretrained_model_name_or_path_lora: "stabilityai/stable-diffusion-2-1" 83 | guidance_scale: 2.0 84 | min_step_percent: 0.05 85 | max_step_percent: [0, 0.5, 0.2, 5000] 86 | only_pretrain_step: 1000 87 | 88 | # guidance_3d_type: "stable-zero123-guidance" 89 | # guidance_3d: 90 | # pretrained_model_name_or_path: "./load/zero123/stable_zero123.ckpt" 91 | # pretrained_config: "./load/zero123/sd-objaverse-finetune-c_concat-256.yaml" 92 | # cond_image_path: ${data.image_path} 93 | # cond_elevation_deg: ${data.default_elevation_deg} 94 | # cond_azimuth_deg: ${data.default_azimuth_deg} 95 | # cond_camera_distance: ${data.default_camera_distance} 96 | # guidance_scale: 5.0 97 | # min_step_percent: 0.2 # (start_iter, start_val, end_val, end_iter) 98 | # max_step_percent: 0.5 99 | 100 | # control_guidance_type: "stable-diffusion-controlnet-reg-guidance" 101 | # control_guidance: 102 | # min_step_percent: 0.1 103 | # max_step_percent: 0.5 104 | # control_prompt_processor_type: "stable-diffusion-prompt-processor" 105 | # control_prompt_processor: 106 | # pretrained_model_name_or_path: "SG161222/Realistic_Vision_V2.0" 107 | # prompt: ${system.prompt_processor.prompt} 108 | # front_threshold: 30. 109 | # back_threshold: 30. 110 | 111 | freq: 112 | n_ref: 2 113 | ref_only_steps: 0 114 | ref_or_guidance: "alternate" 115 | no_diff_steps: -1 116 | guidance_eval: 0 117 | 118 | loggers: 119 | wandb: 120 | enable: false 121 | project: "threestudio" 122 | 123 | loss: 124 | lambda_sd: 0.01 125 | lambda_lora: 0.1 126 | lambda_pretrain: 0.1 127 | lambda_3d_sd: 0.0 128 | lambda_rgb: 1000. 129 | lambda_mask: 100. 130 | lambda_mask_binary: 0.0 131 | lambda_depth: 0.0 132 | lambda_depth_rel: 0.0 133 | lambda_normal: 0.0 134 | lambda_normal_smooth: 0.0 135 | lambda_3d_normal_smooth: 0.0 136 | lambda_z_variance: 0.0 137 | lambda_reg: 0.0 138 | 139 | optimizer: 140 | name: AdamW 141 | args: 142 | betas: [0.9, 0.99] 143 | eps: 1.e-4 144 | params: 145 | geometry.encoding: 146 | lr: 0.01 147 | geometry.feature_network: 148 | lr: 0.001 149 | guidance.train_unet: 150 | lr: 0.00001 151 | guidance.train_unet_lora: 152 | lr: 0.00001 153 | 154 | trainer: 155 | max_steps: 5000 156 | log_every_n_steps: 1 157 | num_sanity_val_steps: 0 158 | val_check_interval: 200 159 | enable_progress_bar: true 160 | precision: 32 161 | strategy: "ddp_find_unused_parameters_true" 162 | 163 | checkpoint: 164 | save_last: true 165 | save_top_k: -1 166 | every_n_train_steps: ${trainer.max_steps} -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # Reference: 2 | # https://github.com/cvpaperchallenge/Ascender 3 | # https://github.com/nerfstudio-project/nerfstudio 4 | 5 | FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 6 | 7 | ARG USER_NAME=dreamer 8 | ARG GROUP_NAME=dreamers 9 | ARG UID=1000 10 | ARG GID=1000 11 | 12 | # Set compute capability for nerfacc and tiny-cuda-nn 13 | # See https://developer.nvidia.com/cuda-gpus and limit number to speed-up build 14 | ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX" 15 | ENV TCNN_CUDA_ARCHITECTURES=90;89;86;80;75;70;61;60 16 | # Speed-up build for RTX 30xx 17 | # ENV TORCH_CUDA_ARCH_LIST="8.6" 18 | # ENV TCNN_CUDA_ARCHITECTURES=86 19 | # Speed-up build for RTX 40xx 20 | # ENV TORCH_CUDA_ARCH_LIST="8.9" 21 | # ENV TCNN_CUDA_ARCHITECTURES=89 22 | 23 | ENV CUDA_HOME=/usr/local/cuda 24 | ENV PATH=${CUDA_HOME}/bin:/home/${USER_NAME}/.local/bin:${PATH} 25 | ENV LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH} 26 | ENV LIBRARY_PATH=${CUDA_HOME}/lib64/stubs:${LIBRARY_PATH} 27 | 28 | # apt install by root user 29 | RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ 30 | build-essential \ 31 | curl \ 32 | git \ 33 | libegl1-mesa-dev \ 34 | libgl1-mesa-dev \ 35 | libgles2-mesa-dev \ 36 | libglib2.0-0 \ 37 | libsm6 \ 38 | libxext6 \ 39 | libxrender1 \ 40 | python-is-python3 \ 41 | python3.10-dev \ 42 | python3-pip \ 43 | wget \ 44 | && rm -rf /var/lib/apt/lists/* 45 | 46 | # Change user to non-root user 47 | RUN groupadd -g ${GID} ${GROUP_NAME} \ 48 | && useradd -ms /bin/sh -u ${UID} -g ${GID} ${USER_NAME} 49 | USER ${USER_NAME} 50 | 51 | RUN pip install --upgrade pip setuptools ninja 52 | RUN pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118 53 | # Install nerfacc and tiny-cuda-nn before installing requirements.txt 54 | # because these two installations are time consuming and error prone 55 | RUN pip install git+https://github.com/KAIR-BAIR/nerfacc.git@v0.5.2 56 | RUN pip install git+https://github.com/NVlabs/tiny-cuda-nn.git#subdirectory=bindings/torch 57 | 58 | COPY requirements.txt /tmp 59 | RUN cd /tmp && pip install -r requirements.txt 60 | WORKDIR /home/${USER_NAME}/threestudio 61 | -------------------------------------------------------------------------------- /docker/compose.yaml: -------------------------------------------------------------------------------- 1 | services: 2 | threestudio: 3 | build: 4 | context: ../ 5 | dockerfile: docker/Dockerfile 6 | args: 7 | # you can set environment variables, otherwise default values will be used 8 | USER_NAME: ${HOST_USER_NAME:-dreamer} # export HOST_USER_NAME=$USER 9 | GROUP_NAME: ${HOST_GROUP_NAME:-dreamers} 10 | UID: ${HOST_UID:-1000} # export HOST_UID=$(id -u) 11 | GID: ${HOST_GID:-1000} # export HOST_GID=$(id -g) 12 | shm_size: '4gb' 13 | environment: 14 | NVIDIA_DISABLE_REQUIRE: 1 # avoid wrong `nvidia-container-cli: requirement error` 15 | tty: true 16 | volumes: 17 | - ../:/home/${HOST_USER_NAME:-dreamer}/threestudio 18 | deploy: 19 | resources: 20 | reservations: 21 | devices: 22 | - driver: nvidia 23 | capabilities: [gpu] 24 | -------------------------------------------------------------------------------- /docs/installation.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | ## Prerequisite 4 | 5 | - NVIDIA GPU with at least 6GB VRAM. The more memory you have, the more methods and higher resolutions you can try. 6 | - [NVIDIA Driver](https://www.nvidia.com/Download/index.aspx) whose version is higher than the [Minimum Required Driver Version](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html) of CUDA Toolkit you want to use. 7 | 8 | ## Install CUDA Toolkit 9 | 10 | You can skip this step if you have installed sufficiently new version or you use Docker. 11 | 12 | Install [CUDA Toolkit](https://developer.nvidia.com/cuda-toolkit-archive). 13 | 14 | - Example for Ubuntu 22.04: 15 | - Run [command for CUDA 11.8 Ubuntu 22.04](https://developer.nvidia.com/cuda-11-8-0-download-archive?target_os=Linux&target_arch=x86_64&Distribution=Ubuntu&target_version=22.04&target_type=deb_local) 16 | - Example for Ubuntu on WSL2: 17 | - `sudo apt-key del 7fa2af80` 18 | - Run [command for CUDA 11.8 WSL-Ubuntu](https://developer.nvidia.com/cuda-11-8-0-download-archive?target_os=Linux&target_arch=x86_64&Distribution=WSL-Ubuntu&target_version=2.0&target_type=deb_local) 19 | 20 | ## Git Clone 21 | 22 | ```bash 23 | git clone https://github.com/threestudio-project/threestudio.git 24 | cd threestudio/ 25 | ``` 26 | 27 | ## Install threestudio via Docker 28 | 29 | 1. [Install Docker Engine](https://docs.docker.com/engine/install/). 30 | This document assumes you [install Docker Engine on Ubuntu](https://docs.docker.com/engine/install/ubuntu/). 31 | 2. [Create `docker` group](https://docs.docker.com/engine/install/linux-postinstall/). 32 | Otherwise, you need to type `sudo docker` instead of `docker`. 33 | 3. [Install NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html#setting-up-nvidia-container-toolkit). 34 | 4. If you use WSL2, [enable systemd](https://learn.microsoft.com/en-us/windows/wsl/wsl-config#systemd-support). 35 | 5. Edit [Dockerfile](../docker/Dockerfile) for your GPU to speed-up build. 36 | The default Dockerfile takes into account many types of GPUs. 37 | 6. Run Docker via `docker compose`. 38 | 39 | ```bash 40 | cd docker/ 41 | docker compose build # build Docker image 42 | docker compose up -d # create and start a container in background 43 | docker compose exec threestudio bash # run bash in the container 44 | 45 | # Enjoy threestudio! 46 | 47 | exit # or Ctrl+D 48 | docker compose stop # stop the container 49 | docker compose start # start the container 50 | docker compose down # stop and remove the container 51 | ``` 52 | 53 | Note: The current Dockerfile will cause errors when using the OpenGL-based rasterizer of nvdiffrast. 54 | You can use the CUDA-based rasterizer by adding commands or editing configs. 55 | 56 | - `system.renderer.context_type=cuda` for training 57 | - `system.exporter.context_type=cuda` for exporting meshes 58 | 59 | [This comment by the nvdiffrast author](https://github.com/NVlabs/nvdiffrast/issues/94#issuecomment-1288566038) could be a guide to resolve this limitation. 60 | -------------------------------------------------------------------------------- /extern/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/extern/__init__.py -------------------------------------------------------------------------------- /extern/ldm_zero123/extras.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from contextlib import contextmanager 3 | from pathlib import Path 4 | 5 | import torch 6 | from omegaconf import OmegaConf 7 | 8 | from extern.ldm_zero123.util import instantiate_from_config 9 | 10 | 11 | @contextmanager 12 | def all_logging_disabled(highest_level=logging.CRITICAL): 13 | """ 14 | A context manager that will prevent any logging messages 15 | triggered during the body from being processed. 16 | 17 | :param highest_level: the maximum logging level in use. 18 | This would only need to be changed if a custom level greater than CRITICAL 19 | is defined. 20 | 21 | https://gist.github.com/simon-weber/7853144 22 | """ 23 | # two kind-of hacks here: 24 | # * can't get the highest logging level in effect => delegate to the user 25 | # * can't get the current module-level override => use an undocumented 26 | # (but non-private!) interface 27 | 28 | previous_level = logging.root.manager.disable 29 | 30 | logging.disable(highest_level) 31 | 32 | try: 33 | yield 34 | finally: 35 | logging.disable(previous_level) 36 | 37 | 38 | def load_training_dir(train_dir, device, epoch="last"): 39 | """Load a checkpoint and config from training directory""" 40 | train_dir = Path(train_dir) 41 | ckpt = list(train_dir.rglob(f"*{epoch}.ckpt")) 42 | assert len(ckpt) == 1, f"found {len(ckpt)} matching ckpt files" 43 | config = list(train_dir.rglob(f"*-project.yaml")) 44 | assert len(ckpt) > 0, f"didn't find any config in {train_dir}" 45 | if len(config) > 1: 46 | print(f"found {len(config)} matching config files") 47 | config = sorted(config)[-1] 48 | print(f"selecting {config}") 49 | else: 50 | config = config[0] 51 | 52 | config = OmegaConf.load(config) 53 | return load_model_from_config(config, ckpt[0], device) 54 | 55 | 56 | def load_model_from_config(config, ckpt, device="cpu", verbose=False): 57 | """Loads a model from config and a ckpt 58 | if config is a path will use omegaconf to load 59 | """ 60 | if isinstance(config, (str, Path)): 61 | config = OmegaConf.load(config) 62 | 63 | with all_logging_disabled(): 64 | print(f"Loading model from {ckpt}") 65 | pl_sd = torch.load(ckpt, map_location="cpu") 66 | global_step = pl_sd["global_step"] 67 | sd = pl_sd["state_dict"] 68 | model = instantiate_from_config(config.model) 69 | m, u = model.load_state_dict(sd, strict=False) 70 | if len(m) > 0 and verbose: 71 | print("missing keys:") 72 | print(m) 73 | if len(u) > 0 and verbose: 74 | print("unexpected keys:") 75 | model.to(device) 76 | model.eval() 77 | model.cond_stage_model.device = device 78 | return model 79 | -------------------------------------------------------------------------------- /extern/ldm_zero123/guidance.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import List, Tuple 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import torch 7 | from IPython.display import clear_output 8 | from scipy import interpolate 9 | 10 | 11 | class GuideModel(torch.nn.Module, abc.ABC): 12 | def __init__(self) -> None: 13 | super().__init__() 14 | 15 | @abc.abstractmethod 16 | def preprocess(self, x_img): 17 | pass 18 | 19 | @abc.abstractmethod 20 | def compute_loss(self, inp): 21 | pass 22 | 23 | 24 | class Guider(torch.nn.Module): 25 | def __init__(self, sampler, guide_model, scale=1.0, verbose=False): 26 | """Apply classifier guidance 27 | 28 | Specify a guidance scale as either a scalar 29 | Or a schedule as a list of tuples t = 0->1 and scale, e.g. 30 | [(0, 10), (0.5, 20), (1, 50)] 31 | """ 32 | super().__init__() 33 | self.sampler = sampler 34 | self.index = 0 35 | self.show = verbose 36 | self.guide_model = guide_model 37 | self.history = [] 38 | 39 | if isinstance(scale, (Tuple, List)): 40 | times = np.array([x[0] for x in scale]) 41 | values = np.array([x[1] for x in scale]) 42 | self.scale_schedule = {"times": times, "values": values} 43 | else: 44 | self.scale_schedule = float(scale) 45 | 46 | self.ddim_timesteps = sampler.ddim_timesteps 47 | self.ddpm_num_timesteps = sampler.ddpm_num_timesteps 48 | 49 | def get_scales(self): 50 | if isinstance(self.scale_schedule, float): 51 | return len(self.ddim_timesteps) * [self.scale_schedule] 52 | 53 | interpolater = interpolate.interp1d( 54 | self.scale_schedule["times"], self.scale_schedule["values"] 55 | ) 56 | fractional_steps = np.array(self.ddim_timesteps) / self.ddpm_num_timesteps 57 | return interpolater(fractional_steps) 58 | 59 | def modify_score(self, model, e_t, x, t, c): 60 | # TODO look up index by t 61 | scale = self.get_scales()[self.index] 62 | 63 | if scale == 0: 64 | return e_t 65 | 66 | sqrt_1ma = self.sampler.ddim_sqrt_one_minus_alphas[self.index].to(x.device) 67 | with torch.enable_grad(): 68 | x_in = x.detach().requires_grad_(True) 69 | pred_x0 = model.predict_start_from_noise(x_in, t=t, noise=e_t) 70 | x_img = model.first_stage_model.decode((1 / 0.18215) * pred_x0) 71 | 72 | inp = self.guide_model.preprocess(x_img) 73 | loss = self.guide_model.compute_loss(inp) 74 | grads = torch.autograd.grad(loss.sum(), x_in)[0] 75 | correction = grads * scale 76 | 77 | if self.show: 78 | clear_output(wait=True) 79 | print( 80 | loss.item(), 81 | scale, 82 | correction.abs().max().item(), 83 | e_t.abs().max().item(), 84 | ) 85 | self.history.append( 86 | [ 87 | loss.item(), 88 | scale, 89 | correction.min().item(), 90 | correction.max().item(), 91 | ] 92 | ) 93 | plt.imshow( 94 | (inp[0].detach().permute(1, 2, 0).clamp(-1, 1).cpu() + 1) / 2 95 | ) 96 | plt.axis("off") 97 | plt.show() 98 | plt.imshow(correction[0][0].detach().cpu()) 99 | plt.axis("off") 100 | plt.show() 101 | 102 | e_t_mod = e_t - sqrt_1ma * correction 103 | if self.show: 104 | fig, axs = plt.subplots(1, 3) 105 | axs[0].imshow(e_t[0][0].detach().cpu(), vmin=-2, vmax=+2) 106 | axs[1].imshow(e_t_mod[0][0].detach().cpu(), vmin=-2, vmax=+2) 107 | axs[2].imshow(correction[0][0].detach().cpu(), vmin=-2, vmax=+2) 108 | plt.show() 109 | self.index += 1 110 | return e_t_mod 111 | -------------------------------------------------------------------------------- /extern/ldm_zero123/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | 9 | def __init__( 10 | self, 11 | warm_up_steps, 12 | lr_min, 13 | lr_max, 14 | lr_start, 15 | max_decay_steps, 16 | verbosity_interval=0, 17 | ): 18 | self.lr_warm_up_steps = warm_up_steps 19 | self.lr_start = lr_start 20 | self.lr_min = lr_min 21 | self.lr_max = lr_max 22 | self.lr_max_decay_steps = max_decay_steps 23 | self.last_lr = 0.0 24 | self.verbosity_interval = verbosity_interval 25 | 26 | def schedule(self, n, **kwargs): 27 | if self.verbosity_interval > 0: 28 | if n % self.verbosity_interval == 0: 29 | print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 30 | if n < self.lr_warm_up_steps: 31 | lr = ( 32 | self.lr_max - self.lr_start 33 | ) / self.lr_warm_up_steps * n + self.lr_start 34 | self.last_lr = lr 35 | return lr 36 | else: 37 | t = (n - self.lr_warm_up_steps) / ( 38 | self.lr_max_decay_steps - self.lr_warm_up_steps 39 | ) 40 | t = min(t, 1.0) 41 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 42 | 1 + np.cos(t * np.pi) 43 | ) 44 | self.last_lr = lr 45 | return lr 46 | 47 | def __call__(self, n, **kwargs): 48 | return self.schedule(n, **kwargs) 49 | 50 | 51 | class LambdaWarmUpCosineScheduler2: 52 | """ 53 | supports repeated iterations, configurable via lists 54 | note: use with a base_lr of 1.0. 55 | """ 56 | 57 | def __init__( 58 | self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0 59 | ): 60 | assert ( 61 | len(warm_up_steps) 62 | == len(f_min) 63 | == len(f_max) 64 | == len(f_start) 65 | == len(cycle_lengths) 66 | ) 67 | self.lr_warm_up_steps = warm_up_steps 68 | self.f_start = f_start 69 | self.f_min = f_min 70 | self.f_max = f_max 71 | self.cycle_lengths = cycle_lengths 72 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 73 | self.last_f = 0.0 74 | self.verbosity_interval = verbosity_interval 75 | 76 | def find_in_interval(self, n): 77 | interval = 0 78 | for cl in self.cum_cycles[1:]: 79 | if n <= cl: 80 | return interval 81 | interval += 1 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: 88 | print( 89 | f"current step: {n}, recent lr-multiplier: {self.last_f}, " 90 | f"current cycle {cycle}" 91 | ) 92 | if n < self.lr_warm_up_steps[cycle]: 93 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ 94 | cycle 95 | ] * n + self.f_start[cycle] 96 | self.last_f = f 97 | return f 98 | else: 99 | t = (n - self.lr_warm_up_steps[cycle]) / ( 100 | self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle] 101 | ) 102 | t = min(t, 1.0) 103 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 104 | 1 + np.cos(t * np.pi) 105 | ) 106 | self.last_f = f 107 | return f 108 | 109 | def __call__(self, n, **kwargs): 110 | return self.schedule(n, **kwargs) 111 | 112 | 113 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 114 | def schedule(self, n, **kwargs): 115 | cycle = self.find_in_interval(n) 116 | n = n - self.cum_cycles[cycle] 117 | if self.verbosity_interval > 0: 118 | if n % self.verbosity_interval == 0: 119 | print( 120 | f"current step: {n}, recent lr-multiplier: {self.last_f}, " 121 | f"current cycle {cycle}" 122 | ) 123 | 124 | if n < self.lr_warm_up_steps[cycle]: 125 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ 126 | cycle 127 | ] * n + self.f_start[cycle] 128 | self.last_f = f 129 | return f 130 | else: 131 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * ( 132 | self.cycle_lengths[cycle] - n 133 | ) / (self.cycle_lengths[cycle]) 134 | self.last_f = f 135 | return f 136 | -------------------------------------------------------------------------------- /extern/ldm_zero123/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/extern/ldm_zero123/models/diffusion/__init__.py -------------------------------------------------------------------------------- /extern/ldm_zero123/models/diffusion/sampling_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def append_dims(x, target_dims): 6 | """Appends dimensions to the end of a tensor until it has target_dims dimensions. 7 | From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" 8 | dims_to_append = target_dims - x.ndim 9 | if dims_to_append < 0: 10 | raise ValueError( 11 | f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" 12 | ) 13 | return x[(...,) + (None,) * dims_to_append] 14 | 15 | 16 | def renorm_thresholding(x0, value): 17 | # renorm 18 | pred_max = x0.max() 19 | pred_min = x0.min() 20 | pred_x0 = (x0 - pred_min) / (pred_max - pred_min) # 0 ... 1 21 | pred_x0 = 2 * pred_x0 - 1.0 # -1 ... 1 22 | 23 | s = torch.quantile(rearrange(pred_x0, "b ... -> b (...)").abs(), value, dim=-1) 24 | s.clamp_(min=1.0) 25 | s = s.view(-1, *((1,) * (pred_x0.ndim - 1))) 26 | 27 | # clip by threshold 28 | # pred_x0 = pred_x0.clamp(-s, s) / s # needs newer pytorch # TODO bring back to pure-gpu with min/max 29 | 30 | # temporary hack: numpy on cpu 31 | pred_x0 = ( 32 | np.clip(pred_x0.cpu().numpy(), -s.cpu().numpy(), s.cpu().numpy()) 33 | / s.cpu().numpy() 34 | ) 35 | pred_x0 = torch.tensor(pred_x0).to(self.model.device) 36 | 37 | # re.renorm 38 | pred_x0 = (pred_x0 + 1.0) / 2.0 # 0 ... 1 39 | pred_x0 = (pred_max - pred_min) * pred_x0 + pred_min # orig range 40 | return pred_x0 41 | 42 | 43 | def norm_thresholding(x0, value): 44 | s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) 45 | return x0 * (value / s) 46 | 47 | 48 | def spatial_norm_thresholding(x0, value): 49 | # b c h w 50 | s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) 51 | return x0 * (value / s) 52 | -------------------------------------------------------------------------------- /extern/ldm_zero123/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/extern/ldm_zero123/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /extern/ldm_zero123/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/extern/ldm_zero123/modules/distributions/__init__.py -------------------------------------------------------------------------------- /extern/ldm_zero123/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to( 34 | device=self.parameters.device 35 | ) 36 | 37 | def sample(self): 38 | x = self.mean + self.std * torch.randn(self.mean.shape).to( 39 | device=self.parameters.device 40 | ) 41 | return x 42 | 43 | def kl(self, other=None): 44 | if self.deterministic: 45 | return torch.Tensor([0.0]) 46 | else: 47 | if other is None: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 50 | dim=[1, 2, 3], 51 | ) 52 | else: 53 | return 0.5 * torch.sum( 54 | torch.pow(self.mean - other.mean, 2) / other.var 55 | + self.var / other.var 56 | - 1.0 57 | - self.logvar 58 | + other.logvar, 59 | dim=[1, 2, 3], 60 | ) 61 | 62 | def nll(self, sample, dims=[1, 2, 3]): 63 | if self.deterministic: 64 | return torch.Tensor([0.0]) 65 | logtwopi = np.log(2.0 * np.pi) 66 | return 0.5 * torch.sum( 67 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 68 | dim=dims, 69 | ) 70 | 71 | def mode(self): 72 | return self.mean 73 | 74 | 75 | def normal_kl(mean1, logvar1, mean2, logvar2): 76 | """ 77 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 78 | Compute the KL divergence between two gaussians. 79 | Shapes are automatically broadcasted, so batches can be compared to 80 | scalars, among other use cases. 81 | """ 82 | tensor = None 83 | for obj in (mean1, logvar1, mean2, logvar2): 84 | if isinstance(obj, torch.Tensor): 85 | tensor = obj 86 | break 87 | assert tensor is not None, "at least one argument must be a Tensor" 88 | 89 | # Force variances to be Tensors. Broadcasting helps convert scalars to 90 | # Tensors, but it does not work for torch.exp(). 91 | logvar1, logvar2 = [ 92 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 93 | for x in (logvar1, logvar2) 94 | ] 95 | 96 | return 0.5 * ( 97 | -1.0 98 | + logvar2 99 | - logvar1 100 | + torch.exp(logvar1 - logvar2) 101 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 102 | ) 103 | -------------------------------------------------------------------------------- /extern/ldm_zero123/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError("Decay must be between 0 and 1") 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer( 14 | "num_updates", 15 | torch.tensor(0, dtype=torch.int) 16 | if use_num_upates 17 | else torch.tensor(-1, dtype=torch.int), 18 | ) 19 | 20 | for name, p in model.named_parameters(): 21 | if p.requires_grad: 22 | # remove as '.'-character is not allowed in buffers 23 | s_name = name.replace(".", "") 24 | self.m_name2s_name.update({name: s_name}) 25 | self.register_buffer(s_name, p.clone().detach().data) 26 | 27 | self.collected_params = [] 28 | 29 | def forward(self, model): 30 | decay = self.decay 31 | 32 | if self.num_updates >= 0: 33 | self.num_updates += 1 34 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 35 | 36 | one_minus_decay = 1.0 - decay 37 | 38 | with torch.no_grad(): 39 | m_param = dict(model.named_parameters()) 40 | shadow_params = dict(self.named_buffers()) 41 | 42 | for key in m_param: 43 | if m_param[key].requires_grad: 44 | sname = self.m_name2s_name[key] 45 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 46 | shadow_params[sname].sub_( 47 | one_minus_decay * (shadow_params[sname] - m_param[key]) 48 | ) 49 | else: 50 | assert not key in self.m_name2s_name 51 | 52 | def copy_to(self, model): 53 | m_param = dict(model.named_parameters()) 54 | shadow_params = dict(self.named_buffers()) 55 | for key in m_param: 56 | if m_param[key].requires_grad: 57 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 58 | else: 59 | assert not key in self.m_name2s_name 60 | 61 | def store(self, parameters): 62 | """ 63 | Save the current parameters for restoring later. 64 | Args: 65 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 66 | temporarily stored. 67 | """ 68 | self.collected_params = [param.clone() for param in parameters] 69 | 70 | def restore(self, parameters): 71 | """ 72 | Restore the parameters stored with the `store` method. 73 | Useful to validate the model with EMA parameters without affecting the 74 | original optimization process. Store the parameters before the 75 | `copy_to` method. After validation (or model saving), use this to 76 | restore the former parameters. 77 | Args: 78 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 79 | updated with the stored parameters. 80 | """ 81 | for c_param, param in zip(self.collected_params, parameters): 82 | param.data.copy_(c_param.data) 83 | -------------------------------------------------------------------------------- /extern/ldm_zero123/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/extern/ldm_zero123/modules/encoders/__init__.py -------------------------------------------------------------------------------- /extern/ldm_zero123/modules/evaluate/ssim.py: -------------------------------------------------------------------------------- 1 | # MIT Licence 2 | 3 | # Methods to predict the SSIM, taken from 4 | # https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py 5 | 6 | from math import exp 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from torch.autograd import Variable 11 | 12 | 13 | def gaussian(window_size, sigma): 14 | gauss = torch.Tensor( 15 | [ 16 | exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) 17 | for x in range(window_size) 18 | ] 19 | ) 20 | return gauss / gauss.sum() 21 | 22 | 23 | def create_window(window_size, channel): 24 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 25 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 26 | window = Variable( 27 | _2D_window.expand(channel, 1, window_size, window_size).contiguous() 28 | ) 29 | return window 30 | 31 | 32 | def _ssim(img1, img2, window, window_size, channel, mask=None, size_average=True): 33 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 34 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 35 | 36 | mu1_sq = mu1.pow(2) 37 | mu2_sq = mu2.pow(2) 38 | mu1_mu2 = mu1 * mu2 39 | 40 | sigma1_sq = ( 41 | F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 42 | ) 43 | sigma2_sq = ( 44 | F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 45 | ) 46 | sigma12 = ( 47 | F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) 48 | - mu1_mu2 49 | ) 50 | 51 | C1 = (0.01) ** 2 52 | C2 = (0.03) ** 2 53 | 54 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ( 55 | (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) 56 | ) 57 | 58 | if not (mask is None): 59 | b = mask.size(0) 60 | ssim_map = ssim_map.mean(dim=1, keepdim=True) * mask 61 | ssim_map = ssim_map.view(b, -1).sum(dim=1) / mask.view(b, -1).sum(dim=1).clamp( 62 | min=1 63 | ) 64 | return ssim_map 65 | 66 | import pdb 67 | 68 | pdb.set_trace 69 | 70 | if size_average: 71 | return ssim_map.mean() 72 | else: 73 | return ssim_map.mean(1).mean(1).mean(1) 74 | 75 | 76 | class SSIM(torch.nn.Module): 77 | def __init__(self, window_size=11, size_average=True): 78 | super(SSIM, self).__init__() 79 | self.window_size = window_size 80 | self.size_average = size_average 81 | self.channel = 1 82 | self.window = create_window(window_size, self.channel) 83 | 84 | def forward(self, img1, img2, mask=None): 85 | (_, channel, _, _) = img1.size() 86 | 87 | if channel == self.channel and self.window.data.type() == img1.data.type(): 88 | window = self.window 89 | else: 90 | window = create_window(self.window_size, channel) 91 | 92 | if img1.is_cuda: 93 | window = window.cuda(img1.get_device()) 94 | window = window.type_as(img1) 95 | 96 | self.window = window 97 | self.channel = channel 98 | 99 | return _ssim( 100 | img1, 101 | img2, 102 | window, 103 | self.window_size, 104 | channel, 105 | mask, 106 | self.size_average, 107 | ) 108 | 109 | 110 | def ssim(img1, img2, window_size=11, mask=None, size_average=True): 111 | (_, channel, _, _) = img1.size() 112 | window = create_window(window_size, channel) 113 | 114 | if img1.is_cuda: 115 | window = window.cuda(img1.get_device()) 116 | window = window.type_as(img1) 117 | 118 | return _ssim(img1, img2, window, window_size, channel, mask, size_average) 119 | -------------------------------------------------------------------------------- /extern/ldm_zero123/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from extern.ldm_zero123.modules.image_degradation.bsrgan import ( 2 | degradation_bsrgan_variant as degradation_fn_bsr, 3 | ) 4 | from extern.ldm_zero123.modules.image_degradation.bsrgan_light import ( 5 | degradation_bsrgan_variant as degradation_fn_bsr_light, 6 | ) 7 | -------------------------------------------------------------------------------- /extern/ldm_zero123/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/extern/ldm_zero123/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /extern/ldm_zero123/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from extern.ldm_zero123.modules.losses.contperceptual import LPIPSWithDiscriminator 2 | -------------------------------------------------------------------------------- /extern/ldm_zero123/thirdp/psp/helpers.py: -------------------------------------------------------------------------------- 1 | # https://github.com/eladrich/pixel2style2pixel 2 | 3 | from collections import namedtuple 4 | 5 | import torch 6 | from torch.nn import ( 7 | AdaptiveAvgPool2d, 8 | BatchNorm2d, 9 | Conv2d, 10 | MaxPool2d, 11 | Module, 12 | PReLU, 13 | ReLU, 14 | Sequential, 15 | Sigmoid, 16 | ) 17 | 18 | """ 19 | ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 20 | """ 21 | 22 | 23 | class Flatten(Module): 24 | def forward(self, input): 25 | return input.view(input.size(0), -1) 26 | 27 | 28 | def l2_norm(input, axis=1): 29 | norm = torch.norm(input, 2, axis, True) 30 | output = torch.div(input, norm) 31 | return output 32 | 33 | 34 | class Bottleneck(namedtuple("Block", ["in_channel", "depth", "stride"])): 35 | """A named tuple describing a ResNet block.""" 36 | 37 | 38 | def get_block(in_channel, depth, num_units, stride=2): 39 | return [Bottleneck(in_channel, depth, stride)] + [ 40 | Bottleneck(depth, depth, 1) for i in range(num_units - 1) 41 | ] 42 | 43 | 44 | def get_blocks(num_layers): 45 | if num_layers == 50: 46 | blocks = [ 47 | get_block(in_channel=64, depth=64, num_units=3), 48 | get_block(in_channel=64, depth=128, num_units=4), 49 | get_block(in_channel=128, depth=256, num_units=14), 50 | get_block(in_channel=256, depth=512, num_units=3), 51 | ] 52 | elif num_layers == 100: 53 | blocks = [ 54 | get_block(in_channel=64, depth=64, num_units=3), 55 | get_block(in_channel=64, depth=128, num_units=13), 56 | get_block(in_channel=128, depth=256, num_units=30), 57 | get_block(in_channel=256, depth=512, num_units=3), 58 | ] 59 | elif num_layers == 152: 60 | blocks = [ 61 | get_block(in_channel=64, depth=64, num_units=3), 62 | get_block(in_channel=64, depth=128, num_units=8), 63 | get_block(in_channel=128, depth=256, num_units=36), 64 | get_block(in_channel=256, depth=512, num_units=3), 65 | ] 66 | else: 67 | raise ValueError( 68 | "Invalid number of layers: {}. Must be one of [50, 100, 152]".format( 69 | num_layers 70 | ) 71 | ) 72 | return blocks 73 | 74 | 75 | class SEModule(Module): 76 | def __init__(self, channels, reduction): 77 | super(SEModule, self).__init__() 78 | self.avg_pool = AdaptiveAvgPool2d(1) 79 | self.fc1 = Conv2d( 80 | channels, channels // reduction, kernel_size=1, padding=0, bias=False 81 | ) 82 | self.relu = ReLU(inplace=True) 83 | self.fc2 = Conv2d( 84 | channels // reduction, channels, kernel_size=1, padding=0, bias=False 85 | ) 86 | self.sigmoid = Sigmoid() 87 | 88 | def forward(self, x): 89 | module_input = x 90 | x = self.avg_pool(x) 91 | x = self.fc1(x) 92 | x = self.relu(x) 93 | x = self.fc2(x) 94 | x = self.sigmoid(x) 95 | return module_input * x 96 | 97 | 98 | class bottleneck_IR(Module): 99 | def __init__(self, in_channel, depth, stride): 100 | super(bottleneck_IR, self).__init__() 101 | if in_channel == depth: 102 | self.shortcut_layer = MaxPool2d(1, stride) 103 | else: 104 | self.shortcut_layer = Sequential( 105 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 106 | BatchNorm2d(depth), 107 | ) 108 | self.res_layer = Sequential( 109 | BatchNorm2d(in_channel), 110 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), 111 | PReLU(depth), 112 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), 113 | BatchNorm2d(depth), 114 | ) 115 | 116 | def forward(self, x): 117 | shortcut = self.shortcut_layer(x) 118 | res = self.res_layer(x) 119 | return res + shortcut 120 | 121 | 122 | class bottleneck_IR_SE(Module): 123 | def __init__(self, in_channel, depth, stride): 124 | super(bottleneck_IR_SE, self).__init__() 125 | if in_channel == depth: 126 | self.shortcut_layer = MaxPool2d(1, stride) 127 | else: 128 | self.shortcut_layer = Sequential( 129 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 130 | BatchNorm2d(depth), 131 | ) 132 | self.res_layer = Sequential( 133 | BatchNorm2d(in_channel), 134 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), 135 | PReLU(depth), 136 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), 137 | BatchNorm2d(depth), 138 | SEModule(depth, 16), 139 | ) 140 | 141 | def forward(self, x): 142 | shortcut = self.shortcut_layer(x) 143 | res = self.res_layer(x) 144 | return res + shortcut 145 | -------------------------------------------------------------------------------- /extern/ldm_zero123/thirdp/psp/id_loss.py: -------------------------------------------------------------------------------- 1 | # https://github.com/eladrich/pixel2style2pixel 2 | import torch 3 | from torch import nn 4 | 5 | from extern.ldm_zero123.thirdp.psp.model_irse import Backbone 6 | 7 | 8 | class IDFeatures(nn.Module): 9 | def __init__(self, model_path): 10 | super(IDFeatures, self).__init__() 11 | print("Loading ResNet ArcFace") 12 | self.facenet = Backbone( 13 | input_size=112, num_layers=50, drop_ratio=0.6, mode="ir_se" 14 | ) 15 | self.facenet.load_state_dict(torch.load(model_path, map_location="cpu")) 16 | self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) 17 | self.facenet.eval() 18 | 19 | def forward(self, x, crop=False): 20 | # Not sure of the image range here 21 | if crop: 22 | x = torch.nn.functional.interpolate(x, (256, 256), mode="area") 23 | x = x[:, :, 35:223, 32:220] 24 | x = self.face_pool(x) 25 | x_feats = self.facenet(x) 26 | return x_feats 27 | -------------------------------------------------------------------------------- /extern/ldm_zero123/thirdp/psp/model_irse.py: -------------------------------------------------------------------------------- 1 | # https://github.com/eladrich/pixel2style2pixel 2 | 3 | from torch.nn import ( 4 | BatchNorm1d, 5 | BatchNorm2d, 6 | Conv2d, 7 | Dropout, 8 | Linear, 9 | Module, 10 | PReLU, 11 | Sequential, 12 | ) 13 | 14 | from extern.ldm_zero123.thirdp.psp.helpers import ( 15 | Flatten, 16 | bottleneck_IR, 17 | bottleneck_IR_SE, 18 | get_blocks, 19 | l2_norm, 20 | ) 21 | 22 | """ 23 | Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 24 | """ 25 | 26 | 27 | class Backbone(Module): 28 | def __init__(self, input_size, num_layers, mode="ir", drop_ratio=0.4, affine=True): 29 | super(Backbone, self).__init__() 30 | assert input_size in [112, 224], "input_size should be 112 or 224" 31 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" 32 | assert mode in ["ir", "ir_se"], "mode should be ir or ir_se" 33 | blocks = get_blocks(num_layers) 34 | if mode == "ir": 35 | unit_module = bottleneck_IR 36 | elif mode == "ir_se": 37 | unit_module = bottleneck_IR_SE 38 | self.input_layer = Sequential( 39 | Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), PReLU(64) 40 | ) 41 | if input_size == 112: 42 | self.output_layer = Sequential( 43 | BatchNorm2d(512), 44 | Dropout(drop_ratio), 45 | Flatten(), 46 | Linear(512 * 7 * 7, 512), 47 | BatchNorm1d(512, affine=affine), 48 | ) 49 | else: 50 | self.output_layer = Sequential( 51 | BatchNorm2d(512), 52 | Dropout(drop_ratio), 53 | Flatten(), 54 | Linear(512 * 14 * 14, 512), 55 | BatchNorm1d(512, affine=affine), 56 | ) 57 | 58 | modules = [] 59 | for block in blocks: 60 | for bottleneck in block: 61 | modules.append( 62 | unit_module( 63 | bottleneck.in_channel, bottleneck.depth, bottleneck.stride 64 | ) 65 | ) 66 | self.body = Sequential(*modules) 67 | 68 | def forward(self, x): 69 | x = self.input_layer(x) 70 | x = self.body(x) 71 | x = self.output_layer(x) 72 | return l2_norm(x) 73 | 74 | 75 | def IR_50(input_size): 76 | """Constructs a ir-50 model.""" 77 | model = Backbone(input_size, num_layers=50, mode="ir", drop_ratio=0.4, affine=False) 78 | return model 79 | 80 | 81 | def IR_101(input_size): 82 | """Constructs a ir-101 model.""" 83 | model = Backbone( 84 | input_size, num_layers=100, mode="ir", drop_ratio=0.4, affine=False 85 | ) 86 | return model 87 | 88 | 89 | def IR_152(input_size): 90 | """Constructs a ir-152 model.""" 91 | model = Backbone( 92 | input_size, num_layers=152, mode="ir", drop_ratio=0.4, affine=False 93 | ) 94 | return model 95 | 96 | 97 | def IR_SE_50(input_size): 98 | """Constructs a ir_se-50 model.""" 99 | model = Backbone( 100 | input_size, num_layers=50, mode="ir_se", drop_ratio=0.4, affine=False 101 | ) 102 | return model 103 | 104 | 105 | def IR_SE_101(input_size): 106 | """Constructs a ir_se-101 model.""" 107 | model = Backbone( 108 | input_size, num_layers=100, mode="ir_se", drop_ratio=0.4, affine=False 109 | ) 110 | return model 111 | 112 | 113 | def IR_SE_152(input_size): 114 | """Constructs a ir_se-152 model.""" 115 | model = Backbone( 116 | input_size, num_layers=152, mode="ir_se", drop_ratio=0.4, affine=False 117 | ) 118 | return model 119 | -------------------------------------------------------------------------------- /load/images/a DSLR photo of a puffin standing on a rock_depth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/load/images/a DSLR photo of a puffin standing on a rock_depth.png -------------------------------------------------------------------------------- /load/images/a DSLR photo of a puffin standing on a rock_normal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/load/images/a DSLR photo of a puffin standing on a rock_normal.png -------------------------------------------------------------------------------- /load/images/a DSLR photo of a puffin standing on a rock_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/load/images/a DSLR photo of a puffin standing on a rock_rgba.png -------------------------------------------------------------------------------- /load/images/a figurine of a frog holding a birthday cake_depth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/load/images/a figurine of a frog holding a birthday cake_depth.png -------------------------------------------------------------------------------- /load/images/a figurine of a frog holding a birthday cake_normal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/load/images/a figurine of a frog holding a birthday cake_normal.png -------------------------------------------------------------------------------- /load/images/a figurine of a frog holding a birthday cake_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/load/images/a figurine of a frog holding a birthday cake_rgba.png -------------------------------------------------------------------------------- /load/images/a kingfisher sitting on top of a piece of wood_depth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/load/images/a kingfisher sitting on top of a piece of wood_depth.png -------------------------------------------------------------------------------- /load/images/a kingfisher sitting on top of a piece of wood_normal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/load/images/a kingfisher sitting on top of a piece of wood_normal.png -------------------------------------------------------------------------------- /load/images/a kingfisher sitting on top of a piece of wood_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/load/images/a kingfisher sitting on top of a piece of wood_rgba.png -------------------------------------------------------------------------------- /load/images/a rubber duck dressed as a nurse_depth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/load/images/a rubber duck dressed as a nurse_depth.png -------------------------------------------------------------------------------- /load/images/a rubber duck dressed as a nurse_normal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/load/images/a rubber duck dressed as a nurse_normal.png -------------------------------------------------------------------------------- /load/images/a rubber duck dressed as a nurse_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/load/images/a rubber duck dressed as a nurse_rgba.png -------------------------------------------------------------------------------- /load/images/a white bowl of multiple fruits_depth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/load/images/a white bowl of multiple fruits_depth.png -------------------------------------------------------------------------------- /load/images/a white bowl of multiple fruits_normal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/load/images/a white bowl of multiple fruits_normal.png -------------------------------------------------------------------------------- /load/images/a white bowl of multiple fruits_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/load/images/a white bowl of multiple fruits_rgba.png -------------------------------------------------------------------------------- /load/images/groot_caption.txt: -------------------------------------------------------------------------------- 1 | cat head anthropomorphic humanoid body, movie poster, marvel little cute Groot character, high detail, hyper realistic, octane rendering. -------------------------------------------------------------------------------- /load/images/groot_depth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/load/images/groot_depth.png -------------------------------------------------------------------------------- /load/images/groot_normal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/load/images/groot_normal.png -------------------------------------------------------------------------------- /load/images/groot_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/load/images/groot_rgba.png -------------------------------------------------------------------------------- /load/images/jay-basket_caption.txt: -------------------------------------------------------------------------------- 1 | a DSLR photo of a blue jay standing on a large basket of rainbow macarons -------------------------------------------------------------------------------- /load/images/jay-basket_depth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/load/images/jay-basket_depth.png -------------------------------------------------------------------------------- /load/images/jay-basket_normal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/load/images/jay-basket_normal.png -------------------------------------------------------------------------------- /load/images/jay-basket_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/load/images/jay-basket_rgba.png -------------------------------------------------------------------------------- /load/images/mushroom_log_caption.txt: -------------------------------------------------------------------------------- 1 | a brightly colored mushroom growing on a log -------------------------------------------------------------------------------- /load/images/mushroom_log_depth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/load/images/mushroom_log_depth.png -------------------------------------------------------------------------------- /load/images/mushroom_log_normal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/load/images/mushroom_log_normal.png -------------------------------------------------------------------------------- /load/images/mushroom_log_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/load/images/mushroom_log_rgba.png -------------------------------------------------------------------------------- /load/images/tiger dressed as a nurse_depth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/load/images/tiger dressed as a nurse_depth.png -------------------------------------------------------------------------------- /load/images/tiger dressed as a nurse_normal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/load/images/tiger dressed as a nurse_normal.png -------------------------------------------------------------------------------- /load/images/tiger dressed as a nurse_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/load/images/tiger dressed as a nurse_rgba.png -------------------------------------------------------------------------------- /load/lights/LICENSE.txt: -------------------------------------------------------------------------------- 1 | The mud_road_puresky.hdr HDR probe is from https://polyhaven.com/a/mud_road_puresky 2 | CC0 License. 3 | -------------------------------------------------------------------------------- /load/lights/bsdf_256_256.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/load/lights/bsdf_256_256.bin -------------------------------------------------------------------------------- /load/lights/mud_road_puresky_1k.hdr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/load/lights/mud_road_puresky_1k.hdr -------------------------------------------------------------------------------- /load/tets/128_tets.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/load/tets/128_tets.npz -------------------------------------------------------------------------------- /load/tets/32_tets.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/load/tets/32_tets.npz -------------------------------------------------------------------------------- /load/tets/64_tets.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DreamCraft3D/5829ef116d36c871ce2b9e54a6153dd3856a1561/load/tets/64_tets.npz -------------------------------------------------------------------------------- /load/tets/generate_tets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 4 | # property and proprietary rights in and to this material, related 5 | # documentation and any modifications thereto. Any use, reproduction, 6 | # disclosure or distribution of this material and related documentation 7 | # without an express license agreement from NVIDIA CORPORATION or 8 | # its affiliates is strictly prohibited. 9 | 10 | import os 11 | 12 | import numpy as np 13 | 14 | """ 15 | This code segment shows how to use Quartet: https://github.com/crawforddoran/quartet, 16 | to generate a tet grid 17 | 1) Download, compile and run Quartet as described in the link above. Example usage `quartet meshes/cube.obj 0.5 cube_5.tet` 18 | 2) Run the function below to generate a file `cube_32_tet.tet` 19 | """ 20 | 21 | 22 | def generate_tetrahedron_grid_file(res=32, root=".."): 23 | frac = 1.0 / res 24 | command = f"cd {root}; ./quartet meshes/cube.obj {frac} meshes/cube_{res}_tet.tet -s meshes/cube_boundary_{res}.obj" 25 | os.system(command) 26 | 27 | 28 | """ 29 | This code segment shows how to convert from a quartet .tet file to compressed npz file 30 | """ 31 | 32 | 33 | def convert_from_quartet_to_npz(quartetfile="cube_32_tet.tet", npzfile="32_tets"): 34 | file1 = open(quartetfile, "r") 35 | header = file1.readline() 36 | numvertices = int(header.split(" ")[1]) 37 | numtets = int(header.split(" ")[2]) 38 | print(numvertices, numtets) 39 | 40 | # load vertices 41 | vertices = np.loadtxt(quartetfile, skiprows=1, max_rows=numvertices) 42 | print(vertices.shape) 43 | 44 | # load indices 45 | indices = np.loadtxt( 46 | quartetfile, dtype=int, skiprows=1 + numvertices, max_rows=numtets 47 | ) 48 | print(indices.shape) 49 | 50 | np.savez_compressed(npzfile, vertices=vertices, indices=indices) 51 | 52 | 53 | root = "/home/gyc/quartet" 54 | for res in [300, 350, 400]: 55 | generate_tetrahedron_grid_file(res, root) 56 | convert_from_quartet_to_npz( 57 | os.path.join(root, f"meshes/cube_{res}_tet.tet"), npzfile=f"{res}_tets" 58 | ) 59 | -------------------------------------------------------------------------------- /load/zero123/download.sh: -------------------------------------------------------------------------------- 1 | # wget https://huggingface.co/cvlab/zero123-weights/resolve/main/105000.ckpt 2 | # mv 105000.ckpt zero123-original.ckpt 3 | wget https://zero123.cs.columbia.edu/assets/zero123-xl.ckpt 4 | # Download stable_zero123.ckpt from https://huggingface.co/stabilityai/stable-zero123 -------------------------------------------------------------------------------- /load/zero123/sd-objaverse-finetune-c_concat-256.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: extern.ldm_zero123.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "image_target" 11 | cond_stage_key: "image_cond" 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: hybrid 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | 19 | scheduler_config: # 10000 warmup steps 20 | target: extern.ldm_zero123.lr_scheduler.LambdaLinearScheduler 21 | params: 22 | warm_up_steps: [ 100 ] 23 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 24 | f_start: [ 1.e-6 ] 25 | f_max: [ 1. ] 26 | f_min: [ 1. ] 27 | 28 | unet_config: 29 | target: extern.ldm_zero123.modules.diffusionmodules.openaimodel.UNetModel 30 | params: 31 | image_size: 32 # unused 32 | in_channels: 8 33 | out_channels: 4 34 | model_channels: 320 35 | attention_resolutions: [ 4, 2, 1 ] 36 | num_res_blocks: 2 37 | channel_mult: [ 1, 2, 4, 4 ] 38 | num_heads: 8 39 | use_spatial_transformer: True 40 | transformer_depth: 1 41 | context_dim: 768 42 | use_checkpoint: True 43 | legacy: False 44 | 45 | first_stage_config: 46 | target: extern.ldm_zero123.models.autoencoder.AutoencoderKL 47 | params: 48 | embed_dim: 4 49 | monitor: val/rec_loss 50 | ddconfig: 51 | double_z: true 52 | z_channels: 4 53 | resolution: 256 54 | in_channels: 3 55 | out_ch: 3 56 | ch: 128 57 | ch_mult: 58 | - 1 59 | - 2 60 | - 4 61 | - 4 62 | num_res_blocks: 2 63 | attn_resolutions: [] 64 | dropout: 0.0 65 | lossconfig: 66 | target: torch.nn.Identity 67 | 68 | cond_stage_config: 69 | target: extern.ldm_zero123.modules.encoders.modules.FrozenCLIPImageEmbedder 70 | 71 | 72 | # data: 73 | # target: extern.ldm_zero123.data.simple.ObjaverseDataModuleFromConfig 74 | # params: 75 | # root_dir: 'views_whole_sphere' 76 | # batch_size: 192 77 | # num_workers: 16 78 | # total_view: 4 79 | # train: 80 | # validation: False 81 | # image_transforms: 82 | # size: 256 83 | 84 | # validation: 85 | # validation: True 86 | # image_transforms: 87 | # size: 256 88 | 89 | 90 | # lightning: 91 | # find_unused_parameters: false 92 | # metrics_over_trainsteps_checkpoint: True 93 | # modelcheckpoint: 94 | # params: 95 | # every_n_train_steps: 5000 96 | # callbacks: 97 | # image_logger: 98 | # target: main.ImageLogger 99 | # params: 100 | # batch_frequency: 500 101 | # max_images: 32 102 | # increase_log_steps: False 103 | # log_first_step: True 104 | # log_images_kwargs: 105 | # use_ema_scope: False 106 | # inpaint: False 107 | # plot_progressive_rows: False 108 | # plot_diffusion_rows: False 109 | # N: 32 110 | # unconditional_scale: 3.0 111 | # unconditional_label: [""] 112 | 113 | # trainer: 114 | # benchmark: True 115 | # val_check_interval: 5000000 # really sorry 116 | # num_sanity_val_steps: 0 117 | # accumulate_grad_batches: 1 118 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | lightning==2.0.0 2 | omegaconf==2.3.0 3 | jaxtyping 4 | typeguard 5 | diffusers<=0.23.0 6 | transformers 7 | accelerate 8 | opencv-python 9 | tensorboard 10 | matplotlib 11 | imageio>=2.28.0 12 | imageio[ffmpeg] 13 | libigl 14 | xatlas 15 | trimesh[easy] 16 | networkx 17 | pysdf 18 | PyMCubes 19 | wandb 20 | gradio 21 | 22 | # deepfloyd 23 | xformers 24 | bitsandbytes 25 | sentencepiece 26 | safetensors 27 | huggingface_hub 28 | 29 | # for zero123 30 | einops 31 | kornia 32 | taming-transformers-rom1504 33 | 34 | #controlnet 35 | controlnet_aux 36 | numpy>=1.22.2 # not directly required, pinned by Snyk to avoid a vulnerability 37 | -------------------------------------------------------------------------------- /threestudio/__init__.py: -------------------------------------------------------------------------------- 1 | __modules__ = {} 2 | 3 | 4 | def register(name): 5 | def decorator(cls): 6 | __modules__[name] = cls 7 | return cls 8 | 9 | return decorator 10 | 11 | 12 | def find(name): 13 | return __modules__[name] 14 | 15 | 16 | ### grammar sugar for logging utilities ### 17 | import logging 18 | 19 | logger = logging.getLogger("pytorch_lightning") 20 | 21 | from pytorch_lightning.utilities.rank_zero import ( 22 | rank_zero_debug, 23 | rank_zero_info, 24 | rank_zero_only, 25 | ) 26 | 27 | debug = rank_zero_debug 28 | info = rank_zero_info 29 | 30 | 31 | @rank_zero_only 32 | def warn(*args, **kwargs): 33 | logger.warn(*args, **kwargs) 34 | 35 | 36 | from . import data, models, systems 37 | -------------------------------------------------------------------------------- /threestudio/data/__init__.py: -------------------------------------------------------------------------------- 1 | from . import image, uncond -------------------------------------------------------------------------------- /threestudio/models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | background, 3 | exporters, 4 | geometry, 5 | guidance, 6 | materials, 7 | prompt_processors, 8 | renderers, 9 | ) 10 | -------------------------------------------------------------------------------- /threestudio/models/background/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | base, 3 | neural_environment_map_background, 4 | solid_color_background, 5 | textured_background, 6 | ) 7 | -------------------------------------------------------------------------------- /threestudio/models/background/base.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.utils.base import BaseModule 10 | from threestudio.utils.typing import * 11 | 12 | 13 | class BaseBackground(BaseModule): 14 | @dataclass 15 | class Config(BaseModule.Config): 16 | pass 17 | 18 | cfg: Config 19 | 20 | def configure(self): 21 | pass 22 | 23 | def forward(self, dirs: Float[Tensor, "B H W 3"]) -> Float[Tensor, "B H W Nc"]: 24 | raise NotImplementedError -------------------------------------------------------------------------------- /threestudio/models/background/neural_environment_map_background.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.models.background.base import BaseBackground 10 | from threestudio.models.networks import get_encoding, get_mlp 11 | from threestudio.utils.ops import get_activation 12 | from threestudio.utils.typing import * 13 | 14 | 15 | @threestudio.register("neural-environment-map-background") 16 | class NeuralEnvironmentMapBackground(BaseBackground): 17 | @dataclass 18 | class Config(BaseBackground.Config): 19 | n_output_dims: int = 3 20 | color_activation: str = "sigmoid" 21 | dir_encoding_config: dict = field( 22 | default_factory=lambda: {"otype": "SphericalHarmonics", "degree": 3} 23 | ) 24 | mlp_network_config: dict = field( 25 | default_factory=lambda: { 26 | "otype": "VanillaMLP", 27 | "activation": "ReLU", 28 | "n_neurons": 16, 29 | "n_hidden_layers": 2, 30 | } 31 | ) 32 | random_aug: bool = False 33 | random_aug_prob: float = 0.5 34 | eval_color: Optional[Tuple[float, float, float]] = None 35 | 36 | # multi-view diffusion 37 | share_aug_bg: bool = False 38 | 39 | cfg: Config 40 | 41 | def configure(self) -> None: 42 | self.encoding = get_encoding(3, self.cfg.dir_encoding_config) 43 | self.network = get_mlp( 44 | self.encoding.n_output_dims, 45 | self.cfg.n_output_dims, 46 | self.cfg.mlp_network_config, 47 | ) 48 | 49 | def forward(self, dirs: Float[Tensor, "B H W 3"]) -> Float[Tensor, "B H W Nc"]: 50 | if not self.training and self.cfg.eval_color is not None: 51 | return torch.ones(*dirs.shape[:-1], self.cfg.n_output_dims).to( 52 | dirs 53 | ) * torch.as_tensor(self.cfg.eval_color).to(dirs) 54 | # viewdirs must be normalized before passing to this function 55 | dirs = (dirs + 1.0) / 2.0 # (-1, 1) => (0, 1) 56 | dirs_embd = self.encoding(dirs.view(-1, 3)) 57 | color = self.network(dirs_embd).view(*dirs.shape[:-1], self.cfg.n_output_dims) 58 | color = get_activation(self.cfg.color_activation)(color) 59 | if ( 60 | self.training 61 | and self.cfg.random_aug 62 | and random.random() < self.cfg.random_aug_prob 63 | ): 64 | # use random background color with probability random_aug_prob 65 | n_color = 1 if self.cfg.share_aug_bg else dirs.shape[0] 66 | color = color * 0 + ( # prevent checking for unused parameters in DDP 67 | torch.rand(n_color, 1, 1, self.cfg.n_output_dims) 68 | .to(dirs) 69 | .expand(*dirs.shape[:-1], -1) 70 | ) 71 | return color -------------------------------------------------------------------------------- /threestudio/models/background/solid_color_background.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.models.background.base import BaseBackground 10 | from threestudio.utils.typing import * 11 | 12 | 13 | @threestudio.register("solid-color-background") 14 | class SolidColorBackground(BaseBackground): 15 | @dataclass 16 | class Config(BaseBackground.Config): 17 | n_output_dims: int = 3 18 | color: Tuple = (1.0, 1.0, 1.0) 19 | learned: bool = False 20 | random_aug: bool = False 21 | random_aug_prob: float = 0.5 22 | 23 | cfg: Config 24 | 25 | def configure(self) -> None: 26 | self.env_color: Float[Tensor, "Nc"] 27 | if self.cfg.learned: 28 | self.env_color = nn.Parameter( 29 | torch.as_tensor(self.cfg.color, dtype=torch.float32) 30 | ) 31 | else: 32 | self.register_buffer( 33 | "env_color", torch.as_tensor(self.cfg.color, dtype=torch.float32) 34 | ) 35 | 36 | def forward(self, dirs: Float[Tensor, "B H W 3"]) -> Float[Tensor, "B H W Nc"]: 37 | color = torch.ones(*dirs.shape[:-1], self.cfg.n_output_dims).to( 38 | dirs 39 | ) * self.env_color.to(dirs) 40 | if ( 41 | self.training 42 | and self.cfg.random_aug 43 | and random.random() < self.cfg.random_aug_prob 44 | ): 45 | # use random background color with probability random_aug_prob 46 | color = color * 0 + ( # prevent checking for unused parameters in DDP 47 | torch.rand(dirs.shape[0], 1, 1, self.cfg.n_output_dims) 48 | .to(dirs) 49 | .expand(*dirs.shape[:-1], -1) 50 | ) 51 | return color -------------------------------------------------------------------------------- /threestudio/models/background/textured_background.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | import threestudio 8 | from threestudio.models.background.base import BaseBackground 9 | from threestudio.utils.ops import get_activation 10 | from threestudio.utils.typing import * 11 | 12 | 13 | @threestudio.register("textured-background") 14 | class TexturedBackground(BaseBackground): 15 | @dataclass 16 | class Config(BaseBackground.Config): 17 | n_output_dims: int = 3 18 | height: int = 64 19 | width: int = 64 20 | color_activation: str = "sigmoid" 21 | 22 | cfg: Config 23 | 24 | def configure(self) -> None: 25 | self.texture = nn.Parameter( 26 | torch.randn((1, self.cfg.n_output_dims, self.cfg.height, self.cfg.width)) 27 | ) 28 | 29 | def spherical_xyz_to_uv(self, dirs: Float[Tensor, "*B 3"]) -> Float[Tensor, "*B 2"]: 30 | x, y, z = dirs[..., 0], dirs[..., 1], dirs[..., 2] 31 | xy = (x**2 + y**2) ** 0.5 32 | u = torch.atan2(xy, z) / torch.pi 33 | v = torch.atan2(y, x) / (torch.pi * 2) + 0.5 34 | uv = torch.stack([u, v], -1) 35 | return uv 36 | 37 | def forward(self, dirs: Float[Tensor, "*B 3"]) -> Float[Tensor, "*B Nc"]: 38 | dirs_shape = dirs.shape[:-1] 39 | uv = self.spherical_xyz_to_uv(dirs.reshape(-1, dirs.shape[-1])) 40 | uv = 2 * uv - 1 # rescale to [-1, 1] for grid_sample 41 | uv = uv.reshape(1, -1, 1, 2) 42 | color = ( 43 | F.grid_sample( 44 | self.texture, 45 | uv, 46 | mode="bilinear", 47 | padding_mode="reflection", 48 | align_corners=False, 49 | ) 50 | .reshape(self.cfg.n_output_dims, -1) 51 | .T.reshape(*dirs_shape, self.cfg.n_output_dims) 52 | ) 53 | color = get_activation(self.cfg.color_activation)(color) 54 | return color -------------------------------------------------------------------------------- /threestudio/models/estimators.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List, Optional, Tuple 2 | 3 | try: 4 | from typing import Literal 5 | except ImportError: 6 | from typing_extensions import Literal 7 | 8 | import torch 9 | from nerfacc.data_specs import RayIntervals 10 | from nerfacc.estimators.base import AbstractEstimator 11 | from nerfacc.pdf import importance_sampling, searchsorted 12 | from nerfacc.volrend import render_transmittance_from_density 13 | from torch import Tensor 14 | 15 | 16 | class ImportanceEstimator(AbstractEstimator): 17 | def __init__( 18 | self, 19 | ) -> None: 20 | super().__init__() 21 | 22 | @torch.no_grad() 23 | def sampling( 24 | self, 25 | prop_sigma_fns: List[Callable], 26 | prop_samples: List[int], 27 | num_samples: int, 28 | # rendering options 29 | n_rays: int, 30 | near_plane: float, 31 | far_plane: float, 32 | sampling_type: Literal["uniform", "lindisp"] = "uniform", 33 | # training options 34 | stratified: bool = False, 35 | requires_grad: bool = False, 36 | ) -> Tuple[Tensor, Tensor]: 37 | """Sampling with CDFs from proposal networks. 38 | 39 | Args: 40 | prop_sigma_fns: Proposal network evaluate functions. It should be a list 41 | of functions that take in samples {t_starts (n_rays, n_samples), 42 | t_ends (n_rays, n_samples)} and returns the post-activation densities 43 | (n_rays, n_samples). 44 | prop_samples: Number of samples to draw from each proposal network. Should 45 | be the same length as `prop_sigma_fns`. 46 | num_samples: Number of samples to draw in the end. 47 | n_rays: Number of rays. 48 | near_plane: Near plane. 49 | far_plane: Far plane. 50 | sampling_type: Sampling type. Either "uniform" or "lindisp". Default to 51 | "lindisp". 52 | stratified: Whether to use stratified sampling. Default to `False`. 53 | 54 | Returns: 55 | A tuple of {Tensor, Tensor}: 56 | 57 | - **t_starts**: The starts of the samples. Shape (n_rays, num_samples). 58 | - **t_ends**: The ends of the samples. Shape (n_rays, num_samples). 59 | 60 | """ 61 | assert len(prop_sigma_fns) == len(prop_samples), ( 62 | "The number of proposal networks and the number of samples " 63 | "should be the same." 64 | ) 65 | cdfs = torch.cat( 66 | [ 67 | torch.zeros((n_rays, 1), device=self.device), 68 | torch.ones((n_rays, 1), device=self.device), 69 | ], 70 | dim=-1, 71 | ) 72 | intervals = RayIntervals(vals=cdfs) 73 | 74 | for level_fn, level_samples in zip(prop_sigma_fns, prop_samples): 75 | intervals, _ = importance_sampling( 76 | intervals, cdfs, level_samples, stratified 77 | ) 78 | t_vals = _transform_stot( 79 | sampling_type, intervals.vals, near_plane, far_plane 80 | ) 81 | t_starts = t_vals[..., :-1] 82 | t_ends = t_vals[..., 1:] 83 | 84 | with torch.set_grad_enabled(requires_grad): 85 | sigmas = level_fn(t_starts, t_ends) 86 | assert sigmas.shape == t_starts.shape 87 | trans, _ = render_transmittance_from_density(t_starts, t_ends, sigmas) 88 | cdfs = 1.0 - torch.cat([trans, torch.zeros_like(trans[:, :1])], dim=-1) 89 | 90 | intervals, _ = importance_sampling(intervals, cdfs, num_samples, stratified) 91 | t_vals_fine = _transform_stot( 92 | sampling_type, intervals.vals, near_plane, far_plane 93 | ) 94 | 95 | t_vals = torch.cat([t_vals, t_vals_fine], dim=-1) 96 | t_vals, _ = torch.sort(t_vals, dim=-1) 97 | 98 | t_starts_ = t_vals[..., :-1] 99 | t_ends_ = t_vals[..., 1:] 100 | 101 | return t_starts_, t_ends_ 102 | 103 | 104 | def _transform_stot( 105 | transform_type: Literal["uniform", "lindisp"], 106 | s_vals: torch.Tensor, 107 | t_min: torch.Tensor, 108 | t_max: torch.Tensor, 109 | ) -> torch.Tensor: 110 | if transform_type == "uniform": 111 | _contract_fn, _icontract_fn = lambda x: x, lambda x: x 112 | elif transform_type == "lindisp": 113 | _contract_fn, _icontract_fn = lambda x: 1 / x, lambda x: 1 / x 114 | else: 115 | raise ValueError(f"Unknown transform_type: {transform_type}") 116 | s_min, s_max = _contract_fn(t_min), _contract_fn(t_max) 117 | icontract_fn = lambda s: _icontract_fn(s * s_max + (1 - s) * s_min) 118 | return icontract_fn(s_vals) -------------------------------------------------------------------------------- /threestudio/models/exporters/__init__.py: -------------------------------------------------------------------------------- 1 | from . import base, mesh_exporter 2 | -------------------------------------------------------------------------------- /threestudio/models/exporters/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import threestudio 4 | from threestudio.models.background.base import BaseBackground 5 | from threestudio.models.geometry.base import BaseImplicitGeometry 6 | from threestudio.models.materials.base import BaseMaterial 7 | from threestudio.utils.base import BaseObject 8 | from threestudio.utils.typing import * 9 | 10 | 11 | @dataclass 12 | class ExporterOutput: 13 | save_name: str 14 | save_type: str 15 | params: Dict[str, Any] 16 | 17 | 18 | class Exporter(BaseObject): 19 | @dataclass 20 | class Config(BaseObject.Config): 21 | save_video: bool = False 22 | 23 | cfg: Config 24 | 25 | def configure( 26 | self, 27 | geometry: BaseImplicitGeometry, 28 | material: BaseMaterial, 29 | background: BaseBackground, 30 | ) -> None: 31 | @dataclass 32 | class SubModules: 33 | geometry: BaseImplicitGeometry 34 | material: BaseMaterial 35 | background: BaseBackground 36 | 37 | self.sub_modules = SubModules(geometry, material, background) 38 | 39 | @property 40 | def geometry(self) -> BaseImplicitGeometry: 41 | return self.sub_modules.geometry 42 | 43 | @property 44 | def material(self) -> BaseMaterial: 45 | return self.sub_modules.material 46 | 47 | @property 48 | def background(self) -> BaseBackground: 49 | return self.sub_modules.background 50 | 51 | def __call__(self, *args, **kwargs) -> List[ExporterOutput]: 52 | raise NotImplementedError 53 | 54 | 55 | @threestudio.register("dummy-exporter") 56 | class DummyExporter(Exporter): 57 | def __call__(self, *args, **kwargs) -> List[ExporterOutput]: 58 | # DummyExporter does not export anything 59 | return [] -------------------------------------------------------------------------------- /threestudio/models/geometry/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | base, 3 | custom_mesh, 4 | implicit_sdf, 5 | implicit_volume, 6 | tetrahedra_sdf_grid, 7 | volume_grid, 8 | ) -------------------------------------------------------------------------------- /threestudio/models/guidance/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | controlnet_guidance, 3 | controlnet_reg_guidance, 4 | deep_floyd_guidance, 5 | stable_diffusion_guidance, 6 | stable_diffusion_unified_guidance, 7 | stable_diffusion_vsd_guidance, 8 | stable_diffusion_bsd_guidance, 9 | stable_zero123_guidance, 10 | zero123_guidance, 11 | zero123_unified_guidance, 12 | clip_guidance, 13 | ) 14 | -------------------------------------------------------------------------------- /threestudio/models/guidance/clip_guidance.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import torch 3 | import torch.nn.functional as F 4 | import torchvision.transforms as T 5 | import clip 6 | 7 | import threestudio 8 | from threestudio.utils.base import BaseObject 9 | from threestudio.models.prompt_processors.base import PromptProcessorOutput 10 | from threestudio.utils.typing import * 11 | 12 | 13 | @threestudio.register("clip-guidance") 14 | class CLIPGuidance(BaseObject): 15 | @dataclass 16 | class Config(BaseObject.Config): 17 | cache_dir: Optional[str] = None 18 | pretrained_model_name_or_path: str = "ViT-B/16" 19 | view_dependent_prompting: bool = True 20 | 21 | cfg: Config 22 | 23 | def configure(self) -> None: 24 | threestudio.info(f"Loading CLIP ...") 25 | self.clip_model, self.clip_preprocess = clip.load( 26 | self.cfg.pretrained_model_name_or_path, 27 | device=self.device, 28 | jit=False, 29 | download_root=self.cfg.cache_dir 30 | ) 31 | 32 | self.aug = T.Compose([ 33 | T.Resize((224, 224)), 34 | T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 35 | ]) 36 | 37 | threestudio.info(f"Loaded CLIP!") 38 | 39 | @torch.cuda.amp.autocast(enabled=False) 40 | def get_embedding(self, input_value, is_text=True): 41 | if is_text: 42 | value = clip.tokenize(input_value).to(self.device) 43 | z = self.clip_model.encode_text(value) 44 | else: 45 | input_value = self.aug(input_value) 46 | z = self.clip_model.encode_image(input_value) 47 | 48 | return z / z.norm(dim=-1, keepdim=True) 49 | 50 | def get_loss(self, image_z, clip_z, loss_type='similarity_score', use_mean=True): 51 | if loss_type == 'similarity_score': 52 | loss = -((image_z * clip_z).sum(-1)) 53 | elif loss_type == 'spherical_dist': 54 | image_z, clip_z = F.normalize(image_z, dim=-1), F.normalize(clip_z, dim=-1) 55 | loss = ((image_z - clip_z).norm(dim=-1).div(2).arcsin().pow(2).mul(2)) 56 | else: 57 | raise NotImplementedError 58 | 59 | return loss.mean() if use_mean else loss 60 | 61 | def __call__( 62 | self, 63 | pred_rgb: Float[Tensor, "B H W C"], 64 | gt_rgb: Float[Tensor, "B H W C"], 65 | prompt_utils: PromptProcessorOutput, 66 | elevation: Float[Tensor, "B"], 67 | azimuth: Float[Tensor, "B"], 68 | camera_distances: Float[Tensor, "B"], 69 | embedding_type: str = 'both', 70 | loss_type: Optional[str] = 'similarity_score', 71 | **kwargs, 72 | ): 73 | clip_text_loss, clip_img_loss = 0, 0 74 | 75 | if embedding_type in ('both', 'text'): 76 | text_embeddings = prompt_utils.get_text_embeddings( 77 | elevation, azimuth, camera_distances, self.cfg.view_dependent_prompting 78 | ).chunk(2)[0] 79 | clip_text_loss = self.get_loss(self.get_embedding(pred_rgb, is_text=False), text_embeddings, loss_type=loss_type) 80 | 81 | if embedding_type in ('both', 'img'): 82 | clip_img_loss = self.get_loss(self.get_embedding(pred_rgb, is_text=False), self.get_embedding(gt_rgb, is_text=False), loss_type=loss_type) 83 | 84 | return clip_text_loss + clip_img_loss 85 | -------------------------------------------------------------------------------- /threestudio/models/materials/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | base, 3 | diffuse_with_point_light_material, 4 | hybrid_rgb_latent_material, 5 | neural_radiance_material, 6 | no_material, 7 | pbr_material, 8 | sd_latent_adapter_material, 9 | ) 10 | -------------------------------------------------------------------------------- /threestudio/models/materials/base.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.utils.base import BaseModule 10 | from threestudio.utils.typing import * 11 | 12 | 13 | class BaseMaterial(BaseModule): 14 | @dataclass 15 | class Config(BaseModule.Config): 16 | pass 17 | 18 | cfg: Config 19 | requires_normal: bool = False 20 | requires_tangent: bool = False 21 | 22 | def configure(self): 23 | pass 24 | 25 | def forward(self, *args, **kwargs) -> Float[Tensor, "*B 3"]: 26 | raise NotImplementedError 27 | 28 | def export(self, *args, **kwargs) -> Dict[str, Any]: 29 | return {} 30 | -------------------------------------------------------------------------------- /threestudio/models/materials/diffuse_with_point_light_material.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.models.materials.base import BaseMaterial 10 | from threestudio.utils.ops import dot, get_activation 11 | from threestudio.utils.typing import * 12 | 13 | 14 | @threestudio.register("diffuse-with-point-light-material") 15 | class DiffuseWithPointLightMaterial(BaseMaterial): 16 | @dataclass 17 | class Config(BaseMaterial.Config): 18 | ambient_light_color: Tuple[float, float, float] = (0.1, 0.1, 0.1) 19 | diffuse_light_color: Tuple[float, float, float] = (0.9, 0.9, 0.9) 20 | ambient_only_steps: int = 1000 21 | diffuse_prob: float = 0.75 22 | textureless_prob: float = 0.5 23 | albedo_activation: str = "sigmoid" 24 | soft_shading: bool = False 25 | 26 | cfg: Config 27 | 28 | def configure(self) -> None: 29 | self.requires_normal = True 30 | 31 | self.ambient_light_color: Float[Tensor, "3"] 32 | self.register_buffer( 33 | "ambient_light_color", 34 | torch.as_tensor(self.cfg.ambient_light_color, dtype=torch.float32), 35 | ) 36 | self.diffuse_light_color: Float[Tensor, "3"] 37 | self.register_buffer( 38 | "diffuse_light_color", 39 | torch.as_tensor(self.cfg.diffuse_light_color, dtype=torch.float32), 40 | ) 41 | self.ambient_only = False 42 | 43 | def forward( 44 | self, 45 | features: Float[Tensor, "B ... Nf"], 46 | positions: Float[Tensor, "B ... 3"], 47 | shading_normal: Float[Tensor, "B ... 3"], 48 | light_positions: Float[Tensor, "B ... 3"], 49 | ambient_ratio: Optional[float] = None, 50 | shading: Optional[str] = None, 51 | **kwargs, 52 | ) -> Float[Tensor, "B ... 3"]: 53 | albedo = get_activation(self.cfg.albedo_activation)(features[..., :3]) 54 | 55 | if ambient_ratio is not None: 56 | # if ambient ratio is specified, use it 57 | diffuse_light_color = (1 - ambient_ratio) * torch.ones_like( 58 | self.diffuse_light_color 59 | ) 60 | ambient_light_color = ambient_ratio * torch.ones_like( 61 | self.ambient_light_color 62 | ) 63 | elif self.training and self.cfg.soft_shading: 64 | # otherwise if in training and soft shading is enabled, random a ambient ratio 65 | diffuse_light_color = torch.full_like( 66 | self.diffuse_light_color, random.random() 67 | ) 68 | ambient_light_color = 1.0 - diffuse_light_color 69 | else: 70 | # otherwise use the default fixed values 71 | diffuse_light_color = self.diffuse_light_color 72 | ambient_light_color = self.ambient_light_color 73 | 74 | light_directions: Float[Tensor, "B ... 3"] = F.normalize( 75 | light_positions - positions, dim=-1 76 | ) 77 | diffuse_light: Float[Tensor, "B ... 3"] = ( 78 | dot(shading_normal, light_directions).clamp(min=0.0) * diffuse_light_color 79 | ) 80 | textureless_color = diffuse_light + ambient_light_color 81 | # clamp albedo to [0, 1] to compute shading 82 | color = albedo.clamp(0.0, 1.0) * textureless_color 83 | 84 | if shading is None: 85 | if self.training: 86 | # adopt the same type of augmentation for the whole batch 87 | if self.ambient_only or random.random() > self.cfg.diffuse_prob: 88 | shading = "albedo" 89 | elif random.random() < self.cfg.textureless_prob: 90 | shading = "textureless" 91 | else: 92 | shading = "diffuse" 93 | else: 94 | if self.ambient_only: 95 | shading = "albedo" 96 | else: 97 | # return shaded color by default in evaluation 98 | shading = "diffuse" 99 | 100 | # multiply by 0 to prevent checking for unused parameters in DDP 101 | if shading == "albedo": 102 | return albedo + textureless_color * 0 103 | elif shading == "textureless": 104 | return albedo * 0 + textureless_color 105 | elif shading == "diffuse": 106 | return color 107 | else: 108 | raise ValueError(f"Unknown shading type {shading}") 109 | 110 | def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): 111 | if global_step < self.cfg.ambient_only_steps: 112 | self.ambient_only = True 113 | else: 114 | self.ambient_only = False 115 | 116 | def export(self, features: Float[Tensor, "*N Nf"], **kwargs) -> Dict[str, Any]: 117 | albedo = get_activation(self.cfg.albedo_activation)(features[..., :3]).clamp( 118 | 0.0, 1.0 119 | ) 120 | return {"albedo": albedo} 121 | -------------------------------------------------------------------------------- /threestudio/models/materials/hybrid_rgb_latent_material.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.models.materials.base import BaseMaterial 10 | from threestudio.models.networks import get_encoding, get_mlp 11 | from threestudio.utils.ops import dot, get_activation 12 | from threestudio.utils.typing import * 13 | 14 | 15 | @threestudio.register("hybrid-rgb-latent-material") 16 | class HybridRGBLatentMaterial(BaseMaterial): 17 | @dataclass 18 | class Config(BaseMaterial.Config): 19 | n_output_dims: int = 3 20 | color_activation: str = "sigmoid" 21 | requires_normal: bool = True 22 | 23 | cfg: Config 24 | 25 | def configure(self) -> None: 26 | self.requires_normal = self.cfg.requires_normal 27 | 28 | def forward( 29 | self, features: Float[Tensor, "B ... Nf"], **kwargs 30 | ) -> Float[Tensor, "B ... Nc"]: 31 | assert ( 32 | features.shape[-1] == self.cfg.n_output_dims 33 | ), f"Expected {self.cfg.n_output_dims} output dims, only got {features.shape[-1]} dims input." 34 | color = features 35 | color[..., :3] = get_activation(self.cfg.color_activation)(color[..., :3]) 36 | return color 37 | -------------------------------------------------------------------------------- /threestudio/models/materials/neural_radiance_material.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.models.materials.base import BaseMaterial 10 | from threestudio.models.networks import get_encoding, get_mlp 11 | from threestudio.utils.ops import dot, get_activation 12 | from threestudio.utils.typing import * 13 | 14 | 15 | @threestudio.register("neural-radiance-material") 16 | class NeuralRadianceMaterial(BaseMaterial): 17 | @dataclass 18 | class Config(BaseMaterial.Config): 19 | input_feature_dims: int = 8 20 | color_activation: str = "sigmoid" 21 | dir_encoding_config: dict = field( 22 | default_factory=lambda: {"otype": "SphericalHarmonics", "degree": 3} 23 | ) 24 | mlp_network_config: dict = field( 25 | default_factory=lambda: { 26 | "otype": "FullyFusedMLP", 27 | "activation": "ReLU", 28 | "n_neurons": 16, 29 | "n_hidden_layers": 2, 30 | } 31 | ) 32 | 33 | cfg: Config 34 | 35 | def configure(self) -> None: 36 | self.encoding = get_encoding(3, self.cfg.dir_encoding_config) 37 | self.n_input_dims = self.cfg.input_feature_dims + self.encoding.n_output_dims # type: ignore 38 | self.network = get_mlp(self.n_input_dims, 3, self.cfg.mlp_network_config) 39 | 40 | def forward( 41 | self, 42 | features: Float[Tensor, "*B Nf"], 43 | viewdirs: Float[Tensor, "*B 3"], 44 | **kwargs, 45 | ) -> Float[Tensor, "*B 3"]: 46 | # viewdirs and normals must be normalized before passing to this function 47 | viewdirs = (viewdirs + 1.0) / 2.0 # (-1, 1) => (0, 1) 48 | viewdirs_embd = self.encoding(viewdirs.view(-1, 3)) 49 | network_inp = torch.cat( 50 | [features.view(-1, features.shape[-1]), viewdirs_embd], dim=-1 51 | ) 52 | color = self.network(network_inp).view(*features.shape[:-1], 3) 53 | color = get_activation(self.cfg.color_activation)(color) 54 | return color 55 | -------------------------------------------------------------------------------- /threestudio/models/materials/no_material.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.models.materials.base import BaseMaterial 10 | from threestudio.models.networks import get_encoding, get_mlp 11 | from threestudio.utils.ops import dot, get_activation 12 | from threestudio.utils.typing import * 13 | 14 | 15 | @threestudio.register("no-material") 16 | class NoMaterial(BaseMaterial): 17 | @dataclass 18 | class Config(BaseMaterial.Config): 19 | n_output_dims: int = 3 20 | color_activation: str = "sigmoid" 21 | input_feature_dims: Optional[int] = None 22 | mlp_network_config: Optional[dict] = None 23 | requires_normal: bool = False 24 | 25 | cfg: Config 26 | 27 | def configure(self) -> None: 28 | self.use_network = False 29 | if ( 30 | self.cfg.input_feature_dims is not None 31 | and self.cfg.mlp_network_config is not None 32 | ): 33 | self.network = get_mlp( 34 | self.cfg.input_feature_dims, 35 | self.cfg.n_output_dims, 36 | self.cfg.mlp_network_config, 37 | ) 38 | self.use_network = True 39 | self.requires_normal = self.cfg.requires_normal 40 | 41 | def forward( 42 | self, features: Float[Tensor, "B ... Nf"], **kwargs 43 | ) -> Float[Tensor, "B ... Nc"]: 44 | if not self.use_network: 45 | assert ( 46 | features.shape[-1] == self.cfg.n_output_dims 47 | ), f"Expected {self.cfg.n_output_dims} output dims, only got {features.shape[-1]} dims input." 48 | color = get_activation(self.cfg.color_activation)(features) 49 | else: 50 | color = self.network(features.view(-1, features.shape[-1])).view( 51 | *features.shape[:-1], self.cfg.n_output_dims 52 | ) 53 | color = get_activation(self.cfg.color_activation)(color) 54 | return color 55 | 56 | def export(self, features: Float[Tensor, "*N Nf"], **kwargs) -> Dict[str, Any]: 57 | color = self(features, **kwargs).clamp(0, 1) 58 | assert color.shape[-1] >= 3, "Output color must have at least 3 channels" 59 | if color.shape[-1] > 3: 60 | threestudio.warn( 61 | "Output color has >3 channels, treating the first 3 as RGB" 62 | ) 63 | return {"albedo": color[..., :3]} 64 | -------------------------------------------------------------------------------- /threestudio/models/materials/pbr_material.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | 4 | import envlight 5 | import numpy as np 6 | import nvdiffrast.torch as dr 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | import threestudio 12 | from threestudio.models.materials.base import BaseMaterial 13 | from threestudio.utils.ops import get_activation 14 | from threestudio.utils.typing import * 15 | 16 | 17 | @threestudio.register("pbr-material") 18 | class PBRMaterial(BaseMaterial): 19 | @dataclass 20 | class Config(BaseMaterial.Config): 21 | material_activation: str = "sigmoid" 22 | environment_texture: str = "load/lights/mud_road_puresky_1k.hdr" 23 | environment_scale: float = 2.0 24 | min_metallic: float = 0.0 25 | max_metallic: float = 0.9 26 | min_roughness: float = 0.08 27 | max_roughness: float = 0.9 28 | use_bump: bool = True 29 | 30 | cfg: Config 31 | 32 | def configure(self) -> None: 33 | self.requires_normal = True 34 | self.requires_tangent = self.cfg.use_bump 35 | 36 | self.light = envlight.EnvLight( 37 | self.cfg.environment_texture, scale=self.cfg.environment_scale 38 | ) 39 | 40 | FG_LUT = torch.from_numpy( 41 | np.fromfile("load/lights/bsdf_256_256.bin", dtype=np.float32).reshape( 42 | 1, 256, 256, 2 43 | ) 44 | ) 45 | self.register_buffer("FG_LUT", FG_LUT) 46 | 47 | def forward( 48 | self, 49 | features: Float[Tensor, "*B Nf"], 50 | viewdirs: Float[Tensor, "*B 3"], 51 | shading_normal: Float[Tensor, "B ... 3"], 52 | tangent: Optional[Float[Tensor, "B ... 3"]] = None, 53 | **kwargs, 54 | ) -> Float[Tensor, "*B 3"]: 55 | prefix_shape = features.shape[:-1] 56 | 57 | material: Float[Tensor, "*B Nf"] = get_activation(self.cfg.material_activation)( 58 | features 59 | ) 60 | albedo = material[..., :3] 61 | metallic = ( 62 | material[..., 3:4] * (self.cfg.max_metallic - self.cfg.min_metallic) 63 | + self.cfg.min_metallic 64 | ) 65 | roughness = ( 66 | material[..., 4:5] * (self.cfg.max_roughness - self.cfg.min_roughness) 67 | + self.cfg.min_roughness 68 | ) 69 | 70 | if self.cfg.use_bump: 71 | assert tangent is not None 72 | # perturb_normal is a delta to the initialization [0, 0, 1] 73 | perturb_normal = (material[..., 5:8] * 2 - 1) + torch.tensor( 74 | [0, 0, 1], dtype=material.dtype, device=material.device 75 | ) 76 | perturb_normal = F.normalize(perturb_normal.clamp(-1, 1), dim=-1) 77 | 78 | # apply normal perturbation in tangent space 79 | bitangent = F.normalize(torch.cross(tangent, shading_normal), dim=-1) 80 | shading_normal = ( 81 | tangent * perturb_normal[..., 0:1] 82 | - bitangent * perturb_normal[..., 1:2] 83 | + shading_normal * perturb_normal[..., 2:3] 84 | ) 85 | shading_normal = F.normalize(shading_normal, dim=-1) 86 | 87 | v = -viewdirs 88 | n_dot_v = (shading_normal * v).sum(-1, keepdim=True) 89 | reflective = n_dot_v * shading_normal * 2 - v 90 | 91 | diffuse_albedo = (1 - metallic) * albedo 92 | 93 | fg_uv = torch.cat([n_dot_v, roughness], -1).clamp(0, 1) 94 | fg = dr.texture( 95 | self.FG_LUT, 96 | fg_uv.reshape(1, -1, 1, 2).contiguous(), 97 | filter_mode="linear", 98 | boundary_mode="clamp", 99 | ).reshape(*prefix_shape, 2) 100 | F0 = (1 - metallic) * 0.04 + metallic * albedo 101 | specular_albedo = F0 * fg[:, 0:1] + fg[:, 1:2] 102 | 103 | diffuse_light = self.light(shading_normal) 104 | specular_light = self.light(reflective, roughness) 105 | 106 | color = diffuse_albedo * diffuse_light + specular_albedo * specular_light 107 | color = color.clamp(0.0, 1.0) 108 | 109 | return color 110 | 111 | def export(self, features: Float[Tensor, "*N Nf"], **kwargs) -> Dict[str, Any]: 112 | material: Float[Tensor, "*N Nf"] = get_activation(self.cfg.material_activation)( 113 | features 114 | ) 115 | albedo = material[..., :3] 116 | metallic = ( 117 | material[..., 3:4] * (self.cfg.max_metallic - self.cfg.min_metallic) 118 | + self.cfg.min_metallic 119 | ) 120 | roughness = ( 121 | material[..., 4:5] * (self.cfg.max_roughness - self.cfg.min_roughness) 122 | + self.cfg.min_roughness 123 | ) 124 | 125 | out = { 126 | "albedo": albedo, 127 | "metallic": metallic, 128 | "roughness": roughness, 129 | } 130 | 131 | if self.cfg.use_bump: 132 | perturb_normal = (material[..., 5:8] * 2 - 1) + torch.tensor( 133 | [0, 0, 1], dtype=material.dtype, device=material.device 134 | ) 135 | perturb_normal = F.normalize(perturb_normal.clamp(-1, 1), dim=-1) 136 | perturb_normal = (perturb_normal + 1) / 2 137 | out.update( 138 | { 139 | "bump": perturb_normal, 140 | } 141 | ) 142 | 143 | return out 144 | -------------------------------------------------------------------------------- /threestudio/models/materials/sd_latent_adapter_material.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import threestudio 9 | from threestudio.models.materials.base import BaseMaterial 10 | from threestudio.utils.typing import * 11 | 12 | 13 | @threestudio.register("sd-latent-adapter-material") 14 | class StableDiffusionLatentAdapterMaterial(BaseMaterial): 15 | @dataclass 16 | class Config(BaseMaterial.Config): 17 | pass 18 | 19 | cfg: Config 20 | 21 | def configure(self) -> None: 22 | adapter = nn.Parameter( 23 | torch.as_tensor( 24 | [ 25 | # R G B 26 | [0.298, 0.207, 0.208], # L1 27 | [0.187, 0.286, 0.173], # L2 28 | [-0.158, 0.189, 0.264], # L3 29 | [-0.184, -0.271, -0.473], # L4 30 | ] 31 | ) 32 | ) 33 | self.register_parameter("adapter", adapter) 34 | 35 | def forward( 36 | self, features: Float[Tensor, "B ... 4"], **kwargs 37 | ) -> Float[Tensor, "B ... 3"]: 38 | assert features.shape[-1] == 4 39 | color = features @ self.adapter 40 | color = (color + 1) / 2 41 | color = color.clamp(0.0, 1.0) 42 | return color 43 | -------------------------------------------------------------------------------- /threestudio/models/prompt_processors/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | base, 3 | deepfloyd_prompt_processor, 4 | dummy_prompt_processor, 5 | stable_diffusion_prompt_processor, 6 | clip_prompt_processor, 7 | ) -------------------------------------------------------------------------------- /threestudio/models/prompt_processors/clip_prompt_processor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from dataclasses import dataclass 4 | 5 | import clip 6 | import torch 7 | import torch 8 | import torch.nn as nn 9 | 10 | import threestudio 11 | from threestudio.models.prompt_processors.base import PromptProcessor, hash_prompt 12 | from threestudio.utils.misc import cleanup 13 | from threestudio.utils.typing import * 14 | 15 | 16 | @threestudio.register("clip-prompt-processor") 17 | class ClipPromptProcessor(PromptProcessor): 18 | @dataclass 19 | class Config(PromptProcessor.Config): 20 | pass 21 | 22 | cfg: Config 23 | 24 | @staticmethod 25 | def spawn_func(pretrained_model_name_or_path, prompts, cache_dir, device): 26 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 27 | clip_model, _ = clip.load(pretrained_model_name_or_path, jit=False) 28 | with torch.no_grad(): 29 | tokens = clip.tokenize( 30 | prompts, 31 | ).to(device) 32 | text_embeddings = clip_model.encode_text(tokens) 33 | text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True) 34 | 35 | for prompt, embedding in zip(prompts, text_embeddings): 36 | torch.save( 37 | embedding, 38 | os.path.join( 39 | cache_dir, 40 | f"{hash_prompt(pretrained_model_name_or_path, prompt)}.pt", 41 | ), 42 | ) 43 | 44 | del clip_model 45 | -------------------------------------------------------------------------------- /threestudio/models/prompt_processors/deepfloyd_prompt_processor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from dataclasses import dataclass 4 | 5 | import torch 6 | import torch.nn as nn 7 | from diffusers import IFPipeline 8 | from transformers import T5EncoderModel, T5Tokenizer 9 | 10 | import threestudio 11 | from threestudio.models.prompt_processors.base import PromptProcessor, hash_prompt 12 | from threestudio.utils.misc import cleanup 13 | from threestudio.utils.typing import * 14 | 15 | 16 | @threestudio.register("deep-floyd-prompt-processor") 17 | class DeepFloydPromptProcessor(PromptProcessor): 18 | @dataclass 19 | class Config(PromptProcessor.Config): 20 | pretrained_model_name_or_path: str = "DeepFloyd/IF-I-XL-v1.0" 21 | 22 | cfg: Config 23 | 24 | ### these functions are unused, kept for debugging ### 25 | def configure_text_encoder(self) -> None: 26 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 27 | self.text_encoder = T5EncoderModel.from_pretrained( 28 | self.cfg.pretrained_model_name_or_path, 29 | subfolder="text_encoder", 30 | load_in_8bit=True, 31 | variant="8bit", 32 | device_map="auto", 33 | ) # FIXME: behavior of auto device map in multi-GPU training 34 | self.pipe = IFPipeline.from_pretrained( 35 | self.cfg.pretrained_model_name_or_path, 36 | text_encoder=self.text_encoder, # pass the previously instantiated 8bit text encoder 37 | unet=None, 38 | ) 39 | 40 | def destroy_text_encoder(self) -> None: 41 | del self.text_encoder 42 | del self.pipe 43 | cleanup() 44 | 45 | def get_text_embeddings( 46 | self, prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]] 47 | ) -> Tuple[Float[Tensor, "B 77 4096"], Float[Tensor, "B 77 4096"]]: 48 | text_embeddings, uncond_text_embeddings = self.pipe.encode_prompt( 49 | prompt=prompt, negative_prompt=negative_prompt, device=self.device 50 | ) 51 | return text_embeddings, uncond_text_embeddings 52 | 53 | ### 54 | 55 | @staticmethod 56 | def spawn_func(pretrained_model_name_or_path, prompts, cache_dir, device): 57 | max_length = 77 58 | tokenizer = T5Tokenizer.from_pretrained( 59 | pretrained_model_name_or_path, 60 | subfolder="tokenizer", 61 | ) 62 | text_encoder = T5EncoderModel.from_pretrained( 63 | pretrained_model_name_or_path, 64 | subfolder="text_encoder", 65 | torch_dtype=torch.float16, # suppress warning 66 | load_in_8bit=True, 67 | variant="8bit", 68 | device_map="auto", 69 | ) 70 | with torch.no_grad(): 71 | text_inputs = tokenizer( 72 | prompts, 73 | padding="max_length", 74 | max_length=max_length, 75 | truncation=True, 76 | add_special_tokens=True, 77 | return_tensors="pt", 78 | ) 79 | text_input_ids = text_inputs.input_ids 80 | attention_mask = text_inputs.attention_mask 81 | text_embeddings = text_encoder( 82 | text_input_ids.to(text_encoder.device), 83 | attention_mask=attention_mask.to(text_encoder.device), 84 | ) 85 | text_embeddings = text_embeddings[0] 86 | 87 | for prompt, embedding in zip(prompts, text_embeddings): 88 | torch.save( 89 | embedding, 90 | os.path.join( 91 | cache_dir, 92 | f"{hash_prompt(pretrained_model_name_or_path, prompt)}.pt", 93 | ), 94 | ) 95 | 96 | del text_encoder -------------------------------------------------------------------------------- /threestudio/models/prompt_processors/dummy_prompt_processor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from dataclasses import dataclass 4 | 5 | import threestudio 6 | from threestudio.models.prompt_processors.base import PromptProcessor, hash_prompt 7 | from threestudio.utils.misc import cleanup 8 | from threestudio.utils.typing import * 9 | 10 | 11 | @threestudio.register("dummy-prompt-processor") 12 | class DummyPromptProcessor(PromptProcessor): 13 | @dataclass 14 | class Config(PromptProcessor.Config): 15 | pretrained_model_name_or_path: str = "" 16 | prompt: str = "" 17 | 18 | cfg: Config -------------------------------------------------------------------------------- /threestudio/models/renderers/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | base, 3 | deferred_volume_renderer, 4 | gan_volume_renderer, 5 | nerf_volume_renderer, 6 | neus_volume_renderer, 7 | nvdiff_rasterizer, 8 | patch_renderer, 9 | ) 10 | -------------------------------------------------------------------------------- /threestudio/models/renderers/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import nerfacc 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | import threestudio 8 | from threestudio.models.background.base import BaseBackground 9 | from threestudio.models.geometry.base import BaseImplicitGeometry 10 | from threestudio.models.materials.base import BaseMaterial 11 | from threestudio.utils.base import BaseModule 12 | from threestudio.utils.typing import * 13 | 14 | 15 | class Renderer(BaseModule): 16 | @dataclass 17 | class Config(BaseModule.Config): 18 | radius: float = 1.0 19 | 20 | cfg: Config 21 | 22 | def configure( 23 | self, 24 | geometry: BaseImplicitGeometry, 25 | material: BaseMaterial, 26 | background: BaseBackground, 27 | ) -> None: 28 | # keep references to submodules using namedtuple, avoid being registered as modules 29 | @dataclass 30 | class SubModules: 31 | geometry: BaseImplicitGeometry 32 | material: BaseMaterial 33 | background: BaseBackground 34 | 35 | self.sub_modules = SubModules(geometry, material, background) 36 | 37 | # set up bounding box 38 | self.bbox: Float[Tensor, "2 3"] 39 | self.register_buffer( 40 | "bbox", 41 | torch.as_tensor( 42 | [ 43 | [-self.cfg.radius, -self.cfg.radius, -self.cfg.radius], 44 | [self.cfg.radius, self.cfg.radius, self.cfg.radius], 45 | ], 46 | dtype=torch.float32, 47 | ), 48 | ) 49 | 50 | def forward(self, *args, **kwargs) -> Dict[str, Any]: 51 | raise NotImplementedError 52 | 53 | @property 54 | def geometry(self) -> BaseImplicitGeometry: 55 | return self.sub_modules.geometry 56 | 57 | @property 58 | def material(self) -> BaseMaterial: 59 | return self.sub_modules.material 60 | 61 | @property 62 | def background(self) -> BaseBackground: 63 | return self.sub_modules.background 64 | 65 | def set_geometry(self, geometry: BaseImplicitGeometry) -> None: 66 | self.sub_modules.geometry = geometry 67 | 68 | def set_material(self, material: BaseMaterial) -> None: 69 | self.sub_modules.material = material 70 | 71 | def set_background(self, background: BaseBackground) -> None: 72 | self.sub_modules.background = background 73 | 74 | 75 | class VolumeRenderer(Renderer): 76 | pass 77 | 78 | 79 | class Rasterizer(Renderer): 80 | pass -------------------------------------------------------------------------------- /threestudio/models/renderers/deferred_volume_renderer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | import threestudio 7 | from threestudio.models.renderers.base import VolumeRenderer 8 | 9 | 10 | class DeferredVolumeRenderer(VolumeRenderer): 11 | pass 12 | -------------------------------------------------------------------------------- /threestudio/models/renderers/patch_renderer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | import threestudio 7 | from threestudio.models.background.base import BaseBackground 8 | from threestudio.models.geometry.base import BaseImplicitGeometry 9 | from threestudio.models.materials.base import BaseMaterial 10 | from threestudio.models.renderers.base import VolumeRenderer 11 | from threestudio.utils.typing import * 12 | 13 | 14 | @threestudio.register("patch-renderer") 15 | class PatchRenderer(VolumeRenderer): 16 | @dataclass 17 | class Config(VolumeRenderer.Config): 18 | patch_size: int = 128 19 | base_renderer_type: str = "" 20 | base_renderer: Optional[VolumeRenderer.Config] = None 21 | global_detach: bool = False 22 | global_downsample: int = 4 23 | 24 | cfg: Config 25 | 26 | def configure( 27 | self, 28 | geometry: BaseImplicitGeometry, 29 | material: BaseMaterial, 30 | background: BaseBackground, 31 | ) -> None: 32 | self.base_renderer = threestudio.find(self.cfg.base_renderer_type)( 33 | self.cfg.base_renderer, 34 | geometry=geometry, 35 | material=material, 36 | background=background, 37 | ) 38 | 39 | def forward( 40 | self, 41 | rays_o: Float[Tensor, "B H W 3"], 42 | rays_d: Float[Tensor, "B H W 3"], 43 | light_positions: Float[Tensor, "B 3"], 44 | bg_color: Optional[Tensor] = None, 45 | **kwargs 46 | ) -> Dict[str, Float[Tensor, "..."]]: 47 | B, H, W, _ = rays_o.shape 48 | 49 | if self.base_renderer.training: 50 | downsample = self.cfg.global_downsample 51 | global_rays_o = torch.nn.functional.interpolate( 52 | rays_o.permute(0, 3, 1, 2), 53 | (H // downsample, W // downsample), 54 | mode="bilinear", 55 | ).permute(0, 2, 3, 1) 56 | global_rays_d = torch.nn.functional.interpolate( 57 | rays_d.permute(0, 3, 1, 2), 58 | (H // downsample, W // downsample), 59 | mode="bilinear", 60 | ).permute(0, 2, 3, 1) 61 | out_global = self.base_renderer( 62 | global_rays_o, global_rays_d, light_positions, bg_color, **kwargs 63 | ) 64 | 65 | PS = self.cfg.patch_size 66 | patch_x = torch.randint(0, W - PS, (1,)).item() 67 | patch_y = torch.randint(0, H - PS, (1,)).item() 68 | patch_rays_o = rays_o[:, patch_y : patch_y + PS, patch_x : patch_x + PS] 69 | patch_rays_d = rays_d[:, patch_y : patch_y + PS, patch_x : patch_x + PS] 70 | out = self.base_renderer( 71 | patch_rays_o, patch_rays_d, light_positions, bg_color, **kwargs 72 | ) 73 | 74 | valid_patch_key = [] 75 | for key in out: 76 | if torch.is_tensor(out[key]): 77 | if len(out[key].shape) == len(out["comp_rgb"].shape): 78 | if out[key][..., 0].shape == out["comp_rgb"][..., 0].shape: 79 | valid_patch_key.append(key) 80 | for key in valid_patch_key: 81 | out_global[key] = F.interpolate( 82 | out_global[key].permute(0, 3, 1, 2), (H, W), mode="bilinear" 83 | ).permute(0, 2, 3, 1) 84 | if self.cfg.global_detach: 85 | out_global[key] = out_global[key].detach() 86 | out_global[key][ 87 | :, patch_y : patch_y + PS, patch_x : patch_x + PS 88 | ] = out[key] 89 | out = out_global 90 | else: 91 | out = self.base_renderer( 92 | rays_o, rays_d, light_positions, bg_color, **kwargs 93 | ) 94 | 95 | return out 96 | 97 | def update_step( 98 | self, epoch: int, global_step: int, on_load_weights: bool = False 99 | ) -> None: 100 | self.base_renderer.update_step(epoch, global_step, on_load_weights) 101 | 102 | def train(self, mode=True): 103 | return self.base_renderer.train(mode) 104 | 105 | def eval(self): 106 | return self.base_renderer.eval() -------------------------------------------------------------------------------- /threestudio/scripts/dreamcraft3d_dreambooth.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from subprocess import run, CalledProcessError 4 | 5 | import cv2 6 | import glob 7 | import numpy as np 8 | import pytorch_lightning as pl 9 | import torch 10 | from tqdm import tqdm 11 | from torchvision.utils import save_image 12 | 13 | from threestudio.scripts.generate_mv_datasets import generate_mv_dataset 14 | from threestudio.utils.config import load_config 15 | import threestudio 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--config", required=True, help="path to config file") 21 | parser.add_argument("--action", default="both", help="action to perform", choices=["gen_data", "dreambooth", "both""]) 22 | args, extras = parser.parse_known_args() 23 | return args, extras 24 | 25 | 26 | def main(args, extras): 27 | cfg = load_config(args.config, cli_args=extras, n_gpus=1) 28 | 29 | if args.action == "gen_data" or args.action == "both": 30 | # Generate multi-view dataset 31 | generate_mv_dataset(cfg) 32 | 33 | if args.action == "dreambooth" or args.action == "both": 34 | # Run DreamBooth. 35 | command = f'accelerate launch threestudio/scripts/train_dreambooth.py \ 36 | --pretrained_model_name_or_path="{cfg.custom_import.dreambooth.model_name}" \ 37 | --instance_data_dir="{cfg.custom_import.dreambooth.instance_dir}" \ 38 | --output_dir="{cfg.custom_import.dreambooth.output_dir}"\ 39 | --instance_prompt="{cfg.custom_import.dreambooth.prompt_dreambooth}" \ 40 | --resolution=512 \ 41 | --train_batch_size=2 \ 42 | --gradient_accumulation_steps=1 \ 43 | --learning_rate=1e-6 \ 44 | --lr_scheduler="constant" \ 45 | --lr_warmup_steps=0 \ 46 | --max_train_steps=1000' 47 | 48 | os.system(command) 49 | 50 | 51 | if __name__ == "__main__": 52 | args, extras = parse_args() 53 | main(args, extras) 54 | -------------------------------------------------------------------------------- /threestudio/scripts/generate_images_if.py: -------------------------------------------------------------------------------- 1 | from diffusers import DiffusionPipeline 2 | from diffusers.utils import pt_to_pil 3 | import torch 4 | 5 | import os 6 | import glob 7 | import json 8 | import argparse 9 | import numpy as np 10 | from tqdm import tqdm 11 | 12 | 13 | SAVE_FOLDER = "./load/images_dreamfusion" 14 | 15 | if __name__ == "__main__": 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--rank", default=0, type=int, help="# of GPU") 18 | parser.add_argument("--prompt",required=True, type=str) 19 | 20 | args = parser.parse_args() 21 | 22 | # stage 1 23 | stage_1 = DiffusionPipeline.from_pretrained( 24 | "DeepFloyd/IF-I-XL-v1.0", 25 | variant="fp16", 26 | torch_dtype=torch.float16, 27 | local_files_only=True 28 | ) 29 | stage_1.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0 30 | stage_1.enable_model_cpu_offload() 31 | 32 | # stage 2 33 | stage_2 = DiffusionPipeline.from_pretrained( 34 | "DeepFloyd/IF-II-L-v1.0", 35 | text_encoder=None, 36 | variant="fp16", 37 | torch_dtype=torch.float16, 38 | local_files_only=True 39 | ) 40 | # stage_2.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0 41 | stage_2.enable_model_cpu_offload() 42 | 43 | # stage 3 44 | # safety_modules = {"feature_extractor": stage_1.feature_extractor, "safety_checker": stage_1.safety_checker, "watermarker": stage_1.watermarker} 45 | safety_modules = None 46 | stage_3 = DiffusionPipeline.from_pretrained( 47 | "stabilityai/stable-diffusion-x4-upscaler", 48 | torch_dtype=torch.float16, 49 | local_files_only=True 50 | ) 51 | stage_3.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0 52 | stage_3.enable_model_cpu_offload() 53 | 54 | # # load prompt library 55 | # with open(os.path.join("load/prompt_library.json"), "r") as f: 56 | # prompt_library = json.load(f) 57 | 58 | # n_prompts = len(prompt_library["dreamfusion"]) 59 | # n_prompts_per_rank = int(np.ceil(n_prompts / 8)) 60 | 61 | # for prompt in tqdm(prompt_library["dreamfusion"][args.rank * n_prompts_per_rank : (args.rank + 1) * n_prompts_per_rank]): 62 | 63 | prompt = args.prompt 64 | print("Prompt:", prompt) 65 | 66 | save_folder = os.path.join(SAVE_FOLDER, prompt) 67 | os.makedirs(save_folder, exist_ok=True) 68 | 69 | # if len(glob.glob(f"{save_folder}/*.png")) >= 30: 70 | # continue 71 | 72 | # enhance prompt 73 | prompt = prompt + ", 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3, hyperrealistic, intricate details, ultra-realistic, award-winning" 74 | 75 | prompt_embeds, negative_embeds = stage_1.encode_prompt(prompt) 76 | for _ in tqdm(range(30)): 77 | seed = np.random.randint(low=0, high=10000000, size=1)[0] 78 | generator = torch.manual_seed(seed) 79 | 80 | ### Stage 1 81 | image = stage_1(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt").images 82 | # pt_to_pil(image)[0].save("./if_stage_I.png") 83 | 84 | ### Stage 2 85 | image = stage_2( 86 | image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt" 87 | ).images 88 | # pt_to_pil(image)[0].save("./if_stage_II.png") 89 | 90 | ### Stage 3 91 | image = stage_3(prompt=prompt, image=(image.float() * 0.5 + 0.5), generator=generator, noise_level=100).images 92 | image[0].save(f"{save_folder}/img_{seed:08d}.png") -------------------------------------------------------------------------------- /threestudio/scripts/generate_images_if_prompt_library.py: -------------------------------------------------------------------------------- 1 | from diffusers import DiffusionPipeline 2 | from diffusers.utils import pt_to_pil 3 | import torch 4 | 5 | import os 6 | import glob 7 | import json 8 | import argparse 9 | import numpy as np 10 | from tqdm import tqdm 11 | 12 | 13 | SAVE_FOLDER = "./load/images_dreamfusion" 14 | 15 | if __name__ == "__main__": 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--rank", default=0, type=int, help="# of GPU") 18 | 19 | args = parser.parse_args() 20 | 21 | # stage 1 22 | stage_1 = DiffusionPipeline.from_pretrained( 23 | "DeepFloyd/IF-I-XL-v1.0", 24 | variant="fp16", 25 | torch_dtype=torch.float16, 26 | local_files_only=True 27 | ) 28 | stage_1.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0 29 | stage_1.enable_model_cpu_offload() 30 | 31 | # stage 2 32 | stage_2 = DiffusionPipeline.from_pretrained( 33 | "DeepFloyd/IF-II-L-v1.0", 34 | text_encoder=None, 35 | variant="fp16", 36 | torch_dtype=torch.float16, 37 | local_files_only=True 38 | ) 39 | # stage_2.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0 40 | stage_2.enable_model_cpu_offload() 41 | 42 | # stage 3 43 | # safety_modules = {"feature_extractor": stage_1.feature_extractor, "safety_checker": stage_1.safety_checker, "watermarker": stage_1.watermarker} 44 | safety_modules = None 45 | stage_3 = DiffusionPipeline.from_pretrained( 46 | "stabilityai/stable-diffusion-x4-upscaler", 47 | torch_dtype=torch.float16, 48 | local_files_only=True 49 | ) 50 | stage_3.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0 51 | stage_3.enable_model_cpu_offload() 52 | 53 | # load prompt library 54 | with open(os.path.join("load/prompt_library.json"), "r") as f: 55 | prompt_library = json.load(f) 56 | 57 | n_prompts = len(prompt_library["dreamfusion"]) 58 | n_prompts_per_rank = int(np.ceil(n_prompts / 8)) 59 | 60 | for prompt in tqdm(prompt_library["dreamfusion"][args.rank * n_prompts_per_rank : (args.rank + 1) * n_prompts_per_rank]): 61 | 62 | print("Prompt:", prompt) 63 | 64 | save_folder = os.path.join(SAVE_FOLDER, prompt) 65 | os.makedirs(save_folder, exist_ok=True) 66 | 67 | if len(glob.glob(f"{save_folder}/*.png")) >= 30: 68 | continue 69 | 70 | # enhance prompt 71 | prompt = prompt + ", 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3, hyperrealistic, intricate details, ultra-realistic, award-winning" 72 | 73 | prompt_embeds, negative_embeds = stage_1.encode_prompt(prompt) 74 | for _ in tqdm(range(30)): 75 | seed = np.random.randint(low=0, high=10000000, size=1)[0] 76 | generator = torch.manual_seed(seed) 77 | 78 | ### Stage 1 79 | image = stage_1(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt").images 80 | # pt_to_pil(image)[0].save("./if_stage_I.png") 81 | 82 | ### Stage 2 83 | image = stage_2( 84 | image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt" 85 | ).images 86 | # pt_to_pil(image)[0].save("./if_stage_II.png") 87 | 88 | ### Stage 3 89 | image = stage_3(prompt=prompt, image=(image.float() * 0.5 + 0.5), generator=generator, noise_level=100).images 90 | image[0].save(f"{save_folder}/img_{seed:08d}.png") -------------------------------------------------------------------------------- /threestudio/scripts/generate_mv_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import glob 4 | import torch 5 | import argparse 6 | import numpy as np 7 | from tqdm import tqdm 8 | import pytorch_lightning as pl 9 | from torchvision.utils import save_image 10 | from subprocess import run, CalledProcessError 11 | from threestudio.utils.config import load_config 12 | import threestudio 13 | 14 | # Constants 15 | AZIMUTH_FACTOR = 360 16 | IMAGE_SIZE = (512, 512) 17 | 18 | 19 | def copy_file(source, destination): 20 | try: 21 | command = ['cp', source, destination] 22 | result = run(command, capture_output=True, text=True) 23 | result.check_returncode() 24 | except CalledProcessError as e: 25 | print(f'Error: {e.output}') 26 | 27 | 28 | def prepare_images(cfg): 29 | rgb_list = sorted(glob.glob(os.path.join(cfg.data.render_image_path, "*.png"))) 30 | rgb_list.sort(key=lambda file: int(os.path.splitext(os.path.basename(file))[0])) 31 | n_rgbs = len(rgb_list) 32 | n_samples = cfg.data.n_samples 33 | 34 | os.makedirs(cfg.data.save_path, exist_ok=True) 35 | 36 | copy_file(cfg.data.ref_image_path, f"{cfg.data.save_path}/ref_0.0.png") 37 | 38 | sampled_indices = np.linspace(0, len(rgb_list)-1, n_samples, dtype=int) 39 | rgb_samples = [rgb_list[index] for index in sampled_indices] 40 | 41 | return rgb_samples 42 | 43 | 44 | def process_images(rgb_samples, cfg, guidance, prompt_utils): 45 | n_rgbs = 120 46 | for rgb_name in tqdm(rgb_samples): 47 | rgb_idx = int(os.path.basename(rgb_name).split(".")[0]) 48 | rgb = cv2.imread(rgb_name)[:, :, :3][:, :, ::-1].copy() / 255.0 49 | H, W = rgb.shape[0:2] 50 | rgb_image, mask_image = rgb[:, :H], rgb[:, -H:, :1] 51 | rgb_image = cv2.resize(rgb_image, IMAGE_SIZE) 52 | rgb_image = torch.FloatTensor(rgb_image).unsqueeze(0).to(guidance.device) 53 | 54 | mask_image = cv2.resize(mask_image, IMAGE_SIZE).reshape(IMAGE_SIZE[0], IMAGE_SIZE[1], 1) 55 | mask_image = torch.FloatTensor(mask_image).unsqueeze(0).to(guidance.device) 56 | 57 | temp = torch.zeros(1).to(guidance.device) 58 | azimuth = torch.tensor([rgb_idx/n_rgbs * AZIMUTH_FACTOR]).to(guidance.device) 59 | camera_distance = torch.tensor([cfg.data.default_camera_distance]).to(guidance.device) 60 | 61 | if cfg.data.view_dependent_noise: 62 | guidance.min_step_percent = 0. + (rgb_idx/n_rgbs) * (cfg.system.guidance.min_step_percent) 63 | guidance.max_step_percent = 0. + (rgb_idx/n_rgbs) * (cfg.system.guidance.max_step_percent) 64 | 65 | denoised_image = process_guidance(cfg, guidance, prompt_utils, rgb_image, azimuth, temp, camera_distance, mask_image) 66 | 67 | save_image(denoised_image.permute(0,3,1,2), f"{cfg.data.save_path}/img_{azimuth[0]}.png", normalize=True, value_range=(0, 1)) 68 | 69 | copy_file(rgb_name.replace("png", "npy"), f"{cfg.data.save_path}/img_{azimuth[0]}.npy") 70 | 71 | if rgb_idx == 0: 72 | copy_file(rgb_name.replace("png", "npy"), f"{cfg.data.save_path}/ref_{azimuth[0]}.npy") 73 | 74 | 75 | def process_guidance(cfg, guidance, prompt_utils, rgb_image, azimuth, temp, camera_distance, mask_image): 76 | if cfg.data.azimuth_range[0] < azimuth < cfg.data.azimuth_range[1]: 77 | return guidance.sample_img2img( 78 | rgb_image, prompt_utils, temp, 79 | azimuth, camera_distance, seed=0, mask=mask_image 80 | )["edit_image"] 81 | else: 82 | return rgb_image 83 | 84 | 85 | def generate_mv_dataset(cfg): 86 | 87 | guidance = threestudio.find(cfg.system.guidance_type)(cfg.system.guidance) 88 | prompt_processor = threestudio.find(cfg.system.prompt_processor_type)(cfg.system.prompt_processor) 89 | prompt_utils = prompt_processor() 90 | 91 | guidance.update_step(epoch=0, global_step=0) 92 | rgb_samples = prepare_images(cfg) 93 | print(rgb_samples) 94 | process_images(rgb_samples, cfg, guidance, prompt_utils) 95 | 96 | -------------------------------------------------------------------------------- /threestudio/scripts/img_to_mv.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from PIL import Image 4 | import torch 5 | from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler, StableDiffusionUpscalePipeline 6 | 7 | 8 | def load_model(superres): 9 | mv_model = DiffusionPipeline.from_pretrained( 10 | "sudo-ai/zero123plus-v1.1", custom_pipeline="sudo-ai/zero123plus-pipeline", 11 | torch_dtype=torch.float16, cache_dir="load/checkpoints/huggingface/hub", local_files_only=True, 12 | ) 13 | mv_model.scheduler = EulerAncestralDiscreteScheduler.from_config( 14 | mv_model.scheduler.config, timestep_spacing='trailing', cache_dir="load/checkpoints/huggingface/hub", local_files_only=True, 15 | ) 16 | 17 | if superres: 18 | superres_model = StableDiffusionUpscalePipeline.from_pretrained( 19 | "stabilityai/stable-diffusion-x4-upscaler", revision="fp16", 20 | torch_dtype=torch.float16, cache_dir="load/checkpoints/huggingface/hub", local_files_only=True, 21 | ) 22 | else: 23 | superres_model = None 24 | 25 | return mv_model, superres_model 26 | 27 | 28 | def superres_4x(image, model, prompt): 29 | low_res_img = image.resize((256, 256)) 30 | model.to('cuda:1') 31 | result = model(prompt=prompt, image=low_res_img).images[0] 32 | return result 33 | 34 | 35 | def img_to_mv(image_path, model): 36 | cond = Image.open(image_path) 37 | model.to('cuda:1') 38 | result = model(cond, num_inference_steps=75).images[0] 39 | return result 40 | 41 | 42 | def crop_save_image_to_2x3_grid(image, args, model): 43 | save_path = args.save_path 44 | width, height = image.size 45 | grid_width = width//2 46 | grid_height = height//3 47 | 48 | images = [] 49 | for i in range(3): 50 | for j in range(2): 51 | left = j * grid_width 52 | upper = i * grid_height 53 | right = (j+1) * grid_width 54 | lower = (i+1) * grid_height 55 | 56 | cropped_image = image.crop((left, upper, right, lower)) 57 | if args.superres: 58 | cropped_image = superres_4x(cropped_image, model, args.prompt) 59 | images.append(cropped_image) 60 | 61 | for idx, img in enumerate(images): 62 | img.save(os.path.join(save_path, f'cropped_{idx}.jpg')) 63 | 64 | 65 | if __name__ == "__main__": 66 | parser = argparse.ArgumentParser() 67 | parser.add_argument('--image_path', type=str, help="path to image (png, jpeg, etc.)") 68 | parser.add_argument('--save_path', type=str, help="path to save output images") 69 | parser.add_argument('--prompt', type=str, help="prompt to use for superres") 70 | parser.add_argument('--superres', action='store_true', help="whether to use superres") 71 | args = parser.parse_args() 72 | 73 | print(args.superres) 74 | 75 | os.makedirs(args.save_path, exist_ok=True) 76 | os.system(f"cp '{args.image_path}' '{args.save_path}'") 77 | 78 | mv_model, superres_model = load_model(args.superres) 79 | images = img_to_mv(args.image_path, mv_model) 80 | crop_save_image_to_2x3_grid(images, args, superres_model) 81 | 82 | 83 | # Example usage: 84 | # python threestudio/scripts/img_to_mv.py --image_path 'mushroom.png' --save_path '.cache/temp' --prompt 'a photo of mushroom' --superres -------------------------------------------------------------------------------- /threestudio/scripts/make_training_vid.py: -------------------------------------------------------------------------------- 1 | # make_training_vid("outputs/zero123/64_teddy_rgba.png@20230627-195615", frames_per_vid=30, fps=20, max_iters=200) 2 | import argparse 3 | import glob 4 | import os 5 | 6 | import imageio 7 | import numpy as np 8 | from PIL import Image, ImageDraw 9 | from tqdm import tqdm 10 | 11 | 12 | def draw_text_in_image(img, texts): 13 | img = Image.fromarray(img) 14 | draw = ImageDraw.Draw(img) 15 | black, white = (0, 0, 0), (255, 255, 255) 16 | for i, text in enumerate(texts): 17 | draw.text((2, (img.size[1] // len(texts)) * i + 1), f"{text}", white) 18 | draw.text((0, (img.size[1] // len(texts)) * i + 1), f"{text}", white) 19 | draw.text((2, (img.size[1] // len(texts)) * i - 1), f"{text}", white) 20 | draw.text((0, (img.size[1] // len(texts)) * i - 1), f"{text}", white) 21 | draw.text((1, (img.size[1] // len(texts)) * i), f"{text}", black) 22 | return np.asarray(img) 23 | 24 | 25 | def make_training_vid(exp, frames_per_vid=1, fps=3, max_iters=None, max_vids=None): 26 | # exp = "/admin/home-vikram/git/threestudio/outputs/zero123/64_teddy_rgba.png@20230627-195615" 27 | files = glob.glob(os.path.join(exp, "save", "*.mp4")) 28 | if os.path.join(exp, "save", "training_vid.mp4") in files: 29 | files.remove(os.path.join(exp, "save", "training_vid.mp4")) 30 | its = [int(os.path.basename(file).split("-")[0].split("it")[-1]) for file in files] 31 | it_sort = np.argsort(its) 32 | files = list(np.array(files)[it_sort]) 33 | its = list(np.array(its)[it_sort]) 34 | max_vids = max_iters // its[0] if max_iters is not None else max_vids 35 | files, its = files[:max_vids], its[:max_vids] 36 | frames, i = [], 0 37 | for it, file in tqdm(zip(its, files), total=len(files)): 38 | vid = imageio.mimread(file) 39 | for _ in range(frames_per_vid): 40 | frame = vid[i % len(vid)] 41 | frame = draw_text_in_image(frame, [str(it)]) 42 | frames.append(frame) 43 | i += 1 44 | # Save 45 | imageio.mimwrite(os.path.join(exp, "save", "training_vid.mp4"), frames, fps=fps) 46 | 47 | 48 | def join(file1, file2, name): 49 | # file1 = "/admin/home-vikram/git/threestudio/outputs/zero123/OLD_64_dragon2_rgba.png@20230629-023028/save/it200-val.mp4" 50 | # file2 = "/admin/home-vikram/git/threestudio/outputs/zero123/64_dragon2_rgba.png@20230628-152734/save/it200-val.mp4" 51 | vid1 = imageio.mimread(file1) 52 | vid2 = imageio.mimread(file2) 53 | frames = [] 54 | for f1, f2 in zip(vid1, vid2): 55 | frames.append( 56 | np.concatenate([f1[:, : f1.shape[0]], f2[:, : f2.shape[0]]], axis=1) 57 | ) 58 | imageio.mimwrite(name, frames) 59 | 60 | 61 | if __name__ == "__main__": 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument("--exp", help="directory of experiment") 64 | parser.add_argument( 65 | "--frames_per_vid", type=int, default=1, help="# of frames from each val vid" 66 | ) 67 | parser.add_argument("--fps", type=int, help="max # of iters to save") 68 | parser.add_argument("--max_iters", type=int, help="max # of iters to save") 69 | parser.add_argument( 70 | "--max_vids", 71 | type=int, 72 | help="max # of val videos to save. Will be overridden by max_iters", 73 | ) 74 | args = parser.parse_args() 75 | make_training_vid( 76 | args.exp, args.frames_per_vid, args.fps, args.max_iters, args.max_vids 77 | ) -------------------------------------------------------------------------------- /threestudio/scripts/run_gaussian.sh: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | prompt_list = [ 4 | "a delicious hamburger", 5 | "A DSLR photo of a roast turkey on a platter", 6 | "A high quality photo of a dragon", 7 | "A DSLR photo of a bald eagle", 8 | "A bunch of blue rose, highly detailed", 9 | "A 3D model of an adorable cottage with a thatched roof", 10 | "A high quality photo of a furry corgi", 11 | "A DSLR photo of a panda", 12 | "a DSLR photo of a cat lying on its side batting at a ball of yarn", 13 | "a beautiful dress made out of fruit, on a mannequin. Studio lighting, high quality, high resolution", 14 | "a DSLR photo of a corgi wearing a beret and holding a baguette, standing up on two hind legs", 15 | "a zoomed out DSLR photo of a stack of pancakes", 16 | "a zoomed out DSLR photo of a baby bunny sitting on top of a stack of pancakes", 17 | ] 18 | negative_prompt = "oversaturated color, ugly, tiling, low quality, noise, ugly pattern" 19 | 20 | gpu_id = 0 21 | max_steps = 10 22 | val_check = 1 23 | out_name = "gsgen_baseline" 24 | for prompt in prompt_list: 25 | print(f"Running model on device {gpu_id}: ", prompt) 26 | command = [ 27 | "python", "launch.py", 28 | "--config", "configs/gaussian_splatting.yaml", 29 | "--train", 30 | f"system.prompt_processor.prompt={prompt}", 31 | f"system.prompt_processor.negative_prompt={negative_prompt}", 32 | f"name={out_name}", 33 | "--gpu", f"{gpu_id}" 34 | ] 35 | subprocess.run(command) 36 | -------------------------------------------------------------------------------- /threestudio/scripts/run_zero123.py: -------------------------------------------------------------------------------- 1 | NAME="dragon2" 2 | 3 | # Phase 1 - 64x64 4 | python launch.py --config configs/zero123.yaml --train --gpu 7 data.image_path=./load/images/${NAME}_rgba.png use_timestamp=False name=${NAME} tag=Phase1 # system.freq.guidance_eval=0 system.loggers.wandb.enable=false system.loggers.wandb.project="zero123" system.loggers.wandb.name=${NAME}_Phase1 5 | 6 | # Phase 1.5 - 512 refine 7 | python launch.py --config configs/zero123-geometry.yaml --train --gpu 4 data.image_path=./load/images/${NAME}_rgba.png system.geometry_convert_from=./outputs/${NAME}/Phase1/ckpts/last.ckpt use_timestamp=False name=${NAME} tag=Phase1p5 # system.freq.guidance_eval=0 system.loggers.wandb.enable=false system.loggers.wandb.project="zero123" system.loggers.wandb.name=${NAME}_Phase1p5 8 | 9 | # Phase 2 - dreamfusion 10 | python launch.py --config configs/experimental/imagecondition_zero123nerf.yaml --train --gpu 5 data.image_path=./load/images/${NAME}_rgba.png system.prompt_processor.prompt="A 3D model of a friendly dragon" system.weights="/admin/home-vikram/git/threestudio/outputs/${NAME}/Phase1/ckpts/last.ckpt" name=${NAME} tag=Phase2 # system.freq.guidance_eval=0 system.loggers.wandb.enable=false system.loggers.wandb.project="zero123" system.loggers.wandb.name=${NAME}_Phase2 11 | 12 | # Phase 2 - SDF + dreamfusion 13 | python launch.py --config configs/experimental/imagecondition_zero123nerf_refine.yaml --train --gpu 5 data.image_path=./load/images/${NAME}_rgba.png system.prompt_processor.prompt="A 3D model of a friendly dragon" system.geometry_convert_from="/admin/home-vikram/git/threestudio/outputs/${NAME}/Phase1/ckpts/last.ckpt" name=${NAME} tag=Phase2_refine # system.freq.guidance_eval=0 system.loggers.wandb.enable=false system.loggers.wandb.project="zero123" system.loggers.wandb.name=${NAME}_Phase2_refine -------------------------------------------------------------------------------- /threestudio/scripts/run_zero123_comparison.sh: -------------------------------------------------------------------------------- 1 | # with standard zero123 2 | threestudio/scripts/run_zero123_phase.sh 6 anya_front 105000 0 3 | 4 | # with zero123XL (not released yet!) 5 | threestudio/scripts/run_zero123_phase.sh 1 anya_front XL_20230604 0 6 | threestudio/scripts/run_zero123_phase.sh 2 baby_phoenix_on_ice XL_20230604 20 7 | threestudio/scripts/run_zero123_phase.sh 3 beach_house_1 XL_20230604 50 8 | threestudio/scripts/run_zero123_phase.sh 4 bollywood_actress XL_20230604 0 9 | threestudio/scripts/run_zero123_phase.sh 5 beach_house_2 XL_20230604 30 10 | threestudio/scripts/run_zero123_phase.sh 6 hamburger XL_20230604 10 11 | threestudio/scripts/run_zero123_phase.sh 7 cactus XL_20230604 8 12 | threestudio/scripts/run_zero123_phase.sh 0 catstatue XL_20230604 50 13 | threestudio/scripts/run_zero123_phase.sh 1 church_ruins XL_20230604 0 14 | threestudio/scripts/run_zero123_phase.sh 2 firekeeper XL_20230604 10 15 | threestudio/scripts/run_zero123_phase.sh 3 futuristic_car XL_20230604 20 16 | threestudio/scripts/run_zero123_phase.sh 4 mona_lisa XL_20230604 10 17 | threestudio/scripts/run_zero123_phase.sh 5 teddy XL_20230604 20 18 | 19 | # set guidance_eval to 0, to greatly speed up training 20 | threestudio/scripts/run_zero123_phase.sh 7 anya_front XL_20230604 0 system.freq.guidance_eval=0 21 | 22 | # disable wandb for faster training (or if you don't want to use it) 23 | threestudio/scripts/run_zero123_phase.sh 7 anya_front XL_20230604 0 system.loggers.wandb.enable=false system.freq.guidance_eval=0 24 | -------------------------------------------------------------------------------- /threestudio/scripts/run_zero123_demo.sh: -------------------------------------------------------------------------------- 1 | NAME="dragon2" 2 | 3 | # Phase 1 - 64x64 4 | python launch.py --config configs/zero123_64.yaml --train --gpu 7 system.loggers.wandb.enable=false system.loggers.wandb.project="voletiv-anya-new" system.loggers.wandb.name=${NAME} data.image_path=./load/images/${NAME}_rgba.png system.freq.guidance_eval=0 system.guidance.pretrained_model_name_or_path="./load/zero123/XL_20230604.ckpt" use_timestamp=False name=${NAME} tag="Phase1_64" 5 | 6 | # python threestudio/scripts/make_training_vid.py --exp /admin/home-vikram/git/threestudio/outputs/zero123/64_dragon2_rgba.png@20230628-152734 --frames_per_vid 30 --fps 20 --max_iters 200 7 | 8 | # # Phase 1.5 - 512 9 | # python launch.py --config configs/zero123_512.yaml --train --gpu 5 system.loggers.wandb.enable=true system.loggers.wandb.project="voletiv-zero123XL-demo" system.loggers.wandb.name="robot_512_drel_n_XL_SAMEgeom" data.image_path=./load/images/robot_rgba.png system.freq.guidance_eval=0 system.guidance.pretrained_model_name_or_path="./load/zero123/XL_20230604.ckpt" tag='${data.random_camera.height}_${rmspace:${basename:${data.image_path}},_}_XL_SAMEgeom' system.weights="/admin/home-vikram/git/threestudio/outputs/zero123/[64, 128]_robot_rgba.png_OLD@20230630-052314/ckpts/last.ckpt" 10 | 11 | # Phase 1.5 - 512 refine 12 | python launch.py --config configs/zero123-geometry.yaml --train --gpu 4 system.loggers.wandb.enable=false system.loggers.wandb.project="voletiv-zero123XL-demo" system.loggers.wandb.name="robot_512_drel_n_XL_SAMEg" system.freq.guidance_eval=0 data.image_path=./load/images/${NAME}_rgba.png system.geometry_convert_from=./outputs/${NAME}/Phase1_64/ckpts/last.ckpt use_timestamp=False name=${NAME} tag="Phase2_512geom" 13 | 14 | # Phase 2 - dreamfusion 15 | python launch.py --config configs/experimental/imagecondition_zero123nerf.yaml --train --gpu 5 system.loggers.wandb.enable=false system.loggers.wandb.project="voletiv-zero123XL-demo" system.loggers.wandb.name="robot_512_drel_n_XL_SAMEw" tag='${data.random_camera.height}_${rmspace:${basename:${data.image_path}},_}_XL_Phase2' system.freq.guidance_eval=0 data.image_path=./load/images/robot_rgba.png system.prompt_processor.prompt="A DSLR 3D photo of a cute anime schoolgirl stands proudly with her arms in the air, pink hair ( unreal engine 5 trending on Artstation Ghibli 4k )" system.weights="/admin/home-vikram/git/threestudio/outputs/zero123/[64, 128]_robot_rgba.png_OLD@20230630-052314/ckpts/last.ckpt" 16 | 17 | python launch.py --config configs/experimental/imagecondition_zero123nerf_refine.yaml --train --gpu 5 system.loggers.wandb.enable=false system.loggers.wandb.project="voletiv-zero123XL-demo" system.loggers.wandb.name="robot_512_drel_n_XL_SAMEw" tag='${data.random_camera.height}_${rmspace:${basename:${data.image_path}},_}_XL_Phase2_refine' system.freq.guidance_eval=0 data.image_path=./load/images/robot_rgba.png system.prompt_processor.prompt="A 3D model of a friendly dragon" system.geometry_convert_from="/admin/home-vikram/git/threestudio/outputs/zero123/[64, 128, 256]_dragon2_rgba.png_XL_REPEAT@20230705-023531/ckpts/last.ckpt" 18 | 19 | # A DSLR 3D photo of a cute anime schoolgirl stands proudly with her arms in the air, pink hair ( unreal engine 5 trending on Artstation Ghibli 4k )" 20 | # "/admin/home-vikram/git/threestudio/outputs/zero123/[64, 128]_robot_rgba.png_OLD@20230630-052314/ckpts/last.ckpt" 21 | 22 | # Adds zero123_512-refine.yaml 23 | # Adds resolution_milestones to image.py 24 | # guidance_eval gets max batch_size 4 25 | # Introduces random_bg in solid_color_bg -------------------------------------------------------------------------------- /threestudio/scripts/run_zero123_phase.sh: -------------------------------------------------------------------------------- 1 | 2 | GPU_ID=$1 # e.g. 0 3 | IMAGE_PREFIX=$2 # e.g. "anya_front" 4 | ZERO123_PREFIX=$3 # e.g. "XL_20230604" 5 | ELEVATION=$4 # e.g. 0 6 | REST=${@:5:99} # e.g. "system.guidance.min_step_percent=0.1 system.guidance.max_step_percent=0.9" 7 | 8 | # change this config if you don't use wandb or want to speed up training 9 | python launch.py --config configs/zero123.yaml --train --gpu $GPU_ID system.loggers.wandb.enable=true system.loggers.wandb.project="claforte-noise_atten" \ 10 | system.loggers.wandb.name="${IMAGE_PREFIX}_zero123_${ZERO123_PREFIX}...fov20_${REST}" \ 11 | data.image_path=./load/images/${IMAGE_PREFIX}_rgba.png system.freq.guidance_eval=37 \ 12 | system.guidance.pretrained_model_name_or_path="./load/zero123/${ZERO123_PREFIX}.ckpt" \ 13 | system.guidance.cond_elevation_deg=$ELEVATION \ 14 | ${REST} 15 | -------------------------------------------------------------------------------- /threestudio/scripts/run_zero123_phase2.sh: -------------------------------------------------------------------------------- 1 | # Reconstruct Anya using latest Zero123XL, in <2000 steps. 2 | python launch.py --config configs/zero123.yaml --train --gpu 0 system.loggers.wandb.enable=true system.loggers.wandb.project="voletiv-anya-new" system.loggers.wandb.name="claforte_params" data.image_path=./load/images/anya_front_rgba.png system.freq.ref_or_zero123="accumulate" system.freq.guidance_eval=13 system.guidance.pretrained_model_name_or_path="./load/zero123/XL_20230604.ckpt" 3 | 4 | # PHASE 2 5 | python launch.py --config configs/experimental/imagecondition_zero123nerf.yaml --train --gpu 0 system.prompt_processor.prompt="A DSLR 3D photo of a cute anime schoolgirl stands proudly with her arms in the air, pink hair ( unreal engine 5 trending on Artstation Ghibli 4k )" system.weights=outputs/zero123/128_anya_front_rgba.png@20230623-145711/ckpts/last.ckpt system.freq.guidance_eval=13 system.loggers.wandb.enable=true system.loggers.wandb.project="voletiv-anya-new" data.image_path=./load/images/anya_front_rgba.png system.loggers.wandb.name="anya" data.random_camera.progressive_until=500 -------------------------------------------------------------------------------- /threestudio/scripts/test_dreambooth.py: -------------------------------------------------------------------------------- 1 | from diffusers import StableDiffusionPipeline, DDIMScheduler 2 | import torch 3 | 4 | # model_id = "load/checkpoints/sd_21_base_mushroom_vd_prompt" 5 | # model_id = "load/checkpoints/sd_base_mushroom" 6 | model_id = ".cache/checkpoints/sd_21_base_rabbit" 7 | # scheduler = DDIMScheduler() 8 | pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") 9 | guidance_scale = 7.5 10 | 11 | prompt = "a sks rabbit, front view" 12 | image = pipe(prompt, num_inference_steps=50, guidance_scale=guidance_scale).images[0] 13 | 14 | image.save("debug.png") 15 | 16 | 17 | # import os 18 | # import cv2 19 | # import glob 20 | # import torch 21 | # import argparse 22 | # import numpy as np 23 | # from tqdm import tqdm 24 | # import pytorch_lightning as pl 25 | # from torchvision.utils import save_image 26 | 27 | # import threestudio 28 | # from threestudio.utils.config import load_config 29 | 30 | 31 | # if __name__ == "__main__": 32 | # parser = argparse.ArgumentParser() 33 | # parser.add_argument("--config", required=True, help="path to config file") 34 | # parser.add_argument("--view_dependent_noise", action="store_true", help="use view depdendent noise strength") 35 | 36 | # args, extras = parser.parse_known_args() 37 | 38 | # cfg = load_config(args.config, cli_args=extras, n_gpus=1) 39 | # guidance = threestudio.find(cfg.system.guidance_type)(cfg.system.guidance) 40 | # prompt_processor = threestudio.find(cfg.system.prompt_processor_type)(cfg.system.prompt_processor) 41 | # prompt_utils = prompt_processor() 42 | 43 | # guidance.update_step(epoch=0, global_step=0) 44 | # elevation, azimuth = torch.zeros(1).cuda(), torch.zeros(1).cuda() 45 | # camera_distances = torch.tensor([3.0]).cuda() 46 | # c2w = torch.zeros(4,4).cuda() 47 | # a = guidance.sample(prompt_utils, elevation, azimuth, camera_distances) # sample_lora 48 | # from torchvision.utils import save_image 49 | # save_image(a.permute(0,3,1,2), "debug.png", normalize=True, value_range=(0,1)) 50 | 51 | 52 | 53 | # python threestudio/scripts/test_dreambooth.py --config configs/experimental/stablediffusion.yaml system.prompt_processor.prompt="a sks mushroom growing on a log" \ 54 | # system.guidance.pretrained_model_name_or_path_lora="load/checkpoints/sd_21_base_mushroom_camera_condition" -------------------------------------------------------------------------------- /threestudio/scripts/test_dreambooth_lora.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler 3 | 4 | 5 | # model_base = "stabilityai/stable-diffusion-2-1-base" 6 | 7 | # pipe = DiffusionPipeline.from_pretrained(model_base, torch_dtype=torch.float16, cache_dir=CACHE_DIR, local_files_only=True) 8 | # pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, cache_dir=CACHE_DIR, local_files_only=True) 9 | # lora_model_path = "load/checkpoints/sd_21_base_bear_dreambooth_lora" 10 | # pipe.unet.load_attn_procs(lora_model_path) 11 | 12 | # pipe.to("cuda") 13 | 14 | 15 | # image = pipe("A picture of a sks bear in the sky", num_inference_steps=50, guidance_scale=7.5).images[0] 16 | # image.save("bear_dreambooth_lora.png") 17 | 18 | 19 | pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", local_files_only=True, safety_checker=None) 20 | pipe.load_lora_weights("if_dreambooth_mushroom") 21 | pipe.scheduler = pipe.scheduler.__class__.from_config(pipe.scheduler.config, variance_type="fixed_small") 22 | pipe.to("cuda:7") 23 | 24 | image = pipe("A photo of a sks mushroom, front view", num_inference_steps=50, guidance_scale=7.5).images[0] 25 | image.save("mushroom_dreambooth_lora.png") -------------------------------------------------------------------------------- /threestudio/systems/__init__.py: -------------------------------------------------------------------------------- 1 | from . import dreamcraft3d, zero123 2 | -------------------------------------------------------------------------------- /threestudio/systems/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import warnings 3 | from bisect import bisect_right 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.optim import lr_scheduler 8 | 9 | import threestudio 10 | 11 | 12 | def get_scheduler(name): 13 | if hasattr(lr_scheduler, name): 14 | return getattr(lr_scheduler, name) 15 | else: 16 | raise NotImplementedError 17 | 18 | 19 | def getattr_recursive(m, attr): 20 | for name in attr.split("."): 21 | m = getattr(m, name) 22 | return m 23 | 24 | 25 | def get_parameters(model, name): 26 | module = getattr_recursive(model, name) 27 | if isinstance(module, nn.Module): 28 | return module.parameters() 29 | elif isinstance(module, nn.Parameter): 30 | return module 31 | return [] 32 | 33 | 34 | def parse_optimizer(config, model): 35 | if hasattr(config, "params"): 36 | params = [ 37 | {"params": get_parameters(model, name), "name": name, **args} 38 | for name, args in config.params.items() 39 | ] 40 | threestudio.debug(f"Specify optimizer params: {config.params}") 41 | else: 42 | params = model.parameters() 43 | if config.name in ["FusedAdam"]: 44 | import apex 45 | 46 | optim = getattr(apex.optimizers, config.name)(params, **config.args) 47 | elif config.name in ["Adan"]: 48 | from threestudio.systems import optimizers 49 | 50 | optim = getattr(optimizers, config.name)(params, **config.args) 51 | else: 52 | optim = getattr(torch.optim, config.name)(params, **config.args) 53 | return optim 54 | 55 | 56 | def parse_scheduler_to_instance(config, optimizer): 57 | if config.name == "ChainedScheduler": 58 | schedulers = [ 59 | parse_scheduler_to_instance(conf, optimizer) for conf in config.schedulers 60 | ] 61 | scheduler = lr_scheduler.ChainedScheduler(schedulers) 62 | elif config.name == "Sequential": 63 | schedulers = [ 64 | parse_scheduler_to_instance(conf, optimizer) for conf in config.schedulers 65 | ] 66 | scheduler = lr_scheduler.SequentialLR( 67 | optimizer, schedulers, milestones=config.milestones 68 | ) 69 | else: 70 | scheduler = getattr(lr_scheduler, config.name)(optimizer, **config.args) 71 | return scheduler 72 | 73 | 74 | def parse_scheduler(config, optimizer): 75 | interval = config.get("interval", "epoch") 76 | assert interval in ["epoch", "step"] 77 | if config.name == "SequentialLR": 78 | scheduler = { 79 | "scheduler": lr_scheduler.SequentialLR( 80 | optimizer, 81 | [ 82 | parse_scheduler(conf, optimizer)["scheduler"] 83 | for conf in config.schedulers 84 | ], 85 | milestones=config.milestones, 86 | ), 87 | "interval": interval, 88 | } 89 | elif config.name == "ChainedScheduler": 90 | scheduler = { 91 | "scheduler": lr_scheduler.ChainedScheduler( 92 | [ 93 | parse_scheduler(conf, optimizer)["scheduler"] 94 | for conf in config.schedulers 95 | ] 96 | ), 97 | "interval": interval, 98 | } 99 | else: 100 | scheduler = { 101 | "scheduler": get_scheduler(config.name)(optimizer, **config.args), 102 | "interval": interval, 103 | } 104 | return scheduler -------------------------------------------------------------------------------- /threestudio/utils/GAN/distribution.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to( 34 | device=self.parameters.device 35 | ) 36 | 37 | def sample(self): 38 | x = self.mean + self.std * torch.randn(self.mean.shape).to( 39 | device=self.parameters.device 40 | ) 41 | return x 42 | 43 | def kl(self, other=None): 44 | if self.deterministic: 45 | return torch.Tensor([0.0]) 46 | else: 47 | if other is None: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 50 | dim=[1, 2, 3], 51 | ) 52 | else: 53 | return 0.5 * torch.sum( 54 | torch.pow(self.mean - other.mean, 2) / other.var 55 | + self.var / other.var 56 | - 1.0 57 | - self.logvar 58 | + other.logvar, 59 | dim=[1, 2, 3], 60 | ) 61 | 62 | def nll(self, sample, dims=[1, 2, 3]): 63 | if self.deterministic: 64 | return torch.Tensor([0.0]) 65 | logtwopi = np.log(2.0 * np.pi) 66 | return 0.5 * torch.sum( 67 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 68 | dim=dims, 69 | ) 70 | 71 | def mode(self): 72 | return self.mean 73 | 74 | 75 | def normal_kl(mean1, logvar1, mean2, logvar2): 76 | """ 77 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 78 | Compute the KL divergence between two gaussians. 79 | Shapes are automatically broadcasted, so batches can be compared to 80 | scalars, among other use cases. 81 | """ 82 | tensor = None 83 | for obj in (mean1, logvar1, mean2, logvar2): 84 | if isinstance(obj, torch.Tensor): 85 | tensor = obj 86 | break 87 | assert tensor is not None, "at least one argument must be a Tensor" 88 | 89 | # Force variances to be Tensors. Broadcasting helps convert scalars to 90 | # Tensors, but it does not work for torch.exp(). 91 | logvar1, logvar2 = [ 92 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 93 | for x in (logvar1, logvar2) 94 | ] 95 | 96 | return 0.5 * ( 97 | -1.0 98 | + logvar2 99 | - logvar1 100 | + torch.exp(logvar1 - logvar2) 101 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 102 | ) -------------------------------------------------------------------------------- /threestudio/utils/GAN/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def generator_loss(discriminator, inputs, reconstructions, cond=None): 6 | if cond is None: 7 | logits_fake = discriminator(reconstructions.contiguous()) 8 | else: 9 | logits_fake = discriminator( 10 | torch.cat((reconstructions.contiguous(), cond), dim=1) 11 | ) 12 | g_loss = -torch.mean(logits_fake) 13 | return g_loss 14 | 15 | 16 | def hinge_d_loss(logits_real, logits_fake): 17 | loss_real = torch.mean(F.relu(1.0 - logits_real)) 18 | loss_fake = torch.mean(F.relu(1.0 + logits_fake)) 19 | d_loss = 0.5 * (loss_real + loss_fake) 20 | return d_loss 21 | 22 | 23 | def discriminator_loss(discriminator, inputs, reconstructions, cond=None): 24 | if cond is None: 25 | logits_real = discriminator(inputs.contiguous().detach()) 26 | logits_fake = discriminator(reconstructions.contiguous().detach()) 27 | else: 28 | logits_real = discriminator( 29 | torch.cat((inputs.contiguous().detach(), cond), dim=1) 30 | ) 31 | logits_fake = discriminator( 32 | torch.cat((reconstructions.contiguous().detach(), cond), dim=1) 33 | ) 34 | d_loss = hinge_d_loss(logits_real, logits_fake).mean() 35 | return d_loss -------------------------------------------------------------------------------- /threestudio/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import base 2 | -------------------------------------------------------------------------------- /threestudio/utils/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from threestudio.utils.config import parse_structured 7 | from threestudio.utils.misc import get_device, load_module_weights 8 | from threestudio.utils.typing import * 9 | 10 | 11 | class Configurable: 12 | @dataclass 13 | class Config: 14 | pass 15 | 16 | def __init__(self, cfg: Optional[dict] = None) -> None: 17 | super().__init__() 18 | self.cfg = parse_structured(self.Config, cfg) 19 | 20 | 21 | class Updateable: 22 | def do_update_step( 23 | self, epoch: int, global_step: int, on_load_weights: bool = False 24 | ): 25 | for attr in self.__dir__(): 26 | if attr.startswith("_"): 27 | continue 28 | try: 29 | module = getattr(self, attr) 30 | except: 31 | continue # ignore attributes like property, which can't be retrived using getattr? 32 | if isinstance(module, Updateable): 33 | module.do_update_step( 34 | epoch, global_step, on_load_weights=on_load_weights 35 | ) 36 | self.update_step(epoch, global_step, on_load_weights=on_load_weights) 37 | 38 | def do_update_step_end(self, epoch: int, global_step: int): 39 | for attr in self.__dir__(): 40 | if attr.startswith("_"): 41 | continue 42 | try: 43 | module = getattr(self, attr) 44 | except: 45 | continue # ignore attributes like property, which can't be retrived using getattr? 46 | if isinstance(module, Updateable): 47 | module.do_update_step_end(epoch, global_step) 48 | self.update_step_end(epoch, global_step) 49 | 50 | def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): 51 | # override this method to implement custom update logic 52 | # if on_load_weights is True, you should be careful doing things related to model evaluations, 53 | # as the models and tensors are not guarenteed to be on the same device 54 | pass 55 | 56 | def update_step_end(self, epoch: int, global_step: int): 57 | pass 58 | 59 | 60 | def update_if_possible(module: Any, epoch: int, global_step: int) -> None: 61 | if isinstance(module, Updateable): 62 | module.do_update_step(epoch, global_step) 63 | 64 | 65 | def update_end_if_possible(module: Any, epoch: int, global_step: int) -> None: 66 | if isinstance(module, Updateable): 67 | module.do_update_step_end(epoch, global_step) 68 | 69 | 70 | class BaseObject(Updateable): 71 | @dataclass 72 | class Config: 73 | pass 74 | 75 | cfg: Config # add this to every subclass of BaseObject to enable static type checking 76 | 77 | def __init__( 78 | self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs 79 | ) -> None: 80 | super().__init__() 81 | self.cfg = parse_structured(self.Config, cfg) 82 | self.device = get_device() 83 | self.configure(*args, **kwargs) 84 | 85 | def configure(self, *args, **kwargs) -> None: 86 | pass 87 | 88 | 89 | class BaseModule(nn.Module, Updateable): 90 | @dataclass 91 | class Config: 92 | weights: Optional[str] = None 93 | 94 | cfg: Config # add this to every subclass of BaseModule to enable static type checking 95 | 96 | def __init__( 97 | self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs 98 | ) -> None: 99 | super().__init__() 100 | self.cfg = parse_structured(self.Config, cfg) 101 | self.device = get_device() 102 | self.configure(*args, **kwargs) 103 | if self.cfg.weights is not None: 104 | # format: path/to/weights:module_name 105 | weights_path, module_name = self.cfg.weights.split(":") 106 | state_dict, epoch, global_step = load_module_weights( 107 | weights_path, module_name=module_name, map_location="cpu" 108 | ) 109 | self.load_state_dict(state_dict) 110 | self.do_update_step( 111 | epoch, global_step, on_load_weights=True 112 | ) # restore states 113 | # dummy tensor to indicate model state 114 | self._dummy: Float[Tensor, "..."] 115 | self.register_buffer("_dummy", torch.zeros(0).float(), persistent=False) 116 | 117 | def configure(self, *args, **kwargs) -> None: 118 | pass 119 | -------------------------------------------------------------------------------- /threestudio/utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | from datetime import datetime 4 | 5 | from omegaconf import OmegaConf 6 | 7 | import threestudio 8 | from threestudio.utils.typing import * 9 | 10 | # ============ Register OmegaConf Recolvers ============= # 11 | OmegaConf.register_new_resolver( 12 | "calc_exp_lr_decay_rate", lambda factor, n: factor ** (1.0 / n) 13 | ) 14 | OmegaConf.register_new_resolver("add", lambda a, b: a + b) 15 | OmegaConf.register_new_resolver("sub", lambda a, b: a - b) 16 | OmegaConf.register_new_resolver("mul", lambda a, b: a * b) 17 | OmegaConf.register_new_resolver("div", lambda a, b: a / b) 18 | OmegaConf.register_new_resolver("idiv", lambda a, b: a // b) 19 | OmegaConf.register_new_resolver("basename", lambda p: os.path.basename(p)) 20 | OmegaConf.register_new_resolver("rmspace", lambda s, sub: s.replace(" ", sub)) 21 | OmegaConf.register_new_resolver("tuple2", lambda s: [float(s), float(s)]) 22 | OmegaConf.register_new_resolver("gt0", lambda s: s > 0) 23 | OmegaConf.register_new_resolver("cmaxgt0", lambda s: C_max(s) > 0) 24 | OmegaConf.register_new_resolver("not", lambda s: not s) 25 | OmegaConf.register_new_resolver( 26 | "cmaxgt0orcmaxgt0", lambda a, b: C_max(a) > 0 or C_max(b) > 0 27 | ) 28 | # ======================================================= # 29 | 30 | 31 | def C_max(value: Any) -> float: 32 | if isinstance(value, int) or isinstance(value, float): 33 | pass 34 | else: 35 | value = config_to_primitive(value) 36 | if not isinstance(value, list): 37 | raise TypeError("Scalar specification only supports list, got", type(value)) 38 | if len(value) >= 6: 39 | max_value = value[2] 40 | for i in range(4, len(value), 2): 41 | max_value = max(max_value, value[i]) 42 | value = [value[0], value[1], max_value, value[3]] 43 | if len(value) == 3: 44 | value = [0] + value 45 | assert len(value) == 4 46 | start_step, start_value, end_value, end_step = value 47 | value = max(start_value, end_value) 48 | return value 49 | 50 | 51 | @dataclass 52 | class ExperimentConfig: 53 | name: str = "default" 54 | description: str = "" 55 | tag: str = "" 56 | seed: int = 0 57 | use_timestamp: bool = True 58 | timestamp: Optional[str] = None 59 | exp_root_dir: str = "outputs" 60 | 61 | # import custom extension 62 | custom_import: Tuple[str] = () 63 | 64 | ### these shouldn't be set manually 65 | exp_dir: str = "outputs/default" 66 | trial_name: str = "exp" 67 | trial_dir: str = "outputs/default/exp" 68 | n_gpus: int = 1 69 | ### 70 | 71 | resume: Optional[str] = None 72 | 73 | data_type: str = "" 74 | data: dict = field(default_factory=dict) 75 | 76 | system_type: str = "" 77 | system: dict = field(default_factory=dict) 78 | 79 | # accept pytorch-lightning trainer parameters 80 | # see https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api 81 | trainer: dict = field(default_factory=dict) 82 | 83 | # accept pytorch-lightning checkpoint callback parameters 84 | # see https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#modelcheckpoint 85 | checkpoint: dict = field(default_factory=dict) 86 | 87 | def __post_init__(self): 88 | if not self.tag and not self.use_timestamp: 89 | raise ValueError("Either tag is specified or use_timestamp is True.") 90 | self.trial_name = self.tag 91 | # if resume from an existing config, self.timestamp should not be None 92 | if self.timestamp is None: 93 | self.timestamp = "" 94 | if self.use_timestamp: 95 | if self.n_gpus > 1: 96 | threestudio.warn( 97 | "Timestamp is disabled when using multiple GPUs, please make sure you have a unique tag." 98 | ) 99 | else: 100 | self.timestamp = datetime.now().strftime("@%Y%m%d-%H%M%S") 101 | self.trial_name += self.timestamp 102 | self.exp_dir = os.path.join(self.exp_root_dir, self.name) 103 | self.trial_dir = os.path.join(self.exp_dir, self.trial_name) 104 | os.makedirs(self.trial_dir, exist_ok=True) 105 | 106 | 107 | def load_config(*yamls: str, cli_args: list = [], from_string=False, **kwargs) -> Any: 108 | if from_string: 109 | yaml_confs = [OmegaConf.create(s) for s in yamls] 110 | else: 111 | yaml_confs = [OmegaConf.load(f) for f in yamls] 112 | cli_conf = OmegaConf.from_cli(cli_args) 113 | cfg = OmegaConf.merge(*yaml_confs, cli_conf, kwargs) 114 | OmegaConf.resolve(cfg) 115 | assert isinstance(cfg, DictConfig) 116 | scfg = parse_structured(ExperimentConfig, cfg) 117 | return scfg 118 | 119 | 120 | def config_to_primitive(config, resolve: bool = True) -> Any: 121 | return OmegaConf.to_container(config, resolve=resolve) 122 | 123 | 124 | def dump_config(path: str, config) -> None: 125 | with open(path, "w") as fp: 126 | OmegaConf.save(config=config, f=fp) 127 | 128 | 129 | def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any: 130 | scfg = OmegaConf.structured(fields(**cfg)) 131 | return scfg -------------------------------------------------------------------------------- /threestudio/utils/lpips/__init__.py: -------------------------------------------------------------------------------- 1 | from .lpips import LPIPS -------------------------------------------------------------------------------- /threestudio/utils/lpips/lpips.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torchvision import models 6 | from collections import namedtuple 7 | 8 | from threestudio.utils.lpips.utils import get_ckpt_path 9 | 10 | 11 | class LPIPS(nn.Module): 12 | # Learned perceptual metric 13 | def __init__(self, use_dropout=True): 14 | super().__init__() 15 | self.scaling_layer = ScalingLayer() 16 | self.chns = [64, 128, 256, 512, 512] # vg16 features 17 | self.net = vgg16(pretrained=True, requires_grad=False) 18 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 19 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 20 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 21 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 22 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 23 | self.load_from_pretrained() 24 | for param in self.parameters(): 25 | param.requires_grad = False 26 | 27 | def load_from_pretrained(self, name="vgg_lpips"): 28 | ckpt = get_ckpt_path(name, "threestudio/utils/lpips") 29 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 30 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 31 | 32 | @classmethod 33 | def from_pretrained(cls, name="vgg_lpips"): 34 | if name != "vgg_lpips": 35 | raise NotImplementedError 36 | model = cls() 37 | ckpt = get_ckpt_path(name) 38 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 39 | return model 40 | 41 | def forward(self, input, target): 42 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 43 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 44 | feats0, feats1, diffs = {}, {}, {} 45 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 46 | for kk in range(len(self.chns)): 47 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) 48 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 49 | 50 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] 51 | val = res[0] 52 | for l in range(1, len(self.chns)): 53 | val += res[l] 54 | return val 55 | 56 | 57 | class ScalingLayer(nn.Module): 58 | def __init__(self): 59 | super(ScalingLayer, self).__init__() 60 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 61 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) 62 | 63 | def forward(self, inp): 64 | return (inp - self.shift) / self.scale 65 | 66 | 67 | class NetLinLayer(nn.Module): 68 | """ A single linear layer which does a 1x1 conv """ 69 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 70 | super(NetLinLayer, self).__init__() 71 | layers = [nn.Dropout(), ] if (use_dropout) else [] 72 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] 73 | self.model = nn.Sequential(*layers) 74 | 75 | 76 | class vgg16(torch.nn.Module): 77 | def __init__(self, requires_grad=False, pretrained=True): 78 | super(vgg16, self).__init__() 79 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 80 | self.slice1 = torch.nn.Sequential() 81 | self.slice2 = torch.nn.Sequential() 82 | self.slice3 = torch.nn.Sequential() 83 | self.slice4 = torch.nn.Sequential() 84 | self.slice5 = torch.nn.Sequential() 85 | self.N_slices = 5 86 | for x in range(4): 87 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 88 | for x in range(4, 9): 89 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 90 | for x in range(9, 16): 91 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 92 | for x in range(16, 23): 93 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 94 | for x in range(23, 30): 95 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 96 | if not requires_grad: 97 | for param in self.parameters(): 98 | param.requires_grad = False 99 | 100 | def forward(self, X): 101 | h = self.slice1(X) 102 | h_relu1_2 = h 103 | h = self.slice2(h) 104 | h_relu2_2 = h 105 | h = self.slice3(h) 106 | h_relu3_3 = h 107 | h = self.slice4(h) 108 | h_relu4_3 = h 109 | h = self.slice5(h) 110 | h_relu5_3 = h 111 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 112 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 113 | return out 114 | 115 | 116 | def normalize_tensor(x,eps=1e-10): 117 | norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) 118 | return x/(norm_factor+eps) 119 | 120 | 121 | def spatial_average(x, keepdim=True): 122 | return x.mean([2,3],keepdim=keepdim) 123 | 124 | -------------------------------------------------------------------------------- /threestudio/utils/perceptual/__init__.py: -------------------------------------------------------------------------------- 1 | from .perceptual import PerceptualLoss 2 | -------------------------------------------------------------------------------- /threestudio/utils/perceptual/utils.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | 4 | import requests 5 | from tqdm import tqdm 6 | 7 | URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"} 8 | 9 | CKPT_MAP = {"vgg_lpips": "vgg.pth"} 10 | 11 | MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"} 12 | 13 | 14 | def download(url, local_path, chunk_size=1024): 15 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 16 | with requests.get(url, stream=True) as r: 17 | total_size = int(r.headers.get("content-length", 0)) 18 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 19 | with open(local_path, "wb") as f: 20 | for data in r.iter_content(chunk_size=chunk_size): 21 | if data: 22 | f.write(data) 23 | pbar.update(chunk_size) 24 | 25 | 26 | def md5_hash(path): 27 | with open(path, "rb") as f: 28 | content = f.read() 29 | return hashlib.md5(content).hexdigest() 30 | 31 | 32 | def get_ckpt_path(name, root, check=False): 33 | assert name in URL_MAP 34 | path = os.path.join(root, CKPT_MAP[name]) 35 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 36 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 37 | download(URL_MAP[name], path) 38 | md5 = md5_hash(path) 39 | assert md5 == MD5_MAP[name], md5 40 | return path 41 | 42 | 43 | class KeyNotFoundError(Exception): 44 | def __init__(self, cause, keys=None, visited=None): 45 | self.cause = cause 46 | self.keys = keys 47 | self.visited = visited 48 | messages = list() 49 | if keys is not None: 50 | messages.append("Key not found: {}".format(keys)) 51 | if visited is not None: 52 | messages.append("Visited: {}".format(visited)) 53 | messages.append("Cause:\n{}".format(cause)) 54 | message = "\n".join(messages) 55 | super().__init__(message) 56 | 57 | 58 | def retrieve( 59 | list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False 60 | ): 61 | """Given a nested list or dict return the desired value at key expanding 62 | callable nodes if necessary and :attr:`expand` is ``True``. The expansion 63 | is done in-place. 64 | 65 | Parameters 66 | ---------- 67 | list_or_dict : list or dict 68 | Possibly nested list or dictionary. 69 | key : str 70 | key/to/value, path like string describing all keys necessary to 71 | consider to get to the desired value. List indices can also be 72 | passed here. 73 | splitval : str 74 | String that defines the delimiter between keys of the 75 | different depth levels in `key`. 76 | default : obj 77 | Value returned if :attr:`key` is not found. 78 | expand : bool 79 | Whether to expand callable nodes on the path or not. 80 | 81 | Returns 82 | ------- 83 | The desired value or if :attr:`default` is not ``None`` and the 84 | :attr:`key` is not found returns ``default``. 85 | 86 | Raises 87 | ------ 88 | Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is 89 | ``None``. 90 | """ 91 | 92 | keys = key.split(splitval) 93 | 94 | success = True 95 | try: 96 | visited = [] 97 | parent = None 98 | last_key = None 99 | for key in keys: 100 | if callable(list_or_dict): 101 | if not expand: 102 | raise KeyNotFoundError( 103 | ValueError( 104 | "Trying to get past callable node with expand=False." 105 | ), 106 | keys=keys, 107 | visited=visited, 108 | ) 109 | list_or_dict = list_or_dict() 110 | parent[last_key] = list_or_dict 111 | 112 | last_key = key 113 | parent = list_or_dict 114 | 115 | try: 116 | if isinstance(list_or_dict, dict): 117 | list_or_dict = list_or_dict[key] 118 | else: 119 | list_or_dict = list_or_dict[int(key)] 120 | except (KeyError, IndexError, ValueError) as e: 121 | raise KeyNotFoundError(e, keys=keys, visited=visited) 122 | 123 | visited += [key] 124 | # final expansion of retrieved value 125 | if expand and callable(list_or_dict): 126 | list_or_dict = list_or_dict() 127 | parent[last_key] = list_or_dict 128 | except KeyNotFoundError as e: 129 | if default is None: 130 | raise e 131 | else: 132 | list_or_dict = default 133 | success = False 134 | 135 | if not pass_success: 136 | return list_or_dict 137 | else: 138 | return list_or_dict, success 139 | 140 | 141 | if __name__ == "__main__": 142 | config = { 143 | "keya": "a", 144 | "keyb": "b", 145 | "keyc": { 146 | "cc1": 1, 147 | "cc2": 2, 148 | }, 149 | } 150 | from omegaconf import OmegaConf 151 | 152 | config = OmegaConf.create(config) 153 | print(config) 154 | retrieve(config, "keya") -------------------------------------------------------------------------------- /threestudio/utils/rasterize.py: -------------------------------------------------------------------------------- 1 | import nvdiffrast.torch as dr 2 | import torch 3 | 4 | from threestudio.utils.typing import * 5 | 6 | 7 | class NVDiffRasterizerContext: 8 | def __init__(self, context_type: str, device: torch.device) -> None: 9 | self.device = device 10 | self.ctx = self.initialize_context(context_type, device) 11 | 12 | def initialize_context( 13 | self, context_type: str, device: torch.device 14 | ) -> Union[dr.RasterizeGLContext, dr.RasterizeCudaContext]: 15 | if context_type == "gl": 16 | return dr.RasterizeGLContext(device=device) 17 | elif context_type == "cuda": 18 | return dr.RasterizeCudaContext(device=device) 19 | else: 20 | raise ValueError(f"Unknown rasterizer context type: {context_type}") 21 | 22 | def vertex_transform( 23 | self, verts: Float[Tensor, "Nv 3"], mvp_mtx: Float[Tensor, "B 4 4"] 24 | ) -> Float[Tensor, "B Nv 4"]: 25 | verts_homo = torch.cat( 26 | [verts, torch.ones([verts.shape[0], 1]).to(verts)], dim=-1 27 | ) 28 | return torch.matmul(verts_homo, mvp_mtx.permute(0, 2, 1)) 29 | 30 | def rasterize( 31 | self, 32 | pos: Float[Tensor, "B Nv 4"], 33 | tri: Integer[Tensor, "Nf 3"], 34 | resolution: Union[int, Tuple[int, int]], 35 | ): 36 | # rasterize in instance mode (single topology) 37 | return dr.rasterize(self.ctx, pos.float(), tri.int(), resolution, grad_db=True) 38 | 39 | def rasterize_one( 40 | self, 41 | pos: Float[Tensor, "Nv 4"], 42 | tri: Integer[Tensor, "Nf 3"], 43 | resolution: Union[int, Tuple[int, int]], 44 | ): 45 | # rasterize one single mesh under a single viewpoint 46 | rast, rast_db = self.rasterize(pos[None, ...], tri, resolution) 47 | return rast[0], rast_db[0] 48 | 49 | def antialias( 50 | self, 51 | color: Float[Tensor, "B H W C"], 52 | rast: Float[Tensor, "B H W 4"], 53 | pos: Float[Tensor, "B Nv 4"], 54 | tri: Integer[Tensor, "Nf 3"], 55 | ) -> Float[Tensor, "B H W C"]: 56 | return dr.antialias(color.float(), rast, pos.float(), tri.int()) 57 | 58 | def interpolate( 59 | self, 60 | attr: Float[Tensor, "B Nv C"], 61 | rast: Float[Tensor, "B H W 4"], 62 | tri: Integer[Tensor, "Nf 3"], 63 | rast_db=None, 64 | diff_attrs=None, 65 | ) -> Float[Tensor, "B H W C"]: 66 | return dr.interpolate( 67 | attr.float(), rast, tri.int(), rast_db=rast_db, diff_attrs=diff_attrs 68 | ) 69 | 70 | def interpolate_one( 71 | self, 72 | attr: Float[Tensor, "Nv C"], 73 | rast: Float[Tensor, "B H W 4"], 74 | tri: Integer[Tensor, "Nf 3"], 75 | rast_db=None, 76 | diff_attrs=None, 77 | ) -> Float[Tensor, "B H W C"]: 78 | return self.interpolate(attr[None, ...], rast, tri, rast_db, diff_attrs) 79 | -------------------------------------------------------------------------------- /threestudio/utils/typing.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains type annotations for the project, using 3 | 1. Python type hints (https://docs.python.org/3/library/typing.html) for Python objects 4 | 2. jaxtyping (https://github.com/google/jaxtyping/blob/main/API.md) for PyTorch tensors 5 | 6 | Two types of typing checking can be used: 7 | 1. Static type checking with mypy (install with pip and enabled as the default linter in VSCode) 8 | 2. Runtime type checking with typeguard (install with pip and triggered at runtime, mainly for tensor dtype and shape checking) 9 | """ 10 | 11 | # Basic types 12 | from typing import ( 13 | Any, 14 | Callable, 15 | Dict, 16 | Iterable, 17 | List, 18 | Literal, 19 | NamedTuple, 20 | NewType, 21 | Optional, 22 | Sized, 23 | Tuple, 24 | Type, 25 | TypeVar, 26 | Union, 27 | ) 28 | 29 | # Tensor dtype 30 | # for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md 31 | from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt 32 | 33 | # Config type 34 | from omegaconf import DictConfig 35 | 36 | # PyTorch Tensor type 37 | from torch import Tensor 38 | 39 | # Runtime type checking decorator 40 | from typeguard import typechecked as typechecker 41 | --------------------------------------------------------------------------------