├── README.md ├── callbacks.py ├── data_loader.py ├── download_modelnet40.sh ├── explaining_gapnet.ipynb ├── gapnet ├── __init__.py ├── layers.py ├── models.py └── utils.py ├── model_cls.py ├── prepare_data.py ├── resources ├── GAPNet Attention.png ├── GAPNet network.png └── GAPNet-input_transform_net.png ├── schedules.py ├── tests └── Tests.py ├── todos.md └── train_cls.py /README.md: -------------------------------------------------------------------------------- 1 | # Pointcloud Experiments with ModelNet40. Includes PointNet and GAPNet 2 | 3 | The purpose of this repository is to provide a clean implementation of GAPNet. In order to facilitate comparability, it includes PointNet as a baseline. The goal is to come up with a readable implementation of GAPNet that outperforms PointNet. 4 | 5 | # Always star if you like. 6 | 7 | If you enjoy this repo, please give it a star. That would be very appreciated! 8 | 9 | # Getting in touch. 10 | 11 | If you got any bug reports or feature requests, please open an issue here on GitHub. 12 | 13 | - [Subscribe to my YouTube channel](https://www.youtube.com/channel/UCcMEBxcDM034JyJ8J3cggRg?view_as=subscriber). 14 | - [Become a patron](https://www.patreon.com/ai_guru). 15 | - [Subscribe to my newsletter](http://ai-guru.de/newsletter/). 16 | - [Visit by homepage/blog](http://ai-guru.de/). 17 | - [Join me on Slack](https://join.slack.com/t/ai-guru/shared_invite/enQtNDEzNjUwMTIwODM0LTdlOWQ1ZTUyZmQ5YTczOTUxYzk2YWI4ZmE0NTdmZGQxMmUxYmUwYmRhMDg1ZDU0NTUxMDI2OWVkOGFjYTViOGQ). 18 | - [Add me on LinkedIn](https://www.linkedin.com/in/dr-tristan-behrens-ai-guru-734967a2/). 19 | - [Add me on Facebook](https://www.facebook.com/AIGuruTristanBehrens). 20 | 21 | 22 | ## Acknowledgements. 23 | 24 | Based heavily on: 25 | 26 | - [https://github.com/TianzhongSong/PointNet-Keras](https://github.com/TianzhongSong/PointNet-Keras) 27 | - [https://github.com/FrankCAN/GAPointNet](https://github.com/FrankCAN/GAPointNet) 28 | 29 | Those people did some really, really great work! 30 | 31 | ## Networks. 32 | 33 | ### 1. PointNet. 34 | 35 | Reference: [https://arxiv.org/abs/1612.00593](https://arxiv.org/abs/1612.00593) 36 | 37 | PointNet is included as a baseline. 38 | 39 | ### 2. GAPNet. 40 | 41 | Reference: [https://arxiv.org/abs/1905.08705](https://arxiv.org/abs/1905.08705) 42 | 43 | GAPNet combines the Graph Neural Network approach with attention. 44 | 45 | ![GAPNet](https://raw.githubusercontent.com/AI-Guru/pointcloud_experiments/master/resources/GAPNet%20network.png) 46 | 47 | ## How to run. 48 | 49 | Download ModelNet40 data: 50 | 51 | `sh download_modelnet40.sh` 52 | 53 | Prepare data: 54 | 55 | `python prepare_data.py` 56 | 57 | Train PointNet: 58 | 59 | `python train_cls.py pointnet` 60 | 61 | Train GAPNet draft: 62 | 63 | `python train_cls.py gapnet_dev` 64 | 65 | ## Notes. 66 | 67 | This project is under heavy construction. 68 | -------------------------------------------------------------------------------- /callbacks.py: -------------------------------------------------------------------------------- 1 | import tensorflow.keras.backend as K 2 | from tensorflow.keras.callbacks import Callback, ModelCheckpoint 3 | import yaml 4 | import h5py 5 | import numpy as np 6 | 7 | class Step(Callback): 8 | 9 | def __init__(self, steps, learning_rates, verbose=0): 10 | self.steps = steps 11 | self.lr = learning_rates 12 | self.verbose = verbose 13 | 14 | def change_lr(self, new_lr): 15 | old_lr = K.get_value(self.model.optimizer.lr) 16 | K.set_value(self.model.optimizer.lr, new_lr) 17 | if self.verbose == 1: 18 | print('Learning rate is %g' %new_lr) 19 | 20 | def on_epoch_begin(self, epoch, logs={}): 21 | for i, step in enumerate(self.steps): 22 | if epoch < step: 23 | self.change_lr(self.lr[i]) 24 | return 25 | self.change_lr(self.lr[i+1]) 26 | 27 | #def on_train_begin(self, logs={}): 28 | # pass 29 | 30 | def get_config(self): 31 | config = {'class': type(self).__name__, 32 | 'steps': self.steps, 33 | 'learning_rates': self.lr, 34 | 'verbose': self.verbose} 35 | return config 36 | 37 | @classmethod 38 | def from_config(cls, config): 39 | offset = config.get('epoch_offset', 0) 40 | steps = [step - offset for step in config['steps']] 41 | return cls(steps, config['learning_rates'], 42 | verbose=config.get('verbose', 0)) 43 | 44 | class TriangularCLR(Callback): 45 | 46 | def __init__(self, learning_rates, half_cycle): 47 | self.lr = learning_rates 48 | self.hc = half_cycle 49 | 50 | def on_train_begin(self, logs={}): 51 | # Setup an iteration counter 52 | self.itr = -1 53 | 54 | def on_batch_begin(self, batch, logs={}): 55 | self.itr += 1 56 | cycle = 1 + self.itr/int(2*self.hc) 57 | x = self.itr - (2.*cycle - 1)*self.hc 58 | x /= self.hc 59 | new_lr = self.lr[0] + (self.lr[1] - self.lr[0])*(1 - abs(x))/cycle 60 | 61 | K.set_value(self.model.optimizer.lr, new_lr) 62 | 63 | 64 | class MetaCheckpoint(ModelCheckpoint): 65 | """ 66 | Checkpoints some training information with the model. This should enable 67 | resuming training and having training information on every checkpoint. 68 | Thanks to Roberto Estevao @robertomest - robertomest@poli.ufrj.br 69 | """ 70 | 71 | def __init__(self, filepath, monitor='val_loss', verbose=0, 72 | save_best_only=False, save_weights_only=False, 73 | mode='auto', period=1, training_args=None, meta=None): 74 | 75 | super(MetaCheckpoint, self).__init__(filepath, monitor='val_loss', 76 | verbose=0, save_best_only=False, 77 | save_weights_only=False, 78 | mode='auto', period=1) 79 | 80 | self.filepath = filepath 81 | self.meta = meta or {'epochs': []} 82 | 83 | if training_args: 84 | self.meta['training_args'] = training_args 85 | 86 | def on_train_begin(self, logs={}): 87 | super(MetaCheckpoint, self).on_train_begin(logs) 88 | 89 | def on_epoch_end(self, epoch, logs={}): 90 | super(MetaCheckpoint, self).on_epoch_end(epoch, logs) 91 | 92 | # Get statistics 93 | self.meta['epochs'].append(epoch) 94 | for k, v in logs.items(): 95 | # Get default gets the value or sets (and gets) the default value 96 | self.meta.setdefault(k, []).append(v) 97 | 98 | # Save to file 99 | filepath = self.filepath.format(epoch=epoch, **logs) 100 | 101 | if self.epochs_since_last_save == 0: 102 | with h5py.File(filepath, 'r+') as f: 103 | meta_group = f.create_group('meta') 104 | meta_group.attrs['training_args'] = yaml.dump( 105 | self.meta.get('training_args', '{}')) 106 | meta_group.create_dataset('epochs', 107 | data=np.array(self.meta['epochs'])) 108 | for k in logs: 109 | meta_group.create_dataset(k, data=np.array(self.meta[k])) 110 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data generator for ModelNet40 3 | reference: https://github.com/garyli1019/pointnet-keras 4 | Date: 08/13/2018 5 | Author: Tianzhong 6 | """ 7 | import numpy as np 8 | import h5py 9 | import random 10 | from tensorflow.keras.utils import to_categorical 11 | 12 | 13 | class DataGenerator: 14 | def __init__(self, file_name, batch_size, number_of_points, nb_classes=40, train=True): 15 | self.fie_name = file_name 16 | self.batch_size = batch_size 17 | self.number_of_points = number_of_points 18 | self.nb_classes = nb_classes 19 | self.train = train 20 | 21 | if number_of_points not in [1024, 2048]: 22 | raise Exception("Invalid number of points.") 23 | 24 | @staticmethod 25 | def rotate_point_cloud(data): 26 | """ Randomly rotate the point clouds to augument the dataset 27 | rotation is per shape based along up direction 28 | Input: 29 | Nx3 array, original point clouds 30 | Return: 31 | Nx3 array, rotated point clouds 32 | """ 33 | rotation_angle = np.random.uniform() * 2 * np.pi 34 | cosval = np.cos(rotation_angle) 35 | sinval = np.sin(rotation_angle) 36 | rotation_matrix = np.array([[cosval, 0, sinval], 37 | [0, 1, 0], 38 | [-sinval, 0, cosval]]) 39 | rotated_data = np.dot(data.reshape((-1, 3)), rotation_matrix) 40 | return rotated_data 41 | 42 | @staticmethod 43 | def jitter_point_cloud(data, sigma=0.01, clip=0.05): 44 | """ Randomly jitter points. jittering is per point. 45 | Input: 46 | Nx3 array, original point clouds 47 | Return: 48 | Nx3 array, jittered point clouds 49 | """ 50 | N, C = data.shape 51 | assert (clip > 0) 52 | jittered_data = np.clip(sigma * np.random.randn(N, C), -1 * clip, clip) 53 | jittered_data += data 54 | return jittered_data 55 | 56 | def generator(self): 57 | f = h5py.File(self.fie_name, mode='r') 58 | nb_sample = f['data'].shape[0] 59 | while True: 60 | index = [n for n in range(nb_sample)] 61 | random.shuffle(index) 62 | for i in range(nb_sample // self.batch_size): 63 | batch_start = i * self.batch_size 64 | batch_end = (i + 1) * self.batch_size 65 | batch_index = index[batch_start: batch_end] 66 | X = [] 67 | Y = [] 68 | for j in batch_index: 69 | 70 | # Get input and output. 71 | item = f['data'][j] 72 | label = f['label'][j] 73 | 74 | # Downsample. 75 | if self.number_of_points == 1024: 76 | item = item[::2,:] 77 | 78 | # Data augmentation. 79 | if self.train: 80 | is_rotate = random.randint(0, 1) 81 | is_jitter = random.randint(0, 1) 82 | if is_rotate == 1: 83 | item = self.rotate_point_cloud(item) 84 | if is_jitter == 1: 85 | item = self.jitter_point_cloud(item) 86 | X.append(item) 87 | Y.append(label[0]) 88 | Y = to_categorical(np.array(Y), self.nb_classes) 89 | yield np.array(X), Y 90 | -------------------------------------------------------------------------------- /download_modelnet40.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/sh 2 | 3 | wget https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip 4 | unzip modelnet40_ply_hdf5_2048.zip 5 | mkdir ModelNet40 6 | mkdir ModelNet40/test 7 | mkdir ModelNet40/train 8 | cp modelnet40_ply_hdf5_2048/ply_data_train*.h5 ModelNet40/train/ 9 | cp modelnet40_ply_hdf5_2048/ply_data_test*.h5 ModelNet40/test/ 10 | rm modelnet40_ply_hdf5_2048.zip 11 | rm -rf modelnet40_ply_hdf5_2048 12 | -------------------------------------------------------------------------------- /explaining_gapnet.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%matplotlib inline\n", 10 | "import matplotlib.pyplot as plt\n", 11 | "from mpl_toolkits.mplot3d import Axes3D\n", 12 | "from gapnet.models import GAPNet\n", 13 | "from data_loader import DataGenerator\n", 14 | "import numpy as np" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "model_weights_path = \"logs/gapnet/05-first_full_network/model_weights.h5\"\n", 24 | "model = GAPNet()\n", 25 | "model.load_weights(model_weights_path)\n", 26 | "\n", 27 | "explain_model = model.create_explaining_model()" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "# Data preparation.\n", 37 | "test_file = './ModelNet40/ply_data_test.h5'\n", 38 | "\n", 39 | "# Hyperparameters.\n", 40 | "nb_classes = 40\n", 41 | "number_of_points = 1024\n", 42 | "epochs = 100\n", 43 | "batch_size = 1\n", 44 | "\n", 45 | "# Data generators for training and validation.\n", 46 | "val = DataGenerator(test_file, batch_size, number_of_points, nb_classes, train=True).generator() " 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "def render_pointcloud(points, title=None):\n", 56 | " \"\"\"\n", 57 | " Renders a point-cloud.\n", 58 | " \"\"\"\n", 59 | "\n", 60 | " fig = plt.figure(figsize=(10, 10))\n", 61 | " ax = fig.add_subplot(111, projection='3d')\n", 62 | "\n", 63 | " ax.scatter(points[:,0], points[:,1], points[:,2], s=4.0, cmap=\"gray\", alpha=0.5)\n", 64 | "\n", 65 | " ax.set_xlabel(\"x\")\n", 66 | " ax.set_ylabel(\"y\")\n", 67 | " ax.set_zlabel(\"z\")\n", 68 | "\n", 69 | " if title != None:\n", 70 | " plt.title(title)\n", 71 | "\n", 72 | " plt.show()\n", 73 | " plt.close()\n", 74 | "\n", 75 | "def render_attention(attention):\n", 76 | " plt.imshow(np.squeeze(attention))\n", 77 | " \n", 78 | " plt.show()\n", 79 | " plt.close()\n", 80 | "\n", 81 | "# Get a sample and render it. \n", 82 | "sample = next(val)[0]\n", 83 | "render_pointcloud(sample[0])\n", 84 | "\n", 85 | "# Use the model for prediction.\n", 86 | "prediction = explain_model.predict(sample)\n", 87 | "\n", 88 | "# Render the transformed pointcloud.\n", 89 | "point_cloud_transformed = prediction[0][0]\n", 90 | "render_pointcloud(point_cloud_transformed)\n", 91 | "\n", 92 | "# Render the one head attention.\n", 93 | "one_head_attention = prediction[1][0]\n", 94 | "render_attention(one_head_attention)\n", 95 | "\n", 96 | "# Render the four head attention.\n", 97 | "four_head_attention = prediction[2][0]\n", 98 | "render_attention(four_head_attention)\n" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": null, 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [] 114 | } 115 | ], 116 | "metadata": { 117 | "kernelspec": { 118 | "display_name": "Python 3", 119 | "language": "python", 120 | "name": "python3" 121 | }, 122 | "language_info": { 123 | "codemirror_mode": { 124 | "name": "ipython", 125 | "version": 3 126 | }, 127 | "file_extension": ".py", 128 | "mimetype": "text/x-python", 129 | "name": "python", 130 | "nbconvert_exporter": "python", 131 | "pygments_lexer": "ipython3", 132 | "version": "3.6.5" 133 | } 134 | }, 135 | "nbformat": 4, 136 | "nbformat_minor": 2 137 | } 138 | -------------------------------------------------------------------------------- /gapnet/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /gapnet/layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | #assert tf.__version__.startswith("1.1"), "Expected tensorflow 1.1X, got {}".format(tf.__version__) 3 | import tensorflow.keras.backend as K 4 | from tensorflow.keras import models, layers 5 | from .utils import assert_shape_is 6 | import numpy as np 7 | 8 | class KNN(tf.keras.layers.Layer): 9 | """ 10 | For a given sequence of vectors, computes the k-nearest neighbors. 11 | """ 12 | 13 | def __init__(self, k, **kwargs): 14 | self.k = k 15 | super(KNN, self).__init__(**kwargs) 16 | 17 | 18 | def build(self, input_shape): 19 | 20 | super(KNN, self).build(input_shape) 21 | 22 | 23 | def call(self, input): 24 | 25 | point_cloud = input 26 | 27 | point_cloud_transpose = K.permute_dimensions(point_cloud, [0, 2, 1]) 28 | 29 | # Compute distances. 30 | point_cloud_inner = tf.matmul(point_cloud, point_cloud_transpose) 31 | point_cloud_inner = -2 * point_cloud_inner 32 | point_cloud_square = tf.reduce_sum(tf.square(point_cloud), axis=-1, keepdims=True) 33 | point_cloud_square_tranpose = tf.transpose(point_cloud_square, perm=[0, 2, 1]) 34 | adj_matrix = point_cloud_square + point_cloud_inner + point_cloud_square_tranpose 35 | 36 | # Compute indices. 37 | neg_adj = -adj_matrix 38 | _, nn_idx = tf.nn.top_k(neg_adj, k=self.k) 39 | 40 | # Compute the neighbors. 41 | batch_size = tf.shape(point_cloud)[0] # Note: Treat batch-size differently. 42 | num_points = point_cloud.get_shape()[1] 43 | num_dims = point_cloud.get_shape()[2] 44 | idx_ = tf.range(batch_size) * num_points 45 | idx_ = tf.reshape(idx_, [-1, 1, 1]) 46 | point_cloud_flat = tf.reshape(point_cloud, [-1, num_dims]) 47 | point_cloud_neighbors = tf.gather(point_cloud_flat, nn_idx + idx_) 48 | 49 | return point_cloud_neighbors 50 | 51 | 52 | class GraphAttention(tf.keras.layers.Layer): 53 | """ 54 | the single-head GAPLayer learns self-attention and neighboring-attention 55 | features in parallel that are then fused together by a non-linear activation 56 | function leaky RELU to obtain attention coefficients, which are further 57 | normalized by a softmax function, then a linear combination operation is 58 | applied to finally generate attention feature. MLP{} denotes multi-layer 59 | perceptron operation, numbers in brace stand for size of a set of filters, 60 | and we use the same notation for the remainder. 61 | """ 62 | 63 | def __init__(self, features_out, batch_normalization=True, **kwargs): 64 | 65 | self.features_out = features_out 66 | self.batch_normalization = batch_normalization=True 67 | 68 | # Call super. 69 | super(GraphAttention, self).__init__(**kwargs) 70 | 71 | 72 | def build(self, input_shapes): 73 | 74 | assert len(input_shapes) == 2 75 | 76 | point_cloud_shape = input_shapes[0].as_list() 77 | self.number_of_points = point_cloud_shape[1] 78 | self.features_in = point_cloud_shape[-1] 79 | 80 | knn_shape = input_shapes[1].as_list() 81 | assert knn_shape[1] == point_cloud_shape[1] 82 | self.k = knn_shape[2] 83 | 84 | # MLP 1 for self attention. 85 | self.self_attention_mlp1 = layers.Dense( 86 | self.features_out, 87 | activation="relu", 88 | name=self.name + "_self_attention_mlp1" 89 | ) 90 | if self.batch_normalization == True: 91 | self.self_attention_bn1 = layers.BatchNormalization() 92 | 93 | # MLP 2 for self attention. 94 | self.self_attention_mlp2 = layers.Dense( 95 | 1, 96 | activation="relu", 97 | name=self.name + "_self_attention_mlp2" 98 | ) 99 | if self.batch_normalization == True: 100 | self.self_attention_bn2 = layers.BatchNormalization() 101 | 102 | # MLP 1 for neighbor attention. 103 | self.neighbor_attention_mlp1 = layers.Dense( 104 | self.features_out, 105 | activation="relu", 106 | name=self.name + "_neighbor_attention_mlp1" 107 | ) 108 | if self.batch_normalization == True: 109 | self.neighbor_attention_bn1 = layers.BatchNormalization() 110 | 111 | # MLP 2 for neighbor attention. 112 | self.neighbor_attention_mlp2 = layers.Dense( 113 | 1, 114 | activation="relu", 115 | name=self.name + "_neighbor_attention_mlp2" 116 | ) 117 | if self.batch_normalization == True: 118 | self.neighbor_attention_bn2 = layers.BatchNormalization() 119 | 120 | # Final bias. 121 | self.output_bias = self.add_variable( 122 | "kernel", 123 | shape=[self.number_of_points, 1, self.features_out]) 124 | 125 | # Call super. 126 | super(GraphAttention, self).build(input_shapes) 127 | 128 | 129 | def call(self, inputs): 130 | 131 | # The first part of the input is the pointcloud. 132 | point_cloud = inputs[0] 133 | assert_shape_is(point_cloud, (1024, 3)) 134 | 135 | # The second part of the input are the KNNs. 136 | knn = inputs[1] 137 | assert_shape_is(knn, (1024, 20, 3)) 138 | 139 | # Reshape the pointcloud if necessary. 140 | if len(point_cloud.shape) == 4: 141 | pass 142 | elif len(point_cloud.shape) == 3: 143 | point_cloud = K.expand_dims(point_cloud, axis=2) 144 | else: 145 | raise Exception("Invalid shape!") 146 | assert_shape_is(point_cloud, (1024, 1, 3)) 147 | 148 | # Tile the pointcloud to make it compatible with KNN. 149 | point_cloud_tiled = K.tile(point_cloud, [1, 1, self.k, 1]) 150 | assert_shape_is(point_cloud_tiled, (1024, 20, 3)) 151 | 152 | # Compute difference between tiled pointcloud and knn. 153 | point_cloud_knn_difference = point_cloud_tiled - knn 154 | assert_shape_is(point_cloud_knn_difference, (1024, 20, 3)) 155 | 156 | # MLP 1 for self attention including batch normalization. 157 | self_attention = self.self_attention_mlp1(point_cloud) 158 | if self.batch_normalization == True: 159 | self_attention = self.self_attention_bn1(self_attention) 160 | assert_shape_is(self_attention, (1024, 1, 16)) 161 | 162 | # MLP 2 for self attention including batch normalization. 163 | self_attention = self.self_attention_mlp2(self_attention) 164 | if self.batch_normalization == True: 165 | self_attention = self.self_attention_bn2(self_attention) 166 | assert_shape_is(self_attention, (1024, 1, 1)) 167 | 168 | # MLP 1 for neighbor attention including batch normalization. 169 | neighbor_attention = self.neighbor_attention_mlp1(point_cloud_knn_difference) 170 | if self.batch_normalization == True: 171 | neighbor_attention = self.neighbor_attention_bn1(neighbor_attention) 172 | assert_shape_is(neighbor_attention, (1024, 20, 16)) 173 | 174 | # Graph features are the ouput of the first MLP. 175 | graph_features = neighbor_attention 176 | 177 | # MLP 2 for neighbor attention including batch normalization. 178 | neighbor_attention = self.neighbor_attention_mlp2(neighbor_attention) 179 | if self.batch_normalization == True: 180 | neighbor_attention = self.neighbor_attention_bn2(neighbor_attention) 181 | assert_shape_is(neighbor_attention, (1024, 20, 1)) 182 | 183 | # Merge self attention and neighbor attention to get attention coefficients. 184 | logits = self_attention + neighbor_attention 185 | assert_shape_is(logits, (1024, 20, 1)) 186 | logits = K.permute_dimensions(logits, (0, 1, 3, 2)) 187 | assert_shape_is(logits, (1024, 1, 20)) 188 | 189 | # Apply leaky relu and softmax to logits to get attention coefficents. 190 | logits = K.relu(logits, alpha=0.2) 191 | attention_coefficients = K.softmax(logits) 192 | assert_shape_is(attention_coefficients, (1024, 1, 20)) 193 | 194 | # Compute attention features from attention coefficients and graph features. 195 | attention_features = tf.matmul(attention_coefficients, graph_features) 196 | attention_features = tf.add(attention_features, self.output_bias) 197 | attention_features = K.relu(attention_features) 198 | assert_shape_is(attention_features, (1024, 1, 16)) 199 | 200 | # Reshape graph features. 201 | #graph_features = K.expand_dims(graph_features, axis=2) 202 | assert_shape_is(graph_features, (1024, 20, 16)) 203 | 204 | # Done. 205 | return attention_features, graph_features, attention_coefficients 206 | 207 | 208 | class MultiGraphAttention(tf.keras.layers.Layer): 209 | """ 210 | The GAPLayer with M heads, as shown in 2(a) , takes N points with F dimensions as input and concatenates attention feature and graph feature respectively from all heads to generate multi-attention features and multi-graph features as output. 211 | """ 212 | 213 | def __init__(self, k, features_out, heads, batch_normalization=True, **kwargs): 214 | 215 | self.k = k 216 | self.features_out = features_out 217 | self.heads = heads 218 | self.batch_normalization = batch_normalization 219 | #self.bn_decay + bn_decay 220 | 221 | # Call super. 222 | super(MultiGraphAttention, self).__init__(**kwargs) 223 | 224 | 225 | def build(self, input_shape): 226 | 227 | self.graph_attentions = [GraphAttention(features_out=self.features_out, batch_normalization=self.batch_normalization) for _ in range(self.heads)] 228 | 229 | # Call super. 230 | super(MultiGraphAttention, self).build(input_shape) 231 | 232 | 233 | def call(self, inputs): 234 | 235 | # Input for a pointcloud. 236 | point_cloud = inputs 237 | 238 | # Create the KNN layer and apply it to the input. 239 | knn = KNN(k=self.k, name=self.name + "_knn")(point_cloud) 240 | 241 | # Do multi-head attention. 242 | attention_features_list = [] 243 | graph_features_list = [] 244 | attention_coefficients_list = [] 245 | for head_index in range(self.heads): 246 | graph_attention = self.graph_attentions[head_index]([point_cloud, knn]) 247 | attention_features = graph_attention[0] 248 | graph_features = graph_attention[1] 249 | attention_coefficients = graph_attention[2] 250 | 251 | attention_features_list.append(attention_features) 252 | graph_features_list.append(graph_features) 253 | attention_coefficients_list.append(attention_coefficients) 254 | 255 | # Only one head. Return first element of lists. 256 | if self.heads == 1: 257 | multi_attention_features = attention_features_list[0] 258 | multi_graph_features = graph_features_list[0] 259 | multi_attention_coefficients = attention_coefficients_list[0] 260 | 261 | # More than one head. Stack. 262 | else: 263 | multi_attention_features = K.concatenate(attention_features_list, axis=3) 264 | multi_graph_features = K.concatenate(graph_features_list, axis=3) 265 | multi_attention_coefficients = K.concatenate(attention_coefficients_list, axis=3) 266 | 267 | assert_shape_is(multi_attention_features, (1024, 1, 16 * self.heads)) 268 | assert_shape_is(multi_graph_features, (1024, 20, 16 * self.heads)) 269 | assert_shape_is(multi_attention_coefficients, (1024, 1, 20 * self.heads)) 270 | 271 | # Done. 272 | return multi_attention_features, multi_graph_features, multi_attention_coefficients 273 | 274 | def compute_output_shape(self, input_shape): 275 | assert False 276 | return (input_shape[0], 1024, 1, 16 * self.heads) 277 | 278 | class Transform(tf.keras.layers.Layer): 279 | """ 280 | spatial transform network: The spatial transform network is used to make 281 | point cloud invariant to certain transformations. The model learns a 3 × 3 282 | matrix for affine transformation from a single-head GAPLayer with 16 283 | channels. 284 | """ 285 | 286 | def __init__(self, **kwargs): 287 | super(Transform, self).__init__(**kwargs) 288 | 289 | #w_init = tf.zeros_initializer() 290 | #self.transform_w = tf.Variable( 291 | # initial_value=w_init(shape=(256, 3 * 3), dtype='float32'), 292 | # trainable=True 293 | #) 294 | 295 | # Biases for the learned transformation matrix. 296 | #b_init = tf.zeros_initializer() 297 | #self.transform_b = tf.Variable( 298 | # initial_value=b_init(shape=(3 * 3,), dtype='float32'), 299 | # trainable=True 300 | # ) 301 | 302 | 303 | 304 | def build(self, input_shape): 305 | 306 | # Weights for the learned transformation matrix. 307 | self.transform_w = self.add_variable("transform_w", shape=(256, 3 * 3)) 308 | self.transform_b = self.add_variable("transform_b", shape=(3 * 3,)) 309 | 310 | # MLP 1 on attention features. 311 | self.mlp1 = layers.Dense(64, activation="linear") 312 | self.mlp_bn1 = layers.BatchNormalization() 313 | self.mlp_activation1 = layers.Activation("relu") 314 | 315 | # MLP 2 on attention features. 316 | self.mlp2 = layers.Dense(128, activation="linear") 317 | self.mlp_bn2 = layers.BatchNormalization() 318 | self.mlp_activation2 = layers.Activation("relu") 319 | 320 | # MLP 3 on attention features. 321 | self.mlp3 = layers.Dense(1024, activation="linear") 322 | self.mlp_bn3 = layers.BatchNormalization() 323 | self.mlp_activation3 = layers.Activation("relu") 324 | 325 | # Pooling layer. 326 | self.pooling = layers.MaxPooling2D((1024, 1)) 327 | 328 | # Flatten layer. 329 | self.flatten = layers.Flatten() 330 | 331 | # Dense 1 on attention features. 332 | self.dense1 = layers.Dense(512, activation="linear") 333 | self.bn1 = layers.BatchNormalization() 334 | self.activation1 = layers.Activation("relu") 335 | 336 | # Dense 2 on attention features. 337 | self.dense2 = layers.Dense(256, activation="linear") 338 | self.bn2 = layers.BatchNormalization() 339 | self.activation2 = layers.Activation("relu") 340 | 341 | super(Transform, self).build(input_shape) 342 | 343 | 344 | def call(self, inputs): 345 | 346 | assert len(inputs) == 3 347 | 348 | # Inputs are point cloud, attention features and graph features. 349 | point_cloud = inputs[0] 350 | attention_features = inputs[1] 351 | graph_features = inputs[2] 352 | assert_shape_is(point_cloud, (1024, 3)) 353 | assert_shape_is(attention_features, (1024, 1, 19)) 354 | assert_shape_is(graph_features, (1024, 20, 16)) 355 | 356 | # MLP 1 on attention features. 357 | net = self.mlp1(attention_features) 358 | net = self.mlp_bn1(net) 359 | net = self.mlp_activation1(net) 360 | assert_shape_is(net, (1024, 1, 64)) 361 | 362 | # MLP 2 on attention features. 363 | net = self.mlp2(net) 364 | net = self.mlp_bn2(net) 365 | net = self.mlp_activation2(net) 366 | assert_shape_is(net, (1024, 1, 128)) 367 | 368 | # Maximum for graph features. 369 | graph_features_max = tf.reduce_max(graph_features, axis=2, keepdims=True) 370 | assert_shape_is(graph_features_max, (1024, 1, 16)) 371 | 372 | # Concatenate with max. 373 | net = layers.concatenate([net, graph_features_max]) 374 | assert_shape_is(net, (1024, 1, 144)) 375 | 376 | # MLP 3 on attention features. 377 | net = self.mlp3(net) 378 | net = self.mlp_bn3(net) 379 | net = self.mlp_activation3(net) 380 | assert_shape_is(net, (1024, 1, 1024)) 381 | 382 | # Pooling layer. 383 | net = self.pooling(net) 384 | assert_shape_is(net, (1, 1, 1024)) 385 | 386 | # Flatten layer. 387 | net = self.flatten(net) 388 | assert_shape_is(net, (1024,)) 389 | 390 | # Dense 1. 391 | net = self.dense1(net) 392 | net = self.bn1(net) 393 | net = self.activation1(net) 394 | assert_shape_is(net, (512,)) 395 | 396 | # Dense 2. 397 | net = self.dense2(net) 398 | net = self.bn2(net) 399 | net = self.activation2(net) 400 | assert_shape_is(net, (256,)) 401 | 402 | # Compute the transformation matrix. 403 | transform = tf.matmul(net, self.transform_w) 404 | assert_shape_is(transform, (9,)) 405 | self.transform_b = self.transform_b + tf.constant(np.eye(3).flatten(), dtype=tf.float32) 406 | transform = tf.nn.bias_add(transform, self.transform_b) 407 | assert_shape_is(transform, (9,)) 408 | transform = K.reshape(transform, (-1, 3, 3)) 409 | assert_shape_is(transform, (3, 3)) 410 | 411 | point_cloud_transformed = tf.matmul(point_cloud, transform) 412 | assert_shape_is(point_cloud_transformed, (1024, 3)) 413 | return point_cloud_transformed 414 | 415 | def compute_output_shape(self, input_shape): 416 | return (input_shape[0], 1024, 3) 417 | -------------------------------------------------------------------------------- /gapnet/models.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import models, layers 3 | from tensorflow.keras import backend as K 4 | from .layers import KNN, Transform, MultiGraphAttention, GraphAttention 5 | from .utils import assert_shape_is 6 | 7 | 8 | class GAPNet(tf.keras.Model): 9 | 10 | def __init__(self, number_of_points=1024, features_in=3, k=20, features_out=16, **kwargs): 11 | super(GAPNet, self).__init__() 12 | 13 | self.number_of_points = number_of_points 14 | self.features_in = features_in 15 | self.k = k 16 | self.features_out = features_out 17 | 18 | self.build_graph(input_shape=(None, number_of_points, features_in)) 19 | 20 | 21 | def build_graph(self, input_shape): 22 | input_shape_nobatch = input_shape[1:] 23 | self.build(input_shape) 24 | inputs = tf.keras.Input(shape=input_shape_nobatch) 25 | 26 | if not hasattr(self, 'call'): 27 | raise AttributeError("User should define 'call' method in sub-class model!") 28 | 29 | _ = self.call(inputs) 30 | 31 | def build(self, input_shape, **kwargs): 32 | 33 | # Create attention layer with one head. 34 | self.onehead_attention = MultiGraphAttention(k=self.k, features_out=self.features_out, heads=1) 35 | 36 | # Create spatial transform layer. 37 | self.transform = Transform() 38 | 39 | # Create attention layer with four heads. 40 | self.fourhead_attention = MultiGraphAttention(k=self.k, features_out=self.features_out, heads=4) 41 | 42 | # MLP 1 on attention features. 43 | self.mlp1 = layers.Dense(64, activation="linear") 44 | self.mlp_bn1 = layers.BatchNormalization() 45 | self.mlp_activation1 = layers.Activation("relu") 46 | 47 | # MLP 2 on attention features. 48 | self.mlp2 = layers.Dense(64, activation="linear") 49 | self.mlp_bn2 = layers.BatchNormalization() 50 | self.mlp_activation2 = layers.Activation("relu") 51 | 52 | # MLP 3 on attention features. 53 | self.mlp3 = layers.Dense(64, activation="linear") 54 | self.mlp_bn3 = layers.BatchNormalization() 55 | self.mlp_activation3 = layers.Activation("relu") 56 | 57 | # MLP 4 on attention features. 58 | self.mlp4 = layers.Dense(128, activation="linear") 59 | self.mlp_bn4 = layers.BatchNormalization() 60 | self.mlp_activation4 = layers.Activation("relu") 61 | 62 | # MLP 5. 63 | self.mlp5 = layers.Dense(1024, activation="linear") 64 | self.mlp_bn5 = layers.BatchNormalization() 65 | self.mlp_activation5 = layers.Activation("relu") 66 | 67 | # Flatten. 68 | self.flatten = layers.Flatten() 69 | 70 | # Dense 1. 71 | self.dense1 = layers.Dense(512, activation="linear") 72 | self.dense_dropout1 = layers.Dropout(0.5) 73 | 74 | # Dense 1. 75 | self.dense2 = layers.Dense(256, activation="linear") 76 | self.dense_dropout2 = layers.Dropout(0.5) 77 | 78 | # Dense 1. 79 | self.dense3 = layers.Dense(40, activation="softmax") 80 | 81 | super(GAPNet, self).build(input_shape) 82 | 83 | 84 | def call(self, inputs): 85 | 86 | point_cloud = inputs 87 | self.point_cloud_in = point_cloud 88 | assert_shape_is(point_cloud, (1024, 3)) 89 | 90 | # First attention layer with one head. 91 | onehead_attention = self.onehead_attention(point_cloud) 92 | onehead_attention_features = onehead_attention[0] 93 | onehead_graph_features = onehead_attention[1] 94 | onehead_attention_coefficients = onehead_attention[2] 95 | self.onehead_attention_coefficients_out = onehead_attention_coefficients 96 | assert_shape_is(onehead_attention_features, (1024, 1, 16)) 97 | assert_shape_is(onehead_graph_features, (1024, 20, 16)) 98 | assert_shape_is(onehead_attention_coefficients, (1024, 1, 20)) 99 | 100 | # Skip connection from point cloud to attention features. 101 | point_cloud_expanded = K.expand_dims(point_cloud, axis=2) 102 | assert_shape_is(point_cloud_expanded, (1024, 1, 3)) 103 | onehead_attention_features = K.concatenate([onehead_attention_features, point_cloud_expanded]) 104 | assert_shape_is(onehead_attention_features, (1024, 1, 19)) 105 | del point_cloud_expanded 106 | 107 | # Spatial transform. 108 | point_cloud_transformed = self.transform([point_cloud, onehead_attention_features, onehead_graph_features]) 109 | assert_shape_is(point_cloud_transformed, (1024, 3)) 110 | self.point_cloud_transformed_out = point_cloud_transformed 111 | del point_cloud 112 | 113 | # Second attention layer with four head. 114 | fourhead_attention = self.fourhead_attention(point_cloud_transformed) 115 | fourhead_attention_features = fourhead_attention[0] 116 | fourhead_graph_features = fourhead_attention[1] 117 | fourhead_attention_coefficients = fourhead_attention[2] 118 | self.fourhead_attention_coefficients_out = fourhead_attention_coefficients 119 | assert_shape_is(fourhead_attention_features, (1024, 1, 64)) 120 | assert_shape_is(fourhead_graph_features, (1024, 20, 64)) 121 | assert_shape_is(fourhead_attention_coefficients, (1024, 1, 80)) 122 | 123 | # Skip connection from transformed point cloud to attention features. 124 | point_cloud_expanded = K.expand_dims(point_cloud_transformed, axis=2) 125 | assert_shape_is(point_cloud_expanded, (1024, 1, 3)) 126 | onehead_attention_features = K.concatenate([fourhead_attention_features, point_cloud_expanded]) 127 | assert_shape_is(onehead_attention_features, (1024, 1, 67)) 128 | 129 | # MLP 1 on attention features. 130 | net1 = self.mlp1(onehead_attention_features) 131 | net1 = self.mlp_bn1(net1) 132 | net1 = self.mlp_activation1(net1) 133 | assert_shape_is(net1, (1024, 1, 64)) 134 | 135 | # MLP 2 on attention features. 136 | net2 = self.mlp2(net1) 137 | net2 = self.mlp_bn2(net2) 138 | net2 = self.mlp_activation2(net2) 139 | assert_shape_is(net2, (1024, 1, 64)) 140 | 141 | # MLP 3 on attention features. 142 | net3 = self.mlp3(net2) 143 | net3 = self.mlp_bn3(net3) 144 | net3 = self.mlp_activation3(net3) 145 | assert_shape_is(net3, (1024, 1, 64)) 146 | 147 | # MLP 4 on attention features. 148 | net4 = self.mlp4(net3) 149 | net4 = self.mlp_bn4(net4) 150 | net4 = self.mlp_activation4(net4) 151 | assert_shape_is(net4, (1024, 1, 128)) 152 | 153 | # Maximum for graph features. 154 | fourhead_graph_features_max = tf.reduce_max(fourhead_graph_features, axis=2, keepdims=True) 155 | assert_shape_is(fourhead_graph_features_max, (1024, 1, 64)) 156 | 157 | # Concatenate all MLPs and maximum of graph features. 158 | net = layers.concatenate([net1, net2, net3, net4, fourhead_graph_features_max]) 159 | assert_shape_is(net, (1024, 1, 384)) 160 | 161 | # MLP 5. 162 | net = self.mlp5(net) 163 | net = self.mlp_bn5(net) 164 | net = self.mlp_activation5(net) 165 | assert_shape_is(net, (1024, 1, 1024)) 166 | 167 | # Maximum for net. 168 | net = K.max(net, axis=1, keepdims=True) 169 | assert_shape_is(net, (1, 1, 1024)) 170 | 171 | # Flatten. 172 | net = self.flatten(net) 173 | assert_shape_is(net, (1024,)) 174 | 175 | # Dense 1. 176 | net = self.dense1(net) 177 | net = self.dense_dropout1(net) 178 | assert_shape_is(net, (512,)) 179 | 180 | # Dense 2. 181 | net = self.dense2(net) 182 | net = self.dense_dropout2(net) 183 | assert_shape_is(net, (256,)) 184 | 185 | # Dense 3. 186 | net = self.dense3(net) 187 | assert_shape_is(net, (40,)) 188 | 189 | return net 190 | 191 | 192 | def create_explaining_model(self): 193 | """ 194 | Creates a neural network that has the auxilary outputs. 195 | """ 196 | 197 | input = self.point_cloud_in 198 | outputs = [ 199 | self.point_cloud_transformed_out, 200 | self.onehead_attention_coefficients_out, 201 | self.fourhead_attention_coefficients_out 202 | ] 203 | return models.Model(input, outputs) 204 | -------------------------------------------------------------------------------- /gapnet/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def assert_shape_is(tensor, expected_shape): 4 | assert tensor.dtype == tf.float32, str(tensor.dtype) 5 | assert isinstance(tensor, tf.Tensor), type(tensor) 6 | assert isinstance(expected_shape, list) or isinstance(expected_shape, tuple), type(expected_shape) 7 | tensor_shape = tensor.shape[1:] 8 | if tensor_shape != expected_shape: 9 | raise Exception("{} is not equal {}".format(tensor_shape, expected_shape)) 10 | -------------------------------------------------------------------------------- /model_cls.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.layers import Conv1D, MaxPooling1D, Flatten, Dropout, Input, BatchNormalization, Dense 2 | from tensorflow.keras.layers import Reshape, Lambda, concatenate 3 | from tensorflow.keras.models import Model 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | 8 | class MatMul(tf.keras.layers.Layer): 9 | 10 | def __init__(self, **kwargs): 11 | super(MatMul, self).__init__(**kwargs) 12 | 13 | def build(self, input_shape): 14 | # Used purely for shape validation. 15 | if not isinstance(input_shape, list): 16 | raise ValueError('`MatMul` layer should be called ' 17 | 'on a list of inputs') 18 | if len(input_shape) != 2: 19 | raise ValueError('The input of `MatMul` layer should be a list containing 2 elements') 20 | 21 | if len(input_shape[0]) != 3 or len(input_shape[1]) != 3: 22 | raise ValueError('The dimensions of each element of inputs should be 3') 23 | 24 | if input_shape[0][-1] != input_shape[1][1]: 25 | raise ValueError('The last dimension of inputs[0] should match the dimension 1 of inputs[1]') 26 | 27 | def call(self, inputs): 28 | if not isinstance(inputs, list): 29 | raise ValueError('A `MatMul` layer should be called ' 30 | 'on a list of inputs.') 31 | return tf.matmul(inputs[0], inputs[1]) 32 | 33 | def compute_output_shape(self, input_shape): 34 | output_shape = [input_shape[0][0], input_shape[0][1], input_shape[1][-1]] 35 | return tuple(output_shape) 36 | 37 | 38 | def create_pointnet(number_of_points, nb_classes): 39 | input_points = Input(shape=(number_of_points, 3)) 40 | # issues 41 | # input transformation net 42 | x = Conv1D(64, 1, activation='relu')(input_points) 43 | x = BatchNormalization()(x) 44 | x = Conv1D(128, 1, activation='relu')(x) 45 | x = BatchNormalization()(x) 46 | x = Conv1D(1024, 1, activation='relu')(x) 47 | x = BatchNormalization()(x) 48 | x = MaxPooling1D(pool_size=number_of_points)(x) 49 | 50 | x = Dense(512, activation='relu')(x) 51 | x = BatchNormalization()(x) 52 | x = Dense(256, activation='relu')(x) 53 | x = BatchNormalization()(x) 54 | 55 | x = Dense(9, weights=[np.zeros([256, 9]), np.array([1, 0, 0, 0, 1, 0, 0, 0, 1]).astype(np.float32)])(x) 56 | input_T = Reshape((3, 3))(x) 57 | 58 | # forward net 59 | g = MatMul()([input_points, input_T]) 60 | g = Conv1D(64, 1, activation='relu')(g) 61 | g = BatchNormalization()(g) 62 | g = Conv1D(64, 1, activation='relu')(g) 63 | g = BatchNormalization()(g) 64 | 65 | # feature transform net 66 | f = Conv1D(64, 1, activation='relu')(g) 67 | f = BatchNormalization()(f) 68 | f = Conv1D(128, 1, activation='relu')(f) 69 | f = BatchNormalization()(f) 70 | f = Conv1D(1024, 1, activation='relu')(f) 71 | f = BatchNormalization()(f) 72 | f = MaxPooling1D(pool_size=number_of_points)(f) 73 | f = Dense(512, activation='relu')(f) 74 | f = BatchNormalization()(f) 75 | f = Dense(256, activation='relu')(f) 76 | f = BatchNormalization()(f) 77 | f = Dense(64 * 64, weights=[np.zeros([256, 64 * 64]), np.eye(64).flatten().astype(np.float32)])(f) 78 | feature_T = Reshape((64, 64))(f) 79 | 80 | # forward net 81 | g = MatMul()([g, feature_T]) 82 | g = Conv1D(64, 1, activation='relu')(g) 83 | g = BatchNormalization()(g) 84 | g = Conv1D(128, 1, activation='relu')(g) 85 | g = BatchNormalization()(g) 86 | g = Conv1D(1024, 1, activation='relu')(g) 87 | g = BatchNormalization()(g) 88 | 89 | # global feature 90 | global_feature = MaxPooling1D(pool_size=number_of_points)(g) 91 | 92 | # point_net_cls 93 | c = Dense(512, activation='relu')(global_feature) 94 | c = BatchNormalization()(c) 95 | c = Dropout(0.5)(c) 96 | c = Dense(256, activation='relu')(c) 97 | c = BatchNormalization()(c) 98 | c = Dropout(0.5)(c) 99 | c = Dense(nb_classes, activation='softmax')(c) 100 | prediction = Flatten()(c) 101 | 102 | model = Model(inputs=input_points, outputs=prediction) 103 | 104 | return model 105 | -------------------------------------------------------------------------------- /prepare_data.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | import os 4 | 5 | import tensorflow as tf 6 | 7 | data_path = './ModelNet40/' 8 | 9 | for d in [['train', len(os.listdir(data_path + 'train'))], ['test', len(os.listdir(data_path + 'test'))]]: 10 | data = None 11 | labels = None 12 | for j in range(d[1]): 13 | file_name = data_path + d[0] + '/ply_data_{0}{1}.h5'.format(d[0], j) 14 | f = h5py.File(file_name, mode='r') 15 | if data is None: 16 | data = f['data'] 17 | labels = f['label'] 18 | else: 19 | data = np.vstack((data, f['data'])) 20 | labels = np.vstack((labels, f['label'])) 21 | f.close() 22 | save_name = data_path + '/ply_data_{0}.h5'.format(d[0]) 23 | print(data.shape) 24 | print(labels.shape) 25 | h5_fout = h5py.File(save_name) 26 | h5_fout.create_dataset( 27 | 'data', data=data, 28 | dtype='float32') 29 | h5_fout.create_dataset( 30 | 'label', data=labels, 31 | dtype='float32') 32 | h5_fout.close() 33 | -------------------------------------------------------------------------------- /resources/GAPNet Attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Guru/pointcloud_experiments/de010a942c96e01cc06a1dca42f17548eb1fca0c/resources/GAPNet Attention.png -------------------------------------------------------------------------------- /resources/GAPNet network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Guru/pointcloud_experiments/de010a942c96e01cc06a1dca42f17548eb1fca0c/resources/GAPNet network.png -------------------------------------------------------------------------------- /resources/GAPNet-input_transform_net.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Guru/pointcloud_experiments/de010a942c96e01cc06a1dca42f17548eb1fca0c/resources/GAPNet-input_transform_net.png -------------------------------------------------------------------------------- /schedules.py: -------------------------------------------------------------------------------- 1 | from callbacks import Step 2 | 3 | 4 | def onetenth_4_8_12(lr): 5 | steps = [4, 8, 12] 6 | lrs = [lr, lr / 10, lr / 100, lr / 1000] 7 | return Step(steps, lrs) 8 | 9 | 10 | def onetenth_10_15_20(lr): 11 | steps = [10, 15, 15] 12 | lrs = [lr, lr / 10, lr / 100, lr / 1000] 13 | return Step(steps, lrs) 14 | 15 | 16 | def onetenth_50_75(lr): 17 | steps = [50, 75] 18 | lrs = [lr, lr / 10, lr / 100] 19 | return Step(steps, lrs) 20 | 21 | 22 | def wideresnet_step(lr): 23 | steps = [60, 120, 160] 24 | lrs = [lr, lr / 5, lr / 25, lr / 125] 25 | return Step(steps, lrs) 26 | -------------------------------------------------------------------------------- /tests/Tests.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import sys 3 | sys.path.append("..") 4 | from gapnet.layers import KNN, GraphAttention, MultiGraphAttention, Transform 5 | from gapnet.models import GAPNet 6 | from tensorflow.keras import models, layers 7 | import numpy as np 8 | from tensorflow.keras.utils import plot_model 9 | 10 | class TestMethods(unittest.TestCase): 11 | 12 | #@unittest.skip 13 | def test_knn(self): 14 | """ 15 | Tests the KNN layer. 16 | """ 17 | 18 | number_of_points = 1024 19 | features = 3 20 | k = 20 21 | 22 | # Create the model. 23 | model = models.Sequential() 24 | model.add(KNN(input_shape=(number_of_points, features), k=k)) 25 | 26 | # Check if output is right. 27 | self.assertEqual(model.outputs[0].shape[1:], (number_of_points, k, features)) 28 | 29 | # TODO Consider checking the KNN condition. 30 | 31 | # Do a prediction. 32 | input = np.array([(x , x, x) for x in range(number_of_points)]) 33 | prediction = model.predict(np.expand_dims(input, axis=0))[0] 34 | print(prediction.shape) 35 | self.assertEqual(prediction.shape, (number_of_points, k, features)) 36 | #print(prediction) 37 | #plt.imshow(prediction) 38 | #plt.show() 39 | #plt.close() 40 | 41 | #@unittest.skip 42 | def test_gap(self): 43 | """ 44 | Tests the graph attention layer. 45 | """ 46 | 47 | number_of_points = 1024 48 | features = 3 49 | k = 20 50 | features_out = 16 51 | 52 | # Input for a pointcloud. 53 | point_cloud = layers.Input(shape=(number_of_points, features)) 54 | 55 | # Create the KNN layer and apply it to the input. 56 | knn = KNN(k=k, name="test_knn")(point_cloud) 57 | 58 | # Create the Graph Attention from point-cloud and KNN. 59 | graph_attention = GraphAttention(features_out)([point_cloud, knn]) 60 | 61 | # Create the model. 62 | model = models.Model(point_cloud, graph_attention) 63 | model.summary() 64 | 65 | # Check if output is right. The first is attention feature. 66 | self.assertEqual(model.outputs[0].shape[1:], (number_of_points, 1, features_out)) 67 | 68 | # Check if output is right. The second is graph feature. 69 | self.assertEqual(model.outputs[1].shape[1:], (number_of_points, 1, k, features_out)) 70 | 71 | # Check if output is right. The second is attention coefficients. 72 | self.assertEqual(model.outputs[2].shape[1:], (number_of_points, 1, k)) 73 | 74 | # Do a prediction. 75 | input = np.array([(x , x, x) for x in range(number_of_points)]) 76 | attention_features, graph_features, attention_coefficients = model.predict(np.expand_dims(input, axis=0)) 77 | #print(prediction) 78 | #plt.imshow(prediction) 79 | #plt.show() 80 | #plt.close() 81 | 82 | 83 | #@unittest.skip 84 | def test_multigap_onehead(self): 85 | """ 86 | Tests the multi head graph attention layer with one head. 87 | """ 88 | 89 | number_of_points = 1024 90 | features_in = 3 91 | k = 20 92 | heads = 1 93 | features_out = 16 94 | 95 | # Input for a pointcloud. 96 | point_cloud = layers.Input(shape=(number_of_points, features_in)) 97 | 98 | # Create the Graph Attention from point-cloud and KNN. 99 | multi_graph_attention = MultiGraphAttention(k=k, features_out=features_out, heads=heads)(point_cloud) 100 | 101 | # Create the model. 102 | model = models.Model(point_cloud, multi_graph_attention) 103 | 104 | # Check if output is right. The first is attention feature. 105 | self.assertEqual(model.outputs[0].shape[1:], (number_of_points, heads, features_out)) 106 | 107 | # Check if output is right. The second is graph feature. 108 | self.assertEqual(model.outputs[1].shape[1:], (number_of_points, k, features_out)) 109 | 110 | # Check if output is right. The second is attention coefficients. 111 | self.assertEqual(model.outputs[2].shape[1:], (number_of_points, heads, k)) 112 | 113 | 114 | #@unittest.skip 115 | def test_multigap_multihead(self): 116 | """ 117 | Tests the multi head graph attention layer with multiple heads. 118 | """ 119 | 120 | number_of_points = 1024 121 | features_in = 3 122 | k = 20 123 | heads = 4 124 | features_out = 16 125 | 126 | # Input for a pointcloud. 127 | point_cloud = layers.Input(shape=(number_of_points, features_in)) 128 | 129 | # Create the Graph Attention from point-cloud and KNN. 130 | multi_graph_attention = MultiGraphAttention(k=k, features_out=features_out, heads=heads)(point_cloud) 131 | 132 | # Create the model. 133 | model = models.Model(point_cloud, multi_graph_attention) 134 | 135 | # Check if output is right. The first is attention feature. 136 | self.assertEqual(model.outputs[0].shape[1:], (number_of_points, 1, heads * features_out)) 137 | 138 | # Check if output is right. The second is graph feature. 139 | self.assertEqual(model.outputs[1].shape[1:], (number_of_points, k, features_out * heads)) 140 | 141 | # Check if output is right. The second is attention coefficients. 142 | self.assertEqual(model.outputs[2].shape[1:], (number_of_points, 1, heads * k)) 143 | 144 | 145 | #@unittest.skip 146 | def test_transform(self): 147 | 148 | number_of_points = 1024 149 | features = 3 150 | k = 20 151 | features_out = 16 152 | 153 | # Input for a pointcloud. 154 | point_cloud = layers.Input(shape=(number_of_points, features)) 155 | 156 | # Create the transform layer from point-cloud. 157 | point_cloud_transformed = Transform(k=k, features=features_out)(point_cloud) 158 | 159 | # Create the model. 160 | model = models.Model(point_cloud, point_cloud_transformed) 161 | 162 | # Check if output is right. The first is attention feature. 163 | self.assertEqual(model.outputs[0].shape[1:], (number_of_points, features)) 164 | 165 | 166 | def test_model(self): 167 | 168 | # Create the model. 169 | model = GAPNet() 170 | model.summary() 171 | 172 | for x in model.non_trainable_weights: 173 | if "normalization" not in str(x): 174 | print(x) 175 | 176 | 177 | if __name__ == '__main__': 178 | unittest.main() 179 | -------------------------------------------------------------------------------- /todos.md: -------------------------------------------------------------------------------- 1 | - [X] add prefix to training 2 | - [ ] visualize attention 3 | - [X] add batch normalization to GAP layer 4 | - [X] make batch normalization deactivatable 5 | - [ ] replace Dense with CNN 6 | - [ ] in asserts: replace constants with variables 7 | - [ ] what would batch normalization do in our case? 8 | - [X] draw picture of transform layer 9 | - [X] draw picture of GAPNet 10 | - [X] implement transform layer 11 | - [X] implement GAPNet 12 | - [ ] let model yield transformation output 13 | - [ ] see if their CNN stuff is good or bad 14 | - [X] what about the (undocumented) skip connection point_cloud_expanded? 15 | - [ ] https://www.tensorflow.org/guide/upgrade 16 | - [ ] what is this? https://github.com/TianzhongSong/PointNet-Keras/blob/master/callbacks.py Change of LR 17 | - [ ] make sure that trainable parameters are properly displayed in summary() 18 | - [ ] implement load/save 19 | - [ ] write detailed documentation 20 | - [ ] consider commenting on GAPNet implementation details, thoughts and findings 21 | - [ ] consider cleaning up asserts in layers and models 22 | - [X] what are the non_trainable parameters? 23 | - [X] ensure that summary has all the output shapes... how? 24 | - [ ] 25 | - [ ] 26 | - [ ] 27 | - [ ] 28 | 29 | 30 | - [X] write README, with references 31 | - [X] publish 32 | - [X] make assertions for all those shapes 33 | - [X] make assertions for second attention single and multi 34 | - [X] consider squeezing attention coefficients in GAP 35 | - [X] what is tf.tile(input_feature, [1, 1, k, 1]) in gat_layers? 36 | - [X] use this for training https://github.com/TianzhongSong/PointNet-Keras 37 | - [X] make build model work 38 | - [X] get attentions out of multihead attention 39 | - [X] use 1024 instead of 2048 as number_of_points 40 | - [X] make assertions for first attention single and multi 41 | - [X] move attention to own layer 42 | - [X] move multihead attention to own layer 43 | -------------------------------------------------------------------------------- /train_cls.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras.callbacks import ModelCheckpoint 3 | from tensorflow.keras.optimizers import Adam 4 | from data_loader import DataGenerator 5 | from schedules import onetenth_50_75 6 | import os 7 | import matplotlib 8 | matplotlib.use('AGG') 9 | import matplotlib.pyplot as plt 10 | import sys 11 | from model_cls import create_pointnet 12 | from gapnet.models import GAPNet 13 | import shutil 14 | 15 | if tf.__version__.startswith("2"): 16 | tf.executing_eagerly() 17 | 18 | model_name = "gapnet" 19 | training_name = "04-first_full_network" 20 | 21 | 22 | def main(): 23 | 24 | # Check command line arguments. 25 | #if len(sys.argv) != 2 or sys.argv[1] not in model_names: 26 | # print("Must provide name of model.") 27 | # print("Options: " + " ".join(model_names)) 28 | # exit(0) 29 | #model_name = sys.argv[1] 30 | 31 | # Data preparation. 32 | nb_classes = 40 33 | train_file = './ModelNet40/ply_data_train.h5' 34 | test_file = './ModelNet40/ply_data_test.h5' 35 | 36 | # Hyperparameters. 37 | number_of_points = 1024 38 | epochs = 100 39 | batch_size = 32 40 | 41 | # Data generators for training and validation. 42 | train = DataGenerator(train_file, batch_size, number_of_points, nb_classes, train=True) 43 | val = DataGenerator(test_file, batch_size, number_of_points, nb_classes, train=False) 44 | 45 | # Create the model. 46 | if model_name == "pointnet": 47 | model = create_pointnet(number_of_points, nb_classes) 48 | elif model_name == "gapnet": 49 | model = GAPNet() 50 | model.summary() 51 | 52 | # Ensure output paths. 53 | output_path = "logs" 54 | if not os.path.exists(output_path): 55 | os.mkdir(output_path) 56 | output_path = os.path.join(output_path, model_name) 57 | if not os.path.exists(output_path): 58 | os.mkdir(output_path) 59 | output_path = os.path.join(output_path, training_name) 60 | if os.path.exists(output_path): 61 | shutil.rmtree(output_path) 62 | os.mkdir(output_path) 63 | 64 | 65 | # Compile the model. 66 | lr = 0.0001 67 | adam = Adam(lr=lr) 68 | model.compile( 69 | optimizer=adam, 70 | loss='categorical_crossentropy', 71 | metrics=['accuracy'] 72 | ) 73 | 74 | # Checkpoint callback. 75 | checkpoint = ModelCheckpoint( 76 | os.path.join(output_path, "model.h5"), 77 | monitor="val_acc", 78 | save_weights_only=True, 79 | save_best_only=True, 80 | verbose=1 81 | ) 82 | 83 | # Logging training progress with tensorboard. 84 | tensorboard_callback = tf.keras.callbacks.TensorBoard( 85 | log_dir=output_path, 86 | histogram_freq=0, 87 | batch_size=32, 88 | write_graph=True, 89 | write_grads=False, 90 | write_images=True, 91 | embeddings_freq=0, 92 | embeddings_layer_names=None, 93 | embeddings_metadata=None, 94 | embeddings_data=None, 95 | update_freq="epoch" 96 | ) 97 | 98 | callbacks = [] 99 | #callbacks.append(checkpoint) 100 | callbacks.append(onetenth_50_75(lr)) 101 | callbacks.append(tensorboard_callback) 102 | 103 | # Train the model. 104 | history = model.fit_generator( 105 | train.generator(), 106 | steps_per_epoch=9840 // batch_size, 107 | epochs=epochs, 108 | validation_data=val.generator(), 109 | validation_steps=2468 // batch_size, 110 | callbacks=callbacks, 111 | verbose=1 112 | ) 113 | 114 | # Save history and model. 115 | plot_history(history, output_path) 116 | save_history(history, output_path) 117 | model.save_weights(os.path.join(output_path, "model_weights.h5")) 118 | 119 | 120 | def plot_history(history, result_dir): 121 | if "acc" in history.history: 122 | plt.plot(history.history['acc'], marker='.') 123 | plt.plot(history.history['val_acc'], marker='.') 124 | elif "accuracy" in history.history: 125 | plt.plot(history.history['accuracy'], marker='.') 126 | plt.plot(history.history['val_accuracy'], marker='.') 127 | plt.title('model accuracy') 128 | plt.xlabel('epoch') 129 | plt.ylabel('accuracy') 130 | plt.grid() 131 | plt.legend(['acc', 'val_acc'], loc='lower right') 132 | plt.savefig(os.path.join(result_dir, 'model_accuracy.png')) 133 | plt.close() 134 | 135 | plt.plot(history.history['loss'], marker='.') 136 | plt.plot(history.history['val_loss'], marker='.') 137 | plt.title('model loss') 138 | plt.xlabel('epoch') 139 | plt.ylabel('loss') 140 | plt.grid() 141 | plt.legend(['loss', 'val_loss'], loc='upper right') 142 | plt.savefig(os.path.join(result_dir, 'model_loss.png')) 143 | plt.close() 144 | 145 | 146 | def save_history(history, result_dir): 147 | if "acc" in history.history: 148 | acc = history.history['acc'] 149 | val_acc = history.history['val_acc'] 150 | elif "accuracy" in history.history: 151 | acc = history.history['accuracy'] 152 | val_acc = history.history['val_accuracy'] 153 | loss = history.history['loss'] 154 | val_loss = history.history['val_loss'] 155 | nb_epoch = len(acc) 156 | 157 | with open(os.path.join(result_dir, 'result.txt'), 'w') as fp: 158 | fp.write('epoch\tloss\tacc\tval_loss\tval_acc\n') 159 | for i in range(nb_epoch): 160 | fp.write('{}\t{}\t{}\t{}\t{}\n'.format( 161 | i, loss[i], acc[i], val_loss[i], val_acc[i])) 162 | fp.close() 163 | 164 | 165 | if __name__ == '__main__': 166 | main() 167 | --------------------------------------------------------------------------------