├── README.md ├── dataset ├── 3dfront │ ├── ali3dfront.py │ ├── latent_transforms.py │ ├── train_512_512_128.txt │ └── tsdf_transforms.py ├── __init__.py └── dataset_base.py ├── factory ├── dataset_factory.py ├── model_factory.py └── trainer_factory.py ├── main ├── test.py └── train.py ├── model ├── __init__.py ├── auto_encoder │ ├── ms_vqgan │ │ ├── blocks.py │ │ ├── discriminator.py │ │ ├── encoder_decoder.py │ │ ├── loss.py │ │ ├── lr_scheduler.py │ │ ├── ms_tsdf_pvqgan_new.py │ │ └── quantize.py │ └── sketch_vae │ │ └── sketch_VAE.py ├── diffuser │ ├── pipelines │ │ ├── pipeline_ddim.py │ │ └── pipeline_ddpm.py │ ├── schedulers │ │ ├── noise_schedule.py │ │ ├── scheduling_ddim.py │ │ └── scheduling_ddpm.py │ └── utils.py ├── model_base.py ├── ms_ldm │ ├── blocks │ │ ├── blk_modules.py │ │ ├── blk_wrapper.py │ │ ├── model_utils.py │ │ └── sparse_blk_modules.py │ ├── multiscale_latent_diffusion.py │ ├── sketch_encoder.py │ ├── spatial_transformer │ │ ├── spatial_transformer_2d.py │ │ ├── spatial_transformer_3d.py │ │ ├── spatial_transformer_3d_sparse.py │ │ └── spatial_transformer_bev_sparse.py │ └── unet_model.py └── utils │ ├── ema.py │ ├── global_mapper.py │ └── torch_sparse_utils.py ├── readme ├── method.png └── teaser.png ├── requirements.txt ├── sketch_samples └── sketch_1.png ├── trainer ├── __init__.py ├── cascaded_ldm_trainer.py ├── sketchVAE_trainner.py ├── trainer_base.py └── tsdf_pvqgan_trainer.py └── utils ├── .DS_Store ├── __init__.py ├── config ├── Configuration.py └── samples │ ├── cascaded_ldm_sketch_cond │ ├── dataset │ │ └── ali3dfront.yaml │ ├── model │ │ ├── pyramid_occ_denoiser.yaml │ │ ├── sketch_vae.yaml │ │ └── tsdf_vqgan_retrain.yaml │ ├── readme.md │ └── root_config.yaml │ ├── cascaded_ldm_ucond │ ├── dataset │ │ └── ali3dfront.yaml │ ├── model │ │ ├── pyramid_occ_denoiser.yaml │ │ └── tsdf_vqgan_retrain.yaml │ ├── readme.md │ └── root_config.yaml │ ├── sketch_VAE │ ├── dataset │ │ └── ali3dfront.yaml │ ├── model │ │ └── sketch_vae.yaml │ ├── readme.md │ └── root_config.yaml │ └── tsdf_gumbel_ms_vqgan │ ├── dataset │ └── ali3dfront.yaml │ ├── model │ └── tsdf_vqgan.yaml │ ├── readme.md │ └── root_config.yaml ├── diffusion_monitor.py ├── graphics_utils.py ├── logger ├── basic_logger.py └── dummy_logger.py ├── torch_distributed_config.py └── visualize_occ.py /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Updates 3 | [2024-10-05] Checkpoints added. 4 | 5 | [2024-10-05] More concrete instructions added. 6 | 7 | # Overview 8 | This is the official implementation of our paper: 9 | 10 | **[CVPR2024] DiffInDScene: Diffusion-based High-Quality 3D Indoor Scene Generation**\ 11 | *Xiaoliang Ju\*, Zhaoyang Huang\*, Yijin Li, Guofeng Zhang, Yu Qiao, Hongsheng Li* 12 | 13 | [[paper]](https://openaccess.thecvf.com/content/CVPR2024/papers/Ju_DiffInDScene_Diffusion-based_High-Quality_3D_Indoor_Scene_Generation_CVPR_2024_paper.pdf)[[sup]](https://openaccess.thecvf.com/content/CVPR2024/supplemental/Ju_DiffInDScene_Diffusion-based_High-Quality_CVPR_2024_supplemental.pdf)[[arXiv]](https://arxiv.org/abs/2306.00519)[[project page]](https://akirahero.github.io/diffindscene/) 14 | 15 | 16 | DiffInDScene generates large indoor scene with a coarse-to-fine fashion: 17 | ![teaser](readme/teaser.png) 18 | which consists of a multi-scale PatchVQGAN for occupancy encoding and a cascaded sparse diffusion model. 19 | ![method](readme/method.png) 20 | 21 | This repo provides or will provide 22 | 23 | - [x] code for data processing 24 | - [x] code for inference 25 | - [x] code for training 26 | - [x] checkpoint 27 | - [ ] extension to other dataset 28 | 29 | 30 | 31 | # Dependency 32 | Our sparse diffusion is implemented based on [TorchSparse](https://github.com/mit-han-lab/torchsparse). For it is still under rapid developing, we provide the commit hash of the version we used: ``1a10fda15098f3bf4fa2d01f8bee53e85762abcf``. 33 | 34 | The main codebases of our framework includes 35 | [VQGAN](https://github.com/CompVis/taming-transformers), 36 | [VQ-VAE-2](https://github.com/rosinality/vq-vae-2-pytorch), and 37 | [Diffusers](https://github.com/huggingface/diffusers), and we only melt the necessary parts into our repo to avoid code dependency. 38 | 39 | We employ [DreamSpace](https://ybbbbt.com/publication/dreamspace/) to texture the generated meshes. You can also substitute it to other similar texturing tools. 40 | 41 | 42 | # Environment Setup 43 | ## Step 1. create a conda environment 44 | ```shell 45 | conda create -n diffindscene python=3.9 46 | conda activate diffindscene 47 | ``` 48 | ## Step 2. Install dependencies by pip 49 | ``` 50 | pip install -r requirements.txt 51 | ``` 52 | ## Step 3. Setup the torchsparse library 53 | ```shell 54 | # for now, we do not support the newest version of torchsparse 55 | # please compile from source code. 56 | 57 | git clone git@github.com:mit-han-lab/torchsparse.git 58 | cd torchsparse 59 | git checkout 1a10fda15098f3bf4fa2d01f8bee53e85762abcf 60 | python setup.py install 61 | 62 | ``` 63 | 64 | # Inference with checkpoints 65 | 66 | ## Step 1: Download checkpoints 67 | Download the checkpoints [here](https://huggingface.co/akirahero/diffindscene/tree/main/ckpt). Put all checkpoints in the folder `ckpt`. 68 | 69 | ## Step 2: Run the inference script 70 | 71 | ### For unconditional generation 72 | ``` 73 | conda activate diffindscene 74 | export PYTHONPATH=${PATH_TO_DIFFINDSCENE}:${PYTHONPATH} 75 | 76 | # unconditional generation 77 | python main/test.py --cfg_dir utils/config/samples/cascaded_ldm_ucond 78 | 79 | ``` 80 | The results will be saved in `output` folder. 81 | 82 | ### For conditioned-generation with sketch 83 | ``` 84 | conda activate diffindscene 85 | export PYTHONPATH=${PATH_TO_DIFFINDSCENE}:${PYTHONPATH} 86 | 87 | # conditioned-generation with sketch 88 | python main/test.py --cfg_dir utils/config/samples/cascaded_ldm_sketch_cond 89 | ``` 90 | The results will be saved in `output` folder. 91 | 92 | More sketch images can be downloaded from [here](https://huggingface.co/akirahero/diffindscene/blob/main/sketch_samples.tar.gz). 93 | 94 | 95 | # Prepare the Dataset 96 | 97 | We mainly use [3D-FRONT](https://tianchi.aliyun.com/specials/promotion/alibaba-3d-scene-dataset) as our dataset. 98 | 99 | The [code for data processing](https://github.com/AkiraHero/3dfront_proc) is developed based on the repo [BlenderProc-3DFront](https://github.com/yinyunie/BlenderProc-3DFront) and [SDFGen](https://github.com/christopherbatty/SDFGen). 100 | 101 | The pipeline mainly consists of following steps 102 | * Extract resources from original dataset and join them to a scene. 103 | * Use blender to remesh the scene to be watertight mesh. 104 | * Generate SDF of the scene. 105 | * Compress *.sdf to *.npz 106 | 107 | Example scripts: 108 | ```shell 109 | # generate watertight meshes 110 | blenderproc run examples/datasets/front_3d_with_improved_mat/process_3dfront.py ${PATH-TO-3D-FUTURE-model} ${PATH-TO-3D-FRONT-texture} ${MESH_OUT_FOLDER} 111 | 112 | # generate SDF for every mesh 113 | sh examples/datasets/front_3d_with_improved_mat/sdf_gen.sh ${MESH_OUT_FOLDER} ${PATH-TO-SDFGen} 114 | 115 | # compress *.sdf to *.npz 116 | python examples/datasets/front_3d_with_improved_mat/npz_tsdf.py ${MESH_OUT_FOLDER} ${NPZ_OUT_DIR} 117 | ``` 118 | 119 | 120 | 121 | # Training from Scratch 122 | Every part of our model corresponds to a individual configuration folder located in `utils/config/samples/`, with an instruction file as `readme.md`. 123 | 124 | 125 | 126 | ## The first stage model: PatchVQGAN 127 | Training script: 128 | ``` 129 | python main/train.py utils/config/samples/tsdf_gumbel_ms_vqgan 130 | ``` 131 | 132 | 133 | 134 | Testing script: 135 | ``` 136 | python main/test.py utils/config/samples/tsdf_gumbel_ms_vqgan 137 | ``` 138 | and the latents will be saved in your designated output path. 139 | 140 | 141 | ## [Optional] Sketch VAE for conditioned generation 142 | ``` 143 | python main/train.py utils/config/samples/sketch_VAE 144 | ``` 145 | 146 | ## Cascaded Latent Diffusion 147 | The cascaded diffusion consists of 3 levels as described in our paper, which can be trained individually by setting "level" variable in `config/samples/cascaded_ldm/model/pyramid_occ_denoiser.yaml`. 148 | 149 | The training script is 150 | 151 | ``` 152 | python main/train.py --cfg_dir utils/config/samples/cascaded_ldm 153 | ``` 154 | and the inference script is 155 | ``` 156 | python main/test.py --cfg_dir utils/config/samples/cascaded_ldm 157 | ``` 158 | 159 | # Citation 160 | ``` 161 | @inproceedings{ju2024diffindscene, 162 | title={DiffInDScene: Diffusion-based High-Quality 3D Indoor Scene Generation}, 163 | author={Ju, Xiaoliang and Huang, Zhaoyang and Li, Yijin and Zhang, Guofeng and Qiao, Yu and Li, Hongsheng}, 164 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 165 | pages={4526--4535}, 166 | year={2024} 167 | } 168 | ``` -------------------------------------------------------------------------------- /dataset/3dfront/ali3dfront.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pickle 4 | import logging 5 | import torch 6 | 7 | from torch.utils.data import DataLoader 8 | from collections import OrderedDict 9 | 10 | from dataset.dataset_base import DatasetBase 11 | from utils.config.Configuration import default 12 | 13 | from .tsdf_transforms import ToTensor, RandomTransformSpace, Compose 14 | from .latent_transforms import RandomTransform, SimpleCrop 15 | 16 | 17 | class Ali3DFront(DatasetBase): 18 | support_data_contents = ["tsdf", "latent", "sketch"] 19 | 20 | def __init__(self, config): 21 | super(Ali3DFront, self).__init__() 22 | self._batch_size = config.paras.batch_size 23 | self._shuffle = config.paras.shuffle 24 | self._num_workers = config.paras.num_workers 25 | self.data_root = config.paras.data_root 26 | self.data_split_file = config.paras.data_split_file 27 | 28 | self.version = default(config.paras, "version", "") 29 | self.load_content = default(config.paras, "load_content", ["tsdf"]) 30 | self.latent_dir = default(config.paras, "latent_dir", None) 31 | self.latent_scale = default(config.paras, "latent_scale", None) 32 | self.level_config = default(config.paras, "level_config", None) 33 | self.voxel_dim = default(config.paras, "voxel_dim", [256, 256, 256]) 34 | 35 | self.batch_collate_function_name = default( 36 | config.paras, "batch_collate_func", "batch_collate_fn" 37 | ) 38 | self.designated_batch_collate_function = self.__getattribute__( 39 | self.batch_collate_function_name 40 | ) 41 | 42 | self.gen_sketch = False 43 | self.config_load_content() 44 | 45 | self.tsdf_cashe = OrderedDict() 46 | self.max_cashe = 1000 47 | 48 | self.transform_list = default(config.paras, "transform", []) 49 | self.set_transform() 50 | 51 | self.mode = "train" 52 | if "mode" in config.paras: 53 | self.set_mode(config.paras.mode) 54 | 55 | self.mode_indices = {"train": [], "test": [], "val": []} 56 | self.scene_id_list, self.scene_path_dict = self.load_scene_id_list( 57 | self.data_root 58 | ) 59 | self.get_train_val_split() 60 | 61 | def config_load_content(self): 62 | for i in self.load_content: 63 | assert i in self.support_data_contents 64 | if "latent" in self.load_content: 65 | latent_files = os.listdir(self.latent_dir) 66 | self.latent_file_paths = {} 67 | for i in latent_files: 68 | file_path = os.path.join(self.latent_dir, i) 69 | scene_id = i.split(".")[0] 70 | self.latent_file_paths[scene_id] = file_path 71 | if "sketch" in self.load_content: 72 | self.gen_sketch = True 73 | 74 | def set_transform(self): 75 | if hasattr(self, "transforms"): 76 | del self.transforms 77 | # set transform for data producing/augmentation 78 | transform = [] 79 | # highest resolution 80 | voxel_dim = self.voxel_dim 81 | voxel_size = 0.04 # max voxel 82 | 83 | random_rotation = True 84 | random_translation = True 85 | if "train" != self.mode: 86 | random_rotation = False 87 | random_translation = False 88 | paddingXY = 0.12 89 | paddingZ = 0.12 90 | epochs = 999 91 | 92 | transform += [ 93 | ToTensor(), 94 | ] 95 | 96 | logging.info("[Dataset]transform list:" + str(self.transform_list)) 97 | if len(self.transform_list) == 0: 98 | logging.warning("[Dataset]No transform added for dataset!!") 99 | 100 | elif self.transform_list[0] == "randomcrop": 101 | transform += [ 102 | RandomTransformSpace( 103 | voxel_dim, 104 | voxel_size, 105 | random_rotation, 106 | random_translation, 107 | paddingXY, 108 | paddingZ, 109 | max_epoch=epochs, 110 | using_camera_pose=False, 111 | random_trans_method="occ_center", 112 | random_rot_method="right-angle", 113 | ), 114 | ] 115 | elif self.transform_list[0] == "simpletrans": 116 | transform += [ 117 | RandomTransform( 118 | voxel_dim, 119 | voxel_size, 120 | random_rotation, 121 | random_translation, 122 | paddingXY, 123 | paddingZ, 124 | max_epoch=epochs, 125 | gen_bev_sketch=self.gen_sketch, 126 | ) 127 | ] 128 | else: 129 | raise NotImplementedError 130 | 131 | if "simplecrop" in self.transform_list: 132 | transform += [SimpleCrop(voxel_dim)] 133 | 134 | self.transforms = Compose(transform) 135 | 136 | def set_level(self, level): 137 | logging.info( 138 | "[DataSet]Ali3DFront changed model level related config:{}".format(level) 139 | ) 140 | assert level in ["first", "second", "third"] 141 | assert self.level_config is not None 142 | level_config = self.level_config[level] 143 | self.load_content = level_config["load_content"] 144 | self.transform_list = level_config["transform"] 145 | self.config_load_content() 146 | self.set_transform() 147 | 148 | def read_scene_list(self, file): 149 | with open(file, "r") as f: 150 | scenes = f.readlines() 151 | scenes = [i.strip("\n") for i in scenes] 152 | return scenes 153 | 154 | def get_train_val_split(self): 155 | cur_dir = os.path.dirname(__file__) 156 | train_list = self.read_scene_list( 157 | os.path.join(cur_dir, self.data_split_file["train"]) 158 | ) 159 | 160 | # no need for generative model 161 | test_list = [] 162 | val_list = [] 163 | 164 | scene_mode_dict = {} 165 | for i in train_list: 166 | scene_mode_dict[i] = "train" 167 | for i in val_list: 168 | scene_mode_dict[i] = "val" 169 | for i in test_list: 170 | scene_mode_dict[i] = "test" 171 | 172 | for inx, i in enumerate(self.scene_id_list): 173 | scene = i 174 | if scene in self.scene_path_dict and scene in scene_mode_dict: 175 | mode = scene_mode_dict[scene] 176 | self.mode_indices[mode].append(inx) 177 | 178 | def set_mode(self, mode): 179 | self.tsdf_cashe.clear() 180 | assert mode in ["train", "val", "test"] 181 | self.mode = mode 182 | 183 | def map_idx(self, idx): 184 | mode_str = self.get_mode_indices_label() 185 | return self.mode_indices[mode_str][idx] 186 | 187 | def get_mode_indices_label(self): 188 | mode_str = self.mode 189 | if self.mode == "test_scene": 190 | mode_str = "test" 191 | return mode_str 192 | 193 | def load_scene_id_list(self, data_dir): 194 | scene_id_list = [ 195 | i for i in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, i)) 196 | ] 197 | # scene path 198 | scene_path_list = {} 199 | for scene in scene_id_list: 200 | p = os.path.join(data_dir, scene) 201 | scene_path_list[scene] = p 202 | return scene_id_list, scene_path_list 203 | 204 | def read_latent(self, scene): 205 | lat_file = self.latent_file_paths[scene] 206 | with open(lat_file, "rb") as f: 207 | lat_dict = pickle.load(f) 208 | return lat_dict 209 | 210 | def read_scene_volumes(self, scene): 211 | def switch_dim(full_tsdf_dict): 212 | # switch dims to x,y,z 213 | tsdf = full_tsdf_dict["tsdf"] 214 | tsdf = tsdf.transpose(0, 2, 1) 215 | origin = np.array( 216 | [ 217 | full_tsdf_dict["origin"][0], 218 | full_tsdf_dict["origin"][2], 219 | full_tsdf_dict["origin"][1], 220 | ] 221 | ) 222 | voxel_size = full_tsdf_dict["voxel_size"] 223 | return {"tsdf": tsdf, "origin": origin, "voxel_size": voxel_size} 224 | 225 | if scene not in self.tsdf_cashe.keys(): 226 | if len(self.tsdf_cashe) > self.max_cashe: 227 | self.tsdf_cashe.popitem(last=False) 228 | scene_path = self.scene_path_dict[scene] 229 | 230 | self.tsdf_cashe[scene] = {} 231 | 232 | if self.version == "v2": 233 | filename = "tsdf_v2.npz" 234 | else: 235 | filename = "tsdf.npz" 236 | full_tsdf_dict = np.load(os.path.join(scene_path, filename)) 237 | full_tsdf_dict = switch_dim(full_tsdf_dict) 238 | self.tsdf_cashe[scene]["gt"] = full_tsdf_dict["tsdf"] 239 | self.tsdf_cashe[scene]["gt_origin"] = full_tsdf_dict["origin"] 240 | self.tsdf_cashe[scene]["voxel_size"] = full_tsdf_dict["voxel_size"] 241 | self.tsdf_cashe[scene]["crop_dim"] = torch.tensor(self.voxel_dim) 242 | self.tsdf_cashe[scene]["data_type"] = "tsdf" 243 | return self.tsdf_cashe[scene] 244 | 245 | def __len__(self): 246 | return len(self.mode_indices[self.get_mode_indices_label()]) 247 | 248 | def __getitem__(self, idx): 249 | return self.get_general_item(idx) 250 | 251 | def get_general_item(self, idx): 252 | idx = self.map_idx(idx) 253 | scene = self.scene_id_list[idx] 254 | 255 | items = { 256 | "scene": scene, 257 | } 258 | 259 | if "tsdf" in self.load_content: 260 | tsdf = self.read_scene_volumes(scene) 261 | items.update({"tsdf": tsdf, "vol_type": tsdf["data_type"]}) 262 | 263 | if "latent" in self.load_content: 264 | latent = self.read_latent(scene) 265 | items.update( 266 | {"latent": latent, "latent_scale": torch.tensor(self.latent_scale)} 267 | ) 268 | try: 269 | if self.transforms is not None: 270 | items = self.transforms(items) 271 | except: 272 | logging.error("[Dataset] Transform failure on scene:{}".format(scene)) 273 | raise RuntimeError 274 | return items 275 | 276 | def get_data_loader(self, distributed=False): 277 | if distributed: 278 | sampler = torch.utils.data.distributed.DistributedSampler(self) 279 | else: 280 | sampler = None 281 | 282 | bfn = self.designated_batch_collate_function 283 | 284 | data_loader = DataLoader( 285 | dataset=self, 286 | batch_size=self._batch_size, 287 | shuffle=(sampler is None) and self._shuffle, 288 | num_workers=self._num_workers, 289 | collate_fn=bfn, 290 | pin_memory=True, 291 | drop_last=False, 292 | sampler=sampler, 293 | ) 294 | return data_loader 295 | 296 | @staticmethod 297 | def batch_collate_fn_for_lat(batch_list, _unused=False): 298 | assert len(batch_list) == 1 # for they hv diff size 299 | batch_dict = {} 300 | batch_dict["scene"] = [] 301 | batch_dict["gt_tsdf"] = [] 302 | batch_dict["gt_origin"] = [] 303 | batch_dict["voxel_size"] = [] 304 | for i in batch_list: 305 | batch_dict["scene"].append(i["scene"]) 306 | batch_dict["gt_tsdf"].append(i["tsdf"]["gt"].unsqueeze(0)) 307 | batch_dict["gt_origin"].append(i["tsdf"]["gt_origin"]) 308 | batch_dict["voxel_size"].append(i["tsdf"]["voxel_size"]) 309 | for k in ["gt_tsdf", "gt_origin", "voxel_size"]: 310 | batch_dict[k] = torch.stack(batch_dict[k]) 311 | return batch_dict 312 | 313 | @staticmethod 314 | def batch_collate_latent_code(batch_list, _unused=False): 315 | batch_dict = {} 316 | keys = batch_list[0].keys() 317 | keep_list_key = ["scene", "latent_scale"] 318 | for k in keys: 319 | batch_dict[k] = [] 320 | for i in batch_dict: 321 | for j in batch_list: 322 | batch_dict[i].append(j[i]) 323 | for k in batch_dict: 324 | if k not in keep_list_key: 325 | batch_dict[k] = torch.cat(batch_dict[k]) 326 | return batch_dict 327 | 328 | @staticmethod 329 | def batch_collate_fn(batch_list, _unused=False): 330 | batch_dict = {} 331 | batch_dict["scene"] = [] 332 | batch_dict["gt_tsdf"] = [] 333 | batch_dict["vol_origin_partial"] = [] 334 | batch_dict["voxel_size"] = batch_list[0]["tsdf"]["voxel_size"] 335 | 336 | for i in batch_list: 337 | batch_dict["scene"].append(i["scene"]) 338 | batch_dict["gt_tsdf"].append(i["partial_tsdf"]["gt"]) 339 | batch_dict["vol_origin_partial"].append(i["vol_origin_partial"]) 340 | 341 | batch_dict["gt_tsdf"] = torch.stack(batch_dict["gt_tsdf"]).unsqueeze( 342 | 1 343 | ) # chn = 1 344 | batch_dict["vol_type"] = batch_list[0]["vol_type"] 345 | batch_dict["vol_origin_partial"] = torch.stack(batch_dict["vol_origin_partial"]) 346 | return batch_dict 347 | 348 | @staticmethod 349 | def load_data_to_gpu(batch_dict, device=None): 350 | batch_dict_gpu = {} 351 | for i in batch_dict: 352 | if isinstance(batch_dict[i], torch.Tensor): 353 | batch_dict_gpu[i] = batch_dict[i].to(device) 354 | else: 355 | batch_dict_gpu[i] = batch_dict[i] 356 | return batch_dict_gpu 357 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import traceback 4 | 5 | path = os.path.dirname(os.path.abspath(__file__)) 6 | py_list = [] 7 | for root, dirs, files in os.walk(path, topdown=False): 8 | for name in files: 9 | if name.endswith(".py") and not name.endswith("__init__.py"): 10 | rel_dir = os.path.relpath(root, path) 11 | if rel_dir != ".": 12 | rel_file = os.path.join(rel_dir, name) 13 | else: 14 | rel_file = name 15 | py_list.append(rel_file) 16 | for py in py_list: 17 | mod_name = ".".join([__name__, *(py.split("/"))]) 18 | mod_name = mod_name[:-3] 19 | try: 20 | mod = __import__(mod_name, fromlist=[mod_name]) 21 | except ModuleNotFoundError as e: 22 | logging.debug(traceback.format_exc()) 23 | logging.debug("Fail to import submodule:{}".format(mod_name)) 24 | continue 25 | classes = [getattr(mod, x) for x in dir(mod) if isinstance(getattr(mod, x), type)] 26 | for cls in classes: 27 | if "dataset" in str(cls): 28 | globals()[cls.__name__] = cls 29 | -------------------------------------------------------------------------------- /dataset/dataset_base.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | 3 | 4 | class DatasetBase(Dataset): 5 | def __init__(self): 6 | super(DatasetBase, self).__init__() 7 | self.mode ='train' 8 | 9 | def __getitem__(self, index): 10 | raise NotImplementedError 11 | 12 | def __len__(self): 13 | raise NotImplementedError 14 | 15 | def get_data_loader(self, distributed=False): 16 | raise NotImplementedError 17 | 18 | def set_mode(self, mode): 19 | assert mode in ['train', 'val', 'test'] 20 | self.mode = mode 21 | 22 | def get_mode(self): 23 | return self.mode 24 | -------------------------------------------------------------------------------- /factory/dataset_factory.py: -------------------------------------------------------------------------------- 1 | from dataset import * 2 | 3 | 4 | class DatasetFactory: 5 | singleton_dataset = None 6 | def __init__(self): 7 | pass 8 | 9 | @staticmethod 10 | def get_dataset(data_config): 11 | class_name = data_config['dataset_class'] 12 | all_classes = DatasetBase.__subclasses__() 13 | for cls in all_classes: 14 | if cls.__name__ == class_name: 15 | return cls(data_config['config_file']['expanded']) 16 | raise TypeError(f'no class named \'{class_name}\' found in dataset folder') 17 | 18 | @classmethod 19 | def get_singleton_dataset(cls, data_config=None): 20 | if data_config is None: 21 | return cls.singleton_dataset 22 | if cls.singleton_dataset is None: 23 | cls.singleton_dataset = cls.get_dataset(data_config) 24 | return cls.singleton_dataset 25 | -------------------------------------------------------------------------------- /factory/model_factory.py: -------------------------------------------------------------------------------- 1 | from model import * 2 | from utils.config.Configuration import Configuration 3 | 4 | class ModelFactory: 5 | def __init__(self): 6 | pass 7 | 8 | @staticmethod 9 | def get_model(model_config): 10 | class_name, paras = Configuration.find_dict_node(model_config, 'model_class') 11 | all_classes = ModelBase.__subclasses__() 12 | for cls in all_classes: 13 | if cls.__name__ == class_name: 14 | return cls(model_config['config_file']['expanded']) # todo not perfect 15 | raise TypeError(f'no class named \'{class_name}\' found in model folder') 16 | 17 | 18 | -------------------------------------------------------------------------------- /factory/trainer_factory.py: -------------------------------------------------------------------------------- 1 | from trainer import * 2 | 3 | 4 | class TrainerFactory: 5 | def __init__(self): 6 | pass 7 | 8 | @staticmethod 9 | def get_trainer(trainer_config): 10 | class_name = trainer_config['trainer_class'] 11 | all_classes = TrainerBase.__subclasses__() 12 | for cls in all_classes: 13 | if cls.__name__ == class_name: 14 | return cls(trainer_config) 15 | raise TypeError(f'no class named \'{class_name}\' found in trainer folder') -------------------------------------------------------------------------------- /main/test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import traceback 3 | import subprocess 4 | from utils.config.Configuration import Configuration 5 | from factory.model_factory import ModelFactory 6 | from factory.dataset_factory import DatasetFactory 7 | from factory.trainer_factory import TrainerFactory 8 | 9 | 10 | 11 | 12 | 13 | if __name__ == '__main__': 14 | git_version = subprocess.check_output(["git", 'rev-parse', 'HEAD']).strip().decode() 15 | logging.info(f'Your program version is {git_version}') 16 | try: 17 | # manage config 18 | logging_logger = logging.getLogger() 19 | logging_logger.setLevel(logging.NOTSET) 20 | ch = logging.StreamHandler() 21 | formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(name)s: %(message)s') 22 | ch.setFormatter(formatter) 23 | 24 | config = Configuration() 25 | args = config.get_shell_args_train() 26 | config.load_config(args.cfg_dir) 27 | config.overwrite_config_by_shell_args(args) 28 | 29 | # instantiating all modules by non-singleton factory 30 | test_dataset_config = config.dataset_config 31 | test_dataset_config['config_file']['expanded'].paras.shuffle = False 32 | dataset = DatasetFactory.get_dataset(test_dataset_config) 33 | 34 | model_config = config.model_config 35 | model_config.config_file.expanded.update({'mode': "testing"}) 36 | model = ModelFactory.get_model(model_config) 37 | trainer = TrainerFactory.get_trainer(config.testing_config) 38 | if config.extra_config['distributed']: 39 | logging.info("using distributed training......") 40 | trainer.config_distributed_computing(launcher=config.extra_config['launcher'], 41 | tcp_port=config.extra_config['tcp_port'], 42 | local_rank=config.extra_config['local_rank']) 43 | 44 | trainer.set_model(model) 45 | trainer.set_test_dataset(dataset) 46 | 47 | # trainer.set_logger(logger) 48 | 49 | # load checkpoint 50 | if args.check_point_file is not None: 51 | trainer.load_state(args.check_point_file) 52 | elif "ckpt" in config.testing_config: 53 | trainer.load_state(config.testing_config.ckpt) 54 | else: 55 | logging.warning("No Checkpoint provided for this test!") 56 | 57 | logging.info("Preparation done! Trainer run!") 58 | trainer.run_test() 59 | 60 | except Exception as e: 61 | logging.exception(traceback.format_exc()) 62 | exit(-1) 63 | -------------------------------------------------------------------------------- /main/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import traceback 4 | import subprocess 5 | import copy 6 | from utils.config.Configuration import Configuration 7 | from factory.model_factory import ModelFactory 8 | from factory.dataset_factory import DatasetFactory 9 | from factory.trainer_factory import TrainerFactory 10 | from utils.logger.basic_logger import BasicLogger 11 | from utils.logger.dummy_logger import DummyLogger 12 | 13 | 14 | 15 | 16 | 17 | if __name__ == '__main__': 18 | git_version = subprocess.check_output(["git", 'rev-parse', 'HEAD']).strip().decode() 19 | logging.info(f'Your program version is {git_version}') 20 | try: 21 | # manage config 22 | logging_logger = logging.getLogger() 23 | logging_logger.setLevel(logging.NOTSET) 24 | ch = logging.StreamHandler() 25 | formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(name)s: %(message)s') 26 | ch.setFormatter(formatter) 27 | 28 | config = Configuration() 29 | args = config.get_shell_args_train() 30 | config.load_config(args.cfg_dir) 31 | config.overwrite_config_by_shell_args(args) 32 | 33 | # instantiating all modules by non-singleton factory 34 | dataset = DatasetFactory.get_dataset(config.dataset_config) 35 | 36 | 37 | if config.training_config.enable_val: 38 | val_dataset_config = copy.deepcopy(config.dataset_config) 39 | 40 | val_dataset_config['config_file']['expanded'].paras.for_train = False 41 | val_dataset_config['config_file']['expanded'].paras.shuffle = False 42 | val_dataset_config['config_file']['expanded'].paras.mode = 'val' 43 | val_dataset = DatasetFactory.get_dataset(val_dataset_config) 44 | else: 45 | val_dataset = None 46 | 47 | 48 | trainer = TrainerFactory.get_trainer(config.training_config) 49 | 50 | if config.extra_config['distributed']: 51 | logging.info("using distributed training......") 52 | trainer.config_distributed_computing(launcher=config.extra_config['launcher'], 53 | tcp_port=config.extra_config['tcp_port'], 54 | local_rank=config.extra_config['local_rank']) 55 | model = ModelFactory.get_model(config.model_config) 56 | 57 | logger = None 58 | if args.log_dir is not None : 59 | config._logging_config['log_dir'] = args.log_dir 60 | 61 | if (not config.extra_config['distributed']) or (os.environ['RANK'] == str(0)): 62 | logger = BasicLogger.get_logger(config) 63 | logger.log_config(config) 64 | else: 65 | logger = DummyLogger.get_logger(config) 66 | trainer.set_model(model) 67 | trainer.set_dataset(dataset) 68 | if config.training_config.enable_val: 69 | trainer.set_val_dataset(val_dataset) 70 | trainer.set_logger(logger) 71 | if args.check_point_file is not None: 72 | trainer.load_state(args.check_point_file) 73 | elif "ckpt" in config.training_config: 74 | trainer.load_state(config.training_config.ckpt) 75 | logging.info("Preparation done! Trainer run!") 76 | trainer.run() 77 | if args.screen_log is not None: 78 | logger.log_model_params(model, force=True) 79 | logger.copy_screen_log(args.screen_log) 80 | except Exception as e: 81 | if logger is not None: 82 | logger.log_model_params(model, force=True) 83 | logging.exception(traceback.format_exc()) 84 | exit(-1) 85 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import traceback 4 | 5 | path = os.path.dirname(os.path.abspath(__file__)) 6 | py_list = [] 7 | for root, dirs, files in os.walk(path, topdown=False): 8 | for name in files: 9 | if name.endswith(".py") and not name.endswith("__init__.py"): 10 | rel_dir = os.path.relpath(root, path) 11 | if rel_dir != ".": 12 | rel_file = os.path.join(rel_dir, name) 13 | else: 14 | rel_file = name 15 | py_list.append(rel_file) 16 | for py in py_list: 17 | mod_name = ".".join([__name__, *(py.split("/"))]) 18 | mod_name = mod_name[:-3] 19 | try: 20 | mod = __import__(mod_name, fromlist=[mod_name]) 21 | except ModuleNotFoundError as e: 22 | logging.error(traceback.format_exc()) 23 | logging.debug("Fail to import submodule:".format(mod_name)) 24 | continue 25 | classes = [getattr(mod, x) for x in dir(mod) if isinstance(getattr(mod, x), type)] 26 | for cls in classes: 27 | if "model" in str(cls): 28 | globals()[cls.__name__] = cls 29 | -------------------------------------------------------------------------------- /model/auto_encoder/ms_vqgan/blocks.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | def nonlinearity(x): 6 | # swish 7 | return x * torch.sigmoid(x) 8 | 9 | 10 | def Normalize(in_channels): 11 | return torch.nn.GroupNorm( 12 | num_groups=32, num_channels=in_channels, eps=1e-6, affine=True 13 | ) 14 | 15 | 16 | class Upsample(nn.Module): 17 | def __init__(self, in_channels, with_conv): 18 | super().__init__() 19 | self.with_conv = with_conv 20 | if self.with_conv: 21 | self.conv = torch.nn.Conv3d( 22 | in_channels, in_channels, kernel_size=3, stride=1, padding=1 23 | ) 24 | 25 | def forward(self, x): 26 | x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") 27 | if self.with_conv: 28 | x = self.conv(x) 29 | return x 30 | 31 | 32 | class Downsample(nn.Module): 33 | def __init__(self, in_channels, with_conv): 34 | super().__init__() 35 | self.with_conv = with_conv 36 | if self.with_conv: 37 | # no asymmetric padding in torch conv, must do it ourselves 38 | self.conv = torch.nn.Conv3d( 39 | in_channels, in_channels, kernel_size=3, stride=2, padding=0 40 | ) 41 | 42 | def forward(self, x): 43 | if self.with_conv: 44 | pad = (1, 1, 1, 1, 1, 1) 45 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0) 46 | x = self.conv(x) 47 | else: 48 | x = torch.nn.functional.avg_pool3d(x, kernel_size=2, stride=2) 49 | return x 50 | 51 | 52 | class ResnetBlock(nn.Module): 53 | def __init__( 54 | self, 55 | *, 56 | in_channels, 57 | out_channels=None, 58 | conv_shortcut=False, 59 | dropout, 60 | temb_channels=512 61 | ): 62 | super().__init__() 63 | self.in_channels = in_channels 64 | out_channels = in_channels if out_channels is None else out_channels 65 | self.out_channels = out_channels 66 | self.use_conv_shortcut = conv_shortcut 67 | 68 | self.norm1 = Normalize(in_channels) 69 | self.conv1 = torch.nn.Conv3d( 70 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 71 | ) 72 | if temb_channels > 0: 73 | self.temb_proj = torch.nn.Linear(temb_channels, out_channels) 74 | self.norm2 = Normalize(out_channels) 75 | self.dropout = torch.nn.Dropout(dropout) 76 | self.conv2 = torch.nn.Conv3d( 77 | out_channels, out_channels, kernel_size=3, stride=1, padding=1 78 | ) 79 | if self.in_channels != self.out_channels: 80 | if self.use_conv_shortcut: 81 | self.conv_shortcut = torch.nn.Conv3d( 82 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 83 | ) 84 | else: 85 | self.nin_shortcut = torch.nn.Conv3d( 86 | in_channels, out_channels, kernel_size=1, stride=1, padding=0 87 | ) 88 | 89 | def forward(self, x, temb): 90 | h = x 91 | h = self.norm1(h) 92 | h = nonlinearity(h) 93 | h = self.conv1(h) 94 | 95 | if temb is not None: 96 | h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] 97 | 98 | h = self.norm2(h) 99 | h = nonlinearity(h) 100 | h = self.dropout(h) 101 | h = self.conv2(h) 102 | 103 | if self.in_channels != self.out_channels: 104 | if self.use_conv_shortcut: 105 | x = self.conv_shortcut(x) 106 | else: 107 | x = self.nin_shortcut(x) 108 | 109 | return x + h 110 | 111 | 112 | class AttnBlock(nn.Module): 113 | def __init__(self, in_channels): 114 | super().__init__() 115 | self.in_channels = in_channels 116 | 117 | self.norm = Normalize(in_channels) 118 | self.q = torch.nn.Conv3d( 119 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 120 | ) 121 | self.k = torch.nn.Conv3d( 122 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 123 | ) 124 | self.v = torch.nn.Conv3d( 125 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 126 | ) 127 | self.proj_out = torch.nn.Conv3d( 128 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 129 | ) 130 | 131 | def forward(self, x): 132 | h_ = x 133 | h_ = self.norm(h_) 134 | q = self.q(h_) 135 | k = self.k(h_) 136 | v = self.v(h_) 137 | 138 | # compute attention 139 | b, c, h, w, l = q.shape 140 | q = q.reshape(b, c, h * w * l) 141 | q = q.permute(0, 2, 1) # b,hw,c 142 | k = k.reshape(b, c, h * w * l) # b,c,hw 143 | w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] 144 | w_ = w_ * (int(c) ** (-0.5)) 145 | w_ = torch.nn.functional.softmax(w_, dim=2) 146 | 147 | # attend to values 148 | v = v.reshape(b, c, h * w * l) 149 | w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) 150 | h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] 151 | h_ = h_.reshape(b, c, h, w, l) 152 | 153 | h_ = self.proj_out(h_) 154 | 155 | return x + h_ 156 | -------------------------------------------------------------------------------- /model/auto_encoder/ms_vqgan/discriminator.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch.nn as nn 3 | import torch 4 | 5 | 6 | def weights_init(m): 7 | classname = m.__class__.__name__ 8 | if classname.find("Conv") != -1: 9 | nn.init.normal_(m.weight.data, 0.0, 0.02) 10 | elif classname.find("BatchNorm") != -1: 11 | nn.init.normal_(m.weight.data, 1.0, 0.02) 12 | nn.init.constant_(m.bias.data, 0) 13 | 14 | 15 | class NLayerDiscriminator(nn.Module): 16 | """Defines a PatchGAN discriminator as in Pix2Pix 17 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 18 | """ 19 | 20 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 21 | """Construct a PatchGAN discriminator 22 | Parameters: 23 | input_nc (int) -- the number of channels in input images 24 | ndf (int) -- the number of filters in the last conv layer 25 | n_layers (int) -- the number of conv layers in the discriminator 26 | norm_layer -- normalization layer 27 | """ 28 | super(NLayerDiscriminator, self).__init__() 29 | if not use_actnorm: 30 | norm_layer = nn.BatchNorm3d 31 | else: 32 | norm_layer = ActNorm 33 | if ( 34 | type(norm_layer) == functools.partial 35 | ): # no need to use bias as BatchNorm2d has affine parameters 36 | use_bias = norm_layer.func != nn.BatchNorm3d 37 | else: 38 | use_bias = norm_layer != nn.BatchNorm3d 39 | 40 | kw = 4 41 | padw = 1 42 | sequence = [ 43 | nn.Conv3d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), 44 | nn.LeakyReLU(0.2, True), 45 | ] 46 | nf_mult = 1 47 | nf_mult_prev = 1 48 | for n in range(1, n_layers): # gradually increase the number of filters 49 | nf_mult_prev = nf_mult 50 | nf_mult = min(2**n, 8) 51 | sequence += [ 52 | nn.Conv3d( 53 | ndf * nf_mult_prev, 54 | ndf * nf_mult, 55 | kernel_size=kw, 56 | stride=2, 57 | padding=padw, 58 | bias=use_bias, 59 | ), 60 | norm_layer(ndf * nf_mult), 61 | nn.LeakyReLU(0.2, True), 62 | ] 63 | 64 | nf_mult_prev = nf_mult 65 | nf_mult = min(2**n_layers, 8) 66 | sequence += [ 67 | nn.Conv3d( 68 | ndf * nf_mult_prev, 69 | ndf * nf_mult, 70 | kernel_size=kw, 71 | stride=1, 72 | padding=padw, 73 | bias=use_bias, 74 | ), 75 | norm_layer(ndf * nf_mult), 76 | nn.LeakyReLU(0.2, True), 77 | ] 78 | 79 | sequence += [ 80 | nn.Conv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) 81 | ] # output 1 channel prediction map 82 | self.main = nn.Sequential(*sequence) 83 | 84 | def forward(self, input): 85 | """Standard forward.""" 86 | return self.main(input) 87 | 88 | 89 | class ActNorm(nn.Module): 90 | def __init__( 91 | self, num_features, logdet=False, affine=True, allow_reverse_init=False 92 | ): 93 | assert affine 94 | super().__init__() 95 | self.logdet = logdet 96 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 97 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 98 | self.allow_reverse_init = allow_reverse_init 99 | 100 | self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) 101 | 102 | def initialize(self, input): 103 | with torch.no_grad(): 104 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 105 | mean = ( 106 | flatten.mean(1) 107 | .unsqueeze(1) 108 | .unsqueeze(2) 109 | .unsqueeze(3) 110 | .permute(1, 0, 2, 3) 111 | ) 112 | std = ( 113 | flatten.std(1) 114 | .unsqueeze(1) 115 | .unsqueeze(2) 116 | .unsqueeze(3) 117 | .permute(1, 0, 2, 3) 118 | ) 119 | 120 | self.loc.data.copy_(-mean) 121 | self.scale.data.copy_(1 / (std + 1e-6)) 122 | 123 | def forward(self, input, reverse=False): 124 | if reverse: 125 | return self.reverse(input) 126 | if len(input.shape) == 2: 127 | input = input[:, :, None, None] 128 | squeeze = True 129 | else: 130 | squeeze = False 131 | 132 | _, _, height, width = input.shape 133 | 134 | if self.training and self.initialized.item() == 0: 135 | self.initialize(input) 136 | self.initialized.fill_(1) 137 | 138 | h = self.scale * (input + self.loc) 139 | 140 | if squeeze: 141 | h = h.squeeze(-1).squeeze(-1) 142 | 143 | if self.logdet: 144 | log_abs = torch.log(torch.abs(self.scale)) 145 | logdet = height * width * torch.sum(log_abs) 146 | logdet = logdet * torch.ones(input.shape[0]).to(input) 147 | return h, logdet 148 | 149 | return h 150 | 151 | def reverse(self, output): 152 | if self.training and self.initialized.item() == 0: 153 | if not self.allow_reverse_init: 154 | raise RuntimeError( 155 | "Initializing ActNorm in reverse direction is " 156 | "disabled by default. Use allow_reverse_init=True to enable." 157 | ) 158 | else: 159 | self.initialize(output) 160 | self.initialized.fill_(1) 161 | 162 | if len(output.shape) == 2: 163 | output = output[:, :, None, None] 164 | squeeze = True 165 | else: 166 | squeeze = False 167 | 168 | h = output / self.scale - self.loc 169 | 170 | if squeeze: 171 | h = h.squeeze(-1).squeeze(-1) 172 | return h 173 | -------------------------------------------------------------------------------- /model/auto_encoder/ms_vqgan/encoder_decoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import numpy as np 4 | 5 | from model.auto_encoder.ms_vqgan.blocks import ( 6 | ResnetBlock, 7 | AttnBlock, 8 | Downsample, 9 | Upsample, 10 | Normalize, 11 | nonlinearity, 12 | ) 13 | 14 | 15 | class Encoder(nn.Module): 16 | def __init__( 17 | self, 18 | *, 19 | ch, 20 | out_ch, 21 | ch_mult=(1, 2, 4, 8), 22 | num_res_blocks, 23 | attn_resolutions, 24 | dropout=0.0, 25 | resamp_with_conv=True, 26 | in_channels, 27 | resolution, 28 | z_channels, 29 | double_z=True, 30 | **ignore_kwargs 31 | ): 32 | super().__init__() 33 | self.ch = ch 34 | self.temb_ch = 0 35 | self.num_resolutions = len(ch_mult) 36 | self.num_res_blocks = num_res_blocks 37 | self.resolution = resolution 38 | self.in_channels = in_channels 39 | 40 | # downsampling 41 | self.conv_in = torch.nn.Conv3d( 42 | in_channels, self.ch, kernel_size=3, stride=1, padding=1 43 | ) 44 | 45 | curr_res = resolution 46 | in_ch_mult = (1,) + tuple(ch_mult) 47 | self.down = nn.ModuleList() 48 | for i_level in range(self.num_resolutions): 49 | block = nn.ModuleList() 50 | attn = nn.ModuleList() 51 | block_in = ch * in_ch_mult[i_level] 52 | block_out = ch * ch_mult[i_level] 53 | for i_block in range(self.num_res_blocks): 54 | block.append( 55 | ResnetBlock( 56 | in_channels=block_in, 57 | out_channels=block_out, 58 | temb_channels=self.temb_ch, 59 | dropout=dropout, 60 | ) 61 | ) 62 | block_in = block_out 63 | if curr_res in attn_resolutions: 64 | attn.append(AttnBlock(block_in)) 65 | down = nn.Module() 66 | down.block = block 67 | down.attn = attn 68 | if i_level != self.num_resolutions - 1: 69 | down.downsample = Downsample(block_in, resamp_with_conv) 70 | curr_res = curr_res // 2 71 | self.down.append(down) 72 | 73 | # middle 74 | self.mid = nn.Module() 75 | self.mid.block_1 = ResnetBlock( 76 | in_channels=block_in, 77 | out_channels=block_in, 78 | temb_channels=self.temb_ch, 79 | dropout=dropout, 80 | ) 81 | self.mid.attn_1 = AttnBlock(block_in) 82 | self.mid.block_2 = ResnetBlock( 83 | in_channels=block_in, 84 | out_channels=block_in, 85 | temb_channels=self.temb_ch, 86 | dropout=dropout, 87 | ) 88 | 89 | # end 90 | self.norm_out = Normalize(block_in) 91 | self.conv_out = torch.nn.Conv3d( 92 | block_in, 93 | 2 * z_channels if double_z else z_channels, 94 | kernel_size=3, 95 | stride=1, 96 | padding=1, 97 | ) 98 | 99 | def forward(self, x): 100 | # assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution) 101 | 102 | # timestep embedding 103 | temb = None 104 | 105 | # downsampling 106 | hs = [self.conv_in(x)] 107 | for i_level in range(self.num_resolutions): 108 | for i_block in range(self.num_res_blocks): 109 | h = self.down[i_level].block[i_block](hs[-1], temb) 110 | if len(self.down[i_level].attn) > 0: 111 | h = self.down[i_level].attn[i_block](h) 112 | hs.append(h) 113 | if i_level != self.num_resolutions - 1: 114 | hs.append(self.down[i_level].downsample(hs[-1])) 115 | 116 | # middle 117 | h = hs[-1] 118 | h = self.mid.block_1(h, temb) 119 | h = self.mid.attn_1(h) 120 | h = self.mid.block_2(h, temb) 121 | 122 | # end 123 | h = self.norm_out(h) 124 | h = nonlinearity(h) 125 | h = self.conv_out(h) 126 | 127 | return h 128 | 129 | 130 | class Decoder_occ(nn.Module): 131 | def __init__( 132 | self, 133 | *, 134 | ch, 135 | out_ch, 136 | ch_mult=(1, 2, 4, 8), 137 | num_res_blocks, 138 | attn_resolutions, 139 | dropout=0.0, 140 | resamp_with_conv=True, 141 | in_channels, 142 | resolution, 143 | z_channels, 144 | give_pre_end=False, 145 | **ignorekwargs 146 | ): 147 | super().__init__() 148 | self.ch = ch 149 | self.temb_ch = 0 150 | self.num_resolutions = len(ch_mult) 151 | self.num_res_blocks = num_res_blocks 152 | self.resolution = resolution 153 | self.in_channels = in_channels 154 | self.give_pre_end = give_pre_end 155 | 156 | # compute in_ch_mult, block_in and curr_res at lowest res 157 | in_ch_mult = (1,) + tuple(ch_mult) 158 | block_in = ch * ch_mult[self.num_resolutions - 1] 159 | curr_res = resolution // 2 ** (self.num_resolutions - 1) 160 | self.z_channel = z_channels 161 | 162 | # z to block_inz 163 | self.conv_in = torch.nn.Conv3d( 164 | z_channels, block_in, kernel_size=3, stride=1, padding=1 165 | ) 166 | 167 | # middle 168 | self.mid = nn.Module() 169 | self.mid.block_1 = ResnetBlock( 170 | in_channels=block_in, 171 | out_channels=block_in, 172 | temb_channels=self.temb_ch, 173 | dropout=dropout, 174 | ) 175 | self.mid.attn_1 = AttnBlock(block_in) 176 | self.mid.block_2 = ResnetBlock( 177 | in_channels=block_in, 178 | out_channels=block_in, 179 | temb_channels=self.temb_ch, 180 | dropout=dropout, 181 | ) 182 | 183 | # upsampling 184 | self.up = nn.ModuleList() 185 | for i_level in reversed(range(self.num_resolutions)): 186 | block = nn.ModuleList() 187 | attn = nn.ModuleList() 188 | block_out = ch * ch_mult[i_level] 189 | for i_block in range(self.num_res_blocks + 1): 190 | block.append( 191 | ResnetBlock( 192 | in_channels=block_in, 193 | out_channels=block_out, 194 | temb_channels=self.temb_ch, 195 | dropout=dropout, 196 | ) 197 | ) 198 | block_in = block_out 199 | if curr_res in attn_resolutions: 200 | attn.append(AttnBlock(block_in)) 201 | up = nn.Module() 202 | up.block = block 203 | up.attn = attn 204 | if i_level != 0: 205 | up.upsample = Upsample(block_in, resamp_with_conv) 206 | curr_res = curr_res * 2 207 | self.up.insert(0, up) # prepend to get consistent order 208 | 209 | # end 210 | self.norm_out = Normalize(block_in) 211 | self.conv_out = torch.nn.Conv3d( 212 | block_in, out_ch, kernel_size=3, stride=1, padding=1 213 | ) 214 | self.occ_conv_out = nn.Sequential( 215 | torch.nn.Conv3d(block_in, 128, kernel_size=3, stride=1, padding=1), 216 | torch.nn.SiLU(), 217 | torch.nn.Conv3d(128, 1, kernel_size=3, stride=1, padding=1), 218 | ) 219 | 220 | def forward(self, z): 221 | # assert z.shape[1:] == self.z_shape[1:] 222 | self.last_z_shape = z.shape 223 | 224 | # timestep embedding 225 | temb = None 226 | 227 | # z to block_in 228 | h = self.conv_in(z) 229 | 230 | # middle 231 | h = self.mid.block_1(h, temb) 232 | h = self.mid.attn_1(h) 233 | h = self.mid.block_2(h, temb) 234 | 235 | # upsampling 236 | for i_level in reversed(range(self.num_resolutions)): 237 | for i_block in range(self.num_res_blocks + 1): 238 | h = self.up[i_level].block[i_block](h, temb) 239 | if len(self.up[i_level].attn) > 0: 240 | h = self.up[i_level].attn[i_block](h) 241 | if i_level != 0: 242 | h = self.up[i_level].upsample(h) 243 | 244 | # end 245 | if self.give_pre_end: 246 | return h 247 | 248 | h = self.norm_out(h) 249 | h = nonlinearity(h) 250 | h1 = self.conv_out(h) 251 | h2 = self.occ_conv_out(h) 252 | return h1, h2 253 | -------------------------------------------------------------------------------- /model/auto_encoder/ms_vqgan/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | 9 | def __init__( 10 | self, 11 | warm_up_steps, 12 | lr_min, 13 | lr_max, 14 | lr_start, 15 | max_decay_steps, 16 | verbosity_interval=0, 17 | ): 18 | self.lr_warm_up_steps = warm_up_steps 19 | self.lr_start = lr_start 20 | self.lr_min = lr_min 21 | self.lr_max = lr_max 22 | self.lr_max_decay_steps = max_decay_steps 23 | self.last_lr = 0.0 24 | self.verbosity_interval = verbosity_interval 25 | 26 | def schedule(self, n): 27 | if self.verbosity_interval > 0: 28 | if n % self.verbosity_interval == 0: 29 | print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 30 | if n < self.lr_warm_up_steps: 31 | lr = ( 32 | self.lr_max - self.lr_start 33 | ) / self.lr_warm_up_steps * n + self.lr_start 34 | self.last_lr = lr 35 | return lr 36 | else: 37 | t = (n - self.lr_warm_up_steps) / ( 38 | self.lr_max_decay_steps - self.lr_warm_up_steps 39 | ) 40 | t = min(t, 1.0) 41 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 42 | 1 + np.cos(t * np.pi) 43 | ) 44 | self.last_lr = lr 45 | return lr 46 | 47 | def __call__(self, n): 48 | return self.schedule(n) 49 | -------------------------------------------------------------------------------- /model/auto_encoder/ms_vqgan/ms_tsdf_pvqgan_new.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import init 3 | import torch.nn as nn 4 | 5 | from einops import rearrange 6 | 7 | from model.model_base import ModelBase 8 | 9 | from model.auto_encoder.ms_vqgan.encoder_decoder import ( 10 | Encoder, 11 | Decoder_occ, 12 | nonlinearity, 13 | Upsample, 14 | ) 15 | 16 | from model.auto_encoder.ms_vqgan.quantize import GumbelQuantize 17 | from model.auto_encoder.ms_vqgan.loss import VQLossWithDiscriminator 18 | from model.auto_encoder.ms_vqgan.lr_scheduler import LambdaWarmUpCosineScheduler 19 | 20 | 21 | def init_weights(net, init_type="normal", gain=0.01): 22 | def init_func(m): 23 | classname = m.__class__.__name__ 24 | if classname.find("BatchNorm2d") != -1: 25 | if hasattr(m, "weight") and m.weight is not None: 26 | init.normal_(m.weight.data, 1.0, gain) 27 | if hasattr(m, "bias") and m.bias is not None: 28 | init.constant_(m.bias.data, 0.0) 29 | elif hasattr(m, "weight") and ( 30 | classname.find("Conv") != -1 or classname.find("Linear") != -1 31 | ): 32 | if init_type == "normal": 33 | init.normal_(m.weight.data, 0.0, gain) 34 | elif init_type == "xavier": 35 | init.xavier_normal_(m.weight.data, gain=gain) 36 | elif init_type == "xavier_uniform": 37 | init.xavier_uniform_(m.weight.data, gain=1.0) 38 | elif init_type == "kaiming": 39 | init.kaiming_normal_(m.weight.data, a=0, mode="fan_in") 40 | elif init_type == "orthogonal": 41 | init.orthogonal_(m.weight.data, gain=gain) 42 | elif init_type == "none": # uses pytorch's default init method 43 | m.reset_parameters() 44 | else: 45 | raise NotImplementedError( 46 | "initialization method [%s] is not implemented" % init_type 47 | ) 48 | if hasattr(m, "bias") and m.bias is not None: 49 | init.constant_(m.bias.data, 0.0) 50 | 51 | net.apply(init_func) 52 | 53 | # propagate to children 54 | for m in net.children(): 55 | m.apply(init_func) 56 | 57 | 58 | class Upsampler4x(nn.Module): 59 | def __init__(self, in_chn) -> None: 60 | super().__init__() 61 | self.up1 = Upsample(in_chn, True) 62 | self.up2 = Upsample(in_chn, True) 63 | 64 | def forward(self, x): 65 | x = self.up1(x) 66 | x = nonlinearity(x) 67 | x = self.up2(x) 68 | return x 69 | 70 | 71 | class Upsampler2x(nn.Module): 72 | def __init__(self, in_chn) -> None: 73 | super().__init__() 74 | self.up1 = Upsample(in_chn, True) 75 | 76 | def forward(self, x): 77 | x = self.up1(x) 78 | return x 79 | 80 | 81 | class MSTSDFPVQGANNew(ModelBase): 82 | def __init__(self, config): 83 | super().__init__() 84 | lossconfig = config.paras.lossconfig 85 | n_embed = config.paras.n_embed 86 | embed_dim = config.paras.embed_dim 87 | self.patch_mode = True 88 | self.encoder1 = Encoder(**config.paras.ddconfig1) 89 | self.encoder2 = Encoder(**config.paras.ddconfig2) 90 | 91 | self.decoder1 = Decoder_occ(**config.paras.ddconfig1) 92 | self.decoder2 = Decoder_occ(**config.paras.ddconfig2) 93 | 94 | lv2_downsample_num = len(config.paras.ddconfig2.ch_mult) - 1 95 | if lv2_downsample_num == 1: 96 | self.upsampler = Upsampler2x(config.paras.ddconfig1["z_channels"]) 97 | elif lv2_downsample_num == 2: 98 | self.upsampler = Upsampler4x(config.paras.ddconfig1["z_channels"]) 99 | 100 | self.loss = VQLossWithDiscriminator(**lossconfig.params) 101 | 102 | self.cube_size, self.stride = 4, 4 103 | if "patch_cube_size" in config.paras: 104 | self.cube_size, self.stride = ( 105 | config.paras["patch_cube_size"], 106 | config.paras["patch_cube_size"], 107 | ) 108 | 109 | z_channels = config.paras.ddconfig1["z_channels"] 110 | self.quantize1 = GumbelQuantize( 111 | z_channels, 112 | embed_dim, 113 | n_embed=n_embed, 114 | kl_weight=1e-8, 115 | temp_init=1.0, 116 | remap=None, 117 | ) 118 | 119 | self.quantize2 = GumbelQuantize( 120 | z_channels, 121 | embed_dim, 122 | n_embed=n_embed, 123 | kl_weight=1e-8, 124 | temp_init=1.0, 125 | remap=None, 126 | ) 127 | self.temperature_scheduler = LambdaWarmUpCosineScheduler( 128 | **config.paras.temperature_scheduler_config 129 | ) 130 | 131 | q_dim = embed_dim 132 | self.quant_conv1 = torch.nn.Conv3d( 133 | config.paras.ddconfig1["z_channels"] + config.paras.ddconfig2["z_channels"], 134 | q_dim, 135 | 1, 136 | ) 137 | self.post_quant_conv1 = torch.nn.Conv3d( 138 | embed_dim * 2, config.paras.ddconfig1["z_channels"], 1 139 | ) 140 | 141 | self.quant_conv2 = torch.nn.Conv3d( 142 | config.paras.ddconfig2["z_channels"], q_dim, 1 143 | ) 144 | self.post_quant_conv2 = torch.nn.Conv3d( 145 | embed_dim, config.paras.ddconfig2["z_channels"], 1 146 | ) 147 | 148 | if self.training: 149 | self.init_w() 150 | 151 | def init_w(self): 152 | init_weights(self.encoder1, "normal", 0.02) 153 | init_weights(self.encoder2, "normal", 0.02) 154 | init_weights(self.decoder1, "normal", 0.02) 155 | init_weights(self.decoder2, "normal", 0.02) 156 | init_weights(self.upsampler, "normal", 0.02) 157 | 158 | init_weights(self.quant_conv1, "normal", 0.02) 159 | init_weights(self.post_quant_conv1, "normal", 0.02) 160 | init_weights(self.quant_conv2, "normal", 0.02) 161 | init_weights(self.post_quant_conv2, "normal", 0.02) 162 | 163 | @staticmethod 164 | # def unfold_to_cubes(self, x, cube_size=8, stride=8): 165 | def unfold_to_cubes(x, cube_size=8, stride=8): 166 | """ 167 | assume x.shape: b, c, d, h, w 168 | return: x_cubes: (b cubes) 169 | """ 170 | x_cubes = ( 171 | x.unfold(2, cube_size, stride) 172 | .unfold(3, cube_size, stride) 173 | .unfold(4, cube_size, stride) 174 | ) 175 | x_cubes = rearrange(x_cubes, "b c p1 p2 p3 d h w -> b c (p1 p2 p3) d h w") 176 | x_cubes = rearrange(x_cubes, "b c p d h w -> (b p) c d h w") 177 | return x_cubes 178 | 179 | @staticmethod 180 | # def fold_to_voxels(self, x_cubes, batch_size, ncubes_per_dim): 181 | def fold_to_voxels( 182 | x_cubes, batch_size, ncubes_per_dim, ncubes_per_dim2, ncubes_per_dim3 183 | ): 184 | x = rearrange(x_cubes, "(b p) c d h w -> b p c d h w", b=batch_size) 185 | x = rearrange( 186 | x, 187 | "b (p1 p2 p3) c d h w -> b c (p1 d) (p2 h) (p3 w)", 188 | p1=ncubes_per_dim, 189 | p2=ncubes_per_dim2, 190 | p3=ncubes_per_dim3, 191 | ) 192 | return x 193 | 194 | def temperature_scheduling(self, global_step): 195 | self.quantize1.temperature = self.temperature_scheduler(global_step) 196 | self.quantize2.temperature = self.temperature_scheduler(global_step) 197 | 198 | def decode(self, quant_t, quant_b, occlv1=None, sparse_decode=True): 199 | upsample_t = self.upsampler(quant_t) 200 | quant = torch.cat([upsample_t, quant_b], 1) 201 | if sparse_decode: 202 | chn = quant.shape[1] 203 | occ = torch.sigmoid(occlv1) 204 | occ_mask = occ > 0.5 205 | occ[occ_mask] = 1 206 | occ[~occ_mask] = 0 207 | occ = occ.to(torch.bool) 208 | occ = occ.repeat([1, chn, 1, 1, 1]) 209 | quant[~occ] *= 0 210 | quant = self.post_quant_conv1(quant) 211 | dec, occ_l2 = self.decoder1(quant) 212 | return dec, occ_l2 213 | 214 | def patch_encode(self, input_data, before_quant=False): 215 | cur_bs = input_data.shape[0] 216 | ncubes_per_dim = [i // self.cube_size for i in input_data.shape[2:]] 217 | x_cubes1 = self.unfold_to_cubes(input_data, self.cube_size, self.stride) 218 | h1 = self.encoder1(x_cubes1) 219 | h1_voxel = self.fold_to_voxels(h1, cur_bs, *ncubes_per_dim) 220 | cur_bs = h1_voxel.shape[0] 221 | ncubes_per_dim = [i // self.cube_size for i in h1_voxel.shape[2:]] 222 | 223 | x_cubes2 = self.unfold_to_cubes(h1_voxel, self.cube_size, self.stride) 224 | h2 = self.encoder2(x_cubes2) 225 | h2_voxel = self.fold_to_voxels(h2, cur_bs, *ncubes_per_dim) 226 | 227 | quant2_ = self.quant_conv2(h2) 228 | 229 | if before_quant: 230 | quant2_voxel_ = self.fold_to_voxels(quant2_, cur_bs, *ncubes_per_dim) 231 | 232 | quant2, diff_t, id_t = self.quantize2(quant2_) 233 | quant2_voxel = self.fold_to_voxels(quant2, cur_bs, *ncubes_per_dim) 234 | 235 | dec_t, occ_l1 = self.decoder2(self.post_quant_conv2(quant2_voxel)) 236 | dec_t_cubes = self.unfold_to_cubes(dec_t, 1, 1) 237 | enc_b = torch.cat([dec_t_cubes, h1], 1) 238 | 239 | quant1_ = self.quant_conv1(enc_b) 240 | cur_bs = input_data.shape[0] 241 | ncubes_per_dim = [i // self.cube_size for i in input_data.shape[2:]] 242 | 243 | if before_quant: 244 | quant1_voxel_ = self.fold_to_voxels(quant1_, cur_bs, *ncubes_per_dim) 245 | 246 | quant1, diff_b, id_b = self.quantize1(quant1_) 247 | quant1_voxel = self.fold_to_voxels(quant1, cur_bs, *ncubes_per_dim) 248 | if before_quant: 249 | return quant2_voxel_, quant1_voxel_, quant2_voxel, quant1_voxel, occ_l1 250 | else: 251 | return quant2_voxel, quant1_voxel, diff_t + diff_b, id_t, id_b, occ_l1 252 | 253 | def forward(self, data_dict): 254 | self.temperature_scheduling(data_dict["global_step"]) 255 | 256 | input_data = data_dict["input"] 257 | quant1, quant2, diff, _, _, occ_l1 = self.patch_encode(input_data) 258 | dec, occ_l2 = self.decode(quant1, quant2, occ_l1) 259 | recon = { 260 | "tsdf": dec, 261 | "occ_l1": occ_l1, 262 | "occ_l2": occ_l2, 263 | } 264 | return recon, diff 265 | 266 | def get_code(self, vol): 267 | return self.patch_encode(vol, before_quant=True) 268 | 269 | def get_last_layer(self): 270 | return self.decoder1.conv_out.weight 271 | 272 | def get_ae_paras(self): 273 | paras = ( 274 | list(self.encoder1.parameters()) 275 | + list(self.encoder2.parameters()) 276 | + list(self.decoder1.parameters()) 277 | + list(self.decoder2.parameters()) 278 | + list(self.upsampler.parameters()) 279 | + list(self.quantize1.parameters()) 280 | + list(self.quant_conv1.parameters()) 281 | + list(self.post_quant_conv1.parameters()) 282 | + list(self.quantize2.parameters()) 283 | + list(self.quant_conv2.parameters()) 284 | + list(self.post_quant_conv2.parameters()) 285 | ) 286 | return paras 287 | -------------------------------------------------------------------------------- /model/auto_encoder/ms_vqgan/quantize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from einops import rearrange 6 | from torch import einsum 7 | 8 | 9 | class GumbelQuantize(nn.Module): 10 | """ 11 | credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!) 12 | Gumbel Softmax trick quantizer 13 | Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016 14 | https://arxiv.org/abs/1611.01144 15 | """ 16 | 17 | def __init__( 18 | self, 19 | num_hiddens, 20 | embedding_dim, 21 | n_embed, 22 | straight_through=True, 23 | kl_weight=1e-8, 24 | temp_init=1.0, 25 | use_vqinterface=True, 26 | remap=None, 27 | unknown_index="random", 28 | use_3d=True, 29 | ): 30 | super().__init__() 31 | 32 | self.embedding_dim = embedding_dim 33 | self.n_embed = n_embed 34 | 35 | self.straight_through = straight_through 36 | self.temperature = temp_init 37 | self.kl_weight = kl_weight 38 | self.use_3d = use_3d 39 | if use_3d: 40 | self.proj = nn.Conv3d(num_hiddens, n_embed, 1) 41 | else: 42 | self.proj = nn.Conv2d(num_hiddens, n_embed, 1) 43 | self.embed = nn.Embedding(n_embed, embedding_dim) 44 | 45 | self.use_vqinterface = use_vqinterface 46 | 47 | self.remap = remap 48 | if self.remap is not None: 49 | self.register_buffer("used", torch.tensor(np.load(self.remap))) 50 | self.re_embed = self.used.shape[0] 51 | self.unknown_index = unknown_index # "random" or "extra" or integer 52 | if self.unknown_index == "extra": 53 | self.unknown_index = self.re_embed 54 | self.re_embed = self.re_embed + 1 55 | print( 56 | f"Remapping {self.n_embed} indices to {self.re_embed} indices. " 57 | f"Using {self.unknown_index} for unknown indices." 58 | ) 59 | else: 60 | self.re_embed = n_embed 61 | 62 | def remap_to_used(self, inds): 63 | ishape = inds.shape 64 | assert len(ishape) > 1 65 | inds = inds.reshape(ishape[0], -1) 66 | used = self.used.to(inds) 67 | match = (inds[:, :, None] == used[None, None, ...]).long() 68 | new = match.argmax(-1) 69 | unknown = match.sum(2) < 1 70 | if self.unknown_index == "random": 71 | new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to( 72 | device=new.device 73 | ) 74 | else: 75 | new[unknown] = self.unknown_index 76 | return new.reshape(ishape) 77 | 78 | def unmap_to_all(self, inds): 79 | ishape = inds.shape 80 | assert len(ishape) > 1 81 | inds = inds.reshape(ishape[0], -1) 82 | used = self.used.to(inds) 83 | if self.re_embed > self.used.shape[0]: # extra token 84 | inds[inds >= self.used.shape[0]] = 0 # simply set to zero 85 | back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) 86 | return back.reshape(ishape) 87 | 88 | def forward(self, z, temp=None, return_logits=False): 89 | # force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work 90 | hard = self.straight_through if self.training else True 91 | temp = self.temperature if temp is None else temp 92 | 93 | if self.use_3d: 94 | assert len(z.shape) == 5 95 | 96 | logits = self.proj(z) 97 | 98 | if self.remap is not None: 99 | # continue only with used logits 100 | full_zeros = torch.zeros_like(logits) 101 | logits = logits[:, self.used, ...] 102 | soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard) 103 | 104 | if self.remap is not None: 105 | # go back to all entries but unused set to zero 106 | full_zeros[:, self.used, ...] = soft_one_hot 107 | soft_one_hot = full_zeros 108 | if self.use_3d: 109 | z_q = einsum("b n h w l, n d -> b d h w l", soft_one_hot, self.embed.weight) 110 | else: 111 | z_q = einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight) 112 | qy = F.softmax(logits, dim=1) 113 | diff = ( 114 | self.kl_weight 115 | * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean() 116 | ) 117 | 118 | ind = soft_one_hot.argmax(dim=1) 119 | if self.remap is not None: 120 | ind = self.remap_to_used(ind) 121 | if self.use_vqinterface: 122 | if return_logits: 123 | return z_q, diff, (None, None, ind), logits 124 | return z_q, diff, (None, None, ind) 125 | return z_q, diff, ind 126 | 127 | def get_codebook_entry(self, indices, shape): 128 | if len(shape) == 4: 129 | b, h, w, c = shape 130 | assert b * h * w == indices.shape[0] 131 | indices = rearrange(indices, "(b h w) -> b h w", b=b, h=h, w=w) 132 | elif len(shape) == 5: 133 | b, h, w, l, c = shape 134 | assert b * h * w * l == indices.shape[0] 135 | indices = rearrange(indices, "(b h w l) -> b h w l", b=b, h=h, w=w, l=l) 136 | if self.remap is not None: 137 | indices = self.unmap_to_all(indices) 138 | if len(shape) == 4: 139 | one_hot = ( 140 | F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float() 141 | ) 142 | z_q = einsum("b n h w, n d -> b d h w", one_hot, self.embed.weight) 143 | elif len(shape) == 5: 144 | one_hot = ( 145 | F.one_hot(indices, num_classes=self.n_embed) 146 | .permute(0, 4, 1, 2, 3) 147 | .float() 148 | ) 149 | z_q = einsum("b n h w l, n d -> b d h w l", one_hot, self.embed.weight) 150 | return z_q 151 | -------------------------------------------------------------------------------- /model/diffuser/pipelines/pipeline_ddim.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import List, Optional, Union 16 | from tqdm import tqdm 17 | import torch 18 | from utils.config.Configuration import default 19 | 20 | 21 | class DDIMPipeline: 22 | 23 | def __init__(self, unet, scheduler): 24 | super().__init__() 25 | self.unet = unet 26 | self.scheduler = scheduler 27 | 28 | @torch.no_grad() 29 | def __call__( 30 | self, 31 | y_cond, 32 | y_t, 33 | mask, 34 | eta: float = 0.0, 35 | num_inference_steps: int = 50, 36 | use_clipped_model_output: Optional[bool] = None, 37 | start_time_step_inx=0, 38 | **args 39 | ): 40 | 41 | image = y_t 42 | txt_cond = default(args, "txt_cond", None) 43 | class_label = default(args, "class_label", None) 44 | 45 | # set step values 46 | self.scheduler.set_timesteps(num_inference_steps) 47 | bs = y_t.shape[0] 48 | 49 | feat_chn = image.shape[1] 50 | mask_adaptive = mask.repeat([1, feat_chn] + (len(mask.shape) - 2) * [1]) 51 | 52 | for t in tqdm( 53 | self.scheduler.timesteps[start_time_step_inx:], 54 | desc="sampling loop time step", 55 | total=num_inference_steps - start_time_step_inx, 56 | ): 57 | # for t in self.scheduler.timesteps: 58 | # 1. predict noise model_output 59 | if y_cond is not None: 60 | model_input = torch.cat([y_cond, image], dim=1) # mask done in unet... 61 | else: 62 | model_input = image 63 | alphas_cumprod = self.scheduler.alphas_cumprod.to(image.device) 64 | timestep_encoding = alphas_cumprod[t].repeat(bs, 1).to(y_t.device) 65 | 66 | model_output = self.unet( 67 | model_input, 68 | timestep_encoding, 69 | mask=mask, 70 | context=txt_cond, 71 | y=class_label, 72 | ) # .sample 73 | 74 | image = self.scheduler.step( 75 | model_output, 76 | t, 77 | image, 78 | eta=eta, 79 | use_clipped_model_output=use_clipped_model_output, 80 | ).prev_sample 81 | 82 | image[~mask_adaptive] = 0.0 83 | return image 84 | -------------------------------------------------------------------------------- /model/diffuser/pipelines/pipeline_ddpm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | from tqdm import tqdm 18 | 19 | import torch 20 | from utils.config.Configuration import default 21 | from ..utils import randn_tensor 22 | 23 | # from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput 24 | 25 | 26 | class DDPMPipeline: 27 | r""" 28 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the 29 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) 30 | 31 | Parameters: 32 | unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. 33 | scheduler ([`SchedulerMixin`]): 34 | A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of 35 | [`DDPMScheduler`], or [`DDIMScheduler`]. 36 | """ 37 | 38 | def __init__(self, unet, scheduler): 39 | super().__init__() 40 | # self.register_modules(unet=unet, scheduler=scheduler) 41 | self.unet = unet 42 | self.scheduler = scheduler 43 | 44 | @torch.no_grad() 45 | def __call__(self, y_cond, y_t, mask, num_inference_steps: int = 50, **args): 46 | 47 | image = y_t 48 | txt_cond = default(args, "txt_cond", None) 49 | 50 | # set step values 51 | self.scheduler.set_timesteps(num_inference_steps) 52 | bs = y_t.shape[0] 53 | 54 | feat_chn = image.shape[1] 55 | mask_adaptive = mask.repeat([1, feat_chn] + (len(mask.shape) - 2) * [1]) 56 | 57 | for t in tqdm( 58 | self.scheduler.timesteps, 59 | desc="sampling loop time step", 60 | total=num_inference_steps, 61 | ): 62 | # for t in self.scheduler.timesteps: 63 | # 1. predict noise model_output 64 | if y_cond is not None: 65 | model_input = torch.cat([y_cond, image], dim=1) # mask done in unet... 66 | else: 67 | model_input = image 68 | alphas_cumprod = self.scheduler.alphas_cumprod.to(image.device) 69 | timestep_encoding = alphas_cumprod[t].repeat(bs, 1).to(y_t.device) 70 | 71 | model_output = self.unet( 72 | model_input, 73 | timestep_encoding, 74 | mask=mask, 75 | context=txt_cond, 76 | ) # .sample 77 | 78 | # 2. predict previous mean of image x_t-1 and add variance depending on eta 79 | # eta corresponds to η in paper and should be between [0, 1] 80 | # do x_t -> x_t-1 81 | image = self.scheduler.step( 82 | model_output, 83 | t, 84 | image, 85 | ).prev_sample 86 | 87 | image[~mask_adaptive] = 0.0 88 | return image 89 | -------------------------------------------------------------------------------- /model/diffuser/schedulers/noise_schedule.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | 5 | def betas_cosine(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor: 6 | """ 7 | Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of 8 | (1-beta) over time from t = [0,1]. 9 | 10 | Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up 11 | to that part of the diffusion process. 12 | 13 | 14 | Args: 15 | num_diffusion_timesteps (`int`): the number of betas to produce. 16 | max_beta (`float`): the maximum beta to use; use values lower than 1 to 17 | prevent singularities. 18 | 19 | Returns: 20 | betas (`np.ndarray`): the betas used by the scheduler to step the model outputs 21 | """ 22 | 23 | def alpha_bar(time_step): 24 | return math.cos((time_step + 0.25) / 1.25 * math.pi / 2) ** 2 25 | 26 | betas = [] 27 | for i in range(num_diffusion_timesteps): 28 | t1 = i / num_diffusion_timesteps 29 | t2 = (i + 1) / num_diffusion_timesteps 30 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 31 | return torch.tensor(betas, dtype=torch.float32) 32 | 33 | 34 | def betas_straight(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor: 35 | """ 36 | Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of 37 | (1-beta) over time from t = [0,1]. 38 | 39 | Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up 40 | to that part of the diffusion process. 41 | 42 | 43 | Args: 44 | num_diffusion_timesteps (`int`): the number of betas to produce. 45 | max_beta (`float`): the maximum beta to use; use values lower than 1 to 46 | prevent singularities. 47 | 48 | Returns: 49 | betas (`np.ndarray`): the betas used by the scheduler to step the model outputs 50 | """ 51 | 52 | def alpha_bar(time_step): 53 | return 1 - time_step 54 | 55 | betas = [] 56 | for i in range(num_diffusion_timesteps): 57 | t1 = i / num_diffusion_timesteps 58 | t2 = (i + 1) / num_diffusion_timesteps 59 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 60 | return torch.tensor(betas, dtype=torch.float32) 61 | 62 | 63 | def betas_arccos(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor: 64 | """ 65 | Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of 66 | (1-beta) over time from t = [0,1]. 67 | 68 | Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up 69 | to that part of the diffusion process. 70 | 71 | 72 | Args: 73 | num_diffusion_timesteps (`int`): the number of betas to produce. 74 | max_beta (`float`): the maximum beta to use; use values lower than 1 to 75 | prevent singularities. 76 | 77 | Returns: 78 | betas (`np.ndarray`): the betas used by the scheduler to step the model outputs 79 | """ 80 | 81 | def alpha_bar(time_step): 82 | return (math.acos(math.sqrt(time_step)) / (math.pi / 2.0)) * 1.25 - 0.25 83 | 84 | betas = [] 85 | for i in range(num_diffusion_timesteps): 86 | t1 = i / num_diffusion_timesteps 87 | t2 = (i + 1) / num_diffusion_timesteps 88 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 89 | return torch.tensor(betas, dtype=torch.float32) 90 | -------------------------------------------------------------------------------- /model/diffuser/utils.py: -------------------------------------------------------------------------------- 1 | # you may not use this file except in compliance with the License. 2 | # You may obtain a copy of the License at 3 | # 4 | # http://www.apache.org/licenses/LICENSE-2.0 5 | # 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | """ ConfigMixin base class and utilities.""" 12 | import functools 13 | import inspect 14 | 15 | from collections import OrderedDict 16 | from typing import Any, Dict, Tuple, Union 17 | from dataclasses import fields 18 | 19 | import logging 20 | from typing import List, Optional, Tuple, Union 21 | import torch 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | def register_to_config(init): 27 | r""" 28 | Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are 29 | automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that 30 | shouldn't be registered in the config, use the `ignore_for_config` class variable 31 | 32 | Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init! 33 | """ 34 | 35 | @functools.wraps(init) 36 | def inner_init(self, *args, **kwargs): 37 | # Ignore private kwargs in the init. 38 | init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} 39 | config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")} 40 | # if not isinstance(self, ConfigMixin): 41 | # raise RuntimeError( 42 | # f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does " 43 | # "not inherit from `ConfigMixin`." 44 | # ) 45 | 46 | ignore = getattr(self, "ignore_for_config", []) 47 | # Get positional arguments aligned with kwargs 48 | new_kwargs = {} 49 | signature = inspect.signature(init) 50 | parameters = { 51 | name: p.default 52 | for i, (name, p) in enumerate(signature.parameters.items()) 53 | if i > 0 and name not in ignore 54 | } 55 | for arg, name in zip(args, parameters.keys()): 56 | new_kwargs[name] = arg 57 | 58 | # Then add all kwargs 59 | new_kwargs.update( 60 | { 61 | k: init_kwargs.get(k, default) 62 | for k, default in parameters.items() 63 | if k not in ignore and k not in new_kwargs 64 | } 65 | ) 66 | new_kwargs = {**config_init_kwargs, **new_kwargs} 67 | getattr(self, "register_to_config")(**new_kwargs) 68 | init(self, *args, **init_kwargs) 69 | 70 | return inner_init 71 | 72 | 73 | def randn_tensor( 74 | shape: Union[Tuple, List], 75 | generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, 76 | device: Optional["torch.device"] = None, 77 | dtype: Optional["torch.dtype"] = None, 78 | layout: Optional["torch.layout"] = None, 79 | ): 80 | """This is a helper function that allows to create random tensors on the desired `device` with the desired `dtype`. When 81 | passing a list of generators one can seed each batched size individually. If CPU generators are passed the tensor 82 | will always be created on CPU. 83 | """ 84 | # device on which tensor is created defaults to device 85 | rand_device = device 86 | batch_size = shape[0] 87 | 88 | layout = layout or torch.strided 89 | device = device or torch.device("cpu") 90 | 91 | if generator is not None: 92 | gen_device_type = ( 93 | generator.device.type 94 | if not isinstance(generator, list) 95 | else generator[0].device.type 96 | ) 97 | if gen_device_type != device.type and gen_device_type == "cpu": 98 | rand_device = "cpu" 99 | if device != "mps": 100 | logger.info( 101 | f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." 102 | f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" 103 | f" slighly speed up this function by passing a generator that was created on the {device} device." 104 | ) 105 | elif gen_device_type != device.type and gen_device_type == "cuda": 106 | raise ValueError( 107 | f"Cannot generate a {device} tensor from a generator of type {gen_device_type}." 108 | ) 109 | 110 | if isinstance(generator, list): 111 | shape = (1,) + shape[1:] 112 | latents = [ 113 | torch.randn( 114 | shape, 115 | generator=generator[i], 116 | device=rand_device, 117 | dtype=dtype, 118 | layout=layout, 119 | ) 120 | for i in range(batch_size) 121 | ] 122 | latents = torch.cat(latents, dim=0).to(device) 123 | else: 124 | latents = torch.randn( 125 | shape, generator=generator, device=rand_device, dtype=dtype, layout=layout 126 | ).to(device) 127 | 128 | return latents 129 | 130 | 131 | class BaseOutput(OrderedDict): 132 | """ 133 | Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a 134 | tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular 135 | python dictionary. 136 | 137 | 138 | 139 | You can't unpack a `BaseOutput` directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple 140 | before. 141 | 142 | 143 | """ 144 | 145 | def __post_init__(self): 146 | class_fields = fields(self) 147 | 148 | # Safety and consistency checks 149 | if not len(class_fields): 150 | raise ValueError(f"{self.__class__.__name__} has no fields.") 151 | 152 | first_field = getattr(self, class_fields[0].name) 153 | other_fields_are_none = all( 154 | getattr(self, field.name) is None for field in class_fields[1:] 155 | ) 156 | 157 | if other_fields_are_none and isinstance(first_field, dict): 158 | for key, value in first_field.items(): 159 | self[key] = value 160 | else: 161 | for field in class_fields: 162 | v = getattr(self, field.name) 163 | if v is not None: 164 | self[field.name] = v 165 | 166 | def __delitem__(self, *args, **kwargs): 167 | raise Exception( 168 | f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance." 169 | ) 170 | 171 | def setdefault(self, *args, **kwargs): 172 | raise Exception( 173 | f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance." 174 | ) 175 | 176 | def pop(self, *args, **kwargs): 177 | raise Exception( 178 | f"You cannot use ``pop`` on a {self.__class__.__name__} instance." 179 | ) 180 | 181 | def update(self, *args, **kwargs): 182 | raise Exception( 183 | f"You cannot use ``update`` on a {self.__class__.__name__} instance." 184 | ) 185 | 186 | def __getitem__(self, k): 187 | if isinstance(k, str): 188 | inner_dict = {k: v for (k, v) in self.items()} 189 | return inner_dict[k] 190 | else: 191 | return self.to_tuple()[k] 192 | 193 | def __setattr__(self, name, value): 194 | if name in self.keys() and value is not None: 195 | # Don't call self.__setitem__ to avoid recursion errors 196 | super().__setitem__(name, value) 197 | super().__setattr__(name, value) 198 | 199 | def __setitem__(self, key, value): 200 | # Will raise a KeyException if needed 201 | super().__setitem__(key, value) 202 | # Don't call self.__setattr__ to avoid recursion errors 203 | super().__setattr__(key, value) 204 | 205 | def to_tuple(self) -> Tuple[Any]: 206 | """ 207 | Convert self to a tuple containing all the attributes/keys that are not `None`. 208 | """ 209 | return tuple(self[k] for k in self.keys()) 210 | -------------------------------------------------------------------------------- /model/model_base.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pickle 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class ModelBase(nn.Module): 9 | module_class_list = {} 10 | 11 | def __init__(self): 12 | super(ModelBase, self).__init__() 13 | self.output_shape = None 14 | self.device = None 15 | self.mod_dict = nn.ModuleDict() 16 | 17 | def get_output_shape(self): 18 | if not self.output_shape: 19 | raise NotImplementedError 20 | return self.output_shape 21 | 22 | def set_device(self, device): 23 | self.device = device 24 | for k, v in self._modules.items(): 25 | if "set_device" in v.__dir__(): 26 | v.set_device(device) 27 | self.to(self.device) 28 | 29 | def set_eval(self): 30 | self.eval() 31 | for k, v in self.mod_dict.items(): 32 | if "eval" in v.__dir__(): 33 | v.eval() 34 | 35 | def set_attr(self, attr, value): 36 | pass 37 | 38 | def load_model_paras(self, params): 39 | if params is not None: 40 | super(ModelBase, self).load_state_dict(params["model_paras"]) 41 | else: 42 | raise AssertionError("Fail to load params for model.") 43 | 44 | def load_model_paras_from_file(self, para_file): 45 | params = None 46 | try: 47 | params = torch.load(para_file, map_location=lambda storage, loc: storage) 48 | except Exception as e: 49 | with open(para_file, "rb") as f: 50 | params = pickle.load(f) 51 | self.load_model_paras(params) 52 | logging.info(f"loaded model params:{para_file}") 53 | 54 | def load_state_dict(self, file): 55 | raise AssertionError( 56 | "The load_state_dict function has been forbidden in this model system. " 57 | "Please use load_model_paras instead." 58 | ) 59 | 60 | @staticmethod 61 | def check_config(config): 62 | required_paras = ["name", "paras"] 63 | ModelBase.check_config_dict(required_paras, config) 64 | 65 | @staticmethod 66 | def check_config_dict(required, config): 67 | assert isinstance(config, dict) 68 | for i in required: 69 | if i not in config.keys(): 70 | err = f"Required config {i} does not exist." 71 | raise KeyError(err) 72 | -------------------------------------------------------------------------------- /model/ms_ldm/blocks/blk_wrapper.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict 2 | import torchsparse 3 | import torch.nn as nn 4 | 5 | # public 6 | __all__ = [ 7 | "model_block_paras", 8 | "linear", 9 | "TimestepEmbedSequential", 10 | "timestep_embedding", 11 | "zero_module", 12 | "conv_nd", 13 | "normalization", 14 | "ResBlock", 15 | "Downsample", 16 | "Upsample", 17 | "AttentionBlock", 18 | "to_sparse", 19 | "to_dense", 20 | "torchsparse", 21 | "nonlinear", 22 | "SpatialTransformer", 23 | ] 24 | 25 | model_block_paras = EasyDict(dict(use_sparse=False, use_bev=False)) 26 | 27 | from model.ms_ldm.blocks.blk_modules import ( 28 | linear, 29 | TimestepEmbedSequential, 30 | zero_module, 31 | timestep_embedding, 32 | ) 33 | 34 | # dense 35 | from model.ms_ldm.blocks.blk_modules import conv_nd as conv_nd_dense 36 | from model.ms_ldm.blocks.blk_modules import ResBlock as ResBlockDense 37 | from model.ms_ldm.blocks.blk_modules import Downsample as DownsampleDense 38 | from model.ms_ldm.blocks.blk_modules import Upsample as UpsampleDense 39 | from model.ms_ldm.blocks.blk_modules import normalization as normalization_dense 40 | from model.ms_ldm.blocks.blk_modules import AttentionBlock as AttentionBlockDense 41 | 42 | # sparse 43 | from model.utils.torch_sparse_utils import to_sparse, to_dense 44 | from model.ms_ldm.blocks.sparse_blk_modules import sparse_conv3d, ResBlock3DSparse 45 | from model.ms_ldm.blocks.sparse_blk_modules import Downsample as DownsampleSparse 46 | from model.ms_ldm.blocks.sparse_blk_modules import Upsample as UpsampleSparse 47 | from model.ms_ldm.blocks.sparse_blk_modules import normalization as normalization_sparse 48 | from model.ms_ldm.blocks.sparse_blk_modules import AttentionBlock as AttentionBlockSparse 49 | from model.ms_ldm.blocks.sparse_blk_modules import SiLU as SparseSiLU 50 | from model.ms_ldm.spatial_transformer.spatial_transformer_3d_sparse import ( 51 | SpatialTransformer as SpatialTransformerSparse, 52 | ) 53 | from model.ms_ldm.spatial_transformer.spatial_transformer_3d import ( 54 | SpatialTransformer as SpatialTransformerDense, 55 | ) 56 | from model.ms_ldm.spatial_transformer.spatial_transformer_2d import ( 57 | SpatialTransformer as SpatialTransformer2D, 58 | ) 59 | from model.ms_ldm.spatial_transformer.spatial_transformer_bev_sparse import ( 60 | SpatialTransformer as SpatialTransformerSparseBEV, 61 | ) 62 | 63 | 64 | def nonlinear(): 65 | if model_block_paras.use_sparse: 66 | return SparseSiLU() 67 | else: 68 | return nn.SiLU() 69 | 70 | 71 | def conv_nd(dims, *args, **kwargs): 72 | if model_block_paras.use_sparse: 73 | if dims != 3: 74 | raise NotImplementedError 75 | else: 76 | return sparse_conv3d(*args, **kwargs) 77 | else: 78 | return conv_nd_dense(dims, *args, **kwargs) 79 | 80 | 81 | def normalization(*args, **kwargs): 82 | if model_block_paras.use_sparse: 83 | return normalization_sparse(*args, **kwargs) 84 | else: 85 | return normalization_dense(*args, **kwargs) 86 | 87 | 88 | def ResBlock(*args, **kwargs): 89 | if model_block_paras.use_sparse: 90 | if kwargs["dims"] != 3: 91 | raise NotImplementedError 92 | return ResBlock3DSparse(*args, **kwargs) 93 | else: 94 | return ResBlockDense(*args, **kwargs) 95 | 96 | 97 | def Downsample(*args, **kwargs): 98 | if model_block_paras.use_sparse: 99 | if kwargs["dims"] != 3: 100 | raise NotImplementedError 101 | return DownsampleSparse(*args, **kwargs) 102 | else: 103 | return DownsampleDense(*args, **kwargs) 104 | 105 | 106 | def Upsample(*args, **kwargs): 107 | if model_block_paras.use_sparse: 108 | if kwargs["dims"] != 3: 109 | raise NotImplementedError 110 | return UpsampleSparse(*args, **kwargs) 111 | else: 112 | return UpsampleDense(*args, **kwargs) 113 | 114 | 115 | def AttentionBlock(*args, **kwargs): 116 | if model_block_paras.use_sparse: 117 | return AttentionBlockSparse(*args, **kwargs) 118 | else: 119 | return AttentionBlockDense(*args, **kwargs) 120 | 121 | 122 | def SpatialTransformer(*args, **kwargs): 123 | if model_block_paras.use_sparse: 124 | if model_block_paras.use_bev: 125 | return SpatialTransformerSparseBEV(*args, **kwargs) 126 | else: 127 | return SpatialTransformerSparse(*args, **kwargs) 128 | else: 129 | if kwargs["dims"] == 2: 130 | return SpatialTransformer2D(*args, **kwargs) 131 | elif kwargs["dims"] == 3: 132 | return SpatialTransformerDense(*args, **kwargs) 133 | else: 134 | raise NotImplementedError 135 | -------------------------------------------------------------------------------- /model/ms_ldm/blocks/model_utils.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import os 12 | import math 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | from einops import repeat 17 | 18 | 19 | def checkpoint(func, inputs, params, flag): 20 | """ 21 | Evaluate a function without caching intermediate activations, allowing for 22 | reduced memory at the expense of extra compute in the backward pass. 23 | :param func: the function to evaluate. 24 | :param inputs: the argument sequence to pass to `func`. 25 | :param params: a sequence of parameters `func` depends on but does not 26 | explicitly take as arguments. 27 | :param flag: if False, disable gradient checkpointing. 28 | """ 29 | if flag: 30 | args = tuple(inputs) + tuple(params) 31 | return CheckpointFunction.apply(func, len(inputs), *args) 32 | else: 33 | return func(*inputs) 34 | 35 | 36 | class CheckpointFunction(torch.autograd.Function): 37 | @staticmethod 38 | def forward(ctx, run_function, length, *args): 39 | ctx.run_function = run_function 40 | ctx.input_tensors = list(args[:length]) 41 | ctx.input_params = list(args[length:]) 42 | 43 | with torch.no_grad(): 44 | output_tensors = ctx.run_function(*ctx.input_tensors) 45 | return output_tensors 46 | 47 | @staticmethod 48 | def backward(ctx, *output_grads): 49 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 50 | with torch.enable_grad(): 51 | # Fixes a bug where the first op in run_function modifies the 52 | # Tensor storage in place, which is not allowed for detach()'d 53 | # Tensors. 54 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 55 | output_tensors = ctx.run_function(*shallow_copies) 56 | input_grads = torch.autograd.grad( 57 | output_tensors, 58 | ctx.input_tensors + ctx.input_params, 59 | output_grads, 60 | allow_unused=True, 61 | ) 62 | del ctx.input_tensors 63 | del ctx.input_params 64 | del output_tensors 65 | return (None, None) + input_grads 66 | 67 | 68 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 69 | """ 70 | Create sinusoidal timestep embeddings. 71 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 72 | These may be fractional. 73 | :param dim: the dimension of the output. 74 | :param max_period: controls the minimum frequency of the embeddings. 75 | :return: an [N x dim] Tensor of positional embeddings. 76 | """ 77 | if not repeat_only: 78 | half = dim // 2 79 | freqs = torch.exp( 80 | -math.log(max_period) 81 | * torch.arange(start=0, end=half, dtype=torch.float32) 82 | / half 83 | ).to(device=timesteps.device) 84 | args = timesteps[:, None].float() * freqs[None] 85 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 86 | if dim % 2: 87 | embedding = torch.cat( 88 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 89 | ) 90 | else: 91 | embedding = repeat(timesteps, "b -> b d", d=dim) 92 | return embedding 93 | 94 | 95 | def zero_module(module): 96 | """ 97 | Zero out the parameters of a module and return it. 98 | """ 99 | for p in module.parameters(): 100 | p.detach().zero_() 101 | return module 102 | 103 | 104 | def scale_module(module, scale): 105 | """ 106 | Scale the parameters of a module and return it. 107 | """ 108 | for p in module.parameters(): 109 | p.detach().mul_(scale) 110 | return module 111 | 112 | 113 | def mean_flat(tensor): 114 | """ 115 | Take the mean over all non-batch dimensions. 116 | """ 117 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 118 | 119 | 120 | def normalization(channels): 121 | """ 122 | Make a standard normalization layer. 123 | :param channels: number of input channels. 124 | :return: an nn.Module for normalization. 125 | """ 126 | return GroupNorm32(32, channels) 127 | 128 | 129 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 130 | class SiLU(nn.Module): 131 | def forward(self, x): 132 | return x * torch.sigmoid(x) 133 | 134 | 135 | class GroupNorm32(nn.GroupNorm): 136 | def forward(self, x): 137 | return super().forward(x.float()).type(x.dtype) 138 | 139 | 140 | def conv_nd(dims, *args, **kwargs): 141 | """ 142 | Create a 1D, 2D, or 3D convolution module. 143 | """ 144 | if dims == 1: 145 | return nn.Conv1d(*args, **kwargs) 146 | elif dims == 2: 147 | return nn.Conv2d(*args, **kwargs) 148 | elif dims == 3: 149 | return nn.Conv3d(*args, **kwargs) 150 | raise ValueError(f"unsupported dimensions: {dims}") 151 | 152 | 153 | def linear(*args, **kwargs): 154 | """ 155 | Create a linear module. 156 | """ 157 | return nn.Linear(*args, **kwargs) 158 | 159 | 160 | def avg_pool_nd(dims, *args, **kwargs): 161 | """ 162 | Create a 1D, 2D, or 3D average pooling module. 163 | """ 164 | if dims == 1: 165 | return nn.AvgPool1d(*args, **kwargs) 166 | elif dims == 2: 167 | return nn.AvgPool2d(*args, **kwargs) 168 | elif dims == 3: 169 | return nn.AvgPool3d(*args, **kwargs) 170 | raise ValueError(f"unsupported dimensions: {dims}") 171 | 172 | -------------------------------------------------------------------------------- /model/ms_ldm/sketch_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SketchEncoder(nn.Module): 6 | def __init__(self, out_chn=768) -> None: 7 | super().__init__() 8 | self.out_chn = out_chn 9 | self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1) 10 | self.relu1 = nn.ReLU() 11 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) 12 | self.conv2 = nn.Conv2d(64, 256, kernel_size=3, stride=1, padding=1) 13 | self.relu2 = nn.ReLU() 14 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) 15 | self.conv3 = nn.Conv2d(256, out_chn, kernel_size=2, stride=2, padding=0) 16 | self.relu3 = nn.ReLU() 17 | 18 | def forward(self, x): 19 | bs = x.shape[0] 20 | x = self.conv1(x) 21 | x = self.relu1(x) 22 | x = self.pool1(x) 23 | x = self.conv2(x) 24 | x = self.relu2(x) 25 | x = self.pool2(x) 26 | x = self.conv3(x) 27 | x = self.relu3(x) 28 | x = x.view(bs, self.out_chn, -1).permute(0, 2, 1) 29 | return x 30 | -------------------------------------------------------------------------------- /model/ms_ldm/spatial_transformer/spatial_transformer_2d.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, einsum 6 | from einops import rearrange, repeat 7 | 8 | from model.ms_ldm.blocks.model_utils import checkpoint 9 | 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def default(val, d): 16 | if exists(val): 17 | return val 18 | return d() if isfunction(d) else d 19 | 20 | 21 | # feedforward 22 | class GEGLU(nn.Module): 23 | def __init__(self, dim_in, dim_out): 24 | super().__init__() 25 | self.proj = nn.Linear(dim_in, dim_out * 2) 26 | 27 | def forward(self, x): 28 | x, gate = self.proj(x).chunk(2, dim=-1) 29 | return x * F.gelu(gate) 30 | 31 | 32 | class FeedForward(nn.Module): 33 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): 34 | super().__init__() 35 | inner_dim = int(dim * mult) 36 | dim_out = default(dim_out, dim) 37 | project_in = ( 38 | nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) 39 | if not glu 40 | else GEGLU(dim, inner_dim) 41 | ) 42 | 43 | self.net = nn.Sequential( 44 | project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) 45 | ) 46 | 47 | def forward(self, x): 48 | return self.net(x) 49 | 50 | 51 | def zero_module(module): 52 | """ 53 | Zero out the parameters of a module and return it. 54 | """ 55 | for p in module.parameters(): 56 | p.detach().zero_() 57 | return module 58 | 59 | 60 | def Normalize(in_channels): 61 | return torch.nn.GroupNorm( 62 | num_groups=32, num_channels=in_channels, eps=1e-6, affine=True 63 | ) 64 | 65 | 66 | class CrossAttention(nn.Module): 67 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): 68 | super().__init__() 69 | inner_dim = dim_head * heads 70 | context_dim = default(context_dim, query_dim) 71 | 72 | self.scale = dim_head**-0.5 73 | self.heads = heads 74 | 75 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 76 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 77 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 78 | 79 | self.to_out = nn.Sequential( 80 | nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) 81 | ) 82 | 83 | def forward(self, x, context=None, mask=None): 84 | h = self.heads 85 | 86 | q = self.to_q(x) 87 | context = default(context, x) 88 | k = self.to_k(context) 89 | v = self.to_v(context) 90 | 91 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) 92 | 93 | sim = einsum("b i d, b j d -> b i j", q, k) * self.scale 94 | 95 | if exists(mask): 96 | mask = rearrange(mask, "b ... -> b (...)") 97 | max_neg_value = -torch.finfo(sim.dtype).max 98 | mask = repeat(mask, "b j -> (b h) () j", h=h) 99 | sim.masked_fill_(~mask, max_neg_value) 100 | 101 | # attention, what we cannot get enough of 102 | attn = sim.softmax(dim=-1) 103 | 104 | out = einsum("b i j, b j d -> b i d", attn, v) 105 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h) 106 | return self.to_out(out) 107 | 108 | 109 | class BasicTransformerBlock(nn.Module): 110 | def __init__( 111 | self, 112 | dim, 113 | n_heads, 114 | d_head, 115 | dropout=0.0, 116 | context_dim=None, 117 | gated_ff=True, 118 | checkpoint=True, 119 | ): 120 | super().__init__() 121 | self.attn1 = CrossAttention( 122 | query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout 123 | ) # is a self-attention 124 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 125 | self.attn2 = CrossAttention( 126 | query_dim=dim, 127 | context_dim=context_dim, 128 | heads=n_heads, 129 | dim_head=d_head, 130 | dropout=dropout, 131 | ) # is self-attn if context is none 132 | self.norm1 = nn.LayerNorm(dim) 133 | self.norm2 = nn.LayerNorm(dim) 134 | self.norm3 = nn.LayerNorm(dim) 135 | self.checkpoint = checkpoint 136 | 137 | def forward(self, x, context=None): 138 | return checkpoint( 139 | self._forward, (x, context), self.parameters(), self.checkpoint 140 | ) 141 | 142 | def _forward(self, x, context=None): 143 | x = self.attn1(self.norm1(x)) + x 144 | x = self.attn2(self.norm2(x), context=context) + x 145 | x = self.ff(self.norm3(x)) + x 146 | return x 147 | 148 | 149 | class SpatialTransformer(nn.Module): 150 | """ 151 | Transformer block for image-like data. 152 | First, project the input (aka embedding) 153 | and reshape to b, t, d. 154 | Then apply standard transformer action. 155 | Finally, reshape to image 156 | """ 157 | 158 | def __init__( 159 | self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None 160 | ): 161 | super().__init__() 162 | self.in_channels = in_channels 163 | inner_dim = n_heads * d_head 164 | self.norm = Normalize(in_channels) 165 | 166 | self.proj_in = nn.Conv2d( 167 | in_channels, inner_dim, kernel_size=1, stride=1, padding=0 168 | ) 169 | 170 | self.transformer_blocks = nn.ModuleList( 171 | [ 172 | BasicTransformerBlock( 173 | inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim 174 | ) 175 | for d in range(depth) 176 | ] 177 | ) 178 | 179 | self.proj_out = zero_module( 180 | nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) 181 | ) 182 | 183 | def forward(self, x, context=None): 184 | # note: if no context is given, cross-attention defaults to self-attention 185 | b, c, h, w = x.shape 186 | x_in = x 187 | x = self.norm(x) 188 | x = self.proj_in(x) 189 | x = rearrange(x, "b c h w -> b (h w) c") 190 | for block in self.transformer_blocks: 191 | x = block(x, context=context) 192 | x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) 193 | x = self.proj_out(x) 194 | return x + x_in 195 | -------------------------------------------------------------------------------- /model/ms_ldm/spatial_transformer/spatial_transformer_3d.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn, einsum 5 | from einops import rearrange, repeat 6 | 7 | from model.ms_ldm.blocks.model_utils import checkpoint 8 | 9 | 10 | def exists(val): 11 | return val is not None 12 | 13 | 14 | def default(val, d): 15 | if exists(val): 16 | return val 17 | return d() if isfunction(d) else d 18 | 19 | 20 | # feedforward 21 | class GEGLU(nn.Module): 22 | def __init__(self, dim_in, dim_out): 23 | super().__init__() 24 | self.proj = nn.Linear(dim_in, dim_out * 2) 25 | 26 | def forward(self, x): 27 | x, gate = self.proj(x).chunk(2, dim=-1) 28 | return x * F.gelu(gate) 29 | 30 | 31 | class FeedForward(nn.Module): 32 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): 33 | super().__init__() 34 | inner_dim = int(dim * mult) 35 | dim_out = default(dim_out, dim) 36 | project_in = ( 37 | nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) 38 | if not glu 39 | else GEGLU(dim, inner_dim) 40 | ) 41 | 42 | self.net = nn.Sequential( 43 | project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) 44 | ) 45 | 46 | def forward(self, x): 47 | return self.net(x) 48 | 49 | 50 | def zero_module(module): 51 | """ 52 | Zero out the parameters of a module and return it. 53 | """ 54 | for p in module.parameters(): 55 | p.detach().zero_() 56 | return module 57 | 58 | 59 | def Normalize(in_channels): 60 | return torch.nn.GroupNorm( 61 | num_groups=32, num_channels=in_channels, eps=1e-6, affine=True 62 | ) 63 | 64 | 65 | class CrossAttention(nn.Module): 66 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): 67 | super().__init__() 68 | inner_dim = dim_head * heads 69 | context_dim = default(context_dim, query_dim) 70 | 71 | self.scale = dim_head**-0.5 72 | self.heads = heads 73 | 74 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 75 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 76 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 77 | 78 | self.to_out = nn.Sequential( 79 | nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) 80 | ) 81 | 82 | def forward(self, x, context=None, mask=None): 83 | h = self.heads 84 | 85 | q = self.to_q(x) 86 | context = default(context, x) 87 | k = self.to_k(context) 88 | v = self.to_v(context) 89 | 90 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) 91 | 92 | sim = einsum("b i d, b j d -> b i j", q, k) * self.scale 93 | 94 | if exists(mask): 95 | mask = rearrange(mask, "b ... -> b (...)") 96 | max_neg_value = -torch.finfo(sim.dtype).max 97 | mask = repeat(mask, "b j -> (b h) () j", h=h) 98 | sim.masked_fill_(~mask, max_neg_value) 99 | 100 | # attention, what we cannot get enough of 101 | attn = sim.softmax(dim=-1) 102 | 103 | out = einsum("b i j, b j d -> b i d", attn, v) 104 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h) 105 | return self.to_out(out) 106 | 107 | 108 | class BasicTransformerBlock(nn.Module): 109 | def __init__( 110 | self, 111 | dim, 112 | n_heads, 113 | d_head, 114 | dropout=0.0, 115 | context_dim=None, 116 | gated_ff=True, 117 | checkpoint=True, 118 | ): 119 | super().__init__() 120 | self.attn1 = CrossAttention( 121 | query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout 122 | ) # is a self-attention 123 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 124 | self.attn2 = CrossAttention( 125 | query_dim=dim, 126 | context_dim=context_dim, 127 | heads=n_heads, 128 | dim_head=d_head, 129 | dropout=dropout, 130 | ) # is self-attn if context is none 131 | self.norm1 = nn.LayerNorm(dim) 132 | self.norm2 = nn.LayerNorm(dim) 133 | self.norm3 = nn.LayerNorm(dim) 134 | self.checkpoint = checkpoint 135 | 136 | def forward(self, x, context=None): 137 | return checkpoint( 138 | self._forward, (x, context), self.parameters(), self.checkpoint 139 | ) 140 | 141 | def _forward(self, x, context=None): 142 | x = self.attn1(self.norm1(x)) + x 143 | x = self.attn2(self.norm2(x), context=context) + x 144 | x = self.ff(self.norm3(x)) + x 145 | return x 146 | 147 | 148 | class SpatialTransformer(nn.Module): 149 | """ 150 | Transformer block for image-like data. 151 | First, project the input (aka embedding) 152 | and reshape to b, t, d. 153 | Then apply standard transformer action. 154 | Finally, reshape to image 155 | """ 156 | 157 | def __init__( 158 | self, 159 | in_channels, 160 | n_heads, 161 | d_head, 162 | depth=1, 163 | dropout=0.0, 164 | context_dim=None, 165 | **args 166 | ): 167 | super().__init__() 168 | self.in_channels = in_channels 169 | inner_dim = n_heads * d_head 170 | self.norm = Normalize(in_channels) 171 | 172 | self.proj_in = nn.Conv3d( 173 | in_channels, inner_dim, kernel_size=1, stride=1, padding=0 174 | ) 175 | 176 | self.transformer_blocks = nn.ModuleList( 177 | [ 178 | BasicTransformerBlock( 179 | inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim 180 | ) 181 | for d in range(depth) 182 | ] 183 | ) 184 | 185 | self.proj_out = zero_module( 186 | nn.Conv3d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) 187 | ) 188 | 189 | def forward(self, x, context=None): 190 | # note: if no context is given, cross-attention defaults to self-attention 191 | b, c, h, w, l = x.shape 192 | x_in = x 193 | x = self.norm(x) 194 | x = self.proj_in(x) 195 | x = rearrange(x, "b c h w l -> b (h w l) c") 196 | for block in self.transformer_blocks: 197 | x = block(x, context=context) 198 | x = rearrange(x, "b (h w l) c -> b c h w l", h=h, w=w, l=l) 199 | x = self.proj_out(x) 200 | return x + x_in 201 | -------------------------------------------------------------------------------- /model/ms_ldm/spatial_transformer/spatial_transformer_3d_sparse.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn, einsum 5 | from einops import rearrange, repeat 6 | 7 | import torchsparse.nn as spnn 8 | from torchsparse.tensor import SparseTensor 9 | 10 | from model.utils.torch_sparse_utils import GroupNorm 11 | from model.ms_ldm.blocks.model_utils import checkpoint 12 | from model.ms_ldm.blocks.sparse_blk_modules import LayerNorm, FeedForward 13 | 14 | # feedforward 15 | class GEGLU(nn.Module): 16 | def __init__(self, dim_in, dim_out): 17 | super().__init__() 18 | self.proj = nn.Linear(dim_in, dim_out * 2) 19 | 20 | def forward(self, x): 21 | x, gate = self.proj(x).chunk(2, dim=-1) 22 | return x * F.gelu(gate) 23 | 24 | 25 | def exists(val): 26 | return val is not None 27 | 28 | 29 | def uniq(arr): 30 | return {el: True for el in arr}.keys() 31 | 32 | 33 | def default(val, d): 34 | if exists(val): 35 | return val 36 | return d() if isfunction(d) else d 37 | 38 | 39 | def zero_module(module): 40 | """ 41 | Zero out the parameters of a module and return it. 42 | """ 43 | for p in module.parameters(): 44 | p.detach().zero_() 45 | return module 46 | 47 | 48 | def Normalize(in_channels): 49 | return GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 50 | 51 | 52 | # class LayerNorm(nn.LayerNorm): 53 | # def forward(self, input): 54 | # coords, feats, stride = input.coords, input.feats, input.stride 55 | 56 | # batch_size = torch.max(coords[:, -1]).item() + 1 57 | # num_channels = feats.shape[1] 58 | 59 | # nfeats = torch.zeros_like(feats) 60 | # for k in range(batch_size): 61 | # indices = coords[:, -1] == k 62 | # bfeats = feats[indices] 63 | # # bfeats = bfeats.transpose(0, 1).reshape(1, num_channels, -1) 64 | # bfeats = super().forward(bfeats) 65 | # # bfeats = bfeats.reshape(num_channels, -1).transpose(0, 1) 66 | # nfeats[indices] = bfeats 67 | 68 | # output = SparseTensor(coords=coords, feats=nfeats, stride=stride) 69 | # try: 70 | # output.cmaps = input.cmaps 71 | # output.kmaps = input.kmaps 72 | # except: 73 | # output._caches.cmaps = input._caches.cmaps 74 | # output._caches.kmaps = input._caches.kmaps 75 | # return output 76 | 77 | 78 | # class FeedForward(nn.Module): 79 | # def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): 80 | # super().__init__() 81 | # inner_dim = int(dim * mult) 82 | # dim_out = default(dim_out, dim) 83 | # project_in = ( 84 | # nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) 85 | # if not glu 86 | # else GEGLU(dim, inner_dim) 87 | # ) 88 | 89 | # self.net = nn.Sequential( 90 | # project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) 91 | # ) 92 | 93 | # def forward(self, x): 94 | # nfeats = torch.zeros_like(x.F) 95 | # coords = x.C 96 | # # cmaps = x.cmaps 97 | # try: 98 | # cmaps = x.cmaps 99 | # kmaps = x.kmaps 100 | # except: 101 | # cmaps = x._caches.cmaps 102 | # kmaps = x._caches.kmaps 103 | # stride = x.stride 104 | # batch_inx = x.C[:, -1].unique() 105 | 106 | # for i in batch_inx: 107 | # indices = x.C[:, -1] == i 108 | # f = x.F[indices] 109 | # out = self.net(f) 110 | # nfeats[indices] = out.squeeze(0) 111 | 112 | # output = SparseTensor(coords=coords, feats=nfeats, stride=stride) 113 | # try: 114 | # output.cmaps = cmaps 115 | # output.kmaps = kmaps 116 | # except: 117 | # output._caches.cmaps = cmaps 118 | # output._caches.kmaps = kmaps 119 | # return output 120 | 121 | 122 | class CrossAttention(nn.Module): 123 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): 124 | super().__init__() 125 | inner_dim = dim_head * heads 126 | context_dim = default(context_dim, query_dim) 127 | 128 | self.scale = dim_head**-0.5 129 | self.heads = heads 130 | 131 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 132 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 133 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 134 | 135 | self.to_out = nn.Sequential( 136 | nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) 137 | ) 138 | 139 | def forward(self, x, context=None, mask=None): 140 | nfeats = torch.zeros_like(x.F) 141 | coords = x.C 142 | # cmaps = x.cmaps 143 | try: 144 | cmaps = x.cmaps 145 | kmaps = x.kmaps 146 | except: 147 | cmaps = x._caches.cmaps 148 | kmaps = x._caches.kmaps 149 | stride = x.stride 150 | 151 | h = self.heads 152 | 153 | # loop over batch 154 | batch_inx = x.C[:, -1].unique() 155 | for i in batch_inx: 156 | context_i = context[i : i + 1] if context is not None else None 157 | indices = x.C[:, -1] == i 158 | f = x.F[indices] 159 | f = f.unsqueeze(0) 160 | 161 | q = self.to_q(f) 162 | context_i = default(context_i, f) 163 | k = self.to_k(context_i) 164 | v = self.to_v(context_i) 165 | 166 | q, k, v = map( 167 | lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v) 168 | ) 169 | 170 | sim = einsum("b i d, b j d -> b i j", q, k) * self.scale 171 | 172 | if exists(mask): 173 | mask = rearrange(mask, "b ... -> b (...)") 174 | max_neg_value = -torch.finfo(sim.dtype).max 175 | mask = repeat(mask, "b j -> (b h) () j", h=h) 176 | sim.masked_fill_(~mask, max_neg_value) 177 | 178 | # attention, what we cannot get enough of 179 | attn = sim.softmax(dim=-1) 180 | 181 | out = einsum("b i j, b j d -> b i d", attn, v) 182 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h) 183 | out = self.to_out(out) 184 | 185 | nfeats[indices] = out.squeeze(0) 186 | 187 | output = SparseTensor(coords=coords, feats=nfeats, stride=stride) 188 | try: 189 | output.cmaps = cmaps 190 | output.kmaps = kmaps 191 | except: 192 | output._caches.cmaps = cmaps 193 | output._caches.kmaps = kmaps 194 | return output 195 | 196 | 197 | class BasicTransformerBlock(nn.Module): 198 | def __init__( 199 | self, 200 | dim, 201 | n_heads, 202 | d_head, 203 | dropout=0.0, 204 | context_dim=None, 205 | gated_ff=True, 206 | checkpoint=True, 207 | ): 208 | super().__init__() 209 | self.attn1 = CrossAttention( 210 | query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout 211 | ) # is a self-attention 212 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 213 | self.attn2 = CrossAttention( 214 | query_dim=dim, 215 | context_dim=context_dim, 216 | heads=n_heads, 217 | dim_head=d_head, 218 | dropout=dropout, 219 | ) # is self-attn if context is none 220 | self.norm1 = LayerNorm(dim) 221 | self.norm2 = LayerNorm(dim) 222 | self.norm3 = LayerNorm(dim) 223 | self.checkpoint = checkpoint 224 | 225 | def forward(self, x, context=None): 226 | return checkpoint( 227 | self._forward, (x, context), self.parameters(), self.checkpoint 228 | ) 229 | 230 | def _forward(self, x, context=None): 231 | x = self.attn1(self.norm1(x)) + x 232 | x = self.attn2(self.norm2(x), context=context) + x 233 | x = self.ff(self.norm3(x)) + x 234 | return x 235 | 236 | 237 | class SpatialTransformer(nn.Module): 238 | """ 239 | Transformer block for image-like data. 240 | First, project the input (aka embedding) 241 | and reshape to b, t, d. 242 | Then apply standard transformer action. 243 | Finally, reshape to image 244 | """ 245 | 246 | def __init__( 247 | self, 248 | in_channels, 249 | n_heads, 250 | d_head, 251 | depth=1, 252 | dropout=0.0, 253 | context_dim=None, 254 | positional_encoding=False, 255 | spatial_dim=None, 256 | **args 257 | ): 258 | super().__init__() 259 | self.in_channels = in_channels 260 | inner_dim = n_heads * d_head 261 | self.norm = Normalize(in_channels) 262 | 263 | self.proj_in = spnn.Conv3d(in_channels, inner_dim, kernel_size=1, stride=1) 264 | 265 | self.transformer_blocks = nn.ModuleList( 266 | [ 267 | BasicTransformerBlock( 268 | inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim 269 | ) 270 | for d in range(depth) 271 | ] 272 | ) 273 | 274 | self.proj_out = zero_module( 275 | spnn.Conv3d(inner_dim, in_channels, kernel_size=1, stride=1) 276 | ) 277 | self.use_position_encoding = positional_encoding 278 | if positional_encoding: 279 | self.positional_embedding = nn.Parameter( 280 | torch.randn(*spatial_dim, inner_dim) / inner_dim**0.5 281 | ) 282 | 283 | def forward(self, x, context=None): 284 | # note: if no context is given, cross-attention defaults to self-attention 285 | # b, c, h, w, l = x.shape 286 | x_in = x 287 | x = self.norm(x) 288 | x = self.proj_in(x) 289 | # x = rearrange(x, 'b c h w l -> b (h w l) c') 290 | if self.use_position_encoding: 291 | pos_emb = self.positional_embedding[ 292 | x.C[:, 1].to(torch.long), 293 | x.C[:, 2].to(torch.long), 294 | x.C[:, 3].to(torch.long), 295 | :, 296 | ] 297 | x.F = x.F + pos_emb 298 | for block in self.transformer_blocks: 299 | x = block(x, context=context) 300 | # x = rearrange(x, 'b (h w l) c -> b c h w l', h=h, w=w, l=l) 301 | x = self.proj_out(x) 302 | return x + x_in 303 | -------------------------------------------------------------------------------- /model/utils/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError("Decay must be between 0 and 1") 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer( 14 | "num_updates", 15 | ( 16 | torch.tensor(0, dtype=torch.int) 17 | if use_num_upates 18 | else torch.tensor(-1, dtype=torch.int) 19 | ), 20 | ) 21 | 22 | for name, p in model.named_parameters(): 23 | if p.requires_grad: 24 | s_name = name.replace(".", "") 25 | self.m_name2s_name.update({name: s_name}) 26 | self.register_buffer(s_name, p.clone().detach().data) 27 | 28 | self.collected_params = [] 29 | 30 | def forward(self, model): 31 | decay = self.decay 32 | 33 | if self.num_updates >= 0: 34 | self.num_updates += 1 35 | decay = min( 36 | self.decay, 37 | torch.true_divide((1 + self.num_updates), (10 + self.num_updates)), 38 | ) 39 | # decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 40 | 41 | one_minus_decay = 1.0 - decay 42 | 43 | with torch.no_grad(): 44 | m_param = dict(model.named_parameters()) 45 | shadow_params = dict(self.named_buffers()) 46 | 47 | for key in m_param: 48 | if m_param[key].requires_grad: 49 | sname = self.m_name2s_name[key] 50 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 51 | shadow_params[sname].sub_( 52 | one_minus_decay * (shadow_params[sname] - m_param[key]) 53 | ) 54 | else: 55 | assert not key in self.m_name2s_name 56 | 57 | def copy_to(self, model): 58 | m_param = dict(model.named_parameters()) 59 | shadow_params = dict(self.named_buffers()) 60 | for key in m_param: 61 | if m_param[key].requires_grad: 62 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 63 | else: 64 | assert not key in self.m_name2s_name 65 | 66 | def store(self, parameters): 67 | """ 68 | Save the current parameters for restoring later. 69 | Args: 70 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 71 | temporarily stored. 72 | """ 73 | self.collected_params = [param.clone() for param in parameters] 74 | 75 | def restore(self, parameters): 76 | """ 77 | Restore the parameters stored with the `store` method. 78 | Useful to validate the model with EMA parameters without affecting the 79 | original optimization process. Store the parameters before the 80 | `copy_to` method. After validation (or model saving), use this to 81 | restore the former parameters. 82 | Args: 83 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 84 | updated with the stored parameters. 85 | """ 86 | for c_param, param in zip(self.collected_params, parameters): 87 | param.data.copy_(c_param.data) 88 | -------------------------------------------------------------------------------- /model/utils/global_mapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class GlobalMapper: 5 | def __init__( 6 | self, 7 | old_origin, 8 | dim_size, 9 | voxel_size=0.04, 10 | default_value=1.0, 11 | device=torch.device("cpu"), 12 | ) -> None: 13 | self.voxel_default_value = default_value 14 | self.scene_map = torch.ones(dim_size).to(device) * self.voxel_default_value 15 | self.dim = dim_size 16 | self.origin = ( 17 | torch.div(old_origin, voxel_size, rounding_mode="floor") 18 | ) * voxel_size 19 | self.voxel_size = voxel_size 20 | self.part_id_info = {} 21 | self.device = device 22 | 23 | def to(self, device): 24 | self.scene_map = self.scene_map.to(device) 25 | self.device = device 26 | 27 | def get_overlap_index(self, get_origin, get_dim): 28 | get_origin = torch.tensor([*get_origin], dtype=torch.int32) 29 | get_dim = torch.tensor([*get_dim], dtype=torch.int32) 30 | self_dim = torch.tensor([*self.dim], dtype=torch.int32) 31 | 32 | if ((get_origin - self_dim) >= 0).any(): 33 | return None 34 | if ((get_origin + get_dim) <= 0).any(): 35 | return None 36 | inx_end = torch.min(get_origin + get_dim, self_dim) 37 | inx_st = torch.max(get_origin, torch.tensor((0, 0, 0))) 38 | get_inx_st = torch.max(-get_origin, torch.tensor((0, 0, 0))) 39 | get_inx_ed = torch.min(self_dim - get_origin, get_dim) 40 | assert (inx_end - inx_st == get_inx_ed - get_inx_st).any() 41 | return (inx_st, inx_end), (get_inx_st, get_inx_ed) 42 | 43 | def get_scene_map(self): 44 | return { 45 | "map": self.scene_map, 46 | "origin": self.origin, 47 | "voxel_size": self.voxel_size, 48 | } 49 | 50 | def update(self, part_id, partial_vol, partial_origin, mode="assign"): 51 | voxel_origin = (partial_origin - self.origin) / self.voxel_size 52 | # print(voxel_origin) 53 | voxel_origin = torch.round(voxel_origin).to(torch.int32) 54 | part_info = { 55 | "part_voxel_origin": voxel_origin, 56 | "part_origin": voxel_origin * self.voxel_size + self.origin, 57 | "part_dim": partial_vol.shape, 58 | } 59 | if part_id not in self.part_id_info: 60 | self.part_id_info[part_id] = part_info 61 | else: 62 | old_part_info = self.part_id_info[part_id] 63 | for k, v in old_part_info: 64 | assert part_info[k] == v 65 | 66 | indices = self.get_overlap_index(voxel_origin, partial_vol.shape) 67 | assert indices is not None 68 | (a_st, a_ed), (b_st, b_ed) = indices 69 | if mode == "assign": 70 | self.scene_map[a_st[0] : a_ed[0], a_st[1] : a_ed[1], a_st[2] : a_ed[2]] = ( 71 | partial_vol[b_st[0] : b_ed[0], b_st[1] : b_ed[1], b_st[2] : b_ed[2]] 72 | ) 73 | elif mode == "add": 74 | self.scene_map[ 75 | a_st[0] : a_ed[0], a_st[1] : a_ed[1], a_st[2] : a_ed[2] 76 | ] += partial_vol[b_st[0] : b_ed[0], b_st[1] : b_ed[1], b_st[2] : b_ed[2]] 77 | elif mode == "random": 78 | a = self.scene_map[a_st[0] : a_ed[0], a_st[1] : a_ed[1], a_st[2] : a_ed[2]] 79 | b = partial_vol[b_st[0] : b_ed[0], b_st[1] : b_ed[1], b_st[2] : b_ed[2]] 80 | c = torch.rand_like(a) 81 | d = (a != self.voxel_default_value) * c 82 | # this option works than the previous average!!!! 83 | 84 | self.scene_map[a_st[0] : a_ed[0], a_st[1] : a_ed[1], a_st[2] : a_ed[2]] = ( 85 | a * d + (1 - d) * b 86 | ) 87 | 88 | else: 89 | raise NotImplementedError 90 | 91 | def get(self, part_id, default_value=0): 92 | assert part_id in self.part_id_info 93 | part_info = self.part_id_info[part_id] 94 | voxel_origin = part_info["part_voxel_origin"] 95 | part_dim = part_info["part_dim"] 96 | indices = self.get_overlap_index(voxel_origin, part_dim) 97 | assert indices is not None 98 | (a_st, a_ed), (b_st, b_ed) = indices 99 | target = torch.ones(part_dim).to(self.device) * default_value 100 | target[b_st[0] : b_ed[0], b_st[1] : b_ed[1], b_st[2] : b_ed[2]] = ( 101 | self.scene_map[a_st[0] : a_ed[0], a_st[1] : a_ed[1], a_st[2] : a_ed[2]] 102 | ) 103 | return target 104 | 105 | def get_voxel_part(self, voxel_origin, voxel_dim, default_value): 106 | voxel_origin = voxel_origin 107 | part_dim = voxel_dim 108 | indices = self.get_overlap_index(voxel_origin, part_dim) 109 | assert indices is not None 110 | (a_st, a_ed), (b_st, b_ed) = indices 111 | target = torch.ones(part_dim).to(self.device) * default_value 112 | target[b_st[0] : b_ed[0], b_st[1] : b_ed[1], b_st[2] : b_ed[2]] = ( 113 | self.scene_map[a_st[0] : a_ed[0], a_st[1] : a_ed[1], a_st[2] : a_ed[2]] 114 | ) 115 | return target 116 | -------------------------------------------------------------------------------- /model/utils/torch_sparse_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import torchsparse 5 | from torchsparse import SparseTensor 6 | 7 | 8 | def sparse_cat(T1, T2, dim): 9 | assert T1.stride == T2.stride 10 | t1, m1 = to_dense(T1) 11 | t2, m2 = to_dense(T2) 12 | max_size, _ = torch.stack( 13 | [torch.tensor(t1.shape[2:]), torch.tensor(t2.shape[2:])] 14 | ).max(dim=0) 15 | 16 | t1_ = torch.zeros(list(t1.shape[:2]) + list(max_size), device=t1.device) 17 | t2_ = torch.zeros(list(t2.shape[:2]) + list(max_size), device=t2.device) 18 | m1_ = torch.zeros(list(m1.shape[:2]) + list(max_size), device=m1.device) 19 | m2_ = torch.zeros(list(m2.shape[:2]) + list(max_size), device=m2.device) 20 | 21 | t1_[:, :, : t1.shape[2], : t1.shape[3], : t1.shape[4]] += t1 22 | t2_[:, :, : t2.shape[2], : t2.shape[3], : t2.shape[4]] += t2 23 | m1_[:, :, : m1.shape[2], : m1.shape[3], : m1.shape[4]] += m1 24 | m2_[:, :, : m2.shape[2], : m2.shape[3], : m2.shape[4]] += m2 25 | 26 | t = torch.cat([t1_, t2_], dim=dim) 27 | m1_ = m1_.to(torch.bool) 28 | m2_ = m2_.to(torch.bool) 29 | m = m1_ | m2_ 30 | out = to_sparse(t, stride=T1.stride, mask=m, spatial_range=T1.spatial_range) 31 | 32 | return out 33 | 34 | 35 | def to_dense(input: SparseTensor): 36 | coords, feats, stride = input.coords, input.feats, input.stride 37 | coords = coords.t().long() 38 | if torchsparse.__version__ == "2.1.0": 39 | pass 40 | else: 41 | coords[:3] = ( 42 | coords[:3] / torch.tensor(stride).reshape(-1, 1).long().to(coords) 43 | ).long() 44 | 45 | coalesce = torch.sparse_coo_tensor(coords, feats).coalesce() 46 | output = coalesce.to_dense() 47 | indices = coalesce.indices().t() 48 | if torchsparse.__version__ == "2.1.0": 49 | # B * W * H * L * C -> B * C * W * H * L 50 | output = output.permute(0, 4, 1, 2, 3).contiguous() 51 | b, w, h, l = indices[:, 0], indices[:, 1], indices[:, 2], indices[:, 3] 52 | else: 53 | # W * H * L * B * C -> B * C * W * H * L 54 | output = output.permute(3, 4, 0, 1, 2).contiguous() 55 | w, h, l, b = indices[:, 0], indices[:, 1], indices[:, 2], indices[:, 3] 56 | 57 | B, C, W, H, L = output.shape 58 | mask = torch.zeros([B, 1, W, H, L], device=feats.device) 59 | mask[b, :, w, h, l] = 1 60 | return output, mask 61 | 62 | 63 | def to_sparse(x, stride=None, mask=None, spatial_range=None): 64 | if stride is None: 65 | stride = (1, 1, 1) 66 | if mask is None: 67 | C = x.sum(dim=1).nonzero() 68 | else: 69 | C = mask.sum(dim=1).nonzero() 70 | b, w, h, l = C[:, 0], C[:, 1], C[:, 2], C[:, 3] 71 | F = x[b, :, w, h, l] 72 | if torchsparse.__version__ == "2.1.0": 73 | C = torch.stack([b, w, h, l]).t().int() 74 | else: 75 | C = torch.stack([w * stride[0], h * stride[1], l * stride[2], b]).t().int() 76 | 77 | if torchsparse.__version__ == "2.1.0": 78 | if spatial_range is None: 79 | spatial_range = [x.shape[0]] + list(x.shape[2:]) 80 | out = SparseTensor(F, C, stride=stride, spatial_range=spatial_range) 81 | else: 82 | out = SparseTensor(F, C, stride=stride) 83 | return out 84 | 85 | 86 | def inherit_sparse_tensor(x, coord, feat): 87 | if torchsparse.__version__ == "2.1.0": 88 | output = SparseTensor( 89 | coords=coord, feats=feat, stride=x.stride, spatial_range=x.spatial_range 90 | ) 91 | cmaps = x._caches.cmaps 92 | kmaps = x._caches.kmaps 93 | output._caches.cmaps = cmaps 94 | output._caches.kmaps = kmaps 95 | else: 96 | output = SparseTensor(coords=coord, feats=feat, stride=x.stride) 97 | cmaps = x.cmaps 98 | kmaps = x.kmaps 99 | output.cmaps = cmaps 100 | output.kmaps = kmaps 101 | return output 102 | 103 | 104 | def get_batch_dim(): 105 | if torchsparse.__version__ == "2.1.0": 106 | return 0 107 | else: 108 | return -1 109 | 110 | 111 | class GroupNorm(nn.GroupNorm): 112 | 113 | def forward(self, input: SparseTensor) -> SparseTensor: 114 | coords, feats = input.coords, input.feats 115 | 116 | batch_dim = get_batch_dim() 117 | batch_size = torch.max(coords[:, batch_dim]).item() + 1 118 | 119 | num_channels = feats.shape[1] 120 | 121 | # PyTorch's GroupNorm function expects the input to be in (N, C, *) 122 | # format where N is batch size, and C is number of channels. "feats" 123 | # is not in that format. So, we extract the feats corresponding to 124 | # each sample, bring it to the format expected by PyTorch's GroupNorm 125 | # function, and invoke it. 126 | nfeats = torch.zeros_like(feats) 127 | for k in range(batch_size): 128 | indices = coords[:, batch_dim] == k 129 | bfeats = feats[indices] 130 | bfeats = bfeats.transpose(0, 1).reshape(1, num_channels, -1) 131 | bfeats = super().forward(bfeats) 132 | bfeats = bfeats.reshape(num_channels, -1).transpose(0, 1) 133 | nfeats[indices] = bfeats 134 | 135 | output = inherit_sparse_tensor(input, coords, nfeats) 136 | 137 | return output 138 | -------------------------------------------------------------------------------- /readme/method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AkiraHero/diffindscene/0ec501b648f627567a1933ad0b518e44334bbf4e/readme/method.png -------------------------------------------------------------------------------- /readme/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AkiraHero/diffindscene/0ec501b648f627567a1933ad0b518e44334bbf4e/readme/teaser.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pyyaml==6.0 2 | easydict==1.10 3 | einops==0.6.1 4 | imageio==2.26.0 5 | torch==1.13.1 6 | numpy==1.23.5 7 | torchvision==0.14.1 8 | tensorboardx==2.6 9 | pandas==1.5.3 10 | scikit-image==0.20.0 11 | trimesh==3.20.2 12 | opencv-python==4.7.0.72 13 | plyfile==0.9 14 | tqdm==4.65.0 15 | # git+https://github.com/mit-han-lab/torchsparse.git 16 | 17 | -------------------------------------------------------------------------------- /sketch_samples/sketch_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AkiraHero/diffindscene/0ec501b648f627567a1933ad0b518e44334bbf4e/sketch_samples/sketch_1.png -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | path = os.path.dirname(os.path.abspath(__file__)) 4 | py_list = [] 5 | for root, dirs, files in os.walk(path, topdown=False): 6 | for name in files: 7 | if name.endswith(".py") and not name.endswith("__init__.py"): 8 | rel_dir = os.path.relpath(root, path) 9 | if rel_dir != ".": 10 | rel_file = os.path.join(rel_dir, name) 11 | else: 12 | rel_file = name 13 | py_list.append(rel_file) 14 | for py in py_list: 15 | mod_name = ".".join([__name__, *(py.split("/"))]) 16 | mod_name = mod_name[:-3] 17 | mod = __import__(mod_name, fromlist=[mod_name]) 18 | classes = [getattr(mod, x) for x in dir(mod) if isinstance(getattr(mod, x), type)] 19 | for cls in classes: 20 | if "trainer" in str(cls): 21 | globals()[cls.__name__] = cls 22 | -------------------------------------------------------------------------------- /trainer/sketchVAE_trainner.py: -------------------------------------------------------------------------------- 1 | from trainer.trainer_base import TrainerBase 2 | import torch 3 | import logging 4 | from utils.logger.basic_logger import LogTracker 5 | import os 6 | import pickle 7 | 8 | 9 | class SKetchVAETrainer(TrainerBase): 10 | def __init__(self, config): 11 | super().__init__() 12 | assert config.config_type in ["training", "testing"] 13 | if config.config_type == "training": 14 | self.optimizer_config = config["optimizer"] 15 | self.max_epoch = config["epoch"] 16 | self.enable_val = config.enable_val 17 | self.val_interval = config.val_interval 18 | self.train_metrics = LogTracker("total_loss", phase="train") 19 | self.train_log_dir = None 20 | elif config.config_type == "testing": 21 | self.test_config = config 22 | self.test_log_dir = config.test_log_dir 23 | if not os.path.exists(self.test_log_dir): 24 | os.makedirs(self.test_log_dir) 25 | 26 | if not self.distributed: 27 | self.device = torch.device(config["device"]) 28 | 29 | def set_optimizer(self, optimizer_config): 30 | model = self.model 31 | if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): 32 | model = self.model.module 33 | optimizer_ref = torch.optim.__dict__[self.optimizer_config[0]["type"]] 34 | self.optimizer = optimizer_ref( 35 | model.get_trainable_parameters(), **optimizer_config[0]["paras"] 36 | ) 37 | logging.info("[Optimizer Paras]" + str(optimizer_config[0]["paras"])) 38 | 39 | def run(self): 40 | # torch.autograd.set_detect_anomaly(True) 41 | if not self.check_ready(): 42 | raise ModuleNotFoundError( 43 | "The trainer not ready. Plz set model/dataset first" 44 | ) 45 | super().run() 46 | model = self.model 47 | if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): 48 | model = self.model.module 49 | self.dataset.set_level("first") 50 | 51 | for epoch in range(self.max_epoch): 52 | if self.enable_val: 53 | if not self.distributed or self.rank == 0: 54 | if self.train_log_dir is None: 55 | self.train_log_dir = self.logger.get_log_dir() 56 | if epoch > 0 and epoch % self.val_interval == 0: 57 | self.is_val = True 58 | logging.info( 59 | "\n\n\n------------------------------Validation Start------------------------------" 60 | ) 61 | self.val_step() 62 | 63 | self.is_val = False 64 | self.epoch = epoch 65 | model.train() 66 | self.train_metrics.reset() 67 | self.data_loader.dataset.epoch = epoch 68 | for step, data in enumerate(self.data_loader): 69 | self.data_loader.dataset.load_data_to_gpu(data, self.device) 70 | 71 | self.optimizer.zero_grad() 72 | 73 | sketch = data["bev_sketch"].squeeze(-1) 74 | dec, posterior = model(sketch) 75 | loss = model.get_loss(sketch, dec, posterior) 76 | total_loss = loss["total"] 77 | total_loss.backward() 78 | 79 | self.optimizer.step() 80 | 81 | # print current status and logging 82 | if not self.distributed or self.rank == 0: 83 | logging.info( 84 | f"[loss] Epoch={epoch}/{self.max_epoch}, step={step}/{len(self.data_loader)}\t" 85 | f"global_step={self.global_step}\t" 86 | f"loss={total_loss:.6f}\t" 87 | # f'tsdf_l1={tsdf_l1:.6f}\t' 88 | ) 89 | self.logger.log_data("loss", total_loss.item(), True) 90 | if step == 0: 91 | self.logger.log_image("gt", sketch[0]) 92 | self.logger.log_image("rec", dec[0]) 93 | 94 | self.step = step 95 | self.global_step += 1 96 | 97 | if not self.distributed or self.rank == 0: 98 | self.logger.log_model_params(self.model, optimizers=self.optimizer) 99 | 100 | def sum_val_dict(self, d): 101 | summary = {} 102 | for i in d: 103 | for k in i: 104 | if k not in summary: 105 | summary[k] = [] 106 | summary[k] += [i[k]] 107 | return summary 108 | 109 | def val_step(self): 110 | raise NotImplementedError 111 | 112 | def run_test(self): 113 | raise NotImplementedError 114 | 115 | def load_state(self, log_file): 116 | if not os.path.exists(log_file): 117 | raise FileNotFoundError(f"file not exist:{log_file}") 118 | params = None 119 | try: 120 | params = torch.load(log_file, map_location=lambda storage, loc: storage) 121 | except Exception as e: 122 | with open(log_file, "rb") as f: 123 | params = pickle.load(f) 124 | if params is not None: 125 | if self.model is not None: 126 | self.model.load_model_paras(params) 127 | else: 128 | raise AssertionError("model does not exist.") 129 | logging.info(f"loaded model params:{log_file}") 130 | # todo: retrive all status including: optimizer epoch log folder... 131 | status = params["status"] 132 | self.epoch = status["epoch"] 133 | self.global_step = status["global_step"] 134 | # if 'opt_paras' in params: 135 | # for opt, opt_paras in zip(self.optimizer, params['opt_paras']): 136 | # opt.load_state_dict(opt_paras) 137 | else: 138 | raise AssertionError("Fail to load params for model.") 139 | -------------------------------------------------------------------------------- /trainer/trainer_base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import logging 4 | from model.model_base import ModelBase 5 | from dataset.dataset_base import DatasetBase 6 | from utils.logger.basic_logger import BasicLogger 7 | from utils.torch_distributed_config import init_distributed_device 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | class TrainerBase: 13 | def __init__(self): 14 | self.model = None 15 | self.dataset = None 16 | self.data_loader = None 17 | self.val_dataset = None 18 | self.val_data_loader = None 19 | self.test_dataset = None 20 | self.test_data_loader = None 21 | 22 | self.enable_val = False 23 | self.device = None 24 | self.optimizer = None 25 | self.optimizer_config = None 26 | self.max_epoch = 0 27 | self.epoch = 0 28 | self.step = 0 29 | self.global_step = 0 30 | self.logger = None 31 | self.distributed = False 32 | self.total_gpus = 0 33 | self.rank = None 34 | self.sync_bn = False 35 | self.launcher = "none" 36 | self.is_val = False 37 | 38 | def get_training_status(self): 39 | training_status = { 40 | "max_epoch": self.max_epoch, 41 | "epoch": self.epoch, 42 | "step": self.step, 43 | "global_step": self.global_step, 44 | } 45 | return training_status 46 | 47 | # to be completed, accoring to running phase: train/test/val 48 | def check_ready(self): 49 | if self.model is None: 50 | return False 51 | if self.dataset is None: 52 | return False 53 | return True 54 | 55 | def run(self): 56 | if not self.check_ready(): 57 | raise ModuleNotFoundError( 58 | "The trainer not ready. Plz set model/dataset first" 59 | ) 60 | if self.optimizer is None: 61 | self.set_optimizer(self.optimizer_config) 62 | if self.distributed and self.sync_bn: 63 | self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model) 64 | self.model.set_device(self.device) 65 | self.data_loader = self.dataset.get_data_loader(distributed=self.distributed) 66 | if self.enable_val: 67 | self.val_data_loader = self.val_dataset.get_data_loader(distributed=False) 68 | if self.distributed: 69 | if not any((p.requires_grad for p in self.model.parameters())): 70 | logging.warning( 71 | "DistributedDataParallel [will not be used] when a module doesn't have any parameter that requires a gradient." 72 | ) 73 | else: 74 | self.model = nn.parallel.DistributedDataParallel( 75 | self.model, device_ids=[self.rank % torch.cuda.device_count()] 76 | ) 77 | # why not set self.model = model.module? 78 | # In case the need to use model(data) directly. 79 | # [Instruction] add code in sub class and using super to run this function for general preparation 80 | 81 | def run_test(self): 82 | self.model.set_device(self.device) 83 | self.test_data_loader = self.test_dataset.get_data_loader( 84 | distributed=self.distributed 85 | ) 86 | # if self.distributed: 87 | # self.model = nn.parallel.DistributedDataParallel(self.model, 88 | # device_ids=[self.rank % torch.cuda.device_count()]) 89 | 90 | def set_model(self, model): 91 | if not isinstance(model, ModelBase): 92 | raise TypeError 93 | self.model = model 94 | if self.optimizer is None: 95 | if self.optimizer_config is not None: 96 | self.set_optimizer(self.optimizer_config) 97 | 98 | def set_dataset(self, dataset): 99 | if not isinstance(dataset, DatasetBase): 100 | raise TypeError 101 | self.dataset = dataset 102 | 103 | def set_val_dataset(self, dataset): 104 | if not isinstance(dataset, DatasetBase): 105 | raise TypeError 106 | self.val_dataset = dataset 107 | 108 | def set_test_dataset(self, dataset): 109 | if not isinstance(dataset, DatasetBase): 110 | raise TypeError 111 | self.test_dataset = dataset 112 | 113 | def set_optimizer(self, optimizer_config): 114 | raise NotImplementedError 115 | 116 | def load_state(self, log_file): 117 | if not os.path.exists(log_file): 118 | raise FileNotFoundError(f"file not exist:{log_file}") 119 | params = None 120 | try: 121 | params = torch.load(log_file, map_location=lambda storage, loc: storage) 122 | except Exception as e: 123 | with open(log_file, "rb") as f: 124 | params = pickle.load(f) 125 | if params is not None: 126 | if self.model is not None: 127 | self.model.load_model_paras(params) 128 | else: 129 | raise AssertionError("model does not exist.") 130 | logging.info(f"loaded model params:{log_file}") 131 | # todo: retrive all status including: optimizer epoch log folder... 132 | status = params["status"] 133 | else: 134 | raise AssertionError("Fail to load params for model.") 135 | 136 | def set_logger(self, logger): 137 | if not isinstance(logger, BasicLogger): 138 | raise TypeError("logger must be with the type: BasicLogger") 139 | self.logger = logger 140 | self.logger.register_status_hook(self.get_training_status) 141 | 142 | def config_distributed_computing(self, launcher, tcp_port=None, local_rank=None): 143 | self.launcher = launcher 144 | if self.launcher == "none": 145 | self.distributed = False 146 | self.total_gpus = 1 147 | else: 148 | self.total_gpus, self.rank = init_distributed_device( 149 | self.launcher, tcp_port, local_rank, backend="nccl" 150 | ) 151 | self.distributed = True 152 | device_id = self.rank % torch.cuda.device_count() 153 | self.device = torch.device(device_id) 154 | -------------------------------------------------------------------------------- /utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AkiraHero/diffindscene/0ec501b648f627567a1933ad0b518e44334bbf4e/utils/.DS_Store -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AkiraHero/diffindscene/0ec501b648f627567a1933ad0b518e44334bbf4e/utils/__init__.py -------------------------------------------------------------------------------- /utils/config/Configuration.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import yaml 4 | import shutil 5 | import datetime 6 | from easydict import EasyDict 7 | 8 | """ 9 | treat the configuration as a tree 10 | """ 11 | 12 | 13 | def default(config, attr, default_value): 14 | if attr in config: 15 | return config[attr] 16 | else: 17 | return default_value 18 | 19 | 20 | # todo: control the access of members 21 | class Configuration: 22 | def __init__(self): 23 | self.config_root_dir = None 24 | self.root_config = None 25 | self.expanded_config = None 26 | self.all_related_config_files = [] 27 | self.dir_checked = False 28 | 29 | self._dataset_config = None 30 | self._training_config = None 31 | self._testing_config = None 32 | self._logging_config = None 33 | self._model_config = None 34 | self._extra_config = {"config_type": "extra"} 35 | 36 | @property 37 | def dataset_config(self): 38 | return EasyDict(self._dataset_config) 39 | 40 | @property 41 | def training_config(self): 42 | return EasyDict(self._training_config) 43 | 44 | @property 45 | def testing_config(self): 46 | return EasyDict(self._testing_config) 47 | 48 | @property 49 | def logging_config(self): 50 | return EasyDict(self._logging_config) 51 | 52 | @property 53 | def model_config(self): 54 | return EasyDict(self._model_config) 55 | 56 | @property 57 | def extra_config(self): 58 | return EasyDict(self._extra_config) 59 | 60 | def check_config_dir(self, config_dir): 61 | if not os.path.isdir(config_dir): 62 | raise IsADirectoryError(f"{config_dir} is not a valid directory.") 63 | # check subconfig dir 64 | if not os.path.isdir(os.path.join(config_dir, "dataset")): 65 | raise IsADirectoryError(f"{config_dir}/dataset is not a valid directory.") 66 | if not os.path.isdir(os.path.join(config_dir, "model")): 67 | raise IsADirectoryError(f"{config_dir}/model is not a valid directory.") 68 | if not os.path.isfile(os.path.join(config_dir, "root_config.yaml")): 69 | raise IsADirectoryError( 70 | f"{config_dir}/root_config.yaml is not a valid/existing file." 71 | ) 72 | self.dir_checked = True 73 | 74 | def load_config(self, config_dir): 75 | self.config_root_dir = config_dir 76 | self.check_config_dir(self.config_root_dir) 77 | if self.dir_checked: 78 | self._load_root_config_file("root_config.yaml") 79 | 80 | def get_complete_config(self): 81 | if self.expanded_config is not None: 82 | return self.expanded_config.copy() 83 | raise TypeError("no complete config found!") 84 | 85 | def get_shell_args_train(self): 86 | parser = argparse.ArgumentParser(description="arg parser") 87 | parser.add_argument( 88 | "--cfg_dir", 89 | required=True, 90 | type=str, 91 | default=None, 92 | help="specify the config for training", 93 | ) 94 | parser.add_argument( 95 | "--batch_size", 96 | type=int, 97 | default=None, 98 | required=False, 99 | help="batch size for training (in each process in distributed training)", 100 | ) 101 | parser.add_argument( 102 | "--epoch", 103 | type=int, 104 | default=None, 105 | required=False, 106 | help="number of epochs to train for", 107 | ) 108 | parser.add_argument( 109 | "--distributed", 110 | action="store_true", 111 | default=False, 112 | help="using distributed training", 113 | ) 114 | parser.add_argument( 115 | "--local_rank", 116 | type=int, 117 | default=0, 118 | help="local rank for distributed training", 119 | ) 120 | parser.add_argument( 121 | "--tcp_port", 122 | type=int, 123 | default=18888, 124 | help="tcp port for distributed training", 125 | ) 126 | parser.add_argument( 127 | "--launcher", 128 | choices=["none", "pytorch", "slurm"], 129 | default="none", 130 | help="select distributed training launcher", 131 | ) 132 | parser.add_argument( 133 | "--screen_log", 134 | type=str, 135 | default="scree_log", 136 | required=False, 137 | help="the file shell redirects to", 138 | ) 139 | parser.add_argument( 140 | "--log_dir", required=False, type=str, default=None, help="log dir" 141 | ) 142 | parser.add_argument( 143 | "--check_point_file", 144 | type=str, 145 | default=None, 146 | help="model checkpoint for pre-loading before training", 147 | ) 148 | args = parser.parse_args() 149 | return args 150 | 151 | def get_shell_args_test(self): 152 | parser = argparse.ArgumentParser(description="arg parser") 153 | parser.add_argument( 154 | "--cfg_dir", 155 | required=True, 156 | type=str, 157 | default=None, 158 | help="specify the config for training", 159 | ) 160 | parser.add_argument( 161 | "--batch_size", 162 | type=int, 163 | default=None, 164 | required=False, 165 | help="batch size for training (in each process in distributed training)", 166 | ) 167 | parser.add_argument( 168 | "--distributed", 169 | action="store_true", 170 | default=False, 171 | help="using distributed testing", 172 | ) 173 | parser.add_argument( 174 | "--local_rank", 175 | type=int, 176 | default=0, 177 | help="local rank for distributed testing", 178 | ) 179 | parser.add_argument( 180 | "--tcp_port", 181 | type=int, 182 | default=18888, 183 | help="tcp port for distributed testing", 184 | ) 185 | parser.add_argument( 186 | "--launcher", 187 | choices=["none", "pytorch", "slurm"], 188 | default="none", 189 | help="select distributed testing launcher", 190 | ) 191 | parser.add_argument( 192 | "--screen_log", type=str, default=None, help="the file shell redirects to" 193 | ) 194 | parser.add_argument( 195 | "--check_point_file", 196 | type=str, 197 | default=None, 198 | help="model checkpoint for pre-loading before testing", 199 | ) 200 | args = parser.parse_args() 201 | return args 202 | 203 | def _load_yaml(self, file): 204 | abs_path = os.path.join(self.config_root_dir, file) 205 | with open(abs_path, "r") as f: 206 | return yaml.safe_load(f) 207 | 208 | def _load_root_config_file(self, config_file): 209 | self.root_config = self._load_yaml(config_file) 210 | self.expanded_config = self.root_config.copy() 211 | self.all_related_config_files.append(config_file) 212 | self._expand_config(self.expanded_config) 213 | # set corresponding config 214 | if "model" in self.expanded_config.keys(): 215 | self._model_config = self.expanded_config["model"] 216 | self._model_config.update({"config_type": "model"}) 217 | if "dataset" in self.expanded_config.keys(): 218 | self._dataset_config = self.expanded_config["dataset"] 219 | self._dataset_config.update({"config_type": "dataset"}) 220 | if "training" in self.expanded_config.keys(): 221 | self._training_config = self.expanded_config["training"] 222 | self._training_config.update({"config_type": "training"}) 223 | if "testing" in self.expanded_config.keys(): 224 | self._testing_config = self.expanded_config["testing"] 225 | self._testing_config.update({"config_type": "testing"}) 226 | if "logging" in self.expanded_config.keys(): 227 | self._logging_config = self.expanded_config["logging"] 228 | self._logging_config.update({"config_type": "logging"}) 229 | 230 | def _expand_config(self, config_dict): 231 | if not self._expand_cur_config(config_dict): 232 | if isinstance(config_dict, dict): 233 | for i in config_dict.keys(): 234 | sub_config = config_dict[i] 235 | self._expand_config(sub_config) 236 | 237 | def _expand_cur_config(self, config_dict): 238 | if not isinstance(config_dict, dict): 239 | return False 240 | if "config_file" in config_dict.keys() and isinstance( 241 | config_dict["config_file"], str 242 | ): 243 | file_name = config_dict["config_file"] 244 | expanded = self._load_yaml(file_name) 245 | self._expand_config(expanded) 246 | self.all_related_config_files.append(file_name) 247 | config_dict["config_file"] = {"file_name": file_name, "expanded": expanded} 248 | return True 249 | return False 250 | 251 | def pack_configurations(self, _path): 252 | # all config file should be located in utils/config?? no 253 | # todo: pack config using expanded config 254 | shutil.copytree(self.config_root_dir, os.path.join(_path, "config")) 255 | 256 | @staticmethod 257 | def find_dict_node(target_dict, node_name): 258 | if not isinstance(target_dict, dict): 259 | raise TypeError 260 | res_parents = [] 261 | res = Configuration._find_node_subtree(target_dict, node_name, res_parents) 262 | 263 | def flat_parents_list(parents, output): 264 | if len(parents) > 1: 265 | output.append(parents[0]) 266 | else: 267 | return 268 | flat_parents_list(parents[1], output) 269 | 270 | output_parents = [] 271 | flat_parents_list(res_parents, output_parents) 272 | return res, output_parents 273 | 274 | def find_node(self, node_name): 275 | return Configuration.find_dict_node(self.expanded_config, node_name) 276 | 277 | @staticmethod 278 | def _find_node_subtree(cur_node, keyword, parents_log=None): 279 | if isinstance(parents_log, list): 280 | parents_log.append(keyword) 281 | if not isinstance(cur_node, dict): 282 | return None 283 | res = Configuration._find_node_cur(cur_node, keyword) 284 | if res is None: 285 | for i in cur_node.keys(): 286 | parents_log.clear() 287 | if isinstance(parents_log, list): 288 | parents_log.append(i) 289 | new_parents_log = [] 290 | parents_log.append(new_parents_log) 291 | res = Configuration._find_node_subtree( 292 | cur_node[i], keyword, new_parents_log 293 | ) 294 | if res is not None: 295 | return res 296 | return res 297 | 298 | @staticmethod 299 | def _find_node_cur(cur_node, keyword): 300 | if not isinstance(cur_node, dict): 301 | return None 302 | for i in cur_node.keys(): 303 | if i == keyword: 304 | return cur_node[i] 305 | return None 306 | 307 | def overwrite_value_by_keywords( 308 | self, parents_keywords_list, cur_keywords, new_value 309 | ): 310 | if not isinstance(self.expanded_config, dict): 311 | raise TypeError 312 | sub_dict_ref = self.expanded_config 313 | for key in parents_keywords_list: 314 | sub_dict_ref = sub_dict_ref[key] 315 | sub_dict_ref[cur_keywords] = new_value 316 | 317 | # only overwrite the first-found one on condition of equal keys 318 | def overwrite_config_by_shell_args(self, args): 319 | for name, value in args._get_kwargs(): 320 | if value is not None: 321 | node, parents = self.find_node(name) 322 | if node is not None: 323 | self.overwrite_value_by_keywords(parents, name, value) 324 | else: 325 | self._extra_config[name] = value 326 | -------------------------------------------------------------------------------- /utils/config/samples/cascaded_ldm_sketch_cond/dataset/ali3dfront.yaml: -------------------------------------------------------------------------------- 1 | dataset_class: Ali3DFront 2 | paras: 3 | mode: 'train' 4 | num_workers: 0 5 | shuffle: false 6 | version: 'v2' 7 | # data_root: '/home/xlju/3dfront' 8 | data_root: '/home/xlju/front3d_ini/new_npz' 9 | 10 | #################### 11 | ## useful config - vary with training level 12 | 13 | batch_size: 1 14 | 15 | # first level 16 | level_config: 17 | first: 18 | load_content: ['latent', 'tsdf', 'sketch'] 19 | transform: ['simpletrans'] 20 | second: 21 | load_content: ['latent'] 22 | transform: ['simpletrans'] 23 | third: 24 | load_content: ['tsdf','latent'] 25 | transform: ['simpletrans', 'simplecrop'] 26 | 27 | voxel_dim: [128, 128, 128] # only effect in third level 28 | batch_collate_func: 'batch_collate_latent_code' 29 | 30 | latent_dir: '/home/xlju/pro/diffs/script/data_latent_new_dataset' 31 | # modify latent_scale according to latents 32 | latent_scale: [0.0290, 0.1550] 33 | data_split_file: 34 | train: train_512_512_128.txt 35 | val: val.txt 36 | test: test.txt 37 | 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /utils/config/samples/cascaded_ldm_sketch_cond/model/pyramid_occ_denoiser.yaml: -------------------------------------------------------------------------------- 1 | name: MSLDM 2 | paras: 3 | unet_model: "sparse" 4 | use_ema: True 5 | multi_restore_batch_size: 4 6 | 7 | ############## 8 | # first second third 9 | level: 'first' 10 | ############## 11 | 12 | add_noise_da: true 13 | classifier_free_guidance: true 14 | 15 | 16 | use_sketch_condition: True 17 | use_sketch_attention: True 18 | sketch_condition_manner: 'cat' 19 | 20 | 21 | # noise schedule 22 | noise_schedule: 23 | first: 24 | num_train_timesteps: 1000 25 | clip_sample: false 26 | beta_schedule: "cos" 27 | second: 28 | num_train_timesteps: 1000 29 | clip_sample: false 30 | beta_schedule: "cos" 31 | third: 32 | num_train_timesteps: 1000 33 | clip_sample: True 34 | beta_schedule: "cos" 35 | clip_sample_range: 3.0 36 | 37 | sketch_embedder: 38 | model_class: SketchVAE 39 | config_file: model/sketch_vae.yaml 40 | ckpt: "/home/xlju/pro/diffs/output/2023-11-29-16-06-47-sk_vae_finetune/model_paras_log/model_ckpt-epoth20-globalstep310794-2023-11-29-17-24-27.pt" 41 | 42 | # latent encoding 43 | first_stage_model: 44 | model_class: MSTSDFPVQGANNew 45 | config_file: model/tsdf_vqgan_retrain.yaml 46 | ckpt: "model_ckpt-epoth844-globalstep399056-2023-10-21-20-19-53.pt" 47 | ############################################################################ 48 | # diffusion network 49 | unet_model: 50 | first: 51 | ckpt: "model_ckpt-epoth979-globalstep258060val_ema-2023-12-08-15-37-28.pt" 52 | model_type: UNetModel 53 | model_args: 54 | dims: 3 55 | is_sparse: True 56 | use_bev: True 57 | in_channels: 8 58 | out_channels: 4 59 | model_channels: 64 60 | attention_resolutions: [1,2] #for text 61 | num_res_blocks: 2 62 | channel_mult: 63 | - 1 64 | - 2 65 | - 4 66 | use_spatial_transformer: True 67 | 68 | transformer_depth: 1 69 | context_dim: 256 70 | 71 | num_heads: 8 72 | use_checkpoint: True 73 | legacy: False 74 | use_position_encoding: True 75 | second: 76 | ckpt: "model_ckpt-epoth494-globalstep368775val_ema-2023-10-27-09-39-59.pt" 77 | model_type: UNetModel 78 | model_args: 79 | dims: 3 80 | is_sparse: True 81 | use_bev: False 82 | in_channels: 8 # with condition 83 | out_channels: 4 84 | model_channels: 64 85 | use_spatial_transformer: False 86 | attention_resolutions: [] 87 | num_res_blocks: 2 88 | channel_mult: 89 | - 1 90 | - 2 91 | - 4 92 | - 8 93 | num_head_channels: 32 94 | third: 95 | ckpt: "model_ckpt-epoth104-globalstep156450val_ema-2023-10-27-15-35-58.pt" 96 | model_type: UNetModel 97 | model_args: 98 | dims: 3 99 | is_sparse: True 100 | use_bev: False 101 | in_channels: 9 102 | out_channels: 1 103 | model_channels: 64 104 | use_spatial_transformer: False 105 | attention_resolutions: [] 106 | num_res_blocks: 2 107 | channel_mult: 108 | - 1 109 | - 2 110 | - 4 111 | - 8 112 | num_head_channels: 32 113 | -------------------------------------------------------------------------------- /utils/config/samples/cascaded_ldm_sketch_cond/model/sketch_vae.yaml: -------------------------------------------------------------------------------- 1 | name: SketchVAE 2 | paras: 3 | embed_dim: 256 4 | n_embed: 2048 5 | kl_weight: 0.0000045 6 | ddconfig: 7 | double_z: True 8 | z_channels: 16 9 | resolution: 256 10 | in_channels: 1 11 | out_ch: 1 12 | ch: 128 13 | ch_mult: [ 1,2,4,4] # num_down = len(ch_mult)-1 14 | num_res_blocks: 2 15 | attn_resolutions: [] 16 | dropout: 0.0 17 | -------------------------------------------------------------------------------- /utils/config/samples/cascaded_ldm_sketch_cond/model/tsdf_vqgan_retrain.yaml: -------------------------------------------------------------------------------- 1 | name: MSTSDFPVQGANNew 2 | paras: 3 | embed_dim: 4 4 | n_embed: 8192 5 | ddconfig1: 6 | double_z: False 7 | z_channels: 4 8 | resolution: 512 9 | in_channels: 1 10 | out_ch: 1 11 | ch: 32 12 | ch_mult: [1,2,4] # num_down = len(ch_mult)-1 13 | num_res_blocks: 2 14 | attn_resolutions: [] 15 | dropout: 0. 16 | ddconfig2: 17 | use_bev: False 18 | double_z: False 19 | z_channels: 4 20 | resolution: 512 21 | in_channels: 4 22 | out_ch: 4 23 | ch: 256 24 | ch_mult: [1,1] # num_down = len(ch_mult)-1 25 | num_res_blocks: 2 26 | attn_resolutions: [] 27 | dropout: 0. 28 | 29 | temperature_scheduler_config: 30 | warm_up_steps: 0 31 | max_decay_steps: 1000001 32 | lr_start: 0.9 33 | lr_max: 0.9 34 | lr_min: 1.0e-06 35 | lossconfig: 36 | params: 37 | disc_conditional: False 38 | disc_in_channels: 1 39 | disc_start: 6000 40 | disc_weight: 0.2 41 | codebook_weight: 1.0 42 | rec_loss: 'l1' 43 | -------------------------------------------------------------------------------- /utils/config/samples/cascaded_ldm_sketch_cond/readme.md: -------------------------------------------------------------------------------- 1 | 2 | * Step 1: set `data_root` in `datadet/ali3dfront.yaml`, which points to the directory of processed `*.npz` files (TSDF). 3 | * Step 2: set `latent_dir` as the latent encoding of initial TSDF data from PatchVQGAN, and the `latent_scale` as the reciprocal of latents STD. 4 | * Step 3: set `level` in `model/pyramid_occ_denoiser.yaml` to set the training stage of the cascaded diffusion; set `use_sketch_condition` to use conditional / unconditional diffusion in the 1st stage; set `sketch_embedder/ckpt` as the checkpoint of `SketchVAE` model if `use_sketch_condition=True`; set `first_stage_model/ckpt` as the checkpoint of the PatchVQGAN model. 5 | * [For inference] set `unet_model/LEVEL/ckpt` as the checkpoints of unet models of different diffusion levels. 6 | * [For Training] set `ckpt` variable in `training` section of `root_config.yaml` for continuous training. -------------------------------------------------------------------------------- /utils/config/samples/cascaded_ldm_sketch_cond/root_config.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | config_file: dataset/ali3dfront.yaml 3 | dataset_class: Ali3DFront 4 | logging: 5 | ckpt_eph_interval: 20 6 | path: ../output 7 | suffix: '1st-ema' 8 | model: 9 | config_file: model/pyramid_occ_denoiser.yaml 10 | model_class: MSLDM 11 | training: 12 | device: 'cuda:0' 13 | epoch: 5000 14 | optimizer: 15 | - name: opt 16 | paras: 17 | lr: 0.0001 18 | type: Adam 19 | trainer_class: CascadedLDMTrainer 20 | 21 | # just to save params 22 | enable_val: true 23 | val_interval: 20 24 | 25 | 26 | testing: 27 | device: 'cuda:0' 28 | mode: "sketch_cond" 29 | trainer_class: CascadedLDMTrainer 30 | save_mesh: true 31 | test_log_dir: "output/sketch_cond" 32 | operating_size: [64, 64, 16] 33 | 34 | # tips: the third stage may be very time-consuming when operating_size=[64, 64, 16], corresponding to a generation of 512*512*128 35 | # so it is better to check the first stage result, then use same seed to run all stages 36 | # level_seq: ['first'] 37 | 38 | level_seq: ['first', 'second', 'third'] 39 | seed: 28699 40 | sketch_image: "sketch_samples/sketch_1.png" 41 | 42 | -------------------------------------------------------------------------------- /utils/config/samples/cascaded_ldm_ucond/dataset/ali3dfront.yaml: -------------------------------------------------------------------------------- 1 | dataset_class: Ali3DFront 2 | paras: 3 | mode: 'train' 4 | num_workers: 0 5 | shuffle: false 6 | version: 'v2' 7 | # data_root: '/home/xlju/3dfront' 8 | data_root: '/home/xlju/front3d_ini/new_npz' 9 | 10 | #################### 11 | ## useful config - vary with training level 12 | 13 | batch_size: 1 14 | 15 | # first level 16 | level_config: 17 | first: 18 | load_content: ['latent', 'tsdf', 'sketch'] 19 | transform: ['simpletrans'] 20 | second: 21 | load_content: ['latent'] 22 | transform: ['simpletrans'] 23 | third: 24 | load_content: ['tsdf','latent'] 25 | transform: ['simpletrans', 'simplecrop'] 26 | 27 | voxel_dim: [128, 128, 128] # only effect in third level 28 | batch_collate_func: 'batch_collate_latent_code' 29 | 30 | latent_dir: '/home/xlju/pro/diffs/script/data_latent_new_dataset' 31 | # modify latent_scale according to latents 32 | latent_scale: [0.0290, 0.1550] 33 | data_split_file: 34 | train: train_512_512_128.txt 35 | val: val.txt 36 | test: test.txt 37 | 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /utils/config/samples/cascaded_ldm_ucond/model/pyramid_occ_denoiser.yaml: -------------------------------------------------------------------------------- 1 | name: MSLDM 2 | paras: 3 | unet_model: "sparse" 4 | 5 | use_ema: True 6 | restoration_method: "general" 7 | multi_restore_batch_size: 16 8 | method: 'generative' 9 | ############## 10 | # first second third 11 | level: 'first' 12 | ############## 13 | add_noise_da: true 14 | classifier_free_guidance: true 15 | use_text_condition: false 16 | 17 | # noise schedule 18 | noise_schedule: 19 | first: 20 | num_train_timesteps: 1000 21 | clip_sample: false 22 | # cos straight arccos scaled_linear 23 | beta_schedule: "cos" 24 | second: 25 | num_train_timesteps: 1000 26 | clip_sample: false 27 | # cos straight arccos scaled_linear 28 | beta_schedule: "cos" 29 | third: 30 | num_train_timesteps: 1000 31 | clip_sample: True 32 | # cos straight arccos scaled_linear 33 | beta_schedule: "cos" 34 | clip_sample_range: 3.0 35 | 36 | 37 | # latent encoding 38 | first_stage_model: 39 | model_class: MSTSDFPVQGANNew 40 | config_file: model/tsdf_vqgan_retrain.yaml 41 | 42 | ckpt: "model_ckpt-epoth844-globalstep399056-2023-10-21-20-19-53.pt" 43 | ############################################################################ 44 | # diffusion network 45 | unet_model: 46 | first: 47 | # consider: /home/xlju/pro/diffs/output/2023-11-08-05-41-51-m_attd_finetune/model_paras_log/model_ckpt-epoth449-globalstep284176val_ema-2023-11-09-00-43-31.pt 48 | # ckpt: "/home/xlju/pro/diffs/output/2023-10-13-18-02-21-sparse1st_resume_lr2/model_paras_log/model_ckpt-epoth495-globalstep531062-2023-10-14-09-30-32.pt" 49 | ckpt: "model_ckpt-epoth494-globalstep171826val_ema-2023-10-27-07-04-32.pt" 50 | model_type: UNetModel 51 | model_args: 52 | dims: 3 53 | is_sparse: True 54 | in_channels: 4 55 | out_channels: 4 56 | model_channels: 64 57 | attention_resolutions: [] 58 | num_res_blocks: 2 59 | channel_mult: 60 | - 1 61 | - 2 62 | - 4 63 | num_head_channels: 32 64 | use_spatial_transformer: True 65 | transformer_depth: 1 66 | context_dim: 256 67 | use_position_encoding: False 68 | second: 69 | ckpt: "model_ckpt-epoth494-globalstep368775val_ema-2023-10-27-09-39-59.pt" 70 | model_type: UNetModel 71 | model_args: 72 | dims: 3 73 | is_sparse: True 74 | in_channels: 8 # with condition 75 | out_channels: 4 76 | model_channels: 64 77 | use_spatial_transformer: False 78 | attention_resolutions: [] 79 | num_res_blocks: 2 80 | channel_mult: 81 | - 1 82 | - 2 83 | - 4 84 | - 8 85 | num_head_channels: 32 86 | third: 87 | ckpt: "model_ckpt-epoth104-globalstep156450val_ema-2023-10-27-15-35-58.pt" 88 | model_type: UNetModel 89 | model_args: 90 | dims: 3 91 | is_sparse: True 92 | in_channels: 9 93 | out_channels: 1 94 | model_channels: 64 95 | use_spatial_transformer: False 96 | attention_resolutions: [] 97 | num_res_blocks: 2 98 | channel_mult: 99 | - 1 100 | - 2 101 | - 4 102 | - 8 103 | num_head_channels: 32 104 | -------------------------------------------------------------------------------- /utils/config/samples/cascaded_ldm_ucond/model/tsdf_vqgan_retrain.yaml: -------------------------------------------------------------------------------- 1 | name: MSTSDFPVQGANNew 2 | paras: 3 | embed_dim: 4 4 | n_embed: 8192 5 | ddconfig1: 6 | double_z: False 7 | z_channels: 4 8 | resolution: 512 9 | in_channels: 1 10 | out_ch: 1 11 | ch: 32 12 | ch_mult: [1,2,4] # num_down = len(ch_mult)-1 13 | num_res_blocks: 2 14 | attn_resolutions: [] 15 | dropout: 0. 16 | ddconfig2: 17 | use_bev: False 18 | double_z: False 19 | z_channels: 4 20 | resolution: 512 21 | in_channels: 4 22 | out_ch: 4 23 | ch: 256 24 | ch_mult: [1,1] # num_down = len(ch_mult)-1 25 | num_res_blocks: 2 26 | attn_resolutions: [] 27 | dropout: 0. 28 | 29 | temperature_scheduler_config: 30 | warm_up_steps: 0 31 | max_decay_steps: 1000001 32 | lr_start: 0.9 33 | lr_max: 0.9 34 | lr_min: 1.0e-06 35 | lossconfig: 36 | params: 37 | disc_conditional: False 38 | disc_in_channels: 1 39 | disc_start: 6000 40 | disc_weight: 0.2 41 | codebook_weight: 1.0 42 | rec_loss: 'l1' 43 | -------------------------------------------------------------------------------- /utils/config/samples/cascaded_ldm_ucond/readme.md: -------------------------------------------------------------------------------- 1 | 2 | * Step 1: set `data_root` in `datadet/ali3dfront.yaml`, which points to the directory of processed `*.npz` files (TSDF). 3 | * Step 2: set `latent_dir` as the latent encoding of initial TSDF data from PatchVQGAN, and the `latent_scale` as the reciprocal of latents STD. 4 | * Step 3: set `level` in `model/pyramid_occ_denoiser.yaml` to set the training stage of the cascaded diffusion; set `use_sketch_condition` to use conditional / unconditional diffusion in the 1st stage; set `sketch_embedder/ckpt` as the checkpoint of `SketchVAE` model if `use_sketch_condition=True`; set `first_stage_model/ckpt` as the checkpoint of the PatchVQGAN model. 5 | * [For inference] set `unet_model/LEVEL/ckpt` as the checkpoints of unet models of different diffusion levels. 6 | * [For Training] set `ckpt` variable in `training` section of `root_config.yaml` for continuous training. -------------------------------------------------------------------------------- /utils/config/samples/cascaded_ldm_ucond/root_config.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | config_file: dataset/ali3dfront.yaml 3 | dataset_class: Ali3DFront 4 | logging: 5 | ckpt_eph_interval: 20 6 | path: ../output 7 | suffix: '1st-ema' 8 | model: 9 | config_file: model/pyramid_occ_denoiser.yaml 10 | model_class: MSLDM 11 | training: 12 | device: 'cuda:0' 13 | epoch: 5000 14 | optimizer: 15 | - name: opt 16 | paras: 17 | lr: 0.0001 18 | type: Adam 19 | trainer_class: CascadedLDMTrainer 20 | 21 | # just to save params 22 | enable_val: true 23 | val_interval: 20 24 | 25 | 26 | testing: 27 | device: 'cuda:0' 28 | mode: "uncond" 29 | seed: 63266 30 | trainer_class: CascadedLDMTrainer 31 | save_mesh: true 32 | test_log_dir: "output/uncond" 33 | operating_size: [32, 32, 16] 34 | level_seq: ['first', 'second', 'third'] 35 | 36 | -------------------------------------------------------------------------------- /utils/config/samples/sketch_VAE/dataset/ali3dfront.yaml: -------------------------------------------------------------------------------- 1 | dataset_class: Ali3DFront 2 | paras: 3 | mode: 'train' 4 | num_workers: 4 5 | shuffle: true 6 | version: 'v2' 7 | # data_root: '/home/xlju/3dfront' 8 | data_root: '/home/xlju/front3d_ini/new_npz' 9 | 10 | #################### 11 | ## useful config - vary with training level 12 | 13 | batch_size: 4 14 | 15 | # first level 16 | level_config: 17 | first: 18 | load_content: ['latent', 'tsdf', 'sketch'] 19 | transform: ['simpletrans'] 20 | second: 21 | load_content: ['latent'] 22 | transform: ['simpletrans'] 23 | third: 24 | load_content: ['tsdf','latent'] 25 | transform: ['simpletrans', 'simplecrop'] 26 | 27 | voxel_dim: [128, 128, 128] # only effect in third level 28 | batch_collate_func: 'batch_collate_latent_code' 29 | #################### 30 | 31 | latent_dir: '/home/xlju/pro/diffs/script/data_latent_new_dataset' 32 | latent_scale: [0.0290, 0.1550] 33 | data_split_file: 34 | train: train_512_512_128.txt 35 | val: val.txt 36 | test: test.txt 37 | 38 | 39 | -------------------------------------------------------------------------------- /utils/config/samples/sketch_VAE/model/sketch_vae.yaml: -------------------------------------------------------------------------------- 1 | name: SketchVAE 2 | paras: 3 | embed_dim: 256 4 | n_embed: 2048 5 | kl_weight: 0.0000045 6 | ddconfig: 7 | double_z: True 8 | z_channels: 16 9 | resolution: 256 10 | in_channels: 1 11 | out_ch: 1 12 | ch: 128 13 | ch_mult: [ 1,2,4,4] # num_down = len(ch_mult)-1 14 | num_res_blocks: 2 15 | attn_resolutions: [] 16 | dropout: 0.0 17 | -------------------------------------------------------------------------------- /utils/config/samples/sketch_VAE/readme.md: -------------------------------------------------------------------------------- 1 | 2 | * Step 1: set `data_root` in `datadet/ali3dfront.yaml`, which points to the directory of processed `*.npz` files (TSDF). 3 | * Step 2: set `latent_dir` as the latent encoding of initial TSDF data from PatchVQGAN, and the `latent_scale` as the reciprocal of latents STD. 4 | * Step 3: set `ckpt` variable in `training` or `testing` section of `root_config.yaml` to load the checkpoint for training or inference. -------------------------------------------------------------------------------- /utils/config/samples/sketch_VAE/root_config.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | config_file: dataset/ali3dfront.yaml 3 | dataset_class: Ali3DFront 4 | logging: 5 | ckpt_eph_interval: 20 6 | path: ../output 7 | suffix: 'none' 8 | model: 9 | config_file: model/sketch_vae.yaml 10 | model_class: SketchVAE 11 | training: 12 | device: 'cuda:0' 13 | epoch: 5000 14 | optimizer: 15 | - name: opt 16 | paras: 17 | lr: 0.00001 18 | type: Adam 19 | trainer_class: SKetchVAETrainer 20 | 21 | # just to save params 22 | enable_val: false 23 | val_interval: 20 24 | 25 | 26 | testing: 27 | device: 'cuda:0' 28 | trainer_class: SKetchVAETrainer 29 | test_log_dir: "/home/xlju/tst" 30 | 31 | 32 | -------------------------------------------------------------------------------- /utils/config/samples/tsdf_gumbel_ms_vqgan/dataset/ali3dfront.yaml: -------------------------------------------------------------------------------- 1 | dataset_class: Ali3DFront 2 | paras: 3 | version: 'v2' 4 | batch_size: 1 5 | mode: 'train' 6 | num_workers: 4 7 | shuffle: true 8 | data_root: '/home/xlju/front3d_ini/new_npz' 9 | data_split_file: 10 | train: train_512_512_128.txt 11 | test: train_512_512_128.txt 12 | voxel_dim: [96, 96, 96] 13 | 14 | # for training 15 | transform: ['randomcrop'] 16 | 17 | # for test 18 | # transform: [] 19 | # batch_collate_func: 'batch_collate_fn_for_lat' 20 | 21 | -------------------------------------------------------------------------------- /utils/config/samples/tsdf_gumbel_ms_vqgan/model/tsdf_vqgan.yaml: -------------------------------------------------------------------------------- 1 | name: MSTSDFPVQGANNew 2 | paras: 3 | embed_dim: 4 4 | n_embed: 8192 5 | ddconfig1: 6 | double_z: False 7 | z_channels: 4 8 | resolution: 512 9 | in_channels: 1 10 | out_ch: 1 11 | ch: 32 12 | ch_mult: [1,2,4] # num_down = len(ch_mult)-1 13 | num_res_blocks: 2 14 | attn_resolutions: [] 15 | dropout: 0. 16 | ddconfig2: 17 | use_bev: False 18 | double_z: False 19 | z_channels: 4 20 | resolution: 512 21 | in_channels: 4 22 | out_ch: 4 23 | ch: 256 24 | ch_mult: [1,1] # num_down = len(ch_mult)-1 25 | num_res_blocks: 2 26 | attn_resolutions: [] 27 | dropout: 0. 28 | 29 | temperature_scheduler_config: 30 | warm_up_steps: 0 31 | max_decay_steps: 1000001 32 | lr_start: 0.9 33 | lr_max: 0.9 34 | lr_min: 1.0e-06 35 | lossconfig: 36 | params: 37 | disc_conditional: False 38 | disc_in_channels: 1 39 | disc_start: 6000 40 | disc_weight: 0.2 41 | codebook_weight: 1.0 42 | rec_loss: 'l1' -------------------------------------------------------------------------------- /utils/config/samples/tsdf_gumbel_ms_vqgan/readme.md: -------------------------------------------------------------------------------- 1 | 2 | * Step 1: set `data_root` in `datadet/ali3dfront.yaml`, which points to the directory of processed `*.npz` files (TSDF). 3 | * Step 2: enable the `# for training` section or the `# for test` section in `datadet/ali3dfront.yaml` for different data transform in training and inference. 4 | * Step 3: set `ckpt` variable in `training` or `testing` section of `root_config.yaml` to load the checkpoint for training or inference. -------------------------------------------------------------------------------- /utils/config/samples/tsdf_gumbel_ms_vqgan/root_config.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | config_file: dataset/ali3dfront.yaml 3 | dataset_class: Ali3DFront 4 | logging: 5 | ckpt_eph_interval: 2 6 | path: ../output 7 | suffix: 'msvqgan' 8 | model: 9 | config_file: model/tsdf_vqgan.yaml 10 | model_class: MSTSDFPVQGANNew 11 | training: 12 | device: 'cuda:0' 13 | epoch: 500 14 | optimizer: 15 | paras: 16 | lr: 0.00001 17 | type: Adam 18 | trainer_class: TSDFPVQGANTrainer 19 | enable_val: false 20 | val_interval: 10 21 | log_ckpt_interval: 2 22 | 23 | 24 | testing: 25 | device: 'cuda:0' 26 | trainer_class: TSDFPVQGANTrainer 27 | save_mesh: false 28 | output_dir: "tst_release" 29 | ckpt: "" -------------------------------------------------------------------------------- /utils/diffusion_monitor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class LossSlidingWindow: 5 | def __init__(self, max_size=1000) -> None: 6 | self.window_max_size = max_size 7 | self.data_slot = [] 8 | 9 | def push(self, data): 10 | if len(self.data_slot) > self.window_max_size: 11 | self.data_slot = self.data_slot[1:] 12 | self.data_slot.append(data) 13 | 14 | 15 | class DiffusionMonitor: 16 | def __init__(self) -> None: 17 | self.scheduler = None 18 | self.timestep_groups = 10 19 | self.snr_groups = 10 20 | self.gamma_groups = 10 21 | self.loss_bin_dict = {} 22 | self.loss_data_window = LossSlidingWindow(max_size=1000) 23 | 24 | def set_scheduler(self, scheduler): 25 | self.scheduler = scheduler 26 | self.calculate_stats_x() 27 | 28 | def calculate_stats_x(self): 29 | self.gamma = self.scheduler.alphas_cumprod 30 | self.snr = self.gamma / (1 - self.gamma + 1e-12) 31 | self.timesteps = [i + 1 for i in range(len(self.gamma))] 32 | 33 | def update_loss(self, loss_dict): 34 | mse_loss_batch_mat = loss_dict["mse_loss_mat"] 35 | loss_mask = loss_dict["loss_mask"] 36 | timestep = loss_dict["timestep"] 37 | gamma = loss_dict["gammas"] 38 | 39 | bs = mse_loss_batch_mat.shape[0] 40 | for i in range(bs): 41 | loss_instance = mse_loss_batch_mat[i][loss_mask[i]].mean() 42 | t = timestep[i] 43 | self.loss_data_window.push( 44 | dict(timestep=t, loss=loss_instance, gamma=gamma[i]) 45 | ) 46 | 47 | def gen_segments(self, x_min, x_max, box_num): 48 | segs = [] 49 | interval = (x_max - x_min) / box_num 50 | st = x_min 51 | for i in range(box_num): 52 | seg = [st, st + interval] 53 | segs += [seg] 54 | st += interval 55 | return segs 56 | 57 | def get_bins(self, x, x_segments, y): 58 | seg_num = len(x_segments) 59 | boxes = [] 60 | for i in range(seg_num): 61 | boxes.append([]) 62 | for x_, y_ in zip(x, y): 63 | for seg_inx, seg in enumerate(x_segments): 64 | if seg[0] <= x_ < seg[1]: 65 | boxes[seg_inx] += [y_] 66 | return boxes 67 | 68 | def get_gamma_loss_dist(self): 69 | # gen data boxes 70 | gamma_segs = [[0, 0.2], [0.2, 0.4], [0.4, 0.6], [0.6, 0.8], [0.8, 1.0]] 71 | loss_data_dict = self.loss_data_window.data_slot 72 | loss_value = [i["loss"] for i in loss_data_dict] 73 | time_value = [i["timestep"] for i in loss_data_dict] 74 | gamma_value = [i["gamma"] for i in loss_data_dict] 75 | 76 | gamma_bins = self.get_bins(gamma_value, gamma_segs, loss_value) 77 | bin_average = [] 78 | for bin in gamma_bins: 79 | if len(bin): 80 | aver = sum(bin) / len(bin) 81 | else: 82 | aver = torch.tensor(-1.0) 83 | bin_average += [aver] 84 | return gamma_segs, bin_average 85 | -------------------------------------------------------------------------------- /utils/logger/basic_logger.py: -------------------------------------------------------------------------------- 1 | import builtins 2 | import os 3 | import datetime 4 | import logging 5 | import pickle 6 | import shutil 7 | import torch 8 | import subprocess 9 | from model.model_base import ModelBase 10 | from utils.config.Configuration import Configuration 11 | from tensorboardX import SummaryWriter 12 | import pandas as pd 13 | import torch 14 | 15 | # todo: add lock while used in multiprocessing... 16 | 17 | 18 | class LogTracker: 19 | """ 20 | record training numerical indicators. 21 | """ 22 | 23 | def __init__(self, *keys, phase="train"): 24 | self.phase = phase 25 | self._data = pd.DataFrame(index=keys, columns=["total", "counts", "average"]) 26 | self.reset() 27 | 28 | def reset(self): 29 | for col in self._data.columns: 30 | self._data[col].values[:] = 0 31 | 32 | def update(self, key, value, n=1): 33 | self._data.total[key] += value * n 34 | self._data.counts[key] += n 35 | self._data.average[key] = self._data.total[key] / self._data.counts[key] 36 | 37 | def avg(self, key): 38 | return self._data.average[key] 39 | 40 | def result(self): 41 | return { 42 | "{}/{}".format(self.phase, k): v 43 | for k, v in dict(self._data.average).items() 44 | } 45 | 46 | 47 | class BasicLogger: 48 | logger = None 49 | 50 | def __new__(cls, *args, **kwargs): 51 | if cls.logger is None: 52 | cls.logger = super(BasicLogger, cls).__new__(cls) 53 | cls.logger.__initialized = False 54 | return cls.logger 55 | 56 | def __init__(self, config): 57 | if self.__initialized: 58 | return 59 | if not isinstance(config, Configuration): 60 | raise TypeError("input must be the Configuration type!") 61 | config_dict = config.get_complete_config() 62 | if "logging" not in config_dict.keys(): 63 | raise KeyError("Not config on logger has been found!") 64 | self._program_version = None 65 | self._monitor_dict = {} 66 | self._status_hook = None 67 | self.root_log_dir = config_dict["logging"]["path"] 68 | self.log_suffix = config_dict["logging"]["suffix"] 69 | self._ckpt_eph_interval = config_dict["logging"]["ckpt_eph_interval"] 70 | if not isinstance(self._ckpt_eph_interval, int): 71 | self._ckpt_eph_interval = 0 72 | 73 | date_time_str = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") 74 | if self.log_suffix is None or len(self.log_suffix) == 0: 75 | self._cur_instance_root_log_dir = date_time_str 76 | else: 77 | self._cur_instance_root_log_dir = "-".join([date_time_str, self.log_suffix]) 78 | 79 | # overwrite by logdir 80 | if "log_dir" in config_dict["logging"]: 81 | self.root_log_dir = "" 82 | self._cur_instance_root_log_dir = config_dict["logging"]["log_dir"] 83 | 84 | self.complete_instance_dir = os.path.join( 85 | self.root_log_dir, self._cur_instance_root_log_dir 86 | ) 87 | 88 | self._tensor_board_log_dir = os.path.join( 89 | self.complete_instance_dir, "tensor_board" 90 | ) 91 | self._data_log_dir = os.path.join(self.complete_instance_dir, "data_log") 92 | self._model_para_log_dir = os.path.join( 93 | self.complete_instance_dir, "model_paras_log" 94 | ) 95 | 96 | if not os.path.exists(self.complete_instance_dir): 97 | os.makedirs(self.complete_instance_dir) 98 | if not os.path.exists(self._tensor_board_log_dir): 99 | os.makedirs(self._tensor_board_log_dir) 100 | if not os.path.exists(self._data_log_dir): 101 | os.makedirs(self._data_log_dir) 102 | if not os.path.exists(self._model_para_log_dir): 103 | os.makedirs(self._model_para_log_dir) 104 | 105 | # add version file 106 | version_file = os.path.join(self.complete_instance_dir, "version.txt") 107 | self.get_program_version() 108 | self.log_version_info(version_file) 109 | self._tensor_board_writer = SummaryWriter(self._tensor_board_log_dir) 110 | self._tensor_board_writer_lf = None 111 | self._data_pickle_file = os.path.join(self._data_log_dir, "data_bin.pkl") 112 | self.__initialized = True 113 | 114 | def get_program_version(self): 115 | git_version = None 116 | try: 117 | git_version = ( 118 | subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode() 119 | ) 120 | if self._program_version is None: 121 | self._program_version = git_version 122 | except: 123 | pass 124 | return git_version 125 | 126 | def log_version_info(self, file_name): 127 | with open(file_name, "w") as f: 128 | f.write(f"Current version:{self._program_version}") 129 | 130 | def log_config(self, config): 131 | if not isinstance(config, Configuration): 132 | raise TypeError("Please input a valid Configuration instance or reference") 133 | config.pack_configurations(self.complete_instance_dir) 134 | 135 | def log_data( 136 | self, data_name, data_content, add_to_tensorboard=False, step_key="global_step" 137 | ): 138 | status = self._status_hook() 139 | if isinstance(data_content, builtins.float): 140 | self._log_scalar( 141 | status, data_name, data_content, add_to_tensorboard, step_key 142 | ) 143 | elif isinstance(data_content, builtins.int): 144 | self._log_scalar( 145 | status, data_name, data_content, add_to_tensorboard, step_key 146 | ) 147 | else: 148 | raise NotImplementedError 149 | 150 | def log_mesh(self, data_name, verts, faces): 151 | if self._tensor_board_writer_lf is None: 152 | new_p = os.path.join(self._tensor_board_log_dir, "large_file_log") 153 | os.makedirs(new_p) 154 | self._tensor_board_writer_lf = SummaryWriter(new_p) 155 | self._tensor_board_writer_lf.add_mesh( 156 | data_name, 157 | vertices=verts, 158 | faces=faces, 159 | global_step=self._status_hook()["global_step"], 160 | ) 161 | 162 | def log_image(self, data_name, image): 163 | if self._tensor_board_writer_lf is None: 164 | new_p = os.path.join(self._tensor_board_log_dir, "large_file_log") 165 | os.makedirs(new_p) 166 | self._tensor_board_writer_lf = SummaryWriter(new_p) 167 | self._tensor_board_writer_lf.add_image( 168 | data_name, image, global_step=self._status_hook()["global_step"] 169 | ) 170 | 171 | def _add_to_pickle(self, status, data_name, data_content): 172 | with open(self._data_pickle_file, "ab") as f: 173 | pickle.dump( 174 | {"status": status, "name": data_name, "content": data_content}, f 175 | ) 176 | 177 | def log_to_pickle(self, data_name, data_content): 178 | status = self._status_hook() 179 | self._add_to_pickle(status, data_name, data_content) 180 | 181 | def save_binary(self, sub_dir, file_name, data_content): 182 | sub_dir = os.path.join(self._data_log_dir, sub_dir) 183 | if not os.path.exists(sub_dir): 184 | os.makedirs(sub_dir) 185 | file_name = os.path.join(sub_dir, file_name) 186 | with open(file_name, "wb") as f: 187 | pickle.dump(data_content, f) 188 | 189 | def get_log_dir(self): 190 | return self._data_log_dir 191 | 192 | def _log_scalar( 193 | self, 194 | status, 195 | data_name, 196 | data_content, 197 | add_to_tensorboard=False, 198 | step_key="global_step", 199 | ): 200 | if data_name not in self._monitor_dict.keys(): 201 | self._monitor_dict[data_name] = [] 202 | self._monitor_dict[data_name].append((status, data_content)) 203 | if add_to_tensorboard: 204 | self._tensor_board_writer.add_scalar( 205 | data_name, data_content, global_step=self._status_hook()[step_key] 206 | ) 207 | self._add_to_pickle(status, data_name, data_content) 208 | 209 | def log_model_params(self, model, optimizers=None, force=False, suffix=""): 210 | # skip when ckpt_eph_interval is set 211 | status = self._status_hook() 212 | epoch = status["epoch"] 213 | global_step = status["global_step"] 214 | 215 | if (not force) and (epoch % self._ckpt_eph_interval != 0): 216 | return False 217 | if model is not None: 218 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 219 | model = model.module 220 | 221 | if not isinstance(model, ModelBase): 222 | raise TypeError("input type must have class attribute of ModelBase!") 223 | para_dict = {"status": status, "model_paras": model.state_dict()} 224 | if optimizers is not None: 225 | if not isinstance(optimizers, list): 226 | # assert isinstance(optimizers, torch.Optimizer) 227 | optimizers = [optimizers] 228 | opt_state_dict_list = [] 229 | for i in optimizers: 230 | opt_state_dict_list += [i.state_dict()] 231 | para_dict.update({"opt_paras": opt_state_dict_list}) 232 | date_time_str = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") 233 | pickle_name = ( 234 | "-".join( 235 | [ 236 | f"model_ckpt-epoth{epoch}-globalstep{global_step}{suffix}", 237 | date_time_str, 238 | ] 239 | ) 240 | + ".pt" 241 | ) 242 | with open(os.path.join(self._model_para_log_dir, pickle_name), "wb") as f: 243 | # pickle.dump(para_dict, f) 244 | torch.save(para_dict, f) 245 | logging.info(f"Log model state dict as: {pickle_name}") 246 | return True 247 | 248 | def register_status_hook(self, fn): 249 | if not callable(fn): 250 | raise TypeError(f"input must be a function!") 251 | self._status_hook = fn 252 | 253 | @classmethod 254 | def get_logger(cls, config=None): 255 | if cls.logger is not None: 256 | if config is not None: 257 | logging.warning("input config for logger will be ignored") 258 | return cls.logger 259 | if config is None: 260 | raise ValueError("config must be set") 261 | else: 262 | cls.logger = BasicLogger(config) 263 | return cls.logger 264 | 265 | def copy_screen_log(self, file_path): 266 | try: 267 | shutil.copy( 268 | file_path, 269 | os.path.join(self.complete_instance_dir, os.path.basename(file_path)), 270 | ) 271 | except: 272 | logging.error("Fail to copy screen log...") 273 | -------------------------------------------------------------------------------- /utils/logger/dummy_logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from utils.logger.basic_logger import BasicLogger 3 | 4 | 5 | class DummyLogger(BasicLogger): 6 | logger = None 7 | 8 | def __init__(self, config): 9 | return 10 | 11 | def log_config(self, config): 12 | return 13 | 14 | def log_data(self, data_name, data_content, add_to_tensorboard=False): 15 | return 16 | 17 | def _add_to_pickle(self, status, data_name, data_content): 18 | return 19 | 20 | def _log_scalar(self, status, data_name, data_content, add_to_tensorboard=False): 21 | return 22 | 23 | def log_model_params(self, *args, **argv): 24 | return 25 | 26 | def register_status_hook(self, fn): 27 | return 28 | 29 | @classmethod 30 | def get_logger(cls, config=None): 31 | if cls.logger is not None: 32 | if config is not None: 33 | logging.warning("input config for logger will be ignored") 34 | return cls.logger 35 | if config is None: 36 | raise ValueError("config must be set") 37 | else: 38 | cls.logger = DummyLogger(config) 39 | return cls.logger 40 | 41 | def copy_screen_log(self, file_path): 42 | return 43 | -------------------------------------------------------------------------------- /utils/torch_distributed_config.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | import random 5 | import shutil 6 | import subprocess 7 | 8 | import torch 9 | import torch.distributed as dist 10 | import torch.multiprocessing as mp 11 | 12 | import datetime 13 | 14 | from torch._utils import _flatten_dense_tensors, _take_tensors, _unflatten_dense_tensors 15 | from collections import OrderedDict 16 | 17 | 18 | def init_distributed_device(launcher, tcp_port, local_rank=None, backend="nccl"): 19 | """ 20 | modified from https://github.com/open-mmlab/mmdetection 21 | Args: 22 | tcp_port: 23 | backend: 24 | 25 | Returns: 26 | 27 | """ 28 | if launcher == "slurm": 29 | logging.info(f"config distributed training with launcher: {launcher}") 30 | proc_id = int(os.environ["SLURM_PROCID"]) 31 | ntasks = int(os.environ["SLURM_NTASKS"]) 32 | node_list = os.environ["SLURM_NODELIST"] 33 | num_gpus = torch.cuda.device_count() 34 | torch.cuda.set_device(proc_id % num_gpus) 35 | addr = subprocess.getoutput( 36 | "scontrol show hostname {} | head -n1".format(node_list) 37 | ) 38 | os.environ["MASTER_PORT"] = str(tcp_port) 39 | os.environ["MASTER_ADDR"] = addr 40 | os.environ["WORLD_SIZE"] = str(ntasks) 41 | os.environ["RANK"] = str(proc_id) 42 | dist.init_process_group(backend=backend) 43 | 44 | total_gpus = dist.get_world_size() 45 | rank = dist.get_rank() 46 | return total_gpus, rank 47 | elif launcher == "pytorch": 48 | logging.info(f"config distributed training with launcher: {launcher}") 49 | assert local_rank is not None 50 | if mp.get_start_method(allow_none=True) is None: 51 | mp.set_start_method("spawn") 52 | 53 | num_gpus = torch.cuda.device_count() 54 | logging.info("Available GPUs:{}".format(num_gpus)) 55 | logging.info("Using TCP Port:{}".format(tcp_port)) 56 | 57 | local_rank = int(os.environ["LOCAL_RANK"]) 58 | torch.cuda.set_device(local_rank % num_gpus) 59 | logging.info("local_rank:{}".format(local_rank)) 60 | dist.init_process_group( 61 | backend=backend, 62 | # init_method='tcp://localhost:%d' % tcp_port, 63 | rank=local_rank, 64 | world_size=num_gpus, 65 | timeout=datetime.timedelta(seconds=60), 66 | ) 67 | 68 | rank = dist.get_rank() 69 | os.environ["WORLD_SIZE"] = str(num_gpus) 70 | os.environ["RANK"] = str(local_rank) 71 | return num_gpus, rank 72 | else: 73 | raise NotImplementedError 74 | 75 | 76 | def get_dist_info(): 77 | if torch.__version__ < "1.0": 78 | initialized = dist._initialized 79 | else: 80 | if dist.is_available(): 81 | initialized = dist.is_initialized() 82 | else: 83 | initialized = False 84 | if initialized: 85 | rank = dist.get_rank() 86 | world_size = dist.get_world_size() 87 | else: 88 | rank = 0 89 | world_size = 1 90 | return rank, world_size 91 | 92 | 93 | def merge_results_dist(result_part, size, tmpdir): 94 | rank, world_size = get_dist_info() 95 | os.makedirs(tmpdir, exist_ok=True) 96 | 97 | dist.barrier() 98 | pickle.dump( 99 | result_part, open(os.path.join(tmpdir, "result_part_{}.pkl".format(rank)), "wb") 100 | ) 101 | dist.barrier() 102 | 103 | if rank != 0: 104 | return None 105 | 106 | part_list = [] 107 | for i in range(world_size): 108 | part_file = os.path.join(tmpdir, "result_part_{}.pkl".format(i)) 109 | part_list.append(pickle.load(open(part_file, "rb"))) 110 | 111 | ordered_results = [] 112 | for res in zip(*part_list): 113 | ordered_results.extend(list(res)) 114 | ordered_results = ordered_results[:size] 115 | shutil.rmtree(tmpdir) 116 | return ordered_results 117 | 118 | 119 | def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1): 120 | if bucket_size_mb > 0: 121 | bucket_size_bytes = bucket_size_mb * 1024 * 1024 122 | buckets = _take_tensors(tensors, bucket_size_bytes) 123 | else: 124 | buckets = OrderedDict() 125 | for tensor in tensors: 126 | tp = tensor.type() 127 | if tp not in buckets: 128 | buckets[tp] = [] 129 | buckets[tp].append(tensor) 130 | buckets = buckets.values() 131 | 132 | for bucket in buckets: 133 | flat_tensors = _flatten_dense_tensors(bucket) 134 | dist.all_reduce(flat_tensors) 135 | flat_tensors.div_(world_size) 136 | for tensor, synced in zip( 137 | bucket, _unflatten_dense_tensors(flat_tensors, bucket) 138 | ): 139 | tensor.copy_(synced) 140 | 141 | 142 | def allreduce_grads(params, coalesce=True, bucket_size_mb=-1): 143 | """Allreduce gradients. 144 | 145 | Args: 146 | params (list[torch.Parameters]): List of parameters of a model 147 | coalesce (bool, optional): Whether allreduce parameters as a whole. 148 | Defaults to True. 149 | bucket_size_mb (int, optional): Size of bucket, the unit is MB. 150 | Defaults to -1. 151 | """ 152 | grads = [ 153 | param.grad.data 154 | for param in params 155 | if param.requires_grad and param.grad is not None 156 | ] 157 | world_size = dist.get_world_size() 158 | if coalesce: 159 | _allreduce_coalesced(grads, world_size, bucket_size_mb) 160 | else: 161 | for tensor in grads: 162 | dist.all_reduce(tensor.div_(world_size)) 163 | -------------------------------------------------------------------------------- /utils/visualize_occ.py: -------------------------------------------------------------------------------- 1 | ''' 2 | part of this file is modified from pytorch3d 3 | 4 | ''' 5 | import numpy as np 6 | import torch 7 | 8 | from utils.graphics_utils import cubify 9 | from io import BytesIO 10 | from typing import Optional 11 | import logging 12 | import sys 13 | 14 | def _write_ply_header( 15 | f, 16 | *, 17 | verts: torch.Tensor, 18 | faces: Optional[torch.LongTensor], 19 | verts_normals: Optional[torch.Tensor], 20 | verts_colors: Optional[torch.Tensor], 21 | ascii: bool, 22 | colors_as_uint8: bool, 23 | ) -> None: 24 | """ 25 | Internal implementation for writing header when saving to a .ply file. 26 | 27 | Args: 28 | f: File object to which the 3D data should be written. 29 | verts: FloatTensor of shape (V, 3) giving vertex coordinates. 30 | faces: LongTensor of shape (F, 3) giving faces. 31 | verts_normals: FloatTensor of shape (V, 3) giving vertex normals. 32 | verts_colors: FloatTensor of shape (V, 3) giving vertex colors. 33 | ascii: (bool) whether to use the ascii ply format. 34 | colors_as_uint8: Whether to save colors as numbers in the range 35 | [0, 255] instead of float32. 36 | """ 37 | assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3) 38 | assert faces is None or not len(faces) or (faces.dim() == 2 and faces.size(1) == 3) 39 | assert verts_normals is None or ( 40 | verts_normals.dim() == 2 and verts_normals.size(1) == 3 41 | ) 42 | assert verts_colors is None or ( 43 | verts_colors.dim() == 2 and verts_colors.size(1) == 3 44 | ) 45 | 46 | if ascii: 47 | f.write(b"ply\nformat ascii 1.0\n") 48 | elif sys.byteorder == "big": 49 | f.write(b"ply\nformat binary_big_endian 1.0\n") 50 | else: 51 | f.write(b"ply\nformat binary_little_endian 1.0\n") 52 | f.write(f"element vertex {verts.shape[0]}\n".encode("ascii")) 53 | f.write(b"property float x\n") 54 | f.write(b"property float y\n") 55 | f.write(b"property float z\n") 56 | if verts_normals is not None: 57 | f.write(b"property float nx\n") 58 | f.write(b"property float ny\n") 59 | f.write(b"property float nz\n") 60 | if verts_colors is not None: 61 | color_ply_type = b"uchar" if colors_as_uint8 else b"float" 62 | for color in (b"red", b"green", b"blue"): 63 | f.write(b"property " + color_ply_type + b" " + color + b"\n") 64 | if len(verts) and faces is not None: 65 | f.write(f"element face {faces.shape[0]}\n".encode("ascii")) 66 | f.write(b"property list uchar int vertex_index\n") 67 | f.write(b"end_header\n") 68 | 69 | def _check_faces_indices( 70 | faces_indices: torch.Tensor, max_index: int, pad_value: Optional[int] = None 71 | ) -> torch.Tensor: 72 | if pad_value is None: 73 | mask = torch.ones(faces_indices.shape[:-1]).bool() # Keep all faces 74 | else: 75 | mask = faces_indices.ne(pad_value).any(dim=-1) 76 | if torch.any(faces_indices[mask] >= max_index) or torch.any( 77 | faces_indices[mask] < 0 78 | ): 79 | logging.warn("Faces have invalid indices") 80 | return faces_indices 81 | 82 | 83 | def _save_ply( 84 | f, 85 | *, 86 | verts: torch.Tensor, 87 | faces: Optional[torch.LongTensor], 88 | verts_normals: Optional[torch.Tensor] = None, 89 | verts_colors: Optional[torch.Tensor] = None, 90 | ascii: bool = False, 91 | decimal_places: Optional[int] = None, 92 | colors_as_uint8: bool = False, 93 | ) -> None: 94 | """ 95 | Internal implementation for saving 3D data to a .ply file. 96 | 97 | Args: 98 | f: File object to which the 3D data should be written. 99 | verts: FloatTensor of shape (V, 3) giving vertex coordinates. 100 | faces: LongTensor of shape (F, 3) giving faces. 101 | verts_normals: FloatTensor of shape (V, 3) giving vertex normals. 102 | verts_colors: FloatTensor of shape (V, 3) giving vertex colors. 103 | ascii: (bool) whether to use the ascii ply format. 104 | decimal_places: Number of decimal places for saving if ascii=True. 105 | colors_as_uint8: Whether to save colors as numbers in the range 106 | [0, 255] instead of float32. 107 | """ 108 | _write_ply_header( 109 | f, 110 | verts=verts, 111 | faces=faces, 112 | verts_normals=verts_normals, 113 | verts_colors=verts_colors, 114 | ascii=ascii, 115 | colors_as_uint8=colors_as_uint8, 116 | ) 117 | 118 | if not (len(verts)): 119 | logging.warn("Empty 'verts' provided") 120 | return 121 | 122 | color_np_type = np.ubyte if colors_as_uint8 else np.float32 123 | verts_dtype = [("verts", np.float32, 3)] 124 | if verts_normals is not None: 125 | verts_dtype.append(("normals", np.float32, 3)) 126 | if verts_colors is not None: 127 | verts_dtype.append(("colors", color_np_type, 3)) 128 | 129 | vert_data = np.zeros(verts.shape[0], dtype=verts_dtype) 130 | vert_data["verts"] = verts.detach().cpu().numpy() 131 | if verts_normals is not None: 132 | vert_data["normals"] = verts_normals.detach().cpu().numpy() 133 | if verts_colors is not None: 134 | color_data = verts_colors.detach().cpu().numpy() 135 | if colors_as_uint8: 136 | vert_data["colors"] = np.rint(color_data * 255) 137 | else: 138 | vert_data["colors"] = color_data 139 | 140 | if ascii: 141 | if decimal_places is None: 142 | float_str = b"%f" 143 | else: 144 | float_str = b"%" + b".%df" % decimal_places 145 | float_group_str = (float_str + b" ") * 3 146 | formats = [float_group_str] 147 | if verts_normals is not None: 148 | formats.append(float_group_str) 149 | if verts_colors is not None: 150 | formats.append(b"%d %d %d " if colors_as_uint8 else float_group_str) 151 | formats[-1] = formats[-1][:-1] + b"\n" 152 | for line_data in vert_data: 153 | for data, format in zip(line_data, formats): 154 | f.write(format % tuple(data)) 155 | else: 156 | if isinstance(f, BytesIO): 157 | # tofile only works with real files, but is faster than this. 158 | f.write(vert_data.tobytes()) 159 | else: 160 | vert_data.tofile(f) 161 | 162 | if faces is not None: 163 | faces_array = faces.detach().cpu().numpy() 164 | 165 | _check_faces_indices(faces, max_index=verts.shape[0]) 166 | 167 | if len(faces_array): 168 | if ascii: 169 | np.savetxt(f, faces_array, "3 %d %d %d") 170 | else: 171 | faces_recs = np.zeros( 172 | len(faces_array), 173 | dtype=[("count", np.uint8), ("vertex_indices", np.uint32, 3)], 174 | ) 175 | faces_recs["count"] = 3 176 | faces_recs["vertex_indices"] = faces_array 177 | faces_uints = faces_recs.view(np.uint8) 178 | 179 | if isinstance(f, BytesIO): 180 | f.write(faces_uints.tobytes()) 181 | else: 182 | faces_uints.tofile(f) 183 | 184 | 185 | 186 | def occ2mesh(occ_vol, thres=0.9, voxel_size=0.04, origin=torch.tensor((0,0,0))): 187 | if len(occ_vol.shape) < 4: 188 | occ_vol = occ_vol.unsqueeze(0) 189 | meshvert, face = cubify(occ_vol, thres, align='center') 190 | verts=meshvert[0] 191 | faces=face[0] 192 | if not verts.shape[0]: 193 | return None, None 194 | 195 | origin = origin.to(verts.device) 196 | verts = verts * voxel_size + origin 197 | 198 | # verts[:, [0,1]] = verts[:, [1,0]] 199 | # faces[:, [0,1,2]] = faces[:, [0,2,1]] 200 | 201 | verts[:, [0,2]] = verts[:, [2,0]] 202 | faces[:, [0,1,2]] = faces[:, [2,1,0]] 203 | 204 | return verts, faces 205 | 206 | 207 | def mesh2ply(verts, faces, filename): 208 | with open(filename, 'wb') as f: 209 | _save_ply(f, verts=verts, faces=faces) 210 | 211 | 212 | --------------------------------------------------------------------------------