├── img └── teaser.png ├── data ├── synsetoffset2category.txt └── download.sh ├── extension ├── setup.py ├── test.py ├── chamfer_cuda.cpp ├── dist_chamfer.py └── chamfer.cu ├── auxiliary ├── loss.py ├── utils.py ├── dataset.py └── model.py ├── README.md └── training └── train.py /img/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theodeprelle/AtlasNetV2/HEAD/img/teaser.png -------------------------------------------------------------------------------- /data/synsetoffset2category.txt: -------------------------------------------------------------------------------- 1 | plane 02691156 2 | bench 02828884 3 | cabinet 02933112 4 | car 02958343 5 | chair 03001627 6 | monitor 03211117 7 | lamp 03636649 8 | speaker 03691459 9 | firearm 04090263 10 | couch 04256520 11 | table 04379243 12 | cellphone 04401088 13 | watercraft 04530566 14 | -------------------------------------------------------------------------------- /extension/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='chamfer', 6 | ext_modules=[ 7 | CUDAExtension('chamfer', [ 8 | 'chamfer_cuda.cpp', 9 | 'chamfer.cu', 10 | ]), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) -------------------------------------------------------------------------------- /extension/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import dist_chamfer 3 | dist = dist_chamfer.chamferDist() 4 | 5 | with torch.enable_grad(): 6 | p1 = torch.rand(10,1000,6) 7 | p2 = torch.rand(10,1500,6) 8 | p1.requires_grad = True 9 | p2.requires_grad = True 10 | points1 = p1.cuda() 11 | points2 = p2.cuda() 12 | cost, _ = dist(points1, points2) 13 | print(cost) 14 | loss = torch.sum(cost) 15 | print(loss) 16 | loss.backward() 17 | print(points1.grad, points2.grad) 18 | -------------------------------------------------------------------------------- /auxiliary/loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import torch 4 | 5 | sys.path.append("./extension/") 6 | import dist_chamfer as ext 7 | distChamferL2 = ext.chamferDist() 8 | 9 | def ChamferLoss(target,prediction): 10 | 11 | dist1, dist2 = distChamferL2(target, prediction) 12 | loss = torch.mean(dist1)+torch.mean(dist2) 13 | 14 | return loss 15 | 16 | class LOSS_LIST: 17 | """list of all the model""" 18 | def __init__(self): 19 | 20 | self.losses={"AtlasNet":ChamferLoss, 21 | "PatchDeformation":ChamferLoss, 22 | "PointTranslation":ChamferLoss,} 23 | 24 | def load(self,options): 25 | 26 | loss = self.losses[options.model] 27 | return loss 28 | -------------------------------------------------------------------------------- /data/download.sh: -------------------------------------------------------------------------------- 1 | #This script download the data from ENPC cloud 2 | 3 | # The point clouds from ShapeNet, with normals 4 | #wget https://cloud.enpc.fr/s/j2ECcKleA1IKNzk/download --no-check-certificate 5 | wget https://cloud.enpc.fr/s/JNf3NAxGbQoQsKY/download --no-check-certificate 6 | 7 | unzip download 8 | rm download 9 | 10 | # The corresponding normalized mesh (for the metro distance) 11 | #wget https://cloud.enpc.fr/s/RATKsfLQUSu0JWW/download --no-check-certificate 12 | wget https://cloud.enpc.fr/s/RATKsfLQUSu0JWW/download --no-check-certificate 13 | 14 | unzip download 15 | rm download 16 | 17 | # the rendered views 18 | #wget https://cloud.enpc.fr/s/S6TCx1QJzviNHq0/download --no-check-certificate 19 | wget https://cloud.enpc.fr/s/8RRYY6dyJ6AF7qk/download --no-check-certificate 20 | 21 | unzip download 22 | rm download 23 | -------------------------------------------------------------------------------- /extension/chamfer_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | ///TMP 5 | //#include "common.h" 6 | /// NOT TMP 7 | 8 | 9 | int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2); 10 | 11 | 12 | int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2); 13 | 14 | 15 | 16 | 17 | int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) { 18 | return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2); 19 | } 20 | 21 | 22 | int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, 23 | at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) { 24 | 25 | return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2); 26 | } 27 | 28 | 29 | 30 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 31 | m.def("forward", &chamfer_forward, "chamfer forward (CUDA)"); 32 | m.def("backward", &chamfer_backward, "chamfer backward (CUDA)"); 33 | } -------------------------------------------------------------------------------- /extension/dist_chamfer.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch import nn 3 | from torch.autograd import Function 4 | import torch 5 | import sys 6 | from numbers import Number 7 | from collections import Set, Mapping, deque 8 | import chamfer 9 | 10 | # Chamfer's distance module @thibaultgroueix 11 | # GPU tensors only 12 | class chamferFunction(Function): 13 | @staticmethod 14 | def forward(ctx, xyz1, xyz2): 15 | batchsize, n, _ = xyz1.size() 16 | _, m, _ = xyz2.size() 17 | 18 | dist1 = torch.zeros(batchsize, n) 19 | dist2 = torch.zeros(batchsize, m) 20 | 21 | idx1 = torch.zeros(batchsize, n).type(torch.IntTensor) 22 | idx2 = torch.zeros(batchsize, m).type(torch.IntTensor) 23 | 24 | dist1 = dist1.cuda() 25 | dist2 = dist2.cuda() 26 | idx1 = idx1.cuda() 27 | idx2 = idx2.cuda() 28 | 29 | chamfer.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) 30 | ctx.save_for_backward(xyz1, xyz2, idx1, idx2) 31 | return dist1, dist2 32 | 33 | @staticmethod 34 | def backward(ctx, graddist1, graddist2): 35 | xyz1, xyz2, idx1, idx2 = ctx.saved_tensors 36 | graddist1 = graddist1.contiguous() 37 | graddist2 = graddist2.contiguous() 38 | 39 | gradxyz1 = torch.zeros(xyz1.size()) 40 | gradxyz2 = torch.zeros(xyz2.size()) 41 | 42 | gradxyz1 = gradxyz1.cuda() 43 | gradxyz2 = gradxyz2.cuda() 44 | chamfer.backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2) 45 | return gradxyz1, gradxyz2 46 | 47 | class chamferDist(nn.Module): 48 | def __init__(self): 49 | super(chamferDist, self).__init__() 50 | 51 | def forward(self, input1, input2): 52 | return chamferFunction.apply(input1, input2) 53 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![teaset](img/teaser.png) 2 | 3 | # AtlasNet V2 - Learning Elementary Structures 4 | This work was build upon [Thibault Groueix](https://github.com/ThibaultGROUEIX/)'s [AtlasNet](https://github.com/ThibaultGROUEIX/AtlasNet) and [3D-CODED](https://github.com/ThibaultGROUEIX/3D-CODED) projects. (you might want to have a look at those) 5 | 6 | This repository contains the source codes for the paper [AtlasNet V2 - Learning Elementary Structures](https://arxiv.org/abs/1908.04725). 7 | 8 | ### Citing this work 9 | 10 | If you find this work useful in your research, please consider citing: 11 | 12 | ``` 13 | @inproceedings{deprelle2019learning, 14 | title={Learning elementary structures for 3D shape generation and matching}, 15 | author={Deprelle, Theo and Groueix, Thibault and Fisher, Matthew and Kim, Vladimir and Russell, Bryan and Aubry, Mathieu}, 16 | booktitle={Advances in Neural Information Processing Systems}, 17 | pages={7433--7443}, 18 | year={2019} 19 | } 20 | ``` 21 | 22 | ### Project Page 23 | 24 | The project page is available http://imagine.enpc.fr/~deprellt/atlasnet2/ 25 | 26 | # Install 27 | 28 | ### Clone the repo and install dependencies 29 | 30 | This implementation uses [Pytorch](http://pytorch.org/). 31 | 32 | ```shell 33 | ## Download the repository 34 | git clone https://github.com/TheoDEPRELLE/AtlasNetV2.git 35 | cd AtlasNetV2 36 | ## Create python env with relevant packages 37 | conda create --name atlasnetV2 python=3.7 38 | source activate atlasnetV2 39 | pip install pandas visdom 40 | conda install pytorch torchvision -c pytorch 41 | conda install -c conda-forge matplotlib 42 | # you're done ! Congrats :) 43 | 44 | ``` 45 | 46 | # Training 47 | 48 | ### Data 49 | 50 | ```shell 51 | cd data; ./download_data.sh; cd .. 52 | ``` 53 | We used the [ShapeNet](https://www.shapenet.org/) dataset for 3D models. 54 | 55 | When using the provided data make sure to respect the shapenet [license](https://shapenet.org/terms). 56 | 57 | * [The point clouds from ShapeNet, with normals](https://cloud.enpc.fr/s/j2ECcKleA1IKNzk) go in ``` data/customShapeNet``` 58 | * [The corresponding normalized mesh (for the metro distance)](https://cloud.enpc.fr/s/RATKsfLQUSu0JWW) goes in ``` data/ShapeNetCorev2Normalized``` 59 | * [the rendered views](https://cloud.enpc.fr/s/S6TCx1QJzviNHq0) go in ``` data/ShapeNetRendering``` 60 | 61 | The trained models and some corresponding results are also available online : 62 | 63 | * [The trained_models](https://cloud.enpc.fr/s/qt5M3ZnjF8NZoy4) go in ``` trained_models/``` 64 | 65 | 66 | ### Build chamfer distance 67 | 68 | The chamfer loss is based on a custom cuda code that need to be compile. 69 | 70 | ```shell 71 | source activate pytorch-atlasnet 72 | cd ./extension 73 | python setup.py install 74 | ``` 75 | 76 | 77 | 78 | ### Start training 79 | 80 | * First launch a visdom server : 81 | 82 | ```bash 83 | python -m visdom.server -p 8888 84 | ``` 85 | 86 | * Check out all the options : 87 | 88 | ```shell 89 | git pull; python training/train.py --help 90 | ``` 91 | 92 | * Run the baseline : 93 | 94 | ```shell 95 | git pull; python training/train.py --model AtlasNet --adjust mlp 96 | git pull; python training/train.py --model AtlasNet --adjust linear 97 | ``` 98 | * Run the Patch Deformation module with the different adjustment modules : 99 | 100 | ```shell 101 | git pull; python training/train.py --model PatchDeformation --adjust mlp 102 | git pull; python training/train.py --model PatchDeformation --adjust linear 103 | ``` 104 | 105 | * Run the Point Translation module with the different adjustment modules: 106 | 107 | ```shell 108 | git pull; python training/train.py --model PointTranslation --adjust mlp 109 | git pull; python training/train.py --model PointTranslation --adjust linear 110 | ``` 111 | 112 | 113 | * Monitor your training on http://localhost:8888/ 114 | 115 | ## Models 116 | 117 | The models train on the SURREAL dataset for the FAUST competition can be found [here](https://github.com/ThibaultGROUEIX/3D-CODED) 118 | 119 | ## Acknowledgement 120 | 121 | This work was partly supported by ANR project EnHerit ANR-17-CE23-0008, Labex Bezout, and gifts from Adobe to Ecole des Ponts. 122 | 123 | 124 | ## License 125 | 126 | [MIT](https://github.com/ThibaultGROUEIX/AtlasNet/blob/master/license_MIT) 127 | -------------------------------------------------------------------------------- /auxiliary/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | import pickle 4 | import visdom 5 | import torch 6 | import sys 7 | import os 8 | import random 9 | 10 | class COLORS: 11 | HEADER = '\033[95m' 12 | OKBLUE = '\033[94m' 13 | OKGREEN = '\033[92m' 14 | WARNING = '\033[93m' 15 | FAIL = '\033[91m' 16 | ENDC = '\033[0m' 17 | BOLD = '\033[1m' 18 | UNDERLINE = '\033[4m' 19 | 20 | def display_opts(opts): 21 | 22 | 23 | display_msg = """PARAMETERS: 24 | model %s %s %s 25 | adjust %s %s %s 26 | dataset %s %s %s 27 | loadmodel %s %s %s 28 | npatch %s %s %s 29 | npoint %s %s %s 30 | nlatent %s %s %s 31 | nbatch %s %s %s 32 | lrate %s %s %s 33 | nepoch %s %s %s 34 | first decay %s %s %s 35 | second decay %s %s %s 36 | training_id %s %s %s 37 | """%(COLORS.OKGREEN,opts.model,COLORS.ENDC, 38 | COLORS.OKGREEN,opts.adjust,COLORS.ENDC, 39 | COLORS.OKGREEN,opts.dataset,COLORS.ENDC, 40 | COLORS.OKGREEN,opts.loadmodel,COLORS.ENDC, 41 | COLORS.OKGREEN,opts.npatch,COLORS.ENDC, 42 | COLORS.OKGREEN,opts.npoint,COLORS.ENDC, 43 | COLORS.OKGREEN,opts.nlatent,COLORS.ENDC, 44 | COLORS.OKGREEN,opts.nbatch,COLORS.ENDC, 45 | COLORS.OKGREEN,opts.lrate,COLORS.ENDC, 46 | COLORS.OKGREEN,opts.nepoch,COLORS.ENDC, 47 | COLORS.OKGREEN,opts.firstdecay,COLORS.ENDC, 48 | COLORS.OKGREEN,opts.seconddecay,COLORS.ENDC, 49 | COLORS.OKGREEN,opts.training_id,COLORS.ENDC) 50 | 51 | print(display_msg) 52 | 53 | 54 | def display_it(mode, opt, epoch_id, batch_id, loss=None): 55 | """display iteration""" 56 | 57 | if batch_id % 50 == 0: 58 | msg = '' 59 | 60 | if mode == 'train': 61 | msg = "[%s%s%s] - %d/%d - %04d %s%f%s" % (COLORS.OKGREEN, 62 | opt.training_id, 63 | COLORS.ENDC, 64 | epoch_id, 65 | opt.nepoch, 66 | batch_id, 67 | COLORS.BOLD, 68 | loss, 69 | COLORS.ENDC) 70 | 71 | if mode == 'valid': 72 | msg = "[%s%s%s] - %d/%d - %04d %s%f%s" % (COLORS.OKBLUE, 73 | opt.training_id, 74 | COLORS.ENDC, 75 | epoch_id, 76 | opt.nepoch, 77 | batch_id, 78 | COLORS.BOLD, 79 | loss, 80 | COLORS.ENDC) 81 | 82 | if mode == 'test': 83 | msg = "[%s%s%s] - %d/%d" % (COLORS.WARNING, 84 | opt.training_id, COLORS.ENDC, 85 | epoch_id, opt.nepoch) 86 | print(msg) 87 | 88 | 89 | class LOGGER: 90 | """logger of the network loss """ 91 | 92 | def __init__(self): 93 | self.history = [] 94 | self.data = [] 95 | 96 | def add(self, val): 97 | self.data.append(val) 98 | 99 | def save(self, path): 100 | with open(path, "wb") as f: 101 | pickle.dump(self.history, f) 102 | 103 | def mean(self): 104 | m = np.mean(np.array(self.data)) 105 | return m 106 | 107 | def reset(self): 108 | if self.data: 109 | self.history.append(np.mean(np.array(self.data))) 110 | self.data = [] 111 | 112 | 113 | def weights_init(m): 114 | classname = m.__class__.__name__ 115 | if classname.find('Conv') != -1: 116 | m.weight.data.normal_(0.0, 0.02) 117 | elif classname.find('BatchNorm') != -1: 118 | m.weight.data.normal_(1.0, 0.02) 119 | m.bias.data.fill_(0) 120 | 121 | class AverageValueMeter(object): 122 | """Computes and stores the average and current value""" 123 | def __init__(self): 124 | self.reset() 125 | 126 | def reset(self): 127 | self.val = 0 128 | self.avg = 0 129 | self.sum = 0 130 | self.count = 0.0 131 | 132 | def update(self, val, n=1): 133 | self.val = val 134 | self.sum += val * n 135 | self.count += n 136 | self.avg = self.sum / self.count 137 | 138 | CHUNK_SIZE = 150 139 | lenght_line = 60 140 | def my_get_n_random_lines(path, n=5): 141 | MY_CHUNK_SIZE = lenght_line * (n+2) 142 | lenght = os.stat(path).st_size 143 | with open(path, 'r') as file: 144 | file.seek(random.randint(400, lenght - MY_CHUNK_SIZE)) 145 | chunk = file.read(MY_CHUNK_SIZE) 146 | lines = chunk.split(os.linesep) 147 | return lines[1:n+1] 148 | -------------------------------------------------------------------------------- /extension/chamfer.cu: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | #include 9 | 10 | 11 | 12 | __global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){ 13 | const int batch=512; 14 | __shared__ float buf[batch*3]; 15 | for (int i=blockIdx.x;ibest){ 127 | result[(i*n+j)]=best; 128 | result_i[(i*n+j)]=best_i; 129 | } 130 | } 131 | __syncthreads(); 132 | } 133 | } 134 | } 135 | // int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){ 136 | int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){ 137 | 138 | const auto batch_size = xyz1.size(0); 139 | const auto n = xyz1.size(1); //num_points point cloud A 140 | const auto m = xyz2.size(1); //num_points point cloud B 141 | 142 | NmDistanceKernel<<>>(batch_size, n, xyz1.data(), m, xyz2.data(), dist1.data(), idx1.data()); 143 | NmDistanceKernel<<>>(batch_size, m, xyz2.data(), n, xyz1.data(), dist2.data(), idx2.data()); 144 | 145 | cudaError_t err = cudaGetLastError(); 146 | if (err != cudaSuccess) { 147 | printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err)); 148 | //THError("aborting"); 149 | return 0; 150 | } 151 | return 1; 152 | 153 | 154 | } 155 | __global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){ 156 | for (int i=blockIdx.x;i>>(batch_size,n,xyz1.data(),m,xyz2.data(),graddist1.data(),idx1.data(),gradxyz1.data(),gradxyz2.data()); 185 | NmDistanceGradKernel<<>>(batch_size,m,xyz2.data(),n,xyz1.data(),graddist2.data(),idx2.data(),gradxyz2.data(),gradxyz1.data()); 186 | 187 | cudaError_t err = cudaGetLastError(); 188 | if (err != cudaSuccess) { 189 | printf("error in nnd get grad: %s\n", cudaGetErrorString(err)); 190 | //THError("aborting"); 191 | return 0; 192 | } 193 | return 1; 194 | 195 | } 196 | 197 | -------------------------------------------------------------------------------- /auxiliary/dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import json 4 | import glob 5 | import torch 6 | import numpy as np 7 | import torchvision.transforms as transforms 8 | import torch.utils.data as data 9 | from auxiliary.utils import * 10 | 11 | class ShapeNet(data.Dataset): 12 | def __init__(self,train=True,options=None): 13 | rootimg = "./data/ShapeNet/ShapeNetRendering" 14 | rootpc = "./data/customShapeNet" 15 | class_choice = None 16 | npoints = options.npoint 17 | normal = False 18 | balanced = False 19 | gen_view=False 20 | SVR=False 21 | idx=0 22 | self.balanced = balanced 23 | self.normal = normal 24 | self.train = train 25 | self.rootimg = rootimg 26 | self.rootpc = rootpc 27 | self.npoints = npoints 28 | self.datapath = [] 29 | self.catfile = os.path.join('./data/synsetoffset2category.txt') 30 | self.cat = {} 31 | self.meta = {} 32 | self.SVR = SVR 33 | self.gen_view = gen_view 34 | self.idx=idx 35 | with open(self.catfile, 'r') as f: 36 | for line in f: 37 | ls = line.strip().split() 38 | self.cat[ls[0]] = ls[1] 39 | if not class_choice is None: 40 | self.cat = {k:v for k,v in self.cat.items() if k in class_choice} 41 | print(self.cat) 42 | empty = [] 43 | for item in self.cat: 44 | dir_img = os.path.join(self.rootimg, self.cat[item]) 45 | fns_img = sorted(os.listdir(dir_img)) 46 | 47 | try: 48 | dir_point = os.path.join(self.rootpc, self.cat[item], 'ply') 49 | fns_pc = sorted(os.listdir(dir_point)) 50 | except: 51 | fns_pc = [] 52 | fns = [val for val in fns_img if val + '.points.ply' in fns_pc] 53 | print('category ', self.cat[item], 'files ' + str(len(fns)), len(fns)/float(len(fns_img)), "%"), 54 | if train: 55 | fns = fns[:int(len(fns) * 0.8)] 56 | else: 57 | fns = fns[int(len(fns) * 0.8):] 58 | 59 | 60 | if len(fns) != 0: 61 | self.meta[item] = [] 62 | for fn in fns: 63 | objpath = "./data/ShapeNetCorev2/" + self.cat[item] + "/" + fn + "/models/model_normalized.ply" 64 | self.meta[item].append( ( os.path.join(dir_img, fn, "rendering"), os.path.join(dir_point, fn + '.points.ply'), item, objpath, fn ) ) 65 | else: 66 | empty.append(item) 67 | for item in empty: 68 | del self.cat[item] 69 | self.idx2cat = {} 70 | self.size = {} 71 | i = 0 72 | for item in self.cat: 73 | self.idx2cat[i] = item 74 | self.size[i] = len(self.meta[item]) 75 | i = i + 1 76 | # for fn in self.meta[item]: 77 | l = int(len(self.meta[item])) 78 | for fn in self.meta[item][0:l]: 79 | self.datapath.append(fn) 80 | 81 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 82 | std=[0.229, 0.224, 0.225]) 83 | 84 | self.transforms = transforms.Compose([ 85 | transforms.Resize(size = 224, interpolation = 2), 86 | transforms.ToTensor(), 87 | # normalize, 88 | ]) 89 | 90 | # RandomResizedCrop or RandomCrop 91 | self.dataAugmentation = transforms.Compose([ 92 | transforms.RandomCrop(127), 93 | transforms.RandomHorizontalFlip(), 94 | ]) 95 | self.validating = transforms.Compose([ 96 | transforms.CenterCrop(127), 97 | ]) 98 | 99 | self.perCatValueMeter = {} 100 | for item in self.cat: 101 | self.perCatValueMeter[item] = AverageValueMeter() 102 | self.perCatValueMeter_metro = {} 103 | for item in self.cat: 104 | self.perCatValueMeter_metro[item] = AverageValueMeter() 105 | self.transformsb = transforms.Compose([ 106 | transforms.Resize(size = 224, interpolation = 2), 107 | ]) 108 | 109 | def __getitem__(self, index): 110 | fn = self.datapath[index] 111 | with open(fn[1]) as fp: 112 | for i, line in enumerate(fp): 113 | if i == 2: 114 | try: 115 | lenght = int(line.split()[2]) 116 | except ValueError: 117 | print(fn) 118 | print(line) 119 | break 120 | for i in range(15): #this for loop is because of some weird error that happens sometime during loading I didn't track it down and brute force the solution like this. 121 | try: 122 | mystring = my_get_n_random_lines(fn[1], n = self.npoints) 123 | point_set = np.loadtxt(mystring).astype(np.float32) 124 | break 125 | except ValueError as excep: 126 | print(fn) 127 | print(excep) 128 | 129 | # centroid = np.expand_dims(np.mean(point_set[:,0:3], axis = 0), 0) #Useless because dataset has been normalised already 130 | # point_set[:,0:3] = point_set[:,0:3] - centroid 131 | if not self.normal: 132 | point_set = point_set[:,0:3] 133 | else: 134 | point_set[:,3:6] = 0.1 * point_set[:,3:6] 135 | point_set = torch.from_numpy(point_set) 136 | 137 | # load image 138 | if self.SVR: 139 | if self.train: 140 | N_tot = len(os.listdir(fn[0])) - 3 141 | if N_tot==1: 142 | print("only one view in ", fn) 143 | if self.gen_view: 144 | N=0 145 | else: 146 | N = np.random.randint(1,N_tot) 147 | if N < 10: 148 | im = Image.open(os.path.join(fn[0], "0" + str(N) + ".png")) 149 | else: 150 | im = Image.open(os.path.join(fn[0], str(N) + ".png")) 151 | 152 | im = self.dataAugmentation(im) #random crop 153 | else: 154 | if self.idx < 10: 155 | im = Image.open(os.path.join(fn[0], "0" + str(self.idx) + ".png")) 156 | else: 157 | im = Image.open(os.path.join(fn[0], str(self.idx) + ".png")) 158 | im = self.validating(im) #center crop 159 | data = self.transforms(im) #scale 160 | data = data[:3,:,:] 161 | else: 162 | data = 0 163 | return point_set.contiguous() 164 | 165 | 166 | def __len__(self): 167 | return len(self.datapath) 168 | 169 | class DATASET_LIST: 170 | """list of all the dataset""" 171 | def __init__(self): 172 | 173 | self.datasets = {"shapenet":ShapeNet} 174 | self.type = self.datasets.keys() 175 | 176 | def load(self,training,options): 177 | 178 | if training: 179 | print("\nTRAINING DATASET:") 180 | else: 181 | print("VALIDATION DATASET:") 182 | dataset = self.datasets[options.dataset](training, options) 183 | print("\n") 184 | return dataset 185 | -------------------------------------------------------------------------------- /training/train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import sys 4 | import torch 5 | import torch.utils.data 6 | import datetime 7 | from torch.autograd import Variable 8 | import matplotlib.cm as cm 9 | 10 | sys.path.insert(1,'.') 11 | from auxiliary.loss import * 12 | from auxiliary.model import * 13 | from auxiliary.utils import * 14 | from auxiliary.dataset import * 15 | 16 | #======================================================================================== 17 | # argument parsing 18 | #======================================================================================== 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--patchDim', type=int, default = 2, help='') 21 | parser.add_argument('--patchDeformDim', type=int, default = 3, help='') 22 | parser.add_argument('--model', type=str, default ="AtlasNet", help='') 23 | parser.add_argument('--adjust', type=str, default ="mlp", help='') 24 | parser.add_argument('--lrate', type=float, default =0.001, help='') 25 | parser.add_argument('--nbatch', type=int, default = 16, help='') 26 | parser.add_argument('--nepoch', type=int, default = 325, help='') 27 | parser.add_argument('--npoint', type=int, default = 2500, help='') 28 | parser.add_argument('--npatch', type=int, default = 10, help='') 29 | parser.add_argument('--dataset', type=str, default ="shapenet", help='') 30 | parser.add_argument('--nlatent', type=int, default =1024, help='') 31 | parser.add_argument('--loadmodel', type=str, default=None,help='') 32 | parser.add_argument('--firstdecay', type=int, default = 250, help='') 33 | parser.add_argument('--seconddecay', type=int, default = 300, help='') 34 | parser.add_argument('--training_id', type=str, default=None,help='') 35 | opt = parser.parse_args() 36 | 37 | if opt.training_id == None and opt.model in ['PointTranslation','AtlasNet']: 38 | 39 | opt.training_id = "%s%dD_%sAdj_%s_%dpatch_%dpts_%dep"%(opt.model, 40 | opt.patchDim, 41 | opt.adjust, 42 | opt.dataset, 43 | opt.npatch, 44 | opt.npoint, 45 | opt.nepoch) 46 | 47 | if opt.training_id == None and opt.model == 'PatchDeformation': 48 | 49 | opt.training_id = "%s%dDto%dD_%s_%s_%dpatch_%dpts_%dep"%(opt.model, 50 | opt.patchDim, 51 | opt.patchDeformDim, 52 | opt.adjust, 53 | opt.dataset, 54 | opt.npatch, 55 | opt.npoint, 56 | opt.nepoch) 57 | display_opts(opt) 58 | #======================================================================================== 59 | 60 | 61 | #======================================================================================== 62 | # training element 63 | #======================================================================================== 64 | DATASET = DATASET_LIST() #list of all the methods 65 | MODEL = MODEL_LIST(opt) #list of all the models 66 | LOSS = LOSS_LIST() #list if all the losses 67 | #======================================================================================== 68 | 69 | 70 | #======================================================================================== 71 | # model selection 72 | #======================================================================================== 73 | if opt.model not in MODEL.type: 74 | print(COLORS.FAIL,"ERROR please select the model from : ",COLORS.ENDC) 75 | for model in sorted(MODEL.type): 76 | print(" >",model) 77 | exit() 78 | 79 | network = MODEL.load(opt) # load the loss 80 | network.apply(weights_init) # init the weight 81 | 82 | if opt.loadmodel != None: 83 | network.load_state_dict(opt.loadmodel) 84 | #======================================================================================== 85 | 86 | 87 | #======================================================================================== 88 | # dataset selection 89 | #======================================================================================== 90 | if opt.dataset not in DATASET.type: 91 | print(COLORS.FAIL,"ERROR please select the dataset from : ",COLORS.ENDC) 92 | for dataset in DATASET.type: 93 | print(" >",dataset) 94 | exit() 95 | 96 | dataset_train = DATASET.load(training=True,options=opt) 97 | dataloader_train = torch.utils.data.DataLoader(dataset_train, 98 | shuffle=True, 99 | batch_size=opt.nbatch, 100 | num_workers=12) 101 | 102 | dataset_valid = DATASET.load(training=False,options=opt) 103 | dataloader_valid = torch.utils.data.DataLoader(dataset_valid, 104 | shuffle=False, 105 | batch_size=opt.nbatch, 106 | num_workers=12) 107 | #====================================================================================== 108 | 109 | 110 | #====================================================================================== 111 | # loss selection 112 | #====================================================================================== 113 | loss = LOSS.load(opt) 114 | #====================================================================================== 115 | 116 | 117 | #====================================================================================== 118 | # optimizer 119 | #====================================================================================== 120 | optimizer = torch.optim.Adam(network.parameters(),lr = opt.lrate) 121 | #====================================================================================== 122 | 123 | 124 | #====================================================================================== 125 | # training logs 126 | #====================================================================================== 127 | logger_path = "./log/%s"%opt.training_id 128 | 129 | if not os.path.exists('log'): 130 | os.mkdir('log') 131 | 132 | if not os.path.exists(logger_path): 133 | os.mkdir(logger_path) 134 | 135 | with open('%s/opt.pickle'%logger_path, 'wb') as handle: 136 | pickle.dump(opt, handle, protocol=pickle.HIGHEST_PROTOCOL) 137 | 138 | trainloss_log = LOGGER() 139 | validloss_log = LOGGER() 140 | visdom = visdom.Visdom(env=opt.training_id, port=8888) 141 | #====================================================================================== 142 | 143 | 144 | #====================================================================================== 145 | # TRAINING BEGIN 146 | #====================================================================================== 147 | for epoch_id in range(opt.nepoch): 148 | 149 | network.train() 150 | 151 | trainloss_log.reset() 152 | 153 | if epoch_id == opt.firstdecay: 154 | optimizer = torch.optim.Adam(network.parameters(),lr = opt.lrate/10) 155 | 156 | if epoch_id == opt.seconddecay: 157 | optimizer = torch.optim.Adam(network.parameters(),lr = opt.lrate/100) 158 | 159 | 160 | #================================================================================== 161 | # training 162 | #================================================================================== 163 | for batch_id, batch in enumerate(dataloader_train): 164 | 165 | if batch.size(0) == opt.nbatch: 166 | 167 | optimizer.zero_grad() 168 | 169 | batch = batch.cuda() 170 | prediction, learnedPatches = network(batch.cuda()) 171 | fittingLoss = loss(batch,prediction) 172 | 173 | fittingLoss.backward() 174 | optimizer.step() 175 | 176 | trainloss_log.add(fittingLoss.item()) 177 | display_it("train", opt, epoch_id, batch_id, trainloss_log.mean()) 178 | #================================================================================== 179 | 180 | 181 | #================================================================================== 182 | # validation 183 | #================================================================================== 184 | with torch.no_grad(): 185 | 186 | network.eval() 187 | 188 | for batch_id, batch in enumerate(dataloader_valid): 189 | 190 | batch = batch.cuda() 191 | prediction, learnedPatches = network(batch.cuda()) 192 | fittingLoss = loss(batch,prediction) 193 | 194 | validloss_log.add(fittingLoss.item()) 195 | display_it("valid", opt, epoch_id, batch_id, validloss_log.mean()) 196 | #================================================================================== 197 | 198 | #================================================================================== 199 | # saving loss and parameters 200 | #================================================================================== 201 | if len(validloss_log.history) > 0: 202 | X = np.column_stack((train_log.history, valid_log.history)) 203 | Y = np.column_stack((np.arange(len(train_log.history)), 204 | np.arange(len(train_log.history)))) 205 | visdom.line(X, Y, win="Fitting loss", 206 | opts=dict(title="Fitting loss", legend=["train", "valid"])) 207 | 208 | color = [[125,125,125]]*(batch.size(1)) 209 | cmap = cm.get_cmap('hsv') 210 | for i in range(opt.npatch): 211 | c = cmap(i/opt.npatch-1) 212 | color += [[int(c[0]*255),int(c[1]*255),int(c[2]*255)]]*(opt.npoint//opt.npatch) 213 | 214 | color = np.array(color) 215 | if batch_id < 3: 216 | X = np.vstack((batch[0].data.cpu().numpy(),prediction[0].data.cpu().numpy())) 217 | Y = np.array([1]*batch.size(1)+[2]*opt.npoint) 218 | visdom.scatter(X, 219 | Y, 220 | win='gt%d' % batch_id, 221 | opts=dict(markersize=4, 222 | markercolor=color, 223 | title='gt%d' % batch_id, 224 | legend=['gt', 'batch'])) 225 | 226 | torch.save(network.state_dict(), '%s/network.pth' % (logger_path)) 227 | #==================================================================================i 228 | -------------------------------------------------------------------------------- /auxiliary/model.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | import numpy as np 7 | from auxiliary.utils import * 8 | 9 | 10 | #TODO dim -> input dim output dim 11 | 12 | class linearTransformMLP(nn.Module): 13 | """linear transformation module""" 14 | 15 | def __init__(self, nlatent = 1024): 16 | super(linearTransformMLP, self).__init__() 17 | 18 | self.conv1 = torch.nn.Conv1d(nlatent, nlatent//2, 1) 19 | self.conv2 = torch.nn.Conv1d(nlatent//2, nlatent//2, 1) 20 | self.conv3 = torch.nn.Conv1d(nlatent//2, 16, 1) 21 | self.bn1 = torch.nn.BatchNorm1d(nlatent//2) 22 | self.bn2 = torch.nn.BatchNorm1d(nlatent//2) 23 | self.th = nn.Tanh() 24 | 25 | def forward(self, x): 26 | 27 | x = F.relu(self.bn1(self.conv1(x))) 28 | x = F.relu(self.bn2(self.conv2(x))) 29 | x = self.th(self.conv3(x)) 30 | x = x.view(x.size(0),4,4).contiguous() 31 | return x 32 | 33 | class linearAdj(nn.Module): 34 | """ prediction the linear transdormation matrix""" 35 | 36 | def __init__(self, dim = 3,nlatent = 1024): 37 | super(linearAdj, self).__init__() 38 | 39 | self.conv1 = torch.nn.Conv1d(nlatent, nlatent//2, 1) 40 | self.conv2 = torch.nn.Conv1d(nlatent//2, nlatent//2, 1) 41 | self.conv3 = torch.nn.Conv1d(nlatent//2, (dim+1)*3, 1) 42 | self.bn1 = torch.nn.BatchNorm1d(nlatent//2) 43 | self.bn2 = torch.nn.BatchNorm1d(nlatent//2) 44 | self.th = nn.Tanh() 45 | self.dim = dim 46 | 47 | def forward(self, x): 48 | 49 | x = F.relu(self.bn1(self.conv1(x))) 50 | x = F.relu(self.bn2(self.conv2(x))) 51 | x = self.th(self.conv3(x)) 52 | R = x[:,0:self.dim*3].view(x.size(0),self.dim,3).contiguous() 53 | T = x[:,self.dim*3:].view(x.size(0),1,3).contiguous() 54 | return R,T 55 | 56 | class mlpAdj(nn.Module): 57 | def __init__(self, nlatent = 1024): 58 | """Atlas decoder""" 59 | 60 | super(mlpAdj, self).__init__() 61 | self.nlatent = nlatent 62 | self.conv1 = torch.nn.Conv1d(self.nlatent, self.nlatent, 1) 63 | self.conv2 = torch.nn.Conv1d(self.nlatent, self.nlatent//2, 1) 64 | self.conv3 = torch.nn.Conv1d(self.nlatent//2, self.nlatent//4, 1) 65 | self.conv4 = torch.nn.Conv1d(self.nlatent//4, 3, 1) 66 | 67 | self.th = nn.Tanh() 68 | self.bn1 = torch.nn.BatchNorm1d(self.nlatent) 69 | self.bn2 = torch.nn.BatchNorm1d(self.nlatent//2) 70 | self.bn3 = torch.nn.BatchNorm1d(self.nlatent//4) 71 | 72 | def forward(self, x): 73 | batchsize = x.size()[0] 74 | x = F.relu(self.bn1(self.conv1(x))) 75 | x = F.relu(self.bn2(self.conv2(x))) 76 | x = F.relu(self.bn3(self.conv3(x))) 77 | x = self.th(self.conv4(x)) 78 | return x 79 | 80 | class patchDeformationMLP(nn.Module): 81 | """deformation of a 2D patch into a 3D surface""" 82 | 83 | def __init__(self,patchDim=2,patchDeformDim=3,tanh=True): 84 | 85 | super(patchDeformationMLP, self).__init__() 86 | layer_size = 128 87 | self.tanh=tanh 88 | self.conv1 = torch.nn.Conv1d(patchDim, layer_size, 1) 89 | self.conv2 = torch.nn.Conv1d(layer_size, layer_size, 1) 90 | self.conv3 = torch.nn.Conv1d(layer_size, patchDeformDim, 1) 91 | self.bn1 = torch.nn.BatchNorm1d(layer_size) 92 | self.bn2 = torch.nn.BatchNorm1d(layer_size) 93 | self.th = nn.Tanh() 94 | 95 | def forward(self, x): 96 | x = F.relu(self.bn1(self.conv1(x))) 97 | x = F.relu(self.bn2(self.conv2(x))) 98 | if self.tanh: 99 | x = self.th(self.conv3(x)) 100 | else: 101 | x = self.conv3(x) 102 | return x 103 | 104 | class PointNetfeat(nn.Module): 105 | def __init__(self, npoint = 2500, nlatent = 1024): 106 | """Encoder""" 107 | 108 | super(PointNetfeat, self).__init__() 109 | self.conv1 = torch.nn.Conv1d(3, 64, 1) 110 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 111 | self.conv3 = torch.nn.Conv1d(128, nlatent, 1) 112 | self.lin = nn.Linear(nlatent, nlatent) 113 | 114 | self.bn1 = torch.nn.BatchNorm1d(64) 115 | self.bn2 = torch.nn.BatchNorm1d(128) 116 | self.bn3 = torch.nn.BatchNorm1d(nlatent) 117 | self.bn4 = torch.nn.BatchNorm1d(nlatent) 118 | 119 | self.npoint = npoint 120 | self.nlatent = nlatent 121 | 122 | def forward(self, x): 123 | batchsize = x.size()[0] 124 | x = F.relu(self.bn1(self.conv1(x))) 125 | x = F.relu(self.bn2(self.conv2(x))) 126 | x = self.bn3(self.conv3(x)) 127 | x,_ = torch.max(x, 2) 128 | x = x.view(-1, self.nlatent) 129 | x = F.relu(self.bn4(self.lin(x).unsqueeze(-1))) 130 | return x[...,0] 131 | 132 | class AtlasNet(nn.Module): 133 | """Atlas net auto encoder""" 134 | 135 | def __init__(self, options): 136 | 137 | super(AtlasNet, self).__init__() 138 | 139 | self.npoint = options.npoint 140 | self.npatch = options.npatch 141 | self.nlatent = options.nlatent 142 | self.patchDim = options.patchDim 143 | 144 | #encoder and decoder modules 145 | #============================================================================== 146 | self.encoder = PointNetfeat(self.npoint,self.nlatent) 147 | self.decoder = nn.ModuleList([mlpAdj(nlatent = 2 +self.nlatent) for i in range(0,self.npatch)]) 148 | #============================================================================== 149 | 150 | def forward(self, x): 151 | 152 | #encoder 153 | #============================================================================== 154 | x = self.encoder(x.transpose(2,1).contiguous()) 155 | #============================================================================== 156 | 157 | outs = [] 158 | patches = [] 159 | for i in range(0,self.npatch): 160 | 161 | #random patch 162 | #========================================================================== 163 | rand_grid = torch.FloatTensor(x.size(0),self.patchDim,self.npoint//self.npatch).cuda() 164 | rand_grid.data.uniform_(0,1) 165 | rand_grid[:,2:,:] = 0 166 | patches.append(rand_grid[0].transpose(1,0)) 167 | #========================================================================== 168 | 169 | #cat with latent vector and decode 170 | #========================================================================== 171 | y = x.unsqueeze(2).expand(x.size(0),x.size(1), rand_grid.size(2)).contiguous() 172 | y = torch.cat( (rand_grid, y), 1).contiguous() 173 | outs.append(self.decoder[i](y)) 174 | #========================================================================== 175 | 176 | return torch.cat(outs,2).transpose(2,1).contiguous(), patches 177 | 178 | 179 | class AtlasNetLinAdj(nn.Module): 180 | """Atlas net auto encoder""" 181 | 182 | def __init__(self, options): 183 | 184 | super(AtlasNetLinAdj, self).__init__() 185 | 186 | self.npoint = options.npoint 187 | self.npatch = options.npatch 188 | self.nlatent = options.nlatent 189 | self.patchDim = options.patchDim 190 | 191 | #encoder and decoder modules 192 | #============================================================================== 193 | self.encoder = PointNetfeat(self.npoint,self.nlatent) 194 | self.linearTransformMatrix = nn.ModuleList(linearAdj(dim=self.patchDim,nlatent=self.nlatent) for i in range(0,self.npatch)) 195 | #============================================================================== 196 | 197 | def forward(self, x): 198 | 199 | #encoder 200 | #============================================================================== 201 | x = self.encoder(x.transpose(2,1).contiguous()) 202 | #============================================================================== 203 | 204 | outs = [] 205 | patches = [] 206 | for i in range(0,self.npatch): 207 | 208 | #random patch 209 | #========================================================================== 210 | rand_grid = torch.FloatTensor(x.size(0),self.patchDim,self.npoint//self.npatch).cuda() 211 | rand_grid.data.uniform_(0,1) 212 | rand_grid[:,2:,:] = 0 213 | patches.append(rand_grid[0].transpose(1,0)) 214 | #========================================================================== 215 | 216 | #cat with latent vector and decode 217 | #========================================================================== 218 | R,T = self.linearTransformMatrix[i](x.unsqueeze(2)) 219 | rand_grid = torch.bmm(rand_grid.transpose(2,1),R) + T 220 | outs.append(rand_grid) 221 | #========================================================================== 222 | 223 | return torch.cat(outs,2).transpose(2,1).contiguous(), patches 224 | 225 | 226 | class PointTransMLPAdj(nn.Module): 227 | """Atlas net auto encoder""" 228 | 229 | def __init__(self, options): 230 | 231 | super(PointTransMLPAdj, self).__init__() 232 | 233 | self.npoint = options.npoint 234 | self.npatch = options.npatch 235 | self.nlatent = options.nlatent 236 | self.nbatch = options.nbatch 237 | self.dim = options.patchDim 238 | 239 | #encoder and decoder modules 240 | #============================================================================== 241 | self.encoder = PointNetfeat(self.npoint,self.nlatent) 242 | self.decoder = nn.ModuleList([mlpAdj(nlatent = self.dim + self.nlatent) for i in range(0,self.npatch)]) 243 | #============================================================================== 244 | 245 | #patch 246 | #============================================================================== 247 | self.grid = [] 248 | for patchIndex in range(self.npatch): 249 | patch = torch.nn.Parameter(torch.FloatTensor(1,self.dim,self.npoint//self.npatch)) 250 | patch.data.uniform_(0,1) 251 | patch.data[:,2:,:]=0 252 | self.register_parameter("patch%d"%patchIndex,patch) 253 | self.grid.append(patch) 254 | #============================================================================== 255 | 256 | def forward(self, x): 257 | 258 | #encoder 259 | #============================================================================== 260 | x = self.encoder(x.transpose(2,1).contiguous()) 261 | #============================================================================== 262 | 263 | outs = [] 264 | patches = [] 265 | 266 | for i in range(0,self.npatch): 267 | 268 | #random planar patch 269 | #========================================================================== 270 | rand_grid = self.grid[i].expand(x.size(0),-1,-1) 271 | patches.append(rand_grid[0].transpose(1,0)) 272 | #========================================================================== 273 | 274 | #cat with latent vector and decode 275 | #========================================================================== 276 | y = x.unsqueeze(2).expand(x.size(0),x.size(1), rand_grid.size(2)).contiguous() 277 | y = torch.cat( (rand_grid, y), 1).contiguous() 278 | outs.append(self.decoder[i](y)) 279 | #========================================================================== 280 | 281 | return torch.cat(outs,2).transpose(2,1).contiguous(), patches 282 | 283 | 284 | class PointTransLinAdj(nn.Module): 285 | """Ours auto encoder""" 286 | 287 | def __init__(self, options): 288 | 289 | super(PointTransLinAdj, self).__init__() 290 | 291 | self.npoint = options.npoint 292 | self.npatch = options.npatch 293 | self.nlatent = options.nlatent 294 | self.patchDim = options.patchDim 295 | self.patchDeformDim = options.patchDeformDim 296 | self.nbatch = options.nbatch 297 | 298 | #encoder decoder and patch deformation module 299 | #============================================================================== 300 | self.encoder = PointNetfeat(self.npoint,self.nlatent) 301 | self.linearTransformMatrix = nn.ModuleList(linearAdj(dim=self.patchDim,nlatent=self.nlatent) for i in range(0,self.npatch)) 302 | self.patchDeformation = nn.ModuleList(patchDeformationMLP(patchDim = self.patchDim, patchDeformDim = self.patchDeformDim) for i in range(0,self.npatch)) 303 | #============================================================================== 304 | 305 | #patch 306 | #============================================================================== 307 | self.grid = [] 308 | for patchIndex in range(self.npatch): 309 | patch = torch.nn.Parameter(torch.FloatTensor(1,self.patchDim,self.npoint//self.npatch)) 310 | patch.data.uniform_(0,1) 311 | patch.data[:,2:,:]=0 312 | self.register_parameter("patch%d"%patchIndex,patch) 313 | self.grid.append(patch) 314 | #============================================================================== 315 | 316 | def forward(self, x): 317 | 318 | #encode data 319 | #============================================================================== 320 | x = self.encoder(x.transpose(2,1).contiguous()) 321 | #============================================================================== 322 | 323 | outs = [] 324 | patches = [] 325 | for i in range(0,self.npatch): 326 | 327 | #random planar patch 328 | #========================================================================== 329 | rand_grid =self.grid[i].expand(x.size(0),-1,-1).transpose(2,1) 330 | patches.append(rand_grid[0]) 331 | #========================================================================== 332 | 333 | #apply linear tranformation to the patch 334 | #========================================================================== 335 | R,T = self.linearTransformMatrix[i](x.unsqueeze(2)) 336 | rand_grid = torch.bmm(rand_grid,R) + T 337 | outs.append(rand_grid) 338 | #========================================================================== 339 | 340 | return torch.cat(outs,1).contiguous().contiguous(), patches 341 | 342 | 343 | class PatchDeformMLPAdj(nn.Module): 344 | """Atlas net auto encoder""" 345 | 346 | def __init__(self, options): 347 | 348 | super(PatchDeformMLPAdj, self).__init__() 349 | 350 | self.npoint = options.npoint 351 | self.npatch = options.npatch 352 | self.nlatent = options.nlatent 353 | self.nbatch = options.nbatch 354 | self.patchDim = options.patchDim 355 | self.patchDeformDim = options.patchDeformDim 356 | 357 | #encoder decoder and patch deformation module 358 | #============================================================================== 359 | self.encoder = PointNetfeat(self.npoint,self.nlatent) 360 | self.decoder = nn.ModuleList([mlpAdj(nlatent = self.patchDeformDim + self.nlatent) for i in range(0,self.npatch)]) 361 | self.patchDeformation = nn.ModuleList(patchDeformationMLP(patchDim = self.patchDim, patchDeformDim = self.patchDeformDim) for i in range(0,self.npatch)) 362 | #============================================================================== 363 | 364 | def forward(self, x): 365 | 366 | #encoder 367 | #============================================================================== 368 | x = self.encoder(x.transpose(2,1).contiguous()) 369 | #============================================================================== 370 | 371 | outs = [] 372 | patches = [] 373 | for i in range(0,self.npatch): 374 | 375 | #random planar patch 376 | #========================================================================== 377 | rand_grid = torch.FloatTensor(x.size(0),self.patchDim,self.npoint//self.npatch).cuda() 378 | rand_grid.data.uniform_(0,1) 379 | rand_grid[:,2:,:] = 0 380 | rand_grid = self.patchDeformation[i](rand_grid.contiguous()) 381 | patches.append(rand_grid[0].transpose(1,0)) 382 | #========================================================================== 383 | 384 | #cat with latent vector and decode 385 | #========================================================================== 386 | y = x.unsqueeze(2).expand(x.size(0),x.size(1), rand_grid.size(2)).contiguous() 387 | y = torch.cat( (rand_grid, y), 1).contiguous() 388 | outs.append(self.decoder[i](y)) 389 | #========================================================================== 390 | 391 | return torch.cat(outs,2).transpose(2,1).contiguous(), patches 392 | 393 | 394 | class PatchDeformLinAdj(nn.Module): 395 | """Ours auto encoder""" 396 | 397 | def __init__(self, options): 398 | 399 | super(PatchDeformLinAdj, self).__init__() 400 | 401 | self.npoint = options.npoint 402 | self.npatch = options.npatch 403 | self.nlatent = options.nlatent 404 | self.patchDim = options.patchDim 405 | self.patchDeformDim = options.patchDeformDim 406 | 407 | #encoder decoder and patch deformation module 408 | #============================================================================== 409 | self.encoder = PointNetfeat(self.npoint,self.nlatent) 410 | self.linearTransformMatrix = nn.ModuleList(linearAdj(dim = self.patchDeformDim,nlatent=self.nlatent) for i in range(0,self.npatch)) 411 | self.patchDeformation = nn.ModuleList(patchDeformationMLP(patchDim = self.patchDim, patchDeformDim = self.patchDeformDim) for i in range(0,self.npatch)) 412 | #============================================================================== 413 | 414 | def forward(self, x): 415 | 416 | #encode data 417 | #============================================================================== 418 | x = self.encoder(x.transpose(2,1).contiguous()) 419 | #============================================================================== 420 | 421 | outs = [] 422 | patches = [] 423 | 424 | for i in range(0,self.npatch): 425 | 426 | #random planar patch 427 | #========================================================================== 428 | rand_grid = torch.FloatTensor(x.size(0),self.patchDim,self.npoint//self.npatch).cuda() 429 | rand_grid.data.uniform_(0,1) 430 | rand_grid[:,2:,:] = 0 431 | #========================================================================== 432 | 433 | #deform the planar patch 434 | #========================================================================== 435 | rand_grid = self.patchDeformation[i](rand_grid.contiguous()).transpose(2,1) 436 | patches.append(rand_grid[0]) 437 | #========================================================================== 438 | 439 | #apply linear tranformation to the patch 440 | #========================================================================== 441 | R,T = self.linearTransformMatrix[i](x.unsqueeze(2)) 442 | rand_grid = torch.bmm(rand_grid,R) + T 443 | outs.append(rand_grid) 444 | #========================================================================== 445 | 446 | return torch.cat(outs,1).contiguous().contiguous(), patches 447 | 448 | class MODEL_LIST: 449 | """list of all the model""" 450 | def __init__(self, options): 451 | if options.adjust == 'mlp': 452 | self.models = {'AtlasNet':AtlasNet,'PointTranslation':PointTransMLPAdj,'PatchDeformation':PatchDeformMLPAdj} 453 | elif options.adjust == 'linear': 454 | self.models = {'AtlasNet':AtlasNetLinAdj,'PointTranslation':PointTransLinAdj,'PatchDeformation':PatchDeformLinAdj} 455 | else: 456 | print(colors.FAIL,"ERROR please select the model from : ",colors.ENDC) 457 | print(" > mlp") 458 | print(" > linear") 459 | exit() 460 | 461 | self.type = self.models.keys() 462 | 463 | def load(self,options): 464 | 465 | return self.models[options.model](options).cuda() 466 | --------------------------------------------------------------------------------