├── envs
├── __init__.py
├── xarm-env
│ ├── xarm_env
│ │ ├── envs
│ │ │ ├── __init__.py
│ │ │ ├── constants.py
│ │ │ └── robot_env.py
│ │ └── __init__.py
│ └── setup.py
└── constants.py
├── cfgs
├── suite
│ ├── task
│ │ └── xarm_env.yaml
│ └── xarm_env_aa.yaml
├── dataloader
│ └── xarm_env_aa.yaml
├── agent
│ └── bc.yaml
├── config.yaml
└── config_eval.yaml
├── requirements.txt
├── .pre-commit-config.yaml
├── env.yml
├── replay_buffer.py
├── LICENSE
├── README.md
├── agent
├── networks
│ ├── policy_head.py
│ ├── mlp.py
│ ├── rgb_modules.py
│ └── gpt.py
└── bc.py
├── video.py
├── .gitignore
├── convert_to_pkl.py
├── logger.py
├── eval.py
├── utils.py
├── train_bc.py
├── suite
└── xarm_env.py
├── process_xarm_data.py
└── read_data
└── xarm_env_aa.py
/envs/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/envs/xarm-env/xarm_env/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from xarm_env.envs.robot_env import RobotEnv
2 |
--------------------------------------------------------------------------------
/envs/xarm-env/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup(name="xarm_env", version="0.0.1", install_requires=["gym"])
4 |
--------------------------------------------------------------------------------
/cfgs/suite/task/xarm_env.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 |
4 | suite: xarm_env
5 | tasks:
6 | # - 0529_norot_ballincup #y
7 | - 0802_insert_plug_triangle
8 |
--------------------------------------------------------------------------------
/envs/xarm-env/xarm_env/__init__.py:
--------------------------------------------------------------------------------
1 | from gym.envs.registration import register
2 |
3 | register(
4 | id="Robot-v1",
5 | entry_point="xarm_env.envs:RobotEnv",
6 | max_episode_steps=400, # 200,
7 | )
8 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | opencv-python==4.9.0.80
2 | ipdb
3 | h5py
4 | wandb
5 | termcolor
6 | tensorboard
7 | hydra-core
8 | hydra-submitit-launcher
9 | gym
10 | dm-env
11 | imageio
12 | imageio-ffmpeg
13 | tqdm
14 | scipy
15 | blosc
16 | einops
17 | zmq
18 |
--------------------------------------------------------------------------------
/envs/constants.py:
--------------------------------------------------------------------------------
1 | HOST_ADDRESS = "10.19.216.156"
2 | CAMERA_HOST_ADDRESS = "10.19.216.156"
3 | DEPLOYMENT_PORT = 10000
4 |
5 | # camera
6 | CAMERA_PORT_OFFSET = 10005
7 | CAM_SERIAL_NUMS = {
8 | 1: "239122072217",
9 | 2: "023322062082",
10 | 3: "233522071078",
11 | 4: "141722076049",
12 | }
13 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v4.4.0
4 | hooks:
5 | - id: check-yaml
6 | - id: end-of-file-fixer
7 | - id: trailing-whitespace
8 | - repo: https://github.com/psf/black
9 | rev: 23.7.0
10 | hooks:
11 | - id: black
12 | - repo: https://github.com/hadialqattan/pycln
13 | rev: v2.2.2
14 | hooks:
15 | - id: pycln
16 |
--------------------------------------------------------------------------------
/envs/xarm-env/xarm_env/envs/constants.py:
--------------------------------------------------------------------------------
1 | HOST_ADDRESS = "10.19.216.156"
2 | CAMERA_HOST_ADDRESS = "10.19.216.156"
3 | DEPLOYMENT_PORT = 10000
4 |
5 | # camera
6 | CAMERA_PORT_OFFSET = 10005
7 | CAM_SERIAL_NUMS = {
8 | 1: "239122072217",
9 | 2: "023322062082",
10 | 3: "233522071078",
11 | 4: "141722076049",
12 | }
13 |
14 | FISH_EYE_CAMERA_PORT_OFFSET = 10010
15 |
16 | FISH_EYE_CAM_SERIAL_NUMS = {51: "18", 52: "26"}
17 |
--------------------------------------------------------------------------------
/env.yml:
--------------------------------------------------------------------------------
1 | name: visuoskin
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - conda-forge
6 | - defaults
7 |
8 | dependencies:
9 | - python=3.11
10 | - numpy
11 | - pre-commit
12 | - pip
13 | - pytorch=2.1
14 | - torchvision
15 | - pytorch-cuda=11.8
16 | - pip:
17 | - ipdb
18 | - h5py
19 | - termcolor
20 | - tensorboard
21 | - hydra-core
22 | - hydra-submitit-launcher
23 | - dm-env
24 | - imageio
25 |
--------------------------------------------------------------------------------
/replay_buffer.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import numpy as np
4 | import torch
5 |
6 | def _worker_init_fn(worker_id):
7 | # seed = np.random.get_state()[1][0] + worker_id
8 | # np.random.seed(seed)
9 | # random.seed(seed)
10 | np.random.seed(worker_id)
11 | random.seed(worker_id)
12 |
13 |
14 | def make_expert_replay_loader(iterable, batch_size):
15 | loader = torch.utils.data.DataLoader(
16 | iterable,
17 | batch_size=batch_size,
18 | num_workers=16,
19 | pin_memory=True,
20 | worker_init_fn=_worker_init_fn,
21 | )
22 | return loader
23 |
--------------------------------------------------------------------------------
/cfgs/dataloader/xarm_env_aa.yaml:
--------------------------------------------------------------------------------
1 | bc_dataset:
2 | _target_: read_data.xarm_env_aa.BCDataset
3 | path: ${root_dir}/processed_data_pkl_aa
4 | tasks: ${suite.task.tasks}
5 | num_demos_per_task: ${num_demos_per_task}
6 | temporal_agg: ${temporal_agg}
7 | num_queries: ${num_queries}
8 | img_size: ${img_size}
9 | action_after_steps: ${suite.action_after_steps}
10 | store_actions: true
11 | pixel_keys: ${suite.pixel_keys}
12 | aux_keys: ${suite.aux_keys}
13 | subsample: 5
14 | skip_first_n: 0
15 | relative_actions: true
16 | random_mask_proprio: false
17 | sensor_params: ${suite.sensor_params}
18 |
--------------------------------------------------------------------------------
/cfgs/agent/bc.yaml:
--------------------------------------------------------------------------------
1 | # @package agent
2 | _target_: agent.bc.BCAgent
3 | obs_shape: ??? # to be specified later
4 | action_shape: ??? # to be specified later
5 | device: ${device}
6 | lr: 1e-4 #1e-5 #1e-4
7 | hidden_dim: ${suite.hidden_dim}
8 | stddev_schedule: 0.1
9 | stddev_clip: 0.3
10 | use_tb: ${use_tb}
11 | augment: True
12 | encoder_type: ${encoder_type}
13 | policy_type: ${policy_type}
14 | policy_head: ${policy_head}
15 | pixel_keys: ${suite.pixel_keys}
16 | aux_keys: ${suite.aux_keys}
17 | use_aux_inputs: ${use_aux_inputs}
18 | train_encoder: true
19 | norm: false
20 | separate_encoders: false # have a separate encoder for each pixel key
21 | temporal_agg: ${temporal_agg}
22 | max_episode_len: ${suite.task_make_fn.max_episode_len} # to be specified later
23 | num_queries: ${num_queries}
24 | use_actions: ${use_actions}
25 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Raunaq Bhirangi
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/cfgs/config.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - local_config
3 | - _self_
4 | - agent: bc
5 | - dataloader: xarm_env_aa
6 | - suite: xarm_env_aa
7 | - override hydra/launcher: submitit_local
8 |
9 | # replay buffer
10 | batch_size: 256
11 | # misc
12 | seed: 0
13 | device: cuda
14 | save_video: true
15 | save_train_video: false
16 | use_tb: true
17 |
18 | # experiment
19 | num_demos_per_task: 5000
20 | encoder_type: 'resnet' # base, patch, resnet
21 | policy_type: 'gpt' # mlp, gpt
22 | img_size: 128
23 | policy_head: deterministic # deterministic, gmm, bet, diffusion, vqbet
24 | use_aux_inputs: true
25 | use_language: false
26 | use_actions: false
27 | sequential_train: false
28 | eval: false
29 | experiment: ${suite.name}_bc
30 | experiment_label: ${policy_head}
31 |
32 | # expert dataset
33 | num_demos: null #10(dmc), 1(metaworld), 1(particle), 1(robotgym)
34 | expert_dataset: ${dataloader.bc_dataset}
35 |
36 | # Load weights
37 | load_bc: false
38 | bc_weight: null
39 |
40 | # Action chunking parameters
41 | temporal_agg: true
42 | num_queries: 10
43 |
44 | # TODO: Fix this
45 | max_episode_len: 1000
46 |
47 | hydra:
48 | job:
49 | chdir: true
50 | run:
51 | dir: ./exp_local/${now:%Y.%m.%d}_${experiment}/${now:%H%M%S}_${experiment_label}
52 | sweep:
53 | dir: ./exp_local/${now:%Y.%m.%d}/${now:%H%M%S}
54 | subdir: ${hydra.job.num}
55 |
--------------------------------------------------------------------------------
/cfgs/suite/xarm_env_aa.yaml:
--------------------------------------------------------------------------------
1 | # @package suite
2 | defaults:
3 | - _self_
4 | - task: xarm_env
5 |
6 | suite: xarm_env
7 | name: "xarm_env"
8 |
9 | # task settings
10 | frame_stack: 1
11 | action_repeat: 1
12 | discount: 0.99
13 | hidden_dim: 256
14 | action_type: continuous
15 |
16 | # train settings
17 | num_train_steps: 10000
18 | log_every_steps: 100
19 | save_every_steps: 500
20 | num_train_steps_per_task: 20000 #2500 #5000
21 |
22 | # eval
23 | eval_every_steps: 200000 #20000 #5000
24 | num_eval_episodes: 5
25 |
26 | # data loading
27 | action_after_steps: 1 #8
28 |
29 | # obs_keys
30 | pixel_keys: ["pixels1", "pixels2", "pixels51", "pixels52"]
31 | aux_keys: ["sensor0","sensor1"]
32 | # aux_keys: ["digit80","digit81"]
33 | # aux_keys: ["sensor"]
34 | feature_key: "proprioceptive"
35 | sensor_params:
36 | sensor_type: reskin
37 | subtract_sensor_baseline: true
38 |
39 | # snapshot
40 | save_snapshot: true
41 |
42 | task_make_fn:
43 | _target_: suite.xarm_env.make
44 | frame_stack: ${suite.frame_stack}
45 | action_repeat: ${suite.action_repeat}
46 | height: ${img_size}
47 | width: ${img_size}
48 | max_episode_len: ??? # to be specified later
49 | max_state_dim: ??? # to be specified later
50 | use_egocentric: true
51 | use_fisheye: true
52 | task_description: "just training"
53 | pixel_keys: ${suite.pixel_keys}
54 | aux_keys: ${suite.aux_keys}
55 | sensor_params: ${suite.sensor_params}
56 | eval: ${eval} # eval true mean use robot
57 |
--------------------------------------------------------------------------------
/cfgs/config_eval.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - local_config
3 | - _self_
4 | - agent: bc
5 | - dataloader: xarm_env_aa
6 | - suite: xarm_env_aa
7 | # - override hydra/launcher: submitit_local
8 |
9 | # replay buffer
10 | batch_size: 256
11 | # misc
12 | seed: 0
13 | device: cuda
14 | save_video: true
15 | save_train_video: false
16 | use_tb: true
17 | eval: true
18 |
19 | # experiment
20 | num_demos_per_task: 5000
21 | encoder_type: 'resnet' # base, patch, resnet
22 | policy_type: 'gpt' # mlp, gpt
23 | policy_head: deterministic # deterministic, gmm, bet, diffusion, vqbet
24 | use_aux_inputs: true
25 | use_language: false
26 | use_actions: false
27 | sequential_train: false
28 | irl: false
29 | experiment: ${suite.name}_eval
30 | experiment_label: ${policy_head}
31 |
32 | # expert dataset
33 | num_demos: null #10(dmc), 1(metaworld), 1(particle), 1(robotgym)
34 | expert_dataset: ${dataloader.bc_dataset}
35 |
36 | # Load weights
37 | load_bc: false
38 | bc_weight: null
39 |
40 | # Action chunking parameters
41 | temporal_agg: true
42 | num_queries: 10
43 |
44 | # TODO: Fix this
45 | max_episode_len: 1000
46 |
47 | hydra:
48 | job:
49 | chdir: true
50 | run:
51 | dir: ./exp_local/${now:%Y.%m.%d}_${experiment}/${now:%H%M%S}_${experiment_label}
52 | sweep:
53 | dir: ./exp_local/${now:%Y.%m.%d}/${now:%H%M%S}
54 | subdir: ${hydra.job.num}
55 | # launcher:
56 | # tasks_per_node: 1
57 | # nodes: 1
58 | # submitit_folder: ./exp/${now:%Y.%m.%d}/${now:%H%M%S}_${experiment}/.slurm
59 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Visuo-Skin (ViSk)
2 | Accompanying code for training VisuoSkin policies as described in the paper:
3 | [Learning Precise, Contact-Rich Manipulation through Uncalibrated Tactile Skins](https://visuoskin.github.io)
4 |
5 |
6 |
7 |
8 |
9 | ## About
10 |
11 | ViSk is a framework for learning visuotactile policies for fine-grained, contact-rich manipulation tasks. ViSk uses a transformer-based architecture in conjunction with [AnySkin](https://any-skin.github.io) and presents a significant improvement over vision-only policies as well as visuotactile policies that use high-dimensional tactile sensors like DIGIT.
12 |
13 | ## Installation
14 | 1. Clone this repository
15 | ```
16 | git clone https://github.com/raunaqbhirangi/visuoskin.git
17 | ```
18 |
19 | 2. Create a conda environment and install dependencies
20 |
21 | ```
22 | conda create -f env.yml
23 | pip install -r requirements.txt
24 | ```
25 |
26 | 3. Move raw data to your desired location and set `DATA_DIR` in `utils.py` to point to this location. Similarly, set `root_dir` in `cfgs/local_config.yaml`.
27 |
28 | 4. Process data for the `current-task` (name of the directory containing demonstration data for the current task) and convert to pkl.
29 |
30 | ```
31 | python process_data.py -t current-task
32 | python convert_to_pkl.py -t current-task
33 | ```
34 | 5. Install `xarm-env` using `pip install -e envs/xarm-env`
35 |
36 | 6. Run BC training
37 | ```
38 | python train_bc.py 'suite.task.tasks=[current-task]'
39 | ```
40 |
--------------------------------------------------------------------------------
/agent/networks/policy_head.py:
--------------------------------------------------------------------------------
1 | import einops
2 |
3 | # import robomimic.utils.tensor_utils as TensorUtils
4 | import torch
5 | import torch.distributions as D
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | import utils
10 |
11 | # from agent.networks.utils.diffusion_policy import DiffusionPolicy
12 | # from agent.networks.utils.vqbet.pretrain_vqvae import init_vqvae, pretrain_vqvae
13 | from agent.networks.mlp import MLP
14 |
15 | ######################################### Deterministic Head #########################################
16 |
17 |
18 | class DeterministicHead(nn.Module):
19 | def __init__(
20 | self,
21 | input_size,
22 | output_size,
23 | hidden_size=1024,
24 | num_layers=2,
25 | action_squash=True,
26 | loss_coef=1.0,
27 | ):
28 | super().__init__()
29 | self.loss_coef = loss_coef
30 |
31 | sizes = [input_size] + [hidden_size] * num_layers + [output_size]
32 | layers = []
33 | for i in range(num_layers):
34 | layers += [nn.Linear(sizes[i], sizes[i + 1]), nn.ReLU()]
35 | layers += [nn.Linear(sizes[-2], sizes[-1])]
36 |
37 | if action_squash:
38 | layers += [nn.Tanh()]
39 |
40 | self.net = nn.Sequential(*layers)
41 |
42 | def forward(self, x, stddev=None, **kwargs):
43 | mu = self.net(x)
44 | std = stddev if stddev is not None else 0.1
45 | std = torch.ones_like(mu) * std
46 | dist = utils.TruncatedNormal(mu, std)
47 | return dist
48 |
49 | def loss_fn(self, dist, target, reduction="mean", **kwargs):
50 | log_probs = dist.log_prob(target)
51 | loss = -log_probs
52 |
53 | if reduction == "mean":
54 | loss = loss.mean() * self.loss_coef
55 | elif reduction == "none":
56 | loss = loss * self.loss_coef
57 | elif reduction == "sum":
58 | loss = loss.sum() * self.loss_coef
59 | else:
60 | raise NotImplementedError
61 |
62 | return {
63 | "actor_loss": loss,
64 | }
65 |
--------------------------------------------------------------------------------
/video.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import imageio
3 | import numpy as np
4 |
5 |
6 | class VideoRecorder:
7 | def __init__(self, root_dir, render_size=256, fps=20):
8 | if root_dir is not None:
9 | self.save_dir = root_dir / "eval_video"
10 | self.save_dir.mkdir(exist_ok=True)
11 | else:
12 | self.save_dir = None
13 |
14 | self.render_size = render_size
15 | self.fps = fps
16 | self.frames = []
17 |
18 | def init(self, env, enabled=True):
19 | self.frames = []
20 | self.enabled = self.save_dir is not None and enabled
21 | self.record(env)
22 |
23 | def record(self, env):
24 | if self.enabled:
25 | if hasattr(env, "physics"):
26 | frame = env.physics.render(
27 | height=self.render_size, width=self.render_size, camera_id=0
28 | )
29 | else:
30 | frame = env.render()
31 | self.frames.append(frame)
32 |
33 | def save(self, file_name):
34 | if self.enabled:
35 | path = self.save_dir / file_name
36 | imageio.mimsave(str(path), self.frames, fps=self.fps)
37 |
38 |
39 | class TrainVideoRecorder:
40 | def __init__(self, root_dir, render_size=256, fps=20):
41 | if root_dir is not None:
42 | self.save_dir = root_dir / "train_video"
43 | self.save_dir.mkdir(exist_ok=True)
44 | else:
45 | self.save_dir = None
46 |
47 | self.render_size = render_size
48 | self.fps = fps
49 | self.frames = []
50 |
51 | def init(self, obs, enabled=True):
52 | self.frames = []
53 | self.enabled = self.save_dir is not None and enabled
54 | self.record(obs)
55 |
56 | def record(self, obs):
57 | if self.enabled:
58 | frame = cv2.resize(
59 | obs[-3:].transpose(1, 2, 0),
60 | dsize=(self.render_size, self.render_size),
61 | interpolation=cv2.INTER_CUBIC,
62 | )
63 | self.frames.append(frame)
64 |
65 | def save(self, file_name):
66 | if self.enabled:
67 | path = self.save_dir / file_name
68 | imageio.mimsave(str(path), self.frames, fps=self.fps)
69 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
162 | exp_local/
163 | test_env.py
164 | create_robosuite_configs.py
165 | expert_demos/
166 | logs/
167 | outputs/
168 | data/
169 | preprocessing_output/
170 | snapshot/
171 | tb/
172 | train.csv
173 | eval.csv
174 | *.png
175 | exp*/
176 | *.mp4
177 | cfgs/local_config.yaml
178 |
--------------------------------------------------------------------------------
/agent/networks/mlp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Callable, List, Optional
3 |
4 |
5 | class MLP(torch.nn.Sequential):
6 | """This block implements the multi-layer perceptron (MLP) module.
7 | Adapted for backward compatibility from the torchvision library:
8 | https://pytorch.org/vision/0.14/generated/torchvision.ops.MLP.html
9 |
10 | LICENSE:
11 |
12 | From PyTorch:
13 |
14 | Copyright (c) 2016- Facebook, Inc (Adam Paszke)
15 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
16 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
17 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
18 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
19 | Copyright (c) 2011-2013 NYU (Clement Farabet)
20 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
21 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
22 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
23 |
24 | From Caffe2:
25 |
26 | Copyright (c) 2016-present, Facebook Inc. All rights reserved.
27 |
28 | All contributions by Facebook:
29 | Copyright (c) 2016 Facebook Inc.
30 |
31 | All contributions by Google:
32 | Copyright (c) 2015 Google Inc.
33 | All rights reserved.
34 |
35 | All contributions by Yangqing Jia:
36 | Copyright (c) 2015 Yangqing Jia
37 | All rights reserved.
38 |
39 | All contributions by Kakao Brain:
40 | Copyright 2019-2020 Kakao Brain
41 |
42 | All contributions by Cruise LLC:
43 | Copyright (c) 2022 Cruise LLC.
44 | All rights reserved.
45 |
46 | All contributions from Caffe:
47 | Copyright(c) 2013, 2014, 2015, the respective contributors
48 | All rights reserved.
49 |
50 | All other contributions:
51 | Copyright(c) 2015, 2016 the respective contributors
52 | All rights reserved.
53 |
54 | Caffe2 uses a copyright model similar to Caffe: each contributor holds
55 | copyright over their contributions to Caffe2. The project versioning records
56 | all such contribution and copyright details. If a contributor wants to further
57 | mark their specific copyright on a particular contribution, they should
58 | indicate their copyright solely in the commit message of the change when it is
59 | committed.
60 |
61 | All rights reserved.
62 |
63 | Redistribution and use in source and binary forms, with or without
64 | modification, are permitted provided that the following conditions are met:
65 |
66 | 1. Redistributions of source code must retain the above copyright
67 | notice, this list of conditions and the following disclaimer.
68 |
69 | 2. Redistributions in binary form must reproduce the above copyright
70 | notice, this list of conditions and the following disclaimer in the
71 | documentation and/or other materials provided with the distribution.
72 |
73 | 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
74 | and IDIAP Research Institute nor the names of its contributors may be
75 | used to endorse or promote products derived from this software without
76 | specific prior written permission.
77 |
78 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
79 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
80 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
81 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
82 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
83 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
84 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
85 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
86 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
87 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
88 | POSSIBILITY OF SUCH DAMAGE.
89 |
90 |
91 | Args:
92 | in_channels (int): Number of channels of the input
93 | hidden_channels (List[int]): List of the hidden channel dimensions
94 | norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the linear layer. If ``None`` this layer won't be used. Default: ``None``
95 | activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the linear layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU``
96 | inplace (bool, optional): Parameter for the activation layer, which can optionally do the operation in-place.
97 | Default is ``None``, which uses the respective default values of the ``activation_layer`` and Dropout layer.
98 | bias (bool): Whether to use bias in the linear layer. Default ``True``
99 | dropout (float): The probability for the dropout layer. Default: 0.0
100 | """
101 |
102 | def __init__(
103 | self,
104 | in_channels: int,
105 | hidden_channels: List[int],
106 | activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
107 | inplace: Optional[bool] = None,
108 | bias: bool = True,
109 | dropout: float = 0.0,
110 | ):
111 | params = {} if inplace is None else {"inplace": inplace}
112 |
113 | layers = []
114 | in_dim = in_channels
115 | for hidden_dim in hidden_channels[:-1]:
116 | layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias))
117 | layers.append(activation_layer(**params))
118 | layers.append(torch.nn.Dropout(dropout, **params))
119 | in_dim = hidden_dim
120 |
121 | layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias))
122 | layers.append(torch.nn.Dropout(dropout, **params))
123 |
124 | super().__init__(*layers)
125 |
--------------------------------------------------------------------------------
/convert_to_pkl.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import h5py as h5
3 | import numpy as np
4 | from pandas import read_csv
5 | import pickle as pkl
6 | import cv2
7 | from pathlib import Path
8 | from scipy.spatial.transform import Rotation as R
9 |
10 | from utils import DATA_DIR
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument("--task-name", "-t", type=str, required=True)
14 | args = parser.parse_args()
15 |
16 | TASK_NAME = args.task_name
17 |
18 | PROCESSED_DATA_PATH = Path(DATA_DIR) / "processed_data/"
19 | SAVE_DATA_PATH = Path(DATA_DIR) / "processed_data_pkl_aa/"
20 |
21 | camera_indices = [1, 2, 51, 52]
22 | img_size = (128, 128)
23 | NUM_DEMOS = None
24 |
25 | # Create the save path
26 | SAVE_DATA_PATH.mkdir(parents=True, exist_ok=True)
27 |
28 | DATASET_PATH = Path(f"{PROCESSED_DATA_PATH}/{TASK_NAME}")
29 |
30 | if (SAVE_DATA_PATH / f"{TASK_NAME}.pkl").exists():
31 | print(f"Data for {TASK_NAME} already exists. Appending to it...")
32 | input("Press Enter to continue...")
33 | data = pkl.load(open(SAVE_DATA_PATH / f"{TASK_NAME}.pkl", "rb"))
34 | observations = data["observations"]
35 | max_cartesian = data["max_cartesian"]
36 | min_cartesian = data["min_cartesian"]
37 | max_gripper = data["max_gripper"]
38 | min_gripper = data["min_gripper"]
39 | else:
40 | # Init storing variables
41 | observations = []
42 |
43 | # Store max and min
44 | max_cartesian, min_cartesian = None, None
45 | max_sensor, min_sensor = None, None
46 | # max_rel_cartesian, min_rel_cartesian = None, None
47 | max_gripper, min_gripper = None, None
48 |
49 | # Load each data point and save in a list
50 | dirs = [x for x in DATASET_PATH.iterdir() if x.is_dir()]
51 | for i, data_point in enumerate(sorted(dirs)):
52 | use_sensor = True
53 | print(f"Processing data point {i+1}/{len(dirs)}")
54 |
55 | if NUM_DEMOS is not None:
56 | if int(str(data_point).split("_")[-1]) >= NUM_DEMOS:
57 | print(f"Skipping data point {data_point}")
58 | continue
59 |
60 | observation = {}
61 | # images
62 | image_dir = data_point / "videos"
63 | if not image_dir.exists():
64 | print(f"Data point {data_point} is incomplete")
65 | continue
66 | for save_idx, idx in enumerate(camera_indices):
67 | # Read the frames in the video
68 | video_path = image_dir / f"camera{idx}.mp4"
69 | cap = cv2.VideoCapture(str(video_path))
70 | if not cap.isOpened():
71 | print(f"Video {video_path} could not be opened")
72 | continue
73 | frames = []
74 | while True:
75 | ret, frame = cap.read()
76 | if not ret:
77 | break
78 | if idx == 52:
79 | # crop the right side of the image for the gripper cam
80 | shape = frame.shape
81 | crop_percent = 0.2
82 | frame = frame[:, : int(shape[1] * (1 - crop_percent))]
83 | frame = cv2.resize(frame, img_size)
84 | frames.append(frame)
85 | if idx < 80:
86 | observation[f"pixels{idx}"] = np.array(frames)
87 | else:
88 | observation[f"digit{idx}"] = np.array(frames)
89 | # read cartesian and gripper states from csv
90 | state_csv_path = data_point / "states.csv"
91 | sensor_csv_path = data_point / "sensor.csv"
92 | state = read_csv(state_csv_path)
93 | try:
94 | sensor_data = read_csv(sensor_csv_path)
95 | sensor_states = sensor_data["sensor_values"].values
96 | sensor_states = np.array(
97 | [
98 | np.array([float(x.strip()) for x in sensor[1:-1].split(",")])
99 | for sensor in sensor_states
100 | ],
101 | dtype=np.float32,
102 | )
103 | except FileNotFoundError:
104 | use_sensor = False
105 | print(f"Sensor data not found for {data_point}")
106 |
107 | # Read cartesian state where every element is a 6D pose
108 | # Separate the pose into values instead of string
109 | cartesian_states = state["pose_aa"].values
110 | cartesian_states = np.array(
111 | [
112 | np.array([float(x.strip()) for x in pose[1:-1].split(",")])
113 | for pose in cartesian_states
114 | ],
115 | dtype=np.float32,
116 | )
117 |
118 | gripper_states = state["gripper_state"].values.astype(np.float32)
119 | observation["cartesian_states"] = cartesian_states.astype(np.float32)
120 | observation["gripper_states"] = gripper_states.astype(np.float32)
121 | if use_sensor:
122 | observation["sensor_states"] = sensor_states.astype(np.float32)
123 | if max_sensor is None:
124 | max_sensor = np.max(sensor_states)
125 | min_sensor = np.min(sensor_states)
126 | else:
127 | max_sensor = np.maximum(max_sensor, np.max(sensor_states))
128 | min_sensor = np.minimum(min_sensor, np.min(sensor_states))
129 | max_sensor = np.max(sensor_states, axis=0)
130 | min_sensor = np.min(sensor_states, axis=0)
131 |
132 | # update max and min
133 | if max_cartesian is None:
134 | max_cartesian = np.max(cartesian_states, axis=0)
135 | min_cartesian = np.min(cartesian_states, axis=0)
136 | else:
137 | max_cartesian = np.maximum(max_cartesian, np.max(cartesian_states, axis=0))
138 | min_cartesian = np.minimum(min_cartesian, np.min(cartesian_states, axis=0))
139 | if max_gripper is None:
140 | max_gripper = np.max(gripper_states)
141 | min_gripper = np.min(gripper_states)
142 | else:
143 | max_gripper = np.maximum(max_gripper, np.max(gripper_states))
144 | min_gripper = np.minimum(min_gripper, np.min(gripper_states))
145 |
146 | # append to observations
147 | observations.append(observation)
148 |
149 | # Save the data
150 | data = {
151 | "observations": observations,
152 | "max_cartesian": max_cartesian,
153 | "min_cartesian": min_cartesian,
154 | "max_gripper": max_gripper,
155 | "min_gripper": min_gripper,
156 | "max_sensor": max_sensor,
157 | "min_sensor": min_sensor,
158 | }
159 | pkl.dump(data, open(SAVE_DATA_PATH / f"{TASK_NAME}.pkl", "wb"))
160 |
--------------------------------------------------------------------------------
/logger.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import datetime
3 | from collections import defaultdict
4 |
5 | import numpy as np
6 | import torch
7 | import torchvision
8 | from termcolor import colored
9 | from torch.utils.tensorboard import SummaryWriter
10 |
11 | BC_TRAIN_FORMAT = [
12 | ("step", "S", "int"),
13 | ("actor_loss", "L", "float"),
14 | ("total_time", "T", "time"),
15 | ]
16 | BC_EVAL_FORMAT = [
17 | ("frame", "F", "int"),
18 | ("step", "S", "int"),
19 | ("episode", "E", "int"),
20 | ("episode_length", "L", "int"),
21 | ("episode_reward", "R", "float"),
22 | ("imitation_reward", "R_i", "float"),
23 | ("total_time", "T", "time"),
24 | ]
25 | SSL_TRAIN_FORMAT = [
26 | ("step", "S", "int"),
27 | ("loss", "L", "float"),
28 | ("total_time", "T", "time"),
29 | ]
30 | SSL_EVAL_FORMAT = [
31 | ("epoch", "E", "int"),
32 | ("step", "S", "int"),
33 | ("loss", "E", "float"),
34 | ("total_time", "T", "time"),
35 | ]
36 |
37 |
38 | class AverageMeter(object):
39 | def __init__(self):
40 | self._sum = 0
41 | self._count = 0
42 |
43 | def update(self, value, n=1):
44 | self._sum += value
45 | self._count += n
46 |
47 | def value(self):
48 | return self._sum / max(1, self._count)
49 |
50 |
51 | class MetersGroup(object):
52 | def __init__(self, csv_file_name, formating):
53 | self._csv_file_name = csv_file_name
54 | self._formating = formating
55 | self._meters = defaultdict(AverageMeter)
56 | self._csv_file = None
57 | self._csv_writer = None
58 |
59 | def log(self, key, value, n=1):
60 | self._meters[key].update(value, n)
61 |
62 | def _prime_meters(self):
63 | data = dict()
64 | for key, meter in self._meters.items():
65 | if key.startswith("train_vq"):
66 | key = key[len("train_vq") + 1 :]
67 | elif key.startswith("train"):
68 | key = key[len("train") + 1 :]
69 | else:
70 | key = key[len("eval") + 1 :]
71 | key = key.replace("/", "_")
72 | data[key] = meter.value()
73 | return data
74 |
75 | def _remove_old_entries(self, data):
76 | rows = []
77 | with self._csv_file_name.open("r") as f:
78 | reader = csv.DictReader(f)
79 | for row in reader:
80 | if float(row["step"]) >= data["step"]:
81 | break
82 | rows.append(row)
83 | with self._csv_file_name.open("w") as f:
84 | writer = csv.DictWriter(f, fieldnames=sorted(data.keys()), restval=0.0)
85 | writer.writeheader()
86 | for row in rows:
87 | writer.writerow(row)
88 |
89 | def _dump_to_csv(self, data):
90 | if self._csv_writer is None:
91 | should_write_header = True
92 | if self._csv_file_name.exists():
93 | self._remove_old_entries(data)
94 | should_write_header = False
95 |
96 | self._csv_file = self._csv_file_name.open("a")
97 | self._csv_writer = csv.DictWriter(
98 | self._csv_file, fieldnames=sorted(data.keys()), restval=0.0
99 | )
100 | if should_write_header:
101 | self._csv_writer.writeheader()
102 |
103 | self._csv_writer.writerow(data)
104 | self._csv_file.flush()
105 |
106 | def _format(self, key, value, ty):
107 | if ty == "int":
108 | value = int(value)
109 | return f"{key}: {value}"
110 | elif ty == "float":
111 | return f"{key}: {value:.04f}"
112 | elif ty == "time":
113 | value = str(datetime.timedelta(seconds=int(value)))
114 | return f"{key}: {value}"
115 | else:
116 | raise f"invalid format type: {ty}"
117 |
118 | def _dump_to_console(self, data, prefix):
119 | prefix = colored(
120 | prefix, "yellow" if prefix in ["train", "train_vq"] else "green"
121 | )
122 | pieces = [f"| {prefix: <14}"]
123 | for key, disp_key, ty in self._formating:
124 | value = data.get(key, 0)
125 | pieces.append(self._format(disp_key, value, ty))
126 | print(" | ".join(pieces))
127 |
128 | def dump(self, step, prefix):
129 | if len(self._meters) == 0:
130 | return
131 | data = self._prime_meters()
132 | data["frame"] = step
133 | self._dump_to_csv(data)
134 | self._dump_to_console(data, prefix)
135 | self._meters.clear()
136 |
137 |
138 | class Logger(object):
139 | def __init__(self, log_dir, use_tb, mode="bc"):
140 | """
141 | mode: bc, ssl
142 | """
143 | self._log_dir = log_dir
144 | if mode == "bc":
145 | self._train_mg = MetersGroup(
146 | log_dir / "train.csv", formating=BC_TRAIN_FORMAT
147 | )
148 | self._eval_mg = MetersGroup(log_dir / "eval.csv", formating=BC_EVAL_FORMAT)
149 | elif mode == "ssl":
150 | self._train_mg = MetersGroup(
151 | log_dir / "train.csv", formating=SSL_TRAIN_FORMAT
152 | )
153 | self._train_vq_mg = MetersGroup(
154 | log_dir / "train_vq.csv", formating=SSL_TRAIN_FORMAT
155 | )
156 | self._eval_mg = MetersGroup(log_dir / "eval.csv", formating=SSL_EVAL_FORMAT)
157 | if use_tb:
158 | self._sw = SummaryWriter(str(log_dir / "tb"))
159 | else:
160 | self._sw = None
161 |
162 | def _try_sw_log(self, key, value, step):
163 | if self._sw is not None:
164 | self._sw.add_scalar(key, value, step)
165 |
166 | def log(self, key, value, step):
167 | assert key.startswith("train") or key.startswith("eval")
168 | if type(value) == torch.Tensor:
169 | value = value.item()
170 | self._try_sw_log(key, value, step)
171 | # mg = self._train_mg if key.startswith('train') else self._eval_mg
172 | if key.startswith("train_vq"):
173 | mg = self._train_vq_mg
174 | else:
175 | mg = self._train_mg if key.startswith("train") else self._eval_mg
176 | mg.log(key, value)
177 |
178 | def log_metrics(self, metrics, step, ty):
179 | for key, value in metrics.items():
180 | self.log(f"{ty}/{key}", value, step)
181 |
182 | def dump(self, step, ty=None):
183 | if ty is None or ty == "eval":
184 | self._eval_mg.dump(step, "eval")
185 | if ty is None or ty == "train":
186 | self._train_mg.dump(step, "train")
187 | if ty is None or ty == "train_vq":
188 | self._train_vq_mg.dump(step, "train_vq")
189 |
190 | def log_and_dump_ctx(self, step, ty):
191 | return LogAndDumpCtx(self, step, ty)
192 |
193 |
194 | class LogAndDumpCtx:
195 | def __init__(self, logger, step, ty):
196 | self._logger = logger
197 | self._step = step
198 | self._ty = ty
199 |
200 | def __enter__(self):
201 | return self
202 |
203 | def __call__(self, key, value):
204 | self._logger.log(f"{self._ty}/{key}", value, self._step)
205 |
206 | def __exit__(self, *args):
207 | self._logger.dump(self._step, self._ty)
208 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import warnings
4 | import os
5 |
6 | os.environ["MKL_SERVICE_FORCE_INTEL"] = "1"
7 | os.environ["MUJOCO_GL"] = "egl" # "osmesa" for calvin
8 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
9 | from pathlib import Path
10 |
11 | import hydra
12 | import torch
13 | import cv2
14 | import numpy as np
15 |
16 | import utils
17 | from logger import Logger
18 | from replay_buffer import make_expert_replay_loader
19 | from video import VideoRecorder
20 |
21 | warnings.filterwarnings("ignore", category=DeprecationWarning)
22 | torch.backends.cudnn.benchmark = True
23 |
24 |
25 | def make_agent(obs_spec, action_spec, cfg):
26 | obs_shape = {}
27 | for key in cfg.suite.pixel_keys:
28 | obs_shape[key] = obs_spec[key].shape
29 | if cfg.use_aux_inputs:
30 | for key in cfg.suite.aux_keys:
31 | obs_shape[key] = obs_spec[key].shape
32 | obs_shape[cfg.suite.feature_key] = obs_spec[cfg.suite.feature_key].shape
33 | cfg.agent.obs_shape = obs_shape
34 | cfg.agent.action_shape = action_spec.shape
35 | return hydra.utils.instantiate(cfg.agent)
36 |
37 |
38 | class WorkspaceIL:
39 | def __init__(self, cfg):
40 | self.work_dir = Path.cwd()
41 | print(f"workspace: {self.work_dir}")
42 |
43 | self.cfg = cfg
44 | utils.set_seed_everywhere(cfg.seed)
45 | self.device = torch.device(cfg.device)
46 |
47 | # load data
48 | dataset_iterable = hydra.utils.call(self.cfg.expert_dataset)
49 | self.expert_replay_loader = make_expert_replay_loader(
50 | dataset_iterable, self.cfg.batch_size
51 | )
52 | self.expert_replay_iter = iter(self.expert_replay_loader)
53 |
54 | # create logger
55 | self.logger = Logger(self.work_dir, use_tb=self.cfg.use_tb)
56 | # create envs
57 | self.cfg.suite.task_make_fn.max_episode_len = 400 # (
58 | # self.expert_replay_loader.dataset._max_episode_len
59 | # )
60 | self.cfg.suite.task_make_fn.max_state_dim = (
61 | self.expert_replay_loader.dataset._max_state_dim
62 | )
63 | if self.cfg.suite.name == "dmc":
64 | self.cfg.suite.task_make_fn.max_action_dim = (
65 | self.expert_replay_loader.dataset._max_action_dim
66 | )
67 | self.env, self.task_descriptions = hydra.utils.call(self.cfg.suite.task_make_fn)
68 |
69 | # create agent
70 | self.agent = make_agent(
71 | self.env[0].observation_spec(), self.env[0].action_spec(), cfg
72 | )
73 |
74 | if self.cfg.sequential_train:
75 | self.scene_idx = 0
76 | self.scene_names = self.cfg.suite.task.scenes
77 | self.task_names = {
78 | task_name: scene[task_name]
79 | for scene in self.cfg.suite.task.tasks
80 | for task_name in scene
81 | }
82 | self.envs_till_idx = len(self.task_names[self.scene_names[0]])
83 | self.expert_replay_loader.dataset.envs_till_idx = self.envs_till_idx
84 | self.steps_till_next_scene = (
85 | self.envs_till_idx * self.cfg.suite.num_train_steps_per_task
86 | )
87 | self.expert_replay_iter = iter(self.expert_replay_loader)
88 | else:
89 | self.envs_till_idx = len(self.env)
90 | self.expert_replay_loader.dataset.envs_till_idx = self.envs_till_idx
91 | self.expert_replay_iter = iter(self.expert_replay_loader)
92 |
93 | # # Discretizer for BeT
94 | # if self.cfg.agent.policy_head in ["bet", "vqbet"]:
95 | # self.agent.discretize(self.expert_replay_loader.dataset.actions)
96 |
97 | self.timer = utils.Timer()
98 | self._global_step = 0
99 | self._global_episode = 0
100 |
101 | self.video_recorder = VideoRecorder(
102 | self.work_dir if self.cfg.save_video else None
103 | )
104 |
105 | @property
106 | def global_step(self):
107 | return self._global_step
108 |
109 | @property
110 | def global_episode(self):
111 | return self._global_episode
112 |
113 | @property
114 | def global_frame(self):
115 | return self.global_step * self.cfg.suite.action_repeat
116 |
117 | def eval(self):
118 | self.agent.train(False)
119 | episode_rewards = []
120 | successes = []
121 | for env_idx in range(self.envs_till_idx):
122 | print(f"evaluating env {env_idx}")
123 | episode, total_reward = 0, 0
124 | eval_until_episode = utils.Until(self.cfg.suite.num_eval_episodes)
125 | success = []
126 |
127 | while eval_until_episode(episode):
128 | time_step = self.env[env_idx].reset()
129 | self.agent.buffer_reset()
130 | step = 0
131 |
132 | # prompt
133 | # if self.cfg.prompt != None and self.cfg.prompt != "intermediate_goal":
134 | # prompt = self.expert_replay_loader.dataset.sample_test(
135 | # 0
136 | # ) # env_idx)
137 | # else:
138 | # prompt = None
139 |
140 | if episode == 0:
141 | self.video_recorder.init(self.env[env_idx], enabled=True)
142 |
143 | # plot obs with cv2
144 | while not time_step.last():
145 | # if self.cfg.prompt == "intermediate_goal":
146 | # prompt = self.expert_replay_loader.dataset.sample_test(
147 | # env_idx, step
148 | # )
149 | with torch.no_grad(), utils.eval_mode(self.agent):
150 | action = self.agent.act(
151 | time_step.observation,
152 | None,
153 | # self.expert_replay_loader.dataset.stats,
154 | self.stats,
155 | step,
156 | self.global_step,
157 | eval_mode=True,
158 | )
159 | print(f"step: {step}")
160 |
161 | time_step = self.env[env_idx].step(action)
162 | self.video_recorder.record(self.env[env_idx])
163 | total_reward += time_step.reward
164 | step += 1
165 |
166 | if self.cfg.suite.name == "calvin" and time_step.reward == 1:
167 | self.agent.buffer_reset()
168 |
169 | episode += 1
170 | success.append(time_step.observation["goal_achieved"])
171 | self.video_recorder.save(f"{self.global_frame}_env{env_idx}.mp4")
172 | episode_rewards.append(total_reward / episode)
173 | successes.append(np.mean(success))
174 |
175 | for _ in range(len(self.env) - self.envs_till_idx):
176 | episode_rewards.append(0)
177 | successes.append(0)
178 |
179 | with self.logger.log_and_dump_ctx(self.global_frame, ty="eval") as log:
180 | for env_idx, reward in enumerate(episode_rewards):
181 | log(f"episode_reward_env{env_idx}", reward)
182 | log(f"success_env{env_idx}", successes[env_idx])
183 | log("episode_reward", np.mean(episode_rewards[: self.envs_till_idx]))
184 | log("success", np.mean(successes))
185 | log("episode_length", step * self.cfg.suite.action_repeat / episode)
186 | log("episode", self.global_episode)
187 | log("step", self.global_step)
188 |
189 | self.agent.train(True)
190 |
191 | def save_snapshot(self):
192 | snapshot = self.work_dir / "snapshot.pt"
193 | self.agent.clear_buffers()
194 | keys_to_save = ["timer", "_global_step", "_global_episode"]
195 | payload = {k: self.__dict__[k] for k in keys_to_save}
196 | payload.update(self.agent.save_snapshot())
197 | with snapshot.open("wb") as f:
198 | torch.save(payload, f)
199 |
200 | self.agent.buffer_reset()
201 |
202 | def load_snapshot(self, snapshots):
203 | # bc
204 | with snapshots["bc"].open("rb") as f:
205 | payload = torch.load(f)
206 | agent_payload = {}
207 | for k, v in payload.items():
208 | if k not in self.__dict__:
209 | agent_payload[k] = v
210 | if "vqvae" in snapshots:
211 | with snapshots["vqvae"].open("rb") as f:
212 | payload = torch.load(f)
213 | agent_payload["vqvae"] = payload
214 | self.agent.load_snapshot(agent_payload, eval=True)
215 | # self.agent.load_snapshot_eval(agent_payload)
216 |
217 | self.stats = payload["stats"]
218 |
219 |
220 | @hydra.main(config_path="cfgs", config_name="config_eval", version_base=None)
221 | def main(cfg):
222 | from eval import WorkspaceIL as W
223 |
224 | root_dir = Path.cwd()
225 | workspace = W(cfg)
226 |
227 | # Load weights
228 | snapshots = {}
229 | # bc
230 | bc_snapshot = Path(cfg.bc_weight)
231 | if not bc_snapshot.exists():
232 | raise FileNotFoundError(f"bc weight not found: {bc_snapshot}")
233 | print(f"loading bc weight: {bc_snapshot}")
234 | snapshots["bc"] = bc_snapshot
235 | # vqvae_snapshot = Path(cfg.vqvae_weight)
236 | # if vqvae_snapshot.exists():
237 | # print(f"loading vqvae weight: {vqvae_snapshot}")
238 | # snapshots["vqvae"] = vqvae_snapshot
239 | workspace.load_snapshot(snapshots)
240 |
241 | workspace.eval()
242 |
243 |
244 | if __name__ == "__main__":
245 | main()
246 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import random
2 | import re
3 | import time
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from omegaconf import OmegaConf
10 | from torch import distributions as pyd
11 | from torch.distributions.utils import _standard_normal
12 |
13 | # Set data directory here
14 | DATA_DIR = ""
15 |
16 |
17 | class eval_mode:
18 | def __init__(self, *models):
19 | self.models = models
20 |
21 | def __enter__(self):
22 | self.prev_states = []
23 | for model in self.models:
24 | self.prev_states.append(model.training)
25 | model.train(False)
26 |
27 | def __exit__(self, *args):
28 | for model, state in zip(self.models, self.prev_states):
29 | model.train(state)
30 | return False
31 |
32 |
33 | def set_seed_everywhere(seed):
34 | torch.manual_seed(seed)
35 | if torch.cuda.is_available():
36 | torch.cuda.manual_seed_all(seed)
37 | np.random.seed(seed)
38 | random.seed(seed)
39 |
40 |
41 | def soft_update_params(net, target_net, tau):
42 | for param, target_param in zip(net.parameters(), target_net.parameters()):
43 | target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
44 |
45 |
46 | def to_torch(xs, device):
47 | for key, value in xs.items():
48 | xs[key] = torch.as_tensor(value, device=device)
49 | return xs
50 |
51 |
52 | def weight_init(m):
53 | if isinstance(m, nn.Linear):
54 | nn.init.orthogonal_(m.weight.data)
55 | if hasattr(m.bias, "data"):
56 | m.bias.data.fill_(0.0)
57 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
58 | gain = nn.init.calculate_gain("relu")
59 | nn.init.orthogonal_(m.weight.data, gain)
60 | if hasattr(m.bias, "data"):
61 | m.bias.data.fill_(0.0)
62 |
63 |
64 | class Until:
65 | def __init__(self, until, action_repeat=1):
66 | self._until = until
67 | self._action_repeat = action_repeat
68 |
69 | def __call__(self, step):
70 | if self._until is None:
71 | return True
72 | until = self._until // self._action_repeat
73 | return step < until
74 |
75 |
76 | class Every:
77 | def __init__(self, every, action_repeat=1):
78 | self._every = every
79 | self._action_repeat = action_repeat
80 |
81 | def __call__(self, step):
82 | if self._every is None:
83 | return False
84 | every = self._every // self._action_repeat
85 | if step % every == 0:
86 | return True
87 | return False
88 |
89 |
90 | class Timer:
91 | def __init__(self):
92 | self._start_time = time.time()
93 | self._last_time = time.time()
94 | # Keep track of evaluation time so that total time only includes train time
95 | self._eval_start_time = 0
96 | self._eval_time = 0
97 | self._eval_flag = False
98 |
99 | def reset(self):
100 | elapsed_time = time.time() - self._last_time
101 | self._last_time = time.time()
102 | total_time = time.time() - self._start_time - self._eval_time
103 | return elapsed_time, total_time
104 |
105 | def eval(self):
106 | if not self._eval_flag:
107 | self._eval_flag = True
108 | self._eval_start_time = time.time()
109 | else:
110 | self._eval_time += time.time() - self._eval_start_time
111 | self._eval_flag = False
112 | self._eval_start_time = 0
113 |
114 | def total_time(self):
115 | return time.time() - self._start_time - self._eval_time
116 |
117 |
118 | class TruncatedNormal(pyd.Normal):
119 | def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6):
120 | super().__init__(loc, scale, validate_args=False)
121 | self.low = low
122 | self.high = high
123 | self.eps = eps
124 |
125 | def _clamp(self, x):
126 | clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps)
127 | x = x - x.detach() + clamped_x.detach()
128 | return x
129 |
130 | def sample(self, clip=None, sample_shape=torch.Size()):
131 | shape = self._extended_shape(sample_shape)
132 | eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
133 | eps *= self.scale
134 | if clip is not None:
135 | eps = torch.clamp(eps, -clip, clip)
136 | x = self.loc + eps
137 | return self._clamp(x)
138 |
139 |
140 | def schedule(schdl, step):
141 | try:
142 | return float(schdl)
143 | except ValueError:
144 | match = re.match(r"linear\((.+),(.+),(.+)\)", schdl)
145 | if match:
146 | init, final, duration = [float(g) for g in match.groups()]
147 | mix = np.clip(step / duration, 0.0, 1.0)
148 | return (1.0 - mix) * init + mix * final
149 | match = re.match(r"step_linear\((.+),(.+),(.+),(.+),(.+)\)", schdl)
150 | if match:
151 | init, final1, duration1, final2, duration2 = [
152 | float(g) for g in match.groups()
153 | ]
154 | if step <= duration1:
155 | mix = np.clip(step / duration1, 0.0, 1.0)
156 | return (1.0 - mix) * init + mix * final1
157 | else:
158 | mix = np.clip((step - duration1) / duration2, 0.0, 1.0)
159 | return (1.0 - mix) * final1 + mix * final2
160 | raise NotImplementedError(schdl)
161 |
162 |
163 | class RandomShiftsAug(nn.Module):
164 | def __init__(self, pad):
165 | super().__init__()
166 | self.pad = pad
167 |
168 | def forward(self, x):
169 | n, c, h, w = x.size()
170 | assert h == w
171 | padding = tuple([self.pad] * 4)
172 | x = F.pad(x, padding, "replicate")
173 | eps = 1.0 / (h + 2 * self.pad)
174 | arange = torch.linspace(
175 | -1.0 + eps, 1.0 - eps, h + 2 * self.pad, device=x.device, dtype=x.dtype
176 | )[:h]
177 | arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
178 | base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
179 | base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1)
180 |
181 | shift = torch.randint(
182 | 0, 2 * self.pad + 1, size=(n, 1, 1, 2), device=x.device, dtype=x.dtype
183 | )
184 | shift *= 2.0 / (h + 2 * self.pad)
185 |
186 | grid = base_grid + shift
187 | return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False)
188 |
189 |
190 | class TorchRunningMeanStd:
191 | def __init__(self, epsilon=1e-4, shape=(), device=None):
192 | self.mean = torch.zeros(shape, device=device)
193 | self.var = torch.ones(shape, device=device)
194 | self.count = epsilon
195 |
196 | def update(self, x):
197 | with torch.no_grad():
198 | batch_mean = torch.mean(x, axis=0)
199 | batch_var = torch.var(x, axis=0)
200 | batch_count = x.shape[0]
201 | self.update_from_moments(batch_mean, batch_var, batch_count)
202 |
203 | def update_from_moments(self, batch_mean, batch_var, batch_count):
204 | self.mean, self.var, self.count = update_mean_var_count_from_moments(
205 | self.mean, self.var, self.count, batch_mean, batch_var, batch_count
206 | )
207 |
208 | @property
209 | def std(self):
210 | return torch.sqrt(self.var)
211 |
212 |
213 | def update_mean_var_count_from_moments(
214 | mean, var, count, batch_mean, batch_var, batch_count
215 | ):
216 | delta = batch_mean - mean
217 | tot_count = count + batch_count
218 |
219 | new_mean = mean + delta + batch_count / tot_count
220 | m_a = var * count
221 | m_b = batch_var * batch_count
222 | M2 = m_a + m_b + torch.pow(delta, 2) * count * batch_count / tot_count
223 | new_var = M2 / tot_count
224 | new_count = tot_count
225 |
226 | return new_mean, new_var, new_count
227 |
228 |
229 | def batch_norm_to_group_norm(layer):
230 | """Iterates over a whole model (or layer of a model) and replaces every batch norm 2D with a group norm
231 |
232 | Args:
233 | layer: model or one layer of a model like resnet34.layer1 or Sequential(), ...
234 | """
235 |
236 | # num_channels: num_groups
237 | GROUP_NORM_LOOKUP = {
238 | 16: 2, # -> channels per group: 8
239 | 32: 4, # -> channels per group: 8
240 | 64: 8, # -> channels per group: 8
241 | 128: 8, # -> channels per group: 16
242 | 256: 16, # -> channels per group: 16
243 | 512: 32, # -> channels per group: 16
244 | 1024: 32, # -> channels per group: 32
245 | 2048: 32, # -> channels per group: 64
246 | }
247 |
248 | for name, module in layer.named_modules():
249 | if name:
250 | try:
251 | # name might be something like: model.layer1.sequential.0.conv1 --> this wont work. Except this case
252 | sub_layer = getattr(layer, name)
253 | if isinstance(sub_layer, torch.nn.BatchNorm2d):
254 | num_channels = sub_layer.num_features
255 | # first level of current layer or model contains a batch norm --> replacing.
256 | layer._modules[name] = torch.nn.GroupNorm(
257 | GROUP_NORM_LOOKUP[num_channels], num_channels
258 | )
259 | except AttributeError:
260 | # go deeper: set name to layer1, getattr will return layer1 --> call this func again
261 | name = name.split(".")[0]
262 | sub_layer = getattr(layer, name)
263 | sub_layer = batch_norm_to_group_norm(sub_layer)
264 | layer.__setattr__(name=name, value=sub_layer)
265 | return layer
266 |
--------------------------------------------------------------------------------
/agent/networks/rgb_modules.py:
--------------------------------------------------------------------------------
1 | """
2 | This file contains all neural modules related to encoding the spatial
3 | information of obs_t, i.e., the abstracted knowledge of the current visual
4 | input conditioned on the language.
5 | """
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import torchvision
10 |
11 | import utils
12 |
13 |
14 | ###############################################################################
15 | #
16 | # Modules related to encoding visual information (can conditioned on language)
17 | #
18 | ###############################################################################
19 |
20 |
21 | class BaseEncoder(nn.Module):
22 | def __init__(self, obs_shape):
23 | super().__init__()
24 |
25 | assert len(obs_shape) == 3
26 | self.repr_dim = 512
27 |
28 | self.convnet = nn.Sequential(
29 | nn.Conv2d(obs_shape[0], 32, 3, stride=2),
30 | nn.ReLU(),
31 | nn.Conv2d(32, 32, 3, stride=1),
32 | nn.ReLU(),
33 | nn.Conv2d(32, 32, 3, stride=1),
34 | nn.ReLU(),
35 | nn.Conv2d(32, 32, 3, stride=1),
36 | nn.ReLU(),
37 | )
38 |
39 | if obs_shape[1] == 84:
40 | dim = 39200
41 | elif obs_shape[1] == 128:
42 | dim = 103968
43 | elif obs_shape[1] == 224:
44 | dim = 352800
45 | self.trunk = nn.Sequential(nn.Linear(dim, 512), nn.LayerNorm(512), nn.Tanh())
46 |
47 | self.apply(utils.weight_init)
48 |
49 | def forward(self, obs):
50 | obs = obs - 0.5
51 | h = self.convnet(obs)
52 | # h = h.view(h.shape[0], -1)
53 | h = h.reshape(h.shape[0], -1)
54 | h = self.trunk(h)
55 | return h
56 |
57 |
58 | class PatchEncoder(nn.Module):
59 | """
60 | A patch encoder that does a linear projection of patches in a RGB image.
61 | """
62 |
63 | def __init__(
64 | self, input_shape, patch_size=[16, 16], embed_size=64, no_patch_embed_bias=False
65 | ):
66 | super().__init__()
67 | C, H, W = input_shape
68 | num_patches = (H // patch_size[0] // 2) * (W // patch_size[1] // 2)
69 | self.img_size = (H, W)
70 | self.patch_size = patch_size
71 | self.num_patches = num_patches
72 | self.h, self.w = H // patch_size[0] // 2, W // patch_size[1] // 2
73 |
74 | self.conv = nn.Sequential(
75 | nn.Conv2d(
76 | C, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
77 | ),
78 | nn.BatchNorm2d(
79 | 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
80 | ),
81 | nn.ReLU(inplace=True),
82 | )
83 | self.proj = nn.Conv2d(
84 | 64,
85 | embed_size,
86 | kernel_size=patch_size,
87 | stride=patch_size,
88 | bias=False if no_patch_embed_bias else True,
89 | )
90 | self.bn = nn.BatchNorm2d(embed_size)
91 |
92 | def forward(self, x):
93 | B, C, H, W = x.shape
94 | x = self.conv(x)
95 | x = self.proj(x)
96 | x = self.bn(x)
97 | return x
98 |
99 |
100 | class SpatialSoftmax(nn.Module):
101 | """
102 | The spatial softmax layer (https://rll.berkeley.edu/dsae/dsae.pdf)
103 | """
104 |
105 | def __init__(self, in_c, in_h, in_w, num_kp=None):
106 | super().__init__()
107 | self._spatial_conv = nn.Conv2d(in_c, num_kp, kernel_size=1)
108 |
109 | pos_x, pos_y = torch.meshgrid(
110 | torch.linspace(-1, 1, in_w).float(),
111 | torch.linspace(-1, 1, in_h).float(),
112 | )
113 |
114 | pos_x = pos_x.reshape(1, in_w * in_h)
115 | pos_y = pos_y.reshape(1, in_w * in_h)
116 | self.register_buffer("pos_x", pos_x)
117 | self.register_buffer("pos_y", pos_y)
118 |
119 | if num_kp is None:
120 | self._num_kp = in_c
121 | else:
122 | self._num_kp = num_kp
123 |
124 | self._in_c = in_c
125 | self._in_w = in_w
126 | self._in_h = in_h
127 |
128 | def forward(self, x):
129 | assert x.shape[1] == self._in_c
130 | assert x.shape[2] == self._in_h
131 | assert x.shape[3] == self._in_w
132 |
133 | h = x
134 | if self._num_kp != self._in_c:
135 | h = self._spatial_conv(h)
136 | h = h.contiguous().view(-1, self._in_h * self._in_w)
137 |
138 | attention = F.softmax(h, dim=-1)
139 | keypoint_x = (
140 | (self.pos_x * attention).sum(1, keepdims=True).view(-1, self._num_kp)
141 | )
142 | keypoint_y = (
143 | (self.pos_y * attention).sum(1, keepdims=True).view(-1, self._num_kp)
144 | )
145 | keypoints = torch.cat([keypoint_x, keypoint_y], dim=1)
146 | return keypoints
147 |
148 |
149 | class SpatialProjection(nn.Module):
150 | def __init__(self, input_shape, out_dim):
151 | super().__init__()
152 |
153 | assert (
154 | len(input_shape) == 3
155 | ), "[error] spatial projection: input shape is not a 3-tuple"
156 | in_c, in_h, in_w = input_shape
157 | num_kp = out_dim // 2
158 | self.out_dim = out_dim
159 | self.spatial_softmax = SpatialSoftmax(in_c, in_h, in_w, num_kp=num_kp)
160 | self.projection = nn.Linear(num_kp * 2, out_dim)
161 |
162 | def forward(self, x):
163 | out = self.spatial_softmax(x)
164 | out = self.projection(out)
165 | return out
166 |
167 | def output_shape(self, input_shape):
168 | return input_shape[:-3] + (self.out_dim,)
169 |
170 |
171 | class ResnetEncoder(nn.Module):
172 | """
173 | A Resnet-18-based encoder for mapping an image to a latent vector
174 |
175 | Encode (f) an image into a latent vector.
176 |
177 | y = f(x), where
178 | x: (B, C, H, W)
179 | y: (B, H_out)
180 |
181 | Args:
182 | input_shape: (C, H, W), the shape of the image
183 | output_size: H_out, the latent vector size
184 | pretrained: whether use pretrained resnet
185 | freeze: whether freeze the pretrained resnet
186 | remove_layer_num: remove the top # layers
187 | no_stride: do not use striding
188 | """
189 |
190 | def __init__(
191 | self,
192 | input_shape,
193 | output_size,
194 | pretrained=False,
195 | freeze=False,
196 | remove_layer_num=2,
197 | no_stride=False,
198 | cond_dim=768,
199 | cond_fusion="film",
200 | ):
201 | super().__init__()
202 |
203 | ### 1. encode input (images) using convolutional layers
204 | assert remove_layer_num <= 5, "[error] please only remove <=5 layers"
205 | layers = list(torchvision.models.resnet18(pretrained=pretrained).children())[
206 | :-remove_layer_num
207 | ]
208 | self.remove_layer_num = remove_layer_num
209 |
210 | assert (
211 | len(input_shape) == 3
212 | ), "[error] input shape of resnet should be (C, H, W)"
213 |
214 | in_channels = input_shape[0]
215 | if in_channels != 3: # has eye_in_hand, increase channel size
216 | conv0 = nn.Conv2d(
217 | in_channels=in_channels,
218 | out_channels=64,
219 | kernel_size=(7, 7),
220 | stride=(2, 2),
221 | padding=(3, 3),
222 | bias=False,
223 | )
224 | layers[0] = conv0
225 |
226 | self.no_stride = no_stride
227 | if self.no_stride:
228 | layers[0].stride = (1, 1)
229 | layers[3].stride = 1
230 |
231 | self.resnet18_base = nn.Sequential(*layers[:4])
232 | self.block_1 = layers[4][0]
233 | self.block_2 = layers[4][1]
234 | self.block_3 = layers[5][0]
235 | self.block_4 = layers[5][1]
236 |
237 | self.cond_fusion = cond_fusion
238 | if cond_fusion == "film":
239 | self.lang_proj1 = nn.Linear(cond_dim, 64 * 2)
240 | self.lang_proj2 = nn.Linear(cond_dim, 64 * 2)
241 | self.lang_proj3 = nn.Linear(cond_dim, 128 * 2)
242 | self.lang_proj4 = nn.Linear(cond_dim, 128 * 2)
243 |
244 | if freeze:
245 | if in_channels != 3:
246 | raise Exception(
247 | "[error] cannot freeze pretrained "
248 | + "resnet with the extra eye_in_hand input"
249 | )
250 | for param in self.resnet18_embeddings.parameters():
251 | param.requires_grad = False
252 |
253 | ### 2. project the encoded input to a latent space
254 | x = torch.zeros(1, *input_shape)
255 | y = self.block_4(
256 | self.block_3(self.block_2(self.block_1(self.resnet18_base(x))))
257 | )
258 | output_shape = y.shape # compute the out dim
259 | self.projection_layer = SpatialProjection(output_shape[1:], output_size)
260 | self.output_shape = self.projection_layer(y).shape
261 |
262 | # Replace BatchNorm layers with GroupNorm
263 | self.resnet18_base = utils.batch_norm_to_group_norm(self.resnet18_base)
264 | self.block_1 = utils.batch_norm_to_group_norm(self.block_1)
265 | self.block_2 = utils.batch_norm_to_group_norm(self.block_2)
266 | self.block_3 = utils.batch_norm_to_group_norm(self.block_3)
267 | self.block_4 = utils.batch_norm_to_group_norm(self.block_4)
268 |
269 | # self.normlayer = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
270 |
271 | def forward(self, x, cond=None, return_intermediate=False):
272 | # # preprocess
273 | # preprocess = nn.Sequential(self.normlayer)
274 | # x = preprocess(x)
275 |
276 | h = self.resnet18_base(x)
277 |
278 | h = self.block_1(h)
279 | if cond is not None and self.cond_fusion == "film": # FiLM layer
280 | B, C, H, W = h.shape
281 | beta, gamma = torch.split(
282 | self.lang_proj1(cond).reshape(B, C * 2, 1, 1), [C, C], 1
283 | )
284 | h = (1 + gamma) * h + beta
285 |
286 | h = self.block_2(h)
287 | if cond is not None and self.cond_fusion == "film": # FiLM layer
288 | B, C, H, W = h.shape
289 | beta, gamma = torch.split(
290 | self.lang_proj2(cond).reshape(B, C * 2, 1, 1), [C, C], 1
291 | )
292 | h = (1 + gamma) * h + beta
293 |
294 | h = self.block_3(h)
295 | if cond is not None and self.cond_fusion == "film": # FiLM layer
296 | B, C, H, W = h.shape
297 | beta, gamma = torch.split(
298 | self.lang_proj3(cond).reshape(B, C * 2, 1, 1), [C, C], 1
299 | )
300 | h = (1 + gamma) * h + beta
301 |
302 | h = self.block_4(h)
303 | if cond is not None and self.cond_fusion == "film": # FiLM layer
304 | B, C, H, W = h.shape
305 | beta, gamma = torch.split(
306 | self.lang_proj4(cond).reshape(B, C * 2, 1, 1), [C, C], 1
307 | )
308 | h = (1 + gamma) * h + beta
309 |
310 | if not return_intermediate:
311 | h = self.projection_layer(h)
312 | return h
313 |
314 | def output_shape(self):
315 | return self.output_shape
316 |
--------------------------------------------------------------------------------
/train_bc.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import warnings
4 | import os
5 |
6 | os.environ["MKL_SERVICE_FORCE_INTEL"] = "1"
7 | os.environ["MUJOCO_GL"] = "egl"
8 | from pathlib import Path
9 |
10 | import hydra
11 | import torch
12 | import numpy as np
13 |
14 | import utils
15 | from logger import Logger
16 | from replay_buffer import make_expert_replay_loader
17 | from video import VideoRecorder
18 |
19 | warnings.filterwarnings("ignore", category=DeprecationWarning)
20 | torch.backends.cudnn.benchmark = True
21 |
22 |
23 | def make_agent(obs_spec, action_spec, cfg):
24 | obs_shape = {}
25 | for key in cfg.suite.pixel_keys:
26 | obs_shape[key] = obs_spec[key].shape
27 | if cfg.use_aux_inputs:
28 | for key in cfg.suite.aux_keys:
29 | obs_shape[key] = obs_spec[key].shape
30 | obs_shape[cfg.suite.feature_key] = obs_spec[cfg.suite.feature_key].shape
31 | cfg.agent.obs_shape = obs_shape
32 | cfg.agent.action_shape = (
33 | action_spec.shape
34 | if cfg.suite.action_type == "continuous"
35 | else action_spec.num_values
36 | )
37 | return hydra.utils.instantiate(cfg.agent)
38 |
39 |
40 | class WorkspaceIL:
41 | def __init__(self, cfg):
42 | self.work_dir = Path.cwd()
43 | print(f"workspace: {self.work_dir}")
44 |
45 | self.cfg = cfg
46 | utils.set_seed_everywhere(cfg.seed)
47 | self.device = torch.device(cfg.device)
48 |
49 | # load data
50 | dataset_iterable = hydra.utils.call(self.cfg.expert_dataset)
51 | self.expert_replay_loader = make_expert_replay_loader(
52 | dataset_iterable, self.cfg.batch_size
53 | )
54 | self.expert_replay_iter = iter(self.expert_replay_loader)
55 | self.stats = self.expert_replay_loader.dataset.stats
56 |
57 | # create logger
58 | self.logger = Logger(self.work_dir, use_tb=self.cfg.use_tb)
59 | # create envs
60 | self.cfg.suite.task_make_fn.max_episode_len = (
61 | self.expert_replay_loader.dataset._max_episode_len
62 | )
63 | self.cfg.suite.task_make_fn.max_state_dim = (
64 | self.expert_replay_loader.dataset._max_state_dim
65 | )
66 | self.env, self.task_descriptions = hydra.utils.call(self.cfg.suite.task_make_fn)
67 |
68 | # create agent
69 | self.agent = make_agent(
70 | self.env[0].observation_spec(), self.env[0].action_spec(), cfg
71 | )
72 |
73 | # TODO: Make this compatible with no eval case
74 | self.envs_till_idx = self.expert_replay_loader.dataset.envs_till_idx
75 | print(f"envs_till_idx: {self.expert_replay_loader.dataset.envs_till_idx}")
76 |
77 | # Discretizer for BeT
78 |
79 | self.timer = utils.Timer()
80 | self._global_step = 0
81 | self._global_episode = 0
82 |
83 | self.video_recorder = VideoRecorder(
84 | self.work_dir if self.cfg.save_video else None
85 | )
86 |
87 | @property
88 | def global_step(self):
89 | return self._global_step
90 |
91 | @property
92 | def global_episode(self):
93 | return self._global_episode
94 |
95 | @property
96 | def global_frame(self):
97 | return self.global_step * self.cfg.suite.action_repeat
98 |
99 | def eval(self):
100 | self.agent.train(False)
101 | episode_rewards = []
102 | successes = []
103 |
104 | num_envs = (
105 | len(self.env) if self.cfg.suite.name == "calvin" else self.envs_till_idx
106 | )
107 |
108 | for env_idx in range(num_envs):
109 | print(f"evaluating env {env_idx}")
110 | episode, total_reward = 0, 0
111 | eval_until_episode = utils.Until(self.cfg.suite.num_eval_episodes)
112 | success = []
113 |
114 | while eval_until_episode(episode):
115 | time_step = self.env[env_idx].reset()
116 | self.agent.buffer_reset()
117 | step = 0
118 |
119 | # prompt
120 | if self.cfg.prompt != None and self.cfg.prompt != "intermediate_goal":
121 | prompt = self.expert_replay_loader.dataset.sample_test(env_idx)
122 | else:
123 | prompt = None
124 |
125 | if episode == 0:
126 | self.video_recorder.init(self.env[env_idx], enabled=True)
127 |
128 | # plot obs with cv2
129 | while not time_step.last():
130 | if self.cfg.prompt == "intermediate_goal":
131 | prompt = self.expert_replay_loader.dataset.sample_test(
132 | env_idx, step
133 | )
134 | with torch.no_grad(), utils.eval_mode(self.agent):
135 | action = self.agent.act(
136 | time_step.observation,
137 | prompt,
138 | self.stats,
139 | step,
140 | self.global_step,
141 | eval_mode=True,
142 | )
143 | time_step = self.env[env_idx].step(action)
144 | self.video_recorder.record(self.env[env_idx])
145 | total_reward += time_step.reward
146 | step += 1
147 |
148 | episode += 1
149 | success.append(time_step.observation["goal_achieved"])
150 | self.video_recorder.save(f"{self.global_step}_env{env_idx}.mp4")
151 | episode_rewards.append(total_reward / episode)
152 | successes.append(np.mean(success))
153 |
154 | for _ in range(len(self.env) - num_envs):
155 | episode_rewards.append(0)
156 | successes.append(0)
157 |
158 | with self.logger.log_and_dump_ctx(self.global_step, ty="eval") as log:
159 | for env_idx, reward in enumerate(episode_rewards):
160 | log(f"episode_reward_env{env_idx}", reward)
161 | log(f"success_env{env_idx}", successes[env_idx])
162 | log("episode_reward", np.mean(episode_rewards[:num_envs]))
163 | log("success", np.mean(successes))
164 | log("episode_length", step * self.cfg.suite.action_repeat / episode)
165 | log("episode", self.global_episode)
166 | log("step", self.global_step)
167 |
168 | self.agent.train(True)
169 |
170 | def train(self):
171 | # predicates
172 | train_until_step = utils.Until(
173 | self.cfg.suite.num_train_steps, 1 # self.cfg.suite.action_repeat
174 | )
175 | log_every_step = utils.Every(
176 | self.cfg.suite.log_every_steps, 1 # self.cfg.suite.action_repeat
177 | )
178 | eval_every_step = utils.Every(
179 | self.cfg.suite.eval_every_steps, 1 # self.cfg.suite.action_repeat
180 | )
181 | save_every_step = utils.Every(
182 | self.cfg.suite.save_every_steps, 1 # self.cfg.suite.action_repeat
183 | )
184 |
185 | metrics = None
186 | while train_until_step(self.global_step):
187 | # try to evaluate
188 | if (
189 | self.cfg.eval
190 | and eval_every_step(self.global_step)
191 | and self.global_step > 0
192 | ):
193 | self.logger.log(
194 | "eval_total_time", self.timer.total_time(), self.global_frame
195 | )
196 | self.eval()
197 |
198 | # update
199 | metrics = self.agent.update(self.expert_replay_iter, self.global_step)
200 | self.logger.log_metrics(metrics, self.global_frame, ty="train")
201 |
202 | # log
203 | if log_every_step(self.global_step):
204 | elapsed_time, total_time = self.timer.reset()
205 | with self.logger.log_and_dump_ctx(self.global_frame, ty="train") as log:
206 | log("total_time", total_time)
207 | log("actor_loss", metrics["actor_loss"])
208 | log("step", self.global_step)
209 |
210 | # save snapshot
211 | if save_every_step(self.global_step):
212 | self.save_snapshot()
213 |
214 | # Update scene
215 | if self.cfg.sequential_train:
216 | if (
217 | self.global_step + 1
218 | ) % self.steps_till_next_scene == 0 and self.scene_idx < len(
219 | self.scene_names
220 | ) - 1:
221 | self.scene_idx = (self.scene_idx + 1) % len(self.scene_names)
222 | self.envs_till_idx += len(
223 | self.task_names[self.scene_names[self.scene_idx]]
224 | )
225 | self.expert_replay_loader.dataset.envs_till_idx = self.envs_till_idx
226 | self.steps_till_next_scene = (
227 | self.envs_till_idx * self.cfg.suite.num_train_steps_per_task
228 | )
229 | self.expert_replay_iter = iter(self.expert_replay_loader)
230 |
231 | # self.agent.reinit_optimizers()
232 |
233 | self._global_step += 1
234 |
235 | def save_snapshot(self):
236 | snapshot_dir = self.work_dir / "snapshot"
237 | snapshot_dir.mkdir(exist_ok=True)
238 | snapshot = snapshot_dir / f"{self.global_step}.pt"
239 | self.agent.clear_buffers()
240 | keys_to_save = ["timer", "_global_step", "_global_episode", "stats"]
241 | payload = {k: self.__dict__[k] for k in keys_to_save}
242 | payload.update(self.agent.save_snapshot())
243 | with snapshot.open("wb") as f:
244 | torch.save(payload, f)
245 |
246 | self.agent.buffer_reset()
247 |
248 | def load_snapshot(self, snapshots, encoder_only=False):
249 | # bc
250 | with snapshots["bc"].open("rb") as f:
251 | payload = torch.load(f)
252 | agent_payload = {}
253 | for k, v in payload.items():
254 | if k not in self.__dict__:
255 | agent_payload[k] = v
256 | if "vqvae" in snapshots:
257 | with snapshots["vqvae"].open("rb") as f:
258 | payload = torch.load(f)
259 | agent_payload["vqvae"] = payload
260 | self.agent.load_snapshot(agent_payload, encoder_only=encoder_only, eval=False)
261 | # self.agent.load_snapshot_eval(agent_payload)
262 |
263 |
264 | @hydra.main(config_path="cfgs", config_name="config", version_base=None)
265 | def main(cfg):
266 | from train_bc import WorkspaceIL as W
267 |
268 | workspace = W(cfg)
269 |
270 | # Load weights
271 | if cfg.load_bc:
272 | # BC weight
273 | snapshots = {}
274 | bc_snapshot = Path(cfg.bc_weight)
275 | if not bc_snapshot.exists():
276 | raise FileNotFoundError(f"bc weight not found: {bc_snapshot}")
277 | print(f"loading bc weight: {bc_snapshot}")
278 | snapshots["bc"] = bc_snapshot
279 | # vqvae_snapshot = Path(cfg.vqvae_weight)
280 | # if vqvae_snapshot.exists():
281 | # print(f"loading vqvae weight: {vqvae_snapshot}")
282 | # snapshots["vqvae"] = vqvae_snapshot
283 | workspace.load_snapshot(snapshots)
284 |
285 | workspace.train()
286 | # workspace.eval()
287 |
288 |
289 | if __name__ == "__main__":
290 | main()
291 |
--------------------------------------------------------------------------------
/envs/xarm-env/xarm_env/envs/robot_env.py:
--------------------------------------------------------------------------------
1 | import gym
2 | from gym import spaces
3 | import cv2
4 | import numpy as np
5 |
6 | # import pybullet
7 | # import pybullet_data
8 | import pickle
9 | from scipy.spatial.transform import Rotation as R
10 |
11 | from openteach.utils.network import create_request_socket, ZMQCameraSubscriber
12 | from xarm_env.envs.constants import *
13 |
14 |
15 | def get_quaternion_orientation(cartesian):
16 | """
17 | Get quaternion orientation from axis angle representation
18 | """
19 | pos = cartesian[:3]
20 | ori = cartesian[3:]
21 | r = R.from_rotvec(ori)
22 | quat = r.as_quat()
23 | return np.concatenate([pos, quat], axis=-1)
24 |
25 |
26 | class RobotEnv(gym.Env):
27 | def __init__(
28 | self,
29 | height=224,
30 | width=224,
31 | use_robot=True, # True when robot used
32 | use_egocentric=False, # True when egocentric camera used
33 | use_fisheye=True,
34 | sensor_type="reskin",
35 | subtract_sensor_baseline=False,
36 | ):
37 | super(RobotEnv, self).__init__()
38 | self.height = height
39 | self.width = width
40 |
41 | self.use_robot = use_robot
42 | self.use_fish_eye = use_fisheye
43 | self.use_egocentric = use_egocentric
44 | self.sensor_type = sensor_type
45 | self.digit_keys = ["digit80", "digit81"]
46 |
47 | self.subtract_sensor_baseline = subtract_sensor_baseline
48 | self.sensor_prev_state = None
49 | self.sensor_baseline = None
50 |
51 | self.feature_dim = 8 # 10 # 7
52 | self.proprio_dim = 8
53 |
54 | self.n_sensors = 2
55 | self.sensor_dim = 15
56 | self.action_dim = 7
57 |
58 | # Robot limits
59 | # self.cartesian_delta_limits = np.array([-10, 10])
60 |
61 | self.n_channels = 3
62 | self.reward = 0
63 |
64 | self.observation_space = spaces.Box(
65 | low=0, high=255, shape=(height, width, self.n_channels), dtype=np.uint8
66 | )
67 | self.action_space = spaces.Box(
68 | low=0.0, high=1.0, shape=(self.action_dim,), dtype=np.float32
69 | )
70 |
71 | if self.use_robot:
72 | # camera subscribers
73 | self.image_subscribers = {}
74 | for cam_idx in list(CAM_SERIAL_NUMS.keys()):
75 | port = CAMERA_PORT_OFFSET + cam_idx
76 | self.image_subscribers[cam_idx] = ZMQCameraSubscriber(
77 | host=HOST_ADDRESS,
78 | port=port,
79 | topic_type="RGB",
80 | )
81 |
82 | # for fish_eye_cam_idx in range(len(FISH_EYE_CAM_SERIAL_NUMS)):
83 | if use_fisheye:
84 | for fish_eye_cam_idx in list(FISH_EYE_CAM_SERIAL_NUMS.keys()):
85 | port = FISH_EYE_CAMERA_PORT_OFFSET + fish_eye_cam_idx
86 | self.image_subscribers[fish_eye_cam_idx] = ZMQCameraSubscriber(
87 | host=HOST_ADDRESS,
88 | port=port,
89 | topic_type="RGB",
90 | )
91 |
92 | # action request port
93 | self.action_request_socket = create_request_socket(
94 | HOST_ADDRESS, DEPLOYMENT_PORT
95 | )
96 |
97 | def step(self, action):
98 | print("current step's action is: ", action)
99 | action = np.array(action)
100 |
101 | action_dict = {
102 | "xarm": {
103 | "cartesian": action[:-1],
104 | "gripper": action[-1:],
105 | }
106 | }
107 |
108 | # send action
109 | self.action_request_socket.send(pickle.dumps(action_dict, protocol=-1))
110 | ret = self.action_request_socket.recv()
111 | ret = pickle.loads(ret)
112 | if ret == "Command failed!":
113 | print("Command failed!")
114 | # return None, 0, True, None
115 | self.action_request_socket.send(b"get_state")
116 | ret = pickle.loads(self.action_request_socket.recv())
117 | # robot_state = pickle.loads(self.action_request_socket.recv())["robot_state"]["xarm"]
118 | # else:
119 | # # robot_state = ret["robot_state"]["xarm"]
120 | # robot_state = ret["robot_state"]["xarm"]
121 | robot_state = ret["robot_state"]["xarm"]
122 |
123 | # cartesian_pos = robot_state[:3]
124 | # cartesian_ori = robot_state[3:6]
125 | # gripper = robot_state[6]
126 | # cartesian_ori_sin = np.sin(cartesian_ori)
127 | # cartesian_ori_cos = np.cos(cartesian_ori)
128 | # robot_state = np.concatenate(
129 | # [cartesian_pos, cartesian_ori_sin, cartesian_ori_cos, [gripper]], axis=0
130 | # )
131 | cartesian = robot_state[:6]
132 | quat_cartesian = get_quaternion_orientation(cartesian)
133 | robot_state = np.concatenate([quat_cartesian, robot_state[6:]], axis=0)
134 |
135 | # subscribe images
136 | image_dict = {}
137 | for cam_idx, img_sub in self.image_subscribers.items():
138 | image_dict[cam_idx] = img_sub.recv_rgb_image()[0]
139 |
140 | obs = {}
141 | obs["features"] = np.array(robot_state, dtype=np.float32)
142 | obs["proprioceptive"] = np.array(robot_state, dtype=np.float32)
143 | if self.sensor_type == "reskin":
144 | try:
145 | sensor_state = ret["sensor_state"]["reskin"]["sensor_values"]
146 | sensor_state_sub = (
147 | np.array(sensor_state, dtype=np.float32) - self.sensor_baseline
148 | )
149 | self.sensor_prev_state = sensor_state_sub
150 | sensor_keys = [
151 | f"sensor{sensor_idx}" for sensor_idx in range(self.n_sensors)
152 | ]
153 | for sidx, sensor_key in enumerate(sensor_keys):
154 | if self.subtract_sensor_baseline:
155 | obs[sensor_key] = sensor_state_sub[
156 | sidx * self.sensor_dim : (sidx + 1) * self.sensor_dim
157 | ]
158 | else:
159 | obs[sensor_key] = sensor_state[
160 | sidx * self.sensor_dim : (sidx + 1) * self.sensor_dim
161 | ]
162 | except KeyError:
163 | pass
164 | elif self.sensor_type == "digit":
165 | for dkey in self.digit_keys:
166 | obs[dkey] = np.array(ret["sensor_state"][dkey])
167 | obs[dkey] = cv2.resize(obs[dkey], (self.width, self.height))
168 | if self.subtract_sensor_baseline:
169 | obs[dkey] = obs[dkey] - self.sensor_baseline
170 |
171 | for cam_idx, image in image_dict.items():
172 | if cam_idx == 52:
173 | # crop the right side of the image for the gripper cam
174 | img_shape = image.shape
175 | crop_percent = 0.2
176 | image = image[:, : int(img_shape[1] * (1 - crop_percent))]
177 | obs[f"pixels{cam_idx}"] = cv2.resize(image, (self.width, self.height))
178 | return obs, self.reward, False, False, {}
179 |
180 | def reset(self, seed=None): # currently same positions, with gripper opening
181 | if self.use_robot:
182 | print("resetting")
183 | self.action_request_socket.send(b"reset")
184 | reset_state = pickle.loads(self.action_request_socket.recv())
185 |
186 | # subscribe robot state
187 | self.action_request_socket.send(b"get_state")
188 | ret = pickle.loads(self.action_request_socket.recv())
189 | robot_state = ret["robot_state"]["xarm"]
190 | # robot_state = np.array(robot_state, dtype=np.float32)
191 | # cartesian_pos = robot_state[:3]
192 | # cartesian_ori = robot_state[3:6]
193 | # gripper = robot_state[6]
194 | # cartesian_ori_sin = np.sin(cartesian_ori)
195 | # cartesian_ori_cos = np.cos(cartesian_ori)
196 | # robot_state = np.concatenate(
197 | # [cartesian_pos, cartesian_ori_sin, cartesian_ori_cos, [gripper]], axis=0
198 | # )
199 | cartesian = robot_state[:6]
200 | quat_cartesian = get_quaternion_orientation(cartesian)
201 | robot_state = np.concatenate([quat_cartesian, robot_state[6:]], axis=0)
202 |
203 | # subscribe images
204 | image_dict = {}
205 | for cam_idx, img_sub in self.image_subscribers.items():
206 | image_dict[cam_idx] = img_sub.recv_rgb_image()[0]
207 |
208 | obs = {}
209 | obs["features"] = robot_state
210 | obs["proprioceptive"] = robot_state
211 | if self.sensor_type == "reskin":
212 | try:
213 | sensor_state = np.array(
214 | ret["sensor_state"]["reskin"]["sensor_values"]
215 | )
216 | # obs["sensor"] = np.array(sensor_state)
217 | if self.subtract_sensor_baseline:
218 | baseline_meas = []
219 | while len(baseline_meas) < 5:
220 | self.action_request_socket.send(b"get_sensor_state")
221 | ret = pickle.loads(self.action_request_socket.recv())
222 | sensor_state = ret["reskin"]["sensor_values"]
223 | baseline_meas.append(sensor_state)
224 | self.sensor_baseline = np.median(baseline_meas, axis=0)
225 | sensor_state = sensor_state - self.sensor_baseline
226 | self.sensor_prev_state = sensor_state
227 | sensor_keys = [
228 | f"sensor{sensor_idx}" for sensor_idx in range(self.n_sensors)
229 | ]
230 | for sidx, sensor_key in enumerate(sensor_keys):
231 | obs[sensor_key] = sensor_state[
232 | sidx * self.sensor_dim : (sidx + 1) * self.sensor_dim
233 | ]
234 | except KeyError:
235 | pass
236 | elif self.sensor_type == "digit":
237 | for dkey in self.digit_keys:
238 | obs[dkey] = np.array(ret["sensor_state"][dkey])
239 | obs[dkey] = cv2.resize(obs[dkey], (self.width, self.height))
240 | if self.subtract_sensor_baseline:
241 | baseline_meas = []
242 | while len(baseline_meas) < 5:
243 | self.action_request_socket.send(b"get_sensor_state")
244 | ret = pickle.loads(self.action_request_socket.recv())
245 | sensor_state = cv2.resize(
246 | ret[dkey], (self.width, self.height)
247 | )
248 | baseline_meas.append(sensor_state)
249 | self.sensor_baseline = np.median(baseline_meas, axis=0)
250 | obs[dkey] = sensor_state - self.sensor_baseline
251 | # obs["sensor"] = sensor_state - self.sensor_baseline
252 | for cam_idx, image in image_dict.items():
253 | if cam_idx == 52:
254 | # crop the right side of the image for the gripper cam
255 | img_shape = image.shape
256 | crop_percent = 0.2
257 | image = image[:, : int(img_shape[1] * (1 - crop_percent))]
258 | obs[f"pixels{cam_idx}"] = cv2.resize(image, (self.width, self.height))
259 |
260 | return obs
261 | else:
262 | obs = {}
263 | obs["features"] = np.zeros(self.feature_dim)
264 | obs["proprioceptive"] = np.zeros(self.proprio_dim)
265 | for sensor_idx in range(self.n_sensors):
266 | obs[f"sensor{sensor_idx}"] = np.zeros(self.sensor_dim)
267 | self.sensor_baseline = np.zeros(self.sensor_dim * self.n_sensors)
268 | obs["pixels"] = np.zeros((self.height, self.width, self.n_channels))
269 | return obs
270 |
271 | def render(self, mode="rgb_array", width=640, height=480):
272 | print("rendering")
273 | # subscribe images
274 | image_list = []
275 | for _, img_sub in self.image_subscribers.items():
276 | image = img_sub.recv_rgb_image()[0]
277 | image_list.append(cv2.resize(image, (width, height)))
278 |
279 | obs = np.concatenate(image_list, axis=1)
280 | return obs
281 |
282 |
283 | if __name__ == "__main__":
284 | env = RobotEnv()
285 | obs = env.reset()
286 | import ipdb
287 |
288 | ipdb.set_trace()
289 |
290 | for i in range(30):
291 | action = obs["features"]
292 | action[0] += 2
293 | obs, reward, done, _ = env.step(action)
294 |
295 | for i in range(30):
296 | action = obs["features"]
297 | action[1] += 2
298 | obs, reward, done, _ = env.step(action)
299 |
300 | for i in range(30):
301 | action = obs["features"]
302 | action[2] += 2
303 | obs, reward, done, _ = env.step(action)
304 |
--------------------------------------------------------------------------------
/agent/networks/gpt.py:
--------------------------------------------------------------------------------
1 | """
2 | An adaptation of Andrej Karpathy's nanoGPT implementation in PyTorch.
3 | Original source: https://github.com/karpathy/nanoGPT
4 |
5 | Original License:
6 | MIT License
7 |
8 | Copyright (c) 2022 Andrej Karpathy
9 |
10 | Permission is hereby granted, free of charge, to any person obtaining a copy
11 | of this software and associated documentation files (the "Software"), to deal
12 | in the Software without restriction, including without limitation the rights
13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
14 | copies of the Software, and to permit persons to whom the Software is
15 | furnished to do so, subject to the following conditions:
16 |
17 | The above copyright notice and this permission notice shall be included in all
18 | copies or substantial portions of the Software.
19 |
20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
26 | SOFTWARE.
27 |
28 | Original comments:
29 | Full definition of a GPT Language Model, all of it in this single file.
30 | References:
31 | 1) the official GPT-2 TensorFlow implementation released by OpenAI:
32 | https://github.com/openai/gpt-2/blob/master/src/model.py
33 | 2) huggingface/transformers PyTorch implementation:
34 | https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
35 | """
36 |
37 | import math
38 | from dataclasses import dataclass
39 |
40 | import torch
41 | import torch.nn as nn
42 | from torch.nn import functional as F
43 |
44 |
45 | # @torch.jit.script # good to enable when not using torch.compile, disable when using (our default)
46 | def new_gelu(x):
47 | """
48 | Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).
49 | Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415
50 | """
51 | return (
52 | 0.5
53 | * x
54 | * (
55 | 1.0
56 | + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))
57 | )
58 | )
59 |
60 |
61 | class CrossAttention(nn.Module):
62 | def __init__(self, repr_dim, nhead=4, nlayers=4, use_buffer_token=False):
63 | super().__init__()
64 | self.tf_decoder = nn.TransformerDecoder(
65 | nn.TransformerDecoderLayer(
66 | d_model=repr_dim,
67 | nhead=nhead,
68 | dim_feedforward=repr_dim * 4,
69 | batch_first=True,
70 | ),
71 | num_layers=nlayers,
72 | )
73 | if use_buffer_token:
74 | self.buffer_token = nn.Parameter(torch.randn(1, 1, repr_dim))
75 | self.use_buffer_token = use_buffer_token
76 |
77 | def forward(self, feat, cond):
78 | # add buffer token to the beginning of the sequence
79 | if self.use_buffer_token:
80 | batch_size = feat.size(0)
81 | buffer_token = self.buffer_token.expand(batch_size, 1, -1)
82 | cond_with_buffer = torch.cat([buffer_token, cond], dim=1)
83 | return self.tf_decoder(feat, cond_with_buffer)
84 | else:
85 | return self.tf_decoder(feat, cond)
86 |
87 |
88 | class CausalSelfAttention(nn.Module):
89 | def __init__(self, config):
90 | super().__init__()
91 | assert config.n_embd % config.n_head == 0
92 | # key, query, value projections for all heads, but in a batch
93 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
94 | # output projection
95 | self.c_proj = nn.Linear(config.n_embd, config.n_embd)
96 | # regularization
97 | self.attn_dropout = nn.Dropout(config.dropout)
98 | self.resid_dropout = nn.Dropout(config.dropout)
99 | # causal mask to ensure that attention is only applied to the left in the input sequence
100 | self.register_buffer(
101 | "bias",
102 | # torch.ones(1, 1, config.block_size, config.block_size),
103 | torch.tril(torch.ones(config.block_size, config.block_size)).view(
104 | 1, 1, config.block_size, config.block_size
105 | ),
106 | )
107 | self.n_head = config.n_head
108 | self.n_embd = config.n_embd
109 |
110 | def forward(self, x, attn_mask=None):
111 | (
112 | B,
113 | T,
114 | C,
115 | ) = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
116 |
117 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim
118 | q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
119 | k = k.view(B, T, self.n_head, C // self.n_head).transpose(
120 | 1, 2
121 | ) # (B, nh, T, hs)
122 | q = q.view(B, T, self.n_head, C // self.n_head).transpose(
123 | 1, 2
124 | ) # (B, nh, T, hs)
125 | v = v.view(B, T, self.n_head, C // self.n_head).transpose(
126 | 1, 2
127 | ) # (B, nh, T, hs)
128 |
129 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
130 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
131 | mask = self.bias[:, :, :T, :T]
132 | if attn_mask is not None:
133 | mask = mask * attn_mask
134 | att = att.masked_fill(mask == 0, float("-inf"))
135 | att = F.softmax(att, dim=-1)
136 | att = self.attn_dropout(att)
137 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
138 | y = (
139 | y.transpose(1, 2).contiguous().view(B, T, C)
140 | ) # re-assemble all head outputs side by side
141 |
142 | # output projection
143 | y = self.resid_dropout(self.c_proj(y))
144 | return y
145 |
146 |
147 | class MLP(nn.Module):
148 | def __init__(self, config):
149 | super().__init__()
150 | self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
151 | self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
152 | self.dropout = nn.Dropout(config.dropout)
153 |
154 | def forward(self, x):
155 | x = self.c_fc(x)
156 | x = new_gelu(x)
157 | x = self.c_proj(x)
158 | x = self.dropout(x)
159 | return x
160 |
161 |
162 | class Block(nn.Module):
163 | def __init__(self, config):
164 | super().__init__()
165 | self.ln_1 = nn.LayerNorm(config.n_embd)
166 | self.attn = CausalSelfAttention(config)
167 | self.ln_2 = nn.LayerNorm(config.n_embd)
168 | self.mlp = MLP(config)
169 |
170 | def forward(self, x, mask=None):
171 | x = x + self.attn(self.ln_1(x), attn_mask=mask)
172 | x = x + self.mlp(self.ln_2(x))
173 | return x
174 |
175 |
176 | @dataclass
177 | class GPTConfig:
178 | block_size: int = 1024
179 | input_dim: int = 256
180 | output_dim: int = 256
181 | n_layer: int = 12
182 | n_head: int = 12
183 | n_embd: int = 768
184 | dropout: float = 0.1
185 |
186 |
187 | class GPT(nn.Module):
188 | def __init__(self, config):
189 | super().__init__()
190 | assert config.input_dim is not None
191 | assert config.output_dim is not None
192 | assert config.block_size is not None
193 | self.config = config
194 |
195 | self.transformer = nn.ModuleDict(
196 | dict(
197 | wte=nn.Linear(config.input_dim, config.n_embd),
198 | wpe=nn.Embedding(config.block_size, config.n_embd),
199 | drop=nn.Dropout(config.dropout),
200 | h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
201 | ln_f=nn.LayerNorm(config.n_embd),
202 | )
203 | )
204 | self.lm_head = nn.Linear(config.n_embd, config.output_dim, bias=False)
205 | # init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper
206 | self.apply(self._init_weights)
207 | for pn, p in self.named_parameters():
208 | if pn.endswith("c_proj.weight"):
209 | torch.nn.init.normal_(
210 | p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)
211 | )
212 |
213 | # report number of parameters
214 | n_params = sum(p.numel() for p in self.parameters())
215 | print("number of parameters: %.2fM" % (n_params / 1e6,))
216 |
217 | def forward(self, input, targets=None, mask=None):
218 | device = input.device
219 | b, t, d = input.size()
220 | assert (
221 | t <= self.config.block_size
222 | ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
223 | pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(
224 | 0
225 | ) # shape (1, t)
226 |
227 | # forward the GPT model itself
228 | tok_emb = self.transformer.wte(
229 | input
230 | ) # token embeddings of shape (b, t, n_embd)
231 | pos_emb = self.transformer.wpe(
232 | pos
233 | ) # position embeddings of shape (1, t, n_embd)
234 | x = self.transformer.drop(tok_emb + pos_emb)
235 | for block in self.transformer.h:
236 | x = block(x, mask=mask)
237 | x = self.transformer.ln_f(x)
238 | logits = self.lm_head(x)
239 | return logits
240 |
241 | def _init_weights(self, module):
242 | if isinstance(module, nn.Linear):
243 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
244 | if module.bias is not None:
245 | torch.nn.init.zeros_(module.bias)
246 | elif isinstance(module, nn.Embedding):
247 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
248 | elif isinstance(module, nn.LayerNorm):
249 | torch.nn.init.zeros_(module.bias)
250 | torch.nn.init.ones_(module.weight)
251 |
252 | def crop_block_size(self, block_size):
253 | assert block_size <= self.config.block_size
254 | self.config.block_size = block_size
255 | self.transformer.wpe.weight = nn.Parameter(
256 | self.transformer.wpe.weight[:block_size]
257 | )
258 | for block in self.transformer.h:
259 | block.attn.bias = block.attn.bias[:, :, :block_size, :block_size]
260 |
261 | def configure_optimizers(self, weight_decay, learning_rate, betas):
262 | """
263 | This long function is unfortunately doing something very simple and is being very defensive:
264 | We are separating out all parameters of the model into two buckets: those that will experience
265 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
266 | We are then returning the PyTorch optimizer object.
267 | """
268 |
269 | # separate out all parameters to those that will and won't experience regularizing weight decay
270 | decay = set()
271 | no_decay = set()
272 | whitelist_weight_modules = (torch.nn.Linear,)
273 | blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
274 | for mn, m in self.named_modules():
275 | for pn, p in m.named_parameters():
276 | fpn = "%s.%s" % (mn, pn) if mn else pn # full param name
277 | if pn.endswith("bias"):
278 | # all biases will not be decayed
279 | no_decay.add(fpn)
280 | elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules):
281 | # weights of whitelist modules will be weight decayed
282 | decay.add(fpn)
283 | elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules):
284 | # weights of blacklist modules will NOT be weight decayed
285 | no_decay.add(fpn)
286 |
287 | # validate that we considered every parameter
288 | param_dict = {pn: p for pn, p in self.named_parameters()}
289 | inter_params = decay & no_decay
290 | union_params = decay | no_decay
291 | assert (
292 | len(inter_params) == 0
293 | ), "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),)
294 | assert (
295 | len(param_dict.keys() - union_params) == 0
296 | ), "parameters %s were not separated into either decay/no_decay set!" % (
297 | str(param_dict.keys() - union_params),
298 | )
299 |
300 | # create the pytorch optimizer object
301 | optim_groups = [
302 | {
303 | "params": [param_dict[pn] for pn in sorted(list(decay))],
304 | "weight_decay": weight_decay,
305 | },
306 | {
307 | "params": [param_dict[pn] for pn in sorted(list(no_decay))],
308 | "weight_decay": 0.0,
309 | },
310 | ]
311 | # optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
312 | optimizer = torch.optim.Adam(optim_groups, lr=learning_rate, betas=betas)
313 | return optimizer
314 |
--------------------------------------------------------------------------------
/suite/xarm_env.py:
--------------------------------------------------------------------------------
1 | from collections import deque
2 | from typing import Any, NamedTuple
3 |
4 | import gym
5 | from gym import Wrapper, spaces
6 |
7 | import xarm_env
8 | import dm_env
9 | import numpy as np
10 | from dm_env import StepType, specs, TimeStep
11 |
12 | import cv2
13 |
14 |
15 | class RGBArrayAsObservationWrapper(dm_env.Environment):
16 | """
17 | Use env.render(rgb_array) as observation
18 | rather than the observation environment provides
19 |
20 | From: https://github.com/hill-a/stable-baselines/issues/915
21 | """
22 |
23 | def __init__(
24 | self,
25 | env,
26 | # height,
27 | # width,
28 | max_episode_len=300,
29 | max_state_dim=100,
30 | task_description="",
31 | pixel_keys=["pixels0"],
32 | aux_keys=["proprioceptive"],
33 | use_robot=True,
34 | ):
35 | self._env = env
36 | # self._width = width
37 | # self._height = height
38 | self._max_episode_len = max_episode_len
39 | self._max_state_dim = max_state_dim
40 | self._task_description = task_description
41 | self.pixel_keys = pixel_keys
42 | self.aux_keys = aux_keys
43 | self.use_robot = use_robot
44 |
45 | # task emb
46 |
47 | obs = self._env.reset()
48 | if self.use_robot:
49 | pixels = obs[pixel_keys[0]]
50 | self.observation_space = spaces.Box(
51 | low=0, high=255, shape=pixels.shape, dtype=pixels.dtype
52 | )
53 |
54 | # Action spec
55 | action_spec = self._env.action_space
56 | # self._action_spec = specs.BoundedArray(
57 | # action_spec[0].shape, np.float32, action_spec[0], action_spec[1], "action"
58 | # )
59 | self._action_spec = specs.Array(
60 | shape=action_spec.shape, dtype=action_spec.dtype, name="action"
61 | )
62 | # Observation spec
63 | # robot_state = np.concatenate(
64 | # [obs["robot0_joint_pos"], obs["robot0_gripper_qpos"]]
65 | # )
66 | features = obs["features"]
67 | self._obs_spec = {}
68 | for key in pixel_keys:
69 | self._obs_spec[key] = specs.BoundedArray(
70 | shape=obs[key].shape,
71 | dtype=np.uint8,
72 | minimum=0,
73 | maximum=255,
74 | name=key,
75 | )
76 | else:
77 | pixels, features = obs["pixels"], obs["features"]
78 | self.observation_space = spaces.Box(
79 | low=0, high=255, shape=pixels.shape, dtype=pixels.dtype
80 | )
81 |
82 | # Action spec
83 | action_spec = self._env.action_space
84 | self._action_spec = specs.Array(
85 | shape=action_spec.shape, dtype=action_spec.dtype, name="action"
86 | )
87 |
88 | # Observation spec
89 | self._obs_spec = {}
90 | for key in pixel_keys:
91 | self._obs_spec[key] = specs.BoundedArray(
92 | shape=pixels.shape,
93 | dtype=np.uint8,
94 | minimum=0,
95 | maximum=255,
96 | name=key,
97 | )
98 |
99 | self._obs_spec["proprioceptive"] = specs.BoundedArray(
100 | shape=features.shape,
101 | dtype=np.float32,
102 | minimum=-np.inf,
103 | maximum=np.inf,
104 | name="proprioceptive",
105 | )
106 | self._obs_spec["features"] = specs.BoundedArray(
107 | shape=(self._max_state_dim,),
108 | dtype=np.float32,
109 | minimum=-np.inf,
110 | maximum=np.inf,
111 | name="features",
112 | )
113 |
114 | for key in aux_keys:
115 | if key.startswith("sensor"):
116 | self._obs_spec[key] = specs.BoundedArray(
117 | shape=obs[key].shape,
118 | dtype=np.float32,
119 | minimum=-np.inf,
120 | maximum=np.inf,
121 | name=key,
122 | )
123 | if key.startswith("digit"):
124 | self._obs_spec[key] = specs.BoundedArray(
125 | shape=pixels.shape,
126 | dtype=np.uint8,
127 | minimum=0,
128 | maximum=255,
129 | name=key,
130 | )
131 | # if "sensor" in aux_keys:
132 | # self._obs_spec["sensor"] = specs.BoundedArray(
133 | # shape=obs["sensor"].shape,
134 | # dtype=np.float32,
135 | # minimum=-np.inf,
136 | # maximum=np.inf,
137 | # name="sensor",
138 | # )
139 |
140 | self.render_image = None
141 |
142 | def reset(self, **kwargs):
143 | self._step = 0
144 | obs = self._env.reset(**kwargs)
145 |
146 | observation = {}
147 | for key in self.pixel_keys:
148 | observation[key] = obs[key]
149 | observation["proprioceptive"] = obs["features"]
150 | observation["features"] = obs["features"]
151 | for key in self.aux_keys:
152 | if key.startswith("sensor"):
153 | observation[key] = obs[key]
154 | observation["goal_achieved"] = False
155 | return observation
156 |
157 | def step(self, action):
158 | self._step += 1
159 | obs, reward, truncated, terminated, info = self._env.step(action)
160 | done = truncated or terminated
161 |
162 | observation = {}
163 | for key in self.pixel_keys:
164 | observation[key] = obs[key]
165 | observation["proprioceptive"] = obs["features"]
166 | observation["features"] = obs["features"]
167 | if "sensor" in self.aux_keys:
168 | observation["sensor"] = obs["sensor"]
169 | for key in self.aux_keys:
170 | if key.startswith("sensor"):
171 | observation[key] = obs[key]
172 | observation["goal_achieved"] = done # (self._step == self._max_episode_len)
173 | return observation, reward, done, info
174 |
175 | def observation_spec(self):
176 | return self._obs_spec
177 |
178 | def action_spec(self):
179 | return self._action_spec
180 |
181 | def render(self, mode="rgb_array", width=256, height=256):
182 | return cv2.resize(self._env.render("rgb_array"), (width, height))
183 |
184 | def __getattr__(self, name):
185 | return getattr(self._env, name)
186 |
187 |
188 | class ActionRepeatWrapper(dm_env.Environment):
189 | def __init__(self, env, num_repeats):
190 | self._env = env
191 | self._num_repeats = num_repeats
192 |
193 | def step(self, action):
194 | reward = 0.0
195 | discount = 1.0
196 | for i in range(self._num_repeats):
197 | time_step = self._env.step(action)
198 | reward += (time_step.reward or 0.0) * discount
199 | discount *= time_step.discount
200 | if time_step.last():
201 | break
202 |
203 | return time_step._replace(reward=reward, discount=discount)
204 |
205 | def observation_spec(self):
206 | return self._env.observation_spec()
207 |
208 | def action_spec(self):
209 | return self._env.action_spec()
210 |
211 | def reset(self, **kwargs):
212 | return self._env.reset(**kwargs)
213 |
214 | def __getattr__(self, name):
215 | return getattr(self._env, name)
216 |
217 |
218 | class FrameStackWrapper(dm_env.Environment):
219 | def __init__(self, env, num_frames):
220 | self._env = env
221 | self._num_frames = num_frames
222 |
223 | self.pixel_keys = [
224 | keys for keys in env.observation_spec().keys() if "pixels" in keys
225 | ]
226 | wrapped_obs_spec = env.observation_spec()[self.pixel_keys[0]]
227 |
228 | # frames lists
229 | self._frames = {}
230 | for key in self.pixel_keys:
231 | self._frames[key] = deque([], maxlen=num_frames)
232 |
233 | pixels_shape = wrapped_obs_spec.shape
234 | if len(pixels_shape) == 4:
235 | pixels_shape = pixels_shape[1:]
236 | self._obs_spec = {}
237 | self._obs_spec["features"] = self._env.observation_spec()["features"]
238 | self._obs_spec["proprioceptive"] = self._env.observation_spec()[
239 | "proprioceptive"
240 | ]
241 | for key in self._env.observation_spec().keys():
242 | if key.startswith("sensor"):
243 | self._obs_spec[key] = self._env.observation_spec()[key]
244 | if key.startswith("digit"):
245 | self._obs_spec[key] = self._env.observation_spec()[key]
246 | for key in self.pixel_keys:
247 | self._obs_spec[key] = specs.BoundedArray(
248 | shape=np.concatenate(
249 | [[pixels_shape[2] * num_frames], pixels_shape[:2]], axis=0
250 | ),
251 | dtype=np.uint8,
252 | minimum=0,
253 | maximum=255,
254 | name=key,
255 | )
256 |
257 | def _transform_observation(self, time_step):
258 | for key in self.pixel_keys:
259 | assert len(self._frames[key]) == self._num_frames
260 | obs = {}
261 | obs["features"] = time_step.observation["features"]
262 | for key in self.pixel_keys:
263 | obs[key] = np.concatenate(list(self._frames[key]), axis=0)
264 | obs["proprioceptive"] = time_step.observation["proprioceptive"]
265 | try:
266 | for key in time_step.observation.keys():
267 | if key.startswith("sensor"):
268 | obs[key] = time_step.observation[key]
269 | except KeyError:
270 | pass
271 | obs["goal_achieved"] = time_step.observation["goal_achieved"]
272 | return time_step._replace(observation=obs)
273 |
274 | def _extract_pixels(self, time_step):
275 | pixels = {}
276 | for key in self.pixel_keys:
277 | pixels[key] = time_step.observation[key]
278 | if len(pixels[key].shape) == 4:
279 | pixels[key] = pixels[key][0]
280 | pixels[key] = pixels[key].transpose(2, 0, 1)
281 | return pixels
282 |
283 | def reset(self, **kwargs):
284 | time_step = self._env.reset(**kwargs)
285 | pixels = self._extract_pixels(time_step)
286 | for key in self.pixel_keys:
287 | for _ in range(self._num_frames):
288 | self._frames[key].append(pixels[key])
289 | return self._transform_observation(time_step)
290 |
291 | def step(self, action):
292 | time_step = self._env.step(action)
293 | pixels = self._extract_pixels(time_step)
294 | for key in self.pixel_keys:
295 | self._frames[key].append(pixels[key])
296 | return self._transform_observation(time_step)
297 |
298 | def observation_spec(self):
299 | return self._obs_spec
300 |
301 | def action_spec(self):
302 | return self._env.action_spec()
303 |
304 | def __getattr__(self, name):
305 | return getattr(self._env, name)
306 |
307 |
308 | class ActionDTypeWrapper(dm_env.Environment):
309 | def __init__(self, env, dtype):
310 | self._env = env
311 | self._discount = 1.0
312 |
313 | # Action spec
314 | wrapped_action_spec = env.action_spec()
315 | self._action_spec = specs.Array(
316 | shape=wrapped_action_spec.shape, dtype=dtype, name="action"
317 | )
318 |
319 | def step(self, action):
320 | action = action.astype(self._env.action_spec().dtype)
321 | # Make time step for action space
322 | observation, reward, done, info = self._env.step(action)
323 | step_type = StepType.LAST if done else StepType.MID
324 |
325 | return TimeStep(
326 | step_type=step_type,
327 | reward=reward,
328 | discount=self._discount,
329 | observation=observation,
330 | )
331 |
332 | def observation_spec(self):
333 | return self._env.observation_spec()
334 |
335 | def action_spec(self):
336 | return self._action_spec
337 |
338 | def reset(self, **kwargs):
339 | obs = self._env.reset(**kwargs)
340 | return TimeStep(
341 | step_type=StepType.FIRST, reward=0, discount=self._discount, observation=obs
342 | )
343 |
344 | def __getattr__(self, name):
345 | return getattr(self._env, name)
346 |
347 |
348 | class ExtendedTimeStep(NamedTuple):
349 | step_type: Any
350 | reward: Any
351 | discount: Any
352 | observation: Any
353 | action: Any
354 |
355 | def first(self):
356 | return self.step_type == StepType.FIRST
357 |
358 | def mid(self):
359 | return self.step_type == StepType.MID
360 |
361 | def last(self):
362 | return self.step_type == StepType.LAST
363 |
364 | def __getitem__(self, attr):
365 | return getattr(self, attr)
366 |
367 |
368 | class ExtendedTimeStepWrapper(dm_env.Environment):
369 | def __init__(self, env):
370 | self._env = env
371 |
372 | def reset(self, **kwargs):
373 | time_step = self._env.reset(**kwargs)
374 | return self._augment_time_step(time_step)
375 |
376 | def step(self, action):
377 | time_step = self._env.step(action)
378 | return self._augment_time_step(time_step, action)
379 |
380 | def _augment_time_step(self, time_step, action=None):
381 | if action is None:
382 | action_spec = self.action_spec()
383 | action = np.zeros(action_spec.shape, dtype=action_spec.dtype)
384 | return ExtendedTimeStep(
385 | observation=time_step.observation,
386 | step_type=time_step.step_type,
387 | action=action,
388 | reward=time_step.reward or 0.0,
389 | discount=time_step.discount or 1.0,
390 | )
391 |
392 | def _replace(
393 | self, time_step, observation=None, action=None, reward=None, discount=None
394 | ):
395 | if observation is None:
396 | observation = time_step.observation
397 | if action is None:
398 | action = time_step.action
399 | if reward is None:
400 | reward = time_step.reward
401 | if discount is None:
402 | discount = time_step.discount
403 | return ExtendedTimeStep(
404 | observation=observation,
405 | step_type=time_step.step_type,
406 | action=action,
407 | reward=reward,
408 | discount=discount,
409 | )
410 |
411 | def observation_spec(self):
412 | return self._env.observation_spec()
413 |
414 | def action_spec(self):
415 | return self._env.action_spec()
416 |
417 | def __getattr__(self, name):
418 | return getattr(self._env, name)
419 |
420 |
421 | def make(
422 | frame_stack,
423 | action_repeat,
424 | height,
425 | width,
426 | max_episode_len,
427 | max_state_dim,
428 | use_egocentric,
429 | use_fisheye,
430 | task_description,
431 | pixel_keys,
432 | aux_keys,
433 | sensor_params,
434 | eval, # True means use_robot=True
435 | ):
436 | env = gym.make(
437 | "Robot-v1",
438 | height=height,
439 | width=width,
440 | use_robot=eval,
441 | use_egocentric=use_egocentric,
442 | use_fisheye=use_fisheye,
443 | subtract_sensor_baseline=sensor_params.subtract_sensor_baseline,
444 | )
445 |
446 | # apply wrappers
447 | env = RGBArrayAsObservationWrapper(
448 | env,
449 | max_episode_len=max_episode_len,
450 | max_state_dim=max_state_dim,
451 | task_description=task_description,
452 | pixel_keys=pixel_keys,
453 | aux_keys=aux_keys,
454 | use_robot=eval,
455 | )
456 | env = ActionDTypeWrapper(env, np.float32)
457 | env = ActionRepeatWrapper(env, action_repeat)
458 | env = FrameStackWrapper(env, frame_stack)
459 | env = ExtendedTimeStepWrapper(env)
460 |
461 | return [env], [task_description]
462 |
--------------------------------------------------------------------------------
/process_xarm_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | import pickle
4 | import cv2
5 | import pandas as pd
6 | import os
7 | import subprocess
8 | import os
9 | import re
10 | import shutil
11 | import h5py
12 | from pathlib import Path
13 |
14 | from utils import DATA_DIR
15 |
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument("--task-name", "-t", type=str, required=True)
18 | args = parser.parse_args()
19 |
20 | TASK_NAME = args.task_name
21 |
22 | DATA_PATH = Path(DATA_DIR)
23 | SAVE_PATH = Path(DATA_DIR) / "processed_data"
24 |
25 | num_demos = None
26 | cam_indices = {
27 | 1: "rgb",
28 | 2: "rgb",
29 | 51: "fish_eye",
30 | 52: "fish_eye",
31 | }
32 | states_file_name = "states"
33 | sensor_file_name = "sensor"
34 |
35 | # Create the save path
36 | SAVE_PATH.mkdir(parents=True, exist_ok=True)
37 | done_flag = True
38 | skip_cam_processing = False
39 | process_type = "cont"
40 |
41 | print(f"#################### Processing task {TASK_NAME} ####################")
42 |
43 | # Check if previous demos from this task exist
44 | if Path(f"{SAVE_PATH}/{TASK_NAME}").exists():
45 | num_prev_demos = len([f for f in (SAVE_PATH / TASK_NAME).iterdir() if f.is_dir()])
46 | if num_prev_demos > 0:
47 | cont_check = input(
48 | f"Previous demonstrations from task {TASK_NAME} exist. Continue from existing demos? y/n."
49 | )
50 | if cont_check == "n":
51 | ow_check = input(
52 | f"Overwrite existing demonstrations from task {TASK_NAME}? y/n."
53 | )
54 | if ow_check == "y":
55 | num_prev_demos = 0
56 | else:
57 | print("Appending new demonstrations to the existing ones.")
58 | process_type = "append"
59 | elif cont_check == "y":
60 | num_prev_demos -= 1 # overwrite the last demo
61 |
62 | else:
63 | num_prev_demos = 0
64 | (SAVE_PATH / TASK_NAME).mkdir(parents=True, exist_ok=True)
65 |
66 | # demo directories
67 | DEMO_DIRS = [
68 | f
69 | for f in (DATA_PATH / TASK_NAME).iterdir()
70 | if f.is_dir() and "fail" not in f.name and "ignore" not in f.name
71 | ]
72 | if num_demos is not None:
73 | DEMO_DIRS = DEMO_DIRS[:num_demos]
74 |
75 | for num, demo_dir in enumerate(sorted(DEMO_DIRS)):
76 | process_sensor = True
77 | # try:
78 | if process_type == "cont" and num < num_prev_demos:
79 | print(f"Skipping demonstration {demo_dir.name}")
80 | continue
81 | if process_type == "append":
82 | demo_id = num + num_prev_demos
83 | elif process_type == "cont":
84 | demo_id = int(demo_dir.name.split("_")[-1])
85 | print("Processing demonstration", demo_dir.name)
86 | output_path = f"{SAVE_PATH}/{TASK_NAME}/demonstration_{demo_id}/"
87 | print("Output path:", output_path)
88 | Path(output_path).mkdir(parents=True, exist_ok=True)
89 | csv_list = [f for f in os.listdir(output_path) if f.endswith(".csv")]
90 | for f in csv_list:
91 | os.remove(os.path.join(output_path, f))
92 | cam_avis = [f"{demo_dir}/cam_{i}_{cam_indices[i]}_video.avi" for i in cam_indices]
93 |
94 | try:
95 | cartesian = h5py.File(f"{demo_dir}/xarm_cartesian_states.h5", "r")
96 | state_timestamps = cartesian["timestamps"]
97 | state_positions = cartesian["cartesian_positions"]
98 |
99 | gripper = h5py.File(f"{demo_dir}/xarm_gripper_states.h5", "r")
100 | gripper_positions = gripper["gripper_positions"]
101 | except:
102 | print("No cartesian or gripper states found. Skipping this demo.")
103 | continue
104 |
105 | try:
106 | with h5py.File(f"{demo_dir}/reskin_sensor_values.h5") as hf:
107 | sensor_timestamps = np.array(hf["timestamps"])
108 | sensor_values = np.array(hf["sensor_values"])
109 | except FileNotFoundError:
110 | print("No sensor values found. Skipping sensor processing.")
111 | process_sensor = False
112 |
113 | state_positions = np.array(state_positions)
114 | gripper_positions = np.array(gripper_positions)
115 | gripper_positions = gripper_positions.reshape(-1, 1)
116 |
117 | state_timestamps = np.array(state_timestamps)
118 |
119 | # Find indices of timestamps where the robot moves
120 | static_timestamps = []
121 | static = False
122 | start, end = None, None
123 | for i in range(1, len(state_positions) - 1):
124 | if (
125 | np.array_equal(state_positions[i], state_positions[i + 1])
126 | and static == False
127 | ):
128 | static = True
129 | start = i
130 | elif (
131 | not np.array_equal(state_positions[i], state_positions[i + 1])
132 | and static == True
133 | ):
134 | static = False
135 | end = i
136 | static_timestamps.append((start, end))
137 | if static:
138 | static_timestamps.append((start, len(state_positions) - 1))
139 |
140 | # read metadata file
141 | CAM_TIMESTAMPS = []
142 | CAM_VALID_LENS = []
143 | skip = False
144 | for idx in cam_indices:
145 | cam_meta_file_path = f"{demo_dir}/cam_{idx}_{cam_indices[idx]}_video.metadata"
146 | with open(cam_meta_file_path, "rb") as f:
147 | image_metadata = pickle.load(f)
148 | image_timestamps = np.asarray(image_metadata["timestamps"]) / 1000.0
149 |
150 | cam_timestamps = dict(timestamps=image_timestamps)
151 | # convert to numpy array
152 | cam_timestamps = np.array(cam_timestamps["timestamps"])
153 |
154 | # Fish eye cam timestamps are divided by 1000
155 | if max(cam_timestamps) < state_timestamps[static_timestamps[0][1]]:
156 | cam_timestamps *= 1000
157 | elif min(cam_timestamps) > state_timestamps[static_timestamps[-1][0]]:
158 | cam_timestamps /= 1000
159 |
160 | valid_indices = []
161 | for k in range(len(static_timestamps) - 1):
162 | start_idx = sum(cam_timestamps < state_timestamps[static_timestamps[k][1]])
163 | end_idx = sum(
164 | cam_timestamps < state_timestamps[static_timestamps[k + 1][0]]
165 | )
166 | valid_indices.extend([i for i in range(start_idx, end_idx)])
167 | cam_timestamps = cam_timestamps[valid_indices]
168 |
169 | # if no valid timestamps, skip
170 | if len(cam_timestamps) == 0:
171 | skip = True
172 | break
173 |
174 | CAM_VALID_LENS.append(valid_indices)
175 | CAM_TIMESTAMPS.append(cam_timestamps)
176 | if skip:
177 | continue
178 |
179 | # cam frames
180 | if not skip_cam_processing:
181 | CAM_FRAMES = []
182 | for idx in range(len(cam_avis)): # cam_indices:
183 | cam_avi = cam_avis[idx]
184 | cam_frames = []
185 | cap_cap = cv2.VideoCapture(cam_avi)
186 | while cap_cap.isOpened():
187 | ret, frame = cap_cap.read()
188 | if ret == False:
189 | break
190 | cam_frames.append(frame)
191 | cap_cap.release()
192 |
193 | # save frames
194 | cam_frames = np.array(cam_frames)
195 | cam_frames = cam_frames[CAM_VALID_LENS[idx]]
196 | CAM_FRAMES.append(cam_frames)
197 |
198 | rgb_frames = CAM_FRAMES
199 | timestamps = CAM_TIMESTAMPS
200 | timestamps.append(state_timestamps)
201 |
202 | min_time_index = np.argmin([len(timestamp) for timestamp in timestamps])
203 | reference_timestamps = timestamps[min_time_index]
204 | align = []
205 | index = []
206 | for i in range(len(timestamps)):
207 | # aligning frames
208 | if i == min_time_index:
209 | align.append(timestamps[i])
210 | index.append(np.arange(len(timestamps[i])))
211 | continue
212 | curindex = []
213 | currrlist = []
214 | for j in range(len(reference_timestamps)):
215 | curlist = []
216 | for k in range(len(timestamps[i])):
217 | curlist.append(abs(timestamps[i][k] - reference_timestamps[j]))
218 | min_index = curlist.index(min(curlist))
219 | currrlist.append(timestamps[i][min_index])
220 | curindex.append(min_index)
221 | align.append(currrlist)
222 | index.append(curindex)
223 |
224 | index = np.array(index)
225 |
226 | if process_sensor:
227 | if max(sensor_timestamps) < state_timestamps[static_timestamps[0][1]]:
228 | print("Sensor data issue")
229 | elif min(sensor_timestamps) > state_timestamps[static_timestamps[-1][0]]:
230 | print("Sensor data issue")
231 | else:
232 | print("All good with sensor data!")
233 | sensor_valid_indices = []
234 | for k in range(len(static_timestamps) - 1):
235 | start_idx = sum(
236 | sensor_timestamps < state_timestamps[static_timestamps[k][1]]
237 | )
238 | end_idx = sum(
239 | sensor_timestamps < state_timestamps[static_timestamps[k + 1][0]]
240 | )
241 | sensor_valid_indices.extend([i for i in range(start_idx, end_idx)])
242 |
243 | sensor_timestamps = sensor_timestamps[sensor_valid_indices]
244 | sensor_values = sensor_values[sensor_valid_indices]
245 |
246 | sensor_timestamps_test = pd.DataFrame(sensor_timestamps)
247 | sensor_values_test = []
248 | for i in range(len(sensor_values)):
249 | sensor_values_test.append(np.array(sensor_values[i]))
250 | sensor_values_test = pd.DataFrame(
251 | {"column": [list(row) for row in sensor_values_test]}
252 | )
253 |
254 | sensor_test = pd.concat(
255 | [sensor_timestamps_test, sensor_values_test],
256 | axis=1,
257 | )
258 |
259 | with open(output_path + f"big_{sensor_file_name}.csv", "a") as f:
260 | sensor_test.to_csv(
261 | f,
262 | header=["created timestamp", "sensor_values"],
263 | index=False,
264 | )
265 |
266 | # convert left_state_timestamps and left_state_positions to a csv file with header "created timestamp", "pose_aa", "gripper_state"
267 | state_timestamps_test = pd.DataFrame(state_timestamps)
268 | # convert each pose_aa to a list
269 | state_positions_test = state_positions
270 | for i in range(len(state_positions_test)):
271 | state_positions_test[i] = np.array(state_positions_test[i])
272 | state_positions_test = pd.DataFrame(
273 | {"column": [list(row) for row in state_positions_test]}
274 | )
275 | # convert left_gripper to True and False
276 | gripper_positions_test = pd.DataFrame(gripper_positions)
277 |
278 | state_test = pd.concat(
279 | [state_timestamps_test, state_positions_test, gripper_positions_test],
280 | axis=1,
281 | )
282 | with open(output_path + f"big_{states_file_name}.csv", "a") as f:
283 | state_test.to_csv(
284 | f,
285 | header=["created timestamp", "pose_aa", "gripper_state"],
286 | index=False,
287 | )
288 |
289 | df = pd.read_csv(output_path + f"big_{states_file_name}.csv")
290 | for i in range(len(reference_timestamps)):
291 | curlist = []
292 | for j in range(len(state_timestamps)):
293 | curlist.append(abs(state_timestamps[j] - reference_timestamps[i]))
294 | min_index = curlist.index(min(curlist))
295 | min_df = df.iloc[min_index]
296 | min_df = min_df.to_frame().transpose()
297 | with open(output_path + f"{states_file_name}.csv", "a") as f:
298 | min_df.to_csv(f, header=f.tell() == 0, index=False)
299 |
300 | if process_sensor:
301 | df = pd.read_csv(output_path + f"big_{sensor_file_name}.csv")
302 | for i in range(len(reference_timestamps)):
303 | curlist = []
304 | for j in range(len(sensor_timestamps)):
305 | curlist.append(abs(sensor_timestamps[j] - reference_timestamps[i]))
306 | min_index = curlist.index(min(curlist))
307 | min_df = df.iloc[min_index]
308 | min_df = min_df.to_frame().transpose()
309 | with open(output_path + f"{sensor_file_name}.csv", "a") as f:
310 | min_df.to_csv(f, header=f.tell() == 0, index=False)
311 |
312 | # Create folders for each camera if they don't exist
313 | output_folder = output_path + "videos"
314 | os.makedirs(output_folder, exist_ok=True)
315 | camera_folders = [f"camera{i}" for i in cam_indices]
316 | for folder in camera_folders:
317 | os.makedirs(os.path.join(output_folder, folder), exist_ok=True)
318 |
319 | # Iterate over each camera and extract the frames based on the indexes
320 | if not skip_cam_processing:
321 | for camera_index, frames in enumerate(rgb_frames):
322 | camera_folder = camera_folders[camera_index]
323 | print(f"Extracting frames for {camera_folder}...")
324 | indexes = index[camera_index]
325 |
326 | # Iterate over the indexes and save the corresponding frames
327 | for i, indexx in enumerate(indexes):
328 | if i % 100 == 0:
329 | print(f"Extracting frame {i}...")
330 | frame = frames[indexx]
331 | # name frame with its timestamp
332 | image_output_path = os.path.join(
333 | output_folder,
334 | camera_folder,
335 | f"frame_{i}_{timestamps[camera_index][indexx]}.jpg",
336 | )
337 | cv2.imwrite(image_output_path, frame)
338 |
339 | csv_file = os.path.join(output_path, f"{states_file_name}.csv")
340 | print(output_path, demo_dir.name)
341 |
342 | def get_timestamp_from_filename(filename):
343 | # Extract the timestamp from the filename using regular expression
344 | timestamp_match = re.search(r"\d+\.\d+", filename)
345 | if timestamp_match:
346 | return float(timestamp_match.group())
347 | else:
348 | return None
349 |
350 | # add desired gripper states
351 | for file in [csv_file]:
352 | df = pd.read_csv(file)
353 | df["desired_gripper_state"] = df["gripper_state"].shift(-1)
354 | df.loc[df.index[-1], "desired_gripper_state"] = df.loc[
355 | df.index[-2], "gripper_state"
356 | ]
357 | df.to_csv(file, index=False)
358 |
359 | def save_only_videos(base_folder_path):
360 | base_folder_path = os.path.join(base_folder_path, "videos")
361 | # Iterate over each camera folder
362 | for cam in cam_indices:
363 | cam_folder = f"camera{cam}"
364 | full_folder_path = os.path.join(base_folder_path, cam_folder)
365 |
366 | # Check if the folder exists
367 | if os.path.exists(full_folder_path):
368 | # List all jpg files
369 | all_files = [
370 | f for f in os.listdir(full_folder_path) if f.endswith(".jpg")
371 | ]
372 |
373 | # Sort files based on the floating-point number in their name
374 | sorted_files = sorted(all_files, key=get_timestamp_from_filename)
375 |
376 | # Write filenames to a temp file
377 | temp_list_filename = os.path.join(base_folder_path, "temp_list.txt")
378 | with open(temp_list_filename, "w") as f:
379 | for filename in sorted_files:
380 | f.write(f"file '{os.path.join(full_folder_path, filename)}'\n")
381 |
382 | # Use ffmpeg to convert sorted images to video
383 | output_video_path = os.path.join(base_folder_path, f"camera{cam}.mp4")
384 | cmd = [
385 | "ffmpeg",
386 | "-f",
387 | "concat",
388 | "-safe",
389 | "0",
390 | "-i",
391 | temp_list_filename,
392 | "-framerate",
393 | "30", # assuming 24 fps, change if needed
394 | "-vcodec",
395 | "libx264",
396 | "-crf",
397 | "18", # quality, lower means better quality
398 | "-pix_fmt",
399 | "yuv420p",
400 | output_video_path,
401 | ]
402 | try:
403 | subprocess.run(cmd, check=True)
404 | except Exception as e:
405 | print(f"EXCEPTION: {e}")
406 | input("Continue?")
407 |
408 | # Delete the temporary list file and the image folder
409 | os.remove(temp_list_filename)
410 | shutil.rmtree(full_folder_path)
411 | else:
412 | print(f"Folder {cam_folder} does not exist!")
413 |
414 | if not skip_cam_processing:
415 | save_only_videos(output_path)
416 |
--------------------------------------------------------------------------------
/read_data/xarm_env_aa.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | import cv2
3 | import random
4 | import numpy as np
5 | import pickle as pkl
6 | from pathlib import Path
7 |
8 | import torch
9 | import torchvision.transforms as transforms
10 | from torch.utils.data import IterableDataset
11 | from scipy.spatial.transform import Rotation as R
12 |
13 |
14 | def get_relative_action(actions, action_after_steps):
15 | """
16 | Convert absolute axis angle actions to relative axis angle actions
17 | Action has both position and orientation. Convert to transformation matrix, get
18 | relative transformation matrix, convert back to axis angle
19 | """
20 |
21 | relative_actions = []
22 | for i in range(len(actions)):
23 | # Get relative transformation matrix
24 | # previous pose
25 | pos_prev = actions[i, :3]
26 | ori_prev = actions[i, 3:6]
27 | r_prev = R.from_rotvec(ori_prev).as_matrix()
28 | matrix_prev = np.eye(4)
29 | matrix_prev[:3, :3] = r_prev
30 | matrix_prev[:3, 3] = pos_prev
31 | # current pose
32 | next_idx = min(i + action_after_steps, len(actions) - 1)
33 | pos = actions[next_idx, :3]
34 | ori = actions[next_idx, 3:6]
35 | gripper = actions[next_idx, 6:]
36 | r = R.from_rotvec(ori).as_matrix()
37 | matrix = np.eye(4)
38 | matrix[:3, :3] = r
39 | matrix[:3, 3] = pos
40 | # relative transformation
41 | matrix_rel = np.linalg.inv(matrix_prev) @ matrix
42 | # relative pose
43 | # pos_rel = matrix_rel[:3, 3]
44 | pos_rel = pos - pos_prev
45 | r_rel = R.from_matrix(matrix_rel[:3, :3]).as_rotvec()
46 | # # compute relative rotation
47 | # r_prev = R.from_rotvec(ori_prev).as_matrix()
48 | # r = R.from_rotvec(ori).as_matrix()
49 | # r_rel = np.linalg.inv(r_prev) @ r
50 | # r_rel = R.from_matrix(r_rel).as_rotvec()
51 | # # compute relative translation
52 | # pos_rel = pos - pos_prev
53 | relative_actions.append(np.concatenate([pos_rel, r_rel, gripper]))
54 | # next_idx = min(i + action_after_steps, len(actions) - 1)
55 | # curr_pose, _ = actions[i, :6], actions[i, 6:]
56 | # next_pose, next_gripper = actions[next_idx, :6], actions[next_idx, 6:]
57 |
58 | # last action
59 | last_action = np.zeros_like(actions[-1])
60 | last_action[-1] = actions[-1][-1]
61 | while len(relative_actions) < len(actions):
62 | relative_actions.append(last_action)
63 | return np.array(relative_actions, dtype=np.float32)
64 |
65 |
66 | def get_absolute_action(rel_actions, base_action):
67 | """
68 | Convert relative axis angle actions to absolute axis angle actions
69 | """
70 | actions = np.zeros((len(rel_actions) + 1, rel_actions.shape[-1]))
71 | actions[0] = base_action
72 | for i in range(1, len(rel_actions) + 1):
73 | # if i == 0:
74 | # actions.append(base_action)
75 | # continue
76 | # Get relative transformation matrix
77 | # previous pose
78 | pos_prev = actions[i - 1, :3]
79 | ori_prev = actions[i - 1, 3:6]
80 | r_prev = R.from_rotvec(ori_prev).as_matrix()
81 | matrix_prev = np.eye(4)
82 | matrix_prev[:3, :3] = r_prev
83 | matrix_prev[:3, 3] = pos_prev
84 | # relative pose
85 | pos_rel = rel_actions[i - 1, :3]
86 | r_rel = rel_actions[i - 1, 3:6]
87 | # compute relative transformation matrix
88 | matrix_rel = np.eye(4)
89 | matrix_rel[:3, :3] = R.from_rotvec(r_rel).as_matrix()
90 | matrix_rel[:3, 3] = pos_rel
91 | # compute absolute transformation matrix
92 | matrix = matrix_prev @ matrix_rel
93 | # absolute pose
94 | pos = matrix[:3, 3]
95 | # r = R.from_matrix(matrix[:3, :3]).as_rotvec()
96 | r = R.from_matrix(matrix[:3, :3]).as_euler("xyz")
97 | actions[i] = np.concatenate([pos, r, rel_actions[i - 1, 6:]])
98 | return np.array(actions, dtype=np.float32)
99 |
100 |
101 | def get_quaternion_orientation(cartesian):
102 | """
103 | Get quaternion orientation from axis angle representation
104 | """
105 | new_cartesian = []
106 | for i in range(len(cartesian)):
107 | pos = cartesian[i, :3]
108 | ori = cartesian[i, 3:]
109 | quat = R.from_rotvec(ori).as_quat()
110 | new_cartesian.append(np.concatenate([pos, quat], axis=-1))
111 | return np.array(new_cartesian, dtype=np.float32)
112 |
113 |
114 | class BCDataset(IterableDataset):
115 | def __init__(
116 | self,
117 | path,
118 | tasks,
119 | num_demos_per_task,
120 | temporal_agg,
121 | num_queries,
122 | img_size,
123 | action_after_steps,
124 | store_actions,
125 | pixel_keys,
126 | aux_keys,
127 | subsample,
128 | skip_first_n,
129 | relative_actions,
130 | random_mask_proprio,
131 | sensor_params,
132 | ):
133 | self._img_size = img_size
134 | self._action_after_steps = action_after_steps
135 | self._store_actions = store_actions
136 | self._pixel_keys = pixel_keys
137 | self._aux_keys = aux_keys
138 | self._random_mask_proprio = random_mask_proprio
139 |
140 | self._sensor_type = sensor_params.sensor_type
141 | self._subtract_sensor_baseline = sensor_params.subtract_sensor_baseline
142 | if self._sensor_type == "digit":
143 | use_anyskin_data = False
144 | elif self._sensor_type == "reskin":
145 | use_anyskin_data = True
146 | else:
147 | assert self._sensor_type == None
148 |
149 | self._num_anyskin_sensors = 2
150 |
151 | # temporal aggregation
152 | self._temporal_agg = temporal_agg
153 | self._num_queries = num_queries
154 |
155 | # get data paths
156 | self._paths = []
157 | self._paths.extend([Path(path) / f"{task}.pkl" for task in tasks])
158 |
159 | paths = {}
160 | idx = 0
161 | for path in self._paths:
162 | paths[idx] = path
163 | idx += 1
164 | del self._paths
165 | self._paths = paths
166 |
167 | # store actions
168 | if self._store_actions:
169 | self.actions = []
170 |
171 | # read data
172 | self._episodes = {}
173 | self._max_episode_len = 0
174 | self._max_state_dim = 7
175 | self._num_samples = 0
176 | min_stat, max_stat = None, None
177 | min_sensor_stat, max_sensor_stat = None, None
178 | digit_mean_stat, digit_std_stat = defaultdict(list), defaultdict(list)
179 | min_sensor_diff_stat, max_sensor_diff_stat = None, None
180 | min_act, max_act = None, None
181 | self.prob = []
182 | sensor_states = []
183 | for _path_idx in self._paths:
184 | print(f"Loading {str(self._paths[_path_idx])}")
185 | # Add to prob
186 | if "fridge" in str(self._paths[_path_idx]):
187 | # self.prob.append(25.0/11.0)
188 | self.prob.append(22.0 / 9.0)
189 | else:
190 | self.prob.append(1)
191 | # read
192 | data = pkl.load(open(str(self._paths[_path_idx]), "rb"))
193 | observations = data["observations"]
194 | # store
195 | self._episodes[_path_idx] = []
196 | for i in range(min(num_demos_per_task, len(observations))):
197 | # compute actions
198 | # absolute actions
199 | actions = np.concatenate(
200 | [
201 | observations[i]["cartesian_states"],
202 | observations[i]["gripper_states"][:, None],
203 | ],
204 | axis=1,
205 | )
206 | if len(actions) == 0:
207 | continue
208 | # skip first n
209 | if skip_first_n is not None:
210 | for key in observations[i].keys():
211 | observations[i][key] = observations[i][key][skip_first_n:]
212 | actions = actions[skip_first_n:]
213 | # subsample
214 | if subsample is not None:
215 | for key in observations[i].keys():
216 | observations[i][key] = observations[i][key][::subsample]
217 | actions = actions[::subsample]
218 | # action after steps
219 | if relative_actions:
220 | actions = get_relative_action(actions, self._action_after_steps)
221 | else:
222 | actions = actions[self._action_after_steps :]
223 | # Convert cartesian states to quaternion orientation
224 | observations[i]["cartesian_states"] = get_quaternion_orientation(
225 | observations[i]["cartesian_states"]
226 | )
227 | if use_anyskin_data:
228 | try:
229 | sensor_baseline = np.median(
230 | observations[i]["sensor_states"][:5], axis=0, keepdims=True
231 | )
232 | if self._subtract_sensor_baseline:
233 | observations[i]["sensor_states"] = (
234 | observations[i]["sensor_states"] - sensor_baseline
235 | )
236 | if max_sensor_stat is None:
237 | max_sensor_stat = np.max(
238 | observations[i]["sensor_states"], axis=0
239 | )
240 | min_sensor_stat = np.min(
241 | observations[i]["sensor_states"], axis=0
242 | )
243 | else:
244 | max_sensor_stat = np.maximum(
245 | max_sensor_stat,
246 | np.max(observations[i]["sensor_states"], axis=0),
247 | )
248 | min_sensor_stat = np.minimum(
249 | min_sensor_stat,
250 | np.min(observations[i]["sensor_states"], axis=0),
251 | )
252 | for sensor_idx in range(self._num_anyskin_sensors):
253 | observations[i][
254 | f"sensor{sensor_idx}_states"
255 | ] = observations[i]["sensor_states"][
256 | ..., sensor_idx * 15 : (sensor_idx + 1) * 15
257 | ]
258 |
259 | except KeyError:
260 | print("WARN: Sensor data not found.")
261 | use_anyskin_data = False
262 | elif self._sensor_type == "digit":
263 | for key in self._aux_keys:
264 | if key.startswith("digit"):
265 | observations[i][key] = (
266 | observations[i][key].astype(np.float32) / 255.0
267 | )
268 | if self._subtract_sensor_baseline:
269 | sensor_baseline = np.median(
270 | observations[i][key][:5], axis=0, keepdims=True
271 | ) # .astype(observations[i][key].dtype)
272 | observations[i][key] = (
273 | observations[i][key] - sensor_baseline
274 | )
275 | delta_filter = np.abs(observations[i][key]) > (
276 | 5.0 / 255.0
277 | )
278 | digit_std_stat[key].append(
279 | observations[i][key][delta_filter]
280 | )
281 | else:
282 | pass
283 |
284 | for key in observations[i].keys():
285 | observations[i][key] = np.concatenate(
286 | [
287 | [observations[i][key][0]],
288 | observations[i][key],
289 | ],
290 | axis=0,
291 | )
292 |
293 | remaining_actions = actions[0]
294 | if relative_actions:
295 | pos = remaining_actions[:-1]
296 | ori_gripper = remaining_actions[-1:]
297 | remaining_actions = np.concatenate(
298 | [np.zeros_like(pos), ori_gripper]
299 | )
300 | actions = np.concatenate(
301 | [
302 | [remaining_actions],
303 | actions,
304 | ],
305 | axis=0,
306 | )
307 | # store
308 | episode = dict(
309 | observation=observations[i],
310 | action=actions,
311 | # task_emb=task_emb,
312 | )
313 | self._episodes[_path_idx].append(episode)
314 | self._max_episode_len = max(
315 | self._max_episode_len,
316 | (
317 | len(observations[i])
318 | if not isinstance(observations[i], dict)
319 | else len(observations[i][self._pixel_keys[0]])
320 | ),
321 | )
322 | self._num_samples += len(observations[i][self._pixel_keys[0]])
323 |
324 | # max, min action
325 | if min_act is None:
326 | min_act = np.min(actions, axis=0)
327 | max_act = np.max(actions, axis=0)
328 | else:
329 | min_act = np.minimum(min_act, np.min(actions, axis=0))
330 | max_act = np.maximum(max_act, np.max(actions, axis=0))
331 |
332 | # store actions
333 | if self._store_actions:
334 | self.actions.append(actions)
335 |
336 | # keep record of max and min stat
337 | max_cartesian = data["max_cartesian"]
338 | min_cartesian = data["min_cartesian"]
339 | max_cartesian = np.concatenate(
340 | [data["max_cartesian"][:3], [1] * 4]
341 | ) # for quaternion
342 | min_cartesian = np.concatenate(
343 | [data["min_cartesian"][:3], [-1] * 4]
344 | ) # for quaternion
345 | max_gripper = data["max_gripper"]
346 | min_gripper = data["min_gripper"]
347 | max_val = np.concatenate([max_cartesian, max_gripper[None]], axis=0)
348 | min_val = np.concatenate([min_cartesian, min_gripper[None]], axis=0)
349 | if max_stat is None:
350 | max_stat = max_val
351 | min_stat = min_val
352 | else:
353 | max_stat = np.maximum(max_stat, max_val)
354 | min_stat = np.minimum(min_stat, min_val)
355 | if use_anyskin_data:
356 | # If baseline is subtracted, use zero as shift and max as scale
357 | if self._subtract_sensor_baseline:
358 | max_sensor_stat = np.maximum(
359 | np.abs(max_sensor_stat), np.abs(min_sensor_stat)
360 | )
361 | min_sensor_stat = np.zeros_like(max_sensor_stat)
362 | # If baseline isn't subtracted, use usual min and max values
363 | else:
364 | if max_sensor_stat is None:
365 | max_sensor_stat = data["max_sensor"]
366 | min_sensor_stat = data["min_sensor"]
367 | else:
368 | max_sensor_stat = np.maximum(
369 | max_sensor_stat, data["max_sensor"]
370 | )
371 | min_sensor_stat = np.minimum(
372 | min_sensor_stat, data["min_sensor"]
373 | )
374 | min_act[3:6], max_act[3:6] = 0, 1 #################################
375 | self.stats = {
376 | "actions": {
377 | "min": min_act, # min_stat,
378 | "max": max_act, # max_stat,
379 | },
380 | "proprioceptive": {
381 | "min": min_stat,
382 | "max": max_stat,
383 | },
384 | }
385 | if use_anyskin_data:
386 | for sensor_idx in range(self._num_anyskin_sensors):
387 | sensor_mask = np.zeros_like(min_sensor_stat, dtype=bool)
388 | sensor_mask[sensor_idx * 15 : (sensor_idx + 1) * 15] = True
389 | self.stats[f"sensor{sensor_idx}"] = {
390 | "min": min_sensor_stat[sensor_mask],
391 | "max": max_sensor_stat[sensor_mask],
392 | }
393 |
394 | if not self._subtract_sensor_baseline:
395 | raise NotImplementedError(
396 | "Normalization not implemented without baseline subtraction"
397 | )
398 | for key in self.stats:
399 | if key.startswith("sensor"):
400 | sensor_states = np.concatenate(
401 | [
402 | observations[i][f"{key}_states"]
403 | for i in range(len(observations))
404 | ],
405 | axis=0,
406 | )
407 | sensor_std = (
408 | np.std(sensor_states, axis=0).reshape((5, 3)).max(axis=0)
409 | )
410 | sensor_std[:2] = sensor_std[:2].max()
411 | sensor_std = np.clip(sensor_std * 3, a_min=100, a_max=None)
412 | # max_xyz = np.clip(max_xyz, a_min=400, a_max=None)
413 | self.stats[key]["max"] = np.tile(
414 | sensor_std, int(self.stats[key]["max"].shape[0] / 3)
415 | )
416 | elif self._sensor_type == "digit":
417 | if self._subtract_sensor_baseline:
418 | shared_mean, shared_std = None, None
419 | for key in digit_std_stat:
420 | digit_std_stat[key] = [
421 | 3 * np.concatenate(digit_std_stat[key], axis=0).std()
422 | ] * 3
423 | digit_mean_stat[key] = [0.0] * 3
424 | self.stats[key] = {
425 | "mean": np.array(digit_mean_stat[key])[:, None, None],
426 | "std": np.array(digit_std_stat[key])[:, None, None],
427 | }
428 | if shared_std is None:
429 | shared_std = digit_std_stat[key]
430 | shared_mean = digit_mean_stat[key]
431 | else:
432 | shared_std = np.maximum(shared_std, digit_std_stat[key])
433 | shared_mean = np.minimum(shared_mean, digit_mean_stat[key])
434 | else:
435 | shared_mean = [0.485, 0.456, 0.406]
436 | shared_std = [0.229, 0.224, 0.225]
437 | self.stats["digit"] = {
438 | "mean": np.array(shared_mean)[:, None, None],
439 | "std": np.array(shared_std)[:, None, None],
440 | }
441 | self.digit_aug = transforms.Compose(
442 | [
443 | transforms.ToTensor(),
444 | ]
445 | )
446 | # augmentation
447 | self.aug = transforms.Compose(
448 | [
449 | transforms.ToPILImage(),
450 | transforms.RandomCrop(self._img_size, padding=4),
451 | transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2),
452 | transforms.ToTensor(),
453 | ]
454 | )
455 |
456 | # Samples from envs
457 | self.envs_till_idx = len(self._episodes)
458 |
459 | self.prob = np.array(self.prob) / np.sum(self.prob)
460 |
461 | def preprocess(self, key, x):
462 | if key.startswith("digit"):
463 | return (x - self.stats["digit"]["mean"]) / (
464 | self.stats["digit"]["std"] + 1e-5
465 | )
466 | return (x - self.stats[key]["min"]) / (
467 | self.stats[key]["max"] - self.stats[key]["min"] + 1e-5
468 | )
469 |
470 | def _sample_episode(self, env_idx=None):
471 | idx = random.randint(0, self.envs_till_idx - 1) if env_idx is None else env_idx
472 |
473 | # sample idx with probability
474 | idx = np.random.choice(list(self._episodes.keys()), p=self.prob)
475 |
476 | episode = random.choice(self._episodes[idx])
477 | return (episode, idx) if env_idx is None else episode
478 |
479 | def _sample(self):
480 | episodes, env_idx = self._sample_episode()
481 | observations = episodes["observation"]
482 | actions = episodes["action"]
483 | sample_idx = np.random.randint(1, len(observations[self._pixel_keys[0]]) - 1)
484 | # Sample obs, action
485 | sampled_pixel = {}
486 | for key in self._pixel_keys:
487 | sampled_pixel[key] = observations[key][-(sample_idx + 1) : -sample_idx]
488 | sampled_pixel[key] = torch.stack(
489 | [
490 | self.aug(sampled_pixel[key][i])
491 | for i in range(len(sampled_pixel[key]))
492 | ]
493 | )
494 | sampled_state = {}
495 | sampled_state = {}
496 |
497 | sampled_state = {}
498 | sampled_state["proprioceptive"] = np.concatenate(
499 | [
500 | observations["cartesian_states"][-(sample_idx + 1) : -sample_idx],
501 | observations["gripper_states"][-(sample_idx + 1) : -sample_idx][
502 | :, None
503 | ],
504 | ],
505 | axis=1,
506 | )
507 |
508 | if self._random_mask_proprio and np.random.rand() < 0.5:
509 | sampled_state["proprioceptive"] = (
510 | np.ones_like(sampled_state["proprioceptive"])
511 | * self.stats["proprioceptive"]["min"]
512 | )
513 | if self._sensor_type == "reskin":
514 | try:
515 | for sensor_idx in range(self._num_anyskin_sensors):
516 | skey = f"sensor{sensor_idx}"
517 | sampled_state[f"{skey}"] = observations[f"{skey}_states"][
518 | -(sample_idx + 1) : -sample_idx
519 | ]
520 | except KeyError:
521 | pass
522 | elif self._sensor_type == "digit":
523 | try:
524 | for sensor_idx in range(self._num_anyskin_sensors):
525 | key = f"digit{80 + sensor_idx}"
526 | sampled_state[key] = observations[key][
527 | -(sample_idx + 1) : -sample_idx
528 | ]
529 | sampled_state[key] = torch.stack(
530 | [
531 | self.digit_aug(sampled_state[key][i])
532 | # self.aug(sampled_state[key][i])
533 | for i in range(len(sampled_state[key]))
534 | ]
535 | )
536 | except KeyError as e:
537 | pass
538 |
539 | if self._temporal_agg:
540 | # arrange sampled action to be of shape (1, num_queries, action_dim)
541 | sampled_action = np.zeros((1, self._num_queries, actions.shape[-1]))
542 | num_actions = 1 + self._num_queries - 1
543 | act = np.zeros((num_actions, actions.shape[-1]))
544 | if num_actions - sample_idx < 0:
545 | act[:num_actions] = actions[-(sample_idx) : -sample_idx + num_actions]
546 | else:
547 | act[:sample_idx] = actions[-sample_idx:]
548 | act[sample_idx:] = actions[-1]
549 | sampled_action = np.lib.stride_tricks.sliding_window_view(
550 | act, (self._num_queries, actions.shape[-1])
551 | )
552 | sampled_action = sampled_action[:, 0]
553 | else:
554 | sampled_action = actions[-(sample_idx + 1) : -sample_idx]
555 |
556 | return_dict = {}
557 | for key in self._pixel_keys:
558 | return_dict[key] = sampled_pixel[key]
559 | for key in self._aux_keys:
560 | return_dict[key] = self.preprocess(key, sampled_state[key])
561 | return_dict["actions"] = self.preprocess("actions", sampled_action)
562 | return_dict["actions"] = self.preprocess("actions", sampled_action)
563 | return_dict["actions"] = self.preprocess("actions", sampled_action)
564 | return return_dict
565 |
566 | def __iter__(self):
567 | while True:
568 | yield self._sample()
569 |
570 | def __len__(self):
571 | return self._num_samples
572 |
--------------------------------------------------------------------------------
/agent/bc.py:
--------------------------------------------------------------------------------
1 | import einops
2 | import hydra
3 | import numpy as np
4 | from collections import deque
5 |
6 | import torch
7 | from torch import nn
8 |
9 | from torchvision import transforms as T
10 |
11 | from agent.networks.rgb_modules import BaseEncoder, ResnetEncoder
12 | import utils
13 |
14 | from agent.networks.policy_head import DeterministicHead
15 | from agent.networks.gpt import GPT, GPTConfig, CrossAttention
16 | from agent.networks.mlp import MLP
17 |
18 |
19 | class Actor(nn.Module):
20 | def __init__(
21 | self,
22 | repr_dim,
23 | act_dim,
24 | hidden_dim,
25 | policy_type="gpt",
26 | policy_head="deterministic",
27 | num_feat_per_step=1,
28 | ):
29 | super().__init__()
30 |
31 | self._policy_type = policy_type
32 | self._policy_head = policy_head
33 | self._repr_dim = repr_dim
34 | self._act_dim = act_dim
35 | self._num_feat_per_step = num_feat_per_step
36 |
37 | self._action_token = nn.Parameter(torch.randn(1, 1, 1, repr_dim))
38 |
39 | # GPT model
40 | if policy_type == "gpt":
41 | self._policy = GPT(
42 | GPTConfig(
43 | block_size=65, # 50, # 51, # 50,
44 | input_dim=repr_dim,
45 | output_dim=hidden_dim,
46 | n_layer=8,
47 | n_head=4,
48 | n_embd=hidden_dim,
49 | dropout=0.1, # 0.6, #0.1,
50 | )
51 | )
52 | else:
53 | raise NotImplementedError
54 | self._action_head = DeterministicHead(
55 | hidden_dim, self._act_dim, hidden_size=hidden_dim, num_layers=2
56 | )
57 | self.apply(utils.weight_init)
58 |
59 | def forward(
60 | self,
61 | obs,
62 | num_prompt_feats,
63 | stddev,
64 | action=None,
65 | cluster_centers=None,
66 | mask=None,
67 | ):
68 | B, T, D = obs.shape
69 | if self._policy_type == "mlp":
70 | if T * D < self._repr_dim:
71 | gt_num_time_steps = (
72 | self._repr_dim // D - num_prompt_feats
73 | ) // self._num_feat_per_step
74 | num_repeat = (
75 | gt_num_time_steps
76 | - (T - num_prompt_feats) // self._num_feat_per_step
77 | )
78 | initial_obs = obs[
79 | :, num_prompt_feats : num_prompt_feats + self._num_feat_per_step
80 | ]
81 | initial_obs = initial_obs.repeat(1, num_repeat, 1)
82 | obs = torch.cat(
83 | [obs[:, :num_prompt_feats], initial_obs, obs[:, num_prompt_feats:]],
84 | dim=1,
85 | )
86 | B, T, D = obs.shape
87 | obs = obs.view(B, 1, T * D)
88 | features = self._policy(obs)
89 | elif self._policy_type == "gpt":
90 | # insert action token at each self._num_feat_per_step interval
91 | prompt = obs[:, :num_prompt_feats]
92 | obs = obs[:, num_prompt_feats:]
93 | obs = obs.view(B, -1, self._num_feat_per_step, obs.shape[-1])
94 | action_token = self._action_token.repeat(B, obs.shape[1], 1, 1)
95 | obs = torch.cat([obs, action_token], dim=-2).view(B, -1, D)
96 | obs = torch.cat([prompt, obs], dim=1)
97 |
98 | if mask is not None:
99 | mask = torch.cat([mask, torch.ones(B, 1).to(mask.device)], dim=1)
100 | mask = mask.view(B, -1, 1, self._num_feat_per_step + 1)
101 | base_mask = torch.ones(
102 | B,
103 | mask.shape[1],
104 | self._num_feat_per_step + 1,
105 | self._num_feat_per_step + 1,
106 | ).to(mask.device)
107 | base_mask[:, :, -1:] = mask
108 |
109 | # get action features
110 | features = self._policy(obs, mask=base_mask if mask is not None else None)
111 | features = features[:, num_prompt_feats:]
112 | num_feat_per_step = self._num_feat_per_step + 1 # +1 for action token
113 | features = features[:, num_feat_per_step - 1 :: num_feat_per_step]
114 |
115 | # action head
116 | pred_action = self._action_head(
117 | features,
118 | stddev,
119 | **{"cluster_centers": cluster_centers, "action_seq": action},
120 | )
121 |
122 | if action is None:
123 | return pred_action
124 | else:
125 | loss = self._action_head.loss_fn(
126 | pred_action,
127 | action,
128 | reduction="mean",
129 | **{"cluster_centers": cluster_centers},
130 | )
131 | return pred_action, loss[0] if isinstance(loss, tuple) else loss
132 |
133 |
134 | class BCAgent:
135 | def __init__(
136 | self,
137 | obs_shape,
138 | action_shape,
139 | device,
140 | lr,
141 | hidden_dim,
142 | stddev_schedule,
143 | stddev_clip,
144 | use_tb,
145 | augment,
146 | encoder_type,
147 | policy_type,
148 | policy_head,
149 | pixel_keys,
150 | aux_keys,
151 | use_aux_inputs,
152 | train_encoder,
153 | norm,
154 | separate_encoders,
155 | temporal_agg,
156 | max_episode_len,
157 | num_queries,
158 | use_actions,
159 | ):
160 | self.device = device
161 | self.lr = lr
162 | self.hidden_dim = hidden_dim
163 | self.stddev_schedule = stddev_schedule
164 | self.stddev_clip = stddev_clip
165 | self.use_tb = use_tb
166 | self.augment = augment
167 | self.encoder_type = encoder_type
168 | self.policy_head = policy_head
169 | self.use_aux_inputs = use_aux_inputs
170 | self.norm = norm
171 | self.train_encoder = train_encoder
172 | self.separate_encoders = separate_encoders
173 | self.use_actions = use_actions # only for the prompt
174 |
175 | # actor parameters
176 | self._act_dim = action_shape[0]
177 |
178 | # keys
179 | self.aux_keys = aux_keys
180 | self.pixel_keys = pixel_keys
181 |
182 | # action chunking params
183 | self.temporal_agg = temporal_agg
184 | self.max_episode_len = max_episode_len
185 | self.num_queries = num_queries if self.temporal_agg else 1
186 |
187 | # number of inputs per time step
188 | num_feat_per_step = len(self.pixel_keys)
189 | if use_aux_inputs:
190 | num_feat_per_step += len(self.aux_keys)
191 |
192 | # observation params
193 | if use_aux_inputs:
194 | aux_shape = {key: obs_shape[key] for key in self.aux_keys}
195 | obs_shape = obs_shape[self.pixel_keys[0]]
196 |
197 | # Track model size
198 | model_size = 0
199 |
200 | # encoder
201 | if self.separate_encoders:
202 | self.encoder = {}
203 | if self.encoder_type == "base":
204 | if self.separate_encoders:
205 | for key in self.pixel_keys:
206 | self.encoder[key] = BaseEncoder(obs_shape).to(device)
207 | self.repr_dim = self.encoder[key].repr_dim
208 | model_size += sum(
209 | p.numel()
210 | for p in self.encoder[key].parameters()
211 | if p.requires_grad
212 | )
213 | else:
214 | self.encoder = BaseEncoder(obs_shape).to(device)
215 | self.repr_dim = self.encoder.repr_dim
216 | model_size += sum(
217 | p.numel() for p in self.encoder.parameters() if p.requires_grad
218 | )
219 | elif self.encoder_type == "resnet":
220 | self.repr_dim = 512
221 | enc_fn = lambda: ResnetEncoder(
222 | obs_shape,
223 | 512,
224 | cond_dim=None,
225 | cond_fusion="none",
226 | ).to(device)
227 | if self.separate_encoders:
228 | for key in self.pixel_keys:
229 | self.encoder[key] = enc_fn()
230 | model_size += sum(
231 | p.numel()
232 | for p in self.encoder[key].parameters()
233 | if p.requires_grad
234 | )
235 | else:
236 | self.encoder = enc_fn()
237 | model_size += sum(
238 | p.numel() for p in self.encoder.parameters() if p.requires_grad
239 | )
240 |
241 | # projector for proprioceptive features
242 | if use_aux_inputs:
243 | self.aux_projector = nn.ModuleDict()
244 | for key in self.aux_keys:
245 | if key.startswith("digit"):
246 | self.aux_projector[key] = ResnetEncoder(
247 | obs_shape,
248 | 512,
249 | cond_dim=None,
250 | cond_fusion="none",
251 | ).to(device)
252 | else:
253 | self.aux_projector[key] = MLP(
254 | aux_shape[key][0],
255 | hidden_channels=[self.repr_dim, self.repr_dim],
256 | ).to(device)
257 | self.aux_projector[key].apply(utils.weight_init)
258 | model_size += sum(
259 | p.numel()
260 | for p in self.aux_projector[key].parameters()
261 | if p.requires_grad
262 | )
263 |
264 | # projector for actions
265 | if self.use_actions:
266 | self.action_projector = MLP(
267 | self._act_dim, hidden_channels=[self.repr_dim, self.repr_dim]
268 | ).to(device)
269 | self.action_projector.apply(utils.weight_init)
270 | model_size += sum(
271 | p.numel() for p in self.action_projector.parameters() if p.requires_grad
272 | )
273 |
274 | # actor
275 | action_dim = (
276 | self._act_dim * self.num_queries if self.temporal_agg else self._act_dim
277 | )
278 | self.actor = Actor(
279 | self.repr_dim,
280 | action_dim,
281 | hidden_dim,
282 | policy_type,
283 | policy_head,
284 | num_feat_per_step,
285 | ).to(device)
286 | model_size += sum(p.numel() for p in self.actor.parameters() if p.requires_grad)
287 |
288 | print(f"Total number of parameters in the model: {model_size}")
289 |
290 | # optimizers
291 | # encoder
292 | if self.train_encoder:
293 | if self.separate_encoders:
294 | params = []
295 | for key in self.pixel_keys:
296 | params += list(self.encoder[key].parameters())
297 | else:
298 | params = list(self.encoder.parameters())
299 | self.encoder_opt = torch.optim.AdamW(params, lr=lr, weight_decay=1e-4)
300 | # self.encoder_scheduler = torch.optim.lr_scheduler.StepLR(
301 | # self.encoder_opt, step_size=15000, gamma=0.1
302 | # )
303 | # proprio
304 | if self.use_aux_inputs:
305 | self.aux_opt = torch.optim.AdamW(
306 | self.aux_projector.parameters(), lr=lr, weight_decay=1e-4
307 | )
308 | # self.proprio_scheduler = torch.optim.lr_scheduler.StepLR(
309 | # self.proprio_opt, step_size=15000, gamma=0.1
310 | # )
311 |
312 | # action projector
313 | if self.use_actions:
314 | self.action_opt = torch.optim.AdamW(
315 | self.action_projector.parameters(), lr=lr, weight_decay=1e-4
316 | )
317 | # self.action_scheduler = torch.optim.lr_scheduler.StepLR(
318 | # self.action_opt, step_size=15000, gamma=0.1
319 | # )
320 | # actor
321 | self.actor_opt = torch.optim.AdamW(
322 | self.actor.parameters(), lr=lr, weight_decay=1e-4
323 | )
324 | # self.actor_scheduler = torch.optim.lr_scheduler.StepLR(
325 | # self.actor_opt, step_size=15000, gamma=0.1
326 | # )
327 |
328 | # augmentations
329 | if self.norm:
330 | if self.encoder_type == "small":
331 | MEAN = torch.tensor([0.0, 0.0, 0.0])
332 | STD = torch.tensor([1.0, 1.0, 1.0])
333 | elif self.encoder_type == "resnet" or self.norm:
334 | MEAN = torch.tensor([0.485, 0.456, 0.406])
335 | STD = torch.tensor([0.229, 0.224, 0.225])
336 | self.customAug = T.Compose([T.Normalize(mean=MEAN, std=STD)])
337 |
338 | # data augmentation
339 | if self.augment:
340 | # self.aug = utils.RandomShiftsAug(pad=4)
341 | self.test_aug = T.Compose([T.ToPILImage(), T.ToTensor()])
342 | self.digit_aug = T.Compose([T.ToTensor()])
343 |
344 | self.train()
345 | self.buffer_reset()
346 |
347 | def __repr__(self):
348 | return "bc"
349 |
350 | def train(self, training=True):
351 | self.training = training
352 | if training:
353 | if self.separate_encoders:
354 | for key in self.pixel_keys:
355 | if self.train_encoder:
356 | self.encoder[key].train(training)
357 | else:
358 | self.encoder[key].eval()
359 | else:
360 | if self.train_encoder:
361 | self.encoder.train(training)
362 | else:
363 | self.encoder.eval()
364 | if self.use_aux_inputs:
365 | self.aux_projector.train(training)
366 | if self.use_actions:
367 | self.action_projector.train(training)
368 | self.actor.train(training)
369 | else:
370 | if self.separate_encoders:
371 | for key in self.pixel_keys:
372 | self.encoder[key].eval()
373 | else:
374 | self.encoder.eval()
375 | if self.use_aux_inputs:
376 | self.aux_projector.eval()
377 | if self.use_actions:
378 | self.action_projector.eval()
379 | self.actor.eval()
380 |
381 | def buffer_reset(self):
382 | self.observation_buffer = {}
383 | for key in self.pixel_keys:
384 | self.observation_buffer[key] = deque(maxlen=1)
385 | if self.use_aux_inputs:
386 | self.aux_buffer = {}
387 | for key in self.aux_keys:
388 | self.aux_buffer[key] = deque(maxlen=1)
389 |
390 | # temporal aggregation
391 | if self.temporal_agg:
392 | self.all_time_actions = torch.zeros(
393 | [
394 | self.max_episode_len,
395 | self.max_episode_len + self.num_queries,
396 | self._act_dim,
397 | ]
398 | ).to(self.device)
399 |
400 | def clear_buffers(self):
401 | del self.observation_buffer
402 | if self.use_aux_inputs:
403 | del self.aux_buffer
404 | if self.temporal_agg:
405 | del self.all_time_actions
406 |
407 | def reinit_optimizers(self):
408 | if self.train_encoder:
409 | if self.separate_encoders:
410 | params = []
411 | for key in self.pixel_keys:
412 | params += list(self.encoder[key].parameters())
413 | else:
414 | params = list(self.encoder.parameters())
415 | self.encoder_opt = torch.optim.AdamW(params, lr=self.lr, weight_decay=1e-4)
416 | try:
417 | self.aux_cond_opt = torch.optim.AdamW(
418 | self.aux_cond_cross_attn.parameters(),
419 | lr=self.lr,
420 | weight_decay=1e-4,
421 | )
422 | except AttributeError:
423 | print("Not optimizing aux_cond_cross_attn")
424 | if self.use_aux_inputs:
425 | self.aux_opt = torch.optim.AdamW(
426 | self.aux_projector.parameters(), lr=self.lr, weight_decay=1e-4
427 | )
428 | if self.use_actions:
429 | self.action_opt = torch.optim.AdamW(
430 | self.action_projector.parameters(), lr=self.lr, weight_decay=1e-4
431 | )
432 | params = list(self.actor.parameters())
433 | self.actor_opt = torch.optim.AdamW(
434 | self.actor.parameters(), lr=self.lr, weight_decay=1e-4
435 | )
436 |
437 | def act(self, obs, prompt, norm_stats, step, global_step, eval_mode=False):
438 | if norm_stats is not None:
439 |
440 | def pre_process(aux_key, s_qpos):
441 | try:
442 | return (s_qpos - norm_stats[aux_key]["min"]) / (
443 | norm_stats[aux_key]["max"] - norm_stats[aux_key]["min"] + 1e-5
444 | )
445 | except KeyError:
446 | return (s_qpos - norm_stats[aux_key]["mean"]) / (
447 | norm_stats[aux_key]["std"] + 1e-5
448 | )
449 |
450 | # pre_process = lambda aux_key, s_qpos: (
451 | # s_qpos - norm_stats[aux_key]["min"]
452 | # ) / (norm_stats[aux_key]["max"] - norm_stats[aux_key]["min"] + 1e-5)
453 | post_process = (
454 | lambda a: a
455 | * (norm_stats["actions"]["max"] - norm_stats["actions"]["min"])
456 | + norm_stats["actions"]["min"]
457 | )
458 |
459 | # lang projection
460 | lang_features = None
461 |
462 | # add to buffer
463 | features = []
464 | aux_features = []
465 | aux_cond_features = []
466 | if self.use_aux_inputs:
467 | # TODO: Add conditioning here
468 | for key in self.aux_keys:
469 | obs[key] = pre_process(key, obs[key])
470 | if key.startswith("digit"):
471 | self.aux_buffer[key].append(
472 | self.digit_aug(obs[key].transpose(1, 2, 0)).numpy()
473 | )
474 | aux_feat = torch.as_tensor(
475 | np.array(self.aux_buffer[key]), device=self.device
476 | ).float()
477 | aux_feat = self.aux_projector[key](aux_feat)[None, :, :]
478 | else:
479 | self.aux_buffer[key].append(obs[key])
480 | aux_feat = torch.as_tensor(
481 | np.array(self.aux_buffer[key]), device=self.device
482 | ).float()
483 | aux_feat = self.aux_projector[key](aux_feat[None, :, :])
484 | aux_features.append(aux_feat[0])
485 |
486 | for key in self.pixel_keys:
487 | self.observation_buffer[key].append(
488 | self.test_aug(obs[key].transpose(1, 2, 0)).numpy()
489 | )
490 | pixels = torch.as_tensor(
491 | np.array(self.observation_buffer[key]), device=self.device
492 | ).float()
493 | pixels = self.customAug(pixels) if self.norm else pixels
494 | # encoder
495 | pixels = (
496 | self.encoder[key](pixels)
497 | if self.separate_encoders
498 | else self.encoder(pixels)
499 | )
500 | features.append(pixels)
501 |
502 | features.extend(aux_features)
503 | features = torch.cat(features, dim=-1).view(-1, self.repr_dim)
504 |
505 | stddev = utils.schedule(self.stddev_schedule, global_step)
506 | action = self.actor(features.unsqueeze(0), 0, stddev)
507 |
508 | if eval_mode:
509 | action = action.mean
510 | else:
511 | action = action.sample()
512 | if self.temporal_agg:
513 | action = action.view(-1, self.num_queries, self._act_dim)
514 | self.all_time_actions[[step], step : step + self.num_queries] = action[-1:]
515 | actions_for_curr_step = self.all_time_actions[:, step]
516 | actions_populated = torch.all(actions_for_curr_step != 0, axis=1)
517 | actions_for_curr_step = actions_for_curr_step[actions_populated]
518 | k = 0.01
519 | exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step)))
520 | exp_weights = exp_weights / exp_weights.sum()
521 | exp_weights = torch.from_numpy(exp_weights).to(self.device).unsqueeze(dim=1)
522 | action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
523 | if norm_stats is not None:
524 | return post_process(action.cpu().numpy()[0])
525 | return action.cpu().numpy()[0]
526 | else:
527 | if norm_stats is not None:
528 | return post_process(action.cpu().numpy()[0, -1, :])
529 | return action.cpu().numpy()[0, -1, :]
530 |
531 | def update(self, expert_replay_iter, step):
532 | metrics = dict()
533 | batch = next(expert_replay_iter)
534 | data = utils.to_torch(batch, self.device)
535 | action = data["actions"].float()
536 |
537 | # features
538 | features = []
539 | aux_features = []
540 | if self.use_aux_inputs:
541 | for key in self.aux_keys:
542 | aux_feat = data[key].float()
543 | if "digit" in key:
544 | shape = aux_feat.shape
545 | # rearrange
546 | aux_feat = einops.rearrange(aux_feat, "b t c h w -> (b t) c h w")
547 | # augment
548 | # aux_feat = self.aug(aux_feat) if self.augment else aux_feat
549 | aux_feat = self.customAug(aux_feat) if self.norm else aux_feat
550 | aux_feat = self.aux_projector[key](aux_feat)
551 | aux_feat = einops.rearrange(
552 | aux_feat, "(b t) d -> b t d", t=shape[1]
553 | )
554 | else:
555 | aux_feat = self.aux_projector[key](aux_feat)
556 | aux_features.append(aux_feat)
557 | for key in self.pixel_keys:
558 | pixel = data[key].float()
559 | shape = pixel.shape
560 | # rearrange
561 | pixel = einops.rearrange(pixel, "b t c h w -> (b t) c h w")
562 | # augment
563 | # pixel = self.aug(pixel) if self.augment else pixel
564 | pixel = self.customAug(pixel) if self.norm else pixel
565 | # encode
566 | if self.train_encoder:
567 | pixel = (
568 | self.encoder[key](pixel)
569 | if self.separate_encoders
570 | else self.encoder(pixel)
571 | )
572 | else:
573 | with torch.no_grad():
574 | pixel = (
575 | self.encoder[key](pixel)
576 | if self.separate_encoders
577 | else self.encoder(pixel)
578 | )
579 | pixel = einops.rearrange(pixel, "(b t) d -> b t d", t=shape[1])
580 | features.append(pixel)
581 | features.extend(aux_features)
582 | # concatenate
583 | features = torch.cat(features, dim=-1).view(
584 | action.shape[0], -1, self.repr_dim
585 | ) # (B, T * num_feat_per_step, D)
586 |
587 | # rearrange action
588 | if self.temporal_agg:
589 | action = einops.rearrange(action, "b t1 t2 d -> b t1 (t2 d)")
590 |
591 | # actor loss
592 | stddev = utils.schedule(self.stddev_schedule, step)
593 | _, actor_loss = self.actor(
594 | features,
595 | 0,
596 | stddev,
597 | action,
598 | mask=None,
599 | )
600 | if self.train_encoder:
601 | self.encoder_opt.zero_grad(set_to_none=True)
602 |
603 | if self.use_aux_inputs:
604 | self.aux_opt.zero_grad(set_to_none=True)
605 | self.actor_opt.zero_grad(set_to_none=True)
606 | actor_loss["actor_loss"].backward()
607 | if self.train_encoder:
608 | self.encoder_opt.step()
609 | try:
610 | self.aux_cond_opt.step()
611 | except AttributeError:
612 | pass
613 | if self.use_aux_inputs:
614 | self.aux_opt.step()
615 | if self.use_actions:
616 | self.action_opt.step()
617 | self.actor_opt.step()
618 |
619 | if self.use_tb:
620 | for key, value in actor_loss.items():
621 | metrics[key] = value.item()
622 |
623 | return metrics
624 |
625 | def save_snapshot(self):
626 | model_keys = ["actor", "encoder"]
627 | opt_keys = ["actor_opt"]
628 | if self.train_encoder:
629 | opt_keys += ["encoder_opt"]
630 | if self.use_aux_inputs:
631 | model_keys += ["aux_projector"]
632 | opt_keys += ["aux_opt"]
633 | if self.use_actions:
634 | model_keys += ["action_projector"]
635 | opt_keys += ["action_opt"]
636 | # models
637 | payload = {
638 | k: self.__dict__[k].state_dict() for k in model_keys if k != "encoder"
639 | }
640 | if "encoder" in model_keys:
641 | if self.separate_encoders:
642 | for key in self.pixel_keys:
643 | payload[f"encoder_{key}"] = self.encoder[key].state_dict()
644 | else:
645 | payload["encoder"] = self.encoder.state_dict()
646 | # optimizers
647 | payload.update({k: self.__dict__[k] for k in opt_keys})
648 |
649 | others = [
650 | "use_aux_inputs",
651 | "aux_keys",
652 | "use_actions",
653 | "max_episode_len",
654 | ]
655 | payload.update({k: self.__dict__[k] for k in others})
656 | return payload
657 |
658 | def load_snapshot(self, payload, encoder_only=False, eval=False, load_opt=False):
659 | # models
660 | if encoder_only:
661 | model_keys = ["encoder"]
662 | payload = {"encoder": payload}
663 | else:
664 | model_keys = ["actor", "encoder"]
665 | if self.use_aux_inputs:
666 | model_keys += ["aux_projector"]
667 | if self.use_actions:
668 | model_keys += ["action_projector"]
669 |
670 | for k in model_keys:
671 | if k == "encoder" and self.separate_encoders:
672 | for key in self.pixel_keys:
673 | self.encoder[key].load_state_dict(payload[f"encoder_{key}"])
674 | else:
675 | self.__dict__[k].load_state_dict(payload[k])
676 |
677 | if eval:
678 | self.train(False)
679 | return
680 |
681 | # if not eval
682 | if not load_opt:
683 | self.reinit_optimizers()
684 | else:
685 | opt_keys = ["actor_opt"]
686 | if self.train_encoder:
687 | opt_keys += ["encoder_opt"]
688 | if self.use_aux_inputs:
689 | opt_keys += ["aux_opt"]
690 | if self.use_actions:
691 | opt_keys += ["action_opt"]
692 | for k in opt_keys:
693 | self.__dict__[k] = payload[k]
694 | self.train(True)
695 |
--------------------------------------------------------------------------------