├── utils ├── __init__.py ├── constant.py ├── trajectory_utils.py ├── util.py ├── data_augmentation.py ├── realworld_utils.py └── image_generation.py ├── dataset ├── __init__.py ├── trajectory_optimization.py └── dataset.py ├── assets └── img │ ├── model.png │ └── action_relabeling.png ├── command_eval.sh ├── command_train.sh ├── policy ├── README.md ├── diffusion_policy │ ├── robomimic_replay_lowdim_dataset.py │ └── sampler.py └── robomimic │ └── rollout.py ├── LICENSE ├── networks └── resnet.py ├── README.md ├── losses.py ├── eval.py ├── environment.yaml └── train.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/img/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junxix/S2I/HEAD/assets/img/model.png -------------------------------------------------------------------------------- /assets/img/action_relabeling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junxix/S2I/HEAD/assets/img/action_relabeling.png -------------------------------------------------------------------------------- /command_eval.sh: -------------------------------------------------------------------------------- 1 | python eval.py --train_data_folder ./lowdim_samples.npy --val_data_folder ./low_dim.hdf5 --size 128 --ckpt ./ckpts/ckpt_epoch_2000.pth -------------------------------------------------------------------------------- /command_train.sh: -------------------------------------------------------------------------------- 1 | python train.py --batch_size 256 --learning_rate 0.005 --temp 0.1 --cosine --aug_path ./lowdim_samples.npy --method SupCon --epochs 2500 --save_freq 100 --print_freq 1 --size 128 --save_mode realworld -------------------------------------------------------------------------------- /utils/constant.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from robomimic.envs.env_base import EnvBase, EnvType 3 | 4 | IMG_MEAN = np.array([0.485, 0.456, 0.406]) 5 | IMG_STD = np.array([0.229, 0.224, 0.225]) 6 | 7 | DEFAULT_CAMERAS = { 8 | EnvType.ROBOSUITE_TYPE: ["agentview"], 9 | EnvType.IG_MOMART_TYPE: ["rgb"], 10 | EnvType.GYM_TYPE: ValueError("No camera names supported for gym type env!"), 11 | } 12 | 13 | CAMERA_NAME = 'cam_750612070851' 14 | 15 | DELTA_THETA = 75 16 | -------------------------------------------------------------------------------- /policy/README.md: -------------------------------------------------------------------------------- 1 | # Manipulation Policy 2 | ## BC-RNN 3 | For state-based [BC-RNN](https://github.com/ARISE-Initiative/robomimic), we modified the rollout program to ensure that during simulation evaluation, 50 starting positions are randomly selected. These positions vary across different seeds but remain consistent within the same seed. 4 | 5 | Here are the argument explanations in the rollout process: 6 | * `--config` : Specifies the configuration for the algorithm's structure. 7 | * `--dataset` : The path to the dataset used for training and loading environment parameters. 8 | * `--checkpoint_dir` : The directory containing the checkpoints to be evaluated. 9 | 10 | ## Diffusion Policy 11 | 12 |
14 |
6 |
7 | ## 🧑🏻💻 Run
8 | For the representation model training stage, run the command `bash command_train.sh` to execute the data training script, which will preprocess the dataset and training the model.
9 |
10 | Here are the argument explanations in the training process:
11 | * `--dataset` : Specifies the entire dataset used for the representation model training.
12 | * `--aug_path` : The path where the results of the augmented dataset will be stored.
13 | * `--save_mode` :Indicates the format or type of the dataset.
14 | * `--size` : Specifies the size to which the images will be resized.
15 | * `--numbers` :The index or specific identifier used for data augmentation within the dataset.
16 |
17 | For the eval stage, run the command `bash command_eval.sh` to complete the segment selection and trajectory optimization processes.
18 |
19 | Here are the argument explanations in the evaluation process:
20 | * `--train_data_folder` : The dataset used for distance-weighted voting during the segment selection process.
21 | * `--val_data_folder` : The folder containing the full mixed-quality demonstration dataset for validation.
22 | * `--size` : Specifies the size to which the images will be resized.
23 |
24 | ## 🤖 Training Manipulation Policy
25 |
26 | After Select Segments to Imitate (S2I), the dataset can be directly used for downstream manipulation policy training as a plug-and-play solution.
27 |
28 | For simulation experiments, we use the state-based [BC-RNN](https://github.com/ARISE-Initiative/robomimic) and the [Diffusion Policy (DP)](https://github.com/real-stanford/diffusion_policy) that can be applied to both state and image data as robot manipulation policies. For real-world experiments, we choose [DP](https://github.com/real-stanford/diffusion_policy) and [ACT](https://github.com/tonyzhaozh/act) as our image-based policies, as well as [RISE](https://github.com/rise-policy/rise) as our point-cloud-based policy. Some minor modifications have been made to the sampler and rollout functions. The modified Python file is available in [`./policy`](https://github.com/Junxix/S2I/tree/main/policy). Refer to the [documentation](policy/README.md) for more details.
29 |
30 | ## 🙏 Acknowledgement
31 |
32 | Our code is built upon: [Diffusion Policy](https://github.com/real-stanford/diffusion_policy/), [RoboMimic](https://github.com/ARISE-Initiative/robomimic), [SupContrast](https://github.com/HobbitLong/SupContrast), [RISE](https://github.com/rise-policy/rise) and [ACT](https://github.com/tonyzhaozh/act). We thank all the authors for the contributions to the community.
33 |
34 | ## ✍️ Citation
35 |
36 | If you find S2I useful in your research, please consider citing the following paper:
37 |
38 | ```bibtex
39 | @article{
40 | chen2024towards,
41 | title = {Towards Effective Utilization of Mixed-Quality Demonstrations in Robotic Manipulation via Segment-Level Selection and Optimization},
42 | author = {Chen, Jingjing and Fang, Hongjie and Fang, Hao-Shu and Lu, Cewu},
43 | journal = {arXiv preprint arXiv:2409.19917},
44 | year = {2024}
45 | }
46 | ```
47 |
48 | ## 📃 License
49 |
50 | S2I by Jingjing Chen, Hongjie Fang, Hao-Shu Fang, Cewu Lu is licensed under MIT License.
51 |
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class SupConLoss(nn.Module):
8 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
9 | It also supports the unsupervised contrastive loss in SimCLR"""
10 | def __init__(self, temperature=0.07, contrast_mode='all',
11 | base_temperature=0.07):
12 | super(SupConLoss, self).__init__()
13 | self.temperature = temperature
14 | self.contrast_mode = contrast_mode
15 | self.base_temperature = base_temperature
16 |
17 | def forward(self, features, labels=None, mask=None):
18 | """Compute loss for model. If both `labels` and `mask` are None,
19 | it degenerates to SimCLR unsupervised loss:
20 | https://arxiv.org/pdf/2002.05709.pdf
21 |
22 | Args:
23 | features: hidden vector of shape [bsz, n_views, ...].
24 | labels: ground truth of shape [bsz].
25 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
26 | has the same class as sample i. Can be asymmetric.
27 | Returns:
28 | A loss scalar.
29 | """
30 | device = (torch.device('cuda')
31 | if features.is_cuda
32 | else torch.device('cpu'))
33 |
34 | if len(features.shape) < 3:
35 | raise ValueError('`features` needs to be [bsz, n_views, ...],'
36 | 'at least 3 dimensions are required')
37 | if len(features.shape) > 3:
38 | features = features.view(features.shape[0], features.shape[1], -1)
39 |
40 | batch_size = features.shape[0]
41 | if labels is not None and mask is not None:
42 | raise ValueError('Cannot define both `labels` and `mask`')
43 | elif labels is None and mask is None:
44 | mask = torch.eye(batch_size, dtype=torch.float32).to(device)
45 | elif labels is not None:
46 | labels = labels.contiguous().view(-1, 1)
47 | if labels.shape[0] != batch_size:
48 | raise ValueError('Num of labels does not match num of features')
49 | mask = torch.eq(labels, labels.T).float().to(device)
50 | else:
51 | mask = mask.float().to(device)
52 |
53 | contrast_count = features.shape[1]
54 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
55 | if self.contrast_mode == 'one':
56 | anchor_feature = features[:, 0]
57 | anchor_count = 1
58 | elif self.contrast_mode == 'all':
59 | anchor_feature = contrast_feature
60 | anchor_count = contrast_count
61 | else:
62 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
63 |
64 | # compute logits
65 | anchor_dot_contrast = torch.div(
66 | torch.matmul(anchor_feature, contrast_feature.T),
67 | self.temperature)
68 | # for numerical stability
69 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
70 | logits = anchor_dot_contrast - logits_max.detach()
71 |
72 | # tile mask
73 | mask = mask.repeat(anchor_count, contrast_count)
74 | # mask-out self-contrast cases
75 | logits_mask = torch.scatter(
76 | torch.ones_like(mask),
77 | 1,
78 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
79 | 0
80 | )
81 | mask = mask * logits_mask
82 |
83 | # compute log_prob
84 | exp_logits = torch.exp(logits) * logits_mask
85 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
86 |
87 | # compute mean of log-likelihood over positive
88 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
89 |
90 | # loss
91 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
92 | loss = loss.view(anchor_count, batch_size).mean()
93 |
94 | return loss
95 |
--------------------------------------------------------------------------------
/utils/data_augmentation.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import h5py
4 | from tqdm import tqdm
5 | from .realworld_utils import *
6 | from .constant import *
7 | from .image_generation import TrajectoryRenderer, PictureGenerator
8 | import argparse
9 |
10 | import robomimic
11 | import robomimic.utils.obs_utils as ObsUtils
12 | import robomimic.utils.env_utils as EnvUtils
13 | import robomimic.utils.file_utils as FileUtils
14 |
15 | def process_slices(generator, trajectory_slices, color_path_slices, ind, total_images_per_slice):
16 | for i, (traj_slice, state_slice) in enumerate(zip(trajectory_slices, color_path_slices)):
17 | generator.generate_positive_picture(traj_slice, state_slice, ind, num_images=total_images_per_slice)
18 | generator.generate_negative_picture(traj_slice, state_slice, ind, num_images=total_images_per_slice)
19 |
20 |
21 | def data_augmentation_realworld(args):
22 | calib_dir = check_directory_exists(os.path.join(args.dataset, "calib"))
23 |
24 | root_dir = os.path.join(args.dataset, "train")
25 | subdirs = [d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
26 | sorted_subdirs = sorted(subdirs, key=lambda x: int(x.split('_scene_')[1].split('_')[0]))
27 | samples = {'images': [], 'end_images': [], 'labels': []}
28 |
29 | total_demos = len(args.numbers)
30 | total_images_per_demo = args.total_images // total_demos
31 |
32 | for ind in args.numbers:
33 | path = os.path.join(root_dir, sorted_subdirs[ind], CAMERA_NAME, 'color')
34 | renderer = TrajectoryRenderer(env=None, camera_name=None, save_mode='realworld', calib_dir=calib_dir, root_dir=path)
35 | generator = PictureGenerator(renderer, samples, save_mode=args.save_mode)
36 |
37 | file_paths, trajectory_points, gripper_command = load_demo_files(root_dir, sorted_subdirs, ind)
38 |
39 | change_indices = realworld_change_indices(gripper_command)
40 | trajectory_slices, color_path_slices = realworld_slice(trajectory_points, file_paths, change_indices)
41 | total_images_per_slice = total_images_per_demo // len(trajectory_slices)
42 | process_slices(generator, trajectory_slices, color_path_slices, ind, total_images_per_slice)
43 |
44 | np.save(args.aug_path, samples)
45 |
46 |
47 | def data_augmentation_robomimic(args):
48 | env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=args.dataset)
49 | env_type = EnvUtils.get_env_type(env_meta=env_meta)
50 | render_image_names = DEFAULT_CAMERAS[env_type]
51 |
52 | dummy_spec = dict(
53 | obs=dict(
54 | low_dim=["robot0_eef_pos"],
55 | rgb=[],
56 | ),
57 | )
58 | ObsUtils.initialize_obs_utils_with_obs_specs(obs_modality_specs=dummy_spec)
59 |
60 | env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=args.dataset)
61 | env = EnvUtils.create_env_from_metadata(env_meta=env_meta, render=False, render_offscreen=True)
62 | is_robosuite_env = EnvUtils.is_robosuite_env(env_meta)
63 |
64 | f = h5py.File(args.dataset, "r")
65 | demos = sorted(f["data"].keys(), key=lambda x: int(x[5:]))
66 | samples = {'images': [], 'end_images': [], 'labels': []}
67 |
68 | total_demos = len(args.numbers)
69 | total_images_per_demo = args.total_images // total_demos
70 |
71 | renderer = TrajectoryRenderer(env, render_image_names[0])
72 | generator = PictureGenerator(renderer, samples, save_mode=args.save_mode)
73 |
74 | for ind in args.numbers:
75 | ep = demos[ind]
76 | states = f[f"data/{ep}/states"][()]
77 | trajectory_points = f[f"data/{ep}/obs/robot0_eef_pos"][()]
78 | actions = f[f"data/{ep}/actions"][()]
79 |
80 | initial_state = dict(states=states[0])
81 | if is_robosuite_env:
82 | initial_state["model"] = f[f"data/{ep}"].attrs["model_file"]
83 | generator.renderer.env.reset()
84 | generator.renderer.env.reset_to(initial_state)
85 |
86 | change_indices = find_all_change_indices(actions)
87 | trajectory_slices, state_slices = slice_trajectory_and_states(trajectory_points, states, change_indices)
88 | total_images_per_slice = total_images_per_demo // len(trajectory_slices)
89 | process_slices(generator, trajectory_slices, state_slices, ind, total_images_per_slice)
90 |
91 | np.save(args.aug_path, samples)
92 | f.close()
93 |
94 |
95 | def data_augmentation(args):
96 | if args.save_mode == 'realworld':
97 | data_augmentation_realworld(args)
98 | else:
99 | data_augmentation_robomimic(args)
100 |
101 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import argparse
4 | import numpy as np
5 | import torch
6 | import torch.backends.cudnn as cudnn
7 | from torchvision import transforms
8 | from tqdm import tqdm
9 | from utils.util import TwoCropTransform, AverageMeter
10 | from networks.resnet import SupConResNet
11 | from dataset.dataset import CustomDataset, ValDataset
12 | from utils.constant import *
13 |
14 | def parse_option():
15 | parser = argparse.ArgumentParser('Argument for training')
16 | parser.add_argument('--model', type=str, default='resnet50')
17 | parser.add_argument("--save_mode", type=str, default='lowdim', choices=['image', 'lowdim', 'realworld'], help="choose the saving method")
18 | parser.add_argument('--train_data_folder', type=str, default='./lowdim_samples.npy', help='path to custom dataset')
19 | parser.add_argument('--val_data_folder', type=str, default='./low_dim.hdf5', help='path to custom dataset')
20 | parser.add_argument('--size', type=int, default=128)
21 | parser.add_argument('--ckpt', type=str, default='./ckpt_epoch_2000.pth',
22 | help='path to pre-trained model')
23 | return parser.parse_args()
24 |
25 | def dist_metric(x, y):
26 | return torch.norm(x - y).item()
27 |
28 | def calculate_label(dist_list, k):
29 | top_k_weights = torch.nn.functional.softmax(torch.tensor([d[0] for d in dist_list[:k]]) * -1, dim=0)
30 | action = sum(weight * dist_list[i][1] for i, weight in enumerate(top_k_weights))
31 | return action
32 |
33 | def clear_folders_if_not_empty(folders):
34 | for folder in folders:
35 | if os.path.exists(folder) and os.listdir(folder):
36 | shutil.rmtree(folder)
37 | os.makedirs(folder)
38 |
39 | def calculate_nearest_neighbors(query_embedding, train_dataset, train_labels, k):
40 | dist_list = [(dist_metric(torch.from_numpy(query_embedding), torch.from_numpy(train_dataset[i])), train_labels[i]) for i in range(len(train_dataset))]
41 | dist_list.sort(key=lambda tup: tup[0])
42 | return calculate_label(dist_list, k)
43 |
44 | def set_loader(opt):
45 | normalize = transforms.Normalize(mean=IMG_MEAN, std=IMG_STD)
46 |
47 | train_transform = transforms.Compose([
48 | transforms.RandomResizedCrop(size=opt.size, scale=(0.8, 1.0)),
49 | transforms.ToTensor(),
50 | normalize,
51 | ])
52 |
53 | val_transform = transforms.Compose([
54 | transforms.Resize((opt.size, opt.size)),
55 | transforms.ToTensor(),
56 | normalize,
57 | ])
58 |
59 | train_dataset = CustomDataset(npy_file=opt.train_data_folder, transform=train_transform)
60 | val_dataset = ValDataset(hdf5_file=opt.val_data_folder, transform=val_transform, save_mode = opt.save_mode)
61 | return train_dataset, val_dataset
62 |
63 | def set_model(opt):
64 | model = SupConResNet(name=opt.model)
65 | ckpt = torch.load(opt.ckpt, map_location='cpu')
66 | state_dict = {k.replace("module.", ""): v for k, v in ckpt['model'].items()}
67 |
68 | if torch.cuda.is_available():
69 | model = model.cuda()
70 | cudnn.benchmark = True
71 | model.load_state_dict(state_dict, strict=False)
72 | else:
73 | raise NotImplementedError('This code requires GPU')
74 | return model
75 |
76 | def get_embeddings(train_dataset, model):
77 | model.eval()
78 | embeddings, labels = [], []
79 | for idx in range(len(train_dataset)):
80 | image, label = train_dataset[idx]
81 | image = image.unsqueeze(0).cuda(non_blocking=True)
82 | with torch.no_grad():
83 | features = model.encoder(image).flatten(start_dim=1)
84 | embeddings.append(features.cpu().numpy())
85 | labels.append(label)
86 | return np.concatenate(embeddings), np.array(labels)
87 |
88 | def classifier(val_dataset, train_dataset, train_labels, model, neighbors_num):
89 | device = next(model.parameters()).device
90 |
91 | for idx in tqdm(range(len(val_dataset))):
92 | image_data, demo_idx, small_demo_idx = val_dataset[idx]
93 | image_data = image_data.unsqueeze(0).to(device)
94 |
95 | with torch.no_grad():
96 | val_embedding = model.encoder(image_data).cpu().numpy()
97 |
98 | label = calculate_nearest_neighbors(val_embedding, train_dataset, train_labels, neighbors_num)
99 | val_dataset.perform_optimization(idx, label=label)
100 |
101 | def main():
102 | opt = parse_option()
103 | train_dataset, val_dataset = set_loader(opt)
104 | model = set_model(opt)
105 | train_embeddings, train_labels = get_embeddings(train_dataset, model)
106 | classifier(val_dataset, train_embeddings, train_labels, model, neighbors_num=64)
107 |
108 | if __name__ == '__main__':
109 | main()
110 |
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: mujoco
2 | channels:
3 | - menpo
4 | - conda-forge
5 | - http://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch3d/
6 | - http://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/
7 | - http://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
8 | - http://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
9 | - defaults
10 | dependencies:
11 | - _libgcc_mutex=0.1=conda_forge
12 | - _openmp_mutex=4.5=2_kmp_llvm
13 | - bzip2=1.0.8=h4bc722e_7
14 | - c-ares=1.34.1=heb4867d_0
15 | - ca-certificates=2024.8.30=hbcca054_0
16 | - elfutils=0.191=h924a536_0
17 | - gettext=0.22.5=he02047a_3
18 | - gettext-tools=0.22.5=he02047a_3
19 | - glew=2.1.0=h9c3ff4c_2
20 | - glfw3=3.2.1=0
21 | - gnutls=3.8.7=h32866dd_0
22 | - icu=75.1=he02047a_0
23 | - keyutils=1.6.1=h166bdaf_0
24 | - krb5=1.21.3=h659f571_0
25 | - ld_impl_linux-64=2.40=h12ee557_0
26 | - libarchive=3.7.4=hfca40fe_0
27 | - libasprintf=0.22.5=he8f35ee_3
28 | - libasprintf-devel=0.22.5=he8f35ee_3
29 | - libcurl=8.10.1=hbbe4b11_0
30 | - libdrm=2.4.123=hb9d3cd8_0
31 | - libedit=3.1.20191231=he28a2e2_2
32 | - libev=4.33=hd590300_2
33 | - libexpat=2.6.3=h5888daf_0
34 | - libffi=3.4.4=h6a678d5_1
35 | - libgcc=14.1.0=h77fa898_1
36 | - libgcc-ng=14.1.0=h69a702a_1
37 | - libgettextpo=0.22.5=he02047a_3
38 | - libgettextpo-devel=0.22.5=he02047a_3
39 | - libglu=9.0.0=ha6d2627_1004
40 | - libiconv=1.17=hd590300_2
41 | - libidn2=2.3.7=hd590300_0
42 | - libllvm19=19.1.1=ha7bfdaf_0
43 | - libmicrohttpd=1.0.1=hbc5bc17_1
44 | - libnghttp2=1.58.0=h47da74e_1
45 | - libnsl=2.0.1=hd590300_0
46 | - libpciaccess=0.18=hd590300_0
47 | - libsqlite=3.46.1=hadc24fc_0
48 | - libssh2=1.11.0=h0841786_0
49 | - libstdcxx=14.1.0=hc0a3c3a_1
50 | - libstdcxx-ng=14.1.0=h4852527_1
51 | - libtasn1=4.19.0=h166bdaf_0
52 | - libunistring=0.9.10=h7f98852_0
53 | - libuuid=2.38.1=h0b41bf4_0
54 | - libxcb=1.17.0=h8a09558_0
55 | - libxcrypt=4.4.36=hd590300_1
56 | - libxml2=2.12.7=he7c6b58_4
57 | - libzlib=1.3.1=hb9d3cd8_2
58 | - llvm-openmp=19.1.1=h024ca30_0
59 | - lz4-c=1.9.4=hcb278e6_0
60 | - lzo=2.10=hd590300_1001
61 | - mesalib=24.2.4=h039c18d_0
62 | - ncurses=6.5=he02047a_1
63 | - nettle=3.9.1=h7ab15ed_0
64 | - openssl=3.3.2=hb9d3cd8_0
65 | - p11-kit=0.24.1=hc5aa10d_0
66 | - pip=24.2=py39h06a4308_0
67 | - pthread-stubs=0.4=hb9d3cd8_1002
68 | - python=3.9.18=h0755675_1_cpython
69 | - readline=8.2=h5eee18b_0
70 | - setuptools=75.1.0=py39h06a4308_0
71 | - tk=8.6.13=noxft_h4845f30_101
72 | - tzdata=2024b=h04d1e81_0
73 | - wheel=0.44.0=py39h06a4308_0
74 | - xorg-glproto=1.4.17=hb9d3cd8_1003
75 | - xorg-libx11=1.8.10=h4f16b4b_0
76 | - xorg-libxau=1.0.11=hb9d3cd8_1
77 | - xorg-libxdamage=1.1.6=hb9d3cd8_0
78 | - xorg-libxdmcp=1.1.5=hb9d3cd8_0
79 | - xorg-libxext=1.3.6=hb9d3cd8_0
80 | - xorg-libxfixes=6.0.1=hb9d3cd8_0
81 | - xorg-libxrandr=1.5.4=hb9d3cd8_0
82 | - xorg-libxrender=0.9.11=hb9d3cd8_1
83 | - xorg-libxxf86vm=1.1.5=hb9d3cd8_3
84 | - xorg-xextproto=7.3.0=hb9d3cd8_1004
85 | - xorg-xf86vidmodeproto=2.3.1=hb9d3cd8_1003
86 | - xorg-xorgproto=2024.1=hb9d3cd8_1
87 | - xz=5.4.6=h5eee18b_1
88 | - zstd=1.5.6=ha6fb4c9_0
89 | - pip:
90 | - absl-py==2.1.0
91 | - cachetools==5.5.0
92 | - certifi==2024.8.30
93 | - cffi==1.17.1
94 | - charset-normalizer==3.4.0
95 | - contourpy==1.3.0
96 | - cycler==0.12.1
97 | - cython==0.29.37
98 | - egl-probe==1.0.2
99 | - etils==1.5.2
100 | - fasteners==0.15
101 | - fonttools==4.54.1
102 | - free-mujoco-py==2.1.6
103 | - fsspec==2024.9.0
104 | - glfw==1.12.0
105 | - google-auth==2.35.0
106 | - google-auth-oauthlib==0.4.6
107 | - grpcio==1.66.2
108 | - h5py==3.12.1
109 | - idna==3.10
110 | - imageio==2.35.1
111 | - importlib-metadata==8.5.0
112 | - importlib-resources==6.4.5
113 | - kiwisolver==1.4.7
114 | - llvmlite==0.43.0
115 | - markdown==3.7
116 | - markupsafe==3.0.1
117 | - matplotlib==3.9.2
118 | - monotonic==1.6
119 | - mujoco==3.0.0
120 | - mujoco-py==2.1.2.14
121 | - numba==0.60.0
122 | - numpy==1.23.5
123 | - oauthlib==3.2.2
124 | - opencv-python==4.10.0.84
125 | - packaging==24.1
126 | - patchelf==0.17.2.1
127 | - pillow==10.4.0
128 | - protobuf==3.19.6
129 | - pyasn1==0.6.1
130 | - pyasn1-modules==0.4.1
131 | - pycparser==2.22
132 | - pyopengl==3.1.7
133 | - pyparsing==3.1.4
134 | - python-dateutil==2.9.0.post0
135 | - requests==2.32.3
136 | - requests-oauthlib==2.0.0
137 | - robomimic==0.2.0
138 | - robosuite==1.2.0
139 | - rsa==4.9
140 | - scipy==1.13.1
141 | - six==1.16.0
142 | - tensorboard==2.10.1
143 | - tensorboard-data-server==0.6.1
144 | - tensorboard-logger==0.1.0
145 | - tensorboard-plugin-wit==1.8.1
146 | - termcolor==2.5.0
147 | - torch==1.10.0+cu113
148 | - torchaudio==0.10.0+cu113
149 | - torchvision==0.11.1+cu113
150 | - tqdm==4.66.5
151 | - typing-extensions==4.12.2
152 | - urllib3==2.2.3
153 | - werkzeug==3.0.4
154 | - zipp==3.20.2
155 |
--------------------------------------------------------------------------------
/policy/diffusion_policy/robomimic_replay_lowdim_dataset.py:
--------------------------------------------------------------------------------
1 |
2 | from typing import Dict, List
3 | import torch
4 | import numpy as np
5 | import h5py
6 | from tqdm import tqdm
7 | import copy
8 | from diffusion_policy.common.pytorch_util import dict_apply
9 | from diffusion_policy.dataset.base_dataset import BaseLowdimDataset, LinearNormalizer
10 | from diffusion_policy.model.common.normalizer import LinearNormalizer, SingleFieldLinearNormalizer
11 | from diffusion_policy.model.common.rotation_transformer import RotationTransformer
12 | from diffusion_policy.common.replay_buffer import ReplayBuffer
13 | from diffusion_policy.common.sampler import (
14 | SequenceSampler, get_val_mask, downsample_mask)
15 | from diffusion_policy.common.normalize_util import (
16 | robomimic_abs_action_only_normalizer_from_stat,
17 | robomimic_abs_action_only_dual_arm_normalizer_from_stat,
18 | get_identity_normalizer_from_stat,
19 | array_to_stats
20 | )
21 |
22 | class RobomimicReplayLowdimDataset(BaseLowdimDataset):
23 | def __init__(self,
24 | dataset_path: str,
25 | horizon=1,
26 | pad_before=0,
27 | pad_after=0,
28 | obs_keys: List[str]=[
29 | 'object',
30 | 'robot0_eef_pos',
31 | 'robot0_eef_quat',
32 | 'robot0_gripper_qpos'],
33 | abs_action=False,
34 | rotation_rep='rotation_6d',
35 | use_legacy_normalizer=False,
36 | seed=42,
37 | val_ratio=0.0,
38 | max_train_episodes=None
39 | ):
40 | obs_keys = list(obs_keys)
41 | rotation_transformer = RotationTransformer(
42 | from_rep='axis_angle', to_rep=rotation_rep)
43 |
44 | replay_buffer = ReplayBuffer.create_empty_numpy()
45 | with h5py.File(dataset_path) as file:
46 | demos = file['data']
47 | for i in tqdm(range(len(demos)), desc="Loading hdf5 to ReplayBuffer"):
48 | demo = demos[f'demo_{i}']
49 |
50 | episode = _data_to_obs(
51 | marks=demo['marks'],
52 | raw_obs=demo['obs'],
53 | raw_actions=demo['actions'][:].astype(np.float32),
54 | obs_keys=obs_keys,
55 | abs_action=abs_action,
56 | rotation_transformer=rotation_transformer)
57 | replay_buffer.add_episode(episode)
58 |
59 | val_mask = get_val_mask(
60 | n_episodes=replay_buffer.n_episodes,
61 | val_ratio=val_ratio,
62 | seed=seed)
63 | train_mask = ~val_mask
64 | train_mask = downsample_mask(
65 | mask=train_mask,
66 | max_n=max_train_episodes,
67 | seed=seed)
68 |
69 | sampler = SequenceSampler(
70 | replay_buffer=replay_buffer,
71 | sequence_length=horizon,
72 | pad_before=pad_before,
73 | pad_after=pad_after,
74 | episode_mask=train_mask)
75 |
76 | self.replay_buffer = replay_buffer
77 | self.sampler = sampler
78 | self.abs_action = abs_action
79 | self.train_mask = train_mask
80 | self.horizon = horizon
81 | self.pad_before = pad_before
82 | self.pad_after = pad_after
83 | self.use_legacy_normalizer = use_legacy_normalizer
84 |
85 | def get_validation_dataset(self):
86 | val_set = copy.copy(self)
87 | val_set.sampler = SequenceSampler(
88 | replay_buffer=self.replay_buffer,
89 | sequence_length=self.horizon,
90 | pad_before=self.pad_before,
91 | pad_after=self.pad_after,
92 | episode_mask=~self.train_mask
93 | )
94 | val_set.train_mask = ~self.train_mask
95 | return val_set
96 |
97 | def get_normalizer(self, **kwargs) -> LinearNormalizer:
98 | normalizer = LinearNormalizer()
99 |
100 | # action
101 | stat = array_to_stats(self.replay_buffer['action'])
102 | if self.abs_action:
103 | if stat['mean'].shape[-1] > 10:
104 | # dual arm
105 | this_normalizer = robomimic_abs_action_only_dual_arm_normalizer_from_stat(stat)
106 | else:
107 | this_normalizer = robomimic_abs_action_only_normalizer_from_stat(stat)
108 |
109 | if self.use_legacy_normalizer:
110 | this_normalizer = normalizer_from_stat(stat)
111 | else:
112 | # already normalized
113 | this_normalizer = get_identity_normalizer_from_stat(stat)
114 | normalizer['action'] = this_normalizer
115 |
116 | # aggregate obs stats
117 | obs_stat = array_to_stats(self.replay_buffer['obs'])
118 |
119 |
120 | normalizer['obs'] = normalizer_from_stat(obs_stat)
121 | return normalizer
122 |
123 | def get_all_actions(self) -> torch.Tensor:
124 | return torch.from_numpy(self.replay_buffer['action'])
125 |
126 | def __len__(self):
127 | return len(self.sampler)
128 |
129 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
130 | data = self.sampler.sample_sequence(idx)
131 | torch_data = dict_apply(data, torch.from_numpy)
132 | return torch_data
133 |
134 | def normalizer_from_stat(stat):
135 | max_abs = np.maximum(stat['max'].max(), np.abs(stat['min']).max())
136 | scale = np.full_like(stat['max'], fill_value=1/max_abs)
137 | offset = np.zeros_like(stat['max'])
138 | return SingleFieldLinearNormalizer.create_manual(
139 | scale=scale,
140 | offset=offset,
141 | input_stats_dict=stat
142 | )
143 |
144 | def _data_to_obs(marks, raw_obs, raw_actions, obs_keys, abs_action, rotation_transformer):
145 | obs = np.concatenate([
146 | raw_obs[key] for key in obs_keys
147 | ], axis=-1).astype(np.float32)
148 |
149 | if abs_action:
150 | is_dual_arm = False
151 | if raw_actions.shape[-1] == 14:
152 | # dual arm
153 | raw_actions = raw_actions.reshape(-1,2,7)
154 | is_dual_arm = True
155 |
156 | pos = raw_actions[...,:3]
157 | rot = raw_actions[...,3:6]
158 | gripper = raw_actions[...,6:]
159 | rot = rotation_transformer.forward(rot)
160 | raw_actions = np.concatenate([
161 | pos, rot, gripper
162 | ], axis=-1).astype(np.float32)
163 |
164 | if is_dual_arm:
165 | raw_actions = raw_actions.reshape(-1,20)
166 |
167 | marks = np.array(marks)
168 | data_length = obs.shape[0]
169 | new_marks = np.zeros(data_length, dtype=int)
170 | valid_indices = marks[marks < data_length]
171 | new_marks[valid_indices] = 1
172 |
173 | data = {
174 | 'obs': obs,
175 | 'action': raw_actions,
176 | 'marks': new_marks
177 | }
178 | return data
179 |
180 |
--------------------------------------------------------------------------------
/dataset/trajectory_optimization.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import numpy as np
4 | from typing import List, Tuple, Dict, Optional
5 | from dataclasses import dataclass
6 |
7 | @dataclass
8 | class TrajectoryPoint:
9 | """Represents a point in the trajectory with its coordinates."""
10 | coordinates: np.ndarray
11 | index: int
12 |
13 | class GeometryCalculator:
14 |
15 | @staticmethod
16 | def calculate_vector(point1: np.ndarray, point2: np.ndarray) -> np.ndarray:
17 | return point2 - point1
18 |
19 | @staticmethod
20 | def calculate_angle(vector1: np.ndarray, vector2: np.ndarray) -> float:
21 | dot_product = np.dot(vector1, vector2)
22 | norm_vector1 = np.linalg.norm(vector1)
23 | norm_vector2 = np.linalg.norm(vector2)
24 | cos_theta = dot_product / (norm_vector1 * norm_vector2)
25 | return np.arccos(np.clip(cos_theta, -1.0, 1.0))
26 |
27 | class CoordinateTransformer:
28 | """Handles coordinate transformations between different reference frames."""
29 | def __init__(self, env, real_world: bool = False, calib_dir: Optional[str] = None):
30 | self.env = env
31 | self.real_world = real_world
32 | self.calib_dir = calib_dir
33 |
34 | def transform_points(self, trajectory_points: np.ndarray) -> np.ndarray:
35 | if self.real_world:
36 | transformed_points = self._transform_real_world(trajectory_points)
37 | else:
38 | transformed_points = self._transform_simulation(trajectory_points)
39 |
40 | return transformed_points[:, :2] # Return only x,y coordinates
41 |
42 | def _transform_real_world(self, points: np.ndarray) -> np.ndarray:
43 | from utils.realworld_utils import translate_points
44 | return translate_points(self.calib_dir, points)
45 |
46 | def _transform_simulation(self, points: np.ndarray) -> np.ndarray:
47 | extrinsic_matrix = self.env.get_camera_extrinsic_matrix('agentview')
48 | camera_position = extrinsic_matrix[:3, 3]
49 | camera_rotation = extrinsic_matrix[:3, :3]
50 | return np.dot(points - camera_position, camera_rotation)
51 |
52 |
53 | class TrajectoryOptimizer:
54 | """Main class for optimizing robot trajectories."""
55 | def __init__(self, env, real_world: bool = False, calib_dir: Optional[str] = None):
56 | self.geometry = GeometryCalculator()
57 | self.transformer = CoordinateTransformer(env, real_world, calib_dir)
58 |
59 | def optimize_trajectory(self, demo: Dict, demo_idx: int, small_demo_idx: int,
60 | three_dimension: bool = False) -> List[int]:
61 | frame_start = demo['frame_start']
62 | waypoints = self._calculate_waypoints_dp(demo, three_dimension)
63 | return [mark + frame_start for mark in waypoints]
64 |
65 | def _calculate_waypoints_dp(self, demo: Dict, three_dimension: bool) -> List[int]:
66 | err_threshold = 0.005
67 | actions = demo['actions']
68 | gt_states = demo['gt_states']
69 | num_frames = len(actions)
70 |
71 | dp_table = self._initialize_dp_table(num_frames)
72 |
73 | min_error = self._compute_trajectory_errors(actions, gt_states, list(range(1, num_frames)))
74 | if err_threshold < min_error:
75 | return list(range(1, num_frames))
76 |
77 | return self._fill_dp_table(dp_table, actions, gt_states, num_frames, err_threshold)
78 |
79 | def _initialize_dp_table(self, size: int) -> Dict:
80 | dp_table = {i: (0, []) for i in range(size)}
81 | dp_table[1] = (1, [1])
82 | return dp_table
83 |
84 | def _compute_trajectory_errors(self, actions: np.ndarray, gt_states: np.ndarray,
85 | waypoints: List[int]) -> float:
86 | from utils.trajectory_utils import compute_errors
87 | return compute_errors(actions=actions, gt_states=gt_states, waypoints=waypoints)
88 |
89 | def _fill_dp_table(self, dp_table: Dict, actions: np.ndarray, gt_states: np.ndarray,
90 | num_frames: int, err_threshold: float) -> List[int]:
91 | initial_waypoints = [0, num_frames - 1]
92 |
93 | for i in range(1, num_frames):
94 | min_waypoints_required = float("inf")
95 | best_waypoints = []
96 |
97 | for k in range(1, i):
98 | waypoints = [j - k for j in initial_waypoints if k <= j < i] + [i - k]
99 | total_err = self._compute_trajectory_errors(
100 | actions[k:i + 1], gt_states[k:i + 1], waypoints
101 | )
102 |
103 | if total_err < err_threshold:
104 | prev_count, prev_waypoints = dp_table[k - 1]
105 | total_count = 1 + prev_count
106 |
107 | if total_count < min_waypoints_required:
108 | min_waypoints_required = total_count
109 | best_waypoints = prev_waypoints + [i]
110 |
111 | dp_table[i] = (min_waypoints_required, best_waypoints)
112 |
113 | _, waypoints = dp_table[num_frames - 1]
114 | waypoints.extend(initial_waypoints)
115 | return sorted(list(set(waypoints)))
116 |
117 | def _calculate_geometric_waypoints(self, demo: Dict, three_dimension: bool) -> List[int]:
118 | points = np.array(demo['trajectory_points'])
119 | trajectory_points = points if three_dimension else self.transformer.transform_points(points)
120 |
121 | tolerance = self._calculate_tolerance(trajectory_points)
122 |
123 | return self._select_waypoints(trajectory_points, tolerance)
124 |
125 | def _calculate_tolerance(self, points: np.ndarray) -> float:
126 | return np.max(np.linalg.norm(points[1:] - points[:-1], axis=1)) * 2
127 |
128 | def _select_waypoints(self, points: np.ndarray, tolerance: float) -> List[int]:
129 | selected_indices = [0]
130 | current_idx = 0
131 |
132 | while current_idx < len(points) - 1:
133 | if np.array_equal(points[current_idx], points[-1]):
134 | break
135 |
136 | next_idx = self._find_next_waypoint(
137 | points, current_idx, selected_indices, tolerance
138 | )
139 |
140 | if next_idx is None:
141 | break
142 |
143 | selected_indices.append(next_idx)
144 | current_idx = next_idx
145 |
146 | return selected_indices
147 |
148 | def _find_next_waypoint(self, points: np.ndarray, current_idx: int,
149 | selected_indices: List[int], tolerance: float) -> Optional[int]:
150 | candidates = []
151 |
152 | for i in range(len(points)):
153 | if i in selected_indices:
154 | continue
155 |
156 | distance = np.linalg.norm(points[current_idx] - points[i])
157 | if distance > tolerance:
158 | continue
159 |
160 | vector_to_candidate = self.geometry.calculate_vector(points[current_idx], points[i])
161 | vector_to_goal = self.geometry.calculate_vector(points[current_idx], points[-1])
162 | angle = self.geometry.calculate_angle(vector_to_candidate, vector_to_goal)
163 |
164 | if angle < np.radians(DELTA_THETA):
165 | candidates.append((angle, distance, i))
166 |
167 | if not candidates:
168 | return None
169 |
170 | candidates.sort(key=lambda x: (x[0], x[1]))
171 | return candidates[0][2]
--------------------------------------------------------------------------------
/utils/realworld_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import io
3 | import numpy as np
4 | from PIL import Image
5 | import matplotlib.pyplot as plt
6 | from .constant import *
7 |
8 |
9 | def load_image_files(sub_path):
10 | file_paths = []
11 | if os.path.exists(sub_path):
12 | npy_files = sorted([f for f in os.listdir(sub_path) if f.endswith('.png')])
13 | for file_name in npy_files:
14 | file_paths.append(file_name)
15 | else:
16 | print(f"The directory {sub_path} does not exist.")
17 | return file_paths
18 |
19 | def load_trajectory_points(sub_path, file_paths):
20 | trajectory_points = []
21 | if os.path.exists(sub_path):
22 | for file_name in file_paths:
23 | base_name = os.path.splitext(file_name)[0]
24 | file_path = os.path.join(sub_path, base_name + '.npy')
25 | data = np.load(file_path)
26 | first_three_numbers = data[:3]
27 | trajectory_points.append(first_three_numbers)
28 | else:
29 | print(f"The directory {sub_path} does not exist.")
30 | return np.array(trajectory_points)
31 |
32 | def load_gripper_command(sub_path, file_paths):
33 | gripper_command = []
34 | if os.path.exists(sub_path):
35 | for file_name in file_paths:
36 | base_name = os.path.splitext(file_name)[0]
37 | file_path = os.path.join(sub_path, base_name + '.npy')
38 | data = np.load(file_path)[0]
39 | gripper_command.append(data)
40 | else:
41 | print(f"The directory {sub_path} does not exist.")
42 | return gripper_command
43 |
44 | def realworld_change_indices(gripper_command):
45 | differences = np.diff(gripper_command)
46 | change_indices = [0]
47 | for i, diff in enumerate(differences):
48 | if diff != 0:
49 | if not change_indices:
50 | change_indices.append(i + 1)
51 | else:
52 | current_diff = i + 1 - change_indices[-1]
53 | if current_diff > 5:
54 | change_indices.append(i + 1)
55 | return change_indices
56 |
57 | def realworld_slice(trajectory_points, file_paths, change_indices):
58 | trajectory_slices, state_slices = [], []
59 | start_idx = 0
60 | for idx in change_indices:
61 | if idx - start_idx >= 10:
62 | trajectory_slices.append(trajectory_points[start_idx:idx])
63 | state_slices.append(file_paths[start_idx:idx])
64 | start_idx = idx
65 | if len(trajectory_points) - start_idx > 15:
66 | trajectory_slices.append(trajectory_points[start_idx:])
67 | state_slices.append(file_paths[start_idx:])
68 | return trajectory_slices, state_slices
69 |
70 | def get_save_mode_factor(save_mode):
71 | if save_mode == 'lowdim':
72 | return 0
73 | elif save_mode == 'image':
74 | return 0.1
75 | elif save_mode == 'realworld':
76 | return 0.1
77 | else:
78 | raise ValueError(f"Unknown save_mode: {save_mode}")
79 |
80 | def apply_image_filter(image, factor):
81 | image_array = np.array(image)
82 | white_image = np.ones_like(image_array) * 255
83 | new_image_array = (image_array * factor + white_image * (1 - factor)).astype(np.uint8)
84 | return Image.fromarray(new_image_array)
85 |
86 |
87 | def load_calibration_data(calib_root_dir):
88 | tcp_file = os.path.join(calib_root_dir, 'tcp.npy')
89 | extrinsics_file = os.path.join(calib_root_dir, 'extrinsics.npy')
90 | intrinsics_file = os.path.join(calib_root_dir, 'intrinsics.npy')
91 |
92 | tcp = np.load(tcp_file)
93 | extrinsics = np.load(extrinsics_file, allow_pickle=True).item()
94 | intrinsics = np.load(intrinsics_file, allow_pickle=True).item()
95 |
96 | return tcp, extrinsics, intrinsics
97 |
98 | def quaternion_to_rotation_matrix(quaternion):
99 | qw, qx, qy, qz = quaternion
100 | R = np.array([
101 | [1 - 2*qy**2 - 2*qz**2, 2*qx*qy - 2*qz*qw, 2*qx*qz + 2*qy*qw],
102 | [2*qx*qy + 2*qz*qw, 1 - 2*qx**2 - 2*qz**2, 2*qy*qz - 2*qx*qw],
103 | [2*qx*qz - 2*qy*qw, 2*qy*qz + 2*qx*qw, 1 - 2*qx**2 - 2*qy**2]
104 | ])
105 | return R
106 |
107 | def create_transformation_matrix(position, quaternion):
108 | R = quaternion_to_rotation_matrix(quaternion)
109 | T = np.eye(4)
110 | T[:3, :3] = R
111 | T[:3, 3] = position
112 | return T
113 |
114 | def compute_extrinsic_matrix(extrinsics, M_cam0433_to_end, M_end_to_base):
115 | M_cam0433_to_A = extrinsics['043322070878'][0]
116 | M_cam7506_to_A = extrinsics['750612070851'][0]
117 |
118 | M_cam7506_to_base = M_cam7506_to_A @ np.linalg.inv(M_cam0433_to_A) @ M_cam0433_to_end @ M_end_to_base
119 | return M_cam7506_to_base
120 |
121 | def convert_to_pixel_coordinates(trajectory_points, extrinsic_matrix, camera_matrix):
122 | translated_points = []
123 | for point in trajectory_points:
124 | object_point_world = np.append(point, 1).reshape(-1, 1)
125 | object_point_camera = extrinsic_matrix @ object_point_world
126 | object_point_pixel = camera_matrix @ object_point_camera
127 | object_point_pixel /= object_point_pixel[2]
128 | pixel_point = np.array([int(object_point_pixel[0]), int(object_point_pixel[1])])
129 | translated_points.append(pixel_point)
130 | return np.array(translated_points)
131 |
132 | def translate_points(calib_root_dir, trajectory_points):
133 | tcp, extrinsics, intrinsics = load_calibration_data(calib_root_dir)
134 |
135 | position = tcp[:3]
136 | quaternion = tcp[3:]
137 | M_end_to_base = create_transformation_matrix(position, quaternion)
138 |
139 | M_cam0433_to_end = np.array([[0, -1, 0, 0],
140 | [1, 0, 0, 0.077],
141 | [0, 0, 1, 0.2665],
142 | [0, 0, 0, 1]])
143 |
144 | extrinsic_matrix = compute_extrinsic_matrix(extrinsics, M_cam0433_to_end, M_end_to_base)
145 | camera_matrix = intrinsics['750612070851']
146 |
147 | return convert_to_pixel_coordinates(trajectory_points, extrinsic_matrix, camera_matrix)
148 |
149 | def plot(transformed_points, image):
150 | plt.clf()
151 | projected_points = transformed_points[:, :2]
152 | plt.plot(projected_points[:, 0], -projected_points[:, 1], color='red', linewidth=5)
153 | plt.axis('off')
154 | plt.xlim(-0.45, 0.45)
155 | plt.ylim(-0.5, 0.5)
156 | plt.gcf().set_size_inches(480/96, 480/96)
157 | plt.tight_layout()
158 |
159 | buf = io.BytesIO()
160 | plt.savefig(buf, format='png', transparent=True, dpi=96)
161 | buf.seek(0)
162 |
163 | image1 = Image.open(buf)
164 | image.paste(image1, (0, 0), image1)
165 | buf.close()
166 | return image
167 |
168 | def check_directory_exists(directory):
169 | if os.path.exists(directory) and os.path.isdir(directory):
170 | sub_dirs = [d for d in os.listdir(directory) if os.path.isdir(os.path.join(directory, d))]
171 | if sub_dirs:
172 | return os.path.join(directory, sub_dirs[0])
173 | else:
174 | print(f"No sub-directories found in {directory}.")
175 | else:
176 | print(f"The directory {directory} does not exist.")
177 | return None
178 |
179 | def load_demo_files(root_dir, subdirs, ind):
180 | color_path = os.path.join(root_dir, subdirs[ind], CAMERA_NAME, 'color')
181 | tcp_path = os.path.join(root_dir, subdirs[ind], CAMERA_NAME, 'tcp')
182 | gripper_command_path = os.path.join(root_dir, subdirs[ind], CAMERA_NAME, 'gripper_command')
183 |
184 | file_paths = load_image_files(color_path)
185 | trajectory_points = load_trajectory_points(tcp_path, file_paths)
186 | gripper_command = load_gripper_command(gripper_command_path, file_paths)
187 |
188 | return file_paths, trajectory_points, gripper_command
189 |
--------------------------------------------------------------------------------
/policy/robomimic/rollout.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 |
5 | import os
6 | import json
7 | import torch
8 | import time
9 | import psutil
10 | import sys
11 | import traceback
12 | import argparse
13 | import numpy as np
14 | from collections import OrderedDict
15 | from torch.utils.data import DataLoader
16 | import logging
17 | import random
18 |
19 | import robomimic.utils.train_utils as TrainUtils
20 | import robomimic.utils.torch_utils as TorchUtils
21 | import robomimic.utils.obs_utils as ObsUtils
22 | import robomimic.utils.env_utils as EnvUtils
23 | import robomimic.utils.file_utils as FileUtils
24 | from robomimic.config import config_factory
25 | from robomimic.algo import algo_factory, RolloutPolicy
26 | from robomimic.utils.log_utils import PrintLogger, DataLogger
27 |
28 |
29 | def set_seed(seed):
30 | random.seed(seed)
31 | np.random.seed(seed)
32 | torch.manual_seed(seed)
33 | if torch.cuda.is_available():
34 | torch.cuda.manual_seed_all(seed)
35 |
36 | class LoggerWriter:
37 | def __init__(self, level):
38 | self.level = level
39 |
40 | def write(self, message):
41 | if message != '\n':
42 | self.level(message)
43 |
44 | def flush(self):
45 | pass
46 |
47 | def setup_logging(log_file_path):
48 | with open(log_file_path, 'w'):
49 | pass
50 |
51 | logging.basicConfig(
52 | level=logging.DEBUG,
53 | format='%(asctime)s - %(levelname)s - %(message)s',
54 | handlers=[
55 | logging.FileHandler(log_file_path),
56 | logging.StreamHandler(sys.stdout)
57 | ]
58 | )
59 |
60 | sys.stdout = LoggerWriter(logging.info)
61 |
62 |
63 | def rollout_from_checkpoint(config, checkpoint_path, device, seed, video_dir, epoch):
64 |
65 | print(f"\n============= Performing Rollout for Seed {seed} and Epoch {epoch} =============")
66 |
67 | set_seed(seed)
68 |
69 | ObsUtils.initialize_obs_utils_with_config(config)
70 | checkpoint = torch.load(checkpoint_path, map_location=device)
71 | dataset_path = os.path.expanduser(config.train.data)
72 | env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=config.train.data)
73 | shape_meta = FileUtils.get_shape_metadata_from_dataset(
74 | dataset_path=config.train.data,
75 | all_obs_keys=config.all_obs_keys,
76 | verbose=True
77 | )
78 |
79 | if config.experiment.env is not None:
80 | env_meta["env_name"] = config.experiment.env
81 | print("=" * 30 + "\n" + "Replacing Env to {}\n".format(env_meta["env_name"]) + "=" * 30)
82 |
83 | envs = OrderedDict()
84 | env_names = [env_meta["env_name"]]
85 | if config.experiment.additional_envs is not None:
86 | env_names.extend(config.experiment.additional_envs)
87 |
88 | for env_name in env_names:
89 | env = EnvUtils.create_env_from_metadata(
90 | env_meta=env_meta,
91 | env_name=env_name,
92 | render=False,
93 | render_offscreen=config.experiment.render_video,
94 | use_image_obs=shape_meta["use_images"],
95 | )
96 | envs[env.name] = env
97 | print(envs[env.name])
98 |
99 | model = algo_factory(
100 | algo_name=config.algo_name,
101 | config=config,
102 | obs_key_shapes=shape_meta["all_shapes"],
103 | ac_dim=shape_meta["ac_dim"],
104 | device=device,
105 | )
106 |
107 | model.load_state_dict(checkpoint["model"])
108 |
109 | obs_normalization_stats = None
110 | if config.train.hdf5_normalize_obs:
111 | trainset, _ = TrainUtils.load_data_for_training(config, obs_keys=shape_meta["all_obs_keys"])
112 | obs_normalization_stats = trainset.get_obs_normalization_stats()
113 |
114 | rollout_model = RolloutPolicy(model, obs_normalization_stats=obs_normalization_stats)
115 |
116 | num_episodes = config.experiment.rollout.n
117 | all_rollout_logs, video_paths = TrainUtils.rollout_with_stats(
118 | policy=rollout_model,
119 | envs=envs,
120 | horizon=config.experiment.rollout.horizon,
121 | use_goals=config.use_goals,
122 | num_episodes=num_episodes,
123 | render=False,
124 | video_dir=video_dir,
125 | epoch=epoch,
126 | video_skip=config.experiment.get("video_skip", 5),
127 | terminate_on_success=config.experiment.rollout.terminate_on_success,
128 | )
129 |
130 | for env_name in all_rollout_logs:
131 | rollout_logs = all_rollout_logs[env_name]
132 | print("\nRollouts results for env {}:".format(env_name))
133 | print(json.dumps(rollout_logs, sort_keys=True, indent=4))
134 |
135 | process = psutil.Process(os.getpid())
136 | mem_usage = int(process.memory_info().rss / 1000000)
137 | print("\nMemory Usage: {} MB\n".format(mem_usage))
138 |
139 |
140 | def extract_epoch_from_filename(filename):
141 |
142 | try:
143 | base_name = os.path.basename(filename)
144 | epoch_str = base_name.split("model_epoch_")[1].split('_')[0].split(".pth")[0]
145 | return int(epoch_str)
146 | except (IndexError, ValueError):
147 | pass
148 | return None
149 |
150 |
151 |
152 | def main(args):
153 | if args.config is not None:
154 | ext_cfg = json.load(open(args.config, 'r'))
155 | config = config_factory(ext_cfg["algo_name"])
156 | with config.values_unlocked():
157 | config.update(ext_cfg)
158 | else:
159 | config = config_factory(args.algo)
160 |
161 | if args.dataset is not None:
162 | config.train.data = args.dataset
163 |
164 | if args.name is not None:
165 | config.experiment.name = args.name
166 |
167 | device = TorchUtils.get_torch_device(try_to_use_cuda=config.train.cuda)
168 |
169 | config.lock()
170 |
171 | # log_file_path = os.path.join(args.checkpoint_dir, "eval-log.txt")
172 | # setup_logging(log_file_path)
173 |
174 | checkpoint_dir = args.checkpoint_dir
175 | checkpoints = []
176 |
177 | for root, _, files in os.walk(args.checkpoint_dir):
178 | for file in files:
179 | if file.endswith(".pth"):
180 | epoch = extract_epoch_from_filename(file)
181 | if epoch is None:
182 | print(f"Failed to extract epoch from filename: {file}")
183 | continue
184 | checkpoint_path = os.path.join(root, file)
185 | checkpoints.append((epoch, checkpoint_path))
186 |
187 | checkpoints.sort(key=lambda x: x[0])
188 |
189 | for epoch, checkpoint_path in checkpoints:
190 | try:
191 | seed = int(checkpoint_path.split('seed')[-1].split('/')[0])
192 | except (ValueError, IndexError):
193 | print(f"Failed to extract seed from path: {checkpoint_path}")
194 | continue
195 |
196 | video_dir = os.path.join(os.path.dirname(checkpoint_path), "videos")
197 | os.makedirs(video_dir, exist_ok=True)
198 |
199 | print(f"Testing checkpoint: {checkpoint_path} with seed: {seed}, saving videos to: {video_dir}")
200 | try:
201 | rollout_from_checkpoint(config, checkpoint_path, device=device, seed=seed, video_dir=video_dir, epoch=epoch)
202 | except Exception as e:
203 | print(f"Rollout failed for {checkpoint_path} with error:\n{e}\n\n{traceback.format_exc()}")
204 |
205 | if __name__ == "__main__":
206 | parser = argparse.ArgumentParser()
207 |
208 | parser.add_argument("--config", type=str, help="Path to the config JSON file")
209 | parser.add_argument("--algo", type=str, help="Algorithm name")
210 | parser.add_argument("--dataset", type=str, help="Dataset path")
211 | parser.add_argument("--name", type=str, help="Experiment name")
212 | parser.add_argument("--checkpoint_dir", type=str, required=True, help="Directory containing checkpoint models")
213 |
214 | args = parser.parse_args()
215 | main(args)
216 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os
3 | import sys
4 | import argparse
5 | import time
6 | import math
7 |
8 | import tensorboard_logger as tb_logger
9 | import torch
10 | import torch.backends.cudnn as cudnn
11 | from torchvision import transforms
12 |
13 | from utils.util import *
14 | from utils.constant import *
15 | from utils.data_augmentation import data_augmentation
16 | from networks.resnet import SupConResNet
17 | from losses import SupConLoss
18 | from dataset.dataset import CustomDataset
19 |
20 | try:
21 | import apex
22 | from apex import amp, optimizers
23 | except ImportError:
24 | pass
25 |
26 | def parse_option():
27 | """ Parse command-line arguments """
28 | parser = argparse.ArgumentParser('argument for training')
29 |
30 | # Training configurations
31 | parser.add_argument('--print_freq', type=int, default=10, help='Print frequency')
32 | parser.add_argument('--save_freq', type=int, default=50, help='Save model frequency')
33 | parser.add_argument('--batch_size', type=int, default=256, help='Batch size')
34 | parser.add_argument('--num_workers', type=int, default=16, help='Number of data loading workers')
35 | parser.add_argument('--epochs', type=int, default=3000, help='Total training epochs')
36 |
37 | # Optimization configurations
38 | parser.add_argument('--learning_rate', type=float, default=0.05, help='Learning rate')
39 | parser.add_argument('--lr_decay_epochs', type=str, default='700,800,900', help='Epochs where learning rate decays')
40 | parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='Learning rate decay rate')
41 | parser.add_argument('--weight_decay', type=float, default=1e-4, help='Weight decay')
42 | parser.add_argument('--momentum', type=float, default=0.9, help='Momentum')
43 |
44 | # Model and dataset settings
45 | parser.add_argument('--model', type=str, default='resnet50', help='Model architecture')
46 | parser.add_argument("--dataset", type=str, default='./low_dim.hdf5', help="path to hdf5 dataset")
47 | parser.add_argument('--aug_path', type=str, default=None, help='Path to custom dataset')
48 | parser.add_argument("--save_mode", type=str, default='lowdim', choices=['image', 'lowdim', 'realworld'], help="choose the saving method")
49 | parser.add_argument('--size', type=int, default=128, help='Image size for RandomResizedCrop')
50 |
51 | # Data augmentation
52 | parser.add_argument("--total_images", type=int, default=100, help="total number of images to generate")
53 | parser.add_argument("--numbers", type=int, nargs='+', default=[0, 1, 2], help="list of numbers for processing")
54 |
55 | # Method and loss function configurations
56 | parser.add_argument('--method', type=str, default='SupCon', choices=['SupCon', 'SimCLR'], help='Contrastive learning method')
57 | parser.add_argument('--temp', type=float, default=0.01, help='Temperature for loss function')
58 |
59 | # Paths for saving model and tensorboard logs
60 | parser.add_argument('--model_path', type=str, default='./lowdim/models', help='Path to save model checkpoints')
61 | parser.add_argument('--tb_path', type=str, default='./lowdim/tensorboard', help='Path for tensorboard logs')
62 |
63 | # Other settings
64 | parser.add_argument('--cosine', action='store_true', help='Use cosine annealing learning rate schedule')
65 | parser.add_argument('--syncBN', action='store_true', help='Use synchronized Batch Normalization')
66 | parser.add_argument('--warm', action='store_true', help='Use warm-up for large batch training')
67 |
68 | opt = parser.parse_args()
69 | opt.lr_decay_epochs = list(map(int, opt.lr_decay_epochs.split(',')))
70 |
71 | opt.model_name = f'{opt.method}_{opt.model}_lr_{opt.learning_rate}_decay_{opt.weight_decay}_bsz_{opt.batch_size}_temp_{opt.temp}_imgsize_{opt.size}'
72 | if opt.cosine:
73 | opt.model_name += '_cosine'
74 | if opt.batch_size > 256:
75 | opt.warm = True
76 | if opt.warm:
77 | opt.model_name += '_warm'
78 | opt.warmup_from = 0.01
79 | opt.warm_epochs = 10
80 | eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3) if opt.cosine else opt.learning_rate
81 | opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * (1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2
82 |
83 | opt.tb_folder = os.path.join(opt.tb_path, opt.model_name)
84 | opt.save_folder = os.path.join(opt.model_path, opt.model_name)
85 | os.makedirs(opt.tb_folder, exist_ok=True)
86 | os.makedirs(opt.save_folder, exist_ok=True)
87 |
88 | return opt
89 |
90 |
91 | def set_loader(opt):
92 | """ Data loader for the training dataset """
93 | normalize = transforms.Normalize(mean=IMG_MEAN, std=IMG_STD)
94 |
95 | train_transform = transforms.Compose([
96 | transforms.RandomResizedCrop(size=opt.size, scale=(0.8, 1.)),
97 | transforms.ToTensor(),
98 | normalize,
99 | ])
100 |
101 |
102 | train_dataset = CustomDataset(npy_file=opt.aug_path, transform=TwoCropTransform(train_transform))
103 |
104 | train_loader = torch.utils.data.DataLoader(
105 | train_dataset, batch_size=opt.batch_size, shuffle=True,
106 | num_workers=opt.num_workers, pin_memory=True)
107 |
108 | return train_loader
109 |
110 | def set_model(opt):
111 | """ Initialize model and loss function """
112 | model = SupConResNet(name=opt.model)
113 | criterion = SupConLoss(temperature=opt.temp)
114 |
115 | # Optional synchronized Batch Normalization
116 | if opt.syncBN:
117 | model = apex.parallel.convert_syncbn_model(model)
118 |
119 | if torch.cuda.is_available():
120 | model = torch.nn.DataParallel(model).cuda() if torch.cuda.device_count() > 1 else model.cuda()
121 | criterion = criterion.cuda()
122 | cudnn.benchmark = True
123 |
124 | return model, criterion
125 |
126 | def train(train_loader, model, criterion, optimizer, epoch, opt):
127 | """ One epoch training loop """
128 | model.train()
129 |
130 | batch_time = AverageMeter()
131 | data_time = AverageMeter()
132 | losses = AverageMeter()
133 |
134 | end = time.time()
135 | for idx, (images, labels) in enumerate(train_loader):
136 | data_time.update(time.time() - end)
137 |
138 | images = torch.cat([images[0], images[1]], dim=0)
139 | images, labels = images.cuda(), labels.cuda(non_blocking=True)
140 |
141 | bsz = labels.size(0)
142 |
143 | warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer)
144 |
145 | features = model(images)
146 | f1, f2 = torch.split(features, [bsz, bsz], dim=0)
147 | features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)
148 |
149 | loss = criterion(features, labels) if opt.method == 'SupCon' else criterion(features)
150 | losses.update(loss.item(), bsz)
151 |
152 | optimizer.zero_grad()
153 | loss.backward()
154 | optimizer.step()
155 |
156 | batch_time.update(time.time() - end)
157 | end = time.time()
158 |
159 | if (idx + 1) % opt.print_freq == 0:
160 | print(f'Epoch: [{epoch}][{idx + 1}/{len(train_loader)}]\t'
161 | f'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
162 | f'Data Time {data_time.val:.3f} ({data_time.avg:.3f})\t'
163 | f'Loss {losses.val:.3f} ({losses.avg:.3f})')
164 |
165 | return losses.avg
166 |
167 | def main():
168 | """ Main function to train the model """
169 | opt = parse_option()
170 | data_augmentation(opt)
171 |
172 | train_loader = set_loader(opt)
173 | model, criterion = set_model(opt)
174 | optimizer = set_optimizer(opt, model)
175 | logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2)
176 |
177 | for epoch in range(1, opt.epochs + 1):
178 | adjust_learning_rate(opt, optimizer, epoch)
179 |
180 | start_time = time.time()
181 | loss = train(train_loader, model, criterion, optimizer, epoch, opt)
182 | print(f'Epoch {epoch}, Total Time {time.time() - start_time:.2f}, Loss {loss:.4f}, Learning Rate {optimizer.param_groups[0]["lr"]}')
183 |
184 | logger.log_value('loss', loss, epoch)
185 | logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch)
186 |
187 | if epoch % opt.save_freq == 0:
188 | save_file = os.path.join(opt.save_folder, f'ckpt_epoch_{epoch}.pth')
189 | save_model(model, optimizer, opt, epoch, save_file)
190 |
191 | save_file = os.path.join(opt.save_folder, 'last.pth')
192 | save_model(model, optimizer, opt, opt.epochs, save_file)
193 |
194 | if __name__ == '__main__':
195 | main()
196 |
--------------------------------------------------------------------------------
/policy/diffusion_policy/sampler.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | import numpy as np
3 | import numba
4 | from diffusion_policy.common.replay_buffer import ReplayBuffer
5 |
6 |
7 | @numba.jit(nopython=True)
8 | def create_indices(
9 | episode_ends: np.ndarray,
10 | sequence_length: int,
11 | episode_mask: np.ndarray,
12 | marks: np.ndarray,
13 | pad_before: int = 0,
14 | pad_after: int = 0,
15 | debug: bool = True,
16 |
17 | ) -> np.ndarray:
18 |
19 | episode_mask.shape == episode_ends.shape
20 | pad_before = min(max(pad_before, 0), sequence_length - 1)
21 | pad_after = min(max(pad_after, 0), sequence_length - 1)
22 |
23 | indices = list()
24 | for i in range(len(episode_ends)):
25 | if not episode_mask[i]:
26 | # skip episode
27 | continue
28 | start_idx = 0
29 | if i > 0:
30 | start_idx = episode_ends[i - 1]
31 | end_idx = episode_ends[i]
32 | episode_length = end_idx - start_idx
33 |
34 | min_start = -pad_before
35 | max_start = episode_length - sequence_length + pad_after
36 |
37 | # range stops one idx before end
38 | for idx in range(min_start, max_start + 1):
39 | buffer_start_idx = max(idx, 0) + start_idx
40 | buffer_end_idx = min(idx + sequence_length, episode_length) + start_idx
41 | start_offset = buffer_start_idx - (idx + start_idx)
42 | end_offset = (idx + sequence_length + start_idx) - buffer_end_idx
43 | sample_start_idx = 0 + start_offset
44 | sample_end_idx = sequence_length - end_offset
45 | if debug:
46 | assert start_offset >= 0
47 | assert end_offset >= 0
48 | assert (sample_end_idx - sample_start_idx) == (
49 | buffer_end_idx - buffer_start_idx
50 | )
51 |
52 | # TODO
53 | sample = []
54 | current_idx = buffer_start_idx
55 | for _ in range(3):
56 | if current_idx < buffer_end_idx:
57 | sample.append(current_idx)
58 | current_idx += 1
59 | else:
60 | break
61 |
62 | while len(sample) < (buffer_end_idx - buffer_start_idx) and current_idx < end_idx:
63 | if marks[current_idx] != 0:
64 | sample.append(current_idx)
65 | current_idx += 1
66 |
67 | sample = np.array(sample)
68 | if len(sample) != buffer_end_idx - buffer_start_idx:
69 | pass
70 | else:
71 | indices.append(
72 | [buffer_start_idx, buffer_end_idx, sample_start_idx, sample_end_idx]
73 | )
74 |
75 | indices = np.array(indices)
76 | return indices
77 |
78 |
79 | def get_val_mask(n_episodes, val_ratio, seed=0):
80 | val_mask = np.zeros(n_episodes, dtype=bool)
81 | if val_ratio <= 0:
82 | return val_mask
83 |
84 | # have at least 1 episode for validation, and at least 1 episode for train
85 | n_val = min(max(1, round(n_episodes * val_ratio)), n_episodes - 1)
86 | rng = np.random.default_rng(seed=seed)
87 | val_idxs = rng.choice(n_episodes, size=n_val, replace=False)
88 | val_mask[val_idxs] = True
89 | return val_mask
90 |
91 |
92 | def downsample_mask(mask, max_n, seed=0):
93 | # subsample training data
94 | train_mask = mask
95 | if (max_n is not None) and (np.sum(train_mask) > max_n):
96 | n_train = int(max_n)
97 | curr_train_idxs = np.nonzero(train_mask)[0]
98 | rng = np.random.default_rng(seed=seed)
99 | train_idxs_idx = rng.choice(len(curr_train_idxs), size=n_train, replace=False)
100 | train_idxs = curr_train_idxs[train_idxs_idx]
101 | train_mask = np.zeros_like(train_mask)
102 | train_mask[train_idxs] = True
103 | assert np.sum(train_mask) == n_train
104 | return train_mask
105 |
106 |
107 | class SequenceSampler:
108 | def __init__(
109 | self,
110 | replay_buffer: ReplayBuffer,
111 | sequence_length: int,
112 | pad_before: int = 0,
113 | pad_after: int = 0,
114 | keys=None,
115 | key_first_k=dict(),
116 | episode_mask: Optional[np.ndarray] = None,
117 | ):
118 | """
119 | key_first_k: dict str: int
120 | Only take first k data from these keys (to improve perf)
121 | """
122 |
123 | super().__init__()
124 | assert sequence_length >= 1
125 | if keys is None:
126 | keys = list(replay_buffer.keys())
127 |
128 | episode_ends = replay_buffer.episode_ends[:]
129 | marks = replay_buffer["marks"]
130 | if episode_mask is None:
131 | episode_mask = np.ones(episode_ends.shape, dtype=bool)
132 |
133 | if np.any(episode_mask):
134 | indices = create_indices(
135 | episode_ends,
136 | sequence_length=sequence_length,
137 | pad_before=pad_before,
138 | pad_after=pad_after,
139 | episode_mask=episode_mask,
140 | marks = np.array(marks),
141 | )
142 | else:
143 | indices = np.zeros((0, 4), dtype=np.int64)
144 |
145 | # (buffer_start_idx, buffer_end_idx, sample_start_idx, sample_end_idx)
146 | self.indices = indices
147 | self.keys = list(keys) # prevent OmegaConf list performance problem
148 | self.sequence_length = sequence_length
149 | self.replay_buffer = replay_buffer
150 | self.key_first_k = key_first_k
151 |
152 | def __len__(self):
153 | return len(self.indices)
154 |
155 | def sample_sequence(self, idx):
156 | (
157 | buffer_start_idx,
158 | buffer_end_idx,
159 | sample_start_idx,
160 | sample_end_idx,
161 | ) = self.indices[idx]
162 | result = dict()
163 | for key in self.keys:
164 | if key == "marks":
165 | continue
166 | # print(key)
167 | # print(self.key_first_k)
168 | input_arr = self.replay_buffer[key]
169 |
170 | # performance optimization, avoid small allocation if possible
171 | marks = self.replay_buffer["marks"]
172 |
173 | if key not in self.key_first_k:
174 | # TODO
175 | input_arr = self.replay_buffer[key]
176 | marks = self.replay_buffer["marks"]
177 |
178 | sample = []
179 |
180 | # obs_step + 1
181 | current_idx = buffer_start_idx
182 | for _ in range(3):
183 | if current_idx < buffer_end_idx:
184 | sample.append(input_arr[current_idx])
185 | current_idx += 1
186 | else:
187 | break
188 |
189 | while len(sample) < buffer_end_idx-buffer_start_idx and current_idx < buffer_end_idx:
190 | if marks[current_idx] != 0:
191 | sample.append(input_arr[current_idx])
192 | current_idx += 1
193 |
194 | while len(sample) < buffer_end_idx-buffer_start_idx and current_idx < len(marks):
195 | if marks[current_idx] != 0:
196 | sample.append(input_arr[current_idx])
197 | current_idx += 1
198 |
199 | sample = np.array(sample)
200 |
201 | if len(sample) != buffer_end_idx-buffer_start_idx:
202 | raise ValueError("Could not fill the sample to the required sequence length.")
203 |
204 | sample = input_arr[buffer_start_idx:buffer_end_idx]
205 |
206 | else:
207 | # performance optimization, only load used obs steps
208 | n_data = buffer_end_idx - buffer_start_idx
209 | k_data = min(self.key_first_k[key], n_data)
210 | # fill value with Nan to catch bugs
211 | # the non-loaded region should never be used
212 | sample = np.full(
213 | (n_data,) + input_arr.shape[1:],
214 | fill_value=np.nan,
215 | dtype=input_arr.dtype,
216 | )
217 | try:
218 | sample[:k_data] = input_arr[
219 | buffer_start_idx : buffer_start_idx + k_data
220 | ]
221 | except Exception as e:
222 | import pdb
223 |
224 | pdb.set_trace()
225 |
226 | data = sample
227 | if (sample_start_idx > 0) or (sample_end_idx < self.sequence_length):
228 | data = np.zeros(
229 | shape=(self.sequence_length,) + input_arr.shape[1:],
230 | dtype=input_arr.dtype,
231 | )
232 | if sample_start_idx > 0:
233 | data[:sample_start_idx] = sample[0]
234 | if sample_end_idx < self.sequence_length:
235 | data[sample_end_idx:] = sample[-1]
236 | data[sample_start_idx:sample_end_idx] = sample
237 | result[key] = data
238 | return result
239 |
--------------------------------------------------------------------------------
/utils/image_generation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import io
3 | import os
4 | import cv2
5 |
6 | from PIL import Image
7 | import matplotlib.pyplot as plt
8 | from .realworld_utils import *
9 | import robosuite.utils.transform_utils as T
10 |
11 | def get_camera_extrinsic_matrix(sim, camera_name):
12 | cam_id = sim.model.camera_name2id(camera_name)
13 | camera_pos = sim.data.cam_xpos[cam_id]
14 | camera_rot = sim.data.cam_xmat[cam_id].reshape(3, 3)
15 | R = T.make_pose(camera_pos, camera_rot)
16 |
17 | camera_axis_correction = np.array(
18 | [[1.0, 0.0, 0.0, 0.0], [0.0, -1.0, 0.0, 0.0], [0.0, 0.0, -1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]
19 | )
20 | R = R @ camera_axis_correction
21 | return R
22 |
23 | class TrajectoryRenderer:
24 | def __init__(self, env, camera_name, save_mode = False, calib_dir=None, root_dir=None):
25 | self.save_mode = save_mode
26 | if self.save_mode == 'realworld':
27 | self.calib_dir = calib_dir
28 | self.root_dir = root_dir
29 | else:
30 | self.env = env
31 | self.camera_name = camera_name
32 |
33 | # self.extrinsic_matrix = env.get_camera_extrinsic_matrix(camera_name)
34 | self.extrinsic_matrix = get_camera_extrinsic_matrix(env.env.sim, camera_name)
35 |
36 | print(self.extrinsic_matrix)
37 |
38 | self.camera_position = self.extrinsic_matrix[:3, 3]
39 | self.camera_rotation = self.extrinsic_matrix[:3, :3]
40 |
41 |
42 | def render_trajectory_image(self, trajectory_points, ind, samples, save_mode, rotate=False, state_slices=None):
43 | if self.save_mode == 'realworld':
44 | if rotate:
45 | trajectory_points = self.rotate_trajectory(trajectory_points, ind)
46 | transformed_points = translate_points(self.calib_dir, trajectory_points)
47 | self.realworld_save_and_append_images(transformed_points, samples, ind, state_slices[0], rotate)
48 | else:
49 | state_slices is not None and self.env.reset_to(state_slices[0])
50 | frame = self.env.render(mode="rgb_array", height=480, width=480, camera_name=self.camera_name)
51 | image2 = Image.fromarray(frame)
52 | factor = get_save_mode_factor(save_mode)
53 | image_array = apply_image_filter(image2, factor)
54 |
55 | if rotate:
56 | trajectory_points = self.rotate_trajectory(trajectory_points, ind)
57 | transformed_points = np.dot(trajectory_points - self.camera_position, self.camera_rotation)
58 |
59 | self.save_and_append_images(transformed_points, samples, ind, image_array, rotate)
60 |
61 | def rotate_trajectory(self, trajectory_points, i):
62 | start_point, end_point, middle_points = trajectory_points[0], trajectory_points[-1], trajectory_points[1:-1]
63 | axis = (end_point - start_point) / np.linalg.norm(end_point - start_point)
64 | angle = np.deg2rad(i * 360 / 30)
65 | rotation_matrix = self.get_rotation_matrix(axis, angle)
66 |
67 | rotated_middle_points = np.dot(middle_points - start_point, rotation_matrix.T) + start_point
68 | return np.vstack([start_point, rotated_middle_points, end_point])
69 |
70 | def get_rotation_matrix(self, axis, angle):
71 | cos_angle = np.cos(angle)
72 | sin_angle = np.sin(angle)
73 | return np.array([
74 | [cos_angle + axis[0]**2 * (1 - cos_angle), axis[0]*axis[1]*(1 - cos_angle) - axis[2]*sin_angle, axis[0]*axis[2]*(1 - cos_angle) + axis[1]*sin_angle],
75 | [axis[1]*axis[0]*(1 - cos_angle) + axis[2]*sin_angle, cos_angle + axis[1]**2 * (1 - cos_angle), axis[1]*axis[2]*(1 - cos_angle) - axis[0]*sin_angle],
76 | [axis[2]*axis[0]*(1 - cos_angle) - axis[1]*sin_angle, axis[2]*axis[1]*(1 - cos_angle) + axis[0]*sin_angle, cos_angle + axis[2]**2 * (1 - cos_angle)]
77 | ])
78 |
79 | def realworld_save_and_append_images(self, transformed_points, samples, ind, image_path, rotate):
80 | img_path = os.path.join(self.root_dir, image_path)
81 | image = cv2.imread(img_path)
82 | factor = get_save_mode_factor(self.save_mode)
83 | image = apply_image_filter(image, factor)
84 | image = np.array(image)
85 |
86 | prev_point = None
87 | for point in transformed_points[:]:
88 | cv2.circle(image, (int(point[0]), int(point[1])), 3, (0, 0, 255), -1)
89 | if prev_point is not None:
90 | cv2.line(image, prev_point, point, (0, 0, 255), thickness=4)
91 | prev_point = point
92 | # save_path = './tmp_realworld/marked_image{}.jpg'.format(ind)
93 | # cv2.imwrite(save_path, image)
94 | samples['images'].append(np.array(image))
95 | samples['labels'].append(1 if rotate else 0)
96 |
97 | def save_and_append_images(self, transformed_points, samples, ind, image_array, rotate):
98 | image_array = plot(transformed_points, image_array)
99 | # image_array.save(f'./tmp/image_com_{ind}.png')
100 |
101 | samples['images'].append(np.array(image_array))
102 | samples['labels'].append(1 if rotate else 0)
103 |
104 |
105 | class TrajectoryNoiseGenerator:
106 | def __init__(self, trajectory_points):
107 | self.trajectory_points = trajectory_points
108 |
109 | def render_one_point(self):
110 | start_index = np.random.randint(0, len(self.trajectory_points) - 10)
111 | end_index = start_index + 1
112 |
113 | sub_trajectory = self.trajectory_points[start_index:end_index]
114 | noisy_sub_trajectory = self.add_noise(sub_trajectory, scale_range=(0.02, 0.05))
115 |
116 | noise_trajectory = self.trajectory_points.copy()
117 | noise_trajectory[start_index:end_index] = noisy_sub_trajectory
118 | self.remove_adjacent_points(noise_trajectory, start_index)
119 | return noise_trajectory
120 |
121 | def add_noise(self, sub_trajectory, scale_range):
122 | noise_scale = np.random.uniform(*scale_range)
123 | noise = noise_scale * np.random.randn(sub_trajectory.shape[0], sub_trajectory.shape[1])
124 | return sub_trajectory + noise
125 |
126 | def remove_adjacent_points(self, noise_trajectory, start_index):
127 | if start_index > 1:
128 | for i in range(1, 4):
129 | noise_trajectory = np.delete(noise_trajectory, start_index - i, axis=0)
130 |
131 | def render_one_point_circle(self):
132 | start_index = np.random.randint(0, min(len(self.trajectory_points) - 10, len(self.trajectory_points)))
133 | end_index = start_index + np.random.randint(5, 10)
134 |
135 | sub_trajectory = self.trajectory_points[start_index:end_index]
136 | noisy_sub_trajectory = self.add_noise(sub_trajectory, scale_range=(0.02, 0.04))
137 |
138 | inserted_indices = self.get_inserted_indices(start_index, end_index, num_points_range=(10, 20))
139 | for i, index in enumerate(inserted_indices):
140 | self.trajectory_points = np.insert(self.trajectory_points, index + i, sub_trajectory[0] + np.random.uniform(0.04, 0.06)*np.random.randn(), axis=0)
141 | return self.trajectory_points
142 |
143 | def get_inserted_indices(self, start_index, end_index, num_points_range):
144 | num_inserted_points = np.random.randint(*num_points_range)
145 | inserted_indices = np.random.randint(start_index, end_index, size=num_inserted_points)
146 | inserted_indices.sort()
147 | return inserted_indices
148 |
149 | def render_series(self):
150 | start_index = np.random.randint(0, max(1, len(self.trajectory_points) - 20))
151 | end_index = start_index + np.random.randint(5, 10)
152 | sub_trajectory = self.trajectory_points[start_index:end_index]
153 |
154 | noisy_sub_trajectory = self.add_noise(sub_trajectory, scale_range=(0.03, 0.06))
155 | noise_trajectory = self.trajectory_points.copy()
156 | noise_trajectory[start_index:end_index] = noisy_sub_trajectory
157 | return noise_trajectory
158 |
159 |
160 | class TrajectoryGenerator:
161 | def __init__(self, trajectory_points):
162 | self.trajectory_points = trajectory_points
163 | self.noise_generator = TrajectoryNoiseGenerator(trajectory_points)
164 |
165 | def generate_negative_trajectory_points(self):
166 | num_one_point = np.random.randint(0, 8)
167 | noise_trajectory = self.trajectory_points.copy()
168 |
169 | for _ in range(num_one_point):
170 | noise_trajectory = self.noise_generator.render_one_point()
171 |
172 | if num_one_point == 0:
173 | one_point_circle_flag = 1
174 | else:
175 | one_point_circle_flag = np.random.randint(2)
176 |
177 | if one_point_circle_flag:
178 | noise_trajectory = self.noise_generator.render_one_point_circle()
179 |
180 | if not one_point_circle_flag and np.random.randint(2):
181 | noise_trajectory = self.noise_generator.render_series()
182 |
183 | return noise_trajectory
184 |
185 |
186 | class PictureGenerator:
187 | def __init__(self, renderer, samples, save_mode):
188 | self.renderer = renderer
189 | self.samples = samples
190 | self.save_mode = save_mode
191 |
192 | def generate_positive_picture(self, trajectory_points, state_slices, ind, num_images=30):
193 | for i in range(num_images):
194 | self.renderer.render_trajectory_image(trajectory_points, i, self.samples, self.save_mode, rotate=True, state_slices=state_slices)
195 | # This part can be expanded if needed
196 |
197 | def generate_negative_picture(self, trajectory_points, state_slices, ind, num_images):
198 | for i in range(num_images):
199 | noise_trajectory = TrajectoryGenerator(trajectory_points).generate_negative_trajectory_points()
200 | self.renderer.render_trajectory_image(noise_trajectory, i, self.samples, self.save_mode, rotate=False, state_slices=state_slices)
201 |
202 |
203 |
--------------------------------------------------------------------------------
/dataset/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import torch
4 | import h5py
5 | from torch.utils.data import Dataset
6 | import numpy as np
7 | from PIL import Image
8 | from torchvision import transforms, datasets
9 | import robomimic.utils.obs_utils as ObsUtils
10 | import robomimic.utils.env_utils as EnvUtils
11 | import robomimic.utils.file_utils as FileUtils
12 |
13 | from scipy.spatial.transform import Rotation
14 | from .trajectory_optimization import TrajectoryOptimizer
15 | from utils.realworld_utils import *
16 | from utils.constant import *
17 |
18 |
19 | class CustomDataset(Dataset):
20 | def __init__(self, npy_file, transform=None):
21 | """
22 | Args:
23 | npy_file (string):
24 | transform (callable, optional):
25 | """
26 | self.transform = transform
27 | self.data = np.load(npy_file, allow_pickle=True).item()
28 | self.inputs = self.data['images']
29 | self.labels = self.data['labels']
30 |
31 | if len(self.inputs) != len(self.labels):
32 | raise ValueError(f"Length mismatch: inputs({len(self.inputs)}) and labels({len(self.labels)})")
33 |
34 | def __len__(self):
35 | return len(self.labels)
36 |
37 | def __getitem__(self, idx):
38 | input_data = self.inputs[idx]
39 | label = self.labels[idx]
40 | if isinstance(input_data, np.ndarray):
41 | input_data = Image.fromarray(input_data)
42 |
43 | if self.transform:
44 | input_data = self.transform(input_data)
45 |
46 | return input_data, label
47 |
48 |
49 |
50 |
51 | class ValDataset(Dataset):
52 | def __init__(self, hdf5_file, transform=None, save_mode=None):
53 | """
54 | Args:
55 | hdf5_file (string)
56 | transform (callable, optional)
57 | """
58 | self.transform = transform
59 | self.save_mode = save_mode
60 | self.hdf5_file = hdf5_file
61 | self.data = h5py.File(hdf5_file, 'r')
62 |
63 | self.demos = list(self.data["data"].keys())
64 | self.small_demos = {}
65 | self.mapping = {}
66 | self._init_env()
67 | self.optimizer = TrajectoryOptimizer(self.env, real_world=False)
68 |
69 | self._split_demos()
70 | self.data.close()
71 |
72 |
73 | def _init_env(self):
74 | env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=self.hdf5_file)
75 | env_type = EnvUtils.get_env_type(env_meta=env_meta)
76 | self.render_image_names = DEFAULT_CAMERAS[env_type]
77 |
78 | dummy_spec = dict(
79 | obs=dict(
80 | low_dim=["robot0_eef_pos"],
81 | rgb=[],
82 | ),
83 | )
84 | ObsUtils.initialize_obs_utils_with_obs_specs(obs_modality_specs=dummy_spec)
85 | self.env = EnvUtils.create_env_from_metadata(env_meta=env_meta, render=False, render_offscreen=True)
86 |
87 | def _calculate_items(self, demo_idx, states, actions):
88 | states = self.data[f"data/{demo_idx}/states"][()]
89 | traj_len = states.shape[0]
90 |
91 | delta_actions = self.data[f"data/{demo_idx}/actions"][()]
92 | action_pos = np.zeros((traj_len, 3), dtype=delta_actions.dtype)
93 | action_ori = np.zeros((traj_len, 3), dtype=delta_actions.dtype)
94 | action_gripper = delta_actions[:, -1:]
95 |
96 | robot = self.env.env.robots[0]
97 | controller = robot.controller
98 |
99 | for i in range(len(states)):
100 | self.env.reset_to({"states": states[i]})
101 | robot.control(actions[i], policy_step=True)
102 |
103 | action_pos[i] = controller.ee_pos
104 | action_ori[i] = Rotation.from_matrix(controller.ee_ori_mat).as_rotvec()
105 |
106 | actions = np.concatenate([action_pos, action_ori, action_gripper], axis=-1)
107 | return actions
108 |
109 |
110 | def _split_demos(self):
111 | for demo_idx in self.demos:
112 | eef_pos =self.data[f"data/{demo_idx}/obs/robot0_eef_pos"][()]
113 | eef_quat = self.data[f"data/{demo_idx}/obs/robot0_eef_quat"][()]
114 | joint_pos = self.data[f"data/{demo_idx}/obs/robot0_joint_pos"][()]
115 | gt_states = []
116 | traj_len = eef_pos.shape[0]
117 | for i in range(traj_len):
118 | gt_states.append(
119 | dict(
120 | robot0_eef_pos=eef_pos[i],
121 | robot0_eef_quat=eef_quat[i],
122 | robot0_joint_pos=joint_pos[i],
123 | )
124 | )
125 | actions = self.data[f'data/{demo_idx}/actions'][()]
126 | states = self.data[f'data/{demo_idx}/states'][()]
127 | trajectory_points = self.data[f'data/{demo_idx}/obs/robot0_eef_pos'][()]
128 | actions = self._calculate_items(demo_idx, states, actions)
129 |
130 | frames = self._get_frames(actions)
131 | small_demos = self._split_into_small_demos(actions, states, trajectory_points, gt_states, frames)
132 |
133 | self.small_demos[demo_idx] = small_demos
134 | self.mapping[demo_idx] = list(range(len(small_demos)))
135 |
136 | def _get_frames(self, actions):
137 | frames = [0, len(actions) - 1]
138 | for i in range(len(actions) - 1):
139 | if actions[i, -1] != actions[i + 1, -1]:
140 | frames.append(i)
141 | frames.sort()
142 |
143 | merged_frames = [frames[0]]
144 | for i in range(1, len(frames)):
145 | if frames[i] - merged_frames[-1] < 15:
146 | merged_frames.pop()
147 | merged_frames.append(frames[i])
148 |
149 | merged_frames.sort()
150 | return merged_frames
151 |
152 | def _split_into_small_demos(self, actions, states, trajectory_points, gt_states, frames):
153 | small_demos = []
154 | for i in range(len(frames) - 1):
155 | start, end = frames[i], frames[i + 1]
156 | small_demos.append({
157 | 'actions': actions[start:end],
158 | 'states': states[start:end],
159 | 'gt_states': gt_states[start:end],
160 | 'trajectory_points': trajectory_points[start:end],
161 | 'frame_start': frames[i],
162 | 'frame_end': frames[i+1]
163 | })
164 | return small_demos
165 |
166 | def __len__(self):
167 | return sum(len(small_demo) for small_demo in self.small_demos.values())
168 |
169 | def __getitem__(self, idx):
170 | demo_idx, small_demo_idx = self._find_small_demo_index(idx)
171 | small_demo = self.small_demos[demo_idx][small_demo_idx]
172 |
173 | positive_image = self.generate_image(small_demo)
174 |
175 | if self.transform:
176 | positive_image = self.transform(positive_image)
177 |
178 | return positive_image, demo_idx, small_demo_idx
179 |
180 | def _find_small_demo_index(self, idx):
181 | for demo_idx, small_demos in self.small_demos.items():
182 | if idx < len(small_demos):
183 | return demo_idx, idx
184 | idx -= len(small_demos)
185 | raise IndexError("Index out of range.")
186 |
187 | def _save_marks(self, demo_idx, marks):
188 | with h5py.File(self.hdf5_file, 'a') as f:
189 | if f'data/{demo_idx}/marks' not in f:
190 | f.create_dataset(f'data/{demo_idx}/marks', data=marks)
191 | else:
192 | existing_marks = f[f'data/{demo_idx}/marks'][:]
193 | all_marks = np.unique(np.concatenate((existing_marks, marks)))
194 | all_marks.sort()
195 | del f[f'data/{demo_idx}/marks']
196 | f.create_dataset(f'data/{demo_idx}/marks', data=all_marks)
197 |
198 |
199 | def visualize_image(self,idx):
200 | demo_idx, small_demo_idx = self._find_small_demo_index(idx)
201 | small_demo = self.small_demos[demo_idx][small_demo_idx]
202 |
203 | positive_image = self.generate_image(small_demo)
204 | return positive_image
205 |
206 | def perform_optimization(self, idx, label):
207 | flag = label < 0.5
208 | demo_idx, small_demo_idx = self._find_small_demo_index(idx)
209 | small_demo = self.small_demos[demo_idx][small_demo_idx]
210 | if flag:
211 | marks = self.optimizer.optimize_trajectory(small_demo, demo_idx, small_demo_idx,three_dimension=True)
212 | else:
213 | marks = list(range(small_demo['frame_start'], small_demo['frame_end']))
214 |
215 | self._save_marks(demo_idx, marks)
216 |
217 | def transform_points(self, trajectory_points):
218 | camera_position = np.array([1.0, 0.0, 1.75])
219 | camera_rotation = np.array([
220 | [0.0, -0.70614724, 0.70806503],
221 | [1.0, 0.0, 0.0],
222 | [0.0, 0.70806503, 0.70614724]
223 | ])
224 |
225 | transformed_points = np.dot(trajectory_points - camera_position, camera_rotation)
226 | return transformed_points
227 |
228 |
229 | def generate_image(self, small_demo, save_mode="image"):
230 | trajectory_points = small_demo['trajectory_points']
231 |
232 | self.env.reset()
233 | self.env.reset_to(dict(states=small_demo['states'][0]))
234 | frame = self.env.render(mode="rgb_array", height=480, width=480, camera_name=self.render_image_names[0])
235 |
236 | image = Image.fromarray(frame)
237 | factor = get_save_mode_factor(save_mode=self.save_mode)
238 | image = apply_image_filter(image, factor)
239 |
240 | transformed_points = self.transform_points(trajectory_points)
241 | return plot(transformed_points, image)
242 |
243 |
244 | class RealworldValDataset(Dataset):
245 | def __init__(self, dataset, transform=None, save_mode=None):
246 | """
247 | Args:
248 | hdf5_file (string)
249 | transform (callable, optional)
250 | """
251 | self.transform = transform
252 | self.calib_dir = check_directory_exists(os.path.join(dataset, "calib"))
253 | self.save_mode = save_mode
254 | self.root_dir = os.path.join(dataset, "train")
255 | subdirs = [d for d in os.listdir(self.root_dir) if os.path.isdir(os.path.join(self.root_dir, d))]
256 | self.subdirs = sorted(subdirs, key=lambda x: int(x.split('_scene_')[1].split('_')[0]))
257 | self.small_demos = {}
258 | self.mapping = {}
259 |
260 | self.optimizer = TrajectoryOptimizer(env=None, real_world=True, calib_dir=self.calib_dir)
261 | self._split_demos()
262 |
263 | def _split_demos(self):
264 | for ind in range(len(self.subdirs)-20):
265 | color_path = os.path.join(self.root_dir, self.subdirs[ind], CAMERA_NAME, 'color')
266 | file_paths, trajectory_points, gripper_command = load_demo_files(self.root_dir, self.subdirs, ind)
267 | frames = realworld_change_indices(gripper_command)
268 | small_demos = self._split_into_small_demos(file_paths, trajectory_points, frames, color_path)
269 |
270 | self.small_demos[ind] = small_demos
271 | self.mapping[ind] = list(range(len(small_demos)))
272 |
273 | def _split_into_small_demos(self, file_paths, trajectory_points, frames, color_path):
274 | small_demos = []
275 | for i in range(len(frames) - 1):
276 | start, end = frames[i], frames[i + 1]
277 | small_demos.append({
278 | 'states': file_paths[start:end],
279 | 'trajectory_points': trajectory_points[start:end],
280 | 'frame_start': frames[i],
281 | 'frame_end': frames[i+1],
282 | 'color_path': color_path
283 | })
284 | return small_demos
285 |
286 | def _find_small_demo_index(self, idx):
287 | for demo_idx, small_demos in self.small_demos.items():
288 | if idx < len(small_demos):
289 | return demo_idx, idx
290 | idx -= len(small_demos)
291 | raise IndexError("Index out of range.")
292 |
293 | def _generate_image(self, small_demo):
294 | trajectory_points = small_demo['trajectory_points']
295 |
296 | transformed_points = translate_points(self.calib_dir, trajectory_points)
297 | img_path = os.path.join(small_demo['color_path'], small_demo['states'][0])
298 | image = cv2.imread(img_path)
299 |
300 | factor = get_save_mode_factor(save_mode=self.save_mode)
301 | image = apply_image_filter(image, factor)
302 | image = np.array(image)
303 |
304 | prev_point = None
305 | for point in transformed_points[:]:
306 | cv2.circle(image, (int(point[0]), int(point[1])), 3, (0, 0, 255), -1)
307 | if prev_point is not None:
308 | cv2.line(image, prev_point, point, (0, 0, 255), thickness=4)
309 | prev_point = point
310 | return image
311 |
312 | def __len__(self):
313 | return sum(len(small_demo) for small_demo in self.small_demos.values())
314 |
315 | def __getitem__(self, idx):
316 | demo_idx, small_demo_idx = self._find_small_demo_index(idx)
317 | small_demo = self.small_demos[demo_idx][small_demo_idx]
318 | positive_image = self._generate_image(small_demo)
319 |
320 | positive_image = self.transform(Image.fromarray(positive_image)) if self.transform and isinstance(positive_image, np.ndarray) else positive_image
321 |
322 | return positive_image, demo_idx, small_demo_idx
323 |
324 | def _save_marks(self, demo_idx, marks, color_path):
325 | npy_file_path = os.path.join(color_path, f'marks_{demo_idx}.npy')
326 | if os.path.exists(npy_file_path):
327 | existing_marks = np.load(npy_file_path)
328 | all_marks = np.unique(np.concatenate((existing_marks, marks)))
329 | else:
330 | all_marks = np.unique(marks)
331 |
332 | all_marks.sort()
333 | np.save(npy_file_path, all_marks)
334 |
335 | def perform_optimization(self, idx, flag=True):
336 | demo_idx, small_demo_idx = self._find_small_demo_index(idx)
337 | small_demo = self.small_demos[demo_idx][small_demo_idx]
338 | if flag:
339 | marks = self.optimizer.optimize_trajectory(small_demo, demo_idx, small_demo_idx,three_dimension=True)
340 | else:
341 | marks = list(range(small_demo['frame_start'], small_demo['frame_end']))
342 |
343 | self._save_marks(demo_idx, marks, small_demo['color_path'])
344 |
345 |
--------------------------------------------------------------------------------