├── .gitignore ├── .vscode ├── launch.json └── settings.json ├── LICENSE ├── README.md ├── collect_demo.py ├── conf ├── agent │ ├── ddpg.yaml │ ├── ddqn.yaml │ ├── dqn.yaml │ ├── dueldqn.yaml │ ├── ppo.yaml │ ├── sac.yaml │ ├── td3.yaml │ └── trpo.yaml ├── collect_demo.yaml └── train_agent.yaml ├── plot.py ├── requirements.txt ├── src ├── __init__.py ├── base.py ├── ddpg.py ├── ddqn.py ├── dqn.py ├── dueldqn.py ├── ppo.py ├── sac.py ├── td3.py ├── trpo.py └── utils │ ├── __init__.py │ ├── drls │ ├── __init__.py │ ├── buffer.py │ ├── env.py │ └── gae.py │ ├── exp │ ├── __init__.py │ └── prepare.py │ ├── logger │ ├── __init__.py │ ├── _archive.py │ ├── _logger.py │ ├── _plot.py │ └── _sync.py │ ├── net │ ├── __init__.py │ ├── actor.py │ ├── critic.py │ └── ptu.py │ └── ospy │ ├── __init__.py │ ├── dataset.py │ ├── file.py │ └── util.py └── train_agent.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Self added 2 | runs/ 3 | models/ 4 | archived/ 5 | data/ 6 | 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | pip-wheel-metadata/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 102 | __pypackages__/ 103 | 104 | # Celery stuff 105 | celerybeat-schedule 106 | celerybeat.pid 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | .dmypy.json 133 | dmypy.json 134 | 135 | # Pyre type checker 136 | .pyre/ 137 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // 使用 IntelliSense 了解相关属性。 3 | // 悬停以查看现有属性的描述。 4 | // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: 当前文件", 9 | "type": "debugpy", 10 | "request": "launch", 11 | "program": "${file}", 12 | "console": "integratedTerminal", 13 | "justMyCode": false, 14 | "args": [ 15 | "agent=ppo", 16 | "env.id=Hopper-v5", 17 | "log.console_output=false" 18 | ] 19 | } 20 | ] 21 | } -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "cSpell.enabled": false 3 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Yi-Chen Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RL-pytorch 2 | Re-implementations of Deep Reinforcement Learning (DRL) algorithms, written in PyTorch. 3 | 4 | ## Installation 5 | 6 | ```bash 7 | pip install -r requirements.txt 8 | ``` 9 | 10 | ## Implemented Algorithms 11 | 12 | - [x] Deep Q Networks (DQN) [[paper](https://www.nature.com/articles/nature14236.pdf)] [[official code](https://github.com/deepmind/dqn)] 13 | - [x] Deep Double Q Networks (DDQN) [[paper](https://arxiv.org/pdf/1509.06461.pdf)] 14 | - [x] Dueling Network Architectures for Deep Reinforcement Learning (DuelDQN) [[paper](https://arxiv.org/pdf/1511.06581.pdf)] 15 | - [x] Continuous control with deep reinforcement learning (DDPG) [[paper](https://arxiv.org/pdf/1509.02971.pdf)] 16 | - [x] Addressing Function Approximation Error in Actor-Critic Methods (TD3) [[paper](https://arxiv.org/pdf/1802.09477.pdf)] [[official code](https://github.com/sfujim/TD3)] 17 | - [x] Soft Actor-Critic Algorithms and Applications (SAC) [[paper](https://arxiv.org/pdf/1812.05905.pdf)] [[official code](https://github.com/rail-berkeley/softlearning/)] 18 | - [x] Trust Region Policy Optimization (TRPO) [[paper](https://arxiv.org/pdf/1502.05477.pdf)] [[official code](https://github.com/joschu/modular_rl)] 19 | - [x] Proximal Policy Optimization (PPO) [[paper](https://arxiv.org/pdf/1707.06347.pdf)] [[official code](https://github.com/openai/baselines)] 20 | 21 | ## Run Experiments 22 | 23 | ```bash 24 | # train an RL agent 25 | # by default, training results are stored at the `runs` dir 26 | python train_agent.py agent=ppo env.id=Hopper-v5 27 | 28 | # plot the training results 29 | python plot.py 30 | 31 | # collect expert demonstrations 32 | python collect_demo.py env.id=Hopper-v5 expert_model_path=models/hopper_sac_expert.pt 33 | ``` 34 | 35 | ## Acknowledgement 36 | With the progress of this project, I found many open-source materials on the Internet to be excellent references. I am deeply grateful for the efforts of their authors. Below is a detailed list. Additionally, I would like to extend my thanks to my friends from [LAMDA-RL](https://github.com/LAMDA-RL) for our helpful discussions. 37 | 38 | **Codebase** 39 | 40 | + [tianshou](https://github.com/thu-ml/tianshou) 41 | + [stable-baselines3](https://github.com/DLR-RM/stable-baselines3) 42 | + [stable-baselines-contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) 43 | + [stable-baselines](https://github.com/Stable-Baselines-Team/stable-baselines) 44 | + [spinningup](https://github.com/openai/spinningup) 45 | + [RL-Adventure2](https://github.com/higgsfield/RL-Adventure-2) 46 | + [unstable_baselines](https://github.com/x35f/unstable_baselines) 47 | + [d4rl_evaluations](https://github.com/rail-berkeley/d4rl_evaluations) 48 | + [TD3](https://github.com/sfujim/TD3) 49 | + [pytorch-trpo](https://github.com/ikostrikov/pytorch-trpo) 50 | 51 | **Blog** 52 | 53 | + [The 37 Implementation Details of Proximal Policy Optimization](https://iclr.iro.umontreal.ca/679b37e0-caab-4710-921b-b59a688075df_1642188062/blog/) 54 | 55 | **Tutorial** 56 | 57 | + [OpenAI SpinningUp](https://spinningup.openai.com/en/latest/index.html) 58 | -------------------------------------------------------------------------------- /collect_demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join 3 | from typing import Callable, Dict, Tuple 4 | 5 | import gymnasium as gym 6 | import hydra 7 | import numpy as np 8 | import torch as th 9 | from omegaconf import DictConfig, OmegaConf 10 | 11 | from src import BaseRLAgent, create_agent 12 | from src.utils.drls.env import get_env_info, make_env, reset_env_fn 13 | from src.utils.exp.prepare import set_random_seed 14 | from src.utils.logger import TBLogger 15 | from src.utils.net.ptu import load_torch_model, set_torch 16 | from src.utils.ospy.dataset import get_dataset_holder, save_dataset_to_h5 17 | 18 | 19 | @th.no_grad() 20 | def _collect_demo( 21 | policy: BaseRLAgent, 22 | env: gym.Env, 23 | reset_env_fn: Callable, 24 | save_dir: str, 25 | save_name: str, 26 | n_traj: int = 10, 27 | n_step: int = 1000_000, 28 | with_log_prob: bool = False, 29 | seed: int = 0, 30 | ) -> Tuple[str, Dict[str, np.ndarray]]: 31 | """Collect dataset 32 | 33 | :param n_traj: Collect [n_trajs] trajectories 34 | :param n_step: Collect [max_steps] transitions 35 | """ 36 | collected_steps, collected_n_traj = 0, 0 37 | _dataset = get_dataset_holder(with_log_prob) 38 | 39 | next_obs, _ = reset_env_fn(env, seed) 40 | while collected_n_traj < n_traj or collected_steps < n_step: 41 | collected_steps += 1 42 | 43 | obs = next_obs 44 | if with_log_prob: 45 | action, log_prob = policy.select_action( 46 | obs, deterministic=True, return_log_prob=True 47 | ) 48 | else: 49 | action = policy.select_action( 50 | obs, deterministic=True, return_log_prob=False 51 | ) 52 | next_obs, reward, terminated, truncated, _ = env.step(action.cpu().numpy()) 53 | 54 | # insert 55 | _dataset["observations"].append(obs) 56 | _dataset["actions"].append(action) 57 | _dataset["rewards"].append(reward) 58 | _dataset["next_observations"].append(next_obs) 59 | _dataset["terminals"].append(terminated) 60 | _dataset["timeouts"].append(truncated) 61 | 62 | if with_log_prob: 63 | _dataset["infos/action_log_probs"].append(log_prob) 64 | 65 | if terminated or truncated: 66 | next_obs, _ = reset_env_fn(env, seed) 67 | collected_n_traj += 1 68 | 69 | dataset = {} 70 | if with_log_prob: 71 | dataset["infos/action_log_probs"] = np.array( 72 | _dataset["infos/action_log_probs"] 73 | ).astype(np.float64) 74 | 75 | dataset.update( 76 | dict( 77 | observations=np.array(_dataset["observations"]).astype(np.float64), 78 | actions=np.array(_dataset["actions"]).astype(np.float64), 79 | next_observations=np.array(_dataset["next_observations"]).astype( 80 | np.float64 81 | ), 82 | rewards=np.array(_dataset["rewards"]).astype(np.float64), 83 | terminals=np.array(_dataset["terminals"]).astype(np.bool_), 84 | timeouts=np.array(_dataset["timeouts"]).astype(np.bool_), 85 | ) 86 | ) 87 | 88 | # dump the saved dataset 89 | save_dataset_to_h5(dataset, save_dir, save_name) 90 | 91 | return ( 92 | f"Successfully save expert demonstration into {save_dir}/{save_name}.hdf5!", 93 | dataset, 94 | ) 95 | 96 | 97 | @hydra.main(config_path="./conf", config_name="collect_demo", version_base="1.3.2") 98 | def main(cfg: DictConfig): 99 | cfg.work_dir = os.getcwd() 100 | # prepare experiment 101 | set_torch() 102 | set_random_seed(cfg.seed) 103 | 104 | # setup logger 105 | logger = TBLogger(args=OmegaConf.to_object(cfg), record_param=cfg.log.record_param) 106 | 107 | # setup environment 108 | env = make_env(cfg.env.id) 109 | OmegaConf.update(cfg, "env[info]", get_env_info(env), merge=False) 110 | 111 | # create agent 112 | agent = create_agent(cfg) 113 | logger.console.info( 114 | load_torch_model(agent.models, join(cfg.work_dir, cfg.expert_model_path)) 115 | ) 116 | 117 | # collect expert dataset 118 | logger.console.info(f"Collecting expert data on the environment {cfg.env.id}...") 119 | logger.console.info( 120 | _collect_demo( 121 | agent, 122 | env, 123 | reset_env_fn, 124 | cfg.demo.save_dir, 125 | cfg.demo.save_name, 126 | cfg.demo.n_traj, 127 | cfg.demo.n_step, 128 | cfg.demo.with_log_prob, 129 | cfg.seed, 130 | )[0] 131 | ) 132 | 133 | 134 | if __name__ == "__main__": 135 | main() 136 | -------------------------------------------------------------------------------- /conf/agent/ddpg.yaml: -------------------------------------------------------------------------------- 1 | algo: ddpg 2 | gamma: 0.99 3 | batch_size: 256 4 | expl_std: 0.1 5 | warmup_steps: 10_000 6 | buffer_size: 1000_000 7 | env_steps: 1 8 | 9 | actor: 10 | net_arch: [256, 256] 11 | activation_fn: ReLU 12 | optimizer: Adam 13 | lr: !!float 3e-4 14 | tau: 0.05 15 | 16 | critic: 17 | net_arch: [256, 256] 18 | activation_fn: ReLU 19 | optimizer: Adam 20 | lr: !!float 3e-4 21 | tau: 0.05 22 | -------------------------------------------------------------------------------- /conf/agent/ddqn.yaml: -------------------------------------------------------------------------------- 1 | algo: ddqn 2 | gamma: 0.99 3 | batch_size: 256 4 | target_update_freq: 10 5 | epsilon: 0.1 6 | buffer_size: 8_000 7 | 8 | QNet: 9 | net_arch: [100, 100] 10 | activation_fn: ReLU 11 | optimizer: Adam 12 | lr: !!float 3e-4 -------------------------------------------------------------------------------- /conf/agent/dqn.yaml: -------------------------------------------------------------------------------- 1 | algo: dqn 2 | gamma: 0.99 3 | batch_size: 256 4 | target_update_freq: 10 5 | epsilon: 0.1 6 | buffer_size: 8_000 7 | 8 | QNet: 9 | net_arch: [100, 100] 10 | activation_fn: ReLU 11 | optimizer: Adam 12 | lr: !!float 3e-4 -------------------------------------------------------------------------------- /conf/agent/dueldqn.yaml: -------------------------------------------------------------------------------- 1 | algo: dueldqn 2 | gamma: 0.99 3 | batch_size: 256 4 | target_update_freq: 10 5 | epsilon: 0.1 6 | buffer_size: 8_000 7 | 8 | QNet: 9 | net_arch: [100] 10 | activation_fn: ReLU 11 | optimizer: Adam 12 | lr: !!float 3e-4 13 | v_head: [100] 14 | adv_head: [100] 15 | mix_type: max -------------------------------------------------------------------------------- /conf/agent/ppo.yaml: -------------------------------------------------------------------------------- 1 | algo: ppo 2 | gamma: 0.99 3 | batch_size: 256 4 | rollout_steps: 2048 5 | lambda_: 0.97 6 | norm_adv: true 7 | use_td_lambda: true 8 | epsilon: 0.1 9 | buffer_size: -1 10 | entropy_coef: 0. 11 | 12 | actor: 13 | net_arch: [64, 64] 14 | activation_fn: ReLU 15 | state_std_independent: False 16 | optimizer: Adam 17 | lr: !!float 1e-3 18 | n_update: 5 19 | clip: 0.5 20 | 21 | value_net: 22 | net_arch: [64, 64] 23 | activation_fn: ReLU 24 | optimizer: Adam 25 | lr: !!float 1e-3 26 | n_update: 5 -------------------------------------------------------------------------------- /conf/agent/sac.yaml: -------------------------------------------------------------------------------- 1 | algo: sac 2 | gamma: 0.99 3 | batch_size: 256 4 | warmup_steps: 10_000 5 | env_steps: 1 6 | buffer_size: 1000_000 7 | 8 | actor: 9 | net_arch: [256, 256] 10 | activation_fn: ReLU 11 | state_std_independent: False 12 | optimizer: Adam 13 | lr: !!float 3e-4 14 | 15 | critic: 16 | net_arch: [256, 256] 17 | activation_fn: ReLU 18 | optimizer: Adam 19 | lr: !!float 3e-4 20 | tau: 0.05 21 | 22 | log_alpha: 23 | auto_tune: True 24 | init_value: 0.0 25 | optimizer: Adam 26 | lr: !!float 3e-4 27 | -------------------------------------------------------------------------------- /conf/agent/td3.yaml: -------------------------------------------------------------------------------- 1 | algo: td3 2 | gamma: 0.99 3 | batch_size: 256 4 | sigma: 0.2 5 | c: 0.5 6 | policy_freq: 2 7 | warmup_steps: 10_000 8 | env_steps: 1 9 | buffer_size: 1000_000 10 | 11 | actor: 12 | net_arch: [256, 256] 13 | activation_fn: ReLU 14 | optimizer: Adam 15 | lr: !!float 3e-4 16 | tau: 0.05 17 | 18 | critic: 19 | net_arch: [256, 256] 20 | activation_fn: ReLU 21 | optimizer: Adam 22 | lr: !!float 3e-4 23 | tau: 0.05 24 | -------------------------------------------------------------------------------- /conf/agent/trpo.yaml: -------------------------------------------------------------------------------- 1 | algo: trpo 2 | gamma: 0.99 3 | batch_size: 256 4 | rollout_steps: 2048 5 | lambda_: 0.97 6 | norm_adv: true 7 | use_td_lambda: true 8 | buffer_size: 50_000 9 | 10 | # conjugate gradient param 11 | residual_tol: !!float 1e-10 12 | cg_steps: 10 13 | damping: !!float 1e-1 14 | 15 | # line search param 16 | beta: 0.8 17 | max_backtrack: 15 18 | accept_ratio: !!float 1e-1 19 | delta: !!float 1e-2 20 | 21 | actor: 22 | net_arch: [256, 256] 23 | activation_fn: ReLU 24 | state_std_independent: False 25 | 26 | value_net: 27 | net_arch: [256, 256] 28 | activation_fn: ReLU 29 | optimizer: Adam 30 | lr: !!float 1e-3 31 | n_update: 5 -------------------------------------------------------------------------------- /conf/collect_demo.yaml: -------------------------------------------------------------------------------- 1 | work_dir: 2 | seed: 3407 3 | device: cuda:0 4 | description: "Demo" 5 | 6 | # env: 7 | env: 8 | id: Hopper-v4 9 | info: 10 | state_shape: 11 | action_shape: 12 | action_dtype: 13 | 14 | # Note: All paths below are absolute paths or relative paths to project dir 15 | expert_model_path: models/hopper_sac_expert.pt 16 | 17 | # log 18 | log: 19 | record_param: 20 | - description 21 | - seed 22 | - agent.algo 23 | - env.id 24 | 25 | # demo 26 | demo: 27 | n_traj: 0 28 | n_step: 1000_000 29 | save_dir: data 30 | save_name: "" 31 | with_log_prob: False 32 | 33 | # algo 34 | defaults: 35 | - agent: sac 36 | - override hydra/hydra_logging: disabled 37 | - override hydra/job_logging: disabled 38 | - _self_ 39 | 40 | hydra: 41 | output_subdir: null 42 | run: 43 | dir: . -------------------------------------------------------------------------------- /conf/train_agent.yaml: -------------------------------------------------------------------------------- 1 | work_dir: 2 | seed: 3407 3 | device: cuda:0 4 | description: "" 5 | 6 | # env: 7 | env: 8 | id: Hopper-v5 9 | info: 10 | state_shape: 11 | action_shape: 12 | action_dtype: 13 | 14 | # log 15 | log: 16 | record_param: 17 | - seed 18 | - agent.algo 19 | - env.id 20 | console_output: true 21 | 22 | # train 23 | train: 24 | max_steps: 1000_000 25 | eval_interval: 5_000 26 | 27 | # algo 28 | defaults: 29 | - agent: ppo 30 | - override hydra/hydra_logging: disabled 31 | - override hydra/job_logging: disabled 32 | - _self_ 33 | 34 | hydra: 35 | output_subdir: null 36 | run: 37 | dir: . -------------------------------------------------------------------------------- /plot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | 4 | import matplotlib.pyplot as plt 5 | import pandas as pd 6 | import seaborn as sbn 7 | 8 | from src.utils.logger import tb2dict, window_smooth 9 | from src.utils.ospy import filter_from_list 10 | 11 | # Hyper-param 12 | WORK_DIR = osp.expanduser("~/workspace/RL-pytorch/runs") 13 | LOG_DIRs = [ 14 | "2024-01-27__18-19-31~seed=3407~agent.algo=ppo~env.id=Hopper-v4", 15 | "2024-01-27__19-14-55~seed=1290~agent.algo=ppo~env.id=Hopper-v4", 16 | ] 17 | KEYs = ["return/eval", "return/train"] 18 | RULE = "events.out.tfevents*" 19 | SMOOTH_WINDOW_SIZE = 10 20 | 21 | # Drawing 22 | sbn.set_style("darkgrid") 23 | 24 | ## 1. Convert tensorboard files to a list of data points 25 | datas = {key: list() for key in KEYs} 26 | env_id, algo = None, None 27 | for log_dir in LOG_DIRs: 28 | # check 29 | _log_dir = log_dir.split("~") 30 | tmp_env_id = filter_from_list(_log_dir, "env.id=*")[0].split("=")[-1] 31 | tmp_algo = filter_from_list(_log_dir, "agent.algo=*")[0].split("=")[-1] 32 | if env_id is None: 33 | env_id = tmp_env_id 34 | else: 35 | assert ( 36 | env_id == tmp_env_id 37 | ), "The data used to plot must from the same environment!" 38 | if algo is None: 39 | algo = tmp_algo 40 | else: 41 | assert algo == tmp_algo, "The data used to plot must from the same algorithm" 42 | # get data 43 | dir_path = osp.join(WORK_DIR, log_dir) 44 | tb_file = filter_from_list(os.listdir(dir_path), RULE)[0] 45 | data = tb2dict(osp.join(dir_path, tb_file), KEYs) 46 | for key in KEYs: 47 | datas[key].append(data[key]) 48 | merged_datas = {key: {"steps": list(), "values": list()} for key in KEYs} 49 | for key in KEYs: # smooth 50 | for i in range(len(datas[key])): 51 | merged_datas[key]["steps"] += datas[key][i]["steps"] 52 | merged_datas[key]["values"] += window_smooth( 53 | datas[key][i]["values"], SMOOTH_WINDOW_SIZE 54 | ) 55 | 56 | ## 2. Drawing multiple lines in a single picture 57 | for key in KEYs: 58 | sbn.lineplot(data=pd.DataFrame(merged_datas[key]), x="steps", y="values", label=key) 59 | plt.title(f"Learning Curves of {algo} on {env_id}") 60 | plt.xlabel("Steps", size=14) 61 | plt.ylabel("Return", size=14) 62 | plt.yticks(size=14) 63 | plt.legend(loc="lower right", fontsize=14) 64 | plt.savefig("result.pdf") 65 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gym<0.25.0 2 | gymnasium==1.0.0 3 | h5py==3.12.1 4 | hydra-core==1.3.2 5 | loguru==0.7.3 6 | matplotlib==3.10.0 7 | numpy==2.2.1 8 | omegaconf==2.3.0 9 | pandas==2.2.3 10 | paramiko==3.5.0 11 | seaborn==0.13.2 12 | tensorboard==2.18.0 13 | tensorboardX==2.6.2.2 14 | torch==2.5.1 15 | tqdm==4.67.1 16 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from omegaconf import DictConfig 4 | 5 | import utils 6 | 7 | from .base import BaseRLAgent 8 | from .ddpg import DDPGAgent 9 | from .ddqn import DDQNAgent 10 | from .dqn import DQNAgent 11 | from .dueldqn import DuelDQNAgent 12 | from .ppo import PPOAgent 13 | from .sac import SACAgent 14 | from .td3 import TD3Agent 15 | from .trpo import TRPOAgent 16 | 17 | AGENTS: Dict[str, BaseRLAgent] = { 18 | "ddpg": DDPGAgent, 19 | "ddqn": DDQNAgent, 20 | "dqn": DQNAgent, 21 | "dueldqn": DuelDQNAgent, 22 | "ppo": PPOAgent, 23 | "sac": SACAgent, 24 | "td3": TD3Agent, 25 | "trpo": TRPOAgent, 26 | } 27 | 28 | 29 | def create_agent(cfg: DictConfig) -> BaseRLAgent: 30 | """To instantiate an agent""" 31 | 32 | def _get_agent(cfg: DictConfig) -> BaseRLAgent: 33 | """For python annotations""" 34 | return AGENTS[cfg.agent.algo](cfg) 35 | 36 | agent = _get_agent(cfg) 37 | agent.setup_model() 38 | return agent 39 | 40 | 41 | __all__ = [ 42 | DDPGAgent, 43 | DDQNAgent, 44 | DQNAgent, 45 | DuelDQNAgent, 46 | PPOAgent, 47 | SACAgent, 48 | TD3Agent, 49 | TRPOAgent, 50 | BaseRLAgent, 51 | AGENTS, 52 | utils, 53 | ] 54 | -------------------------------------------------------------------------------- /src/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from os.path import join 3 | from typing import Callable, Dict, Union 4 | 5 | import gymnasium as gym 6 | import numpy as np 7 | import torch as th 8 | from omegaconf import DictConfig 9 | from torch import nn, optim 10 | from tqdm import trange 11 | 12 | from src.utils.drls.buffer import TransitionBuffer 13 | from src.utils.logger import TBLogger 14 | from src.utils.net.ptu import save_torch_model, tensor2ndarray 15 | 16 | 17 | class BaseRLAgent(ABC): 18 | """Base for RL""" 19 | 20 | def __init__(self, cfg: DictConfig): 21 | self.cfg = cfg 22 | 23 | # hyper-param 24 | self.work_dir = cfg.work_dir 25 | self.device = th.device(cfg.device) 26 | self.seed = cfg.seed 27 | self.batch_size = self.cfg.agent.batch_size 28 | self.gamma = self.cfg.agent.gamma 29 | 30 | # bind env 31 | self.state_shape = tuple(cfg.env.info.state_shape) 32 | self.action_shape = tuple(cfg.env.info.action_shape) 33 | self.action_dtype = cfg.env.info.action_dtype 34 | 35 | # buffer 36 | buffer_kwarg = { 37 | "state_shape": self.state_shape, 38 | "action_shape": self.action_shape, 39 | "action_dtype": self.action_dtype, 40 | "device": self.device, 41 | "buffer_size": self.cfg.agent.buffer_size, 42 | } 43 | 44 | self.trans_buffer = TransitionBuffer(**buffer_kwarg) 45 | 46 | # models 47 | self.models: Dict[str, Union[nn.Module, optim.Optimizer, th.Tensor]] = dict() 48 | 49 | # -------- Initialization --------- 50 | @abstractmethod 51 | def setup_model(self): 52 | raise NotImplementedError 53 | 54 | # -------- Interaction ---------- 55 | @abstractmethod 56 | def select_action( 57 | self, 58 | state: Union[np.ndarray, th.Tensor], 59 | deterministic: bool, 60 | return_log_prob: bool, 61 | **kwarg, 62 | ) -> th.Tensor: 63 | raise NotImplementedError 64 | 65 | @abstractmethod 66 | def update(self) -> Dict: 67 | """Provide the algorithm details for updating parameters""" 68 | raise NotImplementedError 69 | 70 | def learn( 71 | self, 72 | train_env: gym.Env, 73 | eval_env: gym.Env, 74 | reset_env_fn: Callable, 75 | eval_policy: Callable, 76 | logger: TBLogger, 77 | ): 78 | train_return = 0 79 | best_return = -float("inf") 80 | train_steps = self.cfg.train.max_steps 81 | eval_interval = self.cfg.train.eval_interval 82 | 83 | # start training 84 | if self.cfg.log.console_output: 85 | progress_f = None 86 | progress_bar = trange(train_steps) 87 | else: 88 | progress_f = open(join(logger.exp_dir, "progress.txt"), "w") 89 | progress_bar = trange(train_steps, file=progress_f) 90 | next_state, _ = reset_env_fn(train_env, self.seed) 91 | for t in progress_bar: 92 | state = next_state 93 | if "warmup_steps" in self.cfg.agent and t < self.cfg.agent.warmup_steps: 94 | action = train_env.action_space.sample() 95 | else: 96 | action = self.select_action( 97 | state, 98 | deterministic=False, 99 | return_log_prob=False, 100 | action_space=train_env.action_space, 101 | ) 102 | action = tensor2ndarray((action,))[0] 103 | next_state, reward, terminated, truncated, _ = train_env.step(action) 104 | train_return += reward 105 | 106 | # insert transition into buffer 107 | self.trans_buffer.insert_transition( 108 | state, action, next_state, reward, float(terminated) 109 | ) 110 | 111 | # update policy 112 | logger.add_stats(self.update(), t) 113 | 114 | # whether this episode ends 115 | if terminated or truncated: 116 | logger.add_stats({"return/train": train_return}, t) 117 | next_state, _ = reset_env_fn(train_env, self.seed) 118 | train_return = 0 119 | 120 | # evaluate 121 | if (t + 1) % eval_interval == 0: 122 | eval_return = eval_policy(eval_env, reset_env_fn, self, self.seed) 123 | logger.add_stats({"return/eval": eval_return}, t) 124 | 125 | if eval_return > best_return: 126 | logger.console.info( 127 | f"Step {t}: get new best return: {eval_return}!" 128 | ) 129 | save_torch_model(self.models, logger.ckpt_dir, "best_model") 130 | best_return = eval_return 131 | 132 | if progress_f is not None: 133 | progress_f.close() 134 | -------------------------------------------------------------------------------- /src/ddpg.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Dict, Union 3 | 4 | import numpy as np 5 | import torch as th 6 | import torch.nn.functional as F 7 | from omegaconf import DictConfig 8 | from torch import nn, optim 9 | 10 | from src.utils.net.actor import MLPDeterministicActor 11 | from src.utils.net.critic import MLPCritic 12 | from src.utils.net.ptu import freeze_net, gradient_descent, move_device, polyak_update 13 | 14 | from .base import BaseRLAgent 15 | 16 | 17 | class DDPGAgent(BaseRLAgent): 18 | """Deep Deterministic Policy Gradient (DDPG)""" 19 | 20 | def __init__(self, cfg: DictConfig): 21 | super().__init__(cfg) 22 | 23 | def setup_model(self): 24 | # hyper-param 25 | self.warmup_steps = self.cfg.agent.warmup_steps 26 | self.env_steps = self.cfg.agent.env_steps 27 | 28 | # actor 29 | actor_kwarg = { 30 | "state_shape": self.state_shape, 31 | "net_arch": self.cfg.agent.actor.net_arch, 32 | "action_shape": self.action_shape, 33 | "activation_fn": getattr(nn, self.cfg.agent.actor.activation_fn), 34 | } 35 | self.actor = MLPDeterministicActor(**actor_kwarg) 36 | self.actor_target = deepcopy(self.actor) 37 | self.actor_optim = getattr(optim, self.cfg.agent.actor.optimizer)( 38 | self.actor.parameters(), self.cfg.agent.actor.lr 39 | ) 40 | 41 | # critic 42 | critic_kwarg = { 43 | "input_shape": (self.state_shape[0] + self.action_shape[0],), 44 | "output_shape": (1,), 45 | "net_arch": self.cfg.agent.critic.net_arch, 46 | "activation_fn": getattr(nn, self.cfg.agent.critic.activation_fn), 47 | } 48 | self.critic = MLPCritic(**critic_kwarg) 49 | self.critic_target = deepcopy(self.critic) 50 | self.critic_optim = getattr(optim, self.cfg.agent.critic.optimizer)( 51 | self.critic.parameters(), self.cfg.agent.critic.lr 52 | ) 53 | 54 | freeze_net((self.actor_target, self.critic_target)) 55 | move_device( 56 | (self.actor, self.actor_target, self.critic, self.critic_target), 57 | self.device, 58 | ) 59 | 60 | self.models.update( 61 | { 62 | "actor": self.actor, 63 | "actor_target": self.actor_target, 64 | "actor_optim": self.actor_optim, 65 | "critic": self.critic, 66 | "critic_target": self.critic_target, 67 | "critic_optim": self.critic_optim, 68 | } 69 | ) 70 | 71 | def select_action( 72 | self, 73 | state: Union[np.ndarray, th.Tensor], 74 | deterministic: bool, 75 | actor: nn.Module = None, 76 | **kwargs, 77 | ) -> th.Tensor: 78 | state = th.Tensor(state).to(self.device) if type(state) is np.ndarray else state 79 | 80 | if actor is None: 81 | action = self.actor(state) 82 | else: 83 | action = actor(state) 84 | 85 | action = th.tanh(action) 86 | 87 | # add explore noise 88 | if not deterministic: 89 | noise = th.randn_like(action) * (self.cfg.agent.expl_std) 90 | # by default, the action scale is [-1.,1.] 91 | action = th.clamp(action + noise, -1.0, 1.0) 92 | 93 | return action 94 | 95 | def update(self) -> Dict: 96 | self.stats = dict() 97 | rest_steps = self.trans_buffer.size - self.warmup_steps 98 | if not ( 99 | self.trans_buffer.size < self.batch_size 100 | or rest_steps < 0 101 | or rest_steps % self.env_steps != 0 102 | ): 103 | states, actions, next_states, rewards, dones = self.trans_buffer.sample( 104 | self.batch_size 105 | ) 106 | 107 | # update params 108 | for _ in range(self.env_steps): 109 | self._update_critic(states, actions, next_states, rewards, dones) 110 | self._update_actor(states) 111 | 112 | polyak_update( 113 | self.critic.parameters(), 114 | self.critic_target.parameters(), 115 | self.cfg.agent.critic.tau, 116 | ) 117 | polyak_update( 118 | self.actor.parameters(), 119 | self.actor_target.parameters(), 120 | self.cfg.agent.actor.tau, 121 | ) 122 | 123 | return self.stats 124 | 125 | def _update_critic( 126 | self, 127 | states: th.Tensor, 128 | actions: th.Tensor, 129 | next_states: th.Tensor, 130 | rewards: th.Tensor, 131 | dones: th.Tensor, 132 | ): 133 | with th.no_grad(): 134 | pred_next_actions = self.select_action( 135 | next_states, 136 | deterministic=False, 137 | actor=self.actor_target, 138 | ) 139 | target_Q = self.critic_target(next_states, pred_next_actions) 140 | target_Q = rewards + self.gamma * (1 - dones) * target_Q 141 | Q = self.critic(states, actions) 142 | critic_loss = F.mse_loss(Q, target_Q) 143 | self.stats.update( 144 | {"loss/critic": gradient_descent(self.critic_optim, critic_loss)} 145 | ) 146 | 147 | def _update_actor(self, states: th.Tensor): 148 | pred_actions = self.select_action(states, deterministic=True) 149 | Q = self.critic(states, pred_actions) 150 | actor_loss = -th.mean(Q) 151 | self.stats.update( 152 | {"loss/actor": gradient_descent(self.actor_optim, actor_loss)} 153 | ) 154 | -------------------------------------------------------------------------------- /src/ddqn.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from omegaconf import DictConfig 3 | 4 | from .dqn import DQNAgent 5 | 6 | 7 | class DDQNAgent(DQNAgent): 8 | """Deep Double Q Networks (DDQN)""" 9 | 10 | def __init__(self, cfg: DictConfig): 11 | super().__init__(cfg) 12 | 13 | def _get_q_target(self, next_states: th.Tensor): 14 | with th.no_grad(): 15 | _next_action = th.argmax(self.q_net(next_states), -1, True) 16 | q_target = th.gather(self.q_net_target(next_states), -1, _next_action) 17 | return q_target 18 | -------------------------------------------------------------------------------- /src/dqn.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Dict, Union 3 | 4 | import numpy as np 5 | import torch as th 6 | import torch.nn.functional as F 7 | from omegaconf import DictConfig 8 | from torch import nn, optim 9 | 10 | from src.utils.net.critic import MLPCritic 11 | from src.utils.net.ptu import freeze_net, gradient_descent, move_device 12 | 13 | from .base import BaseRLAgent 14 | 15 | 16 | class DQNAgent(BaseRLAgent): 17 | """Deep Q Networks (DQN)""" 18 | 19 | def __init__(self, cfg: DictConfig): 20 | super().__init__(cfg) 21 | 22 | def setup_model(self): 23 | # hyper-param 24 | self.target_update_freq = self.cfg.agent.target_update_freq 25 | self.epsilon = self.cfg.agent.epsilon 26 | self.global_t = 0 27 | 28 | # Q network 29 | q_net_kwarg = { 30 | "input_shape": self.state_shape, 31 | "output_shape": self.action_shape, 32 | "net_arch": self.cfg.agent.QNet.net_arch, 33 | "activation_fn": getattr(nn, self.cfg.agent.QNet.activation_fn), 34 | } 35 | self.q_net = MLPCritic(**q_net_kwarg) 36 | self.q_net_target = deepcopy(self.q_net) 37 | self.q_net_optim = getattr(optim, self.cfg.agent.QNet.optimizer)( 38 | self.q_net.parameters(), self.cfg.agent.QNet.lr 39 | ) 40 | 41 | freeze_net((self.q_net_target,)) 42 | move_device((self.q_net, self.q_net_target), self.device) 43 | 44 | self.models.update( 45 | { 46 | "q_net": self.q_net, 47 | "q_net_target": self.q_net_target, 48 | "q_net_optim": self.q_net_optim, 49 | } 50 | ) 51 | 52 | def select_action( 53 | self, state: Union[np.ndarray, th.Tensor], deterministic: bool, **kwarg 54 | ) -> th.Tensor: 55 | if not deterministic and np.random.random() < self.epsilon: 56 | return kwarg["action_space"].sample() 57 | with th.no_grad(): 58 | state = ( 59 | th.Tensor(state).to(self.device) if type(state) is np.ndarray else state 60 | ) 61 | pred_q = self.q_net(state) 62 | action = th.argmax(pred_q, dim=-1) 63 | return action 64 | 65 | def _get_q_target(self, next_states: th.Tensor): 66 | with th.no_grad(): 67 | q_target, _ = th.max(self.q_net_target(next_states), -1, True) 68 | return q_target 69 | 70 | def _get_q(self, states: th.Tensor, actions: th.Tensor): 71 | q = th.gather(self.q_net(states), -1, actions) 72 | return q 73 | 74 | def update(self) -> Dict: 75 | self.stats = dict() 76 | if self.trans_buffer.size >= self.batch_size: 77 | self.global_t += 1 78 | states, actions, next_states, rewards, dones = self.trans_buffer.sample( 79 | self.batch_size, shuffle=True 80 | ) 81 | 82 | # calculate q target and td target 83 | q_target = self._get_q_target(next_states) 84 | td_target = rewards + self.gamma * (1 - dones) * q_target 85 | 86 | # calculate q 87 | q = self._get_q(states, actions) 88 | 89 | # update q network 90 | loss = F.mse_loss(q, td_target) 91 | self.stats.update( 92 | { 93 | "loss": gradient_descent(self.q_net_optim, loss), 94 | "Q/q": th.mean(q).item(), 95 | "Q/q_target": th.mean(q_target).item(), 96 | } 97 | ) 98 | 99 | # update q target 100 | if self.global_t % self.target_update_freq == 0: 101 | self.q_net_target.load_state_dict(self.q_net.state_dict()) 102 | 103 | return self.stats 104 | -------------------------------------------------------------------------------- /src/dueldqn.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | from omegaconf import DictConfig 4 | from torch import nn, optim 5 | 6 | from src.utils.net.critic import MLPDuleQNet 7 | from src.utils.net.ptu import freeze_net, move_device 8 | 9 | from .ddqn import DDQNAgent 10 | 11 | 12 | class DuelDQNAgent(DDQNAgent): 13 | """Dueling Deep Q Networks (DuelDQN)""" 14 | 15 | def __init__(self, cfg: DictConfig): 16 | super().__init__(cfg) 17 | 18 | def setup_model(self): 19 | # hyper-param 20 | self.target_update_freq = self.cfg.agent.target_update_freq 21 | self.epsilon = self.cfg.agent.epsilon 22 | self.global_t = 0 23 | 24 | # Q network 25 | q_net_kwarg = { 26 | "input_shape": self.state_shape, 27 | "output_shape": self.action_shape, 28 | "net_arch": self.cfg.agent.QNet.net_arch, 29 | "v_head": self.cfg.agent.QNet.v_head, 30 | "adv_head": self.cfg.agent.QNet.adv_head, 31 | "activation_fn": getattr(nn, self.cfg.agent.QNet.activation_fn), 32 | "mix_type": self.cfg.agent.QNet.mix_type, 33 | } 34 | self.q_net = MLPDuleQNet(**q_net_kwarg) 35 | self.q_net_target = deepcopy(self.q_net) 36 | self.q_net_optim = getattr(optim, self.cfg.agent.QNet.optimizer)( 37 | self.q_net.parameters(), self.cfg.agent.QNet.lr 38 | ) 39 | 40 | freeze_net((self.q_net_target,)) 41 | move_device((self.q_net, self.q_net_target), self.device) 42 | 43 | self.models.update( 44 | { 45 | "q_net": self.q_net, 46 | "q_net_target": self.q_net_target, 47 | "q_net_optim": self.q_net_optim, 48 | } 49 | ) 50 | -------------------------------------------------------------------------------- /src/ppo.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch as th 4 | from omegaconf import DictConfig 5 | from torch import nn, optim 6 | from torch.utils.data import BatchSampler 7 | 8 | from src.utils.drls.gae import GAE 9 | from src.utils.net.actor import MLPGaussianActor 10 | from src.utils.net.critic import MLPCritic 11 | from src.utils.net.ptu import gradient_descent, move_device 12 | 13 | from .trpo import TRPOAgent 14 | 15 | 16 | class PPOAgent(TRPOAgent): 17 | """Proximal Policy Optimization (PPO)""" 18 | 19 | def __init__(self, cfg: DictConfig): 20 | super().__init__(cfg) 21 | 22 | def setup_model(self): 23 | # hyper-param 24 | self.epsilon = self.cfg.agent.epsilon 25 | self.lambda_ = self.cfg.agent.lambda_ 26 | 27 | # GAE 28 | self.gae = GAE( 29 | self.gamma, 30 | self.lambda_, 31 | self.cfg.agent.norm_adv, 32 | self.cfg.agent.use_td_lambda, 33 | ) 34 | 35 | # actor 36 | actor_kwarg = { 37 | "state_shape": self.state_shape, 38 | "net_arch": self.cfg.agent.actor.net_arch, 39 | "action_shape": self.action_shape, 40 | "activation_fn": getattr(nn, self.cfg.agent.actor.activation_fn), 41 | } 42 | self.actor = MLPGaussianActor(**actor_kwarg) 43 | self.actor_optim = getattr(optim, self.cfg.agent.actor.optimizer)( 44 | self.actor.parameters(), self.cfg.agent.actor.lr 45 | ) 46 | 47 | # value network 48 | value_net_kwarg = { 49 | "input_shape": self.state_shape, 50 | "output_shape": (1,), 51 | "net_arch": self.cfg.agent.value_net.net_arch, 52 | "activation_fn": getattr(nn, self.cfg.agent.value_net.activation_fn), 53 | } 54 | self.value_net = MLPCritic(**value_net_kwarg) 55 | self.value_net_optim = getattr(optim, self.cfg.agent.value_net.optimizer)( 56 | self.value_net.parameters(), self.cfg.agent.value_net.lr 57 | ) 58 | 59 | move_device((self.actor, self.value_net), self.device) 60 | 61 | self.models.update( 62 | { 63 | "actor": self.actor, 64 | "actor_optim": self.actor_optim, 65 | "value_net": self.value_net, 66 | "value_net_optim": self.value_net_optim, 67 | } 68 | ) 69 | 70 | def _update_actor(self, states: th.Tensor, actions: th.Tensor): 71 | with th.no_grad(): 72 | _, old_log_probs = self._select_action_dist(states, actions) 73 | 74 | idx = list(range(self.trans_buffer.size)) 75 | for _ in range(self.cfg.agent.value_net.n_update): 76 | random.shuffle(idx) 77 | batches = list( 78 | BatchSampler(idx, batch_size=self.batch_size, drop_last=False) 79 | ) 80 | for batch in batches: 81 | sampled_states, sampled_actions = states[batch], actions[batch] 82 | sampled_action_dist, sampled_log_probs = self._select_action_dist( 83 | sampled_states, sampled_actions 84 | ) 85 | ratio = th.exp(sampled_log_probs - old_log_probs[batch]) 86 | surr1 = ratio * self.adv[batch] 87 | surr2 = ( 88 | th.clamp(ratio, 1.0 - self.epsilon, 1.0 + self.epsilon) 89 | * self.adv[batch] 90 | ) 91 | loss = ( 92 | -th.min(surr1, surr2).mean() 93 | - self.cfg.agent.entropy_coef * sampled_action_dist.entropy().mean() 94 | ) 95 | self.stats.update( 96 | { 97 | "loss/actor": gradient_descent( 98 | self.actor_optim, 99 | loss, 100 | self.actor.parameters(), 101 | # experimental results show that clipping grad realy improves performance 102 | self.cfg.agent.actor.clip, 103 | ) 104 | } 105 | ) 106 | -------------------------------------------------------------------------------- /src/sac.py: -------------------------------------------------------------------------------- 1 | import math 2 | from copy import deepcopy 3 | from typing import Dict, Tuple, Union 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.nn.functional as F 8 | from omegaconf import DictConfig 9 | from torch import nn, optim 10 | 11 | from src.utils.net.actor import MLPGaussianActor 12 | from src.utils.net.critic import MLPTwinCritic 13 | from src.utils.net.ptu import freeze_net, gradient_descent, move_device, polyak_update 14 | 15 | from .base import BaseRLAgent 16 | 17 | 18 | class SACAgent(BaseRLAgent): 19 | """Soft Actor Critic (SAC)""" 20 | 21 | def __init__(self, cfg: DictConfig): 22 | super().__init__(cfg) 23 | 24 | def setup_model(self): 25 | # hyper-param 26 | self.entropy_target = -self.action_shape[0] 27 | self.warmup_steps = self.cfg.agent.warmup_steps 28 | self.env_steps = self.cfg.agent.env_steps 29 | 30 | # actor 31 | actor_kwarg = { 32 | "state_shape": self.state_shape, 33 | "net_arch": self.cfg.agent.actor.net_arch, 34 | "action_shape": self.action_shape, 35 | "state_std_independent": self.cfg.agent.actor.state_std_independent, 36 | "activation_fn": getattr(nn, self.cfg.agent.actor.activation_fn), 37 | } 38 | self.actor = MLPGaussianActor(**actor_kwarg) 39 | self.actor_optim = getattr(optim, self.cfg.agent.actor.optimizer)( 40 | self.actor.parameters(), self.cfg.agent.actor.lr 41 | ) 42 | 43 | # critic 44 | critic_kwarg = { 45 | "input_shape": (self.state_shape[0] + self.action_shape[0],), 46 | "net_arch": self.cfg.agent.critic.net_arch, 47 | "output_shape": (1,), 48 | "activation_fn": getattr(nn, self.cfg.agent.critic.activation_fn), 49 | } 50 | self.critic = MLPTwinCritic(**critic_kwarg) 51 | self.critic_target = deepcopy(self.critic) 52 | self.critic_optim = getattr(optim, self.cfg.agent.critic.optimizer)( 53 | self.critic.parameters(), self.cfg.agent.critic.lr 54 | ) 55 | 56 | # alpha, we optimize log(alpha) because alpha should always be bigger than 0. 57 | if self.cfg.agent.log_alpha.auto_tune: 58 | self.log_alpha = th.tensor( 59 | [self.cfg.agent.log_alpha.init_value], 60 | device=self.device, 61 | requires_grad=True, 62 | ) 63 | self.log_alpha_optim = getattr(optim, self.cfg.agent.log_alpha.optimizer)( 64 | [self.log_alpha], self.cfg.agent.log_alpha.lr 65 | ) 66 | self.models.update( 67 | {"log_alpha": self.log_alpha, "log_alpha_optim": self.log_alpha_optim} 68 | ) 69 | else: 70 | self.log_alpha = th.tensor( 71 | [self.cfg.agent.log_alpha.init_value], device=self.device 72 | ) 73 | 74 | freeze_net((self.critic_target,)) 75 | move_device((self.actor, self.critic, self.critic_target), self.device) 76 | 77 | self.models.update( 78 | { 79 | "actor": self.actor, 80 | "actor_optim": self.actor_optim, 81 | "critic": self.critic, 82 | "critic_target": self.critic_target, 83 | "critic_optim": self.critic_optim, 84 | } 85 | ) 86 | 87 | @property 88 | def alpha(self): 89 | return math.exp(self.log_alpha.item()) 90 | 91 | def select_action( 92 | self, 93 | state: Union[np.ndarray, th.Tensor], 94 | deterministic: bool, 95 | return_log_prob: bool, 96 | **kwarg, 97 | ) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]: 98 | """ 99 | :param deterministic: whether sample from the action distribution or just the action mean. 100 | :param return_dtype_tensor: whether the returned data's dtype keeps to be torch.Tensor or numpy.ndarray 101 | :param return_log_prob: whether return log_prob 102 | """ 103 | # Due to the squash operation, we need that keep_dtype_tensor == True here. 104 | if return_log_prob: 105 | action, log_prob = self.actor.sample( 106 | state, deterministic, True, self.device 107 | ) 108 | # squash action 109 | log_prob -= th.sum( 110 | 2 * (np.log(2.0) - action - F.softplus(-2 * action)), 111 | axis=-1, 112 | keepdims=True, 113 | ) 114 | else: 115 | action = self.actor.sample(state, deterministic, False, self.device) 116 | log_prob = None 117 | 118 | action = th.tanh(action) 119 | 120 | # # scale, we could instead use gym.wrappers to rescale action space 121 | # # [-1, +1] -> [-action_scale, action_scale] 122 | # if return_log_prob: 123 | # log_prob -= th.sum( 124 | # np.log(1.0 / self.action_scale) * th.ones_like(action), 125 | # axis=-1, 126 | # keepdim=True, 127 | # ) 128 | # action *= self.action_scale 129 | 130 | return (action, log_prob) if return_log_prob else action 131 | 132 | def update(self) -> Dict: 133 | self.stats = dict() 134 | rest_steps = self.trans_buffer.size - self.warmup_steps 135 | if not ( 136 | self.trans_buffer.size < self.batch_size 137 | or rest_steps < 0 138 | or rest_steps % self.env_steps != 0 139 | ): 140 | states, actions, next_states, rewards, dones = self.trans_buffer.sample( 141 | self.batch_size 142 | ) 143 | 144 | # update params 145 | for _ in range(self.env_steps): 146 | self._update_critic(states, actions, next_states, rewards, dones) 147 | if self.cfg.agent.log_alpha.auto_tune: 148 | self._update_alpha(self._update_actor(states)) 149 | else: 150 | self._update_actor(states) 151 | polyak_update( 152 | self.critic.parameters(), 153 | self.critic_target.parameters(), 154 | self.cfg.agent.critic.tau, 155 | ) 156 | 157 | return self.stats 158 | 159 | def _update_critic( 160 | self, 161 | states: th.Tensor, 162 | actions: th.Tensor, 163 | next_states: th.Tensor, 164 | rewards: th.Tensor, 165 | dones: th.Tensor, 166 | ): 167 | with th.no_grad(): 168 | pred_next_actions, pred_next_log_pis = self.select_action( 169 | next_states, 170 | deterministic=False, 171 | return_log_prob=True, 172 | ) 173 | target_Q1, target_Q2 = self.critic_target( 174 | True, next_states, pred_next_actions 175 | ) 176 | target_Q = th.min(target_Q1, target_Q2) - self.alpha * pred_next_log_pis 177 | target_Q = rewards + self.gamma * (1 - dones) * target_Q 178 | 179 | Q1, Q2 = self.critic(True, states, actions) 180 | critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q) 181 | 182 | self.stats.update( 183 | {"loss/critic": gradient_descent(self.critic_optim, critic_loss)} 184 | ) 185 | 186 | def _update_actor(self, states: th.Tensor): 187 | pred_actions, pred_log_pis = self.select_action( 188 | states, deterministic=False, return_log_prob=True 189 | ) 190 | Q1, Q2 = self.critic(True, states, pred_actions) 191 | actor_loss = th.mean(self.alpha * pred_log_pis - th.min(Q1, Q2)) 192 | self.stats.update( 193 | {"loss/actor": gradient_descent(self.actor_optim, actor_loss)} 194 | ) 195 | 196 | return pred_log_pis.detach() 197 | 198 | def _update_alpha(self, pred_log_pis: th.Tensor): 199 | """Auto-tune alpha 200 | 201 | Note: pred_log_pis are detached from the computation graph 202 | """ 203 | alpha_loss = th.mean(self.log_alpha * (-pred_log_pis - self.entropy_target)) 204 | self.stats.update( 205 | {"loss/alpha": gradient_descent(self.log_alpha_optim, alpha_loss)} 206 | ) 207 | 208 | # update alpha 209 | self.models["log_alpha"].data = self.log_alpha.data 210 | -------------------------------------------------------------------------------- /src/td3.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Dict, Union 3 | 4 | import numpy as np 5 | import torch as th 6 | import torch.nn.functional as F 7 | from omegaconf import DictConfig 8 | from torch import nn, optim 9 | 10 | from src.utils.net.actor import MLPDeterministicActor 11 | from src.utils.net.critic import MLPTwinCritic 12 | from src.utils.net.ptu import freeze_net, gradient_descent, move_device, polyak_update 13 | 14 | from .base import BaseRLAgent 15 | 16 | 17 | class TD3Agent(BaseRLAgent): 18 | """Twin Delayed Deep Deterministic Policy Gradient (TD3)""" 19 | 20 | def __init__(self, cfg: DictConfig): 21 | super().__init__(cfg) 22 | 23 | def setup_model(self): 24 | # hyper-param 25 | self.warmup_steps = self.cfg.agent.warmup_steps 26 | self.env_steps = self.cfg.agent.env_steps 27 | self.total_train_it = 0 28 | 29 | # actor 30 | actor_kwarg = { 31 | "state_shape": self.state_shape, 32 | "net_arch": self.cfg.agent.actor.net_arch, 33 | "action_shape": self.action_shape, 34 | "activation_fn": getattr(nn, self.cfg.agent.actor.activation_fn), 35 | } 36 | self.actor = MLPDeterministicActor(**actor_kwarg) 37 | self.actor_target = deepcopy(self.actor) 38 | self.actor_optim = getattr(optim, self.cfg.agent.actor.optimizer)( 39 | self.actor.parameters(), self.cfg.agent.actor.lr 40 | ) 41 | 42 | # critic 43 | critic_kwarg = { 44 | "input_shape": (self.state_shape[0] + self.action_shape[0],), 45 | "net_arch": self.cfg.agent.critic.net_arch, 46 | "output_shape": (1,), 47 | "activation_fn": getattr(nn, self.cfg.agent.critic.activation_fn), 48 | } 49 | self.critic = MLPTwinCritic(**critic_kwarg) 50 | self.critic_target = deepcopy(self.critic) 51 | self.critic_optim = getattr(optim, self.cfg.agent.critic.optimizer)( 52 | self.critic.parameters(), self.cfg.agent.critic.lr 53 | ) 54 | 55 | freeze_net((self.actor_target, self.critic_target)) 56 | move_device( 57 | (self.actor, self.actor_target, self.critic, self.critic_target), 58 | self.device, 59 | ) 60 | 61 | self.models.update( 62 | { 63 | "actor": self.actor, 64 | "actor_target": self.actor_target, 65 | "actor_optim": self.actor_optim, 66 | "critic": self.critic, 67 | "critic_target": self.critic_target, 68 | "critic_optim": self.critic_optim, 69 | } 70 | ) 71 | 72 | def select_action( 73 | self, 74 | state: Union[np.ndarray, th.Tensor], 75 | deterministic: bool, 76 | actor: nn.Module = None, 77 | **kwargs, 78 | ) -> th.Tensor: 79 | state = th.Tensor(state).to(self.device) if type(state) is np.ndarray else state 80 | 81 | if actor is None: 82 | action = self.actor(state) 83 | else: 84 | action = actor(state) 85 | 86 | action = th.tanh(action) 87 | 88 | # add explore noise 89 | if not deterministic: 90 | noise = th.clamp( 91 | th.randn_like(action) * self.cfg.agent.sigma, 92 | -self.cfg.agent.c, 93 | self.cfg.agent.c, 94 | ) 95 | action = th.clamp(action + noise, -1.0, 1.0) 96 | 97 | return action 98 | 99 | def update(self) -> Dict: 100 | self.stats = dict() 101 | rest_steps = self.trans_buffer.size - self.warmup_steps 102 | if not ( 103 | self.trans_buffer.size < self.batch_size 104 | or rest_steps < 0 105 | or rest_steps % self.env_steps != 0 106 | ): 107 | self.total_train_it += 1 108 | states, actions, next_states, rewards, dones = self.trans_buffer.sample( 109 | self.batch_size 110 | ) 111 | 112 | # update params 113 | for _ in range(self.env_steps): 114 | self._update_critic(states, actions, next_states, rewards, dones) 115 | if self.total_train_it % self.cfg.agent.policy_freq == 0: 116 | self._update_actor(states) 117 | 118 | polyak_update( 119 | self.critic.parameters(), 120 | self.critic_target.parameters(), 121 | self.cfg.agent.critic.tau, 122 | ) 123 | polyak_update( 124 | self.actor.parameters(), 125 | self.actor_target.parameters(), 126 | self.cfg.agent.actor.tau, 127 | ) 128 | 129 | return self.stats 130 | 131 | def _update_critic( 132 | self, 133 | states: th.Tensor, 134 | actions: th.Tensor, 135 | next_states: th.Tensor, 136 | rewards: th.Tensor, 137 | dones: th.Tensor, 138 | ): 139 | with th.no_grad(): 140 | pred_next_actions = self.select_action( 141 | next_states, 142 | deterministic=False, 143 | actor=self.actor_target, 144 | ) 145 | target_Q1, target_Q2 = self.critic_target( 146 | True, next_states, pred_next_actions 147 | ) 148 | target_Q = rewards + self.gamma * (1 - dones) * th.min(target_Q1, target_Q2) 149 | Q1, Q2 = self.critic(True, states, actions) 150 | critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q) 151 | self.stats.update( 152 | {"loss/critic": gradient_descent(self.critic_optim, critic_loss)} 153 | ) 154 | 155 | def _update_actor(self, states: th.Tensor): 156 | pred_actions = self.select_action(states, deterministic=True) 157 | Q = self.critic(False, states, pred_actions) 158 | actor_loss = -th.mean(Q) 159 | self.stats.update( 160 | {"loss/actor": gradient_descent(self.actor_optim, actor_loss)} 161 | ) 162 | -------------------------------------------------------------------------------- /src/trpo.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Callable, Dict, Tuple, Union 3 | 4 | import numpy as np 5 | import torch as th 6 | import torch.nn.functional as F 7 | from omegaconf import DictConfig 8 | from torch import nn, optim 9 | from torch.autograd import grad 10 | from torch.distributions.kl import kl_divergence 11 | from torch.distributions.normal import Normal 12 | from torch.nn.utils.convert_parameters import parameters_to_vector, vector_to_parameters 13 | from torch.utils.data import BatchSampler 14 | 15 | from src.utils.drls.gae import GAE 16 | from src.utils.net.actor import MLPGaussianActor 17 | from src.utils.net.critic import MLPCritic 18 | from src.utils.net.ptu import gradient_descent, move_device 19 | 20 | from .base import BaseRLAgent 21 | 22 | 23 | class TRPOAgent(BaseRLAgent): 24 | """Trust Region Policy Optimization (TRPO)""" 25 | 26 | def __init__(self, cfg: DictConfig): 27 | super().__init__(cfg) 28 | 29 | def setup_model(self): 30 | # hyper-param 31 | self.lambda_ = self.cfg.agent.lambda_ 32 | 33 | ## conjugate gradient 34 | self.residual_tol = self.cfg.agent.residual_tol 35 | self.cg_steps = self.cfg.agent.cg_steps 36 | self.damping = self.cfg.agent.damping 37 | 38 | ## line search 39 | self.beta = self.cfg.agent.beta 40 | self.max_backtrack = self.cfg.agent.max_backtrack 41 | self.accept_ratio = self.cfg.agent.accept_ratio 42 | self.delta = self.cfg.agent.delta 43 | 44 | # GAE 45 | self.gae = GAE( 46 | self.gamma, 47 | self.lambda_, 48 | self.cfg.agent.norm_adv, 49 | self.cfg.agent.use_td_lambda, 50 | ) 51 | 52 | # actor 53 | actor_kwarg = { 54 | "state_shape": self.state_shape, 55 | "net_arch": self.cfg.agent.actor.net_arch, 56 | "action_shape": self.action_shape, 57 | "activation_fn": getattr(nn, self.cfg.agent.actor.activation_fn), 58 | } 59 | self.actor = MLPGaussianActor(**actor_kwarg) 60 | 61 | # value network 62 | value_net_kwarg = { 63 | "input_shape": self.state_shape, 64 | "output_shape": (1,), 65 | "net_arch": self.cfg.agent.value_net.net_arch, 66 | "activation_fn": getattr(nn, self.cfg.agent.value_net.activation_fn), 67 | } 68 | self.value_net = MLPCritic(**value_net_kwarg) 69 | self.value_net_optim = getattr(optim, self.cfg.agent.value_net.optimizer)( 70 | self.value_net.parameters(), self.cfg.agent.value_net.lr 71 | ) 72 | 73 | move_device((self.actor, self.value_net), self.device) 74 | 75 | self.models.update( 76 | { 77 | "actor": self.actor, 78 | "value_net": self.value_net, 79 | "value_net_optim": self.value_net_optim, 80 | } 81 | ) 82 | 83 | def select_action( 84 | self, state: np.ndarray, deterministic: bool, return_log_prob: bool, **kwarg 85 | ) -> Union[Tuple[th.Tensor, th.Tensor], th.Tensor]: 86 | return self.actor.sample(state, deterministic, return_log_prob, self.device) 87 | 88 | def update(self) -> Dict: 89 | self.stats = dict() 90 | if self.trans_buffer.size >= self.cfg.agent.rollout_steps: 91 | states, actions, next_states, rewards, dones = self.trans_buffer.buffers 92 | 93 | # get advantage 94 | with th.no_grad(): 95 | Rs, self.adv = self.gae( 96 | self.value_net, states, rewards, next_states, dones 97 | ) 98 | 99 | self._update_actor(states, actions) 100 | self._update_value_net(states, Rs) 101 | 102 | self.trans_buffer.clear() 103 | return self.stats 104 | 105 | def _update_value_net(self, states: th.Tensor, Rs: th.Tensor) -> float: 106 | idx = list(range(self.trans_buffer.size)) 107 | for _ in range(self.cfg.agent.value_net.n_update): 108 | random.shuffle(idx) 109 | batches = list( 110 | BatchSampler(idx, batch_size=self.batch_size, drop_last=False) 111 | ) 112 | for batch in batches: 113 | sampled_states = states[batch] 114 | values = self.value_net(sampled_states) 115 | loss = F.mse_loss(values, Rs[batch]) 116 | self.stats.update( 117 | {"loss/critic": gradient_descent(self.value_net_optim, loss)} 118 | ) 119 | 120 | def _update_actor(self, states: th.Tensor, actions: th.Tensor): 121 | original_actor_param = th.clone( 122 | parameters_to_vector(self.actor.parameters()).data 123 | ) 124 | 125 | ## pg 126 | action_dist, log_probs = self._select_action_dist(states, actions) 127 | old_action_dist = Normal( 128 | action_dist.loc.data.clone(), action_dist.scale.data.clone() 129 | ) 130 | old_log_probs = log_probs.data.clone() 131 | 132 | loss = self._get_surrogate_loss(log_probs, old_log_probs) 133 | pg = grad(loss, self.actor.parameters(), retain_graph=True) 134 | pg = parameters_to_vector(pg).detach() 135 | 136 | ## x = H^{-1} * pg, H = kl_g' 137 | kl = th.mean(kl_divergence(old_action_dist, action_dist)) 138 | kl_g = grad(kl, self.actor.parameters(), create_graph=True) 139 | kl_g = parameters_to_vector(kl_g) 140 | 141 | update_dir = self._conjugate_gradient(kl_g, pg) # x 142 | Fvp = self._Fvp_func(kl_g, pg) # Hx 143 | full_step_size = th.sqrt( 144 | 2 * self.delta / th.dot(update_dir, Fvp) 145 | ) # denominator: x^t (Hx) 146 | 147 | ## line search for appropriate step size 148 | self.stats.update({"loss/actor": 0.0}) 149 | 150 | def check_constrain(alpha): 151 | step = alpha * full_step_size * update_dir 152 | with th.no_grad(): 153 | vector_to_parameters( 154 | original_actor_param + step, self.actor.parameters() 155 | ) 156 | try: 157 | new_action_dist, new_log_probs = self._select_action_dist( 158 | states, actions 159 | ) 160 | except: 161 | vector_to_parameters( # restore actor 162 | original_actor_param, self.actor.parameters() 163 | ) 164 | return False 165 | new_loss = self._get_surrogate_loss(new_log_probs, old_log_probs) 166 | new_kl = th.mean(kl_divergence(old_action_dist, new_action_dist)) 167 | actual_improve = new_loss - loss 168 | 169 | if actual_improve.item() > 0.0 and new_kl.item() <= self.delta: 170 | self.stats.update({"loss/actor": new_loss.item()}) 171 | return True 172 | else: 173 | return False 174 | 175 | alpha = self._line_search(check_constrain) 176 | vector_to_parameters( 177 | original_actor_param + alpha * full_step_size * update_dir, 178 | self.actor.parameters(), 179 | ) 180 | 181 | def _select_action_dist( 182 | self, states: th.Tensor, actions: th.Tensor 183 | ) -> Tuple[Normal, th.Tensor]: 184 | action_mean, action_std = self.actor(states) 185 | action_dist = Normal(action_mean, action_std) 186 | log_prob = th.sum(action_dist.log_prob(actions), dim=-1, keepdim=True) 187 | return action_dist, log_prob 188 | 189 | def _line_search(self, check_constrain: Callable) -> float: 190 | alpha = 1.0 / self.beta 191 | for _ in range(self.max_backtrack): 192 | alpha *= self.beta 193 | if check_constrain(alpha): 194 | return alpha 195 | return 0.0 196 | 197 | def _get_surrogate_loss( 198 | self, log_probs: th.Tensor, old_log_probs: th.Tensor 199 | ) -> th.Tensor: 200 | return th.mean(th.exp(log_probs - old_log_probs) * self.adv) 201 | 202 | def _conjugate_gradient(self, kl_g: th.Tensor, pg: th.Tensor) -> th.Tensor: 203 | """To calculate s = H^{-1}g without solving inverse of H 204 | 205 | Ref: https://en.wikipedia.org/wiki/Conjugate_gradient_method 206 | 207 | Code modified from: https://github.com/ikostrikov/pytorch-trpo 208 | """ 209 | x = th.zeros_like(pg) 210 | r = pg.clone() 211 | p = pg.clone() 212 | rdotr = th.dot(r, r) 213 | for _ in range(self.cg_steps): 214 | _Fvp = self._Fvp_func(kl_g, p) 215 | alpha = rdotr / th.dot(p, _Fvp) 216 | x += alpha * p 217 | r -= alpha * _Fvp 218 | new_rdotr = th.dot(r, r) 219 | beta = new_rdotr / rdotr 220 | p = r + beta * p 221 | rdotr = new_rdotr 222 | if rdotr < self.residual_tol: 223 | break 224 | return x 225 | 226 | def _Fvp_func(self, kl_g: th.Tensor, p: th.Tensor) -> th.Tensor: 227 | """Fisher vector product""" 228 | gvp = th.dot(kl_g, p) 229 | Hvp = grad(gvp, self.actor.parameters(), retain_graph=True) 230 | Hvp = parameters_to_vector(Hvp).detach() 231 | # tricks to stablize 232 | # see https://www2.maths.lth.se/matematiklth/vision/publdb/reports/pdf/byrod-eccv-10.pdf 233 | Hvp += self.damping * p 234 | return Hvp 235 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import drls, logger, net, ospy 2 | 3 | __all__ = ["logger", "drls", "net", "ospy"] 4 | -------------------------------------------------------------------------------- /src/utils/drls/__init__.py: -------------------------------------------------------------------------------- 1 | from .buffer import BaseBuffer, TransitionBuffer 2 | from .env import get_env_info, make_env, reset_env_fn 3 | from .gae import GAE 4 | 5 | __all__ = [BaseBuffer, TransitionBuffer, get_env_info, make_env, reset_env_fn, GAE] 6 | -------------------------------------------------------------------------------- /src/utils/drls/buffer.py: -------------------------------------------------------------------------------- 1 | import random 2 | import sys 3 | from abc import ABC, abstractmethod 4 | from typing import Dict, List, Optional, Tuple, Union 5 | 6 | import numpy as np 7 | import torch as th 8 | 9 | from ..ospy.dataset import get_one_traj, save_dataset_to_h5, split_dataset_into_trajs 10 | 11 | 12 | class BaseBuffer(ABC): 13 | def __init__( 14 | self, 15 | state_shape: Tuple[int, ...], 16 | action_shape: Tuple[int, ...], 17 | action_dtype: Union[np.int64, np.float32], 18 | device: Union[str, th.device], 19 | buffer_size: int = -1, 20 | ) -> None: 21 | """If buffer size is not specified, it will continually add new items in without removal of old items.""" 22 | self.state_shape = state_shape 23 | self.action_shape = action_shape 24 | 25 | if action_dtype == "int": 26 | self.action_dtype = np.int64 27 | elif action_dtype == "float": 28 | self.action_dtype = np.float32 29 | else: 30 | raise ValueError("Unsupported action dtype!") 31 | 32 | self.device = device 33 | self.buffer_size = buffer_size if buffer_size != -1 else sys.maxsize 34 | self.buffers: List[th.Tensor] = [] 35 | self.clear() 36 | 37 | @abstractmethod 38 | def init_buffer(self) -> None: 39 | raise NotImplementedError 40 | 41 | @abstractmethod 42 | def insert_transition(self) -> None: 43 | raise NotImplementedError 44 | 45 | @abstractmethod 46 | def insert_batch(self) -> None: 47 | raise NotImplementedError 48 | 49 | @abstractmethod 50 | def insert_dataset(self) -> None: 51 | raise NotImplementedError 52 | 53 | def load_dataset( 54 | self, dataset: Dict[str, np.ndarray], n_traj: Optional[int] = None 55 | ) -> None: 56 | """Load [n_traj] trajs into the buffer""" 57 | 58 | if n_traj is None: 59 | self.insert_dataset(dataset) 60 | else: # randomly select [traj_num] trajectories 61 | traj_pairs = split_dataset_into_trajs(dataset) 62 | traj_pair = random.sample(traj_pairs, n_traj) 63 | for start_idx, end_idx in traj_pair: 64 | new_traj = get_one_traj(dataset, start_idx, end_idx) 65 | self.insert_dataset(new_traj) 66 | 67 | def sample( 68 | self, batch_size: Optional[int] = None, shuffle: bool = True 69 | ) -> List[th.Tensor]: 70 | """Randomly sample items from the buffer. 71 | 72 | If batch_size is not provided, we will sample all the stored items. 73 | """ 74 | idx = list(range(self.size)) 75 | if shuffle: 76 | random.shuffle(idx) 77 | if batch_size is not None: 78 | idx = idx[:batch_size] 79 | 80 | return [buffer[idx] for buffer in self.buffers] 81 | 82 | def clear(self) -> None: 83 | self.init_buffer() 84 | self.buffers = [item.to(self.device) for item in self.buffers] 85 | self.ptr = 0 86 | self.size = 0 87 | self.total_size = 0 # Number of all the pushed items 88 | 89 | 90 | class TransitionBuffer(BaseBuffer): 91 | """ 92 | Transition buffer for single task 93 | """ 94 | 95 | def __init__( 96 | self, 97 | state_shape: Tuple[int, ...], 98 | action_shape: Tuple[int, ...], 99 | action_dtype: Union[np.int64, np.float32], 100 | device: Union[str, th.device], 101 | buffer_size: int = -1, 102 | ) -> None: 103 | super().__init__(state_shape, action_shape, action_dtype, device, buffer_size) 104 | 105 | def init_buffer(self) -> None: 106 | # Unlike some popular implementations, 107 | # we start with an empty buffer located in self.device (may be gpu). 108 | 109 | state_shape = (0,) + self.state_shape 110 | 111 | if self.action_dtype == np.int64: 112 | action_shape = (0, 1) 113 | else: 114 | action_shape = (0,) + self.action_shape 115 | 116 | self.buffers = [ 117 | th.zeros(state_shape), # state_buffer 118 | th.tensor(np.zeros(action_shape, dtype=self.action_dtype)), # action_buffer 119 | th.zeros(state_shape), # next_state_buffer 120 | th.zeros((0, 1)), # reward_buffer 121 | th.zeros((0, 1)), # done_buffer 122 | ] 123 | 124 | def insert_transition( 125 | self, 126 | state: np.ndarray, 127 | action: Union[np.ndarray, int], 128 | next_state: np.ndarray, 129 | reward: Union[np.ndarray, float], 130 | done: Union[np.ndarray, float], 131 | ) -> None: 132 | # state 133 | state, next_state = ( 134 | np.array(item, dtype=np.float32) for item in [state, next_state] 135 | ) 136 | # action 137 | if isinstance(action, (int, np.int64)): 138 | action = [action] 139 | action = np.array(action, dtype=self.action_dtype) 140 | # reward and done 141 | reward, done = ( 142 | ( 143 | np.array([item], dtype=np.float32) 144 | if isinstance(item, (float, np.float32)) 145 | else item 146 | ) 147 | for item in [reward, done] 148 | ) 149 | 150 | new_transition = [state, action, next_state, reward, done] 151 | new_transition = [th.tensor(item).to(self.device) for item in new_transition] 152 | 153 | if self.total_size <= self.buffer_size: 154 | self.buffers = [ 155 | th.cat((self.buffers[i], th.unsqueeze(new_transition[i], dim=0)), dim=0) 156 | for i in range(len(new_transition)) 157 | ] 158 | else: 159 | for buffer, new_data in zip(self.buffers, new_transition): 160 | buffer[self.ptr] = new_data 161 | 162 | # update pointer and size 163 | self.ptr = (self.ptr + 1) % self.buffer_size 164 | self.size = min(self.size + 1, self.buffer_size) 165 | self.total_size += 1 166 | 167 | def insert_batch( 168 | self, 169 | states: np.ndarray, 170 | actions: np.ndarray, 171 | next_states: np.ndarray, 172 | rewards: np.ndarray, 173 | dones: np.ndarray, 174 | ) -> None: 175 | """Insert a batch of transitions""" 176 | for i in range(states.shape[0]): 177 | self.insert_transition( 178 | states[i], actions[i], next_states[i], rewards[i], dones[i] 179 | ) 180 | 181 | def insert_dataset(self, dataset: Dict) -> None: 182 | """Insert dataset into the buffer""" 183 | observations, actions, next_observations, rewards, terminals = ( 184 | dataset["observations"], 185 | dataset["actions"], 186 | dataset["next_observations"], 187 | dataset["rewards"], 188 | dataset["terminals"], 189 | ) # we currently not consider the log_pi. But you can insert it with small modifications 190 | self.insert_batch(observations, actions, next_observations, rewards, terminals) 191 | 192 | def save_buffer(self, save_dir: str, file_name: Optional[str] = None) -> None: 193 | buffer = { 194 | "observations": self.buffers[0].cpu().numpy(), 195 | "actions": self.buffers[1].cpu().numpy(), 196 | "next_observations": self.buffers[2].cpu().numpy(), 197 | "rewards": self.buffers[3].cpu().numpy(), 198 | "terminals": self.buffers[4].cpu().numpy(), 199 | } 200 | save_dataset_to_h5( 201 | buffer, save_dir, "buffer" if file_name is None else file_name 202 | ) 203 | -------------------------------------------------------------------------------- /src/utils/drls/env.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Tuple, Union 2 | 3 | import gymnasium as gym 4 | import numpy as np 5 | from gymnasium.spaces import Box, Discrete 6 | 7 | 8 | def _get_space_info(obj: gym.Space) -> Tuple[Tuple[int, ...], str]: 9 | if isinstance(obj, Box): 10 | shape = obj.shape 11 | type_ = "float" 12 | elif isinstance(obj, Discrete): 13 | shape = (obj.n.item(),) 14 | type_ = "int" 15 | else: 16 | raise TypeError("Currently only Box and Discrete are supported!") 17 | return shape, type_ 18 | 19 | 20 | def get_env_info(env: gym.Env) -> Dict[str, Union[Tuple[int, ...], str]]: 21 | state_shape, _ = _get_space_info(env.observation_space) 22 | action_shape, action_dtype = _get_space_info(env.action_space) 23 | 24 | env_info = { 25 | "state_shape": state_shape, 26 | "action_shape": action_shape, 27 | "action_dtype": action_dtype, 28 | } 29 | 30 | if isinstance(env.action_space, Box): 31 | env_info["action_scale"] = float(env.action_space.high[0]) 32 | 33 | return env_info 34 | 35 | 36 | def make_env(env_id: str) -> gym.Env: 37 | """Currently we only support the below simple env style""" 38 | try: 39 | env = gym.make(env_id) 40 | except: 41 | raise ValueError("Unsupported env id!") 42 | return env 43 | 44 | 45 | def reset_env_fn(env: gym.Env, seed: int) -> Tuple[np.ndarray, Dict[str, Any]]: 46 | next_state, info = env.reset(seed=seed) 47 | env.action_space.seed(seed) 48 | env.observation_space.seed(seed) 49 | return (next_state, info) 50 | -------------------------------------------------------------------------------- /src/utils/drls/gae.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch as th 4 | 5 | 6 | class GAE: 7 | """Estimate Advantage using GAE (https://arxiv.org/abs/1506.02438) 8 | 9 | Ref: 10 | [1] https://nn.labml.ai/rl/ppo/gae.html 11 | [2] https://github.com/ikostrikov/pytorch-trpo 12 | """ 13 | 14 | def __init__( 15 | self, 16 | gamma: float, 17 | lambda_: float, 18 | norm_adv: bool = True, 19 | use_td_lambda: bool = True, 20 | ) -> None: 21 | self.gamma = gamma 22 | self.lambda_ = lambda_ 23 | self.norm_adv = norm_adv 24 | self.use_td_lambda = use_td_lambda 25 | 26 | def __call__( 27 | self, 28 | value_net: th.nn.Module, 29 | states: th.Tensor, 30 | rewards: th.Tensor, 31 | next_states: th.Tensor, 32 | dones: th.Tensor, 33 | ) -> Tuple[th.Tensor, th.Tensor]: 34 | """Here we can use two different methods to calculate Returns""" 35 | not_dones = 1.0 - dones 36 | 37 | if self.use_td_lambda: 38 | Rs, advantages = self.td_lambda( 39 | value_net, states, rewards, next_states, not_dones 40 | ) 41 | else: 42 | Rs, advantages = self.gae( 43 | value_net, states, rewards, next_states, not_dones 44 | ) 45 | 46 | if self.norm_adv: 47 | (advantages - advantages.mean()) / (advantages.std() + 1e-8) 48 | 49 | return Rs, advantages 50 | 51 | def gae( 52 | self, 53 | value_net: th.nn.Module, 54 | states: th.Tensor, 55 | rewards: th.Tensor, 56 | next_states: th.Tensor, 57 | not_dones: th.Tensor, 58 | ) -> Tuple[th.Tensor, th.Tensor]: 59 | Rs = th.empty_like(rewards) # reward-to-go R_t 60 | advantages = th.empty_like(rewards) # advantage 61 | values = value_net(states) 62 | 63 | last_value = value_net(next_states[-1]) 64 | last_return = th.clone(last_value) 65 | last_advantage = 0.0 66 | 67 | for t in reversed(range(rewards.shape[0])): 68 | # calculate rewards-to-go reward 69 | Rs[t] = rewards[t] + self.gamma * last_return * not_dones[t] 70 | # delta and advantage 71 | delta = rewards[t] + self.gamma * last_value * not_dones[t] - values[t] 72 | advantages[t] = ( 73 | delta + self.gamma * self.lambda_ * not_dones[t] * last_advantage 74 | ) 75 | # update pointer 76 | last_value = th.clone(values[t]) 77 | last_advantage = advantages[t].clone() 78 | last_return = Rs[t].clone() 79 | return Rs, advantages 80 | 81 | def td_lambda( 82 | self, 83 | value_net: th.nn.Module, 84 | states: th.Tensor, 85 | rewards: th.Tensor, 86 | next_states: th.Tensor, 87 | not_dones: th.Tensor, 88 | ) -> Tuple[th.Tensor, th.Tensor]: 89 | # calcultae value 90 | values, next_values = value_net(states), value_net(next_states) 91 | # calculate TD errors. 92 | deltas = rewards + self.gamma * next_values * not_dones - values 93 | # initialize gae. 94 | advantages = th.empty_like(rewards) 95 | # calculate gae recursively from behind. 96 | advantages[-1] = deltas[-1] 97 | for t in reversed(range(rewards.size(0) - 1)): 98 | advantages[t] = ( 99 | deltas[t] + self.gamma * self.lambda_ * not_dones[t] * advantages[t + 1] 100 | ) 101 | 102 | return advantages + values, advantages 103 | -------------------------------------------------------------------------------- /src/utils/exp/__init__.py: -------------------------------------------------------------------------------- 1 | from .prepare import set_random_seed 2 | 3 | __all__ = ["set_random_seed"] 4 | -------------------------------------------------------------------------------- /src/utils/exp/prepare.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch as th 5 | 6 | 7 | def set_random_seed(seed: int) -> None: 8 | """ 9 | Seed the different random generators. 10 | 11 | :param seed: 12 | """ 13 | # Seed python RNG 14 | random.seed(seed) 15 | # Seed numpy RNG 16 | np.random.seed(seed) 17 | # seed the RNG for all devices 18 | th.manual_seed(seed) 19 | -------------------------------------------------------------------------------- /src/utils/logger/__init__.py: -------------------------------------------------------------------------------- 1 | from loguru import logger as console_logger 2 | 3 | from ._archive import archive_logs 4 | from ._logger import TBLogger 5 | from ._plot import average_smooth, tb2dict, window_smooth 6 | from ._sync import download_logs, upload_logs 7 | 8 | __all__ = [ 9 | console_logger, 10 | archive_logs, 11 | TBLogger, 12 | upload_logs, 13 | download_logs, 14 | tb2dict, 15 | average_smooth, 16 | window_smooth, 17 | ] 18 | -------------------------------------------------------------------------------- /src/utils/logger/_archive.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join 3 | 4 | from ..ospy.file import copys 5 | 6 | 7 | def archive_logs(exp_name: str, src_dir: str, tgt_dir: str = "archived"): 8 | """Locally archive src_dir/exp_name to tgt_dir""" 9 | os.makedirs(tgt_dir, exist_ok=True) 10 | copys(join(src_dir, exp_name), join(tgt_dir, exp_name)) 11 | -------------------------------------------------------------------------------- /src/utils/logger/_logger.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from datetime import datetime 4 | from os.path import join 5 | from typing import Any, Dict, List 6 | 7 | import loguru 8 | import tqdm 9 | from tensorboardX import SummaryWriter 10 | 11 | from ..ospy.file import copys 12 | 13 | 14 | def _parse_record_param( 15 | args: Dict[str, Any], record_param: List[str] 16 | ) -> Dict[str, Any]: 17 | if args is None or record_param is None: 18 | return None 19 | else: 20 | record_param_dict = dict() 21 | for param in record_param: 22 | params = param.split(".") 23 | value = args 24 | for p in params: 25 | try: 26 | value = value[p] 27 | except: 28 | value = "" 29 | break 30 | record_param_dict[param] = value 31 | return record_param_dict 32 | 33 | 34 | def _get_exp_name(record_param_dict: Dict[str, Any], prefix: str = None): 35 | if prefix is not None: 36 | exp_name = prefix 37 | else: 38 | exp_name = datetime.now().strftime("%Y-%m-%d__%H-%M-%S") 39 | for key, value in record_param_dict.items(): 40 | if isinstance(value, str): 41 | value = "-".join(value.split(" ")) 42 | exp_name = exp_name + f"~{key}={value}" 43 | return exp_name 44 | 45 | 46 | class TBLogger: 47 | """Tensorboard Logger""" 48 | 49 | console = loguru.logger 50 | 51 | def __init__( 52 | self, 53 | work_dir: str = "./", 54 | args: Dict[str, Any] = {}, 55 | root_log_dir: str = "runs", 56 | record_param: List[str] = [], 57 | backup_code: bool = False, 58 | code_files_list: List[str] = None, 59 | console_output: bool = True, 60 | **kwargs, 61 | ): 62 | """ 63 | Args: 64 | work_dir: Path of the current work dir 65 | args: Hyper-parameters and configs 66 | root_log_dir: The root directory for all the logs 67 | record_param: Parameters used to name the log dir 68 | backup_code: Whether to backup code 69 | code_files_list: The list of code file/dir to backup 70 | console_output: Whether to output to the console 71 | """ 72 | self.args = args 73 | self.record_param = record_param 74 | self.work_dir = os.path.abspath(work_dir) 75 | self.root_log_dir = join(work_dir, root_log_dir) 76 | self.code_files_list = code_files_list 77 | self.record_param_dict = _parse_record_param(args, record_param) 78 | self.tqdm = tqdm 79 | 80 | # create log dirs 81 | self.exp_name = _get_exp_name(self.record_param_dict) 82 | self.exp_dir = join(self.root_log_dir, self.exp_name) 83 | self._create_artifact_dir() 84 | 85 | # init tb 86 | self.tb = SummaryWriter(log_dir=self.exp_dir, **kwargs) 87 | 88 | # init loguru 89 | if not console_output: 90 | self.console.remove() 91 | self.console_log_file = join(self.exp_dir, "console.log") 92 | self.console.add(self.console_log_file, format="{time} -- {level} -- {message}") 93 | 94 | if backup_code: 95 | self._backup_code() 96 | 97 | # save arguments 98 | self._save_args() 99 | 100 | def _create_artifact_dir(self): 101 | self.ckpt_dir = join(self.exp_dir, "ckpt") 102 | os.makedirs(self.ckpt_dir) # checkpoint, for model, data, etc. 103 | 104 | self.result_dir = join(self.exp_dir, "result") 105 | os.makedirs(self.result_dir) # result, for some intermediate result 106 | 107 | self.code_bk_dir = join(self.exp_dir, "code") 108 | os.makedirs(self.code_bk_dir) # back up code 109 | 110 | def _save_args(self): 111 | if self.args is None: 112 | return 113 | else: 114 | # pp = pprint.PrettyPrinter(indent=4) 115 | # pp.pprint(self.args) 116 | 117 | # self.console.info(f"Arguments: {self.args}") 118 | self.console.info( 119 | f"Save arguments to {join(self.exp_dir, 'parameter.json')}" 120 | ) 121 | with open(join(self.exp_dir, "parameter.json"), "w") as f: 122 | jd = json.dumps(self.args, indent=4) 123 | print(jd, file=f) 124 | 125 | def _backup_code(self): 126 | for code in self.code_files_list: 127 | src_path = join(self.work_dir, code) 128 | tgt_path = join(self.code_bk_dir, code) 129 | copys(src_path, tgt_path) 130 | 131 | # ================ Additional Helper Functions ================ 132 | 133 | def add_stats(self, stats: Dict[str, float], t: int): 134 | for key, value in stats.items(): 135 | self.tb.add_scalar(key, value, t) 136 | -------------------------------------------------------------------------------- /src/utils/logger/_plot.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | import numpy as np 4 | from tensorboard.backend.event_processing import event_accumulator 5 | 6 | 7 | def window_smooth(data: List[float], window_size: int = 10) -> List[float]: 8 | """Copy from https://github.com/openai/spinningup/blob/master/spinup/utils/plot.py 9 | 10 | smoothed_y[t] = average(y[t-k], y[t-k+1], ..., y[t+k-1], y[t+k]), where the param 'window_size' equals to 2*k + 1 11 | """ 12 | if window_size > 1: 13 | """ 14 | smooth data with moving window average. 15 | that is, 16 | smoothed_y[t] = average(y[t-k], y[t-k+1], ..., y[t+k-1], y[t+k]) 17 | where the "smooth" param is width of that window (2k+1) 18 | """ 19 | y = np.ones(window_size) 20 | x = np.asarray(data) 21 | z = np.ones(len(x)) 22 | smooth_data = np.convolve(x, y, "same") / np.convolve(z, y, "same") 23 | smooth_data = smooth_data.tolist() 24 | else: 25 | smooth_data = data 26 | return smooth_data 27 | 28 | 29 | def average_smooth(data: List[float], lambda_: float = 0.6) -> List[float]: 30 | """y[t] = lambda_ * y[t-1] + (1-lambda_) * y[t]""" 31 | smooth_data = [] 32 | for i in range(len(data)): 33 | if i == 0: 34 | smooth_data.append(data[i]) 35 | else: 36 | smooth_data.append(smooth_data[-1] * lambda_ + data[i] * (1 - lambda_)) 37 | return smooth_data 38 | 39 | 40 | def tb2dict(tb_file_path: str, keys: List[str]) -> Dict[str, Dict[str, List[float]]]: 41 | """Convert tensorboard log file into a dict of data points""" 42 | ea = event_accumulator.EventAccumulator(tb_file_path) 43 | ea.Reload() 44 | statistics = dict() 45 | for key in keys: 46 | assert key in ea.scalars.Keys(), f"{key} is not recorded by the tensorboard!" 47 | items = ea.scalars.Items(key) 48 | steps, values = list(), list() 49 | for item in items: 50 | steps.append(item.step) 51 | values.append(item.value) 52 | statistics[key] = {"steps": steps, "values": values} 53 | return statistics 54 | -------------------------------------------------------------------------------- /src/utils/logger/_sync.py: -------------------------------------------------------------------------------- 1 | import getpass 2 | import os 3 | import shutil 4 | from os.path import join 5 | from stat import S_ISDIR as is_remote_dir 6 | from stat import S_ISREG as is_remote_file 7 | 8 | import paramiko 9 | from paramiko.sftp_client import SFTPClient 10 | 11 | # ======================== Connect ======================== 12 | 13 | 14 | def connect_remote( 15 | host: str, 16 | port: int, 17 | ) -> SFTPClient: 18 | client = paramiko.SSHClient() 19 | client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) 20 | 21 | username = input("Please input your username: ") 22 | passwd = getpass.getpass("Please input your password: ") 23 | 24 | print(f"Connecting to {username}@{host}:{port}...") 25 | client.connect(host, port, username, passwd) 26 | print(f"Successfully connected to {username}@{host}:{port}!") 27 | 28 | return client.open_sftp() 29 | 30 | 31 | # ======================== Upload ======================== 32 | 33 | 34 | def _upload_file(sftp: SFTPClient, local_file_path: str, remote_file_path: str): 35 | sftp.put(local_file_path, remote_file_path) 36 | 37 | 38 | def _upload_dir( 39 | sftp: SFTPClient, 40 | local_log_dir: str, 41 | local_src_dir: str, 42 | remote_tgt_dir: str, 43 | verbose: int = 0, 44 | ): 45 | local_work_dir = join(local_src_dir, local_log_dir) 46 | remote_work_dir = join(remote_tgt_dir, local_log_dir) 47 | 48 | # If [local_log_dir] exists in remote, we will re-make it 49 | if local_log_dir in sftp.listdir(remote_tgt_dir): 50 | sftp.rmdir(remote_work_dir) 51 | sftp.mkdir(remote_work_dir) 52 | 53 | for item in os.listdir(local_work_dir): 54 | local_item_path = os.path.join(local_work_dir, item) 55 | remote_item_path = os.path.join(remote_work_dir, item) 56 | 57 | if os.path.isfile(local_item_path): 58 | if verbose == 1: # only report files 59 | print(f"Uploading {local_item_path} to {remote_item_path}...") 60 | _upload_file(sftp, local_item_path, remote_item_path) 61 | elif os.path.isdir(local_item_path): 62 | if verbose == 2: # only report dirs 63 | print(f"Uploading {local_item_path} to {remote_item_path}...") 64 | _upload_dir(sftp, item, local_work_dir, remote_work_dir, verbose) 65 | 66 | 67 | def upload_logs( 68 | host: str, 69 | port: int, 70 | local_log_name: str, 71 | local_src_dir: str, 72 | remote_tgt_dir: str, 73 | verbose: int = 0, 74 | ): 75 | """ 76 | Args: 77 | host: IP address of the remote server 78 | port: Port of the SSH 79 | local_log_name: file or directory name 80 | verbose: 81 | - 0, not output info during uploading 82 | - 1, output info of the uploaded files 83 | - 2, output info of the uploaded directories 84 | """ 85 | assert verbose in [0, 1, 2], "verbose must only be in [0, 1, 2]" 86 | local_log_path = os.path.join(local_src_dir, local_log_name) 87 | 88 | sftp = connect_remote(host=host, port=port) 89 | 90 | print(f"Start uploading logs from {local_log_path} to {host}:{port}!") 91 | if os.path.isfile(local_log_path): 92 | _upload_file(sftp, local_log_path, os.path.join(remote_tgt_dir, local_log_name)) 93 | else: 94 | _upload_dir( 95 | sftp, 96 | local_log_name, 97 | local_src_dir, 98 | remote_tgt_dir, 99 | ) 100 | print(f"Successfully finish uploading {local_log_path}!") 101 | 102 | 103 | # ======================== Download ======================== 104 | 105 | 106 | def _download_file(sftp: SFTPClient, remote_file_path: str, local_file_path: str): 107 | sftp.get(remote_file_path, local_file_path) 108 | 109 | 110 | def _download_dir( 111 | sftp: SFTPClient, 112 | remote_log_dir: str, 113 | remote_src_dir: str, 114 | local_tgt_dir: str, 115 | verbose: int = 0, 116 | ): 117 | local_work_dir = join(local_tgt_dir, remote_log_dir) 118 | remote_work_dir = join(remote_src_dir, remote_log_dir) 119 | 120 | # If [local_log_dir] exists in remote, we will re-make it 121 | if remote_log_dir in os.listdir(local_tgt_dir): 122 | shutil.rmtree(local_work_dir) 123 | os.makedirs(local_work_dir) 124 | 125 | for item in sftp.listdir(remote_work_dir): 126 | local_item_path = os.path.join(local_work_dir, item) 127 | remote_item_path = os.path.join(remote_work_dir, item) 128 | 129 | item_attr = sftp.lstat(remote_item_path) 130 | if is_remote_file(item_attr.st_mode): 131 | if verbose == 1: # only report files 132 | print(f"Downloading {remote_item_path} to {local_item_path}...") 133 | _download_file(sftp, remote_item_path, local_item_path) 134 | elif is_remote_dir(item_attr.st_mode): 135 | if verbose == 2: # only report dirs 136 | print(f"Downloading {remote_item_path} to {local_item_path}...") 137 | _download_dir(sftp, item, remote_work_dir, local_work_dir, verbose) 138 | 139 | 140 | def download_logs( 141 | host: str, 142 | port: int, 143 | remote_log_name: str, 144 | remote_src_dir: str, 145 | local_tgt_dir: str, 146 | verbose: int = 0, 147 | ): 148 | """ 149 | Args: 150 | host: IP address of the remote server 151 | port: Port of the SSH 152 | remote_log_name: file or directory name 153 | verbose: 154 | - 0, not output info during downloading 155 | - 1, output info of the downloaded files 156 | - 2, output info of the downloaded directories 157 | """ 158 | assert verbose in [0, 1, 2], "verbose must only be in [0, 1, 2]" 159 | remote_log_path = os.path.join(remote_src_dir, remote_log_name) 160 | 161 | sftp = connect_remote(host=host, port=port) 162 | 163 | print(f"Start downloading {remote_log_path} from {host}:{port} to {local_tgt_dir}!") 164 | if os.path.isfile(remote_log_path): 165 | _download_file( 166 | sftp, remote_log_path, os.path.join(local_tgt_dir, remote_log_name) 167 | ) 168 | else: 169 | _download_dir( 170 | sftp, 171 | remote_log_name, 172 | remote_src_dir, 173 | local_tgt_dir, 174 | ) 175 | print(f"Successfully finish downloading {remote_log_path}!") 176 | 177 | 178 | # # Example: Upload Logs 179 | # load_dotenv("./remote.env") 180 | # """ 181 | # Content of remote.env: 182 | 183 | # HOSTNAME = "xx.xx.xx.xx" 184 | # PORT = 22 185 | # REMOTE_WORK_DIR = "/path/to/logs" 186 | # """ 187 | # upload( 188 | # hostname=os.environ["HOSTNAME"], 189 | # port=os.environ["PORT"], 190 | # local_log_name="logs", 191 | # local_src_dir="./", 192 | # remote_tgt_dir=os.environ["REMOTE_WORK_DIR"], 193 | # ) 194 | -------------------------------------------------------------------------------- /src/utils/net/__init__.py: -------------------------------------------------------------------------------- 1 | from .actor import MLPDeterministicActor, MLPGaussianActor 2 | from .critic import MLPCritic, MLPDuleQNet, MLPTwinCritic 3 | from .ptu import ( 4 | cnn, 5 | freeze_net, 6 | gradient_descent, 7 | load_torch_model, 8 | mlp, 9 | move_device, 10 | orthogonal_init, 11 | save_torch_model, 12 | set_eval_mode, 13 | set_torch, 14 | set_train_mode, 15 | tensor2ndarray, 16 | variable, 17 | ) 18 | 19 | __all__ = [ 20 | MLPDeterministicActor, 21 | MLPGaussianActor, 22 | MLPCritic, 23 | MLPDuleQNet, 24 | MLPTwinCritic, 25 | cnn, 26 | freeze_net, 27 | gradient_descent, 28 | load_torch_model, 29 | mlp, 30 | move_device, 31 | orthogonal_init, 32 | save_torch_model, 33 | set_torch, 34 | tensor2ndarray, 35 | variable, 36 | set_eval_mode, 37 | set_train_mode, 38 | ] 39 | -------------------------------------------------------------------------------- /src/utils/net/actor.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union 2 | 3 | import numpy as np 4 | import torch as th 5 | from torch.distributions.normal import Normal 6 | from torch.nn import Module, ReLU 7 | 8 | from .ptu import mlp, orthogonal_init, variable 9 | 10 | 11 | class MLPGaussianActor(Module): 12 | """ 13 | Gaussian actor for continuous action space 14 | """ 15 | 16 | def __init__( 17 | self, 18 | state_shape: Tuple[int, ...], 19 | action_shape: Tuple[int, ...], 20 | net_arch: List[int], 21 | state_std_independent: bool = False, 22 | activation_fn: Module = ReLU, 23 | log_std_max: float = 2, 24 | log_std_min: float = -20, 25 | **kwarg, 26 | ): 27 | """ 28 | :param state_std_independent: whether std is a function of state 29 | """ 30 | super().__init__() 31 | self.log_std_max = log_std_max 32 | self.log_std_min = log_std_min 33 | 34 | # network definition 35 | self.feature_extractor, feature_shape = mlp( 36 | state_shape, (-1,), net_arch, activation_fn, **kwarg 37 | ) 38 | self.mu, _ = mlp(feature_shape, action_shape, [], activation_fn, **kwarg) 39 | 40 | # to unify self.log_std to be a function 41 | if state_std_independent: 42 | self._log_std = variable((1,) + action_shape) 43 | self.log_std = lambda _: self._log_std 44 | else: 45 | self.log_std, _ = mlp( 46 | feature_shape, action_shape, [], activation_fn, **kwarg 47 | ) 48 | 49 | self.apply(orthogonal_init) 50 | 51 | def forward(self, state: th.Tensor): 52 | feature = self.feature_extractor(state) 53 | mu, log_std = self.mu(feature), self.log_std(feature) 54 | log_std = th.clamp(log_std, self.log_std_min, self.log_std_max) 55 | return mu, log_std.exp() 56 | 57 | def sample( 58 | self, 59 | state: Union[th.Tensor, np.ndarray], 60 | deterministic: bool, 61 | return_log_prob: bool, 62 | device: Union[th.device, str], 63 | ) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]: 64 | state = th.Tensor(state).to(device) if type(state) is np.ndarray else state 65 | 66 | action_mean, action_std = self.forward(state) 67 | dist = Normal(action_mean, action_std) 68 | 69 | if deterministic: 70 | x = action_mean 71 | else: 72 | x = dist.rsample() 73 | 74 | if return_log_prob: 75 | log_prob = th.sum(dist.log_prob(x), axis=-1, keepdims=True) 76 | 77 | return (x, log_prob) if return_log_prob else x 78 | 79 | 80 | class MLPDeterministicActor(Module): 81 | def __init__( 82 | self, 83 | state_shape: Tuple[int,], 84 | action_shape: Tuple[int,], 85 | net_arch: List[int], 86 | activation_fn: Module = ReLU, 87 | **kwarg, 88 | ): 89 | super().__init__() 90 | 91 | self.feature_extractor, feature_shape = mlp( 92 | state_shape, (-1,), net_arch, activation_fn, **kwarg 93 | ) 94 | self.output_head, _ = mlp( 95 | feature_shape, action_shape, [], activation_fn, **kwarg 96 | ) 97 | 98 | self.apply(orthogonal_init) 99 | 100 | def forward(self, state: th.Tensor): 101 | feature = self.feature_extractor(state) 102 | action = self.output_head(feature) 103 | return action 104 | -------------------------------------------------------------------------------- /src/utils/net/critic.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import torch as th 4 | from torch.nn import Module, ReLU 5 | 6 | from ..net.ptu import mlp, orthogonal_init 7 | 8 | 9 | class MLPCritic(Module): 10 | def __init__( 11 | self, 12 | input_shape: Tuple[int,], 13 | output_shape: Tuple[int,], 14 | net_arch: List[int], 15 | activation_fn: Module = ReLU, 16 | **kwarg, 17 | ): 18 | """ 19 | :param input_dim: input dimension (for vector) or input channel (for image) 20 | """ 21 | super().__init__() 22 | self.value_net, _ = mlp( 23 | input_shape, output_shape, net_arch, activation_fn, **kwarg 24 | ) 25 | self.apply(orthogonal_init) 26 | 27 | def forward(self, *input_) -> th.Tensor: 28 | input_ = th.cat(input_, dim=-1) 29 | return self.value_net(input_) 30 | 31 | 32 | class MLPTwinCritic(Module): 33 | def __init__( 34 | self, 35 | input_shape: Tuple[int,], 36 | output_shape: Tuple[int,], 37 | net_arch: List[int], 38 | activation_fn: Module = ReLU, 39 | **kwarg, 40 | ): 41 | super().__init__() 42 | 43 | self.Q_1, _ = mlp(input_shape, output_shape, net_arch, activation_fn, **kwarg) 44 | self.Q_2, _ = mlp(input_shape, output_shape, net_arch, activation_fn, **kwarg) 45 | 46 | self.apply(orthogonal_init) 47 | 48 | def forward(self, twin_value: bool, *input_): 49 | """ 50 | :param twin_value: whether to return both Q1 and Q2 51 | """ 52 | input_ = th.cat(input_, dim=-1) 53 | return (self.Q_1(input_), self.Q_2(input_)) if twin_value else self.Q_1(input_) 54 | 55 | 56 | class MLPDuleQNet(Module): 57 | """Dueling Q Network""" 58 | 59 | def __init__( 60 | self, 61 | input_shape: Tuple[int,], 62 | output_shape: Tuple[int,], 63 | net_arch: List[int], 64 | v_head: List[int], 65 | adv_head: List[int], 66 | activation_fn: Module = ReLU, 67 | mix_type: str = "max", 68 | **kwarg, 69 | ): 70 | super().__init__() 71 | self.feature_extrator, feature_shape = mlp( 72 | input_shape, (-1,), net_arch, activation_fn, **kwarg 73 | ) 74 | self.value_head, _ = mlp(feature_shape, (1,), v_head, activation_fn, **kwarg) 75 | self.adv_head, _ = mlp( 76 | feature_shape, output_shape, adv_head, activation_fn, **kwarg 77 | ) 78 | self.mix_type = mix_type 79 | 80 | self.apply(orthogonal_init) 81 | 82 | def forward(self, state: th.Tensor): 83 | feature = self.feature_extrator(state) 84 | v = self.value_head(feature) 85 | adv = self.adv_head(feature) 86 | if self.mix_type == "max": 87 | q = v + (adv - th.max(adv, dim=-1, keepdim=True)[0]) 88 | elif self.mix_type == "mean": 89 | q = v + (adv - th.mean(adv, dim=-1, keepdim=True)) 90 | else: 91 | raise NotImplementedError 92 | return q 93 | -------------------------------------------------------------------------------- /src/utils/net/ptu.py: -------------------------------------------------------------------------------- 1 | from itertools import zip_longest 2 | from os.path import join 3 | from typing import Dict, Iterable, List, Tuple, Union 4 | 5 | import torch as th 6 | from torch import nn 7 | from torch.optim import Optimizer 8 | 9 | # --------------------- Setting -------------------- 10 | 11 | 12 | def set_torch(default_th_dtype: th.dtype = th.float32): 13 | th.set_default_dtype(default_th_dtype) 14 | th.utils.backcompat.broadcast_warning.enabled = True 15 | th.utils.backcompat.keepdim_warning.enabled = True 16 | th.set_float32_matmul_precision("high") 17 | 18 | 19 | # --------------------- Tensor --------------------- 20 | 21 | 22 | def tensor2ndarray(tensors: Tuple[th.Tensor]): 23 | """Convert torch.Tensor to numpy.ndarray""" 24 | result = [] 25 | for item in tensors: 26 | if th.is_tensor(item): 27 | result.append(item.detach().cpu().numpy()) 28 | else: 29 | result.append(item) 30 | return result 31 | 32 | 33 | # ------------------- Manipulate NN Module ---------------------- 34 | 35 | 36 | def move_device(modules: List[th.nn.Module], device: Union[str, th.device]): 37 | """Move net to specified device""" 38 | for module in modules: 39 | module.to(device) 40 | 41 | 42 | def freeze_net(nets: List[nn.Module]): 43 | for net in nets: 44 | for p in net.parameters(): 45 | p.requires_grad = False 46 | 47 | 48 | def save_torch_model( 49 | models: Dict[str, Union[nn.Module, th.Tensor]], 50 | ckpt_dir: str, 51 | model_name: str = "models", 52 | file_ext: str = ".pt", 53 | ) -> str: 54 | """Save [Pytorch] model to a pre-specified path 55 | Note: Currently, only th.Tensor and th.nn.Module are supported. 56 | """ 57 | model_name = model_name + file_ext 58 | model_path = join(ckpt_dir, model_name) 59 | state_dicts = {} 60 | for name, model in models.items(): 61 | if isinstance(model, th.Tensor): 62 | state_dicts[name] = {name: model} 63 | else: 64 | state_dicts[name] = model.state_dict() 65 | th.save(state_dicts, model_path) 66 | return f"Successfully save model to {model_path}!" 67 | 68 | 69 | def load_torch_model( 70 | models: Dict[str, Union[nn.Module, th.Tensor]], model_path: str 71 | ) -> str: 72 | """Load [Pytorch] model from a pre-specified path""" 73 | state_dicts = th.load(model_path, weights_only=True) 74 | for name, model in models.items(): 75 | if isinstance(model, th.Tensor): 76 | models[name].data = state_dicts[name][name].data 77 | else: 78 | model.load_state_dict(state_dicts[name]) 79 | return f"Successfully load model from {model_path}!" 80 | 81 | 82 | def set_train_mode(models: Dict[str, Union[nn.Module, th.Tensor]]): 83 | """Set mode of the models to train""" 84 | for model in models: 85 | if isinstance(model, nn.Module): 86 | models[model].train() 87 | 88 | 89 | def set_eval_mode(models: Dict[str, Union[nn.Module, th.Tensor]]): 90 | """Set mode of the models to eval""" 91 | for model in models: 92 | if isinstance(model, nn.Module): 93 | models[model].eval() 94 | 95 | 96 | # copied from stable_baselines3 97 | def zip_strict(*iterables: Iterable) -> Iterable: 98 | r""" 99 | ``zip()`` function but enforces that iterables are of equal length. 100 | Raises ``ValueError`` if iterables not of equal length. 101 | Code inspired by Stackoverflow answer for question #32954486. 102 | 103 | :param \*iterables: iterables to ``zip()`` 104 | """ 105 | # As in Stackoverflow #32954486, use 106 | # new object for "empty" in case we have 107 | # Nones in iterable. 108 | sentinel = object() 109 | for combo in zip_longest(*iterables, fillvalue=sentinel): 110 | if sentinel in combo: 111 | raise ValueError("Iterables have different lengths") 112 | yield combo 113 | 114 | 115 | def polyak_update( 116 | params: Iterable[th.Tensor], 117 | target_params: Iterable[th.Tensor], 118 | tau: float, 119 | ) -> None: 120 | """ 121 | Perform a Polyak average update on ``target_params`` using ``params``: 122 | target parameters are slowly updated towards the main parameters. 123 | ``tau``, the soft update coefficient controls the interpolation: 124 | ``tau=1`` corresponds to copying the parameters to the target ones whereas nothing happens when ``tau=0``. 125 | The Polyak update is done in place, with ``no_grad``, and therefore does not create intermediate tensors, 126 | or a computation graph, reducing memory cost and improving performance. We scale the target params 127 | by ``1-tau`` (in-place), add the new weights, scaled by ``tau`` and store the result of the sum in the target 128 | params (in place). 129 | See https://github.com/DLR-RM/stable-baselines3/issues/93 130 | 131 | :param params: parameters to use to update the target params 132 | :param target_params: parameters to update 133 | :param tau: the soft update coefficient ("Polyak update", between 0 and 1) 134 | """ 135 | with th.no_grad(): 136 | # zip does not raise an exception if length of parameters does not match. 137 | for param, target_param in zip_strict(params, target_params): 138 | target_param.data.mul_(1 - tau) 139 | th.add(target_param.data, param.data, alpha=tau, out=target_param.data) 140 | 141 | 142 | # ------------------ Initialization ---------------------------- 143 | 144 | 145 | def orthogonal_init(m): 146 | """Custom weight init for Conv2D and Linear layers.""" 147 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 148 | nn.init.orthogonal_(m.weight.data) 149 | if hasattr(m.bias, "data"): 150 | m.bias.data.fill_(0.0) 151 | 152 | 153 | # ----------------------- Optimization ---------------------------- 154 | 155 | 156 | def gradient_descent( 157 | net_optim: Optimizer, 158 | loss: th.Tensor, 159 | parameters: Union[th.Tensor, Iterable[th.Tensor]] = None, 160 | max_grad_norm: float = None, 161 | retain_graph: bool = False, 162 | ): 163 | """Update network parameters with gradient descent.""" 164 | net_optim.zero_grad() 165 | loss.backward(retain_graph=retain_graph) 166 | 167 | # gradient clip 168 | if all([parameters, max_grad_norm]): 169 | th.nn.utils.clip_grad_norm_(parameters, max_grad_norm) 170 | 171 | net_optim.step() 172 | return loss.item() 173 | 174 | 175 | # ------------------------ Modules ------------------------ 176 | 177 | 178 | def create_mlp( 179 | input_dim: int, 180 | output_dim: int, 181 | net_arch: List[int], 182 | activation_fn: nn.Module = nn.ReLU, 183 | squash_output: bool = False, 184 | with_bias: bool = True, 185 | ) -> List[nn.Module]: 186 | """ 187 | Copied from stable_baselines 188 | 189 | Create a multi layer perceptron (MLP), which is 190 | a collection of fully-connected layers each followed by an activation function. 191 | 192 | :param input_dim: Dimension of the input vector 193 | :param output_dim: 194 | :param net_arch: Architecture of the neural net 195 | It represents the number of units per layer. 196 | The length of this list is the number of layers. 197 | :param activation_fn: The activation function 198 | to use after each layer. 199 | :param squash_output: Whether to squash the output using a Tanh 200 | activation function 201 | :param with_bias: If set to False, the layers will not learn an additive bias 202 | :return: 203 | """ 204 | 205 | if len(net_arch) > 0: 206 | modules = [nn.Linear(input_dim, net_arch[0], bias=with_bias), activation_fn()] 207 | else: 208 | modules = [] 209 | 210 | for idx in range(len(net_arch) - 1): 211 | modules.append(nn.Linear(net_arch[idx], net_arch[idx + 1], bias=with_bias)) 212 | modules.append(activation_fn()) 213 | 214 | if output_dim > 0: 215 | last_layer_dim = net_arch[-1] if len(net_arch) > 0 else input_dim 216 | modules.append(nn.Linear(last_layer_dim, output_dim, bias=with_bias)) 217 | if squash_output: 218 | modules.append(nn.Tanh()) 219 | return modules 220 | 221 | 222 | def variable(shape: Tuple[int, ...]): 223 | return nn.Parameter(th.zeros(shape), requires_grad=True) 224 | 225 | 226 | def mlp( 227 | input_shape: Tuple[int,], 228 | output_shape: Tuple[int,], 229 | net_arch: List[int], 230 | activation_fn: nn.Module = nn.ReLU, 231 | squash_output: bool = False, 232 | ) -> Tuple[List[nn.Module], int]: 233 | """ 234 | :return: (net, feature_dim) 235 | """ 236 | # output feature dimension 237 | if output_shape[0] == -1: 238 | if len(net_arch) > 0: 239 | feature_shape = (net_arch[-1], 0) 240 | else: 241 | raise ValueError("Empty MLP!") 242 | else: 243 | feature_shape = output_shape 244 | # networks 245 | net = nn.Sequential( 246 | *create_mlp( 247 | input_shape[0], output_shape[0], net_arch, activation_fn, squash_output 248 | ) 249 | ) 250 | return net, feature_shape 251 | 252 | 253 | def cnn( 254 | input_shape: List[int], 255 | output_dim: int, 256 | net_arch: List[Tuple[int]], 257 | activation_fn: nn.Module = nn.ReLU, 258 | ) -> Tuple[List[nn.Module], int]: 259 | """ 260 | :param input_shape: (channel, ...) 261 | :net_arch: list of conv2d, i.e., (output_channel, kernel_size, stride, padding) 262 | """ 263 | input_channel = input_shape[0] 264 | 265 | if len(net_arch) > 0: 266 | module = [nn.Conv2d(input_channel, *net_arch[0]), activation_fn()] 267 | else: 268 | raise ValueError("Empty CNN!") 269 | 270 | # parse modules 271 | for i in range(1, len(net_arch)): 272 | module.append(nn.Conv2d(net_arch[i - 1][0], *net_arch[i])) 273 | module.append(activation_fn()) 274 | net = nn.Sequential(*module) 275 | net.add_module("flatten-0", nn.Flatten()) 276 | 277 | # Compute shape by doing one forward pass 278 | with th.no_grad(): 279 | n_flatten = net(th.randn(input_shape).unsqueeze(dim=0)).shape[1] 280 | 281 | # We use -1 to just extract the feature 282 | if output_dim == -1: 283 | return net, n_flatten 284 | else: 285 | net.add_module("linear-0", nn.Linear(n_flatten, output_dim)) 286 | return net, output_dim 287 | -------------------------------------------------------------------------------- /src/utils/ospy/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import ( 2 | get_dataset, 3 | get_dataset_holder, 4 | get_h5_keys, 5 | get_one_traj, 6 | save_dataset_to_h5, 7 | split_dataset_into_trajs, 8 | ) 9 | from .file import copys 10 | from .util import filter_from_list 11 | 12 | __all__ = [ 13 | get_dataset, 14 | get_dataset_holder, 15 | get_h5_keys, 16 | get_one_traj, 17 | save_dataset_to_h5, 18 | split_dataset_into_trajs, 19 | copys, 20 | filter_from_list, 21 | ] 22 | -------------------------------------------------------------------------------- /src/utils/ospy/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join 3 | from typing import Dict, List 4 | 5 | import h5py 6 | import numpy as np 7 | from tqdm import tqdm 8 | 9 | # ================ Helpers ======================= 10 | 11 | 12 | def get_h5_keys(h5file: h5py.File) -> List[str]: 13 | keys = [] 14 | 15 | def visitor(name, item): 16 | if isinstance(item, h5py.Dataset): 17 | keys.append(name) 18 | 19 | h5file.visititems(visitor) 20 | return keys 21 | 22 | 23 | def get_dataset_holder(with_log_prob: bool): 24 | """To determine the portions of demos""" 25 | dataset = dict( 26 | observations=[], 27 | actions=[], 28 | rewards=[], 29 | next_observations=[], 30 | terminals=[], 31 | timeouts=[], 32 | ) 33 | if with_log_prob: 34 | dataset["infos/action_log_probs"] = [] 35 | return dataset 36 | 37 | 38 | # ================ Utility functions ================ 39 | 40 | 41 | def split_dataset_into_trajs( 42 | dataset: Dict[str, np.ndarray], max_episode_steps: int = None 43 | ): 44 | """Split the [D4RL] style dataset into trajectories 45 | 46 | :return: the corresponding start index and end index (not included) of every trajectories 47 | """ 48 | max_steps = dataset["observations"].shape[0] 49 | if "timeouts" in dataset: 50 | timeout_idx = np.where(dataset["timeouts"] == True)[0] + 1 51 | terminal_idx = np.where(dataset["terminals"] == True)[0] + 1 52 | start_idx = sorted( 53 | set( 54 | [0] 55 | + timeout_idx[timeout_idx < max_steps].tolist() 56 | + terminal_idx[terminal_idx < max_steps].tolist() 57 | + [max_steps] 58 | ) 59 | ) 60 | traj_pairs = list(zip(start_idx[:-1], start_idx[1:])) 61 | else: 62 | if max_episode_steps is None: 63 | raise Exception( 64 | "You have to specify the max_episode_steps if no timeouts in dataset" 65 | ) 66 | else: 67 | traj_pairs = [] 68 | i = 0 69 | while i < max_steps: 70 | start_idx = i 71 | traj_len = 1 72 | while (traj_len <= max_episode_steps) and (i < max_steps): 73 | i += 1 74 | traj_len += 1 75 | if dataset["terminals"][i - 1]: 76 | break 77 | traj_pairs.append([start_idx, i]) 78 | return traj_pairs 79 | 80 | 81 | # ============================= Save ============================== 82 | 83 | 84 | def save_dataset_to_h5(dataset: Dict[str, np.ndarray], save_dir: str, file_name: str): 85 | """To dump dataset into .hdf5 file 86 | 87 | :param dataset: Dataset to be saved 88 | :param save_dir: To save the collected demos 89 | :param file_name: File name of the saved demos 90 | """ 91 | os.makedirs(save_dir, exist_ok=True) 92 | save_path = join(save_dir, file_name + ".hdf5") 93 | hfile = h5py.File(save_path, "w") 94 | for key, value in dataset.items(): 95 | hfile.create_dataset(key, data=value, compression="gzip") 96 | 97 | 98 | # ============================= Get ============================== 99 | 100 | 101 | def get_one_traj( 102 | dataset: Dict[str, np.ndarray], 103 | start_idx: int, 104 | end_idx: int, 105 | with_log_prob: bool = False, 106 | ): 107 | """Return a trajectory in dataset, from start_idx to end_idx (not included).""" 108 | one_traj = { 109 | "observations": dataset["observations"][start_idx:end_idx], 110 | "actions": dataset["actions"][start_idx:end_idx], 111 | "rewards": dataset["rewards"][start_idx:end_idx], 112 | "next_observations": dataset["next_observations"][start_idx:end_idx], 113 | "terminals": dataset["terminals"][start_idx:end_idx], 114 | } 115 | if with_log_prob and "infos/action_log_probs" in dataset: 116 | one_traj.update( 117 | { 118 | "infos/action_log_probs": dataset["infos/action_log_probs"][ 119 | start_idx:end_idx 120 | ] 121 | } 122 | ) 123 | return one_traj 124 | 125 | 126 | def get_dataset( 127 | use_own_dataset: bool, 128 | own_dataset_path: str = None, 129 | d4rl_env_id: str = None, 130 | **kwargs 131 | ) -> Dict[str, np.ndarray]: 132 | if use_own_dataset: 133 | assert ( 134 | own_dataset_path is not None 135 | ), "To use your own dataset, you must fisrt specify your dataset path" 136 | dataset = dict() 137 | with h5py.File(own_dataset_path, "r") as dataset_file: 138 | for k in tqdm(get_h5_keys(dataset_file), desc="load datafile"): 139 | dataset[k] = dataset_file[k][:] 140 | return dataset 141 | else: 142 | import d4rl 143 | 144 | # d4rl is not compatible with gymnasium but only gym 145 | import gym 146 | 147 | env = gym.make(d4rl_env_id) 148 | import gymnasium as gym 149 | 150 | return env.get_dataset() 151 | -------------------------------------------------------------------------------- /src/utils/ospy/file.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | 5 | def copys(src_path: str, tgt_path: str): 6 | if os.path.isfile(src_path): 7 | shutil.copy(src_path, tgt_path) 8 | elif os.path.isdir(src_path): 9 | shutil.copytree(src_path, tgt_path) 10 | else: 11 | raise TypeError("Unknown code file type!") 12 | -------------------------------------------------------------------------------- /src/utils/ospy/util.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import List 3 | 4 | 5 | def filter_from_list(file_list: List[str], rule: str) -> List[str]: 6 | """Could be used to match files given the [rule]""" 7 | return list(filter(lambda x: re.match(rule, x) != None, file_list)) 8 | -------------------------------------------------------------------------------- /train_agent.py: -------------------------------------------------------------------------------- 1 | import os 2 | import signal 3 | from typing import Callable 4 | 5 | import gymnasium as gym 6 | import hydra 7 | import numpy as np 8 | import torch as th 9 | from omegaconf import DictConfig, OmegaConf 10 | 11 | from src import BaseRLAgent, create_agent 12 | from src.utils.drls.env import get_env_info, make_env, reset_env_fn 13 | from src.utils.exp.prepare import set_random_seed 14 | from src.utils.logger import TBLogger 15 | from src.utils.net.ptu import ( 16 | save_torch_model, 17 | set_eval_mode, 18 | set_torch, 19 | set_train_mode, 20 | tensor2ndarray, 21 | ) 22 | 23 | 24 | @th.no_grad 25 | def eval_policy( 26 | eval_env: gym.Env, 27 | reset_env_fn: Callable, 28 | policy: BaseRLAgent, 29 | seed: int, 30 | episodes=10, 31 | ): 32 | """Evaluate Policy""" 33 | set_eval_mode(policy.models) 34 | returns = [] 35 | for _ in range(episodes): 36 | (state, _), terminated, truncated = reset_env_fn(eval_env, seed), False, False 37 | return_ = 0.0 38 | while not (terminated or truncated): 39 | action = policy.select_action( 40 | state, 41 | deterministic=True, 42 | return_log_prob=False, 43 | **{"action_space": eval_env.action_space}, 44 | ) 45 | state, reward, terminated, truncated, _ = eval_env.step( 46 | tensor2ndarray((action,))[0] 47 | ) 48 | return_ += reward 49 | returns.append(return_) 50 | set_train_mode(policy.models) 51 | 52 | # average 53 | return np.mean(returns) 54 | 55 | 56 | @hydra.main(config_path="./conf", config_name="train_agent", version_base="1.3.2") 57 | def main(cfg: DictConfig): 58 | cfg.work_dir = os.getcwd() 59 | # prepare experiment 60 | set_torch() 61 | set_random_seed(cfg.seed) 62 | 63 | # setup logger 64 | logger = TBLogger( 65 | args=OmegaConf.to_object(cfg), 66 | record_param=cfg.log.record_param, 67 | console_output=cfg.log.console_output, 68 | ) 69 | 70 | # setup environment 71 | train_env, eval_env = (make_env(cfg.env.id), make_env(cfg.env.id)) 72 | OmegaConf.update(cfg, "env[info]", get_env_info(eval_env), merge=False) 73 | 74 | # create agent 75 | agent = create_agent(cfg) 76 | 77 | # train agent 78 | def ctr_c_handler(_signum, _frame): 79 | """If the program was stopped by ctr+c, we will save the model before leaving""" 80 | logger.console.warning("The program is stopped...") 81 | logger.console.info( 82 | save_torch_model(agent.models, logger.ckpt_dir, "stopped_model") 83 | ) # save model 84 | exit(1) 85 | 86 | signal.signal(signal.SIGINT, ctr_c_handler) 87 | 88 | agent.learn(train_env, eval_env, reset_env_fn, eval_policy, logger) 89 | 90 | # save model 91 | logger.console.info(save_torch_model(agent.models, logger.ckpt_dir, "final_model")) 92 | 93 | 94 | if __name__ == "__main__": 95 | main() 96 | --------------------------------------------------------------------------------