├── db
├── db
│ ├── __init__.py
│ ├── asgi.py
│ ├── wsgi.py
│ ├── urls.py
│ └── settings.py
├── static
│ ├── css
│ │ └── foo
│ └── js
│ │ └── foo
├── traj_db
│ ├── __init__.py
│ ├── migrations
│ │ └── __init__.py
│ ├── tests.py
│ ├── apps.py
│ ├── admin.py
│ ├── views.py
│ └── models.py
├── manage.py
├── templates
│ ├── vid.html
│ └── base.html
└── README.md
├── model
├── __init__.py
├── main.py
├── model_config.py
├── test.py
├── train.py
├── layers.py
└── models.py
├── dataset_env
├── __init__.py
├── organize.sh
├── README.md
├── data_config.py
├── deg_base.py
├── file_storage.py
├── surreal_deg.py
├── rlbench_deg.py
└── data_aug.py
├── collect_demons
├── __init__.py
├── README.md
├── main.py
├── demons_config.py
└── imitate_play.py
├── utils.py
├── README.md
├── .gitignore
└── global_config.py
/db/db/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/db/static/css/foo:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/db/static/js/foo:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/dataset_env/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/db/traj_db/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/collect_demons/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/db/traj_db/migrations/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/db/traj_db/tests.py:
--------------------------------------------------------------------------------
1 | from django.test import TestCase
2 |
3 | # Create your tests here.
4 |
--------------------------------------------------------------------------------
/db/traj_db/apps.py:
--------------------------------------------------------------------------------
1 | from django.apps import AppConfig
2 |
3 |
4 | class TrajDbConfig(AppConfig):
5 | name = 'traj_db'
6 |
--------------------------------------------------------------------------------
/db/traj_db/admin.py:
--------------------------------------------------------------------------------
1 | from django.contrib import admin
2 | from .models import ArchiveFile, SurrealRoboticsSuiteTrajectory, TrajectoryTag, RLBenchTrajectory
3 | # Register your models here.
4 |
5 | admin.site.register(RLBenchTrajectory)
6 | admin.site.register(SurrealRoboticsSuiteTrajectory)
7 | admin.site.register(ArchiveFile)
8 | admin.site.register(TrajectoryTag)
9 |
--------------------------------------------------------------------------------
/db/db/asgi.py:
--------------------------------------------------------------------------------
1 | """
2 | ASGI config for db project.
3 |
4 | It exposes the ASGI callable as a module-level variable named ``application``.
5 |
6 | For more information on this file, see
7 | https://docs.djangoproject.com/en/3.1/howto/deployment/asgi/
8 | """
9 |
10 | import os
11 |
12 | from django.core.asgi import get_asgi_application
13 |
14 | os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'db.settings')
15 |
16 | application = get_asgi_application()
17 |
--------------------------------------------------------------------------------
/db/db/wsgi.py:
--------------------------------------------------------------------------------
1 | """
2 | WSGI config for db project.
3 |
4 | It exposes the WSGI callable as a module-level variable named ``application``.
5 |
6 | For more information on this file, see
7 | https://docs.djangoproject.com/en/3.1/howto/deployment/wsgi/
8 | """
9 |
10 | import os
11 |
12 | from django.core.wsgi import get_wsgi_application
13 |
14 | os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'db.settings')
15 |
16 | application = get_wsgi_application()
17 |
--------------------------------------------------------------------------------
/db/traj_db/views.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import uuid
4 | from django.shortcuts import render
5 |
6 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../dataset_env'))
7 |
8 | from data_config import get_dataset_args
9 | import file_storage
10 |
11 | config = get_dataset_args()
12 |
13 | def vid(request):
14 | trajectory, episode_id, task_id = file_storage.get_random_trajectory()
15 | vid_path = file_storage.create_video(trajectory)
16 | assert os.path.isfile(vid_path)
17 | return render(request, 'vid.html', {})
18 |
--------------------------------------------------------------------------------
/model/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | from model_config import get_model_args
7 | from train import train
8 | from test import test_experiment
9 |
10 | device = torch.device('gpu' if torch.cuda.is_available() else 'cpu')
11 |
12 | def run():
13 | print("CHANGE exp_name TO NOT OVERRIDE PREV. EXPERIMENTS.")
14 | config = get_model_args()
15 | np.random.seed(config.seed)
16 | torch.manual_seed(config.seed)
17 |
18 | if config.is_train:
19 | train(config)
20 |
21 | test_experiment(config)
22 |
23 | if __name__ == '__main__':
24 | run()
--------------------------------------------------------------------------------
/collect_demons/README.md:
--------------------------------------------------------------------------------
1 | # Data Collection
2 |
3 | This module provides a pit-stop to collect trajectories. The trajectories are stored as pickle files and their identifiers, paths and other meta-data are stored in the DB. The `collect_by` config defined in `demons_config` specifies how you want to collect the data - by teleopetation, a specific/random policy or an imitation-based policy trained on the data. The imitation policy is defined by a RNN-based gaussian policy in `imitate_play.py`. Since teleoperation and collection of random trajectories is environment specific, they are defined in the environment's corresponding DEG. But, they should be run from the `main.py` here.
4 |
--------------------------------------------------------------------------------
/db/manage.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """Django's command-line utility for administrative tasks."""
3 | import os
4 | import sys
5 |
6 |
7 | def main():
8 | os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'db.settings')
9 | try:
10 | from django.core.management import execute_from_command_line
11 | except ImportError as exc:
12 | raise ImportError(
13 | "Couldn't import Django. Are you sure it's installed and "
14 | "available on your PYTHONPATH environment variable? Did you "
15 | "forget to activate a virtual environment?"
16 | ) from exc
17 | command = input("Enter Command : ")
18 | execute_from_command_line(["manage.py"] + command.split(" "))#sys.argv)
19 |
20 |
21 | if __name__ == '__main__':
22 | main()
23 |
--------------------------------------------------------------------------------
/db/db/urls.py:
--------------------------------------------------------------------------------
1 | """db URL Configuration
2 |
3 | The `urlpatterns` list routes URLs to views. For more information please see:
4 | https://docs.djangoproject.com/en/3.1/topics/http/urls/
5 | Examples:
6 | Function views
7 | 1. Add an import: from my_app import views
8 | 2. Add a URL to urlpatterns: path('', views.home, name='home')
9 | Class-based views
10 | 1. Add an import: from other_app.views import Home
11 | 2. Add a URL to urlpatterns: path('', Home.as_view(), name='home')
12 | Including another URLconf
13 | 1. Import the include() function: from django.urls import include, path
14 | 2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
15 | """
16 | from django.contrib import admin
17 | from django.urls import path
18 | from traj_db import views
19 |
20 | urlpatterns = [
21 | path('admin/', admin.site.urls),
22 | path('', views.vid, name="traj_video"),
23 | ]
24 |
--------------------------------------------------------------------------------
/db/templates/vid.html:
--------------------------------------------------------------------------------
1 | {% extends 'base.html' %}
2 | {% load static%}
3 | {% block title %} Trajectory Video {% endblock title %}
4 | {% block content %}
5 |
6 |
7 |
8 |
9 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |

19 |
Initial State
20 |
21 |
22 |
23 |

