├── README.md
├── Teaser.png
├── binvox_rw.py
├── complete_train.sh
├── deform_train.sh
├── joint_test.sh
├── joint_train.sh
├── main.py
├── models.py
├── patch_train.sh
├── pc_util.py
├── splits
├── compose2.py
├── compose_total.py
├── content_boat_test.txt
├── content_boat_train.txt
├── content_cabinet_test.txt
├── content_cabinet_train.txt
├── content_car_test.txt
├── content_car_train.txt
├── content_chair_test.txt
├── content_chair_train.txt
├── content_couch_test.txt
├── content_couch_train.txt
├── content_lamp_test.txt
├── content_lamp_train.txt
├── content_plane_test.txt
├── content_plane_train.txt
├── content_table_test.txt
├── content_table_train.txt
└── merge_splits.py
├── train_complete.py
├── train_deform.py
├── train_joint.py
├── train_patch.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # PatchRD
2 | Code for ECCV 2022 paper PatchRD: Detail-Preserving Shape Completion by Learning Patch Retrieval and Deformation. ([PDF](https://arxiv.org/pdf/2207.11790.pdf))
3 |
4 |
5 |
6 |
7 | ## Installation
8 | Install [Pytorch](https://pytorch.org/get-started/locally/). It is required that you have access to GPUs. The code is tested with Ubuntu 18.04, Pytorch v1.9.0 and CUDA 10.2.
9 |
10 | Install [ChamferDistancePytorch](https://github.com/ThibaultGROUEIX/ChamferDistancePytorch) by following the instructions in their github repo.
11 |
12 | Install [psbody](https://github.com/MPI-IS/mesh) for mesh reading and writing.
13 |
14 | Install the following Python dependencies (with `pip install`):
15 |
16 | numpy
17 | scipy
18 | h5py
19 | sklearn
20 | skimage
21 | opencv-python
22 | kornia
23 | pytorch=1.9.0
24 |
25 | ## Data Download
26 | You can download the dataset [here](https://www.dropbox.com/s/phiw7kjw7d1dfu4/data.zip?dl=0). It's a voxelized dataset for 8 classes in ShapeNet. Train/Test spliting is in `./splits`.
27 | To evaluate the chamfer distance, we convert our output to point cloud and use the point cloud with 16384 points from [Completion3D](https://completion3d.stanford.edu/) dataset as the ground truth. You can either download it from the offcial website or directly [here](https://www.dropbox.com/s/e1g071m10xm7zug/completion3d.zip?dl=0).
28 |
29 |
30 | ## Demo with Pre-trained Models
31 | You can download the pre-trained models for the chair category [here](https://www.dropbox.com/s/y6yzehq39ereekx/checkpoints_chair.zip?dl=0).
32 |
33 | To run the completion pipeline, first **uncomment `--dump_deform`, `--mode test`, `--small_dataset` in `patch_train.sh` and `deform_train.sh`, then uncomment `--small_dataset` in `joint_test.sh`**.
34 |
35 | Then run:
36 |
37 | ```
38 | ./patch_train.sh
39 | ```
40 | ```
41 | ./deform_train.sh
42 | ```
43 | ```
44 | ./joint_test.sh
45 | ```
46 | You will get the input, output meshes for a small set of samples in the `samples_joint` folder.
47 |
48 |
49 |
50 | ## Training
51 | There are four steps to train our framework.
52 | The default settings train model for the chair category. You can change the arguments `--data_content` and `--data_dir` in each bash script to train on other categories.
53 |
54 | ### Step 1: Coarse Completion
55 | This stage takes a partial shape as input, and output the coarse full shape (4x downsampled shape of the detailed full shape).
56 | ```
57 | ./complete_train.sh
58 | ```
59 |
60 | ### Step 2: Retrieval Learning
61 | First train the patch encoder to learn the feature embeddings for the coarse and detailed patch pair.
62 | ```
63 | ./patch_train.sh
64 | ```
65 | Then dump the intermediate retrieval results to train the deformation. **You have to add `--dump_deform` in `patch_train.sh`**, then run
66 |
67 | ```
68 | ./patch_train.sh
69 | ```
70 |
71 | ### Step 3: Initial Deformation Learning
72 | This step learns the initial deformation for the jointly learning stage. Run
73 |
74 | ```
75 | ./deform_train.sh
76 | ```
77 | Then dump the intermediate initial deformation results to train the deformation and blending stage. **You have to add `--dump_deform` in `deform_train.sh`**, then run
78 |
79 | ```
80 | ./deform_train.sh
81 | ```
82 |
83 | ### Step 4: Deformation and Blending
84 | This step joinly learn the deformation and blending. Run
85 | ```
86 | ./joint_train.sh
87 | ```
88 |
89 | ## Evaluation
90 | After training all four networks, you can run the following scripts to run the completion results from a randomly cropped shape. You can find the input and output mesh in `./samples_joint`.
91 | ```
92 | ./joint_test.sh
93 | ```
94 |
95 |
96 | ## Citation
97 | If you find our work useful in your research, please consider citing:
98 |
99 | @inproceedings{sun2022patchrd,
100 | author = {Bo, Sun and Kim, Vladimir(Vova) and Huang, Qixing and Aigerman, Noam and Chaudhuri, Siddhartha},
101 | title = {PatchRD: Detail-Preserving Shape Completion by Learning Patch Retrieval and Deformation},
102 | booktitle = {Proceedings of the IEEE European Conference on Computer Vision},
103 | year = {2022}
104 | }
105 |
--------------------------------------------------------------------------------
/Teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitBoSun/PatchRD/223af88e57f31daf2a016683d2804921fab903e2/Teaser.png
--------------------------------------------------------------------------------
/binvox_rw.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2012 Daniel Maturana
2 | # This file is part of binvox-rw-py.
3 | #
4 | # binvox-rw-py is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU General Public License as published by
6 | # the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # binvox-rw-py is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU General Public License
15 | # along with binvox-rw-py. If not, see .
16 | #
17 |
18 | """
19 | Binvox to Numpy and back.
20 |
21 |
22 | >>> import numpy as np
23 | >>> import binvox_rw
24 | >>> with open('chair.binvox', 'rb') as f:
25 | ... m1 = binvox_rw.read_as_3d_array(f)
26 | ...
27 | >>> m1.dims
28 | [32, 32, 32]
29 | >>> m1.scale
30 | 41.133000000000003
31 | >>> m1.translate
32 | [0.0, 0.0, 0.0]
33 | >>> with open('chair_out.binvox', 'wb') as f:
34 | ... m1.write(f)
35 | ...
36 | >>> with open('chair_out.binvox', 'rb') as f:
37 | ... m2 = binvox_rw.read_as_3d_array(f)
38 | ...
39 | >>> m1.dims==m2.dims
40 | True
41 | >>> m1.scale==m2.scale
42 | True
43 | >>> m1.translate==m2.translate
44 | True
45 | >>> np.all(m1.data==m2.data)
46 | True
47 |
48 | >>> with open('chair.binvox', 'rb') as f:
49 | ... md = binvox_rw.read_as_3d_array(f)
50 | ...
51 | >>> with open('chair.binvox', 'rb') as f:
52 | ... ms = binvox_rw.read_as_coord_array(f)
53 | ...
54 | >>> data_ds = binvox_rw.dense_to_sparse(md.data)
55 | >>> data_sd = binvox_rw.sparse_to_dense(ms.data, 32)
56 | >>> np.all(data_sd==md.data)
57 | True
58 | >>> # the ordering of elements returned by numpy.nonzero changes with axis
59 | >>> # ordering, so to compare for equality we first lexically sort the voxels.
60 | >>> np.all(ms.data[:, np.lexsort(ms.data)] == data_ds[:, np.lexsort(data_ds)])
61 | True
62 | """
63 |
64 | import numpy as np
65 |
66 | class Voxels(object):
67 | """ Holds a binvox model.
68 | data is either a three-dimensional numpy boolean array (dense representation)
69 | or a two-dimensional numpy float array (coordinate representation).
70 |
71 | dims, translate and scale are the model metadata.
72 |
73 | dims are the voxel dimensions, e.g. [32, 32, 32] for a 32x32x32 model.
74 |
75 | scale and translate relate the voxels to the original model coordinates.
76 |
77 | To translate voxel coordinates i, j, k to original coordinates x, y, z:
78 |
79 | x_n = (i+.5)/dims[0]
80 | y_n = (j+.5)/dims[1]
81 | z_n = (k+.5)/dims[2]
82 | x = scale*x_n + translate[0]
83 | y = scale*y_n + translate[1]
84 | z = scale*z_n + translate[2]
85 |
86 | """
87 |
88 | def __init__(self, data, dims, translate, scale, axis_order):
89 | self.data = data
90 | self.dims = dims
91 | self.translate = translate
92 | self.scale = scale
93 | assert (axis_order in ('xzy', 'xyz'))
94 | self.axis_order = axis_order
95 |
96 | def clone(self):
97 | data = self.data.copy()
98 | dims = self.dims[:]
99 | translate = self.translate[:]
100 | return Voxels(data, dims, translate, self.scale, self.axis_order)
101 |
102 | def write(self, fp):
103 | write(self, fp)
104 |
105 | def read_header(fp):
106 | """ Read binvox header. Mostly meant for internal use.
107 | """
108 | line = fp.readline().strip()
109 | if not line.startswith(b'#binvox'):
110 | raise IOError('Not a binvox file')
111 | dims = list(map(int, fp.readline().strip().split(b' ')[1:]))
112 | translate = list(map(float, fp.readline().strip().split(b' ')[1:]))
113 | scale = list(map(float, fp.readline().strip().split(b' ')[1:]))[0]
114 | line = fp.readline()
115 | return dims, translate, scale
116 |
117 | def read_as_3d_array(fp, fix_coords=True):
118 | """ Read binary binvox format as array.
119 |
120 | Returns the model with accompanying metadata.
121 |
122 | Voxels are stored in a three-dimensional numpy array, which is simple and
123 | direct, but may use a lot of memory for large models. (Storage requirements
124 | are 8*(d^3) bytes, where d is the dimensions of the binvox model. Numpy
125 | boolean arrays use a byte per element).
126 |
127 | Doesn't do any checks on input except for the '#binvox' line.
128 | """
129 | dims, translate, scale = read_header(fp)
130 | raw_data = np.frombuffer(fp.read(), dtype=np.uint8)
131 | # if just using reshape() on the raw data:
132 | # indexing the array as array[i,j,k], the indices map into the
133 | # coords as:
134 | # i -> x
135 | # j -> z
136 | # k -> y
137 | # if fix_coords is true, then data is rearranged so that
138 | # mapping is
139 | # i -> x
140 | # j -> y
141 | # k -> z
142 | values, counts = raw_data[::2], raw_data[1::2]
143 | data = np.repeat(values, counts).astype(np.bool)
144 | data = data.reshape(dims)
145 | if fix_coords:
146 | # xzy to xyz TODO the right thing
147 | data = np.transpose(data, (0, 2, 1))
148 | axis_order = 'xyz'
149 | else:
150 | axis_order = 'xzy'
151 | return Voxels(data, dims, translate, scale, axis_order)
152 |
153 | def read_as_coord_array(fp, fix_coords=True):
154 | """ Read binary binvox format as coordinates.
155 |
156 | Returns binvox model with voxels in a "coordinate" representation, i.e. an
157 | 3 x N array where N is the number of nonzero voxels. Each column
158 | corresponds to a nonzero voxel and the 3 rows are the (x, z, y) coordinates
159 | of the voxel. (The odd ordering is due to the way binvox format lays out
160 | data). Note that coordinates refer to the binvox voxels, without any
161 | scaling or translation.
162 |
163 | Use this to save memory if your model is very sparse (mostly empty).
164 |
165 | Doesn't do any checks on input except for the '#binvox' line.
166 | """
167 | dims, translate, scale = read_header(fp)
168 | raw_data = np.frombuffer(fp.read(), dtype=np.uint8)
169 |
170 | values, counts = raw_data[::2], raw_data[1::2]
171 |
172 | sz = np.prod(dims)
173 | index, end_index = 0, 0
174 | end_indices = np.cumsum(counts)
175 | indices = np.concatenate(([0], end_indices[:-1])).astype(end_indices.dtype)
176 |
177 | values = values.astype(np.bool)
178 | indices = indices[values]
179 | end_indices = end_indices[values]
180 |
181 | nz_voxels = []
182 | for index, end_index in zip(indices, end_indices):
183 | nz_voxels.extend(range(index, end_index))
184 | nz_voxels = np.array(nz_voxels)
185 | # TODO are these dims correct?
186 | # according to docs,
187 | # index = x * wxh + z * width + y; // wxh = width * height = d * d
188 |
189 | x = nz_voxels / (dims[0]*dims[1])
190 | zwpy = nz_voxels % (dims[0]*dims[1]) # z*w + y
191 | z = zwpy / dims[0]
192 | y = zwpy % dims[0]
193 | if fix_coords:
194 | data = np.vstack((x, y, z))
195 | axis_order = 'xyz'
196 | else:
197 | data = np.vstack((x, z, y))
198 | axis_order = 'xzy'
199 |
200 | #return Voxels(data, dims, translate, scale, axis_order)
201 | return Voxels(np.ascontiguousarray(data), dims, translate, scale, axis_order)
202 |
203 | def dense_to_sparse(voxel_data, dtype=np.int):
204 | """ From dense representation to sparse (coordinate) representation.
205 | No coordinate reordering.
206 | """
207 | if voxel_data.ndim!=3:
208 | raise ValueError('voxel_data is wrong shape; should be 3D array.')
209 | return np.asarray(np.nonzero(voxel_data), dtype)
210 |
211 | def sparse_to_dense(voxel_data, dims, dtype=np.bool):
212 | if voxel_data.ndim!=2 or voxel_data.shape[0]!=3:
213 | raise ValueError('voxel_data is wrong shape; should be 3xN array.')
214 | if np.isscalar(dims):
215 | dims = [dims]*3
216 | dims = np.atleast_2d(dims).T
217 | # truncate to integers
218 | xyz = voxel_data.astype(np.int)
219 | # discard voxels that fall outside dims
220 | valid_ix = ~np.any((xyz < 0) | (xyz >= dims), 0)
221 | xyz = xyz[:,valid_ix]
222 | out = np.zeros(dims.flatten(), dtype=dtype)
223 | out[tuple(xyz)] = True
224 | return out
225 |
226 | #def get_linear_index(x, y, z, dims):
227 | #""" Assuming xzy order. (y increasing fastest.
228 | #TODO ensure this is right when dims are not all same
229 | #"""
230 | #return x*(dims[1]*dims[2]) + z*dims[1] + y
231 |
232 | def write(voxel_model, fp):
233 | """ Write binary binvox format.
234 |
235 | Note that when saving a model in sparse (coordinate) format, it is first
236 | converted to dense format.
237 |
238 | Doesn't check if the model is 'sane'.
239 |
240 | """
241 | if voxel_model.data.ndim==2:
242 | # TODO avoid conversion to dense
243 | dense_voxel_data = sparse_to_dense(voxel_model.data, voxel_model.dims)
244 | else:
245 | dense_voxel_data = voxel_model.data
246 |
247 | fp.write('#binvox 1\n')
248 | fp.write('dim '+' '.join(map(str, voxel_model.dims))+'\n')
249 | fp.write('translate '+' '.join(map(str, voxel_model.translate))+'\n')
250 | fp.write('scale '+str(voxel_model.scale)+'\n')
251 | fp.write('data\n')
252 | if not voxel_model.axis_order in ('xzy', 'xyz'):
253 | raise ValueError('Unsupported voxel model axis order')
254 |
255 | if voxel_model.axis_order=='xzy':
256 | voxels_flat = dense_voxel_data.flatten()
257 | elif voxel_model.axis_order=='xyz':
258 | voxels_flat = np.transpose(dense_voxel_data, (0, 2, 1)).flatten()
259 |
260 | # keep a sort of state machine for writing run length encoding
261 | state = voxels_flat[0]
262 | ctr = 0
263 | for c in voxels_flat:
264 | if c==state:
265 | ctr += 1
266 | # if ctr hits max, dump
267 | if ctr==255:
268 | fp.write(chr(state))
269 | fp.write(chr(ctr))
270 | ctr = 0
271 | else:
272 | # if switch state, dump
273 | fp.write(chr(state))
274 | fp.write(chr(ctr))
275 | state = c
276 | ctr = 1
277 | # flush out remainders
278 | if ctr > 0:
279 | fp.write(chr(state))
280 | fp.write(chr(ctr))
281 |
282 | if __name__ == '__main__':
283 | import doctest
284 | doctest.testmod()
285 |
--------------------------------------------------------------------------------
/complete_train.sh:
--------------------------------------------------------------------------------
1 | python main.py \
2 | --data_content content_chair \
3 | --data_dir ./data/03001627/ \
4 | --input_size 32 \
5 | --output_size 128 \
6 | --sample_dir samples_complete \
7 | --checkpoint_dir checkpoint_complete \
8 | --gpu 0 \
9 | --epoch 101 \
10 | --lr 1e-4 \
11 | --g_dim 32 \
12 | --w_posi 1.5 \
13 | --w_mask 0.5 \
14 | --csize 26 \
15 | --c_range 26 \
16 | --train_complete \
17 | --model_name coarse_comp \
18 | #--mode test \
19 | #--continue_train \
20 |
--------------------------------------------------------------------------------
/deform_train.sh:
--------------------------------------------------------------------------------
1 | python main.py \
2 | --data_content content_chair \
3 | --data_dir ./data/03001627/ \
4 | --input_size 32 \
5 | --output_size 128 \
6 | --sample_dir samples_deform \
7 | --checkpoint_dir checkpoint_deform \
8 | --gpu 0 \
9 | --epoch 100 \
10 | --lr 1e-4 \
11 | --z_dim 3 \
12 | --g_dim 32 \
13 | --K 2 \
14 | --w_dis 3.0 \
15 | --csize 26 \
16 | --c_range 26 \
17 | --max_sample_num 600 \
18 | --train_deform \
19 | --model_name chair_model0 \
20 | --dump_deform_path ./dump_deform/chair/model0 \
21 | #--dump_deform \
22 | #--small_dataset \
23 | #--mode test \
24 | #--continue_train \
25 |
--------------------------------------------------------------------------------
/joint_test.sh:
--------------------------------------------------------------------------------
1 | python main.py \
2 | --data_content content_chair \
3 | --data_dir ./data/03001627/ \
4 | --input_size 32 \
5 | --output_size 128 \
6 | --sample_dir samples_joint \
7 | --checkpoint_dir checkpoint_joint \
8 | --log_dir logs/joint \
9 | --gpu 0 \
10 | --epoch 1 \
11 | --lr 2e-4 \
12 | --g_dim 32 \
13 | --gw_dim 32 \
14 | --max_wd_num 400 \
15 | --loc_size 5 \
16 | --wd_size 8 \
17 | --sample_step 2 \
18 | --trans_limit 2.0 \
19 | --w_s 10 \
20 | --train_joint \
21 | --model_name chair_model0 \
22 | --dump_deform_path ./dump_deform/chair/model0 \
23 | --compute_cd \
24 | --mode test \
25 | #--small_dataset \
26 |
--------------------------------------------------------------------------------
/joint_train.sh:
--------------------------------------------------------------------------------
1 | python main.py \
2 | --data_content content_chair \
3 | --data_dir ./data/03001627/ \
4 | --input_size 32 \
5 | --output_size 128 \
6 | --sample_dir samples_joint \
7 | --checkpoint_dir checkpoint_joint \
8 | --log_dir logs/joint \
9 | --gpu 0 \
10 | --epoch 1 \
11 | --lr 2e-4 \
12 | --g_dim 32 \
13 | --gw_dim 32 \
14 | --max_wd_num 400 \
15 | --loc_size 5 \
16 | --wd_size 8 \
17 | --sample_step 2 \
18 | --trans_limit 2.0 \
19 | --w_s 10 \
20 | --train_joint \
21 | --model_name chair_model0 \
22 | --dump_deform_path ./dump_deform/chair/model0 \
23 | #--small_dataset \
24 | #--compute_cd \
25 | #--continue_train \
26 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import argparse
4 |
5 | parser = argparse.ArgumentParser()
6 |
7 | # training options
8 | parser.add_argument("--train_complete", action="store_true", dest="train_complete", default=False, help="True for training the coarse completor")
9 | parser.add_argument("--train_patch", action="store_true", dest="train_patch", default=False, help="True for training feature embedding for retrieval")
10 | parser.add_argument("--train_deform", action="store_true", dest="train_deform", default=False, help="True for training initial deformation")
11 | parser.add_argument("--train_joint", action="store_true", dest="train_joint", default=False, help="True for jointly training deformation and blending")
12 | parser.add_argument("--mode", action="store", dest="mode", default="train", help="mode: train or test")
13 | parser.add_argument("--gpu", action="store", dest="gpu", default="0", help="to use which GPU")
14 |
15 | # training parameters
16 | parser.add_argument("--epoch", action="store", dest="epoch", default=20, type=int, help="Epoch to train")
17 | parser.add_argument("--batch_size", action="store", dest="batch_size", default=1, type=int, help="Batch size [1]")
18 | parser.add_argument("--lr", action="store", dest="lr", default=1e-4, type=float, help="Learning rate")
19 | parser.add_argument("--decay_step", action="store", dest="decay_step", default=1, type=int, help="lr decay step")
20 | parser.add_argument("--lr_decay", action="store", dest="lr_decay", default=0.99, type=float, help="lr decay rate")
21 | parser.add_argument("--continue_train", action="store_true", dest="continue_train", default=False, help="If True, continue training, otherwise train from scratch")
22 |
23 | # dump and eval options
24 | parser.add_argument("--dump_deform", action="store_true", dest="dump_deform", default=False, help="True for dumping intermediate retrieval and deformation results in mode --train_patch and --train_deform")
25 | parser.add_argument("--compute_cd", action="store_true", dest="compute_cd", default=False, help="True for evaluating chamfer distance and dumping mesh resulst in mode --train_joint")
26 | parser.add_argument("--small_dataset", action="store_true", dest="small_dataset", default=False, help="True for only use four samples in the testing mode")
27 |
28 |
29 | # model parameters
30 | parser.add_argument("--g_dim", action="store", dest="g_dim", default=32, type=int, help="Channel dimension for models")
31 | parser.add_argument("--z_dim", action="store", dest="z_dim", default=8, type=int, help="Dimension of the latent feature embedding for retrieval learning")
32 |
33 | # data parameters
34 | parser.add_argument("--input_size", action="store", dest="input_size", default=32, type=int, help="Input voxel size")
35 | parser.add_argument("--output_size", action="store", dest="output_size", default=128, type=int, help="Output voxel size")
36 | parser.add_argument("--patch_size", action="store", dest="patch_size", default=18, type=int, help="Patch size for retrieval and deformation")
37 | parser.add_argument("--csize", action="store", dest="csize", default=26, type=int, help="Average size of cropped area")
38 | parser.add_argument("--c_range", action="store", dest="c_range", default=26, type=int, help="Devation of cropped size")
39 | parser.add_argument("--K", action="store", dest="K", default=2, type=int, help="ratio of randomly sampled patch pairs VS similiar pairs in retrieval learning")
40 | parser.add_argument("--max_sample_num", action="store", dest="max_sample_num", default=200, type=int, help="Maximum number of patches in retrieval learning")
41 |
42 |
43 | # paths
44 | parser.add_argument("--data_content", action="store", dest="data_content", help="Data category. See ./splits for all categories")
45 | parser.add_argument("--data_dir", action="store", dest="data_dir", help="Root directory of dataset")
46 | parser.add_argument("--dump_deform_path", action="store", dest="dump_deform_path", help="Directory to dump intermidiate retrieval and initial deformation results.")
47 | parser.add_argument("--model_name", action="store", dest="model_name", default="checkpoint", help="Model name.")
48 | parser.add_argument("--checkpoint_dir", action="store", dest="checkpoint_dir", default="checkpoint", help="Directory name to save the checkpoints")
49 | parser.add_argument("--sample_dir", action="store", dest="sample_dir", default="./samples/", help="Directory name to save the image samples")
50 | parser.add_argument("--log_dir", action="store", dest="log_dir", default="./logs/", help="Directory name to save the training logs")
51 |
52 |
53 | # loss weights
54 | parser.add_argument("--w_mask", action="store", dest="w_mask", default=1, type=float, help="[coarse completion]: weight for the occupied area in the reconstruction loss")
55 | parser.add_argument("--w_posi", action="store", dest="w_posi", default=1, type=float, help="[coarse completion]: weight for the occupied area in the cross entropy loss")
56 | parser.add_argument("--w_ident", action="store", dest="w_ident", default=0, type=float, help="[retrieval learning]: weight for similar pairs")
57 | parser.add_argument("--w_dis", action="store", dest="w_dis", default=1, type=float, help="[initial deformation]: weight for the patch distance prediction")
58 | parser.add_argument("--w_r", action="store", dest="w_r", default=1, type=float, help="[deformation and blending]: weight for reconstruction term")
59 | parser.add_argument("--w_s", action="store", dest="w_s", default=0, type=float, help="[deformation and blending]: weight for smoothness term")
60 |
61 | # parameters for training deformation and blending (--train_joint)
62 | parser.add_argument("--max_wd_num", action="store", dest="max_wd_num", default=32, type=int, help="Maximum number of windows")
63 | parser.add_argument("--sample_step", action="store", dest="sample_step", default=32, type=int, help="Sample stride of retrieved patches")
64 | parser.add_argument("--loc_size", action="store", dest="loc_size", default=32, type=int, help="Number of windows in one subvolume")
65 | parser.add_argument("--wd_size", action="store", dest="wd_size", default=8, type=int, help="Window size")
66 | parser.add_argument("--trans_limit", action="store", dest="trans_limit", default=1e-4, type=float, help="Translation limit upon the initial deformation")
67 | parser.add_argument("--gw_dim", action="store", dest="gw_dim", default=32, type=int, help="Channel dimension for the window encoding branch")
68 |
69 | FLAGS = parser.parse_args()
70 |
71 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
72 | os.environ["CUDA_VISIBLE_DEVICES"]=FLAGS.gpu
73 |
74 |
75 | if FLAGS.train_patch:
76 | from train_patch import MODEL_PATCH
77 | model = MODEL_PATCH(FLAGS)
78 | model.train(FLAGS)
79 |
80 | elif FLAGS.train_complete:
81 | from train_complete import MODEL_COMPLETE
82 | model = MODEL_COMPLETE(FLAGS)
83 | model.train(FLAGS)
84 |
85 | elif FLAGS.train_deform:
86 | print('***********train deform')
87 | from train_deform import MODEL_DEFORM
88 | model = MODEL_DEFORM(FLAGS)
89 | model.train(FLAGS)
90 |
91 | elif FLAGS.train_joint:
92 | print('***********train joint')
93 | from train_joint import MODEL_JOINT
94 | model = MODEL_JOINT(FLAGS)
95 | model.train(FLAGS)
96 | else:
97 | print('Invalid training options!')
98 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.backends.cudnn as cudnn
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch import optim
6 | from torch.autograd import Variable
7 | import numpy as np
8 |
9 | #cell = 4
10 | #input 256
11 | #output 120 (128-4-4)
12 | #receptive field = 18
13 |
14 | # 0 18
15 | #conv 4x4 s1 4 15
16 | #conv 3x3 s2 6 7
17 | #conv 3x3 s1 10 5
18 | #conv 3x3 s1 14 3
19 | #conv 3x3 s1 18 1
20 | #conv 1x1 s1 1 1
21 |
22 | # 0 41
23 | #conv 3x3 s1 4 39
24 | #conv 3x3 s2 6 19
25 | #conv 3x3 s1 10 17
26 | #conv 3x3 s1 14 15
27 | #conv 3x3 s2 18 7
28 | #conv 3x3 s1 26 7
29 |
30 | leaky_f = 0.02
31 | USE_25=True
32 |
33 | class CoarseCompletor_skip(nn.Module):
34 | def __init__(self, d_dim, ):
35 | super(CoarseCompletor_skip, self).__init__()
36 | self.d_dim = d_dim
37 |
38 | self.conv_1 = nn.Conv3d(1, self.d_dim, 5, stride=1, padding=2, bias=True)
39 | self.conv_2 = nn.Conv3d(self.d_dim, self.d_dim*2, 3, stride=2, padding=1, bias=True) # 64
40 |
41 | self.conv_3 = nn.Conv3d(self.d_dim*2, self.d_dim*2, 3, stride=1, padding=1, bias=True)
42 | self.conv_4 = nn.Conv3d(self.d_dim*2, self.d_dim*4, 3, stride=2, padding=1, bias=True) # 32
43 |
44 | self.conv_5 = nn.Conv3d(self.d_dim*4, self.d_dim*4, 3, stride=1, padding=1, bias=True)
45 | self.conv_6 = nn.Conv3d(self.d_dim*4, self.d_dim*8, 3, stride=2, padding=1, bias=True) # 16
46 |
47 | self.conv_7 = nn.Conv3d(self.d_dim*8, self.d_dim*8, 3, stride=1, padding=1, bias=True)
48 | self.conv_8 = nn.Conv3d(self.d_dim*8, self.d_dim*8, 3, stride=2, padding=1, bias=True) # 8
49 |
50 | self.conv_9 = nn.Conv3d(self.d_dim*8, self.d_dim*8, 3, stride=2, padding=1, bias=True) # 4
51 | self.conv_10 = nn.Conv3d(self.d_dim*8, self.d_dim*8, 3, stride=1, padding=1, bias=True) # 4
52 |
53 | self.dconv_1 = nn.ConvTranspose3d(self.d_dim*8, self.d_dim*8, 4, stride=2, padding=1, bias=True) # 8
54 |
55 | self.dconv_2 = nn.Conv3d(self.d_dim*16, self.d_dim*8, 3, stride=1, padding=1, bias=True)
56 | self.dconv_3 = nn.ConvTranspose3d(self.d_dim*8, self.d_dim*8, 4, stride=2, padding=1, bias=True) # 16
57 |
58 | self.dconv_4 = nn.Conv3d(self.d_dim*16, self.d_dim*8, 3, stride=1, padding=1, bias=True)
59 | self.dconv_5 = nn.ConvTranspose3d(self.d_dim*8, self.d_dim*4, 4, stride=2, padding=1, bias=True) # 32
60 |
61 | self.dconv_6 = nn.Conv3d(self.d_dim*8, self.d_dim*4, 3, stride=1, padding=1, bias=True)
62 | self.dconv_7 = nn.ConvTranspose3d(self.d_dim*4, self.d_dim*2, 4, stride=2, padding=1, bias=True) # 64
63 |
64 | self.dconv_8 = nn.Conv3d(self.d_dim*4, self.d_dim*2, 3, stride=1, padding=1, bias=True)
65 | self.dconv_9 = nn.ConvTranspose3d(self.d_dim*2, self.d_dim, 4, stride=2, padding=1, bias=True) # 128
66 |
67 | self.dconv_10 = nn.Conv3d(self.d_dim, 1, 3, stride=1, padding=1, bias=True)
68 |
69 | def forward(self, partial_in, is_training=False):
70 | out = partial_in
71 | out = self.conv_1(out)
72 | out1 = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
73 |
74 | out = self.conv_2(out1)
75 | out2 = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
76 |
77 | out = self.conv_3(out2)
78 | out3 = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
79 |
80 | out = self.conv_4(out3)
81 | out4 = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
82 |
83 | out = self.conv_5(out4)
84 | out5 = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
85 |
86 | out = self.conv_6(out5)
87 | out6 = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
88 |
89 | out = self.conv_7(out6)
90 | out7 = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
91 |
92 | out = self.conv_8(out7)
93 | out8 = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
94 |
95 | out = self.conv_9(out8)
96 | out9 = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
97 |
98 | out = self.conv_10(out9)
99 | out10 = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
100 |
101 | out = self.dconv_1(out10)
102 | outd1 = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
103 |
104 | out = torch.cat([outd1, out8], dim=1)
105 | out = self.dconv_2(out)
106 | outd2 = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
107 |
108 | out = self.dconv_3(outd2)
109 | outd3 = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
110 |
111 | out = torch.cat([outd3, out6], dim=1)
112 | out = self.dconv_4(out)
113 | outd4 = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
114 |
115 | out = self.dconv_5(outd4)
116 | outd5 = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
117 |
118 | out = torch.cat([outd5, out4], dim=1)
119 | out = self.dconv_6(out)
120 | outd6 = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
121 |
122 | out = self.dconv_7(outd6)
123 | outd7 = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
124 |
125 | out = torch.cat([outd7, out2], dim=1)
126 | out = self.dconv_8(out)
127 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
128 |
129 | out = self.dconv_9(out)
130 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
131 |
132 | out = self.dconv_10(out)
133 |
134 | #out = torch.max(torch.min(out, out*0.002+0.998), out*0.002-0.998)
135 | out = torch.sigmoid(out)
136 | return out
137 |
138 |
139 | class PatchEncoder(nn.Module):
140 | def __init__(self, d_dim, z_dim):
141 | super(PatchEncoder, self).__init__()
142 | self.d_dim = d_dim
143 | self.z_dim = z_dim
144 |
145 | self.conv_1 = nn.Conv3d(1, self.d_dim, 4, stride=1, padding=0, bias=True)
146 | self.conv_2 = nn.Conv3d(self.d_dim, self.d_dim*2, 3, stride=2, padding=0, bias=True)
147 | self.conv_3 = nn.Conv3d(self.d_dim*2, self.d_dim*4, 3, stride=1, padding=0, bias=True)
148 | self.conv_4 = nn.Conv3d(self.d_dim*4, self.d_dim*8, 3, stride=1, padding=0, bias=True)
149 | self.conv_5 = nn.Conv3d(self.d_dim*8, self.d_dim*16, 3, stride=2, padding=0, bias=True)
150 | if USE_25:
151 | self.conv_6 = nn.Conv3d(self.d_dim*16, self.d_dim*16, 3, stride=1, padding=1, bias=True) # extra layer
152 | self.conv_7 = nn.Conv3d(self.d_dim*32, self.z_dim, 1, stride=1, padding=0, bias=True)
153 | else:
154 |
155 | self.conv_6 = nn.Conv3d(self.d_dim*16, self.z_dim*2, 1, stride=1, padding=0, bias=True)
156 | self.conv_7 = nn.Conv3d(self.z_dim*2, self.z_dim, 1, stride=1, padding=0, bias=True)
157 |
158 | def forward(self, voxels, return_feat=False, is_training=False):
159 | out = voxels
160 |
161 | out = self.conv_1(out)
162 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
163 | feat1 = out.detach()
164 |
165 | out = self.conv_2(out)
166 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
167 | feat2 = out.detach()
168 |
169 | out = self.conv_3(out)
170 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
171 | feat3 = out.detach()
172 |
173 | out = self.conv_4(out)
174 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
175 | feat4 = out.detach()
176 |
177 | out_1 = self.conv_5(out)
178 | out_1 = F.leaky_relu(out_1, negative_slope=leaky_f, inplace=True)
179 | feat5 = out_1.detach()
180 |
181 | if USE_25:
182 | out_2 = self.conv_6(out_1)
183 | out_2 = F.leaky_relu(out_2, negative_slope=leaky_f, inplace=True)
184 | feat6 = out_2.detach()
185 | out = self.conv_7(torch.cat([out_1, out_2], dim=1))
186 | else:
187 | out = self.conv_6(out_1)
188 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
189 | out = self.conv_7(out)
190 |
191 | out = torch.max(torch.min(out, out*0.002+0.998), out*0.002-0.998)
192 | # out = torch.sigmoid(out)
193 | if return_feat:
194 | return out, feat1, feat2, feat3, feat4, feat5, feat6
195 | return out
196 |
197 |
198 | def reshape_to_size_torch(v1, v2_shape):
199 | _,_,x1,y1,z1 = v1.shape
200 | x,y,z = v2_shape
201 | new_v1 = v1
202 | padding = np.zeros(6)
203 | if z1 < z:
204 | padding[0], padding[1] = int((z-z1)/2), z-z1-int((z-z1)/2)
205 | else:
206 | new_v1 = new_v1[:,:,:, :,int((z1-z)/2):int((z1-z)/2)+z ]
207 | if y1 < y:
208 | padding[2], padding[3] = int((y-y1)/2), y-y1-int((y-y1)/2)
209 | else:
210 | new_v1 = new_v1[:,:,:, int((y1-y)/2):int((y1-y)/2)+y,: ]
211 | if x1 < x:
212 | padding[4], padding[5] = int((x-x1)/2), x-x1-int((x-x1)/2)
213 | else:
214 | new_v1 = new_v1[:,:,int((x1-x)/2):int((x1-x)/2)+x,:,: ]
215 | new_v1 = F.pad(new_v1, tuple(padding.astype(np.int8)))
216 | return new_v1
217 |
218 | class PatchDeformer(nn.Module):
219 | def __init__(self, d_dim, z_dim, pre_dis=False, include_coarse=False):
220 | super(PatchDeformer, self).__init__()
221 | self.d_dim = d_dim
222 | self.z_dim = z_dim
223 | self.pred_dis = pre_dis
224 | self.include_coarse = include_coarse
225 |
226 | self.conv_1 = nn.Conv3d(1, self.d_dim, 3, stride=1, padding=1, bias=True)
227 | self.conv_2 = nn.Conv3d(self.d_dim, self.d_dim*2, 3, stride=2, padding=1, bias=True)
228 | self.conv_3 = nn.Conv3d(self.d_dim*2, self.d_dim*4, 3, stride=2, padding=1, bias=True)
229 |
230 | self.conv_d1 = nn.Conv3d(1, self.d_dim, 3, stride=1, padding=1, bias=True)
231 | self.conv_d2 = nn.Conv3d(self.d_dim, self.d_dim*2, 3, stride=2, padding=1, bias=True)
232 | self.conv_d3 = nn.Conv3d(self.d_dim*2, self.d_dim*4, 3, stride=2, padding=1, bias=True)
233 |
234 | self.fc1 = nn.Linear(self.d_dim*8*5*5*5, 4096)
235 | self.fc2 = nn.Linear(4096, 1024)
236 | self.fc3 = nn.Linear(1024, self.z_dim)
237 |
238 | if self.pred_dis:
239 | self.dis_fc1 = nn.Linear(self.d_dim*8*5*5*5, 2048)
240 | self.dis_fc2 = nn.Linear(2048, 512)
241 | self.dis_fc3 = nn.Linear(512, 1)
242 |
243 | if self.include_coarse:
244 | self.cdis_fc1 = nn.Linear(self.d_dim*4*5*5*5, 1024)
245 | self.cdis_fc2 = nn.Linear(1024, 512)
246 | self.cdis_fc3 = nn.Linear(512, 1)
247 |
248 | self.ddis_fc1 = nn.Linear(self.d_dim*4*5*5*5, 1024)
249 | self.ddis_fc2 = nn.Linear(1024, 512)
250 | self.ddis_fc3 = nn.Linear(512, 1)
251 |
252 | def forward(self, c_vox, d_vox, is_training=False):
253 |
254 | out = c_vox
255 | out = self.conv_1(out)
256 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
257 | out = self.conv_2(out)
258 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
259 | out = self.conv_3(out)
260 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
261 | out_1 = out.view(-1, 5*5*5*self.d_dim*4)
262 |
263 | out = d_vox
264 | out = self.conv_d1(out)
265 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
266 | out = self.conv_d2(out)
267 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
268 | out = self.conv_d3(out)
269 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
270 | out_2 = out.view(-1, 5*5*5*self.d_dim*4)
271 |
272 |
273 | out = torch.cat([out_1, out_2], dim=1)
274 | out = self.fc1(out)
275 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
276 | out = self.fc2(out)
277 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
278 | out = self.fc3(out)
279 | # out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
280 | out_de = torch.max(torch.min(out, out*0.002+0.998), out*0.002-0.998)
281 |
282 | if self.pred_dis:
283 | out = torch.cat([out_1, out_2], dim=1)
284 | out = self.dis_fc1(out)
285 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
286 | out = self.dis_fc2(out)
287 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
288 | out = self.dis_fc3(out)
289 |
290 | out_dis = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
291 | # out_dis = torch.max(torch.min(out, out*0.002+0.998), out*0.002)
292 | if not self.include_coarse:
293 | return out_de, out_dis
294 | else:
295 | # out = torch.cat([out_1, out_2], dim=1)
296 | out = out_1
297 | out = self.cdis_fc1(out)
298 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
299 | out = self.cdis_fc2(out)
300 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
301 | out = self.cdis_fc3(out)
302 | out_c_dis = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
303 |
304 | out = out_2
305 | out = self.ddis_fc1(out)
306 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
307 | out = self.ddis_fc2(out)
308 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
309 | out = self.ddis_fc3(out)
310 | out_d_dis = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
311 |
312 | return out_de, out_dis, out_c_dis, out_d_dis
313 |
314 | return out_de
315 |
316 | class JointDeformer(nn.Module):
317 | def __init__(self, d_dim, dw_dim, k_dim,loc_size, use_mean_x=False, wd_size=8):
318 | super(JointDeformer, self).__init__()
319 | self.d_dim = d_dim
320 | self.dw_dim = dw_dim
321 | self.k_dim = k_dim
322 | self.d = loc_size
323 | self.use_mean_x = use_mean_x
324 | self.wd_size = wd_size
325 |
326 | # input
327 | self.conv_1 = nn.Conv3d(2, self.d_dim, 5, stride=1, padding=2, bias=True)
328 | self.conv_2 = nn.Conv3d(self.d_dim, self.d_dim*2, 3, stride=2, padding=1, bias=True) # 64
329 | self.conv_3 = nn.Conv3d(self.d_dim*2, self.d_dim*2, 3, stride=1, padding=1, bias=True)
330 | self.conv_4 = nn.Conv3d(self.d_dim*2, self.d_dim*4, 3, stride=2, padding=1, bias=True) # 32
331 |
332 | self.conv_5 = nn.Conv3d(self.d_dim*4, self.d_dim*4, 3, stride=1, padding=1, bias=True)
333 | self.conv_6 = nn.Conv3d(self.d_dim*4, self.d_dim*8, 3, stride=2, padding=1, bias=True) # 16
334 | self.conv_7 = nn.Conv3d(self.d_dim*8, self.d_dim*8, 3, stride=2, padding=1, bias=True) # 8
335 | self.conv_8 = nn.Conv3d(self.d_dim*8, self.d_dim*8, 3, stride=2, padding=1, bias=True) # 4
336 |
337 | # partial
338 | self.conv_p1 = nn.Conv3d(1, self.d_dim, 5, stride=1, padding=2, bias=True)
339 | self.conv_p2 = nn.Conv3d(self.d_dim, self.d_dim*2, 3, stride=2, padding=1, bias=True) # 64
340 | self.conv_p3 = nn.Conv3d(self.d_dim*2, self.d_dim*2, 3, stride=1, padding=1, bias=True)
341 | self.conv_p4 = nn.Conv3d(self.d_dim*2, self.d_dim*4, 3, stride=2, padding=1, bias=True) # 32
342 |
343 | self.conv_p5 = nn.Conv3d(self.d_dim*4, self.d_dim*4, 3, stride=1, padding=1, bias=True)
344 | self.conv_p6 = nn.Conv3d(self.d_dim*4, self.d_dim*8, 3, stride=2, padding=1, bias=True) # 16
345 | self.conv_p7 = nn.Conv3d(self.d_dim*8, self.d_dim*8, 3, stride=2, padding=1, bias=True) # 8
346 | self.conv_p8 = nn.Conv3d(self.d_dim*8, self.d_dim*8, 3, stride=2, padding=1, bias=True) # 4
347 |
348 | # windows
349 | self.conv_w0 = nn.Conv3d(self.k_dim, self.dw_dim, 5, stride=1, padding=2, bias=True)
350 | self.conv_w1 = nn.Conv3d(self.dw_dim, self.dw_dim, 3, stride=1, padding=1, bias=True)
351 | self.conv_w2 = nn.Conv3d(self.dw_dim, self.dw_dim*2, 3, stride=2, padding=1, bias=True) # 12 . 16
352 | self.conv_w3 = nn.Conv3d(self.dw_dim*2, self.dw_dim*2, 3, stride=1, padding=1, bias=True)
353 | self.conv_w4 = nn.Conv3d(self.dw_dim*2, self.dw_dim*4, 3, stride=2, padding=1, bias=True) # 6 . 8
354 |
355 | self.conv_w5 = nn.Conv3d(self.dw_dim*4, self.dw_dim*4, 3, stride=1, padding=1, bias=True)
356 | self.conv_w6 = nn.Conv3d(self.dw_dim*4, self.dw_dim*8, 3, stride=2, padding=1, bias=True) # 3 . 4
357 | if self.wd_size==16:
358 | self.conv_w7 = nn.Conv3d(self.dw_dim*8, self.dw_dim*8, 3, stride=2, padding=1, bias=True) # 3 . 4
359 |
360 |
361 | self.fc1 = nn.Linear((self.dw_dim*8*self.d*self.d*self.d + self.d_dim*8*4*4*4*2), 4096)
362 | self.fc2 = nn.Linear(4096, 1024)
363 |
364 | # X
365 | self.fc_x1 = nn.Linear(1024, 256)
366 | self.fc_x2 = nn.Linear(256, 1024)
367 | self.fc_x3 = nn.Linear(1024, self.dw_dim*8*self.d*self.d*self.d)
368 | self.dconv_x0 = nn.Conv3d(self.dw_dim*8, self.dw_dim*8, 3, stride=1, padding=1, bias=True)
369 | self.dconv_x1 = nn.ConvTranspose3d(self.dw_dim*8, self.dw_dim*4, 4, stride=2, padding=1, bias=True) # 8
370 | self.dconv_x2 = nn.Conv3d(self.dw_dim*4, self.dw_dim*4, 3, stride=1, padding=1, bias=True)
371 | self.dconv_x3 = nn.ConvTranspose3d(self.dw_dim*4, self.dw_dim*2, 4, stride=2, padding=1, bias=True) # 16
372 | self.dconv_x4 = nn.Conv3d(self.dw_dim*2, self.d_dim*2, 3, stride=1, padding=1, bias=True)
373 | self.dconv_x5 = nn.ConvTranspose3d(self.dw_dim*2, self.dw_dim, 4, stride=2, padding=1, bias=True) # 32
374 | self.dconv_x6 = nn.Conv3d(self.dw_dim, self.dw_dim, 3, stride=1, padding=1, bias=True)
375 | self.dconv_x7 = nn.Conv3d(self.dw_dim, self.k_dim, 3, stride=1, padding=1, bias=True)
376 |
377 | # D
378 | self.fc_d1 = nn.Linear(1024, 1024)
379 | self.fc_d2 = nn.Linear(1024, 1024)
380 | self.fc_d3 = nn.Linear(1024, self.k_dim*3)
381 |
382 |
383 | def forward(self, partial_in, input_in, windows_in, window_masks_in, loc_mask_in, is_training=False):
384 | leaky_f = 0.02
385 | out = torch.cat([input_in, window_masks_in], dim=1)
386 | out = self.conv_1(out)
387 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
388 | out = self.conv_2(out)
389 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
390 | out = self.conv_3(out)
391 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
392 | out = self.conv_4(out)
393 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
394 | out = self.conv_5(out)
395 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
396 | out = self.conv_6(out)
397 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
398 | out = self.conv_7(out)
399 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
400 | out = self.conv_8(out)
401 | out_coa = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
402 |
403 | out = partial_in
404 | out = self.conv_p1(out)
405 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
406 | out = self.conv_p2(out)
407 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
408 | out = self.conv_p3(out)
409 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
410 | out = self.conv_p4(out)
411 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
412 | out = self.conv_p5(out)
413 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
414 | out = self.conv_p6(out)
415 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
416 | out = self.conv_p7(out)
417 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
418 | out = self.conv_p8(out)
419 | out_p = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
420 |
421 | out = windows_in
422 | out = self.conv_w0(out)
423 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
424 | out = self.conv_w1(out)
425 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
426 | out = self.conv_w2(out)
427 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
428 | out = self.conv_w3(out)
429 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
430 | out = self.conv_w4(out)
431 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
432 | out = self.conv_w5(out)
433 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
434 | out = self.conv_w6(out)
435 | out_w = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
436 | if self.wd_size==16:
437 | out_w = self.conv_w7(out_w)
438 | out_w = F.leaky_relu(out_w, negative_slope=leaky_f, inplace=True)
439 |
440 | out_coa = out_coa.view(1, self.d_dim*8*4*4*4)
441 | out_p = out_p.view(1, self.d_dim*8*4*4*4)
442 | out_w =out_w.view(1, self.dw_dim*8*self.d*self.d*self.d)
443 |
444 | out = torch.cat((out_w, out_coa, out_p), dim=1)
445 | out = self.fc1(out)
446 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
447 | out = self.fc2(out)
448 | out_latent = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
449 |
450 | # D
451 | out = self.fc_d1(out_latent)
452 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
453 | out = self.fc_d2(out)
454 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
455 | out = self.fc_d3(out)
456 | # out_D = torch.max(torch.min(out, out*0.002+0.998), out*0.002-0.998)
457 | out_D = torch.clip(out, -1.0, 1.0)
458 | out_D = out_D.view(self.k_dim, 3)
459 |
460 | # X
461 | out = self.fc_x1(out_latent)
462 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
463 | out = self.fc_x2(out)
464 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
465 | out = self.fc_x3(out)
466 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
467 |
468 | out = out.view(1, self.dw_dim*8, self.d,self.d,self.d)
469 | out = self.dconv_x0(out)
470 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
471 | #print('0',out.shape)
472 | out = self.dconv_x1(out)
473 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
474 | #print('1',out.shape)
475 | out = self.dconv_x2(out)
476 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
477 | #print('2',out.shape)
478 | out = self.dconv_x3(out)
479 | #print('3',out.shape)
480 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
481 | out = self.dconv_x4(out)
482 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
483 | #print('4',out.shape)
484 | out = self.dconv_x5(out)
485 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
486 | #print('5',out.shape)
487 | out = self.dconv_x6(out)
488 | out = F.leaky_relu(out, negative_slope=leaky_f, inplace=True)
489 | #print('6',out.shape)
490 | out_X = self.dconv_x7(out) # [1, kdim, vx, vy, vz]
491 |
492 | # out_X = torch.max(torch.min(out_X,out_X*0.002+0.998), out_X*0.002)
493 | # out_X = torch.sigmoid(out_X)
494 | out_X = torch.clip(out_X, 0.0, 1.0)
495 | # loc_mask_in = F.interpolate(loc_mask_in, scale_factor=8, mode='nearest')
496 | out_X = F.avg_pool3d(out_X, kernel_size=8, stride=8)
497 |
498 | if self.use_mean_x:
499 | out_X = out_X*loc_mask_in/(torch.sum(out_X*loc_mask_in, dim=1).unsqueeze(1)+1e-5)
500 | else:
501 | out_X = torch.exp(out_X)*loc_mask_in/(torch.sum(torch.exp(out_X)*loc_mask_in, dim=1).unsqueeze(1)+1e-5)
502 |
503 | #print('out_x', out_X.mean())
504 | return out_D, out_X
505 |
506 |
507 |
--------------------------------------------------------------------------------
/patch_train.sh:
--------------------------------------------------------------------------------
1 | python main.py \
2 | --data_content content_chair \
3 | --data_dir ./data/03001627/ \
4 | --input_size 32 \
5 | --output_size 128 \
6 | --sample_dir samples_patch \
7 | --checkpoint_dir checkpoint_patch \
8 | --gpu 0 \
9 | --epoch 250 \
10 | --lr 1e-4 \
11 | --z_dim 128 \
12 | --g_dim 32 \
13 | --K 2 \
14 | --w_ident 8 \
15 | --csize 26 \
16 | --c_range 26 \
17 | --max_sample_num 400 \
18 | --train_patch \
19 | --model_name chair_model0 \
20 | --dump_deform_path ./dump_deform/chair/model0 \
21 | #--dump_deform \
22 | #--small_dataset \
23 | #--mode test \
24 | #--continue_train \
25 |
--------------------------------------------------------------------------------
/pc_util.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | """ Utility functions for processing point clouds.
7 | Author: Charles R. Qi and Or Litany
8 | """
9 |
10 | import os
11 | import sys
12 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
13 | sys.path.append(BASE_DIR)
14 |
15 | # Point cloud IO
16 | import numpy as np
17 | try:
18 | from plyfile import PlyData, PlyElement
19 | except:
20 | print("Please install the module 'plyfile' for PLY i/o, e.g.")
21 | print("pip install plyfile")
22 | sys.exit(-1)
23 |
24 |
25 | # Mesh IO
26 | #import trimesh
27 |
28 | import matplotlib.pyplot as pyplot
29 |
30 | # ----------------------------------------
31 | # Point Cloud Sampling
32 | # ----------------------------------------
33 |
34 | def random_sampling(pc, num_sample, replace=None, return_choices=False):
35 | """ Input is NxC, output is num_samplexC
36 | """
37 | if replace is None: replace = (pc.shape[0]num_sample:
138 | pc = random_sampling(pc, num_sample, False)
139 | elif pc.shape[0]num_sample:
185 | pc = random_sampling(pc, num_sample, False)
186 | elif pc.shape[0]np.max(labels))
219 |
220 | vertex = []
221 | #colors = [pyplot.cm.jet(i/float(num_classes)) for i in range(num_classes)]
222 | colors = [colormap(i/float(num_classes)) for i in range(num_classes)]
223 | for i in range(N):
224 | c = colors[labels[i]]
225 | c = [int(x*255) for x in c]
226 | vertex.append( (points[i,0],points[i,1],points[i,2],c[0],c[1],c[2]) )
227 | vertex = np.array(vertex, dtype=[('x', 'f4'), ('y', 'f4'),('z', 'f4'),('red', 'u1'), ('green', 'u1'),('blue', 'u1')])
228 |
229 | el = PlyElement.describe(vertex, 'vertex', comments=['vertices'])
230 | PlyData([el], text=True).write(filename)
231 |
232 | def write_ply_rgb(points, colors, out_filename, num_classes=None):
233 | """ Color (N,3) points with RGB colors (N,3) within range [0,255] as OBJ file """
234 | colors = colors.astype(int)
235 | N = points.shape[0]
236 | fout = open(out_filename, 'w')
237 | for i in range(N):
238 | c = colors[i,:]
239 | fout.write('v %f %f %f %d %d %d\n' % (points[i,0],points[i,1],points[i,2],c[0],c[1],c[2]))
240 | fout.close()
241 |
242 | # ----------------------------------------
243 | # Simple Point cloud and Volume Renderers
244 | # ----------------------------------------
245 |
246 | def pyplot_draw_point_cloud(points, output_filename):
247 | """ points is a Nx3 numpy array """
248 | import matplotlib.pyplot as plt
249 | fig = plt.figure()
250 | ax = fig.add_subplot(111, projection='3d')
251 | ax.scatter(points[:,0], points[:,1], points[:,2])
252 | ax.set_xlabel('x')
253 | ax.set_ylabel('y')
254 | ax.set_zlabel('z')
255 | #savefig(output_filename)
256 |
257 | def pyplot_draw_volume(vol, output_filename):
258 | """ vol is of size vsize*vsize*vsize
259 | output an image to output_filename
260 | """
261 | points = volume_to_point_cloud(vol)
262 | pyplot_draw_point_cloud(points, output_filename)
263 |
264 | # ----------------------------------------
265 | # Simple Point manipulations
266 | # ----------------------------------------
267 | def rotate_point_cloud(points, rotation_matrix=None):
268 | """ Input: (n,3), Output: (n,3) """
269 | # Rotate in-place around Z axis.
270 | if rotation_matrix is None:
271 | rotation_angle = np.random.uniform() * 2 * np.pi
272 | sinval, cosval = np.sin(rotation_angle), np.cos(rotation_angle)
273 | rotation_matrix = np.array([[cosval, sinval, 0],
274 | [-sinval, cosval, 0],
275 | [0, 0, 1]])
276 | ctr = points.mean(axis=0)
277 | rotated_data = np.dot(points-ctr, rotation_matrix) + ctr
278 | return rotated_data, rotation_matrix
279 |
280 | def rotate_pc_along_y(pc, rot_angle):
281 | ''' Input ps is NxC points with first 3 channels as XYZ
282 | z is facing forward, x is left ward, y is downward
283 | '''
284 | cosval = np.cos(rot_angle)
285 | sinval = np.sin(rot_angle)
286 | rotmat = np.array([[cosval, -sinval],[sinval, cosval]])
287 | pc[:,[0,2]] = np.dot(pc[:,[0,2]], np.transpose(rotmat))
288 | return pc
289 |
290 | def roty(t):
291 | """Rotation about the y-axis."""
292 | c = np.cos(t)
293 | s = np.sin(t)
294 | return np.array([[c, 0, s],
295 | [0, 1, 0],
296 | [-s, 0, c]])
297 |
298 | def roty_batch(t):
299 | """Rotation about the y-axis.
300 | t: (x1,x2,...xn)
301 | return: (x1,x2,...,xn,3,3)
302 | """
303 | input_shape = t.shape
304 | output = np.zeros(tuple(list(input_shape)+[3,3]))
305 | c = np.cos(t)
306 | s = np.sin(t)
307 | output[...,0,0] = c
308 | output[...,0,2] = s
309 | output[...,1,1] = 1
310 | output[...,2,0] = -s
311 | output[...,2,2] = c
312 | return output
313 |
314 | def rotz(t):
315 | """Rotation about the z-axis."""
316 | c = np.cos(t)
317 | s = np.sin(t)
318 | return np.array([[c, -s, 0],
319 | [s, c, 0],
320 | [0, 0, 1]])
321 |
322 |
323 |
324 | # ----------------------------------------
325 | # BBox
326 | # ----------------------------------------
327 | def bbox_corner_dist_measure(crnr1, crnr2):
328 | """ compute distance between box corners to replace iou
329 | Args:
330 | crnr1, crnr2: Nx3 points of box corners in camera axis (y points down)
331 | output is a scalar between 0 and 1
332 | """
333 |
334 | dist = sys.maxsize
335 | for y in range(4):
336 | rows = ([(x+y)%4 for x in range(4)] + [4+(x+y)%4 for x in range(4)])
337 | d_ = np.linalg.norm(crnr2[rows, :] - crnr1, axis=1).sum() / 8.0
338 | if d_ < dist:
339 | dist = d_
340 |
341 | u = sum([np.linalg.norm(x[0,:] - x[6,:]) for x in [crnr1, crnr2]])/2.0
342 |
343 | measure = max(1.0 - dist/u, 0)
344 | print(measure)
345 |
346 |
347 | return measure
348 |
349 |
350 | def point_cloud_to_bbox(points):
351 | """ Extract the axis aligned box from a pcl or batch of pcls
352 | Args:
353 | points: Nx3 points or BxNx3
354 | output is 6 dim: xyz pos of center and 3 lengths
355 | """
356 | which_dim = len(points.shape) - 2 # first dim if a single cloud and second if batch
357 | mn, mx = points.min(which_dim), points.max(which_dim)
358 | lengths = mx - mn
359 | cntr = 0.5*(mn + mx)
360 | return np.concatenate([cntr, lengths], axis=which_dim)
361 |
362 | def write_bbox(scene_bbox, out_filename):
363 | """Export scene bbox to meshes
364 | Args:
365 | scene_bbox: (N x 6 numpy array): xyz pos of center and 3 lengths
366 | out_filename: (string) filename
367 | Note:
368 | To visualize the boxes in MeshLab.
369 | 1. Select the objects (the boxes)
370 | 2. Filters -> Polygon and Quad Mesh -> Turn into Quad-Dominant Mesh
371 | 3. Select Wireframe view.
372 | """
373 | def convert_box_to_trimesh_fmt(box):
374 | ctr = box[:3]
375 | lengths = box[3:]
376 | trns = np.eye(4)
377 | trns[0:3, 3] = ctr
378 | trns[3,3] = 1.0
379 | box_trimesh_fmt = trimesh.creation.box(lengths, trns)
380 | return box_trimesh_fmt
381 |
382 | scene = trimesh.scene.Scene()
383 | for box in scene_bbox:
384 | scene.add_geometry(convert_box_to_trimesh_fmt(box))
385 |
386 | mesh_list = trimesh.util.concatenate(scene.dump())
387 | # save to ply file
388 | trimesh.io.export.export_mesh(mesh_list, out_filename, file_type='ply')
389 |
390 | return
391 |
392 | def write_oriented_bbox(scene_bbox, out_filename):
393 | """Export oriented (around Z axis) scene bbox to meshes
394 | Args:
395 | scene_bbox: (N x 7 numpy array): xyz pos of center and 3 lengths (dx,dy,dz)
396 | and heading angle around Z axis.
397 | Y forward, X right, Z upward. heading angle of positive X is 0,
398 | heading angle of positive Y is 90 degrees.
399 | out_filename: (string) filename
400 | """
401 | def heading2rotmat(heading_angle):
402 | pass
403 | rotmat = np.zeros((3,3))
404 | rotmat[2,2] = 1
405 | cosval = np.cos(heading_angle)
406 | sinval = np.sin(heading_angle)
407 | rotmat[0:2,0:2] = np.array([[cosval, -sinval],[sinval, cosval]])
408 | return rotmat
409 |
410 | def convert_oriented_box_to_trimesh_fmt(box):
411 | ctr = box[:3]
412 | lengths = box[3:6]
413 | trns = np.eye(4)
414 | trns[0:3, 3] = ctr
415 | trns[3,3] = 1.0
416 | trns[0:3,0:3] = heading2rotmat(box[6])
417 | box_trimesh_fmt = trimesh.creation.box(lengths, trns)
418 | return box_trimesh_fmt
419 |
420 | scene = trimesh.scene.Scene()
421 | for box in scene_bbox:
422 | scene.add_geometry(convert_oriented_box_to_trimesh_fmt(box))
423 |
424 | mesh_list = trimesh.util.concatenate(scene.dump())
425 | # save to ply file
426 | trimesh.io.export.export_mesh(mesh_list, out_filename, file_type='ply')
427 |
428 | return
429 |
430 | def write_oriented_bbox_camera_coord(scene_bbox, out_filename):
431 | """Export oriented (around Y axis) scene bbox to meshes
432 | Args:
433 | scene_bbox: (N x 7 numpy array): xyz pos of center and 3 lengths (dx,dy,dz)
434 | and heading angle around Y axis.
435 | Z forward, X rightward, Y downward. heading angle of positive X is 0,
436 | heading angle of negative Z is 90 degrees.
437 | out_filename: (string) filename
438 | """
439 | def heading2rotmat(heading_angle):
440 | pass
441 | rotmat = np.zeros((3,3))
442 | rotmat[1,1] = 1
443 | cosval = np.cos(heading_angle)
444 | sinval = np.sin(heading_angle)
445 | rotmat[0,:] = np.array([cosval, 0, sinval])
446 | rotmat[2,:] = np.array([-sinval, 0, cosval])
447 | return rotmat
448 |
449 | def convert_oriented_box_to_trimesh_fmt(box):
450 | ctr = box[:3]
451 | lengths = box[3:6]
452 | trns = np.eye(4)
453 | trns[0:3, 3] = ctr
454 | trns[3,3] = 1.0
455 | trns[0:3,0:3] = heading2rotmat(box[6])
456 | box_trimesh_fmt = trimesh.creation.box(lengths, trns)
457 | return box_trimesh_fmt
458 |
459 | scene = trimesh.scene.Scene()
460 | for box in scene_bbox:
461 | scene.add_geometry(convert_oriented_box_to_trimesh_fmt(box))
462 |
463 | mesh_list = trimesh.util.concatenate(scene.dump())
464 | # save to ply file
465 | trimesh.io.export.export_mesh(mesh_list, out_filename, file_type='ply')
466 |
467 | return
468 |
469 | def write_lines_as_cylinders(pcl, filename, rad=0.005, res=64):
470 | """Create lines represented as cylinders connecting pairs of 3D points
471 | Args:
472 | pcl: (N x 2 x 3 numpy array): N pairs of xyz pos
473 | filename: (string) filename for the output mesh (ply) file
474 | rad: radius for the cylinder
475 | res: number of sections used to create the cylinder
476 | """
477 | scene = trimesh.scene.Scene()
478 | for src,tgt in pcl:
479 | # compute line
480 | vec = tgt - src
481 | M = trimesh.geometry.align_vectors([0,0,1],vec, False)
482 | vec = tgt - src # compute again since align_vectors modifies vec in-place!
483 | M[:3,3] = 0.5*src + 0.5*tgt
484 | height = np.sqrt(np.dot(vec, vec))
485 | scene.add_geometry(trimesh.creation.cylinder(radius=rad, height=height, sections=res, transform=M))
486 | mesh_list = trimesh.util.concatenate(scene.dump())
487 | trimesh.io.export.export_mesh(mesh_list, '%s.ply'%(filename), file_type='ply')
488 |
489 | # ----------------------------------------
490 | # Testing
491 | # ----------------------------------------
492 | if __name__ == '__main__':
493 | print('running some tests')
494 |
495 | ############
496 | ## Test "write_lines_as_cylinders"
497 | ############
498 | pcl = np.random.rand(32, 2, 3)
499 | write_lines_as_cylinders(pcl, 'point_connectors')
500 | input()
501 |
502 |
503 | scene_bbox = np.zeros((1,7))
504 | scene_bbox[0,3:6] = np.array([1,2,3]) # dx,dy,dz
505 | scene_bbox[0,6] = np.pi/4 # 45 degrees
506 | write_oriented_bbox(scene_bbox, 'single_obb_45degree.ply')
507 | ############
508 | ## Test point_cloud_to_bbox
509 | ############
510 | pcl = np.random.rand(32, 16, 3)
511 | pcl_bbox = point_cloud_to_bbox(pcl)
512 | assert pcl_bbox.shape == (32, 6)
513 |
514 | pcl = np.random.rand(16, 3)
515 | pcl_bbox = point_cloud_to_bbox(pcl)
516 | assert pcl_bbox.shape == (6,)
517 |
518 | ############
519 | ## Test corner distance
520 | ############
521 | crnr1 = np.array([[2.59038660e+00, 8.96107932e-01, 4.73305349e+00],
522 | [4.12281644e-01, 8.96107932e-01, 4.48046631e+00],
523 | [2.97129656e-01, 8.96107932e-01, 5.47344275e+00],
524 | [2.47523462e+00, 8.96107932e-01, 5.72602993e+00],
525 | [2.59038660e+00, 4.41155793e-03, 4.73305349e+00],
526 | [4.12281644e-01, 4.41155793e-03, 4.48046631e+00],
527 | [2.97129656e-01, 4.41155793e-03, 5.47344275e+00],
528 | [2.47523462e+00, 4.41155793e-03, 5.72602993e+00]])
529 | crnr2 = crnr1
530 |
531 | print(bbox_corner_dist_measure(crnr1, crnr2))
532 |
533 |
534 |
535 | print('tests PASSED')
536 |
537 |
538 |
539 |
--------------------------------------------------------------------------------
/splits/compose2.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 |
4 | #path = '/home/bo/data/shapenet_processed2/02801938'
5 | #path = '/home/bo/data/shapenet_processed/02828884'
6 | path = '/home/bo/data/shapenet_processed'
7 |
8 | #f1 = open('content_bench_train.txt', 'w')
9 | #f2 = open('content_bench_test.txt', 'w')
10 | f1 = open('total_sub_train.txt', 'w')
11 | f2 = open('total_sub_test.txt', 'w')
12 |
13 | f0 = open('content_chair_train.txt', 'r')
14 | lines = f0.readlines()
15 | for l in lines:
16 | f1.write('03001627/%s\n'%(l.strip()))
17 | f0.close()
18 | f0 = open('content_chair_test.txt', 'r')
19 | lines = f0.readlines()
20 | for l in lines[0:200]:
21 | f2.write('03001627/%s\n'%(l.strip()))
22 | f0.close()
23 |
24 | f0 = open('content_table_test.txt', 'r')
25 | lines = f0.readlines()
26 | for l in lines:
27 | f1.write('04379243/%s\n'%(l.strip()))
28 | f0.close()
29 | f0 = open('content_table_test.txt', 'r')
30 | lines = f0.readlines()
31 | for l in lines[0:200]:
32 | f2.write('04379243/%s\n'%(l.strip()))
33 | f0.close()
34 |
35 |
36 | for cat in os.listdir(path):
37 | if cat in ['03001627', '04379243']:
38 | continue
39 | if not cat in ['02801938', '02828884', '03691459']:
40 | continue
41 | if not cat.startswith('0'):
42 | continue
43 | names = os.listdir(os.path.join(path, cat))
44 | data_len = len(names)
45 |
46 | for i in range(data_len):
47 | if i=int(4*data_len/5) and i