├── .gitignore
├── LICENSE
├── README.md
├── data
├── images
│ ├── myochallenge_2023.gif
│ └── myochallenge_ranking.png
└── myosuite
│ └── assets
│ ├── __init__.py
│ ├── arm
│ ├── __init__.py
│ └── myo_elbow_1dof6muscles.mjb
│ ├── finger
│ ├── __init__.py
│ └── myo_finger_v0.mjb
│ └── hand
│ ├── __init__.py
│ ├── myo_hand_baoding.mjb
│ ├── myo_hand_die.mjb
│ ├── myo_hand_pen.mjb
│ └── myo_hand_pose.mjb
├── docker
├── Dockerfile
└── requirements.txt
└── src
├── definitions.py
├── envs
├── __init__.py
├── baoding.py
├── environment_factory.py
├── pen.py
├── pose.py
└── reorient.py
├── main_ant.py
├── main_baoding.py
├── main_half_cheetah.py
├── main_hopper.py
├── main_humanoid.py
├── main_pen.py
├── main_pose_elbow.py
├── main_pose_finger.py
├── main_pose_hand.py
├── main_reach_finger.py
├── main_reach_hand.py
├── main_reorient.py
├── main_walker.py
├── metrics
├── __init__.py
└── custom_callbacks.py
├── models
├── __init__.py
├── distributions.py
├── ppo
│ └── policies.py
└── sac
│ └── policies.py
└── train
├── __init__.py
└── trainer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # Apple
132 | *DS_Store
133 |
134 | # PyCharm
135 | *.idea
136 |
137 | # Project
138 | /output
139 |
140 | # Tensorboard
141 | .monitor.csv
142 |
143 | # VS Code
144 | .vscode/launch.json
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Mathis Group
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Lattice (Latent Exploration for Reinforcement Learning)
2 |
3 | This repository includes the implementation of Lattice exploration from the paper [Latent Exploration for Reinforcement Learning](https://arxiv.org/abs/2305.20065), published at NeurIPS 2023.
4 |
5 | Lattice introduces random perturbations in the latent state of the policy network, which result in correlated noise across the system's actuators. This form of latent noise can facilitate exploration when controlling high-dimensional systems, especially with redundant actuation, and may find low-effort solutions. A short video explaining the project can be found on [YouTube](https://www.youtube.com/watch?v=_CCF3GDM9jY).
6 |
7 | Lattice builds on top of Stable Baselines 3 (version 1.6.1) and it is here implemented for Recurrent PPO and SAC. Integration with a more recent version of Stable Baselines 3 and compatibility with more algorithms is [currently under development](https://github.com/albertochiappa/stable-baselines3).
8 |
9 | This project was developed by Alberto Silvio Chiappa, Alessandro Marin Vargas, Ann Zixiang Huang and Alexander Mathis (EPFL).
10 |
11 | ## MyoChallenge 2023
12 |
13 | We used Lattice train the top submission of the NeurIPS 2023 competition [MyoChallenge](https://sites.google.com/view/myosuite/myochallenge/myochallenge-2023?authuser=0), object manipulation track. With curriculum learning, reward shaping and Lattice exploration we trained a policy to control a biologically-realistic arm with 63 muscles and 27 degrees of freedom to place random objects inside a box of variable shape:
14 |
15 | 
16 |
17 | We outperformed the other best solutions both in score and effort:
18 |
19 |
20 |
21 | We have also created [a dedicated repository](https://github.com/amathislab/myochallenge_2023eval) for the solution, where we have released the pretrained weights of all the curriculum steps.
22 |
23 | ## Installation
24 |
25 | We recommend using a Docker container to execute the code of this repository. We provide both the docker image [albertochiappa/myo-cuda-pybullet](https://hub.docker.com/repository/docker/albertochiappa/myo-cuda-pybullet) in DockerHub and the Dockerfile in the [docker](/docker/) folder to create the same docker image locally.
26 |
27 | If you prefer to manually create a Conda environment, you can do so with the commands:
28 |
29 | ```bash
30 | conda create --name lattice python=3.8.10
31 | conda activate lattice
32 | pip install -r docker/requirements.txt
33 | pip install myosuite==1.2.4
34 | pip install --upgrade cloudpickle==2.2.0 pickle5==0.0.11 pybullet==3.2.5
35 | ```
36 |
37 | Please note that there is a version error with some packages, e.g. `stable_baselines3`, requiring a later version of `gym` which `myosuite` is incompatible with. For this reason we could not include all the requirements in `docker/requirements.txt`. In our experiments the stated incompatibility did not cause any error.
38 |
39 | ## Training a policy with Lattice
40 |
41 | We provide scripts for various environments of the [MyoSuite](https://sites.google.com/view/myosuite) and [PyBullet](https://pybullet.org).
42 |
43 | Training a policy is as easy as
44 |
45 | ```bash
46 | python main_pose_elbow.py --use_lattice
47 | ```
48 |
49 | if you have created a conda environment.
50 |
51 | If you prefer to use the readily available docker container, you can train like this:
52 |
53 | ```bash
54 | docker run --rm --gpus all -it \
55 | --mount type=bind,src="$(pwd)/src",target=/src \
56 | --mount type=bind,src="$(pwd)/data",target=/data \
57 | --mount type=bind,src="$(pwd)/output",target=/output \
58 | albertochiappa/myo-cuda-pybullet \
59 | python3 src/main_pose_elbow.py --use_lattice
60 | ```
61 |
62 | The previous command will start training in the `Elbow Pose` enviornment using Recurrent PPO. Simply change the main script name to start training for a different environment. The output of the training, including the configuration used to select the hyperparameters and the tensorboard logs, are saved in a subfolder of `output/`, named as the current date. The code outputs useful information to monitor the training in Tensorboard format. You can run Tensorboard in the output folder to visualize the learning curves and much more. The different configuration hyperparameters can be set from the command line, e.g., by running
63 |
64 | ```bash
65 | python src/main_humanoid.py --use_sde --use_lattice --freq=8
66 | ```
67 |
68 | In this case, a policy will be trained with SAC in the Humanoid environment, using state-dependent Lattice with update period 8.
69 |
70 | ## Structure of the repository
71 |
72 | * src/
73 | * envs/
74 | * Modified MyoSuite environments, when we used different parameters from the default (cf. manuscript)
75 | * Factory to instantiate all the environemtns used in the project
76 | * metrics/
77 | * Stable Baselines 3 callbacks to register useful information during the training
78 | * models/
79 | * Implementation of Lattice
80 | * Adaptation of SAC and PPO to use Lattice
81 | * train/
82 | * Trainer class used to manage the trainings
83 | * main_*
84 | * One main file per environment, to start a training
85 | * data/
86 | * configuration files for the MyoSuite environments
87 | * docker-cuda
88 | * Definition of the Dockerfile to create the image used to run the experiments, with GPU support
89 |
90 | ## Reference
91 |
92 | If our work was useful to your research, please cite:
93 |
94 | ```
95 | @article{chiappa2023latent,
96 | title={Latent exploration for reinforcement learning},
97 | author={Chiappa, Alberto Silvio and Vargas, Alessandro Marin and Huang, Ann Zixiang and Mathis, Alexander},
98 | journal={Advances in Neural Information Processing Systems (NeurIPS)},
99 | year={2023}
100 | }
101 | ```
102 |
--------------------------------------------------------------------------------
/data/images/myochallenge_2023.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amathislab/lattice/846d02fa993b9b80ce5ecb806463e0a05711bad3/data/images/myochallenge_2023.gif
--------------------------------------------------------------------------------
/data/images/myochallenge_ranking.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amathislab/lattice/846d02fa993b9b80ce5ecb806463e0a05711bad3/data/images/myochallenge_ranking.png
--------------------------------------------------------------------------------
/data/myosuite/assets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amathislab/lattice/846d02fa993b9b80ce5ecb806463e0a05711bad3/data/myosuite/assets/__init__.py
--------------------------------------------------------------------------------
/data/myosuite/assets/arm/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amathislab/lattice/846d02fa993b9b80ce5ecb806463e0a05711bad3/data/myosuite/assets/arm/__init__.py
--------------------------------------------------------------------------------
/data/myosuite/assets/arm/myo_elbow_1dof6muscles.mjb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amathislab/lattice/846d02fa993b9b80ce5ecb806463e0a05711bad3/data/myosuite/assets/arm/myo_elbow_1dof6muscles.mjb
--------------------------------------------------------------------------------
/data/myosuite/assets/finger/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amathislab/lattice/846d02fa993b9b80ce5ecb806463e0a05711bad3/data/myosuite/assets/finger/__init__.py
--------------------------------------------------------------------------------
/data/myosuite/assets/finger/myo_finger_v0.mjb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amathislab/lattice/846d02fa993b9b80ce5ecb806463e0a05711bad3/data/myosuite/assets/finger/myo_finger_v0.mjb
--------------------------------------------------------------------------------
/data/myosuite/assets/hand/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amathislab/lattice/846d02fa993b9b80ce5ecb806463e0a05711bad3/data/myosuite/assets/hand/__init__.py
--------------------------------------------------------------------------------
/data/myosuite/assets/hand/myo_hand_baoding.mjb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amathislab/lattice/846d02fa993b9b80ce5ecb806463e0a05711bad3/data/myosuite/assets/hand/myo_hand_baoding.mjb
--------------------------------------------------------------------------------
/data/myosuite/assets/hand/myo_hand_die.mjb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amathislab/lattice/846d02fa993b9b80ce5ecb806463e0a05711bad3/data/myosuite/assets/hand/myo_hand_die.mjb
--------------------------------------------------------------------------------
/data/myosuite/assets/hand/myo_hand_pen.mjb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amathislab/lattice/846d02fa993b9b80ce5ecb806463e0a05711bad3/data/myosuite/assets/hand/myo_hand_pen.mjb
--------------------------------------------------------------------------------
/data/myosuite/assets/hand/myo_hand_pose.mjb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amathislab/lattice/846d02fa993b9b80ce5ecb806463e0a05711bad3/data/myosuite/assets/hand/myo_hand_pose.mjb
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:11.7.1-cudnn8-runtime-ubuntu20.04
2 |
3 | RUN apt-get update -q \
4 | && DEBIAN_FRONTEND=noninteractive apt-get install -y \
5 | libgl1-mesa-dev \
6 | libgl1-mesa-glx \
7 | libglew-dev \
8 | libosmesa6-dev \
9 | software-properties-common \
10 | patchelf \
11 | python3 \
12 | python3-pip
13 |
14 |
15 | RUN apt-get install -y --no-install-recommends -o APT::Immediate-Configure=false gcc g++ git && \
16 | apt-get clean && \
17 | rm -rf /var/lib/apt/lists/*
18 |
19 |
20 | ENV PYTHONUNBUFFERED 1
21 |
22 | ADD requirements.txt /
23 |
24 | RUN pip install --upgrade pip
25 | RUN pip install -r requirements.txt
26 | RUN pip install myosuite==1.2.*
27 | RUN pip install --upgrade pickle5 cloudpickle pybullet
28 | RUN python3 -c "import mujoco_py"
--------------------------------------------------------------------------------
/docker/requirements.txt:
--------------------------------------------------------------------------------
1 | grpcio==1.47.0
2 | grpcio-tools==1.47.0
3 | protobuf==3.19.5
4 | pyglet==1.5.26
5 | numpy==1.21.6
6 | git+https://github.com/aravindr93/mjrl.git
7 | matplotlib==3.5.3
8 | requests==2.25.1
9 | scikit-learn==1.0.2
10 | sk-video==1.1.10
11 | tabulate==0.8.10
12 | --find-links https://download.pytorch.org/whl/torch_stable.html
13 | torch==1.12.1+cu116
14 | stable_baselines3==1.6.1
15 | sb3-contrib==1.6.0
16 | pandas==1.3.5
17 | tensorboard==2.10.0
--------------------------------------------------------------------------------
/src/definitions.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | ROOT_DIR = os.path.dirname(os.path.abspath(os.path.join(__file__, os.pardir)))
4 |
--------------------------------------------------------------------------------
/src/envs/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import gym
3 | import myosuite
4 | import numpy as np
5 | from definitions import ROOT_DIR # pylint: disable=import-error
6 | from myosuite.envs.myo import register_env_with_variants
7 |
8 |
9 | myosuite_path = os.path.join(ROOT_DIR, "data", "myosuite")
10 |
11 | # MyoChallenge Baoding: Phase1 env
12 | gym.envs.registration.register(
13 | id="CustomMyoChallengeBaodingP1-v1",
14 | entry_point="envs.baoding:CustomBaodingEnv",
15 | max_episode_steps=200,
16 | kwargs={
17 | "model_path": myosuite_path + "/assets/hand/myo_hand_baoding.mjb",
18 | "normalize_act": True,
19 | # 'goal_time_period': (5, 5),
20 | "goal_xrange": (0.025, 0.025),
21 | "goal_yrange": (0.028, 0.028),
22 | },
23 | )
24 |
25 | # MyoChallenge Die: Phase2 env
26 | gym.envs.registration.register(
27 | id="CustomMyoChallengeDieReorientP2-v0",
28 | entry_point="envs.reorient:CustomReorientEnv",
29 | max_episode_steps=150,
30 | kwargs={
31 | "model_path": myosuite_path + "/assets/hand/myo_hand_die.mjb",
32 | "normalize_act": True,
33 | "frame_skip": 5,
34 | # Randomization in goals
35 | 'goal_pos': (-.020, .020), # +- 2 cm
36 | 'goal_rot': (-3.14, 3.14), # +-180 degrees
37 | # Randomization in physical properties of the die
38 | 'obj_size_change': 0.007, # +-7mm delta change in object size
39 | 'obj_friction_change': (0.2, 0.001, 0.00002) # nominal: 1.0, 0.005, 0.0001
40 | },
41 | )
42 |
43 | register_env_with_variants(id='CustomMyoElbowPoseRandom-v0',
44 | entry_point='envs.pose:CustomPoseEnv',
45 | max_episode_steps=100,
46 | kwargs={
47 | 'model_path': myosuite_path+'/assets/arm/myo_elbow_1dof6muscles.mjb',
48 | 'target_jnt_range': {'r_elbow_flex':(0, 2.27),},
49 | 'viz_site_targets': ('wrist',),
50 | 'normalize_act': True,
51 | 'pose_thd': .175,
52 | 'reset_type': 'random'
53 | }
54 | )
55 |
56 | register_env_with_variants(id='CustomMyoFingerPoseRandom-v0',
57 | entry_point='envs.pose:CustomPoseEnv',
58 | max_episode_steps=100,
59 | kwargs={
60 | 'model_path': myosuite_path + '/assets/finger/myo_finger_v0.mjb',
61 | 'target_jnt_range': {'IFadb':(-.2, .2),
62 | 'IFmcp':(-.4, 1),
63 | 'IFpip':(.1, 1),
64 | 'IFdip':(.1, 1)
65 | },
66 | 'viz_site_targets': ('IFtip',),
67 | 'normalize_act': True,
68 | }
69 | )
70 |
71 | # Hand-Joint posing ==============================
72 | # Create ASL envs ==============================
73 | jnt_namesHand=['pro_sup', 'deviation', 'flexion', 'cmc_abduction', 'cmc_flexion', 'mp_flexion', 'ip_flexion', 'mcp2_flexion', 'mcp2_abduction', 'pm2_flexion', 'md2_flexion', 'mcp3_flexion', 'mcp3_abduction', 'pm3_flexion', 'md3_flexion', 'mcp4_flexion', 'mcp4_abduction', 'pm4_flexion', 'md4_flexion', 'mcp5_flexion', 'mcp5_abduction', 'pm5_flexion', 'md5_flexion']
74 |
75 | ASL_qpos={}
76 | ASL_qpos[0]='0 0 0 0.5624 0.28272 -0.75573 -1.309 1.30045 -0.006982 1.45492 0.998897 1.26466 0 1.40604 0.227795 1.07614 -0.020944 1.46103 0.06284 0.83263 -0.14399 1.571 1.38248'.split(' ')
77 | ASL_qpos[1]='0 0 0 0.0248 0.04536 -0.7854 -1.309 0.366605 0.010473 0.269258 0.111722 1.48459 0 1.45318 1.44532 1.44532 -0.204204 1.46103 1.44532 1.48459 -0.2618 1.47674 1.48459'.split(' ')
78 | ASL_qpos[2]='0 0 0 0.0248 0.04536 -0.7854 -1.13447 0.514973 0.010473 0.128305 0.111722 0.510575 0 0.37704 0.117825 1.44532 -0.204204 1.46103 1.44532 1.48459 -0.2618 1.47674 1.48459'.split(' ')
79 | ASL_qpos[3]='0 0 0 0.3384 0.25305 0.01569 -0.0262045 0.645885 0.010473 0.128305 0.111722 0.510575 0 0.37704 0.117825 1.571 -0.036652 1.52387 1.45318 1.40604 -0.068068 1.39033 1.571'.split(' ')
80 | ASL_qpos[4]='0 0 0 0.6392 -0.147495 -0.7854 -1.309 0.637158 0.010473 0.128305 0.111722 0.510575 0 0.37704 0.117825 0.306345 -0.010472 0.400605 0.133535 0.21994 -0.068068 0.274925 0.01571'.split(' ')
81 | ASL_qpos[5]='0 0 0 0.3384 0.25305 0.01569 -0.0262045 0.645885 0.010473 0.128305 0.111722 0.510575 0 0.37704 0.117825 0.306345 -0.010472 0.400605 0.133535 0.21994 -0.068068 0.274925 0.01571'.split(' ')
82 | ASL_qpos[6]='0 0 0 0.6392 -0.147495 -0.7854 -1.309 0.637158 0.010473 0.128305 0.111722 0.510575 0 0.37704 0.117825 0.306345 -0.010472 0.400605 0.133535 1.1861 -0.2618 1.35891 1.48459'.split(' ')
83 | ASL_qpos[7]='0 0 0 0.524 0.01569 -0.7854 -1.309 0.645885 -0.006982 0.128305 0.111722 0.510575 0 0.37704 0.117825 1.28036 -0.115192 1.52387 1.45318 0.432025 -0.068068 0.18852 0.149245'.split(' ')
84 | ASL_qpos[8]='0 0 0 0.428 0.22338 -0.7854 -1.309 0.645885 -0.006982 0.128305 0.194636 1.39033 0 1.08399 0.573415 0.667675 -0.020944 0 0.06284 0.432025 -0.068068 0.18852 0.149245'.split(' ')
85 | ASL_qpos[9]='0 0 0 0.5624 0.28272 -0.75573 -1.309 1.30045 -0.006982 1.45492 0.998897 0.39275 0 0.18852 0.227795 0.667675 -0.020944 0 0.06284 0.432025 -0.068068 0.18852 0.149245'.split(' ')
86 |
87 | # ASl Eval envs for each numerals
88 | for k in ASL_qpos.keys():
89 | register_env_with_variants(id='CustomMyoHandPose'+str(k)+'Fixed-v0',
90 | entry_point='envs.pose:CustomPoseEnv',
91 | max_episode_steps=100,
92 | kwargs={
93 | 'model_path': myosuite_path + '/assets/hand/myo_hand_pose.mjb',
94 | 'viz_site_targets': ('THtip','IFtip','MFtip','RFtip','LFtip'),
95 | 'target_jnt_value': np.array(ASL_qpos[k],'float'),
96 | 'normalize_act': True,
97 | 'pose_thd': .7,
98 | 'reset_type': "init", # none, init, random
99 | 'target_type': 'fixed', # generate/ fixed
100 | }
101 | )
102 |
103 | # ASL Train Env
104 | m = np.array([ASL_qpos[i] for i in range(10)]).astype(float)
105 | Rpos = {}
106 | for i_n, n in enumerate(jnt_namesHand):
107 | Rpos[n]=(np.min(m[:,i_n]), np.max(m[:,i_n]))
108 |
109 | register_env_with_variants(id='CustomMyoHandPoseRandom-v0', #reconsider
110 | entry_point='envs.pose:CustomPoseEnv',
111 | max_episode_steps=100,
112 | kwargs={
113 | 'model_path': myosuite_path + '/assets/hand/myo_hand_pose.mjb',
114 | 'viz_site_targets': ('THtip','IFtip','MFtip','RFtip','LFtip'),
115 | 'target_jnt_range': Rpos,
116 | 'normalize_act': True,
117 | 'pose_thd': .8,
118 | 'reset_type': "random", # none, init, random
119 | 'target_type': 'generate', # generate/ fixed
120 | }
121 | )
122 |
123 |
124 | # Pen twirl
125 | register_env_with_variants(id='CustomMyoHandPenTwirlRandom-v0',
126 | entry_point='envs.pen:CustomPenEnv',
127 | max_episode_steps=100,
128 | kwargs={
129 | 'model_path': myosuite_path + '/assets/hand/myo_hand_pen.mjb',
130 | 'normalize_act': True,
131 | 'frame_skip': 5,
132 | }
133 | )
--------------------------------------------------------------------------------
/src/envs/baoding.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=attribute-defined-outside-init, dangerous-default-value, protected-access, abstract-method, arguments-renamed, import-error
2 | import collections
3 | import random
4 | import numpy as np
5 | from myosuite.envs.myo.base_v0 import BaseV0
6 | from myosuite.envs.myo.myochallenge.baoding_v1 import WHICH_TASK, BaodingEnvV1, Task
7 | from sb3_contrib import RecurrentPPO
8 | from stable_baselines3.common.vec_env import VecNormalize
9 | from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
10 | from envs.environment_factory import EnvironmentFactory
11 |
12 |
13 | class CustomBaodingEnv(BaodingEnvV1):
14 | DEFAULT_RWD_KEYS_AND_WEIGHTS = {
15 | "pos_dist_1": 5.0,
16 | "pos_dist_2": 5.0,
17 | "alive": 0.0,
18 | "act_reg": 0.0,
19 | # "palm_up": 0.0,
20 | }
21 |
22 | def get_reward_dict(self, obs_dict):
23 | # tracking error
24 | target1_dist = np.linalg.norm(obs_dict["target1_err"], axis=-1)
25 | target2_dist = np.linalg.norm(obs_dict["target2_err"], axis=-1)
26 | target_dist = target1_dist + target2_dist
27 | act_mag = (
28 | np.linalg.norm(self.obs_dict["act"], axis=-1) / self.sim.model.na
29 | if self.sim.model.na != 0
30 | else 0
31 | )
32 |
33 | # detect fall
34 | object1_pos = (
35 | obs_dict["object1_pos"][:, :, 2]
36 | if obs_dict["object1_pos"].ndim == 3
37 | else obs_dict["object1_pos"][2]
38 | )
39 | object2_pos = (
40 | obs_dict["object2_pos"][:, :, 2]
41 | if obs_dict["object2_pos"].ndim == 3
42 | else obs_dict["object2_pos"][2]
43 | )
44 | is_fall_1 = object1_pos < self.drop_th
45 | is_fall_2 = object2_pos < self.drop_th
46 | is_fall = np.logical_or(is_fall_1, is_fall_2) # keep both balls up
47 |
48 | rwd_dict = collections.OrderedDict(
49 | (
50 | # Perform reward tuning here --
51 | # Update Optional Keys section below
52 | # Update reward keys (DEFAULT_RWD_KEYS_AND_WEIGHTS) to update final rewards
53 | # Examples: Env comes pre-packaged with two keys pos_dist_1 and pos_dist_2
54 | # Optional Keys
55 | ("pos_dist_1", -1.0 * target1_dist),
56 | ("pos_dist_2", -1.0 * target2_dist),
57 | ("alive", ~is_fall),
58 | # ("palm_up", palm_up_reward),
59 | # Must keys
60 | ("act_reg", -1.0 * act_mag),
61 | ("sparse", -target_dist),
62 | (
63 | "solved",
64 | (target1_dist < self.proximity_th)
65 | * (target2_dist < self.proximity_th)
66 | * (~is_fall),
67 | ),
68 | ("done", is_fall),
69 | )
70 | )
71 | rwd_dict["dense"] = np.sum(
72 | [wt * rwd_dict[key] for key, wt in self.rwd_keys_wt.items()], axis=0
73 | )
74 |
75 | # Sucess Indicator
76 | self.sim.model.geom_rgba[self.object1_gid, :2] = (
77 | np.array([1, 1])
78 | if target1_dist < self.proximity_th
79 | else np.array([0.5, 0.5])
80 | )
81 | self.sim.model.geom_rgba[self.object2_gid, :2] = (
82 | np.array([0.9, 0.7])
83 | if target1_dist < self.proximity_th
84 | else np.array([0.5, 0.5])
85 | )
86 |
87 | return rwd_dict
88 |
89 | def _add_noise_to_palm_position(
90 | self, qpos: np.ndarray, noise: float = 1
91 | ) -> np.ndarray:
92 | assert 0 <= noise <= 1, "Noise must be between 0 and 1"
93 |
94 | # pronation-supination of the wrist
95 | # noise = 1 corresponds to 10 degrees from facing up (one direction only)
96 | qpos[0] = self.np_random.uniform(
97 | low=-np.pi / 2, high=-np.pi / 2 + np.pi / 18 * noise
98 | )
99 |
100 | # ulnar deviation of wrist:
101 | # noise = 1 corresponds to 10 degrees on either side
102 | qpos[1] = self.np_random.uniform(
103 | low=-np.pi / 18 * noise, high=np.pi / 18 * noise
104 | )
105 |
106 | # extension flexion of the wrist
107 | # noise = 1 corresponds to 10 degrees on either side
108 | qpos[2] = self.np_random.uniform(
109 | low=-np.pi / 18 * noise, high=np.pi / 18 * noise
110 | )
111 |
112 | return qpos
113 |
114 | def _add_noise_to_finger_positions(
115 | self, qpos: np.ndarray, noise: float = 1
116 | ) -> np.ndarray:
117 | assert 0 <= noise <= 1, "Noise parameter must be between 0 and 1"
118 |
119 | # thumb all joints
120 | # noise = 1 corresponds to 10 degrees on either side
121 | qpos[3:7] = self.np_random.uniform(
122 | low=-np.pi / 18 * noise, high=np.pi / 18 * noise
123 | )
124 |
125 | # finger joints
126 | # noise = 1 corresponds to 30 degrees bent instead of fully open
127 | qpos[[7, 9, 10, 11, 13, 14, 15, 17, 18, 19, 21, 22]] = self.np_random.uniform(
128 | low=0, high=np.pi / 6 * noise
129 | )
130 |
131 | # finger abduction (sideways angle)
132 | # noise = 1 corresponds to 5 degrees on either side
133 | qpos[[8, 12, 16, 20]] = self.np_random.uniform(
134 | low=-np.pi / 36 * noise, high=np.pi / 36 * noise
135 | )
136 |
137 | return qpos
138 |
139 | def reset(self, reset_pose=None, reset_vel=None, reset_goal=None, time_period=None):
140 | self.which_task = self.sample_task()
141 | if self.rsi:
142 | # MODIFICATION: randomize starting target position along the cycle
143 | random_phase = np.random.uniform(low=-np.pi, high=np.pi)
144 | else:
145 | random_phase = 0
146 | self.ball_1_starting_angle = 3.0 * np.pi / 4.0 + random_phase
147 | self.ball_2_starting_angle = -1.0 * np.pi / 4.0 + random_phase
148 |
149 | # reset counters
150 | self.counter = 0
151 | self.x_radius = self.np_random.uniform(
152 | low=self.goal_xrange[0], high=self.goal_xrange[1]
153 | )
154 | self.y_radius = self.np_random.uniform(
155 | low=self.goal_yrange[0], high=self.goal_yrange[1]
156 | )
157 |
158 | # reset goal
159 | if time_period is None:
160 | time_period = self.np_random.uniform(
161 | low=self.goal_time_period[0], high=self.goal_time_period[1]
162 | )
163 | self.goal = (
164 | self.create_goal_trajectory(time_step=self.dt, time_period=time_period)
165 | if reset_goal is None
166 | else reset_goal.copy()
167 | )
168 |
169 | # reset scene (MODIFIED from base class MujocoEnv)
170 | qpos = self.init_qpos.copy() if reset_pose is None else reset_pose
171 | qvel = self.init_qvel.copy() if reset_vel is None else reset_vel
172 | self.robot.reset(qpos, qvel)
173 |
174 | if self.rsi:
175 | if np.random.uniform(0, 1) < self.rsi_probability:
176 | self.step(np.zeros(39))
177 |
178 | # update ball positions
179 | obs = self.get_obs().copy()
180 | qpos[23] = obs[35] # ball 1 x-position
181 | qpos[24] = obs[36] # ball 1 y-position
182 | qpos[30] = obs[38] # ball 2 x-position
183 | qpos[31] = obs[39] # ball 2 y-position
184 |
185 | if self.noise_balls:
186 | # update balls x,y,z positions with relative noise
187 | for i in [23, 24, 25, 30, 31, 32]:
188 | qpos[i] += np.random.uniform(
189 | low=-self.noise_balls, high=self.noise_balls
190 | )
191 |
192 | if self.noise_palm:
193 | qpos = self._add_noise_to_palm_position(qpos, self.noise_palm)
194 |
195 | if self.noise_fingers:
196 | qpos = self._add_noise_to_finger_positions(qpos, self.noise_fingers)
197 |
198 | if self.rsi or self.noise_palm or self.noise_fingers or self.noise_balls:
199 | self.set_state(qpos, qvel)
200 |
201 | return self.get_obs()
202 |
203 | def _setup(
204 | self,
205 | frame_skip: int = 10,
206 | drop_th=1.25, # drop height threshold
207 | proximity_th=0.015, # object-target proximity threshold
208 | goal_time_period=(5, 5), # target rotation time period
209 | goal_xrange=(0.025, 0.025), # target rotation: x radius (0.03)
210 | goal_yrange=(0.028, 0.028), # target rotation: x radius (0.02 * 1.5 * 1.2)
211 | obs_keys: list = BaodingEnvV1.DEFAULT_OBS_KEYS,
212 | weighted_reward_keys: list = DEFAULT_RWD_KEYS_AND_WEIGHTS,
213 | task=None,
214 | enable_rsi=False, # random state init for balls
215 | noise_palm=0, # magnitude of noise for palm (between 0 and 1)
216 | noise_fingers=0, # magnitude of noise for fingers (between 0 and 1)
217 | noise_balls=0, # relative magnitude of noise for the balls (1 is 100% relative noise)
218 | rsi_probability=1, # probability of implementing RSI
219 | **kwargs,
220 | ):
221 |
222 | # user parameters
223 | self.task = task
224 | self.which_task = self.sample_task()
225 | self.rsi = enable_rsi
226 | self.noise_palm = noise_palm
227 | self.noise_fingers = noise_fingers
228 | self.drop_th = drop_th
229 | self.proximity_th = proximity_th
230 | self.goal_time_period = goal_time_period
231 | self.goal_xrange = goal_xrange
232 | self.goal_yrange = goal_yrange
233 | self.noise_balls = noise_balls
234 | self.rsi_probability = rsi_probability
235 |
236 | # balls start at these angles
237 | # 1= yellow = top right
238 | # 2= pink = bottom left
239 | self.ball_1_starting_angle = 3.0 * np.pi / 4.0
240 | self.ball_2_starting_angle = -1.0 * np.pi / 4.0
241 |
242 | # init desired trajectory, for rotations
243 | self.center_pos = [-0.0125, -0.07] # [-.0020, -.0522]
244 | self.x_radius = self.np_random.uniform(
245 | low=self.goal_xrange[0], high=self.goal_xrange[1]
246 | )
247 | self.y_radius = self.np_random.uniform(
248 | low=self.goal_yrange[0], high=self.goal_yrange[1]
249 | )
250 |
251 | self.counter = 0
252 | self.goal = self.create_goal_trajectory(
253 | time_step=frame_skip * self.sim.model.opt.timestep, time_period=6
254 | )
255 |
256 | # init target and body sites
257 | self.object1_sid = self.sim.model.site_name2id("ball1_site")
258 | self.object2_sid = self.sim.model.site_name2id("ball2_site")
259 | self.object1_gid = self.sim.model.geom_name2id("ball1")
260 | self.object2_gid = self.sim.model.geom_name2id("ball2")
261 | self.target1_sid = self.sim.model.site_name2id("target1_site")
262 | self.target2_sid = self.sim.model.site_name2id("target2_site")
263 | self.sim.model.site_group[self.target1_sid] = 2
264 | self.sim.model.site_group[self.target2_sid] = 2
265 |
266 | BaseV0._setup(
267 | self,
268 | obs_keys=obs_keys,
269 | weighted_reward_keys=weighted_reward_keys,
270 | frame_skip=frame_skip,
271 | **kwargs,
272 | )
273 |
274 | # reset position
275 | self.init_qpos[:-14] *= 0 # Use fully open as init pos
276 | self.init_qpos[0] = -1.57 # Palm up
277 |
278 | def sample_task(self):
279 | if self.task is None:
280 | return Task(WHICH_TASK)
281 | else:
282 | if self.task == "cw":
283 | return Task(Task.BAODING_CW)
284 | elif self.task == "ccw":
285 | return Task(Task.BAODING_CCW)
286 | elif self.task == "random":
287 | return Task(random.choice(list(Task)))
288 | else:
289 | raise ValueError("Unknown task for baoding: ", self.task)
290 |
291 | def step(self, action):
292 | obs, reward, done, info = super().step(action)
293 | info.update(info.get("rwd_dict"))
294 | return obs, reward, done, info
295 |
296 |
297 | class CustomBaodingP2Env(BaodingEnvV1):
298 | def _setup(
299 | self,
300 | frame_skip: int = 10,
301 | drop_th=1.25, # drop height threshold
302 | proximity_th=0.015, # object-target proximity threshold
303 | goal_time_period=(5, 5), # target rotation time period
304 | goal_xrange=(0.025, 0.025), # target rotation: x radius (0.03)
305 | goal_yrange=(0.028, 0.028), # target rotation: x radius (0.02 * 1.5 * 1.2)
306 | obj_size_range=(0.018, 0.024), # Object size range. Nominal 0.022
307 | obj_mass_range=(0.030, 0.300), # Object weight range. Nominal 43 gms
308 | obj_friction_change=(0.2, 0.001, 0.00002),
309 | task_choice="fixed", # fixed/ random
310 | obs_keys: list = BaodingEnvV1.DEFAULT_OBS_KEYS,
311 | weighted_reward_keys: list = BaodingEnvV1.DEFAULT_RWD_KEYS_AND_WEIGHTS,
312 | enable_rsi=False, # random state init for balls
313 | rsi_probability=1, # probability of implementing RSI
314 | balls_overlap=False,
315 | overlap_probability=0,
316 | limit_init_angle=None,
317 | beta_init_angle=None,
318 | beta_ball_size=None,
319 | beta_ball_mass=None,
320 | noise_fingers=0,
321 | **kwargs,
322 | ):
323 | # user parameters
324 | self.task_choice = task_choice
325 | self.which_task = (
326 | self.np_random.choice(Task) if task_choice == "random" else Task(WHICH_TASK)
327 | )
328 | self.drop_th = drop_th
329 | self.proximity_th = proximity_th
330 | self.goal_time_period = goal_time_period
331 | self.goal_xrange = goal_xrange
332 | self.goal_yrange = goal_yrange
333 | self.rsi = enable_rsi
334 | self.rsi_probability = rsi_probability
335 | self.balls_overlap = balls_overlap
336 | self.overlap_probability = overlap_probability
337 | self.noise_fingers = noise_fingers
338 | self.limit_init_angle = limit_init_angle
339 | self.beta_init_angle = beta_init_angle
340 | self.beta_ball_size = beta_ball_size
341 | self.beta_ball_mass = beta_ball_mass
342 |
343 | # balls start at these angles
344 | # 1= yellow = top right
345 | # 2= pink = bottom left
346 |
347 | if np.random.uniform(0, 1) < self.overlap_probability:
348 | self.ball_1_starting_angle = 3.0 * np.pi / 4.0
349 | self.ball_2_starting_angle = -1.0 * np.pi / 4.0
350 | else:
351 | self.ball_1_starting_angle = 1.0 * np.pi / 4.0
352 | self.ball_2_starting_angle = self.ball_1_starting_angle - np.pi
353 |
354 | # init desired trajectory, for rotations
355 | self.center_pos = [-0.0125, -0.07] # [-.0020, -.0522]
356 | self.x_radius = self.np_random.uniform(
357 | low=self.goal_xrange[0], high=self.goal_xrange[1]
358 | )
359 | self.y_radius = self.np_random.uniform(
360 | low=self.goal_yrange[0], high=self.goal_yrange[1]
361 | )
362 |
363 | self.counter = 0
364 | self.goal = self.create_goal_trajectory(
365 | time_step=frame_skip * self.sim.model.opt.timestep, time_period=6
366 | )
367 |
368 | # init target and body sites
369 | self.object1_bid = self.sim.model.body_name2id("ball1")
370 | self.object2_bid = self.sim.model.body_name2id("ball2")
371 | self.object1_sid = self.sim.model.site_name2id("ball1_site")
372 | self.object2_sid = self.sim.model.site_name2id("ball2_site")
373 | self.object1_gid = self.sim.model.geom_name2id("ball1")
374 | self.object2_gid = self.sim.model.geom_name2id("ball2")
375 | self.target1_sid = self.sim.model.site_name2id("target1_site")
376 | self.target2_sid = self.sim.model.site_name2id("target2_site")
377 | self.sim.model.site_group[self.target1_sid] = 2
378 | self.sim.model.site_group[self.target2_sid] = 2
379 |
380 | # setup for task randomization
381 | self.obj_mass_range = {"low": obj_mass_range[0], "high": obj_mass_range[1]}
382 | self.obj_size_range = {"low": obj_size_range[0], "high": obj_size_range[1]}
383 | self.obj_friction_range = {
384 | "low": self.sim.model.geom_friction[self.object1_gid] - obj_friction_change,
385 | "high": self.sim.model.geom_friction[self.object1_gid]
386 | + obj_friction_change,
387 | }
388 |
389 | BaseV0._setup(
390 | self,
391 | obs_keys=obs_keys,
392 | weighted_reward_keys=weighted_reward_keys,
393 | frame_skip=frame_skip,
394 | **kwargs,
395 | )
396 |
397 | # reset position
398 | self.init_qpos[:-14] *= 0 # Use fully open as init pos
399 | self.init_qpos[0] = -1.57 # Palm up
400 |
401 | def get_reward_dict(self, obs_dict):
402 | # tracking error
403 | target1_dist = np.linalg.norm(obs_dict["target1_err"], axis=-1)
404 | target2_dist = np.linalg.norm(obs_dict["target2_err"], axis=-1)
405 | target_dist = target1_dist + target2_dist
406 | act_mag = (
407 | np.linalg.norm(self.obs_dict["act"], axis=-1) / self.sim.model.na
408 | if self.sim.model.na != 0
409 | else 0
410 | )
411 |
412 | # detect fall
413 | object1_pos = (
414 | obs_dict["object1_pos"][:, :, 2]
415 | if obs_dict["object1_pos"].ndim == 3
416 | else obs_dict["object1_pos"][2]
417 | )
418 | object2_pos = (
419 | obs_dict["object2_pos"][:, :, 2]
420 | if obs_dict["object2_pos"].ndim == 3
421 | else obs_dict["object2_pos"][2]
422 | )
423 | is_fall_1 = object1_pos < self.drop_th
424 | is_fall_2 = object2_pos < self.drop_th
425 | is_fall = np.logical_or(is_fall_1, is_fall_2) # keep both balls up
426 |
427 | rwd_dict = collections.OrderedDict(
428 | (
429 | # Perform reward tuning here --
430 | # Update Optional Keys section below
431 | # Update reward keys (DEFAULT_RWD_KEYS_AND_WEIGHTS) accordingly to update final rewards
432 | # Examples: Env comes pre-packaged with two keys pos_dist_1 and pos_dist_2
433 | # Optional Keys
434 | ("pos_dist_1", -1.0 * target1_dist),
435 | ("pos_dist_2", -1.0 * target2_dist),
436 | # Must keys
437 | ("act_reg", -1.0 * act_mag),
438 | ("alive", ~is_fall),
439 | ("sparse", -target_dist),
440 | (
441 | "solved",
442 | (target1_dist < self.proximity_th)
443 | * (target2_dist < self.proximity_th)
444 | * (~is_fall),
445 | ),
446 | ("done", is_fall),
447 | )
448 | )
449 | rwd_dict["dense"] = np.sum(
450 | [wt * rwd_dict[key] for key, wt in self.rwd_keys_wt.items()], axis=0
451 | )
452 |
453 | # Sucess Indicator
454 | self.sim.model.geom_rgba[self.object1_gid, :2] = (
455 | np.array([1, 1])
456 | if target1_dist < self.proximity_th
457 | else np.array([0.5, 0.5])
458 | )
459 | self.sim.model.geom_rgba[self.object2_gid, :2] = (
460 | np.array([0.9, 0.7])
461 | if target1_dist < self.proximity_th
462 | else np.array([0.5, 0.5])
463 | )
464 |
465 | return rwd_dict
466 |
467 | def _add_noise_to_finger_positions(
468 | self, qpos: np.ndarray, noise: float = 1
469 | ) -> np.ndarray:
470 | assert 0 <= noise <= 1, "Noise parameter must be between 0 and 1"
471 |
472 | # thumb all joints
473 | qpos[4:7] = self.np_random.uniform(
474 | low=-np.pi / 18 * noise, high=np.pi / 18 * noise
475 | )
476 |
477 | # finger joints
478 | qpos[[7, 9, 10, 11, 13, 14, 15, 17, 18, 19, 21, 22]] = self.np_random.uniform(
479 | low=0, high=np.pi / 6 * noise
480 | )
481 |
482 | return qpos
483 |
484 | def reset(self, reset_pose=None, reset_vel=None, reset_goal=None, time_period=None):
485 |
486 | # reset task
487 | if self.task_choice == "random":
488 | self.which_task = self.np_random.choice(Task)
489 |
490 | if np.random.uniform(0, 1) <= self.overlap_probability:
491 | self.ball_1_starting_angle = 3.0 * np.pi / 4.0
492 | elif self.limit_init_angle is not None:
493 | random_phase = self.np_random.uniform(
494 | low=-self.limit_init_angle, high=self.limit_init_angle
495 | )
496 |
497 | if self.beta_init_angle is not None:
498 | random_phase = (
499 | self.np_random.beta(
500 | self.beta_init_angle[0], self.beta_init_angle[1]
501 | )
502 | * 2
503 | * np.pi
504 | - np.pi
505 | )
506 |
507 | self.ball_1_starting_angle = 3.0 * np.pi / 4.0 + random_phase
508 | else:
509 | self.ball_1_starting_angle = self.np_random.uniform(
510 | low=0, high=2 * np.pi
511 | )
512 |
513 | self.ball_2_starting_angle = self.ball_1_starting_angle - np.pi
514 | # reset counters
515 | self.counter = 0
516 | self.x_radius = self.np_random.uniform(
517 | low=self.goal_xrange[0], high=self.goal_xrange[1]
518 | )
519 | self.y_radius = self.np_random.uniform(
520 | low=self.goal_yrange[0], high=self.goal_yrange[1]
521 | )
522 |
523 | # reset goal
524 | if time_period is None:
525 | time_period = self.np_random.uniform(
526 | low=self.goal_time_period[0], high=self.goal_time_period[1]
527 | )
528 | self.goal = (
529 | self.create_goal_trajectory(time_step=self.dt, time_period=time_period)
530 | if reset_goal is None
531 | else reset_goal.copy()
532 | )
533 |
534 | # balls mass changes
535 | self.sim.model.body_mass[self.object1_bid] = self.np_random.uniform(
536 | **self.obj_mass_range
537 | ) # call to mj_setConst(m,d) is being ignored. Derive quantities wont be updated. Die is simple shape. So this is reasonable approximation.
538 | self.sim.model.body_mass[self.object2_bid] = self.np_random.uniform(
539 | **self.obj_mass_range
540 | ) # call to mj_setConst(m,d) is being ignored. Derive quantities wont be updated. Die is simple shape. So this is reasonable approximation.
541 |
542 | if self.beta_ball_mass is not None:
543 | self.sim.model.body_mass[self.object1_bid] = (
544 | self.np_random.beta(self.beta_ball_mass[0], self.beta_ball_mass[1])
545 | * (self.obj_mass_range["high"] - self.obj_mass_range["low"])
546 | + self.obj_mass_range["low"]
547 | )
548 | self.sim.model.body_mass[self.object2_bid] = (
549 | self.np_random.beta(self.beta_ball_mass[0], self.beta_ball_mass[1])
550 | * (self.obj_mass_range["high"] - self.obj_mass_range["low"])
551 | + self.obj_mass_range["low"]
552 | )
553 | # balls friction changes
554 | self.sim.model.geom_friction[self.object1_gid] = self.np_random.uniform(
555 | **self.obj_friction_range
556 | )
557 | self.sim.model.geom_friction[self.object2_gid] = self.np_random.uniform(
558 | **self.obj_friction_range
559 | )
560 |
561 | # balls size changes
562 | self.sim.model.geom_size[self.object1_gid] = self.np_random.uniform(
563 | **self.obj_size_range
564 | )
565 | self.sim.model.geom_size[self.object2_gid] = self.np_random.uniform(
566 | **self.obj_size_range
567 | )
568 |
569 | if self.beta_ball_size is not None:
570 | self.sim.model.geom_size[self.object1_gid] = (
571 | self.np_random.beta(self.beta_ball_size[0], self.beta_ball_size[1])
572 | * (self.obj_size_range["high"] - self.obj_size_range["low"])
573 | + self.obj_size_range["low"]
574 | )
575 | self.sim.model.geom_size[self.object2_gid] = (
576 | self.np_random.beta(self.beta_ball_size[0], self.beta_ball_size[1])
577 | * (self.obj_size_range["high"] - self.obj_size_range["low"])
578 | + self.obj_size_range["low"]
579 | )
580 | # reset scene
581 | qpos = self.init_qpos.copy() if reset_pose is None else reset_pose
582 | qvel = self.init_qvel.copy() if reset_vel is None else reset_vel
583 | self.robot.reset(qpos, qvel)
584 |
585 | if self.rsi and np.random.uniform(0, 1) < self.rsi_probability:
586 | random_phase = np.random.uniform(low=-np.pi, high=np.pi)
587 | self.ball_1_starting_angle = 3.0 * np.pi / 4.0 + random_phase
588 | self.ball_2_starting_angle = -1.0 * np.pi / 4.0 + random_phase
589 |
590 | # # reset scene (MODIFIED from base class MujocoEnv)
591 | self.robot.reset(qpos, qvel)
592 | self.step(np.zeros(39))
593 | # update ball positions
594 | obs_dict = self.get_obs_dict(self.sim)
595 | target_1_pos = obs_dict["target1_pos"]
596 | target_2_pos = obs_dict["target2_pos"]
597 | qpos[23] = target_1_pos[0] # ball 1 x-position
598 | qpos[24] = target_1_pos[1] # ball 1 y-position
599 | qpos[30] = target_2_pos[0] # ball 2 x-position
600 | qpos[31] = target_2_pos[1] # ball 2 y-position
601 | self.set_state(qpos, qvel)
602 |
603 | if self.balls_overlap is False:
604 | self.ball_1_starting_angle = self.np_random.uniform(
605 | low=0, high=2 * np.pi
606 | )
607 | self.ball_2_starting_angle = self.ball_1_starting_angle - np.pi
608 |
609 | if self.noise_fingers is not None:
610 | qpos = self._add_noise_to_finger_positions(qpos, self.noise_fingers)
611 | self.set_state(qpos, qvel)
612 |
613 | return self.get_obs()
614 |
615 | def step(self, action):
616 | obs, reward, done, info = super().step(action)
617 | info.update(info.get("rwd_dict"))
618 | return obs, reward, done, info
619 |
--------------------------------------------------------------------------------
/src/envs/environment_factory.py:
--------------------------------------------------------------------------------
1 | import gym
2 | import pybullet_envs
3 |
4 |
5 | class EnvironmentFactory:
6 | """Static factory to instantiate and register gym environments by name."""
7 |
8 | @staticmethod
9 | def create(env_name, **kwargs):
10 | """Creates an environment given its name as a string, and forwards the kwargs
11 | to its __init__ function.
12 |
13 | Args:
14 | env_name (str): name of the environment
15 |
16 | Raises:
17 | ValueError: if the name of the environment is unknown
18 |
19 | Returns:
20 | gym.env: the selected environment
21 | """
22 | # make myosuite envs
23 | if env_name == "MyoFingerPoseRandom":
24 | return gym.make("myoFingerPoseRandom-v0")
25 | elif env_name == "MyoFingerReachRandom":
26 | return gym.make("myoFingerReachRandom-v0")
27 | elif env_name == "MyoHandReachRandom":
28 | return gym.make("myoHandReachRandom-v0")
29 | elif env_name == "MyoElbowReachRandom":
30 | return gym.make("myoElbowReachRandom-v0")
31 | elif env_name == "CustomMyoBaodingBallsP1":
32 | return gym.make("CustomMyoChallengeBaodingP1-v1", **kwargs)
33 | elif env_name == "CustomMyoReorientP2":
34 | return gym.make("CustomMyoChallengeDieReorientP2-v0", **kwargs)
35 | elif env_name == "CustomMyoElbowPoseRandom":
36 | return gym.make("CustomMyoElbowPoseRandom-v0", **kwargs)
37 | elif env_name == "CustomMyoFingerPoseRandom":
38 | return gym.make("CustomMyoFingerPoseRandom-v0", **kwargs)
39 | elif env_name == "CustomMyoHandPoseRandom":
40 | return gym.make("CustomMyoHandPoseRandom-v0", **kwargs)
41 | elif env_name == "CustomMyoPenTwirlRandom":
42 | return gym.make("CustomMyoHandPenTwirlRandom-v0", **kwargs)
43 | elif env_name == "WalkerBulletEnv":
44 | return gym.make("Walker2DBulletEnv-v0", **kwargs)
45 | elif env_name == "HalfCheetahBulletEnv":
46 | return gym.make("HalfCheetahBulletEnv-v0", **kwargs)
47 | elif env_name == "AntBulletEnv":
48 | return gym.make("AntBulletEnv-v0", **kwargs)
49 | elif env_name == "HopperBulletEnv":
50 | return gym.make("HopperBulletEnv-v0", **kwargs)
51 | elif env_name == "HumanoidBulletEnv":
52 | return gym.make("HumanoidBulletEnv-v0", **kwargs)
53 | else:
54 | raise ValueError("Environment name not recognized:", env_name)
55 |
--------------------------------------------------------------------------------
/src/envs/pen.py:
--------------------------------------------------------------------------------
1 | import collections
2 |
3 | import numpy as np
4 | from myosuite.envs.env_base import MujocoEnv
5 | from myosuite.envs.myo.base_v0 import BaseV0
6 | from myosuite.envs.myo.pen_v0 import PenTwirlRandomEnvV0
7 | from myosuite.utils.quat_math import euler2quat
8 | from myosuite.utils.vector_math import calculate_cosine
9 |
10 |
11 | class CustomPenEnv(PenTwirlRandomEnvV0):
12 | def get_reward_dict(self, obs_dict):
13 | pos_err = obs_dict["obj_err_pos"]
14 | pos_align = np.linalg.norm(pos_err, axis=-1)
15 | rot_align = calculate_cosine(obs_dict["obj_rot"], obs_dict["obj_des_rot"])
16 | dropped = pos_align > 0.075
17 | act_mag = (
18 | np.linalg.norm(self.obs_dict["act"], axis=-1) / self.sim.model.na
19 | if self.sim.model.na != 0
20 | else 0
21 | )
22 | pos_align_diff = self.pos_align - pos_align # should decrease
23 | rot_align_diff = rot_align - self.rot_align # should increase
24 | alive = ~dropped
25 |
26 | rwd_dict = collections.OrderedDict(
27 | (
28 | # Optional Keys
29 | ("pos_align", -1.0 * pos_align),
30 | ("rot_align", rot_align),
31 | ("pos_align_diff", pos_align_diff),
32 | ("rot_align_diff", rot_align_diff),
33 | ("alive", alive),
34 | ("act_reg", -1.0 * act_mag),
35 | ("drop", -1.0 * dropped),
36 | (
37 | "bonus",
38 | 1.0 * (rot_align > 0.9) * (pos_align < 0.075)
39 | + 5.0 * (rot_align > 0.95) * (pos_align < 0.075),
40 | ),
41 | # Must keys
42 | ("sparse", -1.0 * pos_align + rot_align),
43 | ("solved", (rot_align > 0.95) * (~dropped)),
44 | ("done", dropped),
45 | )
46 | )
47 | rwd_dict["dense"] = np.sum(
48 | [wt * rwd_dict[key] for key, wt in self.rwd_keys_wt.items()], axis=0
49 | )
50 | return rwd_dict
51 |
52 | def _setup(
53 | self,
54 | obs_keys: list = PenTwirlRandomEnvV0.DEFAULT_OBS_KEYS,
55 | weighted_reward_keys: list = PenTwirlRandomEnvV0.DEFAULT_RWD_KEYS_AND_WEIGHTS,
56 | goal_orient_range=(
57 | -1,
58 | 1,
59 | ), # can be used to make the task simpler and limit the target orientations
60 | enable_rsi=False,
61 | rsi_distance=0,
62 | **kwargs,
63 | ):
64 | self.target_obj_bid = self.sim.model.body_name2id("target")
65 | self.S_grasp_sid = self.sim.model.site_name2id("S_grasp")
66 | self.obj_bid = self.sim.model.body_name2id("Object")
67 | self.eps_ball_sid = self.sim.model.site_name2id("eps_ball")
68 | self.obj_t_sid = self.sim.model.site_name2id("object_top")
69 | self.obj_b_sid = self.sim.model.site_name2id("object_bottom")
70 | self.tar_t_sid = self.sim.model.site_name2id("target_top")
71 | self.tar_b_sid = self.sim.model.site_name2id("target_bottom")
72 | self.pen_length = np.linalg.norm(
73 | self.sim.model.site_pos[self.obj_t_sid]
74 | - self.sim.model.site_pos[self.obj_b_sid]
75 | )
76 | self.tar_length = np.linalg.norm(
77 | self.sim.model.site_pos[self.tar_t_sid]
78 | - self.sim.model.site_pos[self.tar_b_sid]
79 | )
80 |
81 | self.goal_orient_range = goal_orient_range
82 | self.rsi = enable_rsi
83 | self.rsi_distance = rsi_distance
84 | self.pos_align = 0
85 | self.rot_align = 0
86 |
87 | BaseV0._setup(
88 | self,
89 | obs_keys=obs_keys,
90 | weighted_reward_keys=weighted_reward_keys,
91 | **kwargs,
92 | )
93 | self.init_qpos[:-6] *= 0 # Use fully open as init pos
94 | self.init_qpos[0] = -1.5 # place palm up
95 |
96 | def reset(self):
97 | # randomize target
98 | desired_orien = np.zeros(3)
99 | desired_orien[0] = self.np_random.uniform(
100 | low=self.goal_orient_range[0], high=self.goal_orient_range[1]
101 | )
102 | desired_orien[1] = self.np_random.uniform(
103 | low=self.goal_orient_range[0], high=self.goal_orient_range[1]
104 | )
105 | self.sim.model.body_quat[self.target_obj_bid] = euler2quat(desired_orien)
106 |
107 | if self.rsi:
108 | init_orien = np.zeros(3)
109 | init_orien[:2] = desired_orien[:2] + self.rsi_distance * (
110 | init_orien[:2] - desired_orien[:2]
111 | )
112 | self.sim.model.body_quat[self.obj_bid] = euler2quat(init_orien)
113 |
114 | self.robot.sync_sims(self.sim, self.sim_obsd)
115 | obs = MujocoEnv.reset(self)
116 |
117 | self.pos_align = np.linalg.norm(self.obs_dict["obj_err_pos"], axis=-1)
118 | self.rot_align = calculate_cosine(
119 | self.obs_dict["obj_rot"], self.obs_dict["obj_des_rot"]
120 | )
121 |
122 | return obs
123 |
124 | def step(self, a):
125 | obs, reward, done, info = super().step(a)
126 | self.pos_align = np.linalg.norm(self.obs_dict["obj_err_pos"], axis=-1)
127 | self.rot_align = calculate_cosine(
128 | self.obs_dict["obj_rot"], self.obs_dict["obj_des_rot"]
129 | )
130 | info.update(info.get("rwd_dict"))
131 | return obs, reward, done, info
132 |
133 | def render(self, mode="human"):
134 | return self.sim.render(mode=mode)
135 |
--------------------------------------------------------------------------------
/src/envs/pose.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from myosuite.envs.myo.base_v0 import BaseV0
3 | from myosuite.envs.myo.pose_v0 import PoseEnvV0
4 |
5 |
6 | class CustomPoseEnv(PoseEnvV0):
7 | def _setup(
8 | self,
9 | viz_site_targets: tuple = None, # site to use for targets visualization []
10 | target_jnt_range: dict = None, # joint ranges as tuples {name:(min, max)}_nq
11 | target_jnt_value: list = None, # desired joint vector [des_qpos]_nq
12 | reset_type="init", # none; init; random; sds
13 | target_type="generate", # generate; switch; fixed
14 | obs_keys: list = PoseEnvV0.DEFAULT_OBS_KEYS,
15 | weighted_reward_keys: dict = PoseEnvV0.DEFAULT_RWD_KEYS_AND_WEIGHTS,
16 | pose_thd=0.35,
17 | weight_bodyname=None,
18 | weight_range=None,
19 | sds_distance=0,
20 | target_distance=1, # for non-SDS curriculum, the target is set at a fraction of the full distance
21 | **kwargs,
22 | ):
23 | self.reset_type = reset_type
24 | self.target_type = target_type
25 | self.pose_thd = pose_thd
26 | self.weight_bodyname = weight_bodyname
27 | self.weight_range = weight_range
28 | self.sds_distance = sds_distance
29 | self.target_distance = target_distance
30 |
31 | # resolve joint demands
32 | if target_jnt_range:
33 | self.target_jnt_ids = []
34 | self.target_jnt_range = []
35 | for jnt_name, jnt_range in target_jnt_range.items():
36 | self.target_jnt_ids.append(self.sim.model.joint_name2id(jnt_name))
37 | self.target_jnt_range.append(jnt_range)
38 | self.target_jnt_range = np.array(self.target_jnt_range)
39 | self.target_jnt_value = np.mean(
40 | self.target_jnt_range, axis=1
41 | ) # pseudo targets for init
42 | else:
43 | self.target_jnt_value = target_jnt_value
44 |
45 | BaseV0._setup(
46 | self,
47 | obs_keys=obs_keys,
48 | weighted_reward_keys=weighted_reward_keys,
49 | sites=viz_site_targets,
50 | **kwargs,
51 | )
52 |
53 | def reset(self):
54 | # udpate wegith
55 | if self.weight_bodyname is not None:
56 | bid = self.sim.model.body_name2id(self.weight_bodyname)
57 | gid = self.sim.model.body_geomadr[bid]
58 | weight = self.np_random.uniform(
59 | low=self.weight_range[0], high=self.weight_range[1]
60 | )
61 | self.sim.model.body_mass[bid] = weight
62 | self.sim_obsd.model.body_mass[bid] = weight
63 | # self.sim_obsd.model.geom_size[gid] = self.sim.model.geom_size[gid] * weight/10
64 | self.sim.model.geom_size[gid][0] = 0.01 + 2.5 * weight / 100
65 | # self.sim_obsd.model.geom_size[gid][0] = weight/10
66 |
67 | # update target
68 | if self.target_type == "generate":
69 | # use target_jnt_range to generate targets
70 | self.update_target(restore_sim=True)
71 | elif self.target_type == "fixed":
72 | self.update_target(restore_sim=True)
73 | else:
74 | print("{} Target Type not found ".format(self.target_type))
75 |
76 | # update init state
77 | if self.reset_type is None or self.reset_type == "none":
78 | # no reset; use last state
79 | obs = self.get_obs()
80 | elif self.reset_type == "init":
81 | # reset to init state
82 | obs = BaseV0.reset(self)
83 | elif self.reset_type == "random":
84 | # reset to random state
85 | jnt_init = self.np_random.uniform(
86 | high=self.sim.model.jnt_range[:, 1], low=self.sim.model.jnt_range[:, 0]
87 | )
88 | obs = BaseV0.reset(self, reset_qpos=jnt_init)
89 | elif self.reset_type == "sds":
90 | init_qpos = self.init_qpos.copy()
91 | init_qvel = self.init_qvel.copy()
92 | target_qpos = self.target_jnt_value.copy()
93 | qpos = (1 - self.sds_distance) * target_qpos + self.sds_distance * init_qpos
94 | self.robot.reset(qpos, init_qvel)
95 | obs = self.get_obs()
96 | else:
97 | print("Reset Type not found")
98 |
99 | return obs
100 |
101 | def step(self, action):
102 | obs, reward, done, info = super().step(action)
103 | info.update(info.get("rwd_dict"))
104 | return obs, reward, done, info
105 |
106 | def render(self, mode):
107 | return self.sim.render(mode=mode)
108 |
109 | def get_target_pose(self):
110 | full_distance_target_pose = super().get_target_pose()
111 | init_pose = self.init_qpos.copy()
112 | target_pose = init_pose + self.target_distance * (
113 | full_distance_target_pose - init_pose
114 | )
115 | return target_pose
116 |
--------------------------------------------------------------------------------
/src/envs/reorient.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=attribute-defined-outside-init, dangerous-default-value, protected-access, abstract-method, arguments-renamed
2 | import collections
3 | import numpy as np
4 | from myosuite.envs.env_base import MujocoEnv
5 | from myosuite.envs.myo.base_v0 import BaseV0
6 | from myosuite.envs.myo.myochallenge.reorient_v0 import ReorientEnvV0
7 | from myosuite.utils.quat_math import euler2quat
8 |
9 |
10 | class CustomReorientEnv(ReorientEnvV0):
11 | def get_reward_dict(self, obs_dict):
12 | pos_dist_new = np.abs(np.linalg.norm(self.obs_dict["pos_err"], axis=-1))
13 | rot_dist_new = np.abs(np.linalg.norm(self.obs_dict["rot_err"], axis=-1))
14 | pos_dist_diff = self.pos_dist - pos_dist_new
15 | rot_dist_diff = self.rot_dist - rot_dist_new
16 | act_mag = (
17 | np.linalg.norm(self.obs_dict["act"], axis=-1) / self.sim.model.na
18 | if self.sim.model.na != 0
19 | else 0
20 | )
21 | drop = pos_dist_new > self.drop_th
22 |
23 | rwd_dict = collections.OrderedDict(
24 | (
25 | # Perform reward tuning here --
26 | # Update Optional Keys section below
27 | # Update reward keys (DEFAULT_RWD_KEYS_AND_WEIGHTS) accordingly to update final rewards
28 | # Examples: Env comes pre-packaged with two keys pos_dist and rot_dist
29 | # Optional Keys
30 | ("pos_dist", -1.0 * pos_dist_new),
31 | ("rot_dist", -1.0 * rot_dist_new),
32 | ("pos_dist_diff", pos_dist_diff),
33 | ("rot_dist_diff", rot_dist_diff),
34 | ("alive", ~drop),
35 | # Must keys
36 | ("act_reg", -1.0 * act_mag),
37 | ("sparse", -rot_dist_new - 10.0 * pos_dist_new),
38 | (
39 | "solved",
40 | (
41 | (pos_dist_new < self.pos_th)
42 | and (rot_dist_new < self.rot_th)
43 | and (not drop)
44 | )
45 | * np.ones((1, 1)),
46 | ),
47 | ("done", drop),
48 | )
49 | )
50 | rwd_dict["dense"] = np.sum(
51 | [wt * rwd_dict[key] for key, wt in self.rwd_keys_wt.items()], axis=0
52 | )
53 |
54 | # Sucess Indicator
55 | self.sim.model.site_rgba[self.success_indicator_sid, :2] = (
56 | np.array([0, 2]) if rwd_dict["solved"] else np.array([2, 0])
57 | )
58 | return rwd_dict
59 |
60 | def _setup(
61 | self,
62 | obs_keys: list = ReorientEnvV0.DEFAULT_OBS_KEYS,
63 | weighted_reward_keys: list = ReorientEnvV0.DEFAULT_RWD_KEYS_AND_WEIGHTS,
64 | goal_pos=(0.0, 0.0), # goal position range (relative to initial pos)
65 | goal_rot=(0.785, 0.785), # goal rotation range (relative to initial rot)
66 | obj_size_change=0, # object size change (relative to initial size)
67 | obj_friction_change=(
68 | 0,
69 | 0,
70 | 0,
71 | ), # object friction change (relative to initial size)
72 | pos_th=0.025, # position error threshold
73 | rot_th=0.262, # rotation error threshold
74 | drop_th=0.200, # drop height threshold
75 | enable_rsi=False,
76 | rsi_distance_pos=0,
77 | rsi_distance_rot=0,
78 | goal_rot_x=None,
79 | goal_rot_y=None,
80 | goal_rot_z=None,
81 | guided_trajectory_steps=None,
82 | **kwargs,
83 | ):
84 | self.already_reset = False
85 | self.object_sid = self.sim.model.site_name2id("object_o")
86 | self.goal_sid = self.sim.model.site_name2id("target_o")
87 | self.success_indicator_sid = self.sim.model.site_name2id("target_ball")
88 | self.goal_bid = self.sim.model.body_name2id("target")
89 | self.object_bid = self.sim.model.body_name2id("Object")
90 | self.goal_init_pos = self.sim.data.site_xpos[self.goal_sid].copy()
91 | self.goal_init_rot = self.sim.model.body_quat[self.goal_bid].copy()
92 | self.goal_obj_offset = (
93 | self.sim.data.site_xpos[self.goal_sid]
94 | - self.sim.data.site_xpos[self.object_sid]
95 | ) # visualization offset between target and object
96 | self.goal_pos = goal_pos
97 | self.goal_rot = goal_rot
98 | self.pos_th = pos_th
99 | self.rot_th = rot_th
100 | self.drop_th = drop_th
101 | self.rsi = enable_rsi
102 | self.rsi_distance_pos = rsi_distance_pos
103 | self.rsi_distance_rot = rsi_distance_rot
104 | self.goal_rot_x = goal_rot_x
105 | self.goal_rot_y = goal_rot_y
106 | self.goal_rot_z = goal_rot_z
107 | self.guided_trajectory_steps = guided_trajectory_steps
108 | self.pos_dist = 0
109 | self.rot_dist = 0
110 |
111 | # setup for object randomization
112 | self.target_gid = self.sim.model.geom_name2id("target_dice")
113 | self.target_default_size = self.sim.model.geom_size[self.target_gid].copy()
114 |
115 | object_bid = self.sim.model.body_name2id("Object")
116 | self.object_gid0 = self.sim.model.body_geomadr[object_bid]
117 | self.object_gidn = self.object_gid0 + self.sim.model.body_geomnum[object_bid]
118 | self.object_default_size = self.sim.model.geom_size[
119 | self.object_gid0 : self.object_gidn
120 | ].copy()
121 | self.object_default_pos = self.sim.model.geom_pos[
122 | self.object_gid0 : self.object_gidn
123 | ].copy()
124 |
125 | self.obj_size_change = {"high": obj_size_change, "low": -obj_size_change}
126 | self.obj_friction_range = {
127 | "high": self.sim.model.geom_friction[self.object_gid0 : self.object_gidn]
128 | + obj_friction_change,
129 | "low": self.sim.model.geom_friction[self.object_gid0 : self.object_gidn]
130 | - obj_friction_change,
131 | }
132 |
133 | BaseV0._setup(
134 | self,
135 | obs_keys=obs_keys,
136 | weighted_reward_keys=weighted_reward_keys,
137 | **kwargs,
138 | )
139 | self.init_qpos[:-7] *= 0 # Use fully open as init pos
140 | self.init_qpos[0] = -1.5 # Palm up
141 |
142 | def reset(self, reset_qpos=None, reset_qvel=None):
143 |
144 | # First sample the target position and orientation of the die
145 | self.episode_goal_pos = self.sample_goal_position()
146 | self.episode_goal_rot = self.sample_goal_orientation()
147 |
148 | # Then get the initial position and orientation of the die
149 | if self.rsi:
150 | object_init_pos = (
151 | self.rsi_distance_pos * self.goal_init_pos
152 | + (1 - self.rsi_distance_pos) * self.episode_goal_pos
153 | - self.goal_obj_offset
154 | )
155 | object_init_rot = (
156 | self.rsi_distance_rot * self.goal_init_rot
157 | + (1 - self.rsi_distance_rot) * self.episode_goal_rot
158 | )
159 | else:
160 | object_init_pos = self.goal_init_pos - self.goal_obj_offset
161 | object_init_rot = self.goal_init_rot
162 |
163 | # Set the position of the object
164 | self.sim.model.body_pos[self.object_bid] = object_init_pos
165 | self.sim.model.body_quat[self.object_bid] = object_init_rot
166 |
167 | # Create the target trajectory and set the initial position
168 | self.goal_pos_traj, self.goal_rot_traj = self.create_goal_trajectory(
169 | object_init_pos, object_init_rot
170 | )
171 | self.counter = 0
172 | self.set_die_pos_rot(self.counter)
173 |
174 | # Die friction changes
175 | self.sim.model.geom_friction[
176 | self.object_gid0 : self.object_gidn
177 | ] = self.np_random.uniform(**self.obj_friction_range)
178 |
179 | # Die and Target size changes
180 | del_size = self.np_random.uniform(**self.obj_size_change)
181 | # adjust size of target
182 | self.sim.model.geom_size[self.target_gid] = self.target_default_size + del_size
183 | # adjust size of die
184 | self.sim.model.geom_size[self.object_gid0 : self.object_gidn - 3][:, 1] = (
185 | self.object_default_size[:-3][:, 1] + del_size
186 | )
187 | self.sim.model.geom_size[self.object_gidn - 3 : self.object_gidn] = (
188 | self.object_default_size[-3:] + del_size
189 | )
190 | # adjust boundary of die
191 | object_gpos = self.sim.model.geom_pos[self.object_gid0 : self.object_gidn]
192 | self.sim.model.geom_pos[self.object_gid0 : self.object_gidn] = (
193 | object_gpos
194 | / abs(object_gpos + 1e-16)
195 | * (abs(self.object_default_pos) + del_size)
196 | )
197 |
198 | obs = MujocoEnv.reset(self, reset_qpos, reset_qvel)
199 | self.pos_dist = np.abs(np.linalg.norm(self.obs_dict["pos_err"], axis=-1))
200 | self.rot_dist = np.abs(np.linalg.norm(self.obs_dict["rot_err"], axis=-1))
201 | self.already_reset = True
202 | return obs
203 |
204 | def sample_goal_position(self):
205 | goal_pos = self.goal_init_pos + self.np_random.uniform(
206 | high=self.goal_pos[1], low=self.goal_pos[0], size=3
207 | )
208 | return goal_pos
209 |
210 | def sample_goal_orientation(self):
211 | x_low, x_high = (
212 | self.goal_rot_x[self.np_random.choice(len(self.goal_rot_x))]
213 | if self.goal_rot_x is not None
214 | else self.goal_rot
215 | )
216 | y_low, y_high = (
217 | self.goal_rot_y[self.np_random.choice(len(self.goal_rot_y))]
218 | if self.goal_rot_y is not None
219 | else self.goal_rot
220 | )
221 | z_low, z_high = (
222 | self.goal_rot_z[self.np_random.choice(len(self.goal_rot_z))]
223 | if self.goal_rot_z is not None
224 | else self.goal_rot
225 | )
226 |
227 | goal_rot_x = self.np_random.uniform(x_low, x_high)
228 | goal_rot_y = self.np_random.uniform(y_low, y_high)
229 | goal_rot_z = self.np_random.uniform(z_low, z_high)
230 | goal_rot_quat = euler2quat(np.array([goal_rot_x, goal_rot_y, goal_rot_z]))
231 | return goal_rot_quat
232 |
233 | def step(self, action):
234 | obs, reward, done, info = super().step(action)
235 | self.pos_dist = np.abs(np.linalg.norm(self.obs_dict["pos_err"], axis=-1))
236 | self.rot_dist = np.abs(np.linalg.norm(self.obs_dict["rot_err"], axis=-1))
237 | info.update(info.get("rwd_dict"))
238 |
239 | if self.already_reset:
240 | self.counter += 1
241 | self.set_die_pos_rot(self.counter)
242 | return obs, reward, done, info
243 |
244 | def create_goal_trajectory(self, object_init_pos, object_init_rot):
245 | traj_len = 1000 # Assumes it is larger than the episode len
246 |
247 | pos_traj = np.ones((traj_len, 3)) * self.episode_goal_pos
248 | rot_traj = np.ones((traj_len, 4)) * self.episode_goal_rot
249 |
250 | if (
251 | self.guided_trajectory_steps is not None
252 | ): # Softly reach the target position and orientation
253 | goal_init_pos = object_init_pos + self.goal_obj_offset
254 | guided_pos_traj = np.linspace(
255 | goal_init_pos, self.episode_goal_pos, self.guided_trajectory_steps
256 | )
257 | pos_traj[: self.guided_trajectory_steps] = guided_pos_traj
258 |
259 | guided_rot_traj = np.linspace(
260 | object_init_rot, self.episode_goal_rot, self.guided_trajectory_steps
261 | )
262 | rot_traj[: self.guided_trajectory_steps] = guided_rot_traj
263 | return pos_traj, rot_traj
264 |
265 | def set_die_pos_rot(self, counter):
266 | self.sim.model.body_pos[self.goal_bid] = self.goal_pos_traj[counter]
267 | self.sim.model.body_quat[self.goal_bid] = self.goal_rot_traj[counter]
268 |
--------------------------------------------------------------------------------
/src/main_ant.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import argparse
4 | import torch.nn as nn
5 | from datetime import datetime
6 | from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback
7 | from stable_baselines3.common.monitor import Monitor
8 | from stable_baselines3.common.vec_env import VecNormalize
9 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
10 | from datetime import datetime
11 | from definitions import ROOT_DIR
12 | from envs.environment_factory import EnvironmentFactory
13 | from metrics.custom_callbacks import EnvDumpCallback, TensorboardCallback
14 | from train.trainer import MyoTrainer
15 | from models.sac.policies import LatticeSACPolicy
16 |
17 |
18 | parser = argparse.ArgumentParser(description='Main script to train an agent')
19 |
20 | parser.add_argument('--seed', type=int, default=0,
21 | help='Seed for random number generator')
22 | parser.add_argument('--freq', type=int, default=1,
23 | help='SDE sample frequency')
24 | parser.add_argument('--use_sde', action='store_true', default=False,
25 | help='Flag to use SDE')
26 | parser.add_argument('--use_lattice', action='store_true', default=False,
27 | help='Flag to use lattice')
28 | parser.add_argument('--log_std_init', type=float, default=0.0,
29 | help='Initial log standard deviation')
30 | parser.add_argument('--env_path', type=str,
31 | help='Path to environment file')
32 | parser.add_argument('--model_path', type=str,
33 | help='Path to model file')
34 | parser.add_argument('--num_envs', type=int, default=16,
35 | help='Number of parallel environments')
36 | parser.add_argument('--device', type=str, default="cuda",
37 | help='Device, cuda or cpu')
38 | parser.add_argument('--std_reg', type=float, default=0.0,
39 | help='Additional independent std for the multivariate gaussian (only for lattice)')
40 | args = parser.parse_args()
41 |
42 | # define constants
43 | ENV_NAME = "AntBulletEnv"
44 |
45 | now = datetime.now().strftime("%Y-%m-%d/%H-%M-%S")
46 |
47 | if args.model_path is not None:
48 | model_name = args.model_path.split("/")[-2]
49 | else:
50 | model_name = None
51 |
52 | TENSORBOARD_LOG = (
53 | os.path.join(ROOT_DIR, "output", "training", now)
54 | + f"_ant_sde_{args.use_sde}_lattice_{args.use_lattice}_freq_{args.freq}_log_std_init_{args.log_std_init}_std_reg_{args.std_reg}_sac_seed_{args.seed}_resume_{model_name}"
55 | )
56 |
57 | # Reward structure and task parameters:
58 | config = {
59 | }
60 |
61 | max_episode_steps = 1000
62 |
63 | model_config = dict(
64 | policy=LatticeSACPolicy,
65 | device=args.device,
66 | learning_rate=3e-4,
67 | buffer_size=300_000,
68 | learning_starts=10000,
69 | batch_size=256,
70 | tau=0.02,
71 | gamma=0.98,
72 | train_freq=(8, "step"),
73 | gradient_steps=8,
74 | action_noise=None,
75 | replay_buffer_class=None,
76 | ent_coef="auto",
77 | target_update_interval=1,
78 | target_entropy="auto",
79 | seed=args.seed,
80 | use_sde=args.use_sde,
81 | sde_sample_freq=args.freq,
82 | policy_kwargs=dict(
83 | use_lattice=args.use_lattice,
84 | use_expln=True,
85 | log_std_init=args.log_std_init,
86 | activation_fn=nn.GELU,
87 | net_arch=dict(pi=[400, 300], qf=[400, 300]),
88 | std_clip=(1e-3, 1),
89 | expln_eps=1e-6,
90 | clip_mean=2.0,
91 | std_reg=args.std_reg
92 | ),
93 | )
94 |
95 | # Function that creates and monitors vectorized environments:
96 | def make_parallel_envs(env_config, num_env, start_index=0):
97 | def make_env(_):
98 | def _thunk():
99 | env = EnvironmentFactory.create(ENV_NAME, **env_config)
100 | env.seed(args.seed)
101 | env._max_episode_steps = max_episode_steps
102 | env = Monitor(env, TENSORBOARD_LOG)
103 | return env
104 |
105 | return _thunk
106 |
107 | return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)])
108 |
109 |
110 | if __name__ == "__main__":
111 | # ensure tensorboard log directory exists and copy this file to track
112 | os.makedirs(TENSORBOARD_LOG, exist_ok=True)
113 | shutil.copy(os.path.abspath(__file__), TENSORBOARD_LOG)
114 |
115 | # Create and wrap the training and evaluations environments
116 | envs = make_parallel_envs(config, args.num_envs)
117 |
118 | if args.env_path is not None:
119 | envs = VecNormalize.load(args.env_path, envs)
120 | else:
121 | envs = VecNormalize(envs)
122 |
123 | # Define callbacks for evaluation and saving the agent
124 | eval_callback = EvalCallback(
125 | eval_env=envs,
126 | callback_on_new_best=EnvDumpCallback(TENSORBOARD_LOG, verbose=0),
127 | n_eval_episodes=10,
128 | best_model_save_path=TENSORBOARD_LOG,
129 | log_path=TENSORBOARD_LOG,
130 | eval_freq=10_000,
131 | deterministic=True,
132 | render=False,
133 | verbose=1,
134 | )
135 |
136 | checkpoint_callback = CheckpointCallback(
137 | save_freq=25_000,
138 | save_path=TENSORBOARD_LOG,
139 | save_vecnormalize=True,
140 | verbose=1,
141 | )
142 |
143 | tensorboard_callback = TensorboardCallback(
144 | info_keywords=(
145 | )
146 | )
147 |
148 | # Define trainer
149 | trainer = MyoTrainer(
150 | algo="sac",
151 | envs=envs,
152 | env_config=config,
153 | load_model_path=args.model_path,
154 | log_dir=TENSORBOARD_LOG,
155 | model_config=model_config,
156 | callbacks=[eval_callback, checkpoint_callback, tensorboard_callback],
157 | timesteps=10_000_000,
158 | )
159 |
160 | # Train agent
161 | trainer.train(total_timesteps=trainer.timesteps)
162 | trainer.save()
163 |
--------------------------------------------------------------------------------
/src/main_baoding.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import argparse
4 | import torch.nn as nn
5 | from datetime import datetime
6 | from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback
7 | from stable_baselines3.common.monitor import Monitor
8 | from stable_baselines3.common.vec_env import VecNormalize
9 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
10 | from datetime import datetime
11 | from definitions import ROOT_DIR
12 | from envs.environment_factory import EnvironmentFactory
13 | from metrics.custom_callbacks import EnvDumpCallback, TensorboardCallback
14 | from train.trainer import MyoTrainer
15 | from models.ppo.policies import LatticeRecurrentActorCriticPolicy
16 |
17 |
18 | parser = argparse.ArgumentParser(description='Main script to train an agent')
19 |
20 | parser.add_argument('--seed', type=int, default=0,
21 | help='Seed for random number generator')
22 | parser.add_argument('--freq', type=int, default=1,
23 | help='SDE sample frequency')
24 | parser.add_argument('--use_sde', action='store_true', default=False,
25 | help='Flag to use SDE')
26 | parser.add_argument('--use_lattice', action='store_true', default=False,
27 | help='Flag to use lattice')
28 | parser.add_argument('--log_std_init', type=float, default=0.0,
29 | help='Initial log standard deviation')
30 | parser.add_argument('--env_path', type=str,
31 | help='Path to environment file')
32 | parser.add_argument('--model_path', type=str,
33 | help='Path to model file')
34 | parser.add_argument('--num_envs', type=int, default=16,
35 | help='Number of parallel environments')
36 | parser.add_argument('--device', type=str, default="cuda",
37 | help='Device, cuda or cpu')
38 | parser.add_argument('--std_reg', type=float, default=0.0,
39 | help='Additional independent std for the multivariate gaussian (only for lattice)')
40 |
41 | args = parser.parse_args()
42 |
43 | # define constants
44 | ENV_NAME = "CustomMyoBaodingBallsP1"
45 |
46 | now = datetime.now().strftime("%Y-%m-%d/%H-%M-%S")
47 |
48 | if args.model_path is not None:
49 | model_name = args.model_path.split("/")[-2]
50 | else:
51 | model_name = None
52 |
53 | TENSORBOARD_LOG = (
54 | os.path.join(ROOT_DIR, "output", "training", now)
55 | + f"_baoding_sde_{args.use_sde}_lattice_{args.use_lattice}_freq_{args.freq}_log_std_init_{args.log_std_init}_std_reg_{args.std_reg}_recurrent_ppo_seed_{args.seed}_resume_{model_name}"
56 | )
57 |
58 | # Reward structure and task parameters:
59 | config = {
60 | "seed": args.seed,
61 | "weighted_reward_keys": {
62 | "pos_dist_1": 1,
63 | "pos_dist_2": 1,
64 | "act_reg": 0,
65 | "alive": 1,
66 | "solved": 5,
67 | "done": 0,
68 | "sparse": 0,
69 | },
70 | }
71 |
72 | max_episode_steps = 200
73 |
74 | model_config = dict(
75 | policy=LatticeRecurrentActorCriticPolicy,
76 | device=args.device,
77 | batch_size=32,
78 | n_steps=128,
79 | learning_rate=2.55673e-05,
80 | ent_coef=3.62109e-06,
81 | clip_range=0.3,
82 | gamma=0.99,
83 | gae_lambda=0.9,
84 | max_grad_norm=0.7,
85 | vf_coef=0.835671,
86 | n_epochs=10,
87 | use_sde=args.use_sde,
88 | sde_sample_freq=args.freq,
89 | policy_kwargs=dict(
90 | use_lattice=args.use_lattice,
91 | use_expln=True,
92 | ortho_init=False,
93 | log_std_init=args.log_std_init,
94 | activation_fn=nn.ReLU,
95 | net_arch=[dict(pi=[256, 256], vf=[256, 256])],
96 | std_clip=(1e-3, 10),
97 | expln_eps=1e-6,
98 | full_std=False,
99 | std_reg=args.std_reg
100 | ),
101 | )
102 |
103 | # Function that creates and monitors vectorized environments:
104 | def make_parallel_envs(env_config, num_env, start_index=0):
105 | def make_env(_):
106 | def _thunk():
107 | env = EnvironmentFactory.create(ENV_NAME, **env_config)
108 | env.seed(args.seed)
109 | env._max_episode_steps = max_episode_steps
110 | env = Monitor(env, TENSORBOARD_LOG)
111 | return env
112 |
113 | return _thunk
114 |
115 | return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)])
116 |
117 |
118 | if __name__ == "__main__":
119 | # ensure tensorboard log directory exists and copy this file to track
120 | os.makedirs(TENSORBOARD_LOG, exist_ok=True)
121 | shutil.copy(os.path.abspath(__file__), TENSORBOARD_LOG)
122 |
123 | # Create and wrap the training and evaluations environments
124 | envs = make_parallel_envs(config, args.num_envs)
125 |
126 | if args.env_path is not None:
127 | envs = VecNormalize.load(args.env_path, envs)
128 | else:
129 | envs = VecNormalize(envs)
130 |
131 | # Define callbacks for evaluation and saving the agent
132 | eval_callback = EvalCallback(
133 | eval_env=envs,
134 | callback_on_new_best=EnvDumpCallback(TENSORBOARD_LOG, verbose=0),
135 | n_eval_episodes=10,
136 | best_model_save_path=TENSORBOARD_LOG,
137 | log_path=TENSORBOARD_LOG,
138 | eval_freq=10_000,
139 | deterministic=True,
140 | render=False,
141 | verbose=1,
142 | )
143 |
144 | checkpoint_callback = CheckpointCallback(
145 | save_freq=25_000,
146 | save_path=TENSORBOARD_LOG,
147 | save_vecnormalize=True,
148 | verbose=1,
149 | )
150 |
151 | tensorboard_callback = TensorboardCallback(
152 | info_keywords=(
153 | "pos_dist_1",
154 | "pos_dist_2",
155 | "act_reg",
156 | "alive",
157 | "solved",
158 | )
159 | )
160 |
161 | # Define trainer
162 | trainer = MyoTrainer(
163 | algo="recurrent_ppo",
164 | envs=envs,
165 | env_config=config,
166 | load_model_path=args.model_path,
167 | log_dir=TENSORBOARD_LOG,
168 | model_config=model_config,
169 | callbacks=[eval_callback, checkpoint_callback, tensorboard_callback],
170 | timesteps=20_000_000,
171 | )
172 |
173 | # Train agent
174 | trainer.train(total_timesteps=trainer.timesteps)
175 | trainer.save()
176 |
--------------------------------------------------------------------------------
/src/main_half_cheetah.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import argparse
4 | import torch.nn as nn
5 | from datetime import datetime
6 | from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback
7 | from stable_baselines3.common.monitor import Monitor
8 | from stable_baselines3.common.vec_env import VecNormalize
9 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
10 | from datetime import datetime
11 | from definitions import ROOT_DIR
12 | from envs.environment_factory import EnvironmentFactory
13 | from metrics.custom_callbacks import EnvDumpCallback, TensorboardCallback
14 | from train.trainer import MyoTrainer
15 | from models.sac.policies import LatticeSACPolicy
16 |
17 |
18 | parser = argparse.ArgumentParser(description='Main script to train an agent')
19 |
20 | parser.add_argument('--seed', type=int, default=0,
21 | help='Seed for random number generator')
22 | parser.add_argument('--freq', type=int, default=1,
23 | help='SDE sample frequency')
24 | parser.add_argument('--use_sde', action='store_true', default=False,
25 | help='Flag to use SDE')
26 | parser.add_argument('--use_lattice', action='store_true', default=False,
27 | help='Flag to use lattice')
28 | parser.add_argument('--log_std_init', type=float, default=0.0,
29 | help='Initial log standard deviation')
30 | parser.add_argument('--env_path', type=str,
31 | help='Path to environment file')
32 | parser.add_argument('--model_path', type=str,
33 | help='Path to model file')
34 | parser.add_argument('--num_envs', type=int, default=16,
35 | help='Number of parallel environments')
36 | parser.add_argument('--device', type=str, default="cuda",
37 | help='Device, cuda or cpu')
38 | parser.add_argument('--std_reg', type=float, default=0.0,
39 | help='Additional independent std for the multivariate gaussian (only for lattice)')
40 | args = parser.parse_args()
41 |
42 | # define constants
43 | ENV_NAME = "HalfCheetahBulletEnv"
44 |
45 | now = datetime.now().strftime("%Y-%m-%d/%H-%M-%S")
46 |
47 | if args.model_path is not None:
48 | model_name = args.model_path.split("/")[-2]
49 | else:
50 | model_name = None
51 |
52 | TENSORBOARD_LOG = (
53 | os.path.join(ROOT_DIR, "output", "training", now)
54 | + f"_half_cheetah_sde_{args.use_sde}_lattice_{args.use_lattice}_freq_{args.freq}_log_std_init_{args.log_std_init}_std_reg_{args.std_reg}_sac_seed_{args.seed}_resume_{model_name}"
55 | )
56 |
57 | # Reward structure and task parameters:
58 | config = {
59 | }
60 |
61 | max_episode_steps = 1000
62 | num_envs = args.num_envs
63 |
64 | model_config = dict(
65 | policy=LatticeSACPolicy,
66 | device=args.device,
67 | learning_rate=3e-4,
68 | buffer_size=300_000,
69 | learning_starts=10000,
70 | batch_size=256,
71 | tau=0.02,
72 | gamma=0.98,
73 | train_freq=(8, "step"),
74 | gradient_steps=8,
75 | action_noise=None,
76 | replay_buffer_class=None,
77 | ent_coef="auto",
78 | target_update_interval=1,
79 | target_entropy="auto",
80 | seed=args.seed,
81 | use_sde=args.use_sde,
82 | sde_sample_freq=args.freq,
83 | policy_kwargs=dict(
84 | use_lattice=args.use_lattice,
85 | use_expln=True,
86 | log_std_init=args.log_std_init,
87 | activation_fn=nn.GELU,
88 | net_arch=dict(pi=[400, 300], qf=[400, 300]),
89 | std_clip=(1e-3, 10),
90 | expln_eps=1e-6,
91 | clip_mean=2.0,
92 | std_reg=args.std_reg
93 | ),
94 | )
95 |
96 | # Function that creates and monitors vectorized environments:
97 | def make_parallel_envs(env_config, num_env, start_index=0):
98 | def make_env(_):
99 | def _thunk():
100 | env = EnvironmentFactory.create(ENV_NAME, **env_config)
101 | env.seed(args.seed)
102 | env._max_episode_steps = max_episode_steps
103 | env = Monitor(env, TENSORBOARD_LOG)
104 | return env
105 |
106 | return _thunk
107 |
108 | return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)])
109 |
110 |
111 | if __name__ == "__main__":
112 | # ensure tensorboard log directory exists and copy this file to track
113 | os.makedirs(TENSORBOARD_LOG, exist_ok=True)
114 | shutil.copy(os.path.abspath(__file__), TENSORBOARD_LOG)
115 |
116 | # Create and wrap the training and evaluations environments
117 | envs = make_parallel_envs(config, num_envs)
118 |
119 | if args.env_path is not None:
120 | envs = VecNormalize.load(args.env_path, envs)
121 | else:
122 | envs = VecNormalize(envs)
123 |
124 | # Define callbacks for evaluation and saving the agent
125 | eval_callback = EvalCallback(
126 | eval_env=envs,
127 | callback_on_new_best=EnvDumpCallback(TENSORBOARD_LOG, verbose=0),
128 | n_eval_episodes=10,
129 | best_model_save_path=TENSORBOARD_LOG,
130 | log_path=TENSORBOARD_LOG,
131 | eval_freq=10_000,
132 | deterministic=True,
133 | render=False,
134 | verbose=1,
135 | )
136 |
137 | checkpoint_callback = CheckpointCallback(
138 | save_freq=25_000,
139 | save_path=TENSORBOARD_LOG,
140 | save_vecnormalize=True,
141 | verbose=1,
142 | )
143 |
144 | tensorboard_callback = TensorboardCallback(
145 | info_keywords=(
146 | )
147 | )
148 |
149 | # Define trainer
150 | trainer = MyoTrainer(
151 | algo="sac",
152 | envs=envs,
153 | env_config=config,
154 | load_model_path=args.model_path,
155 | log_dir=TENSORBOARD_LOG,
156 | model_config=model_config,
157 | callbacks=[eval_callback, checkpoint_callback, tensorboard_callback],
158 | timesteps=10_000_000,
159 | )
160 |
161 | # Train agent
162 | trainer.train(total_timesteps=trainer.timesteps)
163 | trainer.save()
164 |
--------------------------------------------------------------------------------
/src/main_hopper.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import argparse
4 | import torch.nn as nn
5 | from datetime import datetime
6 | from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback
7 | from stable_baselines3.common.monitor import Monitor
8 | from stable_baselines3.common.vec_env import VecNormalize
9 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
10 | from datetime import datetime
11 | from definitions import ROOT_DIR
12 | from envs.environment_factory import EnvironmentFactory
13 | from metrics.custom_callbacks import EnvDumpCallback, TensorboardCallback
14 | from train.trainer import MyoTrainer
15 | from models.sac.policies import LatticeSACPolicy
16 |
17 |
18 | parser = argparse.ArgumentParser(description='Main script to train an agent')
19 |
20 | parser.add_argument('--seed', type=int, default=0,
21 | help='Seed for random number generator')
22 | parser.add_argument('--freq', type=int, default=1,
23 | help='SDE sample frequency')
24 | parser.add_argument('--use_sde', action='store_true', default=False,
25 | help='Flag to use SDE')
26 | parser.add_argument('--use_lattice', action='store_true', default=False,
27 | help='Flag to use lattice')
28 | parser.add_argument('--log_std_init', type=float, default=0.0,
29 | help='Initial log standard deviation')
30 | parser.add_argument('--env_path', type=str,
31 | help='Path to environment file')
32 | parser.add_argument('--model_path', type=str,
33 | help='Path to model file')
34 | parser.add_argument('--num_envs', type=int, default=16,
35 | help='Number of parallel environments')
36 | parser.add_argument('--device', type=str, default="cuda",
37 | help='Device, cuda or cpu')
38 | parser.add_argument('--std_reg', type=float, default=0.0,
39 | help='Additional independent std for the multivariate gaussian (only for lattice)')
40 | args = parser.parse_args()
41 |
42 | # define constants
43 | ENV_NAME = "HopperBulletEnv"
44 |
45 | now = datetime.now().strftime("%Y-%m-%d/%H-%M-%S")
46 |
47 | if args.model_path is not None:
48 | model_name = args.model_path.split("/")[-2]
49 | else:
50 | model_name = None
51 |
52 | TENSORBOARD_LOG = (
53 | os.path.join(ROOT_DIR, "output", "training", now)
54 | + f"_hopper_sde_{args.use_sde}_lattice_{args.use_lattice}_freq_{args.freq}_log_std_init_{args.log_std_init}_std_reg_{args.std_reg}_sac_seed_{args.seed}_resume_{model_name}"
55 | )
56 |
57 | # Reward structure and task parameters:
58 | config = {
59 | }
60 |
61 | max_episode_steps = 1000 # default: 1000
62 |
63 | model_config = dict(
64 | policy=LatticeSACPolicy,
65 | device=args.device,
66 | learning_rate=3e-4,
67 | buffer_size=300_000,
68 | learning_starts=10000,
69 | batch_size=256,
70 | tau=0.02,
71 | gamma=0.98,
72 | train_freq=(8, "step"),
73 | gradient_steps=8,
74 | action_noise=None,
75 | replay_buffer_class=None,
76 | ent_coef="auto",
77 | target_update_interval=1,
78 | target_entropy="auto",
79 | seed=args.seed,
80 | use_sde=args.use_sde,
81 | sde_sample_freq=args.freq,
82 | policy_kwargs=dict(
83 | use_lattice=args.use_lattice,
84 | use_expln=True,
85 | log_std_init=args.log_std_init,
86 | activation_fn=nn.GELU,
87 | net_arch=dict(pi=[400, 300], qf=[400, 300]),
88 | std_clip=(1e-3, 10),
89 | expln_eps=1e-6,
90 | clip_mean=2.0,
91 | std_reg=args.std_reg
92 | ),
93 | )
94 |
95 | # Function that creates and monitors vectorized environments:
96 | def make_parallel_envs(env_config, num_env, start_index=0):
97 | def make_env(_):
98 | def _thunk():
99 | env = EnvironmentFactory.create(ENV_NAME, **env_config)
100 | env.seed(args.seed)
101 | env._max_episode_steps = max_episode_steps
102 | env = Monitor(env, TENSORBOARD_LOG)
103 | return env
104 |
105 | return _thunk
106 |
107 | return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)])
108 |
109 |
110 | if __name__ == "__main__":
111 | # ensure tensorboard log directory exists and copy this file to track
112 | os.makedirs(TENSORBOARD_LOG, exist_ok=True)
113 | shutil.copy(os.path.abspath(__file__), TENSORBOARD_LOG)
114 |
115 | # Create and wrap the training and evaluations environments
116 | envs = make_parallel_envs(config, args.num_envs)
117 |
118 | if args.env_path is not None:
119 | envs = VecNormalize.load(args.env_path, envs)
120 | else:
121 | envs = VecNormalize(envs)
122 |
123 | # Define callbacks for evaluation and saving the agent
124 | eval_callback = EvalCallback(
125 | eval_env=envs,
126 | callback_on_new_best=EnvDumpCallback(TENSORBOARD_LOG, verbose=0),
127 | n_eval_episodes=10,
128 | best_model_save_path=TENSORBOARD_LOG,
129 | log_path=TENSORBOARD_LOG,
130 | eval_freq=10_000,
131 | deterministic=True,
132 | render=False,
133 | verbose=1,
134 | )
135 |
136 | checkpoint_callback = CheckpointCallback(
137 | save_freq=25_000,
138 | save_path=TENSORBOARD_LOG,
139 | save_vecnormalize=True,
140 | verbose=1,
141 | )
142 |
143 | tensorboard_callback = TensorboardCallback(
144 | info_keywords=(
145 | )
146 | )
147 |
148 | # Define trainer
149 | trainer = MyoTrainer(
150 | algo="sac",
151 | envs=envs,
152 | env_config=config,
153 | load_model_path=args.model_path,
154 | log_dir=TENSORBOARD_LOG,
155 | model_config=model_config,
156 | callbacks=[eval_callback, checkpoint_callback, tensorboard_callback],
157 | timesteps=15_000_000,
158 | )
159 |
160 | # Train agent
161 | trainer.train(total_timesteps=trainer.timesteps)
162 | trainer.save()
163 |
--------------------------------------------------------------------------------
/src/main_humanoid.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import argparse
4 | import torch.nn as nn
5 | from datetime import datetime
6 | from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback
7 | from stable_baselines3.common.monitor import Monitor
8 | from stable_baselines3.common.vec_env import VecNormalize
9 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
10 | from datetime import datetime
11 | from definitions import ROOT_DIR
12 | from envs.environment_factory import EnvironmentFactory
13 | from metrics.custom_callbacks import EnvDumpCallback, TensorboardCallback
14 | from train.trainer import MyoTrainer
15 | from models.sac.policies import LatticeSACPolicy
16 |
17 |
18 | parser = argparse.ArgumentParser(description='Main script to train an agent')
19 |
20 | parser.add_argument('--seed', type=int, default=0,
21 | help='Seed for random number generator')
22 | parser.add_argument('--freq', type=int, default=1,
23 | help='SDE sample frequency')
24 | parser.add_argument('--use_sde', action='store_true', default=False,
25 | help='Flag to use SDE')
26 | parser.add_argument('--use_lattice', action='store_true', default=False,
27 | help='Flag to use lattice')
28 | parser.add_argument('--log_std_init', type=float, default=0.0,
29 | help='Initial log standard deviation')
30 | parser.add_argument('--env_path', type=str,
31 | help='Path to environment file')
32 | parser.add_argument('--model_path', type=str,
33 | help='Path to model file')
34 | parser.add_argument('--num_envs', type=int, default=16,
35 | help='Number of parallel environments')
36 | parser.add_argument('--device', type=str, default="cuda",
37 | help='Device, cuda or cpu')
38 | parser.add_argument('--std_reg', type=float, default=0.0,
39 | help='Additional independent std for the multivariate gaussian (only for lattice)')
40 | args = parser.parse_args()
41 |
42 | # define constants
43 | ENV_NAME = "HumanoidBulletEnv"
44 |
45 | now = datetime.now().strftime("%Y-%m-%d/%H-%M-%S")
46 |
47 | if args.model_path is not None:
48 | model_name = args.model_path.split("/")[-2]
49 | else:
50 | model_name = None
51 |
52 | TENSORBOARD_LOG = (
53 | os.path.join(ROOT_DIR, "output", "training", now)
54 | + f"_humanoid_sde_{args.use_sde}_lattice_{args.use_lattice}_freq_{args.freq}_log_std_init_{args.log_std_init}_std_reg_{args.std_reg}_sac_seed_{args.seed}_resume_{model_name}"
55 | )
56 |
57 | # Reward structure and task parameters:
58 | config = {
59 | }
60 |
61 | max_episode_steps = 1000
62 |
63 | model_config = dict(
64 | policy=LatticeSACPolicy,
65 | device=args.device,
66 | learning_rate=3e-4,
67 | buffer_size=300_000,
68 | learning_starts=10000,
69 | batch_size=256,
70 | tau=0.02,
71 | gamma=0.98,
72 | train_freq=(8, "step"),
73 | gradient_steps=8,
74 | action_noise=None,
75 | replay_buffer_class=None,
76 | ent_coef="auto",
77 | target_update_interval=1,
78 | target_entropy="auto",
79 | seed=args.seed,
80 | use_sde=args.use_sde,
81 | sde_sample_freq=args.freq,
82 | policy_kwargs=dict(
83 | use_lattice=args.use_lattice,
84 | use_expln=True,
85 | log_std_init=args.log_std_init,
86 | activation_fn=nn.GELU,
87 | net_arch=dict(pi=[400, 300], qf=[400, 300]),
88 | std_clip=(1e-3, 1),
89 | expln_eps=1e-6,
90 | clip_mean=2.0,
91 | std_reg=args.std_reg
92 | ),
93 | )
94 |
95 | # Function that creates and monitors vectorized environments:
96 | def make_parallel_envs(env_config, num_env, start_index=0):
97 | def make_env(_):
98 | def _thunk():
99 | env = EnvironmentFactory.create(ENV_NAME, **env_config)
100 | env.seed(args.seed)
101 | env._max_episode_steps = max_episode_steps
102 | env = Monitor(env, TENSORBOARD_LOG)
103 | return env
104 |
105 | return _thunk
106 |
107 | return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)])
108 |
109 |
110 | if __name__ == "__main__":
111 | # ensure tensorboard log directory exists and copy this file to track
112 | os.makedirs(TENSORBOARD_LOG, exist_ok=True)
113 | shutil.copy(os.path.abspath(__file__), TENSORBOARD_LOG)
114 |
115 | # Create and wrap the training and evaluations environments
116 | envs = make_parallel_envs(config, args.num_envs)
117 |
118 | if args.env_path is not None:
119 | envs = VecNormalize.load(args.env_path, envs)
120 | else:
121 | envs = VecNormalize(envs)
122 |
123 | # Define callbacks for evaluation and saving the agent
124 | eval_callback = EvalCallback(
125 | eval_env=envs,
126 | callback_on_new_best=EnvDumpCallback(TENSORBOARD_LOG, verbose=0),
127 | n_eval_episodes=10,
128 | best_model_save_path=TENSORBOARD_LOG,
129 | log_path=TENSORBOARD_LOG,
130 | eval_freq=10_000,
131 | deterministic=True,
132 | render=False,
133 | verbose=1,
134 | )
135 |
136 | checkpoint_callback = CheckpointCallback(
137 | save_freq=25_000,
138 | save_path=TENSORBOARD_LOG,
139 | save_vecnormalize=True,
140 | verbose=1,
141 | )
142 |
143 | tensorboard_callback = TensorboardCallback(
144 | info_keywords=(
145 | )
146 | )
147 |
148 | # Define trainer
149 | trainer = MyoTrainer(
150 | algo="sac",
151 | envs=envs,
152 | env_config=config,
153 | load_model_path=args.model_path,
154 | log_dir=TENSORBOARD_LOG,
155 | model_config=model_config,
156 | callbacks=[eval_callback, checkpoint_callback, tensorboard_callback],
157 | timesteps=20_000_000,
158 | )
159 |
160 | # Train agent
161 | trainer.train(total_timesteps=trainer.timesteps)
162 | trainer.save()
163 |
--------------------------------------------------------------------------------
/src/main_pen.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import argparse
4 | import torch.nn as nn
5 | from datetime import datetime
6 | from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback
7 | from stable_baselines3.common.monitor import Monitor
8 | from stable_baselines3.common.vec_env import VecNormalize
9 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
10 | from datetime import datetime
11 | from definitions import ROOT_DIR
12 | from envs.environment_factory import EnvironmentFactory
13 | from metrics.custom_callbacks import EnvDumpCallback, TensorboardCallback
14 | from train.trainer import MyoTrainer
15 | from models.ppo.policies import LatticeRecurrentActorCriticPolicy
16 |
17 |
18 | parser = argparse.ArgumentParser(description='Main script to train an agent')
19 |
20 | parser.add_argument('--seed', type=int, default=0,
21 | help='Seed for random number generator')
22 | parser.add_argument('--freq', type=int, default=1,
23 | help='SDE sample frequency')
24 | parser.add_argument('--use_sde', action='store_true', default=False,
25 | help='Flag to use SDE')
26 | parser.add_argument('--use_lattice', action='store_true', default=False,
27 | help='Flag to use lattice')
28 | parser.add_argument('--log_std_init', type=float, default=0.0,
29 | help='Initial log standard deviation')
30 | parser.add_argument('--env_path', type=str,
31 | help='Path to environment file')
32 | parser.add_argument('--model_path', type=str,
33 | help='Path to model file')
34 | parser.add_argument('--num_envs', type=int, default=16,
35 | help='Number of parallel environments')
36 | parser.add_argument('--device', type=str, default="cuda",
37 | help='Device, cuda or cpu')
38 | parser.add_argument('--std_reg', type=float, default=0.0,
39 | help='Additional independent std for the multivariate gaussian (only for lattice)')
40 |
41 | args = parser.parse_args()
42 |
43 | # define constants
44 | ENV_NAME = "CustomMyoPenTwirlRandom"
45 |
46 | now = datetime.now().strftime("%Y-%m-%d/%H-%M-%S")
47 |
48 | if args.model_path is not None:
49 | model_name = args.model_path.split("/")[-2]
50 | else:
51 | model_name = None
52 |
53 | TENSORBOARD_LOG = (
54 | os.path.join(ROOT_DIR, "output", "training", now)
55 | + f"_pen_sde_{args.use_sde}_lattice_{args.use_lattice}_freq_{args.freq}_log_std_init_{args.log_std_init}_std_reg_{args.std_reg}_recurrent_ppo_seed_{args.seed}_resume_{model_name}"
56 | )
57 |
58 | # Reward structure and task parameters:
59 | config = {
60 | "seed": args.seed,
61 | "weighted_reward_keys": {
62 | "pos_align": 0,
63 | "rot_align": 0,
64 | "pos_align_diff": 1e2,
65 | "rot_align_diff": 1e2,
66 | "alive": 0,
67 | "act_reg": 0,
68 | "drop": 0,
69 | "bonus": 0,
70 | "solved": 1,
71 | "done": 0,
72 | "sparse": 0,
73 | },
74 | "goal_orient_range": (-1, 1),
75 | "enable_rsi": False,
76 | "rsi_distance": None,
77 | }
78 |
79 | max_episode_steps = 100
80 |
81 | model_config = dict(
82 | policy=LatticeRecurrentActorCriticPolicy,
83 | device=args.device,
84 | batch_size=32,
85 | n_steps=128,
86 | learning_rate=2.55673e-05,
87 | ent_coef=3.62109e-06,
88 | clip_range=0.3,
89 | gamma=0.99,
90 | gae_lambda=0.9,
91 | max_grad_norm=0.7,
92 | vf_coef=0.835671,
93 | n_epochs=10,
94 | use_sde=args.use_sde,
95 | sde_sample_freq=args.freq,
96 | policy_kwargs=dict(
97 | use_lattice=args.use_lattice,
98 | use_expln=True,
99 | ortho_init=False,
100 | log_std_init=args.log_std_init,
101 | activation_fn=nn.ReLU,
102 | net_arch=[dict(pi=[256, 256], vf=[256, 256])],
103 | std_clip=(1e-3, 10),
104 | expln_eps=1e-6,
105 | full_std=False,
106 | std_reg=args.std_reg
107 | ),
108 | )
109 |
110 | # Function that creates and monitors vectorized environments:
111 | def make_parallel_envs(env_config, num_env, start_index=0):
112 | def make_env(_):
113 | def _thunk():
114 | env = EnvironmentFactory.create(ENV_NAME, **env_config)
115 | env.seed(args.seed)
116 | env._max_episode_steps = max_episode_steps
117 | env = Monitor(env, TENSORBOARD_LOG)
118 | return env
119 |
120 | return _thunk
121 |
122 | return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)])
123 |
124 |
125 | if __name__ == "__main__":
126 | # ensure tensorboard log directory exists and copy this file to track
127 | os.makedirs(TENSORBOARD_LOG, exist_ok=True)
128 | shutil.copy(os.path.abspath(__file__), TENSORBOARD_LOG)
129 |
130 | # Create and wrap the training and evaluations environments
131 | envs = make_parallel_envs(config, args.num_envs)
132 |
133 | if args.env_path is not None:
134 | envs = VecNormalize.load(args.env_path, envs)
135 | else:
136 | envs = VecNormalize(envs)
137 |
138 | # Define callbacks for evaluation and saving the agent
139 | eval_callback = EvalCallback(
140 | eval_env=envs,
141 | callback_on_new_best=EnvDumpCallback(TENSORBOARD_LOG, verbose=0),
142 | n_eval_episodes=10,
143 | best_model_save_path=TENSORBOARD_LOG,
144 | log_path=TENSORBOARD_LOG,
145 | eval_freq=10_000,
146 | deterministic=True,
147 | render=False,
148 | verbose=1,
149 | )
150 |
151 | checkpoint_callback = CheckpointCallback(
152 | save_freq=25_000,
153 | save_path=TENSORBOARD_LOG,
154 | save_vecnormalize=True,
155 | verbose=1,
156 | )
157 |
158 | tensorboard_callback = TensorboardCallback(
159 | info_keywords=list(config["weighted_reward_keys"].keys())
160 | )
161 |
162 | # Define trainer
163 | trainer = MyoTrainer(
164 | algo="recurrent_ppo",
165 | envs=envs,
166 | env_config=config,
167 | load_model_path=args.model_path,
168 | log_dir=TENSORBOARD_LOG,
169 | model_config=model_config,
170 | callbacks=[eval_callback, checkpoint_callback, tensorboard_callback],
171 | timesteps=20_000_000,
172 | )
173 |
174 | # Train agent
175 | trainer.train(total_timesteps=trainer.timesteps)
176 | trainer.save()
177 |
--------------------------------------------------------------------------------
/src/main_pose_elbow.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import argparse
4 | import torch.nn as nn
5 | from datetime import datetime
6 | from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback
7 | from stable_baselines3.common.monitor import Monitor
8 | from stable_baselines3.common.vec_env import VecNormalize
9 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
10 | from datetime import datetime
11 | from definitions import ROOT_DIR
12 | from envs.environment_factory import EnvironmentFactory
13 | from metrics.custom_callbacks import EnvDumpCallback, TensorboardCallback
14 | from train.trainer import MyoTrainer
15 | from models.ppo.policies import LatticeRecurrentActorCriticPolicy
16 |
17 |
18 | parser = argparse.ArgumentParser(description='Main script to train an agent')
19 |
20 | parser.add_argument('--seed', type=int, default=0,
21 | help='Seed for random number generator')
22 | parser.add_argument('--freq', type=int, default=1,
23 | help='SDE sample frequency')
24 | parser.add_argument('--use_sde', action='store_true', default=False,
25 | help='Flag to use SDE')
26 | parser.add_argument('--use_lattice', action='store_true', default=False,
27 | help='Flag to use lattice')
28 | parser.add_argument('--log_std_init', type=float, default=-1.0,
29 | help='Initial log standard deviation')
30 | parser.add_argument('--env_path', type=str,
31 | help='Path to environment file')
32 | parser.add_argument('--model_path', type=str,
33 | help='Path to model file')
34 | parser.add_argument('--num_envs', type=int, default=16,
35 | help='Number of parallel environments')
36 | parser.add_argument('--device', type=str, default="cuda",
37 | help='Device, cuda or cpu')
38 | parser.add_argument('--std_reg', type=float, default=0.0,
39 | help='Additional independent std for the multivariate gaussian (only for lattice)')
40 | parser.add_argument('--alpha', type=float, default=1.0,
41 | help='Weight of the uncorrelated noise (only for lattice)')
42 |
43 | args = parser.parse_args()
44 |
45 | # define constants
46 | ENV_NAME = "CustomMyoElbowPoseRandom"
47 |
48 | now = datetime.now().strftime("%Y-%m-%d/%H-%M-%S")
49 |
50 | if args.model_path is not None:
51 | model_name = args.model_path.split("/")[-2]
52 | else:
53 | model_name = None
54 |
55 | TENSORBOARD_LOG = (
56 | os.path.join(ROOT_DIR, "output", "training", now)
57 | + f"_elbow_pose_sde_{args.use_sde}_lattice_{args.use_lattice}_freq_{args.freq}_log_std_init_{args.log_std_init}_std_reg_{args.std_reg}_alpha_{args.alpha}_recurrent_ppo_seed_{args.seed}_resume_{model_name}"
58 | )
59 |
60 | # Reward structure and task parameters:
61 | config = {
62 | "weighted_reward_keys": {
63 | "pose": 1,
64 | "bonus": 0,
65 | "penalty": 1,
66 | "act_reg": 0,
67 | "solved": 1,
68 | "done": 0,
69 | "sparse": 0,
70 | },
71 | "reset_type": "init",
72 | "sds_distance": None,
73 | "weight_bodyname": None,
74 | "weight_range": None,
75 | "target_distance": 1,
76 | }
77 |
78 | max_episode_steps = 100
79 |
80 | model_config = dict(
81 | policy=LatticeRecurrentActorCriticPolicy,
82 | device=args.device,
83 | batch_size=32,
84 | n_steps=128,
85 | learning_rate=2.55673e-05,
86 | ent_coef=3.62109e-06,
87 | clip_range=0.3,
88 | gamma=0.99,
89 | gae_lambda=0.9,
90 | max_grad_norm=0.7,
91 | vf_coef=0.835671,
92 | n_epochs=10,
93 | use_sde=args.use_sde,
94 | sde_sample_freq=args.freq,
95 | policy_kwargs=dict(
96 | use_lattice=args.use_lattice,
97 | use_expln=True,
98 | ortho_init=False,
99 | log_std_init=args.log_std_init,
100 | activation_fn=nn.ReLU,
101 | net_arch=[dict(pi=[256, 256], vf=[256, 256])],
102 | std_clip=(1e-3, 10),
103 | expln_eps=1e-6,
104 | full_std=False,
105 | std_reg=args.std_reg,
106 | alpha=args.alpha
107 | ),
108 | )
109 |
110 | # Function that creates and monitors vectorized environments:
111 | def make_parallel_envs(env_config, num_env, start_index=0):
112 | def make_env(_):
113 | def _thunk():
114 | env = EnvironmentFactory.create(ENV_NAME, **env_config)
115 | env.seed(args.seed)
116 | env._max_episode_steps = max_episode_steps
117 | env = Monitor(env, TENSORBOARD_LOG)
118 | return env
119 |
120 | return _thunk
121 |
122 | return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)])
123 |
124 |
125 | if __name__ == "__main__":
126 | # ensure tensorboard log directory exists and copy this file to track
127 | os.makedirs(TENSORBOARD_LOG, exist_ok=True)
128 | shutil.copy(os.path.abspath(__file__), TENSORBOARD_LOG)
129 |
130 | # Create and wrap the training and evaluations environments
131 | envs = make_parallel_envs(config, args.num_envs)
132 |
133 | if args.env_path is not None:
134 | envs = VecNormalize.load(args.env_path, envs)
135 | else:
136 | envs = VecNormalize(envs)
137 |
138 | # Define callbacks for evaluation and saving the agent
139 | eval_callback = EvalCallback(
140 | eval_env=envs,
141 | callback_on_new_best=EnvDumpCallback(TENSORBOARD_LOG, verbose=0),
142 | n_eval_episodes=10,
143 | best_model_save_path=TENSORBOARD_LOG,
144 | log_path=TENSORBOARD_LOG,
145 | eval_freq=10_000,
146 | deterministic=True,
147 | render=False,
148 | verbose=1,
149 | )
150 |
151 | checkpoint_callback = CheckpointCallback(
152 | save_freq=25_000,
153 | save_path=TENSORBOARD_LOG,
154 | save_vecnormalize=True,
155 | verbose=1,
156 | )
157 |
158 | tensorboard_callback = TensorboardCallback(
159 | info_keywords=(
160 | "pose",
161 | "bonus",
162 | "penalty",
163 | "act_reg",
164 | "done",
165 | "solved",
166 | "sparse",
167 | )
168 | )
169 |
170 | # Define trainer
171 | trainer = MyoTrainer(
172 | algo="recurrent_ppo",
173 | envs=envs,
174 | env_config=config,
175 | load_model_path=args.model_path,
176 | log_dir=TENSORBOARD_LOG,
177 | model_config=model_config,
178 | callbacks=[eval_callback, checkpoint_callback, tensorboard_callback],
179 | timesteps=600_000,
180 | )
181 |
182 | # Train agent
183 | trainer.train(total_timesteps=trainer.timesteps)
184 | trainer.save()
--------------------------------------------------------------------------------
/src/main_pose_finger.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import argparse
4 | import torch.nn as nn
5 | from datetime import datetime
6 | from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback
7 | from stable_baselines3.common.monitor import Monitor
8 | from stable_baselines3.common.vec_env import VecNormalize
9 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
10 | from datetime import datetime
11 | from definitions import ROOT_DIR
12 | from envs.environment_factory import EnvironmentFactory
13 | from metrics.custom_callbacks import EnvDumpCallback, TensorboardCallback
14 | from train.trainer import MyoTrainer
15 | from models.ppo.policies import LatticeRecurrentActorCriticPolicy
16 |
17 |
18 | parser = argparse.ArgumentParser(description='Main script to train an agent')
19 |
20 | parser.add_argument('--seed', type=int, default=0,
21 | help='Seed for random number generator')
22 | parser.add_argument('--freq', type=int, default=1,
23 | help='SDE sample frequency')
24 | parser.add_argument('--use_sde', action='store_true', default=False,
25 | help='Flag to use SDE')
26 | parser.add_argument('--use_lattice', action='store_true', default=False,
27 | help='Flag to use lattice')
28 | parser.add_argument('--log_std_init', type=float, default=0.0,
29 | help='Initial log standard deviation')
30 | parser.add_argument('--env_path', type=str,
31 | help='Path to environment file')
32 | parser.add_argument('--model_path', type=str,
33 | help='Path to model file')
34 | parser.add_argument('--num_envs', type=int, default=16,
35 | help='Number of parallel environments')
36 | parser.add_argument('--device', type=str, default="cuda",
37 | help='Device, cuda or cpu')
38 | parser.add_argument('--std_reg', type=float, default=0.0,
39 | help='Additional independent std for the multivariate gaussian (only for lattice)')
40 |
41 | args = parser.parse_args()
42 |
43 | # define constants
44 | ENV_NAME = "CustomMyoFingerPoseRandom"
45 |
46 | now = datetime.now().strftime("%Y-%m-%d/%H-%M-%S")
47 |
48 | if args.model_path is not None:
49 | model_name = args.model_path.split("/")[-2]
50 | else:
51 | model_name = None
52 |
53 | TENSORBOARD_LOG = (
54 | os.path.join(ROOT_DIR, "output", "training", now)
55 | + f"_finger_pose_sde_{args.use_sde}_lattice_{args.use_lattice}_freq_{args.freq}_log_std_init_{args.log_std_init}_std_reg_{args.std_reg}_recurrent_ppo_seed_{args.seed}_resume_{model_name}"
56 | )
57 |
58 | # Reward structure and task parameters:
59 | config = {
60 | "weighted_reward_keys": {
61 | "pose": 1,
62 | "bonus": 0,
63 | "penalty": 1,
64 | "act_reg": 0,
65 | "solved": 1,
66 | "done": 0,
67 | "sparse": 0,
68 | },
69 | "seed": args.seed,
70 | "reset_type": "init",
71 | "sds_distance": None,
72 | "weight_bodyname": None,
73 | "weight_range": None,
74 | "target_distance": 0.5,
75 | }
76 |
77 | max_episode_steps = 100
78 |
79 | model_config = dict(
80 | policy=LatticeRecurrentActorCriticPolicy,
81 | device=args.device,
82 | batch_size=32,
83 | n_steps=128,
84 | learning_rate=2.55673e-05,
85 | ent_coef=3.62109e-06,
86 | clip_range=0.3,
87 | gamma=0.99,
88 | gae_lambda=0.9,
89 | max_grad_norm=0.7,
90 | vf_coef=0.835671,
91 | n_epochs=10,
92 | use_sde=args.use_sde,
93 | sde_sample_freq=args.freq,
94 | policy_kwargs=dict(
95 | use_lattice=args.use_lattice,
96 | use_expln=True,
97 | ortho_init=False,
98 | log_std_init=args.log_std_init,
99 | activation_fn=nn.ReLU,
100 | net_arch=[dict(pi=[256, 256], vf=[256, 256])],
101 | std_clip=(1e-3, 10),
102 | expln_eps=1e-6,
103 | full_std=False,
104 | std_reg=args.std_reg
105 | ),
106 | )
107 |
108 | # Function that creates and monitors vectorized environments:
109 | def make_parallel_envs(env_config, num_env, start_index=0):
110 | def make_env(_):
111 | def _thunk():
112 | env = EnvironmentFactory.create(ENV_NAME, **env_config)
113 | env.seed(args.seed)
114 | env._max_episode_steps = max_episode_steps
115 | env = Monitor(env, TENSORBOARD_LOG)
116 | return env
117 |
118 | return _thunk
119 |
120 | return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)])
121 |
122 |
123 | if __name__ == "__main__":
124 | # ensure tensorboard log directory exists and copy this file to track
125 | os.makedirs(TENSORBOARD_LOG, exist_ok=True)
126 | shutil.copy(os.path.abspath(__file__), TENSORBOARD_LOG)
127 |
128 | # Create and wrap the training and evaluations environments
129 | envs = make_parallel_envs(config, args.num_envs)
130 |
131 | if args.env_path is not None:
132 | envs = VecNormalize.load(args.env_path, envs)
133 | else:
134 | envs = VecNormalize(envs)
135 |
136 | # Define callbacks for evaluation and saving the agent
137 | eval_callback = EvalCallback(
138 | eval_env=envs,
139 | callback_on_new_best=EnvDumpCallback(TENSORBOARD_LOG, verbose=0),
140 | n_eval_episodes=10,
141 | best_model_save_path=TENSORBOARD_LOG,
142 | log_path=TENSORBOARD_LOG,
143 | eval_freq=10_000,
144 | deterministic=True,
145 | render=False,
146 | verbose=1,
147 | )
148 |
149 | checkpoint_callback = CheckpointCallback(
150 | save_freq=25_000,
151 | save_path=TENSORBOARD_LOG,
152 | save_vecnormalize=True,
153 | verbose=1,
154 | )
155 |
156 | tensorboard_callback = TensorboardCallback(
157 | info_keywords=(
158 | "pose",
159 | "bonus",
160 | "penalty",
161 | "act_reg",
162 | "done",
163 | "solved",
164 | "sparse",
165 | )
166 | )
167 |
168 | # Define trainer
169 | trainer = MyoTrainer(
170 | algo="recurrent_ppo",
171 | envs=envs,
172 | env_config=config,
173 | load_model_path=args.model_path,
174 | log_dir=TENSORBOARD_LOG,
175 | model_config=model_config,
176 | callbacks=[eval_callback, checkpoint_callback, tensorboard_callback],
177 | timesteps=8_000_000,
178 | )
179 |
180 | # Train agent
181 | trainer.train(total_timesteps=trainer.timesteps)
182 | trainer.save()
183 |
--------------------------------------------------------------------------------
/src/main_pose_hand.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import argparse
4 | import torch.nn as nn
5 | from datetime import datetime
6 | from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback
7 | from stable_baselines3.common.monitor import Monitor
8 | from stable_baselines3.common.vec_env import VecNormalize
9 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
10 | from datetime import datetime
11 | from definitions import ROOT_DIR
12 | from envs.environment_factory import EnvironmentFactory
13 | from metrics.custom_callbacks import EnvDumpCallback, TensorboardCallback
14 | from train.trainer import MyoTrainer
15 | from models.ppo.policies import LatticeRecurrentActorCriticPolicy
16 |
17 |
18 | parser = argparse.ArgumentParser(description='Main script to train an agent')
19 |
20 | parser.add_argument('--seed', type=int, default=0,
21 | help='Seed for random number generator')
22 | parser.add_argument('--freq', type=int, default=1,
23 | help='SDE sample frequency')
24 | parser.add_argument('--use_sde', action='store_true', default=False,
25 | help='Flag to use SDE')
26 | parser.add_argument('--use_lattice', action='store_true', default=False,
27 | help='Flag to use lattice')
28 | parser.add_argument('--log_std_init', type=float, default=0.0,
29 | help='Initial log standard deviation')
30 | parser.add_argument('--env_path', type=str,
31 | help='Path to environment file')
32 | parser.add_argument('--model_path', type=str,
33 | help='Path to model file')
34 | parser.add_argument('--num_envs', type=int, default=16,
35 | help='Number of parallel environments')
36 | parser.add_argument('--device', type=str, default="cuda",
37 | help='Device, cuda or cpu')
38 | parser.add_argument('--std_reg', type=float, default=0.0,
39 | help='Additional independent std for the multivariate gaussian (only for lattice)')
40 | args = parser.parse_args()
41 |
42 | # define constants
43 | ENV_NAME = "CustomMyoHandPoseRandom"
44 |
45 | if args.model_path is not None:
46 | model_name = args.model_path.split("/")[-2]
47 | else:
48 | model_name = None
49 |
50 | now = datetime.now().strftime("%Y-%m-%d/%H-%M-%S")
51 | TENSORBOARD_LOG = (
52 | os.path.join(ROOT_DIR, "output", "training", now)
53 | + f"_hand_pose_sde_{args.use_sde}_lattice_{args.use_lattice}_freq_{args.freq}_log_std_init_{args.log_std_init}_std_reg_{args.std_reg}_recurrent_ppo_seed_{args.seed}_resume_{model_name}"
54 | )
55 |
56 | # Reward structure and task parameters:
57 | config = {
58 | "weighted_reward_keys": {
59 | "pose": 1,
60 | "bonus": 0,
61 | "penalty": 1,
62 | "act_reg": 0,
63 | "solved": 1,
64 | "done": 0,
65 | "sparse": 0,
66 | },
67 | "seed": args.seed,
68 | "reset_type": "init",
69 | "sds_distance": None,
70 | "weight_bodyname": None,
71 | "weight_range": None,
72 | "target_distance": 0.5,
73 | }
74 |
75 | max_episode_steps = 100 # default: 100
76 |
77 | model_config = dict(
78 | policy=LatticeRecurrentActorCriticPolicy,
79 | device=args.device,
80 | batch_size=32,
81 | n_steps=128,
82 | learning_rate=2.55673e-05,
83 | ent_coef=3.62109e-06,
84 | clip_range=0.3,
85 | gamma=0.99,
86 | gae_lambda=0.9,
87 | max_grad_norm=0.7,
88 | vf_coef=0.835671,
89 | n_epochs=10,
90 | use_sde=args.use_sde,
91 | sde_sample_freq=args.freq,
92 | policy_kwargs=dict(
93 | use_lattice=args.use_lattice,
94 | use_expln=True,
95 | ortho_init=False,
96 | log_std_init=args.log_std_init,
97 | activation_fn=nn.ReLU,
98 | net_arch=[dict(pi=[256, 256], vf=[256, 256])],
99 | std_clip=(1e-3, 10),
100 | expln_eps=1e-6,
101 | full_std=False,
102 | std_reg=args.std_reg
103 | ),
104 | )
105 |
106 | # Function that creates and monitors vectorized environments:
107 | def make_parallel_envs(env_config, num_env, start_index=0):
108 | def make_env(_):
109 | def _thunk():
110 | env = EnvironmentFactory.create(ENV_NAME, **env_config)
111 | env.seed(args.seed)
112 | env._max_episode_steps = max_episode_steps
113 | env = Monitor(env, TENSORBOARD_LOG)
114 | return env
115 |
116 | return _thunk
117 |
118 | return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)])
119 |
120 |
121 | if __name__ == "__main__":
122 | # ensure tensorboard log directory exists and copy this file to track
123 | os.makedirs(TENSORBOARD_LOG, exist_ok=True)
124 | shutil.copy(os.path.abspath(__file__), TENSORBOARD_LOG)
125 |
126 | # Create and wrap the training and evaluations environments
127 | envs = make_parallel_envs(config, args.num_envs)
128 |
129 | if args.env_path is not None:
130 | envs = VecNormalize.load(args.env_path, envs)
131 | else:
132 | envs = VecNormalize(envs)
133 |
134 | # Define callbacks for evaluation and saving the agent
135 | eval_callback = EvalCallback(
136 | eval_env=envs,
137 | callback_on_new_best=EnvDumpCallback(TENSORBOARD_LOG, verbose=0),
138 | n_eval_episodes=10,
139 | best_model_save_path=TENSORBOARD_LOG,
140 | log_path=TENSORBOARD_LOG,
141 | eval_freq=10_000,
142 | deterministic=True,
143 | render=False,
144 | verbose=1,
145 | )
146 |
147 | checkpoint_callback = CheckpointCallback(
148 | save_freq=25_000,
149 | save_path=TENSORBOARD_LOG,
150 | save_vecnormalize=True,
151 | verbose=1,
152 | )
153 |
154 | tensorboard_callback = TensorboardCallback(
155 | info_keywords=(
156 | "pose",
157 | "bonus",
158 | "penalty",
159 | "act_reg",
160 | "done",
161 | "solved",
162 | "sparse",
163 | )
164 | )
165 |
166 | # Define trainer
167 | trainer = MyoTrainer(
168 | algo="recurrent_ppo",
169 | envs=envs,
170 | env_config=config,
171 | load_model_path=args.model_path,
172 | log_dir=TENSORBOARD_LOG,
173 | model_config=model_config,
174 | callbacks=[eval_callback, checkpoint_callback, tensorboard_callback],
175 | timesteps=20_000_000,
176 | )
177 |
178 | # Train agent
179 | trainer.train(total_timesteps=trainer.timesteps)
180 | trainer.save()
181 |
--------------------------------------------------------------------------------
/src/main_reach_finger.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import argparse
4 | import torch.nn as nn
5 | from datetime import datetime
6 | from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback
7 | from stable_baselines3.common.monitor import Monitor
8 | from stable_baselines3.common.vec_env import VecNormalize
9 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
10 | from datetime import datetime
11 | from definitions import ROOT_DIR
12 | from envs.environment_factory import EnvironmentFactory
13 | from metrics.custom_callbacks import EnvDumpCallback, TensorboardCallback
14 | from train.trainer import MyoTrainer
15 | from models.ppo.policies import LatticeRecurrentActorCriticPolicy
16 |
17 |
18 | parser = argparse.ArgumentParser(description='Main script to train an agent')
19 |
20 | parser.add_argument('--seed', type=int, default=0,
21 | help='Seed for random number generator')
22 | parser.add_argument('--freq', type=int, default=1,
23 | help='SDE sample frequency')
24 | parser.add_argument('--use_sde', action='store_true', default=False,
25 | help='Flag to use SDE')
26 | parser.add_argument('--use_lattice', action='store_true', default=False,
27 | help='Flag to use lattice')
28 | parser.add_argument('--log_std_init', type=float, default=0.0,
29 | help='Initial log standard deviation')
30 | parser.add_argument('--env_path', type=str,
31 | help='Path to environment file')
32 | parser.add_argument('--model_path', type=str,
33 | help='Path to model file')
34 | parser.add_argument('--num_envs', type=int, default=16,
35 | help='Number of parallel environments')
36 | parser.add_argument('--device', type=str, default="cuda",
37 | help='Device, cuda or cpu')
38 | parser.add_argument('--std_reg', type=float, default=0.0,
39 | help='Additional independent std for the multivariate gaussian (only for lattice)')
40 |
41 | args = parser.parse_args()
42 |
43 | # define constants
44 | ENV_NAME = "MyoFingerReachRandom"
45 |
46 | now = datetime.now().strftime("%Y-%m-%d/%H-%M-%S")
47 |
48 | if args.model_path is not None:
49 | model_name = args.model_path.split("/")[-2]
50 | else:
51 | model_name = None
52 |
53 | TENSORBOARD_LOG = (
54 | os.path.join(ROOT_DIR, "output", "training", now)
55 | + f"_finger_reach_sde_{args.use_sde}_lattice_{args.use_lattice}_freq_{args.freq}_log_std_init_{args.log_std_init}_std_reg_{args.std_reg}_recurrent_ppo_seed_{args.seed}_resume_{model_name}"
56 | )
57 |
58 | # Reward structure and task parameters:
59 | config = {
60 | "seed": args.seed,
61 | }
62 |
63 | max_episode_steps = 100
64 |
65 | model_config = dict(
66 | policy=LatticeRecurrentActorCriticPolicy,
67 | device=args.device,
68 | batch_size=32,
69 | n_steps=128,
70 | learning_rate=2.55673e-05,
71 | ent_coef=3.62109e-06,
72 | clip_range=0.3,
73 | gamma=0.99,
74 | gae_lambda=0.9,
75 | max_grad_norm=0.7,
76 | vf_coef=0.835671,
77 | n_epochs=10,
78 | use_sde=args.use_sde,
79 | sde_sample_freq=args.freq,
80 | policy_kwargs=dict(
81 | use_lattice=args.use_lattice,
82 | use_expln=True,
83 | ortho_init=False,
84 | log_std_init=args.log_std_init,
85 | activation_fn=nn.ReLU,
86 | net_arch=[dict(pi=[256, 256], vf=[256, 256])],
87 | std_clip=(1e-3, 10),
88 | expln_eps=1e-6,
89 | full_std=False,
90 | std_reg=args.std_reg
91 | ),
92 | )
93 |
94 | # Function that creates and monitors vectorized environments:
95 | def make_parallel_envs(env_config, num_env, start_index=0):
96 | def make_env(_):
97 | def _thunk():
98 | env = EnvironmentFactory.create(ENV_NAME, **env_config)
99 | env.seed(args.seed)
100 | env._max_episode_steps = max_episode_steps
101 | env = Monitor(env, TENSORBOARD_LOG)
102 | return env
103 |
104 | return _thunk
105 |
106 | return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)])
107 |
108 |
109 | if __name__ == "__main__":
110 | # ensure tensorboard log directory exists and copy this file to track
111 | os.makedirs(TENSORBOARD_LOG, exist_ok=True)
112 | shutil.copy(os.path.abspath(__file__), TENSORBOARD_LOG)
113 |
114 | # Create and wrap the training and evaluations environments
115 | envs = make_parallel_envs(config, args.num_envs)
116 |
117 | if args.env_path is not None:
118 | envs = VecNormalize.load(args.env_path, envs)
119 | else:
120 | envs = VecNormalize(envs)
121 |
122 | # Define callbacks for evaluation and saving the agent
123 | eval_callback = EvalCallback(
124 | eval_env=envs,
125 | callback_on_new_best=EnvDumpCallback(TENSORBOARD_LOG, verbose=0),
126 | n_eval_episodes=10,
127 | best_model_save_path=TENSORBOARD_LOG,
128 | log_path=TENSORBOARD_LOG,
129 | eval_freq=10_000,
130 | deterministic=True,
131 | render=False,
132 | verbose=1,
133 | )
134 |
135 | checkpoint_callback = CheckpointCallback(
136 | save_freq=25_000,
137 | save_path=TENSORBOARD_LOG,
138 | save_vecnormalize=True,
139 | verbose=1,
140 | )
141 |
142 | tensorboard_callback = TensorboardCallback(
143 | info_keywords=(
144 | "rwd_dense",
145 | "rwd_sparse",
146 | "solved",
147 | "done",
148 | )
149 | )
150 |
151 | # Define trainer
152 | trainer = MyoTrainer(
153 | algo="recurrent_ppo",
154 | envs=envs,
155 | env_config=config,
156 | load_model_path=args.model_path,
157 | log_dir=TENSORBOARD_LOG,
158 | model_config=model_config,
159 | callbacks=[eval_callback, checkpoint_callback, tensorboard_callback],
160 | timesteps=15_000_000,
161 | )
162 |
163 | # Train agent
164 | trainer.train(total_timesteps=trainer.timesteps)
165 | trainer.save()
166 |
--------------------------------------------------------------------------------
/src/main_reach_hand.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import argparse
4 | import torch.nn as nn
5 | from datetime import datetime
6 | from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback
7 | from stable_baselines3.common.monitor import Monitor
8 | from stable_baselines3.common.vec_env import VecNormalize
9 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
10 | from datetime import datetime
11 | from definitions import ROOT_DIR
12 | from envs.environment_factory import EnvironmentFactory
13 | from metrics.custom_callbacks import EnvDumpCallback, TensorboardCallback
14 | from train.trainer import MyoTrainer
15 | from models.ppo.policies import LatticeRecurrentActorCriticPolicy
16 |
17 |
18 | parser = argparse.ArgumentParser(description='Main script to train an agent')
19 |
20 | parser.add_argument('--seed', type=int, default=0,
21 | help='Seed for random number generator')
22 | parser.add_argument('--freq', type=int, default=1,
23 | help='SDE sample frequency')
24 | parser.add_argument('--use_sde', action='store_true', default=False,
25 | help='Flag to use SDE')
26 | parser.add_argument('--use_lattice', action='store_true', default=False,
27 | help='Flag to use lattice')
28 | parser.add_argument('--log_std_init', type=float, default=0.0,
29 | help='Initial log standard deviation')
30 | parser.add_argument('--env_path', type=str,
31 | help='Path to environment file')
32 | parser.add_argument('--model_path', type=str,
33 | help='Path to model file')
34 | parser.add_argument('--num_envs', type=int, default=16,
35 | help='Number of parallel environments')
36 | parser.add_argument('--device', type=str, default="cuda",
37 | help='Device, cuda or cpu')
38 | parser.add_argument('--std_reg', type=float, default=0.0,
39 | help='Additional independent std for the multivariate gaussian (only for lattice)')
40 |
41 | args = parser.parse_args()
42 |
43 | # define constants
44 | ENV_NAME = "MyoHandReachRandom"
45 |
46 | now = datetime.now().strftime("%Y-%m-%d/%H-%M-%S")
47 |
48 | if args.model_path is not None:
49 | model_name = args.model_path.split("/")[-2]
50 | else:
51 | model_name = None
52 |
53 | TENSORBOARD_LOG = (
54 | os.path.join(ROOT_DIR, "output", "training", now)
55 | + f"_hand_reach_sde_{args.use_sde}_lattice_{args.use_lattice}_freq_{args.freq}_log_std_init_{args.log_std_init}_std_reg_{args.std_reg}_recurrent_ppo_seed_{args.seed}_resume_{model_name}"
56 | )
57 |
58 | # Reward structure and task parameters:
59 | config = {
60 | "seed": args.seed,
61 | }
62 |
63 | max_episode_steps = 100
64 |
65 | model_config = dict(
66 | policy=LatticeRecurrentActorCriticPolicy,
67 | device=args.device,
68 | batch_size=32,
69 | n_steps=128,
70 | learning_rate=2.55673e-05,
71 | ent_coef=3.62109e-06,
72 | clip_range=0.3,
73 | gamma=0.99,
74 | gae_lambda=0.9,
75 | max_grad_norm=0.7,
76 | vf_coef=0.835671,
77 | n_epochs=10,
78 | use_sde=args.use_sde,
79 | sde_sample_freq=args.freq,
80 | policy_kwargs=dict(
81 | use_lattice=args.use_lattice,
82 | use_expln=True,
83 | ortho_init=False,
84 | log_std_init=args.log_std_init,
85 | activation_fn=nn.ReLU,
86 | net_arch=[dict(pi=[256, 256], vf=[256, 256])],
87 | std_clip=(1e-3, 10),
88 | expln_eps=1e-6,
89 | full_std=False,
90 | std_reg=args.std_reg
91 | ),
92 | )
93 |
94 | # Function that creates and monitors vectorized environments:
95 | def make_parallel_envs(env_config, num_env, start_index=0):
96 | def make_env(_):
97 | def _thunk():
98 | env = EnvironmentFactory.create(ENV_NAME, **env_config)
99 | env.seed(args.seed)
100 | env._max_episode_steps = max_episode_steps
101 | env = Monitor(env, TENSORBOARD_LOG)
102 | return env
103 |
104 | return _thunk
105 |
106 | return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)])
107 |
108 |
109 | if __name__ == "__main__":
110 | # ensure tensorboard log directory exists and copy this file to track
111 | os.makedirs(TENSORBOARD_LOG, exist_ok=True)
112 | shutil.copy(os.path.abspath(__file__), TENSORBOARD_LOG)
113 |
114 | # Create and wrap the training and evaluations environments
115 | envs = make_parallel_envs(config, args.num_envs)
116 |
117 | if args.env_path is not None:
118 | envs = VecNormalize.load(args.env_path, envs)
119 | else:
120 | envs = VecNormalize(envs)
121 |
122 | # Define callbacks for evaluation and saving the agent
123 | eval_callback = EvalCallback(
124 | eval_env=envs,
125 | callback_on_new_best=EnvDumpCallback(TENSORBOARD_LOG, verbose=0),
126 | n_eval_episodes=10,
127 | best_model_save_path=TENSORBOARD_LOG,
128 | log_path=TENSORBOARD_LOG,
129 | eval_freq=10_000,
130 | deterministic=True,
131 | render=False,
132 | verbose=1,
133 | )
134 |
135 | checkpoint_callback = CheckpointCallback(
136 | save_freq=25_000,
137 | save_path=TENSORBOARD_LOG,
138 | save_vecnormalize=True,
139 | verbose=1,
140 | )
141 |
142 | tensorboard_callback = TensorboardCallback(
143 | info_keywords=(
144 | "rwd_dense",
145 | "rwd_sparse",
146 | "solved",
147 | "done",
148 | )
149 | )
150 |
151 | # Define trainer
152 | trainer = MyoTrainer(
153 | algo="recurrent_ppo",
154 | envs=envs,
155 | env_config=config,
156 | load_model_path=args.model_path,
157 | log_dir=TENSORBOARD_LOG,
158 | model_config=model_config,
159 | callbacks=[eval_callback, checkpoint_callback, tensorboard_callback],
160 | timesteps=15_000_000,
161 | )
162 |
163 | # Train agent
164 | trainer.train(total_timesteps=trainer.timesteps)
165 | trainer.save()
--------------------------------------------------------------------------------
/src/main_reorient.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import argparse
4 | import torch.nn as nn
5 | from datetime import datetime
6 | from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback
7 | from stable_baselines3.common.monitor import Monitor
8 | from stable_baselines3.common.vec_env import VecNormalize
9 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
10 | from datetime import datetime
11 | from definitions import ROOT_DIR
12 | from envs.environment_factory import EnvironmentFactory
13 | from metrics.custom_callbacks import EnvDumpCallback, TensorboardCallback
14 | from train.trainer import MyoTrainer
15 | from models.ppo.policies import LatticeRecurrentActorCriticPolicy
16 |
17 |
18 | parser = argparse.ArgumentParser(description='Main script to train an agent')
19 |
20 | parser.add_argument('--seed', type=int, default=0,
21 | help='Seed for random number generator')
22 | parser.add_argument('--freq', type=int, default=1,
23 | help='SDE sample frequency')
24 | parser.add_argument('--use_sde', action='store_true', default=False,
25 | help='Flag to use SDE')
26 | parser.add_argument('--use_lattice', action='store_true', default=False,
27 | help='Flag to use lattice')
28 | parser.add_argument('--log_std_init', type=float, default=0.0,
29 | help='Initial log standard deviation')
30 | parser.add_argument('--env_path', type=str,
31 | help='Path to environment file')
32 | parser.add_argument('--model_path', type=str,
33 | help='Path to model file')
34 | parser.add_argument('--num_envs', type=int, default=16,
35 | help='Number of parallel environments')
36 | parser.add_argument('--device', type=str, default="cuda",
37 | help='Device, cuda or cpu')
38 | parser.add_argument('--std_reg', type=float, default=0.0,
39 | help='Additional independent std for the multivariate gaussian (only for lattice)')
40 |
41 | args = parser.parse_args()
42 |
43 | # define constants
44 | ENV_NAME = "CustomMyoReorientP2"
45 |
46 | now = datetime.now().strftime("%Y-%m-%d/%H-%M-%S")
47 |
48 | if args.model_path is not None:
49 | model_name = args.model_path.split("/")[-2]
50 | else:
51 | model_name = None
52 |
53 | TENSORBOARD_LOG = (
54 | os.path.join(ROOT_DIR, "output", "training", now)
55 | + f"_reorient_sde_{args.use_sde}_lattice_{args.use_lattice}_freq_{args.freq}_log_std_init_{args.log_std_init}_std_reg_{args.std_reg}_recurrent_ppo_seed_{args.seed}_resume_{model_name}"
56 | )
57 |
58 | # Reward structure and task parameters:
59 | config = {
60 | "seed": args.seed,
61 | "weighted_reward_keys": {
62 | "pos_dist": 2,
63 | "rot_dist": 0.2,
64 | "pos_dist_diff": 1e2,
65 | "rot_dist_diff": 1e1,
66 | "alive": 1,
67 | "act_reg": 0,
68 | "solved": 2,
69 | "done": 0,
70 | "sparse": 0,
71 | },
72 | "goal_pos": (-0.0, 0.0),
73 | "goal_rot": (-3.14 * 0.25, 3.14 * 0.25),
74 | "obj_size_change": 0,
75 | "obj_friction_change": (0, 0, 0),
76 | "enable_rsi": False,
77 | "rsi_distance_pos": None,
78 | "rsi_distance_rot": None,
79 | "goal_rot_x": None,
80 | "goal_rot_y": None,
81 | "goal_rot_z": None,
82 | "guided_trajectory_steps": 0,
83 | }
84 |
85 | max_episode_steps = 150
86 |
87 | model_config = dict(
88 | policy=LatticeRecurrentActorCriticPolicy,
89 | device=args.device,
90 | batch_size=32,
91 | n_steps=128,
92 | learning_rate=2.55673e-05,
93 | ent_coef=3.62109e-06,
94 | clip_range=0.3,
95 | gamma=0.99,
96 | gae_lambda=0.9,
97 | max_grad_norm=0.7,
98 | vf_coef=0.835671,
99 | n_epochs=10,
100 | use_sde=args.use_sde,
101 | sde_sample_freq=args.freq,
102 | policy_kwargs=dict(
103 | use_lattice=args.use_lattice,
104 | use_expln=True,
105 | ortho_init=False,
106 | log_std_init=args.log_std_init,
107 | activation_fn=nn.ReLU,
108 | net_arch=[dict(pi=[256, 256], vf=[256, 256])],
109 | std_clip=(1e-3, 10),
110 | expln_eps=1e-6,
111 | full_std=False,
112 | std_reg=args.std_reg
113 | ),
114 | )
115 |
116 | # Function that creates and monitors vectorized environments:
117 | def make_parallel_envs(env_config, num_env, start_index=0):
118 | def make_env(_):
119 | def _thunk():
120 | env = EnvironmentFactory.create(ENV_NAME, **env_config)
121 | env.seed(args.seed)
122 | env._max_episode_steps = max_episode_steps
123 | env = Monitor(env, TENSORBOARD_LOG)
124 | return env
125 |
126 | return _thunk
127 |
128 | return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)])
129 |
130 |
131 | if __name__ == "__main__":
132 | # ensure tensorboard log directory exists and copy this file to track
133 | os.makedirs(TENSORBOARD_LOG, exist_ok=True)
134 | shutil.copy(os.path.abspath(__file__), TENSORBOARD_LOG)
135 |
136 | # Create and wrap the training and evaluations environments
137 | envs = make_parallel_envs(config, args.num_envs)
138 |
139 | if args.env_path is not None:
140 | envs = VecNormalize.load(args.env_path, envs)
141 | else:
142 | envs = VecNormalize(envs)
143 |
144 | # Define callbacks for evaluation and saving the agent
145 | eval_callback = EvalCallback(
146 | eval_env=envs,
147 | callback_on_new_best=EnvDumpCallback(TENSORBOARD_LOG, verbose=0),
148 | n_eval_episodes=10,
149 | best_model_save_path=TENSORBOARD_LOG,
150 | log_path=TENSORBOARD_LOG,
151 | eval_freq=10_000,
152 | deterministic=True,
153 | render=False,
154 | verbose=1,
155 | )
156 |
157 | checkpoint_callback = CheckpointCallback(
158 | save_freq=25_000,
159 | save_path=TENSORBOARD_LOG,
160 | save_vecnormalize=True,
161 | verbose=1,
162 | )
163 |
164 | tensorboard_callback = TensorboardCallback(
165 | info_keywords=(
166 | "pos_dist",
167 | "rot_dist",
168 | "pos_dist_diff",
169 | "rot_dist_diff",
170 | "act_reg",
171 | "alive",
172 | "solved",
173 | )
174 | )
175 |
176 | # Define trainer
177 | trainer = MyoTrainer(
178 | algo="recurrent_ppo",
179 | envs=envs,
180 | env_config=config,
181 | load_model_path=args.model_path,
182 | log_dir=TENSORBOARD_LOG,
183 | model_config=model_config,
184 | callbacks=[eval_callback, checkpoint_callback, tensorboard_callback],
185 | timesteps=20_000_000,
186 | )
187 |
188 | # Train agent
189 | trainer.train(total_timesteps=trainer.timesteps)
190 | trainer.save()
191 |
--------------------------------------------------------------------------------
/src/main_walker.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import argparse
4 | import torch.nn as nn
5 | from datetime import datetime
6 | from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback
7 | from stable_baselines3.common.monitor import Monitor
8 | from stable_baselines3.common.vec_env import VecNormalize
9 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
10 | from datetime import datetime
11 | from definitions import ROOT_DIR
12 | from envs.environment_factory import EnvironmentFactory
13 | from metrics.custom_callbacks import EnvDumpCallback, TensorboardCallback
14 | from train.trainer import MyoTrainer
15 | from models.sac.policies import LatticeSACPolicy
16 |
17 |
18 | parser = argparse.ArgumentParser(description="Main script to train an agent")
19 |
20 | parser.add_argument("--seed", type=int, default=0,
21 | help="Seed for random number generator")
22 | parser.add_argument("--freq", type=int, default=1,
23 | help="SDE sample frequency")
24 | parser.add_argument("--use_sde", action="store_true",
25 | default=False, help="Flag to use SDE")
26 | parser.add_argument("--use_lattice", action="store_true",
27 | default=False, help="Flag to use lattice")
28 | parser.add_argument("--log_std_init", type=float, default=0.0,
29 | help="Initial log standard deviation")
30 | parser.add_argument("--env_path", type=str,
31 | help="Path to environment file")
32 | parser.add_argument("--model_path", type=str,
33 | help="Path to model file")
34 | parser.add_argument("--num_envs", type=int, default=16,
35 | help="Number of parallel environments")
36 | parser.add_argument("--device", type=str, default="cuda",
37 | help="Device, cuda or cpu")
38 | parser.add_argument("--std_reg", type=float, default=0.0,
39 | help="Additional independent std for the multivariate gaussian (only for lattice)")
40 | args = parser.parse_args()
41 |
42 | # define constants
43 | ENV_NAME = "WalkerBulletEnv"
44 |
45 | now = datetime.now().strftime("%Y-%m-%d/%H-%M-%S")
46 |
47 | if args.model_path is not None:
48 | model_name = args.model_path.split("/")[-2]
49 | else:
50 | model_name = None
51 |
52 | TENSORBOARD_LOG = (
53 | os.path.join(ROOT_DIR, "output", "training", now)
54 | + f"_walker_sde_{args.use_sde}_lattice_{args.use_lattice}_freq_{args.freq}_log_std_init_{args.log_std_init}_std_reg_{args.std_reg}_sac_seed_{args.seed}_resume_{model_name}"
55 | )
56 |
57 | # Reward structure and task parameters:
58 | config = {}
59 |
60 | max_episode_steps = 1000
61 |
62 | model_config = dict(
63 | policy=LatticeSACPolicy,
64 | device=args.device,
65 | learning_rate=3e-4,
66 | buffer_size=300_000,
67 | learning_starts=10000,
68 | batch_size=256,
69 | tau=0.02,
70 | gamma=0.98,
71 | train_freq=(8, "step"),
72 | gradient_steps=8,
73 | action_noise=None,
74 | replay_buffer_class=None,
75 | ent_coef="auto",
76 | target_update_interval=1,
77 | target_entropy="auto",
78 | seed=args.seed,
79 | use_sde=args.use_sde,
80 | sde_sample_freq=args.freq, # number of steps
81 | policy_kwargs=dict(
82 | use_lattice=args.use_lattice,
83 | use_expln=True,
84 | log_std_init=args.log_std_init,
85 | activation_fn=nn.GELU,
86 | net_arch=dict(pi=[400, 300], qf=[400, 300]),
87 | std_clip=(1e-3, 10),
88 | expln_eps=1e-6,
89 | clip_mean=2.0,
90 | std_reg=args.std_reg,
91 | ),
92 | )
93 |
94 | # Function that creates and monitors vectorized environments:
95 | def make_parallel_envs(env_config, num_env, start_index=0):
96 | def make_env(_):
97 | def _thunk():
98 | env = EnvironmentFactory.create(ENV_NAME, **env_config)
99 | env.seed(args.seed)
100 | env._max_episode_steps = max_episode_steps
101 | env = Monitor(env, TENSORBOARD_LOG)
102 | return env
103 |
104 | return _thunk
105 |
106 | return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)])
107 |
108 |
109 | if __name__ == "__main__":
110 | # ensure tensorboard log directory exists and copy this file to track
111 | os.makedirs(TENSORBOARD_LOG, exist_ok=True)
112 | shutil.copy(os.path.abspath(__file__), TENSORBOARD_LOG)
113 |
114 | # Create and wrap the training and evaluations environments
115 | envs = make_parallel_envs(config, args.num_envs)
116 |
117 | if args.env_path is not None:
118 | envs = VecNormalize.load(args.env_path, envs)
119 | else:
120 | envs = VecNormalize(envs)
121 |
122 | # Define callbacks for evaluation and saving the agent
123 | eval_callback = EvalCallback(
124 | eval_env=envs,
125 | callback_on_new_best=EnvDumpCallback(TENSORBOARD_LOG, verbose=0),
126 | n_eval_episodes=10,
127 | best_model_save_path=TENSORBOARD_LOG,
128 | log_path=TENSORBOARD_LOG,
129 | eval_freq=10_000,
130 | deterministic=True,
131 | render=False,
132 | verbose=1,
133 | )
134 |
135 | checkpoint_callback = CheckpointCallback(
136 | save_freq=25_000,
137 | save_path=TENSORBOARD_LOG,
138 | save_vecnormalize=True,
139 | verbose=1,
140 | )
141 |
142 | tensorboard_callback = TensorboardCallback(info_keywords=())
143 |
144 | # Define trainer
145 | trainer = MyoTrainer(
146 | algo="sac",
147 | envs=envs,
148 | env_config=config,
149 | load_model_path=args.model_path,
150 | log_dir=TENSORBOARD_LOG,
151 | model_config=model_config,
152 | callbacks=[eval_callback, checkpoint_callback, tensorboard_callback],
153 | timesteps=15_000_000,
154 | )
155 |
156 | # Train agent
157 | trainer.train(total_timesteps=trainer.timesteps)
158 | trainer.save()
159 |
--------------------------------------------------------------------------------
/src/metrics/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amathislab/lattice/846d02fa993b9b80ce5ecb806463e0a05711bad3/src/metrics/__init__.py
--------------------------------------------------------------------------------
/src/metrics/custom_callbacks.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from stable_baselines3.common.callbacks import BaseCallback
4 |
5 |
6 | class EnvDumpCallback(BaseCallback):
7 | def __init__(self, save_path, verbose=0):
8 | super().__init__(verbose=verbose)
9 | self.save_path = save_path
10 |
11 | def _on_step(self):
12 | env_path = os.path.join(self.save_path, "training_env.pkl")
13 | if self.verbose > 0:
14 | print("Saving the training environment to path ", env_path)
15 | self.training_env.save(env_path)
16 | return True
17 |
18 |
19 | class TensorboardCallback(BaseCallback):
20 | def __init__(self, info_keywords, verbose=0):
21 | super().__init__(verbose=verbose)
22 | self.info_keywords = info_keywords
23 | self.rollout_info = {}
24 |
25 | def _on_rollout_start(self):
26 | self.rollout_info = {key: [] for key in self.info_keywords}
27 |
28 | def _on_step(self):
29 | for key in self.info_keywords:
30 | vals = [info[key] for info in self.locals["infos"]]
31 | self.rollout_info[key].extend(vals)
32 | return True
33 |
34 | def _on_rollout_end(self):
35 | for key in self.info_keywords:
36 | self.logger.record("rollout/" + key, np.mean(self.rollout_info[key]))
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amathislab/lattice/846d02fa993b9b80ce5ecb806463e0a05711bad3/src/models/__init__.py
--------------------------------------------------------------------------------
/src/models/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from typing import Optional
4 | from torch import nn
5 | from torch.distributions import MultivariateNormal
6 | from typing import Tuple
7 | from torch.distributions import Normal
8 | from stable_baselines3.common.distributions import DiagGaussianDistribution, TanhBijector, StateDependentNoiseDistribution
9 |
10 |
11 | class LatticeNoiseDistribution(DiagGaussianDistribution):
12 | """
13 | Like Lattice noise distribution, non-state-dependent. Does not allow time correlation, but
14 | it is more efficient.
15 |
16 | :param action_dim: Dimension of the action space.
17 | """
18 |
19 | def __init__(self, action_dim: int):
20 | super().__init__(action_dim=action_dim)
21 |
22 | def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0, state_dependent: bool = False) -> Tuple[nn.Module, nn.Parameter]:
23 | self.mean_actions = nn.Linear(latent_dim, self.action_dim)
24 | self.std_init = torch.tensor(log_std_init).exp()
25 | if state_dependent:
26 | log_std = nn.Linear(latent_dim, self.action_dim + latent_dim)
27 | else:
28 | log_std = nn.Parameter(torch.zeros(self.action_dim + latent_dim), requires_grad=True)
29 | return self.mean_actions, log_std
30 |
31 | def proba_distribution(self, mean_actions: torch.Tensor, log_std: torch.Tensor) -> "LatticeNoiseDistribution":
32 | """
33 | Create the distribution given its parameters (mean, std)
34 |
35 | :param mean_actions:
36 | :param log_std:
37 | :return:
38 | """
39 | std = log_std.exp() * self.std_init
40 | action_variance = std[..., : self.action_dim] ** 2
41 | latent_variance = std[..., self.action_dim :] ** 2
42 |
43 | sigma_mat = (self.mean_actions.weight * latent_variance[..., None, :]).matmul(self.mean_actions.weight.T)
44 | sigma_mat[..., range(self.action_dim), range(self.action_dim)] += action_variance
45 | self.distribution = MultivariateNormal(mean_actions, sigma_mat)
46 | return self
47 |
48 | def log_prob(self, actions: torch.Tensor) -> torch.Tensor:
49 | return self.distribution.log_prob(actions)
50 |
51 | def entropy(self) -> torch.Tensor:
52 | return self.distribution.entropy()
53 |
54 |
55 | class SquashedLatticeNoiseDistribution(LatticeNoiseDistribution):
56 | """
57 | Lattice noise distribution, followed by a squashing function (tanh) to ensure bounds.
58 |
59 | :param action_dim: Dimension of the action space.
60 | :param epsilon: small value to avoid NaN due to numerical imprecision.
61 | """
62 | def __init__(self, action_dim: int, epsilon: float = 1e-6):
63 | super().__init__(action_dim)
64 | self.epsilon = epsilon
65 | self.gaussian_actions: Optional[torch.Tensor] = None
66 |
67 | def log_prob(self, actions: torch.Tensor, gaussian_actions: Optional[torch.Tensor] = None) -> torch.Tensor:
68 | # Inverse tanh
69 | # Naive implementation (not stable): 0.5 * torch.log((1 + x) / (1 - x))
70 | # We use numpy to avoid numerical instability
71 | if gaussian_actions is None:
72 | # It will be clipped to avoid NaN when inversing tanh
73 | gaussian_actions = TanhBijector.inverse(actions)
74 |
75 | # Log likelihood for a Gaussian distribution
76 | log_prob = super().log_prob(gaussian_actions)
77 | # Squash correction (from original SAC implementation)
78 | # this comes from the fact that tanh is bijective and differentiable
79 | log_prob -= torch.sum(torch.log(1 - actions**2 + self.epsilon), dim=1)
80 | return log_prob
81 |
82 | def entropy(self) -> Optional[torch.Tensor]:
83 | # No analytical form,
84 | # entropy needs to be estimated using -log_prob.mean()
85 | return None
86 |
87 | def sample(self) -> torch.Tensor:
88 | # Reparametrization trick to pass gradients
89 | self.gaussian_actions = super().sample()
90 | return torch.tanh(self.gaussian_actions)
91 |
92 | def mode(self) -> torch.Tensor:
93 | self.gaussian_actions = super().mode()
94 | # Squash the output
95 | return torch.tanh(self.gaussian_actions)
96 |
97 | def log_prob_from_params(self, mean_actions: torch.Tensor, log_std: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
98 | action = self.actions_from_params(mean_actions, log_std)
99 | log_prob = self.log_prob(action, self.gaussian_actions)
100 | return action, log_prob
101 |
102 |
103 | class LatticeStateDependentNoiseDistribution(StateDependentNoiseDistribution):
104 | """
105 | Distribution class of Lattice exploration.
106 | Paper: Latent Exploration for Reinforcement Learning https://arxiv.org/abs/2305.20065
107 |
108 | It creates correlated noise across actuators, with a covariance matrix induced by
109 | the network weights. Can improve exploration in high-dimensional systems.
110 |
111 | :param action_dim: Dimension of the action space.
112 | :param full_std: Whether to use (n_features x n_actions) parameters
113 | for the std instead of only (n_features,), defaults to True
114 | :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
115 | a positive standard deviation (cf paper). It allows to keep variance
116 | above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
117 | Defaults to False
118 | :param squash_output: Whether to squash the output using a tanh function,
119 | this ensures bounds are satisfied, defaults to False
120 | :param learn_features: Whether to learn features for gSDE or not, defaults to False
121 | This will enable gradients to be backpropagated through the features, defaults to False
122 | :param epsilon: small value to avoid NaN due to numerical imprecision, defaults to 1e-6
123 | :param std_clip: clip range for the standard deviation, can be used to prevent extreme values,
124 | defaults to (1e-3, 1.0)
125 | :param std_reg: optional regularization to prevent collapsing to a deterministic policy,
126 | defaults to 0.0
127 | :param alpha: relative weight between action and latent noise, 0 removes the latent noise,
128 | defaults to 1 (equal weight)
129 | """
130 | def __init__(
131 | self,
132 | action_dim: int,
133 | full_std: bool = True,
134 | use_expln: bool = False,
135 | squash_output: bool = False,
136 | learn_features: bool = False,
137 | epsilon: float = 1e-6,
138 | std_clip: Tuple[float, float] = (1e-3, 1.0),
139 | std_reg: float = 0.0,
140 | alpha: float = 1,
141 | ):
142 | super().__init__(
143 | action_dim=action_dim,
144 | full_std=full_std,
145 | use_expln=use_expln,
146 | squash_output=squash_output,
147 | epsilon=epsilon,
148 | learn_features=learn_features,
149 | )
150 | self.min_std, self.max_std = std_clip
151 | self.std_reg = std_reg
152 | self.alpha = alpha
153 |
154 | def get_std(self, log_std: torch.Tensor) -> torch.Tensor:
155 | """
156 | Get the standard deviation from the learned parameter
157 | (log of it by default). This ensures that the std is positive.
158 |
159 | :param log_std:
160 | :return:
161 | """
162 | # Apply correction to remove scaling of action std as a function of the latent
163 | # dimension (see paper for details)
164 | log_std = log_std.clip(min=np.log(self.min_std), max=np.log(self.max_std))
165 | log_std = log_std - 0.5 * np.log(self.latent_sde_dim)
166 |
167 | if self.use_expln:
168 | # From gSDE paper, it allows to keep variance
169 | # above zero and prevent it from growing too fast
170 | below_threshold = torch.exp(log_std) * (log_std <= 0)
171 | # Avoid NaN: zeros values that are below zero
172 | safe_log_std = log_std * (log_std > 0) + self.epsilon
173 | above_threshold = (torch.log1p(safe_log_std) + 1.0) * (log_std > 0)
174 | std = below_threshold + above_threshold
175 | else:
176 | # Use normal exponential
177 | std = torch.exp(log_std)
178 |
179 | if self.full_std:
180 | assert std.shape == (
181 | self.latent_sde_dim,
182 | self.latent_sde_dim + self.action_dim,
183 | )
184 | corr_std = std[:, : self.latent_sde_dim]
185 | ind_std = std[:, -self.action_dim :]
186 | else:
187 | # Reduce the number of parameters:
188 | assert std.shape == (self.latent_sde_dim, 2), std.shape
189 | corr_std = torch.ones(self.latent_sde_dim, self.latent_sde_dim).to(log_std.device) * std[:, 0:1]
190 | ind_std = torch.ones(self.latent_sde_dim, self.action_dim).to(log_std.device) * std[:, 1:]
191 | return corr_std, ind_std
192 |
193 | def sample_weights(self, log_std: torch.Tensor, batch_size: int = 1) -> None:
194 | """
195 | Sample weights for the noise exploration matrix,
196 | using a centered Gaussian distribution.
197 |
198 | :param log_std:
199 | :param batch_size:
200 | """
201 | corr_std, ind_std = self.get_std(log_std)
202 | self.corr_weights_dist = Normal(torch.zeros_like(corr_std), corr_std)
203 | self.ind_weights_dist = Normal(torch.zeros_like(ind_std), ind_std)
204 |
205 | # Reparametrization trick to pass gradients
206 | self.corr_exploration_mat = self.corr_weights_dist.rsample()
207 | self.ind_exploration_mat = self.ind_weights_dist.rsample()
208 |
209 | # Pre-compute matrices in case of parallel exploration
210 | self.corr_exploration_matrices = self.corr_weights_dist.rsample((batch_size,))
211 | self.ind_exploration_matrices = self.ind_weights_dist.rsample((batch_size,))
212 |
213 | def proba_distribution_net(
214 | self,
215 | latent_dim: int,
216 | log_std_init: float = 0,
217 | latent_sde_dim: Optional[int] = None,
218 | clip_mean: float = 0,
219 | ) -> Tuple[nn.Module, nn.Parameter]:
220 | """
221 | Create the layers and parameter that represent the distribution:
222 | one output will be the deterministic action, the other parameter will be the
223 | standard deviation of the distribution that control the weights of the noise matrix,
224 | both for the action perturbation and the latent perturbation.
225 |
226 | :param latent_dim: Dimension of the last layer of the policy (before the action layer)
227 | :param log_std_init: Initial value for the log standard deviation
228 | :param latent_sde_dim: Dimension of the last layer of the features extractor
229 | for gSDE. By default, it is shared with the policy network.
230 | :param clip_mean: From SB3 implementation of SAC, add possibility to hard clip the
231 | mean of the actions.
232 | :return:
233 | """
234 | # Note: we always consider that the noise is based on the features of the last
235 | # layer, so latent_sde_dim is the same as latent_dim
236 | self.mean_actions_net = nn.Linear(latent_dim, self.action_dim)
237 | if clip_mean > 0:
238 | self.clipped_mean_actions_net = nn.Sequential(
239 | self.mean_actions_net,
240 | nn.Hardtanh(min_val=-clip_mean, max_val=clip_mean))
241 | else:
242 | self.clipped_mean_actions_net = self.mean_actions_net
243 | self.latent_sde_dim = latent_dim if latent_sde_dim is None else latent_sde_dim
244 |
245 | log_std = (
246 | torch.ones(self.latent_sde_dim, self.latent_sde_dim + self.action_dim)
247 | if self.full_std
248 | else torch.ones(self.latent_sde_dim, 2)
249 | )
250 |
251 | # Transform it into a parameter so it can be optimized
252 | log_std = nn.Parameter(log_std * log_std_init, requires_grad=True)
253 | # Sample an exploration matrix
254 | self.sample_weights(log_std)
255 | return self.clipped_mean_actions_net, log_std
256 |
257 | def proba_distribution(
258 | self,
259 | mean_actions: torch.Tensor,
260 | log_std: torch.Tensor,
261 | latent_sde: torch.Tensor,
262 | ) -> "LatticeNoiseDistribution":
263 | # Detach the last layer features because we do not want to update the noise generation
264 | # to influence the features of the policy
265 | self._latent_sde = latent_sde if self.learn_features else latent_sde.detach()
266 | corr_std, ind_std = self.get_std(log_std)
267 | latent_corr_variance = torch.mm(self._latent_sde**2, corr_std**2) # Variance of the hidden state
268 | latent_ind_variance = torch.mm(self._latent_sde**2, ind_std**2) + self.std_reg**2 # Variance of the action
269 |
270 | # First consider the correlated variance
271 | sigma_mat = self.alpha**2 * (self.mean_actions_net.weight * latent_corr_variance[:, None, :]).matmul(
272 | self.mean_actions_net.weight.T
273 | )
274 | # Then the independent one, to be added to the diagonal
275 | sigma_mat[:, range(self.action_dim), range(self.action_dim)] += latent_ind_variance
276 | self.distribution = MultivariateNormal(loc=mean_actions, covariance_matrix=sigma_mat, validate_args=False)
277 | return self
278 |
279 | def log_prob(self, actions: torch.Tensor) -> torch.Tensor:
280 | if self.bijector is not None:
281 | gaussian_actions = self.bijector.inverse(actions)
282 | else:
283 | gaussian_actions = actions
284 | log_prob = self.distribution.log_prob(gaussian_actions)
285 |
286 | if self.bijector is not None:
287 | # Squash correction
288 | log_prob -= torch.sum(self.bijector.log_prob_correction(gaussian_actions), dim=1)
289 | return log_prob
290 |
291 | def entropy(self) -> torch.Tensor:
292 | if self.bijector is not None:
293 | return None
294 | return self.distribution.entropy()
295 |
296 | def get_noise(
297 | self,
298 | latent_sde: torch.Tensor,
299 | exploration_mat: torch.Tensor,
300 | exploration_matrices: torch.Tensor,
301 | ) -> torch.Tensor:
302 | latent_sde = latent_sde if self.learn_features else latent_sde.detach()
303 | # Default case: only one exploration matrix
304 | if len(latent_sde) == 1 or len(latent_sde) != len(exploration_matrices):
305 | return torch.mm(latent_sde, exploration_mat)
306 | # Use batch matrix multiplication for efficient computation
307 | # (batch_size, n_features) -> (batch_size, 1, n_features)
308 | latent_sde = latent_sde.unsqueeze(dim=1)
309 | # (batch_size, 1, n_actions)
310 | noise = torch.bmm(latent_sde, exploration_matrices)
311 | return noise.squeeze(dim=1)
312 |
313 | def sample(self) -> torch.Tensor:
314 | latent_noise = self.alpha * self.get_noise(self._latent_sde, self.corr_exploration_mat, self.corr_exploration_matrices)
315 | action_noise = self.get_noise(self._latent_sde, self.ind_exploration_mat, self.ind_exploration_matrices)
316 | actions = self.clipped_mean_actions_net(self._latent_sde + latent_noise) + action_noise
317 | if self.bijector is not None:
318 | return self.bijector.forward(actions)
319 | return actions
320 |
--------------------------------------------------------------------------------
/src/models/ppo/policies.py:
--------------------------------------------------------------------------------
1 | from stable_baselines3.common.preprocessing import get_action_dim
2 | from models.distributions import (
3 | LatticeNoiseDistribution,
4 | LatticeStateDependentNoiseDistribution,
5 | )
6 | from sb3_contrib.common.recurrent.policies import RecurrentActorCriticPolicy
7 |
8 |
9 | class LatticeRecurrentActorCriticPolicy(RecurrentActorCriticPolicy):
10 | def __init__(
11 | self,
12 | observation_space,
13 | action_space,
14 | lr_schedule,
15 | use_lattice=True,
16 | std_clip=(1e-3, 10),
17 | expln_eps=1e-6,
18 | std_reg=0,
19 | alpha=1,
20 | **kwargs
21 | ):
22 | super().__init__(observation_space, action_space, lr_schedule, **kwargs)
23 | if use_lattice:
24 | if self.use_sde:
25 | self.dist_kwargs.update(
26 | {
27 | "epsilon": expln_eps,
28 | "std_clip": std_clip,
29 | "std_reg": std_reg,
30 | "alpha": alpha,
31 | }
32 | )
33 | self.action_dist = LatticeStateDependentNoiseDistribution(
34 | get_action_dim(self.action_space), **self.dist_kwargs
35 | )
36 | else:
37 | self.action_dist = LatticeNoiseDistribution(
38 | get_action_dim(self.action_space)
39 | )
40 | self._build(lr_schedule)
41 |
--------------------------------------------------------------------------------
/src/models/sac/policies.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Optional, Type
2 | import torch
3 | from torch import nn
4 | from models.distributions import (
5 | LatticeStateDependentNoiseDistribution,
6 | SquashedLatticeNoiseDistribution,
7 | )
8 | from stable_baselines3.common.preprocessing import get_action_dim
9 | from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
10 | from stable_baselines3.sac.policies import SACPolicy, Actor
11 |
12 |
13 | class LatticeActor(Actor):
14 | def __init__(
15 | self,
16 | observation_space,
17 | action_space,
18 | net_arch,
19 | features_extractor,
20 | features_dim,
21 | activation_fn: Type[nn.Module] = nn.ReLU,
22 | use_sde: bool = False,
23 | log_std_init: float = -3,
24 | full_std: bool = True,
25 | sde_net_arch: Optional[List[int]] = None,
26 | use_expln: bool = False,
27 | clip_mean: float = 2.0,
28 | normalize_images: bool = True,
29 | use_lattice=False,
30 | std_clip=(1e-3, 10),
31 | expln_eps=1e-6,
32 | std_reg=0,
33 | alpha=1,
34 | ):
35 | super().__init__(
36 | observation_space,
37 | action_space,
38 | net_arch,
39 | features_extractor,
40 | features_dim,
41 | activation_fn=activation_fn,
42 | use_sde=use_sde,
43 | log_std_init=log_std_init,
44 | full_std=full_std,
45 | sde_net_arch=sde_net_arch,
46 | use_expln=use_expln,
47 | clip_mean=clip_mean,
48 | normalize_images=normalize_images,
49 | )
50 | self.use_lattice = use_lattice
51 | self.std_clip = std_clip
52 | self.expln_eps = expln_eps
53 | self.std_reg = std_reg
54 | self.alpha = alpha
55 | if use_lattice:
56 | last_layer_dim = net_arch[-1] if len(net_arch) > 0 else features_dim
57 | action_dim = get_action_dim(self.action_space)
58 | if self.use_sde:
59 | self.action_dist = LatticeStateDependentNoiseDistribution(
60 | action_dim,
61 | full_std=full_std,
62 | use_expln=use_expln,
63 | squash_output=True,
64 | learn_features=True,
65 | epsilon=expln_eps,
66 | std_clip=std_clip,
67 | std_reg=std_reg,
68 | alpha=alpha,
69 | )
70 | self.mu, self.log_std = self.action_dist.proba_distribution_net(
71 | latent_dim=last_layer_dim,
72 | latent_sde_dim=last_layer_dim,
73 | log_std_init=log_std_init,
74 | clip_mean=clip_mean
75 | )
76 | else:
77 | self.action_dist = SquashedLatticeNoiseDistribution(action_dim)
78 | self.mu, self.log_std = self.action_dist.proba_distribution_net(last_layer_dim, log_std_init, state_dependent=True)
79 |
80 | def _get_constructor_parameters(self) -> Dict[str, Any]:
81 | data = super()._get_constructor_parameters()
82 | data.update(
83 | dict(
84 | use_lattice=self.use_lattice,
85 | std_clip=self.std_clip,
86 | expln_eps=self.expln_eps,
87 | std_reg=self.std_reg,
88 | alpha=self.alpha,
89 | )
90 | )
91 | return data
92 |
93 | def get_std(self) -> torch.Tensor:
94 | std = super().get_std()
95 | if self.use_lattice:
96 | std = torch.cat(std, dim=1)
97 | return std
98 |
99 |
100 | class LatticeSACPolicy(SACPolicy):
101 | def __init__(
102 | self,
103 | observation_space,
104 | action_space,
105 | lr_schedule,
106 | use_lattice=False,
107 | std_clip=(1e-3, 10),
108 | expln_eps=1e-6,
109 | std_reg=0,
110 | use_sde=False,
111 | alpha=1,
112 | **kwargs
113 | ):
114 | super().__init__(
115 | observation_space, action_space, lr_schedule, use_sde=use_sde, **kwargs
116 | )
117 | self.lattice_kwargs = {
118 | "use_lattice": use_lattice,
119 | "expln_eps": expln_eps,
120 | "std_clip": std_clip,
121 | "std_reg": std_reg,
122 | "alpha": alpha,
123 | }
124 | self.actor_kwargs.update(self.lattice_kwargs)
125 | if use_lattice:
126 | self._build(lr_schedule)
127 |
128 | def make_actor(
129 | self, features_extractor: Optional[BaseFeaturesExtractor] = None
130 | ) -> Actor:
131 | actor_kwargs = self._update_features_extractor(
132 | self.actor_kwargs, features_extractor
133 | )
134 | return LatticeActor(**actor_kwargs).to(self.device)
135 |
136 | def _get_constructor_parameters(self) -> Dict[str, Any]:
137 | data = super()._get_constructor_parameters()
138 | data.update(self.lattice_kwargs)
139 | return data
140 |
--------------------------------------------------------------------------------
/src/train/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amathislab/lattice/846d02fa993b9b80ce5ecb806463e0a05711bad3/src/train/__init__.py
--------------------------------------------------------------------------------
/src/train/trainer.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from abc import ABC
4 | from dataclasses import dataclass, field
5 | from typing import List
6 | from stable_baselines3 import PPO, SAC, TD3
7 | from sb3_contrib import RecurrentPPO
8 | from stable_baselines3.common.callbacks import BaseCallback
9 | from stable_baselines3.common.vec_env import VecNormalize
10 |
11 |
12 | class Trainer(ABC):
13 | """
14 | Protocol to train a library-independent RL algorithm on a gym environment.
15 | """
16 |
17 | envs: VecNormalize
18 | env_config: dict
19 | model_config: dict
20 | model_path: str
21 | total_timesteps: int
22 | log: bool
23 |
24 | def _init_agent(self) -> None:
25 | """Initialize the agent."""
26 |
27 | def train(self, total_timesteps: int) -> None:
28 | """Train agent on environment for total_timesteps episdodes."""
29 |
30 |
31 | @dataclass
32 | class MyoTrainer:
33 | algo: str
34 | envs: VecNormalize
35 | env_config: dict
36 | load_model_path: str
37 | log_dir: str
38 | model_config: dict = field(default_factory=dict)
39 | callbacks: List[BaseCallback] = field(default_factory=list)
40 | timesteps: int = 10_000_000
41 |
42 | def __post_init__(self):
43 | self.dump_configs(path=self.log_dir)
44 | self.agent = self._init_agent()
45 |
46 | def dump_configs(self, path: str) -> None:
47 | with open(os.path.join(path, "env_config.json"), "w", encoding="utf8") as f:
48 | json.dump(self.env_config, f, indent=4, default=lambda _: '')
49 | with open(os.path.join(path, "model_config.json"), "w", encoding="utf8") as f:
50 | json.dump(self.model_config, f, indent=4, default=lambda _: '')
51 |
52 | def _init_agent(self):
53 | algo_class = self.get_algo_class()
54 | if self.load_model_path is not None:
55 | return algo_class.load(
56 | self.load_model_path,
57 | env=self.envs,
58 | tensorboard_log=self.log_dir,
59 | custom_objects=self.model_config,
60 | )
61 | print("\nNo model path provided. Initializing new model.\n")
62 | return algo_class(
63 | env=self.envs,
64 | verbose=2,
65 | tensorboard_log=self.log_dir,
66 | **self.model_config,
67 | )
68 |
69 | def train(self, total_timesteps: int) -> None:
70 | self.agent.learn(
71 | total_timesteps=total_timesteps,
72 | callback=self.callbacks,
73 | reset_num_timesteps=True,
74 | )
75 |
76 | def save(self) -> None:
77 | self.agent.save(os.path.join(self.log_dir, "final_model.pkl"))
78 | self.envs.save(os.path.join(self.log_dir, "final_env.pkl"))
79 |
80 | def get_algo_class(self):
81 | if self.algo == "ppo":
82 | return PPO
83 | elif self.algo == "recurrent_ppo":
84 | return RecurrentPPO
85 | elif self.algo == "sac":
86 | return SAC
87 | elif self.algo == "td3":
88 | return TD3
89 | else:
90 | raise ValueError("Unknown algorithm ", self.algo)
91 |
92 |
93 | if __name__ == "__main__":
94 | print("This is a module. Run main.py to train the agent.")
95 |
--------------------------------------------------------------------------------