├── __init__.py ├── README.md ├── .pre-commit-config.yaml ├── guidance └── mesh_guidance.py ├── renderer └── mesh_renderer.py ├── configs └── mesh-fitting.yaml ├── geometry └── obj_mesh.py ├── system └── mesh_fitting.py └── .gitignore /__init__.py: -------------------------------------------------------------------------------- 1 | import threestudio 2 | from packaging.version import Version 3 | 4 | if hasattr(threestudio, "__version__") and Version(threestudio.__version__) >= Version( 5 | "0.2.0" 6 | ): 7 | pass 8 | else: 9 | if hasattr(threestudio, "__version__"): 10 | print(f"[INFO] threestudio version: {threestudio.__version__}") 11 | raise ValueError( 12 | "threestudio version must be >= 0.2.0, please update threestudio by pulling the latest version from github" 13 | ) 14 | 15 | from .geometry import obj_mesh 16 | from .guidance import mesh_guidance 17 | from .renderer import mesh_renderer 18 | from .system import mesh_fitting 19 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # threestudio-meshfitting 2 | ![mesh-fitting](https://github.com/DSaurus/threestudio-meshfitting/assets/24589363/236bbad3-902b-4212-9521-99166de8e672) 3 | 4 | A simple extension of threestudio, which enables using neural representation to fit a 3D mesh. To use it, please install [threestudio](https://github.com/threestudio-project/threestudio) first and then install this extension in threestudio `custom` directory. 5 | 6 | # Installation 7 | ``` 8 | cd custom 9 | git clone https://github.com/DSaurus/threestudio-meshfitting.git 10 | ``` 11 | 12 | # Quick Start 13 | ``` 14 | python launch.py --config custom/threestudio-meshfitting/configs/mesh-fitting.yaml --train system.geometry.shape_init="mesh:custom/threestudio-meshfitting/assets/a_Bulbasaur.obj" system.guidance.geometry.shape_init="mesh:custom/threestudio-meshfitting/assets/a_Bulbasaur.obj" 15 | ``` 16 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v4.4.0 7 | hooks: 8 | - id: trailing-whitespace 9 | - id: check-ast 10 | - id: check-merge-conflict 11 | - id: check-yaml 12 | - id: end-of-file-fixer 13 | - id: trailing-whitespace 14 | args: [--markdown-linebreak-ext=md] 15 | 16 | - repo: https://github.com/psf/black 17 | rev: 23.3.0 18 | hooks: 19 | - id: black 20 | language_version: python3 21 | 22 | - repo: https://github.com/pycqa/isort 23 | rev: 5.12.0 24 | hooks: 25 | - id: isort 26 | exclude: README.md 27 | args: ["--profile", "black"] 28 | 29 | # temporarily disable static type checking 30 | # - repo: https://github.com/pre-commit/mirrors-mypy 31 | # rev: v1.2.0 32 | # hooks: 33 | # - id: mypy 34 | # args: ["--ignore-missing-imports", "--scripts-are-modules", "--pretty"] 35 | -------------------------------------------------------------------------------- /guidance/mesh_guidance.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | import threestudio 4 | import torch.nn.functional as F 5 | from threestudio.utils.base import BaseObject 6 | from threestudio.utils.typing import * 7 | 8 | 9 | @threestudio.register("mesh-fitting-guidance") 10 | class MeshGuidance(BaseObject): 11 | @dataclass 12 | class Config(BaseObject.Config): 13 | geometry_type: str = "" 14 | geometry: dict = field(default_factory=dict) 15 | renderer_type: str = "" 16 | renderer: dict = field(default_factory=dict) 17 | material_type: str = "" 18 | material: dict = field(default_factory=dict) 19 | background_type: str = "" 20 | background: dict = field(default_factory=dict) 21 | 22 | cfg: Config 23 | 24 | def configure(self) -> None: 25 | threestudio.info(f"Loading obj") 26 | geometry = threestudio.find(self.cfg.geometry_type)(self.cfg.geometry) 27 | material = threestudio.find(self.cfg.material_type)(self.cfg.material) 28 | background = threestudio.find(self.cfg.background_type)(self.cfg.background) 29 | self.renderer = threestudio.find(self.cfg.renderer_type)( 30 | self.cfg.renderer, 31 | geometry=geometry, 32 | material=material, 33 | background=background, 34 | ) 35 | threestudio.info(f"Loaded mesh!") 36 | 37 | def __call__( 38 | self, 39 | rgb: Float[Tensor, "B H W C"], 40 | elevation: Float[Tensor, "B"], 41 | azimuth: Float[Tensor, "B"], 42 | camera_distances: Float[Tensor, "B"], 43 | rgb_as_latents=False, 44 | guidance_eval=False, 45 | **kwargs, 46 | ): 47 | guide_rgb = self.renderer(**kwargs) 48 | 49 | guidance_out = {"loss_l1": F.l1_loss(rgb, guide_rgb["comp_rgb"])} 50 | return guidance_out 51 | -------------------------------------------------------------------------------- /renderer/mesh_renderer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import nerfacc 4 | import threestudio 5 | import torch 6 | import torch.nn.functional as F 7 | from threestudio.models.background.base import BaseBackground 8 | from threestudio.models.geometry.base import BaseImplicitGeometry 9 | from threestudio.models.materials.base import BaseMaterial 10 | from threestudio.models.renderers.base import Rasterizer, VolumeRenderer 11 | from threestudio.utils.misc import get_device 12 | from threestudio.utils.rasterize import NVDiffRasterizerContext 13 | from threestudio.utils.typing import * 14 | 15 | 16 | @threestudio.register("mesh-fitting-renderer") 17 | class MeshRenderer(Rasterizer): 18 | @dataclass 19 | class Config(VolumeRenderer.Config): 20 | context_type: str = "gl" 21 | 22 | cfg: Config 23 | 24 | def configure( 25 | self, 26 | geometry: BaseImplicitGeometry, 27 | material: BaseMaterial, 28 | background: BaseBackground, 29 | ) -> None: 30 | super().configure(geometry, material, background) 31 | self.ctx = NVDiffRasterizerContext(self.cfg.context_type, get_device()) 32 | 33 | def forward( 34 | self, 35 | mvp_mtx: Float[Tensor, "B 4 4"], 36 | camera_positions: Float[Tensor, "B 3"], 37 | light_positions: Float[Tensor, "B 3"], 38 | height: int, 39 | width: int, 40 | render_rgb: bool = True, 41 | **kwargs 42 | ) -> Dict[str, Any]: 43 | batch_size = mvp_mtx.shape[0] 44 | mesh = self.geometry.isosurface() 45 | 46 | v_pos_clip: Float[Tensor, "B Nv 4"] = self.ctx.vertex_transform( 47 | mesh.v_pos, mvp_mtx 48 | ) 49 | rast, _ = self.ctx.rasterize(v_pos_clip, mesh.t_pos_idx, (height, width)) 50 | mask = rast[..., 3:] > 0 51 | mask_aa = self.ctx.antialias(mask.float(), rast, v_pos_clip, mesh.t_pos_idx) 52 | 53 | out = {"opacity": mask_aa, "mesh": mesh} 54 | 55 | gb_rgb_fg, _ = self.ctx.interpolate_one(mesh.v_color, rast, mesh.t_pos_idx) 56 | gb_pos, _ = self.ctx.interpolate_one(mesh.v_pos, rast, mesh.t_pos_idx) 57 | gb_viewdirs = F.normalize(gb_pos - camera_positions[:, None, None, :], dim=-1) 58 | gb_rgb_bg = self.background(dirs=gb_viewdirs) 59 | gb_rgb = torch.lerp(gb_rgb_bg, gb_rgb_fg, mask.float()) 60 | gb_rgb_aa = self.ctx.antialias(gb_rgb, rast, v_pos_clip, mesh.t_pos_idx) 61 | 62 | out.update({"comp_rgb": gb_rgb_aa, "comp_rgb_bg": gb_rgb_bg}) 63 | 64 | return out 65 | -------------------------------------------------------------------------------- /configs/mesh-fitting.yaml: -------------------------------------------------------------------------------- 1 | name: "mesh-fitting" 2 | tag: "${rmspace:,_}" 3 | exp_root_dir: "outputs" 4 | seed: 0 5 | 6 | data_type: "random-camera-datamodule" 7 | data: 8 | width: 512 9 | height: 512 10 | camera_distance_range: [1.5, 2.0] 11 | elevation_range: [-10, 45] 12 | light_sample_strategy: "magic3d" 13 | fovy_range: [30, 45] 14 | eval_camera_distance: 2.0 15 | eval_fovy_deg: 70. 16 | 17 | system_type: "mesh-fitting-system" 18 | system: 19 | refinement: true 20 | geometry_convert_inherit_texture: true 21 | geometry_type: "custom-mesh" 22 | geometry: 23 | shape_init: ??? 24 | shape_init_params: 1.0 25 | radius: 1.0 # consistent with coarse 26 | pos_encoding_config: 27 | otype: HashGrid 28 | n_levels: 16 29 | n_features_per_level: 2 30 | log2_hashmap_size: 19 31 | base_resolution: 16 32 | per_level_scale: 1.4472692374403782 # max resolution 4096 33 | n_feature_dims: 8 # albedo3 + roughness1 + metallic1 + bump3 34 | shape_init_mesh_up: "+y" 35 | shape_init_mesh_front: "-z" 36 | 37 | material_type: "diffuse-with-point-light-material" 38 | material: 39 | ambient_only_steps: 0 40 | soft_shading: true 41 | 42 | background_type: "solid-color-background" 43 | background: 44 | n_output_dims: 3 45 | color: [0, 0, 0] 46 | 47 | renderer_type: "nvdiff-rasterizer" 48 | renderer: 49 | context_type: cuda 50 | 51 | guidance_type: "mesh-fitting-guidance" 52 | guidance: 53 | geometry_type: "mesh-fitting-obj-mesh" 54 | geometry: 55 | shape_init: ??? 56 | shape_init_params: 1.0 57 | shape_init_mesh_up: "+y" 58 | shape_init_mesh_front: "-z" 59 | material_type: "diffuse-with-point-light-material" 60 | material: 61 | ambient_only_steps: 0 62 | soft_shading: true 63 | background_type: "solid-color-background" 64 | background: 65 | n_output_dims: 3 66 | color: [0, 0, 0] 67 | renderer_type: "mesh-fitting-renderer" 68 | renderer: 69 | context_type: cuda 70 | 71 | loggers: 72 | wandb: 73 | enable: false 74 | project: "threestudio" 75 | name: None 76 | 77 | loss: 78 | lambda_l1: 1. 79 | 80 | optimizer: 81 | name: Adam 82 | args: 83 | lr: 0.01 84 | betas: [0.9, 0.99] 85 | eps: 1.e-15 86 | params: 87 | geometry: 88 | lr: 0.001 89 | 90 | trainer: 91 | max_steps: 5000 92 | log_every_n_steps: 1 93 | num_sanity_val_steps: 1 94 | val_check_interval: 100 95 | enable_progress_bar: true 96 | precision: 16-mixed 97 | 98 | checkpoint: 99 | save_last: true 100 | save_top_k: -1 101 | every_n_train_steps: ${trainer.max_steps} 102 | -------------------------------------------------------------------------------- /geometry/obj_mesh.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | 4 | import numpy as np 5 | import threestudio 6 | import torch 7 | import trimesh 8 | from threestudio.models.geometry.base import BaseExplicitGeometry, contract_to_unisphere 9 | from threestudio.models.mesh import Mesh 10 | from threestudio.utils.typing import * 11 | 12 | 13 | @threestudio.register("mesh-fitting-obj-mesh") 14 | class OBJMesh(BaseExplicitGeometry): 15 | @dataclass 16 | class Config(BaseExplicitGeometry.Config): 17 | shape_init: str = "" 18 | shape_init_params: Optional[Any] = None 19 | shape_init_mesh_up: str = "+z" 20 | shape_init_mesh_front: str = "+x" 21 | 22 | cfg: Config 23 | 24 | def configure(self) -> None: 25 | super().configure() 26 | # Initialize custom mesh 27 | if self.cfg.shape_init.startswith("mesh:"): 28 | assert isinstance(self.cfg.shape_init_params, float) 29 | mesh_path = self.cfg.shape_init[5:] 30 | if not os.path.exists(mesh_path): 31 | raise ValueError(f"Mesh file {mesh_path} does not exist.") 32 | 33 | scene = trimesh.load(mesh_path) 34 | if isinstance(scene, trimesh.Trimesh): 35 | mesh = scene 36 | elif isinstance(scene, trimesh.scene.Scene): 37 | mesh = trimesh.Trimesh() 38 | for obj in scene.geometry.values(): 39 | mesh = trimesh.util.concatenate([mesh, obj]) 40 | else: 41 | raise ValueError(f"Unknown mesh type at {mesh_path}.") 42 | 43 | # move to center 44 | centroid = mesh.vertices.mean(0) 45 | mesh.vertices = mesh.vertices - centroid 46 | 47 | # align to up-z and front-x 48 | dirs = ["+x", "+y", "+z", "-x", "-y", "-z"] 49 | dir2vec = { 50 | "+x": np.array([1, 0, 0]), 51 | "+y": np.array([0, 1, 0]), 52 | "+z": np.array([0, 0, 1]), 53 | "-x": np.array([-1, 0, 0]), 54 | "-y": np.array([0, -1, 0]), 55 | "-z": np.array([0, 0, -1]), 56 | } 57 | if ( 58 | self.cfg.shape_init_mesh_up not in dirs 59 | or self.cfg.shape_init_mesh_front not in dirs 60 | ): 61 | raise ValueError( 62 | f"shape_init_mesh_up and shape_init_mesh_front must be one of {dirs}." 63 | ) 64 | if self.cfg.shape_init_mesh_up[1] == self.cfg.shape_init_mesh_front[1]: 65 | raise ValueError( 66 | "shape_init_mesh_up and shape_init_mesh_front must be orthogonal." 67 | ) 68 | z_, x_ = ( 69 | dir2vec[self.cfg.shape_init_mesh_up], 70 | dir2vec[self.cfg.shape_init_mesh_front], 71 | ) 72 | y_ = np.cross(z_, x_) 73 | std2mesh = np.stack([x_, y_, z_], axis=0).T 74 | mesh2std = np.linalg.inv(std2mesh) 75 | 76 | # scaling 77 | scale = np.abs(mesh.vertices).max() 78 | mesh.vertices = mesh.vertices / scale * self.cfg.shape_init_params 79 | mesh.vertices = np.dot(mesh2std, mesh.vertices.T).T 80 | 81 | self.v_pos = torch.tensor(mesh.vertices, dtype=torch.float32).to( 82 | self.device 83 | ) 84 | self.v_color = torch.tensor( 85 | mesh.visual.vertex_colors[:, :3] / 255, dtype=torch.float32 86 | ).to(self.device) 87 | self.t_pos_idx = torch.tensor(mesh.faces, dtype=torch.int64).to(self.device) 88 | 89 | else: 90 | raise ValueError( 91 | f"Unknown shape initialization type: {self.cfg.shape_init}" 92 | ) 93 | 94 | def isosurface(self): 95 | return self 96 | -------------------------------------------------------------------------------- /system/mesh_fitting.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | import threestudio 4 | from threestudio.systems.base import BaseLift3DSystem 5 | from threestudio.utils.typing import * 6 | 7 | 8 | @threestudio.register("mesh-fitting-system") 9 | class MeshFittingSystem(BaseLift3DSystem): 10 | @dataclass 11 | class Config(BaseLift3DSystem.Config): 12 | refinement: bool = False 13 | 14 | cfg: Config 15 | 16 | def configure(self): 17 | # create geometry, material, background, renderer 18 | super().configure() 19 | 20 | def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]: 21 | render_out = self.renderer(**batch) 22 | return { 23 | **render_out, 24 | } 25 | 26 | def on_fit_start(self) -> None: 27 | super().on_fit_start() 28 | self.guidance = threestudio.find(self.cfg.guidance_type)(self.cfg.guidance) 29 | 30 | def training_step(self, batch, batch_idx): 31 | out = self(batch) 32 | guidance_out = self.guidance(out["comp_rgb"], **batch, rgb_as_latents=False) 33 | loss = 0.0 34 | for name, value in guidance_out.items(): 35 | self.log(f"train/{name}", value) 36 | if name.startswith("loss_"): 37 | loss += value * self.C(self.cfg.loss[name.replace("loss_", "lambda_")]) 38 | 39 | for name, value in self.cfg.loss.items(): 40 | self.log(f"train_params/{name}", self.C(value)) 41 | 42 | return {"loss": loss} 43 | 44 | def validation_step(self, batch, batch_idx): 45 | out = self(batch) 46 | self.save_image_grid( 47 | f"it{self.true_global_step}-{batch['index'][0]}.png", 48 | [ 49 | { 50 | "type": "rgb", 51 | "img": out["comp_rgb"][0], 52 | "kwargs": {"data_format": "HWC"}, 53 | }, 54 | ] 55 | + ( 56 | [ 57 | { 58 | "type": "rgb", 59 | "img": out["comp_normal"][0], 60 | "kwargs": {"data_format": "HWC", "data_range": (0, 1)}, 61 | } 62 | ] 63 | if "comp_normal" in out 64 | else [] 65 | ) 66 | + [ 67 | { 68 | "type": "grayscale", 69 | "img": out["opacity"][0, :, :, 0], 70 | "kwargs": {"cmap": None, "data_range": (0, 1)}, 71 | }, 72 | ], 73 | name="validation_step", 74 | step=self.true_global_step, 75 | ) 76 | 77 | def on_validation_epoch_end(self): 78 | pass 79 | 80 | def test_step(self, batch, batch_idx): 81 | out = self(batch) 82 | self.save_image_grid( 83 | f"it{self.true_global_step}-test/{batch['index'][0]}.png", 84 | [ 85 | { 86 | "type": "rgb", 87 | "img": out["comp_rgb"][0], 88 | "kwargs": {"data_format": "HWC"}, 89 | }, 90 | ] 91 | + ( 92 | [ 93 | { 94 | "type": "rgb", 95 | "img": out["comp_normal"][0], 96 | "kwargs": {"data_format": "HWC", "data_range": (0, 1)}, 97 | } 98 | ] 99 | if "comp_normal" in out 100 | else [] 101 | ) 102 | + [ 103 | { 104 | "type": "grayscale", 105 | "img": out["opacity"][0, :, :, 0], 106 | "kwargs": {"cmap": None, "data_range": (0, 1)}, 107 | }, 108 | ], 109 | name="test_step", 110 | step=self.true_global_step, 111 | ) 112 | 113 | def on_test_epoch_end(self): 114 | self.save_img_sequence( 115 | f"it{self.true_global_step}-test", 116 | f"it{self.true_global_step}-test", 117 | "(\d+)\.png", 118 | save_format="mp4", 119 | fps=30, 120 | name="test", 121 | step=self.true_global_step, 122 | ) 123 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/python 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=python 3 | 4 | ### Python ### 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # poetry 102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 106 | #poetry.lock 107 | 108 | # pdm 109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 110 | #pdm.lock 111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 112 | # in version control. 113 | # https://pdm.fming.dev/#use-with-ide 114 | .pdm.toml 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # pytype static type analyzer 154 | .pytype/ 155 | 156 | # Cython debug symbols 157 | cython_debug/ 158 | 159 | # PyCharm 160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 162 | # and can be added to the global gitignore or merged into this file. For a more nuclear 163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 164 | #.idea/ 165 | 166 | ### Python Patch ### 167 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 168 | poetry.toml 169 | 170 | # ruff 171 | .ruff_cache/ 172 | 173 | # LSP config files 174 | pyrightconfig.json 175 | 176 | # End of https://www.toptal.com/developers/gitignore/api/python 177 | 178 | .vscode/ 179 | .threestudio_cache/ 180 | outputs/ 181 | outputs-gradio/ 182 | 183 | # pretrained model weights 184 | *.ckpt 185 | *.pt 186 | *.pth 187 | 188 | # wandb 189 | wandb/ 190 | 191 | custom/* 192 | 193 | load/tets/256_tets.npz 194 | --------------------------------------------------------------------------------