├── LICENSE
├── README.md
├── assets
├── framework.png
└── torus.gif
├── dataloaders
├── __init__.py
├── cad_models_loader.py
├── cadmulobj_loader.py
├── kittimulobj_loader.py
├── modelnet_loader.py
└── shapenet_part_loader.py
├── experiments
├── counting.py
├── reconstruction.py
├── train_basic.py
└── train_tearing.py
├── models
├── __init__.py
├── autoencoder.py
├── foldingnet.py
├── pointnet.py
├── tearingnet.py
└── tearingnet_graph.py
├── scripts
├── experiments
│ ├── counting.sh
│ ├── reconstruction.sh
│ ├── train_folding_cad.sh
│ ├── train_folding_kitti.sh
│ ├── train_tearing_cad.sh
│ └── train_tearing_kitti.sh
├── gen_data
│ ├── gen_cad_mulobj_test_5x5.sh
│ ├── gen_cad_mulobj_train_5x5.sh
│ ├── gen_kitti_mulobj_test_5x5.sh
│ └── gen_kitti_mulobj_train_5x5.sh
└── launch.sh
└── util
├── __init__.py
├── cad_models_collector.py
├── mesh_writer.py
├── option_handler.py
└── pcdet_create_groundtruth_database.py
/LICENSE:
--------------------------------------------------------------------------------
1 | LIMITED SOFTWARE EVALUATION LICENSE AGREEMENT
2 |
3 |
4 | The following Limited Software Evaluation License (the “License”) constitutes an agreement between you (the “Licensee”) and InterDigital Communications, Inc, a company organized and existing under the laws of the State of Delaware, USA, with its registered offices located at 200 Bellevue Parkway, Suite 300, Wilmington, DE 19809, USA (hereinafter “InterDigital”).
5 |
6 | This License governs the download and use of the Software (as defined below). Your use of the Software is subject to the terms and conditions set forth in this License. By installing, using, accessing or copying the Software, you hereby irrevocably accept the terms and conditions of this License. If you do not accept all parts of the terms and conditions of this License, you cannot install, use, access nor copy the Software.
7 |
8 |
9 | Article 1. Definitions
10 |
11 | “Affiliate” as used herein shall mean any entity that, directly or indirectly, through one or more intermediates, is controlled by, controls, or is under common control with InterDigital or The Licensee, as the case may be. For purposes of this definition only, the term “control” means the possession of the power to direct or cause the direction of the management and policies of an entity, whether by ownership of voting stock or partnership interest, by contract, or otherwise, including direct or indirect ownership of more than fifty percent (50%) of the voting interest in the entity in question.
12 |
13 | “Authorized Purpose” means any use of the Software for reproducing the experimental results reported in the following publication: Jiahao Pang, et.al., “TearingNet: Point Cloud Autoencoder to Learn Topology-Friendly Representations”, IEEE Conference on Computer Vision and Pattern Recognition (CVPR), June 2021 (the "Purpose")
14 |
15 | “Derivative Work” means any work that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship.
16 |
17 | “Documentation” means textual materials delivered by InterDigital to the Licensee pursuant to this License relating to the Software, in written or electronic format, including but not limited to, technical reference manuals, technical notes, user manuals, and application guides.
18 |
19 | “Effective Date” means the date Licensee first installs a copy of the Software on any computer.
20 |
21 | “Limited Period” means the life of the copyright owned by InterDigital on the Software in each and every country where such copyright would exist.
22 |
23 | “Intellectual Property Rights” means all copyrights, trademarks, trade secrets, patents, mask words, and any other intellectual property rights recognized in any jurisdiction worldwide, including all applications and registrations with respect thereto.
24 |
25 | "Open Source Software" shall mean any software, including where appropriate, any and all modifications, derivative works, enhancements, upgrades, improvements, fixed bugs, and/or software statically linked to the source code of such software, released under a free or open source software license that requires, as a condition of usage, copy, modification and/or redistribution of such software, that the party:
26 | • Redistribute the Open Source Software royalty-free; and/or
27 | • Redistribute the Open Source Software under the same license/distribution terms as those contained in the open source or free software license under which it was originally released; and/or
28 | • Release to the public, disclose or otherwise make available the source code of the Open Source Software.
29 |
30 | For purposes of this License, by means of example and without limitation, any software that is released or distributed under any of the following licenses shall be qualified as Open Source Software: (i) GNU General Public License (GPL); (ii) GNU Lesser/Library GPL (LGPL); (iii) the Artistic License; (iv) the Mozilla Public License; (v) the Common Public License; (vi) the Sun Community Source License (SCSL); (vii) the Sun Industry Standards Source License (SISSL); (viii) BSD License; (ix) MIT License; (x) Apache Software License; (xi) Open SSL License; (xii) IBM Public License; and (xiii) Open Software License.
31 |
32 | “Software” means the Software with which this license was downloaded.
33 |
34 |
35 | Article 2. License
36 |
37 | InterDigital grants Licensee a free, worldwide, non-exclusive, license to InterDigital’s copyright on the Software to download, use and reproduce solely for the Authorized Purpose for the Limited Period.
38 |
39 | Licensee shall not pay any royalty, license fee or maintenance fee, or other fee of any nature under this License.
40 |
41 | Licensee shall have the right to correct, adapt and modify the Software, provided that such action is made to accomplish the Authorized Purpose. Licensee shall promptly provide a copy of such correction, adaptation or modification to InterDigital after any such change is made. All such changes to the Software shall be deemed Derivative Work.
42 |
43 |
44 | Article 3. Restrictions on use of the Software
45 |
46 | Licensee shall not remove, obscure or modify any copyright, trademark or other proprietary rights notices, marks or labels contained on or within the Software, falsify or delete any author attributions, legal notices or other labels of the origin or source of the material.
47 |
48 | Licensee may reproduce and distribute copies of the Software (including any Derivative Works to which Licensee has rights) in any medium, with or without modifications, provided that any such distribution meets the following conditions:
49 | 1. Licensee must give any other recipients of the Software or Derivative Works a copy of this License;
50 | 2. Licensee must cause any modified files to carry prominent notices stating that Licensee has changed the files; and
51 | 3. Licensee must retain, in the source form of any Software or Derivative Works that Licensee reproduces or distributes, all copyright, patent, trademark, and attribution notices from the source form of the Software, excluding those notices that do not pertain to any part of the Derivative Works.
52 |
53 |
54 | Article 4. Ownership
55 |
56 | Title to and ownership of the Software, the Documentation, and/or any Intellectual Property Right protecting the Software and/or the Documentation shall at all times remain with InterDigital. Licensee agrees that except for the limited rights granted to the Software as set forth in Section 2 above, in no event shall anything in this License grant, provide, or convey any other rights, privileges, immunities, or interest in or to any Intellectual Property Rights (including but not limited to patent rights) of InterDigital or any of its Affiliates, whether by implication, estoppel, or otherwise.
57 |
58 |
59 | Article 5. Derivative Works
60 |
61 | Derivative Works created by Licensee shall be owned by Licensee who will be able to use it or distribute it freely, provided that this does not infringe any of InterDigital’s rights.
62 |
63 | Licensee may add its own copyright statement to the Derivative Work and may provide additional or different license terms and conditions for use, reproduction, or distribution of Derivative Work, provided the use, reproduction, and distribution of the Derivative Work otherwise complies with the conditions stated in this License.
64 |
65 | Licensee hereby grants to InterDigital a perpetual, fully paid-up, transferrable, non-exclusive, and worldwide license in, on, and to any copyright of Licensee on any Derivative Work to use, modify, and reproduce the Derivative Work in connection with InterDigital’s use of the Software. Such license shall include the right to correct, adapt, modify, reverse engineer, disassemble, decompile or/and otherwise perform or conduct any action leading to the transformation of Derivative Work, provided that such action is made in connection with InterDigital’s use of the Software.
66 |
67 |
68 | Article 6. Publication/Communication
69 |
70 | Any publication or oral communication resulting from the use of the Software shall be elaborated in good faith and shall not be driven by a deliberate will to denigrate InterDigital or any of its products. In any publication and on any support joined to an oral communication (e.g., a PowerPoint presentation) relating to the Software, the following statement shall be inserted:
71 | “TearingNet is an InterDigital product”
72 |
73 | In any publication, the latest publication about the software shall be properly cited. The latest publication currently is:
74 | Pang, Jiahao, et al.. (2021, June 19-25). TearingNet: Point Cloud Autoencoder to Learn Topology-Friendly Representations. IEEE Conference on Computer Vision and Pattern Recognition (CVPR).
75 |
76 | In any oral communication relating to the Software and/or its use, the Licensee shall orally indicate that the Software is InterDigital’s property.
77 |
78 |
79 | Article 7. No Warranty - Disclaimer
80 |
81 | THE SOFTWARE AND DOCUMENTATION ARE PROVIDED TO LICENSEE ON AN “AS IS” BASIS. INTERDIGITAL MAKES NO WARRANTY THAT THE SOFTWARE WILL OPERATE ON ANY PARTICULAR HARDWARE, PLATFORM, OR ENVIRONMENT. THERE IS NO WARRANTY THAT THE OPERATION OF THE SOFTWARE SHALL BE UNINTERRUPTED, WITHOUT BUGS OR ERROR FREE. THE SOFTWARE AND DOCUMENTATION ARE PROVIDED HEREUNDER WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO ANY IMPLIED LIABILITIES AND WARRANTIES OF NONINFRINGEMENT OF INTELLECTUAL PROPERTY, FREEDOM FROM INHERENT DEFECTS, CONFORMITY TO A SAMPLE OR MODEL, MERCHANTABILITY, FITNESS AND/OR SUITABILITY FOR A SPECIFIC OR GENERAL PURPOSE AND THOSE ARISING BY STATUTE OR BY LAW, OR FROM A CAUSE OF DEALING OR USAGE OF TRADE. ANY AND ALL SUCH IMPLIED WARRANTIES ARE FULLY DISCLAIMED BY INTERDIGITAL TO THE MAXIMUM EXTENT ALLOWED BY LAW, AND LICENSEE ACKNOWLEDGES THAT THIS DISCLAIMER OF ALL EXPRESS AND IMPLIED WARRANTIES BY INTERDIGITAL, AS WELL AS LICENSEE’S ACCEPTANCE AND ACKNOWLEDGEMENT OF THE SAME, IS A MATERIAL PART OF THE CONSIDERATION FOR THIS LICENSE.
82 |
83 | InterDigital shall not be obligated to perform or provide any modifications, derivative works, enhancements, upgrades, updates or improvements of the Software or Documentation, or to fix any bug that could arise.
84 |
85 | Licensee at all times uses the Software at its own cost, risk and responsibility. InterDigital shall not be liable for any damages that could accrue by or to Licensee as a result of its use of the Software, either in accordance with this License or not.
86 |
87 | InterDigital shall not be liable for any consequential or indirect losses, including any indirect loss of profits, revenues, business, and/or anticipated savings, whether or not in the contemplation of the Parties at the time of entering into this License unless expressly set out in this License, or arising from gross negligence, willful misconduct or fraud.
88 |
89 | Licensee agrees that it will defend, indemnify and hold harmless InterDigital and its Affiliates against any and all losses, damages, costs and expenses arising from a breach by the Licensee of any of its obligations or representations hereunder, including, without limitation, any third party claims, and/or any claims in connection with any such breach and/or any use of the Software, including any claim from third party arising from access, use, or any other activity in relation to this Software.
90 |
91 | Licensee shall not make any warranty, representation, or commitment on behalf of InterDigital to any other third party.
92 |
93 |
94 | Article 8. Open Source Software
95 |
96 | Licensee hereby represents, warrants, and covenants to InterDigital that Licensee’s use of the Software shall not result in the Contamination of all or any part of the Software, directly or indirectly, or of any Intellectual Property of InterDigital or its Affiliates.
97 |
98 | As used herein, “Contamination” shall mean that the licensing terms under which any Open Source Software, distinct from the Software, is released would also apply to the Software herein, by virtue of such Open Source Software being linked to, combined with, or otherwise connected to the Software.
99 |
100 |
101 | Article 9. No Future Contract Obligation
102 |
103 | Neither this License nor the furnishing of the Software, nor any other InterDigital information provided to Licensee, shall be construed to obligate either party to: (a) enter into any further agreement or negotiation concerning the deployment of the Software; (b) refrain from entering into any agreement or negotiation with any other third party regarding the same or any other subject matter; or (c) refrain from pursuing its business in whatever manner it elects even if this involves competing with the other party.
104 |
105 |
106 | Article 10. General Provisions
107 |
108 | 10.1 Severability. If any provision of this License shall be held to be in contravention of applicable law, this License shall be construed as if such provision were not a part thereof, and in all other respects the terms hereof shall remain in full force and effect.
109 |
110 | 10.2 Governing Law. Regardless of the place of execution, delivery, performance or any other aspect of this License, this License and all of the rights of the parties under this License shall be governed by, construed under and enforced in accordance with the substantive law of the State of Delaware, USA, without regard to conflicts of law principles. In case of a dispute that cannot be settled amicably, the state and federal courts located in New Castle County, Delaware, USA, shall have exclusive jurisdiction over such dispute, and each party hereby irrevocably waives any objection to the jurisdiction of such courts, including but not limited to objections of lack of in personam jurisdiction or based on principles of forum non conveniens.
111 |
112 | 10.3 Survival. The provisions of articles 1, 3, 4, 6, 7, 9, 10.1, 10.2 and 10.5 shall survive termination of this License.
113 |
114 | 10.4 Assignment. InterDigital may assign this license to any third Party. Licensee may not assign this agreement to any third party without InterDigital’s prior written approval.
115 |
116 | 10.5 Entire Agreement. This License constitutes the entire agreement between the parties hereto with respect to the subject matter hereof and supersedes any prior agreements or understanding.
117 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TearingNet: Point Cloud Autoencoder to Learn Topology-Friendly Representations
2 | Created by Jiahao Pang, Duanshun Li, and Dong Tian from InterDigital
3 |
4 |
5 |
6 |
7 |
8 | ## Introduction
9 | This repository contains the implementation of our TearingNet paper accepted in CVPR 2021.
10 | Given a point cloud dataset containing objects with various genera, or scenes with multiple objects, we propose the TearingNet, which is an autoencoder tackling the challenging task of representing the point clouds using a fixed-length descriptor.
11 | Unlike existing works directly deforming predefined primitives of genus zero (e.g., a 2D square patch) to an object-level point cloud, our TearingNet
12 | is characterized by a proposed Tearing network module and a Folding network module interacting with each other iteratively.
13 | Particularly, the Tearing network module learns the point cloud topology explicitly.
14 | By breaking the edges of a primitive graph, it tears the graph into patches or with holes to emulate the topology of a target point cloud, leading to faithful reconstructions.
15 |
16 | ## Installation
17 | * We use Python 3.6, PyTorch 1.3.1 and CUDA 10.0, example commands to set up a virtual environment with anaconda are:
18 | ```
19 | conda create tearingnet python=3.6
20 | conda activate tearingnet
21 | conda install pytorch=1.3.1 torchvision=0.4.2 cudatoolkit=10.0 -c pytorch
22 | ```
23 |
24 | * Clone our repo to a folder, assume the name of the folder is ``TearingNet``.
25 | * Checkout the nndistance folder from 3D Point Capsule Network, put it under ``TearingNet/util``, build it according to the instructions of 3D Point Capsule Network.
26 | * We adopt the ShapeNet part data loader from 3D Point Capsule Network. Checkout shapenet_part_loader.py and put it under the ``TearingNet/dataloaders``.
27 | * Install the TensorboardX, Open3D and the h5py packages, example commands are:
28 | ```
29 | conda install -c open3d-admin open3d
30 | conda install -c conda-forge tensorboardx
31 | conda install -c anaconda h5py
32 | ```
33 |
34 | ## Data Preparation
35 |
36 | ### KITTI Multi-Object Dataset
37 |
38 | * Our KITTI Multi-Object (KIMO) Dataset is constructed with kitti_dataset.py of PCDet (commit 95d2ab5). Please clone and install PCDet, then prepare the KITTI dataset according to their instructions.
39 | * Assume the name of the cloned folder is ``PCDet``, please replace the ``create_groundtruth_database()`` function in ``kitti_dataset.py`` by our modified one provided in ``TearingNet/util/pcdet_create_grouth_database.py``.
40 | * Prepare the KITTI dataset, then generate the data infos according to the instructions in the README.md of PCDet.
41 | * Create the folders ``TearingNet/dataset`` and ``TearingNet/dataset/kittimulobj`` then put the newly-generated folder ``PCDet/data/kitti/kitti_single`` under ``TearingNet/dataset/kittimulobj``. Also, put the newly-generated file ``PCDet/data/kitti/kitti_dbinfos_object.pkl`` under the ``TearingNet/dataset/kittimulobj`` folder.
42 | * Instead of assembling several single-object point clouds together and write down as a multi-object point cloud, we generate the parameters that parameterize the multi-object point clouds then assemble them on the fly during training/testing. To obtain the parameters, run our prepared scripts as follows under the ``TearingNet`` folder. These scripts generate the training and testing splits of the KIMO-5 dataset:
43 | ```
44 | ./scripts/launch.sh ./scripts/gen_data/gen_kitti_mulobj_train_5x5.sh
45 | ./scripts/launch.sh ./scripts/gen_data/gen_kitti_mulobj_test_5x5.sh
46 | ```
47 | * The file structure of the KIMO dataset after these steps becomes:
48 | ```
49 | kittimulobj
50 | ├── kitti_dbinfos_object.pkl
51 | ├── kitti_mulobj_param_test_5x5_2048.pkl
52 | ├── kitti_mulobj_param_train_5x5_2048.pkl
53 | └── kitti_single
54 | ├── 0_0_Pedestrian.bin
55 | ├── 1000_0_Car.bin
56 | ├── 1000_1_Car.bin
57 | ├── 1000_2_Van.bin
58 | ...
59 | ```
60 |
61 | ### CAD Model Multi-Object Dataset
62 |
63 | * Create the ``TearingNet/dataset/cadmulobj`` folder to hold the CAD Model Multi-Object (CAMO) dataset.
64 |
65 | * Our CAMO dataset is based on the CAD models of "person", "car", "cone", "plant" from ModelNet40, and "motorbike" from the ShapeNetPart dataset. Use the scripts download_shapenet_part16_catagories.sh and download_modelnet40_same_with_pointnet.sh from 3D Point Capsule Network to download these two datasets, then orgainze them according to the following file structure under the ``TearingNet/dataset`` folder:
66 | ```
67 | dataset
68 | ├── cadmulobj
69 | ├── kittimulobj
70 | ├── modelnet40
71 | │ └── modelnet40_ply_hdf5_2048
72 | │ ├── ply_data_test0.h5
73 | │ ├── ply_data_test_0_id2file.json
74 | │ ├── ply_data_test1.h5
75 | │ ├── ply_data_test_1_id2file.json
76 | │ ...
77 | └── shapenet_part
78 | ├── shapenetcore_partanno_segmentation_benchmark_v0
79 | │ ├── 02691156
80 | │ │ ├── points
81 | │ │ │ ├── 1021a0914a7207aff927ed529ad90a11.pts
82 | │ │ │ ├── 103c9e43cdf6501c62b600da24e0965.pts
83 | │ │ │ ├── 105f7f51e4140ee4b6b87e72ead132ed.pts
84 | ...
85 | ```
86 | * Extract the "person", "car", "cone" and "plant" models from ModelNet40, and the "motorbike" models from the ShapeNet part dataset, by running the following Python script under the ``TearingNet`` folder:
87 | ```
88 | python util/cad_models_collector.py
89 | ```
90 | * The previous step generates the file ``TearingNet/dataset/cadmulobj/cad_models.npy``, based on which we generate the parameters for the CAMO dataset. To do so, launch the following scripts:
91 | ```
92 | ./scripts/launch.sh ./scripts/gen_data/gen_cad_mulobj_train_5x5.sh
93 | ./scripts/launch.sh ./scripts/gen_data/gen_cad_mulobj_test_5x5.sh
94 | ```
95 | * The file structure of the CAMO dataset after these steps becomes:
96 | ```
97 | cadmulobj
98 | ├── cad_models.npy
99 | ├── cad_mulobj_param_test_5x5.npy
100 | └── cad_mulobj_param_train_5x5.npy
101 | ```
102 | ## Experiments
103 |
104 | ### Training
105 |
106 | We employ a two-stage training strategy to train the TearingNet. The first step is to train a FoldingNet (E-Net & F-Net in paper). Take the KIMO dataset as an example, launch the following scripts under the ``TearingNet`` folder:
107 | ```
108 | ./scripts/launch.sh ./scripts/experiments/train_folding_kitti.sh
109 | ```
110 | Having finished the first step, a pretrained model will be saved in ``TearingNet/results/train_folding_kitti``. To load the pretrained FoldingNet into a TearingNet configuration and perform training, launch the following scripts:
111 | ```
112 | ./scripts/launch.sh ./scripts/experiments/train_tearing_kitti.sh
113 | ```
114 | To see the meanings of the parameters in ``train_folding_kitti.sh`` and ``train_tearing_kitti.sh``, check the Python script ``TearinNet/util/option_handler.py``.
115 | ### Reconstruction
116 |
117 | To perform the reconstruction experiment with the trained model, launch the following scripts:
118 | ```
119 | ./scripts/launch.sh ./scripts/experiments/reconstruction.sh
120 | ```
121 | One may write down the reconstructions in PLY format by setting a positive ``PC_WRITE_FREQ`` value. Again, please refer to ``TearinNet/util/option_handler.py`` for the meanings of individual parameters.
122 | ### Counting
123 |
124 | To perform the counting experiment with the trained model, launch the following scripts:
125 | ```
126 | ./scripts/launch.sh ./scripts/experiments/counting.sh
127 | ```
128 |
129 | ## Citing this Work
130 | Please cite our work if you find it useful for your research:
131 | ```
132 | @inproceedings{pang2021tearingnet,
133 | title={TearingNet: Point Cloud Autoencoder to Learn Topology-Friendly Representations},
134 | author={Pang, Jiahao and Li, Duanshun, and Tian, Dong},
135 | booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
136 | year={2021}
137 | }
138 | ```
139 |
140 | ## Related Projects
141 | * 3D Point Capsule Networks
142 | * AtlasNet
143 | * AtlasNetV2
144 |
145 | * PCDet
146 |
147 |
148 |
149 |
--------------------------------------------------------------------------------
/assets/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/InterDigitalInc/TearingNet/1e43caa266a11e9d50c2d912064bd39d369bb120/assets/framework.png
--------------------------------------------------------------------------------
/assets/torus.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/InterDigitalInc/TearingNet/1e43caa266a11e9d50c2d912064bd39d369bb120/assets/torus.gif
--------------------------------------------------------------------------------
/dataloaders/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import numpy as np
4 | from . import cadmulobj_loader, kittimulobj_loader
5 |
6 |
7 | # Build the dataset for trainging accordingly
8 | def point_cloud_dataset_train(dataset_name, num_points, batch_size, train_split='train', val_batch_size=1, num_workers=8):
9 | train_dataset = None
10 | train_dataloader = None
11 | val_dataset = None
12 | val_dataloader = None
13 | if dataset_name.lower() == 'cad_mulobj': # CAD model multiple-object dataset
14 | train_dataset = cadmulobj_loader.CADMultiObjectDataset(num_points=2048, split=train_split, normalize=True)
15 | train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
16 | elif dataset_name.lower() == 'kitti_mulobj': # KITTI multiple-object dataset
17 | train_dataset = kittimulobj_loader.KITTIMultiObjectDataset(num_points=2048, split=train_split)
18 | train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
19 | return train_dataset, train_dataloader, val_dataset, val_dataloader
20 |
21 |
22 | # Build the dataset for testing accordingly
23 | def point_cloud_dataset_test(dataset_name, num_points, batch_size, test_split='test', test_class=None, num_workers=8):
24 | test_dataset = None
25 | if dataset_name.lower() == 'cad_mulobj': # Our multiple-object dataset
26 | test_dataset = cadmulobj_loader.CADMultiObjectDataset(num_points=2048, split=test_split, normalize=True)
27 | test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
28 | elif dataset_name.lower() == 'kitti_mulobj': # Our multiple-object dataset
29 | test_dataset = kittimulobj_loader.KITTIMultiObjectDataset(num_points=2048, split=test_split)
30 | test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
31 | return test_dataset, test_dataloader
32 |
33 |
34 | # Function to generate rotation matrix
35 | def gen_rotation_matrix(theta=-1, rotation_axis=-1):
36 | all_theta = [0, 90, 180, 270]
37 | all_axis = np.eye(3)
38 |
39 | if theta == -1:
40 | theta = all_theta[np.random.randint(0, 4)]
41 | elif theta == -2:
42 | theta = np.random.rand() * 360
43 | elif theta == -3:
44 | theta == (np.random.rand() - 0.5) * 90
45 | else: theta = all_theta[theta]
46 | rotation_theta = np.deg2rad(theta)
47 |
48 | if rotation_axis == -1:
49 | rotation_axis = all_axis[np.random.randint(0, 3), :]
50 | else: rotation_axis = all_axis[rotation_axis,:]
51 | sin_xyz = np.sin(rotation_theta) * rotation_axis
52 |
53 | R = np.cos(rotation_theta) * np.eye(3)
54 |
55 | R[0,1] = -sin_xyz[2]
56 | R[0,2] = sin_xyz[1]
57 | R[1,0] = sin_xyz[2]
58 | R[1,2] = -sin_xyz[0]
59 | R[2,0] = -sin_xyz[1]
60 | R[2,1] = sin_xyz[0]
61 | R = R + (1 - np.cos(rotation_theta)) * np.dot(np.expand_dims(rotation_axis, axis=1), np.expand_dims(rotation_axis, axis=0))
62 | R = torch.from_numpy(R)
63 | return R.float()
--------------------------------------------------------------------------------
/dataloaders/cad_models_loader.py:
--------------------------------------------------------------------------------
1 | '''
2 | Dataloader for CAD models of "person", "car", "cone", "plant" from ModelNet40, and "motorbike" from ShapeNetPart
3 | '''
4 |
5 | import torch.utils.data as data
6 | import os
7 | import os.path
8 | import numpy as np
9 |
10 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
11 | dataset_path=os.path.abspath(os.path.join(BASE_DIR, '../dataset/'))
12 |
13 |
14 | class CADModelsDataset(data.Dataset):
15 |
16 | def create_cad_models_dataset_pickle(self, root):
17 | dict_dataset = np.load(os.path.join(root, 'cadmulobj/cad_models.npy'), allow_pickle=True).item()
18 | data_person = dict_dataset['person'] # 0. person
19 | data_car = dict_dataset['car'] # 1. car
20 | data_cone = dict_dataset['cone'] # 2. cone
21 | data_plant = dict_dataset['plant'] # 3. plant
22 | data_motorbike = dict_dataset['motorbike'] # 4. motorbike
23 |
24 | label_person = np.ones(data_person.shape[0], dtype=int) * 0
25 | label_car = np.ones(data_car.shape[0], dtype=int) * 1
26 | label_cone = np.ones(data_cone.shape[0], dtype=int) * 2
27 | label_plant = np.ones(data_plant.shape[0], dtype=int) * 3
28 | label_motorbike = np.ones(data_motorbike.shape[0], dtype=int) * 4
29 | self.data = np.concatenate((data_person, data_car, data_cone, data_plant, data_motorbike), axis=0)
30 | self.label = np.concatenate((label_person, label_car, label_cone, label_plant, label_motorbike), axis=0)
31 | self.total = self.data.shape[0]
32 | self.obj_type_num = 5
33 |
34 | def __init__(self, root=dataset_path, num_points=2048, normalize=True):
35 | self.npoints = num_points
36 | self.normalize = normalize
37 | self.create_cad_models_dataset_pickle(root)
38 |
39 | def __getitem__(self, index):
40 | point_set = self.data[index, 0 : self.npoints, :]
41 | label = self.label[index]
42 | if self.normalize:
43 | point_set = self.pc_normalize(point_set)
44 | return point_set, label
45 |
46 | def __len__(self):
47 | return self.total
48 |
49 | # pc: NxC, return NxC
50 | def pc_normalize(self, pc):
51 | l = pc.shape[0]
52 | centroid = np.mean(pc, axis=0)
53 | pc = pc - centroid
54 | m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
55 | pc = pc / m
56 | return pc
57 |
--------------------------------------------------------------------------------
/dataloaders/cadmulobj_loader.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | '''
4 | Multi-object dataset based on CAD models
5 | '''
6 |
7 | import torch.utils.data as data
8 | import os
9 | import sys
10 | import os.path
11 | import numpy as np
12 |
13 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
14 | dataset_path=os.path.abspath(os.path.join(BASE_DIR, '../dataset/'))
15 |
16 |
17 | class CADMultiObjectDataset(data.Dataset):
18 |
19 | def __init__(self, root=dataset_path, num_points=2048, split=None, normalize=True):
20 | from . import cad_models_loader
21 | self.npoints = num_points
22 | self.normalize = normalize
23 | dict_dataset = np.load(os.path.join(root, 'cadmulobj/cad_mulobj_param_' + split.lower() + '.npy'), allow_pickle=True).item()
24 | self.scene_radius = dict_dataset['scene_radius']
25 | self.total = dict_dataset['total']
26 | self.augmentation = dict_dataset['augmentation']
27 | self.num_data_batch = dict_dataset['num_batch']
28 | self.batch_num_model = dict_dataset['batch_num_model']
29 | self.batch_num_example = dict_dataset['batch_num_example']
30 | self.data_list = dict_dataset['list_example']
31 | self.max_obj_num = np.max(self.batch_num_model)
32 | self.base_dataset = cad_models_loader.CADModelsDataset(num_points=2048, normalize=True)
33 | self.obj_type_num = self.base_dataset.obj_type_num
34 |
35 | def __getitem__(self, index):
36 |
37 | index_new = index
38 | label = np.ones(self.max_obj_num, dtype=int) * -1
39 | point_set = []
40 | num_points_each = int(np.ceil(self.npoints / len(self.data_list[index_new]['idx'])))
41 |
42 | for cnt, idx_obj in enumerate(self.data_list[index_new]['idx']): # take out the models one-by-one
43 | obj_pc = self.pc_normalize(self.base_dataset[idx_obj][0][:num_points_each, :])
44 | label[cnt] = self.base_dataset[idx_obj][1]
45 | trans = self.data_list[index_new]['coor'][cnt].copy()
46 | if self.augmentation == True: # rotation augmentation if needed
47 | obj_pc = obj_pc @ self.gen_rotation_matrix(self.data_list[index_new]['dr'][cnt][0], 1)
48 | trans[1] = trans[1] - np.min(obj_pc[:,1])
49 | point_set.append(obj_pc + trans)
50 |
51 | point_set = np.vstack(point_set)[:self.npoints,:]
52 | return point_set.astype(np.float32), label
53 |
54 | def __len__(self):
55 | return self.total
56 |
57 | # pc: NxC, return NxC
58 | def pc_normalize(self, pc):
59 | l = pc.shape[0]
60 | centroid = np.mean(pc, axis=0)
61 | pc = pc - centroid
62 | m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
63 | pc = pc / m
64 | return pc
65 |
66 | # Generate rotation matrix, theta is in terms of degree
67 | def gen_rotation_matrix(self, theta, rotation_axis):
68 | all_axis = np.eye(3)
69 | rotation_theta = np.deg2rad(theta)
70 | rotation_axis = all_axis[rotation_axis,:]
71 | sin_xyz = np.sin(rotation_theta)*rotation_axis
72 | R = np.cos(rotation_theta)*np.eye(3)
73 | R[0,1] = -sin_xyz[2]
74 | R[0,2] = sin_xyz[1]
75 | R[1,0] = sin_xyz[2]
76 | R[1,2] = -sin_xyz[0]
77 | R[2,0] = -sin_xyz[1]
78 | R[2,1] = sin_xyz[0]
79 | R = R + (1-np.cos(rotation_theta))*np.dot(np.expand_dims(rotation_axis, axis = 1), np.expand_dims(rotation_axis, axis = 0))
80 | return R
81 |
82 |
83 | def draw_coor(radius):
84 | x = np.random.randint(radius) * 2 - (radius - 1)
85 | y = np.random.randint(radius) * 2 - (radius - 1)
86 | z = 0
87 | return x,y,z
88 |
89 |
90 | # Use this main() function to generate data
91 | def main():
92 |
93 | if os.path.exists(opt.cad_mulobj_output_file_name + '.npy') == False:
94 | num_batch = len(opt.cad_mulobj_num_example)
95 | list_example = []
96 | for idx_batch in range(num_batch):
97 | for idx_example in range(opt.cad_mulobj_num_example[idx_batch]):
98 | print("Batch: %d, Idx: %d" % (idx_batch, idx_example))
99 | coor = np.zeros((opt.cad_mulobj_num_add_model[idx_batch], 3), dtype=float) # coordinate of the objet
100 | idx = np.zeros(opt.cad_mulobj_num_add_model[idx_batch], dtype=int) # object index from the base dataset
101 | dr = np.zeros((opt.cad_mulobj_num_add_model[idx_batch], 2), dtype=int) # object orientation
102 |
103 | # Generate object properties
104 | for idx_obj in range(opt.cad_mulobj_num_add_model[idx_batch]):
105 | collision = True
106 | if opt.augmentation:
107 | dr[idx_obj, 0], dr[idx_obj, 1] = np.random.randint(0,360), np.random.randint(0,3) # generate direction
108 | else: dr[idx_obj, 0], dr[idx_obj, 1] = 0, 0
109 | idx[idx_obj] = np.random.randint(opt.cad_mulobj_num_ava_model) # generate object index
110 | while collision == True: # generate coordinate
111 | collision = False
112 | coor[idx_obj,2], coor[idx_obj,0], coor[idx_obj,1] = draw_coor(opt.cad_mulobj_scene_radius)
113 | for check_obj in range(idx_obj): # check collision between idx_obj and check_obj
114 | if np.sum(np.power(coor[idx_obj,:] - coor[check_obj,:], 2)) < 4 - 1e-9:
115 | collision =True
116 | break
117 |
118 | # Generate the object index
119 | list_example.append({'coor':coor, 'idx':idx, 'dr':dr})
120 |
121 | # Save the dataset parameters
122 | dict_dataset = {
123 | 'scene_radius': opt.cad_mulobj_scene_radius,
124 | 'total': sum(opt.cad_mulobj_num_example),
125 | 'augmentation': opt.augmentation,
126 | 'ava_model_idx': opt.cad_mulobj_num_ava_model,
127 | 'num_batch': len(opt.cad_mulobj_num_example),
128 | 'batch_num_model': opt.cad_mulobj_num_add_model,
129 | 'batch_num_example': opt.cad_mulobj_num_example,
130 | 'list_example': list_example
131 | }
132 | np.save(opt.cad_mulobj_output_file_name, dict_dataset)
133 | print('Dataset %s generation completed.' % opt.cad_mulobj_output_file_name)
134 |
135 | if __name__ == "__main__":
136 |
137 | sys.path.append(os.path.join(BASE_DIR, '..'))
138 | from util.option_handler import GenerateCADMultiObjectOptionHandler
139 | option_handler = GenerateCADMultiObjectOptionHandler()
140 | opt = option_handler.parse_options() # all options are parsed through this command
141 | option_handler.print_options(opt) # print out all the options
142 | main()
--------------------------------------------------------------------------------
/dataloaders/kittimulobj_loader.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | '''
4 | Multi-object dataset based on KITTI
5 | '''
6 |
7 | import torch.utils.data as data
8 | import pickle
9 | import os
10 | import os.path
11 | import numpy as np
12 | import sys
13 |
14 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
15 | dataset_path=os.path.abspath(os.path.join(BASE_DIR, '../dataset/'))
16 |
17 |
18 | class KITTIMultiObjectDataset(data.Dataset):
19 |
20 | def __init__(self, root=dataset_path, num_points=2048, split='train'):
21 |
22 | with open(os.path.join(root, 'kittimulobj/kitti_mulobj_param_' + split.lower() + '_' + str(num_points) +'.pkl'), 'rb') as pickle_file:
23 | dict_dataset = pickle.load(pickle_file)
24 | with open(os.path.join(root, 'kittimulobj', dict_dataset['base_dataset'] +'.pkl'), 'rb') as pickle_file:
25 | self.obj_dataset = pickle.load(pickle_file)
26 | self.name_dict={'Pedestrian':0, 'Car':1, 'Cyclist':2, 'Van':3, 'Truck':4}
27 | self.npoints = num_points
28 | self.scene_radius = dict_dataset['scene_radius']
29 | self.total = dict_dataset['total']
30 | self.num_data_batch = dict_dataset['num_batch']
31 | self.batch_num_model = dict_dataset['batch_num_model']
32 | self.batch_num_example = dict_dataset['batch_num_example']
33 | self.data_list = dict_dataset['list_example']
34 | self.datapath = os.path.join(root, 'kittimulobj/kitti_single')
35 | self.max_obj_num = np.max(self.batch_num_model)
36 | self.obj_type_num = len(self.name_dict)
37 | if split.lower().find('test') >= 0:
38 | self.test = True
39 | else: self.test = False
40 |
41 | def __getitem__(self, index):
42 |
43 | label = np.ones(self.max_obj_num, dtype=int) * -1
44 | point_set = []
45 | num_points_each = int(np.ceil(self.npoints / len(self.data_list[index]['idx'])))
46 | for cnt, idx_obj in enumerate(self.data_list[index]['idx']): # take out the models one-by-one
47 | obj_pc = np.fromfile(os.path.join(self.datapath, os.path.basename(self.obj_dataset[idx_obj]['path'])), dtype=np.float32).reshape(-1, 4) # read
48 | if self.test: np.random.seed(0)
49 | obj_pc = obj_pc[np.random.choice(obj_pc.shape[0],num_points_each), 0:3] # sample
50 | label[cnt] = self.name_dict[self.obj_dataset[idx_obj]['name']]
51 | obj_pc = self.pc_normalize(obj_pc, self.obj_dataset[idx_obj]['box3d_lidar'], self.obj_dataset[idx_obj]['name']) # normalize
52 | obj_pc[:,:2] += self.data_list[index]['coor'][cnt][:2] # translate
53 | point_set.append(obj_pc)
54 | point_set = np.vstack(point_set)[:self.npoints,:]
55 |
56 | return point_set.astype(np.float32), label
57 |
58 | def __len__(self):
59 | return self.total
60 |
61 | # pc: NxC, return NxC
62 | def pc_normalize(self, pc, bbox, label):
63 | pc -= bbox[:3]
64 | box_len = np.sqrt(bbox[3] ** 2 + bbox[4] ** 2 + bbox[5] ** 2)
65 | pc = pc / (box_len) * 2
66 | if label == 'Pedestrian' or label == 'Cyclist':
67 | pc /= 2 # shrink the point sets for models with a person to better emulate a driving scene
68 | return pc
69 |
70 | def draw_coor(radius):
71 | x = np.random.randint(radius) * 2 - (radius - 1)
72 | y = np.random.randint(radius) * 2 - (radius - 1)
73 | z = 0
74 | return x,y,z
75 |
76 | # Use this main() function to generate data
77 | def main():
78 | dataset_path=os.path.abspath(os.path.join(BASE_DIR, '../dataset/kittimulobj'))
79 | if os.path.exists(opt.kitti_mulobj_output_file_name + '.pkl') == False:
80 | with open(os.path.join(dataset_path, "kitti_dbinfos_object.pkl"), 'rb') as pickle_file:
81 | db_obj = pickle.load(pickle_file)
82 | useful_obj_list = []
83 | max_obj_num = int(opt.kitti_mulobj_scene_radius ** 2)
84 | for add_obj_num in range(max_obj_num):
85 | useful_obj_list.append([])
86 | for idx_obj in range(len(db_obj)):
87 | for add_obj_num in range(max_obj_num):
88 | if db_obj[idx_obj]['num_points_in_gt'] >= int(np.ceil(opt.num_points / (add_obj_num + 1))):
89 | useful_obj_list[add_obj_num].append(idx_obj)
90 | num_batch = len(opt.kitti_mulobj_num_example)
91 | list_example = []
92 | for idx_batch in range(num_batch):
93 | for idx_example in range(opt.kitti_mulobj_num_example[idx_batch]):
94 | print("Batch: %d, Idx: %d" % (idx_batch, idx_example))
95 | coor = np.zeros((opt.kitti_mulobj_num_add_model[idx_batch], 3), dtype=float) # coordinate of the objet
96 | idx = np.zeros(opt.kitti_mulobj_num_add_model[idx_batch], dtype=int) # object index from the base dataset
97 | dr = np.zeros((opt.kitti_mulobj_num_add_model[idx_batch], 2), dtype=int) # object orientation
98 |
99 | # Generate object properties
100 | for idx_obj in range(opt.kitti_mulobj_num_add_model[idx_batch]):
101 | collision = True
102 | idx[idx_obj] = np.random.choice(useful_obj_list[opt.kitti_mulobj_num_add_model[idx_batch]-1]) # generate object index
103 | while collision == True: # generate coordinate
104 | collision = False
105 | coor[idx_obj,0], coor[idx_obj,1], coor[idx_obj,2] = draw_coor(opt.kitti_mulobj_scene_radius)
106 | for check_obj in range(idx_obj): # check collision between idx_obj and check_obj
107 | if np.sum(np.power(coor[idx_obj,:] - coor[check_obj,:], 2)) < 4 - 1e-9:
108 | collision =True
109 | break
110 |
111 | # Generate the object index
112 | list_example.append({'coor':coor, 'idx':idx, 'dr':dr})
113 |
114 | # Save the dataset parameters
115 | dict_dataset = {
116 | 'base_dataset': "kitti_dbinfos_object",
117 | 'scene_radius': opt.kitti_mulobj_scene_radius,
118 | 'total': sum(opt.kitti_mulobj_num_example),
119 | 'num_batch': len(opt.kitti_mulobj_num_example),
120 | 'batch_num_model': opt.kitti_mulobj_num_add_model,
121 | 'batch_num_example': opt.kitti_mulobj_num_example,
122 | 'list_example': list_example
123 | }
124 | with open(opt.kitti_mulobj_output_file_name + '.pkl', 'wb') as f:
125 | print(f.name)
126 | pickle.dump(dict_dataset, f)
127 | print('Dataset generation completed.')
128 | else: print('Dataset already exist.')
129 |
130 |
131 | if __name__ == "__main__":
132 |
133 | sys.path.append(os.path.join(BASE_DIR, '..'))
134 | from util.option_handler import GenerateKittiMultiObjectOptionHandler
135 | option_handler = GenerateKittiMultiObjectOptionHandler()
136 | opt = option_handler.parse_options() # all options are parsed through this command
137 | option_handler.print_options(opt) # print out all the options
138 | main()
--------------------------------------------------------------------------------
/dataloaders/modelnet_loader.py:
--------------------------------------------------------------------------------
1 | '''
2 | Modified based on https://github.com/yongheng1991/3D-point-capsule-networks/blob/master/dataloaders/modelnet40_loader.py
3 | '''
4 |
5 | import os
6 | import numpy as np
7 | from torch.utils.data import Dataset
8 | import h5py
9 |
10 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
11 | data_path=os.path.abspath(os.path.join(BASE_DIR, '../dataset/modelnet40')) + '/'
12 |
13 | class ModelNet40(Dataset):
14 |
15 | def load_h5(self, file_name):
16 | f = h5py.File(file_name)
17 | data = f['data'][:]
18 | label = f['label'][:]
19 | return data, label
20 |
21 | def __init__(self, data_path=data_path, num_points=2048, transform=None,
22 | phase='train'):
23 | self.data_path = os.path.join(data_path, 'modelnet40_ply_hdf5_2048')
24 | self.num_points = num_points
25 | self.num_classes = 40
26 | self.transform = transform
27 |
28 | # store data
29 | shape_name_file = os.path.join(self.data_path, 'shape_names.txt')
30 | self.shape_names = [line.rstrip() for line in open(shape_name_file)]
31 | self.coordinates = []
32 | self.labels = []
33 | try:
34 | files = os.path.join(self.data_path, '{}_files.txt'.format(phase))
35 | files = [line.rstrip() for line in open(files)]
36 | for index, file in enumerate(files):
37 | file_name = file.split('/')[-1]
38 | files[index] = os.path.join(self.data_path, file_name)
39 | except FileNotFoundError:
40 | raise ValueError('Unknown phase or invalid data path.')
41 | for file in files:
42 | current_data, current_label = self.load_h5(file)
43 | current_data = current_data[:, 0:self.num_points, :]
44 | self.coordinates.append(current_data)
45 | self.labels.append(current_label)
46 | self.coordinates = np.vstack(self.coordinates).astype(np.float32)
47 | self.labels = np.vstack(self.labels).squeeze().astype(np.int64)
48 |
49 | def __len__(self):
50 | return self.coordinates.shape[0]
51 |
52 | def __getitem__(self, index):
53 | # coord = np.transpose(self.coordinates[index]) # 3 * N
54 | coord = self.coordinates[index]
55 | label = self.labels[index]
56 | data = (coord,)
57 | # transform coordinates
58 | if self.transform is not None:
59 | transformed, matrix, mask = self.transform(coord)
60 | data += (transformed, matrix, mask)
61 | data += (label,)
62 | return data
63 |
64 | def class_extractor(label, loader):
65 | list_obj=[]
66 | for i in range(len(loader)):
67 | data, label_cur = loader.__getitem__(i)
68 | if label_cur == label:
69 | list_obj.append(data)
70 | list_obj = np.vstack(list_obj).reshape(-1,2048,3)
71 | return list_obj
72 |
73 |
74 | if __name__ == '__main__':
75 | main()
--------------------------------------------------------------------------------
/dataloaders/shapenet_part_loader.py:
--------------------------------------------------------------------------------
1 | # print(dir_point, dir_seg)
2 | #from __future__ import print_function
3 | import torch.utils.data as data
4 | import os
5 | import os.path
6 | import torch
7 | import json
8 | import numpy as np
9 | import sys
10 |
11 |
12 |
13 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
14 | dataset_path=os.path.abspath(os.path.join(BASE_DIR, '../dataset/shapenet_part/shapenetcore_partanno_segmentation_benchmark_v0/'))
15 |
16 | class PartDataset(data.Dataset):
17 | def __init__(self, root=dataset_path, npoints=2500, classification=False, class_choice=None, split='train', normalize=True):
18 | self.npoints = npoints
19 | self.root = root
20 | self.catfile = os.path.join(self.root, 'synsetoffset2category.txt')
21 | self.cat = {}
22 | self.classification = classification
23 | self.normalize = normalize
24 |
25 | with open(self.catfile, 'r') as f:
26 | for line in f:
27 | ls = line.strip().split()
28 | self.cat[ls[0]] = ls[1]
29 | # print(self.cat)
30 | if not class_choice is None:
31 | self.cat = {k: v for k, v in self.cat.items() if k in class_choice}
32 | print(self.cat)
33 | self.meta = {}
34 | with open(os.path.join(self.root, 'train_test_split', 'shuffled_train_file_list.json'), 'r') as f:
35 | train_ids = set([str(d.split('/')[2]) for d in json.load(f)])
36 | with open(os.path.join(self.root, 'train_test_split', 'shuffled_val_file_list.json'), 'r') as f:
37 | val_ids = set([str(d.split('/')[2]) for d in json.load(f)])
38 | with open(os.path.join(self.root, 'train_test_split', 'shuffled_test_file_list.json'), 'r') as f:
39 | test_ids = set([str(d.split('/')[2]) for d in json.load(f)])
40 |
41 | for item in self.cat:
42 | # print('category', item)
43 | self.meta[item] = []
44 | dir_point = os.path.join(self.root, self.cat[item], 'points')
45 | dir_seg = os.path.join(self.root, self.cat[item], 'points_label')
46 | # print(dir_point, dir_seg)
47 | fns = sorted(os.listdir(dir_point))
48 | if split == 'trainval':
49 | fns = [fn for fn in fns if ((fn[0:-4] in train_ids) or (fn[0:-4] in val_ids))]
50 | elif split == 'train':
51 | fns = [fn for fn in fns if fn[0:-4] in train_ids]
52 | elif split == 'val':
53 | fns = [fn for fn in fns if fn[0:-4] in val_ids]
54 | elif split == 'test':
55 | fns = [fn for fn in fns if fn[0:-4] in test_ids]
56 | else:
57 | print('Unknown split: %s. Exiting..' % (split))
58 | exit(-1)
59 |
60 | for fn in fns:
61 | token = (os.path.splitext(os.path.basename(fn))[0])
62 | self.meta[item].append((os.path.join(dir_point, token + '.pts'), os.path.join(dir_seg, token + '.seg'),self.cat[item], token))
63 | self.datapath = []
64 | for item in self.cat:
65 | for fn in self.meta[item]:
66 | self.datapath.append((item, fn[0], fn[1], fn[2], fn[3]))
67 |
68 | self.classes = dict(zip(sorted(self.cat), range(len(self.cat))))
69 | print(self.classes)
70 | self.num_seg_classes = 0
71 | if not self.classification:
72 | for i in range(len(self.datapath)//50):
73 | l = len(np.unique(np.loadtxt(self.datapath[i][2]).astype(np.uint8)))
74 | if l > self.num_seg_classes:
75 | self.num_seg_classes = l
76 | # print(self.num_seg_classes)
77 | self.cache = {} # from index to (point_set, cls, seg) tuple
78 | self.cache_size = 18000
79 |
80 | def __getitem__(self, index):
81 | if index in self.cache:
82 | # point_set, seg, cls= self.cache[index]
83 | point_set, seg, cls, foldername, filename = self.cache[index]
84 | else:
85 | fn = self.datapath[index]
86 | cls = self.classes[self.datapath[index][0]]
87 | # cls = np.array([cls]).astype(np.int32)
88 | point_set = np.loadtxt(fn[1]).astype(np.float32)
89 | if self.normalize:
90 | point_set = self.pc_normalize(point_set)
91 | seg = np.loadtxt(fn[2]).astype(np.int64) - 1
92 | foldername = fn[3]
93 | filename = fn[4]
94 | if len(self.cache) < self.cache_size:
95 | self.cache[index] = (point_set, seg, cls, foldername, filename)
96 |
97 | #print(point_set.shape, seg.shape)
98 | choice = np.random.choice(len(seg), self.npoints, replace=True)
99 | # resample
100 | point_set = point_set[choice, :]
101 | seg = seg[choice]
102 |
103 | # To Pytorch
104 | point_set = torch.from_numpy(point_set)
105 | seg = torch.from_numpy(seg)
106 | cls = torch.from_numpy(np.array([cls]).astype(np.int64))
107 | if self.classification:
108 | return point_set, cls
109 | else:
110 | return point_set, seg , cls
111 |
112 |
113 | def __len__(self):
114 | return len(self.datapath)
115 |
116 | def pc_normalize(self, pc):
117 | """ pc: NxC, return NxC """
118 | l = pc.shape[0]
119 | centroid = np.mean(pc, axis=0)
120 | pc = pc - centroid
121 | m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
122 | pc = pc / m
123 | return pc
124 |
125 |
126 | if __name__ == '__main__':
127 | print('test')
128 | d = PartDataset( root='../dataset/shapenetcore_partanno_segmentation_benchmark_v0/',classification=True, class_choice='Airplane', npoints=2048, split='test')
129 | ps, cls = d[0]
130 | print(ps.size(), ps.type(), cls.size(), cls.type())
--------------------------------------------------------------------------------
/experiments/counting.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | '''
4 | Object counting experiment
5 |
6 | author: jpang
7 | created: Apr 27, 2020 (Mon) 13:55 EST
8 | '''
9 |
10 | import multiprocessing
11 | multiprocessing.set_start_method('spawn', True)
12 |
13 | import torch
14 | import torch.nn.parallel
15 | import sys
16 | import os
17 | import numpy as np
18 | from sklearn.svm import SVC
19 |
20 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
21 | sys.path.append(os.path.join(BASE_DIR, '..'))
22 |
23 | from dataloaders import point_cloud_dataset_test
24 | from models.autoencoder import PointCloudAutoencoder
25 | from util.option_handler import CountingOptionHandler
26 | import time
27 |
28 | import warnings
29 | warnings.filterwarnings("ignore")
30 |
31 | def main():
32 | print(torch.cuda.device_count(), "GPUs will be used for object counting.")
33 |
34 | # Build up an autoencoder
35 | t = time.time()
36 |
37 | # Load a saved model
38 | if opt.checkpoint == '':
39 | print("Please provide the model path.")
40 | exit()
41 | else:
42 | path_checkpoints = [os.path.join(opt.checkpoint, f) for f in os.listdir(opt.checkpoint)
43 | if os.path.isfile(os.path.join(opt.checkpoint, f)) and f[-1]=='h']
44 | if len(path_checkpoints) > 1:
45 | path_checkpoints.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))
46 |
47 | # Take care of the dataset
48 | test_dataset, test_dataloader = point_cloud_dataset_test(opt.dataset_name,
49 | opt.num_points, opt.batch_size, opt.count_split)
50 |
51 | # Create folder to save results
52 | if not os.path.exists(opt.exp_name):
53 | os.makedirs(opt.exp_name)
54 | device = torch.device("cuda:0")
55 |
56 | # Go through each checkpoint in the folder
57 | max_avg_mae, min_avg_mae = 0, 1e6
58 | max_min_mae, min_min_mae = 0, 1e6
59 | for cnt, str_checkpoint in enumerate(path_checkpoints):
60 |
61 | if len(path_checkpoints) > 1:
62 | epoch_cur = int(''.join(filter(str.isdigit, str_checkpoint[str_checkpoint.find('epoch_'):])))
63 | if (opt.epoch_interval[0] < 0) and (len(path_checkpoints) - cnt > opt.epoch_interval[1]): continue
64 | if (opt.epoch_interval[0] >= 0) and (epoch_cur >= opt.epoch_interval[0] and epoch_cur <= opt.epoch_interval[1])== False: continue
65 | else: epoch_cur = -1
66 | checkpoint = torch.load(str_checkpoint)
67 |
68 | # Load the model, may use model configuration from checkpoint
69 | if 'opt' in checkpoint and opt.config_from_checkpoint == True:
70 | checkpoint['opt'].grid_dims = opt.grid_dims
71 | ae = PointCloudAutoencoder(checkpoint['opt'])
72 | print("\nModel configuration loaded from checkpoint %s." % str_checkpoint)
73 | else:
74 | ae = PointCloudAutoencoder(opt)
75 | checkpoint['opt'] = opt
76 | torch.save(checkpoint, str_checkpoint)
77 | print("\nModel configuration written to checkpoint %s." % str_checkpoint)
78 | ae.load_state_dict(checkpoint['model_state_dict'])
79 | print("\nExisting model %s loaded." % (str_checkpoint))
80 | ae.to(device)
81 | ae.eval() # set the autoencoder to evaluation mode
82 |
83 | # Compute the codewords
84 | print('Computing codewords..')
85 | batch_cw = []
86 | batch_labels = []
87 | torch.cuda.empty_cache()
88 | len_test = len(test_dataloader)
89 | with torch.no_grad():
90 | for batch_id, data in enumerate(test_dataloader):
91 | points, labels = data
92 | if(points.size(0) < opt.batch_size): break
93 | points = points.cuda()
94 | with torch.no_grad():
95 | rec = ae(points)
96 | cw = rec['cw']
97 | batch_cw.append(cw.detach().cpu().numpy())
98 | if (opt.dataset_name.lower() != 'torus_orig') and (opt.dataset_name.lower() != 'torus'):
99 | labels = torch.sum(labels>=0, dim=1)
100 | else: labels = labels + 1
101 | batch_labels.append(labels.detach().cpu().numpy())
102 | if batch_id % opt.print_freq == 0:
103 | elapse = time.time() - t
104 | print(' batch_no: %d/%d, time: %f' % (batch_id, len_test, elapse))
105 | batch_cw = np.vstack(batch_cw)
106 | batch_labels = np.vstack(batch_labels).reshape((-1,)).astype(int)
107 |
108 | # Divide the folds for experiments
109 | data_cnt_total = len(batch_labels)
110 | data_fold_idx = np.zeros(data_cnt_total, dtype=int)
111 | if (opt.dataset_name.lower() != 'torus_orig') and (opt.dataset_name.lower() != 'torus'):
112 | obj_cnt = np.zeros(test_dataset.max_obj_num, dtype=int)
113 | else: obj_cnt = np.zeros(3, dtype=int)
114 | for cur_obj in range(data_cnt_total):
115 | obj_cnt[batch_labels[cur_obj]-1] += 1
116 | data_fold_idx[cur_obj] = obj_cnt[batch_labels[cur_obj]-1]
117 | obj_cnt = (obj_cnt / opt.count_fold).astype(int)
118 | for cur_obj in range(data_cnt_total):
119 | data_fold_idx[cur_obj] = int((data_fold_idx[cur_obj]-1) / (obj_cnt[batch_labels[cur_obj]-1] + 1e-10))
120 |
121 | # Perform testing for each fold
122 | mae = np.zeros(opt.count_fold, dtype=float)
123 | for cur_fold in range(opt.count_fold):
124 | data_train = batch_cw[data_fold_idx==cur_fold, :]
125 | label_train = batch_labels[data_fold_idx==cur_fold]
126 | data_test = batch_cw[data_fold_idx!=cur_fold, :]
127 | label_test = batch_labels[data_fold_idx!=cur_fold]
128 |
129 | # Train and test the classifier
130 | classifier = SVC(gamma='scale', C=opt.svm_params[0])
131 | classifier.fit(data_train, label_train)
132 | pred = classifier.predict(data_test)
133 | mae[cur_fold] = np.mean(abs(label_test-pred))
134 | print(' fold: {} '.format(cur_fold+1) + 'mae: {:.4f} '.format(mae[cur_fold]) )
135 | print('avg_mae: {:.4f} '.format(np.mean(mae)))
136 | mae_idx = np.argmin(mae)
137 | print('min_mae: {:.4f} '.format(mae[mae_idx]) + 'min_mae_idx: {} '.format(mae_idx+1))
138 |
139 | if np.mean(mae) > max_avg_mae:
140 | max_avg_mae = np.mean(mae)
141 | max_avg_mae_epoch = epoch_cur
142 | if np.mean(mae) < min_avg_mae:
143 | min_avg_mae = np.mean(mae)
144 | min_avg_mae_epoch = epoch_cur
145 | if mae[mae_idx] > max_min_mae:
146 | max_min_mae = mae[mae_idx]
147 | max_min_mae_epoch = epoch_cur
148 | if mae[mae_idx] < min_min_mae:
149 | min_min_mae = mae[mae_idx]
150 | min_min_mae_epoch = epoch_cur
151 |
152 | print('\nmax_avg_mae: {:.4f},'.format(max_avg_mae) + ' max_avg_mae_epoch: {:d}'.format(max_avg_mae_epoch))
153 | print('min_avg_mae: {:.4f},'.format(min_avg_mae) + ' min_avg_mae_epoch: {:d}'.format(min_avg_mae_epoch))
154 | print('max_min_mae: {:.4f},'.format(max_min_mae) + ' max_min_mae_epoch: {:d}'.format(max_min_mae_epoch))
155 | print('min_min_mae: {:.4f},'.format(min_min_mae) + ' min_min_mae_epoch: {:d}'.format(min_min_mae_epoch))
156 |
157 | print('\nDone!')
158 |
159 | if __name__ == "__main__":
160 |
161 | option_handler = CountingOptionHandler()
162 | opt = option_handler.parse_options() # all options are parsed through this command
163 | option_handler.print_options(opt) # print out all the options
164 | main()
165 |
--------------------------------------------------------------------------------
/experiments/reconstruction.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | '''
4 | Reconstruction experiment
5 | '''
6 |
7 | import multiprocessing
8 | multiprocessing.set_start_method('spawn', True)
9 |
10 | import open3d as o3d
11 | import torch
12 | import sys
13 | import os
14 | import numpy as np
15 | from PIL import Image
16 |
17 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
18 | sys.path.append(os.path.join(BASE_DIR, '..'))
19 |
20 | from dataloaders import point_cloud_dataset_test
21 | from models.autoencoder import PointCloudAutoencoder
22 | from util.option_handler import TestOptionHandler
23 | from util.mesh_writer import write_ply_mesh
24 |
25 | import warnings
26 | warnings.filterwarnings("ignore")
27 |
28 | def main():
29 | print(torch.cuda.device_count(), "GPUs will be used for testing.")
30 |
31 | # Load a saved model
32 | if opt.checkpoint == '':
33 | print("Please provide the model path.")
34 | exit()
35 | else:
36 | checkpoint = torch.load(opt.checkpoint)
37 | if 'opt' in checkpoint and opt.config_from_checkpoint == True:
38 | checkpoint['opt'].grid_dims = opt.grid_dims
39 | checkpoint['opt'].xyz_chamfer_weight = opt.xyz_chamfer_weight
40 | ae = PointCloudAutoencoder(checkpoint['opt'])
41 | print("\nModel configuration loaded from checkpoint %s." % opt.checkpoint)
42 | else:
43 | ae = PointCloudAutoencoder(opt)
44 | checkpoint['opt'] = opt
45 | torch.save(checkpoint, opt.checkpoint)
46 | print("\nModel configuration written to checkpoint %s." % opt.checkpoint)
47 | ae.load_state_dict(checkpoint['model_state_dict'])
48 | print("Existing model %s loaded.\n" % (opt.checkpoint))
49 | device = torch.device("cuda:0")
50 | ae.to(device)
51 | ae.eval() # set the autoencoder to evaluation mode
52 |
53 | # Create a folder to write the results
54 | if not os.path.exists(opt.exp_name):
55 | os.makedirs(opt.exp_name)
56 |
57 | # Take care of the dataset
58 | _, test_dataloader = point_cloud_dataset_test(opt.dataset_name, opt.num_points, opt.batch_size, opt.test_split)
59 | point_cnt = opt.grid_dims[0] * opt.grid_dims[1]
60 |
61 | # Set the colors of all the points as 0.5 for visualization
62 | coloring_rec = np.ones((opt.grid_dims[0] * opt.grid_dims[1], 3)) * 0.5
63 | coloring = np.concatenate((coloring_rec, np.repeat(np.array([np.array(opt.gt_color)]), opt.num_points, axis=0)), axis=0)
64 |
65 | # Begin testing
66 | test_loss_sum, test_loss_sqrt_sum = 0, 0
67 | batch_id = 0
68 | print('\nTesting...')
69 | not_end_yet = True
70 | it = iter(test_dataloader)
71 | len_test = len(test_dataloader)
72 |
73 | # Iterates the testing process
74 | while not_end_yet == True:
75 |
76 | # Fetch the data
77 | points, _ = next(it)
78 | points = points.cuda()
79 | not_end_yet = batch_id + 1 < len_test
80 | if(points.size(0) < opt.batch_size): break
81 |
82 | # The forward pass
83 | with torch.no_grad():
84 | rec = ae(points)
85 | grid = rec['grid'] if 'grid' in rec else None
86 | if 'graph_wght' in rec:
87 | if opt.graph_delete_point_eps > 0: # may use another eps value
88 | old_eps = ae.decoder.graph_filter.graph_eps_sqr
89 | ae.decoder.graph_filter.graph_eps_sqr = opt.graph_delete_point_eps ** 2
90 | with torch.no_grad():
91 | graph_wght = ae(points)['graph_wght']
92 | ae.decoder.graph_filter.graph_eps_sqr = old_eps
93 | else:
94 | graph_wght = rec['graph_wght']
95 | graph_wght = graph_wght.permute([0, 2, 3, 1]).detach().cpu().numpy()
96 | else: graph_wght = None
97 | rec = rec['rec']
98 |
99 | # Benchmarking the results
100 | test_loss, test_loss_sqrt = 0, 0
101 | if not(graph_wght is None) and (opt.graph_delete_point_mode >=0) and opt.graph_delete_point_eps > 0:
102 | idx_map = np.zeros((opt.batch_size, opt.grid_dims[0], opt.grid_dims[1]), dtype=int)
103 | idx = np.zeros(opt.batch_size, dtype=int)
104 |
105 | # When graph is used, isolated points are removed
106 | for i in range(opt.grid_dims[0]):
107 | for j in range(opt.grid_dims[1]):
108 | degree = np.sum(graph_wght[:, i, j, :] > opt.graph_thres, axis=1)
109 | tag = (degree <= opt.graph_delete_point_mode) # tag the point to be deleted
110 | idx_map[:, i, j][tag == True] = -1
111 | idx_map[:, i, j][tag == False] = idx[tag == False]
112 | idx[tag == False] = idx[tag == False] + 1
113 | for b in range(opt.batch_size):
114 | pc_cur = rec[b,idx_map[b].reshape(opt.grid_dims[0] * opt.grid_dims[1])>=0].unsqueeze(0)
115 | test_loss += ae.xyz_loss(points[b].unsqueeze(0), pc_cur, xyz_loss_type=2)
116 | test_loss_sqrt += ae.xyz_loss(points[b].unsqueeze(0), pc_cur, xyz_loss_type=0)
117 | test_loss /= opt.batch_size
118 | test_loss_sqrt /= opt.batch_size
119 | test_loss_sum = test_loss_sum + test_loss
120 | test_loss_sqrt_sum = test_loss_sqrt_sum + test_loss_sqrt
121 | else:
122 | test_loss = ae.xyz_loss(points, rec, xyz_loss_type=2)
123 | test_loss_sqrt = ae.xyz_loss(points, rec, xyz_loss_type=0)
124 | test_loss_sum += test_loss.item()
125 | test_loss_sqrt_sum += test_loss_sqrt.item()
126 |
127 | if batch_id % opt.print_freq == 0:
128 | print(' batch_no: %d/%d, ch_dist: %f, ch^2_dist: %f' %
129 | (batch_id, len_test, test_loss_sqrt.item(), test_loss.item()))
130 |
131 | # Write down the first point cloud
132 | if opt.pc_write_freq > 0 and batch_id % opt.pc_write_freq == 0:
133 | rec_o3d = o3d.geometry.PointCloud()
134 | rec = rec[0].data.cpu()
135 | rec_o3d.colors = o3d.utility.Vector3dVector(coloring[:rec.shape[0],:])
136 | rec_o3d.points = o3d.utility.Vector3dVector(rec)
137 | file_rec = os.path.join(opt.exp_name, str(batch_id) + "_rec.ply")
138 | o3d.io.write_point_cloud(file_rec, rec_o3d)
139 |
140 | gt_o3d = o3d.geometry.PointCloud() # write the ground-truth
141 | gt_o3d.points = o3d.utility.Vector3dVector(points[0].data.cpu())
142 | gt_o3d.colors = o3d.utility.Vector3dVector(coloring[rec.shape[0]:,:])
143 | file_gt = os.path.join(opt.exp_name, str(batch_id) + "_gt.ply")
144 | o3d.io.write_point_cloud(file_gt, gt_o3d)
145 |
146 | if not(grid is None): # write the torn grid
147 | grid = torch.cat((grid[0].contiguous().data.cpu(), torch.zeros(point_cnt, 1)), 1)
148 | grid_o3d = o3d.geometry.PointCloud()
149 | grid_o3d.points = o3d.utility.Vector3dVector(grid)
150 | grid_o3d.colors = o3d.utility.Vector3dVector(coloring_rec)
151 | file_grid = os.path.join(opt.exp_name, str(batch_id) + "_grid.ply")
152 | o3d.io.write_point_cloud(file_grid, grid_o3d)
153 |
154 | if not(graph_wght is None) and opt.write_mesh:
155 | output_path = os.path.join(opt.exp_name, str(batch_id) + "_rec_mesh.ply") # write the reconstructed mesh
156 | write_ply_mesh(rec.view(opt.grid_dims[0], opt.grid_dims[1], 3).numpy(),
157 | coloring_rec.reshape(opt.grid_dims[0], opt.grid_dims[1], 3), opt.graph_edge_color,
158 | output_path, thres=opt.graph_thres, delete_point_mode=opt.graph_delete_point_mode,
159 | weights=graph_wght[0], point_color_as_index=opt.point_color_as_index)
160 |
161 | output_path = os.path.join(opt.exp_name, str(batch_id) + "_grid_mesh.ply") # write the grid mesh
162 | write_ply_mesh(grid.view(opt.grid_dims[0], opt.grid_dims[1], 3).numpy(),
163 | coloring_rec.reshape(opt.grid_dims[0], opt.grid_dims[1], 3), opt.graph_edge_color,
164 | output_path, thres=opt.graph_thres, delete_point_mode=-1,
165 | weights=graph_wght[0], point_color_as_index=opt.point_color_as_index)
166 |
167 | batch_id = batch_id + 1
168 |
169 | # Log the test results
170 | avg_loss = test_loss_sum / (batch_id + 1)
171 | avg_loss_sqrt = test_loss_sqrt_sum / (batch_id + 1)
172 | print('avg_ch_dist: %f avg_ch^2_dist: %f' % (avg_loss_sqrt, avg_loss))
173 |
174 | print('\nDone!\n')
175 |
176 | if __name__ == "__main__":
177 |
178 | option_handler = TestOptionHandler()
179 | opt = option_handler.parse_options() # all options are parsed through this command
180 | option_handler.print_options(opt) # print out all the options
181 | main()
182 |
--------------------------------------------------------------------------------
/experiments/train_basic.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | '''
4 | Basic training script
5 | '''
6 |
7 | import multiprocessing
8 | multiprocessing.set_start_method('spawn', True)
9 |
10 | import torch
11 | import torch.nn.parallel
12 | import torch.optim as optim
13 | import sys
14 | import os
15 | import numpy as np
16 |
17 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
18 | sys.path.append(os.path.join(BASE_DIR, '..'))
19 |
20 | from dataloaders import point_cloud_dataset_train, gen_rotation_matrix
21 | from models.autoencoder import PointCloudAutoencoder
22 | from util.option_handler import TrainOptionHandler
23 | from tensorboardX import SummaryWriter
24 | import time
25 |
26 | import warnings
27 | warnings.filterwarnings("ignore")
28 |
29 | def main():
30 | # Build up an autoencoder
31 | t = time.time()
32 | print(torch.cuda.device_count(), "GPUs will be used for training.")
33 | device = torch.device("cuda:0")
34 | ae = PointCloudAutoencoder(opt)
35 | ae = torch.nn.DataParallel(ae)
36 | ae.to(device)
37 |
38 | # Create folder to save trained models
39 | if not os.path.exists(opt.exp_name):
40 | os.makedirs(opt.exp_name)
41 |
42 | # Take care of the dataset
43 | _, train_dataloader, _, _ = point_cloud_dataset_train(opt.dataset_name, opt.num_points, opt.batch_size, opt.train_split)
44 |
45 | # Create a tensorboard writer
46 | if opt.tf_summary: writer = SummaryWriter(comment=str.replace(opt.exp_name,'/','_'))
47 |
48 | # Setup the optimizer and the scheduler
49 | lr = opt.lr
50 | optimizer = optim.Adam(ae.parameters(), lr=lr[0], betas=(opt.optim_args[0], opt.optim_args[1]), weight_decay=opt.optim_args[2])
51 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr[1], gamma=lr[2], last_epoch=-1)
52 |
53 | # Load a check point if given
54 | if opt.checkpoint != '':
55 | checkpoint = torch.load(opt.checkpoint)
56 | ae.module.load_state_dict(checkpoint['model_state_dict'])
57 | if opt.load_weight_only == False:
58 | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
59 | scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
60 | print("Existing model %s loaded.\n" % (opt.checkpoint))
61 | epoch_start = scheduler.state_dict()['last_epoch']
62 |
63 | # Start training
64 | n_iter = 0
65 | for epoch in range(epoch_start, opt.n_epoch):
66 | ae.train()
67 | print('Training at epoch %d with lr %f' % (epoch, scheduler.get_lr()[0]))
68 |
69 | train_loss_sum = 0
70 | batch_id = 0
71 | not_end_yet = True
72 | it = iter(train_dataloader)
73 | len_train = len(train_dataloader)
74 |
75 | # Iterates the training process
76 | while not_end_yet == True:
77 | points, _ = next(it)
78 | not_end_yet = batch_id + 1 < len_train
79 | if(points.size(0) < opt.batch_size):
80 | break
81 |
82 | # Data augmentation
83 | points = points.cuda()
84 | if opt.augmentation == True:
85 | if opt.augmentation_theta != 0 or opt.augmentation_rotation_axis != 0: # rotation augmentation
86 | rotation_matrix = gen_rotation_matrix(theta=opt.augmentation_theta,
87 | rotation_axis=opt.augmentation_rotation_axis).unsqueeze(0).cuda().expand(opt.batch_size, -1 , -1)
88 | points = torch.bmm(points, rotation_matrix)
89 | if opt.augmentation_max_scale - opt.augmentation_min_scale > 1e-3: # scaling augmentation
90 | noise_scale = np.random.uniform(opt.augmentation_min_scale, opt.augmentation_max_scale)
91 | points *= noise_scale
92 | if (opt.augmentation_flip_axis >= 0) and (np.random.rand() < 0.5): # fliping augmentation
93 | points[:,:,opt.augmentation_flip_axis] *= -1
94 | optimizer.zero_grad()
95 |
96 | # Forward and backward computation
97 | rec = ae(points)
98 | rec = rec["rec"]
99 | train_loss = ae.module.xyz_loss(points, rec)
100 | train_loss.backward()
101 | optimizer.step()
102 |
103 | # Log the result
104 | train_loss_sum += train_loss.item()
105 | if batch_id % opt.print_freq == 0:
106 | elapse = time.time() - t
107 | print(' batch_no: %d/%d, time/iter: %f/%d, train_loss: %f' %
108 | (batch_id, len_train, elapse, n_iter, train_loss.item()))
109 | if opt.tf_summary: writer.add_scalar('train/batch_train_loss', train_loss.item(), n_iter)
110 |
111 | if opt.pc_write_freq > 0 and batch_id % opt.pc_write_freq == 0:
112 | labels = np.concatenate((np.ones(rec.shape[1]), np.zeros(points.shape[1])), axis=0).tolist()
113 | if opt.tf_summary:
114 | writer.add_embedding(torch.cat((rec[0, 0:, 0:3], points[0, :, 0:3]), 0),
115 | global_step=n_iter, metadata=labels, tag="pc")
116 |
117 | n_iter = n_iter + 1
118 | batch_id = batch_id + 1
119 |
120 | # Output the average loss of current epoch
121 | avg_loss = train_loss_sum / (batch_id + 1)
122 | elapse = time.time() - t
123 | print('Epoch: %d time: %f --- avg_loss: %f lr: %f' % (epoch, elapse, avg_loss, scheduler.get_lr()[0]))
124 | if opt.tf_summary: writer.add_scalar('train/epoch_train_loss', avg_loss, epoch)
125 | if opt.tf_summary: writer.add_scalar('train/learning_rate', scheduler.get_lr()[0], epoch)
126 | scheduler.step() # scheduler go one step
127 |
128 | # Save the checkpoint
129 | if epoch % opt.save_epoch_freq == 0 or epoch == opt.n_epoch - 1:
130 | dict_name=opt.exp_name + '/epoch_'+str(epoch)+'.pth'
131 | torch.save({
132 | 'model_state_dict': ae.module.state_dict(),
133 | 'optimizer_state_dict': optimizer.state_dict(),
134 | 'scheduler_state_dict': scheduler.state_dict(),
135 | 'last_epoch': scheduler.state_dict()['last_epoch'],
136 | 'opt': opt
137 | }, dict_name)
138 | print('Current checkpoint saved to %s.' % (dict_name))
139 | print('\n')
140 |
141 | if opt.tf_summary: writer.close()
142 | print('Done!')
143 |
144 |
145 | if __name__ == "__main__":
146 |
147 | option_handler = TrainOptionHandler()
148 | opt = option_handler.parse_options() # all options are parsed through this command
149 | option_handler.print_options(opt) # print out all the options
150 | main()
151 |
--------------------------------------------------------------------------------
/experiments/train_tearing.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | '''
4 | Training script for TearingNet
5 | '''
6 |
7 | import multiprocessing
8 | multiprocessing.set_start_method('spawn', True)
9 |
10 | import torch
11 | import torch.nn.parallel
12 | import torch.optim as optim
13 | import sys
14 | import os
15 | import numpy as np
16 |
17 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
18 | sys.path.append(os.path.join(BASE_DIR, '..'))
19 |
20 | from dataloaders import point_cloud_dataset_train, gen_rotation_matrix
21 | from models.autoencoder import PointCloudAutoencoder
22 | from models.tearingnet import TearingNetBasicModel
23 | from models.tearingnet_graph import TearingNetGraphModel
24 | from util.option_handler import TrainOptionHandler
25 | from tensorboardX import SummaryWriter
26 | import time
27 |
28 | import warnings
29 | warnings.filterwarnings("ignore")
30 |
31 | def main():
32 | # Build up an autoencoder
33 | t = time.time()
34 | print(torch.cuda.device_count(), "GPUs will be used for training.")
35 | device = torch.device("cuda:0")
36 | ae = PointCloudAutoencoder(opt)
37 | ae = torch.nn.DataParallel(ae)
38 | ae.to(device)
39 | tearingnet_basic = isinstance(ae.module.decoder, TearingNetBasicModel)
40 | tearingnet_graph = isinstance(ae.module.decoder, TearingNetGraphModel)
41 |
42 | # Create folder to save trained models
43 | if not os.path.exists(opt.exp_name):
44 | os.makedirs(opt.exp_name)
45 |
46 | # Take care of the dataset
47 | _, train_dataloader, _, _ = point_cloud_dataset_train(dataset_name=opt.dataset_name, \
48 | num_points=opt.num_points, batch_size=opt.batch_size, train_split=opt.train_split)
49 |
50 | # Create a tensorboard writer
51 | if opt.tf_summary: writer = SummaryWriter(comment=str.replace(opt.exp_name,'/','_'))
52 |
53 | # Setup the optimizer and the scheduler
54 | lr = opt.lr
55 | optimizer = optim.Adam(ae.parameters(), lr=lr[0])
56 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr[1], gamma=lr[2], last_epoch=-1)
57 |
58 | # Load a check point if given
59 | if opt.checkpoint != '':
60 | checkpoint = torch.load(opt.checkpoint)
61 | if opt.load_weight_only == False:
62 | ae.module.load_state_dict(checkpoint['model_state_dict'])
63 | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
64 | scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
65 | print("Existing model %s fully loaded.\n" % (opt.checkpoint))
66 | else:
67 | model_dict = ae.module.state_dict() # load parameters from pre-trained FoldingNet
68 | for k in checkpoint['model_state_dict']:
69 | if k in model_dict:
70 | model_dict[k] = checkpoint['model_state_dict'][k]
71 | print(" Found weight: " + k)
72 | elif k.replace('folding1', 'folding') in model_dict:
73 | model_dict[k.replace('folding1', 'folding')] = checkpoint['model_state_dict'][k]
74 | print(" Found weight: " + k)
75 | ae.module.load_state_dict(model_dict)
76 | print("Existing model %s (partly) loaded.\n" % (opt.checkpoint))
77 | epoch_start = scheduler.state_dict()['last_epoch']
78 |
79 | # Start training
80 | n_iter = 0
81 | for epoch in range(epoch_start, opt.n_epoch):
82 | ae.train()
83 | print('Training at epoch %d with lr %f' % (epoch, scheduler.get_lr()[0]))
84 |
85 | train_loss_xyz_sum = 0
86 | if tearingnet_basic:
87 | train_loss_xyz_pre_sum = 0
88 | if tearingnet_graph:
89 | train_loss_xyz_pre_sum = 0
90 | train_loss_xyz_pre2_sum = 0
91 | train_loss_sum = 0
92 | batch_id = 0
93 | not_end_yet = True
94 | it = iter(train_dataloader)
95 | len_train = len(train_dataloader)
96 |
97 | # Iterates the training process
98 | while not_end_yet == True:
99 | points, _ = next(it)
100 | not_end_yet = batch_id + 1 < len_train
101 | if(points.size(0) < opt.batch_size):
102 | break
103 |
104 | # Data agumentation
105 | points = points.cuda()
106 | if opt.augmentation == True:
107 | if opt.augmentation_theta != 0 or opt.augmentation_rotation_axis != 0: # rotation augmentation
108 | rotation_matrix = gen_rotation_matrix(theta=opt.augmentation_theta,
109 | rotation_axis=opt.augmentation_rotation_axis).unsqueeze(0).cuda().expand(opt.batch_size, -1 , -1)
110 | points = torch.bmm(points, rotation_matrix)
111 | if opt.augmentation_max_scale - opt.augmentation_min_scale > 1e-3: # scaling augmentation
112 | noise_scale = np.random.uniform(opt.augmentation_min_scale, opt.augmentation_max_scale)
113 | points *= noise_scale
114 | if (opt.augmentation_flip_axis >= 0) and (np.random.rand() < 0.5): # fliping augmentation
115 | points[:,:,opt.augmentation_flip_axis] *= -1
116 | optimizer.zero_grad()
117 |
118 | # Forward and backward
119 | rec = ae(points)
120 | grid = rec['grid']
121 | if tearingnet_basic:
122 | rec_pre = rec['rec_pre']
123 | if tearingnet_graph:
124 | rec_pre = rec['rec_pre']
125 | rec_pre2 = rec['rec_pre2']
126 | rec = rec['rec']
127 |
128 | # Forward and backward computation
129 | train_loss_xyz = ae.module.xyz_loss(points, rec)
130 | if tearingnet_basic:
131 | train_loss_xyz_pre = ae.module.xyz_loss(points, rec_pre)
132 | if tearingnet_graph:
133 | train_loss_xyz_pre = ae.module.xyz_loss(points, rec_pre)
134 | train_loss_xyz_pre2 = ae.module.xyz_loss(points, rec_pre2)
135 | train_loss = train_loss_xyz # may add other loss if needed
136 | train_loss.backward()
137 | optimizer.step()
138 |
139 | # Log the result
140 | train_loss_xyz_sum += train_loss_xyz.item()
141 | if tearingnet_basic: train_loss_xyz_pre_sum += train_loss_xyz_pre.item()
142 | if tearingnet_graph:
143 | train_loss_xyz_pre_sum += train_loss_xyz_pre.item()
144 | train_loss_xyz_pre2_sum += train_loss_xyz_pre2.item()
145 | train_loss_sum += train_loss.item()
146 | if batch_id % opt.print_freq == 0:
147 | elapse = time.time() - t
148 | if tearingnet_basic:
149 | print(' batch_no: %d/%d, time/iter: %f/%d, loss_xyz: %f, loss_xyz_pre: %f, loss: %f' %
150 | (batch_id, len_train, elapse, n_iter, train_loss_xyz.item(), train_loss_xyz_pre.item(), train_loss.item()))
151 | elif tearingnet_graph:
152 | print(' batch_no: %d/%d, time/iter: %f/%d, loss_xyz: %f, loss_xyz_pre: %f, loss_xyz_pre2: %f, loss: %f' %
153 | (batch_id, len_train, elapse, n_iter, train_loss_xyz.item(), train_loss_xyz_pre.item(), train_loss_xyz_pre2.item(), train_loss.item()))
154 | else:
155 | print(' batch_no: %d/%d, time/iter: %f/%d, loss_xyz: %f, loss: %f' %
156 | (batch_id, len_train, elapse, n_iter, train_loss_xyz.item(), train_loss.item()))
157 | if opt.tf_summary: writer.add_scalar('train/batch_train_loss', train_loss.item(), n_iter)
158 |
159 | if opt.pc_write_freq > 0 and batch_id % opt.pc_write_freq == 0:
160 | labels = np.concatenate((np.ones(rec.shape[1]), np.zeros(points.shape[1])), axis=0).tolist()
161 | if opt.tf_summary:
162 | writer.add_embedding(torch.cat((rec[0, 0:, 0:3], points[0, :, 0:3]), 0),
163 | global_step=n_iter, metadata=labels, tag="pc") # the output point cloud
164 | writer.add_embedding(grid[0, 0:, 0:2], global_step=n_iter, tag="grid") # the 2D grid being used
165 | n_iter = n_iter + 1
166 | batch_id = batch_id + 1
167 |
168 | # Output the average loss of current epoch
169 | avg_loss_xyz = train_loss_xyz_sum / (batch_id + 1)
170 | if tearingnet_basic:
171 | avg_loss_xyz_pre = train_loss_xyz_pre_sum / (batch_id + 1)
172 | if tearingnet_graph:
173 | avg_loss_xyz_pre = train_loss_xyz_pre_sum / (batch_id + 1)
174 | avg_loss_xyz_pre2 = train_loss_xyz_pre2_sum / (batch_id + 1)
175 | avg_loss = train_loss_sum / (batch_id + 1)
176 |
177 | elapse = time.time() - t
178 | if tearingnet_basic:
179 | print('Epoch: %d time: %f --- avg_loss_xyz: %f, avg_loss_xyz_pre: %f, avg_loss: %f lr: %f' % \
180 | (epoch, elapse, avg_loss_xyz, avg_loss_xyz_pre, avg_loss, scheduler.get_lr()[0]))
181 | elif tearingnet_graph:
182 | print('Epoch: %d time: %f --- avg_loss_xyz: %f, avg_loss_xyz_pre: %f, avg_loss_xyz_pre2: %f, avg_loss: %f lr: %f' % \
183 | (epoch, elapse, avg_loss_xyz, avg_loss_xyz_pre, avg_loss_xyz_pre2, avg_loss, scheduler.get_lr()[0]))
184 | else:
185 | print('Epoch: %d time: %f --- avg_loss_xyz: %f, avg_loss: %f lr: %f' % \
186 | (epoch, elapse, avg_loss_xyz, avg_loss, scheduler.get_lr()[0]))
187 | if opt.tf_summary: writer.add_scalar('train/epoch_train_loss', avg_loss, epoch)
188 | if opt.tf_summary: writer.add_scalar('train/learning_rate', scheduler.get_lr()[0], epoch)
189 | scheduler.step()
190 |
191 | # Save the checkpoint
192 | if epoch % opt.save_epoch_freq == 0 or epoch == opt.n_epoch - 1:
193 | dict_name=opt.exp_name + '/epoch_'+str(epoch)+'.pth'
194 | torch.save({
195 | 'model_state_dict': ae.module.state_dict(),
196 | 'optimizer_state_dict': optimizer.state_dict(),
197 | 'scheduler_state_dict': scheduler.state_dict(),
198 | 'last_epoch': scheduler.state_dict()['last_epoch'],
199 | 'opt': opt
200 | }, dict_name)
201 | print('Current checkpoint saved to %s.' % (dict_name))
202 | print('\n')
203 |
204 | if opt.tf_summary: writer.close()
205 | print('Done!')
206 |
207 |
208 | if __name__ == "__main__":
209 |
210 | option_handler = TrainOptionHandler()
211 | opt = option_handler.parse_options() # all options are parsed through this command
212 | option_handler.print_options(opt) # print out all the options
213 | main()
214 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | '''
4 | Initialization and simple utilities
5 | '''
6 |
7 | import importlib
8 |
9 | def get_model_class(model_name):
10 |
11 | # enumerate the model files being used here
12 | model_file_list = ["pointnet", "autoencoder", "foldingnet", "tearingnet", "tearingnetgraph"]
13 |
14 | model_class = None
15 | for filename in model_file_list:
16 |
17 | # Retrieve the model class
18 | modellib = importlib.import_module("models." + filename)
19 | class_name = model_name + "Model"
20 | for name, cls in modellib.__dict__.items():
21 | if name.lower() == class_name.lower():
22 | model_class = cls
23 | break
24 | if model_class is not None:
25 | break
26 |
27 | if model_class is None:
28 | print("The specified model [{}] not found.".format(model_name))
29 | exit(0)
30 |
31 | return model_class
--------------------------------------------------------------------------------
/models/autoencoder.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | '''
4 | Point cloud auto-encoder backbone
5 | '''
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.parallel
10 | import torch.utils.data
11 | import numpy as np
12 | import sys
13 | import os
14 |
15 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
16 | sys.path.append(os.path.join(BASE_DIR, '../util/nndistance'))
17 | from modules.nnd import NNDModule
18 | nn_match = NNDModule()
19 | USE_CUDA = True
20 |
21 | from . import get_model_class
22 | from .tearingnet import TearingNetBasicModel
23 | from .tearingnet_graph import TearingNetGraphModel
24 |
25 |
26 | class PointCloudAutoencoder(nn.Module):
27 |
28 | def __init__(self, opt):
29 | super(PointCloudAutoencoder, self).__init__()
30 |
31 | encoder_class = get_model_class(opt.encoder)
32 | self.encoder = encoder_class(opt)
33 | decoder_class = get_model_class(opt.decoder)
34 | self.decoder = decoder_class(opt)
35 | self.is_train = (opt.phase.lower() == 'train')
36 | if self.is_train:
37 | self.xyz_loss_type = opt.xyz_loss_type
38 | self.xyz_chamfer_weight = opt.xyz_chamfer_weight
39 |
40 | def forward(self, data):
41 |
42 | cw = self.encoder(data)
43 | if isinstance(self.decoder, TearingNetBasicModel):
44 | rec0, rec1, grid = self.decoder(cw)
45 | return {"rec": rec1, "rec_pre": rec0, "grid": grid, "cw": cw}
46 | elif isinstance(self.decoder, TearingNetGraphModel):
47 | rec0, rec1, rec2, grid, graph_wght = self.decoder(cw)
48 | return {"rec": rec2, "rec_pre": rec1, "rec_pre2": rec0, "grid": grid, "graph_wght": graph_wght, "cw": cw}
49 | else:
50 | rec = self.decoder(cw)
51 | return {"rec": rec, "cw": cw}
52 |
53 | def xyz_loss(self, data, rec, xyz_loss_type=-1):
54 |
55 | if xyz_loss_type == -1:
56 | xyz_loss_type = self.xyz_loss_type
57 | dist1, dist2 = nn_match(data.contiguous(), rec.contiguous())
58 | dist2 = dist2 * self.xyz_chamfer_weight
59 |
60 | # Different variants of the Chamfer distance
61 | if xyz_loss_type == 0: # augmented Chamfer distance
62 | loss = torch.max(torch.mean(torch.sqrt(dist1), 1), torch.mean(torch.sqrt(dist2), 1))
63 | loss = torch.mean(loss)
64 | elif xyz_loss_type == 1:
65 | loss = torch.mean(torch.sqrt(dist1), 1) + torch.mean(torch.sqrt(dist2), 1)
66 | loss = torch.mean(loss)
67 | elif xyz_loss_type == 2: # used in other papers
68 | loss = torch.mean(dist1) + torch.mean(dist2)
69 | return loss
70 |
71 |
72 | if __name__ == '__main__':
73 | USE_CUDA = True
--------------------------------------------------------------------------------
/models/foldingnet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | '''
4 | Basic FoldingNet model
5 | '''
6 |
7 | import torch
8 | import torch.nn as nn
9 | from .pointnet import PointwiseMLP
10 |
11 |
12 | class FoldingNetVanilla(nn.Module):
13 |
14 | def __init__(self, folding1_dims, folding2_dims):
15 | super(FoldingNetVanilla, self).__init__()
16 |
17 | # The folding decoder
18 | self.fold1 = PointwiseMLP(folding1_dims, doLastRelu=False)
19 | if folding2_dims[0] > 0:
20 | self.fold2 = PointwiseMLP(folding2_dims, doLastRelu=False)
21 | else: self.fold2 = None
22 |
23 | def forward(self, cw, grid, **kwargs):
24 |
25 | cw_exp = cw.unsqueeze(1).expand(-1, grid.shape[1], -1) # batch_size X point_num X code_length
26 |
27 | # 1st folding
28 | in1 = torch.cat((grid, cw_exp), 2) # batch_size X point_num X (code_length + 3)
29 | out1 = self.fold1(in1) # batch_size X point_num X 3
30 |
31 | # 2nd folding
32 | if not(self.fold2 is None):
33 | in2 = torch.cat((out1, cw_exp), 2) # batch_size X point_num X (code_length + 4)
34 | out2 = self.fold2(in2) # batch_size X point_num X 3
35 | return out2
36 | else: return out1
37 |
38 |
39 | class FoldingNetVanillaModel(nn.Module):
40 |
41 | @staticmethod
42 | def add_options(parser, isTrain = True):
43 |
44 | # Some optionals
45 | parser.add_argument('--grid_dims', type=int, nargs='+', help='Grid dimensions.')
46 | parser.add_argument('--folding1_dims', type=int, nargs='+', default=[514, 512, 512, 3], help='Dimensions of the first folding module.')
47 | parser.add_argument('--folding2_dims', type=int, nargs='+', default=[515, 512, 512, 3], help='Dimensions of the second folding module.')
48 | return parser
49 |
50 | def __init__(self, opt):
51 | super(FoldingNetVanillaModel, self).__init__()
52 |
53 | # Initialize the 2D grid
54 | range_x = torch.linspace(-1.0, 1.0, opt.grid_dims[0])
55 | range_y = torch.linspace(-1.0, 1.0, opt.grid_dims[1])
56 | x_coor, y_coor = torch.meshgrid(range_x, range_y)
57 | self.grid = torch.stack([x_coor, y_coor], axis=-1).float().reshape(-1, 2)
58 |
59 | # Initialize the folding module
60 | self.folding1 = FoldingNetVanilla(opt.folding1_dims, opt.folding2_dims)
61 |
62 | def forward(self, cw):
63 |
64 | grid = self.grid.cuda().unsqueeze(0).expand(cw.shape[0], -1, -1) # batch_size X point_num X 2
65 | pc = self.folding1(cw, grid)
66 | return pc
--------------------------------------------------------------------------------
/models/pointnet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | '''
4 | Basic PointNet model
5 | '''
6 |
7 | import torch.nn as nn
8 | from util.option_handler import str2bool
9 |
10 |
11 | def get_and_init_FC_layer(din, dout):
12 | li = nn.Linear(din, dout)
13 | #init weights/bias
14 | nn.init.xavier_uniform_(li.weight.data, gain=nn.init.calculate_gain('relu'))
15 | li.bias.data.fill_(0.)
16 | return li
17 |
18 |
19 | def get_MLP_layers(dims, doLastRelu):
20 | layers = []
21 | for i in range(1, len(dims)):
22 | layers.append(get_and_init_FC_layer(dims[i-1], dims[i]))
23 | if i==len(dims)-1 and not doLastRelu:
24 | continue
25 | layers.append(nn.ReLU())
26 | return layers
27 |
28 |
29 | class PointwiseMLP(nn.Sequential):
30 | '''Nxdin ->Nxd1->Nxd2->...-> Nxdout'''
31 |
32 | def __init__(self, dims, doLastRelu=False):
33 | layers = get_MLP_layers(dims, doLastRelu)
34 | super(PointwiseMLP, self).__init__(*layers)
35 |
36 |
37 | class GlobalPool(nn.Module):
38 | '''BxNxK -> BxK'''
39 |
40 | def __init__(self, pool_layer):
41 | super(GlobalPool, self).__init__()
42 | self.Pool = pool_layer
43 |
44 | def forward(self, X):
45 | X = X.unsqueeze(-3)
46 | X = self.Pool(X)
47 | X = X.squeeze(-2)
48 | X = X.squeeze(-2)
49 | return X
50 |
51 | class PointNetGlobalMax(nn.Sequential):
52 | '''BxNxdims[0] -> Bxdims[-1]'''
53 |
54 | def __init__(self, dims, doLastRelu=False):
55 | layers = [
56 | PointwiseMLP(dims, doLastRelu=doLastRelu),
57 | GlobalPool(nn.AdaptiveMaxPool2d((1, dims[-1]))),
58 | ]
59 | super(PointNetGlobalMax, self).__init__(*layers)
60 |
61 |
62 | class PointNetVanilla(nn.Sequential):
63 |
64 | def __init__(self, MLP_dims, FC_dims, MLP_doLastRelu):
65 | assert(MLP_dims[-1]==FC_dims[0])
66 | layers = [
67 | PointNetGlobalMax(MLP_dims, doLastRelu=MLP_doLastRelu),
68 | ]
69 | layers.extend(get_MLP_layers(FC_dims, False))
70 | super(PointNetVanilla, self).__init__(*layers)
71 |
72 |
73 | class PointNetVanillaModel(nn.Module):
74 |
75 | @staticmethod
76 | def add_options(parser, isTrain = True):
77 | parser.add_argument('--pointnet_mlp_dims', type=int, nargs='+', default=[3, 64, 128, 128, 1024], help='Dimensions of the MLP in the PointNet encoder.')
78 | parser.add_argument('--pointnet_fc_dims', type=int, nargs='+', default=[1024, 512, 512, 512], help='Dimensions of the FC in the PointNet encoder.')
79 | parser.add_argument("--pointnet_mlp_dolastrelu", type=str2bool, nargs='?', const=True, default=False, help='Apply the last ReLU or not in the PointNet encoder.')
80 | return parser
81 |
82 | def __init__(self, opt):
83 | super(PointNetVanillaModel, self).__init__()
84 | self.pointnet = PointNetVanilla(opt.pointnet_mlp_dims, opt.pointnet_fc_dims, opt.pointnet_mlp_dolastrelu)
85 |
86 | def forward(self, data):
87 | return self.pointnet(data)
88 |
--------------------------------------------------------------------------------
/models/tearingnet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | '''
4 | Basic TearingNet model
5 | '''
6 |
7 | import torch
8 | import torch.nn as nn
9 | from .foldingnet import FoldingNetVanilla
10 |
11 |
12 | def get_Conv2d_layer(dims, kernel_size, doLastRelu):
13 | layers = []
14 | for i in range(1, len(dims)):
15 | if kernel_size != 1:
16 | layers.append(nn.ReplicationPad2d(int((kernel_size - 1) / 2)))
17 | layers.append(nn.Conv2d(in_channels=dims[i-1], out_channels=dims[i],
18 | kernel_size=kernel_size, stride=1, padding=0, bias=True))
19 | if i==len(dims)-1 and not doLastRelu:
20 | continue
21 | layers.append(nn.ReLU(inplace=True))
22 | return layers
23 |
24 |
25 | class Conv2dLayers(nn.Sequential):
26 | def __init__(self, dims, kernel_size, doLastRelu=False):
27 | layers = get_Conv2d_layer(dims, kernel_size, doLastRelu)
28 | super(Conv2dLayers, self).__init__(*layers)
29 |
30 |
31 | class TearingNetBasic(nn.Module):
32 |
33 | def __init__(self, tearing1_dims, tearing2_dims, grid_dims, kernel_size=1):
34 | super(TearingNetBasic, self).__init__()
35 |
36 | self.grid_dims = grid_dims
37 | self.tearing1 = Conv2dLayers(tearing1_dims, kernel_size=kernel_size, doLastRelu=False)
38 | self.tearing2 = Conv2dLayers(tearing2_dims, kernel_size=kernel_size, doLastRelu=False)
39 |
40 | def forward(self, cw, grid, pc, **kwargs):
41 |
42 | grid_exp = grid.contiguous().view(grid.shape[0], self.grid_dims[0], self.grid_dims[1], 2) # batch_size X dim0 X dim1 X 2
43 | pc_exp = pc.contiguous().view(pc.shape[0], self.grid_dims[0], self.grid_dims[1], 3) # batch_size X dim0 X dim1 X 3
44 | cw_exp = cw.unsqueeze(1).unsqueeze(1).expand(-1, self.grid_dims[0], self.grid_dims[1], -1) # batch_size X dim0 X dim1 X code_length
45 | in1 = torch.cat((grid_exp, pc_exp, cw_exp), 3).permute([0, 3, 1, 2])
46 |
47 | # Compute the torn 2D grid
48 | out1 = self.tearing1(in1) # 1st tearing
49 | in2 = torch.cat((in1, out1), 1)
50 | out2 = self.tearing2(in2) # 2nd tearing
51 | out2 = out2.permute([0, 2, 3, 1]).contiguous().view(grid.shape[0], self.grid_dims[0] * self.grid_dims[1], 2)
52 | return grid + out2
53 |
54 |
55 | class TearingNetBasicModel(nn.Module):
56 |
57 | @staticmethod
58 | def add_options(parser, isTrain = True):
59 |
60 | # General optional(s)
61 | parser.add_argument('--grid_dims', type=int, nargs='+', help='Grid dimensions.')
62 |
63 | # Options related to the Folding Network
64 | parser.add_argument('--folding1_dims', type=int, nargs='+', default=[514, 512, 512, 3], help='Dimensions of the first folding module.')
65 | parser.add_argument('--folding2_dims', type=int, nargs='+', default=[515, 512, 512, 3], help='Dimensions of the second folding module.')
66 |
67 | # Options related to the Tearing Network
68 | parser.add_argument('--tearing1_dims', type=int, nargs='+', default=[523, 256, 128, 64], help='Dimensions of the first tearing module.')
69 | parser.add_argument('--tearing2_dims', type=int, nargs='+', default=[587, 256, 128, 2], help='Dimensions of the second tearing module.')
70 | parser.add_argument('--tearing_conv_kernel_size', type=int, default=1, help='Kernel size of the convolutional layers in the Tearing Network, 1 implies MLP.')
71 |
72 | return parser
73 |
74 | def __init__(self, opt):
75 | super(TearingNetBasicModel, self).__init__()
76 |
77 | # Initialize the regular 2D grid
78 | range_x = torch.linspace(-1.0, 1.0, opt.grid_dims[0])
79 | range_y = torch.linspace(-1.0, 1.0, opt.grid_dims[1])
80 | x_coor, y_coor = torch.meshgrid(range_x, range_y)
81 | self.grid = torch.stack([x_coor, y_coor], axis=-1).float().reshape(-1, 2)
82 |
83 | # Initialize the Folding Network and the Tearing Network
84 | self.folding = FoldingNetVanilla(opt.folding1_dims, opt.folding2_dims)
85 | self.tearing = TearingNetBasic(opt.tearing1_dims, opt.tearing2_dims, opt.grid_dims, opt.tearing_conv_kernel_size)
86 |
87 | def forward(self, cw):
88 |
89 | grid0 = self.grid.cuda().unsqueeze(0).expand(cw.shape[0], -1, -1) # batch_size X point_num X 2
90 | pc0 = self.folding(cw, grid0) # Folding Network
91 | grid1 = self.tearing(cw, grid0, pc0) # Tearing Network
92 | pc1 = self.folding(cw, grid1) # Folding Network
93 |
94 | return pc0, pc1, grid1
--------------------------------------------------------------------------------
/models/tearingnet_graph.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | '''
4 | TearingNet with graph filtering
5 | '''
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 | from .foldingnet import FoldingNetVanilla
11 | from .tearingnet import TearingNetBasic
12 |
13 |
14 | class GraphFilter(nn.Module):
15 |
16 | def __init__(self, grid_dims, graph_r, graph_eps, graph_lam):
17 | super(GraphFilter, self).__init__()
18 | self.grid_dims = grid_dims
19 | self.graph_r = graph_r
20 | self.graph_eps_sqr = graph_eps * graph_eps
21 | self.graph_lam = graph_lam
22 |
23 | def forward(self, grid, pc):
24 |
25 | # Data preparation
26 | bs_cur = pc.shape[0]
27 | grid_exp = grid.contiguous().view(bs_cur, self.grid_dims[0], self.grid_dims[1], 2) # batch_size X dim0 X dim1 X 2
28 | pc_exp = pc.contiguous().view(bs_cur, self.grid_dims[0], self.grid_dims[1], 3) # batch_size X dim0 X dim1 X 3
29 | graph_feature = torch.cat((grid_exp, pc_exp), dim=3).permute([0, 3, 1, 2])
30 |
31 | # Compute the graph weights
32 | wght_hori = graph_feature[:,:,:-1,:] - graph_feature[:,:,1:,:] # horizontal weights
33 | wght_vert = graph_feature[:,:,:,:-1] - graph_feature[:,:,:,1:] # vertical weights
34 | wght_hori = torch.exp(-torch.sum(wght_hori * wght_hori, dim=1) / self.graph_eps_sqr) # Gaussian weight
35 | wght_vert = torch.exp(-torch.sum(wght_vert * wght_vert, dim=1) / self.graph_eps_sqr)
36 | wght_hori = (wght_hori > self.graph_r) * wght_hori
37 | wght_vert = (wght_vert > self.graph_r) * wght_vert
38 | wght_lft = torch.cat((torch.zeros([bs_cur, 1, self.grid_dims[1]]).cuda(), wght_hori), 1) # add left
39 | wght_rgh = torch.cat((wght_hori, torch.zeros([bs_cur, 1, self.grid_dims[1]]).cuda()), 1) # add right
40 | wght_top = torch.cat((torch.zeros([bs_cur, self.grid_dims[0], 1]).cuda(), wght_vert), 2) # add top
41 | wght_bot = torch.cat((wght_vert, torch.zeros([bs_cur, self.grid_dims[0], 1]).cuda()), 2) # add bottom
42 | wght_all = torch.cat((wght_lft.unsqueeze(1), wght_rgh.unsqueeze(1), wght_top.unsqueeze(1), wght_bot.unsqueeze(1)), 1)
43 |
44 | # Perform the actural graph filtering: x = (I - \lambda L) * x
45 | wght_hori = wght_hori.unsqueeze(1).expand(-1, 3, -1, -1) # dimension expansion
46 | wght_vert = wght_vert.unsqueeze(1).expand(-1, 3, -1, -1)
47 | pc = pc.permute([0, 2, 1]).contiguous().view(bs_cur, 3, self.grid_dims[0], self.grid_dims[1])
48 | pc_filt = \
49 | torch.cat((torch.zeros([bs_cur, 3, 1, self.grid_dims[1]]).cuda(), pc[:,:,:-1,:] * wght_hori), 2) + \
50 | torch.cat((pc[:,:,1:,:] * wght_hori, torch.zeros([bs_cur, 3, 1, self.grid_dims[1]]).cuda()), 2) + \
51 | torch.cat((torch.zeros([bs_cur, 3, self.grid_dims[0], 1]).cuda(), pc[:,:,:,:-1] * wght_vert), 3) + \
52 | torch.cat((pc[:,:,:,1:] * wght_vert, torch.zeros([bs_cur, 3, self.grid_dims[0], 1]).cuda()), 3) # left, right, top, bottom
53 |
54 | pc_filt = pc + self.graph_lam * (pc_filt - torch.sum(wght_all,dim=1).unsqueeze(1).expand(-1, 3, -1, -1) * pc) # equivalent to ( I - \lambda L) * x
55 | pc_filt = pc_filt.view(bs_cur, 3, -1).permute([0, 2, 1])
56 | return pc_filt, wght_all
57 |
58 |
59 | class TearingNetGraphModel(nn.Module):
60 |
61 | @staticmethod
62 | def add_options(parser, isTrain = True):
63 |
64 | # General optional(s)
65 | parser.add_argument('--grid_dims', type=int, nargs='+', help='Grid dimensions.')
66 |
67 | # Options related to the Folding Network
68 | parser.add_argument('--folding1_dims', type=int, nargs='+', default=[514, 512, 512, 3], help='Dimensions of the first folding module.')
69 | parser.add_argument('--folding2_dims', type=int, nargs='+', default=[515, 512, 512, 3], help='Dimensions of the second folding module.')
70 |
71 | # Options related to the Tearing Network
72 | parser.add_argument('--tearing1_dims', type=int, nargs='+', default=[523, 256, 128, 64], help='Dimensions of the first tearing module.')
73 | parser.add_argument('--tearing2_dims', type=int, nargs='+', default=[587, 256, 128, 2], help='Dimensions of the second tearing module.')
74 | parser.add_argument('--tearing_conv_kernel_size', type=int, default=1, help='Kernel size of the convolutional layers in the Tearing Network, 1 implies MLP.')
75 |
76 | # Options related to graph construction
77 | parser.add_argument('--graph_r', type=float, default=1e-12, help='Parameter r for the r-neighborhood graph.')
78 | parser.add_argument('--graph_eps', type=float, default=0.02, help='Parameter epsilon for the graph (bandwidth parameter).')
79 | parser.add_argument('--graph_lam', type=float, default=0.5, help='Parameter lambda for the graph filter.')
80 |
81 | return parser
82 |
83 | def __init__(self, opt):
84 | super(TearingNetGraphModel, self).__init__()
85 |
86 | # Initialize the regular 2D grid
87 | range_x = torch.linspace(-1.0, 1.0, opt.grid_dims[0])
88 | range_y = torch.linspace(-1.0, 1.0, opt.grid_dims[1])
89 | x_coor, y_coor = torch.meshgrid(range_x, range_y)
90 | self.grid = torch.stack([x_coor, y_coor], axis=-1).float().reshape(-1, 2)
91 |
92 | # Initialize the Folding Network and the Tearing Network
93 | self.folding = FoldingNetVanilla(opt.folding1_dims, opt.folding2_dims)
94 | self.tearing = TearingNetBasic(opt.tearing1_dims, opt.tearing2_dims, opt.grid_dims, opt.tearing_conv_kernel_size)
95 | self.graph_filter = GraphFilter(opt.grid_dims, opt.graph_r, opt.graph_eps, opt.graph_lam)
96 |
97 | def forward(self, cw):
98 |
99 | grid0 = self.grid.cuda().unsqueeze(0).expand(cw.shape[0], -1, -1) # batch_size X point_num X 2
100 | pc0 = self.folding(cw, grid0) # Folding Network
101 | grid1 = self.tearing(cw, grid0, pc0) # Tearing Network
102 | pc1 = self.folding(cw, grid1) # Folding Network
103 | pc2, graph_wght = self.graph_filter(grid1, pc1) # Graph Filtering
104 | return pc0, pc1, pc2, grid1, graph_wght
105 |
--------------------------------------------------------------------------------
/scripts/experiments/counting.sh:
--------------------------------------------------------------------------------
1 | # Counting experiment
2 |
3 | EXP="count_tearing_kitti"
4 | PY_NAME="${HOME_DIR}/experiments/counting.py"
5 | EXP_NAME="results/${EXP}"
6 | CHECKPOINT="${HOME_DIR}/results/train_tearing_kitti/"
7 | PHASE="counting"
8 | BATCH_SIZE="8"
9 | COUNT_FOLD="4"
10 | EPOCH_INTERVAL="-1 10"
11 | CONFIG_FROM_CHECKPOINT="True"
12 | DATASET_NAME="kitti_mulobj"
13 | COUNT_SPLIT="test_5x5"
14 | PRINT_FREQ="50"
15 | GRID_DIMS="45 45"
16 | SVM_PARAMS="100"
17 |
18 | RUN_ARGUMENTS="${PY_NAME} --exp_name ${EXP_NAME} --checkpoint ${CHECKPOINT} --phase ${PHASE} --batch_size ${BATCH_SIZE} --count_fold ${COUNT_FOLD} --epoch_interval ${EPOCH_INTERVAL} --config_from_checkpoint ${CONFIG_FROM_CHECKPOINT} --dataset_name ${DATASET_NAME} --count_split ${COUNT_SPLIT} --print_freq ${PRINT_FREQ} --grid_dims ${GRID_DIMS} --svm_params ${SVM_PARAMS}"
--------------------------------------------------------------------------------
/scripts/experiments/reconstruction.sh:
--------------------------------------------------------------------------------
1 | # Reconstruction experiment
2 |
3 | EXP="rec_tearing_kitti"
4 | PY_NAME="${HOME_DIR}/experiments/reconstruction.py"
5 | EXP_NAME="results/${EXP}"
6 | CHECKPOINT="${HOME_DIR}/results/train_tearing_kitti/epoch_639.pth"
7 | PHASE="test"
8 | CONFIG_FROM_CHECKPOINT="True"
9 | AUGMENTATION="False"
10 | DATASET_NAME="kitti_mulobj"
11 | TEST_SPLIT="test_5x5"
12 | BATCH_SIZE="32"
13 | GRID_DIMS="45 45"
14 | PRINT_FREQ="5"
15 | PC_WRITE_FREQ="-1"
16 | GT_COLOR="0.2 0.2 0.2"
17 | GRAPH_THRES="0"
18 | GRAPH_EDGE_COLOR="0.6 0.6 0.6"
19 | WRITE_MESH="True"
20 | GRAPH_DELETE_POINT_MODE="0"
21 | GRAPH_DELETE_POINT_EPS="0.08"
22 |
23 | RUN_ARGUMENTS="${PY_NAME} --exp_name ${EXP_NAME} --checkpoint ${CHECKPOINT} --phase ${PHASE} --config_from_checkpoint ${CONFIG_FROM_CHECKPOINT} --augmentation ${AUGMENTATION} --dataset_name ${DATASET_NAME} --test_split ${TEST_SPLIT} --batch_size ${BATCH_SIZE} --grid_dims ${GRID_DIMS} --print_freq ${PRINT_FREQ} --pc_write_freq ${PC_WRITE_FREQ} --gt_color ${GT_COLOR} --graph_thres ${GRAPH_THRES} --graph_edge_color ${GRAPH_EDGE_COLOR} --write_mesh ${WRITE_MESH} --graph_delete_point_mode ${GRAPH_DELETE_POINT_MODE} --graph_delete_point_eps ${GRAPH_DELETE_POINT_EPS}"
--------------------------------------------------------------------------------
/scripts/experiments/train_folding_cad.sh:
--------------------------------------------------------------------------------
1 | # FoldingNet pretraining
2 |
3 | EXP="train_folding_cad"
4 | PY_NAME="${HOME_DIR}/experiments/train_basic.py"
5 | EXP_NAME="results/${EXP}"
6 | PHASE="train"
7 | N_EPOCH="640"
8 | SAVE_EPOCH_FREQ="10"
9 | PRINT_FREQ="50"
10 | PC_WRITE_FREQ="1000"
11 | BATCH_SIZE="32"
12 | XYZ_LOSS_TYPE="0"
13 | XYZ_CHAMFER_WEIGHT="0.01"
14 | LR_POLICY="step"
15 | LR="0.0002 80 0.5"
16 | DATASET_NAME="cad_mulobj"
17 | TRAIN_SPLIT="train_5x5"
18 | AUGMENTATION="True"
19 | AUGMENTATION_ROTATION_AXIS="1"
20 | GRID_DIMS="45 45"
21 | ENCODER="pointnetvanilla"
22 | DECODER="foldingnetvanilla"
23 | POINTNET_MLP_DIMS="3 64 64 64 128 1024"
24 | POINTNET_FC_DIMS="1024 512 512"
25 | POINTNET_MLP_DOLASTRELU="False"
26 | FOLDING1_DIMS="514 512 512 3"
27 | FOLDING2_DIMS="515 512 512 3"
28 |
29 | RUN_ARGUMENTS="${PY_NAME} --exp_name ${EXP_NAME} --phase ${PHASE} --n_epoch ${N_EPOCH} --save_epoch_freq ${SAVE_EPOCH_FREQ} --print_freq ${PRINT_FREQ} --pc_write_freq ${PC_WRITE_FREQ} --batch_size ${BATCH_SIZE} --xyz_loss_type ${XYZ_LOSS_TYPE} --xyz_chamfer_weight ${XYZ_CHAMFER_WEIGHT} --lr_policy ${LR_POLICY} --lr ${LR} --dataset_name ${DATASET_NAME} --train_split ${TRAIN_SPLIT} --augmentation ${AUGMENTATION} --augmentation_rotation_axis ${AUGMENTATION_ROTATION_AXIS} --grid_dims ${GRID_DIMS} --encoder ${ENCODER} --decoder ${DECODER} --pointnet_mlp_dims ${POINTNET_MLP_DIMS} --pointnet_fc_dims ${POINTNET_FC_DIMS} --pointnet_mlp_dolastrelu ${POINTNET_MLP_DOLASTRELU} --folding1_dims ${FOLDING1_DIMS} --folding2_dims ${FOLDING2_DIMS}"
--------------------------------------------------------------------------------
/scripts/experiments/train_folding_kitti.sh:
--------------------------------------------------------------------------------
1 | # FoldingNet pretraining
2 |
3 | EXP="train_folding_kitti"
4 | PY_NAME="${HOME_DIR}/experiments/train_basic.py"
5 | EXP_NAME="results/${EXP}"
6 | PHASE="train"
7 | N_EPOCH="640"
8 | SAVE_EPOCH_FREQ="10"
9 | PRINT_FREQ="50"
10 | PC_WRITE_FREQ="1000"
11 | BATCH_SIZE="32"
12 | XYZ_LOSS_TYPE="0"
13 | XYZ_CHAMFER_WEIGHT="0.01"
14 | LR_POLICY="step"
15 | LR="0.0002 80 0.5"
16 | DATASET_NAME="kitti_mulobj"
17 | TRAIN_SPLIT="train_5x5"
18 | AUGMENTATION="True"
19 | AUGMENTATION_THETA="0"
20 | AUGMENTATION_ROTATION_AXIS="0"
21 | AUGMENTATION_FLIP_AXIS="1"
22 | GRID_DIMS="45 45"
23 | ENCODER="pointnetvanilla"
24 | DECODER="foldingnetvanilla"
25 | POINTNET_MLP_DIMS="3 64 64 64 128 1024"
26 | POINTNET_FC_DIMS="1024 512 512"
27 | POINTNET_MLP_DOLASTRELU="False"
28 | FOLDING1_DIMS="514 512 512 3"
29 | FOLDING2_DIMS="515 512 512 3"
30 |
31 | RUN_ARGUMENTS="${PY_NAME} --exp_name ${EXP_NAME} --phase ${PHASE} --n_epoch ${N_EPOCH} --save_epoch_freq ${SAVE_EPOCH_FREQ} --print_freq ${PRINT_FREQ} --pc_write_freq ${PC_WRITE_FREQ} --batch_size ${BATCH_SIZE} --xyz_loss_type ${XYZ_LOSS_TYPE} --xyz_chamfer_weight ${XYZ_CHAMFER_WEIGHT} --lr_policy ${LR_POLICY} --lr ${LR} --dataset_name ${DATASET_NAME} --train_split ${TRAIN_SPLIT} --augmentation ${AUGMENTATION} --augmentation_theta ${AUGMENTATION_THETA} --augmentation_rotation_axis ${AUGMENTATION_ROTATION_AXIS} --augmentation_flip_axis ${AUGMENTATION_FLIP_AXIS} --grid_dims ${GRID_DIMS} --encoder ${ENCODER} --decoder ${DECODER} --pointnet_mlp_dims ${POINTNET_MLP_DIMS} --pointnet_fc_dims ${POINTNET_FC_DIMS} --pointnet_mlp_dolastrelu ${POINTNET_MLP_DOLASTRELU} --folding1_dims ${FOLDING1_DIMS} --folding2_dims ${FOLDING2_DIMS}"
--------------------------------------------------------------------------------
/scripts/experiments/train_tearing_cad.sh:
--------------------------------------------------------------------------------
1 | # Train the TearingNet
2 |
3 | EXP="train_tearing_cad"
4 | PY_NAME="${HOME_DIR}/experiments/train_tearing.py"
5 | EXP_NAME="results/${EXP}"
6 | CHECKPOINT="${HOME_DIR}/results/train_folding_cad/epoch_639.pth"
7 | LOAD_WEIGHT_ONLY="True"
8 | PHASE="train"
9 | N_EPOCH="480"
10 | SAVE_EPOCH_FREQ="10"
11 | PRINT_FREQ="50"
12 | PC_WRITE_FREQ="1000"
13 | BATCH_SIZE="32"
14 | XYZ_LOSS_TYPE="0"
15 | LR_POLICY="step"
16 | LR="0.000001 80 0.5"
17 | DATASET_NAME="cad_mulobj"
18 | TRAIN_SPLIT="train_5x5"
19 | AUGMENTATION="True"
20 | AUGMENTATION_ROTATION_AXIS="1"
21 | GRID_DIMS="45 45"
22 | ENCODER="pointnetvanilla"
23 | DECODER="tearingnetgraph"
24 | POINTNET_MLP_DIMS="3 64 64 64 128 1024"
25 | POINTNET_FC_DIMS="1024 512 512"
26 | POINTNET_MLP_DOLASTRELU="False"
27 | FOLDING1_DIMS="514 512 512 3"
28 | FOLDING2_DIMS="515 512 512 3"
29 | TEARING1_DIMS="517 512 512 64"
30 | TEARING2_DIMS="581 512 512 2"
31 | GRAPH_R="1e-12"
32 | GRAPH_EPS="0.02"
33 | GRAPH_LAM="0.5"
34 |
35 | RUN_ARGUMENTS="${PY_NAME} --exp_name ${EXP_NAME} --checkpoint ${CHECKPOINT} --load_weight_only ${LOAD_WEIGHT_ONLY} --phase ${PHASE} --n_epoch ${N_EPOCH} --save_epoch_freq ${SAVE_EPOCH_FREQ} --print_freq ${PRINT_FREQ} --pc_write_freq ${PC_WRITE_FREQ} --batch_size ${BATCH_SIZE} --xyz_loss_type ${XYZ_LOSS_TYPE} --lr_policy ${LR_POLICY} --lr ${LR} --dataset_name ${DATASET_NAME} --train_split ${TRAIN_SPLIT} --augmentation ${AUGMENTATION} --augmentation_rotation_axis ${AUGMENTATION_ROTATION_AXIS} --grid_dims ${GRID_DIMS} --encoder ${ENCODER} --decoder ${DECODER} --pointnet_mlp_dims ${POINTNET_MLP_DIMS} --pointnet_fc_dims ${POINTNET_FC_DIMS} --pointnet_mlp_dolastrelu ${POINTNET_MLP_DOLASTRELU} --folding1_dims ${FOLDING1_DIMS} --folding2_dims ${FOLDING2_DIMS} --tearing1_dims ${TEARING1_DIMS} --tearing2_dims ${TEARING2_DIMS} --graph_r ${GRAPH_R} --graph_eps ${GRAPH_EPS} --graph_lam ${GRAPH_LAM}"
--------------------------------------------------------------------------------
/scripts/experiments/train_tearing_kitti.sh:
--------------------------------------------------------------------------------
1 | # Train the TearingNet
2 |
3 | EXP="train_tearing_kitti"
4 | PY_NAME="${HOME_DIR}/experiments/train_tearing.py"
5 | EXP_NAME="results/${EXP}"
6 | CHECKPOINT="${HOME_DIR}/results/train_folding_kitti/epoch_639.pth"
7 | LOAD_WEIGHT_ONLY="True"
8 | PHASE="train"
9 | N_EPOCH="480"
10 | SAVE_EPOCH_FREQ="10"
11 | PRINT_FREQ="50"
12 | PC_WRITE_FREQ="1000"
13 | BATCH_SIZE="32"
14 | XYZ_LOSS_TYPE="0"
15 | LR_POLICY="step"
16 | LR="0.000001 80 0.5"
17 | DATASET_NAME="kitti_mulobj"
18 | TRAIN_SPLIT="train_5x5"
19 | AUGMENTATION="True"
20 | AUGMENTATION_THETA="0"
21 | AUGMENTATION_ROTATION_AXIS="0"
22 | AUGMENTATION_FLIP_AXIS="1"
23 | GRID_DIMS="45 45"
24 | ENCODER="pointnetvanilla"
25 | DECODER="tearingnetgraph"
26 | POINTNET_MLP_DIMS="3 64 64 64 128 1024"
27 | POINTNET_FC_DIMS="1024 512 512"
28 | POINTNET_MLP_DOLASTRELU="False"
29 | FOLDING1_DIMS="514 512 512 3"
30 | FOLDING2_DIMS="515 512 512 3"
31 | TEARING1_DIMS="517 512 512 64"
32 | TEARING2_DIMS="581 512 512 2"
33 | GRAPH_R="1e-12"
34 | GRAPH_EPS="0.02"
35 | GRAPH_LAM="0.5"
36 |
37 | RUN_ARGUMENTS="${PY_NAME} --exp_name ${EXP_NAME} --checkpoint ${CHECKPOINT} --load_weight_only ${LOAD_WEIGHT_ONLY} --phase ${PHASE} --n_epoch ${N_EPOCH} --save_epoch_freq ${SAVE_EPOCH_FREQ} --print_freq ${PRINT_FREQ} --pc_write_freq ${PC_WRITE_FREQ} --batch_size ${BATCH_SIZE} --xyz_loss_type ${XYZ_LOSS_TYPE} --lr_policy ${LR_POLICY} --lr ${LR} --dataset_name ${DATASET_NAME} --train_split ${TRAIN_SPLIT} --augmentation ${AUGMENTATION} --augmentation_theta ${AUGMENTATION_THETA} --augmentation_rotation_axis ${AUGMENTATION_ROTATION_AXIS} --augmentation_flip_axis ${AUGMENTATION_FLIP_AXIS} --grid_dims ${GRID_DIMS} --encoder ${ENCODER} --decoder ${DECODER} --pointnet_mlp_dims ${POINTNET_MLP_DIMS} --pointnet_fc_dims ${POINTNET_FC_DIMS} --pointnet_mlp_dolastrelu ${POINTNET_MLP_DOLASTRELU} --folding1_dims ${FOLDING1_DIMS} --folding2_dims ${FOLDING2_DIMS} --tearing1_dims ${TEARING1_DIMS} --tearing2_dims ${TEARING2_DIMS} --graph_r ${GRAPH_R} --graph_eps ${GRAPH_EPS} --graph_lam ${GRAPH_LAM}"
--------------------------------------------------------------------------------
/scripts/gen_data/gen_cad_mulobj_test_5x5.sh:
--------------------------------------------------------------------------------
1 | # Generate CAD model multiple-object dataset
2 |
3 | EXP_NAME="results/gen_cad_mulobj_test_5x5"
4 | PY_NAME="${HOME_DIR}/dataloaders/cadmulobj_loader.py"
5 | PHASE="gen_cadmultiobj"
6 | NUM_POINTS="2048"
7 | AUGMENTATION="True"
8 | CAD_MULOBJ_NUM_ADD_MODEL="3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22"
9 | CAD_MULOBJ_NUM_EXAMPLE="1 4 16 53 143 322 609 974 1328 1550 1550 1328 974 609 322 143 53 16 4 1"
10 | CAD_MULOBJ_NUM_AVA_MODEL="1133"
11 | CAD_MULOBJ_SCENE_RADIUS="5"
12 | CAD_MULOBJ_OUTPUT_FILE_NAME="${HOME_DIR}/dataset/cadmulobj/cad_mulobj_param_test_5x5"
13 |
14 | RUN_ARGUMENTS="${PY_NAME} --exp_name ${EXP_NAME} --phase ${PHASE} --num_points ${NUM_POINTS} --augmentation ${AUGMENTATION} --cad_mulobj_num_add_model ${CAD_MULOBJ_NUM_ADD_MODEL} --cad_mulobj_num_example ${CAD_MULOBJ_NUM_EXAMPLE} --cad_mulobj_num_ava_model ${CAD_MULOBJ_NUM_AVA_MODEL} --cad_mulobj_scene_radius ${CAD_MULOBJ_SCENE_RADIUS} --cad_mulobj_output_file_name ${CAD_MULOBJ_OUTPUT_FILE_NAME}"
--------------------------------------------------------------------------------
/scripts/gen_data/gen_cad_mulobj_train_5x5.sh:
--------------------------------------------------------------------------------
1 | # Generate CAD model multiple-object dataset
2 |
3 | EXP_NAME="results/gen_cad_mulobj_train_5x5"
4 | PY_NAME="${HOME_DIR}/dataloaders/cadmulobj_loader.py"
5 | PHASE="gen_cadmultiobj"
6 | NUM_POINTS="2048"
7 | AUGMENTATION="True"
8 | CAD_MULOBJ_NUM_ADD_MODEL="3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22"
9 | CAD_MULOBJ_NUM_EXAMPLE="3 19 79 264 716 1612 3044 4871 6642 7749 7749 6642 4871 3044 1612 716 264 79 19 3"
10 | CAD_MULOBJ_NUM_AVA_MODEL="1133"
11 | CAD_MULOBJ_SCENE_RADIUS="5"
12 | CAD_MULOBJ_OUTPUT_FILE_NAME="${HOME_DIR}/dataset/cadmulobj/cad_mulobj_param_train_5x5"
13 |
14 | RUN_ARGUMENTS="${PY_NAME} --exp_name ${EXP_NAME} --phase ${PHASE} --num_points ${NUM_POINTS} --augmentation ${AUGMENTATION} --cad_mulobj_num_add_model ${CAD_MULOBJ_NUM_ADD_MODEL} --cad_mulobj_num_example ${CAD_MULOBJ_NUM_EXAMPLE} --cad_mulobj_num_ava_model ${CAD_MULOBJ_NUM_AVA_MODEL} --cad_mulobj_scene_radius ${CAD_MULOBJ_SCENE_RADIUS} --cad_mulobj_output_file_name ${CAD_MULOBJ_OUTPUT_FILE_NAME}"
--------------------------------------------------------------------------------
/scripts/gen_data/gen_kitti_mulobj_test_5x5.sh:
--------------------------------------------------------------------------------
1 | # Generate KITTI multiple-object dataset
2 |
3 | EXP_NAME="results/gen_kitti_mulobj_test_5x5"
4 | PY_NAME="${HOME_DIR}/dataloaders/kittimulobj_loader.py"
5 | PHASE="gen_kittimulobj"
6 | NUM_POINTS="2048"
7 | KITTI_MULOBJ_NUM_ADD_MODEL="3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22"
8 | KITTI_MULOBJ_NUM_EXAMPLE="1 4 16 53 143 322 609 974 1328 1550 1550 1328 974 609 322 143 53 16 4 1"
9 | KITTI_MULOBJ_SCENE_RADIUS="5"
10 | KITTI_MULOBJ_OUTPUT_FILE_NAME="${HOME_DIR}/dataset/kittimulobj/kitti_mulobj_param_test_5x5_2048"
11 |
12 | RUN_ARGUMENTS="${PY_NAME} --exp_name ${EXP_NAME} --phase ${PHASE} --num_points ${NUM_POINTS} --kitti_mulobj_num_add_model ${KITTI_MULOBJ_NUM_ADD_MODEL} --kitti_mulobj_num_example ${KITTI_MULOBJ_NUM_EXAMPLE} --kitti_mulobj_scene_radius ${KITTI_MULOBJ_SCENE_RADIUS} --kitti_mulobj_output_file_name ${KITTI_MULOBJ_OUTPUT_FILE_NAME}"
--------------------------------------------------------------------------------
/scripts/gen_data/gen_kitti_mulobj_train_5x5.sh:
--------------------------------------------------------------------------------
1 | # Generate KITTI multiple-object dataset
2 |
3 | EXP_NAME="results/gen_kitti_mulobj_train_5x5"
4 | PY_NAME="${HOME_DIR}/dataloaders/kittimulobj_loader.py"
5 | PHASE="gen_kittimulobj"
6 | NUM_POINTS="2048"
7 | KITTI_MULOBJ_NUM_ADD_MODEL="3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22"
8 | KITTI_MULOBJ_NUM_EXAMPLE="3 19 79 264 716 1612 3044 4871 6642 7749 7749 6642 4871 3044 1612 716 264 79 19 3"
9 | KITTI_MULOBJ_SCENE_RADIUS="5"
10 | KITTI_MULOBJ_OUTPUT_FILE_NAME="${HOME_DIR}/dataset/kittimulobj/kitti_mulobj_param_train_5x5_2048"
11 |
12 | RUN_ARGUMENTS="${PY_NAME} --exp_name ${EXP_NAME} --phase ${PHASE} --num_points ${NUM_POINTS} --kitti_mulobj_num_add_model ${KITTI_MULOBJ_NUM_ADD_MODEL} --kitti_mulobj_num_example ${KITTI_MULOBJ_NUM_EXAMPLE} --kitti_mulobj_scene_radius ${KITTI_MULOBJ_SCENE_RADIUS} --kitti_mulobj_output_file_name ${KITTI_MULOBJ_OUTPUT_FILE_NAME}"
--------------------------------------------------------------------------------
/scripts/launch.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | USE_GPU="0" # specify a GPU
4 | export CUDA_VISIBLE_DEVICES=${USE_GPU}
5 | echo export CUDA_VISIBLE_DEVICES=${USE_GPU}
6 |
7 | HOME_DIR="$(pwd)"
8 | source $1 # load parameters
9 |
10 | mkdir -p ${EXP_NAME} # run it
11 | LOG_NAME="log.txt"
12 | echo python ${RUN_ARGUMENTS}
13 | python ${RUN_ARGUMENTS} | tee ${EXP_NAME}/${LOG_NAME} &
14 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/InterDigitalInc/TearingNet/1e43caa266a11e9d50c2d912064bd39d369bb120/util/__init__.py
--------------------------------------------------------------------------------
/util/cad_models_collector.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | '''
4 | A tool to collect the CAD models of "person", "car", "cone", "plant" from ModelNet40, and "motorbike" from ShapeNetPart
5 | '''
6 |
7 | import os
8 | import numpy as np
9 |
10 | import sys
11 | sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/..')
12 | from dataloaders.modelnet_loader import ModelNet40, class_extractor
13 | from dataloaders.shapenet_part_loader import PartDataset
14 |
15 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
16 | data_path=os.path.abspath(os.path.join(BASE_DIR, '../dataset/cadmulobj')) + '/'
17 |
18 |
19 | def main():
20 |
21 | label_person = 24
22 | label_car = 7
23 | label_cone = 9
24 | label_plant = 26
25 | from_shapenet_part = 'Motorbike'
26 |
27 | loader = ModelNet40(phase='train')
28 | list_person = class_extractor(label_person, loader)
29 | list_car = class_extractor(label_car, loader)
30 | list_cone = class_extractor(label_cone, loader)
31 | list_plant = class_extractor(label_plant, loader)
32 | loader = ModelNet40(phase='test')
33 | list_person = np.concatenate((list_person, class_extractor(label_person, loader)), axis=0)
34 | list_car = np.concatenate((list_car, class_extractor(label_car, loader)), axis=0)
35 | list_cone = np.concatenate((list_cone, class_extractor(label_cone, loader)), axis=0)
36 | list_plant = np.concatenate((list_plant, class_extractor(label_plant, loader)), axis=0)
37 |
38 | list_motorbike=[]
39 | loader = PartDataset(npoints=2048, classification=False,
40 | class_choice=from_shapenet_part, split='trainval', normalize=True)
41 | for i in range(len(loader)):
42 | list_motorbike.append(loader[i][0].numpy())
43 | loader = PartDataset(npoints=2048, classification=False,
44 | class_choice=from_shapenet_part, split='test', normalize=True)
45 | for i in range(len(loader)):
46 | list_motorbike.append(loader[i][0].numpy())
47 | list_motorbike = np.vstack(list_motorbike).reshape(-1,2048,3)
48 |
49 | interest_obj = {"person": list_person, "car":list_car, "cone": list_cone, "plant": list_plant, "motorbike": list_motorbike}
50 | np.save(os.path.join(data_path, "cad_models"), interest_obj)
51 |
52 |
53 | if __name__ == '__main__':
54 | main()
--------------------------------------------------------------------------------
/util/mesh_writer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | '''
4 | Mesh writer
5 | '''
6 |
7 | import numpy as np
8 |
9 | def dist(i, j):
10 | return (i[0]-j[0]) ** 2 + (i[1]-j[1]) ** 2 + (i[2]-j[2]) ** 2
11 |
12 | def write_ply_mesh(points, color, edge_color, filepath, thres=0, delete_point_mode=-1, weights=None, point_color_as_index=False, thres_edge=1e3):
13 | f = open(filepath, "w")
14 |
15 | if weights is None:
16 | weights = np.zeros((points.shape[0], points.shape[1], 4), dtype=int)
17 | for i in range(points.shape[0]):
18 | for j in range(points.shape[1]):
19 | if i > 0: weights[i,j,0] = 1e6 # left available
20 | if i < points.shape[0] - 1: weights[i,j,1] = 1e6 # right available
21 | if j > 0: weights[i,j,2] = 1e6 # top available
22 | if j < points.shape[1] - 1: weights[i,j,3] = 1e6 # bottom available
23 |
24 | # Initialize the map for drawing points
25 | idx = 0
26 | idx_map = np.zeros((points.shape[0], points.shape[1]), dtype=int)
27 | for i in range(points.shape[0]):
28 | for j in range(points.shape[1]):
29 | degree = sum(weights[i,j,:] > thres)
30 | if degree <= delete_point_mode: # found a point to be removed
31 | idx_map[i,j] = -1
32 | else:
33 | idx_map[i,j] = idx
34 | idx += 1
35 |
36 | # Calculate total number of edges
37 | edge_total = 0
38 | for i in range(points.shape[0]):
39 | for j in range(points.shape[1]):
40 | if idx_map[i,j] >= 0:
41 | if weights[i,j,0] > thres and idx_map[i-1,j] >= 0 and dist(points[i,j,:], points[i-1,j,:]) <= thres_edge: edge_total += 1 # left
42 | if weights[i,j,1] > thres and idx_map[i+1,j] >= 0 and dist(points[i,j,:], points[i+1,j,:]) <= thres_edge: edge_total += 1 # right
43 | if weights[i,j,2] > thres and idx_map[i,j-1] >= 0 and dist(points[i,j,:], points[i,j-1,:]) <= thres_edge: edge_total += 1 # top
44 | if weights[i,j,3] > thres and idx_map[i,j+1] >= 0 and dist(points[i,j,:], points[i,j+1,:]) <= thres_edge: edge_total += 1 # bottom
45 |
46 | # Calculate total number of faces
47 | face_total = 0
48 | for i in range(points.shape[0] - 1):
49 | for j in range(points.shape[1] - 1):
50 | if (weights[i,j,1] > thres) and (weights[i,j,3] > thres) and (weights[i+1,j+1,0]> thres) and (weights[i+1,j+1,2]> thres):
51 | face_total +=1
52 |
53 | # Write header
54 | f.write("ply\n")
55 | f.write("format ascii 1.0\n")
56 | f.write("element vertex " + str(idx) + "\n")
57 | f.write("property float x\n")
58 | f.write("property float y\n")
59 | f.write("property float z\n")
60 | f.write("property uchar red\n")
61 | f.write("property uchar green\n")
62 | f.write("property uchar blue\n")
63 | f.write("element face " + str(face_total * 4) + "\n")
64 | f.write("property list uchar int vertex_index\n")
65 | f.write("element edge " + str(edge_total) + "\n")
66 | f.write("property int vertex1\n")
67 | f.write("property int vertex2\n")
68 | f.write("property uchar red\n")
69 | f.write("property uchar green\n")
70 | f.write("property uchar blue\n")
71 | f.write("end_header\n")
72 |
73 | # Write points
74 | for i in range(points.shape[0]):
75 | for j in range(points.shape[1]):
76 | if idx_map[i,j] >= 0:
77 | f.write(str(points[i,j,0]) + " " + str(points[i,j,1]) + " " + str(points[i,j,2]) + " ")
78 | if point_color_as_index == True: f.write(str(i) + " " + str(j) + " 0\n")
79 | else: f.write(str(int(color[i,j,0] * 255)) + " " + str(int(color[i,j,1] * 255)) + " " + str(int(color[i,j,2] * 255)) + "\n")
80 |
81 | # Write faces
82 | for i in range(points.shape[0] - 1):
83 | for j in range(points.shape[1] - 1):
84 | if (weights[i,j,1] > thres) and (weights[i,j,3] > thres) and (weights[i+1,j+1,0]> thres) and (weights[i+1,j+1,2]> thres):
85 | # f.write("4 " + str(idx_map[i+1,j]) + " " + str(idx_map[i+1,j+1]) + " " + str(idx_map[i,j+1]) + " " + str(idx_map[i,j]) + "\n")
86 | f.write("3 " + str(idx_map[i+1,j]) + " " + str(idx_map[i+1,j+1]) + " " + str(idx_map[i,j+1]) + "\n")
87 | f.write("3 " + str(idx_map[i,j+1]) + " " + str(idx_map[i,j]) + " " + str(idx_map[i+1,j]) + "\n")
88 | f.write("3 " + str(idx_map[i,j+1]) + " " + str(idx_map[i+1,j+1]) + " " + str(idx_map[i+1,j]) + "\n")
89 | f.write("3 " + str(idx_map[i+1,j]) + " " + str(idx_map[i,j]) + " " + str(idx_map[i,j+1]) + "\n")
90 |
91 | # Write meshes
92 | for i in range(points.shape[0]):
93 | for j in range(points.shape[1]):
94 | if idx_map[i,j] >= 0:
95 | if weights[i,j,0] > thres and idx_map[i-1,j] >= 0 and dist(points[i,j,:], points[i-1,j,:]) <= thres_edge: # left
96 | f.write(str(idx_map[i,j]) + " " + str(idx_map[i-1,j]) + " " + str(int(edge_color[0] * 255)) +
97 | " " + str(int(edge_color[1] * 255)) + " " + str(int(edge_color[2] * 255)) + "\n")
98 | if weights[i,j,1] > thres and idx_map[i+1,j] >= 0 and dist(points[i,j,:], points[i+1,j,:]) <= thres_edge: # right
99 | f.write(str(idx_map[i,j]) + " " + str(idx_map[i+1,j]) + " " + str(int(edge_color[0] * 255)) +
100 | " " + str(int(edge_color[1] * 255)) + " " + str(int(edge_color[2] * 255)) + "\n")
101 | if weights[i,j,2] > thres and idx_map[i,j-1] >= 0 and dist(points[i,j,:], points[i,j-1,:]) <= thres_edge: # top
102 | f.write(str(idx_map[i,j]) + " " + str(idx_map[i,j-1]) + " " + str(int(edge_color[0] * 255)) +
103 | " " + str(int(edge_color[1] * 255)) + " " + str(int(edge_color[2] * 255)) + "\n")
104 | if weights[i,j,3] > thres and idx_map[i,j+1] >= 0 and dist(points[i,j,:], points[i,j+1,:]) <= thres_edge: # bottom
105 | f.write(str(idx_map[i,j]) + " " + str(idx_map[i,j+1]) + " " + str(int(edge_color[0] * 255)) +
106 | " " + str(int(edge_color[1] * 255)) + " " + str(int(edge_color[2] * 255)) + "\n")
107 | f.close()
--------------------------------------------------------------------------------
/util/option_handler.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | '''
4 | Option handler
5 | '''
6 |
7 | import argparse
8 | import models
9 |
10 | def str2bool(v):
11 | if isinstance(v, bool):
12 | return v
13 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
14 | return True
15 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
16 | return False
17 | else:
18 | raise argparse.ArgumentTypeError('Boolean value expected.')
19 |
20 | class BasicOptionHandler():
21 |
22 | def add_options(self, parser):
23 | parser.add_argument('--exp_name', type=str, default='experiment_name', help='Name of the experiment, folders are created by this name.')
24 | parser.add_argument('--phase', type=str, default='train', help='Indicate current phase, train or test')
25 | parser.add_argument('--encoder', type=str, default='pointnetvanilla', help='Choice of the encoder.')
26 | parser.add_argument('--decoder', type=str, default='foldingnetvanilla', help='Choice of the decoder.')
27 | parser.add_argument('--checkpoint', type=str, default='', help='Restore from a indicated checkpoint.')
28 | parser.add_argument('--load_weight_only', type=str2bool, nargs='?', const=True, default=False, help='Load the model weight only when restoring from a checkpoint.')
29 | parser.add_argument('--xyz_loss_type', type=int, default=0, help='Choose the loss type for point cloud.')
30 | parser.add_argument('--xyz_chamfer_weight', type=float, default=1, help='Balance the two terms of the Chamfer distance.')
31 | parser.add_argument('--batch_size', type=int, default=8, help='Batch size when loding the dataset.')
32 | parser.add_argument('--num_points', type=int, default=2048, help='Input point set size')
33 | parser.add_argument('--dataset_name', default='', help='Dataset name')
34 | parser.add_argument('--config_from_checkpoint', type=str2bool, nargs='?', const=True, default=False, help='Load the model configuration form checkpoint.')
35 | parser.add_argument('--tf_summary', type=str2bool, nargs='?', const=True, default=True, help='Whether to use tensorboard for log.')
36 | return parser
37 |
38 | def parse_options(self):
39 |
40 | # Initialize parser with basic options
41 | parser = argparse.ArgumentParser(
42 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
43 | parser = self.add_options(parser)
44 |
45 | # Get the basic options
46 | opt, _ = parser.parse_known_args()
47 |
48 | # Train or test
49 | if opt.phase.lower() == 'train':
50 | self.isTrain = True
51 | elif opt.phase.lower() == 'test' or opt.phase.lower() == 'counting':
52 | self.isTrain = False
53 | elif (opt.phase.lower() == 'gen_cadmultiobj') or (opt.phase.lower() == 'kitti_data') or (opt.phase.lower() == 'gen_kittimulobj'):
54 | self.parser = parser
55 | self.isTrain = False
56 | return opt
57 | else:
58 | print("The phase [{}] does not exist.".format(opt.phase))
59 | exit(0)
60 | opt.isTrain = self.isTrain # train or test
61 |
62 | # Add options to the parser according to the chosen models
63 | encoder_option_setter = models.get_model_class(opt.encoder).add_options
64 | parser = encoder_option_setter(parser, self.isTrain)
65 | decoder_option_setter = models.get_model_class(opt.decoder).add_options
66 | parser = decoder_option_setter(parser, self.isTrain)
67 |
68 | self.parser = parser
69 | opt, _ = parser.parse_known_args()
70 | return opt
71 |
72 | def print_options(self, opt):
73 | message = ''
74 | message += '----------------- Options ---------------\n'
75 | # For k, v in sorted(vars(opt).items()):
76 | for k, v in vars(opt).items():
77 | comment = ''
78 | default = self.parser.get_default(k)
79 | if v != default:
80 | comment = '\t[default: %s]' % str(default)
81 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
82 | message += '----------------- End -------------------'
83 | print(message)
84 |
85 | class TrainOptionHandler(BasicOptionHandler):
86 |
87 | def add_options(self, parser):
88 | parser = BasicOptionHandler.add_options(self, parser)
89 | parser.add_argument('--n_epoch', type=int, default=100, help='Number of epoch to train.')
90 | parser.add_argument('--print_freq', type=int, default=20, help='Frequency of displaying training results.')
91 | parser.add_argument('--pc_write_freq', type=int, default=1000, help='Frequency of writing down the point cloud during training, <= 0 means do not write.')
92 | parser.add_argument('--save_epoch_freq', type=int, default=2, help='Frequency of saving the trained model.')
93 | parser.add_argument('--lr', type=float, nargs='+', default=[0.0002], help='Learning rate and its related parameters.')
94 | parser.add_argument('--train_split', type=str, default='train', help='The split of the dataset for training.')
95 | parser.add_argument('--lr_module', type=float, nargs='+', help='Specify the learning rate of the modules.')
96 | parser.add_argument('--optim_args', type=float, nargs='+', default=[0.9, 0.999, 0], help='Parameters of the ADAM optimizer.')
97 | parser.add_argument('--augmentation', type=str2bool, nargs='?', const=True, default=False, help='Apply data augmentation or not.')
98 | parser.add_argument('--augmentation_theta', type=int, default=-1, help='Choice of the rotation angle from [0,90,180,270].')
99 | parser.add_argument('--augmentation_rotation_axis', type=int, default=-1, help='Choice of the rotation axis from np.eye(3).')
100 | parser.add_argument('--augmentation_flip_axis', type=int, default=-1, help='The axis around which to flip for augmentation.')
101 | parser.add_argument('--augmentation_min_scale', type=float, default=1, help='Minimum of the scaling factor for augmentation.')
102 | parser.add_argument('--augmentation_max_scale', type=float, default=1, help='Maximum of the scaling factor for augmentation.')
103 |
104 | return parser
105 |
106 | class TestOptionHandler(BasicOptionHandler):
107 |
108 | def add_options(self, parser):
109 | parser = BasicOptionHandler.add_options(self, parser)
110 | parser.add_argument('--test_split', type=str, default='test', help='Specify the split being used.')
111 | parser.add_argument('--print_freq', type=int, default=20, help='Frequency of displaying results during testing.')
112 | parser.add_argument('--pc_write_freq', type=int, default=1000, help='Frequency of writing down the point cloud during testing.')
113 | parser.add_argument('--gt_color', type=float, nargs='+', default=[0, 0, 0], help='Color of the ground-truth point cloud.')
114 | parser.add_argument('--write_mesh', type=str2bool, nargs='?', const=True, default=False, help='Write down meshes or not.')
115 | parser.add_argument('--graph_thres', type=float, default=-1, help='Threshold of the graph edges.')
116 | parser.add_argument('--graph_edge_color', type=float, nargs='+', default=[0.5, 0.5, 0.5], help='Color of the graph edges.')
117 | parser.add_argument('--graph_delete_point_mode', type=int, default=-1, help='Mode of removing points: -1: no removal; 0: remove those without edge; 1: remove those with one edge.')
118 | parser.add_argument('--graph_delete_point_eps', type=float, default=-1, help='The epsilon to used when evaluating the point-deleted-version point cloud.')
119 | parser.add_argument('--thres_edge', type=float, default=3, help='Threshold of the length of the edge to be written.')
120 | parser.add_argument('--point_color_as_index', type=str2bool, nargs='?', const=True, default=False, help='Whether regard the point cloud color as point index.')
121 | return parser
122 |
123 | class CountingOptionHandler(BasicOptionHandler):
124 |
125 | def add_options(self, parser):
126 | parser = BasicOptionHandler.add_options(self, parser)
127 | parser.add_argument('--count_split', type=str, default='test', help='The split of the dataset for object counting.')
128 | parser.add_argument('--count_fold', type=int, default=4, help='Number of fold for doing the object counting experiment.')
129 | parser.add_argument('--epoch_interval', type=int, nargs='+', default=[-1, 10], help='Range of the epoch to do object counting.')
130 | parser.add_argument('--print_freq', type=int, default=20, help='Frequency of displaying object counting results.')
131 | parser.add_argument('--svm_params', type=float, nargs='+', default=[1e3], help='SVM parameters.')
132 | return parser
133 |
134 | class GenerateCADMultiObjectOptionHandler(BasicOptionHandler):
135 |
136 | def add_options(self, parser):
137 | parser = BasicOptionHandler.add_options(self, parser)
138 | parser.add_argument('--cad_mulobj_num_add_model', type=int, nargs='+', default=[2, 3, 4], help='A list. Numbers of 3D models to add into the scene.')
139 | parser.add_argument('--cad_mulobj_num_example', type=int, nargs='+', default=[1000, 1000, 1000], help='A list. Numbers of examples corresponds to each type of scene.')
140 | parser.add_argument('--cad_mulobj_num_ava_model', type=int, default=1000, help='Total number of available models in the original dataset.')
141 | parser.add_argument('--cad_mulobj_scene_radius', type=float, default=5, help='Radius of the generated scene.')
142 | parser.add_argument('--cad_mulobj_output_file_name', type=str, required=True, help='Path and file name of the output file parametrizing the dataset.')
143 | parser.add_argument('--augmentation', type=str2bool, nargs='?', const=True, default=False, help='Apply data augmentation or not.')
144 | return parser
145 |
146 | class GenerateKittiMultiObjectOptionHandler(BasicOptionHandler):
147 |
148 | def add_options(self, parser):
149 | parser = BasicOptionHandler.add_options(self, parser)
150 | parser.add_argument('--kitti_mulobj_num_add_model', type=int, nargs='+', default=[2, 3, 4], help='A list. Numbers of 3D models to add into the scene.')
151 | parser.add_argument('--kitti_mulobj_num_example', type=int, nargs='+', default=[1000, 1000, 1000], help='A list. Numbers of examples corresponds to each type of scene.')
152 | parser.add_argument('--kitti_mulobj_scene_radius', type=float, default=5, help='Radius of the generated scene.')
153 | parser.add_argument('--kitti_mulobj_output_file_name', type=str, required=True, help='Path and the filename of the output file parametrizing the dataset.')
154 | return parser
--------------------------------------------------------------------------------
/util/pcdet_create_groundtruth_database.py:
--------------------------------------------------------------------------------
1 | '''
2 | Modify based on https://github.com/open-mmlab/OpenPCDet/blob/master/pcdet/datasets/kitti/kitti_dataset.py
3 | replace the create_groundtruth_database() function in kitti_dataset.py of OpenPCDet by the following one.
4 | '''
5 |
6 | def create_groundtruth_database(self, info_path=None, used_classes=None, split='train'):
7 | import torch
8 | from itertools import combinations
9 |
10 | database_save_path = Path(self.root_path) / ('kitti_single' if split == 'train' else ('kitti_single_%s' % split))
11 | database_save_path.mkdir(parents=True, exist_ok=True)
12 | with open(info_path, 'rb') as f:
13 | infos = pickle.load(f)
14 |
15 | # Parameters
16 | num_point = 1536
17 | max_num_obj = 22
18 | interest_class = ['Pedestrian', 'Car', 'Cyclist', 'Van', 'Truck']
19 | all_db_infos=[]
20 | db_obj=[]
21 | for obj_col in range(max_num_obj):
22 | all_db_infos.append([])
23 | all_db_infos.append(np.zeros(max_num_obj, dtype=int)) # accumulator
24 | db_obj_save_path = Path(self.root_path) / ('kitti_dbinfos_object.pkl')
25 |
26 | frame_obj_list = np.zeros((len(infos), max_num_obj), dtype=int)
27 | for k in range(len(infos)):
28 | info = infos[k]
29 | sample_idx = info['point_cloud']['lidar_idx']
30 | points = self.get_lidar(sample_idx)
31 | annos = info['annos']
32 | names = annos['name']
33 | difficulty = annos['difficulty']
34 | bbox = annos['bbox']
35 | gt_boxes = annos['gt_boxes_lidar']
36 |
37 | num_obj = gt_boxes.shape[0]
38 | point_indices = roiaware_pool3d_utils.points_in_boxes_cpu(
39 | torch.from_numpy(points[:, 0:3]), torch.from_numpy(gt_boxes)
40 | ).numpy() # (nboxes, npoints)
41 |
42 | # Count the object occurence
43 | for obj_id in range(num_obj):
44 | if (names[obj_id] in interest_class) == False: continue
45 | filename = '%d_%d_%s.bin' % (k, obj_id, names[obj_id])
46 | filepath = database_save_path / filename
47 | db_path = str(filepath.relative_to(self.root_path)) # kitti_single/xxxxx.bin
48 | gt_points = points[point_indices[obj_id] > 0]
49 |
50 | db_obj.append({'name': names[obj_id], 'path': db_path, 'image_idx': sample_idx, 'gt_idx': obj_id,
51 | 'box3d_lidar': gt_boxes[obj_id], 'num_points_in_gt': gt_points.shape[0],
52 | 'difficulty': difficulty[obj_id], 'bbox': bbox[obj_id], 'score': annos['score'][obj_id]})
53 |
54 | with open(filepath, 'w') as f:
55 | gt_points.tofile(f)
56 | for obj_col in range(max_num_obj):
57 | if gt_points.shape[0] >= np.ceil(num_point/(obj_col+1)).astype(int):
58 | frame_obj_list[k,obj_col] += 2 ** obj_id
59 |
60 | # Conclude how a frame can be used
61 | for obj_col in range(max_num_obj):
62 | if bin(frame_obj_list[k,obj_col])[2:].count('1') >= obj_col + 1:
63 | obj_indicator = np.array(list(bin(frame_obj_list[k,obj_col])[2:].zfill(max_num_obj)[::-1]))=='1'
64 | obj_choice = np.arange(max_num_obj)[obj_indicator]
65 | comb = combinations(obj_choice, obj_col+1)
66 | for obj_scene in list(comb):
67 |
68 | # Write down the scene configuration
69 | db_info=[]
70 | for obj_id in obj_scene:
71 | filename = '%d_%d_%s.bin' % (k, obj_id, names[obj_id])
72 | filepath = database_save_path / filename
73 | db_path = str(filepath.relative_to(self.root_path)) # kitti_single/xxxxx.bin
74 | db_info.append({'name': names[obj_id], 'path': db_path, 'image_idx': sample_idx, 'gt_idx': obj_id,
75 | 'box3d_lidar': gt_boxes[obj_id], 'difficulty': difficulty[obj_id], 'bbox': bbox[obj_id],
76 | 'score': annos['score'][obj_id]})
77 | all_db_infos[len(obj_scene)-1].append(db_info)
78 | all_db_infos[-1][len(obj_scene)-1] += 1
79 | print(k,obj_scene)
80 |
81 | with open(db_obj_save_path, 'wb') as f:
82 | print(f.name)
83 | pickle.dump(db_obj, f)
84 |
--------------------------------------------------------------------------------