├── PyriteConfig ├── .gitignore └── README.md ├── PyriteUtility ├── PyriteUtility │ ├── __init__.py │ ├── planning_control │ │ ├── rrt_plotting.py │ │ └── shooting_sampling.py │ ├── data_pipeline │ │ ├── umift_mask.json │ │ ├── file.py │ │ ├── zarr_examination.py │ │ ├── shared_memory │ │ │ └── shared_memory_util.py │ │ ├── real_data_delete_fields.py │ │ ├── real_data_just_update_meta.py │ │ ├── postprocessing_add_virtual_target_label.py │ │ ├── real_data_check_keys.py │ │ ├── indexing.py │ │ ├── plot_correction_data.py │ │ ├── real_data_check_timing.py │ │ └── real_data_processing.py │ ├── common.py │ ├── plotting │ │ └── matplotlib_helpers.py │ ├── math │ │ └── numerical_differentiation.py │ ├── pytorch_utils │ │ └── model_io.py │ ├── hardware_interface │ │ ├── test_log_video.py │ │ └── multi_camera_visualizer.py │ ├── computer_vision │ │ └── computer_vision_utility.py │ ├── umi_utils │ │ └── usb_util.py │ └── audio │ │ ├── multi_mic.py │ │ └── audio_recorder.py ├── README.md ├── setup.py └── .gitignore ├── PyriteML ├── diffusion_policy │ ├── model │ │ ├── residual │ │ │ ├── conv_mlp.py │ │ │ └── mlp.py │ │ ├── bet │ │ │ ├── libraries │ │ │ │ └── mingpt │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── LICENSE │ │ │ │ │ └── utils.py │ │ │ ├── action_ae │ │ │ │ └── __init__.py │ │ │ ├── latent_generators │ │ │ │ ├── latent_generator.py │ │ │ │ └── transformer.py │ │ │ └── utils.py │ │ ├── common │ │ │ ├── module_attr_mixin.py │ │ │ ├── shape_util.py │ │ │ ├── dict_of_tensor_mixin.py │ │ │ └── lr_scheduler.py │ │ ├── diffusion │ │ │ ├── positional_embedding.py │ │ │ ├── conv1d_components.py │ │ │ ├── ema_model.py │ │ │ └── mlp.py │ │ └── vision │ │ │ ├── model_getter.py │ │ │ └── force_spec_encoder.py │ ├── common │ │ ├── env_util.py │ │ ├── precise_sleep.py │ │ ├── nested_dict_util.py │ │ ├── pymunk_util.py │ │ ├── checkpoint_util.py │ │ ├── pytorch_util.py │ │ ├── json_logger.py │ │ ├── cv2_util.py │ │ └── pose_repr_util.py │ ├── policy │ │ ├── base_image_policy.py │ │ └── base_lowdim_policy.py │ ├── dataset │ │ └── base_dataset.py │ └── config │ │ └── task │ │ ├── online_correction_single_arm_no_base_action_no_force.yaml │ │ ├── stow_conv.yaml │ │ ├── stow_spec.yaml │ │ ├── correction_single_arm_xyzforce_only.yaml │ │ └── stow_no_force.yaml ├── multimodal_representation │ ├── multimodal │ │ ├── __init__.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── tests │ │ │ │ ├── __init__.py │ │ │ │ └── test_layers.py │ │ │ ├── base_models │ │ │ │ ├── __init__.py │ │ │ │ └── layers.py │ │ │ └── models_utils.py │ │ ├── trainers │ │ │ └── __init__.py │ │ ├── scripts │ │ │ └── run_all_tests.sh │ │ ├── dataset │ │ │ └── download_data.sh │ │ ├── dataloaders │ │ │ ├── __init__.py │ │ │ ├── ProcessFlow.py │ │ │ ├── ProcessForce.py │ │ │ └── ToTensor.py │ │ ├── configs │ │ │ └── training_default.yaml │ │ ├── mini_main.py │ │ └── logger.py │ ├── LICENSE │ ├── README.md │ ├── .gitignore │ └── requirements.txt ├── conda_environment.yaml ├── LICENSE ├── train.py ├── online_learning │ ├── learner.py │ ├── actor.py │ └── configs │ │ └── config_v1.py ├── scripts │ ├── test_rmq_client.py │ ├── test_rmq_server.py │ └── test_rmq.ipynb └── README.md ├── PyriteEnvSuites ├── README.md └── envs │ └── task │ └── manip_server_handle_env.py ├── LICENSE ├── base_policy.md └── .gitignore /PyriteConfig/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ -------------------------------------------------------------------------------- /PyriteUtility/PyriteUtility/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /PyriteML/diffusion_policy/model/residual/conv_mlp.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /PyriteML/multimodal_representation/multimodal/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /PyriteUtility/PyriteUtility/planning_control/rrt_plotting.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /PyriteML/diffusion_policy/model/bet/libraries/mingpt/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /PyriteML/multimodal_representation/multimodal/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /PyriteML/multimodal_representation/multimodal/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /PyriteML/multimodal_representation/multimodal/models/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /PyriteML/multimodal_representation/multimodal/models/base_models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /PyriteML/multimodal_representation/multimodal/scripts/run_all_tests.sh: -------------------------------------------------------------------------------- 1 | python -m pytest models/tests -------------------------------------------------------------------------------- /PyriteUtility/README.md: -------------------------------------------------------------------------------- 1 | # PyriteUtilities 2 | This package contains robotics utilities functions and types. -------------------------------------------------------------------------------- /PyriteEnvSuites/README.md: -------------------------------------------------------------------------------- 1 | # PyriteEnvSuites 2 | This package contains robot environments and data generation pipelines. 3 | -------------------------------------------------------------------------------- /PyriteConfig/README.md: -------------------------------------------------------------------------------- 1 | # PyriteConfig 2 | This package contains configs and type conversions for tasks in the Pyrite project. -------------------------------------------------------------------------------- /PyriteML/multimodal_representation/multimodal/dataset/download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget http://downloads.cs.stanford.edu/juno/triangle_real_data.zip -O _tmp.zip 4 | 5 | unzip _tmp.zip 6 | rm _tmp.zip 7 | -------------------------------------------------------------------------------- /PyriteUtility/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="PyriteUtility", 5 | version="0.0.1", 6 | packages=find_packages(), 7 | install_requires=[ 8 | ], 9 | ) 10 | -------------------------------------------------------------------------------- /PyriteML/multimodal_representation/multimodal/dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | from .MultimodalManipulationDataset import MultimodalManipulationDataset 2 | from .ProcessForce import ProcessForce 3 | from .ProcessFlow import ProcessFlow 4 | from .ToTensor import ToTensor 5 | -------------------------------------------------------------------------------- /PyriteUtility/PyriteUtility/data_pipeline/umift_mask.json: -------------------------------------------------------------------------------- 1 | { 2 | "umift_uw_mask": [ 3 | [33, 200], 4 | [56, 164], 5 | [196, 163], 6 | [223, 191], 7 | [223, 200] 8 | ] 9 | , 10 | "mask_uw_resolution": [224, 224] 11 | 12 | } -------------------------------------------------------------------------------- /PyriteUtility/PyriteUtility/data_pipeline/file.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | # write a function to save date in pickle format 4 | def save_data_as_pickle(data, file_name): 5 | with open(file_name, 'wb') as f: 6 | pickle.dump(data, f) 7 | 8 | # write a function to load data from pickle format 9 | def load_data_from_pickle(file_name): 10 | with open(file_name, 'rb') as f: 11 | data = pickle.load(f) 12 | return data 13 | 14 | -------------------------------------------------------------------------------- /PyriteML/diffusion_policy/model/common/module_attr_mixin.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class ModuleAttrMixin(nn.Module): 4 | def __init__(self): 5 | super().__init__() 6 | self._dummy_variable = nn.Parameter(requires_grad=False) 7 | 8 | @property 9 | def device(self): 10 | return next(iter(self.parameters())).device 11 | 12 | @property 13 | def dtype(self): 14 | return next(iter(self.parameters())).dtype 15 | -------------------------------------------------------------------------------- /PyriteML/diffusion_policy/model/diffusion/positional_embedding.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | class SinusoidalPosEmb(nn.Module): 6 | def __init__(self, dim, max_value=10000): 7 | super().__init__() 8 | self.dim = dim 9 | self.max_value = max_value 10 | 11 | def forward(self, x): 12 | device = x.device 13 | half_dim = self.dim // 2 14 | emb = math.log(self.max_value) / (half_dim - 1) 15 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 16 | emb = x[:, None] * emb[None, :] 17 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 18 | return emb 19 | -------------------------------------------------------------------------------- /PyriteML/multimodal_representation/multimodal/dataloaders/ProcessFlow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class ProcessFlow(object): 6 | """Process optical flow into a pyramid. 7 | Args: 8 | pyramid_scale (list): scaling factors to downsample 9 | the spatial pyramid 10 | """ 11 | 12 | def __init__(self, pyramid_scales=[2, 4, 8]): 13 | assert isinstance(pyramid_scales, list) 14 | self.pyramid_scales = pyramid_scales 15 | 16 | def __call__(self, sample): 17 | # subsampling to create small flow images 18 | for scale in self.pyramid_scales: 19 | scaled_flow = sample['flow'][::scale, ::scale] 20 | sample['flow{}'.format(scale)] = scaled_flow 21 | return sample 22 | -------------------------------------------------------------------------------- /PyriteML/diffusion_policy/common/env_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def render_env_video(env, states, actions=None): 6 | observations = states 7 | imgs = list() 8 | for i in range(len(observations)): 9 | state = observations[i] 10 | env.set_state(state) 11 | if i == 0: 12 | env.set_state(state) 13 | img = env.render() 14 | # draw action 15 | if actions is not None: 16 | action = actions[i] 17 | coord = (action / 512 * 96).astype(np.int32) 18 | cv2.drawMarker(img, coord, 19 | color=(255,0,0), markerType=cv2.MARKER_CROSS, 20 | markerSize=8, thickness=1) 21 | imgs.append(img) 22 | imgs = np.array(imgs) 23 | return imgs 24 | -------------------------------------------------------------------------------- /PyriteML/diffusion_policy/model/common/shape_util.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple, Callable 2 | import torch 3 | import torch.nn as nn 4 | 5 | def get_module_device(m: nn.Module): 6 | device = torch.device('cpu') 7 | try: 8 | param = next(iter(m.parameters())) 9 | device = param.device 10 | except StopIteration: 11 | pass 12 | return device 13 | 14 | @torch.no_grad() 15 | def get_output_shape( 16 | input_shape: Tuple[int], 17 | net: Callable[[torch.Tensor], torch.Tensor] 18 | ): 19 | device = get_module_device(net) 20 | test_input = torch.zeros((1,)+tuple(input_shape), device=device) 21 | test_output = net(test_input) 22 | output_shape = tuple(test_output.shape[1:]) 23 | return output_shape 24 | -------------------------------------------------------------------------------- /PyriteUtility/PyriteUtility/common.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Callable, List 2 | import torch 3 | import signal 4 | 5 | 6 | def dict_apply( 7 | x: Dict[str, torch.Tensor], func: Callable[[torch.Tensor], torch.Tensor] 8 | ) -> Dict[str, torch.Tensor]: 9 | result = dict() 10 | for key, value in x.items(): 11 | if isinstance(value, dict): 12 | result[key] = dict_apply(value, func) 13 | else: 14 | result[key] = func(value) 15 | return result 16 | 17 | 18 | class GracefulKiller: 19 | kill_now = False 20 | 21 | def __init__(self): 22 | signal.signal(signal.SIGINT, self.exit_gracefully) 23 | signal.signal(signal.SIGTERM, self.exit_gracefully) 24 | 25 | def exit_gracefully(self, signum, frame): 26 | self.kill_now = True 27 | -------------------------------------------------------------------------------- /PyriteML/multimodal_representation/multimodal/configs/training_default.yaml: -------------------------------------------------------------------------------- 1 | training_type: "selfsupervised" 2 | log_level: 'INFO' 3 | 4 | test: False 5 | 6 | # Ablations 7 | encoder: False 8 | deterministic: False 9 | vision: 1.0 10 | depth: 1.0 11 | proprio: 1.0 12 | force: 1.0 13 | sceneflow: 1.0 14 | opticalflow: 1.0 15 | contact: 1.0 16 | pairing: 1.0 17 | eedelta: 1.0 18 | 19 | # Training parameters 20 | lr: 0.0001 21 | beta1: 0.9 22 | seed: 1234 23 | max_epoch: 50 24 | batch_size: 64 25 | ep_length: 50 26 | zdim: 128 27 | action_dim: 4 28 | 29 | # Dataset params 30 | dataset_params: 31 | force_name: "force" 32 | action_dim: 4 33 | 34 | load: '' 35 | logging_folder: logging/ 36 | 37 | 38 | # path to dataset hdf5 file' 39 | dataset: "dataset/triangle_real_data/" 40 | 41 | val_ratio: 0.20 42 | cuda: True 43 | num_workers: 8 44 | img_record_n: 500 45 | -------------------------------------------------------------------------------- /PyriteML/diffusion_policy/common/precise_sleep.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | def precise_sleep(dt: float, slack_time: float=0.001, time_func=time.monotonic): 4 | """ 5 | Use hybrid of time.sleep and spinning to minimize jitter. 6 | Sleep dt - slack_time seconds first, then spin for the rest. 7 | """ 8 | t_start = time_func() 9 | if dt > slack_time: 10 | time.sleep(dt - slack_time) 11 | t_end = t_start + dt 12 | while time_func() < t_end: 13 | pass 14 | return 15 | 16 | def precise_wait(t_end: float, slack_time: float=0.001, time_func=time.monotonic): 17 | t_start = time_func() 18 | t_wait = t_end - t_start 19 | if t_wait > 0: 20 | t_sleep = t_wait - slack_time 21 | if t_sleep > 0: 22 | time.sleep(t_sleep) 23 | while time_func() < t_end: 24 | pass 25 | return 26 | -------------------------------------------------------------------------------- /PyriteML/diffusion_policy/model/vision/model_getter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | 4 | def get_resnet(name, weights=None, **kwargs): 5 | """ 6 | name: resnet18, resnet34, resnet50 7 | weights: "IMAGENET1K_V1", "r3m" 8 | """ 9 | # load r3m weights 10 | if (weights == "r3m") or (weights == "R3M"): 11 | return get_r3m(name=name, **kwargs) 12 | 13 | func = getattr(torchvision.models, name) 14 | resnet = func(weights=weights, **kwargs) 15 | resnet.fc = torch.nn.Identity() 16 | return resnet 17 | 18 | def get_r3m(name, **kwargs): 19 | """ 20 | name: resnet18, resnet34, resnet50 21 | """ 22 | import r3m 23 | r3m.device = 'cpu' 24 | model = r3m.load_r3m(name) 25 | r3m_model = model.module 26 | resnet_model = r3m_model.convnet 27 | resnet_model = resnet_model.to('cpu') 28 | return resnet_model 29 | -------------------------------------------------------------------------------- /PyriteML/multimodal_representation/multimodal/dataloaders/ProcessForce.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class ProcessForce(object): 6 | """Truncate a time series of force readings with a window size. 7 | Args: 8 | window_size (int): Length of the history window that is 9 | used to truncate the force readings 10 | """ 11 | 12 | def __init__(self, window_size, key='force', tanh=False): 13 | assert isinstance(window_size, int) 14 | self.window_size = window_size 15 | self.key = key 16 | self.tanh = tanh 17 | 18 | def __call__(self, sample): 19 | force = sample[self.key] 20 | force = force[-self.window_size:] 21 | if self.tanh: 22 | force = np.tanh(force) # remove very large force readings 23 | sample[self.key] = force.transpose() 24 | return sample 25 | -------------------------------------------------------------------------------- /PyriteML/conda_environment.yaml: -------------------------------------------------------------------------------- 1 | name: pyrite 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - pytorch 9 | - torchvision 10 | - torchaudio 11 | - huggingface_hub 12 | - wandb 13 | - timm 14 | - diffusers 15 | - accelerate 16 | - threadpoolctl 17 | - plotly 18 | - dill 19 | - einops 20 | - hydra-core 21 | - ipykernel 22 | - ipython 23 | - matplotlib 24 | - omegaconf 25 | - opencv 26 | - pandas 27 | - pyyaml 28 | - scipy 29 | - tqdm 30 | - zarr=2.18 31 | - spatialmath-python 32 | - numcodecs 33 | - scikit-video 34 | - scikit-fda 35 | # trick to avoid cpu affinity issue described in https://github.com/pytorch/pytorch/issues/99625 36 | - llvm-openmp=14 37 | - eigen=3.4 38 | - cmake 39 | - pybind11 40 | - boost 41 | - fmt 42 | - spdlog 43 | - imageio-ffmpeg 44 | - nbformat 45 | - cvxpy -------------------------------------------------------------------------------- /PyriteML/diffusion_policy/common/nested_dict_util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | def nested_dict_map(f, x): 4 | """ 5 | Map f over all leaf of nested dict x 6 | """ 7 | 8 | if not isinstance(x, dict): 9 | return f(x) 10 | y = dict() 11 | for key, value in x.items(): 12 | y[key] = nested_dict_map(f, value) 13 | return y 14 | 15 | def nested_dict_reduce(f, x): 16 | """ 17 | Map f over all values of nested dict x, and reduce to a single value 18 | """ 19 | if not isinstance(x, dict): 20 | return x 21 | 22 | reduced_values = list() 23 | for value in x.values(): 24 | reduced_values.append(nested_dict_reduce(f, value)) 25 | y = functools.reduce(f, reduced_values) 26 | return y 27 | 28 | 29 | def nested_dict_check(f, x): 30 | bool_dict = nested_dict_map(f, x) 31 | result = nested_dict_reduce(lambda x, y: x and y, bool_dict) 32 | return result 33 | -------------------------------------------------------------------------------- /PyriteML/multimodal_representation/multimodal/dataloaders/ToTensor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class ToTensor(object): 6 | """Convert ndarrays in sample to Tensors.""" 7 | 8 | def __init__(self, device=None): 9 | self.device = device 10 | 11 | def __call__(self, sample): 12 | # swap color axis because 13 | # numpy image: H x W x C 14 | # torch image: C X H X W 15 | 16 | # transpose flow into 2 x H x W 17 | for k in sample.keys(): 18 | if k.startswith('flow'): 19 | sample[k] = sample[k].transpose((2, 0, 1)) 20 | 21 | # convert numpy arrays to pytorch tensors 22 | new_dict = dict() 23 | for k, v in sample.items(): 24 | if self.device is None: 25 | # torch.tensor(v, device = self.device, dtype = torch.float32) 26 | new_dict[k] = torch.FloatTensor(v) 27 | else: 28 | new_dict[k] = torch.from_numpy(v).float() 29 | 30 | return new_dict 31 | -------------------------------------------------------------------------------- /PyriteML/diffusion_policy/policy/base_image_policy.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import torch 3 | import torch.nn as nn 4 | from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin 5 | from diffusion_policy.model.common.normalizer import LinearNormalizer 6 | 7 | class BaseImagePolicy(ModuleAttrMixin): 8 | # init accepts keyword argument shape_meta, see config/task/*_image.yaml 9 | 10 | def predict_action(self, obs_dict: Dict[str, torch.Tensor], fixed_action_prefix: torch.Tensor=None) -> Dict[str, torch.Tensor]: 11 | """ 12 | obs_dict: 13 | str: B,To,* 14 | fixed_action_prefix: 15 | B, Tp, Da 16 | return: B,Ta,Da 17 | """ 18 | raise NotImplementedError() 19 | 20 | # reset state for stateful policies 21 | def reset(self): 22 | pass 23 | 24 | # ========== training =========== 25 | # no standard training interface except setting normalizer 26 | def set_normalizer(self, normalizer: LinearNormalizer): 27 | raise NotImplementedError() 28 | -------------------------------------------------------------------------------- /PyriteML/diffusion_policy/model/bet/libraries/mingpt/LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) Copyright (c) 2020 Andrej Karpathy 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Yifan Hou 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /PyriteUtility/PyriteUtility/plotting/matplotlib_helpers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | def set_axes_equal(ax): 6 | """ 7 | Make axes of 3D plot have equal scale so that spheres appear as spheres, 8 | cubes as cubes, etc. 9 | 10 | Input 11 | ax: a matplotlib axis, e.g., as output from plt.gca(). 12 | """ 13 | 14 | x_limits = ax.get_xlim3d() 15 | y_limits = ax.get_ylim3d() 16 | z_limits = ax.get_zlim3d() 17 | 18 | x_range = abs(x_limits[1] - x_limits[0]) 19 | x_middle = np.mean(x_limits) 20 | y_range = abs(y_limits[1] - y_limits[0]) 21 | y_middle = np.mean(y_limits) 22 | z_range = abs(z_limits[1] - z_limits[0]) 23 | z_middle = np.mean(z_limits) 24 | 25 | # The plot bounding box is a sphere in the sense of the infinity 26 | # norm, hence I call half the max range the plot radius. 27 | plot_radius = 0.5 * max([x_range, y_range, z_range]) 28 | 29 | ax.set_xlim3d([x_middle - plot_radius, x_middle + plot_radius]) 30 | ax.set_ylim3d([y_middle - plot_radius, y_middle + plot_radius]) 31 | ax.set_zlim3d([z_middle - plot_radius, z_middle + plot_radius]) 32 | -------------------------------------------------------------------------------- /PyriteML/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Columbia Artificial Intelligence and Robotics Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /PyriteML/diffusion_policy/policy/base_lowdim_policy.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import torch 3 | import torch.nn as nn 4 | from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin 5 | from diffusion_policy.model.common.normalizer import LinearNormalizer 6 | 7 | class BaseLowdimPolicy(ModuleAttrMixin): 8 | # ========= inference ============ 9 | # also as self.device and self.dtype for inference device transfer 10 | def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 11 | """ 12 | obs_dict: 13 | obs: B,To,Do 14 | return: 15 | action: B,Ta,Da 16 | To = 3 17 | Ta = 4 18 | T = 6 19 | |o|o|o| 20 | | | |a|a|a|a| 21 | |o|o| 22 | | |a|a|a|a|a| 23 | | | | | |a|a| 24 | """ 25 | raise NotImplementedError() 26 | 27 | # reset state for stateful policies 28 | def reset(self): 29 | pass 30 | 31 | # ========== training =========== 32 | # no standard training interface except setting normalizer 33 | def set_normalizer(self, normalizer: LinearNormalizer): 34 | raise NotImplementedError() 35 | 36 | -------------------------------------------------------------------------------- /PyriteML/multimodal_representation/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Stanford Interactive Perception and Robot Learning Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /PyriteML/multimodal_representation/multimodal/mini_main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import yaml 4 | 5 | from logger import Logger 6 | from trainers.selfsupervised import selfsupervised 7 | 8 | if __name__ == "__main__": 9 | 10 | # Load the config file 11 | parser = argparse.ArgumentParser(description="Sensor fusion model") 12 | parser.add_argument("--config", help="YAML config file") 13 | parser.add_argument("--notes", default="", help="run notes") 14 | parser.add_argument("--dev", type=bool, default=False, help="run in dev mode") 15 | parser.add_argument( 16 | "--continuation", 17 | type=bool, 18 | default=False, 19 | help="continue a previous run. Will continue the log file", 20 | ) 21 | args = parser.parse_args() 22 | 23 | # Add the yaml to the config args parse 24 | with open(args.config) as f: 25 | configs = yaml.load(f) 26 | 27 | # Merge configs and args 28 | for arg in vars(args): 29 | configs[arg] = getattr(args, arg) 30 | 31 | # Initialize the loggers 32 | logger = Logger(configs) 33 | 34 | # Initialize the trainer 35 | trainer = selfsupervised(configs, logger) 36 | 37 | trainer.train() 38 | -------------------------------------------------------------------------------- /PyriteML/multimodal_representation/README.md: -------------------------------------------------------------------------------- 1 | # Multimodal Representation 2 | 3 | Code for Making Sense of Vision and Touch. 4 | https://sites.google.com/view/visionandtouch 5 | 6 | Code written by: Matthew Tan, Michelle Lee, Peter Zachares, Yuke Zhu 7 | 8 | ## requirements 9 | `pip install -r requirements.txt` 10 | 11 | ## get dataset 12 | 13 | ``` 14 | cd multimodal/dataset 15 | ./download_data.sh 16 | ``` 17 | ## run training 18 | 19 | `python mini_main.py --config configs/training_default.yaml` 20 | 21 | 22 | ## ROBOT DATASET 23 | ---- 24 | action Dataset {50, 4}\ 25 | contact Dataset {50, 50}\ 26 | depth_data Dataset {50, 128, 128, 1}\ 27 | ee_forces_continuous Dataset {50, 50, 6}\ 28 | ee_ori Dataset {50, 4}\ 29 | ee_pos Dataset {50, 3}\ 30 | ee_vel Dataset {50, 3}\ 31 | ee_vel_ori Dataset {50, 3}\ 32 | ee_yaw Dataset {50, 4}\ 33 | ee_yaw_delta Dataset {50, 4}\ 34 | image Dataset {50, 128, 128, 3}\ 35 | joint_pos Dataset {50, 7}\ 36 | joint_vel Dataset {50, 7}\ 37 | optical_flow Dataset {50, 128, 128, 2}\ 38 | proprio Dataset {50, 8}\ 39 | 40 | -------------------------------------------------------------------------------- /PyriteEnvSuites/envs/task/manip_server_handle_env.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | SCRIPT_PATH = os.path.abspath(os.path.dirname(__file__)) 5 | PACKAGE_PATH = os.path.join(SCRIPT_PATH, "../../") 6 | sys.path.append(os.path.join(SCRIPT_PATH, "../../../")) 7 | 8 | import numpy as np 9 | 10 | from PyriteEnvSuites.envs.task.manip_server_env import ManipServerEnv 11 | 12 | class ManipServerHandleEnv(ManipServerEnv): 13 | """ 14 | This class is a wrapper for the ManipServerEnv class. 15 | It wraps the get observation function for the stow robot with handle. 16 | 17 | """ 18 | def __init__(self, *args, **kwargs): 19 | super(ManipServerHandleEnv, self).__init__(*args, **kwargs) 20 | 21 | def get_observation_from_buffer(self): 22 | obs = super(ManipServerHandleEnv, self).get_sparse_observation_from_buffer() 23 | return obs 24 | 25 | def start_saving_data_for_a_new_episode(self, episode_name = ""): 26 | self.server.start_listening_key_events() 27 | self.server.start_saving_data_for_a_new_episode(episode_name) 28 | 29 | def stop_saving_data(self): 30 | self.server.stop_saving_data() 31 | self.server.stop_listening_key_events() 32 | 33 | 34 | def get_episode_folder(self): 35 | return self.server.get_episode_folder() 36 | -------------------------------------------------------------------------------- /PyriteML/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | Training: 4 | python train.py --config-name=train_diffusion_lowdim_workspace 5 | """ 6 | 7 | import sys 8 | import os 9 | 10 | script_dir = os.path.dirname(os.path.abspath(__file__)) 11 | project_root = os.path.join(script_dir, "..") 12 | sys.path.append(os.path.join(project_root, "PyriteUtility")) 13 | 14 | # use line-buffering for both stdout and stderr 15 | sys.stdout = open(sys.stdout.fileno(), mode='w', buffering=1) 16 | sys.stderr = open(sys.stderr.fileno(), mode='w', buffering=1) 17 | 18 | import hydra 19 | from omegaconf import OmegaConf 20 | import pathlib 21 | from diffusion_policy.workspace.base_workspace import BaseWorkspace 22 | 23 | 24 | # allows arbitrary python code execution in configs using the ${eval:''} resolver 25 | OmegaConf.register_new_resolver("eval", eval, replace=True) 26 | 27 | @hydra.main( 28 | version_base=None, 29 | config_path=str(pathlib.Path(__file__).parent.joinpath( 30 | 'diffusion_policy','config')) 31 | ) 32 | def main(cfg: OmegaConf): 33 | # resolve immediately so all the ${now:} resolvers 34 | # will use the same time. 35 | OmegaConf.resolve(cfg) 36 | 37 | cls = hydra.utils.get_class(cfg._target_) 38 | workspace: BaseWorkspace = cls(cfg) 39 | workspace.run() 40 | 41 | if __name__ == "__main__": 42 | main() 43 | -------------------------------------------------------------------------------- /PyriteUtility/PyriteUtility/data_pipeline/zarr_examination.py: -------------------------------------------------------------------------------- 1 | import zarr 2 | import numpy as np 3 | import sys 4 | import os 5 | import matplotlib.pyplot as plt 6 | from tqdm import tqdm 7 | 8 | SCRIPT_PATH = os.path.abspath(os.path.dirname(__file__)) 9 | sys.path.append(os.path.join(SCRIPT_PATH, "../../")) 10 | 11 | if "PYRITE_DATASET_FOLDERS" not in os.environ: 12 | raise ValueError("Please set the environment variable PYRITE_DATASET_FOLDERS") 13 | dataset_folder_path = os.environ.get("PYRITE_DATASET_FOLDERS") 14 | 15 | 16 | from PyriteUtility.computer_vision.imagecodecs_numcodecs import ( 17 | register_codecs, 18 | Jpeg2k, 19 | JpegXl, 20 | ) 21 | 22 | register_codecs() 23 | 24 | 25 | 26 | # Config for umift (single robot) 27 | dataset_path = dataset_folder_path + "/online_belt_assembly_50/processed" 28 | id_list = [0] 29 | 30 | # ‘r’ means read only (must exist); ‘r+’ means read/write (must exist); ‘a’ means read/write (create if doesn’t exist); ‘w’ means create (overwrite if exists); ‘w-’ means create (fail if exists). 31 | buffer = zarr.open(dataset_path, mode="r+") 32 | 33 | for ep, ep_data in buffer["data"].items(): 34 | for key in ep_data.keys(): 35 | print(f"key: {key}, shape: {ep_data[key].shape}, dtype: {ep_data[key].dtype}") 36 | break # Remove this line to iterate through all episodes 37 | 38 | print("All Done!") 39 | -------------------------------------------------------------------------------- /PyriteUtility/PyriteUtility/data_pipeline/shared_memory/shared_memory_util.py: -------------------------------------------------------------------------------- 1 | # This file is copied from universal_manipulation_interface 2 | # https://github.com/real-stanford/universal_manipulation_interface 3 | # https://umi-gripper.github.io/ 4 | 5 | from typing import Tuple 6 | from dataclasses import dataclass 7 | import numpy as np 8 | from multiprocessing.managers import SharedMemoryManager 9 | from atomics import atomicview, MemoryOrder, UINT 10 | 11 | 12 | @dataclass 13 | class ArraySpec: 14 | name: str 15 | shape: Tuple[int] 16 | dtype: np.dtype 17 | 18 | 19 | class SharedAtomicCounter: 20 | def __init__(self, shm_manager: SharedMemoryManager, size: int = 8): # 64bit int 21 | shm = shm_manager.SharedMemory(size=size) 22 | self.shm = shm 23 | self.size = size 24 | self.store(0) # initialize 25 | 26 | @property 27 | def buf(self): 28 | return self.shm.buf[: self.size] 29 | 30 | def load(self) -> int: 31 | with atomicview(buffer=self.buf, atype=UINT) as a: 32 | value = a.load(order=MemoryOrder.ACQUIRE) 33 | return value 34 | 35 | def store(self, value: int): 36 | with atomicview(buffer=self.buf, atype=UINT) as a: 37 | a.store(value, order=MemoryOrder.RELEASE) 38 | 39 | def add(self, value: int): 40 | with atomicview(buffer=self.buf, atype=UINT) as a: 41 | a.add(value, order=MemoryOrder.ACQ_REL) 42 | -------------------------------------------------------------------------------- /PyriteML/diffusion_policy/model/diffusion/conv1d_components.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | # from einops.layers.torch import Rearrange 5 | 6 | 7 | class Downsample1d(nn.Module): 8 | def __init__(self, dim): 9 | super().__init__() 10 | self.conv = nn.Conv1d(dim, dim, 3, 2, 1) 11 | 12 | def forward(self, x): 13 | return self.conv(x) 14 | 15 | class Upsample1d(nn.Module): 16 | def __init__(self, dim): 17 | super().__init__() 18 | self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) 19 | 20 | def forward(self, x): 21 | return self.conv(x) 22 | 23 | class Conv1dBlock(nn.Module): 24 | ''' 25 | Conv1d --> GroupNorm --> Mish 26 | ''' 27 | 28 | def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): 29 | super().__init__() 30 | 31 | self.block = nn.Sequential( 32 | nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), 33 | # Rearrange('batch channels horizon -> batch channels 1 horizon'), 34 | nn.GroupNorm(n_groups, out_channels), 35 | # Rearrange('batch channels 1 horizon -> batch channels horizon'), 36 | nn.Mish(), 37 | ) 38 | 39 | def forward(self, x): 40 | return self.block(x) 41 | 42 | 43 | def test(): 44 | cb = Conv1dBlock(256, 128, kernel_size=3) 45 | x = torch.zeros((1,256,16)) 46 | o = cb(x) 47 | -------------------------------------------------------------------------------- /PyriteML/diffusion_policy/model/common/dict_of_tensor_mixin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class DictOfTensorMixin(nn.Module): 5 | def __init__(self, params_dict=None): 6 | super().__init__() 7 | if params_dict is None: 8 | params_dict = nn.ParameterDict() 9 | self.params_dict = params_dict 10 | 11 | @property 12 | def device(self): 13 | return next(iter(self.parameters())).device 14 | 15 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): 16 | def dfs_add(dest, keys, value: torch.Tensor): 17 | if len(keys) == 1: 18 | dest[keys[0]] = value 19 | return 20 | 21 | if keys[0] not in dest: 22 | dest[keys[0]] = nn.ParameterDict() 23 | dfs_add(dest[keys[0]], keys[1:], value) 24 | 25 | def load_dict(state_dict, prefix): 26 | out_dict = nn.ParameterDict() 27 | for key, value in state_dict.items(): 28 | value: torch.Tensor 29 | if key.startswith(prefix): 30 | param_keys = key[len(prefix):].split('.')[1:] 31 | # if len(param_keys) == 0: 32 | # import pdb; pdb.set_trace() 33 | dfs_add(out_dict, param_keys, value.clone()) 34 | return out_dict 35 | 36 | self.params_dict = load_dict(state_dict, prefix + 'params_dict') 37 | self.params_dict.requires_grad_(False) 38 | return 39 | -------------------------------------------------------------------------------- /PyriteML/diffusion_policy/common/pymunk_util.py: -------------------------------------------------------------------------------- 1 | import pygame 2 | import pymunk 3 | import pymunk.pygame_util 4 | import numpy as np 5 | 6 | COLLTYPE_DEFAULT = 0 7 | COLLTYPE_MOUSE = 1 8 | COLLTYPE_BALL = 2 9 | 10 | def get_body_type(static=False): 11 | body_type = pymunk.Body.DYNAMIC 12 | if static: 13 | body_type = pymunk.Body.STATIC 14 | return body_type 15 | 16 | 17 | def create_rectangle(space, 18 | pos_x,pos_y,width,height, 19 | density=3,static=False): 20 | body = pymunk.Body(body_type=get_body_type(static)) 21 | body.position = (pos_x,pos_y) 22 | shape = pymunk.Poly.create_box(body,(width,height)) 23 | shape.density = density 24 | space.add(body,shape) 25 | return body, shape 26 | 27 | 28 | def create_rectangle_bb(space, 29 | left, bottom, right, top, 30 | **kwargs): 31 | pos_x = (left + right) / 2 32 | pos_y = (top + bottom) / 2 33 | height = top - bottom 34 | width = right - left 35 | return create_rectangle(space, pos_x, pos_y, width, height, **kwargs) 36 | 37 | def create_circle(space, pos_x, pos_y, radius, density=3, static=False): 38 | body = pymunk.Body(body_type=get_body_type(static)) 39 | body.position = (pos_x, pos_y) 40 | shape = pymunk.Circle(body, radius=radius) 41 | shape.density = density 42 | shape.collision_type = COLLTYPE_BALL 43 | space.add(body, shape) 44 | return body, shape 45 | 46 | def get_body_state(body): 47 | state = np.zeros(6, dtype=np.float32) 48 | state[:2] = body.position 49 | state[2] = body.angle 50 | state[3:5] = body.velocity 51 | state[5] = body.angular_velocity 52 | return state 53 | -------------------------------------------------------------------------------- /PyriteML/diffusion_policy/model/residual/mlp.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | SCRIPT_PATH = os.path.abspath(os.path.dirname(__file__)) 5 | sys.path.append(os.path.join(SCRIPT_PATH, "../../../")) 6 | 7 | from typing import Dict 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | class MLPResidual(nn.Module): 13 | def __init__(self, input_dim, action_dim, action_horizon, hidden_dims=None, dropout=0.0): 14 | super().__init__() 15 | if hidden_dims is None: 16 | hidden_dims = [] 17 | self.input_layer = nn.Linear(input_dim, hidden_dims[0]) 18 | self.hidden_layers = nn.ModuleList() 19 | for i in range(len(hidden_dims)-1): 20 | self.hidden_layers.append(nn.Linear(hidden_dims[i], hidden_dims[i+1])) 21 | self.action_dim = action_dim 22 | self.action_horizon = action_horizon 23 | self.output_layer = nn.Linear(hidden_dims[-1], action_dim * action_horizon) 24 | self.activation = nn.GELU() 25 | self.dropout = dropout 26 | 27 | def forward(self, x): 28 | h = self.input_layer(x) 29 | h = self.activation(h) 30 | h = F.dropout(h, p=self.dropout, training=self.training) 31 | for layer in self.hidden_layers: 32 | h = layer(h) 33 | h = self.activation(h) 34 | h = F.dropout(h, p=self.dropout, training=self.training) 35 | h = self.output_layer(h) 36 | h = h.reshape(-1, self.action_horizon, self.action_dim) 37 | return h 38 | 39 | def compute_loss(self, x, target): 40 | pred = self.forward(x) 41 | loss = F.mse_loss(pred, target) 42 | return loss 43 | 44 | -------------------------------------------------------------------------------- /PyriteUtility/PyriteUtility/data_pipeline/real_data_delete_fields.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | SCRIPT_PATH = os.path.abspath(os.path.dirname(__file__)) 5 | sys.path.append(os.path.join(SCRIPT_PATH, "../../")) 6 | 7 | from PyriteUtility.data_pipeline.processing_functions import process_one_episode_into_zarr, generate_meta_for_zarr 8 | 9 | import pathlib 10 | import shutil 11 | import numpy as np 12 | import zarr 13 | import cv2 14 | import concurrent.futures 15 | 16 | # check environment variables 17 | if "PYRITE_RAW_DATASET_FOLDERS" not in os.environ: 18 | raise ValueError("Please set the environment variable PYRITE_RAW_DATASET_FOLDERS") 19 | if "PYRITE_DATASET_FOLDERS" not in os.environ: 20 | raise ValueError("Please set the environment variable PYRITE_DATASET_FOLDERS") 21 | 22 | 23 | # specify the input and output directories 24 | id_list = [0] # single robot 25 | # id_list = [0, 1] # bimanual 26 | 27 | output_dir = pathlib.Path( 28 | os.environ.get("PYRITE_DATASET_FOLDERS") + "/online_stow_nb_v5_50/processed_no_correction/data" 29 | ) 30 | 31 | # clean and create output folders 32 | 33 | files_to_delete = [] 34 | 35 | for episode_name in os.listdir(output_dir): 36 | if episode_name.startswith("."): 37 | continue 38 | 39 | episode_dir = output_dir.joinpath(episode_name) 40 | print("Checking episode: ", episode_dir) 41 | for id in id_list: 42 | files_to_delete.append(episode_dir.joinpath("policy_pose_command_" + str(id))) 43 | files_to_delete.append(episode_dir.joinpath("policy_time_stamps_" + str(id))) 44 | 45 | for file in files_to_delete: 46 | print("Deleting files: ", file) 47 | input("Press Enter to continue...") 48 | for file in files_to_delete: 49 | if os.path.exists(file): 50 | shutil.rmtree(file) 51 | print("All done! Deleted files") -------------------------------------------------------------------------------- /PyriteUtility/PyriteUtility/data_pipeline/real_data_just_update_meta.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | SCRIPT_PATH = os.path.abspath(os.path.dirname(__file__)) 5 | sys.path.append(os.path.join(SCRIPT_PATH, "../../")) 6 | 7 | from PyriteUtility.data_pipeline.processing_functions import process_one_episode_into_zarr, generate_meta_for_zarr 8 | 9 | import pathlib 10 | import shutil 11 | import numpy as np 12 | import zarr 13 | import cv2 14 | import concurrent.futures 15 | 16 | CORRECTION = False # set to true if you want to use the correction data 17 | 18 | # check environment variables 19 | if "PYRITE_RAW_DATASET_FOLDERS" not in os.environ: 20 | raise ValueError("Please set the environment variable PYRITE_RAW_DATASET_FOLDERS") 21 | if "PYRITE_DATASET_FOLDERS" not in os.environ: 22 | raise ValueError("Please set the environment variable PYRITE_DATASET_FOLDERS") 23 | 24 | 25 | # specify the input and output directories 26 | id_list = [0] # single robot 27 | # id_list = [0, 1] # bimanual 28 | 29 | output_dir = pathlib.Path( 30 | os.environ.get("PYRITE_DATASET_FOLDERS") + "/belt_assembly_offline_50_total_190" 31 | ) 32 | 33 | # open the zarr store 34 | store = zarr.DirectoryStore(path=output_dir) 35 | root = zarr.open(store=store, mode="a") 36 | 37 | episode_config = { 38 | "output_dir": output_dir, 39 | "id_list": id_list, 40 | "num_threads": 10, 41 | "has_correction": CORRECTION, 42 | "save_video": False, 43 | "max_workers": 32 44 | } 45 | 46 | print("Generating metadata") 47 | from PyriteUtility.computer_vision.imagecodecs_numcodecs import register_codecs 48 | 49 | register_codecs() 50 | 51 | 52 | count = generate_meta_for_zarr(root, episode_config) 53 | print(f"All done! Generated {count} episodes in {output_dir}") 54 | print("The only thing left is to run postprocess_add_virtual_target_label.py") 55 | -------------------------------------------------------------------------------- /PyriteML/diffusion_policy/model/bet/libraries/mingpt/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | 7 | 8 | def set_seed(seed): 9 | random.seed(seed) 10 | np.random.seed(seed) 11 | torch.manual_seed(seed) 12 | torch.cuda.manual_seed_all(seed) 13 | 14 | 15 | def top_k_logits(logits, k): 16 | v, ix = torch.topk(logits, k) 17 | out = logits.clone() 18 | out[out < v[:, [-1]]] = -float("Inf") 19 | return out 20 | 21 | 22 | @torch.no_grad() 23 | def sample(model, x, steps, temperature=1.0, sample=False, top_k=None): 24 | """ 25 | take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in 26 | the sequence, feeding the predictions back into the model each time. Clearly the sampling 27 | has quadratic complexity unlike an RNN that is only linear, and has a finite context window 28 | of block_size, unlike an RNN that has an infinite context window. 29 | """ 30 | block_size = model.get_block_size() 31 | model.eval() 32 | for k in range(steps): 33 | x_cond = ( 34 | x if x.size(1) <= block_size else x[:, -block_size:] 35 | ) # crop context if needed 36 | logits, _ = model(x_cond) 37 | # pluck the logits at the final step and scale by temperature 38 | logits = logits[:, -1, :] / temperature 39 | # optionally crop probabilities to only the top k options 40 | if top_k is not None: 41 | logits = top_k_logits(logits, top_k) 42 | # apply softmax to convert to probabilities 43 | probs = F.softmax(logits, dim=-1) 44 | # sample from the distribution or take the most likely 45 | if sample: 46 | ix = torch.multinomial(probs, num_samples=1) 47 | else: 48 | _, ix = torch.topk(probs, k=1, dim=-1) 49 | # append to the sequence and continue 50 | x = torch.cat((x, ix), dim=1) 51 | 52 | return x 53 | -------------------------------------------------------------------------------- /PyriteML/diffusion_policy/common/checkpoint_util.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Dict 2 | import os 3 | 4 | class TopKCheckpointManager: 5 | def __init__(self, 6 | save_dir, 7 | monitor_key: str, 8 | mode='min', 9 | k=1, 10 | format_str='epoch={epoch:03d}-train_loss={train_loss:.3f}.ckpt' 11 | ): 12 | assert mode in ['max', 'min'] 13 | assert k >= 0 14 | 15 | self.save_dir = save_dir 16 | self.monitor_key = monitor_key 17 | self.mode = mode 18 | self.k = k 19 | self.format_str = format_str 20 | self.path_value_map = dict() 21 | 22 | def get_ckpt_path(self, data: Dict[str, float]) -> Optional[str]: 23 | if self.k == 0: 24 | return None 25 | 26 | value = data[self.monitor_key] 27 | ckpt_path = os.path.join( 28 | self.save_dir, self.format_str.format(**data)) 29 | 30 | if len(self.path_value_map) < self.k: 31 | # under-capacity 32 | self.path_value_map[ckpt_path] = value 33 | return ckpt_path 34 | 35 | # at capacity 36 | sorted_map = sorted(self.path_value_map.items(), key=lambda x: x[1]) 37 | min_path, min_value = sorted_map[0] 38 | max_path, max_value = sorted_map[-1] 39 | 40 | delete_path = None 41 | if self.mode == 'max': 42 | if value > min_value: 43 | delete_path = min_path 44 | else: 45 | if value < max_value: 46 | delete_path = max_path 47 | 48 | if delete_path is None: 49 | return None 50 | else: 51 | del self.path_value_map[delete_path] 52 | self.path_value_map[ckpt_path] = value 53 | 54 | if not os.path.exists(self.save_dir): 55 | os.mkdir(self.save_dir) 56 | 57 | if os.path.exists(delete_path): 58 | os.remove(delete_path) 59 | return ckpt_path 60 | -------------------------------------------------------------------------------- /PyriteML/diffusion_policy/model/bet/action_ae/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data import DataLoader 4 | import abc 5 | 6 | from typing import Optional, Union 7 | 8 | import diffusion_policy.model.bet.utils as utils 9 | 10 | 11 | class AbstractActionAE(utils.SaveModule, abc.ABC): 12 | @abc.abstractmethod 13 | def fit_model( 14 | self, 15 | input_dataloader: DataLoader, 16 | eval_dataloader: DataLoader, 17 | obs_encoding_net: Optional[nn.Module] = None, 18 | ) -> None: 19 | pass 20 | 21 | @abc.abstractmethod 22 | def encode_into_latent( 23 | self, 24 | input_action: torch.Tensor, 25 | input_rep: Optional[torch.Tensor], 26 | ) -> torch.Tensor: 27 | """ 28 | Given the input action, discretize it. 29 | 30 | Inputs: 31 | input_action (shape: ... x action_dim): The input action to discretize. This can be in a batch, 32 | and is generally assumed that the last dimnesion is the action dimension. 33 | 34 | Outputs: 35 | discretized_action (shape: ... x num_tokens): The discretized action. 36 | """ 37 | raise NotImplementedError 38 | 39 | @abc.abstractmethod 40 | def decode_actions( 41 | self, 42 | latent_action_batch: Optional[torch.Tensor], 43 | input_rep_batch: Optional[torch.Tensor] = None, 44 | ) -> torch.Tensor: 45 | """ 46 | Given a discretized action, convert it to a continuous action. 47 | 48 | Inputs: 49 | latent_action_batch (shape: ... x num_tokens): The discretized action 50 | generated by the discretizer. 51 | 52 | Outputs: 53 | continuous_action (shape: ... x action_dim): The continuous action. 54 | """ 55 | raise NotImplementedError 56 | 57 | @property 58 | @abc.abstractmethod 59 | def num_latents(self) -> Union[int, float]: 60 | """ 61 | Number of possible latents for this generator, useful for state priors that use softmax. 62 | """ 63 | return float("inf") 64 | -------------------------------------------------------------------------------- /PyriteML/multimodal_representation/multimodal/models/tests/test_layers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from models.base_models.layers import conv2d, CausalConv1D 9 | 10 | 11 | class TestConv2d: 12 | def test_same_shape_no_dilation(self): 13 | x = torch.randn(1, 1, 5, 5) 14 | conv = conv2d(1, 1, 3) 15 | with torch.no_grad(): 16 | out = conv(x) 17 | assert out.shape[2:] == x.shape[2:] 18 | 19 | def test_same_shape_with_dilation(self): 20 | x = torch.randn(1, 1, 5, 5) 21 | conv = conv2d(1, 1, 3, dilation=2) 22 | with torch.no_grad(): 23 | out = conv(x) 24 | assert out.shape[2:] == x.shape[2:] 25 | 26 | 27 | class TestCausalConv1d: 28 | def test_same_shape_no_dilation(self): 29 | x = torch.randn(1, 1, 6) 30 | conv1d = CausalConv1D(1, 1, 3) 31 | with torch.no_grad(): 32 | out = conv1d(x) 33 | assert out.shape[2:] == x.shape[2:] 34 | 35 | def test_same_shape_with_dilation(self): 36 | x = torch.randn(1, 1, 6) 37 | conv1d = CausalConv1D(1, 1, 3, dilation=2) 38 | with torch.no_grad(): 39 | out = conv1d(x) 40 | assert out.shape[2:] == x.shape[2:] 41 | 42 | def test_causality_no_dilation(self): 43 | stride = 1 44 | length = 6 45 | dilation = 1 46 | kernel_size = 3 47 | x = torch.randn(1, 1, length) 48 | conv1d = CausalConv1D(1, 1, kernel_size, stride, dilation, bias=False) 49 | with torch.no_grad(): 50 | actual = conv1d(x) 51 | actual = actual.numpy().squeeze() 52 | weight = conv1d.weight.detach().clone().squeeze().numpy() 53 | padding = (int((kernel_size - 1) * dilation), 0) 54 | padded_x = F.pad(x, padding).detach().squeeze().numpy() 55 | expected = [] 56 | for i in range(length): 57 | expected.append(weight @ padded_x[i : i + 3]) 58 | expected = np.asarray(expected) 59 | assert np.allclose(actual, expected) 60 | -------------------------------------------------------------------------------- /PyriteML/diffusion_policy/model/common/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from diffusers.optimization import ( 2 | Union, SchedulerType, Optional, 3 | Optimizer, TYPE_TO_SCHEDULER_FUNCTION 4 | ) 5 | 6 | def get_scheduler( 7 | name: Union[str, SchedulerType], 8 | optimizer: Optimizer, 9 | num_warmup_steps: Optional[int] = None, 10 | num_training_steps: Optional[int] = None, 11 | **kwargs 12 | ): 13 | """ 14 | Added kwargs vs diffuser's original implementation 15 | 16 | Unified API to get any scheduler from its name. 17 | 18 | Args: 19 | name (`str` or `SchedulerType`): 20 | The name of the scheduler to use. 21 | optimizer (`torch.optim.Optimizer`): 22 | The optimizer that will be used during training. 23 | num_warmup_steps (`int`, *optional*): 24 | The number of warmup steps to do. This is not required by all schedulers (hence the argument being 25 | optional), the function will raise an error if it's unset and the scheduler type requires it. 26 | num_training_steps (`int``, *optional*): 27 | The number of training steps to do. This is not required by all schedulers (hence the argument being 28 | optional), the function will raise an error if it's unset and the scheduler type requires it. 29 | """ 30 | name = SchedulerType(name) 31 | schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] 32 | if name == SchedulerType.CONSTANT: 33 | return schedule_func(optimizer, **kwargs) 34 | 35 | # All other schedulers require `num_warmup_steps` 36 | if num_warmup_steps is None: 37 | raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") 38 | 39 | if name == SchedulerType.CONSTANT_WITH_WARMUP: 40 | return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **kwargs) 41 | 42 | # All other schedulers require `num_training_steps` 43 | if num_training_steps is None: 44 | raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") 45 | 46 | return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **kwargs) 47 | -------------------------------------------------------------------------------- /PyriteML/multimodal_representation/.gitignore: -------------------------------------------------------------------------------- 1 | # directories 2 | 3 | *.DS_Store* 4 | data/ 5 | tmp/ 6 | **/*.pkl 7 | 8 | logging/* 9 | 10 | *.h5py 11 | sftp-config.json 12 | 13 | # Byte-compiled / optimized / DLL files 14 | **__pycache__/ 15 | *.py[cod] 16 | *$py.class 17 | 18 | # Vim swap files 19 | *~ 20 | *.swp 21 | *.swo 22 | 23 | # C extensions 24 | *.so 25 | 26 | # Distribution / packaging 27 | .Python 28 | build/ 29 | develop-eggs/ 30 | dist/ 31 | downloads/ 32 | eggs/ 33 | .eggs/ 34 | lib/ 35 | lib64/ 36 | parts/ 37 | sdist/ 38 | var/ 39 | wheels/ 40 | pip-wheel-metadata/ 41 | share/python-wheels/ 42 | *.egg-info/ 43 | .installed.cfg 44 | *.egg 45 | MANIFEST 46 | 47 | # PyInstaller 48 | # Usually these files are written by a python script from a template 49 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 50 | *.manifest 51 | *.spec 52 | 53 | # Installer logs 54 | pip-log.txt 55 | pip-delete-this-directory.txt 56 | 57 | # Unit test / coverage reports 58 | htmlcov/ 59 | .tox/ 60 | .nox/ 61 | .coverage 62 | .coverage.* 63 | .cache 64 | nosetests.xml 65 | coverage.xml 66 | *.cover 67 | .hypothesis/ 68 | .pytest_cache/ 69 | 70 | # Translations 71 | *.mo 72 | *.pot 73 | 74 | # Django stuff: 75 | *.log 76 | local_settings.py 77 | db.sqlite3 78 | 79 | # Flask stuff: 80 | instance/ 81 | .webassets-cache 82 | 83 | # Scrapy stuff: 84 | .scrapy 85 | 86 | # Sphinx documentation 87 | docs/_build/ 88 | 89 | # PyBuilder 90 | target/ 91 | 92 | # Jupyter Notebook 93 | .ipynb_checkpoints 94 | 95 | # IPython 96 | profile_default/ 97 | ipython_config.py 98 | 99 | # pyenv 100 | .python-version 101 | 102 | # celery beat schedule file 103 | celerybeat-schedule 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | -------------------------------------------------------------------------------- /PyriteUtility/PyriteUtility/math/numerical_differentiation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def finite_difference(y, i, dt, order=1): 4 | """ 5 | Compute the finite difference of y at index i with a time step dt. 6 | https://web.media.mit.edu/~crtaylor/calculator.html 7 | 8 | params: 9 | y: list of vectors. [N, D] 10 | i: indices of the points to compute the derivative at. 11 | dt: float 12 | order: order of differentiation. 13 | """ 14 | N = len(y) 15 | 16 | if N < 5: 17 | raise ValueError("The array size must be greater than 5. Got: ", N) 18 | 19 | if order == 0: 20 | return y 21 | elif order == 1: 22 | y_x = (1*y[i-2]-8*y[i-1]+0*y[i+0]+8*y[i+1]-1*y[i+2])/(12*1.0*dt**1) 23 | return y_x 24 | elif order == 2: 25 | y_xx = (-1*y[i-2]+16*y[i-1]-30*y[i+0]+16*y[i+1]-1*y[i+2])/(12*1.0*dt**2) 26 | return y_xx 27 | elif order == 3: 28 | y_xxx = (-1*y[i-2]+2*y[i-1]+0*y[i+0]-2*y[i+1]+1*y[i+2])/(2*1.0*dt**3) 29 | return y_xxx 30 | else: 31 | raise ValueError("Only 0, 1, 2, 3 orders are implemented.") 32 | 33 | 34 | 35 | # wrote a test to check the correctness of the finite_difference function by creating a sine wave and then computing the derivative of the sine wave at a point using the finite_difference function. 36 | def test(): 37 | import matplotlib.pyplot as plt 38 | x = np.linspace(0, 10, 100) 39 | y = np.sin(x) 40 | yd = np.cos(x) 41 | ydd = -np.sin(x) 42 | yddd = -np.cos(x) 43 | 44 | dt = x[1] - x[0] 45 | ids = np.arange(10, 90) 46 | yd_fd = finite_difference(y, ids, dt, 1) 47 | ydd_fd = finite_difference(y, ids, dt, 2) 48 | yddd_fd = finite_difference(y, ids, dt, 3) 49 | 50 | # make a subplot for each of yd, ydd, yddd 51 | fig, axs = plt.subplots(3, 1, layout='constrained') 52 | 53 | 54 | 55 | axs[0].plot(x[ids], yd[ids], label="yd") 56 | axs[0].plot(x[ids], yd_fd, label="yd_fd") 57 | axs[1].plot(x[ids], ydd[ids], label="ydd") 58 | axs[1].plot(x[ids], ydd_fd, label="ydd_fd") 59 | axs[2].plot(x[ids], yddd[ids], label="yddd") 60 | axs[2].plot(x[ids], yddd_fd, label="yddd_fd") 61 | axs[0].legend() 62 | axs[1].legend() 63 | axs[2].legend() 64 | plt.show() 65 | 66 | 67 | 68 | 69 | if __name__ == "__main__": 70 | test() -------------------------------------------------------------------------------- /PyriteML/multimodal_representation/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.7.1 2 | astor==0.7.1 3 | attrs==19.3.0 4 | backcall==0.1.0 5 | bleach==3.1.0 6 | catkin-pkg==0.4.10 7 | certifi==2018.11.29 8 | cffi==1.13.2 9 | chardet==3.0.4 10 | cycler==0.10.0 11 | Cython==0.29.6 12 | DateTime==4.3 13 | decorator==4.3.0 14 | defusedxml==0.5.0 15 | e==1.4.5 16 | entrypoints==0.3 17 | envs==1.3 18 | future==0.17.1 19 | gast==0.2.2 20 | gitdb2==2.0.6 21 | GitPython==3.0.4 22 | glfw==1.9.1 23 | grpcio==1.20.0 24 | gym==0.10.9 25 | h5py==2.9.0 26 | hjson==3.0.1 27 | idna==2.8 28 | imageio==2.6.1 29 | importlib-metadata==1.5.0 30 | ipdb==0.12.2 31 | ipykernel==5.1.0 32 | ipython==7.2.0 33 | ipython-genutils==0.2.0 34 | ipywidgets==7.4.2 35 | jedi==0.13.2 36 | Jinja2==2.10 37 | jsonschema==2.6.0 38 | jupyter==1.0.0 39 | jupyter-client==5.2.4 40 | jupyter-console==6.0.0 41 | jupyter-core==4.4.0 42 | Keras-Applications==1.0.7 43 | Keras-Preprocessing==1.0.9 44 | kiwisolver==1.0.1 45 | lockfile==0.12.2 46 | lxml==4.4.2 47 | Markdown==3.1 48 | MarkupSafe==1.1.0 49 | matplotlib==3.0.2 50 | mistune==0.8.4 51 | mock==2.0.0 52 | more-itertools==8.2.0 53 | mpmath==1.1.0 54 | nbconvert==5.4.0 55 | nbformat==4.4.0 56 | notebook==5.7.4 57 | numpy==1.16.0 58 | packaging==20.1 59 | pandas==0.24.2 60 | pandocfilters==1.4.2 61 | parso==0.3.1 62 | pathlib2==2.3.5 63 | pbr==5.1.3 64 | pexpect==4.6.0 65 | pickleshare==0.7.5 66 | Pillow==5.4.1 67 | pluggy==0.13.1 68 | prometheus-client==0.5.0 69 | prompt-toolkit==2.0.7 70 | protobuf==3.6.1 71 | ptyprocess==0.6.0 72 | py==1.8.1 73 | pycodestyle==2.5.0 74 | pycparser==2.19 75 | pyflakes==2.1.1 76 | pyglet==1.3.2 77 | Pygments==2.3.1 78 | pyparsing==2.3.1 79 | pyquaternion==0.9.5 80 | pytest==5.3.5 81 | python-dateutil==2.7.5 82 | pytz==2018.9 83 | PyYAML==3.13 84 | pyzmq==17.1.2 85 | qtconsole==4.4.3 86 | requests==2.21.0 87 | scipy==1.2.0 88 | seaborn==0.9.0 89 | Send2Trash==1.5.0 90 | six==1.12.0 91 | smmap2==2.0.5 92 | snakeviz==2.0.1 93 | sympy==1.3 94 | tensorboard==1.12.2 95 | tensorboardX==1.6 96 | tensorflow-estimator==1.13.0 97 | termcolor==1.1.0 98 | terminado==0.8.1 99 | testpath==0.4.2 100 | torch==1.1.0 101 | torchsummary==1.5.1 102 | torchvision==0.3.0 103 | tornado==5.1.1 104 | tqdm==4.36.1 105 | traitlets==4.3.2 106 | urllib3==1.24.1 107 | wcwidth==0.1.7 108 | -------------------------------------------------------------------------------- /PyriteML/diffusion_policy/dataset/base_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | import torch.nn 5 | from diffusion_policy.model.common.normalizer import LinearNormalizer 6 | 7 | 8 | class BaseDataset(torch.utils.data.Dataset): 9 | def get_validation_dataset(self) -> 'BaseDataset': 10 | # return an empty dataset by default 11 | return BaseDataset() 12 | 13 | def get_normalizer(self, **kwargs) -> LinearNormalizer: 14 | raise NotImplementedError() 15 | 16 | def get_all_actions(self) -> torch.Tensor: 17 | raise NotImplementedError() 18 | 19 | def __len__(self) -> int: 20 | return 0 21 | 22 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: 23 | """ 24 | output: 25 | obs: 26 | key: T, * 27 | action: T, Da 28 | """ 29 | raise NotImplementedError() 30 | 31 | 32 | class BaseLowdimDataset(torch.utils.data.Dataset): 33 | def get_validation_dataset(self) -> 'BaseLowdimDataset': 34 | # return an empty dataset by default 35 | return BaseLowdimDataset() 36 | 37 | def get_normalizer(self, **kwargs) -> LinearNormalizer: 38 | raise NotImplementedError() 39 | 40 | def get_all_actions(self) -> torch.Tensor: 41 | raise NotImplementedError() 42 | 43 | def __len__(self) -> int: 44 | return 0 45 | 46 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: 47 | """ 48 | output: 49 | obs: T, Do 50 | action: T, Da 51 | """ 52 | raise NotImplementedError() 53 | 54 | 55 | class BaseImageDataset(torch.utils.data.Dataset): 56 | def get_validation_dataset(self) -> 'BaseImageDataset': 57 | # return an empty dataset by default 58 | return BaseImageDataset() 59 | 60 | def get_normalizer(self, **kwargs) -> LinearNormalizer: 61 | raise NotImplementedError() 62 | 63 | def get_all_actions(self) -> torch.Tensor: 64 | raise NotImplementedError() 65 | 66 | def __len__(self) -> int: 67 | return 0 68 | 69 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: 70 | """ 71 | output: 72 | obs: 73 | key: T, * 74 | action: T, Da 75 | """ 76 | raise NotImplementedError() 77 | -------------------------------------------------------------------------------- /PyriteUtility/PyriteUtility/pytorch_utils/model_io.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | SCRIPT_PATH = os.path.abspath(os.path.dirname(__file__)) 5 | sys.path.append(os.path.join(SCRIPT_PATH, "../../..")) 6 | 7 | import torch 8 | import dill 9 | import hydra 10 | 11 | from PyriteML.diffusion_policy.workspace.base_workspace import BaseWorkspace 12 | from PyriteML.diffusion_policy.workspace.train_diffusion_unet_image_finetune_workspace import ( 13 | TrainDiffusionUnetImageFinetuneWorkspace, 14 | ) 15 | 16 | 17 | def load_policy(ckpt_path, device): 18 | # load checkpoint 19 | if not ckpt_path.endswith(".ckpt"): 20 | ckpt_path = os.path.join(ckpt_path, "checkpoints", "latest.ckpt") 21 | payload = torch.load(open(ckpt_path, "rb"), map_location="cpu", pickle_module=dill) 22 | cfg = payload["cfg"] 23 | # print("model_name:", cfg.policy.obs_encoder.model_name) 24 | # print("dataset_path:", cfg.task.dataset.dataset_path) 25 | 26 | cls = hydra.utils.get_class(cfg._target_) 27 | # TODO(Yifan) Get policy independent from workspace 28 | workspace = cls(cfg) 29 | workspace: BaseWorkspace 30 | workspace.load_payload(payload, exclude_keys=['optimizer'], include_keys=None) 31 | 32 | policy = workspace.model 33 | if cfg.training.use_ema: 34 | policy = workspace.ema_model 35 | policy.num_inference_steps = ( 36 | cfg.policy.num_inference_steps 37 | ) # DDIM inference iterations 38 | 39 | policy.eval().to(device) 40 | policy.reset() 41 | return policy, cfg.task.shape_meta 42 | 43 | 44 | def serialize_model(ckpt_path): 45 | policy, shape_meta = load_policy(ckpt_path, "cuda") 46 | sm = torch.jit.script(policy) 47 | sm.save(ckpt_path.replace(".ckpt", ".pt")) 48 | 49 | 50 | # testing 51 | class MyModule(torch.nn.Module): 52 | def __init__(self, N, M): 53 | super(MyModule, self).__init__() 54 | self.weight = torch.nn.Parameter(torch.rand(N, M)) 55 | 56 | def forward(self, input): 57 | if input.sum() > 0: 58 | output = self.weight.mv(input) 59 | else: 60 | output = self.weight + input 61 | return output 62 | 63 | def hahaha(self, x): 64 | x = x + 1 65 | return x 66 | 67 | 68 | if __name__ == "__main__": 69 | policy, meta = load_policy( 70 | "/shared_local/training_outputs/test_kl_belt_new/checkpoints/latest.ckpt", "cuda" 71 | ) -------------------------------------------------------------------------------- /PyriteML/online_learning/learner.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import time 4 | import dill 5 | import hydra 6 | import pickle 7 | import time 8 | import numpy as np 9 | import numpy.typing as npt 10 | from torch.utils.data import DataLoader 11 | import robotmq as rmq 12 | 13 | class Learner: 14 | def __init__(self, 15 | network_server_endpoint: str, 16 | network_weight_topic: str, 17 | transitions_server_endpoint: str, 18 | transitions_topic: str, 19 | network_weight_expire_time_s: int): 20 | self.network_weight_server = rmq.RMQServer( 21 | server_name="network_weight_server", server_endpoint=network_server_endpoint 22 | ) 23 | self.network_weight_server.add_topic(network_weight_topic, network_weight_expire_time_s) 24 | print("[Learner] network_weight_server created") 25 | self.transitions_client = rmq.RMQClient( 26 | client_name="transitions_client", server_endpoint=transitions_server_endpoint 27 | ) 28 | print("[Learner] transitions_client created") 29 | 30 | self.network_weight_topic = network_weight_topic 31 | self.transitions_topic = transitions_topic 32 | 33 | ### payloads: {"model": model, "sparse_normalizer": sparse_normalizer} 34 | def send_network_weights(self, payloads: dict): 35 | start_time = time.time() 36 | pickle_data = pickle.dumps(payloads) 37 | dump_end_time = time.time() 38 | self.network_weight_server.put_data(self.network_weight_topic, pickle_data) 39 | send_end_time = time.time() 40 | 41 | print( 42 | f"[Learner] [send_network_weights] Data size: {len(pickle_data) / 1024**2:.3f}MB. dump: {dump_end_time - start_time:.4f}s, send: {send_end_time - dump_end_time: .4f}s)" 43 | ) 44 | 45 | def receive_transitions(self): 46 | retrieve_start_time = time.time() 47 | retrieved_data, timestamp = self.transitions_client.pop_data(topic=self.transitions_topic, n=1) 48 | retrieve_end_time = time.time() 49 | 50 | if retrieved_data: 51 | transitions = pickle.loads(retrieved_data[0]) 52 | 53 | print( 54 | f"[Learner] [receive_transitions] Received data size: {len(retrieved_data[0]) / 1024**2:.3f}MB. retrieve: {retrieve_end_time - retrieve_start_time:.4f}s, load: {time.time() - retrieve_end_time:.4f}s)" 55 | ) 56 | return transitions 57 | return None -------------------------------------------------------------------------------- /PyriteML/multimodal_representation/multimodal/models/models_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | from torch.distributions import Normal 5 | 6 | 7 | def init_weights(modules): 8 | """ 9 | Weight initialization from original SensorFusion Code 10 | """ 11 | for m in modules: 12 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 13 | nn.init.kaiming_normal_(m.weight.data) 14 | if m.bias is not None: 15 | m.bias.data.zero_() 16 | elif isinstance(m, nn.BatchNorm2d): 17 | m.weight.data.fill_(1) 18 | m.bias.data.zero_() 19 | 20 | 21 | def sample_gaussian(m, v, device): 22 | 23 | epsilon = Normal(0, 1).sample(m.size()) 24 | z = m + torch.sqrt(v) * epsilon.to(device) 25 | 26 | return z 27 | 28 | 29 | def gaussian_parameters(h, dim=-1): 30 | 31 | m, h = torch.split(h, h.size(dim) // 2, dim=dim) 32 | v = F.softplus(h) + 1e-8 33 | return m, v 34 | 35 | 36 | def product_of_experts(m_vect, v_vect): 37 | 38 | T_vect = 1.0 / v_vect 39 | 40 | mu = (m_vect * T_vect).sum(2) * (1 / T_vect.sum(2)) 41 | var = 1 / T_vect.sum(2) 42 | 43 | return mu, var 44 | 45 | 46 | def duplicate(x, rep): 47 | 48 | return x.expand(rep, *x.shape).reshape(-1, *x.shape[1:]) 49 | 50 | 51 | def depth_deconv(in_planes, out_planes): 52 | return nn.Sequential( 53 | nn.Conv2d( 54 | in_planes, 16, kernel_size=3, stride=1, padding=(3 - 1) // 2, bias=True 55 | ), 56 | nn.LeakyReLU(0.1, inplace=True), 57 | nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=(3 - 1) // 2, bias=True), 58 | nn.LeakyReLU(0.1, inplace=True), 59 | nn.ConvTranspose2d( 60 | 16, out_planes, kernel_size=4, stride=2, padding=1, bias=True 61 | ), 62 | nn.LeakyReLU(0.1, inplace=True), 63 | ) 64 | 65 | 66 | def rescaleImage(image, output_size=128, scale=1 / 255.0): 67 | """Rescale the image in a sample to a given size. 68 | Args: 69 | output_size (tuple or int): Desired output size. If tuple, output is 70 | matched to output_size. If int, smaller of image edges is matched 71 | to output_size keeping aspect ratio the same. 72 | """ 73 | image_transform = image * scale 74 | return image_transform.transpose(1, 3).transpose(2, 3) 75 | 76 | 77 | def filter_depth(depth_image): 78 | depth_image = torch.where( 79 | depth_image > 1e-7, depth_image, torch.zeros_like(depth_image) 80 | ) 81 | return torch.where(depth_image < 2, depth_image, torch.zeros_like(depth_image)) 82 | -------------------------------------------------------------------------------- /PyriteML/scripts/test_rmq_client.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from typing import Dict, Callable, Tuple, List 4 | 5 | SCRIPT_PATH = "/home/yifanhou/git/PyriteML/scripts" 6 | sys.path.append(os.path.join(SCRIPT_PATH, '../')) 7 | 8 | 9 | import numpy as np 10 | import torch 11 | import time 12 | import dill 13 | import hydra 14 | import pickle 15 | import time 16 | import numpy as np 17 | import numpy.typing as npt 18 | from torch.utils.data import DataLoader 19 | import robotmq as rmq 20 | 21 | from diffusion_policy.workspace.base_workspace import BaseWorkspace 22 | from diffusion_policy.dataset.base_dataset import BaseImageDataset, BaseDataset 23 | from diffusion_policy.workspace.train_diffusion_unet_image_workspace import TrainDiffusionUnetImageWorkspace 24 | 25 | client = rmq.RMQClient( 26 | client_name="asynchronous_client", server_endpoint="ipc:///tmp/feeds/0" 27 | ) 28 | 29 | # client = rmq.RMQClient("asynchronous_client", "tcp://localhost:5555") 30 | print("Client created") 31 | 32 | while True: 33 | retrieve_start_time = time.time() 34 | retrieved_data, timestamp = client.pop_data(topic="test_checkpoints", n=-1) 35 | retrieve_end_time = time.time() 36 | 37 | if retrieved_data: 38 | received_data = pickle.loads(retrieved_data[0]) 39 | 40 | print( 41 | f"Data size: {len(retrieved_data[0]) / 1024**2:.3f}MB. retrieve: {retrieve_end_time - retrieve_start_time:.4f}s, load: {time.time() - retrieve_end_time:.4f}s)" 42 | ) 43 | 44 | statedict = received_data.state_dict() 45 | for key, value in statedict.items(): 46 | print(key, ": ", value.shape) 47 | # # use the received payload 48 | # cfg = received_data['cfg'] 49 | # print("dataset_path:", cfg.task.dataset.dataset_path) 50 | 51 | # cls = hydra.utils.get_class(cfg._target_) 52 | # workspace = cls(cfg) 53 | # workspace: BaseWorkspace 54 | # workspace.load_payload(received_data, exclude_keys=None, include_keys=None) 55 | 56 | # policy = workspace.model 57 | # if cfg.training.use_ema: 58 | # policy = workspace.ema_model 59 | # policy.num_inference_steps = cfg.policy.num_inference_steps # DDIM inference iterations 60 | 61 | # device = torch.device('cpu') 62 | # policy.eval().to(device) 63 | # policy.reset() 64 | 65 | # # use normalizer saved in the policy 66 | # sparse_normalizer, dense_normalizer = policy.get_normalizer() 67 | 68 | # shape_meta = cfg.task.shape_meta 69 | # print("shape_meta:", shape_meta) 70 | break 71 | 72 | print("No data retrieved ...") 73 | time.sleep(0.2) -------------------------------------------------------------------------------- /PyriteML/scripts/test_rmq_server.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from typing import Dict, Callable, Tuple, List 4 | 5 | SCRIPT_PATH = "/home/yifanhou/git/PyriteML/scripts" 6 | sys.path.append(os.path.join(SCRIPT_PATH, '../')) 7 | 8 | 9 | import numpy as np 10 | import torch 11 | import time 12 | import dill 13 | import hydra 14 | import pickle 15 | import time 16 | import numpy as np 17 | import numpy.typing as npt 18 | from torch.utils.data import DataLoader 19 | import robotmq as rmq 20 | 21 | from diffusion_policy.workspace.base_workspace import BaseWorkspace 22 | from diffusion_policy.dataset.base_dataset import BaseImageDataset, BaseDataset 23 | from diffusion_policy.workspace.train_diffusion_unet_image_workspace import TrainDiffusionUnetImageWorkspace 24 | 25 | # data_path = "/home/yifanhou/training_outputs/" 26 | # ckpt_path = data_path + "2025.03.05_21.53.36_stow_no_force_202_stow_80/checkpoints/latest.ckpt" 27 | 28 | data_path = "/shared_local/training_outputs/" 29 | ckpt_path = data_path + "2025.03.07_13.55.19_stow_residual_residual_mlp/checkpoints/latest.ckpt" 30 | 31 | # load checkpoint 32 | if not ckpt_path.endswith('.ckpt'): 33 | ckpt_path = os.path.join(ckpt_path, 'checkpoints', 'latest.ckpt') 34 | residual_payload = torch.load(open(ckpt_path, 'rb'), map_location='cpu', pickle_module=dill) 35 | 36 | residual_cfg = residual_payload["cfg"] 37 | residual_cls = hydra.utils.get_class(residual_cfg._target_) 38 | residual_workspace = residual_cls(residual_cfg) 39 | residual_workspace: BaseWorkspace 40 | residual_workspace.load_payload(residual_payload, exclude_keys=None, include_keys=None) 41 | residual_policy = residual_workspace.model 42 | residual_obs_encoder = residual_workspace.obs_encoder 43 | # residual_policy.eval().to(device) 44 | # residual_obs_encoder.eval().to(device) 45 | # residual_shape_meta = residual_cfg.task.shape_meta 46 | # residual_normalizer = pickle.load(open(os.path.join(checkpoint_folder_path + pipeline_para["residual_ckpt_path"], "sparse_normalizer.pkl"), "rb")) 47 | 48 | 49 | server = rmq.RMQServer( 50 | server_name="test_rmq_server", server_endpoint="ipc:///tmp/feeds/0" 51 | ) 52 | 53 | print("Server created") 54 | 55 | server.add_topic("test_checkpoints", 10) 56 | 57 | input("Topic established. Press Enter to send data...") 58 | 59 | # Serialize the checkpoint 60 | start_time = time.time() 61 | 62 | pickle_data = pickle.dumps(residual_policy) 63 | # pickle_data = pickle.dumps(payload) 64 | dump_end_time = time.time() 65 | server.put_data("test_checkpoints", pickle_data) 66 | send_end_time = time.time() 67 | time.sleep(0.01) 68 | 69 | print( 70 | f"[Server] Data size: {len(pickle_data) / 1024**2:.3f}MB. dump: {dump_end_time - start_time:.4f}s, send: {send_end_time - dump_end_time: .4f}s)" 71 | ) -------------------------------------------------------------------------------- /PyriteML/online_learning/actor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import time 4 | import dill 5 | import hydra 6 | import pickle 7 | import time 8 | import numpy as np 9 | import numpy.typing as npt 10 | from torch.utils.data import DataLoader 11 | import robotmq as rmq 12 | 13 | # Check and receive res network weights 14 | # send transitions to learner 15 | 16 | class Actor: 17 | def __init__(self, 18 | network_server_endpoint: str, 19 | network_weight_topic: str, 20 | transitions_server_endpoint: str, 21 | transitions_topic: str, 22 | transitions_topic_expire_time_s: int): 23 | self.network_weight_client = rmq.RMQClient( 24 | client_name="network_weight_client", server_endpoint=network_server_endpoint 25 | ) 26 | print("[Actor] network_weight_client created") 27 | 28 | self.transitions_server = rmq.RMQServer( 29 | server_name="transitions_server", server_endpoint=transitions_server_endpoint 30 | ) 31 | self.transitions_server.add_topic(transitions_topic, transitions_topic_expire_time_s) 32 | print("[Actor] transitions_server created") 33 | 34 | self.network_weight_topic = network_weight_topic 35 | self.transitions_topic = transitions_topic 36 | 37 | def receive_network_weights(self, workspace): 38 | retrieve_start_time = time.time() 39 | retrieved_data, timestamp = self.network_weight_client.pop_data(topic=self.network_weight_topic, n=-1) 40 | retrieve_end_time = time.time() 41 | 42 | if retrieved_data: 43 | data = pickle.loads(retrieved_data[0]) 44 | workspace.model = data["model"] 45 | for key, value in data["trainable_obs_encoders"].items(): 46 | workspace.trainable_obs_encoders[key] = value 47 | workspace.sparse_normalizer = data["sparse_normalizer"] 48 | 49 | print( 50 | f"[Actor] [receive_network_weights] Received data size: {len(retrieved_data[0]) / 1024**2:.3f}MB. retrieve: {retrieve_end_time - retrieve_start_time:.4f}s, load: {time.time() - retrieve_end_time:.4f}s)" 51 | ) 52 | return True 53 | return False 54 | 55 | def send_transitions(self, transitions: dict): 56 | start_time = time.time() 57 | pickle_data = pickle.dumps(transitions) 58 | dump_end_time = time.time() 59 | self.transitions_server.put_data(self.transitions_topic, pickle_data) 60 | send_end_time = time.time() 61 | print( 62 | f"[Actor][send_transitions] Data size: {len(pickle_data) / 1024:.3f}KB. dump: {dump_end_time - start_time:.4f}s, send: {send_end_time - dump_end_time: .4f}s)" 63 | ) 64 | # time.sleep(1) 65 | -------------------------------------------------------------------------------- /PyriteUtility/PyriteUtility/hardware_interface/test_log_video.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.join(sys.path[0], "../..")) # PyriteUtility 5 | 6 | import cv2 7 | import json 8 | import time 9 | import numpy as np 10 | from multiprocessing.managers import SharedMemoryManager 11 | 12 | from PyriteUtility.umi_utils.usb_util import ( 13 | reset_all_elgato_devices, 14 | get_sorted_v4l_paths, 15 | ) 16 | from PyriteUtility.hardware_interface.multi_uvc_camera import MultiUvcCamera 17 | from PyriteUtility.hardware_interface.multi_camera_visualizer import ( 18 | MultiCameraVisualizer, 19 | ) 20 | from PyriteUtility.hardware_interface.video_recorder import VideoRecorder 21 | 22 | save_video_path = "/home/yifanhou/data/experiment_log" 23 | v4l_paths = [ 24 | "/dev/v4l/by-id/usb-Elgato_Cam_Link_4K_A29YB41521ZMA3-video-index0", 25 | "/dev/v4l/by-id/usb-Elgato_Game_Capture_HD60_X_00000001-video-index0", 26 | ] 27 | 28 | fps = 60 29 | resolution = (1920, 1080) 30 | 31 | # resolution = [(3840, 2160), (3840, 2160)] 32 | 33 | 34 | def test(): 35 | 36 | # Find and reset all Elgato capture cards. 37 | # Required to workaround a firmware bug. 38 | reset_all_elgato_devices() 39 | 40 | # Wait for all v4l cameras to be back online 41 | time.sleep(0.1) 42 | 43 | video_recorder = [ 44 | VideoRecorder.create_hevc_nvenc( 45 | fps=fps, input_pix_fmt="bgr24", bit_rate=6000 * 1000 46 | ) 47 | for v in v4l_paths 48 | ] 49 | 50 | with SharedMemoryManager() as shm_manager: 51 | with MultiUvcCamera( 52 | dev_video_paths=v4l_paths, 53 | shm_manager=shm_manager, 54 | resolution=resolution, 55 | capture_fps=fps, 56 | video_recorder=video_recorder, 57 | verbose=False, 58 | ) as camera: 59 | print("Started camera") 60 | with MultiCameraVisualizer( 61 | camera=camera, row=2, col=1, vis_fps=fps, rgb_to_bgr=False 62 | ) as multi_cam_vis: 63 | 64 | cv2.setNumThreads(1) 65 | video_path = save_video_path + f"/{time.strftime('%Y%m%d_%H%M%S')}/" 66 | rec_start_time = time.time() + 1 67 | camera.start_recording(video_path, start_time=rec_start_time) 68 | camera.restart_put(rec_start_time) 69 | time.sleep(1.5) 70 | 71 | print("Recording started") 72 | while True: 73 | time.sleep(0.5) 74 | if time.time() - rec_start_time > 20: 75 | print("----------Time is up!----------") 76 | break 77 | 78 | camera.stop_recording() 79 | camera.stop() 80 | 81 | 82 | if __name__ == "__main__": 83 | test() 84 | -------------------------------------------------------------------------------- /PyriteUtility/PyriteUtility/hardware_interface/multi_camera_visualizer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import multiprocessing as mp 3 | import numpy as np 4 | import cv2 5 | from threadpoolctl import threadpool_limits 6 | 7 | 8 | class MultiCameraVisualizer(mp.Process): 9 | def __init__( 10 | self, 11 | camera, 12 | row, 13 | col, 14 | window_name="Multi Cam Vis", 15 | vis_fps=60, 16 | fill_value=0, 17 | rgb_to_bgr=True, 18 | ): 19 | super().__init__() 20 | self.row = row 21 | self.col = col 22 | self.window_name = window_name 23 | self.vis_fps = vis_fps 24 | self.fill_value = fill_value 25 | self.rgb_to_bgr = rgb_to_bgr 26 | self.camera = camera 27 | # shared variables 28 | self.stop_event = mp.Event() 29 | self.flag_running = False 30 | 31 | def __enter__(self): 32 | self.start() 33 | return self 34 | 35 | def __exit__(self, exc_type, exc_val, exc_tb): 36 | self.stop() 37 | 38 | def start(self, wait=False): 39 | super().start() 40 | 41 | def stop(self, wait=False): 42 | self.stop_event.set() 43 | if wait: 44 | self.stop_wait() 45 | 46 | def start_wait(self): 47 | pass 48 | 49 | def stop_wait(self): 50 | self.join() 51 | 52 | def run(self): 53 | cv2.setNumThreads(1) 54 | threadpool_limits(1) 55 | channel_slice = slice(None) 56 | if self.rgb_to_bgr: 57 | channel_slice = slice(None, None, -1) 58 | 59 | vis_data = None 60 | vis_img = None 61 | self.flag_running = True 62 | print("Visualizer is running. Press q to quit") 63 | while not self.stop_event.is_set(): 64 | vis_data = self.camera.get_vis(out=vis_data) 65 | color = vis_data["color"] 66 | N, H, W, C = color.shape 67 | assert C == 3 68 | oh = H * self.row 69 | ow = W * self.col 70 | if vis_img is None: 71 | vis_img = np.full( 72 | (oh, ow, 3), fill_value=self.fill_value, dtype=np.uint8 73 | ) 74 | for row in range(self.row): 75 | for col in range(self.col): 76 | idx = col + row * self.col 77 | h_start = H * row 78 | h_end = h_start + H 79 | w_start = W * col 80 | w_end = w_start + W 81 | if idx < N: 82 | # opencv uses bgr 83 | vis_img[h_start:h_end, w_start:w_end] = color[ 84 | idx, :, :, channel_slice 85 | ] 86 | cv2.imshow(self.window_name, vis_img) 87 | cv2.pollKey() 88 | time.sleep(1 / self.vis_fps) 89 | -------------------------------------------------------------------------------- /base_policy.md: -------------------------------------------------------------------------------- 1 | Follow the following steps if you want to collect data and train a base policy by yourself. 2 | 3 | You do not need these steps if you are using an existing checkpoint as the base policy. 4 | 5 | # Data collection 6 | (Requires robot controller setup) 7 | The data collection pipeline is wrapped in `hardware_interfaces/applications/manipulation_data_collection". 8 | 1. Check the correct config file is selected in `hardware_interfaces/applications/manipulation_data_collection/src/main.cc`. 9 | 2. Build `hardware_interfaces` follow its readme. 10 | 3. On the UR teach pendant, make sure you calibrated the TCP mass. 11 | 4. Edit the config file specified in step 1, make sure you have the correct hardware IP/ID, data saving path, etc. 12 | 5. Launch the manipulation_data_collection binary: 13 | ``` sh 14 | cd hardware_interfaces/build 15 | ./applications/manipulation_data_collection/manipulation_data_collection 16 | ``` 17 | Then follow the on screen instructions. 18 | 19 | Our data collection pipeline saves data episode by episode. The saved data folder looks like this: 20 | ``` 21 | current_dataset/ 22 | episode_1727294514 23 | episode_1727294689 24 | episode_1727308394/ 25 | rgb_0 26 | rgb_1/ 27 | img_count_timestamp.jpg 28 | ... 29 | robot_data_0.json 30 | wrench_data_0.json 31 | ... 32 | ``` 33 | Within an episode, each file/folder corresponds to a device. Every frame of data is saved with a timestamp that was calibrated across all devices. For rgb images, its timestamp is saved in its file name, e.g. 34 | ``` 35 | img_000695_29345.186724_ms 36 | ``` 37 | means that this image is the 695th frame saved in this episode, and it is saved at 29345.186724ms since the program launched. 38 | 39 | ## Data postprocessing 40 | We postprocess the data to match the data format for training (zarr). 41 | ``` python 42 | python cr-dagger/PyriteUtility/data_pipeline/real_data_processing.py 43 | ``` 44 | Specify `id_list`, `input_dir`, `output_dir` then run the script. This script will compress the images into a [zarr](https://zarr.dev/) database, then generate meta data for the whole dataset. This step creates a new folder at `output_dir`. 45 | 46 | # Launch Training 47 | Train a diffusion policy as the base policy. 48 | 49 | 1. Set path to your zarr data in your task config under PyriteML/diffusion_policy/config/task/stow_no_force.yaml 50 | 2. Make sure the above task yaml is being selected in the workspace config at PyriteML/diffusion_policy/config/train_dp_workspace.yaml 51 | 3. Launch training with: 52 | ``` sh 53 | clear; 54 | cd PyriteML; 55 | HYDRA_FULL_ERROR=1 accelerate launch train.py --config-name=train_dp_workspace 56 | ``` 57 | Example of multi-gpu training: 58 | ``` sh 59 | HYDRA_FULL_ERROR=1 accelerate launch --gpu_ids 4,5 --num_processes=2 --main_process_port=28888 train.py --config-name=train_dp_workspace 60 | ``` 61 | -------------------------------------------------------------------------------- /PyriteML/README.md: -------------------------------------------------------------------------------- 1 | # PyriteML 2 | Models and training pipelines. 3 | 4 | ## Install Pyrite Packages 5 | The following is tested on Ubuntu 22.04. 6 | 7 | 1. Install [mamba](https://mamba.readthedocs.io/en/latest/installation/mamba-installation.html) 8 | 2. Clone the pyrite repos. 9 | ``` sh 10 | git clone git@github.com:yifan-hou/PyriteEnvSuites.git 11 | git clone git@github.com:yifan-hou/PyriteConfig.git 12 | git clone git@github.com:yifan-hou/PyriteUtility.git 13 | git clone --recursive git@github.com:yifan-hou/PyriteML.git 14 | ``` 15 | 3. Create a virtual env called `pyrite`: 16 | ``` sh 17 | cd PyriteML 18 | # Note 1: If env create gets stuck, you can create an empty environment, then install pytorch/torchvision/torchaudio following official pytorch installation instructions, then install the rest via mamba. 19 | # Note 2: zarr 3 changed many interfaces and does not work for PyriteML. We recommend to use zarr 2.18 20 | mamba env create -f conda_environment.yaml 21 | # after finish, activate it using 22 | mamba activate pyrite 23 | # a few pip installs 24 | pip install v4l2py 25 | pip install toppra 26 | pip install atomics 27 | pip install vit-pytorch # Need at least 1.7.12, which was not available in conda 28 | pip install imagecodecs # Need at least 2023.9.18, which caused lots of conflicts in conda 29 | 30 | # Install local packages 31 | cd PyriteUtilities 32 | pip install -e . 33 | ``` 34 | 4. Setup environment variables: add the following to your .bashrc or .zshrc, edit according to your local path. 35 | ``` sh 36 | # where the collected raw data folders are 37 | export PYRITE_RAW_DATASET_FOLDERS=$HOME/data/real 38 | # where the post-processed data folders are 39 | export PYRITE_DATASET_FOLDERS=$HOME/data/real_processed 40 | # Each training session will create a folder here. 41 | export PYRITE_CHECKPOINT_FOLDERS=$HOME/training_outputs 42 | # Hardware configs. 43 | export PYRITE_HARDWARE_CONFIG_FOLDERS=$HOME/hardware_interfaces/workcell/ur_test_bench/config 44 | # Logging folder. 45 | export PYRITE_CONTROL_LOG_FOLDERS=$HOME/data/control_log 46 | ``` 47 | 48 | ## Train a diffusion policy 49 | Train a diffusion policy as the base policy. 50 | 51 | 1. Set path to your zarr data in your task config under PyriteML/diffusion_policy/config/task/stow_no_force.yaml 52 | 2. Make sure the above task yaml is being selected in the workspace config at PyriteML/diffusion_policy/config/train_dp_workspace.yaml 53 | 3. Launch training with: 54 | ``` sh 55 | clear; 56 | cd PyriteML; 57 | HYDRA_FULL_ERROR=1 accelerate launch train.py --config-name=train_dp_workspace 58 | ``` 59 | Example of multi-gpu training: 60 | ``` sh 61 | HYDRA_FULL_ERROR=1 accelerate launch --gpu_ids 4,5 --num_processes=2 --main_process_port=28888 train.py --config-name=train_dp_workspace 62 | ``` 63 | 64 | 65 | 66 | 67 | ## 🏷️ License 68 | This repository is released under the MIT license. See [LICENSE](LICENSE) for additional details. 69 | 70 | -------------------------------------------------------------------------------- /PyriteML/diffusion_policy/model/bet/latent_generators/latent_generator.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import torch 3 | from typing import Tuple, Optional 4 | 5 | import diffusion_policy.model.bet.utils as utils 6 | 7 | 8 | class AbstractLatentGenerator(abc.ABC, utils.SaveModule): 9 | """ 10 | Abstract class for a generative model that can generate latents given observation representations. 11 | 12 | In the probabilisitc sense, this model fits and samples from P(latent|observation) given some observation. 13 | """ 14 | 15 | @abc.abstractmethod 16 | def get_latent_and_loss( 17 | self, 18 | obs_rep: torch.Tensor, 19 | target_latents: torch.Tensor, 20 | seq_masks: Optional[torch.Tensor] = None, 21 | ) -> Tuple[torch.Tensor, torch.Tensor]: 22 | """ 23 | Given a set of observation representation and generated latents, get the encoded latent and the loss. 24 | 25 | Inputs: 26 | input_action: Batch of the actions taken in the multimodal demonstrations. 27 | target_latents: Batch of the latents that the generator should learn to generate the actions from. 28 | seq_masks: Batch of masks that indicate which timesteps are valid. 29 | 30 | Outputs: 31 | latent: The sampled latent from the observation. 32 | loss: The loss of the latent generator. 33 | """ 34 | pass 35 | 36 | @abc.abstractmethod 37 | def generate_latents( 38 | self, seq_obses: torch.Tensor, seq_masks: torch.Tensor 39 | ) -> torch.Tensor: 40 | """ 41 | Given a batch of sequences of observations, generate a batch of sequences of latents. 42 | 43 | Inputs: 44 | seq_obses: Batch of sequences of observations, of shape seq x batch x dim, following the transformer convention. 45 | seq_masks: Batch of sequences of masks, of shape seq x batch, following the transformer convention. 46 | 47 | Outputs: 48 | seq_latents: Batch of sequences of latents of shape seq x batch x latent_dim. 49 | """ 50 | pass 51 | 52 | def get_optimizer( 53 | self, weight_decay: float, learning_rate: float, betas: Tuple[float, float] 54 | ) -> torch.optim.Optimizer: 55 | """ 56 | Default optimizer class. Override this if you want to use a different optimizer. 57 | """ 58 | return torch.optim.Adam( 59 | self.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=betas 60 | ) 61 | 62 | 63 | class LatentGeneratorDataParallel(torch.nn.DataParallel): 64 | def get_latent_and_loss(self, *args, **kwargs): 65 | return self.module.get_latent_and_loss(*args, **kwargs) # type: ignore 66 | 67 | def generate_latents(self, *args, **kwargs): 68 | return self.module.generate_latents(*args, **kwargs) # type: ignore 69 | 70 | def get_optimizer(self, *args, **kwargs): 71 | return self.module.get_optimizer(*args, **kwargs) # type: ignore 72 | -------------------------------------------------------------------------------- /PyriteUtility/PyriteUtility/data_pipeline/postprocessing_add_virtual_target_label.py: -------------------------------------------------------------------------------- 1 | # This script does the following: 2 | # 1. Compute the virtual target pose and stiffness based on the force/torque sensor data, add them to the zarr file 3 | # 2. Optionally plot the virtual target and the target in 3D space 4 | # 3. If ts_pose_command is not available, the script will populate it with ts_pose_fb. 5 | 6 | import zarr 7 | import numpy as np 8 | import sys 9 | import os 10 | import matplotlib.pyplot as plt 11 | from tqdm import tqdm 12 | 13 | SCRIPT_PATH = os.path.abspath(os.path.dirname(__file__)) 14 | sys.path.append(os.path.join(SCRIPT_PATH, "../../")) 15 | 16 | import concurrent.futures 17 | 18 | from PyriteUtility.data_pipeline.processing_functions import compute_vt_for_episode 19 | 20 | if "PYRITE_DATASET_FOLDERS" not in os.environ: 21 | raise ValueError("Please set the environment variable PYRITE_DATASET_FOLDERS") 22 | dataset_folder_path = os.environ.get("PYRITE_DATASET_FOLDERS") 23 | 24 | # Config for umift (single robot) 25 | dataset_path = dataset_folder_path + "/online_belt_assembly_50/processed" 26 | id_list = [0] 27 | 28 | # # Config for vase wiping (bimanual) 29 | # dataset_path = dataset_folder_path + "/vase_wiping_v5.2/" 30 | # id_list = [0, 1] 31 | 32 | # ‘r’ means read only (must exist); ‘r+’ means read/write (must exist); ‘a’ means read/write (create if doesn’t exist); ‘w’ means create (overwrite if exists); ‘w-’ means create (fail if exists). 33 | buffer = zarr.open(dataset_path, mode="r+") 34 | 35 | stiffness_estimation_paras = { 36 | "k_max": 5000, # 1cm 50N 37 | "k_min": 200, # 1cm 2.5N 38 | "f_low": 0.5, 39 | "f_high": 5, 40 | "dim": 3, 41 | "characteristic_length": 0.02, 42 | } 43 | 44 | vt_config = { 45 | "stiffness_estimation_para": stiffness_estimation_paras, 46 | "wrench_moving_average_window_size": 500, # should be around 1s of data, 47 | "flag_real": True, # False for simulation data 48 | "num_of_process": 1, # 5 49 | "flag_plot": False, 50 | "fin_every_n": 50, # 50 51 | "id_list": id_list, 52 | } 53 | 54 | if vt_config["flag_plot"]: 55 | assert vt_config["num_of_process"] == 1, "Plotting is not supported for multi-process" 56 | 57 | 58 | if vt_config["num_of_process"] == 1: 59 | for ep, ep_data in tqdm(buffer["data"].items(), desc="Episodes"): 60 | compute_vt_for_episode(ep, ep_data, vt_config) 61 | else: 62 | with concurrent.futures.ProcessPoolExecutor(max_workers=vt_config["num_of_process"]) as executor: 63 | futures = [ 64 | executor.submit( 65 | compute_vt_for_episode, 66 | ep, 67 | ep_data, 68 | vt_config, 69 | ) 70 | for ep, ep_data in tqdm(buffer["data"].items(), desc="Episodes") 71 | ] 72 | for future in concurrent.futures.as_completed(futures): 73 | if not future.result(): 74 | raise RuntimeError("Multi-processing failed!") 75 | 76 | 77 | print("All Done!") 78 | -------------------------------------------------------------------------------- /PyriteML/diffusion_policy/common/pytorch_util.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Callable, List 2 | import collections 3 | import torch 4 | import torch.nn as nn 5 | 6 | def dict_apply( 7 | x: Dict[str, torch.Tensor], 8 | func: Callable[[torch.Tensor], torch.Tensor] 9 | ) -> Dict[str, torch.Tensor]: 10 | result = dict() 11 | for key, value in x.items(): 12 | if isinstance(value, dict): 13 | result[key] = dict_apply(value, func) 14 | else: 15 | result[key] = func(value) 16 | return result 17 | 18 | def pad_remaining_dims(x, target): 19 | assert x.shape == target.shape[:len(x.shape)] 20 | return x.reshape(x.shape + (1,)*(len(target.shape) - len(x.shape))) 21 | 22 | def dict_apply_split( 23 | x: Dict[str, torch.Tensor], 24 | split_func: Callable[[torch.Tensor], Dict[str, torch.Tensor]] 25 | ) -> Dict[str, torch.Tensor]: 26 | results = collections.defaultdict(dict) 27 | for key, value in x.items(): 28 | result = split_func(value) 29 | for k, v in result.items(): 30 | results[k][key] = v 31 | return results 32 | 33 | def dict_apply_reduce( 34 | x: List[Dict[str, torch.Tensor]], 35 | reduce_func: Callable[[List[torch.Tensor]], torch.Tensor] 36 | ) -> Dict[str, torch.Tensor]: 37 | result = dict() 38 | for key in x[0].keys(): 39 | result[key] = reduce_func([x_[key] for x_ in x]) 40 | return result 41 | 42 | 43 | def replace_submodules( 44 | root_module: nn.Module, 45 | predicate: Callable[[nn.Module], bool], 46 | func: Callable[[nn.Module], nn.Module]) -> nn.Module: 47 | """ 48 | predicate: Return true if the module is to be replaced. 49 | func: Return new module to use. 50 | """ 51 | if predicate(root_module): 52 | return func(root_module) 53 | 54 | bn_list = [k.split('.') for k, m 55 | in root_module.named_modules(remove_duplicate=True) 56 | if predicate(m)] 57 | for *parent, k in bn_list: 58 | parent_module = root_module 59 | if len(parent) > 0: 60 | parent_module = root_module.get_submodule('.'.join(parent)) 61 | if isinstance(parent_module, nn.Sequential): 62 | src_module = parent_module[int(k)] 63 | else: 64 | src_module = getattr(parent_module, k) 65 | tgt_module = func(src_module) 66 | if isinstance(parent_module, nn.Sequential): 67 | parent_module[int(k)] = tgt_module 68 | else: 69 | setattr(parent_module, k, tgt_module) 70 | # verify that all BN are replaced 71 | bn_list = [k.split('.') for k, m 72 | in root_module.named_modules(remove_duplicate=True) 73 | if predicate(m)] 74 | assert len(bn_list) == 0 75 | return root_module 76 | 77 | def optimizer_to(optimizer, device): 78 | for state in optimizer.state.values(): 79 | for k, v in state.items(): 80 | if isinstance(v, torch.Tensor): 81 | state[k] = v.to(device=device) 82 | return optimizer 83 | -------------------------------------------------------------------------------- /PyriteUtility/PyriteUtility/computer_vision/computer_vision_utility.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import math 3 | import cv2 4 | import numpy as np 5 | 6 | def get_image_transform_with_border(input_res, output_res, bgr_to_rgb: bool=False): 7 | """ adds a border to make the input image square, and then resizes it to the output resolution """ 8 | iw, ih = input_res 9 | interp_method = cv2.INTER_AREA 10 | 11 | # Determine the size of the square 12 | size = max(iw, ih) 13 | top = (size - ih) // 2 14 | bottom = size - ih - top 15 | left = (size - iw) // 2 16 | right = size - iw - left 17 | 18 | def transform(img: np.ndarray): 19 | assert img.shape == (ih, iw, 3) 20 | # Add border to make the image square 21 | img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=[0, 0, 0]) 22 | # Resize 23 | 24 | if img.dtype == np.float16: 25 | img = img.astype(np.float32) 26 | # print("DEBUG img type:", type(img)) 27 | # print("DEBUG img shape:", getattr(img, 'shape', 'NO SHAPE')) 28 | 29 | 30 | img = cv2.resize(img, output_res, interpolation=interp_method) 31 | if img.dtype == np.float32: 32 | img = img.astype(np.float16) 33 | if bgr_to_rgb: 34 | img = img[:, :, ::-1] 35 | return img 36 | 37 | return transform 38 | 39 | def get_image_transform( 40 | input_res: Tuple[int,int]=(1280,720), 41 | output_res: Tuple[int,int]=(640,480), 42 | bgr_to_rgb: bool=False): 43 | 44 | iw, ih = input_res 45 | ow, oh = output_res 46 | rw, rh = None, None 47 | interp_method = cv2.INTER_AREA 48 | 49 | if (iw/ih) >= (ow/oh): 50 | # input is wider 51 | rh = oh 52 | rw = math.ceil(rh / ih * iw) 53 | if oh > ih: 54 | interp_method = cv2.INTER_LINEAR 55 | else: 56 | rw = ow 57 | rh = math.ceil(rw / iw * ih) 58 | if ow > iw: 59 | interp_method = cv2.INTER_LINEAR 60 | 61 | w_slice_start = (rw - ow) // 2 62 | w_slice = slice(w_slice_start, w_slice_start + ow) 63 | h_slice_start = (rh - oh) // 2 64 | h_slice = slice(h_slice_start, h_slice_start + oh) 65 | c_slice = slice(None) 66 | if bgr_to_rgb: 67 | c_slice = slice(None, None, -1) 68 | 69 | def transform(img: np.ndarray): 70 | assert img.shape == ((ih,iw,3)) 71 | # resize 72 | img = cv2.resize(img, (rw, rh), interpolation=interp_method) 73 | # crop 74 | img = img[h_slice, w_slice, c_slice] 75 | return img 76 | return transform 77 | 78 | 79 | # Apply mask to image. 80 | # img: [h,w,3] image 81 | # mask_polygon_vertices: [n,2] mask polygon vertices 82 | def apply_polygon_mask(img: np.ndarray, mask_polygon_vertices: np.ndarray, color: Tuple[int,int,int]=(0,0,0)): 83 | mask_pts = np.array(mask_polygon_vertices, dtype=np.int32) # Extract mask points 84 | mask = np.ones_like(img, dtype=np.uint8) * 255 85 | cv2.fillPoly(mask, [mask_pts], color) 86 | # apply the mask to the images 87 | img_masked = cv2.bitwise_and(img, mask) 88 | return img_masked -------------------------------------------------------------------------------- /PyriteUtility/PyriteUtility/umi_utils/usb_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | from subprocess import Popen, PIPE, DEVNULL 3 | import fcntl 4 | import pathlib 5 | 6 | 7 | def create_usb_list(): 8 | device_list = list() 9 | lsusb_out = Popen('lsusb -v', shell=True, bufsize=64, 10 | stdin=PIPE, stdout=PIPE, stderr=DEVNULL, 11 | close_fds=True).stdout.read().strip().decode('utf-8') 12 | usb_devices = lsusb_out.split('%s%s' % (os.linesep, os.linesep)) 13 | for device_categories in usb_devices: 14 | if not device_categories: 15 | continue 16 | categories = device_categories.split(os.linesep) 17 | device_stuff = categories[0].strip().split() 18 | bus = device_stuff[1] 19 | device = device_stuff[3][:-1] 20 | device_dict = {'bus': bus, 'device': device} 21 | device_info = ' '.join(device_stuff[6:]) 22 | device_dict['description'] = device_info 23 | for category in categories: 24 | if not category: 25 | continue 26 | categoryinfo = category.strip().split() 27 | if categoryinfo[0] == 'iManufacturer': 28 | manufacturer_info = ' '.join(categoryinfo[2:]) 29 | device_dict['manufacturer'] = manufacturer_info 30 | if categoryinfo[0] == 'iProduct': 31 | device_info = ' '.join(categoryinfo[2:]) 32 | device_dict['device'] = device_info 33 | path = '/dev/bus/usb/%s/%s' % (bus, device) 34 | device_dict['path'] = path 35 | 36 | device_list.append(device_dict) 37 | return device_list 38 | 39 | def reset_usb_device(dev_path): 40 | USBDEVFS_RESET = 21780 41 | try: 42 | f = open(dev_path, 'w', os.O_WRONLY) 43 | fcntl.ioctl(f, USBDEVFS_RESET, 0) 44 | print('Successfully reset %s' % dev_path) 45 | except PermissionError as ex: 46 | raise PermissionError('Try running "sudo chmod 777 {}"'.format(dev_path)) 47 | 48 | def reset_all_elgato_devices(): 49 | """ 50 | Find and reset all Elgato capture cards. 51 | Required to workaround a firmware bug. 52 | """ 53 | 54 | # enumerate UBS device to find Elgato Capture Card 55 | device_list = create_usb_list() 56 | 57 | for dev in device_list: 58 | if 'Elgato' in dev['description']: 59 | dev_usb_path = dev['path'] 60 | print("Resetting Elgato device at %s" % dev_usb_path) 61 | reset_usb_device(dev_usb_path) 62 | 63 | def get_sorted_v4l_paths(by_id=True): 64 | """ 65 | If by_id, sort devices by device name + serial number (preserves device order) 66 | else, sort devices by usb bus id (preserves usb port order) 67 | """ 68 | 69 | dirname = 'by-id' 70 | if not by_id: 71 | dirname = 'by-path' 72 | v4l_dir = pathlib.Path('/dev/v4l').joinpath(dirname) 73 | 74 | valid_paths = list() 75 | for dev_path in sorted(v4l_dir.glob("*video*")): 76 | name = dev_path.name 77 | 78 | # only keep devices ends with "index0" 79 | # since they are the only valid video devices 80 | index_str = name.split('-')[-1] 81 | assert index_str.startswith('index') 82 | index = int(index_str[5:]) 83 | if index == 0: 84 | valid_paths.append(dev_path) 85 | 86 | result = [str(x.absolute()) for x in valid_paths] 87 | 88 | return result 89 | -------------------------------------------------------------------------------- /PyriteML/diffusion_policy/model/vision/force_spec_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from scipy import signal 4 | import numpy as np 5 | 6 | 7 | def convert_to_spec(data, fmin=0, fmax=200, sample_rate=7000, pad_spec=False): 8 | specs = [] 9 | for i in range(data.shape[-1]): 10 | f, t, Sxx = signal.spectrogram( 11 | data[..., i].detach().cpu().numpy(), 12 | fs=sample_rate, 13 | nperseg=512, 14 | noverlap=512 // 4, 15 | nfft=1024, 16 | ) 17 | freq_slice = np.where((f >= fmin) & (f <= fmax)) 18 | # keep only frequencies of interest 19 | f = f[freq_slice] 20 | Sxx = Sxx[:, freq_slice, :] 21 | 22 | if pad_spec: 23 | padded_values = Sxx[:, :, :, 0] 24 | Sxx = np.concatenate([padded_values[:, :, :, None], Sxx], axis=-1) 25 | specs.append(Sxx) 26 | specs = np.concatenate(specs, axis=1) 27 | specs = torch.tensor(specs).float().to(data.device) # (B, 6, 30, 17/18) 28 | return specs 29 | 30 | 31 | class CoordConv(nn.Module): 32 | """Add coordinates in [0,1] to an image, like CoordConv paper.""" 33 | 34 | def forward(self, x): 35 | # needs N,C,H,W inputs 36 | assert x.ndim == 4 37 | h, w = x.shape[2:] 38 | ones_h = x.new_ones((h, 1)) 39 | type_dev = dict(dtype=x.dtype, device=x.device) 40 | lin_h = torch.linspace(-1, 1, h, **type_dev)[:, None] 41 | ones_w = x.new_ones((1, w)) 42 | lin_w = torch.linspace(-1, 1, w, **type_dev)[None, :] 43 | new_maps_2d = torch.stack((lin_h * ones_w, lin_w * ones_h), dim=0) 44 | new_maps_4d = new_maps_2d[None] 45 | assert new_maps_4d.shape == (1, 2, h, w), (x.shape, new_maps_4d.shape) 46 | batch_size = x.size(0) 47 | new_maps_4d_batch = new_maps_4d.repeat(batch_size, 1, 1, 1) 48 | result = torch.cat((x, new_maps_4d_batch), dim=1) 49 | return result 50 | 51 | 52 | class ForceSpecEncoder(nn.Module): 53 | def __init__(self, model, norm_spec): 54 | super().__init__() 55 | # modify model to take in 8 channels 56 | model[0] = nn.Conv2d( 57 | 8, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False 58 | ) 59 | self.model = model 60 | self.coord_conv = CoordConv() # similar as positional encoding 61 | self.norm_spec = norm_spec 62 | 63 | def forward(self, spec): 64 | EPS = 1e-8 65 | # spec: B x C(6) x F x T 66 | log_spec = torch.log(spec + EPS) 67 | x = log_spec 68 | if self.norm_spec.is_norm: 69 | x = (x - self.norm_spec.min) / (self.norm_spec.max - self.norm_spec.min) 70 | x = x * 2 - 1 71 | 72 | x = self.coord_conv(x) 73 | # x: B x C(8) x F x T 74 | x = self.model(x) 75 | return x 76 | 77 | 78 | class ForceSpecTransformer(nn.Module): 79 | def __init__(self, model, norm_spec): 80 | super().__init__() 81 | self.model = model 82 | self.norm_spec = norm_spec 83 | 84 | def forward(self, spec): 85 | EPS = 1e-8 86 | # spec: B x C(6) x F x T 87 | log_spec = torch.log(spec + EPS) 88 | x = log_spec 89 | if self.norm_spec.is_norm: 90 | x = (x - self.norm_spec.min) / (self.norm_spec.max - self.norm_spec.min) 91 | x = x * 2 - 1 92 | 93 | x = self.model(x) 94 | return x 95 | -------------------------------------------------------------------------------- /PyriteUtility/PyriteUtility/planning_control/shooting_sampling.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | import time 4 | 5 | from PyriteUtility.planning_control.tree import Tree 6 | from PyriteUtility.spatial_math import spatial_utilities as su 7 | 8 | from PyriteGenesis.envs.genesis_base_task_env import State, Action 9 | 10 | class ShootingSampling(): 11 | """ 12 | Base class for shooting sampling 13 | 14 | """ 15 | def __init__(self, config, batch_simulate_func, validify_action_func, log_folder_path): 16 | self.config = config 17 | self.batch_simulate_func = batch_simulate_func 18 | self.log_folder_path = log_folder_path 19 | 20 | def rand_vel_sample(self, n_vel_samples): 21 | dim = self.config["action_dim"] 22 | 23 | nominal_vel = np.zeros(dim) 24 | 25 | robot_nominal_vel = nominal_vel[:-4] # Exclude the gripper joint 26 | gripper_vel = 0 27 | # Sample from a Gaussian distribution 28 | gripper_delta_action = np.random.normal(loc=robot_nominal_vel, scale=0.1, size=(n_vel_samples, 1, dim-4)) 29 | 30 | # normalize the delta_action 31 | norm = np.linalg.norm(gripper_delta_action, axis=2, keepdims=True) 32 | norm[norm == 0] = 1 # avoid division by zero 33 | gripper_delta_action = gripper_delta_action / norm * self.config["action_mag_per_step"] 34 | 35 | return np.concatenate([gripper_delta_action, gripper_vel * np.ones((n_vel_samples, 1, 4))], axis=2) 36 | 37 | def update_action(self, action_samples, costs): 38 | # pick the action corresponding to the minimum cost 39 | min_cost_index = np.argmin(costs) 40 | action = action_samples[min_cost_index] 41 | return action, min_cost_index 42 | 43 | def solve(self, 44 | pose7_items, # (N_item, 7) 45 | robot_state, # (N, D) 46 | is_grasp, # (N,) 47 | target_item_id, # (N,) 48 | pose7_WGoal_all): 49 | 50 | N = len(robot_state) 51 | n_vel_samples = np.floor(self.config["num_envs"] / N).astype(int) 52 | 53 | action_sample = self.rand_vel_sample(n_vel_samples) # (ns, 1, d) 54 | states_all, actions_all, env_safety_num_frames, pose7_items_stabilized, target_item_ids_all = self.batch_simulate_func( 55 | pose7_items=pose7_items, 56 | robot_state=robot_state, 57 | batch_actions=action_sample, 58 | is_grasp = is_grasp, 59 | target_item_id=target_item_id, 60 | pose7_WGoal_all=pose7_WGoal_all 61 | ) 62 | states_list = states_all.extract_to_list(env_safety_num_frames, 3) 63 | actions_list = actions_all.extract_to_list(env_safety_num_frames, 3) 64 | pose7_items_stabilized = list(pose7_items_stabilized[env_safety_num_frames >= 3]) 65 | target_item_ids = list(target_item_ids_all[env_safety_num_frames >= 3]) 66 | # data_log.append({ 67 | # "states": states_list, 68 | # "actions": actions_list, 69 | # }) 70 | return states_list, actions_list, pose7_items_stabilized, target_item_ids 71 | 72 | if __name__ == "__main__": 73 | # Example usage 74 | config = { 75 | "action_horizon": 20, 76 | "action_dim": 7, 77 | "n_samples": 2 #2048, 78 | } 79 | 80 | predictive_sampling = ShootingSampling(config) 81 | 82 | sample = predictive_sampling.rand_vel_sample(None) 83 | print("Sampled actions:", sample) 84 | -------------------------------------------------------------------------------- /PyriteML/diffusion_policy/model/diffusion/ema_model.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | from torch.nn.modules.batchnorm import _BatchNorm 4 | 5 | class EMAModel: 6 | """ 7 | Exponential Moving Average of models weights 8 | """ 9 | 10 | def __init__( 11 | self, 12 | model, 13 | update_after_step=0, 14 | inv_gamma=1.0, 15 | power=2 / 3, 16 | min_value=0.0, 17 | max_value=0.9999 18 | ): 19 | """ 20 | @crowsonkb's notes on EMA Warmup: 21 | If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan 22 | to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), 23 | gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 24 | at 215.4k steps). 25 | Args: 26 | inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. 27 | power (float): Exponential factor of EMA warmup. Default: 2/3. 28 | min_value (float): The minimum EMA decay rate. Default: 0. 29 | """ 30 | 31 | self.averaged_model = model 32 | self.averaged_model.eval() 33 | self.averaged_model.requires_grad_(False) 34 | 35 | self.update_after_step = update_after_step 36 | self.inv_gamma = inv_gamma 37 | self.power = power 38 | self.min_value = min_value 39 | self.max_value = max_value 40 | 41 | self.decay = 0.0 42 | self.optimization_step = 0 43 | 44 | def get_decay(self, optimization_step): 45 | """ 46 | Compute the decay factor for the exponential moving average. 47 | """ 48 | step = max(0, optimization_step - self.update_after_step - 1) 49 | value = 1 - (1 + step / self.inv_gamma) ** -self.power 50 | 51 | if step <= 0: 52 | return 0.0 53 | 54 | return max(self.min_value, min(value, self.max_value)) 55 | 56 | @torch.no_grad() 57 | def step(self, new_model): 58 | self.decay = self.get_decay(self.optimization_step) 59 | 60 | # old_all_dataptrs = set() 61 | # for param in new_model.parameters(): 62 | # data_ptr = param.data_ptr() 63 | # if data_ptr != 0: 64 | # old_all_dataptrs.add(data_ptr) 65 | 66 | all_dataptrs = set() 67 | for module, ema_module in zip(new_model.modules(), self.averaged_model.modules()): 68 | for param, ema_param in zip(module.parameters(recurse=False), ema_module.parameters(recurse=False)): 69 | # iterative over immediate parameters only. 70 | if isinstance(param, dict): 71 | raise RuntimeError('Dict parameter not supported') 72 | 73 | # data_ptr = param.data_ptr() 74 | # if data_ptr != 0: 75 | # all_dataptrs.add(data_ptr) 76 | 77 | if isinstance(module, _BatchNorm): 78 | # skip batchnorms 79 | ema_param.copy_(param.to(dtype=ema_param.dtype).data) 80 | elif not param.requires_grad: 81 | ema_param.copy_(param.to(dtype=ema_param.dtype).data) 82 | else: 83 | ema_param.mul_(self.decay) 84 | ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay) 85 | 86 | # verify that iterating over module and then parameters is identical to parameters recursively. 87 | # assert old_all_dataptrs == all_dataptrs 88 | self.optimization_step += 1 89 | -------------------------------------------------------------------------------- /PyriteML/multimodal_representation/multimodal/models/base_models/layers.py: -------------------------------------------------------------------------------- 1 | """Neural network layers. 2 | """ 3 | 4 | import torch.nn as nn 5 | 6 | 7 | def crop_like(input, target): 8 | if input.size()[2:] == target.size()[2:]: 9 | return input 10 | else: 11 | return input[:, :, : target.size(2), : target.size(3)] 12 | 13 | 14 | def deconv(in_planes, out_planes): 15 | return nn.Sequential( 16 | nn.ConvTranspose2d( 17 | in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=False 18 | ), 19 | nn.LeakyReLU(0.1, inplace=True), 20 | ) 21 | 22 | 23 | def predict_flow(in_planes): 24 | return nn.Conv2d(in_planes, 2, kernel_size=3, stride=1, padding=1, bias=False) 25 | 26 | 27 | def conv2d(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, bias=True): 28 | """`same` convolution with LeakyReLU, i.e. output shape equals input shape. 29 | Args: 30 | in_planes (int): The number of input feature maps. 31 | out_planes (int): The number of output feature maps. 32 | kernel_size (int): The filter size. 33 | dilation (int): The filter dilation factor. 34 | stride (int): The filter stride. 35 | """ 36 | # compute new filter size after dilation 37 | # and necessary padding for `same` output size 38 | dilated_kernel_size = (kernel_size - 1) * (dilation - 1) + kernel_size 39 | same_padding = (dilated_kernel_size - 1) // 2 40 | 41 | return nn.Sequential( 42 | nn.Conv2d( 43 | in_channels, 44 | out_channels, 45 | kernel_size=kernel_size, 46 | stride=stride, 47 | padding=same_padding, 48 | dilation=dilation, 49 | bias=bias, 50 | ), 51 | nn.LeakyReLU(0.1, inplace=True), 52 | ) 53 | 54 | 55 | class View(nn.Module): 56 | def __init__(self, size): 57 | super(View, self).__init__() 58 | self.size = size 59 | 60 | def forward(self, tensor): 61 | return tensor.view(self.size) 62 | 63 | 64 | class Flatten(nn.Module): 65 | """Flattens convolutional feature maps for fc layers. 66 | """ 67 | 68 | def __init__(self): 69 | super().__init__() 70 | 71 | def forward(self, x): 72 | return x.reshape(x.size(0), -1) 73 | 74 | 75 | class CausalConv1D(nn.Conv1d): 76 | """A causal 1D convolution. 77 | """ 78 | 79 | def __init__( 80 | self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True 81 | ): 82 | self.__padding = (kernel_size - 1) * dilation 83 | 84 | super().__init__( 85 | in_channels, 86 | out_channels, 87 | kernel_size=kernel_size, 88 | stride=stride, 89 | padding=self.__padding, 90 | dilation=dilation, 91 | bias=bias, 92 | ) 93 | 94 | def forward(self, x): 95 | res = super().forward(x) 96 | if self.__padding != 0: 97 | return res[:, :, : -self.__padding] 98 | return res 99 | 100 | 101 | class ResidualBlock(nn.Module): 102 | """A simple residual block. 103 | """ 104 | 105 | def __init__(self, channels): 106 | super().__init__() 107 | 108 | self.conv1 = conv2d(channels, channels, bias=False) 109 | self.conv2 = conv2d(channels, channels, bias=False) 110 | self.bn1 = nn.BatchNorm2d(channels) 111 | self.bn2 = nn.BatchNorm2d(channels) 112 | self.act = nn.LeakyReLU(0.1, inplace=True) # nn.ReLU(inplace=True) 113 | 114 | def forward(self, x): 115 | out = self.act(x) 116 | out = self.act(self.bn1(self.conv1(out))) 117 | out = self.bn2(self.conv2(out)) 118 | return out + x 119 | -------------------------------------------------------------------------------- /PyriteML/online_learning/configs/config_v1.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | if "PYRITE_HARDWARE_CONFIG_FOLDERS" not in os.environ: 4 | raise ValueError( 5 | "Please set the environment variable PYRITE_HARDWARE_CONFIG_FOLDERS" 6 | ) 7 | if "PYRITE_DATASET_FOLDERS" not in os.environ: 8 | raise ValueError("Please set the environment variable PYRITE_DATASET_FOLDERS") 9 | if "PYRITE_CONTROL_LOG_FOLDERS" not in os.environ: 10 | raise ValueError("Please set the environment variable PYRITE_CONTROL_LOG_FOLDERS") 11 | 12 | hardware_config_folder_path = os.environ.get("PYRITE_HARDWARE_CONFIG_FOLDERS") 13 | data_folder_path = os.environ.get("PYRITE_DATASET_FOLDERS") 14 | 15 | run_learner_on_server = False 16 | 17 | control_para = { 18 | "raw_time_step_s": 0.002, # dt of raw data collection. Used to compute time step from time_s such that the downsampling according to shape_meta works. 19 | "slow_down_factor": 1, # set to 2 to slow down the execution by 2x. Does not affect data saving rate 20 | "sparse_execution_horizon": 12, # execution horizon for the base policy 21 | "delay_tolerance_s": 0.3, # delay larger than this will trigger termination 22 | "max_duration_s": 3500, # actor will quit after this long 23 | "test_nominal_target": False, 24 | "translational_stiffness": [1000, 1000, 1000], 25 | "rotational_stiffness": 70, # or 25 for belt task 26 | "send_transitions_to_server": True, # if False, actor will not send robot data to the learner server. Useful for evaluation 27 | "no_visual_mode": False, 28 | "device": "cuda", 29 | # below are debugging options. Disabled by default 30 | "fix_orientation": False, 31 | "scale_and_cap_residual_action": False, 32 | "residual_action_scale_ratio": 1.0, # ratio of residual to nominal action 33 | } 34 | 35 | hardware_para = { 36 | "hardware_config_path": hardware_config_folder_path + "/belt_assembly.yaml", 37 | } 38 | 39 | learner_para = { 40 | # if residual_ckpt_path is specified and points to a valid checkpoint, the residual learner will load the checkpoint 41 | # Note that currently the new checkpoints will still be saved to a new folder 42 | "residual_ckpt_path": None, 43 | # "residual_ckpt_path": "/2025.05.13_14.21.08_belt_residual_no_base_action_online_residual_mlp", 44 | 45 | "num_episodes_before_first_training": 50, # start training after a good number of episodes collected 46 | 47 | # Below are parameters for multi-batch training. 48 | # These features are not used in the paper. 49 | "num_of_initial_episodes": 0, # first n episodes, both correction and no correction data are used 50 | "num_of_new_episodes": 0, # last n episodes to sample 50% from. set to 0 to disable 51 | } 52 | 53 | # online_learning_para needs to be the same across learner and actor 54 | online_learning_para = { 55 | "data_folder_path": data_folder_path + "/online_belt/", # where to save online correction data 56 | "policy_workspace_config_name": "train_online_residual_mlp_workspace", # workspace config name for the residual policy 57 | "transformers": False, 58 | "network_weight_topic": "network_weights_topic", 59 | "transitions_topic": "transitions_topic", 60 | "network_weight_expire_time_s": 3500, # time after which the actor-learner communication is considered lost 61 | "transitions_topic_expire_time_s": 3500, # time after which the actor-learner communication is considered lost 62 | } 63 | 64 | if run_learner_on_server: 65 | online_learning_para["network_server_endpoint"] = "tcp://localhost:18889" 66 | online_learning_para["transitions_server_endpoint"] = "tcp://localhost:18888" 67 | else: 68 | # IPC is a lot faster than tcp 69 | online_learning_para["network_server_endpoint"] = "ipc:///tmp/feeds/2" 70 | online_learning_para["transitions_server_endpoint"] = "ipc:///tmp/feeds/3" -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /PyriteML/diffusion_policy/model/diffusion/mlp.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | import logging 3 | import torch 4 | import torch.nn as nn 5 | import einops 6 | from einops.layers.torch import Rearrange 7 | from collections import OrderedDict 8 | 9 | from diffusion_policy.model.diffusion.conv1d_components import ( 10 | Downsample1d, 11 | Upsample1d, 12 | Conv1dBlock, 13 | ) 14 | from diffusion_policy.model.diffusion.positional_embedding import SinusoidalPosEmb 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | class MLP(nn.Module): 20 | def __init__( 21 | self, 22 | in_channels, 23 | out_channels, 24 | hidden_channels, # e.g [256, 128] 25 | ): 26 | super().__init__() 27 | 28 | # TO try: 29 | # 1. residual implementation with feedforward 30 | # 2. experiment with layers 2~8 31 | # 3. experiment with hidden layer dimension 512~2048 32 | self.blocks = nn.Sequential( 33 | OrderedDict( 34 | [ 35 | ("dense1", nn.Linear(in_channels, hidden_channels[0])), 36 | ("act1", nn.ReLU()), 37 | ("dense2", nn.Linear(hidden_channels[0], hidden_channels[1])), 38 | ("act2", nn.ReLU()), 39 | ("output", nn.Linear(hidden_channels[1], out_channels)), 40 | ] 41 | ) 42 | ) 43 | 44 | def forward( 45 | self, 46 | x, # (B, :) 47 | ): 48 | """ 49 | returns: 50 | out : [ batch_size x action_dimension ] 51 | """ 52 | out = self.blocks(x) 53 | return out 54 | 55 | 56 | class MLP_conditioned_with_time_encoding(nn.Module): 57 | def __init__( 58 | self, 59 | in_channels, 60 | out_channels, 61 | action_horizon, # used for time encoding 62 | ): 63 | super().__init__() 64 | 65 | # get encoder of time stamp 66 | dsed = in_channels 67 | time_encoder = nn.Sequential( 68 | SinusoidalPosEmb(dsed, max_value=action_horizon), 69 | nn.Linear(dsed, dsed * 4), 70 | nn.Mish(), 71 | nn.Linear(dsed * 4, dsed), 72 | ) 73 | 74 | # TO try: 75 | # 1. residual implementation with feedforward 76 | # 2. experiment with layers 2~8 77 | # 3. experiment with hidden layer dimension 512~2048 78 | self.blocks = nn.Sequential( 79 | OrderedDict( 80 | [ 81 | ("dense1", nn.Linear(in_channels, 256)), 82 | ("act1", nn.ReLU()), 83 | ("dense2", nn.Linear(256, 256)), 84 | ("act2", nn.ReLU()), 85 | ("output", nn.Linear(256, out_channels)), 86 | ] 87 | ) 88 | ) 89 | 90 | self.time_encoder = time_encoder 91 | 92 | def forward( 93 | self, 94 | x, # (B, :) 95 | cond, # (B, :) 96 | timestep, # (1) 97 | ): 98 | """ 99 | returns: 100 | out : [ batch_size x action_dimension ] 101 | """ 102 | # 1. time 103 | timesteps = timestep 104 | if not torch.is_tensor(timesteps): 105 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 106 | timesteps = torch.tensor([timesteps], dtype=torch.long, device=x.device) 107 | elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: 108 | timesteps = timesteps[None].to(x.device) 109 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 110 | timesteps = timesteps.expand(x.shape[0]) 111 | time_encoding = self.time_encoder(timesteps) # (B, diffusion_step_embed_dim) 112 | 113 | global_feature = torch.cat([x, cond], axis=-1) 114 | global_feature = global_feature + time_encoding 115 | 116 | out = self.blocks(global_feature) 117 | return out 118 | -------------------------------------------------------------------------------- /PyriteUtility/.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | example_demo_session 3 | data 4 | data_local 5 | data_workspace 6 | outputs 7 | wandb 8 | **/.DS_Store 9 | *.lprof 10 | 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | share/python-wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .nox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | *.py,cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | cover/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | db.sqlite3 72 | db.sqlite3-journal 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | .pybuilder/ 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | # pyenv 96 | # For a library or package, you might want to ignore these files since the code is 97 | # intended to run in multiple environments; otherwise, check them in: 98 | # .python-version 99 | 100 | # pipenv 101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 103 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 104 | # install all needed dependencies. 105 | #Pipfile.lock 106 | 107 | # poetry 108 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 109 | # This is especially recommended for binary packages to ensure reproducibility, and is more 110 | # commonly ignored for libraries. 111 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 112 | #poetry.lock 113 | 114 | # pdm 115 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 116 | #pdm.lock 117 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 118 | # in version control. 119 | # https://pdm.fming.dev/#use-with-ide 120 | .pdm.toml 121 | 122 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 123 | __pypackages__/ 124 | 125 | # Celery stuff 126 | celerybeat-schedule 127 | celerybeat.pid 128 | 129 | # SageMath parsed files 130 | *.sage.py 131 | 132 | # Environments 133 | .env 134 | .venv 135 | env/ 136 | venv/ 137 | ENV/ 138 | env.bak/ 139 | venv.bak/ 140 | 141 | # add diffusion_policy/env/ 142 | !diffusion_policy/env/ 143 | 144 | # Spyder project settings 145 | .spyderproject 146 | .spyproject 147 | 148 | # Rope project settings 149 | .ropeproject 150 | 151 | # mkdocs documentation 152 | /site 153 | 154 | # mypy 155 | .mypy_cache/ 156 | .dmypy.json 157 | dmypy.json 158 | 159 | # Pyre type checker 160 | .pyre/ 161 | 162 | # pytype static type analyzer 163 | .pytype/ 164 | 165 | # Cython debug symbols 166 | cython_debug/ 167 | 168 | # PyCharm 169 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 170 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 171 | # and can be added to the global gitignore or merged into this file. For a more nuclear 172 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 173 | #.idea/ 174 | -------------------------------------------------------------------------------- /PyriteUtility/PyriteUtility/data_pipeline/real_data_check_keys.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | SCRIPT_PATH = os.path.abspath(os.path.dirname(__file__)) 5 | sys.path.append(os.path.join(SCRIPT_PATH, "../../")) 6 | 7 | import pathlib 8 | import numpy as np 9 | import zarr 10 | import cv2 11 | import concurrent.futures 12 | 13 | CORRECTION = True # set to true if you want to use the correction data 14 | 15 | # check environment variables 16 | if "PYRITE_RAW_DATASET_FOLDERS" not in os.environ: 17 | raise ValueError("Please set the environment variable PYRITE_RAW_DATASET_FOLDERS") 18 | if "PYRITE_DATASET_FOLDERS" not in os.environ: 19 | raise ValueError("Please set the environment variable PYRITE_DATASET_FOLDERS") 20 | 21 | 22 | # specify the input and output directories 23 | id_list = [0] # single robot 24 | # id_list = [0, 1] # bimanual 25 | 26 | ft_sensor_configuration = "handle_on_robot" # "handle_on_sensor" or "handle_on_robot" 27 | 28 | output_dir = pathlib.Path( 29 | os.environ.get("PYRITE_DATASET_FOLDERS") + "/online_belt_v2/raw" 30 | ) 31 | 32 | # open the zarr store 33 | store = zarr.DirectoryStore(path=output_dir) 34 | root = zarr.open(store=store, mode="a") 35 | 36 | print("Reading data from input_dir: ", output_dir) 37 | episode_names = os.listdir(output_dir) 38 | 39 | episode_config = { 40 | "output_dir": output_dir, 41 | "id_list": id_list, 42 | "ft_sensor_configuration": ft_sensor_configuration, 43 | "num_threads": 1, 44 | "has_correction": CORRECTION, 45 | "save_video": False, 46 | "max_workers": 32 47 | } 48 | 49 | 50 | import pandas as pd 51 | 52 | def check_keys(episode_name, output_root, config): 53 | if episode_name.startswith("."): 54 | return True 55 | 56 | # print(f"[process_one_episode_into_zarr] episode_name: {episode_name}, episode_id: {episode_id}") 57 | episode_dir = pathlib.Path(config["output_dir"]).joinpath(episode_name) 58 | try: 59 | json_path = episode_dir.joinpath("key_data.json") 60 | df_key_data = pd.read_json(json_path) 61 | data_key = np.vstack(df_key_data["key_event"]).flatten() 62 | # Check if the data_key is empty 63 | if data_key.size == 0: 64 | print(f"\033[31mError: No key data found for {episode_dir}\033[0m.") 65 | return False 66 | if len(data_key) == 1 and data_key[0] == -1: 67 | print(f"\033[31mError: there's no correction for {episode_dir}\033[0m.") 68 | print(data_key) 69 | return False 70 | 71 | for i in range(0, len(data_key)-1, 2): 72 | if (data_key[i] != 1 and data_key[i] != -1) or (data_key[i + 1] != 0 and data_key[i + 1] != -1): 73 | print(f"\033[31mError: Key data is not in the correct format for {episode_dir}\033[0m.") 74 | print(data_key) 75 | return False 76 | 77 | except FileNotFoundError: 78 | # Handle the case where the file does not exist 79 | print(f"\033[31mError: JSON file not found for {episode_dir}\033[0m.") 80 | except ValueError as e: 81 | # Handle the case where the JSON is invalid 82 | print(f"\033[31mError: Invalid JSON format - {e} - for {episode_dir}\033[0m.") 83 | except Exception as e: 84 | # Handle other exceptions that might occur 85 | print(f"\033[31mAn unexpected error occurred for {episode_dir}: {e}\033[0m.") 86 | else: 87 | # This block executes if no exception occurs 88 | # print(f"{episode_dir} JSON files are successfully read.") 89 | pass 90 | 91 | 92 | with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor: 93 | futures = [ 94 | executor.submit( 95 | check_keys, 96 | episode_name, 97 | root, 98 | episode_config, 99 | ) 100 | for episode_name in sorted(episode_names) 101 | ] 102 | for future in concurrent.futures.as_completed(futures): 103 | if not future.result(): 104 | raise RuntimeError("Multi-processing failed!") 105 | 106 | print("Finished reading.") 107 | 108 | -------------------------------------------------------------------------------- /PyriteUtility/PyriteUtility/data_pipeline/indexing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def get_sample_ids(query, horizon, down_sample_steps, backwards=False, closed=False): 5 | """ 6 | Get the sample ids for the down-sampling of the data. 7 | :param query: (Q,) The query id(s). 8 | :param horizon: Integer scalar. The horizon H of the data. 9 | :param down_sample_steps:Integer scalar. The number of steps to down-sample. 10 | :param backwards: Boolean. Whether to sample backwards from the query point. 11 | :param closed: Boolean. Whether to include the closed end. 12 | 13 | :return: The sampled ids (Q, H) 14 | """ 15 | if isinstance(query, list): 16 | query = np.array(query) 17 | 18 | if closed: 19 | sample_horizon = horizon + 1 20 | else: 21 | sample_horizon = horizon 22 | 23 | local_id = np.arange(sample_horizon) * down_sample_steps 24 | 25 | if not backwards: 26 | return [query_i + local_id for query_i in query] 27 | else: 28 | local_id = local_id[::-1] 29 | return [query_i - local_id for query_i in query] 30 | 31 | 32 | def get_samples( 33 | raw_data, query, horizon, down_sample_steps, backwards=False, closed=False 34 | ): 35 | """ 36 | Get the sample ids for the down-sampling of the data. 37 | :param raw_data: (T, D) The raw data. D could have multiple dimnsions. 38 | For other parameters, see get_sample_ids. 39 | 40 | :return: The sampled data (Q, H, ...). 41 | """ 42 | return raw_data[ 43 | get_sample_ids(query, horizon, down_sample_steps, backwards, closed) 44 | ] 45 | 46 | 47 | def get_dense_query_points_in_horizon( 48 | sparse_total_steps_per_horizon, 49 | dense_action_horizon, 50 | dense_action_down_sample_steps, 51 | delta_steps, 52 | ): 53 | """ 54 | Get the dense query points in the horizon. The queries, which are indices in the original 55 | raw array, is also used as time steps of dense queries. 56 | Adjacent query points are delta_steps apart. 57 | 58 | Example when delta_steps = 3, sparse total steps = 20, dense total steps = 6: 59 | sparse raw steps: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 60 | dense queries: * - - * - - * - - * - - * - - 61 | dense horizons 0 1 2 3 4 5 62 | 0 1 2 3 4 5 63 | 0 1 2 3 4 5 64 | 0 1 2 3 4 5 65 | 0 1 2 3 4 5 66 | num of queries = 5, since the 6th query is out of the horizon. 67 | 68 | :param delta_steps: Integer scalar. The number of steps between two adjacent query points. 69 | """ 70 | 71 | sparse_raw_steps = sparse_total_steps_per_horizon 72 | dense_raw_steps = dense_action_horizon * dense_action_down_sample_steps 73 | num_of_queries = (sparse_raw_steps - dense_raw_steps) // delta_steps + 1 74 | queries = np.arange(num_of_queries) * delta_steps 75 | return queries 76 | 77 | 78 | def test(): 79 | # test 80 | raw_data = 10 * np.arange(20) 81 | query = [10, 12] 82 | horizon = 3 83 | down_sample_steps = 2 84 | 85 | print("== Test get_samples ==") 86 | s = get_samples( 87 | raw_data, query, horizon, down_sample_steps, backwards=True, closed=False 88 | ) 89 | print("raw_data:", raw_data) 90 | print("query:", query) 91 | print("horizon:", horizon) 92 | print("down_sample_steps:", down_sample_steps) 93 | 94 | print(s) 95 | 96 | print("== Test get_dense_query_points_in_horizon ==") 97 | sparse_action_horizon = 5 98 | sparse_action_down_sample_steps = 4 99 | dense_action_horizon = 2 100 | dense_action_down_sample_steps = 3 101 | delta_steps = 3 102 | queries = get_dense_query_points_in_horizon( 103 | sparse_action_horizon, 104 | sparse_action_down_sample_steps, 105 | dense_action_horizon, 106 | dense_action_down_sample_steps, 107 | delta_steps, 108 | ) 109 | print(queries) 110 | 111 | 112 | if __name__ == "__main__": 113 | test() 114 | -------------------------------------------------------------------------------- /PyriteML/diffusion_policy/common/json_logger.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Callable, Any, Sequence 2 | import os 3 | import copy 4 | import json 5 | import numbers 6 | import pandas as pd 7 | 8 | 9 | def read_json_log(path: str, 10 | required_keys: Sequence[str]=tuple(), 11 | **kwargs) -> pd.DataFrame: 12 | """ 13 | Read json-per-line file, with potentially incomplete lines. 14 | kwargs passed to pd.read_json 15 | """ 16 | lines = list() 17 | with open(path, 'r') as f: 18 | while True: 19 | # one json per line 20 | line = f.readline() 21 | if len(line) == 0: 22 | # EOF 23 | break 24 | elif not line.endswith('\n'): 25 | # incomplete line 26 | break 27 | is_relevant = False 28 | for k in required_keys: 29 | if k in line: 30 | is_relevant = True 31 | break 32 | if is_relevant: 33 | lines.append(line) 34 | if len(lines) < 1: 35 | return pd.DataFrame() 36 | json_buf = f'[{",".join([line for line in (line.strip() for line in lines) if line])}]' 37 | df = pd.read_json(json_buf, **kwargs) 38 | return df 39 | 40 | class JsonLogger: 41 | def __init__(self, path: str, 42 | filter_fn: Optional[Callable[[str,Any],bool]]=None): 43 | if filter_fn is None: 44 | filter_fn = lambda k,v: isinstance(v, numbers.Number) 45 | 46 | # default to append mode 47 | self.path = path 48 | self.filter_fn = filter_fn 49 | self.file = None 50 | self.last_log = None 51 | 52 | def start(self): 53 | # use line buffering 54 | try: 55 | self.file = file = open(self.path, 'r+', buffering=1) 56 | except FileNotFoundError: 57 | self.file = file = open(self.path, 'w+', buffering=1) 58 | 59 | # Move the pointer (similar to a cursor in a text editor) to the end of the file 60 | pos = file.seek(0, os.SEEK_END) 61 | 62 | # Read each character in the file one at a time from the last 63 | # character going backwards, searching for a newline character 64 | # If we find a new line, exit the search 65 | while pos > 0 and file.read(1) != "\n": 66 | pos -= 1 67 | file.seek(pos, os.SEEK_SET) 68 | # now the file pointer is at one past the last '\n' 69 | # and pos is at the last '\n'. 70 | last_line_end = file.tell() 71 | 72 | # find the start of second last line 73 | pos = max(0, pos-1) 74 | file.seek(pos, os.SEEK_SET) 75 | while pos > 0 and file.read(1) != "\n": 76 | pos -= 1 77 | file.seek(pos, os.SEEK_SET) 78 | # now the file pointer is at one past the second last '\n' 79 | last_line_start = file.tell() 80 | 81 | if last_line_start < last_line_end: 82 | # has last line of json 83 | last_line = file.readline() 84 | self.last_log = json.loads(last_line) 85 | 86 | # remove the last incomplete line 87 | file.seek(last_line_end) 88 | file.truncate() 89 | 90 | def stop(self): 91 | self.file.close() 92 | self.file = None 93 | 94 | def __enter__(self): 95 | self.start() 96 | return self 97 | 98 | def __exit__(self, exc_type, exc_val, exc_tb): 99 | self.stop() 100 | 101 | def log(self, data: dict): 102 | filtered_data = dict( 103 | filter(lambda x: self.filter_fn(*x), data.items())) 104 | # save current as last log 105 | self.last_log = filtered_data 106 | for k, v in filtered_data.items(): 107 | if isinstance(v, numbers.Integral): 108 | filtered_data[k] = int(v) 109 | elif isinstance(v, numbers.Number): 110 | filtered_data[k] = float(v) 111 | buf = json.dumps(filtered_data) 112 | # ensure one line per json 113 | buf = buf.replace('\n','') + '\n' 114 | self.file.write(buf) 115 | 116 | def get_last_log(self): 117 | return copy.deepcopy(self.last_log) 118 | -------------------------------------------------------------------------------- /PyriteUtility/PyriteUtility/data_pipeline/plot_correction_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import zarr 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import cv2 6 | import pandas as pd 7 | import PyriteUtility.spatial_math.spatial_utilities as su 8 | 9 | def load_episode_data(episode_folder): 10 | """Load data from a given episode folder.""" 11 | episode_data = {} 12 | 13 | # Load policy inference data 14 | policy_inference_path = os.path.join(episode_folder, "policy_inference.zarr") 15 | if os.path.exists(policy_inference_path): 16 | policy_data = zarr.open(policy_inference_path, mode='r') 17 | episode_data['ts_targets'] = policy_data[f'ts_targets_0'][:].reshape(-1, 7) 18 | episode_data['timestamps'] = policy_data['timestamps_s'][:].reshape(-1) 19 | 20 | # Load robot data 21 | robot_data_path = os.path.join(episode_folder, "robot_data_0.json") 22 | rgb_folder = os.path.join(episode_folder, "rgb_0") 23 | 24 | if os.path.exists(robot_data_path): 25 | episode_data['robot_data'] = pd.read_json(robot_data_path) 26 | 27 | if os.path.exists(rgb_folder): 28 | episode_data['rgb_files'] = sorted([os.path.join(rgb_folder, f) for f in os.listdir(rgb_folder) if f.endswith('.jpg') or f.endswith('.png')]) 29 | 30 | return episode_data 31 | 32 | def create_video_from_rgb(rgb_files, output_dir="", fps=30): 33 | """Create a video from RGB images.""" 34 | if not rgb_files: 35 | print("No RGB images found.") 36 | return 37 | 38 | frame = cv2.imread(rgb_files[0]) 39 | height, width, _ = frame.shape 40 | 41 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') 42 | video_writer = cv2.VideoWriter(os.path.join(output_dir, "video.mp4"), fourcc, fps, (width, height)) 43 | 44 | for file in rgb_files: 45 | frame = cv2.imread(file) 46 | video_writer.write(frame) 47 | 48 | video_writer.release() 49 | 50 | def plot_policy_vs_robot_data(robot_data, ts_targets, timestamps, output_dir=""): 51 | """Plot policy inference data with robot data for comparison, aligned by timestamps.""" 52 | 53 | # Align timestamps 54 | robot_timestamps = robot_data['robot_time_stamps'].to_numpy()/ 1000.0 55 | predicted_ts_id = np.searchsorted(timestamps, robot_timestamps) 56 | Npolicy = len(timestamps) 57 | predicted_ts_id = np.minimum(predicted_ts_id, Npolicy - 1) 58 | 59 | robot_data = np.vstack(robot_data['ts_pose_fb']) 60 | predicted_data = ts_targets[predicted_ts_id] 61 | 62 | # transform from pose7 to pose3 63 | feedback_pose = su.pose7_to_SE3(robot_data) 64 | predicted_pose = su.pose7_to_SE3(predicted_data) 65 | # for i, column in enumerate(['x', 'y', 'z']): # Assuming first 3 columns are position data 66 | # plt.plot(timestamps, aligned_robot_data[column], label=f'Robot {column}', linestyle='dashed') 67 | # plt.plot(timestamps, ts_targets[:, i], label=f'Policy {column}', linestyle='solid') 68 | 69 | fig = plt.figure() 70 | ax = plt.axes(projection='3d') 71 | ax.scatter(feedback_pose[:, 0], feedback_pose[:, 1], feedback_pose[:, 2], label='Robot Position', color='red', s=1) 72 | ax.scatter(predicted_pose[:, 0], predicted_pose[:, 1], predicted_pose[:, 2], label='Policy Position', color='blue', s=1) 73 | 74 | plt.title("Policy Inference vs Robot Data") 75 | plt.legend() 76 | plt.grid() 77 | plt.savefig(os.path.join(output_dir, "policy_vs_robot.png")) 78 | 79 | def visualize_episode(episode_folder): 80 | """Load and visualize an episode.""" 81 | print(f"Loading episode from {episode_folder}") 82 | data = load_episode_data(episode_folder) 83 | 84 | if 'robot_data' in data and 'ts_targets' in data and 'timestamps' in data: 85 | plot_policy_vs_robot_data(data['robot_data'], data['ts_targets'], data['timestamps'], output_dir='.') 86 | 87 | if 'rgb_files' in data: 88 | create_video_from_rgb(data['rgb_files'], output_dir='.') 89 | 90 | if __name__ == "__main__": 91 | dataset_folder = '/shared_local/data/raw/correction_new' 92 | # episode_folders = [os.path.join(dataset_folder, ep) for ep in os.listdir(dataset_folder)] 93 | episode_folders = [os.path.join(dataset_folder, ep) for ep in sorted(os.listdir(dataset_folder))[:1]] 94 | 95 | for ep in episode_folders: 96 | visualize_episode(ep) 97 | -------------------------------------------------------------------------------- /PyriteUtility/PyriteUtility/data_pipeline/real_data_check_timing.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | SCRIPT_PATH = os.path.abspath(os.path.dirname(__file__)) 5 | sys.path.append(os.path.join(SCRIPT_PATH, "../../")) 6 | 7 | from PyriteUtility.data_pipeline.processing_functions import process_one_episode_into_zarr, generate_meta_for_zarr 8 | 9 | import pathlib 10 | import shutil 11 | import numpy as np 12 | import zarr 13 | import cv2 14 | import concurrent.futures 15 | 16 | CORRECTION = False # set to true if you want to use the correction data 17 | 18 | # check environment variables 19 | if "PYRITE_RAW_DATASET_FOLDERS" not in os.environ: 20 | raise ValueError("Please set the environment variable PYRITE_RAW_DATASET_FOLDERS") 21 | if "PYRITE_DATASET_FOLDERS" not in os.environ: 22 | raise ValueError("Please set the environment variable PYRITE_DATASET_FOLDERS") 23 | 24 | 25 | # specify the input and output directories 26 | id_list = [0] # single robot 27 | # id_list = [0, 1] # bimanual 28 | 29 | ft_sensor_configuration = "handle_on_robot" # "handle_on_sensor" or "handle_on_robot" 30 | 31 | input_dir = pathlib.Path( 32 | os.environ.get("PYRITE_RAW_DATASET_FOLDERS") + "/belt_assembly_v1" 33 | ) 34 | output_dir = pathlib.Path( 35 | os.environ.get("PYRITE_DATASET_FOLDERS") + "/belt_assembly_v1" 36 | ) 37 | 38 | # open the zarr store 39 | store = zarr.DirectoryStore(path=output_dir) 40 | root = zarr.open(store=store, mode="a") 41 | 42 | print("Reading data from input_dir: ", input_dir) 43 | episode_names = os.listdir(input_dir) 44 | 45 | episode_config = { 46 | "input_dir": input_dir, 47 | "output_dir": output_dir, 48 | "id_list": id_list, 49 | "ft_sensor_configuration": ft_sensor_configuration, 50 | "num_threads": 1, 51 | "has_correction": CORRECTION, 52 | "save_video": False, 53 | "max_workers": 32 54 | } 55 | 56 | 57 | import pandas as pd 58 | 59 | def check_timing(episode_name, output_root, config): 60 | if episode_name.startswith("."): 61 | return True 62 | 63 | # info about input 64 | episode_id = episode_name[8:] 65 | # print(f"[process_one_episode_into_zarr] episode_name: {episode_name}, episode_id: {episode_id}") 66 | episode_dir = pathlib.Path(config["input_dir"]).joinpath(episode_name) 67 | 68 | # read low dim data 69 | data_ts_pose_fb = [] 70 | data_robot_time_stamps = [] 71 | data_robot_wrench = [] 72 | data_wrench = [] 73 | data_wrench_filtered = [] 74 | data_wrench_time_stamps = [] 75 | data_masks = [] 76 | # print(f"Reading low dim data for : {episode_dir}") 77 | try: 78 | for id in config["id_list"]: 79 | # read robot data 80 | json_path = episode_dir.joinpath("robot_data_" + str(id) + ".json") 81 | df_robot_data = pd.read_json(json_path) 82 | robot_time_stamps = df_robot_data["robot_time_stamps"].to_numpy() 83 | 84 | delta_time = robot_time_stamps[1:] - robot_time_stamps[:-1] 85 | average_delta_time = np.mean(delta_time) 86 | print(f"Average delta time for {episode_name}: {average_delta_time}") 87 | 88 | 89 | except FileNotFoundError: 90 | # Handle the case where the file does not exist 91 | print(f"\033[31mError: JSON file not found for {episode_dir}\033[0m.") 92 | except ValueError as e: 93 | # Handle the case where the JSON is invalid 94 | print(f"\033[31mError: Invalid JSON format - {e} - for {episode_dir}\033[0m.") 95 | except Exception as e: 96 | # Handle other exceptions that might occur 97 | print(f"\033[31mAn unexpected error occurred for {episode_dir}: {e}\033[0m.") 98 | else: 99 | # This block executes if no exception occurs 100 | # print(f"{episode_dir} JSON files are successfully read.") 101 | pass 102 | 103 | 104 | with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor: 105 | futures = [ 106 | executor.submit( 107 | check_timing, 108 | episode_name, 109 | root, 110 | episode_config, 111 | ) 112 | for episode_name in sorted(episode_names) 113 | ] 114 | for future in concurrent.futures.as_completed(futures): 115 | if not future.result(): 116 | raise RuntimeError("Multi-processing failed!") 117 | 118 | print("Finished reading.") 119 | 120 | -------------------------------------------------------------------------------- /PyriteML/diffusion_policy/model/bet/latent_generators/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import einops 5 | import diffusion_policy.model.bet.latent_generators.latent_generator as latent_generator 6 | 7 | from diffusion_policy.model.diffusion.transformer_for_diffusion import TransformerForDiffusion 8 | from diffusion_policy.model.bet.libraries.loss_fn import FocalLoss, soft_cross_entropy 9 | 10 | from typing import Optional, Tuple 11 | 12 | class Transformer(latent_generator.AbstractLatentGenerator): 13 | def __init__( 14 | self, 15 | input_dim: int, 16 | num_bins: int, 17 | action_dim: int, 18 | horizon: int, 19 | focal_loss_gamma: float, 20 | offset_loss_scale: float, 21 | **kwargs 22 | ): 23 | super().__init__() 24 | self.model = TransformerForDiffusion( 25 | input_dim=input_dim, 26 | output_dim=num_bins * (1 + action_dim), 27 | horizon=horizon, 28 | **kwargs 29 | ) 30 | self.vocab_size = num_bins 31 | self.focal_loss_gamma = focal_loss_gamma 32 | self.offset_loss_scale = offset_loss_scale 33 | self.action_dim = action_dim 34 | 35 | def get_optimizer(self, **kwargs) -> torch.optim.Optimizer: 36 | return self.model.configure_optimizers(**kwargs) 37 | 38 | def get_latent_and_loss(self, 39 | obs_rep: torch.Tensor, 40 | target_latents: torch.Tensor, 41 | return_loss_components=True, 42 | ) -> Tuple[torch.Tensor, torch.Tensor]: 43 | target_latents, target_offsets = target_latents 44 | target_latents = target_latents.view(-1) 45 | criterion = FocalLoss(gamma=self.focal_loss_gamma) 46 | 47 | t = torch.tensor(0, device=self.model.device) 48 | output = self.model(obs_rep, t) 49 | logits = output[:, :, : self.vocab_size] 50 | offsets = output[:, :, self.vocab_size :] 51 | batch = logits.shape[0] 52 | seq = logits.shape[1] 53 | offsets = einops.rearrange( 54 | offsets, 55 | "N T (V A) -> (N T) V A", # N = batch, T = seq 56 | V=self.vocab_size, 57 | A=self.action_dim, 58 | ) 59 | # calculate (optionally soft) cross entropy and offset losses 60 | class_loss = criterion(logits.view(-1, logits.size(-1)), target_latents) 61 | # offset loss is only calculated on the target class 62 | # if soft targets, argmax is considered the target class 63 | selected_offsets = offsets[ 64 | torch.arange(offsets.size(0)), 65 | target_latents.view(-1), 66 | ] 67 | offset_loss = self.offset_loss_scale * F.mse_loss( 68 | selected_offsets, target_offsets.view(-1, self.action_dim) 69 | ) 70 | loss = offset_loss + class_loss 71 | logits = einops.rearrange(logits, "batch seq classes -> seq batch classes") 72 | offsets = einops.rearrange( 73 | offsets, 74 | "(N T) V A -> T N V A", # ? N, T order? Anyway does not affect loss and training (might affect visualization) 75 | N=batch, 76 | T=seq, 77 | ) 78 | return ( 79 | (logits, offsets), 80 | loss, 81 | {"offset": offset_loss, "class": class_loss, "total": loss}, 82 | ) 83 | 84 | def generate_latents( 85 | self, obs_rep: torch.Tensor 86 | ) -> torch.Tensor: 87 | t = torch.tensor(0, device=self.model.device) 88 | output = self.model(obs_rep, t) 89 | logits = output[:, :, : self.vocab_size] 90 | offsets = output[:, :, self.vocab_size :] 91 | offsets = einops.rearrange( 92 | offsets, 93 | "N T (V A) -> (N T) V A", # N = batch, T = seq 94 | V=self.vocab_size, 95 | A=self.action_dim, 96 | ) 97 | 98 | probs = F.softmax(logits, dim=-1) 99 | batch, seq, choices = probs.shape 100 | # Sample from the multinomial distribution, one per row. 101 | sampled_data = torch.multinomial(probs.view(-1, choices), num_samples=1) 102 | sampled_data = einops.rearrange( 103 | sampled_data, "(batch seq) 1 -> batch seq 1", batch=batch, seq=seq 104 | ) 105 | sampled_offsets = offsets[ 106 | torch.arange(offsets.shape[0]), sampled_data.flatten() 107 | ].view(batch, seq, self.action_dim) 108 | return (sampled_data, sampled_offsets) 109 | -------------------------------------------------------------------------------- /PyriteUtility/PyriteUtility/data_pipeline/real_data_processing.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | SCRIPT_PATH = os.path.abspath(os.path.dirname(__file__)) 5 | sys.path.append(os.path.join(SCRIPT_PATH, "../../")) 6 | 7 | from PyriteUtility.data_pipeline.processing_functions import process_one_episode_into_zarr, generate_meta_for_zarr 8 | 9 | import pathlib 10 | import shutil 11 | import numpy as np 12 | import zarr 13 | import cv2 14 | import concurrent.futures 15 | 16 | CORRECTION = False # set to true if you want to use the correction data 17 | 18 | # check environment variables 19 | if "PYRITE_RAW_DATASET_FOLDERS" not in os.environ: 20 | raise ValueError("Please set the environment variable PYRITE_RAW_DATASET_FOLDERS") 21 | if "PYRITE_DATASET_FOLDERS" not in os.environ: 22 | raise ValueError("Please set the environment variable PYRITE_DATASET_FOLDERS") 23 | 24 | 25 | # specify the input and output director 26 | id_list = [0] # single robot 27 | # id_list = [0, 1] # bimanual 28 | 29 | ft_sensor_configuration = "handle_on_robot" # "handle_on_sensor" or "handle_on_robot" 30 | 31 | input_dir = pathlib.Path( 32 | os.environ.get("PYRITE_RAW_DATASET_FOLDERS") + "/belt_assembly_50" 33 | ) 34 | # input_dir = pathlib.Path( 35 | # os.environ.get("PYRITE_DATASET_FOLDERS") + "/online_stow_nb_v5_50/raw" 36 | # ) 37 | output_dir = pathlib.Path( 38 | os.environ.get("PYRITE_DATASET_FOLDERS") + "/belt_assembly_offline_50" 39 | ) 40 | 41 | # clean and create output folders 42 | if os.path.exists(output_dir): 43 | shutil.rmtree(output_dir) 44 | 45 | # # check for black images 46 | # def check_black_images(rgb_file_list, rgb_dir, i, prefix): 47 | # f = rgb_file_list[i] 48 | # img = cv2.imread(str(rgb_dir.joinpath(f))) 49 | # # print the mean of the image 50 | # img_mean = np.mean(img) 51 | # if img_mean < 50: 52 | # print(f"{prefix}, {f} has mean value of {img_mean}") 53 | # return True 54 | 55 | 56 | # for episode_name in os.listdir(input_dir): 57 | # if episode_name.startswith("."): 58 | # continue 59 | 60 | # episode_dir = input_dir.joinpath(episode_name) 61 | # for id in id_list: 62 | # rgb_dir = episode_dir.joinpath("rgb_" + str(id)) 63 | # rgb_file_list = os.listdir(rgb_dir) 64 | # num_raw_images = len(rgb_file_list) 65 | # print(f"Checking for black images in {rgb_dir}") 66 | # with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: 67 | # futures = set() 68 | # for i in range(len(rgb_file_list)): 69 | # futures.add( 70 | # executor.submit( 71 | # check_black_images, 72 | # rgb_file_list, 73 | # rgb_dir, 74 | # i, 75 | # f"{episode_name} rgb_{id}", 76 | # ) 77 | # ) 78 | 79 | # completed, futures = concurrent.futures.wait(futures) 80 | # for f in completed: 81 | # if not f.result(): 82 | # raise RuntimeError("Failed to read image!") 83 | # exit() 84 | 85 | # open the zarr store 86 | store = zarr.DirectoryStore(path=output_dir) 87 | root = zarr.open(store=store, mode="a") 88 | 89 | print("Reading data from input_dir: ", input_dir) 90 | episode_names = os.listdir(input_dir) 91 | 92 | episode_config = { 93 | "input_dir": input_dir, 94 | "output_dir": output_dir, 95 | "id_list": id_list, 96 | "ft_sensor_configuration": ft_sensor_configuration, 97 | "num_threads": 10, 98 | "has_correction": CORRECTION, 99 | "save_video": False, 100 | "max_workers": 32 101 | } 102 | 103 | with concurrent.futures.ProcessPoolExecutor(max_workers=3) as executor: 104 | futures = [ 105 | executor.submit( 106 | process_one_episode_into_zarr, 107 | episode_name, 108 | root, 109 | episode_config, 110 | ) 111 | for episode_name in episode_names 112 | ] 113 | for future in concurrent.futures.as_completed(futures): 114 | if not future.result(): 115 | raise RuntimeError("Multi-processing failed!") 116 | 117 | print("Finished reading. Now start generating metadata") 118 | from PyriteUtility.computer_vision.imagecodecs_numcodecs import register_codecs 119 | 120 | register_codecs() 121 | 122 | 123 | count = generate_meta_for_zarr(root, episode_config) 124 | print(f"All done! Generated {count} episodes in {output_dir}") 125 | print("The only thing left is to run postprocess_add_virtual_target_label.py") 126 | -------------------------------------------------------------------------------- /PyriteML/scripts/test_rmq.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "import os\n", 11 | "from typing import Dict, Callable, Tuple, List\n", 12 | "\n", 13 | "SCRIPT_PATH = \"/home/yifanhou/git/PyriteML/scripts\"\n", 14 | "sys.path.append(os.path.join(SCRIPT_PATH, '../'))\n", 15 | "\n", 16 | "\n", 17 | "import numpy as np\n", 18 | "import torch\n", 19 | "import time\n", 20 | "import dill\n", 21 | "import hydra\n", 22 | "from torch.utils.data import DataLoader\n", 23 | "\n", 24 | "\n", 25 | "from diffusion_policy.workspace.base_workspace import BaseWorkspace\n", 26 | "from diffusion_policy.dataset.base_dataset import BaseImageDataset, BaseDataset\n", 27 | "from diffusion_policy.workspace.train_diffusion_unet_image_workspace import TrainDiffusionUnetImageWorkspace\n", 28 | "\n", 29 | "data_path = \"/home/yifanhou/training_outputs/\"\n", 30 | "ckpt_path = data_path + \"2025.03.05_21.53.36_stow_no_force_202_stow_80/checkpoints/latest.ckpt\"\n", 31 | "\n", 32 | "device = torch.device('cpu')\n", 33 | "\n", 34 | "# load checkpoint\n", 35 | "if not ckpt_path.endswith('.ckpt'):\n", 36 | " ckpt_path = os.path.join(ckpt_path, 'checkpoints', 'latest.ckpt')\n", 37 | "payload = torch.load(open(ckpt_path, 'rb'), map_location='cpu', pickle_module=dill)" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "import robotmq as rmq\n", 47 | "import pickle\n", 48 | "import time\n", 49 | "import numpy as np\n", 50 | "import numpy.typing as npt\n", 51 | "\n", 52 | "server = rmq.RMQServer(\n", 53 | " server_name=\"test_rmq_server\", server_endpoint=\"ipc:///tmp/feeds/0\"\n", 54 | ")\n", 55 | "client = rmq.RMQClient(\n", 56 | " client_name=\"test_rmq_client\", server_endpoint=\"ipc:///tmp/feeds/0\"\n", 57 | ")\n", 58 | "print(\"Server and client created\")\n", 59 | "\n", 60 | "server.add_topic(\"test_checkpoints\", 10)\n", 61 | "\n", 62 | "# Serialize the checkpoint\n", 63 | "start_time = time.time()\n", 64 | "pickle_data = pickle.dumps(payload)\n", 65 | "dump_end_time = time.time()\n", 66 | "server.put_data(\"test_checkpoints\", pickle_data)\n", 67 | "send_end_time = time.time()\n", 68 | "time.sleep(0.01)\n", 69 | "\n", 70 | "retrieve_start_time = time.time()\n", 71 | "retrieved_data, timestamp = client.peek_data(topic=\"test_checkpoints\", order=\"latest\", n=1)\n", 72 | "retrieve_end_time = time.time()\n", 73 | "received_data = pickle.loads(retrieved_data[0])\n", 74 | "\n", 75 | "\n", 76 | "print(\n", 77 | " f\"Data size: {len(pickle_data) / 1024**2:.3f}MB. dump: {dump_end_time - start_time:.4f}s, send: {send_end_time - dump_end_time: .4f}s, retrieve: {retrieve_end_time - retrieve_start_time:.4f}s, load: {time.time() - retrieve_end_time:.4f}s)\"\n", 78 | ")\n", 79 | "\n", 80 | "\n", 81 | "# use the received payload\n", 82 | "cfg = received_data['cfg']\n", 83 | "print(\"dataset_path:\", cfg.task.dataset.dataset_path)\n", 84 | "\n", 85 | "cls = hydra.utils.get_class(cfg._target_)\n", 86 | "workspace = cls(cfg)\n", 87 | "workspace: BaseWorkspace\n", 88 | "workspace.load_payload(received_data, exclude_keys=None, include_keys=None)\n", 89 | "\n", 90 | "policy = workspace.model\n", 91 | "if cfg.training.use_ema:\n", 92 | " policy = workspace.ema_model\n", 93 | "policy.num_inference_steps = cfg.policy.num_inference_steps # DDIM inference iterations\n", 94 | "\n", 95 | "policy.eval().to(device)\n", 96 | "policy.reset()\n", 97 | "\n", 98 | "# use normalizer saved in the policy\n", 99 | "sparse_normalizer, dense_normalizer = policy.get_normalizer()\n", 100 | "\n", 101 | "shape_meta = cfg.task.shape_meta\n", 102 | "\n" 103 | ] 104 | } 105 | ], 106 | "metadata": { 107 | "kernelspec": { 108 | "display_name": "pyrite", 109 | "language": "python", 110 | "name": "python3" 111 | }, 112 | "language_info": { 113 | "codemirror_mode": { 114 | "name": "ipython", 115 | "version": 3 116 | }, 117 | "file_extension": ".py", 118 | "mimetype": "text/x-python", 119 | "name": "python", 120 | "nbconvert_exporter": "python", 121 | "pygments_lexer": "ipython3", 122 | "version": "3.12.3" 123 | } 124 | }, 125 | "nbformat": 4, 126 | "nbformat_minor": 2 127 | } 128 | -------------------------------------------------------------------------------- /PyriteML/diffusion_policy/config/task/online_correction_single_arm_no_base_action_no_force.yaml: -------------------------------------------------------------------------------- 1 | name: belt_residual_no_base_action_no_force 2 | 3 | # rgb vs. low_dim: raw data are either rgb images or low_dim vectors. frames are aligned within each type. 4 | # obs vs. action: obs data are used as policy input; action are used as labels for policy output. 5 | # dense vs. sparse: used for dense prediction vs. sparse prediction 6 | 7 | # down_sample_steps: how many steps to skip in the raw data for the given usage 8 | # horizon: how many steps to look ahead(action) or back(obs) after downsample for the given usage 9 | sparse_obs_rgb_down_sample_steps: 1 10 | sparse_obs_rgb_horizon: 1 11 | 12 | sparse_obs_low_dim_down_sample_steps: 5 13 | sparse_obs_low_dim_horizon: 3 14 | 15 | sparse_obs_base_down_sample_steps: 50 16 | sparse_obs_base_horizon: 6 17 | 18 | sparse_action_down_sample_steps: 10 19 | sparse_action_horizon: 5 20 | 21 | 22 | shape_meta: &shape_meta 23 | # acceptable types: rgb, low_dim 24 | # fields under raw and obs must be consistent with FlipUpDataset.raw_to_obs_action() 25 | id_list: [0] 26 | raw: # describes what exists in data 27 | rgb_0: 28 | shape: [3, 224, 224] 29 | type: rgb 30 | ts_pose_fb_0: 31 | shape: [7] 32 | type: low_dim 33 | ts_pose_command_0: 34 | shape: [7] 35 | type: low_dim 36 | ts_pose_virtual_target_0: 37 | shape: [7] 38 | type: low_dim 39 | stiffness_0: 40 | shape: [1] 41 | type: low_dim 42 | key_event_0: 43 | shape: [1] 44 | type: low_dim 45 | rgb_time_stamps_0: 46 | shape: [1] 47 | type: timestamp 48 | robot_time_stamps_0: 49 | shape: [1] 50 | type: timestamp 51 | key_event_time_stamps_0: 52 | shape: [1] 53 | type: timestamp 54 | 55 | obs: # describes observations loaded to memory 56 | rgb_0: 57 | shape: [3, 224, 224] 58 | type: rgb 59 | robot0_eef_pos: 60 | shape: [3] 61 | type: low_dim 62 | robot0_eef_rot_axis_angle: 63 | shape: [6] 64 | type: low_dim 65 | rotation_rep: rotation_6d 66 | rgb_time_stamps_0: 67 | shape: [1] 68 | type: timestamp 69 | robot_time_stamps_0: 70 | shape: [1] 71 | type: timestamp 72 | action: # describes actions loaded to memory, computed from robot command 73 | shape: [9] # 9 for residual reference pose, 6 for tool wrench 74 | rotation_rep: rotation_6d 75 | sample: # describes samples used in a batch 76 | # keys here must exist in obs/action above. 77 | # shape, type and rotation_rep are inherited from obs/action above. 78 | obs: 79 | sparse: 80 | rgb_0: 81 | horizon: ${task.sparse_obs_rgb_horizon} # int 82 | down_sample_steps: ${task.sparse_obs_rgb_down_sample_steps} # int 83 | robot0_eef_pos: 84 | horizon: ${task.sparse_obs_low_dim_horizon} # int 85 | down_sample_steps: ${task.sparse_obs_low_dim_down_sample_steps} # float 86 | robot0_eef_rot_axis_angle: # exists in data 87 | horizon: ${task.sparse_obs_low_dim_horizon} # int 88 | down_sample_steps: ${task.sparse_obs_low_dim_down_sample_steps} # float 89 | action: 90 | sparse: 91 | horizon: ${task.sparse_action_horizon} 92 | down_sample_steps: ${task.sparse_action_down_sample_steps} # int 93 | training_duration_per_sparse_query: 200 # TODO (in ms) 94 | 95 | task_name: &task_name belt_residual 96 | 97 | dataset: 98 | _target_: diffusion_policy.dataset.dynamic_dataset.DynamicDataset 99 | shape_meta: *shape_meta 100 | sparse_query_frequency_down_sample_steps: 8 101 | action_padding: False 102 | temporally_independent_normalization: False 103 | seed: 42 104 | val_ratio: 0.05 105 | normalize_wrench: False 106 | weighted_sampling: 1 # if > 1, duplicate the correction sample's ids weighted_sampling times 107 | correction_horizon: 10 # the extra action horizon sampled before and after the "start of correction" 108 | new_episode_prob: 0.5 109 | correction_force_threshold: 3.0 110 | correction_torque_threshold: 1.0 111 | use_raw_policy_timestamps: False 112 | virtual_target_config: 113 | stiffness_estimation_para: 114 | k_max: 5000 # 1cm 50N 115 | k_min: 200 # 1cm 2.5N 116 | f_low: 0.5 117 | f_high: 5 118 | dim: 3 119 | characteristic_length: 0.02 120 | wrench_moving_average_window_size: 500 # should be around 1s of data, 121 | flag_real: True # False for simulation data 122 | num_of_process: 5 # 5 123 | flag_plot: False 124 | fin_every_n: 50 # 50 125 | id_list: ${shape_meta.id_list} 126 | -------------------------------------------------------------------------------- /PyriteML/diffusion_policy/model/bet/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from collections import OrderedDict 4 | from typing import List, Optional 5 | 6 | import einops 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | 11 | from torch.utils.data import random_split 12 | import wandb 13 | 14 | 15 | def mlp(input_dim, hidden_dim, output_dim, hidden_depth, output_mod=None): 16 | if hidden_depth == 0: 17 | mods = [nn.Linear(input_dim, output_dim)] 18 | else: 19 | mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)] 20 | for i in range(hidden_depth - 1): 21 | mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)] 22 | mods.append(nn.Linear(hidden_dim, output_dim)) 23 | if output_mod is not None: 24 | mods.append(output_mod) 25 | trunk = nn.Sequential(*mods) 26 | return trunk 27 | 28 | 29 | class eval_mode: 30 | def __init__(self, *models, no_grad=False): 31 | self.models = models 32 | self.no_grad = no_grad 33 | self.no_grad_context = torch.no_grad() 34 | 35 | def __enter__(self): 36 | self.prev_states = [] 37 | for model in self.models: 38 | self.prev_states.append(model.training) 39 | model.train(False) 40 | if self.no_grad: 41 | self.no_grad_context.__enter__() 42 | 43 | def __exit__(self, *args): 44 | if self.no_grad: 45 | self.no_grad_context.__exit__(*args) 46 | for model, state in zip(self.models, self.prev_states): 47 | model.train(state) 48 | return False 49 | 50 | 51 | def freeze_module(module: nn.Module) -> nn.Module: 52 | for param in module.parameters(): 53 | param.requires_grad = False 54 | module.eval() 55 | return module 56 | 57 | 58 | def set_seed_everywhere(seed): 59 | torch.manual_seed(seed) 60 | if torch.cuda.is_available(): 61 | torch.cuda.manual_seed_all(seed) 62 | np.random.seed(seed) 63 | random.seed(seed) 64 | 65 | 66 | def shuffle_along_axis(a, axis): 67 | idx = np.random.rand(*a.shape).argsort(axis=axis) 68 | return np.take_along_axis(a, idx, axis=axis) 69 | 70 | 71 | def transpose_batch_timestep(*args): 72 | return (einops.rearrange(arg, "b t ... -> t b ...") for arg in args) 73 | 74 | 75 | class TrainWithLogger: 76 | def reset_log(self): 77 | self.log_components = OrderedDict() 78 | 79 | def log_append(self, log_key, length, loss_components): 80 | for key, value in loss_components.items(): 81 | key_name = f"{log_key}/{key}" 82 | count, sum = self.log_components.get(key_name, (0, 0.0)) 83 | self.log_components[key_name] = ( 84 | count + length, 85 | sum + (length * value.detach().cpu().item()), 86 | ) 87 | 88 | def flush_log(self, epoch, iterator=None): 89 | log_components = OrderedDict() 90 | iterator_log_component = OrderedDict() 91 | for key, value in self.log_components.items(): 92 | count, sum = value 93 | to_log = sum / count 94 | log_components[key] = to_log 95 | # Set the iterator status 96 | log_key, name_key = key.split("/") 97 | iterator_log_name = f"{log_key[0]}{name_key[0]}".upper() 98 | iterator_log_component[iterator_log_name] = to_log 99 | postfix = ",".join( 100 | "{}:{:.2e}".format(key, iterator_log_component[key]) 101 | for key in iterator_log_component.keys() 102 | ) 103 | if iterator is not None: 104 | iterator.set_postfix_str(postfix) 105 | wandb.log(log_components, step=epoch) 106 | self.log_components = OrderedDict() 107 | 108 | 109 | class SaveModule(nn.Module): 110 | def set_snapshot_path(self, path): 111 | self.snapshot_path = path 112 | print(f"Setting snapshot path to {self.snapshot_path}") 113 | 114 | def save_snapshot(self): 115 | os.makedirs(self.snapshot_path, exist_ok=True) 116 | torch.save(self.state_dict(), self.snapshot_path / "snapshot.pth") 117 | 118 | def load_snapshot(self): 119 | self.load_state_dict(torch.load(self.snapshot_path / "snapshot.pth")) 120 | 121 | 122 | def split_datasets(dataset, train_fraction=0.95, random_seed=42): 123 | dataset_length = len(dataset) 124 | lengths = [ 125 | int(train_fraction * dataset_length), 126 | dataset_length - int(train_fraction * dataset_length), 127 | ] 128 | train_set, val_set = random_split( 129 | dataset, lengths, generator=torch.Generator().manual_seed(random_seed) 130 | ) 131 | return train_set, val_set 132 | -------------------------------------------------------------------------------- /PyriteUtility/PyriteUtility/audio/multi_mic.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union, Dict, Callable 2 | import numbers 3 | import copy 4 | import time 5 | import pathlib 6 | from multiprocessing.managers import SharedMemoryManager 7 | import numpy as np 8 | from PyriteUtility.audio.mic import Microphone 9 | from PyriteUtility.audio.audio_recorder import AudioRecorder 10 | 11 | class MultiMicrophone: 12 | def __init__(self, 13 | shm_manager: Optional[SharedMemoryManager]=None, 14 | get_max_k=30, 15 | receive_latency=0.0, 16 | device_id=[0], 17 | num_channel=2, 18 | block_size=800, 19 | audio_sr=48000, 20 | put_downsample=True, 21 | audio_recorder: Optional[Union[AudioRecorder, List[AudioRecorder]]]=None, 22 | ): 23 | super().__init__() 24 | 25 | if shm_manager is None: 26 | shm_manager = SharedMemoryManager() 27 | shm_manager.start() 28 | 29 | mics = dict() 30 | for i, id in enumerate(device_id): 31 | mics[id] = Microphone( 32 | shm_manager=shm_manager, 33 | get_max_k=get_max_k, 34 | receive_latency=receive_latency, 35 | device_id=int(id), 36 | num_channel=num_channel, 37 | block_size=block_size, 38 | audio_sr=audio_sr, 39 | put_downsample=put_downsample, 40 | audio_recorder=audio_recorder[i]) 41 | 42 | self.mics = mics 43 | self.shm_manager = shm_manager 44 | 45 | def __enter__(self): 46 | self.start() 47 | return self 48 | 49 | def __exit__(self, exc_type, exc_val, exc_tb): 50 | self.stop() 51 | 52 | @property 53 | def n_mics(self): 54 | return len(self.mics) 55 | 56 | @property 57 | def is_ready(self): 58 | is_ready = True 59 | for mic in self.mics.values(): 60 | if not mic.is_ready: 61 | is_ready = False 62 | return is_ready 63 | 64 | def start(self, wait=True, put_start_time=None): 65 | if put_start_time is None: 66 | put_start_time = time.time() 67 | for mic in self.mics.values(): 68 | mic.start(wait=False, put_start_time=put_start_time) 69 | 70 | if wait: 71 | self.start_wait() 72 | 73 | def stop(self, wait=True): 74 | for mic in self.mics.values(): 75 | mic.stop(wait=False) 76 | 77 | if wait: 78 | self.stop_wait() 79 | 80 | def start_wait(self): 81 | for mic in self.mics.values(): 82 | mic.start_wait() 83 | 84 | def stop_wait(self): 85 | for mic in self.mics.values(): 86 | mic.join() 87 | 88 | def get(self, k=None, out=None) -> Dict[int, Dict[str, np.ndarray]]: 89 | """ 90 | Return order T,H,W,C 91 | { 92 | 0: { 93 | 'rgb': (T,H,W,C), 94 | 'timestamp': (T,) 95 | }, 96 | 1: ... 97 | } 98 | """ 99 | if out is None: 100 | out = dict() 101 | for _, mic in enumerate(self.mics.values()): 102 | i = mic.device_id 103 | this_out = None 104 | if i in out: 105 | this_out = out[i] 106 | this_out = mic.get(k=k, out=this_out) 107 | out[i] = this_out 108 | return out 109 | 110 | def start_recording(self, audio_path: Union[str, List[str]], start_time: float): 111 | if isinstance(audio_path, str): 112 | # directory 113 | video_dir = pathlib.Path(audio_path) 114 | assert video_dir.parent.is_dir() 115 | video_dir.mkdir(parents=True, exist_ok=True) 116 | audio_path = list() 117 | for i in range(self.n_mics): 118 | audio_path.append( 119 | str(video_dir.joinpath(f'{i}.wav').absolute())) 120 | assert len(audio_path) == self.n_mics 121 | 122 | for i, mic in enumerate(self.mics.values()): 123 | mic.start_recording(audio_path[i], start_time) 124 | 125 | def stop_recording(self): 126 | for i, mic in enumerate(self.mics.values()): 127 | mic.stop_recording() 128 | 129 | def restart_put(self, start_time): 130 | for mic in self.mics.values(): 131 | mic.restart_put(start_time) 132 | 133 | 134 | def repeat_to_list(x, n: int, cls): 135 | if x is None: 136 | x = [None] * n 137 | if isinstance(x, cls): 138 | x = [copy.deepcopy(x) for _ in range(n)] 139 | assert len(x) == n 140 | return x -------------------------------------------------------------------------------- /PyriteUtility/PyriteUtility/audio/audio_recorder.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | import av 4 | import numpy as np 5 | import sounddevice as sd 6 | import soundfile as sf 7 | import multiprocessing as mp 8 | 9 | from PyriteUtility.data_pipeline.shared_memory.shared_memory_queue import SharedMemoryQueue, Full, Empty 10 | from PyriteUtility.umi_utils.timestamp_accumulator import get_accumulate_timestamp_idxs 11 | 12 | 13 | class AudioEncoderProcess(mp.Process): 14 | def __init__(self, 15 | shm_manager, 16 | data_example: np.ndarray, 17 | file_path, 18 | codec, sr, num_channel, input_audio_fmt, 19 | **kwargs): 20 | super().__init__() 21 | 22 | self.file_path = file_path 23 | self.codec = codec 24 | self.sr = sr 25 | self.num_channel = num_channel 26 | self.input_audio_fmt = input_audio_fmt 27 | self.kwargs = kwargs 28 | self.shape = None 29 | self.dtype = None 30 | 31 | self.audio_queue = SharedMemoryQueue.create_from_examples( 32 | shm_manager=shm_manager, 33 | examples={'audio_block': data_example}, 34 | buffer_size=128 35 | ) 36 | self.stop_event = mp.Event() 37 | 38 | def stop(self, wait=True): 39 | # wake up thread waiting on queue 40 | self.stop_event.set() 41 | if wait: 42 | self.join() 43 | 44 | def put_audio_block(self, audio_block: np.ndarray): 45 | assert audio_block is not None 46 | self.audio_queue.put({'audio_block': audio_block}) 47 | 48 | def run(self): 49 | with sf.SoundFile(self.file_path, mode='w', samplerate=self.sr, channels=self.num_channel) as file: 50 | data = None 51 | while not self.stop_event.is_set(): 52 | try: 53 | data = self.audio_queue.get(out=data) 54 | file.write(data['audio_block']) 55 | except Empty: 56 | time.sleep(0.5/60) 57 | 58 | 59 | class AudioRecorder(): 60 | def __init__(self, shm_manager, sr, num_channel, codec, input_audio_fmt, **kwargs): 61 | self.shm_manager = shm_manager 62 | self.sr = sr 63 | self.num_channel = num_channel 64 | self.codec = codec 65 | self.input_audio_fmt = input_audio_fmt 66 | self.kwargs = kwargs 67 | 68 | self._reset_state() 69 | 70 | def _reset_state(self): 71 | self.file_path = None 72 | self.enc_thread = None 73 | self.start_time = None 74 | self.next_global_idx = 0 75 | 76 | def __del__(self): 77 | self.stop() 78 | 79 | def is_ready(self): 80 | return self.start_time is not None 81 | 82 | def start(self, file_path, start_time=None): 83 | if self.is_ready(): 84 | # if still recording, stop first and start anew. 85 | self.stop() 86 | 87 | self.file_path = file_path 88 | self.start_time = start_time 89 | 90 | def write_frame(self, audio_block: np.ndarray, frame_time): 91 | if not self.is_ready(): 92 | raise RuntimeError('Must run start() before writing!') 93 | 94 | # create encode threads if not already 95 | if self.enc_thread is None: 96 | self.enc_thread = AudioEncoderProcess( 97 | shm_manager=self.shm_manager, 98 | data_example=audio_block, 99 | file_path=self.file_path, 100 | codec=self.codec, 101 | sr=self.sr, 102 | num_channel=self.num_channel, 103 | input_audio_fmt=self.input_audio_fmt, 104 | **self.kwargs 105 | ) 106 | self.enc_thread.start() 107 | 108 | n_repeats = 1 109 | # if self.start_time is not None: 110 | # local_idxs, global_idxs, self.next_global_idx \ 111 | # = get_accumulate_timestamp_idxs( 112 | # # only one timestamp 113 | # timestamps=[frame_time], 114 | # start_time=self.start_time, 115 | # dt=1/self.put_fps, 116 | # next_global_idx=self.next_global_idx 117 | # ) 118 | # # number of apperance means repeats 119 | # n_repeats = len(local_idxs) 120 | 121 | # print(n_repeats) 122 | if self.start_time is not None and frame_time >= self.start_time: 123 | self.enc_thread.put_audio_block(audio_block) 124 | 125 | def stop(self): 126 | if not self.is_ready(): 127 | return 128 | self.enc_thread.stop(wait=True) 129 | # reset runtime parameters 130 | self._reset_state() -------------------------------------------------------------------------------- /PyriteML/diffusion_policy/config/task/stow_conv.yaml: -------------------------------------------------------------------------------- 1 | # Mod2: 19D action, no dense 2 | name: stow_conv 3 | 4 | # rgb vs. low_dim: raw data are either rgb images or low_dim vectors. frames are aligned within each type. 5 | # obs vs. action: obs data are used as policy input; action are used as labels for policy output. 6 | # dense vs. sparse: used for dense prediction vs. sparse prediction 7 | 8 | # down_sample_steps: how many steps to skip in the raw data for the given usage 9 | # horizon: how many steps to look ahead(action) or back(obs) after downsample for the given usage 10 | sparse_obs_rgb_down_sample_steps: 10 11 | sparse_obs_rgb_horizon: 2 12 | 13 | sparse_obs_low_dim_down_sample_steps: 5 14 | sparse_obs_low_dim_horizon: 3 15 | 16 | sparse_obs_wrench_down_sample_steps: 4 17 | sparse_obs_wrench_horizon: 32 18 | 19 | sparse_action_down_sample_steps: 20 20 | sparse_action_horizon: 16 21 | 22 | # The following parameter is used to avoid action padding when sampling queries. 23 | # It is the duration of a full action prediction horizon in ms. 24 | # This is needed since my sampler uses timestamp instead of index to align data. 25 | action_horizon_duration_buffer_ms: 2100 26 | 27 | shape_meta: &shape_meta 28 | # acceptable types: rgb, low_dim 29 | # fields under raw and obs must be consistent with FlipUpDataset.raw_to_obs_action() 30 | id_list: [0] 31 | raw: # describes what exists in data 32 | rgb_0: 33 | shape: [3, 224, 224] 34 | type: rgb 35 | ts_pose_fb_0: 36 | shape: [7] 37 | type: low_dim 38 | ts_pose_command_0: 39 | shape: [7] 40 | type: low_dim 41 | ts_pose_virtual_target_0: 42 | shape: [7] 43 | type: low_dim 44 | stiffness_0: 45 | shape: [1] 46 | type: low_dim 47 | wrench_0: 48 | shape: [6] 49 | type: low_dim 50 | rgb_time_stamps_0: 51 | shape: [1] 52 | type: timestamp 53 | robot_time_stamps_0: 54 | shape: [1] 55 | type: timestamp 56 | wrench_time_stamps_0: 57 | shape: [1] 58 | type: timestamp 59 | 60 | obs: # describes observations loaded to memory 61 | rgb_0: 62 | shape: [3, 224, 224] 63 | type: rgb 64 | robot0_eef_pos: 65 | shape: [3] 66 | type: low_dim 67 | robot0_eef_rot_axis_angle: 68 | shape: [6] 69 | type: low_dim 70 | rotation_rep: rotation_6d 71 | robot0_eef_wrench: 72 | shape: [6] 73 | type: low_dim 74 | rgb_time_stamps_0: 75 | shape: [1] 76 | type: timestamp 77 | robot_time_stamps_0: 78 | shape: [1] 79 | type: timestamp 80 | wrench_time_stamps_0: 81 | shape: [1] 82 | type: timestamp 83 | action: # describes actions loaded to memory, computed from robot command 84 | shape: [19] # 9 for reference pose, 9 for virtual target, 1 for stiffness 85 | rotation_rep: rotation_6d 86 | sample: # describes samples used in a batch 87 | # keys here must exist in obs/action above. 88 | # shape, type and rotation_rep are inherited from obs/action above. 89 | obs: 90 | sparse: 91 | rgb_0: 92 | horizon: ${task.sparse_obs_rgb_horizon} # int 93 | down_sample_steps: ${task.sparse_obs_rgb_down_sample_steps} # int 94 | robot0_eef_pos: 95 | horizon: ${task.sparse_obs_low_dim_horizon} # int 96 | down_sample_steps: ${task.sparse_obs_low_dim_down_sample_steps} # float 97 | robot0_eef_rot_axis_angle: # exists in data 98 | horizon: ${task.sparse_obs_low_dim_horizon} # int 99 | down_sample_steps: ${task.sparse_obs_low_dim_down_sample_steps} # float 100 | robot0_eef_wrench: 101 | horizon: ${task.sparse_obs_wrench_horizon} # int 102 | down_sample_steps: ${task.sparse_obs_wrench_down_sample_steps} # float 103 | action: 104 | sparse: 105 | horizon: ${task.sparse_action_horizon} 106 | down_sample_steps: ${task.sparse_action_down_sample_steps} # int 107 | training_duration_per_sparse_query: ${task.action_horizon_duration_buffer_ms} 108 | 109 | 110 | task_name: &task_name flip_up 111 | dataset_path: ${oc.env:PYRITE_DATASET_FOLDERS}/stow_test/small_pod_80 112 | 113 | env_runner: # used in workspace for computing metrics 114 | _target_: diffusion_policy.env_runner.real_pusht_image_runner.RealPushTImageRunner 115 | 116 | dataset: 117 | _target_: diffusion_policy.dataset.virtual_target_dataset.VirtualTargetDataset 118 | shape_meta: *shape_meta 119 | dataset_path: ${task.dataset_path} 120 | sparse_query_frequency_down_sample_steps: 8 121 | hack_linear_interpolated_dense_action: False 122 | # cache_dir: null 123 | action_padding: False 124 | temporally_independent_normalization: False 125 | seed: 42 126 | val_ratio: 0.05 127 | normalize_wrench: False 128 | -------------------------------------------------------------------------------- /PyriteML/diffusion_policy/config/task/stow_spec.yaml: -------------------------------------------------------------------------------- 1 | # Mod2: 19D action, no dense 2 | name: flip_up_new 3 | 4 | # rgb vs. low_dim: raw data are either rgb images or low_dim vectors. frames are aligned within each type. 5 | # obs vs. action: obs data are used as policy input; action are used as labels for policy output. 6 | # dense vs. sparse: used for dense prediction vs. sparse prediction 7 | 8 | # down_sample_steps: how many steps to skip in the raw data for the given usage 9 | # horizon: how many steps to look ahead(action) or back(obs) after downsample for the given usage 10 | sparse_obs_rgb_down_sample_steps: 10 11 | sparse_obs_rgb_horizon: 2 12 | 13 | sparse_obs_low_dim_down_sample_steps: 5 14 | sparse_obs_low_dim_horizon: 3 15 | 16 | sparse_obs_wrench_down_sample_steps: 1 17 | sparse_obs_wrench_horizon: 7000 18 | 19 | sparse_action_down_sample_steps: 20 20 | sparse_action_horizon: 16 21 | 22 | # The following parameter is used to avoid action padding when sampling queries. 23 | # It is the duration of a full action prediction horizon in ms. 24 | # This is needed since my sampler uses timestamp instead of index to align data. 25 | action_horizon_duration_buffer_ms: 2100 26 | 27 | shape_meta: &shape_meta 28 | # acceptable types: rgb, low_dim 29 | # fields under raw and obs must be consistent with FlipUpDataset.raw_to_obs_action() 30 | id_list: [0] 31 | raw: # describes what exists in data 32 | rgb_0: 33 | shape: [3, 224, 224] 34 | type: rgb 35 | ts_pose_fb_0: 36 | shape: [7] 37 | type: low_dim 38 | ts_pose_command_0: 39 | shape: [7] 40 | type: low_dim 41 | ts_pose_virtual_target_0: 42 | shape: [7] 43 | type: low_dim 44 | stiffness_0: 45 | shape: [1] 46 | type: low_dim 47 | wrench_0: 48 | shape: [6] 49 | type: low_dim 50 | rgb_time_stamps_0: 51 | shape: [1] 52 | type: timestamp 53 | robot_time_stamps_0: 54 | shape: [1] 55 | type: timestamp 56 | wrench_time_stamps_0: 57 | shape: [1] 58 | type: timestamp 59 | 60 | obs: # describes observations loaded to memory 61 | rgb_0: 62 | shape: [3, 224, 224] 63 | type: rgb 64 | robot0_eef_pos: 65 | shape: [3] 66 | type: low_dim 67 | robot0_eef_rot_axis_angle: 68 | shape: [6] 69 | type: low_dim 70 | rotation_rep: rotation_6d 71 | robot0_eef_wrench: 72 | shape: [6] 73 | type: low_dim 74 | rgb_time_stamps_0: 75 | shape: [1] 76 | type: timestamp 77 | robot_time_stamps_0: 78 | shape: [1] 79 | type: timestamp 80 | wrench_time_stamps_0: 81 | shape: [1] 82 | type: timestamp 83 | action: # describes actions loaded to memory, computed from robot command 84 | shape: [19] # 9 for reference pose, 9 for virtual target, 1 for stiffness 85 | rotation_rep: rotation_6d 86 | sample: # describes samples used in a batch 87 | # keys here must exist in obs/action above. 88 | # shape, type and rotation_rep are inherited from obs/action above. 89 | obs: 90 | sparse: 91 | rgb_0: 92 | horizon: ${task.sparse_obs_rgb_horizon} # int 93 | down_sample_steps: ${task.sparse_obs_rgb_down_sample_steps} # int 94 | robot0_eef_pos: 95 | horizon: ${task.sparse_obs_low_dim_horizon} # int 96 | down_sample_steps: ${task.sparse_obs_low_dim_down_sample_steps} # float 97 | robot0_eef_rot_axis_angle: # exists in data 98 | horizon: ${task.sparse_obs_low_dim_horizon} # int 99 | down_sample_steps: ${task.sparse_obs_low_dim_down_sample_steps} # float 100 | robot0_eef_wrench: 101 | horizon: ${task.sparse_obs_wrench_horizon} # int 102 | down_sample_steps: ${task.sparse_obs_wrench_down_sample_steps} # float 103 | action: 104 | sparse: 105 | horizon: ${task.sparse_action_horizon} 106 | down_sample_steps: ${task.sparse_action_down_sample_steps} # int 107 | training_duration_per_sparse_query: ${task.action_horizon_duration_buffer_ms} 108 | 109 | 110 | task_name: &task_name flip_up 111 | dataset_path: ${oc.env:PYRITE_DATASET_FOLDERS}/stow_test/small_pod_80 112 | 113 | env_runner: # used in workspace for computing metrics 114 | _target_: diffusion_policy.env_runner.real_pusht_image_runner.RealPushTImageRunner 115 | 116 | dataset: 117 | _target_: diffusion_policy.dataset.virtual_target_dataset.VirtualTargetDataset 118 | shape_meta: *shape_meta 119 | dataset_path: ${task.dataset_path} 120 | sparse_query_frequency_down_sample_steps: 8 121 | hack_linear_interpolated_dense_action: False 122 | # cache_dir: null 123 | action_padding: False 124 | temporally_independent_normalization: False 125 | seed: 42 126 | val_ratio: 0.05 127 | normalize_wrench: False 128 | -------------------------------------------------------------------------------- /PyriteML/multimodal_representation/multimodal/logger.py: -------------------------------------------------------------------------------- 1 | import git 2 | from tensorboardX import SummaryWriter 3 | import datetime 4 | import time 5 | import os 6 | 7 | import logging 8 | import sys 9 | import yaml 10 | 11 | 12 | class Logger(object): 13 | """ 14 | Hooks for print statements and tensorboard logging 15 | """ 16 | 17 | def __init__(self, configs): 18 | 19 | self.configs = configs 20 | 21 | time_str = datetime.datetime.fromtimestamp(time.time()).strftime("%Y%m%d%H%M") 22 | prefix_str = time_str + "_" + configs["notes"] 23 | if configs["dev"]: 24 | prefix_str = "dev_" + prefix_str 25 | 26 | self.log_folder = os.path.join(self.configs["logging_folder"], prefix_str) 27 | self.tb_prefix = prefix_str 28 | 29 | self.setup_checks() 30 | self.create_folder_structure() 31 | self.setup_loggers() 32 | self.dump_init_info() 33 | 34 | def create_folder_structure(self): 35 | """ 36 | Creates the folder structure for logging. Subfolders can be added here 37 | """ 38 | base_dir = self.log_folder 39 | sub_folders = ["runs", "models"] 40 | 41 | if not os.path.exists(self.configs["logging_folder"]): 42 | os.mkdir(self.configs["logging_folder"]) 43 | 44 | if not os.path.exists(base_dir): 45 | os.mkdir(base_dir) 46 | 47 | for sf in sub_folders: 48 | if not os.path.exists(os.path.join(base_dir, sf)): 49 | os.mkdir(os.path.join(base_dir, sf)) 50 | 51 | def setup_loggers(self): 52 | """ 53 | Sets up a logger that logs to both file and stdout 54 | """ 55 | log_path = os.path.join(self.log_folder, "log.log") 56 | 57 | self.print_logger = logging.getLogger() 58 | self.print_logger.setLevel( 59 | getattr(logging, self.configs["log_level"].upper(), None) 60 | ) 61 | handlers = [logging.StreamHandler(sys.stdout), logging.FileHandler(log_path)] 62 | formatter = logging.Formatter( 63 | "%(levelname)s - %(filename)s - %(asctime)s - %(message)s" 64 | ) 65 | for h in handlers: 66 | h.setFormatter(formatter) 67 | self.print_logger.addHandler(h) 68 | 69 | # Setup Tensorboard 70 | self.tb = SummaryWriter(os.path.join(self.log_folder, "runs", self.tb_prefix)) 71 | 72 | def setup_checks(self): 73 | """ 74 | Verifies that all changes have been committed 75 | Verifies that hashes match (if continuation) 76 | """ 77 | repo = git.Repo(search_parent_directories=True) 78 | sha = repo.head.object.hexsha 79 | 80 | 81 | # Test for continuation 82 | if self.configs["continuation"]: 83 | self.log_folder = self.configs["logging_folder"] 84 | with open(os.path.join(self.log_folder, "log.log"), "r") as old_log: 85 | for line in old_log: 86 | find_str = "Git hash" 87 | if line.find(find_str) is not -1: 88 | old_sha = line[line.find(find_str) + len(find_str) + 2 : -4] 89 | assert sha == old_sha 90 | 91 | def dump_init_info(self): 92 | """ 93 | Saves important info for replicability 94 | """ 95 | if not self.configs["continuation"]: 96 | self.configs["logging_folder"] = self.log_folder 97 | else: 98 | self.print("=" * 80) 99 | self.print("Continuing log") 100 | self.print("=" * 80) 101 | 102 | repo = git.Repo(search_parent_directories=True) 103 | sha = repo.head.object.hexsha 104 | 105 | self.print("Git hash: {}".format(sha)) 106 | self.print("Dumping YAML file") 107 | self.print("Configs: ", yaml.dump(self.configs)) 108 | 109 | # Save the start of every run 110 | if "start_weights" not in self.configs: 111 | self.configs["start_weights"] = [] 112 | self.configs["start_weights"].append(self.configs["load"]) 113 | 114 | with open(os.path.join(self.log_folder, "configs.yml"), "w") as outfile: 115 | yaml.dump(self.configs, outfile) 116 | self.tb.add_text("hyperparams", str(self.configs)) 117 | 118 | def end_itr(self, weights_path): 119 | """ 120 | Perform all operations needed at end of iteration 121 | 1). Save configs with latest weights 122 | """ 123 | self.configs["latest_weights"] = weights_path 124 | with open(os.path.join(self.log_folder, "configs.yml"), "w") as outfile: 125 | yaml.dump(self.configs, outfile) 126 | 127 | def print(self, *args): 128 | """ 129 | Wrapper for print statement 130 | """ 131 | self.print_logger.info(args) 132 | 133 | -------------------------------------------------------------------------------- /PyriteML/diffusion_policy/common/cv2_util.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import math 3 | import cv2 4 | import numpy as np 5 | 6 | def draw_reticle(img, u, v, label_color): 7 | """ 8 | Draws a reticle (cross-hair) on the image at the given position on top of 9 | the original image. 10 | @param img (In/Out) uint8 3 channel image 11 | @param u X coordinate (width) 12 | @param v Y coordinate (height) 13 | @param label_color tuple of 3 ints for RGB color used for drawing. 14 | """ 15 | # Cast to int. 16 | u = int(u) 17 | v = int(v) 18 | 19 | white = (255, 255, 255) 20 | cv2.circle(img, (u, v), 10, label_color, 1) 21 | cv2.circle(img, (u, v), 11, white, 1) 22 | cv2.circle(img, (u, v), 12, label_color, 1) 23 | cv2.line(img, (u, v + 1), (u, v + 3), white, 1) 24 | cv2.line(img, (u + 1, v), (u + 3, v), white, 1) 25 | cv2.line(img, (u, v - 1), (u, v - 3), white, 1) 26 | cv2.line(img, (u - 1, v), (u - 3, v), white, 1) 27 | 28 | 29 | def draw_text( 30 | img, 31 | *, 32 | text, 33 | uv_top_left, 34 | color=(255, 255, 255), 35 | fontScale=0.5, 36 | thickness=1, 37 | fontFace=cv2.FONT_HERSHEY_SIMPLEX, 38 | outline_color=(0, 0, 0), 39 | line_spacing=1.5, 40 | ): 41 | """ 42 | Draws multiline with an outline. 43 | """ 44 | assert isinstance(text, str) 45 | 46 | uv_top_left = np.array(uv_top_left, dtype=float) 47 | assert uv_top_left.shape == (2,) 48 | 49 | for line in text.splitlines(): 50 | (w, h), _ = cv2.getTextSize( 51 | text=line, 52 | fontFace=fontFace, 53 | fontScale=fontScale, 54 | thickness=thickness, 55 | ) 56 | uv_bottom_left_i = uv_top_left + [0, h] 57 | org = tuple(uv_bottom_left_i.astype(int)) 58 | 59 | if outline_color is not None: 60 | cv2.putText( 61 | img, 62 | text=line, 63 | org=org, 64 | fontFace=fontFace, 65 | fontScale=fontScale, 66 | color=outline_color, 67 | thickness=thickness * 3, 68 | lineType=cv2.LINE_AA, 69 | ) 70 | cv2.putText( 71 | img, 72 | text=line, 73 | org=org, 74 | fontFace=fontFace, 75 | fontScale=fontScale, 76 | color=color, 77 | thickness=thickness, 78 | lineType=cv2.LINE_AA, 79 | ) 80 | 81 | uv_top_left += [0, h * line_spacing] 82 | 83 | 84 | def get_image_transform( 85 | input_res: Tuple[int,int]=(1280,720), 86 | output_res: Tuple[int,int]=(640,480), 87 | bgr_to_rgb: bool=False): 88 | 89 | iw, ih = input_res 90 | ow, oh = output_res 91 | rw, rh = None, None 92 | interp_method = cv2.INTER_AREA 93 | 94 | if (iw/ih) >= (ow/oh): 95 | # input is wider 96 | rh = oh 97 | rw = math.ceil(rh / ih * iw) 98 | if oh > ih: 99 | interp_method = cv2.INTER_LINEAR 100 | else: 101 | rw = ow 102 | rh = math.ceil(rw / iw * ih) 103 | if ow > iw: 104 | interp_method = cv2.INTER_LINEAR 105 | 106 | w_slice_start = (rw - ow) // 2 107 | w_slice = slice(w_slice_start, w_slice_start + ow) 108 | h_slice_start = (rh - oh) // 2 109 | h_slice = slice(h_slice_start, h_slice_start + oh) 110 | c_slice = slice(None) 111 | if bgr_to_rgb: 112 | c_slice = slice(None, None, -1) 113 | 114 | def transform(img: np.ndarray): 115 | assert img.shape == ((ih,iw,3)) 116 | # resize 117 | img = cv2.resize(img, (rw, rh), interpolation=interp_method) 118 | # crop 119 | img = img[h_slice, w_slice, c_slice] 120 | return img 121 | return transform 122 | 123 | def optimal_row_cols( 124 | n_cameras, 125 | in_wh_ratio, 126 | max_resolution=(1920, 1080) 127 | ): 128 | out_w, out_h = max_resolution 129 | out_wh_ratio = out_w / out_h 130 | 131 | n_rows = np.arange(n_cameras,dtype=np.int64) + 1 132 | n_cols = np.ceil(n_cameras / n_rows).astype(np.int64) 133 | cat_wh_ratio = in_wh_ratio * (n_cols / n_rows) 134 | ratio_diff = np.abs(out_wh_ratio - cat_wh_ratio) 135 | best_idx = np.argmin(ratio_diff) 136 | best_n_row = n_rows[best_idx] 137 | best_n_col = n_cols[best_idx] 138 | best_cat_wh_ratio = cat_wh_ratio[best_idx] 139 | 140 | rw, rh = None, None 141 | if best_cat_wh_ratio >= out_wh_ratio: 142 | # cat is wider 143 | rw = math.floor(out_w / best_n_col) 144 | rh = math.floor(rw / in_wh_ratio) 145 | else: 146 | rh = math.floor(out_h / best_n_row) 147 | rw = math.floor(rh * in_wh_ratio) 148 | 149 | # crop_resolution = (rw, rh) 150 | return rw, rh, best_n_col, best_n_row 151 | -------------------------------------------------------------------------------- /PyriteML/diffusion_policy/common/pose_repr_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def compute_relative_pose(pos, rot, base_pos, base_rot_mat, 5 | rot_transformer_to_mat, 6 | rot_transformer_to_target, 7 | backward=False, 8 | delta=False): 9 | if not backward: 10 | # forward pass 11 | if not delta: 12 | output_pos = pos if base_pos is None else pos - base_pos 13 | output_rot = rot_transformer_to_target.forward( 14 | rot_transformer_to_mat.forward(rot) @ np.linalg.inv(base_rot_mat)) 15 | return output_pos, output_rot 16 | else: 17 | all_pos = np.concatenate([base_pos[None,...], pos], axis=0) 18 | output_pos = np.diff(all_pos, axis=0) 19 | 20 | rot_mat = rot_transformer_to_mat.forward(rot) 21 | all_rot_mat = np.concatenate([base_rot_mat[None,...], rot_mat], axis=0) 22 | prev_rot = np.linalg.inv(all_rot_mat[:-1]) 23 | curr_rot = all_rot_mat[1:] 24 | rot = np.matmul(curr_rot, prev_rot) 25 | output_rot = rot_transformer_to_target.forward(rot) 26 | return output_pos, output_rot 27 | 28 | else: 29 | # backward pass 30 | if not delta: 31 | output_pos = pos if base_pos is None else pos + base_pos 32 | output_rot = rot_transformer_to_mat.inverse( 33 | rot_transformer_to_target.inverse(rot) @ base_rot_mat) 34 | return output_pos, output_rot 35 | else: 36 | output_pos = np.cumsum(pos, axis=0) + base_pos 37 | 38 | rot_mat = rot_transformer_to_target.inverse(rot) 39 | output_rot_mat = np.zeros_like(rot_mat) 40 | curr_rot = base_rot_mat 41 | for i in range(len(rot_mat)): 42 | curr_rot = rot_mat[i] @ curr_rot 43 | output_rot_mat[i] = curr_rot 44 | output_rot = rot_transformer_to_mat.inverse(rot) 45 | return output_pos, output_rot 46 | 47 | 48 | def convert_pose_mat_rep(pose_mat, base_pose_mat, pose_rep='abs', backward=False): 49 | if not backward: 50 | # training transform 51 | if pose_rep == 'abs': 52 | return pose_mat 53 | elif pose_rep == 'rel': 54 | # legacy buggy implementation 55 | # for compatibility 56 | pos = pose_mat[...,:3,3] - base_pose_mat[:3,3] 57 | rot = pose_mat[...,:3,:3] @ np.linalg.inv(base_pose_mat[:3,:3]) 58 | out = np.copy(pose_mat) 59 | out[...,:3,:3] = rot 60 | out[...,:3,3] = pos 61 | return out 62 | elif pose_rep == 'relative': 63 | out = np.linalg.inv(base_pose_mat) @ pose_mat 64 | return out 65 | elif pose_rep == 'delta': 66 | all_pos = np.concatenate([base_pose_mat[None,:3,3], pose_mat[...,:3,3]], axis=0) 67 | out_pos = np.diff(all_pos, axis=0) 68 | 69 | all_rot_mat = np.concatenate([base_pose_mat[None,:3,:3], pose_mat[...,:3,:3]], axis=0) 70 | prev_rot = np.linalg.inv(all_rot_mat[:-1]) 71 | curr_rot = all_rot_mat[1:] 72 | out_rot = np.matmul(curr_rot, prev_rot) 73 | 74 | out = np.copy(pose_mat) 75 | out[...,:3,:3] = out_rot 76 | out[...,:3,3] = out_pos 77 | return out 78 | else: 79 | raise RuntimeError(f"Unsupported pose_rep: {pose_rep}") 80 | 81 | else: 82 | # eval transform 83 | if pose_rep == 'abs': 84 | return pose_mat 85 | elif pose_rep == 'rel': 86 | # legacy buggy implementation 87 | # for compatibility 88 | pos = pose_mat[...,:3,3] + base_pose_mat[:3,3] 89 | rot = pose_mat[...,:3,:3] @ base_pose_mat[:3,:3] 90 | out = np.copy(pose_mat) 91 | out[...,:3,:3] = rot 92 | out[...,:3,3] = pos 93 | return out 94 | elif pose_rep == 'relative': 95 | out = base_pose_mat @ pose_mat 96 | return out 97 | elif pose_rep == 'delta': 98 | output_pos = np.cumsum(pose_mat[...,:3,3], axis=0) + base_pose_mat[:3,3] 99 | 100 | output_rot_mat = np.zeros_like(pose_mat[...,:3,:3]) 101 | curr_rot = base_pose_mat[:3,:3] 102 | for i in range(len(pose_mat)): 103 | curr_rot = pose_mat[i,:3,:3] @ curr_rot 104 | output_rot_mat[i] = curr_rot 105 | 106 | out = np.copy(pose_mat) 107 | out[...,:3,:3] = output_rot_mat 108 | out[...,:3,3] = output_pos 109 | return out 110 | else: 111 | raise RuntimeError(f"Unsupported pose_rep: {pose_rep}") 112 | -------------------------------------------------------------------------------- /PyriteML/diffusion_policy/config/task/correction_single_arm_xyzforce_only.yaml: -------------------------------------------------------------------------------- 1 | name: stow_residual_xyzforce_only 2 | 3 | # rgb vs. low_dim: raw data are either rgb images or low_dim vectors. frames are aligned within each type. 4 | # obs vs. action: obs data are used as policy input; action are used as labels for policy output. 5 | # dense vs. sparse: used for dense prediction vs. sparse prediction 6 | 7 | # down_sample_steps: how many steps to skip in the raw data for the given usage 8 | # horizon: how many steps to look ahead(action) or back(obs) after downsample for the given usage 9 | sparse_obs_rgb_down_sample_steps: 10 10 | sparse_obs_rgb_horizon: 1 11 | 12 | sparse_obs_low_dim_down_sample_steps: 5 13 | sparse_obs_low_dim_horizon: 3 14 | 15 | sparse_obs_wrench_down_sample_steps: 4 16 | sparse_obs_wrench_horizon: 32 17 | 18 | sparse_action_down_sample_steps: 10 19 | sparse_action_horizon: 5 20 | 21 | 22 | shape_meta: &shape_meta 23 | # acceptable types: rgb, low_dim 24 | # fields under raw and obs must be consistent with FlipUpDataset.raw_to_obs_action() 25 | id_list: [0] 26 | raw: # describes what exists in data 27 | rgb_0: 28 | shape: [3, 224, 224] 29 | type: rgb 30 | ts_pose_fb_0: 31 | shape: [7] 32 | type: low_dim 33 | ts_pose_command_0: 34 | shape: [7] 35 | type: low_dim 36 | policy_pose_command_0: 37 | shape: [7] 38 | type: low_dim 39 | ts_pose_virtual_target_0: 40 | shape: [7] 41 | type: low_dim 42 | stiffness_0: 43 | shape: [1] 44 | type: low_dim 45 | wrench_0: 46 | shape: [6] 47 | type: low_dim 48 | rgb_time_stamps_0: 49 | shape: [1] 50 | type: timestamp 51 | robot_time_stamps_0: 52 | shape: [1] 53 | type: timestamp 54 | wrench_time_stamps_0: 55 | shape: [1] 56 | type: timestamp 57 | policy_time_stamps_0: 58 | shape: [1] 59 | type: timestamp 60 | 61 | obs: # describes observations loaded to memory 62 | rgb_0: 63 | shape: [3, 224, 224] 64 | type: rgb 65 | robot0_eef_pos: 66 | shape: [3] 67 | type: low_dim 68 | robot0_eef_rot_axis_angle: 69 | shape: [6] 70 | type: low_dim 71 | rotation_rep: rotation_6d 72 | policy_robot0_eef_pos: 73 | shape: [3] 74 | type: low_dim 75 | policy_robot0_eef_rot_axis_angle: 76 | shape: [6] 77 | type: low_dim 78 | rotation_rep: rotation_6d 79 | robot0_eef_wrench: 80 | shape: [6] 81 | type: low_dim 82 | rgb_time_stamps_0: 83 | shape: [1] 84 | type: timestamp 85 | robot_time_stamps_0: 86 | shape: [1] 87 | type: timestamp 88 | policy_time_stamps_0: 89 | shape: [1] 90 | type: timestamp 91 | wrench_time_stamps_0: 92 | shape: [1] 93 | type: timestamp 94 | action: # describes actions loaded to memory, computed from robot command 95 | shape: [3] # xyz force only 96 | rotation_rep: rotation_6d 97 | sample: # describes samples used in a batch 98 | # keys here must exist in obs/action above. 99 | # shape, type and rotation_rep are inherited from obs/action above. 100 | obs: 101 | sparse: 102 | rgb_0: 103 | horizon: ${task.sparse_obs_rgb_horizon} # int 104 | down_sample_steps: ${task.sparse_obs_rgb_down_sample_steps} # int 105 | robot0_eef_pos: 106 | horizon: ${task.sparse_obs_low_dim_horizon} # int 107 | down_sample_steps: ${task.sparse_obs_low_dim_down_sample_steps} # float 108 | robot0_eef_rot_axis_angle: # exists in data 109 | horizon: ${task.sparse_obs_low_dim_horizon} # int 110 | down_sample_steps: ${task.sparse_obs_low_dim_down_sample_steps} # float 111 | policy_robot0_eef_pos: 112 | horizon: 12 # int 113 | down_sample_steps: 1 # float 114 | policy_robot0_eef_rot_axis_angle: 115 | horizon: 12 # int 116 | down_sample_steps: 1 # float 117 | robot0_eef_wrench: 118 | horizon: ${task.sparse_obs_wrench_horizon} # int 119 | down_sample_steps: ${task.sparse_obs_wrench_down_sample_steps} # float 120 | action: 121 | sparse: 122 | horizon: ${task.sparse_action_horizon} 123 | down_sample_steps: ${task.sparse_action_down_sample_steps} # int 124 | training_duration_per_sparse_query: 200 # TODO (in ms) 125 | 126 | task_name: &task_name stow_residual 127 | dataset_path: ${oc.env:PYRITE_DATASET_FOLDERS}/correction 128 | 129 | env_runner: # used in workspace for computing metrics 130 | _target_: diffusion_policy.env_runner.real_pusht_image_runner.RealPushTImageRunner 131 | 132 | dataset: 133 | _target_: diffusion_policy.dataset.virtual_target_dataset.VirtualTargetDataset 134 | shape_meta: *shape_meta 135 | dataset_path: ${task.dataset_path} 136 | sparse_query_frequency_down_sample_steps: 2 # TODO 137 | # cache_dir: null 138 | action_padding: False 139 | temporally_independent_normalization: False 140 | seed: 42 141 | val_ratio: 0.05 142 | hack_linear_interpolated_dense_action: False 143 | normalize_wrench: False 144 | -------------------------------------------------------------------------------- /PyriteML/diffusion_policy/config/task/stow_no_force.yaml: -------------------------------------------------------------------------------- 1 | name: belt_angled_150 2 | 3 | # rgb vs. low_dim: raw data are either rgb images or low_dim vectors. frames are aligned within each type. 4 | # obs vs. action: obs data are used as policy input; action are used as labels for policy output. 5 | 6 | # down_sample_steps: how many steps to skip in the raw data for the given usage 7 | # horizon: how many steps to look ahead(action) or back(obs) after downsample for the given usage 8 | sparse_obs_rgb_down_sample_steps: 10 9 | sparse_obs_rgb_horizon: 2 10 | 11 | sparse_obs_low_dim_down_sample_steps: 5 12 | sparse_obs_low_dim_horizon: 3 13 | 14 | sparse_action_down_sample_steps: 50 15 | sparse_action_horizon: 32 16 | 17 | # The following parameter is used to avoid action padding when sampling queries. 18 | # It is the duration of a full action prediction horizon in ms. 19 | # This is needed since my sampler uses timestamp instead of index to align data. 20 | action_horizon_duration_buffer_ms: 3300 21 | 22 | shape_meta: &shape_meta 23 | # acceptable types: rgb, low_dim 24 | # fields under raw and obs must be consistent with FlipUpDataset.raw_to_obs_action() 25 | id_list: [0] 26 | raw: # describes what exists in data 27 | rgb_0: 28 | shape: [3, 224, 224] 29 | type: rgb 30 | ts_pose_fb_0: 31 | shape: [7] 32 | type: low_dim 33 | ts_pose_command_0: 34 | shape: [7] 35 | type: low_dim 36 | ts_pose_virtual_target_0: 37 | shape: [7] 38 | type: low_dim 39 | stiffness_0: 40 | shape: [1] 41 | type: low_dim 42 | wrench_0: 43 | shape: [6] 44 | type: low_dim 45 | rgb_time_stamps_0: 46 | shape: [1] 47 | type: timestamp 48 | robot_time_stamps_0: 49 | shape: [1] 50 | type: timestamp 51 | wrench_time_stamps_0: 52 | shape: [1] 53 | type: timestamp 54 | 55 | obs: # describes observations loaded to memory 56 | rgb_0: 57 | shape: [3, 224, 224] 58 | type: rgb 59 | robot0_eef_pos: 60 | shape: [3] 61 | type: low_dim 62 | robot0_eef_rot_axis_angle: 63 | shape: [6] 64 | type: low_dim 65 | rotation_rep: rotation_6d 66 | # robot0_abs_eef_rot_axis_angle: 67 | # shape: [6] 68 | # type: low_dim 69 | # rotation_rep: rotation_6d 70 | rgb_time_stamps_0: 71 | shape: [1] 72 | type: timestamp 73 | robot_time_stamps_0: 74 | shape: [1] 75 | type: timestamp 76 | action: # describes actions loaded to memory, computed from robot command 77 | shape: [9] # 9 for reference pose, 9 for virtual target, 1 for stiffness 78 | rotation_rep: rotation_6d 79 | sample: # describes samples used in a batch 80 | # keys here must exist in obs/action above. 81 | # shape, type and rotation_rep are inherited from obs/action above. 82 | obs: 83 | sparse: 84 | rgb_0: 85 | horizon: ${task.sparse_obs_rgb_horizon} # int 86 | down_sample_steps: ${task.sparse_obs_rgb_down_sample_steps} # int 87 | robot0_eef_pos: 88 | horizon: ${task.sparse_obs_low_dim_horizon} # int 89 | down_sample_steps: ${task.sparse_obs_low_dim_down_sample_steps} # float 90 | robot0_eef_rot_axis_angle: # exists in data 91 | horizon: ${task.sparse_obs_low_dim_horizon} # int 92 | down_sample_steps: ${task.sparse_obs_low_dim_down_sample_steps} # float 93 | # robot0_abs_eef_rot_axis_angle: # exists in data 94 | # horizon: ${task.sparse_obs_low_dim_horizon} # int 95 | # down_sample_steps: ${task.sparse_obs_low_dim_down_sample_steps} # float 96 | action: 97 | sparse: 98 | horizon: ${task.sparse_action_horizon} 99 | down_sample_steps: ${task.sparse_action_down_sample_steps} # int 100 | training_duration_per_sparse_query: ${task.action_horizon_duration_buffer_ms} 101 | 102 | 103 | task_name: &task_name 104 | # dataset_path: ${oc.env:PYRITE_DATASET_FOLDERS}/stow_new_blade_dagger1 # offline dagger data 105 | # dataset_path: ${oc.env:PYRITE_DATASET_FOLDERS}/online_stow_nb_v5_50/processed_no_correction # online correction 1st batch data 106 | # dataset_path: ${oc.env:PYRITE_DATASET_FOLDERS}/finetune_online_batch10/processed 107 | # dataset_path: ${oc.env:PYRITE_DATASET_FOLDERS}/belt_assembly 108 | # dataset_path: ${oc.env:PYRITE_DATASET_FOLDERS}/belt_assembly_offline_50_total_190 109 | dataset_path: ${oc.env:PYRITE_DATASET_FOLDERS}/belt_angled_150 110 | 111 | env_runner: # used in workspace for computing metrics 112 | _target_: diffusion_policy.env_runner.real_pusht_image_runner.RealPushTImageRunner 113 | 114 | dataset: 115 | _target_: diffusion_policy.dataset.virtual_target_dataset.VirtualTargetDataset 116 | shape_meta: *shape_meta 117 | dataset_path: ${task.dataset_path} 118 | sparse_query_frequency_down_sample_steps: 8 119 | hack_linear_interpolated_dense_action: False 120 | # cache_dir: null 121 | action_padding: True 122 | temporally_independent_normalization: False 123 | seed: 42 124 | val_ratio: 0.05 125 | normalize_wrench: False 126 | # weighted_sampling: 1 # if > 1, duplicate the correction sample's ids weighted_sampling times 127 | # correction_horizon: 10 # the extra action horizon sampled before and after the "start of correction" --------------------------------------------------------------------------------