├── .gitignore ├── LICENSE ├── README.md ├── data ├── images │ ├── myochallenge_2023.gif │ └── myochallenge_ranking.png └── myosuite │ └── assets │ ├── __init__.py │ ├── arm │ ├── __init__.py │ └── myo_elbow_1dof6muscles.mjb │ ├── finger │ ├── __init__.py │ └── myo_finger_v0.mjb │ └── hand │ ├── __init__.py │ ├── myo_hand_baoding.mjb │ ├── myo_hand_die.mjb │ ├── myo_hand_pen.mjb │ └── myo_hand_pose.mjb ├── docker ├── Dockerfile └── requirements.txt └── src ├── definitions.py ├── envs ├── __init__.py ├── baoding.py ├── environment_factory.py ├── pen.py ├── pose.py └── reorient.py ├── main_ant.py ├── main_baoding.py ├── main_half_cheetah.py ├── main_hopper.py ├── main_humanoid.py ├── main_pen.py ├── main_pose_elbow.py ├── main_pose_finger.py ├── main_pose_hand.py ├── main_reach_finger.py ├── main_reach_hand.py ├── main_reorient.py ├── main_walker.py ├── metrics ├── __init__.py └── custom_callbacks.py ├── models ├── __init__.py ├── distributions.py ├── ppo │ └── policies.py └── sac │ └── policies.py └── train ├── __init__.py └── trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Apple 132 | *DS_Store 133 | 134 | # PyCharm 135 | *.idea 136 | 137 | # Project 138 | /output 139 | 140 | # Tensorboard 141 | .monitor.csv 142 | 143 | # VS Code 144 | .vscode/launch.json -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Mathis Group 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Lattice (Latent Exploration for Reinforcement Learning) 2 | 3 | This repository includes the implementation of Lattice exploration from the paper [Latent Exploration for Reinforcement Learning](https://arxiv.org/abs/2305.20065), published at NeurIPS 2023. 4 | 5 | Lattice introduces random perturbations in the latent state of the policy network, which result in correlated noise across the system's actuators. This form of latent noise can facilitate exploration when controlling high-dimensional systems, especially with redundant actuation, and may find low-effort solutions. A short video explaining the project can be found on [YouTube](https://www.youtube.com/watch?v=_CCF3GDM9jY). 6 | 7 | Lattice builds on top of Stable Baselines 3 (version 1.6.1) and it is here implemented for Recurrent PPO and SAC. Integration with a more recent version of Stable Baselines 3 and compatibility with more algorithms is [currently under development](https://github.com/albertochiappa/stable-baselines3). 8 | 9 | This project was developed by Alberto Silvio Chiappa, Alessandro Marin Vargas, Ann Zixiang Huang and Alexander Mathis (EPFL). 10 | 11 | ## MyoChallenge 2023 12 | 13 | We used Lattice train the top submission of the NeurIPS 2023 competition [MyoChallenge](https://sites.google.com/view/myosuite/myochallenge/myochallenge-2023?authuser=0), object manipulation track. With curriculum learning, reward shaping and Lattice exploration we trained a policy to control a biologically-realistic arm with 63 muscles and 27 degrees of freedom to place random objects inside a box of variable shape: 14 | 15 | ![relocate](/data/images/myochallenge_2023.gif) 16 | 17 | We outperformed the other best solutions both in score and effort: 18 | 19 | drawing 20 | 21 | We have also created [a dedicated repository](https://github.com/amathislab/myochallenge_2023eval) for the solution, where we have released the pretrained weights of all the curriculum steps. 22 | 23 | ## Installation 24 | 25 | We recommend using a Docker container to execute the code of this repository. We provide both the docker image [albertochiappa/myo-cuda-pybullet](https://hub.docker.com/repository/docker/albertochiappa/myo-cuda-pybullet) in DockerHub and the Dockerfile in the [docker](/docker/) folder to create the same docker image locally. 26 | 27 | If you prefer to manually create a Conda environment, you can do so with the commands: 28 | 29 | ```bash 30 | conda create --name lattice python=3.8.10 31 | conda activate lattice 32 | pip install -r docker/requirements.txt 33 | pip install myosuite==1.2.4 34 | pip install --upgrade cloudpickle==2.2.0 pickle5==0.0.11 pybullet==3.2.5 35 | ``` 36 | 37 | Please note that there is a version error with some packages, e.g. `stable_baselines3`, requiring a later version of `gym` which `myosuite` is incompatible with. For this reason we could not include all the requirements in `docker/requirements.txt`. In our experiments the stated incompatibility did not cause any error. 38 | 39 | ## Training a policy with Lattice 40 | 41 | We provide scripts for various environments of the [MyoSuite](https://sites.google.com/view/myosuite) and [PyBullet](https://pybullet.org). 42 | 43 | Training a policy is as easy as 44 | 45 | ```bash 46 | python main_pose_elbow.py --use_lattice 47 | ``` 48 | 49 | if you have created a conda environment. 50 | 51 | If you prefer to use the readily available docker container, you can train like this: 52 | 53 | ```bash 54 | docker run --rm --gpus all -it \ 55 | --mount type=bind,src="$(pwd)/src",target=/src \ 56 | --mount type=bind,src="$(pwd)/data",target=/data \ 57 | --mount type=bind,src="$(pwd)/output",target=/output \ 58 | albertochiappa/myo-cuda-pybullet \ 59 | python3 src/main_pose_elbow.py --use_lattice 60 | ``` 61 | 62 | The previous command will start training in the `Elbow Pose` enviornment using Recurrent PPO. Simply change the main script name to start training for a different environment. The output of the training, including the configuration used to select the hyperparameters and the tensorboard logs, are saved in a subfolder of `output/`, named as the current date. The code outputs useful information to monitor the training in Tensorboard format. You can run Tensorboard in the output folder to visualize the learning curves and much more. The different configuration hyperparameters can be set from the command line, e.g., by running 63 | 64 | ```bash 65 | python src/main_humanoid.py --use_sde --use_lattice --freq=8 66 | ``` 67 | 68 | In this case, a policy will be trained with SAC in the Humanoid environment, using state-dependent Lattice with update period 8. 69 | 70 | ## Structure of the repository 71 | 72 | * src/ 73 | * envs/ 74 | * Modified MyoSuite environments, when we used different parameters from the default (cf. manuscript) 75 | * Factory to instantiate all the environemtns used in the project 76 | * metrics/ 77 | * Stable Baselines 3 callbacks to register useful information during the training 78 | * models/ 79 | * Implementation of Lattice 80 | * Adaptation of SAC and PPO to use Lattice 81 | * train/ 82 | * Trainer class used to manage the trainings 83 | * main_* 84 | * One main file per environment, to start a training 85 | * data/ 86 | * configuration files for the MyoSuite environments 87 | * docker-cuda 88 | * Definition of the Dockerfile to create the image used to run the experiments, with GPU support 89 | 90 | ## Reference 91 | 92 | If our work was useful to your research, please cite: 93 | 94 | ``` 95 | @article{chiappa2023latent, 96 | title={Latent exploration for reinforcement learning}, 97 | author={Chiappa, Alberto Silvio and Vargas, Alessandro Marin and Huang, Ann Zixiang and Mathis, Alexander}, 98 | journal={Advances in Neural Information Processing Systems (NeurIPS)}, 99 | year={2023} 100 | } 101 | ``` 102 | -------------------------------------------------------------------------------- /data/images/myochallenge_2023.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amathislab/lattice/846d02fa993b9b80ce5ecb806463e0a05711bad3/data/images/myochallenge_2023.gif -------------------------------------------------------------------------------- /data/images/myochallenge_ranking.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amathislab/lattice/846d02fa993b9b80ce5ecb806463e0a05711bad3/data/images/myochallenge_ranking.png -------------------------------------------------------------------------------- /data/myosuite/assets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amathislab/lattice/846d02fa993b9b80ce5ecb806463e0a05711bad3/data/myosuite/assets/__init__.py -------------------------------------------------------------------------------- /data/myosuite/assets/arm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amathislab/lattice/846d02fa993b9b80ce5ecb806463e0a05711bad3/data/myosuite/assets/arm/__init__.py -------------------------------------------------------------------------------- /data/myosuite/assets/arm/myo_elbow_1dof6muscles.mjb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amathislab/lattice/846d02fa993b9b80ce5ecb806463e0a05711bad3/data/myosuite/assets/arm/myo_elbow_1dof6muscles.mjb -------------------------------------------------------------------------------- /data/myosuite/assets/finger/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amathislab/lattice/846d02fa993b9b80ce5ecb806463e0a05711bad3/data/myosuite/assets/finger/__init__.py -------------------------------------------------------------------------------- /data/myosuite/assets/finger/myo_finger_v0.mjb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amathislab/lattice/846d02fa993b9b80ce5ecb806463e0a05711bad3/data/myosuite/assets/finger/myo_finger_v0.mjb -------------------------------------------------------------------------------- /data/myosuite/assets/hand/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amathislab/lattice/846d02fa993b9b80ce5ecb806463e0a05711bad3/data/myosuite/assets/hand/__init__.py -------------------------------------------------------------------------------- /data/myosuite/assets/hand/myo_hand_baoding.mjb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amathislab/lattice/846d02fa993b9b80ce5ecb806463e0a05711bad3/data/myosuite/assets/hand/myo_hand_baoding.mjb -------------------------------------------------------------------------------- /data/myosuite/assets/hand/myo_hand_die.mjb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amathislab/lattice/846d02fa993b9b80ce5ecb806463e0a05711bad3/data/myosuite/assets/hand/myo_hand_die.mjb -------------------------------------------------------------------------------- /data/myosuite/assets/hand/myo_hand_pen.mjb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amathislab/lattice/846d02fa993b9b80ce5ecb806463e0a05711bad3/data/myosuite/assets/hand/myo_hand_pen.mjb -------------------------------------------------------------------------------- /data/myosuite/assets/hand/myo_hand_pose.mjb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amathislab/lattice/846d02fa993b9b80ce5ecb806463e0a05711bad3/data/myosuite/assets/hand/myo_hand_pose.mjb -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.7.1-cudnn8-runtime-ubuntu20.04 2 | 3 | RUN apt-get update -q \ 4 | && DEBIAN_FRONTEND=noninteractive apt-get install -y \ 5 | libgl1-mesa-dev \ 6 | libgl1-mesa-glx \ 7 | libglew-dev \ 8 | libosmesa6-dev \ 9 | software-properties-common \ 10 | patchelf \ 11 | python3 \ 12 | python3-pip 13 | 14 | 15 | RUN apt-get install -y --no-install-recommends -o APT::Immediate-Configure=false gcc g++ git && \ 16 | apt-get clean && \ 17 | rm -rf /var/lib/apt/lists/* 18 | 19 | 20 | ENV PYTHONUNBUFFERED 1 21 | 22 | ADD requirements.txt / 23 | 24 | RUN pip install --upgrade pip 25 | RUN pip install -r requirements.txt 26 | RUN pip install myosuite==1.2.* 27 | RUN pip install --upgrade pickle5 cloudpickle pybullet 28 | RUN python3 -c "import mujoco_py" -------------------------------------------------------------------------------- /docker/requirements.txt: -------------------------------------------------------------------------------- 1 | grpcio==1.47.0 2 | grpcio-tools==1.47.0 3 | protobuf==3.19.5 4 | pyglet==1.5.26 5 | numpy==1.21.6 6 | git+https://github.com/aravindr93/mjrl.git 7 | matplotlib==3.5.3 8 | requests==2.25.1 9 | scikit-learn==1.0.2 10 | sk-video==1.1.10 11 | tabulate==0.8.10 12 | --find-links https://download.pytorch.org/whl/torch_stable.html 13 | torch==1.12.1+cu116 14 | stable_baselines3==1.6.1 15 | sb3-contrib==1.6.0 16 | pandas==1.3.5 17 | tensorboard==2.10.0 -------------------------------------------------------------------------------- /src/definitions.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | ROOT_DIR = os.path.dirname(os.path.abspath(os.path.join(__file__, os.pardir))) 4 | -------------------------------------------------------------------------------- /src/envs/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gym 3 | import myosuite 4 | import numpy as np 5 | from definitions import ROOT_DIR # pylint: disable=import-error 6 | from myosuite.envs.myo import register_env_with_variants 7 | 8 | 9 | myosuite_path = os.path.join(ROOT_DIR, "data", "myosuite") 10 | 11 | # MyoChallenge Baoding: Phase1 env 12 | gym.envs.registration.register( 13 | id="CustomMyoChallengeBaodingP1-v1", 14 | entry_point="envs.baoding:CustomBaodingEnv", 15 | max_episode_steps=200, 16 | kwargs={ 17 | "model_path": myosuite_path + "/assets/hand/myo_hand_baoding.mjb", 18 | "normalize_act": True, 19 | # 'goal_time_period': (5, 5), 20 | "goal_xrange": (0.025, 0.025), 21 | "goal_yrange": (0.028, 0.028), 22 | }, 23 | ) 24 | 25 | # MyoChallenge Die: Phase2 env 26 | gym.envs.registration.register( 27 | id="CustomMyoChallengeDieReorientP2-v0", 28 | entry_point="envs.reorient:CustomReorientEnv", 29 | max_episode_steps=150, 30 | kwargs={ 31 | "model_path": myosuite_path + "/assets/hand/myo_hand_die.mjb", 32 | "normalize_act": True, 33 | "frame_skip": 5, 34 | # Randomization in goals 35 | 'goal_pos': (-.020, .020), # +- 2 cm 36 | 'goal_rot': (-3.14, 3.14), # +-180 degrees 37 | # Randomization in physical properties of the die 38 | 'obj_size_change': 0.007, # +-7mm delta change in object size 39 | 'obj_friction_change': (0.2, 0.001, 0.00002) # nominal: 1.0, 0.005, 0.0001 40 | }, 41 | ) 42 | 43 | register_env_with_variants(id='CustomMyoElbowPoseRandom-v0', 44 | entry_point='envs.pose:CustomPoseEnv', 45 | max_episode_steps=100, 46 | kwargs={ 47 | 'model_path': myosuite_path+'/assets/arm/myo_elbow_1dof6muscles.mjb', 48 | 'target_jnt_range': {'r_elbow_flex':(0, 2.27),}, 49 | 'viz_site_targets': ('wrist',), 50 | 'normalize_act': True, 51 | 'pose_thd': .175, 52 | 'reset_type': 'random' 53 | } 54 | ) 55 | 56 | register_env_with_variants(id='CustomMyoFingerPoseRandom-v0', 57 | entry_point='envs.pose:CustomPoseEnv', 58 | max_episode_steps=100, 59 | kwargs={ 60 | 'model_path': myosuite_path + '/assets/finger/myo_finger_v0.mjb', 61 | 'target_jnt_range': {'IFadb':(-.2, .2), 62 | 'IFmcp':(-.4, 1), 63 | 'IFpip':(.1, 1), 64 | 'IFdip':(.1, 1) 65 | }, 66 | 'viz_site_targets': ('IFtip',), 67 | 'normalize_act': True, 68 | } 69 | ) 70 | 71 | # Hand-Joint posing ============================== 72 | # Create ASL envs ============================== 73 | jnt_namesHand=['pro_sup', 'deviation', 'flexion', 'cmc_abduction', 'cmc_flexion', 'mp_flexion', 'ip_flexion', 'mcp2_flexion', 'mcp2_abduction', 'pm2_flexion', 'md2_flexion', 'mcp3_flexion', 'mcp3_abduction', 'pm3_flexion', 'md3_flexion', 'mcp4_flexion', 'mcp4_abduction', 'pm4_flexion', 'md4_flexion', 'mcp5_flexion', 'mcp5_abduction', 'pm5_flexion', 'md5_flexion'] 74 | 75 | ASL_qpos={} 76 | ASL_qpos[0]='0 0 0 0.5624 0.28272 -0.75573 -1.309 1.30045 -0.006982 1.45492 0.998897 1.26466 0 1.40604 0.227795 1.07614 -0.020944 1.46103 0.06284 0.83263 -0.14399 1.571 1.38248'.split(' ') 77 | ASL_qpos[1]='0 0 0 0.0248 0.04536 -0.7854 -1.309 0.366605 0.010473 0.269258 0.111722 1.48459 0 1.45318 1.44532 1.44532 -0.204204 1.46103 1.44532 1.48459 -0.2618 1.47674 1.48459'.split(' ') 78 | ASL_qpos[2]='0 0 0 0.0248 0.04536 -0.7854 -1.13447 0.514973 0.010473 0.128305 0.111722 0.510575 0 0.37704 0.117825 1.44532 -0.204204 1.46103 1.44532 1.48459 -0.2618 1.47674 1.48459'.split(' ') 79 | ASL_qpos[3]='0 0 0 0.3384 0.25305 0.01569 -0.0262045 0.645885 0.010473 0.128305 0.111722 0.510575 0 0.37704 0.117825 1.571 -0.036652 1.52387 1.45318 1.40604 -0.068068 1.39033 1.571'.split(' ') 80 | ASL_qpos[4]='0 0 0 0.6392 -0.147495 -0.7854 -1.309 0.637158 0.010473 0.128305 0.111722 0.510575 0 0.37704 0.117825 0.306345 -0.010472 0.400605 0.133535 0.21994 -0.068068 0.274925 0.01571'.split(' ') 81 | ASL_qpos[5]='0 0 0 0.3384 0.25305 0.01569 -0.0262045 0.645885 0.010473 0.128305 0.111722 0.510575 0 0.37704 0.117825 0.306345 -0.010472 0.400605 0.133535 0.21994 -0.068068 0.274925 0.01571'.split(' ') 82 | ASL_qpos[6]='0 0 0 0.6392 -0.147495 -0.7854 -1.309 0.637158 0.010473 0.128305 0.111722 0.510575 0 0.37704 0.117825 0.306345 -0.010472 0.400605 0.133535 1.1861 -0.2618 1.35891 1.48459'.split(' ') 83 | ASL_qpos[7]='0 0 0 0.524 0.01569 -0.7854 -1.309 0.645885 -0.006982 0.128305 0.111722 0.510575 0 0.37704 0.117825 1.28036 -0.115192 1.52387 1.45318 0.432025 -0.068068 0.18852 0.149245'.split(' ') 84 | ASL_qpos[8]='0 0 0 0.428 0.22338 -0.7854 -1.309 0.645885 -0.006982 0.128305 0.194636 1.39033 0 1.08399 0.573415 0.667675 -0.020944 0 0.06284 0.432025 -0.068068 0.18852 0.149245'.split(' ') 85 | ASL_qpos[9]='0 0 0 0.5624 0.28272 -0.75573 -1.309 1.30045 -0.006982 1.45492 0.998897 0.39275 0 0.18852 0.227795 0.667675 -0.020944 0 0.06284 0.432025 -0.068068 0.18852 0.149245'.split(' ') 86 | 87 | # ASl Eval envs for each numerals 88 | for k in ASL_qpos.keys(): 89 | register_env_with_variants(id='CustomMyoHandPose'+str(k)+'Fixed-v0', 90 | entry_point='envs.pose:CustomPoseEnv', 91 | max_episode_steps=100, 92 | kwargs={ 93 | 'model_path': myosuite_path + '/assets/hand/myo_hand_pose.mjb', 94 | 'viz_site_targets': ('THtip','IFtip','MFtip','RFtip','LFtip'), 95 | 'target_jnt_value': np.array(ASL_qpos[k],'float'), 96 | 'normalize_act': True, 97 | 'pose_thd': .7, 98 | 'reset_type': "init", # none, init, random 99 | 'target_type': 'fixed', # generate/ fixed 100 | } 101 | ) 102 | 103 | # ASL Train Env 104 | m = np.array([ASL_qpos[i] for i in range(10)]).astype(float) 105 | Rpos = {} 106 | for i_n, n in enumerate(jnt_namesHand): 107 | Rpos[n]=(np.min(m[:,i_n]), np.max(m[:,i_n])) 108 | 109 | register_env_with_variants(id='CustomMyoHandPoseRandom-v0', #reconsider 110 | entry_point='envs.pose:CustomPoseEnv', 111 | max_episode_steps=100, 112 | kwargs={ 113 | 'model_path': myosuite_path + '/assets/hand/myo_hand_pose.mjb', 114 | 'viz_site_targets': ('THtip','IFtip','MFtip','RFtip','LFtip'), 115 | 'target_jnt_range': Rpos, 116 | 'normalize_act': True, 117 | 'pose_thd': .8, 118 | 'reset_type': "random", # none, init, random 119 | 'target_type': 'generate', # generate/ fixed 120 | } 121 | ) 122 | 123 | 124 | # Pen twirl 125 | register_env_with_variants(id='CustomMyoHandPenTwirlRandom-v0', 126 | entry_point='envs.pen:CustomPenEnv', 127 | max_episode_steps=100, 128 | kwargs={ 129 | 'model_path': myosuite_path + '/assets/hand/myo_hand_pen.mjb', 130 | 'normalize_act': True, 131 | 'frame_skip': 5, 132 | } 133 | ) -------------------------------------------------------------------------------- /src/envs/baoding.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=attribute-defined-outside-init, dangerous-default-value, protected-access, abstract-method, arguments-renamed, import-error 2 | import collections 3 | import random 4 | import numpy as np 5 | from myosuite.envs.myo.base_v0 import BaseV0 6 | from myosuite.envs.myo.myochallenge.baoding_v1 import WHICH_TASK, BaodingEnvV1, Task 7 | from sb3_contrib import RecurrentPPO 8 | from stable_baselines3.common.vec_env import VecNormalize 9 | from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv 10 | from envs.environment_factory import EnvironmentFactory 11 | 12 | 13 | class CustomBaodingEnv(BaodingEnvV1): 14 | DEFAULT_RWD_KEYS_AND_WEIGHTS = { 15 | "pos_dist_1": 5.0, 16 | "pos_dist_2": 5.0, 17 | "alive": 0.0, 18 | "act_reg": 0.0, 19 | # "palm_up": 0.0, 20 | } 21 | 22 | def get_reward_dict(self, obs_dict): 23 | # tracking error 24 | target1_dist = np.linalg.norm(obs_dict["target1_err"], axis=-1) 25 | target2_dist = np.linalg.norm(obs_dict["target2_err"], axis=-1) 26 | target_dist = target1_dist + target2_dist 27 | act_mag = ( 28 | np.linalg.norm(self.obs_dict["act"], axis=-1) / self.sim.model.na 29 | if self.sim.model.na != 0 30 | else 0 31 | ) 32 | 33 | # detect fall 34 | object1_pos = ( 35 | obs_dict["object1_pos"][:, :, 2] 36 | if obs_dict["object1_pos"].ndim == 3 37 | else obs_dict["object1_pos"][2] 38 | ) 39 | object2_pos = ( 40 | obs_dict["object2_pos"][:, :, 2] 41 | if obs_dict["object2_pos"].ndim == 3 42 | else obs_dict["object2_pos"][2] 43 | ) 44 | is_fall_1 = object1_pos < self.drop_th 45 | is_fall_2 = object2_pos < self.drop_th 46 | is_fall = np.logical_or(is_fall_1, is_fall_2) # keep both balls up 47 | 48 | rwd_dict = collections.OrderedDict( 49 | ( 50 | # Perform reward tuning here -- 51 | # Update Optional Keys section below 52 | # Update reward keys (DEFAULT_RWD_KEYS_AND_WEIGHTS) to update final rewards 53 | # Examples: Env comes pre-packaged with two keys pos_dist_1 and pos_dist_2 54 | # Optional Keys 55 | ("pos_dist_1", -1.0 * target1_dist), 56 | ("pos_dist_2", -1.0 * target2_dist), 57 | ("alive", ~is_fall), 58 | # ("palm_up", palm_up_reward), 59 | # Must keys 60 | ("act_reg", -1.0 * act_mag), 61 | ("sparse", -target_dist), 62 | ( 63 | "solved", 64 | (target1_dist < self.proximity_th) 65 | * (target2_dist < self.proximity_th) 66 | * (~is_fall), 67 | ), 68 | ("done", is_fall), 69 | ) 70 | ) 71 | rwd_dict["dense"] = np.sum( 72 | [wt * rwd_dict[key] for key, wt in self.rwd_keys_wt.items()], axis=0 73 | ) 74 | 75 | # Sucess Indicator 76 | self.sim.model.geom_rgba[self.object1_gid, :2] = ( 77 | np.array([1, 1]) 78 | if target1_dist < self.proximity_th 79 | else np.array([0.5, 0.5]) 80 | ) 81 | self.sim.model.geom_rgba[self.object2_gid, :2] = ( 82 | np.array([0.9, 0.7]) 83 | if target1_dist < self.proximity_th 84 | else np.array([0.5, 0.5]) 85 | ) 86 | 87 | return rwd_dict 88 | 89 | def _add_noise_to_palm_position( 90 | self, qpos: np.ndarray, noise: float = 1 91 | ) -> np.ndarray: 92 | assert 0 <= noise <= 1, "Noise must be between 0 and 1" 93 | 94 | # pronation-supination of the wrist 95 | # noise = 1 corresponds to 10 degrees from facing up (one direction only) 96 | qpos[0] = self.np_random.uniform( 97 | low=-np.pi / 2, high=-np.pi / 2 + np.pi / 18 * noise 98 | ) 99 | 100 | # ulnar deviation of wrist: 101 | # noise = 1 corresponds to 10 degrees on either side 102 | qpos[1] = self.np_random.uniform( 103 | low=-np.pi / 18 * noise, high=np.pi / 18 * noise 104 | ) 105 | 106 | # extension flexion of the wrist 107 | # noise = 1 corresponds to 10 degrees on either side 108 | qpos[2] = self.np_random.uniform( 109 | low=-np.pi / 18 * noise, high=np.pi / 18 * noise 110 | ) 111 | 112 | return qpos 113 | 114 | def _add_noise_to_finger_positions( 115 | self, qpos: np.ndarray, noise: float = 1 116 | ) -> np.ndarray: 117 | assert 0 <= noise <= 1, "Noise parameter must be between 0 and 1" 118 | 119 | # thumb all joints 120 | # noise = 1 corresponds to 10 degrees on either side 121 | qpos[3:7] = self.np_random.uniform( 122 | low=-np.pi / 18 * noise, high=np.pi / 18 * noise 123 | ) 124 | 125 | # finger joints 126 | # noise = 1 corresponds to 30 degrees bent instead of fully open 127 | qpos[[7, 9, 10, 11, 13, 14, 15, 17, 18, 19, 21, 22]] = self.np_random.uniform( 128 | low=0, high=np.pi / 6 * noise 129 | ) 130 | 131 | # finger abduction (sideways angle) 132 | # noise = 1 corresponds to 5 degrees on either side 133 | qpos[[8, 12, 16, 20]] = self.np_random.uniform( 134 | low=-np.pi / 36 * noise, high=np.pi / 36 * noise 135 | ) 136 | 137 | return qpos 138 | 139 | def reset(self, reset_pose=None, reset_vel=None, reset_goal=None, time_period=None): 140 | self.which_task = self.sample_task() 141 | if self.rsi: 142 | # MODIFICATION: randomize starting target position along the cycle 143 | random_phase = np.random.uniform(low=-np.pi, high=np.pi) 144 | else: 145 | random_phase = 0 146 | self.ball_1_starting_angle = 3.0 * np.pi / 4.0 + random_phase 147 | self.ball_2_starting_angle = -1.0 * np.pi / 4.0 + random_phase 148 | 149 | # reset counters 150 | self.counter = 0 151 | self.x_radius = self.np_random.uniform( 152 | low=self.goal_xrange[0], high=self.goal_xrange[1] 153 | ) 154 | self.y_radius = self.np_random.uniform( 155 | low=self.goal_yrange[0], high=self.goal_yrange[1] 156 | ) 157 | 158 | # reset goal 159 | if time_period is None: 160 | time_period = self.np_random.uniform( 161 | low=self.goal_time_period[0], high=self.goal_time_period[1] 162 | ) 163 | self.goal = ( 164 | self.create_goal_trajectory(time_step=self.dt, time_period=time_period) 165 | if reset_goal is None 166 | else reset_goal.copy() 167 | ) 168 | 169 | # reset scene (MODIFIED from base class MujocoEnv) 170 | qpos = self.init_qpos.copy() if reset_pose is None else reset_pose 171 | qvel = self.init_qvel.copy() if reset_vel is None else reset_vel 172 | self.robot.reset(qpos, qvel) 173 | 174 | if self.rsi: 175 | if np.random.uniform(0, 1) < self.rsi_probability: 176 | self.step(np.zeros(39)) 177 | 178 | # update ball positions 179 | obs = self.get_obs().copy() 180 | qpos[23] = obs[35] # ball 1 x-position 181 | qpos[24] = obs[36] # ball 1 y-position 182 | qpos[30] = obs[38] # ball 2 x-position 183 | qpos[31] = obs[39] # ball 2 y-position 184 | 185 | if self.noise_balls: 186 | # update balls x,y,z positions with relative noise 187 | for i in [23, 24, 25, 30, 31, 32]: 188 | qpos[i] += np.random.uniform( 189 | low=-self.noise_balls, high=self.noise_balls 190 | ) 191 | 192 | if self.noise_palm: 193 | qpos = self._add_noise_to_palm_position(qpos, self.noise_palm) 194 | 195 | if self.noise_fingers: 196 | qpos = self._add_noise_to_finger_positions(qpos, self.noise_fingers) 197 | 198 | if self.rsi or self.noise_palm or self.noise_fingers or self.noise_balls: 199 | self.set_state(qpos, qvel) 200 | 201 | return self.get_obs() 202 | 203 | def _setup( 204 | self, 205 | frame_skip: int = 10, 206 | drop_th=1.25, # drop height threshold 207 | proximity_th=0.015, # object-target proximity threshold 208 | goal_time_period=(5, 5), # target rotation time period 209 | goal_xrange=(0.025, 0.025), # target rotation: x radius (0.03) 210 | goal_yrange=(0.028, 0.028), # target rotation: x radius (0.02 * 1.5 * 1.2) 211 | obs_keys: list = BaodingEnvV1.DEFAULT_OBS_KEYS, 212 | weighted_reward_keys: list = DEFAULT_RWD_KEYS_AND_WEIGHTS, 213 | task=None, 214 | enable_rsi=False, # random state init for balls 215 | noise_palm=0, # magnitude of noise for palm (between 0 and 1) 216 | noise_fingers=0, # magnitude of noise for fingers (between 0 and 1) 217 | noise_balls=0, # relative magnitude of noise for the balls (1 is 100% relative noise) 218 | rsi_probability=1, # probability of implementing RSI 219 | **kwargs, 220 | ): 221 | 222 | # user parameters 223 | self.task = task 224 | self.which_task = self.sample_task() 225 | self.rsi = enable_rsi 226 | self.noise_palm = noise_palm 227 | self.noise_fingers = noise_fingers 228 | self.drop_th = drop_th 229 | self.proximity_th = proximity_th 230 | self.goal_time_period = goal_time_period 231 | self.goal_xrange = goal_xrange 232 | self.goal_yrange = goal_yrange 233 | self.noise_balls = noise_balls 234 | self.rsi_probability = rsi_probability 235 | 236 | # balls start at these angles 237 | # 1= yellow = top right 238 | # 2= pink = bottom left 239 | self.ball_1_starting_angle = 3.0 * np.pi / 4.0 240 | self.ball_2_starting_angle = -1.0 * np.pi / 4.0 241 | 242 | # init desired trajectory, for rotations 243 | self.center_pos = [-0.0125, -0.07] # [-.0020, -.0522] 244 | self.x_radius = self.np_random.uniform( 245 | low=self.goal_xrange[0], high=self.goal_xrange[1] 246 | ) 247 | self.y_radius = self.np_random.uniform( 248 | low=self.goal_yrange[0], high=self.goal_yrange[1] 249 | ) 250 | 251 | self.counter = 0 252 | self.goal = self.create_goal_trajectory( 253 | time_step=frame_skip * self.sim.model.opt.timestep, time_period=6 254 | ) 255 | 256 | # init target and body sites 257 | self.object1_sid = self.sim.model.site_name2id("ball1_site") 258 | self.object2_sid = self.sim.model.site_name2id("ball2_site") 259 | self.object1_gid = self.sim.model.geom_name2id("ball1") 260 | self.object2_gid = self.sim.model.geom_name2id("ball2") 261 | self.target1_sid = self.sim.model.site_name2id("target1_site") 262 | self.target2_sid = self.sim.model.site_name2id("target2_site") 263 | self.sim.model.site_group[self.target1_sid] = 2 264 | self.sim.model.site_group[self.target2_sid] = 2 265 | 266 | BaseV0._setup( 267 | self, 268 | obs_keys=obs_keys, 269 | weighted_reward_keys=weighted_reward_keys, 270 | frame_skip=frame_skip, 271 | **kwargs, 272 | ) 273 | 274 | # reset position 275 | self.init_qpos[:-14] *= 0 # Use fully open as init pos 276 | self.init_qpos[0] = -1.57 # Palm up 277 | 278 | def sample_task(self): 279 | if self.task is None: 280 | return Task(WHICH_TASK) 281 | else: 282 | if self.task == "cw": 283 | return Task(Task.BAODING_CW) 284 | elif self.task == "ccw": 285 | return Task(Task.BAODING_CCW) 286 | elif self.task == "random": 287 | return Task(random.choice(list(Task))) 288 | else: 289 | raise ValueError("Unknown task for baoding: ", self.task) 290 | 291 | def step(self, action): 292 | obs, reward, done, info = super().step(action) 293 | info.update(info.get("rwd_dict")) 294 | return obs, reward, done, info 295 | 296 | 297 | class CustomBaodingP2Env(BaodingEnvV1): 298 | def _setup( 299 | self, 300 | frame_skip: int = 10, 301 | drop_th=1.25, # drop height threshold 302 | proximity_th=0.015, # object-target proximity threshold 303 | goal_time_period=(5, 5), # target rotation time period 304 | goal_xrange=(0.025, 0.025), # target rotation: x radius (0.03) 305 | goal_yrange=(0.028, 0.028), # target rotation: x radius (0.02 * 1.5 * 1.2) 306 | obj_size_range=(0.018, 0.024), # Object size range. Nominal 0.022 307 | obj_mass_range=(0.030, 0.300), # Object weight range. Nominal 43 gms 308 | obj_friction_change=(0.2, 0.001, 0.00002), 309 | task_choice="fixed", # fixed/ random 310 | obs_keys: list = BaodingEnvV1.DEFAULT_OBS_KEYS, 311 | weighted_reward_keys: list = BaodingEnvV1.DEFAULT_RWD_KEYS_AND_WEIGHTS, 312 | enable_rsi=False, # random state init for balls 313 | rsi_probability=1, # probability of implementing RSI 314 | balls_overlap=False, 315 | overlap_probability=0, 316 | limit_init_angle=None, 317 | beta_init_angle=None, 318 | beta_ball_size=None, 319 | beta_ball_mass=None, 320 | noise_fingers=0, 321 | **kwargs, 322 | ): 323 | # user parameters 324 | self.task_choice = task_choice 325 | self.which_task = ( 326 | self.np_random.choice(Task) if task_choice == "random" else Task(WHICH_TASK) 327 | ) 328 | self.drop_th = drop_th 329 | self.proximity_th = proximity_th 330 | self.goal_time_period = goal_time_period 331 | self.goal_xrange = goal_xrange 332 | self.goal_yrange = goal_yrange 333 | self.rsi = enable_rsi 334 | self.rsi_probability = rsi_probability 335 | self.balls_overlap = balls_overlap 336 | self.overlap_probability = overlap_probability 337 | self.noise_fingers = noise_fingers 338 | self.limit_init_angle = limit_init_angle 339 | self.beta_init_angle = beta_init_angle 340 | self.beta_ball_size = beta_ball_size 341 | self.beta_ball_mass = beta_ball_mass 342 | 343 | # balls start at these angles 344 | # 1= yellow = top right 345 | # 2= pink = bottom left 346 | 347 | if np.random.uniform(0, 1) < self.overlap_probability: 348 | self.ball_1_starting_angle = 3.0 * np.pi / 4.0 349 | self.ball_2_starting_angle = -1.0 * np.pi / 4.0 350 | else: 351 | self.ball_1_starting_angle = 1.0 * np.pi / 4.0 352 | self.ball_2_starting_angle = self.ball_1_starting_angle - np.pi 353 | 354 | # init desired trajectory, for rotations 355 | self.center_pos = [-0.0125, -0.07] # [-.0020, -.0522] 356 | self.x_radius = self.np_random.uniform( 357 | low=self.goal_xrange[0], high=self.goal_xrange[1] 358 | ) 359 | self.y_radius = self.np_random.uniform( 360 | low=self.goal_yrange[0], high=self.goal_yrange[1] 361 | ) 362 | 363 | self.counter = 0 364 | self.goal = self.create_goal_trajectory( 365 | time_step=frame_skip * self.sim.model.opt.timestep, time_period=6 366 | ) 367 | 368 | # init target and body sites 369 | self.object1_bid = self.sim.model.body_name2id("ball1") 370 | self.object2_bid = self.sim.model.body_name2id("ball2") 371 | self.object1_sid = self.sim.model.site_name2id("ball1_site") 372 | self.object2_sid = self.sim.model.site_name2id("ball2_site") 373 | self.object1_gid = self.sim.model.geom_name2id("ball1") 374 | self.object2_gid = self.sim.model.geom_name2id("ball2") 375 | self.target1_sid = self.sim.model.site_name2id("target1_site") 376 | self.target2_sid = self.sim.model.site_name2id("target2_site") 377 | self.sim.model.site_group[self.target1_sid] = 2 378 | self.sim.model.site_group[self.target2_sid] = 2 379 | 380 | # setup for task randomization 381 | self.obj_mass_range = {"low": obj_mass_range[0], "high": obj_mass_range[1]} 382 | self.obj_size_range = {"low": obj_size_range[0], "high": obj_size_range[1]} 383 | self.obj_friction_range = { 384 | "low": self.sim.model.geom_friction[self.object1_gid] - obj_friction_change, 385 | "high": self.sim.model.geom_friction[self.object1_gid] 386 | + obj_friction_change, 387 | } 388 | 389 | BaseV0._setup( 390 | self, 391 | obs_keys=obs_keys, 392 | weighted_reward_keys=weighted_reward_keys, 393 | frame_skip=frame_skip, 394 | **kwargs, 395 | ) 396 | 397 | # reset position 398 | self.init_qpos[:-14] *= 0 # Use fully open as init pos 399 | self.init_qpos[0] = -1.57 # Palm up 400 | 401 | def get_reward_dict(self, obs_dict): 402 | # tracking error 403 | target1_dist = np.linalg.norm(obs_dict["target1_err"], axis=-1) 404 | target2_dist = np.linalg.norm(obs_dict["target2_err"], axis=-1) 405 | target_dist = target1_dist + target2_dist 406 | act_mag = ( 407 | np.linalg.norm(self.obs_dict["act"], axis=-1) / self.sim.model.na 408 | if self.sim.model.na != 0 409 | else 0 410 | ) 411 | 412 | # detect fall 413 | object1_pos = ( 414 | obs_dict["object1_pos"][:, :, 2] 415 | if obs_dict["object1_pos"].ndim == 3 416 | else obs_dict["object1_pos"][2] 417 | ) 418 | object2_pos = ( 419 | obs_dict["object2_pos"][:, :, 2] 420 | if obs_dict["object2_pos"].ndim == 3 421 | else obs_dict["object2_pos"][2] 422 | ) 423 | is_fall_1 = object1_pos < self.drop_th 424 | is_fall_2 = object2_pos < self.drop_th 425 | is_fall = np.logical_or(is_fall_1, is_fall_2) # keep both balls up 426 | 427 | rwd_dict = collections.OrderedDict( 428 | ( 429 | # Perform reward tuning here -- 430 | # Update Optional Keys section below 431 | # Update reward keys (DEFAULT_RWD_KEYS_AND_WEIGHTS) accordingly to update final rewards 432 | # Examples: Env comes pre-packaged with two keys pos_dist_1 and pos_dist_2 433 | # Optional Keys 434 | ("pos_dist_1", -1.0 * target1_dist), 435 | ("pos_dist_2", -1.0 * target2_dist), 436 | # Must keys 437 | ("act_reg", -1.0 * act_mag), 438 | ("alive", ~is_fall), 439 | ("sparse", -target_dist), 440 | ( 441 | "solved", 442 | (target1_dist < self.proximity_th) 443 | * (target2_dist < self.proximity_th) 444 | * (~is_fall), 445 | ), 446 | ("done", is_fall), 447 | ) 448 | ) 449 | rwd_dict["dense"] = np.sum( 450 | [wt * rwd_dict[key] for key, wt in self.rwd_keys_wt.items()], axis=0 451 | ) 452 | 453 | # Sucess Indicator 454 | self.sim.model.geom_rgba[self.object1_gid, :2] = ( 455 | np.array([1, 1]) 456 | if target1_dist < self.proximity_th 457 | else np.array([0.5, 0.5]) 458 | ) 459 | self.sim.model.geom_rgba[self.object2_gid, :2] = ( 460 | np.array([0.9, 0.7]) 461 | if target1_dist < self.proximity_th 462 | else np.array([0.5, 0.5]) 463 | ) 464 | 465 | return rwd_dict 466 | 467 | def _add_noise_to_finger_positions( 468 | self, qpos: np.ndarray, noise: float = 1 469 | ) -> np.ndarray: 470 | assert 0 <= noise <= 1, "Noise parameter must be between 0 and 1" 471 | 472 | # thumb all joints 473 | qpos[4:7] = self.np_random.uniform( 474 | low=-np.pi / 18 * noise, high=np.pi / 18 * noise 475 | ) 476 | 477 | # finger joints 478 | qpos[[7, 9, 10, 11, 13, 14, 15, 17, 18, 19, 21, 22]] = self.np_random.uniform( 479 | low=0, high=np.pi / 6 * noise 480 | ) 481 | 482 | return qpos 483 | 484 | def reset(self, reset_pose=None, reset_vel=None, reset_goal=None, time_period=None): 485 | 486 | # reset task 487 | if self.task_choice == "random": 488 | self.which_task = self.np_random.choice(Task) 489 | 490 | if np.random.uniform(0, 1) <= self.overlap_probability: 491 | self.ball_1_starting_angle = 3.0 * np.pi / 4.0 492 | elif self.limit_init_angle is not None: 493 | random_phase = self.np_random.uniform( 494 | low=-self.limit_init_angle, high=self.limit_init_angle 495 | ) 496 | 497 | if self.beta_init_angle is not None: 498 | random_phase = ( 499 | self.np_random.beta( 500 | self.beta_init_angle[0], self.beta_init_angle[1] 501 | ) 502 | * 2 503 | * np.pi 504 | - np.pi 505 | ) 506 | 507 | self.ball_1_starting_angle = 3.0 * np.pi / 4.0 + random_phase 508 | else: 509 | self.ball_1_starting_angle = self.np_random.uniform( 510 | low=0, high=2 * np.pi 511 | ) 512 | 513 | self.ball_2_starting_angle = self.ball_1_starting_angle - np.pi 514 | # reset counters 515 | self.counter = 0 516 | self.x_radius = self.np_random.uniform( 517 | low=self.goal_xrange[0], high=self.goal_xrange[1] 518 | ) 519 | self.y_radius = self.np_random.uniform( 520 | low=self.goal_yrange[0], high=self.goal_yrange[1] 521 | ) 522 | 523 | # reset goal 524 | if time_period is None: 525 | time_period = self.np_random.uniform( 526 | low=self.goal_time_period[0], high=self.goal_time_period[1] 527 | ) 528 | self.goal = ( 529 | self.create_goal_trajectory(time_step=self.dt, time_period=time_period) 530 | if reset_goal is None 531 | else reset_goal.copy() 532 | ) 533 | 534 | # balls mass changes 535 | self.sim.model.body_mass[self.object1_bid] = self.np_random.uniform( 536 | **self.obj_mass_range 537 | ) # call to mj_setConst(m,d) is being ignored. Derive quantities wont be updated. Die is simple shape. So this is reasonable approximation. 538 | self.sim.model.body_mass[self.object2_bid] = self.np_random.uniform( 539 | **self.obj_mass_range 540 | ) # call to mj_setConst(m,d) is being ignored. Derive quantities wont be updated. Die is simple shape. So this is reasonable approximation. 541 | 542 | if self.beta_ball_mass is not None: 543 | self.sim.model.body_mass[self.object1_bid] = ( 544 | self.np_random.beta(self.beta_ball_mass[0], self.beta_ball_mass[1]) 545 | * (self.obj_mass_range["high"] - self.obj_mass_range["low"]) 546 | + self.obj_mass_range["low"] 547 | ) 548 | self.sim.model.body_mass[self.object2_bid] = ( 549 | self.np_random.beta(self.beta_ball_mass[0], self.beta_ball_mass[1]) 550 | * (self.obj_mass_range["high"] - self.obj_mass_range["low"]) 551 | + self.obj_mass_range["low"] 552 | ) 553 | # balls friction changes 554 | self.sim.model.geom_friction[self.object1_gid] = self.np_random.uniform( 555 | **self.obj_friction_range 556 | ) 557 | self.sim.model.geom_friction[self.object2_gid] = self.np_random.uniform( 558 | **self.obj_friction_range 559 | ) 560 | 561 | # balls size changes 562 | self.sim.model.geom_size[self.object1_gid] = self.np_random.uniform( 563 | **self.obj_size_range 564 | ) 565 | self.sim.model.geom_size[self.object2_gid] = self.np_random.uniform( 566 | **self.obj_size_range 567 | ) 568 | 569 | if self.beta_ball_size is not None: 570 | self.sim.model.geom_size[self.object1_gid] = ( 571 | self.np_random.beta(self.beta_ball_size[0], self.beta_ball_size[1]) 572 | * (self.obj_size_range["high"] - self.obj_size_range["low"]) 573 | + self.obj_size_range["low"] 574 | ) 575 | self.sim.model.geom_size[self.object2_gid] = ( 576 | self.np_random.beta(self.beta_ball_size[0], self.beta_ball_size[1]) 577 | * (self.obj_size_range["high"] - self.obj_size_range["low"]) 578 | + self.obj_size_range["low"] 579 | ) 580 | # reset scene 581 | qpos = self.init_qpos.copy() if reset_pose is None else reset_pose 582 | qvel = self.init_qvel.copy() if reset_vel is None else reset_vel 583 | self.robot.reset(qpos, qvel) 584 | 585 | if self.rsi and np.random.uniform(0, 1) < self.rsi_probability: 586 | random_phase = np.random.uniform(low=-np.pi, high=np.pi) 587 | self.ball_1_starting_angle = 3.0 * np.pi / 4.0 + random_phase 588 | self.ball_2_starting_angle = -1.0 * np.pi / 4.0 + random_phase 589 | 590 | # # reset scene (MODIFIED from base class MujocoEnv) 591 | self.robot.reset(qpos, qvel) 592 | self.step(np.zeros(39)) 593 | # update ball positions 594 | obs_dict = self.get_obs_dict(self.sim) 595 | target_1_pos = obs_dict["target1_pos"] 596 | target_2_pos = obs_dict["target2_pos"] 597 | qpos[23] = target_1_pos[0] # ball 1 x-position 598 | qpos[24] = target_1_pos[1] # ball 1 y-position 599 | qpos[30] = target_2_pos[0] # ball 2 x-position 600 | qpos[31] = target_2_pos[1] # ball 2 y-position 601 | self.set_state(qpos, qvel) 602 | 603 | if self.balls_overlap is False: 604 | self.ball_1_starting_angle = self.np_random.uniform( 605 | low=0, high=2 * np.pi 606 | ) 607 | self.ball_2_starting_angle = self.ball_1_starting_angle - np.pi 608 | 609 | if self.noise_fingers is not None: 610 | qpos = self._add_noise_to_finger_positions(qpos, self.noise_fingers) 611 | self.set_state(qpos, qvel) 612 | 613 | return self.get_obs() 614 | 615 | def step(self, action): 616 | obs, reward, done, info = super().step(action) 617 | info.update(info.get("rwd_dict")) 618 | return obs, reward, done, info 619 | -------------------------------------------------------------------------------- /src/envs/environment_factory.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import pybullet_envs 3 | 4 | 5 | class EnvironmentFactory: 6 | """Static factory to instantiate and register gym environments by name.""" 7 | 8 | @staticmethod 9 | def create(env_name, **kwargs): 10 | """Creates an environment given its name as a string, and forwards the kwargs 11 | to its __init__ function. 12 | 13 | Args: 14 | env_name (str): name of the environment 15 | 16 | Raises: 17 | ValueError: if the name of the environment is unknown 18 | 19 | Returns: 20 | gym.env: the selected environment 21 | """ 22 | # make myosuite envs 23 | if env_name == "MyoFingerPoseRandom": 24 | return gym.make("myoFingerPoseRandom-v0") 25 | elif env_name == "MyoFingerReachRandom": 26 | return gym.make("myoFingerReachRandom-v0") 27 | elif env_name == "MyoHandReachRandom": 28 | return gym.make("myoHandReachRandom-v0") 29 | elif env_name == "MyoElbowReachRandom": 30 | return gym.make("myoElbowReachRandom-v0") 31 | elif env_name == "CustomMyoBaodingBallsP1": 32 | return gym.make("CustomMyoChallengeBaodingP1-v1", **kwargs) 33 | elif env_name == "CustomMyoReorientP2": 34 | return gym.make("CustomMyoChallengeDieReorientP2-v0", **kwargs) 35 | elif env_name == "CustomMyoElbowPoseRandom": 36 | return gym.make("CustomMyoElbowPoseRandom-v0", **kwargs) 37 | elif env_name == "CustomMyoFingerPoseRandom": 38 | return gym.make("CustomMyoFingerPoseRandom-v0", **kwargs) 39 | elif env_name == "CustomMyoHandPoseRandom": 40 | return gym.make("CustomMyoHandPoseRandom-v0", **kwargs) 41 | elif env_name == "CustomMyoPenTwirlRandom": 42 | return gym.make("CustomMyoHandPenTwirlRandom-v0", **kwargs) 43 | elif env_name == "WalkerBulletEnv": 44 | return gym.make("Walker2DBulletEnv-v0", **kwargs) 45 | elif env_name == "HalfCheetahBulletEnv": 46 | return gym.make("HalfCheetahBulletEnv-v0", **kwargs) 47 | elif env_name == "AntBulletEnv": 48 | return gym.make("AntBulletEnv-v0", **kwargs) 49 | elif env_name == "HopperBulletEnv": 50 | return gym.make("HopperBulletEnv-v0", **kwargs) 51 | elif env_name == "HumanoidBulletEnv": 52 | return gym.make("HumanoidBulletEnv-v0", **kwargs) 53 | else: 54 | raise ValueError("Environment name not recognized:", env_name) 55 | -------------------------------------------------------------------------------- /src/envs/pen.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | import numpy as np 4 | from myosuite.envs.env_base import MujocoEnv 5 | from myosuite.envs.myo.base_v0 import BaseV0 6 | from myosuite.envs.myo.pen_v0 import PenTwirlRandomEnvV0 7 | from myosuite.utils.quat_math import euler2quat 8 | from myosuite.utils.vector_math import calculate_cosine 9 | 10 | 11 | class CustomPenEnv(PenTwirlRandomEnvV0): 12 | def get_reward_dict(self, obs_dict): 13 | pos_err = obs_dict["obj_err_pos"] 14 | pos_align = np.linalg.norm(pos_err, axis=-1) 15 | rot_align = calculate_cosine(obs_dict["obj_rot"], obs_dict["obj_des_rot"]) 16 | dropped = pos_align > 0.075 17 | act_mag = ( 18 | np.linalg.norm(self.obs_dict["act"], axis=-1) / self.sim.model.na 19 | if self.sim.model.na != 0 20 | else 0 21 | ) 22 | pos_align_diff = self.pos_align - pos_align # should decrease 23 | rot_align_diff = rot_align - self.rot_align # should increase 24 | alive = ~dropped 25 | 26 | rwd_dict = collections.OrderedDict( 27 | ( 28 | # Optional Keys 29 | ("pos_align", -1.0 * pos_align), 30 | ("rot_align", rot_align), 31 | ("pos_align_diff", pos_align_diff), 32 | ("rot_align_diff", rot_align_diff), 33 | ("alive", alive), 34 | ("act_reg", -1.0 * act_mag), 35 | ("drop", -1.0 * dropped), 36 | ( 37 | "bonus", 38 | 1.0 * (rot_align > 0.9) * (pos_align < 0.075) 39 | + 5.0 * (rot_align > 0.95) * (pos_align < 0.075), 40 | ), 41 | # Must keys 42 | ("sparse", -1.0 * pos_align + rot_align), 43 | ("solved", (rot_align > 0.95) * (~dropped)), 44 | ("done", dropped), 45 | ) 46 | ) 47 | rwd_dict["dense"] = np.sum( 48 | [wt * rwd_dict[key] for key, wt in self.rwd_keys_wt.items()], axis=0 49 | ) 50 | return rwd_dict 51 | 52 | def _setup( 53 | self, 54 | obs_keys: list = PenTwirlRandomEnvV0.DEFAULT_OBS_KEYS, 55 | weighted_reward_keys: list = PenTwirlRandomEnvV0.DEFAULT_RWD_KEYS_AND_WEIGHTS, 56 | goal_orient_range=( 57 | -1, 58 | 1, 59 | ), # can be used to make the task simpler and limit the target orientations 60 | enable_rsi=False, 61 | rsi_distance=0, 62 | **kwargs, 63 | ): 64 | self.target_obj_bid = self.sim.model.body_name2id("target") 65 | self.S_grasp_sid = self.sim.model.site_name2id("S_grasp") 66 | self.obj_bid = self.sim.model.body_name2id("Object") 67 | self.eps_ball_sid = self.sim.model.site_name2id("eps_ball") 68 | self.obj_t_sid = self.sim.model.site_name2id("object_top") 69 | self.obj_b_sid = self.sim.model.site_name2id("object_bottom") 70 | self.tar_t_sid = self.sim.model.site_name2id("target_top") 71 | self.tar_b_sid = self.sim.model.site_name2id("target_bottom") 72 | self.pen_length = np.linalg.norm( 73 | self.sim.model.site_pos[self.obj_t_sid] 74 | - self.sim.model.site_pos[self.obj_b_sid] 75 | ) 76 | self.tar_length = np.linalg.norm( 77 | self.sim.model.site_pos[self.tar_t_sid] 78 | - self.sim.model.site_pos[self.tar_b_sid] 79 | ) 80 | 81 | self.goal_orient_range = goal_orient_range 82 | self.rsi = enable_rsi 83 | self.rsi_distance = rsi_distance 84 | self.pos_align = 0 85 | self.rot_align = 0 86 | 87 | BaseV0._setup( 88 | self, 89 | obs_keys=obs_keys, 90 | weighted_reward_keys=weighted_reward_keys, 91 | **kwargs, 92 | ) 93 | self.init_qpos[:-6] *= 0 # Use fully open as init pos 94 | self.init_qpos[0] = -1.5 # place palm up 95 | 96 | def reset(self): 97 | # randomize target 98 | desired_orien = np.zeros(3) 99 | desired_orien[0] = self.np_random.uniform( 100 | low=self.goal_orient_range[0], high=self.goal_orient_range[1] 101 | ) 102 | desired_orien[1] = self.np_random.uniform( 103 | low=self.goal_orient_range[0], high=self.goal_orient_range[1] 104 | ) 105 | self.sim.model.body_quat[self.target_obj_bid] = euler2quat(desired_orien) 106 | 107 | if self.rsi: 108 | init_orien = np.zeros(3) 109 | init_orien[:2] = desired_orien[:2] + self.rsi_distance * ( 110 | init_orien[:2] - desired_orien[:2] 111 | ) 112 | self.sim.model.body_quat[self.obj_bid] = euler2quat(init_orien) 113 | 114 | self.robot.sync_sims(self.sim, self.sim_obsd) 115 | obs = MujocoEnv.reset(self) 116 | 117 | self.pos_align = np.linalg.norm(self.obs_dict["obj_err_pos"], axis=-1) 118 | self.rot_align = calculate_cosine( 119 | self.obs_dict["obj_rot"], self.obs_dict["obj_des_rot"] 120 | ) 121 | 122 | return obs 123 | 124 | def step(self, a): 125 | obs, reward, done, info = super().step(a) 126 | self.pos_align = np.linalg.norm(self.obs_dict["obj_err_pos"], axis=-1) 127 | self.rot_align = calculate_cosine( 128 | self.obs_dict["obj_rot"], self.obs_dict["obj_des_rot"] 129 | ) 130 | info.update(info.get("rwd_dict")) 131 | return obs, reward, done, info 132 | 133 | def render(self, mode="human"): 134 | return self.sim.render(mode=mode) 135 | -------------------------------------------------------------------------------- /src/envs/pose.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from myosuite.envs.myo.base_v0 import BaseV0 3 | from myosuite.envs.myo.pose_v0 import PoseEnvV0 4 | 5 | 6 | class CustomPoseEnv(PoseEnvV0): 7 | def _setup( 8 | self, 9 | viz_site_targets: tuple = None, # site to use for targets visualization [] 10 | target_jnt_range: dict = None, # joint ranges as tuples {name:(min, max)}_nq 11 | target_jnt_value: list = None, # desired joint vector [des_qpos]_nq 12 | reset_type="init", # none; init; random; sds 13 | target_type="generate", # generate; switch; fixed 14 | obs_keys: list = PoseEnvV0.DEFAULT_OBS_KEYS, 15 | weighted_reward_keys: dict = PoseEnvV0.DEFAULT_RWD_KEYS_AND_WEIGHTS, 16 | pose_thd=0.35, 17 | weight_bodyname=None, 18 | weight_range=None, 19 | sds_distance=0, 20 | target_distance=1, # for non-SDS curriculum, the target is set at a fraction of the full distance 21 | **kwargs, 22 | ): 23 | self.reset_type = reset_type 24 | self.target_type = target_type 25 | self.pose_thd = pose_thd 26 | self.weight_bodyname = weight_bodyname 27 | self.weight_range = weight_range 28 | self.sds_distance = sds_distance 29 | self.target_distance = target_distance 30 | 31 | # resolve joint demands 32 | if target_jnt_range: 33 | self.target_jnt_ids = [] 34 | self.target_jnt_range = [] 35 | for jnt_name, jnt_range in target_jnt_range.items(): 36 | self.target_jnt_ids.append(self.sim.model.joint_name2id(jnt_name)) 37 | self.target_jnt_range.append(jnt_range) 38 | self.target_jnt_range = np.array(self.target_jnt_range) 39 | self.target_jnt_value = np.mean( 40 | self.target_jnt_range, axis=1 41 | ) # pseudo targets for init 42 | else: 43 | self.target_jnt_value = target_jnt_value 44 | 45 | BaseV0._setup( 46 | self, 47 | obs_keys=obs_keys, 48 | weighted_reward_keys=weighted_reward_keys, 49 | sites=viz_site_targets, 50 | **kwargs, 51 | ) 52 | 53 | def reset(self): 54 | # udpate wegith 55 | if self.weight_bodyname is not None: 56 | bid = self.sim.model.body_name2id(self.weight_bodyname) 57 | gid = self.sim.model.body_geomadr[bid] 58 | weight = self.np_random.uniform( 59 | low=self.weight_range[0], high=self.weight_range[1] 60 | ) 61 | self.sim.model.body_mass[bid] = weight 62 | self.sim_obsd.model.body_mass[bid] = weight 63 | # self.sim_obsd.model.geom_size[gid] = self.sim.model.geom_size[gid] * weight/10 64 | self.sim.model.geom_size[gid][0] = 0.01 + 2.5 * weight / 100 65 | # self.sim_obsd.model.geom_size[gid][0] = weight/10 66 | 67 | # update target 68 | if self.target_type == "generate": 69 | # use target_jnt_range to generate targets 70 | self.update_target(restore_sim=True) 71 | elif self.target_type == "fixed": 72 | self.update_target(restore_sim=True) 73 | else: 74 | print("{} Target Type not found ".format(self.target_type)) 75 | 76 | # update init state 77 | if self.reset_type is None or self.reset_type == "none": 78 | # no reset; use last state 79 | obs = self.get_obs() 80 | elif self.reset_type == "init": 81 | # reset to init state 82 | obs = BaseV0.reset(self) 83 | elif self.reset_type == "random": 84 | # reset to random state 85 | jnt_init = self.np_random.uniform( 86 | high=self.sim.model.jnt_range[:, 1], low=self.sim.model.jnt_range[:, 0] 87 | ) 88 | obs = BaseV0.reset(self, reset_qpos=jnt_init) 89 | elif self.reset_type == "sds": 90 | init_qpos = self.init_qpos.copy() 91 | init_qvel = self.init_qvel.copy() 92 | target_qpos = self.target_jnt_value.copy() 93 | qpos = (1 - self.sds_distance) * target_qpos + self.sds_distance * init_qpos 94 | self.robot.reset(qpos, init_qvel) 95 | obs = self.get_obs() 96 | else: 97 | print("Reset Type not found") 98 | 99 | return obs 100 | 101 | def step(self, action): 102 | obs, reward, done, info = super().step(action) 103 | info.update(info.get("rwd_dict")) 104 | return obs, reward, done, info 105 | 106 | def render(self, mode): 107 | return self.sim.render(mode=mode) 108 | 109 | def get_target_pose(self): 110 | full_distance_target_pose = super().get_target_pose() 111 | init_pose = self.init_qpos.copy() 112 | target_pose = init_pose + self.target_distance * ( 113 | full_distance_target_pose - init_pose 114 | ) 115 | return target_pose 116 | -------------------------------------------------------------------------------- /src/envs/reorient.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=attribute-defined-outside-init, dangerous-default-value, protected-access, abstract-method, arguments-renamed 2 | import collections 3 | import numpy as np 4 | from myosuite.envs.env_base import MujocoEnv 5 | from myosuite.envs.myo.base_v0 import BaseV0 6 | from myosuite.envs.myo.myochallenge.reorient_v0 import ReorientEnvV0 7 | from myosuite.utils.quat_math import euler2quat 8 | 9 | 10 | class CustomReorientEnv(ReorientEnvV0): 11 | def get_reward_dict(self, obs_dict): 12 | pos_dist_new = np.abs(np.linalg.norm(self.obs_dict["pos_err"], axis=-1)) 13 | rot_dist_new = np.abs(np.linalg.norm(self.obs_dict["rot_err"], axis=-1)) 14 | pos_dist_diff = self.pos_dist - pos_dist_new 15 | rot_dist_diff = self.rot_dist - rot_dist_new 16 | act_mag = ( 17 | np.linalg.norm(self.obs_dict["act"], axis=-1) / self.sim.model.na 18 | if self.sim.model.na != 0 19 | else 0 20 | ) 21 | drop = pos_dist_new > self.drop_th 22 | 23 | rwd_dict = collections.OrderedDict( 24 | ( 25 | # Perform reward tuning here -- 26 | # Update Optional Keys section below 27 | # Update reward keys (DEFAULT_RWD_KEYS_AND_WEIGHTS) accordingly to update final rewards 28 | # Examples: Env comes pre-packaged with two keys pos_dist and rot_dist 29 | # Optional Keys 30 | ("pos_dist", -1.0 * pos_dist_new), 31 | ("rot_dist", -1.0 * rot_dist_new), 32 | ("pos_dist_diff", pos_dist_diff), 33 | ("rot_dist_diff", rot_dist_diff), 34 | ("alive", ~drop), 35 | # Must keys 36 | ("act_reg", -1.0 * act_mag), 37 | ("sparse", -rot_dist_new - 10.0 * pos_dist_new), 38 | ( 39 | "solved", 40 | ( 41 | (pos_dist_new < self.pos_th) 42 | and (rot_dist_new < self.rot_th) 43 | and (not drop) 44 | ) 45 | * np.ones((1, 1)), 46 | ), 47 | ("done", drop), 48 | ) 49 | ) 50 | rwd_dict["dense"] = np.sum( 51 | [wt * rwd_dict[key] for key, wt in self.rwd_keys_wt.items()], axis=0 52 | ) 53 | 54 | # Sucess Indicator 55 | self.sim.model.site_rgba[self.success_indicator_sid, :2] = ( 56 | np.array([0, 2]) if rwd_dict["solved"] else np.array([2, 0]) 57 | ) 58 | return rwd_dict 59 | 60 | def _setup( 61 | self, 62 | obs_keys: list = ReorientEnvV0.DEFAULT_OBS_KEYS, 63 | weighted_reward_keys: list = ReorientEnvV0.DEFAULT_RWD_KEYS_AND_WEIGHTS, 64 | goal_pos=(0.0, 0.0), # goal position range (relative to initial pos) 65 | goal_rot=(0.785, 0.785), # goal rotation range (relative to initial rot) 66 | obj_size_change=0, # object size change (relative to initial size) 67 | obj_friction_change=( 68 | 0, 69 | 0, 70 | 0, 71 | ), # object friction change (relative to initial size) 72 | pos_th=0.025, # position error threshold 73 | rot_th=0.262, # rotation error threshold 74 | drop_th=0.200, # drop height threshold 75 | enable_rsi=False, 76 | rsi_distance_pos=0, 77 | rsi_distance_rot=0, 78 | goal_rot_x=None, 79 | goal_rot_y=None, 80 | goal_rot_z=None, 81 | guided_trajectory_steps=None, 82 | **kwargs, 83 | ): 84 | self.already_reset = False 85 | self.object_sid = self.sim.model.site_name2id("object_o") 86 | self.goal_sid = self.sim.model.site_name2id("target_o") 87 | self.success_indicator_sid = self.sim.model.site_name2id("target_ball") 88 | self.goal_bid = self.sim.model.body_name2id("target") 89 | self.object_bid = self.sim.model.body_name2id("Object") 90 | self.goal_init_pos = self.sim.data.site_xpos[self.goal_sid].copy() 91 | self.goal_init_rot = self.sim.model.body_quat[self.goal_bid].copy() 92 | self.goal_obj_offset = ( 93 | self.sim.data.site_xpos[self.goal_sid] 94 | - self.sim.data.site_xpos[self.object_sid] 95 | ) # visualization offset between target and object 96 | self.goal_pos = goal_pos 97 | self.goal_rot = goal_rot 98 | self.pos_th = pos_th 99 | self.rot_th = rot_th 100 | self.drop_th = drop_th 101 | self.rsi = enable_rsi 102 | self.rsi_distance_pos = rsi_distance_pos 103 | self.rsi_distance_rot = rsi_distance_rot 104 | self.goal_rot_x = goal_rot_x 105 | self.goal_rot_y = goal_rot_y 106 | self.goal_rot_z = goal_rot_z 107 | self.guided_trajectory_steps = guided_trajectory_steps 108 | self.pos_dist = 0 109 | self.rot_dist = 0 110 | 111 | # setup for object randomization 112 | self.target_gid = self.sim.model.geom_name2id("target_dice") 113 | self.target_default_size = self.sim.model.geom_size[self.target_gid].copy() 114 | 115 | object_bid = self.sim.model.body_name2id("Object") 116 | self.object_gid0 = self.sim.model.body_geomadr[object_bid] 117 | self.object_gidn = self.object_gid0 + self.sim.model.body_geomnum[object_bid] 118 | self.object_default_size = self.sim.model.geom_size[ 119 | self.object_gid0 : self.object_gidn 120 | ].copy() 121 | self.object_default_pos = self.sim.model.geom_pos[ 122 | self.object_gid0 : self.object_gidn 123 | ].copy() 124 | 125 | self.obj_size_change = {"high": obj_size_change, "low": -obj_size_change} 126 | self.obj_friction_range = { 127 | "high": self.sim.model.geom_friction[self.object_gid0 : self.object_gidn] 128 | + obj_friction_change, 129 | "low": self.sim.model.geom_friction[self.object_gid0 : self.object_gidn] 130 | - obj_friction_change, 131 | } 132 | 133 | BaseV0._setup( 134 | self, 135 | obs_keys=obs_keys, 136 | weighted_reward_keys=weighted_reward_keys, 137 | **kwargs, 138 | ) 139 | self.init_qpos[:-7] *= 0 # Use fully open as init pos 140 | self.init_qpos[0] = -1.5 # Palm up 141 | 142 | def reset(self, reset_qpos=None, reset_qvel=None): 143 | 144 | # First sample the target position and orientation of the die 145 | self.episode_goal_pos = self.sample_goal_position() 146 | self.episode_goal_rot = self.sample_goal_orientation() 147 | 148 | # Then get the initial position and orientation of the die 149 | if self.rsi: 150 | object_init_pos = ( 151 | self.rsi_distance_pos * self.goal_init_pos 152 | + (1 - self.rsi_distance_pos) * self.episode_goal_pos 153 | - self.goal_obj_offset 154 | ) 155 | object_init_rot = ( 156 | self.rsi_distance_rot * self.goal_init_rot 157 | + (1 - self.rsi_distance_rot) * self.episode_goal_rot 158 | ) 159 | else: 160 | object_init_pos = self.goal_init_pos - self.goal_obj_offset 161 | object_init_rot = self.goal_init_rot 162 | 163 | # Set the position of the object 164 | self.sim.model.body_pos[self.object_bid] = object_init_pos 165 | self.sim.model.body_quat[self.object_bid] = object_init_rot 166 | 167 | # Create the target trajectory and set the initial position 168 | self.goal_pos_traj, self.goal_rot_traj = self.create_goal_trajectory( 169 | object_init_pos, object_init_rot 170 | ) 171 | self.counter = 0 172 | self.set_die_pos_rot(self.counter) 173 | 174 | # Die friction changes 175 | self.sim.model.geom_friction[ 176 | self.object_gid0 : self.object_gidn 177 | ] = self.np_random.uniform(**self.obj_friction_range) 178 | 179 | # Die and Target size changes 180 | del_size = self.np_random.uniform(**self.obj_size_change) 181 | # adjust size of target 182 | self.sim.model.geom_size[self.target_gid] = self.target_default_size + del_size 183 | # adjust size of die 184 | self.sim.model.geom_size[self.object_gid0 : self.object_gidn - 3][:, 1] = ( 185 | self.object_default_size[:-3][:, 1] + del_size 186 | ) 187 | self.sim.model.geom_size[self.object_gidn - 3 : self.object_gidn] = ( 188 | self.object_default_size[-3:] + del_size 189 | ) 190 | # adjust boundary of die 191 | object_gpos = self.sim.model.geom_pos[self.object_gid0 : self.object_gidn] 192 | self.sim.model.geom_pos[self.object_gid0 : self.object_gidn] = ( 193 | object_gpos 194 | / abs(object_gpos + 1e-16) 195 | * (abs(self.object_default_pos) + del_size) 196 | ) 197 | 198 | obs = MujocoEnv.reset(self, reset_qpos, reset_qvel) 199 | self.pos_dist = np.abs(np.linalg.norm(self.obs_dict["pos_err"], axis=-1)) 200 | self.rot_dist = np.abs(np.linalg.norm(self.obs_dict["rot_err"], axis=-1)) 201 | self.already_reset = True 202 | return obs 203 | 204 | def sample_goal_position(self): 205 | goal_pos = self.goal_init_pos + self.np_random.uniform( 206 | high=self.goal_pos[1], low=self.goal_pos[0], size=3 207 | ) 208 | return goal_pos 209 | 210 | def sample_goal_orientation(self): 211 | x_low, x_high = ( 212 | self.goal_rot_x[self.np_random.choice(len(self.goal_rot_x))] 213 | if self.goal_rot_x is not None 214 | else self.goal_rot 215 | ) 216 | y_low, y_high = ( 217 | self.goal_rot_y[self.np_random.choice(len(self.goal_rot_y))] 218 | if self.goal_rot_y is not None 219 | else self.goal_rot 220 | ) 221 | z_low, z_high = ( 222 | self.goal_rot_z[self.np_random.choice(len(self.goal_rot_z))] 223 | if self.goal_rot_z is not None 224 | else self.goal_rot 225 | ) 226 | 227 | goal_rot_x = self.np_random.uniform(x_low, x_high) 228 | goal_rot_y = self.np_random.uniform(y_low, y_high) 229 | goal_rot_z = self.np_random.uniform(z_low, z_high) 230 | goal_rot_quat = euler2quat(np.array([goal_rot_x, goal_rot_y, goal_rot_z])) 231 | return goal_rot_quat 232 | 233 | def step(self, action): 234 | obs, reward, done, info = super().step(action) 235 | self.pos_dist = np.abs(np.linalg.norm(self.obs_dict["pos_err"], axis=-1)) 236 | self.rot_dist = np.abs(np.linalg.norm(self.obs_dict["rot_err"], axis=-1)) 237 | info.update(info.get("rwd_dict")) 238 | 239 | if self.already_reset: 240 | self.counter += 1 241 | self.set_die_pos_rot(self.counter) 242 | return obs, reward, done, info 243 | 244 | def create_goal_trajectory(self, object_init_pos, object_init_rot): 245 | traj_len = 1000 # Assumes it is larger than the episode len 246 | 247 | pos_traj = np.ones((traj_len, 3)) * self.episode_goal_pos 248 | rot_traj = np.ones((traj_len, 4)) * self.episode_goal_rot 249 | 250 | if ( 251 | self.guided_trajectory_steps is not None 252 | ): # Softly reach the target position and orientation 253 | goal_init_pos = object_init_pos + self.goal_obj_offset 254 | guided_pos_traj = np.linspace( 255 | goal_init_pos, self.episode_goal_pos, self.guided_trajectory_steps 256 | ) 257 | pos_traj[: self.guided_trajectory_steps] = guided_pos_traj 258 | 259 | guided_rot_traj = np.linspace( 260 | object_init_rot, self.episode_goal_rot, self.guided_trajectory_steps 261 | ) 262 | rot_traj[: self.guided_trajectory_steps] = guided_rot_traj 263 | return pos_traj, rot_traj 264 | 265 | def set_die_pos_rot(self, counter): 266 | self.sim.model.body_pos[self.goal_bid] = self.goal_pos_traj[counter] 267 | self.sim.model.body_quat[self.goal_bid] = self.goal_rot_traj[counter] 268 | -------------------------------------------------------------------------------- /src/main_ant.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | import torch.nn as nn 5 | from datetime import datetime 6 | from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback 7 | from stable_baselines3.common.monitor import Monitor 8 | from stable_baselines3.common.vec_env import VecNormalize 9 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv 10 | from datetime import datetime 11 | from definitions import ROOT_DIR 12 | from envs.environment_factory import EnvironmentFactory 13 | from metrics.custom_callbacks import EnvDumpCallback, TensorboardCallback 14 | from train.trainer import MyoTrainer 15 | from models.sac.policies import LatticeSACPolicy 16 | 17 | 18 | parser = argparse.ArgumentParser(description='Main script to train an agent') 19 | 20 | parser.add_argument('--seed', type=int, default=0, 21 | help='Seed for random number generator') 22 | parser.add_argument('--freq', type=int, default=1, 23 | help='SDE sample frequency') 24 | parser.add_argument('--use_sde', action='store_true', default=False, 25 | help='Flag to use SDE') 26 | parser.add_argument('--use_lattice', action='store_true', default=False, 27 | help='Flag to use lattice') 28 | parser.add_argument('--log_std_init', type=float, default=0.0, 29 | help='Initial log standard deviation') 30 | parser.add_argument('--env_path', type=str, 31 | help='Path to environment file') 32 | parser.add_argument('--model_path', type=str, 33 | help='Path to model file') 34 | parser.add_argument('--num_envs', type=int, default=16, 35 | help='Number of parallel environments') 36 | parser.add_argument('--device', type=str, default="cuda", 37 | help='Device, cuda or cpu') 38 | parser.add_argument('--std_reg', type=float, default=0.0, 39 | help='Additional independent std for the multivariate gaussian (only for lattice)') 40 | args = parser.parse_args() 41 | 42 | # define constants 43 | ENV_NAME = "AntBulletEnv" 44 | 45 | now = datetime.now().strftime("%Y-%m-%d/%H-%M-%S") 46 | 47 | if args.model_path is not None: 48 | model_name = args.model_path.split("/")[-2] 49 | else: 50 | model_name = None 51 | 52 | TENSORBOARD_LOG = ( 53 | os.path.join(ROOT_DIR, "output", "training", now) 54 | + f"_ant_sde_{args.use_sde}_lattice_{args.use_lattice}_freq_{args.freq}_log_std_init_{args.log_std_init}_std_reg_{args.std_reg}_sac_seed_{args.seed}_resume_{model_name}" 55 | ) 56 | 57 | # Reward structure and task parameters: 58 | config = { 59 | } 60 | 61 | max_episode_steps = 1000 62 | 63 | model_config = dict( 64 | policy=LatticeSACPolicy, 65 | device=args.device, 66 | learning_rate=3e-4, 67 | buffer_size=300_000, 68 | learning_starts=10000, 69 | batch_size=256, 70 | tau=0.02, 71 | gamma=0.98, 72 | train_freq=(8, "step"), 73 | gradient_steps=8, 74 | action_noise=None, 75 | replay_buffer_class=None, 76 | ent_coef="auto", 77 | target_update_interval=1, 78 | target_entropy="auto", 79 | seed=args.seed, 80 | use_sde=args.use_sde, 81 | sde_sample_freq=args.freq, 82 | policy_kwargs=dict( 83 | use_lattice=args.use_lattice, 84 | use_expln=True, 85 | log_std_init=args.log_std_init, 86 | activation_fn=nn.GELU, 87 | net_arch=dict(pi=[400, 300], qf=[400, 300]), 88 | std_clip=(1e-3, 1), 89 | expln_eps=1e-6, 90 | clip_mean=2.0, 91 | std_reg=args.std_reg 92 | ), 93 | ) 94 | 95 | # Function that creates and monitors vectorized environments: 96 | def make_parallel_envs(env_config, num_env, start_index=0): 97 | def make_env(_): 98 | def _thunk(): 99 | env = EnvironmentFactory.create(ENV_NAME, **env_config) 100 | env.seed(args.seed) 101 | env._max_episode_steps = max_episode_steps 102 | env = Monitor(env, TENSORBOARD_LOG) 103 | return env 104 | 105 | return _thunk 106 | 107 | return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)]) 108 | 109 | 110 | if __name__ == "__main__": 111 | # ensure tensorboard log directory exists and copy this file to track 112 | os.makedirs(TENSORBOARD_LOG, exist_ok=True) 113 | shutil.copy(os.path.abspath(__file__), TENSORBOARD_LOG) 114 | 115 | # Create and wrap the training and evaluations environments 116 | envs = make_parallel_envs(config, args.num_envs) 117 | 118 | if args.env_path is not None: 119 | envs = VecNormalize.load(args.env_path, envs) 120 | else: 121 | envs = VecNormalize(envs) 122 | 123 | # Define callbacks for evaluation and saving the agent 124 | eval_callback = EvalCallback( 125 | eval_env=envs, 126 | callback_on_new_best=EnvDumpCallback(TENSORBOARD_LOG, verbose=0), 127 | n_eval_episodes=10, 128 | best_model_save_path=TENSORBOARD_LOG, 129 | log_path=TENSORBOARD_LOG, 130 | eval_freq=10_000, 131 | deterministic=True, 132 | render=False, 133 | verbose=1, 134 | ) 135 | 136 | checkpoint_callback = CheckpointCallback( 137 | save_freq=25_000, 138 | save_path=TENSORBOARD_LOG, 139 | save_vecnormalize=True, 140 | verbose=1, 141 | ) 142 | 143 | tensorboard_callback = TensorboardCallback( 144 | info_keywords=( 145 | ) 146 | ) 147 | 148 | # Define trainer 149 | trainer = MyoTrainer( 150 | algo="sac", 151 | envs=envs, 152 | env_config=config, 153 | load_model_path=args.model_path, 154 | log_dir=TENSORBOARD_LOG, 155 | model_config=model_config, 156 | callbacks=[eval_callback, checkpoint_callback, tensorboard_callback], 157 | timesteps=10_000_000, 158 | ) 159 | 160 | # Train agent 161 | trainer.train(total_timesteps=trainer.timesteps) 162 | trainer.save() 163 | -------------------------------------------------------------------------------- /src/main_baoding.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | import torch.nn as nn 5 | from datetime import datetime 6 | from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback 7 | from stable_baselines3.common.monitor import Monitor 8 | from stable_baselines3.common.vec_env import VecNormalize 9 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv 10 | from datetime import datetime 11 | from definitions import ROOT_DIR 12 | from envs.environment_factory import EnvironmentFactory 13 | from metrics.custom_callbacks import EnvDumpCallback, TensorboardCallback 14 | from train.trainer import MyoTrainer 15 | from models.ppo.policies import LatticeRecurrentActorCriticPolicy 16 | 17 | 18 | parser = argparse.ArgumentParser(description='Main script to train an agent') 19 | 20 | parser.add_argument('--seed', type=int, default=0, 21 | help='Seed for random number generator') 22 | parser.add_argument('--freq', type=int, default=1, 23 | help='SDE sample frequency') 24 | parser.add_argument('--use_sde', action='store_true', default=False, 25 | help='Flag to use SDE') 26 | parser.add_argument('--use_lattice', action='store_true', default=False, 27 | help='Flag to use lattice') 28 | parser.add_argument('--log_std_init', type=float, default=0.0, 29 | help='Initial log standard deviation') 30 | parser.add_argument('--env_path', type=str, 31 | help='Path to environment file') 32 | parser.add_argument('--model_path', type=str, 33 | help='Path to model file') 34 | parser.add_argument('--num_envs', type=int, default=16, 35 | help='Number of parallel environments') 36 | parser.add_argument('--device', type=str, default="cuda", 37 | help='Device, cuda or cpu') 38 | parser.add_argument('--std_reg', type=float, default=0.0, 39 | help='Additional independent std for the multivariate gaussian (only for lattice)') 40 | 41 | args = parser.parse_args() 42 | 43 | # define constants 44 | ENV_NAME = "CustomMyoBaodingBallsP1" 45 | 46 | now = datetime.now().strftime("%Y-%m-%d/%H-%M-%S") 47 | 48 | if args.model_path is not None: 49 | model_name = args.model_path.split("/")[-2] 50 | else: 51 | model_name = None 52 | 53 | TENSORBOARD_LOG = ( 54 | os.path.join(ROOT_DIR, "output", "training", now) 55 | + f"_baoding_sde_{args.use_sde}_lattice_{args.use_lattice}_freq_{args.freq}_log_std_init_{args.log_std_init}_std_reg_{args.std_reg}_recurrent_ppo_seed_{args.seed}_resume_{model_name}" 56 | ) 57 | 58 | # Reward structure and task parameters: 59 | config = { 60 | "seed": args.seed, 61 | "weighted_reward_keys": { 62 | "pos_dist_1": 1, 63 | "pos_dist_2": 1, 64 | "act_reg": 0, 65 | "alive": 1, 66 | "solved": 5, 67 | "done": 0, 68 | "sparse": 0, 69 | }, 70 | } 71 | 72 | max_episode_steps = 200 73 | 74 | model_config = dict( 75 | policy=LatticeRecurrentActorCriticPolicy, 76 | device=args.device, 77 | batch_size=32, 78 | n_steps=128, 79 | learning_rate=2.55673e-05, 80 | ent_coef=3.62109e-06, 81 | clip_range=0.3, 82 | gamma=0.99, 83 | gae_lambda=0.9, 84 | max_grad_norm=0.7, 85 | vf_coef=0.835671, 86 | n_epochs=10, 87 | use_sde=args.use_sde, 88 | sde_sample_freq=args.freq, 89 | policy_kwargs=dict( 90 | use_lattice=args.use_lattice, 91 | use_expln=True, 92 | ortho_init=False, 93 | log_std_init=args.log_std_init, 94 | activation_fn=nn.ReLU, 95 | net_arch=[dict(pi=[256, 256], vf=[256, 256])], 96 | std_clip=(1e-3, 10), 97 | expln_eps=1e-6, 98 | full_std=False, 99 | std_reg=args.std_reg 100 | ), 101 | ) 102 | 103 | # Function that creates and monitors vectorized environments: 104 | def make_parallel_envs(env_config, num_env, start_index=0): 105 | def make_env(_): 106 | def _thunk(): 107 | env = EnvironmentFactory.create(ENV_NAME, **env_config) 108 | env.seed(args.seed) 109 | env._max_episode_steps = max_episode_steps 110 | env = Monitor(env, TENSORBOARD_LOG) 111 | return env 112 | 113 | return _thunk 114 | 115 | return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)]) 116 | 117 | 118 | if __name__ == "__main__": 119 | # ensure tensorboard log directory exists and copy this file to track 120 | os.makedirs(TENSORBOARD_LOG, exist_ok=True) 121 | shutil.copy(os.path.abspath(__file__), TENSORBOARD_LOG) 122 | 123 | # Create and wrap the training and evaluations environments 124 | envs = make_parallel_envs(config, args.num_envs) 125 | 126 | if args.env_path is not None: 127 | envs = VecNormalize.load(args.env_path, envs) 128 | else: 129 | envs = VecNormalize(envs) 130 | 131 | # Define callbacks for evaluation and saving the agent 132 | eval_callback = EvalCallback( 133 | eval_env=envs, 134 | callback_on_new_best=EnvDumpCallback(TENSORBOARD_LOG, verbose=0), 135 | n_eval_episodes=10, 136 | best_model_save_path=TENSORBOARD_LOG, 137 | log_path=TENSORBOARD_LOG, 138 | eval_freq=10_000, 139 | deterministic=True, 140 | render=False, 141 | verbose=1, 142 | ) 143 | 144 | checkpoint_callback = CheckpointCallback( 145 | save_freq=25_000, 146 | save_path=TENSORBOARD_LOG, 147 | save_vecnormalize=True, 148 | verbose=1, 149 | ) 150 | 151 | tensorboard_callback = TensorboardCallback( 152 | info_keywords=( 153 | "pos_dist_1", 154 | "pos_dist_2", 155 | "act_reg", 156 | "alive", 157 | "solved", 158 | ) 159 | ) 160 | 161 | # Define trainer 162 | trainer = MyoTrainer( 163 | algo="recurrent_ppo", 164 | envs=envs, 165 | env_config=config, 166 | load_model_path=args.model_path, 167 | log_dir=TENSORBOARD_LOG, 168 | model_config=model_config, 169 | callbacks=[eval_callback, checkpoint_callback, tensorboard_callback], 170 | timesteps=20_000_000, 171 | ) 172 | 173 | # Train agent 174 | trainer.train(total_timesteps=trainer.timesteps) 175 | trainer.save() 176 | -------------------------------------------------------------------------------- /src/main_half_cheetah.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | import torch.nn as nn 5 | from datetime import datetime 6 | from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback 7 | from stable_baselines3.common.monitor import Monitor 8 | from stable_baselines3.common.vec_env import VecNormalize 9 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv 10 | from datetime import datetime 11 | from definitions import ROOT_DIR 12 | from envs.environment_factory import EnvironmentFactory 13 | from metrics.custom_callbacks import EnvDumpCallback, TensorboardCallback 14 | from train.trainer import MyoTrainer 15 | from models.sac.policies import LatticeSACPolicy 16 | 17 | 18 | parser = argparse.ArgumentParser(description='Main script to train an agent') 19 | 20 | parser.add_argument('--seed', type=int, default=0, 21 | help='Seed for random number generator') 22 | parser.add_argument('--freq', type=int, default=1, 23 | help='SDE sample frequency') 24 | parser.add_argument('--use_sde', action='store_true', default=False, 25 | help='Flag to use SDE') 26 | parser.add_argument('--use_lattice', action='store_true', default=False, 27 | help='Flag to use lattice') 28 | parser.add_argument('--log_std_init', type=float, default=0.0, 29 | help='Initial log standard deviation') 30 | parser.add_argument('--env_path', type=str, 31 | help='Path to environment file') 32 | parser.add_argument('--model_path', type=str, 33 | help='Path to model file') 34 | parser.add_argument('--num_envs', type=int, default=16, 35 | help='Number of parallel environments') 36 | parser.add_argument('--device', type=str, default="cuda", 37 | help='Device, cuda or cpu') 38 | parser.add_argument('--std_reg', type=float, default=0.0, 39 | help='Additional independent std for the multivariate gaussian (only for lattice)') 40 | args = parser.parse_args() 41 | 42 | # define constants 43 | ENV_NAME = "HalfCheetahBulletEnv" 44 | 45 | now = datetime.now().strftime("%Y-%m-%d/%H-%M-%S") 46 | 47 | if args.model_path is not None: 48 | model_name = args.model_path.split("/")[-2] 49 | else: 50 | model_name = None 51 | 52 | TENSORBOARD_LOG = ( 53 | os.path.join(ROOT_DIR, "output", "training", now) 54 | + f"_half_cheetah_sde_{args.use_sde}_lattice_{args.use_lattice}_freq_{args.freq}_log_std_init_{args.log_std_init}_std_reg_{args.std_reg}_sac_seed_{args.seed}_resume_{model_name}" 55 | ) 56 | 57 | # Reward structure and task parameters: 58 | config = { 59 | } 60 | 61 | max_episode_steps = 1000 62 | num_envs = args.num_envs 63 | 64 | model_config = dict( 65 | policy=LatticeSACPolicy, 66 | device=args.device, 67 | learning_rate=3e-4, 68 | buffer_size=300_000, 69 | learning_starts=10000, 70 | batch_size=256, 71 | tau=0.02, 72 | gamma=0.98, 73 | train_freq=(8, "step"), 74 | gradient_steps=8, 75 | action_noise=None, 76 | replay_buffer_class=None, 77 | ent_coef="auto", 78 | target_update_interval=1, 79 | target_entropy="auto", 80 | seed=args.seed, 81 | use_sde=args.use_sde, 82 | sde_sample_freq=args.freq, 83 | policy_kwargs=dict( 84 | use_lattice=args.use_lattice, 85 | use_expln=True, 86 | log_std_init=args.log_std_init, 87 | activation_fn=nn.GELU, 88 | net_arch=dict(pi=[400, 300], qf=[400, 300]), 89 | std_clip=(1e-3, 10), 90 | expln_eps=1e-6, 91 | clip_mean=2.0, 92 | std_reg=args.std_reg 93 | ), 94 | ) 95 | 96 | # Function that creates and monitors vectorized environments: 97 | def make_parallel_envs(env_config, num_env, start_index=0): 98 | def make_env(_): 99 | def _thunk(): 100 | env = EnvironmentFactory.create(ENV_NAME, **env_config) 101 | env.seed(args.seed) 102 | env._max_episode_steps = max_episode_steps 103 | env = Monitor(env, TENSORBOARD_LOG) 104 | return env 105 | 106 | return _thunk 107 | 108 | return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)]) 109 | 110 | 111 | if __name__ == "__main__": 112 | # ensure tensorboard log directory exists and copy this file to track 113 | os.makedirs(TENSORBOARD_LOG, exist_ok=True) 114 | shutil.copy(os.path.abspath(__file__), TENSORBOARD_LOG) 115 | 116 | # Create and wrap the training and evaluations environments 117 | envs = make_parallel_envs(config, num_envs) 118 | 119 | if args.env_path is not None: 120 | envs = VecNormalize.load(args.env_path, envs) 121 | else: 122 | envs = VecNormalize(envs) 123 | 124 | # Define callbacks for evaluation and saving the agent 125 | eval_callback = EvalCallback( 126 | eval_env=envs, 127 | callback_on_new_best=EnvDumpCallback(TENSORBOARD_LOG, verbose=0), 128 | n_eval_episodes=10, 129 | best_model_save_path=TENSORBOARD_LOG, 130 | log_path=TENSORBOARD_LOG, 131 | eval_freq=10_000, 132 | deterministic=True, 133 | render=False, 134 | verbose=1, 135 | ) 136 | 137 | checkpoint_callback = CheckpointCallback( 138 | save_freq=25_000, 139 | save_path=TENSORBOARD_LOG, 140 | save_vecnormalize=True, 141 | verbose=1, 142 | ) 143 | 144 | tensorboard_callback = TensorboardCallback( 145 | info_keywords=( 146 | ) 147 | ) 148 | 149 | # Define trainer 150 | trainer = MyoTrainer( 151 | algo="sac", 152 | envs=envs, 153 | env_config=config, 154 | load_model_path=args.model_path, 155 | log_dir=TENSORBOARD_LOG, 156 | model_config=model_config, 157 | callbacks=[eval_callback, checkpoint_callback, tensorboard_callback], 158 | timesteps=10_000_000, 159 | ) 160 | 161 | # Train agent 162 | trainer.train(total_timesteps=trainer.timesteps) 163 | trainer.save() 164 | -------------------------------------------------------------------------------- /src/main_hopper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | import torch.nn as nn 5 | from datetime import datetime 6 | from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback 7 | from stable_baselines3.common.monitor import Monitor 8 | from stable_baselines3.common.vec_env import VecNormalize 9 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv 10 | from datetime import datetime 11 | from definitions import ROOT_DIR 12 | from envs.environment_factory import EnvironmentFactory 13 | from metrics.custom_callbacks import EnvDumpCallback, TensorboardCallback 14 | from train.trainer import MyoTrainer 15 | from models.sac.policies import LatticeSACPolicy 16 | 17 | 18 | parser = argparse.ArgumentParser(description='Main script to train an agent') 19 | 20 | parser.add_argument('--seed', type=int, default=0, 21 | help='Seed for random number generator') 22 | parser.add_argument('--freq', type=int, default=1, 23 | help='SDE sample frequency') 24 | parser.add_argument('--use_sde', action='store_true', default=False, 25 | help='Flag to use SDE') 26 | parser.add_argument('--use_lattice', action='store_true', default=False, 27 | help='Flag to use lattice') 28 | parser.add_argument('--log_std_init', type=float, default=0.0, 29 | help='Initial log standard deviation') 30 | parser.add_argument('--env_path', type=str, 31 | help='Path to environment file') 32 | parser.add_argument('--model_path', type=str, 33 | help='Path to model file') 34 | parser.add_argument('--num_envs', type=int, default=16, 35 | help='Number of parallel environments') 36 | parser.add_argument('--device', type=str, default="cuda", 37 | help='Device, cuda or cpu') 38 | parser.add_argument('--std_reg', type=float, default=0.0, 39 | help='Additional independent std for the multivariate gaussian (only for lattice)') 40 | args = parser.parse_args() 41 | 42 | # define constants 43 | ENV_NAME = "HopperBulletEnv" 44 | 45 | now = datetime.now().strftime("%Y-%m-%d/%H-%M-%S") 46 | 47 | if args.model_path is not None: 48 | model_name = args.model_path.split("/")[-2] 49 | else: 50 | model_name = None 51 | 52 | TENSORBOARD_LOG = ( 53 | os.path.join(ROOT_DIR, "output", "training", now) 54 | + f"_hopper_sde_{args.use_sde}_lattice_{args.use_lattice}_freq_{args.freq}_log_std_init_{args.log_std_init}_std_reg_{args.std_reg}_sac_seed_{args.seed}_resume_{model_name}" 55 | ) 56 | 57 | # Reward structure and task parameters: 58 | config = { 59 | } 60 | 61 | max_episode_steps = 1000 # default: 1000 62 | 63 | model_config = dict( 64 | policy=LatticeSACPolicy, 65 | device=args.device, 66 | learning_rate=3e-4, 67 | buffer_size=300_000, 68 | learning_starts=10000, 69 | batch_size=256, 70 | tau=0.02, 71 | gamma=0.98, 72 | train_freq=(8, "step"), 73 | gradient_steps=8, 74 | action_noise=None, 75 | replay_buffer_class=None, 76 | ent_coef="auto", 77 | target_update_interval=1, 78 | target_entropy="auto", 79 | seed=args.seed, 80 | use_sde=args.use_sde, 81 | sde_sample_freq=args.freq, 82 | policy_kwargs=dict( 83 | use_lattice=args.use_lattice, 84 | use_expln=True, 85 | log_std_init=args.log_std_init, 86 | activation_fn=nn.GELU, 87 | net_arch=dict(pi=[400, 300], qf=[400, 300]), 88 | std_clip=(1e-3, 10), 89 | expln_eps=1e-6, 90 | clip_mean=2.0, 91 | std_reg=args.std_reg 92 | ), 93 | ) 94 | 95 | # Function that creates and monitors vectorized environments: 96 | def make_parallel_envs(env_config, num_env, start_index=0): 97 | def make_env(_): 98 | def _thunk(): 99 | env = EnvironmentFactory.create(ENV_NAME, **env_config) 100 | env.seed(args.seed) 101 | env._max_episode_steps = max_episode_steps 102 | env = Monitor(env, TENSORBOARD_LOG) 103 | return env 104 | 105 | return _thunk 106 | 107 | return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)]) 108 | 109 | 110 | if __name__ == "__main__": 111 | # ensure tensorboard log directory exists and copy this file to track 112 | os.makedirs(TENSORBOARD_LOG, exist_ok=True) 113 | shutil.copy(os.path.abspath(__file__), TENSORBOARD_LOG) 114 | 115 | # Create and wrap the training and evaluations environments 116 | envs = make_parallel_envs(config, args.num_envs) 117 | 118 | if args.env_path is not None: 119 | envs = VecNormalize.load(args.env_path, envs) 120 | else: 121 | envs = VecNormalize(envs) 122 | 123 | # Define callbacks for evaluation and saving the agent 124 | eval_callback = EvalCallback( 125 | eval_env=envs, 126 | callback_on_new_best=EnvDumpCallback(TENSORBOARD_LOG, verbose=0), 127 | n_eval_episodes=10, 128 | best_model_save_path=TENSORBOARD_LOG, 129 | log_path=TENSORBOARD_LOG, 130 | eval_freq=10_000, 131 | deterministic=True, 132 | render=False, 133 | verbose=1, 134 | ) 135 | 136 | checkpoint_callback = CheckpointCallback( 137 | save_freq=25_000, 138 | save_path=TENSORBOARD_LOG, 139 | save_vecnormalize=True, 140 | verbose=1, 141 | ) 142 | 143 | tensorboard_callback = TensorboardCallback( 144 | info_keywords=( 145 | ) 146 | ) 147 | 148 | # Define trainer 149 | trainer = MyoTrainer( 150 | algo="sac", 151 | envs=envs, 152 | env_config=config, 153 | load_model_path=args.model_path, 154 | log_dir=TENSORBOARD_LOG, 155 | model_config=model_config, 156 | callbacks=[eval_callback, checkpoint_callback, tensorboard_callback], 157 | timesteps=15_000_000, 158 | ) 159 | 160 | # Train agent 161 | trainer.train(total_timesteps=trainer.timesteps) 162 | trainer.save() 163 | -------------------------------------------------------------------------------- /src/main_humanoid.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | import torch.nn as nn 5 | from datetime import datetime 6 | from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback 7 | from stable_baselines3.common.monitor import Monitor 8 | from stable_baselines3.common.vec_env import VecNormalize 9 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv 10 | from datetime import datetime 11 | from definitions import ROOT_DIR 12 | from envs.environment_factory import EnvironmentFactory 13 | from metrics.custom_callbacks import EnvDumpCallback, TensorboardCallback 14 | from train.trainer import MyoTrainer 15 | from models.sac.policies import LatticeSACPolicy 16 | 17 | 18 | parser = argparse.ArgumentParser(description='Main script to train an agent') 19 | 20 | parser.add_argument('--seed', type=int, default=0, 21 | help='Seed for random number generator') 22 | parser.add_argument('--freq', type=int, default=1, 23 | help='SDE sample frequency') 24 | parser.add_argument('--use_sde', action='store_true', default=False, 25 | help='Flag to use SDE') 26 | parser.add_argument('--use_lattice', action='store_true', default=False, 27 | help='Flag to use lattice') 28 | parser.add_argument('--log_std_init', type=float, default=0.0, 29 | help='Initial log standard deviation') 30 | parser.add_argument('--env_path', type=str, 31 | help='Path to environment file') 32 | parser.add_argument('--model_path', type=str, 33 | help='Path to model file') 34 | parser.add_argument('--num_envs', type=int, default=16, 35 | help='Number of parallel environments') 36 | parser.add_argument('--device', type=str, default="cuda", 37 | help='Device, cuda or cpu') 38 | parser.add_argument('--std_reg', type=float, default=0.0, 39 | help='Additional independent std for the multivariate gaussian (only for lattice)') 40 | args = parser.parse_args() 41 | 42 | # define constants 43 | ENV_NAME = "HumanoidBulletEnv" 44 | 45 | now = datetime.now().strftime("%Y-%m-%d/%H-%M-%S") 46 | 47 | if args.model_path is not None: 48 | model_name = args.model_path.split("/")[-2] 49 | else: 50 | model_name = None 51 | 52 | TENSORBOARD_LOG = ( 53 | os.path.join(ROOT_DIR, "output", "training", now) 54 | + f"_humanoid_sde_{args.use_sde}_lattice_{args.use_lattice}_freq_{args.freq}_log_std_init_{args.log_std_init}_std_reg_{args.std_reg}_sac_seed_{args.seed}_resume_{model_name}" 55 | ) 56 | 57 | # Reward structure and task parameters: 58 | config = { 59 | } 60 | 61 | max_episode_steps = 1000 62 | 63 | model_config = dict( 64 | policy=LatticeSACPolicy, 65 | device=args.device, 66 | learning_rate=3e-4, 67 | buffer_size=300_000, 68 | learning_starts=10000, 69 | batch_size=256, 70 | tau=0.02, 71 | gamma=0.98, 72 | train_freq=(8, "step"), 73 | gradient_steps=8, 74 | action_noise=None, 75 | replay_buffer_class=None, 76 | ent_coef="auto", 77 | target_update_interval=1, 78 | target_entropy="auto", 79 | seed=args.seed, 80 | use_sde=args.use_sde, 81 | sde_sample_freq=args.freq, 82 | policy_kwargs=dict( 83 | use_lattice=args.use_lattice, 84 | use_expln=True, 85 | log_std_init=args.log_std_init, 86 | activation_fn=nn.GELU, 87 | net_arch=dict(pi=[400, 300], qf=[400, 300]), 88 | std_clip=(1e-3, 1), 89 | expln_eps=1e-6, 90 | clip_mean=2.0, 91 | std_reg=args.std_reg 92 | ), 93 | ) 94 | 95 | # Function that creates and monitors vectorized environments: 96 | def make_parallel_envs(env_config, num_env, start_index=0): 97 | def make_env(_): 98 | def _thunk(): 99 | env = EnvironmentFactory.create(ENV_NAME, **env_config) 100 | env.seed(args.seed) 101 | env._max_episode_steps = max_episode_steps 102 | env = Monitor(env, TENSORBOARD_LOG) 103 | return env 104 | 105 | return _thunk 106 | 107 | return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)]) 108 | 109 | 110 | if __name__ == "__main__": 111 | # ensure tensorboard log directory exists and copy this file to track 112 | os.makedirs(TENSORBOARD_LOG, exist_ok=True) 113 | shutil.copy(os.path.abspath(__file__), TENSORBOARD_LOG) 114 | 115 | # Create and wrap the training and evaluations environments 116 | envs = make_parallel_envs(config, args.num_envs) 117 | 118 | if args.env_path is not None: 119 | envs = VecNormalize.load(args.env_path, envs) 120 | else: 121 | envs = VecNormalize(envs) 122 | 123 | # Define callbacks for evaluation and saving the agent 124 | eval_callback = EvalCallback( 125 | eval_env=envs, 126 | callback_on_new_best=EnvDumpCallback(TENSORBOARD_LOG, verbose=0), 127 | n_eval_episodes=10, 128 | best_model_save_path=TENSORBOARD_LOG, 129 | log_path=TENSORBOARD_LOG, 130 | eval_freq=10_000, 131 | deterministic=True, 132 | render=False, 133 | verbose=1, 134 | ) 135 | 136 | checkpoint_callback = CheckpointCallback( 137 | save_freq=25_000, 138 | save_path=TENSORBOARD_LOG, 139 | save_vecnormalize=True, 140 | verbose=1, 141 | ) 142 | 143 | tensorboard_callback = TensorboardCallback( 144 | info_keywords=( 145 | ) 146 | ) 147 | 148 | # Define trainer 149 | trainer = MyoTrainer( 150 | algo="sac", 151 | envs=envs, 152 | env_config=config, 153 | load_model_path=args.model_path, 154 | log_dir=TENSORBOARD_LOG, 155 | model_config=model_config, 156 | callbacks=[eval_callback, checkpoint_callback, tensorboard_callback], 157 | timesteps=20_000_000, 158 | ) 159 | 160 | # Train agent 161 | trainer.train(total_timesteps=trainer.timesteps) 162 | trainer.save() 163 | -------------------------------------------------------------------------------- /src/main_pen.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | import torch.nn as nn 5 | from datetime import datetime 6 | from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback 7 | from stable_baselines3.common.monitor import Monitor 8 | from stable_baselines3.common.vec_env import VecNormalize 9 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv 10 | from datetime import datetime 11 | from definitions import ROOT_DIR 12 | from envs.environment_factory import EnvironmentFactory 13 | from metrics.custom_callbacks import EnvDumpCallback, TensorboardCallback 14 | from train.trainer import MyoTrainer 15 | from models.ppo.policies import LatticeRecurrentActorCriticPolicy 16 | 17 | 18 | parser = argparse.ArgumentParser(description='Main script to train an agent') 19 | 20 | parser.add_argument('--seed', type=int, default=0, 21 | help='Seed for random number generator') 22 | parser.add_argument('--freq', type=int, default=1, 23 | help='SDE sample frequency') 24 | parser.add_argument('--use_sde', action='store_true', default=False, 25 | help='Flag to use SDE') 26 | parser.add_argument('--use_lattice', action='store_true', default=False, 27 | help='Flag to use lattice') 28 | parser.add_argument('--log_std_init', type=float, default=0.0, 29 | help='Initial log standard deviation') 30 | parser.add_argument('--env_path', type=str, 31 | help='Path to environment file') 32 | parser.add_argument('--model_path', type=str, 33 | help='Path to model file') 34 | parser.add_argument('--num_envs', type=int, default=16, 35 | help='Number of parallel environments') 36 | parser.add_argument('--device', type=str, default="cuda", 37 | help='Device, cuda or cpu') 38 | parser.add_argument('--std_reg', type=float, default=0.0, 39 | help='Additional independent std for the multivariate gaussian (only for lattice)') 40 | 41 | args = parser.parse_args() 42 | 43 | # define constants 44 | ENV_NAME = "CustomMyoPenTwirlRandom" 45 | 46 | now = datetime.now().strftime("%Y-%m-%d/%H-%M-%S") 47 | 48 | if args.model_path is not None: 49 | model_name = args.model_path.split("/")[-2] 50 | else: 51 | model_name = None 52 | 53 | TENSORBOARD_LOG = ( 54 | os.path.join(ROOT_DIR, "output", "training", now) 55 | + f"_pen_sde_{args.use_sde}_lattice_{args.use_lattice}_freq_{args.freq}_log_std_init_{args.log_std_init}_std_reg_{args.std_reg}_recurrent_ppo_seed_{args.seed}_resume_{model_name}" 56 | ) 57 | 58 | # Reward structure and task parameters: 59 | config = { 60 | "seed": args.seed, 61 | "weighted_reward_keys": { 62 | "pos_align": 0, 63 | "rot_align": 0, 64 | "pos_align_diff": 1e2, 65 | "rot_align_diff": 1e2, 66 | "alive": 0, 67 | "act_reg": 0, 68 | "drop": 0, 69 | "bonus": 0, 70 | "solved": 1, 71 | "done": 0, 72 | "sparse": 0, 73 | }, 74 | "goal_orient_range": (-1, 1), 75 | "enable_rsi": False, 76 | "rsi_distance": None, 77 | } 78 | 79 | max_episode_steps = 100 80 | 81 | model_config = dict( 82 | policy=LatticeRecurrentActorCriticPolicy, 83 | device=args.device, 84 | batch_size=32, 85 | n_steps=128, 86 | learning_rate=2.55673e-05, 87 | ent_coef=3.62109e-06, 88 | clip_range=0.3, 89 | gamma=0.99, 90 | gae_lambda=0.9, 91 | max_grad_norm=0.7, 92 | vf_coef=0.835671, 93 | n_epochs=10, 94 | use_sde=args.use_sde, 95 | sde_sample_freq=args.freq, 96 | policy_kwargs=dict( 97 | use_lattice=args.use_lattice, 98 | use_expln=True, 99 | ortho_init=False, 100 | log_std_init=args.log_std_init, 101 | activation_fn=nn.ReLU, 102 | net_arch=[dict(pi=[256, 256], vf=[256, 256])], 103 | std_clip=(1e-3, 10), 104 | expln_eps=1e-6, 105 | full_std=False, 106 | std_reg=args.std_reg 107 | ), 108 | ) 109 | 110 | # Function that creates and monitors vectorized environments: 111 | def make_parallel_envs(env_config, num_env, start_index=0): 112 | def make_env(_): 113 | def _thunk(): 114 | env = EnvironmentFactory.create(ENV_NAME, **env_config) 115 | env.seed(args.seed) 116 | env._max_episode_steps = max_episode_steps 117 | env = Monitor(env, TENSORBOARD_LOG) 118 | return env 119 | 120 | return _thunk 121 | 122 | return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)]) 123 | 124 | 125 | if __name__ == "__main__": 126 | # ensure tensorboard log directory exists and copy this file to track 127 | os.makedirs(TENSORBOARD_LOG, exist_ok=True) 128 | shutil.copy(os.path.abspath(__file__), TENSORBOARD_LOG) 129 | 130 | # Create and wrap the training and evaluations environments 131 | envs = make_parallel_envs(config, args.num_envs) 132 | 133 | if args.env_path is not None: 134 | envs = VecNormalize.load(args.env_path, envs) 135 | else: 136 | envs = VecNormalize(envs) 137 | 138 | # Define callbacks for evaluation and saving the agent 139 | eval_callback = EvalCallback( 140 | eval_env=envs, 141 | callback_on_new_best=EnvDumpCallback(TENSORBOARD_LOG, verbose=0), 142 | n_eval_episodes=10, 143 | best_model_save_path=TENSORBOARD_LOG, 144 | log_path=TENSORBOARD_LOG, 145 | eval_freq=10_000, 146 | deterministic=True, 147 | render=False, 148 | verbose=1, 149 | ) 150 | 151 | checkpoint_callback = CheckpointCallback( 152 | save_freq=25_000, 153 | save_path=TENSORBOARD_LOG, 154 | save_vecnormalize=True, 155 | verbose=1, 156 | ) 157 | 158 | tensorboard_callback = TensorboardCallback( 159 | info_keywords=list(config["weighted_reward_keys"].keys()) 160 | ) 161 | 162 | # Define trainer 163 | trainer = MyoTrainer( 164 | algo="recurrent_ppo", 165 | envs=envs, 166 | env_config=config, 167 | load_model_path=args.model_path, 168 | log_dir=TENSORBOARD_LOG, 169 | model_config=model_config, 170 | callbacks=[eval_callback, checkpoint_callback, tensorboard_callback], 171 | timesteps=20_000_000, 172 | ) 173 | 174 | # Train agent 175 | trainer.train(total_timesteps=trainer.timesteps) 176 | trainer.save() 177 | -------------------------------------------------------------------------------- /src/main_pose_elbow.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | import torch.nn as nn 5 | from datetime import datetime 6 | from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback 7 | from stable_baselines3.common.monitor import Monitor 8 | from stable_baselines3.common.vec_env import VecNormalize 9 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv 10 | from datetime import datetime 11 | from definitions import ROOT_DIR 12 | from envs.environment_factory import EnvironmentFactory 13 | from metrics.custom_callbacks import EnvDumpCallback, TensorboardCallback 14 | from train.trainer import MyoTrainer 15 | from models.ppo.policies import LatticeRecurrentActorCriticPolicy 16 | 17 | 18 | parser = argparse.ArgumentParser(description='Main script to train an agent') 19 | 20 | parser.add_argument('--seed', type=int, default=0, 21 | help='Seed for random number generator') 22 | parser.add_argument('--freq', type=int, default=1, 23 | help='SDE sample frequency') 24 | parser.add_argument('--use_sde', action='store_true', default=False, 25 | help='Flag to use SDE') 26 | parser.add_argument('--use_lattice', action='store_true', default=False, 27 | help='Flag to use lattice') 28 | parser.add_argument('--log_std_init', type=float, default=-1.0, 29 | help='Initial log standard deviation') 30 | parser.add_argument('--env_path', type=str, 31 | help='Path to environment file') 32 | parser.add_argument('--model_path', type=str, 33 | help='Path to model file') 34 | parser.add_argument('--num_envs', type=int, default=16, 35 | help='Number of parallel environments') 36 | parser.add_argument('--device', type=str, default="cuda", 37 | help='Device, cuda or cpu') 38 | parser.add_argument('--std_reg', type=float, default=0.0, 39 | help='Additional independent std for the multivariate gaussian (only for lattice)') 40 | parser.add_argument('--alpha', type=float, default=1.0, 41 | help='Weight of the uncorrelated noise (only for lattice)') 42 | 43 | args = parser.parse_args() 44 | 45 | # define constants 46 | ENV_NAME = "CustomMyoElbowPoseRandom" 47 | 48 | now = datetime.now().strftime("%Y-%m-%d/%H-%M-%S") 49 | 50 | if args.model_path is not None: 51 | model_name = args.model_path.split("/")[-2] 52 | else: 53 | model_name = None 54 | 55 | TENSORBOARD_LOG = ( 56 | os.path.join(ROOT_DIR, "output", "training", now) 57 | + f"_elbow_pose_sde_{args.use_sde}_lattice_{args.use_lattice}_freq_{args.freq}_log_std_init_{args.log_std_init}_std_reg_{args.std_reg}_alpha_{args.alpha}_recurrent_ppo_seed_{args.seed}_resume_{model_name}" 58 | ) 59 | 60 | # Reward structure and task parameters: 61 | config = { 62 | "weighted_reward_keys": { 63 | "pose": 1, 64 | "bonus": 0, 65 | "penalty": 1, 66 | "act_reg": 0, 67 | "solved": 1, 68 | "done": 0, 69 | "sparse": 0, 70 | }, 71 | "reset_type": "init", 72 | "sds_distance": None, 73 | "weight_bodyname": None, 74 | "weight_range": None, 75 | "target_distance": 1, 76 | } 77 | 78 | max_episode_steps = 100 79 | 80 | model_config = dict( 81 | policy=LatticeRecurrentActorCriticPolicy, 82 | device=args.device, 83 | batch_size=32, 84 | n_steps=128, 85 | learning_rate=2.55673e-05, 86 | ent_coef=3.62109e-06, 87 | clip_range=0.3, 88 | gamma=0.99, 89 | gae_lambda=0.9, 90 | max_grad_norm=0.7, 91 | vf_coef=0.835671, 92 | n_epochs=10, 93 | use_sde=args.use_sde, 94 | sde_sample_freq=args.freq, 95 | policy_kwargs=dict( 96 | use_lattice=args.use_lattice, 97 | use_expln=True, 98 | ortho_init=False, 99 | log_std_init=args.log_std_init, 100 | activation_fn=nn.ReLU, 101 | net_arch=[dict(pi=[256, 256], vf=[256, 256])], 102 | std_clip=(1e-3, 10), 103 | expln_eps=1e-6, 104 | full_std=False, 105 | std_reg=args.std_reg, 106 | alpha=args.alpha 107 | ), 108 | ) 109 | 110 | # Function that creates and monitors vectorized environments: 111 | def make_parallel_envs(env_config, num_env, start_index=0): 112 | def make_env(_): 113 | def _thunk(): 114 | env = EnvironmentFactory.create(ENV_NAME, **env_config) 115 | env.seed(args.seed) 116 | env._max_episode_steps = max_episode_steps 117 | env = Monitor(env, TENSORBOARD_LOG) 118 | return env 119 | 120 | return _thunk 121 | 122 | return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)]) 123 | 124 | 125 | if __name__ == "__main__": 126 | # ensure tensorboard log directory exists and copy this file to track 127 | os.makedirs(TENSORBOARD_LOG, exist_ok=True) 128 | shutil.copy(os.path.abspath(__file__), TENSORBOARD_LOG) 129 | 130 | # Create and wrap the training and evaluations environments 131 | envs = make_parallel_envs(config, args.num_envs) 132 | 133 | if args.env_path is not None: 134 | envs = VecNormalize.load(args.env_path, envs) 135 | else: 136 | envs = VecNormalize(envs) 137 | 138 | # Define callbacks for evaluation and saving the agent 139 | eval_callback = EvalCallback( 140 | eval_env=envs, 141 | callback_on_new_best=EnvDumpCallback(TENSORBOARD_LOG, verbose=0), 142 | n_eval_episodes=10, 143 | best_model_save_path=TENSORBOARD_LOG, 144 | log_path=TENSORBOARD_LOG, 145 | eval_freq=10_000, 146 | deterministic=True, 147 | render=False, 148 | verbose=1, 149 | ) 150 | 151 | checkpoint_callback = CheckpointCallback( 152 | save_freq=25_000, 153 | save_path=TENSORBOARD_LOG, 154 | save_vecnormalize=True, 155 | verbose=1, 156 | ) 157 | 158 | tensorboard_callback = TensorboardCallback( 159 | info_keywords=( 160 | "pose", 161 | "bonus", 162 | "penalty", 163 | "act_reg", 164 | "done", 165 | "solved", 166 | "sparse", 167 | ) 168 | ) 169 | 170 | # Define trainer 171 | trainer = MyoTrainer( 172 | algo="recurrent_ppo", 173 | envs=envs, 174 | env_config=config, 175 | load_model_path=args.model_path, 176 | log_dir=TENSORBOARD_LOG, 177 | model_config=model_config, 178 | callbacks=[eval_callback, checkpoint_callback, tensorboard_callback], 179 | timesteps=600_000, 180 | ) 181 | 182 | # Train agent 183 | trainer.train(total_timesteps=trainer.timesteps) 184 | trainer.save() -------------------------------------------------------------------------------- /src/main_pose_finger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | import torch.nn as nn 5 | from datetime import datetime 6 | from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback 7 | from stable_baselines3.common.monitor import Monitor 8 | from stable_baselines3.common.vec_env import VecNormalize 9 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv 10 | from datetime import datetime 11 | from definitions import ROOT_DIR 12 | from envs.environment_factory import EnvironmentFactory 13 | from metrics.custom_callbacks import EnvDumpCallback, TensorboardCallback 14 | from train.trainer import MyoTrainer 15 | from models.ppo.policies import LatticeRecurrentActorCriticPolicy 16 | 17 | 18 | parser = argparse.ArgumentParser(description='Main script to train an agent') 19 | 20 | parser.add_argument('--seed', type=int, default=0, 21 | help='Seed for random number generator') 22 | parser.add_argument('--freq', type=int, default=1, 23 | help='SDE sample frequency') 24 | parser.add_argument('--use_sde', action='store_true', default=False, 25 | help='Flag to use SDE') 26 | parser.add_argument('--use_lattice', action='store_true', default=False, 27 | help='Flag to use lattice') 28 | parser.add_argument('--log_std_init', type=float, default=0.0, 29 | help='Initial log standard deviation') 30 | parser.add_argument('--env_path', type=str, 31 | help='Path to environment file') 32 | parser.add_argument('--model_path', type=str, 33 | help='Path to model file') 34 | parser.add_argument('--num_envs', type=int, default=16, 35 | help='Number of parallel environments') 36 | parser.add_argument('--device', type=str, default="cuda", 37 | help='Device, cuda or cpu') 38 | parser.add_argument('--std_reg', type=float, default=0.0, 39 | help='Additional independent std for the multivariate gaussian (only for lattice)') 40 | 41 | args = parser.parse_args() 42 | 43 | # define constants 44 | ENV_NAME = "CustomMyoFingerPoseRandom" 45 | 46 | now = datetime.now().strftime("%Y-%m-%d/%H-%M-%S") 47 | 48 | if args.model_path is not None: 49 | model_name = args.model_path.split("/")[-2] 50 | else: 51 | model_name = None 52 | 53 | TENSORBOARD_LOG = ( 54 | os.path.join(ROOT_DIR, "output", "training", now) 55 | + f"_finger_pose_sde_{args.use_sde}_lattice_{args.use_lattice}_freq_{args.freq}_log_std_init_{args.log_std_init}_std_reg_{args.std_reg}_recurrent_ppo_seed_{args.seed}_resume_{model_name}" 56 | ) 57 | 58 | # Reward structure and task parameters: 59 | config = { 60 | "weighted_reward_keys": { 61 | "pose": 1, 62 | "bonus": 0, 63 | "penalty": 1, 64 | "act_reg": 0, 65 | "solved": 1, 66 | "done": 0, 67 | "sparse": 0, 68 | }, 69 | "seed": args.seed, 70 | "reset_type": "init", 71 | "sds_distance": None, 72 | "weight_bodyname": None, 73 | "weight_range": None, 74 | "target_distance": 0.5, 75 | } 76 | 77 | max_episode_steps = 100 78 | 79 | model_config = dict( 80 | policy=LatticeRecurrentActorCriticPolicy, 81 | device=args.device, 82 | batch_size=32, 83 | n_steps=128, 84 | learning_rate=2.55673e-05, 85 | ent_coef=3.62109e-06, 86 | clip_range=0.3, 87 | gamma=0.99, 88 | gae_lambda=0.9, 89 | max_grad_norm=0.7, 90 | vf_coef=0.835671, 91 | n_epochs=10, 92 | use_sde=args.use_sde, 93 | sde_sample_freq=args.freq, 94 | policy_kwargs=dict( 95 | use_lattice=args.use_lattice, 96 | use_expln=True, 97 | ortho_init=False, 98 | log_std_init=args.log_std_init, 99 | activation_fn=nn.ReLU, 100 | net_arch=[dict(pi=[256, 256], vf=[256, 256])], 101 | std_clip=(1e-3, 10), 102 | expln_eps=1e-6, 103 | full_std=False, 104 | std_reg=args.std_reg 105 | ), 106 | ) 107 | 108 | # Function that creates and monitors vectorized environments: 109 | def make_parallel_envs(env_config, num_env, start_index=0): 110 | def make_env(_): 111 | def _thunk(): 112 | env = EnvironmentFactory.create(ENV_NAME, **env_config) 113 | env.seed(args.seed) 114 | env._max_episode_steps = max_episode_steps 115 | env = Monitor(env, TENSORBOARD_LOG) 116 | return env 117 | 118 | return _thunk 119 | 120 | return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)]) 121 | 122 | 123 | if __name__ == "__main__": 124 | # ensure tensorboard log directory exists and copy this file to track 125 | os.makedirs(TENSORBOARD_LOG, exist_ok=True) 126 | shutil.copy(os.path.abspath(__file__), TENSORBOARD_LOG) 127 | 128 | # Create and wrap the training and evaluations environments 129 | envs = make_parallel_envs(config, args.num_envs) 130 | 131 | if args.env_path is not None: 132 | envs = VecNormalize.load(args.env_path, envs) 133 | else: 134 | envs = VecNormalize(envs) 135 | 136 | # Define callbacks for evaluation and saving the agent 137 | eval_callback = EvalCallback( 138 | eval_env=envs, 139 | callback_on_new_best=EnvDumpCallback(TENSORBOARD_LOG, verbose=0), 140 | n_eval_episodes=10, 141 | best_model_save_path=TENSORBOARD_LOG, 142 | log_path=TENSORBOARD_LOG, 143 | eval_freq=10_000, 144 | deterministic=True, 145 | render=False, 146 | verbose=1, 147 | ) 148 | 149 | checkpoint_callback = CheckpointCallback( 150 | save_freq=25_000, 151 | save_path=TENSORBOARD_LOG, 152 | save_vecnormalize=True, 153 | verbose=1, 154 | ) 155 | 156 | tensorboard_callback = TensorboardCallback( 157 | info_keywords=( 158 | "pose", 159 | "bonus", 160 | "penalty", 161 | "act_reg", 162 | "done", 163 | "solved", 164 | "sparse", 165 | ) 166 | ) 167 | 168 | # Define trainer 169 | trainer = MyoTrainer( 170 | algo="recurrent_ppo", 171 | envs=envs, 172 | env_config=config, 173 | load_model_path=args.model_path, 174 | log_dir=TENSORBOARD_LOG, 175 | model_config=model_config, 176 | callbacks=[eval_callback, checkpoint_callback, tensorboard_callback], 177 | timesteps=8_000_000, 178 | ) 179 | 180 | # Train agent 181 | trainer.train(total_timesteps=trainer.timesteps) 182 | trainer.save() 183 | -------------------------------------------------------------------------------- /src/main_pose_hand.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | import torch.nn as nn 5 | from datetime import datetime 6 | from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback 7 | from stable_baselines3.common.monitor import Monitor 8 | from stable_baselines3.common.vec_env import VecNormalize 9 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv 10 | from datetime import datetime 11 | from definitions import ROOT_DIR 12 | from envs.environment_factory import EnvironmentFactory 13 | from metrics.custom_callbacks import EnvDumpCallback, TensorboardCallback 14 | from train.trainer import MyoTrainer 15 | from models.ppo.policies import LatticeRecurrentActorCriticPolicy 16 | 17 | 18 | parser = argparse.ArgumentParser(description='Main script to train an agent') 19 | 20 | parser.add_argument('--seed', type=int, default=0, 21 | help='Seed for random number generator') 22 | parser.add_argument('--freq', type=int, default=1, 23 | help='SDE sample frequency') 24 | parser.add_argument('--use_sde', action='store_true', default=False, 25 | help='Flag to use SDE') 26 | parser.add_argument('--use_lattice', action='store_true', default=False, 27 | help='Flag to use lattice') 28 | parser.add_argument('--log_std_init', type=float, default=0.0, 29 | help='Initial log standard deviation') 30 | parser.add_argument('--env_path', type=str, 31 | help='Path to environment file') 32 | parser.add_argument('--model_path', type=str, 33 | help='Path to model file') 34 | parser.add_argument('--num_envs', type=int, default=16, 35 | help='Number of parallel environments') 36 | parser.add_argument('--device', type=str, default="cuda", 37 | help='Device, cuda or cpu') 38 | parser.add_argument('--std_reg', type=float, default=0.0, 39 | help='Additional independent std for the multivariate gaussian (only for lattice)') 40 | args = parser.parse_args() 41 | 42 | # define constants 43 | ENV_NAME = "CustomMyoHandPoseRandom" 44 | 45 | if args.model_path is not None: 46 | model_name = args.model_path.split("/")[-2] 47 | else: 48 | model_name = None 49 | 50 | now = datetime.now().strftime("%Y-%m-%d/%H-%M-%S") 51 | TENSORBOARD_LOG = ( 52 | os.path.join(ROOT_DIR, "output", "training", now) 53 | + f"_hand_pose_sde_{args.use_sde}_lattice_{args.use_lattice}_freq_{args.freq}_log_std_init_{args.log_std_init}_std_reg_{args.std_reg}_recurrent_ppo_seed_{args.seed}_resume_{model_name}" 54 | ) 55 | 56 | # Reward structure and task parameters: 57 | config = { 58 | "weighted_reward_keys": { 59 | "pose": 1, 60 | "bonus": 0, 61 | "penalty": 1, 62 | "act_reg": 0, 63 | "solved": 1, 64 | "done": 0, 65 | "sparse": 0, 66 | }, 67 | "seed": args.seed, 68 | "reset_type": "init", 69 | "sds_distance": None, 70 | "weight_bodyname": None, 71 | "weight_range": None, 72 | "target_distance": 0.5, 73 | } 74 | 75 | max_episode_steps = 100 # default: 100 76 | 77 | model_config = dict( 78 | policy=LatticeRecurrentActorCriticPolicy, 79 | device=args.device, 80 | batch_size=32, 81 | n_steps=128, 82 | learning_rate=2.55673e-05, 83 | ent_coef=3.62109e-06, 84 | clip_range=0.3, 85 | gamma=0.99, 86 | gae_lambda=0.9, 87 | max_grad_norm=0.7, 88 | vf_coef=0.835671, 89 | n_epochs=10, 90 | use_sde=args.use_sde, 91 | sde_sample_freq=args.freq, 92 | policy_kwargs=dict( 93 | use_lattice=args.use_lattice, 94 | use_expln=True, 95 | ortho_init=False, 96 | log_std_init=args.log_std_init, 97 | activation_fn=nn.ReLU, 98 | net_arch=[dict(pi=[256, 256], vf=[256, 256])], 99 | std_clip=(1e-3, 10), 100 | expln_eps=1e-6, 101 | full_std=False, 102 | std_reg=args.std_reg 103 | ), 104 | ) 105 | 106 | # Function that creates and monitors vectorized environments: 107 | def make_parallel_envs(env_config, num_env, start_index=0): 108 | def make_env(_): 109 | def _thunk(): 110 | env = EnvironmentFactory.create(ENV_NAME, **env_config) 111 | env.seed(args.seed) 112 | env._max_episode_steps = max_episode_steps 113 | env = Monitor(env, TENSORBOARD_LOG) 114 | return env 115 | 116 | return _thunk 117 | 118 | return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)]) 119 | 120 | 121 | if __name__ == "__main__": 122 | # ensure tensorboard log directory exists and copy this file to track 123 | os.makedirs(TENSORBOARD_LOG, exist_ok=True) 124 | shutil.copy(os.path.abspath(__file__), TENSORBOARD_LOG) 125 | 126 | # Create and wrap the training and evaluations environments 127 | envs = make_parallel_envs(config, args.num_envs) 128 | 129 | if args.env_path is not None: 130 | envs = VecNormalize.load(args.env_path, envs) 131 | else: 132 | envs = VecNormalize(envs) 133 | 134 | # Define callbacks for evaluation and saving the agent 135 | eval_callback = EvalCallback( 136 | eval_env=envs, 137 | callback_on_new_best=EnvDumpCallback(TENSORBOARD_LOG, verbose=0), 138 | n_eval_episodes=10, 139 | best_model_save_path=TENSORBOARD_LOG, 140 | log_path=TENSORBOARD_LOG, 141 | eval_freq=10_000, 142 | deterministic=True, 143 | render=False, 144 | verbose=1, 145 | ) 146 | 147 | checkpoint_callback = CheckpointCallback( 148 | save_freq=25_000, 149 | save_path=TENSORBOARD_LOG, 150 | save_vecnormalize=True, 151 | verbose=1, 152 | ) 153 | 154 | tensorboard_callback = TensorboardCallback( 155 | info_keywords=( 156 | "pose", 157 | "bonus", 158 | "penalty", 159 | "act_reg", 160 | "done", 161 | "solved", 162 | "sparse", 163 | ) 164 | ) 165 | 166 | # Define trainer 167 | trainer = MyoTrainer( 168 | algo="recurrent_ppo", 169 | envs=envs, 170 | env_config=config, 171 | load_model_path=args.model_path, 172 | log_dir=TENSORBOARD_LOG, 173 | model_config=model_config, 174 | callbacks=[eval_callback, checkpoint_callback, tensorboard_callback], 175 | timesteps=20_000_000, 176 | ) 177 | 178 | # Train agent 179 | trainer.train(total_timesteps=trainer.timesteps) 180 | trainer.save() 181 | -------------------------------------------------------------------------------- /src/main_reach_finger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | import torch.nn as nn 5 | from datetime import datetime 6 | from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback 7 | from stable_baselines3.common.monitor import Monitor 8 | from stable_baselines3.common.vec_env import VecNormalize 9 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv 10 | from datetime import datetime 11 | from definitions import ROOT_DIR 12 | from envs.environment_factory import EnvironmentFactory 13 | from metrics.custom_callbacks import EnvDumpCallback, TensorboardCallback 14 | from train.trainer import MyoTrainer 15 | from models.ppo.policies import LatticeRecurrentActorCriticPolicy 16 | 17 | 18 | parser = argparse.ArgumentParser(description='Main script to train an agent') 19 | 20 | parser.add_argument('--seed', type=int, default=0, 21 | help='Seed for random number generator') 22 | parser.add_argument('--freq', type=int, default=1, 23 | help='SDE sample frequency') 24 | parser.add_argument('--use_sde', action='store_true', default=False, 25 | help='Flag to use SDE') 26 | parser.add_argument('--use_lattice', action='store_true', default=False, 27 | help='Flag to use lattice') 28 | parser.add_argument('--log_std_init', type=float, default=0.0, 29 | help='Initial log standard deviation') 30 | parser.add_argument('--env_path', type=str, 31 | help='Path to environment file') 32 | parser.add_argument('--model_path', type=str, 33 | help='Path to model file') 34 | parser.add_argument('--num_envs', type=int, default=16, 35 | help='Number of parallel environments') 36 | parser.add_argument('--device', type=str, default="cuda", 37 | help='Device, cuda or cpu') 38 | parser.add_argument('--std_reg', type=float, default=0.0, 39 | help='Additional independent std for the multivariate gaussian (only for lattice)') 40 | 41 | args = parser.parse_args() 42 | 43 | # define constants 44 | ENV_NAME = "MyoFingerReachRandom" 45 | 46 | now = datetime.now().strftime("%Y-%m-%d/%H-%M-%S") 47 | 48 | if args.model_path is not None: 49 | model_name = args.model_path.split("/")[-2] 50 | else: 51 | model_name = None 52 | 53 | TENSORBOARD_LOG = ( 54 | os.path.join(ROOT_DIR, "output", "training", now) 55 | + f"_finger_reach_sde_{args.use_sde}_lattice_{args.use_lattice}_freq_{args.freq}_log_std_init_{args.log_std_init}_std_reg_{args.std_reg}_recurrent_ppo_seed_{args.seed}_resume_{model_name}" 56 | ) 57 | 58 | # Reward structure and task parameters: 59 | config = { 60 | "seed": args.seed, 61 | } 62 | 63 | max_episode_steps = 100 64 | 65 | model_config = dict( 66 | policy=LatticeRecurrentActorCriticPolicy, 67 | device=args.device, 68 | batch_size=32, 69 | n_steps=128, 70 | learning_rate=2.55673e-05, 71 | ent_coef=3.62109e-06, 72 | clip_range=0.3, 73 | gamma=0.99, 74 | gae_lambda=0.9, 75 | max_grad_norm=0.7, 76 | vf_coef=0.835671, 77 | n_epochs=10, 78 | use_sde=args.use_sde, 79 | sde_sample_freq=args.freq, 80 | policy_kwargs=dict( 81 | use_lattice=args.use_lattice, 82 | use_expln=True, 83 | ortho_init=False, 84 | log_std_init=args.log_std_init, 85 | activation_fn=nn.ReLU, 86 | net_arch=[dict(pi=[256, 256], vf=[256, 256])], 87 | std_clip=(1e-3, 10), 88 | expln_eps=1e-6, 89 | full_std=False, 90 | std_reg=args.std_reg 91 | ), 92 | ) 93 | 94 | # Function that creates and monitors vectorized environments: 95 | def make_parallel_envs(env_config, num_env, start_index=0): 96 | def make_env(_): 97 | def _thunk(): 98 | env = EnvironmentFactory.create(ENV_NAME, **env_config) 99 | env.seed(args.seed) 100 | env._max_episode_steps = max_episode_steps 101 | env = Monitor(env, TENSORBOARD_LOG) 102 | return env 103 | 104 | return _thunk 105 | 106 | return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)]) 107 | 108 | 109 | if __name__ == "__main__": 110 | # ensure tensorboard log directory exists and copy this file to track 111 | os.makedirs(TENSORBOARD_LOG, exist_ok=True) 112 | shutil.copy(os.path.abspath(__file__), TENSORBOARD_LOG) 113 | 114 | # Create and wrap the training and evaluations environments 115 | envs = make_parallel_envs(config, args.num_envs) 116 | 117 | if args.env_path is not None: 118 | envs = VecNormalize.load(args.env_path, envs) 119 | else: 120 | envs = VecNormalize(envs) 121 | 122 | # Define callbacks for evaluation and saving the agent 123 | eval_callback = EvalCallback( 124 | eval_env=envs, 125 | callback_on_new_best=EnvDumpCallback(TENSORBOARD_LOG, verbose=0), 126 | n_eval_episodes=10, 127 | best_model_save_path=TENSORBOARD_LOG, 128 | log_path=TENSORBOARD_LOG, 129 | eval_freq=10_000, 130 | deterministic=True, 131 | render=False, 132 | verbose=1, 133 | ) 134 | 135 | checkpoint_callback = CheckpointCallback( 136 | save_freq=25_000, 137 | save_path=TENSORBOARD_LOG, 138 | save_vecnormalize=True, 139 | verbose=1, 140 | ) 141 | 142 | tensorboard_callback = TensorboardCallback( 143 | info_keywords=( 144 | "rwd_dense", 145 | "rwd_sparse", 146 | "solved", 147 | "done", 148 | ) 149 | ) 150 | 151 | # Define trainer 152 | trainer = MyoTrainer( 153 | algo="recurrent_ppo", 154 | envs=envs, 155 | env_config=config, 156 | load_model_path=args.model_path, 157 | log_dir=TENSORBOARD_LOG, 158 | model_config=model_config, 159 | callbacks=[eval_callback, checkpoint_callback, tensorboard_callback], 160 | timesteps=15_000_000, 161 | ) 162 | 163 | # Train agent 164 | trainer.train(total_timesteps=trainer.timesteps) 165 | trainer.save() 166 | -------------------------------------------------------------------------------- /src/main_reach_hand.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | import torch.nn as nn 5 | from datetime import datetime 6 | from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback 7 | from stable_baselines3.common.monitor import Monitor 8 | from stable_baselines3.common.vec_env import VecNormalize 9 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv 10 | from datetime import datetime 11 | from definitions import ROOT_DIR 12 | from envs.environment_factory import EnvironmentFactory 13 | from metrics.custom_callbacks import EnvDumpCallback, TensorboardCallback 14 | from train.trainer import MyoTrainer 15 | from models.ppo.policies import LatticeRecurrentActorCriticPolicy 16 | 17 | 18 | parser = argparse.ArgumentParser(description='Main script to train an agent') 19 | 20 | parser.add_argument('--seed', type=int, default=0, 21 | help='Seed for random number generator') 22 | parser.add_argument('--freq', type=int, default=1, 23 | help='SDE sample frequency') 24 | parser.add_argument('--use_sde', action='store_true', default=False, 25 | help='Flag to use SDE') 26 | parser.add_argument('--use_lattice', action='store_true', default=False, 27 | help='Flag to use lattice') 28 | parser.add_argument('--log_std_init', type=float, default=0.0, 29 | help='Initial log standard deviation') 30 | parser.add_argument('--env_path', type=str, 31 | help='Path to environment file') 32 | parser.add_argument('--model_path', type=str, 33 | help='Path to model file') 34 | parser.add_argument('--num_envs', type=int, default=16, 35 | help='Number of parallel environments') 36 | parser.add_argument('--device', type=str, default="cuda", 37 | help='Device, cuda or cpu') 38 | parser.add_argument('--std_reg', type=float, default=0.0, 39 | help='Additional independent std for the multivariate gaussian (only for lattice)') 40 | 41 | args = parser.parse_args() 42 | 43 | # define constants 44 | ENV_NAME = "MyoHandReachRandom" 45 | 46 | now = datetime.now().strftime("%Y-%m-%d/%H-%M-%S") 47 | 48 | if args.model_path is not None: 49 | model_name = args.model_path.split("/")[-2] 50 | else: 51 | model_name = None 52 | 53 | TENSORBOARD_LOG = ( 54 | os.path.join(ROOT_DIR, "output", "training", now) 55 | + f"_hand_reach_sde_{args.use_sde}_lattice_{args.use_lattice}_freq_{args.freq}_log_std_init_{args.log_std_init}_std_reg_{args.std_reg}_recurrent_ppo_seed_{args.seed}_resume_{model_name}" 56 | ) 57 | 58 | # Reward structure and task parameters: 59 | config = { 60 | "seed": args.seed, 61 | } 62 | 63 | max_episode_steps = 100 64 | 65 | model_config = dict( 66 | policy=LatticeRecurrentActorCriticPolicy, 67 | device=args.device, 68 | batch_size=32, 69 | n_steps=128, 70 | learning_rate=2.55673e-05, 71 | ent_coef=3.62109e-06, 72 | clip_range=0.3, 73 | gamma=0.99, 74 | gae_lambda=0.9, 75 | max_grad_norm=0.7, 76 | vf_coef=0.835671, 77 | n_epochs=10, 78 | use_sde=args.use_sde, 79 | sde_sample_freq=args.freq, 80 | policy_kwargs=dict( 81 | use_lattice=args.use_lattice, 82 | use_expln=True, 83 | ortho_init=False, 84 | log_std_init=args.log_std_init, 85 | activation_fn=nn.ReLU, 86 | net_arch=[dict(pi=[256, 256], vf=[256, 256])], 87 | std_clip=(1e-3, 10), 88 | expln_eps=1e-6, 89 | full_std=False, 90 | std_reg=args.std_reg 91 | ), 92 | ) 93 | 94 | # Function that creates and monitors vectorized environments: 95 | def make_parallel_envs(env_config, num_env, start_index=0): 96 | def make_env(_): 97 | def _thunk(): 98 | env = EnvironmentFactory.create(ENV_NAME, **env_config) 99 | env.seed(args.seed) 100 | env._max_episode_steps = max_episode_steps 101 | env = Monitor(env, TENSORBOARD_LOG) 102 | return env 103 | 104 | return _thunk 105 | 106 | return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)]) 107 | 108 | 109 | if __name__ == "__main__": 110 | # ensure tensorboard log directory exists and copy this file to track 111 | os.makedirs(TENSORBOARD_LOG, exist_ok=True) 112 | shutil.copy(os.path.abspath(__file__), TENSORBOARD_LOG) 113 | 114 | # Create and wrap the training and evaluations environments 115 | envs = make_parallel_envs(config, args.num_envs) 116 | 117 | if args.env_path is not None: 118 | envs = VecNormalize.load(args.env_path, envs) 119 | else: 120 | envs = VecNormalize(envs) 121 | 122 | # Define callbacks for evaluation and saving the agent 123 | eval_callback = EvalCallback( 124 | eval_env=envs, 125 | callback_on_new_best=EnvDumpCallback(TENSORBOARD_LOG, verbose=0), 126 | n_eval_episodes=10, 127 | best_model_save_path=TENSORBOARD_LOG, 128 | log_path=TENSORBOARD_LOG, 129 | eval_freq=10_000, 130 | deterministic=True, 131 | render=False, 132 | verbose=1, 133 | ) 134 | 135 | checkpoint_callback = CheckpointCallback( 136 | save_freq=25_000, 137 | save_path=TENSORBOARD_LOG, 138 | save_vecnormalize=True, 139 | verbose=1, 140 | ) 141 | 142 | tensorboard_callback = TensorboardCallback( 143 | info_keywords=( 144 | "rwd_dense", 145 | "rwd_sparse", 146 | "solved", 147 | "done", 148 | ) 149 | ) 150 | 151 | # Define trainer 152 | trainer = MyoTrainer( 153 | algo="recurrent_ppo", 154 | envs=envs, 155 | env_config=config, 156 | load_model_path=args.model_path, 157 | log_dir=TENSORBOARD_LOG, 158 | model_config=model_config, 159 | callbacks=[eval_callback, checkpoint_callback, tensorboard_callback], 160 | timesteps=15_000_000, 161 | ) 162 | 163 | # Train agent 164 | trainer.train(total_timesteps=trainer.timesteps) 165 | trainer.save() -------------------------------------------------------------------------------- /src/main_reorient.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | import torch.nn as nn 5 | from datetime import datetime 6 | from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback 7 | from stable_baselines3.common.monitor import Monitor 8 | from stable_baselines3.common.vec_env import VecNormalize 9 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv 10 | from datetime import datetime 11 | from definitions import ROOT_DIR 12 | from envs.environment_factory import EnvironmentFactory 13 | from metrics.custom_callbacks import EnvDumpCallback, TensorboardCallback 14 | from train.trainer import MyoTrainer 15 | from models.ppo.policies import LatticeRecurrentActorCriticPolicy 16 | 17 | 18 | parser = argparse.ArgumentParser(description='Main script to train an agent') 19 | 20 | parser.add_argument('--seed', type=int, default=0, 21 | help='Seed for random number generator') 22 | parser.add_argument('--freq', type=int, default=1, 23 | help='SDE sample frequency') 24 | parser.add_argument('--use_sde', action='store_true', default=False, 25 | help='Flag to use SDE') 26 | parser.add_argument('--use_lattice', action='store_true', default=False, 27 | help='Flag to use lattice') 28 | parser.add_argument('--log_std_init', type=float, default=0.0, 29 | help='Initial log standard deviation') 30 | parser.add_argument('--env_path', type=str, 31 | help='Path to environment file') 32 | parser.add_argument('--model_path', type=str, 33 | help='Path to model file') 34 | parser.add_argument('--num_envs', type=int, default=16, 35 | help='Number of parallel environments') 36 | parser.add_argument('--device', type=str, default="cuda", 37 | help='Device, cuda or cpu') 38 | parser.add_argument('--std_reg', type=float, default=0.0, 39 | help='Additional independent std for the multivariate gaussian (only for lattice)') 40 | 41 | args = parser.parse_args() 42 | 43 | # define constants 44 | ENV_NAME = "CustomMyoReorientP2" 45 | 46 | now = datetime.now().strftime("%Y-%m-%d/%H-%M-%S") 47 | 48 | if args.model_path is not None: 49 | model_name = args.model_path.split("/")[-2] 50 | else: 51 | model_name = None 52 | 53 | TENSORBOARD_LOG = ( 54 | os.path.join(ROOT_DIR, "output", "training", now) 55 | + f"_reorient_sde_{args.use_sde}_lattice_{args.use_lattice}_freq_{args.freq}_log_std_init_{args.log_std_init}_std_reg_{args.std_reg}_recurrent_ppo_seed_{args.seed}_resume_{model_name}" 56 | ) 57 | 58 | # Reward structure and task parameters: 59 | config = { 60 | "seed": args.seed, 61 | "weighted_reward_keys": { 62 | "pos_dist": 2, 63 | "rot_dist": 0.2, 64 | "pos_dist_diff": 1e2, 65 | "rot_dist_diff": 1e1, 66 | "alive": 1, 67 | "act_reg": 0, 68 | "solved": 2, 69 | "done": 0, 70 | "sparse": 0, 71 | }, 72 | "goal_pos": (-0.0, 0.0), 73 | "goal_rot": (-3.14 * 0.25, 3.14 * 0.25), 74 | "obj_size_change": 0, 75 | "obj_friction_change": (0, 0, 0), 76 | "enable_rsi": False, 77 | "rsi_distance_pos": None, 78 | "rsi_distance_rot": None, 79 | "goal_rot_x": None, 80 | "goal_rot_y": None, 81 | "goal_rot_z": None, 82 | "guided_trajectory_steps": 0, 83 | } 84 | 85 | max_episode_steps = 150 86 | 87 | model_config = dict( 88 | policy=LatticeRecurrentActorCriticPolicy, 89 | device=args.device, 90 | batch_size=32, 91 | n_steps=128, 92 | learning_rate=2.55673e-05, 93 | ent_coef=3.62109e-06, 94 | clip_range=0.3, 95 | gamma=0.99, 96 | gae_lambda=0.9, 97 | max_grad_norm=0.7, 98 | vf_coef=0.835671, 99 | n_epochs=10, 100 | use_sde=args.use_sde, 101 | sde_sample_freq=args.freq, 102 | policy_kwargs=dict( 103 | use_lattice=args.use_lattice, 104 | use_expln=True, 105 | ortho_init=False, 106 | log_std_init=args.log_std_init, 107 | activation_fn=nn.ReLU, 108 | net_arch=[dict(pi=[256, 256], vf=[256, 256])], 109 | std_clip=(1e-3, 10), 110 | expln_eps=1e-6, 111 | full_std=False, 112 | std_reg=args.std_reg 113 | ), 114 | ) 115 | 116 | # Function that creates and monitors vectorized environments: 117 | def make_parallel_envs(env_config, num_env, start_index=0): 118 | def make_env(_): 119 | def _thunk(): 120 | env = EnvironmentFactory.create(ENV_NAME, **env_config) 121 | env.seed(args.seed) 122 | env._max_episode_steps = max_episode_steps 123 | env = Monitor(env, TENSORBOARD_LOG) 124 | return env 125 | 126 | return _thunk 127 | 128 | return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)]) 129 | 130 | 131 | if __name__ == "__main__": 132 | # ensure tensorboard log directory exists and copy this file to track 133 | os.makedirs(TENSORBOARD_LOG, exist_ok=True) 134 | shutil.copy(os.path.abspath(__file__), TENSORBOARD_LOG) 135 | 136 | # Create and wrap the training and evaluations environments 137 | envs = make_parallel_envs(config, args.num_envs) 138 | 139 | if args.env_path is not None: 140 | envs = VecNormalize.load(args.env_path, envs) 141 | else: 142 | envs = VecNormalize(envs) 143 | 144 | # Define callbacks for evaluation and saving the agent 145 | eval_callback = EvalCallback( 146 | eval_env=envs, 147 | callback_on_new_best=EnvDumpCallback(TENSORBOARD_LOG, verbose=0), 148 | n_eval_episodes=10, 149 | best_model_save_path=TENSORBOARD_LOG, 150 | log_path=TENSORBOARD_LOG, 151 | eval_freq=10_000, 152 | deterministic=True, 153 | render=False, 154 | verbose=1, 155 | ) 156 | 157 | checkpoint_callback = CheckpointCallback( 158 | save_freq=25_000, 159 | save_path=TENSORBOARD_LOG, 160 | save_vecnormalize=True, 161 | verbose=1, 162 | ) 163 | 164 | tensorboard_callback = TensorboardCallback( 165 | info_keywords=( 166 | "pos_dist", 167 | "rot_dist", 168 | "pos_dist_diff", 169 | "rot_dist_diff", 170 | "act_reg", 171 | "alive", 172 | "solved", 173 | ) 174 | ) 175 | 176 | # Define trainer 177 | trainer = MyoTrainer( 178 | algo="recurrent_ppo", 179 | envs=envs, 180 | env_config=config, 181 | load_model_path=args.model_path, 182 | log_dir=TENSORBOARD_LOG, 183 | model_config=model_config, 184 | callbacks=[eval_callback, checkpoint_callback, tensorboard_callback], 185 | timesteps=20_000_000, 186 | ) 187 | 188 | # Train agent 189 | trainer.train(total_timesteps=trainer.timesteps) 190 | trainer.save() 191 | -------------------------------------------------------------------------------- /src/main_walker.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | import torch.nn as nn 5 | from datetime import datetime 6 | from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback 7 | from stable_baselines3.common.monitor import Monitor 8 | from stable_baselines3.common.vec_env import VecNormalize 9 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv 10 | from datetime import datetime 11 | from definitions import ROOT_DIR 12 | from envs.environment_factory import EnvironmentFactory 13 | from metrics.custom_callbacks import EnvDumpCallback, TensorboardCallback 14 | from train.trainer import MyoTrainer 15 | from models.sac.policies import LatticeSACPolicy 16 | 17 | 18 | parser = argparse.ArgumentParser(description="Main script to train an agent") 19 | 20 | parser.add_argument("--seed", type=int, default=0, 21 | help="Seed for random number generator") 22 | parser.add_argument("--freq", type=int, default=1, 23 | help="SDE sample frequency") 24 | parser.add_argument("--use_sde", action="store_true", 25 | default=False, help="Flag to use SDE") 26 | parser.add_argument("--use_lattice", action="store_true", 27 | default=False, help="Flag to use lattice") 28 | parser.add_argument("--log_std_init", type=float, default=0.0, 29 | help="Initial log standard deviation") 30 | parser.add_argument("--env_path", type=str, 31 | help="Path to environment file") 32 | parser.add_argument("--model_path", type=str, 33 | help="Path to model file") 34 | parser.add_argument("--num_envs", type=int, default=16, 35 | help="Number of parallel environments") 36 | parser.add_argument("--device", type=str, default="cuda", 37 | help="Device, cuda or cpu") 38 | parser.add_argument("--std_reg", type=float, default=0.0, 39 | help="Additional independent std for the multivariate gaussian (only for lattice)") 40 | args = parser.parse_args() 41 | 42 | # define constants 43 | ENV_NAME = "WalkerBulletEnv" 44 | 45 | now = datetime.now().strftime("%Y-%m-%d/%H-%M-%S") 46 | 47 | if args.model_path is not None: 48 | model_name = args.model_path.split("/")[-2] 49 | else: 50 | model_name = None 51 | 52 | TENSORBOARD_LOG = ( 53 | os.path.join(ROOT_DIR, "output", "training", now) 54 | + f"_walker_sde_{args.use_sde}_lattice_{args.use_lattice}_freq_{args.freq}_log_std_init_{args.log_std_init}_std_reg_{args.std_reg}_sac_seed_{args.seed}_resume_{model_name}" 55 | ) 56 | 57 | # Reward structure and task parameters: 58 | config = {} 59 | 60 | max_episode_steps = 1000 61 | 62 | model_config = dict( 63 | policy=LatticeSACPolicy, 64 | device=args.device, 65 | learning_rate=3e-4, 66 | buffer_size=300_000, 67 | learning_starts=10000, 68 | batch_size=256, 69 | tau=0.02, 70 | gamma=0.98, 71 | train_freq=(8, "step"), 72 | gradient_steps=8, 73 | action_noise=None, 74 | replay_buffer_class=None, 75 | ent_coef="auto", 76 | target_update_interval=1, 77 | target_entropy="auto", 78 | seed=args.seed, 79 | use_sde=args.use_sde, 80 | sde_sample_freq=args.freq, # number of steps 81 | policy_kwargs=dict( 82 | use_lattice=args.use_lattice, 83 | use_expln=True, 84 | log_std_init=args.log_std_init, 85 | activation_fn=nn.GELU, 86 | net_arch=dict(pi=[400, 300], qf=[400, 300]), 87 | std_clip=(1e-3, 10), 88 | expln_eps=1e-6, 89 | clip_mean=2.0, 90 | std_reg=args.std_reg, 91 | ), 92 | ) 93 | 94 | # Function that creates and monitors vectorized environments: 95 | def make_parallel_envs(env_config, num_env, start_index=0): 96 | def make_env(_): 97 | def _thunk(): 98 | env = EnvironmentFactory.create(ENV_NAME, **env_config) 99 | env.seed(args.seed) 100 | env._max_episode_steps = max_episode_steps 101 | env = Monitor(env, TENSORBOARD_LOG) 102 | return env 103 | 104 | return _thunk 105 | 106 | return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)]) 107 | 108 | 109 | if __name__ == "__main__": 110 | # ensure tensorboard log directory exists and copy this file to track 111 | os.makedirs(TENSORBOARD_LOG, exist_ok=True) 112 | shutil.copy(os.path.abspath(__file__), TENSORBOARD_LOG) 113 | 114 | # Create and wrap the training and evaluations environments 115 | envs = make_parallel_envs(config, args.num_envs) 116 | 117 | if args.env_path is not None: 118 | envs = VecNormalize.load(args.env_path, envs) 119 | else: 120 | envs = VecNormalize(envs) 121 | 122 | # Define callbacks for evaluation and saving the agent 123 | eval_callback = EvalCallback( 124 | eval_env=envs, 125 | callback_on_new_best=EnvDumpCallback(TENSORBOARD_LOG, verbose=0), 126 | n_eval_episodes=10, 127 | best_model_save_path=TENSORBOARD_LOG, 128 | log_path=TENSORBOARD_LOG, 129 | eval_freq=10_000, 130 | deterministic=True, 131 | render=False, 132 | verbose=1, 133 | ) 134 | 135 | checkpoint_callback = CheckpointCallback( 136 | save_freq=25_000, 137 | save_path=TENSORBOARD_LOG, 138 | save_vecnormalize=True, 139 | verbose=1, 140 | ) 141 | 142 | tensorboard_callback = TensorboardCallback(info_keywords=()) 143 | 144 | # Define trainer 145 | trainer = MyoTrainer( 146 | algo="sac", 147 | envs=envs, 148 | env_config=config, 149 | load_model_path=args.model_path, 150 | log_dir=TENSORBOARD_LOG, 151 | model_config=model_config, 152 | callbacks=[eval_callback, checkpoint_callback, tensorboard_callback], 153 | timesteps=15_000_000, 154 | ) 155 | 156 | # Train agent 157 | trainer.train(total_timesteps=trainer.timesteps) 158 | trainer.save() 159 | -------------------------------------------------------------------------------- /src/metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amathislab/lattice/846d02fa993b9b80ce5ecb806463e0a05711bad3/src/metrics/__init__.py -------------------------------------------------------------------------------- /src/metrics/custom_callbacks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from stable_baselines3.common.callbacks import BaseCallback 4 | 5 | 6 | class EnvDumpCallback(BaseCallback): 7 | def __init__(self, save_path, verbose=0): 8 | super().__init__(verbose=verbose) 9 | self.save_path = save_path 10 | 11 | def _on_step(self): 12 | env_path = os.path.join(self.save_path, "training_env.pkl") 13 | if self.verbose > 0: 14 | print("Saving the training environment to path ", env_path) 15 | self.training_env.save(env_path) 16 | return True 17 | 18 | 19 | class TensorboardCallback(BaseCallback): 20 | def __init__(self, info_keywords, verbose=0): 21 | super().__init__(verbose=verbose) 22 | self.info_keywords = info_keywords 23 | self.rollout_info = {} 24 | 25 | def _on_rollout_start(self): 26 | self.rollout_info = {key: [] for key in self.info_keywords} 27 | 28 | def _on_step(self): 29 | for key in self.info_keywords: 30 | vals = [info[key] for info in self.locals["infos"]] 31 | self.rollout_info[key].extend(vals) 32 | return True 33 | 34 | def _on_rollout_end(self): 35 | for key in self.info_keywords: 36 | self.logger.record("rollout/" + key, np.mean(self.rollout_info[key])) -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amathislab/lattice/846d02fa993b9b80ce5ecb806463e0a05711bad3/src/models/__init__.py -------------------------------------------------------------------------------- /src/models/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import Optional 4 | from torch import nn 5 | from torch.distributions import MultivariateNormal 6 | from typing import Tuple 7 | from torch.distributions import Normal 8 | from stable_baselines3.common.distributions import DiagGaussianDistribution, TanhBijector, StateDependentNoiseDistribution 9 | 10 | 11 | class LatticeNoiseDistribution(DiagGaussianDistribution): 12 | """ 13 | Like Lattice noise distribution, non-state-dependent. Does not allow time correlation, but 14 | it is more efficient. 15 | 16 | :param action_dim: Dimension of the action space. 17 | """ 18 | 19 | def __init__(self, action_dim: int): 20 | super().__init__(action_dim=action_dim) 21 | 22 | def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0, state_dependent: bool = False) -> Tuple[nn.Module, nn.Parameter]: 23 | self.mean_actions = nn.Linear(latent_dim, self.action_dim) 24 | self.std_init = torch.tensor(log_std_init).exp() 25 | if state_dependent: 26 | log_std = nn.Linear(latent_dim, self.action_dim + latent_dim) 27 | else: 28 | log_std = nn.Parameter(torch.zeros(self.action_dim + latent_dim), requires_grad=True) 29 | return self.mean_actions, log_std 30 | 31 | def proba_distribution(self, mean_actions: torch.Tensor, log_std: torch.Tensor) -> "LatticeNoiseDistribution": 32 | """ 33 | Create the distribution given its parameters (mean, std) 34 | 35 | :param mean_actions: 36 | :param log_std: 37 | :return: 38 | """ 39 | std = log_std.exp() * self.std_init 40 | action_variance = std[..., : self.action_dim] ** 2 41 | latent_variance = std[..., self.action_dim :] ** 2 42 | 43 | sigma_mat = (self.mean_actions.weight * latent_variance[..., None, :]).matmul(self.mean_actions.weight.T) 44 | sigma_mat[..., range(self.action_dim), range(self.action_dim)] += action_variance 45 | self.distribution = MultivariateNormal(mean_actions, sigma_mat) 46 | return self 47 | 48 | def log_prob(self, actions: torch.Tensor) -> torch.Tensor: 49 | return self.distribution.log_prob(actions) 50 | 51 | def entropy(self) -> torch.Tensor: 52 | return self.distribution.entropy() 53 | 54 | 55 | class SquashedLatticeNoiseDistribution(LatticeNoiseDistribution): 56 | """ 57 | Lattice noise distribution, followed by a squashing function (tanh) to ensure bounds. 58 | 59 | :param action_dim: Dimension of the action space. 60 | :param epsilon: small value to avoid NaN due to numerical imprecision. 61 | """ 62 | def __init__(self, action_dim: int, epsilon: float = 1e-6): 63 | super().__init__(action_dim) 64 | self.epsilon = epsilon 65 | self.gaussian_actions: Optional[torch.Tensor] = None 66 | 67 | def log_prob(self, actions: torch.Tensor, gaussian_actions: Optional[torch.Tensor] = None) -> torch.Tensor: 68 | # Inverse tanh 69 | # Naive implementation (not stable): 0.5 * torch.log((1 + x) / (1 - x)) 70 | # We use numpy to avoid numerical instability 71 | if gaussian_actions is None: 72 | # It will be clipped to avoid NaN when inversing tanh 73 | gaussian_actions = TanhBijector.inverse(actions) 74 | 75 | # Log likelihood for a Gaussian distribution 76 | log_prob = super().log_prob(gaussian_actions) 77 | # Squash correction (from original SAC implementation) 78 | # this comes from the fact that tanh is bijective and differentiable 79 | log_prob -= torch.sum(torch.log(1 - actions**2 + self.epsilon), dim=1) 80 | return log_prob 81 | 82 | def entropy(self) -> Optional[torch.Tensor]: 83 | # No analytical form, 84 | # entropy needs to be estimated using -log_prob.mean() 85 | return None 86 | 87 | def sample(self) -> torch.Tensor: 88 | # Reparametrization trick to pass gradients 89 | self.gaussian_actions = super().sample() 90 | return torch.tanh(self.gaussian_actions) 91 | 92 | def mode(self) -> torch.Tensor: 93 | self.gaussian_actions = super().mode() 94 | # Squash the output 95 | return torch.tanh(self.gaussian_actions) 96 | 97 | def log_prob_from_params(self, mean_actions: torch.Tensor, log_std: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 98 | action = self.actions_from_params(mean_actions, log_std) 99 | log_prob = self.log_prob(action, self.gaussian_actions) 100 | return action, log_prob 101 | 102 | 103 | class LatticeStateDependentNoiseDistribution(StateDependentNoiseDistribution): 104 | """ 105 | Distribution class of Lattice exploration. 106 | Paper: Latent Exploration for Reinforcement Learning https://arxiv.org/abs/2305.20065 107 | 108 | It creates correlated noise across actuators, with a covariance matrix induced by 109 | the network weights. Can improve exploration in high-dimensional systems. 110 | 111 | :param action_dim: Dimension of the action space. 112 | :param full_std: Whether to use (n_features x n_actions) parameters 113 | for the std instead of only (n_features,), defaults to True 114 | :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure 115 | a positive standard deviation (cf paper). It allows to keep variance 116 | above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. 117 | Defaults to False 118 | :param squash_output: Whether to squash the output using a tanh function, 119 | this ensures bounds are satisfied, defaults to False 120 | :param learn_features: Whether to learn features for gSDE or not, defaults to False 121 | This will enable gradients to be backpropagated through the features, defaults to False 122 | :param epsilon: small value to avoid NaN due to numerical imprecision, defaults to 1e-6 123 | :param std_clip: clip range for the standard deviation, can be used to prevent extreme values, 124 | defaults to (1e-3, 1.0) 125 | :param std_reg: optional regularization to prevent collapsing to a deterministic policy, 126 | defaults to 0.0 127 | :param alpha: relative weight between action and latent noise, 0 removes the latent noise, 128 | defaults to 1 (equal weight) 129 | """ 130 | def __init__( 131 | self, 132 | action_dim: int, 133 | full_std: bool = True, 134 | use_expln: bool = False, 135 | squash_output: bool = False, 136 | learn_features: bool = False, 137 | epsilon: float = 1e-6, 138 | std_clip: Tuple[float, float] = (1e-3, 1.0), 139 | std_reg: float = 0.0, 140 | alpha: float = 1, 141 | ): 142 | super().__init__( 143 | action_dim=action_dim, 144 | full_std=full_std, 145 | use_expln=use_expln, 146 | squash_output=squash_output, 147 | epsilon=epsilon, 148 | learn_features=learn_features, 149 | ) 150 | self.min_std, self.max_std = std_clip 151 | self.std_reg = std_reg 152 | self.alpha = alpha 153 | 154 | def get_std(self, log_std: torch.Tensor) -> torch.Tensor: 155 | """ 156 | Get the standard deviation from the learned parameter 157 | (log of it by default). This ensures that the std is positive. 158 | 159 | :param log_std: 160 | :return: 161 | """ 162 | # Apply correction to remove scaling of action std as a function of the latent 163 | # dimension (see paper for details) 164 | log_std = log_std.clip(min=np.log(self.min_std), max=np.log(self.max_std)) 165 | log_std = log_std - 0.5 * np.log(self.latent_sde_dim) 166 | 167 | if self.use_expln: 168 | # From gSDE paper, it allows to keep variance 169 | # above zero and prevent it from growing too fast 170 | below_threshold = torch.exp(log_std) * (log_std <= 0) 171 | # Avoid NaN: zeros values that are below zero 172 | safe_log_std = log_std * (log_std > 0) + self.epsilon 173 | above_threshold = (torch.log1p(safe_log_std) + 1.0) * (log_std > 0) 174 | std = below_threshold + above_threshold 175 | else: 176 | # Use normal exponential 177 | std = torch.exp(log_std) 178 | 179 | if self.full_std: 180 | assert std.shape == ( 181 | self.latent_sde_dim, 182 | self.latent_sde_dim + self.action_dim, 183 | ) 184 | corr_std = std[:, : self.latent_sde_dim] 185 | ind_std = std[:, -self.action_dim :] 186 | else: 187 | # Reduce the number of parameters: 188 | assert std.shape == (self.latent_sde_dim, 2), std.shape 189 | corr_std = torch.ones(self.latent_sde_dim, self.latent_sde_dim).to(log_std.device) * std[:, 0:1] 190 | ind_std = torch.ones(self.latent_sde_dim, self.action_dim).to(log_std.device) * std[:, 1:] 191 | return corr_std, ind_std 192 | 193 | def sample_weights(self, log_std: torch.Tensor, batch_size: int = 1) -> None: 194 | """ 195 | Sample weights for the noise exploration matrix, 196 | using a centered Gaussian distribution. 197 | 198 | :param log_std: 199 | :param batch_size: 200 | """ 201 | corr_std, ind_std = self.get_std(log_std) 202 | self.corr_weights_dist = Normal(torch.zeros_like(corr_std), corr_std) 203 | self.ind_weights_dist = Normal(torch.zeros_like(ind_std), ind_std) 204 | 205 | # Reparametrization trick to pass gradients 206 | self.corr_exploration_mat = self.corr_weights_dist.rsample() 207 | self.ind_exploration_mat = self.ind_weights_dist.rsample() 208 | 209 | # Pre-compute matrices in case of parallel exploration 210 | self.corr_exploration_matrices = self.corr_weights_dist.rsample((batch_size,)) 211 | self.ind_exploration_matrices = self.ind_weights_dist.rsample((batch_size,)) 212 | 213 | def proba_distribution_net( 214 | self, 215 | latent_dim: int, 216 | log_std_init: float = 0, 217 | latent_sde_dim: Optional[int] = None, 218 | clip_mean: float = 0, 219 | ) -> Tuple[nn.Module, nn.Parameter]: 220 | """ 221 | Create the layers and parameter that represent the distribution: 222 | one output will be the deterministic action, the other parameter will be the 223 | standard deviation of the distribution that control the weights of the noise matrix, 224 | both for the action perturbation and the latent perturbation. 225 | 226 | :param latent_dim: Dimension of the last layer of the policy (before the action layer) 227 | :param log_std_init: Initial value for the log standard deviation 228 | :param latent_sde_dim: Dimension of the last layer of the features extractor 229 | for gSDE. By default, it is shared with the policy network. 230 | :param clip_mean: From SB3 implementation of SAC, add possibility to hard clip the 231 | mean of the actions. 232 | :return: 233 | """ 234 | # Note: we always consider that the noise is based on the features of the last 235 | # layer, so latent_sde_dim is the same as latent_dim 236 | self.mean_actions_net = nn.Linear(latent_dim, self.action_dim) 237 | if clip_mean > 0: 238 | self.clipped_mean_actions_net = nn.Sequential( 239 | self.mean_actions_net, 240 | nn.Hardtanh(min_val=-clip_mean, max_val=clip_mean)) 241 | else: 242 | self.clipped_mean_actions_net = self.mean_actions_net 243 | self.latent_sde_dim = latent_dim if latent_sde_dim is None else latent_sde_dim 244 | 245 | log_std = ( 246 | torch.ones(self.latent_sde_dim, self.latent_sde_dim + self.action_dim) 247 | if self.full_std 248 | else torch.ones(self.latent_sde_dim, 2) 249 | ) 250 | 251 | # Transform it into a parameter so it can be optimized 252 | log_std = nn.Parameter(log_std * log_std_init, requires_grad=True) 253 | # Sample an exploration matrix 254 | self.sample_weights(log_std) 255 | return self.clipped_mean_actions_net, log_std 256 | 257 | def proba_distribution( 258 | self, 259 | mean_actions: torch.Tensor, 260 | log_std: torch.Tensor, 261 | latent_sde: torch.Tensor, 262 | ) -> "LatticeNoiseDistribution": 263 | # Detach the last layer features because we do not want to update the noise generation 264 | # to influence the features of the policy 265 | self._latent_sde = latent_sde if self.learn_features else latent_sde.detach() 266 | corr_std, ind_std = self.get_std(log_std) 267 | latent_corr_variance = torch.mm(self._latent_sde**2, corr_std**2) # Variance of the hidden state 268 | latent_ind_variance = torch.mm(self._latent_sde**2, ind_std**2) + self.std_reg**2 # Variance of the action 269 | 270 | # First consider the correlated variance 271 | sigma_mat = self.alpha**2 * (self.mean_actions_net.weight * latent_corr_variance[:, None, :]).matmul( 272 | self.mean_actions_net.weight.T 273 | ) 274 | # Then the independent one, to be added to the diagonal 275 | sigma_mat[:, range(self.action_dim), range(self.action_dim)] += latent_ind_variance 276 | self.distribution = MultivariateNormal(loc=mean_actions, covariance_matrix=sigma_mat, validate_args=False) 277 | return self 278 | 279 | def log_prob(self, actions: torch.Tensor) -> torch.Tensor: 280 | if self.bijector is not None: 281 | gaussian_actions = self.bijector.inverse(actions) 282 | else: 283 | gaussian_actions = actions 284 | log_prob = self.distribution.log_prob(gaussian_actions) 285 | 286 | if self.bijector is not None: 287 | # Squash correction 288 | log_prob -= torch.sum(self.bijector.log_prob_correction(gaussian_actions), dim=1) 289 | return log_prob 290 | 291 | def entropy(self) -> torch.Tensor: 292 | if self.bijector is not None: 293 | return None 294 | return self.distribution.entropy() 295 | 296 | def get_noise( 297 | self, 298 | latent_sde: torch.Tensor, 299 | exploration_mat: torch.Tensor, 300 | exploration_matrices: torch.Tensor, 301 | ) -> torch.Tensor: 302 | latent_sde = latent_sde if self.learn_features else latent_sde.detach() 303 | # Default case: only one exploration matrix 304 | if len(latent_sde) == 1 or len(latent_sde) != len(exploration_matrices): 305 | return torch.mm(latent_sde, exploration_mat) 306 | # Use batch matrix multiplication for efficient computation 307 | # (batch_size, n_features) -> (batch_size, 1, n_features) 308 | latent_sde = latent_sde.unsqueeze(dim=1) 309 | # (batch_size, 1, n_actions) 310 | noise = torch.bmm(latent_sde, exploration_matrices) 311 | return noise.squeeze(dim=1) 312 | 313 | def sample(self) -> torch.Tensor: 314 | latent_noise = self.alpha * self.get_noise(self._latent_sde, self.corr_exploration_mat, self.corr_exploration_matrices) 315 | action_noise = self.get_noise(self._latent_sde, self.ind_exploration_mat, self.ind_exploration_matrices) 316 | actions = self.clipped_mean_actions_net(self._latent_sde + latent_noise) + action_noise 317 | if self.bijector is not None: 318 | return self.bijector.forward(actions) 319 | return actions 320 | -------------------------------------------------------------------------------- /src/models/ppo/policies.py: -------------------------------------------------------------------------------- 1 | from stable_baselines3.common.preprocessing import get_action_dim 2 | from models.distributions import ( 3 | LatticeNoiseDistribution, 4 | LatticeStateDependentNoiseDistribution, 5 | ) 6 | from sb3_contrib.common.recurrent.policies import RecurrentActorCriticPolicy 7 | 8 | 9 | class LatticeRecurrentActorCriticPolicy(RecurrentActorCriticPolicy): 10 | def __init__( 11 | self, 12 | observation_space, 13 | action_space, 14 | lr_schedule, 15 | use_lattice=True, 16 | std_clip=(1e-3, 10), 17 | expln_eps=1e-6, 18 | std_reg=0, 19 | alpha=1, 20 | **kwargs 21 | ): 22 | super().__init__(observation_space, action_space, lr_schedule, **kwargs) 23 | if use_lattice: 24 | if self.use_sde: 25 | self.dist_kwargs.update( 26 | { 27 | "epsilon": expln_eps, 28 | "std_clip": std_clip, 29 | "std_reg": std_reg, 30 | "alpha": alpha, 31 | } 32 | ) 33 | self.action_dist = LatticeStateDependentNoiseDistribution( 34 | get_action_dim(self.action_space), **self.dist_kwargs 35 | ) 36 | else: 37 | self.action_dist = LatticeNoiseDistribution( 38 | get_action_dim(self.action_space) 39 | ) 40 | self._build(lr_schedule) 41 | -------------------------------------------------------------------------------- /src/models/sac/policies.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Type 2 | import torch 3 | from torch import nn 4 | from models.distributions import ( 5 | LatticeStateDependentNoiseDistribution, 6 | SquashedLatticeNoiseDistribution, 7 | ) 8 | from stable_baselines3.common.preprocessing import get_action_dim 9 | from stable_baselines3.common.torch_layers import BaseFeaturesExtractor 10 | from stable_baselines3.sac.policies import SACPolicy, Actor 11 | 12 | 13 | class LatticeActor(Actor): 14 | def __init__( 15 | self, 16 | observation_space, 17 | action_space, 18 | net_arch, 19 | features_extractor, 20 | features_dim, 21 | activation_fn: Type[nn.Module] = nn.ReLU, 22 | use_sde: bool = False, 23 | log_std_init: float = -3, 24 | full_std: bool = True, 25 | sde_net_arch: Optional[List[int]] = None, 26 | use_expln: bool = False, 27 | clip_mean: float = 2.0, 28 | normalize_images: bool = True, 29 | use_lattice=False, 30 | std_clip=(1e-3, 10), 31 | expln_eps=1e-6, 32 | std_reg=0, 33 | alpha=1, 34 | ): 35 | super().__init__( 36 | observation_space, 37 | action_space, 38 | net_arch, 39 | features_extractor, 40 | features_dim, 41 | activation_fn=activation_fn, 42 | use_sde=use_sde, 43 | log_std_init=log_std_init, 44 | full_std=full_std, 45 | sde_net_arch=sde_net_arch, 46 | use_expln=use_expln, 47 | clip_mean=clip_mean, 48 | normalize_images=normalize_images, 49 | ) 50 | self.use_lattice = use_lattice 51 | self.std_clip = std_clip 52 | self.expln_eps = expln_eps 53 | self.std_reg = std_reg 54 | self.alpha = alpha 55 | if use_lattice: 56 | last_layer_dim = net_arch[-1] if len(net_arch) > 0 else features_dim 57 | action_dim = get_action_dim(self.action_space) 58 | if self.use_sde: 59 | self.action_dist = LatticeStateDependentNoiseDistribution( 60 | action_dim, 61 | full_std=full_std, 62 | use_expln=use_expln, 63 | squash_output=True, 64 | learn_features=True, 65 | epsilon=expln_eps, 66 | std_clip=std_clip, 67 | std_reg=std_reg, 68 | alpha=alpha, 69 | ) 70 | self.mu, self.log_std = self.action_dist.proba_distribution_net( 71 | latent_dim=last_layer_dim, 72 | latent_sde_dim=last_layer_dim, 73 | log_std_init=log_std_init, 74 | clip_mean=clip_mean 75 | ) 76 | else: 77 | self.action_dist = SquashedLatticeNoiseDistribution(action_dim) 78 | self.mu, self.log_std = self.action_dist.proba_distribution_net(last_layer_dim, log_std_init, state_dependent=True) 79 | 80 | def _get_constructor_parameters(self) -> Dict[str, Any]: 81 | data = super()._get_constructor_parameters() 82 | data.update( 83 | dict( 84 | use_lattice=self.use_lattice, 85 | std_clip=self.std_clip, 86 | expln_eps=self.expln_eps, 87 | std_reg=self.std_reg, 88 | alpha=self.alpha, 89 | ) 90 | ) 91 | return data 92 | 93 | def get_std(self) -> torch.Tensor: 94 | std = super().get_std() 95 | if self.use_lattice: 96 | std = torch.cat(std, dim=1) 97 | return std 98 | 99 | 100 | class LatticeSACPolicy(SACPolicy): 101 | def __init__( 102 | self, 103 | observation_space, 104 | action_space, 105 | lr_schedule, 106 | use_lattice=False, 107 | std_clip=(1e-3, 10), 108 | expln_eps=1e-6, 109 | std_reg=0, 110 | use_sde=False, 111 | alpha=1, 112 | **kwargs 113 | ): 114 | super().__init__( 115 | observation_space, action_space, lr_schedule, use_sde=use_sde, **kwargs 116 | ) 117 | self.lattice_kwargs = { 118 | "use_lattice": use_lattice, 119 | "expln_eps": expln_eps, 120 | "std_clip": std_clip, 121 | "std_reg": std_reg, 122 | "alpha": alpha, 123 | } 124 | self.actor_kwargs.update(self.lattice_kwargs) 125 | if use_lattice: 126 | self._build(lr_schedule) 127 | 128 | def make_actor( 129 | self, features_extractor: Optional[BaseFeaturesExtractor] = None 130 | ) -> Actor: 131 | actor_kwargs = self._update_features_extractor( 132 | self.actor_kwargs, features_extractor 133 | ) 134 | return LatticeActor(**actor_kwargs).to(self.device) 135 | 136 | def _get_constructor_parameters(self) -> Dict[str, Any]: 137 | data = super()._get_constructor_parameters() 138 | data.update(self.lattice_kwargs) 139 | return data 140 | -------------------------------------------------------------------------------- /src/train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amathislab/lattice/846d02fa993b9b80ce5ecb806463e0a05711bad3/src/train/__init__.py -------------------------------------------------------------------------------- /src/train/trainer.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from abc import ABC 4 | from dataclasses import dataclass, field 5 | from typing import List 6 | from stable_baselines3 import PPO, SAC, TD3 7 | from sb3_contrib import RecurrentPPO 8 | from stable_baselines3.common.callbacks import BaseCallback 9 | from stable_baselines3.common.vec_env import VecNormalize 10 | 11 | 12 | class Trainer(ABC): 13 | """ 14 | Protocol to train a library-independent RL algorithm on a gym environment. 15 | """ 16 | 17 | envs: VecNormalize 18 | env_config: dict 19 | model_config: dict 20 | model_path: str 21 | total_timesteps: int 22 | log: bool 23 | 24 | def _init_agent(self) -> None: 25 | """Initialize the agent.""" 26 | 27 | def train(self, total_timesteps: int) -> None: 28 | """Train agent on environment for total_timesteps episdodes.""" 29 | 30 | 31 | @dataclass 32 | class MyoTrainer: 33 | algo: str 34 | envs: VecNormalize 35 | env_config: dict 36 | load_model_path: str 37 | log_dir: str 38 | model_config: dict = field(default_factory=dict) 39 | callbacks: List[BaseCallback] = field(default_factory=list) 40 | timesteps: int = 10_000_000 41 | 42 | def __post_init__(self): 43 | self.dump_configs(path=self.log_dir) 44 | self.agent = self._init_agent() 45 | 46 | def dump_configs(self, path: str) -> None: 47 | with open(os.path.join(path, "env_config.json"), "w", encoding="utf8") as f: 48 | json.dump(self.env_config, f, indent=4, default=lambda _: '') 49 | with open(os.path.join(path, "model_config.json"), "w", encoding="utf8") as f: 50 | json.dump(self.model_config, f, indent=4, default=lambda _: '') 51 | 52 | def _init_agent(self): 53 | algo_class = self.get_algo_class() 54 | if self.load_model_path is not None: 55 | return algo_class.load( 56 | self.load_model_path, 57 | env=self.envs, 58 | tensorboard_log=self.log_dir, 59 | custom_objects=self.model_config, 60 | ) 61 | print("\nNo model path provided. Initializing new model.\n") 62 | return algo_class( 63 | env=self.envs, 64 | verbose=2, 65 | tensorboard_log=self.log_dir, 66 | **self.model_config, 67 | ) 68 | 69 | def train(self, total_timesteps: int) -> None: 70 | self.agent.learn( 71 | total_timesteps=total_timesteps, 72 | callback=self.callbacks, 73 | reset_num_timesteps=True, 74 | ) 75 | 76 | def save(self) -> None: 77 | self.agent.save(os.path.join(self.log_dir, "final_model.pkl")) 78 | self.envs.save(os.path.join(self.log_dir, "final_env.pkl")) 79 | 80 | def get_algo_class(self): 81 | if self.algo == "ppo": 82 | return PPO 83 | elif self.algo == "recurrent_ppo": 84 | return RecurrentPPO 85 | elif self.algo == "sac": 86 | return SAC 87 | elif self.algo == "td3": 88 | return TD3 89 | else: 90 | raise ValueError("Unknown algorithm ", self.algo) 91 | 92 | 93 | if __name__ == "__main__": 94 | print("This is a module. Run main.py to train the agent.") 95 | --------------------------------------------------------------------------------