├── LICENSE ├── README.md ├── images ├── airplane.gif ├── chair.gif └── cover.gif ├── nbs ├── PointNetClass.ipynb └── PointNetSeg.ipynb ├── requirements.txt ├── source ├── args.py ├── dataset.py ├── model.py └── utils.py └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Open In Colab 2 | 3 | # PointNet 4 | PyTorch implementation of "PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation" https://arxiv.org/abs/1612.00593 5 | 6 | 7 | pointnet 8 | 9 | Key points of the implementation are explained in details in [this](https://towardsdatascience.com/deep-learning-on-point-clouds-implementing-pointnet-in-google-colab-1fd65cd3a263) Medium article. 10 | 11 | ## Classification dataset 12 | This code implements object classification on [ModelNet10](https://modelnet.cs.princeton.edu) dataset. 13 | 14 | As in the original paper, we sample 1024 points on objects surfaces depending on the area of the current face. Then we normalize the object to a unit sphere and add Gaussian noise. This is an example of input to the neural network that represents a chair: 15 | 16 | matching points 17 | 18 | You can download the dataset by following [this link](https://drive.google.com/open?id=12Mv19pQ84VO8Av50hUXTixSxd5NDjeEB) 19 | 20 | ## Classification performance 21 | 22 | | Class (Accuracy) | Overall | Bathtub | Bed| Chair|Desk|Dresser|Monitor|Night stand|Sofa|Table|Toilet| 23 | | :---: |:---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | 24 | | ModelNet10 | 82.0% | 93.4% | 92.0% | 97.2% | 81.5% | 71.0% | 89.4% | 56.0% |86.9%| 93.4% |95.9%| 25 | 26 | Pretrained model is available [here](https://drive.google.com/open?id=1nDG0maaqoTkRkVsOLtUAR9X3kn__LMSL) 27 | 28 | ## Usage 29 | * The first and the best option is to run the notebook with comments and visualizations `/nbs/PointNetClass.ipynb` in Google Colab. 30 | * The second option is to clone the repository on a local machine and run a model with default parameters: 31 | ```bash 32 | git clone https://github.com/nikitakaraevv/pointnet 33 | wget http://3dvision.princeton.edu/projects/2014/3DShapeNets/ModelNet10.zip 34 | unzip -q ModelNet10.zip 35 | cd pointnet/ 36 | python train.py 37 | ``` 38 | If for some reason it doesn't work, you can install the requirements before running `python train.py`: 39 | ```bash 40 | conda create -n env python=3.7 41 | conda activate env 42 | pip install -r requirements.txt 43 | ``` 44 | Another example of running a model is: 45 | 46 | ```python train.py --root_dir ../ModelNet10/ --batch_size 16 --lr 0.0001 --epochs 30 --save_model_path ./ckpts``` 47 | 48 | ## Part segmentation dataset 49 | The dataset includes 2609 point clouds representing different airplanes, where every point has its coordinates in 3D space and a label of an airplane’s part the point belongs to. As all images have different number of points and PyTorch library functions require images of the same size to form a PyTorch tensor, we sample uniformly 2000 points from every point cloud. 50 | 51 | You can download the dataset by following [this link](https://drive.google.com/drive/u/1/folders/1Z5XA4uJpA86ky0qV1AVgA_G1_ETkq9En) 52 | 53 | ## Part segmentation performance 54 | The resulting accuracy on the validation dataset is 88%. In the original paper part segmentation results corresponding to category of objects (airplanes) is 83.4%. 55 | 56 | ## Usage 57 | This part of the project is still in development. However, you already can run the notebook `/nbs/PointNetSeg.ipynb` in Colab. 58 | 59 | matching points 60 | 61 | ## Authors 62 | * [Nikita Karaev](https://github.com/nikitakaraevv) 63 | * [Irina Nikulina](https://github.com/washburn125) 64 | -------------------------------------------------------------------------------- /images/airplane.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikitakaraevv/pointnet/256437e9ab27b197347464cecff87121c5c824ff/images/airplane.gif -------------------------------------------------------------------------------- /images/chair.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikitakaraevv/pointnet/256437e9ab27b197347464cecff87121c5c824ff/images/chair.gif -------------------------------------------------------------------------------- /images/cover.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikitakaraevv/pointnet/256437e9ab27b197347464cecff87121c5c824ff/images/cover.gif -------------------------------------------------------------------------------- /nbs/PointNetSeg.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "kernelspec": { 6 | "name": "python3", 7 | "display_name": "Python 3" 8 | }, 9 | "language_info": { 10 | "codemirror_mode": { 11 | "name": "ipython", 12 | "version": 3 13 | }, 14 | "file_extension": ".py", 15 | "mimetype": "text/x-python", 16 | "name": "python", 17 | "nbconvert_exporter": "python", 18 | "pygments_lexer": "ipython3", 19 | "version": "3.7.1" 20 | }, 21 | "colab": { 22 | "name": "PointNetSeg.ipynb", 23 | "provenance": [], 24 | "collapsed_sections": [ 25 | "caWQIszA8r-H" 26 | ], 27 | "machine_shape": "hm", 28 | "include_colab_link": true 29 | }, 30 | "accelerator": "GPU" 31 | }, 32 | "cells": [ 33 | { 34 | "cell_type": "markdown", 35 | "metadata": { 36 | "id": "view-in-github", 37 | "colab_type": "text" 38 | }, 39 | "source": [ 40 | "\"Open" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "metadata": { 46 | "id": "KuEIVQsJGUbQ", 47 | "colab_type": "code", 48 | "outputId": "ff151e25-a584-4c7a-bb2b-634acedd0952", 49 | "colab": { 50 | "base_uri": "https://localhost:8080/", 51 | "height": 35 52 | } 53 | }, 54 | "source": [ 55 | "from google.colab import drive\n", 56 | "drive.mount('/content/gdrive', force_remount=True)\n", 57 | "root_dir = \"/content/gdrive/My Drive/PointNet3D\"\n" 58 | ], 59 | "execution_count": 0, 60 | "outputs": [ 61 | { 62 | "output_type": "stream", 63 | "text": [ 64 | "Mounted at /content/gdrive\n" 65 | ], 66 | "name": "stdout" 67 | } 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "metadata": { 73 | "id": "PLGjzwFkGUbV", 74 | "colab_type": "code", 75 | "outputId": "a19df6db-1edd-4c03-8c1a-3e549eaf08de", 76 | "colab": { 77 | "base_uri": "https://localhost:8080/", 78 | "height": 124 79 | } 80 | }, 81 | "source": [ 82 | "!pip install path.py;\n", 83 | "from path import Path\n", 84 | "import sys\n", 85 | "sys.path.append(root_dir)" 86 | ], 87 | "execution_count": 0, 88 | "outputs": [ 89 | { 90 | "output_type": "stream", 91 | "text": [ 92 | "Requirement already satisfied: path.py in /usr/local/lib/python3.6/dist-packages (12.4.0)\n", 93 | "Requirement already satisfied: path<13.2 in /usr/local/lib/python3.6/dist-packages (from path.py) (13.1.0)\n", 94 | "Requirement already satisfied: importlib-metadata>=0.5; python_version < \"3.8\" in /usr/local/lib/python3.6/dist-packages (from path<13.2->path.py) (1.3.0)\n", 95 | "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata>=0.5; python_version < \"3.8\"->path<13.2->path.py) (0.6.0)\n", 96 | "Requirement already satisfied: more-itertools in /usr/local/lib/python3.6/dist-packages (from zipp>=0.5->importlib-metadata>=0.5; python_version < \"3.8\"->path<13.2->path.py) (8.0.2)\n" 97 | ], 98 | "name": "stdout" 99 | } 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "metadata": { 105 | "id": "LB7sxVUtGUbY", 106 | "colab_type": "code", 107 | "colab": {} 108 | }, 109 | "source": [ 110 | "import plotly.graph_objects as go\n", 111 | "import numpy as np\n", 112 | "import scipy.spatial.distance\n", 113 | "import math\n", 114 | "import random\n", 115 | "import utils\n", 116 | "\n", 117 | "\n", 118 | "class10_dir = \"/datasets/ModelNet10txt/ModelNet10/ModelNet10/\"" 119 | ], 120 | "execution_count": 0, 121 | "outputs": [] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "metadata": { 126 | "id": "31PRdnvMOjLd", 127 | "colab_type": "code", 128 | "colab": {} 129 | }, 130 | "source": [ 131 | "import random\n", 132 | "\n", 133 | "def read_pts(file):\n", 134 | " verts = np.genfromtxt(file)\n", 135 | " return utils.cent_norm(verts)\n", 136 | " #return verts\n", 137 | "\n", 138 | "def read_seg(file):\n", 139 | " verts = np.genfromtxt(file, dtype= (int))\n", 140 | " return verts\n", 141 | "\n", 142 | "def sample_2000(pts, pts_cat): \n", 143 | " res1 = np.concatenate((pts,np.reshape(pts_cat, (pts_cat.shape[0], 1))), axis= 1)\n", 144 | " res = np.asarray(random.choices(res1, weights=None, cum_weights=None, k=2000))\n", 145 | " images = res[:, 0:3]\n", 146 | " categories = res[:, 3]\n", 147 | " categories-=np.ones(categories.shape)\n", 148 | " return images, categories" 149 | ], 150 | "execution_count": 0, 151 | "outputs": [] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "metadata": { 156 | "id": "4KNH5UtbWWok", 157 | "colab_type": "text" 158 | }, 159 | "source": [ 160 | "## Model" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "metadata": { 166 | "id": "yOOrEYSnWV7f", 167 | "colab_type": "code", 168 | "colab": {} 169 | }, 170 | "source": [ 171 | "import torch\n", 172 | "import torch.nn as nn\n", 173 | "import numpy as np\n", 174 | "import torch.nn.functional as F\n", 175 | "\n", 176 | "class Tnet(nn.Module):\n", 177 | " def __init__(self, k=3):\n", 178 | " super().__init__()\n", 179 | " self.k=k\n", 180 | " self.conv1 = nn.Conv1d(k,64,1)\n", 181 | " self.conv2 = nn.Conv1d(64,128,1)\n", 182 | " self.conv3 = nn.Conv1d(128,1024,1)\n", 183 | " self.fc1 = nn.Linear(1024,512)\n", 184 | " self.fc2 = nn.Linear(512,256)\n", 185 | " self.fc3 = nn.Linear(256,k*k)\n", 186 | "\n", 187 | " self.bn1 = nn.BatchNorm1d(64)\n", 188 | " self.bn2 = nn.BatchNorm1d(128)\n", 189 | " self.bn3 = nn.BatchNorm1d(1024)\n", 190 | " self.bn4 = nn.BatchNorm1d(512)\n", 191 | " self.bn5 = nn.BatchNorm1d(256)\n", 192 | " \n", 193 | "\n", 194 | " def forward(self, input):\n", 195 | " # input.shape == (bs,n,3)\n", 196 | " bs = input.size(0)\n", 197 | " xb = F.relu(self.bn1(self.conv1(input)))\n", 198 | " xb = F.relu(self.bn2(self.conv2(xb)))\n", 199 | " xb = F.relu(self.bn3(self.conv3(xb)))\n", 200 | " pool = nn.MaxPool1d(xb.size(-1))(xb)\n", 201 | " flat = nn.Flatten(1)(pool)\n", 202 | " xb = F.relu(self.bn4(self.fc1(flat)))\n", 203 | " xb = F.relu(self.bn5(self.fc2(xb)))\n", 204 | " \n", 205 | " #initialize as identity\n", 206 | " init = torch.eye(self.k, requires_grad=True).repeat(bs,1,1)\n", 207 | " if xb.is_cuda:\n", 208 | " init=init.cuda()\n", 209 | " matrix = self.fc3(xb).view(-1,self.k,self.k) + init\n", 210 | " return matrix\n", 211 | "\n", 212 | "\n", 213 | "class Transform(nn.Module):\n", 214 | " def __init__(self):\n", 215 | " super().__init__()\n", 216 | " self.input_transform = Tnet(k=3)\n", 217 | " self.feature_transform = Tnet(k=128)\n", 218 | " self.fc1 = nn.Conv1d(3,64,1)\n", 219 | " self.fc2 = nn.Conv1d(64,128,1) \n", 220 | " self.fc3 = nn.Conv1d(128,128,1)\n", 221 | " self.fc4 = nn.Conv1d(128,512,1)\n", 222 | " self.fc5 = nn.Conv1d(512,2048,1)\n", 223 | "\n", 224 | " \n", 225 | " self.bn1 = nn.BatchNorm1d(64)\n", 226 | " self.bn2 = nn.BatchNorm1d(128)\n", 227 | " self.bn3 = nn.BatchNorm1d(128)\n", 228 | " self.bn4 = nn.BatchNorm1d(512)\n", 229 | " self.bn5 = nn.BatchNorm1d(2048)\n", 230 | "\n", 231 | " def forward(self, input):\n", 232 | " n_pts = input.size()[2]\n", 233 | " matrix3x3 = self.input_transform(input)\n", 234 | " xb = torch.bmm(torch.transpose(input,1,2), matrix3x3).transpose(1,2)\n", 235 | " outs = []\n", 236 | " \n", 237 | " out1 = F.relu(self.bn1(self.fc1(xb)))\n", 238 | " outs.append(out1)\n", 239 | " out2 = F.relu(self.bn2(self.fc2(out1)))\n", 240 | " outs.append(out2)\n", 241 | " out3 = F.relu(self.bn3(self.fc3(out2)))\n", 242 | " outs.append(out3)\n", 243 | " matrix128x128 = self.feature_transform(out3)\n", 244 | " \n", 245 | " out4 = torch.bmm(torch.transpose(out3,1,2), matrix128x128).transpose(1,2) \n", 246 | " outs.append(out4)\n", 247 | " out5 = F.relu(self.bn4(self.fc4(out4)))\n", 248 | " outs.append(out5)\n", 249 | " \n", 250 | " xb = self.bn5(self.fc5(out5))\n", 251 | " \n", 252 | " xb = nn.MaxPool1d(xb.size(-1))(xb)\n", 253 | " out6 = nn.Flatten(1)(xb).repeat(n_pts,1,1).transpose(0,2).transpose(0,1)#.repeat(1, 1, n_pts)\n", 254 | " outs.append(out6)\n", 255 | " \n", 256 | " \n", 257 | " return outs, matrix3x3, matrix128x128\n", 258 | "\n", 259 | "\n", 260 | "class PointNetSeg(nn.Module):\n", 261 | " def __init__(self, classes = 10):\n", 262 | " super().__init__()\n", 263 | " self.transform = Transform()\n", 264 | "\n", 265 | " self.fc1 = nn.Conv1d(3008,256,1) \n", 266 | " self.fc2 = nn.Conv1d(256,256,1) \n", 267 | " self.fc3 = nn.Conv1d(256,128,1) \n", 268 | " self.fc4 = nn.Conv1d(128,4,1) \n", 269 | " \n", 270 | "\n", 271 | " self.bn1 = nn.BatchNorm1d(256)\n", 272 | " self.bn2 = nn.BatchNorm1d(256)\n", 273 | " \n", 274 | " self.bn3 = nn.BatchNorm1d(128)\n", 275 | " self.bn4 = nn.BatchNorm1d(4)\n", 276 | " \n", 277 | " self.logsoftmax = nn.LogSoftmax(dim=1)\n", 278 | " \n", 279 | "\n", 280 | " def forward(self, input):\n", 281 | " inputs, matrix3x3, matrix128x128 = self.transform(input)\n", 282 | " stack = torch.cat(inputs,1)\n", 283 | " \n", 284 | " xb = F.relu(self.bn1(self.fc1(stack)))\n", 285 | " \n", 286 | " xb = F.relu(self.bn2(self.fc2(xb)))\n", 287 | " \n", 288 | " xb = F.relu(self.bn3(self.fc3(xb)))\n", 289 | " \n", 290 | " output = F.relu(self.bn4(self.fc4(xb)))\n", 291 | " \n", 292 | " return self.logsoftmax(output), matrix3x3, matrix128x128\n", 293 | "\n" 294 | ], 295 | "execution_count": 0, 296 | "outputs": [] 297 | }, 298 | { 299 | "cell_type": "markdown", 300 | "metadata": { 301 | "id": "HdXchtFBWZYG", 302 | "colab_type": "text" 303 | }, 304 | "source": [ 305 | "## Dataset" 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "metadata": { 311 | "id": "ut50_1uQCFCc", 312 | "colab_type": "code", 313 | "colab": {} 314 | }, 315 | "source": [ 316 | "from __future__ import print_function, division\n", 317 | "import os\n", 318 | "import torch\n", 319 | "import pandas as pd\n", 320 | "from skimage import io, transform\n", 321 | "import numpy as np\n", 322 | "import matplotlib.pyplot as plt\n", 323 | "from torch.utils.data import Dataset, DataLoader\n", 324 | "from torchvision import transforms, utils\n", 325 | "from torch.utils.data.dataset import random_split\n", 326 | "import utils\n", 327 | "\n", 328 | "class Data(Dataset):\n", 329 | " \"\"\"Face Landmarks dataset.\"\"\"\n", 330 | "\n", 331 | " def __init__(self, root_dir, valid=False, transform=None):\n", 332 | " \n", 333 | " self.root_dir = root_dir\n", 334 | " self.files = []\n", 335 | " self.valid=valid\n", 336 | "\n", 337 | " newdir = root_dir + '/datasets/airplane_part_seg/02691156/expert_verified/points_label/'\n", 338 | "\n", 339 | " for file in os.listdir(newdir):\n", 340 | " o = {}\n", 341 | " o['category'] = newdir + file\n", 342 | " o['img_path'] = root_dir + '/datasets/airplane_part_seg/02691156/points/'+ file.replace('.seg', '.pts')\n", 343 | " self.files.append(o)\n", 344 | " \n", 345 | "\n", 346 | " def __len__(self):\n", 347 | " return len(self.files)\n", 348 | "\n", 349 | " def __getitem__(self, idx):\n", 350 | " img_path = self.files[idx]['img_path']\n", 351 | " category = self.files[idx]['category']\n", 352 | " with open(img_path, 'r') as f:\n", 353 | " image1 = read_pts(f)\n", 354 | " with open(category, 'r') as f: \n", 355 | " category1 = read_seg(f)\n", 356 | " image2, category2 = sample_2000(image1, category1)\n", 357 | " if not self.valid:\n", 358 | " theta = random.random()*360\n", 359 | " image2 = utils.rotation_z(utils.add_noise(image2), theta)\n", 360 | " \n", 361 | " return {'image': np.array(image2, dtype=\"float32\"), 'category': category2.astype(int)}\n" 362 | ], 363 | "execution_count": 0, 364 | "outputs": [] 365 | }, 366 | { 367 | "cell_type": "code", 368 | "metadata": { 369 | "id": "1mUcFS3Uwci6", 370 | "colab_type": "code", 371 | "outputId": "4d10dda0-7c73-4007-ce71-de106f2a32a6", 372 | "colab": { 373 | "base_uri": "https://localhost:8080/", 374 | "height": 69 375 | } 376 | }, 377 | "source": [ 378 | "\n", 379 | "dset = Data(root_dir , transform=None)\n", 380 | "train_num = int(len(dset) * 0.95)\n", 381 | "val_num = int(len(dset) *0.05)\n", 382 | "if int(len(dset)) - train_num - val_num >0 :\n", 383 | " train_num = train_num + 1\n", 384 | "elif int(len(dset)) - train_num - val_num < 0:\n", 385 | " train_num = train_num -1\n", 386 | "#train_dataset, val_dataset = random_split(dset, [3000, 118])\n", 387 | "train_dataset, val_dataset = random_split(dset, [train_num, val_num])\n", 388 | "val_dataset.valid=True\n", 389 | "\n", 390 | "print('######### Dataset class created #########')\n", 391 | "print('Number of images: ', len(dset))\n", 392 | "print('Sample image shape: ', dset[0]['image'].shape)\n", 393 | "#print('Sample image points categories', dset[0]['category'], end='\\n\\n')\n", 394 | "\n", 395 | "train_loader = DataLoader(dataset=train_dataset, batch_size=64)\n", 396 | "val_loader = DataLoader(dataset=val_dataset, batch_size=64)\n", 397 | "\n", 398 | "#dataloader = torch.utils.data.DataLoader(dset, batch_size=4, shuffle=True, num_workers=4)" 399 | ], 400 | "execution_count": 0, 401 | "outputs": [ 402 | { 403 | "output_type": "stream", 404 | "text": [ 405 | "######### Dataset class created #########\n", 406 | "Number of images: 2690\n", 407 | "Sample image shape: (2000, 3)\n" 408 | ], 409 | "name": "stdout" 410 | } 411 | ] 412 | }, 413 | { 414 | "cell_type": "markdown", 415 | "metadata": { 416 | "id": "gg9RjG7awgVK", 417 | "colab_type": "text" 418 | }, 419 | "source": [ 420 | "## Training loop" 421 | ] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "metadata": { 426 | "id": "bq9AUuN5WRxI", 427 | "colab_type": "code", 428 | "outputId": "a6c2e5eb-ee77-4f4e-f511-77acffab30fa", 429 | "colab": { 430 | "base_uri": "https://localhost:8080/", 431 | "height": 35 432 | } 433 | }, 434 | "source": [ 435 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 436 | "print(device)" 437 | ], 438 | "execution_count": 0, 439 | "outputs": [ 440 | { 441 | "output_type": "stream", 442 | "text": [ 443 | "cuda:0\n" 444 | ], 445 | "name": "stdout" 446 | } 447 | ] 448 | }, 449 | { 450 | "cell_type": "code", 451 | "metadata": { 452 | "id": "cqXW9-oJwEPm", 453 | "colab_type": "code", 454 | "colab": {} 455 | }, 456 | "source": [ 457 | "pointnet = PointNetSeg()" 458 | ], 459 | "execution_count": 0, 460 | "outputs": [] 461 | }, 462 | { 463 | "cell_type": "code", 464 | "metadata": { 465 | "id": "6mA80v2ywHhw", 466 | "colab_type": "code", 467 | "colab": {} 468 | }, 469 | "source": [ 470 | "pointnet.to(device);" 471 | ], 472 | "execution_count": 0, 473 | "outputs": [] 474 | }, 475 | { 476 | "cell_type": "code", 477 | "metadata": { 478 | "id": "JV09EA4_wJnR", 479 | "colab_type": "code", 480 | "colab": {} 481 | }, 482 | "source": [ 483 | "optimizer = torch.optim.Adam(pointnet.parameters(), lr=0.001)" 484 | ], 485 | "execution_count": 0, 486 | "outputs": [] 487 | }, 488 | { 489 | "cell_type": "code", 490 | "metadata": { 491 | "id": "aDb9rPb_wPWj", 492 | "colab_type": "code", 493 | "colab": {} 494 | }, 495 | "source": [ 496 | "def pointnetloss(outputs, labels, m3x3, m128x128, alpha = 0.0001):\n", 497 | " criterion = torch.nn.NLLLoss()\n", 498 | " bs=outputs.size(0)\n", 499 | " id3x3 = torch.eye(3, requires_grad=True).repeat(bs,1,1)\n", 500 | " id128x128 = torch.eye(128, requires_grad=True).repeat(bs,1,1)\n", 501 | " if outputs.is_cuda:\n", 502 | " id3x3=id3x3.cuda()\n", 503 | " id128x128=id128x128.cuda()\n", 504 | " diff3x3 = id3x3-torch.bmm(m3x3,m3x3.transpose(1,2))\n", 505 | " diff128x128 = id128x128-torch.bmm(m128x128,m128x128.transpose(1,2))\n", 506 | " return criterion(outputs, labels) + alpha * (torch.norm(diff3x3)+torch.norm(diff128x128)) / float(bs)\n", 507 | " " 508 | ], 509 | "execution_count": 0, 510 | "outputs": [] 511 | }, 512 | { 513 | "cell_type": "code", 514 | "metadata": { 515 | "id": "CgaPisZFwVzh", 516 | "colab_type": "code", 517 | "colab": {} 518 | }, 519 | "source": [ 520 | "def train(model, train_loader, val_loader=None, epochs=15, save=True):\n", 521 | " for epoch in range(epochs): \n", 522 | " pointnet.train()\n", 523 | " running_loss = 0.0\n", 524 | " for i, data in enumerate(train_loader, 0):\n", 525 | " inputs, labels = data['image'].to(device), data['category'].to(device)\n", 526 | " optimizer.zero_grad()\n", 527 | " outputs, m3x3, m64x64 = pointnet(inputs.transpose(1,2))\n", 528 | "\n", 529 | " loss = pointnetloss(outputs, labels, m3x3, m64x64)\n", 530 | " loss.backward()\n", 531 | " optimizer.step()\n", 532 | "\n", 533 | " # print statistics\n", 534 | " running_loss += loss.item()\n", 535 | " if i % 10 == 9: # print every 10 mini-batches\n", 536 | " print('[%d, %5d] loss: %.3f' %\n", 537 | " (epoch + 1, i + 1, running_loss / 10))\n", 538 | " running_loss = 0.0\n", 539 | "\n", 540 | " pointnet.eval()\n", 541 | " correct = total = 0\n", 542 | "\n", 543 | " # validation\n", 544 | " if val_loader:\n", 545 | " with torch.no_grad():\n", 546 | " for data in val_loader:\n", 547 | " inputs, labels = data['image'].to(device), data['category'].to(device)\n", 548 | " outputs, __, __ = pointnet(inputs.transpose(1,2))\n", 549 | " _, predicted = torch.max(outputs.data, 1)\n", 550 | " total += labels.size(0) * labels.size(1) ##\n", 551 | " correct += (predicted == labels).sum().item()\n", 552 | " val_acc = 100 * correct / total\n", 553 | " print('Valid accuracy: %d %%' % val_acc)\n", 554 | "\n", 555 | " # save the model\n", 556 | " if save:\n", 557 | " torch.save(pointnet.state_dict(), root_dir+\"/modelsSeg/\"+str(epoch)+\"_\"+str(val_acc))\n" 558 | ], 559 | "execution_count": 0, 560 | "outputs": [] 561 | }, 562 | { 563 | "cell_type": "code", 564 | "metadata": { 565 | "id": "3jjVYFmSv9lu", 566 | "colab_type": "code", 567 | "outputId": "0c443233-6750-4382-b790-333ac78dcf98", 568 | "colab": { 569 | "base_uri": "https://localhost:8080/", 570 | "height": 1000 571 | } 572 | }, 573 | "source": [ 574 | "train(pointnet, train_loader, val_loader, save=True)\n" 575 | ], 576 | "execution_count": 0, 577 | "outputs": [ 578 | { 579 | "output_type": "stream", 580 | "text": [ 581 | "[1, 10] loss: 1.203\n", 582 | "[1, 20] loss: 0.924\n", 583 | "[1, 30] loss: 0.844\n", 584 | "[1, 40] loss: 0.800\n", 585 | "Valid accuracy: 78 %\n", 586 | "[2, 10] loss: 0.769\n", 587 | "[2, 20] loss: 0.741\n", 588 | "[2, 30] loss: 0.732\n", 589 | "[2, 40] loss: 0.729\n", 590 | "Valid accuracy: 82 %\n", 591 | "[3, 10] loss: 0.709\n", 592 | "[3, 20] loss: 0.685\n", 593 | "[3, 30] loss: 0.680\n", 594 | "[3, 40] loss: 0.676\n", 595 | "Valid accuracy: 85 %\n", 596 | "[4, 10] loss: 0.663\n", 597 | "[4, 20] loss: 0.642\n", 598 | "[4, 30] loss: 0.637\n", 599 | "[4, 40] loss: 0.636\n", 600 | "Valid accuracy: 87 %\n", 601 | "[5, 10] loss: 0.626\n", 602 | "[5, 20] loss: 0.617\n", 603 | "[5, 30] loss: 0.610\n", 604 | "[5, 40] loss: 0.613\n", 605 | "Valid accuracy: 86 %\n", 606 | "[6, 10] loss: 0.604\n", 607 | "[6, 20] loss: 0.587\n", 608 | "[6, 30] loss: 0.580\n", 609 | "[6, 40] loss: 0.583\n", 610 | "Valid accuracy: 86 %\n", 611 | "[7, 10] loss: 0.574\n", 612 | "[7, 20] loss: 0.563\n", 613 | "[7, 30] loss: 0.558\n", 614 | "[7, 40] loss: 0.565\n", 615 | "Valid accuracy: 86 %\n", 616 | "[8, 10] loss: 0.553\n", 617 | "[8, 20] loss: 0.539\n", 618 | "[8, 30] loss: 0.538\n", 619 | "[8, 40] loss: 0.543\n", 620 | "Valid accuracy: 87 %\n", 621 | "[9, 10] loss: 0.533\n", 622 | "[9, 20] loss: 0.525\n", 623 | "[9, 30] loss: 0.516\n", 624 | "[9, 40] loss: 0.521\n", 625 | "Valid accuracy: 87 %\n", 626 | "[10, 10] loss: 0.522\n", 627 | "[10, 20] loss: 0.506\n", 628 | "[10, 30] loss: 0.503\n", 629 | "[10, 40] loss: 0.507\n", 630 | "Valid accuracy: 87 %\n", 631 | "[11, 10] loss: 0.501\n", 632 | "[11, 20] loss: 0.495\n", 633 | "[11, 30] loss: 0.485\n", 634 | "[11, 40] loss: 0.497\n", 635 | "Valid accuracy: 88 %\n", 636 | "[12, 10] loss: 0.484\n", 637 | "[12, 20] loss: 0.477\n", 638 | "[12, 30] loss: 0.474\n", 639 | "[12, 40] loss: 0.477\n", 640 | "Valid accuracy: 88 %\n", 641 | "[13, 10] loss: 0.472\n", 642 | "[13, 20] loss: 0.458\n", 643 | "[13, 30] loss: 0.456\n", 644 | "[13, 40] loss: 0.464\n", 645 | "Valid accuracy: 87 %\n", 646 | "[14, 10] loss: 0.454\n", 647 | "[14, 20] loss: 0.452\n", 648 | "[14, 30] loss: 0.446\n", 649 | "[14, 40] loss: 0.458\n", 650 | "Valid accuracy: 87 %\n", 651 | "[15, 10] loss: 0.444\n", 652 | "[15, 20] loss: 0.433\n", 653 | "[15, 30] loss: 0.432\n", 654 | "[15, 40] loss: 0.440\n", 655 | "Valid accuracy: 88 %\n" 656 | ], 657 | "name": "stdout" 658 | } 659 | ] 660 | }, 661 | { 662 | "cell_type": "markdown", 663 | "metadata": { 664 | "id": "VeUZqen5GlKr", 665 | "colab_type": "text" 666 | }, 667 | "source": [ 668 | "## test" 669 | ] 670 | }, 671 | { 672 | "cell_type": "markdown", 673 | "metadata": { 674 | "id": "mbBNmdqgGj5-", 675 | "colab_type": "text" 676 | }, 677 | "source": [ 678 | "" 679 | ] 680 | }, 681 | { 682 | "cell_type": "code", 683 | "metadata": { 684 | "id": "Xsk9nSDAI3ba", 685 | "colab_type": "code", 686 | "outputId": "4467b734-e3db-4d48-dce3-b7f269cbaed5", 687 | "colab": { 688 | "base_uri": "https://localhost:8080/", 689 | "height": 867 690 | } 691 | }, 692 | "source": [ 693 | "pointnet = PointNetSeg()\n", 694 | "pointnet.load_state_dict(torch.load(root_dir+\"/modelsSeg/\"+\"14_88.01940298507462\"))\n", 695 | "pointnet.eval()" 696 | ], 697 | "execution_count": 0, 698 | "outputs": [ 699 | { 700 | "output_type": "execute_result", 701 | "data": { 702 | "text/plain": [ 703 | "PointNetSeg(\n", 704 | " (transform): Transform(\n", 705 | " (input_transform): Tnet(\n", 706 | " (conv1): Conv1d(3, 64, kernel_size=(1,), stride=(1,))\n", 707 | " (conv2): Conv1d(64, 128, kernel_size=(1,), stride=(1,))\n", 708 | " (conv3): Conv1d(128, 1024, kernel_size=(1,), stride=(1,))\n", 709 | " (fc1): Linear(in_features=1024, out_features=512, bias=True)\n", 710 | " (fc2): Linear(in_features=512, out_features=256, bias=True)\n", 711 | " (fc3): Linear(in_features=256, out_features=9, bias=True)\n", 712 | " (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 713 | " (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 714 | " (bn3): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 715 | " (bn4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 716 | " (bn5): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 717 | " )\n", 718 | " (feature_transform): Tnet(\n", 719 | " (conv1): Conv1d(128, 64, kernel_size=(1,), stride=(1,))\n", 720 | " (conv2): Conv1d(64, 128, kernel_size=(1,), stride=(1,))\n", 721 | " (conv3): Conv1d(128, 1024, kernel_size=(1,), stride=(1,))\n", 722 | " (fc1): Linear(in_features=1024, out_features=512, bias=True)\n", 723 | " (fc2): Linear(in_features=512, out_features=256, bias=True)\n", 724 | " (fc3): Linear(in_features=256, out_features=16384, bias=True)\n", 725 | " (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 726 | " (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 727 | " (bn3): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 728 | " (bn4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 729 | " (bn5): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 730 | " )\n", 731 | " (fc1): Conv1d(3, 64, kernel_size=(1,), stride=(1,))\n", 732 | " (fc2): Conv1d(64, 128, kernel_size=(1,), stride=(1,))\n", 733 | " (fc3): Conv1d(128, 128, kernel_size=(1,), stride=(1,))\n", 734 | " (fc4): Conv1d(128, 512, kernel_size=(1,), stride=(1,))\n", 735 | " (fc5): Conv1d(512, 2048, kernel_size=(1,), stride=(1,))\n", 736 | " (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 737 | " (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 738 | " (bn3): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 739 | " (bn4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 740 | " (bn5): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 741 | " )\n", 742 | " (fc1): Conv1d(3008, 256, kernel_size=(1,), stride=(1,))\n", 743 | " (fc2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", 744 | " (fc3): Conv1d(256, 128, kernel_size=(1,), stride=(1,))\n", 745 | " (fc4): Conv1d(128, 4, kernel_size=(1,), stride=(1,))\n", 746 | " (bn1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 747 | " (bn2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 748 | " (bn3): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 749 | " (bn4): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 750 | " (logsoftmax): LogSoftmax()\n", 751 | ")" 752 | ] 753 | }, 754 | "metadata": { 755 | "tags": [] 756 | }, 757 | "execution_count": 11 758 | } 759 | ] 760 | }, 761 | { 762 | "cell_type": "code", 763 | "metadata": { 764 | "id": "zoE5fRX8GnWR", 765 | "colab_type": "code", 766 | "outputId": "901bdfb2-0445-4497-e772-e2f482b98f12", 767 | "colab": { 768 | "base_uri": "https://localhost:8080/", 769 | "height": 139 770 | } 771 | }, 772 | "source": [ 773 | "batch = next(iter(val_loader))\n", 774 | "pred = pointnet(batch['image'].transpose(1,2))\n", 775 | "pred_np = np.array(torch.argmax(pred[0],1));\n", 776 | "pred_np\n", 777 | "\n" 778 | ], 779 | "execution_count": 0, 780 | "outputs": [ 781 | { 782 | "output_type": "execute_result", 783 | "data": { 784 | "text/plain": [ 785 | "array([[1, 1, 0, ..., 0, 0, 1],\n", 786 | " [1, 2, 1, ..., 1, 2, 1],\n", 787 | " [0, 0, 2, ..., 0, 1, 3],\n", 788 | " ...,\n", 789 | " [1, 0, 1, ..., 3, 2, 1],\n", 790 | " [2, 3, 0, ..., 0, 1, 3],\n", 791 | " [1, 1, 1, ..., 0, 0, 1]])" 792 | ] 793 | }, 794 | "metadata": { 795 | "tags": [] 796 | }, 797 | "execution_count": 22 798 | } 799 | ] 800 | }, 801 | { 802 | "cell_type": "code", 803 | "metadata": { 804 | "id": "3ISvwR4RL_eT", 805 | "colab_type": "code", 806 | "outputId": "9695bbd3-e74c-41a1-9846-ad438f413a2f", 807 | "colab": { 808 | "base_uri": "https://localhost:8080/", 809 | "height": 35 810 | } 811 | }, 812 | "source": [ 813 | "batch['image'][0].shape" 814 | ], 815 | "execution_count": 0, 816 | "outputs": [ 817 | { 818 | "output_type": "execute_result", 819 | "data": { 820 | "text/plain": [ 821 | "torch.Size([2000, 3])" 822 | ] 823 | }, 824 | "metadata": { 825 | "tags": [] 826 | }, 827 | "execution_count": 61 828 | } 829 | ] 830 | }, 831 | { 832 | "cell_type": "code", 833 | "metadata": { 834 | "id": "MQXXGh3GJSRf", 835 | "colab_type": "code", 836 | "outputId": "23c684cc-7a4f-467b-a8c6-e916bae70107", 837 | "colab": { 838 | "base_uri": "https://localhost:8080/", 839 | "height": 139 840 | } 841 | }, 842 | "source": [ 843 | "pred_np==np.array(batch['category'])" 844 | ], 845 | "execution_count": 0, 846 | "outputs": [ 847 | { 848 | "output_type": "execute_result", 849 | "data": { 850 | "text/plain": [ 851 | "array([[ True, True, True, ..., True, True, True],\n", 852 | " [ True, True, True, ..., True, True, True],\n", 853 | " [ True, True, True, ..., False, True, True],\n", 854 | " ...,\n", 855 | " [ True, True, True, ..., True, True, True],\n", 856 | " [ True, True, True, ..., True, True, True],\n", 857 | " [ True, True, True, ..., False, True, True]])" 858 | ] 859 | }, 860 | "metadata": { 861 | "tags": [] 862 | }, 863 | "execution_count": 23 864 | } 865 | ] 866 | }, 867 | { 868 | "cell_type": "code", 869 | "metadata": { 870 | "id": "60Mr2Bp7O9xu", 871 | "colab_type": "code", 872 | "colab": {} 873 | }, 874 | "source": [ 875 | "acc = (pred_np==np.array(batch['category']))" 876 | ], 877 | "execution_count": 0, 878 | "outputs": [] 879 | }, 880 | { 881 | "cell_type": "code", 882 | "metadata": { 883 | "id": "Jc1jLPj-PBo-", 884 | "colab_type": "code", 885 | "colab": {} 886 | }, 887 | "source": [ 888 | "resulting_acc = np.sum(acc, axis=1) / 2000" 889 | ], 890 | "execution_count": 0, 891 | "outputs": [] 892 | }, 893 | { 894 | "cell_type": "code", 895 | "metadata": { 896 | "id": "w8NwSbbz8jQ-", 897 | "colab_type": "code", 898 | "outputId": "eba05278-cf1a-425b-8462-7684e7a50101", 899 | "colab": { 900 | "base_uri": "https://localhost:8080/", 901 | "height": 156 902 | } 903 | }, 904 | "source": [ 905 | "resulting_acc" 906 | ], 907 | "execution_count": 0, 908 | "outputs": [ 909 | { 910 | "output_type": "execute_result", 911 | "data": { 912 | "text/plain": [ 913 | "array([0.863 , 0.906 , 0.91 , 0.867 , 0.9165, 0.882 , 0.9055, 0.8445,\n", 914 | " 0.8405, 0.9235, 0.8865, 0.8855, 0.903 , 0.884 , 0.8945, 0.85 ,\n", 915 | " 0.9125, 0.822 , 0.9345, 0.895 , 0.9135, 0.9395, 0.9385, 0.9165,\n", 916 | " 0.8865, 0.848 , 0.8765, 0.9105, 0.8805, 0.83 , 0.852 , 0.9225,\n", 917 | " 0.906 , 0.7705, 0.883 , 0.785 , 0.811 , 0.8565, 0.866 , 0.868 ,\n", 918 | " 0.7855, 0.7305, 0.9155, 0.8915, 0.9065, 0.805 , 0.875 , 0.89 ,\n", 919 | " 0.813 , 0.9005, 0.8325, 0.833 , 0.879 , 0.9215, 0.8185, 0.933 ,\n", 920 | " 0.9325, 0.9 , 0.833 , 0.8535, 0.8545, 0.895 , 0.8325, 0.9295])" 921 | ] 922 | }, 923 | "metadata": { 924 | "tags": [] 925 | }, 926 | "execution_count": 18 927 | } 928 | ] 929 | }, 930 | { 931 | "cell_type": "code", 932 | "metadata": { 933 | "id": "nnt7vxPHKamU", 934 | "colab_type": "code", 935 | "outputId": "1ab89e1a-d06d-47c3-dcae-2876e5c6d605", 936 | "colab": { 937 | "base_uri": "https://localhost:8080/", 938 | "height": 139 939 | } 940 | }, 941 | "source": [ 942 | "pred_np" 943 | ], 944 | "execution_count": 0, 945 | "outputs": [ 946 | { 947 | "output_type": "execute_result", 948 | "data": { 949 | "text/plain": [ 950 | "array([[1, 1, 0, ..., 0, 0, 1],\n", 951 | " [1, 2, 1, ..., 1, 2, 1],\n", 952 | " [0, 0, 2, ..., 0, 1, 3],\n", 953 | " ...,\n", 954 | " [1, 0, 1, ..., 3, 2, 1],\n", 955 | " [2, 3, 0, ..., 0, 1, 3],\n", 956 | " [1, 1, 1, ..., 0, 0, 1]])" 957 | ] 958 | }, 959 | "metadata": { 960 | "tags": [] 961 | }, 962 | "execution_count": 25 963 | } 964 | ] 965 | }, 966 | { 967 | "cell_type": "code", 968 | "metadata": { 969 | "id": "N9bgpbtnHC2E", 970 | "colab_type": "code", 971 | "outputId": "394d29a3-8cc0-4ee7-d480-edd452a0d1ce", 972 | "colab": { 973 | "base_uri": "https://localhost:8080/" 974 | } 975 | }, 976 | "source": [ 977 | "x,y,z=np.array(batch['image'][0]).T\n", 978 | "c = np.array(batch['category'][0]).T\n", 979 | "\n", 980 | "fig = go.Figure(data=[go.Scatter3d(x=x, y=y, z=z, \n", 981 | " mode='markers',\n", 982 | " marker=dict(\n", 983 | " size=30,\n", 984 | " color=c, # set color to an array/list of desired values\n", 985 | " colorscale='Viridis', # choose a colorscale\n", 986 | " opacity=1.0\n", 987 | " ))])\n", 988 | "fig.update_traces(marker=dict(size=2,\n", 989 | " line=dict(width=2,\n", 990 | " color='DarkSlateGrey')),\n", 991 | " selector=dict(mode='markers'))\n", 992 | "fig.show()" 993 | ], 994 | "execution_count": 0, 995 | "outputs": [ 996 | { 997 | "output_type": "display_data", 998 | "data": { 999 | "text/html": [ 1000 | "\n", 1001 | "\n", 1002 | "\n", 1003 | "
\n", 1004 | " \n", 1005 | " \n", 1006 | " \n", 1007 | "
\n", 1008 | " \n", 1046 | "
\n", 1047 | "\n", 1048 | "" 1049 | ] 1050 | }, 1051 | "metadata": { 1052 | "tags": [] 1053 | } 1054 | } 1055 | ] 1056 | } 1057 | ] 1058 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.7.1 2 | aiofiles==0.4.0 3 | aiohttp==3.7.4 4 | antiorm==1.2.1 5 | appnope==0.1.0 6 | asn1crypto==0.24.0 7 | astor==0.7.1 8 | async-timeout==3.0.1 9 | attrs==19.3.0 10 | autopep8==1.5 11 | awsebcli==3.15.0 12 | backcall==0.1.0 13 | beautifulsoup4==4.7.1 14 | bleach==3.3.0 15 | blessed==1.15.0 16 | blis==0.2.4 17 | botocore==1.12.131 18 | Bottleneck==1.2.1 19 | cached-property==1.5.1 20 | cement==2.8.2 21 | certifi==2018.11.29 22 | cffi==1.12.3 23 | chardet==3.0.4 24 | Click==7.0 25 | cloudpickle==1.2.1 26 | colorama==0.3.9 27 | cpplint==1.4.5 28 | cryptography==3.3.2 29 | cycler==0.10.0 30 | cymem==2.0.2 31 | Cython==0.29.14 32 | dask==2021.10.0 33 | db==0.1.1 34 | decorator==4.4.2 35 | defusedxml==0.6.0 36 | docker==3.7.2 37 | docker-compose==1.23.2 38 | docker-pycreds==0.4.0 39 | dockerpty==0.4.1 40 | docopt==0.6.2 41 | docutils==0.14 42 | entrypoints==0.3 43 | et-xmlfile==1.0.1 44 | fastai==1.0.50.post1 45 | fastprogress==0.1.20 46 | fastscript==0.1.4 47 | filelock==3.0.12 48 | fire==0.2.1 49 | flufl.i18n==1.1.3 50 | fn==0.4.3 51 | future==0.16.0 52 | gast==0.2.0 53 | gdown==3.7.4 54 | google-pasta==0.1.7 55 | graphviz==0.13.2 56 | grpcio==1.17.0 57 | gym==0.10.11 58 | h11==0.8.1 59 | h5py==2.9.0 60 | html5lib==0.9999999 61 | httptools==0.0.13 62 | idna==2.7 63 | ig-cpp==0.1.9 64 | imageio==2.5.0 65 | importlib-metadata==1.6.0 66 | ipykernel==5.2.0 67 | ipython==7.16.3 68 | ipython-genutils==0.2.0 69 | ipywidgets==7.5.1 70 | jdcal==1.4.1 71 | jedi==0.16.0 72 | Jinja2==2.11.3 73 | jmespath==0.9.4 74 | joblib==0.14.1 75 | json5==0.9.2 76 | jsonschema==3.2.0 77 | jupyter==1.0.0 78 | jupyter-client==6.1.2 79 | jupyter-console==6.0.0 80 | jupyter-core==4.6.3 81 | jupyterlab==2.2.10 82 | jupyterlab-server==1.0.7 83 | Keras-Applications==1.0.8 84 | Keras-Preprocessing==1.1.0 85 | keyboard==0.13.4 86 | kiwisolver==1.0.1 87 | lxml==4.6.5 88 | Markdown==3.0.1 89 | MarkupSafe==1.1.1 90 | matplotlib==3.0.3 91 | mistune==0.8.4 92 | multidict==4.5.2 93 | murmurhash==1.0.2 94 | nbconvert==5.6.1 95 | nbdev==0.2.13 96 | nbformat==5.0.4 97 | networkx==2.3 98 | nose==1.3.7 99 | notebook==6.4.10 100 | numexpr==2.6.9 101 | numpy==1.21.0 102 | nvidia-ml-py3==7.352.0 103 | onnx==1.4.1 104 | open3d==0.9.0.0 105 | openpyxl==2.6.2 106 | packaging==19.0 107 | pandas==0.24.2 108 | pandocfilters==1.4.2 109 | parso==0.6.2 110 | Pat==0.5.2 111 | path.py==11.5.0 112 | pathspec==0.5.9 113 | pexpect==4.8.0 114 | pickleshare==0.7.5 115 | plac==0.9.6 116 | plotly==4.3.0 117 | plyfile==0.7.1 118 | preshed==2.0.1 119 | prometheus-client==0.7.1 120 | prompt-toolkit==3.0.5 121 | protobuf==3.15.0 122 | ptyprocess==0.6.0 123 | pyamplitude==1.2.0.dev1 124 | pycodestyle==2.5.0 125 | pycparser==2.19 126 | pyglet==1.3.2 127 | Pygments==2.7.4 128 | pymesh==1.0.2 129 | pyntcloud==0.1.2 130 | pynvx==1.0.0 131 | pyobjc==6.1 132 | pyparsing==2.3.1 133 | pyperclip==1.7.0 134 | PyQt5==5.13.2 135 | PyQt5-sip==4.19.19 136 | pyrosbag==0.1.3 137 | pyrsistent==0.16.0 138 | python-dateutil==2.8.1 139 | python-multipart==0.0.5 140 | python-pcl==0.3.0a1 141 | python-telegram-bot==11.1.0 142 | pytz==2018.9 143 | PyWavelets==1.0.3 144 | pyzmq==19.0.0 145 | qtconsole==4.5.5 146 | requests==2.20.1 147 | retrying==1.3.3 148 | scikit-image==0.14.2 149 | scikit-learn==0.20.3 150 | scipy==1.2.0 151 | semantic-version==2.5.0 152 | Send2Trash==1.5.0 153 | six==1.14.0 154 | sklearn==0.0 155 | soupsieve==1.9 156 | spacy==2.1.3 157 | spotlight==0.1.6 158 | srsly==0.0.5 159 | starlette==0.11.4 160 | tarjan==0.2.3.2 161 | TBB==0.1 162 | termcolor==1.1.0 163 | terminado==0.8.3 164 | testpath==0.4.4 165 | texttable==0.9.1 166 | thinc==7.0.4 167 | toolz==0.10.0 168 | torch==1.4.0 169 | torchvision==0.5.0 170 | torchviz==0.0.1 171 | tornado==6.0.4 172 | tqdm==4.31.1 173 | traitlets==4.3.3 174 | typing==3.6.6 175 | typing-extensions==3.7.2 176 | uvicorn==0.11.7 177 | uvloop==0.12.2 178 | virtualenv==16.4.3 179 | wasabi==0.2.1 180 | wcwidth==0.1.9 181 | webencodings==0.5.1 182 | websocket-client==0.56.0 183 | websockets==9.1 184 | widgetsnbextension==3.5.1 185 | wrapt==1.11.2 186 | xlrd==1.2.0 187 | yarl==1.3.0 188 | zipp==3.1.0 189 | -------------------------------------------------------------------------------- /source/args.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | 4 | def parse_args(): 5 | parser = argparse.ArgumentParser(description='') 6 | 7 | # experiment settings 8 | parser.add_argument('--root_dir', default='../ModelNet10/', type=str, 9 | help='dataset directory') 10 | parser.add_argument('--batch_size', default=32, type=int, 11 | help='training batch size') 12 | parser.add_argument('--lr', default=1e-3, type=float, 13 | help='learning rate') 14 | parser.add_argument('--epochs', default=15, type=int, 15 | help='number of training epochs') 16 | parser.add_argument('--save_model_path', default='./checkpoints/', type=str, 17 | help='checkpoints dir') 18 | 19 | 20 | args = parser.parse_args() 21 | 22 | assert args.root_dir is not None 23 | 24 | print(' '.join(sys.argv)) 25 | print(args) 26 | 27 | return args 28 | -------------------------------------------------------------------------------- /source/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from path import Path 3 | from source import utils 4 | from torchvision import transforms 5 | from torch.utils.data import Dataset, DataLoader 6 | 7 | 8 | 9 | def default_transforms(): 10 | return transforms.Compose([ 11 | utils.PointSampler(1024), 12 | utils.Normalize(), 13 | utils.ToTensor() 14 | ]) 15 | 16 | class PointCloudData(Dataset): 17 | def __init__(self, root_dir, valid=False, folder="train", transform=default_transforms()): 18 | self.root_dir = root_dir 19 | folders = [dir for dir in sorted(os.listdir(root_dir)) if os.path.isdir(root_dir/dir)] 20 | self.classes = {folder: i for i, folder in enumerate(folders)} 21 | self.transforms = transform if not valid else default_transforms() 22 | self.valid = valid 23 | self.files = [] 24 | for category in self.classes.keys(): 25 | new_dir = root_dir/Path(category)/folder 26 | for file in os.listdir(new_dir): 27 | if file.endswith('.off'): 28 | sample = {} 29 | sample['pcd_path'] = new_dir/file 30 | sample['category'] = category 31 | self.files.append(sample) 32 | 33 | def __len__(self): 34 | return len(self.files) 35 | 36 | def __preproc__(self, file): 37 | verts, faces = utils.read_off(file) 38 | if self.transforms: 39 | pointcloud = self.transforms((verts, faces)) 40 | return pointcloud 41 | 42 | def __getitem__(self, idx): 43 | pcd_path = self.files[idx]['pcd_path'] 44 | category = self.files[idx]['category'] 45 | with open(pcd_path, 'r') as f: 46 | pointcloud = self.__preproc__(f) 47 | return {'pointcloud': pointcloud, 48 | 'category': self.classes[category]} 49 | -------------------------------------------------------------------------------- /source/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | class Tnet(nn.Module): 7 | def __init__(self, k=3): 8 | super().__init__() 9 | self.k=k 10 | self.conv1 = nn.Conv1d(k,64,1) 11 | self.conv2 = nn.Conv1d(64,128,1) 12 | self.conv3 = nn.Conv1d(128,1024,1) 13 | self.fc1 = nn.Linear(1024,512) 14 | self.fc2 = nn.Linear(512,256) 15 | self.fc3 = nn.Linear(256,k*k) 16 | 17 | self.bn1 = nn.BatchNorm1d(64) 18 | self.bn2 = nn.BatchNorm1d(128) 19 | self.bn3 = nn.BatchNorm1d(1024) 20 | self.bn4 = nn.BatchNorm1d(512) 21 | self.bn5 = nn.BatchNorm1d(256) 22 | 23 | 24 | def forward(self, input): 25 | # input.shape == (bs,n,3) 26 | bs = input.size(0) 27 | xb = F.relu(self.bn1(self.conv1(input))) 28 | xb = F.relu(self.bn2(self.conv2(xb))) 29 | xb = F.relu(self.bn3(self.conv3(xb))) 30 | pool = nn.MaxPool1d(xb.size(-1))(xb) 31 | flat = nn.Flatten(1)(pool) 32 | xb = F.relu(self.bn4(self.fc1(flat))) 33 | xb = F.relu(self.bn5(self.fc2(xb))) 34 | 35 | #initialize as identity 36 | init = torch.eye(self.k, requires_grad=True).repeat(bs,1,1) 37 | if xb.is_cuda: 38 | init=init.cuda() 39 | matrix = self.fc3(xb).view(-1,self.k,self.k) + init 40 | return matrix 41 | 42 | 43 | class Transform(nn.Module): 44 | def __init__(self): 45 | super().__init__() 46 | self.input_transform = Tnet(k=3) 47 | self.feature_transform = Tnet(k=64) 48 | self.conv1 = nn.Conv1d(3,64,1) 49 | 50 | self.conv2 = nn.Conv1d(64,128,1) 51 | self.conv3 = nn.Conv1d(128,1024,1) 52 | 53 | 54 | self.bn1 = nn.BatchNorm1d(64) 55 | self.bn2 = nn.BatchNorm1d(128) 56 | self.bn3 = nn.BatchNorm1d(1024) 57 | 58 | def forward(self, input): 59 | matrix3x3 = self.input_transform(input) 60 | # batch matrix multiplication 61 | xb = torch.bmm(torch.transpose(input,1,2), matrix3x3).transpose(1,2) 62 | 63 | xb = F.relu(self.bn1(self.conv1(xb))) 64 | 65 | matrix64x64 = self.feature_transform(xb) 66 | xb = torch.bmm(torch.transpose(xb,1,2), matrix64x64).transpose(1,2) 67 | 68 | xb = F.relu(self.bn2(self.conv2(xb))) 69 | xb = self.bn3(self.conv3(xb)) 70 | xb = nn.MaxPool1d(xb.size(-1))(xb) 71 | output = nn.Flatten(1)(xb) 72 | return output, matrix3x3, matrix64x64 73 | 74 | class PointNet(nn.Module): 75 | def __init__(self, classes = 10): 76 | super().__init__() 77 | self.transform = Transform() 78 | self.fc1 = nn.Linear(1024, 512) 79 | self.fc2 = nn.Linear(512, 256) 80 | self.fc3 = nn.Linear(256, classes) 81 | 82 | 83 | self.bn1 = nn.BatchNorm1d(512) 84 | self.bn2 = nn.BatchNorm1d(256) 85 | self.dropout = nn.Dropout(p=0.3) 86 | self.logsoftmax = nn.LogSoftmax(dim=1) 87 | 88 | def forward(self, input): 89 | xb, matrix3x3, matrix64x64 = self.transform(input) 90 | xb = F.relu(self.bn1(self.fc1(xb))) 91 | xb = F.relu(self.bn2(self.dropout(self.fc2(xb)))) 92 | output = self.fc3(xb) 93 | return self.logsoftmax(output), matrix3x3, matrix64x64 94 | 95 | 96 | -------------------------------------------------------------------------------- /source/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import random 4 | import torch 5 | 6 | 7 | def read_off(file): 8 | if 'OFF' != file.readline().strip(): 9 | raise('Not a valid OFF header') 10 | n_verts, n_faces, __ = tuple([int(s) for s in file.readline().strip().split(' ')]) 11 | verts = [[float(s) for s in file.readline().strip().split(' ')] for i_vert in range(n_verts)] 12 | faces = [[int(s) for s in file.readline().strip().split(' ')][1:] for i_face in range(n_faces)] 13 | return verts, faces 14 | 15 | 16 | class PointSampler(object): 17 | def __init__(self, output_size): 18 | assert isinstance(output_size, int) 19 | self.output_size = output_size 20 | 21 | def triangle_area(self, pt1, pt2, pt3): 22 | side_a = np.linalg.norm(pt1 - pt2) 23 | side_b = np.linalg.norm(pt2 - pt3) 24 | side_c = np.linalg.norm(pt3 - pt1) 25 | s = 0.5 * ( side_a + side_b + side_c) 26 | return max(s * (s - side_a) * (s - side_b) * (s - side_c), 0)**0.5 27 | 28 | def sample_point(self, pt1, pt2, pt3): 29 | # barycentric coordinates on a triangle 30 | # https://mathworld.wolfram.com/BarycentricCoordinates.html 31 | s, t = sorted([random.random(), random.random()]) 32 | f = lambda i: s * pt1[i] + (t-s)*pt2[i] + (1-t)*pt3[i] 33 | return (f(0), f(1), f(2)) 34 | 35 | 36 | def __call__(self, mesh): 37 | verts, faces = mesh 38 | verts = np.array(verts) 39 | areas = np.zeros((len(faces))) 40 | 41 | for i in range(len(areas)): 42 | areas[i] = (self.triangle_area(verts[faces[i][0]], 43 | verts[faces[i][1]], 44 | verts[faces[i][2]])) 45 | 46 | sampled_faces = (random.choices(faces, 47 | weights=areas, 48 | cum_weights=None, 49 | k=self.output_size)) 50 | 51 | sampled_points = np.zeros((self.output_size, 3)) 52 | 53 | for i in range(len(sampled_faces)): 54 | sampled_points[i] = (self.sample_point(verts[sampled_faces[i][0]], 55 | verts[sampled_faces[i][1]], 56 | verts[sampled_faces[i][2]])) 57 | 58 | return sampled_points 59 | 60 | class Normalize(object): 61 | def __call__(self, pointcloud): 62 | assert len(pointcloud.shape)==2 63 | 64 | norm_pointcloud = pointcloud - np.mean(pointcloud, axis=0) 65 | norm_pointcloud /= np.max(np.linalg.norm(norm_pointcloud, axis=1)) 66 | 67 | return norm_pointcloud 68 | 69 | class RandRotation_z(object): 70 | def __call__(self, pointcloud): 71 | assert len(pointcloud.shape)==2 72 | 73 | theta = random.random() * 2. * math.pi 74 | rot_matrix = np.array([[ math.cos(theta), -math.sin(theta), 0], 75 | [ math.sin(theta), math.cos(theta), 0], 76 | [0, 0, 1]]) 77 | 78 | rot_pointcloud = rot_matrix.dot(pointcloud.T).T 79 | return rot_pointcloud 80 | 81 | class RandomNoise(object): 82 | def __call__(self, pointcloud): 83 | assert len(pointcloud.shape)==2 84 | 85 | noise = np.random.normal(0, 0.02, (pointcloud.shape)) 86 | 87 | noisy_pointcloud = pointcloud + noise 88 | return noisy_pointcloud 89 | 90 | class ToTensor(object): 91 | def __call__(self, pointcloud): 92 | assert len(pointcloud.shape)==2 93 | 94 | return torch.from_numpy(pointcloud) 95 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import random 4 | import os 5 | import torch 6 | from path import Path 7 | from source import model 8 | from source import dataset 9 | from source import utils 10 | from source.args import parse_args 11 | from torchvision import transforms 12 | from torch.utils.data import Dataset, DataLoader 13 | 14 | random.seed = 42 15 | 16 | def pointnetloss(outputs, labels, m3x3, m64x64, alpha = 0.0001): 17 | criterion = torch.nn.NLLLoss() 18 | bs=outputs.size(0) 19 | id3x3 = torch.eye(3, requires_grad=True).repeat(bs,1,1) 20 | id64x64 = torch.eye(64, requires_grad=True).repeat(bs,1,1) 21 | if outputs.is_cuda: 22 | id3x3=id3x3.cuda() 23 | id64x64=id64x64.cuda() 24 | diff3x3 = id3x3-torch.bmm(m3x3,m3x3.transpose(1,2)) 25 | diff64x64 = id64x64-torch.bmm(m64x64,m64x64.transpose(1,2)) 26 | return criterion(outputs, labels) + alpha * (torch.norm(diff3x3)+torch.norm(diff64x64)) / float(bs) 27 | 28 | 29 | 30 | def train(args): 31 | path = Path(args.root_dir) 32 | 33 | folders = [dir for dir in sorted(os.listdir(path)) if os.path.isdir(path/dir)] 34 | classes = {folder: i for i, folder in enumerate(folders)}; 35 | 36 | train_transforms = transforms.Compose([ 37 | utils.PointSampler(1024), 38 | utils.Normalize(), 39 | utils.RandRotation_z(), 40 | utils.RandomNoise(), 41 | utils.ToTensor() 42 | ]) 43 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 44 | print(device) 45 | pointnet = model.PointNet() 46 | pointnet.to(device) 47 | optimizer = torch.optim.Adam(pointnet.parameters(), lr=args.lr) 48 | 49 | train_ds = dataset.PointCloudData(path, transform=train_transforms) 50 | valid_ds = dataset.PointCloudData(path, valid=True, folder='test', transform=train_transforms) 51 | print('Train dataset size: ', len(train_ds)) 52 | print('Valid dataset size: ', len(valid_ds)) 53 | print('Number of classes: ', len(train_ds.classes)) 54 | 55 | train_loader = DataLoader(dataset=train_ds, batch_size=args.batch_size, shuffle=True) 56 | valid_loader = DataLoader(dataset=valid_ds, batch_size=args.batch_size*2) 57 | 58 | try: 59 | os.mkdir(args.save_model_path) 60 | except OSError as error: 61 | print(error) 62 | 63 | print('Start training') 64 | for epoch in range(args.epochs): 65 | pointnet.train() 66 | running_loss = 0.0 67 | for i, data in enumerate(train_loader, 0): 68 | inputs, labels = data['pointcloud'].to(device).float(), data['category'].to(device) 69 | optimizer.zero_grad() 70 | outputs, m3x3, m64x64 = pointnet(inputs.transpose(1,2)) 71 | 72 | loss = pointnetloss(outputs, labels, m3x3, m64x64) 73 | loss.backward() 74 | optimizer.step() 75 | 76 | # print statistics 77 | running_loss += loss.item() 78 | if i % 10 == 9: # print every 10 mini-batches 79 | print('[Epoch: %d, Batch: %4d / %4d], loss: %.3f' % 80 | (epoch + 1, i + 1, len(train_loader), running_loss / 10)) 81 | running_loss = 0.0 82 | 83 | pointnet.eval() 84 | correct = total = 0 85 | 86 | # validation 87 | if valid_loader: 88 | with torch.no_grad(): 89 | for data in valid_loader: 90 | inputs, labels = data['pointcloud'].to(device).float(), data['category'].to(device) 91 | outputs, __, __ = pointnet(inputs.transpose(1,2)) 92 | _, predicted = torch.max(outputs.data, 1) 93 | total += labels.size(0) 94 | correct += (predicted == labels).sum().item() 95 | val_acc = 100. * correct / total 96 | print('Valid accuracy: %d %%' % val_acc) 97 | # save the model 98 | 99 | checkpoint = Path(args.save_model_path)/'save_'+str(epoch)+'.pth' 100 | torch.save(pointnet.state_dict(), checkpoint) 101 | print('Model saved to ', checkpoint) 102 | 103 | if __name__ == '__main__': 104 | args = parse_args() 105 | train(args) 106 | --------------------------------------------------------------------------------