├── .gitignore ├── README.md ├── engine_ae.py ├── main_ae.py ├── models_ae.py ├── sample_class_cond.py └── util ├── datasets.py ├── lr_decay.py ├── lr_sched.py ├── misc.py └── shapenet.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__ 2 | output -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node=4 main_ae.py --accum_iter=1 --model Diffusion --output_dir output --log_dir output --num_workers 16 --point_cloud_size 2048 --batch_size 16 --epochs 1000 --warmup_epochs 1 --data_path /data/path --blr 5e-5 2 | -------------------------------------------------------------------------------- /engine_ae.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # References: 3 | # MAE: https://github.com/facebookresearch/mae 4 | # DeiT: https://github.com/facebookresearch/deit 5 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 6 | # -------------------------------------------------------- 7 | 8 | import math 9 | import sys 10 | from typing import Iterable 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | 15 | import util.misc as misc 16 | import util.lr_sched as lr_sched 17 | 18 | 19 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 20 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 21 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 22 | log_writer=None, args=None): 23 | model.train(True) 24 | metric_logger = misc.MetricLogger(delimiter=" ") 25 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 26 | header = 'Epoch: [{}]'.format(epoch) 27 | print_freq = 20 28 | 29 | accum_iter = args.accum_iter 30 | 31 | optimizer.zero_grad() 32 | 33 | if log_writer is not None: 34 | print('log_dir: {}'.format(log_writer.log_dir)) 35 | 36 | for data_iter_step, (query_points, query_values, context_points, context_values, surface, categories) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 37 | 38 | # we use a per iteration (instead of per epoch) lr scheduler 39 | if data_iter_step % accum_iter == 0: 40 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 41 | 42 | query_points = query_points.to(device, non_blocking=True) 43 | query_values = query_values.to(device, non_blocking=True) 44 | context_points = context_points.to(device, non_blocking=True) 45 | context_values = context_values.to(device, non_blocking=True) 46 | surface = surface.to(device, non_blocking=True) 47 | categories = categories.to(device, non_blocking=True) 48 | 49 | 50 | with torch.cuda.amp.autocast(enabled=False): 51 | 52 | loss_diff = model( 53 | context_points, 54 | context_values, 55 | query_points, 56 | query_values, 57 | surface, 58 | categories, 59 | ) 60 | 61 | loss = loss_diff# 62 | 63 | loss_value = loss.item() 64 | 65 | if not math.isfinite(loss_value): 66 | print("Loss is {}, stopping training".format(loss_value)) 67 | sys.exit(1) 68 | 69 | loss /= accum_iter 70 | loss_scaler(loss, optimizer, clip_grad=max_norm, 71 | parameters=model.parameters(), create_graph=False, 72 | update_grad=(data_iter_step + 1) % accum_iter == 0) 73 | if (data_iter_step + 1) % accum_iter == 0: 74 | optimizer.zero_grad() 75 | 76 | torch.cuda.synchronize() 77 | 78 | metric_logger.update(loss=loss_value) 79 | # metric_logger.update(loss_recon=loss_recon.item()) 80 | # metric_logger.update(weight=weight.item()) 81 | # metric_logger.update(loss_bce_vol=loss_bce_vol.item()) 82 | # metric_logger.update(loss_bce_near=loss_bce_near.item()) 83 | 84 | min_lr = 10. 85 | max_lr = 0. 86 | for group in optimizer.param_groups: 87 | min_lr = min(min_lr, group["lr"]) 88 | max_lr = max(max_lr, group["lr"]) 89 | 90 | metric_logger.update(lr=max_lr) 91 | 92 | loss_value_reduce = misc.all_reduce_mean(loss_value) 93 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 94 | """ We use epoch_1000x as the x-axis in tensorboard. 95 | This calibrates different curves when batch size changes. 96 | """ 97 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 98 | log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x) 99 | log_writer.add_scalar('lr', max_lr, epoch_1000x) 100 | 101 | # gather the stats from all processes 102 | metric_logger.synchronize_between_processes() 103 | print("Averaged stats:", metric_logger) 104 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 105 | 106 | 107 | @torch.no_grad() 108 | def evaluate(data_loader, model, device): 109 | criterion = torch.nn.BCEWithLogitsLoss() 110 | 111 | metric_logger = misc.MetricLogger(delimiter=" ") 112 | header = 'Test:' 113 | 114 | # switch to evaluation mode 115 | model.eval() 116 | 117 | for query_points, query_values, context_points, context_values, surface, categories in metric_logger.log_every(data_loader, 50, header): 118 | 119 | query_points = query_points.to(device, non_blocking=True) 120 | query_values = query_values.to(device, non_blocking=True) 121 | context_points = context_points.to(device, non_blocking=True) 122 | context_values = context_values.to(device, non_blocking=True) 123 | surface = surface.to(device, non_blocking=True) 124 | categories = categories.to(device, non_blocking=True) 125 | 126 | # compute output 127 | with torch.cuda.amp.autocast(enabled=False): 128 | 129 | loss_diff, loss_recon, weight, loss_bce_vol, loss_bce_near = model( 130 | context_points, 131 | context_values, 132 | query_points, 133 | query_values, 134 | surface, 135 | categories, 136 | ) 137 | 138 | # if loss_kl is not None: 139 | # loss = loss_vol + 0.1 * loss_near + kl_weight * loss_kl 140 | # else: 141 | # loss = loss_vol + 0.1 * loss_near 142 | loss = loss_diff# 143 | 144 | 145 | batch_size = surface.shape[0] 146 | metric_logger.update(loss=loss.item()) 147 | metric_logger.update(loss_recon=loss_recon.item()) 148 | metric_logger.update(weight=weight.item()) 149 | metric_logger.update(loss_bce_vol=loss_bce_vol.item()) 150 | metric_logger.update(loss_bce_near=loss_bce_near.item()) 151 | 152 | # metric_logger.meters['iou'].update(iou.item(), n=batch_size) 153 | 154 | # if loss_kl is not None: 155 | # metric_logger.update(loss_kl=loss_kl.item()) 156 | 157 | # gather the stats from all processes 158 | metric_logger.synchronize_between_processes() 159 | print('* loss {losses.global_avg:.3f}' 160 | .format(losses=metric_logger.loss)) 161 | 162 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} -------------------------------------------------------------------------------- /main_ae.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import json 4 | import numpy as np 5 | import os 6 | import time 7 | from pathlib import Path 8 | 9 | import torch 10 | import torch.backends.cudnn as cudnn 11 | from torch.utils.tensorboard import SummaryWriter 12 | 13 | torch.set_num_threads(8) 14 | 15 | import util.lr_decay as lrd 16 | import util.misc as misc 17 | from util.datasets import build_shape_surface_occupancy_dataset 18 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 19 | 20 | import models_ae as models_ae 21 | 22 | from engine_ae import train_one_epoch, evaluate 23 | 24 | def get_args_parser(): 25 | parser = argparse.ArgumentParser('Autoencoder', add_help=False) 26 | parser.add_argument('--batch_size', default=64, type=int, 27 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 28 | parser.add_argument('--epochs', default=800, type=int) 29 | parser.add_argument('--accum_iter', default=1, type=int, 30 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 31 | 32 | # Model parameters 33 | parser.add_argument('--model', default='ae_blob64', type=str, metavar='MODEL', 34 | help='Name of model to train') 35 | 36 | parser.add_argument('--point_cloud_size', default=2048, type=int, 37 | help='input size') 38 | 39 | # Optimizer parameters 40 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', 41 | help='Clip gradient norm (default: None, no clipping)') 42 | parser.add_argument('--weight_decay', type=float, default=0.05, 43 | help='weight decay (default: 0.05)') 44 | 45 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 46 | help='learning rate (absolute lr)') 47 | parser.add_argument('--blr', type=float, default=1e-4, metavar='LR', 48 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 49 | parser.add_argument('--layer_decay', type=float, default=0.75, 50 | help='layer-wise lr decay from ELECTRA/BEiT') 51 | 52 | parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', 53 | help='lower lr bound for cyclic schedulers that hit 0') 54 | 55 | parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', 56 | help='epochs to warmup LR') 57 | 58 | 59 | # Dataset parameters 60 | parser.add_argument('--data_path', default='/ibex/scratch/projects/c2168/diffusion-shapes/datasets', type=str, 61 | help='dataset path') 62 | 63 | parser.add_argument('--output_dir', default='./output/', 64 | help='path where to save, empty for no saving') 65 | parser.add_argument('--log_dir', default='./output/', 66 | help='path where to tensorboard log') 67 | parser.add_argument('--device', default='cuda', 68 | help='device to use for training / testing') 69 | parser.add_argument('--seed', default=0, type=int) 70 | parser.add_argument('--resume', default='', 71 | help='resume from checkpoint') 72 | 73 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 74 | help='start epoch') 75 | parser.add_argument('--eval', action='store_true', 76 | help='Perform evaluation only') 77 | parser.add_argument('--dist_eval', action='store_true', default=False, 78 | help='Enabling distributed evaluation (recommended during training for faster monitor') 79 | parser.add_argument('--num_workers', default=60, type=int) 80 | parser.add_argument('--pin_mem', action='store_true', 81 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 82 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 83 | parser.set_defaults(pin_mem=False) 84 | 85 | # distributed training parameters 86 | parser.add_argument('--world_size', default=1, type=int, 87 | help='number of distributed processes') 88 | parser.add_argument('--local_rank', default=-1, type=int) 89 | parser.add_argument('--dist_on_itp', action='store_true') 90 | parser.add_argument('--dist_url', default='env://', 91 | help='url used to set up distributed training') 92 | 93 | return parser 94 | 95 | def main(args): 96 | misc.init_distributed_mode(args) 97 | 98 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 99 | print("{}".format(args).replace(', ', ',\n')) 100 | 101 | device = torch.device(args.device) 102 | 103 | # fix the seed for reproducibility 104 | seed = args.seed + misc.get_rank() 105 | torch.manual_seed(seed) 106 | np.random.seed(seed) 107 | 108 | cudnn.benchmark = True 109 | 110 | dataset_train = build_shape_surface_occupancy_dataset('train', args=args) 111 | dataset_val = build_shape_surface_occupancy_dataset('val', args=args) 112 | 113 | if True: # args.distributed: 114 | num_tasks = misc.get_world_size() 115 | global_rank = misc.get_rank() 116 | sampler_train = torch.utils.data.DistributedSampler( 117 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 118 | ) 119 | print("Sampler_train = %s" % str(sampler_train)) 120 | if args.dist_eval: 121 | if len(dataset_val) % num_tasks != 0: 122 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 123 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 124 | 'equal num of samples per-process.') 125 | sampler_val = torch.utils.data.DistributedSampler( 126 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True) # shuffle=True to reduce monitor bias 127 | else: 128 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 129 | else: 130 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 131 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 132 | 133 | if global_rank == 0 and args.log_dir is not None and not args.eval: 134 | os.makedirs(args.log_dir, exist_ok=True) 135 | log_writer = SummaryWriter(log_dir=args.log_dir) 136 | else: 137 | log_writer = None 138 | 139 | data_loader_train = torch.utils.data.DataLoader( 140 | dataset_train, sampler=sampler_train, 141 | batch_size=args.batch_size, 142 | num_workers=args.num_workers, 143 | pin_memory=args.pin_mem, 144 | drop_last=True, 145 | prefetch_factor=2, 146 | ) 147 | 148 | data_loader_val = torch.utils.data.DataLoader( 149 | dataset_val, sampler=sampler_val, 150 | # batch_size=args.batch_size, 151 | batch_size=1, 152 | # num_workers=args.num_workers, 153 | num_workers=1, 154 | pin_memory=args.pin_mem, 155 | drop_last=False 156 | ) 157 | 158 | model = models_ae.__dict__[args.model](N=args.point_cloud_size) 159 | 160 | model.to(device) 161 | 162 | model_without_ddp = model 163 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 164 | 165 | print("Model = %s" % str(model_without_ddp)) 166 | print('number of params (M): %.2f' % (n_parameters / 1.e6)) 167 | 168 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 169 | 170 | if args.lr is None: # only base_lr is specified 171 | args.lr = args.blr * eff_batch_size / 256 172 | 173 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 174 | print("actual lr: %.2e" % args.lr) 175 | 176 | print("accumulate grad iterations: %d" % args.accum_iter) 177 | print("effective batch size: %d" % eff_batch_size) 178 | 179 | if args.distributed: 180 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False) 181 | model_without_ddp = model.module 182 | 183 | # # build optimizer with layer-wise lr decay (lrd) 184 | # param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay, 185 | # no_weight_decay_list=model_without_ddp.no_weight_decay(), 186 | # layer_decay=args.layer_decay 187 | # ) 188 | optimizer = torch.optim.AdamW(model_without_ddp.parameters(), lr=args.lr) 189 | loss_scaler = NativeScaler() 190 | 191 | criterion = torch.nn.BCEWithLogitsLoss() 192 | 193 | print("criterion = %s" % str(criterion)) 194 | 195 | misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 196 | 197 | if args.eval: 198 | test_stats = evaluate(data_loader_val, model, device) 199 | print(f"iou of the network on the {len(dataset_val)} test images: {test_stats['iou']:.3f}") 200 | exit(0) 201 | 202 | print(f"Start training for {args.epochs} epochs") 203 | start_time = time.time() 204 | max_iou = 0.0 205 | for epoch in range(args.start_epoch, args.epochs): 206 | if args.distributed: 207 | data_loader_train.sampler.set_epoch(epoch) 208 | train_stats = train_one_epoch( 209 | model, criterion, data_loader_train, 210 | optimizer, device, epoch, loss_scaler, 211 | args.clip_grad, 212 | log_writer=log_writer, 213 | args=args 214 | ) 215 | if args.output_dir and (epoch % 5 == 0 or epoch + 1 == args.epochs): 216 | misc.save_model( 217 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 218 | loss_scaler=loss_scaler, epoch=epoch) 219 | 220 | if epoch % 1 == 0 or epoch + 1 == args.epochs: 221 | 222 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 223 | 'epoch': epoch, 224 | 'n_parameters': n_parameters} 225 | else: 226 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 227 | 'epoch': epoch, 228 | 'n_parameters': n_parameters} 229 | 230 | if args.output_dir and misc.is_main_process(): 231 | if log_writer is not None: 232 | log_writer.flush() 233 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 234 | f.write(json.dumps(log_stats) + "\n") 235 | 236 | total_time = time.time() - start_time 237 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 238 | print('Training time {}'.format(total_time_str)) 239 | 240 | if __name__ == '__main__': 241 | args = get_args_parser() 242 | args = args.parse_args() 243 | if args.output_dir: 244 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 245 | main(args) -------------------------------------------------------------------------------- /models_ae.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | 3 | import numpy as np 4 | 5 | import math 6 | 7 | import torch 8 | from torch import nn, einsum 9 | import torch.nn.functional as F 10 | 11 | from einops import rearrange, repeat 12 | 13 | from torch_cluster import fps 14 | 15 | from timm.models.layers import DropPath 16 | 17 | def cdist2(x, y): 18 | # |x_i - y_j|_2^2 = = + - 2* 19 | x_sq_norm = x.pow(2).sum(dim=-1, keepdim=True) 20 | y_sq_norm = y.pow(2).sum(dim=-1) 21 | x_dot_y = x @ y.transpose(-1,-2) 22 | sq_dist = x_sq_norm + y_sq_norm.unsqueeze(dim=-2) - 2*x_dot_y 23 | # For numerical issues 24 | sq_dist.clamp_(min=0.0) 25 | return torch.sqrt(sq_dist) 26 | 27 | def exists(val): 28 | return val is not None 29 | 30 | def default(val, d): 31 | return val if exists(val) else d 32 | 33 | def cache_fn(f): 34 | cache = None 35 | @wraps(f) 36 | def cached_fn(*args, _cache = True, **kwargs): 37 | if not _cache: 38 | return f(*args, **kwargs) 39 | nonlocal cache 40 | if cache is not None: 41 | return cache 42 | cache = f(*args, **kwargs) 43 | return cache 44 | return cached_fn 45 | 46 | class PreNorm(nn.Module): 47 | def __init__(self, dim, fn, context_dim = None, modulated=False): 48 | super().__init__() 49 | self.fn = fn 50 | self.norm = nn.LayerNorm(dim) 51 | self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None 52 | 53 | self.modulated = modulated 54 | if self.modulated: 55 | self.gamma = nn.Linear(dim, dim, bias=False) 56 | self.beta = nn.Linear(dim, dim, bias=False) 57 | 58 | def forward(self, x, **kwargs): 59 | x = self.norm(x) 60 | 61 | if self.modulated: 62 | label = kwargs.pop('label') 63 | gamma = self.gamma(label) # b 1 c 64 | beta = self.beta(label) # b 1 c 65 | # print('layernorm', x.shape, beta.shape) 66 | x = gamma * x + beta 67 | # print('layernorm', x.shape, beta.shape) 68 | 69 | if exists(self.norm_context): 70 | context = kwargs['context'] 71 | normed_context = self.norm_context(context) 72 | kwargs.update(context = normed_context) 73 | 74 | return self.fn(x, **kwargs) 75 | 76 | class GEGLU(nn.Module): 77 | def forward(self, x): 78 | x, gates = x.chunk(2, dim = -1) 79 | return x * F.gelu(gates) 80 | 81 | class FeedForward(nn.Module): 82 | def __init__(self, dim, mult = 4, drop_path_rate = 0.0): 83 | super().__init__() 84 | self.net = nn.Sequential( 85 | nn.Linear(dim, dim * mult * 2), 86 | GEGLU(), 87 | nn.Linear(dim * mult, dim) 88 | ) 89 | 90 | self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() 91 | 92 | def forward(self, x): 93 | return self.drop_path(self.net(x)) 94 | 95 | class Attention(nn.Module): 96 | def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 64, drop_path_rate = 0.0): 97 | super().__init__() 98 | inner_dim = dim_head * heads 99 | context_dim = default(context_dim, query_dim) 100 | self.scale = dim_head ** -0.5 101 | self.heads = heads 102 | 103 | self.to_q = nn.Linear(query_dim, inner_dim, bias = False) 104 | self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False) 105 | self.to_out = nn.Linear(inner_dim, query_dim) 106 | 107 | self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() 108 | 109 | def forward(self, x, context = None, mask = None, attn_mask=None): 110 | h = self.heads 111 | 112 | q = self.to_q(x) 113 | context = default(context, x) 114 | k, v = self.to_kv(context).chunk(2, dim = -1) 115 | # print(q.shape, k.shape, v.shape) 116 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v)) 117 | # print(q.shape, k.shape, v.shape) 118 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 119 | 120 | if exists(mask): 121 | mask = rearrange(mask, 'b ... -> b (...)') 122 | max_neg_value = -torch.finfo(sim.dtype).max 123 | mask = repeat(mask, 'b j -> (b h) () j', h = h) 124 | sim.masked_fill_(~mask, max_neg_value) 125 | 126 | if exists(attn_mask): 127 | # attn_mask = rearrange(attn_mask, 'b i j -> b 1 i j') 128 | attn_mask = repeat(attn_mask, 'i j -> (b h) i j', b=x.shape[0], h = h) 129 | # print(attn_mask) 130 | sim.masked_fill_(~attn_mask, -torch.finfo(sim.dtype).max) 131 | 132 | 133 | # attention, what we cannot get enough of 134 | attn = sim.softmax(dim = -1) 135 | 136 | out = einsum('b i j, b j d -> b i d', attn, v) 137 | out = rearrange(out, '(b h) n d -> b n (h d)', h = h) 138 | return self.drop_path(self.to_out(out)) 139 | 140 | 141 | class PointEmbed(nn.Module): 142 | def __init__(self, hidden_dim=48, dim=128): 143 | super().__init__() 144 | 145 | assert hidden_dim % 6 == 0 146 | 147 | self.embedding_dim = hidden_dim 148 | e = torch.pow(2, torch.arange(self.embedding_dim // 6)).float() * np.pi 149 | e = torch.stack([ 150 | torch.cat([e, torch.zeros(self.embedding_dim // 6), 151 | torch.zeros(self.embedding_dim // 6)]), 152 | torch.cat([torch.zeros(self.embedding_dim // 6), e, 153 | torch.zeros(self.embedding_dim // 6)]), 154 | torch.cat([torch.zeros(self.embedding_dim // 6), 155 | torch.zeros(self.embedding_dim // 6), e]), 156 | ]) 157 | self.register_buffer('basis', e) # 3 x 16 158 | 159 | self.mlp = nn.Linear(self.embedding_dim+3, dim) 160 | 161 | @staticmethod 162 | def embed(input, basis): 163 | projections = torch.einsum( 164 | 'bnd,de->bne', input, basis) 165 | embeddings = torch.cat([projections.sin(), projections.cos()], dim=2) 166 | return embeddings 167 | 168 | def forward(self, input): 169 | # input: B x N x 3 170 | embed = self.mlp(torch.cat([self.embed(input, self.basis), input], dim=2)) # B x N x C 171 | return embed 172 | 173 | 174 | class DiagonalGaussianDistribution(object): 175 | def __init__(self, mean, logvar, deterministic=False): 176 | self.mean = mean 177 | self.logvar = logvar 178 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 179 | self.deterministic = deterministic 180 | self.std = torch.exp(0.5 * self.logvar) 181 | self.var = torch.exp(self.logvar) 182 | if self.deterministic: 183 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.mean.device) 184 | 185 | def sample(self): 186 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.mean.device) 187 | return x 188 | 189 | def kl(self, other=None): 190 | if self.deterministic: 191 | return torch.Tensor([0.]) 192 | else: 193 | if other is None: 194 | return 0.5 * torch.mean(torch.pow(self.mean, 2) 195 | + self.var - 1.0 - self.logvar, 196 | dim=[1, 2]) 197 | else: 198 | return 0.5 * torch.mean( 199 | torch.pow(self.mean - other.mean, 2) / other.var 200 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 201 | dim=[1, 2, 3]) 202 | 203 | def nll(self, sample, dims=[1,2,3]): 204 | if self.deterministic: 205 | return torch.Tensor([0.]) 206 | logtwopi = np.log(2.0 * np.pi) 207 | return 0.5 * torch.sum( 208 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 209 | dim=dims) 210 | 211 | def mode(self): 212 | return self.mean 213 | 214 | class FourierEmbedding(torch.nn.Module): 215 | def __init__(self, num_channels, scale=16): 216 | super().__init__() 217 | self.register_buffer('freqs', torch.randn(num_channels // 2) * scale) 218 | 219 | def forward(self, x): 220 | # print(x.shape, self.freqs.shape) 221 | # x = x.ger((2 * np.pi * self.freqs).to(x.dtype)) 222 | x = torch.einsum('..., n->... n', x, 2 * np.pi * self.freqs) 223 | x = torch.cat([x.cos(), x.sin()], dim=-1) 224 | return x 225 | 226 | class Network(nn.Module): 227 | def __init__( 228 | self, 229 | *, 230 | dim=512, 231 | depth=4, 232 | heads=8, 233 | dim_head=64, 234 | function_dim=1, 235 | ): 236 | super().__init__() 237 | 238 | heads = dim // dim_head 239 | 240 | self.map_noise = FourierEmbedding(num_channels=dim) 241 | 242 | self.layers = nn.ModuleList([]) 243 | 244 | get_latent_attn = lambda: PreNorm(dim, Attention(dim, dim, heads = heads, dim_head = dim_head), context_dim = dim) 245 | get_latent_ff = lambda: PreNorm(dim, FeedForward(dim)) 246 | get_latent_attn, get_latent_ff = map(cache_fn, (get_latent_attn, get_latent_ff)) 247 | 248 | self.depth = depth 249 | 250 | for _ in range(depth): 251 | self.layers.append(nn.ModuleList([ 252 | PreNorm(dim, Attention(dim, dim, heads = heads, dim_head = dim_head), context_dim = dim), 253 | PreNorm(dim, FeedForward(dim)), 254 | nn.ModuleList([ 255 | nn.ModuleList([ 256 | PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head), modulated=True), 257 | PreNorm(dim, FeedForward(dim), modulated=True), 258 | ]) for _ in range(1) 259 | ]) 260 | ])) 261 | 262 | 263 | self.cross_attend_blocks = nn.ModuleList([ 264 | PreNorm(dim, Attention(dim, dim, heads = 1, dim_head = dim), context_dim = dim), 265 | None 266 | ]) 267 | 268 | self.to_output = PreNorm(dim, nn.Linear(dim, 1), modulated=True) 269 | 270 | 271 | self.latent = nn.Embedding(512, dim) 272 | 273 | self.point_embed = PointEmbed(dim=dim) 274 | 275 | self.value_embed = nn.Linear(function_dim, dim) 276 | 277 | self.cond_token = nn.Embedding(1, dim) 278 | 279 | def forward_features(self, context_points, context_values, alpha_embeddings, cond): 280 | # context_points: b n 3 281 | # context_values: b n 1 282 | # alpha: b 283 | 284 | context = self.point_embed(context_points) + self.value_embed(context_values)#.squeeze(-1)) 285 | 286 | context = torch.split(context, context.shape[1]//self.depth, dim=1) 287 | 288 | x = repeat(self.latent.weight, 'n c -> b n c', b=alpha_embeddings.shape[0]) 289 | 290 | for i, (cross_attn, cross_ff, layers) in enumerate(self.layers): 291 | 292 | if cond is None: 293 | c = context[i] 294 | else: 295 | c = torch.cat([context[i], cond + self.cond_token.weight[None]], dim=1) 296 | 297 | x = cross_attn(x, context=c) + x 298 | x = cross_ff(x) + x 299 | 300 | 301 | for self_attn, self_ff in layers: 302 | x = self_attn(x, label=alpha_embeddings) + x 303 | x = self_ff(x, label=alpha_embeddings) + x 304 | return x 305 | 306 | def decode(self, queries, x, alpha_embeddings): 307 | queries = self.point_embed(queries) 308 | 309 | #### 310 | cross_attn, cross_ff = self.cross_attend_blocks 311 | 312 | o = cross_attn(queries, context = x, mask = None)# + queries_embeddings 313 | o = self.to_output(o, label=alpha_embeddings)#.squeeze(-1) 314 | return o 315 | 316 | def forward(self, context_points, context_values, queries, alpha, cond): 317 | 318 | alpha_embeddings = self.map_noise(alpha)[:, None] 319 | 320 | x = self.forward_features(context_points, context_values, alpha_embeddings, cond) 321 | return self.decode(queries, x, alpha_embeddings) 322 | 323 | 324 | class Diffusion(nn.Module): 325 | def __init__(self, N=0): 326 | super().__init__() 327 | 328 | self.model = Network(depth=24, dim=768) 329 | self.condition = nn.Embedding(55, 768) 330 | 331 | self.logvar_fourier = FourierEmbedding(num_channels=768) 332 | self.logvar_linear = nn.Linear(768, 1) 333 | 334 | def forward(self, context_points, context_values, query_points, query_sdf, pc, categories): 335 | B, _, _, device = *context_points.shape, context_points.device 336 | 337 | cond = self.condition(categories)[:, None] 338 | 339 | rnd_normal = torch.randn([pc.shape[0]], device=pc.device) 340 | t = (rnd_normal * 1.2 - 1.2).exp() 341 | 342 | M = 2048 343 | x_i = torch.rand(B, M, 3).to(device) * 2 - 1 344 | 345 | s_i = torch.randn(B, 1, 64, 64, 64).to(device) 346 | 347 | 348 | f_t = context_values + t[:, None] * self.init(x_i, s_i, context_points, blocks=4096) 349 | 350 | query_points = torch.cat([pc, query_points], dim=1) 351 | query_sdf = torch.cat([torch.zeros_like(pc[:, :, 0]), query_sdf], dim=1) 352 | 353 | denominator = torch.sqrt(1 + t**2) 354 | 355 | d = self.model(context_points, f_t[:, :, None] / denominator[:, None, None], query_points, t.log() / 4, cond).squeeze(-1) 356 | 357 | 358 | logvar = self.logvar_linear(self.logvar_fourier(t / denominator)) 359 | 360 | loss_recon = (d - query_sdf)**2 361 | 362 | loss = 1/ logvar.exp() * loss_recon + logvar 363 | loss = torch.sum(loss) / d.shape[0] 364 | 365 | return loss 366 | 367 | @torch.no_grad() 368 | def init(self, x_i, grid, queries, blocks=1): 369 | 370 | return F.grid_sample(grid, queries[:, :, None, None], align_corners=False).squeeze(-1).squeeze(-1).squeeze(1) 371 | 372 | @torch.no_grad() 373 | def sample(self, categories, query_points, n_steps=64): 374 | 375 | if categories is not None: 376 | cond = self.condition(categories)[:, None] 377 | else: 378 | cond = None 379 | 380 | B, device = query_points.shape[0], query_points.device 381 | 382 | sigma_max, sigma_min, rho = 80, 0.002, 7 383 | 384 | step_indices = torch.arange(n_steps, dtype=torch.float32, device=device) 385 | 386 | sigma_steps = (sigma_max ** (1 / rho) + step_indices / (n_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho 387 | 388 | t_steps = torch.as_tensor(sigma_steps) 389 | t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0 390 | 391 | # Main sampling loop. 392 | t_next = t_steps[0] 393 | 394 | context_points = torch.rand(B, 1024*48, 3, device=device) * 2 - 1 395 | 396 | M = 2048 397 | x_i = torch.rand(B, M, 3).to(device) * 2 - 1 398 | s_i = torch.randn(B, 1, 64, 64, 64).to(device) 399 | 400 | context_sdf = self.init(x_i, s_i, context_points, blocks=2048) * t_steps[0] 401 | 402 | query_sdf = self.init(x_i, s_i, query_points, blocks=2048) * t_steps[0] 403 | 404 | query_sdfs = [query_sdf.clone()] 405 | 406 | for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 407 | 408 | alpha_embeddings = self.model.map_noise(t_cur[None].log() / 4)[:, None].expand(B, -1, -1) 409 | 410 | x = self.model.forward_features(context_points, context_sdf[:, :, None] / (1 + t_cur**2).sqrt(), alpha_embeddings, cond) 411 | 412 | d = self.model.decode(context_points, x, alpha_embeddings).squeeze(-1) 413 | 414 | context_sdf = context_sdf + (context_sdf - d) * (t_next - t_cur) / t_cur 415 | 416 | # print(i, t_next, t_cur, t_next / t_cur) 417 | # print(context_sdf.max().item(), context_sdf.min().item(), context_sdf.mean().item(), context_sdf.std().item()) 418 | 419 | d = self.model.decode(query_points, x, alpha_embeddings).squeeze(-1) 420 | 421 | query_sdf = d 422 | return query_sdf 423 | -------------------------------------------------------------------------------- /sample_class_cond.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from pathlib import Path 3 | import util.misc as misc 4 | from util.shapenet import ShapeNet, category_ids 5 | 6 | import models_ae as models_ae 7 | 8 | import mcubes 9 | import trimesh 10 | from scipy.spatial import cKDTree as KDTree 11 | import numpy as np 12 | import torchvision.transforms as T 13 | import torch.backends.cudnn as cudnn 14 | import torch.nn.functional as F 15 | import torch 16 | import yaml 17 | import math 18 | 19 | import argparse 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--model', default='Diffusion', type=str, 23 | metavar='MODEL', help='Name of model to train') 24 | parser.add_argument( 25 | '--pth', default='output/checkpoint-130.pth', type=str) 26 | parser.add_argument('--device', default='cuda', 27 | help='device to use for training / testing') 28 | parser.add_argument('--seed', default=0, type=int) 29 | args = parser.parse_args() 30 | 31 | 32 | # import utils 33 | 34 | 35 | def main(): 36 | print(args) 37 | seed = args.seed 38 | torch.manual_seed(seed) 39 | np.random.seed(seed) 40 | 41 | cudnn.benchmark = True 42 | 43 | model = models_ae.__dict__[args.model]() 44 | device = torch.device(args.device) 45 | 46 | model.eval() 47 | model.load_state_dict(torch.load(args.pth, map_location='cpu')[ 48 | 'model'], strict=True) 49 | model.to(device) 50 | # print(model) 51 | 52 | density = 128 53 | gap = 2. / density 54 | x = np.linspace(-1, 1, density+1) 55 | y = np.linspace(-1, 1, density+1) 56 | z = np.linspace(-1, 1, density+1) 57 | xv, yv, zv = np.meshgrid(x, y, z) 58 | grid = torch.from_numpy(np.stack([xv, yv, zv]).astype(np.float32)).view(3, -1).transpose(0, 1)[None].cuda() 59 | 60 | with torch.no_grad(): 61 | for idx in range(16): 62 | 63 | categories = torch.Tensor([0] * 1).int().cuda() 64 | outputs = model.sample(categories, grid.expand(1, -1, -1), n_steps=64) 65 | 66 | output = outputs[0] 67 | volume = output.view(density+1, density+1, density+1).permute(1, 0, 2).cpu().numpy() * (-1) 68 | 69 | verts, faces = mcubes.marching_cubes(volume, 0) 70 | verts *= gap 71 | verts -= 1. 72 | m = trimesh.Trimesh(verts, faces) 73 | 74 | m.export('samples/{:03d}.obj'.format(idx)) 75 | 76 | if __name__ == '__main__': 77 | main() 78 | -------------------------------------------------------------------------------- /util/datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .shapenet import ShapeNet 4 | 5 | class AxisScaling(object): 6 | def __init__(self, interval=(0.75, 1.25), jitter=True): 7 | assert isinstance(interval, tuple) 8 | self.interval = interval 9 | self.jitter = jitter 10 | 11 | def __call__(self, surface, point): 12 | scaling = torch.rand(1, 3) * 0.5 + 0.75 13 | surface = surface * scaling 14 | point = point * scaling 15 | 16 | scale = (1 / torch.abs(surface).max().item()) * 0.999999 17 | surface *= scale 18 | point *= scale 19 | 20 | if self.jitter: 21 | surface += 0.005 * torch.randn_like(surface) 22 | surface.clamp_(min=-1, max=1) 23 | 24 | return surface, point 25 | 26 | 27 | def build_shape_surface_occupancy_dataset(split, args): 28 | if split == 'train': 29 | # transform = #transforms.Compose([ 30 | transform = AxisScaling((0.75, 1.25), False) 31 | # ]) 32 | transform = None 33 | return ShapeNet(args.data_path, split=split, transform=transform, sampling=True, num_samples=1024*48, return_surface=True, surface_sampling=True, pc_size=1024*8) 34 | elif split == 'val': 35 | # return ShapeNet(args.data_path, split=split, transform=None, sampling=True, num_samples=1024, return_surface=True, surface_sampling=True, pc_size=1024*16) 36 | return ShapeNet(args.data_path, split=split, transform=None, sampling=True, num_samples=1024*48, return_surface=True, surface_sampling=True, pc_size=1024*8) 37 | else: 38 | return ShapeNet(args.data_path, split=split, transform=None, sampling=True, num_samples=1024*48, return_surface=True, surface_sampling=True, pc_size=1024*8) 39 | 40 | if __name__ == '__main__': 41 | # m = ShapeNet('/home/zhanb0b/data/', 'train', transform=AxisScaling(), sampling=True, num_samples=1024, return_surface=True, surface_sampling=True) 42 | m = ShapeNet('/home/zhanb0b/data/', 'train', transform=AxisScaling(), sampling=True, num_samples=1024, return_surface=True, surface_sampling=True) 43 | p, l, s, c = m[0] 44 | print(p.shape, l.shape, s.shape, c) 45 | print(p.max(dim=0)[0], p.min(dim=0)[0]) 46 | print(p[l==1].max(axis=0)[0], p[l==1].min(axis=0)[0]) 47 | print(s.max(axis=0)[0], s.min(axis=0)[0]) -------------------------------------------------------------------------------- /util/lr_decay.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # References: 3 | # MAE: https://github.com/facebookresearch/mae 4 | # DeiT: https://github.com/facebookresearch/deit 5 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 6 | # -------------------------------------------------------- 7 | 8 | import json 9 | 10 | 11 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): 12 | """ 13 | Parameter groups for layer-wise lr decay 14 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 15 | """ 16 | param_group_names = {} 17 | param_groups = {} 18 | 19 | num_layers = len(model.blocks) + 1 20 | 21 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 22 | 23 | for n, p in model.named_parameters(): 24 | if not p.requires_grad: 25 | continue 26 | 27 | # no decay: all 1D parameters and model specific ones 28 | if p.ndim == 1 or n in no_weight_decay_list: 29 | g_decay = "no_decay" 30 | this_decay = 0. 31 | else: 32 | g_decay = "decay" 33 | this_decay = weight_decay 34 | 35 | layer_id = get_layer_id_for_vit(n, num_layers) 36 | group_name = "layer_%d_%s" % (layer_id, g_decay) 37 | 38 | if group_name not in param_group_names: 39 | this_scale = layer_scales[layer_id] 40 | 41 | param_group_names[group_name] = { 42 | "lr_scale": this_scale, 43 | "weight_decay": this_decay, 44 | "params": [], 45 | } 46 | param_groups[group_name] = { 47 | "lr_scale": this_scale, 48 | "weight_decay": this_decay, 49 | "params": [], 50 | } 51 | 52 | param_group_names[group_name]["params"].append(n) 53 | param_groups[group_name]["params"].append(p) 54 | 55 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 56 | 57 | return list(param_groups.values()) 58 | 59 | 60 | def get_layer_id_for_vit(name, num_layers): 61 | """ 62 | Assign a parameter with its layer id 63 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 64 | """ 65 | if name in ['cls_token', 'pos_embed']: 66 | return 0 67 | elif name.startswith('patch_embed'): 68 | return 0 69 | elif name.startswith('blocks'): 70 | return int(name.split('.')[1]) + 1 71 | else: 72 | return num_layers -------------------------------------------------------------------------------- /util/lr_sched.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | def adjust_learning_rate(optimizer, epoch, args): 10 | """Decay the learning rate with half-cycle cosine after warmup""" 11 | if epoch < args.warmup_epochs: 12 | lr = args.lr * epoch / args.warmup_epochs 13 | else: 14 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 15 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 16 | for param_group in optimizer.param_groups: 17 | if "lr_scale" in param_group: 18 | param_group["lr"] = lr * param_group["lr_scale"] 19 | else: 20 | param_group["lr"] = lr 21 | return lr -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # References: 3 | # MAE: https://github.com/facebookresearch/mae 4 | # DeiT: https://github.com/facebookresearch/deit 5 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 6 | # -------------------------------------------------------- 7 | 8 | import builtins 9 | import datetime 10 | import os 11 | import time 12 | from collections import defaultdict, deque 13 | from pathlib import Path 14 | 15 | import torch 16 | import torch.distributed as dist 17 | if torch.__version__[0] == '2': 18 | from torch import inf 19 | else: 20 | from torch._six import inf 21 | 22 | class SmoothedValue(object): 23 | """Track a series of values and provide access to smoothed values over a 24 | window or the global series average. 25 | """ 26 | 27 | def __init__(self, window_size=20, fmt=None): 28 | if fmt is None: 29 | fmt = "{median:.4f} ({global_avg:.4f})" 30 | self.deque = deque(maxlen=window_size) 31 | self.total = 0.0 32 | self.count = 0 33 | self.fmt = fmt 34 | 35 | def update(self, value, n=1): 36 | self.deque.append(value) 37 | self.count += n 38 | self.total += value * n 39 | 40 | def synchronize_between_processes(self): 41 | """ 42 | Warning: does not synchronize the deque! 43 | """ 44 | if not is_dist_avail_and_initialized(): 45 | return 46 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 47 | dist.barrier() 48 | dist.all_reduce(t) 49 | t = t.tolist() 50 | self.count = int(t[0]) 51 | self.total = t[1] 52 | 53 | @property 54 | def median(self): 55 | d = torch.tensor(list(self.deque)) 56 | return d.median().item() 57 | 58 | @property 59 | def avg(self): 60 | d = torch.tensor(list(self.deque), dtype=torch.float32) 61 | return d.mean().item() 62 | 63 | @property 64 | def global_avg(self): 65 | return self.total / self.count 66 | 67 | @property 68 | def max(self): 69 | return max(self.deque) 70 | 71 | @property 72 | def value(self): 73 | return self.deque[-1] 74 | 75 | def __str__(self): 76 | return self.fmt.format( 77 | median=self.median, 78 | avg=self.avg, 79 | global_avg=self.global_avg, 80 | max=self.max, 81 | value=self.value) 82 | 83 | 84 | class MetricLogger(object): 85 | def __init__(self, delimiter="\t"): 86 | self.meters = defaultdict(SmoothedValue) 87 | self.delimiter = delimiter 88 | 89 | def update(self, **kwargs): 90 | for k, v in kwargs.items(): 91 | if v is None: 92 | continue 93 | if isinstance(v, torch.Tensor): 94 | v = v.item() 95 | assert isinstance(v, (float, int)) 96 | self.meters[k].update(v) 97 | 98 | def __getattr__(self, attr): 99 | if attr in self.meters: 100 | return self.meters[attr] 101 | if attr in self.__dict__: 102 | return self.__dict__[attr] 103 | raise AttributeError("'{}' object has no attribute '{}'".format( 104 | type(self).__name__, attr)) 105 | 106 | def __str__(self): 107 | loss_str = [] 108 | for name, meter in self.meters.items(): 109 | loss_str.append( 110 | "{}: {}".format(name, str(meter)) 111 | ) 112 | return self.delimiter.join(loss_str) 113 | 114 | def synchronize_between_processes(self): 115 | for meter in self.meters.values(): 116 | meter.synchronize_between_processes() 117 | 118 | def add_meter(self, name, meter): 119 | self.meters[name] = meter 120 | 121 | def log_every(self, iterable, print_freq, header=None): 122 | i = 0 123 | if not header: 124 | header = '' 125 | start_time = time.time() 126 | end = time.time() 127 | iter_time = SmoothedValue(fmt='{avg:.4f}') 128 | data_time = SmoothedValue(fmt='{avg:.4f}') 129 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 130 | log_msg = [ 131 | header, 132 | '[{0' + space_fmt + '}/{1}]', 133 | 'eta: {eta}', 134 | '{meters}', 135 | 'time: {time}', 136 | 'data: {data}' 137 | ] 138 | if torch.cuda.is_available(): 139 | log_msg.append('max mem: {memory:.0f}') 140 | log_msg = self.delimiter.join(log_msg) 141 | MB = 1024.0 * 1024.0 142 | for obj in iterable: 143 | data_time.update(time.time() - end) 144 | yield obj 145 | iter_time.update(time.time() - end) 146 | if i % print_freq == 0 or i == len(iterable) - 1: 147 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 148 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 149 | if torch.cuda.is_available(): 150 | print(log_msg.format( 151 | i, len(iterable), eta=eta_string, 152 | meters=str(self), 153 | time=str(iter_time), data=str(data_time), 154 | memory=torch.cuda.max_memory_allocated() / MB)) 155 | else: 156 | print(log_msg.format( 157 | i, len(iterable), eta=eta_string, 158 | meters=str(self), 159 | time=str(iter_time), data=str(data_time))) 160 | i += 1 161 | end = time.time() 162 | total_time = time.time() - start_time 163 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 164 | print('{} Total time: {} ({:.4f} s / it)'.format( 165 | header, total_time_str, total_time / len(iterable))) 166 | 167 | 168 | def setup_for_distributed(is_master): 169 | """ 170 | This function disables printing when not in master process 171 | """ 172 | builtin_print = builtins.print 173 | 174 | def print(*args, **kwargs): 175 | force = kwargs.pop('force', False) 176 | force = force or (get_world_size() > 8) 177 | if is_master:# or force: 178 | now = datetime.datetime.now().time() 179 | builtin_print('[{}] '.format(now), end='') # print with time stamp 180 | builtin_print(*args, **kwargs) 181 | 182 | builtins.print = print 183 | 184 | 185 | def is_dist_avail_and_initialized(): 186 | if not dist.is_available(): 187 | return False 188 | if not dist.is_initialized(): 189 | return False 190 | return True 191 | 192 | 193 | def get_world_size(): 194 | if not is_dist_avail_and_initialized(): 195 | return 1 196 | return dist.get_world_size() 197 | 198 | 199 | def get_rank(): 200 | if not is_dist_avail_and_initialized(): 201 | return 0 202 | return dist.get_rank() 203 | 204 | 205 | def is_main_process(): 206 | return get_rank() == 0 207 | 208 | 209 | def save_on_master(*args, **kwargs): 210 | if is_main_process(): 211 | torch.save(*args, **kwargs) 212 | 213 | 214 | def init_distributed_mode(args): 215 | if args.dist_on_itp: 216 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 217 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 218 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 219 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 220 | os.environ['LOCAL_RANK'] = str(args.gpu) 221 | os.environ['RANK'] = str(args.rank) 222 | os.environ['WORLD_SIZE'] = str(args.world_size) 223 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 224 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 225 | args.rank = int(os.environ["RANK"]) 226 | args.world_size = int(os.environ['WORLD_SIZE']) 227 | args.gpu = int(os.environ['LOCAL_RANK']) 228 | elif 'SLURM_PROCID' in os.environ: 229 | args.rank = int(os.environ['SLURM_PROCID']) 230 | args.gpu = args.rank % torch.cuda.device_count() 231 | else: 232 | print('Not using distributed mode') 233 | setup_for_distributed(is_master=True) # hack 234 | args.distributed = False 235 | return 236 | 237 | args.distributed = True 238 | 239 | torch.cuda.set_device(args.gpu) 240 | args.dist_backend = 'nccl' 241 | print('| distributed init (rank {}): {}, gpu {}'.format( 242 | args.rank, args.dist_url, args.gpu), flush=True) 243 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 244 | world_size=args.world_size, rank=args.rank) 245 | torch.distributed.barrier() 246 | setup_for_distributed(args.rank == 0) 247 | 248 | 249 | class NativeScalerWithGradNormCount: 250 | state_dict_key = "amp_scaler" 251 | 252 | def __init__(self): 253 | self._scaler = torch.cuda.amp.GradScaler() 254 | 255 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 256 | self._scaler.scale(loss).backward(create_graph=create_graph) 257 | if update_grad: 258 | if clip_grad is not None: 259 | assert parameters is not None 260 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 261 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 262 | else: 263 | self._scaler.unscale_(optimizer) 264 | norm = get_grad_norm_(parameters) 265 | self._scaler.step(optimizer) 266 | self._scaler.update() 267 | else: 268 | norm = None 269 | return norm 270 | 271 | def state_dict(self): 272 | return self._scaler.state_dict() 273 | 274 | def load_state_dict(self, state_dict): 275 | self._scaler.load_state_dict(state_dict) 276 | 277 | 278 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 279 | if isinstance(parameters, torch.Tensor): 280 | parameters = [parameters] 281 | parameters = [p for p in parameters if p.grad is not None] 282 | norm_type = float(norm_type) 283 | if len(parameters) == 0: 284 | return torch.tensor(0.) 285 | device = parameters[0].grad.device 286 | if norm_type == inf: 287 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 288 | else: 289 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 290 | return total_norm 291 | 292 | 293 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): 294 | output_dir = Path(args.output_dir) 295 | epoch_name = str(epoch) 296 | if loss_scaler is not None: 297 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 298 | for checkpoint_path in checkpoint_paths: 299 | to_save = { 300 | 'model': model_without_ddp.state_dict(), 301 | 'optimizer': optimizer.state_dict(), 302 | 'epoch': epoch, 303 | 'scaler': loss_scaler.state_dict(), 304 | 'args': args, 305 | } 306 | 307 | save_on_master(to_save, checkpoint_path) 308 | else: 309 | client_state = {'epoch': epoch} 310 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) 311 | 312 | 313 | def load_model(args, model_without_ddp, optimizer, loss_scaler): 314 | if args.resume: 315 | if args.resume.startswith('https'): 316 | checkpoint = torch.hub.load_state_dict_from_url( 317 | args.resume, map_location='cpu', check_hash=True) 318 | else: 319 | checkpoint = torch.load(args.resume, map_location='cpu') 320 | model_without_ddp.load_state_dict(checkpoint['model']) 321 | print("Resume checkpoint %s" % args.resume) 322 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 323 | optimizer.load_state_dict(checkpoint['optimizer']) 324 | args.start_epoch = checkpoint['epoch'] + 1 325 | if 'scaler' in checkpoint: 326 | loss_scaler.load_state_dict(checkpoint['scaler']) 327 | print("With optim & sched!") 328 | 329 | 330 | def all_reduce_mean(x): 331 | world_size = get_world_size() 332 | if world_size > 1: 333 | x_reduce = torch.tensor(x).cuda() 334 | dist.all_reduce(x_reduce) 335 | x_reduce /= world_size 336 | return x_reduce.item() 337 | else: 338 | return x -------------------------------------------------------------------------------- /util/shapenet.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import glob 4 | import random 5 | 6 | import yaml 7 | 8 | import torch 9 | from torch.utils import data 10 | 11 | import numpy as np 12 | 13 | from PIL import Image 14 | 15 | import h5py 16 | 17 | category_ids = { 18 | '02691156': 0, 19 | '02747177': 1, 20 | '02773838': 2, 21 | '02801938': 3, 22 | '02808440': 4, 23 | '02818832': 5, 24 | '02828884': 6, 25 | '02843684': 7, 26 | '02871439': 8, 27 | '02876657': 9, 28 | '02880940': 10, 29 | '02924116': 11, 30 | '02933112': 12, 31 | '02942699': 13, 32 | '02946921': 14, 33 | '02954340': 15, 34 | '02958343': 16, 35 | '02992529': 17, 36 | '03001627': 18, 37 | '03046257': 19, 38 | '03085013': 20, 39 | '03207941': 21, 40 | '03211117': 22, 41 | '03261776': 23, 42 | '03325088': 24, 43 | '03337140': 25, 44 | '03467517': 26, 45 | '03513137': 27, 46 | '03593526': 28, 47 | '03624134': 29, 48 | '03636649': 30, 49 | '03642806': 31, 50 | '03691459': 32, 51 | '03710193': 33, 52 | '03759954': 34, 53 | '03761084': 35, 54 | '03790512': 36, 55 | '03797390': 37, 56 | '03928116': 38, 57 | '03938244': 39, 58 | '03948459': 40, 59 | '03991062': 41, 60 | '04004475': 42, 61 | '04074963': 43, 62 | '04090263': 44, 63 | '04099429': 45, 64 | '04225987': 46, 65 | '04256520': 47, 66 | '04330267': 48, 67 | '04379243': 49, 68 | '04401088': 50, 69 | '04460130': 51, 70 | '04468005': 52, 71 | '04530566': 53, 72 | '04554684': 54, 73 | } 74 | 75 | class ShapeNet(data.Dataset): 76 | def __init__(self, dataset_folder, split, categories=None, transform=None, sampling=True, num_samples=4096, return_surface=True, surface_sampling=True, pc_size=2048, replica=16): 77 | 78 | self.pc_size = pc_size 79 | 80 | self.transform = transform 81 | self.num_samples = num_samples 82 | self.sampling = sampling 83 | self.split = split 84 | 85 | self.dataset_folder = dataset_folder 86 | self.return_surface = return_surface 87 | self.surface_sampling = surface_sampling 88 | 89 | self.dataset_folder = dataset_folder 90 | # self.point_folder = os.path.join(self.dataset_folder, 'ShapeNetV2_sdf') 91 | self.point_folder = os.path.join(self.dataset_folder, 'ShapeNetV2_sdf') 92 | self.mesh_folder = os.path.join(self.dataset_folder, 'ShapeNetV2_watertight') 93 | 94 | # categories = None 95 | if categories is None: 96 | categories = os.listdir(self.point_folder) 97 | categories = [c for c in categories if os.path.isdir(os.path.join(self.point_folder, c)) and c.startswith('0')] 98 | categories.sort() 99 | 100 | # categories = ['03001627'] 101 | 102 | print(categories) 103 | 104 | self.models = [] 105 | for c_idx, c in enumerate(categories): 106 | subpath = os.path.join(self.point_folder, c) 107 | assert os.path.isdir(subpath) 108 | 109 | split_file = os.path.join(subpath.replace('ShapeNetV2_sdf', 'ShapeNetV2_point'), split + '.lst') 110 | with open(split_file, 'r') as f: 111 | models_c = f.read().split('\n') 112 | 113 | self.models += [ 114 | {'category': c, 'model': m.replace('.npz', '')} 115 | for m in models_c 116 | ] 117 | 118 | self.replica = replica 119 | 120 | if self.split == 'train': 121 | self.hf = [] 122 | # self.accum_sum = [0] 123 | for i in range(8): 124 | hf = h5py.File('/ibex/ai/project/c2168/biao/shapenet_sdf_h5/{}-{:03d}.h5'.format(self.split, i), 'r') 125 | self.hf.append(hf) 126 | # self.accum_sum.append(self.accum_sum[-1] + len(hf.keys()) // 5) 127 | # print(self.accum_sum) 128 | else: 129 | self.hf = h5py.File('/ibex/ai/project/c2168/biao/shapenet_sdf_h5/{}-000.h5'.format(self.split), 'r') 130 | 131 | def __getitem__(self, idx): 132 | idx = idx % len(self.models) 133 | 134 | category = self.models[idx]['category'] 135 | 136 | model = self.models[idx]['model'] 137 | 138 | if isinstance(self.hf, list): 139 | hf = self.hf[idx % 8] 140 | else: 141 | hf = self.hf 142 | 143 | vol_points = hf['{}_{}_{}'.format(category, model, 'vol_points')][:] 144 | vol_sdf = hf['{}_{}_{}'.format(category, model, 'vol_sdf')][:] 145 | near_points = hf['{}_{}_{}'.format(category, model, 'near_points')][:] 146 | near_label = hf['{}_{}_{}'.format(category, model, 'near_sdf')][:] 147 | surface = hf['{}_{}_{}'.format(category, model, 'surface_points')][:] 148 | 149 | 150 | if self.return_surface: 151 | if self.surface_sampling: 152 | ind = np.random.default_rng().choice(surface.shape[0], self.pc_size, replace=False) 153 | surface = surface[ind] 154 | surface = torch.from_numpy(surface) 155 | 156 | ind = np.random.default_rng().choice(vol_points.shape[0], self.num_samples, replace=False) 157 | vol_points2 = vol_points[ind] 158 | vol_sdf2 = vol_sdf[ind] 159 | 160 | if self.sampling: 161 | 162 | ind = np.random.default_rng().choice(vol_points.shape[0], 1024, replace=False) 163 | vol_points = vol_points[ind] 164 | vol_sdf = vol_sdf[ind] 165 | 166 | 167 | ind = np.random.default_rng().choice(near_points.shape[0], 1024, replace=False) 168 | near_points = near_points[ind] 169 | near_label = near_label[ind] 170 | 171 | 172 | vol_points = torch.from_numpy(vol_points) 173 | vol_sdf = torch.from_numpy(vol_sdf).float() 174 | 175 | vol_points2 = torch.from_numpy(vol_points2) 176 | vol_sdf2 = torch.from_numpy(vol_sdf2).float() 177 | 178 | if self.split == 'train': 179 | near_points = torch.from_numpy(near_points) 180 | near_label = torch.from_numpy(near_label).float() 181 | 182 | 183 | points = torch.cat([vol_points, near_points], dim=0) 184 | labels = torch.cat([vol_sdf, near_label], dim=0) 185 | else: 186 | 187 | near_points = torch.from_numpy(near_points) 188 | near_label = torch.from_numpy(near_label).float() 189 | 190 | 191 | points = torch.cat([vol_points, near_points], dim=0) 192 | labels = torch.cat([vol_sdf, near_label], dim=0) 193 | 194 | if self.transform: 195 | surface, points = self.transform(surface, points) 196 | 197 | 198 | if self.return_surface: 199 | return points, labels, vol_points2, vol_sdf2, surface, category_ids[category]#, model 200 | else: 201 | return points, labels, category_ids[category] 202 | 203 | def __len__(self): 204 | if self.split != 'train': 205 | return len(self.models) 206 | else: 207 | return len(self.models) * self.replica 208 | --------------------------------------------------------------------------------