├── data └── .gitkeep ├── gen_slices ├── ldm │ ├── data │ │ ├── __init__.py │ │ ├── base.py │ │ ├── lsun.py │ │ ├── custom_sin_img.py │ │ └── objaverse.py │ ├── models │ │ └── diffusion │ │ │ └── __init__.py │ ├── modules │ │ ├── encoders │ │ │ └── __init__.py │ │ ├── distributions │ │ │ ├── __init__.py │ │ │ └── distributions.py │ │ ├── diffusionmodules │ │ │ └── __init__.py │ │ ├── losses │ │ │ ├── __init__.py │ │ │ └── contperceptual.py │ │ ├── image_degradation │ │ │ ├── utils │ │ │ │ └── test.png │ │ │ └── __init__.py │ │ └── ema.py │ ├── .DS_Store │ └── lr_scheduler.py ├── assets │ └── .DS_Store ├── configs │ ├── .DS_Store │ ├── autoencoder │ │ ├── autoencoder_kl_32x32x4.yaml │ │ ├── autoencoder_kl_64x64x3.yaml │ │ ├── autoencoder_kl_8x8x64.yaml │ │ ├── autoencoder_kl_16x16x16.yaml │ │ └── autoencoder_kl_f8_infer.yaml │ ├── latent-diffusion │ │ ├── cin256-v2.yaml │ │ ├── txt2img-1p4B-eval.yaml │ │ ├── ffhq-ldm-vq-4.yaml │ │ ├── lsun_bedrooms-ldm-vq-4.yaml │ │ ├── celebahq-ldm-vq-4.yaml │ │ ├── objaverse-ldm-kl-8.yaml │ │ ├── lsun_churches-ldm-kl-8.yaml │ │ ├── objaverse-ldm-kl-8-infer.yaml │ │ ├── custom-sin-img-ldm-kl-8-infer.yaml │ │ └── cin-ldm-vq-f8.yaml │ └── retrieval-augmented-diffusion │ │ └── 768x768.yaml ├── models │ ├── .DS_Store │ ├── first_stage_models │ │ ├── kl-f4 │ │ │ └── config.yaml │ │ ├── kl-f8 │ │ │ └── config.yaml │ │ ├── kl-f16 │ │ │ └── config.yaml │ │ ├── kl-f32 │ │ │ └── config.yaml │ │ ├── vq-f4 │ │ │ └── config.yaml │ │ ├── vq-f4-noattn │ │ │ └── config.yaml │ │ ├── vq-f8-n256 │ │ │ └── config.yaml │ │ ├── vq-f16 │ │ │ └── config.yaml │ │ └── vq-f8 │ │ │ └── config.yaml │ └── ldm │ │ ├── semantic_synthesis256 │ │ └── config.yaml │ │ ├── ffhq256 │ │ └── config.yaml │ │ ├── celeba256 │ │ └── config.yaml │ │ ├── lsun_beds256 │ │ └── config.yaml │ │ ├── inpainting_big │ │ └── config.yaml │ │ ├── text2img256 │ │ └── config.yaml │ │ ├── semantic_synthesis512 │ │ └── config.yaml │ │ ├── cin256 │ │ └── config.yaml │ │ ├── bsr_sr │ │ └── config.yaml │ │ ├── layout2img-openimages256 │ │ └── config.yaml │ │ └── lsun_churches256 │ │ └── config.yaml ├── setup.py ├── environment.yaml ├── scripts │ ├── download_first_stages.sh │ ├── download_models.sh │ ├── inpaint.py │ └── txt2img.py └── re_org_slices.py ├── reg_slices ├── src_convonet │ └── utils │ │ ├── __init__.py │ │ ├── libvoxelize │ │ ├── __init__.py │ │ ├── .gitignore │ │ └── voxelize.pyx │ │ ├── libkdtree │ │ ├── .gitignore │ │ ├── pykdtree │ │ │ ├── __init__.py │ │ │ └── render_template.py │ │ ├── setup.cfg │ │ ├── MANIFEST.in │ │ └── __init__.py │ │ ├── libmcubes │ │ ├── .gitignore │ │ ├── pyarray_symbol.h │ │ ├── __init__.py │ │ ├── pywrapper.h │ │ ├── LICENSE │ │ ├── mcubes.pyx │ │ ├── exporter.py │ │ ├── README.rst │ │ └── pyarraymodule.h │ │ ├── libmesh │ │ ├── .gitignore │ │ ├── __init__.py │ │ └── triangle_hash.pyx │ │ ├── libmise │ │ ├── .gitignore │ │ ├── __init__.py │ │ └── test.py │ │ ├── libsimplify │ │ ├── test.py │ │ ├── __init__.py │ │ └── simplify_mesh.pyx │ │ ├── visualize.py │ │ ├── io.py │ │ └── icp.py ├── src │ ├── vgg16bn_feats.py │ ├── vgg16bn_feats_for_disn.py │ ├── unet_custom.py │ ├── unet_parts.py │ ├── utils_eval.py │ ├── vgg_perceptual_loss.py │ ├── models.py │ ├── model_disn.py │ ├── model_gt.py │ └── datasets_cam.py ├── setup.py ├── options.py ├── test_projection.py └── reconstruct_slices.py ├── imgs ├── demo │ └── input.png └── teaser │ └── slice3d.jpg ├── LICENSE.txt ├── render_slices ├── gen_input.py └── gen_slices.py ├── create_dataset_sin_img.py └── .gitignore /data/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /gen_slices/ldm/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /gen_slices/ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /gen_slices/ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /reg_slices/src_convonet/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /gen_slices/ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /gen_slices/ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /reg_slices/src_convonet/utils/libvoxelize/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /reg_slices/src_convonet/utils/libkdtree/.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | -------------------------------------------------------------------------------- /reg_slices/src_convonet/utils/libkdtree/pykdtree/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /reg_slices/src_convonet/utils/libmcubes/.gitignore: -------------------------------------------------------------------------------- 1 | PyMCubes.egg-info 2 | build 3 | -------------------------------------------------------------------------------- /reg_slices/src_convonet/utils/libmesh/.gitignore: -------------------------------------------------------------------------------- 1 | triangle_hash.cpp 2 | build 3 | -------------------------------------------------------------------------------- /reg_slices/src_convonet/utils/libmise/.gitignore: -------------------------------------------------------------------------------- 1 | mise.c 2 | mise.cpp 3 | mise.html 4 | -------------------------------------------------------------------------------- /imgs/demo/input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizhiwang96/Slice3D/HEAD/imgs/demo/input.png -------------------------------------------------------------------------------- /reg_slices/src_convonet/utils/libvoxelize/.gitignore: -------------------------------------------------------------------------------- 1 | voxelize.c 2 | voxelize.html 3 | build 4 | -------------------------------------------------------------------------------- /gen_slices/ldm/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizhiwang96/Slice3D/HEAD/gen_slices/ldm/.DS_Store -------------------------------------------------------------------------------- /imgs/teaser/slice3d.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizhiwang96/Slice3D/HEAD/imgs/teaser/slice3d.jpg -------------------------------------------------------------------------------- /reg_slices/src_convonet/utils/libkdtree/setup.cfg: -------------------------------------------------------------------------------- 1 | [bdist_rpm] 2 | requires=numpy 3 | release=1 4 | 5 | 6 | -------------------------------------------------------------------------------- /gen_slices/assets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizhiwang96/Slice3D/HEAD/gen_slices/assets/.DS_Store -------------------------------------------------------------------------------- /gen_slices/configs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizhiwang96/Slice3D/HEAD/gen_slices/configs/.DS_Store -------------------------------------------------------------------------------- /gen_slices/ldm/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /gen_slices/models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizhiwang96/Slice3D/HEAD/gen_slices/models/.DS_Store -------------------------------------------------------------------------------- /reg_slices/src_convonet/utils/libkdtree/MANIFEST.in: -------------------------------------------------------------------------------- 1 | exclude pykdtree/render_template.py 2 | include LICENSE.txt 3 | -------------------------------------------------------------------------------- /reg_slices/src_convonet/utils/libmcubes/pyarray_symbol.h: -------------------------------------------------------------------------------- 1 | 2 | #define PY_ARRAY_UNIQUE_SYMBOL mcubes_PyArray_API 3 | -------------------------------------------------------------------------------- /reg_slices/src_convonet/utils/libmise/__init__.py: -------------------------------------------------------------------------------- 1 | from .mise import MISE 2 | 3 | __all__ = [ 4 | MISE 5 | ] 6 | -------------------------------------------------------------------------------- /reg_slices/src_convonet/utils/libkdtree/__init__.py: -------------------------------------------------------------------------------- 1 | from .pykdtree.kdtree import KDTree 2 | 3 | 4 | __all__ = [ 5 | KDTree 6 | ] 7 | -------------------------------------------------------------------------------- /gen_slices/ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizhiwang96/Slice3D/HEAD/gen_slices/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /reg_slices/src_convonet/utils/libsimplify/test.py: -------------------------------------------------------------------------------- 1 | from simplify_mesh import mesh_simplify 2 | import numpy as np 3 | 4 | v = np.random.rand(100, 3) 5 | f = np.random.choice(range(100), (50, 3)) 6 | 7 | mesh_simplify(v, f, 50) -------------------------------------------------------------------------------- /reg_slices/src_convonet/utils/libmesh/__init__.py: -------------------------------------------------------------------------------- 1 | from .inside_mesh import ( 2 | check_mesh_contains, MeshIntersector, TriangleIntersector2d 3 | ) 4 | 5 | 6 | __all__ = [ 7 | check_mesh_contains, MeshIntersector, TriangleIntersector2d 8 | ] 9 | -------------------------------------------------------------------------------- /gen_slices/ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /reg_slices/src_convonet/utils/libkdtree/pykdtree/render_template.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from mako.template import Template 4 | 5 | mytemplate = Template(filename='_kdtree_core.c.mako') 6 | with open('_kdtree_core.c', 'w') as fp: 7 | fp.write(mytemplate.render()) 8 | -------------------------------------------------------------------------------- /gen_slices/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='latent-diffusion', 5 | version='0.0.1', 6 | description='', 7 | packages=find_packages(), 8 | install_requires=[ 9 | 'torch', 10 | 'numpy', 11 | 'tqdm', 12 | ], 13 | ) -------------------------------------------------------------------------------- /reg_slices/src_convonet/utils/libmcubes/__init__.py: -------------------------------------------------------------------------------- 1 | from src_convonet.utils.libmcubes.mcubes import ( 2 | marching_cubes, marching_cubes_func 3 | ) 4 | from src_convonet.utils.libmcubes.exporter import ( 5 | export_mesh, export_obj, export_off 6 | ) 7 | 8 | 9 | __all__ = [ 10 | marching_cubes, marching_cubes_func, 11 | export_mesh, export_obj, export_off 12 | ] 13 | -------------------------------------------------------------------------------- /reg_slices/src_convonet/utils/libsimplify/__init__.py: -------------------------------------------------------------------------------- 1 | from .simplify_mesh import ( 2 | mesh_simplify 3 | ) 4 | import trimesh 5 | 6 | 7 | def simplify_mesh(mesh, f_target=10000, agressiveness=7.): 8 | vertices = mesh.vertices 9 | faces = mesh.faces 10 | 11 | vertices, faces = mesh_simplify(vertices, faces, f_target, agressiveness) 12 | 13 | mesh_simplified = trimesh.Trimesh(vertices, faces, process=False) 14 | 15 | return mesh_simplified 16 | -------------------------------------------------------------------------------- /reg_slices/src_convonet/utils/libmcubes/pywrapper.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef _PYWRAPPER_H 3 | #define _PYWRAPPER_H 4 | 5 | #include 6 | #include "pyarraymodule.h" 7 | 8 | #include 9 | 10 | PyObject* marching_cubes(PyArrayObject* arr, double isovalue); 11 | PyObject* marching_cubes2(PyArrayObject* arr, double isovalue); 12 | PyObject* marching_cubes3(PyArrayObject* arr, double isovalue); 13 | PyObject* marching_cubes_func(PyObject* lower, PyObject* upper, 14 | int numx, int numy, int numz, PyObject* f, double isovalue); 15 | 16 | #endif // _PYWRAPPER_H 17 | -------------------------------------------------------------------------------- /reg_slices/src_convonet/utils/libmise/test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from mise import MISE 3 | import time 4 | 5 | t0 = time.time() 6 | extractor = MISE(1, 2, 0.) 7 | 8 | p = extractor.query() 9 | i = 0 10 | 11 | while p.shape[0] != 0: 12 | print(i) 13 | print(p) 14 | v = 2 * (p.sum(axis=-1) > 2).astype(np.float64) - 1 15 | extractor.update(p, v) 16 | p = extractor.query() 17 | i += 1 18 | if (i >= 8): 19 | break 20 | 21 | print(extractor.to_dense()) 22 | # p, v = extractor.get_points() 23 | # print(p) 24 | # print(v) 25 | print('Total time: %f' % (time.time() - t0)) 26 | -------------------------------------------------------------------------------- /gen_slices/environment.yaml: -------------------------------------------------------------------------------- 1 | name: ldm 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8.5 7 | - pip=20.3 8 | - cudatoolkit=11.0 9 | - pytorch=1.7.0 10 | - torchvision=0.8.1 11 | - numpy=1.19.2 12 | - pip: 13 | - albumentations==0.4.3 14 | - opencv-python==4.1.2.30 15 | - pudb==2019.2 16 | - imageio==2.9.0 17 | - imageio-ffmpeg==0.4.2 18 | - pytorch-lightning==1.4.2 19 | - omegaconf==2.1.1 20 | - test-tube>=0.7.5 21 | - streamlit>=0.73.1 22 | - einops==0.3.0 23 | - torch-fidelity==0.3.0 24 | - transformers==4.3.1 25 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers 26 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip 27 | - -e . -------------------------------------------------------------------------------- /gen_slices/ldm/data/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset 3 | 4 | 5 | class Txt2ImgIterableBaseDataset(IterableDataset): 6 | ''' 7 | Define an interface to make the IterableDatasets for text2img data chainable 8 | ''' 9 | def __init__(self, num_records=0, valid_ids=None, size=256): 10 | super().__init__() 11 | self.num_records = num_records 12 | self.valid_ids = valid_ids 13 | self.sample_ids = valid_ids 14 | self.size = size 15 | 16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') 17 | 18 | def __len__(self): 19 | return self.num_records 20 | 21 | @abstractmethod 22 | def __iter__(self): 23 | pass -------------------------------------------------------------------------------- /gen_slices/models/first_stage_models/kl-f4/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: val/rec_loss 6 | embed_dim: 3 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 1.0e-06 12 | disc_weight: 0.5 13 | ddconfig: 14 | double_z: true 15 | z_channels: 3 16 | resolution: 256 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: 21 | - 1 22 | - 2 23 | - 4 24 | num_res_blocks: 2 25 | attn_resolutions: [] 26 | dropout: 0.0 27 | data: 28 | target: main.DataModuleFromConfig 29 | params: 30 | batch_size: 10 31 | wrap: true 32 | train: 33 | target: ldm.data.openimages.FullOpenImagesTrain 34 | params: 35 | size: 384 36 | crop_size: 256 37 | validation: 38 | target: ldm.data.openimages.FullOpenImagesValidation 39 | params: 40 | size: 384 41 | crop_size: 256 42 | -------------------------------------------------------------------------------- /gen_slices/models/first_stage_models/kl-f8/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: val/rec_loss 6 | embed_dim: 4 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 1.0e-06 12 | disc_weight: 0.5 13 | ddconfig: 14 | double_z: true 15 | z_channels: 4 16 | resolution: 256 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: 21 | - 1 22 | - 2 23 | - 4 24 | - 4 25 | num_res_blocks: 2 26 | attn_resolutions: [] 27 | dropout: 0.0 28 | data: 29 | target: main.DataModuleFromConfig 30 | params: 31 | batch_size: 4 32 | wrap: true 33 | train: 34 | target: ldm.data.openimages.FullOpenImagesTrain 35 | params: 36 | size: 384 37 | crop_size: 256 38 | validation: 39 | target: ldm.data.openimages.FullOpenImagesValidation 40 | params: 41 | size: 384 42 | crop_size: 256 43 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Yizhi Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /gen_slices/models/first_stage_models/kl-f16/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: val/rec_loss 6 | embed_dim: 16 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 1.0e-06 12 | disc_weight: 0.5 13 | ddconfig: 14 | double_z: true 15 | z_channels: 16 16 | resolution: 256 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: 21 | - 1 22 | - 1 23 | - 2 24 | - 2 25 | - 4 26 | num_res_blocks: 2 27 | attn_resolutions: 28 | - 16 29 | dropout: 0.0 30 | data: 31 | target: main.DataModuleFromConfig 32 | params: 33 | batch_size: 6 34 | wrap: true 35 | train: 36 | target: ldm.data.openimages.FullOpenImagesTrain 37 | params: 38 | size: 384 39 | crop_size: 256 40 | validation: 41 | target: ldm.data.openimages.FullOpenImagesValidation 42 | params: 43 | size: 384 44 | crop_size: 256 45 | -------------------------------------------------------------------------------- /gen_slices/models/first_stage_models/kl-f32/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: val/rec_loss 6 | embed_dim: 64 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 1.0e-06 12 | disc_weight: 0.5 13 | ddconfig: 14 | double_z: true 15 | z_channels: 64 16 | resolution: 256 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: 21 | - 1 22 | - 1 23 | - 2 24 | - 2 25 | - 4 26 | - 4 27 | num_res_blocks: 2 28 | attn_resolutions: 29 | - 16 30 | - 8 31 | dropout: 0.0 32 | data: 33 | target: main.DataModuleFromConfig 34 | params: 35 | batch_size: 6 36 | wrap: true 37 | train: 38 | target: ldm.data.openimages.FullOpenImagesTrain 39 | params: 40 | size: 384 41 | crop_size: 256 42 | validation: 43 | target: ldm.data.openimages.FullOpenImagesValidation 44 | params: 45 | size: 384 46 | crop_size: 256 47 | -------------------------------------------------------------------------------- /gen_slices/models/first_stage_models/vq-f4/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 3 6 | n_embed: 8192 7 | monitor: val/rec_loss 8 | 9 | ddconfig: 10 | double_z: false 11 | z_channels: 3 12 | resolution: 256 13 | in_channels: 3 14 | out_ch: 3 15 | ch: 128 16 | ch_mult: 17 | - 1 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: [] 22 | dropout: 0.0 23 | lossconfig: 24 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 25 | params: 26 | disc_conditional: false 27 | disc_in_channels: 3 28 | disc_start: 0 29 | disc_weight: 0.75 30 | codebook_weight: 1.0 31 | 32 | data: 33 | target: main.DataModuleFromConfig 34 | params: 35 | batch_size: 8 36 | num_workers: 16 37 | wrap: true 38 | train: 39 | target: ldm.data.openimages.FullOpenImagesTrain 40 | params: 41 | crop_size: 256 42 | validation: 43 | target: ldm.data.openimages.FullOpenImagesValidation 44 | params: 45 | crop_size: 256 46 | -------------------------------------------------------------------------------- /gen_slices/models/first_stage_models/vq-f4-noattn/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 3 6 | n_embed: 8192 7 | monitor: val/rec_loss 8 | 9 | ddconfig: 10 | attn_type: none 11 | double_z: false 12 | z_channels: 3 13 | resolution: 256 14 | in_channels: 3 15 | out_ch: 3 16 | ch: 128 17 | ch_mult: 18 | - 1 19 | - 2 20 | - 4 21 | num_res_blocks: 2 22 | attn_resolutions: [] 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_start: 11 30 | disc_weight: 0.75 31 | codebook_weight: 1.0 32 | 33 | data: 34 | target: main.DataModuleFromConfig 35 | params: 36 | batch_size: 8 37 | num_workers: 12 38 | wrap: true 39 | train: 40 | target: ldm.data.openimages.FullOpenImagesTrain 41 | params: 42 | crop_size: 256 43 | validation: 44 | target: ldm.data.openimages.FullOpenImagesValidation 45 | params: 46 | crop_size: 256 47 | -------------------------------------------------------------------------------- /gen_slices/models/first_stage_models/vq-f8-n256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 4 6 | n_embed: 256 7 | monitor: val/rec_loss 8 | ddconfig: 9 | double_z: false 10 | z_channels: 4 11 | resolution: 256 12 | in_channels: 3 13 | out_ch: 3 14 | ch: 128 15 | ch_mult: 16 | - 1 17 | - 2 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: 22 | - 32 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_start: 250001 30 | disc_weight: 0.75 31 | codebook_weight: 1.0 32 | 33 | data: 34 | target: main.DataModuleFromConfig 35 | params: 36 | batch_size: 10 37 | num_workers: 20 38 | wrap: true 39 | train: 40 | target: ldm.data.openimages.FullOpenImagesTrain 41 | params: 42 | size: 384 43 | crop_size: 256 44 | validation: 45 | target: ldm.data.openimages.FullOpenImagesValidation 46 | params: 47 | size: 384 48 | crop_size: 256 49 | -------------------------------------------------------------------------------- /gen_slices/models/first_stage_models/vq-f16/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 8 6 | n_embed: 16384 7 | ddconfig: 8 | double_z: false 9 | z_channels: 8 10 | resolution: 256 11 | in_channels: 3 12 | out_ch: 3 13 | ch: 128 14 | ch_mult: 15 | - 1 16 | - 1 17 | - 2 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: 22 | - 16 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_start: 250001 30 | disc_weight: 0.75 31 | disc_num_layers: 2 32 | codebook_weight: 1.0 33 | 34 | data: 35 | target: main.DataModuleFromConfig 36 | params: 37 | batch_size: 14 38 | num_workers: 20 39 | wrap: true 40 | train: 41 | target: ldm.data.openimages.FullOpenImagesTrain 42 | params: 43 | size: 384 44 | crop_size: 256 45 | validation: 46 | target: ldm.data.openimages.FullOpenImagesValidation 47 | params: 48 | size: 384 49 | crop_size: 256 50 | -------------------------------------------------------------------------------- /gen_slices/models/first_stage_models/vq-f8/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 4 6 | n_embed: 16384 7 | monitor: val/rec_loss 8 | ddconfig: 9 | double_z: false 10 | z_channels: 4 11 | resolution: 256 12 | in_channels: 3 13 | out_ch: 3 14 | ch: 128 15 | ch_mult: 16 | - 1 17 | - 2 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: 22 | - 32 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_num_layers: 2 30 | disc_start: 1 31 | disc_weight: 0.6 32 | codebook_weight: 1.0 33 | data: 34 | target: main.DataModuleFromConfig 35 | params: 36 | batch_size: 10 37 | num_workers: 20 38 | wrap: true 39 | train: 40 | target: ldm.data.openimages.FullOpenImagesTrain 41 | params: 42 | size: 384 43 | crop_size: 256 44 | validation: 45 | target: ldm.data.openimages.FullOpenImagesValidation 46 | params: 47 | size: 384 48 | crop_size: 256 49 | -------------------------------------------------------------------------------- /gen_slices/configs/autoencoder/autoencoder_kl_32x32x4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 4 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 4 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [ ] 24 | dropout: 0.0 25 | 26 | data: 27 | target: main.DataModuleFromConfig 28 | params: 29 | batch_size: 12 30 | wrap: True 31 | train: 32 | target: ldm.data.imagenet.ImageNetSRTrain 33 | params: 34 | size: 256 35 | degradation: pil_nearest 36 | validation: 37 | target: ldm.data.imagenet.ImageNetSRValidation 38 | params: 39 | size: 256 40 | degradation: pil_nearest 41 | 42 | lightning: 43 | callbacks: 44 | image_logger: 45 | target: main.ImageLogger 46 | params: 47 | batch_frequency: 1000 48 | max_images: 8 49 | increase_log_steps: True 50 | 51 | trainer: 52 | benchmark: True 53 | accumulate_grad_batches: 2 54 | -------------------------------------------------------------------------------- /gen_slices/configs/autoencoder/autoencoder_kl_64x64x3.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 3 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 3 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [ ] 24 | dropout: 0.0 25 | 26 | 27 | data: 28 | target: main.DataModuleFromConfig 29 | params: 30 | batch_size: 12 31 | wrap: True 32 | train: 33 | target: ldm.data.imagenet.ImageNetSRTrain 34 | params: 35 | size: 256 36 | degradation: pil_nearest 37 | validation: 38 | target: ldm.data.imagenet.ImageNetSRValidation 39 | params: 40 | size: 256 41 | degradation: pil_nearest 42 | 43 | lightning: 44 | callbacks: 45 | image_logger: 46 | target: main.ImageLogger 47 | params: 48 | batch_frequency: 1000 49 | max_images: 8 50 | increase_log_steps: True 51 | 52 | trainer: 53 | benchmark: True 54 | accumulate_grad_batches: 2 55 | -------------------------------------------------------------------------------- /gen_slices/configs/autoencoder/autoencoder_kl_8x8x64.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 64 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 64 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,1,2,2,4,4] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [16,8] 24 | dropout: 0.0 25 | 26 | data: 27 | target: main.DataModuleFromConfig 28 | params: 29 | batch_size: 12 30 | wrap: True 31 | train: 32 | target: ldm.data.imagenet.ImageNetSRTrain 33 | params: 34 | size: 256 35 | degradation: pil_nearest 36 | validation: 37 | target: ldm.data.imagenet.ImageNetSRValidation 38 | params: 39 | size: 256 40 | degradation: pil_nearest 41 | 42 | lightning: 43 | callbacks: 44 | image_logger: 45 | target: main.ImageLogger 46 | params: 47 | batch_frequency: 1000 48 | max_images: 8 49 | increase_log_steps: True 50 | 51 | trainer: 52 | benchmark: True 53 | accumulate_grad_batches: 2 54 | -------------------------------------------------------------------------------- /gen_slices/configs/autoencoder/autoencoder_kl_16x16x16.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 16 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 16 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [16] 24 | dropout: 0.0 25 | 26 | 27 | data: 28 | target: main.DataModuleFromConfig 29 | params: 30 | batch_size: 12 31 | wrap: True 32 | train: 33 | target: ldm.data.imagenet.ImageNetSRTrain 34 | params: 35 | size: 256 36 | degradation: pil_nearest 37 | validation: 38 | target: ldm.data.imagenet.ImageNetSRValidation 39 | params: 40 | size: 256 41 | degradation: pil_nearest 42 | 43 | lightning: 44 | callbacks: 45 | image_logger: 46 | target: main.ImageLogger 47 | params: 48 | batch_frequency: 1000 49 | max_images: 8 50 | increase_log_steps: True 51 | 52 | trainer: 53 | benchmark: True 54 | accumulate_grad_batches: 2 55 | -------------------------------------------------------------------------------- /gen_slices/scripts/download_first_stages.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget -O models/first_stage_models/kl-f4/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f4.zip 3 | wget -O models/first_stage_models/kl-f8/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f8.zip 4 | wget -O models/first_stage_models/kl-f16/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f16.zip 5 | wget -O models/first_stage_models/kl-f32/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f32.zip 6 | wget -O models/first_stage_models/vq-f4/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4.zip 7 | wget -O models/first_stage_models/vq-f4-noattn/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4-noattn.zip 8 | wget -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip 9 | wget -O models/first_stage_models/vq-f8-n256/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip 10 | wget -O models/first_stage_models/vq-f16/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f16.zip 11 | 12 | 13 | 14 | cd models/first_stage_models/kl-f4 15 | unzip -o model.zip 16 | 17 | cd ../kl-f8 18 | unzip -o model.zip 19 | 20 | cd ../kl-f16 21 | unzip -o model.zip 22 | 23 | cd ../kl-f32 24 | unzip -o model.zip 25 | 26 | cd ../vq-f4 27 | unzip -o model.zip 28 | 29 | cd ../vq-f4-noattn 30 | unzip -o model.zip 31 | 32 | cd ../vq-f8 33 | unzip -o model.zip 34 | 35 | cd ../vq-f8-n256 36 | unzip -o model.zip 37 | 38 | cd ../vq-f16 39 | unzip -o model.zip 40 | 41 | cd ../.. -------------------------------------------------------------------------------- /gen_slices/configs/autoencoder/autoencoder_kl_f8_infer.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | ckpt_path: "logs/autoencoder_kl_f8/checkpoints/model.ckpt" 7 | embed_dim: 4 8 | lossconfig: 9 | target: ldm.modules.losses.LPIPSWithDiscriminator 10 | params: 11 | disc_start: 50001 12 | kl_weight: 0.000001 13 | disc_weight: 0.5 14 | 15 | ddconfig: 16 | double_z: True 17 | z_channels: 4 18 | resolution: 256 19 | in_channels: 3 20 | out_ch: 3 21 | ch: 128 22 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 23 | num_res_blocks: 2 24 | attn_resolutions: [ ] 25 | dropout: 0.0 26 | 27 | data: 28 | target: main.DataModuleFromConfig 29 | params: 30 | batch_size: 8 31 | num_workers: 5 32 | wrap: False 33 | train: 34 | target: ldm.data.objaverse.ObjaverseTrain 35 | params: 36 | size: 128 37 | validation: 38 | target: ldm.data.objaverse.ObjaverseValidation 39 | params: 40 | size: 128 41 | test: 42 | target: ldm.data.objaverse.ObjaverseTrainValRec 43 | params: 44 | size: 128 45 | 46 | lightning: 47 | callbacks: 48 | image_logger: 49 | target: main.ImageLogger 50 | params: 51 | batch_frequency: 1000 52 | max_images: 8 53 | increase_log_steps: True 54 | 55 | trainer: 56 | benchmark: True 57 | accumulate_grad_batches: 2 58 | -------------------------------------------------------------------------------- /reg_slices/src_convonet/utils/libmcubes/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2012-2015, P. M. Neila 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | * Neither the name of the copyright holder nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /gen_slices/models/ldm/semantic_synthesis256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0205 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l1 10 | first_stage_key: image 11 | cond_stage_key: segmentation 12 | image_size: 64 13 | channels: 3 14 | concat_mode: true 15 | cond_stage_trainable: true 16 | unet_config: 17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 18 | params: 19 | image_size: 64 20 | in_channels: 6 21 | out_channels: 3 22 | model_channels: 128 23 | attention_resolutions: 24 | - 32 25 | - 16 26 | - 8 27 | num_res_blocks: 2 28 | channel_mult: 29 | - 1 30 | - 4 31 | - 8 32 | num_heads: 8 33 | first_stage_config: 34 | target: ldm.models.autoencoder.VQModelInterface 35 | params: 36 | embed_dim: 3 37 | n_embed: 8192 38 | ddconfig: 39 | double_z: false 40 | z_channels: 3 41 | resolution: 256 42 | in_channels: 3 43 | out_ch: 3 44 | ch: 128 45 | ch_mult: 46 | - 1 47 | - 2 48 | - 4 49 | num_res_blocks: 2 50 | attn_resolutions: [] 51 | dropout: 0.0 52 | lossconfig: 53 | target: torch.nn.Identity 54 | cond_stage_config: 55 | target: ldm.modules.encoders.modules.SpatialRescaler 56 | params: 57 | n_stages: 2 58 | in_channels: 182 59 | out_channels: 3 60 | -------------------------------------------------------------------------------- /reg_slices/src_convonet/utils/libmcubes/mcubes.pyx: -------------------------------------------------------------------------------- 1 | 2 | # distutils: language = c++ 3 | # cython: embedsignature = True 4 | 5 | # from libcpp.vector cimport vector 6 | import numpy as np 7 | 8 | # Define PY_ARRAY_UNIQUE_SYMBOL 9 | cdef extern from "pyarray_symbol.h": 10 | pass 11 | 12 | cimport numpy as np 13 | 14 | np.import_array() 15 | 16 | cdef extern from "pywrapper.h": 17 | cdef object c_marching_cubes "marching_cubes"(np.ndarray, double) except + 18 | cdef object c_marching_cubes2 "marching_cubes2"(np.ndarray, double) except + 19 | cdef object c_marching_cubes3 "marching_cubes3"(np.ndarray, double) except + 20 | cdef object c_marching_cubes_func "marching_cubes_func"(tuple, tuple, int, int, int, object, double) except + 21 | 22 | def marching_cubes(np.ndarray volume, float isovalue): 23 | 24 | verts, faces = c_marching_cubes(volume, isovalue) 25 | verts.shape = (-1, 3) 26 | faces.shape = (-1, 3) 27 | return verts, faces 28 | 29 | def marching_cubes2(np.ndarray volume, float isovalue): 30 | 31 | verts, faces = c_marching_cubes2(volume, isovalue) 32 | verts.shape = (-1, 3) 33 | faces.shape = (-1, 3) 34 | return verts, faces 35 | 36 | def marching_cubes3(np.ndarray volume, float isovalue): 37 | 38 | verts, faces = c_marching_cubes3(volume, isovalue) 39 | verts.shape = (-1, 3) 40 | faces.shape = (-1, 3) 41 | return verts, faces 42 | 43 | def marching_cubes_func(tuple lower, tuple upper, int numx, int numy, int numz, object f, double isovalue): 44 | 45 | verts, faces = c_marching_cubes_func(lower, upper, numx, numy, numz, f, isovalue) 46 | verts.shape = (-1, 3) 47 | faces.shape = (-1, 3) 48 | return verts, faces 49 | -------------------------------------------------------------------------------- /render_slices/gen_input.py: -------------------------------------------------------------------------------- 1 | import os 2 | from joblib import Parallel, delayed 3 | from csv import reader 4 | import json 5 | 6 | path_root_objaverse = '/data/wangyz/03_datasets/' 7 | path_output = '../data/objaverse/00_img_input' 8 | path_blender = '/data/wangyz/04_blender_renderer/blender-3.6.0-linux-x64' 9 | 10 | def gen_input(sid, spath): 11 | 12 | try: 13 | cmd = f'{path_blender}/blender -b -P blender_script_input.py -- \ 14 | --object_path {path_root_objaverse}/{spath} \ 15 | --output_dir ./{path_output} \ 16 | --engine CYCLES \ 17 | --num_images 12 ' 18 | 19 | os.system(cmd) 20 | except: 21 | f = open(f'./processed_objaverse_input_lvis/dataset_input/failed/{sid}.txt', 'w') 22 | f.close() 23 | return 24 | 25 | def get_shape_ids_and_paths(): 26 | shape_ids = [] 27 | sid2spath = {} 28 | shape_paths = [] 29 | # Open the JSON file 30 | with open('../data/objaverse/input_models_path-lvis.json', 'r') as file: 31 | data = json.load(file) 32 | for index, item in enumerate(data): 33 | if os.path.exists(f'{path_root_objaverse}/{item}'): 34 | shape_ids.append(item.split('/')[-1].split('.')[0]) 35 | shape_paths.append(item) 36 | sid2spath[item.split('/')[-1].split('.')[0]] = item 37 | return shape_ids, sid2spath 38 | 39 | def main(): 40 | 41 | shape_ids, sid2spath = get_shape_ids_and_paths() 42 | shape_ids = ['0a0c7e40a66d4fd090f549599f2f2c9d'] # this is an example, delete this line when creating a dataset 43 | with Parallel(n_jobs=8) as p: 44 | p(delayed(gen_input)(sid=sid, spath=sid2spath[sid]) for idx, sid in enumerate(shape_ids)) 45 | 46 | main() 47 | -------------------------------------------------------------------------------- /gen_slices/configs/latent-diffusion/cin256-v2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0001 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss 17 | use_ema: False 18 | 19 | unet_config: 20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 21 | params: 22 | image_size: 64 23 | in_channels: 3 24 | out_channels: 3 25 | model_channels: 192 26 | attention_resolutions: 27 | - 8 28 | - 4 29 | - 2 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 5 36 | num_heads: 1 37 | use_spatial_transformer: true 38 | transformer_depth: 1 39 | context_dim: 512 40 | 41 | first_stage_config: 42 | target: ldm.models.autoencoder.VQModelInterface 43 | params: 44 | embed_dim: 3 45 | n_embed: 8192 46 | ddconfig: 47 | double_z: false 48 | z_channels: 3 49 | resolution: 256 50 | in_channels: 3 51 | out_ch: 3 52 | ch: 128 53 | ch_mult: 54 | - 1 55 | - 2 56 | - 4 57 | num_res_blocks: 2 58 | attn_resolutions: [] 59 | dropout: 0.0 60 | lossconfig: 61 | target: torch.nn.Identity 62 | 63 | cond_stage_config: 64 | target: ldm.modules.encoders.modules.ClassEmbedder 65 | params: 66 | n_classes: 1001 67 | embed_dim: 512 68 | key: class_label 69 | -------------------------------------------------------------------------------- /gen_slices/scripts/download_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget -O models/ldm/celeba256/celeba-256.zip https://ommer-lab.com/files/latent-diffusion/celeba.zip 3 | wget -O models/ldm/ffhq256/ffhq-256.zip https://ommer-lab.com/files/latent-diffusion/ffhq.zip 4 | wget -O models/ldm/lsun_churches256/lsun_churches-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_churches.zip 5 | wget -O models/ldm/lsun_beds256/lsun_beds-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip 6 | wget -O models/ldm/text2img256/model.zip https://ommer-lab.com/files/latent-diffusion/text2img.zip 7 | wget -O models/ldm/cin256/model.zip https://ommer-lab.com/files/latent-diffusion/cin.zip 8 | wget -O models/ldm/semantic_synthesis512/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip 9 | wget -O models/ldm/semantic_synthesis256/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip 10 | wget -O models/ldm/bsr_sr/model.zip https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip 11 | wget -O models/ldm/layout2img-openimages256/model.zip https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip 12 | wget -O models/ldm/inpainting_big/model.zip https://ommer-lab.com/files/latent-diffusion/inpainting_big.zip 13 | 14 | 15 | 16 | cd models/ldm/celeba256 17 | unzip -o celeba-256.zip 18 | 19 | cd ../ffhq256 20 | unzip -o ffhq-256.zip 21 | 22 | cd ../lsun_churches256 23 | unzip -o lsun_churches-256.zip 24 | 25 | cd ../lsun_beds256 26 | unzip -o lsun_beds-256.zip 27 | 28 | cd ../text2img256 29 | unzip -o model.zip 30 | 31 | cd ../cin256 32 | unzip -o model.zip 33 | 34 | cd ../semantic_synthesis512 35 | unzip -o model.zip 36 | 37 | cd ../semantic_synthesis256 38 | unzip -o model.zip 39 | 40 | cd ../bsr_sr 41 | unzip -o model.zip 42 | 43 | cd ../layout2img-openimages256 44 | unzip -o model.zip 45 | 46 | cd ../inpainting_big 47 | unzip -o model.zip 48 | 49 | cd ../.. 50 | -------------------------------------------------------------------------------- /gen_slices/models/ldm/ffhq256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: false 15 | concat_mode: false 16 | monitor: val/loss 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 64 21 | in_channels: 3 22 | out_channels: 3 23 | model_channels: 224 24 | attention_resolutions: 25 | - 8 26 | - 4 27 | - 2 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 3 33 | - 4 34 | num_head_channels: 32 35 | first_stage_config: 36 | target: ldm.models.autoencoder.VQModelInterface 37 | params: 38 | embed_dim: 3 39 | n_embed: 8192 40 | ddconfig: 41 | double_z: false 42 | z_channels: 3 43 | resolution: 256 44 | in_channels: 3 45 | out_ch: 3 46 | ch: 128 47 | ch_mult: 48 | - 1 49 | - 2 50 | - 4 51 | num_res_blocks: 2 52 | attn_resolutions: [] 53 | dropout: 0.0 54 | lossconfig: 55 | target: torch.nn.Identity 56 | cond_stage_config: __is_unconditional__ 57 | data: 58 | target: main.DataModuleFromConfig 59 | params: 60 | batch_size: 42 61 | num_workers: 5 62 | wrap: false 63 | train: 64 | target: ldm.data.faceshq.FFHQTrain 65 | params: 66 | size: 256 67 | validation: 68 | target: ldm.data.faceshq.FFHQValidation 69 | params: 70 | size: 256 71 | -------------------------------------------------------------------------------- /gen_slices/models/ldm/celeba256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: false 15 | concat_mode: false 16 | monitor: val/loss 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 64 21 | in_channels: 3 22 | out_channels: 3 23 | model_channels: 224 24 | attention_resolutions: 25 | - 8 26 | - 4 27 | - 2 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 3 33 | - 4 34 | num_head_channels: 32 35 | first_stage_config: 36 | target: ldm.models.autoencoder.VQModelInterface 37 | params: 38 | embed_dim: 3 39 | n_embed: 8192 40 | ddconfig: 41 | double_z: false 42 | z_channels: 3 43 | resolution: 256 44 | in_channels: 3 45 | out_ch: 3 46 | ch: 128 47 | ch_mult: 48 | - 1 49 | - 2 50 | - 4 51 | num_res_blocks: 2 52 | attn_resolutions: [] 53 | dropout: 0.0 54 | lossconfig: 55 | target: torch.nn.Identity 56 | cond_stage_config: __is_unconditional__ 57 | data: 58 | target: main.DataModuleFromConfig 59 | params: 60 | batch_size: 48 61 | num_workers: 5 62 | wrap: false 63 | train: 64 | target: ldm.data.faceshq.CelebAHQTrain 65 | params: 66 | size: 256 67 | validation: 68 | target: ldm.data.faceshq.CelebAHQValidation 69 | params: 70 | size: 256 71 | -------------------------------------------------------------------------------- /gen_slices/models/ldm/lsun_beds256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: false 15 | concat_mode: false 16 | monitor: val/loss 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 64 21 | in_channels: 3 22 | out_channels: 3 23 | model_channels: 224 24 | attention_resolutions: 25 | - 8 26 | - 4 27 | - 2 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 3 33 | - 4 34 | num_head_channels: 32 35 | first_stage_config: 36 | target: ldm.models.autoencoder.VQModelInterface 37 | params: 38 | embed_dim: 3 39 | n_embed: 8192 40 | ddconfig: 41 | double_z: false 42 | z_channels: 3 43 | resolution: 256 44 | in_channels: 3 45 | out_ch: 3 46 | ch: 128 47 | ch_mult: 48 | - 1 49 | - 2 50 | - 4 51 | num_res_blocks: 2 52 | attn_resolutions: [] 53 | dropout: 0.0 54 | lossconfig: 55 | target: torch.nn.Identity 56 | cond_stage_config: __is_unconditional__ 57 | data: 58 | target: main.DataModuleFromConfig 59 | params: 60 | batch_size: 48 61 | num_workers: 5 62 | wrap: false 63 | train: 64 | target: ldm.data.lsun.LSUNBedroomsTrain 65 | params: 66 | size: 256 67 | validation: 68 | target: ldm.data.lsun.LSUNBedroomsValidation 69 | params: 70 | size: 256 71 | -------------------------------------------------------------------------------- /gen_slices/models/ldm/inpainting_big/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0205 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l1 10 | first_stage_key: image 11 | cond_stage_key: masked_image 12 | image_size: 64 13 | channels: 3 14 | concat_mode: true 15 | monitor: val/loss 16 | scheduler_config: 17 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler 18 | params: 19 | verbosity_interval: 0 20 | warm_up_steps: 1000 21 | max_decay_steps: 50000 22 | lr_start: 0.001 23 | lr_max: 0.1 24 | lr_min: 0.0001 25 | unet_config: 26 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 27 | params: 28 | image_size: 64 29 | in_channels: 7 30 | out_channels: 3 31 | model_channels: 256 32 | attention_resolutions: 33 | - 8 34 | - 4 35 | - 2 36 | num_res_blocks: 2 37 | channel_mult: 38 | - 1 39 | - 2 40 | - 3 41 | - 4 42 | num_heads: 8 43 | resblock_updown: true 44 | first_stage_config: 45 | target: ldm.models.autoencoder.VQModelInterface 46 | params: 47 | embed_dim: 3 48 | n_embed: 8192 49 | monitor: val/rec_loss 50 | ddconfig: 51 | attn_type: none 52 | double_z: false 53 | z_channels: 3 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | num_res_blocks: 2 63 | attn_resolutions: [] 64 | dropout: 0.0 65 | lossconfig: 66 | target: ldm.modules.losses.contperceptual.DummyLoss 67 | cond_stage_config: __is_first_stage__ 68 | -------------------------------------------------------------------------------- /gen_slices/configs/retrieval-augmented-diffusion/768x768.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0001 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.015 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: jpg 11 | cond_stage_key: nix 12 | image_size: 48 13 | channels: 16 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_by_std: false 18 | scale_factor: 0.22765929 19 | unet_config: 20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 21 | params: 22 | image_size: 48 23 | in_channels: 16 24 | out_channels: 16 25 | model_channels: 448 26 | attention_resolutions: 27 | - 4 28 | - 2 29 | - 1 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 4 36 | use_scale_shift_norm: false 37 | resblock_updown: false 38 | num_head_channels: 32 39 | use_spatial_transformer: true 40 | transformer_depth: 1 41 | context_dim: 768 42 | use_checkpoint: true 43 | first_stage_config: 44 | target: ldm.models.autoencoder.AutoencoderKL 45 | params: 46 | monitor: val/rec_loss 47 | embed_dim: 16 48 | ddconfig: 49 | double_z: true 50 | z_channels: 16 51 | resolution: 256 52 | in_channels: 3 53 | out_ch: 3 54 | ch: 128 55 | ch_mult: 56 | - 1 57 | - 1 58 | - 2 59 | - 2 60 | - 4 61 | num_res_blocks: 2 62 | attn_resolutions: 63 | - 16 64 | dropout: 0.0 65 | lossconfig: 66 | target: torch.nn.Identity 67 | cond_stage_config: 68 | target: torch.nn.Identity -------------------------------------------------------------------------------- /gen_slices/configs/latent-diffusion/txt2img-1p4B-eval.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-05 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.012 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | unet_config: 21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 22 | params: 23 | image_size: 32 24 | in_channels: 4 25 | out_channels: 4 26 | model_channels: 320 27 | attention_resolutions: 28 | - 4 29 | - 2 30 | - 1 31 | num_res_blocks: 2 32 | channel_mult: 33 | - 1 34 | - 2 35 | - 4 36 | - 4 37 | num_heads: 8 38 | use_spatial_transformer: true 39 | transformer_depth: 1 40 | context_dim: 1280 41 | use_checkpoint: true 42 | legacy: False 43 | 44 | first_stage_config: 45 | target: ldm.models.autoencoder.AutoencoderKL 46 | params: 47 | embed_dim: 4 48 | monitor: val/rec_loss 49 | ddconfig: 50 | double_z: true 51 | z_channels: 4 52 | resolution: 256 53 | in_channels: 3 54 | out_ch: 3 55 | ch: 128 56 | ch_mult: 57 | - 1 58 | - 2 59 | - 4 60 | - 4 61 | num_res_blocks: 2 62 | attn_resolutions: [] 63 | dropout: 0.0 64 | lossconfig: 65 | target: torch.nn.Identity 66 | 67 | cond_stage_config: 68 | target: ldm.modules.encoders.modules.BERTEmbedder 69 | params: 70 | n_embed: 1280 71 | n_layer: 32 72 | -------------------------------------------------------------------------------- /reg_slices/src_convonet/utils/libmcubes/exporter.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | 5 | def export_obj(vertices, triangles, filename): 6 | """ 7 | Exports a mesh in the (.obj) format. 8 | """ 9 | 10 | with open(filename, 'w') as fh: 11 | 12 | for v in vertices: 13 | fh.write("v {} {} {}\n".format(*v)) 14 | 15 | for f in triangles: 16 | fh.write("f {} {} {}\n".format(*(f + 1))) 17 | 18 | 19 | def export_off(vertices, triangles, filename): 20 | """ 21 | Exports a mesh in the (.off) format. 22 | """ 23 | 24 | with open(filename, 'w') as fh: 25 | fh.write('OFF\n') 26 | fh.write('{} {} 0\n'.format(len(vertices), len(triangles))) 27 | 28 | for v in vertices: 29 | fh.write("{} {} {}\n".format(*v)) 30 | 31 | for f in triangles: 32 | fh.write("3 {} {} {}\n".format(*f)) 33 | 34 | 35 | def export_mesh(vertices, triangles, filename, mesh_name="mcubes_mesh"): 36 | """ 37 | Exports a mesh in the COLLADA (.dae) format. 38 | 39 | Needs PyCollada (https://github.com/pycollada/pycollada). 40 | """ 41 | 42 | import collada 43 | 44 | mesh = collada.Collada() 45 | 46 | vert_src = collada.source.FloatSource("verts-array", vertices, ('X','Y','Z')) 47 | geom = collada.geometry.Geometry(mesh, "geometry0", mesh_name, [vert_src]) 48 | 49 | input_list = collada.source.InputList() 50 | input_list.addInput(0, 'VERTEX', "#verts-array") 51 | 52 | triset = geom.createTriangleSet(np.copy(triangles), input_list, "") 53 | geom.primitives.append(triset) 54 | mesh.geometries.append(geom) 55 | 56 | geomnode = collada.scene.GeometryNode(geom, []) 57 | node = collada.scene.Node(mesh_name, children=[geomnode]) 58 | 59 | myscene = collada.scene.Scene("mcubes_scene", [node]) 60 | mesh.scenes.append(myscene) 61 | mesh.scene = myscene 62 | 63 | mesh.write(filename) 64 | -------------------------------------------------------------------------------- /reg_slices/src/vgg16bn_feats.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torch.nn as nn 4 | 5 | class VGG16BNFeats(torch.nn.Module): 6 | # "D": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"], 7 | # conv1_1: conv, bn, relu 3 8 | # conv1_2: conv, bn, relu 6 9 | # "M": maxpool 7 10 | # conv2_1: conv, bn, relu 10 11 | # conv2_2: conv, bn, relu 13 12 | # "M": maxpool 14 13 | # conv3_1: conv, bn, relu 17 14 | # conv3_2: conv, bn, relu 20 15 | # conv3_3: conv, bn, relu 23 16 | # "M": maxpool 24 17 | # conv4_1: conv, bn, relu 27 18 | # conv4_2: conv, bn, relu 30 19 | # conv4_3: conv, bn, relu 33 20 | # "M": maxpool 34 21 | # conv5_1: conv, bn, relu 37 22 | # conv5_2: conv, bn, relu 40 23 | # conv5_3: conv, bn, relu 43 24 | # "M": maxpool 44 25 | 26 | def __init__(self, requires_grad=True): 27 | super(VGG16BNFeats, self).__init__() 28 | vgg = torchvision.models.vgg16_bn(pretrained=True) 29 | 30 | vgg_features = vgg.features 31 | self.conv1_2 = vgg_features[:4] 32 | self.conv2_2 = vgg_features[4:11] 33 | self.conv3_3 = vgg_features[11:21] 34 | self.conv4_3 = vgg_features[21:31] 35 | self.conv5_3 = vgg_features[31:41] 36 | self.conv_last = vgg_features[41:44] 37 | # self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 38 | self.classifier = nn.Linear(512 * 4 * 4, 128) 39 | 40 | def forward(self, img): 41 | # conv_feats = [] 42 | conv1_2 = self.conv1_2(img) 43 | conv2_2 = self.conv2_2(conv1_2) 44 | conv3_3 = self.conv3_3(conv2_2) 45 | conv4_3 = self.conv4_3(conv3_3) 46 | conv5_3 = self.conv5_3(conv4_3) 47 | conv_last = self.conv_last(conv5_3) 48 | # conv_ret = self.slice4(conv3_3) 49 | # ret = self.slice6(conv5_2) 50 | feat_global = conv_last 51 | # feat_global = self.avgpool(conv_last) 52 | feat_global = torch.flatten(feat_global, 1) 53 | feat_global = self.classifier(feat_global) 54 | return [conv1_2, conv2_2, conv3_3, conv4_3, conv5_3], feat_global -------------------------------------------------------------------------------- /gen_slices/models/ldm/text2img256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 64 21 | in_channels: 3 22 | out_channels: 3 23 | model_channels: 192 24 | attention_resolutions: 25 | - 8 26 | - 4 27 | - 2 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 3 33 | - 5 34 | num_head_channels: 32 35 | use_spatial_transformer: true 36 | transformer_depth: 1 37 | context_dim: 640 38 | first_stage_config: 39 | target: ldm.models.autoencoder.VQModelInterface 40 | params: 41 | embed_dim: 3 42 | n_embed: 8192 43 | ddconfig: 44 | double_z: false 45 | z_channels: 3 46 | resolution: 256 47 | in_channels: 3 48 | out_ch: 3 49 | ch: 128 50 | ch_mult: 51 | - 1 52 | - 2 53 | - 4 54 | num_res_blocks: 2 55 | attn_resolutions: [] 56 | dropout: 0.0 57 | lossconfig: 58 | target: torch.nn.Identity 59 | cond_stage_config: 60 | target: ldm.modules.encoders.modules.BERTEmbedder 61 | params: 62 | n_embed: 640 63 | n_layer: 32 64 | data: 65 | target: main.DataModuleFromConfig 66 | params: 67 | batch_size: 28 68 | num_workers: 10 69 | wrap: false 70 | train: 71 | target: ldm.data.previews.pytorch_dataset.PreviewsTrain 72 | params: 73 | size: 256 74 | validation: 75 | target: ldm.data.previews.pytorch_dataset.PreviewsValidation 76 | params: 77 | size: 256 78 | -------------------------------------------------------------------------------- /gen_slices/models/ldm/semantic_synthesis512/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0205 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l1 10 | first_stage_key: image 11 | cond_stage_key: segmentation 12 | image_size: 128 13 | channels: 3 14 | concat_mode: true 15 | cond_stage_trainable: true 16 | unet_config: 17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 18 | params: 19 | image_size: 128 20 | in_channels: 6 21 | out_channels: 3 22 | model_channels: 128 23 | attention_resolutions: 24 | - 32 25 | - 16 26 | - 8 27 | num_res_blocks: 2 28 | channel_mult: 29 | - 1 30 | - 4 31 | - 8 32 | num_heads: 8 33 | first_stage_config: 34 | target: ldm.models.autoencoder.VQModelInterface 35 | params: 36 | embed_dim: 3 37 | n_embed: 8192 38 | monitor: val/rec_loss 39 | ddconfig: 40 | double_z: false 41 | z_channels: 3 42 | resolution: 256 43 | in_channels: 3 44 | out_ch: 3 45 | ch: 128 46 | ch_mult: 47 | - 1 48 | - 2 49 | - 4 50 | num_res_blocks: 2 51 | attn_resolutions: [] 52 | dropout: 0.0 53 | lossconfig: 54 | target: torch.nn.Identity 55 | cond_stage_config: 56 | target: ldm.modules.encoders.modules.SpatialRescaler 57 | params: 58 | n_stages: 2 59 | in_channels: 182 60 | out_channels: 3 61 | data: 62 | target: main.DataModuleFromConfig 63 | params: 64 | batch_size: 8 65 | wrap: false 66 | num_workers: 10 67 | train: 68 | target: ldm.data.landscapes.RFWTrain 69 | params: 70 | size: 768 71 | crop_size: 512 72 | segmentation_to_float32: true 73 | validation: 74 | target: ldm.data.landscapes.RFWValidation 75 | params: 76 | size: 768 77 | crop_size: 512 78 | segmentation_to_float32: true 79 | -------------------------------------------------------------------------------- /gen_slices/models/ldm/cin256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 32 21 | in_channels: 4 22 | out_channels: 4 23 | model_channels: 256 24 | attention_resolutions: 25 | - 4 26 | - 2 27 | - 1 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 4 33 | num_head_channels: 32 34 | use_spatial_transformer: true 35 | transformer_depth: 1 36 | context_dim: 512 37 | first_stage_config: 38 | target: ldm.models.autoencoder.VQModelInterface 39 | params: 40 | embed_dim: 4 41 | n_embed: 16384 42 | ddconfig: 43 | double_z: false 44 | z_channels: 4 45 | resolution: 256 46 | in_channels: 3 47 | out_ch: 3 48 | ch: 128 49 | ch_mult: 50 | - 1 51 | - 2 52 | - 2 53 | - 4 54 | num_res_blocks: 2 55 | attn_resolutions: 56 | - 32 57 | dropout: 0.0 58 | lossconfig: 59 | target: torch.nn.Identity 60 | cond_stage_config: 61 | target: ldm.modules.encoders.modules.ClassEmbedder 62 | params: 63 | embed_dim: 512 64 | key: class_label 65 | data: 66 | target: main.DataModuleFromConfig 67 | params: 68 | batch_size: 64 69 | num_workers: 12 70 | wrap: false 71 | train: 72 | target: ldm.data.imagenet.ImageNetTrain 73 | params: 74 | config: 75 | size: 256 76 | validation: 77 | target: ldm.data.imagenet.ImageNetValidation 78 | params: 79 | config: 80 | size: 256 81 | -------------------------------------------------------------------------------- /reg_slices/src_convonet/utils/libmcubes/README.rst: -------------------------------------------------------------------------------- 1 | ======== 2 | PyMCubes 3 | ======== 4 | 5 | PyMCubes is an implementation of the marching cubes algorithm to extract 6 | isosurfaces from volumetric data. The volumetric data can be given as a 7 | three-dimensional NumPy array or as a Python function ``f(x, y, z)``. The first 8 | option is much faster, but it requires more memory and becomes unfeasible for 9 | very large volumes. 10 | 11 | PyMCubes also provides a function to export the results of the marching cubes as 12 | COLLADA ``(.dae)`` files. This requires the 13 | `PyCollada `_ library. 14 | 15 | Installation 16 | ============ 17 | 18 | Just as any standard Python package, clone or download the project 19 | and run:: 20 | 21 | $ cd path/to/PyMCubes 22 | $ python setup.py build 23 | $ python setup.py install 24 | 25 | If you do not have write permission on the directory of Python packages, 26 | install with the ``--user`` option:: 27 | 28 | $ python setup.py install --user 29 | 30 | Example 31 | ======= 32 | 33 | The following example creates a data volume with spherical isosurfaces and 34 | extracts one of them (i.e., a sphere) with PyMCubes. The result is exported as 35 | ``sphere.dae``:: 36 | 37 | >>> import numpy as np 38 | >>> import mcubes 39 | 40 | # Create a data volume (30 x 30 x 30) 41 | >>> X, Y, Z = np.mgrid[:30, :30, :30] 42 | >>> u = (X-15)**2 + (Y-15)**2 + (Z-15)**2 - 8**2 43 | 44 | # Extract the 0-isosurface 45 | >>> vertices, triangles = mcubes.marching_cubes(u, 0) 46 | 47 | # Export the result to sphere.dae 48 | >>> mcubes.export_mesh(vertices, triangles, "sphere.dae", "MySphere") 49 | 50 | The second example is very similar to the first one, but it uses a function 51 | to represent the volume instead of a NumPy array:: 52 | 53 | >>> import numpy as np 54 | >>> import mcubes 55 | 56 | # Create the volume 57 | >>> f = lambda x, y, z: x**2 + y**2 + z**2 58 | 59 | # Extract the 16-isosurface 60 | >>> vertices, triangles = mcubes.marching_cubes_func((-10,-10,-10), (10,10,10), 61 | ... 100, 100, 100, f, 16) 62 | 63 | # Export the result to sphere2.dae 64 | >>> mcubes.export_mesh(vertices, triangles, "sphere2.dae", "MySphere") 65 | -------------------------------------------------------------------------------- /gen_slices/models/ldm/bsr_sr/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0155 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l2 10 | first_stage_key: image 11 | cond_stage_key: LR_image 12 | image_size: 64 13 | channels: 3 14 | concat_mode: true 15 | cond_stage_trainable: false 16 | unet_config: 17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 18 | params: 19 | image_size: 64 20 | in_channels: 6 21 | out_channels: 3 22 | model_channels: 160 23 | attention_resolutions: 24 | - 16 25 | - 8 26 | num_res_blocks: 2 27 | channel_mult: 28 | - 1 29 | - 2 30 | - 2 31 | - 4 32 | num_head_channels: 32 33 | first_stage_config: 34 | target: ldm.models.autoencoder.VQModelInterface 35 | params: 36 | embed_dim: 3 37 | n_embed: 8192 38 | monitor: val/rec_loss 39 | ddconfig: 40 | double_z: false 41 | z_channels: 3 42 | resolution: 256 43 | in_channels: 3 44 | out_ch: 3 45 | ch: 128 46 | ch_mult: 47 | - 1 48 | - 2 49 | - 4 50 | num_res_blocks: 2 51 | attn_resolutions: [] 52 | dropout: 0.0 53 | lossconfig: 54 | target: torch.nn.Identity 55 | cond_stage_config: 56 | target: torch.nn.Identity 57 | data: 58 | target: main.DataModuleFromConfig 59 | params: 60 | batch_size: 64 61 | wrap: false 62 | num_workers: 12 63 | train: 64 | target: ldm.data.openimages.SuperresOpenImagesAdvancedTrain 65 | params: 66 | size: 256 67 | degradation: bsrgan_light 68 | downscale_f: 4 69 | min_crop_f: 0.5 70 | max_crop_f: 1.0 71 | random_crop: true 72 | validation: 73 | target: ldm.data.openimages.SuperresOpenImagesAdvancedValidation 74 | params: 75 | size: 256 76 | degradation: bsrgan_light 77 | downscale_f: 4 78 | min_crop_f: 0.5 79 | max_crop_f: 1.0 80 | random_crop: true 81 | -------------------------------------------------------------------------------- /gen_slices/models/ldm/layout2img-openimages256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0205 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l1 10 | first_stage_key: image 11 | cond_stage_key: coordinates_bbox 12 | image_size: 64 13 | channels: 3 14 | conditioning_key: crossattn 15 | cond_stage_trainable: true 16 | unet_config: 17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 18 | params: 19 | image_size: 64 20 | in_channels: 3 21 | out_channels: 3 22 | model_channels: 128 23 | attention_resolutions: 24 | - 8 25 | - 4 26 | - 2 27 | num_res_blocks: 2 28 | channel_mult: 29 | - 1 30 | - 2 31 | - 3 32 | - 4 33 | num_head_channels: 32 34 | use_spatial_transformer: true 35 | transformer_depth: 3 36 | context_dim: 512 37 | first_stage_config: 38 | target: ldm.models.autoencoder.VQModelInterface 39 | params: 40 | embed_dim: 3 41 | n_embed: 8192 42 | monitor: val/rec_loss 43 | ddconfig: 44 | double_z: false 45 | z_channels: 3 46 | resolution: 256 47 | in_channels: 3 48 | out_ch: 3 49 | ch: 128 50 | ch_mult: 51 | - 1 52 | - 2 53 | - 4 54 | num_res_blocks: 2 55 | attn_resolutions: [] 56 | dropout: 0.0 57 | lossconfig: 58 | target: torch.nn.Identity 59 | cond_stage_config: 60 | target: ldm.modules.encoders.modules.BERTEmbedder 61 | params: 62 | n_embed: 512 63 | n_layer: 16 64 | vocab_size: 8192 65 | max_seq_len: 92 66 | use_tokenizer: false 67 | monitor: val/loss_simple_ema 68 | data: 69 | target: main.DataModuleFromConfig 70 | params: 71 | batch_size: 24 72 | wrap: false 73 | num_workers: 10 74 | train: 75 | target: ldm.data.openimages.OpenImagesBBoxTrain 76 | params: 77 | size: 256 78 | validation: 79 | target: ldm.data.openimages.OpenImagesBBoxValidation 80 | params: 81 | size: 256 82 | -------------------------------------------------------------------------------- /reg_slices/src/vgg16bn_feats_for_disn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torch.nn as nn 4 | 5 | class VGG16BNFeats(torch.nn.Module): 6 | # "D": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"], 7 | # conv1_1: conv, bn, relu 3 8 | # conv1_2: conv, bn, relu 6 9 | # "M": maxpool 7 10 | # conv2_1: conv, bn, relu 10 11 | # conv2_2: conv, bn, relu 13 12 | # "M": maxpool 14 13 | # conv3_1: conv, bn, relu 17 14 | # conv3_2: conv, bn, relu 20 15 | # conv3_3: conv, bn, relu 23 16 | # "M": maxpool 24 17 | # conv4_1: conv, bn, relu 27 18 | # conv4_2: conv, bn, relu 30 19 | # conv4_3: conv, bn, relu 33 20 | # "M": maxpool 34 21 | # conv5_1: conv, bn, relu 37 22 | # conv5_2: conv, bn, relu 40 23 | # conv5_3: conv, bn, relu 43 24 | # "M": maxpool 44 25 | 26 | def __init__(self, requires_grad=True): 27 | super(VGG16BNFeats, self).__init__() 28 | vgg = torchvision.models.vgg16_bn(pretrained=True) 29 | 30 | vgg_features = vgg.features 31 | self.conv1_2 = vgg_features[:4] 32 | self.conv2_2 = vgg_features[4:11] 33 | self.conv3_3 = vgg_features[11:21] 34 | self.conv4_3 = vgg_features[21:31] 35 | self.conv5_3 = vgg_features[31:41] 36 | self.conv_last = vgg_features[41:44] 37 | self.avgpool = nn.AdaptiveAvgPool2d((4, 4)) 38 | self.classifier = nn.Sequential( 39 | nn.Linear(512 * 4 * 4, 1024), 40 | nn.ReLU(True), 41 | nn.Dropout(p=0.5), 42 | nn.Linear(1024, 1024), 43 | nn.ReLU(True), 44 | nn.Dropout(p=0.5), 45 | nn.Linear(1024, 128), 46 | ) 47 | def forward(self, img): 48 | # conv_feats = [] 49 | conv1_2 = self.conv1_2(img) 50 | conv2_2 = self.conv2_2(conv1_2) 51 | conv3_3 = self.conv3_3(conv2_2) 52 | conv4_3 = self.conv4_3(conv3_3) 53 | conv5_3 = self.conv5_3(conv4_3) 54 | conv_last = self.conv_last(conv5_3) 55 | feat_global = conv_last 56 | feat_global = self.avgpool(conv_last) 57 | feat_global = torch.flatten(feat_global, 1) 58 | feat_global = self.classifier(feat_global) 59 | return [conv1_2, conv2_2, conv3_3, conv4_3, conv5_3], feat_global -------------------------------------------------------------------------------- /reg_slices/setup.py: -------------------------------------------------------------------------------- 1 | try: 2 | from setuptools import setup 3 | except ImportError: 4 | from distutils.core import setup 5 | from distutils.extension import Extension 6 | from Cython.Build import cythonize 7 | from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension 8 | import numpy 9 | 10 | 11 | # Get the numpy include directory. 12 | numpy_include_dir = numpy.get_include() 13 | 14 | # Extensions 15 | 16 | # mcubes (marching cubes algorithm) 17 | mcubes_module = Extension( 18 | 'src_convonet.utils.libmcubes.mcubes', 19 | sources=[ 20 | 'src_convonet/utils/libmcubes/mcubes.pyx', 21 | 'src_convonet/utils/libmcubes/pywrapper.cpp', 22 | 'src_convonet/utils/libmcubes/marchingcubes.cpp' 23 | ], 24 | language='c++', 25 | extra_compile_args=['-std=c++11'], 26 | include_dirs=[numpy_include_dir] 27 | ) 28 | 29 | # triangle hash (efficient mesh intersection) 30 | triangle_hash_module = Extension( 31 | 'src_convonet.utils.libmesh.triangle_hash', 32 | sources=[ 33 | 'src_convonet/utils/libmesh/triangle_hash.pyx' 34 | ], 35 | libraries=['m'], # Unix-like specific 36 | include_dirs=[numpy_include_dir] 37 | ) 38 | 39 | # mise (efficient mesh extraction) 40 | mise_module = Extension( 41 | 'src_convonet.utils.libmise.mise', 42 | sources=[ 43 | 'src_convonet/utils/libmise/mise.pyx' 44 | ], 45 | ) 46 | 47 | # simplify (efficient mesh simplification) 48 | simplify_mesh_module = Extension( 49 | 'src_convonet.utils.libsimplify.simplify_mesh', 50 | sources=[ 51 | 'src_convonet/utils/libsimplify/simplify_mesh.pyx' 52 | ], 53 | include_dirs=[numpy_include_dir] 54 | ) 55 | 56 | # voxelization (efficient mesh voxelization) 57 | voxelize_module = Extension( 58 | 'src_convonet.utils.libvoxelize.voxelize', 59 | sources=[ 60 | 'src_convonet/utils/libvoxelize/voxelize.pyx' 61 | ], 62 | libraries=['m'] # Unix-like specific 63 | ) 64 | 65 | # Gather all extension modules 66 | ext_modules = [ 67 | # pykdtree, 68 | mcubes_module, 69 | triangle_hash_module, 70 | mise_module, 71 | simplify_mesh_module, 72 | voxelize_module, 73 | ] 74 | 75 | setup( 76 | ext_modules=cythonize(ext_modules), 77 | cmdclass={ 78 | 'build_ext': BuildExtension 79 | } 80 | ) 81 | -------------------------------------------------------------------------------- /gen_slices/configs/latent-diffusion/ffhq-ldm-vq-4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | image_size: 64 12 | channels: 3 13 | monitor: val/loss_simple_ema 14 | unet_config: 15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 16 | params: 17 | image_size: 64 18 | in_channels: 3 19 | out_channels: 3 20 | model_channels: 224 21 | attention_resolutions: 22 | # note: this isn\t actually the resolution but 23 | # the downsampling factor, i.e. this corresnponds to 24 | # attention on spatial resolution 8,16,32, as the 25 | # spatial reolution of the latents is 64 for f4 26 | - 8 27 | - 4 28 | - 2 29 | num_res_blocks: 2 30 | channel_mult: 31 | - 1 32 | - 2 33 | - 3 34 | - 4 35 | num_head_channels: 32 36 | first_stage_config: 37 | target: ldm.models.autoencoder.VQModelInterface 38 | params: 39 | embed_dim: 3 40 | n_embed: 8192 41 | ckpt_path: configs/first_stage_models/vq-f4/model.yaml 42 | ddconfig: 43 | double_z: false 44 | z_channels: 3 45 | resolution: 256 46 | in_channels: 3 47 | out_ch: 3 48 | ch: 128 49 | ch_mult: 50 | - 1 51 | - 2 52 | - 4 53 | num_res_blocks: 2 54 | attn_resolutions: [] 55 | dropout: 0.0 56 | lossconfig: 57 | target: torch.nn.Identity 58 | cond_stage_config: __is_unconditional__ 59 | data: 60 | target: main.DataModuleFromConfig 61 | params: 62 | batch_size: 42 63 | num_workers: 5 64 | wrap: false 65 | train: 66 | target: taming.data.faceshq.FFHQTrain 67 | params: 68 | size: 256 69 | validation: 70 | target: taming.data.faceshq.FFHQValidation 71 | params: 72 | size: 256 73 | 74 | 75 | lightning: 76 | callbacks: 77 | image_logger: 78 | target: main.ImageLogger 79 | params: 80 | batch_frequency: 5000 81 | max_images: 8 82 | increase_log_steps: False 83 | 84 | trainer: 85 | benchmark: True -------------------------------------------------------------------------------- /gen_slices/configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | image_size: 64 12 | channels: 3 13 | monitor: val/loss_simple_ema 14 | unet_config: 15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 16 | params: 17 | image_size: 64 18 | in_channels: 3 19 | out_channels: 3 20 | model_channels: 224 21 | attention_resolutions: 22 | # note: this isn\t actually the resolution but 23 | # the downsampling factor, i.e. this corresnponds to 24 | # attention on spatial resolution 8,16,32, as the 25 | # spatial reolution of the latents is 64 for f4 26 | - 8 27 | - 4 28 | - 2 29 | num_res_blocks: 2 30 | channel_mult: 31 | - 1 32 | - 2 33 | - 3 34 | - 4 35 | num_head_channels: 32 36 | first_stage_config: 37 | target: ldm.models.autoencoder.VQModelInterface 38 | params: 39 | ckpt_path: configs/first_stage_models/vq-f4/model.yaml 40 | embed_dim: 3 41 | n_embed: 8192 42 | ddconfig: 43 | double_z: false 44 | z_channels: 3 45 | resolution: 256 46 | in_channels: 3 47 | out_ch: 3 48 | ch: 128 49 | ch_mult: 50 | - 1 51 | - 2 52 | - 4 53 | num_res_blocks: 2 54 | attn_resolutions: [] 55 | dropout: 0.0 56 | lossconfig: 57 | target: torch.nn.Identity 58 | cond_stage_config: __is_unconditional__ 59 | data: 60 | target: main.DataModuleFromConfig 61 | params: 62 | batch_size: 48 63 | num_workers: 5 64 | wrap: false 65 | train: 66 | target: ldm.data.lsun.LSUNBedroomsTrain 67 | params: 68 | size: 256 69 | validation: 70 | target: ldm.data.lsun.LSUNBedroomsValidation 71 | params: 72 | size: 256 73 | 74 | 75 | lightning: 76 | callbacks: 77 | image_logger: 78 | target: main.ImageLogger 79 | params: 80 | batch_frequency: 5000 81 | max_images: 8 82 | increase_log_steps: False 83 | 84 | trainer: 85 | benchmark: True -------------------------------------------------------------------------------- /gen_slices/configs/latent-diffusion/celebahq-ldm-vq-4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | image_size: 64 12 | channels: 3 13 | monitor: val/loss_simple_ema 14 | 15 | unet_config: 16 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 17 | params: 18 | image_size: 64 19 | in_channels: 3 20 | out_channels: 3 21 | model_channels: 224 22 | attention_resolutions: 23 | # note: this isn\t actually the resolution but 24 | # the downsampling factor, i.e. this corresnponds to 25 | # attention on spatial resolution 8,16,32, as the 26 | # spatial reolution of the latents is 64 for f4 27 | - 8 28 | - 4 29 | - 2 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 4 36 | num_head_channels: 32 37 | first_stage_config: 38 | target: ldm.models.autoencoder.VQModelInterface 39 | params: 40 | embed_dim: 3 41 | n_embed: 8192 42 | ckpt_path: models/first_stage_models/vq-f4/model.ckpt 43 | ddconfig: 44 | double_z: false 45 | z_channels: 3 46 | resolution: 256 47 | in_channels: 3 48 | out_ch: 3 49 | ch: 128 50 | ch_mult: 51 | - 1 52 | - 2 53 | - 4 54 | num_res_blocks: 2 55 | attn_resolutions: [] 56 | dropout: 0.0 57 | lossconfig: 58 | target: torch.nn.Identity 59 | cond_stage_config: __is_unconditional__ 60 | data: 61 | target: main.DataModuleFromConfig 62 | params: 63 | batch_size: 48 64 | num_workers: 5 65 | wrap: false 66 | train: 67 | target: taming.data.faceshq.CelebAHQTrain 68 | params: 69 | size: 256 70 | validation: 71 | target: taming.data.faceshq.CelebAHQValidation 72 | params: 73 | size: 256 74 | 75 | 76 | lightning: 77 | callbacks: 78 | image_logger: 79 | target: main.ImageLogger 80 | params: 81 | batch_frequency: 5000 82 | max_images: 8 83 | increase_log_steps: False 84 | 85 | trainer: 86 | benchmark: True -------------------------------------------------------------------------------- /gen_slices/models/ldm/lsun_churches256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-05 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0155 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | loss_type: l1 11 | first_stage_key: image 12 | cond_stage_key: image 13 | image_size: 32 14 | channels: 4 15 | cond_stage_trainable: false 16 | concat_mode: false 17 | scale_by_std: true 18 | monitor: val/loss_simple_ema 19 | scheduler_config: 20 | target: ldm.lr_scheduler.LambdaLinearScheduler 21 | params: 22 | warm_up_steps: 23 | - 10000 24 | cycle_lengths: 25 | - 10000000000000 26 | f_start: 27 | - 1.0e-06 28 | f_max: 29 | - 1.0 30 | f_min: 31 | - 1.0 32 | unet_config: 33 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 34 | params: 35 | image_size: 32 36 | in_channels: 4 37 | out_channels: 4 38 | model_channels: 192 39 | attention_resolutions: 40 | - 1 41 | - 2 42 | - 4 43 | - 8 44 | num_res_blocks: 2 45 | channel_mult: 46 | - 1 47 | - 2 48 | - 2 49 | - 4 50 | - 4 51 | num_heads: 8 52 | use_scale_shift_norm: true 53 | resblock_updown: true 54 | first_stage_config: 55 | target: ldm.models.autoencoder.AutoencoderKL 56 | params: 57 | embed_dim: 4 58 | monitor: val/rec_loss 59 | ddconfig: 60 | double_z: true 61 | z_channels: 4 62 | resolution: 256 63 | in_channels: 3 64 | out_ch: 3 65 | ch: 128 66 | ch_mult: 67 | - 1 68 | - 2 69 | - 4 70 | - 4 71 | num_res_blocks: 2 72 | attn_resolutions: [] 73 | dropout: 0.0 74 | lossconfig: 75 | target: torch.nn.Identity 76 | 77 | cond_stage_config: '__is_unconditional__' 78 | 79 | data: 80 | target: main.DataModuleFromConfig 81 | params: 82 | batch_size: 96 83 | num_workers: 5 84 | wrap: false 85 | train: 86 | target: ldm.data.lsun.LSUNChurchesTrain 87 | params: 88 | size: 256 89 | validation: 90 | target: ldm.data.lsun.LSUNChurchesValidation 91 | params: 92 | size: 256 93 | -------------------------------------------------------------------------------- /render_slices/gen_slices.py: -------------------------------------------------------------------------------- 1 | import os 2 | from joblib import Parallel, delayed 3 | from csv import reader 4 | import json 5 | 6 | path_root_objaverse = '/data/wangyz/03_datasets/' 7 | path_output = '../data/objaverse/01_img_slices' 8 | path_view = '../data/objaverse/00_img_input' 9 | path_blender = '/data/wangyz/04_blender_renderer/blender-3.6.0-linux-x64' 10 | 11 | def is_file_bigger_than_100MB(file_path): 12 | # Get the size of the file in bytes 13 | file_size = os.path.getsize(file_path) 14 | # Convert bytes to megabytes 15 | file_size_MB = file_size / (1024 * 1024) 16 | return file_size_MB > 100 17 | 18 | def gen_slices(sid, spath): 19 | # check if the mesh file is bigger than 100MB, large files tend to make Blender crash when slicing 20 | object_path = f'{path_root_objaverse}/{spath}' 21 | object_uid = os.path.basename(object_path).split(".")[0] 22 | if is_file_bigger_than_100MB(object_path): return 23 | if os.path.exists(f'{path_output}/{object_uid}/011/Z_4.png'): return 24 | 25 | try: 26 | cmd = f'{path_blender}/blender -b -P blender_script_slices.py -- \ 27 | --object_path {object_path} \ 28 | --output_dir ./{path_output} \ 29 | --engine CYCLES \ 30 | --view_path {path_view} \ 31 | --slice_direction camera \ 32 | --num_images 12 ' 33 | 34 | os.system(cmd) 35 | except: 36 | f = open(f'./logs/failed/{sid}.txt', 'w') 37 | f.close() 38 | return 39 | 40 | def get_shape_ids_and_paths(): 41 | 42 | shape_ids = [] 43 | sid2spath = {} 44 | shape_paths = [] 45 | # Open the JSON file 46 | 47 | with open('../data/objaverse/input_models_path-lvis.json', 'r') as file: 48 | data = json.load(file) 49 | for index, item in enumerate(data): 50 | if os.path.exists(f'{path_root_objaverse}/{item}'): 51 | shape_id = item.split('/')[-1].split('.')[0] 52 | shape_ids.append(shape_id) 53 | shape_paths.append(item) 54 | sid2spath[shape_id] = item 55 | 56 | return shape_ids, sid2spath 57 | 58 | def main(): 59 | shape_ids, sid2spath = get_shape_ids_and_paths() 60 | shape_ids = ['0a0c7e40a66d4fd090f549599f2f2c9d'] # this is an example, delete this line when creating a dataset 61 | with Parallel(n_jobs=8) as p: 62 | p(delayed(gen_slices)(sid=sid, spath=sid2spath[sid]) for idx, sid in enumerate(shape_ids)) 63 | 64 | main() 65 | -------------------------------------------------------------------------------- /gen_slices/configs/latent-diffusion/objaverse-ldm-kl-8.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False' 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0155 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | loss_type: l1 11 | first_stage_key: image 12 | cond_stage_key: img_ipt_view 13 | image_size: 64 14 | channels: 4 15 | cond_stage_trainable: true 16 | conditioning_key: concat 17 | # concat_mode: true 18 | scale_by_std: True 19 | monitor: 'val/loss_simple_ema' 20 | 21 | 22 | unet_config: 23 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 24 | params: 25 | image_size: 64 26 | in_channels: 8 27 | out_channels: 4 28 | model_channels: 192 29 | attention_resolutions: [ 1, 2, 4, 8 ] # 32, 16, 8, 4 30 | num_res_blocks: 2 31 | channel_mult: [ 1,2,2,4,4 ] # 32, 16, 8, 4, 2 32 | num_heads: 8 33 | use_scale_shift_norm: True 34 | resblock_updown: True 35 | 36 | first_stage_config: 37 | target: ldm.models.autoencoder.AutoencoderKL 38 | params: 39 | embed_dim: 4 40 | monitor: "val/rec_loss" 41 | ckpt_path: "logs/autoencoder_kl_f8/checkpoints/model.ckpt" 42 | ddconfig: 43 | double_z: True 44 | z_channels: 4 45 | resolution: 512 46 | in_channels: 3 47 | out_ch: 3 48 | ch: 128 49 | ch_mult: [1,2,4,4] # num_down = len(ch_mult)-1 50 | num_res_blocks: 2 51 | attn_resolutions: [ ] 52 | dropout: 0.0 53 | lossconfig: 54 | target: torch.nn.Identity 55 | 56 | cond_stage_config: 57 | target: ldm.modules.encoders.modules.ImageEncoderVGG16BN 58 | 59 | data: 60 | target: main.DataModuleFromConfig 61 | params: 62 | batch_size: 8 63 | num_workers: 5 64 | wrap: False 65 | train: 66 | target: ldm.data.objaverse.ObjaverseTrain 67 | params: 68 | size: 128 69 | validation: 70 | target: ldm.data.objaverse.ObjaverseValidation 71 | params: 72 | size: 128 73 | test: 74 | target: ldm.data.objaverse.ObjaverseTest 75 | params: 76 | size: 128 77 | lightning: 78 | callbacks: 79 | image_logger: 80 | target: main.ImageLogger 81 | params: 82 | batch_frequency: 2000 83 | max_images: 8 84 | increase_log_steps: False 85 | 86 | 87 | trainer: 88 | benchmark: True -------------------------------------------------------------------------------- /gen_slices/configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False' 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0155 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | loss_type: l1 11 | first_stage_key: "image" 12 | cond_stage_key: "image" 13 | image_size: 32 14 | channels: 4 15 | cond_stage_trainable: False 16 | concat_mode: False 17 | scale_by_std: True 18 | monitor: 'val/loss_simple_ema' 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [10000] 24 | cycle_lengths: [10000000000000] 25 | f_start: [1.e-6] 26 | f_max: [1.] 27 | f_min: [ 1.] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 192 36 | attention_resolutions: [ 1, 2, 4, 8 ] # 32, 16, 8, 4 37 | num_res_blocks: 2 38 | channel_mult: [ 1,2,2,4,4 ] # 32, 16, 8, 4, 2 39 | num_heads: 8 40 | use_scale_shift_norm: True 41 | resblock_updown: True 42 | 43 | first_stage_config: 44 | target: ldm.models.autoencoder.AutoencoderKL 45 | params: 46 | embed_dim: 4 47 | monitor: "val/rec_loss" 48 | ckpt_path: "models/first_stage_models/kl-f8/model.ckpt" 49 | ddconfig: 50 | double_z: True 51 | z_channels: 4 52 | resolution: 256 53 | in_channels: 3 54 | out_ch: 3 55 | ch: 128 56 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 57 | num_res_blocks: 2 58 | attn_resolutions: [ ] 59 | dropout: 0.0 60 | lossconfig: 61 | target: torch.nn.Identity 62 | 63 | cond_stage_config: "__is_unconditional__" 64 | 65 | data: 66 | target: main.DataModuleFromConfig 67 | params: 68 | batch_size: 96 69 | num_workers: 5 70 | wrap: False 71 | train: 72 | target: ldm.data.lsun.LSUNChurchesTrain 73 | params: 74 | size: 256 75 | validation: 76 | target: ldm.data.lsun.LSUNChurchesValidation 77 | params: 78 | size: 256 79 | 80 | lightning: 81 | callbacks: 82 | image_logger: 83 | target: main.ImageLogger 84 | params: 85 | batch_frequency: 5000 86 | max_images: 8 87 | increase_log_steps: False 88 | 89 | 90 | trainer: 91 | benchmark: True -------------------------------------------------------------------------------- /gen_slices/configs/latent-diffusion/objaverse-ldm-kl-8-infer.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False' 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | ckpt_path: "logs/2024-04-23T02-11-33_objaverse-ldm-kl-8/checkpoints/epoch=000103.ckpt" 6 | linear_start: 0.0015 7 | linear_end: 0.0155 8 | num_timesteps_cond: 1 9 | log_every_t: 200 10 | timesteps: 1000 11 | loss_type: l1 12 | first_stage_key: image 13 | cond_stage_key: img_ipt_view 14 | image_size: 64 15 | channels: 4 16 | cond_stage_trainable: true 17 | conditioning_key: concat 18 | # concat_mode: true 19 | scale_by_std: True 20 | monitor: 'val/loss_simple_ema' 21 | 22 | 23 | unet_config: 24 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 25 | params: 26 | image_size: 64 27 | in_channels: 8 28 | out_channels: 4 29 | model_channels: 192 30 | attention_resolutions: [ 1, 2, 4, 8 ] # 32, 16, 8, 4 31 | num_res_blocks: 2 32 | channel_mult: [ 1,2,2,4,4 ] # 32, 16, 8, 4, 2 33 | num_heads: 8 34 | use_scale_shift_norm: True 35 | resblock_updown: True 36 | 37 | first_stage_config: 38 | target: ldm.models.autoencoder.AutoencoderKL 39 | params: 40 | embed_dim: 4 41 | monitor: "val/rec_loss" 42 | ckpt_path: "logs/autoencoder_kl_f8/checkpoints/model.ckpt" 43 | ddconfig: 44 | double_z: True 45 | z_channels: 4 46 | resolution: 512 47 | in_channels: 3 48 | out_ch: 3 49 | ch: 128 50 | ch_mult: [1,2,4,4] # num_down = len(ch_mult)-1 51 | num_res_blocks: 2 52 | attn_resolutions: [ ] 53 | dropout: 0.0 54 | lossconfig: 55 | target: torch.nn.Identity 56 | 57 | cond_stage_config: 58 | target: ldm.modules.encoders.modules.ImageEncoderVGG16BN 59 | 60 | data: 61 | target: main.DataModuleFromConfig 62 | params: 63 | batch_size: 8 64 | num_workers: 5 65 | wrap: False 66 | train: 67 | target: ldm.data.objaverse.ObjaverseTrain 68 | params: 69 | size: 128 70 | validation: 71 | target: ldm.data.objaverse.ObjaverseValidation 72 | params: 73 | size: 128 74 | test: 75 | target: ldm.data.objaverse.ObjaverseTest 76 | params: 77 | size: 128 78 | lightning: 79 | callbacks: 80 | image_logger: 81 | target: main.ImageLogger 82 | params: 83 | batch_frequency: 2000 84 | max_images: 8 85 | increase_log_steps: False 86 | 87 | 88 | trainer: 89 | benchmark: True -------------------------------------------------------------------------------- /reg_slices/src_convonet/utils/libvoxelize/voxelize.pyx: -------------------------------------------------------------------------------- 1 | cimport cython 2 | from libc.math cimport floor, ceil 3 | from cython.view cimport array as cvarray 4 | 5 | cdef extern from "tribox2.h": 6 | int triBoxOverlap(float boxcenter[3], float boxhalfsize[3], 7 | float tri0[3], float tri1[3], float tri2[3]) 8 | 9 | 10 | @cython.boundscheck(False) # Deactivate bounds checking 11 | @cython.wraparound(False) # Deactivate negative indexing. 12 | cpdef int voxelize_mesh_(bint[:, :, :] occ, float[:, :, ::1] faces): 13 | assert(faces.shape[1] == 3) 14 | assert(faces.shape[2] == 3) 15 | 16 | n_faces = faces.shape[0] 17 | cdef int i 18 | for i in range(n_faces): 19 | voxelize_triangle_(occ, faces[i]) 20 | 21 | 22 | @cython.boundscheck(False) # Deactivate bounds checking 23 | @cython.wraparound(False) # Deactivate negative indexing. 24 | cpdef int voxelize_triangle_(bint[:, :, :] occupancies, float[:, ::1] triverts): 25 | cdef int bbox_min[3] 26 | cdef int bbox_max[3] 27 | cdef int i, j, k 28 | cdef float boxhalfsize[3] 29 | cdef float boxcenter[3] 30 | cdef bint intersection 31 | 32 | boxhalfsize[:] = (0.5, 0.5, 0.5) 33 | 34 | for i in range(3): 35 | bbox_min[i] = ( 36 | min(triverts[0, i], triverts[1, i], triverts[2, i]) 37 | ) 38 | bbox_min[i] = min(max(bbox_min[i], 0), occupancies.shape[i] - 1) 39 | 40 | for i in range(3): 41 | bbox_max[i] = ( 42 | max(triverts[0, i], triverts[1, i], triverts[2, i]) 43 | ) 44 | bbox_max[i] = min(max(bbox_max[i], 0), occupancies.shape[i] - 1) 45 | 46 | for i in range(bbox_min[0], bbox_max[0] + 1): 47 | for j in range(bbox_min[1], bbox_max[1] + 1): 48 | for k in range(bbox_min[2], bbox_max[2] + 1): 49 | boxcenter[:] = (i + 0.5, j + 0.5, k + 0.5) 50 | intersection = triBoxOverlap(&boxcenter[0], &boxhalfsize[0], 51 | &triverts[0, 0], &triverts[1, 0], &triverts[2, 0]) 52 | occupancies[i, j, k] |= intersection 53 | 54 | 55 | @cython.boundscheck(False) # Deactivate bounds checking 56 | @cython.wraparound(False) # Deactivate negative indexing. 57 | cdef int test_triangle_aabb(float[::1] boxcenter, float[::1] boxhalfsize, float[:, ::1] triverts): 58 | assert(boxcenter.shape[0] == 3) 59 | assert(boxhalfsize.shape[0] == 3) 60 | assert(triverts.shape[0] == triverts.shape[1] == 3) 61 | 62 | # print(triverts) 63 | # Call functions 64 | cdef int result = triBoxOverlap(&boxcenter[0], &boxhalfsize[0], 65 | &triverts[0, 0], &triverts[1, 0], &triverts[2, 0]) 66 | return result 67 | -------------------------------------------------------------------------------- /gen_slices/configs/latent-diffusion/custom-sin-img-ldm-kl-8-infer.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False' 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | ckpt_path: "logs/2024-04-23T02-11-33_objaverse-ldm-kl-8/checkpoints/epoch=000103.ckpt" 6 | linear_start: 0.0015 7 | linear_end: 0.0155 8 | num_timesteps_cond: 1 9 | log_every_t: 200 10 | timesteps: 1000 11 | loss_type: l1 12 | first_stage_key: image 13 | cond_stage_key: img_ipt_view 14 | image_size: 64 15 | channels: 4 16 | cond_stage_trainable: true 17 | conditioning_key: concat 18 | # concat_mode: true 19 | scale_by_std: True 20 | monitor: 'val/loss_simple_ema' 21 | 22 | 23 | unet_config: 24 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 25 | params: 26 | image_size: 64 27 | in_channels: 8 28 | out_channels: 4 29 | model_channels: 192 30 | attention_resolutions: [ 1, 2, 4, 8 ] # 32, 16, 8, 4 31 | num_res_blocks: 2 32 | channel_mult: [ 1,2,2,4,4 ] # 32, 16, 8, 4, 2 33 | num_heads: 8 34 | use_scale_shift_norm: True 35 | resblock_updown: True 36 | 37 | first_stage_config: 38 | target: ldm.models.autoencoder.AutoencoderKL 39 | params: 40 | embed_dim: 4 41 | monitor: "val/rec_loss" 42 | ckpt_path: "logs/autoencoder_kl_f8/checkpoints/model.ckpt" 43 | ddconfig: 44 | double_z: True 45 | z_channels: 4 46 | resolution: 512 47 | in_channels: 3 48 | out_ch: 3 49 | ch: 128 50 | ch_mult: [1,2,4,4] # num_down = len(ch_mult)-1 51 | num_res_blocks: 2 52 | attn_resolutions: [ ] 53 | dropout: 0.0 54 | lossconfig: 55 | target: torch.nn.Identity 56 | 57 | cond_stage_config: 58 | target: ldm.modules.encoders.modules.ImageEncoderVGG16BN 59 | 60 | data: 61 | target: main.DataModuleFromConfig 62 | params: 63 | batch_size: 1 64 | num_workers: 4 65 | wrap: False 66 | train: 67 | target: ldm.data.custom_sin_img.CustomSinImgTrain 68 | params: 69 | size: 128 70 | validation: 71 | target: ldm.data.custom_sin_img.CustomSinImgValidation 72 | params: 73 | size: 128 74 | test: 75 | target: ldm.data.custom_sin_img.CustomSinImgTest 76 | params: 77 | size: 128 78 | lightning: 79 | callbacks: 80 | image_logger: 81 | target: main.ImageLogger 82 | params: 83 | batch_frequency: 2000 84 | max_images: 8 85 | increase_log_steps: False 86 | 87 | 88 | trainer: 89 | benchmark: True -------------------------------------------------------------------------------- /reg_slices/src_convonet/utils/visualize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from matplotlib import pyplot as plt 3 | from mpl_toolkits.mplot3d import Axes3D 4 | import src.common as common 5 | 6 | 7 | def visualize_data(data, data_type, out_file): 8 | r''' Visualizes the data with regard to its type. 9 | 10 | Args: 11 | data (tensor): batch of data 12 | data_type (string): data type (img, voxels or pointcloud) 13 | out_file (string): output file 14 | ''' 15 | if data_type == 'voxels': 16 | visualize_voxels(data, out_file=out_file) 17 | elif data_type == 'pointcloud': 18 | visualize_pointcloud(data, out_file=out_file) 19 | elif data_type is None or data_type == 'idx': 20 | pass 21 | else: 22 | raise ValueError('Invalid data_type "%s"' % data_type) 23 | 24 | 25 | def visualize_voxels(voxels, out_file=None, show=False): 26 | r''' Visualizes voxel data. 27 | 28 | Args: 29 | voxels (tensor): voxel data 30 | out_file (string): output file 31 | show (bool): whether the plot should be shown 32 | ''' 33 | # Use numpy 34 | voxels = np.asarray(voxels) 35 | # Create plot 36 | fig = plt.figure() 37 | ax = fig.gca(projection=Axes3D.name) 38 | voxels = voxels.transpose(2, 0, 1) 39 | ax.voxels(voxels, edgecolor='k') 40 | ax.set_xlabel('Z') 41 | ax.set_ylabel('X') 42 | ax.set_zlabel('Y') 43 | ax.view_init(elev=30, azim=45) 44 | if out_file is not None: 45 | plt.savefig(out_file) 46 | if show: 47 | plt.show() 48 | plt.close(fig) 49 | 50 | 51 | def visualize_pointcloud(points, normals=None, 52 | out_file=None, show=False): 53 | r''' Visualizes point cloud data. 54 | 55 | Args: 56 | points (tensor): point data 57 | normals (tensor): normal data (if existing) 58 | out_file (string): output file 59 | show (bool): whether the plot should be shown 60 | ''' 61 | # Use numpy 62 | points = np.asarray(points) 63 | # Create plot 64 | fig = plt.figure() 65 | ax = fig.gca(projection=Axes3D.name) 66 | ax.scatter(points[:, 2], points[:, 0], points[:, 1]) 67 | if normals is not None: 68 | ax.quiver( 69 | points[:, 2], points[:, 0], points[:, 1], 70 | normals[:, 2], normals[:, 0], normals[:, 1], 71 | length=0.1, color='k' 72 | ) 73 | ax.set_xlabel('Z') 74 | ax.set_ylabel('X') 75 | ax.set_zlabel('Y') 76 | ax.set_xlim(-0.5, 0.5) 77 | ax.set_ylim(-0.5, 0.5) 78 | ax.set_zlim(-0.5, 0.5) 79 | ax.view_init(elev=30, azim=45) 80 | if out_file is not None: 81 | plt.savefig(out_file) 82 | if show: 83 | plt.show() 84 | plt.close(fig) 85 | 86 | -------------------------------------------------------------------------------- /gen_slices/re_org_slices.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from PIL import Image 4 | import numpy as np 5 | import argparse 6 | 7 | def get_parser(): 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--dir_slices", type=str, default='logs/2024-04-23T02-11-33_objaverse-ldm-kl-8/images_testing_sampled') 10 | parser.add_argument("--type_slices", type=str, default='gen', choices=['gen', 'rec']) 11 | parser.add_argument("--name_dataset", type=str, default='objaverse') 12 | parser.add_argument("--img_size", type=int, default=128) 13 | parser.add_argument("--n_bs", type=int, default=8) 14 | parser.add_argument("--n_views", type=int, default=12) 15 | args = parser.parse_args() 16 | return args 17 | 18 | def crop_slices(args): 19 | dir_slices = args.dir_slices 20 | 21 | if args.type_slices == 'gen': 22 | dir_tgt = f'../data/{args.name_dataset}/04_img_slices_gen' 23 | shape_uids = open(f'../data/{args.name_dataset}/03_splits/test.lst', 'r').read().split('\n') 24 | else: 25 | dir_tgt = f'../data/{args.name_dataset}/05_img_slices_rec' 26 | shape_uids_ = open(f'../data/{args.name_dataset}/03_splits/trainval.lst', 'r').read().split('\n') 27 | shape_uids_num = len(shape_uids_) 28 | shape_uids = shape_uids_ * args.n_views 29 | 30 | axis_list = ['X', 'Z', 'Y'] 31 | part_list = ['1', '2', '3', '4'] 32 | part_list_ = ['4', '3', '2', '1'] 33 | 34 | img_size = args.img_size 35 | n_bs = args.n_bs 36 | 37 | for idx, shape_uid in enumerate(shape_uids): 38 | if idx % 1000 == 0: print(idx) 39 | batch_id = idx // n_bs 40 | case_id = idx % n_bs 41 | if args.type_slices == 'gen': 42 | view_id = '004' 43 | else: 44 | view_id = "%03d"%(idx // shape_uids_num) 45 | 46 | if not os.path.exists(f'{dir_slices}/{batch_id}_{case_id}.png'): continue 47 | img = Image.open(f'{dir_slices}/{batch_id}_{case_id}.png') 48 | os.makedirs(f'{dir_tgt}/{shape_uid}/{view_id}', exist_ok=True) 49 | for idx_i in range(3): 50 | for idx_j in range(4): 51 | axis = axis_list[idx_i] 52 | if idx_i != 1: 53 | part_name = part_list[idx_j] 54 | else: 55 | part_name = part_list_[idx_j] 56 | dir_save = f'{dir_tgt}/{shape_uid}/{view_id}/{axis}_{part_name}.png' 57 | if args.type_slices == 'rec' and os.path.exists(dir_save): continue 58 | crop_area = (idx_j * img_size, idx_i * img_size, (idx_j + 1) * img_size, (idx_i + 1) * img_size) 59 | img_slice = img.crop(crop_area) 60 | img_slice.save(dir_save) 61 | 62 | if __name__ == "__main__": 63 | args = get_parser() 64 | crop_slices(args) -------------------------------------------------------------------------------- /gen_slices/configs/latent-diffusion/cin-ldm-vq-f8.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 32 21 | in_channels: 4 22 | out_channels: 4 23 | model_channels: 256 24 | attention_resolutions: 25 | #note: this isn\t actually the resolution but 26 | # the downsampling factor, i.e. this corresnponds to 27 | # attention on spatial resolution 8,16,32, as the 28 | # spatial reolution of the latents is 32 for f8 29 | - 4 30 | - 2 31 | - 1 32 | num_res_blocks: 2 33 | channel_mult: 34 | - 1 35 | - 2 36 | - 4 37 | num_head_channels: 32 38 | use_spatial_transformer: true 39 | transformer_depth: 1 40 | context_dim: 512 41 | first_stage_config: 42 | target: ldm.models.autoencoder.VQModelInterface 43 | params: 44 | embed_dim: 4 45 | n_embed: 16384 46 | ckpt_path: configs/first_stage_models/vq-f8/model.yaml 47 | ddconfig: 48 | double_z: false 49 | z_channels: 4 50 | resolution: 256 51 | in_channels: 3 52 | out_ch: 3 53 | ch: 128 54 | ch_mult: 55 | - 1 56 | - 2 57 | - 2 58 | - 4 59 | num_res_blocks: 2 60 | attn_resolutions: 61 | - 32 62 | dropout: 0.0 63 | lossconfig: 64 | target: torch.nn.Identity 65 | cond_stage_config: 66 | target: ldm.modules.encoders.modules.ClassEmbedder 67 | params: 68 | embed_dim: 512 69 | key: class_label 70 | data: 71 | target: main.DataModuleFromConfig 72 | params: 73 | batch_size: 64 74 | num_workers: 12 75 | wrap: false 76 | train: 77 | target: ldm.data.imagenet.ImageNetTrain 78 | params: 79 | config: 80 | size: 256 81 | validation: 82 | target: ldm.data.imagenet.ImageNetValidation 83 | params: 84 | config: 85 | size: 256 86 | 87 | 88 | lightning: 89 | callbacks: 90 | image_logger: 91 | target: main.ImageLogger 92 | params: 93 | batch_frequency: 5000 94 | max_images: 8 95 | increase_log_steps: False 96 | 97 | trainer: 98 | benchmark: True -------------------------------------------------------------------------------- /reg_slices/src/unet_custom.py: -------------------------------------------------------------------------------- 1 | from .unet_parts import * 2 | import torchvision 3 | 4 | class UNet(nn.Module): 5 | def __init__(self, n_channels=3): 6 | super(UNet, self).__init__() 7 | self.n_channels = n_channels 8 | bilinear = False 9 | self.n_slices = 12 10 | self.dim_embed = 128 11 | 12 | vgg = torchvision.models.vgg16_bn(pretrained=True) 13 | 14 | vgg_features = vgg.features 15 | self.down1 = vgg_features[:4] 16 | self.down2 = vgg_features[4:11] 17 | self.down3 = vgg_features[11:21] 18 | self.down4 = vgg_features[21:31] 19 | self.down5 = vgg_features[31:41] 20 | self.down5_ = vgg_features[41:44] 21 | 22 | self.trans_c = nn.Conv2d(512 + self.dim_embed, 512, 1) 23 | self.up1 = (Up(512, 256, bilinear)) 24 | self.trans_up1 = nn.Conv2d(512, 256, 1) 25 | self.up2 = (Up(256, 128, bilinear)) 26 | self.trans_up2 = nn.Conv2d(256, 128, 1) 27 | self.up3 = (Up(128, 64, bilinear)) 28 | self.trans_up3 = nn.Conv2d(128, 64, 1) 29 | self.up4 = (Up(64, 32, bilinear)) 30 | self.trans_up4 = nn.Conv2d(64, 32, 1) 31 | self.outc = (OutConv(32, 3)) 32 | self.emds = torch.nn.Embedding(self.n_slices, self.dim_embed) 33 | 34 | 35 | def expand_bs(self, x): 36 | n_bs, n_c, n_w, n_h = x.shape 37 | x_tile = x.view(n_bs, 1, n_c, n_w, n_h).expand(-1, self.n_slices, -1, -1, -1).reshape(n_bs * self.n_slices, n_c, n_w, n_h) 38 | return x_tile 39 | 40 | def forward(self, x): 41 | feats = [] 42 | 43 | x1 = self.down1(x) # 64, img_size, img_size 44 | x2 = self.down2(x1) # 128, img_size // 2, img_size // 2 45 | x3 = self.down3(x2) # 256, img_size // 4, img_size // 4 46 | x4 = self.down4(x3) # 512, img_size // 8, img_size // 8 47 | x5 = self.down5(x4) # 512, img_size // 16, img_size // 16 48 | x5_ = self.down5_(x5) # 512, img_size // 32, img_size // 32 49 | 50 | n_bs, n_c, n_w, n_h = x5.shape 51 | 52 | embs_tile = self.emds.weight.view(1, self.n_slices, self.dim_embed, 1, 1).expand(n_bs, self.n_slices, self.dim_embed, n_w, n_h).reshape(n_bs * self.n_slices, self.dim_embed, n_w, n_h) 53 | 54 | x5_tile = self.expand_bs(x5) 55 | 56 | latent = torch.cat([x5_tile, embs_tile], 1) 57 | latent = self.trans_c(latent) 58 | feats.append(latent) 59 | 60 | x = self.up1(latent, self.trans_up1(self.expand_bs(x4))) 61 | feats.append(x) 62 | x = self.up2(x, self.trans_up2(self.expand_bs(x3))) 63 | feats.append(x) 64 | x = self.up3(x, self.trans_up3(self.expand_bs(x2))) 65 | feats.append(x) 66 | x = self.up4(x, self.trans_up4(self.expand_bs(x1))) 67 | feats.append(x) 68 | out = self.outc(x) 69 | return feats, out 70 | -------------------------------------------------------------------------------- /reg_slices/src_convonet/utils/libsimplify/simplify_mesh.pyx: -------------------------------------------------------------------------------- 1 | # distutils: language = c++ 2 | from libcpp.vector cimport vector 3 | import numpy as np 4 | cimport numpy as np 5 | 6 | 7 | cdef extern from "Simplify.h": 8 | cdef struct vec3f: 9 | double x, y, z 10 | 11 | cdef cppclass SymetricMatrix: 12 | SymetricMatrix() except + 13 | 14 | 15 | cdef extern from "Simplify.h" namespace "Simplify": 16 | cdef struct Triangle: 17 | int v[3] 18 | double err[4] 19 | int deleted, dirty, attr 20 | vec3f uvs[3] 21 | int material 22 | 23 | cdef struct Vertex: 24 | vec3f p 25 | int tstart, tcount 26 | SymetricMatrix q 27 | int border 28 | 29 | cdef vector[Triangle] triangles 30 | cdef vector[Vertex] vertices 31 | cdef void simplify_mesh(int, double) 32 | 33 | 34 | cpdef mesh_simplify(double[:, ::1] vertices_in, long[:, ::1] triangles_in, 35 | int f_target, double agressiveness=7.) except +: 36 | vertices.clear() 37 | triangles.clear() 38 | 39 | # Read in vertices and triangles 40 | cdef Vertex v 41 | for iv in range(vertices_in.shape[0]): 42 | v = Vertex() 43 | v.p.x = vertices_in[iv, 0] 44 | v.p.y = vertices_in[iv, 1] 45 | v.p.z = vertices_in[iv, 2] 46 | vertices.push_back(v) 47 | 48 | cdef Triangle t 49 | for it in range(triangles_in.shape[0]): 50 | t = Triangle() 51 | t.v[0] = triangles_in[it, 0] 52 | t.v[1] = triangles_in[it, 1] 53 | t.v[2] = triangles_in[it, 2] 54 | triangles.push_back(t) 55 | 56 | # Simplify 57 | # print('Simplify...') 58 | simplify_mesh(f_target, agressiveness) 59 | 60 | # Only use triangles that are not deleted 61 | cdef vector[Triangle] triangles_notdel 62 | triangles_notdel.reserve(triangles.size()) 63 | 64 | for t in triangles: 65 | if not t.deleted: 66 | triangles_notdel.push_back(t) 67 | 68 | # Read out triangles 69 | vertices_out = np.empty((vertices.size(), 3), dtype=np.float64) 70 | triangles_out = np.empty((triangles_notdel.size(), 3), dtype=np.int64) 71 | 72 | cdef double[:, :] vertices_out_view = vertices_out 73 | cdef long[:, :] triangles_out_view = triangles_out 74 | 75 | for iv in range(vertices.size()): 76 | vertices_out_view[iv, 0] = vertices[iv].p.x 77 | vertices_out_view[iv, 1] = vertices[iv].p.y 78 | vertices_out_view[iv, 2] = vertices[iv].p.z 79 | 80 | for it in range(triangles_notdel.size()): 81 | triangles_out_view[it, 0] = triangles_notdel[it].v[0] 82 | triangles_out_view[it, 1] = triangles_notdel[it].v[1] 83 | triangles_out_view[it, 2] = triangles_notdel[it].v[2] 84 | 85 | # Clear vertices and triangles 86 | vertices.clear() 87 | triangles.clear() 88 | 89 | return vertices_out, triangles_out -------------------------------------------------------------------------------- /reg_slices/src/unet_parts.py: -------------------------------------------------------------------------------- 1 | """ Parts of the U-Net model """ 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class DoubleConv(nn.Module): 9 | """(convolution => [BN] => ReLU) * 2""" 10 | 11 | def __init__(self, in_channels, out_channels, mid_channels=None): 12 | super().__init__() 13 | if not mid_channels: 14 | mid_channels = out_channels 15 | self.double_conv = nn.Sequential( 16 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), 17 | nn.BatchNorm2d(mid_channels), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), 20 | nn.BatchNorm2d(out_channels), 21 | nn.ReLU(inplace=True) 22 | ) 23 | 24 | def forward(self, x): 25 | return self.double_conv(x) 26 | 27 | 28 | class Down(nn.Module): 29 | """Downscaling with maxpool then double conv""" 30 | 31 | def __init__(self, in_channels, out_channels): 32 | super().__init__() 33 | self.maxpool_conv = nn.Sequential( 34 | nn.MaxPool2d(2), 35 | DoubleConv(in_channels, out_channels) 36 | ) 37 | 38 | def forward(self, x): 39 | return self.maxpool_conv(x) 40 | 41 | 42 | class Up(nn.Module): 43 | """Upscaling then double conv""" 44 | 45 | def __init__(self, in_channels, out_channels, bilinear=True): 46 | super().__init__() 47 | 48 | # if bilinear, use the normal convolutions to reduce the number of channels 49 | if bilinear: 50 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 51 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 52 | else: 53 | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) 54 | self.conv = DoubleConv(in_channels, out_channels) 55 | 56 | def forward(self, x1, x2): 57 | # print('x1_0', x1.shape) 58 | x1 = self.up(x1) 59 | # print('x1', x1.shape) 60 | # print('x2', x2.shape) 61 | # input is CHW 62 | diffY = x2.size()[2] - x1.size()[2] 63 | diffX = x2.size()[3] - x1.size()[3] 64 | 65 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 66 | diffY // 2, diffY - diffY // 2]) 67 | # if you have padding issues, see 68 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 69 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 70 | # print(x1.shape) 71 | # print(x2.shape) 72 | # input() 73 | x = torch.cat([x2, x1], dim=1) 74 | 75 | return self.conv(x) 76 | 77 | 78 | class OutConv(nn.Module): 79 | def __init__(self, in_channels, out_channels): 80 | super(OutConv, self).__init__() 81 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 82 | self.act = nn.Tanh() 83 | def forward(self, x): 84 | return self.act(self.conv(x)) -------------------------------------------------------------------------------- /reg_slices/src/utils_eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import math 4 | from scipy.spatial import cKDTree, distance 5 | from src_convonet.utils.libmesh import check_mesh_contains 6 | 7 | def compute_iou(occ1, occ2): 8 | ''' Computes the Intersection over Union (IoU) value for two sets of 9 | occupancy values. 10 | 11 | Args: 12 | occ1 (tensor): first set of occupancy values 13 | occ2 (tensor): second set of occupancy values 14 | ''' 15 | occ1 = np.asarray(occ1) 16 | occ2 = np.asarray(occ2) 17 | 18 | # Put all data in second dimension 19 | # Also works for 1-dimensional data 20 | if occ1.ndim >= 2: 21 | occ1 = occ1.reshape(occ1.shape[0], -1) 22 | if occ2.ndim >= 2: 23 | occ2 = occ2.reshape(occ2.shape[0], -1) 24 | 25 | # Convert to boolean values 26 | occ1 = (occ1 >= 0.5) 27 | occ2 = (occ2 >= 0.5) 28 | 29 | # Compute IOU 30 | area_union = (occ1 | occ2).astype(np.float32).sum(axis=-1) 31 | area_intersect = (occ1 & occ2).astype(np.float32).sum(axis=-1) 32 | 33 | iou = (area_intersect / area_union) 34 | 35 | return iou 36 | 37 | def eval_iou(mesh, qry, occ_tgt): 38 | 39 | if len(mesh.vertices) != 0 and len(mesh.faces) != 0: 40 | occ = check_mesh_contains(mesh, qry) 41 | iou = compute_iou(occ, occ_tgt) 42 | else: 43 | iou = 0.0 44 | return iou 45 | 46 | 47 | def points_dist(p1, p2, k=1, return_ind=False): 48 | '''distance from p1 to p2''' 49 | tree = cKDTree(p2) 50 | dist, ind = tree.query(p1, k=k) 51 | if return_ind == True: 52 | return dist, ind 53 | else: 54 | return dist 55 | 56 | def chamfer_dist(p1, p2): 57 | d1 = points_dist(p1, p2) ** 2 58 | d2 = points_dist(p2, p1) ** 2 59 | return d1, d2 60 | 61 | def np2th(array, device='cuda'): 62 | tensor = array 63 | if type(array) is not torch.Tensor: 64 | tensor = torch.tensor(array).float() 65 | if type(tensor) is torch.Tensor: 66 | if device=='cuda': 67 | return tensor.cuda() 68 | return tensor.cpu() 69 | else: 70 | return array 71 | 72 | def eval_chamfer(p1, p2, f_thresh=0.01): 73 | """ p1: reconstructed points 74 | p2: reference ponits 75 | shapes: (N, 3) 76 | """ 77 | d1, d2 = chamfer_dist(p1, p2) 78 | 79 | d1sqrt, d2sqrt = (d1**.5), (d2**.5) 80 | chamfer_L1 = 0.5 * (d1sqrt.mean(axis=-1) + d2sqrt.mean(axis=-1)) 81 | chamfer_L2 = 0.5 * (d1.mean(axis=-1) + d2.mean(axis=-1)) 82 | 83 | precision = (d1sqrt < f_thresh).sum(axis=-1) / p1.shape[0] 84 | recall = (d2sqrt < f_thresh).sum(axis=-1) / p2.shape[0] 85 | fscore = 2 * (recall * precision / recall + precision ) 86 | 87 | return [chamfer_L1, chamfer_L2, fscore, precision, recall] 88 | 89 | def eval_hausdoff(p1, p2): 90 | """ p1: reconstructed points 91 | p2: reference ponits 92 | shapes: (N, 3) 93 | """ 94 | dist_rec2ref, _, _ = distance.directed_hausdorff(p1, p2) 95 | dist_ref2rec, _, _ = distance.directed_hausdorff(p2, p1) 96 | dist = max(dist_rec2ref, dist_ref2rec) 97 | return dist_rec2ref, dist_ref2rec, dist 98 | 99 | -------------------------------------------------------------------------------- /reg_slices/src/vgg_perceptual_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | 4 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 5 | 6 | class VGG19Feats(torch.nn.Module): 7 | def __init__(self, requires_grad=False): 8 | super(VGG19Feats, self).__init__() 9 | vgg = torchvision.models.vgg19(pretrained=True).to(device) #.cuda() 10 | # vgg.eval() 11 | vgg_pretrained_features = vgg.features.eval() 12 | self.requires_grad = requires_grad 13 | self.slice1 = torch.nn.Sequential() 14 | self.slice2 = torch.nn.Sequential() 15 | self.slice3 = torch.nn.Sequential() 16 | self.slice4 = torch.nn.Sequential() 17 | self.slice5 = torch.nn.Sequential() 18 | for x in range(3): 19 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 20 | for x in range(3, 8): 21 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 22 | for x in range(8, 13): 23 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 24 | for x in range(13, 22): 25 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 26 | for x in range(22, 31): 27 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 28 | if not self.requires_grad: 29 | for param in self.parameters(): 30 | param.requires_grad = False 31 | 32 | def forward(self, img): 33 | conv1_2 = self.slice1(img) 34 | conv2_2 = self.slice2(conv1_2) 35 | conv3_2 = self.slice3(conv2_2) 36 | conv4_2 = self.slice4(conv3_2) 37 | conv5_2 = self.slice5(conv4_2) 38 | out = [conv1_2, conv2_2, conv3_2, conv4_2, conv5_2] 39 | return out 40 | 41 | 42 | class VGGPerceptualLoss(torch.nn.Module): 43 | def __init__(self): 44 | super(VGGPerceptualLoss, self).__init__() 45 | self.vgg = VGG19Feats().to(device) 46 | self.criterion = torch.nn.functional.l1_loss 47 | self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) 48 | self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) 49 | self.weights = [1.0/2.6, 1.0/4.8, 1.0/3.7, 1.0/5.6, 1.0*10/1.5] 50 | 51 | def forward(self, input_img, target_img): 52 | input_img = (input_img + 1) / 2. 53 | target_img = (target_img + 1) / 2. 54 | 55 | if input_img.shape[1] != 3: 56 | input_img = input_img.repeat(1, 3, 1, 1) 57 | target_img = target_img.repeat(1, 3, 1, 1) 58 | input_img = (input_img - self.mean) / self.std 59 | target_img = (target_img - self.mean) / self.std 60 | 61 | x_vgg, y_vgg = self.vgg(input_img), self.vgg(target_img) 62 | 63 | loss = {} 64 | loss['pt_c_loss'] = self.weights[0] * self.criterion(x_vgg[0], y_vgg[0])+\ 65 | self.weights[1] * self.criterion(x_vgg[1], y_vgg[1])+\ 66 | self.weights[2] * self.criterion(x_vgg[2], y_vgg[2])+\ 67 | self.weights[3] * self.criterion(x_vgg[3], y_vgg[3])+\ 68 | self.weights[4] * self.criterion(x_vgg[4], y_vgg[4]) 69 | loss['pt_s_loss'] = 0.0 70 | 71 | return loss -------------------------------------------------------------------------------- /gen_slices/ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /gen_slices/ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /reg_slices/src_convonet/utils/libmesh/triangle_hash.pyx: -------------------------------------------------------------------------------- 1 | 2 | # distutils: language=c++ 3 | import numpy as np 4 | cimport numpy as np 5 | cimport cython 6 | from libcpp.vector cimport vector 7 | from libc.math cimport floor, ceil 8 | 9 | cdef class TriangleHash: 10 | cdef vector[vector[int]] spatial_hash 11 | cdef int resolution 12 | 13 | def __cinit__(self, double[:, :, :] triangles, int resolution): 14 | self.spatial_hash.resize(resolution * resolution) 15 | self.resolution = resolution 16 | self._build_hash(triangles) 17 | 18 | @cython.boundscheck(False) # Deactivate bounds checking 19 | @cython.wraparound(False) # Deactivate negative indexing. 20 | cdef int _build_hash(self, double[:, :, :] triangles): 21 | assert(triangles.shape[1] == 3) 22 | assert(triangles.shape[2] == 2) 23 | 24 | cdef int n_tri = triangles.shape[0] 25 | cdef int bbox_min[2] 26 | cdef int bbox_max[2] 27 | 28 | cdef int i_tri, j, x, y 29 | cdef int spatial_idx 30 | 31 | for i_tri in range(n_tri): 32 | # Compute bounding box 33 | for j in range(2): 34 | bbox_min[j] = min( 35 | triangles[i_tri, 0, j], triangles[i_tri, 1, j], triangles[i_tri, 2, j] 36 | ) 37 | bbox_max[j] = max( 38 | triangles[i_tri, 0, j], triangles[i_tri, 1, j], triangles[i_tri, 2, j] 39 | ) 40 | bbox_min[j] = min(max(bbox_min[j], 0), self.resolution - 1) 41 | bbox_max[j] = min(max(bbox_max[j], 0), self.resolution - 1) 42 | 43 | # Find all voxels where bounding box intersects 44 | for x in range(bbox_min[0], bbox_max[0] + 1): 45 | for y in range(bbox_min[1], bbox_max[1] + 1): 46 | spatial_idx = self.resolution * x + y 47 | self.spatial_hash[spatial_idx].push_back(i_tri) 48 | 49 | @cython.boundscheck(False) # Deactivate bounds checking 50 | @cython.wraparound(False) # Deactivate negative indexing. 51 | cpdef query(self, double[:, :] points): 52 | assert(points.shape[1] == 2) 53 | cdef int n_points = points.shape[0] 54 | 55 | cdef vector[int] points_indices 56 | cdef vector[int] tri_indices 57 | # cdef int[:] points_indices_np 58 | # cdef int[:] tri_indices_np 59 | 60 | cdef int i_point, k, x, y 61 | cdef int spatial_idx 62 | 63 | for i_point in range(n_points): 64 | x = int(points[i_point, 0]) 65 | y = int(points[i_point, 1]) 66 | if not (0 <= x < self.resolution and 0 <= y < self.resolution): 67 | continue 68 | 69 | spatial_idx = self.resolution * x + y 70 | for i_tri in self.spatial_hash[spatial_idx]: 71 | points_indices.push_back(i_point) 72 | tri_indices.push_back(i_tri) 73 | 74 | points_indices_np = np.zeros(points_indices.size(), dtype=np.int32) 75 | tri_indices_np = np.zeros(tri_indices.size(), dtype=np.int32) 76 | 77 | cdef int[:] points_indices_view = points_indices_np 78 | cdef int[:] tri_indices_view = tri_indices_np 79 | 80 | for k in range(points_indices.size()): 81 | points_indices_view[k] = points_indices[k] 82 | 83 | for k in range(tri_indices.size()): 84 | tri_indices_view[k] = tri_indices[k] 85 | 86 | return points_indices_np, tri_indices_np 87 | -------------------------------------------------------------------------------- /create_dataset_sin_img.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import numpy as np 4 | import argparse 5 | import pickle 6 | from PIL import Image 7 | 8 | def save_pickle(data, pkl_path): 9 | with open(pkl_path, 'wb') as f: 10 | pickle.dump(data, f) 11 | 12 | def get_parser(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--img_path", type=str, default='./imgs/demo/input.png') 15 | parser.add_argument("--name_dataset", type=str, default='custom_sin_img') 16 | parser.add_argument("--img_size", type=int, default=256) 17 | parser.add_argument("--resize_img", type=bool, default=False) 18 | parser.add_argument("--center_obj", type=bool, default=True) 19 | args = parser.parse_args() 20 | return args 21 | 22 | def create_dataset(args): 23 | dir_tgt = f'./data/{args.name_dataset}' 24 | os.makedirs(dir_tgt, exist_ok=True) 25 | object_uid = '00000' 26 | dir_names = ['00_img_input', '01_img_slices', '02_sdfs', '03_splits'] 27 | for dir_name in dir_names: 28 | os.makedirs(f'{dir_tgt}/{dir_name}', exist_ok=True) 29 | 30 | # save image 31 | img = Image.open(args.img_path) 32 | os.makedirs(f'{dir_tgt}/00_img_input/{object_uid}', exist_ok=True) 33 | img_path_new = f'{dir_tgt}/00_img_input/{object_uid}/004.png' 34 | assert (img.mode == 'RGBA') 35 | if args.center_obj: 36 | # we assumed the camera points to the centre of the object 37 | # centre the 2D bbox could be helpful, but does not guarantee that object 3D centre is aligned 38 | alpha = img.split()[3] 39 | bbox = alpha.getbbox() 40 | width, height = img.size 41 | object_width = bbox[2] - bbox[0] 42 | object_height = bbox[3] - bbox[1] 43 | offset_x = (width - object_width) // 2 - bbox[0] 44 | offset_y = (height - object_height) // 2 - bbox[1] 45 | new_image = Image.new('RGBA', (width, height), (0, 0, 0, 0)) 46 | new_image.paste(img, (offset_x, offset_y), mask=alpha) 47 | img = new_image 48 | if args.resize_img: 49 | img_rs = img.resize((args.img_size, args.img_size), Image.ANTIALIAS) 50 | img_rs.save(img_path_new) 51 | else: 52 | # shutil.copy(args.img_path, img_path_new) 53 | img.save(img_path_new, 'PNG') 54 | # write meta pkl 55 | K = np.zeros((3, 3)) 56 | azimuths = np.zeros(12) 57 | elevations = np.zeros(12) 58 | distances = np.ones(12) * 1.2 59 | cam_poses = np.zeros((12, 3, 4)) 60 | scale_rand = 1.0 61 | offset_rand = np.zeros(3) 62 | save_pickle([K, azimuths, elevations, distances, cam_poses, scale_rand, offset_rand], f'{dir_tgt}/00_img_input/{object_uid}/meta.pkl') 63 | 64 | # 01_img_slices 65 | os.makedirs(f'{dir_tgt}/01_img_slices/{object_uid}/004', exist_ok=True) 66 | for axis in ['X', 'Y', 'Z']: 67 | for part in ['1', '2', '3', '4']: 68 | img_slice = Image.new("RGBA", (args.img_size, args.img_size)) 69 | img_slice.save(f'{dir_tgt}/01_img_slices/{object_uid}/004/{axis}_{part}.png') 70 | 71 | # 02_sdfs 72 | arr_sdf = np.zeros((16384, 4)) 73 | np.save(f'{dir_tgt}/02_sdfs/{object_uid}.npy', arr_sdf) 74 | 75 | # 03_splits 76 | os.makedirs(f'{dir_tgt}/03_splits', exist_ok=True) 77 | for split in ['train', 'val', 'test']: 78 | fout = open(f'{dir_tgt}/03_splits/{split}.lst', 'w') 79 | # create 00_img_input 80 | fout.write(object_uid) 81 | fout.close() 82 | 83 | 84 | if __name__ == "__main__": 85 | args = get_parser() 86 | create_dataset(args) -------------------------------------------------------------------------------- /gen_slices/ldm/data/lsun.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | 8 | 9 | class LSUNBase(Dataset): 10 | def __init__(self, 11 | txt_file, 12 | data_root, 13 | size=None, 14 | interpolation="bicubic", 15 | flip_p=0.5 16 | ): 17 | self.data_paths = txt_file 18 | self.data_root = data_root 19 | with open(self.data_paths, "r") as f: 20 | self.image_paths = f.read().splitlines() 21 | self._length = len(self.image_paths) 22 | self.labels = { 23 | "relative_file_path_": [l for l in self.image_paths], 24 | "file_path_": [os.path.join(self.data_root, l) 25 | for l in self.image_paths], 26 | } 27 | 28 | self.size = size 29 | self.interpolation = {"linear": PIL.Image.LINEAR, 30 | "bilinear": PIL.Image.BILINEAR, 31 | "bicubic": PIL.Image.BICUBIC, 32 | "lanczos": PIL.Image.LANCZOS, 33 | }[interpolation] 34 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) 35 | 36 | def __len__(self): 37 | return self._length 38 | 39 | def __getitem__(self, i): 40 | example = dict((k, self.labels[k][i]) for k in self.labels) 41 | image = Image.open(example["file_path_"]) 42 | if not image.mode == "RGB": 43 | image = image.convert("RGB") 44 | 45 | # default to score-sde preprocessing 46 | img = np.array(image).astype(np.uint8) 47 | crop = min(img.shape[0], img.shape[1]) 48 | h, w, = img.shape[0], img.shape[1] 49 | img = img[(h - crop) // 2:(h + crop) // 2, 50 | (w - crop) // 2:(w + crop) // 2] 51 | 52 | image = Image.fromarray(img) 53 | if self.size is not None: 54 | image = image.resize((self.size, self.size), resample=self.interpolation) 55 | 56 | image = self.flip(image) 57 | image = np.array(image).astype(np.uint8) 58 | example["image"] = (image / 127.5 - 1.0).astype(np.float32) 59 | return example 60 | 61 | 62 | class LSUNChurchesTrain(LSUNBase): 63 | def __init__(self, **kwargs): 64 | super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs) 65 | 66 | 67 | class LSUNChurchesValidation(LSUNBase): 68 | def __init__(self, flip_p=0., **kwargs): 69 | super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", 70 | flip_p=flip_p, **kwargs) 71 | 72 | 73 | class LSUNBedroomsTrain(LSUNBase): 74 | def __init__(self, **kwargs): 75 | super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs) 76 | 77 | 78 | class LSUNBedroomsValidation(LSUNBase): 79 | def __init__(self, flip_p=0.0, **kwargs): 80 | super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", 81 | flip_p=flip_p, **kwargs) 82 | 83 | 84 | class LSUNCatsTrain(LSUNBase): 85 | def __init__(self, **kwargs): 86 | super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs) 87 | 88 | 89 | class LSUNCatsValidation(LSUNBase): 90 | def __init__(self, flip_p=0., **kwargs): 91 | super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", 92 | flip_p=flip_p, **kwargs) 93 | -------------------------------------------------------------------------------- /reg_slices/src_convonet/utils/io.py: -------------------------------------------------------------------------------- 1 | import os 2 | from plyfile import PlyElement, PlyData 3 | import numpy as np 4 | 5 | 6 | def export_pointcloud(vertices, out_file, as_text=True): 7 | assert(vertices.shape[1] == 3) 8 | vertices = vertices.astype(np.float32) 9 | vertices = np.ascontiguousarray(vertices) 10 | vector_dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4')] 11 | vertices = vertices.view(dtype=vector_dtype).flatten() 12 | plyel = PlyElement.describe(vertices, 'vertex') 13 | plydata = PlyData([plyel], text=as_text) 14 | plydata.write(out_file) 15 | 16 | 17 | def load_pointcloud(in_file): 18 | plydata = PlyData.read(in_file) 19 | vertices = np.stack([ 20 | plydata['vertex']['x'], 21 | plydata['vertex']['y'], 22 | plydata['vertex']['z'] 23 | ], axis=1) 24 | return vertices 25 | 26 | 27 | def read_off(file): 28 | """ 29 | Reads vertices and faces from an off file. 30 | 31 | :param file: path to file to read 32 | :type file: str 33 | :return: vertices and faces as lists of tuples 34 | :rtype: [(float)], [(int)] 35 | """ 36 | 37 | assert os.path.exists(file), 'file %s not found' % file 38 | 39 | with open(file, 'r') as fp: 40 | lines = fp.readlines() 41 | lines = [line.strip() for line in lines] 42 | 43 | # Fix for ModelNet bug were 'OFF' and the number of vertices and faces 44 | # are all in the first line. 45 | if len(lines[0]) > 3: 46 | assert lines[0][:3] == 'OFF' or lines[0][:3] == 'off', \ 47 | 'invalid OFF file %s' % file 48 | 49 | parts = lines[0][3:].split(' ') 50 | assert len(parts) == 3 51 | 52 | num_vertices = int(parts[0]) 53 | assert num_vertices > 0 54 | 55 | num_faces = int(parts[1]) 56 | assert num_faces > 0 57 | 58 | start_index = 1 59 | # This is the regular case! 60 | else: 61 | assert lines[0] == 'OFF' or lines[0] == 'off', \ 62 | 'invalid OFF file %s' % file 63 | 64 | parts = lines[1].split(' ') 65 | assert len(parts) == 3 66 | 67 | num_vertices = int(parts[0]) 68 | assert num_vertices > 0 69 | 70 | num_faces = int(parts[1]) 71 | assert num_faces > 0 72 | 73 | start_index = 2 74 | 75 | vertices = [] 76 | for i in range(num_vertices): 77 | vertex = lines[start_index + i].split(' ') 78 | vertex = [float(point.strip()) for point in vertex if point != ''] 79 | assert len(vertex) == 3 80 | 81 | vertices.append(vertex) 82 | 83 | faces = [] 84 | for i in range(num_faces): 85 | face = lines[start_index + num_vertices + i].split(' ') 86 | face = [index.strip() for index in face if index != ''] 87 | 88 | # check to be sure 89 | for index in face: 90 | assert index != '', \ 91 | 'found empty vertex index: %s (%s)' \ 92 | % (lines[start_index + num_vertices + i], file) 93 | 94 | face = [int(index) for index in face] 95 | 96 | assert face[0] == len(face) - 1, \ 97 | 'face should have %d vertices but as %d (%s)' \ 98 | % (face[0], len(face) - 1, file) 99 | assert face[0] == 3, \ 100 | 'only triangular meshes supported (%s)' % file 101 | for index in face: 102 | assert index >= 0 and index < num_vertices, \ 103 | 'vertex %d (of %d vertices) does not exist (%s)' \ 104 | % (index, num_vertices, file) 105 | 106 | assert len(face) > 1 107 | 108 | faces.append(face) 109 | 110 | return vertices, faces 111 | 112 | assert False, 'could not open %s' % file 113 | -------------------------------------------------------------------------------- /gen_slices/scripts/inpaint.py: -------------------------------------------------------------------------------- 1 | import argparse, os, sys, glob 2 | from omegaconf import OmegaConf 3 | from PIL import Image 4 | from tqdm import tqdm 5 | import numpy as np 6 | import torch 7 | from main import instantiate_from_config 8 | from ldm.models.diffusion.ddim import DDIMSampler 9 | 10 | 11 | def make_batch(image, mask, device): 12 | image = np.array(Image.open(image).convert("RGB")) 13 | image = image.astype(np.float32)/255.0 14 | image = image[None].transpose(0,3,1,2) 15 | image = torch.from_numpy(image) 16 | 17 | mask = np.array(Image.open(mask).convert("L")) 18 | mask = mask.astype(np.float32)/255.0 19 | mask = mask[None,None] 20 | mask[mask < 0.5] = 0 21 | mask[mask >= 0.5] = 1 22 | mask = torch.from_numpy(mask) 23 | 24 | masked_image = (1-mask)*image 25 | 26 | batch = {"image": image, "mask": mask, "masked_image": masked_image} 27 | for k in batch: 28 | batch[k] = batch[k].to(device=device) 29 | batch[k] = batch[k]*2.0-1.0 30 | return batch 31 | 32 | 33 | if __name__ == "__main__": 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument( 36 | "--indir", 37 | type=str, 38 | nargs="?", 39 | help="dir containing image-mask pairs (`example.png` and `example_mask.png`)", 40 | ) 41 | parser.add_argument( 42 | "--outdir", 43 | type=str, 44 | nargs="?", 45 | help="dir to write results to", 46 | ) 47 | parser.add_argument( 48 | "--steps", 49 | type=int, 50 | default=50, 51 | help="number of ddim sampling steps", 52 | ) 53 | opt = parser.parse_args() 54 | 55 | masks = sorted(glob.glob(os.path.join(opt.indir, "*_mask.png"))) 56 | images = [x.replace("_mask.png", ".png") for x in masks] 57 | print(f"Found {len(masks)} inputs.") 58 | 59 | config = OmegaConf.load("models/ldm/inpainting_big/config.yaml") 60 | model = instantiate_from_config(config.model) 61 | model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"], 62 | strict=False) 63 | 64 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 65 | model = model.to(device) 66 | sampler = DDIMSampler(model) 67 | 68 | os.makedirs(opt.outdir, exist_ok=True) 69 | with torch.no_grad(): 70 | with model.ema_scope(): 71 | for image, mask in tqdm(zip(images, masks)): 72 | outpath = os.path.join(opt.outdir, os.path.split(image)[1]) 73 | batch = make_batch(image, mask, device=device) 74 | 75 | # encode masked image and concat downsampled mask 76 | c = model.cond_stage_model.encode(batch["masked_image"]) 77 | cc = torch.nn.functional.interpolate(batch["mask"], 78 | size=c.shape[-2:]) 79 | c = torch.cat((c, cc), dim=1) 80 | 81 | shape = (c.shape[1]-1,)+c.shape[2:] 82 | samples_ddim, _ = sampler.sample(S=opt.steps, 83 | conditioning=c, 84 | batch_size=c.shape[0], 85 | shape=shape, 86 | verbose=False) 87 | x_samples_ddim = model.decode_first_stage(samples_ddim) 88 | 89 | image = torch.clamp((batch["image"]+1.0)/2.0, 90 | min=0.0, max=1.0) 91 | mask = torch.clamp((batch["mask"]+1.0)/2.0, 92 | min=0.0, max=1.0) 93 | predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0, 94 | min=0.0, max=1.0) 95 | 96 | inpainted = (1-mask)*image+mask*predicted_image 97 | inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255 98 | Image.fromarray(inpainted.astype(np.uint8)).save(outpath) 99 | -------------------------------------------------------------------------------- /reg_slices/src_convonet/utils/icp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.neighbors import NearestNeighbors 3 | 4 | 5 | def best_fit_transform(A, B): 6 | ''' 7 | Calculates the least-squares best-fit transform that maps corresponding 8 | points A to B in m spatial dimensions 9 | Input: 10 | A: Nxm numpy array of corresponding points 11 | B: Nxm numpy array of corresponding points 12 | Returns: 13 | T: (m+1)x(m+1) homogeneous transformation matrix that maps A on to B 14 | R: mxm rotation matrix 15 | t: mx1 translation vector 16 | ''' 17 | 18 | assert A.shape == B.shape 19 | 20 | # get number of dimensions 21 | m = A.shape[1] 22 | 23 | # translate points to their centroids 24 | centroid_A = np.mean(A, axis=0) 25 | centroid_B = np.mean(B, axis=0) 26 | AA = A - centroid_A 27 | BB = B - centroid_B 28 | 29 | # rotation matrix 30 | H = np.dot(AA.T, BB) 31 | U, S, Vt = np.linalg.svd(H) 32 | R = np.dot(Vt.T, U.T) 33 | 34 | # special reflection case 35 | if np.linalg.det(R) < 0: 36 | Vt[m-1,:] *= -1 37 | R = np.dot(Vt.T, U.T) 38 | 39 | # translation 40 | t = centroid_B.T - np.dot(R,centroid_A.T) 41 | 42 | # homogeneous transformation 43 | T = np.identity(m+1) 44 | T[:m, :m] = R 45 | T[:m, m] = t 46 | 47 | return T, R, t 48 | 49 | 50 | def nearest_neighbor(src, dst): 51 | ''' 52 | Find the nearest (Euclidean) neighbor in dst for each point in src 53 | Input: 54 | src: Nxm array of points 55 | dst: Nxm array of points 56 | Output: 57 | distances: Euclidean distances of the nearest neighbor 58 | indices: dst indices of the nearest neighbor 59 | ''' 60 | 61 | assert src.shape == dst.shape 62 | 63 | neigh = NearestNeighbors(n_neighbors=1) 64 | neigh.fit(dst) 65 | distances, indices = neigh.kneighbors(src, return_distance=True) 66 | return distances.ravel(), indices.ravel() 67 | 68 | 69 | def icp(A, B, init_pose=None, max_iterations=20, tolerance=0.001): 70 | ''' 71 | The Iterative Closest Point method: finds best-fit transform that maps 72 | points A on to points B 73 | Input: 74 | A: Nxm numpy array of source mD points 75 | B: Nxm numpy array of destination mD point 76 | init_pose: (m+1)x(m+1) homogeneous transformation 77 | max_iterations: exit algorithm after max_iterations 78 | tolerance: convergence criteria 79 | Output: 80 | T: final homogeneous transformation that maps A on to B 81 | distances: Euclidean distances (errors) of the nearest neighbor 82 | i: number of iterations to converge 83 | ''' 84 | 85 | assert A.shape == B.shape 86 | 87 | # get number of dimensions 88 | m = A.shape[1] 89 | 90 | # make points homogeneous, copy them to maintain the originals 91 | src = np.ones((m+1,A.shape[0])) 92 | dst = np.ones((m+1,B.shape[0])) 93 | src[:m,:] = np.copy(A.T) 94 | dst[:m,:] = np.copy(B.T) 95 | 96 | # apply the initial pose estimation 97 | if init_pose is not None: 98 | src = np.dot(init_pose, src) 99 | 100 | prev_error = 0 101 | 102 | for i in range(max_iterations): 103 | # find the nearest neighbors between the current source and destination points 104 | distances, indices = nearest_neighbor(src[:m,:].T, dst[:m,:].T) 105 | 106 | # compute the transformation between the current source and nearest destination points 107 | T,_,_ = best_fit_transform(src[:m,:].T, dst[:m,indices].T) 108 | 109 | # update the current source 110 | src = np.dot(T, src) 111 | 112 | # check error 113 | mean_error = np.mean(distances) 114 | if np.abs(prev_error - mean_error) < tolerance: 115 | break 116 | prev_error = mean_error 117 | 118 | # calculate final transformation 119 | T,_,_ = best_fit_transform(A, src[:m,:].T) 120 | 121 | return T, distances, i 122 | -------------------------------------------------------------------------------- /gen_slices/ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /gen_slices/ldm/data/custom_sin_img.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | import random 8 | 9 | class CustomSinImgBase(Dataset): 10 | def __init__(self, 11 | txt_file, 12 | data_root, 13 | split, 14 | size=None, 15 | interpolation="bilinear", 16 | flip_p=0.5 17 | ): 18 | self.data_paths = txt_file 19 | self.data_root = data_root 20 | self.split = split 21 | self.n_views = 12 22 | with open(self.data_paths, "r") as f: 23 | self.image_ids = f.read().splitlines() 24 | 25 | self._length = len(self.image_ids) 26 | 27 | self.labels = { 28 | "file_path_": [l for l in self.image_ids], 29 | } 30 | 31 | self.size = size 32 | self.interpolation = {"linear": PIL.Image.LINEAR, 33 | "bilinear": PIL.Image.BILINEAR, 34 | "bicubic": PIL.Image.BICUBIC, 35 | "lanczos": PIL.Image.LANCZOS, 36 | }[interpolation] 37 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) 38 | 39 | def __len__(self): 40 | return self._length 41 | 42 | def png_2_whitebg(self, img): 43 | img_rgb = np.array(img)[:, :, 0:3] 44 | img_alpha = np.array(img)[:, :, 3:4] 45 | img_alpha_ = (img_alpha == 0).astype(np.float32) 46 | img_ret = np.ones(img_rgb.shape) * 255 * img_alpha_ + img_rgb * (1 - img_alpha_) 47 | ret = Image.fromarray(img_ret.astype(np.uint8)) 48 | return ret 49 | 50 | 51 | def __getitem__(self, i): 52 | example = dict((k, self.labels[k][i]) for k in self.labels) 53 | 54 | if self.split == 'train': 55 | cmr_angle_idx = random.randint(0, self.n_views - 1) 56 | else: 57 | cmr_angle_idx = 4 58 | 59 | cmr_angle = "%03d"%cmr_angle_idx 60 | 61 | img_list = [] 62 | for axis in ['X', 'Z', 'Y']: 63 | if axis == 'Z': 64 | slice_list = ['4', '3', '2', '1'] 65 | else: 66 | slice_list = ['1', '2', '3', '4'] 67 | 68 | for part in slice_list: 69 | img_slice = Image.open(f'{self.data_root}/01_img_slices/{example["file_path_"]}/{cmr_angle}/{axis}_{part}.png') 70 | img_slice = self.png_2_whitebg(img_slice) 71 | img_slice = img_slice.resize((self.size, self.size), resample=self.interpolation) 72 | img_slice = np.array(img_slice)[:, :, :] 73 | img_list.append(img_slice) 74 | 75 | img_ipt_view = Image.open(f'{self.data_root}/00_img_input/{example["file_path_"]}/{cmr_angle}.png') 76 | img_ipt_view = self.png_2_whitebg(img_ipt_view) 77 | img_ipt_view_rs = img_ipt_view.resize((self.size, self.size), resample=self.interpolation) 78 | img_ipt_view = np.array(img_ipt_view_rs) 79 | 80 | img_list.append(img_ipt_view) 81 | image = np.concatenate(img_list, -1) 82 | 83 | example["image"] = (image / 127.5 - 1.0).astype(np.float32) 84 | 85 | example["img_ipt_view"] = (img_ipt_view / 127.5 - 1.0).astype(np.float32) 86 | 87 | example["segmentation"] = np.zeros((3, 32, 32)).astype(np.float32) 88 | 89 | return example 90 | 91 | 92 | class CustomSinImgTrain(CustomSinImgBase): 93 | def __init__(self, **kwargs): 94 | super().__init__(txt_file="../data/custom_sin_img/03_splits/train.lst", data_root="../data/custom_sin_img", split='train', **kwargs) 95 | 96 | 97 | class CustomSinImgValidation(CustomSinImgBase): 98 | def __init__(self, flip_p=0., **kwargs): 99 | super().__init__(txt_file="../data/custom_sin_img/03_splits/val.lst", data_root="../data/custom_sin_img", split='val', 100 | flip_p=flip_p, **kwargs) 101 | 102 | class CustomSinImgTest(CustomSinImgBase): 103 | def __init__(self, flip_p=0., **kwargs): 104 | super().__init__(txt_file="../data/custom_sin_img/03_splits/test.lst", data_root="../data/custom_sin_img", split='test', 105 | flip_p=flip_p, **kwargs) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | data/objaverse 3 | data/custom_sin_img 4 | data/shapenet 5 | reg_slices/experiments 6 | gen_slices/logs 7 | gen_slices/models/first_stage_models/kl-f8/model.ckpt 8 | gen_slices/taming/modules/autoencoder/lpips/vgg.pth 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | cover/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | db.sqlite3-journal 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | .pybuilder/ 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # IPython 90 | profile_default/ 91 | ipython_config.py 92 | 93 | # pyenv 94 | # For a library or package, you might want to ignore these files since the code is 95 | # intended to run in multiple environments; otherwise, check them in: 96 | # .python-version 97 | 98 | # pipenv 99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 102 | # install all needed dependencies. 103 | #Pipfile.lock 104 | 105 | # poetry 106 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 107 | # This is especially recommended for binary packages to ensure reproducibility, and is more 108 | # commonly ignored for libraries. 109 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 110 | #poetry.lock 111 | 112 | # pdm 113 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 114 | #pdm.lock 115 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 116 | # in version control. 117 | # https://pdm.fming.dev/#use-with-ide 118 | .pdm.toml 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | reg_slices/src_convonet/utils/libmcubes/mcubes.cpython-39-x86_64-linux-gnu.so 170 | reg_slices/src_convonet/utils/libmesh/triangle_hash.cpython-39-x86_64-linux-gnu.so 171 | reg_slices/src_convonet/utils/libmise/mise.cpython-39-x86_64-linux-gnu.so 172 | reg_slices/src_convonet/utils/libsimplify/simplify_mesh.cpython-39-x86_64-linux-gnu.so 173 | reg_slices/src_convonet/utils/libvoxelize/voxelize.cpython-39-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /reg_slices/src/models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | from .unet_custom import UNet 5 | import numpy as np 6 | import torchvision 7 | torch.set_printoptions(precision=8) 8 | from torchvision import models 9 | from .vgg_perceptual_loss import VGGPerceptualLoss 10 | from .vgg16bn_feats import VGG16BNFeats 11 | 12 | class Slices3DRegModel(nn.Module): 13 | def __init__(self, img_size=128, n_slices=12, mode='train'): 14 | super().__init__() 15 | self.mode = mode 16 | self.slices_generator = UNet(n_channels=3) 17 | self.img_size = img_size 18 | self.att_layer = nn.TransformerEncoderLayer(d_model=128, nhead=4, batch_first=True) 19 | self.att_decoder = nn.TransformerEncoder(self.att_layer, num_layers=3) 20 | self.fc_p = nn.Linear(3, 128) 21 | self.fc_s = nn.Linear(992, 128) 22 | self.fc_out = nn.Sequential( 23 | nn.Linear(128, 1), 24 | ) 25 | self.vggptlossfunc = VGGPerceptualLoss() 26 | self.n_slices = n_slices 27 | 28 | def project_coord(self, coordinates, trans_mat_wo_rot_tp): 29 | n_bs, n_qry = coordinates.shape[0], coordinates.shape[1] 30 | size_lst = coordinates.shape 31 | homo_pc = torch.cat((coordinates, torch.ones((size_lst[0], size_lst[1], 1)).cuda()), axis=-1) 32 | pc_xyz = torch.bmm(homo_pc, trans_mat_wo_rot_tp) 33 | pc_xy = torch.divide(pc_xyz[:, :, :2], pc_xyz[:, :, 2:]) 34 | ret = 2 * (pc_xy - 0.5) # from [0, 1] to [-1, 1] 35 | ret = torch.clamp(ret, min=-1, max=1) 36 | return ret 37 | 38 | def sample_from_planes(self, plane_features, projected_coordinates, mode='bilinear', padding_mode='zeros', box_warp=None): 39 | # assert padding_mode == 'zeros' 40 | n_planes = 1 41 | N, C, H, W = plane_features.shape 42 | _, M, _ = projected_coordinates.shape 43 | plane_features = plane_features.view(N, C, H, W) 44 | projected_coordinates = projected_coordinates.unsqueeze(1) 45 | output_features = torch.nn.functional.grid_sample(plane_features, projected_coordinates.float(), mode=mode, padding_mode=padding_mode, align_corners=True).permute(0, 3, 2, 1).reshape(N, n_planes, M, C) 46 | return output_features 47 | 48 | def forward(self, feed_dict): 49 | 50 | img_input = feed_dict['img_input'] 51 | n_bs, _, n_w, n_h = img_input.shape 52 | 53 | if self.mode == 'test': 54 | qry_no_rot = feed_dict['qry_norot'] 55 | qry_no_rot[:, :, 1:] *= -1 56 | qry_rot = qry_no_rot 57 | else: 58 | qry_no_rot = feed_dict['qry_norot'] 59 | obj_rot_mat = feed_dict['obj_rot_mat'] 60 | qry_rot = torch.bmm(qry_no_rot, obj_rot_mat) 61 | 62 | qry = qry_rot 63 | _, n_qry, _ = qry_no_rot.shape 64 | 65 | feat_list, slices_rec = self.slices_generator(img_input) # n_bs * n_slices, 3, w, h 66 | slices_rec_ = slices_rec.view(n_bs, self.n_slices, 3, n_w, n_h).view(n_bs, self.n_slices * 3, n_w, n_h) 67 | 68 | feat_interp = [] 69 | img_pts = self.project_coord(qry_rot, feed_dict['trans_mat_wo_rot_tp']) 70 | img_pts = img_pts.view(n_bs, 1, n_qry, 2).expand(-1, self.n_slices, -1, -1).reshape(n_bs * self.n_slices, n_qry, 2) 71 | for idx in range(len(feat_list)): 72 | n_bs_, n_c, n_h, n_w = feat_list[idx].shape 73 | feat_planes = feat_list[idx].view(n_bs_, n_c, n_h, n_w) 74 | feats_out = self.sample_from_planes(feat_planes, img_pts) 75 | feat_interp.append(feats_out.squeeze(1)) 76 | 77 | feat_local_aggregated = torch.cat(feat_interp, dim=2) 78 | feat_local_aggregated_ = feat_local_aggregated.view(n_bs, self.n_slices, n_qry, 992).permute(0, 2, 1, 3).reshape(n_bs * n_qry, self.n_slices, 992) 79 | feat_qry = self.fc_p(qry_rot) 80 | feat_slice = self.fc_s(feat_local_aggregated_) # n_bs * n_qry, n_slices, 128 81 | 82 | feat_input = torch.cat([feat_qry.view(n_bs * n_qry, 1, 128), feat_slice], 1) 83 | feat_attened = self.att_decoder(feat_input).view(n_bs, n_qry, self.n_slices + 1, 128)[:, :, 0, :] 84 | sdf_pred = self.fc_out(feat_attened).squeeze(-1) 85 | 86 | ret_dict = {} 87 | ret_dict['sdf_pred'] = sdf_pred 88 | ret_dict['slices_rec'] = slices_rec_ 89 | 90 | img_slices_ = feed_dict['img_slices'].view(n_bs, self.n_slices, 3, n_w, n_h).view(n_bs * self.n_slices, 3, n_w, n_h) 91 | 92 | ret_dict['vgg_loss'] = self.vggptlossfunc(slices_rec, img_slices_)['pt_c_loss'] * 0.001 93 | 94 | return ret_dict 95 | 96 | -------------------------------------------------------------------------------- /reg_slices/options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_parser(): 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('--name_model', type=str, default='slicenet', choices=['slicenet', 'disn', 'gtslice']) 6 | # dataset related 7 | parser.add_argument('--dir_data', type=str, default='../data') 8 | parser.add_argument('--name_dataset', type=str, default='shapenet', choices=['objaverse', 'shapenet', 'custom', 'custom_sin_img']) 9 | parser.add_argument('--name_single', type=str, default='fertility', help='name of the single shape') 10 | parser.add_argument('--n_wk', type=int, default=16, help='number of workers in dataloader') 11 | parser.add_argument('--categories_train', type=str, default='objaverse,', help='the training and validation categories of objects for ShapeNet datasets') 12 | parser.add_argument('--categories_test', type=str, default='objaverse,', help='the testing categories of objects for ShapeNet datasets') 13 | parser.add_argument('--add_noise', type=float, default=0, help='the std of noise added to the point clouds') 14 | parser.add_argument('--gt_source', type=str, default='imnet', choices=['imnet', 'occnet'], help='using which query-occ groundtruth when training on Shapenet') 15 | 16 | parser.add_argument('--img_size', type=int, default=128, help='img_size') 17 | parser.add_argument('--n_qry', type=int, default=256, help='the number of query points for per shape when training') 18 | parser.add_argument('--n_slices', type=int, default=12, help='the number of slices for each shape') 19 | parser.add_argument('--n_views', type=int, default=12, help='the number of views for each shape') 20 | parser.add_argument('--pred_type', type=str, default='sdf', choices=['occ', 'sdf'], help='predict occupancy (occ) or signed distance field (sdf), sdf only works for abc dataset') 21 | 22 | # common hyper-parameters 23 | parser.add_argument('--name_exp', type=str, default='2023_07_04_chairs_vggptloss') 24 | parser.add_argument('--name_exp_cam', type=str, default='2023_1107_airplanes_est') # 2023_07_010_camera_pose_est_ or 2023_07_010_camera_pose_est_rot_shift 2023_1107_airplanes_est 25 | parser.add_argument('--mode', type=str, default='train', choices=['train', 'val', 'test']) 26 | parser.add_argument('--n_bs', type=int, default=16, help='batch size') 27 | parser.add_argument('--n_epochs', type=int, default=600, help='number of epochs') 28 | parser.add_argument('--lr', type=float, default=3e-4, help='init learning rate') 29 | parser.add_argument('--n_dim', type=int, default=128, help='the dimension of hidden layer features') 30 | parser.add_argument('--multi_gpu', type=bool, default=False) 31 | parser.add_argument('--freq_ckpt', type=int, default=4, help='frequency of epoch saving checkpoint') 32 | parser.add_argument('--freq_log', type=int, default=200, help='frequency of outputing training logs') 33 | parser.add_argument('--freq_decay', type=int, default=100, help='decaying the lr evey freq_decay epochs') 34 | parser.add_argument('--weight_decay', type=float, default=0.5, help='weight decay') 35 | parser.add_argument('--tboard', type=bool, default=True, help='whether use tensorboard to visualize loss') 36 | parser.add_argument('--resume', action=argparse.BooleanOptionalAction, help='resume training') 37 | parser.add_argument('--est_campose', action=argparse.BooleanOptionalAction, help='whether to use gt camera poses') 38 | 39 | parser.add_argument('--back_bone_cam_est', type=str, default='vgg16_bn', choices=['vgg16_bn', 'resnet50']) 40 | 41 | # training related 42 | parser.add_argument('--use_white_bg', action=argparse.BooleanOptionalAction, help='whether to use gt camera poses') 43 | 44 | # Marching Cube realted 45 | parser.add_argument('--mc_chunk_size', type=int, default=3000, help='the number of query points in a chunk when doing marching cube, set it according to your GPU memory') 46 | parser.add_argument('--mc_res0', type=int, default=64, help='start resolution for MISE') 47 | parser.add_argument('--mc_up_steps', type=int, default=2, help='number of upsampling steps') 48 | parser.add_argument('--mc_threshold', type=float, default=0.5, help='the threshold for network output values') 49 | # testing related 50 | parser.add_argument('--name_ckpt', type=str, default='10_5511_0.0876_0.9612.ckpt') 51 | parser.add_argument('--name_ckpt_cam', type=str, default='570_225545_1.969e-05.ckpt') # 480_276575_0.0003403.ckpt 500_288075_0.0009608.ckpt 570_225545_1.969e-05.ckpt 52 | parser.add_argument('--from_which_slices', type=str, default='gt', choices=['gt', 'gt_rec', 'gen'], help='using which kind of slices') 53 | parser.add_argument('--overwrite_res', action=argparse.BooleanOptionalAction, help='whether to overwrite the results') 54 | return parser -------------------------------------------------------------------------------- /reg_slices/test_projection.py: -------------------------------------------------------------------------------- 1 | from src.utils import load_mesh, get_img_cam, getBlenderProj, get_rotate_matrix, get_norm_matrix, write_pointcloud, read_params, get_W2O_mat 2 | import cv2 3 | import numpy as np 4 | import h5py 5 | import trimesh 6 | import pickle 7 | 8 | def get_img_points(sample_pc, trans_mat_right, obj_rot_mat, K): 9 | # sample_pc N * 3 10 | size_lst = sample_pc.shape 11 | 12 | homo_pc = np.concatenate((sample_pc, np.ones((size_lst[0], 1))), axis=-1) 13 | pc_xyz = np.matmul(homo_pc, trans_mat_right) 14 | 15 | print(pc_xyz[0:5]) 16 | 17 | pc_xyz_v2_tmp = np.matmul(sample_pc, obj_rot_mat) 18 | pc_xyz_v2_tmp = np.concatenate((pc_xyz_v2_tmp, np.ones((size_lst[0], 1))), axis=-1) 19 | pc_xyz_v2 = np.matmul(pc_xyz_v2_tmp, K) 20 | print(pc_xyz_v2[0:5]) 21 | 22 | pc_xy = np.divide(pc_xyz[:,:2], pc_xyz[:,2:]) 23 | 24 | pc_xy = pc_xy * 256 25 | mintensor = np.array([0.0,0.0]) 26 | maxtensor = np.array([256,256.0]) 27 | return np.minimum(maxtensor, np.maximum(mintensor, pc_xy)) 28 | 29 | def read_sdf_file(path_sdf, scale, offset): 30 | 31 | 32 | sdf_npy = np.load(path_sdf) 33 | 34 | sample_pt = sdf_npy[:, :3] 35 | sample_sdf_val = sdf_npy[:, 3] 36 | 37 | # (camX, camY, camZ) -> (camX, -camZ, camY) 38 | print(offset[0], offset[1], offset[2]) 39 | offset_ = np.array([offset[0], offset[2], -offset[1]]) 40 | 41 | sample_pt = sample_pt * scale + offset_ 42 | sample_sdf_val = (sample_sdf_val) * scale 43 | 44 | return sample_pt, sample_sdf_val 45 | 46 | def test_img_h5(dir_img, shape_id, angle_int): 47 | angle_str = "%02d"%angle_int 48 | img_arr, trans_mat, obj_rot_mat, K, sample_pt = get_img(dir_img, shape_id, angle_int) 49 | # march_obj_fl = f'/localhome/ywa439/Documents/datasets/ShapeNet/DISN_SDF_MC_Meshes/03001627/{shape_id}/isosurf.obj' 50 | sample_pt_ = sample_pt[0:100] 51 | 52 | pc_xy = get_img_points(sample_pt[0:100], trans_mat, obj_rot_mat, K) 53 | 54 | for j in range(pc_xy.shape[0]): 55 | y = int(pc_xy[j, 1]) 56 | x = int(pc_xy[j, 0]) 57 | 58 | cv2.circle(img_arr, (x, y), 3, (255, 0, 0), -1) 59 | 60 | # rot_pc = np.dot(new_pts, obj_rot_mat) 61 | 62 | cv2.imwrite(f"./{shape_id}_{angle_str}_proj_samplept_{str(angle_int)}.png", img_arr) 63 | 64 | 65 | def get_points(obj_fl): 66 | sample_pc = np.zeros((0,3), dtype=np.float32) 67 | mesh_lst = trimesh.load_mesh(obj_fl, process=False) 68 | if not isinstance(mesh_lst, list): 69 | mesh_lst = [mesh_lst] 70 | for mesh in mesh_lst: 71 | choice = np.random.randint(mesh.vertices.shape[0], size=1000) 72 | sample_pc = np.concatenate((sample_pc, mesh.vertices[choice,...]), axis=0) #[choice,...] 73 | # color = [[255,0,0], [0,255,0], [0,0,255], [255, 0, 255]] 74 | color = 255*np.ones_like(sample_pc, dtype=np.uint8) 75 | color[:,0] = 0 76 | color[:,1] = 0 77 | return sample_pc, np.asarray(color, dtype=np.uint8) 78 | 79 | def get_img(dir_img, shape_id, angle_int): 80 | rot90y = np.array([[0, 0, -1], [0, 1, 0], [1, 0, 0]], dtype=np.float32) 81 | angle_str = "%03d"%(angle_int) 82 | dir_sfd = '/data/wangyz/11_slice3d/data/objaverse/02_sdfs' 83 | img_arr = cv2.imread(f'{dir_img}/{shape_id}/{angle_str}.png', cv2.IMREAD_UNCHANGED).astype(np.uint8) 84 | 85 | path_cmr_info = f'{dir_img}/{shape_id}/meta.pkl' 86 | # Open the file in binary mode 87 | with open(path_cmr_info, 'rb') as f: 88 | # Load the data 89 | data = pickle.load(f) 90 | # trans_mat = np.transpose(data[4][angle_int]) 91 | az = -data[1][angle_int] 92 | el = data[2][angle_int] 93 | print(az, el) 94 | distance = data[3][angle_int] 95 | scale = data[5] 96 | offset = data[6] 97 | print('offset', offset) 98 | K, RT = getBlenderProj(az, el, distance, img_w=1, img_h=1) 99 | 100 | sample_pt, sample_sdf_val = read_sdf_file(f'{dir_sfd}/{shape_id}.npy', scale, offset) 101 | 102 | rot_mat = get_rotate_matrix(-np.pi / 2) 103 | W2O_mat = get_W2O_mat((0, 0, 0)) 104 | trans_mat = np.linalg.multi_dot([K, RT, rot_mat, W2O_mat]) 105 | trans_mat = np.transpose(trans_mat) 106 | 107 | rot_full = np.linalg.multi_dot([RT, rot_mat]) 108 | tmp = np.concatenate((np.eye(3), rot_full[:, 3:4]), axis=1) 109 | obj_rot_mat = np.transpose(rot_full)[:3, :] 110 | 111 | K = np.transpose(np.linalg.multi_dot([K, tmp, W2O_mat])) 112 | return img_arr[:, :, :3].copy(), trans_mat, obj_rot_mat, K, sample_pt[(sample_sdf_val) > 0] 113 | 114 | 115 | angle_int = 0 116 | dir_img = '/data/wangyz/11_slice3d/data/objaverse/00_img_input' 117 | shape_id = 'f598dfee0d22404983dc9d9c3f307202' # 00a1a602456f4eb188b522d7ef19e81b 0a5652c16e1a4575903dfc1696382502_00_proj_samplept_0 118 | test_img_h5(dir_img, shape_id, angle_int) 119 | -------------------------------------------------------------------------------- /gen_slices/ldm/data/objaverse.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | import random 8 | 9 | class ObjaverseBase(Dataset): 10 | def __init__(self, 11 | txt_file, 12 | data_root, 13 | split, 14 | size=None, 15 | interpolation="bilinear", 16 | flip_p=0.5 17 | ): 18 | self.data_paths = txt_file 19 | self.data_root = data_root 20 | self.split = split 21 | self.n_views = 12 22 | self.img_size = 128 23 | with open(self.data_paths, "r") as f: 24 | self.image_ids = f.read().splitlines() 25 | if self.split == 'trainval_rec': 26 | self._length_ = len(self.image_ids) 27 | self.image_ids = self.image_ids * self.n_views 28 | self._length = len(self.image_ids) 29 | 30 | self.labels = { 31 | "file_path_": [l for l in self.image_ids], 32 | } 33 | 34 | self.size = size 35 | self.interpolation = {"linear": PIL.Image.LINEAR, 36 | "bilinear": PIL.Image.BILINEAR, 37 | "bicubic": PIL.Image.BICUBIC, 38 | "lanczos": PIL.Image.LANCZOS, 39 | }[interpolation] 40 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) 41 | 42 | def __len__(self): 43 | return self._length 44 | 45 | def png_2_whitebg(self, img): 46 | img_rgb = np.array(img)[:, :, 0:3] 47 | img_alpha = np.array(img)[:, :, 3:4] 48 | img_alpha_ = (img_alpha == 0).astype(np.float32) 49 | img_ret = np.ones(img_rgb.shape) * 255 * img_alpha_ + img_rgb * (1 - img_alpha_) 50 | ret = Image.fromarray(img_ret.astype(np.uint8)) 51 | return ret 52 | 53 | 54 | def __getitem__(self, i): 55 | example = dict((k, self.labels[k][i]) for k in self.labels) 56 | 57 | if self.split == 'train': 58 | cmr_angle_idx = random.randint(0, self.n_views - 1) 59 | elif self.split in ['val', 'test']: 60 | cmr_angle_idx = 4 61 | else: 62 | cmr_angle_idx = i // self._length_ 63 | 64 | cmr_angle = "%03d"%cmr_angle_idx 65 | 66 | img_list = [] 67 | for axis in ['X', 'Z', 'Y']: 68 | if axis == 'Z': 69 | slice_list = ['4', '3', '2', '1'] 70 | else: 71 | slice_list = ['1', '2', '3', '4'] 72 | 73 | for part in slice_list: 74 | img_slice = Image.open(f'{self.data_root}/01_img_slices/{example["file_path_"]}/{cmr_angle}/{axis}_{part}.png') 75 | img_slice = self.png_2_whitebg(img_slice) 76 | img_slice = img_slice.resize((self.size, self.size), resample=self.interpolation) 77 | img_slice = np.array(img_slice)[:, :, :] 78 | img_list.append(img_slice) 79 | 80 | img_ipt_view = Image.open(f'{self.data_root}/00_img_input/{example["file_path_"]}/{cmr_angle}.png') 81 | img_ipt_view = self.png_2_whitebg(img_ipt_view) 82 | img_ipt_view_rs = img_ipt_view.resize((self.size, self.size), resample=self.interpolation) 83 | img_ipt_view = np.array(img_ipt_view_rs) 84 | 85 | img_list.append(img_ipt_view) 86 | image = np.concatenate(img_list, -1) 87 | 88 | example["image"] = (image / 127.5 - 1.0).astype(np.float32) 89 | 90 | example["img_ipt_view"] = (img_ipt_view / 127.5 - 1.0).astype(np.float32) 91 | 92 | example["segmentation"] = np.zeros((3, 32, 32)).astype(np.float32) 93 | 94 | return example 95 | 96 | 97 | class ObjaverseTrain(ObjaverseBase): 98 | def __init__(self, **kwargs): 99 | super().__init__(txt_file="../data/objaverse/03_splits/train.lst", data_root="../data/objaverse", split='train', **kwargs) 100 | 101 | 102 | class ObjaverseValidation(ObjaverseBase): 103 | def __init__(self, flip_p=0., **kwargs): 104 | super().__init__(txt_file="../data/objaverse/03_splits/val.lst", data_root="../data/objaverse", split='val', 105 | flip_p=flip_p, **kwargs) 106 | 107 | class ObjaverseTest(ObjaverseBase): 108 | def __init__(self, flip_p=0., **kwargs): 109 | super().__init__(txt_file="../data/objaverse/03_splits/test.lst", data_root="../data/objaverse", split='test', 110 | flip_p=flip_p, **kwargs) 111 | 112 | class ObjaverseTrainValRec(ObjaverseBase): 113 | def __init__(self, flip_p=0., **kwargs): 114 | super().__init__(txt_file="../data/objaverse/03_splits/trainval.lst", data_root="../data/objaverse", split='trainval_rec', 115 | flip_p=flip_p, **kwargs) -------------------------------------------------------------------------------- /reg_slices/src/model_disn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import torchvision 6 | torch.set_printoptions(precision=8) 7 | from torchvision import models 8 | from .vgg16bn_feats import VGG16BNFeats 9 | 10 | class DISNModel(nn.Module): 11 | def __init__(self, img_size=224, mode='train'): 12 | super().__init__() 13 | self.mode = mode 14 | self.img_encoder = VGG16BNFeats() 15 | self.img_size = img_size 16 | 17 | self.pts_feat_extractor = nn.Sequential( 18 | nn.Linear(3, 64), 19 | nn.ReLU(), 20 | nn.Linear(64, 256), 21 | nn.ReLU(), 22 | nn.Linear(256, 512), 23 | nn.ReLU(), 24 | ) 25 | 26 | self.fc_local = nn.Sequential( 27 | nn.Linear(1472 + 512, 512), 28 | nn.ReLU(), 29 | nn.Linear(512, 256), 30 | nn.ReLU(), 31 | nn.Linear(256, 1), 32 | ) 33 | 34 | 35 | self.fc_global = nn.Sequential( 36 | nn.Linear(1000 + 512, 512), 37 | nn.ReLU(), 38 | nn.Linear(512, 256), 39 | nn.ReLU(), 40 | nn.Linear(256, 1), 41 | ) 42 | 43 | 44 | def project_coord(self, coordinates, trans_mat_right): 45 | n_bs, n_qry = coordinates.shape[0], coordinates.shape[1] 46 | # trans_mat_right = trans_mat_right.unsqueeze(1).expand(-1, n_qry, -1, -1) 47 | # coordinates = coordinates * 2 48 | size_lst = coordinates.shape 49 | homo_pc = torch.cat((coordinates, torch.ones((size_lst[0], size_lst[1], 1)).cuda()), axis=-1) 50 | pc_xyz = torch.bmm(homo_pc, trans_mat_right) 51 | pc_xy = torch.divide(pc_xyz[:, :, :2], pc_xyz[:, :, 2:]) 52 | ret = 2 * (pc_xy - 0.5) # from [0, 1] to [-1, 1] 53 | ret = torch.clamp(ret, min=-1, max=1) 54 | return ret 55 | 56 | def sample_from_planes(self, plane_features, projected_coordinates, mode='bilinear', padding_mode='zeros', box_warp=None): 57 | # assert padding_mode == 'zeros' 58 | n_planes = 1 59 | N, C, H, W = plane_features.shape 60 | _, M, _ = projected_coordinates.shape 61 | plane_features = plane_features.view(N, C, H, W) 62 | projected_coordinates = projected_coordinates.unsqueeze(1) 63 | output_features = torch.nn.functional.grid_sample(plane_features, projected_coordinates.float(), mode=mode, padding_mode=padding_mode, align_corners=True).permute(0, 3, 2, 1).reshape(N, n_planes, M, C) 64 | return output_features 65 | 66 | 67 | 68 | def forward(self, feed_dict): 69 | 70 | 71 | img_input = feed_dict['img_input'] 72 | trans_mat_right = feed_dict['trans_mat_right'] 73 | n_bs, _, n_w, n_h = img_input.shape 74 | qry_no_rot = feed_dict['qry_norot'] 75 | 76 | obj_rot_mat = feed_dict['obj_rot_mat'] 77 | qry_rot = torch.bmm(qry_no_rot, obj_rot_mat) 78 | qry = qry_rot 79 | _, n_qry, _ = qry_no_rot.shape 80 | 81 | feat_list, feats_global = self.img_encoder(img_input) # n * 12, 3, w, h 82 | 83 | feat_interp = [] 84 | 85 | img_pts = self.project_coord(qry_no_rot, trans_mat_right) 86 | # img_pts = img_pts.view(n_bs, 1, n_qry, 2).expand(-1, 12, -1, -1).reshape(n_bs * 12, n_qry, 2) 87 | 88 | for idx in range(len(feat_list)): 89 | 90 | n_bs_, n_c, n_h, n_w = feat_list[idx].shape 91 | feat_planes = feat_list[idx].view(n_bs_, n_c, n_h, n_w) 92 | 93 | feats_out = self.sample_from_planes(feat_planes, img_pts) 94 | 95 | feat_interp.append(feats_out.squeeze(1)) 96 | 97 | feat_local_aggregated = torch.cat(feat_interp, dim=2) # n_bs, n_qry, 1472 98 | 99 | feats_global = feats_global.unsqueeze(1).expand(-1, n_qry, -1) # n_bs, n_qry, 1000 100 | 101 | # print(feat_local_aggregated.shape) 102 | # print(feats_global.shape) 103 | # input() 104 | # feat_local_aggregated_ = feat_local_aggregated.view(n_bs, n_qry, 1472).reshape(n_bs * n_qry, 1472) 105 | 106 | feat_qry = self.pts_feat_extractor(qry_rot) 107 | 108 | feat_local_cat_q = torch.cat([feat_local_aggregated, feat_qry], 2) 109 | feat_global_cat_q = torch.cat([feats_global, feat_qry], 2) 110 | 111 | 112 | 113 | sdf_pred = self.fc_local(feat_local_cat_q) + self.fc_global(feat_global_cat_q) 114 | 115 | 116 | ret_dict = {} 117 | ret_dict['sdf_pred'] = sdf_pred.squeeze(-1) 118 | # ret_dict['slices_rec'] = slices_rec_ 119 | 120 | # img_slices_ = feed_dict['img_slices'].view(n_bs, 12, 3, n_w, n_h).view(n_bs * 12, 3, n_w, n_h) 121 | # print(slices_rec.shape) 122 | #print(img_slices_.shape) 123 | #input() 124 | # ret_dict['vgg_loss'] = self.vggptlossfunc(slices_rec, img_slices_)['pt_c_loss'] * 0.001 125 | 126 | return ret_dict 127 | 128 | -------------------------------------------------------------------------------- /reg_slices/src/model_gt.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | from .unet_custom import UNet 5 | import numpy as np 6 | import torchvision 7 | torch.set_printoptions(precision=8) 8 | from torchvision import models 9 | from .vgg_perceptual_loss import VGGPerceptualLoss 10 | from .vgg16bn_feats import VGG16BNFeats 11 | 12 | class Slices3DGTModel(nn.Module): 13 | def __init__(self, img_size=128, n_slices=12, mode='train'): 14 | super().__init__() 15 | self.mode = mode 16 | self.img_encoder = VGG16BNFeats() 17 | self.img_size = img_size 18 | self.n_slices = n_slices 19 | self.att_layer = nn.TransformerEncoderLayer(d_model=128, nhead=4, batch_first=True) 20 | self.att_decoder = nn.TransformerEncoder(self.att_layer, num_layers=3) 21 | self.fc_out = nn.Sequential( 22 | nn.Linear(128, 1), 23 | ) 24 | self.pts_feat_extractor = nn.Sequential( 25 | nn.Linear(3, 32), 26 | nn.ReLU(), 27 | nn.Linear(32, 64), 28 | nn.ReLU(), 29 | nn.Linear(64, 128), 30 | nn.ReLU(), 31 | ) 32 | 33 | self.fc_local = nn.Sequential( 34 | nn.Linear(1472, 128), 35 | nn.ReLU(), 36 | nn.Linear(128, 128), 37 | nn.ReLU(), 38 | ) 39 | 40 | self.fc_global = nn.Sequential( 41 | nn.Linear(128 + 128, 128), 42 | nn.ReLU(), 43 | nn.Linear(128, 128), 44 | nn.ReLU(), 45 | ) 46 | 47 | 48 | def project_coord(self, coordinates, trans_mat_wo_rot_tp): 49 | n_bs, n_qry = coordinates.shape[0], coordinates.shape[1] 50 | 51 | size_lst = coordinates.shape 52 | homo_pc = torch.cat((coordinates, torch.ones((size_lst[0], size_lst[1], 1)).cuda()), axis=-1) 53 | pc_xyz = torch.bmm(homo_pc, trans_mat_wo_rot_tp) 54 | pc_xy = torch.divide(pc_xyz[:, :, :2], pc_xyz[:, :, 2:]) 55 | ret = 2 * (pc_xy - 0.5) # from [0, 1] to [-1, 1] 56 | ret = torch.clamp(ret, min=-1, max=1) 57 | return ret 58 | 59 | def sample_from_planes(self, plane_features, projected_coordinates, mode='bilinear', padding_mode='zeros', box_warp=None): 60 | # assert padding_mode == 'zeros' 61 | n_planes = 1 62 | N, C, H, W = plane_features.shape 63 | _, M, _ = projected_coordinates.shape 64 | plane_features = plane_features.view(N, C, H, W) 65 | projected_coordinates = projected_coordinates.unsqueeze(1) 66 | output_features = torch.nn.functional.grid_sample(plane_features, projected_coordinates.float(), mode=mode, padding_mode=padding_mode, align_corners=True).permute(0, 3, 2, 1).reshape(N, n_planes, M, C) 67 | return output_features 68 | 69 | def forward(self, feed_dict): 70 | 71 | img_input = feed_dict['img_input'] 72 | n_bs, _, n_w, n_h = img_input.shape 73 | 74 | if self.mode == 'test': 75 | qry_no_rot = feed_dict['qry_norot'] 76 | qry_no_rot[:, :, 1:] *= -1 77 | qry_rot = qry_no_rot 78 | else: 79 | qry_no_rot = feed_dict['qry_norot'] 80 | obj_rot_mat = feed_dict['obj_rot_mat'] 81 | qry_rot = torch.bmm(qry_no_rot, obj_rot_mat) 82 | 83 | qry = qry_rot 84 | _, n_qry, _ = qry_no_rot.shape 85 | 86 | img_slices = feed_dict['img_slices'] 87 | img_slices = img_slices.view(n_bs, self.n_slices, 3, n_w, n_h).view(n_bs * self.n_slices, 3, n_w, n_h) 88 | img_inpt_and_slices = torch.cat([img_input, img_slices], 0) # n_bs * n_slices 89 | feat_list, feats_global = self.img_encoder(img_slices) # n * n_slices, 3, w, h 90 | 91 | feat_interp = [] 92 | img_pts = self.project_coord(qry_rot, feed_dict['trans_mat_wo_rot_tp']) 93 | img_pts = img_pts.view(n_bs, 1, n_qry, 2).expand(-1, self.n_slices, -1, -1).reshape(n_bs * (self.n_slices), n_qry, 2) 94 | for idx in range(len(feat_list)): 95 | feat_planes = feat_list[idx] 96 | feats_out = self.sample_from_planes(feat_planes, img_pts) 97 | feat_interp.append(feats_out.squeeze(1)) 98 | feat_local_aggregated = torch.cat(feat_interp, dim=2) # n_bs * n_slices, n_qry, 1472 99 | feat_local_aggregated = feat_local_aggregated.view(n_bs, self.n_slices, n_qry, 1472).permute(0, 2, 1, 3).reshape(n_bs, n_qry, self.n_slices, 1472) 100 | 101 | feat_qry = self.pts_feat_extractor(qry_rot) 102 | feat_local_aggregated_ = self.fc_local(feat_local_aggregated) 103 | feat_slice = (feat_local_aggregated_).view(n_bs * n_qry, self.n_slices, 128) 104 | feat_input = torch.cat([feat_qry.view(n_bs * n_qry, 1, 128), feat_slice], 1) 105 | feat_attened = self.att_decoder(feat_input).view(n_bs, n_qry, self.n_slices + 1, 128)[:, :, 0, :] 106 | sdf_pred = self.fc_out(feat_attened).squeeze(-1) 107 | 108 | ret_dict = {} 109 | ret_dict['sdf_pred'] = sdf_pred 110 | 111 | return ret_dict 112 | 113 | -------------------------------------------------------------------------------- /reg_slices/src/datasets_cam.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset 4 | import os 5 | import trimesh 6 | from scipy.spatial.transform import Rotation 7 | from .utils import load_mesh, get_img_cam, getBlenderProj, get_rotate_matrix, get_norm_matrix, read_params, get_W2O_mat 8 | from PIL import Image 9 | import torchvision.transforms as T 10 | import random 11 | import h5py 12 | 13 | class CamEstDataset(Dataset): 14 | def __init__(self, split, args) -> None: 15 | self.split = split 16 | self.n_qry = args.n_qry 17 | self.dir_dataset = os.path.join(args.dir_data, args.name_dataset) 18 | self.name_dataset = args.name_dataset 19 | self.img_size = args.img_size 20 | self.files = [] 21 | if self.name_dataset == 'shapenet': 22 | if self.split in {'train', 'val'}: 23 | categories = args.categories_train.split(',')[:-1] 24 | else: 25 | categories = args.categories_test.split(',')[:-1] 26 | self.fext_mesh = 'obj' 27 | else: 28 | categories = [''] 29 | self.fext_mesh = 'ply' 30 | for category in categories: 31 | id_shapes = open(f'{self.dir_dataset}/03_splits/{category}/{split}.lst').read().split() 32 | for shape_id in id_shapes: 33 | self.files.append((category, shape_id)) 34 | 35 | self.preprocess = T.Compose([T.Resize((128, 128)), T.ToTensor(), T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]) 36 | 37 | self.dir_img = f'/localhome/ywa439/Documents/03_blender_bisect/processed_{category}_input_addtrans_xrot/dataset_input' 38 | self.camera_metainfo = f'/localhome/ywa439/Documents/03_blender_bisect/processed_{category}_input_addtrans_xrot/metainfo_input' 39 | self.rot90y = np.array([[0, 0, -1], [0, 1, 0], [1, 0, 0]], dtype=np.float32) 40 | self.dir_pcd = f'/localhome/ywa439/Documents/datasets/ShapeNet/v1/{category}_pc' 41 | 42 | # self.K_ = cal_K((1., 1.)) 43 | 44 | def __len__(self): return len(self.files) 45 | 46 | def rotate(): 47 | batch_sdf_pt_rot[cnt, ...] = np.dot(sample_pt[choice, :], obj_rot_mat) 48 | return 49 | 50 | def get_sdf_h5(self, sdf_h5_file, cat_id, obj): 51 | h5_f = h5py.File(sdf_h5_file, 'r') 52 | try: 53 | if ('pc_sdf_original' in h5_f.keys() 54 | and 'pc_sdf_sample' in h5_f.keys() 55 | and 'norm_params' in h5_f.keys()): 56 | ori_sdf = h5_f['pc_sdf_original'][:].astype(np.float32) 57 | sample_sdf = h5_f['pc_sdf_sample'][:].astype(np.float32) 58 | ori_pt = ori_sdf[:,:3]#, ori_sdf[:,3] 59 | ori_sdf_val = None 60 | sample_pt, sample_sdf_val = sample_sdf[:,:3], sample_sdf[:,3] 61 | norm_params = h5_f['norm_params'][:] 62 | sdf_params = h5_f['sdf_params'][:] 63 | else: 64 | raise Exception(cat_id, obj, "no sdf and sample") 65 | finally: 66 | h5_f.close() 67 | return ori_pt, ori_sdf_val, sample_pt, sample_sdf_val, norm_params, sdf_params 68 | 69 | def __getitem__(self, index): 70 | category, shape_id = self.files[index] 71 | 72 | # cmr_angle = 0 73 | if self.split in {'train', 'val'}: 74 | cmr_angle = str(random.randint(2, 7)) # "%02d"%cmr_angle 75 | else: 76 | cmr_angle = str(4) # "%02d"%cmr_angle 77 | 78 | img = Image.open(f'{self.dir_img}/train/{shape_id}/{cmr_angle}.png').convert('RGB') 79 | img = self.preprocess(img) 80 | 81 | pcd = np.load(f'{self.dir_pcd}/{shape_id}/2048.npy') 82 | 83 | rot_mat = get_rotate_matrix(-np.pi / 2) 84 | 85 | norm_params = [0, 0, 0, 1] 86 | 87 | norm_mat = get_norm_matrix(norm_params) 88 | 89 | with open(f"{self.camera_metainfo}/{shape_id}/{cmr_angle}.txt", 'r') as f: 90 | lines = f.read().splitlines() 91 | param_lst = read_params(lines) # (f"{self.dir_img}/{shape_id}/rendering/rendering_metadata.txt") 92 | camR, _ = get_img_cam(param_lst[0]) 93 | obj_rot_mat = np.dot(self.rot90y, camR) 94 | az, el, distance_ratio = param_lst[0][0], param_lst[0][1], param_lst[0][3] 95 | K, RT = getBlenderProj(az, el, distance_ratio, img_w=1, img_h=1) 96 | W2O_mat = get_W2O_mat((param_lst[0][-3], param_lst[0][-1], -param_lst[0][-2])) 97 | trans_mat = np.linalg.multi_dot([K, RT, rot_mat, W2O_mat, norm_mat]) 98 | trans_mat_right = np.transpose(trans_mat) 99 | regress_mat = np.transpose(np.linalg.multi_dot([RT, rot_mat, W2O_mat, norm_mat])) 100 | 101 | 102 | feed_dict = { 103 | 'img_input': img, 104 | 'obj_rot_mat': torch.tensor(obj_rot_mat).float(), 105 | 'trans_mat_right': torch.tensor(trans_mat_right).float(), 106 | 'pcd': torch.tensor(pcd).float(), 107 | 'regress_mat': torch.tensor(regress_mat).float(), 108 | 'norm_mat': torch.tensor(norm_mat).float(), 109 | 'K': torch.tensor(K).float(), 110 | } 111 | 112 | return feed_dict 113 | -------------------------------------------------------------------------------- /reg_slices/src_convonet/utils/libmcubes/pyarraymodule.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef _EXTMODULE_H 3 | #define _EXTMODULE_H 4 | 5 | #include 6 | #include 7 | 8 | // #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION 9 | #define PY_ARRAY_UNIQUE_SYMBOL mcubes_PyArray_API 10 | #define NO_IMPORT_ARRAY 11 | #include "numpy/arrayobject.h" 12 | 13 | #include 14 | 15 | template 16 | struct numpy_typemap; 17 | 18 | #define define_numpy_type(ctype, dtype) \ 19 | template<> \ 20 | struct numpy_typemap \ 21 | {static const int type = dtype;}; 22 | 23 | define_numpy_type(bool, NPY_BOOL); 24 | define_numpy_type(char, NPY_BYTE); 25 | define_numpy_type(short, NPY_SHORT); 26 | define_numpy_type(int, NPY_INT); 27 | define_numpy_type(long, NPY_LONG); 28 | define_numpy_type(long long, NPY_LONGLONG); 29 | define_numpy_type(unsigned char, NPY_UBYTE); 30 | define_numpy_type(unsigned short, NPY_USHORT); 31 | define_numpy_type(unsigned int, NPY_UINT); 32 | define_numpy_type(unsigned long, NPY_ULONG); 33 | define_numpy_type(unsigned long long, NPY_ULONGLONG); 34 | define_numpy_type(float, NPY_FLOAT); 35 | define_numpy_type(double, NPY_DOUBLE); 36 | define_numpy_type(long double, NPY_LONGDOUBLE); 37 | define_numpy_type(std::complex, NPY_CFLOAT); 38 | define_numpy_type(std::complex, NPY_CDOUBLE); 39 | define_numpy_type(std::complex, NPY_CLONGDOUBLE); 40 | 41 | template 42 | T PyArray_SafeGet(const PyArrayObject* aobj, const npy_intp* indaux) 43 | { 44 | // HORROR. 45 | npy_intp* ind = const_cast(indaux); 46 | void* ptr = PyArray_GetPtr(const_cast(aobj), ind); 47 | switch(PyArray_TYPE(aobj)) 48 | { 49 | case NPY_BOOL: 50 | return static_cast(*reinterpret_cast(ptr)); 51 | case NPY_BYTE: 52 | return static_cast(*reinterpret_cast(ptr)); 53 | case NPY_SHORT: 54 | return static_cast(*reinterpret_cast(ptr)); 55 | case NPY_INT: 56 | return static_cast(*reinterpret_cast(ptr)); 57 | case NPY_LONG: 58 | return static_cast(*reinterpret_cast(ptr)); 59 | case NPY_LONGLONG: 60 | return static_cast(*reinterpret_cast(ptr)); 61 | case NPY_UBYTE: 62 | return static_cast(*reinterpret_cast(ptr)); 63 | case NPY_USHORT: 64 | return static_cast(*reinterpret_cast(ptr)); 65 | case NPY_UINT: 66 | return static_cast(*reinterpret_cast(ptr)); 67 | case NPY_ULONG: 68 | return static_cast(*reinterpret_cast(ptr)); 69 | case NPY_ULONGLONG: 70 | return static_cast(*reinterpret_cast(ptr)); 71 | case NPY_FLOAT: 72 | return static_cast(*reinterpret_cast(ptr)); 73 | case NPY_DOUBLE: 74 | return static_cast(*reinterpret_cast(ptr)); 75 | case NPY_LONGDOUBLE: 76 | return static_cast(*reinterpret_cast(ptr)); 77 | default: 78 | throw std::runtime_error("data type not supported"); 79 | } 80 | } 81 | 82 | template 83 | T PyArray_SafeSet(PyArrayObject* aobj, const npy_intp* indaux, const T& value) 84 | { 85 | // HORROR. 86 | npy_intp* ind = const_cast(indaux); 87 | void* ptr = PyArray_GetPtr(aobj, ind); 88 | switch(PyArray_TYPE(aobj)) 89 | { 90 | case NPY_BOOL: 91 | *reinterpret_cast(ptr) = static_cast(value); 92 | break; 93 | case NPY_BYTE: 94 | *reinterpret_cast(ptr) = static_cast(value); 95 | break; 96 | case NPY_SHORT: 97 | *reinterpret_cast(ptr) = static_cast(value); 98 | break; 99 | case NPY_INT: 100 | *reinterpret_cast(ptr) = static_cast(value); 101 | break; 102 | case NPY_LONG: 103 | *reinterpret_cast(ptr) = static_cast(value); 104 | break; 105 | case NPY_LONGLONG: 106 | *reinterpret_cast(ptr) = static_cast(value); 107 | break; 108 | case NPY_UBYTE: 109 | *reinterpret_cast(ptr) = static_cast(value); 110 | break; 111 | case NPY_USHORT: 112 | *reinterpret_cast(ptr) = static_cast(value); 113 | break; 114 | case NPY_UINT: 115 | *reinterpret_cast(ptr) = static_cast(value); 116 | break; 117 | case NPY_ULONG: 118 | *reinterpret_cast(ptr) = static_cast(value); 119 | break; 120 | case NPY_ULONGLONG: 121 | *reinterpret_cast(ptr) = static_cast(value); 122 | break; 123 | case NPY_FLOAT: 124 | *reinterpret_cast(ptr) = static_cast(value); 125 | break; 126 | case NPY_DOUBLE: 127 | *reinterpret_cast(ptr) = static_cast(value); 128 | break; 129 | case NPY_LONGDOUBLE: 130 | *reinterpret_cast(ptr) = static_cast(value); 131 | break; 132 | default: 133 | throw std::runtime_error("data type not supported"); 134 | } 135 | } 136 | 137 | #endif 138 | -------------------------------------------------------------------------------- /gen_slices/scripts/txt2img.py: -------------------------------------------------------------------------------- 1 | import argparse, os, sys, glob 2 | import torch 3 | import numpy as np 4 | from omegaconf import OmegaConf 5 | from PIL import Image 6 | from tqdm import tqdm, trange 7 | from einops import rearrange 8 | from torchvision.utils import make_grid 9 | 10 | from ldm.util import instantiate_from_config 11 | from ldm.models.diffusion.ddim import DDIMSampler 12 | from ldm.models.diffusion.plms import PLMSSampler 13 | 14 | 15 | def load_model_from_config(config, ckpt, verbose=False): 16 | print(f"Loading model from {ckpt}") 17 | pl_sd = torch.load(ckpt, map_location="cpu") 18 | sd = pl_sd["state_dict"] 19 | model = instantiate_from_config(config.model) 20 | m, u = model.load_state_dict(sd, strict=False) 21 | if len(m) > 0 and verbose: 22 | print("missing keys:") 23 | print(m) 24 | if len(u) > 0 and verbose: 25 | print("unexpected keys:") 26 | print(u) 27 | 28 | model.cuda() 29 | model.eval() 30 | return model 31 | 32 | 33 | if __name__ == "__main__": 34 | parser = argparse.ArgumentParser() 35 | 36 | parser.add_argument( 37 | "--prompt", 38 | type=str, 39 | nargs="?", 40 | default="a painting of a virus monster playing guitar", 41 | help="the prompt to render" 42 | ) 43 | 44 | parser.add_argument( 45 | "--outdir", 46 | type=str, 47 | nargs="?", 48 | help="dir to write results to", 49 | default="outputs/txt2img-samples" 50 | ) 51 | parser.add_argument( 52 | "--ddim_steps", 53 | type=int, 54 | default=200, 55 | help="number of ddim sampling steps", 56 | ) 57 | 58 | parser.add_argument( 59 | "--plms", 60 | action='store_true', 61 | help="use plms sampling", 62 | ) 63 | 64 | parser.add_argument( 65 | "--ddim_eta", 66 | type=float, 67 | default=0.0, 68 | help="ddim eta (eta=0.0 corresponds to deterministic sampling", 69 | ) 70 | parser.add_argument( 71 | "--n_iter", 72 | type=int, 73 | default=1, 74 | help="sample this often", 75 | ) 76 | 77 | parser.add_argument( 78 | "--H", 79 | type=int, 80 | default=256, 81 | help="image height, in pixel space", 82 | ) 83 | 84 | parser.add_argument( 85 | "--W", 86 | type=int, 87 | default=256, 88 | help="image width, in pixel space", 89 | ) 90 | 91 | parser.add_argument( 92 | "--n_samples", 93 | type=int, 94 | default=4, 95 | help="how many samples to produce for the given prompt", 96 | ) 97 | 98 | parser.add_argument( 99 | "--scale", 100 | type=float, 101 | default=5.0, 102 | help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", 103 | ) 104 | opt = parser.parse_args() 105 | 106 | 107 | config = OmegaConf.load("configs/latent-diffusion/txt2img-1p4B-eval.yaml") # TODO: Optionally download from same location as ckpt and chnage this logic 108 | model = load_model_from_config(config, "models/ldm/text2img-large/model.ckpt") # TODO: check path 109 | 110 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 111 | model = model.to(device) 112 | 113 | if opt.plms: 114 | sampler = PLMSSampler(model) 115 | else: 116 | sampler = DDIMSampler(model) 117 | 118 | os.makedirs(opt.outdir, exist_ok=True) 119 | outpath = opt.outdir 120 | 121 | prompt = opt.prompt 122 | 123 | 124 | sample_path = os.path.join(outpath, "samples") 125 | os.makedirs(sample_path, exist_ok=True) 126 | base_count = len(os.listdir(sample_path)) 127 | 128 | all_samples=list() 129 | with torch.no_grad(): 130 | with model.ema_scope(): 131 | uc = None 132 | if opt.scale != 1.0: 133 | uc = model.get_learned_conditioning(opt.n_samples * [""]) 134 | for n in trange(opt.n_iter, desc="Sampling"): 135 | c = model.get_learned_conditioning(opt.n_samples * [prompt]) 136 | shape = [4, opt.H//8, opt.W//8] 137 | samples_ddim, _ = sampler.sample(S=opt.ddim_steps, 138 | conditioning=c, 139 | batch_size=opt.n_samples, 140 | shape=shape, 141 | verbose=False, 142 | unconditional_guidance_scale=opt.scale, 143 | unconditional_conditioning=uc, 144 | eta=opt.ddim_eta) 145 | 146 | x_samples_ddim = model.decode_first_stage(samples_ddim) 147 | x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0) 148 | 149 | for x_sample in x_samples_ddim: 150 | x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') 151 | Image.fromarray(x_sample.astype(np.uint8)).save(os.path.join(sample_path, f"{base_count:04}.png")) 152 | base_count += 1 153 | all_samples.append(x_samples_ddim) 154 | 155 | 156 | # additionally, save as grid 157 | grid = torch.stack(all_samples, 0) 158 | grid = rearrange(grid, 'n b c h w -> (n b) c h w') 159 | grid = make_grid(grid, nrow=opt.n_samples) 160 | 161 | # to image 162 | grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() 163 | Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'{prompt.replace(" ", "-")}.png')) 164 | 165 | print(f"Your samples are ready and waiting four you here: \n{outpath} \nEnjoy.") 166 | -------------------------------------------------------------------------------- /reg_slices/reconstruct_slices.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | from torch import autograd 4 | import numpy as np 5 | import cv2 6 | from tqdm import trange, tqdm 7 | import trimesh 8 | from src_convonet.utils import libmcubes 9 | from src_convonet.common import make_3d_grid, normalize_coord, add_key, coord2index 10 | from src_convonet.utils.libsimplify import simplify_mesh 11 | from src_convonet.utils.libmise import MISE 12 | import time 13 | import math 14 | from src import datasets 15 | from src.models import Slices3DRegModel 16 | from src.model_cam_est import CameraNet 17 | from src.model_disn import DISNModel 18 | from src.model_gt import Slices3DGTModel 19 | from src.datasets import Slice3DDataset 20 | from src.datasets_cam import CamEstDataset 21 | import os 22 | from options import get_parser 23 | import copy 24 | from src.utils import RGB2BGR, tensor2numpy, denorm 25 | 26 | def show_imgs(img_input, img_slices_rec, dir_output, batch_id, shape_id): 27 | n_bs = 1 28 | res_img = np.zeros((128 * 5, 0, 3)) 29 | dir_tgt = f'{dir_output}/{shape_id}' 30 | os.makedirs(dir_tgt, exist_ok=True) 31 | 32 | # X 33 | for idx in range(n_bs): 34 | for slices_idx in range(4): 35 | slice_img = RGB2BGR(tensor2numpy(denorm(img_slices_rec[idx, 3*slices_idx:3*(slices_idx+1), ...]))) 36 | slice_img = cv2.resize(slice_img, (256, 256)) 37 | cv2.imwrite(os.path.join(dir_tgt, f'X_{str(slices_idx + 1)}.png'), slice_img * 255.0) 38 | 39 | 40 | # Z 41 | for idx in range(n_bs): 42 | for slices_idx in range(4, 8): 43 | slice_img = RGB2BGR(tensor2numpy(denorm(img_slices_rec[idx, 3*slices_idx:3*(slices_idx+1), ...]))) 44 | slice_img = cv2.resize(slice_img, (256, 256)) 45 | cv2.imwrite(os.path.join(dir_tgt, f'Z_{str(8 - slices_idx)}.png'), slice_img * 255.0) 46 | 47 | # Y 48 | for idx in range(n_bs): 49 | for slices_idx in range(8, 12): 50 | slice_img = RGB2BGR(tensor2numpy(denorm(img_slices_rec[idx, 3*slices_idx:3*(slices_idx+1), ...]))) 51 | slice_img = cv2.resize(slice_img, (256, 256)) 52 | cv2.imwrite(os.path.join(dir_tgt, f'Y_{str(slices_idx - 7)}.png'), slice_img * 255.0) 53 | 54 | 55 | if __name__ == '__main__': 56 | args = get_parser().parse_args() 57 | if args.name_model == 'slicenet': 58 | model = Slices3DRegModel(img_size=args.img_size, n_slices=args.n_slices, mode=args.mode) 59 | elif args.name_model == 'disn': 60 | model = DISNModel(img_size=args.img_size, mode=args.mode) 61 | else: 62 | model = Slices3DGTModel(img_size=args.img_size, n_slices=args.n_slices, mode=args.mode) 63 | path_ckpt = os.path.join('experiments', args.name_exp, 'ckpt', args.name_ckpt) 64 | model.load_state_dict(torch.load(path_ckpt)['model']) 65 | model = model.cuda() 66 | model = model.eval() 67 | 68 | if args.est_campose: 69 | 70 | model_cam = CameraNet() 71 | path_ckpt_cam = os.path.join('experiments', args.name_exp_cam, 'ckpt', args.name_ckpt_cam) 72 | model_cam.load_state_dict(torch.load(path_ckpt_cam)['model']) 73 | model_cam = model_cam.cuda() 74 | model_cam = model_cam.eval() 75 | 76 | path_res = os.path.join('experiments', args.name_exp, 'results', args.name_dataset) 77 | if not os.path.exists(path_res): 78 | os.makedirs(path_res) 79 | 80 | dataset = Slice3DDataset(split='test', args=args) 81 | 82 | dataset_cam = CamEstDataset(split='test', args=args) 83 | 84 | dir_dataset = os.path.join(args.dir_data, args.name_dataset) 85 | if args.name_dataset == 'shapenet': 86 | categories = args.categories_test.split(',')[:-1] 87 | id_shapes = [] 88 | for category in categories: 89 | id_shapes_ = open(f'{dir_dataset}/04_splits/{category}/test.lst').read().split('\n') 90 | id_shapes += id_shapes_ 91 | 92 | else: 93 | id_shapes = open(f'{dir_dataset}/04_splits/test.lst').read().split('\n') 94 | 95 | dir_output = os.path.join('experiments', args.name_exp, 'img_slices') 96 | os.makedirs(dir_output, exist_ok=True) 97 | 98 | with torch.no_grad(): 99 | 100 | for idx in tqdm(range(len(dataset))): 101 | 102 | shape_id = id_shapes[idx] 103 | data = dataset[idx] 104 | path_mesh = os.path.join(path_res, '%s.obj'%shape_id) 105 | 106 | # if os.path.exists(path_mesh): continue 107 | for key in data: 108 | data[key] = data[key].unsqueeze(0).cuda() 109 | if args.est_campose: 110 | data_cam = dataset_cam[idx] 111 | for key in data_cam: 112 | data_cam[key] = data_cam[key].unsqueeze(0).cuda() 113 | # if use estimated came 114 | if args.est_campose: 115 | print('using predicted pose') 116 | dict_ret_cam = model_cam(data_cam) 117 | 118 | cam_rot_pred = dict_ret_cam['pred_rotation_mat_inv'] 119 | cam_rot_pred[0][0][1] *= -1. 120 | cam_rot_pred[0][0][2] *= -1. 121 | 122 | cam_rot_pred[0][2][1] *= -1. 123 | cam_rot_pred[0][2][2] *= -1. 124 | 125 | cam_rot_pred[0][1][0] *= -1. 126 | 127 | tmp = copy.deepcopy(cam_rot_pred[0][2][:]) 128 | cam_rot_pred[0][2][:] = cam_rot_pred[0][1][:] 129 | cam_rot_pred[0][1][:] = tmp 130 | 131 | data['obj_rot_mat'] = cam_rot_pred 132 | 133 | data['trans_mat_right'] = dict_ret_cam['pred_trans_mat'] 134 | 135 | ret_dict = model(data) 136 | img_slices_rec = ret_dict['slices_rec'] 137 | img_input = data['img_input'] 138 | 139 | show_imgs(img_input, img_slices_rec, dir_output, idx, shape_id) 140 | -------------------------------------------------------------------------------- /gen_slices/ldm/modules/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 5 | 6 | 7 | class LPIPSWithDiscriminator(nn.Module): 8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 11 | disc_loss="hinge"): 12 | 13 | super().__init__() 14 | assert disc_loss in ["hinge", "vanilla"] 15 | self.kl_weight = kl_weight 16 | self.pixel_weight = pixelloss_weight 17 | self.perceptual_loss = LPIPS().eval() 18 | self.perceptual_weight = perceptual_weight 19 | # output log variance 20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 21 | 22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 23 | n_layers=disc_num_layers, 24 | use_actnorm=use_actnorm 25 | ).apply(weights_init) 26 | self.discriminator_iter_start = disc_start 27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 28 | self.disc_factor = disc_factor 29 | self.discriminator_weight = disc_weight 30 | self.disc_conditional = disc_conditional 31 | 32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 33 | if last_layer is not None: 34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 36 | else: 37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 39 | 40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 42 | d_weight = d_weight * self.discriminator_weight 43 | return d_weight 44 | 45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 46 | global_step, last_layer=None, cond=None, split="train", 47 | weights=None): 48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 49 | if self.perceptual_weight > 0: 50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 51 | rec_loss = rec_loss + self.perceptual_weight * p_loss 52 | 53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 54 | weighted_nll_loss = nll_loss 55 | if weights is not None: 56 | weighted_nll_loss = weights*nll_loss 57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 59 | kl_loss = posteriors.kl() 60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 61 | 62 | # now the GAN part 63 | if optimizer_idx == 0: 64 | # generator update 65 | if cond is None: 66 | assert not self.disc_conditional 67 | logits_fake = self.discriminator(reconstructions.contiguous()) 68 | else: 69 | assert self.disc_conditional 70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 71 | g_loss = -torch.mean(logits_fake) 72 | 73 | if self.disc_factor > 0.0: 74 | try: 75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 76 | except RuntimeError: 77 | assert not self.training 78 | d_weight = torch.tensor(0.0) 79 | else: 80 | d_weight = torch.tensor(0.0) 81 | 82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 84 | 85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), 87 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 88 | "{}/d_weight".format(split): d_weight.detach(), 89 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 90 | "{}/g_loss".format(split): g_loss.detach().mean(), 91 | } 92 | return loss, log 93 | 94 | if optimizer_idx == 1: 95 | # second pass for discriminator update 96 | if cond is None: 97 | logits_real = self.discriminator(inputs.contiguous().detach()) 98 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 99 | else: 100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 102 | 103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 105 | 106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 107 | "{}/logits_real".format(split): logits_real.detach().mean(), 108 | "{}/logits_fake".format(split): logits_fake.detach().mean() 109 | } 110 | return d_loss, log 111 | 112 | --------------------------------------------------------------------------------