├── .gitignore ├── data4robotics ├── models │ ├── __init__.py │ ├── base.py │ ├── resnet.py │ ├── action_distributions.py │ └── vit.py ├── trainers │ ├── __init__.py │ ├── bc.py │ └── base.py ├── __init__.py ├── load_pretrained.py ├── task.py ├── transforms.py ├── agent.py ├── misc.py └── replay_buffer.py ├── experiments ├── agent │ ├── features │ │ ├── r3m.yaml │ │ ├── resnet_gn.yaml │ │ └── vit_base.yaml │ ├── policy │ │ ├── gaussian_mixture.yaml │ │ └── gaussian_constant.yaml │ └── default.yaml ├── trainer │ └── bc.yaml ├── hydra │ └── launcher │ │ └── slurm.yaml ├── task │ └── franka.yaml └── finetune.yaml ├── setup.py ├── download_features.sh ├── pretrained_networks_example.py ├── env.yml ├── LICENSE.md ├── README.md └── finetune.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.jpg 2 | *.gif 3 | *.mp4 4 | *.png 5 | __pycache__/ 6 | bc_finetune/ 7 | visual_features/ 8 | *.egg-info/ -------------------------------------------------------------------------------- /data4robotics/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | -------------------------------------------------------------------------------- /data4robotics/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | -------------------------------------------------------------------------------- /experiments/agent/features/r3m.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | _target_: data4robotics.models.resnet.R3M 8 | size: 18 9 | -------------------------------------------------------------------------------- /data4robotics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | from .load_pretrained import load_vit 8 | from .load_pretrained import load_resnet18 9 | -------------------------------------------------------------------------------- /experiments/trainer/bc.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | _target_: data4robotics.trainers.bc.BehaviorCloning 8 | lr: 0.0001 9 | weight_decay: 0.0001 10 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | from setuptools import setup 8 | 9 | setup(name="data4robotics", 10 | packages=["data4robotics"], 11 | version="0.1" 12 | ) 13 | -------------------------------------------------------------------------------- /download_features.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # make sure the folder doesn't already exist 4 | if [ -d "visual_features" ]; then 5 | echo "Data already downloaded!" 6 | exit 0 7 | fi 8 | 9 | wget --output-document features.zip https://cmu.box.com/shared/static/rrlrp5g6ynk03io4rj9uzf6nik5urfl6 10 | unzip features.zip 11 | rm features.zip 12 | 13 | -------------------------------------------------------------------------------- /experiments/agent/features/resnet_gn.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | _target_: data4robotics.models.resnet.ResNet 8 | size: 18 9 | pretrained: null 10 | restore_path: '' 11 | norm_cfg: 12 | name: group_norm 13 | num_groups: 16 14 | -------------------------------------------------------------------------------- /experiments/agent/policy/gaussian_mixture.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | _target_: data4robotics.models.action_distributions.GaussianMixture 8 | in_dim: ${index:${agent.shared_mlp}, -1} 9 | ac_dim: ${task.ac_dim} 10 | ac_chunk: ${ac_chunk} 11 | num_modes: 5 12 | -------------------------------------------------------------------------------- /experiments/agent/policy/gaussian_constant.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | _target_: data4robotics.models.action_distributions.GaussianSharedScale 8 | in_dim: ${index:${agent.shared_mlp}, -1} 9 | ac_dim: ${task.ac_dim} 10 | ac_chunk: ${ac_chunk} 11 | std_fixed: True 12 | -------------------------------------------------------------------------------- /experiments/agent/features/vit_base.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | _target_: data4robotics.models.vit.load_vit 8 | restore_path: '' 9 | model: 10 | _target_: data4robotics.models.vit.vit_base_patch16 11 | img_size: 224 12 | use_cls: True 13 | drop_path_rate: 0.0 14 | -------------------------------------------------------------------------------- /experiments/agent/default.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | defaults: 8 | - features: vit_base 9 | - policy: gaussian_mixture 10 | - _self_ 11 | 12 | _target_: data4robotics.agent.Agent 13 | shared_mlp: [512,512] 14 | odim: ${task.obs_dim} 15 | n_cams: ${task.n_cams} 16 | use_obs: True 17 | dropout: 0.2 18 | -------------------------------------------------------------------------------- /experiments/hydra/launcher/slurm.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | defaults: 8 | - submitit_slurm 9 | 10 | timeout_min: 360 11 | partition: default 12 | tasks_per_node: ${devices} 13 | gpus_per_node: ${devices} 14 | cpus_per_task: ${num_workers} 15 | mem_gb: ${mult:${devices},124} 16 | nodes: 1 17 | max_num_timeout: 100 18 | -------------------------------------------------------------------------------- /pretrained_networks_example.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | from data4robotics import load_vit, load_resnet18 8 | 9 | 10 | # load strongest vit/resnet models 11 | vit_transform, vit_model = load_vit() 12 | res_transform, res_model = load_resnet18() 13 | 14 | 15 | # get embeddings from each network 16 | input_img = torch.rand((1, 3, 480, 640)).cuda() 17 | emb_vit = vit_model(vit_transform(input_img)) 18 | emb_res = res_model(res_transform(input_img)) 19 | 20 | 21 | # print out shapes 22 | print('vit_base embedding shape:', emb_vit.shape) 23 | print('resnet18 embedding shape:', emb_res.shape) 24 | -------------------------------------------------------------------------------- /experiments/task/franka.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | _target_: data4robotics.task.DefaultTask 8 | obs_dim: 7 9 | ac_dim: 7 10 | n_cams: 1 11 | 12 | 13 | train_buffer: 14 | _target_: data4robotics.replay_buffer.RobobufReplayBuffer 15 | buffer_path: ${buffer_path} 16 | transform: ${transform:${train_transform}} 17 | n_test_trans: 500 18 | ac_chunk: ${ac_chunk} 19 | mode: train 20 | cam_idx: 0 21 | 22 | test_buffer: 23 | _target_: data4robotics.replay_buffer.RobobufReplayBuffer 24 | buffer_path: ${buffer_path} 25 | transform: ${transform:preproc} 26 | n_test_trans: 500 27 | ac_chunk: ${ac_chunk} 28 | mode: test 29 | cam_idx: ${task.train_buffer.cam_idx} 30 | -------------------------------------------------------------------------------- /data4robotics/models/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class BaseModel(nn.Module): 12 | def __init__(self, model, restore_path): 13 | super().__init__() 14 | self._model = model 15 | if restore_path: 16 | print('Restoring model from', restore_path) 17 | state_dict = torch.load(restore_path, map_location='cpu') 18 | state_dict = state_dict['features'] if 'features' in state_dict \ 19 | else state_dict['model'] 20 | self.load_state_dict(state_dict) 21 | 22 | @property 23 | def embed_dim(self): 24 | raise NotImplementedError 25 | -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: data4robotics 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - python ==3.9 9 | - pytorch=2.0.1=py3.9_cuda11.8_cudnn8.7.0_0 10 | - torchvision 11 | - pytorch-cuda=11.8 12 | - cmake=3.22.1=h1fce559_0 13 | - numpy 14 | - pandas 15 | - plotly 16 | - pip 17 | - pytest==7.3.1 18 | - scipy 19 | - tqdm 20 | - patchelf=0.18.0=h59595ed_0 21 | - pip: 22 | - transforms3d 23 | - opencv-python 24 | - hydra-ray-launcher==1.2.0 25 | - hydra-core==1.2.0 26 | - hydra-submitit-launcher==1.2.0 27 | - wandb==0.13.4 28 | - timm==0.6.11 29 | - gym==0.23.1 30 | - huggingface-hub==0.12.1 31 | - dm-control==1.0.10 32 | - dm-env==1.6 33 | - dm-tree==0.1.8 34 | - cloudpickle==2.0.0 35 | - mujoco==2.3.2 36 | - imageio==2.22.1 37 | - imageio-ffmpeg==0.4.7 38 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Meta Platforms, Inc. and affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /experiments/finetune.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | defaults: 8 | - agent: default 9 | - task: franka 10 | - trainer: bc 11 | - override hydra/launcher: slurm 12 | - _self_ 13 | 14 | 15 | hydra: 16 | run: 17 | dir: bc_finetune/${exp_name}/${hydra:runtime.choices.task}_${hydra:runtime.choices.agent/features}_${now:%Y-%m-%d_%H-%M-%S} 18 | sweep: 19 | dir: ${base:}/../bc_finetune/${exp_name}/${now:%Y-%m-%d_%H-%M-%S} 20 | subdir: run${hydra:job.num}_${hydra:runtime.choices.task}_${hydra:runtime.choices.agent/features} 21 | 22 | rt: ${hydra:runtime.choices.agent/features} 23 | 24 | exp_name: test 25 | checkpoint_path: ${exp_name}.ckpt 26 | batch_size: 100 27 | num_workers: 10 28 | max_iterations: 50000 29 | eval_freq: 10000 30 | save_freq: 10000 31 | devices: 1 32 | seed: 292285 33 | 34 | buffer_path: ./buffer.pkl 35 | ac_chunk: 1 36 | train_transform: medium 37 | 38 | wandb: 39 | name: null 40 | project: data_proj_finetune 41 | group: ${exp_name} 42 | sweep_name_prefix: eval 43 | debug: False 44 | -------------------------------------------------------------------------------- /data4robotics/trainers/bc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import torch 8 | from torch import nn 9 | from data4robotics.trainers.base import BaseTrainer 10 | 11 | 12 | class BehaviorCloning(BaseTrainer): 13 | def __init__(self, agent, device_id, lr=1e-4, weight_decay=1e-4): 14 | self._lr, self._weight_decay = lr, weight_decay 15 | super().__init__(agent, device_id) 16 | 17 | def training_step(self, batch, global_step): 18 | (imgs, obs), actions, _ = batch 19 | imgs, obs, actions = [ar.to(self.device_id) for ar in \ 20 | (imgs, obs, actions)] 21 | 22 | action_dist = self.model(imgs, obs) 23 | ac_flat = actions.reshape((actions.shape[0], -1)) 24 | loss = -torch.mean(action_dist.log_prob(ac_flat)) 25 | 26 | self.log("bc_loss", global_step, loss.item()) 27 | return loss 28 | 29 | def configure_optimizers(self): 30 | optimizer = torch.optim.Adam(self.model.parameters(), lr=self._lr, 31 | weight_decay=self._weight_decay) 32 | return optimizer 33 | 34 | def _save_callback(self, save_path, _): 35 | # pickle and save the agent also 36 | torch.save('agent.pkl', agent) 37 | -------------------------------------------------------------------------------- /data4robotics/models/resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import torch 8 | import torch.nn as nn 9 | from r3m import load_r3m 10 | from torchvision import models 11 | from torch.nn.modules.linear import Identity 12 | from data4robotics.models.base import BaseModel 13 | 14 | 15 | def _make_norm(norm_cfg): 16 | if norm_cfg['name'] == 'batch_norm': 17 | return nn.BatchNorm2d 18 | if norm_cfg['name'] == 'group_norm': 19 | num_groups = norm_cfg['num_groups'] 20 | return lambda num_channels: nn.GroupNorm(num_groups, num_channels) 21 | raise NotImplementedError(f"Missing norm layer: {norm_cfg['name']}") 22 | 23 | 24 | def _construct_resnet(size, norm, weights=None): 25 | if size == 18: 26 | return models.resnet18(weights=weights, norm_layer=norm) 27 | if size == 34: 28 | return models.resnet34(weights=weights, norm_layer=norm) 29 | if size == 50: 30 | return models.resnet34(weights=weights, norm_layer=norm) 31 | raise NotImplementedError(f'Missing size: {size}') 32 | 33 | 34 | class ResNet(BaseModel): 35 | def __init__(self, size, norm_cfg, pretrained=None, restore_path=''): 36 | norm_layer = _make_norm(norm_cfg) 37 | model = _construct_resnet(size, norm_layer, pretrained) 38 | model.fc = Identity() 39 | super().__init__(model, restore_path) 40 | self._size = size 41 | 42 | def forward(self, x): 43 | return self._model(x) 44 | 45 | @property 46 | def embed_dim(self): 47 | return {18: 512, 34: 512, 50: 2048}[self._size] 48 | 49 | 50 | class R3M(ResNet): 51 | def __init__(self, size): 52 | nn.Module.__init__(self) 53 | self._model = load_r3m(f'resnet{size}').module.convnet.cpu() 54 | self._size = size 55 | -------------------------------------------------------------------------------- /data4robotics/load_pretrained.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import os, torch, data4robotics 8 | from torchvision import transforms 9 | from data4robotics.models import vit 10 | from data4robotics.models import resnet 11 | 12 | 13 | # feature install path 14 | BASE_PATH = os.path.dirname(data4robotics.__file__) + "/../" 15 | FEATURE_PATH = os.path.join(BASE_PATH, "visual_features") 16 | 17 | 18 | def _check_and_download(): 19 | old_cwd = os.getcwd() 20 | 21 | # change cwd to main folder and run download script 22 | os.chdir(BASE_PATH) 23 | download_script = os.path.join(BASE_PATH, 'download_features.sh') 24 | os.system(download_script) 25 | 26 | # change cwd back to old location 27 | os.chdir(old_cwd) 28 | 29 | 30 | def default_transform(): 31 | return transforms.Compose([transforms.Resize((224, 224), antialias=False), 32 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 33 | std=[0.229, 0.224, 0.225])]) 34 | 35 | 36 | def load_vit(model_name="SOUP_1M_DH", device=torch.device('cuda:0')): 37 | _check_and_download() 38 | model = vit.vit_base_patch16(img_size=224, use_cls=True, drop_path_rate=0.0) 39 | restore_path = os.path.join(FEATURE_PATH, f'vit_base/{model_name}.pth') 40 | model = vit.load_vit(model, restore_path) 41 | return default_transform(), model.to(device) 42 | 43 | 44 | def load_resnet18(model_name="IN_1M_resnet18", device=torch.device('cuda:0')): 45 | _check_and_download() 46 | restore_path = os.path.join(FEATURE_PATH, f'resnet18/{model_name}.pth') 47 | model = resnet.ResNet(size=18, pretrained=None, restore_path=restore_path, 48 | norm_cfg=dict(name='group_norm', num_groups=16)) 49 | return default_transform(), model.to(device) 50 | -------------------------------------------------------------------------------- /data4robotics/task.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import torch, wandb, json, cv2, imageio, os 8 | import numpy as np 9 | from torch.utils.data import DataLoader, IterableDataset 10 | from data4robotics.replay_buffer import IterableWrapper 11 | _TEST_WORKERS = 4 12 | 13 | 14 | def _build_data_loader(buffer, batch_size, num_workers, is_train=False): 15 | if is_train and not isinstance(buffer, IterableDataset): 16 | buffer = IterableWrapper(buffer) 17 | 18 | return DataLoader(buffer, batch_size=batch_size, 19 | num_workers=num_workers, 20 | shuffle=not isinstance(buffer, IterableDataset), 21 | pin_memory=True, 22 | persistent_workers=True, 23 | drop_last=True, 24 | worker_init_fn= lambda _: np.random.seed()) 25 | 26 | 27 | class DefaultTask: 28 | def __init__(self, train_buffer, test_buffer, n_cams, obs_dim, ac_dim, 29 | batch_size, num_workers): 30 | self.n_cams, self.obs_dim, self.ac_dim = n_cams, obs_dim, ac_dim 31 | self.train_loader = _build_data_loader(train_buffer, batch_size, num_workers, 32 | is_train=True) 33 | if test_buffer is not None: 34 | test_workers = min(num_workers, _TEST_WORKERS) 35 | self.test_loader = _build_data_loader(test_buffer, batch_size, test_workers) 36 | 37 | def eval(self, trainer, global_step): 38 | losses = [] 39 | for batch in self.test_loader: 40 | with torch.no_grad(): 41 | loss = trainer.training_step(batch, global_step) 42 | losses.append(loss.item()) 43 | 44 | mean_val_loss = np.mean(losses) 45 | print(f'Step: {global_step}\tVal Loss: {mean_val_loss:.4f}') 46 | if wandb.run is not None: 47 | wandb.log({'eval/task_loss': mean_val_loss}, step=global_step) 48 | -------------------------------------------------------------------------------- /data4robotics/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | from torchvision import transforms 8 | 9 | 10 | def image_norm(size=224): 11 | return transforms.Compose([transforms.Resize((size, size)), 12 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 13 | 14 | 15 | def get_transform_by_name(name, size=224): 16 | if name == 'preproc': 17 | return transforms.Compose([transforms.Resize((size, size), antialias=False), 18 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 19 | if name == 'basic': 20 | return transforms.Compose([transforms.RandomResizedCrop(size=size, scale=(0.2, 1.0), antialias=False), 21 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 22 | if name == 'medium': 23 | kernel_size = int(0.05 * size); kernel_size = kernel_size + (1 - kernel_size % 2) 24 | return transforms.Compose([transforms.RandomResizedCrop(size=size, scale=(0.2, 1.0), antialias=False), 25 | transforms.GaussianBlur(kernel_size=kernel_size), 26 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 27 | if name == 'advanced': 28 | kernel_size = int(0.05 * size); kernel_size = kernel_size + (1 - kernel_size % 2) 29 | color_jitter = transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) 30 | return transforms.Compose([transforms.RandomResizedCrop(size=size, scale=(0.2, 1.0), antialias=False), 31 | transforms.RandomApply([color_jitter], p=0.8), 32 | transforms.RandomGrayscale(p=0.2), 33 | transforms.GaussianBlur(kernel_size=kernel_size), 34 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 35 | raise NotImplementedError(f'{name} not found!') 36 | -------------------------------------------------------------------------------- /data4robotics/agent.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import torch 8 | from torch import nn 9 | 10 | 11 | class Agent(nn.Module): 12 | def __init__(self, features, policy, shared_mlp, odim, 13 | n_cams, use_obs, dropout=0): 14 | super().__init__() 15 | 16 | # store visual, policy, and inverse model 17 | self.visual_features = features 18 | self._policy = policy 19 | 20 | # build shared mlp layers 21 | self._odim = odim if use_obs else 0 22 | self._use_obs, self._n_cams = bool(use_obs), n_cams 23 | mlp_in = self._odim + n_cams * features.embed_dim 24 | mlp_def = [mlp_in] + shared_mlp 25 | layers = [nn.BatchNorm1d(num_features=mlp_in)] 26 | for i, o in zip(mlp_def[:-1], mlp_def[1:]): 27 | layers.append(nn.Dropout(dropout)) 28 | layers.append(nn.Linear(i, o)) 29 | layers.append(nn.ReLU()) 30 | layers.append(nn.Dropout(dropout)) 31 | self._shared_mlp = nn.Sequential(*layers) 32 | 33 | def forward(self, imgs, obs, zero_std=False): 34 | s_t = self._shared_forward(imgs, obs) 35 | action_dist = self._policy(s_t, zero_std=zero_std) 36 | return action_dist 37 | 38 | def get_actions(self, img, obs, zero_std=True): 39 | policy_in = self._shared_forward(img, obs) 40 | return self._policy.get_actions(policy_in, zero_std=zero_std) 41 | 42 | def _shared_forward(self, imgs, obs): 43 | shared_in = torch.cat((self.embed(imgs), obs), dim=1) if self._use_obs \ 44 | else self.embed(imgs) 45 | return self._shared_mlp(shared_in) 46 | 47 | def embed(self, imgs): 48 | if len(imgs.shape) == 5: 49 | B, N, C, H, W = imgs.shape 50 | embeds = self.visual_features(imgs.reshape((B * N, C, H, W))) 51 | embeds = embeds.reshape((B, N * self.visual_features.embed_dim)) 52 | return embeds 53 | return self.visual_features(imgs) 54 | 55 | @property 56 | def odim(self): 57 | return self._odim 58 | 59 | @property 60 | def n_cams(self): 61 | return self._n_cams 62 | 63 | @property 64 | def ac_chunk(self): 65 | return self._policy.ac_chunk 66 | 67 | def restore_features(self, restore_path): 68 | if not restore_path: 69 | print('No restore path supplied!') 70 | return 71 | state_dict = torch.load(restore_path, map_location='cpu')['features'] 72 | self.visual_features.load_state_dict(state_dict) 73 | print(f"Restored {restore_path}!") 74 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # An Unbiased Look at Datasets for Visuo-Motor Pre-Training 2 | [[Project Page]](https://data4robotics.github.io/) 3 | 4 | This repository offers a minimal Behavior Cloning (BC) implementation using pre-trained representations from our CoRL project. All tests were conducted on a Franka Panda robot, using the [polymetis controller](https://facebookresearch.github.io/fairo/polymetis/). We've also verified that it works on the [R2D2 control stack](https://github.com/AlexanderKhazatsky/R2D2/tree/main). 5 | 6 | If you find this codebase or our pre-trained representations useful at all, please cite: 7 | ``` 8 | @inproceedings{dasari2023datasets, 9 | title={An Unbiased Look at Datasets for Visuo-Motor Pre-Training}, 10 | author={Dasari, Sudeep and Srirama, Mohan Kumar and Jain, Unnat and Gupta, Abhinav}, 11 | booktitle={Conference on Robot Learning}, 12 | year={2023}, 13 | organization={PMLR} 14 | } 15 | ``` 16 | ## Installation 17 | Our repository is easy to install using miniconda or anaconda: 18 | 19 | ``` 20 | conda env create -f env.yml 21 | conda activate data4robotics 22 | pip install git+https://github.com/AGI-Labs/robobuf.git 23 | pip install git+https://github.com/facebookresearch/r3m.git 24 | pip install -e ./ 25 | ``` 26 | 27 | ## Using Pre-Trained Features 28 | You can easily download our pre-trained represenations using the provided script: `./download_features.sh` 29 | 30 | The features are very modular, and easy to use in your own code-base! Please refer to the [example code](https://github.com/SudeepDasari/data4robotics/blob/main/pretrained_networks_example.py) if you're interested in this. 31 | 32 | ## Training BC Policies 33 | First, you're going to need to convert your training trajectories into our [robobuf](https://github.com/AGI-Labs/robobuf/tree/main) format (pseudo-code below). 34 | ``` 35 | def _resize_and_encode(rgb_img, size=(256,256)): 36 | bgr_image = cv2.resize(bgr_image, size, interpolation=cv2.INTER_AREA) 37 | _, encoded = cv2.imencode(".jpg", bgr_image) 38 | return encoded 39 | 40 | def convert_trajectories(input_trajs, out_path): 41 | out_buffer = [] 42 | for traj in tqdm(input_trajs): 43 | out_traj = [] 44 | for in_obs, in_ac, in_reward in enumerate(data): 45 | out_obs = dict(state=np.array(in_obs['state']).astype(np.float32), 46 | enc_cam_0=_resize_and_encode(in_obs['image'])) 47 | out_action = np.array(in_ac).astype(np.float32) 48 | out_reward = float(in_reward) 49 | out_traj.append((out_obs, out_action, out_reward)) 50 | out_buffer.append(out_traj) 51 | 52 | with open(os.path.join(out_path, 'buf.pkl'), 'wb') as f: 53 | pkl.dump(out_trajs, f) 54 | ``` 55 | 56 | Once the conversion is complete, you can run the example command below: 57 | ``` 58 | python finetune.py exp_name=test agent.features.restore_path=/path/to/SOUP_1M_DH.pth buffer_path=/data/path/buffer.pkl 59 | ``` 60 | This will result in a policy checkpoint saved in the `bc_finetune/` folder. 61 | -------------------------------------------------------------------------------- /data4robotics/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import wandb, torch, signal, functools, time, sys, os, yaml 8 | import numpy as np 9 | from omegaconf import OmegaConf 10 | from hydra.core.hydra_config import HydraConfig 11 | from data4robotics.transforms import get_transform_by_name 12 | OmegaConf.register_new_resolver("env", lambda x: os.environ[x]) 13 | OmegaConf.register_new_resolver("base", lambda: os.path.dirname(os.path.abspath(__file__))) 14 | OmegaConf.register_new_resolver("transform", lambda name: get_transform_by_name(name)) 15 | OmegaConf.register_new_resolver("mult", lambda x, y: int(x) * int(y)) 16 | OmegaConf.register_new_resolver("index", lambda arr, idx: arr[idx]) 17 | 18 | 19 | GLOBAL_STEP = 0 20 | REQUEUE_CAUGHT = False 21 | 22 | 23 | def _signal_helper(signal, frame, prior_handler, trainer, ckpt_path): 24 | global REQUEUE_CAUGHT, GLOBAL_STEP 25 | REQUEUE_CAUGHT = True 26 | 27 | # save train checkpoint 28 | print(f'Caught requeue signal at step: {GLOBAL_STEP}') 29 | trainer.save_checkpoint(ckpt_path, GLOBAL_STEP) 30 | 31 | # return back to submitit handler if it exists 32 | if callable(prior_handler): 33 | return prior_handler(signal, frame) 34 | return sys.exit(-1) 35 | 36 | 37 | def set_checkpoint_handler(trainer, ckpt_path): 38 | global REQUEUE_CAUGHT 39 | REQUEUE_CAUGHT = False 40 | prior_handler = signal.getsignal(signal.SIGUSR2) 41 | handler = functools.partial(_signal_helper, prior_handler=prior_handler, 42 | trainer=trainer, 43 | ckpt_path=ckpt_path) 44 | signal.signal(signal.SIGUSR2, handler) 45 | 46 | 47 | def create_wandb_run(wandb_cfg, job_config, run_id=None): 48 | if wandb_cfg.debug: 49 | return 'null_id' 50 | try: 51 | job_id = HydraConfig().get().job.num 52 | override_dirname = HydraConfig().get().job.override_dirname 53 | name = f'{wandb_cfg.sweep_name_prefix}-{job_id}' 54 | notes = f'{override_dirname}' 55 | except: 56 | name, notes = wandb_cfg.name, None 57 | 58 | wandb_run = wandb.init( 59 | project=wandb_cfg.project, 60 | config=job_config, 61 | group=wandb_cfg.group, 62 | name=name, 63 | notes=notes, 64 | id=run_id, 65 | resume=run_id is not None 66 | ) 67 | return wandb_run.id 68 | 69 | 70 | def init_job(cfg): 71 | cfg_yaml = OmegaConf.to_yaml(cfg) 72 | if os.path.exists('exp_config.yaml'): 73 | old_config = yaml.safe_load(open('exp_config.yaml', 'r')) 74 | create_wandb_run(cfg.wandb, old_config['params'], old_config['wandb_id']) 75 | resume_model = cfg.checkpoint_path 76 | assert os.path.exists(resume_model), '{} does not exist!'.format(cfg.checkpoint_path) 77 | else: 78 | params = yaml.safe_load(cfg_yaml) 79 | wandb_id = create_wandb_run(cfg.wandb, params) 80 | save_dict = dict(wandb_id=wandb_id, params=params) 81 | yaml.dump(save_dict, open('exp_config.yaml', 'w')) 82 | resume_model = None 83 | print('Training w/ Config:') 84 | print(cfg_yaml) 85 | return resume_model 86 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import os, hydra, traceback, torch, tqdm, yaml 8 | import numpy as np 9 | from data4robotics import misc 10 | from omegaconf import DictConfig, OmegaConf 11 | base_path = os.path.dirname(os.path.abspath(__file__)) 12 | 13 | 14 | @hydra.main(config_path=os.path.join(base_path, 'experiments'), config_name="finetune.yaml") 15 | def bc_finetune(cfg: DictConfig): 16 | try: 17 | resume_model = misc.init_job(cfg) 18 | 19 | # set random seeds for reproducibility 20 | torch.manual_seed(cfg.seed) 21 | np.random.seed(cfg.seed + 1) 22 | 23 | # build agent from hydra configs 24 | with open('agent_config.yaml', 'w') as f: 25 | agent_yaml = OmegaConf.to_yaml(cfg.agent, resolve=True) 26 | f.write(agent_yaml) 27 | 28 | agent = hydra.utils.instantiate(cfg.agent) 29 | trainer = hydra.utils.instantiate(cfg.trainer, agent=agent, device_id=0) 30 | 31 | # build task, replay buffer, and dataloader 32 | task = hydra.utils.instantiate(cfg.task, batch_size=cfg.batch_size, 33 | num_workers=cfg.num_workers) 34 | 35 | # restore/save the model as required 36 | if resume_model is not None: 37 | misc.GLOBAL_STEP = trainer.load_checkpoint(resume_model) 38 | elif misc.GLOBAL_STEP == 0: 39 | trainer.save_checkpoint(cfg.checkpoint_path, misc.GLOBAL_STEP) 40 | assert misc.GLOBAL_STEP >= 0, "GLOBAL_STEP not loaded correctly!" 41 | 42 | # register checkpoint handler and enter train loop 43 | misc.set_checkpoint_handler(trainer, cfg.checkpoint_path) 44 | print(f'Starting at Global Step {misc.GLOBAL_STEP}') 45 | 46 | trainer.set_train() 47 | train_iterator = iter(task.train_loader) 48 | for itr in (pbar := tqdm.tqdm(range(cfg.max_iterations), postfix=dict(Loss=None))): 49 | if itr < misc.GLOBAL_STEP: 50 | continue 51 | 52 | # infinitely sample batches until the train loop is finished 53 | try: 54 | batch = next(train_iterator) 55 | except StopIteration: 56 | train_iterator = iter(task.train_loader) 57 | batch = next(train_iterator) 58 | 59 | trainer.optim.zero_grad() 60 | loss = trainer.training_step(batch, misc.GLOBAL_STEP) 61 | loss.backward() 62 | trainer.optim.step() 63 | 64 | pbar.set_postfix(dict(Loss=loss.item())) 65 | misc.GLOBAL_STEP += 1 66 | 67 | if misc.GLOBAL_STEP % cfg.eval_freq == 0: 68 | trainer.set_eval() 69 | task.eval(trainer, misc.GLOBAL_STEP) 70 | trainer.set_train() 71 | 72 | if misc.GLOBAL_STEP >= cfg.max_iterations: 73 | trainer.save_checkpoint(cfg.checkpoint_path, misc.GLOBAL_STEP) 74 | return 75 | elif misc.GLOBAL_STEP % cfg.save_freq == 0: 76 | trainer.save_checkpoint(cfg.checkpoint_path, misc.GLOBAL_STEP) 77 | 78 | # gracefully handle and log errors 79 | except: 80 | traceback.print_exc(file=open('exception.log', 'w')) 81 | with open('exception.log', 'r') as f: 82 | print(f.read()) 83 | 84 | 85 | if __name__ == '__main__': 86 | bc_finetune() 87 | -------------------------------------------------------------------------------- /data4robotics/trainers/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import torch, wandb 8 | import numpy as np 9 | from abc import ABC, abstractmethod 10 | from torch.nn.parallel import DistributedDataParallel as DDP 11 | 12 | 13 | TRAIN_LOG_FREQ, EVAL_LOG_FREQ = 100, 1 14 | class RunningMean: 15 | def __init__(self, max_len=TRAIN_LOG_FREQ): 16 | self._values = [] 17 | self._ctr, self._max_len = 0, max_len 18 | 19 | def append(self, item): 20 | self._ctr = (self._ctr + 1) % self._max_len 21 | if len(self._values) < self._max_len: 22 | self._values.append(item) 23 | else: 24 | self._values[self._ctr] = item 25 | 26 | @property 27 | def mean(self): 28 | if len(self._values) == 0: 29 | raise ValueError 30 | return np.mean(self._values) 31 | 32 | 33 | class BaseTrainer(ABC): 34 | def __init__(self, model, device_id): 35 | self.model, self.device_id = model, device_id 36 | self.set_device(device_id) 37 | self.optim = self.configure_optimizers() 38 | self._trackers = dict() 39 | self._is_train = True; self.set_train() 40 | 41 | @abstractmethod 42 | def training_step(self, batch_input, global_step): 43 | pass 44 | 45 | @abstractmethod 46 | def configure_optimizers(self): 47 | pass 48 | 49 | def save_checkpoint(self, save_path, global_step): 50 | model = self.model 51 | model_weights = model.module.state_dict() if isinstance(model, DDP) \ 52 | else model.state_dict() 53 | save_dict = dict(model=model_weights, 54 | optim=self.optim.state_dict(), 55 | global_step = global_step) 56 | torch.save(save_dict, save_path) 57 | 58 | def _save_callback(self, save_path, save_dict): 59 | pass 60 | 61 | def load_checkpoint(self, load_path): 62 | load_dict = torch.load(load_path) 63 | model = self.model 64 | model = model.module if isinstance(model, DDP) \ 65 | else model 66 | model.load_state_dict(load_dict['model']) 67 | self.optim.load_state_dict(load_dict['optim']) 68 | return load_dict['global_step'] 69 | 70 | def _load_callback(self, load_path, load_dict): 71 | pass 72 | 73 | def wrap_ddp(self): 74 | self.model = DDP(model, device_ids=[self.device_id]) 75 | 76 | def set_train(self): 77 | self._is_train = True 78 | self.model = self.model.train() 79 | 80 | def set_eval(self): 81 | self._is_train = False 82 | self.model = self.model.eval() 83 | 84 | # reset running mean for eval trackers 85 | for k in self._trackers: 86 | if 'eval/' in k: 87 | self._trackers[k] = RunningMean() 88 | 89 | def log(self, key, global_step, value): 90 | log_freq = TRAIN_LOG_FREQ if self._is_train else EVAL_LOG_FREQ 91 | key_prepend = 'train/' if self._is_train else 'eval/' 92 | key = key_prepend + key 93 | 94 | if key not in self._trackers: 95 | self._trackers[key] = RunningMean() 96 | 97 | tracker = self._trackers[key] 98 | tracker.append(value) 99 | 100 | if global_step % log_freq == 0 and wandb.run is not None: 101 | wandb.log({key: tracker.mean}, step=global_step) 102 | 103 | def set_device(self, device_id): 104 | self.model = self.model.to(device_id) 105 | -------------------------------------------------------------------------------- /data4robotics/replay_buffer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import random, torch, tqdm 8 | import numpy as np 9 | import pickle as pkl 10 | from robobuf import ReplayBuffer as RB 11 | from torch.utils.data import Dataset, IterableDataset 12 | 13 | 14 | # helper functions 15 | _img_to_tensor = lambda x: torch.from_numpy(x.copy()).permute((0, 3, 1, 2)).float() / 255 16 | _to_tensor = lambda x: torch.from_numpy(x).float() 17 | 18 | 19 | BUF_SHUFFLE_RNG = 3904767649 20 | class ReplayBuffer(Dataset): 21 | def __init__(self, buffer_path, transform=None, n_train_demos=200, mode='train', ac_chunk=1): 22 | assert mode in ('train', 'test'), "Mode must be train/test" 23 | buffer_data = self._load_buffer(buffer_path) 24 | assert len(buffer_data) >= n_train_demos, "Not enough demos!" 25 | 26 | # shuffle the list with the fixed seed 27 | rng = random.Random(BUF_SHUFFLE_RNG) 28 | rng.shuffle(buffer_data) 29 | 30 | # split data according to mode 31 | buffer_data = buffer_data[:n_train_demos] if mode == 'train' \ 32 | else buffer_data[n_train_demos:] 33 | 34 | self.transform = transform 35 | self.s_a_sprime = [] 36 | for traj in tqdm.tqdm(buffer_data): 37 | imgs, obs, acs = traj['images'], traj['observations'], traj['actions'] 38 | assert len(obs) == len(acs) and len(acs) == len(imgs), "All time dimensions must match!" 39 | 40 | # pad camera dimension if needed 41 | if len(imgs.shape) == 4: 42 | imgs = imgs[:,None] 43 | 44 | for t in range(len(imgs) - ac_chunk): 45 | i_t, o_t = imgs[t], obs[t] 46 | i_t_prime, o_t_prime = imgs[t+ac_chunk], obs[t+ac_chunk] 47 | a_t = acs[t:t+ac_chunk] 48 | self.s_a_sprime.append(((i_t, o_t), a_t, (i_t_prime, o_t_prime))) 49 | 50 | def _load_buffer(self, buffer_path): 51 | print('loading', buffer_path) 52 | with open(buffer_path, 'rb') as f: 53 | buffer_data = pkl.load(f) 54 | return buffer_data 55 | 56 | def __len__(self): 57 | return len(self.s_a_sprime) 58 | 59 | def __getitem__(self, idx): 60 | (i_t, o_t), a_t, (i_t_prime, o_t_prime) = self.s_a_sprime[idx] 61 | 62 | i_t, i_t_prime = _img_to_tensor(i_t), _img_to_tensor(i_t_prime) 63 | o_t, a_t, o_t_prime = _to_tensor(o_t), _to_tensor(a_t), _to_tensor(o_t_prime) 64 | 65 | if self.transform is not None: 66 | N_CAM = i_t.shape[0] 67 | imgs = torch.cat((i_t, i_t_prime), dim=0) 68 | imgs = self.transform(imgs) 69 | i_t, i_t_prime = imgs[:N_CAM], imgs[N_CAM:] 70 | return (i_t, o_t), a_t, (i_t_prime, o_t_prime) 71 | 72 | 73 | def _embed_img(features, img, device, transform): 74 | img = transform(_img_to_tensor(img)) 75 | with torch.no_grad(): 76 | feat = features(img.to(device)) 77 | return feat.reshape(-1).cpu().numpy().astype(np.float32) 78 | 79 | 80 | class IterableWrapper(IterableDataset): 81 | def __init__(self, wrapped_dataset, max_count=float('inf')): 82 | self.wrapped = wrapped_dataset 83 | self.ctr, self.max_count = 0, max_count 84 | 85 | def __iter__(self): 86 | self.ctr = 0 87 | return self 88 | 89 | def __next__(self): 90 | if self.ctr > self.max_count: 91 | raise StopIteration 92 | 93 | self.ctr += 1 94 | idx = int(np.random.choice(len(self.wrapped))) 95 | return self.wrapped[idx] 96 | 97 | 98 | class RobobufReplayBuffer(ReplayBuffer): 99 | def __init__(self, buffer_path, transform=None, n_test_trans=500, mode='train', ac_chunk=1, cam_idx=0): 100 | assert mode in ('train', 'test'), "Mode must be train/test" 101 | with open(buffer_path, 'rb') as f: 102 | buf = RB.load_traj_list(pkl.load(f)) 103 | assert len(buf) > n_test_trans, "Not enough transitions!" 104 | assert ac_chunk == 1, "Only supports ac_chunk of 1 for now!" 105 | 106 | # shuffle the list with the fixed seed 107 | rng = random.Random(BUF_SHUFFLE_RNG) 108 | 109 | # get and shuffle list of buf indices 110 | index_list = list(range(len(buf))) 111 | rng.shuffle(index_list) 112 | 113 | # split data according to mode 114 | index_list = index_list[n_test_trans:] if mode == 'train' \ 115 | else index_list[:n_test_trans] 116 | 117 | self.transform = transform 118 | self.s_a_sprime = [] 119 | last = 0 120 | print(f'Building {mode} buffer with cam_idx={cam_idx}') 121 | for idx in tqdm.tqdm(index_list): 122 | t = buf[idx] 123 | if t.next is None: 124 | last += 1 125 | continue 126 | 127 | i_t, o_t = t.obs.image(cam_idx)[None], t.obs.state 128 | i_t_prime, o_t_prime = t.next.obs.image(cam_idx)[None], t.next.obs.state 129 | a_t = t.action 130 | self.s_a_sprime.append(((i_t, o_t), a_t, (i_t_prime, o_t_prime))) 131 | -------------------------------------------------------------------------------- /data4robotics/models/action_distributions.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import torch 8 | from torch import nn 9 | import torch.nn.functional as F 10 | import torch.distributions as D 11 | 12 | 13 | class ActionDistribution(nn.Module): 14 | def __init__(self, ac_dim, ac_chunk=1): 15 | super().__init__() 16 | self._ac_chunk, self._ac_dim = ac_chunk, ac_dim 17 | 18 | @property 19 | def ac_dim(self): 20 | return self._ac_dim 21 | 22 | @property 23 | def ac_chunk(self): 24 | return self._ac_chunk 25 | 26 | @property 27 | def num_ac_pred(self): 28 | return self._ac_chunk * self._ac_dim 29 | 30 | def unflatten_ac_tensor(self, ac_tensor): 31 | out_shape = list(ac_tensor.shape[:-1]) + [self._ac_chunk, self._ac_dim] 32 | return ac_tensor.reshape(out_shape) 33 | 34 | def get_actions(self, inputs, zero_std=True): 35 | acs = self._sample(inputs, zero_std) 36 | return self.unflatten_ac_tensor(acs) 37 | 38 | def _sample(self, inputs, zero_std=True): 39 | dist = self(inputs, zero_std) 40 | return dist.sample() 41 | 42 | 43 | class Deterministic(ActionDistribution): 44 | def __init__(self, in_dim, ac_dim, ac_chunk=1): 45 | super().__init__(ac_dim, ac_chunk) 46 | self._layer = nn.Linear(in_dim, self.num_ac_pred) 47 | 48 | def forward(self, inputs, zero_std=True): 49 | assert zero_std, "No std prediction in this network!" 50 | return self._layer(inputs) 51 | 52 | def _sample(self, inputs, zero_std=True): 53 | return self(inputs, zero_std) 54 | 55 | 56 | class Gaussian(ActionDistribution): 57 | def __init__(self, in_dim, ac_dim, ac_chunk=1, min_std=1e-4, tanh_mean=False): 58 | super().__init__(ac_dim, ac_chunk) 59 | self._min_std, self._tanh_mean = min_std, tanh_mean 60 | self._mean_net = nn.Linear(in_dim, self.num_ac_pred) 61 | self._scale_net = nn.Linear(in_dim, self.num_ac_pred) 62 | 63 | def forward(self, in_repr, zero_std=False): 64 | B = in_repr.shape[0] 65 | mean = self._mean_net(in_repr).reshape(B, self.num_ac_pred) 66 | scale = self._scale_net(in_repr).reshape(B, self.num_ac_pred) 67 | 68 | # bound the action means and convert scale to std 69 | if self._tanh_mean: 70 | mean = torch.tanh(mean) 71 | std = torch.ones_like(scale) * self._min_std if zero_std else \ 72 | F.softplus(scale) + self._min_std 73 | 74 | # create Normal action distributions 75 | return D.Normal(loc=mean, scale=std) 76 | 77 | 78 | class GaussianSharedScale(ActionDistribution): 79 | def __init__(self, in_dim, ac_dim, ac_chunk=1, min_std=1e-4, tanh_mean=False, 80 | log_std_init=0, std_fixed=False): 81 | super().__init__(ac_dim, ac_chunk) 82 | self._min_std, self._tanh_mean = min_std, tanh_mean 83 | self._mean_net = nn.Linear(in_dim, self.num_ac_pred) 84 | 85 | # create log_std vector and store as param 86 | log_std = torch.Tensor([log_std_init] * ac_dim) 87 | self.register_parameter('log_std', nn.Parameter(log_std, requires_grad=not std_fixed)) 88 | 89 | def forward(self, in_repr, zero_std=False): 90 | B = in_repr.shape[0] 91 | mean = self._mean_net(in_repr).reshape(B, self.num_ac_pred) 92 | scale = self.log_std[None].repeat((B, self._ac_chunk)) 93 | 94 | if self._tanh_mean: 95 | mean = torch.tanh(mean) 96 | std = torch.ones_like(scale) * self._min_std if zero_std else \ 97 | torch.exp(scale) + self._min_std 98 | 99 | # create Normal action distributions 100 | return D.Normal(loc=mean, scale=std) 101 | 102 | 103 | class GaussianMixture(ActionDistribution): 104 | def __init__(self, num_modes, in_dim, ac_dim, ac_chunk=1, min_std=1e-4, tanh_mean=False): 105 | super().__init__(ac_dim, ac_chunk) 106 | self._min_std, self._tanh_mean = min_std, tanh_mean 107 | self._num_modes = num_modes 108 | 109 | self._mean_net = nn.Linear(in_dim, num_modes * self.num_ac_pred) 110 | self._scale_net = nn.Linear(in_dim, num_modes * self.num_ac_pred) 111 | self._logit_net = nn.Linear(in_dim, num_modes) 112 | 113 | def forward(self, in_repr, zero_std=False): 114 | B = in_repr.shape[0] 115 | mean = self._mean_net(in_repr).reshape(B, self._num_modes, self.num_ac_pred) 116 | scale = self._scale_net(in_repr).reshape(B, self._num_modes, self.num_ac_pred) 117 | logits = self._logit_net(in_repr).reshape((B, self._num_modes)) 118 | 119 | # bound the action means and convert scale to std 120 | if self._tanh_mean: 121 | mean = torch.tanh(mean) 122 | std = torch.ones_like(scale) * self._min_std if zero_std else \ 123 | F.softplus(scale) + self._min_std 124 | 125 | # create num_modes independent action distributions 126 | ac_dist = D.Normal(loc=mean, scale=std) 127 | ac_dist = D.Independent(ac_dist, 1) 128 | 129 | # parameterize the mixing distribution and the final GMM 130 | mix_dist = D.Categorical(logits=logits) 131 | gmm_dist = D.MixtureSameFamily(mixture_distribution=mix_dist, 132 | component_distribution=ac_dist) 133 | return gmm_dist 134 | -------------------------------------------------------------------------------- /data4robotics/models/vit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # adapted from: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # -------------------------------------------------------- 11 | # Modified by Sudeep Dasari 12 | 13 | 14 | import os, torch 15 | import torch.nn as nn 16 | import timm.models.vision_transformer 17 | from functools import partial 18 | from timm.models.vision_transformer import resize_pos_embed 19 | 20 | 21 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer): 22 | """Vision Transformer with support for global average pooling""" 23 | 24 | def __init__( 25 | self, global_pool=False, use_cls=True, mask_ratio=None, del_head=True, **kwargs 26 | ): 27 | super(VisionTransformer, self).__init__(**kwargs) 28 | if global_pool: 29 | self.classifier_feature = "global_pool" 30 | elif use_cls: 31 | self.classifier_feature = "use_cls_token" 32 | else: 33 | self.classifier_feature = "reshape_embedding" 34 | 35 | if del_head: 36 | del self.head # don't use prediction head 37 | 38 | if self.classifier_feature == "global_pool": 39 | norm_layer = kwargs["norm_layer"] 40 | embed_dim = kwargs["embed_dim"] 41 | self.fc_norm = norm_layer(embed_dim) 42 | 43 | del self.norm # remove the original norm 44 | 45 | if self.classifier_feature == "reshape_embedding": 46 | self.final_spatial = int(self.patch_embed.num_patches**0.5) 47 | self.embed_dim = ( 48 | self.patch_embed.grid_size[0], 49 | self.patch_embed.grid_size[1], 50 | kwargs["embed_dim"], 51 | ) 52 | 53 | self.mask_ratio = mask_ratio 54 | 55 | def random_masking(self, x, mask_ratio): 56 | """ 57 | Perform per-sample random masking by per-sample shuffling. 58 | Per-sample shuffling is done by argsort random noise. 59 | x: [N, L, D], sequence 60 | """ 61 | N, L, D = x.shape # batch, length, dim 62 | len_keep = int(L * (1 - mask_ratio)) 63 | 64 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 65 | 66 | # sort noise for each sample 67 | ids_shuffle = torch.argsort( 68 | noise, dim=1 69 | ) # ascend: small is keep, large is remove 70 | ids_restore = torch.argsort(ids_shuffle, dim=1) 71 | 72 | # keep the first subset 73 | ids_keep = ids_shuffle[:, :len_keep] 74 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 75 | 76 | # generate the binary mask: 0 is keep, 1 is remove 77 | mask = torch.ones([N, L], device=x.device) 78 | mask[:, :len_keep] = 0 79 | # unshuffle to get the binary mask 80 | mask = torch.gather(mask, dim=1, index=ids_restore) 81 | 82 | return x_masked, mask, ids_restore 83 | 84 | def handle_outcome(self, x): 85 | if self.classifier_feature == "global_pool": 86 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 87 | outcome = self.fc_norm(x) 88 | elif self.classifier_feature == "use_cls_token": 89 | x = self.norm(x) 90 | outcome = x[:, 0] # use cls token 91 | elif self.classifier_feature == "reshape_embedding": 92 | x = self.norm(x) 93 | outcome = reshape_embedding( 94 | x[:, 1:] 95 | ) # remove cls token and reshape embedding 96 | else: 97 | raise NotImplementedError 98 | 99 | return outcome 100 | 101 | def forward_features(self, x): 102 | B = x.shape[0] 103 | x = self.patch_embed(x) 104 | 105 | # add pos embed w/o cls token 106 | x = x + self.pos_embed[:, 1:, :] 107 | 108 | # masking: length -> length * mask_ratio 109 | if self.mask_ratio is not None: 110 | x, _, _ = self.random_masking(x, mask_ratio=self.mask_ratio) 111 | 112 | # append cls token 113 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 114 | x = torch.cat((cls_token.expand(B, -1, -1), x), dim=1) 115 | 116 | x = self.blocks(x) 117 | return self.handle_outcome(x) 118 | 119 | def forward(self, x): 120 | return self.forward_features(x) 121 | 122 | 123 | class ClipVisionTransformer(VisionTransformer): 124 | def forward_features(self, x): 125 | B = x.shape[0] 126 | x = self.patch_embed(x) 127 | x = torch.cat( 128 | [ 129 | self.cls_token.squeeze() 130 | + torch.zeros(B, 1, x.shape[-1], device=x.device), 131 | x, 132 | ], 133 | dim=1, 134 | ) # shape = [*, grid ** 2 + 1, width] 135 | x = x + self.pos_embed.squeeze().to(x.dtype) 136 | x = self.norm_pre(x) 137 | 138 | x = self.blocks(x) 139 | return self.handle_outcome(x) 140 | 141 | 142 | def reshape_embedding(x): 143 | N, L, D = x.shape 144 | H = W = int(L**0.5) 145 | x = x.reshape(N, H, W, D) 146 | x = torch.einsum("nhwd->ndhw", x) 147 | return x 148 | 149 | 150 | def vit_small_patch16(**kwargs): 151 | """ViT small as defined in the DeiT paper.""" 152 | model = VisionTransformer( 153 | patch_size=16, 154 | embed_dim=384, 155 | depth=12, 156 | num_heads=6, 157 | mlp_ratio=4, 158 | qkv_bias=True, 159 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 160 | **kwargs 161 | ) 162 | return model 163 | 164 | 165 | def vit_base_patch16(**kwargs): 166 | model = VisionTransformer( 167 | patch_size=16, 168 | embed_dim=768, 169 | depth=12, 170 | num_heads=12, 171 | mlp_ratio=4, 172 | qkv_bias=True, 173 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 174 | **kwargs 175 | ) 176 | return model 177 | 178 | 179 | def clip_vit_base_patch16(**kwargs): 180 | model = ClipVisionTransformer( 181 | patch_size=16, 182 | embed_dim=768, 183 | depth=12, 184 | num_heads=12, 185 | mlp_ratio=4, 186 | qkv_bias=True, 187 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 188 | # CLIP-specific: 189 | pre_norm=True, 190 | num_classes=512, 191 | **kwargs 192 | ) 193 | return model 194 | 195 | 196 | def vit_large_patch16(**kwargs): 197 | model = VisionTransformer( 198 | patch_size=16, 199 | embed_dim=1024, 200 | depth=24, 201 | num_heads=16, 202 | mlp_ratio=4, 203 | qkv_bias=True, 204 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 205 | **kwargs 206 | ) 207 | return model 208 | 209 | 210 | def vit_huge_patch14(**kwargs): 211 | model = VisionTransformer( 212 | patch_size=14, 213 | embed_dim=1280, 214 | depth=32, 215 | num_heads=16, 216 | mlp_ratio=4, 217 | qkv_bias=True, 218 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 219 | **kwargs 220 | ) 221 | return model 222 | 223 | 224 | def load_vit(model, restore_path): 225 | if restore_path: 226 | print('Restoring model from', restore_path) 227 | state_dict = torch.load(restore_path, map_location='cpu') 228 | state_dict = state_dict['features'] if 'features' in state_dict \ 229 | else state_dict['model'] 230 | 231 | # resize pos_embed if required 232 | if state_dict["pos_embed"].shape != model.pos_embed.shape: 233 | print(f"resizing pos_embed from {state_dict['pos_embed'].shape} to {model.pos_embed.shape}") 234 | state_dict["pos_embed"] = resize_pos_embed( 235 | state_dict["pos_embed"], 236 | model.pos_embed, 237 | getattr(model, "num_tokens", 1), 238 | model.patch_embed.grid_size, 239 | ) 240 | 241 | # filter out keys with name decoder or mask_token 242 | state_dict = { 243 | k: v 244 | for k, v in state_dict.items() 245 | if "decoder" not in k and "mask_token" not in k 246 | } 247 | 248 | # remove norm if using global_pool instead of class token 249 | if model.classifier_feature == "global_pool": 250 | print("Removing extra weights for global_pool") 251 | # remove layer that start with norm 252 | state_dict = {k: v for k, v in state_dict.items() if not k.startswith("norm")} 253 | # add fc_norm in the state dict from the model 254 | state_dict["fc_norm.weight"] = model.fc_norm.weight 255 | state_dict["fc_norm.bias"] = model.fc_norm.bias 256 | 257 | # load state_dict 258 | model.load_state_dict(state_dict) 259 | return model 260 | --------------------------------------------------------------------------------