├── .gitignore ├── LICENSE ├── README.md ├── bridgedata ├── __init__.py ├── data_sets │ ├── __init__.py │ ├── data_augmentation.py │ ├── data_loader.py │ ├── data_utils │ │ ├── __init__.py │ │ └── test_datasets.py │ ├── multi_dataset_loader.py │ └── replay_buffer.py ├── models │ ├── __init__.py │ ├── base_model.py │ ├── gcbc.py │ ├── gcbc_images.py │ ├── gcbc_images_context.py │ ├── gcbc_transfer.py │ └── utils │ │ ├── __init__.py │ │ ├── compute_dataset_normalization.py │ │ ├── gradient_reversal_layer.py │ │ ├── layers.py │ │ ├── modelutils.py │ │ ├── orig_resnet │ │ ├── __init__.py │ │ └── resnet.py │ │ ├── recurrent.py │ │ ├── resnet.py │ │ ├── spatial_softmax.py │ │ └── subnetworks.py ├── policies │ ├── __init__.py │ ├── gcbc_policy.py │ └── gcp_policy.py ├── train.py └── utils │ ├── __init__.py │ ├── calc_success_rates.py │ ├── checkpointer.py │ ├── figure_out_scatter.py │ ├── general_utils.py │ ├── tensorboard_logger.py │ └── vis_utils.py ├── bridgedata_experiments ├── __init__.py ├── bc_fromscratch │ └── conf.py ├── dataset_lmdb.py ├── random_mixing_task_id │ ├── conf.py │ └── conf_toykitchen1.py └── task_id_conditioned │ ├── conf.py │ └── conf_exclude_toykitchen1.py ├── docker └── azure │ ├── Dockerfile │ ├── docker_requirements.txt │ ├── doodad_launch.py │ └── files │ ├── 10_nvidia.json │ └── Xdummy ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | bridgedata/config.py 132 | *.idea 133 | *.sif 134 | *.img 135 | *.DS_Store 136 | .vscode/ 137 | wandb 138 | 139 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Frederik Ebert, Yanlai Yang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Bridge Data Imitation Learning 2 | 3 | This is the accompanying code repository for the paper "Bridge Data: Boosting Generalization of Robotic Skills with Cross-Domain Datasets" ([arXiv paper](https://arxiv.org/abs/2109.13396)). Here is the [project website](https://sites.google.com/view/bridgedata) where you can find more information about how to use and contribute to the dataset. 4 | 5 | ## Installation 6 | 7 | In your `.bashrc` set the environment variables EXP for experiment data and DATA for trainingdata: 8 | 9 | ``` 10 | export EXP= 11 | export DATA= 12 | ``` 13 | 14 | Setup conda environment by running 15 | 16 | ``` 17 | conda create --name bridgedata python=3.6.8 pip 18 | pip install -r requirements.txt 19 | ``` 20 | 21 | then in this directory run 22 | 23 | `pip install -e .` 24 | 25 | Clone the [bridge data robot infrastructure repository](https://github.com/yanlai00/bridge_data_robot_infra), install the dependencies, and run 26 | 27 | `pip install -e widowx_envs` 28 | 29 | ## Examples to run 30 | 31 | ### Training 32 | 33 | #### Single Task Imitation Learning 34 | 35 | `python bridgedata/train.py --path bridgedata_experiments/bc_fromscratch/conf.py` 36 | 37 | The example config file trains the "wipe plate with sponge task". You can change the training task and the training parameters in `bridgedata_experiments/bc_fromscratch/conf.py`. 38 | 39 | #### Multi Task Imitation Learning 40 | 41 | `python bridgedata/train.py --path bridgedata_experiments/task_id_conditioning/conf.py` 42 | 43 | The example config file trains a multi-task, task-id conditioned imitation learning policy on all of the tasks in toykitchen1. 44 | 45 | Another example config file `bridgedata_experiments/task_id_conditioning/conf_exclude_toykitchen1.py` trains a multi-task policy on all of the environments except toykitchen1 (to evaluation transferability of policies). 46 | 47 | #### Multi Task Imitation Learning (with dataset re-balancing) 48 | 49 | `python bridgedata/train.py --path bridgedata_experiments/random_mixing_task_id/conf.py` 50 | 51 | The example config file trains a multi-task, task-id conditioned imitation learning policy on all of the environments except real kitchen 1, and the wipe plate with sponge task. The dataset is re-balanced such that the wipe plate with sponge task takes up 10% of the training dataset. 52 | 53 | Another example config file `bridgedata_experiments/random_mixing_task_id/conf_toykitchen1.py` rebalances the dataset such that trajectories in toy kitchen 1 takes up 30% of the training dataset. 54 | 55 | ## Doodad 56 | 57 | This repository also provides an example script `docker/azure/doodad_launch.py` for launching jobs on cloud compute services like AWS, GCP or Azure with [Doodad](https://github.com/rail-berkeley/doodad). 58 | -------------------------------------------------------------------------------- /bridgedata/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanlai00/bridge_data_imitation_learning/d81a3d9181672f5f26dfbd1844d3017cf7a11367/bridgedata/__init__.py -------------------------------------------------------------------------------- /bridgedata/data_sets/__init__.py: -------------------------------------------------------------------------------- 1 | def get_dataset_class(name): 2 | if name == 'FixLenVideoDataset': 3 | from bridgedata.data_sets.data_loader import FixLenVideoDataset 4 | return FixLenVideoDataset 5 | if name == 'MultiDatasetLoader': 6 | from bridgedata.data_sets.multi_dataset_loader import MultiDatasetLoader 7 | return MultiDatasetLoader 8 | -------------------------------------------------------------------------------- /bridgedata/data_sets/data_augmentation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from bridgedata.utils.general_utils import np_unstack 4 | import torch 5 | 6 | def get_random_crop(images, size, center_crop=False): 7 | # assume images are in Tlen, C, H, W or C, H, W order 8 | im_size = np.array(images.shape[-2:]) 9 | if center_crop: 10 | shift_r = (im_size[0] - size[0]) // 2 11 | shift_c = (im_size[1] - size[1]) // 2 12 | else: 13 | shift_r = np.random.randint(0, im_size[0] - size[0], dtype=np.int) 14 | shift_c = np.random.randint(0, im_size[1] - size[1], dtype=np.int) 15 | 16 | if len(images.shape) == 4: 17 | return images[:, :, shift_r : shift_r + size[0], shift_c : shift_c + size[1]] 18 | elif len(images.shape) == 3: 19 | return images[:, shift_r : shift_r + size[0], shift_c : shift_c + size[1]] 20 | else: 21 | raise ValueError('wrong shape ', images.shape) 22 | 23 | 24 | def get_random_color_aug(images, scale, minus_one_to_one_range=False): 25 | """ 26 | alternative color jitter based on cv2 27 | :param images: shape: tlen/batch, 3, height, width 28 | :param scale: 29 | :return: 30 | """ 31 | if len(images.shape) == 4: 32 | tlen = images.shape[0] 33 | if isinstance(images, torch.Tensor): 34 | images = images.permute(0, 2, 3, 1) 35 | elif isinstance(images, np.ndarray): 36 | images = images.transpose(0, 2, 3, 1) 37 | else: 38 | raise ValueError('not supported data type!') 39 | images = np.concatenate(np_unstack(images, 0), axis=0) 40 | assert images.dtype == np.float32 41 | if minus_one_to_one_range: 42 | images = (images + 1)/2 # convert to 0 to 1 range 43 | assert np.min(images) >= 0 and np.max(images) <= 1 44 | images = (images*255).astype(np.uint8) 45 | hsv = np.asarray(cv2.cvtColor(images, cv2.COLOR_RGB2HSV)) 46 | hsv_rand = np.random.uniform(np.ones(3) - scale, np.ones(3) + scale) 47 | hsv = np.clip(hsv * hsv_rand[None, None], 0, 255) 48 | hsv = hsv.astype(np.uint8) 49 | rgb = np.asarray(cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)) 50 | images = np.stack(np.split(rgb, tlen, axis=0), axis=0) 51 | images = images.transpose(0, 3, 1, 2) 52 | images = images.astype(np.float32)/255. 53 | return images 54 | elif len(images.shape) == 3: 55 | if isinstance(images, torch.Tensor): 56 | images = images.permute(1, 2, 0) 57 | elif isinstance(images, np.ndarray): 58 | images = images.transpose(1, 2, 0) 59 | else: 60 | raise ValueError('not supported data type!') 61 | assert images.dtype == np.float32 62 | if minus_one_to_one_range: 63 | images = (images + 1)/2 # convert to 0 to 1 range 64 | assert np.min(images) >= 0 and np.max(images) <= 1 65 | images = (images*255).astype(np.uint8) 66 | hsv = np.asarray(cv2.cvtColor(images, cv2.COLOR_RGB2HSV)) 67 | hsv_rand = np.random.uniform(np.ones(3) - scale, np.ones(3) + scale) 68 | hsv = np.clip(hsv * hsv_rand[None, None], 0, 255) 69 | hsv = hsv.astype(np.uint8) 70 | images = np.asarray(cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)) 71 | images = images.transpose(2, 0, 1) 72 | images = images.astype(np.float32)/255. 73 | return images 74 | else: 75 | raise ValueError('wrong shape ', images.shape) 76 | 77 | -------------------------------------------------------------------------------- /bridgedata/data_sets/data_loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import numpy as np 3 | import glob 4 | import h5py 5 | import random 6 | import imp 7 | from bridgedata.data_sets.data_utils.test_datasets import make_gifs 8 | from torch.utils.data import DataLoader 9 | import os 10 | from bridgedata.utils.general_utils import Configurable 11 | from bridgedata.utils.general_utils import AttrDict, map_dict, resize_video 12 | from bridgedata.data_sets.data_augmentation import get_random_color_aug, get_random_crop 13 | 14 | class BaseVideoDataset(data.Dataset, Configurable): 15 | def __init__(self, data_conf, phase='train', shuffle=True): 16 | """ 17 | 18 | :param data_dir: 19 | :param mpar: 20 | :param data_conf: 21 | :param phase: 22 | :param shuffle: whether to shuffle within batch, set to False for computing metrics 23 | :param dataset_size: 24 | """ 25 | 26 | self._hp = self._default_hparams() 27 | self._override_defaults(data_conf) 28 | 29 | self.phase = phase 30 | self.data_conf = data_conf 31 | self.shuffle = shuffle and phase == 'train' 32 | 33 | import torch.multiprocessing 34 | torch.multiprocessing.set_sharing_strategy('file_system') 35 | 36 | def _default_hparams(self): 37 | default_dict = AttrDict( 38 | n_worker=10, 39 | ) 40 | return AttrDict(default_dict) 41 | 42 | def get_data_loader(self, batch_size): 43 | print('datadir {}, len {} dataset {}'.format(self.data_conf.data_dir, self.phase, len(self))) 44 | print('data loader nworkers', self._hp.n_worker) 45 | return DataLoader(self, batch_size=batch_size, shuffle=self.shuffle, num_workers=self._hp.n_worker, 46 | drop_last=True) 47 | 48 | 49 | class FixLenVideoDataset(BaseVideoDataset): 50 | """ 51 | Variable length video dataset 52 | """ 53 | 54 | def __init__(self, data_conf, phase='train', shuffle=True, transform=None): 55 | """ 56 | :param data_conf: Attrdict with keys 57 | :param phase: 58 | :param shuffle: whether to shuffle within batch, set to False for computing metrics 59 | :param dataset_size: 60 | """ 61 | super().__init__(data_conf, phase, shuffle) 62 | self._hp = self._default_hparams() 63 | self._override_defaults(data_conf) 64 | self.look_for_files(phase) 65 | self.transform = transform 66 | 67 | def look_for_files(self, phase): 68 | if isinstance(self._hp.data_dir, list): 69 | self.filenames = [] 70 | for dir in self._hp.data_dir: 71 | self.filenames += self._get_filenames(dir) 72 | random.seed(1) 73 | random.shuffle(self.filenames) 74 | else: 75 | self.filenames = self._get_filenames(self._hp.data_dir) 76 | self.filenames = self._maybe_post_split(self.filenames) 77 | if self._hp.train_data_fraction < 1 and phase == 'train': 78 | print('###################################') 79 | print("Warning, using {} fraction of data!!!".format(self._hp.train_data_fraction)) 80 | print('###################################') 81 | self.filenames = self.filenames[:int(len(self.filenames) * self._hp.train_data_fraction)] 82 | 83 | if self._hp.max_train_examples and phase == 'train': 84 | print('###################################') 85 | print("Warning, using max train examples {}!!!".format(self._hp.max_train_examples)) 86 | print('###################################') 87 | self.filenames = self.filenames[:self._hp.max_train_examples] 88 | 89 | self.traj_per_file = self.get_traj_per_file(self.filenames[0]) 90 | if self._hp.T is None: 91 | self._hp.T = self.get_max_seqlen(self.filenames[0]) 92 | print('init dataloader for phase {} with {} files'.format(phase, len(self.filenames))) 93 | 94 | def _default_hparams(self): 95 | # Data Dimensions 96 | default_dict = AttrDict( 97 | name="", # the name of the dataset, used for writing logs 98 | data_dir=None, 99 | random_crop=False, 100 | image_size_beforecrop=None, 101 | color_augmentation=False, 102 | sel_len=-1, # number of time steps for contigous sequence that is shifted within sequeence of T randomly 103 | sel_camera=0, 104 | concatentate_cameras=False, 105 | T=None, 106 | downsample_img_sz=None, 107 | train_data_fraction=1., 108 | max_train_examples=None 109 | ) 110 | # add new params to parent params 111 | parent_params = super()._default_hparams() 112 | parent_params.update(default_dict) 113 | return parent_params 114 | 115 | def _get_filenames(self, data_dir): 116 | assert 'hdf5' not in data_dir, "hdf5 most not be containted in the data dir!" 117 | filenames = sorted(glob.glob(os.path.join(data_dir, os.path.join('hdf5', self.phase) + '/*'))) 118 | if not filenames: 119 | raise RuntimeError('No filenames found in {}'.format(data_dir)) 120 | random.seed(1) 121 | random.shuffle(filenames) 122 | return filenames 123 | 124 | def get_traj_per_file(self, path): 125 | with h5py.File(path, 'r') as F: 126 | return int(np.array(F['traj_per_file'])) 127 | 128 | def get_max_seqlen(self, path): 129 | # return maximum number of images over all trajectories 130 | with h5py.File(path, 'r') as F: 131 | return int(np.array(F['max_num_images'])) 132 | 133 | def _get_num_from_str(self, s): 134 | return int(''.join(filter(str.isdigit, s))) 135 | 136 | def __getitem__(self, index): 137 | # making sure that different loading threads aren't using the same random seed. 138 | np.random.seed(index) 139 | random.seed(index) 140 | 141 | file_index = index // self.traj_per_file 142 | path = self.filenames[file_index] 143 | 144 | output = self.parse_file(path, index, self.traj_per_file) 145 | if self.transform is not None: 146 | return self.transform(output) 147 | else: 148 | return output 149 | 150 | def parse_file(self, path, index=0, traj_per_file=1): 151 | self.single_filename = str.split(path, '/')[-1] 152 | start_ind_str, _ = path.split('/')[-1][:-3].split('to') 153 | with h5py.File(path, 'r') as F: 154 | ex_index = index % traj_per_file # get the index 155 | key = 'traj{}'.format(ex_index) 156 | data_dict = AttrDict() 157 | if key + '/images' in F: 158 | data_dict.images = np.array(F[key + '/images']) 159 | for name in F[key].keys(): 160 | if name in ['states', 'actions', 'pad_mask']: 161 | data_dict[name] = np.array(F[key + '/' + name]).astype(np.float32) 162 | if name in ['camera_ind', 'num_cameras', 'base_pos_ind', 'num_base_positions']: 163 | # camera index used in simulation when using a random virtual camera 164 | data_dict[name] = np.array(F[key + '/' + name]).astype(np.int) 165 | 166 | 167 | if self._hp.T is not None: 168 | for key in data_dict.keys(): 169 | if key in ['camera_ind', 'num_cameras', 'base_pos_ind', 'num_base_positions', 'domain_ind']: 170 | continue 171 | if key == 'actions': # actions are shorter by one time step 172 | data_dict[key] = data_dict[key][:self._hp.T - 1] 173 | else: 174 | data_dict[key] = data_dict[key][:self._hp.T] 175 | 176 | data_dict = self.process_data_dict(data_dict) 177 | if self._hp.sel_len != -1: 178 | data_dict = self.sample_rand_shifts(data_dict) 179 | 180 | data_dict['tlen'] = data_dict['images'].shape[0] 181 | for k, v in data_dict.items(): 182 | if k in ['camera_ind', 'num_cameras', 'base_pos_ind', 'num_base_positions', 'domain_ind', 'tlen']: 183 | continue 184 | if k == 'actions': 185 | desired_T = self._hp.T - 1 # actions need to be shorter by one since they need to have a start and end-state! 186 | else: 187 | desired_T = self._hp.T 188 | if v.shape[0] < desired_T: 189 | data_dict[k] = self.pad_tensor(v, desired_T) 190 | 191 | if 'camera_ind' in data_dict and 'base_pos_id' in data_dict: 192 | data_dict['domain_ind'] = data_dict['camera_ind'] * data_dict['num_base_positions'] + data_dict['base_pos_ind'] 193 | elif 'camera_ind' in data_dict: 194 | data_dict['domain_ind'] = data_dict['camera_ind'] 195 | return data_dict 196 | 197 | def process_data_dict(self, data_dict): 198 | if 'images' in data_dict: 199 | images = data_dict['images'] 200 | if self._hp.sel_camera != -1: 201 | assert len(images.shape) == 5 202 | if self._hp.sel_camera == 'random': 203 | cam_ind = np.random.randint(0, images.shape[1]) 204 | data_dict.camera_ind = np.array([cam_ind]) 205 | data_dict.num_cameras = images.shape[1] 206 | images = images[:, cam_ind] 207 | else: 208 | images = images[:, self._hp.sel_camera] 209 | data_dict.camera_ind = np.array([self._hp.sel_camera]) 210 | images = images[:, None] 211 | # Resize video 212 | if len(images.shape) == 5: 213 | imlist = [] 214 | for n in range(images.shape[1]): 215 | imlist.append(self.preprocess_images(images[:, n])) 216 | data_dict.images = np.stack(imlist, axis=1) 217 | else: 218 | data_dict.images = self.preprocess_images(images) 219 | return data_dict 220 | 221 | def sample_rand_shifts(self, data_dict): 222 | """ This function processes data tensors so as to have length equal to max_seq_len 223 | by sampling / padding if necessary """ 224 | offset = np.random.randint(0, self.T - self._hp.sel_len, 1) 225 | 226 | data_dict = map_dict(lambda tensor: self._croplen(tensor, offset, self._hp.sel_len), data_dict) 227 | if 'actions' in data_dict: 228 | data_dict.action_targets = data_dict.action_targets[:-1] 229 | return data_dict 230 | 231 | def preprocess_images(self, images): 232 | assert images.dtype == np.uint8, 'image need to be uint8!' 233 | if self._hp.downsample_img_sz is not None: 234 | images = resize_video(images, (self._hp.downsample_img_sz[0], self._hp.downsample_img_sz[1])) 235 | images = np.transpose(images, [0, 3, 1, 2]) # convert to channel-first 236 | images = images.astype(np.float32) / 255 237 | if self._hp.color_augmentation and self.phase is 'train': 238 | images = get_random_color_aug(images, self._hp.color_augmentation) 239 | if self._hp.random_crop: 240 | assert images.shape[-2:] == tuple(self._hp.image_size_beforecrop) 241 | images = get_random_crop(images, self._hp.random_crop, center_crop=self.phase != 'train') 242 | images = images * 2 - 1 243 | assert images.dtype == np.float32, 'image need to be float32!' 244 | return images 245 | 246 | def pad_tensor(self, tensor, desired_T): 247 | pad = np.zeros([desired_T - tensor.shape[0]] + list(tensor.shape[1:]), dtype=np.float32) 248 | tensor = np.concatenate([tensor, pad], axis=0) 249 | return tensor 250 | 251 | def _maybe_post_split(self, filenames): 252 | """Splits dataset percentage-wise if respective field defined.""" 253 | try: 254 | return self._split_with_percentage(self.data_conf.train_val_split, filenames) 255 | except (KeyError, AttributeError): 256 | return filenames 257 | 258 | def _split_with_percentage(self, frac, filenames): 259 | assert sum(frac.values()) <= 1.0 # fractions cannot sum up to more than 1 260 | assert self.phase in frac 261 | if self.phase == 'train': 262 | start, end = 0, frac['train'] 263 | elif self.phase == 'val': 264 | start, end = frac['train'], frac['train'] + frac['val'] 265 | else: 266 | start, end = frac['train'] + frac['val'], frac['train'] + frac['val'] + frac['test'] 267 | start, end = int(len(filenames) * start), int(len(filenames) * end) 268 | return filenames[start:end] 269 | 270 | def __len__(self): 271 | return len(self.filenames) * self.traj_per_file 272 | 273 | @staticmethod 274 | def _croplen(val, offset, target_length): 275 | """Pads / crops sequence to desired length.""" 276 | 277 | val = val[int(offset):] 278 | len = val.shape[0] 279 | if len > target_length: 280 | return val[:target_length] 281 | elif len < target_length: 282 | raise ValueError("not enough length") 283 | else: 284 | return val 285 | 286 | @staticmethod 287 | def get_dataset_spec(data_dir): 288 | return imp.load_source('dataset_spec', os.path.join(data_dir, 'dataset_spec.py')).dataset_spec 289 | 290 | 291 | 292 | -------------------------------------------------------------------------------- /bridgedata/data_sets/data_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanlai00/bridge_data_imitation_learning/d81a3d9181672f5f26dfbd1844d3017cf7a11367/bridgedata/data_sets/data_utils/__init__.py -------------------------------------------------------------------------------- /bridgedata/data_sets/data_utils/test_datasets.py: -------------------------------------------------------------------------------- 1 | from bridgedata.utils.vis_utils import npy_to_gif, npy_to_mp4 2 | from bridgedata.utils.general_utils import np_unstack 3 | import numpy as np 4 | import os 5 | 6 | def make_gifs(loader, outdir=None): 7 | if outdir is None: 8 | outdir = os.environ['HOME'] + '/Desktop' 9 | for i_batch, sample_batched in enumerate(loader): 10 | images = np.asarray(sample_batched['images']) 11 | ncam = images.shape[2] 12 | for cam in range(ncam): 13 | images_cam = images[:, :, cam] 14 | images_cam = (np.transpose((images_cam + 1) / 2, [0, 1, 3, 4, 2]) * 255.).astype(np.uint8) # convert to channel-first 15 | 16 | im_list = [] 17 | for t in range(images_cam.shape[1]): 18 | im_list.append(np.concatenate(np_unstack(images_cam[:, t], axis=0), axis=1)) 19 | # npy_to_gif(im_list, outdir + '/traj{}_cam_{}'.format(i_batch, cam), fps=10) 20 | npy_to_mp4(im_list, outdir + '/traj{}_cam_{}'.format(i_batch, cam), fps=10) 21 | 22 | actions = np.asarray(sample_batched['actions']) 23 | # print('actions', actions) 24 | print('tlen', sample_batched['tlen']) 25 | print('camera_ind', sample_batched['camera_ind']) 26 | # import pdb; pdb.set_trace() 27 | 28 | def measure_time(loader): 29 | import time 30 | tstart = time.time() 31 | n_batch = 100 32 | for i_batch, sample_batched in enumerate(loader): 33 | print('ibatch', i_batch) 34 | if i_batch == n_batch: 35 | break 36 | print('average loading time', (time.time() - tstart)/n_batch) 37 | -------------------------------------------------------------------------------- /bridgedata/data_sets/multi_dataset_loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from bridgedata.utils.general_utils import AttrDict 3 | import time 4 | import hashlib 5 | import numpy as np 6 | from torch.utils.data import DataLoader 7 | import os 8 | 9 | class MultiDatasetLoader(): 10 | def __init__(self, dataset_dict, phase, shuffle=True): 11 | """ 12 | :param dataset_dict: Attrdict with {single_task: AttrDict(dataset_class=..., data_conf=AttrDict()),} 13 | :param phase: 14 | :param shuffle: 15 | """ 16 | self.data_sets = AttrDict() 17 | self.lengths = AttrDict() 18 | self._hp = AttrDict() 19 | 20 | if shuffle: 21 | self.n_worker = 10 22 | else: 23 | self.n_worker = 1 24 | if 'n_worker' in dataset_dict: 25 | self.n_worker = dataset_dict.pop('n_worker') 26 | for dataset_name, value in dataset_dict.items(): 27 | dataset_class = value.dataclass 28 | data_conf = value.dataconf 29 | self.data_sets[dataset_name] = dataset_class(data_conf, phase, shuffle) 30 | self.lengths[dataset_name] = len(self.data_sets[dataset_name]) 31 | self._hp[dataset_name] = data_conf 32 | 33 | self.phase = phase 34 | self.shuffle = shuffle and phase == 'train' 35 | 36 | def __getitem__(self, index): 37 | dict = AttrDict() 38 | for name, dataset in self.data_sets.items(): 39 | use_index = index % self.lengths[name] 40 | dict[name] = dataset.__getitem__(use_index) 41 | return dict 42 | 43 | def __len__(self): 44 | lengths = [len(d) for d in self.data_sets.values()] 45 | return max(lengths) 46 | 47 | def get_data_loader(self, batch_size): 48 | lengths = [len(d) for d in self.data_sets.values()] 49 | print('phase {} len {} nworker:{} '.format(self.phase, lengths, self.n_worker)) 50 | return DataLoader(self, batch_size=batch_size, shuffle=self.shuffle, num_workers=self.n_worker, 51 | drop_last=True) 52 | 53 | 54 | class RandomMixingDatasetLoader(): 55 | def __init__(self, dataset_dict, phase, shuffle): 56 | """ 57 | :param dataset_dict: Attrdict with {dataset_name: (dataset_class, data_conf)} 58 | :param phase: 59 | :param shuffle: 60 | """ 61 | self.data_sets = {} 62 | self.lengths = {} 63 | self._hp = dataset_dict 64 | self.data_set_sample_probabilities = [] 65 | for key, value in dataset_dict.items(): 66 | if key.startswith('dataset'): 67 | dataset_name, [dataset_class, data_conf, prob] = key, value 68 | self.data_set_sample_probabilities.append(prob) 69 | self.data_sets[dataset_name] = dataset_class(data_conf, phase, shuffle) 70 | self.lengths[dataset_name] = len(self.data_sets[dataset_name]) 71 | 72 | self.sync_train_domain_and_taskdescription_indices() 73 | 74 | self.phase = phase 75 | self.data_conf = data_conf 76 | self.shuffle = shuffle and phase == 'train' 77 | 78 | if shuffle: 79 | self.n_worker = 10 80 | else: 81 | self.n_worker = 1 82 | 83 | # self.n_worker = 0 84 | self._hp.name = 'random_mixing' 85 | 86 | def __getitem__(self, index): 87 | """ 88 | :param index: index referes to the index of the shortest datasets datapoints for the other datasets are selected randomly 89 | :return: 90 | """ 91 | np.random.seed(index) 92 | name = str(np.random.choice(list(self.data_sets.keys()), 1, p=self.data_set_sample_probabilities)[0]) 93 | use_index = index % self.lengths[name] 94 | # print('index {} useindex {} length {}'.format(index, use_index, self.lengths[name])) 95 | dict = self.data_sets[name].__getitem__(use_index) 96 | return dict 97 | 98 | def __len__(self): 99 | lengths = [len(d) for d in self.data_sets.values()] 100 | return max(lengths) 101 | 102 | def get_data_loader(self, batch_size): 103 | print('len {} dataset {}'.format(self.phase, len(self))) 104 | return DataLoader(self, batch_size=batch_size, shuffle=self.shuffle, num_workers=self.n_worker, 105 | drop_last=True) 106 | 107 | def sync_train_domain_and_taskdescription_indices(self): 108 | dataset_names = list(self.data_sets.keys()) 109 | all_domains = set([d for dataset_name in dataset_names for d in list(self.data_sets[dataset_name].domain_hash_index.keys())]) 110 | all_taskdescriptions = set([d for dataset_name in dataset_names for d in list(self.data_sets[dataset_name].taskdescription2task_index.keys())]) 111 | 112 | self.domain_hash_index = {domain_hash: index for domain_hash, index in 113 | zip(all_domains, range(len(all_domains)))} 114 | self.taskdescription2task_index = {task_descp: index for task_descp, index in 115 | zip(all_taskdescriptions, range(len(all_taskdescriptions)))} 116 | print('taskdescription2task_index', self.taskdescription2task_index) 117 | 118 | def set_domain_and_taskdescription_indices(self, domain_index, task_index): 119 | """ 120 | This is to make sure that the train and val dataloaders are using the same domain_has_index and taskdescription2task_index 121 | """ 122 | for dataset in list(self.data_sets.values()): 123 | dataset.domain_hash_index = domain_index 124 | dataset.taskdescription2task_index = task_index 125 | 126 | @property 127 | def dataset_stats(self): 128 | return "\n".join([dataset.dataset_stats for dataset in list(self.data_sets.values())]) 129 | 130 | 131 | 132 | def count_hashes(loader): 133 | tstart = time.time() 134 | single_task_hashes = set() 135 | bridge_data_hashes = set() 136 | n_batch_counter = 0 137 | for i_batch, sample_batched in enumerate(loader): 138 | # print('ibatch', counter) 139 | add_hashes(single_task_hashes, sample_batched['single_task']['images']) 140 | add_hashes(bridge_data_hashes, sample_batched['bridge_data']['images']) 141 | n_batch_counter += 1 142 | if n_batch_counter % 500 == 0: 143 | print('batch_counter', n_batch_counter) 144 | 145 | print('batch_counter', n_batch_counter) 146 | print('num hashes single task', len(single_task_hashes)) 147 | print('num hashes bridge task', len(bridge_data_hashes)) 148 | print('average loading time', (time.time() - tstart) / n_batch_counter) 149 | 150 | 151 | def add_hashes(hashes, images): 152 | for b in range(images.shape[0]): 153 | image_string = images[b].numpy().tostring() 154 | hashes.add(hashlib.sha256(image_string).hexdigest()) 155 | 156 | -------------------------------------------------------------------------------- /bridgedata/data_sets/replay_buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | from tqdm import tqdm 5 | import copy 6 | from bridgedata.utils.general_utils import AttrDict 7 | # from bridgedata.data_sets.robonet_dataloader import FilteredRoboNetDataset 8 | from bridgedata.utils.general_utils import Configurable 9 | from bridgedata.data_sets.data_augmentation import get_random_color_aug, get_random_crop 10 | from bridgedata.utils.general_utils import select_indices 11 | import os 12 | 13 | 14 | 15 | def make_data_aug(images, hp, phase): 16 | batch_size = images.shape[0] 17 | if hp.color_augmentation and phase is 'train': 18 | aug_images = [] 19 | for b in range(batch_size): 20 | aug_images.append(get_random_color_aug(images[b][None], hp.color_augmentation, minus_one_to_one_range=True)) 21 | aug_images = np.concatenate(aug_images, 0) 22 | images = aug_images* 2 - 1 23 | if hp.random_crop: 24 | assert images.shape[-2:] == tuple(hp.image_size_beforecrop) 25 | cropped_images = [] 26 | for b in range(batch_size): 27 | cropped_images.append(get_random_crop(images[b][None], hp.random_crop, center_crop= phase != 'train')) 28 | images = np.concatenate(cropped_images, 0) 29 | if not isinstance(images, torch.Tensor): 30 | images = torch.from_numpy(images).float() 31 | return images 32 | 33 | def apply_data_aug(dict, hp, phase): 34 | for key, value in dict.items(): 35 | if 'image' in key: 36 | dict[key] = make_data_aug(value, hp, phase) 37 | return dict 38 | 39 | class DatasetReplayBuffer(Configurable): 40 | def __init__(self, data_conf, phase='train', shuffle=True): 41 | self._hp = self._default_hparams() 42 | self._override_defaults(data_conf) 43 | self.phase = phase 44 | 45 | print('making replay buffer for', data_conf.data_dir) 46 | load_data_conf = copy.deepcopy(data_conf) 47 | dataset_class = load_data_conf.pop('dataset_type') 48 | # drop keys that are not used in data loader 49 | if 'max_datapoints' in load_data_conf: 50 | load_data_conf.pop('max_datapoints') 51 | 52 | # remove data augmentation since we want to have non-augmented data in the replay buffer 53 | if 'color_augmentation' in load_data_conf: 54 | load_data_conf.pop('color_augmentation') 55 | if 'random_crop' in load_data_conf: 56 | load_data_conf.pop('random_crop') 57 | if 'sel_camera' in load_data_conf: 58 | if load_data_conf['sel_camera'] == 'random': 59 | load_data_conf.pop('sel_camera') 60 | if 'debug' in load_data_conf: 61 | load_data_conf.pop('debug') 62 | if 'num_cams_per_variation' in load_data_conf: 63 | load_data_conf.pop('num_cams_per_variation') 64 | self.loader = dataset_class(load_data_conf, phase).get_data_loader(1) 65 | self.buffer = [] 66 | self.loadDataset() 67 | self.get_data_counter = 0 68 | 69 | def _default_hparams(self): 70 | # Data Dimensions 71 | default_dict = AttrDict( 72 | name="", # the name of the dataset, used for writing logs 73 | dataset_type=None, 74 | max_train_examples=None, 75 | data_dir=None, 76 | n_worker=10, 77 | random_crop=False, 78 | image_size_beforecrop=None, 79 | color_augmentation=False, 80 | sel_len=-1, # number of time steps for contigous sequence that is shifted within sequeence of T randomly 81 | sel_camera=None, 82 | num_cams_per_variation=None, 83 | concatentate_cameras=False, 84 | T=None, 85 | downsample_img_sz=None, 86 | train_data_fraction=1., 87 | robot_list=None, 88 | camera=0, 89 | target_adim=None, 90 | target_sdim=None, 91 | splits="", 92 | debug=False 93 | ) 94 | return AttrDict(default_dict) 95 | 96 | def random_batch(self): 97 | indices = np.random.randint(0, len(self.buffer), self.batch_size) 98 | output_dict = AttrDict() 99 | selected_dicts = np.array(self.buffer)[indices] 100 | for key in selected_dicts[0]: 101 | output_dict[key] = torch.cat([sel[key] for sel in selected_dicts]) 102 | 103 | t0 = np.array([np.random.randint(0, tlen - 1) for tlen in output_dict['tlen']]) 104 | output_dict['final_image'] = select_indices(output_dict.images, output_dict['tlen'] - 1).squeeze() 105 | for tag in ['states', 'actions', 'images']: 106 | output_dict[tag] = select_indices(output_dict[tag], t0).squeeze() 107 | 108 | if 'sel_camera' in self._hp: 109 | if self._hp.sel_camera == 'random': 110 | ncam = output_dict.images.shape[1] 111 | cam_ind = np.random.randint(0, ncam, self.batch_size) 112 | if self._hp.num_cams_per_variation is not None: 113 | camera_variation_index = output_dict['camera_variation_index'] 114 | output_dict['domain_ind'] = (camera_variation_index * self._hp.num_cams_per_variation + cam_ind).to(torch.long) 115 | else: 116 | output_dict['domain_ind'] = cam_ind 117 | output_dict.images = select_indices(output_dict.images, cam_ind).squeeze() 118 | output_dict['final_image'] = select_indices(output_dict['final_image'], cam_ind).squeeze() 119 | apply_data_aug(output_dict, self._hp, self.phase) 120 | return output_dict 121 | 122 | def loadDataset(self): 123 | print('loading dataset into replay buffer...') 124 | for sampled_batch in tqdm(self.loader): 125 | self.buffer.append(sampled_batch) 126 | # if self._hp.max_datapoints is not None: 127 | # if len(self.buffer) > self._hp.max_datapoints: 128 | # print('max data points reached!') 129 | # break 130 | 131 | if self._hp.debug: 132 | if len(self.buffer) > 10: 133 | # import pdb; pdb.set_trace() 134 | print('break at 10!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!11') 135 | break 136 | print('done loading.') 137 | 138 | def __next__(self): 139 | self.get_data_counter += 1 140 | if self.get_data_counter < self.__len__(): 141 | batch = self.random_batch() 142 | return batch 143 | else: 144 | raise StopIteration 145 | 146 | def __iter__(self): 147 | self.get_data_counter = 0 148 | return self 149 | 150 | def get_data_loader(self, batch_size): 151 | self.batch_size = batch_size 152 | return self 153 | 154 | def set_batch_size(self, batch_size): 155 | self.batch_size = batch_size 156 | 157 | def __len__(self): 158 | if self.batch_size is None: 159 | raise NotImplementedError('length only implemented for loader!') 160 | return int(len(self.buffer * 50)/self.batch_size) # iterate through the same data 50 times. 161 | 162 | class MultiDatasetReplayBuffer(): 163 | def __init__(self, dataset_dict, phase, shuffle=True): 164 | self.data_sets = AttrDict() 165 | self._hp = AttrDict() 166 | for dataset_name, value in dataset_dict.items(): 167 | data_conf = value.dataconf 168 | self.data_sets[dataset_name] = DatasetReplayBuffer(data_conf, phase, shuffle) 169 | self._hp[dataset_name] = data_conf 170 | 171 | self.phase = phase 172 | 173 | def __next__(self): 174 | self.get_data_counter += 1 175 | if self.get_data_counter < self.__len__(): 176 | dict = AttrDict() 177 | for name, dataset in self.data_sets.items(): 178 | dict[name] = dataset.random_batch() 179 | return dict 180 | else: 181 | raise StopIteration 182 | 183 | def __iter__(self): 184 | self.get_data_counter = 0 185 | return self 186 | 187 | def __len__(self): 188 | lengths = [len(d) for d in self.data_sets.values()] 189 | return min(lengths) 190 | 191 | def get_data_loader(self, batch_size): 192 | for dataset in self.data_sets.values(): 193 | dataset.set_batch_size(batch_size) 194 | return self 195 | -------------------------------------------------------------------------------- /bridgedata/models/__init__.py: -------------------------------------------------------------------------------- 1 | def get_model_class(name): 2 | if name == 'GCBCImages': 3 | from bridgedata.models.gcbc_images import GCBCImages 4 | return GCBCImages 5 | if name == 'GCBCTransfer': 6 | from bridgedata.models.gcbc_transfer import GCBCTransfer 7 | return GCBCTransfer 8 | if name == 'GCBCImagesContext': 9 | from bridgedata.models.gcbc_images_context import GCBCImagesContext 10 | return GCBCImagesContext 11 | else: 12 | raise ValueError("modelname not found!") 13 | -------------------------------------------------------------------------------- /bridgedata/models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | from bridgedata.utils.general_utils import AttrDict 5 | import sys 6 | if sys.version_info[0] == 2: 7 | import cPickle as pkl 8 | else: 9 | import pickle as pkl 10 | from bridgedata.utils.general_utils import Configurable 11 | from bridgedata.utils.general_utils import move_to_device 12 | # from bridgedata.models.gcbc_transfer import GCBCTransfer 13 | 14 | class BaseModel(nn.Module, Configurable): 15 | def __init__(self, override_params, logger): 16 | super(BaseModel, self).__init__() 17 | self._hp = self._default_hparams() 18 | self._override_defaults(override_params) 19 | self._logger = logger 20 | 21 | self.normalizing_params = None 22 | if self._hp.dataset_normalization: 23 | if self._hp.store_normalization_inmodel: 24 | self.setup_normalizing_params() 25 | else: 26 | self.load_normalizing_params() 27 | 28 | self.throttle_log_images = 0 29 | 30 | def set_dataset_sufix(self, hp): 31 | self.dataset_sufix = hp.name 32 | 33 | def load_normalizing_params(self): 34 | if 'single_task' in self._hp.data_conf: # skip if using GCBCTransfer model 35 | return 36 | 37 | if 'dataset0' in self._hp.data_conf: # used for RandomMixingDataset 38 | params_dir = self._hp.data_conf.dataset0[1]['data_dir'] 39 | if isinstance(params_dir, list): 40 | params_dir = params_dir[0] 41 | elif self._hp.normalizing_params is not None: 42 | params_dir = self._hp.normalizing_params 43 | else: 44 | # when using a list of data_dirs in the single-dataset loader 45 | if isinstance(self._hp.data_conf['data_dir'], list): 46 | params_dir = self._hp.data_conf['data_dir'][0] 47 | else: 48 | params_dir = self._hp.data_conf['data_dir'] 49 | print('getting normalizing params from ', params_dir) 50 | dict = pkl.load(open(params_dir + '/normalizing_params.pkl', "rb")) 51 | self.normalizing_params = move_to_device(dict, self._hp.device) 52 | 53 | def setup_normalizing_params(self): 54 | self.states_mean = nn.Parameter(torch.tensor(torch.zeros(self._hp.state_dim), dtype=torch.float32)) 55 | self.states_std = nn.Parameter(torch.tensor(torch.zeros(self._hp.state_dim), dtype=torch.float32)) 56 | self.actions_mean = nn.Parameter(torch.tensor(torch.zeros(self._hp.action_dim), dtype=torch.float32)) 57 | self.actions_std = nn.Parameter(torch.tensor(torch.zeros(self._hp.action_dim), dtype=torch.float32)) 58 | 59 | def set_normalizing_params(self, dict): 60 | for k, v in dict.items(): 61 | setattr(self, k, nn.Parameter(torch.tensor(v, dtype=torch.float32), requires_grad=False)) 62 | 63 | def _default_hparams(self): 64 | # General Params: 65 | default_dict = AttrDict({ 66 | 'batch_size': -1, 67 | 'max_seq_len': -1, 68 | 'device':torch.device('cuda'), 69 | 'data_conf':None, 70 | 'restore_path':None, 71 | 'dataset_normalization':True, # path to pkl file with normalization parameters 72 | 'store_normalization_inmodel':True, 73 | 'normalizing_params': None, 74 | 'phase': None, 75 | 'stage': 'main' # or finetuning 76 | }) 77 | 78 | # Network params 79 | default_dict.update({ 80 | 'normalization': 'batch', 81 | }) 82 | 83 | # add new params to parent params 84 | return AttrDict(default_dict) 85 | 86 | 87 | def build_network(self): 88 | raise NotImplementedError("Need to implement this function in the subclass!") 89 | 90 | def forward(self, inputs): 91 | raise NotImplementedError("Need to implement this function in the subclass!") 92 | 93 | def loss(self, model_inputs, model_output): 94 | raise NotImplementedError("Need to implement this function in the subclass!") 95 | 96 | def apply_dataset_normalization(self, tensor, name): 97 | """ 98 | :param tensor: 99 | :param name: either 'states' or 'actions' 100 | :return: 101 | """ 102 | return (tensor - self.__getattr__(name + '_mean')) / (self.__getattr__(name + '_std') + 1e-6) 103 | 104 | def unnormalize_dataset(self, tensor, name): 105 | """ 106 | :param tensor: 107 | :param name: either 'states' or 'actions' 108 | :return: 109 | """ 110 | return tensor * self.__getattr__(name + '_std') + self.__getattr__(name + '_mean') 111 | 112 | def log_outputs(self, model_output, inputs, losses, step, phase): 113 | # Log generally useful outputs 114 | self._log_losses(losses, step) 115 | 116 | # if phase == 'train': 117 | # self.log_gradients(step, phase) 118 | 119 | if self.throttle_log_images % 10 == 0: 120 | self.throttle_log_images = 0 121 | for module in self.modules(): 122 | if hasattr(module, '_log_outputs'): 123 | module._log_outputs(model_output, inputs, losses, step, phase) 124 | self.throttle_log_images += 1 125 | 126 | def _log_losses(self, losses, step): 127 | for name, loss in losses.items(): 128 | name += "_" + self.dataset_sufix 129 | if torch.is_tensor(loss): 130 | self._logger.log_scalar(loss, name, step) 131 | else: 132 | self._logger.log_scalar(loss[0], name, step) 133 | 134 | def _restore_params(self, strict=True): 135 | checkpoint = torch.load(self._hp.restore_path, map_location=self._hp.device) 136 | print('restoring parameters from ', self._hp.restore_path) 137 | self.load_state_dict(checkpoint['state_dict'], strict=strict) 138 | 139 | def _load_weights(self, weight_loading_info): 140 | """ 141 | Loads weights of submodels from defined checkpoints + scopes. 142 | :param weight_loading_info: list of tuples: [(model_handle, scope, checkpoint_path)] 143 | """ 144 | 145 | def get_filtered_weight_dict(checkpoint_path, scope): 146 | if os.path.isfile(checkpoint_path): 147 | checkpoint = torch.load(checkpoint_path, map_location=self._hp.device) 148 | filtered_state_dict = {} 149 | remove_key_length = len(scope) + 1 # need to remove scope from checkpoint key 150 | for key, item in checkpoint['state_dict'].items(): 151 | if key.startswith(scope): 152 | filtered_state_dict[key[remove_key_length:]] = item 153 | if not filtered_state_dict: 154 | raise ValueError("No variable with scope '{}' found in checkpoint '{}'!".format(scope, checkpoint_path)) 155 | return filtered_state_dict 156 | else: 157 | raise ValueError("Cannot find checkpoint file '{}' for loading '{}'.".format(checkpoint_path, scope)) 158 | 159 | print("") 160 | for loading_op in weight_loading_info: 161 | print(("=> loading '{}' from checkpoint '{}'".format(loading_op[1], loading_op[2]))) 162 | filtered_weight_dict = get_filtered_weight_dict(checkpoint_path=loading_op[2], 163 | scope=loading_op[1]) 164 | loading_op[0].load_state_dict(filtered_weight_dict) 165 | print(("=> loaded '{}' from checkpoint '{}'".format(loading_op[1], loading_op[2]))) 166 | print("") 167 | 168 | def log_gradients(self, step, phase): 169 | grad_norms = list([torch.norm(p.grad.data) for p in self.parameters() if p.grad is not None]) 170 | grad_names = list([name for name, p in self.named_parameters() if p.requires_grad]) 171 | 172 | if len(grad_norms) == 0: 173 | return 174 | grad_norms = torch.stack(grad_norms) 175 | 176 | for name, grad_norm in zip(grad_names, grad_norms): 177 | self._logger.log_scalar(grad_norm.mean(), 'gradients/{}mean_norm'.format(name), step, phase) 178 | self._logger.log_scalar(grad_norm.max(), 'gradients/{}max_norm'.format(name), step, phase) 179 | 180 | self._logger.log_scalar(grad_norms.mean(), 'gradients/mean_norm', step, phase) 181 | self._logger.log_scalar(grad_norms.max(), 'gradients/max_norm', step, phase) 182 | 183 | -------------------------------------------------------------------------------- /bridgedata/models/gcbc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pdb 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from bridgedata.utils.general_utils import AttrDict 7 | from bridgedata.utils.general_utils import select_indices 8 | from bridgedata.models.base_model import BaseModel 9 | from bridgedata.models.utils.layers import BaseProcessingNet 10 | 11 | class GCBCModel(BaseModel): 12 | def __init__(self, overrideparams, logger=None): 13 | super().__init__(overrideparams, logger) 14 | self._hp = self._default_hparams() 15 | self._override_defaults(overrideparams) # override defaults with config file 16 | 17 | assert self._hp.batch_size != -1 18 | assert self._hp.state_dim != -1 19 | assert self._hp.action_dim != -1 20 | 21 | self.build_network() 22 | self.actions = None 23 | 24 | def _default_hparams(self): 25 | default_dict = AttrDict( 26 | state_dim=-1, 27 | action_dim=-1, 28 | goal_cond=True, 29 | goal_state_delta_t=None, 30 | use_conv=False 31 | ) 32 | # add new params to parent params 33 | parent_params = super()._default_hparams() 34 | parent_params.update(default_dict) 35 | return parent_params 36 | 37 | def sample_tsteps(self, states, actions): 38 | tlen = states.shape[1] 39 | 40 | # get positives: 41 | t0 = np.random.randint(0, tlen-1, self._hp.batch_size) 42 | 43 | sel_states = select_indices(states, t0) 44 | if self._hp.goal_cond: 45 | if self._hp.goal_state_delta_t is not None: 46 | tg = t0 + np.random.randint(1, self._hp.goal_state_delta_t + 1, self._hp.batch_size) 47 | tg = np.clip(tg, 0, tlen - 1) 48 | goal_states = select_indices(states, tg) 49 | else: 50 | goal_states = states[:, -1] 51 | action_pred_input = torch.cat([sel_states, goal_states], dim=1) 52 | else: 53 | action_pred_input = sel_states 54 | actions = select_indices(actions, t0) 55 | return action_pred_input, actions 56 | 57 | def build_network(self): 58 | if self._hp.goal_cond: 59 | inputdim = self._hp.state_dim*2 60 | else: 61 | inputdim = self._hp.state_dim 62 | self.s_encoder = BaseProcessingNet(inputdim, 128, 63 | self._hp.action_dim, num_layers=3, normalization=None) 64 | 65 | def forward(self, inputs): 66 | """ 67 | forward pass at training time 68 | :param 69 | images shape = batch x time x channel x height x width 70 | :return: model_output 71 | """ 72 | action_pred_input, self.actions = self.sample_tsteps(inputs.states, inputs.action_targets) 73 | a_pred = self.s_encoder.forward(action_pred_input) 74 | return AttrDict(a_pred=a_pred) 75 | 76 | def loss(self, model_input, model_output): 77 | losses = AttrDict(mse=torch.nn.MSELoss()(model_output.a_pred, self.actions)) 78 | 79 | # compute total loss 80 | losses.total_loss = torch.stack(list(losses.values())).sum() 81 | return losses 82 | 83 | 84 | 85 | 86 | class GCBCModelTest(GCBCModel): 87 | def __init__(self, overridparams, logger=None): 88 | super(GCBCModelTest, self).__init__(overridparams, logger) 89 | self._restore_params() 90 | 91 | 92 | def forward(self, inputs): 93 | if self._hp.goal_cond: 94 | a_pred_input = torch.cat([inputs.state, inputs.goal_state], dim=1) 95 | else: 96 | a_pred_input = inputs.state 97 | a_pred = self.s_encoder.forward(a_pred_input) 98 | return AttrDict(a_pred=a_pred) 99 | 100 | 101 | -------------------------------------------------------------------------------- /bridgedata/models/gcbc_images_context.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pdb 3 | import torch 4 | import os 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from bridgedata.utils.general_utils import AttrDict 8 | from bridgedata.utils.general_utils import select_indices, trch2npy 9 | from bridgedata.models.base_model import BaseModel 10 | from bridgedata.models.utils.resnet import get_resnet_encoder 11 | 12 | from bridgedata.models.utils.subnetworks import ConvEncoder 13 | from bridgedata.models.utils.layers import BaseProcessingNet 14 | from bridgedata.utils.general_utils import np_unstack 15 | from bridgedata.models.utils.spatial_softmax import SpatialSoftmax 16 | from bridgedata.data_sets.data_augmentation import get_random_crop 17 | from bridgedata.models.gcbc_images import GCBCImages 18 | from bridgedata.models.gcbc_images import get_tlen_from_padmask 19 | import cv2 20 | from bridgedata.models.gcbc_images import GeneralImageEncoder 21 | 22 | 23 | class GCBCImagesContext(GCBCImages): 24 | def __init__(self, overrideparams, logger): 25 | super().__init__(overrideparams, logger) 26 | self._hp = self._default_hparams() 27 | self._override_defaults(overrideparams) # override defaults with config file 28 | 29 | def _default_hparams(self): 30 | default_dict = AttrDict( 31 | encoder_embedding_size=128, 32 | num_context=3, 33 | ) 34 | # add new params to parent params 35 | parent_params = super()._default_hparams() 36 | parent_params.update(default_dict) 37 | return parent_params 38 | 39 | def build_network(self): 40 | if self._hp.resnet is not None: 41 | self.encoder = GeneralImageEncoder(self._hp.resnet, out_dim=self._hp.encoder_embedding_size, 42 | use_spatial_softmax=self._hp.encoder_spatial_softmax) 43 | self.embedding_size = self._hp.encoder_embedding_size*2 + self._hp.action_dim*self._hp.num_context 44 | if self._hp.goal_cond: 45 | input_dim = 2*self.embedding_size 46 | else: 47 | input_dim = self.embedding_size 48 | else: 49 | raise NotImplementedError 50 | self.action_predictor = BaseProcessingNet(input_dim, mid_dim=256, out_dim=self._hp.action_dim, num_layers=2) 51 | self.future_action_predictor = BaseProcessingNet(input_dim, mid_dim=256, 52 | out_dim=self._hp.action_dim*self._hp.extra_horizon, num_layers=3) 53 | if self._hp.domain_class_mult: 54 | assert self._hp.num_domains > 1 55 | self.classifier = BaseProcessingNet(input_dim, mid_dim=256, 56 | out_dim=self._hp.num_domains, num_layers=3) 57 | 58 | def get_context(self, actions, batch_size, images, tstart_context): 59 | context_actions = [] 60 | context_images = [] 61 | for b in range(batch_size): 62 | context_actions.append(actions[b, tstart_context[b]:tstart_context[b] + self._hp.num_context]) 63 | context_images.append(images[b, tstart_context[b]:tstart_context[b] + self._hp.num_context]) 64 | context_actions = torch.stack(context_actions, dim=0) 65 | context_images = torch.stack(context_images, dim=0) 66 | return AttrDict(actions=context_actions, images=context_images) 67 | 68 | def get_embedding(self, pred_input, context): 69 | assert np.all(np.array(pred_input.shape[-3:]) == np.array([3, 48, 64])) 70 | embedding = self.encoder(pred_input) 71 | context_emb = [self.encoder(c.squeeze()) for c in torch.split(context.images, 1, 1)] 72 | context_emb = torch.stack(context_emb, dim=0).mean(dim=0) 73 | context_actions = torch.unbind(context.actions, 1) 74 | return torch.cat([embedding, context_emb, *context_actions], dim=1) 75 | 76 | def get_context_image_rows(self): 77 | context_images = torch.unbind(self.context.images, dim=1) 78 | image_rows = [] 79 | for context_image in context_images: 80 | row = trch2npy(torch.cat(torch.unbind((context_image + 1)/2, dim=0), dim=2)).transpose(1, 2, 0) 81 | image_rows.append(row) 82 | return image_rows 83 | 84 | -------------------------------------------------------------------------------- /bridgedata/models/gcbc_transfer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | from bridgedata.utils.general_utils import AttrDict, trch2npy 5 | from bridgedata.models.gcbc_images import GCBCImages 6 | from bridgedata.models.base_model import BaseModel 7 | from bridgedata.models.utils.layers import BaseProcessingNet 8 | from bridgedata.models.utils.gradient_reversal_layer import ReverseLayerF, compute_alpha 9 | import copy 10 | 11 | class GCBCTransfer(BaseModel): 12 | def __init__(self, overrideparams, logger=None): 13 | super().__init__(overrideparams, logger) 14 | self._hp = self._default_hparams() 15 | self._override_defaults(overrideparams) # override defaults with config file 16 | 17 | self._hp.shared_params.batch_size = self._hp.batch_size 18 | self._hp.shared_params.device = self._hp.device 19 | self._logger = logger 20 | self.build_network() 21 | 22 | def set_dataset_sufix(self, hp): 23 | if hasattr(hp, 'name'): 24 | self.dataset_sufix = hp.name 25 | self.single_task_model.dataset_sufix = hp.name # used for validation run 26 | else: 27 | self.dataset_sufix = 'multi_dataset' 28 | self.dataset_hp = hp 29 | 30 | def _default_hparams(self): 31 | default_dict = AttrDict( 32 | shared_params=None, 33 | single_task_params=None, 34 | bridge_data_params=None, 35 | classifier_validation_params=None, 36 | datasource_class_mult=0, 37 | use_grad_reverse=True, 38 | child_model_class=GCBCImages, 39 | alpha_delay=0 40 | ) 41 | # add new params to parent params 42 | parent_params = super()._default_hparams() 43 | parent_params.update(default_dict) 44 | return parent_params 45 | 46 | def build_network(self): 47 | self.build_single_task_model() 48 | self.build_bridge_model() 49 | if self._hp.classifier_validation_params is not None: 50 | self.build_classifier_validation_model() 51 | if self._hp.datasource_class_mult != 0: 52 | self.classifier = BaseProcessingNet(self.single_task_model.embedding_size, mid_dim=256, 53 | out_dim=2, num_layers=3) 54 | 55 | def build_single_task_model(self): 56 | single_task_params = AttrDict() 57 | single_task_params.update(self._hp.shared_params) 58 | single_task_params.update(self._hp.single_task_params) 59 | single_task_params.data_conf = self._hp.data_conf.single_task.dataconf 60 | self.single_task_model = self._hp.child_model_class(single_task_params, self._logger) 61 | self._hp.single_task_params = self.single_task_model._hp 62 | 63 | def build_bridge_model(self): 64 | bridge_data_params = AttrDict() 65 | bridge_data_params.update(self._hp.shared_params) 66 | bridge_data_params.update(self._hp.bridge_data_params) 67 | bridge_data_params.data_conf = self._hp.data_conf.bridge_data.dataconf 68 | self.bridge_data_model = self._hp.child_model_class(bridge_data_params, self._logger) 69 | assert self.bridge_data_model.encoder is not None 70 | delattr(self.bridge_data_model, 'encoder') 71 | self.bridge_data_model.encoder = self.single_task_model.encoder 72 | if self.bridge_data_model._hp.shared_classifier: 73 | assert self.bridge_data_model.classifier is not None 74 | delattr(self.bridge_data_model, 'classifier') 75 | self.bridge_data_model.classifier = self.single_task_model.classifier 76 | self._hp.bridge_data_params = self.bridge_data_model._hp 77 | 78 | def build_classifier_validation_model(self): 79 | bridge_data_params = AttrDict() 80 | bridge_data_params.update(self._hp.shared_params) 81 | bridge_data_params.update(self._hp.classifier_validation_params) 82 | bridge_data_params.data_conf = self._hp.data_conf.classifier_validation.dataconf 83 | self.classifier_validation_model = self._hp.child_model_class(bridge_data_params, self._logger) 84 | assert self.classifier_validation_model.encoder is not None 85 | delattr(self.classifier_validation_model, 'encoder') 86 | self.classifier_validation_model.encoder = self.single_task_model.encoder 87 | self._hp.classifier_validation_params = self.bridge_data_model._hp 88 | 89 | def forward(self, inputs): 90 | out = AttrDict() 91 | if 'single_task' not in inputs: 92 | out.validation_run = self.single_task_model.forward(inputs) 93 | return out 94 | inputs.single_task.global_step = inputs.global_step 95 | inputs.bridge_data.global_step = inputs.global_step 96 | inputs.single_task.max_iterations = inputs.max_iterations 97 | inputs.bridge_data.max_iterations = inputs.max_iterations 98 | out.single_task = self.single_task_model.forward(inputs.single_task) 99 | out.bridge_data = self.bridge_data_model.forward(inputs.bridge_data) 100 | if self._hp.classifier_validation_params is not None: 101 | out.classifier_validation = self.classifier_validation_model.forward(inputs.classifier_validation) 102 | if self._hp.datasource_class_mult != 0: 103 | embeddings = torch.cat([out.single_task.embedding, out.bridge_data.embedding], dim=0) 104 | if self._hp.use_grad_reverse: 105 | alpha = compute_alpha(inputs, self._hp.alpha_delay) 106 | out.alpha = alpha 107 | embeddings = ReverseLayerF.apply(embeddings, alpha) 108 | else: 109 | embeddings = embeddings.detach() 110 | out.pred_logit = self.classifier(embeddings) 111 | return out 112 | 113 | def loss(self, model_input, model_output): 114 | if 'validation_run' in model_output: 115 | return self.single_task_model.loss(model_input, model_output.validation_run) 116 | losses = AttrDict() 117 | losses_single_task = self.single_task_model.loss(model_input.single_task, model_output.single_task, compute_total_loss=False) 118 | losses_bridge_data = self.bridge_data_model.loss(model_input.bridge_data, model_output.bridge_data, compute_total_loss=False) 119 | for k, v in losses_single_task.items(): 120 | losses['single_task_' + k] = (v[0], v[1]*self.single_task_model._hp.model_loss_mult) 121 | for k, v in losses_bridge_data.items(): 122 | losses['bridge_data_' + k] = (v[0], v[1]*self.bridge_data_model._hp.model_loss_mult) 123 | if self._hp.classifier_validation_params is not None: 124 | losses_classifier_validation = self.classifier_validation_model.loss(model_input.classifier_validation, model_output.classifier_validation, compute_total_loss=False) 125 | for k, v in losses_classifier_validation.items(): 126 | losses['classifier_validation_' + k] = (v[0], v[1]*self.classifier_validation_model._hp.model_loss_mult) 127 | 128 | if self._hp.datasource_class_mult != 0: 129 | self.class_labels= [] 130 | for i in range(2): 131 | self.class_labels.append(torch.ones(self._hp.batch_size, dtype=torch.long) * i) 132 | self.class_labels = torch.cat(self.class_labels, dim=0).to(torch.device('cuda')) 133 | losses.datasource_classification_loss = [nn.CrossEntropyLoss()(model_output.pred_logit, self.class_labels), 134 | self._hp.datasource_class_mult] 135 | 136 | # compute total loss 137 | losses.total_loss = torch.stack([l[0] * l[1] for l in losses.values()]).sum() 138 | return losses 139 | 140 | def _log_outputs(self, model_output, inputs, losses, step, phase): 141 | if 'validation_run' in model_output: 142 | assert phase == 'val' 143 | self.single_task_model._log_outputs(model_output.validation_run, inputs, losses, step, phase) 144 | return 145 | 146 | self.single_task_model._log_outputs(model_output.single_task, inputs.single_task, losses, step, phase, override_sufix=self.dataset_hp['single_task'].name) 147 | self.bridge_data_model._log_outputs(model_output.bridge_data, inputs.bridge_data, losses, step, phase, override_sufix=self.dataset_hp['bridge_data'].name) 148 | if self._hp.classifier_validation_params is not None: 149 | self.classifier_validation_model._log_outputs(model_output.classifier_validation, inputs.classifier_validation, losses, step, phase, 150 | override_sufix=self.dataset_hp['classifier_validation'].name) 151 | 152 | if self._hp.datasource_class_mult != 0: 153 | predictions = torch.argmax(model_output.pred_logit, dim=1) 154 | error_rate = np.mean(trch2npy(predictions) != trch2npy(self.class_labels)) 155 | self._logger.log_scalar(error_rate, 'datasource_class_error_rate', step) 156 | if self._hp.use_grad_reverse: 157 | self._logger.log_scalar(model_output.alpha, 'data_source_alpha', step) 158 | -------------------------------------------------------------------------------- /bridgedata/models/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanlai00/bridge_data_imitation_learning/d81a3d9181672f5f26dfbd1844d3017cf7a11367/bridgedata/models/utils/__init__.py -------------------------------------------------------------------------------- /bridgedata/models/utils/compute_dataset_normalization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def compute_dataset_normalization(dataloader, no_last_dim_norm=True): 4 | states_list = [] 5 | action_list = [] 6 | 7 | print('computing normalization....') 8 | for i_batch, sample_batched in enumerate(dataloader): 9 | if 'actions' not in sample_batched: 10 | raise NotImplementedError('todo!') 11 | states_list.append(sample_batched['states']) 12 | action_list.append(sample_batched['actions']) 13 | if i_batch == 200: 14 | break 15 | 16 | states = np.concatenate(states_list, axis=0) 17 | actions = np.concatenate(action_list, axis=0) 18 | if actions.shape[0] > 1000: 19 | break 20 | 21 | print('state dim: ', states.shape) 22 | print('action dim: ', actions.shape) 23 | if actions.shape[0] < 1000: 24 | print('Warning Very few examples found!!!') 25 | import pdb; pdb.set_trace() 26 | 27 | dict = { 28 | 'states_mean' : np.mean(states, axis=0), 29 | 'states_std' : np.std(states, axis=0), 30 | 'actions_mean': np.mean(actions, axis=0), 31 | 'actions_std': np.std(actions, axis=0), 32 | } 33 | 34 | for dim in range(states.shape[1]): 35 | if dict['states_mean'][dim] == 0 and dict['states_std'][dim] == 0: 36 | dict['states_mean'][dim] = 0 37 | dict['states_std'][dim] = 1 38 | print('##################################') 39 | print('not normalizing state dim {}, since mean and std are zero!!'.format(dim)) 40 | print('##################################') 41 | 42 | for dim in range(actions.shape[1]): 43 | if dict['actions_mean'][dim] == 0 and dict['actions_std'][dim] == 0: 44 | dict['actions_mean'][dim] = 0 45 | dict['actions_std'][dim] = 1 46 | print('##################################') 47 | print('not normalizing action dim {}, since mean and std are zero!!'.format(dim)) 48 | print('##################################') 49 | 50 | if no_last_dim_norm: 51 | print('##################################') 52 | print('not normalizing grasp action!') 53 | print('##################################') 54 | dict['actions_mean'][-1] = 0 55 | dict['actions_std'][-1] = 1 56 | 57 | print('normalization params') 58 | print(dict) 59 | 60 | return dict -------------------------------------------------------------------------------- /bridgedata/models/utils/gradient_reversal_layer.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Function 2 | 3 | import numpy as np 4 | 5 | def compute_alpha(inputs, alpha_delay): 6 | done_ratio = max(inputs.global_step - alpha_delay, 0) / (inputs.max_iterations - alpha_delay) 7 | alpha = 2. / (1. + np.exp(-10 * done_ratio)) - 1 8 | return alpha 9 | 10 | 11 | class ReverseLayerF(Function): 12 | 13 | @staticmethod 14 | def forward(ctx, x, alpha): 15 | ctx.alpha = alpha 16 | 17 | return x.view_as(x) 18 | 19 | @staticmethod 20 | def backward(ctx, grad_output): 21 | output = grad_output.neg() * ctx.alpha 22 | 23 | return output, None -------------------------------------------------------------------------------- /bridgedata/models/utils/layers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from functools import partial 3 | from bridgedata.utils.general_utils import HasParameters 4 | import math 5 | from bridgedata.utils.general_utils import AttrDict 6 | 7 | 8 | def init_weights_xavier(m): 9 | if isinstance(m, nn.Linear): 10 | nn.init.xavier_normal(m.weight.data) 11 | if m.bias is not None: 12 | m.bias.data.fill_(0) 13 | if isinstance(m, nn.Conv2d): 14 | pass # by default PyTorch uses Kaiming_Normal initializer 15 | 16 | class Block(nn.Sequential): 17 | def __init__(self, **kwargs): 18 | nn.Sequential.__init__(self) 19 | self.params = self.get_default_params() 20 | self.override_defaults(kwargs) 21 | 22 | self.build_block() 23 | self.complete_block() 24 | 25 | def get_default_params(self): 26 | params = AttrDict( 27 | activation=nn.LeakyReLU(0.2, inplace=True), 28 | normalization='batch', 29 | normalization_params=AttrDict() 30 | ) 31 | return params 32 | 33 | def override_defaults(self, override): 34 | for name, value in override.items(): 35 | # print('overriding param {} to value {}'.format(name, value)) 36 | self.params[name] = value 37 | 38 | def build_block(self): 39 | raise NotImplementedError 40 | 41 | def complete_block(self): 42 | if self.params.normalization is not None: 43 | self.params.normalization_params.affine = True 44 | # TODO add a warning if the normalization is over 1 element 45 | if self.params.normalization == 'batch': 46 | normalization = nn.BatchNorm1d if self.params.d == 1 else nn.BatchNorm2d 47 | self.params.normalization_params.track_running_stats = True 48 | 49 | elif self.params.normalization == 'instance': 50 | normalization = nn.InstanceNorm1d if self.params.d == 1 else nn.InstanceNorm2d 51 | self.params.normalization_params.track_running_stats = True 52 | # TODO if affine is false, the biases will not be learned 53 | 54 | elif self.params.normalization == 'group': 55 | normalization = partial(nn.GroupNorm, 8) 56 | if self.params.out_dim < 32: 57 | raise NotImplementedError("note that group norm is likely to not work with this small groups") 58 | 59 | else: 60 | raise ValueError("Normalization type {} unknown".format(self.params.normalization)) 61 | self.add_module('norm', normalization(self.params.out_dim, **self.params.normalization_params)) 62 | 63 | if self.params.activation is not None: 64 | self.add_module('activation', self.params.activation) 65 | 66 | def calc_output_size_and_padding(self, input_size): 67 | """ 68 | :param input_size: list of H, W 69 | :return: 70 | """ 71 | 72 | p = (self.params.kernel_size - self.params.stride) // 2 73 | 74 | s = self.params.stride 75 | k = self.params.kernel_size 76 | i_h = input_size[0] 77 | i_w = input_size[1] 78 | out_h = (i_h + 2 * p - k) / s + 1 79 | out_w = (i_w + 2 * p - k) / s + 1 80 | return [out_h, out_w] 81 | 82 | 83 | class ConvBlock(Block): 84 | def get_default_params(self): 85 | params = super(ConvBlock, self).get_default_params() 86 | params.update(AttrDict( 87 | d=2, 88 | kernel_size=3, 89 | stride=1, 90 | )) 91 | return params 92 | 93 | def build_block(self): 94 | if self.params.d == 2: 95 | cls = nn.Conv2d 96 | elif self.params.d == 1: 97 | cls = nn.Conv1d 98 | elif self.params.d == -2: 99 | cls = nn.ConvTranspose2d 100 | 101 | padding = (self.params.kernel_size - self.params.stride) // 2 102 | self.add_module('conv', cls( 103 | self.params.in_dim, self.params.out_dim, self.params.kernel_size, self.params.stride, padding)) 104 | 105 | class ConvBlockEnc(ConvBlock): 106 | def get_default_params(self): 107 | params = super(ConvBlockEnc, self).get_default_params() 108 | params.update(AttrDict( 109 | kernel_size=4, 110 | stride=2, 111 | )) 112 | return params 113 | 114 | class ConvBlockDec(ConvBlock): 115 | def get_default_params(self): 116 | params = super(ConvBlockDec, self).get_default_params() 117 | params.update(AttrDict( 118 | d = -2, 119 | kernel_size=4, 120 | stride=2, 121 | )) 122 | return params 123 | 124 | class FCBlock(Block): 125 | def get_default_params(self): 126 | params = super(FCBlock, self).get_default_params() 127 | params.update(AttrDict( 128 | d=1, 129 | )) 130 | return params 131 | 132 | def build_block(self): 133 | self.add_module('linear', nn.Linear(self.params.in_dim, self.params.out_dim)) 134 | 135 | 136 | class Linear(FCBlock): 137 | def get_default_params(self): 138 | params = super(Linear, self).get_default_params() 139 | params.update(AttrDict( 140 | activation=None 141 | )) 142 | return params 143 | 144 | 145 | class BaseProcessingNet(nn.Sequential): 146 | """ Constructs a network that keeps the activation dimensions the same throughout the network 147 | Builds an MLP or CNN, depending on the builder. Alternatively uses custom blocks """ 148 | 149 | def __init__(self, in_dim, mid_dim, out_dim, num_layers, block=FCBlock, 150 | final_activation=None, normalization='batch'): 151 | super(BaseProcessingNet, self).__init__() 152 | 153 | self.add_module('input', block(in_dim=in_dim, out_dim=mid_dim, normalization=None)) 154 | for i in range(num_layers): 155 | self.add_module('pyramid-{}'.format(i), 156 | block(in_dim=mid_dim, out_dim=mid_dim, normalization=normalization)) 157 | 158 | self.add_module('head'.format(i + 1), 159 | block(in_dim=mid_dim, out_dim=out_dim, normalization=None, activation=final_activation)) 160 | self.apply(init_weights_xavier) 161 | 162 | 163 | -------------------------------------------------------------------------------- /bridgedata/models/utils/modelutils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def init_weights(m): 4 | classname = m.__class__.__name__ 5 | if classname.find('Conv') != -1 or classname.find('Linear') != -1: 6 | m.weight.data.normal_(0.0, 0.02) 7 | m.bias.data.fill_(0) 8 | elif classname.find('BatchNorm') != -1: 9 | m.weight.data.normal_(1.0, 0.02) 10 | m.bias.data.fill_(0) 11 | 12 | def get_one_hot(nb_digits, active_dim): 13 | """ 14 | param: active_dim: B tensor with indices that need to be set to 1 15 | """ 16 | active_dim = active_dim.type(torch.LongTensor) 17 | batch_size = active_dim.shape[0] 18 | y_onehot = torch.FloatTensor(batch_size, nb_digits) 19 | y_onehot.zero_() 20 | y_onehot.scatter_(1, active_dim[:, None], 1) 21 | return y_onehot 22 | -------------------------------------------------------------------------------- /bridgedata/models/utils/orig_resnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanlai00/bridge_data_imitation_learning/d81a3d9181672f5f26dfbd1844d3017cf7a11367/bridgedata/models/utils/orig_resnet/__init__.py -------------------------------------------------------------------------------- /bridgedata/models/utils/orig_resnet/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.hub import load_state_dict_from_url 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 7 | 'wide_resnet50_2', 'wide_resnet101_2'] 8 | 9 | 10 | model_urls = { 11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 12 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 14 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 15 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 16 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 17 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 18 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 19 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 20 | } 21 | 22 | 23 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 24 | """3x3 convolution with padding""" 25 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 26 | padding=dilation, groups=groups, bias=False, dilation=dilation) 27 | 28 | 29 | def conv1x1(in_planes, out_planes, stride=1): 30 | """1x1 convolution""" 31 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 32 | 33 | 34 | class BasicBlock(nn.Module): 35 | expansion = 1 36 | 37 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 38 | base_width=64, dilation=1, norm_layer=None): 39 | super(BasicBlock, self).__init__() 40 | if norm_layer is None: 41 | norm_layer = nn.BatchNorm2d 42 | if groups != 1 or base_width != 64: 43 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 44 | if dilation > 1: 45 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 46 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 47 | self.conv1 = conv3x3(inplanes, planes, stride) 48 | self.bn1 = norm_layer(planes) 49 | self.relu = nn.ReLU(inplace=True) 50 | self.conv2 = conv3x3(planes, planes) 51 | self.bn2 = norm_layer(planes) 52 | self.downsample = downsample 53 | self.stride = stride 54 | 55 | def forward(self, x): 56 | identity = x 57 | 58 | out = self.conv1(x) 59 | out = self.bn1(out) 60 | out = self.relu(out) 61 | 62 | out = self.conv2(out) 63 | out = self.bn2(out) 64 | 65 | if self.downsample is not None: 66 | identity = self.downsample(x) 67 | 68 | out += identity 69 | out = self.relu(out) 70 | 71 | return out 72 | 73 | 74 | class Bottleneck(nn.Module): 75 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 76 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 77 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 78 | # This variant is also known as ResNet V1.5 and improves accuracy according to 79 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 80 | 81 | expansion = 4 82 | 83 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 84 | base_width=64, dilation=1, norm_layer=None): 85 | super(Bottleneck, self).__init__() 86 | if norm_layer is None: 87 | norm_layer = nn.BatchNorm2d 88 | width = int(planes * (base_width / 64.)) * groups 89 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 90 | self.conv1 = conv1x1(inplanes, width) 91 | self.bn1 = norm_layer(width) 92 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 93 | self.bn2 = norm_layer(width) 94 | self.conv3 = conv1x1(width, planes * self.expansion) 95 | self.bn3 = norm_layer(planes * self.expansion) 96 | self.relu = nn.ReLU(inplace=True) 97 | self.downsample = downsample 98 | self.stride = stride 99 | 100 | def forward(self, x): 101 | identity = x 102 | 103 | out = self.conv1(x) 104 | out = self.bn1(out) 105 | out = self.relu(out) 106 | 107 | out = self.conv2(out) 108 | out = self.bn2(out) 109 | out = self.relu(out) 110 | 111 | out = self.conv3(out) 112 | out = self.bn3(out) 113 | 114 | if self.downsample is not None: 115 | identity = self.downsample(x) 116 | 117 | out += identity 118 | out = self.relu(out) 119 | 120 | return out 121 | 122 | 123 | class ResNet(nn.Module): 124 | 125 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 126 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 127 | norm_layer=None): 128 | super(ResNet, self).__init__() 129 | if norm_layer is None: 130 | norm_layer = nn.BatchNorm2d 131 | self._norm_layer = norm_layer 132 | 133 | self.inplanes = 64 134 | self.dilation = 1 135 | if replace_stride_with_dilation is None: 136 | # each element in the tuple indicates if we should replace 137 | # the 2x2 stride with a dilated convolution instead 138 | replace_stride_with_dilation = [False, False, False] 139 | if len(replace_stride_with_dilation) != 3: 140 | raise ValueError("replace_stride_with_dilation should be None " 141 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 142 | self.groups = groups 143 | self.base_width = width_per_group 144 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 145 | bias=False) 146 | self.bn1 = norm_layer(self.inplanes) 147 | self.relu = nn.ReLU(inplace=True) 148 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 149 | self.layer1 = self._make_layer(block, 64, layers[0]) 150 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 151 | dilate=replace_stride_with_dilation[0]) 152 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 153 | dilate=replace_stride_with_dilation[1]) 154 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 155 | dilate=replace_stride_with_dilation[2]) 156 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 157 | self.fc = nn.Linear(512 * block.expansion, num_classes) 158 | 159 | for m in self.modules(): 160 | if isinstance(m, nn.Conv2d): 161 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 162 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 163 | nn.init.constant_(m.weight, 1) 164 | nn.init.constant_(m.bias, 0) 165 | 166 | # Zero-initialize the last BN in each residual branch, 167 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 168 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 169 | if zero_init_residual: 170 | for m in self.modules(): 171 | if isinstance(m, Bottleneck): 172 | nn.init.constant_(m.bn3.weight, 0) 173 | elif isinstance(m, BasicBlock): 174 | nn.init.constant_(m.bn2.weight, 0) 175 | 176 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 177 | norm_layer = self._norm_layer 178 | downsample = None 179 | previous_dilation = self.dilation 180 | if dilate: 181 | self.dilation *= stride 182 | stride = 1 183 | if stride != 1 or self.inplanes != planes * block.expansion: 184 | downsample = nn.Sequential( 185 | conv1x1(self.inplanes, planes * block.expansion, stride), 186 | norm_layer(planes * block.expansion), 187 | ) 188 | 189 | layers = [] 190 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 191 | self.base_width, previous_dilation, norm_layer)) 192 | self.inplanes = planes * block.expansion 193 | for _ in range(1, blocks): 194 | layers.append(block(self.inplanes, planes, groups=self.groups, 195 | base_width=self.base_width, dilation=self.dilation, 196 | norm_layer=norm_layer)) 197 | 198 | return nn.Sequential(*layers) 199 | 200 | def _forward_impl(self, x): 201 | # See note [TorchScript super()] 202 | x = self.conv1(x) 203 | x = self.bn1(x) 204 | x = self.relu(x) 205 | x = self.maxpool(x) 206 | x = self.layer1(x) 207 | x = self.layer2(x) 208 | x = self.layer3(x) 209 | x = self.layer4(x) 210 | 211 | x = self.avgpool(x) 212 | x = torch.flatten(x, 1) 213 | x = self.fc(x) 214 | 215 | return x 216 | 217 | def forward(self, x): 218 | return self._forward_impl(x) 219 | 220 | 221 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 222 | model = ResNet(block, layers, **kwargs) 223 | if pretrained: 224 | state_dict = load_state_dict_from_url(model_urls[arch], 225 | progress=progress) 226 | model.load_state_dict(state_dict) 227 | return model 228 | 229 | 230 | def resnet18(pretrained=False, progress=True, **kwargs): 231 | r"""ResNet-18 model from 232 | `"Deep Residual Learning for Image Recognition" `_ 233 | 234 | Args: 235 | pretrained (bool): If True, returns a model pre-trained on ImageNet 236 | progress (bool): If True, displays a progress bar of the download to stderr 237 | """ 238 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 239 | **kwargs) 240 | 241 | 242 | def resnet34(pretrained=False, progress=True, **kwargs): 243 | r"""ResNet-34 model from 244 | `"Deep Residual Learning for Image Recognition" `_ 245 | 246 | Args: 247 | pretrained (bool): If True, returns a model pre-trained on ImageNet 248 | progress (bool): If True, displays a progress bar of the download to stderr 249 | """ 250 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 251 | **kwargs) 252 | 253 | 254 | def resnet50(pretrained=False, progress=True, **kwargs): 255 | r"""ResNet-50 model from 256 | `"Deep Residual Learning for Image Recognition" `_ 257 | 258 | Args: 259 | pretrained (bool): If True, returns a model pre-trained on ImageNet 260 | progress (bool): If True, displays a progress bar of the download to stderr 261 | """ 262 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 263 | **kwargs) 264 | 265 | 266 | def resnet101(pretrained=False, progress=True, **kwargs): 267 | r"""ResNet-101 model from 268 | `"Deep Residual Learning for Image Recognition" `_ 269 | 270 | Args: 271 | pretrained (bool): If True, returns a model pre-trained on ImageNet 272 | progress (bool): If True, displays a progress bar of the download to stderr 273 | """ 274 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 275 | **kwargs) 276 | 277 | 278 | def resnet152(pretrained=False, progress=True, **kwargs): 279 | r"""ResNet-152 model from 280 | `"Deep Residual Learning for Image Recognition" `_ 281 | 282 | Args: 283 | pretrained (bool): If True, returns a model pre-trained on ImageNet 284 | progress (bool): If True, displays a progress bar of the download to stderr 285 | """ 286 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 287 | **kwargs) 288 | 289 | 290 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 291 | r"""ResNeXt-50 32x4d model from 292 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 293 | 294 | Args: 295 | pretrained (bool): If True, returns a model pre-trained on ImageNet 296 | progress (bool): If True, displays a progress bar of the download to stderr 297 | """ 298 | kwargs['groups'] = 32 299 | kwargs['width_per_group'] = 4 300 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 301 | pretrained, progress, **kwargs) 302 | 303 | 304 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 305 | r"""ResNeXt-101 32x8d model from 306 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 307 | 308 | Args: 309 | pretrained (bool): If True, returns a model pre-trained on ImageNet 310 | progress (bool): If True, displays a progress bar of the download to stderr 311 | """ 312 | kwargs['groups'] = 32 313 | kwargs['width_per_group'] = 8 314 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 315 | pretrained, progress, **kwargs) 316 | 317 | 318 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 319 | r"""Wide ResNet-50-2 model from 320 | `"Wide Residual Networks" `_ 321 | 322 | The model is the same as ResNet except for the bottleneck number of channels 323 | which is twice larger in every block. The number of channels in outer 1x1 324 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 325 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 326 | 327 | Args: 328 | pretrained (bool): If True, returns a model pre-trained on ImageNet 329 | progress (bool): If True, displays a progress bar of the download to stderr 330 | """ 331 | kwargs['width_per_group'] = 64 * 2 332 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 333 | pretrained, progress, **kwargs) 334 | 335 | 336 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 337 | r"""Wide ResNet-101-2 model from 338 | `"Wide Residual Networks" `_ 339 | 340 | The model is the same as ResNet except for the bottleneck number of channels 341 | which is twice larger in every block. The number of channels in outer 1x1 342 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 343 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 344 | 345 | Args: 346 | pretrained (bool): If True, returns a model pre-trained on ImageNet 347 | progress (bool): If True, displays a progress bar of the download to stderr 348 | """ 349 | kwargs['width_per_group'] = 64 * 2 350 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 351 | pretrained, progress, **kwargs) 352 | -------------------------------------------------------------------------------- /bridgedata/models/utils/recurrent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | from bridgedata.models.utils.modelutils import init_weights 5 | 6 | class LSTM(nn.Module): 7 | def __init__(self, input_size, output_size, hidden_size, n_layers, batch_size): 8 | super(LSTM, self).__init__() 9 | self.input_size = input_size 10 | self.output_size = output_size 11 | self.hidden_size = hidden_size 12 | self.batch_size = batch_size 13 | self.n_layers = n_layers 14 | self.embed = nn.Linear(input_size, hidden_size) 15 | self.lstm = nn.ModuleList([nn.LSTMCell(hidden_size, hidden_size) for i in range(self.n_layers)]) 16 | self.output = nn.Sequential( 17 | nn.Linear(hidden_size, output_size), 18 | #nn.BatchNorm1d(output_size), 19 | nn.Tanh()) 20 | self.apply(init_weights) 21 | 22 | def init_hidden(self): 23 | hidden = [] 24 | for i in range(self.n_layers): 25 | hidden.append((Variable(torch.zeros(self.batch_size, self.hidden_size).cuda()), 26 | Variable(torch.zeros(self.batch_size, self.hidden_size).cuda()))) 27 | self.hidden = hidden 28 | 29 | def forward(self, input): 30 | embedded = self.embed(input.view(-1, self.input_size)) 31 | h_in = embedded 32 | for i in range(self.n_layers): 33 | self.hidden[i] = self.lstm[i](h_in, self.hidden[i]) 34 | h_in = self.hidden[i][0] 35 | 36 | return self.output(h_in) -------------------------------------------------------------------------------- /bridgedata/models/utils/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import sys 5 | from bridgedata.models.utils.orig_resnet.resnet import ResNet, BasicBlock, Bottleneck, model_urls 6 | 7 | def repeat_weights(weights, new_channels): 8 | prev_channels = weights.shape[1] 9 | assert prev_channels == 3, "Original weights should have three input channels" 10 | new_shape = list(weights.shape[:]) 11 | new_shape[1] = new_channels 12 | new_weights = torch.zeros(new_shape, dtype=weights.dtype, layout=weights.layout, device=weights.device) 13 | for i in range(new_channels): 14 | new_weights.data[:, i] = weights[:, i % prev_channels].clone() 15 | return new_weights 16 | 17 | if sys.version_info[0] == 3: 18 | from torch.hub import load_state_dict_from_url 19 | 20 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 21 | model = ResNetCustomStride(block, layers, **kwargs) 22 | if pretrained: 23 | state_dict = load_state_dict_from_url(model_urls[arch], 24 | progress=progress) 25 | state_dict.pop('fc.weight') 26 | state_dict.pop('fc.bias') 27 | model.load_state_dict(state_dict) 28 | return model 29 | 30 | def resnet18(pretrained=False, progress=True, **kwargs): 31 | r"""ResNet-18 model from 32 | `"Deep Residual Learning for Image Recognition" `_ 33 | 34 | Args: 35 | pretrained (bool): If True, returns a model pre-trained on ImageNet 36 | progress (bool): If True, displays a progress bar of the download to stderr 37 | """ 38 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 39 | **kwargs) 40 | 41 | 42 | def resnet34(pretrained=False, progress=True, **kwargs): 43 | r"""ResNet-34 model from 44 | `"Deep Residual Learning for Image Recognition" `_ 45 | 46 | Args: 47 | pretrained (bool): If True, returns a model pre-trained on ImageNet 48 | progress (bool): If True, displays a progress bar of the download to stderr 49 | """ 50 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 51 | **kwargs) 52 | 53 | def resnet34shallow(pretrained=False, progress=True, **kwargs): 54 | r"""ResNet-34 model from 55 | `"Deep Residual Learning for Image Recognition" `_ 56 | 57 | Args: 58 | pretrained (bool): If True, returns a model pre-trained on ImageNet 59 | progress (bool): If True, displays a progress bar of the download to stderr 60 | """ 61 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3, 1], pretrained, progress, strides=(2, 2, 1, 1, 1, 1), planes=(64, 128, 256, 512, 8), 62 | **kwargs) 63 | 64 | 65 | def resnet50(pretrained=False, progress=True, **kwargs): 66 | r"""ResNet-50 model from 67 | `"Deep Residual Learning for Image Recognition" `_ 68 | 69 | Args: 70 | pretrained (bool): If True, returns a model pre-trained on ImageNet 71 | progress (bool): If True, displays a progress bar of the download to stderr 72 | """ 73 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 74 | **kwargs) 75 | 76 | class ResNetCustomStride(ResNet): 77 | # planes = (64, 128, 128, 256) 78 | def __init__(self, block, layers, strides=(2, 2, 1, 1, 1), planes=(64, 128, 256, 512), num_classes=1000, zero_init_residual=False, 79 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 80 | norm_layer=None, create_final_fc_layer=False): 81 | # super(ResNet, self).__init__() 82 | self.block = block 83 | nn.Module.__init__(self) 84 | if norm_layer is None: 85 | norm_layer = nn.BatchNorm2d 86 | self._norm_layer = norm_layer 87 | 88 | self.create_final_fc_layer = create_final_fc_layer 89 | 90 | self.planes = planes 91 | self.inplanes = 64 92 | self.dilation = 1 93 | if replace_stride_with_dilation is None: 94 | # each element in the tuple indicates if we should replace 95 | # the 2x2 stride with a dilated convolution instead 96 | replace_stride_with_dilation = [False, False, False] 97 | if len(replace_stride_with_dilation) != 3: 98 | raise ValueError("replace_stride_with_dilation should be None " 99 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 100 | self.groups = groups 101 | self.base_width = width_per_group 102 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=strides[0], padding=3, 103 | bias=False) 104 | self.bn1 = norm_layer(self.inplanes) 105 | self.relu = nn.ReLU(inplace=True) 106 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=strides[1], padding=1) 107 | self.layer1 = self._make_layer(block, planes[0], layers[0]) 108 | self.layer2 = self._make_layer(block, planes[1], layers[1], stride=strides[2], 109 | dilate=replace_stride_with_dilation[0]) 110 | self.layer3 = self._make_layer(block, planes[2], layers[2], stride=strides[3], 111 | dilate=replace_stride_with_dilation[1]) 112 | self.layer4 = self._make_layer(block, planes[3], layers[3], stride=strides[4], 113 | dilate=replace_stride_with_dilation[2]) 114 | if len(planes) == 5: 115 | self.layer5 = self._make_layer(block, planes[4], layers[4], stride=strides[5], 116 | dilate=replace_stride_with_dilation[2]) 117 | else: 118 | self.layer5 = None 119 | 120 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 121 | 122 | if self.create_final_fc_layer: 123 | self.fc = nn.Linear(512 * block.expansion, num_classes) 124 | 125 | for m in self.modules(): 126 | if isinstance(m, nn.Conv2d): 127 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 128 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 129 | nn.init.constant_(m.weight, 1) 130 | nn.init.constant_(m.bias, 0) 131 | 132 | # Zero-initialize the last BN in each residual branch, 133 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 134 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 135 | if zero_init_residual: 136 | for m in self.modules(): 137 | if isinstance(m, Bottleneck): 138 | nn.init.constant_(m.bn3.weight, 0) 139 | elif isinstance(m, BasicBlock): 140 | nn.init.constant_(m.bn2.weight, 0) 141 | 142 | def get_num_output_featuremaps(self): 143 | if self.block.__name__ == BasicBlock.__name__: 144 | return [12, 16, self.planes[-1]] 145 | if self.block.__name__ == Bottleneck.__name__: 146 | return [12, 16, self.planes[-1]*4] 147 | else: 148 | raise NotImplementedError 149 | 150 | def _forward_impl(self, x): 151 | # See note [TorchScript super()] 152 | x = self.conv1(x) 153 | x = self.bn1(x) 154 | x = self.relu(x) 155 | x = self.maxpool(x) 156 | 157 | x = self.layer1(x) 158 | x = self.layer2(x) 159 | x = self.layer3(x) 160 | x = self.layer4(x) 161 | 162 | if self.layer5 is not None: 163 | return self.layer5(x) 164 | if not self.create_final_fc_layer: 165 | return x 166 | else: 167 | x = self.avgpool(x) 168 | x = torch.flatten(x, 1) 169 | x = self.fc(x) 170 | return x 171 | 172 | 173 | def get_resnet_encoder(resnet_type, channels_in=3, pretrained=True, **kwargs): 174 | if resnet_type == 'resnet50': 175 | Model = resnet50 176 | elif resnet_type == 'resnet34': 177 | Model = resnet34 178 | elif resnet_type == 'resnet34shallow': 179 | Model = resnet34shallow 180 | elif resnet_type == 'resnet18': 181 | Model = resnet18 182 | else: 183 | raise NotImplementedError 184 | model = Model(pretrained=pretrained, progress=True, **kwargs) 185 | for param in model.parameters(): 186 | param.requires_grad = True 187 | 188 | if channels_in != 3: 189 | orig_weights = model.conv1.weight.clone().detach().data 190 | new_weights = repeat_weights(orig_weights, channels_in) 191 | new_layer = nn.Conv2d(channels_in, orig_weights.shape[0], kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) 192 | new_layer.weight = nn.Parameter(new_weights) 193 | model.conv1 = new_layer 194 | 195 | return model -------------------------------------------------------------------------------- /bridgedata/models/utils/spatial_softmax.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import torch.nn.functional as F 4 | from torch.nn.parameter import Parameter 5 | import numpy as np 6 | 7 | 8 | class SpatialSoftmax(torch.nn.Module): 9 | def __init__(self, height, width, channel, temperature=None, data_format='NCHW'): 10 | super(SpatialSoftmax, self).__init__() 11 | self.data_format = data_format 12 | self.height = height 13 | self.width = width 14 | self.channel = channel 15 | 16 | if temperature: 17 | self.temperature = Parameter(torch.ones(1) * temperature) 18 | else: 19 | self.temperature = 1. 20 | 21 | pos_x, pos_y = np.meshgrid( 22 | np.linspace(-1., 1., self.height), 23 | np.linspace(-1., 1., self.width) 24 | ) 25 | self.pos_x = torch.from_numpy(pos_x.reshape(self.height * self.width)).float() 26 | self.pos_y = torch.from_numpy(pos_y.reshape(self.height * self.width)).float() 27 | self.register_buffer('_pos_x', self.pos_x) 28 | self.register_buffer('_pos_y', self.pos_y) 29 | 30 | def forward(self, feature): 31 | # Output: 32 | # (N, C*2) x_0 y_0 ... 33 | 34 | if self.data_format == 'NHWC': 35 | feature = feature.transpose(1, 3).tranpose(2, 3).view(-1, self.height * self.width) 36 | else: 37 | feature = feature.reshape(-1, self.height * self.width) 38 | # feature = feature.view(-1, self.height * self.width) 39 | softmax_attention = F.softmax(feature / self.temperature, dim=-1) 40 | expected_x = torch.sum(Variable(self._pos_x) * softmax_attention, dim=1, keepdim=True) 41 | expected_y = torch.sum(Variable(self._pos_y) * softmax_attention, dim=1, keepdim=True) 42 | # expected_x = torch.sum(self.pos_x * softmax_attention, dim=1, keepdim=True) 43 | # expected_y = torch.sum(self.pos_y * softmax_attention, dim=1, keepdim=True) 44 | expected_xy = torch.cat([expected_x, expected_y], 1) 45 | feature_keypoints = expected_xy.view(-1, self.channel * 2) 46 | 47 | return feature_keypoints -------------------------------------------------------------------------------- /bridgedata/models/utils/subnetworks.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from torch import Tensor 8 | 9 | from bridgedata.models.utils.layers import ConvBlockEnc, ConvBlockDec, Linear 10 | 11 | def init_weights_xavier(m): 12 | if isinstance(m, nn.Linear): 13 | nn.init.xavier_normal(m.weight.data) 14 | if m.bias is not None: 15 | m.bias.data.fill_(0) 16 | if isinstance(m, nn.Conv2d): 17 | pass # by default PyTorch uses Kaiming_Normal initializer 18 | 19 | class SequentialWithConditional(nn.Sequential): 20 | def __init__(self): 21 | super().__init__() 22 | 23 | def forward(self, inp_dict): 24 | """Computes forward pass through the network outputting all intermediate activations with final output.""" 25 | action = inp_dict['act'] 26 | input = inp_dict['input'] 27 | for i, module in enumerate(self._modules.values()): 28 | if isinstance(module, FiLM): 29 | input = module(input, action) 30 | else: 31 | input = module(input) 32 | return input 33 | 34 | class GetIntermediatesSequential(nn.Sequential): 35 | def __init__(self, stride): 36 | super().__init__() 37 | self.stride = stride 38 | 39 | def forward(self, input): 40 | """Computes forward pass through the network outputting all intermediate activations with final output.""" 41 | skips = [] 42 | for i, module in enumerate(self._modules.values()): 43 | input = module(input) 44 | 45 | if i % self.stride == 0: 46 | skips.append(input) 47 | else: 48 | skips.append(None) 49 | return input, skips[:-1] 50 | 51 | 52 | class FiLM(nn.Module): 53 | def __init__(self, hp, inp_dim, feature_size): 54 | super().__init__() 55 | self._hp = hp 56 | 57 | self.inp_dim = inp_dim 58 | self.feature_size = feature_size 59 | self.linear = Linear(in_dim=inp_dim, out_dim=2*feature_size, builder=self._hp.builder) 60 | 61 | def forward(self, feats, inp): 62 | gb = self.linear(inp) 63 | gamma, beta = gb[:, :self.feature_size], gb[:, self.feature_size:] 64 | gamma = gamma.view(feats.size(0), feats.size(1), 1, 1) 65 | beta = beta.view(feats.size(0), feats.size(1), 1, 1) 66 | return feats * gamma + beta 67 | 68 | def get_num_conv_layers(img_sz): 69 | n = math.log2(img_sz[1]) 70 | assert n >= 3, 'imageSize must be at least 8' 71 | return int(n) 72 | 73 | def calc_output_size_and_padding(input_size, k, s, p): 74 | """ 75 | :param input_size: list of H, W 76 | :return: 77 | """ 78 | 79 | i_h = input_size[0] 80 | i_w = input_size[1] 81 | out_h = (i_h + 2 * p - k) / s + 1 82 | out_w = (i_w + 2 * p - k) / s + 1 83 | return [out_h, out_w] 84 | 85 | 86 | class ConvEncoder(nn.Module): 87 | def __init__(self, hp): 88 | super().__init__() 89 | self._hp = hp 90 | 91 | self.n = get_num_conv_layers(hp.img_sz) 92 | if self._hp.use_skips: 93 | self.net = GetIntermediatesSequential(hp.skips_stride) 94 | else: 95 | self.net = nn.Sequential() 96 | 97 | self.size_list = [] # C, H, W 98 | 99 | input_c = hp.input_nc 100 | 101 | print('l-1: indim {} outdim {}'.format(input_c, hp.ngf)) 102 | self.size_list.append([hp.img_sz[0], hp.img_sz[1]]) 103 | 104 | blk = ConvBlockEnc(in_dim=input_c, out_dim=hp.ngf, normalization=None, input_size=self.size_list[-1]) 105 | self.size_list.append(blk.calc_output_size_and_padding(self.size_list[-1])) 106 | self.net.add_module('input', blk) 107 | 108 | for i in range(self.n - 3): 109 | filters_in = hp.ngf * 2 ** i 110 | 111 | blk = ConvBlockEnc(in_dim=filters_in, out_dim=filters_in * 2) 112 | self.size_list.append(blk.calc_output_size_and_padding(self.size_list[-1])) 113 | self.net.add_module('pyramid-{}'.format(i), blk) 114 | print('l{}: indim {} outdim {}'.format(i, filters_in, filters_in*2)) 115 | 116 | # add output layer 117 | self.size_list.append(calc_output_size_and_padding(self.size_list[-1], 3, 1, 1)) 118 | self.net.add_module('head', nn.Conv2d(hp.ngf * 2 ** (self.n - 3), hp.nz_enc, 3, padding=1, stride=1)) 119 | print('l out: indim {} outdim {}'.format(hp.ngf * 2 ** (self.n - 3), hp.nz_enc)) 120 | 121 | self.net.apply(init_weights_xavier) 122 | 123 | def get_output_size(self): 124 | return list(map(int, self.size_list[-1])) 125 | 126 | def forward(self, input): 127 | return self.net(input) 128 | 129 | class ConvDecoder(nn.Module): 130 | def __init__(self, hp): 131 | super().__init__() 132 | self._hp = hp 133 | 134 | self.n = get_num_conv_layers(self.img_sz) 135 | self.net = GetIntermediatesSequential(hp.skips_stride) if hp.use_skips else nn.Sequential() 136 | 137 | # print('l-1: indim {} outdim {}'.format(64, hp./)) 138 | self.net.add_module('head', nn.ConvTranspose2d(64, 32, 4)) 139 | 140 | 141 | for i in range(self.n - 3): 142 | filters_in = 32 // 2 ** i 143 | self.net.add_module('pyramid-{}'.format(i), 144 | ConvBlockDec(in_dim=filters_in, out_dim=filters_in // 2, normalize=hp.apply_dataset_normalization)) 145 | print('l{}: indim {} outdim {}'.format(i, filters_in, filters_in // 2)) 146 | 147 | self.net.add_module('input', ConvBlockDec(in_dim=8, out_dim=hp.input_nc, normalization=None)) 148 | 149 | # add output layer 150 | 151 | # print('l out: indim {} outdim {}'.format(hp.ngf * 2 ** (self.n - 3), hp.nz_enc)) 152 | 153 | self.net.apply(init_weights_xavier) 154 | 155 | def get_output_size(self): 156 | # return (self._hp.nz_enc, self._hp.img_sz[0]//(2**self.n), self._hp.img_sz[1]//(2**self.n)) 157 | return (3, 64, 64) # todo calc this, fix the padding in the convs! 158 | 159 | def forward(self, input): 160 | return self.net(input) -------------------------------------------------------------------------------- /bridgedata/policies/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanlai00/bridge_data_imitation_learning/d81a3d9181672f5f26dfbd1844d3017cf7a11367/bridgedata/policies/__init__.py -------------------------------------------------------------------------------- /bridgedata/policies/gcbc_policy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | from widowx_envs.policies.policy import Policy 5 | import cv2 6 | import glob 7 | from bridgedata.utils.general_utils import AttrDict 8 | from bridgedata.utils.general_utils import np_unstack 9 | import json 10 | 11 | from widowx_envs.utils.datautils.raw2lmdb import crop_image 12 | 13 | class BCPolicyStates(Policy): 14 | """ 15 | Behavioral Cloning Policy 16 | """ 17 | def __init__(self, ag_params, policyparams): 18 | super(BCPolicyStates, self).__init__() 19 | self._hp = self._default_hparams() 20 | self._override_defaults(policyparams) 21 | 22 | model, model_config = self.get_saved_params(policyparams) 23 | model_config['batch_size'] = 1 24 | model_config['restore_path'] = self._hp.restore_path 25 | 26 | self.predictor = model(model_config) 27 | self.predictor.eval() 28 | self.device = torch.device('cuda') 29 | self.predictor.to(self.device) 30 | print('finished setting up policy') 31 | 32 | 33 | def get_saved_params(self, policyparams): 34 | if str.split(policyparams['restore_path'], '/')[-3] == 'finetuning': 35 | search_pattern = '/finetuning_conf*' 36 | stage = 'finetuning' 37 | else: 38 | search_pattern = '/main_conf*' 39 | stage = 'main' 40 | search_pattern = '/'.join(str.split(policyparams['restore_path'], '/')[:-2]) + search_pattern 41 | conffile = glob.glob(search_pattern) 42 | if len(conffile) == 0: 43 | raise ValueError('no conf files found in ', search_pattern) 44 | conffile = conffile[0] 45 | with open(conffile, 'r') as f: 46 | conf = json.load(f) 47 | if conf['train._hp'][stage]['model'] == 'GCBCImages': 48 | from bridgedata.models.gcbc_images import GCBCImagesModelTest 49 | model = GCBCImagesModelTest 50 | new_conf = conf['model_conf'] 51 | elif conf['train._hp'][stage]['model'] == 'GCBCTransfer': 52 | from bridgedata.models.gcbc_images import GCBCImagesModelTest 53 | model = GCBCImagesModelTest 54 | new_conf = conf['model_conf']['shared_params'] 55 | new_conf.update(conf['model_conf'][self._hp.get_sub_model]) 56 | new_conf['get_sub_model'] = self._hp.get_sub_model 57 | else: 58 | raise ValueError('model not found!') 59 | new_conf['identical_default_ok'] = '' 60 | new_conf.update(self._hp.model_override_params) 61 | return model, new_conf 62 | 63 | def reset(self): 64 | super(BCPolicyStates, self).reset() 65 | 66 | def _default_hparams(self): 67 | default_dict = AttrDict({ 68 | 'restore_path': None, 69 | 'verbose': False, 70 | 'type': None, 71 | 'model_override_params': None, 72 | 'get_sub_model': 'single_task_params', 73 | }) 74 | default_dict.update(super(BCPolicyStates, self)._default_hparams()) 75 | return default_dict 76 | 77 | def act(self, t=None, i_tr=None, state=None, loaded_traj_info=None): 78 | self.t = t 79 | self.i_tr = i_tr 80 | goal_states = loaded_traj_info['state'][-1] 81 | 82 | inputs = AttrDict(state=self.npy2trch(state[-1][None]), 83 | goal_state=self.npy2trch(goal_states[None])) 84 | out = self.predictor(inputs) 85 | 86 | output = AttrDict() 87 | output.actions = out['a_pred'].data.cpu().numpy()[0] 88 | return output 89 | 90 | @property 91 | def default_action(self): 92 | return np.zeros(self.predictor._hp.n_actions) 93 | 94 | def log_outputs_stateful(self, logger=None, global_step=None, phase=None, dump_dir=None, exec_seq=None, goal=None, index=None, env=None, goal_pos=None, traj=None, topdown_image=None): 95 | logger.log_video(np.transpose(exec_seq, [0, 3, 1, 2]), 'control/traj{}_'.format(index), global_step, phase) 96 | goal_img = np.transpose(goal, [2, 0, 1])[None] 97 | goal_img = torch.tensor(goal_img) 98 | logger.log_images(goal_img, 'control/traj{}_goal'.format(index), global_step, phase) 99 | 100 | def npy2trch(self, arr): 101 | return torch.from_numpy(arr).float().to(self.device) 102 | 103 | class GCBCPolicyImages(BCPolicyStates): 104 | def __init__(self, ag_params, policyparams): 105 | super(GCBCPolicyImages, self).__init__(ag_params, policyparams) 106 | self._hp = self._default_hparams() 107 | self._override_defaults(policyparams) 108 | 109 | def _default_hparams(self): 110 | default_dict = AttrDict({ 111 | 'confirm_first_image': False, 112 | 'crop_image_region': False, 113 | 'stack_goal_images': False, 114 | }) 115 | default_dict.update(super(GCBCPolicyImages, self)._default_hparams()) 116 | return default_dict 117 | 118 | @staticmethod 119 | def _preprocess_input(input): 120 | assert len(input.shape) == 4 # can currently only handle inputs with 4 dims 121 | if input.max() > 1.0: 122 | input = input / 255. 123 | if input.min() >= 0.0: 124 | input = 2*input - 1.0 125 | if input.shape[-1] == 3: 126 | input = input.transpose(0, 3, 1, 2) 127 | return input 128 | 129 | def act(self, t=None, i_tr=None, images=None, state=None, goal=None, goal_image=None): 130 | # Note: goal_image provides n (2) images starting from the last images of the trajectory 131 | self.t = t 132 | self.i_tr = i_tr 133 | self.goal_image = goal_image 134 | 135 | 136 | images = images[t] 137 | if self._hp.crop_image_region: 138 | target_height, target_width = self._hp.model_override_params['data_conf']['image_size_beforecrop'] 139 | if self._hp.crop_image_region == 'select': 140 | from widowx_envs.utils.datautils.annotate_object_pos import Getdesig 141 | if self.t == 0: 142 | self.crop_center = np.array(Getdesig(images[0]).desig, dtype=np.int32) 143 | print('selected position', self.crop_center) 144 | else: 145 | self.crop_center = self._hp.crop_image_region 146 | images = crop_image(target_height, target_width, self.crop_center, images) 147 | 148 | if self._hp.model_override_params['data_conf']['image_size_beforecrop'] != images.shape[2:4]: 149 | h, w = self._hp.model_override_params['data_conf']['image_size_beforecrop'] 150 | resized_images = np.zeros([images.shape[0], h, w, 3], dtype=images.dtype) 151 | for n in range(images.shape[0]): 152 | resized_images[n] = cv2.resize(images[n], (w, h), interpolation=cv2.INTER_AREA) 153 | images = resized_images 154 | 155 | if t == 0 and self._hp.confirm_first_image: 156 | import matplotlib.pyplot as plt 157 | import matplotlib 158 | # matplotlib.use('TkAgg') 159 | plt.switch_backend('Tkagg') 160 | # matplotlib.use('Agg') 161 | if self.predictor._hp.concatenate_cameras: 162 | plt.imshow(np.concatenate(np_unstack(images, axis=0), 0)) 163 | else: 164 | plt.imshow(images[self.predictor._hp.sel_camera]) 165 | print('saving start image to', self.traj_log_dir + '/start_image.png') 166 | # plt.savefig(self.traj_log_dir + '/start_image.png') 167 | plt.show() 168 | if self.predictor._hp.goal_cond: 169 | if self._hp.stack_goal_images: 170 | for goal_image_single in goal_image: 171 | plt.imshow(goal_image_single[0].transpose(1, 2, 0)) 172 | plt.show() 173 | else: 174 | plt.imshow(goal_image[0, self.predictor._hp.sel_camera]) 175 | # plt.savefig(self.traj_log_dir + '/goal_image.png') 176 | plt.show() 177 | 178 | images = self.npy2trch(self._preprocess_input(images)) 179 | 180 | inputs = AttrDict(I_0=images) 181 | if self.predictor._hp.goal_cond: 182 | if self._hp.stack_goal_images: 183 | inputs['I_g'] = [self.npy2trch(self._preprocess_input(goal_image_single)) for goal_image_single in goal_image] 184 | else: 185 | inputs['I_g'] = self.npy2trch(self._preprocess_input(goal_image[-1] if len(goal_image.shape) > 4 else goal_image)) 186 | 187 | output = AttrDict() 188 | action = self.predictor(inputs).pred_actions.data.cpu().numpy().squeeze() 189 | print('inferred action', action) 190 | output.actions = action 191 | return output 192 | -------------------------------------------------------------------------------- /bridgedata/policies/gcp_policy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from visual_mpc.policy.policy import Policy 4 | from bridgedata.models.gcbc import GCBCModelTest 5 | from bridgedata.utils.general_utils import AttrDict 6 | from bridgedata.models.gcp import GCPModelTest 7 | from bridgedata.utils.general_utils import np_unstack 8 | 9 | class GCPPolicyStates(Policy): 10 | """ 11 | conf_invembed.py Behavioral Cloning Policy 12 | """ 13 | def __init__(self, ag_params, policyparams, gpu_id, ngpu): 14 | super(GCPPolicyStates, self).__init__() 15 | 16 | self._hp = self._default_hparams() 17 | self._override_defaults(policyparams) 18 | 19 | update_dict = { 20 | 'batch_size' : 1, 21 | 'state_dim': ag_params['state_dim'], 22 | 'action_dim': ag_params['action_dim'] 23 | } 24 | self.T = ag_params['T'] 25 | self.pred_len = self.T - 1 26 | 27 | self._hp.gcbc_params.update(update_dict) 28 | self._hp.gcp_params.update(update_dict) 29 | 30 | self.gcbc_predictor = self._hp.gcbc_model(self._hp.gcbc_params) 31 | self.gcbc_predictor.eval() 32 | self.gcp_predictor = self._hp.gcp_model(self._hp.gcp_params) 33 | self.gcp_predictor.eval() 34 | 35 | self.device = torch.device('cpu') 36 | 37 | self.current_plan = None 38 | self.all_plans = [] 39 | self.tplan = 0 40 | 41 | def reset(self): 42 | super().reset() 43 | 44 | def _default_hparams(self): 45 | default_dict = { 46 | 'gcbc_params': {}, 47 | 'gcp_params': {}, 48 | 'gcbc_model': GCBCModelTest, 49 | 'gcp_model': GCPModelTest, 50 | 'verbose': False, 51 | 'replan_interval': 10 52 | } 53 | 54 | parent_params = super()._default_hparams() 55 | for k in default_dict.keys(): 56 | parent_params.add_hparam(k, default_dict[k]) 57 | return parent_params 58 | 59 | def act(self, image=None, t=None, i_tr=None, state=None, loaded_traj_info=None): 60 | self.t = t 61 | self.i_tr = i_tr 62 | goal_state = loaded_traj_info['state'][-1] 63 | goal_state = self.npy2trch(goal_state[None]) 64 | 65 | if t % self._hp.replan_interval == 0: 66 | self.tplan = 0 67 | inputs = AttrDict(state=self.npy2trch(state[-1][None]), 68 | goal_state=goal_state) 69 | self.current_plan = self.gcp_predictor(inputs, self.pred_len) 70 | self.all_plans.append(self.current_plan) 71 | 72 | out = self.gcbc_predictor(self.current_plan[self.tplan], goal_state) 73 | self.tplan += 1 74 | output = AttrDict() 75 | output.actions = out['a_pred'].data.cpu().numpy()[0] 76 | 77 | if self._hp.verbose and self.t == self.T-1: 78 | self.visualize_plan(image) 79 | return output 80 | 81 | def visualize_plan(self, image): 82 | 83 | im_height = image.shape[1] 84 | im_width = image.shape[2] 85 | total_width = (self.T * 2 + 1) * im_width 86 | total_height = len(self.all_plans)*im_height 87 | 88 | out = np.zeros((total_height, total_width, 3)) 89 | 90 | out[:im_height, :] = np.concatenate(np_unstack(image, axis=0), 1) 91 | 92 | for p, plan in enumerate(self.all_plans): 93 | 94 | import pdb; pdb.set_trace() 95 | colstart = p*self._hp.replan_interval*im_width 96 | out[(p+1)*im_height: (p+2)*im_height, colstart: colstart + self.pred_len * im_width] = plan 97 | 98 | 99 | @property 100 | def default_action(self): 101 | return np.zeros(self.predictor._hp.n_actions) 102 | 103 | def log_outputs_stateful(self, logger=None, global_step=None, phase=None, dump_dir=None, exec_seq=None, goal=None, index=None, env=None, goal_pos=None, traj=None, topdown_image=None): 104 | logger.log_video(np.transpose(exec_seq, [0, 3, 1, 2]), 'control/traj{}_'.format(index), global_step, phase) 105 | goal_img = np.transpose(goal, [2, 0, 1])[None] 106 | goal_img = torch.tensor(goal_img) 107 | logger.log_images(goal_img, 'control/traj{}_goal'.format(index), global_step, phase) 108 | 109 | def npy2trch(self, arr): 110 | return torch.from_numpy(arr).float().to(self.device) 111 | 112 | 113 | 114 | 115 | -------------------------------------------------------------------------------- /bridgedata/train.py: -------------------------------------------------------------------------------- 1 | import matplotlib; 2 | import torch 3 | import imp 4 | 5 | matplotlib.use('Agg'); 6 | import inspect 7 | import copy 8 | import glob 9 | import json 10 | import argparse 11 | import os 12 | import time 13 | import datetime 14 | import numpy as np 15 | from torch import autograd 16 | from torch.optim import Adam 17 | from functools import partial 18 | from bridgedata.utils.general_utils import move_to_device 19 | from bridgedata.utils.general_utils import map_recursive 20 | from bridgedata.utils.general_utils import AverageMeter, RecursiveAverageMeter 21 | from bridgedata.utils.general_utils import AttrDict 22 | from bridgedata.utils.checkpointer import CheckpointHandler 23 | from bridgedata.utils.tensorboard_logger import Logger 24 | from bridgedata.models.gcbc_transfer import GCBCTransfer 25 | from bridgedata.models import get_model_class 26 | from bridgedata.data_sets import get_dataset_class 27 | 28 | from bridgedata.models.utils.compute_dataset_normalization import compute_dataset_normalization 29 | from bridgedata.utils.general_utils import sorted_nicely 30 | import shutil 31 | 32 | def save_checkpoint(state, folder, filename='checkpoint.pth'): 33 | print('saving checkpoint ', os.path.join(folder, filename)) 34 | os.makedirs(folder, exist_ok=True) 35 | torch.save(state, os.path.join(folder, filename)) 36 | return os.path.join(folder, filename) 37 | 38 | def delete_older_checkpoints(path): 39 | files = glob.glob(path + '/weights_itr*.pth') 40 | files = sorted_nicely(files) 41 | if len(files) > 1: 42 | for f in files[:-1]: 43 | os.remove(f) 44 | files = glob.glob(path + '/weights_best*.pth') 45 | files = sorted_nicely(files) 46 | if len(files) > 1: 47 | for f in files[:-1]: 48 | os.remove(f) 49 | 50 | def clear_folder(path): 51 | if os.path.exists(path + '/weights'): 52 | shutil.rmtree(path + '/weights') 53 | if os.path.exists(path + '/events'): 54 | shutil.rmtree(path + '/events') 55 | for f in glob.glob( path + "/*.json"): 56 | os.remove(f) 57 | 58 | def datetime_str(): 59 | return datetime.datetime.now().strftime("_%Y-%m-%d_%H-%M-%S") 60 | 61 | def make_path(exp_dir, conf_path, prefix, make_new_dir): 62 | # extract the subfolder structure from config path 63 | if conf_path.endswith('.json'): 64 | return '/'.join(str.split(conf_path, '/')[:-1]) 65 | else: 66 | path = conf_path.split('bridgedata_experiments/', 1)[1] 67 | if make_new_dir: 68 | prefix += datetime_str() 69 | base_path = os.path.join(exp_dir, '/'.join(str.split(path, '/')[:-1])) 70 | return os.path.join(base_path, prefix) if prefix else base_path 71 | 72 | def set_seeds(seed): 73 | """Sets all seeds and disables non-determinism in cuDNN backend.""" 74 | torch.manual_seed(seed) 75 | torch.backends.cudnn.deterministic = True 76 | torch.backends.cudnn.benchmark = False 77 | np.random.seed(seed) 78 | 79 | from bridgedata.utils.general_utils import Configurable 80 | 81 | class ModelTrainer(Configurable): 82 | def __init__(self, args): 83 | # Uncomment the following lines to sync tensorboard logs to weights and biases 84 | # import wandb 85 | # from bridgedata.config import WANDB_API_KEY, WANDB_EMAIL, WANDB_USERNAME 86 | # os.environ['WANDB_API_KEY'] = WANDB_API_KEY 87 | # os.environ['WANDB_USER_EMAIL'] = WANDB_EMAIL 88 | # os.environ['WANDB_USERNAME'] = WANDB_USERNAME 89 | # os.environ["WANDB_MODE"] = "run" 90 | # wandb.init(project='bridge_data', reinit=True, sync_tensorboard=True, name=args.prefix) 91 | self.batch_idx = 0 92 | self.args = args 93 | 94 | ## Set up params 95 | self.conf, self.model_conf, self.data_conf = self.get_configs() 96 | if args.data_config_override is not None: 97 | if 'data_dir' in args.data_config_override: 98 | override_dict = {'data_dir': os.environ['DATA'] + '/' + args.data_config_override['data_dir'], 99 | 'name': '_'.join(str.split(args.data_config_override['data_dir'], '/')[-2:])} 100 | else: 101 | override_dict = args.data_config_override 102 | if 'main' in self.data_conf: 103 | self.data_conf['main']['dataconf'].update(override_dict) 104 | if 'finetuning' in self.data_conf: 105 | self.data_conf['finetuning']['dataconf'].update(override_dict) 106 | self.data_conf['finetuning']['val0']['dataconf'].update(override_dict) 107 | print('data_conf after override: ', self.data_conf) 108 | elif args.source_data_config_override is not None: 109 | if 'data_dir' in args.source_data_config_override: 110 | override_dict = {'data_dir': os.environ['DATA'] + '/' + args.source_data_config_override['data_dir'], 111 | 'name': '_'.join(str.split(args.source_data_config_override['data_dir'], '/')[-2:])} 112 | else: 113 | override_dict = args.source_data_config_override 114 | self.data_conf.main.dataconf.dataset0[1].update(override_dict) 115 | self.data_conf.main.val0.dataconf.update(override_dict) 116 | print('data_conf after override: ', self.data_conf) 117 | self._loggers = AttrDict() 118 | 119 | if args.gpu != -1: 120 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 121 | else: 122 | os.environ["CUDA_VISIBLE_DEVICES"] = str(0) 123 | 124 | self._hp = self._default_hparams() 125 | self._override_defaults(self.conf) # override defaults with config file 126 | 127 | self._hp.exp_path = make_path(os.environ['EXP'] + '/bridgedata_experiments', args.path, args.prefix, args.new_dir) 128 | if not args.resume: 129 | clear_folder(self._hp.exp_path) 130 | self.log_dir = log_dir = os.path.join(self._hp.exp_path, 'events') 131 | print('using log dir: ', log_dir) 132 | 133 | self.run_testmetrics = args.metric 134 | if args.deterministic: 135 | set_seeds(args.deterministic) 136 | 137 | if args.cpu: 138 | self.device = torch.device('cpu') 139 | self.use_cuda = False 140 | else: 141 | self.use_cuda = torch.cuda.is_available() 142 | self.device = torch.device('cuda') if self.use_cuda else torch.device('cpu') 143 | if not args.skip_main: # if resuming finetuning skip the mainstage 144 | self.run_stage(args, 'main') 145 | if self._hp['finetuning']: 146 | self.run_stage(args, 'finetuning') 147 | 148 | def run_stage(self, args, stage): 149 | print('preparing stage ', stage) 150 | self.stage = stage 151 | reuse_action_predictor = True 152 | if stage == 'finetuning' and not args.resume: # if resuming finetuning we don't initialize from mainstage 153 | if self._hp['main'].model is GCBCTransfer: 154 | trained_model_state_dict = copy.deepcopy(self.model.single_task_model.state_dict()) 155 | else: 156 | trained_model_state_dict = copy.deepcopy(self.model.state_dict()) 157 | 158 | if 'goal_cond' in self.model_conf['main'] and 'goal_cond' not in self.model_conf['finetuning']: 159 | reuse_action_predictor = False 160 | for key in copy.deepcopy(trained_model_state_dict): 161 | if 'action_predictor' in key: 162 | trained_model_state_dict.pop(key) 163 | 164 | if stage != 'main': 165 | self._hp.exp_path = os.path.join(self._hp.exp_path, stage) 166 | print('Writing to the experiment directory: {}'.format(self._hp.exp_path)) 167 | if not os.path.exists(self._hp.exp_path): 168 | os.makedirs(self._hp.exp_path) 169 | self.save_dir = os.path.join(self._hp.exp_path, 'weights') 170 | 171 | dataset_class = self.data_conf[stage].dataclass 172 | data_conf = self.data_conf[stage].dataconf 173 | 174 | model_class = self._hp[self.stage].model 175 | self.model, self.train_loader, self.train_dataset = self.make_model_and_dataset(self._hp.logger, model_class, 176 | dataset_class, data_conf, 'train', stage) 177 | self.model_val, val_loader, val_dataset = self.make_model_and_dataset(self._hp.logger, model_class, 178 | dataset_class, data_conf, 'val', stage) 179 | if self._hp.dataset_normalization: 180 | self.model.set_normalizing_params(self.norm_params) 181 | self.model_val.set_normalizing_params(self.norm_params) 182 | 183 | self.val_loaders = [val_loader] 184 | self.val_datasets = [val_dataset] 185 | self.make_additional_valdatasets(stage) 186 | 187 | # to make sure train and validation loaders are using the same task and domain indices 188 | all_datasets = [self.train_dataset, *self.val_datasets] 189 | all_domains = set([d for dataset in all_datasets for d in list(dataset.domain_hash_index.keys())]) 190 | all_taskdescriptions = set([d for dataset in all_datasets for d in dataset.taskdescription2task_index.keys()]) 191 | domain_hash_index = {domain_hash: index for domain_hash, index in 192 | zip(all_domains, range(len(all_domains)))} 193 | if args.load_task_indices: 194 | with open(args.load_task_indices) as f: 195 | taskdescription2task_index = json.load(f) 196 | else: 197 | taskdescription2task_index = {task_descp: index for task_descp, index in 198 | zip(all_taskdescriptions, range(len(all_taskdescriptions)))} 199 | for dataset in all_datasets: 200 | dataset.set_domain_and_taskdescription_indices(domain_hash_index, taskdescription2task_index) 201 | 202 | with open(os.path.join(self._hp.exp_path, "task_index.json"), 'w') as f: 203 | json.dump(taskdescription2task_index, f, indent=4) 204 | 205 | for dataset in all_datasets: 206 | with open(os.path.join(self._hp.exp_path, "{}_{}_dataset_stats.txt".format(dataset._hp.name, dataset.phase)), 'w') as f: 207 | f.write(dataset.dataset_stats) 208 | 209 | self.optimizer = Adam(self.model.parameters(), lr=self._hp.lr, weight_decay=self._hp.weight_decay) 210 | 211 | self._hp.mpar = self.model._hp 212 | save_config({'train._hp': self._hp, 'model_conf': self.model._hp, 'data_conf': self.data_conf}, 213 | os.path.join(self._hp.exp_path, stage + "_conf" + datetime_str() + ".json")) 214 | 215 | if stage.startswith('finetuning') and not args.resume: # if resuming finetuning skip the mainstage 216 | incompatible_keys = self.model.load_state_dict(trained_model_state_dict, strict=False) 217 | if reuse_action_predictor: 218 | assert len(incompatible_keys.missing_keys) == 0 219 | else: 220 | print('Warning, missing keys in stage{}: {}'.format(stage, incompatible_keys.missing_keys)) 221 | print('Warning, unexpected keys in stage{}: {}'.format(stage, incompatible_keys.unexpected_keys)) 222 | if args.resume or self._hp.resume_checkpoint is not None: 223 | if self._hp.resume_checkpoint is not None: 224 | args.resume = self._hp.resume_checkpoint 225 | start_epoch, self.global_step = self.resume(args.resume) 226 | if stage.startswith('finetuning') and 'finetuning' not in args.resume: 227 | start_epoch = 0 228 | self.global_step = 0 229 | args.resume = False # avoid having it crash in the finetuning stage if we resume in the main stage 230 | else: 231 | self.global_step = 0 232 | start_epoch = 0 233 | if 'num_epochs' in self._hp[self.stage]: 234 | num_epochs = self._hp[self.stage].num_epochs 235 | max_iterations = None 236 | else: 237 | num_epochs = None 238 | max_iterations = self._hp[self.stage].max_iterations 239 | self.best_val_loss = float('inf') 240 | self.train(start_epoch, num_epochs, max_iterations) 241 | 242 | def make_additional_valdatasets(self, stage): 243 | val_keys = [k for k in self.data_conf[stage].keys() if k.startswith('val')] 244 | if val_keys == []: 245 | return 246 | for key in val_keys: 247 | print('making extra val dataset with key: ', key) 248 | dataclass = self.data_conf[stage][key].dataclass 249 | dataconf = self.data_conf[stage][key].dataconf 250 | dataset = dataclass(dataconf, 'val', shuffle=True) 251 | loader = dataset.get_data_loader(self._hp.batch_size) 252 | self.val_datasets.append(dataset) 253 | self.val_loaders.append(loader) 254 | 255 | def make_model_and_dataset(self, logger, ModelClass, DatasetClass, data_conf, phase, stage): 256 | logger = logger(os.path.join(self.log_dir, phase)) 257 | self._loggers[phase] = logger 258 | model_conf = copy.deepcopy(self.model_conf) 259 | if 'main' in model_conf: 260 | model_conf = model_conf[stage] 261 | model_conf['batch_size'] = self._hp.batch_size 262 | model_conf['device'] = self.device.type 263 | if stage == 'finetuning': 264 | if 'finetuning_override' in model_conf: 265 | model_conf.update(self.model_conf.finetuning_override) 266 | model_conf['stage'] = 'finetuning' 267 | model_conf['phase'] = phase 268 | model_conf.data_conf = data_conf 269 | model = ModelClass(model_conf, logger) 270 | model.to(self.device) 271 | model.device = self.device 272 | if phase is not 'test': 273 | dataset = DatasetClass(data_conf, phase, shuffle=True) 274 | loader = dataset.get_data_loader(self._hp.batch_size) 275 | 276 | if phase == 'train' and self._hp.dataset_normalization: 277 | self.norm_params = compute_dataset_normalization(loader) 278 | 279 | return model, loader, dataset 280 | 281 | def _default_hparams(self): 282 | # put new parameters in here: 283 | default_dict = { 284 | 'resume_checkpoint': None, 285 | 'logger': Logger, 286 | 'batch_size': 32, 287 | 'mpar': None, # model parameters 288 | 'data_conf': None, # model data parameters 289 | 'exp_path': None, # Path to the folder with experiments 290 | 'log_every': 1, 291 | 'delete_older_checkpoints': True, 292 | 'epoch_cycles_train': 1, 293 | 'optimizer': 'adam', # supported: 'adam', 'rmsprop', 'sgd' 294 | 'lr': 1e-4, 295 | 'momentum': 0, # momentum in RMSProp / SGD optimizer 296 | 'adam_beta': 0.9, # beta1 param in Adam 297 | 'main': None, 298 | 'finetuning': None, 299 | 'weight_decay': 0, 300 | 'dataset_normalization': True, 301 | 'delta_step_val': 100, 302 | 'delta_step_control_val': 500, 303 | 'delta_step_save': 500, 304 | } 305 | # add new params to parent params 306 | return AttrDict(default_dict) 307 | 308 | def get_configs(self): 309 | conf_path = os.path.abspath(self.args.path) 310 | 311 | if conf_path.endswith('.json'): 312 | with open(conf_path, 'r') as f: 313 | json_conf = json.load(f) 314 | conf = json_conf['train._hp'] 315 | conf['model'] = get_model_class(conf['model']) 316 | conf['finetuning_model'] = get_model_class(conf['finetuning_model']) 317 | conf['dataset_class'] = get_dataset_class(conf['dataset_class']) 318 | conf['finetuning_dataset_class'] = get_dataset_class(conf['finetuning_dataset_class']) 319 | conf['identical_default_ok'] = '' 320 | if conf['logger'] == 'Logger': 321 | conf['logger'] = Logger 322 | else: 323 | raise NotImplementedError 324 | model_conf = json_conf['model_conf'] 325 | model_conf['identical_default_ok'] = '' 326 | data_conf = json_conf['data_conf'] 327 | data_conf['identical_default_ok'] = '' 328 | else: 329 | print('loading from the config file {}'.format(conf_path)) 330 | conf_module = imp.load_source('conf', self.args.path) 331 | conf = conf_module.configuration 332 | model_conf = conf_module.model_config 333 | data_conf = conf_module.data_config 334 | 335 | return conf, model_conf, data_conf 336 | 337 | def resume(self, ckpt): 338 | weights_file = CheckpointHandler.get_resume_ckpt_file(ckpt, os.path.join(self._hp.exp_path, 'weights')) 339 | global_step, start_epoch, _ = \ 340 | CheckpointHandler.load_weights(weights_file, self.model, 341 | load_step_and_opt=True, optimizer=self.optimizer, 342 | dataset_length=len(self.train_loader) * self._hp.batch_size, 343 | strict=self.args.strict_weight_loading) 344 | self.model.to(self.model.device) 345 | return start_epoch, global_step 346 | 347 | def train(self, start_epoch, num_epochs, max_iterations): 348 | if max_iterations is not None: 349 | num_epochs = int(np.ceil(max_iterations/len(self.train_loader))) 350 | print('setting num_epochs to ', num_epochs) 351 | self.last_global_step_when_val = int(-1e9) # make sure we do val the first time of the train-val cycle. 352 | self.last_global_step_when_control = int(-1e9) 353 | self.last_global_step_when_save = int(-1e9) 354 | for i, epoch in enumerate(range(start_epoch, num_epochs)): 355 | self.train_val_cycle(epoch, num_epochs) 356 | return epoch 357 | 358 | def train_val_cycle(self, epoch, num_epochs): 359 | if not self.args.no_train: 360 | self.train_epoch(epoch, num_epochs) 361 | if (self.global_step - self.last_global_step_when_val) > self._hp.delta_step_val and not self.args.no_val: 362 | val_loss = self.val() 363 | self.last_global_step_when_val = self.global_step 364 | if (self.global_step - self.last_global_step_when_save) > self._hp.delta_step_save: 365 | if val_loss < self.best_val_loss: 366 | self.best_val_loss = val_loss 367 | save_string = 'weights_best_itr{}.pth'.format(self.global_step) 368 | else: 369 | save_string = 'weights_itr{}.pth'.format(self.global_step) 370 | self._save_file_name = save_checkpoint({ 371 | 'epoch': epoch, 372 | 'global_step': self.global_step, 373 | 'state_dict': self.model.state_dict(), 374 | 'optimizer': self.optimizer.state_dict(), 375 | }, self.save_dir, save_string) 376 | if self._hp.delete_older_checkpoints: 377 | delete_older_checkpoints(self.save_dir) 378 | self.last_global_step_when_save = self.global_step 379 | 380 | @property 381 | def log_outputs_now(self): 382 | return self.global_step % self.log_outputs_interval == 0 383 | 384 | def train_epoch(self, epoch, num_epochs): 385 | self.model.train() 386 | self.model.to(self.device) 387 | epoch_len = len(self.train_loader) 388 | end = time.time() 389 | batch_time = AverageMeter() 390 | upto_log_time = AverageMeter() 391 | data_load_time = AverageMeter() 392 | self.log_outputs_interval = 50 393 | print('starting epoch ', epoch) 394 | self.model.set_dataset_sufix(self.train_dataset._hp) 395 | 396 | for self.batch_idx, sample_batched in enumerate(self.train_loader): 397 | data_load_time.update(time.time() - end) 398 | inputs = move_to_device(sample_batched, self.device) 399 | self.optimizer.zero_grad() 400 | inputs.global_step = self.global_step 401 | inputs.max_iterations = self._hp[self.stage].max_iterations 402 | output = self.model(inputs) 403 | losses = self.model.loss(inputs, output) 404 | losses.total_loss.backward() 405 | self.optimizer.step() 406 | 407 | upto_log_time.update(time.time() - end) 408 | if self.log_outputs_now: 409 | self.model.log_outputs(output, inputs, losses, self.global_step, 410 | phase='train') 411 | batch_time.update(time.time() - end) 412 | 413 | if self.log_outputs_now: 414 | print('GPU {}: {}'.format(os.environ["CUDA_VISIBLE_DEVICES"] if self.use_cuda else 'none', self._hp.exp_path)) 415 | print(('stage {}, itr: {} Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(self.stage, 416 | self.global_step, epoch, self.batch_idx, len(self.train_loader), 417 | 100. * self.batch_idx / len(self.train_loader), losses.total_loss.item()))) 418 | 419 | print('avg time for loading: {:.2f}s, logs: {:.2f}s, compute: {:.2f}s, total: {:.2f}s' 420 | .format(data_load_time.avg, 421 | batch_time.avg - upto_log_time.avg, 422 | upto_log_time.avg - data_load_time.avg, 423 | batch_time.avg)) 424 | togo_train_time = batch_time.avg * (num_epochs - epoch) * epoch_len / 3600. 425 | print('ETA: {:.2f}h'.format(togo_train_time)) 426 | 427 | del output, losses 428 | self.global_step = self.global_step + 1 429 | end = time.time() 430 | 431 | if self.global_step > self._hp[self.stage].max_iterations: 432 | break 433 | 434 | self.model.to(torch.device('cpu')) 435 | 436 | def val(self): 437 | val_losses = [] 438 | for val_loader, val_dataset in zip(self.val_loaders, self.val_datasets): 439 | val_loss = self.eval_for_dataset(val_loader, val_dataset) 440 | val_losses.append(val_loss) 441 | return val_losses[0] 442 | 443 | def eval_for_dataset(self, val_loader, val_dataset): 444 | print('Running Testing') 445 | start = time.time() 446 | self.model_val.to(self.device) 447 | self.model_val.load_state_dict(self.model.state_dict()) 448 | self.model_val.eval() 449 | self.model_val.throttle_log_images = 0 # make sure to log images the first val pass! 450 | self.model_val.set_dataset_sufix(val_dataset._hp) 451 | losses_meter = RecursiveAverageMeter() 452 | with autograd.no_grad(): 453 | for batch_idx, sample_batched in enumerate(val_loader): 454 | inputs = move_to_device(sample_batched, self.device) 455 | inputs.global_step = self.global_step 456 | inputs.max_iterations = self._hp[self.stage].max_iterations 457 | output = self.model_val(inputs) 458 | losses = self.model_val.loss(inputs, output) 459 | losses_meter.update(losses) 460 | del losses 461 | 462 | self.model_val.log_outputs( 463 | output, inputs, losses_meter.avg, self.global_step, phase='val') 464 | print(('\nTest set: Average loss: {:.4f} in {:.2f}s over {} batches\n' 465 | .format(losses_meter.avg.total_loss.item(), time.time() - start, batch_idx))) 466 | del output 467 | self.model_val.to(torch.device('cpu')) 468 | return losses_meter.avg.total_loss.item() 469 | 470 | def get_optimizer_class(self): 471 | if self._hp.optimizer == 'adam': 472 | optim = partial(Adam, betas=(self._hp.adam_beta, 0.999)) 473 | else: 474 | raise ValueError("Optimizer '{}' not supported!".format(self._hp.optimizer)) 475 | return optim 476 | 477 | 478 | def save_config(confs, exp_conf_path): 479 | def func(x): 480 | if inspect.isclass(x) or inspect.isfunction(x): 481 | return x.__name__ 482 | else: 483 | return x 484 | confs = map_recursive(func, confs) 485 | 486 | with open(exp_conf_path, 'w') as f: 487 | json.dump(confs, f, indent=4) 488 | 489 | 490 | if __name__ == '__main__': 491 | parser = argparse.ArgumentParser() 492 | parser.add_argument("--path", help="path to the config file directory") 493 | # Folder settings 494 | parser.add_argument("--prefix", help="experiment prefix, if given creates subfolder in experiment directory") 495 | parser.add_argument('--new_dir', default=False, action='store_true', help='If True, concat datetime string to exp_dir.') 496 | 497 | # Running protocol 498 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 499 | help='path to latest checkpoint (default: none)') 500 | parser.add_argument('--skip_main', default=False, action='store_true', 501 | help='if true will go directly to fine-tuning stage') 502 | parser.add_argument('--no_train', default=False, action='store_true', 503 | help='if False will not run training epoch') 504 | parser.add_argument('--no_val', default=False, action='store_true', 505 | help='if False will not run validation epoch') 506 | parser.add_argument('--metric', default=False, action='store_true', 507 | help='if True, run test metrics') 508 | parser.add_argument('--cpu', default=False, 509 | help='if True, use CPU', action='store_true') 510 | 511 | # Misc 512 | parser.add_argument('--gpu', default=-1, type=int, 513 | help='will set CUDA_VISIBLE_DEVICES to selected value') 514 | parser.add_argument('--strict_weight_loading', default=True, type=int, 515 | help='if True, uses strict weight loading function') 516 | parser.add_argument('--deterministic', default=False, action='store_true', 517 | help='if True, sets fixed seeds for torch and numpy') 518 | parser.add_argument('--data_config_override', default=None, help='used in dooodad for sweeps') 519 | parser.add_argument('--load_task_indices', default=None, help='task indices json file to load') 520 | parser.add_argument('--source_data_config_override', default=None, help='used in dooodad for sweeps') 521 | args = parser.parse_args() 522 | ModelTrainer(args) 523 | -------------------------------------------------------------------------------- /bridgedata/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanlai00/bridge_data_imitation_learning/d81a3d9181672f5f26dfbd1844d3017cf7a11367/bridgedata/utils/__init__.py -------------------------------------------------------------------------------- /bridgedata/utils/calc_success_rates.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | THRESHOLD = 0.02 3 | 4 | def calculate_success_rate(file_name): 5 | print(f'using success threshold of {THRESHOLD}') 6 | with open(file_name, 'r') as f: 7 | contents = f.read() 8 | lines = contents.split('\n') 9 | traj_idx = 0 10 | successes = 0 11 | total = 0 12 | traj_found = True 13 | while traj_found: 14 | traj_found = False 15 | data = None 16 | for line in lines: 17 | if line.startswith(f'{traj_idx}: '): 18 | traj_found = True 19 | data = line 20 | if traj_found: 21 | # Extracting numbers... 22 | data = data.replace(':', ',').split(',') 23 | final_dist = float(data[2]) 24 | with open('scores.txt', 'a') as s: 25 | s.write(str(final_dist) + '\n') 26 | if final_dist <= THRESHOLD: 27 | successes += 1 28 | total += 1 29 | traj_idx += 1 30 | print(f'{1.0*successes/total*100}% success rate, of {total} trials.') 31 | 32 | if __name__ == '__main__': 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('results_file', type=str, help='path to outputted benchmark results file') 35 | args = parser.parse_args() 36 | calculate_success_rate(args.results_file) -------------------------------------------------------------------------------- /bridgedata/utils/checkpointer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import numpy as np 4 | import torch 5 | import sys 6 | import pipes 7 | from bridgedata.utils.general_utils import AttrDict, str2int 8 | 9 | class CheckpointHandler: 10 | @staticmethod 11 | def get_ckpt_name(epoch): 12 | return 'weights_ep{}.pth'.format(epoch) 13 | 14 | @staticmethod 15 | def get_epochs(path): 16 | checkpoint_names = glob.glob(os.path.abspath(path) + "/*.pth") 17 | if len(checkpoint_names) == 0: 18 | raise ValueError("No checkpoints found at {}!".format(path)) 19 | processed_names = [file.split('/')[-1].replace('weights_ep', '').replace('.pth', '') 20 | for file in checkpoint_names] 21 | epochs = list(filter(lambda x: x is not None, [str2int(name) for name in processed_names])) 22 | return epochs 23 | 24 | @staticmethod 25 | def get_resume_ckpt_file(resume, path): 26 | print("Loading from: {}".format(path)) 27 | if resume == 'latest': 28 | max_epoch = np.max(CheckpointHandler.get_epochs(path)) 29 | resume_file = CheckpointHandler.get_ckpt_name(max_epoch) 30 | elif str2int(resume) is not None: 31 | resume_file = CheckpointHandler.get_ckpt_name(resume) 32 | elif '.pth' not in resume: 33 | resume_file = resume + '.pth' 34 | else: 35 | resume_file = resume 36 | 37 | return os.path.join(path, resume_file) 38 | 39 | @staticmethod 40 | def load_weights(weights_file, model, load_step_and_opt=False, optimizer=None, dataset_length=None, strict=True): 41 | success = False 42 | if os.path.isfile(weights_file): 43 | print(("=> loading checkpoint '{}'".format(weights_file))) 44 | checkpoint = torch.load(weights_file, map_location=model.device) 45 | model.load_state_dict(checkpoint['state_dict'], strict=strict) 46 | if load_step_and_opt: 47 | start_epoch = checkpoint['epoch'] + 1 48 | global_step = checkpoint['global_step'] 49 | try: 50 | optimizer.load_state_dict(checkpoint['optimizer']) 51 | except (RuntimeError, ValueError) as e: 52 | if not strict: 53 | print("Could not load optimizer params because of changes in the network + non-strict loading") 54 | pass 55 | else: 56 | raise e 57 | print(("=> loaded checkpoint '{}' (epoch {})" 58 | .format(weights_file, checkpoint['epoch']))) 59 | success = True 60 | else: 61 | # print(("=> no checkpoint found at '{}'".format(weights_file))) 62 | # start_epoch = 0 63 | raise ValueError("Could not find checkpoint file in {}!".format(weights_file)) 64 | 65 | if load_step_and_opt: 66 | return global_step, start_epoch, success 67 | else: 68 | return success 69 | 70 | 71 | @staticmethod 72 | def hack_to_fix_checkpoints(): 73 | import torch 74 | fold = '../experiments/prediction/rec_planner/nav2d/single_wall/soft_fixed_baseline/kl1e-2' 75 | keywords_remove = ['running_var', 'running_mean', 'num_batches_tracked'] 76 | 77 | ckpt = torch.load(fold + '/weights_ep17.pth') 78 | state = dict(ckpt['state_dict']) 79 | new_state = {} 80 | 81 | def remove_keyword(keyword): 82 | pop_list = ([key for key in state if keyword in key]); 83 | [state.pop(key) for key in pop_list] 84 | 85 | [remove_keyword(keyword) for keyword in keywords_remove] 86 | 87 | def fn(key): new_state[key.replace('batchnorm', 'norm')] = state[key] 88 | 89 | [fn(key) for key in state if 'batchnorm' in key] 90 | remove_keyword('batchnorm') 91 | 92 | new_state.update(state) 93 | ckpt['state_dict'] = new_state 94 | torch.save(ckpt, fold + '/weights_norun.pth') 95 | 96 | @staticmethod 97 | def another_hack_to_fix_checkpoints(): 98 | import torch 99 | fold = '../experiments/prediction/rec_planner/nav2d/single_wall/soft_fixed_baseline/kl1e-2' 100 | keywords_remove = ['running_var', 'running_mean', 'num_batches_tracked'] 101 | 102 | ckpt = torch.load(fold + '/weights_ep17.pth') 103 | state = dict(ckpt['state_dict']) 104 | new_state = {} 105 | 106 | def remove_keyword(keyword): 107 | pop_list = ([key for key in state if keyword in key]); 108 | [state.pop(key) for key in pop_list] 109 | 110 | [remove_keyword(keyword) for keyword in keywords_remove] 111 | 112 | new_state.update(state) 113 | ckpt['state_dict'] = new_state 114 | torch.save(ckpt, fold + '/weights_norun.pth') 115 | 116 | 117 | @staticmethod 118 | def rename_parameters(dict, old, new): 119 | """ Renames parameters in the network by finding parameters that contain 'old' and replacing 'old' with 'new' 120 | """ 121 | replacements = [key for key in dict if old in key] 122 | 123 | for key in replacements: 124 | dict[key.replace(old, new)] = dict.pop(key) 125 | 126 | 127 | def get_config_path(path): 128 | conf_names = glob.glob(os.path.abspath(path) + "/*.py") 129 | if len(conf_names) == 0: 130 | raise ValueError("No configuration files found at {}!".format(path)) 131 | 132 | # The standard conf 133 | if 'conf_invembed.py' in map(lambda x: x.split('/')[-1], conf_names): 134 | return os.path.join(path, 'conf_invembed.py') 135 | 136 | # Get latest conf 137 | arrays = [np.array(file.split('__')[-1].replace('_', '-').replace('.py', '').split('-'), 138 | dtype=float) for file in filter(lambda x: '__' in x, conf_names)] 139 | # Converts arrays representing time to values that can be compared 140 | values = np.array(list([ar[5] + 100 * ar[4] + (100 ** 2) * ar[3] + (100 ** 3) * ar[2] 141 | + (100 ** 4) * ar[1] + (100 ** 5) * 10 * ar[0] 142 | for ar in arrays])) 143 | conf_ind = np.argmax(values) 144 | return conf_names[conf_ind] 145 | 146 | 147 | def save_git(base_dir): 148 | # save code revision 149 | print('Save git commit and diff to {}/git.txt'.format(base_dir)) 150 | cmds = ["echo `git rev-parse HEAD` > {}".format( 151 | os.path.join(base_dir, 'git.txt')), 152 | "git diff >> {}".format( 153 | os.path.join(base_dir, 'git.txt'))] 154 | print(cmds) 155 | os.system("\n".join(cmds)) 156 | 157 | 158 | def save_cmd(base_dir): 159 | train_cmd = 'python ' + ' '.join([sys.argv[0]] + [pipes.quote(s) for s in sys.argv[1:]]) 160 | train_cmd += '\n' 161 | print('\n' + '*' * 80) 162 | print('Training command:\n' + train_cmd) 163 | print('*' * 80 + '\n') 164 | with open(os.path.join(base_dir, "cmd.txt"), "w") as f: 165 | f.write(train_cmd) 166 | 167 | 168 | 169 | -------------------------------------------------------------------------------- /bridgedata/utils/figure_out_scatter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | nb_digits = 10 5 | # Dummy input that HAS to be 2D for the scatter (you can use view(-1,1) if needed) 6 | # active_dim = active_dim.type(torch.LongTensor).view(-1, 1) 7 | batch_size = 8 8 | T = 14 9 | active_dim = torch.LongTensor(batch_size, T, 1).random_() % nb_digits 10 | # active_dim = active_dim.type(torch.LongTensor)[:, 0] 11 | # One hot encoding buffer that you create out of the loop and just keep reusing 12 | # T = active_dim.shape[1] 13 | batch_size = active_dim.shape[0] 14 | y_onehot = torch.FloatTensor(batch_size, T, nb_digits) 15 | 16 | # In your for loop 17 | y_onehot.zero_() 18 | import pdb; pdb.set_trace() 19 | y_onehot.scatter_(2, active_dim, 1) 20 | 21 | print(active_dim) 22 | print(y_onehot) 23 | import pdb; pdb.set_trace() 24 | -------------------------------------------------------------------------------- /bridgedata/utils/general_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import re 3 | import cv2 4 | from PIL import Image 5 | from torchvision.transforms import Resize 6 | import torch 7 | from functools import reduce 8 | import copy 9 | 10 | def str2int(str): 11 | try: 12 | return int(str) 13 | except ValueError: 14 | return None 15 | 16 | 17 | class Configurable: 18 | def _override_defaults(self, params): 19 | params = copy.copy(params) 20 | if 'identical_default_ok' in params: 21 | identical_default_ok = True 22 | params.pop('identical_default_ok') 23 | else: 24 | identical_default_ok = False 25 | 26 | for name, value in params.items(): 27 | # print('overriding param {} to value {}'.format(name, value)) 28 | if value == getattr(self._hp, name) and not identical_default_ok: 29 | raise ValueError("attribute is {} is identical to default value {} !!".format(name, value)) 30 | self._hp[name] = value 31 | 32 | def _default_hparams(self): 33 | return AttrDict() 34 | 35 | class HasParameters: 36 | def __init__(self, **kwargs): 37 | self.build_params(kwargs) 38 | 39 | def build_params(self, inputs): 40 | # If params undefined define params 41 | try: 42 | self.params 43 | except AttributeError: 44 | self.params = self.get_default_params() 45 | self.params.update(inputs) 46 | 47 | def move_to_device(inputs, device): 48 | def func(x): 49 | if isinstance(x, np.ndarray): 50 | x = torch.from_numpy(x) 51 | if isinstance(x, list): 52 | return list(map(lambda x:x.to(device), x)) 53 | if isinstance(x, dict): 54 | return AttrDict(map_dict(func, x)) 55 | else: 56 | return x.to(device) 57 | return AttrDict(map_dict(func, inputs)) 58 | 59 | 60 | def down_sample_imgs(obs, des_size): 61 | obs = copy.deepcopy(obs) 62 | imgs = obs['images'] 63 | target_array = np.zeros([imgs.shape[0], imgs.shape[1], des_size[0], des_size[1], 3], dtype=np.uint8) 64 | for n in range(imgs.shape[1]): 65 | for t in range(imgs.shape[0]): 66 | target_array[t, n] = cv2.resize(imgs[t, n], (des_size[1], des_size[0]), interpolation=cv2.INTER_AREA) 67 | obs['images'] = target_array 68 | return obs 69 | 70 | class AttrDict(dict): 71 | __setattr__ = dict.__setitem__ 72 | 73 | def __getattr__(self, attr): 74 | # Take care that getattr() raises AttributeError, not KeyError. 75 | # Required e.g. for hasattr(), deepcopy and OrderedDict. 76 | try: 77 | return self.__getitem__(attr) 78 | except KeyError: 79 | raise AttributeError("Attribute %r not found" % attr) 80 | 81 | def __getstate__(self): return self 82 | def __setstate__(self, d): self = d 83 | 84 | 85 | def np_unstack(array, axis): 86 | arr = np.split(array, array.shape[axis], axis) 87 | arr = [a.squeeze() for a in arr] 88 | return arr 89 | 90 | def map_dict(fn, d): 91 | """takes a dictionary and applies the function to every element""" 92 | return type(d)(map(lambda kv: (kv[0], fn(kv[1])), d.items())) 93 | 94 | 95 | def make_recursive(fn, *argv, **kwargs): 96 | """ Takes a fn and returns a function that can apply fn on tensor structure 97 | which can be a single tensor, tuple or a list. """ 98 | 99 | def recursive_map(tensors): 100 | if tensors is None: 101 | return tensors 102 | elif isinstance(tensors, list) or isinstance(tensors, tuple): 103 | return type(tensors)(map(recursive_map, tensors)) 104 | elif isinstance(tensors, dict): 105 | return type(tensors)(map_dict(recursive_map, tensors)) 106 | elif isinstance(tensors, torch.Tensor): 107 | return fn(tensors, *argv, **kwargs) 108 | else: 109 | try: 110 | return fn(tensors, *argv, **kwargs) 111 | except Exception as e: 112 | print("The following error was raised when recursively applying a function:") 113 | print(e) 114 | raise ValueError("Type {} not supported for recursive map".format(type(tensors))) 115 | 116 | return recursive_map 117 | 118 | 119 | def listdict2dictlist(LD): 120 | """ Converts a list of dicts to a dict of lists """ 121 | 122 | # Take intersection of keys 123 | keys = reduce(lambda x, y: x & y, (map(lambda d: d.keys(), LD))) 124 | return AttrDict({k: [dic[k] for dic in LD] for k in keys}) 125 | 126 | def make_recursive_list(fn): 127 | """ Takes a fn and returns a function that can apply fn across tuples of tensor structures, 128 | each of which can be a single tensor, tuple or a list. """ 129 | 130 | def recursive_map(tensors): 131 | if tensors is None: 132 | return tensors 133 | elif isinstance(tensors[0], list) or isinstance(tensors[0], tuple): 134 | return type(tensors[0])(map(recursive_map, zip(*tensors))) 135 | elif isinstance(tensors[0], dict): 136 | return map_dict(recursive_map, listdict2dictlist(tensors)) 137 | elif isinstance(tensors[0], torch.Tensor): 138 | return fn(*tensors) 139 | else: 140 | try: 141 | return fn(*tensors) 142 | except Exception as e: 143 | print("The following error was raised when recursively applying a function:") 144 | print(e) 145 | raise ValueError("Type {} not supported for recursive map".format(type(tensors))) 146 | 147 | return recursive_map 148 | 149 | 150 | recursively = make_recursive 151 | 152 | 153 | def map_recursive(fn, tensors): 154 | return make_recursive(fn)(tensors) 155 | 156 | 157 | def map_recursive_list(fn, tensors): 158 | return make_recursive_list(fn)(tensors) 159 | 160 | def resize_video(video, size): 161 | transformed_video = np.stack([np.asarray(Resize(size)(Image.fromarray(im))) for im in video], axis=0) 162 | return transformed_video 163 | 164 | class AverageMeter(object): 165 | """Computes and stores the average and current value""" 166 | 167 | def __init__(self): 168 | self.reset() 169 | 170 | def reset(self): 171 | self.val = 0 172 | self.avg = 0 173 | self.sum = 0 174 | self.count = 0 175 | 176 | def update(self, val, n=1): 177 | self.val = val 178 | self.sum += val * n 179 | self.count += n 180 | self.avg = self.sum / self.count 181 | 182 | 183 | def select_indices(tensor, indices): 184 | assert len(indices.shape) == 1 185 | new_images = [] 186 | for b in range(tensor.shape[0]): 187 | new_images.append(tensor[b, indices[b]]) 188 | tensor = torch.stack(new_images, dim=0) 189 | return tensor 190 | 191 | class RecursiveAverageMeter(object): 192 | """Computes and stores the average and current value""" 193 | 194 | def __init__(self): 195 | self.reset() 196 | 197 | def reset(self): 198 | self.val = None 199 | self.avg = None 200 | self.sum = None 201 | self.count = 0 202 | 203 | def update(self, val): 204 | self.val = val 205 | if self.sum is None: 206 | self.sum = val 207 | else: 208 | self.sum = map_recursive_list(lambda x, y: x + y, [self.sum, val]) 209 | self.count += 1 210 | self.avg = map_recursive(lambda x: x / self.count, self.sum) 211 | 212 | 213 | def sorted_nicely(l): 214 | """ Sort the given iterable in the way that humans expect.""" 215 | convert = lambda text: int(text) if text.isdigit() else text 216 | alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] 217 | return sorted(l, key = alphanum_key) 218 | 219 | 220 | def trch2npy(tensor): 221 | return tensor.data.cpu().numpy() 222 | 223 | def npy2trch(tensor, device='cuda'): 224 | return torch.from_numpy(tensor).to(torch.device(device)) 225 | -------------------------------------------------------------------------------- /bridgedata/utils/tensorboard_logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pdb 3 | import torchvision 4 | from bridgedata.utils.vis_utils import draw_text_image 5 | from bridgedata.utils.general_utils import select_indices 6 | import torch 7 | from tensorboardX import SummaryWriter 8 | import numpy as np 9 | import copy 10 | from bridgedata.utils.general_utils import np_unstack 11 | 12 | class Logger: 13 | def __init__(self, log_dir, n_logged_samples=10, summary_writer=None): 14 | self._log_dir = log_dir 15 | self._n_logged_samples = n_logged_samples 16 | if summary_writer is not None: 17 | self._summ_writer = summary_writer 18 | else: 19 | self._summ_writer = SummaryWriter(log_dir) 20 | 21 | def _loop_batch(self, fn, name, val, *argv, **kwargs): 22 | """Loops the logging function n times.""" 23 | for log_idx in range(min(self._n_logged_samples, len(val))): 24 | name_i = os.path.join(name, "_%d" % log_idx) 25 | fn(name_i, val[log_idx], *argv, **kwargs) 26 | 27 | @staticmethod 28 | def _check_size(val, size): 29 | if isinstance(val, torch.Tensor) or isinstance(val, np.ndarray): 30 | assert len(val.shape) == size, "Size of tensor does not fit required size, {} vs {}".format(len(val.shape), 31 | size) 32 | elif isinstance(val, list): 33 | assert len(val[0].shape) == size - 1, "Size of list element does not fit required size, {} vs {}".format( 34 | len(val[0].shape), size - 1) 35 | else: 36 | raise NotImplementedError("Input type {} not supported for dimensionality check!".format(type(val))) 37 | if (val[0].shape[1] > 10000) or (val[0].shape[2] > 10000): 38 | raise ValueError("This might be a bit too much") 39 | 40 | def log_scalar(self, scalar, name, step, phase=None): 41 | if phase: 42 | sc_name = '{}_{}'.format(name, phase) 43 | else: 44 | sc_name = name 45 | self._summ_writer.add_scalar(sc_name, scalar, step) 46 | 47 | def log_scalars(self, scalar_dict, group_name, step, phase=None): 48 | """Will log all scalars in the same plot.""" 49 | if phase: 50 | sc_name = '{}_{}'.format(group_name, phase) 51 | else: 52 | sc_name = group_name 53 | self._summ_writer.add_scalars('{}_{}'.format(sc_name, phase), scalar_dict, step) 54 | 55 | def log_images(self, image, name, step, phase): 56 | self._check_size(image, 4) # [N, C, H, W] 57 | self._loop_batch(self._summ_writer.add_image, '{}_{}'.format(name, phase), image, step) 58 | 59 | def log_video(self, video_frames, name, step, phase=None, fps=10): 60 | assert len(video_frames.shape) == 4, "Need [T, C, H, W] input tensor for single video logging!" 61 | if not isinstance(video_frames, torch.Tensor): video_frames = torch.tensor(video_frames) 62 | video_frames = video_frames.unsqueeze(0) # add an extra dimension to get grid of size 1 63 | if phase: 64 | sc_name = '{}_{}'.format(name, phase) 65 | else: 66 | sc_name = name 67 | self._summ_writer.add_video(sc_name, video_frames, step, fps=fps) 68 | 69 | def log_image(self, images, name, step, phase): 70 | self._summ_writer.add_image('{}_{}'.format(name, phase), images, step) 71 | 72 | def log_image_grid(self, images, name, step, phase, nrow=8): 73 | assert len(images.shape) == 4, "Image grid logging requires input shape [batch, C, H, W]!" 74 | img_grid = torchvision.utils.make_grid(images, nrow=nrow) 75 | self.log_images(img_grid, '{}_{}'.format(name, phase), step) 76 | 77 | def log_video_grid(self, video_frames, name, step, phase, fps=3): 78 | assert len(video_frames.shape) == 5, "Need [N, T, C, H, W] input tensor for video logging!" 79 | self._summ_writer.add_video('{}_{}'.format(name, phase), video_frames, step, fps=fps) 80 | 81 | def log_figures(self, figure, name, step, phase): 82 | """figure: matplotlib.pyplot figure handle""" 83 | assert figure.shape[0] > 0, "Figure logging requires input shape [batch x figures]!" 84 | self._loop_batch(self._summ_writer.add_figure, '{}_{}'.format(name, phase), figure, step) 85 | 86 | def log_figure(self, figure, name, step, phase): 87 | """figure: matplotlib.pyplot figure handle""" 88 | self._summ_writer.add_figure('{}_{}'.format(name, phase), figure, step) 89 | 90 | def dump_scalars(self, log_path=None): 91 | log_path = os.path.join(self._log_dir, "scalar_data.json") if log_path is None else log_path 92 | self._summ_writer.export_scalars_to_json(log_path) 93 | 94 | def log_kbest_videos(self, model_output, inputs, losses, step, phase): 95 | 96 | def get_per_example_loss_red(): 97 | loss = torch.mean((model_output.a_pred - inputs.sel_actions)**2, dim=-1) 98 | no_aux_loss = torch.mean((model_output.a_pred_no_aux - inputs.sel_actions)**2, dim=-1) 99 | loss_red_perex = (loss - no_aux_loss).cpu().detach().numpy() 100 | loss_reduction_row = np.stack([draw_text_image(str(r), dtype=np.uint8) for r in loss_red_perex], axis=0) 101 | T = inputs.best_matches_states.shape[1] 102 | per_example_loss_row = torch.from_numpy(np.tile(loss_reduction_row[:, None], [1, T, 1, 1, 1])) 103 | return per_example_loss_row 104 | 105 | goal_img = inputs.images.squeeze()[:, -1] 106 | vid = assemble_videos_kbestmatches(inputs.current_img, goal_img, inputs.best_matches_images, get_per_example_loss_red()) 107 | self.log_video(vid, 'nearest_neighbors', step, phase, fps=10) 108 | 109 | def flush(self): 110 | self._summ_writer.flush() 111 | 112 | def assemble_videos_kbestmatches(current_img, goal_img, best_matches_images, per_example_loss_red=None, n_batch_examples=10): 113 | """ 114 | all inputs have to torch tensors! 115 | :param current_img: 116 | :param goal_img: 117 | :param best_matches_images: [b, t, nbest, row, cols, channel] 118 | :param per_example_loss_red: 119 | :param n_batch_examples: 120 | :return: 121 | """ 122 | video_rows = [] # list of (b, T, rows, cols, 3) 123 | T = best_matches_images.shape[1] 124 | 125 | if per_example_loss_red is not None: 126 | video_rows.append(per_example_loss_red) 127 | 128 | video_rows.append(copy.deepcopy(current_img[:, None].repeat(1, T, 1, 1, 1))) 129 | video_rows.append(copy.deepcopy(goal_img[:, None].repeat(1, T, 1, 1, 1))) 130 | 131 | for i in range(best_matches_images.shape[2]): 132 | video_rows.append(best_matches_images[:, :T, i]) 133 | 134 | video_rows = [v.cpu().numpy() for v in video_rows] 135 | 136 | videos = np.concatenate(video_rows, axis=2) 137 | videos = np_unstack(videos, axis=0) 138 | videos = np.concatenate(videos[:n_batch_examples], axis=2) 139 | videos = np.transpose(videos, [0, 3, 1, 2]) 140 | return videos 141 | 142 | 143 | import bridgedata 144 | 145 | class Mujoco_Renderer(): 146 | def __init__(self, im_height, im_width): 147 | from mujoco_py import load_model_from_path, MjSim 148 | 149 | mujoco_xml = '/'.join(str.split(bridgedata.__file__, '/')[:-1]) \ 150 | + '/environments/tabletop/assets/sawyer_xyz/sawyer_multiobject_textured.xml' 151 | 152 | self.sim = MjSim(load_model_from_path(mujoco_xml)) 153 | self.im_height = im_height 154 | self.im_width = im_width 155 | 156 | def render(self, qpos): 157 | sim_state = self.sim.get_eef_pose() 158 | sim_state.qpos[:] = qpos 159 | sim_state.qvel[:] = np.zeros_like(self.sim.data.qvel) 160 | self.sim.set_state(sim_state) 161 | self.sim.forward() 162 | 163 | subgoal_image = self.sim.render(self.im_height, self.im_width, camera_name='cam0') 164 | # plt.imshow(subgoal_image) 165 | # plt.savefig('test.png') 166 | return subgoal_image 167 | -------------------------------------------------------------------------------- /bridgedata/utils/vis_utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | import numpy as np 4 | import moviepy.editor as mpy 5 | import os 6 | from PIL import Image 7 | 8 | def fig2img(fig): 9 | """Converts a given figure handle to a 3-channel numpy image array.""" 10 | fig.canvas.draw() 11 | w, h = fig.canvas.get_width_height() 12 | buf = np.fromstring(fig.canvas.tostring_argb(), dtype=np.uint8) 13 | buf.shape = (w, h, 4) 14 | buf = np.roll(buf, 3, axis=2) 15 | w, h, d = buf.shape 16 | return np.array(Image.frombytes("RGBA", (w, h), buf.tostring()), dtype=np.float32)[:, :, :3] / 255. 17 | 18 | 19 | def fig2img_(fig): 20 | """Converts a given figure handle to a 3-channel numpy image array.""" 21 | fig.canvas.draw() 22 | 23 | # Get the RGBA buffer from the figure 24 | w, h = fig.canvas.get_width_height() 25 | buf = np.fromstring(fig.canvas.tostring_argb(), dtype=np.uint8) 26 | buf.shape = (w, h, 4) 27 | 28 | # canvas.tostring_argb give pixmap in ARGB mode. Roll the ALPHA channel to have it in RGBA mode 29 | buf = np.roll(buf, 3, axis=2) 30 | 31 | return buf 32 | 33 | 34 | def plot_graph(array, h=400, w=400, dpi=10, linewidth=3.0): 35 | fig = plt.figure(figsize=(w / dpi, h / dpi), dpi=dpi) 36 | if isinstance(array, torch.Tensor): 37 | array = array.cpu().numpy() 38 | plt.xlim(0, array.shape[0] - 1) 39 | plt.xticks(fontsize=100) 40 | plt.yticks(fontsize=100) 41 | plt.plot(array) 42 | plt.grid() 43 | plt.tight_layout() 44 | fig_img = fig2img(fig) 45 | plt.close(fig) 46 | return fig_img 47 | 48 | def plot_bar(array, h=400, w=400, dpi=10, linewidth=3.0): 49 | fig = plt.figure(figsize=(w / dpi, h / dpi), dpi=dpi) 50 | if isinstance(array, torch.Tensor): 51 | array = array.cpu().numpy() 52 | plt.xlim(0, array.shape[0] - 1) 53 | plt.xticks(fontsize=100) 54 | plt.yticks(fontsize=100) 55 | plt.bar(np.arange(array.shape[0]), array) 56 | plt.grid() 57 | plt.tight_layout() 58 | fig_img = fig2img(fig) 59 | plt.close(fig) 60 | return fig_img 61 | 62 | def npy_to_gif(im_list, filename, fps=4): 63 | save_dir = '/'.join(str.split(filename, '/')[:-1]) 64 | if not os.path.exists(save_dir): 65 | print('creating directory: ', save_dir) 66 | os.makedirs(save_dir) 67 | clip = mpy.ImageSequenceClip(im_list, fps=fps) 68 | clip.write_gif(filename + '.gif') 69 | 70 | def npy_to_mp4(im_list, filename, fps=4): 71 | save_dir = '/'.join(str.split(filename, '/')[:-1]) 72 | 73 | if not os.path.exists(save_dir): 74 | print('creating directory: ', save_dir) 75 | os.mkdir(save_dir) 76 | 77 | clip = mpy.ImageSequenceClip(im_list, fps=fps) 78 | clip.write_videofile(filename + '.mp4') 79 | 80 | 81 | from PIL import Image, ImageDraw 82 | 83 | def draw_text_image(text, background_color=(255,255,255), image_size=(30, 64), dtype=np.float32): 84 | 85 | text_image = Image.new('RGB', image_size[::-1], background_color) 86 | draw = ImageDraw.Draw(text_image) 87 | if text: 88 | draw.text((4, 0), text, fill=(0, 0, 0)) 89 | if dtype == np.float32: 90 | return np.array(text_image).astype(np.float32)/255. 91 | else: 92 | return np.array(text_image) 93 | 94 | 95 | def draw_text_onimage(text, image, color=(255, 0, 0)): 96 | if image.dtype == np.float32: 97 | image = (image*255.).astype(np.uint8) 98 | assert image.dtype == np.uint8 99 | from PIL import Image, ImageDraw 100 | text_image = Image.fromarray(image) 101 | draw = ImageDraw.Draw(text_image) 102 | draw.text((4, 0), text, fill=color) 103 | return np.array(text_image).astype(np.float32)/255. 104 | 105 | import cv2 106 | 107 | def visualize_barplot_array(input_arr, img_size=(64, 64)): 108 | plt.switch_backend('agg') 109 | imgs = [] 110 | for b in range(input_arr.shape[0]): 111 | img = plot_bar(input_arr[b]) 112 | img = cv2.resize(img, (img_size[1], img_size[0]), interpolation=cv2.INTER_CUBIC) 113 | imgs.append((img*255.).astype(np.uint8)) 114 | return imgs 115 | 116 | if __name__ == '__main__': 117 | sigmodis = np.random.random_integers(0, 1, [10, 10]) 118 | visualize_barplot_array(sigmodis) 119 | -------------------------------------------------------------------------------- /bridgedata_experiments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanlai00/bridge_data_imitation_learning/d81a3d9181672f5f26dfbd1844d3017cf7a11367/bridgedata_experiments/__init__.py -------------------------------------------------------------------------------- /bridgedata_experiments/bc_fromscratch/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | from bridgedata.models.gcbc_images import GCBCImages 3 | import numpy as np 4 | from bridgedata.utils.general_utils import AttrDict 5 | current_dir = os.path.dirname(os.path.realpath(__file__)) 6 | from widowx_envs.utils.datautils.lmdb_dataloader import LMDB_Dataset_Pandas 7 | 8 | configuration = AttrDict( 9 | main=AttrDict( 10 | model=GCBCImages, 11 | max_iterations=400000, 12 | ), 13 | ) 14 | 15 | sponge_wipe = AttrDict( 16 | name='sponge_wipe', 17 | random_crop=[96, 128], 18 | color_augmentation=0.1, 19 | image_size_beforecrop=[112, 144], 20 | data_dir=os.environ['DATA'] + '/robonetv2/toykitchen_fixed_cam/berkeley/realkitchen1_counter/', 21 | filtering_function=[lambda dframe: dframe[(dframe['policy_desc'] == 'human_demo, pick up sponge and wipe plate')]], 22 | ) 23 | 24 | data_config = AttrDict( 25 | main=AttrDict( 26 | dataclass=LMDB_Dataset_Pandas, 27 | dataconf=sponge_wipe, 28 | ) 29 | ) 30 | 31 | model_config = AttrDict( 32 | main=AttrDict( 33 | action_dim=7, 34 | state_dim=7, 35 | resnet='resnet34', 36 | img_sz=[96, 128] 37 | ) 38 | ) -------------------------------------------------------------------------------- /bridgedata_experiments/dataset_lmdb.py: -------------------------------------------------------------------------------- 1 | import os 2 | from bridgedata.utils.general_utils import AttrDict 3 | from widowx_envs.utils.datautils.lmdb_dataloader import LMDB_Dataset_Pandas, LMDB_Dataset_Success_Classifier 4 | 5 | TOTAL_NUM_TASKS=78 6 | TOTAL_NUM_TASKS_ALIASING=66 7 | 8 | task_name_aliasing_dict = { 9 | "human_demo, put sweet_potato in pot which is in sink": "human_demo, put sweet potato in pot", 10 | "human_demo, put cup from counter or drying rack into sink": "human_demo, put cup from anywhere into sink", 11 | "human_demo, put eggplant in pot or pan": "human_demo, put eggplant into pot or pan", 12 | "human_demo, put eggplant into pan": "human_demo, put eggplant into pot or pan", 13 | "human_demo, put green squash in pot or pan ": "human_demo, put green squash into pot or pan", 14 | "human_demo, put pot in sink": "human_demo, put pot or pan in sink", 15 | "human_demo, put pan in sink": "human_demo, put pot or pan in sink", 16 | "human_demo, put pan from stove to sink": "human_demo, put pot or pan in sink", 17 | "human_demo, put pan from drying rack into sink": "human_demo, put pot or pan in sink", 18 | "human_demo, put pan on stove from sink": "human_demo, put pot or pan on stove", 19 | "human_demo, put pot on stove which is near stove": "human_demo, put pot or pan on stove", 20 | "human_demo, put pan from sink into drying rack": "human_demo, put pot or pan from sink into drying rack", 21 | 'human_demo, open small 4-flap box flaps': 'human_demo, open box', 22 | 'human_demo, open white 1-flap box flap': 'human_demo, open box', 23 | 'human_demo, open brown 1-flap box flap': 'human_demo, open box', 24 | 'human_demo, open large 4-flap box flaps': 'human_demo, open box', 25 | 'human_demo, close small 4-flap box flaps': 'human_demo, close box', 26 | 'human_demo, close white 1-flap box flap': 'human_demo, close box', 27 | 'human_demo, close brown 1-flap box flap': 'human_demo, close box', 28 | 'human_demo, close large 4-flap box flaps': 'human_demo, close box', 29 | 'human_demo, put pepper in pan': 'human_demo, put pepper in pot or pan', 30 | } 31 | 32 | excluded_dirs = ['initial_testconfig', 'cropped', 'initial_test_config', 'put_eggplant_in_pot_or_pan','chest', 33 | 'put_big_spoon_from_basket_to_tray', 'put_small_spoon_from_basket_to_tray', 34 | 'put_fork_from_basket_to_tray'] 35 | 36 | bridge_data_config = AttrDict( 37 | name='alldata', 38 | random_crop=[96, 128], 39 | color_augmentation=0.1, 40 | image_size_beforecrop=[112, 144], 41 | data_dir=os.environ['DATA'] + '/robonetv2/toykitchen_fixed_cam/', 42 | excluded_dirs=excluded_dirs, 43 | aliasing_dict=task_name_aliasing_dict, 44 | ) 45 | 46 | toysink1_room8052 = AttrDict( 47 | name='toysink1_room8052', 48 | random_crop=[96, 128], 49 | color_augmentation=0.1, 50 | image_size_beforecrop=[112, 144], 51 | data_dir=os.environ['DATA'] + '/robonetv2/toykitchen_fixed_cam/', 52 | filtering_function=[lambda dframe: dframe[(dframe['environment'] == 'toysink1_room8052')]], 53 | excluded_dirs=excluded_dirs, 54 | aliasing_dict=task_name_aliasing_dict, 55 | ) 56 | 57 | validation_conf_toysink1_room8052 = AttrDict( 58 | val0=AttrDict( 59 | dataclass=LMDB_Dataset_Pandas, 60 | dataconf=toysink1_room8052 61 | ), 62 | ) 63 | 64 | toysink3_bww = AttrDict( 65 | name='toysink3_bww', 66 | random_crop=[96, 128], 67 | color_augmentation=0.1, 68 | image_size_beforecrop=[112, 144], 69 | data_dir=os.environ['DATA'] + '/robonetv2/toykitchen_fixed_cam/', 70 | filtering_function=[lambda dframe: dframe[(dframe['environment'] == 'toysink3_bww')]], 71 | excluded_dirs=excluded_dirs, 72 | aliasing_dict=task_name_aliasing_dict, 73 | ) 74 | 75 | validation_conf_toysink3= AttrDict( 76 | val0=AttrDict( 77 | dataclass=LMDB_Dataset_Pandas, 78 | dataconf=toysink3_bww 79 | ), 80 | ) 81 | 82 | toykitchen2 = AttrDict( 83 | name='toykitchen2', 84 | random_crop=[96, 128], 85 | color_augmentation=0.1, 86 | image_size_beforecrop=[112, 144], 87 | data_dir=os.environ['DATA'] + '/robonetv2/toykitchen_fixed_cam/', 88 | filtering_function=[lambda dframe: dframe[(dframe['environment'] == 'toykitchen2_room8052') | (dframe['environment'] == 'toykitchen2')]], 89 | excluded_dirs=excluded_dirs, 90 | aliasing_dict=task_name_aliasing_dict, 91 | ) 92 | 93 | toykitchen2_room8052 = AttrDict( 94 | name='toykitchen2_room8052', 95 | random_crop=[96, 128], 96 | color_augmentation=0.1, 97 | image_size_beforecrop=[112, 144], 98 | data_dir=os.environ['DATA'] + '/robonetv2/toykitchen_fixed_cam/', 99 | filtering_function=[lambda dframe: dframe[(dframe['environment'] == 'toykitchen2_room8052')]], 100 | excluded_dirs=excluded_dirs, 101 | aliasing_dict=task_name_aliasing_dict, 102 | ) 103 | 104 | toykitchen1 = AttrDict( 105 | name='toykitchen1', 106 | random_crop=[96, 128], 107 | color_augmentation=0.1, 108 | image_size_beforecrop=[112, 144], 109 | data_dir=os.environ['DATA'] + '/robonetv2/toykitchen_fixed_cam/berkeley/toykitchen1', 110 | aliasing_dict=task_name_aliasing_dict, 111 | excluded_dirs=excluded_dirs, 112 | ) 113 | 114 | excl_toykitchen2 = AttrDict( 115 | name='excl_toykitchen2', 116 | random_crop=[96, 128], 117 | color_augmentation=0.1, 118 | image_size_beforecrop=[112, 144], 119 | data_dir=os.environ['DATA'] + '/robonetv2/toykitchen_fixed_cam/', 120 | filtering_function=[lambda dframe: dframe[(dframe['environment'] != 'toykitchen2_room8052') & (dframe['environment'] != 'toykitchen2')]], 121 | aliasing_dict=task_name_aliasing_dict, 122 | excluded_dirs=excluded_dirs, 123 | ) 124 | 125 | excl_toykitchen1 = AttrDict( 126 | name='excl_toykitchen1', 127 | random_crop=[96, 128], 128 | color_augmentation=0.1, 129 | image_size_beforecrop=[112, 144], 130 | data_dir=os.environ['DATA'] + '/robonetv2/toykitchen_fixed_cam/', 131 | filtering_function=[lambda dframe: dframe[(dframe['environment'] != 'toykitchen_bww') & (dframe['environment'] != 'toykitchen1')]], 132 | aliasing_dict=task_name_aliasing_dict, 133 | excluded_dirs=excluded_dirs, 134 | ) 135 | 136 | excl_toysink3= AttrDict( 137 | name='excl_toysink3', 138 | random_crop=[96, 128], 139 | color_augmentation=0.1, 140 | image_size_beforecrop=[112, 144], 141 | data_dir=os.environ['DATA'] + '/robonetv2/toykitchen_fixed_cam/', 142 | excluded_dirs=['initial_testconfig', 'cropped', 'initial_test_config', 'put_eggplant_in_pot_or_pan'], 143 | filtering_function=[lambda dframe: dframe[(dframe['environment'] != 'toysink3_bww')]], 144 | aliasing_dict=task_name_aliasing_dict, 145 | ) 146 | 147 | validation_conf_toykitchen2_room8052 = AttrDict( 148 | val0=AttrDict( 149 | dataclass=LMDB_Dataset_Pandas, 150 | dataconf=toykitchen2_room8052 151 | ), 152 | ) 153 | 154 | 155 | -------------------------------------------------------------------------------- /bridgedata_experiments/random_mixing_task_id/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | from bridgedata.models.gcbc_images import GCBCImages 3 | from bridgedata.utils.general_utils import AttrDict 4 | current_dir = os.path.dirname(os.path.realpath(__file__)) 5 | from bridgedata_experiments.dataset_lmdb import task_name_aliasing_dict 6 | from widowx_envs.utils.datautils.lmdb_dataloader import LMDB_Dataset_Pandas 7 | from bridgedata.data_sets.multi_dataset_loader import RandomMixingDatasetLoader 8 | 9 | configuration = AttrDict( 10 | main=AttrDict( 11 | model=GCBCImages, 12 | max_iterations=400000, 13 | ), 14 | ) 15 | 16 | sponge_wipe = AttrDict( 17 | name='sponge_wipe', 18 | random_crop=[96, 128], 19 | color_augmentation=0.1, 20 | image_size_beforecrop=[112, 144], 21 | data_dir=os.environ['DATA'] + '/robonetv2/toykitchen_fixed_cam/berkeley/realkitchen1_counter/', 22 | filtering_function=[lambda dframe: dframe[(dframe['policy_desc'] == 'human_demo, pick up sponge and wipe plate')]], 23 | aliasing_dict=task_name_aliasing_dict, 24 | ) 25 | 26 | validation_sponge_wipe = AttrDict( 27 | val0=AttrDict( 28 | dataclass=LMDB_Dataset_Pandas, 29 | dataconf=sponge_wipe 30 | ), 31 | ) 32 | 33 | excl_real_kitchen_and_toolchest= AttrDict( 34 | name='excl_real_kitchen_and_toolchest', 35 | random_crop=[96, 128], 36 | color_augmentation=0.1, 37 | image_size_beforecrop=[112, 144], 38 | data_dir=os.environ['DATA'] + '/robonetv2/toykitchen_fixed_cam', 39 | excluded_dirs=['initial_testconfig', 'cropped', 'initial_test_config', 'put_eggplant_in_pot_or_pan', 'realkitchen1_counter', 'realkitchen1_dishwasher', 'tool_chest'], 40 | aliasing_dict=task_name_aliasing_dict, 41 | ) 42 | 43 | source_data = sponge_wipe 44 | validation_data = validation_sponge_wipe 45 | bridge_data = excl_real_kitchen_and_toolchest 46 | 47 | data_config = AttrDict( 48 | main=AttrDict( 49 | dataclass=RandomMixingDatasetLoader, 50 | dataconf=AttrDict( 51 | dataset0=[ 52 | LMDB_Dataset_Pandas, 53 | source_data, 54 | 0.1 55 | ], 56 | dataset1=[ 57 | LMDB_Dataset_Pandas, 58 | bridge_data, 59 | 0.9 60 | ], 61 | ), 62 | **validation_data 63 | ) 64 | ) 65 | 66 | model_config = AttrDict( 67 | main=AttrDict( 68 | action_dim=7, 69 | state_dim=7, 70 | resnet='resnet34', 71 | task_id_conditioning=72, 72 | img_sz=[96, 128] 73 | ) 74 | ) -------------------------------------------------------------------------------- /bridgedata_experiments/random_mixing_task_id/conf_toykitchen1.py: -------------------------------------------------------------------------------- 1 | import os 2 | from bridgedata.models.gcbc_images import GCBCImages 3 | from bridgedata.utils.general_utils import AttrDict 4 | current_dir = os.path.dirname(os.path.realpath(__file__)) 5 | from bridgedata_experiments.dataset_lmdb import task_name_aliasing_dict 6 | from widowx_envs.utils.datautils.lmdb_dataloader import LMDB_Dataset_Pandas 7 | from bridgedata.data_sets.multi_dataset_loader import RandomMixingDatasetLoader 8 | 9 | configuration = AttrDict( 10 | main=AttrDict( 11 | model=GCBCImages, 12 | max_iterations=400000, 13 | ), 14 | ) 15 | 16 | bridge_data_config_kitchen1_aliasing = AttrDict( 17 | name='toykitchen1', 18 | random_crop=[96, 128], 19 | color_augmentation=0.1, 20 | image_size_beforecrop=[112, 144], 21 | data_dir=os.environ['DATA'] + '/robonetv2/toykitchen_fixed_cam/', 22 | excluded_dirs=['initial_testconfig', 'cropped', 'initial_test_config', 'put_eggplant_in_pot_or_pan', 'tool_chest', 'from_basket_to_tray', 'realkitchen1'], 23 | filtering_function=[lambda dframe: dframe[(dframe['environment'] == 'toykitchen1') | (dframe['environment'] == 'toykitchen_bww')]], 24 | aliasing_dict=task_name_aliasing_dict, 25 | ) 26 | 27 | validation_conf_toykitchen1_aliasing = AttrDict( 28 | val0=AttrDict( 29 | dataclass=LMDB_Dataset_Pandas, 30 | dataconf=bridge_data_config_kitchen1_aliasing 31 | ), 32 | ) 33 | 34 | bridge_data_config = AttrDict( 35 | name='alldata', 36 | random_crop=[96, 128], 37 | color_augmentation=0.1, 38 | image_size_beforecrop=[112, 144], 39 | data_dir=os.environ['DATA'] + '/robonetv2/toykitchen_fixed_cam', 40 | excluded_dirs=['initial_testconfig', 'cropped', 'initial_test_config', 'put_eggplant_in_pot_or_pan', 'tool_chest', 'from_basket_to_tray', 'realkitchen1'], 41 | aliasing_dict=task_name_aliasing_dict, 42 | ) 43 | 44 | source_data = bridge_data_config_kitchen1_aliasing 45 | validation_data = validation_conf_toykitchen1_aliasing 46 | 47 | data_config = AttrDict( 48 | main=AttrDict( 49 | dataclass=RandomMixingDatasetLoader, 50 | dataconf=AttrDict( 51 | dataset0=[ 52 | LMDB_Dataset_Pandas, 53 | source_data, 54 | 0.3 55 | ], 56 | dataset1=[ 57 | LMDB_Dataset_Pandas, 58 | bridge_data_config, 59 | 0.7 60 | ], 61 | ), 62 | **validation_data 63 | ) 64 | ) 65 | 66 | model_config = AttrDict( 67 | main=AttrDict( 68 | action_dim=7, 69 | state_dim=7, 70 | resnet='resnet34', 71 | task_id_conditioning=70, 72 | img_sz=[96, 128] 73 | ) 74 | ) -------------------------------------------------------------------------------- /bridgedata_experiments/task_id_conditioned/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | from bridgedata.models.gcbc_images import GCBCImages 3 | import numpy as np 4 | from bridgedata.utils.general_utils import AttrDict 5 | current_dir = os.path.dirname(os.path.realpath(__file__)) 6 | from bridgedata_experiments.dataset_lmdb import TOTAL_NUM_TASKS_ALIASING, task_name_aliasing_dict, bridge_data_config 7 | from widowx_envs.utils.datautils.lmdb_dataloader import LMDB_Dataset_Pandas 8 | 9 | configuration = AttrDict( 10 | main=AttrDict( 11 | model=GCBCImages, 12 | max_iterations=400000, 13 | ), 14 | ) 15 | 16 | bridge_data_config_kitchen1_aliasing = AttrDict( 17 | name='toykitchen1', 18 | random_crop=[96, 128], 19 | color_augmentation=0.1, 20 | image_size_beforecrop=[112, 144], 21 | data_dir=os.environ['DATA'] + '/robonetv2/toykitchen_fixed_cam/', 22 | filtering_function=[lambda dframe: dframe[(dframe['environment'] == 'toykitchen1')]], 23 | aliasing_dict=task_name_aliasing_dict, 24 | ) 25 | 26 | validation_conf_toykitchen1_aliasing = AttrDict( 27 | val0=AttrDict( 28 | dataclass=LMDB_Dataset_Pandas, 29 | dataconf=bridge_data_config_kitchen1_aliasing 30 | ), 31 | ) 32 | 33 | data_config = AttrDict( 34 | main=AttrDict( 35 | dataclass=LMDB_Dataset_Pandas, 36 | dataconf=bridge_data_config, 37 | **validation_conf_toykitchen1_aliasing 38 | ) 39 | ) 40 | 41 | model_config = AttrDict( 42 | main=AttrDict( 43 | action_dim=7, 44 | state_dim=7, 45 | resnet='resnet34', 46 | task_id_conditioning=TOTAL_NUM_TASKS_ALIASING, 47 | img_sz=[96, 128] 48 | ) 49 | ) -------------------------------------------------------------------------------- /bridgedata_experiments/task_id_conditioned/conf_exclude_toykitchen1.py: -------------------------------------------------------------------------------- 1 | import os 2 | from bridgedata.models.gcbc_images import GCBCImages 3 | import numpy as np 4 | from bridgedata.utils.general_utils import AttrDict 5 | current_dir = os.path.dirname(os.path.realpath(__file__)) 6 | from bridgedata_experiments.dataset_lmdb import TOTAL_NUM_TASKS_ALIASING, bridge_data_config, task_name_aliasing_dict 7 | from widowx_envs.utils.datautils.lmdb_dataloader import LMDB_Dataset_Pandas 8 | 9 | configuration = AttrDict( 10 | main=AttrDict( 11 | model=GCBCImages, 12 | max_iterations=400000, 13 | ), 14 | ) 15 | 16 | bridge_data_config_kitchen1_aliasing = AttrDict( 17 | name='toykitchen1', 18 | random_crop=[96, 128], 19 | color_augmentation=0.1, 20 | image_size_beforecrop=[112, 144], 21 | data_dir=os.environ['DATA'] + '/robonetv2/toykitchen_fixed_cam/', 22 | filtering_function=[lambda dframe: dframe[(dframe['environment'] == 'toykitchen1') | (dframe['environment'] == 'toykitchen_bww')]], 23 | aliasing_dict=task_name_aliasing_dict, 24 | ) 25 | 26 | validation_conf_toykitchen1_aliasing = AttrDict( 27 | val0=AttrDict( 28 | dataclass=LMDB_Dataset_Pandas, 29 | dataconf=bridge_data_config_kitchen1_aliasing 30 | ), 31 | ) 32 | 33 | data_config = AttrDict( 34 | main=AttrDict( 35 | dataclass=LMDB_Dataset_Pandas, 36 | dataconf=bridge_data_config, 37 | **validation_conf_toykitchen1_aliasing 38 | ) 39 | ) 40 | 41 | model_config = AttrDict( 42 | main=AttrDict( 43 | action_dim=7, 44 | state_dim=7, 45 | resnet='resnet34', 46 | task_id_conditioning=TOTAL_NUM_TASKS_ALIASING, 47 | img_sz=[96, 128] 48 | ) 49 | ) -------------------------------------------------------------------------------- /docker/azure/Dockerfile: -------------------------------------------------------------------------------- 1 | # We need the CUDA base dockerfile to enable GPU rendering 2 | # on hosts with GPUs. 3 | # The image below is a pinned version of nvidia/cuda:9.1-cudnn7-devel-ubuntu16.04 (from Jan 2018) 4 | # If updating the base image, be sure to test on GPU since it has broken in the past. 5 | FROM nvidia/cuda:10.1-cudnn7-devel-ubuntu16.04 6 | 7 | SHELL ["/bin/bash", "-c"] 8 | 9 | ########################################################## 10 | ### System dependencies 11 | ########################################################## 12 | 13 | # Now let's download python 3 and all the dependencies 14 | RUN apt-get update -q 15 | RUN DEBIAN_FRONTEND=noninteractive apt-get install -y \ 16 | cmake \ 17 | curl \ 18 | git \ 19 | libav-tools \ 20 | libgl1-mesa-dev \ 21 | libgl1-mesa-glx \ 22 | libglew-dev \ 23 | libosmesa6-dev \ 24 | net-tools \ 25 | software-properties-common \ 26 | swig \ 27 | unzip \ 28 | vim \ 29 | wget \ 30 | xpra \ 31 | xserver-xorg-dev \ 32 | zlib1g-dev 33 | RUN apt-get clean 34 | RUN rm -rf /var/lib/apt/lists/* 35 | 36 | 37 | # Not sure why this is needed 38 | ENV LANG C.UTF-8 39 | 40 | # Not sure what this is fixing 41 | COPY ./files/Xdummy /usr/local/bin/Xdummy 42 | RUN chmod +x /usr/local/bin/Xdummy 43 | 44 | # Workaround for https://bugs.launchpad.net/ubuntu/+source/nvidia-graphics-drivers-375/+bug/1674677 45 | COPY ./files/10_nvidia.json /usr/share/glvnd/egl_vendor.d/10_nvidia.json 46 | 47 | # Not sure why this is needed 48 | ENV LD_LIBRARY_PATH /usr/local/nvidia/lib64:${LD_LIBRARY_PATH} 49 | 50 | ########################################################## 51 | ### Example Python Installation 52 | ########################################################## 53 | ENV PATH /opt/conda/bin:$PATH 54 | RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O /tmp/miniconda.sh && \ 55 | /bin/bash /tmp/miniconda.sh -b -p /opt/conda && \ 56 | rm /tmp/miniconda.sh && \ 57 | ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \ 58 | echo ". /opt/conda/etc/profile.d/conda.sh" >> /etc/bash.bashrc 59 | 60 | RUN conda update -y --name base conda && conda clean --all -y 61 | 62 | RUN conda create --name bridgedata python=3.6.8 pip 63 | RUN echo "source activate bridgedata" >> ~/.bashrc 64 | ENV OLDPATH $PATH 65 | ENV PATH /opt/conda/envs/bridgedata/bin:$PATH 66 | 67 | # Install packages here 68 | COPY ./docker_requirements.txt requirements.txt 69 | RUN pip install -r requirements.txt 70 | 71 | ENV EXP /data/experiments 72 | ENV DATA /data/trainingdata 73 | 74 | 75 | -------------------------------------------------------------------------------- /docker/azure/docker_requirements.txt: -------------------------------------------------------------------------------- 1 | pyarrow 2 | lmdb 3 | pybullet==2.8.7 4 | cloudpickle==1.3.0 5 | decorator 6 | funcsigs 7 | future 8 | imageio 9 | imageio-ffmpeg 10 | imutils==0.4.6 11 | matplotlib 12 | more-itertools 13 | moviepy 14 | numpy 15 | opencv-python==4.2.0.34 16 | pandas 17 | Pillow 18 | pyquaternion 19 | scikit-image 20 | scipy 21 | six 22 | tensorboard 23 | requests 24 | h5py 25 | torch 26 | torchvision 27 | tensorboardX 28 | nvidia_smi 29 | rospkg 30 | modern_robotics 31 | gym 32 | tqdm 33 | transformations 34 | wandb 35 | -------------------------------------------------------------------------------- /docker/azure/doodad_launch.py: -------------------------------------------------------------------------------- 1 | from doodad.wrappers.easy_launch import sweep_function, save_doodad_config 2 | from bridgedata.train import ModelTrainer 3 | import argparse 4 | import os 5 | 6 | def train(doodad_config, variant): 7 | args = argparse.Namespace() 8 | d = vars(args) 9 | for key, val in variant.items(): 10 | d[key] = val 11 | ModelTrainer(args) 12 | 13 | save_doodad_config(doodad_config) 14 | 15 | if __name__ == "__main__": 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--path", help="path to the config file directory", required=True) 18 | parser.add_argument("--prefix", help="experiment prefix, if given creates subfolder in experiment directory", required=True) 19 | parser.add_argument("--dry", action='store_true', help="dry run, local no doodad") 20 | args = parser.parse_args() 21 | 22 | if not os.path.isabs(args.path): 23 | raise ValueError('experiment path must be absolute!') 24 | 25 | params_to_sweep = {} 26 | default_params = { 27 | 'gpu': -1, 28 | 'strict_weight_loading': True, 29 | 'deterministic': False, 30 | 'cpu': False, 31 | 'metric': False, 32 | 'no_val': False, 33 | 'no_train': False, 34 | 'skip_main': False, 35 | 'resume': '', 36 | 'new_dir': True, 37 | 'path': args.path, 38 | 'prefix': args.prefix, 39 | 'data_config_override': None, 40 | 'source_data_config_override': None, 41 | 'load_task_indices': None, 42 | } 43 | if args.dry: 44 | mode = 'here_no_doodad' 45 | use_gpu = True 46 | else: 47 | mode = 'azure' 48 | use_gpu = True 49 | sweep_function( 50 | train, 51 | params_to_sweep, 52 | default_params=default_params, 53 | log_path=args.prefix, 54 | mode=mode, 55 | use_gpu=use_gpu, 56 | num_gpu=1, 57 | ) 58 | 59 | 60 | -------------------------------------------------------------------------------- /docker/azure/files/10_nvidia.json: -------------------------------------------------------------------------------- 1 | { 2 | "file_format_version" : "1.0.0", 3 | "ICD" : { 4 | "library_path" : "libEGL_nvidia.so.0" 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pyarrow 2 | lmdb 3 | pybullet==2.8.7 4 | cloudpickle==1.3.0 5 | decorator 6 | funcsigs 7 | future 8 | imageio 9 | imageio-ffmpeg 10 | imutils==0.4.6 11 | matplotlib 12 | more-itertools 13 | moviepy 14 | numpy 15 | opencv-python==4.2.0.34 16 | pandas 17 | Pillow 18 | pyquaternion 19 | scikit-image 20 | scipy 21 | six 22 | tensorboard 23 | requests 24 | h5py 25 | torch 26 | torchvision 27 | tensorboardX 28 | nvidia_smi 29 | rospkg 30 | modern_robotics 31 | gym 32 | tqdm 33 | transformations 34 | wandb 35 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | 4 | setup( 5 | name='bridgedata', 6 | version='0.2dev', 7 | packages=['bridgedata'], 8 | license='MIT License', 9 | ) --------------------------------------------------------------------------------