├── arm_models
├── __init__.py
├── utils
│ ├── augmentation.py
│ ├── representation_extractor.py
│ └── load_data.py
├── imitation
│ ├── networks.py
│ └── non_parametric.py
├── dataloaders
│ ├── state_dataset.py
│ └── visual_dataset.py
├── deploy
│ ├── evaluation
│ │ └── performance_test_inn.py
│ ├── model_scripts
│ │ ├── visual_inn.py
│ │ └── inn.py
│ ├── deploy_scripts
│ │ ├── mlp_deploy.py
│ │ ├── bc_deploy.py
│ │ ├── vinn_deploy.py
│ │ └── inn_deploy.py
│ └── deploy.py
└── train.py
├── .gitignore
├── setup.py
└── README.md
/arm_models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Data
2 | arm_models/data/*
3 |
4 | <<<<<<< HEAD
5 | data/*
6 | Imitation_NN/__pycache__/*
7 | utils/__pycache__/*
8 | =======
9 | # Losses and Graphs
10 | arm_models/deploy/evaluation/losses/*
11 | arm_models/deploy/evaluation/graphs/*
12 |
13 | # Model checkpoints
14 | arm_models/deploy/checkpoints/*
15 |
16 | # Wandb logs
17 | arm_models/wandb/*
18 |
19 | # Other installation files
20 | dexterous_arm_models.egg-info/*
21 | >>>>>>> dexarm-inn
22 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from setuptools import setup, find_packages
4 |
5 | print("Installing Dexterous Arm Hardware models package!.")
6 |
7 | if sys.version_info.major != 3:
8 | print("This Python is only compatible with Python 3, but you are running "
9 | "Python {}. The installation will likely fail.".format(sys.version_info.major))
10 |
11 | def read(fname):
12 | return open(os.path.join(os.path.dirname(__file__), fname)).read()
13 |
14 | setup(
15 | name='dexterous_arm_models',
16 | version='1.0.0',
17 | packages=find_packages(),
18 | description='Models that can be deployed on the dexterous arm.',
19 | long_description=read('README.md'),
20 | url='https://github.com/NYU-robot-learning/Dexterous-Arm-Models',
21 | author='Sridhar Pandian',
22 | )
--------------------------------------------------------------------------------
/arm_models/utils/augmentation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision import transforms as T
3 |
4 | def augmentation_generator(task):
5 | if task == "rotate":
6 | norm_values = T.Normalize(
7 | mean = torch.tensor([0.3484, 0.3638, 0.3819]),
8 | std = torch.tensor([0.3224, 0.3151, 0.3166])
9 | )
10 | elif task == "flip":
11 | norm_values = T.Normalize(
12 | mean = torch.tensor([0, 0, 0]),
13 | std = torch.tensor([1, 1, 1])
14 | )
15 | elif task == "spin":
16 | norm_values = T.Normalize(
17 | mean = torch.tensor([0, 0, 0]),
18 | std = torch.tensor([1, 1, 1])
19 | )
20 |
21 | augment_custom = T.Compose([
22 | T.RandomResizedCrop(224, scale = (0.6, 1)),
23 | T.RandomApply(torch.nn.ModuleList([T.ColorJitter(.8,.8,.8,.2)]), p=.3),
24 | T.RandomGrayscale(p = 0.2),
25 | T.RandomApply(torch.nn.ModuleList([T.GaussianBlur((3, 3), (1.0, 2.0))]), p=0.2),
26 | norm_values
27 | ])
28 |
29 | return augment_custom
--------------------------------------------------------------------------------
/arm_models/imitation/networks.py:
--------------------------------------------------------------------------------
1 | # General torch imports
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | # Imports for pretrained encoders
7 | from torchvision import models
8 |
9 | # Import for Representation learners
10 | from byol_pytorch import BYOL
11 |
12 | # State based Behavior Cloning
13 | class MLP(nn.Module):
14 | def __init__(self):
15 | super(MLP, self).__init__()
16 | self.fc_1 = nn.Linear(15, 128)
17 | self.fc_2 = nn.Linear(128, 512)
18 | self.fc_3 = nn.Linear(512, 512)
19 | self.fc_4 = nn.Linear(512, 128)
20 | self.fc_5 = nn.Linear(128, 12)
21 | self.batch_norm = nn.BatchNorm1d(512)
22 |
23 | def forward(self, x):
24 | x = x.view(-1, 15)
25 | x = F.leaky_relu(self.fc_1(x))
26 | x = F.leaky_relu(self.fc_2(x))
27 | x = self.fc_3(x)
28 | x = F.leaky_relu(self.batch_norm(x))
29 | x = F.leaky_relu(self.fc_4(x))
30 | x = self.fc_5(x)
31 | return x
32 |
33 | # Visual Behavior Cloning
34 | class BehaviorCloning(nn.Module):
35 | def __init__(self):
36 | super(BehaviorCloning, self).__init__()
37 | # Encoder
38 | self.encoder = models.resnet50(pretrained = True)
39 | # Fully Connected Regression Layer
40 | self.fc1 = nn.Linear(1000, 1000)
41 | self.fc2 = nn.Linear(1000, 1024)
42 | self.fc3 = nn.Linear(1024, 12)
43 |
44 | def forward(self, x):
45 | x = self.encoder(x)
46 | x = self.fc1(x)
47 | x = F.leaky_relu(x)
48 | x = self.fc2(x)
49 | x = F.leaky_relu(x)
50 | x = self.fc3(x)
51 | return x
52 |
--------------------------------------------------------------------------------
/arm_models/dataloaders/state_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch.utils.data import Dataset
4 |
5 | from arm_models.utils.load_data import load_state_actions
6 |
7 | class CubeRotationDataset(Dataset):
8 | def __init__(self, data_path = '/home/sridhar/dexterous_arm/models/arm_models/data/cube_rotation/complete'):
9 | self.data_path = data_path
10 | self.states, self.actions = load_state_actions(data_path)
11 |
12 | def __len__(self):
13 | return len(self.states)
14 |
15 | def __getitem__(self, idx):
16 | if torch.is_tensor(idx):
17 | idx = idx.tolist()
18 |
19 | return self.states[idx], self.actions[idx]
20 |
21 | class ObjectFlippingDataset(Dataset):
22 | def __init__(self, data_path = '/home/sridhar/dexterous_arm/models/arm_models/data/object_flipping/complete'):
23 | self.data_path = data_path
24 | self.states, self.actions = load_state_actions(data_path)
25 |
26 | def __len__(self):
27 | return len(self.states)
28 |
29 | def __getitem__(self, idx):
30 | if torch.is_tensor(idx):
31 | idx = idx.tolist()
32 |
33 | return self.states[idx], self.actions[idx]
34 |
35 | class FidgetSpinningDataset(Dataset):
36 | def __init__(self, data_path = '/home/sridhar/dexterous_arm/models/arm_models/data/fidget_spinning/complete'):
37 | self.data_path = data_path
38 | self.states, self.actions = load_state_actions(data_path)
39 |
40 | def __len__(self):
41 | return len(self.states)
42 |
43 | def __getitem__(self, idx):
44 | if torch.is_tensor(idx):
45 | idx = idx.tolist()
46 |
47 | return self.states[idx], self.actions[idx]
48 |
--------------------------------------------------------------------------------
/arm_models/utils/representation_extractor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import DataLoader
3 | from torchvision import models
4 |
5 | from tqdm import tqdm
6 |
7 | from arm_models.dataloaders.visual_dataset import *
8 |
9 | from byol_pytorch import BYOL
10 |
11 | REP_TENSOR_PATH = '/home/sridhar/dexterous_arm/models/arm_models/data/fidget_spinning/complete/representations/representations.pth'
12 | REP_MODEL_CHKPT_PATH = '/home/sridhar/dexterous_arm/models/arm_models/deploy/checkpoints/representation_byol - spin - lowest - train - v1.pth'
13 |
14 | def extract_representations(device, CHKPT_PATH):
15 | # Loading the dataset and creating the dataloader
16 | # dataset = CubeRotationVisualDataset()
17 | # dataset = ObjectFlippingVisualDataset()
18 | dataset = FidgetSpinningVisualDataset()
19 |
20 | dataloader = DataLoader(dataset, batch_size = 1, shuffle = False, pin_memory = True, num_workers = 24)
21 |
22 | # Loading the model
23 | original_encoder_model = models.resnet50(pretrained = True)
24 | encoder = torch.nn.Sequential(*(list(original_encoder_model.children())[:-1]))
25 | encoder = encoder.to(device)
26 |
27 | learner = BYOL (
28 | encoder,
29 | image_size = 224
30 | )
31 | learner.load_state_dict(torch.load(CHKPT_PATH))
32 | learner.eval()
33 |
34 | representations = []
35 |
36 | # Extracting the representations
37 | for image, action in tqdm(dataloader):
38 | representation = learner.net(image.float().to(device)).squeeze()
39 | representations.append(representation.detach().cpu())
40 |
41 | representation_tensor = torch.stack(representations)
42 | print("Final representation tensor shape:", representation_tensor.squeeze().shape)
43 |
44 | return representation_tensor.squeeze(1)
45 |
46 | def store_representations(representations, DATA_PATH):
47 | torch.save(representations, DATA_PATH)
48 |
49 | if __name__ == '__main__':
50 | # Selecting the GPU to be used
51 | gpu_number = int(input('Enter the GPU number: '))
52 | device = torch.device('cuda:{}'.format(gpu_number))
53 |
54 | print('Using GPU: {} for extracting the representations. \n'.format(torch.cuda.get_device_name(gpu_number)))
55 |
56 | representation_tensor = extract_representations(device, CHKPT_PATH = REP_MODEL_CHKPT_PATH)
57 | store_representations(representation_tensor, REP_TENSOR_PATH)
--------------------------------------------------------------------------------
/arm_models/deploy/evaluation/performance_test_inn.py:
--------------------------------------------------------------------------------
1 | from cProfile import label
2 | import torch
3 | import pickle
4 |
5 | import matplotlib.pyplot as plt
6 |
7 | from IPython import embed
8 |
9 | from arm_models.imitation.non_parametric import INN
10 | from arm_models.utils.load_data import *
11 |
12 | DATA_PATH = "/home/sridhar/dexterous_arm/models/arm_models/data/cube_rotation/for_eval"
13 |
14 | PERFORMANCE_METRICS_PATH = "/home/sridhar/dexterous_arm/models/arm_models/deploy/evaluation/losses/INN_losses"
15 |
16 | def performance_test_inn(inn, k, test_states, test_actions):
17 | loss = 0
18 |
19 | # Testing the model
20 | for idx, test_state in enumerate(test_states):
21 | if k == 1:
22 | obtained_action, neighbor_index, index_l2_diff, middle_l2_diff, ring_l2_diff, thumb_l2_diff, cube_l2_diff = inn.find_optimum_action(test_state, k)
23 | else:
24 | obtained_action = inn.find_optimum_action(test_state, k)
25 |
26 | action_loss = torch.norm(obtained_action - test_actions[idx])
27 |
28 | loss += action_loss
29 |
30 | normalized_loss = loss / len(test_states)
31 |
32 | print("The testing loss is {}\n".format(normalized_loss))
33 |
34 | return normalized_loss
35 |
36 | if __name__ == "__main__":
37 | k_losses = []
38 |
39 | # Getting the state and action pairs
40 | train_states, train_actions, test_states, test_actions = load_train_test_data(DATA_PATH)
41 |
42 | # Checking if the state action pairs are valid
43 | assert len(test_states) == len(test_actions), "The number of states is not equal to the number of actions!"
44 | print("Total number of train state and action pairs: {}".format(len(train_states)))
45 | print("Total number of test state and action pairs: {}".format(len(test_states)))
46 |
47 | print("\nTesting the model for object priority values from 1 to 50 and finger-tip priority values from 1 to 10.\n")
48 |
49 | plt.figure()
50 |
51 | # Computing the loss for 20 different k-values
52 | for finger_priority_idx in range(10):
53 | finger_losses = []
54 |
55 | for priority_idx in range(50):
56 | inn = INN(device="cuda", target_priority = priority_idx + 1, finger_priority = finger_priority_idx + 1)
57 |
58 | # Loading data into the model
59 | inn.get_data(train_states, train_actions)
60 |
61 | print("Computing loss for priority = {}".format(priority_idx + 1))
62 | computed_loss = performance_test_inn(inn, 1, test_states, test_actions)
63 |
64 | finger_losses.append(computed_loss)
65 |
66 | print("The minimum observed loss was: {} for object priority: {}\n".format(min(finger_losses), finger_losses.index(min(finger_losses)) + 1))
67 | plt.plot([x+1 for x in range(50)], finger_losses, label="finger_priority = {}".format(finger_priority_idx + 1))
68 | plt.legend()
69 |
70 |
71 |
72 | # Saving the losses in a pickle file
73 | # loss_file = open(PERFORMANCE_METRICS_PATH, "ab")
74 | # pickle.dump(k_losses, loss_file)
75 | # loss_file.close()
76 |
77 | # Plotting the k-value based losses
78 | plt.xlabel("Object prioritized weights")
79 |
80 | plt.ylabel("Test loss")
81 | plt.show()
82 | # embed()
83 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Dexterous Imitation Made Easy
2 | **Authors**: Sridhar Pandian Arunachalam*, [Sneha Silwal](http://ssilwal.com/)*, [Ben Evans](https://bennevans.github.io/), [Lerrel Pinto](https://lerrelpinto.com)
3 |
4 | This is the official implementation of the paper [Dexterous Manipulation Made Easy](https://arxiv.org/abs/2203.13251).
5 |
6 | ## Policies on Real Robot
7 |
8 |
9 |
10 |
11 |
12 |
13 | ## Method
14 | 
15 | DIME consists of two phases: demonstration colleciton, which is performed in real-time with visual feedback, and demonstration-based policy learning, which can learn to solve dexterous tasks from a limited number of demonstrations.
16 |
17 | ## Setup
18 | The code base is split into 5 separate packages for convenience and this is one out of the five repositories. You can clone and setup each package by following the instructions on their respective repositories. The packages are:
19 | - [Robot controller packages](https://github.com/NYU-robot-learning/DIME-Controllers):
20 | - [Allegro Hand Controller](https://github.com/NYU-robot-learning/Allegro-Hand-Controller-DIME).
21 | - [Kinova Arm Controller](https://github.com/NYU-robot-learning/Kinova-Arm-Controller-DIME).
22 | - [Camera packages](https://github.com/NYU-robot-learning/DIME-Camera-Packages)
23 | - [Realsense-ROS](https://github.com/NYU-robot-learning/Realsense-ROS-DIME).
24 | - [AR_Tracker_Alvar](https://github.com/ros-perception/ar_track_alvar).
25 | - (Phase 1) Demonstration collection packages:
26 | - [Teleop with Inverse Kinematics Package](https://github.com/NYU-robot-learning/DIME-IK-TeleOp).
27 | - [State based and Image based Demonstration collection package](https://github.com/NYU-robot-learning/DIME-Demonstrations).
28 | - (Phase 2) Nearest-Neighbor Imitation Learning (present in this repository).
29 | - Simulation [environments](https://github.com/NYU-robot-learning/dime_env) and DAPG related codebase.
30 |
31 | You need to setup the Controller packages and IK-TeleOp package before using this package.
32 | To install the dependencies for this package with `pip`:
33 | ```
34 | pip3 install torch==1.10.0+cu113 torchvision==0.11.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
35 | ```
36 | Then install this package with:
37 | ```
38 | pip3 install -e .
39 | ```
40 |
41 | ## Data
42 | All our data can be found in this URL: [https://drive.google.com/drive/folders/1nunGHB2EK9xvlmepNNziDDbt-pH8OAhi](https://drive.google.com/drive/folders/1nunGHB2EK9xvlmepNNziDDbt-pH8OAhi)
43 |
44 | ## Citation
45 |
46 | If you use this repo in your research, please consider citing the paper as follows:
47 | ```
48 | @article{arunachalam2022dime,
49 | title={Dexterous Imitation Made Easy: A Learning-Based Framework for Efficient Dexterous Manipulation},
50 | author={Sridhar Pandian Arunachalam and Sneha Silwal and Ben Evans and Lerrel Pinto},
51 | journal={arXiv preprint arXiv:2203.13251},
52 | year={2022}
53 | }
54 |
--------------------------------------------------------------------------------
/arm_models/utils/load_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | def load_train_test_data(data_path = os.path.join(os.path.abspath(os.pardir), "data")):
5 | train_states, train_actions = load_state_actions(os.path.join(data_path, "train"))
6 | test_states, test_actions = load_state_actions(os.path.join(data_path, "validation"))
7 | return train_states, train_actions, test_states, test_actions
8 |
9 | def load_state_actions(data_path):
10 | # Getting the paths which contain the states and actions for each demo
11 | states_path, actions_path= os.path.join(data_path, "states"), os.path.join(data_path, "actions")
12 |
13 | # Getting the list of all demo states and actions
14 | states_demos_list, actions_demos_list = os.listdir(states_path), os.listdir(actions_path)
15 | assert len(states_demos_list) == len(actions_demos_list), "There are not equal number of state and action demo files!"
16 |
17 | # Sorting the state action pairs based on the demo numbers
18 | states_demos_list.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))
19 | actions_demos_list.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))
20 |
21 | states, actions = [], []
22 |
23 | for demo_idx in range(len(states_demos_list)):
24 | demo_states = torch.load(os.path.join(states_path, states_demos_list[demo_idx]))
25 | demo_actions = torch.load(os.path.join(actions_path, actions_demos_list[demo_idx]))
26 | assert len(demo_states) == len(demo_actions), "Number of states: {} from {} and number of actions: {} from are not equal.\n".format(demo_states.shape[0], states_demos_list[demo_idx], demo_actions.shape[0], actions_demos_list[demo_idx])
27 |
28 | states.append(demo_states)
29 | actions.append(demo_actions)
30 |
31 | # Converting the lists into a tensor
32 | states, actions = torch.cat(states), torch.cat(actions)
33 |
34 | return states, actions
35 |
36 | def load_representation_actions(data_path):
37 | # Getting the paths which contain the representions and actions for each demo
38 | representations_path, states_path, actions_path= os.path.join(data_path, "representations"), os.path.join(data_path, "states"), os.path.join(data_path, "actions")
39 |
40 | # Extracting the representations
41 | representations_file_name = os.listdir(representations_path)[0]
42 | representations = torch.load(os.path.join(representations_path, representations_file_name))
43 |
44 | # Getting the list of all demo states and actions
45 | actions_demos_list = os.listdir(actions_path)
46 |
47 | # Sorting the actions based on the demo numbers
48 | actions_demos_list.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))
49 |
50 | actions = []
51 |
52 | for demo_idx in range(len(actions_demos_list)):
53 | demo_actions = torch.load(os.path.join(actions_path, actions_demos_list[demo_idx]))
54 | actions.append(demo_actions)
55 |
56 | # Converting the lists into a tensor
57 | actions = torch.cat(actions)
58 |
59 | return representations, actions
60 |
61 | def load_state_image_data(data_path):
62 | image_data_path = os.path.join(data_path, "images")
63 | demo_image_folders = os.listdir(image_data_path)
64 | demo_image_folders.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))
65 |
66 | cumm_demos_image_count = []
67 |
68 | for demo_folder in demo_image_folders:
69 | demo_image_folder_path = os.path.join(image_data_path, demo_folder)
70 | demo_image_count = len(os.listdir(demo_image_folder_path))
71 |
72 | if len(cumm_demos_image_count) > 0:
73 | cumm_demos_image_count.append(cumm_demos_image_count[-1] + demo_image_count - 1)
74 | else:
75 | cumm_demos_image_count.append(demo_image_count)
76 |
77 | return image_data_path, demo_image_folders, cumm_demos_image_count
78 |
--------------------------------------------------------------------------------
/arm_models/deploy/model_scripts/visual_inn.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from arm_models.imitation.non_parametric import VINN
4 | from arm_models.utils.load_data import load_representation_actions, load_state_image_data
5 |
6 | DATA_PATH = "/home/sridhar/dexterous_arm/models/arm_models/data/fidget_spinning/complete"
7 |
8 | class VINNDeploy():
9 | def __init__(self, data_path = DATA_PATH, k = 1, load_image_data = True, device = "cpu"):
10 | self.k = k
11 | self.image_data_path = None
12 |
13 | self.representations, self.actions = load_representation_actions(data_path)
14 |
15 | self.model = VINN(device)
16 | self.model.get_data(self.representations, self.actions)
17 |
18 | if load_image_data is True:
19 | print(data_path)
20 | self.image_data_path, self.demo_image_folders, self.cumm_demos_image_count = load_state_image_data(data_path)
21 |
22 | def get_action(self, representation):
23 | if self.k == 1:
24 | action, neighbor_idx = self.model.find_optimum_action(representation, self.k)
25 | return action.detach().cpu().numpy()
26 | else:
27 | return self.model.find_optimum_action(representation, self.k).detach().cpu().numpy()
28 |
29 | def get_debug_action(self, neighbor_index):
30 | if neighbor_index < self.cumm_demos_image_count[0]:
31 | nn_image_num = neighbor_index
32 | nn_trans_image_num = nn_image_num + 1
33 | demo_num = 0
34 | else:
35 | for idx, cumm_demo_images in enumerate(self.cumm_demos_image_count):
36 | if neighbor_index + 1 > cumm_demo_images:
37 | # Obtaining the demo number
38 | demo_num = idx + 1
39 |
40 | # Getting the corresponding image number
41 | nn_image_num = neighbor_index - cumm_demo_images
42 | nn_trans_image_num = nn_image_num + 1
43 |
44 | nn_state_image_path = os.path.join(self.image_data_path, self.demo_image_folders[demo_num], "state_{}.jpg".format(nn_image_num.item()))
45 | trans_state_image_path = os.path.join(self.image_data_path, self.demo_image_folders[demo_num], "state_{}.jpg".format(nn_trans_image_num.item()))
46 |
47 |
48 | return self.actions[neighbor_index + 1], neighbor_index + 1, nn_state_image_path, trans_state_image_path
49 |
50 | def get_action_with_image(self, representation):
51 | if self.k == 1 and self.image_data_path is not None:
52 | # state = list(index_tip_coord) + list(middle_tip_coord) + list(ring_tip_coord) + list(thumb_tip_coord) + list(cube_pos)
53 | calculated_action, neighbor_index = self.model.find_optimum_action(representation, self.k)
54 |
55 | if neighbor_index < self.cumm_demos_image_count[0]:
56 | nn_image_num = neighbor_index
57 | nn_trans_image_num = nn_image_num + 1
58 | demo_num = 0
59 | else:
60 | for idx, cumm_demo_images in enumerate(self.cumm_demos_image_count):
61 | if neighbor_index + 1 > cumm_demo_images:
62 | # Obtaining the demo number
63 | demo_num = idx + 1
64 |
65 | # Getting the corresponding image number
66 | nn_image_num = neighbor_index - cumm_demo_images
67 | nn_trans_image_num = nn_image_num + 1
68 |
69 | nn_state_image_path = os.path.join(self.image_data_path, self.demo_image_folders[demo_num], "state_{}.jpg".format(nn_image_num.item()))
70 | trans_state_image_path = os.path.join(self.image_data_path, self.demo_image_folders[demo_num], "state_{}.jpg".format(nn_trans_image_num.item()))
71 |
72 | return calculated_action.cpu().detach().numpy(), neighbor_index, nn_state_image_path, trans_state_image_path
73 | else:
74 | return self.get_action(representation)
--------------------------------------------------------------------------------
/arm_models/deploy/model_scripts/inn.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | from arm_models.imitation.non_parametric import INN
5 | from arm_models.utils.load_data import load_state_actions, load_state_image_data
6 |
7 | class INNDeploy():
8 | def __init__(self, data_path, target_priority = 1, k = 1, load_image_data = False, device = "cpu"):
9 | self.k = k
10 | self.image_data_path = None
11 |
12 | self.states, self.actions = load_state_actions(data_path)
13 |
14 | self.model = INN(device, target_priority = target_priority)
15 | self.model.get_data(self.states, self.actions)
16 |
17 | if load_image_data is True:
18 | print("Loading data from path:", data_path)
19 | self.image_data_path, self.demo_image_folders, self.cumm_demos_image_count = load_state_image_data(data_path)
20 |
21 | def get_action(self, index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord, object_position):
22 | state = list(index_tip_coord) + list(middle_tip_coord) + list(ring_tip_coord) + list(thumb_tip_coord) + list(object_position)
23 | if self.k == 1:
24 | action, neighbor_idx, index_l2_diff, middle_l2_diff, ring_l2_diff, thumb_l2_diff, object_l2_diff = self.model.find_optimum_action(state, self.k)
25 | return action.cpu().detach().numpy()
26 | else:
27 | return self.model.find_optimum_action(state, self.k).cpu().detach().numpy()
28 |
29 | def get_action_with_image(self, index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord, object_position):
30 | if self.k == 1 and self.image_data_path is not None:
31 | state = list(index_tip_coord) + list(middle_tip_coord) + list(ring_tip_coord) + list(thumb_tip_coord) + list(object_position)
32 | calculated_action, neighbor_index, index_l2_diff, middle_l2_diff, ring_l2_diff, thumb_l2_diff, object_l2_diff = self.model.find_optimum_action(state, self.k)
33 |
34 | if neighbor_index < self.cumm_demos_image_count[0]:
35 | nn_image_num = neighbor_index
36 | nn_trans_image_num = nn_image_num + 1
37 | demo_num = 0
38 | else:
39 | for idx, cumm_demo_images in enumerate(self.cumm_demos_image_count):
40 | if neighbor_index + 1 > cumm_demo_images:
41 | # Obtaining the demo number
42 | demo_num = idx + 1
43 |
44 | # Getting the corresponding image number
45 | nn_image_num = neighbor_index - cumm_demo_images
46 | nn_trans_image_num = nn_image_num + 1
47 |
48 | nn_state_image_path = os.path.join(self.image_data_path, self.demo_image_folders[demo_num], "state_{}.jpg".format(nn_image_num.item()))
49 | trans_state_image_path = os.path.join(self.image_data_path, self.demo_image_folders[demo_num], "state_{}.jpg".format(nn_trans_image_num.item()))
50 |
51 | return calculated_action.cpu().detach().numpy(), neighbor_index, nn_state_image_path, trans_state_image_path, index_l2_diff, middle_l2_diff, ring_l2_diff, thumb_l2_diff, object_l2_diff
52 | else:
53 | self.get_action(index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord, object_position)
54 |
55 | def get_debug_action(self, neighbor_index):
56 | if neighbor_index < self.cumm_demos_image_count[0]:
57 | nn_image_num = neighbor_index
58 | nn_trans_image_num = nn_image_num + 1
59 | demo_num = 0
60 | else:
61 | for idx, cumm_demo_images in enumerate(self.cumm_demos_image_count):
62 | if neighbor_index + 1 > cumm_demo_images:
63 | # Obtaining the demo number
64 | demo_num = idx + 1
65 |
66 | # Getting the corresponding image number
67 | nn_image_num = neighbor_index - cumm_demo_images
68 | nn_trans_image_num = nn_image_num + 1
69 |
70 | nn_state_image_path = os.path.join(self.image_data_path, self.demo_image_folders[demo_num], "state_{}.jpg".format(nn_image_num.item()))
71 | trans_state_image_path = os.path.join(self.image_data_path, self.demo_image_folders[demo_num], "state_{}.jpg".format(nn_trans_image_num.item()))
72 |
73 | return self.actions[neighbor_index + 1], neighbor_index + 1, nn_state_image_path, trans_state_image_path
--------------------------------------------------------------------------------
/arm_models/imitation/non_parametric.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class INN():
5 | def __init__(self, device="cpu", target_priority = 1, finger_priority = 1):
6 | # Setting the device
7 | self.device = torch.device(device)
8 |
9 | # Initializing the states and action pairs to None
10 | self.states = None
11 | self.actions = None
12 |
13 | self.target_priority = target_priority
14 | self.thumb_priority = finger_priority
15 | self.ring_priority = finger_priority
16 |
17 | def get_data(self, states, actions):
18 | self.states = states.to(self.device)
19 | self.actions = actions.to(self.device)
20 |
21 | def getNearestNeighbors(self, input_state, k):
22 | # Comparing the dataset shape and state
23 | assert torch.tensor(input_state).shape == self.states[0].shape, "There is a data shape mismatch: \n Shape of loaded dataset: {} \n Shape of current state: {}".format(input_state.shape, self.states[0].shape)
24 |
25 | # Getting the k-Nearest Neighbor actions
26 | state_diff = self.states - torch.tensor(input_state).to(self.device)
27 |
28 | index_l2_diff = torch.norm(state_diff[:, :3], dim=1)
29 | middle_l2_diff = torch.norm(state_diff[:, 3:6], dim=1)
30 | ring_l2_diff = torch.norm(state_diff[:, 6:9], dim=1)
31 | thumb_l2_diff = torch.norm(state_diff[:, 9:12], dim=1)
32 | object_l2_diff = torch.norm(state_diff[:, 12:], dim=1)
33 |
34 | l2_diff = index_l2_diff + middle_l2_diff + (ring_l2_diff * self.ring_priority) + (thumb_l2_diff + self.thumb_priority) + (object_l2_diff * self.target_priority)
35 |
36 | sorted_idxs = torch.argsort(l2_diff)[:k]
37 |
38 | if k == 1:
39 | k_nn_actions = self.actions[torch.argsort(l2_diff)[0]]
40 | return k_nn_actions, torch.argsort(l2_diff)[0].cpu().detach(), index_l2_diff[sorted_idxs], middle_l2_diff[sorted_idxs], ring_l2_diff[sorted_idxs], thumb_l2_diff[sorted_idxs], object_l2_diff[sorted_idxs]
41 | else:
42 | return self.actions[torch.argsort(l2_diff)[:k]]
43 |
44 |
45 | def find_optimum_action(self, input_state, k):
46 | # Getting the k-Nearest Neighbor actions for the input state
47 | if k == 1:
48 | k_nn_action, neighbor_idx, index_l2_diff, middle_l2_diff, ring_l2_diff, thumb_l2_diff, object_l2_diff = self.getNearestNeighbors(input_state, k)
49 | return k_nn_action.cpu().detach(), neighbor_idx, index_l2_diff, middle_l2_diff, ring_l2_diff, thumb_l2_diff, object_l2_diff
50 | else:
51 | k_nn_actions = self.getNearestNeighbors(input_state, k)
52 |
53 | # Getting the mean value from the set of nearest neighbor states
54 | mean_action = torch.mean(k_nn_actions, 0).cpu().detach()
55 | return mean_action
56 |
57 | class VINN():
58 | def __init__(self, device="cpu", threshold = 0.02):
59 | # Setting the device
60 | self.device = torch.device(device)
61 |
62 | # Initializing the representations and action pairs to None
63 | self.representations = None
64 | self.actions = None
65 |
66 | self.softmax_fiter = nn.Softmax(dim = 0)
67 |
68 | def get_data(self, representations, actions):
69 | self.representations = representations.to(self.device)
70 | self.actions = actions.to(self.device)
71 |
72 | def getNearestNeighbors(self, input_representation, k):
73 | # Comparing the dataset shape and representation
74 | assert torch.tensor(input_representation).shape == self.representations[0].shape, "There is a data shape mismatch: \n Shape of loaded representations: {} \n Shape of input representation: {}".format(input_representation.shape, self.representations[0].shape)
75 |
76 | # Getting the k-Nearest Neighbor actions
77 | representation_difference = self.representations - torch.tensor(input_representation).to(self.device)
78 |
79 | l2_diff = torch.norm(representation_difference, dim=1)
80 | softmax_l2_diff = self.softmax_fiter(l2_diff)
81 | sorted_idxs = torch.argsort(softmax_l2_diff)[:k]
82 |
83 | # Get the actions which have the closest action value
84 |
85 | if k > 1:
86 | return self.actions[sorted_idxs]
87 | else:
88 | return self.actions[sorted_idxs].cpu().detach(), sorted_idxs[0].cpu().detach()
89 |
90 |
91 | def find_optimum_action(self, input_representation, k):
92 | # Getting the k-Nearest Neighbor actions for the input representation
93 | if k == 1:
94 | return self.getNearestNeighbors(input_representation, k)
95 | else:
96 | k_nn_actions = self.getNearestNeighbors(input_representation, k)
97 |
98 | # Getting the mean value from the set of nearest neighbor representations
99 | return torch.mean(k_nn_actions, 0).cpu().detach()
100 |
--------------------------------------------------------------------------------
/arm_models/deploy/deploy_scripts/mlp_deploy.py:
--------------------------------------------------------------------------------
1 | # Standard imports
2 | import numpy as np
3 | import cv2
4 | import torch
5 |
6 | # Standard ROS imports
7 | import rospy
8 | from sensor_msgs.msg import Image, JointState
9 | from visualization_msgs.msg import Marker
10 | from cv_bridge import CvBridge, CvBridgeError
11 |
12 | # Dexterous aarm control package import
13 | from move_dexarm import DexArmControl
14 |
15 | # Importing the Allegro IK library
16 | from ik_teleop.ik_core.allegro_ik import AllegroInvKDL
17 |
18 | # Model Imports
19 | from arm_models.imitation_models.Simple_MLP.model import MLP
20 |
21 | # Other imports
22 | from copy import deepcopy as copy
23 |
24 | # ROS Topics to get the AR Marker data
25 | IMAGE_TOPIC = "/cam_1/color/image_raw"
26 | MARKER_TOPIC = "/visualization_marker"
27 | JOINT_STATE_TOPIC = "/allegroHand/joint_states"
28 |
29 | MODEL_CHKPT_PATH = '/home/sridhar/dexterous_arm/models/arm_models/dexarm_deployment/model_checkpoints/mlp-final.pth'
30 |
31 | class DexArmMLPDeploy():
32 | def __init__(self, device = 'cpu'):
33 | # Ignoring scientific notations
34 | torch.set_printoptions(precision=None)
35 |
36 | # Initializing ROS deployment node
37 | try:
38 | rospy.init_node("model_deploy")
39 | except:
40 | pass
41 |
42 | # Initializing arm controller
43 | self.arm = DexArmControl()
44 |
45 | # Initializing Allegro Ik Library
46 | self.allegro_ik = AllegroInvKDL(cfg = None, urdf_path = "/home/sridhar/dexterous_arm/ik_stuff/ik_teleop/urdf_template/allegro_right.urdf")
47 |
48 | # Loading the model and assigning device
49 | self.device = torch.device(device)
50 | self.model = MLP()
51 | self.model = self.model.to(device)
52 | self.model.load_state_dict(torch.load(MODEL_CHKPT_PATH))
53 | self.model.eval()
54 |
55 | # Moving the dexterous arm to home position
56 | self.arm.home_robot()
57 |
58 | # Initializing topic data
59 | self.image = None
60 | self.allegro_joint_state = None
61 | self.ar_marker_data = None
62 |
63 | # Other initializations
64 | self.bridge = CvBridge()
65 |
66 | rospy.Subscriber(MARKER_TOPIC, Marker, self._callback_ar_marker_data, queue_size = 1)
67 | rospy.Subscriber(JOINT_STATE_TOPIC, JointState, self._callback_joint_state, queue_size = 1)
68 | rospy.Subscriber(IMAGE_TOPIC, Image, self._callback_image, queue_size=1)
69 |
70 | def _callback_ar_marker_data(self, data):
71 | self.ar_marker_data = data
72 |
73 | def _callback_joint_state(self, data):
74 | self.allegro_joint_state = data
75 |
76 | def _callback_image(self, data):
77 | try:
78 | self.image = self.bridge.imgmsg_to_cv2(data, "bgr8")
79 | except CvBridgeError as e:
80 | print(e)
81 |
82 | def get_cube_position(self):
83 | cube_position = np.array([self.ar_marker_data.pose.position.x, self.ar_marker_data.pose.position.y, self.ar_marker_data.pose.position.z])
84 | return cube_position
85 |
86 | def get_tip_coords(self):
87 | index_coord = self.allegro_ik.finger_forward_kinematics('index', list(self.allegro_joint_state.position)[:4])[0]
88 | middle_coord = self.allegro_ik.finger_forward_kinematics('middle', list(self.allegro_joint_state.position)[4:8])[0]
89 | ring_coord = self.allegro_ik.finger_forward_kinematics('ring', list(self.allegro_joint_state.position)[8:12])[0]
90 | thumb_coord = self.allegro_ik.finger_forward_kinematics('thumb', list(self.allegro_joint_state.position)[12:16])[0]
91 |
92 | return index_coord, middle_coord, ring_coord, thumb_coord
93 |
94 | def update_joint_state(self, index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord):
95 | current_joint_angles = list(self.allegro_joint_state.position)
96 |
97 | index_joint_angles = self.allegro_ik.finger_inverse_kinematics('index', index_tip_coord, current_joint_angles[0:4])
98 | middle_joint_angles = self.allegro_ik.finger_inverse_kinematics('middle', middle_tip_coord, current_joint_angles[4:8])
99 | ring_joint_angles = self.allegro_ik.finger_inverse_kinematics('ring', ring_tip_coord, current_joint_angles[8:12])
100 | thumb_joint_angles = self.allegro_ik.finger_inverse_kinematics('thumb', thumb_tip_coord, current_joint_angles[12:16])
101 |
102 | desired_joint_angles = copy(current_joint_angles)
103 |
104 | for idx in range(4):
105 | desired_joint_angles[idx] = index_joint_angles[idx]
106 | desired_joint_angles[4 + idx] = middle_joint_angles[idx]
107 | desired_joint_angles[8 + idx] = ring_joint_angles[idx]
108 | desired_joint_angles[12 + idx] = thumb_joint_angles[idx]
109 |
110 | return desired_joint_angles
111 |
112 | def deploy(self):
113 | while True:
114 | # Waiting for key
115 | next_step = input()
116 |
117 | # Setting the break condition
118 | if next_step == "q":
119 | break
120 |
121 | print("Getting state data...\n")
122 |
123 | print("Current joint angles: {}\n".format(self.allegro_joint_state.position))
124 |
125 | # Getting the state data
126 | index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord = self.get_tip_coords()
127 | cube_pos = self.get_cube_position()
128 |
129 | print("Current state data:\n Index-tip position: {}\n Middle-tip position: {}\n Ring-tip position: {}\n Thumb-tip position: {}\n Cube position: {}\n".format(index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord, cube_pos))
130 |
131 | cv2.imshow("Current state", self.image)
132 | cv2.waitKey(1)
133 |
134 | # Passing the input to the model to get the corresponding action
135 | input_coordinates = torch.tensor(np.concatenate([index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord, cube_pos]))
136 | action = self.model(input_coordinates.float().to(self.device)).detach().cpu().numpy()[0]
137 |
138 | print("Corresponding action: ", action)
139 |
140 | # Updating the hand target coordinates
141 | updated_index_tip_coord = action[0:3]
142 | updated_middle_tip_coord = action[3:6]
143 | updated_ring_tip_coord = action[6:9]
144 | updated_thumb_tip_coord = action[9:12]
145 |
146 | print("updated index tip coord:", updated_index_tip_coord)
147 |
148 | desired_joint_angles = self.update_joint_state(updated_index_tip_coord, updated_middle_tip_coord, updated_ring_tip_coord, updated_thumb_tip_coord)
149 |
150 | print("Moving arm to {}\n".format(desired_joint_angles))
151 | self.arm.move_hand(desired_joint_angles)
152 |
153 | if __name__ == '__main__':
154 | d = DexArmMLPDeploy(device = 'cpu')
155 | d.deploy()
--------------------------------------------------------------------------------
/arm_models/deploy/deploy_scripts/bc_deploy.py:
--------------------------------------------------------------------------------
1 | # Standard imports
2 | import cv2
3 | import torch
4 | from PIL import Image as PILImage
5 |
6 | # Standard ROS imports
7 | import rospy
8 | from sensor_msgs.msg import Image, JointState
9 | from cv_bridge import CvBridge, CvBridgeError
10 |
11 | # Dexterous aarm control package import
12 | from move_dexarm import DexArmControl
13 |
14 | # Importing the Allegro IK library
15 | from ik_teleop.ik_core.allegro_ik import AllegroInvKDL
16 |
17 | # Model Imports
18 | from arm_models.imitation_models.Behavior_Cloning.model import BehaviorCloning
19 |
20 | # Other imports
21 | from copy import deepcopy as copy
22 | from torchvision import transforms as T
23 |
24 | # ROS Topics to get the AR Marker data
25 | IMAGE_TOPIC = "/cam_1/color/image_raw"
26 | JOINT_STATE_TOPIC = "/allegroHand/joint_states"
27 |
28 | MODEL_CHKPT_PATH = '/home/sridhar/dexterous_arm/models/arm_models/dexarm_deployment/model_checkpoints/behavior_cloning.pth'
29 |
30 | class DexArmMLPDeploy():
31 | def __init__(self, device = 'cpu'):
32 | # Ignoring scientific notations
33 | torch.set_printoptions(precision=None)
34 |
35 | # Initializing ROS deployment node
36 | try:
37 | rospy.init_node("model_deploy")
38 | except:
39 | pass
40 |
41 | # Initializing arm controller
42 | self.arm = DexArmControl()
43 |
44 | # Initializing Allegro Ik Library
45 | self.allegro_ik = AllegroInvKDL(cfg = None, urdf_path = "/home/sridhar/dexterous_arm/ik_stuff/ik_teleop/urdf_template/allegro_right.urdf")
46 |
47 | # Loading the model and assigning device
48 | self.device = torch.device(device)
49 | self.model = BehaviorCloning()
50 | self.model = self.model.to(device)
51 | self.model.load_state_dict(torch.load(MODEL_CHKPT_PATH))
52 | self.model.eval()
53 |
54 | # Moving the dexterous arm to home position
55 | self.arm.home_robot()
56 |
57 | # Initializing topic data
58 | self.image = None
59 | self.allegro_joint_state = None
60 |
61 | # Initalizing image transformer
62 | self.image_transform = T.Compose([
63 | T.ToTensor(),
64 | T.Resize((224, 224)),
65 | T.Normalize(
66 | mean = torch.tensor([0.3484, 0.3638, 0.3819]),
67 | std = torch.tensor([0.3224, 0.3151, 0.3166])
68 | )
69 | ])
70 |
71 | # Other initializations
72 | self.bridge = CvBridge()
73 |
74 | rospy.Subscriber(JOINT_STATE_TOPIC, JointState, self._callback_joint_state, queue_size = 1)
75 | rospy.Subscriber(IMAGE_TOPIC, Image, self._callback_image, queue_size=1)
76 |
77 | def _callback_ar_marker_data(self, data):
78 | self.ar_marker_data = data
79 |
80 | def _callback_joint_state(self, data):
81 | self.allegro_joint_state = data
82 |
83 | def _callback_image(self, data):
84 | try:
85 | self.image = self.bridge.imgmsg_to_cv2(data, "bgr8")
86 | except CvBridgeError as e:
87 | print(e)
88 |
89 | def get_tip_coords(self):
90 | index_coord = self.allegro_ik.finger_forward_kinematics('index', list(self.allegro_joint_state.position)[:4])[0]
91 | middle_coord = self.allegro_ik.finger_forward_kinematics('middle', list(self.allegro_joint_state.position)[4:8])[0]
92 | ring_coord = self.allegro_ik.finger_forward_kinematics('ring', list(self.allegro_joint_state.position)[8:12])[0]
93 | thumb_coord = self.allegro_ik.finger_forward_kinematics('thumb', list(self.allegro_joint_state.position)[12:16])[0]
94 |
95 | return index_coord, middle_coord, ring_coord, thumb_coord
96 |
97 | def update_joint_state(self, index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord):
98 | current_joint_angles = list(self.allegro_joint_state.position)
99 |
100 | index_joint_angles = self.allegro_ik.finger_inverse_kinematics('index', index_tip_coord, current_joint_angles[0:4])
101 | middle_joint_angles = self.allegro_ik.finger_inverse_kinematics('middle', middle_tip_coord, current_joint_angles[4:8])
102 | ring_joint_angles = self.allegro_ik.finger_inverse_kinematics('ring', ring_tip_coord, current_joint_angles[8:12])
103 | thumb_joint_angles = self.allegro_ik.finger_inverse_kinematics('thumb', thumb_tip_coord, current_joint_angles[12:16])
104 |
105 | desired_joint_angles = copy(current_joint_angles)
106 |
107 | for idx in range(4):
108 | desired_joint_angles[idx] = index_joint_angles[idx]
109 | desired_joint_angles[4 + idx] = middle_joint_angles[idx]
110 | desired_joint_angles[8 + idx] = ring_joint_angles[idx]
111 | desired_joint_angles[12 + idx] = thumb_joint_angles[idx]
112 |
113 | return desired_joint_angles
114 |
115 | def deploy(self):
116 | while True:
117 | # Waiting for key
118 | next_step = input()
119 |
120 | # Setting the break condition
121 | if next_step == "q":
122 | break
123 |
124 | print("Getting state data...\n")
125 |
126 | print("Current joint angles: {}\n".format(self.allegro_joint_state.position))
127 |
128 | # Getting the state data
129 | index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord = self.get_tip_coords()
130 |
131 | print("Current state data:\n Index-tip position: {}\n Middle-tip position: {}\n Ring-tip position: {}\n Thumb-tip position: {}\n".format(index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord))
132 |
133 | cv2.imshow("Current state", self.image)
134 | cv2.waitKey(1)
135 |
136 | # Transforming the image before passing it into the model
137 | pil_image = PILImage.fromarray(cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB))
138 | cropped_image = pil_image.crop((450, 220, 900, 920)) # Left, Top, Right, Bottom
139 | transformed_image_tensor = self.image_transform(cropped_image).unsqueeze(0)
140 |
141 | action = self.model(transformed_image_tensor.float().to(self.device)).detach().cpu().numpy()[0]
142 |
143 | print("Corresponding action: ", action)
144 |
145 | # Updating the hand target coordinates
146 | updated_index_tip_coord = action[0:3]
147 | updated_middle_tip_coord = action[3:6]
148 | updated_ring_tip_coord = action[6:9]
149 | updated_thumb_tip_coord = action[9:12]
150 |
151 | print("updated index tip coord:", updated_index_tip_coord)
152 |
153 | desired_joint_angles = self.update_joint_state(updated_index_tip_coord, updated_middle_tip_coord, updated_ring_tip_coord, updated_thumb_tip_coord)
154 |
155 | print("Moving arm to {}\n".format(desired_joint_angles))
156 | self.arm.move_hand(desired_joint_angles)
157 |
158 | if __name__ == '__main__':
159 | d = DexArmMLPDeploy(device = 'cpu')
160 | d.deploy()
--------------------------------------------------------------------------------
/arm_models/deploy/deploy_scripts/vinn_deploy.py:
--------------------------------------------------------------------------------
1 | # Standard imports
2 | import numpy as np
3 | import cv2
4 | import torch
5 | from PIL import Image as PILImage
6 |
7 | # Standard ROS imports
8 | import rospy
9 | from sensor_msgs.msg import Image, JointState
10 | from cv_bridge import CvBridge, CvBridgeError
11 |
12 | # Dexterous aarm control package import
13 | from move_dexarm import DexArmControl
14 |
15 | # Importing the Allegro IK library
16 | from ik_teleop.ik_core.allegro_ik import AllegroInvKDL
17 |
18 | # Model Imports
19 | from arm_models.dexarm_deployment.model_scripts.visual_inn import VINNDeploy
20 |
21 | # Other imports
22 | from copy import deepcopy as copy
23 | from torchvision import models
24 | from torchvision import transforms as T
25 |
26 | # ROS Topics to get the AR Marker data
27 | IMAGE_TOPIC = "/cam_1/color/image_raw"
28 | JOINT_STATE_TOPIC = "/allegroHand/joint_states"
29 | ENCODER_MODEL_CHKPT = "/home/sridhar/dexterous_arm/models/arm_models/dexarm_deployment/model_checkpoints/rotation/BYOL-VINN-rotation-lowest-train-loss.pth"
30 |
31 | class DexArmVINNDeploy():
32 | def __init__(self, task, k = 1, use_abs = True, debug = False, device = "cpu"):
33 | # Initializing ROS deployment node
34 | try:
35 | rospy.init_node("model_deploy")
36 | except:
37 | pass
38 |
39 | # Setting the debug parameter only if k is 1
40 | if k == 1:
41 | self.debug = debug
42 | else:
43 | self.debug = False
44 |
45 | self.abs = use_abs
46 |
47 | # Initializing arm controller
48 | print("Initializing controller!")
49 | self.arm = DexArmControl()
50 |
51 | # Initializing Allegro Ik Library
52 | self.allegro_ik = AllegroInvKDL(cfg = None, urdf_path = "/home/sridhar/dexterous_arm/ik_stuff/ik_teleop/urdf_template/allegro_right.urdf")
53 |
54 | # Initializing INN
55 | print("Initializing model!") # TODO - put data path
56 | if task == "rotate":
57 | if self.debug is True:
58 | self.model = VINNDeploy(k = 1, load_image_data = True, device = device)
59 | else:
60 | self.model = VINNDeploy(k = k, device = device)
61 | elif task == "flip":
62 | if self.debug is True:
63 | self.model = VINNDeploy(k = 1, load_image_data = True, device = device)
64 | else:
65 | self.model = VINNDeploy(k = k, device = device)
66 |
67 | # Making sure the hand moves
68 | self.threshold = 0.02
69 |
70 | # Moving the dexterous arm to home position
71 | print("Homing robot!\n")
72 | self.arm.home_robot()
73 |
74 | # Initializing topic data
75 | self.image = None
76 | self.allegro_joint_state = None
77 |
78 | # Initializing the representation extractor
79 | self.device = torch.device(device)
80 | original_encoder_model = models.resnet50(pretrained = True)
81 | self.encoder = torch.nn.Sequential(*(list(original_encoder_model.children())[:-1]))
82 | self.encoder.to(self.device)
83 | self.encoder.load_state_dict(torch.load(ENCODER_MODEL_CHKPT))
84 | self.encoder.eval()
85 |
86 | # Initializing image transformation function
87 | self.image_transform = T.Compose([
88 | T.ToTensor(),
89 | T.Resize((224, 224)),
90 | T.Normalize(
91 | mean = torch.tensor([0.3484, 0.3638, 0.3819]),
92 | std = torch.tensor([0.3224, 0.3151, 0.3166])
93 | )
94 | ])
95 |
96 | # Other initializations
97 | self.bridge = CvBridge()
98 |
99 | rospy.Subscriber(JOINT_STATE_TOPIC, JointState, self._callback_joint_state, queue_size = 1)
100 | rospy.Subscriber(IMAGE_TOPIC, Image, self._callback_image, queue_size=1)
101 |
102 | def _callback_ar_marker_data(self, data):
103 | self.ar_marker_data = data
104 |
105 | def _callback_joint_state(self, data):
106 | self.allegro_joint_state = data
107 |
108 | def _callback_image(self, data):
109 | try:
110 | self.image = self.bridge.imgmsg_to_cv2(data, "bgr8")
111 | except CvBridgeError as e:
112 | print(e)
113 |
114 | def get_tip_coords(self):
115 | index_coord = self.allegro_ik.finger_forward_kinematics('index', list(self.allegro_joint_state.position)[:4])[0]
116 | middle_coord = self.allegro_ik.finger_forward_kinematics('middle', list(self.allegro_joint_state.position)[4:8])[0]
117 | ring_coord = self.allegro_ik.finger_forward_kinematics('ring', list(self.allegro_joint_state.position)[8:12])[0]
118 | thumb_coord = self.allegro_ik.finger_forward_kinematics('thumb', list(self.allegro_joint_state.position)[12:16])[0]
119 |
120 | return index_coord, middle_coord, ring_coord, thumb_coord
121 |
122 | def update_joint_state(self, index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord):
123 | current_joint_angles = list(self.allegro_joint_state.position)
124 |
125 | index_joint_angles = self.allegro_ik.finger_inverse_kinematics('index', index_tip_coord, current_joint_angles[0:4])
126 | middle_joint_angles = self.allegro_ik.finger_inverse_kinematics('middle', middle_tip_coord, current_joint_angles[4:8])
127 | ring_joint_angles = self.allegro_ik.finger_inverse_kinematics('ring', ring_tip_coord, current_joint_angles[8:12])
128 | thumb_joint_angles = self.allegro_ik.finger_inverse_kinematics('thumb', thumb_tip_coord, current_joint_angles[12:16])
129 |
130 | desired_joint_angles = copy(current_joint_angles)
131 |
132 | for idx in range(4):
133 | desired_joint_angles[idx] = index_joint_angles[idx]
134 | desired_joint_angles[4 + idx] = middle_joint_angles[idx]
135 | desired_joint_angles[8 + idx] = ring_joint_angles[idx]
136 | desired_joint_angles[12 + idx] = thumb_joint_angles[idx]
137 |
138 | return desired_joint_angles
139 |
140 | def deploy(self):
141 | while True:
142 | if self.allegro_joint_state is None:
143 | print('No allegro state received!')
144 | continue
145 |
146 | if self.image is None:
147 | # print('No robot image received!')
148 | continue
149 |
150 | # Wating for key
151 | next_step = input()
152 |
153 | # Setting the break condition
154 | if next_step == "q":
155 | break
156 |
157 | print("Getting state data...\n")
158 |
159 | print("Current joint angles: {}\n".format(self.allegro_joint_state.position))
160 |
161 | # Getting the state data
162 | index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord = self.get_tip_coords()
163 | finger_state = list(index_tip_coord) + list(middle_tip_coord) + list(ring_tip_coord) + list(thumb_tip_coord)
164 |
165 | print("Current state data:\n Index-tip position: {}\n Middle-tip position: {}\n Ring-tip position: {}\n Thumb-tip position: {}\n".format(index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord))
166 |
167 | cv2.imshow("Current state image", self.image)
168 | cv2.waitKey(1)
169 |
170 | # Transforming the image
171 | pil_image = PILImage.fromarray(cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB))
172 | cropped_image = pil_image.crop((450, 220, 900, 920)) # Left, Top, Right, Bottom
173 | transformed_image_tensor = self.image_transform(cropped_image).unsqueeze(0)
174 |
175 | representation = self.encoder(transformed_image_tensor.float().to(self.device)).squeeze(2).squeeze(2).detach().cpu().numpy()[0]
176 |
177 | if self.debug is True:
178 | action, neighbor_idx, nn_state_image_path, trans_state_image_path = self.model.get_action_with_image(representation)
179 | action = list(action.reshape(12))
180 |
181 | l2_diff = np.linalg.norm(np.array(finger_state[:12]) - np.array(action))
182 |
183 | while l2_diff < self.threshold:
184 | print("Obtaining new action!")
185 | action, neighbor_idx, nn_state_image_path, trans_state_image_path = self.model.get_debug_action(neighbor_idx)
186 | action = list(action.reshape(12))
187 |
188 | l2_diff = np.linalg.norm(np.array(finger_state[:12]) - np.array(action))
189 |
190 |
191 | # Reading the Nearest Neighbor images
192 | nn_state_image = cv2.imread(nn_state_image_path)
193 | trans_state_image = cv2.imread(trans_state_image_path)
194 |
195 | # Combing the surrent state image and NN image and printing them
196 | combined_images = np.concatenate((
197 | cv2.resize(self.image, (420, 240), interpolation = cv2.INTER_AREA),
198 | cv2.resize(nn_state_image, (420, 240), interpolation = cv2.INTER_AREA),
199 | cv2.resize(trans_state_image, (420, 240), interpolation = cv2.INTER_AREA)
200 | ), axis=1)
201 |
202 | cv2.imshow("Current state and Nearest Neighbor Images", combined_images)
203 | cv2.waitKey(1)
204 |
205 | else:
206 | action = list(self.model.get_action(representation).reshape(12))
207 | # To write debug condition for k > 1
208 |
209 | print("Corresponding action: ", action)
210 |
211 | if self.abs is True:
212 | updated_index_tip_coord = action[0:3]
213 | updated_middle_tip_coord = action[3:6]
214 | updated_ring_tip_coord = action[6:9]
215 | updated_thumb_tip_coord = action[9:12]
216 | else:
217 | updated_index_tip_coord = np.array(index_tip_coord) + np.array(action[0:3])
218 | updated_middle_tip_coord = np.array(middle_tip_coord) + np.array(action[3:6])
219 | updated_ring_tip_coord = np.array(ring_tip_coord) + np.array(action[6:9])
220 | updated_thumb_tip_coord = np.array(thumb_tip_coord) + np.array(action[9:12])
221 |
222 | desired_joint_angles = self.update_joint_state(updated_index_tip_coord, updated_middle_tip_coord, updated_ring_tip_coord, updated_thumb_tip_coord)
223 |
224 | print("Moving arm to {}\n".format(desired_joint_angles))
225 | self.arm.move_hand(desired_joint_angles)
226 |
227 |
228 | if __name__ == "__main__":
229 | d = DexArmVINNDeploy(task = "rotate", debug = True, k = 1, device = "cuda", use_abs = True)
230 | d.deploy()
--------------------------------------------------------------------------------
/arm_models/deploy/deploy_scripts/inn_deploy.py:
--------------------------------------------------------------------------------
1 | # Standard imports
2 | import numpy as np
3 | import cv2
4 |
5 | # Standard ROS imports
6 | import rospy
7 | from sensor_msgs.msg import Image, JointState
8 | from visualization_msgs.msg import Marker
9 | from cv_bridge import CvBridge, CvBridgeError
10 |
11 | # Dexterous aarm control package import
12 | from move_dexarm import DexArmControl
13 |
14 | # Importing the Allegro IK library
15 | from ik_teleop.ik_core.allegro_ik import AllegroInvKDL
16 |
17 | # Model Imports
18 | from arm_models.deploy.model_scripts.inn import INNDeploy
19 |
20 | # Other imports
21 | from copy import deepcopy as copy
22 |
23 | # ROS Topics to get the AR Marker data
24 | IMAGE_TOPIC = "/cam_1/color/image_raw"
25 | MARKER_TOPIC = "/visualization_marker"
26 | JOINT_STATE_TOPIC = "/allegroHand/joint_states"
27 |
28 | # Data paths
29 | CUBE_ROTATION_DATA_PATH = "/home/sridhar/dexterous_arm/models/arm_models/data/cube_rotation/complete"
30 | OBJECT_ROTATION_DATA_PATH = "/home/sridhar/dexterous_arm/models/arm_models/data/object_flipping/complete"
31 | FIDGET_SPINNING_DATA_PATH = "/home/sridhar/dexterous_arm/models/arm_models/data/fidget_spinning/complete"
32 |
33 | class DexArmINNDeploy():
34 | def __init__(self, task, k = 1, use_abs = True, debug = False, target_priority = 1, device = "cpu", threshold = 0.02):
35 | # Initializing ROS deployment node
36 | try:
37 | rospy.init_node("model_deploy")
38 | except:
39 | pass
40 |
41 | # Setting the debug parameter only if k is 1
42 | if k == 1:
43 | self.debug = debug
44 | else:
45 | self.debug = False
46 |
47 | self.abs = use_abs
48 | self.task = task
49 |
50 | # Initializing arm controller
51 | print("Initializing controller!")
52 | self.arm = DexArmControl()
53 |
54 | # Initializing Allegro Ik Library
55 | self.allegro_ik = AllegroInvKDL(cfg = None, urdf_path = "/home/sridhar/dexterous_arm/ik_stuff/ik_teleop/urdf_template/allegro_right.urdf")
56 |
57 | # Initializing INN
58 | print("Initializing model!")
59 | if task == "rotate":
60 | if self.debug is True:
61 | self.model = INNDeploy(k = 1, data_path = CUBE_ROTATION_DATA_PATH, load_image_data = True, target_priority = target_priority, device = device)
62 | else:
63 | self.model = INNDeploy(k = k, data_path = CUBE_ROTATION_DATA_PATH, target_priority = target_priority, device = device)
64 | elif task == "flip":
65 | if self.debug is True:
66 | self.model = INNDeploy(k = 1, data_path = OBJECT_ROTATION_DATA_PATH, load_image_data = True, target_priority = target_priority, device = device)
67 | else:
68 | self.model = INNDeploy(k = k, data_path = OBJECT_ROTATION_DATA_PATH, target_priority = target_priority, device = device)
69 | elif task == "spin":
70 | if self.debug is True:
71 | self.model = INNDeploy(k = 1, data_path = FIDGET_SPINNING_DATA_PATH, load_image_data = True, target_priority = target_priority, device = device)
72 | else:
73 | self.model = INNDeploy(k = k, data_path = FIDGET_SPINNING_DATA_PATH, target_priority = target_priority, device = device)
74 |
75 |
76 | # Making sure the hand moves
77 | self.threshold = 0.02
78 |
79 | # Moving the dexterous arm to home position
80 | print("Homing robot!")
81 | if task == "rotate" or task == "flip":
82 | self.arm.home_robot()
83 | elif task == "spin":
84 | self.arm.spin_pos_arm()
85 |
86 | # Initializing topic data
87 | self.image = None
88 | self.allegro_joint_state = None
89 | if task == "rotate" or task == "flip":
90 | self.ar_marker_data = None
91 |
92 | # Other initializations
93 | self.bridge = CvBridge()
94 |
95 | if task == "rotate" or task == "flip":
96 | rospy.Subscriber(MARKER_TOPIC, Marker, self._callback_ar_marker_data, queue_size = 1)
97 | rospy.Subscriber(JOINT_STATE_TOPIC, JointState, self._callback_joint_state, queue_size = 1)
98 | rospy.Subscriber(IMAGE_TOPIC, Image, self._callback_image, queue_size=1)
99 |
100 | def _callback_ar_marker_data(self, data):
101 | self.ar_marker_data = data
102 |
103 | def _callback_joint_state(self, data):
104 | self.allegro_joint_state = data
105 |
106 | def _callback_image(self, data):
107 | try:
108 | self.image = self.bridge.imgmsg_to_cv2(data, "bgr8")
109 | except CvBridgeError as e:
110 | print(e)
111 |
112 | def get_object_position(self):
113 | object_position = np.array([self.ar_marker_data.pose.position.x, self.ar_marker_data.pose.position.y, self.ar_marker_data.pose.position.z])
114 | return object_position
115 |
116 | def get_tip_coords(self):
117 | index_coord = self.allegro_ik.finger_forward_kinematics('index', list(self.allegro_joint_state.position)[:4])[0]
118 | middle_coord = self.allegro_ik.finger_forward_kinematics('middle', list(self.allegro_joint_state.position)[4:8])[0]
119 | ring_coord = self.allegro_ik.finger_forward_kinematics('ring', list(self.allegro_joint_state.position)[8:12])[0]
120 | thumb_coord = self.allegro_ik.finger_forward_kinematics('thumb', list(self.allegro_joint_state.position)[12:16])[0]
121 |
122 | return index_coord, middle_coord, ring_coord, thumb_coord
123 |
124 | def update_joint_state(self, index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord):
125 | current_joint_angles = list(self.allegro_joint_state.position)
126 |
127 | index_joint_angles = self.allegro_ik.finger_inverse_kinematics('index', index_tip_coord, current_joint_angles[0:4])
128 | middle_joint_angles = self.allegro_ik.finger_inverse_kinematics('middle', middle_tip_coord, current_joint_angles[4:8])
129 | ring_joint_angles = self.allegro_ik.finger_inverse_kinematics('ring', ring_tip_coord, current_joint_angles[8:12])
130 | thumb_joint_angles = self.allegro_ik.finger_inverse_kinematics('thumb', thumb_tip_coord, current_joint_angles[12:16])
131 |
132 | desired_joint_angles = copy(current_joint_angles)
133 |
134 | for idx in range(4):
135 | desired_joint_angles[idx] = index_joint_angles[idx]
136 | desired_joint_angles[4 + idx] = middle_joint_angles[idx]
137 | desired_joint_angles[8 + idx] = ring_joint_angles[idx]
138 | desired_joint_angles[12 + idx] = thumb_joint_angles[idx]
139 |
140 | return desired_joint_angles
141 |
142 | def deploy(self):
143 | print("Deploying model!\n")
144 | while True:
145 | if self.allegro_joint_state is None:
146 | print('No allegro state received!')
147 | continue
148 |
149 | # Wating for key
150 | next_step = input()
151 |
152 | # Setting the break condition
153 | if next_step == "q":
154 | break
155 |
156 | print("Getting state data...\n")
157 |
158 | print("Current joint angles: {}\n".format(self.allegro_joint_state.position))
159 |
160 | # Getting the state data
161 | index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord = self.get_tip_coords()
162 |
163 | if self.task == "rotate" or self.task == "flip":
164 | object_pos = self.get_object_position()
165 | elif self.task == "spin":
166 | object_pos = [0, 0, 0]
167 |
168 | state = list(index_tip_coord) + list(middle_tip_coord) + list(ring_tip_coord) + list(thumb_tip_coord) + list(object_pos)
169 |
170 | if self.task == "rotate" or self.task == "flip":
171 | print("Current state data:\n Index-tip position: {}\n Middle-tip position: {}\n Ring-tip position: {}\n Thumb-tip position: {}\n Object position: {}\n".format(index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord, object_pos))
172 | elif self.task == "spin":
173 | print("Current state data:\n Index-tip position: {}\n Middle-tip position: {}\n Ring-tip position: {}\n Thumb-tip position: {}\n".format(index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord))
174 |
175 | if self.debug is True:
176 | action, nn_state_image_path, trans_state_image_path, index_l2_diff, middle_l2_diff, ring_l2_diff, thumb_l2_diff, object_l2_diff = self.model.get_action_with_image(index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord, object_pos)
177 | action = list(action.reshape(12))
178 |
179 | l2_diff = np.linalg.norm(np.array(state[:12]) - np.array(action))
180 |
181 | while l2_diff < self.threshold:
182 | action, nn_state_image_path, trans_state_image_path, index_l2_diff, middle_l2_diff, ring_l2_diff, thumb_l2_diff, object_l2_diff = self.model.get_action_with_image(action[:3], action[3:6], action[6:9], action[9:12], object_pos)
183 | action = list(action.reshape(12))
184 |
185 | l2_diff = np.linalg.norm(np.array(state[:12]) - np.array(action))
186 |
187 | print("Trans state img path: {}".format(trans_state_image_path))
188 |
189 | print("Distances:\n Index distance: {}\n Middle distance: {}\n Ring distance: {}\n Thumb distance: {}\n Cube distance: {}".format(index_l2_diff.item(), middle_l2_diff.item(), ring_l2_diff.item(), thumb_l2_diff.item(), object_l2_diff.item()))
190 |
191 | # Reading the Nearest Neighbor images
192 | nn_state_image = cv2.imread(nn_state_image_path)
193 | trans_state_image = cv2.imread(trans_state_image_path)
194 |
195 | # Combing the surrent state image and NN image and printing them
196 | combined_images = np.concatenate((
197 | cv2.resize(self.image, (420, 240), interpolation = cv2.INTER_AREA),
198 | cv2.resize(nn_state_image, (420, 240), interpolation = cv2.INTER_AREA),
199 | cv2.resize(trans_state_image, (420, 240), interpolation = cv2.INTER_AREA)
200 | ), axis=1)
201 |
202 | cv2.imshow("Current state and Nearest Neighbor Images", combined_images)
203 | cv2.waitKey(1)
204 |
205 | else:
206 | action = list(self.model.get_action(index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord, object_pos).reshape(12))
207 | l2_diff = np.linalg.norm(np.array(state[:12]) - np.array(action))
208 |
209 | while l2_diff < self.threshold:
210 | action = list(self.model.get_action(action[:3], action[3:6], action[6:9], action[9:12], object_pos).reshape(12))
211 | l2_diff = np.linalg.norm(np.array(state[:12]) - np.array(action))
212 |
213 | print("Corresponding action: ", action)
214 |
215 | if self.abs:
216 | # print("Using absolute values for thumb: {} and for ring finger: {}".format(action[:3], action[3:]))
217 | updated_index_tip_coord = action[0:3]
218 | updated_middle_tip_coord = action[3:6]
219 | updated_ring_tip_coord = action[6:9]
220 | updated_thumb_tip_coord = action[9:12]
221 | else:
222 | updated_index_tip_coord = np.array(index_tip_coord) + np.array(action[0:3])
223 | updated_middle_tip_coord = np.array(middle_tip_coord) + np.array(action[3:6])
224 | updated_ring_tip_coord = np.array(ring_tip_coord) + np.array(action[6:9])
225 | updated_thumb_tip_coord = np.array(thumb_tip_coord) + np.array(action[9:12])
226 |
227 | desired_joint_angles = self.update_joint_state(updated_index_tip_coord, updated_middle_tip_coord, updated_ring_tip_coord, updated_thumb_tip_coord)
228 |
229 | print("Moving arm to {}\n".format(desired_joint_angles))
230 | self.arm.move_hand(desired_joint_angles)
231 |
232 |
233 | if __name__ == "__main__":
234 | # Rotation task
235 | d = DexArmINNDeploy(task = "rotate", debug = True, k = 1, target_priority = 4, device = "cuda")
236 |
237 | # Flipping task
238 | # d = DexArmINNDeploy(task = "flip", debug = True, k = 1, target_priority = 1, device = "cuda")
239 |
240 | # Spinning task
241 | # d = DexArmINNDeploy(task = "spin", debug = True, k = 1, target_priority = 1, device = "cuda")
242 |
243 | d.deploy()
--------------------------------------------------------------------------------
/arm_models/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from torch.utils.data import DataLoader
5 |
6 | from arm_models.dataloaders.state_dataset import *
7 | from arm_models.dataloaders.visual_dataset import *
8 |
9 | from imitation.networks import *
10 | from byol_pytorch import BYOL
11 |
12 | import wandb
13 | import argparse
14 | from tqdm import tqdm
15 |
16 | from utils.augmentation import augmentation_generator
17 |
18 | CHKPT_PATH = os.path.join(os.getcwd(), "deploy", "checkpoints")
19 |
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument('-t', '--task', type=str)
22 | parser.add_argument('-n', '--network', type=str)
23 | parser.add_argument('-b', '--batch_size', type=int)
24 | parser.add_argument('-e', '--epochs', type=int)
25 | parser.add_argument('-g', '--gpu_num', type=int)
26 | parser.add_argument('-l', '--lr', type=float)
27 | parser.add_argument('-i', '--img_size', type=int)
28 | parser.add_argument('-d', '--train_data_path', type=str)
29 | parser.add_argument('--full_data', action='store_true')
30 | parser.add_argument('-v', '--val_data_path', type=str)
31 | parser.add_argument('-f', '--test_data_path', type=str)
32 | parser.add_argument('-a', '--augment_imgs', action='store_true')
33 | parser.add_argument('-r', '--run', type=int)
34 |
35 | if __name__ == '__main__':
36 | # Getting the parser options
37 | options = parser.parse_args()
38 |
39 | # Selecting the torch device
40 | device = torch.device("cuda:{}".format(options.gpu_num))
41 | print("Using GPU: {} for training. \n".format(options.gpu_num))
42 |
43 | # Load the dataset and creating dataloader
44 | print("Loading dataset(s) and creating dataloader(s).\n")
45 | if options.task == "rotate":
46 | if options.img_size is None:
47 | train_dataset = CubeRotationDataset(options.train_data_path)
48 | if options.val_data_path is not None:
49 | val_dataset = CubeRotationDataset(options.val_data_path)
50 | if options.test_data_path is not None:
51 | test_dataset = CubeRotationDataset(options.test_data_path)
52 | else:
53 | if options.network == "bc":
54 | if options.full_data is True:
55 | print("Using the entire dataset!\n")
56 | train_dataset = CubeRotationVisualDataset(type = None, image_size = options.img_size, data_path = options.train_data_path)
57 | else:
58 | train_dataset = CubeRotationVisualDataset(type = "train", image_size = options.img_size, data_path = options.train_data_path)
59 | if options.val_data_path is not None:
60 | val_dataset = CubeRotationVisualDataset(type = "val", image_size = options.img_size, data_path = options.val_data_path)
61 | if options.test_data_path is not None:
62 | test_dataset = CubeRotationVisualDataset(type = "test", image_size = options.img_size, data_path = options.test_data_path)
63 | elif options.network == "representation_byol":
64 | train_dataset = RepresentationVisualDataset(task = "cube_rotation", data_path = options.train_data_path)
65 | if options.val_data_path is not None:
66 | val_dataset = RepresentationVisualDataset(task = "cube_rotation", data_path = options.train_data_path)
67 | if options.test_data_path is not None:
68 | test_dataset = RepresentationVisualDataset(task = "cube_rotation", data_path = options.train_data_path)
69 |
70 | elif options.task == "flip":
71 | if options.img_size is None:
72 | train_dataset = ObjectFlippingDataset(options.train_data_path)
73 | if options.val_data_path is not None:
74 | val_dataset = ObjectFlippingDataset(options.val_data_path)
75 | if options.test_data_path is not None:
76 | test_dataset = ObjectFlippingDataset(options.test_data_path)
77 | else:
78 | if options.network == "bc":
79 | if options.full_data is True:
80 | print("Using the entire dataset!\n")
81 | train_dataset = ObjectFlippingVisualDataset(type = None, image_size = options.img_size, data_path = options.train_data_path)
82 | else:
83 | train_dataset = ObjectFlippingVisualDataset(type = "train", image_size = options.img_size, data_path = options.train_data_path)
84 | if options.val_data_path is not None:
85 | val_dataset = ObjectFlippingVisualDataset(type = "val", image_size = options.img_size, data_path = options.val_data_path)
86 | if options.test_data_path is not None:
87 | test_dataset = ObjectFlippingVisualDataset(type = "test", image_size = options.img_size, data_path = options.test_data_path)
88 | elif options.network == "representation_byol":
89 | train_dataset = RepresentationVisualDataset(task = "object_flipping", data_path = options.train_data_path)
90 | if options.val_data_path is not None:
91 | val_dataset = RepresentationVisualDataset(task = "object_flipping", data_path = options.train_data_path)
92 | if options.test_data_path is not None:
93 | test_dataset = RepresentationVisualDataset(task = "object_flipping", data_path = options.train_data_path)
94 |
95 | elif options.task == "spin":
96 | if options.img_size is None:
97 | train_dataset = FidgetSpinningDataset(options.train_data_path)
98 | if options.val_data_path is not None:
99 | val_dataset = FidgetSpinningDataset(options.val_data_path)
100 | if options.test_data_path is not None:
101 | test_dataset = FidgetSpinningDataset(options.test_data_path)
102 | else:
103 | if options.network == "bc":
104 | if options.full_data is True:
105 | print("Using the entire dataset!\n")
106 | train_dataset = FidgetSpinningVisualDataset(type = None, image_size = options.img_size, data_path = options.train_data_path)
107 | else:
108 | train_dataset = FidgetSpinningVisualDataset(type = "train", image_size = options.img_size, data_path = options.train_data_path)
109 | if options.val_data_path is not None:
110 | val_dataset = FidgetSpinningVisualDataset(type = "val", image_size = options.img_size, data_path = options.val_data_path)
111 | if options.test_data_path is not None:
112 | test_dataset = FidgetSpinningVisualDataset(type = "test", image_size = options.img_size, data_path = options.test_data_path)
113 | elif options.network == "representation_byol":
114 | train_dataset = RepresentationVisualDataset(task = "fidget_spinning", data_path = options.train_data_path)
115 | if options.val_data_path is not None:
116 | val_dataset = RepresentationVisualDataset(task = "fidget_spinning", data_path = options.train_data_path)
117 | if options.test_data_path is not None:
118 | test_dataset = RepresentationVisualDataset(task = "fidget_spinning", data_path = options.train_data_path)
119 |
120 | train_dataloader = DataLoader(train_dataset, options.batch_size, shuffle = True, pin_memory = True, num_workers = 24)
121 |
122 | if options.val_data_path is not None:
123 | val_dataloader = DataLoader(val_dataset, options.batch_size, shuffle = True, pin_memory = True, num_workers = 24)
124 |
125 | if options.test_data_path is not None:
126 | test_dataloader = DataLoader(test_dataset, options.batch_size, shuffle = True, pin_memory = True, num_workers = 24)
127 |
128 | # Initializing the model based on the model argument
129 | print("Loading the model({}) for the task: {}\n".format(options.network, options.task))
130 | if options.network == "mlp":
131 | model = MLP().to(device)
132 | elif options.network == "bc":
133 | model = BehaviorCloning().to(device)
134 | elif options.network == "representation_byol":
135 | original_encoder_model = models.resnet50(pretrained = True)
136 | encoder = torch.nn.Sequential(*(list(original_encoder_model.children())[:-1])).to(device)
137 |
138 | if options.augment_imgs is True:
139 | augment_custom = augmentation_generator(options.task)
140 | model = BYOL(
141 | encoder,
142 | image_size = options.img_size,
143 | augment_fn = augment_custom
144 | )
145 | else:
146 | model = BYOL(
147 | encoder,
148 | image_size = options.img_size
149 | )
150 |
151 | # Initialize WandB logging
152 | wandb.init(project = "{} - {}".format(options.task, options.network))
153 |
154 | # Initializing optimizer and other parameters
155 | optimizer = torch.optim.Adam(model.parameters(), lr = options.lr)
156 |
157 | if options.test_data_path is not None:
158 | low_test_loss = np.inf
159 | elif options.val_data_path is not None:
160 | low_val_loss = np.inf
161 | else:
162 | low_train_loss = np.inf
163 |
164 | # Initializing loss
165 | if options.network != "representation_byol":
166 | loss_fn = nn.MSELoss()
167 |
168 | # Training loop
169 | print("Starting training procedure!\n")
170 | for epoch in range(options.epochs):
171 | # Training part
172 | epoch_train_loss = 0
173 |
174 | for input_data, actions in tqdm(train_dataloader):
175 | optimizer.zero_grad()
176 |
177 | if options.network == "representation_byol":
178 | loss = model(input_data.float().to(device))
179 | else:
180 | predicted_actions = model(input_data.float().to(device))
181 | loss = loss_fn(predicted_actions, actions.float().to(device))
182 |
183 | loss.backward()
184 | optimizer.step()
185 |
186 | epoch_train_loss += loss.item() * input_data.shape[0]
187 |
188 | print("Train loss: {}\n".format(epoch_train_loss / len(train_dataset)))
189 | wandb.log({'train loss': epoch_train_loss / len(train_dataset)})
190 |
191 | # Validation part
192 | if options.val_data_path is not None:
193 | epoch_val_loss = 0
194 |
195 | for input_data, actions in tqdm(val_dataloader):
196 | if options.network == "representation_byol":
197 | loss = model(input_data.float().to(device))
198 | else:
199 | predicted_actions = model(input_data.float().to(device))
200 | loss = loss_fn(predicted_actions, actions.float().to(device))
201 |
202 | epoch_val_loss += loss.item() * input_data.shape[0]
203 |
204 | print("Validation loss: {}\n".format(epoch_val_loss / len(val_dataset)))
205 | wandb.log({'validation loss': epoch_val_loss / len(val_dataset)})
206 |
207 | # Testing part
208 | if options.test_data_path is not None:
209 | epoch_test_loss = 0
210 |
211 | for input_data, actions in tqdm(test_dataloader):
212 | if options.network == "representation_byol":
213 | loss = model(input_data.float().to(device))
214 | else:
215 | predicted_actions = model(input_data.float().to(device))
216 | loss = loss_fn(predicted_actions, actions.float().to(device))
217 |
218 | epoch_test_loss += loss.item() * input_data.shape[0]
219 |
220 | print("Test loss: {}\n".format(epoch_test_loss / len(test_dataset)))
221 | wandb.log({'test loss': epoch_test_loss / len(test_dataset)})
222 |
223 |
224 |
225 | # Saving checkpoints based on the lowest stats
226 | if options.network == "vinn":
227 | low_train_loss = epoch_train_loss / len(train_dataset)
228 | epoch_checkpt_path = os.path.join(CHKPT_PATH, "{} - {} - lowest - train - v{}.pth".format(options.network, options.task, options.run))
229 | print("\nLower train loss encountered! Saving checkpoint {}\n".format(epoch_checkpt_path))
230 | torch.save(encoder.state_dict(), epoch_checkpt_path)
231 | else:
232 | if options.test_data_path is not None:
233 | if epoch_test_loss / len(test_dataset) < low_test_loss:
234 | low_test_loss = epoch_test_loss / len(test_dataset)
235 | epoch_checkpt_path = os.path.join(CHKPT_PATH, "{} - {} - lowest - test - v{}.pth".format(options.network, options.task, options.run))
236 | print("\nLower test loss encountered! Saving checkpoint {}\n".format(epoch_checkpt_path))
237 | torch.save(model.state_dict(), epoch_checkpt_path)
238 | elif options.val_data_path is not None:
239 | if epoch_val_loss / len(val_dataset) < low_val_loss:
240 | low_val_loss = epoch_val_loss / len(val_dataset)
241 | epoch_checkpt_path = os.path.join(CHKPT_PATH, "{} - {} - lowest - val - v{}.pth".format(options.network, options.task, options.run))
242 | print("\nLower validation loss encountered! Saving checkpoint {}\n".format(epoch_checkpt_path))
243 | torch.save(model.state_dict(), epoch_checkpt_path)
244 | else:
245 | if epoch_train_loss / len(train_dataset) < low_train_loss:
246 | low_train_loss = epoch_train_loss / len(train_dataset)
247 | epoch_checkpt_path = os.path.join(CHKPT_PATH, "{} - {} - lowest - train - v{}.pth".format(options.network, options.task, options.run))
248 | print("\nLower train loss encountered! Saving checkpoint {}\n".format(epoch_checkpt_path))
249 | torch.save(model.state_dict(), epoch_checkpt_path)
250 |
251 | # Saving final checkpoint
252 | if options.task == "vinn":
253 | final_epoch_checkpt_path = os.path.join(CHKPT_PATH, "{} - {} - final - v{}.pth".format(options.network, options.task, options.run))
254 | print("\nSaving final checkpoint {}\n".format(final_epoch_checkpt_path))
255 | torch.save(encoder.state_dict(), final_epoch_checkpt_path)
256 | else:
257 | final_epoch_checkpt_path = os.path.join(CHKPT_PATH, "{} - {} - final - v{}.pth".format(options.network, options.task, options.run))
258 | print("\nSaving final checkpoint {}\n".format(final_epoch_checkpt_path))
259 | torch.save(model.state_dict(), final_epoch_checkpt_path)
--------------------------------------------------------------------------------
/arm_models/dataloaders/visual_dataset.py:
--------------------------------------------------------------------------------
1 | from argparse import Action
2 | import os
3 | from PIL import Image
4 |
5 | import torch
6 | from torch.utils.data import Dataset, DataLoader
7 | from torchvision import transforms as T
8 |
9 | from tqdm import tqdm
10 |
11 | class RepresentationVisualDataset(Dataset):
12 | def __init__(self, task, image_size = 224, data_path = '/home/sridhar/dexterous_arm/demonstrations/image_data'):
13 | # Loading the images
14 | images_dir= os.path.join(data_path, task)
15 |
16 | # Sorting all the demo paths
17 | image_path_list = os.listdir(images_dir)
18 | image_path_list.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))
19 |
20 | self.image_tensors = []
21 |
22 | # Transforming the images
23 | if task == "cube_rotation":
24 | mean = torch.tensor([0.4640, 0.4933, 0.5223])
25 | std = torch.tensor([0.2890, 0.2671, 0.2530])
26 | elif task == "object_flipping":
27 | mean = torch.tensor([0.4815, 0.5450, 0.5696])
28 | std = torch.tensor([0.2291, 0.2268, 0.2248])
29 | elif task == "fidget_spinning":
30 | mean = torch.tensor([0.4306, 0.3954, 0.3472])
31 | std = torch.tensor([0.2897, 0.2527, 0.2321])
32 |
33 | self.image_preprocessor = T.Compose([
34 | T.ToTensor(),
35 | T.Resize((image_size, image_size)),
36 | T.Normalize(
37 | mean = mean,
38 | std = std
39 | )
40 | ])
41 |
42 | # Loading all the images in the images vector
43 | print("Loading all the images: \n")
44 | for demo_images_path in tqdm(image_path_list):
45 | demo_image_folder_path = os.path.join(images_dir, demo_images_path)
46 |
47 | # Sort the demo images list
48 | demo_images_list = os.listdir(demo_image_folder_path)
49 |
50 | demo_images_list.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))
51 |
52 | # Read each image and append them in the images array
53 | for idx in range(len(demo_images_list) - 1):
54 | try:
55 | image = Image.open(os.path.join(demo_image_folder_path, demo_images_list[idx]))
56 | if task == "cube_rotation":
57 | image = image.crop((500, 160, 950, 600)) # Left, Top, Right, Bottom
58 | elif task == "object_flipping":
59 | image = image.crop((220, 165, 460, 340)) # Left, Top, Right, Bottom
60 | elif task == "fidget_spinning":
61 | image = image.crop((65, 80, 590, 480)) # Left, Top, Right, Bottom
62 |
63 | image_tensor = self.image_preprocessor(image)
64 | self.image_tensors.append(image_tensor.detach())
65 | image.close()
66 | except:
67 | print('Image cannot be read!')
68 | continue
69 |
70 | def __len__(self):
71 | return len(self.image_tensors)
72 |
73 | def __getitem__(self, idx):
74 | return (self.image_tensors[idx], torch.zeros(12))
75 |
76 | class CubeRotationVisualDataset(Dataset):
77 | def __init__(self, type = None, image_size = 224, data_path = '/home/sridhar/dexterous_arm/models/arm_models/data/cube_rotation'):
78 | if type == "train":
79 | self.data_path = os.path.join(data_path, "for_eval", "train")
80 | elif type == "val":
81 | self.data_path = os.path.join(data_path, "for_eval", "validation")
82 | elif type == "test":
83 | self.data_path = os.path.join(data_path, "for_eval", "test")
84 | else:
85 | self.data_path = os.path.join(data_path, "complete")
86 |
87 | # Loading all the image and action folder paths
88 | images_dir, actions_dir = os.path.join(self.data_path, 'images'), os.path.join(self.data_path, 'actions')
89 |
90 | # Sorting all the demo paths
91 | image_path_list, action_path_list = os.listdir(images_dir), os.listdir(actions_dir)
92 |
93 | image_path_list.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))
94 | action_path_list.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))
95 |
96 | self.image_tensors, self.action_tensors = [], []
97 |
98 | # Transforming the images
99 | if type == "train":
100 | mean = torch.tensor([0.4631, 0.4917, 0.5201])
101 | std = torch.tensor([0.2885, 0.2666, 0.2526])
102 | elif type == "val":
103 | mean = torch.tensor([0.3557, 0.4138, 0.4550])
104 | std = torch.tensor([0.9030, 0.8565, 0.8089])
105 | elif type == "load":
106 | mean = torch.tensor([0, 0, 0])
107 | std = torch.tensor([1, 1, 1])
108 | else:
109 | mean = torch.tensor([0.4631, 0.4923, 0.5215])
110 | std = torch.tensor([0.2891, 0.2674, 0.2535])
111 |
112 | self.image_preprocessor = T.Compose([
113 | T.ToTensor(),
114 | T.Resize((image_size, image_size)),
115 | T.Normalize(
116 | mean = mean,
117 | std = std
118 | )
119 | ])
120 |
121 | # Loading all the images in the images vector
122 | print("Loading all the images: \n")
123 | for demo_images_path in tqdm(image_path_list):
124 | demo_image_folder_path = os.path.join(images_dir, demo_images_path)
125 |
126 | # Sort the demo images list
127 | demo_images_list = os.listdir(demo_image_folder_path)
128 | demo_images_list.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))
129 |
130 | # print("Reading images from {}".format(demo_image_folder_path))
131 |
132 | # Read each image and append them in the images array
133 | for idx in range(len(demo_images_list) - 1):
134 | try:
135 | image = Image.open(os.path.join(demo_image_folder_path, demo_images_list[idx]))
136 | image = image.crop((500, 160, 950, 600)) # Left, Top, Right, Bottom
137 | image_tensor = self.image_preprocessor(image)
138 | self.image_tensors.append(image_tensor.detach())
139 | image.close()
140 | except:
141 | print('Image cannot be read!')
142 | continue
143 |
144 | # Loading all the action vectors
145 | print("\nLoading all the actions: \n")
146 | for demo_action_path in tqdm(action_path_list):
147 | demo_actions = torch.load(os.path.join(actions_dir, demo_action_path))
148 | self.action_tensors.append(demo_actions)
149 |
150 | self.action_tensors = torch.cat(self.action_tensors, dim = 0)
151 |
152 | def __len__(self):
153 | return len(self.image_tensors)
154 |
155 | def __getitem__(self, idx):
156 | return ((self.image_tensors[idx], self.action_tensors[idx]))
157 |
158 | class ObjectFlippingVisualDataset(Dataset):
159 | def __init__(self, type = None, image_size = 224, data_path = '/home/sridhar/dexterous_arm/models/arm_models/data/object_flipping'):
160 | if type == "train":
161 | self.data_path = os.path.join(data_path, "for_eval", "train")
162 | elif type == "val":
163 | self.data_path = os.path.join(data_path, "for_eval", "validation")
164 | elif type == "test":
165 | self.data_path = os.path.join(data_path, "for_eval", "test")
166 | else:
167 | self.data_path = os.path.join(data_path, "complete")
168 |
169 | # Loading all the image and action folder paths
170 | images_dir, actions_dir = os.path.join(self.data_path, 'images'), os.path.join(self.data_path, 'actions')
171 |
172 | # Sorting all the demo paths
173 | image_path_list, action_path_list = os.listdir(images_dir), os.listdir(actions_dir)
174 |
175 | image_path_list.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))
176 | action_path_list.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))
177 |
178 | self.image_tensors, self.action_tensors = [], []
179 |
180 | # Transforming the images
181 | if type == "train":
182 | mean = torch.tensor([0.4631, 0.4917, 0.5201]) # Recalculate
183 | std = torch.tensor([0.2885, 0.2666, 0.2526]) # Recalculate
184 | elif type == "val":
185 | mean = torch.tensor([0.3557, 0.4138, 0.4550]) # Recalculate
186 | std = torch.tensor([0.9030, 0.8565, 0.8089]) # Recalculate
187 | elif type == "load":
188 | mean = torch.tensor([0, 0, 0])
189 | std = torch.tensor([1, 1, 1])
190 | else:
191 | mean = torch.tensor([0.4534, 0.3770, 0.3885])
192 | std = torch.tensor([0.2512, 0.1881, 0.2599])
193 |
194 | self.image_preprocessor = T.Compose([
195 | T.ToTensor(),
196 | T.Resize((image_size, image_size)),
197 | T.Normalize(
198 | mean = mean,
199 | std = std # To be computed
200 | )
201 | ])
202 |
203 | # Loading all the images in the images vector
204 | print("Loading all the images: \n")
205 | for demo_images_path in tqdm(image_path_list):
206 | demo_image_folder_path = os.path.join(images_dir, demo_images_path)
207 |
208 | # Sort the demo images list
209 | demo_images_list = os.listdir(demo_image_folder_path)
210 | demo_images_list.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))
211 |
212 | # Read each image and append them in the images array
213 | for idx in range(len(demo_images_list) - 1):
214 | try:
215 | image = Image.open(os.path.join(demo_image_folder_path, demo_images_list[idx]))
216 | image = image.crop((220, 165, 460, 340)) # Left, Top, Right, Bottom
217 | image_tensor = self.image_preprocessor(image)
218 | self.image_tensors.append(image_tensor.detach())
219 | image.close()
220 | except:
221 | print('Image cannot be read!')
222 | continue
223 |
224 | # Loading all the action vectors
225 | print("\nLoading all the actions: \n")
226 | for demo_action_path in tqdm(action_path_list):
227 | demo_actions = torch.load(os.path.join(actions_dir, demo_action_path))
228 | self.action_tensors.append(demo_actions)
229 |
230 | self.action_tensors = torch.cat(self.action_tensors, dim = 0)
231 |
232 | def __len__(self):
233 | return len(self.image_tensors)
234 |
235 | def __getitem__(self, idx):
236 | return ((self.image_tensors[idx], self.action_tensors[idx]))
237 |
238 | class FidgetSpinningVisualDataset(Dataset):
239 | def __init__(self, type = None, image_size = 224, data_path = '/home/sridhar/dexterous_arm/models/arm_models/data/fidget_spinning'):
240 | if type == "train":
241 | self.data_path = os.path.join(data_path, "for_eval", "train")
242 | elif type == "val":
243 | self.data_path = os.path.join(data_path, "for_eval", "validation")
244 | elif type == "test":
245 | self.data_path = os.path.join(data_path, "for_eval", "test")
246 | else:
247 | self.data_path = os.path.join(data_path, "complete")
248 |
249 | # Loading all the image and action folder paths
250 | images_dir, actions_dir = os.path.join(self.data_path, 'images'), os.path.join(self.data_path, 'actions')
251 |
252 | # Sorting all the demo paths
253 | image_path_list, action_path_list = os.listdir(images_dir), os.listdir(actions_dir)
254 |
255 | image_path_list.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))
256 | action_path_list.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))
257 |
258 | self.image_tensors, self.action_tensors = [], []
259 |
260 | # Transforming the images
261 | if type == "train":
262 | mean = torch.tensor([0.4631, 0.4917, 0.5201]) # To be recalculated
263 | std = torch.tensor([0.2885, 0.2666, 0.2526]) # To be recalculated
264 | elif type == "val":
265 | mean = torch.tensor([0.3557, 0.4138, 0.4550]) # To be recalculated
266 | std = torch.tensor([0.9030, 0.8565, 0.8089]) # To be recalculated
267 | elif type == "load":
268 | mean = torch.tensor([0, 0, 0])
269 | std = torch.tensor([1, 1, 1])
270 | else:
271 | mean = torch.tensor([0.4320, 0.3963, 0.3478])
272 | std = torch.tensor([0.2897, 0.2525, 0.2317])
273 |
274 | self.image_preprocessor = T.Compose([
275 | T.ToTensor(),
276 | T.Resize((image_size, image_size)),
277 | T.Normalize(
278 | mean = mean,
279 | std = std
280 | )
281 | ])
282 |
283 | # Loading all the images in the images vector
284 | print("Loading all the images: \n")
285 | for demo_images_path in tqdm(image_path_list):
286 | demo_image_folder_path = os.path.join(images_dir, demo_images_path)
287 |
288 | # Sort the demo images list
289 | demo_images_list = os.listdir(demo_image_folder_path)
290 | demo_images_list.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))
291 |
292 | # print("Reading images from {}".format(demo_image_folder_path))
293 |
294 | # Read each image and append them in the images array
295 | for idx in range(len(demo_images_list) - 1):
296 | try:
297 | image = Image.open(os.path.join(demo_image_folder_path, demo_images_list[idx]))
298 | image = image.crop((65, 80, 590, 480)) # Left, Top, Right, Bottom
299 | image_tensor = self.image_preprocessor(image)
300 | self.image_tensors.append(image_tensor.detach())
301 | image.close()
302 | except:
303 | print('Image cannot be read!')
304 | continue
305 |
306 | # Loading all the action vectors
307 | print("\nLoading all the actions: \n")
308 | for demo_action_path in tqdm(action_path_list):
309 | demo_actions = torch.load(os.path.join(actions_dir, demo_action_path))
310 | self.action_tensors.append(demo_actions)
311 |
312 | self.action_tensors = torch.cat(self.action_tensors, dim = 0)
313 |
314 | def __len__(self):
315 | return len(self.image_tensors)
316 |
317 | def __getitem__(self, idx):
318 | return ((self.image_tensors[idx], self.action_tensors[idx]))
319 |
320 | if __name__ == '__main__':
321 |
322 | # To find the mean and std of the image pixels for normalization
323 | # dataset = CubeRotationVisualDataset()
324 | # dataset = ObjectFlippingVisualDataset()
325 | # dataset = FidgetSpinningVisualDataset()
326 | dataset = RepresentationVisualDataset("fidget_spinning")
327 | print("Number of images in the dataset: {}\n".format(len(dataset)))
328 |
329 | dataloader = DataLoader(dataset, batch_size = 128, shuffle = False, pin_memory = True, num_workers = 24)
330 |
331 | psum = torch.tensor([0.0, 0.0, 0.0])
332 | psum_sq = torch.tensor([0.0, 0.0, 0.0])
333 |
334 | for images, actions in tqdm(dataloader):
335 | psum += images.sum(axis = [0, 2, 3])
336 | psum_sq += (images ** 2).sum(axis = [0, 2, 3])
337 |
338 | count = len(dataset) * 224 * 224
339 |
340 | total_mean = psum / count
341 | total_var = (psum_sq / count) - (total_mean ** 2)
342 | total_std = torch.sqrt(total_var)
343 |
344 | print("Mean: {}".format(total_mean))
345 | print("Std: {}".format(total_std))
--------------------------------------------------------------------------------
/arm_models/deploy/deploy.py:
--------------------------------------------------------------------------------
1 | # Standard imports
2 | import os
3 | import numpy as np
4 | import torch
5 | import cv2
6 | from PIL import Image as PILImage
7 |
8 | # ROS imports
9 | import rospy
10 | from sensor_msgs.msg import Image, JointState
11 | from visualization_msgs.msg import Marker
12 | from cv_bridge import CvBridge, CvBridgeError
13 |
14 | # Controller imports
15 | from ik_teleop.ik_core.allegro_ik import AllegroInvKDL
16 | from move_dexarm import DexArmControl
17 |
18 | # Model imports
19 | from arm_models.imitation.networks import *
20 | from arm_models.deploy.model_scripts import inn, visual_inn
21 | from arm_models.utils.augmentation import augmentation_generator
22 |
23 | # Miscellaneous imports
24 | from copy import deepcopy as copy
25 | import argparse
26 |
27 | # Other torch imports
28 | from torchvision import transforms as T
29 |
30 | # ROS Topic to get data
31 | ROBOT_IMAGE_TOPIC = "/cam_1/color/image_raw"
32 | AR_MARKER_TOPIC = "/visualization_marker"
33 | HAND_JOINT_STATE_TOPIC = "/allegroHand/joint_states"
34 | ARM_JOINT_STATE_TOPIC = "/j2n6s300_driver/out/joint_state"
35 |
36 | DATA_PATH = os.path.join(os.path.abspath(os.pardir), "data")
37 | URDF_PATH = "/home/sridhar/dexterous_arm/ik_stuff/ik_teleop/urdf_template/allegro_right.urdf"
38 |
39 | parser = argparse.ArgumentParser()
40 | parser.add_argument('-t', '--task', type=str)
41 | parser.add_argument('-m', '--model', type=str)
42 | parser.add_argument('-d', '--device', type=int)
43 | parser.add_argument('--delta', type=int)
44 | parser.add_argument('--run', type=int)
45 |
46 | # Specific to INN or VINN
47 | parser.add_argument('--k', type=int)
48 | parser.add_argument('-p', '--target_priority', type=int)
49 |
50 |
51 | class DexArmDeploy():
52 | def __init__(self, task, model, run, k = None, object_priority = 1, device = "cpu", threshold = 0.02):
53 | # Initializing ROS node for the dexterous arm
54 | print("Initializing ROS node.\n")
55 | rospy.init_node("deploy_node")
56 |
57 | self.task = task
58 | self.model_to_use = model
59 | self.threshold = threshold
60 | self.device = device
61 | self.k = k
62 |
63 | # Initializing the dexterous arm controller
64 | print("Initializing the controller for the dexterous arm.\n")
65 | self.arm = DexArmControl()
66 |
67 | # Positioning robot based on the task
68 | print("Positioning the dexterous arm to perform task: {}\n".format(task))
69 | # if task == "rotate" or task == "flip":
70 | # self.arm.home_robot()
71 | # elif task == "spin":
72 | # self.arm.spin_pos_arm()
73 |
74 | # Initializing the Inverse Kinematics solver for the Allegro Hand
75 | self.allegro_ik = AllegroInvKDL(cfg = None, urdf_path = URDF_PATH)
76 |
77 | # Initializing the model for the specific task
78 | if self.model_to_use == "inn":
79 | print("Initializing INN for task {} with k = {}.\n".format(task, k))
80 | self.model = self._init_INN(task, k, object_priority, device)
81 | elif self.model_to_use == "vinn":
82 | print("Initializing VINN for task {} with k = {}.\n".format(task, k))
83 | self.model, self.learner = self._init_VINN(task, k, device, run)
84 | elif self.model_to_use == "mlp":
85 | print("Initializing a Simple MLP model for task {}.\n".format(task))
86 | self.model = self._init_MLP(task, device, run)
87 | elif self.model_to_use == "bc":
88 | print("Initializing a Behavior Cloning model for task {}.\n".format(task))
89 | self.model = self._init_BC(task, device, run)
90 |
91 | # Loading image transformer for image based models
92 | if self.model_to_use == "vinn" or self.model_to_use == "bc":
93 | if self.task == "rotate":
94 | self.image_transform = T.Compose([
95 | T.ToTensor(),
96 | T.Resize((224, 224)),
97 | T.Normalize(
98 | mean = torch.tensor([0.4631, 0.4923, 0.5215]),
99 | std = torch.tensor([0.2891, 0.2674, 0.2535])
100 | )
101 | ])
102 | elif self.task == "flip":
103 | self.image_transform = T.Compose([
104 | T.ToTensor(),
105 | T.Resize((224, 224)),
106 | T.Normalize(
107 | mean = torch.tensor([0.4534, 0.3770, 0.3885]),
108 | std = torch.tensor([0.2512, 0.1881, 0.2599])
109 | )
110 | ])
111 | elif self.task == "spin":
112 | self.image_transform = T.Compose([
113 | T.ToTensor(),
114 | T.Resize((224, 224)),
115 | T.Normalize(
116 | mean = torch.tensor([0.4306, 0.3954, 0.3472]),
117 | std = torch.tensor([0.2897, 0.2527, 0.2321])
118 | )
119 | ])
120 |
121 | # Realtime data obtained through ROS
122 | self.allegro_joint_state = None
123 | self.kinova_joint_state = None
124 |
125 | self.bridge = CvBridge()
126 | self.robot_image = None
127 |
128 | rospy.Subscriber(HAND_JOINT_STATE_TOPIC, JointState, self._callback_allegro_joint_state, queue_size = 1)
129 | rospy.Subscriber(ARM_JOINT_STATE_TOPIC, JointState, self._callback_kinova_joint_state, queue_size = 1)
130 | rospy.Subscriber(ROBOT_IMAGE_TOPIC, Image, self._callback_robot_image, queue_size=1)
131 |
132 | if self.task == "rotate" or self.task == "flip":
133 | rospy.Subscriber(AR_MARKER_TOPIC, Marker, self._callback_ar_marker_data, queue_size = 1)
134 |
135 | if self.task == "rotate":
136 | self.object_tracked_position = None
137 | self.hand_base_position = None
138 | elif self.task == "flip":
139 | self.object_tracked_position = None
140 |
141 | def _callback_allegro_joint_state(self, data):
142 | self.allegro_joint_state = data
143 |
144 | def _callback_kinova_joint_state(self, data):
145 | self.kinova_joint_state = data
146 |
147 | def _callback_robot_image(self, image):
148 | try:
149 | self.robot_image = self.bridge.imgmsg_to_cv2(image, "bgr8")
150 | except CvBridgeError as e:
151 | print(e)
152 |
153 | def _callback_ar_marker_data(self, data):
154 | if data.id == 0 or data.id == 5:
155 | self.object_tracked_position = np.array([data.pose.position.x, data.pose.position.y, data.pose.position.z])
156 | elif data.id == 8:
157 | self.hand_base_position = np.array([data.pose.position.x, data.pose.position.y, data.pose.position.z])
158 |
159 | def _get_object_position(self):
160 | if self.task == "rotate":
161 | return self.object_tracked_position - self.hand_base_position
162 | elif self.task == "flip":
163 | return self.object_tracked_position
164 |
165 | def _init_INN(self, task, k, object_priority, device):
166 | # Getting the path to load the data
167 | if task == "rotate":
168 | folder = "cube_rotation"
169 | elif task == "flip":
170 | folder = "object_flipping"
171 | elif task == "spin":
172 | folder = "fidget_spinning"
173 | task_data_path = os.path.join(DATA_PATH, folder, "complete")
174 |
175 | # Initializing the model based on k to enable debug
176 | return inn.INNDeploy(
177 | k = k,
178 | data_path = task_data_path,
179 | load_image_data = True if k == 1 else False,
180 | target_priority = object_priority,
181 | device = device
182 | )
183 |
184 | def _init_VINN(self, task, k, device, run):
185 | # Loading the checkpoint based on the task
186 | chkpt_path = os.path.join(os.getcwd(), "checkpoints", "representation_byol - {} - lowest - train - v{}.pth".format(task, run))
187 |
188 | # Loading the representation encoder
189 | original_encoder_model = models.resnet50(pretrained = True)
190 | encoder = torch.nn.Sequential(*(list(original_encoder_model.children())[:-1]))
191 | encoder = encoder.to(device)
192 |
193 | learner = BYOL (
194 | encoder,
195 | image_size = 224
196 | )
197 | learner.load_state_dict(torch.load(chkpt_path))
198 | learner.eval()
199 |
200 | # Loading the VINN model
201 | model = visual_inn.VINNDeploy(
202 | k = k,
203 | load_image_data = True if k == 1 else False,
204 | device = device
205 | )
206 |
207 | return model, learner
208 |
209 | def _init_MLP(self, task, device, run):
210 | # Loading the checkpoint based on the task
211 | chkpt_path = os.path.join(os.getcwd(), "checkpoints", "mlp - {} - lowest - train - v{}.pth".format(task, run))
212 |
213 | model = MLP().to(torch.device(device))
214 | model.load_state_dict(torch.load(chkpt_path))
215 | model.eval()
216 | return model
217 |
218 | def _init_BC(self, task, device, run):
219 | # Loading the checkpoint based on the task
220 | chkpt_path = os.path.join(os.getcwd(), "checkpoints", "bc - {} - lowest - train - v{}.pth".format(task, run))
221 |
222 | model = BehaviorCloning().to(torch.device(device))
223 | model.load_state_dict(torch.load(chkpt_path))
224 | model.eval()
225 | return model
226 |
227 | def _get_tip_coords(self):
228 | index_coord = self.allegro_ik.finger_forward_kinematics('index', list(self.allegro_joint_state.position)[:4])[0]
229 | middle_coord = self.allegro_ik.finger_forward_kinematics('middle', list(self.allegro_joint_state.position)[4:8])[0]
230 | ring_coord = self.allegro_ik.finger_forward_kinematics('ring', list(self.allegro_joint_state.position)[8:12])[0]
231 | thumb_coord = self.allegro_ik.finger_forward_kinematics('thumb', list(self.allegro_joint_state.position)[12:16])[0]
232 |
233 | return index_coord, middle_coord, ring_coord, thumb_coord
234 |
235 | def _update_joint_state(self, index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord):
236 | current_joint_angles = list(self.allegro_joint_state.position)
237 |
238 | index_joint_angles = self.allegro_ik.finger_inverse_kinematics('index', index_tip_coord, current_joint_angles[0:4])
239 | middle_joint_angles = self.allegro_ik.finger_inverse_kinematics('middle', middle_tip_coord, current_joint_angles[4:8])
240 | ring_joint_angles = self.allegro_ik.finger_inverse_kinematics('ring', ring_tip_coord, current_joint_angles[8:12])
241 | thumb_joint_angles = self.allegro_ik.finger_inverse_kinematics('thumb', thumb_tip_coord, current_joint_angles[12:16])
242 |
243 | desired_joint_angles = copy(current_joint_angles)
244 |
245 | for idx in range(4):
246 | desired_joint_angles[idx] = index_joint_angles[idx]
247 | desired_joint_angles[4 + idx] = middle_joint_angles[idx]
248 | desired_joint_angles[8 + idx] = ring_joint_angles[idx]
249 | desired_joint_angles[12 + idx] = thumb_joint_angles[idx]
250 |
251 | return desired_joint_angles
252 |
253 | def start(self, time_loop = False):
254 | print("\nDeploying model: {}\n".format(self.model_to_use))
255 |
256 | if time_loop is True:
257 | rate = rospy.Rate(2) # 2 sec sleep duration
258 |
259 | while True:
260 | if time_loop is False:
261 | next_step = input()
262 |
263 | # Checking if all the data streams are working
264 | if self.allegro_joint_state is None:
265 | print('No allegro state received!')
266 | continue
267 |
268 | if self.robot_image is None:
269 | print('No robot image received!')
270 | continue
271 |
272 | if self.task == "rotate":
273 | if self.object_tracked_position is None:
274 | print("Object cannot be tracked!")
275 | continue
276 |
277 | if self.hand_base_position is None:
278 | print("Hand base position not found!")
279 | continue
280 |
281 | if self.task == "flip":
282 | if self.object_tracked_position is None:
283 | continue
284 |
285 | # Performing the step
286 | print("********************************************************\n Starting new step \n********************************************************")
287 |
288 | # Displaying the current state data
289 | index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord = self._get_tip_coords()
290 | finger_state = list(index_tip_coord) + list(middle_tip_coord) + list(ring_tip_coord) + list(thumb_tip_coord)
291 | print("Current finger-tip positions:\n Index-tip: {}\n Middle-tip: {}\n Ring-tip: {}\n Thumb-tip: {}\n".format(index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord))
292 |
293 | if self.k != 1:
294 | cv2.imshow("Current Robot State Image", self.robot_image)
295 | cv2.waitKey(1)
296 |
297 | if self.task == "rotate" or self.task == "flip":
298 | object_position = list(self._get_object_position())
299 | print("Object is located at position: {}".format(object_position))
300 | elif self.task == "spin":
301 | object_position = [0, 0, 0]
302 |
303 | if self.model_to_use == "vinn" or self.model_to_use == "bc":
304 | pil_image = PILImage.fromarray(cv2.cvtColor(self.robot_image, cv2.COLOR_BGR2RGB))
305 | if self.task == "rotate":
306 | cropped_image = pil_image.crop((500, 160, 950, 600)) # Left, Top, Right, Bottom
307 | elif self.task == "flip": # TODO
308 | cropped_image = pil_image.crop((220, 165, 460, 340)) # Left, Top, Right, Bottom
309 | elif self.task == "spin": # TODO
310 | cropped_image = pil_image.crop((65, 80, 590, 480)) # Left, Top, Right, Bottom
311 |
312 | transformed_image_tensor = self.image_transform(cropped_image).unsqueeze(0)
313 |
314 | # Obtaining the actions through each model
315 | if self.model_to_use == "inn":
316 | if self.k == 1:
317 | print("\nObtaining the action!\n")
318 | action, neighbor_index, nn_state_image_path, trans_state_image_path, index_l2_diff, middle_l2_diff, ring_l2_diff, thumb_l2_diff, object_l2_diff = self.model.get_action_with_image(index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord, object_position)
319 | action = list(action.reshape(12))
320 |
321 | l2_diff = np.linalg.norm(np.array(finger_state) - np.array(action))
322 |
323 | if l2_diff >= self.threshold:
324 | print("Distances:\n Index distance: {}\n Middle distance: {}\n Ring distance: {}\n Thumb distance: {}\n Cube distance: {}".format(index_l2_diff.item(), middle_l2_diff.item(), ring_l2_diff.item(), thumb_l2_diff.item(), object_l2_diff.item()))
325 |
326 | while l2_diff < self.threshold:
327 | print("Action distance is below threshold: {}".format(l2_diff))
328 | action, neighbor_index, nn_state_image_path, trans_state_image_path = self.model.get_debug_action(neighbor_index)
329 | # action, neighbor_index, nn_state_image_path, trans_state_image_path, index_l2_diff, middle_l2_diff, ring_l2_diff, thumb_l2_diff, object_l2_diff = self.model.get_action_with_image(action[:3], action[3:6], action[6:9], action[9:12], object_position)
330 | action = list(action.reshape(12))
331 | print("Primary neighbor index:", neighbor_index)
332 |
333 | l2_diff = np.linalg.norm(np.array(finger_state) - np.array(action))
334 |
335 | # Reading the Nearest Neighbor images
336 | nn_state_image = cv2.imread(nn_state_image_path)
337 | trans_state_image = cv2.imread(trans_state_image_path)
338 |
339 | # Combing the surrent state image and NN image and printing them
340 | combined_images = np.concatenate((
341 | cv2.resize(self.robot_image, (420, 240), interpolation = cv2.INTER_AREA),
342 | cv2.resize(nn_state_image, (420, 240), interpolation = cv2.INTER_AREA),
343 | cv2.resize(trans_state_image, (420, 240), interpolation = cv2.INTER_AREA)
344 | ), axis=1)
345 |
346 | cv2.imshow("Current state and Nearest Neighbor Images", combined_images)
347 | cv2.waitKey(1)
348 |
349 | else:
350 | print("Obtaining the action!\n")
351 | action = list(self.model.get_action(index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord, object_position).reshape(12))
352 | l2_diff = np.linalg.norm(np.array(finger_state) - np.array(action))
353 |
354 | while l2_diff < self.threshold:
355 | print("Action distance is below threshold: {}".format(l2_diff))
356 | action = list(self.model.get_action(action[:3], action[3:6], action[6:9], action[9:12], object_position).reshape(12))
357 | l2_diff = np.linalg.norm(np.array(finger_state) - np.array(action))
358 |
359 | elif self.model_to_use == "vinn":
360 | print("Image shape:", transformed_image_tensor.float().to(self.device).shape)
361 | representation = self.learner.net(transformed_image_tensor.float().to(self.device)).squeeze().detach().cpu().numpy()
362 | print("Representation shape:", representation.shape)
363 | if self.k == 1:
364 | print("Obtaining the action!\n")
365 | action, neighbor_idx, nn_state_image_path, trans_state_image_path = self.model.get_action_with_image(representation)
366 | action = list(action.reshape(12))
367 |
368 | l2_diff = np.linalg.norm(np.array(finger_state[:12]) - np.array(action))
369 |
370 | while l2_diff < self.threshold:
371 | print("Action distance is below threshold. Rolling out another action from the same trajectory!\n")
372 | action, neighbor_idx, nn_state_image_path, trans_state_image_path = self.model.get_debug_action(neighbor_idx)
373 | action = list(action.reshape(12))
374 |
375 | l2_diff = np.linalg.norm(np.array(finger_state[:12]) - np.array(action))
376 |
377 | # Reading the Nearest Neighbor images
378 | nn_state_image = cv2.imread(nn_state_image_path)
379 | trans_state_image = cv2.imread(trans_state_image_path)
380 |
381 | # Combing the surrent state image and NN image and printing them
382 | combined_images = np.concatenate((
383 | cv2.resize(self.robot_image, (420, 240), interpolation = cv2.INTER_AREA),
384 | cv2.resize(nn_state_image, (420, 240), interpolation = cv2.INTER_AREA),
385 | cv2.resize(trans_state_image, (420, 240), interpolation = cv2.INTER_AREA)
386 | ), axis=1)
387 |
388 | cv2.imshow("Current state and Nearest Neighbor Images", combined_images)
389 | cv2.waitKey(1)
390 |
391 | else:
392 | action = list(self.model.get_action(representation).reshape(12))
393 |
394 | elif self.model_to_use == "mlp":
395 | input_coordinates = torch.tensor(np.concatenate([index_tip_coord, middle_tip_coord, ring_tip_coord, thumb_tip_coord, object_position]))
396 | action = self.model(input_coordinates.float().to(self.device)).detach().cpu().numpy()[0]
397 |
398 | elif self.model_to_use == "bc":
399 | action = self.model(transformed_image_tensor.float().to(self.device)).detach().cpu().numpy()[0]
400 |
401 | # Updating the hand target coordinates
402 | updated_index_tip_coord = action[0:3]
403 | updated_middle_tip_coord = action[3:6]
404 | updated_ring_tip_coord = action[6:9]
405 | updated_thumb_tip_coord = action[9:12]
406 |
407 | print("Corresponding finger-tip action:\n Index-tip: {}\n Middle-tip: {}\n Ring-tip: {}\n Thumb-tip: {}\n".format(updated_index_tip_coord, updated_middle_tip_coord, updated_ring_tip_coord, updated_thumb_tip_coord))
408 |
409 | desired_joint_angles = self._update_joint_state(updated_index_tip_coord, updated_middle_tip_coord, updated_ring_tip_coord, updated_thumb_tip_coord)
410 |
411 | print("Moving arm to {}\n".format(desired_joint_angles))
412 | self.arm.move_hand(desired_joint_angles)
413 |
414 | if time_loop is True:
415 | rate.sleep()
416 |
417 |
418 | if __name__ == '__main__':
419 | # Getting options
420 | options = parser.parse_args()
421 |
422 | d = DexArmDeploy(
423 | task = options.task,
424 | model = options.model,
425 | run = options.run,
426 | k = options.k,
427 | object_priority = options.target_priority,
428 | device = "cpu",
429 | threshold = 0.01 * options.delta
430 | )
431 |
432 | d.start()
--------------------------------------------------------------------------------