├── envs ├── __init__.py ├── xarm-env │ ├── xarm_env │ │ ├── envs │ │ │ ├── __init__.py │ │ │ ├── constants.py │ │ │ └── robot_env.py │ │ └── __init__.py │ └── setup.py └── constants.py ├── cfgs ├── suite │ ├── task │ │ └── xarm_env.yaml │ └── xarm_env_aa.yaml ├── dataloader │ └── xarm_env_aa.yaml ├── agent │ └── bc.yaml ├── config.yaml └── config_eval.yaml ├── requirements.txt ├── .pre-commit-config.yaml ├── env.yml ├── replay_buffer.py ├── LICENSE ├── README.md ├── agent ├── networks │ ├── policy_head.py │ ├── mlp.py │ ├── rgb_modules.py │ └── gpt.py └── bc.py ├── video.py ├── .gitignore ├── convert_to_pkl.py ├── logger.py ├── eval.py ├── utils.py ├── train_bc.py ├── suite └── xarm_env.py ├── process_xarm_data.py └── read_data └── xarm_env_aa.py /envs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /envs/xarm-env/xarm_env/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from xarm_env.envs.robot_env import RobotEnv 2 | -------------------------------------------------------------------------------- /envs/xarm-env/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup(name="xarm_env", version="0.0.1", install_requires=["gym"]) 4 | -------------------------------------------------------------------------------- /cfgs/suite/task/xarm_env.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | 4 | suite: xarm_env 5 | tasks: 6 | # - 0529_norot_ballincup #y 7 | - 0802_insert_plug_triangle 8 | -------------------------------------------------------------------------------- /envs/xarm-env/xarm_env/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | 3 | register( 4 | id="Robot-v1", 5 | entry_point="xarm_env.envs:RobotEnv", 6 | max_episode_steps=400, # 200, 7 | ) 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python==4.9.0.80 2 | ipdb 3 | h5py 4 | wandb 5 | termcolor 6 | tensorboard 7 | hydra-core 8 | hydra-submitit-launcher 9 | gym 10 | dm-env 11 | imageio 12 | imageio-ffmpeg 13 | tqdm 14 | scipy 15 | blosc 16 | einops 17 | zmq 18 | -------------------------------------------------------------------------------- /envs/constants.py: -------------------------------------------------------------------------------- 1 | HOST_ADDRESS = "10.19.216.156" 2 | CAMERA_HOST_ADDRESS = "10.19.216.156" 3 | DEPLOYMENT_PORT = 10000 4 | 5 | # camera 6 | CAMERA_PORT_OFFSET = 10005 7 | CAM_SERIAL_NUMS = { 8 | 1: "239122072217", 9 | 2: "023322062082", 10 | 3: "233522071078", 11 | 4: "141722076049", 12 | } 13 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.4.0 4 | hooks: 5 | - id: check-yaml 6 | - id: end-of-file-fixer 7 | - id: trailing-whitespace 8 | - repo: https://github.com/psf/black 9 | rev: 23.7.0 10 | hooks: 11 | - id: black 12 | - repo: https://github.com/hadialqattan/pycln 13 | rev: v2.2.2 14 | hooks: 15 | - id: pycln 16 | -------------------------------------------------------------------------------- /envs/xarm-env/xarm_env/envs/constants.py: -------------------------------------------------------------------------------- 1 | HOST_ADDRESS = "10.19.216.156" 2 | CAMERA_HOST_ADDRESS = "10.19.216.156" 3 | DEPLOYMENT_PORT = 10000 4 | 5 | # camera 6 | CAMERA_PORT_OFFSET = 10005 7 | CAM_SERIAL_NUMS = { 8 | 1: "239122072217", 9 | 2: "023322062082", 10 | 3: "233522071078", 11 | 4: "141722076049", 12 | } 13 | 14 | FISH_EYE_CAMERA_PORT_OFFSET = 10010 15 | 16 | FISH_EYE_CAM_SERIAL_NUMS = {51: "18", 52: "26"} 17 | -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: visuoskin 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | 8 | dependencies: 9 | - python=3.11 10 | - numpy 11 | - pre-commit 12 | - pip 13 | - pytorch=2.1 14 | - torchvision 15 | - pytorch-cuda=11.8 16 | - pip: 17 | - ipdb 18 | - h5py 19 | - termcolor 20 | - tensorboard 21 | - hydra-core 22 | - hydra-submitit-launcher 23 | - dm-env 24 | - imageio 25 | -------------------------------------------------------------------------------- /replay_buffer.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | 6 | def _worker_init_fn(worker_id): 7 | # seed = np.random.get_state()[1][0] + worker_id 8 | # np.random.seed(seed) 9 | # random.seed(seed) 10 | np.random.seed(worker_id) 11 | random.seed(worker_id) 12 | 13 | 14 | def make_expert_replay_loader(iterable, batch_size): 15 | loader = torch.utils.data.DataLoader( 16 | iterable, 17 | batch_size=batch_size, 18 | num_workers=16, 19 | pin_memory=True, 20 | worker_init_fn=_worker_init_fn, 21 | ) 22 | return loader 23 | -------------------------------------------------------------------------------- /cfgs/dataloader/xarm_env_aa.yaml: -------------------------------------------------------------------------------- 1 | bc_dataset: 2 | _target_: read_data.xarm_env_aa.BCDataset 3 | path: ${root_dir}/processed_data_pkl_aa 4 | tasks: ${suite.task.tasks} 5 | num_demos_per_task: ${num_demos_per_task} 6 | temporal_agg: ${temporal_agg} 7 | num_queries: ${num_queries} 8 | img_size: ${img_size} 9 | action_after_steps: ${suite.action_after_steps} 10 | store_actions: true 11 | pixel_keys: ${suite.pixel_keys} 12 | aux_keys: ${suite.aux_keys} 13 | subsample: 5 14 | skip_first_n: 0 15 | relative_actions: true 16 | random_mask_proprio: false 17 | sensor_params: ${suite.sensor_params} 18 | -------------------------------------------------------------------------------- /cfgs/agent/bc.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.bc.BCAgent 3 | obs_shape: ??? # to be specified later 4 | action_shape: ??? # to be specified later 5 | device: ${device} 6 | lr: 1e-4 #1e-5 #1e-4 7 | hidden_dim: ${suite.hidden_dim} 8 | stddev_schedule: 0.1 9 | stddev_clip: 0.3 10 | use_tb: ${use_tb} 11 | augment: True 12 | encoder_type: ${encoder_type} 13 | policy_type: ${policy_type} 14 | policy_head: ${policy_head} 15 | pixel_keys: ${suite.pixel_keys} 16 | aux_keys: ${suite.aux_keys} 17 | use_aux_inputs: ${use_aux_inputs} 18 | train_encoder: true 19 | norm: false 20 | separate_encoders: false # have a separate encoder for each pixel key 21 | temporal_agg: ${temporal_agg} 22 | max_episode_len: ${suite.task_make_fn.max_episode_len} # to be specified later 23 | num_queries: ${num_queries} 24 | use_actions: ${use_actions} 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Raunaq Bhirangi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /cfgs/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - local_config 3 | - _self_ 4 | - agent: bc 5 | - dataloader: xarm_env_aa 6 | - suite: xarm_env_aa 7 | - override hydra/launcher: submitit_local 8 | 9 | # replay buffer 10 | batch_size: 256 11 | # misc 12 | seed: 0 13 | device: cuda 14 | save_video: true 15 | save_train_video: false 16 | use_tb: true 17 | 18 | # experiment 19 | num_demos_per_task: 5000 20 | encoder_type: 'resnet' # base, patch, resnet 21 | policy_type: 'gpt' # mlp, gpt 22 | img_size: 128 23 | policy_head: deterministic # deterministic, gmm, bet, diffusion, vqbet 24 | use_aux_inputs: true 25 | use_language: false 26 | use_actions: false 27 | sequential_train: false 28 | eval: false 29 | experiment: ${suite.name}_bc 30 | experiment_label: ${policy_head} 31 | 32 | # expert dataset 33 | num_demos: null #10(dmc), 1(metaworld), 1(particle), 1(robotgym) 34 | expert_dataset: ${dataloader.bc_dataset} 35 | 36 | # Load weights 37 | load_bc: false 38 | bc_weight: null 39 | 40 | # Action chunking parameters 41 | temporal_agg: true 42 | num_queries: 10 43 | 44 | # TODO: Fix this 45 | max_episode_len: 1000 46 | 47 | hydra: 48 | job: 49 | chdir: true 50 | run: 51 | dir: ./exp_local/${now:%Y.%m.%d}_${experiment}/${now:%H%M%S}_${experiment_label} 52 | sweep: 53 | dir: ./exp_local/${now:%Y.%m.%d}/${now:%H%M%S} 54 | subdir: ${hydra.job.num} 55 | -------------------------------------------------------------------------------- /cfgs/suite/xarm_env_aa.yaml: -------------------------------------------------------------------------------- 1 | # @package suite 2 | defaults: 3 | - _self_ 4 | - task: xarm_env 5 | 6 | suite: xarm_env 7 | name: "xarm_env" 8 | 9 | # task settings 10 | frame_stack: 1 11 | action_repeat: 1 12 | discount: 0.99 13 | hidden_dim: 256 14 | action_type: continuous 15 | 16 | # train settings 17 | num_train_steps: 10000 18 | log_every_steps: 100 19 | save_every_steps: 500 20 | num_train_steps_per_task: 20000 #2500 #5000 21 | 22 | # eval 23 | eval_every_steps: 200000 #20000 #5000 24 | num_eval_episodes: 5 25 | 26 | # data loading 27 | action_after_steps: 1 #8 28 | 29 | # obs_keys 30 | pixel_keys: ["pixels1", "pixels2", "pixels51", "pixels52"] 31 | aux_keys: ["sensor0","sensor1"] 32 | # aux_keys: ["digit80","digit81"] 33 | # aux_keys: ["sensor"] 34 | feature_key: "proprioceptive" 35 | sensor_params: 36 | sensor_type: reskin 37 | subtract_sensor_baseline: true 38 | 39 | # snapshot 40 | save_snapshot: true 41 | 42 | task_make_fn: 43 | _target_: suite.xarm_env.make 44 | frame_stack: ${suite.frame_stack} 45 | action_repeat: ${suite.action_repeat} 46 | height: ${img_size} 47 | width: ${img_size} 48 | max_episode_len: ??? # to be specified later 49 | max_state_dim: ??? # to be specified later 50 | use_egocentric: true 51 | use_fisheye: true 52 | task_description: "just training" 53 | pixel_keys: ${suite.pixel_keys} 54 | aux_keys: ${suite.aux_keys} 55 | sensor_params: ${suite.sensor_params} 56 | eval: ${eval} # eval true mean use robot 57 | -------------------------------------------------------------------------------- /cfgs/config_eval.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - local_config 3 | - _self_ 4 | - agent: bc 5 | - dataloader: xarm_env_aa 6 | - suite: xarm_env_aa 7 | # - override hydra/launcher: submitit_local 8 | 9 | # replay buffer 10 | batch_size: 256 11 | # misc 12 | seed: 0 13 | device: cuda 14 | save_video: true 15 | save_train_video: false 16 | use_tb: true 17 | eval: true 18 | 19 | # experiment 20 | num_demos_per_task: 5000 21 | encoder_type: 'resnet' # base, patch, resnet 22 | policy_type: 'gpt' # mlp, gpt 23 | policy_head: deterministic # deterministic, gmm, bet, diffusion, vqbet 24 | use_aux_inputs: true 25 | use_language: false 26 | use_actions: false 27 | sequential_train: false 28 | irl: false 29 | experiment: ${suite.name}_eval 30 | experiment_label: ${policy_head} 31 | 32 | # expert dataset 33 | num_demos: null #10(dmc), 1(metaworld), 1(particle), 1(robotgym) 34 | expert_dataset: ${dataloader.bc_dataset} 35 | 36 | # Load weights 37 | load_bc: false 38 | bc_weight: null 39 | 40 | # Action chunking parameters 41 | temporal_agg: true 42 | num_queries: 10 43 | 44 | # TODO: Fix this 45 | max_episode_len: 1000 46 | 47 | hydra: 48 | job: 49 | chdir: true 50 | run: 51 | dir: ./exp_local/${now:%Y.%m.%d}_${experiment}/${now:%H%M%S}_${experiment_label} 52 | sweep: 53 | dir: ./exp_local/${now:%Y.%m.%d}/${now:%H%M%S} 54 | subdir: ${hydra.job.num} 55 | # launcher: 56 | # tasks_per_node: 1 57 | # nodes: 1 58 | # submitit_folder: ./exp/${now:%Y.%m.%d}/${now:%H%M%S}_${experiment}/.slurm 59 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Visuo-Skin (ViSk) 2 | Accompanying code for training VisuoSkin policies as described in the paper: 3 | [Learning Precise, Contact-Rich Manipulation through Uncalibrated Tactile Skins](https://visuoskin.github.io) 4 |

5 |

6 | fig1 7 |

