├── pretrained_models
└── reach
│ └── checkpoint_000575
│ ├── .is_checkpoint
│ ├── checkpoint-575
│ ├── checkpoint-575.tune_metadata
│ └── config.yaml
├── envs
├── __init__.py
├── reach
│ ├── reach.ttt
│ └── reach_env.py
├── wrappers.py
└── base_env.py
├── requirements.txt
├── utils
├── geometry.py
├── image.py
├── config.py
└── rllib.py
├── configs
├── default_config.yaml
└── image_obs_config.yaml
├── README.md
├── .gitignore
└── main.py
/pretrained_models/reach/checkpoint_000575/.is_checkpoint:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from envs.reach.reach_env import ReachEnv
2 |
--------------------------------------------------------------------------------
/envs/reach/reach.ttt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isk03276/LearnToMoveUR3/HEAD/envs/reach/reach.ttt
--------------------------------------------------------------------------------
/pretrained_models/reach/checkpoint_000575/checkpoint-575:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isk03276/LearnToMoveUR3/HEAD/pretrained_models/reach/checkpoint_000575/checkpoint-575
--------------------------------------------------------------------------------
/pretrained_models/reach/checkpoint_000575/checkpoint-575.tune_metadata:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isk03276/LearnToMoveUR3/HEAD/pretrained_models/reach/checkpoint_000575/checkpoint-575.tune_metadata
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 |
3 | gym
4 |
5 | # ML
6 | ray[rllib]
7 |
8 | #opencv
9 | opencv-python-headless==4.4.0.46
10 |
11 | #formatting
12 | black==22.1.0
13 |
14 | #ML framework
15 | torch==1.10.0
16 | tensorboard==2.8.0
--------------------------------------------------------------------------------
/utils/geometry.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import numpy as np
4 |
5 |
6 | def get_distance_between_two_pts(pts1: np.ndarray, pts2: np.ndarray) -> float:
7 | """
8 | Calculate distance between two points.
9 | Args:
10 | pts1 (list): first point
11 | pts2 (list): second point
12 |
13 | Returns:
14 | flot: distance
15 | """
16 | return np.linalg.norm(pts1 - pts2)
17 |
--------------------------------------------------------------------------------
/utils/image.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 |
5 | def resize_image(image: np.ndarray, width: int, height: int) -> np.ndarray:
6 | """
7 | Resize a image.
8 | Args:
9 | image (np.ndarray): input image
10 | width (int): width size to resize
11 | height (int): height size to resize
12 |
13 | Returns:
14 | np.ndarray: Resized image
15 | """
16 | return cv2.resize(image, dsize=(width, height))
17 |
--------------------------------------------------------------------------------
/pretrained_models/reach/checkpoint_000575/config.yaml:
--------------------------------------------------------------------------------
1 | env: {}
2 | env_config:
3 | use_arm_camera: false
4 | use_image_observation: false
5 | rllib:
6 | batch_mode: complete_episodes
7 | callbacks: !!python/name:utils.rllib.CustomLogCallback ''
8 | clip_actions: true
9 | clip_param: 0.1
10 | clip_rewards: false
11 | entropy_coeff: 0.01
12 | framework: torch
13 | kl_coeff: 0.5
14 | lambda: 0.95
15 | num_gpus: 1
16 | num_sgd_iter: 10
17 | num_workers: 8
18 | observation_filter: NoFilter
19 | rollout_fragment_length: 100
20 | sgd_minibatch_size: 500
21 | train_batch_size: 5000
22 | vf_clip_param: 10.0
23 |
--------------------------------------------------------------------------------
/configs/default_config.yaml:
--------------------------------------------------------------------------------
1 | env : {}
2 | env_config:
3 | use_image_observation: false
4 | use_arm_camera: false
5 |
6 | rllib:
7 | # Whether to clip actions to the action space's low/high range spec.
8 | clip_actions: true
9 | num_gpus: 0
10 | num_workers: 4
11 | framework: torch
12 | lambda: 0.95
13 | kl_coeff: 0.5
14 | clip_rewards: false
15 | clip_param: 0.1
16 | vf_clip_param: 10.0
17 | entropy_coeff: 0.01
18 | train_batch_size: 5000
19 | rollout_fragment_length: 100
20 | sgd_minibatch_size: 500
21 | num_sgd_iter: 10
22 | observation_filter: NoFilter
23 | # Whether to rollout "complete_episodes" or "truncate_episodes".
24 | batch_mode: complete_episodes
25 |
--------------------------------------------------------------------------------
/utils/config.py:
--------------------------------------------------------------------------------
1 | import yaml
2 |
3 |
4 | def load_config(config_file: str) -> dict:
5 | """
6 | Load a config file.
7 |
8 | Args:
9 | config_file (str): config file path
10 |
11 | Returns:
12 | (dict): loaded config infoes
13 | """
14 | with open(config_file, "r") as f:
15 | config = yaml.load(f, Loader=yaml.FullLoader)
16 | return config
17 |
18 |
19 | def save_config(configs: dict, config_file_path: str):
20 | """
21 | Save a config file.
22 | Args:
23 | configs (dict): env, rllib configs
24 | config_file (str): config file path to save
25 | """
26 | with open(config_file_path, "w") as f:
27 | yaml.dump(configs, f)
28 |
--------------------------------------------------------------------------------
/envs/wrappers.py:
--------------------------------------------------------------------------------
1 | from gym import Env, ObservationWrapper
2 | from gym.spaces import Box
3 |
4 | from utils.image import resize_image
5 |
6 |
7 | class ImageObsWrapper(ObservationWrapper):
8 | def __init__(
9 | self, env: Env, use_arm_camera: bool = True, width: int = 84, height: int = 84
10 | ):
11 | super().__init__(env)
12 | self.use_arm_camera = use_arm_camera
13 | self.width = width
14 | self.height = height
15 | self._modify_observation_space()
16 |
17 | def observation(self, _observation):
18 | obs = self.env.render()
19 | obs = resize_image(obs, self.width, self.height)
20 | return obs
21 |
22 | def _modify_observation_space(self):
23 | channel_num = 6 if self.use_arm_camera else 3
24 | self.env.observation_space = Box(0, 255, (self.width, self.height, channel_num))
25 |
--------------------------------------------------------------------------------
/configs/image_obs_config.yaml:
--------------------------------------------------------------------------------
1 | env : {}
2 | env_config:
3 | use_image_observation: true
4 | use_arm_camera: true
5 |
6 | rllib:
7 | model:
8 | vf_share_layers: false
9 | use_lstm: true
10 | # Whether to feed a_{t-1} to LSTM (one-hot encoded if discrete).
11 | lstm_use_prev_action: false
12 | lstm_cell_size: 256
13 | # Whether to clip actions to the action space's low/high range spec.
14 | clip_actions: true
15 | num_gpus: 1
16 | num_workers: 8
17 | framework: torch
18 | lambda: 0.95
19 | kl_coeff: 0.5
20 | clip_rewards: false
21 | clip_param: 0.1
22 | vf_clip_param: 10.0
23 | entropy_coeff: 0.01
24 | train_batch_size: 5000
25 | rollout_fragment_length: 100
26 | sgd_minibatch_size: 500
27 | num_sgd_iter: 10
28 | observation_filter: NoFilter
29 | # Whether to rollout "complete_episodes" or "truncate_episodes".
30 | batch_mode: complete_episodes
31 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # UR3-Deep-Reinforcement-Learning
2 |
3 | You will be able to learn various tasks of the UR3 with robotiq85 gripper robot.
4 | Learning method is based on the DRL(Deep Reinforcement Learning).
5 | In this repo, we use [CoppeliaSim](http://www.coppeliarobotics.com/) (previously called V-REP), [Pyrep](https://github.com/stepjam/PyRep).
6 | Tasks
7 | - DRL framework : rllib
8 | - Supported tasks
9 |
10 | |Tasks|Learned Task Example|Learning Curve|
11 | |:---:|:---:|:---:|
12 | |**reach**|
|
|
13 | |TO DO|-|-|
14 |
15 | ## Install
16 | This repo was tested with Python 3.7.9 version.
17 |
18 | #### Coppeliasim
19 | PyRep requires version **4.1(other versions may have bugs)** of CoppeliaSim. Download:
20 | - [Ubuntu 16.04](https://www.coppeliarobotics.com/files/CoppeliaSim_Edu_V4_1_0_Ubuntu16_04.tar.xz)
21 | - [Ubuntu 18.04](https://www.coppeliarobotics.com/files/CoppeliaSim_Edu_V4_1_0_Ubuntu18_04.tar.xz)
22 | - [Ubuntu 20.04](https://www.coppeliarobotics.com/files/CoppeliaSim_Edu_V4_1_0_Ubuntu20_04.tar.xz)
23 |
24 | Add the following to your *~/.bashrc* file: (__NOTE__: the 'EDIT ME' in the first line)
25 | ```bash
26 | export COPPELIASIM_ROOT=EDIT/ME/PATH/TO/COPPELIASIM/INSTALL/DIR
27 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$COPPELIASIM_ROOT
28 | export QT_QPA_PLATFORM_PLUGIN_PATH=$COPPELIASIM_ROOT
29 | ```
30 |
31 | #### PyRep
32 | Once you have downloaded and set CoppeliaSim, you can install PyRep:
33 | Move to home workspace
34 | ```bash
35 | git clone https://github.com/stepjam/PyRep.git
36 | cd PyRep
37 | pip install -r requirements.txt
38 | pip install -e .
39 | ```
40 |
41 | Remember to source your bashrc (`source ~/.bashrc`) or
42 | zshrc (`source ~/.zshrc`) after this.
43 |
44 |
45 | #### LearnToMoveUR3
46 | Move to home workspace
47 | Clone repo and Install the python library:
48 | ```bash
49 | git clone https://github.com/isk03276/LearnToMoveUR3.git
50 | cd LearnToMoveUR3
51 | pip install -r requirements.txt
52 | ```
53 |
54 |
55 | ## Getting Started
56 | ```bash
57 | python main.py --env-id ENV_ID --load-from MODEL_CHECKPOINT_PATH #Train
58 | python main.py --env-id reach --test --load-from MODEL_CHECKPOINT_PATH #Test
59 | ```
60 |
61 |
62 | ## Use Pretrained Model
63 | ```bash
64 | python main.py --env-id reach --load-from pretrained_models/reach --test
65 | ```
66 |
67 |
--------------------------------------------------------------------------------
/envs/reach/reach_env.py:
--------------------------------------------------------------------------------
1 | import random
2 | import math
3 | from typing import overload
4 |
5 | from envs.base_env import BaseEnv
6 |
7 | import numpy as np
8 | from pyrep.objects.shape import Shape
9 | from gym.spaces import Box
10 |
11 |
12 | class ReachEnv(BaseEnv):
13 | """
14 | Reach env class.
15 | """
16 |
17 | def __init__(
18 | self,
19 | scene_file="envs/reach/reach.ttt",
20 | use_arm_camera: bool = False,
21 | rendering: bool = True,
22 | ):
23 | super().__init__(scene_file, use_arm_camera, rendering)
24 | self.target = Shape("TargetPoint")
25 | self.target_x_range = (0.2, 0.4)
26 | self.target_y_range = (-0.2, 0.2)
27 | self.target_z_range = (0.4, 0.7)
28 |
29 | self.reach_threshold = 0.05
30 |
31 | def _define_observation_space(self) -> Box:
32 | """
33 | Define/Get observation space.
34 | """
35 | observation_space = Box(float("-inf"), float("inf"), (17,))
36 | return observation_space
37 |
38 | def get_obs(self) -> np.ndarray:
39 | """
40 | Get agent's observation.
41 | The observation contains robot state and a relative position of target object
42 | Returns:
43 | (np.ndarray) : agent's observation
44 | """
45 | obs = self.get_robot_state()
46 | target_realtive_position = self.get_object_position_relative_to_base_link(
47 | self.target
48 | )
49 | obs.extend(target_realtive_position)
50 | return np.array(obs)
51 |
52 | def get_reward(self) -> float:
53 | """
54 | This reward function is based on the distance between target object and tip.
55 | Returns:
56 | (float) : reward
57 | """
58 | distance_between_tip_and_target = self.get_distance_from_tip(
59 | self.target.get_position()
60 | )
61 | return -math.log10(distance_between_tip_and_target / 10 + 1)
62 |
63 | def reset_objects(self):
64 | """
65 | Reset a target object.
66 | """
67 | random_point_x = random.uniform(self.target_x_range[0], self.target_x_range[1])
68 | random_point_y = random.uniform(self.target_y_range[0], self.target_y_range[1])
69 | random_point_z = random.uniform(self.target_z_range[0], self.target_z_range[1])
70 | self.target.set_position([random_point_x, random_point_y, random_point_z])
71 |
72 | def is_goal_state(self) -> bool:
73 | """
74 | If the target object and the tip are close, it is considered as a goal state.
75 | """
76 | distance_between_tip_and_target = self.get_distance_from_tip(
77 | self.target.get_position()
78 | )
79 | return distance_between_tip_and_target <= self.reach_threshold
80 |
--------------------------------------------------------------------------------
/utils/rllib.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import os
3 | import tempfile
4 | from typing import Callable
5 |
6 | import numpy as np
7 | from ray.rllib.agents.trainer import Trainer
8 | from ray.tune.logger import UnifiedLogger
9 | from ray.rllib.agents.callbacks import DefaultCallbacks
10 |
11 |
12 | def get_current_time() -> str:
13 | """
14 | Generate current time as string.
15 | Returns:
16 | str: current time
17 | """
18 | NOWTIMES = datetime.datetime.now()
19 | curr_time = NOWTIMES.strftime("%y%m%d_%H%M%S")
20 | return curr_time
21 |
22 |
23 | def make_logging_folder(root_dir: str, env_id: str, is_test: bool) -> str:
24 | """
25 | Make a folder for logging.
26 | Args:
27 | root_dir (str): parent directory
28 | env_id (str): env id name
29 | is_test (bool): whether to test the model
30 |
31 | Returns:
32 | str: maked logging folder name
33 | """
34 | task = "Train" if not is_test else "Test"
35 | logdir_prefix = "[{}]{}_{}_".format(task, get_current_time(), env_id)
36 | if not os.path.exists(root_dir):
37 | os.makedirs(root_dir)
38 | logdir = tempfile.mkdtemp(prefix=logdir_prefix, dir=root_dir)
39 | return logdir
40 |
41 |
42 | def make_initial_hidden_state(lstm_cell_size: int) -> list:
43 | """
44 | Make initial hidden state for testing lstm-based policy network.
45 | Args:
46 | lstm_cell_size (int): lstm cell size
47 |
48 | Returns:
49 | list: hidden state
50 | """
51 | hidden_state = [np.zeros(lstm_cell_size), np.zeros(lstm_cell_size)]
52 | return hidden_state
53 |
54 |
55 | def save_model(trainer: Trainer, path_to_save: str):
56 | """
57 | Save trained model.
58 | Args:
59 | trainer (Trainer): rllib trainer
60 | """
61 | trainer.save(path_to_save)
62 |
63 |
64 | def load_model(trainer: Trainer, path_to_load: str):
65 | """
66 | Load trained model.
67 | Args:
68 | trainer (Trainer): rllib trainer
69 | path_to_load (str): path to load
70 | """
71 | trainer.restore(path_to_load)
72 |
73 |
74 | def get_logger_creator(logdir: str) -> Callable:
75 | """
76 | Get default logger creator for logging in rllib.
77 | Args:
78 | logdir (str): logging directory path
79 |
80 | Returns:
81 | Callable: logger creator
82 | """
83 |
84 | def logger_creator(config):
85 | return UnifiedLogger(config, logdir, loggers=None)
86 |
87 | return logger_creator
88 |
89 |
90 | class CustomLogCallback(DefaultCallbacks):
91 | """
92 | LogCallback based on Rllib callbacks for add 'success' custom metric.
93 | """
94 |
95 | def on_episode_end(self, *, episode, **kwargs):
96 | """On episode end, add success custom metric."""
97 | info = episode.last_info_for()
98 | episode.custom_metrics["success"] = int(info["success"])
99 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Compiled Lua sources
2 | luac.out
3 |
4 | # luarocks build files
5 | *.src.rock
6 | *.zip
7 | *.tar.gz
8 |
9 | # Object files
10 | *.o
11 | *.os
12 | *.ko
13 | *.obj
14 | *.elf
15 |
16 | # Precompiled Headers
17 | *.gch
18 | *.pch
19 |
20 | # Libraries
21 | *.lib
22 | *.a
23 | *.la
24 | *.lo
25 | *.def
26 | *.exp
27 |
28 | # Shared objects (inc. Windows DLLs)
29 | *.dll
30 | *.so
31 | *.so.*
32 | *.dylib
33 |
34 | # Executables
35 | *.exe
36 | *.out
37 | *.app
38 | *.i*86
39 | *.x86_64
40 | *.hex
41 |
42 | # Byte-compiled / optimized / DLL files
43 | __pycache__/
44 | *.py[cod]
45 | *$py.class
46 |
47 | # C extensions
48 | *.so
49 |
50 | # Distribution / packaging
51 | .Python
52 | build/
53 | develop-eggs/
54 | dist/
55 | downloads/
56 | eggs/
57 | .eggs/
58 | lib/
59 | lib64/
60 | parts/
61 | sdist/
62 | var/
63 | wheels/
64 | pip-wheel-metadata/
65 | share/python-wheels/
66 | *.egg-info/
67 | .installed.cfg
68 | *.egg
69 | MANIFEST
70 |
71 | # PyInstaller
72 | # Usually these files are written by a python script from a template
73 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
74 | *.manifest
75 | *.spec
76 |
77 | # Installer logs
78 | pip-log.txt
79 | pip-delete-this-directory.txt
80 |
81 | # Unit test / coverage reports
82 | htmlcov/
83 | .tox/
84 | .nox/
85 | .coverage
86 | .coverage.*
87 | .cache
88 | nosetests.xml
89 | coverage.xml
90 | *.cover
91 | *.py,cover
92 | .hypothesis/
93 | .pytest_cache/
94 |
95 | # Translations
96 | *.mo
97 | *.pot
98 |
99 | # Django stuff:
100 | *.log
101 | local_settings.py
102 | db.sqlite3
103 | db.sqlite3-journal
104 |
105 | # Flask stuff:
106 | instance/
107 | .webassets-cache
108 |
109 | # Scrapy stuff:
110 | .scrapy
111 |
112 | # Sphinx documentation
113 | docs/_build/
114 |
115 | # PyBuilder
116 | target/
117 |
118 | # Jupyter Notebook
119 | .ipynb_checkpoints
120 |
121 | # IPython
122 | profile_default/
123 | ipython_config.py
124 |
125 | # pyenv
126 | .python-version
127 |
128 | # pipenv
129 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
130 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
131 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
132 | # install all needed dependencies.
133 | #Pipfile.lock
134 |
135 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
136 | __pypackages__/
137 |
138 | # Celery stuff
139 | celerybeat-schedule
140 | celerybeat.pid
141 |
142 | # SageMath parsed files
143 | *.sage.py
144 |
145 | # Environments
146 | .env
147 | .venv
148 | env/
149 | venv/
150 | ENV/
151 | env.bak/
152 | venv.bak/
153 |
154 | # Spyder project settings
155 | .spyderproject
156 | .spyproject
157 |
158 | # Rope project settings
159 | .ropeproject
160 |
161 | # mkdocs documentation
162 | /site
163 |
164 | # mypy
165 | .mypy_cache/
166 | .dmypy.json
167 | dmypy.json
168 |
169 | # Pyre type checker
170 | .pyre/
171 |
172 | # trained model
173 | checkpoints/
174 | games/
175 | .vscode/
176 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from utils.rllib import (
4 | load_model,
5 | make_initial_hidden_state,
6 | get_logger_creator,
7 | make_logging_folder,
8 | save_model,
9 | CustomLogCallback,
10 | )
11 | from utils.config import load_config, save_config
12 | from envs.wrappers import ImageObsWrapper
13 |
14 | import ray
15 | from ray import tune
16 | from ray.rllib.agents import ppo
17 |
18 |
19 | def make_env(
20 | env_id: str, use_image_observation: bool, use_arm_camera: bool, rendering: bool
21 | ):
22 | if env_id == "reach":
23 | from envs import ReachEnv as env_class
24 | else:
25 | raise NotImplementedError
26 |
27 | env = env_class(use_arm_camera=use_arm_camera, rendering=rendering)
28 | return ImageObsWrapper(env) if use_image_observation else env
29 |
30 |
31 | def train(trainer, target_success_mean, path_to_save, save_interval):
32 | status = "[Train] {:2d} reward {:6.2f} len {:6.2f} success mean {:6.2f}"
33 |
34 | iteration = 0
35 | while True:
36 | result = trainer.train()
37 | success_mean = result["custom_metrics"]["success_mean"]
38 | print(
39 | status.format(
40 | iteration,
41 | result["episode_reward_mean"],
42 | result["episode_len_mean"],
43 | success_mean,
44 | )
45 | )
46 | iteration += 1
47 | if iteration % save_interval == 0:
48 | save_model(trainer, path_to_save)
49 | if success_mean >= target_success_mean:
50 | save_model(trainer, path_to_save)
51 | break
52 |
53 |
54 | def test(env, trainer, test_num):
55 | use_lstm = trainer.config.get("model").get("use_lstm")
56 | lstm_cell_size = trainer.config.get("model").get("lstm_cell_size")
57 | success_list = []
58 | for ep in range(test_num):
59 | done = False
60 | obs = env.reset()
61 | rews = []
62 | if use_lstm:
63 | hidden_state = make_initial_hidden_state(lstm_cell_size)
64 |
65 | status = "[Test] {:2d} reward {:6.2f} len {:6.2f}, success mean {:6.2f}"
66 |
67 | while not done:
68 | if use_lstm:
69 | action, hidden_state, _ = trainer.compute_action(obs, hidden_state)
70 | else:
71 | action = trainer.compute_action(obs)
72 | obs, rew, done, info = env.step(action)
73 | rews.append(rew)
74 | success_list.append(int(info["success"]))
75 | print(
76 | status.format(
77 | ep + 1,
78 | sum(rews) / len(rews),
79 | len(rews),
80 | sum(success_list) / len(success_list),
81 | )
82 | )
83 |
84 |
85 | def run(args):
86 | # load rllib config
87 | ray.init()
88 | configs = load_config(args.config_file_path)
89 | configs_to_save = configs.copy()
90 | rllib_configs = configs["rllib"]
91 | rllib_configs["callbacks"] = CustomLogCallback
92 |
93 | # env setting
94 | env_id = args.env_id
95 | env_config = configs["env_config"]
96 | env_args = {
97 | "env_id": env_id,
98 | "use_image_observation": env_config["use_image_observation"],
99 | "use_arm_camera": env_config["use_arm_camera"],
100 | "rendering": False if args.test else args.render,
101 | }
102 | tune.register_env(
103 | env_id, lambda _: make_env(**env_args),
104 | )
105 |
106 | # logging setting
107 | logdir = make_logging_folder(
108 | root_dir="checkpoints/", env_id=env_id, is_test=args.test
109 | )
110 | save_config(configs_to_save, logdir + "/config.yaml")
111 | logger_creator = get_logger_creator(logdir=logdir)
112 |
113 | # rllib trainer setting
114 | trainer = ppo.PPOTrainer(
115 | env=env_id, config=rllib_configs, logger_creator=logger_creator
116 | )
117 |
118 | if args.load_from is not None:
119 | load_model(trainer, args.load_from)
120 |
121 | if not args.test:
122 | train(trainer, args.target_success_mean, logdir, args.save_interval)
123 |
124 | env_args["rendering"] = args.render
125 | test_env = make_env(**env_args)
126 | test(test_env, trainer, args.test_num)
127 |
128 | ray.shutdown()
129 |
130 |
131 | if __name__ == "__main__":
132 | parser = argparse.ArgumentParser(description="Game environments to learn")
133 | parser.add_argument(
134 | "--env-id", default="reach", type=str, help="game environment id: 'reach', ..."
135 | )
136 | parser.add_argument(
137 | "--config-file-path",
138 | default="configs/default_config.yaml",
139 | type=str,
140 | help="Rllib config file path",
141 | )
142 | parser.add_argument("--render", action="store_true", help="Turn on rendering")
143 | # model
144 | parser.add_argument(
145 | "--save-interval", type=int, default=20, help="Model save interval"
146 | )
147 | parser.add_argument("--load-from", type=str, help="Path to load the model")
148 | # train/test
149 | parser.add_argument(
150 | "--target-success-mean",
151 | type=float,
152 | default=0.99,
153 | help="Learning ends when the current success mean is higher than ther target success mean",
154 | )
155 | parser.add_argument("--test", action="store_true", help="Whether to test the model")
156 | parser.add_argument(
157 | "--test-num", type=int, default=10, help="Number of episodes to test the model"
158 | )
159 |
160 | args = parser.parse_args()
161 | run(args)
162 |
--------------------------------------------------------------------------------
/envs/base_env.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, List
2 |
3 | from utils.geometry import get_distance_between_two_pts
4 |
5 | import gym
6 | from gym.spaces import Box
7 | import numpy as np
8 | from pyrep import PyRep
9 | from pyrep.robots.arms.ur3 import UR3
10 | from pyrep.robots.end_effectors.robotiq85_gripper import Robotiq85Gripper
11 | from pyrep.objects.vision_sensor import VisionSensor
12 | from pyrep.objects.shape import Shape
13 | from pyrep.objects import Object
14 |
15 |
16 | class BaseEnv(gym.Env):
17 | """
18 | Base environment class for learning behavior of UR3 robot.
19 | """
20 |
21 | def __init__(
22 | self, scene_file: str, use_arm_camera: bool = False, rendering: bool = True
23 | ):
24 | super().__init__()
25 | self.use_arm_camera = use_arm_camera
26 | self.env = PyRep()
27 | self.env.launch(scene_file, headless=not rendering)
28 | self.env.start()
29 | self.arm = None
30 | self.arm_base_link = None
31 | self.gripper = None
32 | self.tip = None
33 | self.third_view_camera = VisionSensor("kinect_rgb")
34 | self.arm_camera = VisionSensor("arm_camera_rgb")
35 | self.gripper_velocity = 0.2
36 | self._init_robot()
37 |
38 | self.max_time_step = 300
39 | self.current_time_step = 0
40 |
41 | # for the agent to keep the goal states
42 | self.max_consecutive_visit_to_goal = 5
43 | self.cur_consecutive_visit_to_goal = 0
44 |
45 | self.observation_space = self._define_observation_space()
46 | self.action_space = Box(-1.0, 1, (7,))
47 |
48 | def _init_robot(self):
49 | """
50 | Initialize robot.
51 | Assume that (robot arm) UR3 with (gripper) Robotiq85Gripper.
52 | """
53 | self.arm = UR3()
54 | self.arm_base_link = Shape("UR3_link2")
55 | self.gripper = Robotiq85Gripper()
56 | self.tip = self.arm.get_tip()
57 |
58 | def _define_observation_space(self) -> Box:
59 | """
60 | Define/Get observation space.
61 | Returns:
62 | (Box) : defined observation space
63 | """
64 | raise NotImplementedError
65 |
66 | def get_obs(self) -> np.ndarray:
67 | """
68 | Get agent's observation.
69 | Returns:
70 | (np.ndarray) : agent's observation
71 | """
72 | raise NotImplementedError
73 |
74 | def get_reward(self) -> float:
75 | """
76 | Get reward for state/action.
77 | We currently assume reward for state.
78 | Returns:
79 | (float) : reward
80 | """
81 | raise NotImplementedError
82 |
83 | def reset_objects(self):
84 | """
85 | Reset objects in env. (ex. target_object)
86 | """
87 | raise NotImplementedError
88 |
89 | def is_goal_state(self) -> bool:
90 | """
91 | Whether the current state is a goal state or not.
92 | Returns :
93 | (bool) : if current state is a goal state, then return True
94 | """
95 | raise NotImplementedError
96 |
97 | def render(self) -> np.ndarray:
98 | """
99 | Get image observation from cameras.
100 | Returns:
101 | np.ndarray: image observation
102 | """
103 | obs = self.third_view_camera.capture_rgb()
104 | if self.use_arm_camera:
105 | first_view_image = self.arm_camera.capture_rgb()
106 | obs = np.concatenate((obs, first_view_image), axis=2)
107 | return obs
108 |
109 | def get_done_and_info(self) -> Tuple[bool, dict]:
110 | """
111 | Get done and info.
112 | The done indicates whether the episode is over.
113 | The info contains several informations such as whether the episode is successed.
114 | Returns :
115 | (bool) : done
116 | (dict) : info
117 | """
118 | is_success = self.is_success()
119 | return self.time_over() | is_success, {"success": is_success}
120 |
121 | def is_success(self) -> bool:
122 | """
123 | Whether the episode is successed.
124 | Returns:
125 | (bool) : if the episode is successed, then return True
126 | """
127 | if self.is_goal_state():
128 | self.cur_consecutive_visit_to_goal += 1
129 | else:
130 | self.cur_consecutive_visit_to_goal = 0
131 | return self.cur_consecutive_visit_to_goal >= self.max_consecutive_visit_to_goal
132 |
133 | def time_over(self) -> bool:
134 | """
135 | Whether the episode is progressing beyond the set maximum time step.
136 | Returns:
137 | (bool) : if timestep is overed, then return True
138 | """
139 | return True if self.current_time_step >= self.max_time_step else False
140 |
141 | def reset(self) -> np.ndarray:
142 | """
143 | Reset env and Start new episode.
144 | Returns:
145 | (np.ndarray) : initial obs
146 | """
147 | self.current_time_step = 0
148 | self.cur_consecutive_visit_to_goal = 0
149 | self.env.stop()
150 | self.env.start()
151 | self.arm.set_control_loop_enabled(False)
152 | self.arm.set_motor_locked_at_zero_velocity(True)
153 | self.reset_objects()
154 | self.env.step()
155 | return self.get_obs()
156 |
157 | def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, dict]:
158 | """
159 | Execute the action and Observe the new obs, reward, etc.
160 | Args:
161 | action (np.ndarray): the action decided by the agent
162 |
163 | Returns:
164 | (np.ndarray) : new obs
165 | (float) : reward
166 | (bool) : done
167 | (dict) : info
168 | """
169 | self.current_time_step += 1
170 | arm_control = action[:-1]
171 | gripper_control = action[-1]
172 | gripper_control = 1.0 if action[-1] > 0.0 else 0.0
173 | self.arm.set_joint_target_velocities(arm_control)
174 | self.gripper.actuate(gripper_control, self.gripper_velocity)
175 | self.env.step()
176 | done, info = self.get_done_and_info()
177 | return self.get_obs(), self.get_reward(), done, info
178 |
179 | def close(self):
180 | """
181 | Close the env.
182 | """
183 | self.env.stop()
184 | self.env.shutdown()
185 |
186 | def get_distance_from_tip(self, object_position: np.ndarray) -> float:
187 | """
188 | Get distance between tip and target object
189 | Args:
190 | object_position (np.ndarray): position of the target object
191 |
192 | Returns:
193 | (float): distance
194 | """
195 | return get_distance_between_two_pts(self.tip.get_position(), object_position)
196 |
197 | def get_robot_state(self) -> List[float]:
198 | """
199 | Get the robot state.
200 | The robot state contains the robot arm state and the gripper state.
201 | The robot arm state contains positions of arm joints and velocities of arm joints.
202 | The gripper state contains mean of open amount of gripper and mean of velocities of gripper joints.
203 | Returns:
204 | (list) : robot state
205 | """
206 | arm_joint_positions = self.arm.get_joint_positions()
207 | arm_joint_velocities = self.arm.get_joint_velocities()
208 | gripper_positions = self.gripper.get_open_amount()
209 | gripper_velocities = self.gripper.get_joint_velocities()
210 | gripper_position_mean = sum(gripper_positions) / len(gripper_positions)
211 | gripper_velocity_mean = sum(gripper_velocities) / len(gripper_velocities)
212 | robot_state = arm_joint_positions
213 | robot_state.extend(arm_joint_velocities)
214 | robot_state.append(gripper_position_mean)
215 | robot_state.append(gripper_velocity_mean)
216 | return robot_state
217 |
218 | def get_object_position_relative_to_base_link(
219 | self, target_object: Object
220 | ) -> np.ndarray:
221 | """
222 | Get relative positoin of the target object to robot arm base link.
223 | Args:
224 | target_object (Object): target object
225 |
226 | Returns:
227 | (np.ndarray): position array containing x, y, z
228 | """
229 | return target_object.get_position(relative_to=self.arm_base_link)
230 |
--------------------------------------------------------------------------------