├── .gitignore ├── LICENSE ├── README.md ├── anaconda_env └── CompoNet.yml ├── data ├── __init__.py ├── data_utils.py ├── part_dataset_ae.py └── part_dataset_pcn.py ├── datasets └── download_data.sh ├── images └── network_architecture.png ├── models.py ├── test.py ├── tf_ops ├── __init__.py └── nn_distance │ ├── README.md │ ├── __init__.py │ ├── tf_nndistance.cpp │ ├── tf_nndistance.py │ ├── tf_nndistance_compile.sh │ ├── tf_nndistance_cpu.py │ ├── tf_nndistance_g.cu │ ├── tf_nndistance_g.cu.o │ └── tf_nndistance_so.so ├── train.py └── utils ├── __init__.py ├── compile_render_balls_so.sh ├── render_balls_so.cpp ├── render_balls_so.so ├── show3d_balls.py └── tf_util.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | /log/* 3 | /datasets/*/ 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 nschor 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CompoNet: Learning to Generate the Unseen by Part Synthesis and Composition 2 | Created by Nadav Schor, Oren Katzir, Hao Zhang, Daniel Cohen-Or. 3 | 4 | ![representative](https://github.com/nschor/CompoNet/blob/master/images/network_architecture.png) 5 | 6 | 7 | ## Introduction 8 | This work is based on our [ICCV paper](https://arxiv.org/abs/1811.07441). We present CompoNet, a generative neural network for 3D shapes that is based on a part-based prior, where the key idea is for the network to synthesize shapes by varying both the shape parts and their compositions. 9 | 10 | 11 | ## Citation 12 | If you find our work useful in your research, please consider citing: 13 | 14 | @InProceedings{Schor_2019_ICCV, 15 | author = {Schor, Nadav and Katzir, Oren and Zhang, Hao and Cohen-Or, Daniel}, 16 | title = {CompoNet: Learning to Generate the Unseen by Part Synthesis and Composition}, 17 | booktitle = {The IEEE International Conference on Computer Vision (ICCV)}, 18 | month = {October}, 19 | year = {2019} 20 | } 21 | 22 | 23 | ## Dependencies 24 | Requirements: 25 | - Python 2.7 26 | - Tensorflow (version 1.4+) 27 | - OpenCV (for visualization) 28 | 29 | Our code has been tested with Python 2.7, TensorFlow 1.4.0, CUDA 8.0 and cuDNN 6.0 on Ubuntu 18.04. 30 | 31 | 32 | ## Installation 33 | Download the source code from the git repository: 34 | ``` 35 | git clone https://github.com/nschor/CompoNet 36 | ``` 37 | Compile the Chamfer loss file, under `CompoNet/tf_ops/nn_distance`, taken from [Fan et. al](https://github.com/fanhqme/PointSetGeneration). 38 | ``` 39 | cd CompoNet/tf_ops/nn_distance 40 | ``` 41 | 42 | Modify the Tensorflow and CUDA path in the `tf_nndistance_compile.sh` script and run it. 43 | ``` 44 | sh tf_nndistance_compile.sh 45 | ``` 46 | For visualization go to `utils/`. 47 | ``` 48 | cd CompoNet/utils 49 | ``` 50 | Run the `compile_render_balls_so.sh` script. 51 | ``` 52 | sh compile_render_balls_so.sh 53 | ``` 54 | 55 | If you are using Anaconda, we attached the environment we used `CompoNet.yml` under `anaconda_env/`. 56 | Create the environment using: 57 | ``` 58 | cd CompoNet/anaconda_env 59 | conda env create -f CompoNet.yml 60 | ``` 61 | Activate the environment: 62 | ``` 63 | source activate CompoNet 64 | ``` 65 | ### Data Set 66 | Download the ShapeNetPart dataset by running the `download_data.sh` script under `datasets/`. 67 | ``` 68 | cd CompoNet/datasets 69 | sh download_data.sh 70 | ``` 71 | The point-clouds will be stored in `CompoNet/datasets/shapenetcore_partanno_segmentation_benchmark_v0` 72 | 73 | 74 | ### Train CompoNet 75 | To train CompoNet on the Chair category with 400 points per part run: 76 | ``` 77 | python train.py 78 | ``` 79 | Check the available options using: 80 | ``` 81 | python train.py -h 82 | ``` 83 | 84 | ### Generate Shapes Using the Trained Model 85 | To generate new shapes, and visualize them run: 86 | ``` 87 | python test.py --category category --model_path model_path 88 | ``` 89 | Check the available options using: 90 | ``` 91 | python test.py -h 92 | ``` 93 | 94 | ## License 95 | This project is licensed under the terms of the MIT license (see LICENSE for details). 96 | -------------------------------------------------------------------------------- /anaconda_env/CompoNet.yml: -------------------------------------------------------------------------------- 1 | name: CompoNet 2 | channels: 3 | - menpo 4 | - conda-forge 5 | - anaconda 6 | - mw 7 | - defaults 8 | dependencies: 9 | - atk=2.8.0 10 | - backports=1.0 11 | - backports.functools_lru_cache=1.5 12 | - backports.shutil_get_terminal_size=1.0.0 13 | - backports_abc=0.5 14 | - blas=1.0 15 | - ca-certificates=2019.1.23 16 | - cairo=1.14.12 17 | - certifi=2018.11.29 18 | - configparser=3.5.0 19 | - cycler=0.10.0 20 | - dbus=1.10.22 21 | - decorator=4.3.0 22 | - entrypoints=0.2.3 23 | - enum34=1.1.6 24 | - expat=2.2.5 25 | - ffmpeg=2.7.0 26 | - fontconfig=2.12.6 27 | - freeimage=3.17.0 28 | - freetype=2.8.1 29 | - functools32=3.2.3.2 30 | - gdk-pixbuf=2.28.2 31 | - glib=2.53.6 32 | - gmp=6.1.2 33 | - gst-plugins-base=1.12.4 34 | - gstreamer=1.12.4 35 | - gtk2=2.24.31 36 | - h5py=2.8.0 37 | - hdf5=1.10.2 38 | - icu=58.2 39 | - imageio=1.5.0 40 | - intel-openmp=2018.0.0 41 | - ipaddress=1.0.22 42 | - ipykernel=4.10.0 43 | - ipython=5.8.0 44 | - ipython_genutils=0.2.0 45 | - ipywidgets=7.4.2 46 | - jasper=1.900.1 47 | - jinja2=2.10 48 | - jpeg=9b 49 | - jsonschema=2.6.0 50 | - jupyter=1.0.0 51 | - jupyter_client=5.2.3 52 | - jupyter_console=5.2.0 53 | - jupyter_core=4.4.0 54 | - libedit=3.1 55 | - libffi=3.2.1 56 | - libgcc-ng=7.2.0 57 | - libgfortran-ng=7.2.0 58 | - libiconv=1.15 59 | - libpng=1.6.34 60 | - libsodium=1.0.16 61 | - libstdcxx-ng=7.2.0 62 | - libtiff=4.0.9 63 | - libxcb=1.12 64 | - libxml2=2.9.7 65 | - linecache2=1.0.0 66 | - markupsafe=1.0 67 | - matplotlib=2.1.2 68 | - mistune=0.8.3 69 | - mkl=2018.0.1 70 | - nbconvert=5.3.1 71 | - nbformat=4.4.0 72 | - ncurses=6.0 73 | - notebook=5.7.0 74 | - olefile=0.45.1 75 | - opencv=2.4.11 76 | - openssl=1.0.2p 77 | - pandas=0.22.0 78 | - pandoc=2.2.3.2 79 | - pandocfilters=1.4.2 80 | - pango=1.22.4 81 | - pathlib2=2.3.2 82 | - pcre=8.39 83 | - pexpect=4.6.0 84 | - pickleshare=0.7.5 85 | - pillow=5.0.0 86 | - pip=19.0.3 87 | - pixman=0.34.0 88 | - pkg-config=0.29.2 89 | - pkgconfig=1.4.0 90 | - prometheus_client=0.4.2 91 | - prompt_toolkit=1.0.15 92 | - ptyprocess=0.6.0 93 | - pygments=2.2.0 94 | - pyparsing=2.2.0 95 | - pyqt=5.6.0 96 | - python=2.7.14 97 | - python-dateutil=2.6.1 98 | - pytz=2017.3 99 | - pyzmq=17.1.2 100 | - qt=5.6.2 101 | - qtconsole=4.4.2 102 | - readline=7.0 103 | - scandir=1.9.0 104 | - scikit-learn=0.19.1 105 | - scipy=1.0.0 106 | - send2trash=1.5.0 107 | - simplegeneric=0.8.1 108 | - singledispatch=3.4.0.3 109 | - sip=4.18 110 | - sqlite=3.22.0 111 | - ssl_match_hostname=3.5.0.1 112 | - subprocess32=3.2.7 113 | - terminado=0.8.1 114 | - testpath=0.4.2 115 | - tk=8.6.7 116 | - tornado=4.5.3 117 | - traceback2=1.4.0 118 | - traitlets=4.3.2 119 | - unittest2=1.1.0 120 | - wcwidth=0.1.7 121 | - webencodings=0.5.1 122 | - widgetsnbextension=3.4.2 123 | - xorg-libxau=1.0.8 124 | - xorg-libxdmcp=1.1.2 125 | - xz=5.2.3 126 | - zeromq=4.2.5 127 | - zlib=1.2.11 128 | - pip: 129 | - absl-py==0.1.10 130 | - backports-weakref==1.0.post1 131 | - bleach==1.5.0 132 | - funcsigs==1.0.2 133 | - futures==3.2.0 134 | - html5lib==0.9999999 135 | - markdown==3.0.1 136 | - mock==2.0.0 137 | - numpy==1.16.2 138 | - pbr==5.1.3 139 | - protobuf==3.7.0 140 | - setuptools==40.8.0 141 | - six==1.12.0 142 | - tensorflow==1.5.0 143 | - tensorflow-gpu==1.4.0 144 | - tensorflow-tensorboard==0.4.0 145 | - tflearn==0.3.2 146 | - werkzeug==0.14.1 147 | - wheel==0.33.1 148 | prefix: /home/nadav/anaconda2/envs/tensorflow 149 | 150 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nschor/CompoNet/816da713d36d96d715a9c026c8ac4ae568a24780/data/__init__.py -------------------------------------------------------------------------------- /data/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import part_dataset_ae 5 | import part_dataset_pcn 6 | 7 | 8 | def load_data(data_path, num_point, category, seen_split, unseen_split): 9 | pcn_train_dataset, pcn_test_dataset, num_parts = load_pcn_data(data_path, num_point, category, seen_split, 10 | unseen_split) 11 | ae_train_dataset, ae_test_dataset = load_aes_data(data_path, num_point, category, seen_split, unseen_split, 12 | num_parts) 13 | 14 | return pcn_train_dataset, pcn_test_dataset, ae_train_dataset, ae_test_dataset, num_parts 15 | 16 | 17 | def load_pcn_data(data_path, num_point, category, seen_split, unseen_split): 18 | pcn_train_dataset = part_dataset_pcn.PartDatasetPCN(root=data_path, npoints=num_point, class_choice=category, 19 | split=seen_split) 20 | pcn_test_dataset = part_dataset_pcn.PartDatasetPCN(root=data_path, npoints=num_point, class_choice=category, 21 | split=unseen_split) 22 | num_parts = pcn_train_dataset.get_number_of_parts() 23 | 24 | return pcn_train_dataset, pcn_test_dataset, num_parts 25 | 26 | 27 | def load_aes_data(data_path, num_point, category, seen_split, unseen_split, num_parts): 28 | ae_train_dataset = [] 29 | ae_test_dataset = [] 30 | for i in xrange(num_parts): 31 | print 'Loading part ' + str(i) 32 | ae_train_dataset.append( 33 | part_dataset_ae.PartDatasetAE(root=data_path, npoints=num_point, class_choice=category, split=seen_split, 34 | part_label=i)) 35 | ae_test_dataset.append( 36 | part_dataset_ae.PartDatasetAE(root=data_path, npoints=num_point, class_choice=category, split=unseen_split, 37 | part_label=i)) 38 | 39 | return ae_train_dataset, ae_test_dataset 40 | 41 | 42 | def pc_normalize(pc): 43 | centroid = np.mean(pc, axis=0) 44 | pc = pc - centroid 45 | m = np.max(np.sqrt(np.sum(pc ** 2, axis=1))) 46 | pc = pc / m 47 | return pc 48 | 49 | 50 | def data_parse(catfile, class_choice, root, split): 51 | cat = {} 52 | with open(catfile, 'r') as f: 53 | for line in f: 54 | ls = line.strip().split() 55 | cat[ls[0]] = ls[1] 56 | 57 | cat = {k: v for k, v in cat.items() if k in class_choice} 58 | 59 | meta = {} 60 | with open(os.path.join(root, 'train_test_split', 'shuffled_train_file_list.json'), 'r') as f: 61 | train_ids = set([str(d.split('/')[2]) for d in json.load(f)]) 62 | with open(os.path.join(root, 'train_test_split', 'shuffled_val_file_list.json'), 'r') as f: 63 | val_ids = set([str(d.split('/')[2]) for d in json.load(f)]) 64 | with open(os.path.join(root, 'train_test_split', 'shuffled_test_file_list.json'), 'r') as f: 65 | test_ids = set([str(d.split('/')[2]) for d in json.load(f)]) 66 | for item in cat: 67 | meta[item] = [] 68 | dir_point = os.path.join(root, cat[item], 'points') 69 | dir_seg = os.path.join(root, cat[item], 'points_label') 70 | fns = sorted(os.listdir(dir_point)) 71 | if split == 'trainval': 72 | fns = [fn for fn in fns if ((fn[0:-4] in train_ids) or (fn[0:-4] in val_ids))] 73 | elif split == 'train': 74 | fns = [fn for fn in fns if fn[0:-4] in train_ids] 75 | elif split == 'val': 76 | fns = [fn for fn in fns if fn[0:-4] in val_ids] 77 | elif split == 'test': 78 | fns = [fn for fn in fns if fn[0:-4] in test_ids] 79 | else: 80 | print('Unknown split: %s. Exiting..' % (split)) 81 | exit(-1) 82 | 83 | for fn in fns: 84 | token = (os.path.splitext(os.path.basename(fn))[0]) 85 | meta[item].append((os.path.join(dir_point, token + '.pts'), os.path.join(dir_seg, token + '.seg'))) 86 | 87 | return cat, meta 88 | 89 | 90 | def compute_num_of_parts(path): 91 | max_num_parts = 0 92 | for i in range(len(path) / 50): 93 | num_parts = len(np.unique(np.loadtxt(path[i][-1]).astype(np.uint8))) 94 | if num_parts > max_num_parts: 95 | max_num_parts = num_parts 96 | return max_num_parts 97 | 98 | 99 | def get_point_set_and_seg(path, index): 100 | fn = path[index] 101 | point_set = np.loadtxt(fn[1]).astype(np.float32) 102 | seg = np.loadtxt(fn[2]).astype(np.int64) - 1 103 | 104 | return point_set, seg 105 | 106 | 107 | def choose_points(point_set, npoints): 108 | choice = np.random.choice(len(point_set), npoints, replace=True) 109 | point_set = point_set[choice] 110 | 111 | return point_set, choice 112 | -------------------------------------------------------------------------------- /data/part_dataset_ae.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import data_utils 4 | 5 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | 8 | class PartDatasetAE(): 9 | def __init__(self, root, npoints=400, class_choice='Chair', split='train', part_label=None): 10 | if part_label is None: 11 | print 'Wrong part label - part_dataset_ae' 12 | exit(1) 13 | 14 | self.npoints = npoints 15 | self.part_label = part_label 16 | self.cache = {} # caching the loaded parts 17 | 18 | cat, meta = data_utils.data_parse(os.path.join(root, 'synsetoffset2category.txt'), class_choice, root, split) 19 | 20 | self.datapath = [] 21 | for item in cat: 22 | for fn in meta[item]: 23 | # discard missing parts 24 | seg = np.loadtxt(fn[1]).astype(np.int64) - 1 25 | part_points = np.where(seg == self.part_label) 26 | if len(part_points[0]) > 1: 27 | self.datapath.append((item, fn[0], fn[1])) 28 | 29 | def __getitem__(self, index): 30 | if index in self.cache: 31 | point_set = self.cache[index] 32 | else: 33 | point_set, seg = data_utils.get_point_set_and_seg(self.datapath, index) 34 | part_points = np.where(seg == self.part_label) 35 | point_set = data_utils.pc_normalize(point_set[part_points]) 36 | self.cache[index] = point_set 37 | 38 | # choose the right number of point by 39 | # randomly picking, if there are too many 40 | # or re-sampling, if there are less than needed 41 | point_set_length = len(point_set) 42 | if point_set_length >= self.npoints: 43 | point_set, _ = data_utils.choose_points(point_set, self.npoints) 44 | else: 45 | extra_point_set, choice = data_utils.choose_points(point_set, self.npoints - point_set_length) 46 | point_set = np.append(point_set, extra_point_set, axis=0) 47 | 48 | return point_set 49 | 50 | def __len__(self): 51 | return len(self.datapath) 52 | 53 | 54 | if __name__ == '__main__': 55 | from utils import show3d_balls 56 | 57 | d = PartDatasetAE(root=os.path.join(BASE_DIR, '../data/shapenetcore_partanno_segmentation_benchmark_v0'), 58 | class_choice='Chair', split='test', part_label=1) 59 | i = 27 60 | ps = d[i] 61 | show3d_balls.showpoints(ps, ballradius=8) 62 | -------------------------------------------------------------------------------- /data/part_dataset_pcn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import data_utils 4 | 5 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | 8 | class PartDatasetPCN: 9 | def __init__(self, root, npoints=400, class_choice='Chair', split='train'): 10 | self.npoints = npoints 11 | self.cache = {} # caching the loaded parts 12 | 13 | cat, meta = data_utils.data_parse(os.path.join(root, 'synsetoffset2category.txt'), class_choice, root, split) 14 | 15 | self.datapath = [] 16 | for item in cat: 17 | for fn in meta[item]: 18 | self.datapath.append((item, fn[0], fn[1])) 19 | 20 | self.num_parts = data_utils.compute_num_of_parts(self.datapath) 21 | 22 | def __getitem__(self, index): 23 | if index in self.cache: 24 | parts_point_sets = self.cache[index] 25 | else: 26 | point_set, seg = data_utils.get_point_set_and_seg(self.datapath, index) 27 | point_set = data_utils.pc_normalize(point_set) 28 | parts_point_sets = [] 29 | for p in xrange(self.num_parts): 30 | part_points = np.where(seg == p) 31 | if len(part_points[0]) > 1: 32 | part_point_set = point_set[part_points] 33 | is_part_exist = True 34 | else: 35 | part_point_set = np.zeros((self.npoints, 3)) 36 | is_part_exist = False 37 | # normalized each part on its own 38 | if is_part_exist: 39 | norm_part_point_set = data_utils.pc_normalize(part_point_set) 40 | else: 41 | norm_part_point_set = part_point_set 42 | parts_point_sets.append((part_point_set, norm_part_point_set, is_part_exist)) 43 | self.cache[index] = parts_point_sets 44 | 45 | point_sets = [] 46 | for point_set, norm_part_point_set, is_full in parts_point_sets: 47 | # choose the right number of point by 48 | # randomly picking, if there are too many 49 | # or re-sampling, if there are less than needed 50 | point_set_length = len(point_set) 51 | if point_set_length >= self.npoints: 52 | point_set, choice = data_utils.choose_points(point_set, self.npoints) 53 | norm_part_point_set = norm_part_point_set[choice] 54 | else: 55 | extra_point_set, choice = data_utils.choose_points(point_set, self.npoints - point_set_length) 56 | point_set = np.append(point_set, extra_point_set, axis=0) 57 | norm_part_point_set = np.append(norm_part_point_set, norm_part_point_set[choice], axis=0) 58 | point_sets.append((point_set, norm_part_point_set, is_full)) 59 | 60 | return point_sets 61 | 62 | def __len__(self): 63 | return len(self.datapath) 64 | 65 | def get_number_of_parts(self): 66 | return self.num_parts 67 | 68 | 69 | if __name__ == '__main__': 70 | from utils import show3d_balls 71 | 72 | d = PartDatasetPCN(root=os.path.join(BASE_DIR, '../data/shapenetcore_partanno_segmentation_benchmark_v0'), 73 | class_choice='Chair', split='test') 74 | i = 27 75 | point_sets = d[i] 76 | for p in xrange(d.get_number_of_parts()): 77 | ps, _, _ = point_sets[p] 78 | show3d_balls.showpoints(ps, ballradius=8) 79 | -------------------------------------------------------------------------------- /datasets/download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget https://shapenet.cs.stanford.edu/media/shapenetcore_partanno_segmentation_benchmark_v0.zip 3 | unzip shapenetcore_partanno_segmentation_benchmark_v0.zip 4 | rm shapenetcore_partanno_segmentation_benchmark_v0.zip 5 | -------------------------------------------------------------------------------- /images/network_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nschor/CompoNet/816da713d36d96d715a9c026c8ac4ae568a24780/images/network_architecture.png -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from utils import tf_util 4 | from tf_ops.nn_distance import tf_nndistance 5 | 6 | BN_INIT_DECAY = 0.5 7 | BN_DECAY_RATE = 0.5 8 | BN_DECAY_CLIP = 0.99 9 | 10 | 11 | def placeholder_inputs(batch_size, num_point): 12 | point_clouds_ph = tf.placeholder(tf.float32, shape=(batch_size, num_point, 3)) 13 | gt_ph = tf.placeholder(tf.float32, shape=(batch_size, num_point, 3)) 14 | return point_clouds_ph, gt_ph 15 | 16 | 17 | def build_parts_aes_graphs(num_parts, num_points, part_embedding_dim, base_learning_rate, batch_size, decay_step, 18 | decay_rate, bn_decay_step): 19 | point_clouds_phs = [] 20 | ae_ops = [] 21 | for i in xrange(num_parts): 22 | with tf.variable_scope('part' + str(i)): 23 | print 'Graph part ' + str(i) 24 | point_clouds_ph, gt_ph = placeholder_inputs(batch_size, num_points) 25 | point_clouds_phs.append(point_clouds_ph) 26 | 27 | is_training_ph = tf.placeholder(tf.bool, shape=()) 28 | 29 | batch = tf.Variable(0) 30 | bn_momentum = tf.train.exponential_decay(BN_INIT_DECAY, batch * batch_size, bn_decay_step, BN_DECAY_RATE, 31 | staircase=True) 32 | bn_decay = tf.minimum(BN_DECAY_CLIP, 1 - bn_momentum) 33 | bn_decay_summary = tf.summary.scalar('bn_decay' + ' ' + str(i), bn_decay) 34 | 35 | print "Get model and loss" 36 | pred, end_points = get_model_ae(point_clouds_ph, is_training_ph, batch_size, num_points, bn_decay, 37 | part_embedding_dim) 38 | loss, end_points_tmp = get_loss_ae(pred, gt_ph, end_points) 39 | loss_summary = tf.summary.scalar('loss', loss) 40 | 41 | print "Get training operator" 42 | learning_rate = tf.train.exponential_decay(base_learning_rate, batch * batch_size, decay_step, 43 | decay_rate, staircase=True) 44 | learning_rate_summary = tf.summary.scalar('learning_rate', learning_rate) 45 | optimizer = tf.train.AdamOptimizer(learning_rate) 46 | train_op = optimizer.minimize(loss, global_step=batch) 47 | 48 | merged_summary = tf.summary.merge((bn_decay_summary, loss_summary, learning_rate_summary)) 49 | 50 | ae_ops.append({'point_clouds_ph': point_clouds_ph, 51 | 'gt_ph': gt_ph, 52 | 'is_training_ph': is_training_ph, 53 | 'pred': pred, 54 | 'loss': loss, 55 | 'train_op': train_op, 56 | 'merged_summary': merged_summary, 57 | 'step': batch, 58 | 'end_points': end_points}) 59 | 60 | return ae_ops, point_clouds_phs 61 | 62 | 63 | def get_model_ae(point_cloud, is_training, batch_size, num_point, bn_decay=None, embedding_dim=64, reuse=False): 64 | input_point_cloud = tf.expand_dims(point_cloud, -1) 65 | net, end_points = ae_encoder(batch_size, num_point, 3, input_point_cloud, is_training, bn_decay=bn_decay, 66 | embedding_dim=embedding_dim) 67 | net = ae_decoder(batch_size, num_point, net, is_training, bn_decay=bn_decay, reuse=reuse) 68 | 69 | return net, end_points 70 | 71 | 72 | def ae_encoder(batch_size, num_point, point_dim, input_image, is_training, bn_decay=None, embedding_dim=128): 73 | net = tf_util.conv2d(input_image, 64, [1, point_dim], 74 | padding='VALID', stride=[1, 1], 75 | bn=True, is_training=is_training, 76 | scope='conv1', bn_decay=bn_decay) 77 | net = tf_util.conv2d(net, 64, [1, 1], 78 | padding='VALID', stride=[1, 1], 79 | bn=True, is_training=is_training, 80 | scope='conv2', bn_decay=bn_decay) 81 | net = tf_util.conv2d(net, 64, [1, 1], 82 | padding='VALID', stride=[1, 1], 83 | bn=True, is_training=is_training, 84 | scope='conv3', bn_decay=bn_decay) 85 | net = tf_util.conv2d(net, 128, [1, 1], 86 | padding='VALID', stride=[1, 1], 87 | bn=True, is_training=is_training, 88 | scope='conv4', bn_decay=bn_decay) 89 | net = tf_util.conv2d(net, embedding_dim, [1, 1], 90 | padding='VALID', stride=[1, 1], 91 | bn=True, is_training=is_training, 92 | scope='conv5', bn_decay=bn_decay) 93 | global_feat = tf_util.max_pool2d(net, [num_point, 1], 94 | padding='VALID', scope='maxpool') 95 | net = tf.reshape(global_feat, [batch_size, -1]) 96 | end_points = {'embedding': net} 97 | 98 | return net, end_points 99 | 100 | 101 | def ae_decoder(batch_size, num_point, net, is_training, bn_decay=None, reuse=False): 102 | net = tf_util.fully_connected(net, 256, bn=True, is_training=is_training, scope='fc1', bn_decay=bn_decay, 103 | reuse=reuse) 104 | net = tf_util.fully_connected(net, 256, bn=True, is_training=is_training, scope='fc2', bn_decay=bn_decay, 105 | reuse=reuse) 106 | net = tf_util.fully_connected(net, num_point * 3, activation_fn=None, scope='fc3', reuse=reuse) 107 | net = tf.reshape(net, (batch_size, num_point, 3)) 108 | 109 | return net 110 | 111 | 112 | def get_loss_ae(pred, gt, end_points): 113 | dists_forward, _, dists_backward, _ = tf_nndistance.nn_distance(pred, gt) 114 | loss = tf.reduce_mean(dists_forward + dists_backward) 115 | end_points['pcloss'] = loss 116 | 117 | loss = loss * 100 118 | end_points['loss'] = loss 119 | return loss, end_points 120 | 121 | 122 | def build_parts_pcn_graph(ae_ops, point_clouds_ph, num_parts, num_points, noise_embedding_dim, base_learning_rate, 123 | batch_size, decay_step, decay_rate, bn_decay_step): 124 | print '\nGraph PCN' 125 | with tf.variable_scope('pcn'): 126 | pcn_is_training = tf.placeholder(tf.bool, shape=()) 127 | y = tf.placeholder(tf.float32, [None, num_parts, num_points, 3], name='y') 128 | y_mask = tf.placeholder(tf.float32, [None, num_parts, num_points], name='y_mask') 129 | noise = tf.placeholder(tf.float32, shape=[None, noise_embedding_dim]) 130 | 131 | pcn_enc = tf.concat([ae_ops[0]['end_points']['embedding'], ae_ops[1]['end_points']['embedding']], axis=-1) 132 | for p in xrange(2, num_parts): 133 | pcn_enc = tf.concat([pcn_enc, ae_ops[p]['end_points']['embedding']], axis=-1) 134 | pcn_enc = tf.concat([pcn_enc, noise], axis=-1) 135 | 136 | pcn_batch = tf.Variable(0) 137 | bn_momentum = tf.train.exponential_decay(BN_INIT_DECAY, pcn_batch * batch_size, bn_decay_step, BN_DECAY_RATE, 138 | staircase=True) 139 | pcn_bn_decay = tf.minimum(BN_DECAY_CLIP, 1 - bn_momentum) 140 | bn_decay_summary = tf.summary.scalar('bn_decay', pcn_bn_decay) 141 | 142 | print "--- Get model and loss" 143 | x = tf.stack(point_clouds_ph, axis=1) 144 | y_hat = get_model_pcn(pcn_enc, x, num_parts, pcn_is_training, bn_decay=pcn_bn_decay) 145 | pcn_loss = get_loss_pcn(y, y_hat, y_mask, num_parts, batch_size) 146 | loss_summary = tf.summary.scalar('loss', pcn_loss) 147 | 148 | print "--- Get training operator" 149 | pcn_learning_rate = tf.train.exponential_decay(base_learning_rate, pcn_batch * batch_size, decay_step, 150 | decay_rate, staircase=True) 151 | learning_rate_summary = tf.summary.scalar('learning_rate', pcn_learning_rate) 152 | pcn_param = [var for var in tf.trainable_variables() if any(x in var.name for x in ['pcn'])] 153 | optimizer = tf.train.AdamOptimizer(pcn_learning_rate).minimize(pcn_loss, global_step=pcn_batch, 154 | var_list=pcn_param) 155 | 156 | merged_summary = tf.summary.merge((bn_decay_summary, loss_summary, learning_rate_summary)) 157 | 158 | pcn_ops = ({'y': y, 159 | 'y_mask': y_mask, 160 | 'noise': noise, 161 | 'is_training_ph': pcn_is_training, 162 | 'pcn_enc': pcn_enc, 163 | 'pred': y_hat, 164 | 'loss': pcn_loss, 165 | 'train_op': optimizer, 166 | 'merged_summary': merged_summary, 167 | 'step': pcn_batch}) 168 | 169 | return pcn_ops 170 | 171 | 172 | def get_model_pcn(pcn_enc, x, num_parts, is_training, bn_decay=None, reuse=False): 173 | with tf.variable_scope('trans', reuse=reuse): 174 | bias_initializer = np.array([[1., 0, 1, 0, 1, 0] for _ in xrange(num_parts)]) 175 | bias_initializer = bias_initializer.astype('float32').flatten() 176 | 177 | net = tf_util.fully_connected(pcn_enc, 256, bn=True, is_training=is_training, scope='fc1', bn_decay=bn_decay, 178 | reuse=reuse) 179 | net = tf_util.fully_connected(net, 128, bn=True, is_training=is_training, scope='fc2', bn_decay=bn_decay, 180 | reuse=reuse) 181 | trans = tf_util.fully_connected(net, num_parts * 6, activation_fn=None, scope='fc3', 182 | weights_initializer=tf.zeros_initializer(), 183 | biases_initializer=tf.constant_initializer(bias_initializer), reuse=reuse) 184 | 185 | # Perform transformation 186 | with tf.variable_scope('pcn', reuse=reuse): 187 | zeros_dims = tf.stack([tf.shape(x)[0], 1]) 188 | zeros_col = tf.fill(zeros_dims, 0.0) 189 | ''' 190 | sx 0 0 tx 191 | 0 sy 0 ty 192 | 0 0 sz tz 193 | ''' 194 | trans_mat = tf.concat((tf.expand_dims(trans[:, 0], axis=1), zeros_col, zeros_col, 195 | tf.expand_dims(trans[:, 1], axis=1), zeros_col, tf.expand_dims(trans[:, 2], axis=1), 196 | zeros_col, tf.expand_dims(trans[:, 3], axis=1), zeros_col, zeros_col, 197 | tf.expand_dims(trans[:, 4], axis=1), tf.expand_dims(trans[:, 5], axis=1), 198 | 199 | tf.expand_dims(trans[:, 6], axis=1), zeros_col, zeros_col, 200 | tf.expand_dims(trans[:, 7], axis=1), zeros_col, tf.expand_dims(trans[:, 8], axis=1), 201 | zeros_col, tf.expand_dims(trans[:, 9], axis=1), zeros_col, zeros_col, 202 | tf.expand_dims(trans[:, 10], axis=1), tf.expand_dims(trans[:, 11], axis=1)), axis=1) 203 | for p in xrange(2, num_parts): 204 | start_ind = 6 * p 205 | trans_mat = tf.concat((trans_mat, tf.expand_dims(trans[:, start_ind], axis=1), zeros_col, zeros_col, 206 | tf.expand_dims(trans[:, start_ind + 1], axis=1), zeros_col, 207 | tf.expand_dims(trans[:, start_ind + 2], axis=1), zeros_col, 208 | tf.expand_dims(trans[:, start_ind + 3], axis=1), zeros_col, zeros_col, 209 | tf.expand_dims(trans[:, start_ind + 4], axis=1), 210 | tf.expand_dims(trans[:, start_ind + 5], axis=1)), axis=1) 211 | 212 | trans_mat = tf.reshape(trans_mat, (-1, num_parts, 3, 4)) 213 | # adding 1 (w coordinate) to every point (x,y,z,1) 214 | w = tf.ones([tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], 1]) 215 | x = tf.concat((x, w), axis=-1) 216 | x_t = tf.transpose(x, [0, 1, 3, 2]) 217 | y_hat_t = tf.matmul(trans_mat, x_t) 218 | y_hat = tf.transpose(y_hat_t, [0, 1, 3, 2]) 219 | 220 | return y_hat 221 | 222 | 223 | def get_loss_pcn(gt, pred, gt_mask, num_parts, batch_size): 224 | dists_forward_total = tf.zeros(batch_size) 225 | dists_backward_total = tf.zeros(batch_size) 226 | for part in xrange(num_parts): 227 | dists_forward, _, dists_backward, _ = tf_nndistance.nn_distance(pred[:, part], gt[:, part]) 228 | # zero out the non-existing parts 229 | dists_forward = tf.reduce_sum(tf.multiply(dists_forward, gt_mask[:, part]), axis=-1) 230 | dists_backward = tf.reduce_sum(tf.multiply(dists_backward, gt_mask[:, part]), axis=-1) 231 | dists_forward_total += dists_forward 232 | dists_backward_total += dists_backward 233 | 234 | loss = dists_forward_total + dists_backward_total 235 | # divide by the number of parts 236 | div = tf.reduce_sum(tf.reduce_mean(gt_mask, axis=-1), axis=-1) 237 | loss = tf.reduce_mean(tf.div(loss, div)) 238 | 239 | return loss * 100 240 | 241 | 242 | def build_test_graph(num_parts, num_points, part_embedding_dim, base_learning_rate, batch_size, decay_step, decay_rate, 243 | bn_decay_step, noise_embedding_dim): 244 | ae_ops, point_clouds_ph = build_parts_aes_graphs(num_parts, num_points, part_embedding_dim, base_learning_rate, 245 | batch_size, decay_step, decay_rate, bn_decay_step) 246 | for i in xrange(num_parts): 247 | with tf.variable_scope('part' + str(i)): 248 | samples = (tf.placeholder(tf.float32, shape=(batch_size, part_embedding_dim))) 249 | dec = ae_decoder(batch_size, num_points, samples, ae_ops[i]["is_training_ph"], reuse=True) 250 | ae_ops[i]['samples'] = samples 251 | ae_ops[i]['dec'] = dec 252 | 253 | pcn_ops = build_parts_pcn_graph(ae_ops, point_clouds_ph, num_parts, num_points, noise_embedding_dim, 254 | base_learning_rate, batch_size, decay_step, decay_rate, bn_decay_step) 255 | 256 | with tf.variable_scope('pcn'): 257 | x_full = tf.stack((ae_ops[0]['dec'], ae_ops[1]['dec']), axis=1) 258 | for p in xrange(2, num_parts): 259 | x_full = tf.concat((x_full, tf.expand_dims(ae_ops[p]['dec'], axis=1)), axis=1) 260 | 261 | cpcn_enc_full = tf.concat([ae_ops[0]['samples'], ae_ops[1]['samples']], axis=-1) 262 | for p in xrange(2, num_parts): 263 | cpcn_enc_full = tf.concat([cpcn_enc_full, ae_ops[p]['samples']], axis=-1) 264 | cpcn_enc_full = tf.concat([cpcn_enc_full, pcn_ops['noise']], axis=-1) 265 | 266 | y_hat_full = get_model_pcn(cpcn_enc_full, x_full, num_parts, pcn_ops['is_training_ph'], reuse=True) 267 | pcn_ops['pred_full'] = y_hat_full 268 | 269 | return ae_ops, pcn_ops 270 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import tensorflow as tf 4 | import os 5 | import models 6 | from data import data_utils 7 | from utils import show3d_balls 8 | from sklearn.mixture import GaussianMixture 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--data_path', default='datasets/shapenetcore_partanno_segmentation_benchmark_v0', 12 | help='Path to the dataset [default: datasets/shapenetcore_partanno_segmentation_benchmark_v0]') 13 | parser.add_argument('--gpu', type=int, default=0, help='GPU to use [default: GPU 0]') 14 | parser.add_argument('--category', default='Chair', help='Which single class to train on [default: Chair]') 15 | parser.add_argument('--part_embedding_dim', type=int, default=64, help='Embedding dimension of each part [default: 64]') 16 | parser.add_argument('--noise_embedding_dim', type=int, default=16, 17 | help='Embedding dimension of the noise [default: 16]') 18 | parser.add_argument('--num_point', type=int, default=400, help='Number of points per part [default: 400]') 19 | parser.add_argument('--learning_rate', type=float, default=0.001, 20 | help='Initial learning rate for the parts composition network [default: 0.001]') 21 | parser.add_argument('--decay_step', type=int, default=200000, help='Decay step for lr decay [default: 200000]') 22 | parser.add_argument('--decay_rate', type=float, default=0.7, help='Decay rate for lr decay [default: 0.7]') 23 | parser.add_argument('--num_parts', type=int, default=0, 24 | help='Number of Parts, if set to 0 it will take longer to compute [default: 0]') 25 | parser.add_argument('--model_path', default='log/model.ckpt', 26 | help='model checkpoint file path [default: log/model.ckpt]') 27 | parser.add_argument('--num_samples', type=int, default='100', 28 | help='Number of generated shapes_embedd to be shown [default: 0]') 29 | 30 | FLAGS = parser.parse_args() 31 | GPU_INDEX = FLAGS.gpu 32 | CATEGORY = FLAGS.category 33 | PART_EMBEDDING_DIM = FLAGS.part_embedding_dim 34 | NOISE_EMBEDDING_DIM = FLAGS.noise_embedding_dim 35 | NUM_POINTS = FLAGS.num_point 36 | BASE_LEARNING_RATE = FLAGS.learning_rate 37 | DECAY_STEP = FLAGS.decay_step 38 | DECAY_RATE = FLAGS.decay_rate 39 | NUM_PARTS = FLAGS.num_parts 40 | MODEL_PATH = FLAGS.model_path 41 | NUM_SAMPLES = FLAGS.num_samples 42 | 43 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 44 | COLORS = [np.array([163, 254, 170]), np.array([206, 178, 254]), np.array([248, 250, 132]), np.array([237, 186, 145]), 45 | np.array([192, 144, 145]), np.array([158, 218, 73])] 46 | 47 | print 'Loading data' 48 | # Shapenet official train/test split 49 | DATA_PATH = os.path.join(ROOT_DIR, FLAGS.data_path) 50 | # Using the same splits as when training 51 | if NUM_PARTS == 0: 52 | _, _, AE_TRAIN_DATASET, _, NUM_PARTS = data_utils.load_data(DATA_PATH, NUM_POINTS, CATEGORY, 'test', 'trainval') 53 | else: 54 | AE_TRAIN_DATASET, _ = data_utils.load_aes_data(DATA_PATH, NUM_POINTS, CATEGORY, 'test', 'trainval', NUM_PARTS) 55 | 56 | 57 | def get_model(batch_size): 58 | with tf.Graph().as_default(): 59 | with tf.device('/gpu:' + str(GPU_INDEX)): 60 | ae_ops, pcn_ops = models.build_test_graph(NUM_PARTS, NUM_POINTS, PART_EMBEDDING_DIM, BASE_LEARNING_RATE, 61 | batch_size, DECAY_STEP, DECAY_RATE, float(DECAY_STEP), 62 | NOISE_EMBEDDING_DIM) 63 | 64 | saver = tf.train.Saver() 65 | # Create a session 66 | config = tf.ConfigProto() 67 | config.gpu_options.allow_growth = True 68 | config.allow_soft_placement = True 69 | sess = tf.Session(config=config) 70 | # Restore variables from disk. 71 | saver.restore(sess, MODEL_PATH) 72 | 73 | return sess, ae_ops, pcn_ops 74 | 75 | 76 | def compute_embedding_gmm_and_sample_vectors(sess, ae_ops): 77 | num_gmm_components = 20 78 | samples = [] 79 | for p in xrange(NUM_PARTS): 80 | part = [] 81 | print 'Embedding part ' + str(p) 82 | for i in xrange(len(AE_TRAIN_DATASET[p])): 83 | ps = AE_TRAIN_DATASET[p][i] 84 | feed_dict = {ae_ops[p]['point_clouds_ph']: np.expand_dims(ps, axis=0), 85 | ae_ops[p]['gt_ph']: np.expand_dims(ps, axis=0), 86 | ae_ops[p]['is_training_ph']: False, } 87 | 88 | part.append(sess.run(ae_ops[p]['end_points']['embedding'], feed_dict=feed_dict)) 89 | print 'Compute GMM for part', p 90 | gmm = GaussianMixture(n_components=num_gmm_components, covariance_type='full') 91 | gmm.fit(np.squeeze(np.array(part))) 92 | sample, _ = gmm.sample(n_samples=NUM_SAMPLES) 93 | sample = sample[np.random.permutation(np.arange(NUM_SAMPLES))] 94 | samples.append(sample) 95 | 96 | return samples 97 | 98 | 99 | def generate_shapse_from_vectors(sess, ae_ops, pcn_ops, samples): 100 | for i in range(NUM_SAMPLES): 101 | noise = np.random.normal(size=[1, NOISE_EMBEDDING_DIM]) 102 | shapes_embedd = np.stack((np.expand_dims(samples[0][i], axis=0), np.expand_dims(samples[1][i], axis=0)), axis=0) 103 | for p in xrange(2, NUM_PARTS): 104 | shapes_embedd = np.concatenate( 105 | (shapes_embedd, np.expand_dims(np.expand_dims(samples[p][i], axis=0), axis=0)), axis=0) 106 | 107 | # Demonstrate a missing part 108 | if np.random.randint(1000) % 2 == 0: 109 | missing_part = True 110 | shapes_embedd[-1] = np.zeros((1, PART_EMBEDDING_DIM)) 111 | else: 112 | missing_part = False 113 | 114 | feed_dict = {} 115 | for p in xrange(NUM_PARTS): 116 | feed_dict[ae_ops[p]['samples']] = shapes_embedd[p] 117 | feed_dict[ae_ops[p]['is_training_ph']] = False 118 | feed_dict[pcn_ops['noise']] = noise 119 | feed_dict[pcn_ops['is_training_ph']] = False 120 | pred = sess.run(pcn_ops['pred_full'], feed_dict=feed_dict) 121 | 122 | preds = np.concatenate((pred[0, 0], pred[0, 1]), axis=0) 123 | for p in xrange(2, NUM_PARTS): 124 | preds = np.concatenate((preds, pred[0, p]), axis=0) 125 | 126 | show_3d_point_clouds(preds, missing_part) 127 | 128 | 129 | def show_3d_point_clouds(shapes, is_missing_part): 130 | colors = np.zeros_like(shapes) 131 | for p in xrange(NUM_PARTS): 132 | colors[NUM_POINTS * p:NUM_POINTS * (p + 1), :] = COLORS[p] 133 | 134 | # fix orientation 135 | shapes[:, 1] *= -1 136 | shapes = shapes[:, [1, 2, 0]] 137 | 138 | if is_missing_part: 139 | shapes = shapes[:NUM_POINTS * (NUM_PARTS - 1)] 140 | colors = colors[:NUM_POINTS * (NUM_PARTS - 1)] 141 | show3d_balls.showpoints(shapes, c_gt=colors, ballradius=8, normalizecolor=False, background=[255, 255, 255]) 142 | 143 | 144 | def test(): 145 | sess, ae_ops, pcn_ops = get_model(batch_size=1) 146 | samples = compute_embedding_gmm_and_sample_vectors(sess, ae_ops) 147 | generate_shapse_from_vectors(sess, ae_ops, pcn_ops, samples) 148 | 149 | 150 | if __name__ == "__main__": 151 | test() 152 | -------------------------------------------------------------------------------- /tf_ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nschor/CompoNet/816da713d36d96d715a9c026c8ac4ae568a24780/tf_ops/__init__.py -------------------------------------------------------------------------------- /tf_ops/nn_distance/README.md: -------------------------------------------------------------------------------- 1 | From https://github.com/fanhqme/PointSetGeneration/tree/master/depthestimate 2 | -------------------------------------------------------------------------------- /tf_ops/nn_distance/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nschor/CompoNet/816da713d36d96d715a9c026c8ac4ae568a24780/tf_ops/nn_distance/__init__.py -------------------------------------------------------------------------------- /tf_ops/nn_distance/tf_nndistance.cpp: -------------------------------------------------------------------------------- 1 | #include "tensorflow/core/framework/op.h" 2 | #include "tensorflow/core/framework/op_kernel.h" 3 | REGISTER_OP("NnDistance") 4 | .Input("xyz1: float32") 5 | .Input("xyz2: float32") 6 | .Output("dist1: float32") 7 | .Output("idx1: int32") 8 | .Output("dist2: float32") 9 | .Output("idx2: int32"); 10 | REGISTER_OP("NnDistanceGrad") 11 | .Input("xyz1: float32") 12 | .Input("xyz2: float32") 13 | .Input("grad_dist1: float32") 14 | .Input("idx1: int32") 15 | .Input("grad_dist2: float32") 16 | .Input("idx2: int32") 17 | .Output("grad_xyz1: float32") 18 | .Output("grad_xyz2: float32"); 19 | using namespace tensorflow; 20 | 21 | static void nnsearch(int b,int n,int m,const float * xyz1,const float * xyz2,float * dist,int * idx){ 22 | for (int i=0;iinput(0); 50 | const Tensor& xyz2_tensor=context->input(1); 51 | OP_REQUIRES(context,xyz1_tensor.dims()==3,errors::InvalidArgument("NnDistance requires xyz1 be of shape (batch,#points,3)")); 52 | OP_REQUIRES(context,xyz1_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistance only accepts 3d point set xyz1")); 53 | int b=xyz1_tensor.shape().dim_size(0); 54 | int n=xyz1_tensor.shape().dim_size(1); 55 | OP_REQUIRES(context,xyz2_tensor.dims()==3,errors::InvalidArgument("NnDistance requires xyz2 be of shape (batch,#points,3)")); 56 | OP_REQUIRES(context,xyz2_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistance only accepts 3d point set xyz2")); 57 | int m=xyz2_tensor.shape().dim_size(1); 58 | OP_REQUIRES(context,xyz2_tensor.shape().dim_size(0)==b,errors::InvalidArgument("NnDistance expects xyz1 and xyz2 have same batch size")); 59 | auto xyz1_flat=xyz1_tensor.flat(); 60 | const float * xyz1=&xyz1_flat(0); 61 | auto xyz2_flat=xyz2_tensor.flat(); 62 | const float * xyz2=&xyz2_flat(0); 63 | Tensor * dist1_tensor=NULL; 64 | Tensor * idx1_tensor=NULL; 65 | Tensor * dist2_tensor=NULL; 66 | Tensor * idx2_tensor=NULL; 67 | OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,n},&dist1_tensor)); 68 | OP_REQUIRES_OK(context,context->allocate_output(1,TensorShape{b,n},&idx1_tensor)); 69 | auto dist1_flat=dist1_tensor->flat(); 70 | auto idx1_flat=idx1_tensor->flat(); 71 | OP_REQUIRES_OK(context,context->allocate_output(2,TensorShape{b,m},&dist2_tensor)); 72 | OP_REQUIRES_OK(context,context->allocate_output(3,TensorShape{b,m},&idx2_tensor)); 73 | auto dist2_flat=dist2_tensor->flat(); 74 | auto idx2_flat=idx2_tensor->flat(); 75 | float * dist1=&(dist1_flat(0)); 76 | int * idx1=&(idx1_flat(0)); 77 | float * dist2=&(dist2_flat(0)); 78 | int * idx2=&(idx2_flat(0)); 79 | nnsearch(b,n,m,xyz1,xyz2,dist1,idx1); 80 | nnsearch(b,m,n,xyz2,xyz1,dist2,idx2); 81 | } 82 | }; 83 | REGISTER_KERNEL_BUILDER(Name("NnDistance").Device(DEVICE_CPU), NnDistanceOp); 84 | class NnDistanceGradOp : public OpKernel{ 85 | public: 86 | explicit NnDistanceGradOp(OpKernelConstruction* context):OpKernel(context){} 87 | void Compute(OpKernelContext * context)override{ 88 | const Tensor& xyz1_tensor=context->input(0); 89 | const Tensor& xyz2_tensor=context->input(1); 90 | const Tensor& grad_dist1_tensor=context->input(2); 91 | const Tensor& idx1_tensor=context->input(3); 92 | const Tensor& grad_dist2_tensor=context->input(4); 93 | const Tensor& idx2_tensor=context->input(5); 94 | OP_REQUIRES(context,xyz1_tensor.dims()==3,errors::InvalidArgument("NnDistanceGrad requires xyz1 be of shape (batch,#points,3)")); 95 | OP_REQUIRES(context,xyz1_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistanceGrad only accepts 3d point set xyz1")); 96 | int b=xyz1_tensor.shape().dim_size(0); 97 | int n=xyz1_tensor.shape().dim_size(1); 98 | OP_REQUIRES(context,xyz2_tensor.dims()==3,errors::InvalidArgument("NnDistanceGrad requires xyz2 be of shape (batch,#points,3)")); 99 | OP_REQUIRES(context,xyz2_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistanceGrad only accepts 3d point set xyz2")); 100 | int m=xyz2_tensor.shape().dim_size(1); 101 | OP_REQUIRES(context,xyz2_tensor.shape().dim_size(0)==b,errors::InvalidArgument("NnDistanceGrad expects xyz1 and xyz2 have same batch size")); 102 | OP_REQUIRES(context,grad_dist1_tensor.shape()==(TensorShape{b,n}),errors::InvalidArgument("NnDistanceGrad requires grad_dist1 be of shape(batch,#points)")); 103 | OP_REQUIRES(context,idx1_tensor.shape()==(TensorShape{b,n}),errors::InvalidArgument("NnDistanceGrad requires idx1 be of shape(batch,#points)")); 104 | OP_REQUIRES(context,grad_dist2_tensor.shape()==(TensorShape{b,m}),errors::InvalidArgument("NnDistanceGrad requires grad_dist2 be of shape(batch,#points)")); 105 | OP_REQUIRES(context,idx2_tensor.shape()==(TensorShape{b,m}),errors::InvalidArgument("NnDistanceGrad requires idx2 be of shape(batch,#points)")); 106 | auto xyz1_flat=xyz1_tensor.flat(); 107 | const float * xyz1=&xyz1_flat(0); 108 | auto xyz2_flat=xyz2_tensor.flat(); 109 | const float * xyz2=&xyz2_flat(0); 110 | auto idx1_flat=idx1_tensor.flat(); 111 | const int * idx1=&idx1_flat(0); 112 | auto idx2_flat=idx2_tensor.flat(); 113 | const int * idx2=&idx2_flat(0); 114 | auto grad_dist1_flat=grad_dist1_tensor.flat(); 115 | const float * grad_dist1=&grad_dist1_flat(0); 116 | auto grad_dist2_flat=grad_dist2_tensor.flat(); 117 | const float * grad_dist2=&grad_dist2_flat(0); 118 | Tensor * grad_xyz1_tensor=NULL; 119 | OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,n,3},&grad_xyz1_tensor)); 120 | Tensor * grad_xyz2_tensor=NULL; 121 | OP_REQUIRES_OK(context,context->allocate_output(1,TensorShape{b,m,3},&grad_xyz2_tensor)); 122 | auto grad_xyz1_flat=grad_xyz1_tensor->flat(); 123 | float * grad_xyz1=&grad_xyz1_flat(0); 124 | auto grad_xyz2_flat=grad_xyz2_tensor->flat(); 125 | float * grad_xyz2=&grad_xyz2_flat(0); 126 | for (int i=0;iinput(0); 174 | const Tensor& xyz2_tensor=context->input(1); 175 | OP_REQUIRES(context,xyz1_tensor.dims()==3,errors::InvalidArgument("NnDistance requires xyz1 be of shape (batch,#points,3)")); 176 | OP_REQUIRES(context,xyz1_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistance only accepts 3d point set xyz1")); 177 | int b=xyz1_tensor.shape().dim_size(0); 178 | int n=xyz1_tensor.shape().dim_size(1); 179 | OP_REQUIRES(context,xyz2_tensor.dims()==3,errors::InvalidArgument("NnDistance requires xyz2 be of shape (batch,#points,3)")); 180 | OP_REQUIRES(context,xyz2_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistance only accepts 3d point set xyz2")); 181 | int m=xyz2_tensor.shape().dim_size(1); 182 | OP_REQUIRES(context,xyz2_tensor.shape().dim_size(0)==b,errors::InvalidArgument("NnDistance expects xyz1 and xyz2 have same batch size")); 183 | auto xyz1_flat=xyz1_tensor.flat(); 184 | const float * xyz1=&xyz1_flat(0); 185 | auto xyz2_flat=xyz2_tensor.flat(); 186 | const float * xyz2=&xyz2_flat(0); 187 | Tensor * dist1_tensor=NULL; 188 | Tensor * idx1_tensor=NULL; 189 | Tensor * dist2_tensor=NULL; 190 | Tensor * idx2_tensor=NULL; 191 | OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,n},&dist1_tensor)); 192 | OP_REQUIRES_OK(context,context->allocate_output(1,TensorShape{b,n},&idx1_tensor)); 193 | auto dist1_flat=dist1_tensor->flat(); 194 | auto idx1_flat=idx1_tensor->flat(); 195 | OP_REQUIRES_OK(context,context->allocate_output(2,TensorShape{b,m},&dist2_tensor)); 196 | OP_REQUIRES_OK(context,context->allocate_output(3,TensorShape{b,m},&idx2_tensor)); 197 | auto dist2_flat=dist2_tensor->flat(); 198 | auto idx2_flat=idx2_tensor->flat(); 199 | float * dist1=&(dist1_flat(0)); 200 | int * idx1=&(idx1_flat(0)); 201 | float * dist2=&(dist2_flat(0)); 202 | int * idx2=&(idx2_flat(0)); 203 | NmDistanceKernelLauncher(b,n,xyz1,m,xyz2,dist1,idx1,dist2,idx2); 204 | } 205 | }; 206 | REGISTER_KERNEL_BUILDER(Name("NnDistance").Device(DEVICE_GPU), NnDistanceGpuOp); 207 | 208 | void NmDistanceGradKernelLauncher(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2); 209 | class NnDistanceGradGpuOp : public OpKernel{ 210 | public: 211 | explicit NnDistanceGradGpuOp(OpKernelConstruction* context):OpKernel(context){} 212 | void Compute(OpKernelContext * context)override{ 213 | const Tensor& xyz1_tensor=context->input(0); 214 | const Tensor& xyz2_tensor=context->input(1); 215 | const Tensor& grad_dist1_tensor=context->input(2); 216 | const Tensor& idx1_tensor=context->input(3); 217 | const Tensor& grad_dist2_tensor=context->input(4); 218 | const Tensor& idx2_tensor=context->input(5); 219 | OP_REQUIRES(context,xyz1_tensor.dims()==3,errors::InvalidArgument("NnDistanceGrad requires xyz1 be of shape (batch,#points,3)")); 220 | OP_REQUIRES(context,xyz1_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistanceGrad only accepts 3d point set xyz1")); 221 | int b=xyz1_tensor.shape().dim_size(0); 222 | int n=xyz1_tensor.shape().dim_size(1); 223 | OP_REQUIRES(context,xyz2_tensor.dims()==3,errors::InvalidArgument("NnDistanceGrad requires xyz2 be of shape (batch,#points,3)")); 224 | OP_REQUIRES(context,xyz2_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistanceGrad only accepts 3d point set xyz2")); 225 | int m=xyz2_tensor.shape().dim_size(1); 226 | OP_REQUIRES(context,xyz2_tensor.shape().dim_size(0)==b,errors::InvalidArgument("NnDistanceGrad expects xyz1 and xyz2 have same batch size")); 227 | OP_REQUIRES(context,grad_dist1_tensor.shape()==(TensorShape{b,n}),errors::InvalidArgument("NnDistanceGrad requires grad_dist1 be of shape(batch,#points)")); 228 | OP_REQUIRES(context,idx1_tensor.shape()==(TensorShape{b,n}),errors::InvalidArgument("NnDistanceGrad requires idx1 be of shape(batch,#points)")); 229 | OP_REQUIRES(context,grad_dist2_tensor.shape()==(TensorShape{b,m}),errors::InvalidArgument("NnDistanceGrad requires grad_dist2 be of shape(batch,#points)")); 230 | OP_REQUIRES(context,idx2_tensor.shape()==(TensorShape{b,m}),errors::InvalidArgument("NnDistanceGrad requires idx2 be of shape(batch,#points)")); 231 | auto xyz1_flat=xyz1_tensor.flat(); 232 | const float * xyz1=&xyz1_flat(0); 233 | auto xyz2_flat=xyz2_tensor.flat(); 234 | const float * xyz2=&xyz2_flat(0); 235 | auto idx1_flat=idx1_tensor.flat(); 236 | const int * idx1=&idx1_flat(0); 237 | auto idx2_flat=idx2_tensor.flat(); 238 | const int * idx2=&idx2_flat(0); 239 | auto grad_dist1_flat=grad_dist1_tensor.flat(); 240 | const float * grad_dist1=&grad_dist1_flat(0); 241 | auto grad_dist2_flat=grad_dist2_tensor.flat(); 242 | const float * grad_dist2=&grad_dist2_flat(0); 243 | Tensor * grad_xyz1_tensor=NULL; 244 | OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,n,3},&grad_xyz1_tensor)); 245 | Tensor * grad_xyz2_tensor=NULL; 246 | OP_REQUIRES_OK(context,context->allocate_output(1,TensorShape{b,m,3},&grad_xyz2_tensor)); 247 | auto grad_xyz1_flat=grad_xyz1_tensor->flat(); 248 | float * grad_xyz1=&grad_xyz1_flat(0); 249 | auto grad_xyz2_flat=grad_xyz2_tensor->flat(); 250 | float * grad_xyz2=&grad_xyz2_flat(0); 251 | NmDistanceGradKernelLauncher(b,n,xyz1,m,xyz2,grad_dist1,idx1,grad_dist2,idx2,grad_xyz1,grad_xyz2); 252 | } 253 | }; 254 | REGISTER_KERNEL_BUILDER(Name("NnDistanceGrad").Device(DEVICE_GPU), NnDistanceGradGpuOp); 255 | -------------------------------------------------------------------------------- /tf_ops/nn_distance/tf_nndistance.py: -------------------------------------------------------------------------------- 1 | """ Compute Chamfer's Distance. 2 | 3 | Original author: Haoqiang Fan. 4 | Modified by Charles R. Qi 5 | """ 6 | 7 | import tensorflow as tf 8 | from tensorflow.python.framework import ops 9 | import sys 10 | import os 11 | 12 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 13 | nn_distance_module = tf.load_op_library(os.path.join(BASE_DIR, 'tf_nndistance_so.so')) 14 | 15 | 16 | def nn_distance(xyz1, xyz2): 17 | ''' 18 | Computes the distance of nearest neighbors for a pair of point clouds 19 | input: xyz1: (batch_size,#points_1,3) the first point cloud 20 | input: xyz2: (batch_size,#points_2,3) the second point cloud 21 | output: dist1: (batch_size,#point_1) distance from first to second 22 | output: idx1: (batch_size,#point_1) nearest neighbor from first to second 23 | output: dist2: (batch_size,#point_2) distance from second to first 24 | output: idx2: (batch_size,#point_2) nearest neighbor from second to first 25 | ''' 26 | return nn_distance_module.nn_distance(xyz1, xyz2) 27 | 28 | 29 | # @tf.RegisterShape('NnDistance') 30 | # def _nn_distance_shape(op): 31 | # shape1=op.inputs[0].get_shape().with_rank(3) 32 | # shape2=op.inputs[1].get_shape().with_rank(3) 33 | # return [tf.TensorShape([shape1.dims[0],shape1.dims[1]]),tf.TensorShape([shape1.dims[0],shape1.dims[1]]), 34 | # tf.TensorShape([shape2.dims[0],shape2.dims[1]]),tf.TensorShape([shape2.dims[0],shape2.dims[1]])] 35 | @ops.RegisterGradient('NnDistance') 36 | def _nn_distance_grad(op, grad_dist1, grad_idx1, grad_dist2, grad_idx2): 37 | xyz1 = op.inputs[0] 38 | xyz2 = op.inputs[1] 39 | idx1 = op.outputs[1] 40 | idx2 = op.outputs[3] 41 | return nn_distance_module.nn_distance_grad(xyz1, xyz2, grad_dist1, idx1, grad_dist2, idx2) 42 | 43 | 44 | if __name__ == '__main__': 45 | import numpy as np 46 | import random 47 | import time 48 | from tensorflow.python.ops.gradient_checker import compute_gradient 49 | 50 | random.seed(100) 51 | np.random.seed(100) 52 | with tf.Session('') as sess: 53 | xyz1 = np.random.randn(32, 16384, 3).astype('float32') 54 | xyz2 = np.random.randn(32, 1024, 3).astype('float32') 55 | # with tf.device('/gpu:0'): 56 | if True: 57 | inp1 = tf.Variable(xyz1) 58 | inp2 = tf.constant(xyz2) 59 | reta, retb, retc, retd = nn_distance(inp1, inp2) 60 | loss = tf.reduce_sum(reta) + tf.reduce_sum(retc) 61 | train = tf.train.GradientDescentOptimizer(learning_rate=0.05).minimize(loss) 62 | sess.run(tf.initialize_all_variables()) 63 | t0 = time.time() 64 | t1 = t0 65 | best = 1e100 66 | for i in xrange(100): 67 | trainloss, _ = sess.run([loss, train]) 68 | newt = time.time() 69 | best = min(best, newt - t1) 70 | print i, trainloss, (newt - t0) / (i + 1), best 71 | t1 = newt 72 | # print sess.run([inp1,retb,inp2,retd]) 73 | # grads=compute_gradient([inp1,inp2],[(16,32,3),(16,32,3)],loss,(1,),[xyz1,xyz2]) 74 | # for i,j in grads: 75 | # print i.shape,j.shape,np.mean(np.abs(i-j)),np.mean(np.abs(i)),np.mean(np.abs(j)) 76 | # for i in xrange(10): 77 | # t0=time.time() 78 | # a,b,c,d=sess.run([reta,retb,retc,retd],feed_dict={inp1:xyz1,inp2:xyz2}) 79 | # print 'time',time.time()-t0 80 | # print a.shape,b.shape,c.shape,d.shape 81 | # print a.dtype,b.dtype,c.dtype,d.dtype 82 | # samples=np.array(random.sample(range(xyz2.shape[1]),100),dtype='int32') 83 | # dist1=((xyz1[:,samples,None,:]-xyz2[:,None,:,:])**2).sum(axis=-1).min(axis=-1) 84 | # idx1=((xyz1[:,samples,None,:]-xyz2[:,None,:,:])**2).sum(axis=-1).argmin(axis=-1) 85 | # print np.abs(dist1-a[:,samples]).max() 86 | # print np.abs(idx1-b[:,samples]).max() 87 | # dist2=((xyz2[:,samples,None,:]-xyz1[:,None,:,:])**2).sum(axis=-1).min(axis=-1) 88 | # idx2=((xyz2[:,samples,None,:]-xyz1[:,None,:,:])**2).sum(axis=-1).argmin(axis=-1) 89 | # print np.abs(dist2-c[:,samples]).max() 90 | # print np.abs(idx2-d[:,samples]).max() 91 | -------------------------------------------------------------------------------- /tf_ops/nn_distance/tf_nndistance_compile.sh: -------------------------------------------------------------------------------- 1 | /usr/local/cuda-8.0/bin/nvcc tf_nndistance_g.cu -o tf_nndistance_g.cu.o -c -O2 -DGOOGLE_CUDA=1 -x cu -Xcompiler -fPIC 2 | g++ -std=c++11 tf_nndistance.cpp tf_nndistance_g.cu.o -o tf_nndistance_so.so -shared -fPIC -I /home/nadav/anaconda2/envs/tensorflow/lib/python2.7/dist-packages/tensorflow/include -I /usr/local/cuda-8.0/include -I /home/nadav/anaconda2/envs/tensorflow/lib/python2.7/dist-packages/tensorflow/include/external/nsync/public -lcudart -L /usr/local/cuda-8.0/lib64/ -L/home/nadav/anaconda2/envs/tensorflow/lib/python2.7/dist-packages/tensorflow -ltensorflow_framework -O2 -D_GLIBCXX_USE_CXX11_ABI=0 3 | 4 | -------------------------------------------------------------------------------- /tf_ops/nn_distance/tf_nndistance_cpu.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | def nn_distance_cpu(pc1, pc2): 5 | ''' 6 | Input: 7 | pc1: float TF tensor in shape (B,N,C) the first point cloud 8 | pc2: float TF tensor in shape (B,M,C) the second point cloud 9 | Output: 10 | dist1: float TF tensor in shape (B,N) distance from first to second 11 | idx1: int32 TF tensor in shape (B,N) nearest neighbor from first to second 12 | dist2: float TF tensor in shape (B,M) distance from second to first 13 | idx2: int32 TF tensor in shape (B,M) nearest neighbor from second to first 14 | ''' 15 | N = pc1.get_shape()[1].value 16 | M = pc2.get_shape()[1].value 17 | pc1_expand_tile = tf.tile(tf.expand_dims(pc1,2), [1,1,M,1]) 18 | pc2_expand_tile = tf.tile(tf.expand_dims(pc2,1), [1,N,1,1]) 19 | pc_diff = pc1_expand_tile - pc2_expand_tile # B,N,M,C 20 | pc_dist = tf.reduce_sum(pc_diff ** 2, axis=-1) # B,N,M 21 | dist1 = tf.reduce_min(pc_dist, axis=2) # B,N 22 | idx1 = tf.argmin(pc_dist, axis=2) # B,N 23 | dist2 = tf.reduce_min(pc_dist, axis=1) # B,M 24 | idx2 = tf.argmin(pc_dist, axis=1) # B,M 25 | return dist1, idx1, dist2, idx2 26 | 27 | 28 | def verify_nn_distance_cup(): 29 | np.random.seed(0) 30 | sess = tf.Session() 31 | pc1arr = np.random.random((1,5,3)) 32 | pc2arr = np.random.random((1,6,3)) 33 | pc1 = tf.constant(pc1arr) 34 | pc2 = tf.constant(pc2arr) 35 | dist1, idx1, dist2, idx2 = nn_distance_cpu(pc1, pc2) 36 | print(sess.run(dist1)) 37 | print(sess.run(idx1)) 38 | print(sess.run(dist2)) 39 | print(sess.run(idx2)) 40 | 41 | dist = np.zeros((5,6)) 42 | for i in range(5): 43 | for j in range(6): 44 | dist[i,j] = np.sum((pc1arr[0,i,:] - pc2arr[0,j,:]) ** 2) 45 | print(dist) 46 | 47 | if __name__ == '__main__': 48 | verify_nn_distance_cup() 49 | -------------------------------------------------------------------------------- /tf_ops/nn_distance/tf_nndistance_g.cu: -------------------------------------------------------------------------------- 1 | #if GOOGLE_CUDA 2 | #define EIGEN_USE_GPU 3 | //#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 4 | 5 | __global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){ 6 | const int batch=512; 7 | __shared__ float buf[batch*3]; 8 | for (int i=blockIdx.x;ibest){ 120 | result[(i*n+j)]=best; 121 | result_i[(i*n+j)]=best_i; 122 | } 123 | } 124 | __syncthreads(); 125 | } 126 | } 127 | } 128 | void NmDistanceKernelLauncher(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i){ 129 | NmDistanceKernel<<>>(b,n,xyz,m,xyz2,result,result_i); 130 | NmDistanceKernel<<>>(b,m,xyz2,n,xyz,result2,result2_i); 131 | } 132 | __global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){ 133 | for (int i=blockIdx.x;i>>(b,n,xyz1,m,xyz2,grad_dist1,idx1,grad_xyz1,grad_xyz2); 156 | NmDistanceGradKernel<<>>(b,m,xyz2,n,xyz1,grad_dist2,idx2,grad_xyz2,grad_xyz1); 157 | } 158 | 159 | #endif 160 | 161 | -------------------------------------------------------------------------------- /tf_ops/nn_distance/tf_nndistance_g.cu.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nschor/CompoNet/816da713d36d96d715a9c026c8ac4ae568a24780/tf_ops/nn_distance/tf_nndistance_g.cu.o -------------------------------------------------------------------------------- /tf_ops/nn_distance/tf_nndistance_so.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nschor/CompoNet/816da713d36d96d715a9c026c8ac4ae568a24780/tf_ops/nn_distance/tf_nndistance_so.so -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from datetime import datetime 3 | import numpy as np 4 | import tensorflow as tf 5 | import os 6 | import sys 7 | import models 8 | from data import data_utils 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--data_path', default='datasets/shapenetcore_partanno_segmentation_benchmark_v0', 12 | help='Path to the dataset [default: datasets/shapenetcore_partanno_segmentation_benchmark_v0]') 13 | parser.add_argument('--gpu', type=int, default=0, help='GPU to use [default: GPU 0]') 14 | parser.add_argument('--category', default='Chair', help='Which single class to train on [default: Chair]') 15 | parser.add_argument('--log_dir', default='log', help='Log dir [default: log]') 16 | parser.add_argument('--part_embedding_dim', type=int, default=64, help='Embedding dimension of each part [default: 64]') 17 | parser.add_argument('--noise_embedding_dim', type=int, default=16, 18 | help='Embedding dimension of the noise [default: 16]') 19 | parser.add_argument('--num_point', type=int, default=400, help='Number of points per part [default: 400]') 20 | parser.add_argument('--max_epoch_ae', type=int, default=401, help='Number of epochs for each AEs [default: 401]') 21 | parser.add_argument('--max_epoch_pcn', type=int, default=201, 22 | help='Number of epochs for the parts composition network [default: 201]') 23 | parser.add_argument('--batch_size', type=int, default=64, help='Batch Size during training [default: 64]') 24 | parser.add_argument('--learning_rate', type=float, default=0.001, 25 | help='Initial learning rate for the composition network [default: 0.001]') 26 | parser.add_argument('--decay_step', type=int, default=200000, help='Decay step for lr decay [default: 200000]') 27 | parser.add_argument('--decay_rate', type=float, default=0.7, help='Decay rate for lr decay [default: 0.7]') 28 | FLAGS = parser.parse_args() 29 | GPU_INDEX = FLAGS.gpu 30 | CATEGORY = FLAGS.category 31 | LOG_DIR = FLAGS.log_dir 32 | PART_EMBEDDING_DIM = FLAGS.part_embedding_dim 33 | NOISE_EMBEDDING_DIM = FLAGS.noise_embedding_dim 34 | NUM_POINTS = FLAGS.num_point 35 | MAX_EPOCH_AE = FLAGS.max_epoch_ae 36 | MAX_EPOCH_PCN = FLAGS.max_epoch_pcn 37 | BATCH_SIZE = FLAGS.batch_size 38 | BASE_LEARNING_RATE = FLAGS.learning_rate 39 | DECAY_STEP = FLAGS.decay_step 40 | DECAY_RATE = FLAGS.decay_rate 41 | 42 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 43 | if not os.path.exists(LOG_DIR): 44 | os.mkdir(LOG_DIR) 45 | LOG_FOUT = open(os.path.join(LOG_DIR, 'log_train.txt'), 'w') 46 | LOG_FOUT.write(str(FLAGS) + '\n') 47 | 48 | print 'Loading data' 49 | # Shapenet official train/test split 50 | DATA_PATH = os.path.join(ROOT_DIR, FLAGS.data_path) 51 | # Using the smaller set, test, as seen and the larger, trainval, as unseen 52 | PCN_TRAIN_DATASET, PCN_TEST_DATASET, AE_TRAIN_DATASET, AE_TEST_DATASET, NUM_PARTS = data_utils.load_data(DATA_PATH, 53 | NUM_POINTS, 54 | CATEGORY, 55 | 'test', 56 | 'trainval') 57 | 58 | NOISE = np.random.normal(size=[len(PCN_TRAIN_DATASET), NOISE_EMBEDDING_DIM]) 59 | EPOCH_CNT = 0 60 | 61 | 62 | def log_string(out_str): 63 | LOG_FOUT.write(out_str + '\n') 64 | LOG_FOUT.flush() 65 | print(out_str) 66 | 67 | 68 | def get_ae_batch(dataset, idxs, start_idx, end_idx): 69 | batch_size = end_idx - start_idx 70 | batch_data = np.zeros((batch_size, NUM_POINTS, 3)) 71 | for i in range(batch_size): 72 | ps = dataset[idxs[i + start_idx]] 73 | batch_data[i, ...] = ps 74 | return batch_data 75 | 76 | 77 | def train_ae_one_epoch(sess, ops_ae, train_dataset, train_writer): 78 | is_training = True 79 | log_string(str(datetime.now())) 80 | 81 | # Shuffle train samples 82 | train_idxs = np.arange(0, len(train_dataset)) 83 | np.random.shuffle(train_idxs) 84 | num_batches = len(train_dataset) / BATCH_SIZE 85 | 86 | loss_sum = 0 87 | for batch_idx in range(num_batches): 88 | start_idx = batch_idx * BATCH_SIZE 89 | end_idx = (batch_idx + 1) * BATCH_SIZE 90 | batch_data = get_ae_batch(train_dataset, train_idxs, start_idx, end_idx) 91 | feed_dict = {ops_ae['point_clouds_ph']: batch_data, 92 | ops_ae['gt_ph']: batch_data, 93 | ops_ae['is_training_ph']: is_training, } 94 | summary, step, _, loss_val, pred_val = sess.run([ops_ae['merged_summary'], ops_ae['step'], 95 | ops_ae['train_op'], ops_ae['loss'], 96 | ops_ae['pred']], feed_dict=feed_dict) 97 | train_writer.add_summary(summary, step) 98 | loss_sum += loss_val 99 | 100 | if (batch_idx + 1) % 10 == 0: 101 | log_string(' -- %03d / %03d --' % (batch_idx + 1, num_batches)) 102 | log_string('mean loss: %f' % (loss_sum / 10)) 103 | loss_sum = 0 104 | 105 | 106 | def get_pcn_batch(dataset, idxs, start_idx, end_idx): 107 | batch_size = end_idx - start_idx 108 | x = np.zeros((NUM_PARTS, batch_size, NUM_POINTS, 3)) 109 | y = np.zeros((batch_size, NUM_PARTS, NUM_POINTS, 3)) 110 | y_mask = np.zeros((batch_size, NUM_PARTS, NUM_POINTS)) 111 | noise = NOISE[idxs[start_idx:end_idx]] 112 | 113 | for i in range(batch_size): 114 | point_sets = dataset[idxs[start_idx + i]] 115 | for p in xrange(NUM_PARTS): 116 | ps, sn, is_full = point_sets[p] 117 | x[p, i, ...] = sn 118 | y[i, p, ...] = ps 119 | y_mask[i, p] = is_full 120 | return x, y, y_mask, noise 121 | 122 | 123 | def train_pcn_one_epoch(sess, pcn_ops, ae_ops, train_dataset, train_writer): 124 | is_training = True 125 | log_string(str(datetime.now())) 126 | 127 | # Shuffle train samples 128 | train_idxs = np.arange(0, len(train_dataset)) 129 | np.random.shuffle(train_idxs) 130 | num_batches = len(train_dataset) / BATCH_SIZE 131 | 132 | loss_sum = 0 133 | total_loss = 0 134 | for batch_idx in range(num_batches): 135 | start_idx = batch_idx * BATCH_SIZE 136 | end_idx = (batch_idx + 1) * BATCH_SIZE 137 | x, y, y_mask, noise = get_pcn_batch(train_dataset, train_idxs, start_idx, end_idx) 138 | feed_dict = {pcn_ops['y']: y, 139 | pcn_ops['y_mask']: y_mask, 140 | pcn_ops['noise']: noise, 141 | pcn_ops['is_training_ph']: is_training, } 142 | 143 | for part in xrange(len(x)): 144 | feed_dict[ae_ops[part]['point_clouds_ph']] = x[part] 145 | feed_dict[ae_ops[part]['gt_ph']] = x[part] 146 | feed_dict[ae_ops[part]['is_training_ph']] = False # encoders are set in the PCN training phase 147 | 148 | summary, step, _, loss_val, pred_val = sess.run( 149 | [pcn_ops['merged_summary'], pcn_ops['step'], pcn_ops['train_op'], pcn_ops['loss'], pcn_ops['pred']], 150 | feed_dict=feed_dict) 151 | 152 | train_writer.add_summary(summary, step) 153 | loss_sum += loss_val 154 | total_loss += loss_val 155 | 156 | if (batch_idx + 1) % 10 == 0: 157 | log_string(' -- %03d / %03d --' % (batch_idx + 1, num_batches)) 158 | log_string('mean loss: %f' % (loss_sum / 10)) 159 | loss_sum = 0 160 | 161 | return total_loss / float(num_batches) 162 | 163 | 164 | def train(): 165 | with tf.Graph().as_default(): 166 | with tf.device('/gpu:' + str(GPU_INDEX)): 167 | ''' Parts' AE ''' 168 | ae_ops, point_clouds_ph = models.build_parts_aes_graphs(NUM_PARTS, NUM_POINTS, PART_EMBEDDING_DIM, 169 | BASE_LEARNING_RATE, BATCH_SIZE, DECAY_STEP, 170 | DECAY_RATE, float(DECAY_STEP)) 171 | 172 | ''' PCN ''' 173 | pcn_ops = models.build_parts_pcn_graph(ae_ops, point_clouds_ph, NUM_PARTS, NUM_POINTS, NOISE_EMBEDDING_DIM, 174 | BASE_LEARNING_RATE, BATCH_SIZE, DECAY_STEP, DECAY_RATE, 175 | float(DECAY_STEP)) 176 | 177 | saver = tf.train.Saver() 178 | 179 | config = tf.ConfigProto() 180 | config.gpu_options.allow_growth = True 181 | config.allow_soft_placement = True 182 | config.log_device_placement = False 183 | sess = tf.Session(config=config) 184 | 185 | train_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'train'), sess.graph) 186 | 187 | init = tf.global_variables_initializer() 188 | sess.run(init) 189 | 190 | ''' Train Parts AE ''' 191 | for i in xrange(NUM_PARTS): 192 | print 'Training part ' + str(i) 193 | for epoch in range(MAX_EPOCH_AE): 194 | log_string('**** EPOCH %03d ****' % epoch) 195 | sys.stdout.flush() 196 | train_ae_one_epoch(sess, ae_ops[i], AE_TRAIN_DATASET[i], train_writer) 197 | 198 | ''' Train PCN ''' 199 | best_loss = 1e20 200 | for epoch in range(MAX_EPOCH_PCN): 201 | log_string('**** EPOCH %03d ****' % epoch) 202 | sys.stdout.flush() 203 | train_loss = train_pcn_one_epoch(sess, pcn_ops, ae_ops, PCN_TRAIN_DATASET, train_writer) 204 | if train_loss < best_loss: 205 | best_loss = train_loss 206 | save_path = saver.save(sess, os.path.join(LOG_DIR, "best_model_epoch_%03d.ckpt" % epoch)) 207 | log_string("Model saved in file: %s" % save_path) 208 | if epoch % 10 == 0: 209 | save_path = saver.save(sess, os.path.join(LOG_DIR, "model.ckpt")) 210 | log_string("Model saved in file: %s" % save_path) 211 | 212 | 213 | if __name__ == "__main__": 214 | log_string('pid: %s' % (str(os.getpid()))) 215 | train() 216 | LOG_FOUT.close() 217 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nschor/CompoNet/816da713d36d96d715a9c026c8ac4ae568a24780/utils/__init__.py -------------------------------------------------------------------------------- /utils/compile_render_balls_so.sh: -------------------------------------------------------------------------------- 1 | g++ -std=c++11 render_balls_so.cpp -o render_balls_so.so -shared -fPIC -O2 -D_GLIBCXX_USE_CXX11_ABI=0 2 | 3 | -------------------------------------------------------------------------------- /utils/render_balls_so.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | using namespace std; 6 | 7 | struct PointInfo{ 8 | int x,y,z; 9 | float r,g,b; 10 | }; 11 | 12 | extern "C"{ 13 | 14 | void render_ball(int h,int w,unsigned char * show,int n,int * xyzs,float * c0,float * c1,float * c2,int r){ 15 | r=max(r,1); 16 | vector depth(h*w,-2100000000); 17 | vector pattern; 18 | for (int dx=-r;dx<=r;dx++) 19 | for (int dy=-r;dy<=r;dy++) 20 | if (dx*dx+dy*dy=h || y2<0 || y2>=w) && depth[x2*w+y2] 0: 95 | show[:, :, 0] = np.maximum(show[:, :, 0], np.roll(show[:, :, 0], 1, axis=0)) 96 | if magnifyBlue >= 2: 97 | show[:, :, 0] = np.maximum(show[:, :, 0], np.roll(show[:, :, 0], -1, axis=0)) 98 | show[:, :, 0] = np.maximum(show[:, :, 0], np.roll(show[:, :, 0], 1, axis=1)) 99 | if magnifyBlue >= 2: 100 | show[:, :, 0] = np.maximum(show[:, :, 0], np.roll(show[:, :, 0], -1, axis=1)) 101 | if showrot: 102 | cv2.putText(show, 'xangle %d' % (int(xangle / np.pi * 180)), (30, showsz - 30), 0, 0.5, 103 | cv2.cv.CV_RGB(255, 0, 0)) 104 | cv2.putText(show, 'yangle %d' % (int(yangle / np.pi * 180)), (30, showsz - 50), 0, 0.5, 105 | cv2.cv.CV_RGB(255, 0, 0)) 106 | cv2.putText(show, 'zoom %d%%' % (int(zoom * 100)), (30, showsz - 70), 0, 0.5, cv2.cv.CV_RGB(255, 0, 0)) 107 | 108 | changed = True 109 | stop = False 110 | while not stop: 111 | if changed: 112 | render() 113 | changed = False 114 | cv2.imshow('show3d', show) 115 | if waittime == 0: 116 | cmd = cv2.waitKey(10) % 256 117 | else: 118 | cmd = cv2.waitKey(waittime) % 256 119 | if cmd == ord('q'): 120 | stop = True 121 | elif cmd == ord('Q'): 122 | sys.exit(0) 123 | 124 | if cmd == ord('t') or cmd == ord('p'): 125 | if cmd == ord('t'): 126 | if c_gt is None: 127 | c0 = np.zeros((len(xyz),), dtype='float32') + 255 128 | c1 = np.zeros((len(xyz),), dtype='float32') + 255 129 | c2 = np.zeros((len(xyz),), dtype='float32') + 255 130 | else: 131 | c0 = c_gt[:, 0] 132 | c1 = c_gt[:, 1] 133 | c2 = c_gt[:, 2] 134 | else: 135 | if c_pred is None: 136 | c0 = np.zeros((len(xyz),), dtype='float32') + 255 137 | c1 = np.zeros((len(xyz),), dtype='float32') + 255 138 | c2 = np.zeros((len(xyz),), dtype='float32') + 255 139 | else: 140 | c0 = c_pred[:, 0] 141 | c1 = c_pred[:, 1] 142 | c2 = c_pred[:, 2] 143 | if normalizecolor: 144 | c0 /= (c0.max() + 1e-14) / 255.0 145 | c1 /= (c1.max() + 1e-14) / 255.0 146 | c2 /= (c2.max() + 1e-14) / 255.0 147 | c0 = np.require(c0, 'float32', 'C') 148 | c1 = np.require(c1, 'float32', 'C') 149 | c2 = np.require(c2, 'float32', 'C') 150 | changed = True 151 | 152 | if cmd == ord('n'): 153 | zoom *= 1.1 154 | changed = True 155 | elif cmd == ord('m'): 156 | zoom /= 1.1 157 | changed = True 158 | elif cmd == ord('r'): 159 | zoom = 1.0 160 | changed = True 161 | elif cmd == ord('s'): 162 | cv2.imwrite('show3d.png', show) 163 | elif cmd == ord('f'): 164 | freezerot = ~freezerot 165 | if waittime != 0: 166 | break 167 | return cmd 168 | 169 | 170 | if __name__ == '__main__': 171 | np.random.seed(100) 172 | showpoints(np.random.randn(2500, 3)) 173 | -------------------------------------------------------------------------------- /utils/tf_util.py: -------------------------------------------------------------------------------- 1 | """ Wrapper functions for TensorFlow layers. 2 | 3 | Author: Charles R. Qi 4 | Date: November 2017 5 | """ 6 | 7 | import numpy as np 8 | import tensorflow as tf 9 | 10 | 11 | def _variable_on_cpu(name, shape, initializer, use_fp16=False): 12 | """Helper to create a Variable stored on CPU memory. 13 | Args: 14 | name: name of the variable 15 | shape: list of ints 16 | initializer: initializer for Variable 17 | Returns: 18 | Variable Tensor 19 | """ 20 | with tf.device("/cpu:0"): 21 | dtype = tf.float16 if use_fp16 else tf.float32 22 | var = tf.get_variable(name, shape, initializer=initializer, dtype=dtype) 23 | return var 24 | 25 | 26 | def _variable_with_weight_decay(name, shape, stddev, wd, use_xavier=True, initializer=None): 27 | """Helper to create an initialized Variable with weight decay. 28 | 29 | Note that the Variable is initialized with a truncated normal distribution. 30 | A weight decay is added only if one is specified. 31 | 32 | Args: 33 | name: name of the variable 34 | shape: list of ints 35 | stddev: standard deviation of a truncated Gaussian 36 | wd: add L2Loss weight decay multiplied by this float. If None, weight 37 | decay is not added for this Variable. 38 | use_xavier: bool, whether to use xavier initializer 39 | 40 | Returns: 41 | Variable Tensor 42 | """ 43 | if initializer is None: 44 | if use_xavier: 45 | initializer = tf.contrib.layers.xavier_initializer() 46 | else: 47 | initializer = tf.truncated_normal_initializer(stddev=stddev) 48 | var = _variable_on_cpu(name, shape, initializer) 49 | if wd is not None: 50 | weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='weight_loss') 51 | tf.add_to_collection('losses', weight_decay) 52 | return var 53 | 54 | 55 | def conv1d(inputs, 56 | num_output_channels, 57 | kernel_size, 58 | scope, 59 | stride=1, 60 | padding='SAME', 61 | data_format='NHWC', 62 | use_xavier=True, 63 | stddev=1e-3, 64 | weight_decay=None, 65 | activation_fn=tf.nn.relu, 66 | bn=False, 67 | bn_decay=None, 68 | is_training=None): 69 | """ 1D convolution with non-linear operation. 70 | 71 | Args: 72 | inputs: 3-D tensor variable BxLxC 73 | num_output_channels: int 74 | kernel_size: int 75 | scope: string 76 | stride: int 77 | padding: 'SAME' or 'VALID' 78 | data_format: 'NHWC' or 'NCHW' 79 | use_xavier: bool, use xavier_initializer if true 80 | stddev: float, stddev for truncated_normal init 81 | weight_decay: float 82 | activation_fn: function 83 | bn: bool, whether to use batch norm 84 | bn_decay: float or float tensor variable in [0,1] 85 | is_training: bool Tensor variable 86 | 87 | Returns: 88 | Variable tensor 89 | """ 90 | with tf.variable_scope(scope) as sc: 91 | assert (data_format == 'NHWC' or data_format == 'NCHW') 92 | if data_format == 'NHWC': 93 | num_in_channels = inputs.get_shape()[-1].value 94 | elif data_format == 'NCHW': 95 | num_in_channels = inputs.get_shape()[1].value 96 | kernel_shape = [kernel_size, 97 | num_in_channels, num_output_channels] 98 | kernel = _variable_with_weight_decay('weights', 99 | shape=kernel_shape, 100 | use_xavier=use_xavier, 101 | stddev=stddev, 102 | wd=weight_decay) 103 | outputs = tf.nn.conv1d(inputs, kernel, 104 | stride=stride, 105 | padding=padding, 106 | data_format=data_format) 107 | biases = _variable_on_cpu('biases', [num_output_channels], 108 | tf.constant_initializer(0.0)) 109 | outputs = tf.nn.bias_add(outputs, biases, data_format=data_format) 110 | 111 | if bn: 112 | outputs = batch_norm_for_conv1d(outputs, is_training, 113 | bn_decay=bn_decay, scope='bn', 114 | data_format=data_format) 115 | 116 | if activation_fn is not None: 117 | outputs = activation_fn(outputs) 118 | return outputs 119 | 120 | 121 | def conv2d(inputs, 122 | num_output_channels, 123 | kernel_size, 124 | scope, 125 | stride=[1, 1], 126 | padding='SAME', 127 | data_format='NHWC', 128 | use_xavier=True, 129 | stddev=1e-3, 130 | weight_decay=None, 131 | activation_fn=tf.nn.relu, 132 | bn=False, 133 | bn_decay=None, 134 | is_training=None): 135 | """ 2D convolution with non-linear operation. 136 | 137 | Args: 138 | inputs: 4-D tensor variable BxHxWxC 139 | num_output_channels: int 140 | kernel_size: a list of 2 ints 141 | scope: string 142 | stride: a list of 2 ints 143 | padding: 'SAME' or 'VALID' 144 | data_format: 'NHWC' or 'NCHW' 145 | use_xavier: bool, use xavier_initializer if true 146 | stddev: float, stddev for truncated_normal init 147 | weight_decay: float 148 | activation_fn: function 149 | bn: bool, whether to use batch norm 150 | bn_decay: float or float tensor variable in [0,1] 151 | is_training: bool Tensor variable 152 | 153 | Returns: 154 | Variable tensor 155 | """ 156 | with tf.variable_scope(scope) as sc: 157 | kernel_h, kernel_w = kernel_size 158 | assert (data_format == 'NHWC' or data_format == 'NCHW') 159 | if data_format == 'NHWC': 160 | num_in_channels = inputs.get_shape()[-1].value 161 | elif data_format == 'NCHW': 162 | num_in_channels = inputs.get_shape()[1].value 163 | kernel_shape = [kernel_h, kernel_w, 164 | num_in_channels, num_output_channels] 165 | kernel = _variable_with_weight_decay('weights', 166 | shape=kernel_shape, 167 | use_xavier=use_xavier, 168 | stddev=stddev, 169 | wd=weight_decay) 170 | stride_h, stride_w = stride 171 | outputs = tf.nn.conv2d(inputs, kernel, 172 | [1, stride_h, stride_w, 1], 173 | padding=padding, 174 | data_format=data_format) 175 | biases = _variable_on_cpu('biases', [num_output_channels], 176 | tf.constant_initializer(0.0)) 177 | outputs = tf.nn.bias_add(outputs, biases, data_format=data_format) 178 | 179 | if bn: 180 | outputs = batch_norm_for_conv2d(outputs, is_training, 181 | bn_decay=bn_decay, scope='bn', 182 | data_format=data_format) 183 | 184 | if activation_fn is not None: 185 | outputs = activation_fn(outputs) 186 | return outputs 187 | 188 | 189 | def conv2d_transpose(inputs, 190 | num_output_channels, 191 | kernel_size, 192 | scope, 193 | stride=[1, 1], 194 | padding='SAME', 195 | data_format='NHWC', 196 | use_xavier=True, 197 | stddev=1e-3, 198 | weight_decay=None, 199 | activation_fn=tf.nn.relu, 200 | bn=False, 201 | bn_decay=None, 202 | is_training=None): 203 | """ 2D convolution transpose with non-linear operation. 204 | 205 | Args: 206 | inputs: 4-D tensor variable BxHxWxC 207 | num_output_channels: int 208 | kernel_size: a list of 2 ints 209 | scope: string 210 | stride: a list of 2 ints 211 | padding: 'SAME' or 'VALID' 212 | use_xavier: bool, use xavier_initializer if true 213 | stddev: float, stddev for truncated_normal init 214 | weight_decay: float 215 | activation_fn: function 216 | bn: bool, whether to use batch norm 217 | bn_decay: float or float tensor variable in [0,1] 218 | is_training: bool Tensor variable 219 | 220 | Returns: 221 | Variable tensor 222 | 223 | Note: conv2d(conv2d_transpose(a, num_out, ksize, stride), a.shape[-1], ksize, stride) == a 224 | """ 225 | with tf.variable_scope(scope) as sc: 226 | kernel_h, kernel_w = kernel_size 227 | num_in_channels = inputs.get_shape()[-1].value 228 | kernel_shape = [kernel_h, kernel_w, 229 | num_output_channels, num_in_channels] # reversed to conv2d 230 | kernel = _variable_with_weight_decay('weights', 231 | shape=kernel_shape, 232 | use_xavier=use_xavier, 233 | stddev=stddev, 234 | wd=weight_decay) 235 | stride_h, stride_w = stride 236 | 237 | # from slim.convolution2d_transpose 238 | def get_deconv_dim(dim_size, stride_size, kernel_size, padding): 239 | dim_size *= stride_size 240 | 241 | if padding == 'VALID' and dim_size is not None: 242 | dim_size += max(kernel_size - stride_size, 0) 243 | return dim_size 244 | 245 | # caculate output shape 246 | batch_size = inputs.get_shape()[0].value 247 | height = inputs.get_shape()[1].value 248 | width = inputs.get_shape()[2].value 249 | out_height = get_deconv_dim(height, stride_h, kernel_h, padding) 250 | out_width = get_deconv_dim(width, stride_w, kernel_w, padding) 251 | output_shape = [batch_size, out_height, out_width, num_output_channels] 252 | 253 | outputs = tf.nn.conv2d_transpose(inputs, kernel, output_shape, 254 | [1, stride_h, stride_w, 1], 255 | padding=padding) 256 | biases = _variable_on_cpu('biases', [num_output_channels], 257 | tf.constant_initializer(0.0)) 258 | outputs = tf.nn.bias_add(outputs, biases) 259 | 260 | if bn: 261 | outputs = batch_norm_for_conv2d(outputs, is_training, 262 | bn_decay=bn_decay, scope='bn', 263 | data_format=data_format) 264 | 265 | if activation_fn is not None: 266 | outputs = activation_fn(outputs) 267 | return outputs 268 | 269 | 270 | def conv3d(inputs, 271 | num_output_channels, 272 | kernel_size, 273 | scope, 274 | stride=[1, 1, 1], 275 | padding='SAME', 276 | use_xavier=True, 277 | stddev=1e-3, 278 | weight_decay=None, 279 | activation_fn=tf.nn.relu, 280 | bn=False, 281 | bn_decay=None, 282 | is_training=None): 283 | """ 3D convolution with non-linear operation. 284 | 285 | Args: 286 | inputs: 5-D tensor variable BxDxHxWxC 287 | num_output_channels: int 288 | kernel_size: a list of 3 ints 289 | scope: string 290 | stride: a list of 3 ints 291 | padding: 'SAME' or 'VALID' 292 | use_xavier: bool, use xavier_initializer if true 293 | stddev: float, stddev for truncated_normal init 294 | weight_decay: float 295 | activation_fn: function 296 | bn: bool, whether to use batch norm 297 | bn_decay: float or float tensor variable in [0,1] 298 | is_training: bool Tensor variable 299 | 300 | Returns: 301 | Variable tensor 302 | """ 303 | with tf.variable_scope(scope) as sc: 304 | kernel_d, kernel_h, kernel_w = kernel_size 305 | num_in_channels = inputs.get_shape()[-1].value 306 | kernel_shape = [kernel_d, kernel_h, kernel_w, 307 | num_in_channels, num_output_channels] 308 | kernel = _variable_with_weight_decay('weights', 309 | shape=kernel_shape, 310 | use_xavier=use_xavier, 311 | stddev=stddev, 312 | wd=weight_decay) 313 | stride_d, stride_h, stride_w = stride 314 | outputs = tf.nn.conv3d(inputs, kernel, 315 | [1, stride_d, stride_h, stride_w, 1], 316 | padding=padding) 317 | biases = _variable_on_cpu('biases', [num_output_channels], 318 | tf.constant_initializer(0.0)) 319 | outputs = tf.nn.bias_add(outputs, biases) 320 | 321 | if bn: 322 | outputs = batch_norm_for_conv3d(outputs, is_training, 323 | bn_decay=bn_decay, scope='bn') 324 | 325 | if activation_fn is not None: 326 | outputs = activation_fn(outputs) 327 | return outputs 328 | 329 | 330 | def fully_connected(inputs, 331 | num_outputs, 332 | scope, 333 | use_xavier=True, 334 | stddev=1e-3, 335 | weight_decay=None, 336 | activation_fn=tf.nn.relu, 337 | bn=False, 338 | bn_decay=None, 339 | is_training=None, 340 | reuse=False, 341 | weights_initializer=None, 342 | biases_initializer=tf.constant_initializer(0.0)): 343 | """ Fully connected layer with non-linear operation. 344 | 345 | Args: 346 | inputs: 2-D tensor BxN 347 | num_outputs: int 348 | 349 | Returns: 350 | Variable tensor of size B x num_outputs. 351 | """ 352 | with tf.variable_scope(scope, reuse=reuse) as sc: 353 | num_input_units = inputs.get_shape()[-1].value 354 | weights = _variable_with_weight_decay('weights', 355 | shape=[num_input_units, num_outputs], 356 | use_xavier=use_xavier, 357 | stddev=stddev, 358 | wd=weight_decay, 359 | initializer=weights_initializer) 360 | outputs = tf.matmul(inputs, weights) 361 | biases = _variable_on_cpu('biases', [num_outputs], biases_initializer) 362 | outputs = tf.nn.bias_add(outputs, biases) 363 | 364 | if bn: 365 | outputs = batch_norm_for_fc(outputs, is_training, bn_decay, 'bn') 366 | 367 | if activation_fn is not None: 368 | outputs = activation_fn(outputs) 369 | return outputs 370 | 371 | 372 | def max_pool2d(inputs, 373 | kernel_size, 374 | scope, 375 | stride=[2, 2], 376 | padding='VALID'): 377 | """ 2D max pooling. 378 | 379 | Args: 380 | inputs: 4-D tensor BxHxWxC 381 | kernel_size: a list of 2 ints 382 | stride: a list of 2 ints 383 | 384 | Returns: 385 | Variable tensor 386 | """ 387 | with tf.variable_scope(scope) as sc: 388 | kernel_h, kernel_w = kernel_size 389 | stride_h, stride_w = stride 390 | outputs = tf.nn.max_pool(inputs, 391 | ksize=[1, kernel_h, kernel_w, 1], 392 | strides=[1, stride_h, stride_w, 1], 393 | padding=padding, 394 | name=sc.name) 395 | return outputs 396 | 397 | 398 | def avg_pool2d(inputs, 399 | kernel_size, 400 | scope, 401 | stride=[2, 2], 402 | padding='VALID'): 403 | """ 2D avg pooling. 404 | 405 | Args: 406 | inputs: 4-D tensor BxHxWxC 407 | kernel_size: a list of 2 ints 408 | stride: a list of 2 ints 409 | 410 | Returns: 411 | Variable tensor 412 | """ 413 | with tf.variable_scope(scope) as sc: 414 | kernel_h, kernel_w = kernel_size 415 | stride_h, stride_w = stride 416 | outputs = tf.nn.avg_pool(inputs, 417 | ksize=[1, kernel_h, kernel_w, 1], 418 | strides=[1, stride_h, stride_w, 1], 419 | padding=padding, 420 | name=sc.name) 421 | return outputs 422 | 423 | 424 | def max_pool3d(inputs, 425 | kernel_size, 426 | scope, 427 | stride=[2, 2, 2], 428 | padding='VALID'): 429 | """ 3D max pooling. 430 | 431 | Args: 432 | inputs: 5-D tensor BxDxHxWxC 433 | kernel_size: a list of 3 ints 434 | stride: a list of 3 ints 435 | 436 | Returns: 437 | Variable tensor 438 | """ 439 | with tf.variable_scope(scope) as sc: 440 | kernel_d, kernel_h, kernel_w = kernel_size 441 | stride_d, stride_h, stride_w = stride 442 | outputs = tf.nn.max_pool3d(inputs, 443 | ksize=[1, kernel_d, kernel_h, kernel_w, 1], 444 | strides=[1, stride_d, stride_h, stride_w, 1], 445 | padding=padding, 446 | name=sc.name) 447 | return outputs 448 | 449 | 450 | def avg_pool3d(inputs, 451 | kernel_size, 452 | scope, 453 | stride=[2, 2, 2], 454 | padding='VALID'): 455 | """ 3D avg pooling. 456 | 457 | Args: 458 | inputs: 5-D tensor BxDxHxWxC 459 | kernel_size: a list of 3 ints 460 | stride: a list of 3 ints 461 | 462 | Returns: 463 | Variable tensor 464 | """ 465 | with tf.variable_scope(scope) as sc: 466 | kernel_d, kernel_h, kernel_w = kernel_size 467 | stride_d, stride_h, stride_w = stride 468 | outputs = tf.nn.avg_pool3d(inputs, 469 | ksize=[1, kernel_d, kernel_h, kernel_w, 1], 470 | strides=[1, stride_d, stride_h, stride_w, 1], 471 | padding=padding, 472 | name=sc.name) 473 | return outputs 474 | 475 | 476 | def batch_norm_template_unused(inputs, is_training, scope, moments_dims, bn_decay): 477 | """ NOTE: this is older version of the util func. it is deprecated. 478 | Batch normalization on convolutional maps and beyond... 479 | Ref.: http://stackoverflow.com/questions/33949786/how-could-i-use-batch-normalization-in-tensorflow 480 | 481 | Args: 482 | inputs: Tensor, k-D input ... x C could be BC or BHWC or BDHWC 483 | is_training: boolean tf.Varialbe, true indicates training phase 484 | scope: string, variable scope 485 | moments_dims: a list of ints, indicating dimensions for moments calculation 486 | bn_decay: float or float tensor variable, controling moving average weight 487 | Return: 488 | normed: batch-normalized maps 489 | """ 490 | with tf.variable_scope(scope) as sc: 491 | num_channels = inputs.get_shape()[-1].value 492 | beta = _variable_on_cpu(name='beta', shape=[num_channels], 493 | initializer=tf.constant_initializer(0)) 494 | gamma = _variable_on_cpu(name='gamma', shape=[num_channels], 495 | initializer=tf.constant_initializer(1.0)) 496 | batch_mean, batch_var = tf.nn.moments(inputs, moments_dims, name='moments') 497 | decay = bn_decay if bn_decay is not None else 0.9 498 | ema = tf.train.ExponentialMovingAverage(decay=decay) 499 | # Operator that maintains moving averages of variables. 500 | # Need to set reuse=False, otherwise if reuse, will see moments_1/mean/ExponentialMovingAverage/ does not exist 501 | # https://github.com/shekkizh/WassersteinGAN.tensorflow/issues/3 502 | with tf.variable_scope(tf.get_variable_scope(), reuse=False): 503 | ema_apply_op = tf.cond(is_training, 504 | lambda: ema.apply([batch_mean, batch_var]), 505 | lambda: tf.no_op()) 506 | 507 | # Update moving average and return current batch's avg and var. 508 | def mean_var_with_update(): 509 | with tf.control_dependencies([ema_apply_op]): 510 | return tf.identity(batch_mean), tf.identity(batch_var) 511 | 512 | # ema.average returns the Variable holding the average of var. 513 | mean, var = tf.cond(is_training, 514 | mean_var_with_update, 515 | lambda: (ema.average(batch_mean), ema.average(batch_var))) 516 | normed = tf.nn.batch_normalization(inputs, mean, var, beta, gamma, 1e-3) 517 | return normed 518 | 519 | 520 | def batch_norm_template(inputs, is_training, scope, moments_dims_unused, bn_decay, data_format='NHWC'): 521 | """ Batch normalization on convolutional maps and beyond... 522 | Ref.: http://stackoverflow.com/questions/33949786/how-could-i-use-batch-normalization-in-tensorflow 523 | 524 | Args: 525 | inputs: Tensor, k-D input ... x C could be BC or BHWC or BDHWC 526 | is_training: boolean tf.Varialbe, true indicates training phase 527 | scope: string, variable scope 528 | moments_dims: a list of ints, indicating dimensions for moments calculation 529 | bn_decay: float or float tensor variable, controling moving average weight 530 | data_format: 'NHWC' or 'NCHW' 531 | Return: 532 | normed: batch-normalized maps 533 | """ 534 | bn_decay = bn_decay if bn_decay is not None else 0.9 535 | return tf.contrib.layers.batch_norm(inputs, 536 | center=True, scale=True, 537 | is_training=is_training, decay=bn_decay, updates_collections=None, 538 | scope=scope, 539 | data_format=data_format) 540 | 541 | 542 | def batch_norm_for_fc(inputs, is_training, bn_decay, scope): 543 | """ Batch normalization on FC data. 544 | 545 | Args: 546 | inputs: Tensor, 2D BxC input 547 | is_training: boolean tf.Varialbe, true indicates training phase 548 | bn_decay: float or float tensor variable, controling moving average weight 549 | scope: string, variable scope 550 | Return: 551 | normed: batch-normalized maps 552 | """ 553 | return batch_norm_template(inputs, is_training, scope, [0, ], bn_decay) 554 | 555 | 556 | def batch_norm_for_conv1d(inputs, is_training, bn_decay, scope, data_format): 557 | """ Batch normalization on 1D convolutional maps. 558 | 559 | Args: 560 | inputs: Tensor, 3D BLC input maps 561 | is_training: boolean tf.Varialbe, true indicates training phase 562 | bn_decay: float or float tensor variable, controling moving average weight 563 | scope: string, variable scope 564 | data_format: 'NHWC' or 'NCHW' 565 | Return: 566 | normed: batch-normalized maps 567 | """ 568 | return batch_norm_template(inputs, is_training, scope, [0, 1], bn_decay, data_format) 569 | 570 | 571 | def batch_norm_for_conv2d(inputs, is_training, bn_decay, scope, data_format): 572 | """ Batch normalization on 2D convolutional maps. 573 | 574 | Args: 575 | inputs: Tensor, 4D BHWC input maps 576 | is_training: boolean tf.Varialbe, true indicates training phase 577 | bn_decay: float or float tensor variable, controling moving average weight 578 | scope: string, variable scope 579 | data_format: 'NHWC' or 'NCHW' 580 | Return: 581 | normed: batch-normalized maps 582 | """ 583 | return batch_norm_template(inputs, is_training, scope, [0, 1, 2], bn_decay, data_format) 584 | 585 | 586 | def batch_norm_for_conv3d(inputs, is_training, bn_decay, scope): 587 | """ Batch normalization on 3D convolutional maps. 588 | 589 | Args: 590 | inputs: Tensor, 5D BDHWC input maps 591 | is_training: boolean tf.Varialbe, true indicates training phase 592 | bn_decay: float or float tensor variable, controling moving average weight 593 | scope: string, variable scope 594 | Return: 595 | normed: batch-normalized maps 596 | """ 597 | return batch_norm_template(inputs, is_training, scope, [0, 1, 2, 3], bn_decay) 598 | 599 | 600 | def dropout(inputs, 601 | is_training, 602 | scope, 603 | keep_prob=0.5, 604 | noise_shape=None): 605 | """ Dropout layer. 606 | 607 | Args: 608 | inputs: tensor 609 | is_training: boolean tf.Variable 610 | scope: string 611 | keep_prob: float in [0,1] 612 | noise_shape: list of ints 613 | 614 | Returns: 615 | tensor variable 616 | """ 617 | with tf.variable_scope(scope) as sc: 618 | outputs = tf.cond(is_training, 619 | lambda: tf.nn.dropout(inputs, keep_prob, noise_shape), 620 | lambda: inputs) 621 | return outputs 622 | --------------------------------------------------------------------------------