8 | 9 | ## About 10 | 11 | ViSk is a framework for learning visuotactile policies for fine-grained, contact-rich manipulation tasks. ViSk uses a transformer-based architecture in conjunction with [AnySkin](https://any-skin.github.io) and presents a significant improvement over vision-only policies as well as visuotactile policies that use high-dimensional tactile sensors like DIGIT. 12 | 13 | ## Installation 14 | 1. Clone this repository 15 | ``` 16 | git clone https://github.com/raunaqbhirangi/visuoskin.git 17 | ``` 18 | 19 | 2. Create a conda environment and install dependencies 20 | 21 | ``` 22 | conda create -f env.yml 23 | pip install -r requirements.txt 24 | ``` 25 | 26 | 3. Move raw data to your desired location and set `DATA_DIR` in `utils.py` to point to this location. Similarly, set `root_dir` in `cfgs/local_config.yaml`. 27 | 28 | 4. Process data for the `current-task` (name of the directory containing demonstration data for the current task) and convert to pkl. 29 | 30 | ``` 31 | python process_data.py -t current-task 32 | python convert_to_pkl.py -t current-task 33 | ``` 34 | 5. Install `xarm-env` using `pip install -e envs/xarm-env` 35 | 36 | 6. Run BC training 37 | ``` 38 | python train_bc.py 'suite.task.tasks=[current-task]' 39 | ``` 40 | -------------------------------------------------------------------------------- /agent/networks/policy_head.py: -------------------------------------------------------------------------------- 1 | import einops 2 | 3 | # import robomimic.utils.tensor_utils as TensorUtils 4 | import torch 5 | import torch.distributions as D 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | import utils 10 | 11 | # from agent.networks.utils.diffusion_policy import DiffusionPolicy 12 | # from agent.networks.utils.vqbet.pretrain_vqvae import init_vqvae, pretrain_vqvae 13 | from agent.networks.mlp import MLP 14 | 15 | ######################################### Deterministic Head ######################################### 16 | 17 | 18 | class DeterministicHead(nn.Module): 19 | def __init__( 20 | self, 21 | input_size, 22 | output_size, 23 | hidden_size=1024, 24 | num_layers=2, 25 | action_squash=True, 26 | loss_coef=1.0, 27 | ): 28 | super().__init__() 29 | self.loss_coef = loss_coef 30 | 31 | sizes = [input_size] + [hidden_size] * num_layers + [output_size] 32 | layers = [] 33 | for i in range(num_layers): 34 | layers += [nn.Linear(sizes[i], sizes[i + 1]), nn.ReLU()] 35 | layers += [nn.Linear(sizes[-2], sizes[-1])] 36 | 37 | if action_squash: 38 | layers += [nn.Tanh()] 39 | 40 | self.net = nn.Sequential(*layers) 41 | 42 | def forward(self, x, stddev=None, **kwargs): 43 | mu = self.net(x) 44 | std = stddev if stddev is not None else 0.1 45 | std = torch.ones_like(mu) * std 46 | dist = utils.TruncatedNormal(mu, std) 47 | return dist 48 | 49 | def loss_fn(self, dist, target, reduction="mean", **kwargs): 50 | log_probs = dist.log_prob(target) 51 | loss = -log_probs 52 | 53 | if reduction == "mean": 54 | loss = loss.mean() * self.loss_coef 55 | elif reduction == "none": 56 | loss = loss * self.loss_coef 57 | elif reduction == "sum": 58 | loss = loss.sum() * self.loss_coef 59 | else: 60 | raise NotImplementedError 61 | 62 | return { 63 | "actor_loss": loss, 64 | } 65 | -------------------------------------------------------------------------------- /video.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import imageio 3 | import numpy as np 4 | 5 | 6 | class VideoRecorder: 7 | def __init__(self, root_dir, render_size=256, fps=20): 8 | if root_dir is not None: 9 | self.save_dir = root_dir / "eval_video" 10 | self.save_dir.mkdir(exist_ok=True) 11 | else: 12 | self.save_dir = None 13 | 14 | self.render_size = render_size 15 | self.fps = fps 16 | self.frames = [] 17 | 18 | def init(self, env, enabled=True): 19 | self.frames = [] 20 | self.enabled = self.save_dir is not None and enabled 21 | self.record(env) 22 | 23 | def record(self, env): 24 | if self.enabled: 25 | if hasattr(env, "physics"): 26 | frame = env.physics.render( 27 | height=self.render_size, width=self.render_size, camera_id=0 28 | ) 29 | else: 30 | frame = env.render() 31 | self.frames.append(frame) 32 | 33 | def save(self, file_name): 34 | if self.enabled: 35 | path = self.save_dir / file_name 36 | imageio.mimsave(str(path), self.frames, fps=self.fps) 37 | 38 | 39 | class TrainVideoRecorder: 40 | def __init__(self, root_dir, render_size=256, fps=20): 41 | if root_dir is not None: 42 | self.save_dir = root_dir / "train_video" 43 | self.save_dir.mkdir(exist_ok=True) 44 | else: 45 | self.save_dir = None 46 | 47 | self.render_size = render_size 48 | self.fps = fps 49 | self.frames = [] 50 | 51 | def init(self, obs, enabled=True): 52 | self.frames = [] 53 | self.enabled = self.save_dir is not None and enabled 54 | self.record(obs) 55 | 56 | def record(self, obs): 57 | if self.enabled: 58 | frame = cv2.resize( 59 | obs[-3:].transpose(1, 2, 0), 60 | dsize=(self.render_size, self.render_size), 61 | interpolation=cv2.INTER_CUBIC, 62 | ) 63 | self.frames.append(frame) 64 | 65 | def save(self, file_name): 66 | if self.enabled: 67 | path = self.save_dir / file_name 68 | imageio.mimsave(str(path), self.frames, fps=self.fps) 69 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | exp_local/ 163 | test_env.py 164 | create_robosuite_configs.py 165 | expert_demos/ 166 | logs/ 167 | outputs/ 168 | data/ 169 | preprocessing_output/ 170 | snapshot/ 171 | tb/ 172 | train.csv 173 | eval.csv 174 | *.png 175 | exp*/ 176 | *.mp4 177 | cfgs/local_config.yaml 178 | -------------------------------------------------------------------------------- /agent/networks/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Callable, List, Optional 3 | 4 | 5 | class MLP(torch.nn.Sequential): 6 | """This block implements the multi-layer perceptron (MLP) module. 7 | Adapted for backward compatibility from the torchvision library: 8 | https://pytorch.org/vision/0.14/generated/torchvision.ops.MLP.html 9 | 10 | LICENSE: 11 | 12 | From PyTorch: 13 | 14 | Copyright (c) 2016- Facebook, Inc (Adam Paszke) 15 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala) 16 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) 17 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) 18 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) 19 | Copyright (c) 2011-2013 NYU (Clement Farabet) 20 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) 21 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio) 22 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) 23 | 24 | From Caffe2: 25 | 26 | Copyright (c) 2016-present, Facebook Inc. All rights reserved. 27 | 28 | All contributions by Facebook: 29 | Copyright (c) 2016 Facebook Inc. 30 | 31 | All contributions by Google: 32 | Copyright (c) 2015 Google Inc. 33 | All rights reserved. 34 | 35 | All contributions by Yangqing Jia: 36 | Copyright (c) 2015 Yangqing Jia 37 | All rights reserved. 38 | 39 | All contributions by Kakao Brain: 40 | Copyright 2019-2020 Kakao Brain 41 | 42 | All contributions by Cruise LLC: 43 | Copyright (c) 2022 Cruise LLC. 44 | All rights reserved. 45 | 46 | All contributions from Caffe: 47 | Copyright(c) 2013, 2014, 2015, the respective contributors 48 | All rights reserved. 49 | 50 | All other contributions: 51 | Copyright(c) 2015, 2016 the respective contributors 52 | All rights reserved. 53 | 54 | Caffe2 uses a copyright model similar to Caffe: each contributor holds 55 | copyright over their contributions to Caffe2. The project versioning records 56 | all such contribution and copyright details. If a contributor wants to further 57 | mark their specific copyright on a particular contribution, they should 58 | indicate their copyright solely in the commit message of the change when it is 59 | committed. 60 | 61 | All rights reserved. 62 | 63 | Redistribution and use in source and binary forms, with or without 64 | modification, are permitted provided that the following conditions are met: 65 | 66 | 1. Redistributions of source code must retain the above copyright 67 | notice, this list of conditions and the following disclaimer. 68 | 69 | 2. Redistributions in binary form must reproduce the above copyright 70 | notice, this list of conditions and the following disclaimer in the 71 | documentation and/or other materials provided with the distribution. 72 | 73 | 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America 74 | and IDIAP Research Institute nor the names of its contributors may be 75 | used to endorse or promote products derived from this software without 76 | specific prior written permission. 77 | 78 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 79 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 80 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 81 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 82 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 83 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 84 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 85 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 86 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 87 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 88 | POSSIBILITY OF SUCH DAMAGE. 89 | 90 | 91 | Args: 92 | in_channels (int): Number of channels of the input 93 | hidden_channels (List[int]): List of the hidden channel dimensions 94 | norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the linear layer. If ``None`` this layer won't be used. Default: ``None`` 95 | activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the linear layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU`` 96 | inplace (bool, optional): Parameter for the activation layer, which can optionally do the operation in-place. 97 | Default is ``None``, which uses the respective default values of the ``activation_layer`` and Dropout layer. 98 | bias (bool): Whether to use bias in the linear layer. Default ``True`` 99 | dropout (float): The probability for the dropout layer. Default: 0.0 100 | """ 101 | 102 | def __init__( 103 | self, 104 | in_channels: int, 105 | hidden_channels: List[int], 106 | activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, 107 | inplace: Optional[bool] = None, 108 | bias: bool = True, 109 | dropout: float = 0.0, 110 | ): 111 | params = {} if inplace is None else {"inplace": inplace} 112 | 113 | layers = [] 114 | in_dim = in_channels 115 | for hidden_dim in hidden_channels[:-1]: 116 | layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias)) 117 | layers.append(activation_layer(**params)) 118 | layers.append(torch.nn.Dropout(dropout, **params)) 119 | in_dim = hidden_dim 120 | 121 | layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias)) 122 | layers.append(torch.nn.Dropout(dropout, **params)) 123 | 124 | super().__init__(*layers) 125 | -------------------------------------------------------------------------------- /convert_to_pkl.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import h5py as h5 3 | import numpy as np 4 | from pandas import read_csv 5 | import pickle as pkl 6 | import cv2 7 | from pathlib import Path 8 | from scipy.spatial.transform import Rotation as R 9 | 10 | from utils import DATA_DIR 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--task-name", "-t", type=str, required=True) 14 | args = parser.parse_args() 15 | 16 | TASK_NAME = args.task_name 17 | 18 | PROCESSED_DATA_PATH = Path(DATA_DIR) / "processed_data/" 19 | SAVE_DATA_PATH = Path(DATA_DIR) / "processed_data_pkl_aa/" 20 | 21 | camera_indices = [1, 2, 51, 52] 22 | img_size = (128, 128) 23 | NUM_DEMOS = None 24 | 25 | # Create the save path 26 | SAVE_DATA_PATH.mkdir(parents=True, exist_ok=True) 27 | 28 | DATASET_PATH = Path(f"{PROCESSED_DATA_PATH}/{TASK_NAME}") 29 | 30 | if (SAVE_DATA_PATH / f"{TASK_NAME}.pkl").exists(): 31 | print(f"Data for {TASK_NAME} already exists. Appending to it...") 32 | input("Press Enter to continue...") 33 | data = pkl.load(open(SAVE_DATA_PATH / f"{TASK_NAME}.pkl", "rb")) 34 | observations = data["observations"] 35 | max_cartesian = data["max_cartesian"] 36 | min_cartesian = data["min_cartesian"] 37 | max_gripper = data["max_gripper"] 38 | min_gripper = data["min_gripper"] 39 | else: 40 | # Init storing variables 41 | observations = [] 42 | 43 | # Store max and min 44 | max_cartesian, min_cartesian = None, None 45 | max_sensor, min_sensor = None, None 46 | # max_rel_cartesian, min_rel_cartesian = None, None 47 | max_gripper, min_gripper = None, None 48 | 49 | # Load each data point and save in a list 50 | dirs = [x for x in DATASET_PATH.iterdir() if x.is_dir()] 51 | for i, data_point in enumerate(sorted(dirs)): 52 | use_sensor = True 53 | print(f"Processing data point {i+1}/{len(dirs)}") 54 | 55 | if NUM_DEMOS is not None: 56 | if int(str(data_point).split("_")[-1]) >= NUM_DEMOS: 57 | print(f"Skipping data point {data_point}") 58 | continue 59 | 60 | observation = {} 61 | # images 62 | image_dir = data_point / "videos" 63 | if not image_dir.exists(): 64 | print(f"Data point {data_point} is incomplete") 65 | continue 66 | for save_idx, idx in enumerate(camera_indices): 67 | # Read the frames in the video 68 | video_path = image_dir / f"camera{idx}.mp4" 69 | cap = cv2.VideoCapture(str(video_path)) 70 | if not cap.isOpened(): 71 | print(f"Video {video_path} could not be opened") 72 | continue 73 | frames = [] 74 | while True: 75 | ret, frame = cap.read() 76 | if not ret: 77 | break 78 | if idx == 52: 79 | # crop the right side of the image for the gripper cam 80 | shape = frame.shape 81 | crop_percent = 0.2 82 | frame = frame[:, : int(shape[1] * (1 - crop_percent))] 83 | frame = cv2.resize(frame, img_size) 84 | frames.append(frame) 85 | if idx < 80: 86 | observation[f"pixels{idx}"] = np.array(frames) 87 | else: 88 | observation[f"digit{idx}"] = np.array(frames) 89 | # read cartesian and gripper states from csv 90 | state_csv_path = data_point / "states.csv" 91 | sensor_csv_path = data_point / "sensor.csv" 92 | state = read_csv(state_csv_path) 93 | try: 94 | sensor_data = read_csv(sensor_csv_path) 95 | sensor_states = sensor_data["sensor_values"].values 96 | sensor_states = np.array( 97 | [ 98 | np.array([float(x.strip()) for x in sensor[1:-1].split(",")]) 99 | for sensor in sensor_states 100 | ], 101 | dtype=np.float32, 102 | ) 103 | except FileNotFoundError: 104 | use_sensor = False 105 | print(f"Sensor data not found for {data_point}") 106 | 107 | # Read cartesian state where every element is a 6D pose 108 | # Separate the pose into values instead of string 109 | cartesian_states = state["pose_aa"].values 110 | cartesian_states = np.array( 111 | [ 112 | np.array([float(x.strip()) for x in pose[1:-1].split(",")]) 113 | for pose in cartesian_states 114 | ], 115 | dtype=np.float32, 116 | ) 117 | 118 | gripper_states = state["gripper_state"].values.astype(np.float32) 119 | observation["cartesian_states"] = cartesian_states.astype(np.float32) 120 | observation["gripper_states"] = gripper_states.astype(np.float32) 121 | if use_sensor: 122 | observation["sensor_states"] = sensor_states.astype(np.float32) 123 | if max_sensor is None: 124 | max_sensor = np.max(sensor_states) 125 | min_sensor = np.min(sensor_states) 126 | else: 127 | max_sensor = np.maximum(max_sensor, np.max(sensor_states)) 128 | min_sensor = np.minimum(min_sensor, np.min(sensor_states)) 129 | max_sensor = np.max(sensor_states, axis=0) 130 | min_sensor = np.min(sensor_states, axis=0) 131 | 132 | # update max and min 133 | if max_cartesian is None: 134 | max_cartesian = np.max(cartesian_states, axis=0) 135 | min_cartesian = np.min(cartesian_states, axis=0) 136 | else: 137 | max_cartesian = np.maximum(max_cartesian, np.max(cartesian_states, axis=0)) 138 | min_cartesian = np.minimum(min_cartesian, np.min(cartesian_states, axis=0)) 139 | if max_gripper is None: 140 | max_gripper = np.max(gripper_states) 141 | min_gripper = np.min(gripper_states) 142 | else: 143 | max_gripper = np.maximum(max_gripper, np.max(gripper_states)) 144 | min_gripper = np.minimum(min_gripper, np.min(gripper_states)) 145 | 146 | # append to observations 147 | observations.append(observation) 148 | 149 | # Save the data 150 | data = { 151 | "observations": observations, 152 | "max_cartesian": max_cartesian, 153 | "min_cartesian": min_cartesian, 154 | "max_gripper": max_gripper, 155 | "min_gripper": min_gripper, 156 | "max_sensor": max_sensor, 157 | "min_sensor": min_sensor, 158 | } 159 | pkl.dump(data, open(SAVE_DATA_PATH / f"{TASK_NAME}.pkl", "wb")) 160 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import datetime 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | import torch 7 | import torchvision 8 | from termcolor import colored 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | BC_TRAIN_FORMAT = [ 12 | ("step", "S", "int"), 13 | ("actor_loss", "L", "float"), 14 | ("total_time", "T", "time"), 15 | ] 16 | BC_EVAL_FORMAT = [ 17 | ("frame", "F", "int"), 18 | ("step", "S", "int"), 19 | ("episode", "E", "int"), 20 | ("episode_length", "L", "int"), 21 | ("episode_reward", "R", "float"), 22 | ("imitation_reward", "R_i", "float"), 23 | ("total_time", "T", "time"), 24 | ] 25 | SSL_TRAIN_FORMAT = [ 26 | ("step", "S", "int"), 27 | ("loss", "L", "float"), 28 | ("total_time", "T", "time"), 29 | ] 30 | SSL_EVAL_FORMAT = [ 31 | ("epoch", "E", "int"), 32 | ("step", "S", "int"), 33 | ("loss", "E", "float"), 34 | ("total_time", "T", "time"), 35 | ] 36 | 37 | 38 | class AverageMeter(object): 39 | def __init__(self): 40 | self._sum = 0 41 | self._count = 0 42 | 43 | def update(self, value, n=1): 44 | self._sum += value 45 | self._count += n 46 | 47 | def value(self): 48 | return self._sum / max(1, self._count) 49 | 50 | 51 | class MetersGroup(object): 52 | def __init__(self, csv_file_name, formating): 53 | self._csv_file_name = csv_file_name 54 | self._formating = formating 55 | self._meters = defaultdict(AverageMeter) 56 | self._csv_file = None 57 | self._csv_writer = None 58 | 59 | def log(self, key, value, n=1): 60 | self._meters[key].update(value, n) 61 | 62 | def _prime_meters(self): 63 | data = dict() 64 | for key, meter in self._meters.items(): 65 | if key.startswith("train_vq"): 66 | key = key[len("train_vq") + 1 :] 67 | elif key.startswith("train"): 68 | key = key[len("train") + 1 :] 69 | else: 70 | key = key[len("eval") + 1 :] 71 | key = key.replace("/", "_") 72 | data[key] = meter.value() 73 | return data 74 | 75 | def _remove_old_entries(self, data): 76 | rows = [] 77 | with self._csv_file_name.open("r") as f: 78 | reader = csv.DictReader(f) 79 | for row in reader: 80 | if float(row["step"]) >= data["step"]: 81 | break 82 | rows.append(row) 83 | with self._csv_file_name.open("w") as f: 84 | writer = csv.DictWriter(f, fieldnames=sorted(data.keys()), restval=0.0) 85 | writer.writeheader() 86 | for row in rows: 87 | writer.writerow(row) 88 | 89 | def _dump_to_csv(self, data): 90 | if self._csv_writer is None: 91 | should_write_header = True 92 | if self._csv_file_name.exists(): 93 | self._remove_old_entries(data) 94 | should_write_header = False 95 | 96 | self._csv_file = self._csv_file_name.open("a") 97 | self._csv_writer = csv.DictWriter( 98 | self._csv_file, fieldnames=sorted(data.keys()), restval=0.0 99 | ) 100 | if should_write_header: 101 | self._csv_writer.writeheader() 102 | 103 | self._csv_writer.writerow(data) 104 | self._csv_file.flush() 105 | 106 | def _format(self, key, value, ty): 107 | if ty == "int": 108 | value = int(value) 109 | return f"{key}: {value}" 110 | elif ty == "float": 111 | return f"{key}: {value:.04f}" 112 | elif ty == "time": 113 | value = str(datetime.timedelta(seconds=int(value))) 114 | return f"{key}: {value}" 115 | else: 116 | raise f"invalid format type: {ty}" 117 | 118 | def _dump_to_console(self, data, prefix): 119 | prefix = colored( 120 | prefix, "yellow" if prefix in ["train", "train_vq"] else "green" 121 | ) 122 | pieces = [f"| {prefix: <14}"] 123 | for key, disp_key, ty in self._formating: 124 | value = data.get(key, 0) 125 | pieces.append(self._format(disp_key, value, ty)) 126 | print(" | ".join(pieces)) 127 | 128 | def dump(self, step, prefix): 129 | if len(self._meters) == 0: 130 | return 131 | data = self._prime_meters() 132 | data["frame"] = step 133 | self._dump_to_csv(data) 134 | self._dump_to_console(data, prefix) 135 | self._meters.clear() 136 | 137 | 138 | class Logger(object): 139 | def __init__(self, log_dir, use_tb, mode="bc"): 140 | """ 141 | mode: bc, ssl 142 | """ 143 | self._log_dir = log_dir 144 | if mode == "bc": 145 | self._train_mg = MetersGroup( 146 | log_dir / "train.csv", formating=BC_TRAIN_FORMAT 147 | ) 148 | self._eval_mg = MetersGroup(log_dir / "eval.csv", formating=BC_EVAL_FORMAT) 149 | elif mode == "ssl": 150 | self._train_mg = MetersGroup( 151 | log_dir / "train.csv", formating=SSL_TRAIN_FORMAT 152 | ) 153 | self._train_vq_mg = MetersGroup( 154 | log_dir / "train_vq.csv", formating=SSL_TRAIN_FORMAT 155 | ) 156 | self._eval_mg = MetersGroup(log_dir / "eval.csv", formating=SSL_EVAL_FORMAT) 157 | if use_tb: 158 | self._sw = SummaryWriter(str(log_dir / "tb")) 159 | else: 160 | self._sw = None 161 | 162 | def _try_sw_log(self, key, value, step): 163 | if self._sw is not None: 164 | self._sw.add_scalar(key, value, step) 165 | 166 | def log(self, key, value, step): 167 | assert key.startswith("train") or key.startswith("eval") 168 | if type(value) == torch.Tensor: 169 | value = value.item() 170 | self._try_sw_log(key, value, step) 171 | # mg = self._train_mg if key.startswith('train') else self._eval_mg 172 | if key.startswith("train_vq"): 173 | mg = self._train_vq_mg 174 | else: 175 | mg = self._train_mg if key.startswith("train") else self._eval_mg 176 | mg.log(key, value) 177 | 178 | def log_metrics(self, metrics, step, ty): 179 | for key, value in metrics.items(): 180 | self.log(f"{ty}/{key}", value, step) 181 | 182 | def dump(self, step, ty=None): 183 | if ty is None or ty == "eval": 184 | self._eval_mg.dump(step, "eval") 185 | if ty is None or ty == "train": 186 | self._train_mg.dump(step, "train") 187 | if ty is None or ty == "train_vq": 188 | self._train_vq_mg.dump(step, "train_vq") 189 | 190 | def log_and_dump_ctx(self, step, ty): 191 | return LogAndDumpCtx(self, step, ty) 192 | 193 | 194 | class LogAndDumpCtx: 195 | def __init__(self, logger, step, ty): 196 | self._logger = logger 197 | self._step = step 198 | self._ty = ty 199 | 200 | def __enter__(self): 201 | return self 202 | 203 | def __call__(self, key, value): 204 | self._logger.log(f"{self._ty}/{key}", value, self._step) 205 | 206 | def __exit__(self, *args): 207 | self._logger.dump(self._step, self._ty) 208 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import warnings 4 | import os 5 | 6 | os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" 7 | os.environ["MUJOCO_GL"] = "egl" # "osmesa" for calvin 8 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 9 | from pathlib import Path 10 | 11 | import hydra 12 | import torch 13 | import cv2 14 | import numpy as np 15 | 16 | import utils 17 | from logger import Logger 18 | from replay_buffer import make_expert_replay_loader 19 | from video import VideoRecorder 20 | 21 | warnings.filterwarnings("ignore", category=DeprecationWarning) 22 | torch.backends.cudnn.benchmark = True 23 | 24 | 25 | def make_agent(obs_spec, action_spec, cfg): 26 | obs_shape = {} 27 | for key in cfg.suite.pixel_keys: 28 | obs_shape[key] = obs_spec[key].shape 29 | if cfg.use_aux_inputs: 30 | for key in cfg.suite.aux_keys: 31 | obs_shape[key] = obs_spec[key].shape 32 | obs_shape[cfg.suite.feature_key] = obs_spec[cfg.suite.feature_key].shape 33 | cfg.agent.obs_shape = obs_shape 34 | cfg.agent.action_shape = action_spec.shape 35 | return hydra.utils.instantiate(cfg.agent) 36 | 37 | 38 | class WorkspaceIL: 39 | def __init__(self, cfg): 40 | self.work_dir = Path.cwd() 41 | print(f"workspace: {self.work_dir}") 42 | 43 | self.cfg = cfg 44 | utils.set_seed_everywhere(cfg.seed) 45 | self.device = torch.device(cfg.device) 46 | 47 | # load data 48 | dataset_iterable = hydra.utils.call(self.cfg.expert_dataset) 49 | self.expert_replay_loader = make_expert_replay_loader( 50 | dataset_iterable, self.cfg.batch_size 51 | ) 52 | self.expert_replay_iter = iter(self.expert_replay_loader) 53 | 54 | # create logger 55 | self.logger = Logger(self.work_dir, use_tb=self.cfg.use_tb) 56 | # create envs 57 | self.cfg.suite.task_make_fn.max_episode_len = 400 # ( 58 | # self.expert_replay_loader.dataset._max_episode_len 59 | # ) 60 | self.cfg.suite.task_make_fn.max_state_dim = ( 61 | self.expert_replay_loader.dataset._max_state_dim 62 | ) 63 | if self.cfg.suite.name == "dmc": 64 | self.cfg.suite.task_make_fn.max_action_dim = ( 65 | self.expert_replay_loader.dataset._max_action_dim 66 | ) 67 | self.env, self.task_descriptions = hydra.utils.call(self.cfg.suite.task_make_fn) 68 | 69 | # create agent 70 | self.agent = make_agent( 71 | self.env[0].observation_spec(), self.env[0].action_spec(), cfg 72 | ) 73 | 74 | if self.cfg.sequential_train: 75 | self.scene_idx = 0 76 | self.scene_names = self.cfg.suite.task.scenes 77 | self.task_names = { 78 | task_name: scene[task_name] 79 | for scene in self.cfg.suite.task.tasks 80 | for task_name in scene 81 | } 82 | self.envs_till_idx = len(self.task_names[self.scene_names[0]]) 83 | self.expert_replay_loader.dataset.envs_till_idx = self.envs_till_idx 84 | self.steps_till_next_scene = ( 85 | self.envs_till_idx * self.cfg.suite.num_train_steps_per_task 86 | ) 87 | self.expert_replay_iter = iter(self.expert_replay_loader) 88 | else: 89 | self.envs_till_idx = len(self.env) 90 | self.expert_replay_loader.dataset.envs_till_idx = self.envs_till_idx 91 | self.expert_replay_iter = iter(self.expert_replay_loader) 92 | 93 | # # Discretizer for BeT 94 | # if self.cfg.agent.policy_head in ["bet", "vqbet"]: 95 | # self.agent.discretize(self.expert_replay_loader.dataset.actions) 96 | 97 | self.timer = utils.Timer() 98 | self._global_step = 0 99 | self._global_episode = 0 100 | 101 | self.video_recorder = VideoRecorder( 102 | self.work_dir if self.cfg.save_video else None 103 | ) 104 | 105 | @property 106 | def global_step(self): 107 | return self._global_step 108 | 109 | @property 110 | def global_episode(self): 111 | return self._global_episode 112 | 113 | @property 114 | def global_frame(self): 115 | return self.global_step * self.cfg.suite.action_repeat 116 | 117 | def eval(self): 118 | self.agent.train(False) 119 | episode_rewards = [] 120 | successes = [] 121 | for env_idx in range(self.envs_till_idx): 122 | print(f"evaluating env {env_idx}") 123 | episode, total_reward = 0, 0 124 | eval_until_episode = utils.Until(self.cfg.suite.num_eval_episodes) 125 | success = [] 126 | 127 | while eval_until_episode(episode): 128 | time_step = self.env[env_idx].reset() 129 | self.agent.buffer_reset() 130 | step = 0 131 | 132 | # prompt 133 | # if self.cfg.prompt != None and self.cfg.prompt != "intermediate_goal": 134 | # prompt = self.expert_replay_loader.dataset.sample_test( 135 | # 0 136 | # ) # env_idx) 137 | # else: 138 | # prompt = None 139 | 140 | if episode == 0: 141 | self.video_recorder.init(self.env[env_idx], enabled=True) 142 | 143 | # plot obs with cv2 144 | while not time_step.last(): 145 | # if self.cfg.prompt == "intermediate_goal": 146 | # prompt = self.expert_replay_loader.dataset.sample_test( 147 | # env_idx, step 148 | # ) 149 | with torch.no_grad(), utils.eval_mode(self.agent): 150 | action = self.agent.act( 151 | time_step.observation, 152 | None, 153 | # self.expert_replay_loader.dataset.stats, 154 | self.stats, 155 | step, 156 | self.global_step, 157 | eval_mode=True, 158 | ) 159 | print(f"step: {step}") 160 | 161 | time_step = self.env[env_idx].step(action) 162 | self.video_recorder.record(self.env[env_idx]) 163 | total_reward += time_step.reward 164 | step += 1 165 | 166 | if self.cfg.suite.name == "calvin" and time_step.reward == 1: 167 | self.agent.buffer_reset() 168 | 169 | episode += 1 170 | success.append(time_step.observation["goal_achieved"]) 171 | self.video_recorder.save(f"{self.global_frame}_env{env_idx}.mp4") 172 | episode_rewards.append(total_reward / episode) 173 | successes.append(np.mean(success)) 174 | 175 | for _ in range(len(self.env) - self.envs_till_idx): 176 | episode_rewards.append(0) 177 | successes.append(0) 178 | 179 | with self.logger.log_and_dump_ctx(self.global_frame, ty="eval") as log: 180 | for env_idx, reward in enumerate(episode_rewards): 181 | log(f"episode_reward_env{env_idx}", reward) 182 | log(f"success_env{env_idx}", successes[env_idx]) 183 | log("episode_reward", np.mean(episode_rewards[: self.envs_till_idx])) 184 | log("success", np.mean(successes)) 185 | log("episode_length", step * self.cfg.suite.action_repeat / episode) 186 | log("episode", self.global_episode) 187 | log("step", self.global_step) 188 | 189 | self.agent.train(True) 190 | 191 | def save_snapshot(self): 192 | snapshot = self.work_dir / "snapshot.pt" 193 | self.agent.clear_buffers() 194 | keys_to_save = ["timer", "_global_step", "_global_episode"] 195 | payload = {k: self.__dict__[k] for k in keys_to_save} 196 | payload.update(self.agent.save_snapshot()) 197 | with snapshot.open("wb") as f: 198 | torch.save(payload, f) 199 | 200 | self.agent.buffer_reset() 201 | 202 | def load_snapshot(self, snapshots): 203 | # bc 204 | with snapshots["bc"].open("rb") as f: 205 | payload = torch.load(f) 206 | agent_payload = {} 207 | for k, v in payload.items(): 208 | if k not in self.__dict__: 209 | agent_payload[k] = v 210 | if "vqvae" in snapshots: 211 | with snapshots["vqvae"].open("rb") as f: 212 | payload = torch.load(f) 213 | agent_payload["vqvae"] = payload 214 | self.agent.load_snapshot(agent_payload, eval=True) 215 | # self.agent.load_snapshot_eval(agent_payload) 216 | 217 | self.stats = payload["stats"] 218 | 219 | 220 | @hydra.main(config_path="cfgs", config_name="config_eval", version_base=None) 221 | def main(cfg): 222 | from eval import WorkspaceIL as W 223 | 224 | root_dir = Path.cwd() 225 | workspace = W(cfg) 226 | 227 | # Load weights 228 | snapshots = {} 229 | # bc 230 | bc_snapshot = Path(cfg.bc_weight) 231 | if not bc_snapshot.exists(): 232 | raise FileNotFoundError(f"bc weight not found: {bc_snapshot}") 233 | print(f"loading bc weight: {bc_snapshot}") 234 | snapshots["bc"] = bc_snapshot 235 | # vqvae_snapshot = Path(cfg.vqvae_weight) 236 | # if vqvae_snapshot.exists(): 237 | # print(f"loading vqvae weight: {vqvae_snapshot}") 238 | # snapshots["vqvae"] = vqvae_snapshot 239 | workspace.load_snapshot(snapshots) 240 | 241 | workspace.eval() 242 | 243 | 244 | if __name__ == "__main__": 245 | main() 246 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import re 3 | import time 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from omegaconf import OmegaConf 10 | from torch import distributions as pyd 11 | from torch.distributions.utils import _standard_normal 12 | 13 | # Set data directory here 14 | DATA_DIR = "" 15 | 16 | 17 | class eval_mode: 18 | def __init__(self, *models): 19 | self.models = models 20 | 21 | def __enter__(self): 22 | self.prev_states = [] 23 | for model in self.models: 24 | self.prev_states.append(model.training) 25 | model.train(False) 26 | 27 | def __exit__(self, *args): 28 | for model, state in zip(self.models, self.prev_states): 29 | model.train(state) 30 | return False 31 | 32 | 33 | def set_seed_everywhere(seed): 34 | torch.manual_seed(seed) 35 | if torch.cuda.is_available(): 36 | torch.cuda.manual_seed_all(seed) 37 | np.random.seed(seed) 38 | random.seed(seed) 39 | 40 | 41 | def soft_update_params(net, target_net, tau): 42 | for param, target_param in zip(net.parameters(), target_net.parameters()): 43 | target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) 44 | 45 | 46 | def to_torch(xs, device): 47 | for key, value in xs.items(): 48 | xs[key] = torch.as_tensor(value, device=device) 49 | return xs 50 | 51 | 52 | def weight_init(m): 53 | if isinstance(m, nn.Linear): 54 | nn.init.orthogonal_(m.weight.data) 55 | if hasattr(m.bias, "data"): 56 | m.bias.data.fill_(0.0) 57 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 58 | gain = nn.init.calculate_gain("relu") 59 | nn.init.orthogonal_(m.weight.data, gain) 60 | if hasattr(m.bias, "data"): 61 | m.bias.data.fill_(0.0) 62 | 63 | 64 | class Until: 65 | def __init__(self, until, action_repeat=1): 66 | self._until = until 67 | self._action_repeat = action_repeat 68 | 69 | def __call__(self, step): 70 | if self._until is None: 71 | return True 72 | until = self._until // self._action_repeat 73 | return step < until 74 | 75 | 76 | class Every: 77 | def __init__(self, every, action_repeat=1): 78 | self._every = every 79 | self._action_repeat = action_repeat 80 | 81 | def __call__(self, step): 82 | if self._every is None: 83 | return False 84 | every = self._every // self._action_repeat 85 | if step % every == 0: 86 | return True 87 | return False 88 | 89 | 90 | class Timer: 91 | def __init__(self): 92 | self._start_time = time.time() 93 | self._last_time = time.time() 94 | # Keep track of evaluation time so that total time only includes train time 95 | self._eval_start_time = 0 96 | self._eval_time = 0 97 | self._eval_flag = False 98 | 99 | def reset(self): 100 | elapsed_time = time.time() - self._last_time 101 | self._last_time = time.time() 102 | total_time = time.time() - self._start_time - self._eval_time 103 | return elapsed_time, total_time 104 | 105 | def eval(self): 106 | if not self._eval_flag: 107 | self._eval_flag = True 108 | self._eval_start_time = time.time() 109 | else: 110 | self._eval_time += time.time() - self._eval_start_time 111 | self._eval_flag = False 112 | self._eval_start_time = 0 113 | 114 | def total_time(self): 115 | return time.time() - self._start_time - self._eval_time 116 | 117 | 118 | class TruncatedNormal(pyd.Normal): 119 | def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6): 120 | super().__init__(loc, scale, validate_args=False) 121 | self.low = low 122 | self.high = high 123 | self.eps = eps 124 | 125 | def _clamp(self, x): 126 | clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps) 127 | x = x - x.detach() + clamped_x.detach() 128 | return x 129 | 130 | def sample(self, clip=None, sample_shape=torch.Size()): 131 | shape = self._extended_shape(sample_shape) 132 | eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device) 133 | eps *= self.scale 134 | if clip is not None: 135 | eps = torch.clamp(eps, -clip, clip) 136 | x = self.loc + eps 137 | return self._clamp(x) 138 | 139 | 140 | def schedule(schdl, step): 141 | try: 142 | return float(schdl) 143 | except ValueError: 144 | match = re.match(r"linear\((.+),(.+),(.+)\)", schdl) 145 | if match: 146 | init, final, duration = [float(g) for g in match.groups()] 147 | mix = np.clip(step / duration, 0.0, 1.0) 148 | return (1.0 - mix) * init + mix * final 149 | match = re.match(r"step_linear\((.+),(.+),(.+),(.+),(.+)\)", schdl) 150 | if match: 151 | init, final1, duration1, final2, duration2 = [ 152 | float(g) for g in match.groups() 153 | ] 154 | if step <= duration1: 155 | mix = np.clip(step / duration1, 0.0, 1.0) 156 | return (1.0 - mix) * init + mix * final1 157 | else: 158 | mix = np.clip((step - duration1) / duration2, 0.0, 1.0) 159 | return (1.0 - mix) * final1 + mix * final2 160 | raise NotImplementedError(schdl) 161 | 162 | 163 | class RandomShiftsAug(nn.Module): 164 | def __init__(self, pad): 165 | super().__init__() 166 | self.pad = pad 167 | 168 | def forward(self, x): 169 | n, c, h, w = x.size() 170 | assert h == w 171 | padding = tuple([self.pad] * 4) 172 | x = F.pad(x, padding, "replicate") 173 | eps = 1.0 / (h + 2 * self.pad) 174 | arange = torch.linspace( 175 | -1.0 + eps, 1.0 - eps, h + 2 * self.pad, device=x.device, dtype=x.dtype 176 | )[:h] 177 | arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2) 178 | base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2) 179 | base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1) 180 | 181 | shift = torch.randint( 182 | 0, 2 * self.pad + 1, size=(n, 1, 1, 2), device=x.device, dtype=x.dtype 183 | ) 184 | shift *= 2.0 / (h + 2 * self.pad) 185 | 186 | grid = base_grid + shift 187 | return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False) 188 | 189 | 190 | class TorchRunningMeanStd: 191 | def __init__(self, epsilon=1e-4, shape=(), device=None): 192 | self.mean = torch.zeros(shape, device=device) 193 | self.var = torch.ones(shape, device=device) 194 | self.count = epsilon 195 | 196 | def update(self, x): 197 | with torch.no_grad(): 198 | batch_mean = torch.mean(x, axis=0) 199 | batch_var = torch.var(x, axis=0) 200 | batch_count = x.shape[0] 201 | self.update_from_moments(batch_mean, batch_var, batch_count) 202 | 203 | def update_from_moments(self, batch_mean, batch_var, batch_count): 204 | self.mean, self.var, self.count = update_mean_var_count_from_moments( 205 | self.mean, self.var, self.count, batch_mean, batch_var, batch_count 206 | ) 207 | 208 | @property 209 | def std(self): 210 | return torch.sqrt(self.var) 211 | 212 | 213 | def update_mean_var_count_from_moments( 214 | mean, var, count, batch_mean, batch_var, batch_count 215 | ): 216 | delta = batch_mean - mean 217 | tot_count = count + batch_count 218 | 219 | new_mean = mean + delta + batch_count / tot_count 220 | m_a = var * count 221 | m_b = batch_var * batch_count 222 | M2 = m_a + m_b + torch.pow(delta, 2) * count * batch_count / tot_count 223 | new_var = M2 / tot_count 224 | new_count = tot_count 225 | 226 | return new_mean, new_var, new_count 227 | 228 | 229 | def batch_norm_to_group_norm(layer): 230 | """Iterates over a whole model (or layer of a model) and replaces every batch norm 2D with a group norm 231 | 232 | Args: 233 | layer: model or one layer of a model like resnet34.layer1 or Sequential(), ... 234 | """ 235 | 236 | # num_channels: num_groups 237 | GROUP_NORM_LOOKUP = { 238 | 16: 2, # -> channels per group: 8 239 | 32: 4, # -> channels per group: 8 240 | 64: 8, # -> channels per group: 8 241 | 128: 8, # -> channels per group: 16 242 | 256: 16, # -> channels per group: 16 243 | 512: 32, # -> channels per group: 16 244 | 1024: 32, # -> channels per group: 32 245 | 2048: 32, # -> channels per group: 64 246 | } 247 | 248 | for name, module in layer.named_modules(): 249 | if name: 250 | try: 251 | # name might be something like: model.layer1.sequential.0.conv1 --> this wont work. Except this case 252 | sub_layer = getattr(layer, name) 253 | if isinstance(sub_layer, torch.nn.BatchNorm2d): 254 | num_channels = sub_layer.num_features 255 | # first level of current layer or model contains a batch norm --> replacing. 256 | layer._modules[name] = torch.nn.GroupNorm( 257 | GROUP_NORM_LOOKUP[num_channels], num_channels 258 | ) 259 | except AttributeError: 260 | # go deeper: set name to layer1, getattr will return layer1 --> call this func again 261 | name = name.split(".")[0] 262 | sub_layer = getattr(layer, name) 263 | sub_layer = batch_norm_to_group_norm(sub_layer) 264 | layer.__setattr__(name=name, value=sub_layer) 265 | return layer 266 | -------------------------------------------------------------------------------- /agent/networks/rgb_modules.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains all neural modules related to encoding the spatial 3 | information of obs_t, i.e., the abstracted knowledge of the current visual 4 | input conditioned on the language. 5 | """ 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torchvision 10 | 11 | import utils 12 | 13 | 14 | ############################################################################### 15 | # 16 | # Modules related to encoding visual information (can conditioned on language) 17 | # 18 | ############################################################################### 19 | 20 | 21 | class BaseEncoder(nn.Module): 22 | def __init__(self, obs_shape): 23 | super().__init__() 24 | 25 | assert len(obs_shape) == 3 26 | self.repr_dim = 512 27 | 28 | self.convnet = nn.Sequential( 29 | nn.Conv2d(obs_shape[0], 32, 3, stride=2), 30 | nn.ReLU(), 31 | nn.Conv2d(32, 32, 3, stride=1), 32 | nn.ReLU(), 33 | nn.Conv2d(32, 32, 3, stride=1), 34 | nn.ReLU(), 35 | nn.Conv2d(32, 32, 3, stride=1), 36 | nn.ReLU(), 37 | ) 38 | 39 | if obs_shape[1] == 84: 40 | dim = 39200 41 | elif obs_shape[1] == 128: 42 | dim = 103968 43 | elif obs_shape[1] == 224: 44 | dim = 352800 45 | self.trunk = nn.Sequential(nn.Linear(dim, 512), nn.LayerNorm(512), nn.Tanh()) 46 | 47 | self.apply(utils.weight_init) 48 | 49 | def forward(self, obs): 50 | obs = obs - 0.5 51 | h = self.convnet(obs) 52 | # h = h.view(h.shape[0], -1) 53 | h = h.reshape(h.shape[0], -1) 54 | h = self.trunk(h) 55 | return h 56 | 57 | 58 | class PatchEncoder(nn.Module): 59 | """ 60 | A patch encoder that does a linear projection of patches in a RGB image. 61 | """ 62 | 63 | def __init__( 64 | self, input_shape, patch_size=[16, 16], embed_size=64, no_patch_embed_bias=False 65 | ): 66 | super().__init__() 67 | C, H, W = input_shape 68 | num_patches = (H // patch_size[0] // 2) * (W // patch_size[1] // 2) 69 | self.img_size = (H, W) 70 | self.patch_size = patch_size 71 | self.num_patches = num_patches 72 | self.h, self.w = H // patch_size[0] // 2, W // patch_size[1] // 2 73 | 74 | self.conv = nn.Sequential( 75 | nn.Conv2d( 76 | C, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False 77 | ), 78 | nn.BatchNorm2d( 79 | 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True 80 | ), 81 | nn.ReLU(inplace=True), 82 | ) 83 | self.proj = nn.Conv2d( 84 | 64, 85 | embed_size, 86 | kernel_size=patch_size, 87 | stride=patch_size, 88 | bias=False if no_patch_embed_bias else True, 89 | ) 90 | self.bn = nn.BatchNorm2d(embed_size) 91 | 92 | def forward(self, x): 93 | B, C, H, W = x.shape 94 | x = self.conv(x) 95 | x = self.proj(x) 96 | x = self.bn(x) 97 | return x 98 | 99 | 100 | class SpatialSoftmax(nn.Module): 101 | """ 102 | The spatial softmax layer (https://rll.berkeley.edu/dsae/dsae.pdf) 103 | """ 104 | 105 | def __init__(self, in_c, in_h, in_w, num_kp=None): 106 | super().__init__() 107 | self._spatial_conv = nn.Conv2d(in_c, num_kp, kernel_size=1) 108 | 109 | pos_x, pos_y = torch.meshgrid( 110 | torch.linspace(-1, 1, in_w).float(), 111 | torch.linspace(-1, 1, in_h).float(), 112 | ) 113 | 114 | pos_x = pos_x.reshape(1, in_w * in_h) 115 | pos_y = pos_y.reshape(1, in_w * in_h) 116 | self.register_buffer("pos_x", pos_x) 117 | self.register_buffer("pos_y", pos_y) 118 | 119 | if num_kp is None: 120 | self._num_kp = in_c 121 | else: 122 | self._num_kp = num_kp 123 | 124 | self._in_c = in_c 125 | self._in_w = in_w 126 | self._in_h = in_h 127 | 128 | def forward(self, x): 129 | assert x.shape[1] == self._in_c 130 | assert x.shape[2] == self._in_h 131 | assert x.shape[3] == self._in_w 132 | 133 | h = x 134 | if self._num_kp != self._in_c: 135 | h = self._spatial_conv(h) 136 | h = h.contiguous().view(-1, self._in_h * self._in_w) 137 | 138 | attention = F.softmax(h, dim=-1) 139 | keypoint_x = ( 140 | (self.pos_x * attention).sum(1, keepdims=True).view(-1, self._num_kp) 141 | ) 142 | keypoint_y = ( 143 | (self.pos_y * attention).sum(1, keepdims=True).view(-1, self._num_kp) 144 | ) 145 | keypoints = torch.cat([keypoint_x, keypoint_y], dim=1) 146 | return keypoints 147 | 148 | 149 | class SpatialProjection(nn.Module): 150 | def __init__(self, input_shape, out_dim): 151 | super().__init__() 152 | 153 | assert ( 154 | len(input_shape) == 3 155 | ), "[error] spatial projection: input shape is not a 3-tuple" 156 | in_c, in_h, in_w = input_shape 157 | num_kp = out_dim // 2 158 | self.out_dim = out_dim 159 | self.spatial_softmax = SpatialSoftmax(in_c, in_h, in_w, num_kp=num_kp) 160 | self.projection = nn.Linear(num_kp * 2, out_dim) 161 | 162 | def forward(self, x): 163 | out = self.spatial_softmax(x) 164 | out = self.projection(out) 165 | return out 166 | 167 | def output_shape(self, input_shape): 168 | return input_shape[:-3] + (self.out_dim,) 169 | 170 | 171 | class ResnetEncoder(nn.Module): 172 | """ 173 | A Resnet-18-based encoder for mapping an image to a latent vector 174 | 175 | Encode (f) an image into a latent vector. 176 | 177 | y = f(x), where 178 | x: (B, C, H, W) 179 | y: (B, H_out) 180 | 181 | Args: 182 | input_shape: (C, H, W), the shape of the image 183 | output_size: H_out, the latent vector size 184 | pretrained: whether use pretrained resnet 185 | freeze: whether freeze the pretrained resnet 186 | remove_layer_num: remove the top # layers 187 | no_stride: do not use striding 188 | """ 189 | 190 | def __init__( 191 | self, 192 | input_shape, 193 | output_size, 194 | pretrained=False, 195 | freeze=False, 196 | remove_layer_num=2, 197 | no_stride=False, 198 | cond_dim=768, 199 | cond_fusion="film", 200 | ): 201 | super().__init__() 202 | 203 | ### 1. encode input (images) using convolutional layers 204 | assert remove_layer_num <= 5, "[error] please only remove <=5 layers" 205 | layers = list(torchvision.models.resnet18(pretrained=pretrained).children())[ 206 | :-remove_layer_num 207 | ] 208 | self.remove_layer_num = remove_layer_num 209 | 210 | assert ( 211 | len(input_shape) == 3 212 | ), "[error] input shape of resnet should be (C, H, W)" 213 | 214 | in_channels = input_shape[0] 215 | if in_channels != 3: # has eye_in_hand, increase channel size 216 | conv0 = nn.Conv2d( 217 | in_channels=in_channels, 218 | out_channels=64, 219 | kernel_size=(7, 7), 220 | stride=(2, 2), 221 | padding=(3, 3), 222 | bias=False, 223 | ) 224 | layers[0] = conv0 225 | 226 | self.no_stride = no_stride 227 | if self.no_stride: 228 | layers[0].stride = (1, 1) 229 | layers[3].stride = 1 230 | 231 | self.resnet18_base = nn.Sequential(*layers[:4]) 232 | self.block_1 = layers[4][0] 233 | self.block_2 = layers[4][1] 234 | self.block_3 = layers[5][0] 235 | self.block_4 = layers[5][1] 236 | 237 | self.cond_fusion = cond_fusion 238 | if cond_fusion == "film": 239 | self.lang_proj1 = nn.Linear(cond_dim, 64 * 2) 240 | self.lang_proj2 = nn.Linear(cond_dim, 64 * 2) 241 | self.lang_proj3 = nn.Linear(cond_dim, 128 * 2) 242 | self.lang_proj4 = nn.Linear(cond_dim, 128 * 2) 243 | 244 | if freeze: 245 | if in_channels != 3: 246 | raise Exception( 247 | "[error] cannot freeze pretrained " 248 | + "resnet with the extra eye_in_hand input" 249 | ) 250 | for param in self.resnet18_embeddings.parameters(): 251 | param.requires_grad = False 252 | 253 | ### 2. project the encoded input to a latent space 254 | x = torch.zeros(1, *input_shape) 255 | y = self.block_4( 256 | self.block_3(self.block_2(self.block_1(self.resnet18_base(x)))) 257 | ) 258 | output_shape = y.shape # compute the out dim 259 | self.projection_layer = SpatialProjection(output_shape[1:], output_size) 260 | self.output_shape = self.projection_layer(y).shape 261 | 262 | # Replace BatchNorm layers with GroupNorm 263 | self.resnet18_base = utils.batch_norm_to_group_norm(self.resnet18_base) 264 | self.block_1 = utils.batch_norm_to_group_norm(self.block_1) 265 | self.block_2 = utils.batch_norm_to_group_norm(self.block_2) 266 | self.block_3 = utils.batch_norm_to_group_norm(self.block_3) 267 | self.block_4 = utils.batch_norm_to_group_norm(self.block_4) 268 | 269 | # self.normlayer = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 270 | 271 | def forward(self, x, cond=None, return_intermediate=False): 272 | # # preprocess 273 | # preprocess = nn.Sequential(self.normlayer) 274 | # x = preprocess(x) 275 | 276 | h = self.resnet18_base(x) 277 | 278 | h = self.block_1(h) 279 | if cond is not None and self.cond_fusion == "film": # FiLM layer 280 | B, C, H, W = h.shape 281 | beta, gamma = torch.split( 282 | self.lang_proj1(cond).reshape(B, C * 2, 1, 1), [C, C], 1 283 | ) 284 | h = (1 + gamma) * h + beta 285 | 286 | h = self.block_2(h) 287 | if cond is not None and self.cond_fusion == "film": # FiLM layer 288 | B, C, H, W = h.shape 289 | beta, gamma = torch.split( 290 | self.lang_proj2(cond).reshape(B, C * 2, 1, 1), [C, C], 1 291 | ) 292 | h = (1 + gamma) * h + beta 293 | 294 | h = self.block_3(h) 295 | if cond is not None and self.cond_fusion == "film": # FiLM layer 296 | B, C, H, W = h.shape 297 | beta, gamma = torch.split( 298 | self.lang_proj3(cond).reshape(B, C * 2, 1, 1), [C, C], 1 299 | ) 300 | h = (1 + gamma) * h + beta 301 | 302 | h = self.block_4(h) 303 | if cond is not None and self.cond_fusion == "film": # FiLM layer 304 | B, C, H, W = h.shape 305 | beta, gamma = torch.split( 306 | self.lang_proj4(cond).reshape(B, C * 2, 1, 1), [C, C], 1 307 | ) 308 | h = (1 + gamma) * h + beta 309 | 310 | if not return_intermediate: 311 | h = self.projection_layer(h) 312 | return h 313 | 314 | def output_shape(self): 315 | return self.output_shape 316 | -------------------------------------------------------------------------------- /train_bc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import warnings 4 | import os 5 | 6 | os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" 7 | os.environ["MUJOCO_GL"] = "egl" 8 | from pathlib import Path 9 | 10 | import hydra 11 | import torch 12 | import numpy as np 13 | 14 | import utils 15 | from logger import Logger 16 | from replay_buffer import make_expert_replay_loader 17 | from video import VideoRecorder 18 | 19 | warnings.filterwarnings("ignore", category=DeprecationWarning) 20 | torch.backends.cudnn.benchmark = True 21 | 22 | 23 | def make_agent(obs_spec, action_spec, cfg): 24 | obs_shape = {} 25 | for key in cfg.suite.pixel_keys: 26 | obs_shape[key] = obs_spec[key].shape 27 | if cfg.use_aux_inputs: 28 | for key in cfg.suite.aux_keys: 29 | obs_shape[key] = obs_spec[key].shape 30 | obs_shape[cfg.suite.feature_key] = obs_spec[cfg.suite.feature_key].shape 31 | cfg.agent.obs_shape = obs_shape 32 | cfg.agent.action_shape = ( 33 | action_spec.shape 34 | if cfg.suite.action_type == "continuous" 35 | else action_spec.num_values 36 | ) 37 | return hydra.utils.instantiate(cfg.agent) 38 | 39 | 40 | class WorkspaceIL: 41 | def __init__(self, cfg): 42 | self.work_dir = Path.cwd() 43 | print(f"workspace: {self.work_dir}") 44 | 45 | self.cfg = cfg 46 | utils.set_seed_everywhere(cfg.seed) 47 | self.device = torch.device(cfg.device) 48 | 49 | # load data 50 | dataset_iterable = hydra.utils.call(self.cfg.expert_dataset) 51 | self.expert_replay_loader = make_expert_replay_loader( 52 | dataset_iterable, self.cfg.batch_size 53 | ) 54 | self.expert_replay_iter = iter(self.expert_replay_loader) 55 | self.stats = self.expert_replay_loader.dataset.stats 56 | 57 | # create logger 58 | self.logger = Logger(self.work_dir, use_tb=self.cfg.use_tb) 59 | # create envs 60 | self.cfg.suite.task_make_fn.max_episode_len = ( 61 | self.expert_replay_loader.dataset._max_episode_len 62 | ) 63 | self.cfg.suite.task_make_fn.max_state_dim = ( 64 | self.expert_replay_loader.dataset._max_state_dim 65 | ) 66 | self.env, self.task_descriptions = hydra.utils.call(self.cfg.suite.task_make_fn) 67 | 68 | # create agent 69 | self.agent = make_agent( 70 | self.env[0].observation_spec(), self.env[0].action_spec(), cfg 71 | ) 72 | 73 | # TODO: Make this compatible with no eval case 74 | self.envs_till_idx = self.expert_replay_loader.dataset.envs_till_idx 75 | print(f"envs_till_idx: {self.expert_replay_loader.dataset.envs_till_idx}") 76 | 77 | # Discretizer for BeT 78 | 79 | self.timer = utils.Timer() 80 | self._global_step = 0 81 | self._global_episode = 0 82 | 83 | self.video_recorder = VideoRecorder( 84 | self.work_dir if self.cfg.save_video else None 85 | ) 86 | 87 | @property 88 | def global_step(self): 89 | return self._global_step 90 | 91 | @property 92 | def global_episode(self): 93 | return self._global_episode 94 | 95 | @property 96 | def global_frame(self): 97 | return self.global_step * self.cfg.suite.action_repeat 98 | 99 | def eval(self): 100 | self.agent.train(False) 101 | episode_rewards = [] 102 | successes = [] 103 | 104 | num_envs = ( 105 | len(self.env) if self.cfg.suite.name == "calvin" else self.envs_till_idx 106 | ) 107 | 108 | for env_idx in range(num_envs): 109 | print(f"evaluating env {env_idx}") 110 | episode, total_reward = 0, 0 111 | eval_until_episode = utils.Until(self.cfg.suite.num_eval_episodes) 112 | success = [] 113 | 114 | while eval_until_episode(episode): 115 | time_step = self.env[env_idx].reset() 116 | self.agent.buffer_reset() 117 | step = 0 118 | 119 | # prompt 120 | if self.cfg.prompt != None and self.cfg.prompt != "intermediate_goal": 121 | prompt = self.expert_replay_loader.dataset.sample_test(env_idx) 122 | else: 123 | prompt = None 124 | 125 | if episode == 0: 126 | self.video_recorder.init(self.env[env_idx], enabled=True) 127 | 128 | # plot obs with cv2 129 | while not time_step.last(): 130 | if self.cfg.prompt == "intermediate_goal": 131 | prompt = self.expert_replay_loader.dataset.sample_test( 132 | env_idx, step 133 | ) 134 | with torch.no_grad(), utils.eval_mode(self.agent): 135 | action = self.agent.act( 136 | time_step.observation, 137 | prompt, 138 | self.stats, 139 | step, 140 | self.global_step, 141 | eval_mode=True, 142 | ) 143 | time_step = self.env[env_idx].step(action) 144 | self.video_recorder.record(self.env[env_idx]) 145 | total_reward += time_step.reward 146 | step += 1 147 | 148 | episode += 1 149 | success.append(time_step.observation["goal_achieved"]) 150 | self.video_recorder.save(f"{self.global_step}_env{env_idx}.mp4") 151 | episode_rewards.append(total_reward / episode) 152 | successes.append(np.mean(success)) 153 | 154 | for _ in range(len(self.env) - num_envs): 155 | episode_rewards.append(0) 156 | successes.append(0) 157 | 158 | with self.logger.log_and_dump_ctx(self.global_step, ty="eval") as log: 159 | for env_idx, reward in enumerate(episode_rewards): 160 | log(f"episode_reward_env{env_idx}", reward) 161 | log(f"success_env{env_idx}", successes[env_idx]) 162 | log("episode_reward", np.mean(episode_rewards[:num_envs])) 163 | log("success", np.mean(successes)) 164 | log("episode_length", step * self.cfg.suite.action_repeat / episode) 165 | log("episode", self.global_episode) 166 | log("step", self.global_step) 167 | 168 | self.agent.train(True) 169 | 170 | def train(self): 171 | # predicates 172 | train_until_step = utils.Until( 173 | self.cfg.suite.num_train_steps, 1 # self.cfg.suite.action_repeat 174 | ) 175 | log_every_step = utils.Every( 176 | self.cfg.suite.log_every_steps, 1 # self.cfg.suite.action_repeat 177 | ) 178 | eval_every_step = utils.Every( 179 | self.cfg.suite.eval_every_steps, 1 # self.cfg.suite.action_repeat 180 | ) 181 | save_every_step = utils.Every( 182 | self.cfg.suite.save_every_steps, 1 # self.cfg.suite.action_repeat 183 | ) 184 | 185 | metrics = None 186 | while train_until_step(self.global_step): 187 | # try to evaluate 188 | if ( 189 | self.cfg.eval 190 | and eval_every_step(self.global_step) 191 | and self.global_step > 0 192 | ): 193 | self.logger.log( 194 | "eval_total_time", self.timer.total_time(), self.global_frame 195 | ) 196 | self.eval() 197 | 198 | # update 199 | metrics = self.agent.update(self.expert_replay_iter, self.global_step) 200 | self.logger.log_metrics(metrics, self.global_frame, ty="train") 201 | 202 | # log 203 | if log_every_step(self.global_step): 204 | elapsed_time, total_time = self.timer.reset() 205 | with self.logger.log_and_dump_ctx(self.global_frame, ty="train") as log: 206 | log("total_time", total_time) 207 | log("actor_loss", metrics["actor_loss"]) 208 | log("step", self.global_step) 209 | 210 | # save snapshot 211 | if save_every_step(self.global_step): 212 | self.save_snapshot() 213 | 214 | # Update scene 215 | if self.cfg.sequential_train: 216 | if ( 217 | self.global_step + 1 218 | ) % self.steps_till_next_scene == 0 and self.scene_idx < len( 219 | self.scene_names 220 | ) - 1: 221 | self.scene_idx = (self.scene_idx + 1) % len(self.scene_names) 222 | self.envs_till_idx += len( 223 | self.task_names[self.scene_names[self.scene_idx]] 224 | ) 225 | self.expert_replay_loader.dataset.envs_till_idx = self.envs_till_idx 226 | self.steps_till_next_scene = ( 227 | self.envs_till_idx * self.cfg.suite.num_train_steps_per_task 228 | ) 229 | self.expert_replay_iter = iter(self.expert_replay_loader) 230 | 231 | # self.agent.reinit_optimizers() 232 | 233 | self._global_step += 1 234 | 235 | def save_snapshot(self): 236 | snapshot_dir = self.work_dir / "snapshot" 237 | snapshot_dir.mkdir(exist_ok=True) 238 | snapshot = snapshot_dir / f"{self.global_step}.pt" 239 | self.agent.clear_buffers() 240 | keys_to_save = ["timer", "_global_step", "_global_episode", "stats"] 241 | payload = {k: self.__dict__[k] for k in keys_to_save} 242 | payload.update(self.agent.save_snapshot()) 243 | with snapshot.open("wb") as f: 244 | torch.save(payload, f) 245 | 246 | self.agent.buffer_reset() 247 | 248 | def load_snapshot(self, snapshots, encoder_only=False): 249 | # bc 250 | with snapshots["bc"].open("rb") as f: 251 | payload = torch.load(f) 252 | agent_payload = {} 253 | for k, v in payload.items(): 254 | if k not in self.__dict__: 255 | agent_payload[k] = v 256 | if "vqvae" in snapshots: 257 | with snapshots["vqvae"].open("rb") as f: 258 | payload = torch.load(f) 259 | agent_payload["vqvae"] = payload 260 | self.agent.load_snapshot(agent_payload, encoder_only=encoder_only, eval=False) 261 | # self.agent.load_snapshot_eval(agent_payload) 262 | 263 | 264 | @hydra.main(config_path="cfgs", config_name="config", version_base=None) 265 | def main(cfg): 266 | from train_bc import WorkspaceIL as W 267 | 268 | workspace = W(cfg) 269 | 270 | # Load weights 271 | if cfg.load_bc: 272 | # BC weight 273 | snapshots = {} 274 | bc_snapshot = Path(cfg.bc_weight) 275 | if not bc_snapshot.exists(): 276 | raise FileNotFoundError(f"bc weight not found: {bc_snapshot}") 277 | print(f"loading bc weight: {bc_snapshot}") 278 | snapshots["bc"] = bc_snapshot 279 | # vqvae_snapshot = Path(cfg.vqvae_weight) 280 | # if vqvae_snapshot.exists(): 281 | # print(f"loading vqvae weight: {vqvae_snapshot}") 282 | # snapshots["vqvae"] = vqvae_snapshot 283 | workspace.load_snapshot(snapshots) 284 | 285 | workspace.train() 286 | # workspace.eval() 287 | 288 | 289 | if __name__ == "__main__": 290 | main() 291 | -------------------------------------------------------------------------------- /envs/xarm-env/xarm_env/envs/robot_env.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from gym import spaces 3 | import cv2 4 | import numpy as np 5 | 6 | # import pybullet 7 | # import pybullet_data 8 | import pickle 9 | from scipy.spatial.transform import Rotation as R 10 | 11 | from openteach.utils.network import create_request_socket, ZMQCameraSubscriber 12 | from xarm_env.envs.constants import * 13 | 14 | 15 | def get_quaternion_orientation(cartesian): 16 | """ 17 | Get quaternion orientation from axis angle representation 18 | """ 19 | pos = cartesian[:3] 20 | ori = cartesian[3:] 21 | r = R.from_rotvec(ori) 22 | quat = r.as_quat() 23 | return np.concatenate([pos, quat], axis=-1) 24 | 25 | 26 | class RobotEnv(gym.Env): 27 | def __init__( 28 | self, 29 | height=224, 30 | width=224, 31 | use_robot=True, # True when robot used 32 | use_egocentric=False, # True when egocentric camera used 33 | use_fisheye=True, 34 | sensor_type="reskin", 35 | subtract_sensor_baseline=False, 36 | ): 37 | super(RobotEnv, self).__init__() 38 | self.height = height 39 | self.width = width 40 | 41 | self.use_robot = use_robot 42 | self.use_fish_eye = use_fisheye 43 | self.use_egocentric = use_egocentric 44 | self.sensor_type = sensor_type 45 | self.digit_keys = ["digit80", "digit81"] 46 | 47 | self.subtract_sensor_baseline = subtract_sensor_baseline 48 | self.sensor_prev_state = None 49 | self.sensor_baseline = None 50 | 51 | self.feature_dim = 8 # 10 # 7 52 | self.proprio_dim = 8 53 | 54 | self.n_sensors = 2 55 | self.sensor_dim = 15 56 | self.action_dim = 7 57 | 58 | # Robot limits 59 | # self.cartesian_delta_limits = np.array([-10, 10]) 60 | 61 | self.n_channels = 3 62 | self.reward = 0 63 | 64 | self.observation_space = spaces.Box( 65 | low=0, high=255, shape=(height, width, self.n_channels), dtype=np.uint8 66 | ) 67 | self.action_space = spaces.Box( 68 | low=0.0, high=1.0, shape=(self.action_dim,), dtype=np.float32 69 | ) 70 | 71 | if self.use_robot: 72 | # camera subscribers 73 | self.image_subscribers = {} 74 | for cam_idx in list(CAM_SERIAL_NUMS.keys()): 75 | port = CAMERA_PORT_OFFSET + cam_idx 76 | self.image_subscribers[cam_idx] = ZMQCameraSubscriber( 77 | host=HOST_ADDRESS, 78 | port=port, 79 | topic_type="RGB", 80 | ) 81 | 82 | # for fish_eye_cam_idx in range(len(FISH_EYE_CAM_SERIAL_NUMS)): 83 | if use_fisheye: 84 | for fish_eye_cam_idx in list(FISH_EYE_CAM_SERIAL_NUMS.keys()): 85 | port = FISH_EYE_CAMERA_PORT_OFFSET + fish_eye_cam_idx 86 | self.image_subscribers[fish_eye_cam_idx] = ZMQCameraSubscriber( 87 | host=HOST_ADDRESS, 88 | port=port, 89 | topic_type="RGB", 90 | ) 91 | 92 | # action request port 93 | self.action_request_socket = create_request_socket( 94 | HOST_ADDRESS, DEPLOYMENT_PORT 95 | ) 96 | 97 | def step(self, action): 98 | print("current step's action is: ", action) 99 | action = np.array(action) 100 | 101 | action_dict = { 102 | "xarm": { 103 | "cartesian": action[:-1], 104 | "gripper": action[-1:], 105 | } 106 | } 107 | 108 | # send action 109 | self.action_request_socket.send(pickle.dumps(action_dict, protocol=-1)) 110 | ret = self.action_request_socket.recv() 111 | ret = pickle.loads(ret) 112 | if ret == "Command failed!": 113 | print("Command failed!") 114 | # return None, 0, True, None 115 | self.action_request_socket.send(b"get_state") 116 | ret = pickle.loads(self.action_request_socket.recv()) 117 | # robot_state = pickle.loads(self.action_request_socket.recv())["robot_state"]["xarm"] 118 | # else: 119 | # # robot_state = ret["robot_state"]["xarm"] 120 | # robot_state = ret["robot_state"]["xarm"] 121 | robot_state = ret["robot_state"]["xarm"] 122 | 123 | # cartesian_pos = robot_state[:3] 124 | # cartesian_ori = robot_state[3:6] 125 | # gripper = robot_state[6] 126 | # cartesian_ori_sin = np.sin(cartesian_ori) 127 | # cartesian_ori_cos = np.cos(cartesian_ori) 128 | # robot_state = np.concatenate( 129 | # [cartesian_pos, cartesian_ori_sin, cartesian_ori_cos, [gripper]], axis=0 130 | # ) 131 | cartesian = robot_state[:6] 132 | quat_cartesian = get_quaternion_orientation(cartesian) 133 | robot_state = np.concatenate([quat_cartesian, robot_state[6:]], axis=0) 134 | 135 | # subscribe images 136 | image_dict = {} 137 | for cam_idx, img_sub in self.image_subscribers.items(): 138 | image_dict[cam_idx] = img_sub.recv_rgb_image()[0] 139 | 140 | obs = {} 141 | obs["features"] = np.array(robot_state, dtype=np.float32) 142 | obs["proprioceptive"] = np.array(robot_state, dtype=np.float32) 143 | if self.sensor_type == "reskin": 144 | try: 145 | sensor_state = ret["sensor_state"]["reskin"]["sensor_values"] 146 | sensor_state_sub = ( 147 | np.array(sensor_state, dtype=np.float32) - self.sensor_baseline 148 | ) 149 | self.sensor_prev_state = sensor_state_sub 150 | sensor_keys = [ 151 | f"sensor{sensor_idx}" for sensor_idx in range(self.n_sensors) 152 | ] 153 | for sidx, sensor_key in enumerate(sensor_keys): 154 | if self.subtract_sensor_baseline: 155 | obs[sensor_key] = sensor_state_sub[ 156 | sidx * self.sensor_dim : (sidx + 1) * self.sensor_dim 157 | ] 158 | else: 159 | obs[sensor_key] = sensor_state[ 160 | sidx * self.sensor_dim : (sidx + 1) * self.sensor_dim 161 | ] 162 | except KeyError: 163 | pass 164 | elif self.sensor_type == "digit": 165 | for dkey in self.digit_keys: 166 | obs[dkey] = np.array(ret["sensor_state"][dkey]) 167 | obs[dkey] = cv2.resize(obs[dkey], (self.width, self.height)) 168 | if self.subtract_sensor_baseline: 169 | obs[dkey] = obs[dkey] - self.sensor_baseline 170 | 171 | for cam_idx, image in image_dict.items(): 172 | if cam_idx == 52: 173 | # crop the right side of the image for the gripper cam 174 | img_shape = image.shape 175 | crop_percent = 0.2 176 | image = image[:, : int(img_shape[1] * (1 - crop_percent))] 177 | obs[f"pixels{cam_idx}"] = cv2.resize(image, (self.width, self.height)) 178 | return obs, self.reward, False, False, {} 179 | 180 | def reset(self, seed=None): # currently same positions, with gripper opening 181 | if self.use_robot: 182 | print("resetting") 183 | self.action_request_socket.send(b"reset") 184 | reset_state = pickle.loads(self.action_request_socket.recv()) 185 | 186 | # subscribe robot state 187 | self.action_request_socket.send(b"get_state") 188 | ret = pickle.loads(self.action_request_socket.recv()) 189 | robot_state = ret["robot_state"]["xarm"] 190 | # robot_state = np.array(robot_state, dtype=np.float32) 191 | # cartesian_pos = robot_state[:3] 192 | # cartesian_ori = robot_state[3:6] 193 | # gripper = robot_state[6] 194 | # cartesian_ori_sin = np.sin(cartesian_ori) 195 | # cartesian_ori_cos = np.cos(cartesian_ori) 196 | # robot_state = np.concatenate( 197 | # [cartesian_pos, cartesian_ori_sin, cartesian_ori_cos, [gripper]], axis=0 198 | # ) 199 | cartesian = robot_state[:6] 200 | quat_cartesian = get_quaternion_orientation(cartesian) 201 | robot_state = np.concatenate([quat_cartesian, robot_state[6:]], axis=0) 202 | 203 | # subscribe images 204 | image_dict = {} 205 | for cam_idx, img_sub in self.image_subscribers.items(): 206 | image_dict[cam_idx] = img_sub.recv_rgb_image()[0] 207 | 208 | obs = {} 209 | obs["features"] = robot_state 210 | obs["proprioceptive"] = robot_state 211 | if self.sensor_type == "reskin": 212 | try: 213 | sensor_state = np.array( 214 | ret["sensor_state"]["reskin"]["sensor_values"] 215 | ) 216 | # obs["sensor"] = np.array(sensor_state) 217 | if self.subtract_sensor_baseline: 218 | baseline_meas = [] 219 | while len(baseline_meas) < 5: 220 | self.action_request_socket.send(b"get_sensor_state") 221 | ret = pickle.loads(self.action_request_socket.recv()) 222 | sensor_state = ret["reskin"]["sensor_values"] 223 | baseline_meas.append(sensor_state) 224 | self.sensor_baseline = np.median(baseline_meas, axis=0) 225 | sensor_state = sensor_state - self.sensor_baseline 226 | self.sensor_prev_state = sensor_state 227 | sensor_keys = [ 228 | f"sensor{sensor_idx}" for sensor_idx in range(self.n_sensors) 229 | ] 230 | for sidx, sensor_key in enumerate(sensor_keys): 231 | obs[sensor_key] = sensor_state[ 232 | sidx * self.sensor_dim : (sidx + 1) * self.sensor_dim 233 | ] 234 | except KeyError: 235 | pass 236 | elif self.sensor_type == "digit": 237 | for dkey in self.digit_keys: 238 | obs[dkey] = np.array(ret["sensor_state"][dkey]) 239 | obs[dkey] = cv2.resize(obs[dkey], (self.width, self.height)) 240 | if self.subtract_sensor_baseline: 241 | baseline_meas = [] 242 | while len(baseline_meas) < 5: 243 | self.action_request_socket.send(b"get_sensor_state") 244 | ret = pickle.loads(self.action_request_socket.recv()) 245 | sensor_state = cv2.resize( 246 | ret[dkey], (self.width, self.height) 247 | ) 248 | baseline_meas.append(sensor_state) 249 | self.sensor_baseline = np.median(baseline_meas, axis=0) 250 | obs[dkey] = sensor_state - self.sensor_baseline 251 | # obs["sensor"] = sensor_state - self.sensor_baseline 252 | for cam_idx, image in image_dict.items(): 253 | if cam_idx == 52: 254 | # crop the right side of the image for the gripper cam 255 | img_shape = image.shape 256 | crop_percent = 0.2 257 | image = image[:, : int(img_shape[1] * (1 - crop_percent))] 258 | obs[f"pixels{cam_idx}"] = cv2.resize(image, (self.width, self.height)) 259 | 260 | return obs 261 | else: 262 | obs = {} 263 | obs["features"] = np.zeros(self.feature_dim) 264 | obs["proprioceptive"] = np.zeros(self.proprio_dim) 265 | for sensor_idx in range(self.n_sensors): 266 | obs[f"sensor{sensor_idx}"] = np.zeros(self.sensor_dim) 267 | self.sensor_baseline = np.zeros(self.sensor_dim * self.n_sensors) 268 | obs["pixels"] = np.zeros((self.height, self.width, self.n_channels)) 269 | return obs 270 | 271 | def render(self, mode="rgb_array", width=640, height=480): 272 | print("rendering") 273 | # subscribe images 274 | image_list = [] 275 | for _, img_sub in self.image_subscribers.items(): 276 | image = img_sub.recv_rgb_image()[0] 277 | image_list.append(cv2.resize(image, (width, height))) 278 | 279 | obs = np.concatenate(image_list, axis=1) 280 | return obs 281 | 282 | 283 | if __name__ == "__main__": 284 | env = RobotEnv() 285 | obs = env.reset() 286 | import ipdb 287 | 288 | ipdb.set_trace() 289 | 290 | for i in range(30): 291 | action = obs["features"] 292 | action[0] += 2 293 | obs, reward, done, _ = env.step(action) 294 | 295 | for i in range(30): 296 | action = obs["features"] 297 | action[1] += 2 298 | obs, reward, done, _ = env.step(action) 299 | 300 | for i in range(30): 301 | action = obs["features"] 302 | action[2] += 2 303 | obs, reward, done, _ = env.step(action) 304 | -------------------------------------------------------------------------------- /agent/networks/gpt.py: -------------------------------------------------------------------------------- 1 | """ 2 | An adaptation of Andrej Karpathy's nanoGPT implementation in PyTorch. 3 | Original source: https://github.com/karpathy/nanoGPT 4 | 5 | Original License: 6 | MIT License 7 | 8 | Copyright (c) 2022 Andrej Karpathy 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | 28 | Original comments: 29 | Full definition of a GPT Language Model, all of it in this single file. 30 | References: 31 | 1) the official GPT-2 TensorFlow implementation released by OpenAI: 32 | https://github.com/openai/gpt-2/blob/master/src/model.py 33 | 2) huggingface/transformers PyTorch implementation: 34 | https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py 35 | """ 36 | 37 | import math 38 | from dataclasses import dataclass 39 | 40 | import torch 41 | import torch.nn as nn 42 | from torch.nn import functional as F 43 | 44 | 45 | # @torch.jit.script # good to enable when not using torch.compile, disable when using (our default) 46 | def new_gelu(x): 47 | """ 48 | Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). 49 | Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415 50 | """ 51 | return ( 52 | 0.5 53 | * x 54 | * ( 55 | 1.0 56 | + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))) 57 | ) 58 | ) 59 | 60 | 61 | class CrossAttention(nn.Module): 62 | def __init__(self, repr_dim, nhead=4, nlayers=4, use_buffer_token=False): 63 | super().__init__() 64 | self.tf_decoder = nn.TransformerDecoder( 65 | nn.TransformerDecoderLayer( 66 | d_model=repr_dim, 67 | nhead=nhead, 68 | dim_feedforward=repr_dim * 4, 69 | batch_first=True, 70 | ), 71 | num_layers=nlayers, 72 | ) 73 | if use_buffer_token: 74 | self.buffer_token = nn.Parameter(torch.randn(1, 1, repr_dim)) 75 | self.use_buffer_token = use_buffer_token 76 | 77 | def forward(self, feat, cond): 78 | # add buffer token to the beginning of the sequence 79 | if self.use_buffer_token: 80 | batch_size = feat.size(0) 81 | buffer_token = self.buffer_token.expand(batch_size, 1, -1) 82 | cond_with_buffer = torch.cat([buffer_token, cond], dim=1) 83 | return self.tf_decoder(feat, cond_with_buffer) 84 | else: 85 | return self.tf_decoder(feat, cond) 86 | 87 | 88 | class CausalSelfAttention(nn.Module): 89 | def __init__(self, config): 90 | super().__init__() 91 | assert config.n_embd % config.n_head == 0 92 | # key, query, value projections for all heads, but in a batch 93 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd) 94 | # output projection 95 | self.c_proj = nn.Linear(config.n_embd, config.n_embd) 96 | # regularization 97 | self.attn_dropout = nn.Dropout(config.dropout) 98 | self.resid_dropout = nn.Dropout(config.dropout) 99 | # causal mask to ensure that attention is only applied to the left in the input sequence 100 | self.register_buffer( 101 | "bias", 102 | # torch.ones(1, 1, config.block_size, config.block_size), 103 | torch.tril(torch.ones(config.block_size, config.block_size)).view( 104 | 1, 1, config.block_size, config.block_size 105 | ), 106 | ) 107 | self.n_head = config.n_head 108 | self.n_embd = config.n_embd 109 | 110 | def forward(self, x, attn_mask=None): 111 | ( 112 | B, 113 | T, 114 | C, 115 | ) = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 116 | 117 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 118 | q, k, v = self.c_attn(x).split(self.n_embd, dim=2) 119 | k = k.view(B, T, self.n_head, C // self.n_head).transpose( 120 | 1, 2 121 | ) # (B, nh, T, hs) 122 | q = q.view(B, T, self.n_head, C // self.n_head).transpose( 123 | 1, 2 124 | ) # (B, nh, T, hs) 125 | v = v.view(B, T, self.n_head, C // self.n_head).transpose( 126 | 1, 2 127 | ) # (B, nh, T, hs) 128 | 129 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 130 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 131 | mask = self.bias[:, :, :T, :T] 132 | if attn_mask is not None: 133 | mask = mask * attn_mask 134 | att = att.masked_fill(mask == 0, float("-inf")) 135 | att = F.softmax(att, dim=-1) 136 | att = self.attn_dropout(att) 137 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 138 | y = ( 139 | y.transpose(1, 2).contiguous().view(B, T, C) 140 | ) # re-assemble all head outputs side by side 141 | 142 | # output projection 143 | y = self.resid_dropout(self.c_proj(y)) 144 | return y 145 | 146 | 147 | class MLP(nn.Module): 148 | def __init__(self, config): 149 | super().__init__() 150 | self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd) 151 | self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd) 152 | self.dropout = nn.Dropout(config.dropout) 153 | 154 | def forward(self, x): 155 | x = self.c_fc(x) 156 | x = new_gelu(x) 157 | x = self.c_proj(x) 158 | x = self.dropout(x) 159 | return x 160 | 161 | 162 | class Block(nn.Module): 163 | def __init__(self, config): 164 | super().__init__() 165 | self.ln_1 = nn.LayerNorm(config.n_embd) 166 | self.attn = CausalSelfAttention(config) 167 | self.ln_2 = nn.LayerNorm(config.n_embd) 168 | self.mlp = MLP(config) 169 | 170 | def forward(self, x, mask=None): 171 | x = x + self.attn(self.ln_1(x), attn_mask=mask) 172 | x = x + self.mlp(self.ln_2(x)) 173 | return x 174 | 175 | 176 | @dataclass 177 | class GPTConfig: 178 | block_size: int = 1024 179 | input_dim: int = 256 180 | output_dim: int = 256 181 | n_layer: int = 12 182 | n_head: int = 12 183 | n_embd: int = 768 184 | dropout: float = 0.1 185 | 186 | 187 | class GPT(nn.Module): 188 | def __init__(self, config): 189 | super().__init__() 190 | assert config.input_dim is not None 191 | assert config.output_dim is not None 192 | assert config.block_size is not None 193 | self.config = config 194 | 195 | self.transformer = nn.ModuleDict( 196 | dict( 197 | wte=nn.Linear(config.input_dim, config.n_embd), 198 | wpe=nn.Embedding(config.block_size, config.n_embd), 199 | drop=nn.Dropout(config.dropout), 200 | h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]), 201 | ln_f=nn.LayerNorm(config.n_embd), 202 | ) 203 | ) 204 | self.lm_head = nn.Linear(config.n_embd, config.output_dim, bias=False) 205 | # init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper 206 | self.apply(self._init_weights) 207 | for pn, p in self.named_parameters(): 208 | if pn.endswith("c_proj.weight"): 209 | torch.nn.init.normal_( 210 | p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer) 211 | ) 212 | 213 | # report number of parameters 214 | n_params = sum(p.numel() for p in self.parameters()) 215 | print("number of parameters: %.2fM" % (n_params / 1e6,)) 216 | 217 | def forward(self, input, targets=None, mask=None): 218 | device = input.device 219 | b, t, d = input.size() 220 | assert ( 221 | t <= self.config.block_size 222 | ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" 223 | pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze( 224 | 0 225 | ) # shape (1, t) 226 | 227 | # forward the GPT model itself 228 | tok_emb = self.transformer.wte( 229 | input 230 | ) # token embeddings of shape (b, t, n_embd) 231 | pos_emb = self.transformer.wpe( 232 | pos 233 | ) # position embeddings of shape (1, t, n_embd) 234 | x = self.transformer.drop(tok_emb + pos_emb) 235 | for block in self.transformer.h: 236 | x = block(x, mask=mask) 237 | x = self.transformer.ln_f(x) 238 | logits = self.lm_head(x) 239 | return logits 240 | 241 | def _init_weights(self, module): 242 | if isinstance(module, nn.Linear): 243 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 244 | if module.bias is not None: 245 | torch.nn.init.zeros_(module.bias) 246 | elif isinstance(module, nn.Embedding): 247 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 248 | elif isinstance(module, nn.LayerNorm): 249 | torch.nn.init.zeros_(module.bias) 250 | torch.nn.init.ones_(module.weight) 251 | 252 | def crop_block_size(self, block_size): 253 | assert block_size <= self.config.block_size 254 | self.config.block_size = block_size 255 | self.transformer.wpe.weight = nn.Parameter( 256 | self.transformer.wpe.weight[:block_size] 257 | ) 258 | for block in self.transformer.h: 259 | block.attn.bias = block.attn.bias[:, :, :block_size, :block_size] 260 | 261 | def configure_optimizers(self, weight_decay, learning_rate, betas): 262 | """ 263 | This long function is unfortunately doing something very simple and is being very defensive: 264 | We are separating out all parameters of the model into two buckets: those that will experience 265 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights). 266 | We are then returning the PyTorch optimizer object. 267 | """ 268 | 269 | # separate out all parameters to those that will and won't experience regularizing weight decay 270 | decay = set() 271 | no_decay = set() 272 | whitelist_weight_modules = (torch.nn.Linear,) 273 | blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) 274 | for mn, m in self.named_modules(): 275 | for pn, p in m.named_parameters(): 276 | fpn = "%s.%s" % (mn, pn) if mn else pn # full param name 277 | if pn.endswith("bias"): 278 | # all biases will not be decayed 279 | no_decay.add(fpn) 280 | elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules): 281 | # weights of whitelist modules will be weight decayed 282 | decay.add(fpn) 283 | elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules): 284 | # weights of blacklist modules will NOT be weight decayed 285 | no_decay.add(fpn) 286 | 287 | # validate that we considered every parameter 288 | param_dict = {pn: p for pn, p in self.named_parameters()} 289 | inter_params = decay & no_decay 290 | union_params = decay | no_decay 291 | assert ( 292 | len(inter_params) == 0 293 | ), "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),) 294 | assert ( 295 | len(param_dict.keys() - union_params) == 0 296 | ), "parameters %s were not separated into either decay/no_decay set!" % ( 297 | str(param_dict.keys() - union_params), 298 | ) 299 | 300 | # create the pytorch optimizer object 301 | optim_groups = [ 302 | { 303 | "params": [param_dict[pn] for pn in sorted(list(decay))], 304 | "weight_decay": weight_decay, 305 | }, 306 | { 307 | "params": [param_dict[pn] for pn in sorted(list(no_decay))], 308 | "weight_decay": 0.0, 309 | }, 310 | ] 311 | # optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas) 312 | optimizer = torch.optim.Adam(optim_groups, lr=learning_rate, betas=betas) 313 | return optimizer 314 | -------------------------------------------------------------------------------- /suite/xarm_env.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | from typing import Any, NamedTuple 3 | 4 | import gym 5 | from gym import Wrapper, spaces 6 | 7 | import xarm_env 8 | import dm_env 9 | import numpy as np 10 | from dm_env import StepType, specs, TimeStep 11 | 12 | import cv2 13 | 14 | 15 | class RGBArrayAsObservationWrapper(dm_env.Environment): 16 | """ 17 | Use env.render(rgb_array) as observation 18 | rather than the observation environment provides 19 | 20 | From: https://github.com/hill-a/stable-baselines/issues/915 21 | """ 22 | 23 | def __init__( 24 | self, 25 | env, 26 | # height, 27 | # width, 28 | max_episode_len=300, 29 | max_state_dim=100, 30 | task_description="", 31 | pixel_keys=["pixels0"], 32 | aux_keys=["proprioceptive"], 33 | use_robot=True, 34 | ): 35 | self._env = env 36 | # self._width = width 37 | # self._height = height 38 | self._max_episode_len = max_episode_len 39 | self._max_state_dim = max_state_dim 40 | self._task_description = task_description 41 | self.pixel_keys = pixel_keys 42 | self.aux_keys = aux_keys 43 | self.use_robot = use_robot 44 | 45 | # task emb 46 | 47 | obs = self._env.reset() 48 | if self.use_robot: 49 | pixels = obs[pixel_keys[0]] 50 | self.observation_space = spaces.Box( 51 | low=0, high=255, shape=pixels.shape, dtype=pixels.dtype 52 | ) 53 | 54 | # Action spec 55 | action_spec = self._env.action_space 56 | # self._action_spec = specs.BoundedArray( 57 | # action_spec[0].shape, np.float32, action_spec[0], action_spec[1], "action" 58 | # ) 59 | self._action_spec = specs.Array( 60 | shape=action_spec.shape, dtype=action_spec.dtype, name="action" 61 | ) 62 | # Observation spec 63 | # robot_state = np.concatenate( 64 | # [obs["robot0_joint_pos"], obs["robot0_gripper_qpos"]] 65 | # ) 66 | features = obs["features"] 67 | self._obs_spec = {} 68 | for key in pixel_keys: 69 | self._obs_spec[key] = specs.BoundedArray( 70 | shape=obs[key].shape, 71 | dtype=np.uint8, 72 | minimum=0, 73 | maximum=255, 74 | name=key, 75 | ) 76 | else: 77 | pixels, features = obs["pixels"], obs["features"] 78 | self.observation_space = spaces.Box( 79 | low=0, high=255, shape=pixels.shape, dtype=pixels.dtype 80 | ) 81 | 82 | # Action spec 83 | action_spec = self._env.action_space 84 | self._action_spec = specs.Array( 85 | shape=action_spec.shape, dtype=action_spec.dtype, name="action" 86 | ) 87 | 88 | # Observation spec 89 | self._obs_spec = {} 90 | for key in pixel_keys: 91 | self._obs_spec[key] = specs.BoundedArray( 92 | shape=pixels.shape, 93 | dtype=np.uint8, 94 | minimum=0, 95 | maximum=255, 96 | name=key, 97 | ) 98 | 99 | self._obs_spec["proprioceptive"] = specs.BoundedArray( 100 | shape=features.shape, 101 | dtype=np.float32, 102 | minimum=-np.inf, 103 | maximum=np.inf, 104 | name="proprioceptive", 105 | ) 106 | self._obs_spec["features"] = specs.BoundedArray( 107 | shape=(self._max_state_dim,), 108 | dtype=np.float32, 109 | minimum=-np.inf, 110 | maximum=np.inf, 111 | name="features", 112 | ) 113 | 114 | for key in aux_keys: 115 | if key.startswith("sensor"): 116 | self._obs_spec[key] = specs.BoundedArray( 117 | shape=obs[key].shape, 118 | dtype=np.float32, 119 | minimum=-np.inf, 120 | maximum=np.inf, 121 | name=key, 122 | ) 123 | if key.startswith("digit"): 124 | self._obs_spec[key] = specs.BoundedArray( 125 | shape=pixels.shape, 126 | dtype=np.uint8, 127 | minimum=0, 128 | maximum=255, 129 | name=key, 130 | ) 131 | # if "sensor" in aux_keys: 132 | # self._obs_spec["sensor"] = specs.BoundedArray( 133 | # shape=obs["sensor"].shape, 134 | # dtype=np.float32, 135 | # minimum=-np.inf, 136 | # maximum=np.inf, 137 | # name="sensor", 138 | # ) 139 | 140 | self.render_image = None 141 | 142 | def reset(self, **kwargs): 143 | self._step = 0 144 | obs = self._env.reset(**kwargs) 145 | 146 | observation = {} 147 | for key in self.pixel_keys: 148 | observation[key] = obs[key] 149 | observation["proprioceptive"] = obs["features"] 150 | observation["features"] = obs["features"] 151 | for key in self.aux_keys: 152 | if key.startswith("sensor"): 153 | observation[key] = obs[key] 154 | observation["goal_achieved"] = False 155 | return observation 156 | 157 | def step(self, action): 158 | self._step += 1 159 | obs, reward, truncated, terminated, info = self._env.step(action) 160 | done = truncated or terminated 161 | 162 | observation = {} 163 | for key in self.pixel_keys: 164 | observation[key] = obs[key] 165 | observation["proprioceptive"] = obs["features"] 166 | observation["features"] = obs["features"] 167 | if "sensor" in self.aux_keys: 168 | observation["sensor"] = obs["sensor"] 169 | for key in self.aux_keys: 170 | if key.startswith("sensor"): 171 | observation[key] = obs[key] 172 | observation["goal_achieved"] = done # (self._step == self._max_episode_len) 173 | return observation, reward, done, info 174 | 175 | def observation_spec(self): 176 | return self._obs_spec 177 | 178 | def action_spec(self): 179 | return self._action_spec 180 | 181 | def render(self, mode="rgb_array", width=256, height=256): 182 | return cv2.resize(self._env.render("rgb_array"), (width, height)) 183 | 184 | def __getattr__(self, name): 185 | return getattr(self._env, name) 186 | 187 | 188 | class ActionRepeatWrapper(dm_env.Environment): 189 | def __init__(self, env, num_repeats): 190 | self._env = env 191 | self._num_repeats = num_repeats 192 | 193 | def step(self, action): 194 | reward = 0.0 195 | discount = 1.0 196 | for i in range(self._num_repeats): 197 | time_step = self._env.step(action) 198 | reward += (time_step.reward or 0.0) * discount 199 | discount *= time_step.discount 200 | if time_step.last(): 201 | break 202 | 203 | return time_step._replace(reward=reward, discount=discount) 204 | 205 | def observation_spec(self): 206 | return self._env.observation_spec() 207 | 208 | def action_spec(self): 209 | return self._env.action_spec() 210 | 211 | def reset(self, **kwargs): 212 | return self._env.reset(**kwargs) 213 | 214 | def __getattr__(self, name): 215 | return getattr(self._env, name) 216 | 217 | 218 | class FrameStackWrapper(dm_env.Environment): 219 | def __init__(self, env, num_frames): 220 | self._env = env 221 | self._num_frames = num_frames 222 | 223 | self.pixel_keys = [ 224 | keys for keys in env.observation_spec().keys() if "pixels" in keys 225 | ] 226 | wrapped_obs_spec = env.observation_spec()[self.pixel_keys[0]] 227 | 228 | # frames lists 229 | self._frames = {} 230 | for key in self.pixel_keys: 231 | self._frames[key] = deque([], maxlen=num_frames) 232 | 233 | pixels_shape = wrapped_obs_spec.shape 234 | if len(pixels_shape) == 4: 235 | pixels_shape = pixels_shape[1:] 236 | self._obs_spec = {} 237 | self._obs_spec["features"] = self._env.observation_spec()["features"] 238 | self._obs_spec["proprioceptive"] = self._env.observation_spec()[ 239 | "proprioceptive" 240 | ] 241 | for key in self._env.observation_spec().keys(): 242 | if key.startswith("sensor"): 243 | self._obs_spec[key] = self._env.observation_spec()[key] 244 | if key.startswith("digit"): 245 | self._obs_spec[key] = self._env.observation_spec()[key] 246 | for key in self.pixel_keys: 247 | self._obs_spec[key] = specs.BoundedArray( 248 | shape=np.concatenate( 249 | [[pixels_shape[2] * num_frames], pixels_shape[:2]], axis=0 250 | ), 251 | dtype=np.uint8, 252 | minimum=0, 253 | maximum=255, 254 | name=key, 255 | ) 256 | 257 | def _transform_observation(self, time_step): 258 | for key in self.pixel_keys: 259 | assert len(self._frames[key]) == self._num_frames 260 | obs = {} 261 | obs["features"] = time_step.observation["features"] 262 | for key in self.pixel_keys: 263 | obs[key] = np.concatenate(list(self._frames[key]), axis=0) 264 | obs["proprioceptive"] = time_step.observation["proprioceptive"] 265 | try: 266 | for key in time_step.observation.keys(): 267 | if key.startswith("sensor"): 268 | obs[key] = time_step.observation[key] 269 | except KeyError: 270 | pass 271 | obs["goal_achieved"] = time_step.observation["goal_achieved"] 272 | return time_step._replace(observation=obs) 273 | 274 | def _extract_pixels(self, time_step): 275 | pixels = {} 276 | for key in self.pixel_keys: 277 | pixels[key] = time_step.observation[key] 278 | if len(pixels[key].shape) == 4: 279 | pixels[key] = pixels[key][0] 280 | pixels[key] = pixels[key].transpose(2, 0, 1) 281 | return pixels 282 | 283 | def reset(self, **kwargs): 284 | time_step = self._env.reset(**kwargs) 285 | pixels = self._extract_pixels(time_step) 286 | for key in self.pixel_keys: 287 | for _ in range(self._num_frames): 288 | self._frames[key].append(pixels[key]) 289 | return self._transform_observation(time_step) 290 | 291 | def step(self, action): 292 | time_step = self._env.step(action) 293 | pixels = self._extract_pixels(time_step) 294 | for key in self.pixel_keys: 295 | self._frames[key].append(pixels[key]) 296 | return self._transform_observation(time_step) 297 | 298 | def observation_spec(self): 299 | return self._obs_spec 300 | 301 | def action_spec(self): 302 | return self._env.action_spec() 303 | 304 | def __getattr__(self, name): 305 | return getattr(self._env, name) 306 | 307 | 308 | class ActionDTypeWrapper(dm_env.Environment): 309 | def __init__(self, env, dtype): 310 | self._env = env 311 | self._discount = 1.0 312 | 313 | # Action spec 314 | wrapped_action_spec = env.action_spec() 315 | self._action_spec = specs.Array( 316 | shape=wrapped_action_spec.shape, dtype=dtype, name="action" 317 | ) 318 | 319 | def step(self, action): 320 | action = action.astype(self._env.action_spec().dtype) 321 | # Make time step for action space 322 | observation, reward, done, info = self._env.step(action) 323 | step_type = StepType.LAST if done else StepType.MID 324 | 325 | return TimeStep( 326 | step_type=step_type, 327 | reward=reward, 328 | discount=self._discount, 329 | observation=observation, 330 | ) 331 | 332 | def observation_spec(self): 333 | return self._env.observation_spec() 334 | 335 | def action_spec(self): 336 | return self._action_spec 337 | 338 | def reset(self, **kwargs): 339 | obs = self._env.reset(**kwargs) 340 | return TimeStep( 341 | step_type=StepType.FIRST, reward=0, discount=self._discount, observation=obs 342 | ) 343 | 344 | def __getattr__(self, name): 345 | return getattr(self._env, name) 346 | 347 | 348 | class ExtendedTimeStep(NamedTuple): 349 | step_type: Any 350 | reward: Any 351 | discount: Any 352 | observation: Any 353 | action: Any 354 | 355 | def first(self): 356 | return self.step_type == StepType.FIRST 357 | 358 | def mid(self): 359 | return self.step_type == StepType.MID 360 | 361 | def last(self): 362 | return self.step_type == StepType.LAST 363 | 364 | def __getitem__(self, attr): 365 | return getattr(self, attr) 366 | 367 | 368 | class ExtendedTimeStepWrapper(dm_env.Environment): 369 | def __init__(self, env): 370 | self._env = env 371 | 372 | def reset(self, **kwargs): 373 | time_step = self._env.reset(**kwargs) 374 | return self._augment_time_step(time_step) 375 | 376 | def step(self, action): 377 | time_step = self._env.step(action) 378 | return self._augment_time_step(time_step, action) 379 | 380 | def _augment_time_step(self, time_step, action=None): 381 | if action is None: 382 | action_spec = self.action_spec() 383 | action = np.zeros(action_spec.shape, dtype=action_spec.dtype) 384 | return ExtendedTimeStep( 385 | observation=time_step.observation, 386 | step_type=time_step.step_type, 387 | action=action, 388 | reward=time_step.reward or 0.0, 389 | discount=time_step.discount or 1.0, 390 | ) 391 | 392 | def _replace( 393 | self, time_step, observation=None, action=None, reward=None, discount=None 394 | ): 395 | if observation is None: 396 | observation = time_step.observation 397 | if action is None: 398 | action = time_step.action 399 | if reward is None: 400 | reward = time_step.reward 401 | if discount is None: 402 | discount = time_step.discount 403 | return ExtendedTimeStep( 404 | observation=observation, 405 | step_type=time_step.step_type, 406 | action=action, 407 | reward=reward, 408 | discount=discount, 409 | ) 410 | 411 | def observation_spec(self): 412 | return self._env.observation_spec() 413 | 414 | def action_spec(self): 415 | return self._env.action_spec() 416 | 417 | def __getattr__(self, name): 418 | return getattr(self._env, name) 419 | 420 | 421 | def make( 422 | frame_stack, 423 | action_repeat, 424 | height, 425 | width, 426 | max_episode_len, 427 | max_state_dim, 428 | use_egocentric, 429 | use_fisheye, 430 | task_description, 431 | pixel_keys, 432 | aux_keys, 433 | sensor_params, 434 | eval, # True means use_robot=True 435 | ): 436 | env = gym.make( 437 | "Robot-v1", 438 | height=height, 439 | width=width, 440 | use_robot=eval, 441 | use_egocentric=use_egocentric, 442 | use_fisheye=use_fisheye, 443 | subtract_sensor_baseline=sensor_params.subtract_sensor_baseline, 444 | ) 445 | 446 | # apply wrappers 447 | env = RGBArrayAsObservationWrapper( 448 | env, 449 | max_episode_len=max_episode_len, 450 | max_state_dim=max_state_dim, 451 | task_description=task_description, 452 | pixel_keys=pixel_keys, 453 | aux_keys=aux_keys, 454 | use_robot=eval, 455 | ) 456 | env = ActionDTypeWrapper(env, np.float32) 457 | env = ActionRepeatWrapper(env, action_repeat) 458 | env = FrameStackWrapper(env, frame_stack) 459 | env = ExtendedTimeStepWrapper(env) 460 | 461 | return [env], [task_description] 462 | -------------------------------------------------------------------------------- /process_xarm_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import pickle 4 | import cv2 5 | import pandas as pd 6 | import os 7 | import subprocess 8 | import os 9 | import re 10 | import shutil 11 | import h5py 12 | from pathlib import Path 13 | 14 | from utils import DATA_DIR 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--task-name", "-t", type=str, required=True) 18 | args = parser.parse_args() 19 | 20 | TASK_NAME = args.task_name 21 | 22 | DATA_PATH = Path(DATA_DIR) 23 | SAVE_PATH = Path(DATA_DIR) / "processed_data" 24 | 25 | num_demos = None 26 | cam_indices = { 27 | 1: "rgb", 28 | 2: "rgb", 29 | 51: "fish_eye", 30 | 52: "fish_eye", 31 | } 32 | states_file_name = "states" 33 | sensor_file_name = "sensor" 34 | 35 | # Create the save path 36 | SAVE_PATH.mkdir(parents=True, exist_ok=True) 37 | done_flag = True 38 | skip_cam_processing = False 39 | process_type = "cont" 40 | 41 | print(f"#################### Processing task {TASK_NAME} ####################") 42 | 43 | # Check if previous demos from this task exist 44 | if Path(f"{SAVE_PATH}/{TASK_NAME}").exists(): 45 | num_prev_demos = len([f for f in (SAVE_PATH / TASK_NAME).iterdir() if f.is_dir()]) 46 | if num_prev_demos > 0: 47 | cont_check = input( 48 | f"Previous demonstrations from task {TASK_NAME} exist. Continue from existing demos? y/n." 49 | ) 50 | if cont_check == "n": 51 | ow_check = input( 52 | f"Overwrite existing demonstrations from task {TASK_NAME}? y/n." 53 | ) 54 | if ow_check == "y": 55 | num_prev_demos = 0 56 | else: 57 | print("Appending new demonstrations to the existing ones.") 58 | process_type = "append" 59 | elif cont_check == "y": 60 | num_prev_demos -= 1 # overwrite the last demo 61 | 62 | else: 63 | num_prev_demos = 0 64 | (SAVE_PATH / TASK_NAME).mkdir(parents=True, exist_ok=True) 65 | 66 | # demo directories 67 | DEMO_DIRS = [ 68 | f 69 | for f in (DATA_PATH / TASK_NAME).iterdir() 70 | if f.is_dir() and "fail" not in f.name and "ignore" not in f.name 71 | ] 72 | if num_demos is not None: 73 | DEMO_DIRS = DEMO_DIRS[:num_demos] 74 | 75 | for num, demo_dir in enumerate(sorted(DEMO_DIRS)): 76 | process_sensor = True 77 | # try: 78 | if process_type == "cont" and num < num_prev_demos: 79 | print(f"Skipping demonstration {demo_dir.name}") 80 | continue 81 | if process_type == "append": 82 | demo_id = num + num_prev_demos 83 | elif process_type == "cont": 84 | demo_id = int(demo_dir.name.split("_")[-1]) 85 | print("Processing demonstration", demo_dir.name) 86 | output_path = f"{SAVE_PATH}/{TASK_NAME}/demonstration_{demo_id}/" 87 | print("Output path:", output_path) 88 | Path(output_path).mkdir(parents=True, exist_ok=True) 89 | csv_list = [f for f in os.listdir(output_path) if f.endswith(".csv")] 90 | for f in csv_list: 91 | os.remove(os.path.join(output_path, f)) 92 | cam_avis = [f"{demo_dir}/cam_{i}_{cam_indices[i]}_video.avi" for i in cam_indices] 93 | 94 | try: 95 | cartesian = h5py.File(f"{demo_dir}/xarm_cartesian_states.h5", "r") 96 | state_timestamps = cartesian["timestamps"] 97 | state_positions = cartesian["cartesian_positions"] 98 | 99 | gripper = h5py.File(f"{demo_dir}/xarm_gripper_states.h5", "r") 100 | gripper_positions = gripper["gripper_positions"] 101 | except: 102 | print("No cartesian or gripper states found. Skipping this demo.") 103 | continue 104 | 105 | try: 106 | with h5py.File(f"{demo_dir}/reskin_sensor_values.h5") as hf: 107 | sensor_timestamps = np.array(hf["timestamps"]) 108 | sensor_values = np.array(hf["sensor_values"]) 109 | except FileNotFoundError: 110 | print("No sensor values found. Skipping sensor processing.") 111 | process_sensor = False 112 | 113 | state_positions = np.array(state_positions) 114 | gripper_positions = np.array(gripper_positions) 115 | gripper_positions = gripper_positions.reshape(-1, 1) 116 | 117 | state_timestamps = np.array(state_timestamps) 118 | 119 | # Find indices of timestamps where the robot moves 120 | static_timestamps = [] 121 | static = False 122 | start, end = None, None 123 | for i in range(1, len(state_positions) - 1): 124 | if ( 125 | np.array_equal(state_positions[i], state_positions[i + 1]) 126 | and static == False 127 | ): 128 | static = True 129 | start = i 130 | elif ( 131 | not np.array_equal(state_positions[i], state_positions[i + 1]) 132 | and static == True 133 | ): 134 | static = False 135 | end = i 136 | static_timestamps.append((start, end)) 137 | if static: 138 | static_timestamps.append((start, len(state_positions) - 1)) 139 | 140 | # read metadata file 141 | CAM_TIMESTAMPS = [] 142 | CAM_VALID_LENS = [] 143 | skip = False 144 | for idx in cam_indices: 145 | cam_meta_file_path = f"{demo_dir}/cam_{idx}_{cam_indices[idx]}_video.metadata" 146 | with open(cam_meta_file_path, "rb") as f: 147 | image_metadata = pickle.load(f) 148 | image_timestamps = np.asarray(image_metadata["timestamps"]) / 1000.0 149 | 150 | cam_timestamps = dict(timestamps=image_timestamps) 151 | # convert to numpy array 152 | cam_timestamps = np.array(cam_timestamps["timestamps"]) 153 | 154 | # Fish eye cam timestamps are divided by 1000 155 | if max(cam_timestamps) < state_timestamps[static_timestamps[0][1]]: 156 | cam_timestamps *= 1000 157 | elif min(cam_timestamps) > state_timestamps[static_timestamps[-1][0]]: 158 | cam_timestamps /= 1000 159 | 160 | valid_indices = [] 161 | for k in range(len(static_timestamps) - 1): 162 | start_idx = sum(cam_timestamps < state_timestamps[static_timestamps[k][1]]) 163 | end_idx = sum( 164 | cam_timestamps < state_timestamps[static_timestamps[k + 1][0]] 165 | ) 166 | valid_indices.extend([i for i in range(start_idx, end_idx)]) 167 | cam_timestamps = cam_timestamps[valid_indices] 168 | 169 | # if no valid timestamps, skip 170 | if len(cam_timestamps) == 0: 171 | skip = True 172 | break 173 | 174 | CAM_VALID_LENS.append(valid_indices) 175 | CAM_TIMESTAMPS.append(cam_timestamps) 176 | if skip: 177 | continue 178 | 179 | # cam frames 180 | if not skip_cam_processing: 181 | CAM_FRAMES = [] 182 | for idx in range(len(cam_avis)): # cam_indices: 183 | cam_avi = cam_avis[idx] 184 | cam_frames = [] 185 | cap_cap = cv2.VideoCapture(cam_avi) 186 | while cap_cap.isOpened(): 187 | ret, frame = cap_cap.read() 188 | if ret == False: 189 | break 190 | cam_frames.append(frame) 191 | cap_cap.release() 192 | 193 | # save frames 194 | cam_frames = np.array(cam_frames) 195 | cam_frames = cam_frames[CAM_VALID_LENS[idx]] 196 | CAM_FRAMES.append(cam_frames) 197 | 198 | rgb_frames = CAM_FRAMES 199 | timestamps = CAM_TIMESTAMPS 200 | timestamps.append(state_timestamps) 201 | 202 | min_time_index = np.argmin([len(timestamp) for timestamp in timestamps]) 203 | reference_timestamps = timestamps[min_time_index] 204 | align = [] 205 | index = [] 206 | for i in range(len(timestamps)): 207 | # aligning frames 208 | if i == min_time_index: 209 | align.append(timestamps[i]) 210 | index.append(np.arange(len(timestamps[i]))) 211 | continue 212 | curindex = [] 213 | currrlist = [] 214 | for j in range(len(reference_timestamps)): 215 | curlist = [] 216 | for k in range(len(timestamps[i])): 217 | curlist.append(abs(timestamps[i][k] - reference_timestamps[j])) 218 | min_index = curlist.index(min(curlist)) 219 | currrlist.append(timestamps[i][min_index]) 220 | curindex.append(min_index) 221 | align.append(currrlist) 222 | index.append(curindex) 223 | 224 | index = np.array(index) 225 | 226 | if process_sensor: 227 | if max(sensor_timestamps) < state_timestamps[static_timestamps[0][1]]: 228 | print("Sensor data issue") 229 | elif min(sensor_timestamps) > state_timestamps[static_timestamps[-1][0]]: 230 | print("Sensor data issue") 231 | else: 232 | print("All good with sensor data!") 233 | sensor_valid_indices = [] 234 | for k in range(len(static_timestamps) - 1): 235 | start_idx = sum( 236 | sensor_timestamps < state_timestamps[static_timestamps[k][1]] 237 | ) 238 | end_idx = sum( 239 | sensor_timestamps < state_timestamps[static_timestamps[k + 1][0]] 240 | ) 241 | sensor_valid_indices.extend([i for i in range(start_idx, end_idx)]) 242 | 243 | sensor_timestamps = sensor_timestamps[sensor_valid_indices] 244 | sensor_values = sensor_values[sensor_valid_indices] 245 | 246 | sensor_timestamps_test = pd.DataFrame(sensor_timestamps) 247 | sensor_values_test = [] 248 | for i in range(len(sensor_values)): 249 | sensor_values_test.append(np.array(sensor_values[i])) 250 | sensor_values_test = pd.DataFrame( 251 | {"column": [list(row) for row in sensor_values_test]} 252 | ) 253 | 254 | sensor_test = pd.concat( 255 | [sensor_timestamps_test, sensor_values_test], 256 | axis=1, 257 | ) 258 | 259 | with open(output_path + f"big_{sensor_file_name}.csv", "a") as f: 260 | sensor_test.to_csv( 261 | f, 262 | header=["created timestamp", "sensor_values"], 263 | index=False, 264 | ) 265 | 266 | # convert left_state_timestamps and left_state_positions to a csv file with header "created timestamp", "pose_aa", "gripper_state" 267 | state_timestamps_test = pd.DataFrame(state_timestamps) 268 | # convert each pose_aa to a list 269 | state_positions_test = state_positions 270 | for i in range(len(state_positions_test)): 271 | state_positions_test[i] = np.array(state_positions_test[i]) 272 | state_positions_test = pd.DataFrame( 273 | {"column": [list(row) for row in state_positions_test]} 274 | ) 275 | # convert left_gripper to True and False 276 | gripper_positions_test = pd.DataFrame(gripper_positions) 277 | 278 | state_test = pd.concat( 279 | [state_timestamps_test, state_positions_test, gripper_positions_test], 280 | axis=1, 281 | ) 282 | with open(output_path + f"big_{states_file_name}.csv", "a") as f: 283 | state_test.to_csv( 284 | f, 285 | header=["created timestamp", "pose_aa", "gripper_state"], 286 | index=False, 287 | ) 288 | 289 | df = pd.read_csv(output_path + f"big_{states_file_name}.csv") 290 | for i in range(len(reference_timestamps)): 291 | curlist = [] 292 | for j in range(len(state_timestamps)): 293 | curlist.append(abs(state_timestamps[j] - reference_timestamps[i])) 294 | min_index = curlist.index(min(curlist)) 295 | min_df = df.iloc[min_index] 296 | min_df = min_df.to_frame().transpose() 297 | with open(output_path + f"{states_file_name}.csv", "a") as f: 298 | min_df.to_csv(f, header=f.tell() == 0, index=False) 299 | 300 | if process_sensor: 301 | df = pd.read_csv(output_path + f"big_{sensor_file_name}.csv") 302 | for i in range(len(reference_timestamps)): 303 | curlist = [] 304 | for j in range(len(sensor_timestamps)): 305 | curlist.append(abs(sensor_timestamps[j] - reference_timestamps[i])) 306 | min_index = curlist.index(min(curlist)) 307 | min_df = df.iloc[min_index] 308 | min_df = min_df.to_frame().transpose() 309 | with open(output_path + f"{sensor_file_name}.csv", "a") as f: 310 | min_df.to_csv(f, header=f.tell() == 0, index=False) 311 | 312 | # Create folders for each camera if they don't exist 313 | output_folder = output_path + "videos" 314 | os.makedirs(output_folder, exist_ok=True) 315 | camera_folders = [f"camera{i}" for i in cam_indices] 316 | for folder in camera_folders: 317 | os.makedirs(os.path.join(output_folder, folder), exist_ok=True) 318 | 319 | # Iterate over each camera and extract the frames based on the indexes 320 | if not skip_cam_processing: 321 | for camera_index, frames in enumerate(rgb_frames): 322 | camera_folder = camera_folders[camera_index] 323 | print(f"Extracting frames for {camera_folder}...") 324 | indexes = index[camera_index] 325 | 326 | # Iterate over the indexes and save the corresponding frames 327 | for i, indexx in enumerate(indexes): 328 | if i % 100 == 0: 329 | print(f"Extracting frame {i}...") 330 | frame = frames[indexx] 331 | # name frame with its timestamp 332 | image_output_path = os.path.join( 333 | output_folder, 334 | camera_folder, 335 | f"frame_{i}_{timestamps[camera_index][indexx]}.jpg", 336 | ) 337 | cv2.imwrite(image_output_path, frame) 338 | 339 | csv_file = os.path.join(output_path, f"{states_file_name}.csv") 340 | print(output_path, demo_dir.name) 341 | 342 | def get_timestamp_from_filename(filename): 343 | # Extract the timestamp from the filename using regular expression 344 | timestamp_match = re.search(r"\d+\.\d+", filename) 345 | if timestamp_match: 346 | return float(timestamp_match.group()) 347 | else: 348 | return None 349 | 350 | # add desired gripper states 351 | for file in [csv_file]: 352 | df = pd.read_csv(file) 353 | df["desired_gripper_state"] = df["gripper_state"].shift(-1) 354 | df.loc[df.index[-1], "desired_gripper_state"] = df.loc[ 355 | df.index[-2], "gripper_state" 356 | ] 357 | df.to_csv(file, index=False) 358 | 359 | def save_only_videos(base_folder_path): 360 | base_folder_path = os.path.join(base_folder_path, "videos") 361 | # Iterate over each camera folder 362 | for cam in cam_indices: 363 | cam_folder = f"camera{cam}" 364 | full_folder_path = os.path.join(base_folder_path, cam_folder) 365 | 366 | # Check if the folder exists 367 | if os.path.exists(full_folder_path): 368 | # List all jpg files 369 | all_files = [ 370 | f for f in os.listdir(full_folder_path) if f.endswith(".jpg") 371 | ] 372 | 373 | # Sort files based on the floating-point number in their name 374 | sorted_files = sorted(all_files, key=get_timestamp_from_filename) 375 | 376 | # Write filenames to a temp file 377 | temp_list_filename = os.path.join(base_folder_path, "temp_list.txt") 378 | with open(temp_list_filename, "w") as f: 379 | for filename in sorted_files: 380 | f.write(f"file '{os.path.join(full_folder_path, filename)}'\n") 381 | 382 | # Use ffmpeg to convert sorted images to video 383 | output_video_path = os.path.join(base_folder_path, f"camera{cam}.mp4") 384 | cmd = [ 385 | "ffmpeg", 386 | "-f", 387 | "concat", 388 | "-safe", 389 | "0", 390 | "-i", 391 | temp_list_filename, 392 | "-framerate", 393 | "30", # assuming 24 fps, change if needed 394 | "-vcodec", 395 | "libx264", 396 | "-crf", 397 | "18", # quality, lower means better quality 398 | "-pix_fmt", 399 | "yuv420p", 400 | output_video_path, 401 | ] 402 | try: 403 | subprocess.run(cmd, check=True) 404 | except Exception as e: 405 | print(f"EXCEPTION: {e}") 406 | input("Continue?") 407 | 408 | # Delete the temporary list file and the image folder 409 | os.remove(temp_list_filename) 410 | shutil.rmtree(full_folder_path) 411 | else: 412 | print(f"Folder {cam_folder} does not exist!") 413 | 414 | if not skip_cam_processing: 415 | save_only_videos(output_path) 416 | -------------------------------------------------------------------------------- /read_data/xarm_env_aa.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import cv2 3 | import random 4 | import numpy as np 5 | import pickle as pkl 6 | from pathlib import Path 7 | 8 | import torch 9 | import torchvision.transforms as transforms 10 | from torch.utils.data import IterableDataset 11 | from scipy.spatial.transform import Rotation as R 12 | 13 | 14 | def get_relative_action(actions, action_after_steps): 15 | """ 16 | Convert absolute axis angle actions to relative axis angle actions 17 | Action has both position and orientation. Convert to transformation matrix, get 18 | relative transformation matrix, convert back to axis angle 19 | """ 20 | 21 | relative_actions = [] 22 | for i in range(len(actions)): 23 | # Get relative transformation matrix 24 | # previous pose 25 | pos_prev = actions[i, :3] 26 | ori_prev = actions[i, 3:6] 27 | r_prev = R.from_rotvec(ori_prev).as_matrix() 28 | matrix_prev = np.eye(4) 29 | matrix_prev[:3, :3] = r_prev 30 | matrix_prev[:3, 3] = pos_prev 31 | # current pose 32 | next_idx = min(i + action_after_steps, len(actions) - 1) 33 | pos = actions[next_idx, :3] 34 | ori = actions[next_idx, 3:6] 35 | gripper = actions[next_idx, 6:] 36 | r = R.from_rotvec(ori).as_matrix() 37 | matrix = np.eye(4) 38 | matrix[:3, :3] = r 39 | matrix[:3, 3] = pos 40 | # relative transformation 41 | matrix_rel = np.linalg.inv(matrix_prev) @ matrix 42 | # relative pose 43 | # pos_rel = matrix_rel[:3, 3] 44 | pos_rel = pos - pos_prev 45 | r_rel = R.from_matrix(matrix_rel[:3, :3]).as_rotvec() 46 | # # compute relative rotation 47 | # r_prev = R.from_rotvec(ori_prev).as_matrix() 48 | # r = R.from_rotvec(ori).as_matrix() 49 | # r_rel = np.linalg.inv(r_prev) @ r 50 | # r_rel = R.from_matrix(r_rel).as_rotvec() 51 | # # compute relative translation 52 | # pos_rel = pos - pos_prev 53 | relative_actions.append(np.concatenate([pos_rel, r_rel, gripper])) 54 | # next_idx = min(i + action_after_steps, len(actions) - 1) 55 | # curr_pose, _ = actions[i, :6], actions[i, 6:] 56 | # next_pose, next_gripper = actions[next_idx, :6], actions[next_idx, 6:] 57 | 58 | # last action 59 | last_action = np.zeros_like(actions[-1]) 60 | last_action[-1] = actions[-1][-1] 61 | while len(relative_actions) < len(actions): 62 | relative_actions.append(last_action) 63 | return np.array(relative_actions, dtype=np.float32) 64 | 65 | 66 | def get_absolute_action(rel_actions, base_action): 67 | """ 68 | Convert relative axis angle actions to absolute axis angle actions 69 | """ 70 | actions = np.zeros((len(rel_actions) + 1, rel_actions.shape[-1])) 71 | actions[0] = base_action 72 | for i in range(1, len(rel_actions) + 1): 73 | # if i == 0: 74 | # actions.append(base_action) 75 | # continue 76 | # Get relative transformation matrix 77 | # previous pose 78 | pos_prev = actions[i - 1, :3] 79 | ori_prev = actions[i - 1, 3:6] 80 | r_prev = R.from_rotvec(ori_prev).as_matrix() 81 | matrix_prev = np.eye(4) 82 | matrix_prev[:3, :3] = r_prev 83 | matrix_prev[:3, 3] = pos_prev 84 | # relative pose 85 | pos_rel = rel_actions[i - 1, :3] 86 | r_rel = rel_actions[i - 1, 3:6] 87 | # compute relative transformation matrix 88 | matrix_rel = np.eye(4) 89 | matrix_rel[:3, :3] = R.from_rotvec(r_rel).as_matrix() 90 | matrix_rel[:3, 3] = pos_rel 91 | # compute absolute transformation matrix 92 | matrix = matrix_prev @ matrix_rel 93 | # absolute pose 94 | pos = matrix[:3, 3] 95 | # r = R.from_matrix(matrix[:3, :3]).as_rotvec() 96 | r = R.from_matrix(matrix[:3, :3]).as_euler("xyz") 97 | actions[i] = np.concatenate([pos, r, rel_actions[i - 1, 6:]]) 98 | return np.array(actions, dtype=np.float32) 99 | 100 | 101 | def get_quaternion_orientation(cartesian): 102 | """ 103 | Get quaternion orientation from axis angle representation 104 | """ 105 | new_cartesian = [] 106 | for i in range(len(cartesian)): 107 | pos = cartesian[i, :3] 108 | ori = cartesian[i, 3:] 109 | quat = R.from_rotvec(ori).as_quat() 110 | new_cartesian.append(np.concatenate([pos, quat], axis=-1)) 111 | return np.array(new_cartesian, dtype=np.float32) 112 | 113 | 114 | class BCDataset(IterableDataset): 115 | def __init__( 116 | self, 117 | path, 118 | tasks, 119 | num_demos_per_task, 120 | temporal_agg, 121 | num_queries, 122 | img_size, 123 | action_after_steps, 124 | store_actions, 125 | pixel_keys, 126 | aux_keys, 127 | subsample, 128 | skip_first_n, 129 | relative_actions, 130 | random_mask_proprio, 131 | sensor_params, 132 | ): 133 | self._img_size = img_size 134 | self._action_after_steps = action_after_steps 135 | self._store_actions = store_actions 136 | self._pixel_keys = pixel_keys 137 | self._aux_keys = aux_keys 138 | self._random_mask_proprio = random_mask_proprio 139 | 140 | self._sensor_type = sensor_params.sensor_type 141 | self._subtract_sensor_baseline = sensor_params.subtract_sensor_baseline 142 | if self._sensor_type == "digit": 143 | use_anyskin_data = False 144 | elif self._sensor_type == "reskin": 145 | use_anyskin_data = True 146 | else: 147 | assert self._sensor_type == None 148 | 149 | self._num_anyskin_sensors = 2 150 | 151 | # temporal aggregation 152 | self._temporal_agg = temporal_agg 153 | self._num_queries = num_queries 154 | 155 | # get data paths 156 | self._paths = [] 157 | self._paths.extend([Path(path) / f"{task}.pkl" for task in tasks]) 158 | 159 | paths = {} 160 | idx = 0 161 | for path in self._paths: 162 | paths[idx] = path 163 | idx += 1 164 | del self._paths 165 | self._paths = paths 166 | 167 | # store actions 168 | if self._store_actions: 169 | self.actions = [] 170 | 171 | # read data 172 | self._episodes = {} 173 | self._max_episode_len = 0 174 | self._max_state_dim = 7 175 | self._num_samples = 0 176 | min_stat, max_stat = None, None 177 | min_sensor_stat, max_sensor_stat = None, None 178 | digit_mean_stat, digit_std_stat = defaultdict(list), defaultdict(list) 179 | min_sensor_diff_stat, max_sensor_diff_stat = None, None 180 | min_act, max_act = None, None 181 | self.prob = [] 182 | sensor_states = [] 183 | for _path_idx in self._paths: 184 | print(f"Loading {str(self._paths[_path_idx])}") 185 | # Add to prob 186 | if "fridge" in str(self._paths[_path_idx]): 187 | # self.prob.append(25.0/11.0) 188 | self.prob.append(22.0 / 9.0) 189 | else: 190 | self.prob.append(1) 191 | # read 192 | data = pkl.load(open(str(self._paths[_path_idx]), "rb")) 193 | observations = data["observations"] 194 | # store 195 | self._episodes[_path_idx] = [] 196 | for i in range(min(num_demos_per_task, len(observations))): 197 | # compute actions 198 | # absolute actions 199 | actions = np.concatenate( 200 | [ 201 | observations[i]["cartesian_states"], 202 | observations[i]["gripper_states"][:, None], 203 | ], 204 | axis=1, 205 | ) 206 | if len(actions) == 0: 207 | continue 208 | # skip first n 209 | if skip_first_n is not None: 210 | for key in observations[i].keys(): 211 | observations[i][key] = observations[i][key][skip_first_n:] 212 | actions = actions[skip_first_n:] 213 | # subsample 214 | if subsample is not None: 215 | for key in observations[i].keys(): 216 | observations[i][key] = observations[i][key][::subsample] 217 | actions = actions[::subsample] 218 | # action after steps 219 | if relative_actions: 220 | actions = get_relative_action(actions, self._action_after_steps) 221 | else: 222 | actions = actions[self._action_after_steps :] 223 | # Convert cartesian states to quaternion orientation 224 | observations[i]["cartesian_states"] = get_quaternion_orientation( 225 | observations[i]["cartesian_states"] 226 | ) 227 | if use_anyskin_data: 228 | try: 229 | sensor_baseline = np.median( 230 | observations[i]["sensor_states"][:5], axis=0, keepdims=True 231 | ) 232 | if self._subtract_sensor_baseline: 233 | observations[i]["sensor_states"] = ( 234 | observations[i]["sensor_states"] - sensor_baseline 235 | ) 236 | if max_sensor_stat is None: 237 | max_sensor_stat = np.max( 238 | observations[i]["sensor_states"], axis=0 239 | ) 240 | min_sensor_stat = np.min( 241 | observations[i]["sensor_states"], axis=0 242 | ) 243 | else: 244 | max_sensor_stat = np.maximum( 245 | max_sensor_stat, 246 | np.max(observations[i]["sensor_states"], axis=0), 247 | ) 248 | min_sensor_stat = np.minimum( 249 | min_sensor_stat, 250 | np.min(observations[i]["sensor_states"], axis=0), 251 | ) 252 | for sensor_idx in range(self._num_anyskin_sensors): 253 | observations[i][ 254 | f"sensor{sensor_idx}_states" 255 | ] = observations[i]["sensor_states"][ 256 | ..., sensor_idx * 15 : (sensor_idx + 1) * 15 257 | ] 258 | 259 | except KeyError: 260 | print("WARN: Sensor data not found.") 261 | use_anyskin_data = False 262 | elif self._sensor_type == "digit": 263 | for key in self._aux_keys: 264 | if key.startswith("digit"): 265 | observations[i][key] = ( 266 | observations[i][key].astype(np.float32) / 255.0 267 | ) 268 | if self._subtract_sensor_baseline: 269 | sensor_baseline = np.median( 270 | observations[i][key][:5], axis=0, keepdims=True 271 | ) # .astype(observations[i][key].dtype) 272 | observations[i][key] = ( 273 | observations[i][key] - sensor_baseline 274 | ) 275 | delta_filter = np.abs(observations[i][key]) > ( 276 | 5.0 / 255.0 277 | ) 278 | digit_std_stat[key].append( 279 | observations[i][key][delta_filter] 280 | ) 281 | else: 282 | pass 283 | 284 | for key in observations[i].keys(): 285 | observations[i][key] = np.concatenate( 286 | [ 287 | [observations[i][key][0]], 288 | observations[i][key], 289 | ], 290 | axis=0, 291 | ) 292 | 293 | remaining_actions = actions[0] 294 | if relative_actions: 295 | pos = remaining_actions[:-1] 296 | ori_gripper = remaining_actions[-1:] 297 | remaining_actions = np.concatenate( 298 | [np.zeros_like(pos), ori_gripper] 299 | ) 300 | actions = np.concatenate( 301 | [ 302 | [remaining_actions], 303 | actions, 304 | ], 305 | axis=0, 306 | ) 307 | # store 308 | episode = dict( 309 | observation=observations[i], 310 | action=actions, 311 | # task_emb=task_emb, 312 | ) 313 | self._episodes[_path_idx].append(episode) 314 | self._max_episode_len = max( 315 | self._max_episode_len, 316 | ( 317 | len(observations[i]) 318 | if not isinstance(observations[i], dict) 319 | else len(observations[i][self._pixel_keys[0]]) 320 | ), 321 | ) 322 | self._num_samples += len(observations[i][self._pixel_keys[0]]) 323 | 324 | # max, min action 325 | if min_act is None: 326 | min_act = np.min(actions, axis=0) 327 | max_act = np.max(actions, axis=0) 328 | else: 329 | min_act = np.minimum(min_act, np.min(actions, axis=0)) 330 | max_act = np.maximum(max_act, np.max(actions, axis=0)) 331 | 332 | # store actions 333 | if self._store_actions: 334 | self.actions.append(actions) 335 | 336 | # keep record of max and min stat 337 | max_cartesian = data["max_cartesian"] 338 | min_cartesian = data["min_cartesian"] 339 | max_cartesian = np.concatenate( 340 | [data["max_cartesian"][:3], [1] * 4] 341 | ) # for quaternion 342 | min_cartesian = np.concatenate( 343 | [data["min_cartesian"][:3], [-1] * 4] 344 | ) # for quaternion 345 | max_gripper = data["max_gripper"] 346 | min_gripper = data["min_gripper"] 347 | max_val = np.concatenate([max_cartesian, max_gripper[None]], axis=0) 348 | min_val = np.concatenate([min_cartesian, min_gripper[None]], axis=0) 349 | if max_stat is None: 350 | max_stat = max_val 351 | min_stat = min_val 352 | else: 353 | max_stat = np.maximum(max_stat, max_val) 354 | min_stat = np.minimum(min_stat, min_val) 355 | if use_anyskin_data: 356 | # If baseline is subtracted, use zero as shift and max as scale 357 | if self._subtract_sensor_baseline: 358 | max_sensor_stat = np.maximum( 359 | np.abs(max_sensor_stat), np.abs(min_sensor_stat) 360 | ) 361 | min_sensor_stat = np.zeros_like(max_sensor_stat) 362 | # If baseline isn't subtracted, use usual min and max values 363 | else: 364 | if max_sensor_stat is None: 365 | max_sensor_stat = data["max_sensor"] 366 | min_sensor_stat = data["min_sensor"] 367 | else: 368 | max_sensor_stat = np.maximum( 369 | max_sensor_stat, data["max_sensor"] 370 | ) 371 | min_sensor_stat = np.minimum( 372 | min_sensor_stat, data["min_sensor"] 373 | ) 374 | min_act[3:6], max_act[3:6] = 0, 1 ################################# 375 | self.stats = { 376 | "actions": { 377 | "min": min_act, # min_stat, 378 | "max": max_act, # max_stat, 379 | }, 380 | "proprioceptive": { 381 | "min": min_stat, 382 | "max": max_stat, 383 | }, 384 | } 385 | if use_anyskin_data: 386 | for sensor_idx in range(self._num_anyskin_sensors): 387 | sensor_mask = np.zeros_like(min_sensor_stat, dtype=bool) 388 | sensor_mask[sensor_idx * 15 : (sensor_idx + 1) * 15] = True 389 | self.stats[f"sensor{sensor_idx}"] = { 390 | "min": min_sensor_stat[sensor_mask], 391 | "max": max_sensor_stat[sensor_mask], 392 | } 393 | 394 | if not self._subtract_sensor_baseline: 395 | raise NotImplementedError( 396 | "Normalization not implemented without baseline subtraction" 397 | ) 398 | for key in self.stats: 399 | if key.startswith("sensor"): 400 | sensor_states = np.concatenate( 401 | [ 402 | observations[i][f"{key}_states"] 403 | for i in range(len(observations)) 404 | ], 405 | axis=0, 406 | ) 407 | sensor_std = ( 408 | np.std(sensor_states, axis=0).reshape((5, 3)).max(axis=0) 409 | ) 410 | sensor_std[:2] = sensor_std[:2].max() 411 | sensor_std = np.clip(sensor_std * 3, a_min=100, a_max=None) 412 | # max_xyz = np.clip(max_xyz, a_min=400, a_max=None) 413 | self.stats[key]["max"] = np.tile( 414 | sensor_std, int(self.stats[key]["max"].shape[0] / 3) 415 | ) 416 | elif self._sensor_type == "digit": 417 | if self._subtract_sensor_baseline: 418 | shared_mean, shared_std = None, None 419 | for key in digit_std_stat: 420 | digit_std_stat[key] = [ 421 | 3 * np.concatenate(digit_std_stat[key], axis=0).std() 422 | ] * 3 423 | digit_mean_stat[key] = [0.0] * 3 424 | self.stats[key] = { 425 | "mean": np.array(digit_mean_stat[key])[:, None, None], 426 | "std": np.array(digit_std_stat[key])[:, None, None], 427 | } 428 | if shared_std is None: 429 | shared_std = digit_std_stat[key] 430 | shared_mean = digit_mean_stat[key] 431 | else: 432 | shared_std = np.maximum(shared_std, digit_std_stat[key]) 433 | shared_mean = np.minimum(shared_mean, digit_mean_stat[key]) 434 | else: 435 | shared_mean = [0.485, 0.456, 0.406] 436 | shared_std = [0.229, 0.224, 0.225] 437 | self.stats["digit"] = { 438 | "mean": np.array(shared_mean)[:, None, None], 439 | "std": np.array(shared_std)[:, None, None], 440 | } 441 | self.digit_aug = transforms.Compose( 442 | [ 443 | transforms.ToTensor(), 444 | ] 445 | ) 446 | # augmentation 447 | self.aug = transforms.Compose( 448 | [ 449 | transforms.ToPILImage(), 450 | transforms.RandomCrop(self._img_size, padding=4), 451 | transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2), 452 | transforms.ToTensor(), 453 | ] 454 | ) 455 | 456 | # Samples from envs 457 | self.envs_till_idx = len(self._episodes) 458 | 459 | self.prob = np.array(self.prob) / np.sum(self.prob) 460 | 461 | def preprocess(self, key, x): 462 | if key.startswith("digit"): 463 | return (x - self.stats["digit"]["mean"]) / ( 464 | self.stats["digit"]["std"] + 1e-5 465 | ) 466 | return (x - self.stats[key]["min"]) / ( 467 | self.stats[key]["max"] - self.stats[key]["min"] + 1e-5 468 | ) 469 | 470 | def _sample_episode(self, env_idx=None): 471 | idx = random.randint(0, self.envs_till_idx - 1) if env_idx is None else env_idx 472 | 473 | # sample idx with probability 474 | idx = np.random.choice(list(self._episodes.keys()), p=self.prob) 475 | 476 | episode = random.choice(self._episodes[idx]) 477 | return (episode, idx) if env_idx is None else episode 478 | 479 | def _sample(self): 480 | episodes, env_idx = self._sample_episode() 481 | observations = episodes["observation"] 482 | actions = episodes["action"] 483 | sample_idx = np.random.randint(1, len(observations[self._pixel_keys[0]]) - 1) 484 | # Sample obs, action 485 | sampled_pixel = {} 486 | for key in self._pixel_keys: 487 | sampled_pixel[key] = observations[key][-(sample_idx + 1) : -sample_idx] 488 | sampled_pixel[key] = torch.stack( 489 | [ 490 | self.aug(sampled_pixel[key][i]) 491 | for i in range(len(sampled_pixel[key])) 492 | ] 493 | ) 494 | sampled_state = {} 495 | sampled_state = {} 496 | 497 | sampled_state = {} 498 | sampled_state["proprioceptive"] = np.concatenate( 499 | [ 500 | observations["cartesian_states"][-(sample_idx + 1) : -sample_idx], 501 | observations["gripper_states"][-(sample_idx + 1) : -sample_idx][ 502 | :, None 503 | ], 504 | ], 505 | axis=1, 506 | ) 507 | 508 | if self._random_mask_proprio and np.random.rand() < 0.5: 509 | sampled_state["proprioceptive"] = ( 510 | np.ones_like(sampled_state["proprioceptive"]) 511 | * self.stats["proprioceptive"]["min"] 512 | ) 513 | if self._sensor_type == "reskin": 514 | try: 515 | for sensor_idx in range(self._num_anyskin_sensors): 516 | skey = f"sensor{sensor_idx}" 517 | sampled_state[f"{skey}"] = observations[f"{skey}_states"][ 518 | -(sample_idx + 1) : -sample_idx 519 | ] 520 | except KeyError: 521 | pass 522 | elif self._sensor_type == "digit": 523 | try: 524 | for sensor_idx in range(self._num_anyskin_sensors): 525 | key = f"digit{80 + sensor_idx}" 526 | sampled_state[key] = observations[key][ 527 | -(sample_idx + 1) : -sample_idx 528 | ] 529 | sampled_state[key] = torch.stack( 530 | [ 531 | self.digit_aug(sampled_state[key][i]) 532 | # self.aug(sampled_state[key][i]) 533 | for i in range(len(sampled_state[key])) 534 | ] 535 | ) 536 | except KeyError as e: 537 | pass 538 | 539 | if self._temporal_agg: 540 | # arrange sampled action to be of shape (1, num_queries, action_dim) 541 | sampled_action = np.zeros((1, self._num_queries, actions.shape[-1])) 542 | num_actions = 1 + self._num_queries - 1 543 | act = np.zeros((num_actions, actions.shape[-1])) 544 | if num_actions - sample_idx < 0: 545 | act[:num_actions] = actions[-(sample_idx) : -sample_idx + num_actions] 546 | else: 547 | act[:sample_idx] = actions[-sample_idx:] 548 | act[sample_idx:] = actions[-1] 549 | sampled_action = np.lib.stride_tricks.sliding_window_view( 550 | act, (self._num_queries, actions.shape[-1]) 551 | ) 552 | sampled_action = sampled_action[:, 0] 553 | else: 554 | sampled_action = actions[-(sample_idx + 1) : -sample_idx] 555 | 556 | return_dict = {} 557 | for key in self._pixel_keys: 558 | return_dict[key] = sampled_pixel[key] 559 | for key in self._aux_keys: 560 | return_dict[key] = self.preprocess(key, sampled_state[key]) 561 | return_dict["actions"] = self.preprocess("actions", sampled_action) 562 | return_dict["actions"] = self.preprocess("actions", sampled_action) 563 | return_dict["actions"] = self.preprocess("actions", sampled_action) 564 | return return_dict 565 | 566 | def __iter__(self): 567 | while True: 568 | yield self._sample() 569 | 570 | def __len__(self): 571 | return self._num_samples 572 | -------------------------------------------------------------------------------- /agent/bc.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import hydra 3 | import numpy as np 4 | from collections import deque 5 | 6 | import torch 7 | from torch import nn 8 | 9 | from torchvision import transforms as T 10 | 11 | from agent.networks.rgb_modules import BaseEncoder, ResnetEncoder 12 | import utils 13 | 14 | from agent.networks.policy_head import DeterministicHead 15 | from agent.networks.gpt import GPT, GPTConfig, CrossAttention 16 | from agent.networks.mlp import MLP 17 | 18 | 19 | class Actor(nn.Module): 20 | def __init__( 21 | self, 22 | repr_dim, 23 | act_dim, 24 | hidden_dim, 25 | policy_type="gpt", 26 | policy_head="deterministic", 27 | num_feat_per_step=1, 28 | ): 29 | super().__init__() 30 | 31 | self._policy_type = policy_type 32 | self._policy_head = policy_head 33 | self._repr_dim = repr_dim 34 | self._act_dim = act_dim 35 | self._num_feat_per_step = num_feat_per_step 36 | 37 | self._action_token = nn.Parameter(torch.randn(1, 1, 1, repr_dim)) 38 | 39 | # GPT model 40 | if policy_type == "gpt": 41 | self._policy = GPT( 42 | GPTConfig( 43 | block_size=65, # 50, # 51, # 50, 44 | input_dim=repr_dim, 45 | output_dim=hidden_dim, 46 | n_layer=8, 47 | n_head=4, 48 | n_embd=hidden_dim, 49 | dropout=0.1, # 0.6, #0.1, 50 | ) 51 | ) 52 | else: 53 | raise NotImplementedError 54 | self._action_head = DeterministicHead( 55 | hidden_dim, self._act_dim, hidden_size=hidden_dim, num_layers=2 56 | ) 57 | self.apply(utils.weight_init) 58 | 59 | def forward( 60 | self, 61 | obs, 62 | num_prompt_feats, 63 | stddev, 64 | action=None, 65 | cluster_centers=None, 66 | mask=None, 67 | ): 68 | B, T, D = obs.shape 69 | if self._policy_type == "mlp": 70 | if T * D < self._repr_dim: 71 | gt_num_time_steps = ( 72 | self._repr_dim // D - num_prompt_feats 73 | ) // self._num_feat_per_step 74 | num_repeat = ( 75 | gt_num_time_steps 76 | - (T - num_prompt_feats) // self._num_feat_per_step 77 | ) 78 | initial_obs = obs[ 79 | :, num_prompt_feats : num_prompt_feats + self._num_feat_per_step 80 | ] 81 | initial_obs = initial_obs.repeat(1, num_repeat, 1) 82 | obs = torch.cat( 83 | [obs[:, :num_prompt_feats], initial_obs, obs[:, num_prompt_feats:]], 84 | dim=1, 85 | ) 86 | B, T, D = obs.shape 87 | obs = obs.view(B, 1, T * D) 88 | features = self._policy(obs) 89 | elif self._policy_type == "gpt": 90 | # insert action token at each self._num_feat_per_step interval 91 | prompt = obs[:, :num_prompt_feats] 92 | obs = obs[:, num_prompt_feats:] 93 | obs = obs.view(B, -1, self._num_feat_per_step, obs.shape[-1]) 94 | action_token = self._action_token.repeat(B, obs.shape[1], 1, 1) 95 | obs = torch.cat([obs, action_token], dim=-2).view(B, -1, D) 96 | obs = torch.cat([prompt, obs], dim=1) 97 | 98 | if mask is not None: 99 | mask = torch.cat([mask, torch.ones(B, 1).to(mask.device)], dim=1) 100 | mask = mask.view(B, -1, 1, self._num_feat_per_step + 1) 101 | base_mask = torch.ones( 102 | B, 103 | mask.shape[1], 104 | self._num_feat_per_step + 1, 105 | self._num_feat_per_step + 1, 106 | ).to(mask.device) 107 | base_mask[:, :, -1:] = mask 108 | 109 | # get action features 110 | features = self._policy(obs, mask=base_mask if mask is not None else None) 111 | features = features[:, num_prompt_feats:] 112 | num_feat_per_step = self._num_feat_per_step + 1 # +1 for action token 113 | features = features[:, num_feat_per_step - 1 :: num_feat_per_step] 114 | 115 | # action head 116 | pred_action = self._action_head( 117 | features, 118 | stddev, 119 | **{"cluster_centers": cluster_centers, "action_seq": action}, 120 | ) 121 | 122 | if action is None: 123 | return pred_action 124 | else: 125 | loss = self._action_head.loss_fn( 126 | pred_action, 127 | action, 128 | reduction="mean", 129 | **{"cluster_centers": cluster_centers}, 130 | ) 131 | return pred_action, loss[0] if isinstance(loss, tuple) else loss 132 | 133 | 134 | class BCAgent: 135 | def __init__( 136 | self, 137 | obs_shape, 138 | action_shape, 139 | device, 140 | lr, 141 | hidden_dim, 142 | stddev_schedule, 143 | stddev_clip, 144 | use_tb, 145 | augment, 146 | encoder_type, 147 | policy_type, 148 | policy_head, 149 | pixel_keys, 150 | aux_keys, 151 | use_aux_inputs, 152 | train_encoder, 153 | norm, 154 | separate_encoders, 155 | temporal_agg, 156 | max_episode_len, 157 | num_queries, 158 | use_actions, 159 | ): 160 | self.device = device 161 | self.lr = lr 162 | self.hidden_dim = hidden_dim 163 | self.stddev_schedule = stddev_schedule 164 | self.stddev_clip = stddev_clip 165 | self.use_tb = use_tb 166 | self.augment = augment 167 | self.encoder_type = encoder_type 168 | self.policy_head = policy_head 169 | self.use_aux_inputs = use_aux_inputs 170 | self.norm = norm 171 | self.train_encoder = train_encoder 172 | self.separate_encoders = separate_encoders 173 | self.use_actions = use_actions # only for the prompt 174 | 175 | # actor parameters 176 | self._act_dim = action_shape[0] 177 | 178 | # keys 179 | self.aux_keys = aux_keys 180 | self.pixel_keys = pixel_keys 181 | 182 | # action chunking params 183 | self.temporal_agg = temporal_agg 184 | self.max_episode_len = max_episode_len 185 | self.num_queries = num_queries if self.temporal_agg else 1 186 | 187 | # number of inputs per time step 188 | num_feat_per_step = len(self.pixel_keys) 189 | if use_aux_inputs: 190 | num_feat_per_step += len(self.aux_keys) 191 | 192 | # observation params 193 | if use_aux_inputs: 194 | aux_shape = {key: obs_shape[key] for key in self.aux_keys} 195 | obs_shape = obs_shape[self.pixel_keys[0]] 196 | 197 | # Track model size 198 | model_size = 0 199 | 200 | # encoder 201 | if self.separate_encoders: 202 | self.encoder = {} 203 | if self.encoder_type == "base": 204 | if self.separate_encoders: 205 | for key in self.pixel_keys: 206 | self.encoder[key] = BaseEncoder(obs_shape).to(device) 207 | self.repr_dim = self.encoder[key].repr_dim 208 | model_size += sum( 209 | p.numel() 210 | for p in self.encoder[key].parameters() 211 | if p.requires_grad 212 | ) 213 | else: 214 | self.encoder = BaseEncoder(obs_shape).to(device) 215 | self.repr_dim = self.encoder.repr_dim 216 | model_size += sum( 217 | p.numel() for p in self.encoder.parameters() if p.requires_grad 218 | ) 219 | elif self.encoder_type == "resnet": 220 | self.repr_dim = 512 221 | enc_fn = lambda: ResnetEncoder( 222 | obs_shape, 223 | 512, 224 | cond_dim=None, 225 | cond_fusion="none", 226 | ).to(device) 227 | if self.separate_encoders: 228 | for key in self.pixel_keys: 229 | self.encoder[key] = enc_fn() 230 | model_size += sum( 231 | p.numel() 232 | for p in self.encoder[key].parameters() 233 | if p.requires_grad 234 | ) 235 | else: 236 | self.encoder = enc_fn() 237 | model_size += sum( 238 | p.numel() for p in self.encoder.parameters() if p.requires_grad 239 | ) 240 | 241 | # projector for proprioceptive features 242 | if use_aux_inputs: 243 | self.aux_projector = nn.ModuleDict() 244 | for key in self.aux_keys: 245 | if key.startswith("digit"): 246 | self.aux_projector[key] = ResnetEncoder( 247 | obs_shape, 248 | 512, 249 | cond_dim=None, 250 | cond_fusion="none", 251 | ).to(device) 252 | else: 253 | self.aux_projector[key] = MLP( 254 | aux_shape[key][0], 255 | hidden_channels=[self.repr_dim, self.repr_dim], 256 | ).to(device) 257 | self.aux_projector[key].apply(utils.weight_init) 258 | model_size += sum( 259 | p.numel() 260 | for p in self.aux_projector[key].parameters() 261 | if p.requires_grad 262 | ) 263 | 264 | # projector for actions 265 | if self.use_actions: 266 | self.action_projector = MLP( 267 | self._act_dim, hidden_channels=[self.repr_dim, self.repr_dim] 268 | ).to(device) 269 | self.action_projector.apply(utils.weight_init) 270 | model_size += sum( 271 | p.numel() for p in self.action_projector.parameters() if p.requires_grad 272 | ) 273 | 274 | # actor 275 | action_dim = ( 276 | self._act_dim * self.num_queries if self.temporal_agg else self._act_dim 277 | ) 278 | self.actor = Actor( 279 | self.repr_dim, 280 | action_dim, 281 | hidden_dim, 282 | policy_type, 283 | policy_head, 284 | num_feat_per_step, 285 | ).to(device) 286 | model_size += sum(p.numel() for p in self.actor.parameters() if p.requires_grad) 287 | 288 | print(f"Total number of parameters in the model: {model_size}") 289 | 290 | # optimizers 291 | # encoder 292 | if self.train_encoder: 293 | if self.separate_encoders: 294 | params = [] 295 | for key in self.pixel_keys: 296 | params += list(self.encoder[key].parameters()) 297 | else: 298 | params = list(self.encoder.parameters()) 299 | self.encoder_opt = torch.optim.AdamW(params, lr=lr, weight_decay=1e-4) 300 | # self.encoder_scheduler = torch.optim.lr_scheduler.StepLR( 301 | # self.encoder_opt, step_size=15000, gamma=0.1 302 | # ) 303 | # proprio 304 | if self.use_aux_inputs: 305 | self.aux_opt = torch.optim.AdamW( 306 | self.aux_projector.parameters(), lr=lr, weight_decay=1e-4 307 | ) 308 | # self.proprio_scheduler = torch.optim.lr_scheduler.StepLR( 309 | # self.proprio_opt, step_size=15000, gamma=0.1 310 | # ) 311 | 312 | # action projector 313 | if self.use_actions: 314 | self.action_opt = torch.optim.AdamW( 315 | self.action_projector.parameters(), lr=lr, weight_decay=1e-4 316 | ) 317 | # self.action_scheduler = torch.optim.lr_scheduler.StepLR( 318 | # self.action_opt, step_size=15000, gamma=0.1 319 | # ) 320 | # actor 321 | self.actor_opt = torch.optim.AdamW( 322 | self.actor.parameters(), lr=lr, weight_decay=1e-4 323 | ) 324 | # self.actor_scheduler = torch.optim.lr_scheduler.StepLR( 325 | # self.actor_opt, step_size=15000, gamma=0.1 326 | # ) 327 | 328 | # augmentations 329 | if self.norm: 330 | if self.encoder_type == "small": 331 | MEAN = torch.tensor([0.0, 0.0, 0.0]) 332 | STD = torch.tensor([1.0, 1.0, 1.0]) 333 | elif self.encoder_type == "resnet" or self.norm: 334 | MEAN = torch.tensor([0.485, 0.456, 0.406]) 335 | STD = torch.tensor([0.229, 0.224, 0.225]) 336 | self.customAug = T.Compose([T.Normalize(mean=MEAN, std=STD)]) 337 | 338 | # data augmentation 339 | if self.augment: 340 | # self.aug = utils.RandomShiftsAug(pad=4) 341 | self.test_aug = T.Compose([T.ToPILImage(), T.ToTensor()]) 342 | self.digit_aug = T.Compose([T.ToTensor()]) 343 | 344 | self.train() 345 | self.buffer_reset() 346 | 347 | def __repr__(self): 348 | return "bc" 349 | 350 | def train(self, training=True): 351 | self.training = training 352 | if training: 353 | if self.separate_encoders: 354 | for key in self.pixel_keys: 355 | if self.train_encoder: 356 | self.encoder[key].train(training) 357 | else: 358 | self.encoder[key].eval() 359 | else: 360 | if self.train_encoder: 361 | self.encoder.train(training) 362 | else: 363 | self.encoder.eval() 364 | if self.use_aux_inputs: 365 | self.aux_projector.train(training) 366 | if self.use_actions: 367 | self.action_projector.train(training) 368 | self.actor.train(training) 369 | else: 370 | if self.separate_encoders: 371 | for key in self.pixel_keys: 372 | self.encoder[key].eval() 373 | else: 374 | self.encoder.eval() 375 | if self.use_aux_inputs: 376 | self.aux_projector.eval() 377 | if self.use_actions: 378 | self.action_projector.eval() 379 | self.actor.eval() 380 | 381 | def buffer_reset(self): 382 | self.observation_buffer = {} 383 | for key in self.pixel_keys: 384 | self.observation_buffer[key] = deque(maxlen=1) 385 | if self.use_aux_inputs: 386 | self.aux_buffer = {} 387 | for key in self.aux_keys: 388 | self.aux_buffer[key] = deque(maxlen=1) 389 | 390 | # temporal aggregation 391 | if self.temporal_agg: 392 | self.all_time_actions = torch.zeros( 393 | [ 394 | self.max_episode_len, 395 | self.max_episode_len + self.num_queries, 396 | self._act_dim, 397 | ] 398 | ).to(self.device) 399 | 400 | def clear_buffers(self): 401 | del self.observation_buffer 402 | if self.use_aux_inputs: 403 | del self.aux_buffer 404 | if self.temporal_agg: 405 | del self.all_time_actions 406 | 407 | def reinit_optimizers(self): 408 | if self.train_encoder: 409 | if self.separate_encoders: 410 | params = [] 411 | for key in self.pixel_keys: 412 | params += list(self.encoder[key].parameters()) 413 | else: 414 | params = list(self.encoder.parameters()) 415 | self.encoder_opt = torch.optim.AdamW(params, lr=self.lr, weight_decay=1e-4) 416 | try: 417 | self.aux_cond_opt = torch.optim.AdamW( 418 | self.aux_cond_cross_attn.parameters(), 419 | lr=self.lr, 420 | weight_decay=1e-4, 421 | ) 422 | except AttributeError: 423 | print("Not optimizing aux_cond_cross_attn") 424 | if self.use_aux_inputs: 425 | self.aux_opt = torch.optim.AdamW( 426 | self.aux_projector.parameters(), lr=self.lr, weight_decay=1e-4 427 | ) 428 | if self.use_actions: 429 | self.action_opt = torch.optim.AdamW( 430 | self.action_projector.parameters(), lr=self.lr, weight_decay=1e-4 431 | ) 432 | params = list(self.actor.parameters()) 433 | self.actor_opt = torch.optim.AdamW( 434 | self.actor.parameters(), lr=self.lr, weight_decay=1e-4 435 | ) 436 | 437 | def act(self, obs, prompt, norm_stats, step, global_step, eval_mode=False): 438 | if norm_stats is not None: 439 | 440 | def pre_process(aux_key, s_qpos): 441 | try: 442 | return (s_qpos - norm_stats[aux_key]["min"]) / ( 443 | norm_stats[aux_key]["max"] - norm_stats[aux_key]["min"] + 1e-5 444 | ) 445 | except KeyError: 446 | return (s_qpos - norm_stats[aux_key]["mean"]) / ( 447 | norm_stats[aux_key]["std"] + 1e-5 448 | ) 449 | 450 | # pre_process = lambda aux_key, s_qpos: ( 451 | # s_qpos - norm_stats[aux_key]["min"] 452 | # ) / (norm_stats[aux_key]["max"] - norm_stats[aux_key]["min"] + 1e-5) 453 | post_process = ( 454 | lambda a: a 455 | * (norm_stats["actions"]["max"] - norm_stats["actions"]["min"]) 456 | + norm_stats["actions"]["min"] 457 | ) 458 | 459 | # lang projection 460 | lang_features = None 461 | 462 | # add to buffer 463 | features = [] 464 | aux_features = [] 465 | aux_cond_features = [] 466 | if self.use_aux_inputs: 467 | # TODO: Add conditioning here 468 | for key in self.aux_keys: 469 | obs[key] = pre_process(key, obs[key]) 470 | if key.startswith("digit"): 471 | self.aux_buffer[key].append( 472 | self.digit_aug(obs[key].transpose(1, 2, 0)).numpy() 473 | ) 474 | aux_feat = torch.as_tensor( 475 | np.array(self.aux_buffer[key]), device=self.device 476 | ).float() 477 | aux_feat = self.aux_projector[key](aux_feat)[None, :, :] 478 | else: 479 | self.aux_buffer[key].append(obs[key]) 480 | aux_feat = torch.as_tensor( 481 | np.array(self.aux_buffer[key]), device=self.device 482 | ).float() 483 | aux_feat = self.aux_projector[key](aux_feat[None, :, :]) 484 | aux_features.append(aux_feat[0]) 485 | 486 | for key in self.pixel_keys: 487 | self.observation_buffer[key].append( 488 | self.test_aug(obs[key].transpose(1, 2, 0)).numpy() 489 | ) 490 | pixels = torch.as_tensor( 491 | np.array(self.observation_buffer[key]), device=self.device 492 | ).float() 493 | pixels = self.customAug(pixels) if self.norm else pixels 494 | # encoder 495 | pixels = ( 496 | self.encoder[key](pixels) 497 | if self.separate_encoders 498 | else self.encoder(pixels) 499 | ) 500 | features.append(pixels) 501 | 502 | features.extend(aux_features) 503 | features = torch.cat(features, dim=-1).view(-1, self.repr_dim) 504 | 505 | stddev = utils.schedule(self.stddev_schedule, global_step) 506 | action = self.actor(features.unsqueeze(0), 0, stddev) 507 | 508 | if eval_mode: 509 | action = action.mean 510 | else: 511 | action = action.sample() 512 | if self.temporal_agg: 513 | action = action.view(-1, self.num_queries, self._act_dim) 514 | self.all_time_actions[[step], step : step + self.num_queries] = action[-1:] 515 | actions_for_curr_step = self.all_time_actions[:, step] 516 | actions_populated = torch.all(actions_for_curr_step != 0, axis=1) 517 | actions_for_curr_step = actions_for_curr_step[actions_populated] 518 | k = 0.01 519 | exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step))) 520 | exp_weights = exp_weights / exp_weights.sum() 521 | exp_weights = torch.from_numpy(exp_weights).to(self.device).unsqueeze(dim=1) 522 | action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True) 523 | if norm_stats is not None: 524 | return post_process(action.cpu().numpy()[0]) 525 | return action.cpu().numpy()[0] 526 | else: 527 | if norm_stats is not None: 528 | return post_process(action.cpu().numpy()[0, -1, :]) 529 | return action.cpu().numpy()[0, -1, :] 530 | 531 | def update(self, expert_replay_iter, step): 532 | metrics = dict() 533 | batch = next(expert_replay_iter) 534 | data = utils.to_torch(batch, self.device) 535 | action = data["actions"].float() 536 | 537 | # features 538 | features = [] 539 | aux_features = [] 540 | if self.use_aux_inputs: 541 | for key in self.aux_keys: 542 | aux_feat = data[key].float() 543 | if "digit" in key: 544 | shape = aux_feat.shape 545 | # rearrange 546 | aux_feat = einops.rearrange(aux_feat, "b t c h w -> (b t) c h w") 547 | # augment 548 | # aux_feat = self.aug(aux_feat) if self.augment else aux_feat 549 | aux_feat = self.customAug(aux_feat) if self.norm else aux_feat 550 | aux_feat = self.aux_projector[key](aux_feat) 551 | aux_feat = einops.rearrange( 552 | aux_feat, "(b t) d -> b t d", t=shape[1] 553 | ) 554 | else: 555 | aux_feat = self.aux_projector[key](aux_feat) 556 | aux_features.append(aux_feat) 557 | for key in self.pixel_keys: 558 | pixel = data[key].float() 559 | shape = pixel.shape 560 | # rearrange 561 | pixel = einops.rearrange(pixel, "b t c h w -> (b t) c h w") 562 | # augment 563 | # pixel = self.aug(pixel) if self.augment else pixel 564 | pixel = self.customAug(pixel) if self.norm else pixel 565 | # encode 566 | if self.train_encoder: 567 | pixel = ( 568 | self.encoder[key](pixel) 569 | if self.separate_encoders 570 | else self.encoder(pixel) 571 | ) 572 | else: 573 | with torch.no_grad(): 574 | pixel = ( 575 | self.encoder[key](pixel) 576 | if self.separate_encoders 577 | else self.encoder(pixel) 578 | ) 579 | pixel = einops.rearrange(pixel, "(b t) d -> b t d", t=shape[1]) 580 | features.append(pixel) 581 | features.extend(aux_features) 582 | # concatenate 583 | features = torch.cat(features, dim=-1).view( 584 | action.shape[0], -1, self.repr_dim 585 | ) # (B, T * num_feat_per_step, D) 586 | 587 | # rearrange action 588 | if self.temporal_agg: 589 | action = einops.rearrange(action, "b t1 t2 d -> b t1 (t2 d)") 590 | 591 | # actor loss 592 | stddev = utils.schedule(self.stddev_schedule, step) 593 | _, actor_loss = self.actor( 594 | features, 595 | 0, 596 | stddev, 597 | action, 598 | mask=None, 599 | ) 600 | if self.train_encoder: 601 | self.encoder_opt.zero_grad(set_to_none=True) 602 | 603 | if self.use_aux_inputs: 604 | self.aux_opt.zero_grad(set_to_none=True) 605 | self.actor_opt.zero_grad(set_to_none=True) 606 | actor_loss["actor_loss"].backward() 607 | if self.train_encoder: 608 | self.encoder_opt.step() 609 | try: 610 | self.aux_cond_opt.step() 611 | except AttributeError: 612 | pass 613 | if self.use_aux_inputs: 614 | self.aux_opt.step() 615 | if self.use_actions: 616 | self.action_opt.step() 617 | self.actor_opt.step() 618 | 619 | if self.use_tb: 620 | for key, value in actor_loss.items(): 621 | metrics[key] = value.item() 622 | 623 | return metrics 624 | 625 | def save_snapshot(self): 626 | model_keys = ["actor", "encoder"] 627 | opt_keys = ["actor_opt"] 628 | if self.train_encoder: 629 | opt_keys += ["encoder_opt"] 630 | if self.use_aux_inputs: 631 | model_keys += ["aux_projector"] 632 | opt_keys += ["aux_opt"] 633 | if self.use_actions: 634 | model_keys += ["action_projector"] 635 | opt_keys += ["action_opt"] 636 | # models 637 | payload = { 638 | k: self.__dict__[k].state_dict() for k in model_keys if k != "encoder" 639 | } 640 | if "encoder" in model_keys: 641 | if self.separate_encoders: 642 | for key in self.pixel_keys: 643 | payload[f"encoder_{key}"] = self.encoder[key].state_dict() 644 | else: 645 | payload["encoder"] = self.encoder.state_dict() 646 | # optimizers 647 | payload.update({k: self.__dict__[k] for k in opt_keys}) 648 | 649 | others = [ 650 | "use_aux_inputs", 651 | "aux_keys", 652 | "use_actions", 653 | "max_episode_len", 654 | ] 655 | payload.update({k: self.__dict__[k] for k in others}) 656 | return payload 657 | 658 | def load_snapshot(self, payload, encoder_only=False, eval=False, load_opt=False): 659 | # models 660 | if encoder_only: 661 | model_keys = ["encoder"] 662 | payload = {"encoder": payload} 663 | else: 664 | model_keys = ["actor", "encoder"] 665 | if self.use_aux_inputs: 666 | model_keys += ["aux_projector"] 667 | if self.use_actions: 668 | model_keys += ["action_projector"] 669 | 670 | for k in model_keys: 671 | if k == "encoder" and self.separate_encoders: 672 | for key in self.pixel_keys: 673 | self.encoder[key].load_state_dict(payload[f"encoder_{key}"]) 674 | else: 675 | self.__dict__[k].load_state_dict(payload[k]) 676 | 677 | if eval: 678 | self.train(False) 679 | return 680 | 681 | # if not eval 682 | if not load_opt: 683 | self.reinit_optimizers() 684 | else: 685 | opt_keys = ["actor_opt"] 686 | if self.train_encoder: 687 | opt_keys += ["encoder_opt"] 688 | if self.use_aux_inputs: 689 | opt_keys += ["aux_opt"] 690 | if self.use_actions: 691 | opt_keys += ["action_opt"] 692 | for k in opt_keys: 693 | self.__dict__[k] = payload[k] 694 | self.train(True) 695 | --------------------------------------------------------------------------------