├── LICENSE ├── README.md ├── assets ├── framework.png └── torus.gif ├── dataloaders ├── __init__.py ├── cad_models_loader.py ├── cadmulobj_loader.py ├── kittimulobj_loader.py ├── modelnet_loader.py └── shapenet_part_loader.py ├── experiments ├── counting.py ├── reconstruction.py ├── train_basic.py └── train_tearing.py ├── models ├── __init__.py ├── autoencoder.py ├── foldingnet.py ├── pointnet.py ├── tearingnet.py └── tearingnet_graph.py ├── scripts ├── experiments │ ├── counting.sh │ ├── reconstruction.sh │ ├── train_folding_cad.sh │ ├── train_folding_kitti.sh │ ├── train_tearing_cad.sh │ └── train_tearing_kitti.sh ├── gen_data │ ├── gen_cad_mulobj_test_5x5.sh │ ├── gen_cad_mulobj_train_5x5.sh │ ├── gen_kitti_mulobj_test_5x5.sh │ └── gen_kitti_mulobj_train_5x5.sh └── launch.sh └── util ├── __init__.py ├── cad_models_collector.py ├── mesh_writer.py ├── option_handler.py └── pcdet_create_groundtruth_database.py /LICENSE: -------------------------------------------------------------------------------- 1 | LIMITED SOFTWARE EVALUATION LICENSE AGREEMENT 2 | 3 | 4 | The following Limited Software Evaluation License (the “License”) constitutes an agreement between you (the “Licensee”) and InterDigital Communications, Inc, a company organized and existing under the laws of the State of Delaware, USA, with its registered offices located at 200 Bellevue Parkway, Suite 300, Wilmington, DE 19809, USA (hereinafter “InterDigital”). 5 | 6 | This License governs the download and use of the Software (as defined below). Your use of the Software is subject to the terms and conditions set forth in this License. By installing, using, accessing or copying the Software, you hereby irrevocably accept the terms and conditions of this License. If you do not accept all parts of the terms and conditions of this License, you cannot install, use, access nor copy the Software. 7 | 8 | 9 | Article 1. Definitions 10 | 11 | “Affiliate” as used herein shall mean any entity that, directly or indirectly, through one or more intermediates, is controlled by, controls, or is under common control with InterDigital or The Licensee, as the case may be. For purposes of this definition only, the term “control” means the possession of the power to direct or cause the direction of the management and policies of an entity, whether by ownership of voting stock or partnership interest, by contract, or otherwise, including direct or indirect ownership of more than fifty percent (50%) of the voting interest in the entity in question. 12 | 13 | “Authorized Purpose” means any use of the Software for reproducing the experimental results reported in the following publication: Jiahao Pang, et.al., “TearingNet: Point Cloud Autoencoder to Learn Topology-Friendly Representations”, IEEE Conference on Computer Vision and Pattern Recognition (CVPR), June 2021 (the "Purpose") 14 | 15 | “Derivative Work” means any work that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. 16 | 17 | “Documentation” means textual materials delivered by InterDigital to the Licensee pursuant to this License relating to the Software, in written or electronic format, including but not limited to, technical reference manuals, technical notes, user manuals, and application guides. 18 | 19 | “Effective Date” means the date Licensee first installs a copy of the Software on any computer. 20 | 21 | “Limited Period” means the life of the copyright owned by InterDigital on the Software in each and every country where such copyright would exist. 22 | 23 | “Intellectual Property Rights” means all copyrights, trademarks, trade secrets, patents, mask words, and any other intellectual property rights recognized in any jurisdiction worldwide, including all applications and registrations with respect thereto. 24 | 25 | "Open Source Software" shall mean any software, including where appropriate, any and all modifications, derivative works, enhancements, upgrades, improvements, fixed bugs, and/or software statically linked to the source code of such software, released under a free or open source software license that requires, as a condition of usage, copy, modification and/or redistribution of such software, that the party: 26 | • Redistribute the Open Source Software royalty-free; and/or 27 | • Redistribute the Open Source Software under the same license/distribution terms as those contained in the open source or free software license under which it was originally released; and/or 28 | • Release to the public, disclose or otherwise make available the source code of the Open Source Software. 29 | 30 | For purposes of this License, by means of example and without limitation, any software that is released or distributed under any of the following licenses shall be qualified as Open Source Software: (i) GNU General Public License (GPL); (ii) GNU Lesser/Library GPL (LGPL); (iii) the Artistic License; (iv) the Mozilla Public License; (v) the Common Public License; (vi) the Sun Community Source License (SCSL); (vii) the Sun Industry Standards Source License (SISSL); (viii) BSD License; (ix) MIT License; (x) Apache Software License; (xi) Open SSL License; (xii) IBM Public License; and (xiii) Open Software License. 31 | 32 | “Software” means the Software with which this license was downloaded. 33 | 34 | 35 | Article 2. License 36 | 37 | InterDigital grants Licensee a free, worldwide, non-exclusive, license to InterDigital’s copyright on the Software to download, use and reproduce solely for the Authorized Purpose for the Limited Period. 38 | 39 | Licensee shall not pay any royalty, license fee or maintenance fee, or other fee of any nature under this License. 40 | 41 | Licensee shall have the right to correct, adapt and modify the Software, provided that such action is made to accomplish the Authorized Purpose. Licensee shall promptly provide a copy of such correction, adaptation or modification to InterDigital after any such change is made. All such changes to the Software shall be deemed Derivative Work. 42 | 43 | 44 | Article 3. Restrictions on use of the Software 45 | 46 | Licensee shall not remove, obscure or modify any copyright, trademark or other proprietary rights notices, marks or labels contained on or within the Software, falsify or delete any author attributions, legal notices or other labels of the origin or source of the material. 47 | 48 | Licensee may reproduce and distribute copies of the Software (including any Derivative Works to which Licensee has rights) in any medium, with or without modifications, provided that any such distribution meets the following conditions: 49 | 1. Licensee must give any other recipients of the Software or Derivative Works a copy of this License; 50 | 2. Licensee must cause any modified files to carry prominent notices stating that Licensee has changed the files; and 51 | 3. Licensee must retain, in the source form of any Software or Derivative Works that Licensee reproduces or distributes, all copyright, patent, trademark, and attribution notices from the source form of the Software, excluding those notices that do not pertain to any part of the Derivative Works. 52 | 53 | 54 | Article 4. Ownership 55 | 56 | Title to and ownership of the Software, the Documentation, and/or any Intellectual Property Right protecting the Software and/or the Documentation shall at all times remain with InterDigital. Licensee agrees that except for the limited rights granted to the Software as set forth in Section 2 above, in no event shall anything in this License grant, provide, or convey any other rights, privileges, immunities, or interest in or to any Intellectual Property Rights (including but not limited to patent rights) of InterDigital or any of its Affiliates, whether by implication, estoppel, or otherwise. 57 | 58 | 59 | Article 5. Derivative Works 60 | 61 | Derivative Works created by Licensee shall be owned by Licensee who will be able to use it or distribute it freely, provided that this does not infringe any of InterDigital’s rights. 62 | 63 | Licensee may add its own copyright statement to the Derivative Work and may provide additional or different license terms and conditions for use, reproduction, or distribution of Derivative Work, provided the use, reproduction, and distribution of the Derivative Work otherwise complies with the conditions stated in this License. 64 | 65 | Licensee hereby grants to InterDigital a perpetual, fully paid-up, transferrable, non-exclusive, and worldwide license in, on, and to any copyright of Licensee on any Derivative Work to use, modify, and reproduce the Derivative Work in connection with InterDigital’s use of the Software. Such license shall include the right to correct, adapt, modify, reverse engineer, disassemble, decompile or/and otherwise perform or conduct any action leading to the transformation of Derivative Work, provided that such action is made in connection with InterDigital’s use of the Software. 66 | 67 | 68 | Article 6. Publication/Communication 69 | 70 | Any publication or oral communication resulting from the use of the Software shall be elaborated in good faith and shall not be driven by a deliberate will to denigrate InterDigital or any of its products. In any publication and on any support joined to an oral communication (e.g., a PowerPoint presentation) relating to the Software, the following statement shall be inserted: 71 | “TearingNet is an InterDigital product” 72 | 73 | In any publication, the latest publication about the software shall be properly cited. The latest publication currently is: 74 | Pang, Jiahao, et al.. (2021, June 19-25). TearingNet: Point Cloud Autoencoder to Learn Topology-Friendly Representations. IEEE Conference on Computer Vision and Pattern Recognition (CVPR). 75 | 76 | In any oral communication relating to the Software and/or its use, the Licensee shall orally indicate that the Software is InterDigital’s property. 77 | 78 | 79 | Article 7. No Warranty - Disclaimer 80 | 81 | THE SOFTWARE AND DOCUMENTATION ARE PROVIDED TO LICENSEE ON AN “AS IS” BASIS. INTERDIGITAL MAKES NO WARRANTY THAT THE SOFTWARE WILL OPERATE ON ANY PARTICULAR HARDWARE, PLATFORM, OR ENVIRONMENT. THERE IS NO WARRANTY THAT THE OPERATION OF THE SOFTWARE SHALL BE UNINTERRUPTED, WITHOUT BUGS OR ERROR FREE. THE SOFTWARE AND DOCUMENTATION ARE PROVIDED HEREUNDER WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO ANY IMPLIED LIABILITIES AND WARRANTIES OF NONINFRINGEMENT OF INTELLECTUAL PROPERTY, FREEDOM FROM INHERENT DEFECTS, CONFORMITY TO A SAMPLE OR MODEL, MERCHANTABILITY, FITNESS AND/OR SUITABILITY FOR A SPECIFIC OR GENERAL PURPOSE AND THOSE ARISING BY STATUTE OR BY LAW, OR FROM A CAUSE OF DEALING OR USAGE OF TRADE. ANY AND ALL SUCH IMPLIED WARRANTIES ARE FULLY DISCLAIMED BY INTERDIGITAL TO THE MAXIMUM EXTENT ALLOWED BY LAW, AND LICENSEE ACKNOWLEDGES THAT THIS DISCLAIMER OF ALL EXPRESS AND IMPLIED WARRANTIES BY INTERDIGITAL, AS WELL AS LICENSEE’S ACCEPTANCE AND ACKNOWLEDGEMENT OF THE SAME, IS A MATERIAL PART OF THE CONSIDERATION FOR THIS LICENSE. 82 | 83 | InterDigital shall not be obligated to perform or provide any modifications, derivative works, enhancements, upgrades, updates or improvements of the Software or Documentation, or to fix any bug that could arise. 84 | 85 | Licensee at all times uses the Software at its own cost, risk and responsibility. InterDigital shall not be liable for any damages that could accrue by or to Licensee as a result of its use of the Software, either in accordance with this License or not. 86 | 87 | InterDigital shall not be liable for any consequential or indirect losses, including any indirect loss of profits, revenues, business, and/or anticipated savings, whether or not in the contemplation of the Parties at the time of entering into this License unless expressly set out in this License, or arising from gross negligence, willful misconduct or fraud. 88 | 89 | Licensee agrees that it will defend, indemnify and hold harmless InterDigital and its Affiliates against any and all losses, damages, costs and expenses arising from a breach by the Licensee of any of its obligations or representations hereunder, including, without limitation, any third party claims, and/or any claims in connection with any such breach and/or any use of the Software, including any claim from third party arising from access, use, or any other activity in relation to this Software. 90 | 91 | Licensee shall not make any warranty, representation, or commitment on behalf of InterDigital to any other third party. 92 | 93 | 94 | Article 8. Open Source Software 95 | 96 | Licensee hereby represents, warrants, and covenants to InterDigital that Licensee’s use of the Software shall not result in the Contamination of all or any part of the Software, directly or indirectly, or of any Intellectual Property of InterDigital or its Affiliates. 97 | 98 | As used herein, “Contamination” shall mean that the licensing terms under which any Open Source Software, distinct from the Software, is released would also apply to the Software herein, by virtue of such Open Source Software being linked to, combined with, or otherwise connected to the Software. 99 | 100 | 101 | Article 9. No Future Contract Obligation 102 | 103 | Neither this License nor the furnishing of the Software, nor any other InterDigital information provided to Licensee, shall be construed to obligate either party to: (a) enter into any further agreement or negotiation concerning the deployment of the Software; (b) refrain from entering into any agreement or negotiation with any other third party regarding the same or any other subject matter; or (c) refrain from pursuing its business in whatever manner it elects even if this involves competing with the other party. 104 | 105 | 106 | Article 10. General Provisions 107 | 108 | 10.1 Severability. If any provision of this License shall be held to be in contravention of applicable law, this License shall be construed as if such provision were not a part thereof, and in all other respects the terms hereof shall remain in full force and effect. 109 | 110 | 10.2 Governing Law. Regardless of the place of execution, delivery, performance or any other aspect of this License, this License and all of the rights of the parties under this License shall be governed by, construed under and enforced in accordance with the substantive law of the State of Delaware, USA, without regard to conflicts of law principles. In case of a dispute that cannot be settled amicably, the state and federal courts located in New Castle County, Delaware, USA, shall have exclusive jurisdiction over such dispute, and each party hereby irrevocably waives any objection to the jurisdiction of such courts, including but not limited to objections of lack of in personam jurisdiction or based on principles of forum non conveniens. 111 | 112 | 10.3 Survival. The provisions of articles 1, 3, 4, 6, 7, 9, 10.1, 10.2 and 10.5 shall survive termination of this License. 113 | 114 | 10.4 Assignment. InterDigital may assign this license to any third Party. Licensee may not assign this agreement to any third party without InterDigital’s prior written approval. 115 | 116 | 10.5 Entire Agreement. This License constitutes the entire agreement between the parties hereto with respect to the subject matter hereof and supersedes any prior agreements or understanding. 117 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TearingNet: Point Cloud Autoencoder to Learn Topology-Friendly Representations 2 | Created by Jiahao Pang, Duanshun Li, and Dong Tian from InterDigital 3 | 4 |

5 | framework 6 |

7 | 8 | ## Introduction 9 | This repository contains the implementation of our TearingNet paper accepted in CVPR 2021. 10 | Given a point cloud dataset containing objects with various genera, or scenes with multiple objects, we propose the TearingNet, which is an autoencoder tackling the challenging task of representing the point clouds using a fixed-length descriptor. 11 | Unlike existing works directly deforming predefined primitives of genus zero (e.g., a 2D square patch) to an object-level point cloud, our TearingNet 12 | is characterized by a proposed Tearing network module and a Folding network module interacting with each other iteratively. 13 | Particularly, the Tearing network module learns the point cloud topology explicitly. 14 | By breaking the edges of a primitive graph, it tears the graph into patches or with holes to emulate the topology of a target point cloud, leading to faithful reconstructions. 15 | 16 | ## Installation 17 | * We use Python 3.6, PyTorch 1.3.1 and CUDA 10.0, example commands to set up a virtual environment with anaconda are: 18 | ``` 19 | conda create tearingnet python=3.6 20 | conda activate tearingnet 21 | conda install pytorch=1.3.1 torchvision=0.4.2 cudatoolkit=10.0 -c pytorch 22 | ``` 23 | 24 | * Clone our repo to a folder, assume the name of the folder is ``TearingNet``. 25 | * Checkout the nndistance folder from 3D Point Capsule Network, put it under ``TearingNet/util``, build it according to the instructions of 3D Point Capsule Network. 26 | * We adopt the ShapeNet part data loader from 3D Point Capsule Network. Checkout shapenet_part_loader.py and put it under the ``TearingNet/dataloaders``. 27 | * Install the TensorboardX, Open3D and the h5py packages, example commands are: 28 | ``` 29 | conda install -c open3d-admin open3d 30 | conda install -c conda-forge tensorboardx 31 | conda install -c anaconda h5py 32 | ``` 33 | 34 | ## Data Preparation 35 | 36 | ### KITTI Multi-Object Dataset 37 | 38 | * Our KITTI Multi-Object (KIMO) Dataset is constructed with kitti_dataset.py of PCDet (commit 95d2ab5). Please clone and install PCDet, then prepare the KITTI dataset according to their instructions. 39 | * Assume the name of the cloned folder is ``PCDet``, please replace the ``create_groundtruth_database()`` function in ``kitti_dataset.py`` by our modified one provided in ``TearingNet/util/pcdet_create_grouth_database.py``. 40 | * Prepare the KITTI dataset, then generate the data infos according to the instructions in the README.md of PCDet. 41 | * Create the folders ``TearingNet/dataset`` and ``TearingNet/dataset/kittimulobj`` then put the newly-generated folder ``PCDet/data/kitti/kitti_single`` under ``TearingNet/dataset/kittimulobj``. Also, put the newly-generated file ``PCDet/data/kitti/kitti_dbinfos_object.pkl`` under the ``TearingNet/dataset/kittimulobj`` folder. 42 | * Instead of assembling several single-object point clouds together and write down as a multi-object point cloud, we generate the parameters that parameterize the multi-object point clouds then assemble them on the fly during training/testing. To obtain the parameters, run our prepared scripts as follows under the ``TearingNet`` folder. These scripts generate the training and testing splits of the KIMO-5 dataset: 43 | ``` 44 | ./scripts/launch.sh ./scripts/gen_data/gen_kitti_mulobj_train_5x5.sh 45 | ./scripts/launch.sh ./scripts/gen_data/gen_kitti_mulobj_test_5x5.sh 46 | ``` 47 | * The file structure of the KIMO dataset after these steps becomes: 48 | ``` 49 | kittimulobj 50 | ├── kitti_dbinfos_object.pkl 51 | ├── kitti_mulobj_param_test_5x5_2048.pkl 52 | ├── kitti_mulobj_param_train_5x5_2048.pkl 53 | └── kitti_single 54 | ├── 0_0_Pedestrian.bin 55 | ├── 1000_0_Car.bin 56 | ├── 1000_1_Car.bin 57 | ├── 1000_2_Van.bin 58 | ... 59 | ``` 60 | 61 | ### CAD Model Multi-Object Dataset 62 | 63 | * Create the ``TearingNet/dataset/cadmulobj`` folder to hold the CAD Model Multi-Object (CAMO) dataset. 64 | 65 | * Our CAMO dataset is based on the CAD models of "person", "car", "cone", "plant" from ModelNet40, and "motorbike" from the ShapeNetPart dataset. Use the scripts download_shapenet_part16_catagories.sh and download_modelnet40_same_with_pointnet.sh from 3D Point Capsule Network to download these two datasets, then orgainze them according to the following file structure under the ``TearingNet/dataset`` folder: 66 | ``` 67 | dataset 68 | ├── cadmulobj 69 | ├── kittimulobj 70 | ├── modelnet40 71 | │ └── modelnet40_ply_hdf5_2048 72 | │ ├── ply_data_test0.h5 73 | │ ├── ply_data_test_0_id2file.json 74 | │ ├── ply_data_test1.h5 75 | │ ├── ply_data_test_1_id2file.json 76 | │ ... 77 | └── shapenet_part 78 | ├── shapenetcore_partanno_segmentation_benchmark_v0 79 | │ ├── 02691156 80 | │ │ ├── points 81 | │ │ │ ├── 1021a0914a7207aff927ed529ad90a11.pts 82 | │ │ │ ├── 103c9e43cdf6501c62b600da24e0965.pts 83 | │ │ │ ├── 105f7f51e4140ee4b6b87e72ead132ed.pts 84 | ... 85 | ``` 86 | * Extract the "person", "car", "cone" and "plant" models from ModelNet40, and the "motorbike" models from the ShapeNet part dataset, by running the following Python script under the ``TearingNet`` folder: 87 | ``` 88 | python util/cad_models_collector.py 89 | ``` 90 | * The previous step generates the file ``TearingNet/dataset/cadmulobj/cad_models.npy``, based on which we generate the parameters for the CAMO dataset. To do so, launch the following scripts: 91 | ``` 92 | ./scripts/launch.sh ./scripts/gen_data/gen_cad_mulobj_train_5x5.sh 93 | ./scripts/launch.sh ./scripts/gen_data/gen_cad_mulobj_test_5x5.sh 94 | ``` 95 | * The file structure of the CAMO dataset after these steps becomes: 96 | ``` 97 | cadmulobj 98 | ├── cad_models.npy 99 | ├── cad_mulobj_param_test_5x5.npy 100 | └── cad_mulobj_param_train_5x5.npy 101 | ``` 102 | ## Experiments 103 | 104 | ### Training 105 | 106 | We employ a two-stage training strategy to train the TearingNet. The first step is to train a FoldingNet (E-Net & F-Net in paper). Take the KIMO dataset as an example, launch the following scripts under the ``TearingNet`` folder: 107 | ``` 108 | ./scripts/launch.sh ./scripts/experiments/train_folding_kitti.sh 109 | ``` 110 | Having finished the first step, a pretrained model will be saved in ``TearingNet/results/train_folding_kitti``. To load the pretrained FoldingNet into a TearingNet configuration and perform training, launch the following scripts: 111 | ``` 112 | ./scripts/launch.sh ./scripts/experiments/train_tearing_kitti.sh 113 | ``` 114 | To see the meanings of the parameters in ``train_folding_kitti.sh`` and ``train_tearing_kitti.sh``, check the Python script ``TearinNet/util/option_handler.py``. 115 | ### Reconstruction 116 | 117 | To perform the reconstruction experiment with the trained model, launch the following scripts: 118 | ``` 119 | ./scripts/launch.sh ./scripts/experiments/reconstruction.sh 120 | ``` 121 | One may write down the reconstructions in PLY format by setting a positive ``PC_WRITE_FREQ`` value. Again, please refer to ``TearinNet/util/option_handler.py`` for the meanings of individual parameters. 122 | ### Counting 123 | 124 | To perform the counting experiment with the trained model, launch the following scripts: 125 | ``` 126 | ./scripts/launch.sh ./scripts/experiments/counting.sh 127 | ``` 128 | 129 | ## Citing this Work 130 | Please cite our work if you find it useful for your research: 131 | ``` 132 | @inproceedings{pang2021tearingnet, 133 | title={TearingNet: Point Cloud Autoencoder to Learn Topology-Friendly Representations}, 134 | author={Pang, Jiahao and Li, Duanshun, and Tian, Dong}, 135 | booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 136 | year={2021} 137 | } 138 | ``` 139 | 140 | ## Related Projects 141 | * 3D Point Capsule Networks 142 | * AtlasNet 143 | * AtlasNetV2 144 | 145 | * PCDet 146 | 147 |

