├── .gitignore
├── LICENSE
├── README.md
├── anaconda_env
└── CompoNet.yml
├── data
├── __init__.py
├── data_utils.py
├── part_dataset_ae.py
└── part_dataset_pcn.py
├── datasets
└── download_data.sh
├── images
└── network_architecture.png
├── models.py
├── test.py
├── tf_ops
├── __init__.py
└── nn_distance
│ ├── README.md
│ ├── __init__.py
│ ├── tf_nndistance.cpp
│ ├── tf_nndistance.py
│ ├── tf_nndistance_compile.sh
│ ├── tf_nndistance_cpu.py
│ ├── tf_nndistance_g.cu
│ ├── tf_nndistance_g.cu.o
│ └── tf_nndistance_so.so
├── train.py
└── utils
├── __init__.py
├── compile_render_balls_so.sh
├── render_balls_so.cpp
├── render_balls_so.so
├── show3d_balls.py
└── tf_util.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | /log/*
3 | /datasets/*/
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 nschor
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CompoNet: Learning to Generate the Unseen by Part Synthesis and Composition
2 | Created by Nadav Schor, Oren Katzir, Hao Zhang, Daniel Cohen-Or.
3 |
4 | 
5 |
6 |
7 | ## Introduction
8 | This work is based on our [ICCV paper](https://arxiv.org/abs/1811.07441). We present CompoNet, a generative neural network for 3D shapes that is based on a part-based prior, where the key idea is for the network to synthesize shapes by varying both the shape parts and their compositions.
9 |
10 |
11 | ## Citation
12 | If you find our work useful in your research, please consider citing:
13 |
14 | @InProceedings{Schor_2019_ICCV,
15 | author = {Schor, Nadav and Katzir, Oren and Zhang, Hao and Cohen-Or, Daniel},
16 | title = {CompoNet: Learning to Generate the Unseen by Part Synthesis and Composition},
17 | booktitle = {The IEEE International Conference on Computer Vision (ICCV)},
18 | month = {October},
19 | year = {2019}
20 | }
21 |
22 |
23 | ## Dependencies
24 | Requirements:
25 | - Python 2.7
26 | - Tensorflow (version 1.4+)
27 | - OpenCV (for visualization)
28 |
29 | Our code has been tested with Python 2.7, TensorFlow 1.4.0, CUDA 8.0 and cuDNN 6.0 on Ubuntu 18.04.
30 |
31 |
32 | ## Installation
33 | Download the source code from the git repository:
34 | ```
35 | git clone https://github.com/nschor/CompoNet
36 | ```
37 | Compile the Chamfer loss file, under `CompoNet/tf_ops/nn_distance`, taken from [Fan et. al](https://github.com/fanhqme/PointSetGeneration).
38 | ```
39 | cd CompoNet/tf_ops/nn_distance
40 | ```
41 |
42 | Modify the Tensorflow and CUDA path in the `tf_nndistance_compile.sh` script and run it.
43 | ```
44 | sh tf_nndistance_compile.sh
45 | ```
46 | For visualization go to `utils/`.
47 | ```
48 | cd CompoNet/utils
49 | ```
50 | Run the `compile_render_balls_so.sh` script.
51 | ```
52 | sh compile_render_balls_so.sh
53 | ```
54 |
55 | If you are using Anaconda, we attached the environment we used `CompoNet.yml` under `anaconda_env/`.
56 | Create the environment using:
57 | ```
58 | cd CompoNet/anaconda_env
59 | conda env create -f CompoNet.yml
60 | ```
61 | Activate the environment:
62 | ```
63 | source activate CompoNet
64 | ```
65 | ### Data Set
66 | Download the ShapeNetPart dataset by running the `download_data.sh` script under `datasets/`.
67 | ```
68 | cd CompoNet/datasets
69 | sh download_data.sh
70 | ```
71 | The point-clouds will be stored in `CompoNet/datasets/shapenetcore_partanno_segmentation_benchmark_v0`
72 |
73 |
74 | ### Train CompoNet
75 | To train CompoNet on the Chair category with 400 points per part run:
76 | ```
77 | python train.py
78 | ```
79 | Check the available options using:
80 | ```
81 | python train.py -h
82 | ```
83 |
84 | ### Generate Shapes Using the Trained Model
85 | To generate new shapes, and visualize them run:
86 | ```
87 | python test.py --category category --model_path model_path
88 | ```
89 | Check the available options using:
90 | ```
91 | python test.py -h
92 | ```
93 |
94 | ## License
95 | This project is licensed under the terms of the MIT license (see LICENSE for details).
96 |
--------------------------------------------------------------------------------
/anaconda_env/CompoNet.yml:
--------------------------------------------------------------------------------
1 | name: CompoNet
2 | channels:
3 | - menpo
4 | - conda-forge
5 | - anaconda
6 | - mw
7 | - defaults
8 | dependencies:
9 | - atk=2.8.0
10 | - backports=1.0
11 | - backports.functools_lru_cache=1.5
12 | - backports.shutil_get_terminal_size=1.0.0
13 | - backports_abc=0.5
14 | - blas=1.0
15 | - ca-certificates=2019.1.23
16 | - cairo=1.14.12
17 | - certifi=2018.11.29
18 | - configparser=3.5.0
19 | - cycler=0.10.0
20 | - dbus=1.10.22
21 | - decorator=4.3.0
22 | - entrypoints=0.2.3
23 | - enum34=1.1.6
24 | - expat=2.2.5
25 | - ffmpeg=2.7.0
26 | - fontconfig=2.12.6
27 | - freeimage=3.17.0
28 | - freetype=2.8.1
29 | - functools32=3.2.3.2
30 | - gdk-pixbuf=2.28.2
31 | - glib=2.53.6
32 | - gmp=6.1.2
33 | - gst-plugins-base=1.12.4
34 | - gstreamer=1.12.4
35 | - gtk2=2.24.31
36 | - h5py=2.8.0
37 | - hdf5=1.10.2
38 | - icu=58.2
39 | - imageio=1.5.0
40 | - intel-openmp=2018.0.0
41 | - ipaddress=1.0.22
42 | - ipykernel=4.10.0
43 | - ipython=5.8.0
44 | - ipython_genutils=0.2.0
45 | - ipywidgets=7.4.2
46 | - jasper=1.900.1
47 | - jinja2=2.10
48 | - jpeg=9b
49 | - jsonschema=2.6.0
50 | - jupyter=1.0.0
51 | - jupyter_client=5.2.3
52 | - jupyter_console=5.2.0
53 | - jupyter_core=4.4.0
54 | - libedit=3.1
55 | - libffi=3.2.1
56 | - libgcc-ng=7.2.0
57 | - libgfortran-ng=7.2.0
58 | - libiconv=1.15
59 | - libpng=1.6.34
60 | - libsodium=1.0.16
61 | - libstdcxx-ng=7.2.0
62 | - libtiff=4.0.9
63 | - libxcb=1.12
64 | - libxml2=2.9.7
65 | - linecache2=1.0.0
66 | - markupsafe=1.0
67 | - matplotlib=2.1.2
68 | - mistune=0.8.3
69 | - mkl=2018.0.1
70 | - nbconvert=5.3.1
71 | - nbformat=4.4.0
72 | - ncurses=6.0
73 | - notebook=5.7.0
74 | - olefile=0.45.1
75 | - opencv=2.4.11
76 | - openssl=1.0.2p
77 | - pandas=0.22.0
78 | - pandoc=2.2.3.2
79 | - pandocfilters=1.4.2
80 | - pango=1.22.4
81 | - pathlib2=2.3.2
82 | - pcre=8.39
83 | - pexpect=4.6.0
84 | - pickleshare=0.7.5
85 | - pillow=5.0.0
86 | - pip=19.0.3
87 | - pixman=0.34.0
88 | - pkg-config=0.29.2
89 | - pkgconfig=1.4.0
90 | - prometheus_client=0.4.2
91 | - prompt_toolkit=1.0.15
92 | - ptyprocess=0.6.0
93 | - pygments=2.2.0
94 | - pyparsing=2.2.0
95 | - pyqt=5.6.0
96 | - python=2.7.14
97 | - python-dateutil=2.6.1
98 | - pytz=2017.3
99 | - pyzmq=17.1.2
100 | - qt=5.6.2
101 | - qtconsole=4.4.2
102 | - readline=7.0
103 | - scandir=1.9.0
104 | - scikit-learn=0.19.1
105 | - scipy=1.0.0
106 | - send2trash=1.5.0
107 | - simplegeneric=0.8.1
108 | - singledispatch=3.4.0.3
109 | - sip=4.18
110 | - sqlite=3.22.0
111 | - ssl_match_hostname=3.5.0.1
112 | - subprocess32=3.2.7
113 | - terminado=0.8.1
114 | - testpath=0.4.2
115 | - tk=8.6.7
116 | - tornado=4.5.3
117 | - traceback2=1.4.0
118 | - traitlets=4.3.2
119 | - unittest2=1.1.0
120 | - wcwidth=0.1.7
121 | - webencodings=0.5.1
122 | - widgetsnbextension=3.4.2
123 | - xorg-libxau=1.0.8
124 | - xorg-libxdmcp=1.1.2
125 | - xz=5.2.3
126 | - zeromq=4.2.5
127 | - zlib=1.2.11
128 | - pip:
129 | - absl-py==0.1.10
130 | - backports-weakref==1.0.post1
131 | - bleach==1.5.0
132 | - funcsigs==1.0.2
133 | - futures==3.2.0
134 | - html5lib==0.9999999
135 | - markdown==3.0.1
136 | - mock==2.0.0
137 | - numpy==1.16.2
138 | - pbr==5.1.3
139 | - protobuf==3.7.0
140 | - setuptools==40.8.0
141 | - six==1.12.0
142 | - tensorflow==1.5.0
143 | - tensorflow-gpu==1.4.0
144 | - tensorflow-tensorboard==0.4.0
145 | - tflearn==0.3.2
146 | - werkzeug==0.14.1
147 | - wheel==0.33.1
148 | prefix: /home/nadav/anaconda2/envs/tensorflow
149 |
150 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nschor/CompoNet/816da713d36d96d715a9c026c8ac4ae568a24780/data/__init__.py
--------------------------------------------------------------------------------
/data/data_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import numpy as np
4 | import part_dataset_ae
5 | import part_dataset_pcn
6 |
7 |
8 | def load_data(data_path, num_point, category, seen_split, unseen_split):
9 | pcn_train_dataset, pcn_test_dataset, num_parts = load_pcn_data(data_path, num_point, category, seen_split,
10 | unseen_split)
11 | ae_train_dataset, ae_test_dataset = load_aes_data(data_path, num_point, category, seen_split, unseen_split,
12 | num_parts)
13 |
14 | return pcn_train_dataset, pcn_test_dataset, ae_train_dataset, ae_test_dataset, num_parts
15 |
16 |
17 | def load_pcn_data(data_path, num_point, category, seen_split, unseen_split):
18 | pcn_train_dataset = part_dataset_pcn.PartDatasetPCN(root=data_path, npoints=num_point, class_choice=category,
19 | split=seen_split)
20 | pcn_test_dataset = part_dataset_pcn.PartDatasetPCN(root=data_path, npoints=num_point, class_choice=category,
21 | split=unseen_split)
22 | num_parts = pcn_train_dataset.get_number_of_parts()
23 |
24 | return pcn_train_dataset, pcn_test_dataset, num_parts
25 |
26 |
27 | def load_aes_data(data_path, num_point, category, seen_split, unseen_split, num_parts):
28 | ae_train_dataset = []
29 | ae_test_dataset = []
30 | for i in xrange(num_parts):
31 | print 'Loading part ' + str(i)
32 | ae_train_dataset.append(
33 | part_dataset_ae.PartDatasetAE(root=data_path, npoints=num_point, class_choice=category, split=seen_split,
34 | part_label=i))
35 | ae_test_dataset.append(
36 | part_dataset_ae.PartDatasetAE(root=data_path, npoints=num_point, class_choice=category, split=unseen_split,
37 | part_label=i))
38 |
39 | return ae_train_dataset, ae_test_dataset
40 |
41 |
42 | def pc_normalize(pc):
43 | centroid = np.mean(pc, axis=0)
44 | pc = pc - centroid
45 | m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
46 | pc = pc / m
47 | return pc
48 |
49 |
50 | def data_parse(catfile, class_choice, root, split):
51 | cat = {}
52 | with open(catfile, 'r') as f:
53 | for line in f:
54 | ls = line.strip().split()
55 | cat[ls[0]] = ls[1]
56 |
57 | cat = {k: v for k, v in cat.items() if k in class_choice}
58 |
59 | meta = {}
60 | with open(os.path.join(root, 'train_test_split', 'shuffled_train_file_list.json'), 'r') as f:
61 | train_ids = set([str(d.split('/')[2]) for d in json.load(f)])
62 | with open(os.path.join(root, 'train_test_split', 'shuffled_val_file_list.json'), 'r') as f:
63 | val_ids = set([str(d.split('/')[2]) for d in json.load(f)])
64 | with open(os.path.join(root, 'train_test_split', 'shuffled_test_file_list.json'), 'r') as f:
65 | test_ids = set([str(d.split('/')[2]) for d in json.load(f)])
66 | for item in cat:
67 | meta[item] = []
68 | dir_point = os.path.join(root, cat[item], 'points')
69 | dir_seg = os.path.join(root, cat[item], 'points_label')
70 | fns = sorted(os.listdir(dir_point))
71 | if split == 'trainval':
72 | fns = [fn for fn in fns if ((fn[0:-4] in train_ids) or (fn[0:-4] in val_ids))]
73 | elif split == 'train':
74 | fns = [fn for fn in fns if fn[0:-4] in train_ids]
75 | elif split == 'val':
76 | fns = [fn for fn in fns if fn[0:-4] in val_ids]
77 | elif split == 'test':
78 | fns = [fn for fn in fns if fn[0:-4] in test_ids]
79 | else:
80 | print('Unknown split: %s. Exiting..' % (split))
81 | exit(-1)
82 |
83 | for fn in fns:
84 | token = (os.path.splitext(os.path.basename(fn))[0])
85 | meta[item].append((os.path.join(dir_point, token + '.pts'), os.path.join(dir_seg, token + '.seg')))
86 |
87 | return cat, meta
88 |
89 |
90 | def compute_num_of_parts(path):
91 | max_num_parts = 0
92 | for i in range(len(path) / 50):
93 | num_parts = len(np.unique(np.loadtxt(path[i][-1]).astype(np.uint8)))
94 | if num_parts > max_num_parts:
95 | max_num_parts = num_parts
96 | return max_num_parts
97 |
98 |
99 | def get_point_set_and_seg(path, index):
100 | fn = path[index]
101 | point_set = np.loadtxt(fn[1]).astype(np.float32)
102 | seg = np.loadtxt(fn[2]).astype(np.int64) - 1
103 |
104 | return point_set, seg
105 |
106 |
107 | def choose_points(point_set, npoints):
108 | choice = np.random.choice(len(point_set), npoints, replace=True)
109 | point_set = point_set[choice]
110 |
111 | return point_set, choice
112 |
--------------------------------------------------------------------------------
/data/part_dataset_ae.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import data_utils
4 |
5 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
6 |
7 |
8 | class PartDatasetAE():
9 | def __init__(self, root, npoints=400, class_choice='Chair', split='train', part_label=None):
10 | if part_label is None:
11 | print 'Wrong part label - part_dataset_ae'
12 | exit(1)
13 |
14 | self.npoints = npoints
15 | self.part_label = part_label
16 | self.cache = {} # caching the loaded parts
17 |
18 | cat, meta = data_utils.data_parse(os.path.join(root, 'synsetoffset2category.txt'), class_choice, root, split)
19 |
20 | self.datapath = []
21 | for item in cat:
22 | for fn in meta[item]:
23 | # discard missing parts
24 | seg = np.loadtxt(fn[1]).astype(np.int64) - 1
25 | part_points = np.where(seg == self.part_label)
26 | if len(part_points[0]) > 1:
27 | self.datapath.append((item, fn[0], fn[1]))
28 |
29 | def __getitem__(self, index):
30 | if index in self.cache:
31 | point_set = self.cache[index]
32 | else:
33 | point_set, seg = data_utils.get_point_set_and_seg(self.datapath, index)
34 | part_points = np.where(seg == self.part_label)
35 | point_set = data_utils.pc_normalize(point_set[part_points])
36 | self.cache[index] = point_set
37 |
38 | # choose the right number of point by
39 | # randomly picking, if there are too many
40 | # or re-sampling, if there are less than needed
41 | point_set_length = len(point_set)
42 | if point_set_length >= self.npoints:
43 | point_set, _ = data_utils.choose_points(point_set, self.npoints)
44 | else:
45 | extra_point_set, choice = data_utils.choose_points(point_set, self.npoints - point_set_length)
46 | point_set = np.append(point_set, extra_point_set, axis=0)
47 |
48 | return point_set
49 |
50 | def __len__(self):
51 | return len(self.datapath)
52 |
53 |
54 | if __name__ == '__main__':
55 | from utils import show3d_balls
56 |
57 | d = PartDatasetAE(root=os.path.join(BASE_DIR, '../data/shapenetcore_partanno_segmentation_benchmark_v0'),
58 | class_choice='Chair', split='test', part_label=1)
59 | i = 27
60 | ps = d[i]
61 | show3d_balls.showpoints(ps, ballradius=8)
62 |
--------------------------------------------------------------------------------
/data/part_dataset_pcn.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import data_utils
4 |
5 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
6 |
7 |
8 | class PartDatasetPCN:
9 | def __init__(self, root, npoints=400, class_choice='Chair', split='train'):
10 | self.npoints = npoints
11 | self.cache = {} # caching the loaded parts
12 |
13 | cat, meta = data_utils.data_parse(os.path.join(root, 'synsetoffset2category.txt'), class_choice, root, split)
14 |
15 | self.datapath = []
16 | for item in cat:
17 | for fn in meta[item]:
18 | self.datapath.append((item, fn[0], fn[1]))
19 |
20 | self.num_parts = data_utils.compute_num_of_parts(self.datapath)
21 |
22 | def __getitem__(self, index):
23 | if index in self.cache:
24 | parts_point_sets = self.cache[index]
25 | else:
26 | point_set, seg = data_utils.get_point_set_and_seg(self.datapath, index)
27 | point_set = data_utils.pc_normalize(point_set)
28 | parts_point_sets = []
29 | for p in xrange(self.num_parts):
30 | part_points = np.where(seg == p)
31 | if len(part_points[0]) > 1:
32 | part_point_set = point_set[part_points]
33 | is_part_exist = True
34 | else:
35 | part_point_set = np.zeros((self.npoints, 3))
36 | is_part_exist = False
37 | # normalized each part on its own
38 | if is_part_exist:
39 | norm_part_point_set = data_utils.pc_normalize(part_point_set)
40 | else:
41 | norm_part_point_set = part_point_set
42 | parts_point_sets.append((part_point_set, norm_part_point_set, is_part_exist))
43 | self.cache[index] = parts_point_sets
44 |
45 | point_sets = []
46 | for point_set, norm_part_point_set, is_full in parts_point_sets:
47 | # choose the right number of point by
48 | # randomly picking, if there are too many
49 | # or re-sampling, if there are less than needed
50 | point_set_length = len(point_set)
51 | if point_set_length >= self.npoints:
52 | point_set, choice = data_utils.choose_points(point_set, self.npoints)
53 | norm_part_point_set = norm_part_point_set[choice]
54 | else:
55 | extra_point_set, choice = data_utils.choose_points(point_set, self.npoints - point_set_length)
56 | point_set = np.append(point_set, extra_point_set, axis=0)
57 | norm_part_point_set = np.append(norm_part_point_set, norm_part_point_set[choice], axis=0)
58 | point_sets.append((point_set, norm_part_point_set, is_full))
59 |
60 | return point_sets
61 |
62 | def __len__(self):
63 | return len(self.datapath)
64 |
65 | def get_number_of_parts(self):
66 | return self.num_parts
67 |
68 |
69 | if __name__ == '__main__':
70 | from utils import show3d_balls
71 |
72 | d = PartDatasetPCN(root=os.path.join(BASE_DIR, '../data/shapenetcore_partanno_segmentation_benchmark_v0'),
73 | class_choice='Chair', split='test')
74 | i = 27
75 | point_sets = d[i]
76 | for p in xrange(d.get_number_of_parts()):
77 | ps, _, _ = point_sets[p]
78 | show3d_balls.showpoints(ps, ballradius=8)
79 |
--------------------------------------------------------------------------------
/datasets/download_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | wget https://shapenet.cs.stanford.edu/media/shapenetcore_partanno_segmentation_benchmark_v0.zip
3 | unzip shapenetcore_partanno_segmentation_benchmark_v0.zip
4 | rm shapenetcore_partanno_segmentation_benchmark_v0.zip
5 |
--------------------------------------------------------------------------------
/images/network_architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nschor/CompoNet/816da713d36d96d715a9c026c8ac4ae568a24780/images/network_architecture.png
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | from utils import tf_util
4 | from tf_ops.nn_distance import tf_nndistance
5 |
6 | BN_INIT_DECAY = 0.5
7 | BN_DECAY_RATE = 0.5
8 | BN_DECAY_CLIP = 0.99
9 |
10 |
11 | def placeholder_inputs(batch_size, num_point):
12 | point_clouds_ph = tf.placeholder(tf.float32, shape=(batch_size, num_point, 3))
13 | gt_ph = tf.placeholder(tf.float32, shape=(batch_size, num_point, 3))
14 | return point_clouds_ph, gt_ph
15 |
16 |
17 | def build_parts_aes_graphs(num_parts, num_points, part_embedding_dim, base_learning_rate, batch_size, decay_step,
18 | decay_rate, bn_decay_step):
19 | point_clouds_phs = []
20 | ae_ops = []
21 | for i in xrange(num_parts):
22 | with tf.variable_scope('part' + str(i)):
23 | print 'Graph part ' + str(i)
24 | point_clouds_ph, gt_ph = placeholder_inputs(batch_size, num_points)
25 | point_clouds_phs.append(point_clouds_ph)
26 |
27 | is_training_ph = tf.placeholder(tf.bool, shape=())
28 |
29 | batch = tf.Variable(0)
30 | bn_momentum = tf.train.exponential_decay(BN_INIT_DECAY, batch * batch_size, bn_decay_step, BN_DECAY_RATE,
31 | staircase=True)
32 | bn_decay = tf.minimum(BN_DECAY_CLIP, 1 - bn_momentum)
33 | bn_decay_summary = tf.summary.scalar('bn_decay' + ' ' + str(i), bn_decay)
34 |
35 | print "Get model and loss"
36 | pred, end_points = get_model_ae(point_clouds_ph, is_training_ph, batch_size, num_points, bn_decay,
37 | part_embedding_dim)
38 | loss, end_points_tmp = get_loss_ae(pred, gt_ph, end_points)
39 | loss_summary = tf.summary.scalar('loss', loss)
40 |
41 | print "Get training operator"
42 | learning_rate = tf.train.exponential_decay(base_learning_rate, batch * batch_size, decay_step,
43 | decay_rate, staircase=True)
44 | learning_rate_summary = tf.summary.scalar('learning_rate', learning_rate)
45 | optimizer = tf.train.AdamOptimizer(learning_rate)
46 | train_op = optimizer.minimize(loss, global_step=batch)
47 |
48 | merged_summary = tf.summary.merge((bn_decay_summary, loss_summary, learning_rate_summary))
49 |
50 | ae_ops.append({'point_clouds_ph': point_clouds_ph,
51 | 'gt_ph': gt_ph,
52 | 'is_training_ph': is_training_ph,
53 | 'pred': pred,
54 | 'loss': loss,
55 | 'train_op': train_op,
56 | 'merged_summary': merged_summary,
57 | 'step': batch,
58 | 'end_points': end_points})
59 |
60 | return ae_ops, point_clouds_phs
61 |
62 |
63 | def get_model_ae(point_cloud, is_training, batch_size, num_point, bn_decay=None, embedding_dim=64, reuse=False):
64 | input_point_cloud = tf.expand_dims(point_cloud, -1)
65 | net, end_points = ae_encoder(batch_size, num_point, 3, input_point_cloud, is_training, bn_decay=bn_decay,
66 | embedding_dim=embedding_dim)
67 | net = ae_decoder(batch_size, num_point, net, is_training, bn_decay=bn_decay, reuse=reuse)
68 |
69 | return net, end_points
70 |
71 |
72 | def ae_encoder(batch_size, num_point, point_dim, input_image, is_training, bn_decay=None, embedding_dim=128):
73 | net = tf_util.conv2d(input_image, 64, [1, point_dim],
74 | padding='VALID', stride=[1, 1],
75 | bn=True, is_training=is_training,
76 | scope='conv1', bn_decay=bn_decay)
77 | net = tf_util.conv2d(net, 64, [1, 1],
78 | padding='VALID', stride=[1, 1],
79 | bn=True, is_training=is_training,
80 | scope='conv2', bn_decay=bn_decay)
81 | net = tf_util.conv2d(net, 64, [1, 1],
82 | padding='VALID', stride=[1, 1],
83 | bn=True, is_training=is_training,
84 | scope='conv3', bn_decay=bn_decay)
85 | net = tf_util.conv2d(net, 128, [1, 1],
86 | padding='VALID', stride=[1, 1],
87 | bn=True, is_training=is_training,
88 | scope='conv4', bn_decay=bn_decay)
89 | net = tf_util.conv2d(net, embedding_dim, [1, 1],
90 | padding='VALID', stride=[1, 1],
91 | bn=True, is_training=is_training,
92 | scope='conv5', bn_decay=bn_decay)
93 | global_feat = tf_util.max_pool2d(net, [num_point, 1],
94 | padding='VALID', scope='maxpool')
95 | net = tf.reshape(global_feat, [batch_size, -1])
96 | end_points = {'embedding': net}
97 |
98 | return net, end_points
99 |
100 |
101 | def ae_decoder(batch_size, num_point, net, is_training, bn_decay=None, reuse=False):
102 | net = tf_util.fully_connected(net, 256, bn=True, is_training=is_training, scope='fc1', bn_decay=bn_decay,
103 | reuse=reuse)
104 | net = tf_util.fully_connected(net, 256, bn=True, is_training=is_training, scope='fc2', bn_decay=bn_decay,
105 | reuse=reuse)
106 | net = tf_util.fully_connected(net, num_point * 3, activation_fn=None, scope='fc3', reuse=reuse)
107 | net = tf.reshape(net, (batch_size, num_point, 3))
108 |
109 | return net
110 |
111 |
112 | def get_loss_ae(pred, gt, end_points):
113 | dists_forward, _, dists_backward, _ = tf_nndistance.nn_distance(pred, gt)
114 | loss = tf.reduce_mean(dists_forward + dists_backward)
115 | end_points['pcloss'] = loss
116 |
117 | loss = loss * 100
118 | end_points['loss'] = loss
119 | return loss, end_points
120 |
121 |
122 | def build_parts_pcn_graph(ae_ops, point_clouds_ph, num_parts, num_points, noise_embedding_dim, base_learning_rate,
123 | batch_size, decay_step, decay_rate, bn_decay_step):
124 | print '\nGraph PCN'
125 | with tf.variable_scope('pcn'):
126 | pcn_is_training = tf.placeholder(tf.bool, shape=())
127 | y = tf.placeholder(tf.float32, [None, num_parts, num_points, 3], name='y')
128 | y_mask = tf.placeholder(tf.float32, [None, num_parts, num_points], name='y_mask')
129 | noise = tf.placeholder(tf.float32, shape=[None, noise_embedding_dim])
130 |
131 | pcn_enc = tf.concat([ae_ops[0]['end_points']['embedding'], ae_ops[1]['end_points']['embedding']], axis=-1)
132 | for p in xrange(2, num_parts):
133 | pcn_enc = tf.concat([pcn_enc, ae_ops[p]['end_points']['embedding']], axis=-1)
134 | pcn_enc = tf.concat([pcn_enc, noise], axis=-1)
135 |
136 | pcn_batch = tf.Variable(0)
137 | bn_momentum = tf.train.exponential_decay(BN_INIT_DECAY, pcn_batch * batch_size, bn_decay_step, BN_DECAY_RATE,
138 | staircase=True)
139 | pcn_bn_decay = tf.minimum(BN_DECAY_CLIP, 1 - bn_momentum)
140 | bn_decay_summary = tf.summary.scalar('bn_decay', pcn_bn_decay)
141 |
142 | print "--- Get model and loss"
143 | x = tf.stack(point_clouds_ph, axis=1)
144 | y_hat = get_model_pcn(pcn_enc, x, num_parts, pcn_is_training, bn_decay=pcn_bn_decay)
145 | pcn_loss = get_loss_pcn(y, y_hat, y_mask, num_parts, batch_size)
146 | loss_summary = tf.summary.scalar('loss', pcn_loss)
147 |
148 | print "--- Get training operator"
149 | pcn_learning_rate = tf.train.exponential_decay(base_learning_rate, pcn_batch * batch_size, decay_step,
150 | decay_rate, staircase=True)
151 | learning_rate_summary = tf.summary.scalar('learning_rate', pcn_learning_rate)
152 | pcn_param = [var for var in tf.trainable_variables() if any(x in var.name for x in ['pcn'])]
153 | optimizer = tf.train.AdamOptimizer(pcn_learning_rate).minimize(pcn_loss, global_step=pcn_batch,
154 | var_list=pcn_param)
155 |
156 | merged_summary = tf.summary.merge((bn_decay_summary, loss_summary, learning_rate_summary))
157 |
158 | pcn_ops = ({'y': y,
159 | 'y_mask': y_mask,
160 | 'noise': noise,
161 | 'is_training_ph': pcn_is_training,
162 | 'pcn_enc': pcn_enc,
163 | 'pred': y_hat,
164 | 'loss': pcn_loss,
165 | 'train_op': optimizer,
166 | 'merged_summary': merged_summary,
167 | 'step': pcn_batch})
168 |
169 | return pcn_ops
170 |
171 |
172 | def get_model_pcn(pcn_enc, x, num_parts, is_training, bn_decay=None, reuse=False):
173 | with tf.variable_scope('trans', reuse=reuse):
174 | bias_initializer = np.array([[1., 0, 1, 0, 1, 0] for _ in xrange(num_parts)])
175 | bias_initializer = bias_initializer.astype('float32').flatten()
176 |
177 | net = tf_util.fully_connected(pcn_enc, 256, bn=True, is_training=is_training, scope='fc1', bn_decay=bn_decay,
178 | reuse=reuse)
179 | net = tf_util.fully_connected(net, 128, bn=True, is_training=is_training, scope='fc2', bn_decay=bn_decay,
180 | reuse=reuse)
181 | trans = tf_util.fully_connected(net, num_parts * 6, activation_fn=None, scope='fc3',
182 | weights_initializer=tf.zeros_initializer(),
183 | biases_initializer=tf.constant_initializer(bias_initializer), reuse=reuse)
184 |
185 | # Perform transformation
186 | with tf.variable_scope('pcn', reuse=reuse):
187 | zeros_dims = tf.stack([tf.shape(x)[0], 1])
188 | zeros_col = tf.fill(zeros_dims, 0.0)
189 | '''
190 | sx 0 0 tx
191 | 0 sy 0 ty
192 | 0 0 sz tz
193 | '''
194 | trans_mat = tf.concat((tf.expand_dims(trans[:, 0], axis=1), zeros_col, zeros_col,
195 | tf.expand_dims(trans[:, 1], axis=1), zeros_col, tf.expand_dims(trans[:, 2], axis=1),
196 | zeros_col, tf.expand_dims(trans[:, 3], axis=1), zeros_col, zeros_col,
197 | tf.expand_dims(trans[:, 4], axis=1), tf.expand_dims(trans[:, 5], axis=1),
198 |
199 | tf.expand_dims(trans[:, 6], axis=1), zeros_col, zeros_col,
200 | tf.expand_dims(trans[:, 7], axis=1), zeros_col, tf.expand_dims(trans[:, 8], axis=1),
201 | zeros_col, tf.expand_dims(trans[:, 9], axis=1), zeros_col, zeros_col,
202 | tf.expand_dims(trans[:, 10], axis=1), tf.expand_dims(trans[:, 11], axis=1)), axis=1)
203 | for p in xrange(2, num_parts):
204 | start_ind = 6 * p
205 | trans_mat = tf.concat((trans_mat, tf.expand_dims(trans[:, start_ind], axis=1), zeros_col, zeros_col,
206 | tf.expand_dims(trans[:, start_ind + 1], axis=1), zeros_col,
207 | tf.expand_dims(trans[:, start_ind + 2], axis=1), zeros_col,
208 | tf.expand_dims(trans[:, start_ind + 3], axis=1), zeros_col, zeros_col,
209 | tf.expand_dims(trans[:, start_ind + 4], axis=1),
210 | tf.expand_dims(trans[:, start_ind + 5], axis=1)), axis=1)
211 |
212 | trans_mat = tf.reshape(trans_mat, (-1, num_parts, 3, 4))
213 | # adding 1 (w coordinate) to every point (x,y,z,1)
214 | w = tf.ones([tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], 1])
215 | x = tf.concat((x, w), axis=-1)
216 | x_t = tf.transpose(x, [0, 1, 3, 2])
217 | y_hat_t = tf.matmul(trans_mat, x_t)
218 | y_hat = tf.transpose(y_hat_t, [0, 1, 3, 2])
219 |
220 | return y_hat
221 |
222 |
223 | def get_loss_pcn(gt, pred, gt_mask, num_parts, batch_size):
224 | dists_forward_total = tf.zeros(batch_size)
225 | dists_backward_total = tf.zeros(batch_size)
226 | for part in xrange(num_parts):
227 | dists_forward, _, dists_backward, _ = tf_nndistance.nn_distance(pred[:, part], gt[:, part])
228 | # zero out the non-existing parts
229 | dists_forward = tf.reduce_sum(tf.multiply(dists_forward, gt_mask[:, part]), axis=-1)
230 | dists_backward = tf.reduce_sum(tf.multiply(dists_backward, gt_mask[:, part]), axis=-1)
231 | dists_forward_total += dists_forward
232 | dists_backward_total += dists_backward
233 |
234 | loss = dists_forward_total + dists_backward_total
235 | # divide by the number of parts
236 | div = tf.reduce_sum(tf.reduce_mean(gt_mask, axis=-1), axis=-1)
237 | loss = tf.reduce_mean(tf.div(loss, div))
238 |
239 | return loss * 100
240 |
241 |
242 | def build_test_graph(num_parts, num_points, part_embedding_dim, base_learning_rate, batch_size, decay_step, decay_rate,
243 | bn_decay_step, noise_embedding_dim):
244 | ae_ops, point_clouds_ph = build_parts_aes_graphs(num_parts, num_points, part_embedding_dim, base_learning_rate,
245 | batch_size, decay_step, decay_rate, bn_decay_step)
246 | for i in xrange(num_parts):
247 | with tf.variable_scope('part' + str(i)):
248 | samples = (tf.placeholder(tf.float32, shape=(batch_size, part_embedding_dim)))
249 | dec = ae_decoder(batch_size, num_points, samples, ae_ops[i]["is_training_ph"], reuse=True)
250 | ae_ops[i]['samples'] = samples
251 | ae_ops[i]['dec'] = dec
252 |
253 | pcn_ops = build_parts_pcn_graph(ae_ops, point_clouds_ph, num_parts, num_points, noise_embedding_dim,
254 | base_learning_rate, batch_size, decay_step, decay_rate, bn_decay_step)
255 |
256 | with tf.variable_scope('pcn'):
257 | x_full = tf.stack((ae_ops[0]['dec'], ae_ops[1]['dec']), axis=1)
258 | for p in xrange(2, num_parts):
259 | x_full = tf.concat((x_full, tf.expand_dims(ae_ops[p]['dec'], axis=1)), axis=1)
260 |
261 | cpcn_enc_full = tf.concat([ae_ops[0]['samples'], ae_ops[1]['samples']], axis=-1)
262 | for p in xrange(2, num_parts):
263 | cpcn_enc_full = tf.concat([cpcn_enc_full, ae_ops[p]['samples']], axis=-1)
264 | cpcn_enc_full = tf.concat([cpcn_enc_full, pcn_ops['noise']], axis=-1)
265 |
266 | y_hat_full = get_model_pcn(cpcn_enc_full, x_full, num_parts, pcn_ops['is_training_ph'], reuse=True)
267 | pcn_ops['pred_full'] = y_hat_full
268 |
269 | return ae_ops, pcn_ops
270 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | import tensorflow as tf
4 | import os
5 | import models
6 | from data import data_utils
7 | from utils import show3d_balls
8 | from sklearn.mixture import GaussianMixture
9 |
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--data_path', default='datasets/shapenetcore_partanno_segmentation_benchmark_v0',
12 | help='Path to the dataset [default: datasets/shapenetcore_partanno_segmentation_benchmark_v0]')
13 | parser.add_argument('--gpu', type=int, default=0, help='GPU to use [default: GPU 0]')
14 | parser.add_argument('--category', default='Chair', help='Which single class to train on [default: Chair]')
15 | parser.add_argument('--part_embedding_dim', type=int, default=64, help='Embedding dimension of each part [default: 64]')
16 | parser.add_argument('--noise_embedding_dim', type=int, default=16,
17 | help='Embedding dimension of the noise [default: 16]')
18 | parser.add_argument('--num_point', type=int, default=400, help='Number of points per part [default: 400]')
19 | parser.add_argument('--learning_rate', type=float, default=0.001,
20 | help='Initial learning rate for the parts composition network [default: 0.001]')
21 | parser.add_argument('--decay_step', type=int, default=200000, help='Decay step for lr decay [default: 200000]')
22 | parser.add_argument('--decay_rate', type=float, default=0.7, help='Decay rate for lr decay [default: 0.7]')
23 | parser.add_argument('--num_parts', type=int, default=0,
24 | help='Number of Parts, if set to 0 it will take longer to compute [default: 0]')
25 | parser.add_argument('--model_path', default='log/model.ckpt',
26 | help='model checkpoint file path [default: log/model.ckpt]')
27 | parser.add_argument('--num_samples', type=int, default='100',
28 | help='Number of generated shapes_embedd to be shown [default: 0]')
29 |
30 | FLAGS = parser.parse_args()
31 | GPU_INDEX = FLAGS.gpu
32 | CATEGORY = FLAGS.category
33 | PART_EMBEDDING_DIM = FLAGS.part_embedding_dim
34 | NOISE_EMBEDDING_DIM = FLAGS.noise_embedding_dim
35 | NUM_POINTS = FLAGS.num_point
36 | BASE_LEARNING_RATE = FLAGS.learning_rate
37 | DECAY_STEP = FLAGS.decay_step
38 | DECAY_RATE = FLAGS.decay_rate
39 | NUM_PARTS = FLAGS.num_parts
40 | MODEL_PATH = FLAGS.model_path
41 | NUM_SAMPLES = FLAGS.num_samples
42 |
43 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
44 | COLORS = [np.array([163, 254, 170]), np.array([206, 178, 254]), np.array([248, 250, 132]), np.array([237, 186, 145]),
45 | np.array([192, 144, 145]), np.array([158, 218, 73])]
46 |
47 | print 'Loading data'
48 | # Shapenet official train/test split
49 | DATA_PATH = os.path.join(ROOT_DIR, FLAGS.data_path)
50 | # Using the same splits as when training
51 | if NUM_PARTS == 0:
52 | _, _, AE_TRAIN_DATASET, _, NUM_PARTS = data_utils.load_data(DATA_PATH, NUM_POINTS, CATEGORY, 'test', 'trainval')
53 | else:
54 | AE_TRAIN_DATASET, _ = data_utils.load_aes_data(DATA_PATH, NUM_POINTS, CATEGORY, 'test', 'trainval', NUM_PARTS)
55 |
56 |
57 | def get_model(batch_size):
58 | with tf.Graph().as_default():
59 | with tf.device('/gpu:' + str(GPU_INDEX)):
60 | ae_ops, pcn_ops = models.build_test_graph(NUM_PARTS, NUM_POINTS, PART_EMBEDDING_DIM, BASE_LEARNING_RATE,
61 | batch_size, DECAY_STEP, DECAY_RATE, float(DECAY_STEP),
62 | NOISE_EMBEDDING_DIM)
63 |
64 | saver = tf.train.Saver()
65 | # Create a session
66 | config = tf.ConfigProto()
67 | config.gpu_options.allow_growth = True
68 | config.allow_soft_placement = True
69 | sess = tf.Session(config=config)
70 | # Restore variables from disk.
71 | saver.restore(sess, MODEL_PATH)
72 |
73 | return sess, ae_ops, pcn_ops
74 |
75 |
76 | def compute_embedding_gmm_and_sample_vectors(sess, ae_ops):
77 | num_gmm_components = 20
78 | samples = []
79 | for p in xrange(NUM_PARTS):
80 | part = []
81 | print 'Embedding part ' + str(p)
82 | for i in xrange(len(AE_TRAIN_DATASET[p])):
83 | ps = AE_TRAIN_DATASET[p][i]
84 | feed_dict = {ae_ops[p]['point_clouds_ph']: np.expand_dims(ps, axis=0),
85 | ae_ops[p]['gt_ph']: np.expand_dims(ps, axis=0),
86 | ae_ops[p]['is_training_ph']: False, }
87 |
88 | part.append(sess.run(ae_ops[p]['end_points']['embedding'], feed_dict=feed_dict))
89 | print 'Compute GMM for part', p
90 | gmm = GaussianMixture(n_components=num_gmm_components, covariance_type='full')
91 | gmm.fit(np.squeeze(np.array(part)))
92 | sample, _ = gmm.sample(n_samples=NUM_SAMPLES)
93 | sample = sample[np.random.permutation(np.arange(NUM_SAMPLES))]
94 | samples.append(sample)
95 |
96 | return samples
97 |
98 |
99 | def generate_shapse_from_vectors(sess, ae_ops, pcn_ops, samples):
100 | for i in range(NUM_SAMPLES):
101 | noise = np.random.normal(size=[1, NOISE_EMBEDDING_DIM])
102 | shapes_embedd = np.stack((np.expand_dims(samples[0][i], axis=0), np.expand_dims(samples[1][i], axis=0)), axis=0)
103 | for p in xrange(2, NUM_PARTS):
104 | shapes_embedd = np.concatenate(
105 | (shapes_embedd, np.expand_dims(np.expand_dims(samples[p][i], axis=0), axis=0)), axis=0)
106 |
107 | # Demonstrate a missing part
108 | if np.random.randint(1000) % 2 == 0:
109 | missing_part = True
110 | shapes_embedd[-1] = np.zeros((1, PART_EMBEDDING_DIM))
111 | else:
112 | missing_part = False
113 |
114 | feed_dict = {}
115 | for p in xrange(NUM_PARTS):
116 | feed_dict[ae_ops[p]['samples']] = shapes_embedd[p]
117 | feed_dict[ae_ops[p]['is_training_ph']] = False
118 | feed_dict[pcn_ops['noise']] = noise
119 | feed_dict[pcn_ops['is_training_ph']] = False
120 | pred = sess.run(pcn_ops['pred_full'], feed_dict=feed_dict)
121 |
122 | preds = np.concatenate((pred[0, 0], pred[0, 1]), axis=0)
123 | for p in xrange(2, NUM_PARTS):
124 | preds = np.concatenate((preds, pred[0, p]), axis=0)
125 |
126 | show_3d_point_clouds(preds, missing_part)
127 |
128 |
129 | def show_3d_point_clouds(shapes, is_missing_part):
130 | colors = np.zeros_like(shapes)
131 | for p in xrange(NUM_PARTS):
132 | colors[NUM_POINTS * p:NUM_POINTS * (p + 1), :] = COLORS[p]
133 |
134 | # fix orientation
135 | shapes[:, 1] *= -1
136 | shapes = shapes[:, [1, 2, 0]]
137 |
138 | if is_missing_part:
139 | shapes = shapes[:NUM_POINTS * (NUM_PARTS - 1)]
140 | colors = colors[:NUM_POINTS * (NUM_PARTS - 1)]
141 | show3d_balls.showpoints(shapes, c_gt=colors, ballradius=8, normalizecolor=False, background=[255, 255, 255])
142 |
143 |
144 | def test():
145 | sess, ae_ops, pcn_ops = get_model(batch_size=1)
146 | samples = compute_embedding_gmm_and_sample_vectors(sess, ae_ops)
147 | generate_shapse_from_vectors(sess, ae_ops, pcn_ops, samples)
148 |
149 |
150 | if __name__ == "__main__":
151 | test()
152 |
--------------------------------------------------------------------------------
/tf_ops/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nschor/CompoNet/816da713d36d96d715a9c026c8ac4ae568a24780/tf_ops/__init__.py
--------------------------------------------------------------------------------
/tf_ops/nn_distance/README.md:
--------------------------------------------------------------------------------
1 | From https://github.com/fanhqme/PointSetGeneration/tree/master/depthestimate
2 |
--------------------------------------------------------------------------------
/tf_ops/nn_distance/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nschor/CompoNet/816da713d36d96d715a9c026c8ac4ae568a24780/tf_ops/nn_distance/__init__.py
--------------------------------------------------------------------------------
/tf_ops/nn_distance/tf_nndistance.cpp:
--------------------------------------------------------------------------------
1 | #include "tensorflow/core/framework/op.h"
2 | #include "tensorflow/core/framework/op_kernel.h"
3 | REGISTER_OP("NnDistance")
4 | .Input("xyz1: float32")
5 | .Input("xyz2: float32")
6 | .Output("dist1: float32")
7 | .Output("idx1: int32")
8 | .Output("dist2: float32")
9 | .Output("idx2: int32");
10 | REGISTER_OP("NnDistanceGrad")
11 | .Input("xyz1: float32")
12 | .Input("xyz2: float32")
13 | .Input("grad_dist1: float32")
14 | .Input("idx1: int32")
15 | .Input("grad_dist2: float32")
16 | .Input("idx2: int32")
17 | .Output("grad_xyz1: float32")
18 | .Output("grad_xyz2: float32");
19 | using namespace tensorflow;
20 |
21 | static void nnsearch(int b,int n,int m,const float * xyz1,const float * xyz2,float * dist,int * idx){
22 | for (int i=0;iinput(0);
50 | const Tensor& xyz2_tensor=context->input(1);
51 | OP_REQUIRES(context,xyz1_tensor.dims()==3,errors::InvalidArgument("NnDistance requires xyz1 be of shape (batch,#points,3)"));
52 | OP_REQUIRES(context,xyz1_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistance only accepts 3d point set xyz1"));
53 | int b=xyz1_tensor.shape().dim_size(0);
54 | int n=xyz1_tensor.shape().dim_size(1);
55 | OP_REQUIRES(context,xyz2_tensor.dims()==3,errors::InvalidArgument("NnDistance requires xyz2 be of shape (batch,#points,3)"));
56 | OP_REQUIRES(context,xyz2_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistance only accepts 3d point set xyz2"));
57 | int m=xyz2_tensor.shape().dim_size(1);
58 | OP_REQUIRES(context,xyz2_tensor.shape().dim_size(0)==b,errors::InvalidArgument("NnDistance expects xyz1 and xyz2 have same batch size"));
59 | auto xyz1_flat=xyz1_tensor.flat();
60 | const float * xyz1=&xyz1_flat(0);
61 | auto xyz2_flat=xyz2_tensor.flat();
62 | const float * xyz2=&xyz2_flat(0);
63 | Tensor * dist1_tensor=NULL;
64 | Tensor * idx1_tensor=NULL;
65 | Tensor * dist2_tensor=NULL;
66 | Tensor * idx2_tensor=NULL;
67 | OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,n},&dist1_tensor));
68 | OP_REQUIRES_OK(context,context->allocate_output(1,TensorShape{b,n},&idx1_tensor));
69 | auto dist1_flat=dist1_tensor->flat();
70 | auto idx1_flat=idx1_tensor->flat();
71 | OP_REQUIRES_OK(context,context->allocate_output(2,TensorShape{b,m},&dist2_tensor));
72 | OP_REQUIRES_OK(context,context->allocate_output(3,TensorShape{b,m},&idx2_tensor));
73 | auto dist2_flat=dist2_tensor->flat();
74 | auto idx2_flat=idx2_tensor->flat();
75 | float * dist1=&(dist1_flat(0));
76 | int * idx1=&(idx1_flat(0));
77 | float * dist2=&(dist2_flat(0));
78 | int * idx2=&(idx2_flat(0));
79 | nnsearch(b,n,m,xyz1,xyz2,dist1,idx1);
80 | nnsearch(b,m,n,xyz2,xyz1,dist2,idx2);
81 | }
82 | };
83 | REGISTER_KERNEL_BUILDER(Name("NnDistance").Device(DEVICE_CPU), NnDistanceOp);
84 | class NnDistanceGradOp : public OpKernel{
85 | public:
86 | explicit NnDistanceGradOp(OpKernelConstruction* context):OpKernel(context){}
87 | void Compute(OpKernelContext * context)override{
88 | const Tensor& xyz1_tensor=context->input(0);
89 | const Tensor& xyz2_tensor=context->input(1);
90 | const Tensor& grad_dist1_tensor=context->input(2);
91 | const Tensor& idx1_tensor=context->input(3);
92 | const Tensor& grad_dist2_tensor=context->input(4);
93 | const Tensor& idx2_tensor=context->input(5);
94 | OP_REQUIRES(context,xyz1_tensor.dims()==3,errors::InvalidArgument("NnDistanceGrad requires xyz1 be of shape (batch,#points,3)"));
95 | OP_REQUIRES(context,xyz1_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistanceGrad only accepts 3d point set xyz1"));
96 | int b=xyz1_tensor.shape().dim_size(0);
97 | int n=xyz1_tensor.shape().dim_size(1);
98 | OP_REQUIRES(context,xyz2_tensor.dims()==3,errors::InvalidArgument("NnDistanceGrad requires xyz2 be of shape (batch,#points,3)"));
99 | OP_REQUIRES(context,xyz2_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistanceGrad only accepts 3d point set xyz2"));
100 | int m=xyz2_tensor.shape().dim_size(1);
101 | OP_REQUIRES(context,xyz2_tensor.shape().dim_size(0)==b,errors::InvalidArgument("NnDistanceGrad expects xyz1 and xyz2 have same batch size"));
102 | OP_REQUIRES(context,grad_dist1_tensor.shape()==(TensorShape{b,n}),errors::InvalidArgument("NnDistanceGrad requires grad_dist1 be of shape(batch,#points)"));
103 | OP_REQUIRES(context,idx1_tensor.shape()==(TensorShape{b,n}),errors::InvalidArgument("NnDistanceGrad requires idx1 be of shape(batch,#points)"));
104 | OP_REQUIRES(context,grad_dist2_tensor.shape()==(TensorShape{b,m}),errors::InvalidArgument("NnDistanceGrad requires grad_dist2 be of shape(batch,#points)"));
105 | OP_REQUIRES(context,idx2_tensor.shape()==(TensorShape{b,m}),errors::InvalidArgument("NnDistanceGrad requires idx2 be of shape(batch,#points)"));
106 | auto xyz1_flat=xyz1_tensor.flat();
107 | const float * xyz1=&xyz1_flat(0);
108 | auto xyz2_flat=xyz2_tensor.flat();
109 | const float * xyz2=&xyz2_flat(0);
110 | auto idx1_flat=idx1_tensor.flat();
111 | const int * idx1=&idx1_flat(0);
112 | auto idx2_flat=idx2_tensor.flat();
113 | const int * idx2=&idx2_flat(0);
114 | auto grad_dist1_flat=grad_dist1_tensor.flat();
115 | const float * grad_dist1=&grad_dist1_flat(0);
116 | auto grad_dist2_flat=grad_dist2_tensor.flat();
117 | const float * grad_dist2=&grad_dist2_flat(0);
118 | Tensor * grad_xyz1_tensor=NULL;
119 | OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,n,3},&grad_xyz1_tensor));
120 | Tensor * grad_xyz2_tensor=NULL;
121 | OP_REQUIRES_OK(context,context->allocate_output(1,TensorShape{b,m,3},&grad_xyz2_tensor));
122 | auto grad_xyz1_flat=grad_xyz1_tensor->flat();
123 | float * grad_xyz1=&grad_xyz1_flat(0);
124 | auto grad_xyz2_flat=grad_xyz2_tensor->flat();
125 | float * grad_xyz2=&grad_xyz2_flat(0);
126 | for (int i=0;iinput(0);
174 | const Tensor& xyz2_tensor=context->input(1);
175 | OP_REQUIRES(context,xyz1_tensor.dims()==3,errors::InvalidArgument("NnDistance requires xyz1 be of shape (batch,#points,3)"));
176 | OP_REQUIRES(context,xyz1_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistance only accepts 3d point set xyz1"));
177 | int b=xyz1_tensor.shape().dim_size(0);
178 | int n=xyz1_tensor.shape().dim_size(1);
179 | OP_REQUIRES(context,xyz2_tensor.dims()==3,errors::InvalidArgument("NnDistance requires xyz2 be of shape (batch,#points,3)"));
180 | OP_REQUIRES(context,xyz2_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistance only accepts 3d point set xyz2"));
181 | int m=xyz2_tensor.shape().dim_size(1);
182 | OP_REQUIRES(context,xyz2_tensor.shape().dim_size(0)==b,errors::InvalidArgument("NnDistance expects xyz1 and xyz2 have same batch size"));
183 | auto xyz1_flat=xyz1_tensor.flat();
184 | const float * xyz1=&xyz1_flat(0);
185 | auto xyz2_flat=xyz2_tensor.flat();
186 | const float * xyz2=&xyz2_flat(0);
187 | Tensor * dist1_tensor=NULL;
188 | Tensor * idx1_tensor=NULL;
189 | Tensor * dist2_tensor=NULL;
190 | Tensor * idx2_tensor=NULL;
191 | OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,n},&dist1_tensor));
192 | OP_REQUIRES_OK(context,context->allocate_output(1,TensorShape{b,n},&idx1_tensor));
193 | auto dist1_flat=dist1_tensor->flat();
194 | auto idx1_flat=idx1_tensor->flat();
195 | OP_REQUIRES_OK(context,context->allocate_output(2,TensorShape{b,m},&dist2_tensor));
196 | OP_REQUIRES_OK(context,context->allocate_output(3,TensorShape{b,m},&idx2_tensor));
197 | auto dist2_flat=dist2_tensor->flat();
198 | auto idx2_flat=idx2_tensor->flat();
199 | float * dist1=&(dist1_flat(0));
200 | int * idx1=&(idx1_flat(0));
201 | float * dist2=&(dist2_flat(0));
202 | int * idx2=&(idx2_flat(0));
203 | NmDistanceKernelLauncher(b,n,xyz1,m,xyz2,dist1,idx1,dist2,idx2);
204 | }
205 | };
206 | REGISTER_KERNEL_BUILDER(Name("NnDistance").Device(DEVICE_GPU), NnDistanceGpuOp);
207 |
208 | void NmDistanceGradKernelLauncher(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2);
209 | class NnDistanceGradGpuOp : public OpKernel{
210 | public:
211 | explicit NnDistanceGradGpuOp(OpKernelConstruction* context):OpKernel(context){}
212 | void Compute(OpKernelContext * context)override{
213 | const Tensor& xyz1_tensor=context->input(0);
214 | const Tensor& xyz2_tensor=context->input(1);
215 | const Tensor& grad_dist1_tensor=context->input(2);
216 | const Tensor& idx1_tensor=context->input(3);
217 | const Tensor& grad_dist2_tensor=context->input(4);
218 | const Tensor& idx2_tensor=context->input(5);
219 | OP_REQUIRES(context,xyz1_tensor.dims()==3,errors::InvalidArgument("NnDistanceGrad requires xyz1 be of shape (batch,#points,3)"));
220 | OP_REQUIRES(context,xyz1_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistanceGrad only accepts 3d point set xyz1"));
221 | int b=xyz1_tensor.shape().dim_size(0);
222 | int n=xyz1_tensor.shape().dim_size(1);
223 | OP_REQUIRES(context,xyz2_tensor.dims()==3,errors::InvalidArgument("NnDistanceGrad requires xyz2 be of shape (batch,#points,3)"));
224 | OP_REQUIRES(context,xyz2_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistanceGrad only accepts 3d point set xyz2"));
225 | int m=xyz2_tensor.shape().dim_size(1);
226 | OP_REQUIRES(context,xyz2_tensor.shape().dim_size(0)==b,errors::InvalidArgument("NnDistanceGrad expects xyz1 and xyz2 have same batch size"));
227 | OP_REQUIRES(context,grad_dist1_tensor.shape()==(TensorShape{b,n}),errors::InvalidArgument("NnDistanceGrad requires grad_dist1 be of shape(batch,#points)"));
228 | OP_REQUIRES(context,idx1_tensor.shape()==(TensorShape{b,n}),errors::InvalidArgument("NnDistanceGrad requires idx1 be of shape(batch,#points)"));
229 | OP_REQUIRES(context,grad_dist2_tensor.shape()==(TensorShape{b,m}),errors::InvalidArgument("NnDistanceGrad requires grad_dist2 be of shape(batch,#points)"));
230 | OP_REQUIRES(context,idx2_tensor.shape()==(TensorShape{b,m}),errors::InvalidArgument("NnDistanceGrad requires idx2 be of shape(batch,#points)"));
231 | auto xyz1_flat=xyz1_tensor.flat();
232 | const float * xyz1=&xyz1_flat(0);
233 | auto xyz2_flat=xyz2_tensor.flat();
234 | const float * xyz2=&xyz2_flat(0);
235 | auto idx1_flat=idx1_tensor.flat();
236 | const int * idx1=&idx1_flat(0);
237 | auto idx2_flat=idx2_tensor.flat();
238 | const int * idx2=&idx2_flat(0);
239 | auto grad_dist1_flat=grad_dist1_tensor.flat();
240 | const float * grad_dist1=&grad_dist1_flat(0);
241 | auto grad_dist2_flat=grad_dist2_tensor.flat();
242 | const float * grad_dist2=&grad_dist2_flat(0);
243 | Tensor * grad_xyz1_tensor=NULL;
244 | OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,n,3},&grad_xyz1_tensor));
245 | Tensor * grad_xyz2_tensor=NULL;
246 | OP_REQUIRES_OK(context,context->allocate_output(1,TensorShape{b,m,3},&grad_xyz2_tensor));
247 | auto grad_xyz1_flat=grad_xyz1_tensor->flat();
248 | float * grad_xyz1=&grad_xyz1_flat(0);
249 | auto grad_xyz2_flat=grad_xyz2_tensor->flat();
250 | float * grad_xyz2=&grad_xyz2_flat(0);
251 | NmDistanceGradKernelLauncher(b,n,xyz1,m,xyz2,grad_dist1,idx1,grad_dist2,idx2,grad_xyz1,grad_xyz2);
252 | }
253 | };
254 | REGISTER_KERNEL_BUILDER(Name("NnDistanceGrad").Device(DEVICE_GPU), NnDistanceGradGpuOp);
255 |
--------------------------------------------------------------------------------
/tf_ops/nn_distance/tf_nndistance.py:
--------------------------------------------------------------------------------
1 | """ Compute Chamfer's Distance.
2 |
3 | Original author: Haoqiang Fan.
4 | Modified by Charles R. Qi
5 | """
6 |
7 | import tensorflow as tf
8 | from tensorflow.python.framework import ops
9 | import sys
10 | import os
11 |
12 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
13 | nn_distance_module = tf.load_op_library(os.path.join(BASE_DIR, 'tf_nndistance_so.so'))
14 |
15 |
16 | def nn_distance(xyz1, xyz2):
17 | '''
18 | Computes the distance of nearest neighbors for a pair of point clouds
19 | input: xyz1: (batch_size,#points_1,3) the first point cloud
20 | input: xyz2: (batch_size,#points_2,3) the second point cloud
21 | output: dist1: (batch_size,#point_1) distance from first to second
22 | output: idx1: (batch_size,#point_1) nearest neighbor from first to second
23 | output: dist2: (batch_size,#point_2) distance from second to first
24 | output: idx2: (batch_size,#point_2) nearest neighbor from second to first
25 | '''
26 | return nn_distance_module.nn_distance(xyz1, xyz2)
27 |
28 |
29 | # @tf.RegisterShape('NnDistance')
30 | # def _nn_distance_shape(op):
31 | # shape1=op.inputs[0].get_shape().with_rank(3)
32 | # shape2=op.inputs[1].get_shape().with_rank(3)
33 | # return [tf.TensorShape([shape1.dims[0],shape1.dims[1]]),tf.TensorShape([shape1.dims[0],shape1.dims[1]]),
34 | # tf.TensorShape([shape2.dims[0],shape2.dims[1]]),tf.TensorShape([shape2.dims[0],shape2.dims[1]])]
35 | @ops.RegisterGradient('NnDistance')
36 | def _nn_distance_grad(op, grad_dist1, grad_idx1, grad_dist2, grad_idx2):
37 | xyz1 = op.inputs[0]
38 | xyz2 = op.inputs[1]
39 | idx1 = op.outputs[1]
40 | idx2 = op.outputs[3]
41 | return nn_distance_module.nn_distance_grad(xyz1, xyz2, grad_dist1, idx1, grad_dist2, idx2)
42 |
43 |
44 | if __name__ == '__main__':
45 | import numpy as np
46 | import random
47 | import time
48 | from tensorflow.python.ops.gradient_checker import compute_gradient
49 |
50 | random.seed(100)
51 | np.random.seed(100)
52 | with tf.Session('') as sess:
53 | xyz1 = np.random.randn(32, 16384, 3).astype('float32')
54 | xyz2 = np.random.randn(32, 1024, 3).astype('float32')
55 | # with tf.device('/gpu:0'):
56 | if True:
57 | inp1 = tf.Variable(xyz1)
58 | inp2 = tf.constant(xyz2)
59 | reta, retb, retc, retd = nn_distance(inp1, inp2)
60 | loss = tf.reduce_sum(reta) + tf.reduce_sum(retc)
61 | train = tf.train.GradientDescentOptimizer(learning_rate=0.05).minimize(loss)
62 | sess.run(tf.initialize_all_variables())
63 | t0 = time.time()
64 | t1 = t0
65 | best = 1e100
66 | for i in xrange(100):
67 | trainloss, _ = sess.run([loss, train])
68 | newt = time.time()
69 | best = min(best, newt - t1)
70 | print i, trainloss, (newt - t0) / (i + 1), best
71 | t1 = newt
72 | # print sess.run([inp1,retb,inp2,retd])
73 | # grads=compute_gradient([inp1,inp2],[(16,32,3),(16,32,3)],loss,(1,),[xyz1,xyz2])
74 | # for i,j in grads:
75 | # print i.shape,j.shape,np.mean(np.abs(i-j)),np.mean(np.abs(i)),np.mean(np.abs(j))
76 | # for i in xrange(10):
77 | # t0=time.time()
78 | # a,b,c,d=sess.run([reta,retb,retc,retd],feed_dict={inp1:xyz1,inp2:xyz2})
79 | # print 'time',time.time()-t0
80 | # print a.shape,b.shape,c.shape,d.shape
81 | # print a.dtype,b.dtype,c.dtype,d.dtype
82 | # samples=np.array(random.sample(range(xyz2.shape[1]),100),dtype='int32')
83 | # dist1=((xyz1[:,samples,None,:]-xyz2[:,None,:,:])**2).sum(axis=-1).min(axis=-1)
84 | # idx1=((xyz1[:,samples,None,:]-xyz2[:,None,:,:])**2).sum(axis=-1).argmin(axis=-1)
85 | # print np.abs(dist1-a[:,samples]).max()
86 | # print np.abs(idx1-b[:,samples]).max()
87 | # dist2=((xyz2[:,samples,None,:]-xyz1[:,None,:,:])**2).sum(axis=-1).min(axis=-1)
88 | # idx2=((xyz2[:,samples,None,:]-xyz1[:,None,:,:])**2).sum(axis=-1).argmin(axis=-1)
89 | # print np.abs(dist2-c[:,samples]).max()
90 | # print np.abs(idx2-d[:,samples]).max()
91 |
--------------------------------------------------------------------------------
/tf_ops/nn_distance/tf_nndistance_compile.sh:
--------------------------------------------------------------------------------
1 | /usr/local/cuda-8.0/bin/nvcc tf_nndistance_g.cu -o tf_nndistance_g.cu.o -c -O2 -DGOOGLE_CUDA=1 -x cu -Xcompiler -fPIC
2 | g++ -std=c++11 tf_nndistance.cpp tf_nndistance_g.cu.o -o tf_nndistance_so.so -shared -fPIC -I /home/nadav/anaconda2/envs/tensorflow/lib/python2.7/dist-packages/tensorflow/include -I /usr/local/cuda-8.0/include -I /home/nadav/anaconda2/envs/tensorflow/lib/python2.7/dist-packages/tensorflow/include/external/nsync/public -lcudart -L /usr/local/cuda-8.0/lib64/ -L/home/nadav/anaconda2/envs/tensorflow/lib/python2.7/dist-packages/tensorflow -ltensorflow_framework -O2 -D_GLIBCXX_USE_CXX11_ABI=0
3 |
4 |
--------------------------------------------------------------------------------
/tf_ops/nn_distance/tf_nndistance_cpu.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 |
4 | def nn_distance_cpu(pc1, pc2):
5 | '''
6 | Input:
7 | pc1: float TF tensor in shape (B,N,C) the first point cloud
8 | pc2: float TF tensor in shape (B,M,C) the second point cloud
9 | Output:
10 | dist1: float TF tensor in shape (B,N) distance from first to second
11 | idx1: int32 TF tensor in shape (B,N) nearest neighbor from first to second
12 | dist2: float TF tensor in shape (B,M) distance from second to first
13 | idx2: int32 TF tensor in shape (B,M) nearest neighbor from second to first
14 | '''
15 | N = pc1.get_shape()[1].value
16 | M = pc2.get_shape()[1].value
17 | pc1_expand_tile = tf.tile(tf.expand_dims(pc1,2), [1,1,M,1])
18 | pc2_expand_tile = tf.tile(tf.expand_dims(pc2,1), [1,N,1,1])
19 | pc_diff = pc1_expand_tile - pc2_expand_tile # B,N,M,C
20 | pc_dist = tf.reduce_sum(pc_diff ** 2, axis=-1) # B,N,M
21 | dist1 = tf.reduce_min(pc_dist, axis=2) # B,N
22 | idx1 = tf.argmin(pc_dist, axis=2) # B,N
23 | dist2 = tf.reduce_min(pc_dist, axis=1) # B,M
24 | idx2 = tf.argmin(pc_dist, axis=1) # B,M
25 | return dist1, idx1, dist2, idx2
26 |
27 |
28 | def verify_nn_distance_cup():
29 | np.random.seed(0)
30 | sess = tf.Session()
31 | pc1arr = np.random.random((1,5,3))
32 | pc2arr = np.random.random((1,6,3))
33 | pc1 = tf.constant(pc1arr)
34 | pc2 = tf.constant(pc2arr)
35 | dist1, idx1, dist2, idx2 = nn_distance_cpu(pc1, pc2)
36 | print(sess.run(dist1))
37 | print(sess.run(idx1))
38 | print(sess.run(dist2))
39 | print(sess.run(idx2))
40 |
41 | dist = np.zeros((5,6))
42 | for i in range(5):
43 | for j in range(6):
44 | dist[i,j] = np.sum((pc1arr[0,i,:] - pc2arr[0,j,:]) ** 2)
45 | print(dist)
46 |
47 | if __name__ == '__main__':
48 | verify_nn_distance_cup()
49 |
--------------------------------------------------------------------------------
/tf_ops/nn_distance/tf_nndistance_g.cu:
--------------------------------------------------------------------------------
1 | #if GOOGLE_CUDA
2 | #define EIGEN_USE_GPU
3 | //#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
4 |
5 | __global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){
6 | const int batch=512;
7 | __shared__ float buf[batch*3];
8 | for (int i=blockIdx.x;ibest){
120 | result[(i*n+j)]=best;
121 | result_i[(i*n+j)]=best_i;
122 | }
123 | }
124 | __syncthreads();
125 | }
126 | }
127 | }
128 | void NmDistanceKernelLauncher(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i){
129 | NmDistanceKernel<<>>(b,n,xyz,m,xyz2,result,result_i);
130 | NmDistanceKernel<<>>(b,m,xyz2,n,xyz,result2,result2_i);
131 | }
132 | __global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){
133 | for (int i=blockIdx.x;i>>(b,n,xyz1,m,xyz2,grad_dist1,idx1,grad_xyz1,grad_xyz2);
156 | NmDistanceGradKernel<<>>(b,m,xyz2,n,xyz1,grad_dist2,idx2,grad_xyz2,grad_xyz1);
157 | }
158 |
159 | #endif
160 |
161 |
--------------------------------------------------------------------------------
/tf_ops/nn_distance/tf_nndistance_g.cu.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nschor/CompoNet/816da713d36d96d715a9c026c8ac4ae568a24780/tf_ops/nn_distance/tf_nndistance_g.cu.o
--------------------------------------------------------------------------------
/tf_ops/nn_distance/tf_nndistance_so.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nschor/CompoNet/816da713d36d96d715a9c026c8ac4ae568a24780/tf_ops/nn_distance/tf_nndistance_so.so
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from datetime import datetime
3 | import numpy as np
4 | import tensorflow as tf
5 | import os
6 | import sys
7 | import models
8 | from data import data_utils
9 |
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--data_path', default='datasets/shapenetcore_partanno_segmentation_benchmark_v0',
12 | help='Path to the dataset [default: datasets/shapenetcore_partanno_segmentation_benchmark_v0]')
13 | parser.add_argument('--gpu', type=int, default=0, help='GPU to use [default: GPU 0]')
14 | parser.add_argument('--category', default='Chair', help='Which single class to train on [default: Chair]')
15 | parser.add_argument('--log_dir', default='log', help='Log dir [default: log]')
16 | parser.add_argument('--part_embedding_dim', type=int, default=64, help='Embedding dimension of each part [default: 64]')
17 | parser.add_argument('--noise_embedding_dim', type=int, default=16,
18 | help='Embedding dimension of the noise [default: 16]')
19 | parser.add_argument('--num_point', type=int, default=400, help='Number of points per part [default: 400]')
20 | parser.add_argument('--max_epoch_ae', type=int, default=401, help='Number of epochs for each AEs [default: 401]')
21 | parser.add_argument('--max_epoch_pcn', type=int, default=201,
22 | help='Number of epochs for the parts composition network [default: 201]')
23 | parser.add_argument('--batch_size', type=int, default=64, help='Batch Size during training [default: 64]')
24 | parser.add_argument('--learning_rate', type=float, default=0.001,
25 | help='Initial learning rate for the composition network [default: 0.001]')
26 | parser.add_argument('--decay_step', type=int, default=200000, help='Decay step for lr decay [default: 200000]')
27 | parser.add_argument('--decay_rate', type=float, default=0.7, help='Decay rate for lr decay [default: 0.7]')
28 | FLAGS = parser.parse_args()
29 | GPU_INDEX = FLAGS.gpu
30 | CATEGORY = FLAGS.category
31 | LOG_DIR = FLAGS.log_dir
32 | PART_EMBEDDING_DIM = FLAGS.part_embedding_dim
33 | NOISE_EMBEDDING_DIM = FLAGS.noise_embedding_dim
34 | NUM_POINTS = FLAGS.num_point
35 | MAX_EPOCH_AE = FLAGS.max_epoch_ae
36 | MAX_EPOCH_PCN = FLAGS.max_epoch_pcn
37 | BATCH_SIZE = FLAGS.batch_size
38 | BASE_LEARNING_RATE = FLAGS.learning_rate
39 | DECAY_STEP = FLAGS.decay_step
40 | DECAY_RATE = FLAGS.decay_rate
41 |
42 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
43 | if not os.path.exists(LOG_DIR):
44 | os.mkdir(LOG_DIR)
45 | LOG_FOUT = open(os.path.join(LOG_DIR, 'log_train.txt'), 'w')
46 | LOG_FOUT.write(str(FLAGS) + '\n')
47 |
48 | print 'Loading data'
49 | # Shapenet official train/test split
50 | DATA_PATH = os.path.join(ROOT_DIR, FLAGS.data_path)
51 | # Using the smaller set, test, as seen and the larger, trainval, as unseen
52 | PCN_TRAIN_DATASET, PCN_TEST_DATASET, AE_TRAIN_DATASET, AE_TEST_DATASET, NUM_PARTS = data_utils.load_data(DATA_PATH,
53 | NUM_POINTS,
54 | CATEGORY,
55 | 'test',
56 | 'trainval')
57 |
58 | NOISE = np.random.normal(size=[len(PCN_TRAIN_DATASET), NOISE_EMBEDDING_DIM])
59 | EPOCH_CNT = 0
60 |
61 |
62 | def log_string(out_str):
63 | LOG_FOUT.write(out_str + '\n')
64 | LOG_FOUT.flush()
65 | print(out_str)
66 |
67 |
68 | def get_ae_batch(dataset, idxs, start_idx, end_idx):
69 | batch_size = end_idx - start_idx
70 | batch_data = np.zeros((batch_size, NUM_POINTS, 3))
71 | for i in range(batch_size):
72 | ps = dataset[idxs[i + start_idx]]
73 | batch_data[i, ...] = ps
74 | return batch_data
75 |
76 |
77 | def train_ae_one_epoch(sess, ops_ae, train_dataset, train_writer):
78 | is_training = True
79 | log_string(str(datetime.now()))
80 |
81 | # Shuffle train samples
82 | train_idxs = np.arange(0, len(train_dataset))
83 | np.random.shuffle(train_idxs)
84 | num_batches = len(train_dataset) / BATCH_SIZE
85 |
86 | loss_sum = 0
87 | for batch_idx in range(num_batches):
88 | start_idx = batch_idx * BATCH_SIZE
89 | end_idx = (batch_idx + 1) * BATCH_SIZE
90 | batch_data = get_ae_batch(train_dataset, train_idxs, start_idx, end_idx)
91 | feed_dict = {ops_ae['point_clouds_ph']: batch_data,
92 | ops_ae['gt_ph']: batch_data,
93 | ops_ae['is_training_ph']: is_training, }
94 | summary, step, _, loss_val, pred_val = sess.run([ops_ae['merged_summary'], ops_ae['step'],
95 | ops_ae['train_op'], ops_ae['loss'],
96 | ops_ae['pred']], feed_dict=feed_dict)
97 | train_writer.add_summary(summary, step)
98 | loss_sum += loss_val
99 |
100 | if (batch_idx + 1) % 10 == 0:
101 | log_string(' -- %03d / %03d --' % (batch_idx + 1, num_batches))
102 | log_string('mean loss: %f' % (loss_sum / 10))
103 | loss_sum = 0
104 |
105 |
106 | def get_pcn_batch(dataset, idxs, start_idx, end_idx):
107 | batch_size = end_idx - start_idx
108 | x = np.zeros((NUM_PARTS, batch_size, NUM_POINTS, 3))
109 | y = np.zeros((batch_size, NUM_PARTS, NUM_POINTS, 3))
110 | y_mask = np.zeros((batch_size, NUM_PARTS, NUM_POINTS))
111 | noise = NOISE[idxs[start_idx:end_idx]]
112 |
113 | for i in range(batch_size):
114 | point_sets = dataset[idxs[start_idx + i]]
115 | for p in xrange(NUM_PARTS):
116 | ps, sn, is_full = point_sets[p]
117 | x[p, i, ...] = sn
118 | y[i, p, ...] = ps
119 | y_mask[i, p] = is_full
120 | return x, y, y_mask, noise
121 |
122 |
123 | def train_pcn_one_epoch(sess, pcn_ops, ae_ops, train_dataset, train_writer):
124 | is_training = True
125 | log_string(str(datetime.now()))
126 |
127 | # Shuffle train samples
128 | train_idxs = np.arange(0, len(train_dataset))
129 | np.random.shuffle(train_idxs)
130 | num_batches = len(train_dataset) / BATCH_SIZE
131 |
132 | loss_sum = 0
133 | total_loss = 0
134 | for batch_idx in range(num_batches):
135 | start_idx = batch_idx * BATCH_SIZE
136 | end_idx = (batch_idx + 1) * BATCH_SIZE
137 | x, y, y_mask, noise = get_pcn_batch(train_dataset, train_idxs, start_idx, end_idx)
138 | feed_dict = {pcn_ops['y']: y,
139 | pcn_ops['y_mask']: y_mask,
140 | pcn_ops['noise']: noise,
141 | pcn_ops['is_training_ph']: is_training, }
142 |
143 | for part in xrange(len(x)):
144 | feed_dict[ae_ops[part]['point_clouds_ph']] = x[part]
145 | feed_dict[ae_ops[part]['gt_ph']] = x[part]
146 | feed_dict[ae_ops[part]['is_training_ph']] = False # encoders are set in the PCN training phase
147 |
148 | summary, step, _, loss_val, pred_val = sess.run(
149 | [pcn_ops['merged_summary'], pcn_ops['step'], pcn_ops['train_op'], pcn_ops['loss'], pcn_ops['pred']],
150 | feed_dict=feed_dict)
151 |
152 | train_writer.add_summary(summary, step)
153 | loss_sum += loss_val
154 | total_loss += loss_val
155 |
156 | if (batch_idx + 1) % 10 == 0:
157 | log_string(' -- %03d / %03d --' % (batch_idx + 1, num_batches))
158 | log_string('mean loss: %f' % (loss_sum / 10))
159 | loss_sum = 0
160 |
161 | return total_loss / float(num_batches)
162 |
163 |
164 | def train():
165 | with tf.Graph().as_default():
166 | with tf.device('/gpu:' + str(GPU_INDEX)):
167 | ''' Parts' AE '''
168 | ae_ops, point_clouds_ph = models.build_parts_aes_graphs(NUM_PARTS, NUM_POINTS, PART_EMBEDDING_DIM,
169 | BASE_LEARNING_RATE, BATCH_SIZE, DECAY_STEP,
170 | DECAY_RATE, float(DECAY_STEP))
171 |
172 | ''' PCN '''
173 | pcn_ops = models.build_parts_pcn_graph(ae_ops, point_clouds_ph, NUM_PARTS, NUM_POINTS, NOISE_EMBEDDING_DIM,
174 | BASE_LEARNING_RATE, BATCH_SIZE, DECAY_STEP, DECAY_RATE,
175 | float(DECAY_STEP))
176 |
177 | saver = tf.train.Saver()
178 |
179 | config = tf.ConfigProto()
180 | config.gpu_options.allow_growth = True
181 | config.allow_soft_placement = True
182 | config.log_device_placement = False
183 | sess = tf.Session(config=config)
184 |
185 | train_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'train'), sess.graph)
186 |
187 | init = tf.global_variables_initializer()
188 | sess.run(init)
189 |
190 | ''' Train Parts AE '''
191 | for i in xrange(NUM_PARTS):
192 | print 'Training part ' + str(i)
193 | for epoch in range(MAX_EPOCH_AE):
194 | log_string('**** EPOCH %03d ****' % epoch)
195 | sys.stdout.flush()
196 | train_ae_one_epoch(sess, ae_ops[i], AE_TRAIN_DATASET[i], train_writer)
197 |
198 | ''' Train PCN '''
199 | best_loss = 1e20
200 | for epoch in range(MAX_EPOCH_PCN):
201 | log_string('**** EPOCH %03d ****' % epoch)
202 | sys.stdout.flush()
203 | train_loss = train_pcn_one_epoch(sess, pcn_ops, ae_ops, PCN_TRAIN_DATASET, train_writer)
204 | if train_loss < best_loss:
205 | best_loss = train_loss
206 | save_path = saver.save(sess, os.path.join(LOG_DIR, "best_model_epoch_%03d.ckpt" % epoch))
207 | log_string("Model saved in file: %s" % save_path)
208 | if epoch % 10 == 0:
209 | save_path = saver.save(sess, os.path.join(LOG_DIR, "model.ckpt"))
210 | log_string("Model saved in file: %s" % save_path)
211 |
212 |
213 | if __name__ == "__main__":
214 | log_string('pid: %s' % (str(os.getpid())))
215 | train()
216 | LOG_FOUT.close()
217 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nschor/CompoNet/816da713d36d96d715a9c026c8ac4ae568a24780/utils/__init__.py
--------------------------------------------------------------------------------
/utils/compile_render_balls_so.sh:
--------------------------------------------------------------------------------
1 | g++ -std=c++11 render_balls_so.cpp -o render_balls_so.so -shared -fPIC -O2 -D_GLIBCXX_USE_CXX11_ABI=0
2 |
3 |
--------------------------------------------------------------------------------
/utils/render_balls_so.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | using namespace std;
6 |
7 | struct PointInfo{
8 | int x,y,z;
9 | float r,g,b;
10 | };
11 |
12 | extern "C"{
13 |
14 | void render_ball(int h,int w,unsigned char * show,int n,int * xyzs,float * c0,float * c1,float * c2,int r){
15 | r=max(r,1);
16 | vector depth(h*w,-2100000000);
17 | vector pattern;
18 | for (int dx=-r;dx<=r;dx++)
19 | for (int dy=-r;dy<=r;dy++)
20 | if (dx*dx+dy*dy=h || y2<0 || y2>=w) && depth[x2*w+y2] 0:
95 | show[:, :, 0] = np.maximum(show[:, :, 0], np.roll(show[:, :, 0], 1, axis=0))
96 | if magnifyBlue >= 2:
97 | show[:, :, 0] = np.maximum(show[:, :, 0], np.roll(show[:, :, 0], -1, axis=0))
98 | show[:, :, 0] = np.maximum(show[:, :, 0], np.roll(show[:, :, 0], 1, axis=1))
99 | if magnifyBlue >= 2:
100 | show[:, :, 0] = np.maximum(show[:, :, 0], np.roll(show[:, :, 0], -1, axis=1))
101 | if showrot:
102 | cv2.putText(show, 'xangle %d' % (int(xangle / np.pi * 180)), (30, showsz - 30), 0, 0.5,
103 | cv2.cv.CV_RGB(255, 0, 0))
104 | cv2.putText(show, 'yangle %d' % (int(yangle / np.pi * 180)), (30, showsz - 50), 0, 0.5,
105 | cv2.cv.CV_RGB(255, 0, 0))
106 | cv2.putText(show, 'zoom %d%%' % (int(zoom * 100)), (30, showsz - 70), 0, 0.5, cv2.cv.CV_RGB(255, 0, 0))
107 |
108 | changed = True
109 | stop = False
110 | while not stop:
111 | if changed:
112 | render()
113 | changed = False
114 | cv2.imshow('show3d', show)
115 | if waittime == 0:
116 | cmd = cv2.waitKey(10) % 256
117 | else:
118 | cmd = cv2.waitKey(waittime) % 256
119 | if cmd == ord('q'):
120 | stop = True
121 | elif cmd == ord('Q'):
122 | sys.exit(0)
123 |
124 | if cmd == ord('t') or cmd == ord('p'):
125 | if cmd == ord('t'):
126 | if c_gt is None:
127 | c0 = np.zeros((len(xyz),), dtype='float32') + 255
128 | c1 = np.zeros((len(xyz),), dtype='float32') + 255
129 | c2 = np.zeros((len(xyz),), dtype='float32') + 255
130 | else:
131 | c0 = c_gt[:, 0]
132 | c1 = c_gt[:, 1]
133 | c2 = c_gt[:, 2]
134 | else:
135 | if c_pred is None:
136 | c0 = np.zeros((len(xyz),), dtype='float32') + 255
137 | c1 = np.zeros((len(xyz),), dtype='float32') + 255
138 | c2 = np.zeros((len(xyz),), dtype='float32') + 255
139 | else:
140 | c0 = c_pred[:, 0]
141 | c1 = c_pred[:, 1]
142 | c2 = c_pred[:, 2]
143 | if normalizecolor:
144 | c0 /= (c0.max() + 1e-14) / 255.0
145 | c1 /= (c1.max() + 1e-14) / 255.0
146 | c2 /= (c2.max() + 1e-14) / 255.0
147 | c0 = np.require(c0, 'float32', 'C')
148 | c1 = np.require(c1, 'float32', 'C')
149 | c2 = np.require(c2, 'float32', 'C')
150 | changed = True
151 |
152 | if cmd == ord('n'):
153 | zoom *= 1.1
154 | changed = True
155 | elif cmd == ord('m'):
156 | zoom /= 1.1
157 | changed = True
158 | elif cmd == ord('r'):
159 | zoom = 1.0
160 | changed = True
161 | elif cmd == ord('s'):
162 | cv2.imwrite('show3d.png', show)
163 | elif cmd == ord('f'):
164 | freezerot = ~freezerot
165 | if waittime != 0:
166 | break
167 | return cmd
168 |
169 |
170 | if __name__ == '__main__':
171 | np.random.seed(100)
172 | showpoints(np.random.randn(2500, 3))
173 |
--------------------------------------------------------------------------------
/utils/tf_util.py:
--------------------------------------------------------------------------------
1 | """ Wrapper functions for TensorFlow layers.
2 |
3 | Author: Charles R. Qi
4 | Date: November 2017
5 | """
6 |
7 | import numpy as np
8 | import tensorflow as tf
9 |
10 |
11 | def _variable_on_cpu(name, shape, initializer, use_fp16=False):
12 | """Helper to create a Variable stored on CPU memory.
13 | Args:
14 | name: name of the variable
15 | shape: list of ints
16 | initializer: initializer for Variable
17 | Returns:
18 | Variable Tensor
19 | """
20 | with tf.device("/cpu:0"):
21 | dtype = tf.float16 if use_fp16 else tf.float32
22 | var = tf.get_variable(name, shape, initializer=initializer, dtype=dtype)
23 | return var
24 |
25 |
26 | def _variable_with_weight_decay(name, shape, stddev, wd, use_xavier=True, initializer=None):
27 | """Helper to create an initialized Variable with weight decay.
28 |
29 | Note that the Variable is initialized with a truncated normal distribution.
30 | A weight decay is added only if one is specified.
31 |
32 | Args:
33 | name: name of the variable
34 | shape: list of ints
35 | stddev: standard deviation of a truncated Gaussian
36 | wd: add L2Loss weight decay multiplied by this float. If None, weight
37 | decay is not added for this Variable.
38 | use_xavier: bool, whether to use xavier initializer
39 |
40 | Returns:
41 | Variable Tensor
42 | """
43 | if initializer is None:
44 | if use_xavier:
45 | initializer = tf.contrib.layers.xavier_initializer()
46 | else:
47 | initializer = tf.truncated_normal_initializer(stddev=stddev)
48 | var = _variable_on_cpu(name, shape, initializer)
49 | if wd is not None:
50 | weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='weight_loss')
51 | tf.add_to_collection('losses', weight_decay)
52 | return var
53 |
54 |
55 | def conv1d(inputs,
56 | num_output_channels,
57 | kernel_size,
58 | scope,
59 | stride=1,
60 | padding='SAME',
61 | data_format='NHWC',
62 | use_xavier=True,
63 | stddev=1e-3,
64 | weight_decay=None,
65 | activation_fn=tf.nn.relu,
66 | bn=False,
67 | bn_decay=None,
68 | is_training=None):
69 | """ 1D convolution with non-linear operation.
70 |
71 | Args:
72 | inputs: 3-D tensor variable BxLxC
73 | num_output_channels: int
74 | kernel_size: int
75 | scope: string
76 | stride: int
77 | padding: 'SAME' or 'VALID'
78 | data_format: 'NHWC' or 'NCHW'
79 | use_xavier: bool, use xavier_initializer if true
80 | stddev: float, stddev for truncated_normal init
81 | weight_decay: float
82 | activation_fn: function
83 | bn: bool, whether to use batch norm
84 | bn_decay: float or float tensor variable in [0,1]
85 | is_training: bool Tensor variable
86 |
87 | Returns:
88 | Variable tensor
89 | """
90 | with tf.variable_scope(scope) as sc:
91 | assert (data_format == 'NHWC' or data_format == 'NCHW')
92 | if data_format == 'NHWC':
93 | num_in_channels = inputs.get_shape()[-1].value
94 | elif data_format == 'NCHW':
95 | num_in_channels = inputs.get_shape()[1].value
96 | kernel_shape = [kernel_size,
97 | num_in_channels, num_output_channels]
98 | kernel = _variable_with_weight_decay('weights',
99 | shape=kernel_shape,
100 | use_xavier=use_xavier,
101 | stddev=stddev,
102 | wd=weight_decay)
103 | outputs = tf.nn.conv1d(inputs, kernel,
104 | stride=stride,
105 | padding=padding,
106 | data_format=data_format)
107 | biases = _variable_on_cpu('biases', [num_output_channels],
108 | tf.constant_initializer(0.0))
109 | outputs = tf.nn.bias_add(outputs, biases, data_format=data_format)
110 |
111 | if bn:
112 | outputs = batch_norm_for_conv1d(outputs, is_training,
113 | bn_decay=bn_decay, scope='bn',
114 | data_format=data_format)
115 |
116 | if activation_fn is not None:
117 | outputs = activation_fn(outputs)
118 | return outputs
119 |
120 |
121 | def conv2d(inputs,
122 | num_output_channels,
123 | kernel_size,
124 | scope,
125 | stride=[1, 1],
126 | padding='SAME',
127 | data_format='NHWC',
128 | use_xavier=True,
129 | stddev=1e-3,
130 | weight_decay=None,
131 | activation_fn=tf.nn.relu,
132 | bn=False,
133 | bn_decay=None,
134 | is_training=None):
135 | """ 2D convolution with non-linear operation.
136 |
137 | Args:
138 | inputs: 4-D tensor variable BxHxWxC
139 | num_output_channels: int
140 | kernel_size: a list of 2 ints
141 | scope: string
142 | stride: a list of 2 ints
143 | padding: 'SAME' or 'VALID'
144 | data_format: 'NHWC' or 'NCHW'
145 | use_xavier: bool, use xavier_initializer if true
146 | stddev: float, stddev for truncated_normal init
147 | weight_decay: float
148 | activation_fn: function
149 | bn: bool, whether to use batch norm
150 | bn_decay: float or float tensor variable in [0,1]
151 | is_training: bool Tensor variable
152 |
153 | Returns:
154 | Variable tensor
155 | """
156 | with tf.variable_scope(scope) as sc:
157 | kernel_h, kernel_w = kernel_size
158 | assert (data_format == 'NHWC' or data_format == 'NCHW')
159 | if data_format == 'NHWC':
160 | num_in_channels = inputs.get_shape()[-1].value
161 | elif data_format == 'NCHW':
162 | num_in_channels = inputs.get_shape()[1].value
163 | kernel_shape = [kernel_h, kernel_w,
164 | num_in_channels, num_output_channels]
165 | kernel = _variable_with_weight_decay('weights',
166 | shape=kernel_shape,
167 | use_xavier=use_xavier,
168 | stddev=stddev,
169 | wd=weight_decay)
170 | stride_h, stride_w = stride
171 | outputs = tf.nn.conv2d(inputs, kernel,
172 | [1, stride_h, stride_w, 1],
173 | padding=padding,
174 | data_format=data_format)
175 | biases = _variable_on_cpu('biases', [num_output_channels],
176 | tf.constant_initializer(0.0))
177 | outputs = tf.nn.bias_add(outputs, biases, data_format=data_format)
178 |
179 | if bn:
180 | outputs = batch_norm_for_conv2d(outputs, is_training,
181 | bn_decay=bn_decay, scope='bn',
182 | data_format=data_format)
183 |
184 | if activation_fn is not None:
185 | outputs = activation_fn(outputs)
186 | return outputs
187 |
188 |
189 | def conv2d_transpose(inputs,
190 | num_output_channels,
191 | kernel_size,
192 | scope,
193 | stride=[1, 1],
194 | padding='SAME',
195 | data_format='NHWC',
196 | use_xavier=True,
197 | stddev=1e-3,
198 | weight_decay=None,
199 | activation_fn=tf.nn.relu,
200 | bn=False,
201 | bn_decay=None,
202 | is_training=None):
203 | """ 2D convolution transpose with non-linear operation.
204 |
205 | Args:
206 | inputs: 4-D tensor variable BxHxWxC
207 | num_output_channels: int
208 | kernel_size: a list of 2 ints
209 | scope: string
210 | stride: a list of 2 ints
211 | padding: 'SAME' or 'VALID'
212 | use_xavier: bool, use xavier_initializer if true
213 | stddev: float, stddev for truncated_normal init
214 | weight_decay: float
215 | activation_fn: function
216 | bn: bool, whether to use batch norm
217 | bn_decay: float or float tensor variable in [0,1]
218 | is_training: bool Tensor variable
219 |
220 | Returns:
221 | Variable tensor
222 |
223 | Note: conv2d(conv2d_transpose(a, num_out, ksize, stride), a.shape[-1], ksize, stride) == a
224 | """
225 | with tf.variable_scope(scope) as sc:
226 | kernel_h, kernel_w = kernel_size
227 | num_in_channels = inputs.get_shape()[-1].value
228 | kernel_shape = [kernel_h, kernel_w,
229 | num_output_channels, num_in_channels] # reversed to conv2d
230 | kernel = _variable_with_weight_decay('weights',
231 | shape=kernel_shape,
232 | use_xavier=use_xavier,
233 | stddev=stddev,
234 | wd=weight_decay)
235 | stride_h, stride_w = stride
236 |
237 | # from slim.convolution2d_transpose
238 | def get_deconv_dim(dim_size, stride_size, kernel_size, padding):
239 | dim_size *= stride_size
240 |
241 | if padding == 'VALID' and dim_size is not None:
242 | dim_size += max(kernel_size - stride_size, 0)
243 | return dim_size
244 |
245 | # caculate output shape
246 | batch_size = inputs.get_shape()[0].value
247 | height = inputs.get_shape()[1].value
248 | width = inputs.get_shape()[2].value
249 | out_height = get_deconv_dim(height, stride_h, kernel_h, padding)
250 | out_width = get_deconv_dim(width, stride_w, kernel_w, padding)
251 | output_shape = [batch_size, out_height, out_width, num_output_channels]
252 |
253 | outputs = tf.nn.conv2d_transpose(inputs, kernel, output_shape,
254 | [1, stride_h, stride_w, 1],
255 | padding=padding)
256 | biases = _variable_on_cpu('biases', [num_output_channels],
257 | tf.constant_initializer(0.0))
258 | outputs = tf.nn.bias_add(outputs, biases)
259 |
260 | if bn:
261 | outputs = batch_norm_for_conv2d(outputs, is_training,
262 | bn_decay=bn_decay, scope='bn',
263 | data_format=data_format)
264 |
265 | if activation_fn is not None:
266 | outputs = activation_fn(outputs)
267 | return outputs
268 |
269 |
270 | def conv3d(inputs,
271 | num_output_channels,
272 | kernel_size,
273 | scope,
274 | stride=[1, 1, 1],
275 | padding='SAME',
276 | use_xavier=True,
277 | stddev=1e-3,
278 | weight_decay=None,
279 | activation_fn=tf.nn.relu,
280 | bn=False,
281 | bn_decay=None,
282 | is_training=None):
283 | """ 3D convolution with non-linear operation.
284 |
285 | Args:
286 | inputs: 5-D tensor variable BxDxHxWxC
287 | num_output_channels: int
288 | kernel_size: a list of 3 ints
289 | scope: string
290 | stride: a list of 3 ints
291 | padding: 'SAME' or 'VALID'
292 | use_xavier: bool, use xavier_initializer if true
293 | stddev: float, stddev for truncated_normal init
294 | weight_decay: float
295 | activation_fn: function
296 | bn: bool, whether to use batch norm
297 | bn_decay: float or float tensor variable in [0,1]
298 | is_training: bool Tensor variable
299 |
300 | Returns:
301 | Variable tensor
302 | """
303 | with tf.variable_scope(scope) as sc:
304 | kernel_d, kernel_h, kernel_w = kernel_size
305 | num_in_channels = inputs.get_shape()[-1].value
306 | kernel_shape = [kernel_d, kernel_h, kernel_w,
307 | num_in_channels, num_output_channels]
308 | kernel = _variable_with_weight_decay('weights',
309 | shape=kernel_shape,
310 | use_xavier=use_xavier,
311 | stddev=stddev,
312 | wd=weight_decay)
313 | stride_d, stride_h, stride_w = stride
314 | outputs = tf.nn.conv3d(inputs, kernel,
315 | [1, stride_d, stride_h, stride_w, 1],
316 | padding=padding)
317 | biases = _variable_on_cpu('biases', [num_output_channels],
318 | tf.constant_initializer(0.0))
319 | outputs = tf.nn.bias_add(outputs, biases)
320 |
321 | if bn:
322 | outputs = batch_norm_for_conv3d(outputs, is_training,
323 | bn_decay=bn_decay, scope='bn')
324 |
325 | if activation_fn is not None:
326 | outputs = activation_fn(outputs)
327 | return outputs
328 |
329 |
330 | def fully_connected(inputs,
331 | num_outputs,
332 | scope,
333 | use_xavier=True,
334 | stddev=1e-3,
335 | weight_decay=None,
336 | activation_fn=tf.nn.relu,
337 | bn=False,
338 | bn_decay=None,
339 | is_training=None,
340 | reuse=False,
341 | weights_initializer=None,
342 | biases_initializer=tf.constant_initializer(0.0)):
343 | """ Fully connected layer with non-linear operation.
344 |
345 | Args:
346 | inputs: 2-D tensor BxN
347 | num_outputs: int
348 |
349 | Returns:
350 | Variable tensor of size B x num_outputs.
351 | """
352 | with tf.variable_scope(scope, reuse=reuse) as sc:
353 | num_input_units = inputs.get_shape()[-1].value
354 | weights = _variable_with_weight_decay('weights',
355 | shape=[num_input_units, num_outputs],
356 | use_xavier=use_xavier,
357 | stddev=stddev,
358 | wd=weight_decay,
359 | initializer=weights_initializer)
360 | outputs = tf.matmul(inputs, weights)
361 | biases = _variable_on_cpu('biases', [num_outputs], biases_initializer)
362 | outputs = tf.nn.bias_add(outputs, biases)
363 |
364 | if bn:
365 | outputs = batch_norm_for_fc(outputs, is_training, bn_decay, 'bn')
366 |
367 | if activation_fn is not None:
368 | outputs = activation_fn(outputs)
369 | return outputs
370 |
371 |
372 | def max_pool2d(inputs,
373 | kernel_size,
374 | scope,
375 | stride=[2, 2],
376 | padding='VALID'):
377 | """ 2D max pooling.
378 |
379 | Args:
380 | inputs: 4-D tensor BxHxWxC
381 | kernel_size: a list of 2 ints
382 | stride: a list of 2 ints
383 |
384 | Returns:
385 | Variable tensor
386 | """
387 | with tf.variable_scope(scope) as sc:
388 | kernel_h, kernel_w = kernel_size
389 | stride_h, stride_w = stride
390 | outputs = tf.nn.max_pool(inputs,
391 | ksize=[1, kernel_h, kernel_w, 1],
392 | strides=[1, stride_h, stride_w, 1],
393 | padding=padding,
394 | name=sc.name)
395 | return outputs
396 |
397 |
398 | def avg_pool2d(inputs,
399 | kernel_size,
400 | scope,
401 | stride=[2, 2],
402 | padding='VALID'):
403 | """ 2D avg pooling.
404 |
405 | Args:
406 | inputs: 4-D tensor BxHxWxC
407 | kernel_size: a list of 2 ints
408 | stride: a list of 2 ints
409 |
410 | Returns:
411 | Variable tensor
412 | """
413 | with tf.variable_scope(scope) as sc:
414 | kernel_h, kernel_w = kernel_size
415 | stride_h, stride_w = stride
416 | outputs = tf.nn.avg_pool(inputs,
417 | ksize=[1, kernel_h, kernel_w, 1],
418 | strides=[1, stride_h, stride_w, 1],
419 | padding=padding,
420 | name=sc.name)
421 | return outputs
422 |
423 |
424 | def max_pool3d(inputs,
425 | kernel_size,
426 | scope,
427 | stride=[2, 2, 2],
428 | padding='VALID'):
429 | """ 3D max pooling.
430 |
431 | Args:
432 | inputs: 5-D tensor BxDxHxWxC
433 | kernel_size: a list of 3 ints
434 | stride: a list of 3 ints
435 |
436 | Returns:
437 | Variable tensor
438 | """
439 | with tf.variable_scope(scope) as sc:
440 | kernel_d, kernel_h, kernel_w = kernel_size
441 | stride_d, stride_h, stride_w = stride
442 | outputs = tf.nn.max_pool3d(inputs,
443 | ksize=[1, kernel_d, kernel_h, kernel_w, 1],
444 | strides=[1, stride_d, stride_h, stride_w, 1],
445 | padding=padding,
446 | name=sc.name)
447 | return outputs
448 |
449 |
450 | def avg_pool3d(inputs,
451 | kernel_size,
452 | scope,
453 | stride=[2, 2, 2],
454 | padding='VALID'):
455 | """ 3D avg pooling.
456 |
457 | Args:
458 | inputs: 5-D tensor BxDxHxWxC
459 | kernel_size: a list of 3 ints
460 | stride: a list of 3 ints
461 |
462 | Returns:
463 | Variable tensor
464 | """
465 | with tf.variable_scope(scope) as sc:
466 | kernel_d, kernel_h, kernel_w = kernel_size
467 | stride_d, stride_h, stride_w = stride
468 | outputs = tf.nn.avg_pool3d(inputs,
469 | ksize=[1, kernel_d, kernel_h, kernel_w, 1],
470 | strides=[1, stride_d, stride_h, stride_w, 1],
471 | padding=padding,
472 | name=sc.name)
473 | return outputs
474 |
475 |
476 | def batch_norm_template_unused(inputs, is_training, scope, moments_dims, bn_decay):
477 | """ NOTE: this is older version of the util func. it is deprecated.
478 | Batch normalization on convolutional maps and beyond...
479 | Ref.: http://stackoverflow.com/questions/33949786/how-could-i-use-batch-normalization-in-tensorflow
480 |
481 | Args:
482 | inputs: Tensor, k-D input ... x C could be BC or BHWC or BDHWC
483 | is_training: boolean tf.Varialbe, true indicates training phase
484 | scope: string, variable scope
485 | moments_dims: a list of ints, indicating dimensions for moments calculation
486 | bn_decay: float or float tensor variable, controling moving average weight
487 | Return:
488 | normed: batch-normalized maps
489 | """
490 | with tf.variable_scope(scope) as sc:
491 | num_channels = inputs.get_shape()[-1].value
492 | beta = _variable_on_cpu(name='beta', shape=[num_channels],
493 | initializer=tf.constant_initializer(0))
494 | gamma = _variable_on_cpu(name='gamma', shape=[num_channels],
495 | initializer=tf.constant_initializer(1.0))
496 | batch_mean, batch_var = tf.nn.moments(inputs, moments_dims, name='moments')
497 | decay = bn_decay if bn_decay is not None else 0.9
498 | ema = tf.train.ExponentialMovingAverage(decay=decay)
499 | # Operator that maintains moving averages of variables.
500 | # Need to set reuse=False, otherwise if reuse, will see moments_1/mean/ExponentialMovingAverage/ does not exist
501 | # https://github.com/shekkizh/WassersteinGAN.tensorflow/issues/3
502 | with tf.variable_scope(tf.get_variable_scope(), reuse=False):
503 | ema_apply_op = tf.cond(is_training,
504 | lambda: ema.apply([batch_mean, batch_var]),
505 | lambda: tf.no_op())
506 |
507 | # Update moving average and return current batch's avg and var.
508 | def mean_var_with_update():
509 | with tf.control_dependencies([ema_apply_op]):
510 | return tf.identity(batch_mean), tf.identity(batch_var)
511 |
512 | # ema.average returns the Variable holding the average of var.
513 | mean, var = tf.cond(is_training,
514 | mean_var_with_update,
515 | lambda: (ema.average(batch_mean), ema.average(batch_var)))
516 | normed = tf.nn.batch_normalization(inputs, mean, var, beta, gamma, 1e-3)
517 | return normed
518 |
519 |
520 | def batch_norm_template(inputs, is_training, scope, moments_dims_unused, bn_decay, data_format='NHWC'):
521 | """ Batch normalization on convolutional maps and beyond...
522 | Ref.: http://stackoverflow.com/questions/33949786/how-could-i-use-batch-normalization-in-tensorflow
523 |
524 | Args:
525 | inputs: Tensor, k-D input ... x C could be BC or BHWC or BDHWC
526 | is_training: boolean tf.Varialbe, true indicates training phase
527 | scope: string, variable scope
528 | moments_dims: a list of ints, indicating dimensions for moments calculation
529 | bn_decay: float or float tensor variable, controling moving average weight
530 | data_format: 'NHWC' or 'NCHW'
531 | Return:
532 | normed: batch-normalized maps
533 | """
534 | bn_decay = bn_decay if bn_decay is not None else 0.9
535 | return tf.contrib.layers.batch_norm(inputs,
536 | center=True, scale=True,
537 | is_training=is_training, decay=bn_decay, updates_collections=None,
538 | scope=scope,
539 | data_format=data_format)
540 |
541 |
542 | def batch_norm_for_fc(inputs, is_training, bn_decay, scope):
543 | """ Batch normalization on FC data.
544 |
545 | Args:
546 | inputs: Tensor, 2D BxC input
547 | is_training: boolean tf.Varialbe, true indicates training phase
548 | bn_decay: float or float tensor variable, controling moving average weight
549 | scope: string, variable scope
550 | Return:
551 | normed: batch-normalized maps
552 | """
553 | return batch_norm_template(inputs, is_training, scope, [0, ], bn_decay)
554 |
555 |
556 | def batch_norm_for_conv1d(inputs, is_training, bn_decay, scope, data_format):
557 | """ Batch normalization on 1D convolutional maps.
558 |
559 | Args:
560 | inputs: Tensor, 3D BLC input maps
561 | is_training: boolean tf.Varialbe, true indicates training phase
562 | bn_decay: float or float tensor variable, controling moving average weight
563 | scope: string, variable scope
564 | data_format: 'NHWC' or 'NCHW'
565 | Return:
566 | normed: batch-normalized maps
567 | """
568 | return batch_norm_template(inputs, is_training, scope, [0, 1], bn_decay, data_format)
569 |
570 |
571 | def batch_norm_for_conv2d(inputs, is_training, bn_decay, scope, data_format):
572 | """ Batch normalization on 2D convolutional maps.
573 |
574 | Args:
575 | inputs: Tensor, 4D BHWC input maps
576 | is_training: boolean tf.Varialbe, true indicates training phase
577 | bn_decay: float or float tensor variable, controling moving average weight
578 | scope: string, variable scope
579 | data_format: 'NHWC' or 'NCHW'
580 | Return:
581 | normed: batch-normalized maps
582 | """
583 | return batch_norm_template(inputs, is_training, scope, [0, 1, 2], bn_decay, data_format)
584 |
585 |
586 | def batch_norm_for_conv3d(inputs, is_training, bn_decay, scope):
587 | """ Batch normalization on 3D convolutional maps.
588 |
589 | Args:
590 | inputs: Tensor, 5D BDHWC input maps
591 | is_training: boolean tf.Varialbe, true indicates training phase
592 | bn_decay: float or float tensor variable, controling moving average weight
593 | scope: string, variable scope
594 | Return:
595 | normed: batch-normalized maps
596 | """
597 | return batch_norm_template(inputs, is_training, scope, [0, 1, 2, 3], bn_decay)
598 |
599 |
600 | def dropout(inputs,
601 | is_training,
602 | scope,
603 | keep_prob=0.5,
604 | noise_shape=None):
605 | """ Dropout layer.
606 |
607 | Args:
608 | inputs: tensor
609 | is_training: boolean tf.Variable
610 | scope: string
611 | keep_prob: float in [0,1]
612 | noise_shape: list of ints
613 |
614 | Returns:
615 | tensor variable
616 | """
617 | with tf.variable_scope(scope) as sc:
618 | outputs = tf.cond(is_training,
619 | lambda: tf.nn.dropout(inputs, keep_prob, noise_shape),
620 | lambda: inputs)
621 | return outputs
622 |
--------------------------------------------------------------------------------