├── utils ├── __init__.py ├── constant.py ├── trajectory_utils.py ├── util.py ├── data_augmentation.py ├── realworld_utils.py └── image_generation.py ├── dataset ├── __init__.py ├── trajectory_optimization.py └── dataset.py ├── assets └── img │ ├── model.png │ └── action_relabeling.png ├── command_eval.sh ├── command_train.sh ├── policy ├── README.md ├── diffusion_policy │ ├── robomimic_replay_lowdim_dataset.py │ └── sampler.py └── robomimic │ └── rollout.py ├── LICENSE ├── networks └── resnet.py ├── README.md ├── losses.py ├── eval.py ├── environment.yaml └── train.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/img/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junxix/S2I/HEAD/assets/img/model.png -------------------------------------------------------------------------------- /assets/img/action_relabeling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junxix/S2I/HEAD/assets/img/action_relabeling.png -------------------------------------------------------------------------------- /command_eval.sh: -------------------------------------------------------------------------------- 1 | python eval.py --train_data_folder ./lowdim_samples.npy --val_data_folder ./low_dim.hdf5 --size 128 --ckpt ./ckpts/ckpt_epoch_2000.pth -------------------------------------------------------------------------------- /command_train.sh: -------------------------------------------------------------------------------- 1 | python train.py --batch_size 256 --learning_rate 0.005 --temp 0.1 --cosine --aug_path ./lowdim_samples.npy --method SupCon --epochs 2500 --save_freq 100 --print_freq 1 --size 128 --save_mode realworld -------------------------------------------------------------------------------- /utils/constant.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from robomimic.envs.env_base import EnvBase, EnvType 3 | 4 | IMG_MEAN = np.array([0.485, 0.456, 0.406]) 5 | IMG_STD = np.array([0.229, 0.224, 0.225]) 6 | 7 | DEFAULT_CAMERAS = { 8 | EnvType.ROBOSUITE_TYPE: ["agentview"], 9 | EnvType.IG_MOMART_TYPE: ["rgb"], 10 | EnvType.GYM_TYPE: ValueError("No camera names supported for gym type env!"), 11 | } 12 | 13 | CAMERA_NAME = 'cam_750612070851' 14 | 15 | DELTA_THETA = 75 16 | -------------------------------------------------------------------------------- /policy/README.md: -------------------------------------------------------------------------------- 1 | # Manipulation Policy 2 | ## BC-RNN 3 | For state-based [BC-RNN](https://github.com/ARISE-Initiative/robomimic), we modified the rollout program to ensure that during simulation evaluation, 50 starting positions are randomly selected. These positions vary across different seeds but remain consistent within the same seed. 4 | 5 | Here are the argument explanations in the rollout process: 6 | * `--config` : Specifies the configuration for the algorithm's structure. 7 | * `--dataset` : The path to the dataset used for training and loading environment parameters. 8 | * `--checkpoint_dir` : The directory containing the checkpoints to be evaluated. 9 | 10 | ## Diffusion Policy 11 | 12 |
13 | xx picture 14 |
15 | 16 | For the [diffusion policy](https://github.com/real-stanford/diffusion_policy) method, we slightly modified the sampling process to align with the action relabeling procedure. Specifically, we adjusted the sampling to ensure that each sampled state corresponds to its preceding state and action, matching the process depicted in the action relabeling diagram. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Jingjing Chen, Hongjie Fang, Hao-Shu Fang, Cewu Lu 4 | Copyright (c) 2024 Shanghai Jiao Tong University 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. -------------------------------------------------------------------------------- /networks/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision.models as models 5 | 6 | 7 | 8 | class SupConResNet(nn.Module): 9 | """backbone + projection head""" 10 | def __init__(self, name='resnet50', head='mlp', feat_dim=128, pretrained=True): 11 | super(SupConResNet, self).__init__() 12 | if name == 'resnet50': 13 | self.encoder = models.resnet50(pretrained=pretrained) 14 | dim_in = 2048 15 | elif name == 'resnet18': 16 | self.encoder = models.resnet18(pretrained=pretrained) 17 | dim_in = 512 18 | else: 19 | raise ValueError(f"Model {name} not supported") 20 | self.encoder = nn.Sequential(*list(self.encoder.children())[:-1]) 21 | for param in self.encoder.parameters(): 22 | param.requires_grad = False 23 | 24 | if head == 'linear': 25 | self.head = nn.Linear(dim_in, feat_dim) 26 | elif head == 'mlp': 27 | self.head = nn.Sequential( 28 | nn.Linear(dim_in, dim_in), 29 | nn.ReLU(inplace=True), 30 | nn.Linear(dim_in, feat_dim) 31 | ) 32 | else: 33 | raise NotImplementedError( 34 | 'head not supported: {}'.format(head)) 35 | 36 | def forward(self, x): 37 | feat = self.encoder(x) 38 | feat = torch.flatten(feat, 1) 39 | feat = F.normalize(self.head(feat), dim=1) 40 | return feat 41 | 42 | 43 | -------------------------------------------------------------------------------- /utils/trajectory_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import wandb 3 | from scipy.spatial.transform import Rotation 4 | import robosuite.utils.transform_utils as T 5 | 6 | def find_all_change_indices(data): 7 | changes_indices = [] 8 | prev_value = data[0][-1] 9 | 10 | for i in range(1, len(data)): 11 | curr_value = data[i][-1] 12 | if (prev_value == -1 and curr_value == 1) or (prev_value == 1 and curr_value == -1): 13 | changes_indices.append(i) 14 | prev_value = curr_value 15 | 16 | return changes_indices 17 | 18 | def slice_trajectory_and_states(trajectory_points, states, change_indices): 19 | trajectory_slices, state_slices = [], [] 20 | start_idx = 0 21 | 22 | for idx in change_indices: 23 | if idx - start_idx >= 10: 24 | trajectory_slices.append(trajectory_points[start_idx:idx]) 25 | state_slices.append(states[start_idx:idx]) 26 | start_idx = idx 27 | 28 | # Handle the last slice if it's long enough 29 | if len(trajectory_points) - start_idx > 15: 30 | trajectory_slices.append(trajectory_points[start_idx:]) 31 | state_slices.append(states[start_idx:]) 32 | 33 | return trajectory_slices, state_slices 34 | 35 | def compute_errors(actions, gt_states, waypoints, return_list=False): 36 | if waypoints[0] != 0: 37 | waypoints = [0] + waypoints 38 | 39 | gt_pos = [p["robot0_eef_pos"] for p in gt_states] 40 | gt_quat = [p["robot0_eef_quat"] for p in gt_states] 41 | keypoints_pos = [actions[k, :3] for k in waypoints] 42 | keypoints_quat = [gt_quat[k] for k in waypoints] 43 | 44 | errors = [] 45 | 46 | for i in range(len(waypoints) - 1): 47 | start_idx = waypoints[i] 48 | end_idx = waypoints[i + 1] 49 | 50 | segment_pos = gt_pos[start_idx:end_idx] 51 | segment_quat = gt_quat[start_idx:end_idx] 52 | 53 | for j in range(len(segment_pos)): 54 | line_vector = keypoints_pos[i + 1] - keypoints_pos[i] 55 | point_vector = segment_pos[j] - keypoints_pos[i] 56 | t = np.clip( 57 | np.dot(point_vector, line_vector) / np.dot(line_vector, line_vector), 58 | 0, 1 59 | ) 60 | proj_point = keypoints_pos[i] + t * line_vector 61 | pos_err = np.linalg.norm(segment_pos[j] - proj_point) 62 | 63 | pred_quat = T.quat_slerp( 64 | keypoints_quat[i], 65 | keypoints_quat[i + 1], 66 | fraction=j/len(segment_quat) 67 | ) 68 | quat_err = ( 69 | Rotation.from_quat(pred_quat) * 70 | Rotation.from_quat(segment_quat[j]).inv() 71 | ).magnitude() 72 | 73 | errors.append(pos_err + quat_err) 74 | 75 | max_error = np.max(errors) 76 | return (max_error, errors) if return_list else max_error 77 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import math 4 | import numpy as np 5 | import torch 6 | import torch.optim as optim 7 | 8 | 9 | class TwoCropTransform: 10 | """Create two crops of the same image""" 11 | def __init__(self, transform): 12 | self.transform = transform 13 | 14 | def __call__(self, x): 15 | return [self.transform(x), self.transform(x)] 16 | 17 | 18 | class AverageMeter(object): 19 | """Computes and stores the average and current value""" 20 | def __init__(self): 21 | self.reset() 22 | 23 | def reset(self): 24 | self.val = 0 25 | self.avg = 0 26 | self.sum = 0 27 | self.count = 0 28 | 29 | def update(self, val, n=1): 30 | self.val = val 31 | self.sum += val * n 32 | self.count += n 33 | self.avg = self.sum / self.count 34 | 35 | 36 | def accuracy(output, target, topk=(1,)): 37 | """Computes the accuracy over the k top predictions for the specified values of k""" 38 | with torch.no_grad(): 39 | maxk = max(topk) 40 | batch_size = target.size(0) 41 | 42 | _, pred = output.topk(maxk, 1, True, True) 43 | pred = pred.t() 44 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 45 | 46 | res = [] 47 | for k in topk: 48 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 49 | res.append(correct_k.mul_(100.0 / batch_size)) 50 | return res 51 | 52 | 53 | def adjust_learning_rate(args, optimizer, epoch): 54 | lr = args.learning_rate 55 | if args.cosine: 56 | eta_min = lr * (args.lr_decay_rate ** 3) 57 | lr = eta_min + (lr - eta_min) * ( 58 | 1 + math.cos(math.pi * epoch / args.epochs)) / 2 59 | else: 60 | steps = np.sum(epoch > np.asarray(args.lr_decay_epochs)) 61 | if steps > 0: 62 | lr = lr * (args.lr_decay_rate ** steps) 63 | 64 | for param_group in optimizer.param_groups: 65 | param_group['lr'] = lr 66 | 67 | 68 | def warmup_learning_rate(args, epoch, batch_id, total_batches, optimizer): 69 | if args.warm and epoch <= args.warm_epochs: 70 | p = (batch_id + (epoch - 1) * total_batches) / \ 71 | (args.warm_epochs * total_batches) 72 | lr = args.warmup_from + p * (args.warmup_to - args.warmup_from) 73 | 74 | for param_group in optimizer.param_groups: 75 | param_group['lr'] = lr 76 | 77 | 78 | def set_optimizer(opt, model): 79 | optimizer = optim.SGD(model.parameters(), 80 | lr=opt.learning_rate, 81 | momentum=opt.momentum, 82 | weight_decay=opt.weight_decay) 83 | return optimizer 84 | 85 | 86 | def save_model(model, optimizer, opt, epoch, save_file): 87 | print('==> Saving...') 88 | state = { 89 | 'opt': opt, 90 | 'model': model.state_dict(), 91 | 'optimizer': optimizer.state_dict(), 92 | 'epoch': epoch, 93 | } 94 | torch.save(state, save_file) 95 | del state 96 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Towards Effective Utilization of Mixed-Quality Demonstrations in Robotic Manipulation via Segment-Level Selection and Optimization 2 | 3 | [[Paper]](https://arxiv.org/pdf/2409.19917) [[Project Page]](https://tonyfang.net/s2i/) 4 | 5 | xx picture 6 | 7 | ## 🧑🏻‍💻 Run 8 | For the representation model training stage, run the command `bash command_train.sh` to execute the data training script, which will preprocess the dataset and training the model. 9 | 10 | Here are the argument explanations in the training process: 11 | * `--dataset` : Specifies the entire dataset used for the representation model training. 12 | * `--aug_path` : The path where the results of the augmented dataset will be stored. 13 | * `--save_mode` :Indicates the format or type of the dataset. 14 | * `--size` : Specifies the size to which the images will be resized. 15 | * `--numbers` :The index or specific identifier used for data augmentation within the dataset. 16 | 17 | For the eval stage, run the command `bash command_eval.sh` to complete the segment selection and trajectory optimization processes. 18 | 19 | Here are the argument explanations in the evaluation process: 20 | * `--train_data_folder` : The dataset used for distance-weighted voting during the segment selection process. 21 | * `--val_data_folder` : The folder containing the full mixed-quality demonstration dataset for validation. 22 | * `--size` : Specifies the size to which the images will be resized. 23 | 24 | ## 🤖 Training Manipulation Policy 25 | 26 | After Select Segments to Imitate (S2I), the dataset can be directly used for downstream manipulation policy training as a plug-and-play solution. 27 | 28 | For simulation experiments, we use the state-based [BC-RNN](https://github.com/ARISE-Initiative/robomimic) and the [Diffusion Policy (DP)](https://github.com/real-stanford/diffusion_policy) that can be applied to both state and image data as robot manipulation policies. For real-world experiments, we choose [DP](https://github.com/real-stanford/diffusion_policy) and [ACT](https://github.com/tonyzhaozh/act) as our image-based policies, as well as [RISE](https://github.com/rise-policy/rise) as our point-cloud-based policy. Some minor modifications have been made to the sampler and rollout functions. The modified Python file is available in [`./policy`](https://github.com/Junxix/S2I/tree/main/policy). Refer to the [documentation](policy/README.md) for more details. 29 | 30 | ## 🙏 Acknowledgement 31 | 32 | Our code is built upon: [Diffusion Policy](https://github.com/real-stanford/diffusion_policy/), [RoboMimic](https://github.com/ARISE-Initiative/robomimic), [SupContrast](https://github.com/HobbitLong/SupContrast), [RISE](https://github.com/rise-policy/rise) and [ACT](https://github.com/tonyzhaozh/act). We thank all the authors for the contributions to the community. 33 | 34 | ## ✍️ Citation 35 | 36 | If you find S2I useful in your research, please consider citing the following paper: 37 | 38 | ```bibtex 39 | @article{ 40 | chen2024towards, 41 | title = {Towards Effective Utilization of Mixed-Quality Demonstrations in Robotic Manipulation via Segment-Level Selection and Optimization}, 42 | author = {Chen, Jingjing and Fang, Hongjie and Fang, Hao-Shu and Lu, Cewu}, 43 | journal = {arXiv preprint arXiv:2409.19917}, 44 | year = {2024} 45 | } 46 | ``` 47 | 48 | ## 📃 License 49 | 50 | S2I by Jingjing Chen, Hongjie Fang, Hao-Shu Fang, Cewu Lu is licensed under MIT License. 51 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class SupConLoss(nn.Module): 8 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. 9 | It also supports the unsupervised contrastive loss in SimCLR""" 10 | def __init__(self, temperature=0.07, contrast_mode='all', 11 | base_temperature=0.07): 12 | super(SupConLoss, self).__init__() 13 | self.temperature = temperature 14 | self.contrast_mode = contrast_mode 15 | self.base_temperature = base_temperature 16 | 17 | def forward(self, features, labels=None, mask=None): 18 | """Compute loss for model. If both `labels` and `mask` are None, 19 | it degenerates to SimCLR unsupervised loss: 20 | https://arxiv.org/pdf/2002.05709.pdf 21 | 22 | Args: 23 | features: hidden vector of shape [bsz, n_views, ...]. 24 | labels: ground truth of shape [bsz]. 25 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j 26 | has the same class as sample i. Can be asymmetric. 27 | Returns: 28 | A loss scalar. 29 | """ 30 | device = (torch.device('cuda') 31 | if features.is_cuda 32 | else torch.device('cpu')) 33 | 34 | if len(features.shape) < 3: 35 | raise ValueError('`features` needs to be [bsz, n_views, ...],' 36 | 'at least 3 dimensions are required') 37 | if len(features.shape) > 3: 38 | features = features.view(features.shape[0], features.shape[1], -1) 39 | 40 | batch_size = features.shape[0] 41 | if labels is not None and mask is not None: 42 | raise ValueError('Cannot define both `labels` and `mask`') 43 | elif labels is None and mask is None: 44 | mask = torch.eye(batch_size, dtype=torch.float32).to(device) 45 | elif labels is not None: 46 | labels = labels.contiguous().view(-1, 1) 47 | if labels.shape[0] != batch_size: 48 | raise ValueError('Num of labels does not match num of features') 49 | mask = torch.eq(labels, labels.T).float().to(device) 50 | else: 51 | mask = mask.float().to(device) 52 | 53 | contrast_count = features.shape[1] 54 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 55 | if self.contrast_mode == 'one': 56 | anchor_feature = features[:, 0] 57 | anchor_count = 1 58 | elif self.contrast_mode == 'all': 59 | anchor_feature = contrast_feature 60 | anchor_count = contrast_count 61 | else: 62 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) 63 | 64 | # compute logits 65 | anchor_dot_contrast = torch.div( 66 | torch.matmul(anchor_feature, contrast_feature.T), 67 | self.temperature) 68 | # for numerical stability 69 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 70 | logits = anchor_dot_contrast - logits_max.detach() 71 | 72 | # tile mask 73 | mask = mask.repeat(anchor_count, contrast_count) 74 | # mask-out self-contrast cases 75 | logits_mask = torch.scatter( 76 | torch.ones_like(mask), 77 | 1, 78 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 79 | 0 80 | ) 81 | mask = mask * logits_mask 82 | 83 | # compute log_prob 84 | exp_logits = torch.exp(logits) * logits_mask 85 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 86 | 87 | # compute mean of log-likelihood over positive 88 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 89 | 90 | # loss 91 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos 92 | loss = loss.view(anchor_count, batch_size).mean() 93 | 94 | return loss 95 | -------------------------------------------------------------------------------- /utils/data_augmentation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import h5py 4 | from tqdm import tqdm 5 | from .realworld_utils import * 6 | from .constant import * 7 | from .image_generation import TrajectoryRenderer, PictureGenerator 8 | import argparse 9 | 10 | import robomimic 11 | import robomimic.utils.obs_utils as ObsUtils 12 | import robomimic.utils.env_utils as EnvUtils 13 | import robomimic.utils.file_utils as FileUtils 14 | 15 | def process_slices(generator, trajectory_slices, color_path_slices, ind, total_images_per_slice): 16 | for i, (traj_slice, state_slice) in enumerate(zip(trajectory_slices, color_path_slices)): 17 | generator.generate_positive_picture(traj_slice, state_slice, ind, num_images=total_images_per_slice) 18 | generator.generate_negative_picture(traj_slice, state_slice, ind, num_images=total_images_per_slice) 19 | 20 | 21 | def data_augmentation_realworld(args): 22 | calib_dir = check_directory_exists(os.path.join(args.dataset, "calib")) 23 | 24 | root_dir = os.path.join(args.dataset, "train") 25 | subdirs = [d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))] 26 | sorted_subdirs = sorted(subdirs, key=lambda x: int(x.split('_scene_')[1].split('_')[0])) 27 | samples = {'images': [], 'end_images': [], 'labels': []} 28 | 29 | total_demos = len(args.numbers) 30 | total_images_per_demo = args.total_images // total_demos 31 | 32 | for ind in args.numbers: 33 | path = os.path.join(root_dir, sorted_subdirs[ind], CAMERA_NAME, 'color') 34 | renderer = TrajectoryRenderer(env=None, camera_name=None, save_mode='realworld', calib_dir=calib_dir, root_dir=path) 35 | generator = PictureGenerator(renderer, samples, save_mode=args.save_mode) 36 | 37 | file_paths, trajectory_points, gripper_command = load_demo_files(root_dir, sorted_subdirs, ind) 38 | 39 | change_indices = realworld_change_indices(gripper_command) 40 | trajectory_slices, color_path_slices = realworld_slice(trajectory_points, file_paths, change_indices) 41 | total_images_per_slice = total_images_per_demo // len(trajectory_slices) 42 | process_slices(generator, trajectory_slices, color_path_slices, ind, total_images_per_slice) 43 | 44 | np.save(args.aug_path, samples) 45 | 46 | 47 | def data_augmentation_robomimic(args): 48 | env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=args.dataset) 49 | env_type = EnvUtils.get_env_type(env_meta=env_meta) 50 | render_image_names = DEFAULT_CAMERAS[env_type] 51 | 52 | dummy_spec = dict( 53 | obs=dict( 54 | low_dim=["robot0_eef_pos"], 55 | rgb=[], 56 | ), 57 | ) 58 | ObsUtils.initialize_obs_utils_with_obs_specs(obs_modality_specs=dummy_spec) 59 | 60 | env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=args.dataset) 61 | env = EnvUtils.create_env_from_metadata(env_meta=env_meta, render=False, render_offscreen=True) 62 | is_robosuite_env = EnvUtils.is_robosuite_env(env_meta) 63 | 64 | f = h5py.File(args.dataset, "r") 65 | demos = sorted(f["data"].keys(), key=lambda x: int(x[5:])) 66 | samples = {'images': [], 'end_images': [], 'labels': []} 67 | 68 | total_demos = len(args.numbers) 69 | total_images_per_demo = args.total_images // total_demos 70 | 71 | renderer = TrajectoryRenderer(env, render_image_names[0]) 72 | generator = PictureGenerator(renderer, samples, save_mode=args.save_mode) 73 | 74 | for ind in args.numbers: 75 | ep = demos[ind] 76 | states = f[f"data/{ep}/states"][()] 77 | trajectory_points = f[f"data/{ep}/obs/robot0_eef_pos"][()] 78 | actions = f[f"data/{ep}/actions"][()] 79 | 80 | initial_state = dict(states=states[0]) 81 | if is_robosuite_env: 82 | initial_state["model"] = f[f"data/{ep}"].attrs["model_file"] 83 | generator.renderer.env.reset() 84 | generator.renderer.env.reset_to(initial_state) 85 | 86 | change_indices = find_all_change_indices(actions) 87 | trajectory_slices, state_slices = slice_trajectory_and_states(trajectory_points, states, change_indices) 88 | total_images_per_slice = total_images_per_demo // len(trajectory_slices) 89 | process_slices(generator, trajectory_slices, state_slices, ind, total_images_per_slice) 90 | 91 | np.save(args.aug_path, samples) 92 | f.close() 93 | 94 | 95 | def data_augmentation(args): 96 | if args.save_mode == 'realworld': 97 | data_augmentation_realworld(args) 98 | else: 99 | data_augmentation_robomimic(args) 100 | 101 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | import numpy as np 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | from torchvision import transforms 8 | from tqdm import tqdm 9 | from utils.util import TwoCropTransform, AverageMeter 10 | from networks.resnet import SupConResNet 11 | from dataset.dataset import CustomDataset, ValDataset 12 | from utils.constant import * 13 | 14 | def parse_option(): 15 | parser = argparse.ArgumentParser('Argument for training') 16 | parser.add_argument('--model', type=str, default='resnet50') 17 | parser.add_argument("--save_mode", type=str, default='lowdim', choices=['image', 'lowdim', 'realworld'], help="choose the saving method") 18 | parser.add_argument('--train_data_folder', type=str, default='./lowdim_samples.npy', help='path to custom dataset') 19 | parser.add_argument('--val_data_folder', type=str, default='./low_dim.hdf5', help='path to custom dataset') 20 | parser.add_argument('--size', type=int, default=128) 21 | parser.add_argument('--ckpt', type=str, default='./ckpt_epoch_2000.pth', 22 | help='path to pre-trained model') 23 | return parser.parse_args() 24 | 25 | def dist_metric(x, y): 26 | return torch.norm(x - y).item() 27 | 28 | def calculate_label(dist_list, k): 29 | top_k_weights = torch.nn.functional.softmax(torch.tensor([d[0] for d in dist_list[:k]]) * -1, dim=0) 30 | action = sum(weight * dist_list[i][1] for i, weight in enumerate(top_k_weights)) 31 | return action 32 | 33 | def clear_folders_if_not_empty(folders): 34 | for folder in folders: 35 | if os.path.exists(folder) and os.listdir(folder): 36 | shutil.rmtree(folder) 37 | os.makedirs(folder) 38 | 39 | def calculate_nearest_neighbors(query_embedding, train_dataset, train_labels, k): 40 | dist_list = [(dist_metric(torch.from_numpy(query_embedding), torch.from_numpy(train_dataset[i])), train_labels[i]) for i in range(len(train_dataset))] 41 | dist_list.sort(key=lambda tup: tup[0]) 42 | return calculate_label(dist_list, k) 43 | 44 | def set_loader(opt): 45 | normalize = transforms.Normalize(mean=IMG_MEAN, std=IMG_STD) 46 | 47 | train_transform = transforms.Compose([ 48 | transforms.RandomResizedCrop(size=opt.size, scale=(0.8, 1.0)), 49 | transforms.ToTensor(), 50 | normalize, 51 | ]) 52 | 53 | val_transform = transforms.Compose([ 54 | transforms.Resize((opt.size, opt.size)), 55 | transforms.ToTensor(), 56 | normalize, 57 | ]) 58 | 59 | train_dataset = CustomDataset(npy_file=opt.train_data_folder, transform=train_transform) 60 | val_dataset = ValDataset(hdf5_file=opt.val_data_folder, transform=val_transform, save_mode = opt.save_mode) 61 | return train_dataset, val_dataset 62 | 63 | def set_model(opt): 64 | model = SupConResNet(name=opt.model) 65 | ckpt = torch.load(opt.ckpt, map_location='cpu') 66 | state_dict = {k.replace("module.", ""): v for k, v in ckpt['model'].items()} 67 | 68 | if torch.cuda.is_available(): 69 | model = model.cuda() 70 | cudnn.benchmark = True 71 | model.load_state_dict(state_dict, strict=False) 72 | else: 73 | raise NotImplementedError('This code requires GPU') 74 | return model 75 | 76 | def get_embeddings(train_dataset, model): 77 | model.eval() 78 | embeddings, labels = [], [] 79 | for idx in range(len(train_dataset)): 80 | image, label = train_dataset[idx] 81 | image = image.unsqueeze(0).cuda(non_blocking=True) 82 | with torch.no_grad(): 83 | features = model.encoder(image).flatten(start_dim=1) 84 | embeddings.append(features.cpu().numpy()) 85 | labels.append(label) 86 | return np.concatenate(embeddings), np.array(labels) 87 | 88 | def classifier(val_dataset, train_dataset, train_labels, model, neighbors_num): 89 | device = next(model.parameters()).device 90 | 91 | for idx in tqdm(range(len(val_dataset))): 92 | image_data, demo_idx, small_demo_idx = val_dataset[idx] 93 | image_data = image_data.unsqueeze(0).to(device) 94 | 95 | with torch.no_grad(): 96 | val_embedding = model.encoder(image_data).cpu().numpy() 97 | 98 | label = calculate_nearest_neighbors(val_embedding, train_dataset, train_labels, neighbors_num) 99 | val_dataset.perform_optimization(idx, label=label) 100 | 101 | def main(): 102 | opt = parse_option() 103 | train_dataset, val_dataset = set_loader(opt) 104 | model = set_model(opt) 105 | train_embeddings, train_labels = get_embeddings(train_dataset, model) 106 | classifier(val_dataset, train_embeddings, train_labels, model, neighbors_num=64) 107 | 108 | if __name__ == '__main__': 109 | main() 110 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: mujoco 2 | channels: 3 | - menpo 4 | - conda-forge 5 | - http://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch3d/ 6 | - http://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/ 7 | - http://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ 8 | - http://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ 9 | - defaults 10 | dependencies: 11 | - _libgcc_mutex=0.1=conda_forge 12 | - _openmp_mutex=4.5=2_kmp_llvm 13 | - bzip2=1.0.8=h4bc722e_7 14 | - c-ares=1.34.1=heb4867d_0 15 | - ca-certificates=2024.8.30=hbcca054_0 16 | - elfutils=0.191=h924a536_0 17 | - gettext=0.22.5=he02047a_3 18 | - gettext-tools=0.22.5=he02047a_3 19 | - glew=2.1.0=h9c3ff4c_2 20 | - glfw3=3.2.1=0 21 | - gnutls=3.8.7=h32866dd_0 22 | - icu=75.1=he02047a_0 23 | - keyutils=1.6.1=h166bdaf_0 24 | - krb5=1.21.3=h659f571_0 25 | - ld_impl_linux-64=2.40=h12ee557_0 26 | - libarchive=3.7.4=hfca40fe_0 27 | - libasprintf=0.22.5=he8f35ee_3 28 | - libasprintf-devel=0.22.5=he8f35ee_3 29 | - libcurl=8.10.1=hbbe4b11_0 30 | - libdrm=2.4.123=hb9d3cd8_0 31 | - libedit=3.1.20191231=he28a2e2_2 32 | - libev=4.33=hd590300_2 33 | - libexpat=2.6.3=h5888daf_0 34 | - libffi=3.4.4=h6a678d5_1 35 | - libgcc=14.1.0=h77fa898_1 36 | - libgcc-ng=14.1.0=h69a702a_1 37 | - libgettextpo=0.22.5=he02047a_3 38 | - libgettextpo-devel=0.22.5=he02047a_3 39 | - libglu=9.0.0=ha6d2627_1004 40 | - libiconv=1.17=hd590300_2 41 | - libidn2=2.3.7=hd590300_0 42 | - libllvm19=19.1.1=ha7bfdaf_0 43 | - libmicrohttpd=1.0.1=hbc5bc17_1 44 | - libnghttp2=1.58.0=h47da74e_1 45 | - libnsl=2.0.1=hd590300_0 46 | - libpciaccess=0.18=hd590300_0 47 | - libsqlite=3.46.1=hadc24fc_0 48 | - libssh2=1.11.0=h0841786_0 49 | - libstdcxx=14.1.0=hc0a3c3a_1 50 | - libstdcxx-ng=14.1.0=h4852527_1 51 | - libtasn1=4.19.0=h166bdaf_0 52 | - libunistring=0.9.10=h7f98852_0 53 | - libuuid=2.38.1=h0b41bf4_0 54 | - libxcb=1.17.0=h8a09558_0 55 | - libxcrypt=4.4.36=hd590300_1 56 | - libxml2=2.12.7=he7c6b58_4 57 | - libzlib=1.3.1=hb9d3cd8_2 58 | - llvm-openmp=19.1.1=h024ca30_0 59 | - lz4-c=1.9.4=hcb278e6_0 60 | - lzo=2.10=hd590300_1001 61 | - mesalib=24.2.4=h039c18d_0 62 | - ncurses=6.5=he02047a_1 63 | - nettle=3.9.1=h7ab15ed_0 64 | - openssl=3.3.2=hb9d3cd8_0 65 | - p11-kit=0.24.1=hc5aa10d_0 66 | - pip=24.2=py39h06a4308_0 67 | - pthread-stubs=0.4=hb9d3cd8_1002 68 | - python=3.9.18=h0755675_1_cpython 69 | - readline=8.2=h5eee18b_0 70 | - setuptools=75.1.0=py39h06a4308_0 71 | - tk=8.6.13=noxft_h4845f30_101 72 | - tzdata=2024b=h04d1e81_0 73 | - wheel=0.44.0=py39h06a4308_0 74 | - xorg-glproto=1.4.17=hb9d3cd8_1003 75 | - xorg-libx11=1.8.10=h4f16b4b_0 76 | - xorg-libxau=1.0.11=hb9d3cd8_1 77 | - xorg-libxdamage=1.1.6=hb9d3cd8_0 78 | - xorg-libxdmcp=1.1.5=hb9d3cd8_0 79 | - xorg-libxext=1.3.6=hb9d3cd8_0 80 | - xorg-libxfixes=6.0.1=hb9d3cd8_0 81 | - xorg-libxrandr=1.5.4=hb9d3cd8_0 82 | - xorg-libxrender=0.9.11=hb9d3cd8_1 83 | - xorg-libxxf86vm=1.1.5=hb9d3cd8_3 84 | - xorg-xextproto=7.3.0=hb9d3cd8_1004 85 | - xorg-xf86vidmodeproto=2.3.1=hb9d3cd8_1003 86 | - xorg-xorgproto=2024.1=hb9d3cd8_1 87 | - xz=5.4.6=h5eee18b_1 88 | - zstd=1.5.6=ha6fb4c9_0 89 | - pip: 90 | - absl-py==2.1.0 91 | - cachetools==5.5.0 92 | - certifi==2024.8.30 93 | - cffi==1.17.1 94 | - charset-normalizer==3.4.0 95 | - contourpy==1.3.0 96 | - cycler==0.12.1 97 | - cython==0.29.37 98 | - egl-probe==1.0.2 99 | - etils==1.5.2 100 | - fasteners==0.15 101 | - fonttools==4.54.1 102 | - free-mujoco-py==2.1.6 103 | - fsspec==2024.9.0 104 | - glfw==1.12.0 105 | - google-auth==2.35.0 106 | - google-auth-oauthlib==0.4.6 107 | - grpcio==1.66.2 108 | - h5py==3.12.1 109 | - idna==3.10 110 | - imageio==2.35.1 111 | - importlib-metadata==8.5.0 112 | - importlib-resources==6.4.5 113 | - kiwisolver==1.4.7 114 | - llvmlite==0.43.0 115 | - markdown==3.7 116 | - markupsafe==3.0.1 117 | - matplotlib==3.9.2 118 | - monotonic==1.6 119 | - mujoco==3.0.0 120 | - mujoco-py==2.1.2.14 121 | - numba==0.60.0 122 | - numpy==1.23.5 123 | - oauthlib==3.2.2 124 | - opencv-python==4.10.0.84 125 | - packaging==24.1 126 | - patchelf==0.17.2.1 127 | - pillow==10.4.0 128 | - protobuf==3.19.6 129 | - pyasn1==0.6.1 130 | - pyasn1-modules==0.4.1 131 | - pycparser==2.22 132 | - pyopengl==3.1.7 133 | - pyparsing==3.1.4 134 | - python-dateutil==2.9.0.post0 135 | - requests==2.32.3 136 | - requests-oauthlib==2.0.0 137 | - robomimic==0.2.0 138 | - robosuite==1.2.0 139 | - rsa==4.9 140 | - scipy==1.13.1 141 | - six==1.16.0 142 | - tensorboard==2.10.1 143 | - tensorboard-data-server==0.6.1 144 | - tensorboard-logger==0.1.0 145 | - tensorboard-plugin-wit==1.8.1 146 | - termcolor==2.5.0 147 | - torch==1.10.0+cu113 148 | - torchaudio==0.10.0+cu113 149 | - torchvision==0.11.1+cu113 150 | - tqdm==4.66.5 151 | - typing-extensions==4.12.2 152 | - urllib3==2.2.3 153 | - werkzeug==3.0.4 154 | - zipp==3.20.2 155 | -------------------------------------------------------------------------------- /policy/diffusion_policy/robomimic_replay_lowdim_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Dict, List 3 | import torch 4 | import numpy as np 5 | import h5py 6 | from tqdm import tqdm 7 | import copy 8 | from diffusion_policy.common.pytorch_util import dict_apply 9 | from diffusion_policy.dataset.base_dataset import BaseLowdimDataset, LinearNormalizer 10 | from diffusion_policy.model.common.normalizer import LinearNormalizer, SingleFieldLinearNormalizer 11 | from diffusion_policy.model.common.rotation_transformer import RotationTransformer 12 | from diffusion_policy.common.replay_buffer import ReplayBuffer 13 | from diffusion_policy.common.sampler import ( 14 | SequenceSampler, get_val_mask, downsample_mask) 15 | from diffusion_policy.common.normalize_util import ( 16 | robomimic_abs_action_only_normalizer_from_stat, 17 | robomimic_abs_action_only_dual_arm_normalizer_from_stat, 18 | get_identity_normalizer_from_stat, 19 | array_to_stats 20 | ) 21 | 22 | class RobomimicReplayLowdimDataset(BaseLowdimDataset): 23 | def __init__(self, 24 | dataset_path: str, 25 | horizon=1, 26 | pad_before=0, 27 | pad_after=0, 28 | obs_keys: List[str]=[ 29 | 'object', 30 | 'robot0_eef_pos', 31 | 'robot0_eef_quat', 32 | 'robot0_gripper_qpos'], 33 | abs_action=False, 34 | rotation_rep='rotation_6d', 35 | use_legacy_normalizer=False, 36 | seed=42, 37 | val_ratio=0.0, 38 | max_train_episodes=None 39 | ): 40 | obs_keys = list(obs_keys) 41 | rotation_transformer = RotationTransformer( 42 | from_rep='axis_angle', to_rep=rotation_rep) 43 | 44 | replay_buffer = ReplayBuffer.create_empty_numpy() 45 | with h5py.File(dataset_path) as file: 46 | demos = file['data'] 47 | for i in tqdm(range(len(demos)), desc="Loading hdf5 to ReplayBuffer"): 48 | demo = demos[f'demo_{i}'] 49 | 50 | episode = _data_to_obs( 51 | marks=demo['marks'], 52 | raw_obs=demo['obs'], 53 | raw_actions=demo['actions'][:].astype(np.float32), 54 | obs_keys=obs_keys, 55 | abs_action=abs_action, 56 | rotation_transformer=rotation_transformer) 57 | replay_buffer.add_episode(episode) 58 | 59 | val_mask = get_val_mask( 60 | n_episodes=replay_buffer.n_episodes, 61 | val_ratio=val_ratio, 62 | seed=seed) 63 | train_mask = ~val_mask 64 | train_mask = downsample_mask( 65 | mask=train_mask, 66 | max_n=max_train_episodes, 67 | seed=seed) 68 | 69 | sampler = SequenceSampler( 70 | replay_buffer=replay_buffer, 71 | sequence_length=horizon, 72 | pad_before=pad_before, 73 | pad_after=pad_after, 74 | episode_mask=train_mask) 75 | 76 | self.replay_buffer = replay_buffer 77 | self.sampler = sampler 78 | self.abs_action = abs_action 79 | self.train_mask = train_mask 80 | self.horizon = horizon 81 | self.pad_before = pad_before 82 | self.pad_after = pad_after 83 | self.use_legacy_normalizer = use_legacy_normalizer 84 | 85 | def get_validation_dataset(self): 86 | val_set = copy.copy(self) 87 | val_set.sampler = SequenceSampler( 88 | replay_buffer=self.replay_buffer, 89 | sequence_length=self.horizon, 90 | pad_before=self.pad_before, 91 | pad_after=self.pad_after, 92 | episode_mask=~self.train_mask 93 | ) 94 | val_set.train_mask = ~self.train_mask 95 | return val_set 96 | 97 | def get_normalizer(self, **kwargs) -> LinearNormalizer: 98 | normalizer = LinearNormalizer() 99 | 100 | # action 101 | stat = array_to_stats(self.replay_buffer['action']) 102 | if self.abs_action: 103 | if stat['mean'].shape[-1] > 10: 104 | # dual arm 105 | this_normalizer = robomimic_abs_action_only_dual_arm_normalizer_from_stat(stat) 106 | else: 107 | this_normalizer = robomimic_abs_action_only_normalizer_from_stat(stat) 108 | 109 | if self.use_legacy_normalizer: 110 | this_normalizer = normalizer_from_stat(stat) 111 | else: 112 | # already normalized 113 | this_normalizer = get_identity_normalizer_from_stat(stat) 114 | normalizer['action'] = this_normalizer 115 | 116 | # aggregate obs stats 117 | obs_stat = array_to_stats(self.replay_buffer['obs']) 118 | 119 | 120 | normalizer['obs'] = normalizer_from_stat(obs_stat) 121 | return normalizer 122 | 123 | def get_all_actions(self) -> torch.Tensor: 124 | return torch.from_numpy(self.replay_buffer['action']) 125 | 126 | def __len__(self): 127 | return len(self.sampler) 128 | 129 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: 130 | data = self.sampler.sample_sequence(idx) 131 | torch_data = dict_apply(data, torch.from_numpy) 132 | return torch_data 133 | 134 | def normalizer_from_stat(stat): 135 | max_abs = np.maximum(stat['max'].max(), np.abs(stat['min']).max()) 136 | scale = np.full_like(stat['max'], fill_value=1/max_abs) 137 | offset = np.zeros_like(stat['max']) 138 | return SingleFieldLinearNormalizer.create_manual( 139 | scale=scale, 140 | offset=offset, 141 | input_stats_dict=stat 142 | ) 143 | 144 | def _data_to_obs(marks, raw_obs, raw_actions, obs_keys, abs_action, rotation_transformer): 145 | obs = np.concatenate([ 146 | raw_obs[key] for key in obs_keys 147 | ], axis=-1).astype(np.float32) 148 | 149 | if abs_action: 150 | is_dual_arm = False 151 | if raw_actions.shape[-1] == 14: 152 | # dual arm 153 | raw_actions = raw_actions.reshape(-1,2,7) 154 | is_dual_arm = True 155 | 156 | pos = raw_actions[...,:3] 157 | rot = raw_actions[...,3:6] 158 | gripper = raw_actions[...,6:] 159 | rot = rotation_transformer.forward(rot) 160 | raw_actions = np.concatenate([ 161 | pos, rot, gripper 162 | ], axis=-1).astype(np.float32) 163 | 164 | if is_dual_arm: 165 | raw_actions = raw_actions.reshape(-1,20) 166 | 167 | marks = np.array(marks) 168 | data_length = obs.shape[0] 169 | new_marks = np.zeros(data_length, dtype=int) 170 | valid_indices = marks[marks < data_length] 171 | new_marks[valid_indices] = 1 172 | 173 | data = { 174 | 'obs': obs, 175 | 'action': raw_actions, 176 | 'marks': new_marks 177 | } 178 | return data 179 | 180 | -------------------------------------------------------------------------------- /dataset/trajectory_optimization.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import numpy as np 4 | from typing import List, Tuple, Dict, Optional 5 | from dataclasses import dataclass 6 | 7 | @dataclass 8 | class TrajectoryPoint: 9 | """Represents a point in the trajectory with its coordinates.""" 10 | coordinates: np.ndarray 11 | index: int 12 | 13 | class GeometryCalculator: 14 | 15 | @staticmethod 16 | def calculate_vector(point1: np.ndarray, point2: np.ndarray) -> np.ndarray: 17 | return point2 - point1 18 | 19 | @staticmethod 20 | def calculate_angle(vector1: np.ndarray, vector2: np.ndarray) -> float: 21 | dot_product = np.dot(vector1, vector2) 22 | norm_vector1 = np.linalg.norm(vector1) 23 | norm_vector2 = np.linalg.norm(vector2) 24 | cos_theta = dot_product / (norm_vector1 * norm_vector2) 25 | return np.arccos(np.clip(cos_theta, -1.0, 1.0)) 26 | 27 | class CoordinateTransformer: 28 | """Handles coordinate transformations between different reference frames.""" 29 | def __init__(self, env, real_world: bool = False, calib_dir: Optional[str] = None): 30 | self.env = env 31 | self.real_world = real_world 32 | self.calib_dir = calib_dir 33 | 34 | def transform_points(self, trajectory_points: np.ndarray) -> np.ndarray: 35 | if self.real_world: 36 | transformed_points = self._transform_real_world(trajectory_points) 37 | else: 38 | transformed_points = self._transform_simulation(trajectory_points) 39 | 40 | return transformed_points[:, :2] # Return only x,y coordinates 41 | 42 | def _transform_real_world(self, points: np.ndarray) -> np.ndarray: 43 | from utils.realworld_utils import translate_points 44 | return translate_points(self.calib_dir, points) 45 | 46 | def _transform_simulation(self, points: np.ndarray) -> np.ndarray: 47 | extrinsic_matrix = self.env.get_camera_extrinsic_matrix('agentview') 48 | camera_position = extrinsic_matrix[:3, 3] 49 | camera_rotation = extrinsic_matrix[:3, :3] 50 | return np.dot(points - camera_position, camera_rotation) 51 | 52 | 53 | class TrajectoryOptimizer: 54 | """Main class for optimizing robot trajectories.""" 55 | def __init__(self, env, real_world: bool = False, calib_dir: Optional[str] = None): 56 | self.geometry = GeometryCalculator() 57 | self.transformer = CoordinateTransformer(env, real_world, calib_dir) 58 | 59 | def optimize_trajectory(self, demo: Dict, demo_idx: int, small_demo_idx: int, 60 | three_dimension: bool = False) -> List[int]: 61 | frame_start = demo['frame_start'] 62 | waypoints = self._calculate_waypoints_dp(demo, three_dimension) 63 | return [mark + frame_start for mark in waypoints] 64 | 65 | def _calculate_waypoints_dp(self, demo: Dict, three_dimension: bool) -> List[int]: 66 | err_threshold = 0.005 67 | actions = demo['actions'] 68 | gt_states = demo['gt_states'] 69 | num_frames = len(actions) 70 | 71 | dp_table = self._initialize_dp_table(num_frames) 72 | 73 | min_error = self._compute_trajectory_errors(actions, gt_states, list(range(1, num_frames))) 74 | if err_threshold < min_error: 75 | return list(range(1, num_frames)) 76 | 77 | return self._fill_dp_table(dp_table, actions, gt_states, num_frames, err_threshold) 78 | 79 | def _initialize_dp_table(self, size: int) -> Dict: 80 | dp_table = {i: (0, []) for i in range(size)} 81 | dp_table[1] = (1, [1]) 82 | return dp_table 83 | 84 | def _compute_trajectory_errors(self, actions: np.ndarray, gt_states: np.ndarray, 85 | waypoints: List[int]) -> float: 86 | from utils.trajectory_utils import compute_errors 87 | return compute_errors(actions=actions, gt_states=gt_states, waypoints=waypoints) 88 | 89 | def _fill_dp_table(self, dp_table: Dict, actions: np.ndarray, gt_states: np.ndarray, 90 | num_frames: int, err_threshold: float) -> List[int]: 91 | initial_waypoints = [0, num_frames - 1] 92 | 93 | for i in range(1, num_frames): 94 | min_waypoints_required = float("inf") 95 | best_waypoints = [] 96 | 97 | for k in range(1, i): 98 | waypoints = [j - k for j in initial_waypoints if k <= j < i] + [i - k] 99 | total_err = self._compute_trajectory_errors( 100 | actions[k:i + 1], gt_states[k:i + 1], waypoints 101 | ) 102 | 103 | if total_err < err_threshold: 104 | prev_count, prev_waypoints = dp_table[k - 1] 105 | total_count = 1 + prev_count 106 | 107 | if total_count < min_waypoints_required: 108 | min_waypoints_required = total_count 109 | best_waypoints = prev_waypoints + [i] 110 | 111 | dp_table[i] = (min_waypoints_required, best_waypoints) 112 | 113 | _, waypoints = dp_table[num_frames - 1] 114 | waypoints.extend(initial_waypoints) 115 | return sorted(list(set(waypoints))) 116 | 117 | def _calculate_geometric_waypoints(self, demo: Dict, three_dimension: bool) -> List[int]: 118 | points = np.array(demo['trajectory_points']) 119 | trajectory_points = points if three_dimension else self.transformer.transform_points(points) 120 | 121 | tolerance = self._calculate_tolerance(trajectory_points) 122 | 123 | return self._select_waypoints(trajectory_points, tolerance) 124 | 125 | def _calculate_tolerance(self, points: np.ndarray) -> float: 126 | return np.max(np.linalg.norm(points[1:] - points[:-1], axis=1)) * 2 127 | 128 | def _select_waypoints(self, points: np.ndarray, tolerance: float) -> List[int]: 129 | selected_indices = [0] 130 | current_idx = 0 131 | 132 | while current_idx < len(points) - 1: 133 | if np.array_equal(points[current_idx], points[-1]): 134 | break 135 | 136 | next_idx = self._find_next_waypoint( 137 | points, current_idx, selected_indices, tolerance 138 | ) 139 | 140 | if next_idx is None: 141 | break 142 | 143 | selected_indices.append(next_idx) 144 | current_idx = next_idx 145 | 146 | return selected_indices 147 | 148 | def _find_next_waypoint(self, points: np.ndarray, current_idx: int, 149 | selected_indices: List[int], tolerance: float) -> Optional[int]: 150 | candidates = [] 151 | 152 | for i in range(len(points)): 153 | if i in selected_indices: 154 | continue 155 | 156 | distance = np.linalg.norm(points[current_idx] - points[i]) 157 | if distance > tolerance: 158 | continue 159 | 160 | vector_to_candidate = self.geometry.calculate_vector(points[current_idx], points[i]) 161 | vector_to_goal = self.geometry.calculate_vector(points[current_idx], points[-1]) 162 | angle = self.geometry.calculate_angle(vector_to_candidate, vector_to_goal) 163 | 164 | if angle < np.radians(DELTA_THETA): 165 | candidates.append((angle, distance, i)) 166 | 167 | if not candidates: 168 | return None 169 | 170 | candidates.sort(key=lambda x: (x[0], x[1])) 171 | return candidates[0][2] -------------------------------------------------------------------------------- /utils/realworld_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import numpy as np 4 | from PIL import Image 5 | import matplotlib.pyplot as plt 6 | from .constant import * 7 | 8 | 9 | def load_image_files(sub_path): 10 | file_paths = [] 11 | if os.path.exists(sub_path): 12 | npy_files = sorted([f for f in os.listdir(sub_path) if f.endswith('.png')]) 13 | for file_name in npy_files: 14 | file_paths.append(file_name) 15 | else: 16 | print(f"The directory {sub_path} does not exist.") 17 | return file_paths 18 | 19 | def load_trajectory_points(sub_path, file_paths): 20 | trajectory_points = [] 21 | if os.path.exists(sub_path): 22 | for file_name in file_paths: 23 | base_name = os.path.splitext(file_name)[0] 24 | file_path = os.path.join(sub_path, base_name + '.npy') 25 | data = np.load(file_path) 26 | first_three_numbers = data[:3] 27 | trajectory_points.append(first_three_numbers) 28 | else: 29 | print(f"The directory {sub_path} does not exist.") 30 | return np.array(trajectory_points) 31 | 32 | def load_gripper_command(sub_path, file_paths): 33 | gripper_command = [] 34 | if os.path.exists(sub_path): 35 | for file_name in file_paths: 36 | base_name = os.path.splitext(file_name)[0] 37 | file_path = os.path.join(sub_path, base_name + '.npy') 38 | data = np.load(file_path)[0] 39 | gripper_command.append(data) 40 | else: 41 | print(f"The directory {sub_path} does not exist.") 42 | return gripper_command 43 | 44 | def realworld_change_indices(gripper_command): 45 | differences = np.diff(gripper_command) 46 | change_indices = [0] 47 | for i, diff in enumerate(differences): 48 | if diff != 0: 49 | if not change_indices: 50 | change_indices.append(i + 1) 51 | else: 52 | current_diff = i + 1 - change_indices[-1] 53 | if current_diff > 5: 54 | change_indices.append(i + 1) 55 | return change_indices 56 | 57 | def realworld_slice(trajectory_points, file_paths, change_indices): 58 | trajectory_slices, state_slices = [], [] 59 | start_idx = 0 60 | for idx in change_indices: 61 | if idx - start_idx >= 10: 62 | trajectory_slices.append(trajectory_points[start_idx:idx]) 63 | state_slices.append(file_paths[start_idx:idx]) 64 | start_idx = idx 65 | if len(trajectory_points) - start_idx > 15: 66 | trajectory_slices.append(trajectory_points[start_idx:]) 67 | state_slices.append(file_paths[start_idx:]) 68 | return trajectory_slices, state_slices 69 | 70 | def get_save_mode_factor(save_mode): 71 | if save_mode == 'lowdim': 72 | return 0 73 | elif save_mode == 'image': 74 | return 0.1 75 | elif save_mode == 'realworld': 76 | return 0.1 77 | else: 78 | raise ValueError(f"Unknown save_mode: {save_mode}") 79 | 80 | def apply_image_filter(image, factor): 81 | image_array = np.array(image) 82 | white_image = np.ones_like(image_array) * 255 83 | new_image_array = (image_array * factor + white_image * (1 - factor)).astype(np.uint8) 84 | return Image.fromarray(new_image_array) 85 | 86 | 87 | def load_calibration_data(calib_root_dir): 88 | tcp_file = os.path.join(calib_root_dir, 'tcp.npy') 89 | extrinsics_file = os.path.join(calib_root_dir, 'extrinsics.npy') 90 | intrinsics_file = os.path.join(calib_root_dir, 'intrinsics.npy') 91 | 92 | tcp = np.load(tcp_file) 93 | extrinsics = np.load(extrinsics_file, allow_pickle=True).item() 94 | intrinsics = np.load(intrinsics_file, allow_pickle=True).item() 95 | 96 | return tcp, extrinsics, intrinsics 97 | 98 | def quaternion_to_rotation_matrix(quaternion): 99 | qw, qx, qy, qz = quaternion 100 | R = np.array([ 101 | [1 - 2*qy**2 - 2*qz**2, 2*qx*qy - 2*qz*qw, 2*qx*qz + 2*qy*qw], 102 | [2*qx*qy + 2*qz*qw, 1 - 2*qx**2 - 2*qz**2, 2*qy*qz - 2*qx*qw], 103 | [2*qx*qz - 2*qy*qw, 2*qy*qz + 2*qx*qw, 1 - 2*qx**2 - 2*qy**2] 104 | ]) 105 | return R 106 | 107 | def create_transformation_matrix(position, quaternion): 108 | R = quaternion_to_rotation_matrix(quaternion) 109 | T = np.eye(4) 110 | T[:3, :3] = R 111 | T[:3, 3] = position 112 | return T 113 | 114 | def compute_extrinsic_matrix(extrinsics, M_cam0433_to_end, M_end_to_base): 115 | M_cam0433_to_A = extrinsics['043322070878'][0] 116 | M_cam7506_to_A = extrinsics['750612070851'][0] 117 | 118 | M_cam7506_to_base = M_cam7506_to_A @ np.linalg.inv(M_cam0433_to_A) @ M_cam0433_to_end @ M_end_to_base 119 | return M_cam7506_to_base 120 | 121 | def convert_to_pixel_coordinates(trajectory_points, extrinsic_matrix, camera_matrix): 122 | translated_points = [] 123 | for point in trajectory_points: 124 | object_point_world = np.append(point, 1).reshape(-1, 1) 125 | object_point_camera = extrinsic_matrix @ object_point_world 126 | object_point_pixel = camera_matrix @ object_point_camera 127 | object_point_pixel /= object_point_pixel[2] 128 | pixel_point = np.array([int(object_point_pixel[0]), int(object_point_pixel[1])]) 129 | translated_points.append(pixel_point) 130 | return np.array(translated_points) 131 | 132 | def translate_points(calib_root_dir, trajectory_points): 133 | tcp, extrinsics, intrinsics = load_calibration_data(calib_root_dir) 134 | 135 | position = tcp[:3] 136 | quaternion = tcp[3:] 137 | M_end_to_base = create_transformation_matrix(position, quaternion) 138 | 139 | M_cam0433_to_end = np.array([[0, -1, 0, 0], 140 | [1, 0, 0, 0.077], 141 | [0, 0, 1, 0.2665], 142 | [0, 0, 0, 1]]) 143 | 144 | extrinsic_matrix = compute_extrinsic_matrix(extrinsics, M_cam0433_to_end, M_end_to_base) 145 | camera_matrix = intrinsics['750612070851'] 146 | 147 | return convert_to_pixel_coordinates(trajectory_points, extrinsic_matrix, camera_matrix) 148 | 149 | def plot(transformed_points, image): 150 | plt.clf() 151 | projected_points = transformed_points[:, :2] 152 | plt.plot(projected_points[:, 0], -projected_points[:, 1], color='red', linewidth=5) 153 | plt.axis('off') 154 | plt.xlim(-0.45, 0.45) 155 | plt.ylim(-0.5, 0.5) 156 | plt.gcf().set_size_inches(480/96, 480/96) 157 | plt.tight_layout() 158 | 159 | buf = io.BytesIO() 160 | plt.savefig(buf, format='png', transparent=True, dpi=96) 161 | buf.seek(0) 162 | 163 | image1 = Image.open(buf) 164 | image.paste(image1, (0, 0), image1) 165 | buf.close() 166 | return image 167 | 168 | def check_directory_exists(directory): 169 | if os.path.exists(directory) and os.path.isdir(directory): 170 | sub_dirs = [d for d in os.listdir(directory) if os.path.isdir(os.path.join(directory, d))] 171 | if sub_dirs: 172 | return os.path.join(directory, sub_dirs[0]) 173 | else: 174 | print(f"No sub-directories found in {directory}.") 175 | else: 176 | print(f"The directory {directory} does not exist.") 177 | return None 178 | 179 | def load_demo_files(root_dir, subdirs, ind): 180 | color_path = os.path.join(root_dir, subdirs[ind], CAMERA_NAME, 'color') 181 | tcp_path = os.path.join(root_dir, subdirs[ind], CAMERA_NAME, 'tcp') 182 | gripper_command_path = os.path.join(root_dir, subdirs[ind], CAMERA_NAME, 'gripper_command') 183 | 184 | file_paths = load_image_files(color_path) 185 | trajectory_points = load_trajectory_points(tcp_path, file_paths) 186 | gripper_command = load_gripper_command(gripper_command_path, file_paths) 187 | 188 | return file_paths, trajectory_points, gripper_command 189 | -------------------------------------------------------------------------------- /policy/robomimic/rollout.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | 5 | import os 6 | import json 7 | import torch 8 | import time 9 | import psutil 10 | import sys 11 | import traceback 12 | import argparse 13 | import numpy as np 14 | from collections import OrderedDict 15 | from torch.utils.data import DataLoader 16 | import logging 17 | import random 18 | 19 | import robomimic.utils.train_utils as TrainUtils 20 | import robomimic.utils.torch_utils as TorchUtils 21 | import robomimic.utils.obs_utils as ObsUtils 22 | import robomimic.utils.env_utils as EnvUtils 23 | import robomimic.utils.file_utils as FileUtils 24 | from robomimic.config import config_factory 25 | from robomimic.algo import algo_factory, RolloutPolicy 26 | from robomimic.utils.log_utils import PrintLogger, DataLogger 27 | 28 | 29 | def set_seed(seed): 30 | random.seed(seed) 31 | np.random.seed(seed) 32 | torch.manual_seed(seed) 33 | if torch.cuda.is_available(): 34 | torch.cuda.manual_seed_all(seed) 35 | 36 | class LoggerWriter: 37 | def __init__(self, level): 38 | self.level = level 39 | 40 | def write(self, message): 41 | if message != '\n': 42 | self.level(message) 43 | 44 | def flush(self): 45 | pass 46 | 47 | def setup_logging(log_file_path): 48 | with open(log_file_path, 'w'): 49 | pass 50 | 51 | logging.basicConfig( 52 | level=logging.DEBUG, 53 | format='%(asctime)s - %(levelname)s - %(message)s', 54 | handlers=[ 55 | logging.FileHandler(log_file_path), 56 | logging.StreamHandler(sys.stdout) 57 | ] 58 | ) 59 | 60 | sys.stdout = LoggerWriter(logging.info) 61 | 62 | 63 | def rollout_from_checkpoint(config, checkpoint_path, device, seed, video_dir, epoch): 64 | 65 | print(f"\n============= Performing Rollout for Seed {seed} and Epoch {epoch} =============") 66 | 67 | set_seed(seed) 68 | 69 | ObsUtils.initialize_obs_utils_with_config(config) 70 | checkpoint = torch.load(checkpoint_path, map_location=device) 71 | dataset_path = os.path.expanduser(config.train.data) 72 | env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=config.train.data) 73 | shape_meta = FileUtils.get_shape_metadata_from_dataset( 74 | dataset_path=config.train.data, 75 | all_obs_keys=config.all_obs_keys, 76 | verbose=True 77 | ) 78 | 79 | if config.experiment.env is not None: 80 | env_meta["env_name"] = config.experiment.env 81 | print("=" * 30 + "\n" + "Replacing Env to {}\n".format(env_meta["env_name"]) + "=" * 30) 82 | 83 | envs = OrderedDict() 84 | env_names = [env_meta["env_name"]] 85 | if config.experiment.additional_envs is not None: 86 | env_names.extend(config.experiment.additional_envs) 87 | 88 | for env_name in env_names: 89 | env = EnvUtils.create_env_from_metadata( 90 | env_meta=env_meta, 91 | env_name=env_name, 92 | render=False, 93 | render_offscreen=config.experiment.render_video, 94 | use_image_obs=shape_meta["use_images"], 95 | ) 96 | envs[env.name] = env 97 | print(envs[env.name]) 98 | 99 | model = algo_factory( 100 | algo_name=config.algo_name, 101 | config=config, 102 | obs_key_shapes=shape_meta["all_shapes"], 103 | ac_dim=shape_meta["ac_dim"], 104 | device=device, 105 | ) 106 | 107 | model.load_state_dict(checkpoint["model"]) 108 | 109 | obs_normalization_stats = None 110 | if config.train.hdf5_normalize_obs: 111 | trainset, _ = TrainUtils.load_data_for_training(config, obs_keys=shape_meta["all_obs_keys"]) 112 | obs_normalization_stats = trainset.get_obs_normalization_stats() 113 | 114 | rollout_model = RolloutPolicy(model, obs_normalization_stats=obs_normalization_stats) 115 | 116 | num_episodes = config.experiment.rollout.n 117 | all_rollout_logs, video_paths = TrainUtils.rollout_with_stats( 118 | policy=rollout_model, 119 | envs=envs, 120 | horizon=config.experiment.rollout.horizon, 121 | use_goals=config.use_goals, 122 | num_episodes=num_episodes, 123 | render=False, 124 | video_dir=video_dir, 125 | epoch=epoch, 126 | video_skip=config.experiment.get("video_skip", 5), 127 | terminate_on_success=config.experiment.rollout.terminate_on_success, 128 | ) 129 | 130 | for env_name in all_rollout_logs: 131 | rollout_logs = all_rollout_logs[env_name] 132 | print("\nRollouts results for env {}:".format(env_name)) 133 | print(json.dumps(rollout_logs, sort_keys=True, indent=4)) 134 | 135 | process = psutil.Process(os.getpid()) 136 | mem_usage = int(process.memory_info().rss / 1000000) 137 | print("\nMemory Usage: {} MB\n".format(mem_usage)) 138 | 139 | 140 | def extract_epoch_from_filename(filename): 141 | 142 | try: 143 | base_name = os.path.basename(filename) 144 | epoch_str = base_name.split("model_epoch_")[1].split('_')[0].split(".pth")[0] 145 | return int(epoch_str) 146 | except (IndexError, ValueError): 147 | pass 148 | return None 149 | 150 | 151 | 152 | def main(args): 153 | if args.config is not None: 154 | ext_cfg = json.load(open(args.config, 'r')) 155 | config = config_factory(ext_cfg["algo_name"]) 156 | with config.values_unlocked(): 157 | config.update(ext_cfg) 158 | else: 159 | config = config_factory(args.algo) 160 | 161 | if args.dataset is not None: 162 | config.train.data = args.dataset 163 | 164 | if args.name is not None: 165 | config.experiment.name = args.name 166 | 167 | device = TorchUtils.get_torch_device(try_to_use_cuda=config.train.cuda) 168 | 169 | config.lock() 170 | 171 | # log_file_path = os.path.join(args.checkpoint_dir, "eval-log.txt") 172 | # setup_logging(log_file_path) 173 | 174 | checkpoint_dir = args.checkpoint_dir 175 | checkpoints = [] 176 | 177 | for root, _, files in os.walk(args.checkpoint_dir): 178 | for file in files: 179 | if file.endswith(".pth"): 180 | epoch = extract_epoch_from_filename(file) 181 | if epoch is None: 182 | print(f"Failed to extract epoch from filename: {file}") 183 | continue 184 | checkpoint_path = os.path.join(root, file) 185 | checkpoints.append((epoch, checkpoint_path)) 186 | 187 | checkpoints.sort(key=lambda x: x[0]) 188 | 189 | for epoch, checkpoint_path in checkpoints: 190 | try: 191 | seed = int(checkpoint_path.split('seed')[-1].split('/')[0]) 192 | except (ValueError, IndexError): 193 | print(f"Failed to extract seed from path: {checkpoint_path}") 194 | continue 195 | 196 | video_dir = os.path.join(os.path.dirname(checkpoint_path), "videos") 197 | os.makedirs(video_dir, exist_ok=True) 198 | 199 | print(f"Testing checkpoint: {checkpoint_path} with seed: {seed}, saving videos to: {video_dir}") 200 | try: 201 | rollout_from_checkpoint(config, checkpoint_path, device=device, seed=seed, video_dir=video_dir, epoch=epoch) 202 | except Exception as e: 203 | print(f"Rollout failed for {checkpoint_path} with error:\n{e}\n\n{traceback.format_exc()}") 204 | 205 | if __name__ == "__main__": 206 | parser = argparse.ArgumentParser() 207 | 208 | parser.add_argument("--config", type=str, help="Path to the config JSON file") 209 | parser.add_argument("--algo", type=str, help="Algorithm name") 210 | parser.add_argument("--dataset", type=str, help="Dataset path") 211 | parser.add_argument("--name", type=str, help="Experiment name") 212 | parser.add_argument("--checkpoint_dir", type=str, required=True, help="Directory containing checkpoint models") 213 | 214 | args = parser.parse_args() 215 | main(args) 216 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import sys 4 | import argparse 5 | import time 6 | import math 7 | 8 | import tensorboard_logger as tb_logger 9 | import torch 10 | import torch.backends.cudnn as cudnn 11 | from torchvision import transforms 12 | 13 | from utils.util import * 14 | from utils.constant import * 15 | from utils.data_augmentation import data_augmentation 16 | from networks.resnet import SupConResNet 17 | from losses import SupConLoss 18 | from dataset.dataset import CustomDataset 19 | 20 | try: 21 | import apex 22 | from apex import amp, optimizers 23 | except ImportError: 24 | pass 25 | 26 | def parse_option(): 27 | """ Parse command-line arguments """ 28 | parser = argparse.ArgumentParser('argument for training') 29 | 30 | # Training configurations 31 | parser.add_argument('--print_freq', type=int, default=10, help='Print frequency') 32 | parser.add_argument('--save_freq', type=int, default=50, help='Save model frequency') 33 | parser.add_argument('--batch_size', type=int, default=256, help='Batch size') 34 | parser.add_argument('--num_workers', type=int, default=16, help='Number of data loading workers') 35 | parser.add_argument('--epochs', type=int, default=3000, help='Total training epochs') 36 | 37 | # Optimization configurations 38 | parser.add_argument('--learning_rate', type=float, default=0.05, help='Learning rate') 39 | parser.add_argument('--lr_decay_epochs', type=str, default='700,800,900', help='Epochs where learning rate decays') 40 | parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='Learning rate decay rate') 41 | parser.add_argument('--weight_decay', type=float, default=1e-4, help='Weight decay') 42 | parser.add_argument('--momentum', type=float, default=0.9, help='Momentum') 43 | 44 | # Model and dataset settings 45 | parser.add_argument('--model', type=str, default='resnet50', help='Model architecture') 46 | parser.add_argument("--dataset", type=str, default='./low_dim.hdf5', help="path to hdf5 dataset") 47 | parser.add_argument('--aug_path', type=str, default=None, help='Path to custom dataset') 48 | parser.add_argument("--save_mode", type=str, default='lowdim', choices=['image', 'lowdim', 'realworld'], help="choose the saving method") 49 | parser.add_argument('--size', type=int, default=128, help='Image size for RandomResizedCrop') 50 | 51 | # Data augmentation 52 | parser.add_argument("--total_images", type=int, default=100, help="total number of images to generate") 53 | parser.add_argument("--numbers", type=int, nargs='+', default=[0, 1, 2], help="list of numbers for processing") 54 | 55 | # Method and loss function configurations 56 | parser.add_argument('--method', type=str, default='SupCon', choices=['SupCon', 'SimCLR'], help='Contrastive learning method') 57 | parser.add_argument('--temp', type=float, default=0.01, help='Temperature for loss function') 58 | 59 | # Paths for saving model and tensorboard logs 60 | parser.add_argument('--model_path', type=str, default='./lowdim/models', help='Path to save model checkpoints') 61 | parser.add_argument('--tb_path', type=str, default='./lowdim/tensorboard', help='Path for tensorboard logs') 62 | 63 | # Other settings 64 | parser.add_argument('--cosine', action='store_true', help='Use cosine annealing learning rate schedule') 65 | parser.add_argument('--syncBN', action='store_true', help='Use synchronized Batch Normalization') 66 | parser.add_argument('--warm', action='store_true', help='Use warm-up for large batch training') 67 | 68 | opt = parser.parse_args() 69 | opt.lr_decay_epochs = list(map(int, opt.lr_decay_epochs.split(','))) 70 | 71 | opt.model_name = f'{opt.method}_{opt.model}_lr_{opt.learning_rate}_decay_{opt.weight_decay}_bsz_{opt.batch_size}_temp_{opt.temp}_imgsize_{opt.size}' 72 | if opt.cosine: 73 | opt.model_name += '_cosine' 74 | if opt.batch_size > 256: 75 | opt.warm = True 76 | if opt.warm: 77 | opt.model_name += '_warm' 78 | opt.warmup_from = 0.01 79 | opt.warm_epochs = 10 80 | eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3) if opt.cosine else opt.learning_rate 81 | opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * (1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2 82 | 83 | opt.tb_folder = os.path.join(opt.tb_path, opt.model_name) 84 | opt.save_folder = os.path.join(opt.model_path, opt.model_name) 85 | os.makedirs(opt.tb_folder, exist_ok=True) 86 | os.makedirs(opt.save_folder, exist_ok=True) 87 | 88 | return opt 89 | 90 | 91 | def set_loader(opt): 92 | """ Data loader for the training dataset """ 93 | normalize = transforms.Normalize(mean=IMG_MEAN, std=IMG_STD) 94 | 95 | train_transform = transforms.Compose([ 96 | transforms.RandomResizedCrop(size=opt.size, scale=(0.8, 1.)), 97 | transforms.ToTensor(), 98 | normalize, 99 | ]) 100 | 101 | 102 | train_dataset = CustomDataset(npy_file=opt.aug_path, transform=TwoCropTransform(train_transform)) 103 | 104 | train_loader = torch.utils.data.DataLoader( 105 | train_dataset, batch_size=opt.batch_size, shuffle=True, 106 | num_workers=opt.num_workers, pin_memory=True) 107 | 108 | return train_loader 109 | 110 | def set_model(opt): 111 | """ Initialize model and loss function """ 112 | model = SupConResNet(name=opt.model) 113 | criterion = SupConLoss(temperature=opt.temp) 114 | 115 | # Optional synchronized Batch Normalization 116 | if opt.syncBN: 117 | model = apex.parallel.convert_syncbn_model(model) 118 | 119 | if torch.cuda.is_available(): 120 | model = torch.nn.DataParallel(model).cuda() if torch.cuda.device_count() > 1 else model.cuda() 121 | criterion = criterion.cuda() 122 | cudnn.benchmark = True 123 | 124 | return model, criterion 125 | 126 | def train(train_loader, model, criterion, optimizer, epoch, opt): 127 | """ One epoch training loop """ 128 | model.train() 129 | 130 | batch_time = AverageMeter() 131 | data_time = AverageMeter() 132 | losses = AverageMeter() 133 | 134 | end = time.time() 135 | for idx, (images, labels) in enumerate(train_loader): 136 | data_time.update(time.time() - end) 137 | 138 | images = torch.cat([images[0], images[1]], dim=0) 139 | images, labels = images.cuda(), labels.cuda(non_blocking=True) 140 | 141 | bsz = labels.size(0) 142 | 143 | warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer) 144 | 145 | features = model(images) 146 | f1, f2 = torch.split(features, [bsz, bsz], dim=0) 147 | features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1) 148 | 149 | loss = criterion(features, labels) if opt.method == 'SupCon' else criterion(features) 150 | losses.update(loss.item(), bsz) 151 | 152 | optimizer.zero_grad() 153 | loss.backward() 154 | optimizer.step() 155 | 156 | batch_time.update(time.time() - end) 157 | end = time.time() 158 | 159 | if (idx + 1) % opt.print_freq == 0: 160 | print(f'Epoch: [{epoch}][{idx + 1}/{len(train_loader)}]\t' 161 | f'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 162 | f'Data Time {data_time.val:.3f} ({data_time.avg:.3f})\t' 163 | f'Loss {losses.val:.3f} ({losses.avg:.3f})') 164 | 165 | return losses.avg 166 | 167 | def main(): 168 | """ Main function to train the model """ 169 | opt = parse_option() 170 | data_augmentation(opt) 171 | 172 | train_loader = set_loader(opt) 173 | model, criterion = set_model(opt) 174 | optimizer = set_optimizer(opt, model) 175 | logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2) 176 | 177 | for epoch in range(1, opt.epochs + 1): 178 | adjust_learning_rate(opt, optimizer, epoch) 179 | 180 | start_time = time.time() 181 | loss = train(train_loader, model, criterion, optimizer, epoch, opt) 182 | print(f'Epoch {epoch}, Total Time {time.time() - start_time:.2f}, Loss {loss:.4f}, Learning Rate {optimizer.param_groups[0]["lr"]}') 183 | 184 | logger.log_value('loss', loss, epoch) 185 | logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch) 186 | 187 | if epoch % opt.save_freq == 0: 188 | save_file = os.path.join(opt.save_folder, f'ckpt_epoch_{epoch}.pth') 189 | save_model(model, optimizer, opt, epoch, save_file) 190 | 191 | save_file = os.path.join(opt.save_folder, 'last.pth') 192 | save_model(model, optimizer, opt, opt.epochs, save_file) 193 | 194 | if __name__ == '__main__': 195 | main() 196 | -------------------------------------------------------------------------------- /policy/diffusion_policy/sampler.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import numpy as np 3 | import numba 4 | from diffusion_policy.common.replay_buffer import ReplayBuffer 5 | 6 | 7 | @numba.jit(nopython=True) 8 | def create_indices( 9 | episode_ends: np.ndarray, 10 | sequence_length: int, 11 | episode_mask: np.ndarray, 12 | marks: np.ndarray, 13 | pad_before: int = 0, 14 | pad_after: int = 0, 15 | debug: bool = True, 16 | 17 | ) -> np.ndarray: 18 | 19 | episode_mask.shape == episode_ends.shape 20 | pad_before = min(max(pad_before, 0), sequence_length - 1) 21 | pad_after = min(max(pad_after, 0), sequence_length - 1) 22 | 23 | indices = list() 24 | for i in range(len(episode_ends)): 25 | if not episode_mask[i]: 26 | # skip episode 27 | continue 28 | start_idx = 0 29 | if i > 0: 30 | start_idx = episode_ends[i - 1] 31 | end_idx = episode_ends[i] 32 | episode_length = end_idx - start_idx 33 | 34 | min_start = -pad_before 35 | max_start = episode_length - sequence_length + pad_after 36 | 37 | # range stops one idx before end 38 | for idx in range(min_start, max_start + 1): 39 | buffer_start_idx = max(idx, 0) + start_idx 40 | buffer_end_idx = min(idx + sequence_length, episode_length) + start_idx 41 | start_offset = buffer_start_idx - (idx + start_idx) 42 | end_offset = (idx + sequence_length + start_idx) - buffer_end_idx 43 | sample_start_idx = 0 + start_offset 44 | sample_end_idx = sequence_length - end_offset 45 | if debug: 46 | assert start_offset >= 0 47 | assert end_offset >= 0 48 | assert (sample_end_idx - sample_start_idx) == ( 49 | buffer_end_idx - buffer_start_idx 50 | ) 51 | 52 | # TODO 53 | sample = [] 54 | current_idx = buffer_start_idx 55 | for _ in range(3): 56 | if current_idx < buffer_end_idx: 57 | sample.append(current_idx) 58 | current_idx += 1 59 | else: 60 | break 61 | 62 | while len(sample) < (buffer_end_idx - buffer_start_idx) and current_idx < end_idx: 63 | if marks[current_idx] != 0: 64 | sample.append(current_idx) 65 | current_idx += 1 66 | 67 | sample = np.array(sample) 68 | if len(sample) != buffer_end_idx - buffer_start_idx: 69 | pass 70 | else: 71 | indices.append( 72 | [buffer_start_idx, buffer_end_idx, sample_start_idx, sample_end_idx] 73 | ) 74 | 75 | indices = np.array(indices) 76 | return indices 77 | 78 | 79 | def get_val_mask(n_episodes, val_ratio, seed=0): 80 | val_mask = np.zeros(n_episodes, dtype=bool) 81 | if val_ratio <= 0: 82 | return val_mask 83 | 84 | # have at least 1 episode for validation, and at least 1 episode for train 85 | n_val = min(max(1, round(n_episodes * val_ratio)), n_episodes - 1) 86 | rng = np.random.default_rng(seed=seed) 87 | val_idxs = rng.choice(n_episodes, size=n_val, replace=False) 88 | val_mask[val_idxs] = True 89 | return val_mask 90 | 91 | 92 | def downsample_mask(mask, max_n, seed=0): 93 | # subsample training data 94 | train_mask = mask 95 | if (max_n is not None) and (np.sum(train_mask) > max_n): 96 | n_train = int(max_n) 97 | curr_train_idxs = np.nonzero(train_mask)[0] 98 | rng = np.random.default_rng(seed=seed) 99 | train_idxs_idx = rng.choice(len(curr_train_idxs), size=n_train, replace=False) 100 | train_idxs = curr_train_idxs[train_idxs_idx] 101 | train_mask = np.zeros_like(train_mask) 102 | train_mask[train_idxs] = True 103 | assert np.sum(train_mask) == n_train 104 | return train_mask 105 | 106 | 107 | class SequenceSampler: 108 | def __init__( 109 | self, 110 | replay_buffer: ReplayBuffer, 111 | sequence_length: int, 112 | pad_before: int = 0, 113 | pad_after: int = 0, 114 | keys=None, 115 | key_first_k=dict(), 116 | episode_mask: Optional[np.ndarray] = None, 117 | ): 118 | """ 119 | key_first_k: dict str: int 120 | Only take first k data from these keys (to improve perf) 121 | """ 122 | 123 | super().__init__() 124 | assert sequence_length >= 1 125 | if keys is None: 126 | keys = list(replay_buffer.keys()) 127 | 128 | episode_ends = replay_buffer.episode_ends[:] 129 | marks = replay_buffer["marks"] 130 | if episode_mask is None: 131 | episode_mask = np.ones(episode_ends.shape, dtype=bool) 132 | 133 | if np.any(episode_mask): 134 | indices = create_indices( 135 | episode_ends, 136 | sequence_length=sequence_length, 137 | pad_before=pad_before, 138 | pad_after=pad_after, 139 | episode_mask=episode_mask, 140 | marks = np.array(marks), 141 | ) 142 | else: 143 | indices = np.zeros((0, 4), dtype=np.int64) 144 | 145 | # (buffer_start_idx, buffer_end_idx, sample_start_idx, sample_end_idx) 146 | self.indices = indices 147 | self.keys = list(keys) # prevent OmegaConf list performance problem 148 | self.sequence_length = sequence_length 149 | self.replay_buffer = replay_buffer 150 | self.key_first_k = key_first_k 151 | 152 | def __len__(self): 153 | return len(self.indices) 154 | 155 | def sample_sequence(self, idx): 156 | ( 157 | buffer_start_idx, 158 | buffer_end_idx, 159 | sample_start_idx, 160 | sample_end_idx, 161 | ) = self.indices[idx] 162 | result = dict() 163 | for key in self.keys: 164 | if key == "marks": 165 | continue 166 | # print(key) 167 | # print(self.key_first_k) 168 | input_arr = self.replay_buffer[key] 169 | 170 | # performance optimization, avoid small allocation if possible 171 | marks = self.replay_buffer["marks"] 172 | 173 | if key not in self.key_first_k: 174 | # TODO 175 | input_arr = self.replay_buffer[key] 176 | marks = self.replay_buffer["marks"] 177 | 178 | sample = [] 179 | 180 | # obs_step + 1 181 | current_idx = buffer_start_idx 182 | for _ in range(3): 183 | if current_idx < buffer_end_idx: 184 | sample.append(input_arr[current_idx]) 185 | current_idx += 1 186 | else: 187 | break 188 | 189 | while len(sample) < buffer_end_idx-buffer_start_idx and current_idx < buffer_end_idx: 190 | if marks[current_idx] != 0: 191 | sample.append(input_arr[current_idx]) 192 | current_idx += 1 193 | 194 | while len(sample) < buffer_end_idx-buffer_start_idx and current_idx < len(marks): 195 | if marks[current_idx] != 0: 196 | sample.append(input_arr[current_idx]) 197 | current_idx += 1 198 | 199 | sample = np.array(sample) 200 | 201 | if len(sample) != buffer_end_idx-buffer_start_idx: 202 | raise ValueError("Could not fill the sample to the required sequence length.") 203 | 204 | sample = input_arr[buffer_start_idx:buffer_end_idx] 205 | 206 | else: 207 | # performance optimization, only load used obs steps 208 | n_data = buffer_end_idx - buffer_start_idx 209 | k_data = min(self.key_first_k[key], n_data) 210 | # fill value with Nan to catch bugs 211 | # the non-loaded region should never be used 212 | sample = np.full( 213 | (n_data,) + input_arr.shape[1:], 214 | fill_value=np.nan, 215 | dtype=input_arr.dtype, 216 | ) 217 | try: 218 | sample[:k_data] = input_arr[ 219 | buffer_start_idx : buffer_start_idx + k_data 220 | ] 221 | except Exception as e: 222 | import pdb 223 | 224 | pdb.set_trace() 225 | 226 | data = sample 227 | if (sample_start_idx > 0) or (sample_end_idx < self.sequence_length): 228 | data = np.zeros( 229 | shape=(self.sequence_length,) + input_arr.shape[1:], 230 | dtype=input_arr.dtype, 231 | ) 232 | if sample_start_idx > 0: 233 | data[:sample_start_idx] = sample[0] 234 | if sample_end_idx < self.sequence_length: 235 | data[sample_end_idx:] = sample[-1] 236 | data[sample_start_idx:sample_end_idx] = sample 237 | result[key] = data 238 | return result 239 | -------------------------------------------------------------------------------- /utils/image_generation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import io 3 | import os 4 | import cv2 5 | 6 | from PIL import Image 7 | import matplotlib.pyplot as plt 8 | from .realworld_utils import * 9 | import robosuite.utils.transform_utils as T 10 | 11 | def get_camera_extrinsic_matrix(sim, camera_name): 12 | cam_id = sim.model.camera_name2id(camera_name) 13 | camera_pos = sim.data.cam_xpos[cam_id] 14 | camera_rot = sim.data.cam_xmat[cam_id].reshape(3, 3) 15 | R = T.make_pose(camera_pos, camera_rot) 16 | 17 | camera_axis_correction = np.array( 18 | [[1.0, 0.0, 0.0, 0.0], [0.0, -1.0, 0.0, 0.0], [0.0, 0.0, -1.0, 0.0], [0.0, 0.0, 0.0, 1.0]] 19 | ) 20 | R = R @ camera_axis_correction 21 | return R 22 | 23 | class TrajectoryRenderer: 24 | def __init__(self, env, camera_name, save_mode = False, calib_dir=None, root_dir=None): 25 | self.save_mode = save_mode 26 | if self.save_mode == 'realworld': 27 | self.calib_dir = calib_dir 28 | self.root_dir = root_dir 29 | else: 30 | self.env = env 31 | self.camera_name = camera_name 32 | 33 | # self.extrinsic_matrix = env.get_camera_extrinsic_matrix(camera_name) 34 | self.extrinsic_matrix = get_camera_extrinsic_matrix(env.env.sim, camera_name) 35 | 36 | print(self.extrinsic_matrix) 37 | 38 | self.camera_position = self.extrinsic_matrix[:3, 3] 39 | self.camera_rotation = self.extrinsic_matrix[:3, :3] 40 | 41 | 42 | def render_trajectory_image(self, trajectory_points, ind, samples, save_mode, rotate=False, state_slices=None): 43 | if self.save_mode == 'realworld': 44 | if rotate: 45 | trajectory_points = self.rotate_trajectory(trajectory_points, ind) 46 | transformed_points = translate_points(self.calib_dir, trajectory_points) 47 | self.realworld_save_and_append_images(transformed_points, samples, ind, state_slices[0], rotate) 48 | else: 49 | state_slices is not None and self.env.reset_to(state_slices[0]) 50 | frame = self.env.render(mode="rgb_array", height=480, width=480, camera_name=self.camera_name) 51 | image2 = Image.fromarray(frame) 52 | factor = get_save_mode_factor(save_mode) 53 | image_array = apply_image_filter(image2, factor) 54 | 55 | if rotate: 56 | trajectory_points = self.rotate_trajectory(trajectory_points, ind) 57 | transformed_points = np.dot(trajectory_points - self.camera_position, self.camera_rotation) 58 | 59 | self.save_and_append_images(transformed_points, samples, ind, image_array, rotate) 60 | 61 | def rotate_trajectory(self, trajectory_points, i): 62 | start_point, end_point, middle_points = trajectory_points[0], trajectory_points[-1], trajectory_points[1:-1] 63 | axis = (end_point - start_point) / np.linalg.norm(end_point - start_point) 64 | angle = np.deg2rad(i * 360 / 30) 65 | rotation_matrix = self.get_rotation_matrix(axis, angle) 66 | 67 | rotated_middle_points = np.dot(middle_points - start_point, rotation_matrix.T) + start_point 68 | return np.vstack([start_point, rotated_middle_points, end_point]) 69 | 70 | def get_rotation_matrix(self, axis, angle): 71 | cos_angle = np.cos(angle) 72 | sin_angle = np.sin(angle) 73 | return np.array([ 74 | [cos_angle + axis[0]**2 * (1 - cos_angle), axis[0]*axis[1]*(1 - cos_angle) - axis[2]*sin_angle, axis[0]*axis[2]*(1 - cos_angle) + axis[1]*sin_angle], 75 | [axis[1]*axis[0]*(1 - cos_angle) + axis[2]*sin_angle, cos_angle + axis[1]**2 * (1 - cos_angle), axis[1]*axis[2]*(1 - cos_angle) - axis[0]*sin_angle], 76 | [axis[2]*axis[0]*(1 - cos_angle) - axis[1]*sin_angle, axis[2]*axis[1]*(1 - cos_angle) + axis[0]*sin_angle, cos_angle + axis[2]**2 * (1 - cos_angle)] 77 | ]) 78 | 79 | def realworld_save_and_append_images(self, transformed_points, samples, ind, image_path, rotate): 80 | img_path = os.path.join(self.root_dir, image_path) 81 | image = cv2.imread(img_path) 82 | factor = get_save_mode_factor(self.save_mode) 83 | image = apply_image_filter(image, factor) 84 | image = np.array(image) 85 | 86 | prev_point = None 87 | for point in transformed_points[:]: 88 | cv2.circle(image, (int(point[0]), int(point[1])), 3, (0, 0, 255), -1) 89 | if prev_point is not None: 90 | cv2.line(image, prev_point, point, (0, 0, 255), thickness=4) 91 | prev_point = point 92 | # save_path = './tmp_realworld/marked_image{}.jpg'.format(ind) 93 | # cv2.imwrite(save_path, image) 94 | samples['images'].append(np.array(image)) 95 | samples['labels'].append(1 if rotate else 0) 96 | 97 | def save_and_append_images(self, transformed_points, samples, ind, image_array, rotate): 98 | image_array = plot(transformed_points, image_array) 99 | # image_array.save(f'./tmp/image_com_{ind}.png') 100 | 101 | samples['images'].append(np.array(image_array)) 102 | samples['labels'].append(1 if rotate else 0) 103 | 104 | 105 | class TrajectoryNoiseGenerator: 106 | def __init__(self, trajectory_points): 107 | self.trajectory_points = trajectory_points 108 | 109 | def render_one_point(self): 110 | start_index = np.random.randint(0, len(self.trajectory_points) - 10) 111 | end_index = start_index + 1 112 | 113 | sub_trajectory = self.trajectory_points[start_index:end_index] 114 | noisy_sub_trajectory = self.add_noise(sub_trajectory, scale_range=(0.02, 0.05)) 115 | 116 | noise_trajectory = self.trajectory_points.copy() 117 | noise_trajectory[start_index:end_index] = noisy_sub_trajectory 118 | self.remove_adjacent_points(noise_trajectory, start_index) 119 | return noise_trajectory 120 | 121 | def add_noise(self, sub_trajectory, scale_range): 122 | noise_scale = np.random.uniform(*scale_range) 123 | noise = noise_scale * np.random.randn(sub_trajectory.shape[0], sub_trajectory.shape[1]) 124 | return sub_trajectory + noise 125 | 126 | def remove_adjacent_points(self, noise_trajectory, start_index): 127 | if start_index > 1: 128 | for i in range(1, 4): 129 | noise_trajectory = np.delete(noise_trajectory, start_index - i, axis=0) 130 | 131 | def render_one_point_circle(self): 132 | start_index = np.random.randint(0, min(len(self.trajectory_points) - 10, len(self.trajectory_points))) 133 | end_index = start_index + np.random.randint(5, 10) 134 | 135 | sub_trajectory = self.trajectory_points[start_index:end_index] 136 | noisy_sub_trajectory = self.add_noise(sub_trajectory, scale_range=(0.02, 0.04)) 137 | 138 | inserted_indices = self.get_inserted_indices(start_index, end_index, num_points_range=(10, 20)) 139 | for i, index in enumerate(inserted_indices): 140 | self.trajectory_points = np.insert(self.trajectory_points, index + i, sub_trajectory[0] + np.random.uniform(0.04, 0.06)*np.random.randn(), axis=0) 141 | return self.trajectory_points 142 | 143 | def get_inserted_indices(self, start_index, end_index, num_points_range): 144 | num_inserted_points = np.random.randint(*num_points_range) 145 | inserted_indices = np.random.randint(start_index, end_index, size=num_inserted_points) 146 | inserted_indices.sort() 147 | return inserted_indices 148 | 149 | def render_series(self): 150 | start_index = np.random.randint(0, max(1, len(self.trajectory_points) - 20)) 151 | end_index = start_index + np.random.randint(5, 10) 152 | sub_trajectory = self.trajectory_points[start_index:end_index] 153 | 154 | noisy_sub_trajectory = self.add_noise(sub_trajectory, scale_range=(0.03, 0.06)) 155 | noise_trajectory = self.trajectory_points.copy() 156 | noise_trajectory[start_index:end_index] = noisy_sub_trajectory 157 | return noise_trajectory 158 | 159 | 160 | class TrajectoryGenerator: 161 | def __init__(self, trajectory_points): 162 | self.trajectory_points = trajectory_points 163 | self.noise_generator = TrajectoryNoiseGenerator(trajectory_points) 164 | 165 | def generate_negative_trajectory_points(self): 166 | num_one_point = np.random.randint(0, 8) 167 | noise_trajectory = self.trajectory_points.copy() 168 | 169 | for _ in range(num_one_point): 170 | noise_trajectory = self.noise_generator.render_one_point() 171 | 172 | if num_one_point == 0: 173 | one_point_circle_flag = 1 174 | else: 175 | one_point_circle_flag = np.random.randint(2) 176 | 177 | if one_point_circle_flag: 178 | noise_trajectory = self.noise_generator.render_one_point_circle() 179 | 180 | if not one_point_circle_flag and np.random.randint(2): 181 | noise_trajectory = self.noise_generator.render_series() 182 | 183 | return noise_trajectory 184 | 185 | 186 | class PictureGenerator: 187 | def __init__(self, renderer, samples, save_mode): 188 | self.renderer = renderer 189 | self.samples = samples 190 | self.save_mode = save_mode 191 | 192 | def generate_positive_picture(self, trajectory_points, state_slices, ind, num_images=30): 193 | for i in range(num_images): 194 | self.renderer.render_trajectory_image(trajectory_points, i, self.samples, self.save_mode, rotate=True, state_slices=state_slices) 195 | # This part can be expanded if needed 196 | 197 | def generate_negative_picture(self, trajectory_points, state_slices, ind, num_images): 198 | for i in range(num_images): 199 | noise_trajectory = TrajectoryGenerator(trajectory_points).generate_negative_trajectory_points() 200 | self.renderer.render_trajectory_image(noise_trajectory, i, self.samples, self.save_mode, rotate=False, state_slices=state_slices) 201 | 202 | 203 | -------------------------------------------------------------------------------- /dataset/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import h5py 5 | from torch.utils.data import Dataset 6 | import numpy as np 7 | from PIL import Image 8 | from torchvision import transforms, datasets 9 | import robomimic.utils.obs_utils as ObsUtils 10 | import robomimic.utils.env_utils as EnvUtils 11 | import robomimic.utils.file_utils as FileUtils 12 | 13 | from scipy.spatial.transform import Rotation 14 | from .trajectory_optimization import TrajectoryOptimizer 15 | from utils.realworld_utils import * 16 | from utils.constant import * 17 | 18 | 19 | class CustomDataset(Dataset): 20 | def __init__(self, npy_file, transform=None): 21 | """ 22 | Args: 23 | npy_file (string): 24 | transform (callable, optional): 25 | """ 26 | self.transform = transform 27 | self.data = np.load(npy_file, allow_pickle=True).item() 28 | self.inputs = self.data['images'] 29 | self.labels = self.data['labels'] 30 | 31 | if len(self.inputs) != len(self.labels): 32 | raise ValueError(f"Length mismatch: inputs({len(self.inputs)}) and labels({len(self.labels)})") 33 | 34 | def __len__(self): 35 | return len(self.labels) 36 | 37 | def __getitem__(self, idx): 38 | input_data = self.inputs[idx] 39 | label = self.labels[idx] 40 | if isinstance(input_data, np.ndarray): 41 | input_data = Image.fromarray(input_data) 42 | 43 | if self.transform: 44 | input_data = self.transform(input_data) 45 | 46 | return input_data, label 47 | 48 | 49 | 50 | 51 | class ValDataset(Dataset): 52 | def __init__(self, hdf5_file, transform=None, save_mode=None): 53 | """ 54 | Args: 55 | hdf5_file (string) 56 | transform (callable, optional) 57 | """ 58 | self.transform = transform 59 | self.save_mode = save_mode 60 | self.hdf5_file = hdf5_file 61 | self.data = h5py.File(hdf5_file, 'r') 62 | 63 | self.demos = list(self.data["data"].keys()) 64 | self.small_demos = {} 65 | self.mapping = {} 66 | self._init_env() 67 | self.optimizer = TrajectoryOptimizer(self.env, real_world=False) 68 | 69 | self._split_demos() 70 | self.data.close() 71 | 72 | 73 | def _init_env(self): 74 | env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=self.hdf5_file) 75 | env_type = EnvUtils.get_env_type(env_meta=env_meta) 76 | self.render_image_names = DEFAULT_CAMERAS[env_type] 77 | 78 | dummy_spec = dict( 79 | obs=dict( 80 | low_dim=["robot0_eef_pos"], 81 | rgb=[], 82 | ), 83 | ) 84 | ObsUtils.initialize_obs_utils_with_obs_specs(obs_modality_specs=dummy_spec) 85 | self.env = EnvUtils.create_env_from_metadata(env_meta=env_meta, render=False, render_offscreen=True) 86 | 87 | def _calculate_items(self, demo_idx, states, actions): 88 | states = self.data[f"data/{demo_idx}/states"][()] 89 | traj_len = states.shape[0] 90 | 91 | delta_actions = self.data[f"data/{demo_idx}/actions"][()] 92 | action_pos = np.zeros((traj_len, 3), dtype=delta_actions.dtype) 93 | action_ori = np.zeros((traj_len, 3), dtype=delta_actions.dtype) 94 | action_gripper = delta_actions[:, -1:] 95 | 96 | robot = self.env.env.robots[0] 97 | controller = robot.controller 98 | 99 | for i in range(len(states)): 100 | self.env.reset_to({"states": states[i]}) 101 | robot.control(actions[i], policy_step=True) 102 | 103 | action_pos[i] = controller.ee_pos 104 | action_ori[i] = Rotation.from_matrix(controller.ee_ori_mat).as_rotvec() 105 | 106 | actions = np.concatenate([action_pos, action_ori, action_gripper], axis=-1) 107 | return actions 108 | 109 | 110 | def _split_demos(self): 111 | for demo_idx in self.demos: 112 | eef_pos =self.data[f"data/{demo_idx}/obs/robot0_eef_pos"][()] 113 | eef_quat = self.data[f"data/{demo_idx}/obs/robot0_eef_quat"][()] 114 | joint_pos = self.data[f"data/{demo_idx}/obs/robot0_joint_pos"][()] 115 | gt_states = [] 116 | traj_len = eef_pos.shape[0] 117 | for i in range(traj_len): 118 | gt_states.append( 119 | dict( 120 | robot0_eef_pos=eef_pos[i], 121 | robot0_eef_quat=eef_quat[i], 122 | robot0_joint_pos=joint_pos[i], 123 | ) 124 | ) 125 | actions = self.data[f'data/{demo_idx}/actions'][()] 126 | states = self.data[f'data/{demo_idx}/states'][()] 127 | trajectory_points = self.data[f'data/{demo_idx}/obs/robot0_eef_pos'][()] 128 | actions = self._calculate_items(demo_idx, states, actions) 129 | 130 | frames = self._get_frames(actions) 131 | small_demos = self._split_into_small_demos(actions, states, trajectory_points, gt_states, frames) 132 | 133 | self.small_demos[demo_idx] = small_demos 134 | self.mapping[demo_idx] = list(range(len(small_demos))) 135 | 136 | def _get_frames(self, actions): 137 | frames = [0, len(actions) - 1] 138 | for i in range(len(actions) - 1): 139 | if actions[i, -1] != actions[i + 1, -1]: 140 | frames.append(i) 141 | frames.sort() 142 | 143 | merged_frames = [frames[0]] 144 | for i in range(1, len(frames)): 145 | if frames[i] - merged_frames[-1] < 15: 146 | merged_frames.pop() 147 | merged_frames.append(frames[i]) 148 | 149 | merged_frames.sort() 150 | return merged_frames 151 | 152 | def _split_into_small_demos(self, actions, states, trajectory_points, gt_states, frames): 153 | small_demos = [] 154 | for i in range(len(frames) - 1): 155 | start, end = frames[i], frames[i + 1] 156 | small_demos.append({ 157 | 'actions': actions[start:end], 158 | 'states': states[start:end], 159 | 'gt_states': gt_states[start:end], 160 | 'trajectory_points': trajectory_points[start:end], 161 | 'frame_start': frames[i], 162 | 'frame_end': frames[i+1] 163 | }) 164 | return small_demos 165 | 166 | def __len__(self): 167 | return sum(len(small_demo) for small_demo in self.small_demos.values()) 168 | 169 | def __getitem__(self, idx): 170 | demo_idx, small_demo_idx = self._find_small_demo_index(idx) 171 | small_demo = self.small_demos[demo_idx][small_demo_idx] 172 | 173 | positive_image = self.generate_image(small_demo) 174 | 175 | if self.transform: 176 | positive_image = self.transform(positive_image) 177 | 178 | return positive_image, demo_idx, small_demo_idx 179 | 180 | def _find_small_demo_index(self, idx): 181 | for demo_idx, small_demos in self.small_demos.items(): 182 | if idx < len(small_demos): 183 | return demo_idx, idx 184 | idx -= len(small_demos) 185 | raise IndexError("Index out of range.") 186 | 187 | def _save_marks(self, demo_idx, marks): 188 | with h5py.File(self.hdf5_file, 'a') as f: 189 | if f'data/{demo_idx}/marks' not in f: 190 | f.create_dataset(f'data/{demo_idx}/marks', data=marks) 191 | else: 192 | existing_marks = f[f'data/{demo_idx}/marks'][:] 193 | all_marks = np.unique(np.concatenate((existing_marks, marks))) 194 | all_marks.sort() 195 | del f[f'data/{demo_idx}/marks'] 196 | f.create_dataset(f'data/{demo_idx}/marks', data=all_marks) 197 | 198 | 199 | def visualize_image(self,idx): 200 | demo_idx, small_demo_idx = self._find_small_demo_index(idx) 201 | small_demo = self.small_demos[demo_idx][small_demo_idx] 202 | 203 | positive_image = self.generate_image(small_demo) 204 | return positive_image 205 | 206 | def perform_optimization(self, idx, label): 207 | flag = label < 0.5 208 | demo_idx, small_demo_idx = self._find_small_demo_index(idx) 209 | small_demo = self.small_demos[demo_idx][small_demo_idx] 210 | if flag: 211 | marks = self.optimizer.optimize_trajectory(small_demo, demo_idx, small_demo_idx,three_dimension=True) 212 | else: 213 | marks = list(range(small_demo['frame_start'], small_demo['frame_end'])) 214 | 215 | self._save_marks(demo_idx, marks) 216 | 217 | def transform_points(self, trajectory_points): 218 | camera_position = np.array([1.0, 0.0, 1.75]) 219 | camera_rotation = np.array([ 220 | [0.0, -0.70614724, 0.70806503], 221 | [1.0, 0.0, 0.0], 222 | [0.0, 0.70806503, 0.70614724] 223 | ]) 224 | 225 | transformed_points = np.dot(trajectory_points - camera_position, camera_rotation) 226 | return transformed_points 227 | 228 | 229 | def generate_image(self, small_demo, save_mode="image"): 230 | trajectory_points = small_demo['trajectory_points'] 231 | 232 | self.env.reset() 233 | self.env.reset_to(dict(states=small_demo['states'][0])) 234 | frame = self.env.render(mode="rgb_array", height=480, width=480, camera_name=self.render_image_names[0]) 235 | 236 | image = Image.fromarray(frame) 237 | factor = get_save_mode_factor(save_mode=self.save_mode) 238 | image = apply_image_filter(image, factor) 239 | 240 | transformed_points = self.transform_points(trajectory_points) 241 | return plot(transformed_points, image) 242 | 243 | 244 | class RealworldValDataset(Dataset): 245 | def __init__(self, dataset, transform=None, save_mode=None): 246 | """ 247 | Args: 248 | hdf5_file (string) 249 | transform (callable, optional) 250 | """ 251 | self.transform = transform 252 | self.calib_dir = check_directory_exists(os.path.join(dataset, "calib")) 253 | self.save_mode = save_mode 254 | self.root_dir = os.path.join(dataset, "train") 255 | subdirs = [d for d in os.listdir(self.root_dir) if os.path.isdir(os.path.join(self.root_dir, d))] 256 | self.subdirs = sorted(subdirs, key=lambda x: int(x.split('_scene_')[1].split('_')[0])) 257 | self.small_demos = {} 258 | self.mapping = {} 259 | 260 | self.optimizer = TrajectoryOptimizer(env=None, real_world=True, calib_dir=self.calib_dir) 261 | self._split_demos() 262 | 263 | def _split_demos(self): 264 | for ind in range(len(self.subdirs)-20): 265 | color_path = os.path.join(self.root_dir, self.subdirs[ind], CAMERA_NAME, 'color') 266 | file_paths, trajectory_points, gripper_command = load_demo_files(self.root_dir, self.subdirs, ind) 267 | frames = realworld_change_indices(gripper_command) 268 | small_demos = self._split_into_small_demos(file_paths, trajectory_points, frames, color_path) 269 | 270 | self.small_demos[ind] = small_demos 271 | self.mapping[ind] = list(range(len(small_demos))) 272 | 273 | def _split_into_small_demos(self, file_paths, trajectory_points, frames, color_path): 274 | small_demos = [] 275 | for i in range(len(frames) - 1): 276 | start, end = frames[i], frames[i + 1] 277 | small_demos.append({ 278 | 'states': file_paths[start:end], 279 | 'trajectory_points': trajectory_points[start:end], 280 | 'frame_start': frames[i], 281 | 'frame_end': frames[i+1], 282 | 'color_path': color_path 283 | }) 284 | return small_demos 285 | 286 | def _find_small_demo_index(self, idx): 287 | for demo_idx, small_demos in self.small_demos.items(): 288 | if idx < len(small_demos): 289 | return demo_idx, idx 290 | idx -= len(small_demos) 291 | raise IndexError("Index out of range.") 292 | 293 | def _generate_image(self, small_demo): 294 | trajectory_points = small_demo['trajectory_points'] 295 | 296 | transformed_points = translate_points(self.calib_dir, trajectory_points) 297 | img_path = os.path.join(small_demo['color_path'], small_demo['states'][0]) 298 | image = cv2.imread(img_path) 299 | 300 | factor = get_save_mode_factor(save_mode=self.save_mode) 301 | image = apply_image_filter(image, factor) 302 | image = np.array(image) 303 | 304 | prev_point = None 305 | for point in transformed_points[:]: 306 | cv2.circle(image, (int(point[0]), int(point[1])), 3, (0, 0, 255), -1) 307 | if prev_point is not None: 308 | cv2.line(image, prev_point, point, (0, 0, 255), thickness=4) 309 | prev_point = point 310 | return image 311 | 312 | def __len__(self): 313 | return sum(len(small_demo) for small_demo in self.small_demos.values()) 314 | 315 | def __getitem__(self, idx): 316 | demo_idx, small_demo_idx = self._find_small_demo_index(idx) 317 | small_demo = self.small_demos[demo_idx][small_demo_idx] 318 | positive_image = self._generate_image(small_demo) 319 | 320 | positive_image = self.transform(Image.fromarray(positive_image)) if self.transform and isinstance(positive_image, np.ndarray) else positive_image 321 | 322 | return positive_image, demo_idx, small_demo_idx 323 | 324 | def _save_marks(self, demo_idx, marks, color_path): 325 | npy_file_path = os.path.join(color_path, f'marks_{demo_idx}.npy') 326 | if os.path.exists(npy_file_path): 327 | existing_marks = np.load(npy_file_path) 328 | all_marks = np.unique(np.concatenate((existing_marks, marks))) 329 | else: 330 | all_marks = np.unique(marks) 331 | 332 | all_marks.sort() 333 | np.save(npy_file_path, all_marks) 334 | 335 | def perform_optimization(self, idx, flag=True): 336 | demo_idx, small_demo_idx = self._find_small_demo_index(idx) 337 | small_demo = self.small_demos[demo_idx][small_demo_idx] 338 | if flag: 339 | marks = self.optimizer.optimize_trajectory(small_demo, demo_idx, small_demo_idx,three_dimension=True) 340 | else: 341 | marks = list(range(small_demo['frame_start'], small_demo['frame_end'])) 342 | 343 | self._save_marks(demo_idx, marks, small_demo['color_path']) 344 | 345 | --------------------------------------------------------------------------------