├── arm_models ├── __init__.py ├── utils │ ├── augmentation.py │ ├── representation_extractor.py │ └── load_data.py ├── imitation │ ├── networks.py │ └── non_parametric.py ├── dataloaders │ ├── state_dataset.py │ └── visual_dataset.py ├── deploy │ ├── evaluation │ │ └── performance_test_inn.py │ ├── model_scripts │ │ ├── visual_inn.py │ │ └── inn.py │ ├── deploy_scripts │ │ ├── mlp_deploy.py │ │ ├── bc_deploy.py │ │ ├── vinn_deploy.py │ │ └── inn_deploy.py │ └── deploy.py └── train.py ├── .gitignore ├── setup.py └── README.md /arm_models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Data 2 | arm_models/data/* 3 | 4 | <<<<<<< HEAD 5 | data/* 6 | Imitation_NN/__pycache__/* 7 | utils/__pycache__/* 8 | ======= 9 | # Losses and Graphs 10 | arm_models/deploy/evaluation/losses/* 11 | arm_models/deploy/evaluation/graphs/* 12 | 13 | # Model checkpoints 14 | arm_models/deploy/checkpoints/* 15 | 16 | # Wandb logs 17 | arm_models/wandb/* 18 | 19 | # Other installation files 20 | dexterous_arm_models.egg-info/* 21 | >>>>>>> dexarm-inn 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from setuptools import setup, find_packages 4 | 5 | print("Installing Dexterous Arm Hardware models package!.") 6 | 7 | if sys.version_info.major != 3: 8 | print("This Python is only compatible with Python 3, but you are running " 9 | "Python {}. The installation will likely fail.".format(sys.version_info.major)) 10 | 11 | def read(fname): 12 | return open(os.path.join(os.path.dirname(__file__), fname)).read() 13 | 14 | setup( 15 | name='dexterous_arm_models', 16 | version='1.0.0', 17 | packages=find_packages(), 18 | description='Models that can be deployed on the dexterous arm.', 19 | long_description=read('README.md'), 20 | url='https://github.com/NYU-robot-learning/Dexterous-Arm-Models', 21 | author='Sridhar Pandian', 22 | ) -------------------------------------------------------------------------------- /arm_models/utils/augmentation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms as T 3 | 4 | def augmentation_generator(task): 5 | if task == "rotate": 6 | norm_values = T.Normalize( 7 | mean = torch.tensor([0.3484, 0.3638, 0.3819]), 8 | std = torch.tensor([0.3224, 0.3151, 0.3166]) 9 | ) 10 | elif task == "flip": 11 | norm_values = T.Normalize( 12 | mean = torch.tensor([0, 0, 0]), 13 | std = torch.tensor([1, 1, 1]) 14 | ) 15 | elif task == "spin": 16 | norm_values = T.Normalize( 17 | mean = torch.tensor([0, 0, 0]), 18 | std = torch.tensor([1, 1, 1]) 19 | ) 20 | 21 | augment_custom = T.Compose([ 22 | T.RandomResizedCrop(224, scale = (0.6, 1)), 23 | T.RandomApply(torch.nn.ModuleList([T.ColorJitter(.8,.8,.8,.2)]), p=.3), 24 | T.RandomGrayscale(p = 0.2), 25 | T.RandomApply(torch.nn.ModuleList([T.GaussianBlur((3, 3), (1.0, 2.0))]), p=0.2), 26 | norm_values 27 | ]) 28 | 29 | return augment_custom -------------------------------------------------------------------------------- /arm_models/imitation/networks.py: -------------------------------------------------------------------------------- 1 | # General torch imports 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | # Imports for pretrained encoders 7 | from torchvision import models 8 | 9 | # Import for Representation learners 10 | from byol_pytorch import BYOL 11 | 12 | # State based Behavior Cloning 13 | class MLP(nn.Module): 14 | def __init__(self): 15 | super(MLP, self).__init__() 16 | self.fc_1 = nn.Linear(15, 128) 17 | self.fc_2 = nn.Linear(128, 512) 18 | self.fc_3 = nn.Linear(512, 512) 19 | self.fc_4 = nn.Linear(512, 128) 20 | self.fc_5 = nn.Linear(128, 12) 21 | self.batch_norm = nn.BatchNorm1d(512) 22 | 23 | def forward(self, x): 24 | x = x.view(-1, 15) 25 | x = F.leaky_relu(self.fc_1(x)) 26 | x = F.leaky_relu(self.fc_2(x)) 27 | x = self.fc_3(x) 28 | x = F.leaky_relu(self.batch_norm(x)) 29 | x = F.leaky_relu(self.fc_4(x)) 30 | x = self.fc_5(x) 31 | return x 32 | 33 | # Visual Behavior Cloning 34 | class BehaviorCloning(nn.Module): 35 | def __init__(self): 36 | super(BehaviorCloning, self).__init__() 37 | # Encoder 38 | self.encoder = models.resnet50(pretrained = True) 39 | # Fully Connected Regression Layer 40 | self.fc1 = nn.Linear(1000, 1000) 41 | self.fc2 = nn.Linear(1000, 1024) 42 | self.fc3 = nn.Linear(1024, 12) 43 | 44 | def forward(self, x): 45 | x = self.encoder(x) 46 | x = self.fc1(x) 47 | x = F.leaky_relu(x) 48 | x = self.fc2(x) 49 | x = F.leaky_relu(x) 50 | x = self.fc3(x) 51 | return x 52 | -------------------------------------------------------------------------------- /arm_models/dataloaders/state_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import Dataset 4 | 5 | from arm_models.utils.load_data import load_state_actions 6 | 7 | class CubeRotationDataset(Dataset): 8 | def __init__(self, data_path = '/home/sridhar/dexterous_arm/models/arm_models/data/cube_rotation/complete'): 9 | self.data_path = data_path 10 | self.states, self.actions = load_state_actions(data_path) 11 | 12 | def __len__(self): 13 | return len(self.states) 14 | 15 | def __getitem__(self, idx): 16 | if torch.is_tensor(idx): 17 | idx = idx.tolist() 18 | 19 | return self.states[idx], self.actions[idx] 20 | 21 | class ObjectFlippingDataset(Dataset): 22 | def __init__(self, data_path = '/home/sridhar/dexterous_arm/models/arm_models/data/object_flipping/complete'): 23 | self.data_path = data_path 24 | self.states, self.actions = load_state_actions(data_path) 25 | 26 | def __len__(self): 27 | return len(self.states) 28 | 29 | def __getitem__(self, idx): 30 | if torch.is_tensor(idx): 31 | idx = idx.tolist() 32 | 33 | return self.states[idx], self.actions[idx] 34 | 35 | class FidgetSpinningDataset(Dataset): 36 | def __init__(self, data_path = '/home/sridhar/dexterous_arm/models/arm_models/data/fidget_spinning/complete'): 37 | self.data_path = data_path 38 | self.states, self.actions = load_state_actions(data_path) 39 | 40 | def __len__(self): 41 | return len(self.states) 42 | 43 | def __getitem__(self, idx): 44 | if torch.is_tensor(idx): 45 | idx = idx.tolist() 46 | 47 | return self.states[idx], self.actions[idx] 48 | -------------------------------------------------------------------------------- /arm_models/utils/representation_extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torchvision import models 4 | 5 | from tqdm import tqdm 6 | 7 | from arm_models.dataloaders.visual_dataset import * 8 | 9 | from byol_pytorch import BYOL 10 | 11 | REP_TENSOR_PATH = '/home/sridhar/dexterous_arm/models/arm_models/data/fidget_spinning/complete/representations/representations.pth' 12 | REP_MODEL_CHKPT_PATH = '/home/sridhar/dexterous_arm/models/arm_models/deploy/checkpoints/representation_byol - spin - lowest - train - v1.pth' 13 | 14 | def extract_representations(device, CHKPT_PATH): 15 | # Loading the dataset and creating the dataloader 16 | # dataset = CubeRotationVisualDataset() 17 | # dataset = ObjectFlippingVisualDataset() 18 | dataset = FidgetSpinningVisualDataset() 19 | 20 | dataloader = DataLoader(dataset, batch_size = 1, shuffle = False, pin_memory = True, num_workers = 24) 21 | 22 | # Loading the model 23 | original_encoder_model = models.resnet50(pretrained = True) 24 | encoder = torch.nn.Sequential(*(list(original_encoder_model.children())[:-1])) 25 | encoder = encoder.to(device) 26 | 27 | learner = BYOL ( 28 | encoder, 29 | image_size = 224 30 | ) 31 | learner.load_state_dict(torch.load(CHKPT_PATH)) 32 | learner.eval() 33 | 34 | representations = [] 35 | 36 | # Extracting the representations 37 | for image, action in tqdm(dataloader): 38 | representation = learner.net(image.float().to(device)).squeeze() 39 | representations.append(representation.detach().cpu()) 40 | 41 | representation_tensor = torch.stack(representations) 42 | print("Final representation tensor shape:", representation_tensor.squeeze().shape) 43 | 44 | return representation_tensor.squeeze(1) 45 | 46 | def store_representations(representations, DATA_PATH): 47 | torch.save(representations, DATA_PATH) 48 | 49 | if __name__ == '__main__': 50 | # Selecting the GPU to be used 51 | gpu_number = int(input('Enter the GPU number: ')) 52 | device = torch.device('cuda:{}'.format(gpu_number)) 53 | 54 | print('Using GPU: {} for extracting the representations. \n'.format(torch.cuda.get_device_name(gpu_number))) 55 | 56 | representation_tensor = extract_representations(device, CHKPT_PATH = REP_MODEL_CHKPT_PATH) 57 | store_representations(representation_tensor, REP_TENSOR_PATH) -------------------------------------------------------------------------------- /arm_models/deploy/evaluation/performance_test_inn.py: -------------------------------------------------------------------------------- 1 | from cProfile import label 2 | import torch 3 | import pickle 4 | 5 | import matplotlib.pyplot as plt 6 | 7 | from IPython import embed 8 | 9 | from arm_models.imitation.non_parametric import INN 10 | from arm_models.utils.load_data import * 11 | 12 | DATA_PATH = "/home/sridhar/dexterous_arm/models/arm_models/data/cube_rotation/for_eval" 13 | 14 | PERFORMANCE_METRICS_PATH = "/home/sridhar/dexterous_arm/models/arm_models/deploy/evaluation/losses/INN_losses" 15 | 16 | def performance_test_inn(inn, k, test_states, test_actions): 17 | loss = 0 18 | 19 | # Testing the model 20 | for idx, test_state in enumerate(test_states): 21 | if k == 1: 22 | obtained_action, neighbor_index, index_l2_diff, middle_l2_diff, ring_l2_diff, thumb_l2_diff, cube_l2_diff = inn.find_optimum_action(test_state, k) 23 | else: 24 | obtained_action = inn.find_optimum_action(test_state, k) 25 | 26 | action_loss = torch.norm(obtained_action - test_actions[idx]) 27 | 28 | loss += action_loss 29 | 30 | normalized_loss = loss / len(test_states) 31 | 32 | print("The testing loss is {}\n".format(normalized_loss)) 33 | 34 | return normalized_loss 35 | 36 | if __name__ == "__main__": 37 | k_losses = [] 38 | 39 | # Getting the state and action pairs 40 | train_states, train_actions, test_states, test_actions = load_train_test_data(DATA_PATH) 41 | 42 | # Checking if the state action pairs are valid 43 | assert len(test_states) == len(test_actions), "The number of states is not equal to the number of actions!" 44 | print("Total number of train state and action pairs: {}".format(len(train_states))) 45 | print("Total number of test state and action pairs: {}".format(len(test_states))) 46 | 47 | print("\nTesting the model for object priority values from 1 to 50 and finger-tip priority values from 1 to 10.\n") 48 | 49 | plt.figure() 50 | 51 | # Computing the loss for 20 different k-values 52 | for finger_priority_idx in range(10): 53 | finger_losses = [] 54 | 55 | for priority_idx in range(50): 56 | inn = INN(device="cuda", target_priority = priority_idx + 1, finger_priority = finger_priority_idx + 1) 57 | 58 | # Loading data into the model 59 | inn.get_data(train_states, train_actions) 60 | 61 | print("Computing loss for priority = {}".format(priority_idx + 1)) 62 | computed_loss = performance_test_inn(inn, 1, test_states, test_actions) 63 | 64 | finger_losses.append(computed_loss) 65 | 66 | print("The minimum observed loss was: {} for object priority: {}\n".format(min(finger_losses), finger_losses.index(min(finger_losses)) + 1)) 67 | plt.plot([x+1 for x in range(50)], finger_losses, label="finger_priority = {}".format(finger_priority_idx + 1)) 68 | plt.legend() 69 | 70 | 71 | 72 | # Saving the losses in a pickle file 73 | # loss_file = open(PERFORMANCE_METRICS_PATH, "ab") 74 | # pickle.dump(k_losses, loss_file) 75 | # loss_file.close() 76 | 77 | # Plotting the k-value based losses 78 | plt.xlabel("Object prioritized weights") 79 | 80 | plt.ylabel("Test loss") 81 | plt.show() 82 | # embed() 83 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dexterous Imitation Made Easy 2 | **Authors**: Sridhar Pandian Arunachalam*, [Sneha Silwal](http://ssilwal.com/)*, [Ben Evans](https://bennevans.github.io/), [Lerrel Pinto](https://lerrelpinto.com) 3 | 4 | This is the official implementation of the paper [Dexterous Manipulation Made Easy](https://arxiv.org/abs/2203.13251). 5 | 6 | ## Policies on Real Robot 7 |

8 | 9 | 10 | 11 |

12 | 13 | ## Method 14 | ![DIME](https://github.com/NYU-robot-learning/dime/blob/gh-pages/figs/intro.png) 15 | DIME consists of two phases: demonstration colleciton, which is performed in real-time with visual feedback, and demonstration-based policy learning, which can learn to solve dexterous tasks from a limited number of demonstrations. 16 | 17 | ## Setup 18 | The code base is split into 5 separate packages for convenience and this is one out of the five repositories. You can clone and setup each package by following the instructions on their respective repositories. The packages are: 19 | - [Robot controller packages](https://github.com/NYU-robot-learning/DIME-Controllers): 20 | - [Allegro Hand Controller](https://github.com/NYU-robot-learning/Allegro-Hand-Controller-DIME). 21 | - [Kinova Arm Controller](https://github.com/NYU-robot-learning/Kinova-Arm-Controller-DIME). 22 | - [Camera packages](https://github.com/NYU-robot-learning/DIME-Camera-Packages) 23 | - [Realsense-ROS](https://github.com/NYU-robot-learning/Realsense-ROS-DIME). 24 | - [AR_Tracker_Alvar](https://github.com/ros-perception/ar_track_alvar). 25 | - (Phase 1) Demonstration collection packages: 26 | - [Teleop with Inverse Kinematics Package](https://github.com/NYU-robot-learning/DIME-IK-TeleOp). 27 | - [State based and Image based Demonstration collection package](https://github.com/NYU-robot-learning/DIME-Demonstrations). 28 | - (Phase 2) Nearest-Neighbor Imitation Learning (present in this repository). 29 | - Simulation [environments](https://github.com/NYU-robot-learning/dime_env) and DAPG related codebase. 30 | 31 | You need to setup the Controller packages and IK-TeleOp package before using this package. 32 | To install the dependencies for this package with `pip`: 33 | ``` 34 | pip3 install torch==1.10.0+cu113 torchvision==0.11.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html 35 | ``` 36 | Then install this package with: 37 | ``` 38 | pip3 install -e . 39 | ``` 40 | 41 | ## Data 42 | All our data can be found in this URL: [https://drive.google.com/drive/folders/1nunGHB2EK9xvlmepNNziDDbt-pH8OAhi](https://drive.google.com/drive/folders/1nunGHB2EK9xvlmepNNziDDbt-pH8OAhi) 43 | 44 | ## Citation 45 | 46 | If you use this repo in your research, please consider citing the paper as follows: 47 | ``` 48 | @article{arunachalam2022dime, 49 | title={Dexterous Imitation Made Easy: A Learning-Based Framework for Efficient Dexterous Manipulation}, 50 | author={Sridhar Pandian Arunachalam and Sneha Silwal and Ben Evans and Lerrel Pinto}, 51 | journal={arXiv preprint arXiv:2203.13251}, 52 | year={2022} 53 | } 54 | -------------------------------------------------------------------------------- /arm_models/utils/load_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | def load_train_test_data(data_path = os.path.join(os.path.abspath(os.pardir), "data")): 5 | train_states, train_actions = load_state_actions(os.path.join(data_path, "train")) 6 | test_states, test_actions = load_state_actions(os.path.join(data_path, "validation")) 7 | return train_states, train_actions, test_states, test_actions 8 | 9 | def load_state_actions(data_path): 10 | # Getting the paths which contain the states and actions for each demo 11 | states_path, actions_path= os.path.join(data_path, "states"), os.path.join(data_path, "actions") 12 | 13 | # Getting the list of all demo states and actions 14 | states_demos_list, actions_demos_list = os.listdir(states_path), os.listdir(actions_path) 15 | assert len(states_demos_list) == len(actions_demos_list), "There are not equal number of state and action demo files!" 16 | 17 | # Sorting the state action pairs based on the demo numbers 18 | states_demos_list.sort(key=lambda f: int(''.join(filter(str.isdigit, f)))) 19 | actions_demos_list.sort(key=lambda f: int(''.join(filter(str.isdigit, f)))) 20 | 21 | states, actions = [], [] 22 | 23 | for demo_idx in range(len(states_demos_list)): 24 | demo_states = torch.load(os.path.join(states_path, states_demos_list[demo_idx])) 25 | demo_actions = torch.load(os.path.join(actions_path, actions_demos_list[demo_idx])) 26 | assert len(demo_states) == len(demo_actions), "Number of states: {} from {} and number of actions: {} from are not equal.\n".format(demo_states.shape[0], states_demos_list[demo_idx], demo_actions.shape[0], actions_demos_list[demo_idx]) 27 | 28 | states.append(demo_states) 29 | actions.append(demo_actions) 30 | 31 | # Converting the lists into a tensor 32 | states, actions = torch.cat(states), torch.cat(actions) 33 | 34 | return states, actions 35 | 36 | def load_representation_actions(data_path): 37 | # Getting the paths which contain the representions and actions for each demo 38 | representations_path, states_path, actions_path= os.path.join(data_path, "representations"), os.path.join(data_path, "states"), os.path.join(data_path, "actions") 39 | 40 | # Extracting the representations 41 | representations_file_name = os.listdir(representations_path)[0] 42 | representations = torch.load(os.path.join(representations_path, representations_file_name)) 43 | 44 | # Getting the list of all demo states and actions 45 | actions_demos_list = os.listdir(actions_path) 46 | 47 | # Sorting the actions based on the demo numbers 48 | actions_demos_list.sort(key=lambda f: int(''.join(filter(str.isdigit, f)))) 49 | 50 | actions = [] 51 | 52 | for demo_idx in range(len(actions_demos_list)): 53 | demo_actions = torch.load(os.path.join(actions_path, actions_demos_list[demo_idx])) 54 | actions.append(demo_actions) 55 | 56 | # Converting the lists into a tensor 57 | actions = torch.cat(actions) 58 | 59 | return representations, actions 60 | 61 | def load_state_image_data(data_path): 62 | image_data_path = os.path.join(data_path, "images") 63 | demo_image_folders = os.listdir(image_data_path) 64 | demo_image_folders.sort(key=lambda f: int(''.join(filter(str.isdigit, f)))) 65 | 66 | cumm_demos_image_count = [] 67 | 68 | for demo_folder in demo_image_folders: 69 | demo_image_folder_path = os.path.join(image_data_path, demo_folder) 70 | demo_image_count = len(os.listdir(demo_image_folder_path)) 71 | 72 | if len(cumm_demos_image_count) > 0: 73 | cumm_demos_image_count.append(cumm_demos_image_count[-1] + demo_image_count - 1) 74 | else: 75 | cumm_demos_image_count.append(demo_image_count) 76 | 77 | return image_data_path, demo_image_folders, cumm_demos_image_count 78 | -------------------------------------------------------------------------------- /arm_models/deploy/model_scripts/visual_inn.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from arm_models.imitation.non_parametric import VINN 4 | from arm_models.utils.load_data import load_representation_actions, load_state_image_data 5 | 6 | DATA_PATH = "/home/sridhar/dexterous_arm/models/arm_models/data/fidget_spinning/complete" 7 | 8 | class VINNDeploy(): 9 | def __init__(self, data_path = DATA_PATH, k = 1, load_image_data = True, device = "cpu"): 10 | self.k = k 11 | self.image_data_path = None 12 | 13 | self.representations, self.actions = load_representation_actions(data_path) 14 | 15 | self.model = VINN(device) 16 | self.model.get_data(self.representations, self.actions) 17 | 18 | if load_image_data is True: 19 | print(data_path) 20 | self.image_data_path, self.demo_image_folders, self.cumm_demos_image_count = load_state_image_data(data_path) 21 | 22 | def get_action(self, representation): 23 | if self.k == 1: 24 | action, neighbor_idx = self.model.find_optimum_action(representation, self.k) 25 | return action.detach().cpu().numpy() 26 | else: 27 | return self.model.find_optimum_action(representation, self.k).detach().cpu().numpy() 28 | 29 | def get_debug_action(self, neighbor_index): 30 | if neighbor_index < self.cumm_demos_image_count[0]: 31 | nn_image_num = neighbor_index 32 | nn_trans_image_num = nn_image_num + 1 33 | demo_num = 0 34 | else: 35 | for idx, cumm_demo_images in enumerate(self.cumm_demos_image_count): 36 | if neighbor_index + 1 > cumm_demo_images: 37 | # Obtaining the demo number 38 | demo_num = idx + 1 39 | 40 | # Getting the corresponding image number 41 | nn_image_num = neighbor_index - cumm_demo_images 42 | nn_trans_image_num = nn_image_num + 1 43 | 44 | nn_state_image_path = os.path.join(self.image_data_path, self.demo_image_folders[demo_num], "state_{}.jpg".format(nn_image_num.item())) 45 | trans_state_image_path = os.path.join(self.image_data_path, self.demo_image_folders[demo_num], "state_{}.jpg".format(nn_trans_image_num.item())) 46 | 47 | 48 | return self.actions[neighbor_index + 1], neighbor_index + 1, nn_state_image_path, trans_state_image_path 49 | 50 | def get_action_with_image(self, representation): 51 | if self.k == 1 and self.image_data_path is not None: 52 | # state = list(index_tip_coord) + list(middle_tip_coord) + list(ring_tip_coord) + list(thumb_tip_coord) + list(cube_pos) 53 | calculated_action, neighbor_index = self.model.find_optimum_action(representation, self.k) 54 | 55 | if neighbor_index < self.cumm_demos_image_count[0]: 56 | nn_image_num = neighbor_index 57 | nn_trans_image_num = nn_image_num + 1 58 | demo_num = 0 59 | else: 60 | for idx, cumm_demo_images in enumerate(self.cumm_demos_image_count): 61 | if neighbor_index + 1 > cumm_demo_images: 62 | # Obtaining the demo number 63 | demo_num = idx + 1 64 | 65 | # Getting the corresponding image number 66 | nn_image_num = neighbor_index - cumm_demo_images 67 | nn_trans_image_num = nn_image_num + 1 68 | 69 | nn_state_image_path = os.path.join(self.image_data_path, self.demo_image_folders[demo_num], "state_{}.jpg".format(nn_image_num.item())) 70 | trans_state_image_path = os.path.join(self.image_data_path, self.demo_image_folders[demo_num], "state_{}.jpg".format(nn_trans_image_num.item())) 71 | 72 | return calculated_action.cpu().detach().numpy(), neighbor_index, nn_state_image_path, trans_state_image_path 73 | else: 74 | return self.get_action(representation) -------------------------------------------------------------------------------- /arm_models/deploy/model_scripts/inn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from arm_models.imitation.non_parametric import INN 5 | from arm_models.utils.load_data import load_state_actions, load_state_image_data 6 | 7 | class INNDeploy(): 8 | def __init__(self, data_path, target_priority = 1, k = 1, load_image_data = False, device = "cpu"): 9 | self.k = k 10 | self.image_data_path = None 11 | 12 | self.states, self.actions = load_state_actions(data_path) 13 | 14 | self.model = INN(device, target_priority = target_priority) 15 | self.model.get_data(self.states, self.actions) 16 | 17 | if load_image_data is True: 18 | print("Loading data from path:", data_path) 19 | self.image_data_path, self.demo_image_folders, self.cumm_demos_image_count = load_state_image_data(data_path) 20 | 21 | def get_action(self, index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord, object_position): 22 | state = list(index_tip_coord) + list(middle_tip_coord) + list(ring_tip_coord) + list(thumb_tip_coord) + list(object_position) 23 | if self.k == 1: 24 | action, neighbor_idx, index_l2_diff, middle_l2_diff, ring_l2_diff, thumb_l2_diff, object_l2_diff = self.model.find_optimum_action(state, self.k) 25 | return action.cpu().detach().numpy() 26 | else: 27 | return self.model.find_optimum_action(state, self.k).cpu().detach().numpy() 28 | 29 | def get_action_with_image(self, index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord, object_position): 30 | if self.k == 1 and self.image_data_path is not None: 31 | state = list(index_tip_coord) + list(middle_tip_coord) + list(ring_tip_coord) + list(thumb_tip_coord) + list(object_position) 32 | calculated_action, neighbor_index, index_l2_diff, middle_l2_diff, ring_l2_diff, thumb_l2_diff, object_l2_diff = self.model.find_optimum_action(state, self.k) 33 | 34 | if neighbor_index < self.cumm_demos_image_count[0]: 35 | nn_image_num = neighbor_index 36 | nn_trans_image_num = nn_image_num + 1 37 | demo_num = 0 38 | else: 39 | for idx, cumm_demo_images in enumerate(self.cumm_demos_image_count): 40 | if neighbor_index + 1 > cumm_demo_images: 41 | # Obtaining the demo number 42 | demo_num = idx + 1 43 | 44 | # Getting the corresponding image number 45 | nn_image_num = neighbor_index - cumm_demo_images 46 | nn_trans_image_num = nn_image_num + 1 47 | 48 | nn_state_image_path = os.path.join(self.image_data_path, self.demo_image_folders[demo_num], "state_{}.jpg".format(nn_image_num.item())) 49 | trans_state_image_path = os.path.join(self.image_data_path, self.demo_image_folders[demo_num], "state_{}.jpg".format(nn_trans_image_num.item())) 50 | 51 | return calculated_action.cpu().detach().numpy(), neighbor_index, nn_state_image_path, trans_state_image_path, index_l2_diff, middle_l2_diff, ring_l2_diff, thumb_l2_diff, object_l2_diff 52 | else: 53 | self.get_action(index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord, object_position) 54 | 55 | def get_debug_action(self, neighbor_index): 56 | if neighbor_index < self.cumm_demos_image_count[0]: 57 | nn_image_num = neighbor_index 58 | nn_trans_image_num = nn_image_num + 1 59 | demo_num = 0 60 | else: 61 | for idx, cumm_demo_images in enumerate(self.cumm_demos_image_count): 62 | if neighbor_index + 1 > cumm_demo_images: 63 | # Obtaining the demo number 64 | demo_num = idx + 1 65 | 66 | # Getting the corresponding image number 67 | nn_image_num = neighbor_index - cumm_demo_images 68 | nn_trans_image_num = nn_image_num + 1 69 | 70 | nn_state_image_path = os.path.join(self.image_data_path, self.demo_image_folders[demo_num], "state_{}.jpg".format(nn_image_num.item())) 71 | trans_state_image_path = os.path.join(self.image_data_path, self.demo_image_folders[demo_num], "state_{}.jpg".format(nn_trans_image_num.item())) 72 | 73 | return self.actions[neighbor_index + 1], neighbor_index + 1, nn_state_image_path, trans_state_image_path -------------------------------------------------------------------------------- /arm_models/imitation/non_parametric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class INN(): 5 | def __init__(self, device="cpu", target_priority = 1, finger_priority = 1): 6 | # Setting the device 7 | self.device = torch.device(device) 8 | 9 | # Initializing the states and action pairs to None 10 | self.states = None 11 | self.actions = None 12 | 13 | self.target_priority = target_priority 14 | self.thumb_priority = finger_priority 15 | self.ring_priority = finger_priority 16 | 17 | def get_data(self, states, actions): 18 | self.states = states.to(self.device) 19 | self.actions = actions.to(self.device) 20 | 21 | def getNearestNeighbors(self, input_state, k): 22 | # Comparing the dataset shape and state 23 | assert torch.tensor(input_state).shape == self.states[0].shape, "There is a data shape mismatch: \n Shape of loaded dataset: {} \n Shape of current state: {}".format(input_state.shape, self.states[0].shape) 24 | 25 | # Getting the k-Nearest Neighbor actions 26 | state_diff = self.states - torch.tensor(input_state).to(self.device) 27 | 28 | index_l2_diff = torch.norm(state_diff[:, :3], dim=1) 29 | middle_l2_diff = torch.norm(state_diff[:, 3:6], dim=1) 30 | ring_l2_diff = torch.norm(state_diff[:, 6:9], dim=1) 31 | thumb_l2_diff = torch.norm(state_diff[:, 9:12], dim=1) 32 | object_l2_diff = torch.norm(state_diff[:, 12:], dim=1) 33 | 34 | l2_diff = index_l2_diff + middle_l2_diff + (ring_l2_diff * self.ring_priority) + (thumb_l2_diff + self.thumb_priority) + (object_l2_diff * self.target_priority) 35 | 36 | sorted_idxs = torch.argsort(l2_diff)[:k] 37 | 38 | if k == 1: 39 | k_nn_actions = self.actions[torch.argsort(l2_diff)[0]] 40 | return k_nn_actions, torch.argsort(l2_diff)[0].cpu().detach(), index_l2_diff[sorted_idxs], middle_l2_diff[sorted_idxs], ring_l2_diff[sorted_idxs], thumb_l2_diff[sorted_idxs], object_l2_diff[sorted_idxs] 41 | else: 42 | return self.actions[torch.argsort(l2_diff)[:k]] 43 | 44 | 45 | def find_optimum_action(self, input_state, k): 46 | # Getting the k-Nearest Neighbor actions for the input state 47 | if k == 1: 48 | k_nn_action, neighbor_idx, index_l2_diff, middle_l2_diff, ring_l2_diff, thumb_l2_diff, object_l2_diff = self.getNearestNeighbors(input_state, k) 49 | return k_nn_action.cpu().detach(), neighbor_idx, index_l2_diff, middle_l2_diff, ring_l2_diff, thumb_l2_diff, object_l2_diff 50 | else: 51 | k_nn_actions = self.getNearestNeighbors(input_state, k) 52 | 53 | # Getting the mean value from the set of nearest neighbor states 54 | mean_action = torch.mean(k_nn_actions, 0).cpu().detach() 55 | return mean_action 56 | 57 | class VINN(): 58 | def __init__(self, device="cpu", threshold = 0.02): 59 | # Setting the device 60 | self.device = torch.device(device) 61 | 62 | # Initializing the representations and action pairs to None 63 | self.representations = None 64 | self.actions = None 65 | 66 | self.softmax_fiter = nn.Softmax(dim = 0) 67 | 68 | def get_data(self, representations, actions): 69 | self.representations = representations.to(self.device) 70 | self.actions = actions.to(self.device) 71 | 72 | def getNearestNeighbors(self, input_representation, k): 73 | # Comparing the dataset shape and representation 74 | assert torch.tensor(input_representation).shape == self.representations[0].shape, "There is a data shape mismatch: \n Shape of loaded representations: {} \n Shape of input representation: {}".format(input_representation.shape, self.representations[0].shape) 75 | 76 | # Getting the k-Nearest Neighbor actions 77 | representation_difference = self.representations - torch.tensor(input_representation).to(self.device) 78 | 79 | l2_diff = torch.norm(representation_difference, dim=1) 80 | softmax_l2_diff = self.softmax_fiter(l2_diff) 81 | sorted_idxs = torch.argsort(softmax_l2_diff)[:k] 82 | 83 | # Get the actions which have the closest action value 84 | 85 | if k > 1: 86 | return self.actions[sorted_idxs] 87 | else: 88 | return self.actions[sorted_idxs].cpu().detach(), sorted_idxs[0].cpu().detach() 89 | 90 | 91 | def find_optimum_action(self, input_representation, k): 92 | # Getting the k-Nearest Neighbor actions for the input representation 93 | if k == 1: 94 | return self.getNearestNeighbors(input_representation, k) 95 | else: 96 | k_nn_actions = self.getNearestNeighbors(input_representation, k) 97 | 98 | # Getting the mean value from the set of nearest neighbor representations 99 | return torch.mean(k_nn_actions, 0).cpu().detach() 100 | -------------------------------------------------------------------------------- /arm_models/deploy/deploy_scripts/mlp_deploy.py: -------------------------------------------------------------------------------- 1 | # Standard imports 2 | import numpy as np 3 | import cv2 4 | import torch 5 | 6 | # Standard ROS imports 7 | import rospy 8 | from sensor_msgs.msg import Image, JointState 9 | from visualization_msgs.msg import Marker 10 | from cv_bridge import CvBridge, CvBridgeError 11 | 12 | # Dexterous aarm control package import 13 | from move_dexarm import DexArmControl 14 | 15 | # Importing the Allegro IK library 16 | from ik_teleop.ik_core.allegro_ik import AllegroInvKDL 17 | 18 | # Model Imports 19 | from arm_models.imitation_models.Simple_MLP.model import MLP 20 | 21 | # Other imports 22 | from copy import deepcopy as copy 23 | 24 | # ROS Topics to get the AR Marker data 25 | IMAGE_TOPIC = "/cam_1/color/image_raw" 26 | MARKER_TOPIC = "/visualization_marker" 27 | JOINT_STATE_TOPIC = "/allegroHand/joint_states" 28 | 29 | MODEL_CHKPT_PATH = '/home/sridhar/dexterous_arm/models/arm_models/dexarm_deployment/model_checkpoints/mlp-final.pth' 30 | 31 | class DexArmMLPDeploy(): 32 | def __init__(self, device = 'cpu'): 33 | # Ignoring scientific notations 34 | torch.set_printoptions(precision=None) 35 | 36 | # Initializing ROS deployment node 37 | try: 38 | rospy.init_node("model_deploy") 39 | except: 40 | pass 41 | 42 | # Initializing arm controller 43 | self.arm = DexArmControl() 44 | 45 | # Initializing Allegro Ik Library 46 | self.allegro_ik = AllegroInvKDL(cfg = None, urdf_path = "/home/sridhar/dexterous_arm/ik_stuff/ik_teleop/urdf_template/allegro_right.urdf") 47 | 48 | # Loading the model and assigning device 49 | self.device = torch.device(device) 50 | self.model = MLP() 51 | self.model = self.model.to(device) 52 | self.model.load_state_dict(torch.load(MODEL_CHKPT_PATH)) 53 | self.model.eval() 54 | 55 | # Moving the dexterous arm to home position 56 | self.arm.home_robot() 57 | 58 | # Initializing topic data 59 | self.image = None 60 | self.allegro_joint_state = None 61 | self.ar_marker_data = None 62 | 63 | # Other initializations 64 | self.bridge = CvBridge() 65 | 66 | rospy.Subscriber(MARKER_TOPIC, Marker, self._callback_ar_marker_data, queue_size = 1) 67 | rospy.Subscriber(JOINT_STATE_TOPIC, JointState, self._callback_joint_state, queue_size = 1) 68 | rospy.Subscriber(IMAGE_TOPIC, Image, self._callback_image, queue_size=1) 69 | 70 | def _callback_ar_marker_data(self, data): 71 | self.ar_marker_data = data 72 | 73 | def _callback_joint_state(self, data): 74 | self.allegro_joint_state = data 75 | 76 | def _callback_image(self, data): 77 | try: 78 | self.image = self.bridge.imgmsg_to_cv2(data, "bgr8") 79 | except CvBridgeError as e: 80 | print(e) 81 | 82 | def get_cube_position(self): 83 | cube_position = np.array([self.ar_marker_data.pose.position.x, self.ar_marker_data.pose.position.y, self.ar_marker_data.pose.position.z]) 84 | return cube_position 85 | 86 | def get_tip_coords(self): 87 | index_coord = self.allegro_ik.finger_forward_kinematics('index', list(self.allegro_joint_state.position)[:4])[0] 88 | middle_coord = self.allegro_ik.finger_forward_kinematics('middle', list(self.allegro_joint_state.position)[4:8])[0] 89 | ring_coord = self.allegro_ik.finger_forward_kinematics('ring', list(self.allegro_joint_state.position)[8:12])[0] 90 | thumb_coord = self.allegro_ik.finger_forward_kinematics('thumb', list(self.allegro_joint_state.position)[12:16])[0] 91 | 92 | return index_coord, middle_coord, ring_coord, thumb_coord 93 | 94 | def update_joint_state(self, index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord): 95 | current_joint_angles = list(self.allegro_joint_state.position) 96 | 97 | index_joint_angles = self.allegro_ik.finger_inverse_kinematics('index', index_tip_coord, current_joint_angles[0:4]) 98 | middle_joint_angles = self.allegro_ik.finger_inverse_kinematics('middle', middle_tip_coord, current_joint_angles[4:8]) 99 | ring_joint_angles = self.allegro_ik.finger_inverse_kinematics('ring', ring_tip_coord, current_joint_angles[8:12]) 100 | thumb_joint_angles = self.allegro_ik.finger_inverse_kinematics('thumb', thumb_tip_coord, current_joint_angles[12:16]) 101 | 102 | desired_joint_angles = copy(current_joint_angles) 103 | 104 | for idx in range(4): 105 | desired_joint_angles[idx] = index_joint_angles[idx] 106 | desired_joint_angles[4 + idx] = middle_joint_angles[idx] 107 | desired_joint_angles[8 + idx] = ring_joint_angles[idx] 108 | desired_joint_angles[12 + idx] = thumb_joint_angles[idx] 109 | 110 | return desired_joint_angles 111 | 112 | def deploy(self): 113 | while True: 114 | # Waiting for key 115 | next_step = input() 116 | 117 | # Setting the break condition 118 | if next_step == "q": 119 | break 120 | 121 | print("Getting state data...\n") 122 | 123 | print("Current joint angles: {}\n".format(self.allegro_joint_state.position)) 124 | 125 | # Getting the state data 126 | index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord = self.get_tip_coords() 127 | cube_pos = self.get_cube_position() 128 | 129 | print("Current state data:\n Index-tip position: {}\n Middle-tip position: {}\n Ring-tip position: {}\n Thumb-tip position: {}\n Cube position: {}\n".format(index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord, cube_pos)) 130 | 131 | cv2.imshow("Current state", self.image) 132 | cv2.waitKey(1) 133 | 134 | # Passing the input to the model to get the corresponding action 135 | input_coordinates = torch.tensor(np.concatenate([index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord, cube_pos])) 136 | action = self.model(input_coordinates.float().to(self.device)).detach().cpu().numpy()[0] 137 | 138 | print("Corresponding action: ", action) 139 | 140 | # Updating the hand target coordinates 141 | updated_index_tip_coord = action[0:3] 142 | updated_middle_tip_coord = action[3:6] 143 | updated_ring_tip_coord = action[6:9] 144 | updated_thumb_tip_coord = action[9:12] 145 | 146 | print("updated index tip coord:", updated_index_tip_coord) 147 | 148 | desired_joint_angles = self.update_joint_state(updated_index_tip_coord, updated_middle_tip_coord, updated_ring_tip_coord, updated_thumb_tip_coord) 149 | 150 | print("Moving arm to {}\n".format(desired_joint_angles)) 151 | self.arm.move_hand(desired_joint_angles) 152 | 153 | if __name__ == '__main__': 154 | d = DexArmMLPDeploy(device = 'cpu') 155 | d.deploy() -------------------------------------------------------------------------------- /arm_models/deploy/deploy_scripts/bc_deploy.py: -------------------------------------------------------------------------------- 1 | # Standard imports 2 | import cv2 3 | import torch 4 | from PIL import Image as PILImage 5 | 6 | # Standard ROS imports 7 | import rospy 8 | from sensor_msgs.msg import Image, JointState 9 | from cv_bridge import CvBridge, CvBridgeError 10 | 11 | # Dexterous aarm control package import 12 | from move_dexarm import DexArmControl 13 | 14 | # Importing the Allegro IK library 15 | from ik_teleop.ik_core.allegro_ik import AllegroInvKDL 16 | 17 | # Model Imports 18 | from arm_models.imitation_models.Behavior_Cloning.model import BehaviorCloning 19 | 20 | # Other imports 21 | from copy import deepcopy as copy 22 | from torchvision import transforms as T 23 | 24 | # ROS Topics to get the AR Marker data 25 | IMAGE_TOPIC = "/cam_1/color/image_raw" 26 | JOINT_STATE_TOPIC = "/allegroHand/joint_states" 27 | 28 | MODEL_CHKPT_PATH = '/home/sridhar/dexterous_arm/models/arm_models/dexarm_deployment/model_checkpoints/behavior_cloning.pth' 29 | 30 | class DexArmMLPDeploy(): 31 | def __init__(self, device = 'cpu'): 32 | # Ignoring scientific notations 33 | torch.set_printoptions(precision=None) 34 | 35 | # Initializing ROS deployment node 36 | try: 37 | rospy.init_node("model_deploy") 38 | except: 39 | pass 40 | 41 | # Initializing arm controller 42 | self.arm = DexArmControl() 43 | 44 | # Initializing Allegro Ik Library 45 | self.allegro_ik = AllegroInvKDL(cfg = None, urdf_path = "/home/sridhar/dexterous_arm/ik_stuff/ik_teleop/urdf_template/allegro_right.urdf") 46 | 47 | # Loading the model and assigning device 48 | self.device = torch.device(device) 49 | self.model = BehaviorCloning() 50 | self.model = self.model.to(device) 51 | self.model.load_state_dict(torch.load(MODEL_CHKPT_PATH)) 52 | self.model.eval() 53 | 54 | # Moving the dexterous arm to home position 55 | self.arm.home_robot() 56 | 57 | # Initializing topic data 58 | self.image = None 59 | self.allegro_joint_state = None 60 | 61 | # Initalizing image transformer 62 | self.image_transform = T.Compose([ 63 | T.ToTensor(), 64 | T.Resize((224, 224)), 65 | T.Normalize( 66 | mean = torch.tensor([0.3484, 0.3638, 0.3819]), 67 | std = torch.tensor([0.3224, 0.3151, 0.3166]) 68 | ) 69 | ]) 70 | 71 | # Other initializations 72 | self.bridge = CvBridge() 73 | 74 | rospy.Subscriber(JOINT_STATE_TOPIC, JointState, self._callback_joint_state, queue_size = 1) 75 | rospy.Subscriber(IMAGE_TOPIC, Image, self._callback_image, queue_size=1) 76 | 77 | def _callback_ar_marker_data(self, data): 78 | self.ar_marker_data = data 79 | 80 | def _callback_joint_state(self, data): 81 | self.allegro_joint_state = data 82 | 83 | def _callback_image(self, data): 84 | try: 85 | self.image = self.bridge.imgmsg_to_cv2(data, "bgr8") 86 | except CvBridgeError as e: 87 | print(e) 88 | 89 | def get_tip_coords(self): 90 | index_coord = self.allegro_ik.finger_forward_kinematics('index', list(self.allegro_joint_state.position)[:4])[0] 91 | middle_coord = self.allegro_ik.finger_forward_kinematics('middle', list(self.allegro_joint_state.position)[4:8])[0] 92 | ring_coord = self.allegro_ik.finger_forward_kinematics('ring', list(self.allegro_joint_state.position)[8:12])[0] 93 | thumb_coord = self.allegro_ik.finger_forward_kinematics('thumb', list(self.allegro_joint_state.position)[12:16])[0] 94 | 95 | return index_coord, middle_coord, ring_coord, thumb_coord 96 | 97 | def update_joint_state(self, index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord): 98 | current_joint_angles = list(self.allegro_joint_state.position) 99 | 100 | index_joint_angles = self.allegro_ik.finger_inverse_kinematics('index', index_tip_coord, current_joint_angles[0:4]) 101 | middle_joint_angles = self.allegro_ik.finger_inverse_kinematics('middle', middle_tip_coord, current_joint_angles[4:8]) 102 | ring_joint_angles = self.allegro_ik.finger_inverse_kinematics('ring', ring_tip_coord, current_joint_angles[8:12]) 103 | thumb_joint_angles = self.allegro_ik.finger_inverse_kinematics('thumb', thumb_tip_coord, current_joint_angles[12:16]) 104 | 105 | desired_joint_angles = copy(current_joint_angles) 106 | 107 | for idx in range(4): 108 | desired_joint_angles[idx] = index_joint_angles[idx] 109 | desired_joint_angles[4 + idx] = middle_joint_angles[idx] 110 | desired_joint_angles[8 + idx] = ring_joint_angles[idx] 111 | desired_joint_angles[12 + idx] = thumb_joint_angles[idx] 112 | 113 | return desired_joint_angles 114 | 115 | def deploy(self): 116 | while True: 117 | # Waiting for key 118 | next_step = input() 119 | 120 | # Setting the break condition 121 | if next_step == "q": 122 | break 123 | 124 | print("Getting state data...\n") 125 | 126 | print("Current joint angles: {}\n".format(self.allegro_joint_state.position)) 127 | 128 | # Getting the state data 129 | index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord = self.get_tip_coords() 130 | 131 | print("Current state data:\n Index-tip position: {}\n Middle-tip position: {}\n Ring-tip position: {}\n Thumb-tip position: {}\n".format(index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord)) 132 | 133 | cv2.imshow("Current state", self.image) 134 | cv2.waitKey(1) 135 | 136 | # Transforming the image before passing it into the model 137 | pil_image = PILImage.fromarray(cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)) 138 | cropped_image = pil_image.crop((450, 220, 900, 920)) # Left, Top, Right, Bottom 139 | transformed_image_tensor = self.image_transform(cropped_image).unsqueeze(0) 140 | 141 | action = self.model(transformed_image_tensor.float().to(self.device)).detach().cpu().numpy()[0] 142 | 143 | print("Corresponding action: ", action) 144 | 145 | # Updating the hand target coordinates 146 | updated_index_tip_coord = action[0:3] 147 | updated_middle_tip_coord = action[3:6] 148 | updated_ring_tip_coord = action[6:9] 149 | updated_thumb_tip_coord = action[9:12] 150 | 151 | print("updated index tip coord:", updated_index_tip_coord) 152 | 153 | desired_joint_angles = self.update_joint_state(updated_index_tip_coord, updated_middle_tip_coord, updated_ring_tip_coord, updated_thumb_tip_coord) 154 | 155 | print("Moving arm to {}\n".format(desired_joint_angles)) 156 | self.arm.move_hand(desired_joint_angles) 157 | 158 | if __name__ == '__main__': 159 | d = DexArmMLPDeploy(device = 'cpu') 160 | d.deploy() -------------------------------------------------------------------------------- /arm_models/deploy/deploy_scripts/vinn_deploy.py: -------------------------------------------------------------------------------- 1 | # Standard imports 2 | import numpy as np 3 | import cv2 4 | import torch 5 | from PIL import Image as PILImage 6 | 7 | # Standard ROS imports 8 | import rospy 9 | from sensor_msgs.msg import Image, JointState 10 | from cv_bridge import CvBridge, CvBridgeError 11 | 12 | # Dexterous aarm control package import 13 | from move_dexarm import DexArmControl 14 | 15 | # Importing the Allegro IK library 16 | from ik_teleop.ik_core.allegro_ik import AllegroInvKDL 17 | 18 | # Model Imports 19 | from arm_models.dexarm_deployment.model_scripts.visual_inn import VINNDeploy 20 | 21 | # Other imports 22 | from copy import deepcopy as copy 23 | from torchvision import models 24 | from torchvision import transforms as T 25 | 26 | # ROS Topics to get the AR Marker data 27 | IMAGE_TOPIC = "/cam_1/color/image_raw" 28 | JOINT_STATE_TOPIC = "/allegroHand/joint_states" 29 | ENCODER_MODEL_CHKPT = "/home/sridhar/dexterous_arm/models/arm_models/dexarm_deployment/model_checkpoints/rotation/BYOL-VINN-rotation-lowest-train-loss.pth" 30 | 31 | class DexArmVINNDeploy(): 32 | def __init__(self, task, k = 1, use_abs = True, debug = False, device = "cpu"): 33 | # Initializing ROS deployment node 34 | try: 35 | rospy.init_node("model_deploy") 36 | except: 37 | pass 38 | 39 | # Setting the debug parameter only if k is 1 40 | if k == 1: 41 | self.debug = debug 42 | else: 43 | self.debug = False 44 | 45 | self.abs = use_abs 46 | 47 | # Initializing arm controller 48 | print("Initializing controller!") 49 | self.arm = DexArmControl() 50 | 51 | # Initializing Allegro Ik Library 52 | self.allegro_ik = AllegroInvKDL(cfg = None, urdf_path = "/home/sridhar/dexterous_arm/ik_stuff/ik_teleop/urdf_template/allegro_right.urdf") 53 | 54 | # Initializing INN 55 | print("Initializing model!") # TODO - put data path 56 | if task == "rotate": 57 | if self.debug is True: 58 | self.model = VINNDeploy(k = 1, load_image_data = True, device = device) 59 | else: 60 | self.model = VINNDeploy(k = k, device = device) 61 | elif task == "flip": 62 | if self.debug is True: 63 | self.model = VINNDeploy(k = 1, load_image_data = True, device = device) 64 | else: 65 | self.model = VINNDeploy(k = k, device = device) 66 | 67 | # Making sure the hand moves 68 | self.threshold = 0.02 69 | 70 | # Moving the dexterous arm to home position 71 | print("Homing robot!\n") 72 | self.arm.home_robot() 73 | 74 | # Initializing topic data 75 | self.image = None 76 | self.allegro_joint_state = None 77 | 78 | # Initializing the representation extractor 79 | self.device = torch.device(device) 80 | original_encoder_model = models.resnet50(pretrained = True) 81 | self.encoder = torch.nn.Sequential(*(list(original_encoder_model.children())[:-1])) 82 | self.encoder.to(self.device) 83 | self.encoder.load_state_dict(torch.load(ENCODER_MODEL_CHKPT)) 84 | self.encoder.eval() 85 | 86 | # Initializing image transformation function 87 | self.image_transform = T.Compose([ 88 | T.ToTensor(), 89 | T.Resize((224, 224)), 90 | T.Normalize( 91 | mean = torch.tensor([0.3484, 0.3638, 0.3819]), 92 | std = torch.tensor([0.3224, 0.3151, 0.3166]) 93 | ) 94 | ]) 95 | 96 | # Other initializations 97 | self.bridge = CvBridge() 98 | 99 | rospy.Subscriber(JOINT_STATE_TOPIC, JointState, self._callback_joint_state, queue_size = 1) 100 | rospy.Subscriber(IMAGE_TOPIC, Image, self._callback_image, queue_size=1) 101 | 102 | def _callback_ar_marker_data(self, data): 103 | self.ar_marker_data = data 104 | 105 | def _callback_joint_state(self, data): 106 | self.allegro_joint_state = data 107 | 108 | def _callback_image(self, data): 109 | try: 110 | self.image = self.bridge.imgmsg_to_cv2(data, "bgr8") 111 | except CvBridgeError as e: 112 | print(e) 113 | 114 | def get_tip_coords(self): 115 | index_coord = self.allegro_ik.finger_forward_kinematics('index', list(self.allegro_joint_state.position)[:4])[0] 116 | middle_coord = self.allegro_ik.finger_forward_kinematics('middle', list(self.allegro_joint_state.position)[4:8])[0] 117 | ring_coord = self.allegro_ik.finger_forward_kinematics('ring', list(self.allegro_joint_state.position)[8:12])[0] 118 | thumb_coord = self.allegro_ik.finger_forward_kinematics('thumb', list(self.allegro_joint_state.position)[12:16])[0] 119 | 120 | return index_coord, middle_coord, ring_coord, thumb_coord 121 | 122 | def update_joint_state(self, index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord): 123 | current_joint_angles = list(self.allegro_joint_state.position) 124 | 125 | index_joint_angles = self.allegro_ik.finger_inverse_kinematics('index', index_tip_coord, current_joint_angles[0:4]) 126 | middle_joint_angles = self.allegro_ik.finger_inverse_kinematics('middle', middle_tip_coord, current_joint_angles[4:8]) 127 | ring_joint_angles = self.allegro_ik.finger_inverse_kinematics('ring', ring_tip_coord, current_joint_angles[8:12]) 128 | thumb_joint_angles = self.allegro_ik.finger_inverse_kinematics('thumb', thumb_tip_coord, current_joint_angles[12:16]) 129 | 130 | desired_joint_angles = copy(current_joint_angles) 131 | 132 | for idx in range(4): 133 | desired_joint_angles[idx] = index_joint_angles[idx] 134 | desired_joint_angles[4 + idx] = middle_joint_angles[idx] 135 | desired_joint_angles[8 + idx] = ring_joint_angles[idx] 136 | desired_joint_angles[12 + idx] = thumb_joint_angles[idx] 137 | 138 | return desired_joint_angles 139 | 140 | def deploy(self): 141 | while True: 142 | if self.allegro_joint_state is None: 143 | print('No allegro state received!') 144 | continue 145 | 146 | if self.image is None: 147 | # print('No robot image received!') 148 | continue 149 | 150 | # Wating for key 151 | next_step = input() 152 | 153 | # Setting the break condition 154 | if next_step == "q": 155 | break 156 | 157 | print("Getting state data...\n") 158 | 159 | print("Current joint angles: {}\n".format(self.allegro_joint_state.position)) 160 | 161 | # Getting the state data 162 | index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord = self.get_tip_coords() 163 | finger_state = list(index_tip_coord) + list(middle_tip_coord) + list(ring_tip_coord) + list(thumb_tip_coord) 164 | 165 | print("Current state data:\n Index-tip position: {}\n Middle-tip position: {}\n Ring-tip position: {}\n Thumb-tip position: {}\n".format(index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord)) 166 | 167 | cv2.imshow("Current state image", self.image) 168 | cv2.waitKey(1) 169 | 170 | # Transforming the image 171 | pil_image = PILImage.fromarray(cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)) 172 | cropped_image = pil_image.crop((450, 220, 900, 920)) # Left, Top, Right, Bottom 173 | transformed_image_tensor = self.image_transform(cropped_image).unsqueeze(0) 174 | 175 | representation = self.encoder(transformed_image_tensor.float().to(self.device)).squeeze(2).squeeze(2).detach().cpu().numpy()[0] 176 | 177 | if self.debug is True: 178 | action, neighbor_idx, nn_state_image_path, trans_state_image_path = self.model.get_action_with_image(representation) 179 | action = list(action.reshape(12)) 180 | 181 | l2_diff = np.linalg.norm(np.array(finger_state[:12]) - np.array(action)) 182 | 183 | while l2_diff < self.threshold: 184 | print("Obtaining new action!") 185 | action, neighbor_idx, nn_state_image_path, trans_state_image_path = self.model.get_debug_action(neighbor_idx) 186 | action = list(action.reshape(12)) 187 | 188 | l2_diff = np.linalg.norm(np.array(finger_state[:12]) - np.array(action)) 189 | 190 | 191 | # Reading the Nearest Neighbor images 192 | nn_state_image = cv2.imread(nn_state_image_path) 193 | trans_state_image = cv2.imread(trans_state_image_path) 194 | 195 | # Combing the surrent state image and NN image and printing them 196 | combined_images = np.concatenate(( 197 | cv2.resize(self.image, (420, 240), interpolation = cv2.INTER_AREA), 198 | cv2.resize(nn_state_image, (420, 240), interpolation = cv2.INTER_AREA), 199 | cv2.resize(trans_state_image, (420, 240), interpolation = cv2.INTER_AREA) 200 | ), axis=1) 201 | 202 | cv2.imshow("Current state and Nearest Neighbor Images", combined_images) 203 | cv2.waitKey(1) 204 | 205 | else: 206 | action = list(self.model.get_action(representation).reshape(12)) 207 | # To write debug condition for k > 1 208 | 209 | print("Corresponding action: ", action) 210 | 211 | if self.abs is True: 212 | updated_index_tip_coord = action[0:3] 213 | updated_middle_tip_coord = action[3:6] 214 | updated_ring_tip_coord = action[6:9] 215 | updated_thumb_tip_coord = action[9:12] 216 | else: 217 | updated_index_tip_coord = np.array(index_tip_coord) + np.array(action[0:3]) 218 | updated_middle_tip_coord = np.array(middle_tip_coord) + np.array(action[3:6]) 219 | updated_ring_tip_coord = np.array(ring_tip_coord) + np.array(action[6:9]) 220 | updated_thumb_tip_coord = np.array(thumb_tip_coord) + np.array(action[9:12]) 221 | 222 | desired_joint_angles = self.update_joint_state(updated_index_tip_coord, updated_middle_tip_coord, updated_ring_tip_coord, updated_thumb_tip_coord) 223 | 224 | print("Moving arm to {}\n".format(desired_joint_angles)) 225 | self.arm.move_hand(desired_joint_angles) 226 | 227 | 228 | if __name__ == "__main__": 229 | d = DexArmVINNDeploy(task = "rotate", debug = True, k = 1, device = "cuda", use_abs = True) 230 | d.deploy() -------------------------------------------------------------------------------- /arm_models/deploy/deploy_scripts/inn_deploy.py: -------------------------------------------------------------------------------- 1 | # Standard imports 2 | import numpy as np 3 | import cv2 4 | 5 | # Standard ROS imports 6 | import rospy 7 | from sensor_msgs.msg import Image, JointState 8 | from visualization_msgs.msg import Marker 9 | from cv_bridge import CvBridge, CvBridgeError 10 | 11 | # Dexterous aarm control package import 12 | from move_dexarm import DexArmControl 13 | 14 | # Importing the Allegro IK library 15 | from ik_teleop.ik_core.allegro_ik import AllegroInvKDL 16 | 17 | # Model Imports 18 | from arm_models.deploy.model_scripts.inn import INNDeploy 19 | 20 | # Other imports 21 | from copy import deepcopy as copy 22 | 23 | # ROS Topics to get the AR Marker data 24 | IMAGE_TOPIC = "/cam_1/color/image_raw" 25 | MARKER_TOPIC = "/visualization_marker" 26 | JOINT_STATE_TOPIC = "/allegroHand/joint_states" 27 | 28 | # Data paths 29 | CUBE_ROTATION_DATA_PATH = "/home/sridhar/dexterous_arm/models/arm_models/data/cube_rotation/complete" 30 | OBJECT_ROTATION_DATA_PATH = "/home/sridhar/dexterous_arm/models/arm_models/data/object_flipping/complete" 31 | FIDGET_SPINNING_DATA_PATH = "/home/sridhar/dexterous_arm/models/arm_models/data/fidget_spinning/complete" 32 | 33 | class DexArmINNDeploy(): 34 | def __init__(self, task, k = 1, use_abs = True, debug = False, target_priority = 1, device = "cpu", threshold = 0.02): 35 | # Initializing ROS deployment node 36 | try: 37 | rospy.init_node("model_deploy") 38 | except: 39 | pass 40 | 41 | # Setting the debug parameter only if k is 1 42 | if k == 1: 43 | self.debug = debug 44 | else: 45 | self.debug = False 46 | 47 | self.abs = use_abs 48 | self.task = task 49 | 50 | # Initializing arm controller 51 | print("Initializing controller!") 52 | self.arm = DexArmControl() 53 | 54 | # Initializing Allegro Ik Library 55 | self.allegro_ik = AllegroInvKDL(cfg = None, urdf_path = "/home/sridhar/dexterous_arm/ik_stuff/ik_teleop/urdf_template/allegro_right.urdf") 56 | 57 | # Initializing INN 58 | print("Initializing model!") 59 | if task == "rotate": 60 | if self.debug is True: 61 | self.model = INNDeploy(k = 1, data_path = CUBE_ROTATION_DATA_PATH, load_image_data = True, target_priority = target_priority, device = device) 62 | else: 63 | self.model = INNDeploy(k = k, data_path = CUBE_ROTATION_DATA_PATH, target_priority = target_priority, device = device) 64 | elif task == "flip": 65 | if self.debug is True: 66 | self.model = INNDeploy(k = 1, data_path = OBJECT_ROTATION_DATA_PATH, load_image_data = True, target_priority = target_priority, device = device) 67 | else: 68 | self.model = INNDeploy(k = k, data_path = OBJECT_ROTATION_DATA_PATH, target_priority = target_priority, device = device) 69 | elif task == "spin": 70 | if self.debug is True: 71 | self.model = INNDeploy(k = 1, data_path = FIDGET_SPINNING_DATA_PATH, load_image_data = True, target_priority = target_priority, device = device) 72 | else: 73 | self.model = INNDeploy(k = k, data_path = FIDGET_SPINNING_DATA_PATH, target_priority = target_priority, device = device) 74 | 75 | 76 | # Making sure the hand moves 77 | self.threshold = 0.02 78 | 79 | # Moving the dexterous arm to home position 80 | print("Homing robot!") 81 | if task == "rotate" or task == "flip": 82 | self.arm.home_robot() 83 | elif task == "spin": 84 | self.arm.spin_pos_arm() 85 | 86 | # Initializing topic data 87 | self.image = None 88 | self.allegro_joint_state = None 89 | if task == "rotate" or task == "flip": 90 | self.ar_marker_data = None 91 | 92 | # Other initializations 93 | self.bridge = CvBridge() 94 | 95 | if task == "rotate" or task == "flip": 96 | rospy.Subscriber(MARKER_TOPIC, Marker, self._callback_ar_marker_data, queue_size = 1) 97 | rospy.Subscriber(JOINT_STATE_TOPIC, JointState, self._callback_joint_state, queue_size = 1) 98 | rospy.Subscriber(IMAGE_TOPIC, Image, self._callback_image, queue_size=1) 99 | 100 | def _callback_ar_marker_data(self, data): 101 | self.ar_marker_data = data 102 | 103 | def _callback_joint_state(self, data): 104 | self.allegro_joint_state = data 105 | 106 | def _callback_image(self, data): 107 | try: 108 | self.image = self.bridge.imgmsg_to_cv2(data, "bgr8") 109 | except CvBridgeError as e: 110 | print(e) 111 | 112 | def get_object_position(self): 113 | object_position = np.array([self.ar_marker_data.pose.position.x, self.ar_marker_data.pose.position.y, self.ar_marker_data.pose.position.z]) 114 | return object_position 115 | 116 | def get_tip_coords(self): 117 | index_coord = self.allegro_ik.finger_forward_kinematics('index', list(self.allegro_joint_state.position)[:4])[0] 118 | middle_coord = self.allegro_ik.finger_forward_kinematics('middle', list(self.allegro_joint_state.position)[4:8])[0] 119 | ring_coord = self.allegro_ik.finger_forward_kinematics('ring', list(self.allegro_joint_state.position)[8:12])[0] 120 | thumb_coord = self.allegro_ik.finger_forward_kinematics('thumb', list(self.allegro_joint_state.position)[12:16])[0] 121 | 122 | return index_coord, middle_coord, ring_coord, thumb_coord 123 | 124 | def update_joint_state(self, index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord): 125 | current_joint_angles = list(self.allegro_joint_state.position) 126 | 127 | index_joint_angles = self.allegro_ik.finger_inverse_kinematics('index', index_tip_coord, current_joint_angles[0:4]) 128 | middle_joint_angles = self.allegro_ik.finger_inverse_kinematics('middle', middle_tip_coord, current_joint_angles[4:8]) 129 | ring_joint_angles = self.allegro_ik.finger_inverse_kinematics('ring', ring_tip_coord, current_joint_angles[8:12]) 130 | thumb_joint_angles = self.allegro_ik.finger_inverse_kinematics('thumb', thumb_tip_coord, current_joint_angles[12:16]) 131 | 132 | desired_joint_angles = copy(current_joint_angles) 133 | 134 | for idx in range(4): 135 | desired_joint_angles[idx] = index_joint_angles[idx] 136 | desired_joint_angles[4 + idx] = middle_joint_angles[idx] 137 | desired_joint_angles[8 + idx] = ring_joint_angles[idx] 138 | desired_joint_angles[12 + idx] = thumb_joint_angles[idx] 139 | 140 | return desired_joint_angles 141 | 142 | def deploy(self): 143 | print("Deploying model!\n") 144 | while True: 145 | if self.allegro_joint_state is None: 146 | print('No allegro state received!') 147 | continue 148 | 149 | # Wating for key 150 | next_step = input() 151 | 152 | # Setting the break condition 153 | if next_step == "q": 154 | break 155 | 156 | print("Getting state data...\n") 157 | 158 | print("Current joint angles: {}\n".format(self.allegro_joint_state.position)) 159 | 160 | # Getting the state data 161 | index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord = self.get_tip_coords() 162 | 163 | if self.task == "rotate" or self.task == "flip": 164 | object_pos = self.get_object_position() 165 | elif self.task == "spin": 166 | object_pos = [0, 0, 0] 167 | 168 | state = list(index_tip_coord) + list(middle_tip_coord) + list(ring_tip_coord) + list(thumb_tip_coord) + list(object_pos) 169 | 170 | if self.task == "rotate" or self.task == "flip": 171 | print("Current state data:\n Index-tip position: {}\n Middle-tip position: {}\n Ring-tip position: {}\n Thumb-tip position: {}\n Object position: {}\n".format(index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord, object_pos)) 172 | elif self.task == "spin": 173 | print("Current state data:\n Index-tip position: {}\n Middle-tip position: {}\n Ring-tip position: {}\n Thumb-tip position: {}\n".format(index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord)) 174 | 175 | if self.debug is True: 176 | action, nn_state_image_path, trans_state_image_path, index_l2_diff, middle_l2_diff, ring_l2_diff, thumb_l2_diff, object_l2_diff = self.model.get_action_with_image(index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord, object_pos) 177 | action = list(action.reshape(12)) 178 | 179 | l2_diff = np.linalg.norm(np.array(state[:12]) - np.array(action)) 180 | 181 | while l2_diff < self.threshold: 182 | action, nn_state_image_path, trans_state_image_path, index_l2_diff, middle_l2_diff, ring_l2_diff, thumb_l2_diff, object_l2_diff = self.model.get_action_with_image(action[:3], action[3:6], action[6:9], action[9:12], object_pos) 183 | action = list(action.reshape(12)) 184 | 185 | l2_diff = np.linalg.norm(np.array(state[:12]) - np.array(action)) 186 | 187 | print("Trans state img path: {}".format(trans_state_image_path)) 188 | 189 | print("Distances:\n Index distance: {}\n Middle distance: {}\n Ring distance: {}\n Thumb distance: {}\n Cube distance: {}".format(index_l2_diff.item(), middle_l2_diff.item(), ring_l2_diff.item(), thumb_l2_diff.item(), object_l2_diff.item())) 190 | 191 | # Reading the Nearest Neighbor images 192 | nn_state_image = cv2.imread(nn_state_image_path) 193 | trans_state_image = cv2.imread(trans_state_image_path) 194 | 195 | # Combing the surrent state image and NN image and printing them 196 | combined_images = np.concatenate(( 197 | cv2.resize(self.image, (420, 240), interpolation = cv2.INTER_AREA), 198 | cv2.resize(nn_state_image, (420, 240), interpolation = cv2.INTER_AREA), 199 | cv2.resize(trans_state_image, (420, 240), interpolation = cv2.INTER_AREA) 200 | ), axis=1) 201 | 202 | cv2.imshow("Current state and Nearest Neighbor Images", combined_images) 203 | cv2.waitKey(1) 204 | 205 | else: 206 | action = list(self.model.get_action(index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord, object_pos).reshape(12)) 207 | l2_diff = np.linalg.norm(np.array(state[:12]) - np.array(action)) 208 | 209 | while l2_diff < self.threshold: 210 | action = list(self.model.get_action(action[:3], action[3:6], action[6:9], action[9:12], object_pos).reshape(12)) 211 | l2_diff = np.linalg.norm(np.array(state[:12]) - np.array(action)) 212 | 213 | print("Corresponding action: ", action) 214 | 215 | if self.abs: 216 | # print("Using absolute values for thumb: {} and for ring finger: {}".format(action[:3], action[3:])) 217 | updated_index_tip_coord = action[0:3] 218 | updated_middle_tip_coord = action[3:6] 219 | updated_ring_tip_coord = action[6:9] 220 | updated_thumb_tip_coord = action[9:12] 221 | else: 222 | updated_index_tip_coord = np.array(index_tip_coord) + np.array(action[0:3]) 223 | updated_middle_tip_coord = np.array(middle_tip_coord) + np.array(action[3:6]) 224 | updated_ring_tip_coord = np.array(ring_tip_coord) + np.array(action[6:9]) 225 | updated_thumb_tip_coord = np.array(thumb_tip_coord) + np.array(action[9:12]) 226 | 227 | desired_joint_angles = self.update_joint_state(updated_index_tip_coord, updated_middle_tip_coord, updated_ring_tip_coord, updated_thumb_tip_coord) 228 | 229 | print("Moving arm to {}\n".format(desired_joint_angles)) 230 | self.arm.move_hand(desired_joint_angles) 231 | 232 | 233 | if __name__ == "__main__": 234 | # Rotation task 235 | d = DexArmINNDeploy(task = "rotate", debug = True, k = 1, target_priority = 4, device = "cuda") 236 | 237 | # Flipping task 238 | # d = DexArmINNDeploy(task = "flip", debug = True, k = 1, target_priority = 1, device = "cuda") 239 | 240 | # Spinning task 241 | # d = DexArmINNDeploy(task = "spin", debug = True, k = 1, target_priority = 1, device = "cuda") 242 | 243 | d.deploy() -------------------------------------------------------------------------------- /arm_models/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from torch.utils.data import DataLoader 5 | 6 | from arm_models.dataloaders.state_dataset import * 7 | from arm_models.dataloaders.visual_dataset import * 8 | 9 | from imitation.networks import * 10 | from byol_pytorch import BYOL 11 | 12 | import wandb 13 | import argparse 14 | from tqdm import tqdm 15 | 16 | from utils.augmentation import augmentation_generator 17 | 18 | CHKPT_PATH = os.path.join(os.getcwd(), "deploy", "checkpoints") 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('-t', '--task', type=str) 22 | parser.add_argument('-n', '--network', type=str) 23 | parser.add_argument('-b', '--batch_size', type=int) 24 | parser.add_argument('-e', '--epochs', type=int) 25 | parser.add_argument('-g', '--gpu_num', type=int) 26 | parser.add_argument('-l', '--lr', type=float) 27 | parser.add_argument('-i', '--img_size', type=int) 28 | parser.add_argument('-d', '--train_data_path', type=str) 29 | parser.add_argument('--full_data', action='store_true') 30 | parser.add_argument('-v', '--val_data_path', type=str) 31 | parser.add_argument('-f', '--test_data_path', type=str) 32 | parser.add_argument('-a', '--augment_imgs', action='store_true') 33 | parser.add_argument('-r', '--run', type=int) 34 | 35 | if __name__ == '__main__': 36 | # Getting the parser options 37 | options = parser.parse_args() 38 | 39 | # Selecting the torch device 40 | device = torch.device("cuda:{}".format(options.gpu_num)) 41 | print("Using GPU: {} for training. \n".format(options.gpu_num)) 42 | 43 | # Load the dataset and creating dataloader 44 | print("Loading dataset(s) and creating dataloader(s).\n") 45 | if options.task == "rotate": 46 | if options.img_size is None: 47 | train_dataset = CubeRotationDataset(options.train_data_path) 48 | if options.val_data_path is not None: 49 | val_dataset = CubeRotationDataset(options.val_data_path) 50 | if options.test_data_path is not None: 51 | test_dataset = CubeRotationDataset(options.test_data_path) 52 | else: 53 | if options.network == "bc": 54 | if options.full_data is True: 55 | print("Using the entire dataset!\n") 56 | train_dataset = CubeRotationVisualDataset(type = None, image_size = options.img_size, data_path = options.train_data_path) 57 | else: 58 | train_dataset = CubeRotationVisualDataset(type = "train", image_size = options.img_size, data_path = options.train_data_path) 59 | if options.val_data_path is not None: 60 | val_dataset = CubeRotationVisualDataset(type = "val", image_size = options.img_size, data_path = options.val_data_path) 61 | if options.test_data_path is not None: 62 | test_dataset = CubeRotationVisualDataset(type = "test", image_size = options.img_size, data_path = options.test_data_path) 63 | elif options.network == "representation_byol": 64 | train_dataset = RepresentationVisualDataset(task = "cube_rotation", data_path = options.train_data_path) 65 | if options.val_data_path is not None: 66 | val_dataset = RepresentationVisualDataset(task = "cube_rotation", data_path = options.train_data_path) 67 | if options.test_data_path is not None: 68 | test_dataset = RepresentationVisualDataset(task = "cube_rotation", data_path = options.train_data_path) 69 | 70 | elif options.task == "flip": 71 | if options.img_size is None: 72 | train_dataset = ObjectFlippingDataset(options.train_data_path) 73 | if options.val_data_path is not None: 74 | val_dataset = ObjectFlippingDataset(options.val_data_path) 75 | if options.test_data_path is not None: 76 | test_dataset = ObjectFlippingDataset(options.test_data_path) 77 | else: 78 | if options.network == "bc": 79 | if options.full_data is True: 80 | print("Using the entire dataset!\n") 81 | train_dataset = ObjectFlippingVisualDataset(type = None, image_size = options.img_size, data_path = options.train_data_path) 82 | else: 83 | train_dataset = ObjectFlippingVisualDataset(type = "train", image_size = options.img_size, data_path = options.train_data_path) 84 | if options.val_data_path is not None: 85 | val_dataset = ObjectFlippingVisualDataset(type = "val", image_size = options.img_size, data_path = options.val_data_path) 86 | if options.test_data_path is not None: 87 | test_dataset = ObjectFlippingVisualDataset(type = "test", image_size = options.img_size, data_path = options.test_data_path) 88 | elif options.network == "representation_byol": 89 | train_dataset = RepresentationVisualDataset(task = "object_flipping", data_path = options.train_data_path) 90 | if options.val_data_path is not None: 91 | val_dataset = RepresentationVisualDataset(task = "object_flipping", data_path = options.train_data_path) 92 | if options.test_data_path is not None: 93 | test_dataset = RepresentationVisualDataset(task = "object_flipping", data_path = options.train_data_path) 94 | 95 | elif options.task == "spin": 96 | if options.img_size is None: 97 | train_dataset = FidgetSpinningDataset(options.train_data_path) 98 | if options.val_data_path is not None: 99 | val_dataset = FidgetSpinningDataset(options.val_data_path) 100 | if options.test_data_path is not None: 101 | test_dataset = FidgetSpinningDataset(options.test_data_path) 102 | else: 103 | if options.network == "bc": 104 | if options.full_data is True: 105 | print("Using the entire dataset!\n") 106 | train_dataset = FidgetSpinningVisualDataset(type = None, image_size = options.img_size, data_path = options.train_data_path) 107 | else: 108 | train_dataset = FidgetSpinningVisualDataset(type = "train", image_size = options.img_size, data_path = options.train_data_path) 109 | if options.val_data_path is not None: 110 | val_dataset = FidgetSpinningVisualDataset(type = "val", image_size = options.img_size, data_path = options.val_data_path) 111 | if options.test_data_path is not None: 112 | test_dataset = FidgetSpinningVisualDataset(type = "test", image_size = options.img_size, data_path = options.test_data_path) 113 | elif options.network == "representation_byol": 114 | train_dataset = RepresentationVisualDataset(task = "fidget_spinning", data_path = options.train_data_path) 115 | if options.val_data_path is not None: 116 | val_dataset = RepresentationVisualDataset(task = "fidget_spinning", data_path = options.train_data_path) 117 | if options.test_data_path is not None: 118 | test_dataset = RepresentationVisualDataset(task = "fidget_spinning", data_path = options.train_data_path) 119 | 120 | train_dataloader = DataLoader(train_dataset, options.batch_size, shuffle = True, pin_memory = True, num_workers = 24) 121 | 122 | if options.val_data_path is not None: 123 | val_dataloader = DataLoader(val_dataset, options.batch_size, shuffle = True, pin_memory = True, num_workers = 24) 124 | 125 | if options.test_data_path is not None: 126 | test_dataloader = DataLoader(test_dataset, options.batch_size, shuffle = True, pin_memory = True, num_workers = 24) 127 | 128 | # Initializing the model based on the model argument 129 | print("Loading the model({}) for the task: {}\n".format(options.network, options.task)) 130 | if options.network == "mlp": 131 | model = MLP().to(device) 132 | elif options.network == "bc": 133 | model = BehaviorCloning().to(device) 134 | elif options.network == "representation_byol": 135 | original_encoder_model = models.resnet50(pretrained = True) 136 | encoder = torch.nn.Sequential(*(list(original_encoder_model.children())[:-1])).to(device) 137 | 138 | if options.augment_imgs is True: 139 | augment_custom = augmentation_generator(options.task) 140 | model = BYOL( 141 | encoder, 142 | image_size = options.img_size, 143 | augment_fn = augment_custom 144 | ) 145 | else: 146 | model = BYOL( 147 | encoder, 148 | image_size = options.img_size 149 | ) 150 | 151 | # Initialize WandB logging 152 | wandb.init(project = "{} - {}".format(options.task, options.network)) 153 | 154 | # Initializing optimizer and other parameters 155 | optimizer = torch.optim.Adam(model.parameters(), lr = options.lr) 156 | 157 | if options.test_data_path is not None: 158 | low_test_loss = np.inf 159 | elif options.val_data_path is not None: 160 | low_val_loss = np.inf 161 | else: 162 | low_train_loss = np.inf 163 | 164 | # Initializing loss 165 | if options.network != "representation_byol": 166 | loss_fn = nn.MSELoss() 167 | 168 | # Training loop 169 | print("Starting training procedure!\n") 170 | for epoch in range(options.epochs): 171 | # Training part 172 | epoch_train_loss = 0 173 | 174 | for input_data, actions in tqdm(train_dataloader): 175 | optimizer.zero_grad() 176 | 177 | if options.network == "representation_byol": 178 | loss = model(input_data.float().to(device)) 179 | else: 180 | predicted_actions = model(input_data.float().to(device)) 181 | loss = loss_fn(predicted_actions, actions.float().to(device)) 182 | 183 | loss.backward() 184 | optimizer.step() 185 | 186 | epoch_train_loss += loss.item() * input_data.shape[0] 187 | 188 | print("Train loss: {}\n".format(epoch_train_loss / len(train_dataset))) 189 | wandb.log({'train loss': epoch_train_loss / len(train_dataset)}) 190 | 191 | # Validation part 192 | if options.val_data_path is not None: 193 | epoch_val_loss = 0 194 | 195 | for input_data, actions in tqdm(val_dataloader): 196 | if options.network == "representation_byol": 197 | loss = model(input_data.float().to(device)) 198 | else: 199 | predicted_actions = model(input_data.float().to(device)) 200 | loss = loss_fn(predicted_actions, actions.float().to(device)) 201 | 202 | epoch_val_loss += loss.item() * input_data.shape[0] 203 | 204 | print("Validation loss: {}\n".format(epoch_val_loss / len(val_dataset))) 205 | wandb.log({'validation loss': epoch_val_loss / len(val_dataset)}) 206 | 207 | # Testing part 208 | if options.test_data_path is not None: 209 | epoch_test_loss = 0 210 | 211 | for input_data, actions in tqdm(test_dataloader): 212 | if options.network == "representation_byol": 213 | loss = model(input_data.float().to(device)) 214 | else: 215 | predicted_actions = model(input_data.float().to(device)) 216 | loss = loss_fn(predicted_actions, actions.float().to(device)) 217 | 218 | epoch_test_loss += loss.item() * input_data.shape[0] 219 | 220 | print("Test loss: {}\n".format(epoch_test_loss / len(test_dataset))) 221 | wandb.log({'test loss': epoch_test_loss / len(test_dataset)}) 222 | 223 | 224 | 225 | # Saving checkpoints based on the lowest stats 226 | if options.network == "vinn": 227 | low_train_loss = epoch_train_loss / len(train_dataset) 228 | epoch_checkpt_path = os.path.join(CHKPT_PATH, "{} - {} - lowest - train - v{}.pth".format(options.network, options.task, options.run)) 229 | print("\nLower train loss encountered! Saving checkpoint {}\n".format(epoch_checkpt_path)) 230 | torch.save(encoder.state_dict(), epoch_checkpt_path) 231 | else: 232 | if options.test_data_path is not None: 233 | if epoch_test_loss / len(test_dataset) < low_test_loss: 234 | low_test_loss = epoch_test_loss / len(test_dataset) 235 | epoch_checkpt_path = os.path.join(CHKPT_PATH, "{} - {} - lowest - test - v{}.pth".format(options.network, options.task, options.run)) 236 | print("\nLower test loss encountered! Saving checkpoint {}\n".format(epoch_checkpt_path)) 237 | torch.save(model.state_dict(), epoch_checkpt_path) 238 | elif options.val_data_path is not None: 239 | if epoch_val_loss / len(val_dataset) < low_val_loss: 240 | low_val_loss = epoch_val_loss / len(val_dataset) 241 | epoch_checkpt_path = os.path.join(CHKPT_PATH, "{} - {} - lowest - val - v{}.pth".format(options.network, options.task, options.run)) 242 | print("\nLower validation loss encountered! Saving checkpoint {}\n".format(epoch_checkpt_path)) 243 | torch.save(model.state_dict(), epoch_checkpt_path) 244 | else: 245 | if epoch_train_loss / len(train_dataset) < low_train_loss: 246 | low_train_loss = epoch_train_loss / len(train_dataset) 247 | epoch_checkpt_path = os.path.join(CHKPT_PATH, "{} - {} - lowest - train - v{}.pth".format(options.network, options.task, options.run)) 248 | print("\nLower train loss encountered! Saving checkpoint {}\n".format(epoch_checkpt_path)) 249 | torch.save(model.state_dict(), epoch_checkpt_path) 250 | 251 | # Saving final checkpoint 252 | if options.task == "vinn": 253 | final_epoch_checkpt_path = os.path.join(CHKPT_PATH, "{} - {} - final - v{}.pth".format(options.network, options.task, options.run)) 254 | print("\nSaving final checkpoint {}\n".format(final_epoch_checkpt_path)) 255 | torch.save(encoder.state_dict(), final_epoch_checkpt_path) 256 | else: 257 | final_epoch_checkpt_path = os.path.join(CHKPT_PATH, "{} - {} - final - v{}.pth".format(options.network, options.task, options.run)) 258 | print("\nSaving final checkpoint {}\n".format(final_epoch_checkpt_path)) 259 | torch.save(model.state_dict(), final_epoch_checkpt_path) -------------------------------------------------------------------------------- /arm_models/dataloaders/visual_dataset.py: -------------------------------------------------------------------------------- 1 | from argparse import Action 2 | import os 3 | from PIL import Image 4 | 5 | import torch 6 | from torch.utils.data import Dataset, DataLoader 7 | from torchvision import transforms as T 8 | 9 | from tqdm import tqdm 10 | 11 | class RepresentationVisualDataset(Dataset): 12 | def __init__(self, task, image_size = 224, data_path = '/home/sridhar/dexterous_arm/demonstrations/image_data'): 13 | # Loading the images 14 | images_dir= os.path.join(data_path, task) 15 | 16 | # Sorting all the demo paths 17 | image_path_list = os.listdir(images_dir) 18 | image_path_list.sort(key=lambda f: int(''.join(filter(str.isdigit, f)))) 19 | 20 | self.image_tensors = [] 21 | 22 | # Transforming the images 23 | if task == "cube_rotation": 24 | mean = torch.tensor([0.4640, 0.4933, 0.5223]) 25 | std = torch.tensor([0.2890, 0.2671, 0.2530]) 26 | elif task == "object_flipping": 27 | mean = torch.tensor([0.4815, 0.5450, 0.5696]) 28 | std = torch.tensor([0.2291, 0.2268, 0.2248]) 29 | elif task == "fidget_spinning": 30 | mean = torch.tensor([0.4306, 0.3954, 0.3472]) 31 | std = torch.tensor([0.2897, 0.2527, 0.2321]) 32 | 33 | self.image_preprocessor = T.Compose([ 34 | T.ToTensor(), 35 | T.Resize((image_size, image_size)), 36 | T.Normalize( 37 | mean = mean, 38 | std = std 39 | ) 40 | ]) 41 | 42 | # Loading all the images in the images vector 43 | print("Loading all the images: \n") 44 | for demo_images_path in tqdm(image_path_list): 45 | demo_image_folder_path = os.path.join(images_dir, demo_images_path) 46 | 47 | # Sort the demo images list 48 | demo_images_list = os.listdir(demo_image_folder_path) 49 | 50 | demo_images_list.sort(key=lambda f: int(''.join(filter(str.isdigit, f)))) 51 | 52 | # Read each image and append them in the images array 53 | for idx in range(len(demo_images_list) - 1): 54 | try: 55 | image = Image.open(os.path.join(demo_image_folder_path, demo_images_list[idx])) 56 | if task == "cube_rotation": 57 | image = image.crop((500, 160, 950, 600)) # Left, Top, Right, Bottom 58 | elif task == "object_flipping": 59 | image = image.crop((220, 165, 460, 340)) # Left, Top, Right, Bottom 60 | elif task == "fidget_spinning": 61 | image = image.crop((65, 80, 590, 480)) # Left, Top, Right, Bottom 62 | 63 | image_tensor = self.image_preprocessor(image) 64 | self.image_tensors.append(image_tensor.detach()) 65 | image.close() 66 | except: 67 | print('Image cannot be read!') 68 | continue 69 | 70 | def __len__(self): 71 | return len(self.image_tensors) 72 | 73 | def __getitem__(self, idx): 74 | return (self.image_tensors[idx], torch.zeros(12)) 75 | 76 | class CubeRotationVisualDataset(Dataset): 77 | def __init__(self, type = None, image_size = 224, data_path = '/home/sridhar/dexterous_arm/models/arm_models/data/cube_rotation'): 78 | if type == "train": 79 | self.data_path = os.path.join(data_path, "for_eval", "train") 80 | elif type == "val": 81 | self.data_path = os.path.join(data_path, "for_eval", "validation") 82 | elif type == "test": 83 | self.data_path = os.path.join(data_path, "for_eval", "test") 84 | else: 85 | self.data_path = os.path.join(data_path, "complete") 86 | 87 | # Loading all the image and action folder paths 88 | images_dir, actions_dir = os.path.join(self.data_path, 'images'), os.path.join(self.data_path, 'actions') 89 | 90 | # Sorting all the demo paths 91 | image_path_list, action_path_list = os.listdir(images_dir), os.listdir(actions_dir) 92 | 93 | image_path_list.sort(key=lambda f: int(''.join(filter(str.isdigit, f)))) 94 | action_path_list.sort(key=lambda f: int(''.join(filter(str.isdigit, f)))) 95 | 96 | self.image_tensors, self.action_tensors = [], [] 97 | 98 | # Transforming the images 99 | if type == "train": 100 | mean = torch.tensor([0.4631, 0.4917, 0.5201]) 101 | std = torch.tensor([0.2885, 0.2666, 0.2526]) 102 | elif type == "val": 103 | mean = torch.tensor([0.3557, 0.4138, 0.4550]) 104 | std = torch.tensor([0.9030, 0.8565, 0.8089]) 105 | elif type == "load": 106 | mean = torch.tensor([0, 0, 0]) 107 | std = torch.tensor([1, 1, 1]) 108 | else: 109 | mean = torch.tensor([0.4631, 0.4923, 0.5215]) 110 | std = torch.tensor([0.2891, 0.2674, 0.2535]) 111 | 112 | self.image_preprocessor = T.Compose([ 113 | T.ToTensor(), 114 | T.Resize((image_size, image_size)), 115 | T.Normalize( 116 | mean = mean, 117 | std = std 118 | ) 119 | ]) 120 | 121 | # Loading all the images in the images vector 122 | print("Loading all the images: \n") 123 | for demo_images_path in tqdm(image_path_list): 124 | demo_image_folder_path = os.path.join(images_dir, demo_images_path) 125 | 126 | # Sort the demo images list 127 | demo_images_list = os.listdir(demo_image_folder_path) 128 | demo_images_list.sort(key=lambda f: int(''.join(filter(str.isdigit, f)))) 129 | 130 | # print("Reading images from {}".format(demo_image_folder_path)) 131 | 132 | # Read each image and append them in the images array 133 | for idx in range(len(demo_images_list) - 1): 134 | try: 135 | image = Image.open(os.path.join(demo_image_folder_path, demo_images_list[idx])) 136 | image = image.crop((500, 160, 950, 600)) # Left, Top, Right, Bottom 137 | image_tensor = self.image_preprocessor(image) 138 | self.image_tensors.append(image_tensor.detach()) 139 | image.close() 140 | except: 141 | print('Image cannot be read!') 142 | continue 143 | 144 | # Loading all the action vectors 145 | print("\nLoading all the actions: \n") 146 | for demo_action_path in tqdm(action_path_list): 147 | demo_actions = torch.load(os.path.join(actions_dir, demo_action_path)) 148 | self.action_tensors.append(demo_actions) 149 | 150 | self.action_tensors = torch.cat(self.action_tensors, dim = 0) 151 | 152 | def __len__(self): 153 | return len(self.image_tensors) 154 | 155 | def __getitem__(self, idx): 156 | return ((self.image_tensors[idx], self.action_tensors[idx])) 157 | 158 | class ObjectFlippingVisualDataset(Dataset): 159 | def __init__(self, type = None, image_size = 224, data_path = '/home/sridhar/dexterous_arm/models/arm_models/data/object_flipping'): 160 | if type == "train": 161 | self.data_path = os.path.join(data_path, "for_eval", "train") 162 | elif type == "val": 163 | self.data_path = os.path.join(data_path, "for_eval", "validation") 164 | elif type == "test": 165 | self.data_path = os.path.join(data_path, "for_eval", "test") 166 | else: 167 | self.data_path = os.path.join(data_path, "complete") 168 | 169 | # Loading all the image and action folder paths 170 | images_dir, actions_dir = os.path.join(self.data_path, 'images'), os.path.join(self.data_path, 'actions') 171 | 172 | # Sorting all the demo paths 173 | image_path_list, action_path_list = os.listdir(images_dir), os.listdir(actions_dir) 174 | 175 | image_path_list.sort(key=lambda f: int(''.join(filter(str.isdigit, f)))) 176 | action_path_list.sort(key=lambda f: int(''.join(filter(str.isdigit, f)))) 177 | 178 | self.image_tensors, self.action_tensors = [], [] 179 | 180 | # Transforming the images 181 | if type == "train": 182 | mean = torch.tensor([0.4631, 0.4917, 0.5201]) # Recalculate 183 | std = torch.tensor([0.2885, 0.2666, 0.2526]) # Recalculate 184 | elif type == "val": 185 | mean = torch.tensor([0.3557, 0.4138, 0.4550]) # Recalculate 186 | std = torch.tensor([0.9030, 0.8565, 0.8089]) # Recalculate 187 | elif type == "load": 188 | mean = torch.tensor([0, 0, 0]) 189 | std = torch.tensor([1, 1, 1]) 190 | else: 191 | mean = torch.tensor([0.4534, 0.3770, 0.3885]) 192 | std = torch.tensor([0.2512, 0.1881, 0.2599]) 193 | 194 | self.image_preprocessor = T.Compose([ 195 | T.ToTensor(), 196 | T.Resize((image_size, image_size)), 197 | T.Normalize( 198 | mean = mean, 199 | std = std # To be computed 200 | ) 201 | ]) 202 | 203 | # Loading all the images in the images vector 204 | print("Loading all the images: \n") 205 | for demo_images_path in tqdm(image_path_list): 206 | demo_image_folder_path = os.path.join(images_dir, demo_images_path) 207 | 208 | # Sort the demo images list 209 | demo_images_list = os.listdir(demo_image_folder_path) 210 | demo_images_list.sort(key=lambda f: int(''.join(filter(str.isdigit, f)))) 211 | 212 | # Read each image and append them in the images array 213 | for idx in range(len(demo_images_list) - 1): 214 | try: 215 | image = Image.open(os.path.join(demo_image_folder_path, demo_images_list[idx])) 216 | image = image.crop((220, 165, 460, 340)) # Left, Top, Right, Bottom 217 | image_tensor = self.image_preprocessor(image) 218 | self.image_tensors.append(image_tensor.detach()) 219 | image.close() 220 | except: 221 | print('Image cannot be read!') 222 | continue 223 | 224 | # Loading all the action vectors 225 | print("\nLoading all the actions: \n") 226 | for demo_action_path in tqdm(action_path_list): 227 | demo_actions = torch.load(os.path.join(actions_dir, demo_action_path)) 228 | self.action_tensors.append(demo_actions) 229 | 230 | self.action_tensors = torch.cat(self.action_tensors, dim = 0) 231 | 232 | def __len__(self): 233 | return len(self.image_tensors) 234 | 235 | def __getitem__(self, idx): 236 | return ((self.image_tensors[idx], self.action_tensors[idx])) 237 | 238 | class FidgetSpinningVisualDataset(Dataset): 239 | def __init__(self, type = None, image_size = 224, data_path = '/home/sridhar/dexterous_arm/models/arm_models/data/fidget_spinning'): 240 | if type == "train": 241 | self.data_path = os.path.join(data_path, "for_eval", "train") 242 | elif type == "val": 243 | self.data_path = os.path.join(data_path, "for_eval", "validation") 244 | elif type == "test": 245 | self.data_path = os.path.join(data_path, "for_eval", "test") 246 | else: 247 | self.data_path = os.path.join(data_path, "complete") 248 | 249 | # Loading all the image and action folder paths 250 | images_dir, actions_dir = os.path.join(self.data_path, 'images'), os.path.join(self.data_path, 'actions') 251 | 252 | # Sorting all the demo paths 253 | image_path_list, action_path_list = os.listdir(images_dir), os.listdir(actions_dir) 254 | 255 | image_path_list.sort(key=lambda f: int(''.join(filter(str.isdigit, f)))) 256 | action_path_list.sort(key=lambda f: int(''.join(filter(str.isdigit, f)))) 257 | 258 | self.image_tensors, self.action_tensors = [], [] 259 | 260 | # Transforming the images 261 | if type == "train": 262 | mean = torch.tensor([0.4631, 0.4917, 0.5201]) # To be recalculated 263 | std = torch.tensor([0.2885, 0.2666, 0.2526]) # To be recalculated 264 | elif type == "val": 265 | mean = torch.tensor([0.3557, 0.4138, 0.4550]) # To be recalculated 266 | std = torch.tensor([0.9030, 0.8565, 0.8089]) # To be recalculated 267 | elif type == "load": 268 | mean = torch.tensor([0, 0, 0]) 269 | std = torch.tensor([1, 1, 1]) 270 | else: 271 | mean = torch.tensor([0.4320, 0.3963, 0.3478]) 272 | std = torch.tensor([0.2897, 0.2525, 0.2317]) 273 | 274 | self.image_preprocessor = T.Compose([ 275 | T.ToTensor(), 276 | T.Resize((image_size, image_size)), 277 | T.Normalize( 278 | mean = mean, 279 | std = std 280 | ) 281 | ]) 282 | 283 | # Loading all the images in the images vector 284 | print("Loading all the images: \n") 285 | for demo_images_path in tqdm(image_path_list): 286 | demo_image_folder_path = os.path.join(images_dir, demo_images_path) 287 | 288 | # Sort the demo images list 289 | demo_images_list = os.listdir(demo_image_folder_path) 290 | demo_images_list.sort(key=lambda f: int(''.join(filter(str.isdigit, f)))) 291 | 292 | # print("Reading images from {}".format(demo_image_folder_path)) 293 | 294 | # Read each image and append them in the images array 295 | for idx in range(len(demo_images_list) - 1): 296 | try: 297 | image = Image.open(os.path.join(demo_image_folder_path, demo_images_list[idx])) 298 | image = image.crop((65, 80, 590, 480)) # Left, Top, Right, Bottom 299 | image_tensor = self.image_preprocessor(image) 300 | self.image_tensors.append(image_tensor.detach()) 301 | image.close() 302 | except: 303 | print('Image cannot be read!') 304 | continue 305 | 306 | # Loading all the action vectors 307 | print("\nLoading all the actions: \n") 308 | for demo_action_path in tqdm(action_path_list): 309 | demo_actions = torch.load(os.path.join(actions_dir, demo_action_path)) 310 | self.action_tensors.append(demo_actions) 311 | 312 | self.action_tensors = torch.cat(self.action_tensors, dim = 0) 313 | 314 | def __len__(self): 315 | return len(self.image_tensors) 316 | 317 | def __getitem__(self, idx): 318 | return ((self.image_tensors[idx], self.action_tensors[idx])) 319 | 320 | if __name__ == '__main__': 321 | 322 | # To find the mean and std of the image pixels for normalization 323 | # dataset = CubeRotationVisualDataset() 324 | # dataset = ObjectFlippingVisualDataset() 325 | # dataset = FidgetSpinningVisualDataset() 326 | dataset = RepresentationVisualDataset("fidget_spinning") 327 | print("Number of images in the dataset: {}\n".format(len(dataset))) 328 | 329 | dataloader = DataLoader(dataset, batch_size = 128, shuffle = False, pin_memory = True, num_workers = 24) 330 | 331 | psum = torch.tensor([0.0, 0.0, 0.0]) 332 | psum_sq = torch.tensor([0.0, 0.0, 0.0]) 333 | 334 | for images, actions in tqdm(dataloader): 335 | psum += images.sum(axis = [0, 2, 3]) 336 | psum_sq += (images ** 2).sum(axis = [0, 2, 3]) 337 | 338 | count = len(dataset) * 224 * 224 339 | 340 | total_mean = psum / count 341 | total_var = (psum_sq / count) - (total_mean ** 2) 342 | total_std = torch.sqrt(total_var) 343 | 344 | print("Mean: {}".format(total_mean)) 345 | print("Std: {}".format(total_std)) -------------------------------------------------------------------------------- /arm_models/deploy/deploy.py: -------------------------------------------------------------------------------- 1 | # Standard imports 2 | import os 3 | import numpy as np 4 | import torch 5 | import cv2 6 | from PIL import Image as PILImage 7 | 8 | # ROS imports 9 | import rospy 10 | from sensor_msgs.msg import Image, JointState 11 | from visualization_msgs.msg import Marker 12 | from cv_bridge import CvBridge, CvBridgeError 13 | 14 | # Controller imports 15 | from ik_teleop.ik_core.allegro_ik import AllegroInvKDL 16 | from move_dexarm import DexArmControl 17 | 18 | # Model imports 19 | from arm_models.imitation.networks import * 20 | from arm_models.deploy.model_scripts import inn, visual_inn 21 | from arm_models.utils.augmentation import augmentation_generator 22 | 23 | # Miscellaneous imports 24 | from copy import deepcopy as copy 25 | import argparse 26 | 27 | # Other torch imports 28 | from torchvision import transforms as T 29 | 30 | # ROS Topic to get data 31 | ROBOT_IMAGE_TOPIC = "/cam_1/color/image_raw" 32 | AR_MARKER_TOPIC = "/visualization_marker" 33 | HAND_JOINT_STATE_TOPIC = "/allegroHand/joint_states" 34 | ARM_JOINT_STATE_TOPIC = "/j2n6s300_driver/out/joint_state" 35 | 36 | DATA_PATH = os.path.join(os.path.abspath(os.pardir), "data") 37 | URDF_PATH = "/home/sridhar/dexterous_arm/ik_stuff/ik_teleop/urdf_template/allegro_right.urdf" 38 | 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument('-t', '--task', type=str) 41 | parser.add_argument('-m', '--model', type=str) 42 | parser.add_argument('-d', '--device', type=int) 43 | parser.add_argument('--delta', type=int) 44 | parser.add_argument('--run', type=int) 45 | 46 | # Specific to INN or VINN 47 | parser.add_argument('--k', type=int) 48 | parser.add_argument('-p', '--target_priority', type=int) 49 | 50 | 51 | class DexArmDeploy(): 52 | def __init__(self, task, model, run, k = None, object_priority = 1, device = "cpu", threshold = 0.02): 53 | # Initializing ROS node for the dexterous arm 54 | print("Initializing ROS node.\n") 55 | rospy.init_node("deploy_node") 56 | 57 | self.task = task 58 | self.model_to_use = model 59 | self.threshold = threshold 60 | self.device = device 61 | self.k = k 62 | 63 | # Initializing the dexterous arm controller 64 | print("Initializing the controller for the dexterous arm.\n") 65 | self.arm = DexArmControl() 66 | 67 | # Positioning robot based on the task 68 | print("Positioning the dexterous arm to perform task: {}\n".format(task)) 69 | # if task == "rotate" or task == "flip": 70 | # self.arm.home_robot() 71 | # elif task == "spin": 72 | # self.arm.spin_pos_arm() 73 | 74 | # Initializing the Inverse Kinematics solver for the Allegro Hand 75 | self.allegro_ik = AllegroInvKDL(cfg = None, urdf_path = URDF_PATH) 76 | 77 | # Initializing the model for the specific task 78 | if self.model_to_use == "inn": 79 | print("Initializing INN for task {} with k = {}.\n".format(task, k)) 80 | self.model = self._init_INN(task, k, object_priority, device) 81 | elif self.model_to_use == "vinn": 82 | print("Initializing VINN for task {} with k = {}.\n".format(task, k)) 83 | self.model, self.learner = self._init_VINN(task, k, device, run) 84 | elif self.model_to_use == "mlp": 85 | print("Initializing a Simple MLP model for task {}.\n".format(task)) 86 | self.model = self._init_MLP(task, device, run) 87 | elif self.model_to_use == "bc": 88 | print("Initializing a Behavior Cloning model for task {}.\n".format(task)) 89 | self.model = self._init_BC(task, device, run) 90 | 91 | # Loading image transformer for image based models 92 | if self.model_to_use == "vinn" or self.model_to_use == "bc": 93 | if self.task == "rotate": 94 | self.image_transform = T.Compose([ 95 | T.ToTensor(), 96 | T.Resize((224, 224)), 97 | T.Normalize( 98 | mean = torch.tensor([0.4631, 0.4923, 0.5215]), 99 | std = torch.tensor([0.2891, 0.2674, 0.2535]) 100 | ) 101 | ]) 102 | elif self.task == "flip": 103 | self.image_transform = T.Compose([ 104 | T.ToTensor(), 105 | T.Resize((224, 224)), 106 | T.Normalize( 107 | mean = torch.tensor([0.4534, 0.3770, 0.3885]), 108 | std = torch.tensor([0.2512, 0.1881, 0.2599]) 109 | ) 110 | ]) 111 | elif self.task == "spin": 112 | self.image_transform = T.Compose([ 113 | T.ToTensor(), 114 | T.Resize((224, 224)), 115 | T.Normalize( 116 | mean = torch.tensor([0.4306, 0.3954, 0.3472]), 117 | std = torch.tensor([0.2897, 0.2527, 0.2321]) 118 | ) 119 | ]) 120 | 121 | # Realtime data obtained through ROS 122 | self.allegro_joint_state = None 123 | self.kinova_joint_state = None 124 | 125 | self.bridge = CvBridge() 126 | self.robot_image = None 127 | 128 | rospy.Subscriber(HAND_JOINT_STATE_TOPIC, JointState, self._callback_allegro_joint_state, queue_size = 1) 129 | rospy.Subscriber(ARM_JOINT_STATE_TOPIC, JointState, self._callback_kinova_joint_state, queue_size = 1) 130 | rospy.Subscriber(ROBOT_IMAGE_TOPIC, Image, self._callback_robot_image, queue_size=1) 131 | 132 | if self.task == "rotate" or self.task == "flip": 133 | rospy.Subscriber(AR_MARKER_TOPIC, Marker, self._callback_ar_marker_data, queue_size = 1) 134 | 135 | if self.task == "rotate": 136 | self.object_tracked_position = None 137 | self.hand_base_position = None 138 | elif self.task == "flip": 139 | self.object_tracked_position = None 140 | 141 | def _callback_allegro_joint_state(self, data): 142 | self.allegro_joint_state = data 143 | 144 | def _callback_kinova_joint_state(self, data): 145 | self.kinova_joint_state = data 146 | 147 | def _callback_robot_image(self, image): 148 | try: 149 | self.robot_image = self.bridge.imgmsg_to_cv2(image, "bgr8") 150 | except CvBridgeError as e: 151 | print(e) 152 | 153 | def _callback_ar_marker_data(self, data): 154 | if data.id == 0 or data.id == 5: 155 | self.object_tracked_position = np.array([data.pose.position.x, data.pose.position.y, data.pose.position.z]) 156 | elif data.id == 8: 157 | self.hand_base_position = np.array([data.pose.position.x, data.pose.position.y, data.pose.position.z]) 158 | 159 | def _get_object_position(self): 160 | if self.task == "rotate": 161 | return self.object_tracked_position - self.hand_base_position 162 | elif self.task == "flip": 163 | return self.object_tracked_position 164 | 165 | def _init_INN(self, task, k, object_priority, device): 166 | # Getting the path to load the data 167 | if task == "rotate": 168 | folder = "cube_rotation" 169 | elif task == "flip": 170 | folder = "object_flipping" 171 | elif task == "spin": 172 | folder = "fidget_spinning" 173 | task_data_path = os.path.join(DATA_PATH, folder, "complete") 174 | 175 | # Initializing the model based on k to enable debug 176 | return inn.INNDeploy( 177 | k = k, 178 | data_path = task_data_path, 179 | load_image_data = True if k == 1 else False, 180 | target_priority = object_priority, 181 | device = device 182 | ) 183 | 184 | def _init_VINN(self, task, k, device, run): 185 | # Loading the checkpoint based on the task 186 | chkpt_path = os.path.join(os.getcwd(), "checkpoints", "representation_byol - {} - lowest - train - v{}.pth".format(task, run)) 187 | 188 | # Loading the representation encoder 189 | original_encoder_model = models.resnet50(pretrained = True) 190 | encoder = torch.nn.Sequential(*(list(original_encoder_model.children())[:-1])) 191 | encoder = encoder.to(device) 192 | 193 | learner = BYOL ( 194 | encoder, 195 | image_size = 224 196 | ) 197 | learner.load_state_dict(torch.load(chkpt_path)) 198 | learner.eval() 199 | 200 | # Loading the VINN model 201 | model = visual_inn.VINNDeploy( 202 | k = k, 203 | load_image_data = True if k == 1 else False, 204 | device = device 205 | ) 206 | 207 | return model, learner 208 | 209 | def _init_MLP(self, task, device, run): 210 | # Loading the checkpoint based on the task 211 | chkpt_path = os.path.join(os.getcwd(), "checkpoints", "mlp - {} - lowest - train - v{}.pth".format(task, run)) 212 | 213 | model = MLP().to(torch.device(device)) 214 | model.load_state_dict(torch.load(chkpt_path)) 215 | model.eval() 216 | return model 217 | 218 | def _init_BC(self, task, device, run): 219 | # Loading the checkpoint based on the task 220 | chkpt_path = os.path.join(os.getcwd(), "checkpoints", "bc - {} - lowest - train - v{}.pth".format(task, run)) 221 | 222 | model = BehaviorCloning().to(torch.device(device)) 223 | model.load_state_dict(torch.load(chkpt_path)) 224 | model.eval() 225 | return model 226 | 227 | def _get_tip_coords(self): 228 | index_coord = self.allegro_ik.finger_forward_kinematics('index', list(self.allegro_joint_state.position)[:4])[0] 229 | middle_coord = self.allegro_ik.finger_forward_kinematics('middle', list(self.allegro_joint_state.position)[4:8])[0] 230 | ring_coord = self.allegro_ik.finger_forward_kinematics('ring', list(self.allegro_joint_state.position)[8:12])[0] 231 | thumb_coord = self.allegro_ik.finger_forward_kinematics('thumb', list(self.allegro_joint_state.position)[12:16])[0] 232 | 233 | return index_coord, middle_coord, ring_coord, thumb_coord 234 | 235 | def _update_joint_state(self, index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord): 236 | current_joint_angles = list(self.allegro_joint_state.position) 237 | 238 | index_joint_angles = self.allegro_ik.finger_inverse_kinematics('index', index_tip_coord, current_joint_angles[0:4]) 239 | middle_joint_angles = self.allegro_ik.finger_inverse_kinematics('middle', middle_tip_coord, current_joint_angles[4:8]) 240 | ring_joint_angles = self.allegro_ik.finger_inverse_kinematics('ring', ring_tip_coord, current_joint_angles[8:12]) 241 | thumb_joint_angles = self.allegro_ik.finger_inverse_kinematics('thumb', thumb_tip_coord, current_joint_angles[12:16]) 242 | 243 | desired_joint_angles = copy(current_joint_angles) 244 | 245 | for idx in range(4): 246 | desired_joint_angles[idx] = index_joint_angles[idx] 247 | desired_joint_angles[4 + idx] = middle_joint_angles[idx] 248 | desired_joint_angles[8 + idx] = ring_joint_angles[idx] 249 | desired_joint_angles[12 + idx] = thumb_joint_angles[idx] 250 | 251 | return desired_joint_angles 252 | 253 | def start(self, time_loop = False): 254 | print("\nDeploying model: {}\n".format(self.model_to_use)) 255 | 256 | if time_loop is True: 257 | rate = rospy.Rate(2) # 2 sec sleep duration 258 | 259 | while True: 260 | if time_loop is False: 261 | next_step = input() 262 | 263 | # Checking if all the data streams are working 264 | if self.allegro_joint_state is None: 265 | print('No allegro state received!') 266 | continue 267 | 268 | if self.robot_image is None: 269 | print('No robot image received!') 270 | continue 271 | 272 | if self.task == "rotate": 273 | if self.object_tracked_position is None: 274 | print("Object cannot be tracked!") 275 | continue 276 | 277 | if self.hand_base_position is None: 278 | print("Hand base position not found!") 279 | continue 280 | 281 | if self.task == "flip": 282 | if self.object_tracked_position is None: 283 | continue 284 | 285 | # Performing the step 286 | print("********************************************************\n Starting new step \n********************************************************") 287 | 288 | # Displaying the current state data 289 | index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord = self._get_tip_coords() 290 | finger_state = list(index_tip_coord) + list(middle_tip_coord) + list(ring_tip_coord) + list(thumb_tip_coord) 291 | print("Current finger-tip positions:\n Index-tip: {}\n Middle-tip: {}\n Ring-tip: {}\n Thumb-tip: {}\n".format(index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord)) 292 | 293 | if self.k != 1: 294 | cv2.imshow("Current Robot State Image", self.robot_image) 295 | cv2.waitKey(1) 296 | 297 | if self.task == "rotate" or self.task == "flip": 298 | object_position = list(self._get_object_position()) 299 | print("Object is located at position: {}".format(object_position)) 300 | elif self.task == "spin": 301 | object_position = [0, 0, 0] 302 | 303 | if self.model_to_use == "vinn" or self.model_to_use == "bc": 304 | pil_image = PILImage.fromarray(cv2.cvtColor(self.robot_image, cv2.COLOR_BGR2RGB)) 305 | if self.task == "rotate": 306 | cropped_image = pil_image.crop((500, 160, 950, 600)) # Left, Top, Right, Bottom 307 | elif self.task == "flip": # TODO 308 | cropped_image = pil_image.crop((220, 165, 460, 340)) # Left, Top, Right, Bottom 309 | elif self.task == "spin": # TODO 310 | cropped_image = pil_image.crop((65, 80, 590, 480)) # Left, Top, Right, Bottom 311 | 312 | transformed_image_tensor = self.image_transform(cropped_image).unsqueeze(0) 313 | 314 | # Obtaining the actions through each model 315 | if self.model_to_use == "inn": 316 | if self.k == 1: 317 | print("\nObtaining the action!\n") 318 | action, neighbor_index, nn_state_image_path, trans_state_image_path, index_l2_diff, middle_l2_diff, ring_l2_diff, thumb_l2_diff, object_l2_diff = self.model.get_action_with_image(index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord, object_position) 319 | action = list(action.reshape(12)) 320 | 321 | l2_diff = np.linalg.norm(np.array(finger_state) - np.array(action)) 322 | 323 | if l2_diff >= self.threshold: 324 | print("Distances:\n Index distance: {}\n Middle distance: {}\n Ring distance: {}\n Thumb distance: {}\n Cube distance: {}".format(index_l2_diff.item(), middle_l2_diff.item(), ring_l2_diff.item(), thumb_l2_diff.item(), object_l2_diff.item())) 325 | 326 | while l2_diff < self.threshold: 327 | print("Action distance is below threshold: {}".format(l2_diff)) 328 | action, neighbor_index, nn_state_image_path, trans_state_image_path = self.model.get_debug_action(neighbor_index) 329 | # action, neighbor_index, nn_state_image_path, trans_state_image_path, index_l2_diff, middle_l2_diff, ring_l2_diff, thumb_l2_diff, object_l2_diff = self.model.get_action_with_image(action[:3], action[3:6], action[6:9], action[9:12], object_position) 330 | action = list(action.reshape(12)) 331 | print("Primary neighbor index:", neighbor_index) 332 | 333 | l2_diff = np.linalg.norm(np.array(finger_state) - np.array(action)) 334 | 335 | # Reading the Nearest Neighbor images 336 | nn_state_image = cv2.imread(nn_state_image_path) 337 | trans_state_image = cv2.imread(trans_state_image_path) 338 | 339 | # Combing the surrent state image and NN image and printing them 340 | combined_images = np.concatenate(( 341 | cv2.resize(self.robot_image, (420, 240), interpolation = cv2.INTER_AREA), 342 | cv2.resize(nn_state_image, (420, 240), interpolation = cv2.INTER_AREA), 343 | cv2.resize(trans_state_image, (420, 240), interpolation = cv2.INTER_AREA) 344 | ), axis=1) 345 | 346 | cv2.imshow("Current state and Nearest Neighbor Images", combined_images) 347 | cv2.waitKey(1) 348 | 349 | else: 350 | print("Obtaining the action!\n") 351 | action = list(self.model.get_action(index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord, object_position).reshape(12)) 352 | l2_diff = np.linalg.norm(np.array(finger_state) - np.array(action)) 353 | 354 | while l2_diff < self.threshold: 355 | print("Action distance is below threshold: {}".format(l2_diff)) 356 | action = list(self.model.get_action(action[:3], action[3:6], action[6:9], action[9:12], object_position).reshape(12)) 357 | l2_diff = np.linalg.norm(np.array(finger_state) - np.array(action)) 358 | 359 | elif self.model_to_use == "vinn": 360 | print("Image shape:", transformed_image_tensor.float().to(self.device).shape) 361 | representation = self.learner.net(transformed_image_tensor.float().to(self.device)).squeeze().detach().cpu().numpy() 362 | print("Representation shape:", representation.shape) 363 | if self.k == 1: 364 | print("Obtaining the action!\n") 365 | action, neighbor_idx, nn_state_image_path, trans_state_image_path = self.model.get_action_with_image(representation) 366 | action = list(action.reshape(12)) 367 | 368 | l2_diff = np.linalg.norm(np.array(finger_state[:12]) - np.array(action)) 369 | 370 | while l2_diff < self.threshold: 371 | print("Action distance is below threshold. Rolling out another action from the same trajectory!\n") 372 | action, neighbor_idx, nn_state_image_path, trans_state_image_path = self.model.get_debug_action(neighbor_idx) 373 | action = list(action.reshape(12)) 374 | 375 | l2_diff = np.linalg.norm(np.array(finger_state[:12]) - np.array(action)) 376 | 377 | # Reading the Nearest Neighbor images 378 | nn_state_image = cv2.imread(nn_state_image_path) 379 | trans_state_image = cv2.imread(trans_state_image_path) 380 | 381 | # Combing the surrent state image and NN image and printing them 382 | combined_images = np.concatenate(( 383 | cv2.resize(self.robot_image, (420, 240), interpolation = cv2.INTER_AREA), 384 | cv2.resize(nn_state_image, (420, 240), interpolation = cv2.INTER_AREA), 385 | cv2.resize(trans_state_image, (420, 240), interpolation = cv2.INTER_AREA) 386 | ), axis=1) 387 | 388 | cv2.imshow("Current state and Nearest Neighbor Images", combined_images) 389 | cv2.waitKey(1) 390 | 391 | else: 392 | action = list(self.model.get_action(representation).reshape(12)) 393 | 394 | elif self.model_to_use == "mlp": 395 | input_coordinates = torch.tensor(np.concatenate([index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord, object_position])) 396 | action = self.model(input_coordinates.float().to(self.device)).detach().cpu().numpy()[0] 397 | 398 | elif self.model_to_use == "bc": 399 | action = self.model(transformed_image_tensor.float().to(self.device)).detach().cpu().numpy()[0] 400 | 401 | # Updating the hand target coordinates 402 | updated_index_tip_coord = action[0:3] 403 | updated_middle_tip_coord = action[3:6] 404 | updated_ring_tip_coord = action[6:9] 405 | updated_thumb_tip_coord = action[9:12] 406 | 407 | print("Corresponding finger-tip action:\n Index-tip: {}\n Middle-tip: {}\n Ring-tip: {}\n Thumb-tip: {}\n".format(updated_index_tip_coord, updated_middle_tip_coord, updated_ring_tip_coord, updated_thumb_tip_coord)) 408 | 409 | desired_joint_angles = self._update_joint_state(updated_index_tip_coord, updated_middle_tip_coord, updated_ring_tip_coord, updated_thumb_tip_coord) 410 | 411 | print("Moving arm to {}\n".format(desired_joint_angles)) 412 | self.arm.move_hand(desired_joint_angles) 413 | 414 | if time_loop is True: 415 | rate.sleep() 416 | 417 | 418 | if __name__ == '__main__': 419 | # Getting options 420 | options = parser.parse_args() 421 | 422 | d = DexArmDeploy( 423 | task = options.task, 424 | model = options.model, 425 | run = options.run, 426 | k = options.k, 427 | object_priority = options.target_priority, 428 | device = "cpu", 429 | threshold = 0.01 * options.delta 430 | ) 431 | 432 | d.start() --------------------------------------------------------------------------------