├── pretrained_models └── reach │ └── checkpoint_000575 │ ├── .is_checkpoint │ ├── checkpoint-575 │ ├── checkpoint-575.tune_metadata │ └── config.yaml ├── envs ├── __init__.py ├── reach │ ├── reach.ttt │ └── reach_env.py ├── wrappers.py └── base_env.py ├── requirements.txt ├── utils ├── geometry.py ├── image.py ├── config.py └── rllib.py ├── configs ├── default_config.yaml └── image_obs_config.yaml ├── README.md ├── .gitignore └── main.py /pretrained_models/reach/checkpoint_000575/.is_checkpoint: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /envs/__init__.py: -------------------------------------------------------------------------------- 1 | from envs.reach.reach_env import ReachEnv 2 | -------------------------------------------------------------------------------- /envs/reach/reach.ttt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isk03276/LearnToMoveUR3/HEAD/envs/reach/reach.ttt -------------------------------------------------------------------------------- /pretrained_models/reach/checkpoint_000575/checkpoint-575: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isk03276/LearnToMoveUR3/HEAD/pretrained_models/reach/checkpoint_000575/checkpoint-575 -------------------------------------------------------------------------------- /pretrained_models/reach/checkpoint_000575/checkpoint-575.tune_metadata: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isk03276/LearnToMoveUR3/HEAD/pretrained_models/reach/checkpoint_000575/checkpoint-575.tune_metadata -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | 3 | gym 4 | 5 | # ML 6 | ray[rllib] 7 | 8 | #opencv 9 | opencv-python-headless==4.4.0.46 10 | 11 | #formatting 12 | black==22.1.0 13 | 14 | #ML framework 15 | torch==1.10.0 16 | tensorboard==2.8.0 -------------------------------------------------------------------------------- /utils/geometry.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | 5 | 6 | def get_distance_between_two_pts(pts1: np.ndarray, pts2: np.ndarray) -> float: 7 | """ 8 | Calculate distance between two points. 9 | Args: 10 | pts1 (list): first point 11 | pts2 (list): second point 12 | 13 | Returns: 14 | flot: distance 15 | """ 16 | return np.linalg.norm(pts1 - pts2) 17 | -------------------------------------------------------------------------------- /utils/image.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def resize_image(image: np.ndarray, width: int, height: int) -> np.ndarray: 6 | """ 7 | Resize a image. 8 | Args: 9 | image (np.ndarray): input image 10 | width (int): width size to resize 11 | height (int): height size to resize 12 | 13 | Returns: 14 | np.ndarray: Resized image 15 | """ 16 | return cv2.resize(image, dsize=(width, height)) 17 | -------------------------------------------------------------------------------- /pretrained_models/reach/checkpoint_000575/config.yaml: -------------------------------------------------------------------------------- 1 | env: {} 2 | env_config: 3 | use_arm_camera: false 4 | use_image_observation: false 5 | rllib: 6 | batch_mode: complete_episodes 7 | callbacks: !!python/name:utils.rllib.CustomLogCallback '' 8 | clip_actions: true 9 | clip_param: 0.1 10 | clip_rewards: false 11 | entropy_coeff: 0.01 12 | framework: torch 13 | kl_coeff: 0.5 14 | lambda: 0.95 15 | num_gpus: 1 16 | num_sgd_iter: 10 17 | num_workers: 8 18 | observation_filter: NoFilter 19 | rollout_fragment_length: 100 20 | sgd_minibatch_size: 500 21 | train_batch_size: 5000 22 | vf_clip_param: 10.0 23 | -------------------------------------------------------------------------------- /configs/default_config.yaml: -------------------------------------------------------------------------------- 1 | env : {} 2 | env_config: 3 | use_image_observation: false 4 | use_arm_camera: false 5 | 6 | rllib: 7 | # Whether to clip actions to the action space's low/high range spec. 8 | clip_actions: true 9 | num_gpus: 0 10 | num_workers: 4 11 | framework: torch 12 | lambda: 0.95 13 | kl_coeff: 0.5 14 | clip_rewards: false 15 | clip_param: 0.1 16 | vf_clip_param: 10.0 17 | entropy_coeff: 0.01 18 | train_batch_size: 5000 19 | rollout_fragment_length: 100 20 | sgd_minibatch_size: 500 21 | num_sgd_iter: 10 22 | observation_filter: NoFilter 23 | # Whether to rollout "complete_episodes" or "truncate_episodes". 24 | batch_mode: complete_episodes 25 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | 3 | 4 | def load_config(config_file: str) -> dict: 5 | """ 6 | Load a config file. 7 | 8 | Args: 9 | config_file (str): config file path 10 | 11 | Returns: 12 | (dict): loaded config infoes 13 | """ 14 | with open(config_file, "r") as f: 15 | config = yaml.load(f, Loader=yaml.FullLoader) 16 | return config 17 | 18 | 19 | def save_config(configs: dict, config_file_path: str): 20 | """ 21 | Save a config file. 22 | Args: 23 | configs (dict): env, rllib configs 24 | config_file (str): config file path to save 25 | """ 26 | with open(config_file_path, "w") as f: 27 | yaml.dump(configs, f) 28 | -------------------------------------------------------------------------------- /envs/wrappers.py: -------------------------------------------------------------------------------- 1 | from gym import Env, ObservationWrapper 2 | from gym.spaces import Box 3 | 4 | from utils.image import resize_image 5 | 6 | 7 | class ImageObsWrapper(ObservationWrapper): 8 | def __init__( 9 | self, env: Env, use_arm_camera: bool = True, width: int = 84, height: int = 84 10 | ): 11 | super().__init__(env) 12 | self.use_arm_camera = use_arm_camera 13 | self.width = width 14 | self.height = height 15 | self._modify_observation_space() 16 | 17 | def observation(self, _observation): 18 | obs = self.env.render() 19 | obs = resize_image(obs, self.width, self.height) 20 | return obs 21 | 22 | def _modify_observation_space(self): 23 | channel_num = 6 if self.use_arm_camera else 3 24 | self.env.observation_space = Box(0, 255, (self.width, self.height, channel_num)) 25 | -------------------------------------------------------------------------------- /configs/image_obs_config.yaml: -------------------------------------------------------------------------------- 1 | env : {} 2 | env_config: 3 | use_image_observation: true 4 | use_arm_camera: true 5 | 6 | rllib: 7 | model: 8 | vf_share_layers: false 9 | use_lstm: true 10 | # Whether to feed a_{t-1} to LSTM (one-hot encoded if discrete). 11 | lstm_use_prev_action: false 12 | lstm_cell_size: 256 13 | # Whether to clip actions to the action space's low/high range spec. 14 | clip_actions: true 15 | num_gpus: 1 16 | num_workers: 8 17 | framework: torch 18 | lambda: 0.95 19 | kl_coeff: 0.5 20 | clip_rewards: false 21 | clip_param: 0.1 22 | vf_clip_param: 10.0 23 | entropy_coeff: 0.01 24 | train_batch_size: 5000 25 | rollout_fragment_length: 100 26 | sgd_minibatch_size: 500 27 | num_sgd_iter: 10 28 | observation_filter: NoFilter 29 | # Whether to rollout "complete_episodes" or "truncate_episodes". 30 | batch_mode: complete_episodes 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UR3-Deep-Reinforcement-Learning 2 | 3 | You will be able to learn various tasks of the UR3 with robotiq85 gripper robot. 4 | Learning method is based on the DRL(Deep Reinforcement Learning). 5 | In this repo, we use [CoppeliaSim](http://www.coppeliarobotics.com/) (previously called V-REP), [Pyrep](https://github.com/stepjam/PyRep). 6 | Tasks 7 | - DRL framework : rllib 8 | - Supported tasks 9 | 10 | |Tasks|Learned Task Example|Learning Curve| 11 | |:---:|:---:|:---:| 12 | |**reach**||| 13 | |TO DO|-|-| 14 | 15 | ## Install 16 | This repo was tested with Python 3.7.9 version. 17 | 18 | #### Coppeliasim 19 | PyRep requires version **4.1(other versions may have bugs)** of CoppeliaSim. Download: 20 | - [Ubuntu 16.04](https://www.coppeliarobotics.com/files/CoppeliaSim_Edu_V4_1_0_Ubuntu16_04.tar.xz) 21 | - [Ubuntu 18.04](https://www.coppeliarobotics.com/files/CoppeliaSim_Edu_V4_1_0_Ubuntu18_04.tar.xz) 22 | - [Ubuntu 20.04](https://www.coppeliarobotics.com/files/CoppeliaSim_Edu_V4_1_0_Ubuntu20_04.tar.xz) 23 | 24 | Add the following to your *~/.bashrc* file: (__NOTE__: the 'EDIT ME' in the first line) 25 | ```bash 26 | export COPPELIASIM_ROOT=EDIT/ME/PATH/TO/COPPELIASIM/INSTALL/DIR 27 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$COPPELIASIM_ROOT 28 | export QT_QPA_PLATFORM_PLUGIN_PATH=$COPPELIASIM_ROOT 29 | ``` 30 | 31 | #### PyRep 32 | Once you have downloaded and set CoppeliaSim, you can install PyRep: 33 | Move to home workspace 34 | ```bash 35 | git clone https://github.com/stepjam/PyRep.git 36 | cd PyRep 37 | pip install -r requirements.txt 38 | pip install -e . 39 | ``` 40 | 41 | Remember to source your bashrc (`source ~/.bashrc`) or 42 | zshrc (`source ~/.zshrc`) after this. 43 | 44 | 45 | #### LearnToMoveUR3 46 | Move to home workspace 47 | Clone repo and Install the python library: 48 | ```bash 49 | git clone https://github.com/isk03276/LearnToMoveUR3.git 50 | cd LearnToMoveUR3 51 | pip install -r requirements.txt 52 | ``` 53 | 54 | 55 | ## Getting Started 56 | ```bash 57 | python main.py --env-id ENV_ID --load-from MODEL_CHECKPOINT_PATH #Train 58 | python main.py --env-id reach --test --load-from MODEL_CHECKPOINT_PATH #Test 59 | ``` 60 | 61 | 62 | ## Use Pretrained Model 63 | ```bash 64 | python main.py --env-id reach --load-from pretrained_models/reach --test 65 | ``` 66 | 67 | -------------------------------------------------------------------------------- /envs/reach/reach_env.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | from typing import overload 4 | 5 | from envs.base_env import BaseEnv 6 | 7 | import numpy as np 8 | from pyrep.objects.shape import Shape 9 | from gym.spaces import Box 10 | 11 | 12 | class ReachEnv(BaseEnv): 13 | """ 14 | Reach env class. 15 | """ 16 | 17 | def __init__( 18 | self, 19 | scene_file="envs/reach/reach.ttt", 20 | use_arm_camera: bool = False, 21 | rendering: bool = True, 22 | ): 23 | super().__init__(scene_file, use_arm_camera, rendering) 24 | self.target = Shape("TargetPoint") 25 | self.target_x_range = (0.2, 0.4) 26 | self.target_y_range = (-0.2, 0.2) 27 | self.target_z_range = (0.4, 0.7) 28 | 29 | self.reach_threshold = 0.05 30 | 31 | def _define_observation_space(self) -> Box: 32 | """ 33 | Define/Get observation space. 34 | """ 35 | observation_space = Box(float("-inf"), float("inf"), (17,)) 36 | return observation_space 37 | 38 | def get_obs(self) -> np.ndarray: 39 | """ 40 | Get agent's observation. 41 | The observation contains robot state and a relative position of target object 42 | Returns: 43 | (np.ndarray) : agent's observation 44 | """ 45 | obs = self.get_robot_state() 46 | target_realtive_position = self.get_object_position_relative_to_base_link( 47 | self.target 48 | ) 49 | obs.extend(target_realtive_position) 50 | return np.array(obs) 51 | 52 | def get_reward(self) -> float: 53 | """ 54 | This reward function is based on the distance between target object and tip. 55 | Returns: 56 | (float) : reward 57 | """ 58 | distance_between_tip_and_target = self.get_distance_from_tip( 59 | self.target.get_position() 60 | ) 61 | return -math.log10(distance_between_tip_and_target / 10 + 1) 62 | 63 | def reset_objects(self): 64 | """ 65 | Reset a target object. 66 | """ 67 | random_point_x = random.uniform(self.target_x_range[0], self.target_x_range[1]) 68 | random_point_y = random.uniform(self.target_y_range[0], self.target_y_range[1]) 69 | random_point_z = random.uniform(self.target_z_range[0], self.target_z_range[1]) 70 | self.target.set_position([random_point_x, random_point_y, random_point_z]) 71 | 72 | def is_goal_state(self) -> bool: 73 | """ 74 | If the target object and the tip are close, it is considered as a goal state. 75 | """ 76 | distance_between_tip_and_target = self.get_distance_from_tip( 77 | self.target.get_position() 78 | ) 79 | return distance_between_tip_and_target <= self.reach_threshold 80 | -------------------------------------------------------------------------------- /utils/rllib.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import tempfile 4 | from typing import Callable 5 | 6 | import numpy as np 7 | from ray.rllib.agents.trainer import Trainer 8 | from ray.tune.logger import UnifiedLogger 9 | from ray.rllib.agents.callbacks import DefaultCallbacks 10 | 11 | 12 | def get_current_time() -> str: 13 | """ 14 | Generate current time as string. 15 | Returns: 16 | str: current time 17 | """ 18 | NOWTIMES = datetime.datetime.now() 19 | curr_time = NOWTIMES.strftime("%y%m%d_%H%M%S") 20 | return curr_time 21 | 22 | 23 | def make_logging_folder(root_dir: str, env_id: str, is_test: bool) -> str: 24 | """ 25 | Make a folder for logging. 26 | Args: 27 | root_dir (str): parent directory 28 | env_id (str): env id name 29 | is_test (bool): whether to test the model 30 | 31 | Returns: 32 | str: maked logging folder name 33 | """ 34 | task = "Train" if not is_test else "Test" 35 | logdir_prefix = "[{}]{}_{}_".format(task, get_current_time(), env_id) 36 | if not os.path.exists(root_dir): 37 | os.makedirs(root_dir) 38 | logdir = tempfile.mkdtemp(prefix=logdir_prefix, dir=root_dir) 39 | return logdir 40 | 41 | 42 | def make_initial_hidden_state(lstm_cell_size: int) -> list: 43 | """ 44 | Make initial hidden state for testing lstm-based policy network. 45 | Args: 46 | lstm_cell_size (int): lstm cell size 47 | 48 | Returns: 49 | list: hidden state 50 | """ 51 | hidden_state = [np.zeros(lstm_cell_size), np.zeros(lstm_cell_size)] 52 | return hidden_state 53 | 54 | 55 | def save_model(trainer: Trainer, path_to_save: str): 56 | """ 57 | Save trained model. 58 | Args: 59 | trainer (Trainer): rllib trainer 60 | """ 61 | trainer.save(path_to_save) 62 | 63 | 64 | def load_model(trainer: Trainer, path_to_load: str): 65 | """ 66 | Load trained model. 67 | Args: 68 | trainer (Trainer): rllib trainer 69 | path_to_load (str): path to load 70 | """ 71 | trainer.restore(path_to_load) 72 | 73 | 74 | def get_logger_creator(logdir: str) -> Callable: 75 | """ 76 | Get default logger creator for logging in rllib. 77 | Args: 78 | logdir (str): logging directory path 79 | 80 | Returns: 81 | Callable: logger creator 82 | """ 83 | 84 | def logger_creator(config): 85 | return UnifiedLogger(config, logdir, loggers=None) 86 | 87 | return logger_creator 88 | 89 | 90 | class CustomLogCallback(DefaultCallbacks): 91 | """ 92 | LogCallback based on Rllib callbacks for add 'success' custom metric. 93 | """ 94 | 95 | def on_episode_end(self, *, episode, **kwargs): 96 | """On episode end, add success custom metric.""" 97 | info = episode.last_info_for() 98 | episode.custom_metrics["success"] = int(info["success"]) 99 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Lua sources 2 | luac.out 3 | 4 | # luarocks build files 5 | *.src.rock 6 | *.zip 7 | *.tar.gz 8 | 9 | # Object files 10 | *.o 11 | *.os 12 | *.ko 13 | *.obj 14 | *.elf 15 | 16 | # Precompiled Headers 17 | *.gch 18 | *.pch 19 | 20 | # Libraries 21 | *.lib 22 | *.a 23 | *.la 24 | *.lo 25 | *.def 26 | *.exp 27 | 28 | # Shared objects (inc. Windows DLLs) 29 | *.dll 30 | *.so 31 | *.so.* 32 | *.dylib 33 | 34 | # Executables 35 | *.exe 36 | *.out 37 | *.app 38 | *.i*86 39 | *.x86_64 40 | *.hex 41 | 42 | # Byte-compiled / optimized / DLL files 43 | __pycache__/ 44 | *.py[cod] 45 | *$py.class 46 | 47 | # C extensions 48 | *.so 49 | 50 | # Distribution / packaging 51 | .Python 52 | build/ 53 | develop-eggs/ 54 | dist/ 55 | downloads/ 56 | eggs/ 57 | .eggs/ 58 | lib/ 59 | lib64/ 60 | parts/ 61 | sdist/ 62 | var/ 63 | wheels/ 64 | pip-wheel-metadata/ 65 | share/python-wheels/ 66 | *.egg-info/ 67 | .installed.cfg 68 | *.egg 69 | MANIFEST 70 | 71 | # PyInstaller 72 | # Usually these files are written by a python script from a template 73 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 74 | *.manifest 75 | *.spec 76 | 77 | # Installer logs 78 | pip-log.txt 79 | pip-delete-this-directory.txt 80 | 81 | # Unit test / coverage reports 82 | htmlcov/ 83 | .tox/ 84 | .nox/ 85 | .coverage 86 | .coverage.* 87 | .cache 88 | nosetests.xml 89 | coverage.xml 90 | *.cover 91 | *.py,cover 92 | .hypothesis/ 93 | .pytest_cache/ 94 | 95 | # Translations 96 | *.mo 97 | *.pot 98 | 99 | # Django stuff: 100 | *.log 101 | local_settings.py 102 | db.sqlite3 103 | db.sqlite3-journal 104 | 105 | # Flask stuff: 106 | instance/ 107 | .webassets-cache 108 | 109 | # Scrapy stuff: 110 | .scrapy 111 | 112 | # Sphinx documentation 113 | docs/_build/ 114 | 115 | # PyBuilder 116 | target/ 117 | 118 | # Jupyter Notebook 119 | .ipynb_checkpoints 120 | 121 | # IPython 122 | profile_default/ 123 | ipython_config.py 124 | 125 | # pyenv 126 | .python-version 127 | 128 | # pipenv 129 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 130 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 131 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 132 | # install all needed dependencies. 133 | #Pipfile.lock 134 | 135 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 136 | __pypackages__/ 137 | 138 | # Celery stuff 139 | celerybeat-schedule 140 | celerybeat.pid 141 | 142 | # SageMath parsed files 143 | *.sage.py 144 | 145 | # Environments 146 | .env 147 | .venv 148 | env/ 149 | venv/ 150 | ENV/ 151 | env.bak/ 152 | venv.bak/ 153 | 154 | # Spyder project settings 155 | .spyderproject 156 | .spyproject 157 | 158 | # Rope project settings 159 | .ropeproject 160 | 161 | # mkdocs documentation 162 | /site 163 | 164 | # mypy 165 | .mypy_cache/ 166 | .dmypy.json 167 | dmypy.json 168 | 169 | # Pyre type checker 170 | .pyre/ 171 | 172 | # trained model 173 | checkpoints/ 174 | games/ 175 | .vscode/ 176 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from utils.rllib import ( 4 | load_model, 5 | make_initial_hidden_state, 6 | get_logger_creator, 7 | make_logging_folder, 8 | save_model, 9 | CustomLogCallback, 10 | ) 11 | from utils.config import load_config, save_config 12 | from envs.wrappers import ImageObsWrapper 13 | 14 | import ray 15 | from ray import tune 16 | from ray.rllib.agents import ppo 17 | 18 | 19 | def make_env( 20 | env_id: str, use_image_observation: bool, use_arm_camera: bool, rendering: bool 21 | ): 22 | if env_id == "reach": 23 | from envs import ReachEnv as env_class 24 | else: 25 | raise NotImplementedError 26 | 27 | env = env_class(use_arm_camera=use_arm_camera, rendering=rendering) 28 | return ImageObsWrapper(env) if use_image_observation else env 29 | 30 | 31 | def train(trainer, target_success_mean, path_to_save, save_interval): 32 | status = "[Train] {:2d} reward {:6.2f} len {:6.2f} success mean {:6.2f}" 33 | 34 | iteration = 0 35 | while True: 36 | result = trainer.train() 37 | success_mean = result["custom_metrics"]["success_mean"] 38 | print( 39 | status.format( 40 | iteration, 41 | result["episode_reward_mean"], 42 | result["episode_len_mean"], 43 | success_mean, 44 | ) 45 | ) 46 | iteration += 1 47 | if iteration % save_interval == 0: 48 | save_model(trainer, path_to_save) 49 | if success_mean >= target_success_mean: 50 | save_model(trainer, path_to_save) 51 | break 52 | 53 | 54 | def test(env, trainer, test_num): 55 | use_lstm = trainer.config.get("model").get("use_lstm") 56 | lstm_cell_size = trainer.config.get("model").get("lstm_cell_size") 57 | success_list = [] 58 | for ep in range(test_num): 59 | done = False 60 | obs = env.reset() 61 | rews = [] 62 | if use_lstm: 63 | hidden_state = make_initial_hidden_state(lstm_cell_size) 64 | 65 | status = "[Test] {:2d} reward {:6.2f} len {:6.2f}, success mean {:6.2f}" 66 | 67 | while not done: 68 | if use_lstm: 69 | action, hidden_state, _ = trainer.compute_action(obs, hidden_state) 70 | else: 71 | action = trainer.compute_action(obs) 72 | obs, rew, done, info = env.step(action) 73 | rews.append(rew) 74 | success_list.append(int(info["success"])) 75 | print( 76 | status.format( 77 | ep + 1, 78 | sum(rews) / len(rews), 79 | len(rews), 80 | sum(success_list) / len(success_list), 81 | ) 82 | ) 83 | 84 | 85 | def run(args): 86 | # load rllib config 87 | ray.init() 88 | configs = load_config(args.config_file_path) 89 | configs_to_save = configs.copy() 90 | rllib_configs = configs["rllib"] 91 | rllib_configs["callbacks"] = CustomLogCallback 92 | 93 | # env setting 94 | env_id = args.env_id 95 | env_config = configs["env_config"] 96 | env_args = { 97 | "env_id": env_id, 98 | "use_image_observation": env_config["use_image_observation"], 99 | "use_arm_camera": env_config["use_arm_camera"], 100 | "rendering": False if args.test else args.render, 101 | } 102 | tune.register_env( 103 | env_id, lambda _: make_env(**env_args), 104 | ) 105 | 106 | # logging setting 107 | logdir = make_logging_folder( 108 | root_dir="checkpoints/", env_id=env_id, is_test=args.test 109 | ) 110 | save_config(configs_to_save, logdir + "/config.yaml") 111 | logger_creator = get_logger_creator(logdir=logdir) 112 | 113 | # rllib trainer setting 114 | trainer = ppo.PPOTrainer( 115 | env=env_id, config=rllib_configs, logger_creator=logger_creator 116 | ) 117 | 118 | if args.load_from is not None: 119 | load_model(trainer, args.load_from) 120 | 121 | if not args.test: 122 | train(trainer, args.target_success_mean, logdir, args.save_interval) 123 | 124 | env_args["rendering"] = args.render 125 | test_env = make_env(**env_args) 126 | test(test_env, trainer, args.test_num) 127 | 128 | ray.shutdown() 129 | 130 | 131 | if __name__ == "__main__": 132 | parser = argparse.ArgumentParser(description="Game environments to learn") 133 | parser.add_argument( 134 | "--env-id", default="reach", type=str, help="game environment id: 'reach', ..." 135 | ) 136 | parser.add_argument( 137 | "--config-file-path", 138 | default="configs/default_config.yaml", 139 | type=str, 140 | help="Rllib config file path", 141 | ) 142 | parser.add_argument("--render", action="store_true", help="Turn on rendering") 143 | # model 144 | parser.add_argument( 145 | "--save-interval", type=int, default=20, help="Model save interval" 146 | ) 147 | parser.add_argument("--load-from", type=str, help="Path to load the model") 148 | # train/test 149 | parser.add_argument( 150 | "--target-success-mean", 151 | type=float, 152 | default=0.99, 153 | help="Learning ends when the current success mean is higher than ther target success mean", 154 | ) 155 | parser.add_argument("--test", action="store_true", help="Whether to test the model") 156 | parser.add_argument( 157 | "--test-num", type=int, default=10, help="Number of episodes to test the model" 158 | ) 159 | 160 | args = parser.parse_args() 161 | run(args) 162 | -------------------------------------------------------------------------------- /envs/base_env.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List 2 | 3 | from utils.geometry import get_distance_between_two_pts 4 | 5 | import gym 6 | from gym.spaces import Box 7 | import numpy as np 8 | from pyrep import PyRep 9 | from pyrep.robots.arms.ur3 import UR3 10 | from pyrep.robots.end_effectors.robotiq85_gripper import Robotiq85Gripper 11 | from pyrep.objects.vision_sensor import VisionSensor 12 | from pyrep.objects.shape import Shape 13 | from pyrep.objects import Object 14 | 15 | 16 | class BaseEnv(gym.Env): 17 | """ 18 | Base environment class for learning behavior of UR3 robot. 19 | """ 20 | 21 | def __init__( 22 | self, scene_file: str, use_arm_camera: bool = False, rendering: bool = True 23 | ): 24 | super().__init__() 25 | self.use_arm_camera = use_arm_camera 26 | self.env = PyRep() 27 | self.env.launch(scene_file, headless=not rendering) 28 | self.env.start() 29 | self.arm = None 30 | self.arm_base_link = None 31 | self.gripper = None 32 | self.tip = None 33 | self.third_view_camera = VisionSensor("kinect_rgb") 34 | self.arm_camera = VisionSensor("arm_camera_rgb") 35 | self.gripper_velocity = 0.2 36 | self._init_robot() 37 | 38 | self.max_time_step = 300 39 | self.current_time_step = 0 40 | 41 | # for the agent to keep the goal states 42 | self.max_consecutive_visit_to_goal = 5 43 | self.cur_consecutive_visit_to_goal = 0 44 | 45 | self.observation_space = self._define_observation_space() 46 | self.action_space = Box(-1.0, 1, (7,)) 47 | 48 | def _init_robot(self): 49 | """ 50 | Initialize robot. 51 | Assume that (robot arm) UR3 with (gripper) Robotiq85Gripper. 52 | """ 53 | self.arm = UR3() 54 | self.arm_base_link = Shape("UR3_link2") 55 | self.gripper = Robotiq85Gripper() 56 | self.tip = self.arm.get_tip() 57 | 58 | def _define_observation_space(self) -> Box: 59 | """ 60 | Define/Get observation space. 61 | Returns: 62 | (Box) : defined observation space 63 | """ 64 | raise NotImplementedError 65 | 66 | def get_obs(self) -> np.ndarray: 67 | """ 68 | Get agent's observation. 69 | Returns: 70 | (np.ndarray) : agent's observation 71 | """ 72 | raise NotImplementedError 73 | 74 | def get_reward(self) -> float: 75 | """ 76 | Get reward for state/action. 77 | We currently assume reward for state. 78 | Returns: 79 | (float) : reward 80 | """ 81 | raise NotImplementedError 82 | 83 | def reset_objects(self): 84 | """ 85 | Reset objects in env. (ex. target_object) 86 | """ 87 | raise NotImplementedError 88 | 89 | def is_goal_state(self) -> bool: 90 | """ 91 | Whether the current state is a goal state or not. 92 | Returns : 93 | (bool) : if current state is a goal state, then return True 94 | """ 95 | raise NotImplementedError 96 | 97 | def render(self) -> np.ndarray: 98 | """ 99 | Get image observation from cameras. 100 | Returns: 101 | np.ndarray: image observation 102 | """ 103 | obs = self.third_view_camera.capture_rgb() 104 | if self.use_arm_camera: 105 | first_view_image = self.arm_camera.capture_rgb() 106 | obs = np.concatenate((obs, first_view_image), axis=2) 107 | return obs 108 | 109 | def get_done_and_info(self) -> Tuple[bool, dict]: 110 | """ 111 | Get done and info. 112 | The done indicates whether the episode is over. 113 | The info contains several informations such as whether the episode is successed. 114 | Returns : 115 | (bool) : done 116 | (dict) : info 117 | """ 118 | is_success = self.is_success() 119 | return self.time_over() | is_success, {"success": is_success} 120 | 121 | def is_success(self) -> bool: 122 | """ 123 | Whether the episode is successed. 124 | Returns: 125 | (bool) : if the episode is successed, then return True 126 | """ 127 | if self.is_goal_state(): 128 | self.cur_consecutive_visit_to_goal += 1 129 | else: 130 | self.cur_consecutive_visit_to_goal = 0 131 | return self.cur_consecutive_visit_to_goal >= self.max_consecutive_visit_to_goal 132 | 133 | def time_over(self) -> bool: 134 | """ 135 | Whether the episode is progressing beyond the set maximum time step. 136 | Returns: 137 | (bool) : if timestep is overed, then return True 138 | """ 139 | return True if self.current_time_step >= self.max_time_step else False 140 | 141 | def reset(self) -> np.ndarray: 142 | """ 143 | Reset env and Start new episode. 144 | Returns: 145 | (np.ndarray) : initial obs 146 | """ 147 | self.current_time_step = 0 148 | self.cur_consecutive_visit_to_goal = 0 149 | self.env.stop() 150 | self.env.start() 151 | self.arm.set_control_loop_enabled(False) 152 | self.arm.set_motor_locked_at_zero_velocity(True) 153 | self.reset_objects() 154 | self.env.step() 155 | return self.get_obs() 156 | 157 | def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, dict]: 158 | """ 159 | Execute the action and Observe the new obs, reward, etc. 160 | Args: 161 | action (np.ndarray): the action decided by the agent 162 | 163 | Returns: 164 | (np.ndarray) : new obs 165 | (float) : reward 166 | (bool) : done 167 | (dict) : info 168 | """ 169 | self.current_time_step += 1 170 | arm_control = action[:-1] 171 | gripper_control = action[-1] 172 | gripper_control = 1.0 if action[-1] > 0.0 else 0.0 173 | self.arm.set_joint_target_velocities(arm_control) 174 | self.gripper.actuate(gripper_control, self.gripper_velocity) 175 | self.env.step() 176 | done, info = self.get_done_and_info() 177 | return self.get_obs(), self.get_reward(), done, info 178 | 179 | def close(self): 180 | """ 181 | Close the env. 182 | """ 183 | self.env.stop() 184 | self.env.shutdown() 185 | 186 | def get_distance_from_tip(self, object_position: np.ndarray) -> float: 187 | """ 188 | Get distance between tip and target object 189 | Args: 190 | object_position (np.ndarray): position of the target object 191 | 192 | Returns: 193 | (float): distance 194 | """ 195 | return get_distance_between_two_pts(self.tip.get_position(), object_position) 196 | 197 | def get_robot_state(self) -> List[float]: 198 | """ 199 | Get the robot state. 200 | The robot state contains the robot arm state and the gripper state. 201 | The robot arm state contains positions of arm joints and velocities of arm joints. 202 | The gripper state contains mean of open amount of gripper and mean of velocities of gripper joints. 203 | Returns: 204 | (list) : robot state 205 | """ 206 | arm_joint_positions = self.arm.get_joint_positions() 207 | arm_joint_velocities = self.arm.get_joint_velocities() 208 | gripper_positions = self.gripper.get_open_amount() 209 | gripper_velocities = self.gripper.get_joint_velocities() 210 | gripper_position_mean = sum(gripper_positions) / len(gripper_positions) 211 | gripper_velocity_mean = sum(gripper_velocities) / len(gripper_velocities) 212 | robot_state = arm_joint_positions 213 | robot_state.extend(arm_joint_velocities) 214 | robot_state.append(gripper_position_mean) 215 | robot_state.append(gripper_velocity_mean) 216 | return robot_state 217 | 218 | def get_object_position_relative_to_base_link( 219 | self, target_object: Object 220 | ) -> np.ndarray: 221 | """ 222 | Get relative positoin of the target object to robot arm base link. 223 | Args: 224 | target_object (Object): target object 225 | 226 | Returns: 227 | (np.ndarray): position array containing x, y, z 228 | """ 229 | return target_object.get_position(relative_to=self.arm_base_link) 230 | --------------------------------------------------------------------------------