├── .gitignore ├── README.md ├── collect_demons ├── README.md ├── __init__.py ├── demons_config.py ├── imitate_play.py └── main.py ├── dataset_env ├── README.md ├── __init__.py ├── data_aug.py ├── data_config.py ├── deg_base.py ├── file_storage.py ├── organize.sh ├── rlbench_deg.py └── surreal_deg.py ├── db ├── README.md ├── db │ ├── __init__.py │ ├── asgi.py │ ├── settings.py │ ├── urls.py │ └── wsgi.py ├── manage.py ├── static │ ├── css │ │ └── foo │ └── js │ │ └── foo ├── templates │ ├── base.html │ └── vid.html └── traj_db │ ├── __init__.py │ ├── admin.py │ ├── apps.py │ ├── migrations │ └── __init__.py │ ├── models.py │ ├── tests.py │ └── views.py ├── global_config.py ├── model ├── __init__.py ├── layers.py ├── main.py ├── model_config.py ├── models.py ├── test.py └── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | sftp-config.json 3 | db/static/media 4 | web_db/static/media 5 | dataset_env/surreal 6 | dataset_env/furniture 7 | dataset_env/rlbench 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | pip-wheel-metadata/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A Boilerplate For Data Driven Robotics 2 | & a Pytorch implementation of [Learning Latent Plans from Play](https://learning-from-play.github.io/). 3 | 4 | This repo is supposed to provide organized & scalable experimentation of data-driven robotics learning. You can adapt it to your own model and environment with minor modifications. 5 | 6 | ## Organization 7 | ### Modules 8 | This setup consists of a databse (`db/`) inspired from [[1]](https://arxiv.org/abs/1909.12200) storing meta-data of the trajectories collected and a light web-app which renders a video of the trajectory. The DEG module (`dataset_env/`) provides easy adaption to various environments, dataloaders (`deg_base.py`), an easy functionality to interact with the DB and store/retrieve trajectories (`file_storage.py`) - all bundled up. The current implementation includes support for [RLBench](https://github.com/stepjam/RLBench/) and (older)[Robosuite](https://github.com/ARISE-Initiative/robosuite) environments. The collection module (`collect_demons/`) provides data-collection mechanisms such as teleoperation and imitation policies. Every new model can have it's on directory and the current `model/` contains a Pytorch implementation of [LfP](https://learning-from-play.github.io/). The training and testing code are defined in `model/` too. 9 | 10 | Additional information about each module is provided in their respective READMEs. 11 | ### Configs 12 | Config common to all the modules is defined in `global_config.py`. Each of the other modules have their own config files (`*_config.py`) which add to the global config. The config system is designed to automatically change on minor edits (eg. a change in `env` changes all the paths and other env-related properties). 13 | -------------------------------------------------------------------------------- /collect_demons/README.md: -------------------------------------------------------------------------------- 1 | # Data Collection 2 | 3 | This module provides a pit-stop to collect trajectories. The trajectories are stored as pickle files and their identifiers, paths and other meta-data are stored in the DB. The `collect_by` config defined in `demons_config` specifies how you want to collect the data - by teleopetation, a specific/random policy or an imitation-based policy trained on the data. The imitation policy is defined by a RNN-based gaussian policy in `imitate_play.py`. Since teleoperation and collection of random trajectories is environment specific, they are defined in the environment's corresponding DEG. But, they should be run from the `main.py` here. 4 | -------------------------------------------------------------------------------- /collect_demons/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhruvramani/data-driven-robotics/059ac21d516cffce291e2e8ac965c9da44969114/collect_demons/__init__.py -------------------------------------------------------------------------------- /collect_demons/demons_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | 5 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../')) 6 | 7 | import utils 8 | from global_config import * 9 | 10 | def get_demons_args(): 11 | parser = get_global_parser() 12 | 13 | # NOTE: 'SURREAL' is a placeholder. The deg is set according to global_config.env -> see below. v 14 | parser.add_argument('--deg', type=env2deg, default='SURREAL') 15 | parser.add_argument("--collect_by", type=str, default='teleop', choices=['teleop', 'imitation', 'expert', 'policy', 'exploration', 'random']) 16 | parser.add_argument("--device", type=str, default="keyboard", choices=["keyboard", "spacemouse"]) 17 | parser.add_argument("--collect_freq", type=int, default=1) 18 | parser.add_argument("--flush_freq", type=int, default=25) # NOTE : RAM Issues, change here : 75 19 | parser.add_argument("--break_traj_success", type=utils.str2bool, default=True) 20 | parser.add_argument("--n_runs", type=int, default=10, #10 21 | help="no. of runs of traj collection, affective when break_traj_success = False") 22 | 23 | # Imitation model 24 | parser.add_argument('--resume', type=utils.str2bool, default=False) 25 | parser.add_argument('--train_imitation', type=utils.str2bool, default=False) 26 | parser.add_argument('--models_save_path', type=str, default=os.path.join(DATA_DIR, 'runs/imitation-models/')) 27 | parser.add_argument('--tensorboard_path', type=str, default=os.path.join(DATA_DIR, 'runs/imitation-tensorboard/')) 28 | parser.add_argument('--load_models', type=utils.str2bool, default=True) 29 | parser.add_argument('--use_model_perception', type=utils.str2bool, default=True) 30 | parser.add_argument('--n_gen_traj', type=int, default=200, help="Number of trajectories to generate by imitation") 31 | 32 | config = parser.parse_args() 33 | config.env_args = env2args(config.env) 34 | config.deg = env2deg(config.env) 35 | config.data_path = os.path.join(config.data_path, '{}_{}/'.format(config.env, config.env_type)) 36 | config.models_save_path = os.path.join(config.models_save_path, '{}_{}/'.format(config.env, config.env_type)) 37 | config.tensorboard_path = os.path.join(config.tensorboard_path, '{}_{}_{}/'.format(config.env, config.env_type, config.exp_name)) 38 | 39 | if config.train_imitation and not config.resume: 40 | utils.recreate_dir(config.models_save_path, config.display_warnings) 41 | utils.recreate_dir(config.tensorboard_path, config.display_warnings) 42 | else: 43 | utils.check_n_create_dir(config.models_save_path, config.display_warnings) 44 | utils.check_n_create_dir(config.tensorboard_path, config.display_warnings) 45 | 46 | utils.check_n_create_dir(config.data_path, config.display_warnings) 47 | 48 | return config 49 | 50 | if __name__ == '__main__': 51 | args = get_demons_args() 52 | print(args.models_save_path) 53 | -------------------------------------------------------------------------------- /collect_demons/imitate_play.py: -------------------------------------------------------------------------------- 1 | # TODO : *IMPORTANT* - change the imitation policy to be condtioned on the GROUND STATE rather than visual obv 2 | # While running this imitated policy, collect visual_obv w/ domain randomization 3 | 4 | import os 5 | import sys 6 | import torch 7 | import torch.nn.functional as F 8 | from tqdm import tqdm # TODO : Remove TQDMs 9 | from tensorboardX import SummaryWriter 10 | from torch.utils.data import DataLoader 11 | from torch.distributions.normal import Normal 12 | 13 | from demons_config import get_demons_args 14 | 15 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../model')) 16 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../datatset_env')) 17 | 18 | from models import PerceptionModule 19 | from model_config import get_model_args 20 | 21 | from file_storage import store_trajectoy 22 | 23 | device = torch.device('gpu' if torch.cuda.is_available() else 'cpu') 24 | 25 | class ImitationPolicy(torch.nn.Module): 26 | def __init__(self, action_dim=8, state_dim=72, hidden_size=2048, batch_size=1, rnn_type='RNN', num_layers=2): 27 | super(ImitationPolicy, self).__init__() 28 | self.state_dim = state_dim 29 | self.action_dim = action_dim 30 | self.batch_size = batch_size 31 | self.hidden_size = hidden_size 32 | self.rnn_type = rnn_type.upper() 33 | self.num_layers = num_layers 34 | 35 | assert self.rnn_type in ['LSTM', 'GRU', 'RNN'] 36 | self.rnn = {'LSTM' : torch.nn.LSTMCell, 'GRU' : torch.nn.GRUCell, 'RNN' : torch.nn.RNNCell}[self.rnn_type] 37 | 38 | self.rnn1 = self.rnn(self.state_dim, self.hidden_size) 39 | self.rnn2 = self.rnn(self.hidden_size, self.hidden_size) 40 | 41 | # NOTE : Original paper used a Mixture of Logistic (MoL) dist. Implement later. 42 | self.hidden2mean = torch.nn.Linear(self.hidden_size, self.action_dim) 43 | log_std = -0.5 * np.ones(self.action_dim, dtype=np.np.float32) 44 | self.log_std = torch.nn.Parameter(torch.as_tensor(log_std)) 45 | 46 | # The hidden states of the RNN. 47 | self.relu = torch.nn.ReLU() 48 | self.tanh = torch.nn.Tanh() 49 | 50 | def _prepare_obs(self, state, perception_module): 51 | if state.size()[-1] != self.state_dim: 52 | assert perception_module is not None 53 | state = perception_module(state) 54 | 55 | # The hidden states of the RNN. 56 | self.h1 = torch.randn(state.shape[0], self.hidden_size) 57 | self.h2 = torch.randn(state.shape[0], self.hidden_size) 58 | if self.rnn_type == 'LSTM': 59 | self.c1 = torch.randn(state.shape[0], self.hidden_size) 60 | self.c2 = torch.randn(state.shape[0], self.hidden_size) 61 | 62 | return state 63 | 64 | def _distribution(self, obs): 65 | if self.rnn_type == 'LSTM': 66 | self.h1, self.c1 = self.relu(self.rnn1(obs, (self.h1, self.c1))) 67 | self.h2, self.c2 = self.relu(self.rnn2(self.h1, (self.h2, self.c2))) 68 | else: 69 | self.h1 = self.relu(self.rnn1(obs, self.h1)) 70 | self.h2 = self.relu(self.rnn2(self.h1, self.h2)) 71 | 72 | mean = self.tanh(self.hidden2mean(self.h2)) 73 | std = torch.exp(self.log_std) 74 | return Normal(mean, std) 75 | 76 | def _log_prob_from_distribution(self, policy, action): 77 | return policy.log_prob(action).sum(axis=-1) 78 | 79 | def forward(self, state, action=None, perception_module=None): 80 | obs = self._prepare_obs(state, perception_module) 81 | policy = self._distribution(obs) 82 | 83 | logp_a = None 84 | if action is not None: 85 | logp_a = self._log_prob_from_distribution(policy, action) 86 | 87 | return policy, logp_a 88 | 89 | def step(self, state, perception_module=None): 90 | with torch.no_grad(): 91 | obs = self._prepare_obs(state, perception_module) 92 | policy = self._distribution(obs) 93 | action = policy.sample() 94 | logp_a = self._log_prob_from_distribution(policy, action) 95 | 96 | return action.numpy(), logp_a.numpy() 97 | 98 | def train_imitation(demons_config): 99 | model_config = get_model_args() 100 | 101 | deg = demons_config.deg(get_episode_type='EPISODE_ROBOT_PLAY') 102 | 103 | vobs_dim, dof_dim = deg.obs_space[deg.vis_obv_key], deg.obs_space[deg.dof_obv_key] 104 | act_dim = deg.action_space 105 | 106 | tensorboard_writer = SummaryWriter(logdir=demons_config.tensorboard_path) 107 | perception_module = PerceptionModule(vobs_dim, dof_dim, model_config.visual_state_dim).to(device) 108 | imitation_policy = ImitationPolicy(act_dim, model_config.combined_state_dim,).to(device) 109 | 110 | params = list(perception_module.parameters()) + list(imitation_policy.parameters()) 111 | print("Number of parameters : {}".format(len(params))) 112 | 113 | optimizer = torch.optim.Adam(params, lr=model_config.learning_rate) 114 | 115 | if(demons_config.load_models): 116 | # TODO : IMPORTANT - Check if file exist before loading 117 | if demons_config.use_model_perception: 118 | perception_module.load_state_dict(torch.load(os.path.join(model_config.models_save_path, 'perception.pth'))) 119 | else : 120 | perception_module.load_state_dict(torch.load(os.path.join(demons_config.models_save_path, 'perception.pth'))) 121 | imitation_policy.load_state_dict(torch.load(os.path.join(demons_config.models_save_path, 'imitation_policy.pth'))) 122 | optimizer.load_state_dict(torch.load(os.path.join(demons_config.models_save_path, 'optimizer.pth'))) 123 | 124 | print("Run : tensorboard --logdir={} --host '0.0.0.0' --port 6006".format(demons_config.tensorboard_path)) 125 | data_loader = DataLoader(deg.traj_dataset, batch_size=model_config.batch_size, shuffle=True, num_workers=1) 126 | max_step_size = len(data_loader.dataset) 127 | 128 | for epoch in tqdm(range(model_config.max_epochs), desc="Check Tensorboard"): 129 | for i, trajectory in enumerate(data_loader): 130 | trajectory = {key : trajectory[key].float().to(device) for key in trajectory.keys()} 131 | visual_obvs, dof_obs, action = trajectory[deg.vis_obv_key], trajectory[deg.dof_obv_key], trajectory['action'] 132 | batch_size, seq_len = visual_obvs.shape[0], visual_obvs.shape[1] 133 | 134 | visual_obvs = visual_obvs.reshape(batch_size * seq_len, vobs_dim[2], vobs_dim[0], vobs_dim[1]) 135 | dof_obs = dof_obs.reshape(batch_size * seq_len, dof_dim) 136 | actions = trajectory['action'].reshape(batch_size * seq_len, -1) 137 | 138 | states = perception_module(visual_obvs, dof_obs) # DEBUG : Might raise in-place errors 139 | 140 | pi, logp_a = imitation_policy(state=states, action=actions) 141 | 142 | optimizer.zero_grad() 143 | loss = -logp_a 144 | loss = loss.mean() 145 | 146 | tensorboard_writer.add_scalar('Clone Loss', loss, epoch * max_step_size + i) 147 | loss.backward() 148 | optimizer.step() 149 | 150 | if int(i % model_config.save_interval) == 0: 151 | if not demons_config.use_model_perception: 152 | torch.save(perception_module.state_dict(), os.path.join(demons_config.models_save_path, 'perception.pth')) 153 | torch.save(imitation_policy.state_dict(), os.path.join(demons_config.models_save_path, 'imitation_policy.pth')) 154 | torch.save(optimizer.state_dict(), os.path.join(demons_config.models_save_path, 'optimizer.pth')) 155 | 156 | 157 | def imitate_play(): 158 | model_config = get_model_args() 159 | demons_config = get_demons_args() 160 | 161 | deg = demons_config.deg() 162 | env = deg.get_env() 163 | 164 | vobs_dim, dof_dim = deg.obs_space[deg.vis_obv_key], deg.obs_space[deg.dof_obv_key] 165 | act_dim = deg.action_space 166 | 167 | with torch.no_grad(): 168 | perception_module = PerceptionModule(vobs_dim, dof_dim, model_config.visual_state_dim).to(device) 169 | imitation_policy = ImitationPolicy(act_dim, model_config.combined_state_dim,).to(device) 170 | 171 | if demons_config.use_model_perception: 172 | perception_module.load_state_dict(torch.load(os.path.join(model_config.models_save_path, 'perception.pth'))) 173 | else : 174 | perception_module.load_state_dict(torch.load(os.path.join(demons_config.models_save_path, 'perception.pth'))) 175 | imitation_policy.load_state_dict(torch.load(os.path.join(demons_config.models_save_path, 'imitation_policy.pth'))) 176 | 177 | for run in range(demons_config.n_gen_traj): 178 | obs = env.reset() 179 | tr_vobvs, tr_dof, tr_actions = [], [], [] 180 | 181 | for step in range(demon_config.flush_freq): 182 | visual_obv, dof_obv = torch.from_numpy(obvs[deg.vis_obv_key]).float(), torch.from_numpy(obvs[deg.dof_obv_key]).float() 183 | visual_obv = visual_obv.reshape(1, visual_obv.shape[2], visual_obv.shape[0], visual_obv.shape[1]) 184 | dof_obv = dof_obv.reshape(1, dof_obv.shape[0]) 185 | state = perception_module(visual_obv, dof_obv) 186 | 187 | action, _ = imitation_policy.step(state) 188 | 189 | if int(step % demons_config.collect_freq) == 0: 190 | tr_vobvs.append(visual_obv) 191 | tr_dof.append(dof_obv) 192 | tr_actions.append(action[0]) 193 | 194 | obs, _, done, _ = env.step(action[0]) 195 | 196 | print('Storing Trajectory') 197 | trajectory = {deg.vis_obv_key : np.array(tr_vobvs), deg.dof_obv_key : np.array(tr_dof), 'action' : np.array(tr_actions)} 198 | store_trajectoy(trajectory, 'imitation') 199 | trajectory, tr_vobvs, tr_dof, tr_actions = {}, [], [], [] 200 | 201 | env.close() 202 | 203 | if __name__ == '__main__': 204 | demons_config = get_demons_args() 205 | if demons_config.train_imitation: 206 | train_imitation(demons_config) 207 | 208 | imitate_play() -------------------------------------------------------------------------------- /collect_demons/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import torch 5 | 6 | from demons_config import get_demons_args 7 | 8 | if __name__ == "__main__": 9 | demon_config = get_demons_args() 10 | deg = demon_config.deg() 11 | if demon_config.collect_by == 'teleop': 12 | # TODO remove this 13 | # tasks = ['beat_the_buzz', 'block_pyramid', 'change_channel', 'change_clock', 'close_box', 'close_door', 'close_drawer', 'close_fridge', 'close_grill', 'close_jar', 'close_laptop_lid', 'close_microwave', 'empty_container', 'empty_dishwasher', 'get_ice_from_fridge', 'hang_frame_on_hanger', 'hannoi_square', 'hit_ball_with_queue', 'hockey', 'insert_usb_in_computer', 'lamp_off', 'lamp_on', 'light_bulb_in', 'light_bulb_out', 'meat_off_grill', 'meat_on_grill', 'move_hanger', 'open_box', 'open_door', 'open_drawer', 'open_fridge', 'open_grill', 'open_jar', 'open_microwave', 'open_oven', 'open_window', 'open_wine_bottle', 'phone_on_base', 'pick_and_lift', 'pick_up_cup', 'place_cups', 'place_hanger_on_rack', 'place_shape_in_shape_sorter', 'play_jenga', 'plug_charger_in_power_supply', 'pour_from_cup_to_cup', 'press_switch', 'push_button', 'push_buttons', 'put_books_on_bookshelf', 'put_bottle_in_fridge', 'put_groceries_in_cupboard', 'put_item_in_drawer', 'put_knife_in_knife_block', 'put_knife_on_chopping_board', 'put_money_in_safe', 'put_plate_in_colored_dish_rack', 'put_rubbish_in_bin', 'put_shoes_in_box', 'put_toilet_roll_on_stand', 'put_tray_in_oven', 'put_umbrella_in_umbrella_stand', 'reach_and_drag', 'reach_target', 'remove_cups', 'scoop_with_spatula', 'screw_nail', 'set_the_table', 'setup_checkers', 'slide_block_to_target', 'slide_cabinet_open', 'slide_cabinet_open_and_place_cups', 'solve_puzzle', 'stack_blocks', 'stack_cups', 'stack_wine', 'straighten_rope', 'sweep_to_dustpan', 'take_cup_out_from_cabinet', 'take_frame_off_hanger', 'take_item_out_of_drawer', 'take_lid_off_saucepan', 'take_money_out_safe', 'take_off_weighing_scales', 'take_plate_off_colored_dish_rack', 'take_shoes_out_of_box', 'take_toilet_roll_off_stand', 'take_tray_out_of_oven', 'take_umbrella_out_of_umbrella_stand', 'take_usb_out_of_computer', 'toilet_seat_down', 'toilet_seat_up', 'turn_oven_on', 'turn_tap', 'tv_off', 'tv_on', 'unplug_charger', 'water_plants', 'weighing_scales', 'wipe_desk'] 14 | # for i, task in enumerate(tasks): 15 | # print(i, task) 16 | # try : 17 | deg.teleoperate(demon_config) #, task) 18 | # except : 19 | # print("Couldn't collect demos") 20 | elif demon_config.collect_by == 'random': 21 | deg.random_trajectory(demon_config) 22 | elif demon_config.collect_by == 'imitation': 23 | #NOTE : NOT TESTED 24 | import imitate_play 25 | 26 | if demon_config.train_imitation: 27 | imitate_play.train_imitation(demon_config) 28 | imitate_play.imitate_play() -------------------------------------------------------------------------------- /dataset_env/README.md: -------------------------------------------------------------------------------- 1 | # Datasets & Environment Groups 2 | 3 | `DataEnvGroup`s (DEGs) provide a common window to interact with environments and their collected datasets. The abstract class (`deg_base.py`) provides most of the common functionalities and has to be overriden to adopt to a new environment. Each environment has it's own DEG file which provides definitions for the abstract methods and properties - see `rlbench_deg.py` and `surreal_deg.py`. Use the DEG to define & get environments, datasets and dataloaders and observation/action spaces with a common environment-agnostic syntax. 4 | 5 | The data-based configs and the link to the DB are defined in `data_config.py`. The environment is specified in `../global_config.py` and it's modification reflects the changes in data-configs. File and DB related functionalities are defined in `file_storage.py`. Visual augmentations to improve performance, inspired from [RAD](https://mishalaskin.github.io/rad/) - are provided in `data_aug.py`. 6 | 7 | ## Adding New Environments 8 | 9 | All the envs are stored in a directory outside the repo (defined by `ENV_PATH` in each of the DEG files) because of cleanliness. If you want to modify the original environment, clone it into this directory. All the cloning and renaming code goes in `organize.sh`. 10 | 11 | To create a DEG for a new environment, create a new `envname_deg.py` file and inherit the `DataEnvGroup` class. Import the actual environment class and provide definitions for environment-specific abstract methods (like `get_env`, `teleoperate` etc.) and properties (observation space & keys). See `rlbench_deg.py` for example. 12 | 13 | After adding the DEG file, create a table for the environment in `db/` (refer to its README). Add the table info in `taj_db_dict` & `env2keys()` in `data_config.py` and add the DEG info in `env2deg()` in `../global_config.py`. That's it! The rest of the functionalities, properties and configs will adopt automatically to the change in `env` in `../global_config.py`. 14 | -------------------------------------------------------------------------------- /dataset_env/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhruvramani/data-driven-robotics/059ac21d516cffce291e2e8ac965c9da44969114/dataset_env/__init__.py -------------------------------------------------------------------------------- /dataset_env/data_aug.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numbers 3 | import random 4 | import time 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | # NOTE : Source - RAD (https://github.com/MishaLaskin/rad) 11 | # TODO : Will test later 12 | 13 | def random_crop(images, out=192): 14 | ''' 15 | + Arguments: 16 | - images: np.array shape (N,C,H,W) 17 | - out: output size (e.g. 84) 18 | ''' 19 | n, c, h, w = images.shape 20 | crop_max = h - out + 1 21 | w1 = np.random.randint(0, crop_max, n) 22 | h1 = np.random.randint(0, crop_max, n) 23 | cropped = np.empty((n, c, out, out), dtype=images.dtype) 24 | for i, (img, w11, h11) in enumerate(zip(images, w1, h1)): 25 | cropped[i] = img[:, h11:h11 + out, w11:w11 + out] 26 | 27 | return np.reshape(cropped, (n, out, out, c)) 28 | 29 | def grayscale(images): 30 | ''' 31 | + Arguments: 32 | - images: np.array shape (B,C,H,W) 33 | ''' 34 | images = torch.from_numpy(images) 35 | device = images.device 36 | b, c, h, w = images.shape 37 | frames = c // 3 38 | 39 | images = images.view([b, frames, 3, h, w]) 40 | images = images[:, :, 0, ...] * 0.2989 + images[:, :, 1, ...] * 0.587 + images[:, :, 2, ...] * 0.114 41 | 42 | images = images.type(torch.uint8).float() 43 | # assert len(images.shape) == 3, images.shape 44 | images = images[:, :, None, :, :] 45 | images = images.numpy() * np.ones([1, 1, 3, 1, 1]) 46 | return np.reshape(images, (b, h, w, 1)) 47 | 48 | def random_grayscale(images, probab=0.3): 49 | ''' 50 | + Arguments: 51 | - images: np.array shape (B,C,H,W) 52 | - probab 53 | ''' 54 | images = torch.from_numpy(images) 55 | device = images.device 56 | in_type = images.type() 57 | images = images * 255. 58 | images = images.type(torch.uint8) 59 | # images: [B, C, H, W] 60 | bs, channels, h, w = images.shape 61 | images = images.to(device) 62 | gray_images = grayscale(images) 63 | rnd = np.random.uniform(0., 1., size=(images.shape[0],)) 64 | mask = rnd <= probab 65 | mask = torch.from_numpy(mask) 66 | frames = images.shape[1] // 3 67 | images = images.view(*gray_images.shape) 68 | mask = mask[:, None] * torch.ones([1, frames]).type(mask.dtype) 69 | mask = mask.type(images.dtype).to(device) 70 | mask = mask[:, :, None, None, None] 71 | out = mask * gray_images + (1 - mask) * images 72 | out = out.view([bs, -1, h, w]).type(in_type) / 255. 73 | return np.reshape(out.numpy(), (bs, h, w, -1)) 74 | 75 | def random_cutout(images, min_cut=10, max_cut=30): 76 | ''' 77 | + Arguments: 78 | - images: np.array shape (B,C,H,W) 79 | - min / max cut: int, min / max size of cutout 80 | ''' 81 | 82 | n, c, h, w = images.shape 83 | w1 = np.random.randint(min_cut, max_cut, n) 84 | h1 = np.random.randint(min_cut, max_cut, n) 85 | 86 | cutouts = np.empty((n, c, h, w), dtype=images.dtype) 87 | for i, (img, w11, h11) in enumerate(zip(images, w1, h1)): 88 | cut_img = img.copy() 89 | cut_img[:, h11:h11 + h11, w11:w11 + w11] = 0 90 | cutouts[i] = cut_img 91 | return np.reshape(cutouts, (n, h, w, c)) 92 | 93 | def random_cutout_color(images, min_cut=10, max_cut=30): 94 | ''' 95 | + Arguments: 96 | - images: np.array shape (N,C,H,W) 97 | - out: output size (e.g. 84) 98 | ''' 99 | n, c, h, w = images.shape 100 | w1 = np.random.randint(min_cut, max_cut, n) 101 | h1 = np.random.randint(min_cut, max_cut, n) 102 | 103 | cutouts = np.empty((n, c, h, w), dtype=images.dtype) 104 | rand_box = np.random.randint(0, 255, size=(n, c)) / 255. 105 | for i, (img, w11, h11) in enumerate(zip(images, w1, h1)): 106 | cut_img = img.copy() 107 | 108 | cut_img[:, h11:h11 + h11, w11:w11 + w11] = np.tile( 109 | rand_box[i].reshape(-1,1,1), 110 | (1,) + cut_img[:, h11:h11 + h11, w11:w11 + w11].shape[1:]) 111 | 112 | cutouts[i] = cut_img 113 | return np.reshape(cutouts, (n, h, w, c)) 114 | 115 | def random_flip(images, probab=0.2): 116 | ''' 117 | + Arguments: 118 | - images: np.array shape (B,C,H,W) 119 | - probab 120 | ''' 121 | images = torch.from_numpy(images) 122 | device = images.device 123 | bs, channels, h, w = images.shape 124 | 125 | images = images.to(device) 126 | 127 | flipped_images = images.flip([3]) 128 | 129 | rnd = np.random.uniform(0., 1., size=(images.shape[0],)) 130 | mask = rnd <= probab 131 | mask = torch.from_numpy(mask) 132 | frames = images.shape[1] #// 3 133 | images = images.view(*flipped_images.shape) 134 | mask = mask[:, None] * torch.ones([1, frames]).type(mask.dtype) 135 | 136 | mask = mask.type(images.dtype).to(device) 137 | mask = mask[:, :, None, None] 138 | 139 | out = mask * flipped_images + (1 - mask) * images 140 | 141 | out = out.view([bs, h, w, -1]) 142 | return out.numpy() 143 | 144 | def random_rotation(images, probab=0.3): 145 | ''' 146 | + Arguments: 147 | - images: np.array shape (B,C,H,W) 148 | - probab 149 | ''' 150 | images = torch.from_numpy(images) 151 | device = images.device 152 | # images: [B, C, H, W] 153 | bs, channels, h, w = images.shape 154 | 155 | images = images.to(device) 156 | 157 | rot90_images = images.rot90(1,[2,3]) 158 | rot180_images = images.rot90(2,[2,3]) 159 | rot270_images = images.rot90(3,[2,3]) 160 | 161 | rnd = np.random.uniform(0., 1., size=(images.shape[0],)) 162 | rnd_rot = np.random.randint(1, 4, size=(images.shape[0],)) 163 | mask = rnd <= probab 164 | mask = rnd_rot * mask 165 | mask = torch.from_numpy(mask).to(device) 166 | 167 | frames = images.shape[1] 168 | masks = [torch.zeros_like(mask) for _ in range(4)] 169 | for i,m in enumerate(masks): 170 | m[torch.where(mask==i)] = 1 171 | m = m[:, None] * torch.ones([1, frames]).type(mask.dtype).type(images.dtype).to(device) 172 | m = m[:,:,None,None] 173 | masks[i] = m 174 | 175 | 176 | out = masks[0] * images + masks[1] * rot90_images + masks[2] * rot180_images + masks[3] * rot270_images 177 | 178 | out = out.view([bs, h, w, -1]) 179 | return out.numpy() 180 | 181 | def random_convolution(images): 182 | ''' 183 | + Arguments: 184 | - images: np.array shape (B,C,H,W) 185 | ''' 186 | images = torch.from_numpy(images) 187 | _device = images.device 188 | 189 | img_h, img_w = images.shape[2], images.shape[3] 190 | num_stack_channel = images.shape[1] 191 | num_batch = images.shape[0] 192 | num_trans = num_batch 193 | batch_size = int(num_batch / num_trans) 194 | 195 | # initialize random covolution 196 | rand_conv = nn.Conv2d(3, 3, kernel_size=3, bias=False, padding=1).to(_device) 197 | 198 | for trans_index in range(num_trans): 199 | torch.nn.init.xavier_normal_(rand_conv.weight.data) 200 | temp_images = images[trans_index*batch_size:(trans_index+1)*batch_size] 201 | temp_images = temp_images.reshape(-1, 3, img_h, img_w) # (batch x stack, channel, h, w) 202 | rand_out = rand_conv(temp_images) 203 | if trans_index == 0: 204 | total_out = rand_out 205 | else: 206 | total_out = torch.cat((total_out, rand_out), 0) 207 | total_out = total_out.reshape(-1, num_stack_channel, img_h, img_w) 208 | return np.reshape(total_out.numpy(), (num_batch, img_h, img_w, num_stack_channel)) 209 | 210 | def no_aug(images): 211 | return images 212 | 213 | def random_color_jitter(images): 214 | ''' 215 | + Arguments: 216 | - images: np.array shape (B,C,H,W) 217 | ''' 218 | b,c,h,w = images.shape 219 | images = images.view(-1,3,h,w) 220 | transform_module = nn.Sequential(ColorJitterLayer(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5, p=1.0, batch_size=128)) 221 | images = transform_module(images).view(b, h, w, c) 222 | return images 223 | 224 | # ------------------------------------------------------------------------------------------------------------------------ 225 | 226 | def rgb2hsv(rgb, eps=1e-8): 227 | # Reference: https://www.rapidtables.com/convert/color/rgb-to-hsv.html 228 | # Reference: https://github.com/scikit-image/scikit-image/blob/master/skimage/color/colorconv.py#L287 229 | 230 | _device = rgb.device 231 | r, g, b = rgb[:, 0, :, :], rgb[:, 1, :, :], rgb[:, 2, :, :] 232 | 233 | Cmax = rgb.max(1)[0] 234 | Cmin = rgb.min(1)[0] 235 | delta = Cmax - Cmin 236 | 237 | hue = torch.zeros((rgb.shape[0], rgb.shape[2], rgb.shape[3])).to(_device) 238 | hue[Cmax== r] = (((g - b)/(delta + eps)) % 6)[Cmax == r] 239 | hue[Cmax == g] = ((b - r)/(delta + eps) + 2)[Cmax == g] 240 | hue[Cmax == b] = ((r - g)/(delta + eps) + 4)[Cmax == b] 241 | hue[Cmax == 0] = 0.0 242 | hue = hue / 6. # making hue range as [0, 1.0) 243 | hue = hue.unsqueeze(dim=1) 244 | 245 | saturation = (delta) / (Cmax + eps) 246 | saturation[Cmax == 0.] = 0. 247 | saturation = saturation.to(_device) 248 | saturation = saturation.unsqueeze(dim=1) 249 | 250 | value = Cmax 251 | value = value.to(_device) 252 | value = value.unsqueeze(dim=1) 253 | 254 | return torch.cat((hue, saturation, value), dim=1)#.type(torch.FloatTensor).to(_device) 255 | # return hue, saturation, value 256 | 257 | def hsv2rgb(hsv): 258 | # Reference: https://www.rapidtables.com/convert/color/hsv-to-rgb.html 259 | # Reference: https://github.com/scikit-image/scikit-image/blob/master/skimage/color/colorconv.py#L287 260 | 261 | _device = hsv.device 262 | 263 | hsv = torch.clamp(hsv, 0, 1) 264 | hue = hsv[:, 0, :, :] * 360. 265 | saturation = hsv[:, 1, :, :] 266 | value = hsv[:, 2, :, :] 267 | 268 | c = value * saturation 269 | x = - c * (torch.abs((hue / 60.) % 2 - 1) - 1) 270 | m = (value - c).unsqueeze(dim=1) 271 | 272 | rgb_prime = torch.zeros_like(hsv).to(_device) 273 | 274 | inds = (hue < 60) * (hue >= 0) 275 | rgb_prime[:, 0, :, :][inds] = c[inds] 276 | rgb_prime[:, 1, :, :][inds] = x[inds] 277 | 278 | inds = (hue < 120) * (hue >= 60) 279 | rgb_prime[:, 0, :, :][inds] = x[inds] 280 | rgb_prime[:, 1, :, :][inds] = c[inds] 281 | 282 | inds = (hue < 180) * (hue >= 120) 283 | rgb_prime[:, 1, :, :][inds] = c[inds] 284 | rgb_prime[:, 2, :, :][inds] = x[inds] 285 | 286 | inds = (hue < 240) * (hue >= 180) 287 | rgb_prime[:, 1, :, :][inds] = x[inds] 288 | rgb_prime[:, 2, :, :][inds] = c[inds] 289 | 290 | inds = (hue < 300) * (hue >= 240) 291 | rgb_prime[:, 2, :, :][inds] = c[inds] 292 | rgb_prime[:, 0, :, :][inds] = x[inds] 293 | 294 | inds = (hue < 360) * (hue >= 300) 295 | rgb_prime[:, 2, :, :][inds] = x[inds] 296 | rgb_prime[:, 0, :, :][inds] = c[inds] 297 | 298 | rgb = rgb_prime + torch.cat((m, m, m), dim=1) 299 | rgb = rgb.to(_device) 300 | 301 | return torch.clamp(rgb, 0, 1) 302 | 303 | class ColorJitterLayer(nn.Module): 304 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, p=0, batch_size=128, stack_size=3): 305 | super(ColorJitterLayer, self).__init__() 306 | self.brightness = self._check_input(brightness, 'brightness') 307 | self.contrast = self._check_input(contrast, 'contrast') 308 | self.saturation = self._check_input(saturation, 'saturation') 309 | self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), 310 | clip_first_on_zero=False) 311 | self.prob = p 312 | self.batch_size = batch_size 313 | self.stack_size = stack_size 314 | 315 | def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): 316 | if isinstance(value, numbers.Number): 317 | if value < 0: 318 | raise ValueError("If {} is a single number, it must be non negative.".format(name)) 319 | value = [center - value, center + value] 320 | if clip_first_on_zero: 321 | value[0] = max(value[0], 0) 322 | elif isinstance(value, (tuple, list)) and len(value) == 2: 323 | if not bound[0] <= value[0] <= value[1] <= bound[1]: 324 | raise ValueError("{} values should be between {}".format(name, bound)) 325 | else: 326 | raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name)) 327 | # if value is 0 or (1., 1.) for brightness/contrast/saturation 328 | # or (0., 0.) for hue, do nothing 329 | if value[0] == value[1] == center: 330 | value = None 331 | return value 332 | 333 | def adjust_contrast(self, x): 334 | """ 335 | Args: 336 | x: torch tensor img (rgb type) 337 | Factor: torch tensor with same length as x 338 | 0 gives gray solid image, 1 gives original image, 339 | Returns: 340 | torch tensor image: Brightness adjusted 341 | """ 342 | _device = x.device 343 | factor = torch.empty(self.batch_size, device=_device).uniform_(*self.contrast) 344 | factor = factor.reshape(-1,1).repeat(1, self.stack_size).reshape(-1) 345 | means = torch.mean(x, dim=(2, 3), keepdim=True) 346 | return torch.clamp((x - means) 347 | * factor.view(len(x), 1, 1, 1) + means, 0, 1) 348 | 349 | def adjust_hue(self, x): 350 | _device = x.device 351 | factor = torch.empty(self.batch_size, device=_device).uniform_(*self.hue) 352 | factor = factor.reshape(-1,1).repeat(1, self.stack_size).reshape(-1) 353 | h = x[:, 0, :, :] 354 | h += (factor.view(len(x), 1, 1) * 255. / 360.) 355 | h = (h % 1) 356 | x[:, 0, :, :] = h 357 | return x 358 | 359 | def adjust_brightness(self, x): 360 | """ 361 | Args: 362 | x: torch tensor img (hsv type) 363 | Factor: 364 | torch tensor with same length as x 365 | 0 gives black image, 1 gives original image, 366 | 2 gives the brightness factor of 2. 367 | Returns: 368 | torch tensor image: Brightness adjusted 369 | """ 370 | _device = x.device 371 | factor = torch.empty(self.batch_size, device=_device).uniform_(*self.brightness) 372 | factor = factor.reshape(-1,1).repeat(1, self.stack_size).reshape(-1) 373 | x[:, 2, :, :] = torch.clamp(x[:, 2, :, :] 374 | * factor.view(len(x), 1, 1), 0, 1) 375 | return torch.clamp(x, 0, 1) 376 | 377 | def adjust_saturate(self, x): 378 | """ 379 | Args: 380 | x: torch tensor img (hsv type) 381 | Factor: 382 | torch tensor with same length as x 383 | 0 gives black image and white, 1 gives original image, 384 | 2 gives the brightness factor of 2. 385 | Returns: 386 | torch tensor image: Brightness adjusted 387 | """ 388 | _device = x.device 389 | factor = torch.empty(self.batch_size, device=_device).uniform_(*self.saturation) 390 | factor = factor.reshape(-1,1).repeat(1, self.stack_size).reshape(-1) 391 | x[:, 1, :, :] = torch.clamp(x[:, 1, :, :] 392 | * factor.view(len(x), 1, 1), 0, 1) 393 | return torch.clamp(x, 0, 1) 394 | 395 | def transform(self, inputs): 396 | hsv_transform_list = [rgb2hsv, self.adjust_brightness, 397 | self.adjust_hue, self.adjust_saturate, 398 | hsv2rgb] 399 | rgb_transform_list = [self.adjust_contrast] 400 | # Shuffle transform 401 | if random.uniform(0,1) >= 0.5: 402 | transform_list = rgb_transform_list + hsv_transform_list 403 | else: 404 | transform_list = hsv_transform_list + rgb_transform_list 405 | for t in transform_list: 406 | inputs = t(inputs) 407 | return inputs 408 | 409 | def forward(self, inputs): 410 | _device = inputs.device 411 | random_inds = np.random.choice( 412 | [True, False], len(inputs), p=[self.prob, 1 - self.prob]) 413 | inds = torch.tensor(random_inds).to(_device) 414 | if random_inds.sum() > 0: 415 | inputs[inds] = self.transform(inputs[inds]) 416 | return inputs 417 | 418 | 419 | aug_to_func = { 420 | 'crop': random_crop, 421 | 'grayscale': random_grayscale, 422 | 'cutout': random_cutout, 423 | 'cutout_color': random_cutout_color, 424 | 'flip': random_flip, 425 | 'rotate': random_rotation, 426 | 'rand_conv': random_convolution, 427 | 'color_jitter': random_color_jitter, 428 | 'no_aug': no_aug, 429 | } 430 | 431 | def apply_augs(images, config): 432 | for aug in config.augs: 433 | n, h, w, c = images.shape 434 | images = np.reshape(images, (n, c, h, w)) 435 | images = aug_to_func[aug](images) 436 | return images -------------------------------------------------------------------------------- /dataset_env/data_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import django 5 | 6 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../')) 7 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../db/')) 8 | 9 | os.environ.setdefault("DJANGO_SETTINGS_MODULE", "db.settings") 10 | django.setup() 11 | 12 | import utils 13 | from global_config import * 14 | from traj_db.models import SurrealRoboticsSuiteTrajectory, RLBenchTrajectory 15 | 16 | taj_db_dict = {'SURREAL' : SurrealRoboticsSuiteTrajectory,'RLBENCH' : RLBenchTrajectory} 17 | 18 | def get_dataset_args(): 19 | parser = get_global_parser() 20 | 21 | # NOTE: 'SURREAL' is a placeholder. The dbs are set according to global_config.env -> see below. v 22 | parser.add_argument('--traj_db', type=env2TrajDB, default='SURREAL') 23 | parser.add_argument('--obv_keys', type=env2keys, default='SURREAL') 24 | parser.add_argument('--get_by_task_id', type=bool, default=False) 25 | parser.add_argument('--archives_path', type=str, default=os.path.join(DATA_DIR, 'data_files/archives')) 26 | parser.add_argument('--episode_type', type=ep_type, default='teleop', choices=['teleop', 'imitation', 'expert', 'policy', 'exploration', 'random']) 27 | parser.add_argument('--media_dir', type=str, default=os.path.join(BASE_DIR, 'db/static/media/')) 28 | parser.add_argument('--vid_path', type=str, default='vid.mp4') 29 | parser.add_argument('--fps', type=int, default=30) 30 | parser.add_argument('--vocab_path', type=str, default=os.path.join(DATA_DIR, 'data_files/vocab.pkl')) 31 | 32 | parser.add_argument('--data_agumentation', type=utils.str2bool, default=False) # WARNING : Don't use now, UNSTABLE. 33 | parser.add_argument('--augs', type=utils.str2list, default='crop', help='See others in data_aug.py') 34 | 35 | config = parser.parse_args() 36 | config.env_args = env2args(config.env) 37 | config.traj_db = env2TrajDB(config.env) 38 | config.obv_keys = env2keys(config.env) 39 | config.data_path = os.path.join(config.data_path, '{}_{}/'.format(config.env, config.env_type)) 40 | config.archives_path = os.path.join(config.archives_path, '{}_{}/'.format(config.env, config.env_type)) 41 | config.vid_path = os.path.join(config.media_dir, config.vid_path) 42 | 43 | utils.check_n_create_dir(config.data_path, config.display_warnings) 44 | utils.check_n_create_dir(config.archives_path, config.display_warnings) 45 | utils.check_n_create_dir(config.media_dir, config.display_warnings) 46 | 47 | return config 48 | 49 | def env2TrajDB(string): 50 | if string is None: 51 | return None 52 | return taj_db_dict[string.upper()] 53 | 54 | def ep_type(string): 55 | ep_dict = {'teleop': 'EPISODE_ROBOT_PLAY', 'imitation': 'EPISODE_ROBOT_IMITATED', 56 | 'expert': 'EPISODE_ROBOT_EXPERT', 'policy': 'EPISODE_ROBOT_POLICY', 57 | 'exploration': 'EPISODE_ROBOT_EXPLORED', 'random': 'EPISODE_ROBOT_RANDOM'} 58 | return ep_dict[string.lower()] 59 | 60 | def env2keys(string): 61 | ep_dict = {'RLBENCH' : {'vis_obv_key' : 'left_shoulder_rgb', 'dof_obv_key' : 'state'}, 62 | 'SURREAL' : {'vis_obv_key' : 'image', 'dof_obv_key' : 'robot-state'},} 63 | return ep_dict[string.upper()] 64 | 65 | if __name__ == '__main__': 66 | print("=> Testing data_config.py") 67 | args = get_dataset_args() 68 | print(args.traj_db) 69 | -------------------------------------------------------------------------------- /dataset_env/deg_base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import torch 5 | from torch.nn.utils.rnn import pad_sequence 6 | from torch.utils.data import Dataset, DataLoader 7 | 8 | import data_aug as rad 9 | from data_config import get_dataset_args 10 | from file_storage import get_trajectory, get_random_trajectory 11 | 12 | class DataEnvGroup(object): 13 | ''' + NOTE : Create subclass for every environment, eg. 14 | Check `assert self.env_name == 'ENV_NAME'` 15 | ''' 16 | def __init__(self, get_episode_type=None): 17 | ''' + Arguments: 18 | - get_episode_type: Get data of a particular episode_type (teleop, imitation, etc.) 19 | > Default : None, get data with any episode_type. 20 | ''' 21 | self.config = get_dataset_args() 22 | self.env_name = self.config.env 23 | self.env_type = self.config.env_type 24 | self.episode_type = get_episode_type 25 | self.traj_dataset = self.TrajDataset(self.episode_type, self.config) 26 | 27 | # NOTE : Environment dependent properties 28 | # Set these after inheriting the class. NotImplementedError 29 | self.vis_obv_key = None 30 | self.dof_obv_key = None 31 | 32 | self.obs_space = None 33 | self.action_space = None 34 | 35 | def get_env(self): 36 | raise NotImplementedError 37 | 38 | def _get_obs(self, obs, key): 39 | raise NotImplementedError 40 | 41 | def teleoperate(self, demon_config, task=None): 42 | raise NotImplementedError 43 | 44 | def random_trajectory(self, demons_config): 45 | raise NotImplementedError 46 | 47 | def get_random_goal(self): 48 | assert issubclass(type(self), DataEnvGroup) is True # NOTE : might raise error - remove if so 49 | goal = get_random_trajectory()[0][self.vis_obv_key][-1] 50 | return goal 51 | 52 | class TrajDataset(Dataset): 53 | def __init__(self, episode_type, config): 54 | self.episode_type = episode_type 55 | self.config = config 56 | 57 | def __len__(self): 58 | if self.episode_type is None: 59 | return self.config.traj_db.objects.count() - 1 # HACK 60 | else: 61 | return self.config.traj_db.objects.filter(episode_type=self.episode_type).count() 62 | 63 | def __getitem__(self, idx): 64 | trajectory = get_trajectory(index=idx, episode_type=self.episode_type) 65 | if self.config.data_agumentation: 66 | trajectory[self.vis_obv_key] = rad.apply_augs(trajectory[self.vis_obv_key], self.config) 67 | return trajectory 68 | 69 | def _collate_wrap(self, remove_task_state=False): 70 | # NOTE - if batch_size > 1, it removes the task-dependent states to get a common dim. 71 | def pad_collate(batch): 72 | assert None not in [self.vis_obv_key, self.dof_obv_key] 73 | tr_vobvs = [torch.from_numpy(b[self.vis_obv_key]) for b in batch] 74 | tr_dof = [torch.from_numpy(b[self.dof_obv_key][:, : self.obs_space[self.dof_obv_key]] if remove_task_state else b[self.dof_obv_key]) for b in batch] 75 | tr_actions = [torch.from_numpy(b['action']) for b in batch] 76 | 77 | tr_vobvs_pad = pad_sequence(tr_vobvs, batch_first=True, padding_value=0) 78 | tr_dof_pad = pad_sequence(tr_dof, batch_first=True, padding_value=0) 79 | tr_actions_pad = pad_sequence(tr_actions, batch_first=True, padding_value=0) 80 | 81 | padded_batch = {self.vis_obv_key : tr_vobvs_pad, self.dof_obv_key : tr_dof_pad, 'action' : tr_actions_pad} 82 | return padded_batch 83 | 84 | return pad_collate 85 | 86 | def get_traj_dataloader(self, batch_size, num_workers=1, shuffle=True): 87 | dataloader = DataLoader(dataset=self.traj_dataset, batch_size=batch_size, 88 | shuffle=shuffle, num_workers=num_workers, collate_fn=self._collate_wrap(remove_task_state=(batch_size != 1))) 89 | return dataloader 90 | -------------------------------------------------------------------------------- /dataset_env/file_storage.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import sys 4 | import uuid 5 | import torch 6 | import pickle 7 | import tarfile 8 | import numpy as np 9 | import torchvision 10 | from random import randint 11 | from data_config import get_dataset_args, ep_type 12 | 13 | from PIL import Image 14 | 15 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../')) 16 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../db/')) 17 | config = get_dataset_args() 18 | 19 | from traj_db.models import ArchiveFile 20 | 21 | def store_trajectoy(trajectory, episode_type=config.episode_type, task=None): 22 | ''' 23 | Save trajectory to the corresponding database based on env and env_type specified in config. 24 | + Arguments: 25 | - trajectory: {deg.vis_obv_key : np.array([n]), deg.dof_obv_key : np.array([n]), 'action' : np.array([n])} 26 | - episode_type [optional]: tag to store trajectories with (eg. 'teleop' or 'imitation') 27 | ''' 28 | if 'EPISODE_' not in episode_type: 29 | episode_type = ep_type(episode_type) 30 | if task is None: 31 | task = config.env_type 32 | assert 'action' in trajectory.keys() 33 | 34 | # NOTE : Current data_path is a placeholder. Edited below with UUID. 35 | metadata = config.traj_db(task_id=task, env_id=config.env, 36 | data_path=config.data_path, episode_type=episode_type, traj_steps=trajectory['action'].shape[0]) 37 | 38 | metadata.save() 39 | metadata.data_path = os.path.join(config.data_path, "traj_{}.pt".format(metadata.episode_id)) 40 | metadata.save() 41 | 42 | with open(metadata.data_path, 'wb') as file: 43 | pickle.dump(trajectory, file, protocol=pickle.HIGHEST_PROTOCOL) 44 | 45 | def get_trajectory(episode_type=None, index=None, episode_id=None): 46 | ''' 47 | Gets a particular trajectory from the corresponding database based on env and env_type specified in config. 48 | + Arguments: 49 | - episode_type [optional]: if you want trajectory specific to one episode_type (eg. 'teleop' or 'imitation') 50 | - index [optional]: get trajectory at a particular index 51 | - episode_id [optional]: get trajectory by it's episode_id (primary key) 52 | 53 | + NOTE: either index or episode_id should be not None. 54 | + NOTE: episode_type, env_type become POINTLESS when you pass episode_id. 55 | ''' 56 | # TODO : If trajectory is in archive-file, get it from there 57 | if episode_id is None and index is None: 58 | return [get_trajectory(episode_id=traj_obj.episode_id) for traj_obj in config.traj_db.objects.all()] 59 | 60 | if index is not None: 61 | if episode_type is None: # TODO : Clean code 62 | metadata = config.traj_db.objects.filter(task_id=config.env_type)[index] if config.get_by_task_id else config.traj_db.objects.all()[index] 63 | else: 64 | metadata = config.traj_db.objects.filter(task_id=config.env_type, episode_type=episode_type)[index] if config.get_by_task_id else config.traj_db.objects.filter(episode_type=episode_type)[index] 65 | elif episode_id is not None: 66 | episode_id = str(episode_id) 67 | metadata = config.traj_db.objects.get(episode_id=uuid.UUID(episode_id)) 68 | 69 | with open(metadata.data_path, 'rb') as file: 70 | trajectory = pickle.load(file) 71 | 72 | return trajectory 73 | 74 | def get_random_trajectory(episode_type=None): 75 | ''' 76 | Gets a random trajectory from the corresponding database based on env and env_type specified in config. 77 | + Arguments: 78 | - episode_type [optional]: if you want trajectory specific to one episode_type (eg. 'teleop' or 'imitation') 79 | ''' 80 | count = config.traj_db.objects.count() 81 | random_index = randint(1, count) 82 | if episode_type is None: 83 | metadata = config.traj_db.objects.filter(task_id=config.env_type)[random_index] if config.get_by_task_id else config.traj_db.objects.all()[random_index] 84 | else: 85 | metadata = config.traj_db.objects.filter(task_id=config.env_type, episode_type=episode_type)[random_index] if config.get_by_task_id else config.traj_db.objects.filter(episode_type=episode_type)[random_index] 86 | 87 | episode_id = str(metadata.episode_id) 88 | task_id = metadata.task_id 89 | trajectory = get_trajectory(episode_id=episode_id) 90 | 91 | return trajectory, episode_id, task_id 92 | 93 | def create_video(trajectory): 94 | ''' 95 | Creates videos and stores video, the initial and the final frame in the paths specified in data_config. 96 | + Arguments: 97 | - trajectory: {deg.vis_obv_key : np.array([n]), deg.dof_obv_key : np.array([n]), 'action' : np.array([n])} 98 | ''' 99 | frames = trajectory[config.obv_keys['vis_obv_key']].astype(np.uint8) 100 | assert frames.shape[-1] == 3 101 | 102 | inital_obv, goal_obv = Image.fromarray(frames[0]), Image.fromarray(frames[-1]) 103 | inital_obv.save(os.path.join(config.media_dir, 'inital.png')) 104 | goal_obv.save(os.path.join(config.media_dir, 'goal.png')) 105 | 106 | if type(frames) is not torch.Tensor: 107 | frames = torch.from_numpy(frames) 108 | 109 | torchvision.io.write_video(config.vid_path, frames, config.fps) 110 | return config.vid_path 111 | 112 | # NOT TESTED 113 | def archive_traj_task(task=config.env_type, episode_type=None, file_name=None): 114 | ''' 115 | Archives trajectories by task (env_type) 116 | + Arguments: 117 | - task: config.env_type - group and archive them all. 118 | - episode_type [optional]: store trajectories w/ same task, episode_type together. 119 | - file_name: the name of the archive file. [NOTE: NOT THE PATH] 120 | > NOTE : default file_name: `env_task.tar.gz` 121 | ''' 122 | if episode_type is None: 123 | objects = config.traj_db.objects.get(task_id=task) 124 | f_name = "{}_{}.tar.gz".format(config.env, config.env_type) 125 | else: 126 | objects = config.traj_db.objects.get(task_id=task, episode_type=episode_type) 127 | f_name = "{}_{}_{}.tar.gz".format(config.env, config.env_type, episode_type) 128 | 129 | if file_name is None: 130 | file_name = f_name 131 | file_name = os.path.join(config.archives_path, file_name) 132 | 133 | tar = tarfile.open(file_name, "w:gz") 134 | for metadata in objects: 135 | if metadata.is_archived == True: 136 | continue 137 | 138 | metadata.is_archived = True 139 | metadata.save() 140 | 141 | tar.add(metadata.data_path) 142 | 143 | archive = ArchiveFile(trajectory=metadata, env_id=metadata.env_id, archive_file=file_name) 144 | archive.save() 145 | tar.close() 146 | 147 | def delete_trajectory(episode_id): 148 | obj = config.traj_db.objects.get(episode_id=uuid.UUID(episode_id)) 149 | if os.path.exists(obj.data_path): 150 | os.remove(obj.data_path) 151 | obj.delete() 152 | 153 | def flush_traj_db(): 154 | raise NotImplementedError -------------------------------------------------------------------------------- /dataset_env/organize.sh: -------------------------------------------------------------------------------- 1 | # All envs in `/scratch/envs` 2 | # # ---- Surreal Robotics Suite ---- 3 | # git clone https://github.com/StanfordVL/robosuite.git 4 | # cd robosuite 5 | # pip3 install -r requirements-extra.txt 6 | # cd ../ 7 | # mv ./robosuite ./surreal 8 | # # -------------------------------- 9 | 10 | # # ---- USC's Furniture Dataset ---- 11 | # git clone https://github.com/clvrai/furniture 12 | # cd furniture 13 | # sudo apt-get install libgl1-mesa-dev libgl1-mesa-glx libosmesa6-dev patchelf libopenmpi-dev libglew-dev 14 | # pip install -r requirements.txt 15 | # pip install -r requirements.dev.txt 16 | 17 | # # Download this https://drive.google.com/drive/folders/1ofnw_zid9zlfkjBLY_gl-CozwLUco2ib 18 | # # and extract to furniture dir 19 | # unzip binary.zip 20 | 21 | # # Virtual Display 22 | # # sudo apt-get install xserver-xorg libglu1-mesa-dev freeglut3-dev mesa-common-dev libxmu-dev libxi-dev 23 | # # sudo nvidia-xconfig -a --use-display-device=None --virtual=1280x1024 24 | # # sudo /usr/bin/X :1 & 25 | # # python -m demo_manual --virtual_display :1 26 | # # ---------------------------------- 27 | -------------------------------------------------------------------------------- /dataset_env/rlbench_deg.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import gym 4 | import numpy as np 5 | 6 | from deg_base import DataEnvGroup 7 | from file_storage import store_trajectoy 8 | 9 | ENV_PATH = '/scratch/envs/RLBench' 10 | sys.path.append(ENV_PATH) 11 | 12 | import rlbench.gym 13 | from rlbench.environment import Environment 14 | from rlbench.action_modes import ArmActionMode, ActionMode 15 | from rlbench.backend.observation import Observation 16 | from rlbench.observation_config import ObservationConfig 17 | from rlbench.tasks import ReachTarget 18 | 19 | class RLBenchDataEnvGroup(DataEnvGroup): 20 | ''' DataEnvGroup for RLBench environment. 21 | 22 | + The observation space can be modified through `global_config.env_args` 23 | + Observation space: 24 | - 'state': proprioceptive feature : [37] + task_specific 25 | > robot joint - velocities, positions, forces 26 | > gripper - pose, joint-position, touch_forces 27 | > task_low_dim_state 28 | - RGB-D/RGB image : (128 x 128 by default) 29 | > 'left_shoulder_rgb', 'right_shoulder_rgb' 30 | > 'front_rgb', 'wrist_rgb' 31 | - NOTE : Better to use only one of the RGB obvs rather than all, saves a lot of time while env creation. 32 | 33 | - Refer: https://github.com/stepjam/RLBench/blob/20988254b773aae433146fff3624d8bcb9ed7330/rlbench/observation_config.py 34 | 35 | + The action spaces by default are joint velocities [7] and gripper actuations [1]. 36 | - Dimension : [8] 37 | ''' 38 | def __init__(self, get_episode_type=None): 39 | super(RLBenchDataEnvGroup, self).__init__(get_episode_type) 40 | assert self.env_name == 'RLBENCH' 41 | self.env_obj = None 42 | self.use_gym = self.config.env_args['use_gym'] 43 | 44 | self.observation_mode = self.config.env_args['observation_mode'].lower() 45 | self.left_obv_key = 'left_shoulder_rgb' 46 | self.right_obv_key = 'right_shoulder_rgb' 47 | self.wrist_obv_key = 'wrist_rgb' 48 | self.front_obv_key = 'front_rgb' 49 | 50 | if self.observation_mode not in ['vision', 'state'] and not self.use_gym: 51 | self.config.env_args['vis_obv_key'] = self.observation_mode 52 | 53 | self.vis_obv_key = self.config.env_args['vis_obv_key'] 54 | self.dof_obv_key = 'state' 55 | self.env_action_key = 'joint_velocities' 56 | self.env_gripper_key = 'gripper_open' 57 | 58 | self.obs_space = {self.dof_obv_key : (37), self.left_obv_key : (128, 128, 3), self.right_obv_key : (128, 128, 3), 59 | self.wrist_obv_key : (128, 128, 3), self.front_obv_key: (128, 128, 3)} 60 | 61 | self.action_space, self.gripper_space = (7), (1) 62 | if self.config.env_args['combine_action_space']: 63 | self.action_space += self.gripper_space 64 | 65 | def get_env(self, task=None): 66 | task = task if task else self.config.env_type 67 | if self.use_gym: 68 | assert type(task) == str # NOTE : When using gym, the task has to be represented as a sting. 69 | assert self.observation_mode in ['vision', 'state'] 70 | 71 | env = gym.make(task, observation_mode=self.config.env_args['observation_mode'], 72 | render_mode=self.config.env_args['render_mode']) 73 | self.env_obj = env 74 | else: 75 | obs_config = ObservationConfig() 76 | if self.observation_mode == 'vision': 77 | obs_config.set_all(True) 78 | elif self.observation_mode == 'state': 79 | obs_config.set_all_high_dim(False) 80 | obs_config.set_all_low_dim(True) 81 | else: 82 | obs_config_dict = {self.left_obv_key : obs_config.left_shoulder_camera, 83 | self.right_obv_key : obs_config.right_shoulder_camera, 84 | self.wrist_obv_key : obs_config.wrist_camera, 85 | self.front_obv_key: obs_config.front_camera} 86 | 87 | assert self.observation_mode in obs_config_dict.keys() 88 | 89 | obs_config.set_all_high_dim(False) 90 | obs_config_dict[self.observation_mode].set_all(True) 91 | obs_config.set_all_low_dim(True) 92 | 93 | # TODO : Write code to change it from env_args 94 | action_mode = ActionMode(ArmActionMode.ABS_JOINT_VELOCITY) 95 | self.env_obj = Environment(action_mode, obs_config=obs_config, headless=True) 96 | 97 | task = task if task else ReachTarget 98 | if type(task) == str: 99 | task = task.split('-')[0] 100 | task = self.env_obj._string_to_task(task) 101 | 102 | self.env_obj.launch() 103 | env = self.env_obj.get_task(task) # NOTE : `env` refered as `task` in RLBench docs. 104 | return env 105 | 106 | def _get_obs(self, obs, key): 107 | assert obs is not None and key is not None 108 | if type(obs) == tuple: 109 | obs = obs[1] 110 | if type(obs) == dict: 111 | return obs[key] 112 | elif type(obs) == Observation: 113 | if key == 'state': 114 | return obs.get_low_dim_data() 115 | return getattr(obs, key) 116 | 117 | def shutdown_env(self): 118 | if self.env_obj is None: 119 | print("Environment not created, call `.get_env()`") 120 | elif self.use_gym: 121 | self.env_obj.close() 122 | else: 123 | self.env_obj.shutdown() 124 | self.env_obj = None 125 | 126 | def teleoperate(self, demons_config, task=None): 127 | if self.config.env_args['keyboard_teleop']: 128 | raise NotImplementedError 129 | else: 130 | if self.env_obj is None or task is None: 131 | env = self.get_env(task) 132 | else: 133 | if type(task) == str: 134 | task = self.env_obj._string_to_task(task.split('-')[0]) 135 | env = self.env_obj.get_task(task) 136 | 137 | if self.use_gym: 138 | demos = env.task.get_demos(demons_config.n_runs, live_demos=True) 139 | else : 140 | demos = env.get_demos(demons_config.n_runs, live_demos=True) 141 | demos = np.array(demos).flatten() 142 | 143 | for i in range(demons_config.n_runs): 144 | sample = demos[i] 145 | if self.observation_mode != 'state': 146 | tr_vobvs = np.array([self._get_obs(obs, self.vis_obv_key) for obs in sample]) 147 | tr_dof = np.array([self._get_obs(obs, self.dof_obv_key).flatten() for obs in sample]) 148 | tr_actions = np.array([self._get_obs(obs, self.env_action_key).flatten() for obs in sample]) 149 | tr_gripper = np.array([[self._get_obs(obs, self.env_gripper_key)] for obs in sample]) 150 | 151 | if self.config.env_args['combine_action_space']: 152 | tr_actions = np.concatenate((tr_actions, tr_gripper), axis=-1) 153 | 154 | print("Storing Trajectory") 155 | trajectory = {self.dof_obv_key : tr_dof, 'action' : tr_actions} 156 | if self.observation_mode != 'state': 157 | trajectory.update({self.vis_obv_key : tr_vobvs}) 158 | store_trajectoy(trajectory, episode_type='teleop', task=task) 159 | 160 | self.shutdown_env() 161 | 162 | def random_trajectory(self, demons_config): 163 | env = self.get_env() 164 | obs = env.reset() 165 | 166 | tr_vobvs, tr_dof, tr_actions = [], [], [] 167 | for step in range(demons_config.flush_freq): 168 | if self.observation_mode != 'state': 169 | tr_vobvs.append(np.array(self._get_obs(obs, self.vis_obv_key))) 170 | tr_dof.append(np.array(self._get_obs(obs, self.dof_obv_key).flatten())) 171 | 172 | action = np.random.normal(size=self.action_space[0]) 173 | obs, reward, done, info = env.step(action) 174 | 175 | tr_actions.append(action) 176 | 177 | print("Storing Trajectory") 178 | trajectory = {self.dof_obv_key : np.array(tr_dof), 'action' : np.array(tr_actions)} 179 | if self.observation_mode != 'state': 180 | trajectory.update({self.vis_obv_key : np.array(tr_vobvs)}) 181 | store_trajectoy(trajectory, episode_type='random') 182 | self.shutdown_env() 183 | 184 | if __name__ == '__main__': 185 | from torch.utils.data import DataLoader 186 | print("=> Testing rlbench_deg.py") 187 | 188 | deg = RLBenchDataEnvGroup() 189 | print(deg.obs_space[deg.vis_obv_key], deg.action_space) 190 | env = deg.get_env(task='change_clock') 191 | obs = env.reset() 192 | print(deg._get_obs(obs, deg.dof_obv_key).shape) 193 | print(deg._get_obs(obs, 'task_low_dim_state').shape) 194 | 195 | #traj_data = DataLoader(deg.traj_dataset, batch_size=1, shuffle=False, num_workers=1) 196 | traj_data = deg.get_traj_dataloader(batch_size=5) 197 | for i, b in enumerate(traj_data): 198 | print(i, b[deg.dof_obv_key].shape) 199 | 200 | deg.shutdown_env() 201 | -------------------------------------------------------------------------------- /dataset_env/surreal_deg.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | from collections import OrderedDict 5 | 6 | from deg_base import DataEnvGroup 7 | from file_storage import store_trajectoy 8 | 9 | ENV_PATH = '/scratch/envs/robosuite' 10 | sys.path.append(ENV_PATH) 11 | 12 | import robosuite as suite 13 | from robosuite.wrappers import IKWrapper 14 | import robosuite.utils.transform_utils as T 15 | 16 | class SurrealDataEnvGroup(DataEnvGroup): 17 | ''' DataEnvGroup for Surreal Robotics Suite environment. 18 | 19 | + The observation space can be modified through `global_config.env_args` 20 | + Observation space: 21 | - 'robot-state': proprioceptive feature - vector of: 22 | > cos and sin of robot joint positions 23 | > robot joint velocities 24 | > current configuration of the gripper. 25 | - 'object-state': object-centric feature 26 | - 'image': RGB/RGB-D image 27 | > (256 x 256 by default) 28 | - Refer: https://github.com/StanfordVL/robosuite/tree/master/robosuite/environments 29 | 30 | + The action spaces by default are joint velocities and gripper actuations. 31 | - Dimension : [8] 32 | - To use the end-effector action-space use inverse-kinematics using IKWrapper. 33 | - Refer: https://github.com/StanfordVL/robosuite/tree/master/robosuite/wrappers 34 | ''' 35 | def __init__(self, get_episode_type=None): 36 | super(SurrealDataEnvGroup, self).__init__(get_episode_type) 37 | assert self.env_name == 'SURREAL' 38 | 39 | self.vis_obv_key = 'image' 40 | self.dof_obv_key = 'robot-state' 41 | self.obs_space = {self.vis_obv_key: (256, 256, 3), self.dof_obv_key: (30)} 42 | self.action_space = (8) 43 | 44 | def get_env(self): 45 | env = suite.make(self.config.env_type, **self.config.env_args) 46 | return env 47 | 48 | def _get_obs(self, obs, key): 49 | return obs[key] 50 | 51 | def play_trajectory(self): 52 | # TODO 53 | # Refer https://github.com/StanfordVL/robosuite/blob/master/robosuite/scripts/playback_demonstrations_from_hdf5.py 54 | raise NotImplementedError 55 | 56 | def teleoperate(self, demons_config): 57 | env = self.get_env() 58 | # Need to use inverse-kinematics controller to set position using device 59 | env = IKWrapper(env) 60 | 61 | if demons_config.device == "keyboard": 62 | from robosuite.devices import Keyboard 63 | device = Keyboard() 64 | env.viewer.add_keypress_callback("any", device.on_press) 65 | env.viewer.add_keyup_callback("any", device.on_release) 66 | env.viewer.add_keyrepeat_callback("any", device.on_press) 67 | elif demons_config.device == "spacemouse": 68 | from robosuite.devices import SpaceMouse 69 | device = SpaceMouse() 70 | 71 | for run in range(demons_config.n_runs): 72 | obs = env.reset() 73 | env.set_robot_joint_positions([0, -1.18, 0.00, 2.18, 0.00, 0.57, 1.5708]) 74 | # rotate the gripper so we can see it easily - NOTE : REMOVE MAYBE 75 | env.viewer.set_camera(camera_id=2) 76 | env.render() 77 | device.start_control() 78 | 79 | reset = False 80 | task_completion_hold_count = -1 81 | step = 0 82 | tr_vobvs, tr_dof, tr_actions = [], [], [] 83 | 84 | while not reset: 85 | if int(step % demons_config.collect_freq) == 0: 86 | tr_vobvs.append(np.array(obs[self.vis_obv_key])) 87 | tr_dof.append(np.array(obs[self.dof_obv_key].flatten())) 88 | 89 | device_state = device.get_controller_state() 90 | dpos, rotation, grasp, reset = ( 91 | device_state["dpos"], 92 | device_state["rotation"], 93 | device_state["grasp"], 94 | device_state["reset"], 95 | ) 96 | 97 | current = env._right_hand_orn 98 | drotation = current.T.dot(rotation) 99 | dquat = T.mat2quat(drotation) 100 | grasp = grasp - 1. 101 | ik_action = np.concatenate([dpos, dquat, [grasp]]) 102 | 103 | obs, _, done, _ = env.step(ik_action) 104 | env.render() 105 | 106 | joint_velocities = np.array(env.controller.commanded_joint_velocities) 107 | if env.env.mujoco_robot.name == "sawyer": 108 | gripper_actuation = np.array(ik_action[7:]) 109 | elif env.env.mujoco_robot.name == "baxter": 110 | gripper_actuation = np.array(ik_action[14:]) 111 | 112 | # NOTE: Action for the normal environment (not inverse kinematic) 113 | action = np.concatenate([joint_velocities, gripper_actuation], axis=0) 114 | 115 | if int(step % demons_config.collect_freq) == 0: 116 | tr_actions.append(action) 117 | 118 | if (int(step % demons_config.flush_freq) == 0) or (demons_config.break_traj_success and task_completion_hold_count == 0): 119 | print("Storing Trajectory") 120 | trajectory = {self.vis_obv_key : np.array(tr_vobvs), self.dof_obv_key : np.array(tr_dof), 'action' : np.array(tr_actions)} 121 | store_trajectoy(trajectory, 'teleop') 122 | trajectory, tr_vobvs, tr_dof, tr_actions = {}, [], [], [] 123 | 124 | if demons_config.break_traj_success and env._check_success(): 125 | if task_completion_hold_count > 0: 126 | task_completion_hold_count -= 1 # latched state, decrement count 127 | else: 128 | task_completion_hold_count = 10 # reset count on first success timestep 129 | else: 130 | task_completion_hold_count = -1 131 | 132 | step += 1 133 | 134 | env.close() 135 | 136 | def random_trajectory(self, demons_config): 137 | env = self.get_env() 138 | obs = env.reset() 139 | env.set_robot_joint_positions([0, -1.18, 0.00, 2.18, 0.00, 0.57, 1.5708]) 140 | 141 | tr_vobvs, tr_dof, tr_actions = [], [], [] 142 | for step in range(demons_config.flush_freq): 143 | tr_vobvs.append(np.array(obs[self.vis_obv_key])) 144 | tr_dof.append(np.array(obs[self.dof_obv_key].flatten())) 145 | 146 | action = np.random.randn(env.dof) 147 | obs, reward, done, info = env.step(action) 148 | 149 | tr_actions.append(action) 150 | 151 | print("Storing Trajectory") 152 | trajectory = {self.vis_obv_key : np.array(tr_vobvs), self.dof_obv_key : np.array(tr_dof), 'action' : np.array(tr_actions)} 153 | store_trajectoy(trajectory, 'random') 154 | env.close() 155 | 156 | 157 | if __name__ == '__main__': 158 | from torch.utils.data import DataLoader 159 | print("=> Testing surreal_deg.py") 160 | 161 | deg = SurrealDataEnvGroup() 162 | print(deg.obs_space[deg.vis_obv_key], deg.action_space) 163 | print(deg.get_env().reset()[deg.vis_obv_key].shape) 164 | 165 | traj_data = DataLoader(deg.traj_dataset, batch_size=1, shuffle=True, num_workers=1) 166 | print(next(iter(traj_data))[deg.dof_obv_key].shape) 167 | -------------------------------------------------------------------------------- /db/README.md: -------------------------------------------------------------------------------- 1 | ## Django Based Database for Trajectory Meta-data. 2 | 3 | ## Setup 4 | To run the usual commands on `manage.py`, run `python3 manage.py` and enter commands when prompted. I had to edit `manage.py` out due to errors with argv and the global config. For running `runserver`, type the command twice (as prompted). 5 | 6 | Run the commands in order to setup the DB from scratch. 7 | ``` 8 | createsuperuser 9 | makemigrations 10 | migrate 11 | ``` 12 | 13 | ## Database 14 | The `traj_db` app provides the database to store the meta-data of the trajectories. The `Trajectory` table is an abstract-table. Each environment gets its own table just by inheriting it (see `traj_db/models.py`). Each trajectory is stored as a pickle file and is represented by UUID. The table includes other information such as the task it's performing, how it was generated etc. All the functionalities to interact with the pickle files and the DB is provided in `../dataset_env/file_storage.py`. 15 | 16 | ## Adding New Environments/Tables 17 | To add a table for a new environment, create a new, empty subclass of `Trajectory` and register it in `traj_db/admin.py`. That's it! Follow `../dataset_env`'s README for other config-modifications. 18 | 19 | ## Views 20 | The `admin/` portal provides a good interface to the database, but usage of functions in `../dataset_env/file_storage.py` is recommended. The index page generates a video of a random trajectory of the current environment and displays it. 21 | -------------------------------------------------------------------------------- /db/db/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhruvramani/data-driven-robotics/059ac21d516cffce291e2e8ac965c9da44969114/db/db/__init__.py -------------------------------------------------------------------------------- /db/db/asgi.py: -------------------------------------------------------------------------------- 1 | """ 2 | ASGI config for db project. 3 | 4 | It exposes the ASGI callable as a module-level variable named ``application``. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/3.1/howto/deployment/asgi/ 8 | """ 9 | 10 | import os 11 | 12 | from django.core.asgi import get_asgi_application 13 | 14 | os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'db.settings') 15 | 16 | application = get_asgi_application() 17 | -------------------------------------------------------------------------------- /db/db/settings.py: -------------------------------------------------------------------------------- 1 | """ 2 | Django settings for db project. 3 | 4 | Generated by 'django-admin startproject' using Django 3.1. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/3.1/topics/settings/ 8 | 9 | For the full list of settings and their values, see 10 | https://docs.djangoproject.com/en/3.1/ref/settings/ 11 | """ 12 | import os 13 | from pathlib import Path 14 | 15 | # Build paths inside the project like this: BASE_DIR / 'subdir'. 16 | BASE_DIR = Path(__file__).resolve(strict=True).parent.parent 17 | 18 | 19 | # Quick-start development settings - unsuitable for production 20 | # See https://docs.djangoproject.com/en/3.1/howto/deployment/checklist/ 21 | 22 | # SECURITY WARNING: keep the secret key used in production secret! 23 | SECRET_KEY = 'n71h#3gzx7m%#8*%gm!fgl)$*()rk0l%f9aj7g1a4l5-syjx*)' 24 | 25 | # SECURITY WARNING: don't run with debug turned on in production! 26 | DEBUG = True 27 | 28 | ALLOWED_HOSTS = ['10.24.6.58', '127.0.0.1', 'localhost'] 29 | 30 | 31 | # Application definition 32 | 33 | INSTALLED_APPS = [ 34 | 'django.contrib.admin', 35 | 'django.contrib.auth', 36 | 'django.contrib.contenttypes', 37 | 'django.contrib.sessions', 38 | 'django.contrib.messages', 39 | 'django.contrib.staticfiles', 40 | 'traj_db.apps.TrajDbConfig', 41 | 'polymorphic', 42 | ] 43 | 44 | MIDDLEWARE = [ 45 | 'django.middleware.security.SecurityMiddleware', 46 | 'django.contrib.sessions.middleware.SessionMiddleware', 47 | 'django.middleware.common.CommonMiddleware', 48 | 'django.middleware.csrf.CsrfViewMiddleware', 49 | 'django.contrib.auth.middleware.AuthenticationMiddleware', 50 | 'django.contrib.messages.middleware.MessageMiddleware', 51 | 'django.middleware.clickjacking.XFrameOptionsMiddleware', 52 | ] 53 | 54 | ROOT_URLCONF = 'db.urls' 55 | 56 | TEMPLATE_DIR = os.path.join(BASE_DIR, "templates") 57 | 58 | TEMPLATES = [ 59 | { 60 | 'BACKEND': 'django.template.backends.django.DjangoTemplates', 61 | 'DIRS': [TEMPLATE_DIR], 62 | 'APP_DIRS': True, 63 | 'OPTIONS': { 64 | 'context_processors': [ 65 | 'django.template.context_processors.debug', 66 | 'django.template.context_processors.request', 67 | 'django.contrib.auth.context_processors.auth', 68 | 'django.contrib.messages.context_processors.messages', 69 | ], 70 | }, 71 | }, 72 | ] 73 | 74 | WSGI_APPLICATION = 'db.wsgi.application' 75 | 76 | 77 | # Database 78 | # https://docs.djangoproject.com/en/3.1/ref/settings/#databases 79 | 80 | 81 | # NOTE: Might migrate to Postgres in the future. Need sqllite-3 for speed now. 82 | # https://www.digitalocean.com/community/tutorials/how-to-use-postgresql-with-your-django-application-on-ubuntu-14-04 83 | # https://www.vphventures.com/how-to-migrate-your-django-project-from-sqlite-to-postgresql/ 84 | 85 | DATABASES = { 86 | 'default': { 87 | 'ENGINE': 'django.db.backends.sqlite3', 88 | 'NAME': BASE_DIR / 'db.sqlite3', 89 | } 90 | } 91 | 92 | 93 | # Password validation 94 | # https://docs.djangoproject.com/en/3.1/ref/settings/#auth-password-validators 95 | 96 | AUTH_PASSWORD_VALIDATORS = [ 97 | { 98 | 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', 99 | }, 100 | { 101 | 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', 102 | }, 103 | { 104 | 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', 105 | }, 106 | { 107 | 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', 108 | }, 109 | ] 110 | 111 | 112 | # Internationalization 113 | # https://docs.djangoproject.com/en/3.1/topics/i18n/ 114 | 115 | LANGUAGE_CODE = 'en-us' 116 | 117 | TIME_ZONE = 'UTC' 118 | 119 | USE_I18N = True 120 | 121 | USE_L10N = True 122 | 123 | USE_TZ = True 124 | 125 | 126 | # Static files (CSS, JavaScript, Images) 127 | # https://docs.djangoproject.com/en/3.1/howto/static-files/ 128 | 129 | STATIC_URL = '/static/' 130 | STATICFILES_DIRS = [os.path.join(BASE_DIR, 'static'), ] 131 | STATIC_ROOT = os.path.join(BASE_DIR, 'staticfiles') 132 | -------------------------------------------------------------------------------- /db/db/urls.py: -------------------------------------------------------------------------------- 1 | """db URL Configuration 2 | 3 | The `urlpatterns` list routes URLs to views. For more information please see: 4 | https://docs.djangoproject.com/en/3.1/topics/http/urls/ 5 | Examples: 6 | Function views 7 | 1. Add an import: from my_app import views 8 | 2. Add a URL to urlpatterns: path('', views.home, name='home') 9 | Class-based views 10 | 1. Add an import: from other_app.views import Home 11 | 2. Add a URL to urlpatterns: path('', Home.as_view(), name='home') 12 | Including another URLconf 13 | 1. Import the include() function: from django.urls import include, path 14 | 2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) 15 | """ 16 | from django.contrib import admin 17 | from django.urls import path 18 | from traj_db import views 19 | 20 | urlpatterns = [ 21 | path('admin/', admin.site.urls), 22 | path('', views.vid, name="traj_video"), 23 | ] 24 | -------------------------------------------------------------------------------- /db/db/wsgi.py: -------------------------------------------------------------------------------- 1 | """ 2 | WSGI config for db project. 3 | 4 | It exposes the WSGI callable as a module-level variable named ``application``. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/3.1/howto/deployment/wsgi/ 8 | """ 9 | 10 | import os 11 | 12 | from django.core.wsgi import get_wsgi_application 13 | 14 | os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'db.settings') 15 | 16 | application = get_wsgi_application() 17 | -------------------------------------------------------------------------------- /db/manage.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Django's command-line utility for administrative tasks.""" 3 | import os 4 | import sys 5 | 6 | 7 | def main(): 8 | os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'db.settings') 9 | try: 10 | from django.core.management import execute_from_command_line 11 | except ImportError as exc: 12 | raise ImportError( 13 | "Couldn't import Django. Are you sure it's installed and " 14 | "available on your PYTHONPATH environment variable? Did you " 15 | "forget to activate a virtual environment?" 16 | ) from exc 17 | command = input("Enter Command : ") 18 | execute_from_command_line(["manage.py"] + command.split(" "))#sys.argv) 19 | 20 | 21 | if __name__ == '__main__': 22 | main() 23 | -------------------------------------------------------------------------------- /db/static/css/foo: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhruvramani/data-driven-robotics/059ac21d516cffce291e2e8ac965c9da44969114/db/static/css/foo -------------------------------------------------------------------------------- /db/static/js/foo: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhruvramani/data-driven-robotics/059ac21d516cffce291e2e8ac965c9da44969114/db/static/js/foo -------------------------------------------------------------------------------- /db/templates/base.html: -------------------------------------------------------------------------------- 1 | {% load static %} 2 | 3 | 4 | 5 |
6 | 7 | 8 | 9 |Initial State
20 |Final State
25 |