├── models ├── __init__.py ├── models.py └── multi_fingan.py ├── utils ├── __init__.py ├── plots.py ├── constrain_hand.py ├── tb_visualizer.py ├── contactutils.py ├── forward_kinematics_barrett.py └── util.py ├── datasets ├── __init__.py ├── custom_dataset_data_loader.py ├── egad_synthetic_oneobject.py ├── dataset.py └── ycb_synthetic_oneobject.py ├── networks ├── __init__.py ├── img_encoder_and_grasp_predictor.py ├── discriminator.py ├── grasp_generator.py └── networks.py ├── options ├── __init__.py ├── test_options.py ├── train_options.py └── base_options.py ├── data ├── objects_in_YCB.npy ├── average_hand_joints_per_taxonomy.npy ├── download_train_data.sh └── download_test_data.sh ├── images └── architecture.png ├── files ├── uniform_rotations.npy └── sample_rotations_one_hemisphere_new.npy ├── .gitignore ├── requirements.txt ├── LICENSE ├── README.md ├── train.py └── test.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/objects_in_YCB.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aalto-intelligent-robotics/Multi-FinGAN/HEAD/data/objects_in_YCB.npy -------------------------------------------------------------------------------- /images/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aalto-intelligent-robotics/Multi-FinGAN/HEAD/images/architecture.png -------------------------------------------------------------------------------- /files/uniform_rotations.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aalto-intelligent-robotics/Multi-FinGAN/HEAD/files/uniform_rotations.npy -------------------------------------------------------------------------------- /data/average_hand_joints_per_taxonomy.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aalto-intelligent-robotics/Multi-FinGAN/HEAD/data/average_hand_joints_per_taxonomy.npy -------------------------------------------------------------------------------- /files/sample_rotations_one_hemisphere_new.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aalto-intelligent-robotics/Multi-FinGAN/HEAD/files/sample_rotations_one_hemisphere_new.npy -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints/ 2 | *.swp 3 | __pycache__ 4 | */__pycache__ 5 | graspit_grasps/ 6 | data/meshes 7 | data/train_data 8 | data/test_data 9 | .vscode 10 | results/* 11 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | easydict==1.9 2 | cvxpy==1.1.3 3 | pyquaternion==0.9.5 4 | pandas==1.0.3 5 | numpy==1.13.3 6 | trimesh==3.6.30 7 | cvxpylayers==0.1.4 8 | cvxopt==1.2.5 9 | torch==1.5.0 10 | joblib==0.14.1 11 | numpy_quaternion==2020.5.19.15.27.24 12 | torchvision==0.6.0 13 | mayavi==4.7.1 14 | Pillow==8.1.2 15 | PyYAML==5.4.1 16 | quaternion==3.5.2.post4 17 | tensorboardX==2.1 -------------------------------------------------------------------------------- /data/download_train_data.sh: -------------------------------------------------------------------------------- 1 | mkdir -p $1"/train_data/" 2 | mkdir -p $1"/meshes/" 3 | wget --no-check-certificate -r 'https://drive.google.com/uc?export=download&id=1nPV-L2oG0g4h1lJ5hszsGpRCul1BcwTU' -O $1"ycb_meshes.zip" 4 | unzip -n -q $1"ycb_meshes.zip" -d $1"/meshes/" 5 | rm $1"ycb_meshes.zip" 6 | wget --no-check-certificate -r 'https://drive.google.com/uc?export=download&id=1X2jwA7wCDVb4J9r-HZWXtRBKD7fa2sOO' -O $1"graspit_training_grasps.zip" 7 | unzip -n -q $1"graspit_training_grasps.zip" -d $1"/train_data/" 8 | rm $1"graspit_training_grasps.zip" -------------------------------------------------------------------------------- /data/download_test_data.sh: -------------------------------------------------------------------------------- 1 | mkdir -p $1"/test_data/" 2 | mkdir -p $1"/meshes/" 3 | wget --no-check-certificate -r 'https://drive.google.com/uc?export=download&id=1nPV-L2oG0g4h1lJ5hszsGpRCul1BcwTU' -O $1"ycb_meshes.zip" 4 | unzip -n -q $1"ycb_meshes.zip" -d $1"/meshes/" 5 | rm $1"ycb_meshes.zip" 6 | wget --no-check-certificate -r 'https://drive.google.com/uc?export=download&id=1eQnMYiOJk1I26L6arWKYvINiInh3EyIe' -O $1"egad_validation_meshes.zip" 7 | unzip -n -q $1"egad_validation_meshes.zip" -d $1"/meshes/" 8 | rm $1"egad_validation_meshes.zip" 9 | wget --no-check-certificate -r 'https://drive.google.com/uc?export=download&id=1bqhKCOdtxqrsRjJ6bvI5E3MF6NUx00Li' -O $1"graspit_test_grasps_simulated_annealing.zip" 10 | unzip -n -q $1"graspit_test_grasps_simulated_annealing.zip" -d $1"/test_data/" 11 | rm $1"graspit_test_grasps_simulated_annealing.zip" -------------------------------------------------------------------------------- /networks/img_encoder_and_grasp_predictor.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .networks import NetworkBase 3 | import torchvision 4 | 5 | 6 | class Network(NetworkBase): 7 | def __init__(self, output_dim=7): 8 | super(Network, self).__init__() 9 | self.model = torchvision.models.resnet50(pretrained=True) 10 | self.model.fc = nn.Linear(2048, output_dim) 11 | self.name = 'image_encoder_and_grasp_predictor' 12 | 13 | def forward(self, x): 14 | x = self.model.conv1(x) 15 | x = self.model.bn1(x) 16 | x = self.model.relu(x) 17 | x = self.model.maxpool(x) 18 | x = self.model.layer1(x) 19 | x = self.model.layer2(x) 20 | x = self.model.layer3(x) 21 | x = self.model.layer4(x) 22 | x = self.model.avgpool(x) 23 | img_representation = x.view(x.size(0), -1) 24 | x = self.model.fc(img_representation) 25 | return x, img_representation 26 | -------------------------------------------------------------------------------- /utils/plots.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from mayavi import mlab 3 | mlab.options.offscreen = True 4 | 5 | 6 | def plot_scene_w_grasps(list_obj_verts, list_obj_faces, list_obj_handverts, list_obj_handfaces): 7 | figure = mlab.figure(1, bgcolor=(1, 1, 1), 8 | fgcolor=(0, 0, 0), size=(640, 480)) 9 | mlab.clf() 10 | for i in range(len(list_obj_verts)): 11 | vertices = list_obj_verts[i] 12 | mlab.triangular_mesh(vertices[:, 0], vertices[:, 1], vertices[:, 2], 13 | list_obj_faces[i], color=(1, 0, 0), opacity=0.5) 14 | for i in range(len(list_obj_handverts)): 15 | vertices = list_obj_handverts[i] 16 | mlab.triangular_mesh(vertices[:, 0], vertices[:, 1], 17 | vertices[:, 2], list_obj_handfaces[i], color=(0, 0, 1)) 18 | mlab.view(azimuth=-90, distance=1.5) 19 | data = mlab.screenshot(figure) 20 | return data 21 | -------------------------------------------------------------------------------- /networks/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .networks import NetworkBase 3 | 4 | 5 | class Discriminator(NetworkBase): 6 | """Discriminator.""" 7 | 8 | def __init__(self, input_dim=7): 9 | super(Discriminator, self).__init__() 10 | self.name = 'discriminator' 11 | self.fc1 = nn.Linear(input_dim, 48) 12 | self.fc2 = nn.Linear(48, 32) 13 | self.fc3 = nn.Linear(32, 16) 14 | self.fc4 = nn.Linear(16, 1) 15 | self.layer_norm1 = nn.LayerNorm(48) 16 | self.layer_norm2 = nn.LayerNorm(32) 17 | self.layer_norm3 = nn.LayerNorm(16) 18 | self.model = nn.Sequential( 19 | self.fc1, 20 | nn.ReLU(), 21 | self.layer_norm1, 22 | self.fc2, 23 | nn.ReLU(), 24 | self.layer_norm2, 25 | self.fc3, 26 | nn.ReLU(), 27 | self.layer_norm3, 28 | self.fc4, 29 | ) 30 | 31 | def forward(self, x): 32 | return self.model(x) 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Jens Lundell 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /datasets/custom_dataset_data_loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from datasets.dataset import DatasetFactory 3 | 4 | 5 | class CustomDatasetDataLoader: 6 | def __init__(self, opt, mode='train'): 7 | self.opt = opt 8 | self.mode = mode 9 | self.num_threds = opt.n_threads_train 10 | self.create_dataset() 11 | 12 | def create_dataset(self): 13 | self.dataset = DatasetFactory.get_by_name( 14 | self.opt.dataset_name, self.opt, self.mode) 15 | if hasattr(self.dataset, 'collate_fn'): 16 | self.dataloader = torch.utils.data.DataLoader( 17 | self.dataset, 18 | batch_size=self.opt.batch_size, 19 | collate_fn=self.dataset.collate_fn, 20 | shuffle=not self.opt.serial_batches and self.mode == 'train', 21 | num_workers=int(self.num_threds), 22 | drop_last=True) 23 | else: 24 | self.dataloader = torch.utils.data.DataLoader( 25 | self.dataset, 26 | batch_size=self.opt.batch_size, 27 | shuffle=not self.opt.serial_batches and self.mode == 'train', 28 | num_workers=int(self.num_threds), 29 | drop_last=True) 30 | 31 | def load_data(self): 32 | return self.dataloader 33 | 34 | def __len__(self): 35 | return len(self.dataset) 36 | -------------------------------------------------------------------------------- /networks/grasp_generator.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .networks import NetworkBase 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | class Network(NetworkBase): 8 | def __init__(self, input_dim): 9 | super(Network, self).__init__() 10 | self.name = 'grasp_generator' 11 | self.fc0 = nn.Linear(2048, 128) 12 | self.fc1 = nn.Linear(128 + input_dim, 256) 13 | self.fc2 = nn.Linear(256, 256) 14 | self.fc3 = nn.Linear(256, 312) 15 | self.fc4 = nn.Linear(312, 256) 16 | self.fcHR_residual = nn.Linear(256, 128) 17 | self.fcHR_residual_2 = nn.Linear( 18 | 128, 1) 19 | self.fcR = nn.Linear(256, 64) 20 | self.fcR_2 = nn.Linear(64, 3) 21 | self.fcT = nn.Linear(256, 64) 22 | self.fcT_2 = nn.Linear(64, 3) 23 | 24 | def forward(self, image_representation, hand_representations, Ro, Ts): 25 | x_image = self.fc0(image_representation) 26 | x = self.fc1(torch.cat((x_image, hand_representations, Ro, Ts), -1)) 27 | 28 | x = F.relu(x) 29 | x = self.fc2(x) 30 | x = F.relu(x) 31 | x = self.fc3(x) 32 | x = F.relu(x) 33 | x = self.fc4(x) 34 | x = F.relu(x) 35 | 36 | x_hr = self.fcHR_residual(x) 37 | x_hr = F.relu(x_hr) 38 | HR = torch.zeros((Ro.shape[0], 7)).cuda() 39 | # Only output the spread between the fingers and not the actual finger rotations as those are refined 40 | # in the refinement layer 41 | HR[:, :1] = self.fcHR_residual_2( 42 | x_hr) 43 | HR += hand_representations 44 | x_r = self.fcR(x) 45 | R = Ro + self.fcR_2(x_r) 46 | x_t = self.fcT(x) 47 | x_t = F.relu(x_t) 48 | T = Ts + self.fcT_2(x_t) 49 | return HR, R, T 50 | -------------------------------------------------------------------------------- /networks/networks.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import functools 3 | 4 | 5 | class NetworksFactory: 6 | def __init__(self): 7 | pass 8 | 9 | @staticmethod 10 | def get_by_name(network_name, *args, **kwargs): 11 | 12 | if network_name == 'img_encoder_and_grasp_predictor': 13 | from .img_encoder_and_grasp_predictor import Network 14 | network = Network(*args, **kwargs) 15 | elif network_name == 'grasp_generator': 16 | from .grasp_generator import Network 17 | network = Network(*args, **kwargs) 18 | elif network_name == 'discriminator': 19 | from .discriminator import Discriminator 20 | network = Discriminator(*args, **kwargs) 21 | else: 22 | raise ValueError("Network %s not recognized." % network_name) 23 | 24 | print("Network %s was created" % network_name) 25 | 26 | return network 27 | 28 | 29 | class NetworkBase(nn.Module): 30 | def __init__(self): 31 | super(NetworkBase, self).__init__() 32 | self.name = 'BaseNetwork' 33 | 34 | def get_name(self): 35 | return self.name 36 | 37 | def init_weights(self): 38 | self.apply(self.weights_init_fn) 39 | 40 | def weights_init_fn(self, m): 41 | classname = m.__class__.__name__ 42 | if classname.find('Conv') != -1: 43 | m.weight.data.normal_(0.0, 0.02) 44 | if hasattr(m.bias, 'data'): 45 | m.bias.data.fill_(0) 46 | elif classname.find('BatchNorm2d') != -1: 47 | m.weight.data.normal_(1.0, 0.02) 48 | m.bias.data.fill_(0) 49 | 50 | def get_norm_layer(self, norm_type='batch'): 51 | if norm_type == 'batch': 52 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True) 53 | elif norm_type == 'instance': 54 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) 55 | elif norm_type == 'batchnorm2d': 56 | norm_layer = nn.BatchNorm2d 57 | else: 58 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 59 | 60 | return norm_layer 61 | -------------------------------------------------------------------------------- /utils/constrain_hand.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from utils import util 4 | import cvxpy as cp 5 | from cvxpylayers.torch import CvxpyLayer 6 | 7 | 8 | class ConstrainHand(nn.Module): 9 | def __init__(self, constrain_method, batch_size, device='cpu'): 10 | super(ConstrainHand, self).__init__() 11 | self.constain_method = constrain_method 12 | if constrain_method == "hard": 13 | self.constraint = nn.Hardtanh(0, util.deg2rad(180)) 14 | if constrain_method == "soft": 15 | self.constraint = nn.Sigmoid() 16 | if constrain_method == "cvx": 17 | self.set_cvx_layer(batch_size, device) 18 | 19 | def forward(self, HR): 20 | HR_unconstrained = HR.clone() 21 | if self.constain_method == "cvx": 22 | solution, = self.cvxpylayer(self.theta_max_torch, self.theta_min_torch, HR) 23 | less_than_min = HR < self.theta_min_torch 24 | HR[less_than_min] = HR[less_than_min] + solution[less_than_min] 25 | more_than_max = HR > self.theta_max_torch 26 | HR[more_than_max] = HR[more_than_max] - solution[more_than_max] 27 | elif self.constain_method == "hard": 28 | HR[:, 0] = self.constraint(1*HR[:, 0]) 29 | elif self.constain_method == "soft": 30 | HR[:, 0] = self.constraint(HR[:, 0])*util.deg2rad(180) 31 | return HR, HR_unconstrained 32 | 33 | def set_cvx_layer(self, batch_size, device): 34 | x = cp.Variable((batch_size, 7)) 35 | theta_max = cp.Parameter((batch_size, 7)) 36 | theta_min = cp.Parameter((batch_size, 7)) 37 | theta = cp.Parameter((batch_size, 7)) 38 | constraints = [theta-x <= theta_max, theta+x >= theta_min] 39 | objective = cp.Minimize(cp.pnorm(x)) 40 | problem = cp.Problem(objective, constraints) 41 | assert problem.is_dpp() 42 | self.cvxpylayer = CvxpyLayer(problem, parameters=[ 43 | theta_max, theta_min, theta], variables=[x]) 44 | eps = 1e-10 45 | self.theta_max_torch = util.deg2rad(torch.tensor( 46 | [180., 140., 140., 140., 48., 48., 48.], requires_grad=True)).to(device)-eps 47 | self.theta_max_torch = self.theta_max_torch.unsqueeze( 48 | 0).repeat(batch_size, 1)+eps 49 | self.theta_min_torch = torch.zeros( 50 | (batch_size, 7), requires_grad=True).to(device) 51 | -------------------------------------------------------------------------------- /utils/tb_visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import time 4 | from . import util 5 | from tensorboardX import SummaryWriter 6 | 7 | 8 | class TBVisualizer: 9 | def __init__(self, opt): 10 | self.opt = opt 11 | self.save_path = os.path.join(opt.checkpoints_dir, opt.name) 12 | 13 | self.log_path = os.path.join(self.save_path, 'loss_log2.txt') 14 | self.tb_path = os.path.join(self.save_path, 'summary.json') 15 | self.writer = SummaryWriter(self.save_path) 16 | 17 | with open(self.log_path, "a") as log_file: 18 | now = time.strftime("%c") 19 | log_file.write('================ Training Loss (%s) ================\n' % now) 20 | 21 | def __del__(self): 22 | self.writer.close() 23 | 24 | def display_current_results(self, visuals, it, is_train, save_visuals=False): 25 | for label, image_numpy in visuals.items(): 26 | sum_name = '{}/{}'.format('Train' if is_train else 'Test', label) 27 | if type(image_numpy) is list: 28 | if len(image_numpy) == 0 or len(image_numpy[0]) == 0: 29 | continue 30 | images = (np.stack(image_numpy).transpose(0, 3, 1, 2))/255.0 31 | self.writer.add_images(sum_name, images, it) 32 | else: 33 | self.writer.add_image(sum_name, image_numpy.transpose((2, 0, 1)), it) 34 | 35 | self.writer.export_scalars_to_json(self.tb_path) 36 | 37 | def plot_scalars(self, scalars, it, is_train): 38 | for label, scalar in scalars.items(): 39 | sum_name = '{}/{}'.format('Train' if is_train else 'Test', label) 40 | self.writer.add_scalar(sum_name, scalar, it) 41 | 42 | def print_current_train_errors(self, epoch, i, iters_per_epoch, errors, t, visuals_were_stored): 43 | log_time = time.strftime("[%d/%m/%Y %H:%M:%S]") 44 | visuals_info = "v" if visuals_were_stored else "" 45 | message = '%s (T%s, epoch: %d, it: %d/%d, t/smpl: %.3fs) ' % (log_time, visuals_info, epoch, i, iters_per_epoch, t) 46 | for k, v in errors.items(): 47 | message += '%s:%.3f ' % (k, v) 48 | 49 | print(message) 50 | with open(self.log_path, "a") as log_file: 51 | log_file.write('%s\n' % message) 52 | 53 | def print_current_validate_errors(self, epoch, errors, t): 54 | log_time = time.strftime("[%d/%m/%Y %H:%M:%S]") 55 | message = '%s (V, epoch: %d, time_to_val: %ds) ' % (log_time, epoch, t) 56 | for k, v in errors.items(): 57 | message += '%s:%.3f ' % (k, v) 58 | 59 | print(message) 60 | with open(self.log_path, "a") as log_file: 61 | log_file.write('%s\n' % message) 62 | 63 | def save_images(self, visuals): 64 | for label, image_numpy in visuals.items(): 65 | image_name = '%s.png' % label 66 | save_path = os.path.join(self.save_path, "samples", image_name) 67 | util.save_image(image_numpy, save_path) 68 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multi-FinGAN 2 | 3 | This repository includes code used in our work on [Multi-FinGAN: Generative Coarse-To-Fine Samplingof Multi-Finger Grasps](https://arxiv.org/pdf/2012.09696.pdf). More specifically, it includes code to train a Multi-FinGAN model and do the simulation experiments. 4 | 5 | **Authors**: Jens Lundell\ 6 | **Maintainer**: Jens Lundell, jens.lundell@aalto.fi 7 | **Affiliation**: Intelligent Robotics Lab, Aalto University 8 | 9 | ## Getting Started 10 | 11 | The code was developed for python3.6 and Ubuntu 18.04. 12 | 13 | ### Dependencies 14 | 15 | - Python requirements: Run `pip install -r requirements.txt`. 16 | - [Kaolin](https://github.com/NVIDIAGameWorks/kaolin) 17 | - [Barrett kinematics layer](https://github.com/aalto-intelligent-robotics/pytorch_barrett_hand_forward_kinematics_layer) 18 | 19 | Note, for kaolin you need to checkout v0.1 20 | 21 | ``` 22 | git checkout v0.1 23 | python setup.py install --user 24 | ``` 25 | 26 | ## Model 27 | 28 | 29 | 30 | Multi-FinGAN takes a single RGB image of one object and predicts a multi-fingered robotic grasp on that objects. Our architecture consists of three stages. First, the object's shape is completed. Next, grasps are generated on the object in the image by feeding it through an image encoder, then predicting grasp types on that object and then generate 6D grasp poses and finger configurations. Finally, the hand is refined to be close to the surface of the target object but not in collision with it by using our parameter-free fully differentiable [barrett kinematics layer](https://github.com/aalto-intelligent-robotics/pytorch_barrett_hand_forward_kinematics_layer). 31 | 32 | ## Train 33 | 34 | To train a Multi-FinGAN model as in the paper you have to run the following (training data is pulled automatically) 35 | 36 | ``` 37 | python train.py --pregenerate_data 38 | ``` 39 | 40 | We recommend pre-generating training data as this considerably speeds up training. Additional command line flags are found [here](options/base_options.py) and [here](options/train_options.py) 41 | 42 | To visualize the training progress do 43 | 44 | ``` 45 | tensorboard --logdir checkpoints/ --samples_per_plugin=images=100 46 | ``` 47 | 48 | ## Simulation experiments 49 | 50 | Here we detail how to redo the Grasping in Simulation experiments in Section V-B in our [paper](https://arxiv.org/pdf/2012.09696.pdf). 51 | 52 | To test the GraspIt! grasps do: 53 | 54 | ``` 55 | python test.py ... 56 | ``` 57 | 58 | To test Multi-FinGAN do: 59 | 60 | ``` 61 | python test.py ... 62 | ``` 63 | 64 | ## Headless servers 65 | 66 | If you want to run the code on a headless servers but have problems try the following: 67 | 68 | ``` 69 | export QT_API=pyqt 70 | xvfb-run python train.py --pregenerate_data 71 | ``` 72 | 73 | ## Citation 74 | 75 | If this code is useful in your research, please consider citing: 76 | 77 | ``` 78 | @article{lundell2020multi, 79 | title={Multi-FinGAN: Generative Coarse-To-Fine Sampling of Multi-Finger Grasps}, 80 | author={Lundell, Jens and Corona, Enric and Le, Tran Nguyen and Verdoja, Francesco and Weinzaepfel, Philippe and Rogez, Gregory and Moreno-Noguer, Francesc and Kyrki, Ville}, 81 | journal={arXiv preprint arXiv:2012.09696}, 82 | year={2020} 83 | } 84 | ``` 85 | 86 | ## License 87 | 88 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details 89 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | import torch 3 | import numpy as np 4 | import random 5 | from easydict import EasyDict 6 | 7 | 8 | class TestOptions(BaseOptions): 9 | def initialize(self): 10 | BaseOptions.initialize(self) 11 | self.parser.add_argument('--graspit', action='store_true') 12 | self.parser.add_argument('--threshold_intersections', type=float, default=50) 13 | self.parser.add_argument('--threshold_contact', type=float, default=0.009, 14 | help="threshold in meters to determine if the distance between two vertices make contact") 15 | self.parser.add_argument('--display', action='store_true') 16 | self.parser.add_argument('--save_folder', type=str, default='./results/simulation_results/') 17 | self.parser.add_argument('--test_set', default="ycb", choices=["ycb", "egad"], type=str) 18 | opts, _ = self.parser.parse_known_args() 19 | if not opts.graspit: 20 | self.parser.add_argument('--num_viewpoints_per_object_test', default=5, type=int) 21 | self.parser.add_argument('--num_grasps_to_sample', type=int, choices=range(2, 200), default=30) 22 | self.parser.add_argument('--checkpoint_dir_load', type=str, help='path to checkpoint we want to load', required="true") 23 | self.parser.add_argument('--load_epoch', type=str, default='latest', 24 | help='which epoch to load? ') 25 | 26 | def parse(self): 27 | if not self.initialized: 28 | self.initialize() 29 | self.opt = EasyDict(vars(self.parser.parse_args())) 30 | 31 | # set and check load_epoch 32 | if self.opt.graspit: 33 | self.opt.save_folder += "graspit/" 34 | if self.opt.test_set == "ycb": 35 | self.opt.graspit_grasp_dir = './data/test_data/ycb_graspit_grasps/' 36 | self.opt.object_mesh_dir = './data/meshes/ycb_meshes/0*/google_16k/textured_simplified.obj' 37 | elif self.opt.test_set == "egad": 38 | self.opt.graspit_grasp_dir = './data/test_data/egad_graspit_grasps/' 39 | self.opt.object_mesh_dir = './data/meshes/egad_val_set_meshes/*_simplified.ply' 40 | else: 41 | self.opt.save_folder += "multifin_gan/" 42 | if not self.opt.pregenerate_data: 43 | self.opt.n_threads_train = 0 44 | self.opt.n_threads_test = 0 45 | 46 | self.load_epoch() 47 | # In val mode we do not want to run the inference on any seed that was used while training or testing 48 | seed_list = list(range(1, 10000)) 49 | seed_list.remove(self.opt.manual_seed) 50 | self.opt.manual_seed = random.choice(seed_list) 51 | self.opt.no_discriminator = True 52 | self.opt.precomputed_rotations = True 53 | torch.manual_seed(self.opt.manual_seed) 54 | np.random.seed(self.opt.manual_seed) 55 | self.get_set_gpus() 56 | if self.opt.test_set == "ycb": 57 | self.opt.object_mesh_dir = './data/meshes/ycb_meshes/' 58 | self.opt.dataset_name = "ycb" 59 | elif self.opt.test_set == "egad": 60 | self.opt.object_mesh_dir = './data/meshes/egad_val_set_meshes/' 61 | self.opt.dataset_name = "egad" 62 | 63 | self.opt.is_train = False 64 | self.opt.load_network = True 65 | args = vars(self.opt) 66 | 67 | # print in terminal args 68 | self.print(args) 69 | 70 | # save args to file 71 | # self.save(args) 72 | 73 | return self.opt 74 | -------------------------------------------------------------------------------- /utils/contactutils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def batch_mesh_contains_points( 5 | ray_origins, 6 | obj_triangles, 7 | direction=torch.Tensor([0.4395064455, 0.617598629942, 8 | 0.652231566745]), device="cpu" 9 | ): 10 | """Time efficient but memory greedy ! 11 | Computes ALL ray/triangle intersections and then counts them to determine 12 | if point inside mesh 13 | 14 | Args: 15 | ray_origins: (batch_size x point_nb x 3) 16 | obj_triangles: (batch_size, triangle_nb, vertex_nb=3, vertex_coords=3) 17 | tol_thresh: To determine if ray and triangle are // 18 | Returns: 19 | exterior: (batch_size, point_nb) 1 if the point is outside mesh, 0 else 20 | """ 21 | direction = direction.to(device) 22 | tol_thresh = 0.00000000001 23 | # ray_origins.requires_grad = False 24 | # obj_triangles.requires_grad = False 25 | batch_size = obj_triangles.shape[0] 26 | triangle_nb = obj_triangles.shape[1] 27 | point_nb = ray_origins.shape[1] 28 | 29 | # Batch dim and triangle dim will flattened together 30 | batch_points_size = batch_size * triangle_nb 31 | # Direction is random but shared 32 | v0, v1, v2 = obj_triangles[:, :, 0], obj_triangles[:, :, 33 | 1], obj_triangles[:, :, 34 | 2] 35 | # Get edges 36 | v0v1 = v1 - v0 37 | v0v2 = v2 - v0 38 | 39 | # Expand needed vectors 40 | batch_direction = direction.view(1, 1, 3).expand(batch_size, triangle_nb, 41 | 3) 42 | 43 | # Compute ray/triangle intersections 44 | pvec = torch.cross(batch_direction, v0v2, dim=2) 45 | dets = torch.bmm(v0v1.view(batch_points_size, 1, 3), 46 | pvec.view(batch_points_size, 3, 47 | 1)).view(batch_size, triangle_nb) 48 | 49 | # Check if ray and triangle are parallel 50 | parallel = abs(dets) < tol_thresh 51 | invdet = 1 / (dets + 0.1 * tol_thresh) 52 | 53 | # Repeat mesh info as many times as there are rays 54 | triangle_nb = v0.shape[1] 55 | v0 = v0.repeat(1, point_nb, 1) 56 | v0v1 = v0v1.repeat(1, point_nb, 1) 57 | v0v2 = v0v2.repeat(1, point_nb, 1) 58 | hand_verts_repeated = (ray_origins.view(batch_size, point_nb, 1, 3).repeat( 59 | 1, 1, triangle_nb, 1).view(ray_origins.shape[0], 60 | triangle_nb * point_nb, 3)) 61 | pvec = pvec.repeat(1, point_nb, 1) 62 | invdet = invdet.repeat(1, point_nb) 63 | tvec = hand_verts_repeated - v0 64 | u_val = (torch.bmm( 65 | tvec.view(batch_size * tvec.shape[1], 1, 3), 66 | pvec.view(batch_size * tvec.shape[1], 3, 1), 67 | ).view(batch_size, tvec.shape[1]) * invdet) 68 | # Check ray intersects inside triangle 69 | u_correct = (u_val > 0) * (u_val < 1) 70 | qvec = torch.cross(tvec, v0v1, dim=2) 71 | 72 | batch_direction = batch_direction.repeat(1, point_nb, 1) 73 | v_val = (torch.bmm( 74 | batch_direction.view(batch_size * qvec.shape[1], 1, 3), 75 | qvec.view(batch_size * qvec.shape[1], 3, 1), 76 | ).view(batch_size, qvec.shape[1]) * invdet) 77 | v_correct = (v_val > 0) * (u_val + v_val < 1) 78 | t = (torch.bmm( 79 | v0v2.view(batch_size * qvec.shape[1], 1, 3), 80 | qvec.view(batch_size * qvec.shape[1], 3, 1), 81 | ).view(batch_size, qvec.shape[1]) * invdet) 82 | # Check triangle is in front of ray_origin along ray direction 83 | t_pos = t >= tol_thresh 84 | parallel = parallel.repeat(1, point_nb) 85 | # # Check that all intersection conditions are met 86 | try: 87 | not_parallel = 1 - parallel 88 | except: 89 | not_parallel = parallel == False 90 | final_inter = v_correct * u_correct * not_parallel * t_pos 91 | # Reshape batch point/vertices intersection matrix 92 | # final_intersections[batch_idx, point_idx, triangle_idx] == 1 means ray 93 | # intersects triangle 94 | final_intersections = final_inter.view(batch_size, point_nb, triangle_nb) 95 | # Check if intersection number accross mesh is odd to determine if point is 96 | # outside of mesh 97 | exterior = final_intersections.sum(2) % 2 == 0 98 | return exterior 99 | -------------------------------------------------------------------------------- /datasets/egad_synthetic_oneobject.py: -------------------------------------------------------------------------------- 1 | import glob 2 | 3 | import numpy as np 4 | 5 | from datasets.dataset import DatasetBase 6 | import trimesh 7 | import math 8 | 9 | 10 | class Dataset(DatasetBase): 11 | def __init__(self, opt, mode): 12 | super(Dataset, self).__init__(opt, mode) 13 | self.name = 'Dataset_egad_synthetic_one_object' 14 | self.debug = opt.debug 15 | self.object_mesh_dir = opt.object_mesh_dir 16 | self.grasp_dir = opt.grasp_dir 17 | self.should_pregenerate_data = opt.pregenerate_data 18 | self.num_viewpoints_per_object = opt.num_viewpoints_per_object 19 | self.setup_data() 20 | if self.should_pregenerate_data: 21 | self.pregenerate_data(self.num_viewpoints_per_object) 22 | 23 | def split_index(self, index): 24 | obj_id = math.floor(index/self.num_viewpoints_per_object) 25 | viewpoint_idx = index-self.num_viewpoints_per_object*obj_id 26 | return obj_id, viewpoint_idx 27 | 28 | def __getitem__(self, index): 29 | # Get object at random: 30 | 31 | id_obj, viewpoint_idx = self.split_index(index) 32 | 33 | if self.should_pregenerate_data: 34 | color, _, all_obj_verts, all_obj_verts_resampled800, all_obj_faces = self.pregenerated_data_per_object[ 35 | id_obj][viewpoint_idx] 36 | else: 37 | color, _, all_obj_verts, all_obj_verts_resampled800, all_obj_faces = self.generate_data( 38 | id_obj) 39 | 40 | # Normalize: 41 | img = color[0].transpose(1, 2, 0)/256 42 | img = img - self.means_rgb 43 | img = img / self.std_rgb 44 | # pack data 45 | sample = {'rgb_img': img, 46 | 'object_id': id_obj, 47 | '3d_points_object': all_obj_verts, 48 | '3d_faces_object': all_obj_faces, 49 | 'object_points_resampled': all_obj_verts_resampled800, 50 | 'taxonomy': 0, 51 | 'hand_gt_representation': 0, 52 | 'hand_gt_pose': 0, 53 | } 54 | 55 | return sample 56 | 57 | def __len__(self): 58 | return self.num_viewpoints_per_object*self.training_models 59 | 60 | def setup_data(self): 61 | 62 | models_original = glob.glob( 63 | self.object_mesh_dir + '*[0-9].ply') 64 | models_simplified = glob.glob( 65 | self.object_mesh_dir + '*_simplified.ply') 66 | models_original.sort() 67 | models_simplified.sort() 68 | self.models_original = np.array(models_original) 69 | self.models_simplified = [] 70 | for i in self.models_original: 71 | for j in models_simplified: 72 | if i.split("/")[-1].split(".")[0] in j: 73 | self.models_simplified.append(j) 74 | break 75 | 76 | self.models_simplified = np.asarray(self.models_simplified) 77 | self.training_models = self.models_original.shape[0] 78 | for i in range(self.training_models): 79 | obj_orig = trimesh.load(self.models_original[i]) 80 | obj_simp = trimesh.load(self.models_simplified[i]) 81 | object_center = obj_orig.centroid 82 | obj_orig.vertices -= object_center 83 | obj_simp.vertices -= object_center 84 | resampled = trimesh.sample.sample_surface_even(obj_orig, 800)[0] 85 | if resampled.shape[0] < 800: 86 | resampled = trimesh.sample.sample_surface(obj_orig, 800)[0] 87 | self.resampled_objects_800verts.append(resampled) 88 | 89 | self.all_object_faces.append(obj_orig.faces) 90 | self.all_object_vertices.append(obj_orig.vertices) 91 | # Get texture (it has to be face colors): 92 | visual = 255*np.ones((obj_orig.vertices.shape[0], 3)) 93 | colors = visual 94 | triangles = colors[obj_orig.faces] 95 | self.all_object_textures.append(np.uint8(triangles.mean(1))) 96 | self.all_object_vertices_simplified.append(obj_simp.vertices) 97 | self.all_object_faces_simplified.append(obj_simp.faces) 98 | 99 | self.resampled_objects_800verts = np.asarray( 100 | self.resampled_objects_800verts) 101 | -------------------------------------------------------------------------------- /datasets/dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import kaolin 3 | from kaolin.graphics.nmr.util import get_points_from_angles 4 | import torch 5 | import numpy as np 6 | import pyquaternion 7 | 8 | 9 | class DatasetFactory: 10 | def __init__(self): 11 | pass 12 | 13 | @staticmethod 14 | def get_by_name(dataset_name, opt, mode): 15 | if dataset_name == 'ycb': 16 | from datasets.ycb_synthetic_oneobject import Dataset 17 | dataset = Dataset(opt, mode) 18 | elif dataset_name == 'egad': 19 | from datasets.egad_synthetic_oneobject import Dataset 20 | dataset = Dataset(opt, mode) 21 | else: 22 | raise ValueError("Dataset [%s] not recognized." % dataset_name) 23 | 24 | print('Dataset {} was created'.format(dataset.name)) 25 | return dataset 26 | 27 | 28 | class DatasetBase(data.Dataset): 29 | def __init__(self, opt, mode): 30 | super(DatasetBase, self).__init__() 31 | self.name = 'BaseDataset' 32 | self.opt = opt 33 | self.mode = mode 34 | self.setup_camera() 35 | 36 | self.all_object_vertices = [] 37 | self.all_object_faces = [] 38 | self.all_object_textures = [] 39 | self.all_object_vertices_simplified = [] 40 | self.resampled_objects_800verts = [] 41 | self.all_object_faces_simplified = [] 42 | self.all_grasp_translations = [] 43 | self.all_grasp_rotations = [] 44 | self.all_grasp_hand_configurations = [] 45 | self.all_grasp_taxonomies = [] 46 | 47 | # Resnet normalization values 48 | self.means_rgb = [0.485, 0.456, 0.406] 49 | self.std_rgb = [0.229, 0.224, 0.225] 50 | 51 | self.IMG_EXTENSIONS = [ 52 | '.jpg', '.JPG', '.jpeg', '.JPEG', 53 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 54 | ] 55 | 56 | def get_name(self): 57 | return self.name 58 | 59 | def setup_camera(self): 60 | self.renderer = kaolin.graphics.NeuralMeshRenderer( 61 | camera_mode='look_at') 62 | self.renderer.light_intensity_directional = 0.0 63 | self.renderer.light_intensity_ambient = 1.0 64 | camera_distance = 0.4 65 | elevation = 0.0 66 | azimuth = 0.0 67 | self.renderer.eye = get_points_from_angles( 68 | camera_distance, elevation, azimuth) 69 | 70 | def generate_data(self, id_obj): 71 | random_rot_np = pyquaternion.Quaternion().random().rotation_matrix 72 | random_rot = torch.FloatTensor(random_rot_np) 73 | rotated_verts = torch.matmul( 74 | random_rot, torch.FloatTensor(self.all_object_vertices[id_obj]).T).T 75 | rotated_verts = rotated_verts.unsqueeze(0).cuda() 76 | faces = torch.LongTensor( 77 | self.all_object_faces[id_obj]).unsqueeze(0).cuda() 78 | textures = torch.FloatTensor(self.all_object_textures[id_obj]).cuda() 79 | 80 | # Reshape as needed for renderer: 81 | textures = textures.reshape( 82 | 1, len(textures), 1, 1, 1, 3).repeat(1, 1, 2, 2, 2, 1) 83 | 84 | # Rendering with CUDA: 85 | color, _, _ = self.renderer(rotated_verts, faces, textures=textures) 86 | all_object_vertices = [ 87 | np.matmul(random_rot_np, self.all_object_vertices_simplified[id_obj].T).T] 88 | all_obj_verts_resampled800 = [ 89 | np.matmul(random_rot_np, self.resampled_objects_800verts[id_obj].T).T] 90 | all_obj_faces = [self.all_object_faces_simplified[id_obj]] 91 | 92 | return color.cpu().data.numpy(), random_rot_np, all_object_vertices, all_obj_verts_resampled800, all_obj_faces 93 | 94 | def pregenerate_data(self, num_viewpoints_per_object): 95 | self.pregenerated_data_per_object = [] 96 | for id_obj in range(self.training_models): 97 | current_object_data = [] 98 | for _ in range(num_viewpoints_per_object): 99 | color, random_rot, all_object_vertices, all_obj_verts_resampled800, all_obj_faces = self.generate_data( 100 | id_obj) 101 | current_object_data.append( 102 | (color, random_rot, all_object_vertices, all_obj_verts_resampled800, all_obj_faces)) 103 | self.pregenerated_data_per_object.append(current_object_data) 104 | 105 | def collate_fn(self, args): 106 | length = len(args) 107 | keys = list(args[0].keys()) 108 | data = {} 109 | 110 | for _, key in enumerate(keys): 111 | data_type = [] 112 | 113 | if key == 'rgb_img' or key == 'mask_img' or key == 'noise_img' or key == 'plane_eq' or key == 'hand_gt_representation' or key == 'hand_gt_pose': 114 | for j in range(length): 115 | data_type.append(torch.FloatTensor(args[j][key])) 116 | data_type = torch.stack(data_type) 117 | elif key == 'label' or key == 'taxonomy': 118 | labels = [] 119 | for j in range(length): 120 | labels.append(args[j][key]) 121 | data_type = torch.LongTensor(labels) 122 | else: 123 | for j in range(length): 124 | data_type.append(args[j][key]) 125 | data[key] = data_type 126 | return data 127 | -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.optim import lr_scheduler 4 | 5 | 6 | class ModelsFactory: 7 | def __init__(self): 8 | pass 9 | 10 | @staticmethod 11 | def get_by_name(*args, **kwargs): 12 | model = None 13 | 14 | from .multi_fingan import Model 15 | model = Model(*args, **kwargs) 16 | 17 | print("Model %s was created" % model.name) 18 | return model 19 | 20 | 21 | class BaseModel(object): 22 | def __init__(self, opt): 23 | self.name = 'BaseModel' 24 | 25 | self.opt = opt 26 | self.gpu_ids = opt.gpu_ids 27 | self.is_train = opt.is_train 28 | 29 | self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor 30 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 31 | 32 | def get_name(self): 33 | return self.name 34 | 35 | def is_train(self): 36 | return self.is_train 37 | 38 | def set_input(self, input): 39 | assert False, "set_input not implemented" 40 | 41 | def set_train(self): 42 | assert False, "set_train not implemented" 43 | 44 | def set_eval(self): 45 | assert False, "set_eval not implemented" 46 | 47 | def forward(self, keep_data_for_visuals=False): 48 | assert False, "forward not implemented" 49 | 50 | # used in test time, no backprop 51 | def test(self): 52 | assert False, "test not implemented" 53 | 54 | def get_image_paths(self): 55 | return {} 56 | 57 | def optimize_parameters(self): 58 | assert False, "optimize_parameters not implemented" 59 | 60 | def get_current_visuals(self): 61 | return {} 62 | 63 | def get_current_errors(self): 64 | return {} 65 | 66 | def get_current_scalars(self): 67 | return {} 68 | 69 | def save(self, label): 70 | assert False, "save not implemented" 71 | 72 | def load(self): 73 | assert False, "load not implemented" 74 | 75 | def save_optimizer(self, optimizer, optimizer_label, label): 76 | save_filename = 'opt_epoch_%s_id_%s.pth' % (label, 77 | optimizer_label) 78 | save_path = os.path.join(self.save_dir, save_filename) 79 | torch.save( 80 | { 81 | 'optimizer_state_dict': optimizer.state_dict(), 82 | 'learning_rate_G': self.current_lr_image_encoder_and_grasp_predictor, 83 | }, save_path) 84 | 85 | def load_optimizer(self, optimizer, optimizer_label, label, device): 86 | load_filename = 'opt_epoch_%s_id_%s.pth' % (label, 87 | optimizer_label) 88 | load_path = os.path.join(self.save_dir, load_filename) 89 | assert os.path.exists( 90 | load_path 91 | ), 'Weights file not found. Have you trained a model!? We are not providing one' % load_path 92 | 93 | checkpoint = torch.load(load_path, map_location=device) 94 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 95 | self.current_lr_image_encoder_and_grasp_predictor = checkpoint['learning_rate_G'] 96 | print('loaded optimizer: %s' % load_path) 97 | 98 | def save_network(self, network, network_label, label, epoch): 99 | save_filename = 'net_epoch_%s_id_%s.pth' % (label, network_label) 100 | save_path = os.path.join(self.save_dir, save_filename) 101 | torch.save( 102 | { 103 | 'epoch': epoch, 104 | 'model_state_dict': network.state_dict(), 105 | }, save_path) 106 | 107 | #print('saved net: %s' % save_path) 108 | 109 | def load_network(self, network, network_label, label, device): 110 | load_filename = 'net_epoch_%s_id_%s.pth' % (label, network_label) 111 | load_path = os.path.join(self.save_dir, load_filename) 112 | assert os.path.exists( 113 | load_path 114 | ), 'Weights file not found. Have you trained a model!? We are not providing one' % load_path 115 | checkpoint = torch.load(load_path, map_location=device) 116 | network.load_state_dict(checkpoint['model_state_dict']) 117 | self.set_epoch(checkpoint["epoch"]) 118 | print('loaded net: %s' % load_path) 119 | 120 | def update_learning_rate(self): 121 | pass 122 | 123 | def print_network(self, network): 124 | num_params = 0 125 | for param in network.parameters(): 126 | num_params += param.numel() 127 | print(network) 128 | print('Total number of parameters: %d' % num_params) 129 | 130 | def get_scheduler(self, optimizer, opt): 131 | if opt.lr_policy == 'lambda': 132 | 133 | def lambda_rule(epoch): 134 | lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - 135 | opt.niter) / float(opt.niter_decay + 1) 136 | return lr_l 137 | 138 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 139 | elif opt.lr_policy == 'step': 140 | scheduler = lr_scheduler.StepLR(optimizer, 141 | step_size=opt.lr_decay_iters, 142 | gamma=0.1) 143 | elif opt.lr_policy == 'plateau': 144 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 145 | mode='min', 146 | factor=0.2, 147 | threshold=0.01, 148 | patience=5) 149 | else: 150 | return NotImplementedError( 151 | 'learning rate policy [%s] is not implemented', opt.lr_policy) 152 | return scheduler 153 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict 2 | import numpy as np 3 | import torch 4 | import random 5 | from .base_options import BaseOptions 6 | 7 | 8 | class TrainOptions(BaseOptions): 9 | def initialize(self): 10 | BaseOptions.initialize(self) 11 | self.parser.add_argument( 12 | '--n_threads_train', default=4, type=int, help='# threads for loading data') 13 | self.parser.add_argument('--num_iters_validate', default=10, 14 | type=int, help='# batches to use when validating') 15 | self.parser.add_argument('--print_freq_s', type=int, default=60, 16 | help='frequency of showing training results on console') 17 | self.parser.add_argument('--display_freq_s', type=int, default=300, 18 | help='frequency [s] of showing training results on screen') 19 | self.parser.add_argument('--display_freq_s_val', type=int, default=300, 20 | help='frequency [s] of showing training results on screen') 21 | self.parser.add_argument('--save_latest_freq_s', type=int, default=3600, 22 | help='frequency of saving the latest results') 23 | 24 | self.parser.add_argument('--nepochs_no_decay', type=int, default=400, 25 | help='# of epochs at starting learning rate') 26 | self.parser.add_argument('--nepochs_decay', type=int, default=400, 27 | help='# of epochs to linearly decay learning rate to zero') 28 | 29 | self.parser.add_argument('--train_G_every_n_iterations', type=int, 30 | default=5, help='train G every n interations') 31 | self.parser.add_argument( 32 | '--optimizer', default="Adam", choices=["Adam", "SGD"], type=str) 33 | 34 | self.parser.add_argument( 35 | '--poses_g_sigma', type=float, default=0.06, help='initial learning rate for adam') 36 | self.parser.add_argument( 37 | '--lr_G', type=float, default=0.0001, help='initial learning rate for G adam') 38 | self.parser.add_argument( 39 | '--G_adam_b1', type=float, default=0.5, help='beta1 for G adam') 40 | self.parser.add_argument( 41 | '--G_adam_b2', type=float, default=0.999, help='beta2 for G adam') 42 | self.parser.add_argument( 43 | '--lr_D', type=float, default=0.0001, help='initial learning rate for D adam') 44 | self.parser.add_argument( 45 | '--D_adam_b1', type=float, default=0.5, help='beta1 for D adam') 46 | self.parser.add_argument( 47 | '--D_adam_b2', type=float, default=0.999, help='beta2 for D adam') 48 | self.parser.add_argument('--lambda_D_prob', type=float, default=1, 49 | help='lambda for real/fake discriminator loss') 50 | self.parser.add_argument( 51 | '--lambda_D_gp', type=float, default=10, help='lambda gradient penalty loss') 52 | 53 | self.parser.add_argument( 54 | '--lambda_G_classification', type=float, default=1.0, help='') 55 | self.parser.add_argument( 56 | '--lambda_G_contactloss', type=float, default=100.0, help='') 57 | self.parser.add_argument( 58 | '--lambda_G_intersections', type=float, default=100.0, help='') 59 | self.parser.add_argument('--no_discriminator', action='store_true', 60 | help='if true, do not train the discriminator') 61 | self.parser.add_argument('--no_classification_loss', action='store_true', 62 | help='if true, do not train with intersection loss') 63 | self.parser.add_argument('--no_intersection_loss', action='store_true', 64 | help='if true, do not train with intersection loss') 65 | self.parser.add_argument('--no_contact_loss', action='store_true', 66 | help='if true, do not train with intersection loss') 67 | self.parser.add_argument('--no_orientation_loss', action='store_true', 68 | help='if true, do not train with orientation loss') 69 | self.parser.add_argument( 70 | '--lambda_G_orientation', type=float, default=1.0, help='') 71 | self.parser.add_argument( 72 | '--continue_train', action='store_true', help='if true then we continue to train') 73 | self.parser.add_argument('--ablation_study', action='store_true', 74 | help='if true then we train for ablation study') 75 | opts, _ = self.parser.parse_known_args() 76 | if opts.continue_train: 77 | self.parser.add_argument( 78 | '--checkpoint_dir_load', type=str, help='path to checkpoint we want to continue train from', required="true") 79 | self.parser.add_argument('--load_epoch', type=str, default='latest', 80 | help='which epoch to load? ') 81 | self.is_train = True 82 | 83 | def parse(self): 84 | if not self.initialized: 85 | self.initialize() 86 | self.opt = EasyDict(vars(self.parser.parse_args())) 87 | if not self.opt.pregenerate_data: 88 | self.opt.n_threads_train = 0 89 | self.opt.n_threads_test = 0 90 | 91 | # set is train or set 92 | self.opt.is_train = True 93 | if self.opt.manual_seed is None: 94 | self.opt.manual_seed = random.randint(1, 10000) 95 | 96 | # set and check load_epoch 97 | if self.opt.continue_train: 98 | self.load_epoch() 99 | self.opt.load_network = True 100 | else: 101 | self.set_epoch() 102 | self.opt.load_network = False 103 | torch.manual_seed(self.opt.manual_seed) 104 | np.random.seed(self.opt.manual_seed) 105 | # get and set gpus 106 | if self.opt.ablation_study: 107 | self.opt.checkpoints_dir = self.opt.checkpoints_dir+"/ablation_study/" 108 | 109 | self.opt.dataset_name = "ycb" 110 | self.get_set_gpus() 111 | 112 | args = vars(self.opt) 113 | 114 | # print in terminal args 115 | self.print(args) 116 | 117 | # save args to file 118 | self.save(args) 119 | 120 | return self.opt 121 | -------------------------------------------------------------------------------- /datasets/ycb_synthetic_oneobject.py: -------------------------------------------------------------------------------- 1 | import glob 2 | 3 | import numpy as np 4 | import pickle 5 | 6 | from datasets.dataset import DatasetBase 7 | from utils import util 8 | import pyquaternion 9 | import trimesh 10 | import math 11 | 12 | 13 | class Dataset(DatasetBase): 14 | def __init__(self, opt, mode): 15 | super(Dataset, self).__init__(opt, mode) 16 | self.name = 'Dataset_ycb_synthetic_one_object' 17 | self.debug = opt.debug 18 | self.object_mesh_dir = opt.object_mesh_dir 19 | self.grasp_dir = opt.grasp_dir 20 | self.should_pregenerate_data = opt.pregenerate_data 21 | # All models in the ycb dataset we train on are: 22 | # [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 17, 18, 20, 21, 30, 31, 32, 34, 40, 41, 49] 23 | # Test set is now the same as the val set 24 | self.test_models = np.asarray([3, 7, 40]) 25 | self.num_viewpoints_per_object = opt.num_viewpoints_per_object 26 | # read dataset 27 | self.setup_data() 28 | 29 | # Either we pre-generate data once and reuse these all the time which significantly speeds 30 | # up data loading or we generate new data for each batch 31 | if self.should_pregenerate_data: 32 | self.pregenerate_data(self.num_viewpoints_per_object) 33 | 34 | def split_index(self, index): 35 | obj_id = math.floor(index/self.num_viewpoints_per_object) 36 | viewpoint_idx = index-self.num_viewpoints_per_object*obj_id 37 | return obj_id, viewpoint_idx 38 | 39 | def __getitem__(self, index): 40 | # Get object at random: 41 | 42 | id_obj, viewpoint_idx = self.split_index(index) 43 | 44 | if self.should_pregenerate_data: 45 | color, random_rot, all_object_vertices, all_obj_verts_resampled800, all_obj_faces = self.pregenerated_data_per_object[ 46 | id_obj][viewpoint_idx] 47 | else: 48 | color, random_rot, all_object_vertices, all_obj_verts_resampled800, all_obj_faces = self.generate_data( 49 | id_obj) 50 | 51 | if self.mode == "val": 52 | # Val models are both train and test models but from new viewpoints 53 | all_grasp_taxonomies = [] 54 | grasp_repr = [] 55 | grasp_pose = [] 56 | else: 57 | 58 | # Get random grasp: 59 | id_grasp = np.random.randint( 60 | 0, len(self.all_grasp_rotations[id_obj])) 61 | grasp_rot = self.all_grasp_rotations[id_obj][id_grasp] 62 | grasp_trans = self.all_grasp_translations[id_obj][id_grasp] 63 | grasp_dof = self.all_grasp_hand_configurations[id_obj][id_grasp] 64 | # Taxonomy ground truth: 65 | all_grasp_taxonomies = self.all_grasp_taxonomies[id_obj] 66 | grasp_rot = pyquaternion.Quaternion( 67 | np.array(grasp_rot)[[3, 0, 1, 2]]).rotation_matrix 68 | grasp_pose = np.eye(4) 69 | grasp_pose[:3, :3] = grasp_rot 70 | grasp_pose[:3, 3] = grasp_trans 71 | 72 | grasp_to_object_transformation = np.eye(4) 73 | grasp_to_object_transformation[:3, :3] = random_rot 74 | grasp_pose = np.matmul(grasp_to_object_transformation, grasp_pose) 75 | grasp_repr = util.joints_to_grasp_representation(grasp_dof) 76 | 77 | # Normalize: 78 | img = color[0].transpose(1, 2, 0)/256 79 | img = img - self.means_rgb 80 | img = img / self.std_rgb 81 | 82 | # pack data 83 | sample = {'rgb_img': img, 84 | 'object_id': id_obj, 85 | 'taxonomy': all_grasp_taxonomies, 86 | 'hand_gt_representation': grasp_repr, 87 | 'hand_gt_pose': grasp_pose, 88 | '3d_points_object': all_object_vertices, 89 | '3d_faces_object': all_obj_faces, 90 | 'object_points_resampled': all_obj_verts_resampled800, 91 | } 92 | 93 | return sample 94 | 95 | def __len__(self): 96 | return self.num_viewpoints_per_object*self.training_models 97 | 98 | def setup_data(self): 99 | 100 | models_original = glob.glob( 101 | self.object_mesh_dir + '/0*/google_16k/textured.obj') 102 | models_simplified = glob.glob( 103 | self.object_mesh_dir + '/0*/google_16k/textured_simplified.obj') 104 | 105 | models_original.sort() 106 | models_simplified.sort() 107 | objects_in_YCB = np.load('./data/objects_in_YCB.npy') 108 | if self.mode == "train": 109 | objects_in_YCB = np.setdiff1d(objects_in_YCB, self.test_models) 110 | elif self.mode == "test": 111 | objects_in_YCB = self.test_models 112 | elif self.mode == "val": 113 | # Val models are both train and test models but from new viewpoints 114 | objects_in_YCB = objects_in_YCB 115 | 116 | self.models_original = np.array(models_original)[objects_in_YCB] 117 | self.models_simplified = [] 118 | for i in self.models_original: 119 | for j in models_simplified: 120 | if i.split("/")[-3] in j: 121 | self.models_simplified.append(j) 122 | break 123 | self.models_simplified = np.asarray(self.models_simplified) 124 | self.training_models = self.models_original.shape[0] 125 | for i in range(self.training_models): 126 | # Kaolin doesn't load textures from obj 127 | # so using Trimesh 128 | obj_orig = trimesh.load(self.models_original[i]) 129 | obj_simp = trimesh.load(self.models_simplified[i]) 130 | object_center = obj_orig.centroid 131 | obj_orig.vertices -= object_center 132 | obj_simp.vertices -= object_center 133 | resampled = trimesh.sample.sample_surface_even(obj_orig, 800)[0] 134 | if resampled.shape[0] < 800: 135 | resampled = trimesh.sample.sample_surface(obj_orig, 800)[0] 136 | self.resampled_objects_800verts.append(resampled) 137 | 138 | self.all_object_faces.append(obj_orig.faces) 139 | self.all_object_vertices.append(obj_orig.vertices) 140 | # Get texture (it has to be face colors): 141 | visual = obj_orig.visual.to_color() 142 | colors = visual.vertex_colors[:, :3] 143 | triangles = colors[obj_orig.faces] 144 | self.all_object_textures.append(np.uint8(triangles.mean(1))) 145 | self.all_object_vertices_simplified.append(obj_simp.vertices) 146 | self.all_object_faces_simplified.append(obj_simp.faces) 147 | 148 | # Get all grasps for this object: 149 | object_grasp_translations = [] 150 | object_grasp_rotations = [] 151 | object_grasp_hand_configurations = [] 152 | grasp_files = glob.glob( 153 | self.grasp_dir + "obj_%d_*" % objects_in_YCB[i]) 154 | grasp_files.sort() 155 | available_taxonomy = np.zeros(7) 156 | for file in grasp_files: 157 | data = pickle.load(open(file, 'rb'), encoding='latin') 158 | object_grasp_translations.append( 159 | data['pose'][:3] - object_center) 160 | object_grasp_rotations.append(data['pose'][3:]) 161 | object_grasp_hand_configurations.append(data['joints']) 162 | available_taxonomy[data['taxonomy'] - 1] += 1 163 | 164 | self.all_grasp_translations.append(object_grasp_translations) 165 | self.all_grasp_rotations.append(object_grasp_rotations) 166 | self.all_grasp_hand_configurations.append( 167 | object_grasp_hand_configurations) 168 | self.all_grasp_taxonomies.append((available_taxonomy > 0)*1) 169 | 170 | self.resampled_objects_800verts = np.asarray( 171 | self.resampled_objects_800verts) 172 | -------------------------------------------------------------------------------- /utils/forward_kinematics_barrett.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils import util 3 | 4 | 5 | def get_relevant_vertices(): 6 | finger_vertices = [7, 37, 21, 51, 71, 34, 61, 86, 46, 22, 53, 40, 69, 52, 38, 13, 29, 19, 8, 81, 60, 72, 97, 36, 7 | 28, 45, 49, 66, 3, 70, 65, 68, 25, 95, 48, 57, 98, 27, 79, 15, 80, 89, 62, 100, 91, 5, 44, 63] 8 | finger_tip_vertices = [35, 98, 44, 90, 95, 77, 84, 73, 76, 25, 108, 64, 22, 24, 96, 23, 85, 79, 83, 30, 45, 47, 68, 54, 42, 69, 92, 86, 9 | 19, 7, 94, 37, 99, 91, 11, 107, 0, 89, 57, 59, 109, 4, 65, 31, 2, 1, 10, 101, 52, 97, 87, 50, 72, 15, 106, 82, 12, 56, 78, 32, 46, 8] 10 | return finger_vertices, finger_tip_vertices 11 | 12 | 13 | def optimize_fingers(handfullpose, R, T, obj_verts, barrett_layer, object_finger_threshold, optimize_finger_tip=False, step=25, device="cpu"): 14 | relevant_finger_vertices, relevant_finger_tip_vertices = get_relevant_vertices() 15 | handfullpose_converged = handfullpose.clone().to(device) 16 | touching_indexes = 0 17 | 18 | num_samples = 1000//step + 1 19 | pose = torch.zeros((1, 4, 4)).to(device) 20 | pose[0, :3, :3] = R 21 | pose[0, :3, 3] = T 22 | pose = pose.repeat(num_samples, 1, 1) 23 | 24 | inds = torch.linspace(0, util.deg2rad(140), num_samples).to(device) 25 | 26 | def solve_for_rotation(idx): 27 | _, _, all_finger_vertices, all_finger_tip_vertices, _ = barrett_layer(pose, handfullpose_repeated) 28 | # all_relevant_finger_vertices = all_finger_vertices # [:, :, vertices["finger"][0]] 29 | # all_relevant_finger_tip_vertices = all_finger_tip_vertices # [:, :, vertices["finger_tip"][0]] 30 | all_relevant_finger_vertices = all_finger_vertices[:, :, relevant_finger_vertices] 31 | all_relevant_finger_tip_vertices = all_finger_tip_vertices[:, :, relevant_finger_tip_vertices] 32 | # We have three fingers on the barrett hand 33 | touching_indexes = 0 34 | for i in range(3): 35 | current_finger_vertices = all_relevant_finger_vertices[:, i] 36 | current_finger_tip_vertices = all_relevant_finger_tip_vertices[:, i] 37 | current_concat_finger_vertices = torch.cat((current_finger_vertices, current_finger_tip_vertices), dim=1).squeeze() 38 | vertex_solution, converged = get_optimization_angle( 39 | current_concat_finger_vertices, obj_verts, object_finger_threshold, device) 40 | if converged: 41 | touching_indexes += 1 42 | delta_angle = inds[vertex_solution]-handfullpose[0, i+idx] # - inds[vertex_solution] 43 | handfullpose[0, i+idx] = handfullpose[0, i+idx] + delta_angle 44 | return touching_indexes 45 | handfullpose_repeated = handfullpose_converged.repeat(num_samples, 1) 46 | handfullpose_repeated[:, 1] = inds 47 | handfullpose_repeated[:, 2] = inds 48 | handfullpose_repeated[:, 3] = inds 49 | touching_indexes = solve_for_rotation(1) 50 | if optimize_finger_tip: 51 | inds = torch.linspace(0, util.deg2rad(40), num_samples).to(device) 52 | handfullpose_repeated = handfullpose.clone().to(device).repeat(num_samples, 1) 53 | handfullpose_repeated[:, 4] = inds 54 | handfullpose_repeated[:, 5] = inds 55 | handfullpose_repeated[:, 6] = inds 56 | touching_indexes += solve_for_rotation(4) 57 | return touching_indexes 58 | 59 | 60 | def optimize_fingers_batched(handfullpose, R, T, obj_verts, barrett_layer, object_finger_threshold, optimize_finger_tip=False, step=25, device="cpu"): 61 | if type(obj_verts) is not torch.Tensor: 62 | obj_verts = torch.FloatTensor(obj_verts).to(device) 63 | 64 | relevant_finger_vertices, relevant_finger_tip_vertices = get_relevant_vertices() 65 | 66 | handfullpose_converged = handfullpose.clone().to(device) 67 | touching_indexes = 0 68 | batch_size = handfullpose.shape[0] 69 | 70 | num_samples = 1000//step + 1 71 | inds = torch.linspace(0, util.deg2rad(140), num_samples).to(device).repeat(batch_size) 72 | 73 | # Here we need to create, for each barrett hand in the batch, a new barrett hand with joints 74 | # set from fully open to fully closed 75 | pose = torch.zeros((inds.shape[0], 4, 4)).to(device) 76 | for i in range(batch_size): 77 | pose[i*num_samples:(i+1)*num_samples, :3, :3] = R[i] 78 | pose[i*num_samples:(i+1)*num_samples, :3, 3] = T[i] 79 | 80 | def solve_for_rotation(idx): 81 | _, _, all_finger_vertices, all_finger_tip_vertices, _ = barrett_layer(pose, handfullpose_repeated) 82 | all_relevant_finger_vertices = all_finger_vertices[:, :, relevant_finger_vertices] 83 | all_relevant_finger_tip_vertices = all_finger_tip_vertices[:, :, relevant_finger_tip_vertices] 84 | # We have three fingers on the barrett hand 85 | for i in range(3): 86 | current_finger_vertices = all_relevant_finger_vertices[:, i] 87 | current_finger_tip_vertices = all_relevant_finger_tip_vertices[:, i] 88 | current_concat_finger_vertices = torch.cat((current_finger_vertices, current_finger_tip_vertices), dim=1) 89 | 90 | current_concat_finger_vertices = current_concat_finger_vertices.reshape( 91 | batch_size, current_concat_finger_vertices.shape[0]//batch_size, current_concat_finger_vertices.shape[-2], 3) 92 | vertex_solution, _ = get_optimization_angle_batched( 93 | current_concat_finger_vertices, obj_verts, object_finger_threshold, device) 94 | handfullpose[:, i+idx] = handfullpose[:, i+idx] + inds[vertex_solution]-handfullpose[:, i+idx] # delta_angle 95 | 96 | handfullpose_repeated = handfullpose_converged.repeat_interleave(num_samples, dim=0) 97 | handfullpose_repeated[:, 1] = inds 98 | handfullpose_repeated[:, 2] = inds 99 | handfullpose_repeated[:, 3] = inds 100 | solve_for_rotation(1) 101 | if optimize_finger_tip: 102 | # TODO: Here the GPUs are synchronizing and the execution time is way slower than not optimizing the fingers 103 | inds = torch.linspace(0, util.deg2rad(40), num_samples, device=device).repeat(batch_size) # .to(device) 104 | handfullpose_repeated = handfullpose.clone().repeat(1, num_samples).view(-1, 7) # repeat_interleave(num_samples, dim=0) 105 | handfullpose_repeated[:, 4] = inds 106 | handfullpose_repeated[:, 5] = inds 107 | handfullpose_repeated[:, 6] = inds 108 | 109 | return touching_indexes 110 | 111 | 112 | def get_optimization_angle_batched(arc_points, obj_verts, object_finger_threshold, device): 113 | arc_points = arc_points.unsqueeze(1) 114 | obj_verts = obj_verts.unsqueeze(1).unsqueeze(1) 115 | dists = obj_verts - arc_points 116 | eu_dists = (dists**2).sum(-1) 117 | eu_dists = torch.sqrt(eu_dists.min(1)[0]) 118 | solutions = eu_dists < object_finger_threshold 119 | earliest_in_arc = solutions.max(1)[1] 120 | zero = (earliest_in_arc == 0) 121 | vertex_solution = (999*zero+earliest_in_arc).min(-1)[0] 122 | vertex_solution = vertex_solution*(vertex_solution != 999) 123 | converged = solutions.view(solutions.shape[0], -1).max(-1)[0] == 1 124 | return vertex_solution, converged 125 | 126 | 127 | def get_optimization_angle(arc_points, obj_verts, object_finger_threshold, device): 128 | if type(obj_verts) is not torch.Tensor: 129 | obj_verts = torch.FloatTensor(obj_verts).to(device) 130 | 131 | obj_verts = obj_verts.unsqueeze(1).unsqueeze(1) 132 | arc_points = arc_points.unsqueeze(0) 133 | dists = obj_verts - arc_points 134 | eu_dists = torch.sqrt((dists**2).sum(-1)) 135 | threshold = torch.FloatTensor([object_finger_threshold]).to(device) 136 | 137 | eu_dists = eu_dists.min(0)[0] 138 | solutions = eu_dists < threshold 139 | 140 | earliest_in_arc = solutions.cpu().data.numpy().argmax(0) 141 | earliest_in_arc[earliest_in_arc == 0] = 999 142 | 143 | vertex_solution = earliest_in_arc.min() 144 | if vertex_solution == 999: 145 | vertex_solution = 0 146 | 147 | converged = solutions.max().cpu().data.numpy() == 1 148 | return vertex_solution, converged 149 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | from options.train_options import TrainOptions 3 | from datasets.custom_dataset_data_loader import CustomDatasetDataLoader 4 | from models.models import ModelsFactory 5 | from utils.tb_visualizer import TBVisualizer 6 | from utils import util 7 | import torch 8 | import os 9 | import subprocess 10 | 11 | 12 | class Train: 13 | def __init__(self): 14 | self.opt = TrainOptions().parse() 15 | self.model = ModelsFactory.get_by_name(self.opt) 16 | self.tb_visualizer = TBVisualizer(self.opt) 17 | if self.get_training_data(): 18 | self.setup_train_test_sets() 19 | self.train() 20 | 21 | def get_training_data(self): 22 | if not os.path.isdir("data/train_data/") or len(os.listdir("data/train_data/")) == 0 or not os.path.isdir("data/meshes/") or len(os.listdir("data/meshes/")) == 1: 23 | while True: 24 | download_data = input("Training data does not exist. Want to download it (y/n)? ") 25 | if download_data == "y": 26 | dir = os.path.dirname(os.path.realpath(__file__)) + "/data/" 27 | print("Downloading training data. This will take some time so just sit back and relax.") 28 | subprocess.Popen([dir+'download_train_data.sh %s' % dir], shell=True).wait() 29 | print("Done downloading training data. Continuing with training.") 30 | return True 31 | elif download_data == "n": 32 | print("You chose to not download the data. Terminating training") 33 | return False 34 | else: 35 | print("Training data exists. Proceeding with training.") 36 | return True 37 | 38 | def setup_train_test_sets(self): 39 | data_loader_train = CustomDatasetDataLoader(self.opt, mode='train') 40 | self.dataset_train = data_loader_train.load_data() 41 | self.dataset_train_size = len(data_loader_train) 42 | print('#train images = %d' % self.dataset_train_size) 43 | data_loader_val = CustomDatasetDataLoader(self.opt, mode='test') 44 | self.dataset_val = data_loader_val.load_data() 45 | self.dataset_val_size = len(data_loader_val) 46 | print('#val images = %d' % self.dataset_val_size) 47 | 48 | def train(self): 49 | # Here we set the start epoch. It is nonzero only if we continue train or test as the epoch saved for the network 50 | # we load is used 51 | start_epoch = self.model.get_epoch() 52 | self.total_steps = start_epoch * self.dataset_train_size 53 | self.iters_per_epoch = self.dataset_train_size / self.opt.batch_size 54 | self.last_display_time = None 55 | self.last_display_time_val = None 56 | self.last_save_latest_time = None 57 | self.last_print_time = time.time() 58 | self.visuals_per_batch = self.iters_per_epoch//2 59 | for i_epoch in range(start_epoch, self.opt.nepochs_no_decay + self.opt.nepochs_decay + 1): 60 | epoch_start_time = time.time() 61 | 62 | # train epoch 63 | self.train_epoch(i_epoch) 64 | # print epoch info 65 | self.print_epoch_info(time.time() - epoch_start_time, i_epoch) 66 | # update learning rate 67 | self.update_learning_rate(i_epoch) 68 | # save model 69 | self.model.save("latest", i_epoch+1) 70 | self.display_visualizer_train(i_epoch) 71 | if (i_epoch) % 5 == 0: # Only test the network every fifth epoch 72 | self.test(i_epoch, self.total_steps) 73 | 74 | def print_epoch_info(self, time_epoch, epoch_num): 75 | print('End of epoch %d / %d \t Time Taken: %d sec (%d min or %d h)' % 76 | (epoch_num, self.opt.nepochs_no_decay + self.opt.nepochs_decay, time_epoch, 77 | time_epoch / 60, time_epoch / 3600)) 78 | 79 | def update_learning_rate(self, epoch_num): 80 | if epoch_num > self.opt.nepochs_no_decay: 81 | self.model.update_learning_rate() 82 | 83 | def train_epoch(self, i_epoch): 84 | self.model.set_train() 85 | self.epoch_losses_G = [] 86 | self.epoch_losses_D = [] 87 | self.epoch_scalars = [] 88 | self.epoch_visuals = [] 89 | for i_train_batch, train_batch in enumerate(self.dataset_train): 90 | iter_start_time = time.time() 91 | 92 | self.model.set_input(train_batch) 93 | train_generator = self.train_generator(i_train_batch) 94 | 95 | self.model.optimize_parameters(train_generator=train_generator) 96 | 97 | self.total_steps += self.opt.batch_size 98 | 99 | self.bookkeep_epoch_data(train_generator) 100 | 101 | if ((i_train_batch+1) % self.visuals_per_batch == 0): 102 | self.bookkeep_epoch_visualizations() 103 | self.display_terminal( 104 | iter_start_time, i_epoch, i_train_batch, True) 105 | 106 | def train_generator(self, batch_num): 107 | return ((batch_num+1) % self.opt.train_G_every_n_iterations) == 0 108 | 109 | def bookkeep_epoch_visualizations(self): 110 | self.epoch_visuals.append( 111 | self.model.get_current_visuals()) 112 | 113 | def bookkeep_epoch_data(self, train_generator): 114 | if train_generator: 115 | self.epoch_losses_G.append(self.model.get_current_errors_G()) 116 | self.epoch_scalars.append(self.model.get_current_scalars()) 117 | self.epoch_losses_D.append(self.model.get_current_errors_D()) 118 | 119 | def display_terminal(self, iter_start_time, i_epoch, i_train_batch, visuals_flag): 120 | errors = self.model.get_current_errors() 121 | t = (time.time() - iter_start_time) / self.opt.batch_size 122 | self.tb_visualizer.print_current_train_errors( 123 | i_epoch, i_train_batch, self.iters_per_epoch, errors, t, visuals_flag) 124 | 125 | def display_visualizer_train(self, total_steps): 126 | self.tb_visualizer.display_current_results( 127 | util.concatenate_dictionary(self.epoch_visuals), total_steps, is_train=True, save_visuals=True) 128 | self.tb_visualizer.plot_scalars( 129 | util.average_dictionary(self.epoch_losses_G), total_steps, is_train=True) 130 | self.tb_visualizer.plot_scalars( 131 | util.average_dictionary(self.epoch_losses_D), total_steps, is_train=True) 132 | self.tb_visualizer.plot_scalars( 133 | util.average_dictionary(self.epoch_scalars), total_steps, is_train=True) 134 | 135 | def display_visualizer_test(self, test_epoch_visuals, epoch_num, average_test_results, test_time, total_steps): 136 | self.tb_visualizer.print_current_validate_errors( 137 | epoch_num, average_test_results, test_time) 138 | self.tb_visualizer.plot_scalars( 139 | average_test_results, epoch_num, is_train=False) 140 | self.tb_visualizer.display_current_results( 141 | util.concatenate_dictionary(test_epoch_visuals), total_steps, is_train=False, save_visuals=True) 142 | 143 | def test(self, i_epoch, total_steps): 144 | val_start_time = time.time() 145 | 146 | self.model.set_eval() 147 | test_epoch_visuals = [] 148 | 149 | iters_per_epoch_val = self.dataset_val_size / self.opt.batch_size 150 | visuals_per_val_epoch = max(1, round(iters_per_epoch_val//2)) 151 | errors = [] 152 | with torch.no_grad(): 153 | for i_val_batch, val_batch in enumerate(self.dataset_val): 154 | self.model.set_input(val_batch) 155 | self.model.forward_G(train=True) 156 | errors.append(self.model.get_current_errors_G()) 157 | if (i_val_batch+1) % visuals_per_val_epoch == 0: 158 | test_epoch_visuals.append( 159 | self.model.get_current_visuals()) 160 | 161 | average_test_results = util.average_dictionary(errors) 162 | test_time = (time.time() - val_start_time) 163 | self.display_visualizer_test(test_epoch_visuals, i_epoch, average_test_results, test_time, total_steps) 164 | self.model.set_train() 165 | 166 | 167 | if __name__ == "__main__": 168 | Train() 169 | -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | from utils import util 5 | import torch 6 | import yaml 7 | import sys 8 | from easydict import EasyDict 9 | 10 | 11 | class BaseOptions(): 12 | def __init__(self): 13 | self.parser = argparse.ArgumentParser() 14 | self.initialized = False 15 | 16 | def initialize(self): 17 | self.parser.add_argument('--object_mesh_dir', type=str, 18 | default='./data/meshes/ycb_meshes/', help='path to dataset') 19 | self.parser.add_argument('--grasp_dir', type=str, 20 | default='./data/train_data/graspit_training_grasps/', help='path to dataset') 21 | self.parser.add_argument( 22 | '--batch_size', type=int, default=100, help='input batch size') 23 | self.parser.add_argument( 24 | '--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 25 | self.parser.add_argument('-n', '--name', type=str, default='', 26 | help='name of the experiment. It decides where to store samples and models') 27 | self.parser.add_argument( 28 | '--n_threads_test', default=0, type=int, help='# threads for loading data') 29 | self.parser.add_argument( 30 | '--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 31 | self.parser.add_argument('--serial_batches', action='store_true', 32 | help='if true, takes images in order to make batches, otherwise takes them randomly') 33 | self.parser.add_argument('--precomputed_rotations', action='store_true', 34 | help='Use precomputed rotations') 35 | self.parser.add_argument('-ofct', '--object_finger_contact_threshold', type=float, default=0.01, 36 | help='The threshold for object-finger contact in the hand refinement part') 37 | self.parser.add_argument('--random_rot_std', type=float, default=0.1) 38 | self.parser.add_argument('--collision_loss_threshold', 39 | type=float, default=0.01) 40 | self.parser.add_argument( 41 | '--manual_seed', type=int, help='manual seed') 42 | self.parser.add_argument( 43 | '--extra_name', type=str, default='', help='string appended to end of folder which we save data to') 44 | 45 | self.parser.add_argument( 46 | '--pregenerate_data', action='store_true', help='If we want to pregenerate viewpoints') 47 | self.parser.add_argument( 48 | '--debug', 49 | action='store_true', 50 | help='If true, we will only train on the first object' 51 | ) 52 | self.parser.add_argument( 53 | '--num_viewpoints_per_object', 54 | type=int, 55 | default=100, 56 | ) 57 | self.parser.add_argument( 58 | '--constrain_method', choices=['soft', 'hard', 'cvx'], default='soft') 59 | self.parser.add_argument('--optimize_fingertip', action='store_true', 60 | help='If we also want to optimize the finger tip') 61 | 62 | self.initialized = True 63 | 64 | def parse(self): 65 | pass 66 | 67 | def set_folder_name(self): 68 | folder_name = "views_"+str(self.opt.num_viewpoints_per_object) 69 | if not self.opt.no_classification_loss: 70 | folder_name += "_classification_" + \ 71 | str(int(self.opt.lambda_G_classification)) 72 | if not self.opt.no_contact_loss: 73 | folder_name += "_contact_"+str(int(self.opt.lambda_G_contactloss)) 74 | if not self.opt.no_intersection_loss: 75 | folder_name += "_intersection_" + \ 76 | str(int(self.opt.lambda_G_intersections)) 77 | if not self.opt.no_orientation_loss: 78 | folder_name += "_orientation_" + \ 79 | str(int(self.opt.lambda_G_orientation)) 80 | if not self.opt.no_discriminator: 81 | folder_name += "_discriminator" 82 | folder_name += "_adv_"+str(int(self.opt.lambda_D_prob)) 83 | folder_name += "_gp_"+str(int(self.opt.lambda_D_gp)) 84 | folder_name += "_coll_threshold_" + \ 85 | str(self.opt.collision_loss_threshold).split(".")[-1] 86 | folder_name += "_obj_finger_contact_threshold_" + \ 87 | str(self.opt.object_finger_contact_threshold).split(".")[-1] 88 | folder_name += "_finger_constrainer_"+self.opt.constrain_method 89 | folder_name += "_optimizer_"+self.opt.optimizer 90 | folder_name += "_number_of_epochs_" + \ 91 | str(self.opt.nepochs_no_decay+self.opt.nepochs_decay) 92 | folder_name += self.opt.extra_name 93 | return folder_name 94 | 95 | def load_opt_file(self, file): 96 | with open(file, "r") as f: 97 | opt = EasyDict(yaml.load(f, Loader=yaml.FullLoader)) 98 | opt.load_epoch = self.opt.load_epoch 99 | opt.checkpoint_dir_load = self.opt.checkpoint_dir_load 100 | opt.batch_size = self.opt.batch_size 101 | opt.gpu_ids = self.opt.gpu_ids 102 | #opt.gpu_ids = ','.join([str(elem) for elem in opt.gpu_ids]) 103 | self.opt.update(opt) 104 | 105 | def set_epoch(self): 106 | folder_name = self.set_folder_name() 107 | self.opt.name = folder_name+self.opt.name 108 | models_dir = os.path.join( 109 | self.opt.checkpoints_dir, self.opt.name) 110 | if os.path.exists(models_dir): 111 | print("Terminating. The folder " + models_dir + 112 | " exists and you chose to train from start. Either remove the folder or choose to continue train") 113 | sys.exit() 114 | 115 | def load_epoch(self): 116 | if os.path.exists(self.opt.checkpoint_dir_load): 117 | self.load_opt_file(self.opt.checkpoint_dir_load+"opt_train.yaml") 118 | found = False 119 | for file in os.listdir(self.opt.checkpoint_dir_load): 120 | if file.startswith("net_epoch_"+self.opt.load_epoch): 121 | found = True 122 | break 123 | if not found: 124 | print("Terminating. Epoch "+self.opt.load_epoch+" not found ") 125 | sys.exit() 126 | else: 127 | print("Terminating. You want to continue train but the folder " + self.opt.checkpoint_dir_load + 128 | " does not exist. First you need to train a model.") 129 | sys.exit() 130 | 131 | def get_set_gpus(self): 132 | # get gpu ids 133 | str_ids = self.opt.gpu_ids.split(',') 134 | self.opt.gpu_ids = [] 135 | for str_id in str_ids: 136 | id = int(str_id) 137 | if id >= 0: 138 | self.opt.gpu_ids.append(id) 139 | elif id == -1: 140 | self.opt.gpu_ids.append(id) 141 | return 142 | if torch.cuda.is_available(): 143 | while self.opt.gpu_ids[0] >= torch.cuda.device_count(): 144 | self.opt.gpu_ids[0] -= 1 145 | # set gpu ids 146 | if len(self.opt.gpu_ids) > 0: 147 | torch.cuda.set_device(self.opt.gpu_ids[0]) 148 | 149 | def print(self, args): 150 | print('------------ Options -------------') 151 | for k, v in sorted(args.items()): 152 | print('%s: %s' % (str(k), str(v))) 153 | print('-------------- End ----------------') 154 | 155 | def save(self, args): 156 | expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name) 157 | print(expr_dir) 158 | util.mkdirs(expr_dir) 159 | file_name = os.path.join(expr_dir, 'opt_%s.txt' % 160 | ('train' if self.is_train else 'test')) 161 | file_name_yaml = os.path.join(expr_dir, 'opt_%s.yaml' % 162 | ('train' if self.is_train else 'test')) 163 | with open(file_name_yaml, 'w') as opt_file: 164 | yaml.dump(args, opt_file) 165 | with open(file_name, 'wt') as opt_file: 166 | opt_file.write('------------ Options -------------\n') 167 | for k, v in sorted(args.items()): 168 | opt_file.write('%s: %s\n' % (str(k), str(v))) 169 | opt_file.write('-------------- End ----------------\n') 170 | file_name = os.path.join(expr_dir, 'command_line.txt') 171 | with open(file_name, 'wt') as opt_file: 172 | opt_file.write(" ".join(sys.argv)) 173 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from collections import OrderedDict 3 | import trimesh 4 | import quaternion 5 | import torch 6 | import math 7 | import os 8 | import numpy as np 9 | from PIL import Image 10 | import copy 11 | import cvxopt as cvx 12 | cvx.solvers.options['show_progress'] = False 13 | 14 | 15 | def batch_pairwise_dist(x, y, use_cuda=True): 16 | bs, num_points_x, points_dim = x.size() 17 | _, num_points_y, _ = y.size() 18 | xx = torch.bmm(x, x.transpose(2, 1)) 19 | yy = torch.bmm(y, y.transpose(2, 1)) 20 | zz = torch.bmm(x, y.transpose(2, 1)) 21 | if use_cuda: 22 | dtype = torch.cuda.LongTensor 23 | else: 24 | dtype = torch.LongTensor 25 | diag_ind_x = torch.arange(0, num_points_x).type(dtype) 26 | diag_ind_y = torch.arange(0, num_points_y).type(dtype) 27 | rx = (xx[:, diag_ind_x, 28 | diag_ind_x].unsqueeze(1).expand_as(zz.transpose(2, 1))) 29 | ry = yy[:, diag_ind_y, diag_ind_y].unsqueeze(1).expand_as(zz) 30 | P = rx.transpose(2, 1) + ry - 2 * zz 31 | return P 32 | 33 | 34 | def mkdirs(paths): 35 | if isinstance(paths, list) and not isinstance(paths, str): 36 | for path in paths: 37 | mkdir(path) 38 | else: 39 | mkdir(paths) 40 | 41 | 42 | def mkdir(path): 43 | if not os.path.exists(path): 44 | os.makedirs(path) 45 | 46 | 47 | def save_image(image_numpy, image_path): 48 | mkdir(os.path.dirname(image_path)) 49 | image_pil = Image.fromarray(image_numpy) 50 | image_pil.save(image_path) 51 | 52 | 53 | def convert_qt_to_T_matrix(qt): 54 | q = qt[3:] 55 | t = qt[:3] 56 | q = np.quaternion(q[-1], q[0], q[1], q[2]) 57 | R = quaternion.as_rotation_matrix(q) 58 | T = np.eye(4) 59 | T[:3, 3] = t 60 | T[:3, :3] = R 61 | return T 62 | 63 | 64 | def concatenate_barret_vertices_and_faces(vertices, faces, use_torch=False): 65 | all_vertices_per_batch = [] 66 | all_faces_per_batch = [] 67 | 68 | for i in range(vertices[0].shape[0]): 69 | flattened_vertices = [] 70 | flattened_faces = [] 71 | num_vertices = 0 72 | for v, f in zip(vertices, faces): 73 | curr_v = v[i] 74 | curr_f = f+num_vertices 75 | if len(v[i].shape) > 2: 76 | curr_f = curr_f.unsqueeze(0).repeat(v[i].shape[0], 1, 1) 77 | curr_f[1:] += curr_v.shape[1] 78 | if curr_f.shape[0] > 2: 79 | curr_f[2] += curr_v.shape[1] 80 | curr_f = curr_f.reshape( 81 | curr_f.shape[0]*curr_f.shape[1], curr_f.shape[2]) 82 | num_vertices += curr_v.shape[1]*curr_v.shape[0] 83 | curr_v = v[i].reshape( 84 | v[i].shape[0] * v[i].shape[1], v[i].shape[2]) 85 | else: 86 | num_vertices += curr_v.shape[0] 87 | if use_torch: 88 | flattened_vertices.append(curr_v) 89 | flattened_faces.append(curr_f) 90 | else: 91 | flattened_vertices.append(curr_v.cpu().data.numpy()) 92 | flattened_faces.append(curr_f.cpu().data.numpy()) 93 | if use_torch: 94 | all_vertices_per_batch.append(torch.cat(flattened_vertices)) 95 | all_faces_per_batch.append(torch.cat(flattened_faces)) 96 | else: 97 | all_vertices_per_batch.append(np.concatenate(flattened_vertices)) 98 | all_faces_per_batch.append(np.concatenate(flattened_faces)) 99 | if use_torch: 100 | return torch.stack(all_vertices_per_batch), all_faces_per_batch 101 | else: 102 | # np.concatenate(flattened_vertices) 103 | return all_vertices_per_batch, all_faces_per_batch 104 | 105 | 106 | def concatenate_barret_vertices(vertices, use_torch=False): 107 | all_vertices_per_batch = [] 108 | for i in range(vertices[0].shape[0]): 109 | flattened_vertices = [] 110 | for v in vertices: 111 | curr_v = v[i] 112 | if len(v[i].shape) > 2: 113 | curr_v = v[i].reshape( 114 | v[i].shape[0] * v[i].shape[1], v[i].shape[2]) 115 | if use_torch: 116 | flattened_vertices.append(curr_v) 117 | else: 118 | flattened_vertices.append(curr_v.cpu().data.numpy()) 119 | if use_torch: 120 | all_vertices_per_batch.append(torch.cat(flattened_vertices)) 121 | else: 122 | all_vertices_per_batch.append(np.concatenate(flattened_vertices)) 123 | if use_torch: 124 | return torch.stack(all_vertices_per_batch) 125 | else: 126 | return all_vertices_per_batch 127 | 128 | 129 | def concatenate_barret_vertices_and_vertice_areas(vertices, vertice_areas, use_torch=False): 130 | all_vertices_per_batch = [] 131 | all_vertice_areas_per_batch = [] 132 | for i in range(vertices[0].shape[0]): 133 | flattened_vertices = [] 134 | flattened_v_areas = [] 135 | for v, v_a in zip(vertices, vertice_areas): 136 | curr_v = v[i] 137 | curr_v_a = v_a 138 | if len(v[i].shape) > 2: 139 | curr_v = v[i].reshape( 140 | v[i].shape[0] * v[i].shape[1], v[i].shape[2]) 141 | curr_v_a = curr_v_a.repeat(v[i].shape[0]) 142 | if use_torch: 143 | flattened_vertices.append(curr_v) 144 | flattened_v_areas.append(curr_v_a) 145 | else: 146 | flattened_vertices.append(curr_v.cpu().data.numpy()) 147 | flattened_v_areas.append(curr_v_a.cpu().data.numpy()) 148 | if use_torch: 149 | all_vertices_per_batch.append(torch.cat(flattened_vertices)) 150 | all_vertice_areas_per_batch.append(torch.cat(flattened_v_areas)) 151 | else: 152 | all_vertices_per_batch.append(np.concatenate(flattened_vertices)) 153 | all_vertice_areas_per_batch.append( 154 | np.concatenate(flattened_v_areas)) 155 | if use_torch: 156 | return torch.stack(all_vertices_per_batch), torch.stack(all_vertice_areas_per_batch) 157 | else: 158 | return all_vertices_per_batch, all_vertice_areas_per_batch 159 | 160 | 161 | def joints_to_grasp_representation(joints): 162 | dofs = np.zeros(7) 163 | dofs[0] = joints[0] 164 | dofs[1] = joints[4] 165 | dofs[2] = joints[1] 166 | dofs[3] = joints[6] 167 | dofs[4] = joints[5] 168 | dofs[5] = joints[2] 169 | dofs[6] = joints[7] 170 | return dofs 171 | 172 | 173 | def load_mesh(mesh_file): 174 | mesh = trimesh.load(mesh_file) 175 | object_info = {} 176 | object_info["verts"] = mesh.vertices 177 | object_info["faces"] = mesh.faces 178 | object_info["verts_resampled"] = trimesh.sample.sample_surface_even(mesh, 800)[ 179 | 0] 180 | 181 | return object_info 182 | 183 | 184 | def calculate_classification_statistics(y_true, y_pred): 185 | tp = (y_true * y_pred).sum().to(torch.float32) 186 | tn = ((1 - y_true) * (1 - y_pred)).sum().to(torch.float32) 187 | fp = ((1 - y_true) * y_pred).sum().to(torch.float32) 188 | fn = (y_true * (1 - y_pred)).sum().to(torch.float32) 189 | return tp, tn, fp, fn 190 | 191 | 192 | def f1_score(tp, tn, fp, fn): 193 | '''Calculate F1 score. Can work with gpu tensors 194 | 195 | The original implementation is written by Michal Haltuf on Kaggle. 196 | 197 | Returns 198 | ------- 199 | torch.Tensor 200 | `ndim` == 1. 0 <= val <= 1 201 | 202 | Reference 203 | --------- 204 | - https://www.kaggle.com/rejpalcz/best-loss-function-for-f1-score-metric 205 | - https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html#sklearn.metrics.f1_score 206 | - https://discuss.pytorch.org/t/calculating-precision-recall-and-f1-score-in-case-of-multi-label-classification/28265/6 207 | 208 | ''' 209 | epsilon = 1e-7 210 | 211 | precision = tp / (tp + fp + epsilon) 212 | recall = tp / (tp + fn + epsilon) 213 | 214 | f1 = 2 * (precision*recall) / (precision + recall + epsilon) 215 | return f1 216 | 217 | 218 | def rotm_to_axis_angles(R, device="cpu"): 219 | trace = torch.einsum('bii->b', R) 220 | theta = torch.acos((trace-1)/2) 221 | eps1 = 0.01 222 | eps2 = 0.1 223 | axis_angles = torch.zeros((R.shape[0], 3)).to(device) 224 | axis_angles[:, 0] = R[:, 2, 1]-R[:, 1, 2] 225 | axis_angles[:, 1] = R[:, 0, 2]-R[:, 2, 0] 226 | axis_angles[:, 2] = R[:, 1, 0]-R[:, 0, 1] 227 | temp = 1/(2*torch.sin(theta)).unsqueeze(-1) 228 | axis_angles *= temp 229 | singularities = torch.where(((R[:, 0, 1]-R[:, 1, 0]).abs() < eps1) & ((R[:, 0, 2]-R[:, 2, 0]).abs() 230 | < eps1) & ((R[:, 1, 2]-R[:, 2, 1]).abs() < eps1))[0] 231 | if singularities.nelement() > 0: 232 | theta[singularities] = math.pi 233 | for i in singularities: 234 | trace_sing = R[i].trace() 235 | if (((R[i, 0, 1]+R[i, 1, 0]).abs() < eps2) & ((R[i, 0, 2] + R[i, 2, 0]).abs() < eps2) & ((R[i, 1, 2]+R[i, 2, 1]).abs() < eps1) & ((trace_sing-3).abs() < eps2)): 236 | axis_angles[i] = torch.zeros(3) 237 | axis_angles[i, 0] = 1 238 | theta[i] = 0 239 | else: 240 | theta[i] = math.pi 241 | xx = 0.5*(R[i, 0, 0]+1) 242 | yy = 0.5*(R[i, 1, 1]+1) 243 | zz = 0.5*(R[i, 2, 2]+1) 244 | xy = (R[i, 0, 1]+R[i, 1, 0])/4 245 | xz = (R[i, 0, 2]+R[i, 2, 0])/4 246 | yz = (R[i, 1, 2]+R[i, 2, 1])/4 247 | if (xx > yy) & (xx > zz): 248 | if (xx < eps1): 249 | x = 0 250 | y = 0.7071 251 | z = 0.7071 252 | else: 253 | x = torch.sqrt(xx) 254 | y = xy/x 255 | z = xz/x 256 | elif (yy > xx): 257 | if (yy < eps1): 258 | x = 0.7071 259 | y = 0 260 | z = 0.7071 261 | else: 262 | y = torch.sqrt(yy) 263 | x = xy/y 264 | z = yz/y 265 | else: 266 | if (zz < eps1): 267 | x = 0.7071 268 | y = 0.7071 269 | z = 0 270 | else: 271 | z = torch.sqrt(zz) 272 | x = xz/z 273 | y = yz/z 274 | axis_angles[i, 0] = x 275 | axis_angles[i, 1] = y 276 | axis_angles[i, 2] = z 277 | return axis_angles*theta.view(theta.shape[0], 1) 278 | 279 | 280 | def axis_angles_to_rotation_matrix(angle_axis, eps=1e-6, device="cpu"): 281 | theta = angle_axis.norm(dim=-1) 282 | angle_axis = angle_axis/theta.view(theta.shape[0], 1) 283 | k_one = 1.0 284 | wxyz = angle_axis # / (theta + eps) 285 | wx, wy, wz = torch.chunk(wxyz, 3, dim=1) 286 | cos_theta = torch.cos(theta).unsqueeze(1) 287 | sin_theta = torch.sin(theta).unsqueeze(1) 288 | r00 = cos_theta + wx * wx * (k_one - cos_theta) 289 | r10 = wz * sin_theta + wx * wy * (k_one - cos_theta) 290 | r20 = -wy * sin_theta + wx * wz * (k_one - cos_theta) 291 | r01 = wx * wy * (k_one - cos_theta) - wz * sin_theta 292 | r11 = cos_theta + wy * wy * (k_one - cos_theta) 293 | r21 = wx * sin_theta + wy * wz * (k_one - cos_theta) 294 | r02 = wy * sin_theta + wx * wz * (k_one - cos_theta) 295 | r12 = -wx * sin_theta + wy * wz * (k_one - cos_theta) 296 | r22 = cos_theta + wz * wz * (k_one - cos_theta) 297 | rotation_matrix = torch.cat( 298 | [r00, r01, r02, r10, r11, r12, r20, r21, r22], dim=1).to(device) 299 | return rotation_matrix.view(-1, 3, 3) 300 | 301 | 302 | def rad2deg(rads): 303 | return 180. * rads / math.pi 304 | 305 | 306 | def deg2rad(degs): 307 | return math.pi * degs / 180. 308 | 309 | 310 | def valid_hand_conf(degs, device="cpu", only_check_spread=False): 311 | deg_max = torch.FloatTensor([180, 140, 140, 140, 48, 48, 48]).to(device) 312 | deg_min = torch.zeros(7).to(device) 313 | if only_check_spread: 314 | return torch.sum(degs[:, 0]-1e-2 > deg_max[0])+torch.sum(degs[:, 0]+1e-2 < deg_min[0]) 315 | else: 316 | return torch.sum(degs-1e-2 > deg_max)+torch.sum(degs+1e-2 < deg_min) 317 | 318 | 319 | def euclidean_distance(d1, d2): 320 | return torch.norm(d1-d2, dim=-1) 321 | 322 | 323 | def qrot(q, v): 324 | """ 325 | Rotate vector(s) v about the rotation described by quaternion(s) q. 326 | Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v, 327 | where * denotes any number of dimensions. 328 | Returns a tensor of shape (*, 3). 329 | """ 330 | assert q.shape[-1] == 4 331 | assert v.shape[-1] == 3 332 | assert q.shape[:-1] == v.shape[:-1] 333 | 334 | original_shape = list(v.shape) 335 | q = q.view(-1, 4) 336 | v = v.view(-1, 3) 337 | 338 | qvec = q[:, 1:] 339 | uv = torch.cross(qvec, v, dim=1) 340 | uuv = torch.cross(qvec, uv, dim=1) 341 | return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape) 342 | 343 | 344 | def axis_angle_rot(angle_axis, vec): 345 | theta = angle_axis.norm(dim=-1) 346 | vec = vec.repeat(theta.shape[0], 1) 347 | angle_axis = angle_axis/theta.view(theta.shape[0], 1) 348 | angle_cos = torch.cos(theta) 349 | angle_sin = torch.sin(theta) 350 | vec_rot = vec*angle_cos.view(theta.shape[0], 1)+torch.cross(angle_axis, vec)*angle_sin.view(theta.shape[0], 1) + \ 351 | angle_axis*(angle_axis*vec).sum(-1, keepdim=True) * \ 352 | (1-angle_cos.view(theta.shape[0], 1)) 353 | return vec_rot 354 | 355 | 356 | def axis_angle_distance(random_rots, R): 357 | random_rots_normed = random_rots/random_rots.norm(dim=-1, keepdim=True) 358 | R_normed = R/R.norm(dim=-1, keepdim=True) 359 | return 1-(random_rots_normed*R_normed).sum(-1) 360 | 361 | 362 | def average_dictionary(list_dict): 363 | averaged_dictionary = OrderedDict() 364 | for dict_losses in list_dict: 365 | for key in dict_losses.keys(): 366 | if key in averaged_dictionary: 367 | averaged_dictionary[key] += dict_losses[key] 368 | else: 369 | averaged_dictionary[key] = copy.deepcopy(dict_losses[key]) 370 | for key in averaged_dictionary.keys(): 371 | averaged_dictionary[key] /= len(list_dict) 372 | return averaged_dictionary 373 | 374 | 375 | def concatenate_dictionary(list_dict): 376 | concatenated_dictionary = {} 377 | for curr_dict in list_dict: 378 | for key in curr_dict.keys(): 379 | if key in concatenated_dictionary: 380 | concatenated_dictionary[key].append(curr_dict[key]) 381 | else: 382 | concatenated_dictionary[key] = [curr_dict[key]] 383 | return concatenated_dictionary 384 | 385 | 386 | def intersect_vox(obj_mesh, hand_mesh, pitch=0.01): 387 | obj_vox = obj_mesh.voxelized(pitch=pitch) 388 | obj_points = obj_vox.points 389 | inside = hand_mesh.contains(obj_points) 390 | volume = inside.sum() * np.power(pitch, 3) 391 | return volume 392 | 393 | 394 | def create_mesh(input_vertices, input_faces): 395 | return trimesh.Trimesh(input_vertices, input_faces) 396 | 397 | 398 | def get_intersection(hand_vertices, hand_faces, obj_mesh): 399 | hand_mesh = create_mesh(hand_vertices, hand_faces) 400 | intersections = intersect_vox(hand_mesh, obj_mesh, 0.005) 401 | return intersections 402 | 403 | 404 | def min_norm_vector_in_facet(facet, wrench_regularizer=1e-10): 405 | """ Finds the minimum norm point in the convex hull of a given facet (aka simplex) by solving a QP. 406 | 407 | Parameters 408 | ---------- 409 | facet : 6xN :obj:`numpy.ndarray` 410 | vectors forming the facet 411 | wrench_regularizer : float 412 | small float to make quadratic program positive semidefinite 413 | 414 | Returns 415 | ------- 416 | float 417 | minimum norm of any point in the convex hull of the facet 418 | Nx1 :obj:`numpy.ndarray` 419 | vector of coefficients that achieves the minimum 420 | """ 421 | dim = facet.shape[1] # num vertices in facet 422 | 423 | # create alpha weights for vertices of facet 424 | G = facet.T.dot(facet) 425 | grasp_matrix = G + wrench_regularizer * np.eye(G.shape[0]) 426 | 427 | # Solve QP to minimize .5 x'Px + q'x subject to Gx <= h, Ax = b 428 | P = cvx.matrix(2 * grasp_matrix) # quadratic cost for Euclidean dist 429 | q = cvx.matrix(np.zeros((dim, 1))) 430 | G = cvx.matrix(-np.eye(dim)) # greater than zero constraint 431 | h = cvx.matrix(np.zeros((dim, 1))) 432 | A = cvx.matrix(np.ones((1, dim))) # sum constraint to enforce convex 433 | b = cvx.matrix(np.ones(1)) # combinations of vertices 434 | sol = cvx.solvers.qp(P, q, G, h, A, b) 435 | v = np.array(sol['x']) 436 | min_norm = np.sqrt(sol['primal objective']) 437 | 438 | return abs(min_norm), v 439 | 440 | 441 | def grasp_matrix(forces, torques, normals, soft_fingers=False, 442 | finger_radius=0.005, params=None): 443 | if params is not None and 'finger_radius' in params.keys(): 444 | finger_radius = params.finger_radius 445 | num_forces = forces.shape[1] 446 | num_torques = torques.shape[1] 447 | if num_forces != num_torques: 448 | raise ValueError('Need same number of forces and torques') 449 | 450 | num_cols = num_forces 451 | if soft_fingers: 452 | num_normals = 2 453 | if normals.ndim > 1: 454 | num_normals = 2*normals.shape[1] 455 | num_cols = num_cols + num_normals 456 | 457 | torque_scaling = 1 458 | G = np.zeros([6, num_cols]) 459 | for i in range(num_forces): 460 | G[:3, i] = forces[:, i] 461 | # G[3:,i] = forces[:,i] # ZEROS 462 | G[3:, i] = torque_scaling * torques[:, i] 463 | 464 | if soft_fingers: 465 | torsion = np.pi * finger_radius**2 * \ 466 | params.friction_coef * normals * params.torque_scaling 467 | pos_normal_i = -num_normals 468 | neg_normal_i = -num_normals + num_normals / 2 469 | G[3:, pos_normal_i:neg_normal_i] = torsion 470 | G[3:, neg_normal_i:] = -torsion 471 | 472 | return G 473 | 474 | 475 | def get_normal_face(p1, p2, p3): 476 | U = p2 - p1 477 | V = p3 - p1 478 | Nx = U[1]*V[2] - U[2]*V[1] 479 | Ny = U[2]*V[0] - U[0]*V[2] 480 | Nz = U[0]*V[1] - U[1]*V[0] 481 | return [-1*Nx, -1*Ny, -1*Nz] 482 | 483 | 484 | def get_normal_face_batched(vertices, faces): 485 | p1 = vertices[faces[:, 0]] 486 | p2 = vertices[faces[:, 1]] 487 | p3 = vertices[faces[:, 2]] 488 | U = p2 - p1 489 | V = p3 - p1 490 | Nx = U[:, 1]*V[:, 2] - U[:, 2]*V[:, 1] 491 | Ny = U[:, 2]*V[:, 0] - U[:, 0]*V[:, 2] 492 | Nz = U[:, 0]*V[:, 1] - U[:, 1]*V[:, 0] 493 | return -1*torch.stack((Nx, Ny, Nz)).T 494 | 495 | 496 | def get_distance_vertices(obj, hand): 497 | n1 = len(hand) 498 | n2 = len(obj) 499 | 500 | matrix1 = hand[np.newaxis].repeat(n2, 0) 501 | matrix2 = obj[:, np.newaxis].repeat(n1, 1) 502 | dists = np.sqrt(((matrix1-matrix2)**2).sum(-1)) 503 | 504 | return dists.min(0) 505 | 506 | 507 | def get_distance_vertices_batched(obj, hand): 508 | n1 = len(hand[1]) 509 | n2 = len(obj[0]) 510 | 511 | matrix1 = hand.unsqueeze(1).repeat(1, n2, 1, 1) 512 | matrix2 = obj.unsqueeze(2).repeat(1, 1, n1, 1) 513 | dists = torch.norm(matrix1 - matrix2, dim=-1) 514 | dists = dists.min(1)[0] 515 | 516 | return dists 517 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from barrett_kinematics.barrett_layer.barrett_layer import BarrettLayer 2 | import pickle 3 | import pyquaternion 4 | import trimesh 5 | import os 6 | import glob 7 | from utils import util 8 | import utils.plots as plot_utils 9 | import torch 10 | import numpy as np 11 | from options.test_options import TestOptions 12 | 13 | import time 14 | from datasets.custom_dataset_data_loader import CustomDatasetDataLoader 15 | from models.models import ModelsFactory 16 | 17 | from joblib import Parallel, delayed 18 | 19 | import pandas as pd 20 | import subprocess 21 | 22 | 23 | class Test: 24 | def __init__(self): 25 | self.opt = TestOptions().parse() 26 | if self.get_test_data(): 27 | self.display = self.opt.display 28 | self.threshold_intersections = self.opt.threshold_intersections 29 | self.threshold_contact = self.opt.threshold_contact 30 | self.vertices_and_parts_on_barrett_hand_that_can_touch_object() 31 | try: 32 | os.makedirs(self.opt.save_folder) 33 | except: 34 | pass 35 | if self.opt.graspit: 36 | self.barrett_layer = BarrettLayer() 37 | 38 | self.save_path = self.opt.save_folder+self.opt.test_set 39 | self.simulation_experiment_graspit() 40 | else: 41 | self.setup_multi_fingan() 42 | with torch.no_grad(): 43 | self.simulation_experiment_multi_fingan() 44 | 45 | def get_test_data(self): 46 | if not os.path.isdir("data/test_data/") or len(os.listdir("data/test_data/")) < 2 or not os.path.isdir("data/meshes/") or len(os.listdir("data/meshes/")) < 2: 47 | while True: 48 | download_data = input("Test data does not exist. Want to download it (y/n)? ") 49 | if download_data == "y": 50 | dir = os.path.dirname(os.path.realpath(__file__)) + "/data/" 51 | print("Downloading test data. This will take some time so just sit back and relax.") 52 | subprocess.Popen([dir+'download_test_data.sh %s' % dir], shell=True).wait() 53 | print("Done downloading test data. Continuing with testing.") 54 | return True 55 | elif download_data == "n": 56 | print("You chose to not download test data. Terminating testing") 57 | return False 58 | else: 59 | print("Testing data exists. Proceeding with Testing.") 60 | return True 61 | 62 | def vertices_and_parts_on_barrett_hand_that_can_touch_object(self): 63 | self.finger_vertices = [7, 37, 21, 51, 71, 34, 61, 86, 46, 22, 53, 40, 69, 52, 38, 13, 29, 19, 8, 81, 60, 72, 97, 36, 64 | 28, 45, 49, 66, 3, 70, 65, 68, 25, 95, 48, 57, 98, 27, 79, 15, 80, 89, 62, 100, 91, 5, 44, 63] 65 | self.finger_tip_vertices = [35, 98, 44, 90, 95, 77, 84, 73, 76, 25, 108, 64, 22, 24, 96, 23, 85, 79, 83, 30, 45, 47, 68, 54, 42, 69, 92, 86, 66 | 19, 7, 94, 37, 99, 91, 11, 107, 0, 89, 57, 59, 109, 4, 65, 31, 2, 1, 10, 101, 52, 97, 87, 50, 72, 15, 106, 82, 12, 56, 78, 32, 46, 8] 67 | self.palm_vertices = [18, 205, 20, 90, 129, 166, 107, 70, 27, 161, 34, 211, 80, 114] 68 | self.parts_that_can_touch_vertices = [self.palm_vertices, self.finger_vertices, self.finger_vertices, 69 | self.finger_vertices, self.finger_tip_vertices, self.finger_tip_vertices, self.finger_tip_vertices] 70 | # Only the palm (0), three proximal links (3,4,5) and three distal links (6,7,8) can touch an object 71 | # We exclute the knuckle links on finger 1 and 2 in https://support.barrett.com/wiki/Hand/280/KinematicsJointRangesConversionFactors 72 | self.parts_that_can_touch = [0, 3, 4, 5, 6, 7, 8] 73 | 74 | def setup_multi_fingan(self): 75 | self.device = torch.device('cuda:{}'.format( 76 | self.opt.gpu_ids[0])) if self.opt.gpu_ids[0] != -1 and torch.cuda.is_available() else torch.device('cpu') 77 | self.save_path = self.opt.save_folder+self.opt.test_set 78 | self.num_grasps_to_sample = self.opt.num_grasps_to_sample 79 | # Let's set batch size to 1 as we generate grasps from one viewpoint only 80 | self.opt.batch_size = 1 81 | 82 | self.opt.object_finger_contact_threshold = 0.004 83 | self.opt.optimize_finger_tip = True 84 | self.opt.num_viewpoints_per_object = self.opt.num_viewpoints_per_object_test 85 | 86 | self.opt.n_threads_train = self.opt.n_threads_test 87 | 88 | self.model = ModelsFactory.get_by_name(self.opt) 89 | 90 | data_loader_test = CustomDatasetDataLoader(self.opt, mode='val') 91 | self.dataset_test = data_loader_test.load_data() 92 | self.dataset_test_size = len(data_loader_test) 93 | print('#test images = %d' % self.dataset_test_size) 94 | self.model.set_eval() 95 | 96 | self.hand_in_parts = self.calculate_number_of_hand_vertices_per_part(self.model.barrett_layer) 97 | 98 | def calculate_number_of_hand_vertices_per_part(self, hand): 99 | hand_in_parts = [0] # One palm link 100 | hand_in_parts.append(hand.num_vertices_per_part[0]) 101 | for _ in range(2): # Two knuckles 102 | hand_in_parts.append(hand_in_parts[-1]+hand.num_vertices_per_part[1]) 103 | for _ in range(3): # three proximal links 104 | hand_in_parts.append(hand_in_parts[-1]+hand.num_vertices_per_part[2]) 105 | for _ in range(3): # Three distal links 106 | hand_in_parts.append(hand_in_parts[-1]+hand.num_vertices_per_part[3]) 107 | return hand_in_parts 108 | 109 | def import_meshes(self): 110 | models = glob.glob(self.opt.object_mesh_dir) 111 | models.sort() 112 | models = np.array(models) 113 | all_meshes = {} 114 | for i in range(models.shape[0]): 115 | curr_mesh = {} 116 | if self.opt.test_set == "egad": 117 | name = os.path.splitext(models[i].split("/")[-1])[0] 118 | elif self.opt.test_set == "ycb": 119 | name = models[i].split("/")[-3] 120 | obj = trimesh.load(models[i]) 121 | resampled_objects_800verts = trimesh.sample.sample_surface_even(obj, 800)[ 122 | 0] 123 | curr_mesh["vertices"] = np.expand_dims(obj.vertices, 0) 124 | curr_mesh["faces"] = np.expand_dims(obj.faces, 0) 125 | curr_mesh["resampled_vertices"] = resampled_objects_800verts 126 | all_meshes[name] = curr_mesh 127 | return all_meshes 128 | 129 | def simulation_experiment_graspit(self): 130 | grasp_files = glob.glob(self.opt.graspit_grasp_dir+"*pkl") 131 | time_graspit = np.loadtxt(self.opt.graspit_grasp_dir+"time.txt") 132 | 133 | grasp_files.sort() 134 | all_meshes = self.import_meshes() 135 | df = pd.DataFrame(columns=["Number of contacts", "Intersection", "Epsilon quality", 136 | "Epsilon quality from GraspIt!", "Volume quality from GraspIt!", "Objects name", "Time for graspit"]) 137 | 138 | for i, grasp_file in enumerate(grasp_files): 139 | grasps = pickle.load(open(grasp_file, "rb")) 140 | mesh_name = os.path.splitext(os.path.basename(grasp_file))[0] 141 | all_grasp_stabilities = [] 142 | try: 143 | mesh = all_meshes[mesh_name] 144 | except: 145 | continue 146 | if type(grasps) is list: 147 | for grasp_idx, grasp in enumerate(grasps): 148 | data = self.evaluate_graspit_grasp(grasp, mesh, grasp_idx) 149 | all_grasp_stabilities.append(data) 150 | else: 151 | all_grasp_stabilities.append(self.evaluate_graspit_grasp( 152 | grasps, mesh, i)) 153 | #np.save(self.save_path+"_"+str(i)+".npy", np.asarray(all_grasp_stabilities)) 154 | all_grasp_stabilities = np.asarray(all_grasp_stabilities) 155 | sorted_according_to_quality = (-1*all_grasp_stabilities[:, 2]).argsort() 156 | all_grasp_stabilities = all_grasp_stabilities[sorted_according_to_quality] 157 | df_per_object = pd.DataFrame(data=np.asarray(all_grasp_stabilities), columns=["Number of contacts", "Intersection", "Epsilon quality", 158 | "Epsilon quality from GraspIt!", "Volume quality from GraspIt!"]) 159 | df_per_object["Objects name"] = mesh_name 160 | df_per_object["Time for graspit"] = time_graspit[i] 161 | df = df.append(df_per_object) 162 | df.to_csv(self.save_path+"_results.csv") 163 | 164 | def evaluate_graspit_grasp(self, grasp, mesh, i): 165 | grasp_pose = grasp["pose"] 166 | grasp_pose_torch = torch.eye(4).unsqueeze(0) 167 | grasp_pose_torch[:, :3, :3] = torch.from_numpy(pyquaternion.Quaternion( 168 | np.array(grasp_pose[3:])[[3, 0, 1, 2]]).rotation_matrix) 169 | grasp_pose_torch[:, :3, 3] = torch.FloatTensor(grasp_pose[:3]) 170 | grasp_joints_torch = torch.FloatTensor( 171 | util.joints_to_grasp_representation(grasp["joints"])).unsqueeze(0) 172 | 173 | * verts, _ = self.barrett_layer(grasp_pose_torch, grasp_joints_torch) 174 | verts_concatenated, faces_concatenated = util.concatenate_barret_vertices_and_faces( 175 | verts, self.barrett_layer.gripper_faces) 176 | 177 | forces_post, torques_post, normals_post, finger_is_touching, vertices_that_touch = self.get_contact_points( 178 | verts, self.barrett_layer, mesh["resampled_vertices"]) 179 | obj_mesh = util.create_mesh(mesh['vertices'][0], mesh['faces'][0]) 180 | hand_mesh = util.create_mesh( 181 | verts_concatenated[0], faces_concatenated[0]) 182 | intersections = util.intersect_vox(hand_mesh, obj_mesh, 0.005) 183 | if intersections > self.threshold_intersections or len(forces_post) < 3: 184 | data = [ 185 | finger_is_touching.sum(-1), intersections, 0.0, grasp["epsilon"], grasp["volume"]] 186 | else: 187 | G = util.grasp_matrix(np.array(forces_post).transpose(), np.array( 188 | torques_post).transpose(), np.array(normals_post).transpose()) 189 | grasp_metric = util.min_norm_vector_in_facet(G)[0] 190 | data = [finger_is_touching.sum(-1), intersections, grasp_metric, grasp["epsilon"], 191 | grasp["volume"]] 192 | if self.display: 193 | object_vertices = [np.asarray(mesh['vertices']).squeeze()] 194 | object_faces = [np.asarray(mesh['faces']).squeeze()] 195 | hand_vertices = [verts_concatenated[0]] 196 | hand_faces = [faces_concatenated[0]] 197 | self.save_grasp_image(hand_vertices, hand_faces, object_vertices, object_faces, "/tmp/grasp_graspit"+str(i)+".png") 198 | return data 199 | 200 | def save_grasp_image(self, hand_vertices, hand_faces, object_vertices, object_faces, save_dir): 201 | image_numpy = plot_utils.plot_scene_w_grasps(object_vertices, object_faces, hand_vertices, hand_faces) 202 | util.save_image(image_numpy, save_dir) 203 | 204 | def simulation_experiment_multi_fingan(self): 205 | all_data_to_save = np.zeros( 206 | (self.dataset_test_size*self.num_grasps_to_sample, 9)) 207 | 208 | print("SAMPLING %d POSSIBLE GRASPS" % (self.num_grasps_to_sample)) 209 | for i_test_batch, test_batch in enumerate(self.dataset_test): 210 | print("PROCESSING TEST SAMPLE %d" % (i_test_batch)) 211 | time_batch = time.time() 212 | self.model.set_input(test_batch) 213 | 214 | resampled_obj_verts = torch.FloatTensor( 215 | self.model.input_obj_resampled_verts[0][0]).to(self.device) 216 | #self.model.input_object_id = [0] 217 | # self.model.input_center_objects = torch.FloatTensor( 218 | # [self.model.input_obj_resampled_verts[0][0].mean(0)]).cuda() 219 | time_s = time.time() 220 | self.model.forward_G(False) 221 | time_generating = time.time()-time_s 222 | print("Time for generating grasps " + str(time_generating)) 223 | verts = self.model.refined_handpose_concatenated 224 | time_s = time.time() 225 | forces_post, torques_post, normals_post, parts_are_touching, vertices_that_touch = self.get_contact_points_batch( 226 | verts, self.model.barrett_layer, resampled_obj_verts, self.num_grasps_to_sample) 227 | time_contact = time.time()-time_s 228 | enough_contacts = parts_are_touching.sum(-1) >= 3 229 | time_s = time.time() 230 | intersections = np.zeros(self.num_grasps_to_sample) 231 | obj_mesh = util.create_mesh( 232 | test_batch['3d_points_object'][0][0], test_batch['3d_faces_object'][0][0]) 233 | 234 | intersections = np.asarray(Parallel(n_jobs=24)( 235 | delayed(util.get_intersection)(verts[i].cpu().numpy(), self.model.refined_handpose_faces[i].cpu().numpy(), obj_mesh) 236 | for i in range(self.num_grasps_to_sample)) 237 | ) 238 | time_intersection = time.time()-time_s 239 | print("Time to calculate all intersections " + str(time_intersection)) 240 | no_intersection = intersections < self.threshold_intersections 241 | grasp_to_check_quality = no_intersection & enough_contacts.cpu().numpy() 242 | 243 | time_s = time.time() 244 | qualities = np.zeros(self.num_grasps_to_sample) 245 | for dim in grasp_to_check_quality.nonzero()[0]: 246 | try: 247 | G = util.grasp_matrix(torch.stack(forces_post[dim]).T.cpu().data.numpy(), torch.stack(torques_post[dim]).T.cpu().data.numpy(), torch.stack( 248 | normals_post[dim]).T.cpu().data.numpy()) 249 | except: 250 | print("Problem with batch "+str(i_test_batch)) 251 | continue 252 | grasp_metric = util.min_norm_vector_in_facet(G)[0] 253 | qualities[dim] = grasp_metric 254 | 255 | if self.display: 256 | best_grasps = (-1*qualities).argsort() 257 | object_vertices = [self.model.input_obj_verts[0][0]] 258 | object_faces = [self.model.input_obj_faces[0][0]] 259 | hand_vertices = [self.model.refined_handpose_concatenated[best_grasps][0].cpu()] 260 | hand_faces = [self.model.refined_handpose_faces[0].cpu()] 261 | self.save_grasp_image(hand_vertices, hand_faces, object_vertices, object_faces, 262 | "/tmp/multi_fingan_best_grasp_batch_"+str(i_test_batch)+".png") 263 | 264 | time_for_batch = time.time()-time_batch 265 | time_quality = time.time()-time_s 266 | print("Time to calculate qualities " + str(time_quality)) 267 | print("Total time for batch " + str(time_for_batch)) 268 | batch_data_to_save = self.populate_batch_data( 269 | parts_are_touching.sum(-1).data, intersections, qualities, time_generating, time_contact, time_intersection, time_quality, time_for_batch) 270 | all_data_to_save[i_test_batch*self.num_grasps_to_sample:(i_test_batch+1)*self.num_grasps_to_sample] = batch_data_to_save 271 | df = pd.DataFrame(data=all_data_to_save, columns=["Object id", "Num touches", "Intersection", "Qualities", "Average time per grasp generating", 272 | "Average time per grasp contact", "Average time per grasp intersection", "Average time per grasp qualities", "Average time total"]) 273 | 274 | df.to_csv(self.save_path+"_results.csv") 275 | 276 | def populate_batch_data(self, num_contacts, intersections, qualities, time_to_generate, time_for_contact, time_for_intersections, time_for_quality, total_time): 277 | array_for_data = np.zeros( 278 | (self.num_grasps_to_sample, 9)) 279 | array_for_data[:, 0] = int(self.model.input_object_id[0]) 280 | array_for_data[:, 1] = num_contacts 281 | array_for_data[:, 2] = intersections 282 | array_for_data[:, 3] = qualities 283 | array_for_data[:, 4] = time_to_generate/self.num_grasps_to_sample 284 | array_for_data[:, 5] = time_for_contact/self.num_grasps_to_sample 285 | array_for_data[:, 6] = time_for_intersections / self.num_grasps_to_sample 286 | array_for_data[:, 7] = time_for_quality/self.num_grasps_to_sample 287 | array_for_data[:, 8] = total_time/self.num_grasps_to_sample 288 | return array_for_data 289 | 290 | def concatenate_faces(self, hand, use_torch=False): 291 | if use_torch: 292 | faces = [hand.gripper_faces[0]] 293 | for _ in range(2): 294 | faces.append(hand.gripper_faces[1]) 295 | for _ in range(3): 296 | faces.append(hand.gripper_faces[2]) 297 | for _ in range(3): 298 | faces.append(hand.gripper_faces[3]) 299 | else: 300 | faces = [hand.gripper_faces[0].cpu().data.numpy()] 301 | for i in range(3): 302 | faces.append(hand.gripper_faces[2].cpu().data.numpy()) 303 | for i in range(3): 304 | faces.append(hand.gripper_faces[3].cpu().data.numpy()) 305 | return faces 306 | 307 | def get_contact_points_batch(self, all_hand_vertices, hand, resampled_obj_verts, batch_size): 308 | faces = self.concatenate_faces(hand, True) 309 | all_forces = [[] for _ in range(batch_size)] 310 | all_torques = [[] for _ in range(batch_size)] 311 | all_normals = [[] for _ in range(batch_size)] 312 | vertices_that_touches = [[] for _ in range(batch_size)] 313 | # We consider seven parts: the palm, the three finger tips, and the three finger bases 314 | # parts_touching = torch.zeros(palm_vertices.shape[0], 7) 315 | # distance_touching_vertices = self.calculate_contact_batch(points_3d) 316 | resampled_obj_verts = resampled_obj_verts.unsqueeze( 317 | 0).repeat(batch_size, 1, 1) 318 | dists = util.get_distance_vertices_batched( 319 | resampled_obj_verts, all_hand_vertices) 320 | 321 | parts_are_touching = torch.zeros(batch_size, len(self.parts_that_can_touch)) 322 | for j, i in enumerate(self.parts_that_can_touch): 323 | val, dims = dists[:, self.hand_in_parts[i]:self.hand_in_parts[i+1]][:, self.parts_that_can_touch_vertices[j]].min(dim=1) 324 | parts_are_touching[:, j] = val < self.threshold_contact 325 | batches_to_eval = torch.where(val < self.threshold_contact)[0] 326 | for batch in batches_to_eval: 327 | temp = torch.where(dims[batch] == faces[i])[0] 328 | normal = util.get_normal_face_batched( 329 | all_hand_vertices[batch], faces[i][temp]+self.hand_in_parts[i]).mean(0) * 1e5 330 | normal = normal/normal.norm() 331 | all_torques[batch].append( 332 | torch.FloatTensor([0, 0, 0]).to(self.device)) 333 | all_normals[batch].append(normal) 334 | all_forces[batch].append(normal) 335 | vertices_that_touches[batch].append( 336 | all_hand_vertices[batch][self.hand_in_parts[i]+dims[batch]]) 337 | return all_forces, all_torques, all_normals, parts_are_touching, vertices_that_touches 338 | 339 | def get_contact_points(self, all_hand_vertices, hand, resampled_obj_verts, debug=False): 340 | palm_vertices, _, all_finger_vertices, all_finger_tip_vertices = all_hand_vertices 341 | hand_vertices = [palm_vertices[0].cpu().data.numpy()] 342 | for base_finger in all_finger_vertices[0]: 343 | hand_vertices.append(base_finger.cpu().data.numpy()) 344 | for finger_tip in all_finger_tip_vertices[0]: 345 | hand_vertices.append(finger_tip.cpu().data.numpy()) 346 | hand_faces = self.concatenate_faces(hand) 347 | forces = [] 348 | torques = [] 349 | normals = [] 350 | # We consider seven parts: the palm, the three finger tips, and the three finger bases 351 | part_is_touching = torch.zeros(len(self.parts_that_can_touch)) 352 | 353 | vertices_that_touches = [] 354 | 355 | for i in range(len(self.parts_that_can_touch)): 356 | # Get the distance between all vertices on hand and all sampled points on the object 357 | dists = util.get_distance_vertices( 358 | resampled_obj_verts, hand_vertices[i])[self.parts_that_can_touch_vertices[i]] 359 | if np.min(dists) < self.threshold_contact: 360 | part_is_touching[i] = 1 361 | vertices_that_touches.append( 362 | hand_vertices[i][np.argmin(dists)]) 363 | # Get all incident faces of the vertice on the hand that is in contact with the object 364 | faces = np.where(np.argmin(dists) == hand_faces[i])[0] 365 | normal = [] 366 | # Calculate the normal of all incident faces 367 | for j in range(len(faces)): 368 | normal.append(util.get_normal_face(hand_vertices[i][hand_faces[i][faces[j], 0]], hand_vertices[i] 369 | [hand_faces[i][faces[j], 1]], hand_vertices[i][hand_faces[i][faces[j], 2]])) 370 | # The contact normal is the average of all face norlas that are incident to the vertice in contact 371 | normal = np.mean(normal, 0) * 1e5 372 | normal = normal/np.sqrt((np.array(normal)**2).sum()) 373 | torques.append([0, 0, 0]) 374 | normals.append(normal) 375 | forces.append(normal) 376 | 377 | return forces, torques, normals, part_is_touching, vertices_that_touches 378 | 379 | 380 | if __name__ == '__main__': 381 | Test() 382 | -------------------------------------------------------------------------------- /models/multi_fingan.py: -------------------------------------------------------------------------------- 1 | import utils.plots as plot_utils 2 | from utils import forward_kinematics_barrett as fk 3 | from barrett_kinematics.barrett_layer.barrett_layer import BarrettLayer 4 | import numpy as np 5 | from networks.networks import NetworksFactory 6 | from .models import BaseModel 7 | from torch.autograd import Variable 8 | from collections import OrderedDict 9 | import torch 10 | from utils import contactutils, util, constrain_hand 11 | 12 | 13 | class Model(BaseModel): 14 | def __init__(self, opt): 15 | super(Model, self).__init__(opt) 16 | self.name = 'Multi_FinGan' 17 | 18 | self.setup_touching_vertices(opt) 19 | 20 | self.device = torch.device('cuda:{}'.format( 21 | opt.gpu_ids[0])) if opt.gpu_ids[0] != -1 and torch.cuda.is_available() else torch.device('cpu') 22 | 23 | # create networks 24 | self.create_and_init_networks() 25 | 26 | self.constrain_hand = constrain_hand.ConstrainHand( 27 | self.opt.constrain_method, self.opt.batch_size, self.device) 28 | # init train variables 29 | if self.is_train: 30 | self.create_and_init_optimizer() 31 | 32 | self.i_epoch = 0 33 | if self.opt.load_network: 34 | self.load() 35 | 36 | self.init_losses() 37 | self.taxonomy_poses = np.load( 38 | './data/average_hand_joints_per_taxonomy.npy') 39 | self.taxonomy_tensor = torch.FloatTensor( 40 | self.taxonomy_poses).to(self.device) 41 | self.gradient_accumulation_every = 2 42 | self.gradient_accumulation_current_step = 0 43 | 44 | if self.opt.precomputed_rotations: 45 | all_approach_orientation = np.load( 46 | "files/uniform_rotations.npy", allow_pickle=True) 47 | self.approach_orientation = torch.FloatTensor( 48 | all_approach_orientation[self.opt.num_grasps_to_sample-2]) 49 | self.setup_bookkeeping_variables() 50 | 51 | def setup_bookkeeping_variables(self): 52 | self.delta_T = torch.zeros(1).to(self.device) 53 | self.delta_R = torch.zeros(1).to(self.device) 54 | self.delta_HR = torch.zeros(1).to(self.device) 55 | self.invalid_hand_conf = torch.zeros(1).to(self.device) 56 | self.true_positive = torch.zeros(1).to(self.device) 57 | self.true_negative = torch.zeros(1).to(self.device) 58 | self.false_positive = torch.zeros(1).to(self.device) 59 | self.false_negative = torch.zeros(1).to(self.device) 60 | self.display_hand_gt_pose = None 61 | self.display_hand_gt_rep = None 62 | self.display_mesh_vertices = None 63 | self.display_mesh_faces = None 64 | 65 | def setup_touching_vertices(self, opt): 66 | self.touching_hand_vertices = [769, 802, 809, 815, 912, 915, 67 | 923, 929, 934, 937, 1026, 1029, 1030, 1037, 1043, 1045, 1048] 68 | self.touching_hand_vertices += [18, 129, 205] 69 | 70 | def create_and_init_networks(self): 71 | self.image_encoder_and_grasp_predictor = self.create_image_encoder_and_grasp_predictor() 72 | self.image_encoder_and_grasp_predictor.init_weights() 73 | self.grasp_generator = self.create_grasp_generator() 74 | self.grasp_generator.init_weights() 75 | 76 | if len(self.gpu_ids) > 1: 77 | self.image_encoder_and_grasp_predictor = torch.nn.DataParallel( 78 | self.image_encoder_and_grasp_predictor, device_ids=self.gpu_ids) 79 | self.grasp_generator = torch.nn.DataParallel(self.grasp_generator, 80 | device_ids=self.gpu_ids) 81 | self.image_encoder_and_grasp_predictor.to(self.device) 82 | self.grasp_generator.to(self.device) 83 | 84 | # Initialize Barrett layer 85 | self.barrett_layer = BarrettLayer( 86 | device=self.device).to(self.device) 87 | if not self.opt.no_discriminator: 88 | # Discriminator network 89 | self.discriminator = self.create_discriminator() 90 | self.discriminator.init_weights() 91 | if len(self.gpu_ids) > 1: 92 | self.discriminator = torch.nn.DataParallel( 93 | self.discriminator, device_ids=self.gpu_ids) 94 | self.discriminator.to(self.device) 95 | 96 | def create_image_encoder_and_grasp_predictor(self): 97 | return NetworksFactory.get_by_name( 98 | 'img_encoder_and_grasp_predictor') # The output is 6 or 7 as we consider 6 or 7 grasp taxonomy classes for the Barrett hand 99 | 100 | def create_grasp_generator(self): 101 | return NetworksFactory.get_by_name( 102 | 'grasp_generator', input_dim=3 + 3 + 7) # 3 for rotation 3 for tranlsation and 7 for hand joints 103 | 104 | def create_discriminator(self): 105 | return NetworksFactory.get_by_name('discriminator', 106 | input_dim=3 + 3 + 1) # 3 for rotation 3 for tranlsation and 1 for finger spread 107 | 108 | def create_and_init_optimizer(self): 109 | self.current_lr_image_encoder_and_grasp_predictor = self.opt.lr_G 110 | # initialize optimizers 111 | if self.opt.optimizer == "Adam": 112 | optimizer = torch.optim.Adam 113 | self.optimizer_image_enc_and_grasp_predictor = optimizer( 114 | self.image_encoder_and_grasp_predictor.parameters(), 115 | lr=self.current_lr_image_encoder_and_grasp_predictor, 116 | betas=[self.opt.G_adam_b1, self.opt.G_adam_b2]) 117 | 118 | self.optimizer_grasp_generator = optimizer(self.grasp_generator.parameters(), 119 | lr=self.current_lr_image_encoder_and_grasp_predictor) 120 | if not self.opt.no_discriminator: 121 | self.current_lr_discriminator = self.opt.lr_D 122 | self.optimizer_discriminator = torch.optim.Adam( 123 | self.discriminator.parameters(), 124 | lr=self.current_lr_discriminator, 125 | betas=[self.opt.D_adam_b1, self.opt.D_adam_b2]) 126 | elif self.opt.optimizer == "SGD": 127 | optimizer = torch.optim.SGD 128 | self.optimizer_image_enc_and_grasp_predictor = optimizer( 129 | self.image_encoder_and_grasp_predictor.parameters(), 130 | lr=self.current_lr_image_encoder_and_grasp_predictor, 131 | momentum=0.9) 132 | 133 | self.optimizer_grasp_generator = optimizer(self.grasp_generator.parameters(), 134 | lr=self.current_lr_image_encoder_and_grasp_predictor, momentum=0.9) 135 | if not self.opt.no_discriminator: 136 | self.current_lr_discriminator = self.opt.lr_D 137 | self.optimizer_discriminator = optimizer( 138 | self.discriminator.parameters(), 139 | lr=self.current_lr_discriminator, 140 | momentum=0.9 141 | ) 142 | else: 143 | raise ValueError("Optimizer ", self.opt.optimizer, 144 | " not availabel.") 145 | 146 | def set_epoch(self, epoch): 147 | self.i_epoch = epoch 148 | 149 | def get_epoch(self): 150 | return self.i_epoch 151 | 152 | def init_losses(self): 153 | self.loss_g_contactloss = Variable(self.Tensor([0])) 154 | self.loss_g_interpenetration = Variable(self.Tensor([0])) 155 | self.loss_g_CE = Variable(self.Tensor([0])) 156 | self.acc_g = Variable(self.Tensor([0])) 157 | self.criterion_CE = torch.nn.BCEWithLogitsLoss().to(self.device) 158 | # if self.opt.rot_loss: 159 | self.canonical_approach_vec = torch.FloatTensor( 160 | [[0, 0, 1]]).to(self.device) 161 | self.loss_g_orientation = Variable(self.Tensor([0])) 162 | if not self.opt.no_discriminator: 163 | self.loss_g_fake = Variable(self.Tensor([0])) 164 | self.loss_d_real = Variable(self.Tensor([0])) 165 | self.loss_d_fake = Variable(self.Tensor([0])) 166 | self.loss_d_fakeminusreal = Variable(self.Tensor([0])) 167 | self.loss_d_gp = Variable(self.Tensor([0])) 168 | 169 | def set_input(self, data_input): 170 | self.input_rgb_img = data_input['rgb_img'].float().permute( 171 | 0, 3, 1, 2).contiguous() 172 | self.input_object_id = data_input['object_id'] 173 | if self.opt.dataset_name == 'ycb': 174 | self.input_taxonomy = data_input['taxonomy'].float() 175 | else: 176 | self.input_taxonomy = data_input['taxonomy'] 177 | self.input_obj_verts = data_input['3d_points_object'] 178 | self.input_obj_faces = data_input['3d_faces_object'] 179 | self.input_obj_resampled_verts = data_input['object_points_resampled'] 180 | self.input_hand_gt_rep = data_input['hand_gt_representation'].float() 181 | self.input_hand_gt_pose = data_input['hand_gt_pose'].float() 182 | 183 | if torch.cuda.is_available(): 184 | self.input_rgb_img = self.input_rgb_img.to(self.device) 185 | self.input_taxonomy = self.input_taxonomy.to(self.device) 186 | self.input_hand_gt_rep = self.input_hand_gt_rep.to(self.device) 187 | self.input_hand_gt_pose = self.input_hand_gt_pose.to(self.device) 188 | 189 | self.batch_size = self.input_rgb_img.size(0) 190 | self.calculate_center_of_objects() 191 | 192 | def calculate_center_of_objects(self): 193 | center_objects = [] 194 | for i in range(self.batch_size): 195 | center_objects.append(self.input_obj_resampled_verts[i][ 196 | 0].mean(0)) 197 | 198 | self.input_center_objects = torch.FloatTensor( 199 | center_objects).to(self.device) 200 | 201 | def set_train(self): 202 | self.image_encoder_and_grasp_predictor.train() 203 | self.grasp_generator.train() 204 | if not self.opt.no_discriminator: 205 | self.discriminator.train() 206 | self.zero_losses() 207 | self.zero_bookkeeping() 208 | self.is_train = True 209 | 210 | def zero_bookkeeping(self): 211 | self.true_positive *= 0 212 | self.true_negative *= 0 213 | self.false_positive *= 0 214 | self.false_negative *= 0 215 | self.delta_HR *= 0 216 | self.delta_T *= 0 217 | self.delta_R *= 0 218 | self.invalid_hand_conf *= 0 219 | 220 | def zero_losses(self): 221 | self.loss_g_CE *= 0 222 | self.acc_g *= 0 223 | # if self.opt.intersection_loss: 224 | self.loss_g_interpenetration *= 0 225 | # if self.opt.contact_loss: 226 | self.loss_g_contactloss *= 0 227 | # if self.opt.rot_loss: 228 | self.loss_g_orientation *= 0 229 | if not self.opt.no_discriminator: 230 | self.loss_g_fake *= 0 231 | self.loss_d_real *= 0 232 | self.loss_d_fake *= 0 233 | self.loss_d_fake *= 0 234 | self.loss_d_real *= 0 235 | self.loss_d_gp *= 0 236 | 237 | def set_eval(self): 238 | self.image_encoder_and_grasp_predictor.eval() 239 | self.grasp_generator.eval() 240 | if not self.opt.no_discriminator: 241 | self.discriminator.eval() 242 | # Zero all losses 243 | self.zero_losses() 244 | self.is_train = False 245 | 246 | def create_pose(self, T, R): 247 | pose = torch.eye(4).view(1, 4, 4).to(self.device) 248 | pose = pose.repeat(T.shape[0], 1, 1) 249 | pose[:, :3, :3] = R 250 | pose[:, :3, 3] = T 251 | return pose 252 | 253 | def calculate_interpenetration(self, hand_vertices, hand_vertice_areas, batch_size): 254 | interpenetration = torch.FloatTensor([0]).to(self.device) 255 | # INTERSECTION LOSS ON OPTIMIZED HAND! 256 | for i in range(batch_size): 257 | numobjects = len(self.input_obj_verts[i]) 258 | all_triangles = [] 259 | # all_verts = []_ 260 | for j in range(numobjects): 261 | obj_triangles = self.input_obj_verts[i][j][ 262 | self.input_obj_faces[i][j]] 263 | obj_triangles = torch.FloatTensor( 264 | obj_triangles).to(self.device) 265 | all_triangles.append(obj_triangles) 266 | 267 | all_triangles = torch.cat(all_triangles) 268 | 269 | exterior = contactutils.batch_mesh_contains_points( 270 | hand_vertices[i].unsqueeze(0), all_triangles.unsqueeze(0), device=self.device) 271 | penetr_mask = ~exterior 272 | if penetr_mask.sum() == 0: 273 | continue 274 | 275 | allpoints_resampled = torch.FloatTensor( 276 | self.input_obj_resampled_verts[i]).to(self.device).reshape( 277 | -1, 3).unsqueeze(0) 278 | dists = util.batch_pairwise_dist( 279 | hand_vertices[i, penetr_mask[0]].unsqueeze(0), allpoints_resampled) 280 | mins21, _ = torch.min(dists, 2) 281 | mins21[mins21 < 1e-4] = 0 282 | mins21 = mins21*hand_vertice_areas[i, penetr_mask[0]] 283 | 284 | interpenetration = interpenetration + mins21.mean() 285 | 286 | return interpenetration 287 | 288 | def calculate_contact(self, points_3d): 289 | relevantobjs_resampled = torch.FloatTensor([ 290 | self.input_obj_resampled_verts[i][0] 291 | for i in range(self.batch_size) 292 | ]).to(self.device) 293 | 294 | distance_touching_vertices_fake = self.get_touching_distances( 295 | points_3d, relevantobjs_resampled) 296 | distance_touching_vertices_fake[distance_touching_vertices_fake < 297 | self.opt.collision_loss_threshold] = 0 298 | 299 | return distance_touching_vertices_fake 300 | 301 | def get_touching_distances(self, hand_points, object_points): 302 | 303 | relevant_vertices = hand_points[:, self.touching_hand_vertices] 304 | n1 = len(self.touching_hand_vertices) 305 | n2 = len(object_points[0]) 306 | 307 | matrix1 = relevant_vertices.unsqueeze(1).repeat(1, n2, 1, 1) 308 | matrix2 = object_points.unsqueeze(2).repeat(1, 1, n1, 1) 309 | 310 | dists = torch.norm(matrix1 - matrix2, dim=-1) 311 | dists = dists.min(1)[0] 312 | return dists 313 | 314 | def get_distances_single_example(self, relevant_vertices, object_points): 315 | n1 = len(relevant_vertices) 316 | # TODO: HAVE TO DO IT IN A LOOP SINCE OBJECTS ALL HAVE DIFFERENT AMOUNT OF VERTICES 317 | n2 = len(object_points) 318 | 319 | matrix1 = relevant_vertices.unsqueeze(0).repeat(n2, 1, 1) 320 | if torch.cuda.is_available(): 321 | matrix2 = torch.FloatTensor(object_points).to(self.device).unsqueeze( 322 | 1).repeat(1, n1, 1) 323 | else: 324 | matrix2 = torch.FloatTensor(object_points).unsqueeze(1).repeat( 325 | 1, n1, 1) 326 | dists = torch.sqrt(((matrix1 - matrix2)**2).sum(-1)) 327 | return dists.min(0)[0] 328 | 329 | def optimize_parameters(self, 330 | train_generator=True, 331 | ): 332 | if self.is_train: 333 | # convert tensor to variables 334 | self.batch_size = self.input_rgb_img.size(0) 335 | 336 | # train discriminator_ 337 | if not self.opt.no_discriminator: 338 | fake_input_D, real_input_D, loss_D = self.forward_D() 339 | loss_D_gp = self.gradient_penalty_D(fake_input_D, real_input_D) 340 | 341 | self.optimizer_discriminator.zero_grad() 342 | loss = loss_D + loss_D_gp 343 | loss.backward(retain_graph=True) 344 | self.optimizer_discriminator.step() 345 | 346 | # train generator 347 | if train_generator: 348 | self.forward_G(True) 349 | loss_G = self.combine_generator_losses() 350 | loss_G.backward() 351 | self.gradient_accumulation_current_step += 1 352 | 353 | if self.gradient_accumulation_current_step % self.gradient_accumulation_every == 0: 354 | self.optimizer_image_enc_and_grasp_predictor.step() 355 | self.optimizer_grasp_generator.step() 356 | self.optimizer_image_enc_and_grasp_predictor.zero_grad() 357 | self.optimizer_grasp_generator.zero_grad() 358 | self.gradient_accumulation_current_step = 0 359 | 360 | def evaluate_prediction(self, prediction): 361 | self.loss_g_CE = self.criterion_CE( 362 | prediction, self.input_taxonomy)*self.opt.lambda_G_classification 363 | thresholded_predictions = (torch.sigmoid(prediction) > 0.5).long() 364 | self.acc_g = (thresholded_predictions.cpu().data.numpy( 365 | ) == self.input_taxonomy.cpu().data.numpy()).mean() 366 | tp, tn, fp, fn = util.calculate_classification_statistics( 367 | self.input_taxonomy, thresholded_predictions) 368 | self.true_positive += tp 369 | self.true_negative += tn 370 | self.false_positive += fp 371 | self.false_negative += fn 372 | 373 | def finger_refinement(self, HR, rot_matrix, T, batched=False): 374 | if batched: 375 | fk.optimize_fingers_batched( 376 | HR, rot_matrix, 377 | T, self.input_obj_resampled_verts[0][ 378 | 0], self.barrett_layer, self.opt.object_finger_contact_threshold, optimize_finger_tip=self.opt.optimize_fingertip, device=self.device) 379 | else: 380 | for i in range(self.batch_size): 381 | _ = fk.optimize_fingers( 382 | HR[i].view(1, -1), rot_matrix[i], 383 | T[i], self.input_obj_resampled_verts[i][ 384 | 0], self.barrett_layer, self.opt.object_finger_contact_threshold, optimize_finger_tip=self.opt.optimize_fingertip, device=self.device) 385 | 386 | def forward_G(self, train=True): 387 | input_img = self.input_rgb_img 388 | prediction, img_representations = self.image_encoder_and_grasp_predictor.forward( 389 | input_img) 390 | if train: 391 | self.evaluate_prediction(prediction) 392 | hand_representations = self.hand_representation_from_prediction( 393 | prediction) 394 | 395 | if self.opt.precomputed_rotations: 396 | # If we use precomputed rotations it means that we are doing the simulation experiments 397 | # and for that we generate self.opt.num_grasps_to_sample grasps but only on one object at a time. 398 | # Therefore, we need to repeat the translation_in, img_representation, and hand_representation 399 | # accordingly 400 | random_rots = torch.FloatTensor(self.approach_orientation).to( 401 | self.device) 402 | translation_in = self.input_center_objects.repeat( 403 | self.opt.num_grasps_to_sample, 1) 404 | img_representations = img_representations.repeat( 405 | self.opt.num_grasps_to_sample, 1) 406 | hand_representations = hand_representations.repeat( 407 | self.opt.num_grasps_to_sample, 1) 408 | else: 409 | axis_angles = util.rotm_to_axis_angles( 410 | self.input_hand_gt_pose[:, :3, :3], self.device) 411 | random_rots = axis_angles + torch.normal( 412 | mean=torch.zeros(self.batch_size, 3)).to(self.device) / 5 413 | translation_in = self.input_center_objects 414 | 415 | hand_configuration, R, T = self.grasp_generator.forward(img_representations, hand_representations, 416 | random_rots, translation_in) 417 | rot_matrix = util.axis_angles_to_rotation_matrix(R) 418 | 419 | hand_configuration, hand_configuration_unconstrained = self.constrain_hand( 420 | hand_configuration) 421 | # Batched hand refinement is only possible when not training, i.e. when we do the simulation experiments in https://arxiv.org/pdf/2012.09696.pdf. 422 | # The reason for this is because in training mode we operate with much higher batch sizes and the batched hand-refinement layer cannot handle large 423 | # batch sizes 424 | self.finger_refinement( 425 | hand_configuration, rot_matrix, T, batched=not train) 426 | 427 | self.display_hand_gt_pose = self.input_hand_gt_pose 428 | self.display_hand_gt_rep = self.input_hand_gt_rep 429 | self.display_mesh_vertices = self.input_obj_verts 430 | self.display_mesh_faces = self.input_obj_faces 431 | pose = self.create_pose(T, rot_matrix) 432 | 433 | * points_3d, _ = self.barrett_layer(pose, hand_configuration) 434 | if train: 435 | self.delta_T = util.euclidean_distance( 436 | translation_in, T).mean() 437 | 438 | self.invalid_hand_conf = util.valid_hand_conf(util.rad2deg( 439 | hand_configuration), self.device).squeeze() 440 | self.delta_R = util.axis_angle_distance(random_rots, R).mean() 441 | self.delta_HR = (hand_representations - 442 | hand_configuration).abs().mean() 443 | 444 | points_3d, points_3d_area = util.concatenate_barret_vertices_and_vertice_areas( 445 | points_3d, self.barrett_layer.vertice_face_areas, use_torch=True) 446 | self.refined_handpose = points_3d.cpu().data.numpy() 447 | points_3d_area = points_3d_area.to(self.device) 448 | self.calculate_generator_losses( 449 | points_3d, points_3d_area, R, T, translation_in, hand_configuration_unconstrained) 450 | else: 451 | points_3d_concatenated, faces_concatenated = util.concatenate_barret_vertices_and_faces( 452 | points_3d, self.barrett_layer.gripper_faces, use_torch=True) 453 | 454 | self.refined_handpose_concatenated = points_3d_concatenated 455 | self.refined_handpose = points_3d 456 | self.refined_handpose_faces = faces_concatenated 457 | self.refined_HR = hand_configuration 458 | self.refined_R = R 459 | self.refined_T = T 460 | 461 | def calculate_generator_losses(self, points_3d, points_3d_area, R, T, translation_in, hand_configuration_unconstrained): 462 | interpenetration = torch.FloatTensor([0]).to(self.device) 463 | # INTERSECTION LOSS ON OPTIMIZED HAND 464 | interpenetration = self.calculate_interpenetration( 465 | points_3d, points_3d_area, self.batch_size) 466 | self.loss_g_interpenetration = interpenetration / \ 467 | self.batch_size * self.opt.lambda_G_intersections 468 | # CONTACT LOSS ON OPTIMIZED HAND 469 | distance_touching_vertices_fake = self.calculate_contact(points_3d) 470 | self.loss_g_contactloss = distance_touching_vertices_fake.mean( 471 | ) * self.opt.lambda_G_contactloss 472 | # ORIENTATION LOSS 473 | self.loss_g_orientation = self.opt.lambda_G_orientation * \ 474 | self.calculate_rot_loss( 475 | R, T, self.input_center_objects).mean() 476 | if not self.opt.no_discriminator: 477 | fake_input_D = torch.cat( 478 | ( 479 | R, 480 | hand_configuration_unconstrained[:, 0].unsqueeze(1), 481 | T - translation_in), 482 | 1) 483 | d_fake_prob = self.discriminator(fake_input_D) 484 | self.loss_g_fake = self.compute_loss_D( 485 | d_fake_prob, True) * self.opt.lambda_D_prob 486 | 487 | def combine_generator_losses(self): 488 | combined_losses = 0 489 | if not self.opt.no_classification_loss: 490 | combined_losses += self.loss_g_CE 491 | if not self.opt.no_intersection_loss: 492 | combined_losses += self.loss_g_interpenetration[0] 493 | if not self.opt.no_contact_loss: 494 | combined_losses += self.loss_g_contactloss 495 | if not self.opt.no_orientation_loss: 496 | combined_losses += self.loss_g_orientation 497 | if not self.opt.no_discriminator: 498 | combined_losses += self.loss_g_fake 499 | return combined_losses 500 | 501 | def calculate_rot_loss(self, rot, T, object_center): 502 | hand_to_obj_vec = object_center-T 503 | hand_to_obj_vec_normed = hand_to_obj_vec / \ 504 | torch.norm(hand_to_obj_vec, dim=-1, keepdim=True) 505 | hand_approach_vector = util.axis_angle_rot( 506 | rot, self.canonical_approach_vec) 507 | 508 | rot_loss = 1-(hand_to_obj_vec_normed*hand_approach_vector).sum(-1) 509 | return rot_loss 510 | 511 | def is_nan(self, tensor, type): 512 | if torch.any(torch.isnan(tensor)): 513 | print(type + " is nan") 514 | 515 | def forward_D(self): 516 | input_img = self.input_rgb_img 517 | axis_angles = util.rotm_to_axis_angles( 518 | self.input_hand_gt_pose[:, :3, :3], self.device) 519 | input_hand_gt_rot = axis_angles 520 | random_rots = axis_angles + torch.normal( 521 | mean=torch.zeros(self.batch_size, 3)).to(self.device) / 5 522 | 523 | prediction, img_representations = self.image_encoder_and_grasp_predictor.forward( 524 | input_img) 525 | self.evaluate_prediction(prediction) 526 | hand_representations = self.hand_representation_from_prediction( 527 | prediction) 528 | 529 | translation_in = self.input_center_objects 530 | 531 | HR, R, T = self.grasp_generator.forward(img_representations, hand_representations, 532 | random_rots, translation_in) 533 | 534 | fake_input_D = torch.cat( 535 | (R, HR[:, 0].unsqueeze(1), T - translation_in), 1).detach() 536 | real_input_D = torch.cat( 537 | (input_hand_gt_rot, self.input_hand_gt_rep[:, 0].unsqueeze(1), 538 | self.input_hand_gt_pose[:, :3, 3] - translation_in), 1).detach() 539 | 540 | d_fake_prob = self.discriminator(fake_input_D) 541 | d_real_prob = self.discriminator(real_input_D) 542 | 543 | self.loss_d_real = self.compute_loss_D( 544 | d_real_prob, True) * self.opt.lambda_D_prob 545 | self.loss_d_fake = self.compute_loss_D( 546 | d_fake_prob, False) * self.opt.lambda_D_prob 547 | 548 | return fake_input_D, real_input_D, self.loss_d_real + self.loss_d_fake 549 | 550 | def hand_representation_from_prediction(self, prediction): 551 | prediction_probs = torch.sigmoid( 552 | prediction) > 0.5 553 | idx = prediction_probs.nonzero() 554 | hand_representations = torch.zeros((self.batch_size, 7)) 555 | rand_nonzero_cols = torch.randint( 556 | self.taxonomy_tensor.shape[0], (self.batch_size,)) 557 | for batch_idx in torch.unique(idx[:, 0]): 558 | nonzero_columns = prediction_probs[batch_idx].nonzero() 559 | num_nonzero_cols = len(nonzero_columns) 560 | rand_nonzero_cols[batch_idx] = nonzero_columns[torch.randint( 561 | num_nonzero_cols, (1,))][0] 562 | hand_representations = self.taxonomy_tensor[rand_nonzero_cols] 563 | return hand_representations 564 | 565 | def gradient_penalty_D(self, fake_input_D, real_input_D): 566 | # interpolate sample 567 | alpha = torch.rand( 568 | self.batch_size, fake_input_D.shape[1]).to(self.device) 569 | alpha.requires_grad = True 570 | interpolated = alpha * real_input_D + (1 - alpha) * fake_input_D 571 | interpolated_prob = self.discriminator(interpolated) 572 | 573 | # compute gradients 574 | grad = torch.autograd.grad(outputs=interpolated_prob, 575 | inputs=interpolated, 576 | grad_outputs=torch.ones( 577 | interpolated_prob.size()).to(self.device), 578 | retain_graph=True, 579 | create_graph=True, 580 | only_inputs=True)[0] 581 | 582 | # penalize gradients 583 | grad = grad.view(grad.size(0), -1) 584 | grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1)) 585 | self.loss_d_gp = torch.mean( 586 | (grad_l2norm - 1)**2) * self.opt.lambda_D_gp 587 | 588 | return self.loss_d_gp 589 | 590 | def compute_loss_D(self, estim, is_real): 591 | return -torch.mean(estim) if is_real else torch.mean(estim) 592 | 593 | def get_current_errors(self): 594 | losses = [] 595 | losses.append( 596 | ('g CE', self.loss_g_CE.cpu().data.numpy())) 597 | losses.append( 598 | ('g acc', self.acc_g)) 599 | 600 | losses.append( 601 | ('g contact loss', self.loss_g_contactloss.cpu().data.numpy())) 602 | losses.append( 603 | ('g intersection loss', self.loss_g_interpenetration.cpu().data.numpy())) 604 | losses.append( 605 | ('g orientation loss', self.loss_g_orientation .cpu().data.numpy())) 606 | if not self.opt.no_discriminator: 607 | losses.append(('g fake', self.loss_g_fake.cpu().data.numpy())) 608 | losses.append(('d real', self.loss_d_real.cpu().data.numpy())) 609 | losses.append(('d fake', self.loss_d_fake.cpu().data.numpy())) 610 | losses.append(('d fakeminusreal', self.loss_d_fake.cpu().data.numpy() - 611 | self.loss_d_real.cpu().data.numpy())) 612 | losses.append(('d gp', self.loss_d_gp.cpu().data.numpy())) 613 | return OrderedDict(losses) 614 | 615 | def get_current_errors_G(self): 616 | losses = [] 617 | losses.append( 618 | ('g CE', self.loss_g_CE.cpu().data.numpy())) 619 | losses.append( 620 | ('g acc', self.acc_g)) 621 | losses.append( 622 | ('g contact loss', self.loss_g_contactloss.cpu().data.numpy())) 623 | losses.append( 624 | ('g intersection loss', self.loss_g_interpenetration.cpu().data.numpy())) 625 | losses.append( 626 | ('g orientation loss', self.loss_g_orientation .cpu().data.numpy())) 627 | if not self.opt.no_discriminator: 628 | losses.append(('g fake', self.loss_g_fake.cpu().data.numpy())) 629 | return OrderedDict(losses) 630 | 631 | def get_current_errors_D(self): 632 | losses = [] 633 | if not self.opt.no_discriminator: 634 | losses.append(('d real', self.loss_d_real.cpu().data.numpy())) 635 | losses.append(('d fake', self.loss_d_fake.cpu().data.numpy())) 636 | losses.append(('d fakeminusreal', self.loss_d_fake.cpu().data.numpy() - 637 | self.loss_d_real.cpu().data.numpy())) 638 | losses.append(('d gp', self.loss_d_gp.cpu().data.numpy())) 639 | return OrderedDict(losses) 640 | 641 | def get_current_scalars(self): 642 | scalars = [ 643 | ('mean euclidean dist', self.delta_T.cpu().data.numpy()), 644 | ('mean rotation dist', self.delta_R.cpu().data.numpy()), 645 | ('mean finger spread angle', self.delta_HR.cpu().data.numpy()), 646 | ('lr G', self.current_lr_image_encoder_and_grasp_predictor) 647 | ] 648 | if not self.opt.no_discriminator: 649 | scalars.append(('lr D', self.current_lr_discriminator)) 650 | scalars.append( 651 | ('f1 score', self.f1_score())) 652 | scalars.append( 653 | ('number of invalid configurations', self.invalid_hand_conf.cpu().data.numpy().astype('float32'))) 654 | 655 | return OrderedDict(scalars) 656 | 657 | def get_current_visuals(self): 658 | # visuals return dictionary 659 | visuals = {} 660 | groundtruths = [] 661 | predictions = [] 662 | if self.display_hand_gt_rep is not None: 663 | rand_grasp_idx = np.random.randint( 664 | len(self.display_hand_gt_rep)) 665 | else: 666 | rand_grasp_idx = 0 667 | 668 | if self.display_hand_gt_rep is not None: 669 | * hand_verts, _ = self.barrett_layer( 670 | self.display_hand_gt_pose[rand_grasp_idx].view(1, 4, 4), self.display_hand_gt_rep[rand_grasp_idx].view(1, 7)) 671 | gt_verts, gt_faces = util.concatenate_barret_vertices_and_faces( 672 | hand_verts, self.barrett_layer.gripper_faces) 673 | groundtruths = plot_utils.plot_scene_w_grasps( 674 | self.display_mesh_vertices[rand_grasp_idx], self.display_mesh_faces[rand_grasp_idx], gt_verts, 675 | gt_faces) 676 | try: 677 | predictions = plot_utils.plot_scene_w_grasps( 678 | self.display_mesh_vertices[rand_grasp_idx], self.display_mesh_faces[rand_grasp_idx], 679 | [self.refined_handpose[rand_grasp_idx]], gt_faces) 680 | except Exception as e: 681 | pass 682 | 683 | visuals['1_groundtruth'] = groundtruths 684 | visuals['2_prediction'] = predictions 685 | return OrderedDict(visuals) 686 | 687 | def save(self, label, epoch): 688 | # save networks and optimizers 689 | self.save_network(self.image_encoder_and_grasp_predictor, 690 | 'image_encoder_and_grasp_predictor', label, epoch) 691 | self.save_optimizer(self.optimizer_image_enc_and_grasp_predictor, 692 | 'image_encoder_and_grasp_predictor', label) 693 | self.save_network(self.grasp_generator, 694 | 'grasp_generator', label, epoch) 695 | self.save_optimizer(self.optimizer_grasp_generator, 696 | 'grasp_generator', label) 697 | if not self.opt.no_discriminator: 698 | self.save_network(self.discriminator, 699 | 'discriminator', label, epoch) 700 | self.save_optimizer(self.optimizer_discriminator, 701 | 'discriminator', label) 702 | 703 | def load(self): 704 | load_epoch = self.opt.load_epoch 705 | 706 | # load image_encoder 707 | self.load_network(self.image_encoder_and_grasp_predictor, 708 | 'image_encoder_and_grasp_predictor', load_epoch, self.device) 709 | self.load_network(self.grasp_generator, 710 | 'grasp_generator', load_epoch, self.device) 711 | if not self.opt.no_discriminator: 712 | self.load_network(self.discriminator, 713 | 'discriminator', load_epoch, self.device) 714 | 715 | if self.is_train: 716 | # load optimizers 717 | self.load_optimizer(self.optimizer_image_enc_and_grasp_predictor, 'image_encoder_and_grasp_predictor', 718 | load_epoch, self.device) 719 | self.load_optimizer(self.optimizer_grasp_generator, 'grasp_generator', 720 | load_epoch, self.device) 721 | if not self.opt.no_discriminator: 722 | self.load_optimizer(self.optimizer_discriminator, 'discriminator', 723 | load_epoch, self.device) 724 | self.set_learning_rate() 725 | 726 | def update_learning_rate(self): 727 | # updated learning rateimage_encoder 728 | lr_decay_G = self.opt.lr_G / self.opt.nepochs_decay 729 | self.current_lr_image_encoder_and_grasp_predictor -= lr_decay_G 730 | if not self.opt.no_discriminator: 731 | lr_decay_D = self.opt.lr_D / self.opt.nepochs_decay 732 | self.current_lr_discriminator -= lr_decay_D 733 | self.set_learning_rate(plot=False) 734 | print('update image_encoder_and_grasp_predictor and grasp_generator learning rate: %f -> %f' % 735 | (self.current_lr_image_encoder_and_grasp_predictor + lr_decay_G, self.current_lr_image_encoder_and_grasp_predictor)) 736 | if not self.opt.no_discriminator: 737 | print('update discriminator learning rate: %f -> %f' % 738 | (self.current_lr_discriminator + lr_decay_D, self.current_lr_discriminator)) 739 | 740 | def set_learning_rate(self, plot=True): 741 | for param_group in self.optimizer_image_enc_and_grasp_predictor.param_groups: 742 | param_group['lr'] = self.current_lr_image_encoder_and_grasp_predictor 743 | for param_group in self.optimizer_grasp_generator.param_groups: 744 | param_group['lr'] = self.current_lr_image_encoder_and_grasp_predictor 745 | if not self.opt.no_discriminator: 746 | for param_group in self.optimizer_discriminator.param_groups: 747 | param_group['lr'] = self.current_lr_discriminator 748 | if plot: 749 | print('set image_encoder_and_grasp_predictor and grasp_generator learning rate to: %f' % 750 | (self.current_lr_image_encoder_and_grasp_predictor)) 751 | if not self.opt.no_discriminator: 752 | print('set discriminator learning rate to: %f' % 753 | (self.current_lr_discriminator)) 754 | 755 | def f1_score(self): 756 | f1_score = util.f1_score( 757 | self.true_positive, self.true_negative, self.false_positive, self.false_negative) 758 | return f1_score.cpu().numpy() 759 | --------------------------------------------------------------------------------