├── LICENSE
├── README.md
├── images
├── airplane.gif
├── chair.gif
└── cover.gif
├── nbs
├── PointNetClass.ipynb
└── PointNetSeg.ipynb
├── requirements.txt
├── source
├── args.py
├── dataset.py
├── model.py
└── utils.py
└── train.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # PointNet
4 | PyTorch implementation of "PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation" https://arxiv.org/abs/1612.00593
5 |
6 |
7 |
8 |
9 | Key points of the implementation are explained in details in [this](https://towardsdatascience.com/deep-learning-on-point-clouds-implementing-pointnet-in-google-colab-1fd65cd3a263) Medium article.
10 |
11 | ## Classification dataset
12 | This code implements object classification on [ModelNet10](https://modelnet.cs.princeton.edu) dataset.
13 |
14 | As in the original paper, we sample 1024 points on objects surfaces depending on the area of the current face. Then we normalize the object to a unit sphere and add Gaussian noise. This is an example of input to the neural network that represents a chair:
15 |
16 |
17 |
18 | You can download the dataset by following [this link](https://drive.google.com/open?id=12Mv19pQ84VO8Av50hUXTixSxd5NDjeEB)
19 |
20 | ## Classification performance
21 |
22 | | Class (Accuracy) | Overall | Bathtub | Bed| Chair|Desk|Dresser|Monitor|Night stand|Sofa|Table|Toilet|
23 | | :---: |:---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
24 | | ModelNet10 | 82.0% | 93.4% | 92.0% | 97.2% | 81.5% | 71.0% | 89.4% | 56.0% |86.9%| 93.4% |95.9%|
25 |
26 | Pretrained model is available [here](https://drive.google.com/open?id=1nDG0maaqoTkRkVsOLtUAR9X3kn__LMSL)
27 |
28 | ## Usage
29 | * The first and the best option is to run the notebook with comments and visualizations `/nbs/PointNetClass.ipynb` in Google Colab.
30 | * The second option is to clone the repository on a local machine and run a model with default parameters:
31 | ```bash
32 | git clone https://github.com/nikitakaraevv/pointnet
33 | wget http://3dvision.princeton.edu/projects/2014/3DShapeNets/ModelNet10.zip
34 | unzip -q ModelNet10.zip
35 | cd pointnet/
36 | python train.py
37 | ```
38 | If for some reason it doesn't work, you can install the requirements before running `python train.py`:
39 | ```bash
40 | conda create -n env python=3.7
41 | conda activate env
42 | pip install -r requirements.txt
43 | ```
44 | Another example of running a model is:
45 |
46 | ```python train.py --root_dir ../ModelNet10/ --batch_size 16 --lr 0.0001 --epochs 30 --save_model_path ./ckpts```
47 |
48 | ## Part segmentation dataset
49 | The dataset includes 2609 point clouds representing different airplanes, where every point has its coordinates in 3D space and a label of an airplane’s part the point belongs to. As all images have different number of points and PyTorch library functions require images of the same size to form a PyTorch tensor, we sample uniformly 2000 points from every point cloud.
50 |
51 | You can download the dataset by following [this link](https://drive.google.com/drive/u/1/folders/1Z5XA4uJpA86ky0qV1AVgA_G1_ETkq9En)
52 |
53 | ## Part segmentation performance
54 | The resulting accuracy on the validation dataset is 88%. In the original paper part segmentation results corresponding to category of objects (airplanes) is 83.4%.
55 |
56 | ## Usage
57 | This part of the project is still in development. However, you already can run the notebook `/nbs/PointNetSeg.ipynb` in Colab.
58 |
59 |
60 |
61 | ## Authors
62 | * [Nikita Karaev](https://github.com/nikitakaraevv)
63 | * [Irina Nikulina](https://github.com/washburn125)
64 |
--------------------------------------------------------------------------------
/images/airplane.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nikitakaraevv/pointnet/256437e9ab27b197347464cecff87121c5c824ff/images/airplane.gif
--------------------------------------------------------------------------------
/images/chair.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nikitakaraevv/pointnet/256437e9ab27b197347464cecff87121c5c824ff/images/chair.gif
--------------------------------------------------------------------------------
/images/cover.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nikitakaraevv/pointnet/256437e9ab27b197347464cecff87121c5c824ff/images/cover.gif
--------------------------------------------------------------------------------
/nbs/PointNetSeg.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "kernelspec": {
6 | "name": "python3",
7 | "display_name": "Python 3"
8 | },
9 | "language_info": {
10 | "codemirror_mode": {
11 | "name": "ipython",
12 | "version": 3
13 | },
14 | "file_extension": ".py",
15 | "mimetype": "text/x-python",
16 | "name": "python",
17 | "nbconvert_exporter": "python",
18 | "pygments_lexer": "ipython3",
19 | "version": "3.7.1"
20 | },
21 | "colab": {
22 | "name": "PointNetSeg.ipynb",
23 | "provenance": [],
24 | "collapsed_sections": [
25 | "caWQIszA8r-H"
26 | ],
27 | "machine_shape": "hm",
28 | "include_colab_link": true
29 | },
30 | "accelerator": "GPU"
31 | },
32 | "cells": [
33 | {
34 | "cell_type": "markdown",
35 | "metadata": {
36 | "id": "view-in-github",
37 | "colab_type": "text"
38 | },
39 | "source": [
40 | "
"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "metadata": {
46 | "id": "KuEIVQsJGUbQ",
47 | "colab_type": "code",
48 | "outputId": "ff151e25-a584-4c7a-bb2b-634acedd0952",
49 | "colab": {
50 | "base_uri": "https://localhost:8080/",
51 | "height": 35
52 | }
53 | },
54 | "source": [
55 | "from google.colab import drive\n",
56 | "drive.mount('/content/gdrive', force_remount=True)\n",
57 | "root_dir = \"/content/gdrive/My Drive/PointNet3D\"\n"
58 | ],
59 | "execution_count": 0,
60 | "outputs": [
61 | {
62 | "output_type": "stream",
63 | "text": [
64 | "Mounted at /content/gdrive\n"
65 | ],
66 | "name": "stdout"
67 | }
68 | ]
69 | },
70 | {
71 | "cell_type": "code",
72 | "metadata": {
73 | "id": "PLGjzwFkGUbV",
74 | "colab_type": "code",
75 | "outputId": "a19df6db-1edd-4c03-8c1a-3e549eaf08de",
76 | "colab": {
77 | "base_uri": "https://localhost:8080/",
78 | "height": 124
79 | }
80 | },
81 | "source": [
82 | "!pip install path.py;\n",
83 | "from path import Path\n",
84 | "import sys\n",
85 | "sys.path.append(root_dir)"
86 | ],
87 | "execution_count": 0,
88 | "outputs": [
89 | {
90 | "output_type": "stream",
91 | "text": [
92 | "Requirement already satisfied: path.py in /usr/local/lib/python3.6/dist-packages (12.4.0)\n",
93 | "Requirement already satisfied: path<13.2 in /usr/local/lib/python3.6/dist-packages (from path.py) (13.1.0)\n",
94 | "Requirement already satisfied: importlib-metadata>=0.5; python_version < \"3.8\" in /usr/local/lib/python3.6/dist-packages (from path<13.2->path.py) (1.3.0)\n",
95 | "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata>=0.5; python_version < \"3.8\"->path<13.2->path.py) (0.6.0)\n",
96 | "Requirement already satisfied: more-itertools in /usr/local/lib/python3.6/dist-packages (from zipp>=0.5->importlib-metadata>=0.5; python_version < \"3.8\"->path<13.2->path.py) (8.0.2)\n"
97 | ],
98 | "name": "stdout"
99 | }
100 | ]
101 | },
102 | {
103 | "cell_type": "code",
104 | "metadata": {
105 | "id": "LB7sxVUtGUbY",
106 | "colab_type": "code",
107 | "colab": {}
108 | },
109 | "source": [
110 | "import plotly.graph_objects as go\n",
111 | "import numpy as np\n",
112 | "import scipy.spatial.distance\n",
113 | "import math\n",
114 | "import random\n",
115 | "import utils\n",
116 | "\n",
117 | "\n",
118 | "class10_dir = \"/datasets/ModelNet10txt/ModelNet10/ModelNet10/\""
119 | ],
120 | "execution_count": 0,
121 | "outputs": []
122 | },
123 | {
124 | "cell_type": "code",
125 | "metadata": {
126 | "id": "31PRdnvMOjLd",
127 | "colab_type": "code",
128 | "colab": {}
129 | },
130 | "source": [
131 | "import random\n",
132 | "\n",
133 | "def read_pts(file):\n",
134 | " verts = np.genfromtxt(file)\n",
135 | " return utils.cent_norm(verts)\n",
136 | " #return verts\n",
137 | "\n",
138 | "def read_seg(file):\n",
139 | " verts = np.genfromtxt(file, dtype= (int))\n",
140 | " return verts\n",
141 | "\n",
142 | "def sample_2000(pts, pts_cat): \n",
143 | " res1 = np.concatenate((pts,np.reshape(pts_cat, (pts_cat.shape[0], 1))), axis= 1)\n",
144 | " res = np.asarray(random.choices(res1, weights=None, cum_weights=None, k=2000))\n",
145 | " images = res[:, 0:3]\n",
146 | " categories = res[:, 3]\n",
147 | " categories-=np.ones(categories.shape)\n",
148 | " return images, categories"
149 | ],
150 | "execution_count": 0,
151 | "outputs": []
152 | },
153 | {
154 | "cell_type": "markdown",
155 | "metadata": {
156 | "id": "4KNH5UtbWWok",
157 | "colab_type": "text"
158 | },
159 | "source": [
160 | "## Model"
161 | ]
162 | },
163 | {
164 | "cell_type": "code",
165 | "metadata": {
166 | "id": "yOOrEYSnWV7f",
167 | "colab_type": "code",
168 | "colab": {}
169 | },
170 | "source": [
171 | "import torch\n",
172 | "import torch.nn as nn\n",
173 | "import numpy as np\n",
174 | "import torch.nn.functional as F\n",
175 | "\n",
176 | "class Tnet(nn.Module):\n",
177 | " def __init__(self, k=3):\n",
178 | " super().__init__()\n",
179 | " self.k=k\n",
180 | " self.conv1 = nn.Conv1d(k,64,1)\n",
181 | " self.conv2 = nn.Conv1d(64,128,1)\n",
182 | " self.conv3 = nn.Conv1d(128,1024,1)\n",
183 | " self.fc1 = nn.Linear(1024,512)\n",
184 | " self.fc2 = nn.Linear(512,256)\n",
185 | " self.fc3 = nn.Linear(256,k*k)\n",
186 | "\n",
187 | " self.bn1 = nn.BatchNorm1d(64)\n",
188 | " self.bn2 = nn.BatchNorm1d(128)\n",
189 | " self.bn3 = nn.BatchNorm1d(1024)\n",
190 | " self.bn4 = nn.BatchNorm1d(512)\n",
191 | " self.bn5 = nn.BatchNorm1d(256)\n",
192 | " \n",
193 | "\n",
194 | " def forward(self, input):\n",
195 | " # input.shape == (bs,n,3)\n",
196 | " bs = input.size(0)\n",
197 | " xb = F.relu(self.bn1(self.conv1(input)))\n",
198 | " xb = F.relu(self.bn2(self.conv2(xb)))\n",
199 | " xb = F.relu(self.bn3(self.conv3(xb)))\n",
200 | " pool = nn.MaxPool1d(xb.size(-1))(xb)\n",
201 | " flat = nn.Flatten(1)(pool)\n",
202 | " xb = F.relu(self.bn4(self.fc1(flat)))\n",
203 | " xb = F.relu(self.bn5(self.fc2(xb)))\n",
204 | " \n",
205 | " #initialize as identity\n",
206 | " init = torch.eye(self.k, requires_grad=True).repeat(bs,1,1)\n",
207 | " if xb.is_cuda:\n",
208 | " init=init.cuda()\n",
209 | " matrix = self.fc3(xb).view(-1,self.k,self.k) + init\n",
210 | " return matrix\n",
211 | "\n",
212 | "\n",
213 | "class Transform(nn.Module):\n",
214 | " def __init__(self):\n",
215 | " super().__init__()\n",
216 | " self.input_transform = Tnet(k=3)\n",
217 | " self.feature_transform = Tnet(k=128)\n",
218 | " self.fc1 = nn.Conv1d(3,64,1)\n",
219 | " self.fc2 = nn.Conv1d(64,128,1) \n",
220 | " self.fc3 = nn.Conv1d(128,128,1)\n",
221 | " self.fc4 = nn.Conv1d(128,512,1)\n",
222 | " self.fc5 = nn.Conv1d(512,2048,1)\n",
223 | "\n",
224 | " \n",
225 | " self.bn1 = nn.BatchNorm1d(64)\n",
226 | " self.bn2 = nn.BatchNorm1d(128)\n",
227 | " self.bn3 = nn.BatchNorm1d(128)\n",
228 | " self.bn4 = nn.BatchNorm1d(512)\n",
229 | " self.bn5 = nn.BatchNorm1d(2048)\n",
230 | "\n",
231 | " def forward(self, input):\n",
232 | " n_pts = input.size()[2]\n",
233 | " matrix3x3 = self.input_transform(input)\n",
234 | " xb = torch.bmm(torch.transpose(input,1,2), matrix3x3).transpose(1,2)\n",
235 | " outs = []\n",
236 | " \n",
237 | " out1 = F.relu(self.bn1(self.fc1(xb)))\n",
238 | " outs.append(out1)\n",
239 | " out2 = F.relu(self.bn2(self.fc2(out1)))\n",
240 | " outs.append(out2)\n",
241 | " out3 = F.relu(self.bn3(self.fc3(out2)))\n",
242 | " outs.append(out3)\n",
243 | " matrix128x128 = self.feature_transform(out3)\n",
244 | " \n",
245 | " out4 = torch.bmm(torch.transpose(out3,1,2), matrix128x128).transpose(1,2) \n",
246 | " outs.append(out4)\n",
247 | " out5 = F.relu(self.bn4(self.fc4(out4)))\n",
248 | " outs.append(out5)\n",
249 | " \n",
250 | " xb = self.bn5(self.fc5(out5))\n",
251 | " \n",
252 | " xb = nn.MaxPool1d(xb.size(-1))(xb)\n",
253 | " out6 = nn.Flatten(1)(xb).repeat(n_pts,1,1).transpose(0,2).transpose(0,1)#.repeat(1, 1, n_pts)\n",
254 | " outs.append(out6)\n",
255 | " \n",
256 | " \n",
257 | " return outs, matrix3x3, matrix128x128\n",
258 | "\n",
259 | "\n",
260 | "class PointNetSeg(nn.Module):\n",
261 | " def __init__(self, classes = 10):\n",
262 | " super().__init__()\n",
263 | " self.transform = Transform()\n",
264 | "\n",
265 | " self.fc1 = nn.Conv1d(3008,256,1) \n",
266 | " self.fc2 = nn.Conv1d(256,256,1) \n",
267 | " self.fc3 = nn.Conv1d(256,128,1) \n",
268 | " self.fc4 = nn.Conv1d(128,4,1) \n",
269 | " \n",
270 | "\n",
271 | " self.bn1 = nn.BatchNorm1d(256)\n",
272 | " self.bn2 = nn.BatchNorm1d(256)\n",
273 | " \n",
274 | " self.bn3 = nn.BatchNorm1d(128)\n",
275 | " self.bn4 = nn.BatchNorm1d(4)\n",
276 | " \n",
277 | " self.logsoftmax = nn.LogSoftmax(dim=1)\n",
278 | " \n",
279 | "\n",
280 | " def forward(self, input):\n",
281 | " inputs, matrix3x3, matrix128x128 = self.transform(input)\n",
282 | " stack = torch.cat(inputs,1)\n",
283 | " \n",
284 | " xb = F.relu(self.bn1(self.fc1(stack)))\n",
285 | " \n",
286 | " xb = F.relu(self.bn2(self.fc2(xb)))\n",
287 | " \n",
288 | " xb = F.relu(self.bn3(self.fc3(xb)))\n",
289 | " \n",
290 | " output = F.relu(self.bn4(self.fc4(xb)))\n",
291 | " \n",
292 | " return self.logsoftmax(output), matrix3x3, matrix128x128\n",
293 | "\n"
294 | ],
295 | "execution_count": 0,
296 | "outputs": []
297 | },
298 | {
299 | "cell_type": "markdown",
300 | "metadata": {
301 | "id": "HdXchtFBWZYG",
302 | "colab_type": "text"
303 | },
304 | "source": [
305 | "## Dataset"
306 | ]
307 | },
308 | {
309 | "cell_type": "code",
310 | "metadata": {
311 | "id": "ut50_1uQCFCc",
312 | "colab_type": "code",
313 | "colab": {}
314 | },
315 | "source": [
316 | "from __future__ import print_function, division\n",
317 | "import os\n",
318 | "import torch\n",
319 | "import pandas as pd\n",
320 | "from skimage import io, transform\n",
321 | "import numpy as np\n",
322 | "import matplotlib.pyplot as plt\n",
323 | "from torch.utils.data import Dataset, DataLoader\n",
324 | "from torchvision import transforms, utils\n",
325 | "from torch.utils.data.dataset import random_split\n",
326 | "import utils\n",
327 | "\n",
328 | "class Data(Dataset):\n",
329 | " \"\"\"Face Landmarks dataset.\"\"\"\n",
330 | "\n",
331 | " def __init__(self, root_dir, valid=False, transform=None):\n",
332 | " \n",
333 | " self.root_dir = root_dir\n",
334 | " self.files = []\n",
335 | " self.valid=valid\n",
336 | "\n",
337 | " newdir = root_dir + '/datasets/airplane_part_seg/02691156/expert_verified/points_label/'\n",
338 | "\n",
339 | " for file in os.listdir(newdir):\n",
340 | " o = {}\n",
341 | " o['category'] = newdir + file\n",
342 | " o['img_path'] = root_dir + '/datasets/airplane_part_seg/02691156/points/'+ file.replace('.seg', '.pts')\n",
343 | " self.files.append(o)\n",
344 | " \n",
345 | "\n",
346 | " def __len__(self):\n",
347 | " return len(self.files)\n",
348 | "\n",
349 | " def __getitem__(self, idx):\n",
350 | " img_path = self.files[idx]['img_path']\n",
351 | " category = self.files[idx]['category']\n",
352 | " with open(img_path, 'r') as f:\n",
353 | " image1 = read_pts(f)\n",
354 | " with open(category, 'r') as f: \n",
355 | " category1 = read_seg(f)\n",
356 | " image2, category2 = sample_2000(image1, category1)\n",
357 | " if not self.valid:\n",
358 | " theta = random.random()*360\n",
359 | " image2 = utils.rotation_z(utils.add_noise(image2), theta)\n",
360 | " \n",
361 | " return {'image': np.array(image2, dtype=\"float32\"), 'category': category2.astype(int)}\n"
362 | ],
363 | "execution_count": 0,
364 | "outputs": []
365 | },
366 | {
367 | "cell_type": "code",
368 | "metadata": {
369 | "id": "1mUcFS3Uwci6",
370 | "colab_type": "code",
371 | "outputId": "4d10dda0-7c73-4007-ce71-de106f2a32a6",
372 | "colab": {
373 | "base_uri": "https://localhost:8080/",
374 | "height": 69
375 | }
376 | },
377 | "source": [
378 | "\n",
379 | "dset = Data(root_dir , transform=None)\n",
380 | "train_num = int(len(dset) * 0.95)\n",
381 | "val_num = int(len(dset) *0.05)\n",
382 | "if int(len(dset)) - train_num - val_num >0 :\n",
383 | " train_num = train_num + 1\n",
384 | "elif int(len(dset)) - train_num - val_num < 0:\n",
385 | " train_num = train_num -1\n",
386 | "#train_dataset, val_dataset = random_split(dset, [3000, 118])\n",
387 | "train_dataset, val_dataset = random_split(dset, [train_num, val_num])\n",
388 | "val_dataset.valid=True\n",
389 | "\n",
390 | "print('######### Dataset class created #########')\n",
391 | "print('Number of images: ', len(dset))\n",
392 | "print('Sample image shape: ', dset[0]['image'].shape)\n",
393 | "#print('Sample image points categories', dset[0]['category'], end='\\n\\n')\n",
394 | "\n",
395 | "train_loader = DataLoader(dataset=train_dataset, batch_size=64)\n",
396 | "val_loader = DataLoader(dataset=val_dataset, batch_size=64)\n",
397 | "\n",
398 | "#dataloader = torch.utils.data.DataLoader(dset, batch_size=4, shuffle=True, num_workers=4)"
399 | ],
400 | "execution_count": 0,
401 | "outputs": [
402 | {
403 | "output_type": "stream",
404 | "text": [
405 | "######### Dataset class created #########\n",
406 | "Number of images: 2690\n",
407 | "Sample image shape: (2000, 3)\n"
408 | ],
409 | "name": "stdout"
410 | }
411 | ]
412 | },
413 | {
414 | "cell_type": "markdown",
415 | "metadata": {
416 | "id": "gg9RjG7awgVK",
417 | "colab_type": "text"
418 | },
419 | "source": [
420 | "## Training loop"
421 | ]
422 | },
423 | {
424 | "cell_type": "code",
425 | "metadata": {
426 | "id": "bq9AUuN5WRxI",
427 | "colab_type": "code",
428 | "outputId": "a6c2e5eb-ee77-4f4e-f511-77acffab30fa",
429 | "colab": {
430 | "base_uri": "https://localhost:8080/",
431 | "height": 35
432 | }
433 | },
434 | "source": [
435 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
436 | "print(device)"
437 | ],
438 | "execution_count": 0,
439 | "outputs": [
440 | {
441 | "output_type": "stream",
442 | "text": [
443 | "cuda:0\n"
444 | ],
445 | "name": "stdout"
446 | }
447 | ]
448 | },
449 | {
450 | "cell_type": "code",
451 | "metadata": {
452 | "id": "cqXW9-oJwEPm",
453 | "colab_type": "code",
454 | "colab": {}
455 | },
456 | "source": [
457 | "pointnet = PointNetSeg()"
458 | ],
459 | "execution_count": 0,
460 | "outputs": []
461 | },
462 | {
463 | "cell_type": "code",
464 | "metadata": {
465 | "id": "6mA80v2ywHhw",
466 | "colab_type": "code",
467 | "colab": {}
468 | },
469 | "source": [
470 | "pointnet.to(device);"
471 | ],
472 | "execution_count": 0,
473 | "outputs": []
474 | },
475 | {
476 | "cell_type": "code",
477 | "metadata": {
478 | "id": "JV09EA4_wJnR",
479 | "colab_type": "code",
480 | "colab": {}
481 | },
482 | "source": [
483 | "optimizer = torch.optim.Adam(pointnet.parameters(), lr=0.001)"
484 | ],
485 | "execution_count": 0,
486 | "outputs": []
487 | },
488 | {
489 | "cell_type": "code",
490 | "metadata": {
491 | "id": "aDb9rPb_wPWj",
492 | "colab_type": "code",
493 | "colab": {}
494 | },
495 | "source": [
496 | "def pointnetloss(outputs, labels, m3x3, m128x128, alpha = 0.0001):\n",
497 | " criterion = torch.nn.NLLLoss()\n",
498 | " bs=outputs.size(0)\n",
499 | " id3x3 = torch.eye(3, requires_grad=True).repeat(bs,1,1)\n",
500 | " id128x128 = torch.eye(128, requires_grad=True).repeat(bs,1,1)\n",
501 | " if outputs.is_cuda:\n",
502 | " id3x3=id3x3.cuda()\n",
503 | " id128x128=id128x128.cuda()\n",
504 | " diff3x3 = id3x3-torch.bmm(m3x3,m3x3.transpose(1,2))\n",
505 | " diff128x128 = id128x128-torch.bmm(m128x128,m128x128.transpose(1,2))\n",
506 | " return criterion(outputs, labels) + alpha * (torch.norm(diff3x3)+torch.norm(diff128x128)) / float(bs)\n",
507 | " "
508 | ],
509 | "execution_count": 0,
510 | "outputs": []
511 | },
512 | {
513 | "cell_type": "code",
514 | "metadata": {
515 | "id": "CgaPisZFwVzh",
516 | "colab_type": "code",
517 | "colab": {}
518 | },
519 | "source": [
520 | "def train(model, train_loader, val_loader=None, epochs=15, save=True):\n",
521 | " for epoch in range(epochs): \n",
522 | " pointnet.train()\n",
523 | " running_loss = 0.0\n",
524 | " for i, data in enumerate(train_loader, 0):\n",
525 | " inputs, labels = data['image'].to(device), data['category'].to(device)\n",
526 | " optimizer.zero_grad()\n",
527 | " outputs, m3x3, m64x64 = pointnet(inputs.transpose(1,2))\n",
528 | "\n",
529 | " loss = pointnetloss(outputs, labels, m3x3, m64x64)\n",
530 | " loss.backward()\n",
531 | " optimizer.step()\n",
532 | "\n",
533 | " # print statistics\n",
534 | " running_loss += loss.item()\n",
535 | " if i % 10 == 9: # print every 10 mini-batches\n",
536 | " print('[%d, %5d] loss: %.3f' %\n",
537 | " (epoch + 1, i + 1, running_loss / 10))\n",
538 | " running_loss = 0.0\n",
539 | "\n",
540 | " pointnet.eval()\n",
541 | " correct = total = 0\n",
542 | "\n",
543 | " # validation\n",
544 | " if val_loader:\n",
545 | " with torch.no_grad():\n",
546 | " for data in val_loader:\n",
547 | " inputs, labels = data['image'].to(device), data['category'].to(device)\n",
548 | " outputs, __, __ = pointnet(inputs.transpose(1,2))\n",
549 | " _, predicted = torch.max(outputs.data, 1)\n",
550 | " total += labels.size(0) * labels.size(1) ##\n",
551 | " correct += (predicted == labels).sum().item()\n",
552 | " val_acc = 100 * correct / total\n",
553 | " print('Valid accuracy: %d %%' % val_acc)\n",
554 | "\n",
555 | " # save the model\n",
556 | " if save:\n",
557 | " torch.save(pointnet.state_dict(), root_dir+\"/modelsSeg/\"+str(epoch)+\"_\"+str(val_acc))\n"
558 | ],
559 | "execution_count": 0,
560 | "outputs": []
561 | },
562 | {
563 | "cell_type": "code",
564 | "metadata": {
565 | "id": "3jjVYFmSv9lu",
566 | "colab_type": "code",
567 | "outputId": "0c443233-6750-4382-b790-333ac78dcf98",
568 | "colab": {
569 | "base_uri": "https://localhost:8080/",
570 | "height": 1000
571 | }
572 | },
573 | "source": [
574 | "train(pointnet, train_loader, val_loader, save=True)\n"
575 | ],
576 | "execution_count": 0,
577 | "outputs": [
578 | {
579 | "output_type": "stream",
580 | "text": [
581 | "[1, 10] loss: 1.203\n",
582 | "[1, 20] loss: 0.924\n",
583 | "[1, 30] loss: 0.844\n",
584 | "[1, 40] loss: 0.800\n",
585 | "Valid accuracy: 78 %\n",
586 | "[2, 10] loss: 0.769\n",
587 | "[2, 20] loss: 0.741\n",
588 | "[2, 30] loss: 0.732\n",
589 | "[2, 40] loss: 0.729\n",
590 | "Valid accuracy: 82 %\n",
591 | "[3, 10] loss: 0.709\n",
592 | "[3, 20] loss: 0.685\n",
593 | "[3, 30] loss: 0.680\n",
594 | "[3, 40] loss: 0.676\n",
595 | "Valid accuracy: 85 %\n",
596 | "[4, 10] loss: 0.663\n",
597 | "[4, 20] loss: 0.642\n",
598 | "[4, 30] loss: 0.637\n",
599 | "[4, 40] loss: 0.636\n",
600 | "Valid accuracy: 87 %\n",
601 | "[5, 10] loss: 0.626\n",
602 | "[5, 20] loss: 0.617\n",
603 | "[5, 30] loss: 0.610\n",
604 | "[5, 40] loss: 0.613\n",
605 | "Valid accuracy: 86 %\n",
606 | "[6, 10] loss: 0.604\n",
607 | "[6, 20] loss: 0.587\n",
608 | "[6, 30] loss: 0.580\n",
609 | "[6, 40] loss: 0.583\n",
610 | "Valid accuracy: 86 %\n",
611 | "[7, 10] loss: 0.574\n",
612 | "[7, 20] loss: 0.563\n",
613 | "[7, 30] loss: 0.558\n",
614 | "[7, 40] loss: 0.565\n",
615 | "Valid accuracy: 86 %\n",
616 | "[8, 10] loss: 0.553\n",
617 | "[8, 20] loss: 0.539\n",
618 | "[8, 30] loss: 0.538\n",
619 | "[8, 40] loss: 0.543\n",
620 | "Valid accuracy: 87 %\n",
621 | "[9, 10] loss: 0.533\n",
622 | "[9, 20] loss: 0.525\n",
623 | "[9, 30] loss: 0.516\n",
624 | "[9, 40] loss: 0.521\n",
625 | "Valid accuracy: 87 %\n",
626 | "[10, 10] loss: 0.522\n",
627 | "[10, 20] loss: 0.506\n",
628 | "[10, 30] loss: 0.503\n",
629 | "[10, 40] loss: 0.507\n",
630 | "Valid accuracy: 87 %\n",
631 | "[11, 10] loss: 0.501\n",
632 | "[11, 20] loss: 0.495\n",
633 | "[11, 30] loss: 0.485\n",
634 | "[11, 40] loss: 0.497\n",
635 | "Valid accuracy: 88 %\n",
636 | "[12, 10] loss: 0.484\n",
637 | "[12, 20] loss: 0.477\n",
638 | "[12, 30] loss: 0.474\n",
639 | "[12, 40] loss: 0.477\n",
640 | "Valid accuracy: 88 %\n",
641 | "[13, 10] loss: 0.472\n",
642 | "[13, 20] loss: 0.458\n",
643 | "[13, 30] loss: 0.456\n",
644 | "[13, 40] loss: 0.464\n",
645 | "Valid accuracy: 87 %\n",
646 | "[14, 10] loss: 0.454\n",
647 | "[14, 20] loss: 0.452\n",
648 | "[14, 30] loss: 0.446\n",
649 | "[14, 40] loss: 0.458\n",
650 | "Valid accuracy: 87 %\n",
651 | "[15, 10] loss: 0.444\n",
652 | "[15, 20] loss: 0.433\n",
653 | "[15, 30] loss: 0.432\n",
654 | "[15, 40] loss: 0.440\n",
655 | "Valid accuracy: 88 %\n"
656 | ],
657 | "name": "stdout"
658 | }
659 | ]
660 | },
661 | {
662 | "cell_type": "markdown",
663 | "metadata": {
664 | "id": "VeUZqen5GlKr",
665 | "colab_type": "text"
666 | },
667 | "source": [
668 | "## test"
669 | ]
670 | },
671 | {
672 | "cell_type": "markdown",
673 | "metadata": {
674 | "id": "mbBNmdqgGj5-",
675 | "colab_type": "text"
676 | },
677 | "source": [
678 | ""
679 | ]
680 | },
681 | {
682 | "cell_type": "code",
683 | "metadata": {
684 | "id": "Xsk9nSDAI3ba",
685 | "colab_type": "code",
686 | "outputId": "4467b734-e3db-4d48-dce3-b7f269cbaed5",
687 | "colab": {
688 | "base_uri": "https://localhost:8080/",
689 | "height": 867
690 | }
691 | },
692 | "source": [
693 | "pointnet = PointNetSeg()\n",
694 | "pointnet.load_state_dict(torch.load(root_dir+\"/modelsSeg/\"+\"14_88.01940298507462\"))\n",
695 | "pointnet.eval()"
696 | ],
697 | "execution_count": 0,
698 | "outputs": [
699 | {
700 | "output_type": "execute_result",
701 | "data": {
702 | "text/plain": [
703 | "PointNetSeg(\n",
704 | " (transform): Transform(\n",
705 | " (input_transform): Tnet(\n",
706 | " (conv1): Conv1d(3, 64, kernel_size=(1,), stride=(1,))\n",
707 | " (conv2): Conv1d(64, 128, kernel_size=(1,), stride=(1,))\n",
708 | " (conv3): Conv1d(128, 1024, kernel_size=(1,), stride=(1,))\n",
709 | " (fc1): Linear(in_features=1024, out_features=512, bias=True)\n",
710 | " (fc2): Linear(in_features=512, out_features=256, bias=True)\n",
711 | " (fc3): Linear(in_features=256, out_features=9, bias=True)\n",
712 | " (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
713 | " (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
714 | " (bn3): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
715 | " (bn4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
716 | " (bn5): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
717 | " )\n",
718 | " (feature_transform): Tnet(\n",
719 | " (conv1): Conv1d(128, 64, kernel_size=(1,), stride=(1,))\n",
720 | " (conv2): Conv1d(64, 128, kernel_size=(1,), stride=(1,))\n",
721 | " (conv3): Conv1d(128, 1024, kernel_size=(1,), stride=(1,))\n",
722 | " (fc1): Linear(in_features=1024, out_features=512, bias=True)\n",
723 | " (fc2): Linear(in_features=512, out_features=256, bias=True)\n",
724 | " (fc3): Linear(in_features=256, out_features=16384, bias=True)\n",
725 | " (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
726 | " (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
727 | " (bn3): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
728 | " (bn4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
729 | " (bn5): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
730 | " )\n",
731 | " (fc1): Conv1d(3, 64, kernel_size=(1,), stride=(1,))\n",
732 | " (fc2): Conv1d(64, 128, kernel_size=(1,), stride=(1,))\n",
733 | " (fc3): Conv1d(128, 128, kernel_size=(1,), stride=(1,))\n",
734 | " (fc4): Conv1d(128, 512, kernel_size=(1,), stride=(1,))\n",
735 | " (fc5): Conv1d(512, 2048, kernel_size=(1,), stride=(1,))\n",
736 | " (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
737 | " (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
738 | " (bn3): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
739 | " (bn4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
740 | " (bn5): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
741 | " )\n",
742 | " (fc1): Conv1d(3008, 256, kernel_size=(1,), stride=(1,))\n",
743 | " (fc2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n",
744 | " (fc3): Conv1d(256, 128, kernel_size=(1,), stride=(1,))\n",
745 | " (fc4): Conv1d(128, 4, kernel_size=(1,), stride=(1,))\n",
746 | " (bn1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
747 | " (bn2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
748 | " (bn3): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
749 | " (bn4): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
750 | " (logsoftmax): LogSoftmax()\n",
751 | ")"
752 | ]
753 | },
754 | "metadata": {
755 | "tags": []
756 | },
757 | "execution_count": 11
758 | }
759 | ]
760 | },
761 | {
762 | "cell_type": "code",
763 | "metadata": {
764 | "id": "zoE5fRX8GnWR",
765 | "colab_type": "code",
766 | "outputId": "901bdfb2-0445-4497-e772-e2f482b98f12",
767 | "colab": {
768 | "base_uri": "https://localhost:8080/",
769 | "height": 139
770 | }
771 | },
772 | "source": [
773 | "batch = next(iter(val_loader))\n",
774 | "pred = pointnet(batch['image'].transpose(1,2))\n",
775 | "pred_np = np.array(torch.argmax(pred[0],1));\n",
776 | "pred_np\n",
777 | "\n"
778 | ],
779 | "execution_count": 0,
780 | "outputs": [
781 | {
782 | "output_type": "execute_result",
783 | "data": {
784 | "text/plain": [
785 | "array([[1, 1, 0, ..., 0, 0, 1],\n",
786 | " [1, 2, 1, ..., 1, 2, 1],\n",
787 | " [0, 0, 2, ..., 0, 1, 3],\n",
788 | " ...,\n",
789 | " [1, 0, 1, ..., 3, 2, 1],\n",
790 | " [2, 3, 0, ..., 0, 1, 3],\n",
791 | " [1, 1, 1, ..., 0, 0, 1]])"
792 | ]
793 | },
794 | "metadata": {
795 | "tags": []
796 | },
797 | "execution_count": 22
798 | }
799 | ]
800 | },
801 | {
802 | "cell_type": "code",
803 | "metadata": {
804 | "id": "3ISvwR4RL_eT",
805 | "colab_type": "code",
806 | "outputId": "9695bbd3-e74c-41a1-9846-ad438f413a2f",
807 | "colab": {
808 | "base_uri": "https://localhost:8080/",
809 | "height": 35
810 | }
811 | },
812 | "source": [
813 | "batch['image'][0].shape"
814 | ],
815 | "execution_count": 0,
816 | "outputs": [
817 | {
818 | "output_type": "execute_result",
819 | "data": {
820 | "text/plain": [
821 | "torch.Size([2000, 3])"
822 | ]
823 | },
824 | "metadata": {
825 | "tags": []
826 | },
827 | "execution_count": 61
828 | }
829 | ]
830 | },
831 | {
832 | "cell_type": "code",
833 | "metadata": {
834 | "id": "MQXXGh3GJSRf",
835 | "colab_type": "code",
836 | "outputId": "23c684cc-7a4f-467b-a8c6-e916bae70107",
837 | "colab": {
838 | "base_uri": "https://localhost:8080/",
839 | "height": 139
840 | }
841 | },
842 | "source": [
843 | "pred_np==np.array(batch['category'])"
844 | ],
845 | "execution_count": 0,
846 | "outputs": [
847 | {
848 | "output_type": "execute_result",
849 | "data": {
850 | "text/plain": [
851 | "array([[ True, True, True, ..., True, True, True],\n",
852 | " [ True, True, True, ..., True, True, True],\n",
853 | " [ True, True, True, ..., False, True, True],\n",
854 | " ...,\n",
855 | " [ True, True, True, ..., True, True, True],\n",
856 | " [ True, True, True, ..., True, True, True],\n",
857 | " [ True, True, True, ..., False, True, True]])"
858 | ]
859 | },
860 | "metadata": {
861 | "tags": []
862 | },
863 | "execution_count": 23
864 | }
865 | ]
866 | },
867 | {
868 | "cell_type": "code",
869 | "metadata": {
870 | "id": "60Mr2Bp7O9xu",
871 | "colab_type": "code",
872 | "colab": {}
873 | },
874 | "source": [
875 | "acc = (pred_np==np.array(batch['category']))"
876 | ],
877 | "execution_count": 0,
878 | "outputs": []
879 | },
880 | {
881 | "cell_type": "code",
882 | "metadata": {
883 | "id": "Jc1jLPj-PBo-",
884 | "colab_type": "code",
885 | "colab": {}
886 | },
887 | "source": [
888 | "resulting_acc = np.sum(acc, axis=1) / 2000"
889 | ],
890 | "execution_count": 0,
891 | "outputs": []
892 | },
893 | {
894 | "cell_type": "code",
895 | "metadata": {
896 | "id": "w8NwSbbz8jQ-",
897 | "colab_type": "code",
898 | "outputId": "eba05278-cf1a-425b-8462-7684e7a50101",
899 | "colab": {
900 | "base_uri": "https://localhost:8080/",
901 | "height": 156
902 | }
903 | },
904 | "source": [
905 | "resulting_acc"
906 | ],
907 | "execution_count": 0,
908 | "outputs": [
909 | {
910 | "output_type": "execute_result",
911 | "data": {
912 | "text/plain": [
913 | "array([0.863 , 0.906 , 0.91 , 0.867 , 0.9165, 0.882 , 0.9055, 0.8445,\n",
914 | " 0.8405, 0.9235, 0.8865, 0.8855, 0.903 , 0.884 , 0.8945, 0.85 ,\n",
915 | " 0.9125, 0.822 , 0.9345, 0.895 , 0.9135, 0.9395, 0.9385, 0.9165,\n",
916 | " 0.8865, 0.848 , 0.8765, 0.9105, 0.8805, 0.83 , 0.852 , 0.9225,\n",
917 | " 0.906 , 0.7705, 0.883 , 0.785 , 0.811 , 0.8565, 0.866 , 0.868 ,\n",
918 | " 0.7855, 0.7305, 0.9155, 0.8915, 0.9065, 0.805 , 0.875 , 0.89 ,\n",
919 | " 0.813 , 0.9005, 0.8325, 0.833 , 0.879 , 0.9215, 0.8185, 0.933 ,\n",
920 | " 0.9325, 0.9 , 0.833 , 0.8535, 0.8545, 0.895 , 0.8325, 0.9295])"
921 | ]
922 | },
923 | "metadata": {
924 | "tags": []
925 | },
926 | "execution_count": 18
927 | }
928 | ]
929 | },
930 | {
931 | "cell_type": "code",
932 | "metadata": {
933 | "id": "nnt7vxPHKamU",
934 | "colab_type": "code",
935 | "outputId": "1ab89e1a-d06d-47c3-dcae-2876e5c6d605",
936 | "colab": {
937 | "base_uri": "https://localhost:8080/",
938 | "height": 139
939 | }
940 | },
941 | "source": [
942 | "pred_np"
943 | ],
944 | "execution_count": 0,
945 | "outputs": [
946 | {
947 | "output_type": "execute_result",
948 | "data": {
949 | "text/plain": [
950 | "array([[1, 1, 0, ..., 0, 0, 1],\n",
951 | " [1, 2, 1, ..., 1, 2, 1],\n",
952 | " [0, 0, 2, ..., 0, 1, 3],\n",
953 | " ...,\n",
954 | " [1, 0, 1, ..., 3, 2, 1],\n",
955 | " [2, 3, 0, ..., 0, 1, 3],\n",
956 | " [1, 1, 1, ..., 0, 0, 1]])"
957 | ]
958 | },
959 | "metadata": {
960 | "tags": []
961 | },
962 | "execution_count": 25
963 | }
964 | ]
965 | },
966 | {
967 | "cell_type": "code",
968 | "metadata": {
969 | "id": "N9bgpbtnHC2E",
970 | "colab_type": "code",
971 | "outputId": "394d29a3-8cc0-4ee7-d480-edd452a0d1ce",
972 | "colab": {
973 | "base_uri": "https://localhost:8080/"
974 | }
975 | },
976 | "source": [
977 | "x,y,z=np.array(batch['image'][0]).T\n",
978 | "c = np.array(batch['category'][0]).T\n",
979 | "\n",
980 | "fig = go.Figure(data=[go.Scatter3d(x=x, y=y, z=z, \n",
981 | " mode='markers',\n",
982 | " marker=dict(\n",
983 | " size=30,\n",
984 | " color=c, # set color to an array/list of desired values\n",
985 | " colorscale='Viridis', # choose a colorscale\n",
986 | " opacity=1.0\n",
987 | " ))])\n",
988 | "fig.update_traces(marker=dict(size=2,\n",
989 | " line=dict(width=2,\n",
990 | " color='DarkSlateGrey')),\n",
991 | " selector=dict(mode='markers'))\n",
992 | "fig.show()"
993 | ],
994 | "execution_count": 0,
995 | "outputs": [
996 | {
997 | "output_type": "display_data",
998 | "data": {
999 | "text/html": [
1000 | "\n",
1001 | "