├── README.md ├── Teaser.png ├── binvox_rw.py ├── complete_train.sh ├── deform_train.sh ├── joint_test.sh ├── joint_train.sh ├── main.py ├── models.py ├── patch_train.sh ├── pc_util.py ├── splits ├── compose2.py ├── compose_total.py ├── content_boat_test.txt ├── content_boat_train.txt ├── content_cabinet_test.txt ├── content_cabinet_train.txt ├── content_car_test.txt ├── content_car_train.txt ├── content_chair_test.txt ├── content_chair_train.txt ├── content_couch_test.txt ├── content_couch_train.txt ├── content_lamp_test.txt ├── content_lamp_train.txt ├── content_plane_test.txt ├── content_plane_train.txt ├── content_table_test.txt ├── content_table_train.txt └── merge_splits.py ├── train_complete.py ├── train_deform.py ├── train_joint.py ├── train_patch.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # PatchRD 2 | Code for ECCV 2022 paper PatchRD: Detail-Preserving Shape Completion by Learning Patch Retrieval and Deformation. ([PDF](https://arxiv.org/pdf/2207.11790.pdf)) 3 |

4 | 5 |

6 | 7 | ## Installation 8 | Install [Pytorch](https://pytorch.org/get-started/locally/). It is required that you have access to GPUs. The code is tested with Ubuntu 18.04, Pytorch v1.9.0 and CUDA 10.2. 9 | 10 | Install [ChamferDistancePytorch](https://github.com/ThibaultGROUEIX/ChamferDistancePytorch) by following the instructions in their github repo. 11 | 12 | Install [psbody](https://github.com/MPI-IS/mesh) for mesh reading and writing. 13 | 14 | Install the following Python dependencies (with `pip install`): 15 | 16 | numpy 17 | scipy 18 | h5py 19 | sklearn 20 | skimage 21 | opencv-python 22 | kornia 23 | pytorch=1.9.0 24 | 25 | ## Data Download 26 | You can download the dataset [here](https://www.dropbox.com/s/phiw7kjw7d1dfu4/data.zip?dl=0). It's a voxelized dataset for 8 classes in ShapeNet. Train/Test spliting is in `./splits`. 27 | To evaluate the chamfer distance, we convert our output to point cloud and use the point cloud with 16384 points from [Completion3D](https://completion3d.stanford.edu/) dataset as the ground truth. You can either download it from the offcial website or directly [here](https://www.dropbox.com/s/e1g071m10xm7zug/completion3d.zip?dl=0). 28 | 29 | 30 | ## Demo with Pre-trained Models 31 | You can download the pre-trained models for the chair category [here](https://www.dropbox.com/s/y6yzehq39ereekx/checkpoints_chair.zip?dl=0). 32 | 33 | To run the completion pipeline, first **uncomment `--dump_deform`, `--mode test`, `--small_dataset` in `patch_train.sh` and `deform_train.sh`, then uncomment `--small_dataset` in `joint_test.sh`**. 34 | 35 | Then run: 36 | 37 | ``` 38 | ./patch_train.sh 39 | ``` 40 | ``` 41 | ./deform_train.sh 42 | ``` 43 | ``` 44 | ./joint_test.sh 45 | ``` 46 | You will get the input, output meshes for a small set of samples in the `samples_joint` folder. 47 | 48 | 49 | 50 | ## Training 51 | There are four steps to train our framework. 52 | The default settings train model for the chair category. You can change the arguments `--data_content` and `--data_dir` in each bash script to train on other categories. 53 | 54 | ### Step 1: Coarse Completion 55 | This stage takes a partial shape as input, and output the coarse full shape (4x downsampled shape of the detailed full shape). 56 | ``` 57 | ./complete_train.sh 58 | ``` 59 | 60 | ### Step 2: Retrieval Learning 61 | First train the patch encoder to learn the feature embeddings for the coarse and detailed patch pair. 62 | ``` 63 | ./patch_train.sh 64 | ``` 65 | Then dump the intermediate retrieval results to train the deformation. **You have to add `--dump_deform` in `patch_train.sh`**, then run 66 | 67 | ``` 68 | ./patch_train.sh 69 | ``` 70 | 71 | ### Step 3: Initial Deformation Learning 72 | This step learns the initial deformation for the jointly learning stage. Run 73 | 74 | ``` 75 | ./deform_train.sh 76 | ``` 77 | Then dump the intermediate initial deformation results to train the deformation and blending stage. **You have to add `--dump_deform` in `deform_train.sh`**, then run 78 | 79 | ``` 80 | ./deform_train.sh 81 | ``` 82 | 83 | ### Step 4: Deformation and Blending 84 | This step joinly learn the deformation and blending. Run 85 | ``` 86 | ./joint_train.sh 87 | ``` 88 | 89 | ## Evaluation 90 | After training all four networks, you can run the following scripts to run the completion results from a randomly cropped shape. You can find the input and output mesh in `./samples_joint`. 91 | ``` 92 | ./joint_test.sh 93 | ``` 94 | 95 | 96 | ## Citation 97 | If you find our work useful in your research, please consider citing: 98 | 99 | @inproceedings{sun2022patchrd, 100 | author = {Bo, Sun and Kim, Vladimir(Vova) and Huang, Qixing and Aigerman, Noam and Chaudhuri, Siddhartha}, 101 | title = {PatchRD: Detail-Preserving Shape Completion by Learning Patch Retrieval and Deformation}, 102 | booktitle = {Proceedings of the IEEE European Conference on Computer Vision}, 103 | year = {2022} 104 | } 105 | -------------------------------------------------------------------------------- /Teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitBoSun/PatchRD/223af88e57f31daf2a016683d2804921fab903e2/Teaser.png -------------------------------------------------------------------------------- /binvox_rw.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2012 Daniel Maturana 2 | # This file is part of binvox-rw-py. 3 | # 4 | # binvox-rw-py is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # binvox-rw-py is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU General Public License 15 | # along with binvox-rw-py. If not, see . 16 | # 17 | 18 | """ 19 | Binvox to Numpy and back. 20 | 21 | 22 | >>> import numpy as np 23 | >>> import binvox_rw 24 | >>> with open('chair.binvox', 'rb') as f: 25 | ... m1 = binvox_rw.read_as_3d_array(f) 26 | ... 27 | >>> m1.dims 28 | [32, 32, 32] 29 | >>> m1.scale 30 | 41.133000000000003 31 | >>> m1.translate 32 | [0.0, 0.0, 0.0] 33 | >>> with open('chair_out.binvox', 'wb') as f: 34 | ... m1.write(f) 35 | ... 36 | >>> with open('chair_out.binvox', 'rb') as f: 37 | ... m2 = binvox_rw.read_as_3d_array(f) 38 | ... 39 | >>> m1.dims==m2.dims 40 | True 41 | >>> m1.scale==m2.scale 42 | True 43 | >>> m1.translate==m2.translate 44 | True 45 | >>> np.all(m1.data==m2.data) 46 | True 47 | 48 | >>> with open('chair.binvox', 'rb') as f: 49 | ... md = binvox_rw.read_as_3d_array(f) 50 | ... 51 | >>> with open('chair.binvox', 'rb') as f: 52 | ... ms = binvox_rw.read_as_coord_array(f) 53 | ... 54 | >>> data_ds = binvox_rw.dense_to_sparse(md.data) 55 | >>> data_sd = binvox_rw.sparse_to_dense(ms.data, 32) 56 | >>> np.all(data_sd==md.data) 57 | True 58 | >>> # the ordering of elements returned by numpy.nonzero changes with axis 59 | >>> # ordering, so to compare for equality we first lexically sort the voxels. 60 | >>> np.all(ms.data[:, np.lexsort(ms.data)] == data_ds[:, np.lexsort(data_ds)]) 61 | True 62 | """ 63 | 64 | import numpy as np 65 | 66 | class Voxels(object): 67 | """ Holds a binvox model. 68 | data is either a three-dimensional numpy boolean array (dense representation) 69 | or a two-dimensional numpy float array (coordinate representation). 70 | 71 | dims, translate and scale are the model metadata. 72 | 73 | dims are the voxel dimensions, e.g. [32, 32, 32] for a 32x32x32 model. 74 | 75 | scale and translate relate the voxels to the original model coordinates. 76 | 77 | To translate voxel coordinates i, j, k to original coordinates x, y, z: 78 | 79 | x_n = (i+.5)/dims[0] 80 | y_n = (j+.5)/dims[1] 81 | z_n = (k+.5)/dims[2] 82 | x = scale*x_n + translate[0] 83 | y = scale*y_n + translate[1] 84 | z = scale*z_n + translate[2] 85 | 86 | """ 87 | 88 | def __init__(self, data, dims, translate, scale, axis_order): 89 | self.data = data 90 | self.dims = dims 91 | self.translate = translate 92 | self.scale = scale 93 | assert (axis_order in ('xzy', 'xyz')) 94 | self.axis_order = axis_order 95 | 96 | def clone(self): 97 | data = self.data.copy() 98 | dims = self.dims[:] 99 | translate = self.translate[:] 100 | return Voxels(data, dims, translate, self.scale, self.axis_order) 101 | 102 | def write(self, fp): 103 | write(self, fp) 104 | 105 | def read_header(fp): 106 | """ Read binvox header. Mostly meant for internal use. 107 | """ 108 | line = fp.readline().strip() 109 | if not line.startswith(b'#binvox'): 110 | raise IOError('Not a binvox file') 111 | dims = list(map(int, fp.readline().strip().split(b' ')[1:])) 112 | translate = list(map(float, fp.readline().strip().split(b' ')[1:])) 113 | scale = list(map(float, fp.readline().strip().split(b' ')[1:]))[0] 114 | line = fp.readline() 115 | return dims, translate, scale 116 | 117 | def read_as_3d_array(fp, fix_coords=True): 118 | """ Read binary binvox format as array. 119 | 120 | Returns the model with accompanying metadata. 121 | 122 | Voxels are stored in a three-dimensional numpy array, which is simple and 123 | direct, but may use a lot of memory for large models. (Storage requirements 124 | are 8*(d^3) bytes, where d is the dimensions of the binvox model. Numpy 125 | boolean arrays use a byte per element). 126 | 127 | Doesn't do any checks on input except for the '#binvox' line. 128 | """ 129 | dims, translate, scale = read_header(fp) 130 | raw_data = np.frombuffer(fp.read(), dtype=np.uint8) 131 | # if just using reshape() on the raw data: 132 | # indexing the array as array[i,j,k], the indices map into the 133 | # coords as: 134 | # i -> x 135 | # j -> z 136 | # k -> y 137 | # if fix_coords is true, then data is rearranged so that 138 | # mapping is 139 | # i -> x 140 | # j -> y 141 | # k -> z 142 | values, counts = raw_data[::2], raw_data[1::2] 143 | data = np.repeat(values, counts).astype(np.bool) 144 | data = data.reshape(dims) 145 | if fix_coords: 146 | # xzy to xyz TODO the right thing 147 | data = np.transpose(data, (0, 2, 1)) 148 | axis_order = 'xyz' 149 | else: 150 | axis_order = 'xzy' 151 | return Voxels(data, dims, translate, scale, axis_order) 152 | 153 | def read_as_coord_array(fp, fix_coords=True): 154 | """ Read binary binvox format as coordinates. 155 | 156 | Returns binvox model with voxels in a "coordinate" representation, i.e. an 157 | 3 x N array where N is the number of nonzero voxels. Each column 158 | corresponds to a nonzero voxel and the 3 rows are the (x, z, y) coordinates 159 | of the voxel. (The odd ordering is due to the way binvox format lays out 160 | data). Note that coordinates refer to the binvox voxels, without any 161 | scaling or translation. 162 | 163 | Use this to save memory if your model is very sparse (mostly empty). 164 | 165 | Doesn't do any checks on input except for the '#binvox' line. 166 | """ 167 | dims, translate, scale = read_header(fp) 168 | raw_data = np.frombuffer(fp.read(), dtype=np.uint8) 169 | 170 | values, counts = raw_data[::2], raw_data[1::2] 171 | 172 | sz = np.prod(dims) 173 | index, end_index = 0, 0 174 | end_indices = np.cumsum(counts) 175 | indices = np.concatenate(([0], end_indices[:-1])).astype(end_indices.dtype) 176 | 177 | values = values.astype(np.bool) 178 | indices = indices[values] 179 | end_indices = end_indices[values] 180 | 181 | nz_voxels = [] 182 | for index, end_index in zip(indices, end_indices): 183 | nz_voxels.extend(range(index, end_index)) 184 | nz_voxels = np.array(nz_voxels) 185 | # TODO are these dims correct? 186 | # according to docs, 187 | # index = x * wxh + z * width + y; // wxh = width * height = d * d 188 | 189 | x = nz_voxels / (dims[0]*dims[1]) 190 | zwpy = nz_voxels % (dims[0]*dims[1]) # z*w + y 191 | z = zwpy / dims[0] 192 | y = zwpy % dims[0] 193 | if fix_coords: 194 | data = np.vstack((x, y, z)) 195 | axis_order = 'xyz' 196 | else: 197 | data = np.vstack((x, z, y)) 198 | axis_order = 'xzy' 199 | 200 | #return Voxels(data, dims, translate, scale, axis_order) 201 | return Voxels(np.ascontiguousarray(data), dims, translate, scale, axis_order) 202 | 203 | def dense_to_sparse(voxel_data, dtype=np.int): 204 | """ From dense representation to sparse (coordinate) representation. 205 | No coordinate reordering. 206 | """ 207 | if voxel_data.ndim!=3: 208 | raise ValueError('voxel_data is wrong shape; should be 3D array.') 209 | return np.asarray(np.nonzero(voxel_data), dtype) 210 | 211 | def sparse_to_dense(voxel_data, dims, dtype=np.bool): 212 | if voxel_data.ndim!=2 or voxel_data.shape[0]!=3: 213 | raise ValueError('voxel_data is wrong shape; should be 3xN array.') 214 | if np.isscalar(dims): 215 | dims = [dims]*3 216 | dims = np.atleast_2d(dims).T 217 | # truncate to integers 218 | xyz = voxel_data.astype(np.int) 219 | # discard voxels that fall outside dims 220 | valid_ix = ~np.any((xyz < 0) | (xyz >= dims), 0) 221 | xyz = xyz[:,valid_ix] 222 | out = np.zeros(dims.flatten(), dtype=dtype) 223 | out[tuple(xyz)] = True 224 | return out 225 | 226 | #def get_linear_index(x, y, z, dims): 227 | #""" Assuming xzy order. (y increasing fastest. 228 | #TODO ensure this is right when dims are not all same 229 | #""" 230 | #return x*(dims[1]*dims[2]) + z*dims[1] + y 231 | 232 | def write(voxel_model, fp): 233 | """ Write binary binvox format. 234 | 235 | Note that when saving a model in sparse (coordinate) format, it is first 236 | converted to dense format. 237 | 238 | Doesn't check if the model is 'sane'. 239 | 240 | """ 241 | if voxel_model.data.ndim==2: 242 | # TODO avoid conversion to dense 243 | dense_voxel_data = sparse_to_dense(voxel_model.data, voxel_model.dims) 244 | else: 245 | dense_voxel_data = voxel_model.data 246 | 247 | fp.write('#binvox 1\n') 248 | fp.write('dim '+' '.join(map(str, voxel_model.dims))+'\n') 249 | fp.write('translate '+' '.join(map(str, voxel_model.translate))+'\n') 250 | fp.write('scale '+str(voxel_model.scale)+'\n') 251 | fp.write('data\n') 252 | if not voxel_model.axis_order in ('xzy', 'xyz'): 253 | raise ValueError('Unsupported voxel model axis order') 254 | 255 | if voxel_model.axis_order=='xzy': 256 | voxels_flat = dense_voxel_data.flatten() 257 | elif voxel_model.axis_order=='xyz': 258 | voxels_flat = np.transpose(dense_voxel_data, (0, 2, 1)).flatten() 259 | 260 | # keep a sort of state machine for writing run length encoding 261 | state = voxels_flat[0] 262 | ctr = 0 263 | for c in voxels_flat: 264 | if c==state: 265 | ctr += 1 266 | # if ctr hits max, dump 267 | if ctr==255: 268 | fp.write(chr(state)) 269 | fp.write(chr(ctr)) 270 | ctr = 0 271 | else: 272 | # if switch state, dump 273 | fp.write(chr(state)) 274 | fp.write(chr(ctr)) 275 | state = c 276 | ctr = 1 277 | # flush out remainders 278 | if ctr > 0: 279 | fp.write(chr(state)) 280 | fp.write(chr(ctr)) 281 | 282 | if __name__ == '__main__': 283 | import doctest 284 | doctest.testmod() 285 | -------------------------------------------------------------------------------- /complete_train.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --data_content content_chair \ 3 | --data_dir ./data/03001627/ \ 4 | --input_size 32 \ 5 | --output_size 128 \ 6 | --sample_dir samples_complete \ 7 | --checkpoint_dir checkpoint_complete \ 8 | --gpu 0 \ 9 | --epoch 101 \ 10 | --lr 1e-4 \ 11 | --g_dim 32 \ 12 | --w_posi 1.5 \ 13 | --w_mask 0.5 \ 14 | --csize 26 \ 15 | --c_range 26 \ 16 | --train_complete \ 17 | --model_name coarse_comp \ 18 | #--mode test \ 19 | #--continue_train \ 20 | -------------------------------------------------------------------------------- /deform_train.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --data_content content_chair \ 3 | --data_dir ./data/03001627/ \ 4 | --input_size 32 \ 5 | --output_size 128 \ 6 | --sample_dir samples_deform \ 7 | --checkpoint_dir checkpoint_deform \ 8 | --gpu 0 \ 9 | --epoch 100 \ 10 | --lr 1e-4 \ 11 | --z_dim 3 \ 12 | --g_dim 32 \ 13 | --K 2 \ 14 | --w_dis 3.0 \ 15 | --csize 26 \ 16 | --c_range 26 \ 17 | --max_sample_num 600 \ 18 | --train_deform \ 19 | --model_name chair_model0 \ 20 | --dump_deform_path ./dump_deform/chair/model0 \ 21 | #--dump_deform \ 22 | #--small_dataset \ 23 | #--mode test \ 24 | #--continue_train \ 25 | -------------------------------------------------------------------------------- /joint_test.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --data_content content_chair \ 3 | --data_dir ./data/03001627/ \ 4 | --input_size 32 \ 5 | --output_size 128 \ 6 | --sample_dir samples_joint \ 7 | --checkpoint_dir checkpoint_joint \ 8 | --log_dir logs/joint \ 9 | --gpu 0 \ 10 | --epoch 1 \ 11 | --lr 2e-4 \ 12 | --g_dim 32 \ 13 | --gw_dim 32 \ 14 | --max_wd_num 400 \ 15 | --loc_size 5 \ 16 | --wd_size 8 \ 17 | --sample_step 2 \ 18 | --trans_limit 2.0 \ 19 | --w_s 10 \ 20 | --train_joint \ 21 | --model_name chair_model0 \ 22 | --dump_deform_path ./dump_deform/chair/model0 \ 23 | --compute_cd \ 24 | --mode test \ 25 | #--small_dataset \ 26 | -------------------------------------------------------------------------------- /joint_train.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --data_content content_chair \ 3 | --data_dir ./data/03001627/ \ 4 | --input_size 32 \ 5 | --output_size 128 \ 6 | --sample_dir samples_joint \ 7 | --checkpoint_dir checkpoint_joint \ 8 | --log_dir logs/joint \ 9 | --gpu 0 \ 10 | --epoch 1 \ 11 | --lr 2e-4 \ 12 | --g_dim 32 \ 13 | --gw_dim 32 \ 14 | --max_wd_num 400 \ 15 | --loc_size 5 \ 16 | --wd_size 8 \ 17 | --sample_step 2 \ 18 | --trans_limit 2.0 \ 19 | --w_s 10 \ 20 | --train_joint \ 21 | --model_name chair_model0 \ 22 | --dump_deform_path ./dump_deform/chair/model0 \ 23 | #--small_dataset \ 24 | #--compute_cd \ 25 | #--continue_train \ 26 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import argparse 4 | 5 | parser = argparse.ArgumentParser() 6 | 7 | # training options 8 | parser.add_argument("--train_complete", action="store_true", dest="train_complete", default=False, help="True for training the coarse completor") 9 | parser.add_argument("--train_patch", action="store_true", dest="train_patch", default=False, help="True for training feature embedding for retrieval") 10 | parser.add_argument("--train_deform", action="store_true", dest="train_deform", default=False, help="True for training initial deformation") 11 | parser.add_argument("--train_joint", action="store_true", dest="train_joint", default=False, help="True for jointly training deformation and blending") 12 | parser.add_argument("--mode", action="store", dest="mode", default="train", help="mode: train or test") 13 | parser.add_argument("--gpu", action="store", dest="gpu", default="0", help="to use which GPU") 14 | 15 | # training parameters 16 | parser.add_argument("--epoch", action="store", dest="epoch", default=20, type=int, help="Epoch to train") 17 | parser.add_argument("--batch_size", action="store", dest="batch_size", default=1, type=int, help="Batch size [1]") 18 | parser.add_argument("--lr", action="store", dest="lr", default=1e-4, type=float, help="Learning rate") 19 | parser.add_argument("--decay_step", action="store", dest="decay_step", default=1, type=int, help="lr decay step") 20 | parser.add_argument("--lr_decay", action="store", dest="lr_decay", default=0.99, type=float, help="lr decay rate") 21 | parser.add_argument("--continue_train", action="store_true", dest="continue_train", default=False, help="If True, continue training, otherwise train from scratch") 22 | 23 | # dump and eval options 24 | parser.add_argument("--dump_deform", action="store_true", dest="dump_deform", default=False, help="True for dumping intermediate retrieval and deformation results in mode --train_patch and --train_deform") 25 | parser.add_argument("--compute_cd", action="store_true", dest="compute_cd", default=False, help="True for evaluating chamfer distance and dumping mesh resulst in mode --train_joint") 26 | parser.add_argument("--small_dataset", action="store_true", dest="small_dataset", default=False, help="True for only use four samples in the testing mode") 27 | 28 | 29 | # model parameters 30 | parser.add_argument("--g_dim", action="store", dest="g_dim", default=32, type=int, help="Channel dimension for models") 31 | parser.add_argument("--z_dim", action="store", dest="z_dim", default=8, type=int, help="Dimension of the latent feature embedding for retrieval learning") 32 | 33 | # data parameters 34 | parser.add_argument("--input_size", action="store", dest="input_size", default=32, type=int, help="Input voxel size") 35 | parser.add_argument("--output_size", action="store", dest="output_size", default=128, type=int, help="Output voxel size") 36 | parser.add_argument("--patch_size", action="store", dest="patch_size", default=18, type=int, help="Patch size for retrieval and deformation") 37 | parser.add_argument("--csize", action="store", dest="csize", default=26, type=int, help="Average size of cropped area") 38 | parser.add_argument("--c_range", action="store", dest="c_range", default=26, type=int, help="Devation of cropped size") 39 | parser.add_argument("--K", action="store", dest="K", default=2, type=int, help="ratio of randomly sampled patch pairs VS similiar pairs in retrieval learning") 40 | parser.add_argument("--max_sample_num", action="store", dest="max_sample_num", default=200, type=int, help="Maximum number of patches in retrieval learning") 41 | 42 | 43 | # paths 44 | parser.add_argument("--data_content", action="store", dest="data_content", help="Data category. See ./splits for all categories") 45 | parser.add_argument("--data_dir", action="store", dest="data_dir", help="Root directory of dataset") 46 | parser.add_argument("--dump_deform_path", action="store", dest="dump_deform_path", help="Directory to dump intermidiate retrieval and initial deformation results.") 47 | parser.add_argument("--model_name", action="store", dest="model_name", default="checkpoint", help="Model name.") 48 | parser.add_argument("--checkpoint_dir", action="store", dest="checkpoint_dir", default="checkpoint", help="Directory name to save the checkpoints") 49 | parser.add_argument("--sample_dir", action="store", dest="sample_dir", default="./samples/", help="Directory name to save the image samples") 50 | parser.add_argument("--log_dir", action="store", dest="log_dir", default="./logs/", help="Directory name to save the training logs") 51 | 52 | 53 | # loss weights 54 | parser.add_argument("--w_mask", action="store", dest="w_mask", default=1, type=float, help="[coarse completion]: weight for the occupied area in the reconstruction loss") 55 | parser.add_argument("--w_posi", action="store", dest="w_posi", default=1, type=float, help="[coarse completion]: weight for the occupied area in the cross entropy loss") 56 | parser.add_argument("--w_ident", action="store", dest="w_ident", default=0, type=float, help="[retrieval learning]: weight for similar pairs") 57 | parser.add_argument("--w_dis", action="store", dest="w_dis", default=1, type=float, help="[initial deformation]: weight for the patch distance prediction") 58 | parser.add_argument("--w_r", action="store", dest="w_r", default=1, type=float, help="[deformation and blending]: weight for reconstruction term") 59 | parser.add_argument("--w_s", action="store", dest="w_s", default=0, type=float, help="[deformation and blending]: weight for smoothness term") 60 | 61 | # parameters for training deformation and blending (--train_joint) 62 | parser.add_argument("--max_wd_num", action="store", dest="max_wd_num", default=32, type=int, help="Maximum number of windows") 63 | parser.add_argument("--sample_step", action="store", dest="sample_step", default=32, type=int, help="Sample stride of retrieved patches") 64 | parser.add_argument("--loc_size", action="store", dest="loc_size", default=32, type=int, help="Number of windows in one subvolume") 65 | parser.add_argument("--wd_size", action="store", dest="wd_size", default=8, type=int, help="Window size") 66 | parser.add_argument("--trans_limit", action="store", dest="trans_limit", default=1e-4, type=float, help="Translation limit upon the initial deformation") 67 | parser.add_argument("--gw_dim", action="store", dest="gw_dim", default=32, type=int, help="Channel dimension for the window encoding branch") 68 | 69 | FLAGS = parser.parse_args() 70 | 71 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 72 | os.environ["CUDA_VISIBLE_DEVICES"]=FLAGS.gpu 73 | 74 | 75 | if FLAGS.train_patch: 76 | from train_patch import MODEL_PATCH 77 | model = MODEL_PATCH(FLAGS) 78 | model.train(FLAGS) 79 | 80 | elif FLAGS.train_complete: 81 | from train_complete import MODEL_COMPLETE 82 | model = MODEL_COMPLETE(FLAGS) 83 | model.train(FLAGS) 84 | 85 | elif FLAGS.train_deform: 86 | print('***********train deform') 87 | from train_deform import MODEL_DEFORM 88 | model = MODEL_DEFORM(FLAGS) 89 | model.train(FLAGS) 90 | 91 | elif FLAGS.train_joint: 92 | print('***********train joint') 93 | from train_joint import MODEL_JOINT 94 | model = MODEL_JOINT(FLAGS) 95 | model.train(FLAGS) 96 | else: 97 | print('Invalid training options!') 98 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.backends.cudnn as cudnn 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch import optim 6 | from torch.autograd import Variable 7 | import numpy as np 8 | 9 | #cell = 4 10 | #input 256 11 | #output 120 (128-4-4) 12 | #receptive field = 18 13 | 14 | # 0 18 15 | #conv 4x4 s1 4 15 16 | #conv 3x3 s2 6 7 17 | #conv 3x3 s1 10 5 18 | #conv 3x3 s1 14 3 19 | #conv 3x3 s1 18 1 20 | #conv 1x1 s1 1 1 21 | 22 | # 0 41 23 | #conv 3x3 s1 4 39 24 | #conv 3x3 s2 6 19 25 | #conv 3x3 s1 10 17 26 | #conv 3x3 s1 14 15 27 | #conv 3x3 s2 18 7 28 | #conv 3x3 s1 26 7 29 | 30 | leaky_f = 0.02 31 | USE_25=True 32 | 33 | class CoarseCompletor_skip(nn.Module): 34 | def __init__(self, d_dim, ): 35 | super(CoarseCompletor_skip, self).__init__() 36 | self.d_dim = d_dim 37 | 38 | self.conv_1 = nn.Conv3d(1, self.d_dim, 5, stride=1, padding=2, bias=True) 39 | self.conv_2 = nn.Conv3d(self.d_dim, self.d_dim*2, 3, stride=2, padding=1, bias=True) # 64 40 | 41 | self.conv_3 = nn.Conv3d(self.d_dim*2, self.d_dim*2, 3, stride=1, padding=1, bias=True) 42 | self.conv_4 = nn.Conv3d(self.d_dim*2, self.d_dim*4, 3, stride=2, padding=1, bias=True) # 32 43 | 44 | self.conv_5 = nn.Conv3d(self.d_dim*4, self.d_dim*4, 3, stride=1, padding=1, bias=True) 45 | self.conv_6 = nn.Conv3d(self.d_dim*4, self.d_dim*8, 3, stride=2, padding=1, bias=True) # 16 46 | 47 | self.conv_7 = nn.Conv3d(self.d_dim*8, self.d_dim*8, 3, stride=1, padding=1, bias=True) 48 | self.conv_8 = nn.Conv3d(self.d_dim*8, self.d_dim*8, 3, stride=2, padding=1, bias=True) # 8 49 | 50 | self.conv_9 = nn.Conv3d(self.d_dim*8, self.d_dim*8, 3, stride=2, padding=1, bias=True) # 4 51 | self.conv_10 = nn.Conv3d(self.d_dim*8, self.d_dim*8, 3, stride=1, padding=1, bias=True) # 4 52 | 53 | self.dconv_1 = nn.ConvTranspose3d(self.d_dim*8, self.d_dim*8, 4, stride=2, padding=1, bias=True) # 8 54 | 55 | self.dconv_2 = nn.Conv3d(self.d_dim*16, self.d_dim*8, 3, stride=1, padding=1, bias=True) 56 | self.dconv_3 = nn.ConvTranspose3d(self.d_dim*8, self.d_dim*8, 4, stride=2, padding=1, bias=True) # 16 57 | 58 | self.dconv_4 = nn.Conv3d(self.d_dim*16, self.d_dim*8, 3, stride=1, padding=1, bias=True) 59 | self.dconv_5 = nn.ConvTranspose3d(self.d_dim*8, self.d_dim*4, 4, stride=2, padding=1, bias=True) # 32 60 | 61 | self.dconv_6 = nn.Conv3d(self.d_dim*8, self.d_dim*4, 3, stride=1, padding=1, bias=True) 62 | self.dconv_7 = nn.ConvTranspose3d(self.d_dim*4, self.d_dim*2, 4, stride=2, padding=1, bias=True) # 64 63 | 64 | self.dconv_8 = nn.Conv3d(self.d_dim*4, self.d_dim*2, 3, stride=1, padding=1, bias=True) 65 | self.dconv_9 = nn.ConvTranspose3d(self.d_dim*2, self.d_dim, 4, stride=2, padding=1, bias=True) # 128 66 | 67 | self.dconv_10 = nn.Conv3d(self.d_dim, 1, 3, stride=1, padding=1, bias=True) 68 | 69 | def forward(self, partial_in, is_training=False): 70 | out = partial_in 71 | out = self.conv_1(out) 72 | out1 = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 73 | 74 | out = self.conv_2(out1) 75 | out2 = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 76 | 77 | out = self.conv_3(out2) 78 | out3 = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 79 | 80 | out = self.conv_4(out3) 81 | out4 = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 82 | 83 | out = self.conv_5(out4) 84 | out5 = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 85 | 86 | out = self.conv_6(out5) 87 | out6 = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 88 | 89 | out = self.conv_7(out6) 90 | out7 = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 91 | 92 | out = self.conv_8(out7) 93 | out8 = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 94 | 95 | out = self.conv_9(out8) 96 | out9 = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 97 | 98 | out = self.conv_10(out9) 99 | out10 = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 100 | 101 | out = self.dconv_1(out10) 102 | outd1 = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 103 | 104 | out = torch.cat([outd1, out8], dim=1) 105 | out = self.dconv_2(out) 106 | outd2 = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 107 | 108 | out = self.dconv_3(outd2) 109 | outd3 = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 110 | 111 | out = torch.cat([outd3, out6], dim=1) 112 | out = self.dconv_4(out) 113 | outd4 = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 114 | 115 | out = self.dconv_5(outd4) 116 | outd5 = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 117 | 118 | out = torch.cat([outd5, out4], dim=1) 119 | out = self.dconv_6(out) 120 | outd6 = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 121 | 122 | out = self.dconv_7(outd6) 123 | outd7 = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 124 | 125 | out = torch.cat([outd7, out2], dim=1) 126 | out = self.dconv_8(out) 127 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 128 | 129 | out = self.dconv_9(out) 130 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 131 | 132 | out = self.dconv_10(out) 133 | 134 | #out = torch.max(torch.min(out, out*0.002+0.998), out*0.002-0.998) 135 | out = torch.sigmoid(out) 136 | return out 137 | 138 | 139 | class PatchEncoder(nn.Module): 140 | def __init__(self, d_dim, z_dim): 141 | super(PatchEncoder, self).__init__() 142 | self.d_dim = d_dim 143 | self.z_dim = z_dim 144 | 145 | self.conv_1 = nn.Conv3d(1, self.d_dim, 4, stride=1, padding=0, bias=True) 146 | self.conv_2 = nn.Conv3d(self.d_dim, self.d_dim*2, 3, stride=2, padding=0, bias=True) 147 | self.conv_3 = nn.Conv3d(self.d_dim*2, self.d_dim*4, 3, stride=1, padding=0, bias=True) 148 | self.conv_4 = nn.Conv3d(self.d_dim*4, self.d_dim*8, 3, stride=1, padding=0, bias=True) 149 | self.conv_5 = nn.Conv3d(self.d_dim*8, self.d_dim*16, 3, stride=2, padding=0, bias=True) 150 | if USE_25: 151 | self.conv_6 = nn.Conv3d(self.d_dim*16, self.d_dim*16, 3, stride=1, padding=1, bias=True) # extra layer 152 | self.conv_7 = nn.Conv3d(self.d_dim*32, self.z_dim, 1, stride=1, padding=0, bias=True) 153 | else: 154 | 155 | self.conv_6 = nn.Conv3d(self.d_dim*16, self.z_dim*2, 1, stride=1, padding=0, bias=True) 156 | self.conv_7 = nn.Conv3d(self.z_dim*2, self.z_dim, 1, stride=1, padding=0, bias=True) 157 | 158 | def forward(self, voxels, return_feat=False, is_training=False): 159 | out = voxels 160 | 161 | out = self.conv_1(out) 162 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 163 | feat1 = out.detach() 164 | 165 | out = self.conv_2(out) 166 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 167 | feat2 = out.detach() 168 | 169 | out = self.conv_3(out) 170 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 171 | feat3 = out.detach() 172 | 173 | out = self.conv_4(out) 174 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 175 | feat4 = out.detach() 176 | 177 | out_1 = self.conv_5(out) 178 | out_1 = F.leaky_relu(out_1, negative_slope=leaky_f, inplace=True) 179 | feat5 = out_1.detach() 180 | 181 | if USE_25: 182 | out_2 = self.conv_6(out_1) 183 | out_2 = F.leaky_relu(out_2, negative_slope=leaky_f, inplace=True) 184 | feat6 = out_2.detach() 185 | out = self.conv_7(torch.cat([out_1, out_2], dim=1)) 186 | else: 187 | out = self.conv_6(out_1) 188 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 189 | out = self.conv_7(out) 190 | 191 | out = torch.max(torch.min(out, out*0.002+0.998), out*0.002-0.998) 192 | # out = torch.sigmoid(out) 193 | if return_feat: 194 | return out, feat1, feat2, feat3, feat4, feat5, feat6 195 | return out 196 | 197 | 198 | def reshape_to_size_torch(v1, v2_shape): 199 | _,_,x1,y1,z1 = v1.shape 200 | x,y,z = v2_shape 201 | new_v1 = v1 202 | padding = np.zeros(6) 203 | if z1 < z: 204 | padding[0], padding[1] = int((z-z1)/2), z-z1-int((z-z1)/2) 205 | else: 206 | new_v1 = new_v1[:,:,:, :,int((z1-z)/2):int((z1-z)/2)+z ] 207 | if y1 < y: 208 | padding[2], padding[3] = int((y-y1)/2), y-y1-int((y-y1)/2) 209 | else: 210 | new_v1 = new_v1[:,:,:, int((y1-y)/2):int((y1-y)/2)+y,: ] 211 | if x1 < x: 212 | padding[4], padding[5] = int((x-x1)/2), x-x1-int((x-x1)/2) 213 | else: 214 | new_v1 = new_v1[:,:,int((x1-x)/2):int((x1-x)/2)+x,:,: ] 215 | new_v1 = F.pad(new_v1, tuple(padding.astype(np.int8))) 216 | return new_v1 217 | 218 | class PatchDeformer(nn.Module): 219 | def __init__(self, d_dim, z_dim, pre_dis=False, include_coarse=False): 220 | super(PatchDeformer, self).__init__() 221 | self.d_dim = d_dim 222 | self.z_dim = z_dim 223 | self.pred_dis = pre_dis 224 | self.include_coarse = include_coarse 225 | 226 | self.conv_1 = nn.Conv3d(1, self.d_dim, 3, stride=1, padding=1, bias=True) 227 | self.conv_2 = nn.Conv3d(self.d_dim, self.d_dim*2, 3, stride=2, padding=1, bias=True) 228 | self.conv_3 = nn.Conv3d(self.d_dim*2, self.d_dim*4, 3, stride=2, padding=1, bias=True) 229 | 230 | self.conv_d1 = nn.Conv3d(1, self.d_dim, 3, stride=1, padding=1, bias=True) 231 | self.conv_d2 = nn.Conv3d(self.d_dim, self.d_dim*2, 3, stride=2, padding=1, bias=True) 232 | self.conv_d3 = nn.Conv3d(self.d_dim*2, self.d_dim*4, 3, stride=2, padding=1, bias=True) 233 | 234 | self.fc1 = nn.Linear(self.d_dim*8*5*5*5, 4096) 235 | self.fc2 = nn.Linear(4096, 1024) 236 | self.fc3 = nn.Linear(1024, self.z_dim) 237 | 238 | if self.pred_dis: 239 | self.dis_fc1 = nn.Linear(self.d_dim*8*5*5*5, 2048) 240 | self.dis_fc2 = nn.Linear(2048, 512) 241 | self.dis_fc3 = nn.Linear(512, 1) 242 | 243 | if self.include_coarse: 244 | self.cdis_fc1 = nn.Linear(self.d_dim*4*5*5*5, 1024) 245 | self.cdis_fc2 = nn.Linear(1024, 512) 246 | self.cdis_fc3 = nn.Linear(512, 1) 247 | 248 | self.ddis_fc1 = nn.Linear(self.d_dim*4*5*5*5, 1024) 249 | self.ddis_fc2 = nn.Linear(1024, 512) 250 | self.ddis_fc3 = nn.Linear(512, 1) 251 | 252 | def forward(self, c_vox, d_vox, is_training=False): 253 | 254 | out = c_vox 255 | out = self.conv_1(out) 256 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 257 | out = self.conv_2(out) 258 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 259 | out = self.conv_3(out) 260 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 261 | out_1 = out.view(-1, 5*5*5*self.d_dim*4) 262 | 263 | out = d_vox 264 | out = self.conv_d1(out) 265 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 266 | out = self.conv_d2(out) 267 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 268 | out = self.conv_d3(out) 269 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 270 | out_2 = out.view(-1, 5*5*5*self.d_dim*4) 271 | 272 | 273 | out = torch.cat([out_1, out_2], dim=1) 274 | out = self.fc1(out) 275 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 276 | out = self.fc2(out) 277 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 278 | out = self.fc3(out) 279 | # out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 280 | out_de = torch.max(torch.min(out, out*0.002+0.998), out*0.002-0.998) 281 | 282 | if self.pred_dis: 283 | out = torch.cat([out_1, out_2], dim=1) 284 | out = self.dis_fc1(out) 285 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 286 | out = self.dis_fc2(out) 287 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 288 | out = self.dis_fc3(out) 289 | 290 | out_dis = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 291 | # out_dis = torch.max(torch.min(out, out*0.002+0.998), out*0.002) 292 | if not self.include_coarse: 293 | return out_de, out_dis 294 | else: 295 | # out = torch.cat([out_1, out_2], dim=1) 296 | out = out_1 297 | out = self.cdis_fc1(out) 298 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 299 | out = self.cdis_fc2(out) 300 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 301 | out = self.cdis_fc3(out) 302 | out_c_dis = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 303 | 304 | out = out_2 305 | out = self.ddis_fc1(out) 306 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 307 | out = self.ddis_fc2(out) 308 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 309 | out = self.ddis_fc3(out) 310 | out_d_dis = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 311 | 312 | return out_de, out_dis, out_c_dis, out_d_dis 313 | 314 | return out_de 315 | 316 | class JointDeformer(nn.Module): 317 | def __init__(self, d_dim, dw_dim, k_dim,loc_size, use_mean_x=False, wd_size=8): 318 | super(JointDeformer, self).__init__() 319 | self.d_dim = d_dim 320 | self.dw_dim = dw_dim 321 | self.k_dim = k_dim 322 | self.d = loc_size 323 | self.use_mean_x = use_mean_x 324 | self.wd_size = wd_size 325 | 326 | # input 327 | self.conv_1 = nn.Conv3d(2, self.d_dim, 5, stride=1, padding=2, bias=True) 328 | self.conv_2 = nn.Conv3d(self.d_dim, self.d_dim*2, 3, stride=2, padding=1, bias=True) # 64 329 | self.conv_3 = nn.Conv3d(self.d_dim*2, self.d_dim*2, 3, stride=1, padding=1, bias=True) 330 | self.conv_4 = nn.Conv3d(self.d_dim*2, self.d_dim*4, 3, stride=2, padding=1, bias=True) # 32 331 | 332 | self.conv_5 = nn.Conv3d(self.d_dim*4, self.d_dim*4, 3, stride=1, padding=1, bias=True) 333 | self.conv_6 = nn.Conv3d(self.d_dim*4, self.d_dim*8, 3, stride=2, padding=1, bias=True) # 16 334 | self.conv_7 = nn.Conv3d(self.d_dim*8, self.d_dim*8, 3, stride=2, padding=1, bias=True) # 8 335 | self.conv_8 = nn.Conv3d(self.d_dim*8, self.d_dim*8, 3, stride=2, padding=1, bias=True) # 4 336 | 337 | # partial 338 | self.conv_p1 = nn.Conv3d(1, self.d_dim, 5, stride=1, padding=2, bias=True) 339 | self.conv_p2 = nn.Conv3d(self.d_dim, self.d_dim*2, 3, stride=2, padding=1, bias=True) # 64 340 | self.conv_p3 = nn.Conv3d(self.d_dim*2, self.d_dim*2, 3, stride=1, padding=1, bias=True) 341 | self.conv_p4 = nn.Conv3d(self.d_dim*2, self.d_dim*4, 3, stride=2, padding=1, bias=True) # 32 342 | 343 | self.conv_p5 = nn.Conv3d(self.d_dim*4, self.d_dim*4, 3, stride=1, padding=1, bias=True) 344 | self.conv_p6 = nn.Conv3d(self.d_dim*4, self.d_dim*8, 3, stride=2, padding=1, bias=True) # 16 345 | self.conv_p7 = nn.Conv3d(self.d_dim*8, self.d_dim*8, 3, stride=2, padding=1, bias=True) # 8 346 | self.conv_p8 = nn.Conv3d(self.d_dim*8, self.d_dim*8, 3, stride=2, padding=1, bias=True) # 4 347 | 348 | # windows 349 | self.conv_w0 = nn.Conv3d(self.k_dim, self.dw_dim, 5, stride=1, padding=2, bias=True) 350 | self.conv_w1 = nn.Conv3d(self.dw_dim, self.dw_dim, 3, stride=1, padding=1, bias=True) 351 | self.conv_w2 = nn.Conv3d(self.dw_dim, self.dw_dim*2, 3, stride=2, padding=1, bias=True) # 12 . 16 352 | self.conv_w3 = nn.Conv3d(self.dw_dim*2, self.dw_dim*2, 3, stride=1, padding=1, bias=True) 353 | self.conv_w4 = nn.Conv3d(self.dw_dim*2, self.dw_dim*4, 3, stride=2, padding=1, bias=True) # 6 . 8 354 | 355 | self.conv_w5 = nn.Conv3d(self.dw_dim*4, self.dw_dim*4, 3, stride=1, padding=1, bias=True) 356 | self.conv_w6 = nn.Conv3d(self.dw_dim*4, self.dw_dim*8, 3, stride=2, padding=1, bias=True) # 3 . 4 357 | if self.wd_size==16: 358 | self.conv_w7 = nn.Conv3d(self.dw_dim*8, self.dw_dim*8, 3, stride=2, padding=1, bias=True) # 3 . 4 359 | 360 | 361 | self.fc1 = nn.Linear((self.dw_dim*8*self.d*self.d*self.d + self.d_dim*8*4*4*4*2), 4096) 362 | self.fc2 = nn.Linear(4096, 1024) 363 | 364 | # X 365 | self.fc_x1 = nn.Linear(1024, 256) 366 | self.fc_x2 = nn.Linear(256, 1024) 367 | self.fc_x3 = nn.Linear(1024, self.dw_dim*8*self.d*self.d*self.d) 368 | self.dconv_x0 = nn.Conv3d(self.dw_dim*8, self.dw_dim*8, 3, stride=1, padding=1, bias=True) 369 | self.dconv_x1 = nn.ConvTranspose3d(self.dw_dim*8, self.dw_dim*4, 4, stride=2, padding=1, bias=True) # 8 370 | self.dconv_x2 = nn.Conv3d(self.dw_dim*4, self.dw_dim*4, 3, stride=1, padding=1, bias=True) 371 | self.dconv_x3 = nn.ConvTranspose3d(self.dw_dim*4, self.dw_dim*2, 4, stride=2, padding=1, bias=True) # 16 372 | self.dconv_x4 = nn.Conv3d(self.dw_dim*2, self.d_dim*2, 3, stride=1, padding=1, bias=True) 373 | self.dconv_x5 = nn.ConvTranspose3d(self.dw_dim*2, self.dw_dim, 4, stride=2, padding=1, bias=True) # 32 374 | self.dconv_x6 = nn.Conv3d(self.dw_dim, self.dw_dim, 3, stride=1, padding=1, bias=True) 375 | self.dconv_x7 = nn.Conv3d(self.dw_dim, self.k_dim, 3, stride=1, padding=1, bias=True) 376 | 377 | # D 378 | self.fc_d1 = nn.Linear(1024, 1024) 379 | self.fc_d2 = nn.Linear(1024, 1024) 380 | self.fc_d3 = nn.Linear(1024, self.k_dim*3) 381 | 382 | 383 | def forward(self, partial_in, input_in, windows_in, window_masks_in, loc_mask_in, is_training=False): 384 | leaky_f = 0.02 385 | out = torch.cat([input_in, window_masks_in], dim=1) 386 | out = self.conv_1(out) 387 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 388 | out = self.conv_2(out) 389 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 390 | out = self.conv_3(out) 391 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 392 | out = self.conv_4(out) 393 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 394 | out = self.conv_5(out) 395 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 396 | out = self.conv_6(out) 397 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 398 | out = self.conv_7(out) 399 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 400 | out = self.conv_8(out) 401 | out_coa = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 402 | 403 | out = partial_in 404 | out = self.conv_p1(out) 405 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 406 | out = self.conv_p2(out) 407 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 408 | out = self.conv_p3(out) 409 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 410 | out = self.conv_p4(out) 411 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 412 | out = self.conv_p5(out) 413 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 414 | out = self.conv_p6(out) 415 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 416 | out = self.conv_p7(out) 417 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 418 | out = self.conv_p8(out) 419 | out_p = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 420 | 421 | out = windows_in 422 | out = self.conv_w0(out) 423 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 424 | out = self.conv_w1(out) 425 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 426 | out = self.conv_w2(out) 427 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 428 | out = self.conv_w3(out) 429 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 430 | out = self.conv_w4(out) 431 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 432 | out = self.conv_w5(out) 433 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 434 | out = self.conv_w6(out) 435 | out_w = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 436 | if self.wd_size==16: 437 | out_w = self.conv_w7(out_w) 438 | out_w = F.leaky_relu(out_w, negative_slope=leaky_f, inplace=True) 439 | 440 | out_coa = out_coa.view(1, self.d_dim*8*4*4*4) 441 | out_p = out_p.view(1, self.d_dim*8*4*4*4) 442 | out_w =out_w.view(1, self.dw_dim*8*self.d*self.d*self.d) 443 | 444 | out = torch.cat((out_w, out_coa, out_p), dim=1) 445 | out = self.fc1(out) 446 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 447 | out = self.fc2(out) 448 | out_latent = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 449 | 450 | # D 451 | out = self.fc_d1(out_latent) 452 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 453 | out = self.fc_d2(out) 454 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 455 | out = self.fc_d3(out) 456 | # out_D = torch.max(torch.min(out, out*0.002+0.998), out*0.002-0.998) 457 | out_D = torch.clip(out, -1.0, 1.0) 458 | out_D = out_D.view(self.k_dim, 3) 459 | 460 | # X 461 | out = self.fc_x1(out_latent) 462 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 463 | out = self.fc_x2(out) 464 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 465 | out = self.fc_x3(out) 466 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 467 | 468 | out = out.view(1, self.dw_dim*8, self.d,self.d,self.d) 469 | out = self.dconv_x0(out) 470 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 471 | #print('0',out.shape) 472 | out = self.dconv_x1(out) 473 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 474 | #print('1',out.shape) 475 | out = self.dconv_x2(out) 476 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 477 | #print('2',out.shape) 478 | out = self.dconv_x3(out) 479 | #print('3',out.shape) 480 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 481 | out = self.dconv_x4(out) 482 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 483 | #print('4',out.shape) 484 | out = self.dconv_x5(out) 485 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 486 | #print('5',out.shape) 487 | out = self.dconv_x6(out) 488 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True) 489 | #print('6',out.shape) 490 | out_X = self.dconv_x7(out) # [1, kdim, vx, vy, vz] 491 | 492 | # out_X = torch.max(torch.min(out_X,out_X*0.002+0.998), out_X*0.002) 493 | # out_X = torch.sigmoid(out_X) 494 | out_X = torch.clip(out_X, 0.0, 1.0) 495 | # loc_mask_in = F.interpolate(loc_mask_in, scale_factor=8, mode='nearest') 496 | out_X = F.avg_pool3d(out_X, kernel_size=8, stride=8) 497 | 498 | if self.use_mean_x: 499 | out_X = out_X*loc_mask_in/(torch.sum(out_X*loc_mask_in, dim=1).unsqueeze(1)+1e-5) 500 | else: 501 | out_X = torch.exp(out_X)*loc_mask_in/(torch.sum(torch.exp(out_X)*loc_mask_in, dim=1).unsqueeze(1)+1e-5) 502 | 503 | #print('out_x', out_X.mean()) 504 | return out_D, out_X 505 | 506 | 507 | -------------------------------------------------------------------------------- /patch_train.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --data_content content_chair \ 3 | --data_dir ./data/03001627/ \ 4 | --input_size 32 \ 5 | --output_size 128 \ 6 | --sample_dir samples_patch \ 7 | --checkpoint_dir checkpoint_patch \ 8 | --gpu 0 \ 9 | --epoch 250 \ 10 | --lr 1e-4 \ 11 | --z_dim 128 \ 12 | --g_dim 32 \ 13 | --K 2 \ 14 | --w_ident 8 \ 15 | --csize 26 \ 16 | --c_range 26 \ 17 | --max_sample_num 400 \ 18 | --train_patch \ 19 | --model_name chair_model0 \ 20 | --dump_deform_path ./dump_deform/chair/model0 \ 21 | #--dump_deform \ 22 | #--small_dataset \ 23 | #--mode test \ 24 | #--continue_train \ 25 | -------------------------------------------------------------------------------- /pc_util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ Utility functions for processing point clouds. 7 | Author: Charles R. Qi and Or Litany 8 | """ 9 | 10 | import os 11 | import sys 12 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 13 | sys.path.append(BASE_DIR) 14 | 15 | # Point cloud IO 16 | import numpy as np 17 | try: 18 | from plyfile import PlyData, PlyElement 19 | except: 20 | print("Please install the module 'plyfile' for PLY i/o, e.g.") 21 | print("pip install plyfile") 22 | sys.exit(-1) 23 | 24 | 25 | # Mesh IO 26 | #import trimesh 27 | 28 | import matplotlib.pyplot as pyplot 29 | 30 | # ---------------------------------------- 31 | # Point Cloud Sampling 32 | # ---------------------------------------- 33 | 34 | def random_sampling(pc, num_sample, replace=None, return_choices=False): 35 | """ Input is NxC, output is num_samplexC 36 | """ 37 | if replace is None: replace = (pc.shape[0]num_sample: 138 | pc = random_sampling(pc, num_sample, False) 139 | elif pc.shape[0]num_sample: 185 | pc = random_sampling(pc, num_sample, False) 186 | elif pc.shape[0]np.max(labels)) 219 | 220 | vertex = [] 221 | #colors = [pyplot.cm.jet(i/float(num_classes)) for i in range(num_classes)] 222 | colors = [colormap(i/float(num_classes)) for i in range(num_classes)] 223 | for i in range(N): 224 | c = colors[labels[i]] 225 | c = [int(x*255) for x in c] 226 | vertex.append( (points[i,0],points[i,1],points[i,2],c[0],c[1],c[2]) ) 227 | vertex = np.array(vertex, dtype=[('x', 'f4'), ('y', 'f4'),('z', 'f4'),('red', 'u1'), ('green', 'u1'),('blue', 'u1')]) 228 | 229 | el = PlyElement.describe(vertex, 'vertex', comments=['vertices']) 230 | PlyData([el], text=True).write(filename) 231 | 232 | def write_ply_rgb(points, colors, out_filename, num_classes=None): 233 | """ Color (N,3) points with RGB colors (N,3) within range [0,255] as OBJ file """ 234 | colors = colors.astype(int) 235 | N = points.shape[0] 236 | fout = open(out_filename, 'w') 237 | for i in range(N): 238 | c = colors[i,:] 239 | fout.write('v %f %f %f %d %d %d\n' % (points[i,0],points[i,1],points[i,2],c[0],c[1],c[2])) 240 | fout.close() 241 | 242 | # ---------------------------------------- 243 | # Simple Point cloud and Volume Renderers 244 | # ---------------------------------------- 245 | 246 | def pyplot_draw_point_cloud(points, output_filename): 247 | """ points is a Nx3 numpy array """ 248 | import matplotlib.pyplot as plt 249 | fig = plt.figure() 250 | ax = fig.add_subplot(111, projection='3d') 251 | ax.scatter(points[:,0], points[:,1], points[:,2]) 252 | ax.set_xlabel('x') 253 | ax.set_ylabel('y') 254 | ax.set_zlabel('z') 255 | #savefig(output_filename) 256 | 257 | def pyplot_draw_volume(vol, output_filename): 258 | """ vol is of size vsize*vsize*vsize 259 | output an image to output_filename 260 | """ 261 | points = volume_to_point_cloud(vol) 262 | pyplot_draw_point_cloud(points, output_filename) 263 | 264 | # ---------------------------------------- 265 | # Simple Point manipulations 266 | # ---------------------------------------- 267 | def rotate_point_cloud(points, rotation_matrix=None): 268 | """ Input: (n,3), Output: (n,3) """ 269 | # Rotate in-place around Z axis. 270 | if rotation_matrix is None: 271 | rotation_angle = np.random.uniform() * 2 * np.pi 272 | sinval, cosval = np.sin(rotation_angle), np.cos(rotation_angle) 273 | rotation_matrix = np.array([[cosval, sinval, 0], 274 | [-sinval, cosval, 0], 275 | [0, 0, 1]]) 276 | ctr = points.mean(axis=0) 277 | rotated_data = np.dot(points-ctr, rotation_matrix) + ctr 278 | return rotated_data, rotation_matrix 279 | 280 | def rotate_pc_along_y(pc, rot_angle): 281 | ''' Input ps is NxC points with first 3 channels as XYZ 282 | z is facing forward, x is left ward, y is downward 283 | ''' 284 | cosval = np.cos(rot_angle) 285 | sinval = np.sin(rot_angle) 286 | rotmat = np.array([[cosval, -sinval],[sinval, cosval]]) 287 | pc[:,[0,2]] = np.dot(pc[:,[0,2]], np.transpose(rotmat)) 288 | return pc 289 | 290 | def roty(t): 291 | """Rotation about the y-axis.""" 292 | c = np.cos(t) 293 | s = np.sin(t) 294 | return np.array([[c, 0, s], 295 | [0, 1, 0], 296 | [-s, 0, c]]) 297 | 298 | def roty_batch(t): 299 | """Rotation about the y-axis. 300 | t: (x1,x2,...xn) 301 | return: (x1,x2,...,xn,3,3) 302 | """ 303 | input_shape = t.shape 304 | output = np.zeros(tuple(list(input_shape)+[3,3])) 305 | c = np.cos(t) 306 | s = np.sin(t) 307 | output[...,0,0] = c 308 | output[...,0,2] = s 309 | output[...,1,1] = 1 310 | output[...,2,0] = -s 311 | output[...,2,2] = c 312 | return output 313 | 314 | def rotz(t): 315 | """Rotation about the z-axis.""" 316 | c = np.cos(t) 317 | s = np.sin(t) 318 | return np.array([[c, -s, 0], 319 | [s, c, 0], 320 | [0, 0, 1]]) 321 | 322 | 323 | 324 | # ---------------------------------------- 325 | # BBox 326 | # ---------------------------------------- 327 | def bbox_corner_dist_measure(crnr1, crnr2): 328 | """ compute distance between box corners to replace iou 329 | Args: 330 | crnr1, crnr2: Nx3 points of box corners in camera axis (y points down) 331 | output is a scalar between 0 and 1 332 | """ 333 | 334 | dist = sys.maxsize 335 | for y in range(4): 336 | rows = ([(x+y)%4 for x in range(4)] + [4+(x+y)%4 for x in range(4)]) 337 | d_ = np.linalg.norm(crnr2[rows, :] - crnr1, axis=1).sum() / 8.0 338 | if d_ < dist: 339 | dist = d_ 340 | 341 | u = sum([np.linalg.norm(x[0,:] - x[6,:]) for x in [crnr1, crnr2]])/2.0 342 | 343 | measure = max(1.0 - dist/u, 0) 344 | print(measure) 345 | 346 | 347 | return measure 348 | 349 | 350 | def point_cloud_to_bbox(points): 351 | """ Extract the axis aligned box from a pcl or batch of pcls 352 | Args: 353 | points: Nx3 points or BxNx3 354 | output is 6 dim: xyz pos of center and 3 lengths 355 | """ 356 | which_dim = len(points.shape) - 2 # first dim if a single cloud and second if batch 357 | mn, mx = points.min(which_dim), points.max(which_dim) 358 | lengths = mx - mn 359 | cntr = 0.5*(mn + mx) 360 | return np.concatenate([cntr, lengths], axis=which_dim) 361 | 362 | def write_bbox(scene_bbox, out_filename): 363 | """Export scene bbox to meshes 364 | Args: 365 | scene_bbox: (N x 6 numpy array): xyz pos of center and 3 lengths 366 | out_filename: (string) filename 367 | Note: 368 | To visualize the boxes in MeshLab. 369 | 1. Select the objects (the boxes) 370 | 2. Filters -> Polygon and Quad Mesh -> Turn into Quad-Dominant Mesh 371 | 3. Select Wireframe view. 372 | """ 373 | def convert_box_to_trimesh_fmt(box): 374 | ctr = box[:3] 375 | lengths = box[3:] 376 | trns = np.eye(4) 377 | trns[0:3, 3] = ctr 378 | trns[3,3] = 1.0 379 | box_trimesh_fmt = trimesh.creation.box(lengths, trns) 380 | return box_trimesh_fmt 381 | 382 | scene = trimesh.scene.Scene() 383 | for box in scene_bbox: 384 | scene.add_geometry(convert_box_to_trimesh_fmt(box)) 385 | 386 | mesh_list = trimesh.util.concatenate(scene.dump()) 387 | # save to ply file 388 | trimesh.io.export.export_mesh(mesh_list, out_filename, file_type='ply') 389 | 390 | return 391 | 392 | def write_oriented_bbox(scene_bbox, out_filename): 393 | """Export oriented (around Z axis) scene bbox to meshes 394 | Args: 395 | scene_bbox: (N x 7 numpy array): xyz pos of center and 3 lengths (dx,dy,dz) 396 | and heading angle around Z axis. 397 | Y forward, X right, Z upward. heading angle of positive X is 0, 398 | heading angle of positive Y is 90 degrees. 399 | out_filename: (string) filename 400 | """ 401 | def heading2rotmat(heading_angle): 402 | pass 403 | rotmat = np.zeros((3,3)) 404 | rotmat[2,2] = 1 405 | cosval = np.cos(heading_angle) 406 | sinval = np.sin(heading_angle) 407 | rotmat[0:2,0:2] = np.array([[cosval, -sinval],[sinval, cosval]]) 408 | return rotmat 409 | 410 | def convert_oriented_box_to_trimesh_fmt(box): 411 | ctr = box[:3] 412 | lengths = box[3:6] 413 | trns = np.eye(4) 414 | trns[0:3, 3] = ctr 415 | trns[3,3] = 1.0 416 | trns[0:3,0:3] = heading2rotmat(box[6]) 417 | box_trimesh_fmt = trimesh.creation.box(lengths, trns) 418 | return box_trimesh_fmt 419 | 420 | scene = trimesh.scene.Scene() 421 | for box in scene_bbox: 422 | scene.add_geometry(convert_oriented_box_to_trimesh_fmt(box)) 423 | 424 | mesh_list = trimesh.util.concatenate(scene.dump()) 425 | # save to ply file 426 | trimesh.io.export.export_mesh(mesh_list, out_filename, file_type='ply') 427 | 428 | return 429 | 430 | def write_oriented_bbox_camera_coord(scene_bbox, out_filename): 431 | """Export oriented (around Y axis) scene bbox to meshes 432 | Args: 433 | scene_bbox: (N x 7 numpy array): xyz pos of center and 3 lengths (dx,dy,dz) 434 | and heading angle around Y axis. 435 | Z forward, X rightward, Y downward. heading angle of positive X is 0, 436 | heading angle of negative Z is 90 degrees. 437 | out_filename: (string) filename 438 | """ 439 | def heading2rotmat(heading_angle): 440 | pass 441 | rotmat = np.zeros((3,3)) 442 | rotmat[1,1] = 1 443 | cosval = np.cos(heading_angle) 444 | sinval = np.sin(heading_angle) 445 | rotmat[0,:] = np.array([cosval, 0, sinval]) 446 | rotmat[2,:] = np.array([-sinval, 0, cosval]) 447 | return rotmat 448 | 449 | def convert_oriented_box_to_trimesh_fmt(box): 450 | ctr = box[:3] 451 | lengths = box[3:6] 452 | trns = np.eye(4) 453 | trns[0:3, 3] = ctr 454 | trns[3,3] = 1.0 455 | trns[0:3,0:3] = heading2rotmat(box[6]) 456 | box_trimesh_fmt = trimesh.creation.box(lengths, trns) 457 | return box_trimesh_fmt 458 | 459 | scene = trimesh.scene.Scene() 460 | for box in scene_bbox: 461 | scene.add_geometry(convert_oriented_box_to_trimesh_fmt(box)) 462 | 463 | mesh_list = trimesh.util.concatenate(scene.dump()) 464 | # save to ply file 465 | trimesh.io.export.export_mesh(mesh_list, out_filename, file_type='ply') 466 | 467 | return 468 | 469 | def write_lines_as_cylinders(pcl, filename, rad=0.005, res=64): 470 | """Create lines represented as cylinders connecting pairs of 3D points 471 | Args: 472 | pcl: (N x 2 x 3 numpy array): N pairs of xyz pos 473 | filename: (string) filename for the output mesh (ply) file 474 | rad: radius for the cylinder 475 | res: number of sections used to create the cylinder 476 | """ 477 | scene = trimesh.scene.Scene() 478 | for src,tgt in pcl: 479 | # compute line 480 | vec = tgt - src 481 | M = trimesh.geometry.align_vectors([0,0,1],vec, False) 482 | vec = tgt - src # compute again since align_vectors modifies vec in-place! 483 | M[:3,3] = 0.5*src + 0.5*tgt 484 | height = np.sqrt(np.dot(vec, vec)) 485 | scene.add_geometry(trimesh.creation.cylinder(radius=rad, height=height, sections=res, transform=M)) 486 | mesh_list = trimesh.util.concatenate(scene.dump()) 487 | trimesh.io.export.export_mesh(mesh_list, '%s.ply'%(filename), file_type='ply') 488 | 489 | # ---------------------------------------- 490 | # Testing 491 | # ---------------------------------------- 492 | if __name__ == '__main__': 493 | print('running some tests') 494 | 495 | ############ 496 | ## Test "write_lines_as_cylinders" 497 | ############ 498 | pcl = np.random.rand(32, 2, 3) 499 | write_lines_as_cylinders(pcl, 'point_connectors') 500 | input() 501 | 502 | 503 | scene_bbox = np.zeros((1,7)) 504 | scene_bbox[0,3:6] = np.array([1,2,3]) # dx,dy,dz 505 | scene_bbox[0,6] = np.pi/4 # 45 degrees 506 | write_oriented_bbox(scene_bbox, 'single_obb_45degree.ply') 507 | ############ 508 | ## Test point_cloud_to_bbox 509 | ############ 510 | pcl = np.random.rand(32, 16, 3) 511 | pcl_bbox = point_cloud_to_bbox(pcl) 512 | assert pcl_bbox.shape == (32, 6) 513 | 514 | pcl = np.random.rand(16, 3) 515 | pcl_bbox = point_cloud_to_bbox(pcl) 516 | assert pcl_bbox.shape == (6,) 517 | 518 | ############ 519 | ## Test corner distance 520 | ############ 521 | crnr1 = np.array([[2.59038660e+00, 8.96107932e-01, 4.73305349e+00], 522 | [4.12281644e-01, 8.96107932e-01, 4.48046631e+00], 523 | [2.97129656e-01, 8.96107932e-01, 5.47344275e+00], 524 | [2.47523462e+00, 8.96107932e-01, 5.72602993e+00], 525 | [2.59038660e+00, 4.41155793e-03, 4.73305349e+00], 526 | [4.12281644e-01, 4.41155793e-03, 4.48046631e+00], 527 | [2.97129656e-01, 4.41155793e-03, 5.47344275e+00], 528 | [2.47523462e+00, 4.41155793e-03, 5.72602993e+00]]) 529 | crnr2 = crnr1 530 | 531 | print(bbox_corner_dist_measure(crnr1, crnr2)) 532 | 533 | 534 | 535 | print('tests PASSED') 536 | 537 | 538 | 539 | -------------------------------------------------------------------------------- /splits/compose2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | #path = '/home/bo/data/shapenet_processed2/02801938' 5 | #path = '/home/bo/data/shapenet_processed/02828884' 6 | path = '/home/bo/data/shapenet_processed' 7 | 8 | #f1 = open('content_bench_train.txt', 'w') 9 | #f2 = open('content_bench_test.txt', 'w') 10 | f1 = open('total_sub_train.txt', 'w') 11 | f2 = open('total_sub_test.txt', 'w') 12 | 13 | f0 = open('content_chair_train.txt', 'r') 14 | lines = f0.readlines() 15 | for l in lines: 16 | f1.write('03001627/%s\n'%(l.strip())) 17 | f0.close() 18 | f0 = open('content_chair_test.txt', 'r') 19 | lines = f0.readlines() 20 | for l in lines[0:200]: 21 | f2.write('03001627/%s\n'%(l.strip())) 22 | f0.close() 23 | 24 | f0 = open('content_table_test.txt', 'r') 25 | lines = f0.readlines() 26 | for l in lines: 27 | f1.write('04379243/%s\n'%(l.strip())) 28 | f0.close() 29 | f0 = open('content_table_test.txt', 'r') 30 | lines = f0.readlines() 31 | for l in lines[0:200]: 32 | f2.write('04379243/%s\n'%(l.strip())) 33 | f0.close() 34 | 35 | 36 | for cat in os.listdir(path): 37 | if cat in ['03001627', '04379243']: 38 | continue 39 | if not cat in ['02801938', '02828884', '03691459']: 40 | continue 41 | if not cat.startswith('0'): 42 | continue 43 | names = os.listdir(os.path.join(path, cat)) 44 | data_len = len(names) 45 | 46 | for i in range(data_len): 47 | if i=int(4*data_len/5) and i