├── LICENSE ├── README.md ├── TRAINED_MODEL_LICENSE ├── checkpoints └── download_models.sh ├── data ├── __init__.py ├── base_dataset.py ├── grasp_evaluator_data.py └── grasp_sampling_data.py ├── demo ├── __init__.py ├── data │ ├── blue_mug.npy │ ├── cheezit.npy │ ├── cylinder.npy │ ├── green_bowl.npy │ ├── mustard.npy │ ├── pepper.npy │ ├── red_mug.npy │ ├── sugar.npy │ ├── white_bowl.npy │ └── white_mug.npy ├── examples │ ├── 1.png │ └── 2.png └── main.py ├── eval.py ├── grasp_estimator.py ├── gripper_control_points └── panda.npy ├── gripper_models ├── featuretype.STL ├── mug.obj ├── panda_gripper.obj ├── panda_gripper │ ├── finger.stl │ └── hand.stl ├── panda_pc.npy └── yumi_gripper │ ├── base.stl │ ├── base_coarse.stl │ ├── finger.stl │ └── finger_coarse.stl ├── models ├── __init__.py ├── grasp_net.py ├── losses.py └── networks.py ├── options ├── __init__.py ├── base_options.py ├── test_options.py └── train_options.py ├── renderer ├── __init__.py ├── object_renderer.py └── online_object_renderer.py ├── requirements.txt ├── shapenet_ids.txt ├── test.py ├── train.py ├── uniform_quaternions ├── data2_4608.qua └── data3_36864.qua └── utils ├── __init__.py ├── sample.py ├── surface_normal.py ├── utils.py ├── visualization_utils.py └── writer.py /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2020 NVIDIA Corporation 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 6-DoF GraspNet: Variational Grasp Generation for Object Manipulation 2 | 3 | This is a PyTorch implementation of [6-DoF 4 | GraspNet](https://arxiv.org/abs/1905.10520). The original Tensorflow 5 | implementation can be found here . 6 | 7 | # License 8 | 9 | The source code is released under [MIT License](LICENSE) and the trained weights are released under [CC-BY-NC-SA 2.0](TRAINED_MODEL_LICENSE). 10 | 11 | ## Installation 12 | 13 | This code has been tested with python 3.6, PyTorch 1.4 and CUDA 10.0 on Ubuntu 14 | 18.04. To install do 15 | 16 | 1) `pip3 install torch==1.4.0+cu100 torchvision==0.5.0+cu100 -f ` 17 | 18 | 2) Clone this repository: `git clone 19 | git@github.com:jsll/pytorch_6dof-graspnet.git`. 20 | 21 | 3) Clone pointnet++: `git@github.com:erikwijmans/Pointnet2_PyTorch.git`. 22 | 23 | 4) Run `cd Pointnet2_PyTorch && pip3 install -r requirements.txt` 24 | 25 | 5) `cd pytorch_6dof-graspnet` 26 | 27 | 6) Run `pip3 install -r requirements.txt` to install necessary python libraries. 28 | 29 | 7) (Optional) Download the trained models either by running `sh 30 | checkpoints/download_models.sh` or manually from [here](https://drive.google.com/file/d/1B0EeVlHbYBki__WszkbY8A3Za941K8QI/view?usp=sharing). Trained 31 | models are released under [CC-BY-NC-SA 2.0](TRAINED_MODEL_LICENSE). 32 | 33 | ## Disclaimer 34 | 35 | The pre-trained models released in this repo are retrained from scratch and not converted from the original ones trained in Tensorflow. I tried to convert the Tensorflow models but with no luck. Although I trained the new models for a substantial amount of time on the same training data, no guarantees to their performance compared to the original work can be given. 36 | 37 | ## Updates 38 | 39 | In the paper, the authors only used gradient-based refinement. Recently, they released a Metropolis-Hastings 40 | sampling method which they found to give better results in shorter time. As a result, I keep the Metropolis-Hastings sampling as the default for the demo. 41 | 42 | This repository also includes an improved grasp sampling network which was 43 | proposed here . The new grasp sampling 44 | network is trained with [Implicit Maximum Likelihood Estimation](https://arxiv.org/pdf/2004.03590.pdf). 45 | 46 | ### Update 9th June 2020 47 | 48 | I have now uploaded new models that are trained for longer and until the test loss flattened. The new models can be downloaded in the same way as detailed in step 7 above. 49 | 50 | ## Demo 51 | 52 | Run the demo using the command below 53 | 54 | ```shell 55 | python -m demo.main 56 | ``` 57 | 58 | Per default, the demo script runs the GAN sampler with sampling based 59 | refinement. To use the VAE sampler and/or gradient refinement run: 60 | 61 | ```shell 62 | python -m demo.main --grasp_sampler_folder checkpoints/vae_pretrained/ --refinement_method gradient 63 | ``` 64 | 65 | ![example](demo/examples/1.png) ![example](demo/examples/2.png) 66 | 67 | ## Dataset 68 | 69 | ### Get ShapeNet Models 70 | 71 | Download the meshes with ids written in [shapenet_ids.txt](shapenet_ids.txt) from Some of the objects are in `ShapenetCore` and `ShapenetSem`. 72 | 73 | ### Prepare ShapeNet Models 74 | 75 | 1. Clone and build: 76 | 2. Create a watertight mesh version assuming the object path is model.obj: `manifold model.obj temp.watertight.obj -s` 77 | 3. Simplify it: `simplify -i temp.watertight.obj -o model.obj -m -r 0.02` 78 | 79 | ### Download the dataset 80 | 81 | The dataset can be downloaded from [here](https://drive.google.com/open?id=1GkFrkvpP-R1letnv6rt_WLSX80o43Jjm). The dataset has 3 folders: 82 | 83 | 1) `grasps` folder: contains all the grasps for each object. 84 | 2) `meshes` folder: has the folder for all the meshes used. Except `cylinder` and `box` the rest of the folders are empty and need to be populated by the downloaded meshes from shapenet. 85 | 3) `splits` folder: contains the train/test split for each of the categories. 86 | 87 | Verify the dataset by running `python grasp_data_reader.py` to visualize the evaluator data and `python grasp_data_reader.py --vae-mode` to visualize only the positive grasps. 88 | 89 | ## Training 90 | 91 | To train the grasp sampler (vae or gan) or the evaluator with bare minimum configurations run: 92 | 93 | ```shell 94 | python3 train.py --arch {vae,gan,evaluator} --dataset_root_folder $DATASET_ROOT_FOLDER 95 | ``` 96 | 97 | where the `$DATASET_ROOT_FOLDER` is the path to the dataset you downloaded. 98 | 99 | To monitor the training, run `tensorboard --logdir checkpoints/` and click . 100 | 101 | For more training options run 102 | GAN Training Example Command: 103 | 104 | ```shell 105 | python3 train.py --help 106 | ``` 107 | 108 | ## Quantitative Evaluation 109 | 110 | I have not converted the code for doing quantitative evaluation 111 | to PyTorch. I 112 | would appreciate it if someone could convert it and send in a pull request. 113 | 114 | ## Citation 115 | 116 | If you find this work useful in your research, please consider citing the 117 | original authors' work: 118 | 119 | ``` 120 | inproceedings{mousavian2019graspnet, 121 | title={6-DOF GraspNet: Variational Grasp Generation for Object Manipulation}, 122 | author={Arsalan Mousavian and Clemens Eppner and Dieter Fox}, 123 | booktitle={International Conference on Computer Vision (ICCV)}, 124 | year={2019} 125 | } 126 | ``` 127 | 128 | as well as 129 | 130 | ``` 131 | @inproceedings{lundell2023constrained, 132 | title={Constrained generative sampling of 6-dof grasps}, 133 | author={Lundell, Jens and Verdoja, Francesco and Le, Tran Nguyen and Mousavian, Arsalan and Fox, Dieter and Kyrki, Ville}, 134 | booktitle={2023 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)}, 135 | pages={2940--2946}, 136 | year={2023}, 137 | organization={IEEE} 138 | } 139 | ``` 140 | -------------------------------------------------------------------------------- /TRAINED_MODEL_LICENSE: -------------------------------------------------------------------------------- 1 | CC BY-NC-SA 2.0 2 | 3 | THE WORK (AS DEFINED BELOW) IS PROVIDED UNDER THE TERMS OF THIS CREATIVE COMMONS PUBLIC LICENSE ("CCPL" OR "LICENSE"). THE WORK IS PROTECTED BY COPYRIGHT AND/OR OTHER APPLICABLE LAW. ANY USE OF THE WORK OTHER THAN AS AUTHORIZED UNDER THIS LICENSE OR COPYRIGHT LAW IS PROHIBITED. 4 | 5 | BY EXERCISING ANY RIGHTS TO THE WORK PROVIDED HERE, YOU ACCEPT AND AGREE TO BE BOUND BY THE TERMS OF THIS LICENSE. THE LICENSOR GRANTS YOU THE RIGHTS CONTAINED HERE IN CONSIDERATION OF YOUR ACCEPTANCE OF SUCH TERMS AND CONDITIONS. 6 | 7 | 1. Definitions 8 | 9 | "Collective Work" means a work, such as a periodical issue, anthology or encyclopedia, in which the Work in its entirety in unmodified form, along with a number of other contributions, constituting separate and independent works in themselves, are assembled into a collective whole. A work that constitutes a Collective Work will not be considered a Derivative Work (as defined below) for the purposes of this License. 10 | "Derivative Work" means a work based upon the Work or upon the Work and other pre-existing works, such as a translation, musical arrangement, dramatization, fictionalization, motion picture version, sound recording, art reproduction, abridgment, condensation, or any other form in which the Work may be recast, transformed, or adapted, except that a work that constitutes a Collective Work will not be considered a Derivative Work for the purpose of this License. For the avoidance of doubt, where the Work is a musical composition or sound recording, the synchronization of the Work in timed-relation with a moving image ("synching") will be considered a Derivative Work for the purpose of this License. 11 | "Licensor" means the individual or entity that offers the Work under the terms of this License. 12 | "Original Author" means the individual or entity who created the Work. 13 | "Work" means the copyrightable work of authorship offered under the terms of this License. 14 | "You" means an individual or entity exercising rights under this License who has not previously violated the terms of this License with respect to the Work, or who has received express permission from the Licensor to exercise rights under this License despite a previous violation. 15 | "License Elements" means the following high-level license attributes as selected by Licensor and indicated in the title of this License: Attribution, Noncommercial, ShareAlike. 16 | 2. Fair Use Rights. Nothing in this license is intended to reduce, limit, or restrict any rights arising from fair use, first sale or other limitations on the exclusive rights of the copyright owner under copyright law or other applicable laws. 17 | 18 | 3. License Grant. Subject to the terms and conditions of this License, Licensor hereby grants You a worldwide, royalty-free, non-exclusive, perpetual (for the duration of the applicable copyright) license to exercise the rights in the Work as stated below: 19 | 20 | to reproduce the Work, to incorporate the Work into one or more Collective Works, and to reproduce the Work as incorporated in the Collective Works; 21 | to create and reproduce Derivative Works; 22 | to distribute copies or phonorecords of, display publicly, perform publicly, and perform publicly by means of a digital audio transmission the Work including as incorporated in Collective Works; 23 | to distribute copies or phonorecords of, display publicly, perform publicly, and perform publicly by means of a digital audio transmission Derivative Works; 24 | The above rights may be exercised in all media and formats whether now known or hereafter devised. The above rights include the right to make such modifications as are technically necessary to exercise the rights in other media and formats. All rights not expressly granted by Licensor are hereby reserved, including but not limited to the rights set forth in Sections 4(e) and 4(f). 25 | 26 | 4. Restrictions.The license granted in Section 3 above is expressly made subject to and limited by the following restrictions: 27 | 28 | You may distribute, publicly display, publicly perform, or publicly digitally perform the Work only under the terms of this License, and You must include a copy of, or the Uniform Resource Identifier for, this License with every copy or phonorecord of the Work You distribute, publicly display, publicly perform, or publicly digitally perform. You may not offer or impose any terms on the Work that alter or restrict the terms of this License or the recipients' exercise of the rights granted hereunder. You may not sublicense the Work. You must keep intact all notices that refer to this License and to the disclaimer of warranties. You may not distribute, publicly display, publicly perform, or publicly digitally perform the Work with any technological measures that control access or use of the Work in a manner inconsistent with the terms of this License Agreement. The above applies to the Work as incorporated in a Collective Work, but this does not require the Collective Work apart from the Work itself to be made subject to the terms of this License. If You create a Collective Work, upon notice from any Licensor You must, to the extent practicable, remove from the Collective Work any reference to such Licensor or the Original Author, as requested. If You create a Derivative Work, upon notice from any Licensor You must, to the extent practicable, remove from the Derivative Work any reference to such Licensor or the Original Author, as requested. 29 | You may distribute, publicly display, publicly perform, or publicly digitally perform a Derivative Work only under the terms of this License, a later version of this License with the same License Elements as this License, or a Creative Commons iCommons license that contains the same License Elements as this License (e.g. Attribution-NonCommercial-ShareAlike 2.0 Japan). You must include a copy of, or the Uniform Resource Identifier for, this License or other license specified in the previous sentence with every copy or phonorecord of each Derivative Work You distribute, publicly display, publicly perform, or publicly digitally perform. You may not offer or impose any terms on the Derivative Works that alter or restrict the terms of this License or the recipients' exercise of the rights granted hereunder, and You must keep intact all notices that refer to this License and to the disclaimer of warranties. You may not distribute, publicly display, publicly perform, or publicly digitally perform the Derivative Work with any technological measures that control access or use of the Work in a manner inconsistent with the terms of this License Agreement. The above applies to the Derivative Work as incorporated in a Collective Work, but this does not require the Collective Work apart from the Derivative Work itself to be made subject to the terms of this License. 30 | You may not exercise any of the rights granted to You in Section 3 above in any manner that is primarily intended for or directed toward commercial advantage or private monetary compensation. The exchange of the Work for other copyrighted works by means of digital file-sharing or otherwise shall not be considered to be intended for or directed toward commercial advantage or private monetary compensation, provided there is no payment of any monetary compensation in connection with the exchange of copyrighted works. 31 | If you distribute, publicly display, publicly perform, or publicly digitally perform the Work or any Derivative Works or Collective Works, You must keep intact all copyright notices for the Work and give the Original Author credit reasonable to the medium or means You are utilizing by conveying the name (or pseudonym if applicable) of the Original Author if supplied; the title of the Work if supplied; to the extent reasonably practicable, the Uniform Resource Identifier, if any, that Licensor specifies to be associated with the Work, unless such URI does not refer to the copyright notice or licensing information for the Work; and in the case of a Derivative Work, a credit identifying the use of the Work in the Derivative Work (e.g., "French translation of the Work by Original Author," or "Screenplay based on original Work by Original Author"). Such credit may be implemented in any reasonable manner; provided, however, that in the case of a Derivative Work or Collective Work, at a minimum such credit will appear where any other comparable authorship credit appears and in a manner at least as prominent as such other comparable authorship credit. 32 | For the avoidance of doubt, where the Work is a musical composition: 33 | 34 | Performance Royalties Under Blanket Licenses. Licensor reserves the exclusive right to collect, whether individually or via a performance rights society (e.g. ASCAP, BMI, SESAC), royalties for the public performance or public digital performance (e.g. webcast) of the Work if that performance is primarily intended for or directed toward commercial advantage or private monetary compensation. 35 | Mechanical Rights and Statutory Royalties. Licensor reserves the exclusive right to collect, whether individually or via a music rights agency or designated agent (e.g. Harry Fox Agency), royalties for any phonorecord You create from the Work ("cover version") and distribute, subject to the compulsory license created by 17 USC Section 115 of the US Copyright Act (or the equivalent in other jurisdictions), if Your distribution of such cover version is primarily intended for or directed toward commercial advantage or private monetary compensation. 36 | Webcasting Rights and Statutory Royalties. For the avoidance of doubt, where the Work is a sound recording, Licensor reserves the exclusive right to collect, whether individually or via a performance-rights society (e.g. SoundExchange), royalties for the public digital performance (e.g. webcast) of the Work, subject to the compulsory license created by 17 USC Section 114 of the US Copyright Act (or the equivalent in other jurisdictions), if Your public digital performance is primarily intended for or directed toward commercial advantage or private monetary compensation. 37 | 5. Representations, Warranties and Disclaimer 38 | 39 | UNLESS OTHERWISE MUTUALLY AGREED TO BY THE PARTIES IN WRITING, LICENSOR OFFERS THE WORK AS-IS AND MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE WORK, EXPRESS, IMPLIED, STATUTORY OR OTHERWISE, INCLUDING, WITHOUT LIMITATION, WARRANTIES OF TITLE, MERCHANTIBILITY, FITNESS FOR A PARTICULAR PURPOSE, NONINFRINGEMENT, OR THE ABSENCE OF LATENT OR OTHER DEFECTS, ACCURACY, OR THE PRESENCE OF ABSENCE OF ERRORS, WHETHER OR NOT DISCOVERABLE. SOME JURISDICTIONS DO NOT ALLOW THE EXCLUSION OF IMPLIED WARRANTIES, SO SUCH EXCLUSION MAY NOT APPLY TO YOU. 40 | 41 | 6. Limitation on Liability. EXCEPT TO THE EXTENT REQUIRED BY APPLICABLE LAW, IN NO EVENT WILL LICENSOR BE LIABLE TO YOU ON ANY LEGAL THEORY FOR ANY SPECIAL, INCIDENTAL, CONSEQUENTIAL, PUNITIVE OR EXEMPLARY DAMAGES ARISING OUT OF THIS LICENSE OR THE USE OF THE WORK, EVEN IF LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 42 | 43 | 7. Termination 44 | 45 | This License and the rights granted hereunder will terminate automatically upon any breach by You of the terms of this License. Individuals or entities who have received Derivative Works or Collective Works from You under this License, however, will not have their licenses terminated provided such individuals or entities remain in full compliance with those licenses. Sections 1, 2, 5, 6, 7, and 8 will survive any termination of this License. 46 | Subject to the above terms and conditions, the license granted here is perpetual (for the duration of the applicable copyright in the Work). Notwithstanding the above, Licensor reserves the right to release the Work under different license terms or to stop distributing the Work at any time; provided, however that any such election will not serve to withdraw this License (or any other license that has been, or is required to be, granted under the terms of this License), and this License will continue in full force and effect unless terminated as stated above. 47 | 8. Miscellaneous 48 | 49 | Each time You distribute or publicly digitally perform the Work or a Collective Work, the Licensor offers to the recipient a license to the Work on the same terms and conditions as the license granted to You under this License. 50 | Each time You distribute or publicly digitally perform a Derivative Work, Licensor offers to the recipient a license to the original Work on the same terms and conditions as the license granted to You under this License. 51 | If any provision of this License is invalid or unenforceable under applicable law, it shall not affect the validity or enforceability of the remainder of the terms of this License, and without further action by the parties to this agreement, such provision shall be reformed to the minimum extent necessary to make such provision valid and enforceable. 52 | No term or provision of this License shall be deemed waived and no breach consented to unless such waiver or consent shall be in writing and signed by the party to be charged with such waiver or consent. 53 | This License constitutes the entire agreement between the parties with respect to the Work licensed here. There are no understandings, agreements or representations with respect to the Work not specified here. Licensor shall not be bound by any additional provisions that may appear in any communication from You. This License may not be modified without the mutual written agreement of the Licensor and You. -------------------------------------------------------------------------------- /checkpoints/download_models.sh: -------------------------------------------------------------------------------- 1 | # Download the pre-trained networks network 2 | wget --no-check-certificate -r 'https://drive.google.com/uc?export=download&id=1B0EeVlHbYBki__WszkbY8A3Za941K8QI' -O checkpoints/models.zip 3 | echo "Models downloaded. Starting to unzip" 4 | unzip -q checkpoints/models.zip -d checkpoints/ 5 | rm checkpoints/models.zip 6 | echo "Models downloaded and unzipped." 7 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from data.base_dataset import collate_fn 3 | import threading 4 | 5 | 6 | def CreateDataset(opt): 7 | """loads dataset class""" 8 | 9 | if opt.arch == 'vae' or opt.arch == 'gan': 10 | from data.grasp_sampling_data import GraspSamplingData 11 | dataset = GraspSamplingData(opt) 12 | else: 13 | from data.grasp_evaluator_data import GraspEvaluatorData 14 | dataset = GraspEvaluatorData(opt) 15 | return dataset 16 | 17 | 18 | class DataLoader: 19 | """multi-threaded data loading""" 20 | def __init__(self, opt): 21 | self.opt = opt 22 | self.dataset = CreateDataset(opt) 23 | self.dataloader = torch.utils.data.DataLoader( 24 | self.dataset, 25 | batch_size=opt.num_objects_per_batch, 26 | shuffle=not opt.serial_batches, 27 | num_workers=int(opt.num_threads), 28 | collate_fn=collate_fn) 29 | 30 | def __len__(self): 31 | return min(len(self.dataset), self.opt.max_dataset_size) 32 | 33 | def __iter__(self): 34 | for i, data in enumerate(self.dataloader): 35 | if i * self.opt.batch_size >= self.opt.max_dataset_size: 36 | break 37 | yield data 38 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import numpy as np 3 | import pickle 4 | import os 5 | import copy 6 | import json 7 | from utils.sample import Object 8 | from utils import utils 9 | import glob 10 | from renderer.online_object_renderer import OnlineObjectRenderer 11 | import threading 12 | 13 | 14 | class NoPositiveGraspsException(Exception): 15 | """raised when there's no positive grasps for an object.""" 16 | pass 17 | 18 | 19 | class BaseDataset(data.Dataset): 20 | def __init__(self, 21 | opt, 22 | caching=True, 23 | min_difference_allowed=(0, 0, 0), 24 | max_difference_allowed=(3, 3, 0), 25 | collision_hard_neg_min_translation=(-0.03, -0.03, -0.03), 26 | collision_hard_neg_max_translation=(0.03, 0.03, 0.03), 27 | collision_hard_neg_min_rotation=(-0.6, -0.2, -0.6), 28 | collision_hard_neg_max_rotation=(+0.6, +0.2, +0.6), 29 | collision_hard_neg_num_perturbations=10): 30 | super(BaseDataset, self).__init__() 31 | self.opt = opt 32 | self.mean = 0 33 | self.std = 1 34 | self.ninput_channels = None 35 | self.current_pc = None 36 | self.caching = caching 37 | self.cache = {} 38 | self.collision_hard_neg_min_translation = collision_hard_neg_min_translation 39 | self.collision_hard_neg_max_translation = collision_hard_neg_max_translation 40 | self.collision_hard_neg_min_rotation = collision_hard_neg_min_rotation 41 | self.collision_hard_neg_max_rotation = collision_hard_neg_max_rotation 42 | self.collision_hard_neg_num_perturbations = collision_hard_neg_num_perturbations 43 | self.lock = threading.Lock() 44 | for i in range(3): 45 | assert (collision_hard_neg_min_rotation[i] <= 46 | collision_hard_neg_max_rotation[i]) 47 | assert (collision_hard_neg_min_translation[i] <= 48 | collision_hard_neg_max_translation[i]) 49 | 50 | self.renderer = OnlineObjectRenderer(caching=True) 51 | 52 | if opt.use_uniform_quaternions: 53 | self.all_poses = utils.uniform_quaternions() 54 | else: 55 | self.all_poses = utils.nonuniform_quaternions() 56 | 57 | self.eval_files = [ 58 | json.load(open(f)) for f in glob.glob( 59 | os.path.join(self.opt.dataset_root_folder, 'splits', '*.json')) 60 | ] 61 | 62 | def apply_dropout(self, pc): 63 | if self.opt.occlusion_nclusters == 0 or self.opt.occlusion_dropout_rate == 0.: 64 | return np.copy(pc) 65 | 66 | labels = utils.farthest_points(pc, self.opt.occlusion_nclusters, 67 | utils.distance_by_translation_point) 68 | 69 | removed_labels = np.unique(labels) 70 | removed_labels = removed_labels[np.random.rand(removed_labels.shape[0]) 71 | < self.opt.occlusion_dropout_rate] 72 | if removed_labels.shape[0] == 0: 73 | return np.copy(pc) 74 | mask = np.ones(labels.shape, labels.dtype) 75 | for l in removed_labels: 76 | mask = np.logical_and(mask, labels != l) 77 | return pc[mask] 78 | 79 | def render_random_scene(self, camera_pose=None): 80 | """ 81 | Renders a random view and return (pc, camera_pose, object_pose). 82 | object_pose is None for single object per scene. 83 | """ 84 | if camera_pose is None: 85 | viewing_index = np.random.randint(0, high=len(self.all_poses)) 86 | camera_pose = self.all_poses[viewing_index] 87 | 88 | in_camera_pose = copy.deepcopy(camera_pose) 89 | _, _, pc, camera_pose = self.renderer.render(in_camera_pose) 90 | pc = self.apply_dropout(pc) 91 | pc = utils.regularize_pc_point_count(pc, self.opt.npoints) 92 | pc_mean = np.mean(pc, 0, keepdims=True) 93 | pc[:, :3] -= pc_mean[:, :3] 94 | camera_pose[:3, 3] -= pc_mean[0, :3] 95 | 96 | return pc, camera_pose, in_camera_pose 97 | 98 | def change_object_and_render(self, 99 | cad_path, 100 | cad_scale, 101 | camera_pose=None, 102 | thread_id=0): 103 | if camera_pose is None: 104 | viewing_index = np.random.randint(0, high=len(self.all_poses)) 105 | camera_pose = self.all_poses[viewing_index] 106 | 107 | in_camera_pose = copy.deepcopy(camera_pose) 108 | _, _, pc, camera_pose = self.renderer.change_and_render( 109 | cad_path, cad_scale, in_camera_pose, thread_id) 110 | pc = self.apply_dropout(pc) 111 | pc = utils.regularize_pc_point_count(pc, self.opt.npoints) 112 | pc_mean = np.mean(pc, 0, keepdims=True) 113 | pc[:, :3] -= pc_mean[:, :3] 114 | camera_pose[:3, 3] -= pc_mean[0, :3] 115 | 116 | return pc, camera_pose, in_camera_pose 117 | 118 | def change_object(self, cad_path, cad_scale): 119 | self.renderer.change_object(cad_path, cad_scale) 120 | 121 | def read_grasp_file(self, path, return_all_grasps=False): 122 | file_name = path 123 | if self.caching and file_name in self.cache: 124 | pos_grasps, pos_qualities, neg_grasps, neg_qualities, cad, cad_path, cad_scale = copy.deepcopy( 125 | self.cache[file_name]) 126 | return pos_grasps, pos_qualities, neg_grasps, neg_qualities, cad, cad_path, cad_scale 127 | 128 | pos_grasps, pos_qualities, neg_grasps, neg_qualities, cad, cad_path, cad_scale = self.read_object_grasp_data( 129 | path, 130 | ratio_of_grasps_to_be_used=self.opt.grasps_ratio, 131 | return_all_grasps=return_all_grasps) 132 | 133 | if self.caching: 134 | self.cache[file_name] = (pos_grasps, pos_qualities, neg_grasps, 135 | neg_qualities, cad, cad_path, cad_scale) 136 | return copy.deepcopy(self.cache[file_name]) 137 | 138 | return pos_grasps, pos_qualities, neg_grasps, neg_qualities, cad, cad_path, cad_scale 139 | 140 | def read_object_grasp_data(self, 141 | json_path, 142 | quality='quality_flex_object_in_gripper', 143 | ratio_of_grasps_to_be_used=1., 144 | return_all_grasps=False): 145 | """ 146 | Reads the grasps from the json path and loads the mesh and all the 147 | grasps. 148 | """ 149 | num_clusters = self.opt.num_grasp_clusters 150 | root_folder = self.opt.dataset_root_folder 151 | 152 | if num_clusters <= 0: 153 | raise NoPositiveGraspsException 154 | 155 | json_dict = json.load(open(json_path)) 156 | 157 | object_model = Object(os.path.join(root_folder, json_dict['object'])) 158 | object_model.rescale(json_dict['object_scale']) 159 | object_model = object_model.mesh 160 | object_mean = np.mean(object_model.vertices, 0, keepdims=1) 161 | 162 | object_model.vertices -= object_mean 163 | grasps = np.asarray(json_dict['transforms']) 164 | grasps[:, :3, 3] -= object_mean 165 | 166 | flex_qualities = np.asarray(json_dict[quality]) 167 | try: 168 | heuristic_qualities = np.asarray( 169 | json_dict['quality_number_of_contacts']) 170 | except KeyError: 171 | heuristic_qualities = np.ones(flex_qualities.shape) 172 | 173 | successful_mask = np.logical_and(flex_qualities > 0.01, 174 | heuristic_qualities > 0.01) 175 | 176 | positive_grasp_indexes = np.where(successful_mask)[0] 177 | negative_grasp_indexes = np.where(~successful_mask)[0] 178 | 179 | positive_grasps = grasps[positive_grasp_indexes, :, :] 180 | negative_grasps = grasps[negative_grasp_indexes, :, :] 181 | positive_qualities = heuristic_qualities[positive_grasp_indexes] 182 | negative_qualities = heuristic_qualities[negative_grasp_indexes] 183 | 184 | def cluster_grasps(grasps, qualities): 185 | cluster_indexes = np.asarray( 186 | utils.farthest_points(grasps, num_clusters, 187 | utils.distance_by_translation_grasp)) 188 | output_grasps = [] 189 | output_qualities = [] 190 | 191 | for i in range(num_clusters): 192 | indexes = np.where(cluster_indexes == i)[0] 193 | if ratio_of_grasps_to_be_used < 1: 194 | num_grasps_to_choose = max( 195 | 1, 196 | int(ratio_of_grasps_to_be_used * float(len(indexes)))) 197 | if len(indexes) == 0: 198 | raise NoPositiveGraspsException 199 | indexes = np.random.choice(indexes, 200 | size=num_grasps_to_choose, 201 | replace=False) 202 | 203 | output_grasps.append(grasps[indexes, :, :]) 204 | output_qualities.append(qualities[indexes]) 205 | 206 | output_grasps = np.asarray(output_grasps) 207 | output_qualities = np.asarray(output_qualities) 208 | 209 | return output_grasps, output_qualities 210 | 211 | if not return_all_grasps: 212 | positive_grasps, positive_qualities = cluster_grasps( 213 | positive_grasps, positive_qualities) 214 | negative_grasps, negative_qualities = cluster_grasps( 215 | negative_grasps, negative_qualities) 216 | num_positive_grasps = np.sum([p.shape[0] for p in positive_grasps]) 217 | num_negative_grasps = np.sum([p.shape[0] for p in negative_grasps]) 218 | else: 219 | num_positive_grasps = positive_grasps.shape[0] 220 | num_negative_grasps = negative_grasps.shape[0] 221 | return positive_grasps, positive_qualities, negative_grasps, negative_qualities, object_model, os.path.join( 222 | root_folder, json_dict['object']), json_dict['object_scale'] 223 | 224 | def sample_grasp_indexes(self, n, grasps, qualities): 225 | """ 226 | Stratified sampling of the grasps. 227 | """ 228 | nonzero_rows = [i for i in range(len(grasps)) if len(grasps[i]) > 0] 229 | num_clusters = len(nonzero_rows) 230 | replace = n > num_clusters 231 | if num_clusters == 0: 232 | raise NoPositiveGraspsException 233 | 234 | grasp_rows = np.random.choice(range(num_clusters), 235 | size=n, 236 | replace=replace).astype(np.int32) 237 | grasp_rows = [nonzero_rows[i] for i in grasp_rows] 238 | grasp_cols = [] 239 | for grasp_row in grasp_rows: 240 | if len(grasps[grasp_rows]) == 0: 241 | raise ValueError('grasps cannot be empty') 242 | 243 | grasp_cols.append(np.random.randint(len(grasps[grasp_row]))) 244 | 245 | grasp_cols = np.asarray(grasp_cols, dtype=np.int32) 246 | 247 | return np.vstack((grasp_rows, grasp_cols)).T 248 | 249 | def get_mean_std(self): 250 | """ Computes Mean and Standard Deviation from Training Data 251 | If mean/std file doesn't exist, will compute one 252 | :returns 253 | mean: N-dimensional mean 254 | std: N-dimensional standard deviation 255 | ninput_channels: N 256 | (here N=5) 257 | """ 258 | 259 | mean_std_cache = os.path.join(self.opt.dataset_root_folder, 260 | 'mean_std_cache.p') 261 | if not os.path.isfile(mean_std_cache): 262 | print('computing mean std from train data...') 263 | # doesn't run augmentation during m/std computation 264 | num_aug = self.opt.num_aug 265 | self.opt.num_aug = 1 266 | mean, std = np.array(0), np.array(0) 267 | for i, data in enumerate(self): 268 | if i % 500 == 0: 269 | print('{} of {}'.format(i, self.size)) 270 | features = data['edge_features'] 271 | mean = mean + features.mean(axis=1) 272 | std = std + features.std(axis=1) 273 | mean = mean / (i + 1) 274 | std = std / (i + 1) 275 | transform_dict = { 276 | 'mean': mean[:, np.newaxis], 277 | 'std': std[:, np.newaxis], 278 | 'ninput_channels': len(mean) 279 | } 280 | with open(mean_std_cache, 'wb') as f: 281 | pickle.dump(transform_dict, f) 282 | print('saved: ', mean_std_cache) 283 | self.opt.num_aug = num_aug 284 | # open mean / std from file 285 | with open(mean_std_cache, 'rb') as f: 286 | transform_dict = pickle.load(f) 287 | print('loaded mean / std from cache') 288 | self.mean = transform_dict['mean'] 289 | self.std = transform_dict['std'] 290 | self.ninput_channels = transform_dict['ninput_channels'] 291 | 292 | def make_dataset(self): 293 | split_files = os.listdir( 294 | os.path.join(self.opt.dataset_root_folder, 295 | self.opt.splits_folder_name)) 296 | files = [] 297 | for split_file in split_files: 298 | if split_file.find('.json') < 0: 299 | continue 300 | should_go_through = False 301 | if self.opt.allowed_categories == '': 302 | should_go_through = True 303 | if self.opt.blacklisted_categories != '': 304 | if self.opt.blacklisted_categories.find( 305 | split_file[:-5]) >= 0: 306 | should_go_through = False 307 | else: 308 | if self.opt.allowed_categories.find(split_file[:-5]) >= 0: 309 | should_go_through = True 310 | 311 | if should_go_through: 312 | files += [ 313 | os.path.join(self.opt.dataset_root_folder, 314 | self.opt.grasps_folder_name, f) 315 | for f in json.load( 316 | open( 317 | os.path.join(self.opt.dataset_root_folder, 318 | self.opt.splits_folder_name, 319 | split_file)))[self.opt.dataset_split] 320 | ] 321 | return files 322 | 323 | 324 | def collate_fn(batch): 325 | """Creates mini-batch tensors 326 | We should build custom collate_fn rather than using default collate_fn 327 | """ 328 | batch = list(filter(lambda x: x is not None, batch)) # 329 | meta = {} 330 | keys = batch[0].keys() 331 | for key in keys: 332 | meta.update({key: np.concatenate([d[key] for d in batch])}) 333 | return meta 334 | -------------------------------------------------------------------------------- /data/grasp_evaluator_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from data.base_dataset import BaseDataset, NoPositiveGraspsException 4 | import numpy as np 5 | from utils import utils 6 | import random 7 | import time 8 | try: 9 | from Queue import Queue 10 | except: 11 | from queue import Queue 12 | 13 | 14 | class GraspEvaluatorData(BaseDataset): 15 | def __init__(self, opt, ratio_positive=0.3, ratio_hardnegative=0.4): 16 | BaseDataset.__init__(self, opt) 17 | self.opt = opt 18 | self.device = torch.device('cuda:{}'.format( 19 | opt.gpu_ids[0])) if opt.gpu_ids else torch.device('cpu') 20 | self.root = opt.dataset_root_folder 21 | self.paths = self.make_dataset() 22 | self.size = len(self.paths) 23 | self.collision_hard_neg_queue = {} 24 | #self.get_mean_std() 25 | opt.input_nc = self.ninput_channels 26 | self.ratio_positive = self.set_ratios(ratio_positive) 27 | self.ratio_hardnegative = self.set_ratios(ratio_hardnegative) 28 | 29 | def set_ratios(self, ratio): 30 | if int(self.opt.num_grasps_per_object * ratio) == 0: 31 | return 1 / self.opt.num_grasps_per_object 32 | return ratio 33 | 34 | def __getitem__(self, index): 35 | path = self.paths[index] 36 | if self.opt.balanced_data: 37 | data = self.get_uniform_evaluator_data(path) 38 | else: 39 | data = self.get_nonuniform_evaluator_data(path) 40 | 41 | gt_control_points = utils.transform_control_points_numpy( 42 | data[1], self.opt.num_grasps_per_object, mode='rt') 43 | 44 | meta = {} 45 | meta['pc'] = data[0][:, :, :3] 46 | meta['grasp_rt'] = gt_control_points[:, :, :3] 47 | meta['labels'] = data[2] 48 | meta['quality'] = data[3] 49 | meta['pc_pose'] = data[4] 50 | meta['cad_path'] = data[5] 51 | meta['cad_scale'] = data[6] 52 | return meta 53 | 54 | def __len__(self): 55 | return self.size 56 | 57 | def get_uniform_evaluator_data(self, path, verify_grasps=False): 58 | pos_grasps, pos_qualities, neg_grasps, neg_qualities, obj_mesh, cad_path, cad_scale = self.read_grasp_file( 59 | path) 60 | 61 | output_pcs = [] 62 | output_grasps = [] 63 | output_qualities = [] 64 | output_labels = [] 65 | output_pc_poses = [] 66 | output_cad_paths = [cad_path] * self.opt.batch_size 67 | output_cad_scales = np.asarray([cad_scale] * self.opt.batch_size, 68 | np.float32) 69 | 70 | num_positive = int(self.opt.batch_size * self.opt.ratio_positive) 71 | positive_clusters = self.sample_grasp_indexes(num_positive, pos_grasps, 72 | pos_qualities) 73 | num_hard_negative = int(self.opt.batch_size * 74 | self.opt.ratio_hardnegative) 75 | num_flex_negative = self.opt.batch_size - num_positive - num_hard_negative 76 | negative_clusters = self.sample_grasp_indexes(num_flex_negative, 77 | neg_grasps, 78 | neg_qualities) 79 | hard_neg_candidates = [] 80 | # Fill in Positive Examples. 81 | 82 | for clusters, grasps, qualities in zip( 83 | [positive_clusters, negative_clusters], [pos_grasps, neg_grasps], 84 | [pos_qualities, neg_qualities]): 85 | for cluster in clusters: 86 | selected_grasp = grasps[cluster[0]][cluster[1]] 87 | selected_quality = qualities[cluster[0]][cluster[1]] 88 | hard_neg_candidates += utils.perturb_grasp( 89 | selected_grasp, 90 | self.collision_hard_neg_num_perturbations, 91 | self.collision_hard_neg_min_translation, 92 | self.collision_hard_neg_max_translation, 93 | self.collision_hard_neg_min_rotation, 94 | self.collision_hard_neg_max_rotation, 95 | ) 96 | 97 | if verify_grasps: 98 | collisions, heuristic_qualities = utils.evaluate_grasps( 99 | output_grasps, obj_mesh) 100 | for computed_quality, expected_quality, g in zip( 101 | heuristic_qualities, output_qualities, output_grasps): 102 | err = abs(computed_quality - expected_quality) 103 | if err > 1e-3: 104 | raise ValueError( 105 | 'Heuristic does not match with the values from data generation {}!={}' 106 | .format(computed_quality, expected_quality)) 107 | 108 | # If queue does not have enough data, fill it up with hard negative examples from the positives. 109 | if path not in self.collision_hard_neg_queue or len( 110 | self.collision_hard_neg_queue[path]) < num_hard_negative: 111 | if path not in self.collision_hard_neg_queue: 112 | self.collision_hard_neg_queue[path] = [] 113 | #hard negatives are perturbations of correct grasps. 114 | collisions, heuristic_qualities = utils.evaluate_grasps( 115 | hard_neg_candidates, obj_mesh) 116 | 117 | hard_neg_mask = collisions | (heuristic_qualities < 0.001) 118 | hard_neg_indexes = np.where(hard_neg_mask)[0].tolist() 119 | np.random.shuffle(hard_neg_indexes) 120 | for index in hard_neg_indexes: 121 | self.collision_hard_neg_queue[path].append( 122 | (hard_neg_candidates[index], -1.0)) 123 | random.shuffle(self.collision_hard_neg_queue[path]) 124 | 125 | # Adding positive grasps 126 | for positive_cluster in positive_clusters: 127 | #print(positive_cluster) 128 | selected_grasp = pos_grasps[positive_cluster[0]][ 129 | positive_cluster[1]] 130 | selected_quality = pos_qualities[positive_cluster[0]][ 131 | positive_cluster[1]] 132 | output_grasps.append(selected_grasp) 133 | output_qualities.append(selected_quality) 134 | output_labels.append(1) 135 | 136 | # Adding hard neg 137 | for i in range(num_hard_negative): 138 | grasp, quality = self.collision_hard_neg_queue[path][i] 139 | output_grasps.append(grasp) 140 | output_qualities.append(quality) 141 | output_labels.append(0) 142 | 143 | self.collision_hard_neg_queue[path] = self.collision_hard_neg_queue[ 144 | path][num_hard_negative:] 145 | 146 | # Adding flex neg 147 | if len(negative_clusters) != num_flex_negative: 148 | raise ValueError( 149 | 'negative clusters should have the same length as num_flex_negative {} != {}' 150 | .format(len(negative_clusters), num_flex_negative)) 151 | 152 | for negative_cluster in negative_clusters: 153 | selected_grasp = neg_grasps[negative_cluster[0]][ 154 | negative_cluster[1]] 155 | selected_quality = neg_qualities[negative_cluster[0]][ 156 | negative_cluster[1]] 157 | output_grasps.append(selected_grasp) 158 | output_qualities.append(selected_quality) 159 | output_labels.append(0) 160 | 161 | #self.change_object(cad_path, cad_scale) 162 | for iter in range(self.opt.num_grasps_per_object): 163 | if iter > 0: 164 | output_pcs.append(np.copy(output_pcs[0])) 165 | output_pc_poses.append(np.copy(output_pc_poses[0])) 166 | else: 167 | pc, camera_pose, _ = self.change_object_and_render( 168 | cad_path, 169 | cad_scale, 170 | thread_id=torch.utils.data.get_worker_info().id 171 | if torch.utils.data.get_worker_info() else 0) 172 | output_pcs.append(pc) 173 | output_pc_poses.append(utils.inverse_transform(camera_pose)) 174 | 175 | output_grasps[iter] = camera_pose.dot(output_grasps[iter]) 176 | 177 | output_pcs = np.asarray(output_pcs, dtype=np.float32) 178 | output_grasps = np.asarray(output_grasps, dtype=np.float32) 179 | output_labels = np.asarray(output_labels, dtype=np.int32) 180 | output_qualities = np.asarray(output_qualities, dtype=np.float32) 181 | output_pc_poses = np.asarray(output_pc_poses, dtype=np.float32) 182 | 183 | return output_pcs, output_grasps, output_labels, output_qualities, output_pc_poses, output_cad_paths, output_cad_scales 184 | 185 | def get_nonuniform_evaluator_data(self, path, verify_grasps=False): 186 | 187 | pos_grasps, pos_qualities, neg_grasps, neg_qualities, obj_mesh, cad_path, cad_scale = self.read_grasp_file( 188 | path) 189 | 190 | output_pcs = [] 191 | output_grasps = [] 192 | output_qualities = [] 193 | output_labels = [] 194 | output_pc_poses = [] 195 | output_cad_paths = [cad_path] * self.opt.num_grasps_per_object 196 | output_cad_scales = np.asarray( 197 | [cad_scale] * self.opt.num_grasps_per_object, np.float32) 198 | 199 | num_positive = int(self.opt.num_grasps_per_object * 200 | self.ratio_positive) 201 | positive_clusters = self.sample_grasp_indexes(num_positive, pos_grasps, 202 | pos_qualities) 203 | num_negative = self.opt.num_grasps_per_object - num_positive 204 | negative_clusters = self.sample_grasp_indexes(num_negative, neg_grasps, 205 | neg_qualities) 206 | hard_neg_candidates = [] 207 | # Fill in Positive Examples. 208 | for positive_cluster in positive_clusters: 209 | selected_grasp = pos_grasps[positive_cluster[0]][ 210 | positive_cluster[1]] 211 | selected_quality = pos_qualities[positive_cluster[0]][ 212 | positive_cluster[1]] 213 | output_grasps.append(selected_grasp) 214 | output_qualities.append(selected_quality) 215 | output_labels.append(1) 216 | hard_neg_candidates += utils.perturb_grasp( 217 | selected_grasp, 218 | self.collision_hard_neg_num_perturbations, 219 | self.collision_hard_neg_min_translation, 220 | self.collision_hard_neg_max_translation, 221 | self.collision_hard_neg_min_rotation, 222 | self.collision_hard_neg_max_rotation, 223 | ) 224 | 225 | if verify_grasps: 226 | collisions, heuristic_qualities = utils.evaluate_grasps( 227 | output_grasps, obj_mesh) 228 | for computed_quality, expected_quality, g in zip( 229 | heuristic_qualities, output_qualities, output_grasps): 230 | err = abs(computed_quality - expected_quality) 231 | if err > 1e-3: 232 | raise ValueError( 233 | 'Heuristic does not match with the values from data generation {}!={}' 234 | .format(computed_quality, expected_quality)) 235 | 236 | # If queue does not have enough data, fill it up with hard negative examples from the positives. 237 | if path not in self.collision_hard_neg_queue or self.collision_hard_neg_queue[ 238 | path].qsize() < num_negative: 239 | if path not in self.collision_hard_neg_queue: 240 | self.collision_hard_neg_queue[path] = Queue() 241 | #hard negatives are perturbations of correct grasps. 242 | random_selector = np.random.rand() 243 | if random_selector < self.ratio_hardnegative: 244 | #print('add hard neg') 245 | collisions, heuristic_qualities = utils.evaluate_grasps( 246 | hard_neg_candidates, obj_mesh) 247 | hard_neg_mask = collisions | (heuristic_qualities < 0.001) 248 | hard_neg_indexes = np.where(hard_neg_mask)[0].tolist() 249 | np.random.shuffle(hard_neg_indexes) 250 | for index in hard_neg_indexes: 251 | self.collision_hard_neg_queue[path].put( 252 | (hard_neg_candidates[index], -1.0)) 253 | if random_selector >= self.ratio_hardnegative or self.collision_hard_neg_queue[ 254 | path].qsize() < num_negative: 255 | for negative_cluster in negative_clusters: 256 | selected_grasp = neg_grasps[negative_cluster[0]][ 257 | negative_cluster[1]] 258 | selected_quality = neg_qualities[negative_cluster[0]][ 259 | negative_cluster[1]] 260 | self.collision_hard_neg_queue[path].put( 261 | (selected_grasp, selected_quality)) 262 | 263 | # Use negative examples from queue. 264 | for _ in range(num_negative): 265 | #print('qsize = ', self._collision_hard_neg_queue[file_path].qsize()) 266 | grasp, quality = self.collision_hard_neg_queue[path].get() 267 | output_grasps.append(grasp) 268 | output_qualities.append(quality) 269 | output_labels.append(0) 270 | 271 | for iter in range(self.opt.num_grasps_per_object): 272 | if iter > 0: 273 | output_pcs.append(np.copy(output_pcs[0])) 274 | output_pc_poses.append(np.copy(output_pc_poses[0])) 275 | else: 276 | pc, camera_pose, _ = self.change_object_and_render( 277 | cad_path, 278 | cad_scale, 279 | thread_id=torch.utils.data.get_worker_info().id 280 | if torch.utils.data.get_worker_info() else 0) 281 | #self.change_object(cad_path, cad_scale) 282 | #pc, camera_pose, _ = self.render_random_scene() 283 | 284 | output_pcs.append(pc) 285 | output_pc_poses.append(utils.inverse_transform(camera_pose)) 286 | 287 | output_grasps[iter] = camera_pose.dot(output_grasps[iter]) 288 | 289 | output_pcs = np.asarray(output_pcs, dtype=np.float32) 290 | output_grasps = np.asarray(output_grasps, dtype=np.float32) 291 | output_labels = np.asarray(output_labels, dtype=np.int32) 292 | output_qualities = np.asarray(output_qualities, dtype=np.float32) 293 | output_pc_poses = np.asarray(output_pc_poses, dtype=np.float32) 294 | return output_pcs, output_grasps, output_labels, output_qualities, output_pc_poses, output_cad_paths, output_cad_scales 295 | -------------------------------------------------------------------------------- /data/grasp_sampling_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from data.base_dataset import BaseDataset, NoPositiveGraspsException 4 | import numpy as np 5 | from utils import utils 6 | 7 | 8 | class GraspSamplingData(BaseDataset): 9 | def __init__(self, opt): 10 | BaseDataset.__init__(self, opt) 11 | self.opt = opt 12 | self.device = torch.device('cuda:{}'.format( 13 | opt.gpu_ids[0])) if opt.gpu_ids else torch.device('cpu') 14 | self.root = opt.dataset_root_folder 15 | self.paths = self.make_dataset() 16 | self.size = len(self.paths) 17 | #self.get_mean_std() 18 | opt.input_nc = self.ninput_channels 19 | self.i = 0 20 | 21 | def __getitem__(self, index): 22 | path = self.paths[index] 23 | pos_grasps, pos_qualities, _, _, _, cad_path, cad_scale = self.read_grasp_file( 24 | path) 25 | meta = {} 26 | try: 27 | all_clusters = self.sample_grasp_indexes( 28 | self.opt.num_grasps_per_object, pos_grasps, pos_qualities) 29 | except NoPositiveGraspsException: 30 | if self.opt.skip_error: 31 | return None 32 | else: 33 | return self.__getitem__(np.random.randint(0, self.size)) 34 | 35 | #self.change_object(cad_path, cad_scale) 36 | #pc, camera_pose, _ = self.render_random_scene() 37 | pc, camera_pose, _ = self.change_object_and_render( 38 | cad_path, 39 | cad_scale, 40 | thread_id=torch.utils.data.get_worker_info().id 41 | if torch.utils.data.get_worker_info() else 0) 42 | 43 | output_qualities = [] 44 | output_grasps = [] 45 | for iter in range(self.opt.num_grasps_per_object): 46 | selected_grasp_index = all_clusters[iter] 47 | 48 | selected_grasp = pos_grasps[selected_grasp_index[0]][ 49 | selected_grasp_index[1]] 50 | selected_quality = pos_qualities[selected_grasp_index[0]][ 51 | selected_grasp_index[1]] 52 | output_qualities.append(selected_quality) 53 | output_grasps.append(camera_pose.dot(selected_grasp)) 54 | gt_control_points = utils.transform_control_points_numpy( 55 | np.array(output_grasps), self.opt.num_grasps_per_object, mode='rt') 56 | 57 | meta['pc'] = np.array([pc] * self.opt.num_grasps_per_object)[:, :, :3] 58 | meta['grasp_rt'] = np.array(output_grasps).reshape( 59 | len(output_grasps), -1) 60 | 61 | meta['pc_pose'] = np.array([utils.inverse_transform(camera_pose)] * 62 | self.opt.num_grasps_per_object) 63 | meta['cad_path'] = np.array([cad_path] * 64 | self.opt.num_grasps_per_object) 65 | meta['cad_scale'] = np.array([cad_scale] * 66 | self.opt.num_grasps_per_object) 67 | meta['quality'] = np.array(output_qualities) 68 | meta['target_cps'] = np.array(gt_control_points[:, :, :3]) 69 | return meta 70 | 71 | def __len__(self): 72 | return self.size -------------------------------------------------------------------------------- /demo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsll/pytorch_6dof-graspnet/634a18c9f087ea25e7eced58c01db13394443fb9/demo/__init__.py -------------------------------------------------------------------------------- /demo/data/blue_mug.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsll/pytorch_6dof-graspnet/634a18c9f087ea25e7eced58c01db13394443fb9/demo/data/blue_mug.npy -------------------------------------------------------------------------------- /demo/data/cheezit.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsll/pytorch_6dof-graspnet/634a18c9f087ea25e7eced58c01db13394443fb9/demo/data/cheezit.npy -------------------------------------------------------------------------------- /demo/data/cylinder.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsll/pytorch_6dof-graspnet/634a18c9f087ea25e7eced58c01db13394443fb9/demo/data/cylinder.npy -------------------------------------------------------------------------------- /demo/data/green_bowl.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsll/pytorch_6dof-graspnet/634a18c9f087ea25e7eced58c01db13394443fb9/demo/data/green_bowl.npy -------------------------------------------------------------------------------- /demo/data/mustard.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsll/pytorch_6dof-graspnet/634a18c9f087ea25e7eced58c01db13394443fb9/demo/data/mustard.npy -------------------------------------------------------------------------------- /demo/data/pepper.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsll/pytorch_6dof-graspnet/634a18c9f087ea25e7eced58c01db13394443fb9/demo/data/pepper.npy -------------------------------------------------------------------------------- /demo/data/red_mug.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsll/pytorch_6dof-graspnet/634a18c9f087ea25e7eced58c01db13394443fb9/demo/data/red_mug.npy -------------------------------------------------------------------------------- /demo/data/sugar.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsll/pytorch_6dof-graspnet/634a18c9f087ea25e7eced58c01db13394443fb9/demo/data/sugar.npy -------------------------------------------------------------------------------- /demo/data/white_bowl.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsll/pytorch_6dof-graspnet/634a18c9f087ea25e7eced58c01db13394443fb9/demo/data/white_bowl.npy -------------------------------------------------------------------------------- /demo/data/white_mug.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsll/pytorch_6dof-graspnet/634a18c9f087ea25e7eced58c01db13394443fb9/demo/data/white_mug.npy -------------------------------------------------------------------------------- /demo/examples/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsll/pytorch_6dof-graspnet/634a18c9f087ea25e7eced58c01db13394443fb9/demo/examples/1.png -------------------------------------------------------------------------------- /demo/examples/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsll/pytorch_6dof-graspnet/634a18c9f087ea25e7eced58c01db13394443fb9/demo/examples/2.png -------------------------------------------------------------------------------- /demo/main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import numpy as np 4 | import argparse 5 | import grasp_estimator 6 | import sys 7 | import os 8 | import glob 9 | import mayavi.mlab as mlab 10 | from utils.visualization_utils import * 11 | import mayavi.mlab as mlab 12 | from utils import utils 13 | from data import DataLoader 14 | 15 | 16 | def make_parser(): 17 | parser = argparse.ArgumentParser( 18 | description='6-DoF GraspNet Demo', 19 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 20 | parser.add_argument('--grasp_sampler_folder', 21 | type=str, 22 | default='checkpoints/gan_pretrained/') 23 | parser.add_argument('--grasp_evaluator_folder', 24 | type=str, 25 | default='checkpoints/evaluator_pretrained/') 26 | parser.add_argument('--refinement_method', 27 | choices={"gradient", "sampling"}, 28 | default='sampling') 29 | parser.add_argument('--refine_steps', type=int, default=25) 30 | 31 | parser.add_argument('--npy_folder', type=str, default='demo/data/') 32 | parser.add_argument( 33 | '--threshold', 34 | type=float, 35 | default=0.8, 36 | help= 37 | "When choose_fn is something else than all, all grasps with a score given by the evaluator notwork less than the threshold are removed" 38 | ) 39 | parser.add_argument( 40 | '--choose_fn', 41 | choices={ 42 | "all", "better_than_threshold", "better_than_threshold_in_sequence" 43 | }, 44 | default='better_than_threshold', 45 | help= 46 | "If all, no grasps are removed. If better than threshold, only the last refined grasps are considered while better_than_threshold_in_sequence consideres all refined grasps" 47 | ) 48 | 49 | parser.add_argument('--target_pc_size', type=int, default=1024) 50 | parser.add_argument('--num_grasp_samples', type=int, default=200) 51 | parser.add_argument( 52 | '--generate_dense_grasps', 53 | action='store_true', 54 | help= 55 | "If enabled, it will create a [num_grasp_samples x num_grasp_samples] dense grid of latent space values and generate grasps from these." 56 | ) 57 | 58 | parser.add_argument( 59 | '--batch_size', 60 | type=int, 61 | default=30, 62 | help= 63 | "Set the batch size of the number of grasps we want to process and can fit into the GPU memory at each forward pass. The batch_size can be increased for a GPU with more memory." 64 | ) 65 | parser.add_argument('--train_data', action='store_true') 66 | opts, _ = parser.parse_known_args() 67 | if opts.train_data: 68 | parser.add_argument('--dataset_root_folder', 69 | required=True, 70 | type=str, 71 | help='path to root directory of the dataset.') 72 | return parser 73 | 74 | 75 | def get_color_for_pc(pc, K, color_image): 76 | proj = pc.dot(K.T) 77 | proj[:, 0] /= proj[:, 2] 78 | proj[:, 1] /= proj[:, 2] 79 | 80 | pc_colors = np.zeros((pc.shape[0], 3), dtype=np.uint8) 81 | for i, p in enumerate(proj): 82 | x = int(p[0]) 83 | y = int(p[1]) 84 | pc_colors[i, :] = color_image[y, x, :] 85 | 86 | return pc_colors 87 | 88 | 89 | def backproject(depth_cv, 90 | intrinsic_matrix, 91 | return_finite_depth=True, 92 | return_selection=False): 93 | 94 | depth = depth_cv.astype(np.float32, copy=True) 95 | 96 | # get intrinsic matrix 97 | K = intrinsic_matrix 98 | Kinv = np.linalg.inv(K) 99 | 100 | # compute the 3D points 101 | width = depth.shape[1] 102 | height = depth.shape[0] 103 | 104 | # construct the 2D points matrix 105 | x, y = np.meshgrid(np.arange(width), np.arange(height)) 106 | ones = np.ones((height, width), dtype=np.float32) 107 | x2d = np.stack((x, y, ones), axis=2).reshape(width * height, 3) 108 | 109 | # backprojection 110 | R = np.dot(Kinv, x2d.transpose()) 111 | 112 | # compute the 3D points 113 | X = np.multiply(np.tile(depth.reshape(1, width * height), (3, 1)), R) 114 | X = np.array(X).transpose() 115 | if return_finite_depth: 116 | selection = np.isfinite(X[:, 0]) 117 | X = X[selection, :] 118 | 119 | if return_selection: 120 | return X, selection 121 | 122 | return X 123 | 124 | 125 | def main(args): 126 | parser = make_parser() 127 | args = parser.parse_args() 128 | grasp_sampler_args = utils.read_checkpoint_args(args.grasp_sampler_folder) 129 | grasp_sampler_args.is_train = False 130 | grasp_evaluator_args = utils.read_checkpoint_args( 131 | args.grasp_evaluator_folder) 132 | grasp_evaluator_args.continue_train = True 133 | estimator = grasp_estimator.GraspEstimator(grasp_sampler_args, 134 | grasp_evaluator_args, args) 135 | if args.train_data: 136 | grasp_sampler_args.dataset_root_folder = args.dataset_root_folder 137 | grasp_sampler_args.num_grasps_per_object = 1 138 | grasp_sampler_args.num_objects_per_batch = 1 139 | dataset = DataLoader(grasp_sampler_args) 140 | for i, data in enumerate(dataset): 141 | generated_grasps, generated_scores = estimator.generate_and_refine_grasps( 142 | data["pc"].squeeze()) 143 | mlab.figure(bgcolor=(1, 1, 1)) 144 | draw_scene(data["pc"][0], 145 | grasps=generated_grasps, 146 | grasp_scores=generated_scores) 147 | print('close the window to continue to next object . . .') 148 | mlab.show() 149 | else: 150 | for npy_file in glob.glob(os.path.join(args.npy_folder, '*.npy')): 151 | # Depending on your numpy version you may need to change allow_pickle 152 | # from True to False. 153 | 154 | data = np.load(npy_file, allow_pickle=True, 155 | encoding="latin1").item() 156 | 157 | depth = data['depth'] 158 | image = data['image'] 159 | K = data['intrinsics_matrix'] 160 | # Removing points that are farther than 1 meter or missing depth 161 | # values. 162 | #depth[depth == 0 or depth > 1] = np.nan 163 | 164 | np.nan_to_num(depth, copy=False) 165 | mask = np.where(np.logical_or(depth == 0, depth > 1)) 166 | depth[mask] = np.nan 167 | pc, selection = backproject(depth, 168 | K, 169 | return_finite_depth=True, 170 | return_selection=True) 171 | pc_colors = image.copy() 172 | pc_colors = np.reshape(pc_colors, [-1, 3]) 173 | pc_colors = pc_colors[selection, :] 174 | 175 | # Smoothed pc comes from averaging the depth for 10 frames and removing 176 | # the pixels with jittery depth between those 10 frames. 177 | object_pc = data['smoothed_object_pc'] 178 | generated_grasps, generated_scores = estimator.generate_and_refine_grasps( 179 | object_pc) 180 | mlab.figure(bgcolor=(1, 1, 1)) 181 | draw_scene( 182 | pc, 183 | pc_color=pc_colors, 184 | grasps=generated_grasps, 185 | grasp_scores=generated_scores, 186 | ) 187 | print('close the window to continue to next object . . .') 188 | mlab.show() 189 | 190 | 191 | if __name__ == '__main__': 192 | main(sys.argv[1:]) 193 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | from __future__ import print_function 9 | 10 | import numpy as np 11 | import yaml 12 | import argparse 13 | import numpy 14 | import grasp_estimator 15 | import copy 16 | import sys 17 | import os 18 | import grasp_data_reader 19 | import torch 20 | import glob 21 | import sample 22 | import json 23 | import subprocess 24 | import time 25 | import datetime 26 | import os 27 | from sklearn.metrics import precision_recall_curve, average_precision_score 28 | from scipy import spatial 29 | import shutil 30 | from utils import utils 31 | RADIUS = 0.02 32 | 33 | 34 | def default(obj): 35 | if type(obj).__module__ == np.__name__: 36 | if isinstance(obj, np.ndarray): 37 | return obj.tolist() 38 | else: 39 | return obj.item() 40 | raise TypeError('Unknown type:', type(obj)) 41 | 42 | 43 | def create_directory(path, delete_if_exist=False): 44 | if not os.path.isdir(path): 45 | os.makedirs(path) 46 | else: 47 | if delete_if_exist: 48 | print('***************** deleting folder ', path) 49 | shutil.rmtree(path) 50 | os.makedirs(path) 51 | 52 | 53 | def make_parser(argv): 54 | """ 55 | Outputs a parser. 56 | """ 57 | parser = argparse.ArgumentParser( 58 | description='Evaluators', 59 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 60 | parser.add_argument('--grasp_sampler_folder', type=str, default='') 61 | parser.add_argument('--grasp_evaluator_folder', type=str, default='') 62 | parser.add_argument('--eval_data_folder', type=str, default='') 63 | parser.add_argument('--generate_data_if_missing', type=int, default=0) 64 | parser.add_argument('--dataset_root_folder', type=str, default='') 65 | parser.add_argument('--num_experiments', type=int, default=100) 66 | parser.add_argument('--eval_split', type=str, default='test') 67 | parser.add_argument('--eval_grasp_evaluator', type=int, default=0) 68 | parser.add_argument('--eval_vae_and_evaluator', type=int, default=1) 69 | parser.add_argument('--output_folder', type=str, default='') 70 | parser.add_argument('--gradient_based_refinement', 71 | action='store_true', 72 | default=False) 73 | 74 | return parser.parse_args(argv[1:]) 75 | 76 | 77 | class Evaluator(): 78 | def __init__(self, 79 | cfg, 80 | create_data_if_not_exist, 81 | output_folder, 82 | eval_experiment_folder, 83 | num_experiments, 84 | eval_grasp_evaluator=False, 85 | eval_vae_and_evaluator=True): 86 | self._should_create_data = create_data_if_not_exist 87 | self._eval_experiment_folder = eval_experiment_folder 88 | self._num_experiments = num_experiments 89 | self._grasp_reader = grasp_data_reader.PointCloudReader( 90 | root_folder=cfg.dataset_root_folder, 91 | batch_size=cfg.num_grasps_per_object, 92 | num_grasp_clusters=cfg.num_grasp_clusters, 93 | npoints=cfg.npoints, 94 | min_difference_allowed=(0, 0, 0), 95 | max_difference_allowed=(3, 3, 0), 96 | occlusion_nclusters=0, 97 | occlusion_dropout_rate=0., 98 | use_uniform_quaternions=False, 99 | ratio_of_grasps_used=1, 100 | ) 101 | self._cfg = cfg 102 | self._grasp_estimator = grasp_estimator.GraspEstimator(cfg) 103 | os.environ['CUDA_VISIBLE_DEVICES'] = str(self._cfg.gpu) 104 | self._sess = tf.Session() 105 | del os.environ['CUDA_VISIBLE_DEVICES'] 106 | self._grasp_estimator.build_network() 107 | self._eval_grasp_evaluator = eval_grasp_evaluator 108 | self._eval_vae_and_evaluator = eval_vae_and_evaluator 109 | self._flex_initialized = False 110 | self._output_folder = output_folder 111 | self.update_time_stamp() 112 | 113 | def read_eval_scene(self, file_path, visualize=False): 114 | if not os.path.isfile(file_path): 115 | if not self._should_create_data: 116 | raise ValueError('could not find data {}'.format(file_path)) 117 | 118 | json_path = self._grasp_reader.generate_object_set( 119 | self._cfg.eval_split) 120 | obj_grasp_data = self._grasp_reader.read_grasp_file( 121 | os.path.join(self._cfg.dataset_root_folder, json_path), True) 122 | obj_pose = self._grasp_reader.arrange_objects( 123 | obj_grasp_data[-3])[0] 124 | in_camera_pose = None 125 | print('changing object to ', obj_grasp_data[-2]) 126 | self._grasp_reader.change_object(obj_grasp_data[-2], 127 | obj_grasp_data[-1]) 128 | pc, camera_pose, in_camera_pose = self._grasp_reader.render_random_scene( 129 | None) 130 | folder_path = file_path[:file_path.rfind('/')] 131 | create_directory(folder_path) 132 | 133 | print('writing {}'.format(file_path)) 134 | np.save(file_path, {'json': json_path, 'obj_pose': obj_pose, 'camera_pose': in_camera_pose}) 135 | else: 136 | 137 | 138 | 139 | 140 | d = np.load(file_path).item() 141 | json_path = d['json'] 142 | obj_pose = d['obj_pose'] 143 | obj_grasp_data = self._grasp_reader.read_grasp_file( 144 | os.path.join(self._cfg.dataset_root_folder, json_path), True 145 | )in_camera_pose = d['camera_pose'] 146 | self._grasp_reader.change_object(obj_grasp_data[-2], obj_grasp_data[-1]) 147 | pc, camera_pose, _= self._grasp_reader.render_random 148 | _scene(in_camera_pose) 149 | 150 | 151 | pos_grasps = np.matmul(np.expand_dims(camera_pose, 0), obj_grasp_data[0]) 152 | neg_grasps = np.matmul(np.expand_dims(camera_pose, 0), 153 | obj_grasp_data[2]) 154 | grasp_labels = np.hstack( 155 | 156 | (np.ones(pos_grasps.shape[0]), np.zeros(neg_grasps.shape[0]))).astype(np.int32) 157 | grasps = np.concatenate((pos_grasp 158 | s, neg_grasps), 0) 159 | 160 | if visualize: 161 | from visualization_utils import draw_scene 162 | import mayavi.mlab as mlab 163 | 164 | pos_mask = np.logical_and(grasp_labels == 1, np.random.rand(*grasp_labels.shape) < 0.1) 165 | neg_mask = np.logical_and( 166 | grasp_labels == 0, 167 | np.random.rand(*grasp_labels.shape) < 0.01) 168 | 169 | 170 | 171 | print(grasps[pos_mask, :, :].shape, grasps[neg_mask, :, :].shape) 172 | draw_scene(pc, grasps[pos_mask, :, :]) 173 | mlab.show() 174 | 175 | draw_scene(pc, grasps[neg_mask, :, :]) 176 | mlab.show() 177 | 178 | return pc[:, :3], grasps, grasp_labels, {'cad_path': obj_grasp_data[-2], 'cad_scale': obj_grasp_data[-1], 'to_canonical_transformation': grasp_data_reader.inverse_transform(camera_pose)} 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | } 187 | def eval_scene(self, file_path, visualize=False): 188 | """ 189 | Returns full_results, evaluator_results. 190 | full_results: Contains information about grasps in canonical pose, scores, 191 | ground truth positive grasps, and also cad path and scale that is used for 192 | flex evaluation. 193 | evaluator_results: Only contains information for the classification of positive 194 | and negative grasps. The info is gt label of each grasp, predicted score for 195 | each grasp, and the 4x4 transformation of each grasp. 196 | """ 197 | pc, grasps, grasps_label, flex_info = self.read_eval_scene(file_path) 198 | canonical_transform = flex_info['to_canonical_transformation'] 199 | evaluator_result = None 200 | full_results = None 201 | if self._eval_grasp_evaluator: 202 | latents = self._grasp_estimator.sample_latents() 203 | output_grasps, output_scores, _ = self._grasp_estimator.predict_grasps( 204 | self._sess, 205 | pc, latents, 0, grasps_rt=gevaluator_result = (grasps_label, output_scores, output_grasps) 206 | 207 | 208 | latents = np.random.rand(self._cfg.num_samples, self._cfg.latent_size) * 4 - 2 209 | print(pc.shape) 210 | 211 | generated_grasps, generated_scores, _ = self._grasp_estimator.predict_grasps( 212 | self._sess, 213 | pc, 214 | latents, 215 | num_refine_steps=self._cfg.num_refine_steps, 216 | ) 217 | 218 | gt_pos_grasps = [g for g, l in zip(grasps, grasps_label) if l == 1] 219 | gt_pos_grasps = np.asarray(gt_pos_grasps).copy() 220 | gt_pos_grasps_canonical = np.matmul(canonical_transform, gt_pos_grasps) 221 | generated_grasps = np.asarray(generated_grasps) 222 | 223 | print(generated_grasps.shape) 224 | generated_grasps_canonical = np.matmul(canonical_transform, generated_grasps) 225 | 226 | 227 | obj = sample.Object(flex_info['cad_path']) 228 | obj.rescale(flex_info['cad_scale']) 229 | mesh = obj.mesh 230 | mesh_mean = np.mean(mesh.vertices, 0, keepdims=True) 231 | 232 | canonical_pc = pc.dot(canonical_transform[:3, :3].T) 233 | canonical_pc += np.expand_dims(canonical_transform[:3, 3], 0) 234 | 235 | gt_pos_grasps_canonical[:, :3, 3] += mesh_mean 236 | canonical_pc += mesh_mean 237 | generated_grasps_canonical[:, :3, 3] += mesh_mean 238 | 239 | if visualize: 240 | from visualization_utils import draw_scene 241 | import mayavi.mlab as mlab 242 | 243 | draw_scene(canonical_pc, grasps=gt_pos_grasps_canonical, mesh=mesh) 244 | mlab.show() 245 | 246 | 247 | 248 | 249 | mlab.show() 250 | 251 | 252 | 253 | 254 | full_results = (generated_grasps_canonical, generated_scores, gt_pos_grasps_canonical, flex_info['cad_path'], flex_info['cad_scale']) 255 | 256 | 257 | 258 | return full_results, evaluator_result 259 | 260 | 261 | self._signature = self._get_current_time_stamp() 262 | 263 | 264 | """No Comments.""" 265 | now = datetime.datetime.now() 266 | return now.strftime("%Y-%m-%d_%H-%M") 267 | 268 | 269 | """ 270 | Evaluates all of the test scenes. 271 | 272 | Args: 273 | plot_curves: bool, if True, plots the coressponding figure 274 | for each of the evaluations. 275 | """ 276 | self._grasp_estimator.load_weights(self._sess) 277 | 278 | create_directory(self._eval_experiment_folder) 279 | 280 | num_digits = len(str(self._num_experiments)) 281 | 282 | all_eval_results = [] 283 | all_full_results = [] 284 | 285 | for i in range(self._num_experiments): 286 | full_result, eval_result = self.eval_scene(os.path.join(self._eval_experiment_folder, 'eval_configs', str(i).zfill(num_digits)) + '.npy', False) 287 | if full_result is not None: 288 | 289 | 290 | grasps, scores, gt_grasps, cad_path, cad_scale = full_result 291 | experiment_folder = os.path.join(self._eval_experiment_folder, 'flex_folder', str(i).zfill(num_digits)) 292 | flex_outcomes = self.eval_grasps_on_flex(grasps, cad_path, cad 293 | _scale, experim 294 | ent_folder) 295 | all_full_results.append((grasps, scores, 296 | flex_outcomes, gt_grasps)) 297 | 298 | 299 | if eval_result is not None: 300 | all_eval_results.append(eval_result) 301 | 302 | if len(all_eval_results) > 0: 303 | self.metric_classification_mean_ap( 304 | [x[0] for x in all_eval_results], 305 | [x[1] for x in all_eval_results], 306 | plot_curves) 307 | if len(all_full_results) > 0: 308 | self.metric_coverage_success_rate( 309 | [x[0] for x in all_full_result[x[1] for x in all_full_results], 310 | [x[2] for x in all_full_results], 311 | [x[3] for x in all_full_results], 312 | plot_curves 313 | ) plot_curves 314 | def metric_classification_mean_ap(self, gt_labels_list, scores_list, visualize): 315 | """ 316 | 317 | Computes the average precision metric for evaluator. 318 | 319 | Args: 320 | gt_labels_list: list of binary numbers indicating the success of 321 | grasps. 1 means grasp is successful and 0 means failure. 322 | scores_list: list of float numbers for the score of each grasp. 323 | visualize: bool, if True, visualizes the plots. 324 | 325 | Returns: 326 | average_precision: area under the curve for precision-recall plot. 327 | best_threshold: float, best threshold that has the highest f-1 328 | measure. 329 | """ 330 | all_gt_labels = [] 331 | all_scores = [] 332 | if len(gt_labels_list) != len(scores_list): 333 | raise ValueError("Length of the lists should match") 334 | 335 | for gt_labels in gt_labels_list: 336 | all_gt_labels += [l for l in gt_labels] 337 | for scores in scores_list: 338 | all_scores += [s for s in scores] 339 | 340 | precision, recall, thresholds = precision_recall_curve(all_gt_labels, all_scores) 341 | average_precision = average_precision_score(all_gt_labe 342 | ls, all_scores) 343 | f1_score = 2 * (precision * recall) / (precision + recall) 344 | 345 | best_threshold = thresholds[np.argmax(f1_score)] 346 | if visualize: 347 | import matplotlib.pyplot as plt 348 | plt.plot(recall, precision) 349 | 350 | plt.xlabel('Recall') 351 | plt.ylabel('Precision') 352 | plt.ylim([0.0, 1.05]) 353 | plt.xlim([0.0, 1.0]) 354 | plt.title('2-class Precision-Recall curve: AP={0:02f}, best_treshold = {0:02f}'.format( 355 | av 356 | erage_precision, best_threshold)) 357 | plt..format(() 358 | 359 | np.save( 360 | os.path.join(self._output_folder, '{}_evalauator.npy'.format(self._signature)), 361 | {'cfg':self._cfg, 'precisions': p 362 | recision, 'recalls': recall, 'average_precisio {n': average_precision, 'best_threshold': best_threshold} 363 | ) 364 | 365 | 366 | 367 | 368 | }) 369 | return average_precision, best_threshold 370 | 371 | def metric_coverage_success_rate(self, grasps_list, scores_list, flex_outcomes_list, gt_grasps_list, visualize): 372 | """ 373 | 374 | 375 | Computes the coverage success rate for grasps of multiple objects. 376 | 377 | Args: 378 | grasps_list: list of numpy array, each numpy array is the predicted 379 | grasps for each object. Each numpy array has shape (n, 4, 4) where 380 | n is the number of predicted grasps for each object. 381 | scores_list: list of numpy array, each numpy array is the predicted 382 | scores for each grasp of the corresponding object. 383 | flex_outcomes_list: list of numpy array, each element of the numpy 384 | array indicates whether that grasp succeeds in grasping the object 385 | or not. 386 | gt_grasps_list: list of numpy array. Each numpy array has shape of 387 | (m, 4, 4) where m is the number of groundtruth grasps for each 388 | object. 389 | visualize: bool. If True, it will plot the curve. 390 | 391 | Returns: 392 | auc: float, area under the curve for the success-coverage plot. 393 | """ 394 | all_trees = [] 395 | all_grasps = [] 396 | all_object_indexes = [] 397 | all_scores = [] 398 | all_flex_outcomes = [] 399 | visited = set() 400 | tot_num_gt_grasps = 0 401 | for i in range(len(grasps_list)): 402 | print('building kd-tree {}/{}'.format(i, len(grasps_list))) 403 | gt_grasps = np.asarray(gt_grasps_list[i]).copy() 404 | all_trees.append(spatial.KDTree(gt_grasps[:, :3, 3])) 405 | tot_num_gt_grasps += gt_grasps.shape[0] 406 | 407 | for g, s, f in zip(grasps_list[i], scores_list[i], flex_outcomes_list[i]): 408 | all_grasps.append(np.asarray(g).copy()) 409 | 410 | all_object_indexes.append(i) 411 | all_scores.append(s) 412 | all_flex_outcomes.append(f) 413 | 414 | all_grasps = np.asarray(all_grasps) 415 | 416 | all_scores = np.asarray(all_scores) 417 | order = np.argsort(-all_scores) 418 | num_covered_so_far = 0 419 | correct_grasps_so_far = 0 420 | num_visited_grasps_so_far = 0 421 | 422 | precisions = [] 423 | recalls = [] 424 | prev_score = None 425 | 426 | for oindex, index in enumerate(order): 427 | if oindex % 1000 == 0: 428 | print(oindex, len(order)) 429 | 430 | object_id = all_object_indexes[index] 431 | close_indexes = all_trees[object_id].query_ball_point(all_grasps[index, :3, 3], RADIUS) 432 | 433 | 434 | num_new_covered_gt_grasps = 0 435 | 436 | for close_index in close_indexes: 437 | key = (object_id, close_index) 438 | if key in visited: 439 | continue 440 | 441 | visited.add(key) 442 | num_new_covered_gt_grasps += 1 443 | 444 | correct_grasps_so_far += all_flex_outcomes[index] 445 | num_visited_grasps_so_far += 1 446 | num_covered_so_far += num_new_covered_gt_grasps 447 | 448 | if prev_score is not None and abs(prev_score - all_scores[index]) < 1e-3: 449 | precisions[-1] = float(correct_grasps_so_f 450 | ar) / num_visited_grasps_so_far 451 | recalls[-1] = float(num 452 | _covered_so_far) / tot_num_gt_grasps 453 | else: 454 | precisions.append(float(correct_grasps_so_far) / num_visited_grasps_so_far) 455 | recalls.append(flo 456 | at(num_covered_so_far) / tot_num_gt_grasps) 457 | prev_score = all_scores[index] 458 | 459 | auc = 0 460 | for i in range(1, len(precisions)): 461 | auc += (recalls[i] - recalls[i-1]) * (precisions[i] + precisions[i-1]) * 0.5 462 | 463 | 464 | if visualize: 465 | import matplotlib.pyplot as plt 466 | plt.plot(recalls, precisions) 467 | plt.title('auc = {0:02f}'.format(auc)) 468 | plt.ylim([0.0, 1.05]) 469 | plt.xlim([0.0, 1.0]) 470 | plt.show() 471 | 472 | print('auc = {}'.format(auc)) 473 | np.save( 474 | os.path.join(self._output_folder, '{}_vae+evaluator.npy'.format(self._signature)), 475 | {'precisions': precisions, 'recal 476 | ls': recalls, 'auc': auc, 'cfg': self._cfg} { 477 | ) 478 | 479 | 480 | 481 | }) 482 | return auc 483 | 484 | def eval_grasps_on_flex(self, grasps, cad_path, cad_scale, experiment_folder): 485 | """ 486 | 487 | Evaluates the graps on flex physics engine and determines whether the 488 | grasps will succeed or not. 489 | 490 | Args: 491 | grasps: numpy array list of grasps for an object. 492 | cad_path: string, path to the obj/stl file of the object. 493 | cad_scale: float, the scale that is applied to the mesh of the 494 | object. 495 | experiment_folder: the folder that is used to copy the temp files 496 | necessary for running the jobs and also aggregating the results. 497 | 498 | Returns: 499 | grasp_success: list of binary numbers. 0 means that the grasp failed, 500 | and 1 means that the grasp succeeded. 501 | """ 502 | raise NotImplementedError("The code for grasp evaluation is not released") 503 | 504 | 505 | def __del__(self): 506 | del self._grasp_reader 507 | 508 | 509 | if __name__ == '__main__': 510 | args = make_parser(sys.argv) 511 | utils.mkdir(args.output_folder) 512 | 513 | grasp_sampler_args = utils.read_checkpoint_args(args.grasp_sampler_folder) 514 | grasp_sampler_args.is_train = False 515 | grasp_evaluator_args = utils.read_checkpoint_args( 516 | args.grasp_evaluator_folder) 517 | grasp_evaluator_args.continue_train = True 518 | if args.gradient_based_refinement: 519 | args.num_refine_steps = 10 520 | args.refinement = "gradient" 521 | else: 522 | args.num_refine_steps = 20 523 | args.refinement = "sampling" 524 | 525 | estimator = grasp_estimator.GraspEstimator(grasp_sampler_args, 526 | grasp_evaluator_args, args) 527 | evaluator = Evaluator( 528 | cfg, 529 | args.generate_data_if_missing, 530 | args.output_folder, 531 | args.eval_data_folder, 532 | args.num_experiments, 533 | eval_grasp_evaluator=args.eval_grasp_evaluator, 534 | eval_vae_and_evaluator=args.eval_vae_and_evaluator, 535 | ) 536 | evaluator.eval_all(True) 537 | del evaluator 538 | -------------------------------------------------------------------------------- /grasp_estimator.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | from models import create_model 4 | import numpy as np 5 | import torch 6 | import time 7 | import trimesh 8 | import trimesh.transformations as tra 9 | #import surface_normal 10 | import copy 11 | import os 12 | from utils import utils 13 | 14 | 15 | class GraspEstimator: 16 | """ 17 | Includes the code used for running the inference. 18 | """ 19 | def __init__(self, grasp_sampler_opt, grasp_evaluator_opt, opt): 20 | self.grasp_sampler_opt = grasp_sampler_opt 21 | self.grasp_evaluator_opt = grasp_evaluator_opt 22 | self.opt = opt 23 | self.target_pc_size = opt.target_pc_size 24 | self.num_refine_steps = opt.refine_steps 25 | self.refine_method = opt.refinement_method 26 | self.threshold = opt.threshold 27 | self.batch_size = opt.batch_size 28 | self.generate_dense_grasps = opt.generate_dense_grasps 29 | if self.generate_dense_grasps: 30 | self.num_grasps_per_dim = opt.num_grasp_samples 31 | self.num_grasp_samples = opt.num_grasp_samples * opt.num_grasp_samples 32 | else: 33 | self.num_grasp_samples = opt.num_grasp_samples 34 | self.choose_fn = opt.choose_fn 35 | self.choose_fns = { 36 | "all": 37 | None, 38 | "better_than_threshold": 39 | utils.choose_grasps_better_than_threshold, 40 | "better_than_threshold_in_sequence": 41 | utils.choose_grasps_better_than_threshold_in_sequence, 42 | } 43 | self.device = torch.device("cuda:0") 44 | self.grasp_evaluator = create_model(grasp_evaluator_opt) 45 | self.grasp_sampler = create_model(grasp_sampler_opt) 46 | 47 | def keep_inliers(self, grasps, confidences, z, pc, inlier_indices_list): 48 | for i, inlier_indices in enumerate(inlier_indices_list): 49 | grasps[i] = grasps[i][inlier_indices] 50 | confidences[i] = confidences[i][inlier_indices] 51 | z[i] = z[i][inlier_indices] 52 | pc[i] = pc[i][inlier_indices] 53 | 54 | def generate_and_refine_grasps( 55 | self, 56 | pc, 57 | ): 58 | pc_list, pc_mean = self.prepare_pc(pc) 59 | grasps_list, confidence_list, z_list = self.generate_grasps(pc_list) 60 | inlier_indices = utils.get_inlier_grasp_indices(grasps_list, 61 | torch.zeros(1, 3).to( 62 | self.device), 63 | threshold=1.0, 64 | device=self.device) 65 | self.keep_inliers(grasps_list, confidence_list, z_list, pc_list, 66 | inlier_indices) 67 | improved_eulers, improved_ts, improved_success = [], [], [] 68 | for pc, grasps in zip(pc_list, grasps_list): 69 | out = self.refine_grasps(pc, grasps, self.refine_method, 70 | self.num_refine_steps) 71 | improved_eulers.append(out[0]) 72 | improved_ts.append(out[1]) 73 | improved_success.append(out[2]) 74 | improved_eulers = np.hstack(improved_eulers) 75 | improved_ts = np.hstack(improved_ts) 76 | improved_success = np.hstack(improved_success) 77 | if self.choose_fn is "all": 78 | selection_mask = np.ones(improved_success.shape, dtype=np.float32) 79 | else: 80 | selection_mask = self.choose_fns[self.choose_fn](improved_eulers, 81 | improved_ts, 82 | improved_success, 83 | self.threshold) 84 | grasps = utils.rot_and_trans_to_grasps(improved_eulers, improved_ts, 85 | selection_mask) 86 | utils.denormalize_grasps(grasps, pc_mean) 87 | refine_indexes, sample_indexes = np.where(selection_mask) 88 | success_prob = improved_success[refine_indexes, 89 | sample_indexes].tolist() 90 | return grasps, success_prob 91 | 92 | def prepare_pc(self, pc): 93 | if pc.shape[0] > self.target_pc_size: 94 | pc = utils.regularize_pc_point_count(pc, self.target_pc_size) 95 | pc_mean = np.mean(pc, 0) 96 | pc -= np.expand_dims(pc_mean, 0) 97 | pc = np.tile(pc, (self.num_grasp_samples, 1, 1)) 98 | pc = torch.from_numpy(pc).float().to(self.device) 99 | pcs = [] 100 | pcs = utils.partition_array_into_subarrays(pc, self.batch_size) 101 | return pcs, pc_mean 102 | 103 | def generate_grasps(self, pcs): 104 | all_grasps = [] 105 | all_confidence = [] 106 | all_z = [] 107 | if self.generate_dense_grasps: 108 | latent_samples = self.grasp_sampler.net.module.generate_dense_latents( 109 | self.num_grasps_per_dim) 110 | latent_samples = utils.partition_array_into_subarrays( 111 | latent_samples, self.batch_size) 112 | for latent_sample, pc in zip(latent_samples, pcs): 113 | grasps, confidence, z = self.grasp_sampler.generate_grasps( 114 | pc, latent_sample) 115 | all_grasps.append(grasps) 116 | all_confidence.append(confidence) 117 | all_z.append(z) 118 | else: 119 | for pc in pcs: 120 | grasps, confidence, z = self.grasp_sampler.generate_grasps(pc) 121 | all_grasps.append(grasps) 122 | all_confidence.append(confidence) 123 | all_z.append(z) 124 | return all_grasps, all_confidence, all_z 125 | 126 | def refine_grasps(self, pc, grasps, refine_method, num_refine_steps=10): 127 | 128 | grasp_eulers, grasp_translations = utils.convert_qt_to_rt(grasps) 129 | if refine_method == "gradient": 130 | improve_fun = self.improve_grasps_gradient_based 131 | grasp_eulers = torch.autograd.Variable(grasp_eulers.to( 132 | self.device), 133 | requires_grad=True) 134 | grasp_translations = torch.autograd.Variable(grasp_translations.to( 135 | self.device), 136 | requires_grad=True) 137 | 138 | else: 139 | improve_fun = self.improve_grasps_sampling_based 140 | 141 | improved_success = [] 142 | improved_eulers = [] 143 | improved_ts = [] 144 | improved_eulers.append(grasp_eulers.cpu().data.numpy()) 145 | improved_ts.append(grasp_translations.cpu().data.numpy()) 146 | last_success = None 147 | for i in range(num_refine_steps): 148 | success_prob, last_success = improve_fun(pc, grasp_eulers, 149 | grasp_translations, 150 | last_success) 151 | improved_success.append(success_prob.cpu().data.numpy()) 152 | improved_eulers.append(grasp_eulers.cpu().data.numpy()) 153 | improved_ts.append(grasp_translations.cpu().data.numpy()) 154 | 155 | # we need to run the success on the final improved grasps 156 | grasp_pcs = utils.control_points_from_rot_and_trans( 157 | grasp_eulers, grasp_translations, self.device) 158 | improved_success.append( 159 | self.grasp_evaluator.evaluate_grasps( 160 | pc, grasp_pcs).squeeze().cpu().data.numpy()) 161 | 162 | return np.asarray(improved_eulers), np.asarray( 163 | improved_ts), np.asarray(improved_success) 164 | 165 | def improve_grasps_gradient_based( 166 | self, pcs, grasp_eulers, grasp_trans, last_success 167 | ): #euler_angles, translation, eval_and_improve, metadata): 168 | grasp_pcs = utils.control_points_from_rot_and_trans( 169 | grasp_eulers, grasp_trans, self.device) 170 | 171 | success = self.grasp_evaluator.evaluate_grasps(pcs, grasp_pcs) 172 | success.squeeze().backward( 173 | torch.ones(success.shape[0]).to(self.device)) 174 | delta_t = grasp_trans.grad 175 | norm_t = torch.norm(delta_t, p=2, dim=-1).to(self.device) 176 | # Adjust the alpha so that it won't update more than 1 cm. Gradient is only valid 177 | # in small neighborhood. 178 | alpha = torch.min(0.01 / norm_t, torch.tensor(1.0).to(self.device)) 179 | grasp_trans.data += grasp_trans.grad * alpha[:, None] 180 | temp = grasp_eulers.clone() 181 | grasp_eulers.data += grasp_eulers.grad * alpha[:, None] 182 | return success.squeeze(), None 183 | 184 | def improve_grasps_sampling_based(self, 185 | pcs, 186 | grasp_eulers, 187 | grasp_trans, 188 | last_success=None): 189 | with torch.no_grad(): 190 | if last_success is None: 191 | grasp_pcs = utils.control_points_from_rot_and_trans( 192 | grasp_eulers, grasp_trans, self.device) 193 | last_success = self.grasp_evaluator.evaluate_grasps( 194 | pcs, grasp_pcs) 195 | 196 | delta_t = 2 * (torch.rand(grasp_trans.shape).to(self.device) - 0.5) 197 | delta_t *= 0.02 198 | delta_euler_angles = ( 199 | torch.rand(grasp_eulers.shape).to(self.device) - 0.5) * 2 200 | perturbed_translation = grasp_trans + delta_t 201 | perturbed_euler_angles = grasp_eulers + delta_euler_angles 202 | grasp_pcs = utils.control_points_from_rot_and_trans( 203 | perturbed_euler_angles, perturbed_translation, self.device) 204 | 205 | perturbed_success = self.grasp_evaluator.evaluate_grasps( 206 | pcs, grasp_pcs) 207 | ratio = perturbed_success / torch.max( 208 | last_success, 209 | torch.tensor(0.0001).to(self.device)) 210 | 211 | mask = torch.rand(ratio.shape).to(self.device) <= ratio 212 | 213 | next_success = last_success 214 | ind = torch.where(mask)[0] 215 | next_success[ind] = perturbed_success[ind] 216 | grasp_trans[ind].data = perturbed_translation.data[ind] 217 | grasp_eulers[ind].data = perturbed_euler_angles.data[ind] 218 | return last_success.squeeze(), next_success 219 | -------------------------------------------------------------------------------- /gripper_control_points/panda.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsll/pytorch_6dof-graspnet/634a18c9f087ea25e7eced58c01db13394443fb9/gripper_control_points/panda.npy -------------------------------------------------------------------------------- /gripper_models/featuretype.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsll/pytorch_6dof-graspnet/634a18c9f087ea25e7eced58c01db13394443fb9/gripper_models/featuretype.STL -------------------------------------------------------------------------------- /gripper_models/panda_gripper.obj: -------------------------------------------------------------------------------- 1 | v 0.05421821303665638 0.007778378669172524 0.11076592143774033 2 | v 0.039979396890921635 -0.00613801833242178 0.11197762053608895 3 | v 0.042534183906391264 0.010387184098362923 0.0585316962443525 4 | v 0.054274293556809426 -0.00583982700482011 0.11220345135927201 5 | v 0.06640338242053986 -0.01035996433347464 0.05855462865121663 6 | v 0.03997647545023938 0.008730395697057247 0.09456484627127648 7 | v 0.03994319871620974 0.008622935973107815 0.10935282196998597 8 | v 0.05610156148672104 0.010479452088475226 0.07738816600441933 9 | v 0.0652202693372965 0.010388848371803757 0.07758761990964413 10 | v 0.06525340951979161 -0.01040047220885754 0.07743658919036389 11 | v 0.05391758117824793 0.008841173723340033 0.10898935657143594 12 | v 0.05510264165699482 -0.010494819842278959 0.07683558996915818 13 | v 0.03990887965395814 0.005481416825205088 0.11212470989823342 14 | v 0.053988714665174485 -0.008616076782345774 0.1097279139995575 15 | v 0.05429311186075211 0.005277918651700018 0.11224903401136399 16 | v 0.039867356615141035 -0.00869277399033308 0.10856622692346574 17 | v 0.06630941011011601 0.010401263833045956 0.058566509751579725 18 | v 0.04258330418728292 -0.010448622517287731 0.05854680107415188 19 | v -0.05421821303665638 -0.007778378669172525 0.11076592143774033 20 | v -0.039979396890921635 0.00613801833242178 0.11197762053608895 21 | v -0.042534183906391264 -0.010387184098362923 0.0585316962443525 22 | v -0.054274293556809426 0.005839827004820108 0.11220345135927201 23 | v -0.06640338242053986 0.010359964333474636 0.05855462865121663 24 | v -0.03997647545023938 -0.008730395697057247 0.09456484627127648 25 | v -0.03994319871620974 -0.008622935973107815 0.10935282196998597 26 | v -0.05610156148672104 -0.010479452088475227 0.07738816600441933 27 | v -0.0652202693372965 -0.01038884837180376 0.07758761990964413 28 | v -0.06525340951979161 0.010400472208857536 0.07743658919036389 29 | v -0.05391758117824793 -0.008841173723340034 0.10898935657143594 30 | v -0.05510264165699482 0.010494819842278957 0.07683558996915818 31 | v -0.03990887965395814 -0.005481416825205088 0.11212470989823342 32 | v -0.053988714665174485 0.008616076782345772 0.1097279139995575 33 | v -0.05429311186075211 -0.00527791865170002 0.11224903401136399 34 | v -0.039867356615141035 0.00869277399033308 0.10856622692346574 35 | v -0.06630941011011601 -0.01040126383304596 0.058566509751579725 36 | v -0.04258330418728292 0.010448622517287731 0.05854680107415188 37 | v -0.04178430512547493 -0.028010861948132515 -0.0063964021392166615 38 | v -0.09769713878631592 -0.019066786393523216 0.024627480655908585 39 | v -0.09701105952262878 0.01980205439031124 0.020913975313305855 40 | v -0.09099503606557846 -0.00017339896294288337 0.0005767836119048297 41 | v 0.0005787304835394025 -0.031635917723178864 0.005973074585199356 42 | v -0.09544170647859573 -0.021822135895490646 0.021756364032626152 43 | v -0.10028521716594696 0.016218392178416252 0.04586976766586304 44 | v 0.1000911220908165 -0.014017123728990555 0.056026946753263474 45 | v -0.09460194408893585 0.018555866554379463 0.056617286056280136 46 | v -0.09433312714099884 0.021342597901821136 0.008942007087171078 47 | v 0.08797474205493927 -0.0006270006415434182 -0.024961603805422783 48 | v -0.09127645939588547 -0.014421950094401836 0.06568973511457443 49 | v -0.09320925921201706 -0.01604551635682583 0.06361687183380127 50 | v -0.10042587667703629 -0.004494380671530962 0.058119092136621475 51 | v 0.09249575436115265 -0.012987935915589333 0.06571194529533386 52 | v -0.09116631746292114 -0.023555776104331017 0.004607920069247484 53 | v 0.09070632606744766 -0.015471714548766613 0.06519009917974472 54 | v -0.08543447405099869 -0.022732771933078766 -0.004120331723242998 55 | v -0.09289686381816864 -0.023026082664728165 0.009876862168312073 56 | v 0.09790761768817902 -0.01751038245856762 0.05166616290807724 57 | v 0.0005585200269706547 0.03158979117870331 0.006214227061718702 58 | v 0.10119467973709106 0.003602118231356144 -0.012627636082470417 59 | v 0.09665588289499283 0.0004695942625403404 0.06307835876941681 60 | v 0.09384757280349731 0.017607156187295914 0.058448486030101776 61 | v -0.04204234480857849 0.02801508456468582 -0.006349603645503521 62 | v 0.09945143014192581 3.1802206649445e-05 -0.017161205410957336 63 | v 0.04551994055509567 0.0019174328772351146 -0.025659434497356415 64 | v -0.09025220572948456 -0.01676723174750805 0.06433220207691193 65 | v 0.10030148178339005 0.001190581009723246 0.05862313136458397 66 | v -0.07437754422426224 0.024816755205392838 -0.007155700121074915 67 | v 0.10259784758090973 -0.0034295637160539627 -0.00896429643034935 68 | v 0.1027156412601471 0.003494761884212494 -0.008029340766370296 69 | v -0.08589234948158264 -0.02449524775147438 -0.0018327292054891586 70 | v 0.09406288713216782 0.0131387272849679 0.06468318402767181 71 | v -0.08206511288881302 -0.025270354002714157 0.0051970407366752625 72 | v -0.09466791152954102 -0.02065763622522354 0.009537984617054462 73 | v -0.08824997395277023 0.022314293310046196 -0.0019331998191773891 74 | v -0.09747105836868286 -0.0016167220892384648 0.06230733543634415 75 | v 0.09552478045225143 -0.017053674906492233 0.02212120220065117 76 | v -0.08335519582033157 0.022376984357833862 -0.005526112858206034 77 | v -0.09936285763978958 -0.016994841396808624 0.05411478504538536 78 | v 0.0968022570014 0.017033156007528305 0.030322037637233734 79 | v 0.09160291403532028 0.01695552095770836 0.005562833044677973 80 | v -0.09012892097234726 0.016734914854168892 0.06443186104297638 81 | v 0.09177957475185394 0.01571837067604065 0.06491253525018692 82 | v 0.0840023085474968 0.0018427835311740637 -0.02552490122616291 83 | v 0.08797289431095123 0.0030759950168430805 -0.02397872507572174 84 | v 0.09962863475084305 -0.016013137996196747 0.04930143058300018 85 | v -0.09201552718877792 -0.01844867318868637 0.058393169194459915 86 | v -0.08025997132062912 -0.0008337647304870188 -0.007321717217564583 87 | v -0.07576971501111984 -0.025980981066823006 -0.005082232877612114 88 | v -0.10019978880882263 0.015550931915640831 0.05467259883880615 89 | v 0.09941626340150833 0.015804897993803024 0.05497027188539505 90 | v 0.10374269634485245 -1.7281157852266915e-05 -0.00919930450618267 91 | v 0.08254846930503845 -0.0008678357116878033 -0.025870200246572495 92 | v 0.0875329002737999 -0.016812244430184364 -0.0012064524926245213 93 | v 0.08223627507686615 -0.016532698646187782 -0.005118109285831451 94 | v -0.09373555332422256 0.022707808762788773 0.014340022578835487 95 | v -0.09371249377727509 -0.012566703371703625 0.06516847014427185 96 | v -0.07666800171136856 -0.024650242179632187 -0.0069321258924901485 97 | v 0.08927122503519058 0.01713424362242222 0.0636143907904625 98 | v 0.08776598423719406 -0.0032150978222489357 -0.023884393274784088 99 | v -0.09726057201623917 -0.019229214638471603 0.05058842897415161 100 | v -0.09369184076786041 0.020670883357524872 0.04330829158425331 101 | v 0.09740705043077469 0.017585095018148422 0.051984645426273346 102 | v 0.09855398535728455 -0.01663215272128582 0.05473393574357033 103 | v 0.09344169497489929 -0.014617033302783966 0.06450004875659943 104 | v 0.08296618610620499 0.00381033169105649 -0.024449335411190987 105 | v -0.09092690050601959 -0.021324951201677322 0.0009798021055758 106 | v -0.09280849248170853 -0.0001125619382946752 0.06596215069293976 107 | v -0.0917946845293045 0.021482910960912704 0.0026841284707188606 108 | v 0.09998264163732529 -0.009323876351118088 0.058489199727773666 109 | v 0.08358591049909592 -0.0036368216387927532 -0.024606257677078247 110 | v 0.1001875177025795 0.012505676597356796 0.056894149631261826 111 | v -0.09290558844804764 0.015396904200315475 0.06455627083778381 112 | v 0.0851321741938591 0.016558213159441948 -0.0038727361243218184 113 | v 0.09294531494379044 -0.0005056463996879756 0.06595310568809509 114 | v 0.10115781426429749 -0.0036167786456644535 -0.012610324658453465 115 | v -0.07790137827396393 0.02295910380780697 -0.007399038877338171 116 | v -0.0857401043176651 0.024729391559958458 -0.0012316935462877154 117 | v -0.10016821324825287 -0.014623090624809265 0.055734917521476746 118 | v -0.09951794147491455 -0.018192630261182785 0.043814171105623245 119 | v 0.09070031344890594 -0.017254667356610298 0.0630820095539093 120 | v 0.0919061228632927 -0.016804175451397896 0.006295484956353903 121 | v 0.09953752160072327 0.016230100765824318 0.051584091037511826 122 | v -0.08118050545454025 0.025447947904467583 0.0035006047692149878 123 | v -0.09906721860170364 0.017129460349678993 0.05430515855550766 124 | v -0.08656162023544312 -0.00033731618896126747 -0.004163281060755253 125 | v -0.09461534768342972 -0.00031412430689670146 0.007574658375233412 126 | v -0.07529757916927338 0.026034310460090637 -0.005030847620218992 127 | v -0.08017436414957047 -0.02276112325489521 -0.006909539457410574 128 | v 0.0018608596874400973 0.03161578252911568 0.0011797059560194612 129 | v 0.0458698496222496 -0.0015001518186181784 -0.02592480182647705 130 | v -0.0817025899887085 0.024515172466635704 -0.005051423329859972 131 | v -0.10003473609685898 0.009941554628312588 0.05834079533815384 132 | v -0.09267213940620422 0.013539588078856468 0.0656878799200058 133 | v -0.09849356859922409 0.019268833100795746 0.0449417382478714 134 | v -0.09040140360593796 0.023869164288043976 0.004368236754089594 135 | v 0.0019865017384290695 -0.031597502529621124 0.001152931246906519 136 | v -0.09849606454372406 -1.593970591784455e-05 0.027081793174147606 137 | v 0.10398972034454346 -4.109224391868338e-05 0.005690876394510269 138 | v 0.09192700684070587 0.01342480443418026 0.06573130935430527 139 | f 5 18 3 140 | f 3 8 17 141 | f 17 8 9 142 | f 18 6 3 143 | f 10 9 4 144 | f 3 17 5 145 | f 14 12 10 146 | f 12 16 18 147 | f 14 16 12 148 | f 7 3 6 149 | f 18 16 6 150 | f 16 13 6 151 | f 10 12 18 152 | f 10 18 5 153 | f 10 5 9 154 | f 9 15 4 155 | f 1 15 9 156 | f 7 13 1 157 | f 3 7 8 158 | f 10 4 14 159 | f 5 17 9 160 | f 8 11 9 161 | f 11 1 9 162 | f 4 13 2 163 | f 15 13 4 164 | f 16 2 13 165 | f 2 16 14 166 | f 13 7 6 167 | f 1 13 15 168 | f 7 11 8 169 | f 2 14 4 170 | f 1 11 7 171 | f 23 36 21 172 | f 21 26 35 173 | f 35 26 27 174 | f 36 24 21 175 | f 28 27 22 176 | f 21 35 23 177 | f 32 30 28 178 | f 30 34 36 179 | f 32 34 30 180 | f 25 21 24 181 | f 36 34 24 182 | f 34 31 24 183 | f 28 30 36 184 | f 28 36 23 185 | f 28 23 27 186 | f 27 33 22 187 | f 19 33 27 188 | f 25 31 19 189 | f 21 25 26 190 | f 28 22 32 191 | f 23 35 27 192 | f 26 29 27 193 | f 29 19 27 194 | f 22 31 20 195 | f 33 31 22 196 | f 34 20 31 197 | f 20 34 32 198 | f 31 25 24 199 | f 19 31 33 200 | f 25 29 26 201 | f 20 32 22 202 | f 19 29 25 203 | f 80 97 57 204 | f 75 56 135 205 | f 87 135 41 206 | f 78 128 101 207 | f 128 57 101 208 | f 45 80 57 209 | f 120 135 92 210 | f 41 135 56 211 | f 75 135 120 212 | f 121 137 90 213 | f 114 67 120 214 | f 71 87 41 215 | f 37 135 87 216 | f 106 95 48 217 | f 119 53 64 218 | f 79 128 78 219 | f 97 60 57 220 | f 60 101 57 221 | f 110 137 121 222 | f 88 133 43 223 | f 63 61 104 224 | f 100 57 122 225 | f 135 109 93 226 | f 119 41 56 227 | f 44 84 137 228 | f 114 92 98 229 | f 106 48 51 230 | f 51 113 106 231 | f 132 106 113 232 | f 113 138 132 233 | f 71 41 99 234 | f 78 101 121 235 | f 104 128 112 236 | f 68 79 78 237 | f 128 104 61 238 | f 45 100 133 239 | f 100 45 57 240 | f 100 122 134 241 | f 100 94 133 242 | f 94 100 134 243 | f 128 61 126 244 | f 128 126 122 245 | f 128 122 57 246 | f 66 61 63 247 | f 129 86 115 248 | f 86 129 127 249 | f 115 66 63 250 | f 109 135 37 251 | f 62 114 98 252 | f 93 109 98 253 | f 92 93 98 254 | f 92 135 93 255 | f 102 119 56 256 | f 62 58 90 257 | f 84 102 56 258 | f 137 84 90 259 | f 74 106 132 260 | f 106 74 95 261 | f 80 132 81 262 | f 138 81 132 263 | f 119 85 41 264 | f 85 119 64 265 | f 85 99 41 266 | f 47 62 98 267 | f 83 104 112 268 | f 81 97 80 269 | f 43 133 39 270 | f 131 74 132 271 | f 133 123 45 272 | f 125 136 39 273 | f 126 61 66 274 | f 124 73 76 275 | f 105 54 69 276 | f 82 129 63 277 | f 129 96 127 278 | f 75 84 56 279 | f 68 121 90 280 | f 62 90 114 281 | f 87 71 69 282 | f 48 53 51 283 | f 64 53 48 284 | f 113 70 138 285 | f 44 137 108 286 | f 38 136 125 287 | f 74 50 95 288 | f 43 117 50 289 | f 112 79 58 290 | f 58 79 68 291 | f 58 68 90 292 | f 81 89 60 293 | f 89 110 121 294 | f 70 89 81 295 | f 89 70 110 296 | f 78 121 68 297 | f 89 121 101 298 | f 101 60 89 299 | f 50 74 131 300 | f 63 104 82 301 | f 125 39 46 302 | f 136 43 39 303 | f 127 54 124 304 | f 73 40 107 305 | f 124 76 86 306 | f 86 76 115 307 | f 63 129 115 308 | f 37 96 129 309 | f 37 87 96 310 | f 130 73 116 311 | f 109 37 129 312 | f 72 38 125 313 | f 117 95 50 314 | f 79 112 128 315 | f 65 70 59 316 | f 88 43 131 317 | f 131 43 50 318 | f 134 46 94 319 | f 123 133 88 320 | f 132 111 131 321 | f 111 88 131 322 | f 111 45 123 323 | f 54 105 124 324 | f 116 126 130 325 | f 46 134 107 326 | f 105 125 40 327 | f 40 125 107 328 | f 91 129 82 329 | f 109 129 91 330 | f 92 114 120 331 | f 84 44 102 332 | f 67 90 84 333 | f 67 75 120 334 | f 113 59 70 335 | f 65 59 108 336 | f 137 110 65 337 | f 65 108 137 338 | f 52 69 71 339 | f 83 62 47 340 | f 47 82 83 341 | f 62 83 58 342 | f 112 58 83 343 | f 97 81 60 344 | f 70 81 138 345 | f 70 65 110 346 | f 82 104 83 347 | f 111 123 88 348 | f 111 132 80 349 | f 45 111 80 350 | f 125 46 107 351 | f 39 94 46 352 | f 94 39 133 353 | f 115 130 66 354 | f 126 116 122 355 | f 116 134 122 356 | f 105 40 124 357 | f 124 40 73 358 | f 86 127 124 359 | f 84 75 67 360 | f 67 114 90 361 | f 103 108 51 362 | f 108 59 51 363 | f 44 103 102 364 | f 53 103 51 365 | f 55 42 72 366 | f 118 43 136 367 | f 117 43 118 368 | f 52 105 69 369 | f 91 82 47 370 | f 73 130 76 371 | f 126 66 130 372 | f 73 107 134 373 | f 116 73 134 374 | f 44 108 103 375 | f 51 59 113 376 | f 119 103 53 377 | f 49 99 85 378 | f 85 64 49 379 | f 77 99 49 380 | f 52 55 72 381 | f 99 42 71 382 | f 42 55 71 383 | f 55 52 71 384 | f 52 72 105 385 | f 117 49 95 386 | f 64 48 49 387 | f 130 115 76 388 | f 69 54 96 389 | f 96 54 127 390 | f 96 87 69 391 | f 77 117 118 392 | f 72 42 38 393 | f 99 118 42 394 | f 118 99 77 395 | f 105 72 125 396 | f 117 77 49 397 | f 48 95 49 398 | f 98 91 47 399 | f 119 102 103 400 | f 38 118 136 401 | f 98 109 91 402 | f 118 38 42 403 | -------------------------------------------------------------------------------- /gripper_models/panda_gripper/finger.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsll/pytorch_6dof-graspnet/634a18c9f087ea25e7eced58c01db13394443fb9/gripper_models/panda_gripper/finger.stl -------------------------------------------------------------------------------- /gripper_models/panda_gripper/hand.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsll/pytorch_6dof-graspnet/634a18c9f087ea25e7eced58c01db13394443fb9/gripper_models/panda_gripper/hand.stl -------------------------------------------------------------------------------- /gripper_models/panda_pc.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsll/pytorch_6dof-graspnet/634a18c9f087ea25e7eced58c01db13394443fb9/gripper_models/panda_pc.npy -------------------------------------------------------------------------------- /gripper_models/yumi_gripper/base.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsll/pytorch_6dof-graspnet/634a18c9f087ea25e7eced58c01db13394443fb9/gripper_models/yumi_gripper/base.stl -------------------------------------------------------------------------------- /gripper_models/yumi_gripper/base_coarse.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsll/pytorch_6dof-graspnet/634a18c9f087ea25e7eced58c01db13394443fb9/gripper_models/yumi_gripper/base_coarse.stl -------------------------------------------------------------------------------- /gripper_models/yumi_gripper/finger.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsll/pytorch_6dof-graspnet/634a18c9f087ea25e7eced58c01db13394443fb9/gripper_models/yumi_gripper/finger.stl -------------------------------------------------------------------------------- /gripper_models/yumi_gripper/finger_coarse.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsll/pytorch_6dof-graspnet/634a18c9f087ea25e7eced58c01db13394443fb9/gripper_models/yumi_gripper/finger_coarse.stl -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | def create_model(opt): 2 | from .grasp_net import GraspNetModel 3 | model = GraspNetModel(opt) 4 | return model 5 | -------------------------------------------------------------------------------- /models/grasp_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from . import networks 3 | from os.path import join 4 | import utils.utils as utils 5 | 6 | 7 | class GraspNetModel: 8 | """ Class for training Model weights 9 | 10 | :args opt: structure containing configuration params 11 | e.g., 12 | --dataset_mode -> sampling / evaluation) 13 | """ 14 | def __init__(self, opt): 15 | self.opt = opt 16 | self.gpu_ids = opt.gpu_ids 17 | self.is_train = opt.is_train 18 | if self.gpu_ids and self.gpu_ids[0] >= torch.cuda.device_count(): 19 | self.gpu_ids[0] = torch.cuda.device_count() - 1 20 | self.device = torch.device('cuda:{}'.format( 21 | self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') 22 | self.save_dir = join(opt.checkpoints_dir, opt.name) 23 | self.optimizer = None 24 | self.loss = None 25 | self.pcs = None 26 | self.grasps = None 27 | # load/define networks 28 | self.net = networks.define_classifier(opt, self.gpu_ids, opt.arch, 29 | opt.init_type, opt.init_gain, 30 | self.device) 31 | 32 | self.criterion = networks.define_loss(opt) 33 | 34 | self.confidence_loss = None 35 | if self.opt.arch == "vae": 36 | self.kl_loss = None 37 | self.reconstruction_loss = None 38 | elif self.opt.arch == "gan": 39 | self.reconstruction_loss = None 40 | else: 41 | self.classification_loss = None 42 | 43 | if self.is_train: 44 | self.optimizer = torch.optim.Adam(self.net.parameters(), 45 | lr=opt.lr, 46 | betas=(opt.beta1, 0.999)) 47 | self.scheduler = networks.get_scheduler(self.optimizer, opt) 48 | if not self.is_train or opt.continue_train: 49 | self.load_network(opt.which_epoch, self.is_train) 50 | 51 | def set_input(self, data): 52 | input_pcs = torch.from_numpy(data['pc']).contiguous() 53 | input_grasps = torch.from_numpy(data['grasp_rt']).float() 54 | if self.opt.arch == "evaluator": 55 | targets = torch.from_numpy(data['labels']).float() 56 | else: 57 | targets = torch.from_numpy(data['target_cps']).float() 58 | self.pcs = input_pcs.to(self.device).requires_grad_(self.is_train) 59 | self.grasps = input_grasps.to(self.device).requires_grad_( 60 | self.is_train) 61 | self.targets = targets.to(self.device) 62 | 63 | def generate_grasps(self, pcs, z=None): 64 | with torch.no_grad(): 65 | return self.net.module.generate_grasps(pcs, z=z) 66 | 67 | def evaluate_grasps(self, pcs, gripper_pcs): 68 | success, _ = self.net.module(pcs, gripper_pcs) 69 | return torch.sigmoid(success) 70 | 71 | def forward(self): 72 | return self.net(self.pcs, self.grasps, train=self.is_train) 73 | 74 | def backward(self, out): 75 | if self.opt.arch == 'vae': 76 | predicted_cp, confidence, mu, logvar = out 77 | predicted_cp = utils.transform_control_points( 78 | predicted_cp, predicted_cp.shape[0], device=self.device) 79 | self.reconstruction_loss, self.confidence_loss = self.criterion[1]( 80 | predicted_cp, 81 | self.targets, 82 | confidence=confidence, 83 | confidence_weight=self.opt.confidence_weight, 84 | device=self.device) 85 | self.kl_loss = self.opt.kl_loss_weight * self.criterion[0]( 86 | mu, logvar, device=self.device) 87 | self.loss = self.kl_loss + self.reconstruction_loss + self.confidence_loss 88 | elif self.opt.arch == 'gan': 89 | predicted_cp, confidence = out 90 | predicted_cp = utils.transform_control_points( 91 | predicted_cp, predicted_cp.shape[0], device=self.device) 92 | self.reconstruction_loss, self.confidence_loss = self.criterion( 93 | predicted_cp, 94 | self.targets, 95 | confidence=confidence, 96 | confidence_weight=self.opt.confidence_weight, 97 | device=self.device) 98 | self.loss = self.reconstruction_loss + self.confidence_loss 99 | elif self.opt.arch == 'evaluator': 100 | grasp_classification, confidence = out 101 | self.classification_loss, self.confidence_loss = self.criterion( 102 | grasp_classification.squeeze(), 103 | self.targets, 104 | confidence, 105 | self.opt.confidence_weight, 106 | device=self.device) 107 | self.loss = self.classification_loss + self.confidence_loss 108 | 109 | self.loss.backward() 110 | 111 | def optimize_parameters(self): 112 | self.optimizer.zero_grad() 113 | out = self.forward() 114 | self.backward(out) 115 | self.optimizer.step() 116 | 117 | 118 | ################## 119 | 120 | def load_network(self, which_epoch, train=True): 121 | """load model from disk""" 122 | save_filename = '%s_net.pth' % which_epoch 123 | load_path = join(self.save_dir, save_filename) 124 | net = self.net 125 | if isinstance(net, torch.nn.DataParallel): 126 | net = net.module 127 | print('loading the model from %s' % load_path) 128 | checkpoint = torch.load(load_path, map_location=self.device) 129 | if hasattr(checkpoint['model_state_dict'], '_metadata'): 130 | del checkpoint['model_state_dict']._metadata 131 | net.load_state_dict(checkpoint['model_state_dict']) 132 | if train: 133 | self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 134 | self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 135 | self.opt.epoch_count = checkpoint["epoch"] 136 | else: 137 | net.eval() 138 | 139 | def save_network(self, net_name, epoch_num): 140 | """save model to disk""" 141 | save_filename = '%s_net.pth' % (net_name) 142 | save_path = join(self.save_dir, save_filename) 143 | torch.save( 144 | { 145 | 'epoch': epoch_num + 1, 146 | 'model_state_dict': self.net.module.cpu().state_dict(), 147 | 'optimizer_state_dict': self.optimizer.state_dict(), 148 | 'scheduler_state_dict': self.scheduler.state_dict(), 149 | }, save_path) 150 | 151 | if len(self.gpu_ids) > 0 and torch.cuda.is_available(): 152 | self.net.cuda(self.gpu_ids[0]) 153 | 154 | def update_learning_rate(self): 155 | """update learning rate (called once every epoch)""" 156 | self.scheduler.step() 157 | lr = self.optimizer.param_groups[0]['lr'] 158 | print('learning rate = %.7f' % lr) 159 | 160 | def test(self): 161 | """tests model 162 | returns: number correct and total number 163 | """ 164 | with torch.no_grad(): 165 | out = self.forward() 166 | prediction, confidence = out 167 | if self.opt.arch == "vae": 168 | predicted_cp = utils.transform_control_points( 169 | prediction, prediction.shape[0], device=self.device) 170 | reconstruction_loss, _ = self.criterion[1]( 171 | predicted_cp, 172 | self.targets, 173 | confidence=confidence, 174 | confidence_weight=self.opt.confidence_weight, 175 | device=self.device) 176 | return reconstruction_loss, 1 177 | elif self.opt.arch == "gan": 178 | predicted_cp = utils.transform_control_points( 179 | prediction, prediction.shape[0], device=self.device) 180 | reconstruction_loss, _ = self.criterion( 181 | predicted_cp, 182 | self.targets, 183 | confidence=confidence, 184 | confidence_weight=self.opt.confidence_weight, 185 | device=self.device) 186 | return reconstruction_loss, 1 187 | else: 188 | 189 | predicted = torch.round(torch.sigmoid(prediction)).squeeze() 190 | correct = (predicted == self.targets).sum().item() 191 | return correct, len(self.targets) 192 | -------------------------------------------------------------------------------- /models/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def control_point_l1_loss_better_than_threshold(pred_control_points, 6 | gt_control_points, 7 | confidence, 8 | confidence_threshold, 9 | device="cpu"): 10 | npoints = pred_control_points.shape[1] 11 | mask = torch.greater_equal(confidence, confidence_threshold) 12 | mask_ratio = torch.mean(mask) 13 | mask = torch.repeat_interleave(mask, npoints, dim=1) 14 | p1 = pred_control_points[mask] 15 | p2 = gt_control_points[mask] 16 | 17 | return control_point_l1_loss(p1, p2), mask_ratio 18 | 19 | 20 | def accuracy_better_than_threshold(pred_success_logits, 21 | gt, 22 | confidence, 23 | confidence_threshold, 24 | device="cpu"): 25 | """ 26 | Computes average precision for the grasps with confidence > threshold. 27 | """ 28 | pred_classes = torch.argmax(pred_success_logits, -1) 29 | correct = torch.equal(pred_classes, gt) 30 | mask = torch.squeeze(torch.greater_equal(confidence, confidence_threshold), 31 | -1) 32 | 33 | positive_acc = torch.sum(correct * mask * gt) / torch.max( 34 | torch.sum(mask * gt), torch.tensor(1)) 35 | negative_acc = torch.sum(correct * mask * (1. - gt)) / torch.max( 36 | torch.sum(mask * (1. - gt)), torch.tensor(1)) 37 | 38 | return 0.5 * (positive_acc + negative_acc), torch.sum(mask) / gt.shape[0] 39 | 40 | 41 | def control_point_l1_loss(pred_control_points, 42 | gt_control_points, 43 | confidence=None, 44 | confidence_weight=None, 45 | device="cpu"): 46 | """ 47 | Computes the l1 loss between the predicted control points and the 48 | groundtruth control points on the gripper. 49 | """ 50 | #print('control_point_l1_loss', pred_control_points.shape, 51 | # gt_control_points.shape) 52 | error = torch.sum(torch.abs(pred_control_points - gt_control_points), -1) 53 | error = torch.mean(error, -1) 54 | if confidence is not None: 55 | assert (confidence_weight is not None) 56 | error *= confidence 57 | confidence_term = torch.mean( 58 | torch.log(torch.max( 59 | confidence, 60 | torch.tensor(1e-10).to(device)))) * confidence_weight 61 | #print('confidence_term = ', confidence_term.shape) 62 | 63 | #print('l1_error = {}'.format(error.shape)) 64 | if confidence is None: 65 | return torch.mean(error) 66 | else: 67 | return torch.mean(error), -confidence_term 68 | 69 | 70 | def classification_with_confidence_loss(pred_logit, 71 | gt, 72 | confidence, 73 | confidence_weight, 74 | device="cpu"): 75 | """ 76 | Computes the cross entropy loss and confidence term that penalizes 77 | outputing zero confidence. Returns cross entropy loss and the confidence 78 | regularization term. 79 | """ 80 | classification_loss = torch.nn.functional.binary_cross_entropy_with_logits( 81 | pred_logit, gt) 82 | confidence_term = torch.mean( 83 | torch.log(torch.max( 84 | confidence, 85 | torch.tensor(1e-10).to(device)))) * confidence_weight 86 | 87 | return classification_loss, -confidence_term 88 | 89 | 90 | def min_distance_loss(pred_control_points, 91 | gt_control_points, 92 | confidence=None, 93 | confidence_weight=None, 94 | threshold=None, 95 | device="cpu"): 96 | """ 97 | Computes the minimum distance (L1 distance)between each gt control point 98 | and any of the predicted control points. 99 | 100 | Args: 101 | pred_control_points: tensor of (N_pred, M, 4) shape. N is the number of 102 | grasps. M is the number of points on the gripper. 103 | gt_control_points: (N_gt, M, 4) 104 | confidence: tensor of N_pred, tensor for the confidence of each 105 | prediction. 106 | confidence_weight: float, the weight for confidence loss. 107 | """ 108 | pred_shape = pred_control_points.shape 109 | gt_shape = gt_control_points.shape 110 | 111 | if len(pred_shape) != 3: 112 | raise ValueError( 113 | "pred_control_point should have len of 3. {}".format(pred_shape)) 114 | if len(gt_shape) != 3: 115 | raise ValueError( 116 | "gt_control_point should have len of 3. {}".format(gt_shape)) 117 | if pred_shape != gt_shape: 118 | raise ValueError("shapes do no match {} != {}".format( 119 | pred_shape, gt_shape)) 120 | 121 | # N_pred x Ngt x M x 3 122 | error = pred_control_points.unsqueeze(1) - gt_control_points.unsqueeze(0) 123 | error = torch.sum(torch.abs(error), 124 | -1) # L1 distance of error (N_pred, N_gt, M) 125 | error = torch.mean( 126 | error, -1) # average L1 for all the control points. (N_pred, N_gt) 127 | 128 | min_distance_error, closest_index = error.min( 129 | 0) #[0] # take the min distance for each gt control point. (N_gt) 130 | #print('min_distance_error', get_shape(min_distance_error)) 131 | if confidence is not None: 132 | #print('closest_index', get_shape(closest_index)) 133 | selected_confidence = torch.nn.functional.one_hot( 134 | closest_index, 135 | num_classes=closest_index.shape[0]).float() # (N_gt, N_pred) 136 | selected_confidence *= confidence 137 | #print('selected_confidence', selected_confidence) 138 | selected_confidence = torch.sum(selected_confidence, -1) # N_gt 139 | #print('selected_confidence', selected_confidence) 140 | min_distance_error *= selected_confidence 141 | confidence_term = torch.mean( 142 | torch.log(torch.max( 143 | confidence, 144 | torch.tensor(1e-4).to(device)))) * confidence_weight 145 | else: 146 | confidence_term = 0. 147 | 148 | return torch.mean(min_distance_error), -confidence_term 149 | 150 | 151 | def min_distance_better_than_threshold(pred_control_points, 152 | gt_control_points, 153 | confidence, 154 | confidence_threshold, 155 | device="cpu"): 156 | error = torch.expand_dims(pred_control_points, 1) - torch.expand_dims( 157 | gt_control_points, 0) 158 | error = torch.sum(torch.abs(error), 159 | -1) # L1 distance of error (N_pred, N_gt, M) 160 | error = torch.mean( 161 | error, -1) # average L1 for all the control points. (N_pred, N_gt) 162 | error = torch.min(error, -1) # (B, N_pred) 163 | mask = torch.greater_equal(confidence, confidence_threshold) 164 | mask = torch.squeeze(mask, dim=-1) 165 | 166 | return torch.mean(error[mask]), torch.mean(mask) 167 | 168 | 169 | def kl_divergence(mu, log_sigma, device="cpu"): 170 | """ 171 | Computes the kl divergence for batch of mu and log_sigma. 172 | """ 173 | return torch.mean( 174 | -.5 * torch.sum(1. + log_sigma - mu**2 - torch.exp(log_sigma), dim=-1)) 175 | 176 | 177 | def confidence_loss(confidence, confidence_weight, device="cpu"): 178 | return torch.mean( 179 | torch.log(torch.max( 180 | confidence, 181 | torch.tensor(1e-10).to(device)))) * confidence_weight 182 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import functools 5 | from torch.optim import lr_scheduler 6 | import torch.nn.functional as F 7 | from models import losses 8 | from torch.nn import Sequential as Seq, Linear as Lin, ReLU, BatchNorm1d as BN 9 | import pointnet2_ops.pointnet2_modules as pointnet2 10 | 11 | 12 | def get_scheduler(optimizer, opt): 13 | if opt.lr_policy == 'lambda': 14 | 15 | def lambda_rule(epoch): 16 | lr_l = 1.0 - max( 17 | 0, epoch + 1 + 1 - opt.niter) / float(opt.niter_decay + 1) 18 | return lr_l 19 | 20 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 21 | elif opt.lr_policy == 'step': 22 | scheduler = lr_scheduler.StepLR(optimizer, 23 | step_size=opt.lr_decay_iters, 24 | gamma=0.1) 25 | elif opt.lr_policy == 'plateau': 26 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 27 | mode='min', 28 | factor=0.2, 29 | threshold=0.01, 30 | patience=5) 31 | else: 32 | return NotImplementedError( 33 | 'learning rate policy [%s] is not implemented', opt.lr_policy) 34 | return scheduler 35 | 36 | 37 | def init_weights(net, init_type, init_gain): 38 | def init_func(m): 39 | classname = m.__class__.__name__ 40 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 41 | or classname.find('Linear') != -1): 42 | if init_type == 'normal': 43 | init.normal_(m.weight.data, 0.0, init_gain) 44 | elif init_type == 'xavier': 45 | init.xavier_normal_(m.weight.data, gain=init_gain) 46 | elif init_type == 'kaiming': 47 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 48 | elif init_type == 'orthogonal': 49 | init.orthogonal_(m.weight.data, gain=init_gain) 50 | else: 51 | raise NotImplementedError( 52 | 'initialization method [%s] is not implemented' % 53 | init_type) 54 | elif classname.find('BatchNorm') != -1: 55 | init.normal_(m.weight.data, 1.0, init_gain) 56 | init.constant_(m.bias.data, 0.0) 57 | 58 | net.apply(init_func) 59 | 60 | 61 | def init_net(net, init_type, init_gain, gpu_ids): 62 | if len(gpu_ids) > 0: 63 | assert (torch.cuda.is_available()) 64 | net.cuda(gpu_ids[0]) 65 | net = net.cuda() 66 | net = torch.nn.DataParallel(net, gpu_ids) 67 | if init_type != 'none': 68 | init_weights(net, init_type, init_gain) 69 | return net 70 | 71 | 72 | def define_classifier(opt, gpu_ids, arch, init_type, init_gain, device): 73 | net = None 74 | if arch == 'vae': 75 | net = GraspSamplerVAE(opt.model_scale, opt.pointnet_radius, 76 | opt.pointnet_nclusters, opt.latent_size, device) 77 | elif arch == 'gan': 78 | net = GraspSamplerGAN(opt.model_scale, opt.pointnet_radius, 79 | opt.pointnet_nclusters, opt.latent_size, device) 80 | elif arch == 'evaluator': 81 | net = GraspEvaluator(opt.model_scale, opt.pointnet_radius, 82 | opt.pointnet_nclusters, device) 83 | else: 84 | raise NotImplementedError('model name [%s] is not recognized' % arch) 85 | return init_net(net, init_type, init_gain, gpu_ids) 86 | 87 | 88 | def define_loss(opt): 89 | if opt.arch == 'vae': 90 | kl_loss = losses.kl_divergence 91 | reconstruction_loss = losses.control_point_l1_loss 92 | return kl_loss, reconstruction_loss 93 | elif opt.arch == 'gan': 94 | reconstruction_loss = losses.min_distance_loss 95 | return reconstruction_loss 96 | elif opt.arch == 'evaluator': 97 | loss = losses.classification_with_confidence_loss 98 | return loss 99 | else: 100 | raise NotImplementedError("Loss not found") 101 | 102 | 103 | class GraspSampler(nn.Module): 104 | def __init__(self, latent_size, device): 105 | super(GraspSampler, self).__init__() 106 | self.latent_size = latent_size 107 | self.device = device 108 | 109 | def create_decoder(self, model_scale, pointnet_radius, pointnet_nclusters, 110 | num_input_features): 111 | # The number of input features for the decoder is 3+latent space where 3 112 | # represents the x, y, z position of the point-cloud 113 | 114 | self.decoder = base_network(pointnet_radius, pointnet_nclusters, 115 | model_scale, num_input_features) 116 | self.q = nn.Linear(model_scale * 1024, 4) 117 | self.t = nn.Linear(model_scale * 1024, 3) 118 | self.confidence = nn.Linear(model_scale * 1024, 1) 119 | 120 | def decode(self, xyz, z): 121 | xyz_features = self.concatenate_z_with_pc(xyz, 122 | z).transpose(-1, 123 | 1).contiguous() 124 | for module in self.decoder[0]: 125 | xyz, xyz_features = module(xyz, xyz_features) 126 | x = self.decoder[1](xyz_features.squeeze(-1)) 127 | predicted_qt = torch.cat( 128 | (F.normalize(self.q(x), p=2, dim=-1), self.t(x)), -1) 129 | 130 | return predicted_qt, torch.sigmoid(self.confidence(x)).squeeze() 131 | 132 | def concatenate_z_with_pc(self, pc, z): 133 | z.unsqueeze_(1) 134 | z = z.expand(-1, pc.shape[1], -1) 135 | return torch.cat((pc, z), -1) 136 | 137 | def get_latent_size(self): 138 | return self.latent_size 139 | 140 | 141 | class GraspSamplerVAE(GraspSampler): 142 | """Network for learning a generative VAE grasp-sampler 143 | """ 144 | def __init__(self, 145 | model_scale, 146 | pointnet_radius=0.02, 147 | pointnet_nclusters=128, 148 | latent_size=2, 149 | device="cpu"): 150 | super(GraspSamplerVAE, self).__init__(latent_size, device) 151 | self.create_encoder(model_scale, pointnet_radius, pointnet_nclusters) 152 | 153 | self.create_decoder(model_scale, pointnet_radius, pointnet_nclusters, 154 | latent_size + 3) 155 | self.create_bottleneck(model_scale * 1024, latent_size) 156 | 157 | def create_encoder( 158 | self, 159 | model_scale, 160 | pointnet_radius, 161 | pointnet_nclusters, 162 | ): 163 | # The number of input features for the encoder is 19: the x, y, z 164 | # position of the point-cloud and the flattened 4x4=16 grasp pose matrix 165 | self.encoder = base_network(pointnet_radius, pointnet_nclusters, 166 | model_scale, 19) 167 | 168 | def create_bottleneck(self, input_size, latent_size): 169 | mu = nn.Linear(input_size, latent_size) 170 | logvar = nn.Linear(input_size, latent_size) 171 | self.latent_space = nn.ModuleList([mu, logvar]) 172 | 173 | def encode(self, xyz, xyz_features): 174 | for module in self.encoder[0]: 175 | xyz, xyz_features = module(xyz, xyz_features) 176 | return self.encoder[1](xyz_features.squeeze(-1)) 177 | 178 | def bottleneck(self, z): 179 | return self.latent_space[0](z), self.latent_space[1](z) 180 | 181 | def reparameterize(self, mu, logvar): 182 | std = torch.exp(0.5 * logvar) 183 | eps = torch.randn_like(std) 184 | return mu + eps * std 185 | 186 | def forward(self, pc, grasp=None, train=True): 187 | if train: 188 | return self.forward_train(pc, grasp) 189 | else: 190 | return self.forward_test(pc, grasp) 191 | 192 | def forward_train(self, pc, grasp): 193 | input_features = torch.cat( 194 | (pc, grasp.unsqueeze(1).expand(-1, pc.shape[1], -1)), 195 | -1).transpose(-1, 1).contiguous() 196 | z = self.encode(pc, input_features) 197 | mu, logvar = self.bottleneck(z) 198 | z = self.reparameterize(mu, logvar) 199 | qt, confidence = self.decode(pc, z) 200 | return qt, confidence, mu, logvar 201 | 202 | def forward_test(self, pc, grasp): 203 | input_features = torch.cat( 204 | (pc, grasp.unsqueeze(1).expand(-1, pc.shape[1], -1)), 205 | -1).transpose(-1, 1).contiguous() 206 | z = self.encode(pc, input_features) 207 | mu, _ = self.bottleneck(z) 208 | qt, confidence = self.decode(pc, mu) 209 | return qt, confidence 210 | 211 | def sample_latent(self, batch_size): 212 | return torch.randn(batch_size, self.latent_size).to(self.device) 213 | 214 | def generate_grasps(self, pc, z=None): 215 | if z is None: 216 | z = self.sample_latent(pc.shape[0]) 217 | qt, confidence = self.decode(pc, z) 218 | return qt, confidence, z.squeeze() 219 | 220 | def generate_dense_latents(self, resolution): 221 | """ 222 | For the VAE sampler we consider dense latents to correspond to those between -2 and 2 223 | """ 224 | latents = torch.meshgrid(*[ 225 | torch.linspace(-2, 2, resolution) for i in range(self.latent_size) 226 | ]) 227 | return torch.stack([latents[i].flatten() for i in range(len(latents))], 228 | dim=-1).to(self.device) 229 | 230 | 231 | class GraspSamplerGAN(GraspSampler): 232 | """ 233 | Altough the name says this sampler is based on the GAN formulation, it is 234 | not actually optimizing based on the commonly known adversarial game. 235 | Instead, it is based on the Implicit Maximum Likelihood Estimation from 236 | https://arxiv.org/pdf/1809.09087.pdf which is similar to the GAN formulation 237 | but with new insights that avoids e.g. mode collapses. 238 | """ 239 | def __init__(self, 240 | model_scale, 241 | pointnet_radius, 242 | pointnet_nclusters, 243 | latent_size=2, 244 | device="cpu"): 245 | super(GraspSamplerGAN, self).__init__(latent_size, device) 246 | self.create_decoder(model_scale, pointnet_radius, pointnet_nclusters, 247 | latent_size + 3) 248 | 249 | def sample_latent(self, batch_size): 250 | return torch.rand(batch_size, self.latent_size).to(self.device) 251 | 252 | def forward(self, pc, grasps=None, train=True): 253 | z = self.sample_latent(pc.shape[0]) 254 | return self.decode(pc, z) 255 | 256 | def generate_grasps(self, pc, z=None): 257 | if z is None: 258 | z = self.sample_latent(pc.shape[0]) 259 | qt, confidence = self.decode(pc, z) 260 | return qt, confidence, z.squeeze() 261 | 262 | def generate_dense_latents(self, resolution): 263 | latents = torch.meshgrid(*[ 264 | torch.linspace(0, 1, resolution) for i in range(self.latent_size) 265 | ]) 266 | return torch.stack([latents[i].flatten() for i in range(len(latents))], 267 | dim=-1).to(self.device) 268 | 269 | 270 | class GraspEvaluator(nn.Module): 271 | def __init__(self, 272 | model_scale=1, 273 | pointnet_radius=0.02, 274 | pointnet_nclusters=128, 275 | device="cpu"): 276 | super(GraspEvaluator, self).__init__() 277 | self.create_evaluator(pointnet_radius, model_scale, pointnet_nclusters) 278 | self.device = device 279 | 280 | def create_evaluator(self, pointnet_radius, model_scale, 281 | pointnet_nclusters): 282 | # The number of input features for the evaluator is 4: the x, y, z 283 | # position of the concatenated gripper and object point-clouds and an 284 | # extra binary feature, which is 0 for the object and 1 for the gripper, 285 | # to tell these point-clouds apart 286 | self.evaluator = base_network(pointnet_radius, pointnet_nclusters, 287 | model_scale, 4) 288 | self.predictions_logits = nn.Linear(1024 * model_scale, 1) 289 | self.confidence = nn.Linear(1024 * model_scale, 1) 290 | 291 | def evaluate(self, xyz, xyz_features): 292 | for module in self.evaluator[0]: 293 | xyz, xyz_features = module(xyz, xyz_features) 294 | return self.evaluator[1](xyz_features.squeeze(-1)) 295 | 296 | def forward(self, pc, gripper_pc, train=True): 297 | pc, pc_features = self.merge_pc_and_gripper_pc(pc, gripper_pc) 298 | x = self.evaluate(pc, pc_features.contiguous()) 299 | return self.predictions_logits(x), torch.sigmoid(self.confidence(x)) 300 | 301 | def merge_pc_and_gripper_pc(self, pc, gripper_pc): 302 | """ 303 | Merges the object point cloud and gripper point cloud and 304 | adds a binary auxiliary feature that indicates whether each point 305 | belongs to the object or to the gripper. 306 | """ 307 | pc_shape = pc.shape 308 | gripper_shape = gripper_pc.shape 309 | assert (len(pc_shape) == 3) 310 | assert (len(gripper_shape) == 3) 311 | assert (pc_shape[0] == gripper_shape[0]) 312 | 313 | npoints = pc_shape[1] 314 | batch_size = pc_shape[0] 315 | 316 | l0_xyz = torch.cat((pc, gripper_pc), 1) 317 | labels = [ 318 | torch.ones(pc.shape[1], 1, dtype=torch.float32), 319 | torch.zeros(gripper_pc.shape[1], 1, dtype=torch.float32) 320 | ] 321 | labels = torch.cat(labels, 0) 322 | labels.unsqueeze_(0) 323 | labels = labels.repeat(batch_size, 1, 1) 324 | 325 | l0_points = torch.cat([l0_xyz, labels.to(self.device)], 326 | -1).transpose(-1, 1) 327 | return l0_xyz, l0_points 328 | 329 | 330 | def base_network(pointnet_radius, pointnet_nclusters, scale, in_features): 331 | sa1_module = pointnet2.PointnetSAModule( 332 | npoint=pointnet_nclusters, 333 | radius=pointnet_radius, 334 | nsample=64, 335 | mlp=[in_features, 64 * scale, 64 * scale, 128 * scale]) 336 | sa2_module = pointnet2.PointnetSAModule( 337 | npoint=32, 338 | radius=0.04, 339 | nsample=128, 340 | mlp=[128 * scale, 128 * scale, 128 * scale, 256 * scale]) 341 | 342 | sa3_module = pointnet2.PointnetSAModule( 343 | mlp=[256 * scale, 256 * scale, 256 * scale, 512 * scale]) 344 | 345 | sa_modules = nn.ModuleList([sa1_module, sa2_module, sa3_module]) 346 | fc_layer = nn.Sequential(nn.Linear(512 * scale, 1024 * scale), 347 | nn.BatchNorm1d(1024 * scale), nn.ReLU(True), 348 | nn.Linear(1024 * scale, 1024 * scale), 349 | nn.BatchNorm1d(1024 * scale), nn.ReLU(True)) 350 | return nn.ModuleList([sa_modules, fc_layer]) 351 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsll/pytorch_6dof-graspnet/634a18c9f087ea25e7eced58c01db13394443fb9/options/__init__.py -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from utils import utils 4 | import torch 5 | import shutil 6 | import yaml 7 | 8 | 9 | class BaseOptions: 10 | def __init__(self): 11 | self.parser = argparse.ArgumentParser( 12 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 13 | self.initialized = False 14 | 15 | def initialize(self): 16 | # data params 17 | self.parser.add_argument( 18 | '--dataset_root_folder', 19 | type=str, 20 | default= 21 | '/home/jens/Documents/datasets/grasping/unified_grasp_data/', 22 | help='path to root directory of the dataset.') 23 | self.parser.add_argument('--num_objects_per_batch', 24 | type=int, 25 | default=1, 26 | help='data batch size.') 27 | self.parser.add_argument('--num_grasps_per_object', 28 | type=int, 29 | default=64) 30 | self.parser.add_argument('--npoints', 31 | type=int, 32 | default=1024, 33 | help='number of points in each batch') 34 | self.parser.add_argument( 35 | '--occlusion_nclusters', 36 | type=int, 37 | default=0, 38 | help= 39 | 'clusters the points to nclusters to be selected for simulating the dropout' 40 | ) 41 | self.parser.add_argument( 42 | '--occlusion_dropout_rate', 43 | type=float, 44 | default=0, 45 | help= 46 | 'probability at which the clusters are removed from point cloud.') 47 | self.parser.add_argument('--depth_noise', type=float, 48 | default=0.0) # to be used in the data reader. 49 | self.parser.add_argument('--num_grasp_clusters', type=int, default=32) 50 | self.parser.add_argument('--arch', 51 | choices={"vae", "gan", "evaluator"}, 52 | default='vae') 53 | self.parser.add_argument('--max_dataset_size', 54 | type=int, 55 | default=float("inf"), 56 | help='Maximum number of samples per epoch') 57 | self.parser.add_argument('--num_threads', 58 | default=3, 59 | type=int, 60 | help='# threads for loading data') 61 | self.parser.add_argument( 62 | '--gpu_ids', 63 | type=str, 64 | default='0', 65 | help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 66 | self.parser.add_argument('--checkpoints_dir', 67 | type=str, 68 | default='./checkpoints', 69 | help='models are saved here') 70 | self.parser.add_argument( 71 | '--serial_batches', 72 | action='store_true', 73 | help='if true, takes meshes in order, otherwise takes them randomly' 74 | ) 75 | self.parser.add_argument('--seed', 76 | type=int, 77 | help='if specified, uses seed') 78 | self.parser.add_argument( 79 | '--gripper', 80 | type=str, 81 | default='panda', 82 | help= 83 | 'type of the gripper. Leave it to panda if you want to use it for franka robot' 84 | ) 85 | self.parser.add_argument('--latent_size', type=int, default=2) 86 | self.parser.add_argument( 87 | '--gripper_pc_npoints', 88 | type=int, 89 | default=-1, 90 | help= 91 | 'number of points representing the gripper. -1 just uses the points on the finger and also the base. other values use subsampling of the gripper mesh' 92 | ) 93 | self.parser.add_argument( 94 | '--merge_pcs_in_vae_encoder', 95 | type=int, 96 | default=0, 97 | help= 98 | 'whether to create unified pc in encoder by coloring the points (similar to evaluator' 99 | ) 100 | self.parser.add_argument( 101 | '--allowed_categories', 102 | type=str, 103 | default='', 104 | help= 105 | 'if left blank uses all the categories in the /splits/.json, otherwise only chooses the categories that are set.' 106 | ) 107 | self.parser.add_argument('--blacklisted_categories', 108 | type=str, 109 | default='', 110 | help='The opposite of allowed categories') 111 | 112 | self.parser.add_argument('--use_uniform_quaternions', 113 | type=int, 114 | default=0) 115 | self.parser.add_argument( 116 | '--model_scale', 117 | type=int, 118 | default=1, 119 | help= 120 | 'the scale of the parameters. Use scale >= 1. Scale=2 increases the number of parameters in model by 4x.' 121 | ) 122 | self.parser.add_argument( 123 | '--splits_folder_name', 124 | type=str, 125 | default='splits', 126 | help= 127 | 'Folder name for the directory that has all the jsons for train/test splits.' 128 | ) 129 | self.parser.add_argument( 130 | '--grasps_folder_name', 131 | type=str, 132 | default='grasps', 133 | help= 134 | 'Directory that contains the grasps. Will be joined with the dataset_root_folder and the file names as defined in the splits.' 135 | ) 136 | self.parser.add_argument( 137 | '--pointnet_radius', 138 | help='Radius for ball query for PointNet++, just the first layer', 139 | type=float, 140 | default=0.02) 141 | self.parser.add_argument( 142 | '--pointnet_nclusters', 143 | help= 144 | 'Number of cluster centroids for PointNet++, just the first layer', 145 | type=int, 146 | default=128) 147 | self.parser.add_argument( 148 | '--init_type', 149 | type=str, 150 | default='normal', 151 | help='network initialization [normal|xavier|kaiming|orthogonal]') 152 | self.parser.add_argument( 153 | '--init_gain', 154 | type=float, 155 | default=0.02, 156 | help='scaling factor for normal, xavier and orthogonal.') 157 | self.parser.add_argument( 158 | '--grasps_ratio', 159 | type=float, 160 | default=1.0, 161 | help= 162 | 'used for checking the effect of number of grasps per object on the success of the model.' 163 | ) 164 | self.parser.add_argument( 165 | '--skip_error', 166 | action='store_true', 167 | help= 168 | 'Will not fill the dataset with a new grasp if it raises NoPositiveGraspsException' 169 | ) 170 | self.parser.add_argument( 171 | '--balanced_data', 172 | action='store_true', 173 | default=False, 174 | ) 175 | self.parser.add_argument( 176 | '--confidence_weight', 177 | type=float, 178 | default=1.0, 179 | help= 180 | 'initially I wanted to compute confidence for vae and evaluator outputs, ' 181 | 'setting the confidence weight to 1. immediately pushes the confidence to 1.0.' 182 | ) 183 | 184 | def parse(self): 185 | if not self.initialized: 186 | self.initialize() 187 | self.opt, unknown = self.parser.parse_known_args() 188 | self.opt.is_train = self.is_train # train or test 189 | if self.opt.is_train: 190 | self.opt.dataset_split = "train" 191 | else: 192 | self.opt.dataset_split = "test" 193 | self.opt.batch_size = self.opt.num_objects_per_batch * \ 194 | self.opt.num_grasps_per_object 195 | str_ids = self.opt.gpu_ids.split(',') 196 | self.opt.gpu_ids = [] 197 | for str_id in str_ids: 198 | id = int(str_id) 199 | if id >= 0: 200 | self.opt.gpu_ids.append(id) 201 | # set gpu ids 202 | if len(self.opt.gpu_ids) > 0: 203 | torch.cuda.set_device(self.opt.gpu_ids[0]) 204 | 205 | args = vars(self.opt) 206 | 207 | if self.opt.seed is not None: 208 | import numpy as np 209 | import random 210 | torch.manual_seed(self.opt.seed) 211 | np.random.seed(self.opt.seed) 212 | random.seed(self.opt.seed) 213 | 214 | if self.is_train: 215 | print('------------ Options -------------') 216 | for k, v in sorted(args.items()): 217 | print('%s: %s' % (str(k), str(v))) 218 | print('-------------- End ----------------') 219 | 220 | # save to the disk 221 | name = self.opt.arch 222 | name += "_lr_" + str(self.opt.lr).split(".")[-1] + "_bs_" + str( 223 | self.opt.batch_size) 224 | name += "_scale_" + str(self.opt.model_scale) + "_npoints_" + str( 225 | self.opt.pointnet_nclusters) + "_radius_" + str( 226 | self.opt.pointnet_radius).split(".")[-1] 227 | if self.opt.arch == "vae" or self.opt.arch == "gan": 228 | name += "_latent_size_" + str(self.opt.latent_size) 229 | 230 | self.opt.name = name 231 | expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name) 232 | if os.path.isdir(expr_dir) and not self.opt.continue_train: 233 | option = "Directory " + expr_dir + \ 234 | " already exists and you have not chosen to continue to train.\nDo you want to override that training instance with a new one the press (Y/N)." 235 | print(option) 236 | while True: 237 | choice = input() 238 | if choice.upper() == "Y": 239 | print("Overriding directory " + expr_dir) 240 | shutil.rmtree(expr_dir) 241 | utils.mkdir(expr_dir) 242 | break 243 | elif choice.upper() == "N": 244 | print( 245 | "Terminating. Remember, if you want to continue to train from a saved instance then run the script with the flag --continue_train" 246 | ) 247 | return None 248 | else: 249 | utils.mkdir(expr_dir) 250 | 251 | yaml_path = os.path.join(expr_dir, 'opt.yaml') 252 | with open(yaml_path, 'w') as yaml_file: 253 | yaml.dump(args, yaml_file) 254 | 255 | file_name = os.path.join(expr_dir, 'opt.txt') 256 | with open(file_name, 'wt') as opt_file: 257 | opt_file.write('------------ Options -------------\n') 258 | for k, v in sorted(args.items()): 259 | opt_file.write('%s: %s\n' % (str(k), str(v))) 260 | opt_file.write('-------------- End ----------------\n') 261 | return self.opt 262 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | def initialize(self): 6 | BaseOptions.initialize(self) 7 | self.parser.add_argument( 8 | '--which_epoch', 9 | type=str, 10 | default='latest', 11 | help='which epoch to load? set to latest to use latest cached model' 12 | ) 13 | 14 | self.is_train = False -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | def initialize(self): 6 | BaseOptions.initialize(self) 7 | self.parser.add_argument( 8 | '--print_freq', 9 | type=int, 10 | default=100, 11 | help='frequency of showing training results on console') 12 | self.parser.add_argument('--save_latest_freq', 13 | type=int, 14 | default=250, 15 | help='frequency of saving the latest results') 16 | self.parser.add_argument( 17 | '--save_epoch_freq', 18 | type=int, 19 | default=1, 20 | help='frequency of saving checkpoints at the end of epochs') 21 | self.parser.add_argument( 22 | '--run_test_freq', 23 | type=int, 24 | default=1, 25 | help='frequency of running test in training script') 26 | self.parser.add_argument( 27 | '--continue_train', 28 | action='store_true', 29 | help='continue training: load the latest model') 30 | self.parser.add_argument( 31 | '--epoch_count', 32 | type=int, 33 | default=1, 34 | help= 35 | 'the starting epoch count, we save the model by , +, ...' 36 | ) 37 | self.parser.add_argument('--phase', 38 | type=str, 39 | default='train', 40 | help='train, val, test, etc') 41 | self.parser.add_argument( 42 | '--which_epoch', 43 | type=str, 44 | default='latest', 45 | help='which epoch to load? set to latest to use latest cached model' 46 | ) 47 | self.parser.add_argument('--niter', 48 | type=int, 49 | default=100, 50 | help='# of iter at starting learning rate') 51 | self.parser.add_argument( 52 | '--niter_decay', 53 | type=int, 54 | default=2000, 55 | help='# of iter to linearly decay learning rate to zero') 56 | self.parser.add_argument('--beta1', 57 | type=float, 58 | default=0.9, 59 | help='momentum term of adam') 60 | self.parser.add_argument('--lr', 61 | type=float, 62 | default=0.0002, 63 | help='initial learning rate for adam') 64 | self.parser.add_argument( 65 | '--lr_policy', 66 | type=str, 67 | default='lambda', 68 | help='learning rate policy: lambda|step|plateau') 69 | self.parser.add_argument( 70 | '--lr_decay_iters', 71 | type=int, 72 | default=50, 73 | help='multiply by a gamma every lr_decay_iters iterations') 74 | self.parser.add_argument('--kl_loss_weight', type=float, default=0.01) 75 | self.parser.add_argument('--no_vis', 76 | action='store_true', 77 | help='will not use tensorboard') 78 | self.parser.add_argument('--verbose_plot', 79 | action='store_true', 80 | help='plots network weights, etc.') 81 | self.is_train = True -------------------------------------------------------------------------------- /renderer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsll/pytorch_6dof-graspnet/634a18c9f087ea25e7eced58c01db13394443fb9/renderer/__init__.py -------------------------------------------------------------------------------- /renderer/object_renderer.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | 4 | import trimesh 5 | import trimesh.transformations as tra 6 | import pyrender 7 | import numpy as np 8 | import copy 9 | import cv2 10 | import h5py 11 | import utils.sample as sample 12 | import math 13 | import sys 14 | import argparse 15 | import os 16 | 17 | 18 | class ObjectRenderer: 19 | def __init__(self, fov=np.pi / 6, object_paths=[], object_scales=[]): 20 | """ 21 | Args: 22 | fov: float, 23 | """ 24 | self._fov = fov 25 | self.mesh = None 26 | self._scene = None 27 | self.tmesh = None 28 | self._init_scene() 29 | self._object_nodes = [] 30 | self._object_means = [] 31 | self._object_distances = [] 32 | 33 | assert (isinstance(object_paths, list)) 34 | assert (len(object_paths) > 0) 35 | 36 | for path, scale in zip(object_paths, object_scales): 37 | node, obj_mean = self._load_object(path, scale) 38 | self._object_nodes.append(node) 39 | self._object_means.append(obj_mean) 40 | 41 | def _init_scene(self): 42 | self._scene = pyrender.Scene() 43 | camera = pyrender.PerspectiveCamera( 44 | yfov=self._fov, aspectRatio=1.0, 45 | znear=0.001) # do not change aspect ratio 46 | camera_pose = tra.euler_matrix(np.pi, 0, 0) 47 | 48 | self._scene.add(camera, pose=camera_pose, name='camera') 49 | 50 | light = pyrender.SpotLight(color=np.ones(4), 51 | intensity=3., 52 | innerConeAngle=np.pi / 16, 53 | outerConeAngle=np.pi / 6.0) 54 | self._scene.add(light, pose=camera_pose, name='light') 55 | 56 | self.renderer = pyrender.OffscreenRenderer(400, 400) 57 | 58 | def _load_object(self, path, scale=1.0): 59 | obj = sample.Object(path) 60 | obj.rescale(scale) 61 | print('rescaling with scale', scale) 62 | 63 | tmesh = obj.mesh 64 | tmesh_mean = np.mean(tmesh.vertices, 0) 65 | tmesh.vertices -= np.expand_dims(tmesh_mean, 0) 66 | 67 | lbs = np.min(tmesh.vertices, 0) 68 | ubs = np.max(tmesh.vertices, 0) 69 | self._object_distances.append(np.max(ubs - lbs) * 5) 70 | 71 | print(self._object_distances) 72 | 73 | self.tmesh = copy.deepcopy(tmesh) 74 | mesh = pyrender.Mesh.from_trimesh(tmesh) 75 | 76 | return self._scene.add(mesh, 77 | name='object'), np.expand_dims(tmesh_mean, 0) 78 | 79 | def _to_pointcloud(self, depth): 80 | fy = fx = 0.5 / np.tan(self._fov * 0.5) # aspectRatio is one. 81 | height = depth.shape[0] 82 | width = depth.shape[1] 83 | 84 | mask = np.where(depth > 0) 85 | 86 | x = mask[1] 87 | y = mask[0] 88 | 89 | normalized_x = (x.astype(np.float32) - width * 0.5) / width 90 | normalized_y = (y.astype(np.float32) - height * 0.5) / height 91 | 92 | world_x = normalized_x * depth[y, x] / fx 93 | world_y = normalized_y * depth[y, x] / fy 94 | world_z = depth[y, x] 95 | ones = np.ones(world_z.shape[0], dtype=np.float32) 96 | 97 | return np.vstack((world_x, world_y, world_z, ones)).T 98 | 99 | def render(self, object_poses, render_pc=True): 100 | assert (isinstance(object_poses, list)) 101 | assert (len(object_poses) == len(self._object_nodes)) 102 | 103 | all_transferred_poses = [] 104 | for object_pose, object_node, object_distance in zip( 105 | object_poses, self._object_nodes, self._object_distances): 106 | transferred_pose = object_pose.copy() 107 | transferred_pose[2, 3] = object_distance 108 | all_transferred_poses.append(transferred_pose) 109 | self._scene.set_pose(object_node, transferred_pose) 110 | 111 | color, depth = self.renderer.render(self._scene) 112 | 113 | if render_pc: 114 | pc = self._to_pointcloud(depth) 115 | else: 116 | pc = None 117 | 118 | return color, depth, pc, all_transferred_poses 119 | 120 | def render_all_and_save_to_h5(self, output_path, all_eulers, vis=False): 121 | """ 122 | Args: 123 | output_path: path of the h5 file. 124 | all_eulers: list of 3 elemenet-tuples indicating the euler angles 125 | in degrees. 126 | """ 127 | if len(self._object_nodes) != 1: 128 | raise ValueError( 129 | 'object nodes should have 1 element, not {}'.format( 130 | len(self._object_nodes))) 131 | 132 | hf = h5py.File(output_path) 133 | #point_grp = hf.create_group('points') 134 | #pose_grp = hf.create_group('object_poses') 135 | mean_grp = hf.create_dataset('object_mean', data=self._object_means[0]) 136 | 137 | pcs = [] 138 | rotations = [] 139 | #import mayavi.mlab as mlab 140 | for i, euler in enumerate(all_eulers): 141 | assert isinstance(euler, tuple) and len(euler) == 3 142 | rotation = tra.euler_matrix(*euler) 143 | color, _, pc, final_rotation = self.render([rotation]) 144 | MAX_POINTS = 3000 145 | if pc.shape[0] > MAX_POINTS: 146 | pc = pc[np.random.choice( 147 | range(pc.shape[0]), replace=False, size=MAX_POINTS), :] 148 | elif pc.shape[0] < MAX_POINTS: 149 | pc = pc[np.random.choice( 150 | range(pc.shape[0]), replace=True, size=MAX_POINTS), :] 151 | 152 | #print('{}/{}: {}'.format(i, len(all_eulers), pc.shape)) 153 | cv2.imshow('w', color) 154 | cv2.waitKey(1) 155 | 156 | # mlab.figure() 157 | #mlab.points3d(pc[:, 0], pc[:, 1], pc[:, 2]) 158 | # mlab.show() 159 | key = '{}_{}_{}'.format(euler[0], euler[1], euler[2]) 160 | pcs.append(pc) 161 | rotations.append(final_rotation[0]) 162 | 163 | hf.create_dataset('pcs', data=pcs, compression='gzip') 164 | hf.create_dataset('object_poses', data=rotations, compression='gzip') 165 | 166 | hf.close() 167 | 168 | @property 169 | def object_distances(self): 170 | return self._object_distances 171 | -------------------------------------------------------------------------------- /renderer/online_object_renderer.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | 4 | import numpy as np 5 | import copy 6 | import cv2 7 | import h5py 8 | import utils.sample as sample 9 | import utils.utils as utils 10 | import math 11 | import sys 12 | import argparse 13 | import os 14 | import time 15 | # Uncomment following line for headless rendering 16 | os.environ["PYOPENGL_PLATFORM"] = "egl" 17 | import pyrender 18 | 19 | import trimesh 20 | import trimesh.transformations as tra 21 | from multiprocessing import Manager 22 | import multiprocessing as mp 23 | 24 | 25 | class OnlineObjectRenderer: 26 | def __init__(self, fov=np.pi / 6, caching=True): 27 | """ 28 | Args: 29 | fov: float, 30 | """ 31 | self._fov = fov 32 | self._fy = self._fx = 1 / (0.5 / np.tan(self._fov * 0.5) 33 | ) # aspectRatio is one. 34 | self.mesh = None 35 | self._scene = None 36 | self.tmesh = None 37 | self._init_scene() 38 | self._current_context = None 39 | self._cache = {} if caching else None 40 | self._caching = caching 41 | 42 | def _init_scene(self): 43 | self._scene = pyrender.Scene() 44 | camera = pyrender.PerspectiveCamera( 45 | yfov=self._fov, aspectRatio=1.0, 46 | znear=0.001) # do not change aspect ratio 47 | camera_pose = tra.euler_matrix(np.pi, 0, 0) 48 | 49 | self._scene.add(camera, pose=camera_pose, name='camera') 50 | 51 | #light = pyrender.SpotLight(color=np.ones(4), intensity=3., innerConeAngle=np.pi/16, outerConeAngle=np.pi/6.0) 52 | #self._scene.add(light, pose=camera_pose, name='light') 53 | 54 | self.renderer = None 55 | 56 | def _load_object(self, path, scale): 57 | if (path, scale) in self._cache: 58 | return self._cache[(path, scale)] 59 | obj = sample.Object(path) 60 | obj.rescale(scale) 61 | 62 | tmesh = obj.mesh 63 | tmesh_mean = np.mean(tmesh.vertices, 0) 64 | tmesh.vertices -= np.expand_dims(tmesh_mean, 0) 65 | 66 | lbs = np.min(tmesh.vertices, 0) 67 | ubs = np.max(tmesh.vertices, 0) 68 | object_distance = np.max(ubs - lbs) * 5 69 | 70 | mesh = pyrender.Mesh.from_trimesh(tmesh) 71 | 72 | context = { 73 | 'tmesh': copy.deepcopy(tmesh), 74 | 'distance': object_distance, 75 | 'node': pyrender.Node(mesh=mesh), 76 | 'mesh_mean': np.expand_dims(tmesh_mean, 0), 77 | } 78 | 79 | self._cache[(path, scale)] = context 80 | return self._cache[(path, scale)] 81 | 82 | def change_object(self, path, scale): 83 | if self._current_context is not None: 84 | self._scene.remove_node(self._current_context['node']) 85 | 86 | if not self._caching: 87 | self._cache = {} 88 | self._current_context = self._load_object(path, scale) 89 | self._scene.add_node(self._current_context['node']) 90 | 91 | def current_context(self): 92 | return self._current_context 93 | 94 | def _to_pointcloud(self, depth): 95 | height = depth.shape[0] 96 | width = depth.shape[1] 97 | 98 | mask = np.where(depth > 0) 99 | 100 | x = mask[1] 101 | y = mask[0] 102 | 103 | normalized_x = (x.astype(np.float32) - width * 0.5) / width 104 | normalized_y = (y.astype(np.float32) - height * 0.5) / height 105 | 106 | world_x = self._fx * normalized_x * depth[y, x] 107 | world_y = self._fy * normalized_y * depth[y, x] 108 | world_z = depth[y, x] 109 | ones = np.ones(world_z.shape[0], dtype=np.float32) 110 | 111 | return np.vstack((world_x, world_y, world_z, ones)).T 112 | 113 | def change_and_render(self, cad_path, cad_scale, pose, render_pc=True): 114 | self.change_object(cad_path, cad_scale) 115 | color, depth, pc, transferred_pose = self.render(pose) 116 | 117 | return color, depth, pc, transferred_pose 118 | 119 | def render(self, pose, render_pc=True): 120 | if self.renderer is None: 121 | self.renderer = pyrender.OffscreenRenderer(400, 400) 122 | if self._current_context is None: 123 | raise ValueError('invoke change_object first') 124 | transferred_pose = pose.copy() 125 | transferred_pose[2, 3] = self._current_context['distance'] 126 | self._scene.set_pose(self._current_context['node'], transferred_pose) 127 | 128 | color, depth = self.renderer.render(self._scene) 129 | 130 | if render_pc: 131 | pc = self._to_pointcloud(depth) 132 | else: 133 | pc = None 134 | 135 | return color, depth, pc, transferred_pose 136 | 137 | def render_canonical_pc(self, poses): 138 | all_pcs = [] 139 | for pose in poses: 140 | _, _, pc, pose = self.render(pose) 141 | pc = pc.dot(utils.inverse_transform(pose).T) 142 | all_pcs.append(pc) 143 | all_pcs = np.concatenate(all_pcs, 0) 144 | return all_pcs 145 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | h5py==2.10.0 2 | tqdm==4.45.0 3 | mayavi==4.7.1 4 | torch==1.4.0+cu100 -f 5 | pointnet2_ops==3.0.0 6 | trimesh==3.6.30 7 | numpy==1.17.4 8 | pyrender==0.1.39 9 | matplotlib==3.1.1 10 | easydict==1.9 11 | opencv_python==4.2.0.34 12 | PyYAML==5.4 13 | tensorboardX==2.0 14 | python-fcl 15 | rtree 16 | -------------------------------------------------------------------------------- /shapenet_ids.txt: -------------------------------------------------------------------------------- 1 | d7305324e9dd49eccee5e41d780064a2 2 | f7d776fd68b126f23b67070c4a034f08 3 | fad118b32085f3f2c2c72e575af174cd 4 | d851cbc873de1c4d3b6eb309177a6753 5 | c382be46c8ab7815d333084c1357713e 6 | 387b695db51190d3be276203d0b1a33f 7 | 162201dfe14b73f0281365259d1cf342 8 | 2efc35a3625fa50961a9876fa6384765 9 | 1be6b2c84cdab826c043c2d07bb83fc8 10 | 3702fc385b6c7708fc33503fd88ecb34 11 | a86d587f38569fdf394a7890920ef7fd 12 | 1bc5d303ff4d6e7e1113901b72a68e7c 13 | 6a772d12b98ab61dc26651d9d35b77ca 14 | 8bb057d18e2fcc4779368d1198f406e7 15 | 15787789482f045d8add95bf56d3d2fa 16 | 4be4184845972fba5ea36d5a57a5a3bb 17 | 6c379385bf0a23ffdec712af445786fe 18 | b6f30c63c946c286cf6897d8875cfd5e 19 | 1f035aa5fc6da0983ecac81e09b15ea9 20 | 8099b9d546231dd27b9c6deef486a7d8 21 | d75af64aa166c24eacbe2257d0988c9c 22 | 8e840ba109f252534c6955b188e290e0 23 | 7984d4980d5b07bceba393d429f71de3 24 | 9a52843cc89cd208362be90aaa182ec6 25 | 17952a204c0a9f526c69dceb67157a66 26 | 4b93a81c3534e56d19620b61f6587b3e 27 | 3093367916fb5216823323ed0e090a6f 28 | 8f6c86feaa74698d5c91ee20ade72edc 29 | 6faf1f04bde838e477f883dde7397db2 30 | ff1a44e1c1785d618bca309f2c51966a 31 | d2e7e725aa6b39f0d333084c1357713e 32 | b7e705de46ebdcc14af54ba5738cb1c5 33 | 6b810dbc89542fd8a531220b48579115 34 | 633379db14d4d2b287dd60af81c93a3c 35 | 10f6e09036350e92b3f21f1137c3c347 36 | c5eb3234b73037562825656dc457df78 37 | e94e46bc5833f2f5e57b873e4f3ef3a4 38 | 4b8b10d03552e0891898dfa8eb8eefff 39 | f626192a5930d6c712f0124e8fa3930b 40 | fa23aa60ec51c8e4c40fe5637f0a27e1 41 | b811555ccf5ef6c4948fa2daa427fe1f 42 | 3143a4accdc23349cac584186c95ce9b 43 | be3c2533130dd3da55f46d55537192b6 44 | 48e260a614c0fd4434a8988fdcee4fde 45 | 594b22f21daf33ce6aea2f18ee404fd5 46 | 3d3e993f7baa4d7ef1ff24a8b1564a36 47 | 46955fddcc83a50f79b586547e543694 48 | 4845731dbf7522b07492cbf7d8bec255 49 | c34718bd10e378186c6c61abcbd83e5a 50 | e9499e4a9f632725d6e865157050a80e 51 | d74bc917899133e080c257afea181fa2 52 | 3ae3a9b74f96fef28fe15648f042f0d9 53 | 34ae0b61b0d8aaf2d7b20fded0142d7a 54 | ecb86f63e92e346a25c70fb1df3f879b 55 | 8012f52dd0a4d2f718a93a45bf780820 56 | 83fd5e88eca47f1dab298e30fc6f45ba 57 | 8eab5598b81afd7bab5b523beb03efcd 58 | dc0926ce09d6ce78eb8e919b102c6c08 59 | 9afea0432f292379dc0e610397fef7f9 60 | dd381b3459767f7b18f18cdcd25d1bbb 61 | 5e896bc124bc0af9fd590443d27a974e 62 | 3108a736282eec1bc58e834f0b160845 63 | 128ecbc10df5b05d96eaf1340564a4de 64 | 20b7adb178ea2c71d8892a9c05c4aa0e 65 | 5fe74baba21bba7ca4eec1b19b3a18f8 66 | 244894af3ba967ccd957eaf7f4edb205 67 | 9d8c711750a73b06ad1d789f3b2120d0 68 | 336122c3105440d193e42e2720468bf0 69 | 3dbd66422997d234b811ffed11682339 70 | 4eefe941048189bdb8046e84ebdc62d2 71 | a637500654ca8d16c97cfc3e8a6b1d16 72 | 85a2511c375b5b32f72755048bac3f96 73 | 586e67c53f181dc22adf8abaa25e0215 74 | 83827973c79ca7631c9ec1e03e401f54 75 | cd418bf69c286701e301960cce35ab79 76 | 960c5c5bff2d3a4bbced73c51e99f8b2 77 | 6ca2149ac6d3699130612f5c0ef21eb8 78 | 68582543c4c6d0bccfdfe3f21f42a111 79 | 74221bae887b0e801ab89e10ca4a3aec 80 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from options.test_options import TestOptions 2 | from data import DataLoader 3 | from models import create_model 4 | from utils.writer import Writer 5 | 6 | 7 | def run_test(epoch=-1, name=""): 8 | print('Running Test') 9 | opt = TestOptions().parse() 10 | opt.serial_batches = True # no shuffle 11 | opt.name = name 12 | dataset = DataLoader(opt) 13 | model = create_model(opt) 14 | writer = Writer(opt) 15 | # test 16 | writer.reset_counter() 17 | 18 | for i, data in enumerate(dataset): 19 | model.set_input(data) 20 | ncorrect, nexamples = model.test() 21 | writer.update_counter(ncorrect, nexamples) 22 | writer.print_acc(epoch, writer.acc) 23 | return writer.acc 24 | 25 | 26 | if __name__ == '__main__': 27 | run_test() 28 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | from options.train_options import TrainOptions 3 | from data import DataLoader 4 | from models import create_model 5 | from utils.writer import Writer 6 | from test import run_test 7 | import threading 8 | 9 | 10 | def main(): 11 | opt = TrainOptions().parse() 12 | if opt == None: 13 | return 14 | 15 | dataset = DataLoader(opt) 16 | dataset_size = len(dataset) * opt.num_grasps_per_object 17 | model = create_model(opt) 18 | writer = Writer(opt) 19 | total_steps = 0 20 | for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): 21 | epoch_start_time = time.time() 22 | iter_data_time = time.time() 23 | epoch_iter = 0 24 | for i, data in enumerate(dataset): 25 | iter_start_time = time.time() 26 | if total_steps % opt.print_freq == 0: 27 | t_data = iter_start_time - iter_data_time 28 | total_steps += opt.batch_size 29 | epoch_iter += opt.batch_size 30 | model.set_input(data) 31 | model.optimize_parameters() 32 | if total_steps % opt.print_freq == 0: 33 | loss_types = [] 34 | if opt.arch == "vae": 35 | loss = [ 36 | model.loss, model.kl_loss, model.reconstruction_loss, 37 | model.confidence_loss 38 | ] 39 | loss_types = [ 40 | "total_loss", "kl_loss", "reconstruction_loss", 41 | "confidence loss" 42 | ] 43 | elif opt.arch == "gan": 44 | loss = [ 45 | model.loss, model.reconstruction_loss, 46 | model.confidence_loss 47 | ] 48 | loss_types = [ 49 | "total_loss", "reconstruction_loss", "confidence_loss" 50 | ] 51 | else: 52 | loss = [ 53 | model.loss, model.classification_loss, 54 | model.confidence_loss 55 | ] 56 | loss_types = [ 57 | "total_loss", "classification_loss", "confidence_loss" 58 | ] 59 | t = (time.time() - iter_start_time) / opt.batch_size 60 | writer.print_current_losses(epoch, epoch_iter, loss, t, t_data, 61 | loss_types) 62 | writer.plot_loss(loss, epoch, epoch_iter, dataset_size, 63 | loss_types) 64 | 65 | if i % opt.save_latest_freq == 0: 66 | print('saving the latest model (epoch %d, total_steps %d)' % 67 | (epoch, total_steps)) 68 | model.save_network('latest', epoch) 69 | 70 | iter_data_time = time.time() 71 | 72 | if epoch % opt.save_epoch_freq == 0: 73 | print('saving the model at the end of epoch %d, iters %d' % 74 | (epoch, total_steps)) 75 | model.save_network('latest', epoch) 76 | model.save_network(str(epoch), epoch) 77 | 78 | print('End of epoch %d / %d \t Time Taken: %d sec' % 79 | (epoch, opt.niter + opt.niter_decay, 80 | time.time() - epoch_start_time)) 81 | model.update_learning_rate() 82 | if opt.verbose_plot: 83 | writer.plot_model_wts(model, epoch) 84 | 85 | if epoch % opt.run_test_freq == 0: 86 | acc = run_test(epoch, name=opt.name) 87 | writer.plot_acc(acc, epoch) 88 | 89 | writer.close() 90 | 91 | 92 | if __name__ == '__main__': 93 | main() 94 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsll/pytorch_6dof-graspnet/634a18c9f087ea25e7eced58c01db13394443fb9/utils/__init__.py -------------------------------------------------------------------------------- /utils/surface_normal.py: -------------------------------------------------------------------------------- 1 | try: 2 | import mayavi.mlab as mlab 3 | from visualization_utils import draw_scene 4 | import matplotlib.pyplot as plt 5 | except: 6 | pass 7 | import numpy as np 8 | import trimesh.transformations as tra 9 | import trimesh 10 | 11 | 12 | def cov_matrix(center, points): 13 | if points.shape[0] == 0: 14 | return None 15 | n = points.shape[0] 16 | diff = points - np.expand_dims(center, 0) 17 | cov = diff.T.dot(diff) / diff.shape[0] 18 | cov /= n 19 | 20 | eigen_values, eigen_vectors = np.linalg.eig(cov) 21 | 22 | order = np.argsort(-eigen_values) 23 | 24 | return eigen_values[order], eigen_vectors[:, order] 25 | 26 | 27 | def choose_direction(direction, point): 28 | dot = np.sum(direction * point) 29 | if dot >= 0: 30 | return -direction 31 | return direction 32 | 33 | 34 | def propose_grasps(pc, radius, num_grasps=1, vis=False): 35 | output_grasps = [] 36 | 37 | for _ in range(num_grasps): 38 | center_index = np.random.randint(pc.shape[0]) 39 | center_point = pc[center_index, :].copy() 40 | d = np.sqrt(np.sum(np.square(pc - np.expand_dims(center_point, 0)), 41 | -1)) 42 | index = np.where(d < radius)[0] 43 | neighbors = pc[index, :] 44 | 45 | eigen_values, eigen_vectors = cov_matrix(center_point, neighbors) 46 | direction = eigen_vectors[:, 2] 47 | 48 | direction = choose_direction(direction, center_point) 49 | 50 | surface_orientation = trimesh.geometry.align_vectors([0, 0, 1], 51 | direction) 52 | roll_orientation = tra.quaternion_matrix( 53 | tra.quaternion_about_axis(np.random.uniform(0, 2 * np.pi), 54 | [0, 0, 1])) 55 | gripper_transform = surface_orientation.dot(roll_orientation) 56 | gripper_transform[:3, 3] = center_point 57 | 58 | translation_transform = np.eye(4) 59 | translation_transform[2, 3] = -np.random.uniform(0.0669, 0.1122) 60 | 61 | gripper_transform = gripper_transform.dot(translation_transform) 62 | output_grasps.append(gripper_transform.copy()) 63 | 64 | if vis: 65 | draw_scene(pc, grasps=output_grasps) 66 | mlab.show() 67 | 68 | return np.asarray(output_grasps) 69 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import copy 3 | import os 4 | import math 5 | import time 6 | import trimesh.transformations as tra 7 | import json 8 | from utils import sample 9 | import torch 10 | import yaml 11 | from easydict import EasyDict as edict 12 | 13 | GRIPPER_PC = np.load('gripper_models/panda_pc.npy', 14 | allow_pickle=True).item()['points'] 15 | GRIPPER_PC[:, 3] = 1. 16 | 17 | 18 | def farthest_points(data, 19 | nclusters, 20 | dist_func, 21 | return_center_indexes=False, 22 | return_distances=False, 23 | verbose=False): 24 | """ 25 | Performs farthest point sampling on data points. 26 | Args: 27 | data: numpy array of the data points. 28 | nclusters: int, number of clusters. 29 | dist_dunc: distance function that is used to compare two data points. 30 | return_center_indexes: bool, If True, returns the indexes of the center of 31 | clusters. 32 | return_distances: bool, If True, return distances of each point from centers. 33 | 34 | Returns clusters, [centers, distances]: 35 | clusters: numpy array containing the cluster index for each element in 36 | data. 37 | centers: numpy array containing the integer index of each center. 38 | distances: numpy array of [npoints] that contains the closest distance of 39 | each point to any of the cluster centers. 40 | """ 41 | if nclusters >= data.shape[0]: 42 | if return_center_indexes: 43 | return np.arange(data.shape[0], 44 | dtype=np.int32), np.arange(data.shape[0], 45 | dtype=np.int32) 46 | 47 | return np.arange(data.shape[0], dtype=np.int32) 48 | 49 | clusters = np.ones((data.shape[0], ), dtype=np.int32) * -1 50 | distances = np.ones((data.shape[0], ), dtype=np.float32) * 1e7 51 | centers = [] 52 | for iter in range(nclusters): 53 | index = np.argmax(distances) 54 | centers.append(index) 55 | shape = list(data.shape) 56 | for i in range(1, len(shape)): 57 | shape[i] = 1 58 | 59 | broadcasted_data = np.tile(np.expand_dims(data[index], 0), shape) 60 | new_distances = dist_func(broadcasted_data, data) 61 | distances = np.minimum(distances, new_distances) 62 | clusters[distances == new_distances] = iter 63 | if verbose: 64 | print('farthest points max distance : {}'.format( 65 | np.max(distances))) 66 | 67 | if return_center_indexes: 68 | if return_distances: 69 | return clusters, np.asarray(centers, dtype=np.int32), distances 70 | return clusters, np.asarray(centers, dtype=np.int32) 71 | 72 | return clusters 73 | 74 | 75 | def distance_by_translation_grasp(p1, p2): 76 | """ 77 | Gets two nx4x4 numpy arrays and computes the translation of all the 78 | grasps. 79 | """ 80 | t1 = p1[:, :3, 3] 81 | t2 = p2[:, :3, 3] 82 | return np.sqrt(np.sum(np.square(t1 - t2), axis=-1)) 83 | 84 | 85 | def distance_by_translation_point(p1, p2): 86 | """ 87 | Gets two nx3 points and computes the distance between point p1 and p2. 88 | """ 89 | return np.sqrt(np.sum(np.square(p1 - p2), axis=-1)) 90 | 91 | 92 | def regularize_pc_point_count(pc, npoints, use_farthest_point=False): 93 | """ 94 | If point cloud pc has less points than npoints, it oversamples. 95 | Otherwise, it downsample the input pc to have npoint points. 96 | use_farthest_point: indicates whether to use farthest point sampling 97 | to downsample the points. Farthest point sampling version runs slower. 98 | """ 99 | if pc.shape[0] > npoints: 100 | if use_farthest_point: 101 | _, center_indexes = farthest_points(pc, 102 | npoints, 103 | distance_by_translation_point, 104 | return_center_indexes=True) 105 | else: 106 | center_indexes = np.random.choice(range(pc.shape[0]), 107 | size=npoints, 108 | replace=False) 109 | pc = pc[center_indexes, :] 110 | else: 111 | required = npoints - pc.shape[0] 112 | if required > 0: 113 | index = np.random.choice(range(pc.shape[0]), size=required) 114 | pc = np.concatenate((pc, pc[index, :]), axis=0) 115 | return pc 116 | 117 | 118 | def perturb_grasp(grasp, num, min_translation, max_translation, min_rotation, 119 | max_rotation): 120 | """ 121 | Self explanatory. 122 | """ 123 | output_grasps = [] 124 | for _ in range(num): 125 | sampled_translation = [ 126 | np.random.uniform(lb, ub) 127 | for lb, ub in zip(min_translation, max_translation) 128 | ] 129 | sampled_rotation = [ 130 | np.random.uniform(lb, ub) 131 | for lb, ub in zip(min_rotation, max_rotation) 132 | ] 133 | grasp_transformation = tra.euler_matrix(*sampled_rotation) 134 | grasp_transformation[:3, 3] = sampled_translation 135 | output_grasps.append(np.matmul(grasp, grasp_transformation)) 136 | 137 | return output_grasps 138 | 139 | 140 | def evaluate_grasps(grasp_tfs, obj_mesh): 141 | """ 142 | Check the collision of the grasps and also heuristic quality for each 143 | grasp. 144 | """ 145 | collisions, _ = sample.in_collision_with_gripper( 146 | obj_mesh, 147 | grasp_tfs, 148 | gripper_name='panda', 149 | silent=True, 150 | ) 151 | qualities = sample.grasp_quality_point_contacts( 152 | grasp_tfs, 153 | collisions, 154 | object_mesh=obj_mesh, 155 | gripper_name='panda', 156 | silent=True, 157 | ) 158 | 159 | return np.asarray(collisions), np.asarray(qualities) 160 | 161 | 162 | def inverse_transform(trans): 163 | """ 164 | Computes the inverse of 4x4 transform. 165 | """ 166 | rot = trans[:3, :3] 167 | t = trans[:3, 3] 168 | rot = np.transpose(rot) 169 | t = -np.matmul(rot, t) 170 | output = np.zeros((4, 4), dtype=np.float32) 171 | output[3][3] = 1 172 | output[:3, :3] = rot 173 | output[:3, 3] = t 174 | 175 | return output 176 | 177 | 178 | def uniform_quaternions(): 179 | quaternions = [ 180 | l[:-1].split('\t') for l in open( 181 | '../uniform_quaternions/data2_4608.qua', 'r').readlines() 182 | ] 183 | 184 | quaternions = [[float(t[0]), 185 | float(t[1]), 186 | float(t[2]), 187 | float(t[3])] for t in quaternions] 188 | quaternions = np.asarray(quaternions) 189 | quaternions = np.roll(quaternions, 1, axis=1) 190 | return [tra.quaternion_matrix(q) for q in quaternions] 191 | 192 | 193 | def nonuniform_quaternions(): 194 | all_poses = [] 195 | for az in np.linspace(0, np.pi * 2, 30): 196 | for el in np.linspace(-np.pi / 2, np.pi / 2, 30): 197 | all_poses.append(tra.euler_matrix(el, az, 0)) 198 | return all_poses 199 | 200 | 201 | def print_network(net): 202 | """Print the total number of parameters in the network 203 | Parameters: 204 | network 205 | """ 206 | print('---------- Network initialized -------------') 207 | num_params = 0 208 | for param in net.parameters(): 209 | num_params += param.numel() 210 | print('[Network] Total number of parameters : %.3f M' % (num_params / 1e6)) 211 | print('-----------------------------------------------') 212 | 213 | 214 | def merge_pc_and_gripper_pc(pc, 215 | gripper_pc, 216 | instance_mode=0, 217 | pc_latent=None, 218 | gripper_pc_latent=None): 219 | """ 220 | Merges the object point cloud and gripper point cloud and 221 | adds a binary auxilary feature that indicates whether each point 222 | belongs to the object or to the gripper. 223 | """ 224 | 225 | pc_shape = pc.shape 226 | gripper_shape = gripper_pc.shape 227 | assert (len(pc_shape) == 3) 228 | assert (len(gripper_shape) == 3) 229 | assert (pc_shape[0] == gripper_shape[0]) 230 | 231 | npoints = pc.shape[1] 232 | batch_size = pc.shape[0] 233 | 234 | if instance_mode == 1: 235 | assert pc_shape[-1] == 3 236 | latent_dist = [pc_latent, gripper_pc_latent] 237 | latent_dist = torch.cat(latent_dist, 1) 238 | 239 | l0_xyz = torch.cat((pc, gripper_pc), 1) 240 | labels = [ 241 | torch.ones((pc.shape[1], 1), dtype=torch.float32), 242 | torch.zeros((gripper_pc.shape[1], 1), dtype=torch.float32) 243 | ] 244 | labels = torch.cat(labels, 0) 245 | labels = torch.expand_dims(labels, 0) 246 | labels = torch.tile(labels, [batch_size, 1, 1]) 247 | 248 | if instance_mode == 1: 249 | l0_points = torch.cat([l0_xyz, latent_dist, labels], -1) 250 | else: 251 | l0_points = torch.cat([l0_xyz, labels], -1) 252 | 253 | return l0_xyz, l0_points 254 | 255 | 256 | def get_gripper_pc(batch_size, npoints, use_torch=True): 257 | """ 258 | Returns a numpy array or a tensor of shape (batch_size x npoints x 4). 259 | Represents gripper with the sepcified number of points. 260 | use_tf: switches between output tensor or numpy array. 261 | """ 262 | output = np.copy(GRIPPER_PC) 263 | if npoints != -1: 264 | assert (npoints > 0 and npoints <= output.shape[0] 265 | ), 'gripper_pc_npoint is too large {} > {}'.format( 266 | npoints, output.shape[0]) 267 | output = output[:npoints] 268 | output = np.expand_dims(output, 0) 269 | else: 270 | raise ValueError('npoints should not be -1.') 271 | 272 | if use_torch: 273 | output = torch.tensor(output, torch.float32) 274 | output = output.repeat(batch, size, 1, 1) 275 | return output 276 | else: 277 | output = np.tile(output, [batch_size, 1, 1]) 278 | 279 | return output 280 | 281 | 282 | def get_control_point_tensor(batch_size, use_torch=True, device="cpu"): 283 | """ 284 | Outputs a tensor of shape (batch_size x 6 x 3). 285 | use_tf: switches between outputing a tensor and outputing a numpy array. 286 | """ 287 | control_points = np.load('./gripper_control_points/panda.npy')[:, :3] 288 | control_points = [[0, 0, 0], [0, 0, 0], control_points[0, :], 289 | control_points[1, :], control_points[-2, :], 290 | control_points[-1, :]] 291 | control_points = np.asarray(control_points, dtype=np.float32) 292 | control_points = np.tile(np.expand_dims(control_points, 0), 293 | [batch_size, 1, 1]) 294 | 295 | if use_torch: 296 | return torch.tensor(control_points).to(device) 297 | 298 | return control_points 299 | 300 | 301 | def transform_control_points(gt_grasps, batch_size, mode='qt', device="cpu"): 302 | """ 303 | Transforms canonical points using gt_grasps. 304 | mode = 'qt' expects gt_grasps to have (batch_size x 7) where each 305 | element is catenation of quaternion and translation for each 306 | grasps. 307 | mode = 'rt': expects to have shape (batch_size x 4 x 4) where 308 | each element is 4x4 transformation matrix of each grasp. 309 | """ 310 | assert (mode == 'qt' or mode == 'rt'), mode 311 | grasp_shape = gt_grasps.shape 312 | if mode == 'qt': 313 | assert (len(grasp_shape) == 2), grasp_shape 314 | assert (grasp_shape[-1] == 7), grasp_shape 315 | control_points = get_control_point_tensor(batch_size, device=device) 316 | num_control_points = control_points.shape[1] 317 | input_gt_grasps = gt_grasps 318 | 319 | gt_grasps = torch.unsqueeze(input_gt_grasps, 320 | 1).repeat(1, num_control_points, 1) 321 | 322 | gt_q = gt_grasps[:, :, :4] 323 | gt_t = gt_grasps[:, :, 4:] 324 | gt_control_points = qrot(gt_q, control_points) 325 | gt_control_points += gt_t 326 | 327 | return gt_control_points 328 | else: 329 | assert (len(grasp_shape) == 3), grasp_shape 330 | assert (grasp_shape[1] == 4 and grasp_shape[2] == 4), grasp_shape 331 | control_points = get_control_point_tensor(batch_size, device=device) 332 | shape = control_points.shape 333 | ones = torch.ones((shape[0], shape[1], 1), dtype=torch.float32) 334 | control_points = torch.cat((control_points, ones), -1) 335 | return torch.matmul(control_points, gt_grasps.permute(0, 2, 1)) 336 | 337 | 338 | def transform_control_points_numpy(gt_grasps, batch_size, mode='qt'): 339 | """ 340 | Transforms canonical points using gt_grasps. 341 | mode = 'qt' expects gt_grasps to have (batch_size x 7) where each 342 | element is catenation of quaternion and translation for each 343 | grasps. 344 | mode = 'rt': expects to have shape (batch_size x 4 x 4) where 345 | each element is 4x4 transformation matrix of each grasp. 346 | """ 347 | assert (mode == 'qt' or mode == 'rt'), mode 348 | grasp_shape = gt_grasps.shape 349 | if mode == 'qt': 350 | assert (len(grasp_shape) == 2), grasp_shape 351 | assert (grasp_shape[-1] == 7), grasp_shape 352 | control_points = get_control_point_tensor(batch_size, use_torch=False) 353 | num_control_points = control_points.shape[1] 354 | input_gt_grasps = gt_grasps 355 | gt_grasps = np.expand_dims(input_gt_grasps, 356 | 1).repeat(num_control_points, axis=1) 357 | gt_q = gt_grasps[:, :, :4] 358 | gt_t = gt_grasps[:, :, 4:] 359 | 360 | gt_control_points = rotate_point_by_quaternion(control_points, gt_q) 361 | gt_control_points += gt_t 362 | 363 | return gt_control_points 364 | else: 365 | assert (len(grasp_shape) == 3), grasp_shape 366 | assert (grasp_shape[1] == 4 and grasp_shape[2] == 4), grasp_shape 367 | control_points = get_control_point_tensor(batch_size, use_torch=False) 368 | shape = control_points.shape 369 | ones = np.ones((shape[0], shape[1], 1), dtype=np.float32) 370 | control_points = np.concatenate((control_points, ones), -1) 371 | return np.matmul(control_points, np.transpose(gt_grasps, (0, 2, 1))) 372 | 373 | 374 | def quaternion_mult(q, r): 375 | """ 376 | Multiply quaternion(s) q with quaternion(s) r. 377 | Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions. 378 | Returns q*r as a tensor of shape (*, 4). 379 | """ 380 | assert q.shape[-1] == 4 381 | assert r.shape[-1] == 4 382 | 383 | original_shape = q.shape 384 | 385 | # Compute outer product 386 | terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4)) 387 | 388 | w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3] 389 | x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2] 390 | y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1] 391 | z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0] 392 | return torch.stack((w, x, y, z), dim=1).view(original_shape) 393 | 394 | 395 | def conj_quaternion(q): 396 | """ 397 | Conjugate of quaternion q. 398 | """ 399 | q_conj = q.clone() 400 | q_conj[:, :, 1:] *= -1 401 | return q_conj 402 | 403 | 404 | def rotate_point_by_quaternion(point, q, device="cpu"): 405 | """ 406 | Takes in points with shape of (batch_size x n x 3) and quaternions with 407 | shape of (batch_size x n x 4) and returns a tensor with shape of 408 | (batch_size x n x 3) which is the rotation of the point with quaternion 409 | q. 410 | """ 411 | shape = point.shape 412 | q_shape = q.shape 413 | 414 | assert (len(shape) == 3), 'point shape = {} q shape = {}'.format( 415 | shape, q_shape) 416 | assert (shape[-1] == 3), 'point shape = {} q shape = {}'.format( 417 | shape, q_shape) 418 | assert (len(q_shape) == 3), 'point shape = {} q shape = {}'.format( 419 | shape, q_shape) 420 | assert (q_shape[-1] == 4), 'point shape = {} q shape = {}'.format( 421 | shape, q_shape) 422 | assert (q_shape[1] == shape[1]), 'point shape = {} q shape = {}'.format( 423 | shape, q_shape) 424 | 425 | q_conj = conj_quaternion(q) 426 | r = torch.cat([ 427 | torch.zeros( 428 | (shape[0], shape[1], 1), dtype=point.dtype).to(device), point 429 | ], 430 | dim=-1) 431 | final_point = quaternion_mult(quaternion_mult(q, r), q_conj) 432 | final_output = final_point[:, :, 433 | 1:] #torch.slice(final_point, [0, 0, 1], shape) 434 | return final_output 435 | 436 | 437 | def tc_rotation_matrix(az, el, th, batched=False): 438 | if batched: 439 | 440 | cx = torch.cos(torch.reshape(az, [-1, 1])) 441 | cy = torch.cos(torch.reshape(el, [-1, 1])) 442 | cz = torch.cos(torch.reshape(th, [-1, 1])) 443 | sx = torch.sin(torch.reshape(az, [-1, 1])) 444 | sy = torch.sin(torch.reshape(el, [-1, 1])) 445 | sz = torch.sin(torch.reshape(th, [-1, 1])) 446 | 447 | ones = torch.ones_like(cx) 448 | zeros = torch.zeros_like(cx) 449 | 450 | rx = torch.cat([ones, zeros, zeros, zeros, cx, -sx, zeros, sx, cx], 451 | dim=-1) 452 | ry = torch.cat([cy, zeros, sy, zeros, ones, zeros, -sy, zeros, cy], 453 | dim=-1) 454 | rz = torch.cat([cz, -sz, zeros, sz, cz, zeros, zeros, zeros, ones], 455 | dim=-1) 456 | 457 | rx = torch.reshape(rx, [-1, 3, 3]) 458 | ry = torch.reshape(ry, [-1, 3, 3]) 459 | rz = torch.reshape(rz, [-1, 3, 3]) 460 | 461 | return torch.matmul(rz, torch.matmul(ry, rx)) 462 | else: 463 | cx = torch.cos(az) 464 | cy = torch.cos(el) 465 | cz = torch.cos(th) 466 | sx = torch.sin(az) 467 | sy = torch.sin(el) 468 | sz = torch.sin(th) 469 | 470 | rx = torch.stack([[1., 0., 0.], [0, cx, -sx], [0, sx, cx]], dim=0) 471 | ry = torch.stack([[cy, 0, sy], [0, 1, 0], [-sy, 0, cy]], dim=0) 472 | rz = torch.stack([[cz, -sz, 0], [sz, cz, 0], [0, 0, 1]], dim=0) 473 | 474 | return torch.matmul(rz, torch.matmul(ry, rx)) 475 | 476 | 477 | def mkdir(path): 478 | if not os.path.isdir(path): 479 | os.makedirs(path) 480 | 481 | 482 | def control_points_from_rot_and_trans(grasp_eulers, 483 | grasp_translations, 484 | device="cpu"): 485 | rot = tc_rotation_matrix(grasp_eulers[:, 0], 486 | grasp_eulers[:, 1], 487 | grasp_eulers[:, 2], 488 | batched=True) 489 | grasp_pc = get_control_point_tensor(grasp_eulers.shape[0], device=device) 490 | grasp_pc = torch.matmul(grasp_pc, rot.permute(0, 2, 1)) 491 | grasp_pc += grasp_translations.unsqueeze(1).expand(-1, grasp_pc.shape[1], 492 | -1) 493 | return grasp_pc 494 | 495 | 496 | def rot_and_trans_to_grasps(euler_angles, translations, selection_mask): 497 | grasps = [] 498 | refine_indexes, sample_indexes = np.where(selection_mask) 499 | for refine_index, sample_index in zip(refine_indexes, sample_indexes): 500 | rt = tra.euler_matrix(*euler_angles[refine_index, sample_index, :]) 501 | rt[:3, 3] = translations[refine_index, sample_index, :] 502 | grasps.append(rt) 503 | return grasps 504 | 505 | 506 | def convert_qt_to_rt(grasps): 507 | Ts = grasps[:, 4:] 508 | Rs = qeuler(grasps[:, :4], "zyx") 509 | return Rs, Ts 510 | 511 | 512 | def qeuler(q, order, epsilon=0): 513 | """ 514 | Convert quaternion(s) q to Euler angles. 515 | Expects a tensor of shape (*, 4), where * denotes any number of dimensions. 516 | Returns a tensor of shape (*, 3). 517 | """ 518 | assert q.shape[-1] == 4 519 | 520 | original_shape = list(q.shape) 521 | original_shape[-1] = 3 522 | q = q.view(-1, 4) 523 | 524 | q0 = q[:, 0] 525 | q1 = q[:, 1] 526 | q2 = q[:, 2] 527 | q3 = q[:, 3] 528 | 529 | if order == 'xyz': 530 | x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) 531 | y = torch.asin( 532 | torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon)) 533 | z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) 534 | elif order == 'yzx': 535 | x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) 536 | y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3)) 537 | z = torch.asin( 538 | torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon)) 539 | elif order == 'zxy': 540 | x = torch.asin( 541 | torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon)) 542 | y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) 543 | z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3)) 544 | elif order == 'xzy': 545 | x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) 546 | y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3)) 547 | z = torch.asin( 548 | torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon)) 549 | elif order == 'yxz': 550 | x = torch.asin( 551 | torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon)) 552 | y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2)) 553 | z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) 554 | elif order == 'zyx': 555 | x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) 556 | y = torch.asin( 557 | torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon)) 558 | z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) 559 | else: 560 | raise ValueError("Invalid order " + order) 561 | 562 | return torch.stack((x, y, z), dim=1).view(original_shape) 563 | 564 | 565 | def read_checkpoint_args(folder_path): 566 | return edict(yaml.load(open(os.path.join(folder_path, 'opt.yaml')))) 567 | 568 | 569 | def choose_grasps_better_than_threshold(eulers, 570 | translations, 571 | probs, 572 | threshold=0.7): 573 | """ 574 | Chooses the grasps that have scores higher than the input threshold. 575 | """ 576 | print('choose_better_than_threshold threshold=', threshold) 577 | return np.asarray(probs >= threshold, dtype=np.float32) 578 | 579 | 580 | def choose_grasps_better_than_threshold_in_sequence(eulers, 581 | translations, 582 | probs, 583 | threshold=0.7): 584 | """ 585 | Chooses the grasps with the maximum score in the sequence of grasp refinements. 586 | """ 587 | output = np.zeros(probs.shape, dtype=np.float32) 588 | max_index = np.argmax(probs, 0) 589 | max_value = np.max(probs, 0) 590 | for i in range(probs.shape[1]): 591 | if max_value[i] > threshold: 592 | output[max_index[i]][i] = 1. 593 | return output 594 | 595 | 596 | def denormalize_grasps(grasps, mean=0, std=1): 597 | temp = 1 / std 598 | for grasp in grasps: 599 | grasp[:3, 3] = (std * grasp[:3, 3] + mean) 600 | 601 | 602 | def quat2mat(quat): 603 | """Convert quaternion coefficients to rotation matrix. 604 | Args: 605 | quat: first three coeff of quaternion of rotation. fourth is then computed to have a norm of 1 -- size = [B, 3] 606 | Returns: 607 | Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] 608 | """ 609 | norm_quat = torch.cat([quat[:, :1].detach() * 0 + 1, quat], dim=1) 610 | norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True) 611 | w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 612 | 2], norm_quat[:, 613 | 3] 614 | 615 | B = quat.size(0) 616 | 617 | w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) 618 | wx, wy, wz = w * x, w * y, w * z 619 | xy, xz, yz = x * y, x * z, y * z 620 | 621 | rotMat = torch.stack([ 622 | w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy, 623 | w2 - x2 + y2 - z2, 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz, 624 | w2 - x2 - y2 + z2 625 | ], 626 | dim=1).reshape(B, 3, 3) 627 | return rotMat 628 | 629 | 630 | def qrot(q, v): 631 | """ 632 | Rotate vector(s) v about the rotation described by quaternion(s) q. 633 | Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v, 634 | where * denotes any number of dimensions. 635 | Returns a tensor of shape (*, 3). 636 | """ 637 | assert q.shape[-1] == 4 638 | assert v.shape[-1] == 3 639 | assert q.shape[:-1] == v.shape[:-1] 640 | 641 | original_shape = list(v.shape) 642 | q = q.view(-1, 4) 643 | v = v.view(-1, 3) 644 | 645 | qvec = q[:, 1:] 646 | uv = torch.cross(qvec, v, dim=1) 647 | uuv = torch.cross(qvec, uv, dim=1) 648 | return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape) 649 | 650 | 651 | def get_inlier_grasp_indices(grasp_list, query_point, threshold=1.0, device="cpu"): 652 | """This function returns all grasps whose distance between the mid of the finger tips and the query point is less than the threshold value. 653 | 654 | Arguments: 655 | grasps are given as a list of [B,7] where B is the number of grasps and the other 656 | 7 values represent teh quaternion and translation. 657 | query_point is a 1x3 point in 3D space. 658 | threshold represents the maximum distance between a grasp and the query_point 659 | """ 660 | indices_to_keep = [] 661 | for grasps in grasp_list: 662 | grasp_cps = transform_control_points(grasps, 663 | grasps.shape[0], 664 | device=device) 665 | mid_points = get_mid_of_contact_points(grasp_cps) 666 | dist = torch.norm(mid_points - query_point, 2, dim=-1) 667 | indices_to_keep.append(torch.where(dist <= threshold)) 668 | return indices_to_keep 669 | 670 | 671 | def get_mid_of_contact_points(grasp_cps): 672 | mid = (grasp_cps[:, 0, :] + grasp_cps[:, 1, :]) / 2.0 673 | return mid 674 | 675 | 676 | def euclid_dist(point1, point2): 677 | return np.linalg.norm(point1 - point2) 678 | 679 | def partition_array_into_subarrays(array, sub_array_size): 680 | subarrays = [] 681 | for i in range(0, math.ceil(array.shape[0] / sub_array_size)): 682 | subarrays.append(array[i * sub_array_size:(i + 1) * sub_array_size]) 683 | return subarrays -------------------------------------------------------------------------------- /utils/visualization_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import mayavi.mlab as mlab 4 | from utils import utils, sample 5 | import numpy as np 6 | import trimesh 7 | 8 | 9 | def get_color_plasma_org(x): 10 | import matplotlib.pyplot as plt 11 | return tuple([x for i, x in enumerate(plt.cm.plasma(x)) if i < 3]) 12 | 13 | 14 | def get_color_plasma(x): 15 | return tuple([float(1 - x), float(x), float(0)]) 16 | 17 | 18 | def plot_mesh(mesh): 19 | assert type(mesh) == trimesh.base.Trimesh 20 | mlab.triangular_mesh(mesh.vertices[:, 0], 21 | mesh.vertices[:, 1], 22 | mesh.vertices[:, 2], 23 | mesh.faces, 24 | colormap='Blues') 25 | 26 | 27 | def draw_scene(pc, 28 | grasps=[], 29 | grasp_scores=None, 30 | grasp_color=None, 31 | gripper_color=(0, 1, 0), 32 | mesh=None, 33 | show_gripper_mesh=False, 34 | grasps_selection=None, 35 | visualize_diverse_grasps=False, 36 | min_seperation_distance=0.03, 37 | pc_color=None, 38 | plasma_coloring=False, 39 | target_cps=None): 40 | """ 41 | Draws the 3D scene for the object and the scene. 42 | Args: 43 | pc: point cloud of the object 44 | grasps: list of 4x4 numpy array indicating the transformation of the grasps. 45 | grasp_scores: grasps will be colored based on the scores. If left 46 | empty, grasps are visualized in green. 47 | grasp_color: if it is a tuple, sets the color for all the grasps. If list 48 | is provided it is the list of tuple(r,g,b) for each grasp. 49 | mesh: If not None, shows the mesh of the object. Type should be trimesh 50 | mesh. 51 | show_gripper_mesh: If True, shows the gripper mesh for each grasp. 52 | grasp_selection: if provided, filters the grasps based on the value of 53 | each selection. 1 means select ith grasp. 0 means exclude the grasp. 54 | visualize_diverse_grasps: sorts the grasps based on score. Selects the 55 | top score grasp to visualize and then choose grasps that are not within 56 | min_seperation_distance distance of any of the previously selected 57 | grasps. Only set it to True to declutter the grasps for better 58 | visualization. 59 | pc_color: if provided, should be a n x 3 numpy array for color of each 60 | point in the point cloud pc. Each number should be between 0 and 1. 61 | plasma_coloring: If True, sets the plasma colormap for visualizting the 62 | pc. 63 | """ 64 | 65 | max_grasps = 100 66 | grasps = np.array(grasps) 67 | 68 | if grasp_scores is not None: 69 | grasp_scores = np.array(grasp_scores) 70 | 71 | if len(grasps) > max_grasps: 72 | 73 | print('Downsampling grasps, there are too many') 74 | chosen_ones = np.random.randint(low=0, 75 | high=len(grasps), 76 | size=max_grasps) 77 | grasps = grasps[chosen_ones] 78 | if grasp_scores is not None: 79 | grasp_scores = grasp_scores[chosen_ones] 80 | 81 | if mesh is not None: 82 | if type(mesh) == list: 83 | for elem in mesh: 84 | plot_mesh(elem) 85 | else: 86 | plot_mesh(mesh) 87 | 88 | if pc_color is None and pc is not None: 89 | if plasma_coloring: 90 | mlab.points3d(pc[:, 0], 91 | pc[:, 1], 92 | pc[:, 2], 93 | pc[:, 2], 94 | colormap='plasma') 95 | else: 96 | mlab.points3d(pc[:, 0], 97 | pc[:, 1], 98 | pc[:, 2], 99 | color=(0.1, 0.1, 1), 100 | scale_factor=0.01) 101 | elif pc is not None: 102 | if plasma_coloring: 103 | mlab.points3d(pc[:, 0], 104 | pc[:, 1], 105 | pc[:, 2], 106 | pc_color[:, 0], 107 | colormap='plasma') 108 | else: 109 | rgba = np.zeros((pc.shape[0], 4), dtype=np.uint8) 110 | rgba[:, :3] = np.asarray(pc_color) 111 | rgba[:, 3] = 255 112 | src = mlab.pipeline.scalar_scatter(pc[:, 0], pc[:, 1], pc[:, 2]) 113 | src.add_attribute(rgba, 'colors') 114 | src.data.point_data.set_active_scalars('colors') 115 | g = mlab.pipeline.glyph(src) 116 | g.glyph.scale_mode = "data_scaling_off" 117 | g.glyph.glyph.scale_factor = 0.01 118 | 119 | grasp_pc = np.squeeze(utils.get_control_point_tensor(1, False), 0) 120 | grasp_pc[2, 2] = 0.059 121 | grasp_pc[3, 2] = 0.059 122 | 123 | mid_point = 0.5 * (grasp_pc[2, :] + grasp_pc[3, :]) 124 | 125 | modified_grasp_pc = [] 126 | modified_grasp_pc.append(np.zeros((3, ), np.float32)) 127 | modified_grasp_pc.append(mid_point) 128 | modified_grasp_pc.append(grasp_pc[2]) 129 | modified_grasp_pc.append(grasp_pc[4]) 130 | modified_grasp_pc.append(grasp_pc[2]) 131 | modified_grasp_pc.append(grasp_pc[3]) 132 | modified_grasp_pc.append(grasp_pc[5]) 133 | 134 | grasp_pc = np.asarray(modified_grasp_pc) 135 | 136 | def transform_grasp_pc(g): 137 | output = np.matmul(grasp_pc, g[:3, :3].T) 138 | output += np.expand_dims(g[:3, 3], 0) 139 | 140 | return output 141 | 142 | if grasp_scores is not None: 143 | indexes = np.argsort(-np.asarray(grasp_scores)) 144 | else: 145 | indexes = range(len(grasps)) 146 | 147 | print('draw scene ', len(grasps)) 148 | 149 | selected_grasps_so_far = [] 150 | removed = 0 151 | 152 | if grasp_scores is not None: 153 | min_score = np.min(grasp_scores) 154 | max_score = np.max(grasp_scores) 155 | top5 = np.array(grasp_scores).argsort()[-5:][::-1] 156 | 157 | for ii in range(len(grasps)): 158 | i = indexes[ii] 159 | if grasps_selection is not None: 160 | if grasps_selection[i] == False: 161 | continue 162 | 163 | g = grasps[i] 164 | is_diverse = True 165 | for prevg in selected_grasps_so_far: 166 | distance = np.linalg.norm(prevg[:3, 3] - g[:3, 3]) 167 | 168 | if distance < min_seperation_distance: 169 | is_diverse = False 170 | break 171 | 172 | if visualize_diverse_grasps: 173 | if not is_diverse: 174 | removed += 1 175 | continue 176 | else: 177 | if grasp_scores is not None: 178 | print('selected', i, grasp_scores[i], min_score, max_score) 179 | else: 180 | print('selected', i) 181 | selected_grasps_so_far.append(g) 182 | 183 | if isinstance(gripper_color, list): 184 | pass 185 | elif grasp_scores is not None: 186 | normalized_score = (grasp_scores[i] - 187 | min_score) / (max_score - min_score + 0.0001) 188 | if grasp_color is not None: 189 | gripper_color = grasp_color[ii] 190 | else: 191 | gripper_color = get_color_plasma(normalized_score) 192 | 193 | if min_score == 1.0: 194 | gripper_color = (0.0, 1.0, 0.0) 195 | 196 | if show_gripper_mesh: 197 | gripper_mesh = sample.Object( 198 | 'gripper_models/panda_gripper.obj').mesh 199 | gripper_mesh.apply_transform(g) 200 | mlab.triangular_mesh( 201 | gripper_mesh.vertices[:, 0], 202 | gripper_mesh.vertices[:, 1], 203 | gripper_mesh.vertices[:, 2], 204 | gripper_mesh.faces, 205 | color=gripper_color, 206 | opacity=1 if visualize_diverse_grasps else 0.5) 207 | else: 208 | pts = np.matmul(grasp_pc, g[:3, :3].T) 209 | pts += np.expand_dims(g[:3, 3], 0) 210 | if isinstance(gripper_color, list): 211 | mlab.plot3d(pts[:, 0], 212 | pts[:, 1], 213 | pts[:, 2], 214 | color=gripper_color[i], 215 | tube_radius=0.003, 216 | opacity=1) 217 | else: 218 | tube_radius = 0.001 219 | mlab.plot3d(pts[:, 0], 220 | pts[:, 1], 221 | pts[:, 2], 222 | color=gripper_color, 223 | tube_radius=tube_radius, 224 | opacity=1) 225 | if target_cps is not None: 226 | mlab.points3d(target_cps[ii, :, 0], 227 | target_cps[ii, :, 1], 228 | target_cps[ii, :, 2], 229 | color=(1.0, 0.0, 0), 230 | scale_factor=0.01) 231 | 232 | print('removed {} similar grasps'.format(removed)) 233 | 234 | 235 | def get_axis(): 236 | # hacky axis for mayavi 237 | axis = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) 238 | axis_x = np.array([np.linspace(0, 0.10, 50), np.zeros(50), np.zeros(50)]).T 239 | axis_y = np.array([np.zeros(50), np.linspace(0, 0.10, 50), np.zeros(50)]).T 240 | axis_z = np.array([np.zeros(50), np.zeros(50), np.linspace(0, 0.10, 50)]).T 241 | axis = np.concatenate([axis_x, axis_y, axis_z], axis=0) 242 | return axis 243 | -------------------------------------------------------------------------------- /utils/writer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | try: 5 | from tensorboardX import SummaryWriter 6 | except ImportError as error: 7 | print('tensorboard X not installed, visualizing wont be available') 8 | SummaryWriter = None 9 | 10 | 11 | class Writer: 12 | def __init__(self, opt): 13 | self.name = opt.name 14 | self.opt = opt 15 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 16 | self.log_name = os.path.join(self.save_dir, 'loss_log.txt') 17 | self.testacc_log = os.path.join(self.save_dir, 'testacc_log.txt') 18 | self.start_logs() 19 | self.nexamples = 0 20 | self.confidence_acc = 0 21 | self.ncorrect = 0 22 | 23 | if opt.is_train and not opt.no_vis and SummaryWriter is not None: 24 | self.display = SummaryWriter( 25 | logdir=os.path.join(self.opt.checkpoints_dir, self.opt.name) + 26 | "/tensorboard") #comment=opt.name) 27 | else: 28 | self.display = None 29 | 30 | def start_logs(self): 31 | """ creates test / train log files """ 32 | if self.opt.is_train: 33 | with open(self.log_name, "a") as log_file: 34 | now = time.strftime("%c") 35 | log_file.write( 36 | '================ Training Loss (%s) ================\n' % 37 | now) 38 | else: 39 | with open(self.testacc_log, "a") as log_file: 40 | now = time.strftime("%c") 41 | log_file.write( 42 | '================ Testing Acc (%s) ================\n' % 43 | now) 44 | 45 | def print_current_losses(self, 46 | epoch, 47 | i, 48 | losses, 49 | t, 50 | t_data, 51 | loss_types="total_loss"): 52 | """ prints train loss to terminal / file """ 53 | if type(losses) == list: 54 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f)' \ 55 | % (epoch, i, t, t_data) 56 | for (loss_type, loss_value) in zip(loss_types, losses): 57 | message += ' %s: %.3f' % (loss_type, loss_value.item()) 58 | else: 59 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) loss: %.3f ' \ 60 | % (epoch, i, t, t_data, losses.item()) 61 | print(message) 62 | with open(self.log_name, "a") as log_file: 63 | log_file.write('%s\n' % message) 64 | 65 | def plot_loss(self, losses, epoch, i, n, loss_types): 66 | iters = i + (epoch - 1) * n 67 | if self.display: 68 | if type(losses) == list: 69 | for (loss_type, loss_value) in zip(loss_types, losses): 70 | self.display.add_scalar('data/train_loss/' + loss_type, 71 | loss_value, iters) 72 | else: 73 | self.display.add_scalar('data/train_loss', losses, iters) 74 | 75 | def plot_model_wts(self, model, epoch): 76 | if self.opt.is_train and self.display: 77 | for name, param in model.net.named_parameters(): 78 | self.display.add_histogram(name, 79 | param.clone().cpu().data.numpy(), 80 | epoch) 81 | 82 | def print_acc(self, epoch, acc): 83 | """ prints test accuracy to terminal / file """ 84 | if self.opt.arch == "evaluator": 85 | message = 'epoch: {}, TEST ACC: [{:.5} %]\n' \ 86 | .format(epoch, acc * 100) 87 | else: 88 | message = 'epoch: {}, TEST REC LOSS: [{:.5}]\n' \ 89 | .format(epoch, acc) 90 | 91 | print(message) 92 | with open(self.testacc_log, "a") as log_file: 93 | log_file.write('%s\n' % message) 94 | 95 | def plot_acc(self, acc, epoch): 96 | if self.display: 97 | if self.opt.arch == "evaluator": 98 | self.display.add_scalar('data/test_acc/grasp_prediction', acc, 99 | epoch) 100 | else: 101 | self.display.add_scalar('data/test_loss/grasp_reconstruction', 102 | acc, epoch) 103 | 104 | def reset_counter(self): 105 | """ 106 | counts # of correct examples 107 | """ 108 | self.ncorrect = 0 109 | self.nexamples = 0 110 | 111 | def update_counter(self, ncorrect, nexamples): 112 | self.nexamples += nexamples 113 | self.ncorrect += ncorrect 114 | 115 | @property 116 | def acc(self): 117 | return float(self.ncorrect) / self.nexamples 118 | 119 | def close(self): 120 | if self.display is not None: 121 | self.display.close() 122 | --------------------------------------------------------------------------------