24 |
Final State
25 |
26 |
27 |
28 | {% endblock content %}
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import shutil
4 | import pathlib
5 |
6 | def delete_folder(pth):
7 | pth = pathlib.Path(pth)
8 | for sub in pth.iterdir() :
9 | if sub.is_dir() :
10 | delete_folder(sub)
11 | else :
12 | sub.unlink()
13 | pth.rmdir()
14 |
15 | def delete_shutil(path):
16 | shutil.rmtree(path)
17 |
18 | def recreate_dir(path, display_warning=True):
19 | if os.path.exists(path):
20 | delete_shutil(path)
21 | os.makedirs(path)
22 |
23 | if display_warning:
24 | print("=> Directory {} created.".format(path))
25 |
26 | def check_n_create_dir(path, display_warning=True):
27 | if not os.path.isdir(path):
28 | pathlib.Path(path).mkdir(parents=True, exist_ok=True)
29 | elif display_warning:
30 | print("=> Directory {} exists (no new directory created).".format(path))
31 |
32 | def str2bool(string):
33 | return string.lower() == 'true'
34 |
35 | def str2list(string):
36 | if not string:
37 | return []
38 | return [x for x in string.split(",")]
--------------------------------------------------------------------------------
/dataset_env/organize.sh:
--------------------------------------------------------------------------------
1 | # All envs in `/scratch/envs`
2 | # # ---- Surreal Robotics Suite ----
3 | # git clone https://github.com/StanfordVL/robosuite.git
4 | # cd robosuite
5 | # pip3 install -r requirements-extra.txt
6 | # cd ../
7 | # mv ./robosuite ./surreal
8 | # # --------------------------------
9 |
10 | # # ---- USC's Furniture Dataset ----
11 | # git clone https://github.com/clvrai/furniture
12 | # cd furniture
13 | # sudo apt-get install libgl1-mesa-dev libgl1-mesa-glx libosmesa6-dev patchelf libopenmpi-dev libglew-dev
14 | # pip install -r requirements.txt
15 | # pip install -r requirements.dev.txt
16 |
17 | # # Download this https://drive.google.com/drive/folders/1ofnw_zid9zlfkjBLY_gl-CozwLUco2ib
18 | # # and extract to furniture dir
19 | # unzip binary.zip
20 |
21 | # # Virtual Display
22 | # # sudo apt-get install xserver-xorg libglu1-mesa-dev freeglut3-dev mesa-common-dev libxmu-dev libxi-dev
23 | # # sudo nvidia-xconfig -a --use-display-device=None --virtual=1280x1024
24 | # # sudo /usr/bin/X :1 &
25 | # # python -m demo_manual --virtual_display :1
26 | # # ----------------------------------
27 |
--------------------------------------------------------------------------------
/db/templates/base.html:
--------------------------------------------------------------------------------
1 | {% load static %}
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 | {% block title %} {% endblock title %}
10 |
11 |
12 |
13 |
21 |
22 |
23 |
24 | {% block content %}
25 | {% endblock content %}
26 |
27 |
--------------------------------------------------------------------------------
/db/README.md:
--------------------------------------------------------------------------------
1 | ## Django Based Database for Trajectory Meta-data.
2 |
3 | ## Setup
4 | To run the usual commands on `manage.py`, run `python3 manage.py` and enter commands when prompted. I had to edit `manage.py` out due to errors with argv and the global config. For running `runserver`, type the command twice (as prompted).
5 |
6 | Run the commands in order to setup the DB from scratch.
7 | ```
8 | createsuperuser
9 | makemigrations
10 | migrate
11 | ```
12 |
13 | ## Database
14 | The `traj_db` app provides the database to store the meta-data of the trajectories. The `Trajectory` table is an abstract-table. Each environment gets its own table just by inheriting it (see `traj_db/models.py`). Each trajectory is stored as a pickle file and is represented by UUID. The table includes other information such as the task it's performing, how it was generated etc. All the functionalities to interact with the pickle files and the DB is provided in `../dataset_env/file_storage.py`.
15 |
16 | ## Adding New Environments/Tables
17 | To add a table for a new environment, create a new, empty subclass of `Trajectory` and register it in `traj_db/admin.py`. That's it! Follow `../dataset_env`'s README for other config-modifications.
18 |
19 | ## Views
20 | The `admin/` portal provides a good interface to the database, but usage of functions in `../dataset_env/file_storage.py` is recommended. The index page generates a video of a random trajectory of the current environment and displays it.
21 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # A Boilerplate For Data Driven Robotics
2 | & a Pytorch implementation of [Learning Latent Plans from Play](https://learning-from-play.github.io/).
3 |
4 | This repo is supposed to provide organized & scalable experimentation of data-driven robotics learning. You can adapt it to your own model and environment with minor modifications.
5 |
6 | ## Organization
7 | ### Modules
8 | This setup consists of a databse (`db/`) inspired from [[1]](https://arxiv.org/abs/1909.12200) storing meta-data of the trajectories collected and a light web-app which renders a video of the trajectory. The DEG module (`dataset_env/`) provides easy adaption to various environments, dataloaders (`deg_base.py`), an easy functionality to interact with the DB and store/retrieve trajectories (`file_storage.py`) - all bundled up. The current implementation includes support for [RLBench](https://github.com/stepjam/RLBench/) and (older)[Robosuite](https://github.com/ARISE-Initiative/robosuite) environments. The collection module (`collect_demons/`) provides data-collection mechanisms such as teleoperation and imitation policies. Every new model can have it's on directory and the current `model/` contains a Pytorch implementation of [LfP](https://learning-from-play.github.io/). The training and testing code are defined in `model/` too.
9 |
10 | Additional information about each module is provided in their respective READMEs.
11 | ### Configs
12 | Config common to all the modules is defined in `global_config.py`. Each of the other modules have their own config files (`*_config.py`) which add to the global config. The config system is designed to automatically change on minor edits (eg. a change in `env` changes all the paths and other env-related properties).
13 |
--------------------------------------------------------------------------------
/dataset_env/README.md:
--------------------------------------------------------------------------------
1 | # Datasets & Environment Groups
2 |
3 | `DataEnvGroup`s (DEGs) provide a common window to interact with environments and their collected datasets. The abstract class (`deg_base.py`) provides most of the common functionalities and has to be overriden to adopt to a new environment. Each environment has it's own DEG file which provides definitions for the abstract methods and properties - see `rlbench_deg.py` and `surreal_deg.py`. Use the DEG to define & get environments, datasets and dataloaders and observation/action spaces with a common environment-agnostic syntax.
4 |
5 | The data-based configs and the link to the DB are defined in `data_config.py`. The environment is specified in `../global_config.py` and it's modification reflects the changes in data-configs. File and DB related functionalities are defined in `file_storage.py`. Visual augmentations to improve performance, inspired from [RAD](https://mishalaskin.github.io/rad/) - are provided in `data_aug.py`.
6 |
7 | ## Adding New Environments
8 |
9 | All the envs are stored in a directory outside the repo (defined by `ENV_PATH` in each of the DEG files) because of cleanliness. If you want to modify the original environment, clone it into this directory. All the cloning and renaming code goes in `organize.sh`.
10 |
11 | To create a DEG for a new environment, create a new `envname_deg.py` file and inherit the `DataEnvGroup` class. Import the actual environment class and provide definitions for environment-specific abstract methods (like `get_env`, `teleoperate` etc.) and properties (observation space & keys). See `rlbench_deg.py` for example.
12 |
13 | After adding the DEG file, create a table for the environment in `db/` (refer to its README). Add the table info in `taj_db_dict` & `env2keys()` in `data_config.py` and add the DEG info in `env2deg()` in `../global_config.py`. That's it! The rest of the functionalities, properties and configs will adopt automatically to the change in `env` in `../global_config.py`.
14 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | sftp-config.json
3 | db/static/media
4 | web_db/static/media
5 | dataset_env/surreal
6 | dataset_env/furniture
7 | dataset_env/rlbench
8 | __pycache__/
9 | *.py[cod]
10 | *$py.class
11 |
12 | # C extensions
13 | *.so
14 |
15 | # Distribution / packaging
16 | .Python
17 | build/
18 | develop-eggs/
19 | dist/
20 | downloads/
21 | eggs/
22 | .eggs/
23 | lib/
24 | lib64/
25 | parts/
26 | sdist/
27 | var/
28 | wheels/
29 | pip-wheel-metadata/
30 | share/python-wheels/
31 | *.egg-info/
32 | .installed.cfg
33 | *.egg
34 | MANIFEST
35 |
36 | # PyInstaller
37 | # Usually these files are written by a python script from a template
38 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
39 | *.manifest
40 | *.spec
41 |
42 | # Installer logs
43 | pip-log.txt
44 | pip-delete-this-directory.txt
45 |
46 | # Unit test / coverage reports
47 | htmlcov/
48 | .tox/
49 | .nox/
50 | .coverage
51 | .coverage.*
52 | .cache
53 | nosetests.xml
54 | coverage.xml
55 | *.cover
56 | *.py,cover
57 | .hypothesis/
58 | .pytest_cache/
59 |
60 | # Translations
61 | *.mo
62 | *.pot
63 |
64 | # Django stuff:
65 | *.log
66 | local_settings.py
67 | db.sqlite3
68 | db.sqlite3-journal
69 |
70 | # Flask stuff:
71 | instance/
72 | .webassets-cache
73 |
74 | # Scrapy stuff:
75 | .scrapy
76 |
77 | # Sphinx documentation
78 | docs/_build/
79 |
80 | # PyBuilder
81 | target/
82 |
83 | # Jupyter Notebook
84 | .ipynb_checkpoints
85 |
86 | # IPython
87 | profile_default/
88 | ipython_config.py
89 |
90 | # pyenv
91 | .python-version
92 |
93 | # pipenv
94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
97 | # install all needed dependencies.
98 | #Pipfile.lock
99 |
100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
101 | __pypackages__/
102 |
103 | # Celery stuff
104 | celerybeat-schedule
105 | celerybeat.pid
106 |
107 | # SageMath parsed files
108 | *.sage.py
109 |
110 | # Environments
111 | .env
112 | .venv
113 | env/
114 | venv/
115 | ENV/
116 | env.bak/
117 | venv.bak/
118 |
119 | # Spyder project settings
120 | .spyderproject
121 | .spyproject
122 |
123 | # Rope project settings
124 | .ropeproject
125 |
126 | # mkdocs documentation
127 | /site
128 |
129 | # mypy
130 | .mypy_cache/
131 | .dmypy.json
132 | dmypy.json
133 |
134 | # Pyre type checker
135 | .pyre/
136 |
--------------------------------------------------------------------------------
/model/model_config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 |
5 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../'))
6 |
7 | import utils
8 | from global_config import *
9 |
10 | def get_model_args():
11 | parser = get_global_parser()
12 |
13 | # NOTE: 'SURREAL' is a placeholder. The deg is set according to global_config.env -> see below. v
14 | parser.add_argument('--deg', type=env2deg, default='SURREAL')
15 | parser.add_argument('--seed', '-s', type=int, default=0)
16 | parser.add_argument('--is_train', type=utils.str2bool, default=False)
17 | parser.add_argument('--resume', type=utils.str2bool, default=False)
18 |
19 | parser.add_argument('--num_workers', type=int, default=1)
20 | parser.add_argument('--save_graphs', type=utils.str2bool, default=False)
21 | parser.add_argument('--tensorboard_path', type=str, default=os.path.join(DATA_DIR, 'runs/tensorboard/'))
22 | parser.add_argument('--models_save_path', type=str, default=os.path.join(DATA_DIR, 'runs/models/'))
23 | parser.add_argument('--save_interval_epoch', type=int, default=10)
24 |
25 | parser.add_argument('--max_epochs', type=int, default=4) #50
26 | parser.add_argument('--batch_size', type=int, default=2)
27 | parser.add_argument('--learning_rate', type=float, default=2e-4)
28 | parser.add_argument('--n_test_evals', type=int, default=10)
29 | parser.add_argument('--max_test_timestep', type=int, default=40)
30 |
31 | parser.add_argument('--beta', type=float, default=0.01)
32 | parser.add_argument('--visual_state_dim', type=int, default=64)
33 | parser.add_argument('--combined_state_dim', type=int, default=94)
34 | parser.add_argument('--goal_dim', type=int, default=32)
35 | parser.add_argument('--latent_dim', type=int, default=256)
36 |
37 | config = parser.parse_args()
38 | config.env_args = env2args(config.env)
39 | config.deg = env2deg(config.env)
40 | config.models_save_path = os.path.join(config.models_save_path, 'model_{}_{}_{}/'.format(config.env, config.env_type, config.exp_name))
41 | config.tensorboard_path = os.path.join(config.tensorboard_path, 'model_{}_{}_{}/'.format(config.env, config.env_type, config.exp_name))
42 | config.data_path = os.path.join(config.data_path, '{}_{}/'.format(config.env, config.env_type))
43 |
44 | if config.is_train and not config.resume:
45 | utils.recreate_dir(config.models_save_path, config.display_warnings)
46 | utils.recreate_dir(config.tensorboard_path, config.display_warnings)
47 | else:
48 | utils.check_n_create_dir(config.models_save_path, config.display_warnings)
49 | utils.check_n_create_dir(config.tensorboard_path, config.display_warnings)
50 |
51 | utils.check_n_create_dir(config.data_path, config.display_warnings)
52 |
53 | return config
--------------------------------------------------------------------------------
/model/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy
3 | import torch
4 |
5 | from models import *
6 | #from render_browser import render_browser
7 |
8 | device = torch.device('gpu' if torch.cuda.is_available() else 'cpu')
9 |
10 | #@render_browser
11 | def test_experiment(config):
12 | deg = config.deg()
13 | env = deg.get_env()
14 | vobs_dim, dof_dim = deg.obs_space[deg.vis_obv_key], deg.obs_space[deg.dof_obv_key]
15 | act_dim = deg.action_space
16 |
17 | with torch.no_grad():
18 | perception_module = PerceptionModule(vobs_dim, dof_dim, config.visual_state_dim).to(device)
19 | visual_goal_encoder = VisualGoalEncoder(config.visual_state_dim, config.goal_dim).to(device)
20 | plan_proposer = PlanProposerModule(config.combined_state_dim, config.goal_dim, config.latent_dim).to(device)
21 | control_module = ControlModule(act_dim, config.combined_state_dim, config.goal_dim, config.latent_dim).to(device)
22 |
23 | perception_module.load_state_dict(torch.load(os.path.join(config.models_save_path, 'perception.pth')))
24 | visual_goal_encoder.load_state_dict(torch.load(os.path.join(config.models_save_path, 'visual_goal.pth')))
25 | plan_proposer.load_state_dict(torch.load(os.path.join(config.models_save_path, 'plan_proposer.pth')))
26 | control_module.load_state_dict(torch.load(os.path.join(config.models_save_path, 'control_module.pth')))
27 |
28 | obvs = env.reset()
29 | for i in range(config.n_test_evals):
30 | obvs = env.reset()
31 | goal = torch.from_numpy(deg.get_random_goal()).float()
32 | goal = goal.reshape(1, goal.shape[2], goal.shape[0], goal.shape[1])
33 | goal = perception_module(goal)
34 | goal, _, _, _ = visual_goal_encoder(goal)
35 | t, done = 0, False
36 |
37 | while (not done) and t <= config.max_test_timestep:
38 | # TODO : Figure out way to set done tru when goal is reached
39 | visual_obv, dof_obv = torch.from_numpy(deg._get_obs(obvs, deg.vis_obv_key).float(), torch.from_numpy(deg._get_obs(obvs, deg.dof_obv_key).float()
40 | visual_obv = visual_obv.reshape(1, visual_obv.shape[2], visual_obv.shape[0], visual_obv.shape[1])
41 | dof_obv = dof_obv.reshape(1, dof_obv.shape[0])
42 | state = perception_module(visual_obv, dof_obv)
43 | z_p, _, _, _ = plan_proposer(state, goal)
44 |
45 | action, _ = control_module.step(state, goal, z_p)
46 | obvs, _, done, _ = env.step(action[0])
47 | #yield env.render(mode='rgb_array')
48 | env.render()
49 | t += 1
50 |
51 | if __name__ == '__main__':
52 | from model_config import get_model_args
53 | config = get_model_args()
54 | test_experiment(config)
--------------------------------------------------------------------------------
/collect_demons/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import numpy as np
4 | import torch
5 |
6 | from demons_config import get_demons_args
7 |
8 | if __name__ == "__main__":
9 | demon_config = get_demons_args()
10 | deg = demon_config.deg()
11 | if demon_config.collect_by == 'teleop':
12 | # TODO remove this
13 | # tasks = ['beat_the_buzz', 'block_pyramid', 'change_channel', 'change_clock', 'close_box', 'close_door', 'close_drawer', 'close_fridge', 'close_grill', 'close_jar', 'close_laptop_lid', 'close_microwave', 'empty_container', 'empty_dishwasher', 'get_ice_from_fridge', 'hang_frame_on_hanger', 'hannoi_square', 'hit_ball_with_queue', 'hockey', 'insert_usb_in_computer', 'lamp_off', 'lamp_on', 'light_bulb_in', 'light_bulb_out', 'meat_off_grill', 'meat_on_grill', 'move_hanger', 'open_box', 'open_door', 'open_drawer', 'open_fridge', 'open_grill', 'open_jar', 'open_microwave', 'open_oven', 'open_window', 'open_wine_bottle', 'phone_on_base', 'pick_and_lift', 'pick_up_cup', 'place_cups', 'place_hanger_on_rack', 'place_shape_in_shape_sorter', 'play_jenga', 'plug_charger_in_power_supply', 'pour_from_cup_to_cup', 'press_switch', 'push_button', 'push_buttons', 'put_books_on_bookshelf', 'put_bottle_in_fridge', 'put_groceries_in_cupboard', 'put_item_in_drawer', 'put_knife_in_knife_block', 'put_knife_on_chopping_board', 'put_money_in_safe', 'put_plate_in_colored_dish_rack', 'put_rubbish_in_bin', 'put_shoes_in_box', 'put_toilet_roll_on_stand', 'put_tray_in_oven', 'put_umbrella_in_umbrella_stand', 'reach_and_drag', 'reach_target', 'remove_cups', 'scoop_with_spatula', 'screw_nail', 'set_the_table', 'setup_checkers', 'slide_block_to_target', 'slide_cabinet_open', 'slide_cabinet_open_and_place_cups', 'solve_puzzle', 'stack_blocks', 'stack_cups', 'stack_wine', 'straighten_rope', 'sweep_to_dustpan', 'take_cup_out_from_cabinet', 'take_frame_off_hanger', 'take_item_out_of_drawer', 'take_lid_off_saucepan', 'take_money_out_safe', 'take_off_weighing_scales', 'take_plate_off_colored_dish_rack', 'take_shoes_out_of_box', 'take_toilet_roll_off_stand', 'take_tray_out_of_oven', 'take_umbrella_out_of_umbrella_stand', 'take_usb_out_of_computer', 'toilet_seat_down', 'toilet_seat_up', 'turn_oven_on', 'turn_tap', 'tv_off', 'tv_on', 'unplug_charger', 'water_plants', 'weighing_scales', 'wipe_desk']
14 | # for i, task in enumerate(tasks):
15 | # print(i, task)
16 | # try :
17 | deg.teleoperate(demon_config) #, task)
18 | # except :
19 | # print("Couldn't collect demos")
20 | elif demon_config.collect_by == 'random':
21 | deg.random_trajectory(demon_config)
22 | elif demon_config.collect_by == 'imitation':
23 | #NOTE : NOT TESTED
24 | import imitate_play
25 |
26 | if demon_config.train_imitation:
27 | imitate_play.train_imitation(demon_config)
28 | imitate_play.imitate_play()
--------------------------------------------------------------------------------
/collect_demons/demons_config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 |
5 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../'))
6 |
7 | import utils
8 | from global_config import *
9 |
10 | def get_demons_args():
11 | parser = get_global_parser()
12 |
13 | # NOTE: 'SURREAL' is a placeholder. The deg is set according to global_config.env -> see below. v
14 | parser.add_argument('--deg', type=env2deg, default='SURREAL')
15 | parser.add_argument("--collect_by", type=str, default='teleop', choices=['teleop', 'imitation', 'expert', 'policy', 'exploration', 'random'])
16 | parser.add_argument("--device", type=str, default="keyboard", choices=["keyboard", "spacemouse"])
17 | parser.add_argument("--collect_freq", type=int, default=1)
18 | parser.add_argument("--flush_freq", type=int, default=25) # NOTE : RAM Issues, change here : 75
19 | parser.add_argument("--break_traj_success", type=utils.str2bool, default=True)
20 | parser.add_argument("--n_runs", type=int, default=10, #10
21 | help="no. of runs of traj collection, affective when break_traj_success = False")
22 |
23 | # Imitation model
24 | parser.add_argument('--resume', type=utils.str2bool, default=False)
25 | parser.add_argument('--train_imitation', type=utils.str2bool, default=False)
26 | parser.add_argument('--models_save_path', type=str, default=os.path.join(DATA_DIR, 'runs/imitation-models/'))
27 | parser.add_argument('--tensorboard_path', type=str, default=os.path.join(DATA_DIR, 'runs/imitation-tensorboard/'))
28 | parser.add_argument('--load_models', type=utils.str2bool, default=True)
29 | parser.add_argument('--use_model_perception', type=utils.str2bool, default=True)
30 | parser.add_argument('--n_gen_traj', type=int, default=200, help="Number of trajectories to generate by imitation")
31 |
32 | config = parser.parse_args()
33 | config.env_args = env2args(config.env)
34 | config.deg = env2deg(config.env)
35 | config.data_path = os.path.join(config.data_path, '{}_{}/'.format(config.env, config.env_type))
36 | config.models_save_path = os.path.join(config.models_save_path, '{}_{}/'.format(config.env, config.env_type))
37 | config.tensorboard_path = os.path.join(config.tensorboard_path, '{}_{}_{}/'.format(config.env, config.env_type, config.exp_name))
38 |
39 | if config.train_imitation and not config.resume:
40 | utils.recreate_dir(config.models_save_path, config.display_warnings)
41 | utils.recreate_dir(config.tensorboard_path, config.display_warnings)
42 | else:
43 | utils.check_n_create_dir(config.models_save_path, config.display_warnings)
44 | utils.check_n_create_dir(config.tensorboard_path, config.display_warnings)
45 |
46 | utils.check_n_create_dir(config.data_path, config.display_warnings)
47 |
48 | return config
49 |
50 | if __name__ == '__main__':
51 | args = get_demons_args()
52 | print(args.models_save_path)
53 |
--------------------------------------------------------------------------------
/global_config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import json
4 | import datetime
5 | import argparse
6 |
7 | import utils
8 |
9 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'dataset_env/'))
10 |
11 | BASE_DIR = os.path.dirname((os.path.abspath(__file__)))
12 | DATA_DIR = '/scratch/robotics_data'
13 | TIME_STAMP = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
14 |
15 | # NOTE : RLBench - Run `Xvfb :DISP_NUM -screen 0 1024x768x24 +extension GLX +render -noreset &`
16 | # `export DISPLAY=:DISP_NUM`, where `DISP_NUM` >= 99 (whichever is free)
17 | # NOTE : RoboSuite - Run `xvfb-run -a -s "-screen 0 1400x900x24" zsh` &
18 | # `export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libGLEW.so` for rendering.
19 |
20 | def get_global_parser():
21 | ''' Global Config - contains global arguments common to all modules.
22 | NOTE: All the paths/dirs have env-env_type-exp_name concated to them at the end.
23 | - see model/model_config.py for example.
24 | '''
25 |
26 | parser = argparse.ArgumentParser("Data driven robotics",
27 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
28 |
29 | parser.add_argument('--env', type=str, default='RLBENCH') #SURREAL
30 | parser.add_argument('--env_type', type=str, default='close_drawer-vision-v0') #SawyerPickPlace
31 | parser.add_argument('--env_args', type=env2args, default='RLBENCH') # NOTE : placeholder, changed later.
32 | parser.add_argument('--exp_name', type=str, default='v0.5')
33 | parser.add_argument('--data_path', type=str, default=os.path.join(DATA_DIR, 'data_files/saved_data/'))
34 | parser.add_argument('--display_warnings', type=utils.str2bool, default=False)
35 |
36 | return parser
37 |
38 | def env2args(env_name):
39 | # NOTE - WARNING : Changing camera_height, camera_width changes the camer_obs dim.
40 | surreal_args = dict(has_renderer=True, has_offscreen_renderer=True, ignore_done=True, use_camera_obs=True,
41 | camera_height=256, camera_width=256, camera_name='agentview', use_object_obs=False, reward_shaping=True)
42 |
43 | # NOTE - WARNING : observation_mode='vision' is very heavy
44 | rlbench_args = dict(use_gym=False, observation_mode='left_shoulder_rgb', vis_obv_key='left_shoulder_rgb', render_mode='rgb_array',
45 | keyboard_teleop=False, combine_action_space=True)
46 |
47 | env2args_dict = {'SURREAL' : surreal_args, 'RLBENCH' : rlbench_args}
48 | return env2args_dict[env_name]
49 |
50 | def env2deg(env_name):
51 | from dataset_env.surreal_deg import SurrealDataEnvGroup
52 | from dataset_env.rlbench_deg import RLBenchDataEnvGroup
53 |
54 | # NOTE: *IMPORTANT* - Use these in model_config & demons_config ONLY.
55 | # *DEADLOCK* : can lead to deadlock if an object is created in global_config.
56 | env2deg_dict = {'SURREAL' : SurrealDataEnvGroup,'RLBENCH' : RLBenchDataEnvGroup}
57 |
58 | if env_name is None:
59 | return None
60 | return env2deg_dict[env_name.upper()]
61 |
62 | if __name__ == '__main__':
63 | print("=> Testing global_config.py")
64 | args = get_global_parser().parse_args()
65 | print(args.env)
66 | deg = env2deg('SURREAL')
67 | print(deg.action_space)
68 |
--------------------------------------------------------------------------------
/dataset_env/data_config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 | import django
5 |
6 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../'))
7 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../db/'))
8 |
9 | os.environ.setdefault("DJANGO_SETTINGS_MODULE", "db.settings")
10 | django.setup()
11 |
12 | import utils
13 | from global_config import *
14 | from traj_db.models import SurrealRoboticsSuiteTrajectory, RLBenchTrajectory
15 |
16 | taj_db_dict = {'SURREAL' : SurrealRoboticsSuiteTrajectory,'RLBENCH' : RLBenchTrajectory}
17 |
18 | def get_dataset_args():
19 | parser = get_global_parser()
20 |
21 | # NOTE: 'SURREAL' is a placeholder. The dbs are set according to global_config.env -> see below. v
22 | parser.add_argument('--traj_db', type=env2TrajDB, default='SURREAL')
23 | parser.add_argument('--obv_keys', type=env2keys, default='SURREAL')
24 | parser.add_argument('--get_by_task_id', type=bool, default=False)
25 | parser.add_argument('--archives_path', type=str, default=os.path.join(DATA_DIR, 'data_files/archives'))
26 | parser.add_argument('--episode_type', type=ep_type, default='teleop', choices=['teleop', 'imitation', 'expert', 'policy', 'exploration', 'random'])
27 | parser.add_argument('--media_dir', type=str, default=os.path.join(BASE_DIR, 'db/static/media/'))
28 | parser.add_argument('--vid_path', type=str, default='vid.mp4')
29 | parser.add_argument('--fps', type=int, default=30)
30 | parser.add_argument('--vocab_path', type=str, default=os.path.join(DATA_DIR, 'data_files/vocab.pkl'))
31 |
32 | parser.add_argument('--data_agumentation', type=utils.str2bool, default=False) # WARNING : Don't use now, UNSTABLE.
33 | parser.add_argument('--augs', type=utils.str2list, default='crop', help='See others in data_aug.py')
34 |
35 | config = parser.parse_args()
36 | config.env_args = env2args(config.env)
37 | config.traj_db = env2TrajDB(config.env)
38 | config.obv_keys = env2keys(config.env)
39 | config.data_path = os.path.join(config.data_path, '{}_{}/'.format(config.env, config.env_type))
40 | config.archives_path = os.path.join(config.archives_path, '{}_{}/'.format(config.env, config.env_type))
41 | config.vid_path = os.path.join(config.media_dir, config.vid_path)
42 |
43 | utils.check_n_create_dir(config.data_path, config.display_warnings)
44 | utils.check_n_create_dir(config.archives_path, config.display_warnings)
45 | utils.check_n_create_dir(config.media_dir, config.display_warnings)
46 |
47 | return config
48 |
49 | def env2TrajDB(string):
50 | if string is None:
51 | return None
52 | return taj_db_dict[string.upper()]
53 |
54 | def ep_type(string):
55 | ep_dict = {'teleop': 'EPISODE_ROBOT_PLAY', 'imitation': 'EPISODE_ROBOT_IMITATED',
56 | 'expert': 'EPISODE_ROBOT_EXPERT', 'policy': 'EPISODE_ROBOT_POLICY',
57 | 'exploration': 'EPISODE_ROBOT_EXPLORED', 'random': 'EPISODE_ROBOT_RANDOM'}
58 | return ep_dict[string.lower()]
59 |
60 | def env2keys(string):
61 | ep_dict = {'RLBENCH' : {'vis_obv_key' : 'left_shoulder_rgb', 'dof_obv_key' : 'state'},
62 | 'SURREAL' : {'vis_obv_key' : 'image', 'dof_obv_key' : 'robot-state'},}
63 | return ep_dict[string.upper()]
64 |
65 | if __name__ == '__main__':
66 | print("=> Testing data_config.py")
67 | args = get_dataset_args()
68 | print(args.traj_db)
69 |
--------------------------------------------------------------------------------
/dataset_env/deg_base.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import numpy as np
4 | import torch
5 | from torch.nn.utils.rnn import pad_sequence
6 | from torch.utils.data import Dataset, DataLoader
7 |
8 | import data_aug as rad
9 | from data_config import get_dataset_args
10 | from file_storage import get_trajectory, get_random_trajectory
11 |
12 | class DataEnvGroup(object):
13 | ''' + NOTE : Create subclass for every environment, eg.
14 | Check `assert self.env_name == 'ENV_NAME'`
15 | '''
16 | def __init__(self, get_episode_type=None):
17 | ''' + Arguments:
18 | - get_episode_type: Get data of a particular episode_type (teleop, imitation, etc.)
19 | > Default : None, get data with any episode_type.
20 | '''
21 | self.config = get_dataset_args()
22 | self.env_name = self.config.env
23 | self.env_type = self.config.env_type
24 | self.episode_type = get_episode_type
25 | self.traj_dataset = self.TrajDataset(self.episode_type, self.config)
26 |
27 | # NOTE : Environment dependent properties
28 | # Set these after inheriting the class. NotImplementedError
29 | self.vis_obv_key = None
30 | self.dof_obv_key = None
31 |
32 | self.obs_space = None
33 | self.action_space = None
34 |
35 | def get_env(self):
36 | raise NotImplementedError
37 |
38 | def _get_obs(self, obs, key):
39 | raise NotImplementedError
40 |
41 | def teleoperate(self, demon_config, task=None):
42 | raise NotImplementedError
43 |
44 | def random_trajectory(self, demons_config):
45 | raise NotImplementedError
46 |
47 | def get_random_goal(self):
48 | assert issubclass(type(self), DataEnvGroup) is True # NOTE : might raise error - remove if so
49 | goal = get_random_trajectory()[0][self.vis_obv_key][-1]
50 | return goal
51 |
52 | class TrajDataset(Dataset):
53 | def __init__(self, episode_type, config):
54 | self.episode_type = episode_type
55 | self.config = config
56 |
57 | def __len__(self):
58 | if self.episode_type is None:
59 | return self.config.traj_db.objects.count() - 1 # HACK
60 | else:
61 | return self.config.traj_db.objects.filter(episode_type=self.episode_type).count()
62 |
63 | def __getitem__(self, idx):
64 | trajectory = get_trajectory(index=idx, episode_type=self.episode_type)
65 | if self.config.data_agumentation:
66 | trajectory[self.vis_obv_key] = rad.apply_augs(trajectory[self.vis_obv_key], self.config)
67 | return trajectory
68 |
69 | def _collate_wrap(self, remove_task_state=False):
70 | # NOTE - if batch_size > 1, it removes the task-dependent states to get a common dim.
71 | def pad_collate(batch):
72 | assert None not in [self.vis_obv_key, self.dof_obv_key]
73 | tr_vobvs = [torch.from_numpy(b[self.vis_obv_key]) for b in batch]
74 | tr_dof = [torch.from_numpy(b[self.dof_obv_key][:, : self.obs_space[self.dof_obv_key]] if remove_task_state else b[self.dof_obv_key]) for b in batch]
75 | tr_actions = [torch.from_numpy(b['action']) for b in batch]
76 |
77 | tr_vobvs_pad = pad_sequence(tr_vobvs, batch_first=True, padding_value=0)
78 | tr_dof_pad = pad_sequence(tr_dof, batch_first=True, padding_value=0)
79 | tr_actions_pad = pad_sequence(tr_actions, batch_first=True, padding_value=0)
80 |
81 | padded_batch = {self.vis_obv_key : tr_vobvs_pad, self.dof_obv_key : tr_dof_pad, 'action' : tr_actions_pad}
82 | return padded_batch
83 |
84 | return pad_collate
85 |
86 | def get_traj_dataloader(self, batch_size, num_workers=1, shuffle=True):
87 | dataloader = DataLoader(dataset=self.traj_dataset, batch_size=batch_size,
88 | shuffle=shuffle, num_workers=num_workers, collate_fn=self._collate_wrap(remove_task_state=(batch_size != 1)))
89 | return dataloader
90 |
--------------------------------------------------------------------------------
/db/db/settings.py:
--------------------------------------------------------------------------------
1 | """
2 | Django settings for db project.
3 |
4 | Generated by 'django-admin startproject' using Django 3.1.
5 |
6 | For more information on this file, see
7 | https://docs.djangoproject.com/en/3.1/topics/settings/
8 |
9 | For the full list of settings and their values, see
10 | https://docs.djangoproject.com/en/3.1/ref/settings/
11 | """
12 | import os
13 | from pathlib import Path
14 |
15 | # Build paths inside the project like this: BASE_DIR / 'subdir'.
16 | BASE_DIR = Path(__file__).resolve(strict=True).parent.parent
17 |
18 |
19 | # Quick-start development settings - unsuitable for production
20 | # See https://docs.djangoproject.com/en/3.1/howto/deployment/checklist/
21 |
22 | # SECURITY WARNING: keep the secret key used in production secret!
23 | SECRET_KEY = 'n71h#3gzx7m%#8*%gm!fgl)$*()rk0l%f9aj7g1a4l5-syjx*)'
24 |
25 | # SECURITY WARNING: don't run with debug turned on in production!
26 | DEBUG = True
27 |
28 | ALLOWED_HOSTS = ['10.24.6.58', '127.0.0.1', 'localhost']
29 |
30 |
31 | # Application definition
32 |
33 | INSTALLED_APPS = [
34 | 'django.contrib.admin',
35 | 'django.contrib.auth',
36 | 'django.contrib.contenttypes',
37 | 'django.contrib.sessions',
38 | 'django.contrib.messages',
39 | 'django.contrib.staticfiles',
40 | 'traj_db.apps.TrajDbConfig',
41 | 'polymorphic',
42 | ]
43 |
44 | MIDDLEWARE = [
45 | 'django.middleware.security.SecurityMiddleware',
46 | 'django.contrib.sessions.middleware.SessionMiddleware',
47 | 'django.middleware.common.CommonMiddleware',
48 | 'django.middleware.csrf.CsrfViewMiddleware',
49 | 'django.contrib.auth.middleware.AuthenticationMiddleware',
50 | 'django.contrib.messages.middleware.MessageMiddleware',
51 | 'django.middleware.clickjacking.XFrameOptionsMiddleware',
52 | ]
53 |
54 | ROOT_URLCONF = 'db.urls'
55 |
56 | TEMPLATE_DIR = os.path.join(BASE_DIR, "templates")
57 |
58 | TEMPLATES = [
59 | {
60 | 'BACKEND': 'django.template.backends.django.DjangoTemplates',
61 | 'DIRS': [TEMPLATE_DIR],
62 | 'APP_DIRS': True,
63 | 'OPTIONS': {
64 | 'context_processors': [
65 | 'django.template.context_processors.debug',
66 | 'django.template.context_processors.request',
67 | 'django.contrib.auth.context_processors.auth',
68 | 'django.contrib.messages.context_processors.messages',
69 | ],
70 | },
71 | },
72 | ]
73 |
74 | WSGI_APPLICATION = 'db.wsgi.application'
75 |
76 |
77 | # Database
78 | # https://docs.djangoproject.com/en/3.1/ref/settings/#databases
79 |
80 |
81 | # NOTE: Might migrate to Postgres in the future. Need sqllite-3 for speed now.
82 | # https://www.digitalocean.com/community/tutorials/how-to-use-postgresql-with-your-django-application-on-ubuntu-14-04
83 | # https://www.vphventures.com/how-to-migrate-your-django-project-from-sqlite-to-postgresql/
84 |
85 | DATABASES = {
86 | 'default': {
87 | 'ENGINE': 'django.db.backends.sqlite3',
88 | 'NAME': BASE_DIR / 'db.sqlite3',
89 | }
90 | }
91 |
92 |
93 | # Password validation
94 | # https://docs.djangoproject.com/en/3.1/ref/settings/#auth-password-validators
95 |
96 | AUTH_PASSWORD_VALIDATORS = [
97 | {
98 | 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
99 | },
100 | {
101 | 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
102 | },
103 | {
104 | 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
105 | },
106 | {
107 | 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
108 | },
109 | ]
110 |
111 |
112 | # Internationalization
113 | # https://docs.djangoproject.com/en/3.1/topics/i18n/
114 |
115 | LANGUAGE_CODE = 'en-us'
116 |
117 | TIME_ZONE = 'UTC'
118 |
119 | USE_I18N = True
120 |
121 | USE_L10N = True
122 |
123 | USE_TZ = True
124 |
125 |
126 | # Static files (CSS, JavaScript, Images)
127 | # https://docs.djangoproject.com/en/3.1/howto/static-files/
128 |
129 | STATIC_URL = '/static/'
130 | STATICFILES_DIRS = [os.path.join(BASE_DIR, 'static'), ]
131 | STATIC_ROOT = os.path.join(BASE_DIR, 'staticfiles')
132 |
--------------------------------------------------------------------------------
/db/traj_db/models.py:
--------------------------------------------------------------------------------
1 | import uuid
2 | from django.db import models
3 | from django.conf import settings
4 | from polymorphic.models import PolymorphicModel
5 |
6 | class TrajectoryTag(models.Model):
7 | '''Table of tags that can be attached to episodes.
8 |
9 | + Attributes:
10 | - name: Human readable tag name.
11 |
12 | object.trajectory_set : The epsidoes that have been annotated with this tag.
13 | '''
14 | name = models.CharField(max_length=75)
15 |
16 | def __str__(self):
17 | return self.name
18 |
19 | class Trajectory(PolymorphicModel):
20 | ''' Abstract table desicribing trajectory schema.
21 | NOTE : ABSTRACT CLASS - Create a new subclass/table for each Env.
22 | NOTE : EACH environment has it's own seperate table to store data.
23 |
24 | REFER: https://django-polymorphic.readthedocs.io/en/latest/quickstart.html
25 | https://stackoverflow.com/questions/30343212/foreignkey-field-related-to-abstract-model-in-django
26 |
27 | + Attributes:
28 | - episode_id: A UUID to identify the trajectory.
29 | - traj_count: The index number to this trajectory.
30 | - env_id: An identifier for the environment. Usually, the environment name.
31 | - task_id: Stores config.env_type - corresponding to the behavior in the trajectory.
32 | - traj_steps: Number of time-steps in the trajectory.
33 | - data_path: Filename holding the trajectory data for this episode.
34 | - is_archived: Wether the trajectory is present in a tar.gz file or not
35 | - time_stamp: Unix timestamp recording when the trajectory was generated.
36 | - episode_type: The type of policy that generated the episode. Possible values are:
37 | > `EPISODE_ROBOT_IMITATED`: Generated by a script which imitated human-generated play data.
38 | > `EPISODE_ROBOT_PLAY`: Generated from teleoperated human demonstration as play data.
39 | > `EPISODE_ROBOT_EXPERT` : Generated from expert human demonstration to solve a particular task.
40 | > `EPISODE_ROBOT_POLICY`: Generated from a policy aimed to solve a particular task.
41 | > `EPISODE_ROBOT_EXPLORED` : Data generated from (random) exploration.
42 | - tags: (TrajectoryTag) A list of tags attached to this episode.
43 | '''
44 | # class Meta: # Not needed as PolymorphicModel
45 | # abstract = True
46 |
47 | EPISODE_TYPE_CHOICES = [
48 | ('EPISODE_ROBOT_IMITATED', 'Script Generated by Imitation of Play-Data'),
49 | ('EPISODE_ROBOT_PLAY', 'Human Demonstration: Play-Data'),
50 | ('EPISODE_ROBOT_EXPERT', 'Human Demonstration: Expert, Task Specific'),
51 | ('EPISODE_ROBOT_POLICY', 'Collected by Task Specific Policy'),
52 | ('EPISODE_ROBOT_EXPLORED', 'Generated by (Random) Exploration'),
53 | ('EPISODE_ROBOT_RANDOM', 'For Testing')
54 | ]
55 |
56 | episode_id = models.UUIDField(default=uuid.uuid4, editable=False)
57 | traj_count = models.AutoField(primary_key=True)
58 | # NOTE: In Utopia, this should be defined in subclasses.
59 | # But => both sub & super have diff. PKs - Error.
60 | env_id = models.CharField(max_length=50)
61 | task_id = models.CharField(max_length=50)
62 | traj_steps = models.IntegerField()
63 | time_stamp = models.TimeField(auto_now_add=True, auto_now=False)
64 | data_path = models.CharField(max_length=1024)
65 | is_archived = models.BooleanField(default=False)
66 | episode_type = models.CharField(max_length=30, choices=EPISODE_TYPE_CHOICES, default='EPISODE_ROBOT_PLAY')
67 | tag = models.ForeignKey(TrajectoryTag, on_delete=models.SET_NULL, null=True)
68 |
69 | def __str__(self):
70 | return "{} - {} : {}".format(self.env_id, self.task_id, self.episode_id)
71 |
72 | class RLBenchTrajectory(Trajectory):
73 | ''' Trajectory table for RLBench environment. '''
74 | pass
75 |
76 | class SurrealRoboticsSuiteTrajectory(Trajectory):
77 | ''' Trajectory table for Surreal Robotics Suite environment. '''
78 | pass
79 |
80 | class ArchiveFile(models.Model):
81 | ''' Table describing where episodes are stored in archives.
82 | This information is relevant if you want to download or extract a specific episode from the archives they are distributed in.
83 |
84 | + Attributes:
85 | - trajectory: Foreign key into the `Trajectory` table.
86 | - env_id: Which env this trajectory belongs to.
87 | > Same as `trajectory.env_id`.
88 | - archive_file: Name of the archive file containing the corresponding episode.
89 | '''
90 | trajectory = models.ForeignKey(Trajectory, on_delete=models.CASCADE)
91 | env_id = models.CharField(max_length=50)
92 | archive_file = models.CharField(max_length=1024)
93 |
94 | def __str__(self):
95 | return "{} : {}".format(self.trajectory.episode_id, self.archive_file)
96 |
--------------------------------------------------------------------------------
/model/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import torch.nn.functional as F
5 | from tqdm import tqdm # TODO : Remove TQDM
6 | from tensorboardX import SummaryWriter
7 | from torch.utils.data import DataLoader
8 |
9 | from models import *
10 |
11 | device = torch.device('gpu' if torch.cuda.is_available() else 'cpu')
12 |
13 | def train(config):
14 | deg = config.deg()
15 | tensorboard_writer = SummaryWriter(logdir=config.tensorboard_path)
16 |
17 | vobs_dim, dof_dim = deg.obs_space[deg.vis_obv_key], deg.obs_space[deg.dof_obv_key]
18 | act_dim = deg.action_space
19 |
20 | perception_module = PerceptionModule(vobs_dim, dof_dim, config.visual_state_dim).to(device)
21 | visual_goal_encoder = VisualGoalEncoder(config.visual_state_dim, config.goal_dim).to(device)
22 | plan_recognizer = PlanRecognizerModule(config.combined_state_dim, act_dim, config.latent_dim).to(device)
23 | plan_proposer = PlanProposerModule(config.combined_state_dim, config.goal_dim, config.latent_dim).to(device)
24 | control_policy = ControlPolicy(act_dim, config.combined_state_dim, config.goal_dim, config.latent_dim).to(device)
25 |
26 | params = list(perception_module.parameters()) + list(visual_goal_encoder.parameters())
27 | params += list(plan_recognizer.parameters()) + list(plan_proposer.parameters()) + list(control_policy.parameters())
28 |
29 | print("Number of parameters : {}".format(len(params)))
30 | optimizer = torch.optim.Adam(params, lr=config.learning_rate)
31 |
32 | if(config.save_graphs):
33 | tensorboard_writer.add_graph(perception_module)
34 | tensorboard_writer.add_graph(visual_goal_encoder)
35 | tensorboard_writer.add_graph(plan_recognizer)
36 | tensorboard_writer.add_graph(plan_proposer)
37 | tensorboard_writer.add_graph(control_policy)
38 |
39 | if(config.resume):
40 | # TODO : IMPORTANT - Check if file exist before loading
41 | # TODO : Implement load & save functions within the class for easier loading and saving
42 | perception_module.load_state_dict(torch.load(os.path.join(config.models_save_path, 'perception.pth')))
43 | visual_goal_encoder.load_state_dict(torch.load(os.path.join(config.models_save_path, 'visual_goal.pth')))
44 | plan_recognizer.load_state_dict(torch.load(os.path.join(config.models_save_path, 'plan_recognizer.pth')))
45 | plan_proposer.load_state_dict(torch.load(os.path.join(config.models_save_path, 'plan_proposer.pth')))
46 | control_policy.load_state_dict(torch.load(os.path.join(config.models_save_path, 'control_policy.pth')))
47 | optimizer.load_state_dict(torch.load(os.path.join(config.models_save_path, 'optimizer.pth')))
48 |
49 | print("Run : `tensorboard --logdir={} --host '0.0.0.0' --port 6006`".format(config.tensorboard_path))
50 | dataloader = deg.get_traj_dataloader(batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers)
51 |
52 | def inference(trajectory, goal_state):
53 | visual_obvs, dof_obs, action = trajectory[deg.vis_obv_key], trajectory[deg.dof_obv_key], trajectory['action']
54 | batch_size, seq_len = visual_obvs.shape[0], visual_obvs.shape[1]
55 |
56 | visual_obvs = visual_obvs.reshape(batch_size * seq_len, vobs_dim[2], vobs_dim[0], vobs_dim[1])
57 | dof_obs = dof_obs.reshape(batch_size * seq_len, dof_dim)
58 | actions = trajectory['action'].reshape(batch_size * seq_len, -1)
59 |
60 | states = perception_module(visual_obvs, dof_obs) # DEBUG : Might raise in-place errors
61 | states = states.reshape(batch_size, seq_len, config.combined_state_dim)
62 | inital_state = states[:, 0]
63 |
64 | prior_z, prior_z_dist, _, _ = plan_proposer(inital_state, goal_state)
65 | post_z, post_z_dist, _, _ = plan_recognizer(states, action)
66 |
67 | # NOTE : goal_state dims - [batch_size, goal_dim]
68 | goal_state = goal_state.unsqueeze(1)
69 | goal_states = goal_state.repeat(1, seq_len, 1)
70 | post_zs = post_z.unsqueeze(1)
71 | post_zs = post_zs.repeat(1, seq_len, 1)
72 |
73 | goal_states = goal_states.reshape(batch_size * seq_len, -1)
74 | post_zs = post_zs.reshape(batch_size * seq_len, -1)
75 | states = states.reshape(batch_size * seq_len, -1)
76 |
77 | pi, logp_a = control_policy(state=states, goal=goal_states,
78 | zp=post_zs, action=actions)
79 |
80 | loss_pi = -logp_a.mean()
81 | loss_kl = torch.distributions.kl_divergence(post_z_dist, prior_z_dist).mean()
82 | loss = loss_pi + config.beta * loss_kl
83 | return loss
84 |
85 | for epoch in tqdm(range(config.max_epochs), desc="Check Tensorboard"):
86 | # NOTE : God bless detect_anomaly() 🙏🙏
87 | #with torch.autograd.detect_anomaly():
88 | for i, trajectory in enumerate(dataloader):
89 | trajectory = {key : trajectory[key].float().to(device) for key in trajectory.keys()}
90 |
91 | last_obvs = trajectory[deg.vis_obv_key][:, -1].reshape(trajectory[deg.vis_obv_key].shape[0], vobs_dim[2], vobs_dim[0], vobs_dim[1])
92 | goal_state = perception_module(last_obvs)
93 | goal_state, _, _, _ = visual_goal_encoder(goal_state)
94 |
95 | loss = inference(trajectory, goal_state)
96 | optimizer.zero_grad()
97 | loss.backward()
98 | optimizer.step()
99 |
100 | tensorboard_writer.add_scalar('loss/total', loss, epoch * len(dataloader.dataset) + i)
101 |
102 | if int(epoch % config.save_interval_epoch) == 0:
103 | torch.save(perception_module.state_dict(), os.path.join(config.models_save_path, 'perception.pth'))
104 | torch.save(visual_goal_encoder.state_dict(), os.path.join(config.models_save_path, 'visual_goal.pth'))
105 | torch.save(plan_recognizer.state_dict(), os.path.join(config.models_save_path, 'plan_recognizer.pth'))
106 | torch.save(plan_proposer.state_dict(), os.path.join(config.models_save_path, 'plan_proposer.pth'))
107 | torch.save(control_policy.state_dict(), os.path.join(config.models_save_path, 'control_policy.pth'))
108 | torch.save(optimizer.state_dict(), os.path.join(config.models_save_path, 'optimizer.pth'))
109 |
110 | if __name__ == '__main__':
111 | from model_config import get_model_args
112 | config = get_model_args()
113 | torch.manual_seed(config.seed)
114 |
115 | train(config)
--------------------------------------------------------------------------------
/dataset_env/file_storage.py:
--------------------------------------------------------------------------------
1 | import re
2 | import os
3 | import sys
4 | import uuid
5 | import torch
6 | import pickle
7 | import tarfile
8 | import numpy as np
9 | import torchvision
10 | from random import randint
11 | from data_config import get_dataset_args, ep_type
12 |
13 | from PIL import Image
14 |
15 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../'))
16 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../db/'))
17 | config = get_dataset_args()
18 |
19 | from traj_db.models import ArchiveFile
20 |
21 | def store_trajectoy(trajectory, episode_type=config.episode_type, task=None):
22 | '''
23 | Save trajectory to the corresponding database based on env and env_type specified in config.
24 | + Arguments:
25 | - trajectory: {deg.vis_obv_key : np.array([n]), deg.dof_obv_key : np.array([n]), 'action' : np.array([n])}
26 | - episode_type [optional]: tag to store trajectories with (eg. 'teleop' or 'imitation')
27 | '''
28 | if 'EPISODE_' not in episode_type:
29 | episode_type = ep_type(episode_type)
30 | if task is None:
31 | task = config.env_type
32 | assert 'action' in trajectory.keys()
33 |
34 | # NOTE : Current data_path is a placeholder. Edited below with UUID.
35 | metadata = config.traj_db(task_id=task, env_id=config.env,
36 | data_path=config.data_path, episode_type=episode_type, traj_steps=trajectory['action'].shape[0])
37 |
38 | metadata.save()
39 | metadata.data_path = os.path.join(config.data_path, "traj_{}.pt".format(metadata.episode_id))
40 | metadata.save()
41 |
42 | with open(metadata.data_path, 'wb') as file:
43 | pickle.dump(trajectory, file, protocol=pickle.HIGHEST_PROTOCOL)
44 |
45 | def get_trajectory(episode_type=None, index=None, episode_id=None):
46 | '''
47 | Gets a particular trajectory from the corresponding database based on env and env_type specified in config.
48 | + Arguments:
49 | - episode_type [optional]: if you want trajectory specific to one episode_type (eg. 'teleop' or 'imitation')
50 | - index [optional]: get trajectory at a particular index
51 | - episode_id [optional]: get trajectory by it's episode_id (primary key)
52 |
53 | + NOTE: either index or episode_id should be not None.
54 | + NOTE: episode_type, env_type become POINTLESS when you pass episode_id.
55 | '''
56 | # TODO : If trajectory is in archive-file, get it from there
57 | if episode_id is None and index is None:
58 | return [get_trajectory(episode_id=traj_obj.episode_id) for traj_obj in config.traj_db.objects.all()]
59 |
60 | if index is not None:
61 | if episode_type is None: # TODO : Clean code
62 | metadata = config.traj_db.objects.filter(task_id=config.env_type)[index] if config.get_by_task_id else config.traj_db.objects.all()[index]
63 | else:
64 | metadata = config.traj_db.objects.filter(task_id=config.env_type, episode_type=episode_type)[index] if config.get_by_task_id else config.traj_db.objects.filter(episode_type=episode_type)[index]
65 | elif episode_id is not None:
66 | episode_id = str(episode_id)
67 | metadata = config.traj_db.objects.get(episode_id=uuid.UUID(episode_id))
68 |
69 | with open(metadata.data_path, 'rb') as file:
70 | trajectory = pickle.load(file)
71 |
72 | return trajectory
73 |
74 | def get_random_trajectory(episode_type=None):
75 | '''
76 | Gets a random trajectory from the corresponding database based on env and env_type specified in config.
77 | + Arguments:
78 | - episode_type [optional]: if you want trajectory specific to one episode_type (eg. 'teleop' or 'imitation')
79 | '''
80 | count = config.traj_db.objects.count()
81 | random_index = randint(1, count)
82 | if episode_type is None:
83 | metadata = config.traj_db.objects.filter(task_id=config.env_type)[random_index] if config.get_by_task_id else config.traj_db.objects.all()[random_index]
84 | else:
85 | metadata = config.traj_db.objects.filter(task_id=config.env_type, episode_type=episode_type)[random_index] if config.get_by_task_id else config.traj_db.objects.filter(episode_type=episode_type)[random_index]
86 |
87 | episode_id = str(metadata.episode_id)
88 | task_id = metadata.task_id
89 | trajectory = get_trajectory(episode_id=episode_id)
90 |
91 | return trajectory, episode_id, task_id
92 |
93 | def create_video(trajectory):
94 | '''
95 | Creates videos and stores video, the initial and the final frame in the paths specified in data_config.
96 | + Arguments:
97 | - trajectory: {deg.vis_obv_key : np.array([n]), deg.dof_obv_key : np.array([n]), 'action' : np.array([n])}
98 | '''
99 | frames = trajectory[config.obv_keys['vis_obv_key']].astype(np.uint8)
100 | assert frames.shape[-1] == 3
101 |
102 | inital_obv, goal_obv = Image.fromarray(frames[0]), Image.fromarray(frames[-1])
103 | inital_obv.save(os.path.join(config.media_dir, 'inital.png'))
104 | goal_obv.save(os.path.join(config.media_dir, 'goal.png'))
105 |
106 | if type(frames) is not torch.Tensor:
107 | frames = torch.from_numpy(frames)
108 |
109 | torchvision.io.write_video(config.vid_path, frames, config.fps)
110 | return config.vid_path
111 |
112 | # NOT TESTED
113 | def archive_traj_task(task=config.env_type, episode_type=None, file_name=None):
114 | '''
115 | Archives trajectories by task (env_type)
116 | + Arguments:
117 | - task: config.env_type - group and archive them all.
118 | - episode_type [optional]: store trajectories w/ same task, episode_type together.
119 | - file_name: the name of the archive file. [NOTE: NOT THE PATH]
120 | > NOTE : default file_name: `env_task.tar.gz`
121 | '''
122 | if episode_type is None:
123 | objects = config.traj_db.objects.get(task_id=task)
124 | f_name = "{}_{}.tar.gz".format(config.env, config.env_type)
125 | else:
126 | objects = config.traj_db.objects.get(task_id=task, episode_type=episode_type)
127 | f_name = "{}_{}_{}.tar.gz".format(config.env, config.env_type, episode_type)
128 |
129 | if file_name is None:
130 | file_name = f_name
131 | file_name = os.path.join(config.archives_path, file_name)
132 |
133 | tar = tarfile.open(file_name, "w:gz")
134 | for metadata in objects:
135 | if metadata.is_archived == True:
136 | continue
137 |
138 | metadata.is_archived = True
139 | metadata.save()
140 |
141 | tar.add(metadata.data_path)
142 |
143 | archive = ArchiveFile(trajectory=metadata, env_id=metadata.env_id, archive_file=file_name)
144 | archive.save()
145 | tar.close()
146 |
147 | def delete_trajectory(episode_id):
148 | obj = config.traj_db.objects.get(episode_id=uuid.UUID(episode_id))
149 | if os.path.exists(obj.data_path):
150 | os.remove(obj.data_path)
151 | obj.delete()
152 |
153 | def flush_traj_db():
154 | raise NotImplementedError
--------------------------------------------------------------------------------
/dataset_env/surreal_deg.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import numpy as np
4 | from collections import OrderedDict
5 |
6 | from deg_base import DataEnvGroup
7 | from file_storage import store_trajectoy
8 |
9 | ENV_PATH = '/scratch/envs/robosuite'
10 | sys.path.append(ENV_PATH)
11 |
12 | import robosuite as suite
13 | from robosuite.wrappers import IKWrapper
14 | import robosuite.utils.transform_utils as T
15 |
16 | class SurrealDataEnvGroup(DataEnvGroup):
17 | ''' DataEnvGroup for Surreal Robotics Suite environment.
18 |
19 | + The observation space can be modified through `global_config.env_args`
20 | + Observation space:
21 | - 'robot-state': proprioceptive feature - vector of:
22 | > cos and sin of robot joint positions
23 | > robot joint velocities
24 | > current configuration of the gripper.
25 | - 'object-state': object-centric feature
26 | - 'image': RGB/RGB-D image
27 | > (256 x 256 by default)
28 | - Refer: https://github.com/StanfordVL/robosuite/tree/master/robosuite/environments
29 |
30 | + The action spaces by default are joint velocities and gripper actuations.
31 | - Dimension : [8]
32 | - To use the end-effector action-space use inverse-kinematics using IKWrapper.
33 | - Refer: https://github.com/StanfordVL/robosuite/tree/master/robosuite/wrappers
34 | '''
35 | def __init__(self, get_episode_type=None):
36 | super(SurrealDataEnvGroup, self).__init__(get_episode_type)
37 | assert self.env_name == 'SURREAL'
38 |
39 | self.vis_obv_key = 'image'
40 | self.dof_obv_key = 'robot-state'
41 | self.obs_space = {self.vis_obv_key: (256, 256, 3), self.dof_obv_key: (30)}
42 | self.action_space = (8)
43 |
44 | def get_env(self):
45 | env = suite.make(self.config.env_type, **self.config.env_args)
46 | return env
47 |
48 | def _get_obs(self, obs, key):
49 | return obs[key]
50 |
51 | def play_trajectory(self):
52 | # TODO
53 | # Refer https://github.com/StanfordVL/robosuite/blob/master/robosuite/scripts/playback_demonstrations_from_hdf5.py
54 | raise NotImplementedError
55 |
56 | def teleoperate(self, demons_config):
57 | env = self.get_env()
58 | # Need to use inverse-kinematics controller to set position using device
59 | env = IKWrapper(env)
60 |
61 | if demons_config.device == "keyboard":
62 | from robosuite.devices import Keyboard
63 | device = Keyboard()
64 | env.viewer.add_keypress_callback("any", device.on_press)
65 | env.viewer.add_keyup_callback("any", device.on_release)
66 | env.viewer.add_keyrepeat_callback("any", device.on_press)
67 | elif demons_config.device == "spacemouse":
68 | from robosuite.devices import SpaceMouse
69 | device = SpaceMouse()
70 |
71 | for run in range(demons_config.n_runs):
72 | obs = env.reset()
73 | env.set_robot_joint_positions([0, -1.18, 0.00, 2.18, 0.00, 0.57, 1.5708])
74 | # rotate the gripper so we can see it easily - NOTE : REMOVE MAYBE
75 | env.viewer.set_camera(camera_id=2)
76 | env.render()
77 | device.start_control()
78 |
79 | reset = False
80 | task_completion_hold_count = -1
81 | step = 0
82 | tr_vobvs, tr_dof, tr_actions = [], [], []
83 |
84 | while not reset:
85 | if int(step % demons_config.collect_freq) == 0:
86 | tr_vobvs.append(np.array(obs[self.vis_obv_key]))
87 | tr_dof.append(np.array(obs[self.dof_obv_key].flatten()))
88 |
89 | device_state = device.get_controller_state()
90 | dpos, rotation, grasp, reset = (
91 | device_state["dpos"],
92 | device_state["rotation"],
93 | device_state["grasp"],
94 | device_state["reset"],
95 | )
96 |
97 | current = env._right_hand_orn
98 | drotation = current.T.dot(rotation)
99 | dquat = T.mat2quat(drotation)
100 | grasp = grasp - 1.
101 | ik_action = np.concatenate([dpos, dquat, [grasp]])
102 |
103 | obs, _, done, _ = env.step(ik_action)
104 | env.render()
105 |
106 | joint_velocities = np.array(env.controller.commanded_joint_velocities)
107 | if env.env.mujoco_robot.name == "sawyer":
108 | gripper_actuation = np.array(ik_action[7:])
109 | elif env.env.mujoco_robot.name == "baxter":
110 | gripper_actuation = np.array(ik_action[14:])
111 |
112 | # NOTE: Action for the normal environment (not inverse kinematic)
113 | action = np.concatenate([joint_velocities, gripper_actuation], axis=0)
114 |
115 | if int(step % demons_config.collect_freq) == 0:
116 | tr_actions.append(action)
117 |
118 | if (int(step % demons_config.flush_freq) == 0) or (demons_config.break_traj_success and task_completion_hold_count == 0):
119 | print("Storing Trajectory")
120 | trajectory = {self.vis_obv_key : np.array(tr_vobvs), self.dof_obv_key : np.array(tr_dof), 'action' : np.array(tr_actions)}
121 | store_trajectoy(trajectory, 'teleop')
122 | trajectory, tr_vobvs, tr_dof, tr_actions = {}, [], [], []
123 |
124 | if demons_config.break_traj_success and env._check_success():
125 | if task_completion_hold_count > 0:
126 | task_completion_hold_count -= 1 # latched state, decrement count
127 | else:
128 | task_completion_hold_count = 10 # reset count on first success timestep
129 | else:
130 | task_completion_hold_count = -1
131 |
132 | step += 1
133 |
134 | env.close()
135 |
136 | def random_trajectory(self, demons_config):
137 | env = self.get_env()
138 | obs = env.reset()
139 | env.set_robot_joint_positions([0, -1.18, 0.00, 2.18, 0.00, 0.57, 1.5708])
140 |
141 | tr_vobvs, tr_dof, tr_actions = [], [], []
142 | for step in range(demons_config.flush_freq):
143 | tr_vobvs.append(np.array(obs[self.vis_obv_key]))
144 | tr_dof.append(np.array(obs[self.dof_obv_key].flatten()))
145 |
146 | action = np.random.randn(env.dof)
147 | obs, reward, done, info = env.step(action)
148 |
149 | tr_actions.append(action)
150 |
151 | print("Storing Trajectory")
152 | trajectory = {self.vis_obv_key : np.array(tr_vobvs), self.dof_obv_key : np.array(tr_dof), 'action' : np.array(tr_actions)}
153 | store_trajectoy(trajectory, 'random')
154 | env.close()
155 |
156 |
157 | if __name__ == '__main__':
158 | from torch.utils.data import DataLoader
159 | print("=> Testing surreal_deg.py")
160 |
161 | deg = SurrealDataEnvGroup()
162 | print(deg.obs_space[deg.vis_obv_key], deg.action_space)
163 | print(deg.get_env().reset()[deg.vis_obv_key].shape)
164 |
165 | traj_data = DataLoader(deg.traj_dataset, batch_size=1, shuffle=True, num_workers=1)
166 | print(next(iter(traj_data))[deg.dof_obv_key].shape)
167 |
--------------------------------------------------------------------------------
/model/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.nn.parameter import Parameter
4 | import numpy as np
5 | from torch.distributions.normal import Normal
6 |
7 | class GaussianNetwork(torch.nn.Module):
8 | def __init__(self, in_dim, latent_dim):
9 | super(GaussianNetwork, self).__init__()
10 | self.in_dim = in_dim
11 | self.latent_dim = latent_dim
12 |
13 | self.lin1 = torch.nn.Linear(self.in_dim, 2048)
14 | self.lin2 = torch.nn.Linear(2048, 2048)
15 | self.relu = torch.nn.ReLU()
16 |
17 | self.hidden2mean = torch.nn.Linear(2048, self.latent_dim)
18 | self.hidden2logv = torch.nn.Linear(2048, self.latent_dim)
19 |
20 | def forward(self, x):
21 | output = self.relu(self.lin1(x))
22 | output = self.relu(self.lin2(output))
23 |
24 | mean = self.hidden2mean(output)
25 | logv = self.hidden2logv(output)
26 | std = torch.exp(0.5 * logv)
27 |
28 | z = torch.randn([self.latent_dim]) # TODO: shape might be [[self.latent_dim]]
29 | z = z * std + mean
30 |
31 | return z, Normal(mean, std), mean, std
32 |
33 | class SpatialSoftmax(torch.nn.Module):
34 | '''
35 | Spatial softmax is used to find the expected pixel location of feature maps.
36 | Source : https://gist.github.com/jeasinema/1cba9b40451236ba2cfb507687e08834
37 | Output : Tensor of shape (N, C * 2) - don't use flatten!
38 | '''
39 | def __init__(self, height, width, channel, temperature=None, data_format='NCHW'):
40 | super(SpatialSoftmax, self).__init__()
41 | self.data_format = data_format
42 | self.height = height
43 | self.width = width
44 | self.channel = channel
45 |
46 | if temperature:
47 | self.temperature = Parameter(torch.ones(1) * temperature)
48 | else:
49 | self.temperature = 1.
50 |
51 | pos_x, pos_y = np.meshgrid(np.linspace(-1., 1., self.height), np.linspace(-1., 1., self.width))
52 | pos_x = torch.FloatTensor(pos_x.reshape(self.height * self.width))
53 | pos_y = torch.FloatTensor(pos_y.reshape(self.height * self.width))
54 | self.register_buffer('pos_x', pos_x)
55 | self.register_buffer('pos_y', pos_y)
56 |
57 | def forward(self, feature):
58 | if self.data_format == 'NHWC':
59 | feature = feature.transpose(1, 3).tranpose(2, 3).view(-1, self.height * self.width)
60 | else:
61 | feature = feature.view(-1, self.height * self.width)
62 |
63 | softmax_attention = F.softmax(feature / self.temperature, dim=-1)
64 | expected_x = torch.sum(self.pos_x * softmax_attention, dim=1, keepdim=True)
65 | expected_y = torch.sum(self.pos_y * softmax_attention, dim=1, keepdim=True)
66 | expected_xy = torch.cat([expected_x, expected_y], 1)
67 | feature_keypoints = expected_xy.view(-1, self.channel * 2)
68 |
69 | return feature_keypoints
70 |
71 | class ConditionalVAE(torch.nn.Module):
72 | '''
73 | CoditionalVAE : z ~ p(z | x, c)
74 | + input_size : combined dimension of x, c.
75 | + layer_sizes : list of dims for diff layers of the encoder.
76 | '''
77 | def __init__(self, input_size, latent_size=256, layer_sizes=[2048] * 4, decoder=False):
78 | super(ConditionalVAE, self).__init__()
79 | self.input_size = input_size
80 | self.latent_size = latent_size
81 | self.layer_sizes = layer_sizes
82 |
83 | assert type(self.layer_sizes) == list
84 |
85 | self.decoder = decoder
86 |
87 | self.encoder_network = []
88 | self.encoder_network.append(torch.nn.Linear(self.input_size, self.layer_sizes[0]))
89 | self.encoder_network.append(torch.nn.ReLU())
90 |
91 | for i in range(len(self.layer_sizes[:-1])):
92 | in_size, out_size = self.layer_sizes[i], self.layer_sizes[i + 1]
93 | self.encoder_network.append(torch.nn.Linear(in_size, out_size))
94 | self.encoder_network.append(torch.nn.ReLU())
95 |
96 | self.encoder_network = torch.nn.Sequential(*self.encoder_network)
97 | self.hidden2mean = torch.nn.Linear(self.layer_sizes[-1], self.latent_size)
98 | self.hidden2logv = torch.nn.Linear(self.layer_sizes[-1], self.latent_size) # TODO : Maybe keep logstd fixed
99 | self.softplus = torch.nn.Softplus()
100 |
101 | if self.decoder:
102 | self.dlayers_size = self.layer_sizes.reverse()
103 | self.decoder_network = torch.nn.Sequential()
104 | self.decoder_network.append(torch.nn.Linear(self.latent_size, self.dlayers_size[0]))
105 | self.decoder_network.append(torch.nn.ReLU())
106 |
107 | for i in range(len(self.dlayers_size[:-1])):
108 | in_size, out_size = self.dlayers_size[i], self.dlayers_size[i + 1]
109 | self.decoder_network.append(torch.nn.Linear(in_size, out_size))
110 | self.decoder_network.append(torch.nn.ReLU())
111 |
112 | self.decoder_network.append(torch.nn.Linear(self.dlayers_size[-1], self.input_size))
113 | self.decoder_network = torch.nn.Sequential(*self.decoder_network)
114 |
115 | def forward(self, x, c):
116 | batch_size = x.size(0)
117 | assert x.size()[-1] + c.size()[-1] == self.input_size
118 |
119 | x = torch.cat((x, c), dim=-1)
120 | x = self.encoder_network(x)
121 |
122 | mean = self.hidden2mean(x)
123 | logv = self.hidden2logv(x)
124 | std = torch.exp(0.5 * logv)
125 |
126 | z = torch.randn([batch_size, self.latent_size])
127 | z = z * std + mean
128 | dist = Normal(mean, std)
129 |
130 | if self.decoder:
131 | recon_x = self.decoder_network(z)
132 | return recon_x, z, dist, mean, logv
133 |
134 | return z, dist, mean, logv
135 |
136 | class SeqVAE(torch.nn.Module):
137 | '''
138 | Input sequence shape : (batch, seq, feature)
139 | '''
140 | def __init__(self, input_size, latent_size=256, hidden_size=2048, rnn_type='RNN', decoder=False, num_layers=2, bidirectional=True):
141 | super(SeqVAE, self).__init__()
142 | self.input_size = input_size
143 | self.hidden_size = hidden_size
144 | self.latent_size = latent_size
145 |
146 | self.num_layers = num_layers
147 | self.rnn_type = rnn_type.upper()
148 | self.decoder = decoder
149 | self.bidirectional = bidirectional
150 |
151 | assert self.rnn_type in ['LSTM', 'GRU', 'RNN']
152 |
153 | rnn = {'LSTM' : torch.nn.LSTM, 'RNN' : torch.nn.RNN, 'GRU' : torch.nn.GRU}[self.rnn_type]
154 | self.encoder_rnn = rnn(self.input_size, self.hidden_size, num_layers=self.num_layers, bidirectional=self.bidirectional, batch_first=True)
155 |
156 | if self.decoder:
157 | self.decoder_rnn = rnn(self.hidden_size, self.input_size, num_layers=self.num_layers, bidirectional=self.bidirectional, batch_first=True)
158 |
159 | self.hidden_factor = (2 if bidirectional else 1) * num_layers
160 | self.hidden2mean = torch.nn.Linear(self.hidden_size * self.hidden_factor, self.latent_size)
161 | self.hidden2logv = torch.nn.Linear(self.hidden_size * self.hidden_factor, self.latent_size)
162 | self.softplus = torch.nn.Softplus() # TODO : Check w/ Softplus
163 |
164 | if self.decoder:
165 | self.latent2hidden = torch.nn.Linear(self.latent_size, self.hidden_size * self.hidden_factor)
166 |
167 | def forward(self, input_sequence):
168 | batch_size, seq_len = input_sequence.shape[0], input_sequence.shape[1]
169 | # NOTE : Assuming that all sequences are of the same length
170 | #sorted_lengths, sorted_idx = torch.sort(seq_len, descending=True)
171 | #input_sequence = input_sequence[sorted_idx]
172 | #input_sequence = torch.nn.rnn_utils.pack_padded_sequence(input_sequence, sorted_lengths.data.tolist(), batch_first=True)
173 | _, hidden = self.encoder_rnn(input_sequence)
174 | if self.bidirectional or self.num_layers > 1:
175 | hidden = hidden.view(batch_size, self.hidden_size * self.hidden_factor)
176 | else:
177 | hidden = hidden.squeeze()
178 |
179 | mean = self.hidden2mean(hidden)
180 | logv = self.hidden2logv(hidden)
181 | std = torch.exp(0.5 * logv)
182 |
183 | z = torch.randn([batch_size, self.latent_size])
184 | z = z * std + mean
185 | dist = Normal(mean, std)
186 |
187 | if self.decoder :
188 | hidden = self.latent2hidden(z)
189 |
190 | if self.bidirectional or self.num_layers > 1:
191 | hidden = hidden.view(self.hidden_factor, batch_size, self.hidden_size)
192 | else:
193 | hidden = hidden.unsqueeze(0)
194 |
195 | outputs, _ = self.decoder_rnn(packed_input, hidden)
196 | return outputs, z, dist, mean, logv
197 |
198 | return z, dist, mean, logv
199 |
200 | if __name__ == '__main__':
201 | data = torch.zeros([10,3,3,3])
202 | data[0,0,0,1] = 10
203 | data[0,1,1,1] = 10
204 | data[0,2,1,2] = 10
205 | layer = SpatialSoftmax(3, 3, 3, temperature=1)
206 | print(data.shape, layer(data).shape)
--------------------------------------------------------------------------------
/model/models.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | from torch.distributions.normal import Normal
5 | from layers import SpatialSoftmax, ConditionalVAE, SeqVAE, GaussianNetwork
6 |
7 | '''
8 | > D consists of paired {(O_t, a_t)}
9 | > O_t is the set of observations from each of the robot’s N sensory channels.
10 | > In our experiments, O= {I, p}.
11 | > I = an RGB image observation from a fixed first-person viewpoint.
12 | > p = the internal 8-DOF proprioceptive state of the agent.
13 |
14 | > Since our logs are raw observations,
15 | we define one encoder per sensory channel for N high-dimensional observations to one low-dimensional fused state
16 | s_t = concat([E_1(o1); ... ; E_N(o_N)])
17 | '''
18 |
19 | class PerceptionModule(torch.nn.Module):
20 | '''
21 | Maps raw observation (image & proprioception) O_t to a low dimension embedding s_t.
22 | TODO : Normalize proprioception to have zero mean and unit variance.
23 | '''
24 | def __init__(self, visual_obv_dim=[256, 256, 3], dof_obv_dim=[8], state_dim=64):
25 | super(PerceptionModule, self).__init__()
26 |
27 | self.visual_obv_dim = visual_obv_dim
28 | self.dof_obv_dim = dof_obv_dim
29 | self.state_dim = state_dim
30 |
31 | assert visual_obv_dim[0] == visual_obv_dim[1]
32 | # NOTE/TODO : IMPORTANT - whenever the dim. of the obvs changes, this has to be changed
33 | self.conv1 = torch.nn.Conv2d(3, 32, kernel_size=(8, 8), stride=4, padding=2)
34 | self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=(4, 4), stride=2)
35 | self.conv3 = torch.nn.Conv2d(64, 64, kernel_size=(3, 3), stride=1)
36 | self.ss = SpatialSoftmax(height=29, width=29, channel=64) # TODO : Change dim here
37 | self.lin1 = torch.nn.Linear(128, 512) # SS O/Ps shape (N, C * 2)
38 | self.lin2 = torch.nn.Linear(512, state_dim)
39 | self.relu = torch.nn.ReLU()
40 |
41 | def forward(self, visual_obv, dof_obv=None):
42 | output = self.relu(self.conv1(visual_obv))
43 | output = self.relu(self.conv2(output))
44 | output = self.conv3(output)
45 | output = self.ss(output)
46 | output = self.relu(self.lin1(output))
47 | output = self.lin2(output)
48 |
49 | if dof_obv is not None:
50 | output = torch.cat((output, dof_obv), -1)
51 | return output
52 |
53 | class VisualGoalEncoder(torch.nn.Module):
54 | '''
55 | Maps goal s_g to latent embedding z.
56 | + NOTE First pass through the PerceptionModule to get s_g = P(O_{-1}).
57 | + NOTE Pass only the visual observation.
58 | '''
59 | def __init__(self, state_dim=64, goal_dim=32):
60 | super(VisualGoalEncoder, self).__init__()
61 |
62 | self.state_dim = state_dim
63 | self.goal_dim = goal_dim
64 | self.goal_dist = GaussianNetwork(state_dim, goal_dim)
65 |
66 | def forward(self, visual_obv, perception_module=None):
67 | if(visual_obv.shape[-1] != self.state_dim):
68 | assert perception_module is not None
69 | visual_obv = perception_module(visual_obv, None)
70 | z, dist, mean, std = self.goal_dist(visual_obv)
71 |
72 | return z, dist, mean, std
73 |
74 | class PlanRecognizerModule(torch.nn.Module):
75 | '''
76 | A SeqVAE which maps the entire trajectory T to latent space z_p ~ q(z_p|T).
77 | Represents the posterior distribution over z_p.
78 | NOTE : Pass a *SINGLE* trajectory [[[Ot/st, at], ... K]] only, where Ot = [Vobv, DOF].
79 | NOTE : Trajectory shape : (1, K, 2) (Batch Size = 1).
80 | '''
81 | def __init__(self, state_dim=72, act_dim=8, latent_dim=256):
82 | super(PlanRecognizerModule, self).__init__()
83 | self.state_dim = state_dim
84 | self.act_dim = act_dim
85 | self.latent_dim = latent_dim
86 |
87 | self.seqVae = SeqVAE(input_size=state_dim + act_dim, latent_size=latent_dim)
88 |
89 | def forward(self, states, actions):
90 | traj_tensor = torch.cat((states, actions), -1)
91 |
92 | z, dist, mean, logv = self.seqVae(traj_tensor)
93 | return z, dist, mean, logv
94 |
95 | class PlanProposerModule(torch.nn.Module):
96 | '''
97 | A ConditionalVAE which maps initial state s_0 and goal z to latent space z_p ~ p(z_p|s_0, z).
98 | Represents a prior over z_p.
99 | '''
100 | def __init__(self, state_dim=72, goal_dim=32, latent_dim=256):
101 | super(PlanProposerModule, self).__init__()
102 | self.state_dim = state_dim
103 | self.goal_dim = goal_dim
104 | self.latent_dim = latent_dim
105 |
106 | self.cVae = ConditionalVAE(state_dim + goal_dim, latent_size=latent_dim)
107 |
108 | def forward(self, initial_obv, goal_obv, goal_encoder=None, perception_module=None):
109 | assert initial_obv.size()[-1] == self.state_dim
110 |
111 | if goal_obv.size()[-1] != self.goal_dim:
112 | assert goal_encoder is not None
113 | if(type(goal_encoder) == VisualGoalEncoder):
114 | goal_obv = goal_encoder(goal_obv, perception_module)
115 |
116 | z, dist, mean, logv = self.cVae(initial_obv, goal_obv)
117 | return z, dist, mean, logv
118 |
119 | class ControlPolicy(torch.nn.Module):
120 | '''
121 | TODO : Major changes here
122 | RNN based goal (z), z_p conditioned policy : a_t ~ \pi(a_t | s_t, z, z_p).
123 | '''
124 | def __init__(self, action_dim=8, state_dim=72, goal_dim=32, latent_dim=256,
125 | hidden_size=2048, batch_size=1, rnn_type='RNN', num_layers=2):
126 | super(ControlPolicy, self).__init__()
127 | self.state_dim = state_dim
128 | self.action_dim = action_dim
129 | self.goal_dim = goal_dim
130 | self.latent_dim = latent_dim
131 | self.batch_size = batch_size
132 | self.input_dim = self.state_dim + self.goal_dim + self.latent_dim
133 | self.hidden_size = hidden_size
134 | self.rnn_type = rnn_type.upper()
135 | self.num_layers = num_layers
136 |
137 | assert self.rnn_type in ['LSTM', 'GRU', 'RNN']
138 | self.rnn = {'LSTM' : torch.nn.LSTMCell, 'GRU' : torch.nn.GRUCell, 'RNN' : torch.nn.RNNCell}[self.rnn_type]
139 |
140 | self.rnn1 = self.rnn(self.input_dim, self.hidden_size)
141 | self.rnn2 = self.rnn(self.hidden_size, self.hidden_size)
142 |
143 | # NOTE : Original paper used a Mixture of Logistic (MoL) dist. Implement later.
144 | self.hidden2mean = torch.nn.Linear(self.hidden_size, self.action_dim)
145 | log_std = -0.5 * np.ones(self.action_dim, dtype=np.float32)
146 | self.log_std = torch.nn.Parameter(torch.as_tensor(log_std))
147 |
148 | self.relu = torch.nn.ReLU()
149 | self.tanh = torch.nn.Tanh()
150 |
151 | def _prepare_obs(self, state, goal, zp, goal_encoder=None, perception_module=None):
152 | if state.size()[-1] != self.state_dim:
153 | assert perception_module is not None
154 | state = perception_module(state)
155 | if goal.size()[-1] != self.goal_dim:
156 | assert goal_encoder is not None
157 | if(type(goal_encoder) == VisualGoalEncoder):
158 | goal = goal_encoder(goal, perception_module)
159 |
160 | obs = torch.cat((state, goal, zp), -1)
161 |
162 | # The hidden states of the RNN.
163 | self.h1 = torch.randn(obs.shape[0], self.hidden_size)
164 | self.h2 = torch.randn(obs.shape[0], self.hidden_size)
165 | if self.rnn_type == 'LSTM':
166 | self.c1 = torch.randn(obs.shape[0], self.hidden_size)
167 | self.c2 = torch.randn(obs.shape[0], self.hidden_size)
168 |
169 | return obs
170 |
171 | def _distribution(self, obs):
172 | if self.rnn_type == 'LSTM':
173 | self.h1, self.c1 = self.relu(self.rnn1(obs, (self.h1, self.c1)))
174 | self.h2, self.c2 = self.relu(self.rnn2(self.h1, (self.h2, self.c2)))
175 | else:
176 | self.h1 = self.relu(self.rnn1(obs, self.h1))
177 | self.h2 = self.relu(self.rnn2(self.h1, self.h2))
178 |
179 | mean = self.tanh(self.hidden2mean(self.h2))
180 | std = torch.exp(self.log_std)
181 | return Normal(mean, std)
182 |
183 | def _log_prob_from_distribution(self, policy, action):
184 | return policy.log_prob(action).sum(axis=-1)
185 |
186 | def forward(self, state, goal, zp, action=None, goal_encoder=None, perception_module=None):
187 | obs = self._prepare_obs(state, goal, zp, goal_encoder, perception_module)
188 | policy = self._distribution(obs)
189 |
190 | logp_a = None
191 | if action is not None:
192 | logp_a = self._log_prob_from_distribution(policy, action)
193 |
194 | return policy, logp_a
195 |
196 | def step(self, state, goal, zp, goal_encoder=None, perception_module=None):
197 | with torch.no_grad():
198 | obs = self._prepare_obs(state, goal, zp, goal_encoder, perception_module)
199 | policy = self._distribution(obs)
200 | action = policy.sample()
201 | logp_a = self._log_prob_from_distribution(policy, action)
202 |
203 | return action.numpy(), logp_a.numpy()
204 |
--------------------------------------------------------------------------------
/dataset_env/rlbench_deg.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import gym
4 | import numpy as np
5 |
6 | from deg_base import DataEnvGroup
7 | from file_storage import store_trajectoy
8 |
9 | ENV_PATH = '/scratch/envs/RLBench'
10 | sys.path.append(ENV_PATH)
11 |
12 | import rlbench.gym
13 | from rlbench.environment import Environment
14 | from rlbench.action_modes import ArmActionMode, ActionMode
15 | from rlbench.backend.observation import Observation
16 | from rlbench.observation_config import ObservationConfig
17 | from rlbench.tasks import ReachTarget
18 |
19 | class RLBenchDataEnvGroup(DataEnvGroup):
20 | ''' DataEnvGroup for RLBench environment.
21 |
22 | + The observation space can be modified through `global_config.env_args`
23 | + Observation space:
24 | - 'state': proprioceptive feature : [37] + task_specific
25 | > robot joint - velocities, positions, forces
26 | > gripper - pose, joint-position, touch_forces
27 | > task_low_dim_state
28 | - RGB-D/RGB image : (128 x 128 by default)
29 | > 'left_shoulder_rgb', 'right_shoulder_rgb'
30 | > 'front_rgb', 'wrist_rgb'
31 | - NOTE : Better to use only one of the RGB obvs rather than all, saves a lot of time while env creation.
32 |
33 | - Refer: https://github.com/stepjam/RLBench/blob/20988254b773aae433146fff3624d8bcb9ed7330/rlbench/observation_config.py
34 |
35 | + The action spaces by default are joint velocities [7] and gripper actuations [1].
36 | - Dimension : [8]
37 | '''
38 | def __init__(self, get_episode_type=None):
39 | super(RLBenchDataEnvGroup, self).__init__(get_episode_type)
40 | assert self.env_name == 'RLBENCH'
41 | self.env_obj = None
42 | self.use_gym = self.config.env_args['use_gym']
43 |
44 | self.observation_mode = self.config.env_args['observation_mode'].lower()
45 | self.left_obv_key = 'left_shoulder_rgb'
46 | self.right_obv_key = 'right_shoulder_rgb'
47 | self.wrist_obv_key = 'wrist_rgb'
48 | self.front_obv_key = 'front_rgb'
49 |
50 | if self.observation_mode not in ['vision', 'state'] and not self.use_gym:
51 | self.config.env_args['vis_obv_key'] = self.observation_mode
52 |
53 | self.vis_obv_key = self.config.env_args['vis_obv_key']
54 | self.dof_obv_key = 'state'
55 | self.env_action_key = 'joint_velocities'
56 | self.env_gripper_key = 'gripper_open'
57 |
58 | self.obs_space = {self.dof_obv_key : (37), self.left_obv_key : (128, 128, 3), self.right_obv_key : (128, 128, 3),
59 | self.wrist_obv_key : (128, 128, 3), self.front_obv_key: (128, 128, 3)}
60 |
61 | self.action_space, self.gripper_space = (7), (1)
62 | if self.config.env_args['combine_action_space']:
63 | self.action_space += self.gripper_space
64 |
65 | def get_env(self, task=None):
66 | task = task if task else self.config.env_type
67 | if self.use_gym:
68 | assert type(task) == str # NOTE : When using gym, the task has to be represented as a sting.
69 | assert self.observation_mode in ['vision', 'state']
70 |
71 | env = gym.make(task, observation_mode=self.config.env_args['observation_mode'],
72 | render_mode=self.config.env_args['render_mode'])
73 | self.env_obj = env
74 | else:
75 | obs_config = ObservationConfig()
76 | if self.observation_mode == 'vision':
77 | obs_config.set_all(True)
78 | elif self.observation_mode == 'state':
79 | obs_config.set_all_high_dim(False)
80 | obs_config.set_all_low_dim(True)
81 | else:
82 | obs_config_dict = {self.left_obv_key : obs_config.left_shoulder_camera,
83 | self.right_obv_key : obs_config.right_shoulder_camera,
84 | self.wrist_obv_key : obs_config.wrist_camera,
85 | self.front_obv_key: obs_config.front_camera}
86 |
87 | assert self.observation_mode in obs_config_dict.keys()
88 |
89 | obs_config.set_all_high_dim(False)
90 | obs_config_dict[self.observation_mode].set_all(True)
91 | obs_config.set_all_low_dim(True)
92 |
93 | # TODO : Write code to change it from env_args
94 | action_mode = ActionMode(ArmActionMode.ABS_JOINT_VELOCITY)
95 | self.env_obj = Environment(action_mode, obs_config=obs_config, headless=True)
96 |
97 | task = task if task else ReachTarget
98 | if type(task) == str:
99 | task = task.split('-')[0]
100 | task = self.env_obj._string_to_task(task)
101 |
102 | self.env_obj.launch()
103 | env = self.env_obj.get_task(task) # NOTE : `env` refered as `task` in RLBench docs.
104 | return env
105 |
106 | def _get_obs(self, obs, key):
107 | assert obs is not None and key is not None
108 | if type(obs) == tuple:
109 | obs = obs[1]
110 | if type(obs) == dict:
111 | return obs[key]
112 | elif type(obs) == Observation:
113 | if key == 'state':
114 | return obs.get_low_dim_data()
115 | return getattr(obs, key)
116 |
117 | def shutdown_env(self):
118 | if self.env_obj is None:
119 | print("Environment not created, call `.get_env()`")
120 | elif self.use_gym:
121 | self.env_obj.close()
122 | else:
123 | self.env_obj.shutdown()
124 | self.env_obj = None
125 |
126 | def teleoperate(self, demons_config, task=None):
127 | if self.config.env_args['keyboard_teleop']:
128 | raise NotImplementedError
129 | else:
130 | if self.env_obj is None or task is None:
131 | env = self.get_env(task)
132 | else:
133 | if type(task) == str:
134 | task = self.env_obj._string_to_task(task.split('-')[0])
135 | env = self.env_obj.get_task(task)
136 |
137 | if self.use_gym:
138 | demos = env.task.get_demos(demons_config.n_runs, live_demos=True)
139 | else :
140 | demos = env.get_demos(demons_config.n_runs, live_demos=True)
141 | demos = np.array(demos).flatten()
142 |
143 | for i in range(demons_config.n_runs):
144 | sample = demos[i]
145 | if self.observation_mode != 'state':
146 | tr_vobvs = np.array([self._get_obs(obs, self.vis_obv_key) for obs in sample])
147 | tr_dof = np.array([self._get_obs(obs, self.dof_obv_key).flatten() for obs in sample])
148 | tr_actions = np.array([self._get_obs(obs, self.env_action_key).flatten() for obs in sample])
149 | tr_gripper = np.array([[self._get_obs(obs, self.env_gripper_key)] for obs in sample])
150 |
151 | if self.config.env_args['combine_action_space']:
152 | tr_actions = np.concatenate((tr_actions, tr_gripper), axis=-1)
153 |
154 | print("Storing Trajectory")
155 | trajectory = {self.dof_obv_key : tr_dof, 'action' : tr_actions}
156 | if self.observation_mode != 'state':
157 | trajectory.update({self.vis_obv_key : tr_vobvs})
158 | store_trajectoy(trajectory, episode_type='teleop', task=task)
159 |
160 | self.shutdown_env()
161 |
162 | def random_trajectory(self, demons_config):
163 | env = self.get_env()
164 | obs = env.reset()
165 |
166 | tr_vobvs, tr_dof, tr_actions = [], [], []
167 | for step in range(demons_config.flush_freq):
168 | if self.observation_mode != 'state':
169 | tr_vobvs.append(np.array(self._get_obs(obs, self.vis_obv_key)))
170 | tr_dof.append(np.array(self._get_obs(obs, self.dof_obv_key).flatten()))
171 |
172 | action = np.random.normal(size=self.action_space[0])
173 | obs, reward, done, info = env.step(action)
174 |
175 | tr_actions.append(action)
176 |
177 | print("Storing Trajectory")
178 | trajectory = {self.dof_obv_key : np.array(tr_dof), 'action' : np.array(tr_actions)}
179 | if self.observation_mode != 'state':
180 | trajectory.update({self.vis_obv_key : np.array(tr_vobvs)})
181 | store_trajectoy(trajectory, episode_type='random')
182 | self.shutdown_env()
183 |
184 | if __name__ == '__main__':
185 | from torch.utils.data import DataLoader
186 | print("=> Testing rlbench_deg.py")
187 |
188 | deg = RLBenchDataEnvGroup()
189 | print(deg.obs_space[deg.vis_obv_key], deg.action_space)
190 | env = deg.get_env(task='change_clock')
191 | obs = env.reset()
192 | print(deg._get_obs(obs, deg.dof_obv_key).shape)
193 | print(deg._get_obs(obs, 'task_low_dim_state').shape)
194 |
195 | #traj_data = DataLoader(deg.traj_dataset, batch_size=1, shuffle=False, num_workers=1)
196 | traj_data = deg.get_traj_dataloader(batch_size=5)
197 | for i, b in enumerate(traj_data):
198 | print(i, b[deg.dof_obv_key].shape)
199 |
200 | deg.shutdown_env()
201 |
--------------------------------------------------------------------------------
/collect_demons/imitate_play.py:
--------------------------------------------------------------------------------
1 | # TODO : *IMPORTANT* - change the imitation policy to be condtioned on the GROUND STATE rather than visual obv
2 | # While running this imitated policy, collect visual_obv w/ domain randomization
3 |
4 | import os
5 | import sys
6 | import torch
7 | import torch.nn.functional as F
8 | from tqdm import tqdm # TODO : Remove TQDMs
9 | from tensorboardX import SummaryWriter
10 | from torch.utils.data import DataLoader
11 | from torch.distributions.normal import Normal
12 |
13 | from demons_config import get_demons_args
14 |
15 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../model'))
16 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../datatset_env'))
17 |
18 | from models import PerceptionModule
19 | from model_config import get_model_args
20 |
21 | from file_storage import store_trajectoy
22 |
23 | device = torch.device('gpu' if torch.cuda.is_available() else 'cpu')
24 |
25 | class ImitationPolicy(torch.nn.Module):
26 | def __init__(self, action_dim=8, state_dim=72, hidden_size=2048, batch_size=1, rnn_type='RNN', num_layers=2):
27 | super(ImitationPolicy, self).__init__()
28 | self.state_dim = state_dim
29 | self.action_dim = action_dim
30 | self.batch_size = batch_size
31 | self.hidden_size = hidden_size
32 | self.rnn_type = rnn_type.upper()
33 | self.num_layers = num_layers
34 |
35 | assert self.rnn_type in ['LSTM', 'GRU', 'RNN']
36 | self.rnn = {'LSTM' : torch.nn.LSTMCell, 'GRU' : torch.nn.GRUCell, 'RNN' : torch.nn.RNNCell}[self.rnn_type]
37 |
38 | self.rnn1 = self.rnn(self.state_dim, self.hidden_size)
39 | self.rnn2 = self.rnn(self.hidden_size, self.hidden_size)
40 |
41 | # NOTE : Original paper used a Mixture of Logistic (MoL) dist. Implement later.
42 | self.hidden2mean = torch.nn.Linear(self.hidden_size, self.action_dim)
43 | log_std = -0.5 * np.ones(self.action_dim, dtype=np.np.float32)
44 | self.log_std = torch.nn.Parameter(torch.as_tensor(log_std))
45 |
46 | # The hidden states of the RNN.
47 | self.relu = torch.nn.ReLU()
48 | self.tanh = torch.nn.Tanh()
49 |
50 | def _prepare_obs(self, state, perception_module):
51 | if state.size()[-1] != self.state_dim:
52 | assert perception_module is not None
53 | state = perception_module(state)
54 |
55 | # The hidden states of the RNN.
56 | self.h1 = torch.randn(state.shape[0], self.hidden_size)
57 | self.h2 = torch.randn(state.shape[0], self.hidden_size)
58 | if self.rnn_type == 'LSTM':
59 | self.c1 = torch.randn(state.shape[0], self.hidden_size)
60 | self.c2 = torch.randn(state.shape[0], self.hidden_size)
61 |
62 | return state
63 |
64 | def _distribution(self, obs):
65 | if self.rnn_type == 'LSTM':
66 | self.h1, self.c1 = self.relu(self.rnn1(obs, (self.h1, self.c1)))
67 | self.h2, self.c2 = self.relu(self.rnn2(self.h1, (self.h2, self.c2)))
68 | else:
69 | self.h1 = self.relu(self.rnn1(obs, self.h1))
70 | self.h2 = self.relu(self.rnn2(self.h1, self.h2))
71 |
72 | mean = self.tanh(self.hidden2mean(self.h2))
73 | std = torch.exp(self.log_std)
74 | return Normal(mean, std)
75 |
76 | def _log_prob_from_distribution(self, policy, action):
77 | return policy.log_prob(action).sum(axis=-1)
78 |
79 | def forward(self, state, action=None, perception_module=None):
80 | obs = self._prepare_obs(state, perception_module)
81 | policy = self._distribution(obs)
82 |
83 | logp_a = None
84 | if action is not None:
85 | logp_a = self._log_prob_from_distribution(policy, action)
86 |
87 | return policy, logp_a
88 |
89 | def step(self, state, perception_module=None):
90 | with torch.no_grad():
91 | obs = self._prepare_obs(state, perception_module)
92 | policy = self._distribution(obs)
93 | action = policy.sample()
94 | logp_a = self._log_prob_from_distribution(policy, action)
95 |
96 | return action.numpy(), logp_a.numpy()
97 |
98 | def train_imitation(demons_config):
99 | model_config = get_model_args()
100 |
101 | deg = demons_config.deg(get_episode_type='EPISODE_ROBOT_PLAY')
102 |
103 | vobs_dim, dof_dim = deg.obs_space[deg.vis_obv_key], deg.obs_space[deg.dof_obv_key]
104 | act_dim = deg.action_space
105 |
106 | tensorboard_writer = SummaryWriter(logdir=demons_config.tensorboard_path)
107 | perception_module = PerceptionModule(vobs_dim, dof_dim, model_config.visual_state_dim).to(device)
108 | imitation_policy = ImitationPolicy(act_dim, model_config.combined_state_dim,).to(device)
109 |
110 | params = list(perception_module.parameters()) + list(imitation_policy.parameters())
111 | print("Number of parameters : {}".format(len(params)))
112 |
113 | optimizer = torch.optim.Adam(params, lr=model_config.learning_rate)
114 |
115 | if(demons_config.load_models):
116 | # TODO : IMPORTANT - Check if file exist before loading
117 | if demons_config.use_model_perception:
118 | perception_module.load_state_dict(torch.load(os.path.join(model_config.models_save_path, 'perception.pth')))
119 | else :
120 | perception_module.load_state_dict(torch.load(os.path.join(demons_config.models_save_path, 'perception.pth')))
121 | imitation_policy.load_state_dict(torch.load(os.path.join(demons_config.models_save_path, 'imitation_policy.pth')))
122 | optimizer.load_state_dict(torch.load(os.path.join(demons_config.models_save_path, 'optimizer.pth')))
123 |
124 | print("Run : tensorboard --logdir={} --host '0.0.0.0' --port 6006".format(demons_config.tensorboard_path))
125 | data_loader = DataLoader(deg.traj_dataset, batch_size=model_config.batch_size, shuffle=True, num_workers=1)
126 | max_step_size = len(data_loader.dataset)
127 |
128 | for epoch in tqdm(range(model_config.max_epochs), desc="Check Tensorboard"):
129 | for i, trajectory in enumerate(data_loader):
130 | trajectory = {key : trajectory[key].float().to(device) for key in trajectory.keys()}
131 | visual_obvs, dof_obs, action = trajectory[deg.vis_obv_key], trajectory[deg.dof_obv_key], trajectory['action']
132 | batch_size, seq_len = visual_obvs.shape[0], visual_obvs.shape[1]
133 |
134 | visual_obvs = visual_obvs.reshape(batch_size * seq_len, vobs_dim[2], vobs_dim[0], vobs_dim[1])
135 | dof_obs = dof_obs.reshape(batch_size * seq_len, dof_dim)
136 | actions = trajectory['action'].reshape(batch_size * seq_len, -1)
137 |
138 | states = perception_module(visual_obvs, dof_obs) # DEBUG : Might raise in-place errors
139 |
140 | pi, logp_a = imitation_policy(state=states, action=actions)
141 |
142 | optimizer.zero_grad()
143 | loss = -logp_a
144 | loss = loss.mean()
145 |
146 | tensorboard_writer.add_scalar('Clone Loss', loss, epoch * max_step_size + i)
147 | loss.backward()
148 | optimizer.step()
149 |
150 | if int(i % model_config.save_interval) == 0:
151 | if not demons_config.use_model_perception:
152 | torch.save(perception_module.state_dict(), os.path.join(demons_config.models_save_path, 'perception.pth'))
153 | torch.save(imitation_policy.state_dict(), os.path.join(demons_config.models_save_path, 'imitation_policy.pth'))
154 | torch.save(optimizer.state_dict(), os.path.join(demons_config.models_save_path, 'optimizer.pth'))
155 |
156 |
157 | def imitate_play():
158 | model_config = get_model_args()
159 | demons_config = get_demons_args()
160 |
161 | deg = demons_config.deg()
162 | env = deg.get_env()
163 |
164 | vobs_dim, dof_dim = deg.obs_space[deg.vis_obv_key], deg.obs_space[deg.dof_obv_key]
165 | act_dim = deg.action_space
166 |
167 | with torch.no_grad():
168 | perception_module = PerceptionModule(vobs_dim, dof_dim, model_config.visual_state_dim).to(device)
169 | imitation_policy = ImitationPolicy(act_dim, model_config.combined_state_dim,).to(device)
170 |
171 | if demons_config.use_model_perception:
172 | perception_module.load_state_dict(torch.load(os.path.join(model_config.models_save_path, 'perception.pth')))
173 | else :
174 | perception_module.load_state_dict(torch.load(os.path.join(demons_config.models_save_path, 'perception.pth')))
175 | imitation_policy.load_state_dict(torch.load(os.path.join(demons_config.models_save_path, 'imitation_policy.pth')))
176 |
177 | for run in range(demons_config.n_gen_traj):
178 | obs = env.reset()
179 | tr_vobvs, tr_dof, tr_actions = [], [], []
180 |
181 | for step in range(demon_config.flush_freq):
182 | visual_obv, dof_obv = torch.from_numpy(obvs[deg.vis_obv_key]).float(), torch.from_numpy(obvs[deg.dof_obv_key]).float()
183 | visual_obv = visual_obv.reshape(1, visual_obv.shape[2], visual_obv.shape[0], visual_obv.shape[1])
184 | dof_obv = dof_obv.reshape(1, dof_obv.shape[0])
185 | state = perception_module(visual_obv, dof_obv)
186 |
187 | action, _ = imitation_policy.step(state)
188 |
189 | if int(step % demons_config.collect_freq) == 0:
190 | tr_vobvs.append(visual_obv)
191 | tr_dof.append(dof_obv)
192 | tr_actions.append(action[0])
193 |
194 | obs, _, done, _ = env.step(action[0])
195 |
196 | print('Storing Trajectory')
197 | trajectory = {deg.vis_obv_key : np.array(tr_vobvs), deg.dof_obv_key : np.array(tr_dof), 'action' : np.array(tr_actions)}
198 | store_trajectoy(trajectory, 'imitation')
199 | trajectory, tr_vobvs, tr_dof, tr_actions = {}, [], [], []
200 |
201 | env.close()
202 |
203 | if __name__ == '__main__':
204 | demons_config = get_demons_args()
205 | if demons_config.train_imitation:
206 | train_imitation(demons_config)
207 |
208 | imitate_play()
--------------------------------------------------------------------------------
/dataset_env/data_aug.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numbers
3 | import random
4 | import time
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | # NOTE : Source - RAD (https://github.com/MishaLaskin/rad)
11 | # TODO : Will test later
12 |
13 | def random_crop(images, out=192):
14 | '''
15 | + Arguments:
16 | - images: np.array shape (N,C,H,W)
17 | - out: output size (e.g. 84)
18 | '''
19 | n, c, h, w = images.shape
20 | crop_max = h - out + 1
21 | w1 = np.random.randint(0, crop_max, n)
22 | h1 = np.random.randint(0, crop_max, n)
23 | cropped = np.empty((n, c, out, out), dtype=images.dtype)
24 | for i, (img, w11, h11) in enumerate(zip(images, w1, h1)):
25 | cropped[i] = img[:, h11:h11 + out, w11:w11 + out]
26 |
27 | return np.reshape(cropped, (n, out, out, c))
28 |
29 | def grayscale(images):
30 | '''
31 | + Arguments:
32 | - images: np.array shape (B,C,H,W)
33 | '''
34 | images = torch.from_numpy(images)
35 | device = images.device
36 | b, c, h, w = images.shape
37 | frames = c // 3
38 |
39 | images = images.view([b, frames, 3, h, w])
40 | images = images[:, :, 0, ...] * 0.2989 + images[:, :, 1, ...] * 0.587 + images[:, :, 2, ...] * 0.114
41 |
42 | images = images.type(torch.uint8).float()
43 | # assert len(images.shape) == 3, images.shape
44 | images = images[:, :, None, :, :]
45 | images = images.numpy() * np.ones([1, 1, 3, 1, 1])
46 | return np.reshape(images, (b, h, w, 1))
47 |
48 | def random_grayscale(images, probab=0.3):
49 | '''
50 | + Arguments:
51 | - images: np.array shape (B,C,H,W)
52 | - probab
53 | '''
54 | images = torch.from_numpy(images)
55 | device = images.device
56 | in_type = images.type()
57 | images = images * 255.
58 | images = images.type(torch.uint8)
59 | # images: [B, C, H, W]
60 | bs, channels, h, w = images.shape
61 | images = images.to(device)
62 | gray_images = grayscale(images)
63 | rnd = np.random.uniform(0., 1., size=(images.shape[0],))
64 | mask = rnd <= probab
65 | mask = torch.from_numpy(mask)
66 | frames = images.shape[1] // 3
67 | images = images.view(*gray_images.shape)
68 | mask = mask[:, None] * torch.ones([1, frames]).type(mask.dtype)
69 | mask = mask.type(images.dtype).to(device)
70 | mask = mask[:, :, None, None, None]
71 | out = mask * gray_images + (1 - mask) * images
72 | out = out.view([bs, -1, h, w]).type(in_type) / 255.
73 | return np.reshape(out.numpy(), (bs, h, w, -1))
74 |
75 | def random_cutout(images, min_cut=10, max_cut=30):
76 | '''
77 | + Arguments:
78 | - images: np.array shape (B,C,H,W)
79 | - min / max cut: int, min / max size of cutout
80 | '''
81 |
82 | n, c, h, w = images.shape
83 | w1 = np.random.randint(min_cut, max_cut, n)
84 | h1 = np.random.randint(min_cut, max_cut, n)
85 |
86 | cutouts = np.empty((n, c, h, w), dtype=images.dtype)
87 | for i, (img, w11, h11) in enumerate(zip(images, w1, h1)):
88 | cut_img = img.copy()
89 | cut_img[:, h11:h11 + h11, w11:w11 + w11] = 0
90 | cutouts[i] = cut_img
91 | return np.reshape(cutouts, (n, h, w, c))
92 |
93 | def random_cutout_color(images, min_cut=10, max_cut=30):
94 | '''
95 | + Arguments:
96 | - images: np.array shape (N,C,H,W)
97 | - out: output size (e.g. 84)
98 | '''
99 | n, c, h, w = images.shape
100 | w1 = np.random.randint(min_cut, max_cut, n)
101 | h1 = np.random.randint(min_cut, max_cut, n)
102 |
103 | cutouts = np.empty((n, c, h, w), dtype=images.dtype)
104 | rand_box = np.random.randint(0, 255, size=(n, c)) / 255.
105 | for i, (img, w11, h11) in enumerate(zip(images, w1, h1)):
106 | cut_img = img.copy()
107 |
108 | cut_img[:, h11:h11 + h11, w11:w11 + w11] = np.tile(
109 | rand_box[i].reshape(-1,1,1),
110 | (1,) + cut_img[:, h11:h11 + h11, w11:w11 + w11].shape[1:])
111 |
112 | cutouts[i] = cut_img
113 | return np.reshape(cutouts, (n, h, w, c))
114 |
115 | def random_flip(images, probab=0.2):
116 | '''
117 | + Arguments:
118 | - images: np.array shape (B,C,H,W)
119 | - probab
120 | '''
121 | images = torch.from_numpy(images)
122 | device = images.device
123 | bs, channels, h, w = images.shape
124 |
125 | images = images.to(device)
126 |
127 | flipped_images = images.flip([3])
128 |
129 | rnd = np.random.uniform(0., 1., size=(images.shape[0],))
130 | mask = rnd <= probab
131 | mask = torch.from_numpy(mask)
132 | frames = images.shape[1] #// 3
133 | images = images.view(*flipped_images.shape)
134 | mask = mask[:, None] * torch.ones([1, frames]).type(mask.dtype)
135 |
136 | mask = mask.type(images.dtype).to(device)
137 | mask = mask[:, :, None, None]
138 |
139 | out = mask * flipped_images + (1 - mask) * images
140 |
141 | out = out.view([bs, h, w, -1])
142 | return out.numpy()
143 |
144 | def random_rotation(images, probab=0.3):
145 | '''
146 | + Arguments:
147 | - images: np.array shape (B,C,H,W)
148 | - probab
149 | '''
150 | images = torch.from_numpy(images)
151 | device = images.device
152 | # images: [B, C, H, W]
153 | bs, channels, h, w = images.shape
154 |
155 | images = images.to(device)
156 |
157 | rot90_images = images.rot90(1,[2,3])
158 | rot180_images = images.rot90(2,[2,3])
159 | rot270_images = images.rot90(3,[2,3])
160 |
161 | rnd = np.random.uniform(0., 1., size=(images.shape[0],))
162 | rnd_rot = np.random.randint(1, 4, size=(images.shape[0],))
163 | mask = rnd <= probab
164 | mask = rnd_rot * mask
165 | mask = torch.from_numpy(mask).to(device)
166 |
167 | frames = images.shape[1]
168 | masks = [torch.zeros_like(mask) for _ in range(4)]
169 | for i,m in enumerate(masks):
170 | m[torch.where(mask==i)] = 1
171 | m = m[:, None] * torch.ones([1, frames]).type(mask.dtype).type(images.dtype).to(device)
172 | m = m[:,:,None,None]
173 | masks[i] = m
174 |
175 |
176 | out = masks[0] * images + masks[1] * rot90_images + masks[2] * rot180_images + masks[3] * rot270_images
177 |
178 | out = out.view([bs, h, w, -1])
179 | return out.numpy()
180 |
181 | def random_convolution(images):
182 | '''
183 | + Arguments:
184 | - images: np.array shape (B,C,H,W)
185 | '''
186 | images = torch.from_numpy(images)
187 | _device = images.device
188 |
189 | img_h, img_w = images.shape[2], images.shape[3]
190 | num_stack_channel = images.shape[1]
191 | num_batch = images.shape[0]
192 | num_trans = num_batch
193 | batch_size = int(num_batch / num_trans)
194 |
195 | # initialize random covolution
196 | rand_conv = nn.Conv2d(3, 3, kernel_size=3, bias=False, padding=1).to(_device)
197 |
198 | for trans_index in range(num_trans):
199 | torch.nn.init.xavier_normal_(rand_conv.weight.data)
200 | temp_images = images[trans_index*batch_size:(trans_index+1)*batch_size]
201 | temp_images = temp_images.reshape(-1, 3, img_h, img_w) # (batch x stack, channel, h, w)
202 | rand_out = rand_conv(temp_images)
203 | if trans_index == 0:
204 | total_out = rand_out
205 | else:
206 | total_out = torch.cat((total_out, rand_out), 0)
207 | total_out = total_out.reshape(-1, num_stack_channel, img_h, img_w)
208 | return np.reshape(total_out.numpy(), (num_batch, img_h, img_w, num_stack_channel))
209 |
210 | def no_aug(images):
211 | return images
212 |
213 | def random_color_jitter(images):
214 | '''
215 | + Arguments:
216 | - images: np.array shape (B,C,H,W)
217 | '''
218 | b,c,h,w = images.shape
219 | images = images.view(-1,3,h,w)
220 | transform_module = nn.Sequential(ColorJitterLayer(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5, p=1.0, batch_size=128))
221 | images = transform_module(images).view(b, h, w, c)
222 | return images
223 |
224 | # ------------------------------------------------------------------------------------------------------------------------
225 |
226 | def rgb2hsv(rgb, eps=1e-8):
227 | # Reference: https://www.rapidtables.com/convert/color/rgb-to-hsv.html
228 | # Reference: https://github.com/scikit-image/scikit-image/blob/master/skimage/color/colorconv.py#L287
229 |
230 | _device = rgb.device
231 | r, g, b = rgb[:, 0, :, :], rgb[:, 1, :, :], rgb[:, 2, :, :]
232 |
233 | Cmax = rgb.max(1)[0]
234 | Cmin = rgb.min(1)[0]
235 | delta = Cmax - Cmin
236 |
237 | hue = torch.zeros((rgb.shape[0], rgb.shape[2], rgb.shape[3])).to(_device)
238 | hue[Cmax== r] = (((g - b)/(delta + eps)) % 6)[Cmax == r]
239 | hue[Cmax == g] = ((b - r)/(delta + eps) + 2)[Cmax == g]
240 | hue[Cmax == b] = ((r - g)/(delta + eps) + 4)[Cmax == b]
241 | hue[Cmax == 0] = 0.0
242 | hue = hue / 6. # making hue range as [0, 1.0)
243 | hue = hue.unsqueeze(dim=1)
244 |
245 | saturation = (delta) / (Cmax + eps)
246 | saturation[Cmax == 0.] = 0.
247 | saturation = saturation.to(_device)
248 | saturation = saturation.unsqueeze(dim=1)
249 |
250 | value = Cmax
251 | value = value.to(_device)
252 | value = value.unsqueeze(dim=1)
253 |
254 | return torch.cat((hue, saturation, value), dim=1)#.type(torch.FloatTensor).to(_device)
255 | # return hue, saturation, value
256 |
257 | def hsv2rgb(hsv):
258 | # Reference: https://www.rapidtables.com/convert/color/hsv-to-rgb.html
259 | # Reference: https://github.com/scikit-image/scikit-image/blob/master/skimage/color/colorconv.py#L287
260 |
261 | _device = hsv.device
262 |
263 | hsv = torch.clamp(hsv, 0, 1)
264 | hue = hsv[:, 0, :, :] * 360.
265 | saturation = hsv[:, 1, :, :]
266 | value = hsv[:, 2, :, :]
267 |
268 | c = value * saturation
269 | x = - c * (torch.abs((hue / 60.) % 2 - 1) - 1)
270 | m = (value - c).unsqueeze(dim=1)
271 |
272 | rgb_prime = torch.zeros_like(hsv).to(_device)
273 |
274 | inds = (hue < 60) * (hue >= 0)
275 | rgb_prime[:, 0, :, :][inds] = c[inds]
276 | rgb_prime[:, 1, :, :][inds] = x[inds]
277 |
278 | inds = (hue < 120) * (hue >= 60)
279 | rgb_prime[:, 0, :, :][inds] = x[inds]
280 | rgb_prime[:, 1, :, :][inds] = c[inds]
281 |
282 | inds = (hue < 180) * (hue >= 120)
283 | rgb_prime[:, 1, :, :][inds] = c[inds]
284 | rgb_prime[:, 2, :, :][inds] = x[inds]
285 |
286 | inds = (hue < 240) * (hue >= 180)
287 | rgb_prime[:, 1, :, :][inds] = x[inds]
288 | rgb_prime[:, 2, :, :][inds] = c[inds]
289 |
290 | inds = (hue < 300) * (hue >= 240)
291 | rgb_prime[:, 2, :, :][inds] = c[inds]
292 | rgb_prime[:, 0, :, :][inds] = x[inds]
293 |
294 | inds = (hue < 360) * (hue >= 300)
295 | rgb_prime[:, 2, :, :][inds] = x[inds]
296 | rgb_prime[:, 0, :, :][inds] = c[inds]
297 |
298 | rgb = rgb_prime + torch.cat((m, m, m), dim=1)
299 | rgb = rgb.to(_device)
300 |
301 | return torch.clamp(rgb, 0, 1)
302 |
303 | class ColorJitterLayer(nn.Module):
304 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, p=0, batch_size=128, stack_size=3):
305 | super(ColorJitterLayer, self).__init__()
306 | self.brightness = self._check_input(brightness, 'brightness')
307 | self.contrast = self._check_input(contrast, 'contrast')
308 | self.saturation = self._check_input(saturation, 'saturation')
309 | self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
310 | clip_first_on_zero=False)
311 | self.prob = p
312 | self.batch_size = batch_size
313 | self.stack_size = stack_size
314 |
315 | def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
316 | if isinstance(value, numbers.Number):
317 | if value < 0:
318 | raise ValueError("If {} is a single number, it must be non negative.".format(name))
319 | value = [center - value, center + value]
320 | if clip_first_on_zero:
321 | value[0] = max(value[0], 0)
322 | elif isinstance(value, (tuple, list)) and len(value) == 2:
323 | if not bound[0] <= value[0] <= value[1] <= bound[1]:
324 | raise ValueError("{} values should be between {}".format(name, bound))
325 | else:
326 | raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name))
327 | # if value is 0 or (1., 1.) for brightness/contrast/saturation
328 | # or (0., 0.) for hue, do nothing
329 | if value[0] == value[1] == center:
330 | value = None
331 | return value
332 |
333 | def adjust_contrast(self, x):
334 | """
335 | Args:
336 | x: torch tensor img (rgb type)
337 | Factor: torch tensor with same length as x
338 | 0 gives gray solid image, 1 gives original image,
339 | Returns:
340 | torch tensor image: Brightness adjusted
341 | """
342 | _device = x.device
343 | factor = torch.empty(self.batch_size, device=_device).uniform_(*self.contrast)
344 | factor = factor.reshape(-1,1).repeat(1, self.stack_size).reshape(-1)
345 | means = torch.mean(x, dim=(2, 3), keepdim=True)
346 | return torch.clamp((x - means)
347 | * factor.view(len(x), 1, 1, 1) + means, 0, 1)
348 |
349 | def adjust_hue(self, x):
350 | _device = x.device
351 | factor = torch.empty(self.batch_size, device=_device).uniform_(*self.hue)
352 | factor = factor.reshape(-1,1).repeat(1, self.stack_size).reshape(-1)
353 | h = x[:, 0, :, :]
354 | h += (factor.view(len(x), 1, 1) * 255. / 360.)
355 | h = (h % 1)
356 | x[:, 0, :, :] = h
357 | return x
358 |
359 | def adjust_brightness(self, x):
360 | """
361 | Args:
362 | x: torch tensor img (hsv type)
363 | Factor:
364 | torch tensor with same length as x
365 | 0 gives black image, 1 gives original image,
366 | 2 gives the brightness factor of 2.
367 | Returns:
368 | torch tensor image: Brightness adjusted
369 | """
370 | _device = x.device
371 | factor = torch.empty(self.batch_size, device=_device).uniform_(*self.brightness)
372 | factor = factor.reshape(-1,1).repeat(1, self.stack_size).reshape(-1)
373 | x[:, 2, :, :] = torch.clamp(x[:, 2, :, :]
374 | * factor.view(len(x), 1, 1), 0, 1)
375 | return torch.clamp(x, 0, 1)
376 |
377 | def adjust_saturate(self, x):
378 | """
379 | Args:
380 | x: torch tensor img (hsv type)
381 | Factor:
382 | torch tensor with same length as x
383 | 0 gives black image and white, 1 gives original image,
384 | 2 gives the brightness factor of 2.
385 | Returns:
386 | torch tensor image: Brightness adjusted
387 | """
388 | _device = x.device
389 | factor = torch.empty(self.batch_size, device=_device).uniform_(*self.saturation)
390 | factor = factor.reshape(-1,1).repeat(1, self.stack_size).reshape(-1)
391 | x[:, 1, :, :] = torch.clamp(x[:, 1, :, :]
392 | * factor.view(len(x), 1, 1), 0, 1)
393 | return torch.clamp(x, 0, 1)
394 |
395 | def transform(self, inputs):
396 | hsv_transform_list = [rgb2hsv, self.adjust_brightness,
397 | self.adjust_hue, self.adjust_saturate,
398 | hsv2rgb]
399 | rgb_transform_list = [self.adjust_contrast]
400 | # Shuffle transform
401 | if random.uniform(0,1) >= 0.5:
402 | transform_list = rgb_transform_list + hsv_transform_list
403 | else:
404 | transform_list = hsv_transform_list + rgb_transform_list
405 | for t in transform_list:
406 | inputs = t(inputs)
407 | return inputs
408 |
409 | def forward(self, inputs):
410 | _device = inputs.device
411 | random_inds = np.random.choice(
412 | [True, False], len(inputs), p=[self.prob, 1 - self.prob])
413 | inds = torch.tensor(random_inds).to(_device)
414 | if random_inds.sum() > 0:
415 | inputs[inds] = self.transform(inputs[inds])
416 | return inputs
417 |
418 |
419 | aug_to_func = {
420 | 'crop': random_crop,
421 | 'grayscale': random_grayscale,
422 | 'cutout': random_cutout,
423 | 'cutout_color': random_cutout_color,
424 | 'flip': random_flip,
425 | 'rotate': random_rotation,
426 | 'rand_conv': random_convolution,
427 | 'color_jitter': random_color_jitter,
428 | 'no_aug': no_aug,
429 | }
430 |
431 | def apply_augs(images, config):
432 | for aug in config.augs:
433 | n, h, w, c = images.shape
434 | images = np.reshape(images, (n, c, h, w))
435 | images = aug_to_func[aug](images)
436 | return images
--------------------------------------------------------------------------------