148 | torus interpolation 149 |

-------------------------------------------------------------------------------- /assets/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/TearingNet/1e43caa266a11e9d50c2d912064bd39d369bb120/assets/framework.png -------------------------------------------------------------------------------- /assets/torus.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/TearingNet/1e43caa266a11e9d50c2d912064bd39d369bb120/assets/torus.gif -------------------------------------------------------------------------------- /dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | from . import cadmulobj_loader, kittimulobj_loader 5 | 6 | 7 | # Build the dataset for trainging accordingly 8 | def point_cloud_dataset_train(dataset_name, num_points, batch_size, train_split='train', val_batch_size=1, num_workers=8): 9 | train_dataset = None 10 | train_dataloader = None 11 | val_dataset = None 12 | val_dataloader = None 13 | if dataset_name.lower() == 'cad_mulobj': # CAD model multiple-object dataset 14 | train_dataset = cadmulobj_loader.CADMultiObjectDataset(num_points=2048, split=train_split, normalize=True) 15 | train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2) 16 | elif dataset_name.lower() == 'kitti_mulobj': # KITTI multiple-object dataset 17 | train_dataset = kittimulobj_loader.KITTIMultiObjectDataset(num_points=2048, split=train_split) 18 | train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2) 19 | return train_dataset, train_dataloader, val_dataset, val_dataloader 20 | 21 | 22 | # Build the dataset for testing accordingly 23 | def point_cloud_dataset_test(dataset_name, num_points, batch_size, test_split='test', test_class=None, num_workers=8): 24 | test_dataset = None 25 | if dataset_name.lower() == 'cad_mulobj': # Our multiple-object dataset 26 | test_dataset = cadmulobj_loader.CADMultiObjectDataset(num_points=2048, split=test_split, normalize=True) 27 | test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2) 28 | elif dataset_name.lower() == 'kitti_mulobj': # Our multiple-object dataset 29 | test_dataset = kittimulobj_loader.KITTIMultiObjectDataset(num_points=2048, split=test_split) 30 | test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2) 31 | return test_dataset, test_dataloader 32 | 33 | 34 | # Function to generate rotation matrix 35 | def gen_rotation_matrix(theta=-1, rotation_axis=-1): 36 | all_theta = [0, 90, 180, 270] 37 | all_axis = np.eye(3) 38 | 39 | if theta == -1: 40 | theta = all_theta[np.random.randint(0, 4)] 41 | elif theta == -2: 42 | theta = np.random.rand() * 360 43 | elif theta == -3: 44 | theta == (np.random.rand() - 0.5) * 90 45 | else: theta = all_theta[theta] 46 | rotation_theta = np.deg2rad(theta) 47 | 48 | if rotation_axis == -1: 49 | rotation_axis = all_axis[np.random.randint(0, 3), :] 50 | else: rotation_axis = all_axis[rotation_axis,:] 51 | sin_xyz = np.sin(rotation_theta) * rotation_axis 52 | 53 | R = np.cos(rotation_theta) * np.eye(3) 54 | 55 | R[0,1] = -sin_xyz[2] 56 | R[0,2] = sin_xyz[1] 57 | R[1,0] = sin_xyz[2] 58 | R[1,2] = -sin_xyz[0] 59 | R[2,0] = -sin_xyz[1] 60 | R[2,1] = sin_xyz[0] 61 | R = R + (1 - np.cos(rotation_theta)) * np.dot(np.expand_dims(rotation_axis, axis=1), np.expand_dims(rotation_axis, axis=0)) 62 | R = torch.from_numpy(R) 63 | return R.float() -------------------------------------------------------------------------------- /dataloaders/cad_models_loader.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Dataloader for CAD models of "person", "car", "cone", "plant" from ModelNet40, and "motorbike" from ShapeNetPart 3 | ''' 4 | 5 | import torch.utils.data as data 6 | import os 7 | import os.path 8 | import numpy as np 9 | 10 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 11 | dataset_path=os.path.abspath(os.path.join(BASE_DIR, '../dataset/')) 12 | 13 | 14 | class CADModelsDataset(data.Dataset): 15 | 16 | def create_cad_models_dataset_pickle(self, root): 17 | dict_dataset = np.load(os.path.join(root, 'cadmulobj/cad_models.npy'), allow_pickle=True).item() 18 | data_person = dict_dataset['person'] # 0. person 19 | data_car = dict_dataset['car'] # 1. car 20 | data_cone = dict_dataset['cone'] # 2. cone 21 | data_plant = dict_dataset['plant'] # 3. plant 22 | data_motorbike = dict_dataset['motorbike'] # 4. motorbike 23 | 24 | label_person = np.ones(data_person.shape[0], dtype=int) * 0 25 | label_car = np.ones(data_car.shape[0], dtype=int) * 1 26 | label_cone = np.ones(data_cone.shape[0], dtype=int) * 2 27 | label_plant = np.ones(data_plant.shape[0], dtype=int) * 3 28 | label_motorbike = np.ones(data_motorbike.shape[0], dtype=int) * 4 29 | self.data = np.concatenate((data_person, data_car, data_cone, data_plant, data_motorbike), axis=0) 30 | self.label = np.concatenate((label_person, label_car, label_cone, label_plant, label_motorbike), axis=0) 31 | self.total = self.data.shape[0] 32 | self.obj_type_num = 5 33 | 34 | def __init__(self, root=dataset_path, num_points=2048, normalize=True): 35 | self.npoints = num_points 36 | self.normalize = normalize 37 | self.create_cad_models_dataset_pickle(root) 38 | 39 | def __getitem__(self, index): 40 | point_set = self.data[index, 0 : self.npoints, :] 41 | label = self.label[index] 42 | if self.normalize: 43 | point_set = self.pc_normalize(point_set) 44 | return point_set, label 45 | 46 | def __len__(self): 47 | return self.total 48 | 49 | # pc: NxC, return NxC 50 | def pc_normalize(self, pc): 51 | l = pc.shape[0] 52 | centroid = np.mean(pc, axis=0) 53 | pc = pc - centroid 54 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 55 | pc = pc / m 56 | return pc 57 | -------------------------------------------------------------------------------- /dataloaders/cadmulobj_loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | Multi-object dataset based on CAD models 5 | ''' 6 | 7 | import torch.utils.data as data 8 | import os 9 | import sys 10 | import os.path 11 | import numpy as np 12 | 13 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 14 | dataset_path=os.path.abspath(os.path.join(BASE_DIR, '../dataset/')) 15 | 16 | 17 | class CADMultiObjectDataset(data.Dataset): 18 | 19 | def __init__(self, root=dataset_path, num_points=2048, split=None, normalize=True): 20 | from . import cad_models_loader 21 | self.npoints = num_points 22 | self.normalize = normalize 23 | dict_dataset = np.load(os.path.join(root, 'cadmulobj/cad_mulobj_param_' + split.lower() + '.npy'), allow_pickle=True).item() 24 | self.scene_radius = dict_dataset['scene_radius'] 25 | self.total = dict_dataset['total'] 26 | self.augmentation = dict_dataset['augmentation'] 27 | self.num_data_batch = dict_dataset['num_batch'] 28 | self.batch_num_model = dict_dataset['batch_num_model'] 29 | self.batch_num_example = dict_dataset['batch_num_example'] 30 | self.data_list = dict_dataset['list_example'] 31 | self.max_obj_num = np.max(self.batch_num_model) 32 | self.base_dataset = cad_models_loader.CADModelsDataset(num_points=2048, normalize=True) 33 | self.obj_type_num = self.base_dataset.obj_type_num 34 | 35 | def __getitem__(self, index): 36 | 37 | index_new = index 38 | label = np.ones(self.max_obj_num, dtype=int) * -1 39 | point_set = [] 40 | num_points_each = int(np.ceil(self.npoints / len(self.data_list[index_new]['idx']))) 41 | 42 | for cnt, idx_obj in enumerate(self.data_list[index_new]['idx']): # take out the models one-by-one 43 | obj_pc = self.pc_normalize(self.base_dataset[idx_obj][0][:num_points_each, :]) 44 | label[cnt] = self.base_dataset[idx_obj][1] 45 | trans = self.data_list[index_new]['coor'][cnt].copy() 46 | if self.augmentation == True: # rotation augmentation if needed 47 | obj_pc = obj_pc @ self.gen_rotation_matrix(self.data_list[index_new]['dr'][cnt][0], 1) 48 | trans[1] = trans[1] - np.min(obj_pc[:,1]) 49 | point_set.append(obj_pc + trans) 50 | 51 | point_set = np.vstack(point_set)[:self.npoints,:] 52 | return point_set.astype(np.float32), label 53 | 54 | def __len__(self): 55 | return self.total 56 | 57 | # pc: NxC, return NxC 58 | def pc_normalize(self, pc): 59 | l = pc.shape[0] 60 | centroid = np.mean(pc, axis=0) 61 | pc = pc - centroid 62 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 63 | pc = pc / m 64 | return pc 65 | 66 | # Generate rotation matrix, theta is in terms of degree 67 | def gen_rotation_matrix(self, theta, rotation_axis): 68 | all_axis = np.eye(3) 69 | rotation_theta = np.deg2rad(theta) 70 | rotation_axis = all_axis[rotation_axis,:] 71 | sin_xyz = np.sin(rotation_theta)*rotation_axis 72 | R = np.cos(rotation_theta)*np.eye(3) 73 | R[0,1] = -sin_xyz[2] 74 | R[0,2] = sin_xyz[1] 75 | R[1,0] = sin_xyz[2] 76 | R[1,2] = -sin_xyz[0] 77 | R[2,0] = -sin_xyz[1] 78 | R[2,1] = sin_xyz[0] 79 | R = R + (1-np.cos(rotation_theta))*np.dot(np.expand_dims(rotation_axis, axis = 1), np.expand_dims(rotation_axis, axis = 0)) 80 | return R 81 | 82 | 83 | def draw_coor(radius): 84 | x = np.random.randint(radius) * 2 - (radius - 1) 85 | y = np.random.randint(radius) * 2 - (radius - 1) 86 | z = 0 87 | return x,y,z 88 | 89 | 90 | # Use this main() function to generate data 91 | def main(): 92 | 93 | if os.path.exists(opt.cad_mulobj_output_file_name + '.npy') == False: 94 | num_batch = len(opt.cad_mulobj_num_example) 95 | list_example = [] 96 | for idx_batch in range(num_batch): 97 | for idx_example in range(opt.cad_mulobj_num_example[idx_batch]): 98 | print("Batch: %d, Idx: %d" % (idx_batch, idx_example)) 99 | coor = np.zeros((opt.cad_mulobj_num_add_model[idx_batch], 3), dtype=float) # coordinate of the objet 100 | idx = np.zeros(opt.cad_mulobj_num_add_model[idx_batch], dtype=int) # object index from the base dataset 101 | dr = np.zeros((opt.cad_mulobj_num_add_model[idx_batch], 2), dtype=int) # object orientation 102 | 103 | # Generate object properties 104 | for idx_obj in range(opt.cad_mulobj_num_add_model[idx_batch]): 105 | collision = True 106 | if opt.augmentation: 107 | dr[idx_obj, 0], dr[idx_obj, 1] = np.random.randint(0,360), np.random.randint(0,3) # generate direction 108 | else: dr[idx_obj, 0], dr[idx_obj, 1] = 0, 0 109 | idx[idx_obj] = np.random.randint(opt.cad_mulobj_num_ava_model) # generate object index 110 | while collision == True: # generate coordinate 111 | collision = False 112 | coor[idx_obj,2], coor[idx_obj,0], coor[idx_obj,1] = draw_coor(opt.cad_mulobj_scene_radius) 113 | for check_obj in range(idx_obj): # check collision between idx_obj and check_obj 114 | if np.sum(np.power(coor[idx_obj,:] - coor[check_obj,:], 2)) < 4 - 1e-9: 115 | collision =True 116 | break 117 | 118 | # Generate the object index 119 | list_example.append({'coor':coor, 'idx':idx, 'dr':dr}) 120 | 121 | # Save the dataset parameters 122 | dict_dataset = { 123 | 'scene_radius': opt.cad_mulobj_scene_radius, 124 | 'total': sum(opt.cad_mulobj_num_example), 125 | 'augmentation': opt.augmentation, 126 | 'ava_model_idx': opt.cad_mulobj_num_ava_model, 127 | 'num_batch': len(opt.cad_mulobj_num_example), 128 | 'batch_num_model': opt.cad_mulobj_num_add_model, 129 | 'batch_num_example': opt.cad_mulobj_num_example, 130 | 'list_example': list_example 131 | } 132 | np.save(opt.cad_mulobj_output_file_name, dict_dataset) 133 | print('Dataset %s generation completed.' % opt.cad_mulobj_output_file_name) 134 | 135 | if __name__ == "__main__": 136 | 137 | sys.path.append(os.path.join(BASE_DIR, '..')) 138 | from util.option_handler import GenerateCADMultiObjectOptionHandler 139 | option_handler = GenerateCADMultiObjectOptionHandler() 140 | opt = option_handler.parse_options() # all options are parsed through this command 141 | option_handler.print_options(opt) # print out all the options 142 | main() -------------------------------------------------------------------------------- /dataloaders/kittimulobj_loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | Multi-object dataset based on KITTI 5 | ''' 6 | 7 | import torch.utils.data as data 8 | import pickle 9 | import os 10 | import os.path 11 | import numpy as np 12 | import sys 13 | 14 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 15 | dataset_path=os.path.abspath(os.path.join(BASE_DIR, '../dataset/')) 16 | 17 | 18 | class KITTIMultiObjectDataset(data.Dataset): 19 | 20 | def __init__(self, root=dataset_path, num_points=2048, split='train'): 21 | 22 | with open(os.path.join(root, 'kittimulobj/kitti_mulobj_param_' + split.lower() + '_' + str(num_points) +'.pkl'), 'rb') as pickle_file: 23 | dict_dataset = pickle.load(pickle_file) 24 | with open(os.path.join(root, 'kittimulobj', dict_dataset['base_dataset'] +'.pkl'), 'rb') as pickle_file: 25 | self.obj_dataset = pickle.load(pickle_file) 26 | self.name_dict={'Pedestrian':0, 'Car':1, 'Cyclist':2, 'Van':3, 'Truck':4} 27 | self.npoints = num_points 28 | self.scene_radius = dict_dataset['scene_radius'] 29 | self.total = dict_dataset['total'] 30 | self.num_data_batch = dict_dataset['num_batch'] 31 | self.batch_num_model = dict_dataset['batch_num_model'] 32 | self.batch_num_example = dict_dataset['batch_num_example'] 33 | self.data_list = dict_dataset['list_example'] 34 | self.datapath = os.path.join(root, 'kittimulobj/kitti_single') 35 | self.max_obj_num = np.max(self.batch_num_model) 36 | self.obj_type_num = len(self.name_dict) 37 | if split.lower().find('test') >= 0: 38 | self.test = True 39 | else: self.test = False 40 | 41 | def __getitem__(self, index): 42 | 43 | label = np.ones(self.max_obj_num, dtype=int) * -1 44 | point_set = [] 45 | num_points_each = int(np.ceil(self.npoints / len(self.data_list[index]['idx']))) 46 | for cnt, idx_obj in enumerate(self.data_list[index]['idx']): # take out the models one-by-one 47 | obj_pc = np.fromfile(os.path.join(self.datapath, os.path.basename(self.obj_dataset[idx_obj]['path'])), dtype=np.float32).reshape(-1, 4) # read 48 | if self.test: np.random.seed(0) 49 | obj_pc = obj_pc[np.random.choice(obj_pc.shape[0],num_points_each), 0:3] # sample 50 | label[cnt] = self.name_dict[self.obj_dataset[idx_obj]['name']] 51 | obj_pc = self.pc_normalize(obj_pc, self.obj_dataset[idx_obj]['box3d_lidar'], self.obj_dataset[idx_obj]['name']) # normalize 52 | obj_pc[:,:2] += self.data_list[index]['coor'][cnt][:2] # translate 53 | point_set.append(obj_pc) 54 | point_set = np.vstack(point_set)[:self.npoints,:] 55 | 56 | return point_set.astype(np.float32), label 57 | 58 | def __len__(self): 59 | return self.total 60 | 61 | # pc: NxC, return NxC 62 | def pc_normalize(self, pc, bbox, label): 63 | pc -= bbox[:3] 64 | box_len = np.sqrt(bbox[3] ** 2 + bbox[4] ** 2 + bbox[5] ** 2) 65 | pc = pc / (box_len) * 2 66 | if label == 'Pedestrian' or label == 'Cyclist': 67 | pc /= 2 # shrink the point sets for models with a person to better emulate a driving scene 68 | return pc 69 | 70 | def draw_coor(radius): 71 | x = np.random.randint(radius) * 2 - (radius - 1) 72 | y = np.random.randint(radius) * 2 - (radius - 1) 73 | z = 0 74 | return x,y,z 75 | 76 | # Use this main() function to generate data 77 | def main(): 78 | dataset_path=os.path.abspath(os.path.join(BASE_DIR, '../dataset/kittimulobj')) 79 | if os.path.exists(opt.kitti_mulobj_output_file_name + '.pkl') == False: 80 | with open(os.path.join(dataset_path, "kitti_dbinfos_object.pkl"), 'rb') as pickle_file: 81 | db_obj = pickle.load(pickle_file) 82 | useful_obj_list = [] 83 | max_obj_num = int(opt.kitti_mulobj_scene_radius ** 2) 84 | for add_obj_num in range(max_obj_num): 85 | useful_obj_list.append([]) 86 | for idx_obj in range(len(db_obj)): 87 | for add_obj_num in range(max_obj_num): 88 | if db_obj[idx_obj]['num_points_in_gt'] >= int(np.ceil(opt.num_points / (add_obj_num + 1))): 89 | useful_obj_list[add_obj_num].append(idx_obj) 90 | num_batch = len(opt.kitti_mulobj_num_example) 91 | list_example = [] 92 | for idx_batch in range(num_batch): 93 | for idx_example in range(opt.kitti_mulobj_num_example[idx_batch]): 94 | print("Batch: %d, Idx: %d" % (idx_batch, idx_example)) 95 | coor = np.zeros((opt.kitti_mulobj_num_add_model[idx_batch], 3), dtype=float) # coordinate of the objet 96 | idx = np.zeros(opt.kitti_mulobj_num_add_model[idx_batch], dtype=int) # object index from the base dataset 97 | dr = np.zeros((opt.kitti_mulobj_num_add_model[idx_batch], 2), dtype=int) # object orientation 98 | 99 | # Generate object properties 100 | for idx_obj in range(opt.kitti_mulobj_num_add_model[idx_batch]): 101 | collision = True 102 | idx[idx_obj] = np.random.choice(useful_obj_list[opt.kitti_mulobj_num_add_model[idx_batch]-1]) # generate object index 103 | while collision == True: # generate coordinate 104 | collision = False 105 | coor[idx_obj,0], coor[idx_obj,1], coor[idx_obj,2] = draw_coor(opt.kitti_mulobj_scene_radius) 106 | for check_obj in range(idx_obj): # check collision between idx_obj and check_obj 107 | if np.sum(np.power(coor[idx_obj,:] - coor[check_obj,:], 2)) < 4 - 1e-9: 108 | collision =True 109 | break 110 | 111 | # Generate the object index 112 | list_example.append({'coor':coor, 'idx':idx, 'dr':dr}) 113 | 114 | # Save the dataset parameters 115 | dict_dataset = { 116 | 'base_dataset': "kitti_dbinfos_object", 117 | 'scene_radius': opt.kitti_mulobj_scene_radius, 118 | 'total': sum(opt.kitti_mulobj_num_example), 119 | 'num_batch': len(opt.kitti_mulobj_num_example), 120 | 'batch_num_model': opt.kitti_mulobj_num_add_model, 121 | 'batch_num_example': opt.kitti_mulobj_num_example, 122 | 'list_example': list_example 123 | } 124 | with open(opt.kitti_mulobj_output_file_name + '.pkl', 'wb') as f: 125 | print(f.name) 126 | pickle.dump(dict_dataset, f) 127 | print('Dataset generation completed.') 128 | else: print('Dataset already exist.') 129 | 130 | 131 | if __name__ == "__main__": 132 | 133 | sys.path.append(os.path.join(BASE_DIR, '..')) 134 | from util.option_handler import GenerateKittiMultiObjectOptionHandler 135 | option_handler = GenerateKittiMultiObjectOptionHandler() 136 | opt = option_handler.parse_options() # all options are parsed through this command 137 | option_handler.print_options(opt) # print out all the options 138 | main() -------------------------------------------------------------------------------- /dataloaders/modelnet_loader.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Modified based on https://github.com/yongheng1991/3D-point-capsule-networks/blob/master/dataloaders/modelnet40_loader.py 3 | ''' 4 | 5 | import os 6 | import numpy as np 7 | from torch.utils.data import Dataset 8 | import h5py 9 | 10 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 11 | data_path=os.path.abspath(os.path.join(BASE_DIR, '../dataset/modelnet40')) + '/' 12 | 13 | class ModelNet40(Dataset): 14 | 15 | def load_h5(self, file_name): 16 | f = h5py.File(file_name) 17 | data = f['data'][:] 18 | label = f['label'][:] 19 | return data, label 20 | 21 | def __init__(self, data_path=data_path, num_points=2048, transform=None, 22 | phase='train'): 23 | self.data_path = os.path.join(data_path, 'modelnet40_ply_hdf5_2048') 24 | self.num_points = num_points 25 | self.num_classes = 40 26 | self.transform = transform 27 | 28 | # store data 29 | shape_name_file = os.path.join(self.data_path, 'shape_names.txt') 30 | self.shape_names = [line.rstrip() for line in open(shape_name_file)] 31 | self.coordinates = [] 32 | self.labels = [] 33 | try: 34 | files = os.path.join(self.data_path, '{}_files.txt'.format(phase)) 35 | files = [line.rstrip() for line in open(files)] 36 | for index, file in enumerate(files): 37 | file_name = file.split('/')[-1] 38 | files[index] = os.path.join(self.data_path, file_name) 39 | except FileNotFoundError: 40 | raise ValueError('Unknown phase or invalid data path.') 41 | for file in files: 42 | current_data, current_label = self.load_h5(file) 43 | current_data = current_data[:, 0:self.num_points, :] 44 | self.coordinates.append(current_data) 45 | self.labels.append(current_label) 46 | self.coordinates = np.vstack(self.coordinates).astype(np.float32) 47 | self.labels = np.vstack(self.labels).squeeze().astype(np.int64) 48 | 49 | def __len__(self): 50 | return self.coordinates.shape[0] 51 | 52 | def __getitem__(self, index): 53 | # coord = np.transpose(self.coordinates[index]) # 3 * N 54 | coord = self.coordinates[index] 55 | label = self.labels[index] 56 | data = (coord,) 57 | # transform coordinates 58 | if self.transform is not None: 59 | transformed, matrix, mask = self.transform(coord) 60 | data += (transformed, matrix, mask) 61 | data += (label,) 62 | return data 63 | 64 | def class_extractor(label, loader): 65 | list_obj=[] 66 | for i in range(len(loader)): 67 | data, label_cur = loader.__getitem__(i) 68 | if label_cur == label: 69 | list_obj.append(data) 70 | list_obj = np.vstack(list_obj).reshape(-1,2048,3) 71 | return list_obj 72 | 73 | 74 | if __name__ == '__main__': 75 | main() -------------------------------------------------------------------------------- /dataloaders/shapenet_part_loader.py: -------------------------------------------------------------------------------- 1 | # print(dir_point, dir_seg) 2 | #from __future__ import print_function 3 | import torch.utils.data as data 4 | import os 5 | import os.path 6 | import torch 7 | import json 8 | import numpy as np 9 | import sys 10 | 11 | 12 | 13 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 14 | dataset_path=os.path.abspath(os.path.join(BASE_DIR, '../dataset/shapenet_part/shapenetcore_partanno_segmentation_benchmark_v0/')) 15 | 16 | class PartDataset(data.Dataset): 17 | def __init__(self, root=dataset_path, npoints=2500, classification=False, class_choice=None, split='train', normalize=True): 18 | self.npoints = npoints 19 | self.root = root 20 | self.catfile = os.path.join(self.root, 'synsetoffset2category.txt') 21 | self.cat = {} 22 | self.classification = classification 23 | self.normalize = normalize 24 | 25 | with open(self.catfile, 'r') as f: 26 | for line in f: 27 | ls = line.strip().split() 28 | self.cat[ls[0]] = ls[1] 29 | # print(self.cat) 30 | if not class_choice is None: 31 | self.cat = {k: v for k, v in self.cat.items() if k in class_choice} 32 | print(self.cat) 33 | self.meta = {} 34 | with open(os.path.join(self.root, 'train_test_split', 'shuffled_train_file_list.json'), 'r') as f: 35 | train_ids = set([str(d.split('/')[2]) for d in json.load(f)]) 36 | with open(os.path.join(self.root, 'train_test_split', 'shuffled_val_file_list.json'), 'r') as f: 37 | val_ids = set([str(d.split('/')[2]) for d in json.load(f)]) 38 | with open(os.path.join(self.root, 'train_test_split', 'shuffled_test_file_list.json'), 'r') as f: 39 | test_ids = set([str(d.split('/')[2]) for d in json.load(f)]) 40 | 41 | for item in self.cat: 42 | # print('category', item) 43 | self.meta[item] = [] 44 | dir_point = os.path.join(self.root, self.cat[item], 'points') 45 | dir_seg = os.path.join(self.root, self.cat[item], 'points_label') 46 | # print(dir_point, dir_seg) 47 | fns = sorted(os.listdir(dir_point)) 48 | if split == 'trainval': 49 | fns = [fn for fn in fns if ((fn[0:-4] in train_ids) or (fn[0:-4] in val_ids))] 50 | elif split == 'train': 51 | fns = [fn for fn in fns if fn[0:-4] in train_ids] 52 | elif split == 'val': 53 | fns = [fn for fn in fns if fn[0:-4] in val_ids] 54 | elif split == 'test': 55 | fns = [fn for fn in fns if fn[0:-4] in test_ids] 56 | else: 57 | print('Unknown split: %s. Exiting..' % (split)) 58 | exit(-1) 59 | 60 | for fn in fns: 61 | token = (os.path.splitext(os.path.basename(fn))[0]) 62 | self.meta[item].append((os.path.join(dir_point, token + '.pts'), os.path.join(dir_seg, token + '.seg'),self.cat[item], token)) 63 | self.datapath = [] 64 | for item in self.cat: 65 | for fn in self.meta[item]: 66 | self.datapath.append((item, fn[0], fn[1], fn[2], fn[3])) 67 | 68 | self.classes = dict(zip(sorted(self.cat), range(len(self.cat)))) 69 | print(self.classes) 70 | self.num_seg_classes = 0 71 | if not self.classification: 72 | for i in range(len(self.datapath)//50): 73 | l = len(np.unique(np.loadtxt(self.datapath[i][2]).astype(np.uint8))) 74 | if l > self.num_seg_classes: 75 | self.num_seg_classes = l 76 | # print(self.num_seg_classes) 77 | self.cache = {} # from index to (point_set, cls, seg) tuple 78 | self.cache_size = 18000 79 | 80 | def __getitem__(self, index): 81 | if index in self.cache: 82 | # point_set, seg, cls= self.cache[index] 83 | point_set, seg, cls, foldername, filename = self.cache[index] 84 | else: 85 | fn = self.datapath[index] 86 | cls = self.classes[self.datapath[index][0]] 87 | # cls = np.array([cls]).astype(np.int32) 88 | point_set = np.loadtxt(fn[1]).astype(np.float32) 89 | if self.normalize: 90 | point_set = self.pc_normalize(point_set) 91 | seg = np.loadtxt(fn[2]).astype(np.int64) - 1 92 | foldername = fn[3] 93 | filename = fn[4] 94 | if len(self.cache) < self.cache_size: 95 | self.cache[index] = (point_set, seg, cls, foldername, filename) 96 | 97 | #print(point_set.shape, seg.shape) 98 | choice = np.random.choice(len(seg), self.npoints, replace=True) 99 | # resample 100 | point_set = point_set[choice, :] 101 | seg = seg[choice] 102 | 103 | # To Pytorch 104 | point_set = torch.from_numpy(point_set) 105 | seg = torch.from_numpy(seg) 106 | cls = torch.from_numpy(np.array([cls]).astype(np.int64)) 107 | if self.classification: 108 | return point_set, cls 109 | else: 110 | return point_set, seg , cls 111 | 112 | 113 | def __len__(self): 114 | return len(self.datapath) 115 | 116 | def pc_normalize(self, pc): 117 | """ pc: NxC, return NxC """ 118 | l = pc.shape[0] 119 | centroid = np.mean(pc, axis=0) 120 | pc = pc - centroid 121 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 122 | pc = pc / m 123 | return pc 124 | 125 | 126 | if __name__ == '__main__': 127 | print('test') 128 | d = PartDataset( root='../dataset/shapenetcore_partanno_segmentation_benchmark_v0/',classification=True, class_choice='Airplane', npoints=2048, split='test') 129 | ps, cls = d[0] 130 | print(ps.size(), ps.type(), cls.size(), cls.type()) -------------------------------------------------------------------------------- /experiments/counting.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | Object counting experiment 5 | 6 | author: jpang 7 | created: Apr 27, 2020 (Mon) 13:55 EST 8 | ''' 9 | 10 | import multiprocessing 11 | multiprocessing.set_start_method('spawn', True) 12 | 13 | import torch 14 | import torch.nn.parallel 15 | import sys 16 | import os 17 | import numpy as np 18 | from sklearn.svm import SVC 19 | 20 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 21 | sys.path.append(os.path.join(BASE_DIR, '..')) 22 | 23 | from dataloaders import point_cloud_dataset_test 24 | from models.autoencoder import PointCloudAutoencoder 25 | from util.option_handler import CountingOptionHandler 26 | import time 27 | 28 | import warnings 29 | warnings.filterwarnings("ignore") 30 | 31 | def main(): 32 | print(torch.cuda.device_count(), "GPUs will be used for object counting.") 33 | 34 | # Build up an autoencoder 35 | t = time.time() 36 | 37 | # Load a saved model 38 | if opt.checkpoint == '': 39 | print("Please provide the model path.") 40 | exit() 41 | else: 42 | path_checkpoints = [os.path.join(opt.checkpoint, f) for f in os.listdir(opt.checkpoint) 43 | if os.path.isfile(os.path.join(opt.checkpoint, f)) and f[-1]=='h'] 44 | if len(path_checkpoints) > 1: 45 | path_checkpoints.sort(key=lambda f: int(''.join(filter(str.isdigit, f)))) 46 | 47 | # Take care of the dataset 48 | test_dataset, test_dataloader = point_cloud_dataset_test(opt.dataset_name, 49 | opt.num_points, opt.batch_size, opt.count_split) 50 | 51 | # Create folder to save results 52 | if not os.path.exists(opt.exp_name): 53 | os.makedirs(opt.exp_name) 54 | device = torch.device("cuda:0") 55 | 56 | # Go through each checkpoint in the folder 57 | max_avg_mae, min_avg_mae = 0, 1e6 58 | max_min_mae, min_min_mae = 0, 1e6 59 | for cnt, str_checkpoint in enumerate(path_checkpoints): 60 | 61 | if len(path_checkpoints) > 1: 62 | epoch_cur = int(''.join(filter(str.isdigit, str_checkpoint[str_checkpoint.find('epoch_'):]))) 63 | if (opt.epoch_interval[0] < 0) and (len(path_checkpoints) - cnt > opt.epoch_interval[1]): continue 64 | if (opt.epoch_interval[0] >= 0) and (epoch_cur >= opt.epoch_interval[0] and epoch_cur <= opt.epoch_interval[1])== False: continue 65 | else: epoch_cur = -1 66 | checkpoint = torch.load(str_checkpoint) 67 | 68 | # Load the model, may use model configuration from checkpoint 69 | if 'opt' in checkpoint and opt.config_from_checkpoint == True: 70 | checkpoint['opt'].grid_dims = opt.grid_dims 71 | ae = PointCloudAutoencoder(checkpoint['opt']) 72 | print("\nModel configuration loaded from checkpoint %s." % str_checkpoint) 73 | else: 74 | ae = PointCloudAutoencoder(opt) 75 | checkpoint['opt'] = opt 76 | torch.save(checkpoint, str_checkpoint) 77 | print("\nModel configuration written to checkpoint %s." % str_checkpoint) 78 | ae.load_state_dict(checkpoint['model_state_dict']) 79 | print("\nExisting model %s loaded." % (str_checkpoint)) 80 | ae.to(device) 81 | ae.eval() # set the autoencoder to evaluation mode 82 | 83 | # Compute the codewords 84 | print('Computing codewords..') 85 | batch_cw = [] 86 | batch_labels = [] 87 | torch.cuda.empty_cache() 88 | len_test = len(test_dataloader) 89 | with torch.no_grad(): 90 | for batch_id, data in enumerate(test_dataloader): 91 | points, labels = data 92 | if(points.size(0) < opt.batch_size): break 93 | points = points.cuda() 94 | with torch.no_grad(): 95 | rec = ae(points) 96 | cw = rec['cw'] 97 | batch_cw.append(cw.detach().cpu().numpy()) 98 | if (opt.dataset_name.lower() != 'torus_orig') and (opt.dataset_name.lower() != 'torus'): 99 | labels = torch.sum(labels>=0, dim=1) 100 | else: labels = labels + 1 101 | batch_labels.append(labels.detach().cpu().numpy()) 102 | if batch_id % opt.print_freq == 0: 103 | elapse = time.time() - t 104 | print(' batch_no: %d/%d, time: %f' % (batch_id, len_test, elapse)) 105 | batch_cw = np.vstack(batch_cw) 106 | batch_labels = np.vstack(batch_labels).reshape((-1,)).astype(int) 107 | 108 | # Divide the folds for experiments 109 | data_cnt_total = len(batch_labels) 110 | data_fold_idx = np.zeros(data_cnt_total, dtype=int) 111 | if (opt.dataset_name.lower() != 'torus_orig') and (opt.dataset_name.lower() != 'torus'): 112 | obj_cnt = np.zeros(test_dataset.max_obj_num, dtype=int) 113 | else: obj_cnt = np.zeros(3, dtype=int) 114 | for cur_obj in range(data_cnt_total): 115 | obj_cnt[batch_labels[cur_obj]-1] += 1 116 | data_fold_idx[cur_obj] = obj_cnt[batch_labels[cur_obj]-1] 117 | obj_cnt = (obj_cnt / opt.count_fold).astype(int) 118 | for cur_obj in range(data_cnt_total): 119 | data_fold_idx[cur_obj] = int((data_fold_idx[cur_obj]-1) / (obj_cnt[batch_labels[cur_obj]-1] + 1e-10)) 120 | 121 | # Perform testing for each fold 122 | mae = np.zeros(opt.count_fold, dtype=float) 123 | for cur_fold in range(opt.count_fold): 124 | data_train = batch_cw[data_fold_idx==cur_fold, :] 125 | label_train = batch_labels[data_fold_idx==cur_fold] 126 | data_test = batch_cw[data_fold_idx!=cur_fold, :] 127 | label_test = batch_labels[data_fold_idx!=cur_fold] 128 | 129 | # Train and test the classifier 130 | classifier = SVC(gamma='scale', C=opt.svm_params[0]) 131 | classifier.fit(data_train, label_train) 132 | pred = classifier.predict(data_test) 133 | mae[cur_fold] = np.mean(abs(label_test-pred)) 134 | print(' fold: {} '.format(cur_fold+1) + 'mae: {:.4f} '.format(mae[cur_fold]) ) 135 | print('avg_mae: {:.4f} '.format(np.mean(mae))) 136 | mae_idx = np.argmin(mae) 137 | print('min_mae: {:.4f} '.format(mae[mae_idx]) + 'min_mae_idx: {} '.format(mae_idx+1)) 138 | 139 | if np.mean(mae) > max_avg_mae: 140 | max_avg_mae = np.mean(mae) 141 | max_avg_mae_epoch = epoch_cur 142 | if np.mean(mae) < min_avg_mae: 143 | min_avg_mae = np.mean(mae) 144 | min_avg_mae_epoch = epoch_cur 145 | if mae[mae_idx] > max_min_mae: 146 | max_min_mae = mae[mae_idx] 147 | max_min_mae_epoch = epoch_cur 148 | if mae[mae_idx] < min_min_mae: 149 | min_min_mae = mae[mae_idx] 150 | min_min_mae_epoch = epoch_cur 151 | 152 | print('\nmax_avg_mae: {:.4f},'.format(max_avg_mae) + ' max_avg_mae_epoch: {:d}'.format(max_avg_mae_epoch)) 153 | print('min_avg_mae: {:.4f},'.format(min_avg_mae) + ' min_avg_mae_epoch: {:d}'.format(min_avg_mae_epoch)) 154 | print('max_min_mae: {:.4f},'.format(max_min_mae) + ' max_min_mae_epoch: {:d}'.format(max_min_mae_epoch)) 155 | print('min_min_mae: {:.4f},'.format(min_min_mae) + ' min_min_mae_epoch: {:d}'.format(min_min_mae_epoch)) 156 | 157 | print('\nDone!') 158 | 159 | if __name__ == "__main__": 160 | 161 | option_handler = CountingOptionHandler() 162 | opt = option_handler.parse_options() # all options are parsed through this command 163 | option_handler.print_options(opt) # print out all the options 164 | main() 165 | -------------------------------------------------------------------------------- /experiments/reconstruction.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | Reconstruction experiment 5 | ''' 6 | 7 | import multiprocessing 8 | multiprocessing.set_start_method('spawn', True) 9 | 10 | import open3d as o3d 11 | import torch 12 | import sys 13 | import os 14 | import numpy as np 15 | from PIL import Image 16 | 17 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 18 | sys.path.append(os.path.join(BASE_DIR, '..')) 19 | 20 | from dataloaders import point_cloud_dataset_test 21 | from models.autoencoder import PointCloudAutoencoder 22 | from util.option_handler import TestOptionHandler 23 | from util.mesh_writer import write_ply_mesh 24 | 25 | import warnings 26 | warnings.filterwarnings("ignore") 27 | 28 | def main(): 29 | print(torch.cuda.device_count(), "GPUs will be used for testing.") 30 | 31 | # Load a saved model 32 | if opt.checkpoint == '': 33 | print("Please provide the model path.") 34 | exit() 35 | else: 36 | checkpoint = torch.load(opt.checkpoint) 37 | if 'opt' in checkpoint and opt.config_from_checkpoint == True: 38 | checkpoint['opt'].grid_dims = opt.grid_dims 39 | checkpoint['opt'].xyz_chamfer_weight = opt.xyz_chamfer_weight 40 | ae = PointCloudAutoencoder(checkpoint['opt']) 41 | print("\nModel configuration loaded from checkpoint %s." % opt.checkpoint) 42 | else: 43 | ae = PointCloudAutoencoder(opt) 44 | checkpoint['opt'] = opt 45 | torch.save(checkpoint, opt.checkpoint) 46 | print("\nModel configuration written to checkpoint %s." % opt.checkpoint) 47 | ae.load_state_dict(checkpoint['model_state_dict']) 48 | print("Existing model %s loaded.\n" % (opt.checkpoint)) 49 | device = torch.device("cuda:0") 50 | ae.to(device) 51 | ae.eval() # set the autoencoder to evaluation mode 52 | 53 | # Create a folder to write the results 54 | if not os.path.exists(opt.exp_name): 55 | os.makedirs(opt.exp_name) 56 | 57 | # Take care of the dataset 58 | _, test_dataloader = point_cloud_dataset_test(opt.dataset_name, opt.num_points, opt.batch_size, opt.test_split) 59 | point_cnt = opt.grid_dims[0] * opt.grid_dims[1] 60 | 61 | # Set the colors of all the points as 0.5 for visualization 62 | coloring_rec = np.ones((opt.grid_dims[0] * opt.grid_dims[1], 3)) * 0.5 63 | coloring = np.concatenate((coloring_rec, np.repeat(np.array([np.array(opt.gt_color)]), opt.num_points, axis=0)), axis=0) 64 | 65 | # Begin testing 66 | test_loss_sum, test_loss_sqrt_sum = 0, 0 67 | batch_id = 0 68 | print('\nTesting...') 69 | not_end_yet = True 70 | it = iter(test_dataloader) 71 | len_test = len(test_dataloader) 72 | 73 | # Iterates the testing process 74 | while not_end_yet == True: 75 | 76 | # Fetch the data 77 | points, _ = next(it) 78 | points = points.cuda() 79 | not_end_yet = batch_id + 1 < len_test 80 | if(points.size(0) < opt.batch_size): break 81 | 82 | # The forward pass 83 | with torch.no_grad(): 84 | rec = ae(points) 85 | grid = rec['grid'] if 'grid' in rec else None 86 | if 'graph_wght' in rec: 87 | if opt.graph_delete_point_eps > 0: # may use another eps value 88 | old_eps = ae.decoder.graph_filter.graph_eps_sqr 89 | ae.decoder.graph_filter.graph_eps_sqr = opt.graph_delete_point_eps ** 2 90 | with torch.no_grad(): 91 | graph_wght = ae(points)['graph_wght'] 92 | ae.decoder.graph_filter.graph_eps_sqr = old_eps 93 | else: 94 | graph_wght = rec['graph_wght'] 95 | graph_wght = graph_wght.permute([0, 2, 3, 1]).detach().cpu().numpy() 96 | else: graph_wght = None 97 | rec = rec['rec'] 98 | 99 | # Benchmarking the results 100 | test_loss, test_loss_sqrt = 0, 0 101 | if not(graph_wght is None) and (opt.graph_delete_point_mode >=0) and opt.graph_delete_point_eps > 0: 102 | idx_map = np.zeros((opt.batch_size, opt.grid_dims[0], opt.grid_dims[1]), dtype=int) 103 | idx = np.zeros(opt.batch_size, dtype=int) 104 | 105 | # When graph is used, isolated points are removed 106 | for i in range(opt.grid_dims[0]): 107 | for j in range(opt.grid_dims[1]): 108 | degree = np.sum(graph_wght[:, i, j, :] > opt.graph_thres, axis=1) 109 | tag = (degree <= opt.graph_delete_point_mode) # tag the point to be deleted 110 | idx_map[:, i, j][tag == True] = -1 111 | idx_map[:, i, j][tag == False] = idx[tag == False] 112 | idx[tag == False] = idx[tag == False] + 1 113 | for b in range(opt.batch_size): 114 | pc_cur = rec[b,idx_map[b].reshape(opt.grid_dims[0] * opt.grid_dims[1])>=0].unsqueeze(0) 115 | test_loss += ae.xyz_loss(points[b].unsqueeze(0), pc_cur, xyz_loss_type=2) 116 | test_loss_sqrt += ae.xyz_loss(points[b].unsqueeze(0), pc_cur, xyz_loss_type=0) 117 | test_loss /= opt.batch_size 118 | test_loss_sqrt /= opt.batch_size 119 | test_loss_sum = test_loss_sum + test_loss 120 | test_loss_sqrt_sum = test_loss_sqrt_sum + test_loss_sqrt 121 | else: 122 | test_loss = ae.xyz_loss(points, rec, xyz_loss_type=2) 123 | test_loss_sqrt = ae.xyz_loss(points, rec, xyz_loss_type=0) 124 | test_loss_sum += test_loss.item() 125 | test_loss_sqrt_sum += test_loss_sqrt.item() 126 | 127 | if batch_id % opt.print_freq == 0: 128 | print(' batch_no: %d/%d, ch_dist: %f, ch^2_dist: %f' % 129 | (batch_id, len_test, test_loss_sqrt.item(), test_loss.item())) 130 | 131 | # Write down the first point cloud 132 | if opt.pc_write_freq > 0 and batch_id % opt.pc_write_freq == 0: 133 | rec_o3d = o3d.geometry.PointCloud() 134 | rec = rec[0].data.cpu() 135 | rec_o3d.colors = o3d.utility.Vector3dVector(coloring[:rec.shape[0],:]) 136 | rec_o3d.points = o3d.utility.Vector3dVector(rec) 137 | file_rec = os.path.join(opt.exp_name, str(batch_id) + "_rec.ply") 138 | o3d.io.write_point_cloud(file_rec, rec_o3d) 139 | 140 | gt_o3d = o3d.geometry.PointCloud() # write the ground-truth 141 | gt_o3d.points = o3d.utility.Vector3dVector(points[0].data.cpu()) 142 | gt_o3d.colors = o3d.utility.Vector3dVector(coloring[rec.shape[0]:,:]) 143 | file_gt = os.path.join(opt.exp_name, str(batch_id) + "_gt.ply") 144 | o3d.io.write_point_cloud(file_gt, gt_o3d) 145 | 146 | if not(grid is None): # write the torn grid 147 | grid = torch.cat((grid[0].contiguous().data.cpu(), torch.zeros(point_cnt, 1)), 1) 148 | grid_o3d = o3d.geometry.PointCloud() 149 | grid_o3d.points = o3d.utility.Vector3dVector(grid) 150 | grid_o3d.colors = o3d.utility.Vector3dVector(coloring_rec) 151 | file_grid = os.path.join(opt.exp_name, str(batch_id) + "_grid.ply") 152 | o3d.io.write_point_cloud(file_grid, grid_o3d) 153 | 154 | if not(graph_wght is None) and opt.write_mesh: 155 | output_path = os.path.join(opt.exp_name, str(batch_id) + "_rec_mesh.ply") # write the reconstructed mesh 156 | write_ply_mesh(rec.view(opt.grid_dims[0], opt.grid_dims[1], 3).numpy(), 157 | coloring_rec.reshape(opt.grid_dims[0], opt.grid_dims[1], 3), opt.graph_edge_color, 158 | output_path, thres=opt.graph_thres, delete_point_mode=opt.graph_delete_point_mode, 159 | weights=graph_wght[0], point_color_as_index=opt.point_color_as_index) 160 | 161 | output_path = os.path.join(opt.exp_name, str(batch_id) + "_grid_mesh.ply") # write the grid mesh 162 | write_ply_mesh(grid.view(opt.grid_dims[0], opt.grid_dims[1], 3).numpy(), 163 | coloring_rec.reshape(opt.grid_dims[0], opt.grid_dims[1], 3), opt.graph_edge_color, 164 | output_path, thres=opt.graph_thres, delete_point_mode=-1, 165 | weights=graph_wght[0], point_color_as_index=opt.point_color_as_index) 166 | 167 | batch_id = batch_id + 1 168 | 169 | # Log the test results 170 | avg_loss = test_loss_sum / (batch_id + 1) 171 | avg_loss_sqrt = test_loss_sqrt_sum / (batch_id + 1) 172 | print('avg_ch_dist: %f avg_ch^2_dist: %f' % (avg_loss_sqrt, avg_loss)) 173 | 174 | print('\nDone!\n') 175 | 176 | if __name__ == "__main__": 177 | 178 | option_handler = TestOptionHandler() 179 | opt = option_handler.parse_options() # all options are parsed through this command 180 | option_handler.print_options(opt) # print out all the options 181 | main() 182 | -------------------------------------------------------------------------------- /experiments/train_basic.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | Basic training script 5 | ''' 6 | 7 | import multiprocessing 8 | multiprocessing.set_start_method('spawn', True) 9 | 10 | import torch 11 | import torch.nn.parallel 12 | import torch.optim as optim 13 | import sys 14 | import os 15 | import numpy as np 16 | 17 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 18 | sys.path.append(os.path.join(BASE_DIR, '..')) 19 | 20 | from dataloaders import point_cloud_dataset_train, gen_rotation_matrix 21 | from models.autoencoder import PointCloudAutoencoder 22 | from util.option_handler import TrainOptionHandler 23 | from tensorboardX import SummaryWriter 24 | import time 25 | 26 | import warnings 27 | warnings.filterwarnings("ignore") 28 | 29 | def main(): 30 | # Build up an autoencoder 31 | t = time.time() 32 | print(torch.cuda.device_count(), "GPUs will be used for training.") 33 | device = torch.device("cuda:0") 34 | ae = PointCloudAutoencoder(opt) 35 | ae = torch.nn.DataParallel(ae) 36 | ae.to(device) 37 | 38 | # Create folder to save trained models 39 | if not os.path.exists(opt.exp_name): 40 | os.makedirs(opt.exp_name) 41 | 42 | # Take care of the dataset 43 | _, train_dataloader, _, _ = point_cloud_dataset_train(opt.dataset_name, opt.num_points, opt.batch_size, opt.train_split) 44 | 45 | # Create a tensorboard writer 46 | if opt.tf_summary: writer = SummaryWriter(comment=str.replace(opt.exp_name,'/','_')) 47 | 48 | # Setup the optimizer and the scheduler 49 | lr = opt.lr 50 | optimizer = optim.Adam(ae.parameters(), lr=lr[0], betas=(opt.optim_args[0], opt.optim_args[1]), weight_decay=opt.optim_args[2]) 51 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr[1], gamma=lr[2], last_epoch=-1) 52 | 53 | # Load a check point if given 54 | if opt.checkpoint != '': 55 | checkpoint = torch.load(opt.checkpoint) 56 | ae.module.load_state_dict(checkpoint['model_state_dict']) 57 | if opt.load_weight_only == False: 58 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 59 | scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 60 | print("Existing model %s loaded.\n" % (opt.checkpoint)) 61 | epoch_start = scheduler.state_dict()['last_epoch'] 62 | 63 | # Start training 64 | n_iter = 0 65 | for epoch in range(epoch_start, opt.n_epoch): 66 | ae.train() 67 | print('Training at epoch %d with lr %f' % (epoch, scheduler.get_lr()[0])) 68 | 69 | train_loss_sum = 0 70 | batch_id = 0 71 | not_end_yet = True 72 | it = iter(train_dataloader) 73 | len_train = len(train_dataloader) 74 | 75 | # Iterates the training process 76 | while not_end_yet == True: 77 | points, _ = next(it) 78 | not_end_yet = batch_id + 1 < len_train 79 | if(points.size(0) < opt.batch_size): 80 | break 81 | 82 | # Data augmentation 83 | points = points.cuda() 84 | if opt.augmentation == True: 85 | if opt.augmentation_theta != 0 or opt.augmentation_rotation_axis != 0: # rotation augmentation 86 | rotation_matrix = gen_rotation_matrix(theta=opt.augmentation_theta, 87 | rotation_axis=opt.augmentation_rotation_axis).unsqueeze(0).cuda().expand(opt.batch_size, -1 , -1) 88 | points = torch.bmm(points, rotation_matrix) 89 | if opt.augmentation_max_scale - opt.augmentation_min_scale > 1e-3: # scaling augmentation 90 | noise_scale = np.random.uniform(opt.augmentation_min_scale, opt.augmentation_max_scale) 91 | points *= noise_scale 92 | if (opt.augmentation_flip_axis >= 0) and (np.random.rand() < 0.5): # fliping augmentation 93 | points[:,:,opt.augmentation_flip_axis] *= -1 94 | optimizer.zero_grad() 95 | 96 | # Forward and backward computation 97 | rec = ae(points) 98 | rec = rec["rec"] 99 | train_loss = ae.module.xyz_loss(points, rec) 100 | train_loss.backward() 101 | optimizer.step() 102 | 103 | # Log the result 104 | train_loss_sum += train_loss.item() 105 | if batch_id % opt.print_freq == 0: 106 | elapse = time.time() - t 107 | print(' batch_no: %d/%d, time/iter: %f/%d, train_loss: %f' % 108 | (batch_id, len_train, elapse, n_iter, train_loss.item())) 109 | if opt.tf_summary: writer.add_scalar('train/batch_train_loss', train_loss.item(), n_iter) 110 | 111 | if opt.pc_write_freq > 0 and batch_id % opt.pc_write_freq == 0: 112 | labels = np.concatenate((np.ones(rec.shape[1]), np.zeros(points.shape[1])), axis=0).tolist() 113 | if opt.tf_summary: 114 | writer.add_embedding(torch.cat((rec[0, 0:, 0:3], points[0, :, 0:3]), 0), 115 | global_step=n_iter, metadata=labels, tag="pc") 116 | 117 | n_iter = n_iter + 1 118 | batch_id = batch_id + 1 119 | 120 | # Output the average loss of current epoch 121 | avg_loss = train_loss_sum / (batch_id + 1) 122 | elapse = time.time() - t 123 | print('Epoch: %d time: %f --- avg_loss: %f lr: %f' % (epoch, elapse, avg_loss, scheduler.get_lr()[0])) 124 | if opt.tf_summary: writer.add_scalar('train/epoch_train_loss', avg_loss, epoch) 125 | if opt.tf_summary: writer.add_scalar('train/learning_rate', scheduler.get_lr()[0], epoch) 126 | scheduler.step() # scheduler go one step 127 | 128 | # Save the checkpoint 129 | if epoch % opt.save_epoch_freq == 0 or epoch == opt.n_epoch - 1: 130 | dict_name=opt.exp_name + '/epoch_'+str(epoch)+'.pth' 131 | torch.save({ 132 | 'model_state_dict': ae.module.state_dict(), 133 | 'optimizer_state_dict': optimizer.state_dict(), 134 | 'scheduler_state_dict': scheduler.state_dict(), 135 | 'last_epoch': scheduler.state_dict()['last_epoch'], 136 | 'opt': opt 137 | }, dict_name) 138 | print('Current checkpoint saved to %s.' % (dict_name)) 139 | print('\n') 140 | 141 | if opt.tf_summary: writer.close() 142 | print('Done!') 143 | 144 | 145 | if __name__ == "__main__": 146 | 147 | option_handler = TrainOptionHandler() 148 | opt = option_handler.parse_options() # all options are parsed through this command 149 | option_handler.print_options(opt) # print out all the options 150 | main() 151 | -------------------------------------------------------------------------------- /experiments/train_tearing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | Training script for TearingNet 5 | ''' 6 | 7 | import multiprocessing 8 | multiprocessing.set_start_method('spawn', True) 9 | 10 | import torch 11 | import torch.nn.parallel 12 | import torch.optim as optim 13 | import sys 14 | import os 15 | import numpy as np 16 | 17 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 18 | sys.path.append(os.path.join(BASE_DIR, '..')) 19 | 20 | from dataloaders import point_cloud_dataset_train, gen_rotation_matrix 21 | from models.autoencoder import PointCloudAutoencoder 22 | from models.tearingnet import TearingNetBasicModel 23 | from models.tearingnet_graph import TearingNetGraphModel 24 | from util.option_handler import TrainOptionHandler 25 | from tensorboardX import SummaryWriter 26 | import time 27 | 28 | import warnings 29 | warnings.filterwarnings("ignore") 30 | 31 | def main(): 32 | # Build up an autoencoder 33 | t = time.time() 34 | print(torch.cuda.device_count(), "GPUs will be used for training.") 35 | device = torch.device("cuda:0") 36 | ae = PointCloudAutoencoder(opt) 37 | ae = torch.nn.DataParallel(ae) 38 | ae.to(device) 39 | tearingnet_basic = isinstance(ae.module.decoder, TearingNetBasicModel) 40 | tearingnet_graph = isinstance(ae.module.decoder, TearingNetGraphModel) 41 | 42 | # Create folder to save trained models 43 | if not os.path.exists(opt.exp_name): 44 | os.makedirs(opt.exp_name) 45 | 46 | # Take care of the dataset 47 | _, train_dataloader, _, _ = point_cloud_dataset_train(dataset_name=opt.dataset_name, \ 48 | num_points=opt.num_points, batch_size=opt.batch_size, train_split=opt.train_split) 49 | 50 | # Create a tensorboard writer 51 | if opt.tf_summary: writer = SummaryWriter(comment=str.replace(opt.exp_name,'/','_')) 52 | 53 | # Setup the optimizer and the scheduler 54 | lr = opt.lr 55 | optimizer = optim.Adam(ae.parameters(), lr=lr[0]) 56 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr[1], gamma=lr[2], last_epoch=-1) 57 | 58 | # Load a check point if given 59 | if opt.checkpoint != '': 60 | checkpoint = torch.load(opt.checkpoint) 61 | if opt.load_weight_only == False: 62 | ae.module.load_state_dict(checkpoint['model_state_dict']) 63 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 64 | scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 65 | print("Existing model %s fully loaded.\n" % (opt.checkpoint)) 66 | else: 67 | model_dict = ae.module.state_dict() # load parameters from pre-trained FoldingNet 68 | for k in checkpoint['model_state_dict']: 69 | if k in model_dict: 70 | model_dict[k] = checkpoint['model_state_dict'][k] 71 | print(" Found weight: " + k) 72 | elif k.replace('folding1', 'folding') in model_dict: 73 | model_dict[k.replace('folding1', 'folding')] = checkpoint['model_state_dict'][k] 74 | print(" Found weight: " + k) 75 | ae.module.load_state_dict(model_dict) 76 | print("Existing model %s (partly) loaded.\n" % (opt.checkpoint)) 77 | epoch_start = scheduler.state_dict()['last_epoch'] 78 | 79 | # Start training 80 | n_iter = 0 81 | for epoch in range(epoch_start, opt.n_epoch): 82 | ae.train() 83 | print('Training at epoch %d with lr %f' % (epoch, scheduler.get_lr()[0])) 84 | 85 | train_loss_xyz_sum = 0 86 | if tearingnet_basic: 87 | train_loss_xyz_pre_sum = 0 88 | if tearingnet_graph: 89 | train_loss_xyz_pre_sum = 0 90 | train_loss_xyz_pre2_sum = 0 91 | train_loss_sum = 0 92 | batch_id = 0 93 | not_end_yet = True 94 | it = iter(train_dataloader) 95 | len_train = len(train_dataloader) 96 | 97 | # Iterates the training process 98 | while not_end_yet == True: 99 | points, _ = next(it) 100 | not_end_yet = batch_id + 1 < len_train 101 | if(points.size(0) < opt.batch_size): 102 | break 103 | 104 | # Data agumentation 105 | points = points.cuda() 106 | if opt.augmentation == True: 107 | if opt.augmentation_theta != 0 or opt.augmentation_rotation_axis != 0: # rotation augmentation 108 | rotation_matrix = gen_rotation_matrix(theta=opt.augmentation_theta, 109 | rotation_axis=opt.augmentation_rotation_axis).unsqueeze(0).cuda().expand(opt.batch_size, -1 , -1) 110 | points = torch.bmm(points, rotation_matrix) 111 | if opt.augmentation_max_scale - opt.augmentation_min_scale > 1e-3: # scaling augmentation 112 | noise_scale = np.random.uniform(opt.augmentation_min_scale, opt.augmentation_max_scale) 113 | points *= noise_scale 114 | if (opt.augmentation_flip_axis >= 0) and (np.random.rand() < 0.5): # fliping augmentation 115 | points[:,:,opt.augmentation_flip_axis] *= -1 116 | optimizer.zero_grad() 117 | 118 | # Forward and backward 119 | rec = ae(points) 120 | grid = rec['grid'] 121 | if tearingnet_basic: 122 | rec_pre = rec['rec_pre'] 123 | if tearingnet_graph: 124 | rec_pre = rec['rec_pre'] 125 | rec_pre2 = rec['rec_pre2'] 126 | rec = rec['rec'] 127 | 128 | # Forward and backward computation 129 | train_loss_xyz = ae.module.xyz_loss(points, rec) 130 | if tearingnet_basic: 131 | train_loss_xyz_pre = ae.module.xyz_loss(points, rec_pre) 132 | if tearingnet_graph: 133 | train_loss_xyz_pre = ae.module.xyz_loss(points, rec_pre) 134 | train_loss_xyz_pre2 = ae.module.xyz_loss(points, rec_pre2) 135 | train_loss = train_loss_xyz # may add other loss if needed 136 | train_loss.backward() 137 | optimizer.step() 138 | 139 | # Log the result 140 | train_loss_xyz_sum += train_loss_xyz.item() 141 | if tearingnet_basic: train_loss_xyz_pre_sum += train_loss_xyz_pre.item() 142 | if tearingnet_graph: 143 | train_loss_xyz_pre_sum += train_loss_xyz_pre.item() 144 | train_loss_xyz_pre2_sum += train_loss_xyz_pre2.item() 145 | train_loss_sum += train_loss.item() 146 | if batch_id % opt.print_freq == 0: 147 | elapse = time.time() - t 148 | if tearingnet_basic: 149 | print(' batch_no: %d/%d, time/iter: %f/%d, loss_xyz: %f, loss_xyz_pre: %f, loss: %f' % 150 | (batch_id, len_train, elapse, n_iter, train_loss_xyz.item(), train_loss_xyz_pre.item(), train_loss.item())) 151 | elif tearingnet_graph: 152 | print(' batch_no: %d/%d, time/iter: %f/%d, loss_xyz: %f, loss_xyz_pre: %f, loss_xyz_pre2: %f, loss: %f' % 153 | (batch_id, len_train, elapse, n_iter, train_loss_xyz.item(), train_loss_xyz_pre.item(), train_loss_xyz_pre2.item(), train_loss.item())) 154 | else: 155 | print(' batch_no: %d/%d, time/iter: %f/%d, loss_xyz: %f, loss: %f' % 156 | (batch_id, len_train, elapse, n_iter, train_loss_xyz.item(), train_loss.item())) 157 | if opt.tf_summary: writer.add_scalar('train/batch_train_loss', train_loss.item(), n_iter) 158 | 159 | if opt.pc_write_freq > 0 and batch_id % opt.pc_write_freq == 0: 160 | labels = np.concatenate((np.ones(rec.shape[1]), np.zeros(points.shape[1])), axis=0).tolist() 161 | if opt.tf_summary: 162 | writer.add_embedding(torch.cat((rec[0, 0:, 0:3], points[0, :, 0:3]), 0), 163 | global_step=n_iter, metadata=labels, tag="pc") # the output point cloud 164 | writer.add_embedding(grid[0, 0:, 0:2], global_step=n_iter, tag="grid") # the 2D grid being used 165 | n_iter = n_iter + 1 166 | batch_id = batch_id + 1 167 | 168 | # Output the average loss of current epoch 169 | avg_loss_xyz = train_loss_xyz_sum / (batch_id + 1) 170 | if tearingnet_basic: 171 | avg_loss_xyz_pre = train_loss_xyz_pre_sum / (batch_id + 1) 172 | if tearingnet_graph: 173 | avg_loss_xyz_pre = train_loss_xyz_pre_sum / (batch_id + 1) 174 | avg_loss_xyz_pre2 = train_loss_xyz_pre2_sum / (batch_id + 1) 175 | avg_loss = train_loss_sum / (batch_id + 1) 176 | 177 | elapse = time.time() - t 178 | if tearingnet_basic: 179 | print('Epoch: %d time: %f --- avg_loss_xyz: %f, avg_loss_xyz_pre: %f, avg_loss: %f lr: %f' % \ 180 | (epoch, elapse, avg_loss_xyz, avg_loss_xyz_pre, avg_loss, scheduler.get_lr()[0])) 181 | elif tearingnet_graph: 182 | print('Epoch: %d time: %f --- avg_loss_xyz: %f, avg_loss_xyz_pre: %f, avg_loss_xyz_pre2: %f, avg_loss: %f lr: %f' % \ 183 | (epoch, elapse, avg_loss_xyz, avg_loss_xyz_pre, avg_loss_xyz_pre2, avg_loss, scheduler.get_lr()[0])) 184 | else: 185 | print('Epoch: %d time: %f --- avg_loss_xyz: %f, avg_loss: %f lr: %f' % \ 186 | (epoch, elapse, avg_loss_xyz, avg_loss, scheduler.get_lr()[0])) 187 | if opt.tf_summary: writer.add_scalar('train/epoch_train_loss', avg_loss, epoch) 188 | if opt.tf_summary: writer.add_scalar('train/learning_rate', scheduler.get_lr()[0], epoch) 189 | scheduler.step() 190 | 191 | # Save the checkpoint 192 | if epoch % opt.save_epoch_freq == 0 or epoch == opt.n_epoch - 1: 193 | dict_name=opt.exp_name + '/epoch_'+str(epoch)+'.pth' 194 | torch.save({ 195 | 'model_state_dict': ae.module.state_dict(), 196 | 'optimizer_state_dict': optimizer.state_dict(), 197 | 'scheduler_state_dict': scheduler.state_dict(), 198 | 'last_epoch': scheduler.state_dict()['last_epoch'], 199 | 'opt': opt 200 | }, dict_name) 201 | print('Current checkpoint saved to %s.' % (dict_name)) 202 | print('\n') 203 | 204 | if opt.tf_summary: writer.close() 205 | print('Done!') 206 | 207 | 208 | if __name__ == "__main__": 209 | 210 | option_handler = TrainOptionHandler() 211 | opt = option_handler.parse_options() # all options are parsed through this command 212 | option_handler.print_options(opt) # print out all the options 213 | main() 214 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | Initialization and simple utilities 5 | ''' 6 | 7 | import importlib 8 | 9 | def get_model_class(model_name): 10 | 11 | # enumerate the model files being used here 12 | model_file_list = ["pointnet", "autoencoder", "foldingnet", "tearingnet", "tearingnetgraph"] 13 | 14 | model_class = None 15 | for filename in model_file_list: 16 | 17 | # Retrieve the model class 18 | modellib = importlib.import_module("models." + filename) 19 | class_name = model_name + "Model" 20 | for name, cls in modellib.__dict__.items(): 21 | if name.lower() == class_name.lower(): 22 | model_class = cls 23 | break 24 | if model_class is not None: 25 | break 26 | 27 | if model_class is None: 28 | print("The specified model [{}] not found.".format(model_name)) 29 | exit(0) 30 | 31 | return model_class -------------------------------------------------------------------------------- /models/autoencoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | Point cloud auto-encoder backbone 5 | ''' 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.parallel 10 | import torch.utils.data 11 | import numpy as np 12 | import sys 13 | import os 14 | 15 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 16 | sys.path.append(os.path.join(BASE_DIR, '../util/nndistance')) 17 | from modules.nnd import NNDModule 18 | nn_match = NNDModule() 19 | USE_CUDA = True 20 | 21 | from . import get_model_class 22 | from .tearingnet import TearingNetBasicModel 23 | from .tearingnet_graph import TearingNetGraphModel 24 | 25 | 26 | class PointCloudAutoencoder(nn.Module): 27 | 28 | def __init__(self, opt): 29 | super(PointCloudAutoencoder, self).__init__() 30 | 31 | encoder_class = get_model_class(opt.encoder) 32 | self.encoder = encoder_class(opt) 33 | decoder_class = get_model_class(opt.decoder) 34 | self.decoder = decoder_class(opt) 35 | self.is_train = (opt.phase.lower() == 'train') 36 | if self.is_train: 37 | self.xyz_loss_type = opt.xyz_loss_type 38 | self.xyz_chamfer_weight = opt.xyz_chamfer_weight 39 | 40 | def forward(self, data): 41 | 42 | cw = self.encoder(data) 43 | if isinstance(self.decoder, TearingNetBasicModel): 44 | rec0, rec1, grid = self.decoder(cw) 45 | return {"rec": rec1, "rec_pre": rec0, "grid": grid, "cw": cw} 46 | elif isinstance(self.decoder, TearingNetGraphModel): 47 | rec0, rec1, rec2, grid, graph_wght = self.decoder(cw) 48 | return {"rec": rec2, "rec_pre": rec1, "rec_pre2": rec0, "grid": grid, "graph_wght": graph_wght, "cw": cw} 49 | else: 50 | rec = self.decoder(cw) 51 | return {"rec": rec, "cw": cw} 52 | 53 | def xyz_loss(self, data, rec, xyz_loss_type=-1): 54 | 55 | if xyz_loss_type == -1: 56 | xyz_loss_type = self.xyz_loss_type 57 | dist1, dist2 = nn_match(data.contiguous(), rec.contiguous()) 58 | dist2 = dist2 * self.xyz_chamfer_weight 59 | 60 | # Different variants of the Chamfer distance 61 | if xyz_loss_type == 0: # augmented Chamfer distance 62 | loss = torch.max(torch.mean(torch.sqrt(dist1), 1), torch.mean(torch.sqrt(dist2), 1)) 63 | loss = torch.mean(loss) 64 | elif xyz_loss_type == 1: 65 | loss = torch.mean(torch.sqrt(dist1), 1) + torch.mean(torch.sqrt(dist2), 1) 66 | loss = torch.mean(loss) 67 | elif xyz_loss_type == 2: # used in other papers 68 | loss = torch.mean(dist1) + torch.mean(dist2) 69 | return loss 70 | 71 | 72 | if __name__ == '__main__': 73 | USE_CUDA = True -------------------------------------------------------------------------------- /models/foldingnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | Basic FoldingNet model 5 | ''' 6 | 7 | import torch 8 | import torch.nn as nn 9 | from .pointnet import PointwiseMLP 10 | 11 | 12 | class FoldingNetVanilla(nn.Module): 13 | 14 | def __init__(self, folding1_dims, folding2_dims): 15 | super(FoldingNetVanilla, self).__init__() 16 | 17 | # The folding decoder 18 | self.fold1 = PointwiseMLP(folding1_dims, doLastRelu=False) 19 | if folding2_dims[0] > 0: 20 | self.fold2 = PointwiseMLP(folding2_dims, doLastRelu=False) 21 | else: self.fold2 = None 22 | 23 | def forward(self, cw, grid, **kwargs): 24 | 25 | cw_exp = cw.unsqueeze(1).expand(-1, grid.shape[1], -1) # batch_size X point_num X code_length 26 | 27 | # 1st folding 28 | in1 = torch.cat((grid, cw_exp), 2) # batch_size X point_num X (code_length + 3) 29 | out1 = self.fold1(in1) # batch_size X point_num X 3 30 | 31 | # 2nd folding 32 | if not(self.fold2 is None): 33 | in2 = torch.cat((out1, cw_exp), 2) # batch_size X point_num X (code_length + 4) 34 | out2 = self.fold2(in2) # batch_size X point_num X 3 35 | return out2 36 | else: return out1 37 | 38 | 39 | class FoldingNetVanillaModel(nn.Module): 40 | 41 | @staticmethod 42 | def add_options(parser, isTrain = True): 43 | 44 | # Some optionals 45 | parser.add_argument('--grid_dims', type=int, nargs='+', help='Grid dimensions.') 46 | parser.add_argument('--folding1_dims', type=int, nargs='+', default=[514, 512, 512, 3], help='Dimensions of the first folding module.') 47 | parser.add_argument('--folding2_dims', type=int, nargs='+', default=[515, 512, 512, 3], help='Dimensions of the second folding module.') 48 | return parser 49 | 50 | def __init__(self, opt): 51 | super(FoldingNetVanillaModel, self).__init__() 52 | 53 | # Initialize the 2D grid 54 | range_x = torch.linspace(-1.0, 1.0, opt.grid_dims[0]) 55 | range_y = torch.linspace(-1.0, 1.0, opt.grid_dims[1]) 56 | x_coor, y_coor = torch.meshgrid(range_x, range_y) 57 | self.grid = torch.stack([x_coor, y_coor], axis=-1).float().reshape(-1, 2) 58 | 59 | # Initialize the folding module 60 | self.folding1 = FoldingNetVanilla(opt.folding1_dims, opt.folding2_dims) 61 | 62 | def forward(self, cw): 63 | 64 | grid = self.grid.cuda().unsqueeze(0).expand(cw.shape[0], -1, -1) # batch_size X point_num X 2 65 | pc = self.folding1(cw, grid) 66 | return pc -------------------------------------------------------------------------------- /models/pointnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | Basic PointNet model 5 | ''' 6 | 7 | import torch.nn as nn 8 | from util.option_handler import str2bool 9 | 10 | 11 | def get_and_init_FC_layer(din, dout): 12 | li = nn.Linear(din, dout) 13 | #init weights/bias 14 | nn.init.xavier_uniform_(li.weight.data, gain=nn.init.calculate_gain('relu')) 15 | li.bias.data.fill_(0.) 16 | return li 17 | 18 | 19 | def get_MLP_layers(dims, doLastRelu): 20 | layers = [] 21 | for i in range(1, len(dims)): 22 | layers.append(get_and_init_FC_layer(dims[i-1], dims[i])) 23 | if i==len(dims)-1 and not doLastRelu: 24 | continue 25 | layers.append(nn.ReLU()) 26 | return layers 27 | 28 | 29 | class PointwiseMLP(nn.Sequential): 30 | '''Nxdin ->Nxd1->Nxd2->...-> Nxdout''' 31 | 32 | def __init__(self, dims, doLastRelu=False): 33 | layers = get_MLP_layers(dims, doLastRelu) 34 | super(PointwiseMLP, self).__init__(*layers) 35 | 36 | 37 | class GlobalPool(nn.Module): 38 | '''BxNxK -> BxK''' 39 | 40 | def __init__(self, pool_layer): 41 | super(GlobalPool, self).__init__() 42 | self.Pool = pool_layer 43 | 44 | def forward(self, X): 45 | X = X.unsqueeze(-3) 46 | X = self.Pool(X) 47 | X = X.squeeze(-2) 48 | X = X.squeeze(-2) 49 | return X 50 | 51 | class PointNetGlobalMax(nn.Sequential): 52 | '''BxNxdims[0] -> Bxdims[-1]''' 53 | 54 | def __init__(self, dims, doLastRelu=False): 55 | layers = [ 56 | PointwiseMLP(dims, doLastRelu=doLastRelu), 57 | GlobalPool(nn.AdaptiveMaxPool2d((1, dims[-1]))), 58 | ] 59 | super(PointNetGlobalMax, self).__init__(*layers) 60 | 61 | 62 | class PointNetVanilla(nn.Sequential): 63 | 64 | def __init__(self, MLP_dims, FC_dims, MLP_doLastRelu): 65 | assert(MLP_dims[-1]==FC_dims[0]) 66 | layers = [ 67 | PointNetGlobalMax(MLP_dims, doLastRelu=MLP_doLastRelu), 68 | ] 69 | layers.extend(get_MLP_layers(FC_dims, False)) 70 | super(PointNetVanilla, self).__init__(*layers) 71 | 72 | 73 | class PointNetVanillaModel(nn.Module): 74 | 75 | @staticmethod 76 | def add_options(parser, isTrain = True): 77 | parser.add_argument('--pointnet_mlp_dims', type=int, nargs='+', default=[3, 64, 128, 128, 1024], help='Dimensions of the MLP in the PointNet encoder.') 78 | parser.add_argument('--pointnet_fc_dims', type=int, nargs='+', default=[1024, 512, 512, 512], help='Dimensions of the FC in the PointNet encoder.') 79 | parser.add_argument("--pointnet_mlp_dolastrelu", type=str2bool, nargs='?', const=True, default=False, help='Apply the last ReLU or not in the PointNet encoder.') 80 | return parser 81 | 82 | def __init__(self, opt): 83 | super(PointNetVanillaModel, self).__init__() 84 | self.pointnet = PointNetVanilla(opt.pointnet_mlp_dims, opt.pointnet_fc_dims, opt.pointnet_mlp_dolastrelu) 85 | 86 | def forward(self, data): 87 | return self.pointnet(data) 88 | -------------------------------------------------------------------------------- /models/tearingnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | Basic TearingNet model 5 | ''' 6 | 7 | import torch 8 | import torch.nn as nn 9 | from .foldingnet import FoldingNetVanilla 10 | 11 | 12 | def get_Conv2d_layer(dims, kernel_size, doLastRelu): 13 | layers = [] 14 | for i in range(1, len(dims)): 15 | if kernel_size != 1: 16 | layers.append(nn.ReplicationPad2d(int((kernel_size - 1) / 2))) 17 | layers.append(nn.Conv2d(in_channels=dims[i-1], out_channels=dims[i], 18 | kernel_size=kernel_size, stride=1, padding=0, bias=True)) 19 | if i==len(dims)-1 and not doLastRelu: 20 | continue 21 | layers.append(nn.ReLU(inplace=True)) 22 | return layers 23 | 24 | 25 | class Conv2dLayers(nn.Sequential): 26 | def __init__(self, dims, kernel_size, doLastRelu=False): 27 | layers = get_Conv2d_layer(dims, kernel_size, doLastRelu) 28 | super(Conv2dLayers, self).__init__(*layers) 29 | 30 | 31 | class TearingNetBasic(nn.Module): 32 | 33 | def __init__(self, tearing1_dims, tearing2_dims, grid_dims, kernel_size=1): 34 | super(TearingNetBasic, self).__init__() 35 | 36 | self.grid_dims = grid_dims 37 | self.tearing1 = Conv2dLayers(tearing1_dims, kernel_size=kernel_size, doLastRelu=False) 38 | self.tearing2 = Conv2dLayers(tearing2_dims, kernel_size=kernel_size, doLastRelu=False) 39 | 40 | def forward(self, cw, grid, pc, **kwargs): 41 | 42 | grid_exp = grid.contiguous().view(grid.shape[0], self.grid_dims[0], self.grid_dims[1], 2) # batch_size X dim0 X dim1 X 2 43 | pc_exp = pc.contiguous().view(pc.shape[0], self.grid_dims[0], self.grid_dims[1], 3) # batch_size X dim0 X dim1 X 3 44 | cw_exp = cw.unsqueeze(1).unsqueeze(1).expand(-1, self.grid_dims[0], self.grid_dims[1], -1) # batch_size X dim0 X dim1 X code_length 45 | in1 = torch.cat((grid_exp, pc_exp, cw_exp), 3).permute([0, 3, 1, 2]) 46 | 47 | # Compute the torn 2D grid 48 | out1 = self.tearing1(in1) # 1st tearing 49 | in2 = torch.cat((in1, out1), 1) 50 | out2 = self.tearing2(in2) # 2nd tearing 51 | out2 = out2.permute([0, 2, 3, 1]).contiguous().view(grid.shape[0], self.grid_dims[0] * self.grid_dims[1], 2) 52 | return grid + out2 53 | 54 | 55 | class TearingNetBasicModel(nn.Module): 56 | 57 | @staticmethod 58 | def add_options(parser, isTrain = True): 59 | 60 | # General optional(s) 61 | parser.add_argument('--grid_dims', type=int, nargs='+', help='Grid dimensions.') 62 | 63 | # Options related to the Folding Network 64 | parser.add_argument('--folding1_dims', type=int, nargs='+', default=[514, 512, 512, 3], help='Dimensions of the first folding module.') 65 | parser.add_argument('--folding2_dims', type=int, nargs='+', default=[515, 512, 512, 3], help='Dimensions of the second folding module.') 66 | 67 | # Options related to the Tearing Network 68 | parser.add_argument('--tearing1_dims', type=int, nargs='+', default=[523, 256, 128, 64], help='Dimensions of the first tearing module.') 69 | parser.add_argument('--tearing2_dims', type=int, nargs='+', default=[587, 256, 128, 2], help='Dimensions of the second tearing module.') 70 | parser.add_argument('--tearing_conv_kernel_size', type=int, default=1, help='Kernel size of the convolutional layers in the Tearing Network, 1 implies MLP.') 71 | 72 | return parser 73 | 74 | def __init__(self, opt): 75 | super(TearingNetBasicModel, self).__init__() 76 | 77 | # Initialize the regular 2D grid 78 | range_x = torch.linspace(-1.0, 1.0, opt.grid_dims[0]) 79 | range_y = torch.linspace(-1.0, 1.0, opt.grid_dims[1]) 80 | x_coor, y_coor = torch.meshgrid(range_x, range_y) 81 | self.grid = torch.stack([x_coor, y_coor], axis=-1).float().reshape(-1, 2) 82 | 83 | # Initialize the Folding Network and the Tearing Network 84 | self.folding = FoldingNetVanilla(opt.folding1_dims, opt.folding2_dims) 85 | self.tearing = TearingNetBasic(opt.tearing1_dims, opt.tearing2_dims, opt.grid_dims, opt.tearing_conv_kernel_size) 86 | 87 | def forward(self, cw): 88 | 89 | grid0 = self.grid.cuda().unsqueeze(0).expand(cw.shape[0], -1, -1) # batch_size X point_num X 2 90 | pc0 = self.folding(cw, grid0) # Folding Network 91 | grid1 = self.tearing(cw, grid0, pc0) # Tearing Network 92 | pc1 = self.folding(cw, grid1) # Folding Network 93 | 94 | return pc0, pc1, grid1 -------------------------------------------------------------------------------- /models/tearingnet_graph.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | TearingNet with graph filtering 5 | ''' 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from .foldingnet import FoldingNetVanilla 11 | from .tearingnet import TearingNetBasic 12 | 13 | 14 | class GraphFilter(nn.Module): 15 | 16 | def __init__(self, grid_dims, graph_r, graph_eps, graph_lam): 17 | super(GraphFilter, self).__init__() 18 | self.grid_dims = grid_dims 19 | self.graph_r = graph_r 20 | self.graph_eps_sqr = graph_eps * graph_eps 21 | self.graph_lam = graph_lam 22 | 23 | def forward(self, grid, pc): 24 | 25 | # Data preparation 26 | bs_cur = pc.shape[0] 27 | grid_exp = grid.contiguous().view(bs_cur, self.grid_dims[0], self.grid_dims[1], 2) # batch_size X dim0 X dim1 X 2 28 | pc_exp = pc.contiguous().view(bs_cur, self.grid_dims[0], self.grid_dims[1], 3) # batch_size X dim0 X dim1 X 3 29 | graph_feature = torch.cat((grid_exp, pc_exp), dim=3).permute([0, 3, 1, 2]) 30 | 31 | # Compute the graph weights 32 | wght_hori = graph_feature[:,:,:-1,:] - graph_feature[:,:,1:,:] # horizontal weights 33 | wght_vert = graph_feature[:,:,:,:-1] - graph_feature[:,:,:,1:] # vertical weights 34 | wght_hori = torch.exp(-torch.sum(wght_hori * wght_hori, dim=1) / self.graph_eps_sqr) # Gaussian weight 35 | wght_vert = torch.exp(-torch.sum(wght_vert * wght_vert, dim=1) / self.graph_eps_sqr) 36 | wght_hori = (wght_hori > self.graph_r) * wght_hori 37 | wght_vert = (wght_vert > self.graph_r) * wght_vert 38 | wght_lft = torch.cat((torch.zeros([bs_cur, 1, self.grid_dims[1]]).cuda(), wght_hori), 1) # add left 39 | wght_rgh = torch.cat((wght_hori, torch.zeros([bs_cur, 1, self.grid_dims[1]]).cuda()), 1) # add right 40 | wght_top = torch.cat((torch.zeros([bs_cur, self.grid_dims[0], 1]).cuda(), wght_vert), 2) # add top 41 | wght_bot = torch.cat((wght_vert, torch.zeros([bs_cur, self.grid_dims[0], 1]).cuda()), 2) # add bottom 42 | wght_all = torch.cat((wght_lft.unsqueeze(1), wght_rgh.unsqueeze(1), wght_top.unsqueeze(1), wght_bot.unsqueeze(1)), 1) 43 | 44 | # Perform the actural graph filtering: x = (I - \lambda L) * x 45 | wght_hori = wght_hori.unsqueeze(1).expand(-1, 3, -1, -1) # dimension expansion 46 | wght_vert = wght_vert.unsqueeze(1).expand(-1, 3, -1, -1) 47 | pc = pc.permute([0, 2, 1]).contiguous().view(bs_cur, 3, self.grid_dims[0], self.grid_dims[1]) 48 | pc_filt = \ 49 | torch.cat((torch.zeros([bs_cur, 3, 1, self.grid_dims[1]]).cuda(), pc[:,:,:-1,:] * wght_hori), 2) + \ 50 | torch.cat((pc[:,:,1:,:] * wght_hori, torch.zeros([bs_cur, 3, 1, self.grid_dims[1]]).cuda()), 2) + \ 51 | torch.cat((torch.zeros([bs_cur, 3, self.grid_dims[0], 1]).cuda(), pc[:,:,:,:-1] * wght_vert), 3) + \ 52 | torch.cat((pc[:,:,:,1:] * wght_vert, torch.zeros([bs_cur, 3, self.grid_dims[0], 1]).cuda()), 3) # left, right, top, bottom 53 | 54 | pc_filt = pc + self.graph_lam * (pc_filt - torch.sum(wght_all,dim=1).unsqueeze(1).expand(-1, 3, -1, -1) * pc) # equivalent to ( I - \lambda L) * x 55 | pc_filt = pc_filt.view(bs_cur, 3, -1).permute([0, 2, 1]) 56 | return pc_filt, wght_all 57 | 58 | 59 | class TearingNetGraphModel(nn.Module): 60 | 61 | @staticmethod 62 | def add_options(parser, isTrain = True): 63 | 64 | # General optional(s) 65 | parser.add_argument('--grid_dims', type=int, nargs='+', help='Grid dimensions.') 66 | 67 | # Options related to the Folding Network 68 | parser.add_argument('--folding1_dims', type=int, nargs='+', default=[514, 512, 512, 3], help='Dimensions of the first folding module.') 69 | parser.add_argument('--folding2_dims', type=int, nargs='+', default=[515, 512, 512, 3], help='Dimensions of the second folding module.') 70 | 71 | # Options related to the Tearing Network 72 | parser.add_argument('--tearing1_dims', type=int, nargs='+', default=[523, 256, 128, 64], help='Dimensions of the first tearing module.') 73 | parser.add_argument('--tearing2_dims', type=int, nargs='+', default=[587, 256, 128, 2], help='Dimensions of the second tearing module.') 74 | parser.add_argument('--tearing_conv_kernel_size', type=int, default=1, help='Kernel size of the convolutional layers in the Tearing Network, 1 implies MLP.') 75 | 76 | # Options related to graph construction 77 | parser.add_argument('--graph_r', type=float, default=1e-12, help='Parameter r for the r-neighborhood graph.') 78 | parser.add_argument('--graph_eps', type=float, default=0.02, help='Parameter epsilon for the graph (bandwidth parameter).') 79 | parser.add_argument('--graph_lam', type=float, default=0.5, help='Parameter lambda for the graph filter.') 80 | 81 | return parser 82 | 83 | def __init__(self, opt): 84 | super(TearingNetGraphModel, self).__init__() 85 | 86 | # Initialize the regular 2D grid 87 | range_x = torch.linspace(-1.0, 1.0, opt.grid_dims[0]) 88 | range_y = torch.linspace(-1.0, 1.0, opt.grid_dims[1]) 89 | x_coor, y_coor = torch.meshgrid(range_x, range_y) 90 | self.grid = torch.stack([x_coor, y_coor], axis=-1).float().reshape(-1, 2) 91 | 92 | # Initialize the Folding Network and the Tearing Network 93 | self.folding = FoldingNetVanilla(opt.folding1_dims, opt.folding2_dims) 94 | self.tearing = TearingNetBasic(opt.tearing1_dims, opt.tearing2_dims, opt.grid_dims, opt.tearing_conv_kernel_size) 95 | self.graph_filter = GraphFilter(opt.grid_dims, opt.graph_r, opt.graph_eps, opt.graph_lam) 96 | 97 | def forward(self, cw): 98 | 99 | grid0 = self.grid.cuda().unsqueeze(0).expand(cw.shape[0], -1, -1) # batch_size X point_num X 2 100 | pc0 = self.folding(cw, grid0) # Folding Network 101 | grid1 = self.tearing(cw, grid0, pc0) # Tearing Network 102 | pc1 = self.folding(cw, grid1) # Folding Network 103 | pc2, graph_wght = self.graph_filter(grid1, pc1) # Graph Filtering 104 | return pc0, pc1, pc2, grid1, graph_wght 105 | -------------------------------------------------------------------------------- /scripts/experiments/counting.sh: -------------------------------------------------------------------------------- 1 | # Counting experiment 2 | 3 | EXP="count_tearing_kitti" 4 | PY_NAME="${HOME_DIR}/experiments/counting.py" 5 | EXP_NAME="results/${EXP}" 6 | CHECKPOINT="${HOME_DIR}/results/train_tearing_kitti/" 7 | PHASE="counting" 8 | BATCH_SIZE="8" 9 | COUNT_FOLD="4" 10 | EPOCH_INTERVAL="-1 10" 11 | CONFIG_FROM_CHECKPOINT="True" 12 | DATASET_NAME="kitti_mulobj" 13 | COUNT_SPLIT="test_5x5" 14 | PRINT_FREQ="50" 15 | GRID_DIMS="45 45" 16 | SVM_PARAMS="100" 17 | 18 | RUN_ARGUMENTS="${PY_NAME} --exp_name ${EXP_NAME} --checkpoint ${CHECKPOINT} --phase ${PHASE} --batch_size ${BATCH_SIZE} --count_fold ${COUNT_FOLD} --epoch_interval ${EPOCH_INTERVAL} --config_from_checkpoint ${CONFIG_FROM_CHECKPOINT} --dataset_name ${DATASET_NAME} --count_split ${COUNT_SPLIT} --print_freq ${PRINT_FREQ} --grid_dims ${GRID_DIMS} --svm_params ${SVM_PARAMS}" -------------------------------------------------------------------------------- /scripts/experiments/reconstruction.sh: -------------------------------------------------------------------------------- 1 | # Reconstruction experiment 2 | 3 | EXP="rec_tearing_kitti" 4 | PY_NAME="${HOME_DIR}/experiments/reconstruction.py" 5 | EXP_NAME="results/${EXP}" 6 | CHECKPOINT="${HOME_DIR}/results/train_tearing_kitti/epoch_639.pth" 7 | PHASE="test" 8 | CONFIG_FROM_CHECKPOINT="True" 9 | AUGMENTATION="False" 10 | DATASET_NAME="kitti_mulobj" 11 | TEST_SPLIT="test_5x5" 12 | BATCH_SIZE="32" 13 | GRID_DIMS="45 45" 14 | PRINT_FREQ="5" 15 | PC_WRITE_FREQ="-1" 16 | GT_COLOR="0.2 0.2 0.2" 17 | GRAPH_THRES="0" 18 | GRAPH_EDGE_COLOR="0.6 0.6 0.6" 19 | WRITE_MESH="True" 20 | GRAPH_DELETE_POINT_MODE="0" 21 | GRAPH_DELETE_POINT_EPS="0.08" 22 | 23 | RUN_ARGUMENTS="${PY_NAME} --exp_name ${EXP_NAME} --checkpoint ${CHECKPOINT} --phase ${PHASE} --config_from_checkpoint ${CONFIG_FROM_CHECKPOINT} --augmentation ${AUGMENTATION} --dataset_name ${DATASET_NAME} --test_split ${TEST_SPLIT} --batch_size ${BATCH_SIZE} --grid_dims ${GRID_DIMS} --print_freq ${PRINT_FREQ} --pc_write_freq ${PC_WRITE_FREQ} --gt_color ${GT_COLOR} --graph_thres ${GRAPH_THRES} --graph_edge_color ${GRAPH_EDGE_COLOR} --write_mesh ${WRITE_MESH} --graph_delete_point_mode ${GRAPH_DELETE_POINT_MODE} --graph_delete_point_eps ${GRAPH_DELETE_POINT_EPS}" -------------------------------------------------------------------------------- /scripts/experiments/train_folding_cad.sh: -------------------------------------------------------------------------------- 1 | # FoldingNet pretraining 2 | 3 | EXP="train_folding_cad" 4 | PY_NAME="${HOME_DIR}/experiments/train_basic.py" 5 | EXP_NAME="results/${EXP}" 6 | PHASE="train" 7 | N_EPOCH="640" 8 | SAVE_EPOCH_FREQ="10" 9 | PRINT_FREQ="50" 10 | PC_WRITE_FREQ="1000" 11 | BATCH_SIZE="32" 12 | XYZ_LOSS_TYPE="0" 13 | XYZ_CHAMFER_WEIGHT="0.01" 14 | LR_POLICY="step" 15 | LR="0.0002 80 0.5" 16 | DATASET_NAME="cad_mulobj" 17 | TRAIN_SPLIT="train_5x5" 18 | AUGMENTATION="True" 19 | AUGMENTATION_ROTATION_AXIS="1" 20 | GRID_DIMS="45 45" 21 | ENCODER="pointnetvanilla" 22 | DECODER="foldingnetvanilla" 23 | POINTNET_MLP_DIMS="3 64 64 64 128 1024" 24 | POINTNET_FC_DIMS="1024 512 512" 25 | POINTNET_MLP_DOLASTRELU="False" 26 | FOLDING1_DIMS="514 512 512 3" 27 | FOLDING2_DIMS="515 512 512 3" 28 | 29 | RUN_ARGUMENTS="${PY_NAME} --exp_name ${EXP_NAME} --phase ${PHASE} --n_epoch ${N_EPOCH} --save_epoch_freq ${SAVE_EPOCH_FREQ} --print_freq ${PRINT_FREQ} --pc_write_freq ${PC_WRITE_FREQ} --batch_size ${BATCH_SIZE} --xyz_loss_type ${XYZ_LOSS_TYPE} --xyz_chamfer_weight ${XYZ_CHAMFER_WEIGHT} --lr_policy ${LR_POLICY} --lr ${LR} --dataset_name ${DATASET_NAME} --train_split ${TRAIN_SPLIT} --augmentation ${AUGMENTATION} --augmentation_rotation_axis ${AUGMENTATION_ROTATION_AXIS} --grid_dims ${GRID_DIMS} --encoder ${ENCODER} --decoder ${DECODER} --pointnet_mlp_dims ${POINTNET_MLP_DIMS} --pointnet_fc_dims ${POINTNET_FC_DIMS} --pointnet_mlp_dolastrelu ${POINTNET_MLP_DOLASTRELU} --folding1_dims ${FOLDING1_DIMS} --folding2_dims ${FOLDING2_DIMS}" -------------------------------------------------------------------------------- /scripts/experiments/train_folding_kitti.sh: -------------------------------------------------------------------------------- 1 | # FoldingNet pretraining 2 | 3 | EXP="train_folding_kitti" 4 | PY_NAME="${HOME_DIR}/experiments/train_basic.py" 5 | EXP_NAME="results/${EXP}" 6 | PHASE="train" 7 | N_EPOCH="640" 8 | SAVE_EPOCH_FREQ="10" 9 | PRINT_FREQ="50" 10 | PC_WRITE_FREQ="1000" 11 | BATCH_SIZE="32" 12 | XYZ_LOSS_TYPE="0" 13 | XYZ_CHAMFER_WEIGHT="0.01" 14 | LR_POLICY="step" 15 | LR="0.0002 80 0.5" 16 | DATASET_NAME="kitti_mulobj" 17 | TRAIN_SPLIT="train_5x5" 18 | AUGMENTATION="True" 19 | AUGMENTATION_THETA="0" 20 | AUGMENTATION_ROTATION_AXIS="0" 21 | AUGMENTATION_FLIP_AXIS="1" 22 | GRID_DIMS="45 45" 23 | ENCODER="pointnetvanilla" 24 | DECODER="foldingnetvanilla" 25 | POINTNET_MLP_DIMS="3 64 64 64 128 1024" 26 | POINTNET_FC_DIMS="1024 512 512" 27 | POINTNET_MLP_DOLASTRELU="False" 28 | FOLDING1_DIMS="514 512 512 3" 29 | FOLDING2_DIMS="515 512 512 3" 30 | 31 | RUN_ARGUMENTS="${PY_NAME} --exp_name ${EXP_NAME} --phase ${PHASE} --n_epoch ${N_EPOCH} --save_epoch_freq ${SAVE_EPOCH_FREQ} --print_freq ${PRINT_FREQ} --pc_write_freq ${PC_WRITE_FREQ} --batch_size ${BATCH_SIZE} --xyz_loss_type ${XYZ_LOSS_TYPE} --xyz_chamfer_weight ${XYZ_CHAMFER_WEIGHT} --lr_policy ${LR_POLICY} --lr ${LR} --dataset_name ${DATASET_NAME} --train_split ${TRAIN_SPLIT} --augmentation ${AUGMENTATION} --augmentation_theta ${AUGMENTATION_THETA} --augmentation_rotation_axis ${AUGMENTATION_ROTATION_AXIS} --augmentation_flip_axis ${AUGMENTATION_FLIP_AXIS} --grid_dims ${GRID_DIMS} --encoder ${ENCODER} --decoder ${DECODER} --pointnet_mlp_dims ${POINTNET_MLP_DIMS} --pointnet_fc_dims ${POINTNET_FC_DIMS} --pointnet_mlp_dolastrelu ${POINTNET_MLP_DOLASTRELU} --folding1_dims ${FOLDING1_DIMS} --folding2_dims ${FOLDING2_DIMS}" -------------------------------------------------------------------------------- /scripts/experiments/train_tearing_cad.sh: -------------------------------------------------------------------------------- 1 | # Train the TearingNet 2 | 3 | EXP="train_tearing_cad" 4 | PY_NAME="${HOME_DIR}/experiments/train_tearing.py" 5 | EXP_NAME="results/${EXP}" 6 | CHECKPOINT="${HOME_DIR}/results/train_folding_cad/epoch_639.pth" 7 | LOAD_WEIGHT_ONLY="True" 8 | PHASE="train" 9 | N_EPOCH="480" 10 | SAVE_EPOCH_FREQ="10" 11 | PRINT_FREQ="50" 12 | PC_WRITE_FREQ="1000" 13 | BATCH_SIZE="32" 14 | XYZ_LOSS_TYPE="0" 15 | LR_POLICY="step" 16 | LR="0.000001 80 0.5" 17 | DATASET_NAME="cad_mulobj" 18 | TRAIN_SPLIT="train_5x5" 19 | AUGMENTATION="True" 20 | AUGMENTATION_ROTATION_AXIS="1" 21 | GRID_DIMS="45 45" 22 | ENCODER="pointnetvanilla" 23 | DECODER="tearingnetgraph" 24 | POINTNET_MLP_DIMS="3 64 64 64 128 1024" 25 | POINTNET_FC_DIMS="1024 512 512" 26 | POINTNET_MLP_DOLASTRELU="False" 27 | FOLDING1_DIMS="514 512 512 3" 28 | FOLDING2_DIMS="515 512 512 3" 29 | TEARING1_DIMS="517 512 512 64" 30 | TEARING2_DIMS="581 512 512 2" 31 | GRAPH_R="1e-12" 32 | GRAPH_EPS="0.02" 33 | GRAPH_LAM="0.5" 34 | 35 | RUN_ARGUMENTS="${PY_NAME} --exp_name ${EXP_NAME} --checkpoint ${CHECKPOINT} --load_weight_only ${LOAD_WEIGHT_ONLY} --phase ${PHASE} --n_epoch ${N_EPOCH} --save_epoch_freq ${SAVE_EPOCH_FREQ} --print_freq ${PRINT_FREQ} --pc_write_freq ${PC_WRITE_FREQ} --batch_size ${BATCH_SIZE} --xyz_loss_type ${XYZ_LOSS_TYPE} --lr_policy ${LR_POLICY} --lr ${LR} --dataset_name ${DATASET_NAME} --train_split ${TRAIN_SPLIT} --augmentation ${AUGMENTATION} --augmentation_rotation_axis ${AUGMENTATION_ROTATION_AXIS} --grid_dims ${GRID_DIMS} --encoder ${ENCODER} --decoder ${DECODER} --pointnet_mlp_dims ${POINTNET_MLP_DIMS} --pointnet_fc_dims ${POINTNET_FC_DIMS} --pointnet_mlp_dolastrelu ${POINTNET_MLP_DOLASTRELU} --folding1_dims ${FOLDING1_DIMS} --folding2_dims ${FOLDING2_DIMS} --tearing1_dims ${TEARING1_DIMS} --tearing2_dims ${TEARING2_DIMS} --graph_r ${GRAPH_R} --graph_eps ${GRAPH_EPS} --graph_lam ${GRAPH_LAM}" -------------------------------------------------------------------------------- /scripts/experiments/train_tearing_kitti.sh: -------------------------------------------------------------------------------- 1 | # Train the TearingNet 2 | 3 | EXP="train_tearing_kitti" 4 | PY_NAME="${HOME_DIR}/experiments/train_tearing.py" 5 | EXP_NAME="results/${EXP}" 6 | CHECKPOINT="${HOME_DIR}/results/train_folding_kitti/epoch_639.pth" 7 | LOAD_WEIGHT_ONLY="True" 8 | PHASE="train" 9 | N_EPOCH="480" 10 | SAVE_EPOCH_FREQ="10" 11 | PRINT_FREQ="50" 12 | PC_WRITE_FREQ="1000" 13 | BATCH_SIZE="32" 14 | XYZ_LOSS_TYPE="0" 15 | LR_POLICY="step" 16 | LR="0.000001 80 0.5" 17 | DATASET_NAME="kitti_mulobj" 18 | TRAIN_SPLIT="train_5x5" 19 | AUGMENTATION="True" 20 | AUGMENTATION_THETA="0" 21 | AUGMENTATION_ROTATION_AXIS="0" 22 | AUGMENTATION_FLIP_AXIS="1" 23 | GRID_DIMS="45 45" 24 | ENCODER="pointnetvanilla" 25 | DECODER="tearingnetgraph" 26 | POINTNET_MLP_DIMS="3 64 64 64 128 1024" 27 | POINTNET_FC_DIMS="1024 512 512" 28 | POINTNET_MLP_DOLASTRELU="False" 29 | FOLDING1_DIMS="514 512 512 3" 30 | FOLDING2_DIMS="515 512 512 3" 31 | TEARING1_DIMS="517 512 512 64" 32 | TEARING2_DIMS="581 512 512 2" 33 | GRAPH_R="1e-12" 34 | GRAPH_EPS="0.02" 35 | GRAPH_LAM="0.5" 36 | 37 | RUN_ARGUMENTS="${PY_NAME} --exp_name ${EXP_NAME} --checkpoint ${CHECKPOINT} --load_weight_only ${LOAD_WEIGHT_ONLY} --phase ${PHASE} --n_epoch ${N_EPOCH} --save_epoch_freq ${SAVE_EPOCH_FREQ} --print_freq ${PRINT_FREQ} --pc_write_freq ${PC_WRITE_FREQ} --batch_size ${BATCH_SIZE} --xyz_loss_type ${XYZ_LOSS_TYPE} --lr_policy ${LR_POLICY} --lr ${LR} --dataset_name ${DATASET_NAME} --train_split ${TRAIN_SPLIT} --augmentation ${AUGMENTATION} --augmentation_theta ${AUGMENTATION_THETA} --augmentation_rotation_axis ${AUGMENTATION_ROTATION_AXIS} --augmentation_flip_axis ${AUGMENTATION_FLIP_AXIS} --grid_dims ${GRID_DIMS} --encoder ${ENCODER} --decoder ${DECODER} --pointnet_mlp_dims ${POINTNET_MLP_DIMS} --pointnet_fc_dims ${POINTNET_FC_DIMS} --pointnet_mlp_dolastrelu ${POINTNET_MLP_DOLASTRELU} --folding1_dims ${FOLDING1_DIMS} --folding2_dims ${FOLDING2_DIMS} --tearing1_dims ${TEARING1_DIMS} --tearing2_dims ${TEARING2_DIMS} --graph_r ${GRAPH_R} --graph_eps ${GRAPH_EPS} --graph_lam ${GRAPH_LAM}" -------------------------------------------------------------------------------- /scripts/gen_data/gen_cad_mulobj_test_5x5.sh: -------------------------------------------------------------------------------- 1 | # Generate CAD model multiple-object dataset 2 | 3 | EXP_NAME="results/gen_cad_mulobj_test_5x5" 4 | PY_NAME="${HOME_DIR}/dataloaders/cadmulobj_loader.py" 5 | PHASE="gen_cadmultiobj" 6 | NUM_POINTS="2048" 7 | AUGMENTATION="True" 8 | CAD_MULOBJ_NUM_ADD_MODEL="3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22" 9 | CAD_MULOBJ_NUM_EXAMPLE="1 4 16 53 143 322 609 974 1328 1550 1550 1328 974 609 322 143 53 16 4 1" 10 | CAD_MULOBJ_NUM_AVA_MODEL="1133" 11 | CAD_MULOBJ_SCENE_RADIUS="5" 12 | CAD_MULOBJ_OUTPUT_FILE_NAME="${HOME_DIR}/dataset/cadmulobj/cad_mulobj_param_test_5x5" 13 | 14 | RUN_ARGUMENTS="${PY_NAME} --exp_name ${EXP_NAME} --phase ${PHASE} --num_points ${NUM_POINTS} --augmentation ${AUGMENTATION} --cad_mulobj_num_add_model ${CAD_MULOBJ_NUM_ADD_MODEL} --cad_mulobj_num_example ${CAD_MULOBJ_NUM_EXAMPLE} --cad_mulobj_num_ava_model ${CAD_MULOBJ_NUM_AVA_MODEL} --cad_mulobj_scene_radius ${CAD_MULOBJ_SCENE_RADIUS} --cad_mulobj_output_file_name ${CAD_MULOBJ_OUTPUT_FILE_NAME}" -------------------------------------------------------------------------------- /scripts/gen_data/gen_cad_mulobj_train_5x5.sh: -------------------------------------------------------------------------------- 1 | # Generate CAD model multiple-object dataset 2 | 3 | EXP_NAME="results/gen_cad_mulobj_train_5x5" 4 | PY_NAME="${HOME_DIR}/dataloaders/cadmulobj_loader.py" 5 | PHASE="gen_cadmultiobj" 6 | NUM_POINTS="2048" 7 | AUGMENTATION="True" 8 | CAD_MULOBJ_NUM_ADD_MODEL="3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22" 9 | CAD_MULOBJ_NUM_EXAMPLE="3 19 79 264 716 1612 3044 4871 6642 7749 7749 6642 4871 3044 1612 716 264 79 19 3" 10 | CAD_MULOBJ_NUM_AVA_MODEL="1133" 11 | CAD_MULOBJ_SCENE_RADIUS="5" 12 | CAD_MULOBJ_OUTPUT_FILE_NAME="${HOME_DIR}/dataset/cadmulobj/cad_mulobj_param_train_5x5" 13 | 14 | RUN_ARGUMENTS="${PY_NAME} --exp_name ${EXP_NAME} --phase ${PHASE} --num_points ${NUM_POINTS} --augmentation ${AUGMENTATION} --cad_mulobj_num_add_model ${CAD_MULOBJ_NUM_ADD_MODEL} --cad_mulobj_num_example ${CAD_MULOBJ_NUM_EXAMPLE} --cad_mulobj_num_ava_model ${CAD_MULOBJ_NUM_AVA_MODEL} --cad_mulobj_scene_radius ${CAD_MULOBJ_SCENE_RADIUS} --cad_mulobj_output_file_name ${CAD_MULOBJ_OUTPUT_FILE_NAME}" -------------------------------------------------------------------------------- /scripts/gen_data/gen_kitti_mulobj_test_5x5.sh: -------------------------------------------------------------------------------- 1 | # Generate KITTI multiple-object dataset 2 | 3 | EXP_NAME="results/gen_kitti_mulobj_test_5x5" 4 | PY_NAME="${HOME_DIR}/dataloaders/kittimulobj_loader.py" 5 | PHASE="gen_kittimulobj" 6 | NUM_POINTS="2048" 7 | KITTI_MULOBJ_NUM_ADD_MODEL="3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22" 8 | KITTI_MULOBJ_NUM_EXAMPLE="1 4 16 53 143 322 609 974 1328 1550 1550 1328 974 609 322 143 53 16 4 1" 9 | KITTI_MULOBJ_SCENE_RADIUS="5" 10 | KITTI_MULOBJ_OUTPUT_FILE_NAME="${HOME_DIR}/dataset/kittimulobj/kitti_mulobj_param_test_5x5_2048" 11 | 12 | RUN_ARGUMENTS="${PY_NAME} --exp_name ${EXP_NAME} --phase ${PHASE} --num_points ${NUM_POINTS} --kitti_mulobj_num_add_model ${KITTI_MULOBJ_NUM_ADD_MODEL} --kitti_mulobj_num_example ${KITTI_MULOBJ_NUM_EXAMPLE} --kitti_mulobj_scene_radius ${KITTI_MULOBJ_SCENE_RADIUS} --kitti_mulobj_output_file_name ${KITTI_MULOBJ_OUTPUT_FILE_NAME}" -------------------------------------------------------------------------------- /scripts/gen_data/gen_kitti_mulobj_train_5x5.sh: -------------------------------------------------------------------------------- 1 | # Generate KITTI multiple-object dataset 2 | 3 | EXP_NAME="results/gen_kitti_mulobj_train_5x5" 4 | PY_NAME="${HOME_DIR}/dataloaders/kittimulobj_loader.py" 5 | PHASE="gen_kittimulobj" 6 | NUM_POINTS="2048" 7 | KITTI_MULOBJ_NUM_ADD_MODEL="3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22" 8 | KITTI_MULOBJ_NUM_EXAMPLE="3 19 79 264 716 1612 3044 4871 6642 7749 7749 6642 4871 3044 1612 716 264 79 19 3" 9 | KITTI_MULOBJ_SCENE_RADIUS="5" 10 | KITTI_MULOBJ_OUTPUT_FILE_NAME="${HOME_DIR}/dataset/kittimulobj/kitti_mulobj_param_train_5x5_2048" 11 | 12 | RUN_ARGUMENTS="${PY_NAME} --exp_name ${EXP_NAME} --phase ${PHASE} --num_points ${NUM_POINTS} --kitti_mulobj_num_add_model ${KITTI_MULOBJ_NUM_ADD_MODEL} --kitti_mulobj_num_example ${KITTI_MULOBJ_NUM_EXAMPLE} --kitti_mulobj_scene_radius ${KITTI_MULOBJ_SCENE_RADIUS} --kitti_mulobj_output_file_name ${KITTI_MULOBJ_OUTPUT_FILE_NAME}" -------------------------------------------------------------------------------- /scripts/launch.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | USE_GPU="0" # specify a GPU 4 | export CUDA_VISIBLE_DEVICES=${USE_GPU} 5 | echo export CUDA_VISIBLE_DEVICES=${USE_GPU} 6 | 7 | HOME_DIR="$(pwd)" 8 | source $1 # load parameters 9 | 10 | mkdir -p ${EXP_NAME} # run it 11 | LOG_NAME="log.txt" 12 | echo python ${RUN_ARGUMENTS} 13 | python ${RUN_ARGUMENTS} | tee ${EXP_NAME}/${LOG_NAME} & 14 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/TearingNet/1e43caa266a11e9d50c2d912064bd39d369bb120/util/__init__.py -------------------------------------------------------------------------------- /util/cad_models_collector.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | A tool to collect the CAD models of "person", "car", "cone", "plant" from ModelNet40, and "motorbike" from ShapeNetPart 5 | ''' 6 | 7 | import os 8 | import numpy as np 9 | 10 | import sys 11 | sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/..') 12 | from dataloaders.modelnet_loader import ModelNet40, class_extractor 13 | from dataloaders.shapenet_part_loader import PartDataset 14 | 15 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 16 | data_path=os.path.abspath(os.path.join(BASE_DIR, '../dataset/cadmulobj')) + '/' 17 | 18 | 19 | def main(): 20 | 21 | label_person = 24 22 | label_car = 7 23 | label_cone = 9 24 | label_plant = 26 25 | from_shapenet_part = 'Motorbike' 26 | 27 | loader = ModelNet40(phase='train') 28 | list_person = class_extractor(label_person, loader) 29 | list_car = class_extractor(label_car, loader) 30 | list_cone = class_extractor(label_cone, loader) 31 | list_plant = class_extractor(label_plant, loader) 32 | loader = ModelNet40(phase='test') 33 | list_person = np.concatenate((list_person, class_extractor(label_person, loader)), axis=0) 34 | list_car = np.concatenate((list_car, class_extractor(label_car, loader)), axis=0) 35 | list_cone = np.concatenate((list_cone, class_extractor(label_cone, loader)), axis=0) 36 | list_plant = np.concatenate((list_plant, class_extractor(label_plant, loader)), axis=0) 37 | 38 | list_motorbike=[] 39 | loader = PartDataset(npoints=2048, classification=False, 40 | class_choice=from_shapenet_part, split='trainval', normalize=True) 41 | for i in range(len(loader)): 42 | list_motorbike.append(loader[i][0].numpy()) 43 | loader = PartDataset(npoints=2048, classification=False, 44 | class_choice=from_shapenet_part, split='test', normalize=True) 45 | for i in range(len(loader)): 46 | list_motorbike.append(loader[i][0].numpy()) 47 | list_motorbike = np.vstack(list_motorbike).reshape(-1,2048,3) 48 | 49 | interest_obj = {"person": list_person, "car":list_car, "cone": list_cone, "plant": list_plant, "motorbike": list_motorbike} 50 | np.save(os.path.join(data_path, "cad_models"), interest_obj) 51 | 52 | 53 | if __name__ == '__main__': 54 | main() -------------------------------------------------------------------------------- /util/mesh_writer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | Mesh writer 5 | ''' 6 | 7 | import numpy as np 8 | 9 | def dist(i, j): 10 | return (i[0]-j[0]) ** 2 + (i[1]-j[1]) ** 2 + (i[2]-j[2]) ** 2 11 | 12 | def write_ply_mesh(points, color, edge_color, filepath, thres=0, delete_point_mode=-1, weights=None, point_color_as_index=False, thres_edge=1e3): 13 | f = open(filepath, "w") 14 | 15 | if weights is None: 16 | weights = np.zeros((points.shape[0], points.shape[1], 4), dtype=int) 17 | for i in range(points.shape[0]): 18 | for j in range(points.shape[1]): 19 | if i > 0: weights[i,j,0] = 1e6 # left available 20 | if i < points.shape[0] - 1: weights[i,j,1] = 1e6 # right available 21 | if j > 0: weights[i,j,2] = 1e6 # top available 22 | if j < points.shape[1] - 1: weights[i,j,3] = 1e6 # bottom available 23 | 24 | # Initialize the map for drawing points 25 | idx = 0 26 | idx_map = np.zeros((points.shape[0], points.shape[1]), dtype=int) 27 | for i in range(points.shape[0]): 28 | for j in range(points.shape[1]): 29 | degree = sum(weights[i,j,:] > thres) 30 | if degree <= delete_point_mode: # found a point to be removed 31 | idx_map[i,j] = -1 32 | else: 33 | idx_map[i,j] = idx 34 | idx += 1 35 | 36 | # Calculate total number of edges 37 | edge_total = 0 38 | for i in range(points.shape[0]): 39 | for j in range(points.shape[1]): 40 | if idx_map[i,j] >= 0: 41 | if weights[i,j,0] > thres and idx_map[i-1,j] >= 0 and dist(points[i,j,:], points[i-1,j,:]) <= thres_edge: edge_total += 1 # left 42 | if weights[i,j,1] > thres and idx_map[i+1,j] >= 0 and dist(points[i,j,:], points[i+1,j,:]) <= thres_edge: edge_total += 1 # right 43 | if weights[i,j,2] > thres and idx_map[i,j-1] >= 0 and dist(points[i,j,:], points[i,j-1,:]) <= thres_edge: edge_total += 1 # top 44 | if weights[i,j,3] > thres and idx_map[i,j+1] >= 0 and dist(points[i,j,:], points[i,j+1,:]) <= thres_edge: edge_total += 1 # bottom 45 | 46 | # Calculate total number of faces 47 | face_total = 0 48 | for i in range(points.shape[0] - 1): 49 | for j in range(points.shape[1] - 1): 50 | if (weights[i,j,1] > thres) and (weights[i,j,3] > thres) and (weights[i+1,j+1,0]> thres) and (weights[i+1,j+1,2]> thres): 51 | face_total +=1 52 | 53 | # Write header 54 | f.write("ply\n") 55 | f.write("format ascii 1.0\n") 56 | f.write("element vertex " + str(idx) + "\n") 57 | f.write("property float x\n") 58 | f.write("property float y\n") 59 | f.write("property float z\n") 60 | f.write("property uchar red\n") 61 | f.write("property uchar green\n") 62 | f.write("property uchar blue\n") 63 | f.write("element face " + str(face_total * 4) + "\n") 64 | f.write("property list uchar int vertex_index\n") 65 | f.write("element edge " + str(edge_total) + "\n") 66 | f.write("property int vertex1\n") 67 | f.write("property int vertex2\n") 68 | f.write("property uchar red\n") 69 | f.write("property uchar green\n") 70 | f.write("property uchar blue\n") 71 | f.write("end_header\n") 72 | 73 | # Write points 74 | for i in range(points.shape[0]): 75 | for j in range(points.shape[1]): 76 | if idx_map[i,j] >= 0: 77 | f.write(str(points[i,j,0]) + " " + str(points[i,j,1]) + " " + str(points[i,j,2]) + " ") 78 | if point_color_as_index == True: f.write(str(i) + " " + str(j) + " 0\n") 79 | else: f.write(str(int(color[i,j,0] * 255)) + " " + str(int(color[i,j,1] * 255)) + " " + str(int(color[i,j,2] * 255)) + "\n") 80 | 81 | # Write faces 82 | for i in range(points.shape[0] - 1): 83 | for j in range(points.shape[1] - 1): 84 | if (weights[i,j,1] > thres) and (weights[i,j,3] > thres) and (weights[i+1,j+1,0]> thres) and (weights[i+1,j+1,2]> thres): 85 | # f.write("4 " + str(idx_map[i+1,j]) + " " + str(idx_map[i+1,j+1]) + " " + str(idx_map[i,j+1]) + " " + str(idx_map[i,j]) + "\n") 86 | f.write("3 " + str(idx_map[i+1,j]) + " " + str(idx_map[i+1,j+1]) + " " + str(idx_map[i,j+1]) + "\n") 87 | f.write("3 " + str(idx_map[i,j+1]) + " " + str(idx_map[i,j]) + " " + str(idx_map[i+1,j]) + "\n") 88 | f.write("3 " + str(idx_map[i,j+1]) + " " + str(idx_map[i+1,j+1]) + " " + str(idx_map[i+1,j]) + "\n") 89 | f.write("3 " + str(idx_map[i+1,j]) + " " + str(idx_map[i,j]) + " " + str(idx_map[i,j+1]) + "\n") 90 | 91 | # Write meshes 92 | for i in range(points.shape[0]): 93 | for j in range(points.shape[1]): 94 | if idx_map[i,j] >= 0: 95 | if weights[i,j,0] > thres and idx_map[i-1,j] >= 0 and dist(points[i,j,:], points[i-1,j,:]) <= thres_edge: # left 96 | f.write(str(idx_map[i,j]) + " " + str(idx_map[i-1,j]) + " " + str(int(edge_color[0] * 255)) + 97 | " " + str(int(edge_color[1] * 255)) + " " + str(int(edge_color[2] * 255)) + "\n") 98 | if weights[i,j,1] > thres and idx_map[i+1,j] >= 0 and dist(points[i,j,:], points[i+1,j,:]) <= thres_edge: # right 99 | f.write(str(idx_map[i,j]) + " " + str(idx_map[i+1,j]) + " " + str(int(edge_color[0] * 255)) + 100 | " " + str(int(edge_color[1] * 255)) + " " + str(int(edge_color[2] * 255)) + "\n") 101 | if weights[i,j,2] > thres and idx_map[i,j-1] >= 0 and dist(points[i,j,:], points[i,j-1,:]) <= thres_edge: # top 102 | f.write(str(idx_map[i,j]) + " " + str(idx_map[i,j-1]) + " " + str(int(edge_color[0] * 255)) + 103 | " " + str(int(edge_color[1] * 255)) + " " + str(int(edge_color[2] * 255)) + "\n") 104 | if weights[i,j,3] > thres and idx_map[i,j+1] >= 0 and dist(points[i,j,:], points[i,j+1,:]) <= thres_edge: # bottom 105 | f.write(str(idx_map[i,j]) + " " + str(idx_map[i,j+1]) + " " + str(int(edge_color[0] * 255)) + 106 | " " + str(int(edge_color[1] * 255)) + " " + str(int(edge_color[2] * 255)) + "\n") 107 | f.close() -------------------------------------------------------------------------------- /util/option_handler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | Option handler 5 | ''' 6 | 7 | import argparse 8 | import models 9 | 10 | def str2bool(v): 11 | if isinstance(v, bool): 12 | return v 13 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 14 | return True 15 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 16 | return False 17 | else: 18 | raise argparse.ArgumentTypeError('Boolean value expected.') 19 | 20 | class BasicOptionHandler(): 21 | 22 | def add_options(self, parser): 23 | parser.add_argument('--exp_name', type=str, default='experiment_name', help='Name of the experiment, folders are created by this name.') 24 | parser.add_argument('--phase', type=str, default='train', help='Indicate current phase, train or test') 25 | parser.add_argument('--encoder', type=str, default='pointnetvanilla', help='Choice of the encoder.') 26 | parser.add_argument('--decoder', type=str, default='foldingnetvanilla', help='Choice of the decoder.') 27 | parser.add_argument('--checkpoint', type=str, default='', help='Restore from a indicated checkpoint.') 28 | parser.add_argument('--load_weight_only', type=str2bool, nargs='?', const=True, default=False, help='Load the model weight only when restoring from a checkpoint.') 29 | parser.add_argument('--xyz_loss_type', type=int, default=0, help='Choose the loss type for point cloud.') 30 | parser.add_argument('--xyz_chamfer_weight', type=float, default=1, help='Balance the two terms of the Chamfer distance.') 31 | parser.add_argument('--batch_size', type=int, default=8, help='Batch size when loding the dataset.') 32 | parser.add_argument('--num_points', type=int, default=2048, help='Input point set size') 33 | parser.add_argument('--dataset_name', default='', help='Dataset name') 34 | parser.add_argument('--config_from_checkpoint', type=str2bool, nargs='?', const=True, default=False, help='Load the model configuration form checkpoint.') 35 | parser.add_argument('--tf_summary', type=str2bool, nargs='?', const=True, default=True, help='Whether to use tensorboard for log.') 36 | return parser 37 | 38 | def parse_options(self): 39 | 40 | # Initialize parser with basic options 41 | parser = argparse.ArgumentParser( 42 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 43 | parser = self.add_options(parser) 44 | 45 | # Get the basic options 46 | opt, _ = parser.parse_known_args() 47 | 48 | # Train or test 49 | if opt.phase.lower() == 'train': 50 | self.isTrain = True 51 | elif opt.phase.lower() == 'test' or opt.phase.lower() == 'counting': 52 | self.isTrain = False 53 | elif (opt.phase.lower() == 'gen_cadmultiobj') or (opt.phase.lower() == 'kitti_data') or (opt.phase.lower() == 'gen_kittimulobj'): 54 | self.parser = parser 55 | self.isTrain = False 56 | return opt 57 | else: 58 | print("The phase [{}] does not exist.".format(opt.phase)) 59 | exit(0) 60 | opt.isTrain = self.isTrain # train or test 61 | 62 | # Add options to the parser according to the chosen models 63 | encoder_option_setter = models.get_model_class(opt.encoder).add_options 64 | parser = encoder_option_setter(parser, self.isTrain) 65 | decoder_option_setter = models.get_model_class(opt.decoder).add_options 66 | parser = decoder_option_setter(parser, self.isTrain) 67 | 68 | self.parser = parser 69 | opt, _ = parser.parse_known_args() 70 | return opt 71 | 72 | def print_options(self, opt): 73 | message = '' 74 | message += '----------------- Options ---------------\n' 75 | # For k, v in sorted(vars(opt).items()): 76 | for k, v in vars(opt).items(): 77 | comment = '' 78 | default = self.parser.get_default(k) 79 | if v != default: 80 | comment = '\t[default: %s]' % str(default) 81 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 82 | message += '----------------- End -------------------' 83 | print(message) 84 | 85 | class TrainOptionHandler(BasicOptionHandler): 86 | 87 | def add_options(self, parser): 88 | parser = BasicOptionHandler.add_options(self, parser) 89 | parser.add_argument('--n_epoch', type=int, default=100, help='Number of epoch to train.') 90 | parser.add_argument('--print_freq', type=int, default=20, help='Frequency of displaying training results.') 91 | parser.add_argument('--pc_write_freq', type=int, default=1000, help='Frequency of writing down the point cloud during training, <= 0 means do not write.') 92 | parser.add_argument('--save_epoch_freq', type=int, default=2, help='Frequency of saving the trained model.') 93 | parser.add_argument('--lr', type=float, nargs='+', default=[0.0002], help='Learning rate and its related parameters.') 94 | parser.add_argument('--train_split', type=str, default='train', help='The split of the dataset for training.') 95 | parser.add_argument('--lr_module', type=float, nargs='+', help='Specify the learning rate of the modules.') 96 | parser.add_argument('--optim_args', type=float, nargs='+', default=[0.9, 0.999, 0], help='Parameters of the ADAM optimizer.') 97 | parser.add_argument('--augmentation', type=str2bool, nargs='?', const=True, default=False, help='Apply data augmentation or not.') 98 | parser.add_argument('--augmentation_theta', type=int, default=-1, help='Choice of the rotation angle from [0,90,180,270].') 99 | parser.add_argument('--augmentation_rotation_axis', type=int, default=-1, help='Choice of the rotation axis from np.eye(3).') 100 | parser.add_argument('--augmentation_flip_axis', type=int, default=-1, help='The axis around which to flip for augmentation.') 101 | parser.add_argument('--augmentation_min_scale', type=float, default=1, help='Minimum of the scaling factor for augmentation.') 102 | parser.add_argument('--augmentation_max_scale', type=float, default=1, help='Maximum of the scaling factor for augmentation.') 103 | 104 | return parser 105 | 106 | class TestOptionHandler(BasicOptionHandler): 107 | 108 | def add_options(self, parser): 109 | parser = BasicOptionHandler.add_options(self, parser) 110 | parser.add_argument('--test_split', type=str, default='test', help='Specify the split being used.') 111 | parser.add_argument('--print_freq', type=int, default=20, help='Frequency of displaying results during testing.') 112 | parser.add_argument('--pc_write_freq', type=int, default=1000, help='Frequency of writing down the point cloud during testing.') 113 | parser.add_argument('--gt_color', type=float, nargs='+', default=[0, 0, 0], help='Color of the ground-truth point cloud.') 114 | parser.add_argument('--write_mesh', type=str2bool, nargs='?', const=True, default=False, help='Write down meshes or not.') 115 | parser.add_argument('--graph_thres', type=float, default=-1, help='Threshold of the graph edges.') 116 | parser.add_argument('--graph_edge_color', type=float, nargs='+', default=[0.5, 0.5, 0.5], help='Color of the graph edges.') 117 | parser.add_argument('--graph_delete_point_mode', type=int, default=-1, help='Mode of removing points: -1: no removal; 0: remove those without edge; 1: remove those with one edge.') 118 | parser.add_argument('--graph_delete_point_eps', type=float, default=-1, help='The epsilon to used when evaluating the point-deleted-version point cloud.') 119 | parser.add_argument('--thres_edge', type=float, default=3, help='Threshold of the length of the edge to be written.') 120 | parser.add_argument('--point_color_as_index', type=str2bool, nargs='?', const=True, default=False, help='Whether regard the point cloud color as point index.') 121 | return parser 122 | 123 | class CountingOptionHandler(BasicOptionHandler): 124 | 125 | def add_options(self, parser): 126 | parser = BasicOptionHandler.add_options(self, parser) 127 | parser.add_argument('--count_split', type=str, default='test', help='The split of the dataset for object counting.') 128 | parser.add_argument('--count_fold', type=int, default=4, help='Number of fold for doing the object counting experiment.') 129 | parser.add_argument('--epoch_interval', type=int, nargs='+', default=[-1, 10], help='Range of the epoch to do object counting.') 130 | parser.add_argument('--print_freq', type=int, default=20, help='Frequency of displaying object counting results.') 131 | parser.add_argument('--svm_params', type=float, nargs='+', default=[1e3], help='SVM parameters.') 132 | return parser 133 | 134 | class GenerateCADMultiObjectOptionHandler(BasicOptionHandler): 135 | 136 | def add_options(self, parser): 137 | parser = BasicOptionHandler.add_options(self, parser) 138 | parser.add_argument('--cad_mulobj_num_add_model', type=int, nargs='+', default=[2, 3, 4], help='A list. Numbers of 3D models to add into the scene.') 139 | parser.add_argument('--cad_mulobj_num_example', type=int, nargs='+', default=[1000, 1000, 1000], help='A list. Numbers of examples corresponds to each type of scene.') 140 | parser.add_argument('--cad_mulobj_num_ava_model', type=int, default=1000, help='Total number of available models in the original dataset.') 141 | parser.add_argument('--cad_mulobj_scene_radius', type=float, default=5, help='Radius of the generated scene.') 142 | parser.add_argument('--cad_mulobj_output_file_name', type=str, required=True, help='Path and file name of the output file parametrizing the dataset.') 143 | parser.add_argument('--augmentation', type=str2bool, nargs='?', const=True, default=False, help='Apply data augmentation or not.') 144 | return parser 145 | 146 | class GenerateKittiMultiObjectOptionHandler(BasicOptionHandler): 147 | 148 | def add_options(self, parser): 149 | parser = BasicOptionHandler.add_options(self, parser) 150 | parser.add_argument('--kitti_mulobj_num_add_model', type=int, nargs='+', default=[2, 3, 4], help='A list. Numbers of 3D models to add into the scene.') 151 | parser.add_argument('--kitti_mulobj_num_example', type=int, nargs='+', default=[1000, 1000, 1000], help='A list. Numbers of examples corresponds to each type of scene.') 152 | parser.add_argument('--kitti_mulobj_scene_radius', type=float, default=5, help='Radius of the generated scene.') 153 | parser.add_argument('--kitti_mulobj_output_file_name', type=str, required=True, help='Path and the filename of the output file parametrizing the dataset.') 154 | return parser -------------------------------------------------------------------------------- /util/pcdet_create_groundtruth_database.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Modify based on https://github.com/open-mmlab/OpenPCDet/blob/master/pcdet/datasets/kitti/kitti_dataset.py 3 | replace the create_groundtruth_database() function in kitti_dataset.py of OpenPCDet by the following one. 4 | ''' 5 | 6 | def create_groundtruth_database(self, info_path=None, used_classes=None, split='train'): 7 | import torch 8 | from itertools import combinations 9 | 10 | database_save_path = Path(self.root_path) / ('kitti_single' if split == 'train' else ('kitti_single_%s' % split)) 11 | database_save_path.mkdir(parents=True, exist_ok=True) 12 | with open(info_path, 'rb') as f: 13 | infos = pickle.load(f) 14 | 15 | # Parameters 16 | num_point = 1536 17 | max_num_obj = 22 18 | interest_class = ['Pedestrian', 'Car', 'Cyclist', 'Van', 'Truck'] 19 | all_db_infos=[] 20 | db_obj=[] 21 | for obj_col in range(max_num_obj): 22 | all_db_infos.append([]) 23 | all_db_infos.append(np.zeros(max_num_obj, dtype=int)) # accumulator 24 | db_obj_save_path = Path(self.root_path) / ('kitti_dbinfos_object.pkl') 25 | 26 | frame_obj_list = np.zeros((len(infos), max_num_obj), dtype=int) 27 | for k in range(len(infos)): 28 | info = infos[k] 29 | sample_idx = info['point_cloud']['lidar_idx'] 30 | points = self.get_lidar(sample_idx) 31 | annos = info['annos'] 32 | names = annos['name'] 33 | difficulty = annos['difficulty'] 34 | bbox = annos['bbox'] 35 | gt_boxes = annos['gt_boxes_lidar'] 36 | 37 | num_obj = gt_boxes.shape[0] 38 | point_indices = roiaware_pool3d_utils.points_in_boxes_cpu( 39 | torch.from_numpy(points[:, 0:3]), torch.from_numpy(gt_boxes) 40 | ).numpy() # (nboxes, npoints) 41 | 42 | # Count the object occurence 43 | for obj_id in range(num_obj): 44 | if (names[obj_id] in interest_class) == False: continue 45 | filename = '%d_%d_%s.bin' % (k, obj_id, names[obj_id]) 46 | filepath = database_save_path / filename 47 | db_path = str(filepath.relative_to(self.root_path)) # kitti_single/xxxxx.bin 48 | gt_points = points[point_indices[obj_id] > 0] 49 | 50 | db_obj.append({'name': names[obj_id], 'path': db_path, 'image_idx': sample_idx, 'gt_idx': obj_id, 51 | 'box3d_lidar': gt_boxes[obj_id], 'num_points_in_gt': gt_points.shape[0], 52 | 'difficulty': difficulty[obj_id], 'bbox': bbox[obj_id], 'score': annos['score'][obj_id]}) 53 | 54 | with open(filepath, 'w') as f: 55 | gt_points.tofile(f) 56 | for obj_col in range(max_num_obj): 57 | if gt_points.shape[0] >= np.ceil(num_point/(obj_col+1)).astype(int): 58 | frame_obj_list[k,obj_col] += 2 ** obj_id 59 | 60 | # Conclude how a frame can be used 61 | for obj_col in range(max_num_obj): 62 | if bin(frame_obj_list[k,obj_col])[2:].count('1') >= obj_col + 1: 63 | obj_indicator = np.array(list(bin(frame_obj_list[k,obj_col])[2:].zfill(max_num_obj)[::-1]))=='1' 64 | obj_choice = np.arange(max_num_obj)[obj_indicator] 65 | comb = combinations(obj_choice, obj_col+1) 66 | for obj_scene in list(comb): 67 | 68 | # Write down the scene configuration 69 | db_info=[] 70 | for obj_id in obj_scene: 71 | filename = '%d_%d_%s.bin' % (k, obj_id, names[obj_id]) 72 | filepath = database_save_path / filename 73 | db_path = str(filepath.relative_to(self.root_path)) # kitti_single/xxxxx.bin 74 | db_info.append({'name': names[obj_id], 'path': db_path, 'image_idx': sample_idx, 'gt_idx': obj_id, 75 | 'box3d_lidar': gt_boxes[obj_id], 'difficulty': difficulty[obj_id], 'bbox': bbox[obj_id], 76 | 'score': annos['score'][obj_id]}) 77 | all_db_infos[len(obj_scene)-1].append(db_info) 78 | all_db_infos[-1][len(obj_scene)-1] += 1 79 | print(k,obj_scene) 80 | 81 | with open(db_obj_save_path, 'wb') as f: 82 | print(f.name) 83 | pickle.dump(db_obj, f) 84 | --------------------------------------------------------------------------------