├── utils ├── __init__.py ├── meter.py ├── tool_add_control.py ├── generate_3d_mask.py ├── visualize_voxel.py ├── debugger.py └── generate_2d_mask.py ├── images └── teaser.png ├── lib ├── __init__.py ├── solvers.py ├── visualize.py ├── distributed.py └── utils.py ├── data ├── npy_2_pth.py ├── visualize_df_gt.py └── sdf_2_npy.py ├── models ├── diffusion │ ├── common.py │ └── __init__.py ├── modules │ ├── fp16_util.py │ ├── nn.py │ ├── resample.py │ └── scheduler.py └── networks │ ├── __init__.py │ ├── controlnet.py │ └── resunet3d.py ├── ddp_main.py ├── .gitignore ├── datasets ├── epn_control.py ├── dataset.py ├── __init__.py ├── dataloader.py └── transforms.py ├── configs ├── default.yaml ├── epn_control_train.yaml └── epn_control_test.yaml ├── README.md ├── environment.yml ├── tools ├── test.py └── ddp_trainer.py └── LICENSE /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /images/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/DiffComplete/HEAD/images/teaser.png -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- 1 | # This source code is licensed under the MIT license found in the 2 | # LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /utils/meter.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | class Meter(object): 5 | 6 | def __init__(self): 7 | self.measure_dicts = {} 8 | 9 | def add_attributes(self, new_attribute): 10 | self.measure_dicts[new_attribute] = [] 11 | 12 | def clear_data(self): 13 | for key, item in self.measure_dicts.items(): 14 | self.measure_dicts[key] = [] 15 | 16 | def add_data(self, attribute, data): 17 | 18 | assert attribute in self.measure_dicts 19 | self.measure_dicts[attribute].append(data) 20 | 21 | def return_avg_dict(self): 22 | 23 | return_dict = {} 24 | for key, item in self.measure_dicts.items(): 25 | return_dict[key] = f'{np.mean(self.measure_dicts[key]):.9f}' 26 | 27 | return return_dict -------------------------------------------------------------------------------- /data/npy_2_pth.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import sys 3 | import os 4 | import numpy as np 5 | import os.path as osp 6 | import torch 7 | 8 | if __name__ == '__main__': 9 | sdf_path = "/mnt/proj74/rhchu/dataset/3d_epn/shapenet_dim32_sdf_npy" 10 | df_path = "/mnt/proj74/rhchu/dataset/3d_epn/shapenet_dim32_df_npy" 11 | out_path = "/mnt/proj74/rhchu/dataset/3d_epn/control_data" 12 | clss = ['02933112', '04530566', '03636649', '02691156', '02958343', '04379243', '04256520', '03001627'] 13 | 14 | for cls in clss: 15 | sdfs = osp.join(sdf_path, cls) 16 | sdf_files = os.listdir(sdfs) 17 | for sdf_file in sdf_files: 18 | sdf_name = sdf_file[:-4] 19 | gt_file = sdf_name[:-3] + '0__.npy' 20 | sdf = np.load(osp.join(sdf_path, cls, sdf_file)) 21 | df = np.load(osp.join(df_path, cls, gt_file)) 22 | out_cls_path = osp.join(out_path, cls) 23 | os.makedirs(out_cls_path, exist_ok=True) 24 | out_file = osp.join(out_cls_path, sdf_name + '.pth') 25 | torch.save((sdf, df), out_file) 26 | 27 | -------------------------------------------------------------------------------- /data/visualize_df_gt.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import sys 3 | import os 4 | import numpy as np 5 | import os.path as osp 6 | import mcubes 7 | 8 | def get_shape_df(path): 9 | dims = np.fromfile(path, np.uint64, 3) 10 | df = np.fromfile(path, np.float32, offset=3 * 8).reshape(dims) 11 | return df 12 | 13 | 14 | if __name__ == '__main__': 15 | 16 | base_path = "/mnt/proj74/rhchu/dataset/3d-epn/shapenet_dim32_df" 17 | cls = '03001627' 18 | out_path = "/mnt/proj74/rhchu/dataset/3d-epn/vis_df_gt" 19 | 20 | df_path = osp.join(base_path, cls) 21 | df_files = os.listdir(df_path) 22 | for df_file in df_files: 23 | df_name = df_file[:-3] 24 | df = get_shape_df(osp.join(df_path, df_file)) 25 | out_cls_path = osp.join(out_path, cls) 26 | os.makedirs(out_cls_path, exist_ok=True) 27 | tdf = np.clip(df, 0, 3) 28 | out_file = osp.join(out_cls_path, df_name + '.obj') 29 | vertices, traingles = mcubes.marching_cubes(tdf, 0.5) 30 | mcubes.export_obj(vertices, traingles, out_file) 31 | print(f"Save {out_file}!") 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /models/diffusion/common.py: -------------------------------------------------------------------------------- 1 | import enum 2 | 3 | class ModelMeanType(enum.Enum): 4 | """ 5 | Which type of output the model predicts. 6 | """ 7 | 8 | PREVIOUS_X = enum.auto() # the model predicts x_{t-1} 9 | START_X = enum.auto() # the model predicts x_0 10 | EPSILON = enum.auto() # the model predicts epsilon 11 | 12 | 13 | class ModelVarType(enum.Enum): 14 | """ 15 | What is used as the model's output variance. 16 | The LEARNED_RANGE option has been added to allow the model to predict 17 | values between FIXED_SMALL and FIXED_LARGE, making its job easier. 18 | """ 19 | 20 | LEARNED = enum.auto() 21 | FIXED_SMALL = enum.auto() 22 | FIXED_LARGE = enum.auto() 23 | LEARNED_RANGE = enum.auto() 24 | 25 | 26 | class LossType(enum.Enum): 27 | MSE = enum.auto() # use raw MSE loss (and KL when learning variances) 28 | RESCALED_MSE = ( 29 | enum.auto() 30 | ) # use raw MSE loss (with RESCALED_KL when learning variances) 31 | KL = enum.auto() # use the variational lower-bound 32 | RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB 33 | 34 | def is_vb(self): 35 | return self == LossType.KL or self == LossType.RESCALED_KL -------------------------------------------------------------------------------- /ddp_main.py: -------------------------------------------------------------------------------- 1 | # This source code is licensed under the MIT license found in the 2 | # LICENSE file in the root directory of this source tree. 3 | 4 | import os 5 | import os.path as osp 6 | import sys 7 | import torch 8 | import hydra 9 | from hydra.core.hydra_config import HydraConfig 10 | import numpy as np 11 | import argparse 12 | import shutil 13 | import time 14 | from tools.ddp_trainer import DiffusionTrainer 15 | from lib.distributed import multi_proc_run, ErrorHandler 16 | import random 17 | 18 | 19 | def single_proc_run(config): 20 | if not torch.cuda.is_available(): 21 | raise Exception('No GPUs FOUND.') 22 | trainer = DiffusionTrainer(config) 23 | if config.train.is_train: 24 | trainer.train() 25 | else: 26 | trainer.test() 27 | 28 | def get_args(): 29 | parser = argparse.ArgumentParser('DiffComplete') 30 | parser.add_argument('--config', type=str, help='name of config file') 31 | args = parser.parse_args() 32 | return args 33 | 34 | 35 | @hydra.main(config_path='configs', config_name='epn_control_train.yaml') 36 | def main(config): 37 | # fix seed 38 | np.random.seed(config.misc.seed) 39 | torch.manual_seed(config.misc.seed) 40 | torch.cuda.manual_seed(config.misc.seed) 41 | 42 | # Convert to dict 43 | if config.exp.num_gpus > 1: 44 | multi_proc_run(config.exp.num_gpus, fun=single_proc_run, fun_args=(config,)) 45 | else: 46 | single_proc_run(config) 47 | 48 | if __name__ == '__main__': 49 | __spec__ = None 50 | os.environ['MKL_THREADING_LAYER'] = 'GNU' 51 | # os.environ["OMP_NUM_THREADS"] = "4" 52 | main() -------------------------------------------------------------------------------- /utils/tool_add_control.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | assert len(sys.argv) == 3, 'Args are wrong.' 5 | 6 | # input_path = sys.argv[1] 7 | # output_path = sys.argv[2] 8 | 9 | input_path = sys.argv[1] 10 | output_path = sys.argv[2] 11 | 12 | assert os.path.exists(input_path), 'Input model does not exist.' 13 | assert not os.path.exists(output_path), 'Output filename already exists.' 14 | assert os.path.exists(os.path.dirname(output_path)), 'Output path is not valid.' 15 | 16 | import torch 17 | from share import * 18 | from cldm.model import create_model 19 | 20 | 21 | def get_node_name(name, parent_name): 22 | if len(name) <= len(parent_name): 23 | return False, '' 24 | p = name[:len(parent_name)] 25 | if p != parent_name: 26 | return False, '' 27 | return True, name[len(parent_name):] 28 | 29 | 30 | model = create_model(config_path='./models/cldm_v15.yaml') 31 | 32 | import ipdb; ipdb.set_trace() 33 | pretrained_weights = torch.load(input_path) 34 | if 'state_dict' in pretrained_weights: 35 | pretrained_weights = pretrained_weights['state_dict'] 36 | 37 | scratch_dict = model.state_dict() 38 | 39 | target_dict = {} 40 | for k in scratch_dict.keys(): 41 | is_control, name = get_node_name(k, 'control_') 42 | if is_control: 43 | copy_k = 'model.diffusion_' + name 44 | else: 45 | copy_k = k 46 | if copy_k in pretrained_weights: 47 | target_dict[k] = pretrained_weights[copy_k].clone() 48 | else: 49 | target_dict[k] = scratch_dict[k].clone() 50 | print(f'These weights are newly added: {k}') 51 | 52 | model.load_state_dict(target_dict, strict=True) 53 | torch.save(model.state_dict(), output_path) 54 | print('Done.') 55 | -------------------------------------------------------------------------------- /data/sdf_2_npy.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import sys 3 | import os 4 | import numpy as np 5 | import os.path as osp 6 | 7 | def get_shape_df(path): 8 | dims = np.fromfile(path, np.uint64, 3) 9 | df = np.fromfile(path, np.float32, offset=3 * 8).reshape(dims) 10 | return df 11 | 12 | def get_shape_sdf(path): 13 | dims = np.fromfile(path, np.uint64, 3) 14 | sdf = np.fromfile(path, dtype = np.float32, offset=3*8).reshape(dims) 15 | return sdf 16 | 17 | if __name__ == '__main__': 18 | 19 | base_df_path = "/mnt/proj74/rhchu/dataset/3d_epn/shapenet_dim32_df" 20 | out_df_path = "/mnt/proj74/rhchu/dataset/3d_epn/shapenet_dim32_df_npy" 21 | base_sdf_path = "/mnt/proj74/rhchu/dataset/3d_epn/shapenet_dim32_sdf" 22 | out_sdf_path = "/mnt/proj74/rhchu/dataset/3d_epn/shapenet_dim32_sdf_npy" 23 | 24 | clss = ['02933112', '04530566', '03636649', '02691156', '02958343', '04379243', '04256520', '03001627'] 25 | 26 | # Process df files 27 | for cls in clss: 28 | df_path = osp.join(base_df_path, cls) 29 | df_files = os.listdir(df_path) 30 | for df_file in df_files: 31 | df_name = df_file[:-3] 32 | df = get_shape_df(osp.join(df_path, df_file)) 33 | out_cls_path = osp.join(out_df_path, cls) 34 | os.makedirs(out_cls_path, exist_ok=True) 35 | out_file = osp.join(out_cls_path, df_name + '.npy') 36 | np.save(out_file, df) 37 | 38 | # Process sdf files 39 | for cls in clss: 40 | sdf_path = osp.join(base_sdf_path, cls) 41 | sdf_files = os.listdir(sdf_path) 42 | for sdf_file in sdf_files: 43 | sdf_name = sdf_file[:-4] 44 | sdf = get_shape_sdf(osp.join(sdf_path, sdf_file)) 45 | sdf = np.expand_dims(sdf, 0) 46 | sdf = np.concatenate([np.fabs(sdf), np.sign(sdf)], axis=0) 47 | out_cls_path = osp.join(out_sdf_path, cls) 48 | os.makedirs(out_cls_path, exist_ok=True) 49 | out_file = osp.join(out_cls_path, sdf_name + '.npy') 50 | np.save(out_file, sdf) 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib64/ 18 | parts/ 19 | sdist/ 20 | var/ 21 | wheels/ 22 | pip-wheel-metadata/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | db.sqlite3-journal 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # pipenv 87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 90 | # install all needed dependencies. 91 | #Pipfile.lock 92 | 93 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 94 | __pypackages__/ 95 | 96 | # Celery stuff 97 | celerybeat-schedule 98 | celerybeat.pid 99 | 100 | # SageMath parsed files 101 | *.sage.py 102 | 103 | # Environments 104 | .env 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | -------------------------------------------------------------------------------- /models/modules/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import torch.nn as nn 6 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 7 | 8 | 9 | def convert_module_to_f16(l): 10 | """ 11 | Convert primitive modules to float16. 12 | """ 13 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 14 | l.weight.data = l.weight.data.half() 15 | l.bias.data = l.bias.data.half() 16 | 17 | 18 | def convert_module_to_f32(l): 19 | """ 20 | Convert primitive modules to float32, undoing convert_module_to_f16(). 21 | """ 22 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 23 | l.weight.data = l.weight.data.float() 24 | l.bias.data = l.bias.data.float() 25 | 26 | 27 | def make_master_params(model_params): 28 | """ 29 | Copy model parameters into a (differently-shaped) list of full-precision 30 | parameters. 31 | """ 32 | master_params = _flatten_dense_tensors( 33 | [param.detach().float() for param in model_params] 34 | ) 35 | master_params = nn.Parameter(master_params) 36 | master_params.requires_grad = True 37 | return [master_params] 38 | 39 | 40 | def model_grads_to_master_grads(model_params, master_params): 41 | """ 42 | Copy the gradients from the model parameters into the master parameters 43 | from make_master_params(). 44 | """ 45 | master_params[0].grad = _flatten_dense_tensors( 46 | [param.grad.data.detach().float() for param in model_params] 47 | ) 48 | 49 | 50 | def master_params_to_model_params(model_params, master_params): 51 | """ 52 | Copy the master parameter data back into the model parameters. 53 | """ 54 | # Without copying to a list, if a generator is passed, this will 55 | # silently not copy any parameters. 56 | model_params = list(model_params) 57 | 58 | for param, master_param in zip( 59 | model_params, unflatten_master_params(model_params, master_params) 60 | ): 61 | param.detach().copy_(master_param) 62 | 63 | 64 | def unflatten_master_params(model_params, master_params): 65 | """ 66 | Unflatten the master parameters to look like model_params. 67 | """ 68 | return _unflatten_dense_tensors(master_params[0].detach(), model_params) 69 | 70 | 71 | def zero_grad(model_params): 72 | for param in model_params: 73 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 74 | if param.grad is not None: 75 | param.grad.detach_() 76 | param.grad.zero_() -------------------------------------------------------------------------------- /models/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | import models.diffusion.gaussian_diffusion as gaussian_diffusion 2 | from models.diffusion.gaussian_diffusion import space_timesteps 3 | from models.diffusion.common import ModelMeanType, ModelVarType, LossType 4 | 5 | ModelMeanTypeDict= { 6 | 'PREVIOUS_X': ModelMeanType.PREVIOUS_X, 7 | 'START_X': ModelMeanType.START_X, 8 | 'EPSILON': ModelMeanType.EPSILON 9 | } 10 | 11 | ModelVarTypeDict= { 12 | 'LEARNED': ModelVarType.LEARNED, 13 | 'FIXED_SMALL': ModelVarType.FIXED_SMALL, 14 | 'FIXED_LARGE': ModelVarType.FIXED_LARGE, 15 | 'LEARNED_RANGE': ModelVarType.LEARNED_RANGE 16 | } 17 | 18 | LossTypeDict= { 19 | 'MSE': LossType.MSE, 20 | 'RESCALED_MSE': LossType.RESCALED_MSE, 21 | 'KL': LossType.KL, 22 | 'RESCALED_KL': LossType.RESCALED_KL 23 | } 24 | 25 | MODELS = [] 26 | 27 | def add_models(module): 28 | MODELS.extend([getattr(module, a) for a in dir(module) if 'Diffusion' in a]) 29 | 30 | add_models(gaussian_diffusion) 31 | 32 | 33 | def get_models(): 34 | '''Returns a tuple of sample models.''' 35 | return MODELS 36 | 37 | def load_diff_model(name): 38 | '''Creates and returns an instance of the model given its class name. 39 | ''' 40 | # Find the model class from its name 41 | all_models = get_models() 42 | mdict = {model.__name__: model for model in all_models} 43 | if name not in mdict: 44 | print('Invalid model index. Options are:') 45 | # Display a list of valid model names 46 | for model in all_models: 47 | print('\t* {}'.format(model.__name__)) 48 | return None 49 | NetClass = mdict[name] 50 | 51 | return NetClass 52 | 53 | def initialize_diff_model(DiffusionClass, betas, config): 54 | model_var_type = ModelVarTypeDict[config.diffusion.model_var_type] 55 | model_mean_type = ModelMeanTypeDict[config.diffusion.model_mean_type] 56 | loss_type = LossTypeDict[config.diffusion.loss_type] 57 | 58 | if not 'Spaced' in DiffusionClass.__name__: 59 | model = DiffusionClass(betas=betas, 60 | model_var_type=model_var_type, 61 | model_mean_type=model_mean_type, 62 | loss_type=loss_type, 63 | rescale_timesteps=config.diffusion.rescale_timestep 64 | if hasattr(config.diffusion, 'rescale_timestep') else False # False 65 | ) 66 | else: 67 | respacing = [config.diffusion.step // config.diffusion.respacing] 68 | 69 | model = DiffusionClass(use_timesteps=space_timesteps(config.diffusion.step, respacing), 70 | betas=betas, 71 | model_var_type=model_var_type, 72 | model_mean_type=model_mean_type, 73 | loss_type=loss_type 74 | ) 75 | 76 | return model -------------------------------------------------------------------------------- /lib/solvers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import logging 8 | 9 | from torch.optim import SGD, Adam 10 | from torch.optim.lr_scheduler import LambdaLR, StepLR 11 | 12 | 13 | class LambdaStepLR(LambdaLR): 14 | 15 | def __init__(self, optimizer, lr_lambda, last_step=-1): 16 | super(LambdaStepLR, self).__init__(optimizer, lr_lambda, last_step) 17 | 18 | @property 19 | def last_step(self): 20 | """Use last_epoch for the step counter""" 21 | return self.last_epoch 22 | 23 | @last_step.setter 24 | def last_step(self, v): 25 | self.last_epoch = v 26 | 27 | 28 | class PolyLR(LambdaStepLR): 29 | """DeepLab learning rate policy""" 30 | 31 | def __init__(self, optimizer, max_iter, power=0.9, last_step=-1): 32 | super(PolyLR, self).__init__(optimizer, lambda s: (1 - s / (max_iter + 1))**power, last_step) 33 | 34 | 35 | class SquaredLR(LambdaStepLR): 36 | """ Used for SGD Lars""" 37 | 38 | def __init__(self, optimizer, max_iter, last_step=-1): 39 | super(SquaredLR, self).__init__(optimizer, lambda s: (1 - s / (max_iter + 1))**2, last_step) 40 | 41 | 42 | class ExpLR(LambdaStepLR): 43 | 44 | def __init__(self, optimizer, step_size, gamma=0.9, last_step=-1): 45 | # (0.9 ** 21.854) = 0.1, (0.95 ** 44.8906) = 0.1 46 | # To get 0.1 every N using gamma 0.9, N * log(0.9)/log(0.1) = 0.04575749 N 47 | # To get 0.1 every N using gamma g, g ** N = 0.1 -> N * log(g) = log(0.1) -> g = np.exp(log(0.1) / N) 48 | super(ExpLR, self).__init__(optimizer, lambda s: gamma**(s / step_size), last_step) 49 | 50 | 51 | def initialize_optimizer(params, config): 52 | assert config.optimizer in ['SGD', 'Adagrad', 'Adam', 'RMSProp', 'Rprop', 'SGDLars'] 53 | 54 | if config.optimizer == 'SGD': 55 | return SGD( 56 | params, 57 | lr=config.lr, 58 | momentum=config.sgd_momentum, 59 | dampening=config.sgd_dampening, 60 | weight_decay=config.weight_decay) 61 | elif config.optimizer == 'Adam': 62 | return Adam( 63 | params, 64 | lr=config.lr, 65 | betas=(config.adam_beta1, config.adam_beta2), 66 | weight_decay=config.weight_decay) 67 | else: 68 | logging.error('Optimizer type not supported') 69 | raise ValueError('Optimizer type not supported') 70 | 71 | 72 | def initialize_scheduler(optimizer, config, last_step=-1): 73 | if config.scheduler == 'StepLR': 74 | return StepLR( 75 | optimizer, step_size=config.step_size, gamma=config.step_gamma, last_epoch=last_step) 76 | elif config.scheduler == 'PolyLR': 77 | return PolyLR(optimizer, max_iter=config.max_iter, power=config.poly_power, last_step=last_step) 78 | elif config.scheduler == 'SquaredLR': 79 | return SquaredLR(optimizer, max_iter=config.max_iter, last_step=last_step) 80 | elif config.scheduler == 'ExpLR': 81 | return ExpLR( 82 | optimizer, step_size=config.exp_step_size, gamma=config.exp_gamma, last_step=last_step) 83 | else: 84 | logging.error('Scheduler not supported') -------------------------------------------------------------------------------- /datasets/epn_control.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os.path as osp 4 | import glob 5 | import logging 6 | from torch.utils.data import Dataset 7 | from datasets.dataset import DictDataset, DatasetPhase, str2datasetphase_type 8 | from lib.utils import read_txt 9 | 10 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 11 | 12 | class ControlledEPNDataset(DictDataset): 13 | 14 | DATA_PATH_FILE = { 15 | DatasetPhase.Train: 'train', 16 | DatasetPhase.Test: 'test', 17 | DatasetPhase.Debug: 'test' 18 | } 19 | 20 | def __init__(self, 21 | config, 22 | input_transform=None, 23 | target_transform=None, 24 | augment_data=False, 25 | cache=False, 26 | phase=DatasetPhase.Train): 27 | 28 | if isinstance(phase, str): 29 | phase = str2datasetphase_type(phase) 30 | 31 | data_root = config.data.data_dir 32 | if config.data.per_class: 33 | data_paths = read_txt(osp.join(data_root, 'splits', self.DATA_PATH_FILE[phase]+'_'+config.data.class_id+'.txt')) 34 | else: 35 | data_paths = read_txt(osp.join(data_root, 'splits', self.DATA_PATH_FILE[phase]+'.txt')) 36 | data_paths = [data_path for data_path in data_paths] 37 | self.config = config 38 | self.representation = config.exp.representation 39 | self.trunc_distance = config.data.trunc_distance 40 | self.log_df = config.data.log_df 41 | self.data_paths = data_paths 42 | self.augment_data = augment_data 43 | self.input_transform = input_transform 44 | self.target_transform = target_transform 45 | self.suffix = config.data.suffix 46 | # logging.info('Loading {}: {}'.format(self.__class__.__name__, self.DATA_PATH_FILE[phase])) 47 | 48 | DictDataset.__init__( 49 | self, 50 | data_paths, 51 | input_transform=input_transform, 52 | target_transform=target_transform, 53 | cache=cache, 54 | data_root=data_root) 55 | 56 | 57 | def get_output_id(self, iteration): 58 | return '_'.join(Path(self.data_paths[iteration]).stem.split('_')[:2]) 59 | 60 | def load(self, filename): 61 | return torch.load(filename) 62 | 63 | def __len__(self): 64 | return len(self.data_paths) 65 | 66 | def __getitem__(self, index: int): 67 | filename = self.data_root / self.data_paths[index] 68 | scan_id = osp.basename(filename).replace(self.suffix, '') 69 | input_sdf, gt_df = self.load(filename) 70 | 71 | if self.representation == 'tsdf': 72 | input_sdf = np.clip(input_sdf, -self.trunc_distance, self.trunc_distance) 73 | gt_df = np.clip(gt_df, 0.0, self.trunc_distance) 74 | 75 | if self.log_df: 76 | gt_df = np.log(gt_df + 1) 77 | 78 | # Transformation 79 | if self.input_transform is not None: 80 | input_sdf = self.input_transform(input_sdf) 81 | if self.target_transform is not None: 82 | gt_df = self.target_transform(gt_df) 83 | 84 | return scan_id, input_sdf, gt_df 85 | -------------------------------------------------------------------------------- /configs/default.yaml: -------------------------------------------------------------------------------- 1 | # This source code is licensed under the MIT license found in the 2 | # LICENSE file in the root directory of this source tree. 3 | 4 | diffusion: 5 | model: GaussianDiffusion 6 | test_model: RepaintSpacedDiffusion 7 | step: 1000 8 | model_var_type: FIXED_SMALL 9 | learn_sigma: False 10 | sampler: 'second-order' 11 | model_mean_type: EPSILON 12 | rescale_timestep: False 13 | loss_type: MSE 14 | beta_schedule: 'linear' 15 | scale_ratio: 1.0 16 | diffusion_learn_sigma: False 17 | respacing: 10 18 | # diffusion_model_var_type: LEARNED_RANGE 19 | # diffusion_loss_type: RESCALED_MSE 20 | 21 | # if diffusion_learn_sigma: 22 | # diffusion_model_var_type = ModelVarType.LEARNED_RANGE 23 | # diffusion_loss_type = LossType.RESCALED_MSE 24 | 25 | net: 26 | network: ResUNet 27 | in_channels: 1 28 | model_channels: 64 29 | num_res_blocks: 3 30 | channel_mult: 1,2,2,2 31 | attention_resolutions: 32 | unet_activation: 33 | weights: checkpoint_ResUNet_iter200000.pth 34 | 35 | 36 | optimizer: 37 | optimizer: Adam 38 | lr: 0.0001 39 | adam_beta1: 0.9 40 | adam_beta2: 0.999 41 | lr_decay: False 42 | weight_decay: 0 43 | 44 | # Scheduler 45 | scheduler: StepLR 46 | step_size: 500 47 | step_gamma: 0.1 48 | poly_power: 0.9 49 | exp_gamma: 0.95 50 | exp_step_size: 445 51 | 52 | data: 53 | dataset: ShapeNetDataset 54 | train_file: 55 | data_dir: 56 | collate_fn: 57 | input_transform: 58 | targer_transform: 59 | cache_data: False 60 | persistent_workers: False 61 | suffix: .npy 62 | 63 | train: 64 | max_iter: 270000 65 | is_train: True 66 | stat_freq: 50 67 | val_freq: 10000 68 | empty_cache_freq: 1 69 | train_phase: train 70 | overwrite_weights: False 71 | resume: True 72 | resume_optimizer: True 73 | eval_upsample: False 74 | lenient_weight_loading: False 75 | mix_precision: True 76 | use_gradient_clip: False 77 | gradient_clip_value: 1.0 78 | 79 | # Test 80 | test: 81 | partial_shape: 82 | test_cnt: 160 83 | clip_noise: False 84 | use_ddim: False 85 | ddim_eta: 1.0 86 | test_phase: test 87 | test_batch_size: 4 88 | 89 | 90 | # Misc 91 | misc: 92 | seed: 123 93 | 94 | exp: 95 | res: 32 96 | representation: sdf 97 | 98 | batch_size: 64 99 | num_gpus: 1 100 | num_workers: 0 101 | 102 | skip_validate: True 103 | log_dir: exps/default 104 | 105 | 106 | ################################################################################ 107 | # slurm parameters 108 | ################################################################################ 109 | defaults: 110 | - hydra/launcher: submitit_slurm 111 | - hydra/hydra_logging: colorlog 112 | 113 | hydra: 114 | run: 115 | dir: ${exp.log_dir} 116 | sweep: 117 | dir: ${exp.log_dir} 118 | launcher: 119 | partition: dev 120 | submitit_folder: ${hydra.sweep.dir}/.submitit/%j 121 | name: ${hydra.job.name} 122 | timeout_min: 4320 123 | cpus_per_task: 24 124 | gpus_per_node: ${exp.num_gpus} 125 | tasks_per_node: 1 126 | mem_gb: 256 127 | nodes: 1 128 | constraint: 129 | exclude: seti 130 | max_num_timeout: 3 131 | 132 | 133 | 134 | -------------------------------------------------------------------------------- /datasets/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | from abc import ABC 8 | from pathlib import Path 9 | from collections import defaultdict 10 | 11 | import random 12 | import numpy as np 13 | from enum import Enum 14 | 15 | import torch 16 | from torch.utils.data import Dataset, DataLoader 17 | 18 | import datasets.transforms as t 19 | from datasets.dataloader import InfSampler, DistributedInfSampler 20 | from lib.distributed import get_world_size 21 | 22 | 23 | class DatasetPhase(Enum): 24 | Train = 0 25 | Val = 1 26 | Test = 2 27 | Debug = 3 28 | 29 | 30 | def datasetphase_2str(arg): 31 | if arg == DatasetPhase.Train: 32 | return 'train' 33 | elif arg == DatasetPhase.Val: 34 | return 'val' 35 | elif arg == DatasetPhase.Test: 36 | return 'test' 37 | else: 38 | raise ValueError('phase must be one of dataset enum.') 39 | 40 | 41 | def str2datasetphase_type(arg): 42 | if arg.upper() == 'TRAIN': 43 | return DatasetPhase.Train 44 | elif arg.upper() == 'VAL': 45 | return DatasetPhase.Val 46 | elif arg.upper() == 'TEST': 47 | return DatasetPhase.Test 48 | else: 49 | raise ValueError('phase must be one of train/val/test') 50 | 51 | 52 | class DictDataset(Dataset, ABC): 53 | 54 | def __init__(self, 55 | data_paths, 56 | input_transform=None, 57 | target_transform=None, 58 | cache=False, 59 | data_root='/'): 60 | """ 61 | data_paths: list of lists, [[str_path_to_input, str_path_to_label], [...]] 62 | """ 63 | Dataset.__init__(self) 64 | 65 | # Allows easier path concatenation 66 | if not isinstance(data_root, Path): 67 | data_root = Path(data_root) 68 | 69 | self.data_root = data_root 70 | self.data_paths = sorted(data_paths) 71 | 72 | self.input_transform = input_transform 73 | self.target_transform = target_transform 74 | 75 | # dictionary of input 76 | self.data_loader_dict = { 77 | 'input': (self.load_input, self.input_transform), 78 | 'target': (self.load_target, self.target_transform) 79 | } 80 | 81 | # For large dataset, do not cache 82 | self.cache = cache 83 | self.cache_dict = defaultdict(dict) 84 | self.loading_key_order = ['input', 'target'] 85 | 86 | def load_input(self, index): 87 | raise NotImplementedError 88 | 89 | def load_target(self, index): 90 | raise NotImplementedError 91 | 92 | def get_classnames(self): 93 | pass 94 | 95 | def reorder_result(self, result): 96 | return result 97 | 98 | def __getitem__(self, index): 99 | out_array = [] 100 | for k in self.loading_key_order: 101 | loader, transformer = self.data_loader_dict[k] 102 | v = loader(index) 103 | if transformer: 104 | v = transformer(v) 105 | out_array.append(v) 106 | return out_array 107 | 108 | def __len__(self): 109 | return len(self.data_paths) -------------------------------------------------------------------------------- /models/networks/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from models.networks import resunet3d, controlnet 3 | 4 | # Import network 5 | MODELS = [] 6 | 7 | def add_models(module): 8 | MODELS.extend([getattr(module, a) for a in dir(module) if 'Net' in a]) 9 | 10 | add_models(resunet3d) 11 | add_models(controlnet) 12 | 13 | def get_models(): 14 | '''Returns a tuple of sample models.''' 15 | return MODELS 16 | 17 | def load_network(name): 18 | '''Creates and returns an instance of the model given its class name. 19 | ''' 20 | # Find the model class from its name 21 | all_models = get_models() 22 | mdict = {model.__name__: model for model in all_models} 23 | if name not in mdict: 24 | print('Invalid model index. Options are:') 25 | # Display a list of valid model names 26 | for model in all_models: 27 | print('\t* {}'.format(model.__name__)) 28 | return None 29 | NetClass = mdict[name] 30 | 31 | return NetClass 32 | 33 | def initialize_network(NetClass, config): 34 | if not isinstance(NetClass, type) or not issubclass(NetClass, nn.Module): 35 | raise TypeError("network class must be a subclass of nn.Module") 36 | 37 | if isinstance(config.net.channel_mult, str): 38 | config.net.channel_mult = list(map(int,config.net.channel_mult.split(','))) 39 | 40 | if not config.net.attention_resolutions: 41 | config.net.attention_resolutions = [] 42 | else: 43 | config.net.attention_resolutions = list(map(int,config.net.attention_resolutions.split(','))) 44 | 45 | model = NetClass( 46 | in_channels=config.net.in_channels, model_channels=config.net.model_channels, 47 | out_channels=2 if hasattr(config.diffusion, 'diffusion_learn_sigma') 48 | and config.diffusion.diffusion_learn_sigma else 1, # 1 49 | num_res_blocks=config.net.num_res_blocks, # 3 50 | channel_mult=config.net.channel_mult, # (1, 2, 2, 2) 51 | attention_resolutions=config.net.attention_resolutions, # [] 52 | dropout=0, 53 | dims=3, 54 | activation=config.net.unet_activation if hasattr(config.net, 'unet_activation') else None 55 | ) 56 | 57 | return model 58 | 59 | def initialize_controlnet(NetClass, config): 60 | if not isinstance(NetClass, type) or not issubclass(NetClass, nn.Module): 61 | raise TypeError("network class must be a subclass of nn.Module") 62 | 63 | if isinstance(config.net.channel_mult, str): 64 | config.net.channel_mult = list(map(int,config.net.channel_mult.split(','))) 65 | 66 | if not config.net.attention_resolutions: 67 | config.net.attention_resolutions = [] 68 | else: 69 | config.net.attention_resolutions = list(map(int,config.net.attention_resolutions.split(','))) 70 | 71 | model = NetClass( 72 | in_channels=config.net.in_channels, model_channels=config.net.model_channels, 73 | hint_channels = config.net.hint_channels, 74 | out_channels=2 if hasattr(config.diffusion, 'diffusion_learn_sigma') 75 | and config.diffusion.diffusion_learn_sigma else 1, # 1 76 | num_res_blocks=config.net.num_res_blocks, # 3 77 | channel_mult=config.net.channel_mult, # (1, 2, 2, 2) 78 | attention_resolutions=config.net.attention_resolutions, # [] 79 | dropout=0, 80 | dims=3, 81 | activation=config.net.unet_activation if hasattr(config.net, 'unet_activation') else None 82 | ) 83 | 84 | return model 85 | -------------------------------------------------------------------------------- /utils/generate_3d_mask.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image, ImageDraw 3 | import math 4 | import random 5 | import argparse 6 | import os 7 | from utils.visualize_voxel import visualize_data 8 | 9 | def toU8(sample): 10 | if sample is None: 11 | return sample 12 | 13 | sample = np.clip(((sample + 1) * 127.5), 0, 255).astype(np.uint8) 14 | sample = np.transpose(sample, (1, 2, 0)) 15 | return sample 16 | 17 | def write_images(voxels, img_names, dir_path): 18 | os.makedirs(dir_path, exist_ok=True) 19 | 20 | for image_name, voxel in zip(img_names, voxels): 21 | out_path = os.path.join(dir_path, image_name) 22 | visualize_data(voxel, 'voxels', out_path) 23 | 24 | def RandomCrop(s, tries, hole_range=[0,1]): 25 | coef = min(hole_range[0] + hole_range[1], 1.0) 26 | 27 | while True: 28 | mask = np.ones((s, s, s), np.uint8) 29 | def Fill(max_size): 30 | # import ipdb; ipdb.set_trace() 31 | l, w, h = np.random.randint(max_size), np.random.randint(max_size), np.random.randint(max_size) 32 | ll, ww, hh = l // 2, w // 2, h // 2 33 | x, y, z = np.random.randint(-ll, s - l + ll), np.random.randint(-ww, s - w + ww), np.random.randint(-hh, s - h + hh) 34 | mask[max(z, 0): min(z + h, s), max(y, 0): min(y + w, s), max(x, 0): min(x + l, s)] = 0 35 | def MultiFill(max_tries, max_size): 36 | for _ in range(max_tries): 37 | Fill(max_size) 38 | MultiFill(int(tries * coef), s) 39 | hole_ratio = 1 - np.mean(mask) 40 | assert hole_ratio >= hole_range[0] and hole_ratio <= hole_range[1] 41 | return mask[np.newaxis, ...].astype(np.float32) 42 | 43 | 44 | def HalfCrop(s): 45 | 46 | axis = np.random.randint(3) 47 | mask = np.ones((s, s, s), np.uint8) 48 | crop = np.random.randint(int(0.3 * s) , int(0.7 * s)) 49 | if axis == 0: 50 | if np.random.random() > 0.5: 51 | mask[: crop, ...] = 0 52 | else: 53 | mask[crop:, ...] = 0 54 | elif axis == 1: 55 | if np.random.random() > 0.5: 56 | mask[:, :crop, :] = 0 57 | else: 58 | mask[:, crop:, :] = 0 59 | else: 60 | if np.random.random() > 0.5: 61 | mask[..., :crop] = 0 62 | else: 63 | mask[..., :crop] = 0 64 | 65 | return mask[np.newaxis, ...].astype(np.float32) 66 | 67 | def BatchRandomMask(batch_size, s, hole_range=[0, 1]): 68 | return np.stack([RandomMask(s, hole_range=hole_range) for _ in range(batch_size)], axis=0) 69 | 70 | if __name__ == '__main__': 71 | 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument('--res', type=int, required=False, default=32) 74 | parser.add_argument('--tries', type=int, required=False, default=8) 75 | parser.add_argument('--type', type=str, required=False, default="random") 76 | args = parser.parse_args() 77 | 78 | cnt = 50 79 | tot = 0 80 | dir_path = "./output_mask_half" if args.type == "half" else "./output_mask_random" 81 | 82 | masks = [] 83 | names = [] 84 | for i in range(cnt): 85 | mask = HalfCrop(s=args.res) if args.type == "half" else RandomCrop(s=args.res, tries=args.tries) 86 | mask = np.squeeze(mask) 87 | tot += mask.mean() 88 | masks.append(mask) 89 | names.append(f"{i}.jpg") 90 | print(tot / cnt) 91 | write_images(masks, names, dir_path) 92 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from lib.distributed import get_world_size 3 | from datasets.dataloader import InfSampler, DistributedInfSampler, SequentialDistributedSampler 4 | import datasets.transforms as t 5 | from datasets.dataset import str2datasetphase_type, DatasetPhase 6 | 7 | from datasets import epn_control 8 | 9 | DATASETS = [] 10 | 11 | def add_datasets(module): 12 | DATASETS.extend([getattr(module, a) for a in dir(module) if 'Dataset' in a]) 13 | 14 | add_datasets(epn_control) 15 | 16 | def load_dataset(name): 17 | '''Creates and returns an instance of the datasets given its name. 18 | ''' 19 | # Find the model class from its name 20 | mdict = {dataset.__name__: dataset for dataset in DATASETS} 21 | if name not in mdict: 22 | print('Invalid dataset index. Options are:') 23 | # Display a list of valid dataset names 24 | for dataset in DATASETS: 25 | print('\t* {}'.format(dataset.__name__)) 26 | raise ValueError(f'Dataset {name} not defined') 27 | DatasetClass = mdict[name] 28 | 29 | return DatasetClass 30 | 31 | 32 | def initialize_data_loader(DatasetClass, 33 | config, 34 | phase, 35 | num_workers, 36 | shuffle, 37 | repeat, 38 | collate, 39 | augment_data, 40 | batch_size, 41 | input_transform=None, 42 | target_transform=None, 43 | persistent_workers=False): 44 | if isinstance(phase, str): 45 | phase = str2datasetphase_type(phase) 46 | 47 | # Transform: currently None 48 | transform_train = [] 49 | if augment_data: 50 | if input_transform is not None: 51 | transform_train += input_transform 52 | 53 | if len(transform_train) > 0: 54 | transforms = t.Compose(transform_train) 55 | else: 56 | transforms = None 57 | 58 | dataset = DatasetClass( 59 | config, 60 | input_transform=transforms, 61 | target_transform=target_transform, 62 | cache=config.data.cache_data, 63 | augment_data=augment_data, 64 | phase=phase) 65 | 66 | if collate: 67 | collate_fn = t.collate_fn_factory() 68 | 69 | data_args = { 70 | 'dataset': dataset, 71 | 'num_workers': num_workers, 72 | 'batch_size': batch_size, 73 | 'collate_fn': collate_fn, 74 | 'persistent_workers': persistent_workers 75 | } 76 | else: 77 | data_args = { 78 | 'dataset': dataset, 79 | 'num_workers': num_workers, 80 | 'batch_size': batch_size, 81 | 'persistent_workers': persistent_workers 82 | } 83 | 84 | if repeat: 85 | if get_world_size() > 1: 86 | data_args['sampler'] = DistributedInfSampler(dataset, shuffle=shuffle) 87 | else: 88 | data_args['sampler'] = InfSampler(dataset, shuffle) 89 | 90 | else: 91 | data_args['shuffle'] = shuffle 92 | 93 | if config.train.is_train == False and config.test.partial_shape == True and get_world_size() > 1: 94 | data_args['sampler'] = SequentialDistributedSampler(dataset, batch_size=batch_size) 95 | 96 | data_loader = DataLoader(**data_args) 97 | 98 | return data_loader -------------------------------------------------------------------------------- /configs/epn_control_train.yaml: -------------------------------------------------------------------------------- 1 | # This source code is licensed under the MIT license found in the 2 | # LICENSE file in the root directory of this source tree. 3 | 4 | diffusion: 5 | model: GaussianDiffusion 6 | test_model: GaussianDiffusion 7 | step: 1000 8 | model_var_type: FIXED_SMALL 9 | learn_sigma: False 10 | sampler: 'second-order' 11 | model_mean_type: EPSILON 12 | rescale_timestep: False 13 | loss_type: MSE 14 | beta_schedule: 'linear' 15 | scale_ratio: 1.0 16 | diffusion_learn_sigma: False 17 | respacing: 10 18 | # diffusion_model_var_type: LEARNED_RANGE 19 | # diffusion_loss_type: RESCALED_MSE 20 | 21 | # if diffusion_learn_sigma: 22 | # diffusion_model_var_type = ModelVarType.LEARNED_RANGE 23 | # diffusion_loss_type = LossType.RESCALED_MSE 24 | 25 | net: 26 | network: ControlledUNet 27 | in_channels: 1 28 | model_channels: 64 29 | hint_channels: 2 30 | num_res_blocks: 3 31 | channel_mult: 1,2,2,2 32 | attention_resolutions: 33 | unet_activation: 34 | weights: 35 | controlnet: ControlNet 36 | control_weights: 37 | sd_locked: True 38 | 39 | optimizer: 40 | optimizer: Adam 41 | lr: 0.0001 42 | adam_beta1: 0.9 43 | adam_beta2: 0.999 44 | lr_decay: False 45 | weight_decay: 0 46 | 47 | # Scheduler 48 | scheduler: StepLR 49 | step_size: 500 50 | step_gamma: 0.1 51 | poly_power: 0.9 52 | exp_gamma: 0.95 53 | exp_step_size: 445 54 | 55 | data: 56 | per_class: True 57 | class_id: '03001627' 58 | dataset: ControlledEPNDataset 59 | train_file: 60 | data_dir: data/3d_epn 61 | collate_fn: 62 | input_transform: 63 | targer_transform: 64 | cache_data: False 65 | persistent_workers: True 66 | suffix: .pth 67 | log_df: False 68 | trunc_distance: 3.0 69 | 70 | train: 71 | train_phase: train 72 | debug: False 73 | max_iter: 300000 74 | is_train: True 75 | stat_freq: 50 76 | val_freq: 20000 77 | empty_cache_freq: 1 78 | overwrite_weights: False 79 | resume: True 80 | resume_optimizer: True 81 | eval_upsample: False 82 | lenient_weight_loading: False 83 | mix_precision: True 84 | use_gradient_clip: False 85 | gradient_clip_value: 1.0 86 | weighted_loss: False 87 | 88 | # Test 89 | test: 90 | partial_shape: True 91 | test_cnt: 160 92 | clip_noise: False 93 | use_ddim: False 94 | ddim_eta: 1.0 95 | test_phase: test 96 | test_batch_size: 4 97 | 98 | # Misc 99 | misc: 100 | seed: 123 101 | 102 | exp: 103 | res: 32 104 | representation: tsdf 105 | batch_size: 32 106 | num_gpus: 4 107 | num_workers: 32 108 | 109 | skip_validate: True 110 | log_dir: exps/epn_${data.class_id}_${exp.representation}_${exp.res} 111 | 112 | 113 | ################################################################################ 114 | # slurm parameters 115 | ################################################################################ 116 | defaults: 117 | - hydra/launcher: submitit_slurm 118 | - hydra/hydra_logging: colorlog 119 | 120 | hydra: 121 | run: 122 | dir: ${exp.log_dir} 123 | sweep: 124 | dir: ${exp.log_dir} 125 | launcher: 126 | partition: dev 127 | submitit_folder: ${hydra.sweep.dir}/.submitit/%j 128 | name: ${hydra.job.name} 129 | timeout_min: 4320 130 | cpus_per_task: 24 131 | gpus_per_node: ${exp.num_gpus} 132 | tasks_per_node: 1 133 | mem_gb: 256 134 | nodes: 1 135 | constraint: 136 | exclude: seti 137 | max_num_timeout: 3 138 | 139 | 140 | 141 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DiffComplete: Diffusion-based Generative 3D Shape Completion [NeurIPS 2023] 2 | 3 | ### [**Project Page**](https://ruihangchu.com/diffcomplete.html) | [**Paper**](https://arxiv.org/abs/2306.16329) | [**YouTube**](https://www.youtube.com/watch?v=aCBu5yZEvVI) 4 | 5 | 🔥🔥🔥 DiffComplete is a novel diffusion-based approach to enable multimodal, realistic, and high-fidelity 3D shape completion. 6 | 7 | ![teaser](images/teaser.png) 8 | 9 | # Environments 10 | You can easily set up and activate a conda environment for this project by using the following commands: 11 | ```angular2html 12 | conda env create -f environment.yml 13 | conda activate diffcom 14 | ``` 15 | 16 | # Data Construction 17 | For 3D-EPN dataset, we download the original data available from [3D-EPN](https://graphics.stanford.edu/projects/cnncomplete/data.html) for both training and evaluation purposes. 18 | To run the default setting with a resolution of $32^3$, we download the necessary data files [shapenet_dim32_df.zip](http://kaldir.vc.in.tum.de/adai/CNNComplete/shapenet_dim32_df.zip) and [shapenet_dim32_sdf.zip](http://kaldir.vc.in.tum.de/adai/CNNComplete/shapenet_dim32_sdf.zip) for the completed and partial shapes, respectively. 19 | 20 | To prepare the data, you can run ```data/sdf_2_npy.py``` convert the files to ```.npy``` format for easier handling. Then, run ```data/npy_2_pth.py``` to obtain the paired data of eight object classes for model training. 21 | 22 | The data structure should be organized as follows before training. 23 | 24 | ``` 25 | DiffComplete 26 | ├── data 27 | │ ├── 3d_epn 28 | │ │ ├── 02691156 29 | │ │ │ ├── 10155655850468db78d106ce0a280f87__0__.pth 30 | │ │ │ ├── ... 31 | │ │ ├── 02933112 32 | │ │ ├── 03001627 33 | │ │ ├── ... 34 | │ │ ├── splits 35 | │ │ │ ├── train_02691156.txt 36 | │ │ │ ├── train_02933112.txt 37 | │ │ │ ├── ... 38 | │ │ │ ├── test_02691156.txt 39 | │ │ │ ├── test_02933112.txt 40 | │ │ │ ├── ... 41 | ``` 42 | 43 | # Training and Inference 44 | Our training and inference processes primarily rely on the configuration files (```configs/epn_control_train.yaml``` and ```configs/epn_control_test.yaml```). You can adjust the number of GPUs used by modifying ```exp/num_gpus``` in these ```yaml``` files. This setting trains a specific model for each object category; thereby you could change ```data/class_id``` in the ```yaml``` file. 45 | 46 | To train the diffusion model, you can run the following command: 47 | ```angular2html 48 | python ddp_main.py --config-name epn_control_train.yaml 49 | ``` 50 | 51 | To test the trained model, you can denote the paths to the pretrained models by filling in ```net/weights``` and ```net/control_weights``` in the ```yaml``` file, and then run the following command: 52 | ```angular2html 53 | python ddp_main.py --config-name epn_control_test.yaml train.is_train=False 54 | ``` 55 | 56 | 57 | ## Citation 58 | If you find our work useful in your research, please consider citing: 59 | ``` 60 | @article{chu2024diffcomplete, 61 | title={Diffcomplete: Diffusion-based generative 3d shape completion}, 62 | author={Chu, Ruihang and Xie, Enze and Mo, Shentong and Li, Zhenguo and Nie{\ss}ner, Matthias and Fu, Chi-Wing and Jia, Jiaya}, 63 | journal={Advances in Neural Information Processing Systems}, 64 | year={2023} 65 | } 66 | ``` 67 | 68 | 69 | ## Acknowledgement 70 | 71 | We would like to thank the following repos for their great work: 72 | 73 | - This work is inspired by [ControlNet](https://github.com/lllyasviel/ControlNet). 74 | - This work utilizes 3D-UNet from [Wavelet-Generation](https://github.com/edward1997104/Wavelet-Generation). 75 | -------------------------------------------------------------------------------- /configs/epn_control_test.yaml: -------------------------------------------------------------------------------- 1 | # This source code is licensed under the MIT license found in the 2 | # LICENSE file in the root directory of this source tree. 3 | 4 | diffusion: 5 | model: GaussianDiffusion 6 | test_model: SpacedDiffusion 7 | step: 1000 8 | model_var_type: FIXED_SMALL 9 | learn_sigma: False 10 | sampler: 'second-order' 11 | model_mean_type: EPSILON 12 | rescale_timestep: False 13 | loss_type: MSE 14 | beta_schedule: 'linear' 15 | scale_ratio: 1.0 16 | diffusion_learn_sigma: False 17 | respacing: 10 18 | # diffusion_model_var_type: LEARNED_RANGE 19 | # diffusion_loss_type: RESCALED_MSE 20 | 21 | # if diffusion_learn_sigma: 22 | # diffusion_model_var_type = ModelVarType.LEARNED_RANGE 23 | # diffusion_loss_type = LossType.RESCALED_MSE 24 | 25 | net: 26 | network: ControlledUNet 27 | in_channels: 1 28 | model_channels: 64 29 | hint_channels: 2 30 | num_res_blocks: 3 31 | channel_mult: 1,2,2,2 32 | attention_resolutions: 33 | unet_activation: 34 | weights: exps/epn_03001627_tsdf_32/weights/checkpoint_ControlledUNet_iter200000.pth 35 | controlnet: ControlNet 36 | control_weights: exps/epn_03001627_tsdf_32/weights/checkpoint_ControlNet_iter200000.pth 37 | sd_locked: True 38 | 39 | optimizer: 40 | optimizer: Adam 41 | lr: 0.0001 42 | adam_beta1: 0.9 43 | adam_beta2: 0.999 44 | lr_decay: False 45 | weight_decay: 0 46 | 47 | # Scheduler 48 | scheduler: StepLR 49 | step_size: 500 50 | step_gamma: 0.1 51 | poly_power: 0.9 52 | exp_gamma: 0.95 53 | exp_step_size: 445 54 | 55 | data: 56 | per_class: True 57 | class_id: '03001627' 58 | dataset: ControlledEPNDataset 59 | train_file: 60 | data_dir: data/3d_epn 61 | collate_fn: 62 | input_transform: 63 | targer_transform: 64 | cache_data: False 65 | persistent_workers: False 66 | suffix: .pth 67 | log_df: False 68 | trunc_distance: 3.0 69 | 70 | train: 71 | debug: True 72 | max_iter: 200000 73 | is_train: True 74 | stat_freq: 50 75 | val_freq: 10000 76 | empty_cache_freq: 1 77 | train_phase: train 78 | overwrite_weights: False 79 | resume: True 80 | resume_optimizer: True 81 | eval_upsample: False 82 | lenient_weight_loading: False 83 | mix_precision: True 84 | use_gradient_clip: False 85 | gradient_clip_value: 1.0 86 | fine_tune_encoder: False 87 | 88 | # Test 89 | test: 90 | partial_shape: True 91 | test_cnt: 160 92 | clip_noise: False 93 | use_ddim: False 94 | ddim_eta: 1.0 95 | test_phase: test 96 | test_batch_size: 32 97 | 98 | # Misc 99 | misc: 100 | seed: 123 101 | 102 | exp: 103 | res: 32 104 | representation: tsdf 105 | 106 | batch_size: 32 107 | num_gpus: 1 108 | num_workers: 0 109 | 110 | skip_validate: True 111 | log_dir: exps/epn_${data.class_id}_${exp.representation}_${exp.res} 112 | 113 | ################################################################################ 114 | # slurm parameters 115 | ################################################################################ 116 | defaults: 117 | - hydra/launcher: submitit_slurm 118 | - hydra/hydra_logging: colorlog 119 | 120 | hydra: 121 | run: 122 | dir: ${exp.log_dir} 123 | sweep: 124 | dir: ${exp.log_dir} 125 | launcher: 126 | partition: dev 127 | submitit_folder: ${hydra.sweep.dir}/.submitit/%j 128 | name: ${hydra.job.name} 129 | timeout_min: 4320 130 | cpus_per_task: 24 131 | gpus_per_node: ${exp.num_gpus} 132 | tasks_per_node: 1 133 | mem_gb: 256 134 | nodes: 1 135 | constraint: 136 | exclude: seti 137 | max_num_timeout: 3 -------------------------------------------------------------------------------- /utils/visualize_voxel.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from matplotlib import pyplot as plt 3 | from mpl_toolkits.mplot3d import Axes3D 4 | from torchvision.utils import save_image 5 | # import im2mesh.common as common 6 | 7 | 8 | def visualize_data(data, data_type, out_file): 9 | r''' Visualizes the data with regard to its type. 10 | Args: 11 | data (tensor): batch of data 12 | data_type (string): data type (img, voxels or pointcloud) 13 | out_file (string): output file 14 | ''' 15 | if data_type == 'img': 16 | if data.dim() == 3: 17 | data = data.unsqueeze(0) 18 | save_image(data, out_file, nrow=4) 19 | elif data_type == 'voxels': 20 | visualize_voxels(data, out_file=out_file) 21 | elif data_type == 'pointcloud': 22 | visualize_pointcloud(data, out_file=out_file) 23 | elif data_type is None or data_type == 'idx': 24 | pass 25 | else: 26 | raise ValueError('Invalid data_type "%s"' % data_type) 27 | 28 | 29 | def visualize_voxels(voxels, out_file=None, show=False): 30 | r''' Visualizes voxel data. 31 | Args: 32 | voxels (tensor): voxel data 33 | out_file (string): output file 34 | show (bool): whether the plot should be shown 35 | ''' 36 | # Use numpy 37 | voxels = np.asarray(voxels) 38 | # Create plot 39 | fig = plt.figure() 40 | ax = fig.add_subplot(projection=Axes3D.name) 41 | voxels = voxels.transpose(2, 0, 1) 42 | ax.voxels(voxels, edgecolor='k') 43 | ax.set_xlabel('Z') 44 | ax.set_ylabel('X') 45 | ax.set_zlabel('Y') 46 | ax.view_init(elev=30, azim=45) 47 | if out_file is not None: 48 | plt.savefig(out_file) 49 | if show: 50 | plt.show() 51 | plt.close(fig) 52 | 53 | 54 | def visualize_pointcloud(points, normals=None, 55 | out_file=None, show=False): 56 | r''' Visualizes point cloud data. 57 | Args: 58 | points (tensor): point data 59 | normals (tensor): normal data (if existing) 60 | out_file (string): output file 61 | show (bool): whether the plot should be shown 62 | ''' 63 | # Use numpy 64 | points = np.asarray(points) 65 | # Create plot 66 | fig = plt.figure() 67 | ax = fig.add_subplot(projection=Axes3D.name) 68 | ax.scatter(points[:, 2], points[:, 0], points[:, 1]) 69 | if normals is not None: 70 | ax.quiver( 71 | points[:, 2], points[:, 0], points[:, 1], 72 | normals[:, 2], normals[:, 0], normals[:, 1], 73 | length=0.1, color='k' 74 | ) 75 | ax.set_xlabel('Z') 76 | ax.set_ylabel('X') 77 | ax.set_zlabel('Y') 78 | ax.set_xlim(-0.5, 0.5) 79 | ax.set_ylim(-0.5, 0.5) 80 | ax.set_zlim(-0.5, 0.5) 81 | ax.view_init(elev=30, azim=45) 82 | if out_file is not None: 83 | plt.savefig(out_file) 84 | if show: 85 | plt.show() 86 | plt.close(fig) 87 | 88 | # def visualise_projection( 89 | # self, points, world_mat, camera_mat, img, output_file='out.png'): 90 | # r''' Visualizes the transformation and projection to image plane. 91 | # The first points of the batch are transformed and projected to the 92 | # respective image. After performing the relevant transformations, the 93 | # visualization is saved in the provided output_file path. 94 | # Arguments: 95 | # points (tensor): batch of point cloud points 96 | # world_mat (tensor): batch of matrices to rotate pc to camera-based 97 | # coordinates 98 | # camera_mat (tensor): batch of camera matrices to project to 2D image 99 | # plane 100 | # img (tensor): tensor of batch GT image files 101 | # output_file (string): where the output should be saved 102 | # ''' 103 | # points_transformed = common.transform_points(points, world_mat) 104 | # points_img = common.project_to_camera(points_transformed, camera_mat) 105 | # pimg2 = points_img[0].detach().cpu().numpy() 106 | # image = img[0].cpu().numpy() 107 | # plt.imshow(image.transpose(1, 2, 0)) 108 | # plt.plot( 109 | # (pimg2[:, 0] + 1)*image.shape[1]/2, 110 | # (pimg2[:, 1] + 1) * image.shape[2]/2, 'x') 111 | # plt.savefig(output_file) -------------------------------------------------------------------------------- /utils/debugger.py: -------------------------------------------------------------------------------- 1 | import time 2 | import datetime 3 | import os 4 | import sys 5 | import numpy as np 6 | import logging 7 | import random 8 | import configs.config as config 9 | from shutil import copyfile 10 | import torch 11 | from tensorboardX import SummaryWriter 12 | 13 | def get_root_logger(log_file=None, log_level=logging.INFO): 14 | logger = logging.getLogger('wavelet') 15 | # if the logger has been initialized, just return it 16 | if logger.hasHandlers(): 17 | return logger 18 | 19 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', level=log_level) 20 | if log_file is not None: 21 | file_handler = logging.FileHandler(log_file, 'w') 22 | file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) 23 | file_handler.setLevel(log_level) 24 | logger.addHandler(file_handler) 25 | 26 | return logger 27 | 28 | class MyDebugger(): 29 | pre_fix = config.debug_base_folder 30 | 31 | def __init__(self, exp_name: str, fix_rand_seed=None, is_save_print_to_file=True, 32 | config_path=os.path.join('configs', 'config.py')): 33 | if fix_rand_seed is not None: 34 | np.random.seed(seed=fix_rand_seed) 35 | random.seed(fix_rand_seed) 36 | torch.manual_seed(fix_rand_seed) 37 | if isinstance(exp_name, str): 38 | self.exp_name = exp_name 39 | else: 40 | self.exp_name = '_'.join(exp_name) 41 | # self._debug_dir_name = os.path.join(os.path.dirname(__file__), MyDebugger.pre_fix, 42 | # datetime.datetime.fromtimestamp(time.time()).strftime( 43 | # f'%Y-%m-%d_%H-%M-%S_{self.model_name}')) 44 | self._debug_dir_name = os.path.join(os.getcwd(), MyDebugger.pre_fix, exp_name) 45 | self._write_dir_name = os.path.join(self._debug_dir_name, 'tensorboard') 46 | # self._debug_dir_name = os.path.join(os.path.dirname(__file__), self._debug_dir_name) 47 | print("=================== Program Start ====================") 48 | print(f"Output directory: {self._debug_dir_name}") 49 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 50 | self._init_debug_dir() 51 | log_file = os.path.join(self._debug_dir_name, f'{timestamp}.log') 52 | self.logger = get_root_logger(log_file=log_file) 53 | self.logger.info("Directory {} established".format(self._debug_dir_name)) 54 | self.writer = SummaryWriter(log_dir=self._write_dir_name) 55 | 56 | ######## redirect the standard output 57 | # if is_save_print_to_file: 58 | # sys.stdout = open(self.file_path("print.log"), 'w') 59 | # 60 | # ######## print the dir again on the log 61 | # print("=================== Program Start ====================") 62 | # print(f"Output directory: {self._debug_dir_name}") 63 | 64 | ######## copy config file to 65 | config_file_save_path = self.file_path(os.path.basename(config_path)) 66 | assert os.path.exists(config_path) 67 | copyfile(config_path, config_file_save_path) 68 | self.logger.info(f"config file created at {config_file_save_path}") 69 | 70 | def file_path(self, file_name): 71 | return os.path.join(self._debug_dir_name, file_name) 72 | 73 | def set_direcotry_name(self, name): 74 | self._debug_dir_name = name 75 | 76 | def _init_debug_dir(self): 77 | # init root debug dir 78 | if not os.path.exists(MyDebugger.pre_fix): 79 | os.mkdir(MyDebugger.pre_fix) 80 | if not os.path.exists(self._debug_dir_name): 81 | os.mkdir(self._debug_dir_name) 82 | if not os.path.exists(self._write_dir_name): 83 | os.mkdir(self._write_dir_name) 84 | 85 | def save_text(self, idx, save_type, filepath): 86 | self.logger.info(f"Epoch {idx} {save_type} saved in {filepath}") 87 | 88 | def add_scalar(self, *scalar): 89 | self.writer.add_scalar(*scalar) 90 | 91 | @staticmethod 92 | def get_save_text(save_type): 93 | return f"{save_type} saved in " 94 | 95 | 96 | if __name__ == '__main__': 97 | debugger = MyDebugger('testing') 98 | # file can save in the path 99 | file_path = debugger.file_path('file_to_be_save.txt') 100 | -------------------------------------------------------------------------------- /datasets/dataloader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import math 7 | import torch 8 | import torch.distributed as dist 9 | from torch.utils.data.sampler import Sampler 10 | 11 | 12 | class InfSampler(Sampler): 13 | """Samples elements randomly, without replacement. 14 | 15 | Arguments: 16 | data_source (Dataset): dataset to sample from 17 | """ 18 | 19 | def __init__(self, data_source, shuffle=False): 20 | self.data_source = data_source 21 | self.shuffle = shuffle 22 | self.reset_permutation() 23 | 24 | def reset_permutation(self): 25 | perm = len(self.data_source) 26 | if self.shuffle: 27 | perm = torch.randperm(perm) 28 | self._perm = perm.tolist() 29 | 30 | def __iter__(self): 31 | return self 32 | 33 | def __next__(self): 34 | if len(self._perm) == 0: 35 | self.reset_permutation() 36 | 37 | return self._perm.pop() 38 | 39 | def __len__(self): 40 | return len(self.data_source) 41 | 42 | next = __next__ # Python 2 compatibility 43 | 44 | 45 | class DistributedInfSampler(InfSampler): 46 | def __init__(self, data_source, num_replicas=None, rank=None, shuffle=True): 47 | if num_replicas is None: 48 | if not dist.is_available(): 49 | raise RuntimeError("Requires distributed package to be available") 50 | num_replicas = dist.get_world_size() 51 | if rank is None: 52 | if not dist.is_available(): 53 | raise RuntimeError("Requires distributed package to be available") 54 | rank = dist.get_rank() 55 | 56 | self.data_source = data_source 57 | self.num_replicas = num_replicas 58 | self.rank = rank 59 | self.epoch = 0 60 | self.it = 0 61 | self.num_samples = int(math.ceil(len(self.data_source) * 1.0 / self.num_replicas)) 62 | self.total_size = self.num_samples * self.num_replicas 63 | self.shuffle = shuffle 64 | self.reset_permutation() 65 | 66 | def __next__(self): 67 | it = self.it * self.num_replicas + self.rank 68 | value = self._perm[it % len(self._perm)] 69 | self.it = self.it + 1 70 | 71 | if (self.it * self.num_replicas) >= len(self._perm): 72 | self.reset_permutation() 73 | self.it = 0 74 | return value 75 | 76 | def __len__(self): 77 | return self.num_samples 78 | 79 | 80 | # https://github.com/huggingface/transformers/blob/447808c85f0e6d6b0aeeb07214942bf1e578f9d2/src/transformers/trainer_pt_utils.py 81 | class SequentialDistributedSampler(Sampler): 82 | """ 83 | Distributed Sampler that subsamples indicies sequentially, 84 | making it easier to collate all results at the end. 85 | Even though we only use this sampler for eval and predict (no training), 86 | which means that the model params won't have to be synced (i.e. will not hang 87 | for synchronization even if varied number of forward passes), we still add extra 88 | samples to the sampler to make it evenly divisible (like in `DistributedSampler`) 89 | to make it easy to `gather` or `reduce` resulting tensors at the end of the loop. 90 | """ 91 | 92 | def __init__(self, dataset, batch_size, rank=None, num_replicas=None): 93 | if num_replicas is None: 94 | if not torch.distributed.is_available(): 95 | raise RuntimeError("Requires distributed package to be available") 96 | num_replicas = torch.distributed.get_world_size() 97 | if rank is None: 98 | if not torch.distributed.is_available(): 99 | raise RuntimeError("Requires distributed package to be available") 100 | rank = torch.distributed.get_rank() 101 | self.dataset = dataset 102 | self.num_replicas = num_replicas 103 | self.rank = rank 104 | self.batch_size = batch_size 105 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.batch_size / self.num_replicas)) * self.batch_size 106 | self.total_size = self.num_samples * self.num_replicas 107 | 108 | def __iter__(self): 109 | indices = list(range(len(self.dataset))) 110 | # add extra samples to make it evenly divisible 111 | indices += [indices[-1]] * (self.total_size - len(indices)) 112 | # subsample 113 | indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples] 114 | return iter(indices) -------------------------------------------------------------------------------- /utils/generate_2d_mask.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image, ImageDraw 3 | import math 4 | import random 5 | import argparse 6 | import os 7 | 8 | def toU8(sample): 9 | if sample is None: 10 | return sample 11 | 12 | sample = np.clip(((sample + 1) * 127.5), 0, 255).astype(np.uint8) 13 | sample = np.transpose(sample, (1, 2, 0)) 14 | return sample 15 | 16 | def write_images(imgs, img_names, dir_path): 17 | os.makedirs(dir_path, exist_ok=True) 18 | 19 | for image_name, image in zip(img_names, imgs): 20 | out_path = os.path.join(dir_path, image_name) 21 | imwrite(img=image, path=out_path) 22 | 23 | def imwrite(path=None, img=None): 24 | Image.fromarray(img).save(path) 25 | 26 | def RandomBrush( 27 | max_tries, 28 | s, 29 | min_num_vertex = 4, 30 | max_num_vertex = 18, 31 | mean_angle = 2*math.pi / 5, 32 | angle_range = 2*math.pi / 15, 33 | min_width = 12, 34 | max_width = 48): 35 | H, W = s, s 36 | average_radius = math.sqrt(H*H+W*W) / 8 37 | mask = Image.new('L', (W, H), 0) 38 | for _ in range(np.random.randint(max_tries)): 39 | num_vertex = np.random.randint(min_num_vertex, max_num_vertex) 40 | angle_min = mean_angle - np.random.uniform(0, angle_range) 41 | angle_max = mean_angle + np.random.uniform(0, angle_range) 42 | angles = [] 43 | vertex = [] 44 | for i in range(num_vertex): 45 | if i % 2 == 0: 46 | angles.append(2*math.pi - np.random.uniform(angle_min, angle_max)) 47 | else: 48 | angles.append(np.random.uniform(angle_min, angle_max)) 49 | 50 | h, w = mask.size 51 | vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h)))) 52 | for i in range(num_vertex): 53 | r = np.clip( 54 | np.random.normal(loc=average_radius, scale=average_radius//2), 55 | 0, 2*average_radius) 56 | new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w) 57 | new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h) 58 | vertex.append((int(new_x), int(new_y))) 59 | 60 | draw = ImageDraw.Draw(mask) 61 | width = int(np.random.uniform(min_width, max_width)) 62 | draw.line(vertex, fill=1, width=width) 63 | for v in vertex: 64 | draw.ellipse((v[0] - width//2, 65 | v[1] - width//2, 66 | v[0] + width//2, 67 | v[1] + width//2), 68 | fill=1) 69 | if np.random.random() > 0.5: 70 | mask.transpose(Image.FLIP_LEFT_RIGHT) 71 | if np.random.random() > 0.5: 72 | mask.transpose(Image.FLIP_TOP_BOTTOM) 73 | mask = np.asarray(mask, np.uint8) 74 | if np.random.random() > 0.5: 75 | mask = np.flip(mask, 0) 76 | if np.random.random() > 0.5: 77 | mask = np.flip(mask, 1) 78 | return mask 79 | 80 | def RandomMask(s, hole_range=[0,1]): 81 | coef = min(hole_range[0] + hole_range[1], 1.0) 82 | 83 | while True: 84 | mask = np.ones((s, s), np.uint8) 85 | def Fill(max_size): 86 | w, h = np.random.randint(max_size), np.random.randint(max_size) 87 | ww, hh = w // 2, h // 2 88 | x, y = np.random.randint(-ww, s - w + ww), np.random.randint(-hh, s - h + hh) 89 | mask[max(y, 0): min(y + h, s), max(x, 0): min(x + w, s)] = 0 90 | def MultiFill(max_tries, max_size): 91 | for _ in range(np.random.randint(max_tries)): 92 | Fill(max_size) 93 | MultiFill(int(5 * coef), s // 2) 94 | MultiFill(int(3 * coef), s) 95 | mask = np.logical_and(mask, 1 - RandomBrush(int(9 * coef), s)) # hole denoted as 0, reserved as 1 96 | hole_ratio = 1 - np.mean(mask) 97 | if hole_range is not None and (hole_ratio <= hole_range[0] or hole_ratio >= hole_range[1]): 98 | continue 99 | return mask[np.newaxis, ...].astype(np.float32) 100 | 101 | def BatchRandomMask(batch_size, s, hole_range=[0, 1]): 102 | return np.stack([RandomMask(s, hole_range=hole_range) for _ in range(batch_size)], axis=0) 103 | 104 | 105 | if __name__ == '__main__': 106 | 107 | parser = argparse.ArgumentParser() 108 | parser.add_argument('--res', type=int, required=False, default=256) 109 | args = parser.parse_args() 110 | 111 | cnt = 50 112 | tot = 0 113 | dir_path = "./output_2d_masks" 114 | 115 | masks = [] 116 | names = [] 117 | for i in range(cnt): 118 | mask = RandomMask(s=args.res) 119 | tot += mask.mean() 120 | mask_save = np.squeeze(toU8(mask * 2 - 1)) 121 | masks.append(mask_save) 122 | names.append(f"{i}.jpg") 123 | print(tot / cnt) 124 | write_images(masks, names, dir_path) 125 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: diffcom 2 | channels: 3 | - open3d-admin 4 | - pytorch 5 | - anaconda 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _openmp_mutex=4.5=1_gnu 10 | - argon2-cffi=21.3.0=pyhd3eb1b0_0 11 | - argon2-cffi-bindings=21.2.0=py38h7f8727e_0 12 | - asttokens=2.0.5=pyhd3eb1b0_0 13 | - attrs=21.4.0=pyhd3eb1b0_0 14 | - backcall=0.2.0=pyhd3eb1b0_0 15 | - beautifulsoup4=4.11.1=py38h06a4308_0 16 | - blas=1.0=mkl 17 | - bleach=4.1.0=pyhd3eb1b0_0 18 | - brotlipy=0.7.0=py38h27cfd23_1003 19 | - bzip2=1.0.8=h7b6447c_0 20 | - ca-certificates=2022.07.19=h06a4308_0 21 | - certifi=2022.6.15=py38h06a4308_0 22 | - cffi=1.15.0=py38hd667e15_1 23 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 24 | - cryptography=36.0.0=py38h9ce1e76_0 25 | - cudatoolkit=11.3.1=h2bc3f7f_2 26 | - debugpy=1.5.1=py38h295c915_0 27 | - decorator=5.1.1=pyhd3eb1b0_0 28 | - defusedxml=0.7.1=pyhd3eb1b0_0 29 | - entrypoints=0.4=py38h06a4308_0 30 | - executing=0.8.3=pyhd3eb1b0_0 31 | - ffmpeg=4.3=hf484d3e_0 32 | - freetype=2.11.0=h70c0345_0 33 | - giflib=5.2.1=h7b6447c_0 34 | - gmp=6.2.1=h2531618_2 35 | - gnutls=3.6.15=he1e5248_0 36 | - idna=3.3=pyhd3eb1b0_0 37 | - importlib_resources=5.2.0=pyhd3eb1b0_1 38 | - intel-openmp=2021.4.0=h06a4308_3561 39 | - ipykernel=6.9.1=py38h06a4308_0 40 | - ipython=8.3.0=py38h06a4308_0 41 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 42 | - ipywidgets=7.6.5=pyhd3eb1b0_1 43 | - jedi=0.18.1=py38h06a4308_1 44 | - jinja2=3.0.3=pyhd3eb1b0_0 45 | - jpeg=9d=h7f8727e_0 46 | - jsonschema=4.4.0=py38h06a4308_0 47 | - jupyter_client=7.2.2=py38h06a4308_0 48 | - jupyter_core=4.10.0=py38h06a4308_0 49 | - jupyterlab_pygments=0.1.2=py_0 50 | - jupyterlab_widgets=1.0.0=pyhd3eb1b0_1 51 | - lame=3.100=h7b6447c_0 52 | - lcms2=2.12=h3be6417_0 53 | - ld_impl_linux-64=2.35.1=h7274673_9 54 | - libffi=3.3=he6710b0_2 55 | - libgcc-ng=9.3.0=h5101ec6_17 56 | - libgomp=9.3.0=h5101ec6_17 57 | - libiconv=1.15=h63c8f33_5 58 | - libidn2=2.3.2=h7f8727e_0 59 | - libpng=1.6.37=hbc83047_0 60 | - libsodium=1.0.18=h7b6447c_0 61 | - libstdcxx-ng=9.3.0=hd4cf53a_17 62 | - libtasn1=4.16.0=h27cfd23_0 63 | - libtiff=4.2.0=h85742a9_0 64 | - libunistring=0.9.10=h27cfd23_0 65 | - libuv=1.40.0=h7b6447c_0 66 | - libwebp=1.2.2=h55f646e_0 67 | - libwebp-base=1.2.2=h7f8727e_0 68 | - lz4-c=1.9.3=h295c915_1 69 | - markupsafe=2.0.1=py38h27cfd23_0 70 | - matplotlib-inline=0.1.2=pyhd3eb1b0_2 71 | - mistune=0.8.4=py38h7b6447c_1000 72 | - mkl=2021.4.0=h06a4308_640 73 | - mkl-service=2.4.0=py38h7f8727e_0 74 | - mkl_fft=1.3.1=py38hd3c417c_0 75 | - mkl_random=1.2.2=py38h51133e4_0 76 | - nbclient=0.5.13=py38h06a4308_0 77 | - nbconvert=6.4.4=py38h06a4308_0 78 | - nbformat=5.3.0=py38h06a4308_0 79 | - ncurses=6.3=h7f8727e_2 80 | - nest-asyncio=1.5.5=py38h06a4308_0 81 | - nettle=3.7.3=hbbd107a_1 82 | - notebook=6.4.11=py38h06a4308_0 83 | - numpy=1.21.2=py38h20f2e39_0 84 | - numpy-base=1.21.2=py38h79a1101_0 85 | - open3d=0.11.2=py38_0 86 | - openh264=2.1.1=h4ff587b_0 87 | - openssl=1.1.1q=h7f8727e_0 88 | - packaging=21.3=pyhd3eb1b0_0 89 | - pandocfilters=1.5.0=pyhd3eb1b0_0 90 | - parso=0.8.3=pyhd3eb1b0_0 91 | - pexpect=4.8.0=pyhd3eb1b0_3 92 | - pickleshare=0.7.5=pyhd3eb1b0_1003 93 | - pillow=9.0.1=py38h22f2fdc_0 94 | - pip=21.2.4=py38h06a4308_0 95 | - prometheus_client=0.13.1=pyhd3eb1b0_0 96 | - prompt-toolkit=3.0.20=pyhd3eb1b0_0 97 | - ptyprocess=0.7.0=pyhd3eb1b0_2 98 | - pure_eval=0.2.2=pyhd3eb1b0_0 99 | - pycparser=2.21=pyhd3eb1b0_0 100 | - pygments=2.11.2=pyhd3eb1b0_0 101 | - pyopenssl=22.0.0=pyhd3eb1b0_0 102 | - pyparsing=3.0.4=pyhd3eb1b0_0 103 | - pyrsistent=0.18.0=py38heee7806_0 104 | - pysocks=1.7.1=py38h06a4308_0 105 | - python=3.8.12=h12debd9_0 106 | - python-dateutil=2.8.2=pyhd3eb1b0_0 107 | - python-fastjsonschema=2.15.1=pyhd3eb1b0_0 108 | - pytorch=1.11.0=py3.8_cuda11.3_cudnn8.2.0_0 109 | - pytorch-mutex=1.0=cuda 110 | - pyzmq=22.3.0=py38h295c915_2 111 | - readline=8.1.2=h7f8727e_1 112 | - requests=2.27.1=pyhd3eb1b0_0 113 | - send2trash=1.8.0=pyhd3eb1b0_1 114 | - setuptools=58.0.4=py38h06a4308_0 115 | - six=1.16.0=pyhd3eb1b0_1 116 | - soupsieve=2.3.1=pyhd3eb1b0_0 117 | - sqlite=3.38.0=hc218d9a_0 118 | - stack_data=0.2.0=pyhd3eb1b0_0 119 | - terminado=0.13.1=py38h06a4308_0 120 | - testpath=0.5.0=pyhd3eb1b0_0 121 | - tk=8.6.11=h1ccaba5_0 122 | - torchaudio=0.11.0=py38_cu113 123 | - torchvision=0.12.0=py38_cu113 124 | - tornado=6.1=py38h27cfd23_0 125 | - traitlets=5.1.1=pyhd3eb1b0_0 126 | - typing-extensions=4.1.1=hd3eb1b0_0 127 | - typing_extensions=4.1.1=pyh06a4308_0 128 | - urllib3=1.26.8=pyhd3eb1b0_0 129 | - wcwidth=0.2.5=pyhd3eb1b0_0 130 | - webencodings=0.5.1=py38_1 131 | - wheel=0.37.1=pyhd3eb1b0_0 132 | - widgetsnbextension=3.5.2=py38h06a4308_0 133 | - xz=5.2.5=h7b6447c_0 134 | - zeromq=4.3.4=h2531618_0 135 | - zipp=3.8.0=py38h06a4308_0 136 | - zlib=1.2.11=h7f8727e_4 137 | - zstd=1.4.9=haebb681_0 138 | - pip: 139 | - ccimport==0.4.2 140 | - cumm-cu113==0.3.7 141 | - fire==0.4.0 142 | - lark==1.1.4 143 | - ninja==1.11.1 144 | - pccm==0.4.4 145 | - portalocker==2.6.0 146 | - pybind11==2.10.1 147 | - pywavelets==1.4.1 148 | - spconv-cu113==2.2.6 149 | - termcolor==2.1.1 150 | - tqdm==4.64.1 151 | - pymcubes==0.1.2 152 | - scipy==1.9.3 153 | - hydra-core==1.0.0 154 | - hydra-colorlog==1.0.0 155 | - hydra-submitit-launcher==1.1.0 156 | - tensorboardX==2.6 157 | - scikit-image==0.19.3 158 | - scikit-learn==1.2.1 159 | - matplotlib==3.7.0 160 | - addict==2.4.0 161 | - pandas==1.5.3 162 | - plyfile==0.7.4 163 | - sklearn==0.0.post1 164 | - k3d==2.15.2 165 | - trimesh==3.20.0 166 | -------------------------------------------------------------------------------- /models/modules/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | 11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 12 | class SiLU(nn.Module): 13 | def forward(self, x): 14 | return x * th.sigmoid(x) 15 | 16 | 17 | class GroupNorm32(nn.GroupNorm): 18 | def forward(self, x): 19 | return super().forward(x.float()).type(x.dtype) 20 | 21 | 22 | def conv_nd(dims, *args, **kwargs): 23 | """ 24 | Create a 1D, 2D, or 3D convolution module. 25 | """ 26 | if dims == 1: 27 | return nn.Conv1d(*args, **kwargs) 28 | elif dims == 2: 29 | return nn.Conv2d(*args, **kwargs) 30 | elif dims == 3: 31 | return nn.Conv3d(*args, **kwargs) 32 | raise ValueError(f"unsupported dimensions: {dims}") 33 | 34 | 35 | def linear(*args, **kwargs): 36 | """ 37 | Create a linear module. 38 | """ 39 | return nn.Linear(*args, **kwargs) 40 | 41 | 42 | def avg_pool_nd(dims, *args, **kwargs): 43 | """ 44 | Create a 1D, 2D, or 3D average pooling module. 45 | """ 46 | if dims == 1: 47 | return nn.AvgPool1d(*args, **kwargs) 48 | elif dims == 2: 49 | return nn.AvgPool2d(*args, **kwargs) 50 | elif dims == 3: 51 | return nn.AvgPool3d(*args, **kwargs) 52 | raise ValueError(f"unsupported dimensions: {dims}") 53 | 54 | 55 | def update_ema(target_params, source_params, rate=0.99): 56 | """ 57 | Update target parameters to be closer to those of source parameters using 58 | an exponential moving average. 59 | :param target_params: the target parameter sequence. 60 | :param source_params: the source parameter sequence. 61 | :param rate: the EMA rate (closer to 1 means slower). 62 | """ 63 | for targ, src in zip(target_params, source_params): 64 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 65 | 66 | 67 | def zero_module(module): 68 | """ 69 | Zero out the parameters of a module and return it. 70 | """ 71 | for p in module.parameters(): 72 | p.detach().zero_() 73 | return module 74 | 75 | 76 | def scale_module(module, scale): 77 | """ 78 | Scale the parameters of a module and return it. 79 | """ 80 | for p in module.parameters(): 81 | p.detach().mul_(scale) 82 | return module 83 | 84 | 85 | def mean_flat(tensor): 86 | """ 87 | Take the mean over all non-batch dimensions. 88 | """ 89 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 90 | 91 | 92 | ### REMARK: Change to 4 93 | def normalization(channels): 94 | """ 95 | Make a standard normalization layer. 96 | :param channels: number of input channels. 97 | :return: an nn.Module for normalization. 98 | """ 99 | return GroupNorm32(32, channels) 100 | 101 | 102 | def timestep_embedding(timesteps, dim, max_period=10000): 103 | """ 104 | Create sinusoidal timestep embeddings. 105 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 106 | These may be fractional. 107 | :param dim: the dimension of the output. 108 | :param max_period: controls the minimum frequency of the embeddings. 109 | :return: an [N x dim] Tensor of positional embeddings. 110 | """ 111 | half = dim // 2 112 | freqs = th.exp( 113 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 114 | ).to(device=timesteps.device) 115 | args = timesteps[:, None].float() * freqs[None] 116 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 117 | if dim % 2: 118 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 119 | return embedding 120 | 121 | 122 | def checkpoint(func, inputs, params, flag): 123 | """ 124 | Evaluate a function without caching intermediate activations, allowing for 125 | reduced memory at the expense of extra compute in the backward pass. 126 | :param func: the function to evaluate. 127 | :param inputs: the argument sequence to pass to `func`. 128 | :param params: a sequence of parameters `func` depends on but does not 129 | explicitly take as arguments. 130 | :param flag: if False, disable gradient checkpointing. 131 | """ 132 | if flag: 133 | args = tuple(inputs) + tuple(params) 134 | return CheckpointFunction.apply(func, len(inputs), *args) 135 | else: 136 | return func(*inputs) 137 | 138 | 139 | class CheckpointFunction(th.autograd.Function): 140 | @staticmethod 141 | def forward(ctx, run_function, length, *args): 142 | ctx.run_function = run_function 143 | ctx.input_tensors = list(args[:length]) 144 | ctx.input_params = list(args[length:]) 145 | with th.no_grad(): 146 | output_tensors = ctx.run_function(*ctx.input_tensors) 147 | return output_tensors 148 | 149 | @staticmethod 150 | def backward(ctx, *output_grads): 151 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 152 | with th.enable_grad(): 153 | # Fixes a bug where the first op in run_function modifies the 154 | # Tensor storage in place, which is not allowed for detach()'d 155 | # Tensors. 156 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 157 | output_tensors = ctx.run_function(*shallow_copies) 158 | input_grads = th.autograd.grad( 159 | output_tensors, 160 | ctx.input_tensors + ctx.input_params, 161 | output_grads, 162 | allow_unused=True, 163 | ) 164 | del ctx.input_tensors 165 | del ctx.input_params 166 | del output_tensors 167 | return (None, None) + input_grads -------------------------------------------------------------------------------- /models/modules/resample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | import torch.distributed as dist 6 | 7 | 8 | def create_named_schedule_sampler(name, diffusion): 9 | """ 10 | Create a ScheduleSampler from a library of pre-defined samplers. 11 | :param name: the name of the sampler. 12 | :param diffusion: the diffusion object to sample for. 13 | """ 14 | if name == "uniform": 15 | return UniformSampler(diffusion) 16 | elif name == "loss-second-moment": 17 | return LossSecondMomentResampler(diffusion) 18 | else: 19 | raise NotImplementedError(f"unknown schedule sampler: {name}") 20 | 21 | 22 | class ScheduleSampler(ABC): 23 | """ 24 | A distribution over timesteps in the diffusion process, intended to reduce 25 | variance of the objective. 26 | By default, samplers perform unbiased importance sampling, in which the 27 | objective's mean is unchanged. 28 | However, subclasses may override sample() to change how the resampled 29 | terms are reweighted, allowing for actual changes in the objective. 30 | """ 31 | 32 | @abstractmethod 33 | def weights(self): 34 | """ 35 | Get a numpy array of weights, one per diffusion step. 36 | The weights needn't be normalized, but must be positive. 37 | """ 38 | 39 | def sample(self, batch_size, device): 40 | """ 41 | Importance-sample timesteps for a batch. 42 | :param batch_size: the number of timesteps. 43 | :param device: the torch device to save to. 44 | :return: a tuple (timesteps, weights): 45 | - timesteps: a tensor of timestep indices. 46 | - weights: a tensor of weights to scale the resulting losses. 47 | """ 48 | w = self.weights() 49 | p = w / np.sum(w) 50 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 51 | indices = th.from_numpy(indices_np).long().to(device) 52 | weights_np = 1 / (len(p) * p[indices_np]) 53 | weights = th.from_numpy(weights_np).float().to(device) 54 | return indices, weights 55 | 56 | 57 | class UniformSampler(ScheduleSampler): 58 | def __init__(self, diffusion): 59 | self.diffusion = diffusion 60 | self._weights = np.ones([diffusion.num_timesteps]) 61 | 62 | def weights(self): 63 | return self._weights 64 | 65 | 66 | class LossAwareSampler(ScheduleSampler): 67 | def update_with_local_losses(self, local_ts, local_losses): 68 | """ 69 | Update the reweighting using losses from a model. 70 | Call this method from each rank with a batch of timesteps and the 71 | corresponding losses for each of those timesteps. 72 | This method will perform synchronization to make sure all of the ranks 73 | maintain the exact same reweighting. 74 | :param local_ts: an integer Tensor of timesteps. 75 | :param local_losses: a 1D Tensor of losses. 76 | """ 77 | batch_sizes = [ 78 | th.tensor([0], dtype=th.int32, device=local_ts.device) 79 | for _ in range(dist.get_world_size()) 80 | ] 81 | dist.all_gather( 82 | batch_sizes, 83 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 84 | ) 85 | 86 | # Pad all_gather batches to be the maximum batch size. 87 | batch_sizes = [x.item() for x in batch_sizes] 88 | max_bs = max(batch_sizes) 89 | 90 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 91 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 92 | dist.all_gather(timestep_batches, local_ts) 93 | dist.all_gather(loss_batches, local_losses) 94 | timesteps = [ 95 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 96 | ] 97 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 98 | self.update_with_all_losses(timesteps, losses) 99 | 100 | @abstractmethod 101 | def update_with_all_losses(self, ts, losses): 102 | """ 103 | Update the reweighting using losses from a model. 104 | Sub-classes should override this method to update the reweighting 105 | using losses from the model. 106 | This method directly updates the reweighting without synchronizing 107 | between workers. It is called by update_with_local_losses from all 108 | ranks with identical arguments. Thus, it should have deterministic 109 | behavior to maintain state across workers. 110 | :param ts: a list of int timesteps. 111 | :param losses: a list of float losses, one per timestep. 112 | """ 113 | 114 | 115 | class LossSecondMomentResampler(LossAwareSampler): 116 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 117 | self.diffusion = diffusion 118 | self.history_per_term = history_per_term 119 | self.uniform_prob = uniform_prob 120 | self._loss_history = np.zeros( 121 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 122 | ) 123 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 124 | 125 | def weights(self): 126 | if not self._warmed_up(): 127 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 128 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 129 | weights /= np.sum(weights) 130 | weights *= 1 - self.uniform_prob 131 | weights += self.uniform_prob / len(weights) 132 | return weights 133 | 134 | def update_with_all_losses(self, ts, losses): 135 | for t, loss in zip(ts, losses): 136 | if self._loss_counts[t] == self.history_per_term: 137 | # Shift out the oldest loss term. 138 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 139 | self._loss_history[t, -1] = loss 140 | else: 141 | self._loss_history[t, self._loss_counts[t]] = loss 142 | self._loss_counts[t] += 1 143 | 144 | def _warmed_up(self): 145 | return (self._loss_counts == self.history_per_term).all() -------------------------------------------------------------------------------- /models/modules/scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license 16 | 17 | def get_schedule(t_T, t_0, n_sample, n_steplength, debug=0): 18 | if n_steplength > 1: 19 | if not n_sample > 1: 20 | raise RuntimeError('n_steplength has no effect if n_sample=1') 21 | 22 | t = t_T 23 | times = [t] 24 | while t >= 0: 25 | t = t - 1 26 | times.append(t) 27 | n_steplength_cur = min(n_steplength, t_T - t) 28 | 29 | for _ in range(n_sample - 1): 30 | 31 | for _ in range(n_steplength_cur): 32 | t = t + 1 33 | times.append(t) 34 | for _ in range(n_steplength_cur): 35 | t = t - 1 36 | times.append(t) 37 | 38 | _check_times(times, t_0, t_T) 39 | 40 | if debug == 2: 41 | for x in [list(range(0, 50)), list(range(-1, -50, -1))]: 42 | _plot_times(x=x, times=[times[i] for i in x]) 43 | 44 | return times 45 | 46 | 47 | def _check_times(times, t_0, t_T): 48 | # Check end 49 | assert times[0] > times[1], (times[0], times[1]) 50 | 51 | # Check beginning 52 | assert times[-1] == -1, times[-1] 53 | 54 | # Steplength = 1 55 | for t_last, t_cur in zip(times[:-1], times[1:]): 56 | assert abs(t_last - t_cur) == 1, (t_last, t_cur) 57 | 58 | # Value range 59 | for t in times: 60 | assert t >= t_0, (t, t_0) 61 | assert t <= t_T, (t, t_T) 62 | 63 | 64 | def _plot_times(x, times): 65 | import matplotlib.pyplot as plt 66 | plt.plot(x, times) 67 | plt.show() 68 | 69 | 70 | def get_schedule_jump(t_T, n_sample, jump_length, jump_n_sample, 71 | jump2_length=1, jump2_n_sample=1, 72 | jump3_length=1, jump3_n_sample=1, 73 | start_resampling=100000000): 74 | 75 | jumps = {} 76 | for j in range(0, t_T - jump_length, jump_length): 77 | jumps[j] = jump_n_sample - 1 78 | 79 | jumps2 = {} 80 | for j in range(0, t_T - jump2_length, jump2_length): 81 | jumps2[j] = jump2_n_sample - 1 82 | 83 | jumps3 = {} 84 | for j in range(0, t_T - jump3_length, jump3_length): 85 | jumps3[j] = jump3_n_sample - 1 86 | 87 | t = t_T 88 | ts = [] 89 | 90 | while t >= 1: 91 | t = t-1 92 | ts.append(t) 93 | 94 | if ( 95 | t + 1 < t_T - 1 and 96 | t <= start_resampling 97 | ): 98 | for _ in range(n_sample - 1): 99 | t = t + 1 100 | ts.append(t) 101 | 102 | if t >= 0: 103 | t = t - 1 104 | ts.append(t) 105 | 106 | if ( 107 | jumps3.get(t, 0) > 0 and 108 | t <= start_resampling - jump3_length 109 | ): 110 | jumps3[t] = jumps3[t] - 1 111 | for _ in range(jump3_length): 112 | t = t + 1 113 | ts.append(t) 114 | 115 | if ( 116 | jumps2.get(t, 0) > 0 and 117 | t <= start_resampling - jump2_length 118 | ): 119 | jumps2[t] = jumps2[t] - 1 120 | for _ in range(jump2_length): 121 | t = t + 1 122 | ts.append(t) 123 | jumps3 = {} 124 | for j in range(0, t_T - jump3_length, jump3_length): 125 | jumps3[j] = jump3_n_sample - 1 126 | 127 | if ( 128 | jumps.get(t, 0) > 0 and 129 | t <= start_resampling - jump_length 130 | ): 131 | jumps[t] = jumps[t] - 1 132 | for _ in range(jump_length): 133 | t = t + 1 134 | ts.append(t) 135 | jumps2 = {} 136 | for j in range(0, t_T - jump2_length, jump2_length): 137 | jumps2[j] = jump2_n_sample - 1 138 | 139 | jumps3 = {} 140 | for j in range(0, t_T - jump3_length, jump3_length): 141 | jumps3[j] = jump3_n_sample - 1 142 | 143 | ts.append(-1) 144 | 145 | _check_times(ts, -1, t_T) 146 | 147 | return ts 148 | 149 | 150 | def get_schedule_jump_paper(): 151 | t_T = 250 152 | jump_length = 10 153 | jump_n_sample = 10 154 | 155 | jumps = {} 156 | for j in range(0, t_T - jump_length, jump_length): 157 | jumps[j] = jump_n_sample - 1 158 | 159 | t = t_T 160 | ts = [] 161 | 162 | while t >= 1: 163 | t = t-1 164 | ts.append(t) 165 | 166 | if jumps.get(t, 0) > 0: 167 | jumps[t] = jumps[t] - 1 168 | for _ in range(jump_length): 169 | t = t + 1 170 | ts.append(t) 171 | 172 | ts.append(-1) 173 | 174 | _check_times(ts, -1, t_T) 175 | 176 | return ts 177 | 178 | 179 | def get_schedule_jump_test(to_supplement=False): 180 | ts = get_schedule_jump(t_T=250, n_sample=1, 181 | jump_length=10, jump_n_sample=10, 182 | jump2_length=1, jump2_n_sample=1, 183 | jump3_length=1, jump3_n_sample=1, 184 | start_resampling=250) 185 | 186 | import matplotlib.pyplot as plt 187 | SMALL_SIZE = 8*3 188 | MEDIUM_SIZE = 10*3 189 | BIGGER_SIZE = 12*3 190 | 191 | plt.rc('font', size=SMALL_SIZE) # controls default text sizes 192 | plt.rc('axes', titlesize=SMALL_SIZE) # fontsize of the axes title 193 | plt.rc('axes', labelsize=MEDIUM_SIZE) # fontsize of the x and y labels 194 | plt.rc('xtick', labelsize=SMALL_SIZE) # fontsize of the tick labels 195 | plt.rc('ytick', labelsize=SMALL_SIZE) # fontsize of the tick labels 196 | plt.rc('legend', fontsize=SMALL_SIZE) # legend fontsize 197 | plt.rc('figure', titlesize=BIGGER_SIZE) # fontsize of the figure title 198 | 199 | plt.plot(ts) 200 | 201 | fig = plt.gcf() 202 | fig.set_size_inches(20, 10) 203 | 204 | ax = plt.gca() 205 | ax.set_xlabel('Number of Transitions') 206 | ax.set_ylabel('Diffusion time $t$') 207 | 208 | fig.tight_layout() 209 | 210 | if to_supplement: 211 | out_path = "./jump_sched.pdf" 212 | plt.savefig(out_path) 213 | 214 | out_path = "./schedule.png" 215 | plt.savefig(out_path) 216 | print(out_path) 217 | 218 | 219 | def main(): 220 | get_schedule_jump_test() 221 | 222 | 223 | if __name__ == "__main__": 224 | main() 225 | -------------------------------------------------------------------------------- /lib/visualize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from matplotlib import pyplot as plt 3 | from mpl_toolkits.mplot3d import Axes3D 4 | from torchvision.utils import save_image 5 | # import im2mesh.common as common 6 | import numpy as np 7 | import k3d 8 | from matplotlib import cm, colors 9 | import trimesh 10 | from pathlib import Path 11 | 12 | 13 | def visualize_data(data, data_type, out_file): 14 | r''' Visualizes the data with regard to its type. 15 | Args: 16 | data (tensor): batch of data 17 | data_type (string): data type (img, voxels or pointcloud) 18 | out_file (string): output file 19 | ''' 20 | if data_type == 'img': 21 | if data.dim() == 3: 22 | data = data.unsqueeze(0) 23 | save_image(data, out_file, nrow=4) 24 | elif data_type == 'voxels': 25 | visualize_voxels(data, out_file=out_file) 26 | elif data_type == 'pointcloud': 27 | visualize_pointcloud(data, out_file=out_file) 28 | elif data_type is None or data_type == 'idx': 29 | pass 30 | else: 31 | raise ValueError('Invalid data_type "%s"' % data_type) 32 | 33 | 34 | def visualize_voxels(voxels, out_file=None, show=False): 35 | r''' Visualizes voxel data. 36 | Args: 37 | voxels (tensor): voxel data 38 | out_file (string): output file 39 | show (bool): whether the plot should be shown 40 | ''' 41 | # Use numpy 42 | voxels = np.asarray(voxels) 43 | # Create plot 44 | fig = plt.figure() 45 | ax = fig.add_subplot(projection=Axes3D.name) 46 | voxels = voxels.transpose(2, 0, 1) 47 | ax.voxels(voxels, edgecolor='k') 48 | ax.set_xlabel('Z') 49 | ax.set_ylabel('X') 50 | ax.set_zlabel('Y') 51 | ax.view_init(elev=30, azim=45) 52 | if out_file is not None: 53 | plt.savefig(out_file) 54 | if show: 55 | plt.show() 56 | plt.close(fig) 57 | 58 | 59 | def visualize_pointcloud(points, normals=None, 60 | out_file=None, show=False): 61 | r''' Visualizes point cloud data. 62 | Args: 63 | points (tensor): point data 64 | normals (tensor): normal data (if existing) 65 | out_file (string): output file 66 | show (bool): whether the plot should be shown 67 | ''' 68 | # Use numpy 69 | points = np.asarray(points) 70 | # Create plot 71 | fig = plt.figure() 72 | ax = fig.add_subplot(projection=Axes3D.name) 73 | ax.scatter(points[:, 2], points[:, 0], points[:, 1]) 74 | if normals is not None: 75 | ax.quiver( 76 | points[:, 2], points[:, 0], points[:, 1], 77 | normals[:, 2], normals[:, 0], normals[:, 1], 78 | length=0.1, color='k' 79 | ) 80 | ax.set_xlabel('Z') 81 | ax.set_ylabel('X') 82 | ax.set_zlabel('Y') 83 | ax.set_xlim(-0.5, 0.5) 84 | ax.set_ylim(-0.5, 0.5) 85 | ax.set_zlim(-0.5, 0.5) 86 | ax.view_init(elev=30, azim=45) 87 | if out_file is not None: 88 | plt.savefig(out_file) 89 | if show: 90 | plt.show() 91 | plt.close(fig) 92 | 93 | def visualise_projection( 94 | self, points, world_mat, camera_mat, img, output_file='out.png'): 95 | r''' Visualizes the transformation and projection to image plane. 96 | The first points of the batch are transformed and projected to the 97 | respective image. After performing the relevant transformations, the 98 | visualization is saved in the provided output_file path. 99 | Arguments: 100 | points (tensor): batch of point cloud points 101 | world_mat (tensor): batch of matrices to rotate pc to camera-based 102 | coordinates 103 | camera_mat (tensor): batch of camera matrices to project to 2D image 104 | plane 105 | img (tensor): tensor of batch GT image files 106 | output_file (string): where the output should be saved 107 | ''' 108 | points_transformed = common.transform_points(points, world_mat) 109 | points_img = common.project_to_camera(points_transformed, camera_mat) 110 | pimg2 = points_img[0].detach().cpu().numpy() 111 | image = img[0].cpu().numpy() 112 | plt.imshow(image.transpose(1, 2, 0)) 113 | plt.plot( 114 | (pimg2[:, 0] + 1)*image.shape[1]/2, 115 | (pimg2[:, 1] + 1) * image.shape[2]/2, 'x') 116 | plt.savefig(output_file) 117 | 118 | def visualize_mesh(vertices, faces, file_name, flip_axes=False): 119 | vertices = np.array(vertices) 120 | plot = k3d.plot(name='mesh', grid_visible=False, grid=(-0.55, -0.55, -0.55, 0.55, 0.55, 0.55)) 121 | if flip_axes: 122 | rot_matrix = np.array([ 123 | [-1.0000000, 0.0000000, 0.0000000], 124 | [0.0000000, 0.0000000, 1.0000000], 125 | [0.0000000, 1.0000000, 0.0000000] 126 | ]) 127 | vertices = vertices @ rot_matrix 128 | plt_mesh = k3d.mesh(vertices.astype(np.float32), faces.astype(np.uint32), color=0xd0d0d0) 129 | plot += plt_mesh 130 | plt_mesh.shader = '3d' 131 | plot.display() 132 | 133 | 134 | def visualize_meshes(meshes, flip_axes=False): 135 | assert len(meshes) == 3 136 | plot = k3d.plot(name='meshes', grid_visible=False, grid=(-0.55, -0.55, -0.55, 0.55, 0.55, 0.55)) 137 | for mesh_idx, mesh in enumerate(meshes): 138 | vertices, faces = mesh[:2] 139 | if flip_axes: 140 | vertices[:, 2] = vertices[:, 2] * -1 141 | vertices[:, [0, 1, 2]] = vertices[:, [0, 2, 1]] 142 | vertices += [[-32, -32, 0], [0, -32, 0], [32, -32, 0]][mesh_idx] 143 | plt_mesh = k3d.mesh(vertices.astype(np.float32), faces.astype(np.uint32), color=0xd0d0d0) 144 | plot += plt_mesh 145 | plt_mesh.shader = '3d' 146 | plot.display() 147 | 148 | 149 | def visualize_sdf(sdf: np.array, filename: Path) -> None: 150 | assert sdf.shape[0] == sdf.shape[1] == sdf.shape[2], "SDF grid has to be of cubic shape" 151 | print(f"Creating SDF visualization for {sdf.shape[0]}^3 grid ...") 152 | 153 | voxels = np.stack(np.meshgrid(range(sdf.shape[0]), range(sdf.shape[1]), range(sdf.shape[2]))).reshape(3, -1).T 154 | 155 | sdf[sdf < 0] /= np.abs(sdf[sdf < 0]).max() 156 | sdf[sdf > 0] /= sdf[sdf > 0].max() 157 | sdf /= 2. 158 | 159 | corners = np.array([ 160 | [-.25, -.25, -.25], 161 | [.25, -.25, -.25], 162 | [-.25, .25, -.25], 163 | [.25, .25, -.25], 164 | [-.25, -.25, .25], 165 | [.25, -.25, .25], 166 | [-.25, .25, .25], 167 | [.25, .25, .25] 168 | ])[np.newaxis, :].repeat(voxels.shape[0], axis=0).reshape(-1, 3) 169 | 170 | scale_factors = sdf[tuple(voxels.T)].repeat(8, axis=0) 171 | cube_vertex_colors = cm.get_cmap('seismic')(colors.Normalize(vmin=-1, vmax=1)(scale_factors))[:, :3] 172 | scale_factors[scale_factors < 0] *= .25 173 | cube_vertices = voxels.repeat(8, axis=0) + corners * scale_factors[:, np.newaxis] 174 | 175 | faces = np.array([ 176 | [1, 0, 2], [2, 3, 1], [5, 1, 3], [3, 7, 5], [4, 5, 7], [7, 6, 4], 177 | [0, 4, 6], [6, 2, 0], [3, 2, 6], [6, 7, 3], [5, 4, 0], [0, 1, 5] 178 | ])[np.newaxis, :].repeat(voxels.shape[0], axis=0).reshape(-1, 3) 179 | cube_faces = faces + (np.arange(0, voxels.shape[0]) * 8)[np.newaxis, :].repeat(12, axis=0).T.flatten()[:, np.newaxis] 180 | 181 | mesh = trimesh.Trimesh(vertices=cube_vertices, faces=cube_faces, vertex_colors=cube_vertex_colors, process=False) 182 | mesh.export(str(filename)) 183 | print(f"Exported to {filename}") -------------------------------------------------------------------------------- /models/networks/controlnet.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | import math 4 | import numpy as np 5 | import torch 6 | import torch as th 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from models.modules.fp16_util import convert_module_to_f16, convert_module_to_f32 11 | from models.modules.nn import ( 12 | SiLU, 13 | conv_nd, 14 | linear, 15 | avg_pool_nd, 16 | zero_module, 17 | normalization, 18 | timestep_embedding, 19 | checkpoint, 20 | ) 21 | from models.networks.resunet3d import TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock 22 | 23 | class ControlNet(nn.Module): 24 | def __init__( 25 | self, 26 | in_channels, 27 | model_channels, 28 | hint_channels, 29 | out_channels, 30 | num_res_blocks, 31 | attention_resolutions, 32 | dropout=0, 33 | channel_mult=(1, 2, 4, 8), 34 | conv_resample=True, 35 | dims=2, 36 | num_classes=None, 37 | use_checkpoint=False, 38 | num_heads=1, 39 | num_heads_upsample=-1, 40 | use_scale_shift_norm=False, 41 | activation = None 42 | ): 43 | super().__init__() 44 | 45 | self.activation = activation if activation is not None else SiLU() 46 | if num_heads_upsample == -1: 47 | num_heads_upsample = num_heads 48 | 49 | # self.voxel_size = voxel_size 50 | self.in_channels = in_channels 51 | self.model_channels = model_channels 52 | self.out_channels = out_channels 53 | self.num_res_blocks = num_res_blocks 54 | self.attention_resolutions = attention_resolutions 55 | self.dropout = dropout 56 | self.channel_mult = channel_mult 57 | self.conv_resample = conv_resample 58 | self.num_classes = num_classes 59 | self.use_checkpoint = use_checkpoint 60 | self.num_heads = num_heads 61 | self.num_heads_upsample = num_heads_upsample 62 | self.dims = dims 63 | 64 | time_embed_dim = model_channels * 4 65 | self.time_embed = nn.Sequential( 66 | linear(model_channels, time_embed_dim), 67 | self.activation, 68 | linear(time_embed_dim, time_embed_dim), 69 | ) 70 | 71 | if self.num_classes is not None: 72 | self.label_emb = nn.Embedding(num_classes, time_embed_dim) 73 | 74 | # self.input_hint_block = ModuleList([self.make_zero_conv(hint_channels)]) 75 | 76 | self.input_hint_block = TimestepEmbedSequential( 77 | conv_nd(dims, hint_channels, 16, 3, padding=1), 78 | nn.SiLU(), 79 | zero_module(conv_nd(dims, 16, in_channels, 3, padding=1)) 80 | ) 81 | self.input_blocks = nn.ModuleList( 82 | [ 83 | TimestepEmbedSequential( 84 | conv_nd(dims, in_channels, model_channels, 3, padding=1) 85 | ) 86 | ] 87 | ) 88 | 89 | self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)]) 90 | 91 | input_block_chans = [model_channels] 92 | ch = model_channels 93 | ds = 1 94 | for level, mult in enumerate(channel_mult): 95 | for _ in range(num_res_blocks): 96 | layers = [ 97 | ResBlock( 98 | ch, 99 | time_embed_dim, 100 | dropout, 101 | out_channels=mult * model_channels, 102 | dims=dims, 103 | use_checkpoint=use_checkpoint, 104 | use_scale_shift_norm=use_scale_shift_norm, 105 | activation = self.activation 106 | ) 107 | ] 108 | ch = mult * model_channels 109 | if ds in attention_resolutions: 110 | layers.append( 111 | AttentionBlock( 112 | ch, use_checkpoint=use_checkpoint, num_heads=num_heads 113 | ) 114 | ) 115 | self.input_blocks.append(TimestepEmbedSequential(*layers)) 116 | self.zero_convs.append(self.make_zero_conv(ch)) 117 | input_block_chans.append(ch) 118 | if level != len(channel_mult) - 1: 119 | self.input_blocks.append( 120 | TimestepEmbedSequential(Downsample(ch, conv_resample, dims=dims)) 121 | ) 122 | input_block_chans.append(ch) 123 | self.zero_convs.append(self.make_zero_conv(ch)) 124 | ds *= 2 125 | 126 | self.middle_block = TimestepEmbedSequential( 127 | ResBlock( 128 | ch, 129 | time_embed_dim, 130 | dropout, 131 | dims=dims, 132 | use_checkpoint=use_checkpoint, 133 | use_scale_shift_norm=use_scale_shift_norm, 134 | activation = self.activation 135 | ), 136 | AttentionBlock(ch, use_checkpoint=use_checkpoint, num_heads=num_heads), 137 | ResBlock( 138 | ch, 139 | time_embed_dim, 140 | dropout, 141 | dims=dims, 142 | use_checkpoint=use_checkpoint, 143 | use_scale_shift_norm=use_scale_shift_norm, 144 | activation = self.activation 145 | ), 146 | ) 147 | self.middle_block_out = self.make_zero_conv(ch) 148 | 149 | def convert_to_fp16(self): 150 | """ 151 | Convert the torso of the model to float16. 152 | """ 153 | self.input_blocks.apply(convert_module_to_f16) 154 | self.middle_block.apply(convert_module_to_f16) 155 | self.output_blocks.apply(convert_module_to_f16) 156 | 157 | def convert_to_fp32(self): 158 | """ 159 | Convert the torso of the model to float32. 160 | """ 161 | self.input_blocks.apply(convert_module_to_f32) 162 | self.middle_block.apply(convert_module_to_f32) 163 | self.output_blocks.apply(convert_module_to_f32) 164 | 165 | @property 166 | def inner_dtype(self): 167 | """ 168 | Get the dtype used by the torso of the model. 169 | """ 170 | return torch.float32 # FIXED 171 | # return next(self.input_blocks.parameters()).dtype 172 | 173 | def make_zero_conv(self, channels): 174 | return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0))) 175 | 176 | def forward(self, x, hint, timesteps, y=None): 177 | """ 178 | Apply the model to an input batch. 179 | :param x: an [N x C x ...] Tensor of inputs. 180 | :param timesteps: a 1-D batch of timesteps. 181 | :param y: an [N] Tensor of labels, if class-conditional. 182 | :return: an [N x C x ...] Tensor of outputs. 183 | """ 184 | assert (y is not None) == ( 185 | self.num_classes is not None 186 | ), "must specify y if and only if the model is class-conditional" 187 | 188 | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) # (16. 256) 189 | 190 | if self.num_classes is not None: 191 | assert y.shape == (x.shape[0],) 192 | emb = emb + self.label_emb(y) 193 | 194 | guided_hint = self.input_hint_block(hint, emb) # torch.Size([16, 1, 32, 32, 32]) 195 | 196 | outs = [] 197 | 198 | h = x.type(self.inner_dtype) 199 | for module, zero_conv in zip(self.input_blocks, self.zero_convs): 200 | if guided_hint is not None: 201 | h = module(h, emb) 202 | h += guided_hint 203 | guided_hint = None 204 | else: 205 | h = module(h, emb) 206 | outs.append(zero_conv(h,emb)) 207 | 208 | h = self.middle_block(h, emb) 209 | outs.append(self.middle_block_out(h, emb)) 210 | 211 | return outs 212 | -------------------------------------------------------------------------------- /lib/distributed.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import time 5 | import torch 6 | import signal 7 | import random 8 | import pickle 9 | import threading 10 | import functools 11 | import traceback 12 | import torch.nn as nn 13 | import torch.distributed as dist 14 | import torch.multiprocessing as mp 15 | 16 | """Multiprocessing error handler.""" 17 | 18 | 19 | class ChildException(Exception): 20 | """Wraps an exception from a child process.""" 21 | 22 | def __init__(self, child_trace): 23 | super(ChildException, self).__init__(child_trace) 24 | 25 | 26 | class ErrorHandler(object): 27 | """Multiprocessing error handler (based on fairseq's). 28 | 29 | Listens for errors in child processes and 30 | propagates the tracebacks to the parent process. 31 | """ 32 | 33 | def __init__(self, error_queue): 34 | # Shared error queue 35 | self.error_queue = error_queue 36 | # Children processes sharing the error queue 37 | self.children_pids = [] 38 | # Start a thread listening to errors 39 | self.error_listener = threading.Thread(target=self.listen, daemon=True) 40 | self.error_listener.start() 41 | # Register the signal handler 42 | signal.signal(signal.SIGUSR1, self.signal_handler) 43 | 44 | def add_child(self, pid): 45 | """Registers a child process.""" 46 | self.children_pids.append(pid) 47 | 48 | def listen(self): 49 | """Listens for errors in the error queue.""" 50 | # Wait until there is an error in the queue 51 | child_trace = self.error_queue.get() 52 | # Put the error back for the signal handler 53 | self.error_queue.put(child_trace) 54 | # Invoke the signal handler 55 | os.kill(os.getpid(), signal.SIGUSR1) 56 | 57 | def signal_handler(self, sig_num, stack_frame): 58 | """Signal handler.""" 59 | # Kill children processes 60 | for pid in self.children_pids: 61 | os.kill(pid, signal.SIGINT) 62 | # Propagate the error from the child process 63 | raise ChildException(self.error_queue.get()) 64 | 65 | 66 | """Multiprocessing helpers.""" 67 | 68 | 69 | def run(proc_rank, world_size, port, error_queue, fun, fun_args, fun_kwargs): 70 | """Runs a function from a child process.""" 71 | 72 | try: 73 | init_process_group(proc_rank, world_size, port) 74 | fun(*fun_args, **fun_kwargs) 75 | except: 76 | # Propagate exception to the parent process 77 | error_queue.put(traceback.format_exc()) 78 | finally: 79 | destroy_process_group() 80 | 81 | 82 | def multi_proc_run(num_proc, fun, fun_args=(), fun_kwargs={}): 83 | """Runs a function in a multi-proc setting.""" 84 | 85 | # Handle errors from training subprocesses 86 | error_queue = mp.SimpleQueue() 87 | error_handler = ErrorHandler(error_queue) 88 | 89 | # Run each training subprocess 90 | port = random.randint(10001, 20002) 91 | # port = random.randint(23023, 50400) 92 | mp.spawn(run, nprocs=num_proc, args=(num_proc, port, error_queue, fun, fun_args, fun_kwargs)) 93 | 94 | 95 | """Distributed helpers.""" 96 | 97 | 98 | def is_master_proc(num_gpus): 99 | """Determines if the current process is the master process. 100 | 101 | Master process is responsible for logging, writing and loading checkpoints. 102 | In the multi GPU setting, we assign the master role to the rank 0 process. 103 | When training using a single GPU, there is only one training processes 104 | which is considered the master processes. 105 | """ 106 | return num_gpus == 1 or torch.distributed.get_rank() == 0 107 | 108 | 109 | def get_world_size(): 110 | if not dist.is_available(): 111 | return 1 112 | if not dist.is_initialized(): 113 | return 1 114 | return dist.get_world_size() 115 | 116 | 117 | def get_rank(): 118 | if not dist.is_available(): 119 | return 0 120 | if not dist.is_initialized(): 121 | return 0 122 | return dist.get_rank() 123 | 124 | 125 | def synchronize(): 126 | """ 127 | Helper function to synchronize (barrier) among all processes when 128 | using distributed training 129 | """ 130 | if not dist.is_available(): 131 | return 132 | if not dist.is_initialized(): 133 | return 134 | world_size = dist.get_world_size() 135 | if world_size == 1: 136 | return 137 | dist.barrier() 138 | 139 | 140 | def all_gather_differentiable(tensor): 141 | """ 142 | Run differentiable gather function for SparseConv features with variable number of points. 143 | tensor: [num_points, feature_dim] 144 | """ 145 | world_size = get_world_size() 146 | if world_size == 1: 147 | return [tensor] 148 | 149 | num_points, f_dim = tensor.size() 150 | local_np = torch.LongTensor([num_points]).to("cuda") 151 | np_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)] 152 | dist.all_gather(np_list, local_np) 153 | np_list = [int(np.item()) for np in np_list] 154 | max_np = max(np_list) 155 | 156 | tensor_list = [] 157 | for _ in np_list: 158 | tensor_list.append(torch.FloatTensor(size=(max_np, f_dim)).to("cuda")) 159 | if local_np != max_np: 160 | padding = torch.zeros(size=(max_np - local_np, f_dim)).to("cuda").float() 161 | tensor = torch.cat((tensor, padding), dim=0) 162 | assert tensor.size() == (max_np, f_dim) 163 | 164 | dist.all_gather(tensor_list, tensor) 165 | 166 | data_list = [] 167 | for gather_np, gather_tensor in zip(np_list, tensor_list): 168 | gather_tensor = gather_tensor[:gather_np] 169 | assert gather_tensor.size() == (gather_np, f_dim) 170 | data_list.append(gather_tensor) 171 | return data_list 172 | 173 | 174 | def all_gather(data): 175 | """ 176 | Run all_gather on arbitrary picklable data (not necessarily tensors) 177 | Args: 178 | data: any picklable object 179 | Returns: 180 | list[data]: list of data gathered from each rank 181 | """ 182 | world_size = get_world_size() 183 | if world_size == 1: 184 | return [data] 185 | 186 | # serialized to a Tensor 187 | buffer = pickle.dumps(data) 188 | storage = torch.ByteStorage.from_buffer(buffer) 189 | tensor = torch.ByteTensor(storage).to("cuda") 190 | 191 | # obtain Tensor size of each rank 192 | local_size = torch.LongTensor([tensor.numel()]).to("cuda") 193 | size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)] 194 | dist.all_gather(size_list, local_size) 195 | size_list = [int(size.item()) for size in size_list] 196 | max_size = max(size_list) 197 | 198 | # receiving Tensor from all ranks 199 | # we pad the tensor because torch all_gather does not support 200 | # gathering tensors of different shapes 201 | tensor_list = [] 202 | for _ in size_list: 203 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda")) 204 | if local_size != max_size: 205 | padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda") 206 | tensor = torch.cat((tensor, padding), dim=0) 207 | dist.all_gather(tensor_list, tensor) 208 | 209 | data_list = [] 210 | for size, tensor in zip(size_list, tensor_list): 211 | buffer = tensor.cpu().numpy().tobytes()[:size] 212 | data_list.append(pickle.loads(buffer)) 213 | 214 | return data_list 215 | 216 | 217 | def init_process_group(proc_rank, world_size, port): 218 | """Initializes the default process group.""" 219 | # Set the GPU to use 220 | print(proc_rank) 221 | 222 | torch.cuda.set_device(proc_rank) 223 | # Initialize the process group 224 | torch.distributed.init_process_group( 225 | backend="nccl", 226 | init_method="tcp://{}:{}".format("localhost", port), 227 | world_size=world_size, 228 | rank=proc_rank 229 | ) 230 | 231 | 232 | def destroy_process_group(): 233 | """Destroys the default process group.""" 234 | torch.distributed.destroy_process_group() -------------------------------------------------------------------------------- /tools/test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import shutil 4 | import warnings 5 | import mcubes 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from lib.distributed import get_world_size, all_gather, is_master_proc 10 | from models.diffusion import load_diff_model, initialize_diff_model 11 | from models.diffusion.gaussian_diffusion import get_named_beta_schedule 12 | from skimage.measure import marching_cubes 13 | 14 | from lib.utils import Timer, AverageMeter 15 | from lib.visualize import visualize_mesh 16 | 17 | def distributed_concat(tensor, num_total_examples): 18 | output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())] 19 | torch.distributed.all_gather(output_tensors, tensor) 20 | concat = torch.cat(output_tensors, dim=0) 21 | # truncate the dummy elements added by SequentialDistributedSampler 22 | return concat[:num_total_examples] 23 | 24 | def test(model, control_model, data_loader, config): 25 | 26 | 27 | is_master = is_master_proc(config.exp.num_gpus) if config.exp.num_gpus > 1 else True 28 | cur_device = torch.cuda.current_device() 29 | global_timer, iter_timer = Timer(), Timer() 30 | 31 | bs = config.test.test_batch_size // config.exp.num_gpus 32 | 33 | model.eval() 34 | control_model.eval() 35 | 36 | if is_master: 37 | logging.info('===> Start testing') 38 | global_timer.tic() 39 | 40 | # Clear cache (when run in test mode, cleanup training cache) 41 | torch.cuda.empty_cache() 42 | 43 | # Split test data into different gpus 44 | test_cnt = len(data_loader) // config.exp.num_gpus 45 | test_iter = int(config.net.control_weights[:-4].split('iter')[1]) 46 | 47 | cls = config.data.class_id 48 | save_folder = 'completion_results' 49 | os.makedirs(save_folder, exist_ok=True) 50 | save_folder = os.path.join(save_folder, str(cls), str(test_iter)) 51 | os.makedirs(save_folder, exist_ok=True) 52 | noise_folder = os.path.join(save_folder, 'noise') 53 | os.makedirs(noise_folder, exist_ok=True) 54 | 55 | npz_folder = os.path.join('completion_results_npz', str(cls), str(test_iter)) 56 | os.makedirs(npz_folder, exist_ok=True) 57 | 58 | # Setting of Diffusion Models 59 | clip_noise = config.test.clip_noise 60 | use_ddim = config.test.use_ddim 61 | ddim_eta = config.test.ddim_eta 62 | betas = get_named_beta_schedule(config.diffusion.beta_schedule, 63 | config.diffusion.step, 64 | config.diffusion.scale_ratio) 65 | DiffusionClass = load_diff_model(config.diffusion.test_model) 66 | diffusion_model = initialize_diff_model(DiffusionClass, betas, config) 67 | 68 | data_iter = data_loader.__iter__() 69 | 70 | iter_timer.tic() 71 | 72 | 73 | if config.exp.num_gpus == 1: 74 | 75 | with torch.no_grad(): 76 | for m in range(test_cnt): 77 | scan_ids, observe, gt = data_iter.next() 78 | sign = observe[:, 1].numpy() 79 | bs = observe.size(0) 80 | noise = None 81 | model_kwargs = { 82 | 'noise_save_path': os.path.join(noise_folder, f'{scan_ids[0]}noise.pt')} 83 | model_kwargs["hint"] = observe.to(cur_device) # torch.Size([1, 2, 32, 32, 32]) 84 | 85 | # # # Visualize range scans (by SDF) 86 | # for i in range(len(observe)): 87 | # single_observe = observe[i] 88 | # obs_sdf = single_observe[0].numpy() 89 | # scan_id = scan_ids[i] 90 | # sdf_vertices, sdf_traingles = mcubes.marching_cubes(obs_sdf, 0.5) 91 | # out_file = os.path.join(save_folder, f'{scan_id}input.obj') 92 | # mcubes.export_obj(sdf_vertices, sdf_traingles, out_file) 93 | # # print(f"Save {out_file}!") 94 | # 95 | # # Visualize GT DF 96 | # gt_df = gt.numpy() 97 | # for i in range(len(gt_df)): 98 | # gt_single = gt_df[i] 99 | # scan_id = scan_ids[i] 100 | # vertices, traingles = mcubes.marching_cubes(gt_single, 0.5) 101 | # # vertices = (vertices.astype(np.float32) - 0.5) / config.exp.res - 0.5 102 | # out_file = os.path.join(save_folder, f'{scan_id}gt.obj') 103 | # mcubes.export_obj(vertices, traingles, out_file) 104 | # # print(f"Save {out_file}!") 105 | 106 | if use_ddim: 107 | low_samples = diffusion_model.ddim_sample_loop(model=model, 108 | shape=[bs, 1] + [config.exp.res] * 3, 109 | device=cur_device, 110 | clip_denoised=clip_noise, progress=True, 111 | noise=noise, 112 | eta=ddim_eta, 113 | model_kwargs=model_kwargs).detach() 114 | else: 115 | low_samples = diffusion_model.p_sample_loop(model=model, 116 | control_model=control_model, 117 | shape=[bs, 1] + [config.exp.res] * 3, 118 | device=cur_device, 119 | clip_denoised=clip_noise, progress=True, noise=noise, 120 | model_kwargs=model_kwargs).detach() 121 | 122 | low_samples = low_samples.cpu().numpy()[:, 0] 123 | if config.data.log_df == True: 124 | low_samples = np.exp(low_samples) - 1 125 | low_samples = np.clip(low_samples, 0, config.data.trunc_distance) 126 | 127 | # Visualize predicted DF 128 | for i in range(len(low_samples)): 129 | low_sample = low_samples[i] 130 | scan_id = scan_ids[i] 131 | # You can choose more advanced surface extraining methods for TDF outputs 132 | vertices, traingles = mcubes.marching_cubes(low_sample, 0.5) 133 | out_file = os.path.join(save_folder, f'{scan_id}output.obj') 134 | mcubes.export_obj(vertices, traingles, out_file) 135 | out_npy_file = os.path.join(npz_folder, f'{scan_id}output.npy') 136 | np.save(out_npy_file, low_sample) 137 | 138 | else: 139 | with torch.no_grad(): 140 | for scan_id, observe, gt in data_loader: 141 | sign = observe[:, 1].numpy() 142 | noise = None 143 | model_kwargs = { 144 | 'noise_save_path': os.path.join(noise_folder, f'{scan_id[0]}noise.pt')} 145 | model_kwargs["hint"] = observe.to(cur_device) # torch.Size([1, 2, 32, 32, 32]) 146 | 147 | if use_ddim: 148 | low_samples = diffusion_model.ddim_sample_loop(model=model, 149 | shape=[bs, 1] + [config.exp.res] * 3, 150 | device=cur_device, 151 | clip_denoised=clip_noise, progress=True, 152 | noise=noise, 153 | eta=ddim_eta, 154 | model_kwargs=model_kwargs).detach() 155 | else: 156 | low_samples = diffusion_model.p_sample_loop(model=model, 157 | control_model=control_model, 158 | shape=[bs, 1] + [config.exp.res] * 3, 159 | device=cur_device, 160 | clip_denoised=clip_noise, progress=True, noise=noise, 161 | model_kwargs=model_kwargs).detach() 162 | 163 | low_samples = low_samples.cpu().numpy()[:, 0] 164 | if config.data.log_df == True: 165 | low_samples = np.exp(low_samples) - 1 166 | low_samples = np.clip(low_samples, 0, config.data.trunc_distance) 167 | 168 | # You can visualize the results here 169 | 170 | iter_time = iter_timer.toc(False) 171 | global_time = global_timer.toc(False) 172 | -------------------------------------------------------------------------------- /lib/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import json 8 | import logging 9 | import os 10 | import errno 11 | import time 12 | import torch 13 | import numpy as np 14 | from omegaconf import OmegaConf 15 | from lib.distributed import get_world_size 16 | 17 | def find_free_port(): 18 | import socket 19 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 20 | # Binding to port 0 will cause the OS to find an available port for us 21 | sock.bind(("", 0)) 22 | port = sock.getsockname()[1] 23 | sock.close() 24 | # NOTE: there is still a chance the port could be taken by other processes. 25 | return port 26 | 27 | def load_state_with_same_shape(model, weights): 28 | # weights['conv1.kernel'] = weights['conv1.kernel'].repeat([1,3,1])/3.0 29 | model_state = model.state_dict() 30 | if list(weights.keys())[0].startswith('module.'): 31 | logging.info("Loading multigpu weights with module. prefix...") 32 | weights = {k.partition('module.')[2]: weights[k] for k in weights.keys()} 33 | 34 | if list(weights.keys())[0].startswith('encoder.'): 35 | logging.info("Loading multigpu weights with encoder. prefix...") 36 | weights = {k.partition('encoder.')[2]: weights[k] for k in weights.keys()} 37 | 38 | # print(weights.items()) 39 | # print("===================") 40 | # print("===================") 41 | # print("===================") 42 | # print("===================") 43 | # print("===================") 44 | # print(model_state) 45 | 46 | filtered_weights = { 47 | k: v for k, v in weights.items() if k in model_state and v.size() == model_state[k].size() 48 | } 49 | logging.info("Loading weights:" + ', '.join(filtered_weights.keys())) 50 | return filtered_weights 51 | 52 | 53 | def checkpoint(model, optimizer, epoch, iteration, config, best, scaler=None, postfix=None): 54 | mkdir_p('weights') 55 | filename = f"checkpoint_{config.net.network}_iter{iteration}.pth" 56 | if config.train.overwrite_weights: 57 | filename = f"checkpoint_{config.net.network}.pth" 58 | if postfix is not None: 59 | filename = f"checkpoint_{config.net.network}_{postfix}.pth" 60 | checkpoint_file = 'weights/' + filename 61 | 62 | _model = model.module if get_world_size() > 1 else model 63 | state = { 64 | 'iteration': iteration, 65 | 'epoch': epoch, 66 | 'arch': config.net.network, 67 | 'state_dict': _model.state_dict(), 68 | 'optimizer': optimizer.state_dict() 69 | } 70 | 71 | if hasattr(config.train, 'mix_precision') and config.train.mix_precision: 72 | state['scalar'] = scaler.state_dict() 73 | 74 | if best is not None: 75 | state['best'] = best 76 | state['best_iter'] = iteration 77 | 78 | # Save config 79 | OmegaConf.save(config, 'config.yaml') 80 | 81 | torch.save(state, checkpoint_file) 82 | logging.info(f"Checkpoint saved to {checkpoint_file}") 83 | 84 | if postfix == None: 85 | # Delete symlink if it exists 86 | if os.path.exists('weights/weights.pth'): 87 | os.remove('weights/weights.pth') 88 | # Create symlink 89 | os.system('ln -s {} weights/weights.pth'.format(filename)) 90 | 91 | 92 | def checkpoint_control(model, optimizer, epoch, iteration, config, best, scaler=None, postfix=None): 93 | mkdir_p('weights') 94 | filename = f"checkpoint_{config.net.controlnet}_iter{iteration}.pth" 95 | if config.train.overwrite_weights: 96 | filename = f"checkpoint_{config.net.controlnet}.pth" 97 | if postfix is not None: 98 | filename = f"checkpoint_{config.net.controlnet}_{postfix}.pth" 99 | checkpoint_file = 'weights/' + filename 100 | 101 | _model = model.module if get_world_size() > 1 else model 102 | state = { 103 | 'iteration': iteration, 104 | 'epoch': epoch, 105 | 'arch': config.net.controlnet, 106 | 'state_dict': _model.state_dict(), 107 | 'optimizer': optimizer.state_dict() 108 | } 109 | 110 | if hasattr(config.train, 'mix_precision') and config.train.mix_precision: 111 | state['scalar'] = scaler.state_dict() 112 | 113 | if best is not None: 114 | state['best'] = best 115 | state['best_iter'] = iteration 116 | 117 | 118 | torch.save(state, checkpoint_file) 119 | logging.info(f"Checkpoint saved to {checkpoint_file}") 120 | 121 | if postfix == None: 122 | # Delete symlink if it exists 123 | if os.path.exists('weights/weights.pth'): 124 | os.remove('weights/weights.pth') 125 | # Create symlink 126 | os.system('ln -s {} weights/weights.pth'.format(filename)) 127 | 128 | def fast_hist(pred, label, n): 129 | k = (label >= 0) & (label < n) 130 | return np.bincount(n * label[k].astype(int) + pred[k], minlength=n ** 2).reshape(n, n) 131 | 132 | 133 | def per_class_iu(hist): 134 | with np.errstate(divide='ignore', invalid='ignore'): 135 | return np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) 136 | 137 | 138 | class WithTimer(object): 139 | """Timer for with statement.""" 140 | 141 | def __init__(self, name=None): 142 | self.name = name 143 | 144 | def __enter__(self): 145 | self.tstart = time.time() 146 | 147 | def __exit__(self, type, value, traceback): 148 | out_str = 'Elapsed: %s' % (time.time() - self.tstart) 149 | if self.name: 150 | logging.info('[{self.name}]') 151 | logging.info(out_str) 152 | 153 | 154 | class Timer(object): 155 | """A simple timer.""" 156 | 157 | def __init__(self): 158 | self.total_time = 0. 159 | self.calls = 0 160 | self.start_time = 0. 161 | self.diff = 0. 162 | self.average_time = 0. 163 | 164 | def reset(self): 165 | self.total_time = 0 166 | self.calls = 0 167 | self.start_time = 0 168 | self.diff = 0 169 | self.averate_time = 0 170 | 171 | def tic(self): 172 | # using time.time instead of time.clock because time time.clock 173 | # does not normalize for multithreading 174 | self.start_time = time.time() 175 | 176 | def toc(self, average=True): 177 | self.diff = time.time() - self.start_time 178 | self.total_time += self.diff 179 | self.calls += 1 180 | self.average_time = self.total_time / self.calls 181 | if average: 182 | return self.average_time 183 | else: 184 | return self.diff 185 | 186 | 187 | class ExpTimer(Timer): 188 | """ Exponential Moving Average Timer """ 189 | 190 | def __init__(self, alpha=0.5): 191 | super(ExpTimer, self).__init__() 192 | self.alpha = alpha 193 | 194 | def toc(self): 195 | self.diff = time.time() - self.start_time 196 | self.average_time = self.alpha * self.diff + \ 197 | (1 - self.alpha) * self.average_time 198 | return self.average_time 199 | 200 | 201 | class AverageMeter(object): 202 | """Computes and stores the average and current value""" 203 | 204 | def __init__(self): 205 | self.reset() 206 | 207 | def reset(self): 208 | self.val = 0 209 | self.avg = 0 210 | self.sum = 0 211 | self.count = 0 212 | 213 | def update(self, val, n=1): 214 | self.val = val 215 | self.sum += val * n 216 | self.count += n 217 | self.avg = self.sum / self.count 218 | 219 | 220 | def mkdir_p(path): 221 | try: 222 | os.makedirs(path) 223 | except OSError as exc: 224 | if exc.errno == errno.EEXIST and os.path.isdir(path): 225 | pass 226 | else: 227 | raise 228 | 229 | 230 | def read_txt(path): 231 | """Read txt file into lines. 232 | """ 233 | with open(path) as f: 234 | lines = f.readlines() 235 | lines = [x.strip() for x in lines] 236 | return lines 237 | 238 | 239 | def debug_on(): 240 | import sys 241 | import pdb 242 | import functools 243 | import traceback 244 | 245 | def decorator(f): 246 | 247 | @functools.wraps(f) 248 | def wrapper(*args, **kwargs): 249 | try: 250 | return f(*args, **kwargs) 251 | except Exception: 252 | info = sys.exc_info() 253 | traceback.print_exception(*info) 254 | pdb.post_mortem(info[2]) 255 | 256 | return wrapper 257 | 258 | return decorator 259 | 260 | 261 | def count_parameters(model): 262 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 263 | 264 | 265 | def get_torch_device(is_cuda): 266 | return torch.device('cuda' if is_cuda else 'cpu') 267 | 268 | 269 | class HashTimeBatch(object): 270 | 271 | def __init__(self, prime=5279): 272 | self.prime = prime 273 | 274 | def __call__(self, time, batch): 275 | return self.hash(time, batch) 276 | 277 | def hash(self, time, batch): 278 | return self.prime * batch + time 279 | 280 | def dehash(self, key): 281 | time = key % self.prime 282 | batch = key / self.prime 283 | return time, batch 284 | 285 | 286 | -------------------------------------------------------------------------------- /datasets/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import random 8 | 9 | import logging 10 | import numpy as np 11 | import scipy 12 | import scipy.ndimage 13 | import scipy.interpolate 14 | import torch 15 | 16 | 17 | # A sparse tensor consists of coordinates and associated features. 18 | # You must apply augmentation to both. 19 | # In 2D, flip, shear, scale, and rotation of images are coordinate transformation 20 | # color jitter, hue, etc., are feature transformations 21 | ############################## 22 | # Feature transformations 23 | ############################## 24 | class ChromaticTranslation(object): 25 | """Add random color to the image, input must be an array in [0,255] or a PIL image""" 26 | 27 | def __init__(self, trans_range_ratio=1e-1): 28 | """ 29 | trans_range_ratio: ratio of translation i.e. 255 * 2 * ratio * rand(-0.5, 0.5) 30 | """ 31 | self.trans_range_ratio = trans_range_ratio 32 | 33 | def __call__(self, coords, feats, labels, instances): 34 | if random.random() < 0.95: 35 | tr = (np.random.rand(1, 3) - 0.5) * 255 * 2 * self.trans_range_ratio 36 | feats[:, :3] = np.clip(tr + feats[:, :3], 0, 255) 37 | return coords, feats, labels, instances 38 | 39 | 40 | class ChromaticAutoContrast(object): 41 | 42 | def __init__(self, randomize_blend_factor=True, blend_factor=0.5): 43 | self.randomize_blend_factor = randomize_blend_factor 44 | self.blend_factor = blend_factor 45 | 46 | def __call__(self, coords, feats, labels, instances): 47 | if random.random() < 0.2: 48 | # mean = np.mean(feats, 0, keepdims=True) 49 | # std = np.std(feats, 0, keepdims=True) 50 | # lo = mean - std 51 | # hi = mean + std 52 | lo = feats[:, :3].min(0, keepdims=True) 53 | hi = feats[:, :3].max(0, keepdims=True) 54 | assert hi.max() > 1, f"invalid color value. Color is supposed to be [0-255]" 55 | 56 | scale = 255 / (hi - lo) 57 | 58 | contrast_feats = (feats[:, :3] - lo) * scale 59 | 60 | blend_factor = random.random() if self.randomize_blend_factor else self.blend_factor 61 | feats[:, :3] = (1 - blend_factor) * feats + blend_factor * contrast_feats 62 | return coords, feats, labels, instances 63 | 64 | 65 | class ChromaticJitter(object): 66 | 67 | def __init__(self, std=0.01): 68 | self.std = std 69 | 70 | def __call__(self, coords, feats, labels, instances): 71 | if random.random() < 0.95: 72 | noise = np.random.randn(feats.shape[0], 3) 73 | noise *= self.std * 255 74 | feats[:, :3] = np.clip(noise + feats[:, :3], 0, 255) 75 | return coords, feats, labels, instances 76 | 77 | 78 | class HueSaturationTranslation(object): 79 | 80 | @staticmethod 81 | def rgb_to_hsv(rgb): 82 | # Translated from source of colorsys.rgb_to_hsv 83 | # r,g,b should be a numpy arrays with values between 0 and 255 84 | # rgb_to_hsv returns an array of floats between 0.0 and 1.0. 85 | rgb = rgb.astype('float') 86 | hsv = np.zeros_like(rgb) 87 | # in case an RGBA array was passed, just copy the A channel 88 | hsv[..., 3:] = rgb[..., 3:] 89 | r, g, b = rgb[..., 0], rgb[..., 1], rgb[..., 2] 90 | maxc = np.max(rgb[..., :3], axis=-1) 91 | minc = np.min(rgb[..., :3], axis=-1) 92 | hsv[..., 2] = maxc 93 | mask = maxc != minc 94 | hsv[mask, 1] = (maxc - minc)[mask] / maxc[mask] 95 | rc = np.zeros_like(r) 96 | gc = np.zeros_like(g) 97 | bc = np.zeros_like(b) 98 | rc[mask] = (maxc - r)[mask] / (maxc - minc)[mask] 99 | gc[mask] = (maxc - g)[mask] / (maxc - minc)[mask] 100 | bc[mask] = (maxc - b)[mask] / (maxc - minc)[mask] 101 | hsv[..., 0] = np.select([r == maxc, g == maxc], [bc - gc, 2.0 + rc - bc], default=4.0 + gc - rc) 102 | hsv[..., 0] = (hsv[..., 0] / 6.0) % 1.0 103 | return hsv 104 | 105 | @staticmethod 106 | def hsv_to_rgb(hsv): 107 | # Translated from source of colorsys.hsv_to_rgb 108 | # h,s should be a numpy arrays with values between 0.0 and 1.0 109 | # v should be a numpy array with values between 0.0 and 255.0 110 | # hsv_to_rgb returns an array of uints between 0 and 255. 111 | rgb = np.empty_like(hsv) 112 | rgb[..., 3:] = hsv[..., 3:] 113 | h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2] 114 | i = (h * 6.0).astype('uint8') 115 | f = (h * 6.0) - i 116 | p = v * (1.0 - s) 117 | q = v * (1.0 - s * f) 118 | t = v * (1.0 - s * (1.0 - f)) 119 | i = i % 6 120 | conditions = [s == 0.0, i == 1, i == 2, i == 3, i == 4, i == 5] 121 | rgb[..., 0] = np.select(conditions, [v, q, p, p, t, v], default=v) 122 | rgb[..., 1] = np.select(conditions, [v, v, v, q, p, p], default=t) 123 | rgb[..., 2] = np.select(conditions, [v, p, t, v, v, q], default=p) 124 | return rgb.astype('uint8') 125 | 126 | def __init__(self, hue_max, saturation_max): 127 | self.hue_max = hue_max 128 | self.saturation_max = saturation_max 129 | 130 | def __call__(self, coords, feats, labels, instances): 131 | # Assume feat[:, :3] is rgb 132 | hsv = HueSaturationTranslation.rgb_to_hsv(feats[:, :3]) 133 | hue_val = (random.random() - 0.5) * 2 * self.hue_max 134 | sat_ratio = 1 + (random.random() - 0.5) * 2 * self.saturation_max 135 | hsv[..., 0] = np.remainder(hue_val + hsv[..., 0] + 1, 1) 136 | hsv[..., 1] = np.clip(sat_ratio * hsv[..., 1], 0, 1) 137 | feats[:, :3] = np.clip(HueSaturationTranslation.hsv_to_rgb(hsv), 0, 255) 138 | 139 | return coords, feats, labels, instances 140 | 141 | 142 | ############################## 143 | # Coordinate transformations 144 | ############################## 145 | class RandomDropout(object): 146 | 147 | def __init__(self, dropout_ratio=0.2, dropout_application_ratio=0.5): 148 | """ 149 | upright_axis: axis index among x,y,z, i.e. 2 for z 150 | """ 151 | self.dropout_ratio = dropout_ratio 152 | self.dropout_application_ratio = dropout_application_ratio 153 | 154 | def __call__(self, coords, feats, labels, instances): 155 | if random.random() < self.dropout_ratio: 156 | N = len(coords) 157 | inds = np.random.choice(N, int(N * (1 - self.dropout_ratio)), replace=False) 158 | return coords[inds], feats[inds], labels[inds], instances[inds] 159 | return coords, feats, labels, instances 160 | 161 | 162 | class RandomHorizontalFlip(object): 163 | 164 | def __init__(self, upright_axis, is_temporal): 165 | """ 166 | upright_axis: axis index among x,y,z, i.e. 2 for z 167 | """ 168 | self.is_temporal = is_temporal 169 | self.D = 4 if is_temporal else 3 170 | self.upright_axis = {'x': 0, 'y': 1, 'z': 2}[upright_axis.lower()] 171 | # Use the rest of axes for flipping. 172 | self.horz_axes = set(range(self.D)) - set([self.upright_axis]) 173 | 174 | def __call__(self, coords, feats, labels, instances): 175 | if random.random() < 0.95: 176 | for curr_ax in self.horz_axes: 177 | if random.random() < 0.5: 178 | coord_max = np.max(coords[:, curr_ax]) 179 | coords[:, curr_ax] = coord_max - coords[:, curr_ax] 180 | return coords, feats, labels, instances 181 | 182 | 183 | class ElasticDistortion: 184 | 185 | def __init__(self, distortion_params): 186 | self.distortion_params = distortion_params 187 | 188 | def elastic_distortion(self, coords, feats, labels, granularity, magnitude): 189 | """Apply elastic distortion on sparse coordinate space. 190 | 191 | pointcloud: numpy array of (number of points, at least 3 spatial dims) 192 | granularity: size of the noise grid (in same scale[m/cm] as the voxel grid) 193 | magnitude: noise multiplier 194 | """ 195 | blurx = np.ones((3, 1, 1, 1)).astype('float32') / 3 196 | blury = np.ones((1, 3, 1, 1)).astype('float32') / 3 197 | blurz = np.ones((1, 1, 3, 1)).astype('float32') / 3 198 | coords_min = coords.min(0) 199 | 200 | # Create Gaussian noise tensor of the size given by granularity. 201 | noise_dim = ((coords - coords_min).max(0) // granularity).astype(int) + 3 202 | noise = np.random.randn(*noise_dim, 3).astype(np.float32) 203 | 204 | # Smoothing. 205 | for _ in range(2): 206 | noise = scipy.ndimage.filters.convolve(noise, blurx, mode='constant', cval=0) 207 | noise = scipy.ndimage.filters.convolve(noise, blury, mode='constant', cval=0) 208 | noise = scipy.ndimage.filters.convolve(noise, blurz, mode='constant', cval=0) 209 | 210 | # Trilinear interpolate noise filters for each spatial dimensions. 211 | ax = [ 212 | np.linspace(d_min, d_max, d) 213 | for d_min, d_max, d in zip(coords_min - granularity, coords_min + granularity * 214 | (noise_dim - 2), noise_dim) 215 | ] 216 | interp = scipy.interpolate.RegularGridInterpolator(ax, noise, bounds_error=0, fill_value=0) 217 | coords += interp(coords) * magnitude 218 | return coords, feats, labels 219 | 220 | def __call__(self, coords, feats, labels): 221 | if self.distortion_params is not None: 222 | if random.random() < 0.95: 223 | for granularity, magnitude in self.distortion_params: 224 | coords, feats, labels = self.elastic_distortion(coords, feats, labels, granularity, 225 | magnitude) 226 | return coords, feats, labels 227 | 228 | 229 | class Compose(object): 230 | """Composes several transforms together.""" 231 | 232 | def __init__(self, transforms): 233 | self.transforms = transforms 234 | 235 | def __call__(self, *args): 236 | for t in self.transforms: 237 | args = t(*args) 238 | return args 239 | 240 | 241 | class collate_fn_factory: 242 | """Generates collate function for coords, feats, labels. 243 | 244 | Args: 245 | limit_numpoints: If 0 or False, does not alter batch size. If positive integer, limits batch 246 | size so that the number of input coordinates is below limit_numpoints. 247 | """ 248 | 249 | def __init__(self): 250 | pass 251 | 252 | def __call__(self, list_data): 253 | return list_data -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /tools/ddp_trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging 3 | import os 4 | import sys 5 | import torch 6 | import torch.nn.functional as F 7 | import tqdm 8 | 9 | from torch import nn 10 | from torch.serialization import default_restore_location 11 | from torch.cuda.amp import autocast, GradScaler 12 | from tensorboardX import SummaryWriter 13 | from omegaconf import OmegaConf 14 | 15 | from lib.distributed import get_world_size, all_gather, is_master_proc 16 | from lib.solvers import initialize_optimizer, initialize_scheduler 17 | from datasets import load_dataset, initialize_data_loader 18 | from lib.utils import checkpoint, checkpoint_control, Timer, AverageMeter, load_state_with_same_shape, count_parameters 19 | 20 | from models.networks import load_network, initialize_network, initialize_controlnet 21 | from models.diffusion import load_diff_model, initialize_diff_model 22 | from models.diffusion.gaussian_diffusion import get_named_beta_schedule 23 | from models.modules.resample import UniformSampler, LossSecondMomentResampler, LossAwareSampler 24 | 25 | from tools.test import test as test_ 26 | 27 | def exists(val): 28 | return val is not None 29 | 30 | def default(val, d): 31 | if exists(val): 32 | return val 33 | return d() if callable(d) else d 34 | 35 | class DiffusionTrainer: 36 | def __init__(self, config): 37 | 38 | self.is_master = is_master_proc(config.exp.num_gpus) if config.exp.num_gpus > 1 else True 39 | self.cur_device = torch.cuda.current_device() 40 | 41 | # Load the configurations 42 | self.setup_logging() 43 | 44 | # Use the previous configurations if the training breaks 45 | # if config.train.is_train == True and config.train.debug == False: 46 | # if os.path.exists('config.yaml'): 47 | # logging.info('===> Loading exsiting config file') 48 | # config = OmegaConf.load('config.yaml') 49 | # logging.info('===> Loaded exsiting config file') 50 | # logging.info('===> Configurations') 51 | # logging.info(config) 52 | 53 | # Dataloader 54 | DatasetClass = load_dataset(config.data.dataset) 55 | logging.info('===> Initializing dataloader') 56 | self.train_data_loader = initialize_data_loader( 57 | DatasetClass, config, phase=config.train.train_phase, 58 | num_workers=config.exp.num_workers, augment_data=False, 59 | shuffle=True, repeat=True, collate=config.data.collate_fn, 60 | batch_size=config.exp.batch_size // config.exp.num_gpus, 61 | persistent_workers=config.data.persistent_workers 62 | ) 63 | 64 | self.test_data_loader = initialize_data_loader( 65 | DatasetClass, config, phase=config.test.test_phase, 66 | num_workers=config.exp.num_workers, augment_data=False, 67 | shuffle=config.test.partial_shape==None, repeat=False, collate=config.data.collate_fn, 68 | batch_size=config.test.test_batch_size // config.exp.num_gpus, 69 | persistent_workers=config.data.persistent_workers 70 | ) 71 | 72 | # Main network initialization 73 | logging.info('===> Building model') 74 | NetClass = load_network(config.net.network) 75 | model = initialize_network(NetClass, config) 76 | 77 | # ControlNet initialization 78 | logging.info('===> Building model') 79 | ControlNet = load_network(config.net.controlnet) 80 | control_model = initialize_controlnet(ControlNet, config) 81 | 82 | logging.info('===> Number of trainable parameters: {}: {}'.format(NetClass.__name__, count_parameters(model))) 83 | logging.info('===> Number of trainable parameters: {}: {}'.format(ControlNet.__name__, count_parameters(control_model))) 84 | logging.info(model) 85 | logging.info(control_model) 86 | 87 | # Load weights for the main network and control network 88 | if config.net.weights: 89 | logging.info('===> Loading weights: ' + config.net.weights) 90 | state = torch.load(config.net.weights, map_location=lambda s, l: default_restore_location(s, 'cpu')) 91 | matched_weights = load_state_with_same_shape(model, state['state_dict']) 92 | model_dict = model.state_dict() 93 | model_dict.update(matched_weights) 94 | model.load_state_dict(model_dict) 95 | 96 | if config.net.control_weights: 97 | logging.info('===> Loading weights: ' + config.net.control_weights) 98 | config.net.control_weights = default(config.net.control_weights, config.net.weights) 99 | control_state = torch.load(config.net.control_weights, map_location=lambda s, l: default_restore_location(s, 'cpu')) 100 | 101 | control_matched_weights = load_state_with_same_shape(control_model, control_state['state_dict']) 102 | control_model_dict = control_model.state_dict() 103 | control_model_dict.update(control_matched_weights) 104 | control_model.load_state_dict(control_model_dict) 105 | 106 | model = model.cuda() 107 | if config.exp.num_gpus > 1: 108 | model = torch.nn.parallel.DistributedDataParallel( 109 | module=model, device_ids=[self.cur_device], 110 | output_device=self.cur_device, 111 | broadcast_buffers=False, 112 | # find_unused_parameters=True 113 | ) 114 | 115 | control_model = control_model.cuda() 116 | if config.exp.num_gpus > 1: 117 | control_model = torch.nn.parallel.DistributedDataParallel( 118 | module=control_model, device_ids=[self.cur_device], 119 | output_device=self.cur_device, 120 | broadcast_buffers=False, 121 | ) 122 | 123 | self.config = config 124 | self.skip_validate = config.exp.skip_validate 125 | self.model = model 126 | self.control_model = control_model 127 | 128 | # Diffusion model 129 | # linear, 1000, 1.0 130 | betas = get_named_beta_schedule(config.diffusion.beta_schedule, 131 | config.diffusion.step, 132 | config.diffusion.scale_ratio) 133 | DiffusionClass = load_diff_model(config.diffusion.model) 134 | self.diffusion_model = initialize_diff_model(DiffusionClass, betas, config) 135 | 136 | 137 | # Sample 138 | if config.diffusion.sampler == 'uniform': 139 | self.sampler = UniformSampler(self.diffusion_model) 140 | elif config.diffusion.sampler == 'second-order': 141 | self.sampler = LossSecondMomentResampler(self.diffusion_model) 142 | else: 143 | raise Exception("Unknown Sampler.....") 144 | 145 | if self.is_master: 146 | self.writer = SummaryWriter(log_dir='tensorboard') 147 | 148 | self.optimizer, self.scheduler = self.configure_optimizers(config) 149 | 150 | # # fix parameters for training 151 | # if config.train.fine_tune_encoder == True: 152 | # assert config.net.weights is not None, "Please specify the pre-trained weights for fine-tuning" 153 | # for name, p in self.model.named_parameters(): 154 | # if 'time_embed' in name or 'input_blocks' in name or 'middle_block' in name: 155 | # p.requires_grad = True 156 | # else: 157 | # p.requires_grad = False 158 | 159 | # Mixed precision training 160 | if hasattr(config.train, 'mix_precision') and self.config.train.mix_precision: 161 | self.scaler = GradScaler() 162 | else: 163 | self.scaler = None 164 | 165 | # Continue training from the checkpoint (TBD) 166 | if config.train.is_train: 167 | checkpoint_fn = 'weights/weights.pth' 168 | self.min_loss = 100 169 | self.best = -1 170 | self.curr_iter, self.epoch, self.is_training = 1, 1, True 171 | 172 | def configure_optimizers(self, config): 173 | params = list(self.control_model.parameters()) 174 | params += list(self.model.parameters()) 175 | 176 | optimizer = initialize_optimizer(params, config.optimizer) 177 | if config.optimizer.lr_decay: # False 178 | scheduler = initialize_scheduler(self.optimizer, config.optimizer) 179 | else: 180 | scheduler = None 181 | return optimizer, scheduler 182 | 183 | def setup_logging(self): 184 | 185 | ch = logging.StreamHandler(sys.stdout) 186 | logging.getLogger().setLevel(logging.WARN) 187 | if self.is_master: 188 | logging.getLogger().setLevel(logging.INFO) 189 | logging.basicConfig( 190 | format=os.uname()[1].split('.')[0] + ' %(asctime)s %(message)s', 191 | datefmt='%m/%d %H:%M:%S', 192 | handlers=[ch]) 193 | 194 | def load_state(self, state): 195 | if get_world_size() > 1: 196 | _model = self.model.module 197 | else: 198 | _model = self.model 199 | _model.load_state_dict(state) 200 | 201 | def set_seed(self): 202 | # Set seed based on args.seed and the update number so that we get 203 | # reproducible results when resuming from checkpoints 204 | seed = self.config.misc.seed + self.curr_iter 205 | torch.manual_seed(seed) 206 | torch.cuda.manual_seed(seed) 207 | 208 | def test(self): 209 | return test_(self.model, self.control_model, self.test_data_loader, self.config) 210 | 211 | 212 | def validate(self): 213 | if not self.skip_validate: 214 | val_loss, val_score, _, = test_(self.model, self.val_data_loader, self.config) 215 | self.writer.add_scalar('val/loss', val_loss, self.curr_iter) 216 | self.writer.add_scalar('val/score', val_score, self.curr_iter) 217 | 218 | if val_score > self.best: 219 | self.best = val_score 220 | self.best_iter = self.curr_iter 221 | checkpoint(self.model, self.optimizer, self.epoch, self.curr_iter, self.config, 222 | self.best, self.scaler, postfix="best") 223 | logging.info("Current best score: {:.3f} at iter {}".format(self.best, self.best_iter)) 224 | 225 | checkpoint(self.model, self.optimizer, self.epoch, self.curr_iter, self.config, self.best, self.scaler) 226 | checkpoint_control(self.control_model, self.optimizer, self.epoch, self.curr_iter, self.config, self.best, self.scaler) 227 | 228 | def train(self): 229 | ## To be checked 230 | self.model.train() 231 | self.control_model.train() 232 | 233 | # Configuration 234 | data_timer, iter_timer = Timer(), Timer() 235 | fw_timer, bw_timer, ddp_timer = Timer(), Timer(), Timer() 236 | data_time_avg, iter_time_avg = AverageMeter(), AverageMeter() 237 | fw_time_avg, bw_time_avg, ddp_time_avg = AverageMeter(), AverageMeter(), AverageMeter() 238 | 239 | losses = { 240 | 'total_loss': AverageMeter(), 241 | 'mse_loss': AverageMeter() 242 | } 243 | 244 | # Train the network 245 | logging.info('===> Start training on {} GPUs, batch-size={}'.format( 246 | get_world_size(), self.config.exp.batch_size)) 247 | 248 | data_iter = self.train_data_loader.__iter__() # (distributed) infinite sampler 249 | while self.is_training: 250 | for _ in range(len(self.train_data_loader)): 251 | self.optimizer.zero_grad() 252 | data_time = 0 253 | batch_losses = {'total_loss': 0.0, 254 | 'mse_loss': 0.0} 255 | iter_timer.tic() 256 | 257 | # set random seed for every iteration for trackability 258 | self.set_seed() 259 | total_loss = 0.0 260 | mse_loss = 0.0 261 | # Get training data 262 | data_timer.tic() 263 | scan_id, input_sdf, gt_df = data_iter.next() 264 | shape_gt = gt_df.unsqueeze(1).to(self.cur_device) 265 | 266 | if self.config.data.dataset != 'ControlledEPNDataset': 267 | input_sdf = input_sdf.unsqueeze(1).to(self.cur_device) 268 | else: 269 | input_sdf = input_sdf.to(self.cur_device) 270 | 271 | data_time += data_timer.toc(False) 272 | # Feed forward 273 | fw_timer.tic() 274 | 275 | if hasattr(self.config.train, 'mix_precision') and self.config.train.mix_precision: 276 | with autocast(): 277 | t, t_weights = self.sampler.sample(shape_gt.size(0), device=self.cur_device) 278 | iterative_loss = self.diffusion_model.training_losses(model=self.model, 279 | control_model=self.control_model, 280 | x_start=shape_gt, 281 | hint=input_sdf, 282 | t=t, 283 | weighted_loss=self.config.train.weighted_loss) 284 | mse_loss += torch.mean(iterative_loss['loss'] * t_weights) 285 | else: 286 | t, t_weights = self.sampler.sample(shape_gt.size(0), device=self.cur_device) 287 | iterative_loss = self.diffusion_model.training_losses(model=self.model, 288 | control_model=self.control_model, 289 | x_start=shape_gt, 290 | hint=input_sdf, 291 | t=t, 292 | weighted_loss=self.config.train.weighted_loss) 293 | mse_loss += torch.mean(iterative_loss['loss'] * t_weights) 294 | 295 | # Compute and accumulate gradient 296 | total_loss += mse_loss 297 | 298 | # bp the loss 299 | fw_timer.toc(False) 300 | bw_timer.tic() 301 | 302 | if hasattr(self.config.train, 'mix_precision') and self.config.train.mix_precision: 303 | self.scaler.scale(total_loss).backward() 304 | else: 305 | total_loss.backward() 306 | if self.config.train.use_gradient_clip: 307 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.train.gradient_clip_value) 308 | 309 | bw_timer.toc(False) 310 | 311 | # gather information 312 | logging_output = {'total_loss': total_loss.item(), 'mse_loss': mse_loss.item()} 313 | 314 | ddp_timer.tic() 315 | if self.config.exp.num_gpus > 1: 316 | logging_output = all_gather(logging_output) 317 | logging_output = {w: np.mean([ 318 | a[w] for a in logging_output] 319 | ) for w in logging_output[0]} 320 | 321 | batch_losses['total_loss'] += logging_output['total_loss'] 322 | batch_losses['mse_loss'] += logging_output['mse_loss'] 323 | ddp_timer.toc(False) 324 | 325 | # Update number of steps 326 | if hasattr(self.config.train, 'mix_precision') and self.config.train.mix_precision: 327 | self.scaler.step(self.optimizer) 328 | self.scaler.update() 329 | else: 330 | self.optimizer.step() 331 | if self.scheduler is not None: 332 | self.scheduler.step() 333 | 334 | # print(self.model.state_dict()['input_blocks.14.0.in_layers.0.bias']) 335 | # print(self.model.state_dict()['output_blocks.15.0.in_layers.2.bias']) 336 | 337 | data_time_avg.update(data_time) 338 | iter_time_avg.update(iter_timer.toc(False)) 339 | fw_time_avg.update(fw_timer.diff) 340 | bw_time_avg.update(bw_timer.diff) 341 | ddp_time_avg.update(ddp_timer.diff) 342 | 343 | losses['total_loss'].update(batch_losses['total_loss'], shape_gt.size(0)) 344 | losses['mse_loss'].update(batch_losses['mse_loss'], shape_gt.size(0)) 345 | 346 | if self.curr_iter >= self.config.train.max_iter: 347 | self.is_training = False 348 | break 349 | 350 | last_lr = self.scheduler.get_last_lr()[0] if self.scheduler else self.optimizer.state_dict()['param_groups'][0]['lr'] 351 | if self.curr_iter % self.config.train.stat_freq == 0 or self.curr_iter == 1: 352 | # lrs = ', '.join(['{:.3e}'.format(x) for x in last_lr]) 353 | # debug_str = "===> Epoch[{}]({}/{}): Loss {:.4f}, LR: {}\t".format( 354 | # self.epoch, self.curr_iter, len(self.train_data_loader), 355 | # losses['total_loss'].avg, lrs) 356 | lr = '{:.3e}'.format(last_lr) 357 | debug_str = "===> Epoch[{}]({}/{}): Loss {:.4f}, LR: {}\t".format( 358 | self.epoch, self.curr_iter, len(self.train_data_loader), 359 | losses['total_loss'].avg, lr) 360 | debug_str += "Data time: {:.4f}, Forward time: {:.4f}, Backward time: {:.4f}, DDP time: {:.4f}, Total iter time: {:.4f}".format( 361 | data_time_avg.avg, fw_time_avg.avg, bw_time_avg.avg, ddp_time_avg.avg, 362 | iter_time_avg.avg) 363 | logging.info(debug_str) 364 | # Reset timers 365 | data_time_avg.reset() 366 | iter_time_avg.reset() 367 | 368 | # Write logs 369 | if self.is_master: 370 | self.writer.add_scalar('train/loss', losses['total_loss'].avg, self.curr_iter) 371 | self.writer.add_scalar('train/learning_rate', last_lr, self.curr_iter) 372 | 373 | # clear loss 374 | losses['total_loss'].reset() 375 | losses['mse_loss'].reset() 376 | 377 | # Validation 378 | if self.curr_iter % self.config.train.val_freq == 0 and self.is_master: 379 | self.validate() 380 | if not self.skip_validate: 381 | self.model.train() 382 | self.control_model.train() 383 | 384 | 385 | if self.curr_iter % self.config.train.empty_cache_freq == 0: 386 | # Clear cache 387 | torch.cuda.empty_cache() 388 | 389 | # End of iteration 390 | self.curr_iter += 1 391 | 392 | # max_memory_allocated = torch.cuda.max_memory_allocated(self.cur_device) / (1024 ** 2) 393 | # logging.info(f"End of Epoch {self.epoch + 1}, Max memory allocated: {max_memory_allocated:.2f} MiB") 394 | 395 | self.epoch += 1 396 | 397 | # Explicit memory cleanup 398 | if hasattr(data_iter, 'cleanup'): 399 | data_iter.cleanup() 400 | 401 | # max_memory_allocated = torch.cuda.max_memory_allocated(self.cur_device) / (1024 ** 2) 402 | # print(f"End of training, Max memory allocated: {max_memory_allocated:.2f} MiB") 403 | 404 | # Save the final model 405 | if self.is_master: 406 | self.validate() -------------------------------------------------------------------------------- /models/networks/resunet3d.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | import math 4 | import numpy as np 5 | import torch 6 | import torch as th 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from models.modules.fp16_util import convert_module_to_f16, convert_module_to_f32 11 | from models.modules.nn import ( 12 | SiLU, 13 | conv_nd, 14 | linear, 15 | avg_pool_nd, 16 | zero_module, 17 | normalization, 18 | timestep_embedding, 19 | checkpoint, 20 | ) 21 | 22 | class TimestepBlock(nn.Module): 23 | """ 24 | Any module where forward() takes timestep embeddings as a second argument. 25 | """ 26 | 27 | @abstractmethod 28 | def forward(self, x, emb): 29 | """ 30 | Apply the module to `x` given `emb` timestep embeddings. 31 | """ 32 | 33 | 34 | class TimestepEmbedSequential(nn.Sequential, TimestepBlock): 35 | """ 36 | A sequential module that passes timestep embeddings to the children that 37 | support it as an extra input. 38 | """ 39 | 40 | def forward(self, x, emb): 41 | for layer in self: 42 | if isinstance(layer, TimestepBlock): 43 | x = layer(x, emb) 44 | else: 45 | x = layer(x) 46 | return x 47 | 48 | 49 | class Upsample(nn.Module): 50 | """ 51 | An upsampling layer with an optional convolution. 52 | :param channels: channels in the inputs and outputs. 53 | :param use_conv: a bool determining if a convolution is applied. 54 | :param dims: determines if the signal is 1D, 2D, or 3D 55 | """ 56 | 57 | def __init__(self, channels, use_conv, dims=2): 58 | super().__init__() 59 | self.channels = channels 60 | self.use_conv = use_conv 61 | self.dims = dims 62 | if use_conv: 63 | self.conv = conv_nd(dims, channels, channels, 3, padding=1) 64 | 65 | def forward(self, x): 66 | assert x.shape[1] == self.channels 67 | x = F.interpolate(x, scale_factor=2, mode="nearest") 68 | if self.use_conv: 69 | x = self.conv(x) 70 | return x 71 | 72 | 73 | class Downsample(nn.Module): 74 | """ 75 | A downsampling layer with an optional convolution. 76 | :param channels: channels in the inputs and outputs. 77 | :param use_conv: a bool determining if a convolution is applied. 78 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 79 | downsampling occurs in the inner-two dimensions. 80 | """ 81 | 82 | def __init__(self, channels, use_conv, dims=2): 83 | super().__init__() 84 | self.channels = channels 85 | self.use_conv = use_conv 86 | self.dims = dims 87 | stride = 2 88 | if use_conv: 89 | self.op = conv_nd(dims, channels, channels, 3, stride=stride, padding=1) 90 | else: 91 | self.op = avg_pool_nd(stride) 92 | 93 | def forward(self, x): 94 | assert x.shape[1] == self.channels 95 | return self.op(x) 96 | 97 | 98 | class ResBlock(TimestepBlock): 99 | """ 100 | A residual block that can optionally change the number of channels. 101 | :param channels: the number of input channels. 102 | :param emb_channels: the number of timestep embedding channels. 103 | :param dropout: the rate of dropout. 104 | :param out_channels: if specified, the number of out channels. 105 | :param use_conv: if True and out_channels is specified, use a spatial 106 | convolution instead of a smaller 1x1 convolution to change the 107 | channels in the skip connection. 108 | :param dims: determines if the signal is 1D, 2D, or 3D. 109 | :param use_checkpoint: if True, use gradient checkpointing on this module. 110 | """ 111 | 112 | def __init__( 113 | self, 114 | channels, 115 | emb_channels, 116 | dropout, 117 | out_channels=None, 118 | use_conv=False, 119 | use_scale_shift_norm=False, 120 | dims=2, 121 | use_checkpoint=False, 122 | activation = SiLU() 123 | ): 124 | super().__init__() 125 | self.channels = channels 126 | self.emb_channels = emb_channels 127 | self.dropout = dropout 128 | self.out_channels = out_channels or channels 129 | self.use_conv = use_conv 130 | self.use_checkpoint = use_checkpoint 131 | self.use_scale_shift_norm = use_scale_shift_norm 132 | 133 | self.in_layers = nn.Sequential( 134 | normalization(channels), 135 | activation, 136 | conv_nd(dims, channels, self.out_channels, 3, padding=1), 137 | ) 138 | self.emb_layers = nn.Sequential( 139 | activation, 140 | linear( 141 | emb_channels, 142 | 2 * self.out_channels if use_scale_shift_norm else self.out_channels, 143 | ), 144 | ) 145 | self.out_layers = nn.Sequential( 146 | normalization(self.out_channels), 147 | activation, 148 | nn.Dropout(p=dropout), 149 | zero_module( 150 | conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) 151 | ), 152 | ) 153 | 154 | if self.out_channels == channels: 155 | self.skip_connection = nn.Identity() 156 | elif use_conv: 157 | self.skip_connection = conv_nd( 158 | dims, channels, self.out_channels, 3, padding=1 159 | ) 160 | else: 161 | self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) 162 | 163 | def forward(self, x, emb): 164 | """ 165 | Apply the block to a Tensor, conditioned on a timestep embedding. 166 | :param x: an [N x C x ...] Tensor of features. 167 | :param emb: an [N x emb_channels] Tensor of timestep embeddings. 168 | :return: an [N x C x ...] Tensor of outputs. 169 | """ 170 | return checkpoint( 171 | self._forward, (x, emb), self.parameters(), self.use_checkpoint 172 | ) 173 | 174 | def _forward(self, x, emb): 175 | h = self.in_layers(x) 176 | emb_out = self.emb_layers(emb).type(h.dtype) 177 | while len(emb_out.shape) < len(h.shape): 178 | emb_out = emb_out[..., None] 179 | if self.use_scale_shift_norm: 180 | out_norm, out_rest = self.out_layers[0], self.out_layers[1:] 181 | scale, shift = th.chunk(emb_out, 2, dim=1) 182 | h = out_norm(h) * (1 + scale) + shift 183 | h = out_rest(h) 184 | else: 185 | h = h + emb_out 186 | h = self.out_layers(h) 187 | return self.skip_connection(x) + h 188 | 189 | 190 | class AttentionBlock(nn.Module): 191 | """ 192 | An attention block that allows spatial positions to attend to each other. 193 | Originally ported from here, but adapted to the N-d case. 194 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. 195 | """ 196 | 197 | def __init__(self, channels, num_heads=1, use_checkpoint=False): 198 | super().__init__() 199 | self.channels = channels 200 | self.num_heads = num_heads 201 | self.use_checkpoint = use_checkpoint 202 | 203 | self.norm = normalization(channels) 204 | self.qkv = conv_nd(1, channels, channels * 3, 1) 205 | self.attention = QKVAttention() 206 | self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) 207 | 208 | def forward(self, x): 209 | return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint) 210 | 211 | def _forward(self, x): 212 | b, c, *spatial = x.shape 213 | x = x.reshape(b, c, -1) 214 | qkv = self.qkv(self.norm(x)) 215 | qkv = qkv.reshape(b * self.num_heads, -1, qkv.shape[2]) 216 | h = self.attention(qkv) 217 | h = h.reshape(b, -1, h.shape[-1]) 218 | h = self.proj_out(h) 219 | return (x + h).reshape(b, c, *spatial) 220 | 221 | 222 | class QKVAttention(nn.Module): 223 | """ 224 | A module which performs QKV attention. 225 | """ 226 | 227 | def forward(self, qkv): 228 | """ 229 | Apply QKV attention. 230 | :param qkv: an [N x (C * 3) x T] tensor of Qs, Ks, and Vs. 231 | :return: an [N x C x T] tensor after attention. 232 | """ 233 | ch = qkv.shape[1] // 3 234 | q, k, v = th.split(qkv, ch, dim=1) 235 | scale = 1 / math.sqrt(math.sqrt(ch)) 236 | weight = th.einsum( 237 | "bct,bcs->bts", q * scale, k * scale 238 | ) # More stable with f16 than dividing afterwards 239 | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) 240 | return th.einsum("bts,bcs->bct", weight, v) 241 | 242 | @staticmethod 243 | def count_flops(model, _x, y): 244 | """ 245 | A counter for the `thop` package to count the operations in an 246 | attention operation. 247 | Meant to be used like: 248 | macs, params = thop.profile( 249 | model, 250 | inputs=(inputs, timestamps), 251 | custom_ops={QKVAttention: QKVAttention.count_flops}, 252 | ) 253 | """ 254 | b, c, *spatial = y[0].shape 255 | num_spatial = int(np.prod(spatial)) 256 | # We perform two matmuls with the same number of ops. 257 | # The first computes the weight matrix, the second computes 258 | # the combination of the value vectors. 259 | matmul_ops = 2 * b * (num_spatial ** 2) * c 260 | model.total_ops += th.DoubleTensor([matmul_ops]) 261 | 262 | 263 | class ResUNet(nn.Module): 264 | """ 265 | The full UNet model with attention and timestep embedding. 266 | :param in_channels: channels in the input Tensor. 267 | :param model_channels: base channel count for the model. 268 | :param out_channels: channels in the output Tensor. 269 | :param num_res_blocks: number of residual blocks per downsample. 270 | :param attention_resolutions: a collection of downsample rates at which 271 | attention will take place. May be a set, list, or tuple. 272 | For example, if this contains 4, then at 4x downsampling, attention 273 | will be used. 274 | :param dropout: the dropout probability. 275 | :param channel_mult: channel multiplier for each level of the UNet. 276 | :param conv_resample: if True, use learned convolutions for upsampling and 277 | downsampling. 278 | :param dims: determines if the signal is 1D, 2D, or 3D. 279 | :param num_classes: if specified (as an int), then this model will be 280 | class-conditional with `num_classes` classes. 281 | :param use_checkpoint: use gradient checkpointing to reduce memory usage. 282 | :param num_heads: the number of attention heads in each attention layer. 283 | """ 284 | 285 | def __init__( 286 | self, 287 | in_channels, 288 | model_channels, 289 | out_channels, 290 | num_res_blocks, 291 | attention_resolutions, 292 | dropout=0, 293 | channel_mult=(1, 2, 4, 8), 294 | conv_resample=True, 295 | dims=2, 296 | num_classes=None, 297 | use_checkpoint=False, 298 | num_heads=1, 299 | num_heads_upsample=-1, 300 | use_scale_shift_norm=False, 301 | activation = None, 302 | ): 303 | super().__init__() 304 | 305 | self.activation = activation if activation is not None else SiLU() 306 | if num_heads_upsample == -1: 307 | num_heads_upsample = num_heads 308 | 309 | self.in_channels = in_channels 310 | self.model_channels = model_channels 311 | self.out_channels = out_channels 312 | self.num_res_blocks = num_res_blocks 313 | self.attention_resolutions = attention_resolutions 314 | self.dropout = dropout 315 | self.channel_mult = channel_mult 316 | self.conv_resample = conv_resample 317 | self.num_classes = num_classes 318 | self.use_checkpoint = use_checkpoint 319 | self.num_heads = num_heads 320 | self.num_heads_upsample = num_heads_upsample 321 | time_embed_dim = model_channels * 4 322 | self.time_embed = nn.Sequential( 323 | linear(model_channels, time_embed_dim), 324 | self.activation, 325 | linear(time_embed_dim, time_embed_dim), 326 | ) 327 | 328 | if self.num_classes is not None: 329 | self.label_emb = nn.Embedding(num_classes, time_embed_dim) 330 | 331 | self.input_blocks = nn.ModuleList( 332 | [ 333 | TimestepEmbedSequential( 334 | conv_nd(dims, in_channels, model_channels, 3, padding=1) 335 | ) 336 | ] 337 | ) 338 | input_block_chans = [model_channels] 339 | ch = model_channels 340 | ds = 1 341 | for level, mult in enumerate(channel_mult): 342 | for _ in range(num_res_blocks): 343 | layers = [ 344 | ResBlock( 345 | ch, 346 | time_embed_dim, 347 | dropout, 348 | out_channels=mult * model_channels, 349 | dims=dims, 350 | use_checkpoint=use_checkpoint, 351 | use_scale_shift_norm=use_scale_shift_norm, 352 | activation = self.activation 353 | ) 354 | ] 355 | ch = mult * model_channels 356 | if ds in attention_resolutions: 357 | layers.append( 358 | AttentionBlock( 359 | ch, use_checkpoint=use_checkpoint, num_heads=num_heads 360 | ) 361 | ) 362 | self.input_blocks.append(TimestepEmbedSequential(*layers)) 363 | input_block_chans.append(ch) 364 | if level != len(channel_mult) - 1: 365 | self.input_blocks.append( 366 | TimestepEmbedSequential(Downsample(ch, conv_resample, dims=dims)) 367 | ) 368 | input_block_chans.append(ch) 369 | ds *= 2 370 | 371 | self.middle_block = TimestepEmbedSequential( 372 | ResBlock( 373 | ch, 374 | time_embed_dim, 375 | dropout, 376 | dims=dims, 377 | use_checkpoint=use_checkpoint, 378 | use_scale_shift_norm=use_scale_shift_norm, 379 | activation = self.activation 380 | ), 381 | AttentionBlock(ch, use_checkpoint=use_checkpoint, num_heads=num_heads), 382 | ResBlock( 383 | ch, 384 | time_embed_dim, 385 | dropout, 386 | dims=dims, 387 | use_checkpoint=use_checkpoint, 388 | use_scale_shift_norm=use_scale_shift_norm, 389 | activation = self.activation 390 | ), 391 | ) 392 | 393 | self.output_blocks = nn.ModuleList([]) 394 | for level, mult in list(enumerate(channel_mult))[::-1]: 395 | for i in range(num_res_blocks + 1): 396 | layers = [ 397 | ResBlock( 398 | ch + input_block_chans.pop(), 399 | time_embed_dim, 400 | dropout, 401 | out_channels=model_channels * mult, 402 | dims=dims, 403 | use_checkpoint=use_checkpoint, 404 | use_scale_shift_norm=use_scale_shift_norm, 405 | activation = self.activation 406 | ) 407 | ] 408 | ch = model_channels * mult 409 | if ds in attention_resolutions: 410 | layers.append( 411 | AttentionBlock( 412 | ch, 413 | use_checkpoint=use_checkpoint, 414 | num_heads=num_heads_upsample, 415 | ) 416 | ) 417 | if level and i == num_res_blocks: 418 | layers.append(Upsample(ch, conv_resample, dims=dims)) 419 | ds //= 2 420 | self.output_blocks.append(TimestepEmbedSequential(*layers)) 421 | 422 | self.out = nn.Sequential( 423 | normalization(ch), 424 | self.activation, 425 | zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), 426 | ) 427 | 428 | def convert_to_fp16(self): 429 | """ 430 | Convert the torso of the model to float16. 431 | """ 432 | self.input_blocks.apply(convert_module_to_f16) 433 | self.middle_block.apply(convert_module_to_f16) 434 | self.output_blocks.apply(convert_module_to_f16) 435 | 436 | def convert_to_fp32(self): 437 | """ 438 | Convert the torso of the model to float32. 439 | """ 440 | self.input_blocks.apply(convert_module_to_f32) 441 | self.middle_block.apply(convert_module_to_f32) 442 | self.output_blocks.apply(convert_module_to_f32) 443 | 444 | @property 445 | def inner_dtype(self): 446 | """ 447 | Get the dtype used by the torso of the model. 448 | """ 449 | return torch.float32 # FIXED 450 | # return next(self.input_blocks.parameters()).dtype 451 | 452 | def forward(self, x, timesteps, y=None, low_cond = None): 453 | """ 454 | Apply the model to an input batch. 455 | :param x: an [N x C x ...] Tensor of inputs. 456 | :param timesteps: a 1-D batch of timesteps. 457 | :param y: an [N] Tensor of labels, if class-conditional. 458 | :param low_cond: an [N x C x ...] Tensor of condition. 459 | :return: an [N x C x ...] Tensor of outputs. 460 | """ 461 | 462 | ## concat the condition 463 | if low_cond is not None: 464 | x = th.cat((x, low_cond), dim = 1) 465 | 466 | assert (y is not None) == ( 467 | self.num_classes is not None 468 | ), "must specify y if and only if the model is class-conditional" 469 | 470 | hs = [] 471 | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) 472 | 473 | if self.num_classes is not None: 474 | assert y.shape == (x.shape[0],) 475 | emb = emb + self.label_emb(y) 476 | 477 | h = x.type(self.inner_dtype) 478 | for module in self.input_blocks: 479 | h = module(h, emb) 480 | hs.append(h) 481 | h = self.middle_block(h, emb) 482 | for module in self.output_blocks: 483 | 484 | # handling for non-even inputs 485 | # h = F.interpolate(h, size= hs[-1].size()[-3:], mode='trilinear') 486 | if hs[-1].size(-1) < h.size(-1): 487 | h = h[..., :-1] 488 | if hs[-1].size(-2) < h.size(-2): 489 | h = h[..., :-1, :] 490 | if hs[-1].size(-3) < h.size(-3): 491 | h = h[..., :-1, :, :] 492 | 493 | cat_in = th.cat([h, hs.pop()], dim=1) 494 | h = module(cat_in, emb) 495 | h = h.type(x.dtype) 496 | return self.out(h) 497 | 498 | def get_feature_vectors(self, x, timesteps, y=None): 499 | """ 500 | Apply the model and return all of the intermediate tensors. 501 | :param x: an [N x C x ...] Tensor of inputs. 502 | :param timesteps: a 1-D batch of timesteps. 503 | :param y: an [N] Tensor of labels, if class-conditional. 504 | :return: a dict with the following keys: 505 | - 'down': a list of hidden state tensors from downsampling. 506 | - 'middle': the tensor of the output of the lowest-resolution 507 | block in the model. 508 | - 'up': a list of hidden state tensors from upsampling. 509 | """ 510 | hs = [] 511 | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) 512 | if self.num_classes is not None: 513 | assert y.shape == (x.shape[0],) 514 | emb = emb + self.label_emb(y) 515 | result = dict(down=[], up=[]) 516 | h = x.type(self.inner_dtype) 517 | for module in self.input_blocks: 518 | h = module(h, emb) 519 | hs.append(h) 520 | result["down"].append(h.type(x.dtype)) 521 | h = self.middle_block(h, emb) 522 | result["middle"] = h.type(x.dtype) 523 | for module in self.output_blocks: 524 | cat_in = th.cat([h, hs.pop()], dim=1) 525 | h = module(cat_in, emb) 526 | result["up"].append(h.type(x.dtype)) 527 | return result 528 | 529 | 530 | 531 | class ControlledUNet(ResUNet): 532 | def forward(self, x, timesteps=None, control=None, only_mid_control=False, **kwargs): 533 | hs = [] 534 | t_emb = timestep_embedding(timesteps, self.model_channels) 535 | emb = self.time_embed(t_emb) 536 | 537 | h = x.type(self.inner_dtype) 538 | for module in self.input_blocks: 539 | h = module(h, emb) 540 | hs.append(h) 541 | h = self.middle_block(h, emb) 542 | 543 | if control is not None: 544 | h += control.pop() 545 | 546 | for i, module in enumerate(self.output_blocks): 547 | if hs[-1].size(-1) < h.size(-1): 548 | h = h[..., :-1] 549 | if hs[-1].size(-2) < h.size(-2): 550 | h = h[..., :-1, :] 551 | if hs[-1].size(-3) < h.size(-3): 552 | h = h[..., :-1, :, :] 553 | 554 | if only_mid_control or control is None: 555 | h = torch.cat([h, hs.pop()], dim=1) 556 | else: 557 | h = torch.cat([h, hs.pop() + control.pop()], dim=1) 558 | h = module(h, emb) 559 | 560 | h = h.type(x.dtype) 561 | return self.out(h) 562 | 563 | --------------------------------------------------------------------------------