├── .gitignore ├── README.md ├── fig └── method.png ├── log └── directory_place_holder ├── meta_config ├── mt10 │ ├── modular_2_2_2_256_rand.json │ ├── modular_2_2_2_256_reweight.json │ ├── modular_2_2_2_256_reweight_rand.json │ ├── modular_4_4_2_128_reweight.json │ ├── modular_4_4_2_128_reweight_rand.json │ ├── mtmhsac.json │ ├── mtmhsac_rand.json │ ├── mtsac.json │ └── mtsac_rand.json └── mt50 │ ├── modular_2_2_2_256_reweight.json │ ├── modular_2_2_2_256_reweight_rand.json │ ├── modular_4_4_2_128_reweight.json │ ├── modular_4_4_2_128_reweight_rand.json │ ├── mtmhsac.json │ ├── mtmhsac_rand.json │ ├── mtsac.json │ └── mtsac_rand.json ├── metaworld_utils ├── __init__.py └── meta_env.py ├── starter ├── mt_para_mhmt_sac.py ├── mt_para_mtsac.py └── mt_para_mtsac_modular_gated_cas.py └── torchrl ├── __init__.py ├── algo ├── __init__.py ├── off_policy │ ├── __init__.py │ ├── mt_sac.py │ ├── mtmh_sac.py │ ├── off_rl_algo.py │ ├── sac.py │ ├── twin_sac.py │ └── twin_sac_q.py ├── rl_algo.py └── utils.py ├── collector ├── __init__.py ├── base.py ├── mt.py └── para │ ├── __init__.py │ ├── async_mt.py │ ├── base.py │ └── mt.py ├── env ├── __init__.py ├── base_wrapper.py ├── continuous_wrapper.py └── get_env.py ├── networks ├── __init__.py ├── base.py ├── init.py └── nets.py ├── policies ├── __init__.py ├── continuous_policy.py └── distribution.py ├── replay_buffers ├── __init__.py ├── base.py └── shared │ ├── __init__.py │ ├── base.py │ └── shmarray.py └── utils ├── __init__.py ├── args.py ├── logger.py └── plot_csv.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .venv 106 | venv/ 107 | ENV/ 108 | env.bak/ 109 | venv.bak/ 110 | 111 | # Spyder project settings 112 | .spyderproject 113 | .spyproject 114 | 115 | # Rope project settings 116 | .ropeproject 117 | 118 | # mkdocs documentation 119 | /site 120 | 121 | # mypy 122 | .mypy_cache/ 123 | .dmypy.json 124 | dmypy.json 125 | 126 | # Pyre type checker 127 | .pyre/ 128 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Soft-Module 2 | 3 | Implementation for "Multi-task Reinforcement Learning with Soft Modularization" 4 | 5 | Paper Link: [Multi-Task Reinforcement Learning with Soft Modularization](https://arxiv.org/abs/2003.13661) 6 | 7 | ![Demo](./fig/method.png) 8 | 9 | Our project page is at [https://rchalyang.github.io/SoftModule/](https://rchalyang.github.io/SoftModule/) 10 | 11 | ## Setup Environment 12 | 13 | ### Environement Requirements 14 | * Python 3 15 | * Pytorch 1.7 16 | * posix_ipc 17 | * tensorboardX 18 | * tabulate, gym 19 | * MetaWorld(Please check next section to set-up MetaWorld) 20 | * seaborn(for plotting) 21 | 22 | ### MetaWorld Setup 23 | We evaluated our method on [MetaWorld](https://meta-world.github.io). 24 | 25 | Since [MetaWorld](https://meta-world.github.io) is under active development, we perform all the experiment on our forked MetaWorld(https://github.com/RchalYang/metaworld). 26 | 27 | ``` 28 | #Our MetaWorld installation 29 | git clone https://github.com/RchalYang/metaworld.git 30 | cd metaworld 31 | pip install -e . 32 | ``` 33 | 34 | ## Our Network Structure 35 | 36 | See ```ModularGatedCascadeCondNet``` in ```torchrl/networks/nets.py``` for details 37 | 38 | ## Training 39 | 40 | All log and snapshot would be stored logging directory. Logging directory is default to be "./log/EXPERIMENT_NAME". 41 | 42 | EXPERIMENT_NAME can be set with "--id" argument when start experiment. And prefix directory can be set with "--log_dir" argument) 43 | 44 | ``` 45 | # Modular Network // MT10-Conditioned // Shallow 46 | python starter/mt_para_mtsac_modular_gated_cas.py --config meta_config/mt10/modular_2_2_2_256_reweight_rand.json --id MT10_Conditioned_Modular_Shallow --seed SEED --worker_nums 10 --eval_worker_nums 10 47 | 48 | 49 | # Modular Network // MT10-Fixed // Shallow 50 | python starter/mt_para_mtsac_modular_gated_cas.py --config meta_config/mt10/modular_2_2_2_256_reweight.json --id MT10_Fixed_Modular_Shallow --seed SEED --worker_nums 10 --eval_worker_nums 10 51 | 52 | 53 | # Modular Network // MT10-Conditioned // Deep 54 | python starter/mt_para_mtsac_modular_gated_cas.py --config meta_config/mt10/modular_4_4_2_128_reweight_rand.json --id MT10_Conditioned_Modular_Deep --seed SEED --worker_nums 10 --eval_worker_nums 10 55 | 56 | 57 | # Modular Network // MT10-Fixed // Deep 58 | python starter/mt_para_mtsac_modular_gated_cas.py --config meta_config/mt10/modular_4_4_2_128_reweight.json --id MT10_Fixed_Modular_Deep --seed SEED --worker_nums 10 --eval_worker_nums 10 59 | 60 | 61 | # Modular Network // MT50-Conditioned // Shallow 62 | python starter/mt_para_mtsac_modular_gated_cas.py --config meta_config/mt50/modular_2_2_2_256_reweight_rand.json --id MT50_Conditioned_Modular_Shallow --seed SEED --worker_nums 50 --eval_worker_nums 50 63 | 64 | 65 | # Modular Network // MT50-Fixed // Shallow 66 | python starter/mt_para_mtsac_modular_gated_cas.py --config meta_config/mt50/modular_2_2_2_256_reweight.json --id MT50_Fixed_Modular_Shallow --seed SEED --worker_nums 50 --eval_worker_nums 50 67 | 68 | 69 | # Modular Network // MT50-Conditioned // Deep 70 | python starter/mt_para_mtsac_modular_gated_cas.py --config meta_config/mt50/modular_4_4_2_128_reweight_rand.json --id MT50_Conditioned_Modular_Deep --seed SEED --worker_nums 50 --eval_worker_nums 50 71 | 72 | 73 | # Modular Network // MT50-Fixed // Deep 74 | python starter/mt_para_mtsac_modular_gated_cas.py --config meta_config/mt50/modular_4_4_2_128_reweight.json --id MT50_Fixed_Modular_Deep --seed SEED --worker_nums 50 --eval_worker_nums 50 75 | 76 | ``` 77 | 78 | ## Plot Training Curve 79 | 80 | To plot the training curves, you could use the following command. 81 | 82 | * id argument is used for multiple experiment names. 83 | 84 | * seed argument is used for multiple seeds 85 | 86 | * replace "mean_success_rate" with different entry to see different curve for different entry. 87 | 88 | ``` 89 | python torchrl/utils/plot_csv.py --id EXPERIMENTS --env_name mt10 --entry "mean_success_rate" --add_tag POSTFIX_FOR_OUTPUT_FILES --seed SEEDS 90 | ``` 91 | 92 | 93 | ## Citation 94 | 95 | If you find our work useful, please cite our work. 96 | 97 | ``` 98 | @misc{yang2020multitask, 99 | title={Multi-Task Reinforcement Learning with Soft Modularization}, 100 | author={Ruihan Yang and Huazhe Xu and Yi Wu and Xiaolong Wang}, 101 | year={2020}, 102 | eprint={2003.13661}, 103 | archivePrefix={arXiv}, 104 | primaryClass={cs.LG} 105 | } 106 | ``` 107 | -------------------------------------------------------------------------------- /fig/method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RchalYang/Soft-Module/e6d7c8ad362a6b950632236322356da742dc838c/fig/method.png -------------------------------------------------------------------------------- /log/directory_place_holder: -------------------------------------------------------------------------------- 1 | use empty file to add log into git -------------------------------------------------------------------------------- /meta_config/mt10/modular_2_2_2_256_rand.json: -------------------------------------------------------------------------------- 1 | { 2 | "env_name" : "mt10", 3 | "env":{ 4 | "reward_scale":1, 5 | "obs_norm":false 6 | }, 7 | "meta_env":{ 8 | "obs_type": "with_goal", 9 | "random_init": true 10 | }, 11 | "replay_buffer":{ 12 | "size": 1e6 13 | }, 14 | "net":{ 15 | "hidden_shapes": [400, 400], 16 | "em_hidden_shapes": [400], 17 | "num_layers": 2, 18 | "num_modules": 2, 19 | "module_hidden": 256, 20 | "num_gating_layers": 2, 21 | "gating_hidden": 256, 22 | "add_bn": false, 23 | "pre_softmax": false 24 | }, 25 | "general_setting": { 26 | "discount" : 0.99, 27 | "pretrain_epochs" : 20, 28 | "num_epochs" : 7500, 29 | "epoch_frames" : 200, 30 | "max_episode_frames" : 200, 31 | 32 | "batch_size" : 1280, 33 | "min_pool" : 10000, 34 | 35 | "target_hard_update_period" : 1000, 36 | "use_soft_update" : true, 37 | "tau" : 0.005, 38 | "opt_times" : 200, 39 | 40 | "eval_episodes" : 3 41 | }, 42 | "sac":{ 43 | 44 | "plr" : 3e-4, 45 | "qlr" : 3e-4, 46 | 47 | "reparameterization": true, 48 | "automatic_entropy_tuning": true, 49 | "policy_std_reg_weight": 0, 50 | "policy_mean_reg_weight": 0 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /meta_config/mt10/modular_2_2_2_256_reweight.json: -------------------------------------------------------------------------------- 1 | { 2 | "env_name" : "mt10", 3 | "env":{ 4 | "reward_scale":1, 5 | "obs_norm":false 6 | }, 7 | "meta_env":{ 8 | "obs_type": "with_goal" 9 | }, 10 | "replay_buffer":{ 11 | "size": 1e6 12 | }, 13 | "net":{ 14 | "hidden_shapes": [400, 400], 15 | "em_hidden_shapes": [400], 16 | "num_layers": 2, 17 | "num_modules": 2, 18 | "module_hidden": 256, 19 | "num_gating_layers": 2, 20 | "gating_hidden": 256, 21 | "add_bn": false, 22 | "pre_softmax": false 23 | }, 24 | "general_setting": { 25 | "discount" : 0.99, 26 | "pretrain_epochs" : 20, 27 | "num_epochs" : 7500, 28 | "epoch_frames" : 200, 29 | "max_episode_frames" : 200, 30 | 31 | "batch_size" : 1280, 32 | "min_pool" : 10000, 33 | 34 | "target_hard_update_period" : 1000, 35 | "use_soft_update" : true, 36 | "tau" : 0.005, 37 | "opt_times" : 200, 38 | 39 | "eval_episodes" : 3 40 | }, 41 | "sac":{ 42 | 43 | "plr" : 3e-4, 44 | "qlr" : 3e-4, 45 | 46 | "reparameterization": true, 47 | "automatic_entropy_tuning": true, 48 | "temp_reweight": true, 49 | "policy_std_reg_weight": 0, 50 | "policy_mean_reg_weight": 0 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /meta_config/mt10/modular_2_2_2_256_reweight_rand.json: -------------------------------------------------------------------------------- 1 | { 2 | "env_name" : "mt10", 3 | "env":{ 4 | "reward_scale":1, 5 | "obs_norm":false 6 | }, 7 | "meta_env":{ 8 | "obs_type": "with_goal", 9 | "random_init": true 10 | }, 11 | "replay_buffer":{ 12 | "size": 1e6 13 | }, 14 | "net":{ 15 | "hidden_shapes": [400, 400], 16 | "em_hidden_shapes": [400], 17 | "num_layers": 2, 18 | "num_modules": 2, 19 | "module_hidden": 256, 20 | "num_gating_layers": 2, 21 | "gating_hidden": 256, 22 | "add_bn": false, 23 | "pre_softmax": false 24 | }, 25 | "general_setting": { 26 | "discount" : 0.99, 27 | "pretrain_epochs" : 20, 28 | "num_epochs" : 7500, 29 | "epoch_frames" : 200, 30 | "max_episode_frames" : 200, 31 | 32 | "batch_size" : 1280, 33 | "min_pool" : 10000, 34 | 35 | "target_hard_update_period" : 1000, 36 | "use_soft_update" : true, 37 | "tau" : 0.005, 38 | "opt_times" : 200, 39 | 40 | "eval_episodes" : 3 41 | }, 42 | "sac":{ 43 | 44 | "plr" : 3e-4, 45 | "qlr" : 3e-4, 46 | 47 | "reparameterization": true, 48 | "automatic_entropy_tuning": true, 49 | "temp_reweight": true, 50 | "policy_std_reg_weight": 0, 51 | "policy_mean_reg_weight": 0 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /meta_config/mt10/modular_4_4_2_128_reweight.json: -------------------------------------------------------------------------------- 1 | { 2 | "env_name" : "mt10", 3 | "env":{ 4 | "reward_scale":1, 5 | "obs_norm":false 6 | }, 7 | "meta_env":{ 8 | "obs_type": "with_goal" 9 | }, 10 | "replay_buffer":{ 11 | "size": 1e6 12 | }, 13 | "net":{ 14 | "hidden_shapes": [400, 400], 15 | "em_hidden_shapes": [400], 16 | "num_layers": 4, 17 | "num_modules": 4, 18 | "module_hidden": 128, 19 | "num_gating_layers": 2, 20 | "gating_hidden": 256, 21 | "add_bn": false, 22 | "pre_softmax": false 23 | }, 24 | "general_setting": { 25 | "discount" : 0.99, 26 | "pretrain_epochs" : 20, 27 | "num_epochs" : 7500, 28 | "epoch_frames" : 200, 29 | "max_episode_frames" : 200, 30 | 31 | "batch_size" : 1280, 32 | "min_pool" : 10000, 33 | 34 | "target_hard_update_period" : 1000, 35 | "use_soft_update" : true, 36 | "tau" : 0.005, 37 | "opt_times" : 200, 38 | 39 | "eval_episodes" : 3 40 | }, 41 | "sac":{ 42 | 43 | "plr" : 3e-4, 44 | "qlr" : 3e-4, 45 | 46 | "reparameterization": true, 47 | "automatic_entropy_tuning": true, 48 | "temp_reweight": true, 49 | "policy_std_reg_weight": 0, 50 | "policy_mean_reg_weight": 0 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /meta_config/mt10/modular_4_4_2_128_reweight_rand.json: -------------------------------------------------------------------------------- 1 | { 2 | "env_name" : "mt10", 3 | "env":{ 4 | "reward_scale":1, 5 | "obs_norm":false 6 | }, 7 | "meta_env":{ 8 | "obs_type": "with_goal", 9 | "random_init": true 10 | }, 11 | "replay_buffer":{ 12 | "size": 1e6 13 | }, 14 | "net":{ 15 | "hidden_shapes": [400, 400], 16 | "em_hidden_shapes": [400], 17 | "num_layers": 4, 18 | "num_modules": 4, 19 | "module_hidden": 128, 20 | "num_gating_layers": 2, 21 | "gating_hidden": 256, 22 | "add_bn": false, 23 | "pre_softmax": false 24 | }, 25 | "general_setting": { 26 | "discount" : 0.99, 27 | "pretrain_epochs" : 20, 28 | "num_epochs" : 7500, 29 | "epoch_frames" : 200, 30 | "max_episode_frames" : 200, 31 | 32 | "batch_size" : 1280, 33 | "min_pool" : 10000, 34 | 35 | "target_hard_update_period" : 1000, 36 | "use_soft_update" : true, 37 | "tau" : 0.005, 38 | "opt_times" : 200, 39 | 40 | "eval_episodes" : 3 41 | }, 42 | "sac":{ 43 | 44 | "plr" : 3e-4, 45 | "qlr" : 3e-4, 46 | 47 | "reparameterization": true, 48 | "automatic_entropy_tuning": true, 49 | "temp_reweight": true, 50 | "policy_std_reg_weight": 0, 51 | "policy_mean_reg_weight": 0 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /meta_config/mt10/mtmhsac.json: -------------------------------------------------------------------------------- 1 | { 2 | "env_name" : "mt10", 3 | "env":{ 4 | "reward_scale":1, 5 | "obs_norm":false 6 | }, 7 | "meta_env":{ 8 | "obs_type": "with_goal_and_id" 9 | }, 10 | "replay_buffer":{ 11 | "size": 1e6 12 | }, 13 | "net":{ 14 | "hidden_shapes": [400,400], 15 | "append_hidden_shapes":[400] 16 | }, 17 | "general_setting": { 18 | "discount" : 0.99, 19 | "pretrain_epochs" : 20, 20 | "num_epochs" : 7500, 21 | "epoch_frames" : 200, 22 | "max_episode_frames" : 200, 23 | 24 | "batch_size" : 1280, 25 | "min_pool" : 10000, 26 | 27 | "target_hard_update_period" : 1000, 28 | "use_soft_update" : true, 29 | "tau" : 0.005, 30 | "opt_times" : 200, 31 | 32 | "eval_episodes" : 3 33 | }, 34 | "sac":{ 35 | 36 | "plr" : 3e-4, 37 | "qlr" : 3e-4, 38 | 39 | "reparameterization": true, 40 | "automatic_entropy_tuning": true, 41 | "policy_std_reg_weight": 0, 42 | "policy_mean_reg_weight": 0 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /meta_config/mt10/mtmhsac_rand.json: -------------------------------------------------------------------------------- 1 | { 2 | "env_name" : "mt10", 3 | "env":{ 4 | "reward_scale":1, 5 | "obs_norm":false 6 | }, 7 | "meta_env":{ 8 | "obs_type": "with_goal_and_id", 9 | "random_init": true 10 | }, 11 | "replay_buffer":{ 12 | "size": 1e6 13 | }, 14 | "net":{ 15 | "hidden_shapes": [400,400], 16 | "append_hidden_shapes":[400] 17 | }, 18 | "general_setting": { 19 | "discount" : 0.99, 20 | "pretrain_epochs" : 20, 21 | "num_epochs" : 7500, 22 | "epoch_frames" : 200, 23 | "max_episode_frames" : 200, 24 | 25 | "batch_size" : 1280, 26 | "min_pool" : 10000, 27 | 28 | "target_hard_update_period" : 1000, 29 | "use_soft_update" : true, 30 | "tau" : 0.005, 31 | "opt_times" : 200, 32 | 33 | "eval_episodes" : 3 34 | }, 35 | "sac":{ 36 | 37 | "plr" : 3e-4, 38 | "qlr" : 3e-4, 39 | 40 | "reparameterization": true, 41 | "automatic_entropy_tuning": true, 42 | "policy_std_reg_weight": 0, 43 | "policy_mean_reg_weight": 0 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /meta_config/mt10/mtsac.json: -------------------------------------------------------------------------------- 1 | { 2 | "env_name" : "mt10", 3 | "env":{ 4 | "reward_scale":1, 5 | "obs_norm":false 6 | }, 7 | "meta_env":{ 8 | "obs_type": "with_goal_and_id" 9 | }, 10 | "replay_buffer":{ 11 | "size": 1e6 12 | }, 13 | "net":{ 14 | "hidden_shapes": [400,400,400], 15 | "append_hidden_shapes":[] 16 | }, 17 | "general_setting": { 18 | "discount" : 0.99, 19 | "pretrain_epochs" : 20, 20 | "num_epochs" : 7500, 21 | "epoch_frames" : 200, 22 | "max_episode_frames" : 200, 23 | 24 | "batch_size" : 1280, 25 | "min_pool" : 10000, 26 | 27 | "target_hard_update_period" : 1000, 28 | "use_soft_update" : true, 29 | "tau" : 0.005, 30 | "opt_times" : 200, 31 | 32 | "eval_episodes" : 3 33 | }, 34 | "sac":{ 35 | 36 | "plr" : 3e-4, 37 | "qlr" : 3e-4, 38 | 39 | "reparameterization": true, 40 | "automatic_entropy_tuning": true, 41 | "policy_std_reg_weight": 0, 42 | "policy_mean_reg_weight": 0 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /meta_config/mt10/mtsac_rand.json: -------------------------------------------------------------------------------- 1 | { 2 | "env_name" : "mt10", 3 | "env":{ 4 | "reward_scale":1, 5 | "obs_norm":false 6 | }, 7 | "meta_env":{ 8 | "obs_type": "with_goal_and_id", 9 | "random_init": true 10 | }, 11 | "replay_buffer":{ 12 | "size": 1e6 13 | }, 14 | "net":{ 15 | "hidden_shapes": [400,400,400], 16 | "append_hidden_shapes":[] 17 | }, 18 | "general_setting": { 19 | "discount" : 0.99, 20 | "pretrain_epochs" : 20, 21 | "num_epochs" : 7500, 22 | "epoch_frames" : 200, 23 | "max_episode_frames" : 200, 24 | 25 | "batch_size" : 1280, 26 | "min_pool" : 10000, 27 | 28 | "target_hard_update_period" : 1000, 29 | "use_soft_update" : true, 30 | "tau" : 0.005, 31 | "opt_times" : 200, 32 | 33 | "eval_episodes" : 3 34 | }, 35 | "sac":{ 36 | 37 | "plr" : 3e-4, 38 | "qlr" : 3e-4, 39 | 40 | "reparameterization": true, 41 | "automatic_entropy_tuning": true, 42 | "policy_std_reg_weight": 0, 43 | "policy_mean_reg_weight": 0 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /meta_config/mt50/modular_2_2_2_256_reweight.json: -------------------------------------------------------------------------------- 1 | { 2 | "env_name" : "mt50", 3 | "env":{ 4 | "reward_scale":1, 5 | "obs_norm":false 6 | }, 7 | "meta_env":{ 8 | "obs_type": "with_goal" 9 | }, 10 | "replay_buffer":{ 11 | "size": 1e7 12 | }, 13 | "net":{ 14 | "hidden_shapes": [400, 400], 15 | "em_hidden_shapes": [400], 16 | "num_layers": 2, 17 | "num_modules": 2, 18 | "module_hidden": 256, 19 | "num_gating_layers": 2, 20 | "gating_hidden": 256, 21 | "add_bn": false, 22 | "pre_softmax": false 23 | }, 24 | "general_setting": { 25 | "discount" : 0.99, 26 | "pretrain_epochs" : 20, 27 | "num_epochs" : 5000, 28 | "epoch_frames" : 200, 29 | "max_episode_frames" : 1000, 30 | 31 | "batch_size" : 6400, 32 | "min_pool" : 10000, 33 | 34 | "target_hard_update_period" : 1000, 35 | "use_soft_update" : true, 36 | "tau" : 0.005, 37 | "opt_times" : 200, 38 | 39 | "eval_episodes" : 3 40 | }, 41 | "sac":{ 42 | 43 | "plr" : 3e-4, 44 | "qlr" : 3e-4, 45 | 46 | "reparameterization": true, 47 | "automatic_entropy_tuning": true, 48 | "temp_reweight": true, 49 | "policy_std_reg_weight": 0, 50 | "policy_mean_reg_weight": 0 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /meta_config/mt50/modular_2_2_2_256_reweight_rand.json: -------------------------------------------------------------------------------- 1 | { 2 | "env_name" : "mt50", 3 | "env":{ 4 | "reward_scale":1, 5 | "obs_norm":false 6 | }, 7 | "meta_env":{ 8 | "obs_type": "with_goal", 9 | "random_init": true 10 | }, 11 | "replay_buffer":{ 12 | "size": 1e7 13 | }, 14 | "net":{ 15 | "hidden_shapes": [400, 400], 16 | "em_hidden_shapes": [400], 17 | "num_layers": 2, 18 | "num_modules": 2, 19 | "module_hidden": 256, 20 | "num_gating_layers": 2, 21 | "gating_hidden": 256, 22 | "add_bn": false, 23 | "pre_softmax": false 24 | }, 25 | "general_setting": { 26 | "discount" : 0.99, 27 | "pretrain_epochs" : 20, 28 | "num_epochs" : 5000, 29 | "epoch_frames" : 200, 30 | "max_episode_frames" : 1000, 31 | 32 | "batch_size" : 6400, 33 | "min_pool" : 10000, 34 | 35 | "target_hard_update_period" : 1000, 36 | "use_soft_update" : true, 37 | "tau" : 0.005, 38 | "opt_times" : 200, 39 | 40 | "eval_episodes" : 3 41 | }, 42 | "sac":{ 43 | 44 | "plr" : 3e-4, 45 | "qlr" : 3e-4, 46 | 47 | "reparameterization": true, 48 | "automatic_entropy_tuning": true, 49 | "temp_reweight": true, 50 | "policy_std_reg_weight": 0, 51 | "policy_mean_reg_weight": 0 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /meta_config/mt50/modular_4_4_2_128_reweight.json: -------------------------------------------------------------------------------- 1 | { 2 | "env_name" : "mt50", 3 | "env":{ 4 | "reward_scale":1, 5 | "obs_norm":false 6 | }, 7 | "meta_env":{ 8 | "obs_type": "with_goal" 9 | }, 10 | "replay_buffer":{ 11 | "size": 1e7 12 | }, 13 | "net":{ 14 | "hidden_shapes": [400, 400], 15 | "em_hidden_shapes": [400], 16 | "num_layers": 4, 17 | "num_modules": 4, 18 | "module_hidden": 128, 19 | "num_gating_layers": 2, 20 | "gating_hidden": 256, 21 | "add_bn": false, 22 | "pre_softmax": false 23 | }, 24 | "general_setting": { 25 | "discount" : 0.99, 26 | "pretrain_epochs" : 20, 27 | "num_epochs" : 5000, 28 | "epoch_frames" : 200, 29 | "max_episode_frames" : 1000, 30 | 31 | "batch_size" : 6400, 32 | "min_pool" : 10000, 33 | 34 | "target_hard_update_period" : 1000, 35 | "use_soft_update" : true, 36 | "tau" : 0.005, 37 | "opt_times" : 200, 38 | 39 | "eval_episodes" : 3 40 | }, 41 | "sac":{ 42 | 43 | "plr" : 3e-4, 44 | "qlr" : 3e-4, 45 | 46 | "reparameterization": true, 47 | "automatic_entropy_tuning": true, 48 | "temp_reweight": true, 49 | "policy_std_reg_weight": 0, 50 | "policy_mean_reg_weight": 0 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /meta_config/mt50/modular_4_4_2_128_reweight_rand.json: -------------------------------------------------------------------------------- 1 | { 2 | "env_name" : "mt50", 3 | "env":{ 4 | "reward_scale":1, 5 | "obs_norm":false 6 | }, 7 | "meta_env":{ 8 | "obs_type": "with_goal", 9 | "random_init": true 10 | }, 11 | "replay_buffer":{ 12 | "size": 1e7 13 | }, 14 | "net":{ 15 | "hidden_shapes": [400, 400], 16 | "em_hidden_shapes": [400], 17 | "num_layers": 4, 18 | "num_modules": 4, 19 | "module_hidden": 128, 20 | "num_gating_layers": 2, 21 | "gating_hidden": 256, 22 | "add_bn": false, 23 | "pre_softmax": false 24 | }, 25 | "general_setting": { 26 | "discount" : 0.99, 27 | "pretrain_epochs" : 20, 28 | "num_epochs" : 5000, 29 | "epoch_frames" : 200, 30 | "max_episode_frames" : 1000, 31 | 32 | "batch_size" : 6400, 33 | "min_pool" : 10000, 34 | 35 | "target_hard_update_period" : 1000, 36 | "use_soft_update" : true, 37 | "tau" : 0.005, 38 | "opt_times" : 200, 39 | 40 | "eval_episodes" : 3 41 | }, 42 | "sac":{ 43 | 44 | "plr" : 3e-4, 45 | "qlr" : 3e-4, 46 | 47 | "reparameterization": true, 48 | "automatic_entropy_tuning": true, 49 | "temp_reweight": true, 50 | "policy_std_reg_weight": 0, 51 | "policy_mean_reg_weight": 0 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /meta_config/mt50/mtmhsac.json: -------------------------------------------------------------------------------- 1 | { 2 | "env_name" : "mt50", 3 | "env":{ 4 | "reward_scale":1, 5 | "obs_norm":false 6 | }, 7 | "meta_env":{ 8 | "obs_type": "with_goal_and_id" 9 | }, 10 | "replay_buffer":{ 11 | "size": 1e7 12 | }, 13 | "net":{ 14 | "hidden_shapes": [400,400], 15 | "append_hidden_shapes":[400] 16 | }, 17 | "general_setting": { 18 | "discount" : 0.99, 19 | "pretrain_epochs" : 20, 20 | "num_epochs" : 5000, 21 | "epoch_frames" : 200, 22 | "max_episode_frames" : 1000, 23 | 24 | "batch_size" : 6400, 25 | "min_pool" : 10000, 26 | 27 | "target_hard_update_period" : 1000, 28 | "use_soft_update" : true, 29 | "tau" : 0.005, 30 | "opt_times" : 200, 31 | 32 | "eval_episodes" : 3 33 | }, 34 | "sac":{ 35 | 36 | "plr" : 3e-4, 37 | "qlr" : 3e-4, 38 | 39 | "reparameterization": true, 40 | "automatic_entropy_tuning": true, 41 | "policy_std_reg_weight": 0, 42 | "policy_mean_reg_weight": 0 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /meta_config/mt50/mtmhsac_rand.json: -------------------------------------------------------------------------------- 1 | { 2 | "env_name" : "mt50", 3 | "env":{ 4 | "reward_scale":1, 5 | "obs_norm":false 6 | }, 7 | "meta_env":{ 8 | "obs_type": "with_goal_and_id", 9 | "random_init": true 10 | }, 11 | "replay_buffer":{ 12 | "size": 1e7 13 | }, 14 | "net":{ 15 | "hidden_shapes": [400,400], 16 | "append_hidden_shapes":[400] 17 | }, 18 | "general_setting": { 19 | "discount" : 0.99, 20 | "pretrain_epochs" : 20, 21 | "num_epochs" : 5000, 22 | "epoch_frames" : 200, 23 | "max_episode_frames" : 1000, 24 | 25 | "batch_size" : 6400, 26 | "min_pool" : 10000, 27 | 28 | "target_hard_update_period" : 1000, 29 | "use_soft_update" : true, 30 | "tau" : 0.005, 31 | "opt_times" : 200, 32 | 33 | "eval_episodes" : 3 34 | }, 35 | "sac":{ 36 | 37 | "plr" : 3e-4, 38 | "qlr" : 3e-4, 39 | 40 | "reparameterization": true, 41 | "automatic_entropy_tuning": true, 42 | "policy_std_reg_weight": 0, 43 | "policy_mean_reg_weight": 0 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /meta_config/mt50/mtsac.json: -------------------------------------------------------------------------------- 1 | { 2 | "env_name" : "mt50", 3 | "env":{ 4 | "reward_scale":1, 5 | "obs_norm":false 6 | }, 7 | "meta_env":{ 8 | "obs_type": "with_goal_and_id" 9 | }, 10 | "replay_buffer":{ 11 | "size": 1e7 12 | }, 13 | "net":{ 14 | "hidden_shapes": [400,400,400], 15 | "append_hidden_shapes":[] 16 | }, 17 | "general_setting": { 18 | "discount" : 0.99, 19 | "pretrain_epochs" : 20, 20 | "num_epochs" : 5000, 21 | "epoch_frames" : 200, 22 | "max_episode_frames" : 1000, 23 | 24 | "batch_size" : 6400, 25 | "min_pool" : 10000, 26 | 27 | "target_hard_update_period" : 1000, 28 | "use_soft_update" : true, 29 | "tau" : 0.005, 30 | "opt_times" : 200, 31 | 32 | "eval_episodes" : 3 33 | }, 34 | "sac":{ 35 | 36 | "plr" : 3e-4, 37 | "qlr" : 3e-4, 38 | 39 | "reparameterization": true, 40 | "automatic_entropy_tuning": true, 41 | "policy_std_reg_weight": 0, 42 | "policy_mean_reg_weight": 0 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /meta_config/mt50/mtsac_rand.json: -------------------------------------------------------------------------------- 1 | { 2 | "env_name" : "mt50", 3 | "env":{ 4 | "reward_scale":1, 5 | "obs_norm":false 6 | }, 7 | "meta_env":{ 8 | "obs_type": "with_goal_and_id", 9 | "random_init": true 10 | }, 11 | "replay_buffer":{ 12 | "size": 1e7 13 | }, 14 | "net":{ 15 | "hidden_shapes": [400,400,400], 16 | "append_hidden_shapes":[] 17 | }, 18 | "general_setting": { 19 | "discount" : 0.99, 20 | "pretrain_epochs" : 20, 21 | "num_epochs" : 5000, 22 | "epoch_frames" : 200, 23 | "max_episode_frames" : 1000, 24 | 25 | "batch_size" : 6400, 26 | "min_pool" : 10000, 27 | 28 | "target_hard_update_period" : 1000, 29 | "use_soft_update" : true, 30 | "tau" : 0.005, 31 | "opt_times" : 200, 32 | 33 | "eval_episodes" : 3 34 | }, 35 | "sac":{ 36 | 37 | "plr" : 3e-4, 38 | "qlr" : 3e-4, 39 | 40 | "reparameterization": true, 41 | "automatic_entropy_tuning": true, 42 | "policy_std_reg_weight": 0, 43 | "policy_mean_reg_weight": 0 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /metaworld_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .meta_env import * -------------------------------------------------------------------------------- /metaworld_utils/meta_env.py: -------------------------------------------------------------------------------- 1 | 2 | import gym 3 | from gym import Wrapper 4 | from gym.spaces import Box 5 | import numpy as np 6 | from metaworld.envs.mujoco.sawyer_xyz import * 7 | from metaworld.envs.mujoco.multitask_env import MultiClassMultiTaskEnv 8 | from metaworld.core.serializable import Serializable 9 | import sys 10 | sys.path.append("../..") 11 | from torchrl.env.continuous_wrapper import * 12 | from torchrl.env.get_env import wrap_continuous_env 13 | 14 | 15 | class SingleWrapper(Wrapper): 16 | def __init__(self, env): 17 | self._env = env 18 | self.action_space = env.action_space 19 | self.observation_space = env.observation_space 20 | 21 | def reset(self): 22 | return self._env.reset() 23 | 24 | def seed(self, se): 25 | self._env.seed(se) 26 | 27 | def reset_with_index(self, task_idx): 28 | return self._env.reset() 29 | 30 | def step(self, action): 31 | obs, reward, done, info = self._env.step(action) 32 | return obs, reward, done, info 33 | 34 | def render(self, mode='human', **kwargs): 35 | return self._env.render(mode=mode, **kwargs) 36 | 37 | def close(self): 38 | self._env.close() 39 | 40 | 41 | class MTEnv(MultiClassMultiTaskEnv): 42 | def __init__(self, 43 | task_env_cls_dict, 44 | task_args_kwargs, 45 | sample_all=True, 46 | sample_goals=False, 47 | obs_type='plain', 48 | repeat_times=1, 49 | ): 50 | Serializable.quick_init(self, locals()) 51 | super().__init__( 52 | task_env_cls_dict, 53 | task_args_kwargs, 54 | sample_all, 55 | sample_goals, 56 | obs_type) 57 | 58 | self.train_mode = True 59 | self.repeat_times = repeat_times 60 | 61 | def reset(self, **kwargs): 62 | if self.train_mode: 63 | sample_task = np.random.randint(0, self.num_tasks) 64 | self.set_task(sample_task) 65 | return super().reset(**kwargs) 66 | 67 | def reset_with_index(self, task_idx, **kwargs): 68 | self.set_task(task_idx) 69 | return super().reset(**kwargs) 70 | 71 | def train(self): 72 | self.train_mode = True 73 | 74 | def test(self): 75 | self.train_mode = False 76 | 77 | def render(self, mode='human'): 78 | return super().render(mode=mode) 79 | 80 | @property 81 | def observation_space(self): 82 | if self._obs_type == 'plain': 83 | return self._task_envs[self.observation_space_index].observation_space 84 | else: 85 | plain_high = self._task_envs[self.observation_space_index].observation_space.high 86 | plain_low = self._task_envs[self.observation_space_index].observation_space.low 87 | goal_high = self.active_env.goal_space.high 88 | goal_low = self.active_env.goal_space.low 89 | if self._obs_type == 'with_goal': 90 | return Box( 91 | high=np.concatenate([plain_high, goal_high] + [goal_high] * (self.repeat_times -1) ), 92 | low=np.concatenate([plain_low, goal_low] + [goal_low] * (self.repeat_times -1 ))) 93 | elif self._obs_type == 'with_goal_id' and self._fully_discretized: 94 | goal_id_low = np.zeros(shape=(self._n_discrete_goals * self.repeat_times,)) 95 | goal_id_high = np.ones(shape=(self._n_discrete_goals * self.repeat_times,)) 96 | return Box( 97 | high=np.concatenate([plain_high, goal_id_low,]), 98 | low=np.concatenate([plain_low, goal_id_high,])) 99 | elif self._obs_type == 'with_goal_and_id' and self._fully_discretized: 100 | goal_id_low = np.zeros(shape=(self._n_discrete_goals,)) 101 | goal_id_high = np.ones(shape=(self._n_discrete_goals,)) 102 | return Box( 103 | high=np.concatenate([plain_high, goal_id_low, goal_high] + [goal_id_low, goal_high] * (self.repeat_times - 1) ), 104 | low=np.concatenate([plain_low, goal_id_high, goal_low] + [goal_id_high, goal_low] * (self.repeat_times - 1) )) 105 | else: 106 | raise NotImplementedError 107 | 108 | def _augment_observation(self, obs): 109 | # optionally zero-pad observation 110 | if np.prod(obs.shape) < self._max_plain_dim: 111 | zeros = np.zeros( 112 | shape=(self._max_plain_dim - np.prod(obs.shape),) 113 | ) 114 | obs = np.concatenate([obs, zeros]) 115 | 116 | # augment the observation based on obs_type: 117 | if self._obs_type == 'with_goal_id' or self._obs_type == 'with_goal_and_id': 118 | 119 | aug_ob = [] 120 | if self._obs_type == 'with_goal_and_id': 121 | aug_ob.append(self.active_env._state_goal) 122 | # if self._obs_type == 'with_goal_and_id': 123 | # obs = np.concatenate([obs, self.active_env._state_goal]) 124 | task_id = self._env_discrete_index[self._task_names[self.active_task]] + (self.active_env.active_discrete_goal or 0) 125 | task_onehot = np.zeros(shape=(self._n_discrete_goals,), dtype=np.float32) 126 | task_onehot[task_id] = 1. 127 | aug_ob.append(task_onehot) 128 | 129 | obs = np.concatenate([obs] + aug_ob * self.repeat_times) 130 | 131 | elif self._obs_type == 'with_goal': 132 | obs = np.concatenate([obs] + [self.active_env._state_goal] * self.repeat_times ) 133 | return obs 134 | 135 | 136 | def generate_single_task_env(env_id, kwargs): 137 | env = globals()[env_id](**kwargs) 138 | env = SingleWrapper(env) 139 | return env 140 | 141 | 142 | def generate_mt_env(cls_dict, args_kwargs, **kwargs): 143 | copy_kwargs = kwargs.copy() 144 | if "random_init" in copy_kwargs: 145 | del copy_kwargs["random_init"] 146 | env = MTEnv( 147 | task_env_cls_dict=cls_dict, 148 | task_args_kwargs=args_kwargs, 149 | **copy_kwargs 150 | ) 151 | # Set to discretized since the env is actually not used 152 | env._sample_goals = False 153 | env._fully_discretized = True 154 | 155 | goals_dict = { 156 | t: [e.goal.copy()] 157 | for t, e in zip(env._task_names, env._task_envs) 158 | } 159 | env.discretize_goal_space(goals_dict) 160 | return env 161 | 162 | 163 | def generate_single_mt_env(task_cls, task_args, env_rank, num_tasks, 164 | max_obs_dim, env_params, meta_env_params): 165 | 166 | env = task_cls(*task_args['args'], **task_args["kwargs"]) 167 | env.discretize_goal_space(env.goal.copy()) 168 | if "sampled_index" in meta_env_params: 169 | del meta_env_params["sampled_index"] 170 | env = AugObs(env, env_rank, num_tasks, max_obs_dim, meta_env_params) 171 | env = wrap_continuous_env(env, **env_params) 172 | 173 | act_space = env.action_space 174 | if isinstance(act_space, gym.spaces.Box): 175 | env = NormAct(env) 176 | return env 177 | 178 | 179 | def generate_mt10_env(mt_param): 180 | from metaworld.envs.mujoco.env_dict import EASY_MODE_CLS_DICT, EASY_MODE_ARGS_KWARGS 181 | 182 | if "random_init" in mt_param: 183 | for key in EASY_MODE_ARGS_KWARGS: 184 | EASY_MODE_ARGS_KWARGS[key]["kwargs"]["random_init"]=True 185 | 186 | return generate_mt_env(EASY_MODE_CLS_DICT, EASY_MODE_ARGS_KWARGS, **mt_param), \ 187 | EASY_MODE_CLS_DICT, EASY_MODE_ARGS_KWARGS 188 | 189 | 190 | def generate_mt50_env(mt_param): 191 | from metaworld.envs.mujoco.env_dict import HARD_MODE_CLS_DICT, HARD_MODE_ARGS_KWARGS 192 | cls_dict = {} 193 | args_kwargs = {} 194 | for k in HARD_MODE_CLS_DICT.keys(): 195 | for task in HARD_MODE_CLS_DICT[k].keys(): 196 | cls_dict[task] = HARD_MODE_CLS_DICT[k][task] 197 | args_kwargs[task] = HARD_MODE_ARGS_KWARGS[k][task] 198 | 199 | if "random_init" in mt_param: 200 | for key in args_kwargs: 201 | args_kwargs[key]["kwargs"]["random_init"]=mt_param["random_init"] 202 | 203 | return generate_mt_env(cls_dict, args_kwargs, **mt_param), \ 204 | cls_dict, args_kwargs 205 | 206 | 207 | def get_meta_env(env_id, env_param, mt_param, return_dicts=True): 208 | cls_dicts = None 209 | args_kwargs = None 210 | if env_id == "mt10": 211 | env, cls_dicts, args_kwargs = generate_mt10_env(mt_param) 212 | elif env_id == "mt50": 213 | env, cls_dicts, args_kwargs = generate_mt50_env(mt_param) 214 | else: 215 | env = generate_single_task_env(env_id, mt_param) 216 | 217 | env = wrap_continuous_env(env, **env_param) 218 | 219 | act_space = env.action_space 220 | if isinstance(act_space, gym.spaces.Box): 221 | env = NormAct(env) 222 | if env_id == "mt10" or env_id == "mt50": 223 | env.num_tasks = len(cls_dicts) 224 | else: 225 | env.num_tasks = 1 226 | 227 | if cls_dicts is not None and return_dicts is True: 228 | return env, cls_dicts, args_kwargs 229 | else: 230 | return env 231 | -------------------------------------------------------------------------------- /starter/mt_para_mhmt_sac.py: -------------------------------------------------------------------------------- 1 | import sys 2 | # import sys 3 | sys.path.append(".") 4 | 5 | import torch 6 | 7 | import os 8 | import time 9 | import os.path as osp 10 | 11 | import numpy as np 12 | 13 | from torchrl.utils import get_args 14 | from torchrl.utils import get_params 15 | from torchrl.env import get_env 16 | 17 | from torchrl.utils import Logger 18 | 19 | args = get_args() 20 | params = get_params(args.config) 21 | 22 | import torchrl.policies as policies 23 | import torchrl.networks as networks 24 | from torchrl.algo import SAC 25 | from torchrl.algo import TwinSAC 26 | from torchrl.algo import TwinSACQ 27 | from torchrl.algo import MTSAC 28 | from torchrl.algo import MTMHSAC 29 | from torchrl.collector.para import ParallelCollector 30 | from torchrl.collector.para import AsyncParallelCollector 31 | from torchrl.collector.para.mt import SingleTaskParallelCollectorBase 32 | from torchrl.collector.para.async_mt import AsyncSingleTaskParallelCollector 33 | from torchrl.collector.para.async_mt import AsyncMultiTaskParallelCollectorUniform 34 | 35 | from torchrl.replay_buffers.shared import SharedBaseReplayBuffer 36 | from torchrl.replay_buffers.shared import AsyncSharedReplayBuffer 37 | import gym 38 | 39 | from metaworld_utils.meta_env import get_meta_env 40 | 41 | def experiment(args): 42 | 43 | device = torch.device("cuda:{}".format(args.device) if args.cuda else "cpu") 44 | 45 | env, cls_dicts, cls_args = get_meta_env( params['env_name'], params['env'], params['meta_env']) 46 | 47 | env.seed(args.seed) 48 | torch.manual_seed(args.seed) 49 | np.random.seed(args.seed) 50 | if args.cuda: 51 | torch.backends.cudnn.deterministic=True 52 | 53 | buffer_param = params['replay_buffer'] 54 | 55 | experiment_name = os.path.split( os.path.splitext( args.config )[0] )[-1] if args.id is None \ 56 | else args.id 57 | logger = Logger( experiment_name , params['env_name'], args.seed, params, args.log_dir ) 58 | 59 | params['general_setting']['env'] = env 60 | params['general_setting']['logger'] = logger 61 | params['general_setting']['device'] = device 62 | 63 | params['net']['base_type']=networks.MLPBase 64 | 65 | import torch.multiprocessing as mp 66 | mp.set_start_method('spawn', force=True) 67 | 68 | pf = policies.MultiHeadGuassianContPolicy ( 69 | input_shape = env.observation_space.shape[0], 70 | output_shape = 2 * env.action_space.shape[0], 71 | head_num=env.num_tasks, 72 | **params['net'] ) 73 | qf1 = networks.FlattenBootstrappedNet( 74 | input_shape = env.observation_space.shape[0] + env.action_space.shape[0], 75 | output_shape = 1, 76 | head_num=env.num_tasks, 77 | **params['net'] ) 78 | qf2 = networks.FlattenBootstrappedNet( 79 | input_shape = env.observation_space.shape[0] + env.action_space.shape[0], 80 | output_shape = 1, 81 | head_num=env.num_tasks, 82 | **params['net'] ) 83 | 84 | example_ob = env.reset() 85 | example_dict = { 86 | "obs": example_ob, 87 | "next_obs": example_ob, 88 | "acts": env.action_space.sample(), 89 | "rewards": [0], 90 | "terminals": [False], 91 | "task_idxs": [0] 92 | } 93 | replay_buffer = AsyncSharedReplayBuffer( int(buffer_param['size']), 94 | args.worker_nums 95 | ) 96 | replay_buffer.build_by_example(example_dict) 97 | 98 | params['general_setting']['replay_buffer'] = replay_buffer 99 | 100 | epochs = params['general_setting']['pretrain_epochs'] + \ 101 | params['general_setting']['num_epochs'] 102 | 103 | params['general_setting']['collector'] = AsyncMultiTaskParallelCollectorUniform( 104 | env=env, pf=pf, replay_buffer=replay_buffer, 105 | env_cls = cls_dicts, env_args = [params["env"], cls_args, params["meta_env"]], 106 | device=device, 107 | reset_idx=True, 108 | epoch_frames=params['general_setting']['epoch_frames'], 109 | max_episode_frames=params['general_setting']['max_episode_frames'], 110 | eval_episodes = params['general_setting']['eval_episodes'], 111 | worker_nums=args.worker_nums, eval_worker_nums=args.eval_worker_nums, 112 | train_epochs = epochs, eval_epochs= params['general_setting']['num_epochs'] 113 | ) 114 | params['general_setting']['batch_size'] = int(params['general_setting']['batch_size']) 115 | params['general_setting']['save_dir'] = osp.join(logger.work_dir,"model") 116 | # agent = MTMHSAC( 117 | agent = MTSAC( 118 | pf = pf, 119 | qf1 = qf1, 120 | qf2 = qf2, 121 | task_nums=env.num_tasks, 122 | **params['sac'], 123 | **params['general_setting'] 124 | ) 125 | agent.train() 126 | 127 | if __name__ == "__main__": 128 | experiment(args) 129 | -------------------------------------------------------------------------------- /starter/mt_para_mtsac.py: -------------------------------------------------------------------------------- 1 | import sys 2 | # import sys 3 | sys.path.append(".") 4 | 5 | import torch 6 | 7 | import os 8 | import time 9 | import os.path as osp 10 | 11 | import numpy as np 12 | 13 | from torchrl.utils import get_args 14 | from torchrl.utils import get_params 15 | from torchrl.env import get_env 16 | 17 | from torchrl.utils import Logger 18 | 19 | args = get_args() 20 | params = get_params(args.config) 21 | 22 | import torchrl.policies as policies 23 | import torchrl.networks as networks 24 | from torchrl.algo import SAC 25 | from torchrl.algo import TwinSAC 26 | from torchrl.algo import TwinSACQ 27 | from torchrl.algo import MTSAC 28 | from torchrl.collector.para import ParallelCollector 29 | from torchrl.collector.para import AsyncParallelCollector 30 | from torchrl.collector.para.mt import SingleTaskParallelCollectorBase 31 | from torchrl.collector.para.async_mt import AsyncSingleTaskParallelCollector 32 | from torchrl.collector.para.async_mt import AsyncMultiTaskParallelCollectorUniform 33 | 34 | from torchrl.replay_buffers.shared import SharedBaseReplayBuffer 35 | from torchrl.replay_buffers.shared import AsyncSharedReplayBuffer 36 | import gym 37 | 38 | from metaworld_utils.meta_env import get_meta_env 39 | 40 | def experiment(args): 41 | 42 | device = torch.device("cuda:{}".format(args.device) if args.cuda else "cpu") 43 | 44 | env, cls_dicts, cls_args = get_meta_env( params['env_name'], params['env'], params['meta_env']) 45 | 46 | env.seed(args.seed) 47 | torch.manual_seed(args.seed) 48 | np.random.seed(args.seed) 49 | if args.cuda: 50 | torch.backends.cudnn.deterministic=True 51 | 52 | buffer_param = params['replay_buffer'] 53 | 54 | experiment_name = os.path.split( os.path.splitext( args.config )[0] )[-1] if args.id is None \ 55 | else args.id 56 | logger = Logger( experiment_name , params['env_name'], args.seed, params, args.log_dir ) 57 | 58 | params['general_setting']['env'] = env 59 | params['general_setting']['logger'] = logger 60 | params['general_setting']['device'] = device 61 | 62 | params['net']['base_type']=networks.MLPBase 63 | 64 | import torch.multiprocessing as mp 65 | mp.set_start_method('spawn', force=True) 66 | 67 | from torchrl.networks.init import normal_init 68 | 69 | pf = policies.GuassianContPolicy( 70 | input_shape = env.observation_space.shape[0], 71 | output_shape = 2 * env.action_space.shape[0], 72 | **params['net'] ) 73 | qf1 = networks.FlattenNet( 74 | input_shape = env.observation_space.shape[0] + env.action_space.shape[0], 75 | output_shape = 1, 76 | **params['net'] ) 77 | qf2 = networks.FlattenNet( 78 | input_shape = env.observation_space.shape[0] + env.action_space.shape[0], 79 | output_shape = 1, 80 | **params['net'] ) 81 | 82 | example_ob = env.reset() 83 | example_dict = { 84 | "obs": example_ob, 85 | "next_obs": example_ob, 86 | "acts": env.action_space.sample(), 87 | "rewards": [0], 88 | "terminals": [False], 89 | "task_idxs": [0] 90 | } 91 | replay_buffer = AsyncSharedReplayBuffer( int(buffer_param['size']), 92 | args.worker_nums 93 | ) 94 | replay_buffer.build_by_example(example_dict) 95 | 96 | params['general_setting']['replay_buffer'] = replay_buffer 97 | 98 | epochs = params['general_setting']['pretrain_epochs'] + \ 99 | params['general_setting']['num_epochs'] 100 | 101 | print(env.action_space) 102 | print(env.observation_space) 103 | params['general_setting']['collector'] = AsyncMultiTaskParallelCollectorUniform( 104 | env=env, pf=pf, replay_buffer=replay_buffer, 105 | env_cls = cls_dicts, env_args = [params["env"], cls_args, params["meta_env"]], 106 | device=device, 107 | reset_idx=True, 108 | epoch_frames=params['general_setting']['epoch_frames'], 109 | max_episode_frames=params['general_setting']['max_episode_frames'], 110 | eval_episodes = params['general_setting']['eval_episodes'], 111 | worker_nums=args.worker_nums, eval_worker_nums=args.eval_worker_nums, 112 | train_epochs = epochs, eval_epochs= params['general_setting']['num_epochs'] 113 | ) 114 | params['general_setting']['batch_size'] = int(params['general_setting']['batch_size']) 115 | params['general_setting']['save_dir'] = osp.join(logger.work_dir,"model") 116 | agent = MTSAC( 117 | pf = pf, 118 | qf1 = qf1, 119 | qf2 = qf2, 120 | task_nums=env.num_tasks, 121 | **params['sac'], 122 | **params['general_setting'] 123 | ) 124 | agent.train() 125 | 126 | if __name__ == "__main__": 127 | experiment(args) 128 | -------------------------------------------------------------------------------- /starter/mt_para_mtsac_modular_gated_cas.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append(".") 3 | 4 | import torch 5 | 6 | import os 7 | import time 8 | import os.path as osp 9 | 10 | import numpy as np 11 | 12 | from torchrl.utils import get_args 13 | from torchrl.utils import get_params 14 | from torchrl.env import get_env 15 | 16 | from torchrl.utils import Logger 17 | 18 | args = get_args() 19 | params = get_params(args.config) 20 | 21 | import torchrl.policies as policies 22 | import torchrl.networks as networks 23 | from torchrl.algo import SAC 24 | from torchrl.algo import TwinSAC 25 | from torchrl.algo import TwinSACQ 26 | from torchrl.algo import MTSAC 27 | from torchrl.collector.para import ParallelCollector 28 | from torchrl.collector.para import AsyncParallelCollector 29 | from torchrl.collector.para.mt import SingleTaskParallelCollectorBase 30 | from torchrl.collector.para.async_mt import AsyncSingleTaskParallelCollector 31 | from torchrl.collector.para.async_mt import AsyncMultiTaskParallelCollectorUniform 32 | 33 | from torchrl.replay_buffers.shared import SharedBaseReplayBuffer 34 | from torchrl.replay_buffers.shared import AsyncSharedReplayBuffer 35 | import gym 36 | 37 | from metaworld_utils.meta_env import get_meta_env 38 | 39 | import random 40 | 41 | def experiment(args): 42 | 43 | device = torch.device("cuda:{}".format(args.device) if args.cuda else "cpu") 44 | 45 | env, cls_dicts, cls_args = get_meta_env( params['env_name'], params['env'], params['meta_env']) 46 | 47 | env.seed(args.seed) 48 | torch.manual_seed(args.seed) 49 | np.random.seed(args.seed) 50 | random.seed(args.seed) 51 | if args.cuda: 52 | torch.backends.cudnn.deterministic=True 53 | 54 | buffer_param = params['replay_buffer'] 55 | 56 | experiment_name = os.path.split( os.path.splitext( args.config )[0] )[-1] if args.id is None \ 57 | else args.id 58 | logger = Logger( experiment_name , params['env_name'], args.seed, params, args.log_dir ) 59 | 60 | params['general_setting']['env'] = env 61 | params['general_setting']['logger'] = logger 62 | params['general_setting']['device'] = device 63 | 64 | params['net']['base_type']=networks.MLPBase 65 | 66 | import torch.multiprocessing as mp 67 | mp.set_start_method('spawn', force=True) 68 | 69 | from torchrl.networks.init import normal_init 70 | 71 | example_ob = env.reset() 72 | example_embedding = env.active_task_one_hot 73 | 74 | pf = policies.ModularGuassianGatedCascadeCondContPolicy( 75 | input_shape=env.observation_space.shape[0], 76 | em_input_shape=np.prod(example_embedding.shape), 77 | output_shape=2 * env.action_space.shape[0], 78 | **params['net']) 79 | 80 | if args.pf_snap is not None: 81 | pf.load_state_dict(torch.load(args.pf_snap, map_location='cpu')) 82 | 83 | qf1 = networks.FlattenModularGatedCascadeCondNet( 84 | input_shape=env.observation_space.shape[0] + env.action_space.shape[0], 85 | em_input_shape=np.prod(example_embedding.shape), 86 | output_shape=1, 87 | **params['net']) 88 | qf2 = networks.FlattenModularGatedCascadeCondNet( 89 | input_shape=env.observation_space.shape[0] + env.action_space.shape[0], 90 | em_input_shape=np.prod(example_embedding.shape), 91 | output_shape=1, 92 | **params['net']) 93 | 94 | if args.qf1_snap is not None: 95 | qf1.load_state_dict(torch.load(args.qf2_snap, map_location='cpu')) 96 | if args.qf2_snap is not None: 97 | qf2.load_state_dict(torch.load(args.qf2_snap, map_location='cpu')) 98 | 99 | example_dict = { 100 | "obs": example_ob, 101 | "next_obs": example_ob, 102 | "acts": env.action_space.sample(), 103 | "rewards": [0], 104 | "terminals": [False], 105 | "task_idxs": [0], 106 | "embedding_inputs": example_embedding 107 | } 108 | 109 | replay_buffer = AsyncSharedReplayBuffer(int(buffer_param['size']), 110 | args.worker_nums 111 | ) 112 | replay_buffer.build_by_example(example_dict) 113 | 114 | params['general_setting']['replay_buffer'] = replay_buffer 115 | 116 | epochs = params['general_setting']['pretrain_epochs'] + \ 117 | params['general_setting']['num_epochs'] 118 | 119 | print(env.action_space) 120 | print(env.observation_space) 121 | params['general_setting']['collector'] = AsyncMultiTaskParallelCollectorUniform( 122 | env=env, pf=pf, replay_buffer=replay_buffer, 123 | env_cls = cls_dicts, env_args = [params["env"], cls_args, params["meta_env"]], 124 | device=device, 125 | reset_idx=True, 126 | epoch_frames=params['general_setting']['epoch_frames'], 127 | max_episode_frames=params['general_setting']['max_episode_frames'], 128 | eval_episodes = params['general_setting']['eval_episodes'], 129 | worker_nums=args.worker_nums, eval_worker_nums=args.eval_worker_nums, 130 | train_epochs = epochs, eval_epochs= params['general_setting']['num_epochs'] 131 | ) 132 | params['general_setting']['batch_size'] = int(params['general_setting']['batch_size']) 133 | params['general_setting']['save_dir'] = osp.join(logger.work_dir,"model") 134 | agent = MTSAC( 135 | pf = pf, 136 | qf1 = qf1, 137 | qf2 = qf2, 138 | task_nums=env.num_tasks, 139 | **params['sac'], 140 | **params['general_setting'] 141 | ) 142 | agent.train() 143 | 144 | if __name__ == "__main__": 145 | experiment(args) 146 | -------------------------------------------------------------------------------- /torchrl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RchalYang/Soft-Module/e6d7c8ad362a6b950632236322356da742dc838c/torchrl/__init__.py -------------------------------------------------------------------------------- /torchrl/algo/__init__.py: -------------------------------------------------------------------------------- 1 | from .off_policy import * 2 | 3 | __all__ = [ 4 | 'SAC', 5 | 'MTSAC', 6 | 'MTMHSAC', 7 | 'DDPG', 8 | 'TwinSAC', 9 | 'TwinSACQ', 10 | 'TD3', 11 | ] -------------------------------------------------------------------------------- /torchrl/algo/off_policy/__init__.py: -------------------------------------------------------------------------------- 1 | from .sac import SAC 2 | from .mt_sac import MTSAC 3 | from .mtmh_sac import MTMHSAC 4 | from .twin_sac import TwinSAC 5 | from .twin_sac_q import TwinSACQ 6 | -------------------------------------------------------------------------------- /torchrl/algo/off_policy/mt_sac.py: -------------------------------------------------------------------------------- 1 | from .twin_sac_q import TwinSACQ 2 | import copy 3 | import torch 4 | import numpy as np 5 | 6 | import torchrl.policies as policies 7 | import torch.nn.functional as F 8 | 9 | class MTSAC(TwinSACQ): 10 | """" 11 | Support Different Temperature for different tasks 12 | """ 13 | def __init__(self, task_nums, 14 | temp_reweight=False, 15 | grad_clip=True, 16 | **kwargs): 17 | super().__init__(**kwargs) 18 | 19 | self.task_nums = task_nums 20 | if self.automatic_entropy_tuning: 21 | self.log_alpha = torch.zeros(self.task_nums).to(self.device) 22 | self.log_alpha.requires_grad_() 23 | self.alpha_optimizer = self.optimizer_class( 24 | [self.log_alpha], 25 | lr=self.plr, 26 | ) 27 | self.sample_key = ["obs", "next_obs", "acts", "rewards", 28 | "terminals", "task_idxs"] 29 | 30 | self.pf_flag = isinstance(self.pf, 31 | policies.EmbeddingGuassianContPolicyBase) 32 | 33 | self.idx_flag = isinstance(self.pf, policies.MultiHeadGuassianContPolicy) 34 | 35 | self.temp_reweight = temp_reweight 36 | if self.pf_flag: 37 | self.sample_key.append("embedding_inputs") 38 | self.grad_clip = grad_clip 39 | 40 | def update(self, batch): 41 | self.training_update_num += 1 42 | obs = batch['obs'] 43 | actions = batch['acts'] 44 | next_obs = batch['next_obs'] 45 | rewards = batch['rewards'] 46 | terminals = batch['terminals'] 47 | 48 | if self.pf_flag: 49 | embedding_inputs = batch["embedding_inputs"] 50 | 51 | if self.idx_flag: 52 | task_idx = batch['task_idxs'] 53 | 54 | rewards = torch.Tensor(rewards).to(self.device) 55 | terminals = torch.Tensor(terminals).to(self.device) 56 | obs = torch.Tensor(obs).to(self.device) 57 | actions = torch.Tensor(actions).to(self.device) 58 | next_obs = torch.Tensor(next_obs).to(self.device) 59 | 60 | if self.pf_flag: 61 | embedding_inputs = torch.Tensor(embedding_inputs).to(self.device) 62 | 63 | if self.idx_flag: 64 | task_idx = torch.Tensor(task_idx).to( self.device ).long() 65 | 66 | self.pf.train() 67 | self.qf1.train() 68 | self.qf2.train() 69 | 70 | """ 71 | Policy operations. 72 | """ 73 | if self.idx_flag: 74 | sample_info = self.pf.explore(obs, task_idx, 75 | return_log_probs=True) 76 | else: 77 | if self.pf_flag: 78 | sample_info = self.pf.explore(obs, embedding_inputs, 79 | return_log_probs=True) 80 | else: 81 | sample_info = self.pf.explore(obs, return_log_probs=True) 82 | 83 | mean = sample_info["mean"] 84 | log_std = sample_info["log_std"] 85 | new_actions = sample_info["action"] 86 | log_probs = sample_info["log_prob"] 87 | 88 | if self.idx_flag: 89 | q1_pred = self.qf1([obs, actions], task_idx) 90 | q2_pred = self.qf2([obs, actions], task_idx) 91 | else: 92 | if self.pf_flag: 93 | q1_pred = self.qf1([obs, actions], embedding_inputs) 94 | q2_pred = self.qf2([obs, actions], embedding_inputs) 95 | else: 96 | q1_pred = self.qf1([obs, actions]) 97 | q2_pred = self.qf2([obs, actions]) 98 | 99 | reweight_coeff = 1 100 | if self.automatic_entropy_tuning: 101 | """ 102 | Alpha Loss 103 | """ 104 | batch_size = log_probs.shape[0] 105 | log_alphas = (self.log_alpha.unsqueeze(0)).expand( 106 | (batch_size, self.task_nums)) 107 | log_alphas = log_alphas.unsqueeze(-1) 108 | # log_alphas = log_alphas.gather(1, task_idx) 109 | 110 | alpha_loss = -(log_alphas * 111 | (log_probs + self.target_entropy).detach()).mean() 112 | 113 | self.alpha_optimizer.zero_grad() 114 | alpha_loss.backward() 115 | self.alpha_optimizer.step() 116 | 117 | alphas = (self.log_alpha.exp().detach()).unsqueeze(0) 118 | alphas = alphas.expand((batch_size, self.task_nums)).unsqueeze(-1) 119 | # (batch_size, 1) 120 | if self.temp_reweight: 121 | softmax_temp = F.softmax(-self.log_alpha.detach()).unsqueeze(0) 122 | reweight_coeff = softmax_temp.expand((batch_size, 123 | self.task_nums)) 124 | reweight_coeff = reweight_coeff.unsqueeze(-1) * self.task_nums 125 | else: 126 | alphas = 1 127 | alpha_loss = 0 128 | 129 | with torch.no_grad(): 130 | if self.idx_flag: 131 | target_sample_info = self.pf.explore(next_obs, 132 | task_idx, 133 | return_log_probs=True) 134 | else: 135 | if self.pf_flag: 136 | target_sample_info = self.pf.explore(next_obs, 137 | embedding_inputs, 138 | return_log_probs=True) 139 | else: 140 | target_sample_info = self.pf.explore(next_obs, 141 | return_log_probs=True) 142 | 143 | target_actions = target_sample_info["action"] 144 | target_log_probs = target_sample_info["log_prob"] 145 | 146 | if self.idx_flag: 147 | target_q1_pred = self.target_qf1([next_obs, target_actions], 148 | task_idx) 149 | target_q2_pred = self.target_qf2([next_obs, target_actions], 150 | task_idx) 151 | else: 152 | if self.pf_flag: 153 | target_q1_pred = self.target_qf1([next_obs, target_actions], 154 | embedding_inputs) 155 | target_q2_pred = self.target_qf2([next_obs, target_actions], 156 | embedding_inputs) 157 | else: 158 | target_q1_pred = self.target_qf1([next_obs, target_actions]) 159 | target_q2_pred = self.target_qf2([next_obs, target_actions]) 160 | 161 | min_target_q = torch.min(target_q1_pred, target_q2_pred) 162 | target_v_values = min_target_q - alphas * target_log_probs 163 | """ 164 | QF Loss 165 | """ 166 | # q_target = rewards + (1. - terminals) * self.discount * target_v_values 167 | # There is no actual terminate in meta-world -> just filter all time_limit terminal 168 | q_target = rewards + self.discount * target_v_values 169 | 170 | qf1_loss = (reweight_coeff * 171 | ((q1_pred - q_target.detach()) ** 2)).mean() 172 | qf2_loss = (reweight_coeff * 173 | ((q2_pred - q_target.detach()) ** 2)).mean() 174 | 175 | assert q1_pred.shape == q_target.shape, print(q1_pred.shape, q_target.shape) 176 | assert q2_pred.shape == q_target.shape, print(q1_pred.shape, q_target.shape) 177 | 178 | if self.idx_flag: 179 | q_new_actions = torch.min( 180 | self.qf1([obs, new_actions], task_idx), 181 | self.qf2([obs, new_actions], task_idx)) 182 | else: 183 | if self.pf_flag: 184 | q_new_actions = torch.min( 185 | self.qf1([obs, new_actions], embedding_inputs), 186 | self.qf2([obs, new_actions], embedding_inputs)) 187 | else: 188 | q_new_actions = torch.min( 189 | self.qf1([obs, new_actions]), 190 | self.qf2([obs, new_actions])) 191 | """ 192 | Policy Loss 193 | """ 194 | if not self.reparameterization: 195 | raise NotImplementedError 196 | else: 197 | assert log_probs.shape == q_new_actions.shape 198 | policy_loss = (reweight_coeff * 199 | (alphas * log_probs - q_new_actions)).mean() 200 | 201 | std_reg_loss = self.policy_std_reg_weight * (log_std**2).mean() 202 | mean_reg_loss = self.policy_mean_reg_weight * (mean**2).mean() 203 | 204 | policy_loss += std_reg_loss + mean_reg_loss 205 | 206 | """ 207 | Update Networks 208 | """ 209 | 210 | self.pf_optimizer.zero_grad() 211 | policy_loss.backward() 212 | if self.grad_clip: 213 | pf_norm = torch.nn.utils.clip_grad_norm_(self.pf.parameters(), 1) 214 | self.pf_optimizer.step() 215 | 216 | self.qf1_optimizer.zero_grad() 217 | qf1_loss.backward() 218 | if self.grad_clip: 219 | qf1_norm = torch.nn.utils.clip_grad_norm_(self.qf1.parameters(), 1) 220 | self.qf1_optimizer.step() 221 | 222 | self.qf2_optimizer.zero_grad() 223 | qf2_loss.backward() 224 | if self.grad_clip: 225 | qf2_norm = torch.nn.utils.clip_grad_norm_(self.qf2.parameters(), 1) 226 | self.qf2_optimizer.step() 227 | 228 | self._update_target_networks() 229 | 230 | # Information For Logger 231 | info = {} 232 | info['Reward_Mean'] = rewards.mean().item() 233 | 234 | if self.automatic_entropy_tuning: 235 | for i in range(self.task_nums): 236 | info["alpha_{}".format(i)] = self.log_alpha[i].exp().item() 237 | info["Alpha_loss"] = alpha_loss.item() 238 | info['Training/policy_loss'] = policy_loss.item() 239 | info['Training/qf1_loss'] = qf1_loss.item() 240 | info['Training/qf2_loss'] = qf2_loss.item() 241 | 242 | if self.grad_clip: 243 | info['Training/pf_norm'] = pf_norm.item() 244 | info['Training/qf1_norm'] = qf1_norm.item() 245 | info['Training/qf2_norm'] = qf2_norm.item() 246 | 247 | info['log_std/mean'] = log_std.mean().item() 248 | info['log_std/std'] = log_std.std().item() 249 | info['log_std/max'] = log_std.max().item() 250 | info['log_std/min'] = log_std.min().item() 251 | 252 | log_probs_display = log_probs.detach() 253 | log_probs_display = (log_probs_display.mean(0)).squeeze(1) 254 | for i in range(self.task_nums): 255 | info["log_prob_{}".format(i)] = log_probs_display[i].item() 256 | 257 | info['log_probs/mean'] = log_probs.mean().item() 258 | info['log_probs/std'] = log_probs.std().item() 259 | info['log_probs/max'] = log_probs.max().item() 260 | info['log_probs/min'] = log_probs.min().item() 261 | 262 | info['mean/mean'] = mean.mean().item() 263 | info['mean/std'] = mean.std().item() 264 | info['mean/max'] = mean.max().item() 265 | info['mean/min'] = mean.min().item() 266 | 267 | return info 268 | 269 | def update_per_epoch(self): 270 | for _ in range(self.opt_times): 271 | batch = self.replay_buffer.random_batch(self.batch_size, 272 | self.sample_key, 273 | reshape=False) 274 | infos = self.update(batch) 275 | self.logger.add_update_info(infos) 276 | -------------------------------------------------------------------------------- /torchrl/algo/off_policy/mtmh_sac.py: -------------------------------------------------------------------------------- 1 | from .twin_sac_q import TwinSACQ 2 | from .mt_sac import MTSAC 3 | import copy 4 | import torch 5 | import torchrl.algo.utils as atu 6 | import numpy as np 7 | import torch.nn.functional as F 8 | 9 | class MTMHSAC(MTSAC): 10 | ## Multi Task Multi Head SAC (Input Processed) 11 | def __init__(self,**kwargs): 12 | super().__init__(**kwargs) 13 | self.head_idx = list(range(self.task_nums)) 14 | self.sample_key = ["obs", "next_obs", "acts", "rewards", "task_idxs", 15 | "terminals"] 16 | if self.pf_flag: 17 | self.sample_key.append("embedding_inputs") 18 | 19 | def update(self, batch): 20 | self.training_update_num += 1 21 | 22 | obs = batch['obs'] 23 | actions = batch['acts'] 24 | next_obs = batch['next_obs'] 25 | rewards = batch['rewards'] 26 | terminals = batch['terminals'] 27 | 28 | # For Task 29 | task_idx = batch['task_idxs'] 30 | # task_onehot = batch['task_onehot'] 31 | if self.pf_flag: 32 | embedding_inputs = batch["embedding_inputs"] 33 | 34 | rewards = torch.Tensor(rewards).to( self.device ) 35 | terminals = torch.Tensor(terminals).to( self.device ) 36 | obs = torch.Tensor(obs).to( self.device ) 37 | actions = torch.Tensor(actions).to( self.device ) 38 | next_obs = torch.Tensor(next_obs).to( self.device ) 39 | # For Task 40 | task_idx = torch.Tensor(task_idx).to( self.device ).long() 41 | # task_onehot = torch.Tensor(task_onehot).to( self.device ) 42 | 43 | if self.pf_flag: 44 | embedding_inputs = torch.Tensor(embedding_inputs).to(self.device) 45 | 46 | """ 47 | Policy operations. 48 | """ 49 | if self.pf_flag: 50 | sample_info = self.pf.explore(obs, embedding_inputs, 51 | self.head_idx, return_log_probs=True ) 52 | else: 53 | sample_info = self.pf.explore(obs, self.head_idx, return_log_probs=True ) 54 | 55 | mean_list = sample_info["mean"] 56 | log_std_list = sample_info["log_std"] 57 | new_actions_list = sample_info["action"] 58 | log_probs_list = sample_info["log_prob"] 59 | # ent_list = sample_info["ent"] 60 | 61 | means = atu.unsqe_cat_gather(mean_list, task_idx, dim = 1) 62 | 63 | log_stds = atu.unsqe_cat_gather(log_std_list, task_idx, dim = 1) 64 | 65 | new_actions = atu.unsqe_cat_gather(new_actions_list, task_idx, dim = 1) 66 | 67 | # log_probs = torch.cat 68 | log_probs = atu.unsqe_cat_gather(log_probs_list, task_idx, dim = 1) 69 | 70 | if self.pf_flag: 71 | q1_pred_list = self.qf1([obs, actions], embedding_inputs, self.head_idx) 72 | q2_pred_list = self.qf2([obs, actions], embedding_inputs, self.head_idx) 73 | else: 74 | q1_pred_list = self.qf1([obs, actions], self.head_idx) 75 | q2_pred_list = self.qf2([obs, actions], self.head_idx) 76 | 77 | q1_preds = atu.unsqe_cat_gather(q1_pred_list, task_idx, dim = 1) 78 | q2_preds = atu.unsqe_cat_gather(q2_pred_list, task_idx, dim = 1) 79 | 80 | reweight_coeff = 1 81 | if self.automatic_entropy_tuning: 82 | """ 83 | Alpha Loss 84 | """ 85 | batch_size = log_probs.shape[0] 86 | log_alphas = (self.log_alpha.unsqueeze(0)).expand((batch_size, self.task_nums)) 87 | log_alphas = log_alphas.gather(1, task_idx) 88 | 89 | alpha_loss = -(log_alphas * (log_probs + self.target_entropy).detach()).mean() 90 | self.alpha_optimizer.zero_grad() 91 | alpha_loss.backward() 92 | self.alpha_optimizer.step() 93 | # alpha = self.log_alpha.exp() 94 | alphas = (self.log_alpha.exp().detach()).unsqueeze(0).expand((batch_size, self.task_nums)) 95 | alphas = alphas.gather(1, task_idx) 96 | if self.temp_reweight: 97 | softmax_temp = F.softmax(-self.log_alpha.detach()) 98 | reweight_coeff = softmax_temp.unsqueeze(0).expand((batch_size, self.task_nums)) 99 | reweight_coeff = reweight_coeff.gather(1, task_idx) 100 | else: 101 | alphas = 1 102 | alpha_loss = 0 103 | 104 | progress_weight = 1 105 | if self.progress_reweight: 106 | progress_weight = torch.Tensor(self.collector.task_progress) 107 | progress_weight = progress_weight.unsqueeze(0).expand((batch_size, self.task_nums)) 108 | progress_weight = progress_weight.gather(1, task_idx) 109 | reweight_coeff = reweight_coeff * progress_weight 110 | 111 | with torch.no_grad(): 112 | if self.pf_flag: 113 | target_sample_info = self.pf.explore(next_obs, embedding_inputs, 114 | self.head_idx, return_log_probs=True ) 115 | else: 116 | target_sample_info = self.pf.explore(next_obs, self.head_idx, return_log_probs=True ) 117 | 118 | target_actions_list = target_sample_info["action"] 119 | target_actions = atu.unsqe_cat_gather(target_actions_list, task_idx, dim = 1) 120 | 121 | target_log_probs_list = target_sample_info["log_prob"] 122 | target_log_probs = atu.unsqe_cat_gather(target_log_probs_list, task_idx, dim = 1) 123 | 124 | if self.pf_flag: 125 | target_q1_pred_list = self.target_qf1([next_obs, target_actions], 126 | embedding_inputs, self.head_idx) 127 | target_q2_pred_list = self.target_qf2([next_obs, target_actions], 128 | embedding_inputs, self.head_idx) 129 | else: 130 | target_q1_pred_list = self.target_qf1([next_obs, target_actions], self.head_idx) 131 | target_q2_pred_list = self.target_qf2([next_obs, target_actions], self.head_idx) 132 | 133 | target_q1_pred = atu.unsqe_cat_gather(target_q1_pred_list, task_idx, dim = 1) 134 | target_q2_pred = atu.unsqe_cat_gather(target_q2_pred_list, task_idx, dim = 1) 135 | 136 | min_target_q = torch.min(target_q1_pred, target_q2_pred) 137 | target_v_values = min_target_q - alphas * target_log_probs 138 | 139 | """ 140 | QF Loss 141 | """ 142 | # q_target = rewards + (1. - terminals) * self.discount * target_v_values 143 | # There is no actual terminate in meta-world -> just filter all time_limit terminal 144 | q_target = rewards + self.discount * target_v_values 145 | 146 | qf1_loss = (reweight_coeff * ((q1_preds - q_target.detach()) ** 2)).mean() 147 | qf2_loss = (reweight_coeff * ((q2_preds - q_target.detach()) ** 2)).mean() 148 | 149 | # """ 150 | # VF Loss 151 | # """ 152 | if self.pf_flag: 153 | q1_new_actions_list = self.qf1([obs, new_actions], 154 | embedding_inputs, self.head_idx) 155 | q2_new_actions_list = self.qf2([obs, new_actions], 156 | embedding_inputs, self.head_idx) 157 | else: 158 | q1_new_actions_list = self.qf1([obs, new_actions], self.head_idx) 159 | q2_new_actions_list = self.qf2([obs, new_actions], self.head_idx) 160 | 161 | q1_new_actions = atu.unsqe_cat_gather(q1_new_actions_list, task_idx, dim = 1) 162 | q2_new_actions = atu.unsqe_cat_gather(q2_new_actions_list, task_idx, dim = 1) 163 | 164 | q_new_actions = torch.min( 165 | q1_new_actions, 166 | q2_new_actions 167 | ) 168 | 169 | """ 170 | Policy Loss 171 | """ 172 | if not self.reparameterization: 173 | raise NotImplementedError 174 | else: 175 | # policy_loss = ( alphas * log_probs - q_new_actions).mean() 176 | assert log_probs.shape == q_new_actions.shape 177 | policy_loss = (reweight_coeff * ( alphas * log_probs - q_new_actions)).mean() 178 | 179 | std_reg_loss = self.policy_std_reg_weight * (log_stds**2).mean() 180 | mean_reg_loss = self.policy_mean_reg_weight * (means**2).mean() 181 | 182 | policy_loss += std_reg_loss + mean_reg_loss 183 | 184 | """ 185 | Update Networks 186 | """ 187 | 188 | self.pf_optimizer.zero_grad() 189 | policy_loss.backward() 190 | self.pf_optimizer.step() 191 | 192 | self.qf1_optimizer.zero_grad() 193 | qf1_loss.backward() 194 | self.qf1_optimizer.step() 195 | 196 | self.qf2_optimizer.zero_grad() 197 | qf2_loss.backward() 198 | self.qf2_optimizer.step() 199 | 200 | self._update_target_networks() 201 | 202 | # Information For Logger 203 | info = {} 204 | info['Reward_Mean'] = rewards.mean().item() 205 | 206 | if self.automatic_entropy_tuning: 207 | for i in range(self.task_nums): 208 | info["alpha_{}".format(i)] = self.log_alpha[i].exp().item() 209 | info["Alpha_loss"] = alpha_loss.item() 210 | info['Training/policy_loss'] = policy_loss.item() 211 | info['Training/qf1_loss'] = qf1_loss.item() 212 | info['Training/qf2_loss'] = qf2_loss.item() 213 | 214 | info['log_std/mean'] = log_stds.mean().item() 215 | info['log_std/std'] = log_stds.std().item() 216 | info['log_std/max'] = log_stds.max().item() 217 | info['log_std/min'] = log_stds.min().item() 218 | 219 | info['log_probs/mean'] = log_probs.mean().item() 220 | info['log_probs/std'] = log_probs.std().item() 221 | info['log_probs/max'] = log_probs.max().item() 222 | info['log_probs/min'] = log_probs.min().item() 223 | 224 | info['mean/mean'] = means.mean().item() 225 | info['mean/std'] = means.std().item() 226 | info['mean/max'] = means.max().item() 227 | info['mean/min'] = means.min().item() 228 | 229 | return info 230 | 231 | def update_per_epoch(self): 232 | for _ in range(self.opt_times): 233 | batch = self.replay_buffer.random_batch(self.batch_size, 234 | self.sample_key, 235 | reshape=True) 236 | infos = self.update(batch) 237 | self.logger.add_update_info(infos) 238 | -------------------------------------------------------------------------------- /torchrl/algo/off_policy/off_rl_algo.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import math 4 | 5 | import torch 6 | 7 | from torchrl.algo.rl_algo import RLAlgo 8 | 9 | class OffRLAlgo(RLAlgo): 10 | """ 11 | Base RL Algorithm Framework 12 | """ 13 | def __init__(self, 14 | 15 | pretrain_epochs=0, 16 | 17 | min_pool = 0, 18 | 19 | target_hard_update_period = 1000, 20 | use_soft_update = True, 21 | tau = 0.001, 22 | opt_times = 1, 23 | 24 | **kwargs 25 | ): 26 | super(OffRLAlgo, self).__init__(**kwargs) 27 | 28 | # environment relevant information 29 | self.pretrain_epochs = pretrain_epochs 30 | 31 | # target_network update information 32 | self.target_hard_update_period = target_hard_update_period 33 | self.use_soft_update = use_soft_update 34 | self.tau = tau 35 | 36 | # training information 37 | self.opt_times = opt_times 38 | self.min_pool = min_pool 39 | 40 | self.sample_key = [ "obs", "next_obs", "acts", "rewards", "terminals" ] 41 | 42 | def update_per_timestep(self): 43 | if self.replay_buffer.num_steps_can_sample() > max( self.min_pool, self.batch_size ): 44 | for _ in range( self.opt_times ): 45 | batch = self.replay_buffer.random_batch(self.batch_size, self.sample_key) 46 | infos = self.update( batch ) 47 | self.logger.add_update_info( infos ) 48 | 49 | def update_per_epoch(self): 50 | for _ in range( self.opt_times ): 51 | batch = self.replay_buffer.random_batch(self.batch_size, self.sample_key) 52 | infos = self.update( batch ) 53 | self.logger.add_update_info( infos ) 54 | 55 | def pretrain(self): 56 | total_frames = 0 57 | self.pretrain_epochs * self.collector.worker_nums * self.epoch_frames 58 | 59 | for pretrain_epoch in range( self.pretrain_epochs ): 60 | 61 | start = time.time() 62 | 63 | self.start_epoch() 64 | 65 | training_epoch_info = self.collector.train_one_epoch() 66 | for reward in training_epoch_info["train_rewards"]: 67 | self.training_episode_rewards.append(reward) 68 | 69 | finish_epoch_info = self.finish_epoch() 70 | 71 | total_frames += self.collector.active_worker_nums * self.epoch_frames 72 | 73 | infos = {} 74 | 75 | infos["Train_Epoch_Reward"] = training_epoch_info["train_epoch_reward"] 76 | infos["Running_Training_Average_Rewards"] = np.mean(self.training_episode_rewards) 77 | infos.update(finish_epoch_info) 78 | 79 | self.logger.add_epoch_info(pretrain_epoch, total_frames, time.time() - start, infos, csv_write=False ) 80 | 81 | self.pretrain_frames = total_frames 82 | 83 | self.logger.log("Finished Pretrain") 84 | -------------------------------------------------------------------------------- /torchrl/algo/off_policy/sac.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import copy 4 | 5 | import torch 6 | import torch.optim as optim 7 | from torch import nn as nn 8 | 9 | from .off_rl_algo import OffRLAlgo 10 | 11 | class SAC(OffRLAlgo): 12 | """ 13 | SAC 14 | """ 15 | def __init__( 16 | self, 17 | pf, vf, qf, 18 | plr,vlr,qlr, 19 | optimizer_class=optim.Adam, 20 | 21 | policy_std_reg_weight=1e-3, 22 | policy_mean_reg_weight=1e-3, 23 | 24 | reparameterization = True, 25 | automatic_entropy_tuning = True, 26 | target_entropy = None, 27 | **kwargs 28 | ): 29 | super(SAC, self).__init__(**kwargs) 30 | self.pf = pf 31 | self.qf = qf 32 | self.vf = vf 33 | self.target_vf = copy.deepcopy(vf) 34 | self.to(self.device) 35 | 36 | self.plr = plr 37 | self.vlr = vlr 38 | self.qlr = qlr 39 | 40 | self.qf_optimizer = optimizer_class( 41 | self.qf.parameters(), 42 | lr=self.qlr, 43 | ) 44 | 45 | self.vf_optimizer = optimizer_class( 46 | self.vf.parameters(), 47 | lr=self.vlr, 48 | ) 49 | 50 | self.pf_optimizer = optimizer_class( 51 | self.pf.parameters(), 52 | lr=self.plr, 53 | ) 54 | 55 | self.automatic_entropy_tuning = automatic_entropy_tuning 56 | if self.automatic_entropy_tuning: 57 | if target_entropy: 58 | self.target_entropy = target_entropy 59 | else: 60 | self.target_entropy = -np.prod(self.env.action_space.shape).item() # from rlkit 61 | self.log_alpha = torch.zeros(1).to(self.device) 62 | self.log_alpha.requires_grad_() 63 | self.alpha_optimizer = optimizer_class( 64 | [self.log_alpha], 65 | lr=self.plr, 66 | ) 67 | 68 | self.qf_criterion = nn.MSELoss() 69 | self.vf_criterion = nn.MSELoss() 70 | 71 | self.policy_std_reg_weight = policy_std_reg_weight 72 | self.policy_mean_reg_weight = policy_mean_reg_weight 73 | 74 | self.reparameterization = reparameterization 75 | 76 | def update(self, batch): 77 | self.training_update_num += 1 78 | 79 | obs = batch['obs'] 80 | actions = batch['acts'] 81 | next_obs = batch['next_obs'] 82 | rewards = batch['rewards'] 83 | terminals = batch['terminals'] 84 | 85 | rewards = torch.Tensor(rewards).to( self.device ) 86 | terminals = torch.Tensor(terminals).to( self.device ) 87 | obs = torch.Tensor(obs).to( self.device ) 88 | actions = torch.Tensor(actions).to( self.device ) 89 | next_obs = torch.Tensor(next_obs).to( self.device ) 90 | 91 | """ 92 | Policy operations. 93 | """ 94 | sample_info = self.pf.explore(obs, return_log_probs=True ) 95 | 96 | mean = sample_info["mean"] 97 | log_std = sample_info["log_std"] 98 | new_actions = sample_info["action"] 99 | log_probs = sample_info["log_prob"] 100 | ent = sample_info["ent"] 101 | 102 | q_pred = self.qf([obs, actions]) 103 | v_pred = self.vf(obs) 104 | 105 | if self.automatic_entropy_tuning: 106 | """ 107 | Alpha Loss 108 | """ 109 | alpha_loss = -(self.log_alpha * (log_probs + self.target_entropy).detach()).mean() 110 | self.alpha_optimizer.zero_grad() 111 | alpha_loss.backward() 112 | self.alpha_optimizer.step() 113 | alpha = self.log_alpha.exp() 114 | else: 115 | alpha = 1 116 | alpha_loss = 0 117 | 118 | """ 119 | QF Loss 120 | """ 121 | target_v_values = self.target_vf(next_obs) 122 | q_target = rewards + (1. - terminals) * self.discount * target_v_values 123 | qf_loss = self.qf_criterion( q_pred, q_target.detach()) 124 | 125 | """ 126 | VF Loss 127 | """ 128 | q_new_actions = self.qf([obs, new_actions]) 129 | v_target = q_new_actions - alpha * log_probs 130 | vf_loss = self.vf_criterion( v_pred, v_target.detach()) 131 | 132 | """ 133 | Policy Loss 134 | """ 135 | if not self.reparameterization: 136 | log_policy_target = q_new_actions - v_pred 137 | policy_loss = ( 138 | log_probs * ( alpha * log_probs - log_policy_target).detach() 139 | ).mean() 140 | else: 141 | policy_loss = ( alpha * log_probs - q_new_actions).mean() 142 | 143 | std_reg_loss = self.policy_std_reg_weight * (log_std**2).mean() 144 | mean_reg_loss = self.policy_mean_reg_weight * (mean**2).mean() 145 | 146 | policy_loss += std_reg_loss + mean_reg_loss 147 | 148 | """ 149 | Update Networks 150 | """ 151 | 152 | self.pf_optimizer.zero_grad() 153 | policy_loss.backward() 154 | self.pf_optimizer.step() 155 | 156 | self.qf_optimizer.zero_grad() 157 | qf_loss.backward() 158 | self.qf_optimizer.step() 159 | 160 | self.vf_optimizer.zero_grad() 161 | vf_loss.backward() 162 | self.vf_optimizer.step() 163 | 164 | self._update_target_networks() 165 | 166 | # Information For Logger 167 | info = {} 168 | info['Reward_Mean'] = rewards.mean().item() 169 | 170 | if self.automatic_entropy_tuning: 171 | info["Alpha"] = alpha.item() 172 | info["Alpha_loss"] = alpha_loss.item() 173 | info['Training/policy_loss'] = policy_loss.item() 174 | info['Training/vf_loss'] = vf_loss.item() 175 | info['Training/qf_loss'] = qf_loss.item() 176 | 177 | info['log_std/mean'] = log_std.mean().item() 178 | info['log_std/std'] = log_std.std().item() 179 | info['log_std/max'] = log_std.max().item() 180 | info['log_std/min'] = log_std.min().item() 181 | 182 | info['log_probs/mean'] = log_std.mean().item() 183 | info['log_probs/std'] = log_std.std().item() 184 | info['log_probs/max'] = log_std.max().item() 185 | info['log_probs/min'] = log_std.min().item() 186 | 187 | info['mean/mean'] = mean.mean().item() 188 | info['mean/std'] = mean.std().item() 189 | info['mean/max'] = mean.max().item() 190 | info['mean/min'] = mean.min().item() 191 | 192 | return info 193 | 194 | @property 195 | def networks(self): 196 | return [ 197 | self.pf, 198 | self.qf, 199 | self.vf, 200 | self.target_vf 201 | ] 202 | 203 | @property 204 | def snapshot_networks(self): 205 | return [ 206 | ["pf", self.pf], 207 | ["qf", self.qf], 208 | ["vf", self.vf] 209 | ] 210 | 211 | @property 212 | def target_networks(self): 213 | return [ 214 | ( self.vf, self.target_vf ) 215 | ] 216 | -------------------------------------------------------------------------------- /torchrl/algo/off_policy/twin_sac.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import copy 4 | 5 | import torch 6 | import torch.optim as optim 7 | from torch import nn as nn 8 | 9 | from .off_rl_algo import OffRLAlgo 10 | 11 | class TwinSAC(OffRLAlgo): 12 | """ 13 | SAC 14 | """ 15 | 16 | def __init__( 17 | self, 18 | pf, vf, 19 | qf1, qf2, 20 | plr,vlr,qlr, 21 | optimizer_class=optim.Adam, 22 | 23 | policy_std_reg_weight=1e-3, 24 | policy_mean_reg_weight=1e-3, 25 | 26 | reparameterization = True, 27 | automatic_entropy_tuning = True, 28 | target_entropy = None, 29 | **kwargs 30 | ): 31 | super(TwinSAC, self).__init__(**kwargs) 32 | self.pf = pf 33 | self.qf1 = qf1 34 | self.qf2 = qf2 35 | self.vf = vf 36 | self.target_vf = copy.deepcopy(vf) 37 | self.to(self.device) 38 | 39 | self.plr = plr 40 | self.vlr = vlr 41 | self.qlr = qlr 42 | 43 | self.qf1_optimizer = optimizer_class( 44 | self.qf1.parameters(), 45 | lr=self.qlr, 46 | ) 47 | 48 | self.qf2_optimizer = optimizer_class( 49 | self.qf2.parameters(), 50 | lr=self.qlr, 51 | ) 52 | 53 | self.vf_optimizer = optimizer_class( 54 | self.vf.parameters(), 55 | lr=self.vlr, 56 | ) 57 | 58 | self.pf_optimizer = optimizer_class( 59 | self.pf.parameters(), 60 | lr=self.plr, 61 | ) 62 | 63 | self.automatic_entropy_tuning = automatic_entropy_tuning 64 | if self.automatic_entropy_tuning: 65 | if target_entropy: 66 | self.target_entropy = target_entropy 67 | else: 68 | self.target_entropy = -np.prod(self.env.action_space.shape).item() # from rlkit 69 | self.log_alpha = torch.zeros(1).to(self.device) 70 | self.log_alpha.requires_grad_() 71 | self.alpha_optimizer = optimizer_class( 72 | [self.log_alpha], 73 | lr=self.plr, 74 | ) 75 | 76 | self.qf_criterion = nn.MSELoss() 77 | self.vf_criterion = nn.MSELoss() 78 | 79 | self.policy_std_reg_weight = policy_std_reg_weight 80 | self.policy_mean_reg_weight = policy_mean_reg_weight 81 | 82 | self.reparameterization = reparameterization 83 | 84 | 85 | def update(self, batch): 86 | self.training_update_num += 1 87 | obs = batch['obs'] 88 | actions = batch['acts'] 89 | next_obs = batch['next_obs'] 90 | rewards = batch['rewards'] 91 | terminals = batch['terminals'] 92 | 93 | rewards = torch.Tensor(rewards).to( self.device ) 94 | terminals = torch.Tensor(terminals).to( self.device ) 95 | obs = torch.Tensor(obs).to( self.device ) 96 | actions = torch.Tensor(actions).to( self.device ) 97 | next_obs = torch.Tensor(next_obs).to( self.device ) 98 | 99 | """ 100 | Policy operations. 101 | """ 102 | sample_info = self.pf.explore(obs, return_log_probs=True ) 103 | 104 | mean = sample_info["mean"] 105 | log_std = sample_info["log_std"] 106 | new_actions = sample_info["action"] 107 | log_probs = sample_info["log_prob"] 108 | ent = sample_info["ent"] 109 | 110 | q1_pred = self.qf1([obs, actions]) 111 | q2_pred = self.qf2([obs, actions]) 112 | v_pred = self.vf(obs) 113 | 114 | if self.automatic_entropy_tuning: 115 | """ 116 | Alpha Loss 117 | """ 118 | alpha_loss = -(self.log_alpha * (log_probs + self.target_entropy).detach()).mean() 119 | self.alpha_optimizer.zero_grad() 120 | alpha_loss.backward() 121 | self.alpha_optimizer.step() 122 | alpha = self.log_alpha.exp() 123 | else: 124 | alpha = 1 125 | alpha_loss = 0 126 | 127 | """ 128 | QF Loss 129 | """ 130 | target_v_values = self.target_vf(next_obs) 131 | q_target = rewards + (1. - terminals) * self.discount * target_v_values 132 | qf1_loss = self.qf_criterion( q1_pred, q_target.detach()) 133 | qf2_loss = self.qf_criterion( q2_pred, q_target.detach()) 134 | 135 | """ 136 | VF Loss 137 | """ 138 | q_new_actions = torch.min(self.qf1([obs, new_actions]), self.qf2([obs, new_actions])) 139 | v_target = q_new_actions - alpha * log_probs 140 | vf_loss = self.vf_criterion( v_pred, v_target.detach()) 141 | 142 | """ 143 | Policy Loss 144 | """ 145 | if not self.reparameterization: 146 | log_policy_target = q_new_actions - v_pred 147 | policy_loss = ( 148 | log_probs * ( alpha * log_probs - log_policy_target).detach() 149 | ).mean() 150 | else: 151 | policy_loss = ( alpha * log_probs - q_new_actions).mean() 152 | 153 | std_reg_loss = self.policy_std_reg_weight * (log_std**2).mean() 154 | mean_reg_loss = self.policy_mean_reg_weight * (mean**2).mean() 155 | 156 | policy_loss += std_reg_loss + mean_reg_loss 157 | 158 | """ 159 | Update Networks 160 | """ 161 | 162 | self.pf_optimizer.zero_grad() 163 | policy_loss.backward() 164 | self.pf_optimizer.step() 165 | 166 | self.qf1_optimizer.zero_grad() 167 | qf1_loss.backward() 168 | self.qf1_optimizer.step() 169 | 170 | self.qf2_optimizer.zero_grad() 171 | qf2_loss.backward() 172 | self.qf2_optimizer.step() 173 | 174 | self.vf_optimizer.zero_grad() 175 | vf_loss.backward() 176 | self.vf_optimizer.step() 177 | 178 | self._update_target_networks() 179 | 180 | # Information For Logger 181 | info = {} 182 | info['Reward_Mean'] = rewards.mean().item() 183 | 184 | if self.automatic_entropy_tuning: 185 | info["Alpha"] = alpha.item() 186 | info["Alpha_loss"] = alpha_loss.item() 187 | info['Training/policy_loss'] = policy_loss.item() 188 | info['Training/vf_loss'] = vf_loss.item() 189 | info['Training/qf1_loss'] = qf1_loss.item() 190 | info['Training/qf2_loss'] = qf2_loss.item() 191 | 192 | info['log_std/mean'] = log_std.mean().item() 193 | info['log_std/std'] = log_std.std().item() 194 | info['log_std/max'] = log_std.max().item() 195 | info['log_std/min'] = log_std.min().item() 196 | 197 | info['log_probs/mean'] = log_probs.mean().item() 198 | info['log_probs/std'] = log_probs.std().item() 199 | info['log_probs/max'] = log_probs.max().item() 200 | info['log_probs/min'] = log_probs.min().item() 201 | 202 | info['mean/mean'] = mean.mean().item() 203 | info['mean/std'] = mean.std().item() 204 | info['mean/max'] = mean.max().item() 205 | info['mean/min'] = mean.min().item() 206 | 207 | return info 208 | 209 | @property 210 | def networks(self): 211 | return [ 212 | self.pf, 213 | self.qf1, 214 | self.qf2, 215 | self.vf, 216 | self.target_vf 217 | ] 218 | 219 | @property 220 | def snapshot_networks(self): 221 | return [ 222 | ["pf", self.pf], 223 | ["qf1", self.qf1], 224 | ["qf2", self.qf2], 225 | ["vf", self.vf] 226 | ] 227 | 228 | @property 229 | def target_networks(self): 230 | return [ 231 | ( self.vf, self.target_vf ) 232 | ] 233 | -------------------------------------------------------------------------------- /torchrl/algo/off_policy/twin_sac_q.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import copy 4 | 5 | import torch 6 | import torch.optim as optim 7 | from torch import nn as nn 8 | 9 | from .off_rl_algo import OffRLAlgo 10 | 11 | class TwinSACQ(OffRLAlgo): 12 | """ 13 | Twin SAC without V 14 | """ 15 | 16 | def __init__( 17 | self, 18 | pf, 19 | qf1, qf2, 20 | plr, qlr, 21 | optimizer_class=optim.Adam, 22 | 23 | policy_std_reg_weight=1e-3, 24 | policy_mean_reg_weight=1e-3, 25 | 26 | reparameterization=True, 27 | automatic_entropy_tuning=True, 28 | target_entropy=None, 29 | **kwargs 30 | ): 31 | super(TwinSACQ, self).__init__(**kwargs) 32 | self.pf = pf 33 | self.qf1 = qf1 34 | self.qf2 = qf2 35 | self.target_qf1 = copy.deepcopy(qf1) 36 | self.target_qf2 = copy.deepcopy(qf2) 37 | 38 | self.to(self.device) 39 | 40 | self.plr = plr 41 | self.qlr = qlr 42 | 43 | self.optimizer_class = optimizer_class 44 | 45 | self.qf1_optimizer = optimizer_class( 46 | self.qf1.parameters(), 47 | lr=self.qlr, 48 | ) 49 | 50 | self.qf2_optimizer = optimizer_class( 51 | self.qf2.parameters(), 52 | lr=self.qlr, 53 | ) 54 | 55 | self.pf_optimizer = optimizer_class( 56 | self.pf.parameters(), 57 | lr=self.plr, 58 | ) 59 | 60 | self.automatic_entropy_tuning = automatic_entropy_tuning 61 | if self.automatic_entropy_tuning: 62 | if target_entropy: 63 | self.target_entropy = target_entropy 64 | else: 65 | self.target_entropy = -np.prod(self.env.action_space.shape).item() # from rlkit 66 | self.log_alpha = torch.zeros(1).to(self.device) 67 | self.log_alpha.requires_grad_() 68 | self.alpha_optimizer = optimizer_class( 69 | [self.log_alpha], 70 | lr=self.plr, 71 | ) 72 | 73 | self.qf_criterion = nn.MSELoss() 74 | 75 | self.policy_std_reg_weight = policy_std_reg_weight 76 | self.policy_mean_reg_weight = policy_mean_reg_weight 77 | 78 | self.reparameterization = reparameterization 79 | 80 | def update(self, batch): 81 | self.training_update_num += 1 82 | obs = batch['obs'] 83 | actions = batch['acts'] 84 | next_obs = batch['next_obs'] 85 | rewards = batch['rewards'] 86 | terminals = batch['terminals'] 87 | 88 | rewards = torch.Tensor(rewards).to( self.device ) 89 | terminals = torch.Tensor(terminals).to( self.device ) 90 | obs = torch.Tensor(obs).to( self.device ) 91 | actions = torch.Tensor(actions).to( self.device ) 92 | next_obs = torch.Tensor(next_obs).to( self.device ) 93 | 94 | """ 95 | Policy operations. 96 | """ 97 | sample_info = self.pf.explore(obs, return_log_probs=True ) 98 | 99 | mean = sample_info["mean"] 100 | log_std = sample_info["log_std"] 101 | new_actions = sample_info["action"] 102 | log_probs = sample_info["log_prob"] 103 | 104 | q1_pred = self.qf1([obs, actions]) 105 | q2_pred = self.qf2([obs, actions]) 106 | # v_pred = self.vf(obs) 107 | 108 | if self.automatic_entropy_tuning: 109 | """ 110 | Alpha Loss 111 | """ 112 | alpha_loss = -(self.log_alpha * (log_probs + self.target_entropy).detach()).mean() 113 | self.alpha_optimizer.zero_grad() 114 | alpha_loss.backward() 115 | self.alpha_optimizer.step() 116 | alpha = self.log_alpha.exp().detach() 117 | else: 118 | alpha = 1 119 | alpha_loss = 0 120 | 121 | with torch.no_grad(): 122 | target_sample_info = self.pf.explore(next_obs, return_log_probs=True ) 123 | 124 | target_actions = target_sample_info["action"] 125 | target_log_probs = target_sample_info["log_prob"] 126 | 127 | target_q1_pred = self.target_qf1([next_obs, target_actions]) 128 | target_q2_pred = self.target_qf2([next_obs, target_actions]) 129 | min_target_q = torch.min(target_q1_pred, target_q2_pred) 130 | target_v_values = min_target_q - alpha * target_log_probs 131 | """ 132 | QF Loss 133 | """ 134 | q_target = rewards + (1. - terminals) * self.discount * target_v_values 135 | qf1_loss = self.qf_criterion(q1_pred, q_target.detach()) 136 | qf2_loss = self.qf_criterion(q2_pred, q_target.detach()) 137 | assert q1_pred.shape == q_target.shape 138 | assert q2_pred.shape == q_target.shape 139 | # qf1_loss = (0.5 * ( q1_pred - q_target.detach() ) ** 2).mean() 140 | # qf2_loss = (0.5 * ( q2_pred - q_target.detach() ) ** 2).mean() 141 | 142 | q_new_actions = torch.min( 143 | self.qf1([obs, new_actions]), 144 | self.qf2([obs, new_actions])) 145 | """ 146 | Policy Loss 147 | """ 148 | if not self.reparameterization: 149 | raise NotImplementedError 150 | else: 151 | assert log_probs.shape == q_new_actions.shape 152 | policy_loss = ( alpha * log_probs - q_new_actions).mean() 153 | 154 | std_reg_loss = self.policy_std_reg_weight * (log_std**2).mean() 155 | mean_reg_loss = self.policy_mean_reg_weight * (mean**2).mean() 156 | 157 | policy_loss += std_reg_loss + mean_reg_loss 158 | 159 | """ 160 | Update Networks 161 | """ 162 | 163 | self.pf_optimizer.zero_grad() 164 | policy_loss.backward() 165 | pf_norm = torch.nn.utils.clip_grad_norm_(self.pf.parameters(), 10) 166 | self.pf_optimizer.step() 167 | 168 | self.qf1_optimizer.zero_grad() 169 | qf1_loss.backward() 170 | qf1_norm = torch.nn.utils.clip_grad_norm_(self.qf1.parameters(), 10) 171 | self.qf1_optimizer.step() 172 | 173 | self.qf2_optimizer.zero_grad() 174 | qf2_loss.backward() 175 | qf2_norm = torch.nn.utils.clip_grad_norm_(self.qf2.parameters(), 10) 176 | self.qf2_optimizer.step() 177 | 178 | self._update_target_networks() 179 | 180 | # Information For Logger 181 | info = {} 182 | info['Reward_Mean'] = rewards.mean().item() 183 | 184 | if self.automatic_entropy_tuning: 185 | info["Alpha"] = alpha.item() 186 | info["Alpha_loss"] = alpha_loss.item() 187 | info['Training/policy_loss'] = policy_loss.item() 188 | info['Training/qf1_loss'] = qf1_loss.item() 189 | info['Training/qf2_loss'] = qf2_loss.item() 190 | 191 | info['Training/pf_norm'] = pf_norm 192 | info['Training/qf1_norm'] = qf1_norm 193 | info['Training/qf2_norm'] = qf2_norm 194 | 195 | info['log_std/mean'] = log_std.mean().item() 196 | info['log_std/std'] = log_std.std().item() 197 | info['log_std/max'] = log_std.max().item() 198 | info['log_std/min'] = log_std.min().item() 199 | 200 | info['log_probs/mean'] = log_probs.mean().item() 201 | info['log_probs/std'] = log_probs.std().item() 202 | info['log_probs/max'] = log_probs.max().item() 203 | info['log_probs/min'] = log_probs.min().item() 204 | 205 | info['mean/mean'] = mean.mean().item() 206 | info['mean/std'] = mean.std().item() 207 | info['mean/max'] = mean.max().item() 208 | info['mean/min'] = mean.min().item() 209 | 210 | return info 211 | 212 | @property 213 | def networks(self): 214 | return [ 215 | self.pf, 216 | self.qf1, 217 | self.qf2, 218 | self.target_qf1, 219 | self.target_qf2 220 | ] 221 | 222 | @property 223 | def snapshot_networks(self): 224 | return [ 225 | ["pf", self.pf], 226 | ["qf1", self.qf1], 227 | ["qf2", self.qf2], 228 | ] 229 | 230 | @property 231 | def target_networks(self): 232 | return [ 233 | ( self.qf1, self.target_qf1 ), 234 | ( self.qf2, self.target_qf2 ) 235 | ] 236 | -------------------------------------------------------------------------------- /torchrl/algo/rl_algo.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import time 3 | from collections import deque 4 | import numpy as np 5 | 6 | import torch 7 | 8 | import torchrl.algo.utils as atu 9 | 10 | import gym 11 | 12 | import os 13 | import os.path as osp 14 | 15 | class RLAlgo(): 16 | """ 17 | Base RL Algorithm Framework 18 | """ 19 | def __init__(self, 20 | env = None, 21 | replay_buffer = None, 22 | collector = None, 23 | logger = None, 24 | continuous = None, 25 | discount=0.99, 26 | num_epochs = 3000, 27 | epoch_frames = 1000, 28 | max_episode_frames = 999, 29 | batch_size = 128, 30 | device = 'cpu', 31 | train_render = False, 32 | eval_episodes = 1, 33 | eval_render = False, 34 | save_interval = 100, 35 | save_dir = None 36 | ): 37 | 38 | self.env = env 39 | 40 | self.continuous = isinstance(self.env.action_space, gym.spaces.Box) 41 | 42 | self.replay_buffer = replay_buffer 43 | self.collector = collector 44 | # device specification 45 | self.device = device 46 | 47 | # environment relevant information 48 | self.discount = discount 49 | self.num_epochs = num_epochs 50 | self.epoch_frames = epoch_frames 51 | self.max_episode_frames = max_episode_frames 52 | 53 | self.train_render = train_render 54 | self.eval_render = eval_render 55 | 56 | # training information 57 | self.batch_size = batch_size 58 | self.training_update_num = 0 59 | self.sample_key = None 60 | 61 | # Logger & relevant setting 62 | self.logger = logger 63 | 64 | 65 | self.episode_rewards = deque(maxlen=30) 66 | self.training_episode_rewards = deque(maxlen=30) 67 | self.eval_episodes = eval_episodes 68 | 69 | self.save_interval = save_interval 70 | self.save_dir = save_dir 71 | if not osp.exists( self.save_dir ): 72 | os.mkdir( self.save_dir ) 73 | 74 | self.best_eval = None 75 | 76 | def start_epoch(self): 77 | pass 78 | 79 | def finish_epoch(self): 80 | return {} 81 | 82 | def pretrain(self): 83 | pass 84 | 85 | def update_per_epoch(self): 86 | pass 87 | 88 | def snapshot(self, prefix, epoch): 89 | for name, network in self.snapshot_networks: 90 | model_file_name="model_{}_{}.pth".format(name, epoch) 91 | model_path=osp.join(prefix, model_file_name) 92 | torch.save(network.state_dict(), model_path) 93 | 94 | def train(self): 95 | self.pretrain() 96 | total_frames = 0 97 | if hasattr(self, "pretrain_frames"): 98 | total_frames = self.pretrain_frames 99 | 100 | self.start_epoch() 101 | 102 | for epoch in range(self.num_epochs): 103 | self.current_epoch = epoch 104 | start = time.time() 105 | 106 | self.start_epoch() 107 | 108 | explore_start_time = time.time() 109 | training_epoch_info = self.collector.train_one_epoch() 110 | for reward in training_epoch_info["train_rewards"]: 111 | self.training_episode_rewards.append(reward) 112 | explore_time = time.time() - explore_start_time 113 | 114 | train_start_time = time.time() 115 | self.update_per_epoch() 116 | train_time = time.time() - train_start_time 117 | 118 | finish_epoch_info = self.finish_epoch() 119 | 120 | eval_start_time = time.time() 121 | eval_infos = self.collector.eval_one_epoch() 122 | eval_time = time.time() - eval_start_time 123 | 124 | total_frames += self.collector.active_worker_nums * self.epoch_frames 125 | 126 | infos = {} 127 | 128 | for reward in eval_infos["eval_rewards"]: 129 | self.episode_rewards.append(reward) 130 | # del eval_infos["eval_rewards"] 131 | 132 | if self.best_eval is None or \ 133 | np.mean(eval_infos["eval_rewards"]) > self.best_eval: 134 | self.best_eval = np.mean(eval_infos["eval_rewards"]) 135 | self.snapshot(self.save_dir, 'best') 136 | del eval_infos["eval_rewards"] 137 | 138 | infos["Running_Average_Rewards"] = np.mean(self.episode_rewards) 139 | infos["Train_Epoch_Reward"] = training_epoch_info["train_epoch_reward"] 140 | infos["Running_Training_Average_Rewards"] = np.mean( 141 | self.training_episode_rewards) 142 | infos["Explore_Time"] = explore_time 143 | infos["Train___Time"] = train_time 144 | infos["Eval____Time"] = eval_time 145 | infos.update(eval_infos) 146 | infos.update(finish_epoch_info) 147 | 148 | self.logger.add_epoch_info(epoch, total_frames, 149 | time.time() - start, infos ) 150 | 151 | if epoch % self.save_interval == 0: 152 | self.snapshot(self.save_dir, epoch) 153 | 154 | self.snapshot(self.save_dir, "finish") 155 | self.collector.terminate() 156 | 157 | def update(self, batch): 158 | raise NotImplementedError 159 | 160 | def _update_target_networks(self): 161 | if self.use_soft_update: 162 | for net, target_net in self.target_networks: 163 | atu.soft_update_from_to(net, target_net, self.tau) 164 | else: 165 | if self.training_update_num % self.target_hard_update_period == 0: 166 | for net, target_net in self.target_networks: 167 | atu.copy_model_params_from_to(net, target_net) 168 | 169 | @property 170 | def networks(self): 171 | return [ 172 | ] 173 | 174 | @property 175 | def snapshot_networks(self): 176 | return [ 177 | ] 178 | 179 | @property 180 | def target_networks(self): 181 | return [ 182 | ] 183 | 184 | def to(self, device): 185 | for net in self.networks: 186 | net.to(device) 187 | -------------------------------------------------------------------------------- /torchrl/algo/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def quantile_regression_loss(coefficient, source, target): 6 | diff = target.unsqueeze(-1) - source.unsqueeze(1) 7 | loss = huber(diff) * (coefficient - (diff.detach() < 0).float()).abs() 8 | loss = loss.mean() 9 | return loss 10 | 11 | 12 | def huber(x, k=1.0): 13 | return torch.where(x.abs() < k, 0.5 * x.pow(2), k * (x.abs() - 0.5 * k)) 14 | 15 | 16 | def soft_update_from_to(source, target, tau): 17 | for target_param, param in zip(target.parameters(), source.parameters()): 18 | target_param.data.copy_( 19 | target_param.data * (1.0 - tau) + param.data * tau 20 | ) 21 | 22 | 23 | def copy_model_params_from_to(source, target): 24 | for target_param, param in zip(target.parameters(), source.parameters()): 25 | target_param.data.copy_(param.data) 26 | 27 | 28 | def unsqe_cat_gather(tensor_list, idx, dim = 1 ): 29 | tensor_list = [tensor.unsqueeze(dim) for tensor in tensor_list] 30 | tensors = torch.cat(tensor_list, dim = dim) 31 | 32 | target_shape = list(tensors.shape) 33 | target_shape[dim] = 1 34 | 35 | view_shape = list(idx.shape) + [1] * (len(target_shape) - len(idx.shape)) 36 | idx = idx.view(view_shape) 37 | idx = idx.expand(tuple(target_shape)) 38 | tensors = tensors.gather(dim, idx).squeeze(dim) 39 | return tensors 40 | 41 | 42 | def update_linear_schedule(optimizer, epoch, total_num_epochs, initial_lr): 43 | """Decreases the learning rate linearly""" 44 | lr = initial_lr - (initial_lr * (epoch / float(total_num_epochs))) 45 | for param_group in optimizer.param_groups: 46 | param_group['lr'] = lr 47 | 48 | -------------------------------------------------------------------------------- /torchrl/collector/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseCollector 2 | from .mt import MultiTaskCollectorBase 3 | -------------------------------------------------------------------------------- /torchrl/collector/base.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | import torch 3 | import torch.multiprocessing as mp 4 | import copy 5 | import numpy as np 6 | import gym 7 | 8 | class EnvInfo(): 9 | def __init__(self, 10 | env, 11 | device, 12 | train_render, 13 | eval_render, 14 | epoch_frames, 15 | eval_episodes, 16 | max_episode_frames, 17 | continuous, 18 | env_rank): 19 | 20 | self.current_step = 0 21 | 22 | self.env = env 23 | self.device = device 24 | self.train_render = train_render 25 | self.eval_render = eval_render 26 | self.epoch_frames = epoch_frames 27 | self.eval_episodes = eval_episodes 28 | self.max_episode_frames = max_episode_frames 29 | self.continuous = continuous 30 | self.env_rank = env_rank 31 | 32 | # For Parallel Async 33 | self.env_cls = None 34 | self.env_args = None 35 | 36 | def start_episode(self): 37 | self.current_step = 0 38 | 39 | def finish_episode(self): 40 | pass 41 | 42 | 43 | class BaseCollector: 44 | 45 | def __init__( 46 | self, 47 | env, pf, replay_buffer, 48 | train_render=False, 49 | eval_episodes=1, 50 | eval_render=False, 51 | epoch_frames=1000, 52 | device='cpu', 53 | max_episode_frames = 999): 54 | 55 | self.pf = pf 56 | self.replay_buffer = replay_buffer 57 | 58 | self.env = env 59 | self.env.train() 60 | continuous = isinstance(self.env.action_space, gym.spaces.Box) 61 | self.train_render = train_render 62 | 63 | self.eval_env = copy.copy(env) 64 | self.eval_env._reward_scale = 1 65 | self.eval_episodes = eval_episodes 66 | self.eval_render = eval_render 67 | 68 | self.env_info = EnvInfo( 69 | env, device, train_render, eval_render, 70 | epoch_frames, eval_episodes, 71 | max_episode_frames, continuous, None 72 | ) 73 | self.c_ob = { 74 | "ob": self.env.reset() 75 | } 76 | 77 | self.train_rew = 0 78 | self.training_episode_rewards = deque(maxlen=20) 79 | 80 | # device specification 81 | self.device = device 82 | 83 | self.to(self.device) 84 | 85 | self.epoch_frames = epoch_frames 86 | self.max_episode_frames = max_episode_frames 87 | 88 | self.worker_nums = 1 89 | self.active_worker_nums = 1 90 | 91 | @classmethod 92 | def take_actions(cls, funcs, env_info, ob_info, replay_buffer): 93 | 94 | pf = funcs["pf"] 95 | ob = ob_info["ob"] 96 | out = pf.explore( torch.Tensor( ob ).to(env_info.device).unsqueeze(0)) 97 | act = out["action"] 98 | act = act.detach().cpu().numpy() 99 | 100 | if not env_info.continuous: 101 | act = act[0] 102 | 103 | if type(act) is not int: 104 | if np.isnan(act).any(): 105 | print("NaN detected. BOOM") 106 | exit() 107 | 108 | next_ob, reward, done, info = env_info.env.step(act) 109 | if env_info.train_render: 110 | env_info.env.render() 111 | env_info.current_step += 1 112 | 113 | sample_dict = { 114 | "obs":ob, 115 | "next_obs": next_ob, 116 | "acts": act, 117 | "rewards": [reward], 118 | "terminals": [done], 119 | "time_limits": [True if "time_limit" in info else False] 120 | } 121 | 122 | if done or env_info.current_step >= env_info.max_episode_frames: 123 | next_ob = env_info.env.reset() 124 | env_info.finish_episode() 125 | env_info.start_episode() # reset current_step 126 | 127 | replay_buffer.add_sample( sample_dict, env_info.env_rank) 128 | 129 | return next_ob, done, reward, info 130 | 131 | def terminate(self): 132 | pass 133 | 134 | def train_one_epoch(self): 135 | train_rews = [] 136 | train_epoch_reward = 0 137 | self.env.train() 138 | for _ in range(self.epoch_frames): 139 | # Sample actions 140 | next_ob, done, reward, _ = self.__class__.take_actions(self.funcs, 141 | self.env_info, self.c_ob, self.replay_buffer ) 142 | self.c_ob["ob"] = next_ob 143 | # print(self.c_ob) 144 | self.train_rew += reward 145 | train_epoch_reward += reward 146 | if done: 147 | self.training_episode_rewards.append(self.train_rew) 148 | train_rews.append(self.train_rew) 149 | self.train_rew = 0 150 | 151 | return { 152 | 'train_rewards':train_rews, 153 | 'train_epoch_reward':train_epoch_reward 154 | } 155 | 156 | def eval_one_epoch(self): 157 | 158 | eval_infos = {} 159 | eval_rews = [] 160 | 161 | done = False 162 | 163 | self.eval_env = copy.copy(self.env) 164 | self.eval_env.eval() 165 | # print(self.eval_env._obs_mean) 166 | traj_lens = [] 167 | for _ in range(self.eval_episodes): 168 | 169 | eval_ob = self.eval_env.reset() 170 | rew = 0 171 | traj_len = 0 172 | while not done: 173 | act = self.pf.eval_act(torch.Tensor(eval_ob).to( 174 | self.device).unsqueeze(0)) 175 | eval_ob, r, done, _ = self.eval_env.step(act) 176 | rew += r 177 | traj_len += 1 178 | if self.eval_render: 179 | self.eval_env.render() 180 | 181 | eval_rews.append(rew) 182 | traj_lens.append(traj_len) 183 | 184 | done = False 185 | 186 | eval_infos["eval_rewards"] = eval_rews 187 | eval_infos["eval_traj_length"] = np.mean(traj_lens) 188 | return eval_infos 189 | 190 | def to(self, device): 191 | for func in self.funcs: 192 | self.funcs[func].to(device) 193 | 194 | @property 195 | def funcs(self): 196 | return { 197 | "pf": self.pf 198 | } 199 | -------------------------------------------------------------------------------- /torchrl/collector/mt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from .base import BaseCollector 5 | 6 | class MultiTaskCollectorBase(BaseCollector): 7 | 8 | @classmethod 9 | def take_actions(cls, funcs, env_info, ob_info, replay_buffer): 10 | 11 | pf = funcs["pf"] 12 | ob = ob_info["ob"] 13 | # idx = ob_info["task_index"] 14 | task_idx = env_info.env.active_task 15 | out = pf.explore( torch.Tensor( ob ).to(env_info.device).unsqueeze(0), 16 | [task_idx]) 17 | act = out["action"] 18 | act = act[0] 19 | act = act.detach().cpu().numpy() 20 | 21 | if not env_info.continuous: 22 | act = act[0] 23 | 24 | if type(act) is not int: 25 | if np.isnan(act).any(): 26 | print("NaN detected. BOOM") 27 | exit() 28 | 29 | next_ob, reward, done, info = env_info.env.step(act) 30 | if env_info.train_render: 31 | env_info.env.render() 32 | env_info.current_step += 1 33 | 34 | sample_dict = { 35 | "obs":ob, 36 | "next_obs": next_ob, 37 | "acts": act, 38 | "task_idx": task_idx, 39 | "rewards": [reward], 40 | "terminals": [done] 41 | } 42 | 43 | if done or env_info.current_step >= env_info.max_episode_frames: 44 | next_ob = env_info.env.reset() 45 | env_info.finish_episode() 46 | env_info.start_episode() # reset current_step 47 | 48 | replay_buffer.add_sample( sample_dict, env_info.env_rank) 49 | 50 | return next_ob, done, reward, info 51 | 52 | -------------------------------------------------------------------------------- /torchrl/collector/para/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import ParallelCollector 2 | from .base import AsyncParallelCollector -------------------------------------------------------------------------------- /torchrl/collector/para/base.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.multiprocessing as mp 4 | import copy 5 | import numpy as np 6 | import gym 7 | from collections import deque 8 | 9 | from torchrl.collector.base import BaseCollector 10 | from torchrl.collector.base import EnvInfo 11 | 12 | from torchrl.replay_buffers.shared import SharedBaseReplayBuffer 13 | 14 | TIMEOUT_CHILD = 200 15 | 16 | class ParallelCollector(BaseCollector): 17 | 18 | def __init__(self, 19 | env, pf, replay_buffer, 20 | env_cls, env_args, 21 | train_epochs, 22 | eval_epochs, 23 | worker_nums = 4, 24 | eval_worker_nums = 1, 25 | **kwargs): 26 | 27 | super().__init__( 28 | env, pf, replay_buffer, 29 | **kwargs) 30 | 31 | self.env_cls = env_cls 32 | self.env_args = env_args 33 | 34 | self.env_info.device = 'cpu' # CPU For multiprocess sampling 35 | self.shared_funcs = copy.deepcopy(self.funcs) 36 | for key in self.shared_funcs: 37 | self.shared_funcs[key].to(self.env_info.device) 38 | 39 | # assert isinstance(replay_buffer, SharedBaseReplayBuffer), \ 40 | # "Should Use Shared Replay buffer" 41 | self.replay_buffer = replay_buffer 42 | 43 | self.worker_nums = worker_nums 44 | self.active_worker_nums = worker_nums 45 | self.eval_worker_nums = eval_worker_nums 46 | 47 | self.manager = mp.Manager() 48 | self.train_epochs = train_epochs 49 | self.eval_epochs = eval_epochs 50 | self.start_worker() 51 | 52 | @staticmethod 53 | def train_worker_process(cls, shared_funcs, env_info, 54 | replay_buffer, shared_que, 55 | start_barrier, epochs ): 56 | 57 | replay_buffer.rebuild_from_tag() 58 | local_funcs = copy.deepcopy(shared_funcs) 59 | for key in local_funcs: 60 | local_funcs[key].to(env_info.device) 61 | 62 | # Rebuild Env 63 | env_info.env = env_info.env_cls(**env_info.env_args) 64 | 65 | c_ob = { 66 | "ob": env_info.env.reset() 67 | } 68 | train_rew = 0 69 | current_epoch = 0 70 | while True: 71 | start_barrier.wait() 72 | current_epoch += 1 73 | if current_epoch > epochs: 74 | break 75 | 76 | for key in shared_funcs: 77 | local_funcs[key].load_state_dict(shared_funcs[key].state_dict()) 78 | 79 | train_rews = [] 80 | train_epoch_reward = 0 81 | 82 | for _ in range(env_info.epoch_frames): 83 | next_ob, done, reward, _ = cls.take_actions(local_funcs, env_info, c_ob, replay_buffer ) 84 | c_ob["ob"] = next_ob 85 | train_rew += reward 86 | train_epoch_reward += reward 87 | if done: 88 | train_rews.append(train_rew) 89 | train_rew = 0 90 | 91 | shared_que.put({ 92 | 'train_rewards':train_rews, 93 | 'train_epoch_reward':train_epoch_reward 94 | }) 95 | 96 | @staticmethod 97 | def eval_worker_process(shared_pf, 98 | env_info, shared_que, start_barrier, epochs): 99 | 100 | pf = copy.deepcopy(shared_pf).to(env_info.device) 101 | 102 | # Rebuild Env 103 | env_info.env = env_info.env_cls(**env_info.env_args) 104 | 105 | env_info.env.eval() 106 | env_info.env._reward_scale = 1 107 | current_epoch = 0 108 | 109 | while True: 110 | start_barrier.wait() 111 | current_epoch += 1 112 | if current_epoch > epochs: 113 | break 114 | pf.load_state_dict(shared_pf.state_dict()) 115 | 116 | eval_rews = [] 117 | 118 | done = False 119 | for _ in range(env_info.eval_episodes): 120 | 121 | eval_ob = env_info.env.reset() 122 | rew = 0 123 | while not done: 124 | act = pf.eval_act( torch.Tensor( eval_ob ).to(env_info.device).unsqueeze(0)) 125 | eval_ob, r, done, _ = env_info.env.step( act ) 126 | rew += r 127 | if env_info.eval_render: 128 | env_info.env.render() 129 | 130 | eval_rews.append(rew) 131 | done = False 132 | 133 | shared_que.put({ 134 | 'eval_rewards':eval_rews 135 | }) 136 | 137 | 138 | def start_worker(self): 139 | self.workers = [] 140 | self.shared_que = self.manager.Queue(self.worker_nums) 141 | self.start_barrier = mp.Barrier(self.worker_nums+1) 142 | 143 | self.eval_workers = [] 144 | self.eval_shared_que = self.manager.Queue(self.eval_worker_nums) 145 | self.eval_start_barrier = mp.Barrier(self.eval_worker_nums+1) 146 | 147 | self.env_info.env_cls = self.env_cls 148 | self.env_info.env_args = self.env_args 149 | 150 | for i in range(self.worker_nums): 151 | self.env_info.env_rank = i 152 | p = mp.Process( 153 | target=self.__class__.train_worker_process, 154 | args=( self.__class__, self.shared_funcs, 155 | self.env_info, self.replay_buffer, 156 | self.shared_que, self.start_barrier, 157 | self.train_epochs)) 158 | p.start() 159 | self.workers.append(p) 160 | 161 | for i in range(self.eval_worker_nums): 162 | eval_p = mp.Process( 163 | target=self.__class__.eval_worker_process, 164 | args=(self.shared_funcs["pf"], 165 | self.env_info, self.eval_shared_que, self.eval_start_barrier, 166 | self.eval_epochs)) 167 | eval_p.start() 168 | self.eval_workers.append(eval_p) 169 | 170 | def terminate(self): 171 | self.start_barrier.wait() 172 | self.eval_start_barrier.wait() 173 | for p in self.workers: 174 | p.join() 175 | 176 | for p in self.eval_workers: 177 | p.join() 178 | 179 | def train_one_epoch(self): 180 | self.start_barrier.wait() 181 | train_rews = [] 182 | train_epoch_reward = 0 183 | 184 | for key in self.shared_funcs: 185 | self.shared_funcs[key].load_state_dict(self.funcs[key].state_dict()) 186 | for _ in range(self.worker_nums): 187 | worker_rst = self.shared_que.get() 188 | train_rews += worker_rst["train_rewards"] 189 | train_epoch_reward += worker_rst["train_epoch_reward"] 190 | 191 | return { 192 | 'train_rewards':train_rews, 193 | 'train_epoch_reward':train_epoch_reward 194 | } 195 | 196 | def eval_one_epoch(self): 197 | self.eval_start_barrier.wait() 198 | eval_rews = [] 199 | 200 | self.shared_funcs["pf"].load_state_dict(self.funcs["pf"].state_dict()) 201 | 202 | for _ in range(self.eval_worker_nums): 203 | worker_rst = self.eval_shared_que.get() 204 | eval_rews += worker_rst["eval_rewards"] 205 | 206 | return { 207 | 'eval_rewards':eval_rews, 208 | } 209 | 210 | @property 211 | def funcs(self): 212 | return { 213 | "pf": self.pf 214 | } 215 | 216 | class AsyncParallelCollector(ParallelCollector): 217 | def start_worker(self): 218 | self.workers = [] 219 | self.shared_que = self.manager.Queue(self.worker_nums) 220 | self.start_barrier = mp.Barrier(self.worker_nums) 221 | 222 | self.eval_workers = [] 223 | self.eval_shared_que = self.manager.Queue(self.eval_worker_nums) 224 | self.eval_start_barrier = mp.Barrier(self.eval_worker_nums) 225 | 226 | self.env_info.env_cls = self.env_cls 227 | self.env_info.env_args = self.env_args 228 | 229 | for i in range(self.worker_nums): 230 | self.env_info.env_rank = i 231 | p = mp.Process( 232 | target=self.__class__.train_worker_process, 233 | args=( self.__class__, self.shared_funcs, 234 | self.env_info, self.replay_buffer, 235 | self.shared_que, self.start_barrier, 236 | self.train_epochs)) 237 | p.start() 238 | self.workers.append(p) 239 | 240 | for i in range(self.eval_worker_nums): 241 | eval_p = mp.Process( 242 | target=self.__class__.eval_worker_process, 243 | args=(self.pf, 244 | self.env_info, self.eval_shared_que, self.eval_start_barrier, 245 | self.eval_epochs)) 246 | eval_p.start() 247 | self.eval_workers.append(eval_p) 248 | 249 | def terminate(self): 250 | # self.eval_start_barrier.wait() 251 | for p in self.workers: 252 | p.join() 253 | 254 | for p in self.eval_workers: 255 | p.join() 256 | 257 | def train_one_epoch(self): 258 | train_rews = [] 259 | train_epoch_reward = 0 260 | 261 | for key in self.shared_funcs: 262 | self.shared_funcs[key].load_state_dict(self.funcs[key].state_dict()) 263 | for _ in range(self.worker_nums): 264 | worker_rst = self.shared_que.get() 265 | train_rews += worker_rst["train_rewards"] 266 | train_epoch_reward += worker_rst["train_epoch_reward"] 267 | 268 | return { 269 | 'train_rewards':train_rews, 270 | 'train_epoch_reward':train_epoch_reward 271 | } 272 | 273 | def eval_one_epoch(self): 274 | # self.eval_start_barrier.wait() 275 | eval_rews = [] 276 | self.shared_funcs["pf"].load_state_dict(self.funcs["pf"].state_dict()) 277 | for _ in range(self.eval_worker_nums): 278 | worker_rst = self.eval_shared_que.get() 279 | eval_rews += worker_rst["eval_rewards"] 280 | 281 | return { 282 | 'eval_rewards':eval_rews, 283 | } 284 | -------------------------------------------------------------------------------- /torchrl/collector/para/mt.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | import copy 5 | import numpy as np 6 | 7 | from .base import ParallelCollector 8 | import torch.multiprocessing as mp 9 | 10 | import torchrl.policies as policies 11 | 12 | class SingleTaskParallelCollectorBase(ParallelCollector): 13 | 14 | def __init__(self, 15 | reset_idx = False, 16 | **kwargs): 17 | self.reset_idx = reset_idx 18 | super().__init__(**kwargs) 19 | 20 | @staticmethod 21 | def eval_worker_process(shared_pf, 22 | env_info, shared_que, start_barrier, terminate_mark, reset_idx): 23 | 24 | pf = copy.deepcopy(shared_pf) 25 | idx_flag = isinstance(pf, policies.MultiHeadGuassianContPolicy) 26 | 27 | env_info.env.eval() 28 | env_info.env._reward_scale = 1 29 | 30 | while True: 31 | start_barrier.wait() 32 | if terminate_mark.value == 1: 33 | break 34 | pf.load_state_dict(shared_pf.state_dict()) 35 | 36 | eval_rews = [] 37 | 38 | done = False 39 | success = 0 40 | for idx in range(env_info.eval_episodes): 41 | if reset_idx: 42 | eval_ob = env_info.env.reset_with_index(idx) 43 | else: 44 | eval_ob = env_info.env.reset() 45 | rew = 0 46 | current_success = 0 47 | while not done: 48 | # act = pf.eval( torch.Tensor( eval_ob ).to(env_info.device).unsqueeze(0)) 49 | if idx_flag: 50 | act = pf.eval( torch.Tensor( eval_ob ).to(env_info.device).unsqueeze(0), [task_idx] ) 51 | else: 52 | act = pf.eval( torch.Tensor( eval_ob ).to(env_info.device).unsqueeze(0)) 53 | eval_ob, r, done, info = env_info.env.step( act ) 54 | rew += r 55 | if env_info.eval_render: 56 | env_info.env.render() 57 | 58 | current_success = max(current_success, info["success"]) 59 | 60 | eval_rews.append(rew) 61 | done = False 62 | success += current_success 63 | 64 | shared_que.put({ 65 | 'eval_rewards': eval_rews, 66 | 'success_rate': success / env_info.eval_episodes 67 | }) 68 | 69 | def start_worker(self): 70 | self.workers = [] 71 | self.shared_que = self.manager.Queue() 72 | self.start_barrier = mp.Barrier(self.worker_nums+1) 73 | self.terminate_mark = mp.Value( 'i', 0 ) 74 | 75 | self.eval_workers = [] 76 | self.eval_shared_que = self.manager.Queue() 77 | self.eval_start_barrier = mp.Barrier(self.eval_worker_nums+1) 78 | 79 | for i in range(self.worker_nums): 80 | self.env_info.env_rank = i 81 | p = mp.Process( 82 | target=self.__class__.train_worker_process, 83 | args=( self.__class__, self.funcs, 84 | self.env_info, self.replay_buffer, 85 | self.shared_que, self.start_barrier, 86 | self.terminate_mark)) 87 | p.start() 88 | self.workers.append(p) 89 | 90 | for i in range(self.eval_worker_nums): 91 | eval_p = mp.Process( 92 | target=self.__class__.eval_worker_process, 93 | args=(self.pf, 94 | self.env_info, self.eval_shared_que, self.eval_start_barrier, 95 | self.terminate_mark, self.reset_idx)) 96 | eval_p.start() 97 | self.eval_workers.append(eval_p) 98 | 99 | def eval_one_epoch(self): 100 | self.eval_start_barrier.wait() 101 | eval_rews = [] 102 | mean_success_rate = 0 103 | for _ in range(self.eval_worker_nums): 104 | worker_rst = self.eval_shared_que.get() 105 | eval_rews += worker_rst["eval_rewards"] 106 | mean_success_rate += worker_rst["success_rate"] 107 | 108 | return { 109 | 'eval_rewards':eval_rews, 110 | 'mean_success_rate': mean_success_rate / self.eval_worker_nums 111 | } 112 | 113 | -------------------------------------------------------------------------------- /torchrl/env/__init__.py: -------------------------------------------------------------------------------- 1 | from .get_env import get_env -------------------------------------------------------------------------------- /torchrl/env/base_wrapper.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | 4 | 5 | class BaseWrapper(gym.Wrapper): 6 | def __init__(self, env): 7 | super(BaseWrapper, self).__init__(env) 8 | self._wrapped_env = env 9 | self.training = True 10 | 11 | def train(self): 12 | if isinstance(self._wrapped_env, BaseWrapper): 13 | self._wrapped_env.train() 14 | self.training = True 15 | 16 | def eval(self): 17 | if isinstance(self._wrapped_env, BaseWrapper): 18 | self._wrapped_env.eval() 19 | self.training = False 20 | 21 | def render(self, mode='human', **kwargs): 22 | return self._wrapped_env.render(mode=mode, **kwargs) 23 | 24 | def __getattr__(self, attr): 25 | if attr == '_wrapped_env': 26 | raise AttributeError() 27 | return getattr(self._wrapped_env, attr) 28 | 29 | 30 | class RewardShift(gym.RewardWrapper, BaseWrapper): 31 | def __init__(self, env, reward_scale=1): 32 | super(RewardShift, self).__init__(env) 33 | self._reward_scale = reward_scale 34 | 35 | def reward(self, reward): 36 | if self.training: 37 | return self._reward_scale * reward 38 | else: 39 | return reward 40 | 41 | 42 | def update_mean_var_count_from_moments( 43 | mean, var, count, 44 | batch_mean, batch_var, batch_count): 45 | """ 46 | Imported From OpenAI Baseline 47 | """ 48 | delta = batch_mean - mean 49 | tot_count = count + batch_count 50 | 51 | new_mean = mean + delta * batch_count / tot_count 52 | m_a = var * count 53 | m_b = batch_var * batch_count 54 | M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count 55 | new_var = M2 / tot_count 56 | new_count = tot_count 57 | 58 | return new_mean, new_var, new_count 59 | 60 | 61 | class NormObs(gym.ObservationWrapper, BaseWrapper): 62 | """ 63 | Normalized Observation => Optional, Use Momentum 64 | """ 65 | def __init__(self, env, epsilon=1e-4, clipob=10.): 66 | super(NormObs, self).__init__(env) 67 | self.count = epsilon 68 | self.clipob = clipob 69 | self._obs_mean = np.zeros(env.observation_space.shape[0]) 70 | self._obs_var = np.ones(env.observation_space.shape[0]) 71 | 72 | def _update_obs_estimate(self, obs): 73 | self._obs_mean, self._obs_var, self.count = update_mean_var_count_from_moments( 74 | self._obs_mean, self._obs_var, self.count, obs, np.zeros_like(obs), 1) 75 | 76 | def _apply_normalize_obs(self, raw_obs): 77 | if self.training: 78 | self._update_obs_estimate(raw_obs) 79 | return np.clip( 80 | (raw_obs - self._obs_mean) / (np.sqrt(self._obs_var) + 1e-8), 81 | -self.clipob, self.clipob) 82 | 83 | def observation(self, observation): 84 | return self._apply_normalize_obs(observation) 85 | 86 | 87 | class NormRet(BaseWrapper): 88 | def __init__(self, env, discount=0.99, epsilon=1e-4): 89 | super(NormRet, self).__init__(env) 90 | self._ret = 0 91 | self.count = 1e-4 92 | self.ret_mean = 0 93 | self.ret_var = 1 94 | self.discount = discount 95 | self.epsilon = 1e-4 96 | 97 | def step(self, act): 98 | obs, rews, done, infos = self.env.step(act) 99 | if self.training: 100 | self.ret = self.ret * self.discount + rews 101 | # if self.ret_rms: 102 | self.ret_mean, self.ret_var, self.count = update_mean_var_count_from_moments( 103 | self.ret_mean, self.ret_var, self.count, self.ret, 0, 1) 104 | rews = rews / np.sqrt(self.ret_var + self.epsilon) 105 | self.ret *= (1-done) 106 | # print(self.count, self.ret_mean, self.ret_var) 107 | # print(self.training, rews) 108 | return obs, rews, done, infos 109 | 110 | def reset(self, **kwargs): 111 | self.ret = 0 112 | return self.env.reset(**kwargs) 113 | -------------------------------------------------------------------------------- /torchrl/env/continuous_wrapper.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | 4 | from .base_wrapper import BaseWrapper 5 | from gym.spaces import Box 6 | 7 | 8 | class AugObs(gym.ObservationWrapper, BaseWrapper): 9 | def __init__(self, env, env_rank, num_tasks, max_obs_dim, meta_env_params): 10 | super(AugObs, self).__init__(env) 11 | self.env_rank = env_rank 12 | self.num_tasks = num_tasks 13 | self.task_onehot = np.zeros(shape=(num_tasks,), dtype=np.float32) 14 | self.task_onehot[env_rank] = 1. 15 | self.max_obs_dim = max_obs_dim 16 | self.obs_type = meta_env_params["obs_type"] 17 | self.obs_dim = np.prod(env.observation_space.shape) 18 | 19 | if self.obs_type == "with_goal_and_id": 20 | self.obs_dim += num_tasks 21 | self.obs_dim += np.prod(env._state_goal.shape) 22 | elif self.obs_type == "with_goal": 23 | self.obs_dim += np.prod(env._state_goal.shape) 24 | elif self.obs_type == "with_goal_id": 25 | self.obs_dim += num_tasks 26 | 27 | if self.obs_dim < self.max_obs_dim: 28 | self.pedding = np.zeros(self.max_obs_dim - self.obs_dim) 29 | 30 | self.repeat_times = meta_env_params["repeat_times"] \ 31 | if "repeat_times" in meta_env_params else 1 32 | 33 | # self.set_observation_space() 34 | 35 | # if self.obs_type == 'plain': 36 | # self.observation_space = self._wrapped_env.observation_space 37 | # else: 38 | # plain_high = self._wrapped_env.observation_space.high 39 | # plain_low = self._wrapped_env.observation_space.low 40 | # goal_high = self._wrapped_env.goal_space.high 41 | # goal_low = self._wrapped_env.goal_space.low 42 | # if self.obs_type == 'with_goal': 43 | # self.observation_space = Box( 44 | # high=np.concatenate([plain_high, goal_high] + [goal_high] * (self.repeat_times -1) ), 45 | # low=np.concatenate([plain_low, goal_low] + [goal_low] * (self.repeat_times -1 ))) 46 | # elif self.obs_type == 'with_goal_id' and self._fully_discretized: 47 | # goal_id_low = np.zeros(shape=(self._n_discrete_goals * self.repeat_times,)) 48 | # goal_id_high = np.ones(shape=(self._n_discrete_goals * self.repeat_times,)) 49 | # self.observation_space = Box( 50 | # high=np.concatenate([plain_high, goal_id_low,]), 51 | # low=np.concatenate([plain_low, goal_id_high,])) 52 | # elif self.obs_type == 'with_goal_and_id' and self._fully_discretized: 53 | # goal_id_low = np.zeros(shape=(self._n_discrete_goals,)) 54 | # goal_id_high = np.ones(shape=(self._n_discrete_goals,)) 55 | # self.observation_space = Box( 56 | # high=np.concatenate([plain_high, goal_id_low, goal_high] + [goal_id_low, goal_high] * (self.repeat_times - 1) ), 57 | # low=np.concatenate([plain_low, goal_id_high, goal_low] + [goal_id_high, goal_low] * (self.repeat_times - 1) )) 58 | # else: 59 | # raise NotImplementedError 60 | 61 | 62 | def observation(self, observation): 63 | 64 | if self.obs_type == "with_goal_and_id": 65 | aug_ob = np.concatenate([self._wrapped_env._state_goal, 66 | self.task_onehot]) 67 | elif self.obs_type == "with_goal": 68 | aug_ob = self._wrapped_env._state_goal 69 | elif self.obs_type == "with_goal_id": 70 | aug_ob = self.task_onehot 71 | elif self.obs_type == "plain": 72 | aug_ob = [] 73 | 74 | aug_ob = np.concatenate([aug_ob] * self.repeat_times) 75 | if self.obs_dim < self.max_obs_dim: 76 | observation = np.concatenate([observation, self.pedding]) 77 | observation = np.concatenate([observation, aug_ob]) 78 | return observation 79 | 80 | # # @property 81 | # def set_observation_space(self): 82 | # if self._obs_type == 'plain': 83 | # self.observation_space = self._wrapped_env.observation_space 84 | # else: 85 | # plain_high = self._wrapped_env.observation_space.high 86 | # plain_low = self._wrapped_env.observation_space.low 87 | # goal_high = self._wrapped_env.goal_space.high 88 | # goal_low = self._wrapped_env.goal_space.low 89 | # if self._obs_type == 'with_goal': 90 | # self.observation_space = Box( 91 | # high=np.concatenate([plain_high, goal_high] + [goal_high] * (self.repeat_times -1) ), 92 | # low=np.concatenate([plain_low, goal_low] + [goal_low] * (self.repeat_times -1 ))) 93 | # elif self._obs_type == 'with_goal_id' and self._fully_discretized: 94 | # goal_id_low = np.zeros(shape=(self._n_discrete_goals * self.repeat_times,)) 95 | # goal_id_high = np.ones(shape=(self._n_discrete_goals * self.repeat_times,)) 96 | # self.observation_space = Box( 97 | # high=np.concatenate([plain_high, goal_id_low,]), 98 | # low=np.concatenate([plain_low, goal_id_high,])) 99 | # elif self._obs_type == 'with_goal_and_id' and self._fully_discretized: 100 | # goal_id_low = np.zeros(shape=(self._n_discrete_goals,)) 101 | # goal_id_high = np.ones(shape=(self._n_discrete_goals,)) 102 | # self.observation_space = Box( 103 | # high=np.concatenate([plain_high, goal_id_low, goal_high] + [goal_id_low, goal_high] * (self.repeat_times - 1) ), 104 | # low=np.concatenate([plain_low, goal_id_high, goal_low] + [goal_id_high, goal_low] * (self.repeat_times - 1) )) 105 | # else: 106 | # raise NotImplementedError 107 | 108 | 109 | class NormObs(gym.ObservationWrapper, BaseWrapper): 110 | """ 111 | Normalized Observation => Optional, Use Momentum 112 | """ 113 | def __init__( self, env, obs_alpha = 0.001 ): 114 | super(NormObs,self).__init__(env) 115 | self._obs_alpha = obs_alpha 116 | self._obs_mean = np.zeros(env.observation_space.shape[0]) 117 | self._obs_var = np.ones(env.observation_space.shape[0]) 118 | 119 | # Check Trajectory is ended by time limit or not 120 | class TimeLimitAugment(gym.Wrapper): 121 | def step(self, action): 122 | obs, rew, done, info = self.env.step(action) 123 | if done and self.env._max_episode_steps == self.env._elapsed_steps: 124 | info['time_limit'] = True 125 | return obs, rew, done, info 126 | 127 | def reset(self, **kwargs): 128 | return self.env.reset(**kwargs) 129 | 130 | 131 | class NormAct(gym.ActionWrapper, BaseWrapper): 132 | """ 133 | Normalized Action => [ -1, 1 ] 134 | """ 135 | def __init__(self, env): 136 | super(NormAct, self).__init__(env) 137 | ub = np.ones(self.env.action_space.shape) 138 | self.action_space = gym.spaces.Box(-1 * ub, ub) 139 | 140 | def action(self, action): 141 | lb = self.env.action_space.low 142 | ub = self.env.action_space.high 143 | scaled_action = lb + (action + 1.) * 0.5 * (ub - lb) 144 | return np.clip(scaled_action, lb, ub) 145 | -------------------------------------------------------------------------------- /torchrl/env/get_env.py: -------------------------------------------------------------------------------- 1 | from .continuous_wrapper import * 2 | from .base_wrapper import * 3 | import os 4 | import gym 5 | import mujoco_py 6 | import xml.etree.ElementTree as ET 7 | 8 | 9 | def wrap_continuous_env(env, obs_norm, reward_scale): 10 | env = RewardShift(env, reward_scale) 11 | if obs_norm: 12 | return NormObs(env) 13 | return env 14 | 15 | 16 | def get_env( env_id, env_param ): 17 | # env = gym.make(env_id) 18 | # if "customize" in env_param: 19 | # env = customized_mujoco(env, env_param["customize"]) 20 | # del env_param["customize"] 21 | # env = BaseWrapper(env) 22 | env = BaseWrapper(gym.make(env_id)) 23 | if "rew_norm" in env_param: 24 | env = NormRet(env, **env_param["rew_norm"]) 25 | del env_param["rew_norm"] 26 | 27 | ob_space = env.observation_space 28 | env = wrap_continuous_env(env, **env_param) 29 | 30 | if str(env.__class__.__name__).find('TimeLimit') >= 0: 31 | env = TimeLimitAugment(env) 32 | 33 | act_space = env.action_space 34 | if isinstance(act_space, gym.spaces.Box): 35 | return NormAct(env) 36 | return env 37 | -------------------------------------------------------------------------------- /torchrl/networks/__init__.py: -------------------------------------------------------------------------------- 1 | from .nets import * 2 | from .base import * 3 | from .init import * 4 | -------------------------------------------------------------------------------- /torchrl/networks/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import torchrl.networks.init as init 6 | 7 | class MLPBase(nn.Module): 8 | def __init__(self, input_shape, hidden_shapes, activation_func=F.relu, init_func = init.basic_init, last_activation_func = None ): 9 | super().__init__() 10 | 11 | self.activation_func = activation_func 12 | self.fcs = [] 13 | if last_activation_func is not None: 14 | self.last_activation_func = last_activation_func 15 | else: 16 | self.last_activation_func = activation_func 17 | input_shape = np.prod(input_shape) 18 | 19 | self.output_shape = input_shape 20 | for i, next_shape in enumerate( hidden_shapes ): 21 | fc = nn.Linear(input_shape, next_shape) 22 | init_func(fc) 23 | self.fcs.append(fc) 24 | # set attr for pytorch to track parameters( device ) 25 | self.__setattr__("fc{}".format(i), fc) 26 | 27 | input_shape = next_shape 28 | self.output_shape = next_shape 29 | 30 | def forward(self, x): 31 | 32 | out = x 33 | for fc in self.fcs[:-1]: 34 | out = fc(out) 35 | out = self.activation_func(out) 36 | out = self.fcs[-1](out) 37 | out = self.last_activation_func(out) 38 | return out 39 | 40 | def calc_next_shape(input_shape, conv_info): 41 | """ 42 | take input shape per-layer conv-info as input 43 | """ 44 | out_channels, kernel_size, stride, padding = conv_info 45 | c, h, w = input_shape 46 | # for padding, dilation, kernel_size, stride in conv_info: 47 | h = int((h + 2*padding[0] - ( kernel_size[0] - 1 ) - 1 ) / stride[0] + 1) 48 | w = int((w + 2*padding[1] - ( kernel_size[1] - 1 ) - 1 ) / stride[1] + 1) 49 | return (out_channels, h, w ) 50 | 51 | class CNNBase(nn.Module): 52 | def __init__(self, input_shape, hidden_shapes, activation_func=F.relu, init_func = init.basic_init, last_activation_func = None ): 53 | super().__init__() 54 | 55 | current_shape = input_shape 56 | in_channels = input_shape[0] 57 | self.activation_func = activation_func 58 | if last_activation_func is not None: 59 | self.last_activation_func = last_activation_func 60 | else: 61 | self.last_activation_func = activation_func 62 | self.convs = [] 63 | self.output_shape = current_shape[0] * current_shape[1] * current_shape[2] 64 | for i, conv_info in enumerate( hidden_shapes ): 65 | out_channels, kernel_size, stride, padding = conv_info 66 | conv = nn.Conv2d( in_channels, out_channels, kernel_size, stride, padding ) 67 | init_func(conv) 68 | self.convs.append(conv) 69 | # set attr for pytorch to track parameters( device ) 70 | self.__setattr__("conv{}".format(i), conv) 71 | 72 | in_channels = out_channels 73 | current_shape = calc_next_shape( current_shape, conv_info ) 74 | self.output_shape = current_shape[0] * current_shape[1] * current_shape[2] 75 | 76 | def forward(self, x): 77 | 78 | out = x 79 | for conv in self.convs[:-1]: 80 | out = conv(out) 81 | out = self.activation_func(out) 82 | 83 | out = self.convs[-1](out) 84 | out = self.last_activation_func(out) 85 | 86 | batch_size = out.size()[0] 87 | return out.view(batch_size, -1) 88 | -------------------------------------------------------------------------------- /torchrl/networks/init.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | 4 | def _fanin_init(tensor, alpha = 0): 5 | size = tensor.size() 6 | if len(size) == 2: 7 | fan_in = size[0] 8 | elif len(size) > 2: 9 | fan_in = np.prod(size[1:]) 10 | else: 11 | raise Exception("Shape must be have dimension at least 2.") 12 | # bound = 1. / np.sqrt(fan_in) 13 | bound = np.sqrt( 1. / ( (1 + alpha * alpha ) * fan_in) ) 14 | return tensor.data.uniform_(-bound, bound) 15 | 16 | def _uniform_init(tensor, param=3e-3): 17 | return tensor.data.uniform_(-param, param) 18 | 19 | def _constant_bias_init(tensor, constant = 0.1): 20 | tensor.data.fill_( constant ) 21 | 22 | def _normal_init(tensor, mean=0, std =1e-3): 23 | return tensor.data.normal_(mean,std) 24 | 25 | def layer_init(layer, weight_init = _fanin_init, bias_init = _constant_bias_init ): 26 | weight_init(layer.weight) 27 | bias_init(layer.bias) 28 | 29 | def basic_init(layer): 30 | layer_init(layer, weight_init = _fanin_init, bias_init = _constant_bias_init) 31 | 32 | def uniform_init(layer): 33 | layer_init(layer, weight_init = _uniform_init, bias_init = _uniform_init ) 34 | 35 | def normal_init(layer): 36 | layer_init(layer, weight_init = _normal_init, bias_init = _normal_init) 37 | 38 | def _orthogonal_init(tensor, gain = np.sqrt(2)): 39 | nn.init.orthogonal_(tensor, gain = gain) 40 | 41 | def orthogonal_init(layer, scale = np.sqrt(2), constant = 0 ): 42 | layer_init( 43 | layer, 44 | weight_init= lambda x:_orthogonal_init(x, gain=scale), 45 | bias_init=lambda x: _constant_bias_init(x, 0)) 46 | -------------------------------------------------------------------------------- /torchrl/networks/nets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchrl.networks.init as init 6 | 7 | 8 | class ZeroNet(nn.Module): 9 | def forward(self, x): 10 | return torch.zeros(1) 11 | 12 | 13 | class Net(nn.Module): 14 | def __init__( 15 | self, output_shape, 16 | base_type, 17 | append_hidden_shapes=[], 18 | append_hidden_init_func=init.basic_init, 19 | net_last_init_func=init.uniform_init, 20 | activation_func=F.relu, 21 | **kwargs): 22 | 23 | super().__init__() 24 | 25 | self.base = base_type(activation_func=activation_func, **kwargs) 26 | self.activation_func = activation_func 27 | append_input_shape = self.base.output_shape 28 | self.append_fcs = [] 29 | for i, next_shape in enumerate(append_hidden_shapes): 30 | fc = nn.Linear(append_input_shape, next_shape) 31 | append_hidden_init_func(fc) 32 | self.append_fcs.append(fc) 33 | # set attr for pytorch to track parameters( device ) 34 | self.__setattr__("append_fc{}".format(i), fc) 35 | append_input_shape = next_shape 36 | 37 | self.last = nn.Linear(append_input_shape, output_shape) 38 | net_last_init_func(self.last) 39 | 40 | def forward(self, x): 41 | out = self.base(x) 42 | 43 | for append_fc in self.append_fcs: 44 | out = append_fc(out) 45 | out = self.activation_func(out) 46 | 47 | out = self.last(out) 48 | return out 49 | 50 | 51 | class FlattenNet(Net): 52 | def forward(self, input): 53 | out = torch.cat(input, dim = -1) 54 | return super().forward(out) 55 | 56 | 57 | def null_activation(x): 58 | return x 59 | 60 | class ModularGatedCascadeCondNet(nn.Module): 61 | def __init__(self, output_shape, 62 | base_type, em_input_shape, input_shape, 63 | em_hidden_shapes, 64 | hidden_shapes, 65 | 66 | num_layers, num_modules, 67 | 68 | module_hidden, 69 | 70 | gating_hidden, num_gating_layers, 71 | 72 | # gated_hidden 73 | add_bn = True, 74 | pre_softmax = False, 75 | cond_ob = True, 76 | module_hidden_init_func = init.basic_init, 77 | last_init_func = init.uniform_init, 78 | activation_func = F.relu, 79 | **kwargs ): 80 | 81 | super().__init__() 82 | 83 | self.base = base_type( 84 | last_activation_func = null_activation, 85 | input_shape = input_shape, 86 | activation_func = activation_func, 87 | hidden_shapes = hidden_shapes, 88 | **kwargs ) 89 | self.em_base = base_type( 90 | last_activation_func = null_activation, 91 | input_shape = em_input_shape, 92 | activation_func = activation_func, 93 | hidden_shapes = em_hidden_shapes, 94 | **kwargs ) 95 | 96 | self.activation_func = activation_func 97 | 98 | module_input_shape = self.base.output_shape 99 | self.layer_modules = [] 100 | 101 | self.num_layers = num_layers 102 | self.num_modules = num_modules 103 | 104 | for i in range(num_layers): 105 | layer_module = [] 106 | for j in range( num_modules ): 107 | fc = nn.Linear(module_input_shape, module_hidden) 108 | module_hidden_init_func(fc) 109 | if add_bn: 110 | module = nn.Sequential( 111 | nn.BatchNorm1d(module_input_shape), 112 | fc, 113 | nn.BatchNorm1d(module_hidden) 114 | ) 115 | else: 116 | module = fc 117 | 118 | layer_module.append(module) 119 | self.__setattr__("module_{}_{}".format(i,j), module) 120 | 121 | module_input_shape = module_hidden 122 | self.layer_modules.append(layer_module) 123 | 124 | self.last = nn.Linear(module_input_shape, output_shape) 125 | last_init_func( self.last ) 126 | 127 | assert self.em_base.output_shape == self.base.output_shape, \ 128 | "embedding should has the same dimension with base output for gated" 129 | gating_input_shape = self.em_base.output_shape 130 | self.gating_fcs = [] 131 | for i in range(num_gating_layers): 132 | gating_fc = nn.Linear(gating_input_shape, gating_hidden) 133 | module_hidden_init_func(gating_fc) 134 | self.gating_fcs.append(gating_fc) 135 | self.__setattr__("gating_fc_{}".format(i), gating_fc) 136 | gating_input_shape = gating_hidden 137 | 138 | self.gating_weight_fcs = [] 139 | self.gating_weight_cond_fcs = [] 140 | 141 | self.gating_weight_fc_0 = nn.Linear(gating_input_shape, 142 | num_modules * num_modules ) 143 | last_init_func( self.gating_weight_fc_0) 144 | # self.gating_weight_fcs.append(self.gating_weight_fc_0) 145 | 146 | for layer_idx in range(num_layers-2): 147 | gating_weight_cond_fc = nn.Linear((layer_idx+1) * \ 148 | num_modules * num_modules, 149 | gating_input_shape) 150 | module_hidden_init_func(gating_weight_cond_fc) 151 | self.__setattr__("gating_weight_cond_fc_{}".format(layer_idx+1), 152 | gating_weight_cond_fc) 153 | self.gating_weight_cond_fcs.append(gating_weight_cond_fc) 154 | 155 | gating_weight_fc = nn.Linear(gating_input_shape, 156 | num_modules * num_modules) 157 | last_init_func(gating_weight_fc) 158 | self.__setattr__("gating_weight_fc_{}".format(layer_idx+1), 159 | gating_weight_fc) 160 | self.gating_weight_fcs.append(gating_weight_fc) 161 | 162 | self.gating_weight_cond_last = nn.Linear((num_layers-1) * \ 163 | num_modules * num_modules, 164 | gating_input_shape) 165 | module_hidden_init_func(self.gating_weight_cond_last) 166 | 167 | self.gating_weight_last = nn.Linear(gating_input_shape, num_modules) 168 | last_init_func( self.gating_weight_last ) 169 | 170 | self.pre_softmax = pre_softmax 171 | self.cond_ob = cond_ob 172 | 173 | def forward(self, x, embedding_input, return_weights = False): 174 | # Return weights for visualization 175 | out = self.base(x) 176 | embedding = self.em_base(embedding_input) 177 | 178 | if self.cond_ob: 179 | embedding = embedding * out 180 | 181 | out = self.activation_func(out) 182 | 183 | if len(self.gating_fcs) > 0: 184 | embedding = self.activation_func(embedding) 185 | for fc in self.gating_fcs[:-1]: 186 | embedding = fc(embedding) 187 | embedding = self.activation_func(embedding) 188 | embedding = self.gating_fcs[-1](embedding) 189 | 190 | base_shape = embedding.shape[:-1] 191 | 192 | weights = [] 193 | flatten_weights = [] 194 | 195 | raw_weight = self.gating_weight_fc_0(self.activation_func(embedding)) 196 | 197 | weight_shape = base_shape + torch.Size([self.num_modules, 198 | self.num_modules]) 199 | flatten_shape = base_shape + torch.Size([self.num_modules * \ 200 | self.num_modules]) 201 | 202 | raw_weight = raw_weight.view(weight_shape) 203 | 204 | softmax_weight = F.softmax(raw_weight, dim=-1) 205 | weights.append(softmax_weight) 206 | if self.pre_softmax: 207 | flatten_weights.append(raw_weight.view(flatten_shape)) 208 | else: 209 | flatten_weights.append(softmax_weight.view(flatten_shape)) 210 | 211 | for gating_weight_fc, gating_weight_cond_fc in zip(self.gating_weight_fcs, self.gating_weight_cond_fcs): 212 | cond = torch.cat(flatten_weights, dim=-1) 213 | if self.pre_softmax: 214 | cond = self.activation_func(cond) 215 | cond = gating_weight_cond_fc(cond) 216 | cond = cond * embedding 217 | cond = self.activation_func(cond) 218 | 219 | raw_weight = gating_weight_fc(cond) 220 | raw_weight = raw_weight.view(weight_shape) 221 | softmax_weight = F.softmax(raw_weight, dim=-1) 222 | weights.append(softmax_weight) 223 | if self.pre_softmax: 224 | flatten_weights.append(raw_weight.view(flatten_shape)) 225 | else: 226 | flatten_weights.append(softmax_weight.view(flatten_shape)) 227 | 228 | cond = torch.cat(flatten_weights, dim=-1) 229 | if self.pre_softmax: 230 | cond = self.activation_func(cond) 231 | cond = self.gating_weight_cond_last(cond) 232 | cond = cond * embedding 233 | cond = self.activation_func(cond) 234 | 235 | raw_last_weight = self.gating_weight_last(cond) 236 | last_weight = F.softmax(raw_last_weight, dim = -1) 237 | 238 | module_outputs = [(layer_module(out)).unsqueeze(-2) \ 239 | for layer_module in self.layer_modules[0]] 240 | 241 | module_outputs = torch.cat(module_outputs, dim = -2 ) 242 | 243 | # [TODO] Optimize using 1 * 1 convolution. 244 | 245 | for i in range(self.num_layers - 1): 246 | new_module_outputs = [] 247 | for j, layer_module in enumerate(self.layer_modules[i + 1]): 248 | module_input = (module_outputs * \ 249 | weights[i][..., j, :].unsqueeze(-1)).sum(dim=-2) 250 | 251 | module_input = self.activation_func(module_input) 252 | new_module_outputs.append(( 253 | layer_module(module_input) 254 | ).unsqueeze(-2)) 255 | 256 | module_outputs = torch.cat(new_module_outputs, dim = -2) 257 | 258 | out = (module_outputs * last_weight.unsqueeze(-1)).sum(-2) 259 | out = self.activation_func(out) 260 | out = self.last(out) 261 | 262 | if return_weights: 263 | return out, weights, last_weight 264 | return out 265 | 266 | 267 | class FlattenModularGatedCascadeCondNet(ModularGatedCascadeCondNet): 268 | def forward(self, input, embedding_input, return_weights = False): 269 | out = torch.cat( input, dim = -1 ) 270 | return super().forward(out, embedding_input, return_weights = return_weights) 271 | 272 | 273 | class BootstrappedNet(Net): 274 | def __init__(self, output_shape, 275 | head_num = 10, 276 | **kwargs ): 277 | self.head_num = head_num 278 | self.origin_output_shape = output_shape 279 | output_shape *= self.head_num 280 | super().__init__(output_shape = output_shape, **kwargs) 281 | 282 | def forward(self, x, idx): 283 | base_shape = x.shape[:-1] 284 | out = super().forward(x) 285 | out_shape = base_shape + torch.Size([self.origin_output_shape, self.head_num]) 286 | view_idx_shape = base_shape + torch.Size([1, 1]) 287 | expand_idx_shape = base_shape + torch.Size([self.origin_output_shape, 1]) 288 | 289 | out = out.reshape(out_shape) 290 | 291 | idx = idx.view(view_idx_shape) 292 | idx = idx.expand(expand_idx_shape) 293 | 294 | out = out.gather(-1, idx).squeeze(-1) 295 | return out 296 | 297 | 298 | class FlattenBootstrappedNet(BootstrappedNet): 299 | def forward(self, input, idx ): 300 | out = torch.cat( input, dim = -1 ) 301 | return super().forward(out, idx) 302 | -------------------------------------------------------------------------------- /torchrl/policies/__init__.py: -------------------------------------------------------------------------------- 1 | from .continuous_policy import * 2 | from .distribution import * -------------------------------------------------------------------------------- /torchrl/policies/continuous_policy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.distributions import Normal 4 | import numpy as np 5 | import torchrl.networks as networks 6 | from .distribution import TanhNormal 7 | import torch.nn.functional as F 8 | import torchrl.networks.init as init 9 | 10 | LOG_SIG_MAX = 2 11 | LOG_SIG_MIN = -20 12 | 13 | 14 | class UniformPolicyContinuous(nn.Module): 15 | def __init__(self, action_shape): 16 | super().__init__() 17 | self.action_shape = action_shape 18 | 19 | def forward(self, x): 20 | return torch.Tensor(np.random.uniform(-1., 1., self.action_shape)) 21 | 22 | def explore(self, x): 23 | return { 24 | "action": torch.Tensor( 25 | np.random.uniform(-1., 1., self.action_shape)) 26 | } 27 | 28 | 29 | class DetContPolicy(networks.Net): 30 | def forward(self, x): 31 | return torch.tanh(super().forward(x)) 32 | 33 | def eval_act( self, x ): 34 | with torch.no_grad(): 35 | return self.forward(x).squeeze(0).detach().cpu().numpy() 36 | 37 | def explore( self, x ): 38 | return { 39 | "action": self.forward(x).squeeze(0) 40 | } 41 | 42 | 43 | class FixGuassianContPolicy(networks.Net): 44 | def __init__(self, norm_std_explore, **kwargs): 45 | super().__init__(**kwargs) 46 | self.norm_std_explore = norm_std_explore 47 | 48 | def forward(self, x): 49 | return torch.tanh(super().forward(x)) 50 | 51 | def eval_act(self, x): 52 | with torch.no_grad(): 53 | return self.forward(x).squeeze(0).detach().cpu().numpy() 54 | 55 | def explore(self, x): 56 | action = self.forward(x).squeeze(0) 57 | action += Normal( 58 | torch.zeros(action.size()), 59 | self.norm_std_explore * torch.ones(action.size()) 60 | ).sample().to(action.device) 61 | 62 | return { 63 | "action": action 64 | } 65 | 66 | 67 | class GuassianContPolicy(networks.Net): 68 | def forward(self, x): 69 | x = super().forward(x) 70 | 71 | mean, log_std = x.chunk(2, dim=-1) 72 | 73 | log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) 74 | std = torch.exp(log_std) 75 | 76 | return mean, std, log_std 77 | 78 | def eval_act(self, x): 79 | with torch.no_grad(): 80 | mean, _, _ = self.forward(x) 81 | return torch.tanh(mean.squeeze(0)).detach().cpu().numpy() 82 | 83 | def explore( self, x, return_log_probs = False, return_pre_tanh = False ): 84 | 85 | mean, std, log_std = self.forward(x) 86 | 87 | dis = TanhNormal(mean, std) 88 | 89 | ent = dis.entropy().sum(-1, keepdim=True) 90 | 91 | dic = { 92 | "mean": mean, 93 | "log_std": log_std, 94 | "ent":ent 95 | } 96 | 97 | if return_log_probs: 98 | action, z = dis.rsample(return_pretanh_value=True) 99 | log_prob = dis.log_prob( 100 | action, 101 | pre_tanh_value=z 102 | ) 103 | log_prob = log_prob.sum(dim=-1, keepdim=True) 104 | dic["pre_tanh"] = z.squeeze(0) 105 | dic["log_prob"] = log_prob 106 | else: 107 | if return_pre_tanh: 108 | action, z = dis.rsample(return_pretanh_value=True) 109 | dic["pre_tanh"] = z.squeeze(0) 110 | action = dis.rsample(return_pretanh_value=False) 111 | 112 | dic["action"] = action.squeeze(0) 113 | return dic 114 | 115 | def update(self, obs, actions): 116 | mean, std, log_std = self.forward(obs) 117 | dis = TanhNormal(mean, std) 118 | 119 | log_prob = dis.log_prob(actions).sum(-1, keepdim=True) 120 | ent = dis.entropy().sum(-1, keepdim=True) 121 | 122 | out = { 123 | "mean": mean, 124 | "log_std": log_std, 125 | "log_prob": log_prob, 126 | "ent": ent 127 | } 128 | return out 129 | 130 | 131 | class GuassianContPolicyBasicBias(networks.Net): 132 | 133 | def __init__(self, output_shape, **kwargs): 134 | super().__init__(output_shape=output_shape, **kwargs) 135 | self.logstd = nn.Parameter(torch.zeros(output_shape)) 136 | 137 | def forward(self, x): 138 | mean = super().forward(x) 139 | 140 | logstd = torch.clamp(self.logstd, LOG_SIG_MIN, LOG_SIG_MAX) 141 | std = torch.exp(logstd) 142 | std = std.unsqueeze(0).expand_as(mean) 143 | return mean, std, logstd 144 | 145 | def eval_act( self, x ): 146 | with torch.no_grad(): 147 | mean, std, log_std = self.forward(x) 148 | # return torch.tanh(mean.squeeze(0)).detach().cpu().numpy() 149 | return mean.squeeze(0).detach().cpu().numpy() 150 | 151 | def explore(self, x, return_log_probs = False, return_pre_tanh = False): 152 | 153 | mean, std, log_std = self.forward(x) 154 | 155 | dis = Normal(mean, std) 156 | # dis = TanhNormal(mean, std) 157 | 158 | ent = dis.entropy().sum(1, keepdim=True) 159 | 160 | dic = { 161 | "mean": mean, 162 | "log_std": log_std, 163 | "ent": ent 164 | } 165 | 166 | if return_log_probs: 167 | action = dis.sample() 168 | log_prob = dis.log_prob(action) 169 | log_prob = log_prob.sum(dim=1, keepdim=True) 170 | # dic["pre_tanh"] = z.squeeze(0) 171 | dic["log_prob"] = log_prob 172 | else: 173 | # if return_pre_tanh: 174 | # action, z = dis.rsample(return_pretanh_value=True) 175 | # dic["pre_tanh"] = z.squeeze(0) 176 | action = dis.sample() 177 | 178 | dic["action"] = action.squeeze(0) 179 | return dic 180 | 181 | # if return_log_probs: 182 | # action, z = dis.rsample(return_pretanh_value=True) 183 | # log_prob = dis.log_prob( 184 | # action, 185 | # pre_tanh_value=z 186 | # ) 187 | # log_prob = log_prob.sum(dim=1, keepdim=True) 188 | # dic["pre_tanh"] = z.squeeze(0) 189 | # dic["log_prob"] = log_prob 190 | # else: 191 | # if return_pre_tanh: 192 | # action, z = dis.rsample(return_pretanh_value=True) 193 | # dic["pre_tanh"] = z.squeeze(0) 194 | # action = dis.rsample(return_pretanh_value=False) 195 | 196 | # dic["action"] = action.squeeze(0) 197 | # return dic 198 | 199 | def update(self, obs, actions): 200 | mean, std, log_std = self.forward(obs) 201 | # dis = TanhNormal(mean, std) 202 | dis = Normal(mean, std) 203 | # dis = TanhNormal(mean, std) 204 | 205 | log_prob = dis.log_prob(actions).sum(-1, keepdim=True) 206 | ent = dis.entropy().sum(1, keepdim=True) 207 | 208 | out = { 209 | "mean": mean, 210 | "log_std": log_std, 211 | "log_prob": log_prob, 212 | "ent": ent 213 | } 214 | return out 215 | 216 | class EmbeddingGuassianContPolicyBase: 217 | 218 | def eval_act( self, x, embedding_input ): 219 | with torch.no_grad(): 220 | mean, std, log_std = self.forward(x, embedding_input) 221 | return torch.tanh(mean.squeeze(0)).detach().cpu().numpy() 222 | 223 | def explore( self, x, embedding_input, return_log_probs = False, return_pre_tanh = False ): 224 | 225 | mean, std, log_std = self.forward(x, embedding_input) 226 | 227 | dis = TanhNormal(mean, std) 228 | 229 | ent = dis.entropy().sum(-1, keepdim=True) 230 | 231 | dic = { 232 | "mean": mean, 233 | "log_std": log_std, 234 | "ent":ent 235 | } 236 | 237 | if return_log_probs: 238 | action, z = dis.rsample( return_pretanh_value = True ) 239 | log_prob = dis.log_prob( 240 | action, 241 | pre_tanh_value=z 242 | ) 243 | log_prob = log_prob.sum(dim=-1, keepdim=True) 244 | dic["pre_tanh"] = z.squeeze(0) 245 | dic["log_prob"] = log_prob 246 | else: 247 | if return_pre_tanh: 248 | action, z = dis.rsample( return_pretanh_value = True ) 249 | dic["pre_tanh"] = z.squeeze(0) 250 | action = dis.rsample( return_pretanh_value = False ) 251 | 252 | dic["action"] = action.squeeze(0) 253 | return dic 254 | 255 | def update(self, obs, embedding_input, actions): 256 | mean, std, log_std = self.forward(obs, embedding_input) 257 | dis = TanhNormal(mean, std) 258 | 259 | log_prob = dis.log_prob(actions).sum(-1, keepdim=True) 260 | ent = dis.entropy().sum(1, keepdim=True) 261 | 262 | out = { 263 | "mean": mean, 264 | "log_std": log_std, 265 | "log_prob": log_prob, 266 | "ent": ent 267 | } 268 | return out 269 | 270 | 271 | class EmbeddingDetContPolicyBase: 272 | def eval_act( self, x, embedding_input ): 273 | with torch.no_grad(): 274 | return torch.tanh(self.forward(x, embedding_input)).squeeze(0).detach().cpu().numpy() 275 | 276 | 277 | def explore( self, x, embedding_input ): 278 | return { 279 | "action":torch.tanh( 280 | self.forward(x, embedding_input)).squeeze(0)} 281 | 282 | 283 | class ModularGuassianGatedCascadeCondContPolicy(networks.ModularGatedCascadeCondNet, EmbeddingGuassianContPolicyBase): 284 | def forward(self, x, embedding_input, return_weights = False ): 285 | x = super().forward(x, embedding_input, return_weights = return_weights) 286 | if isinstance(x, tuple): 287 | general_weights = x[1] 288 | last_weights = x[2] 289 | x = x[0] 290 | 291 | mean, log_std = x.chunk(2, dim=-1) 292 | 293 | log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) 294 | std = torch.exp(log_std) 295 | 296 | if return_weights: 297 | return mean, std, log_std, general_weights, last_weights 298 | # return mean, std, log_std, general_weights 299 | return mean, std, log_std 300 | 301 | def eval_act( self, x, embedding_input, return_weights = False ): 302 | with torch.no_grad(): 303 | if return_weights: 304 | # mean, std, log_std, general_weights, last_weights = self.forward(x, embedding_input, return_weights) 305 | mean, std, log_std, general_weights = self.forward(x, embedding_input, return_weights) 306 | else: 307 | mean, std, log_std = self.forward(x, embedding_input, return_weights) 308 | if return_weights: 309 | # return torch.tanh(mean.squeeze(0)).detach().cpu().numpy(), general_weights, last_weights 310 | return torch.tanh(mean.squeeze(0)).detach().cpu().numpy(), general_weights 311 | return torch.tanh(mean.squeeze(0)).detach().cpu().numpy() 312 | 313 | 314 | def explore( self, x, embedding_input, return_log_probs = False, 315 | return_pre_tanh = False, return_weights = False ): 316 | if return_weights: 317 | mean, std, log_std, general_weights, last_weights = self.forward(x, embedding_input, return_weights) 318 | # general_weights, last_weights = weights 319 | dic = { 320 | "general_weights": general_weights, 321 | "last_weights": last_weights 322 | } 323 | else: 324 | mean, std, log_std = self.forward(x, embedding_input) 325 | dic = {} 326 | 327 | dis = TanhNormal(mean, std) 328 | 329 | ent = dis.entropy().sum(-1, keepdim=True) 330 | 331 | dic.update({ 332 | "mean": mean, 333 | "log_std": log_std, 334 | "ent":ent 335 | }) 336 | 337 | if return_log_probs: 338 | action, z = dis.rsample( return_pretanh_value = True ) 339 | log_prob = dis.log_prob( 340 | action, 341 | pre_tanh_value=z 342 | ) 343 | log_prob = log_prob.sum(dim=-1, keepdim=True) 344 | dic["pre_tanh"] = z.squeeze(0) 345 | dic["log_prob"] = log_prob 346 | else: 347 | if return_pre_tanh: 348 | action, z = dis.rsample( return_pretanh_value = True ) 349 | dic["pre_tanh"] = z.squeeze(0) 350 | action = dis.rsample( return_pretanh_value = False ) 351 | 352 | dic["action"] = action.squeeze(0) 353 | return dic 354 | 355 | 356 | class MultiHeadGuassianContPolicy(networks.BootstrappedNet): 357 | def forward(self, x, idx): 358 | x = super().forward(x, idx) 359 | 360 | mean, log_std = x.chunk(2, dim=-1) 361 | 362 | log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) 363 | std = torch.exp(log_std) 364 | 365 | return mean, std, log_std 366 | 367 | def eval_act( self, x, idx ): 368 | with torch.no_grad(): 369 | mean, _, _= self.forward(x, idx) 370 | return torch.tanh(mean.squeeze(0)).detach().cpu().numpy() 371 | 372 | def explore( self, x, idx, return_log_probs=False, return_pre_tanh=False): 373 | mean, std, log_std = self.forward(x, idx) 374 | 375 | dis = TanhNormal(mean, std) 376 | 377 | ent = dis.entropy().sum(-1, keepdim=True) 378 | 379 | dic = { 380 | "mean": mean, 381 | "log_std": log_std, 382 | "ent":ent 383 | } 384 | 385 | if return_log_probs: 386 | action, z = dis.rsample( return_pretanh_value = True ) 387 | log_prob = dis.log_prob( 388 | action, 389 | pre_tanh_value=z 390 | ) 391 | log_prob = log_prob.sum(dim=-1, keepdim=True) 392 | dic["pre_tanh"] = z.squeeze(0) 393 | dic["log_prob"] = log_prob 394 | else: 395 | if return_pre_tanh: 396 | action, z = dis.rsample( return_pretanh_value = True ) 397 | dic["pre_tanh"] = z.squeeze(0) 398 | action = dis.rsample( return_pretanh_value = False ) 399 | 400 | dic["action"] = action.squeeze(0) 401 | return dic 402 | -------------------------------------------------------------------------------- /torchrl/policies/distribution.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.distributions import Distribution, Normal 3 | 4 | class TanhNormal(Distribution): 5 | """ 6 | Basically from RLKIT 7 | 8 | Represent distribution of X where 9 | X ~ tanh(Z) 10 | Z ~ N(mean, std) 11 | 12 | Note: this is not very numerically stable. 13 | """ 14 | def __init__(self, normal_mean, normal_std, epsilon=1e-6): 15 | """ 16 | :param normal_mean: Mean of the normal distribution 17 | :param normal_std: Std of the normal distribution 18 | :param epsilon: Numerical stability epsilon when computing log-prob. 19 | """ 20 | self.normal_mean = normal_mean 21 | self.normal_std = normal_std 22 | self.normal = Normal(normal_mean, normal_std) 23 | self.epsilon = epsilon 24 | 25 | def sample_n(self, n, return_pre_tanh_value=False): 26 | z = self.normal.sample_n(n) 27 | if return_pre_tanh_value: 28 | return torch.tanh(z), z 29 | else: 30 | return torch.tanh(z) 31 | 32 | def log_prob(self, value, pre_tanh_value=None): 33 | """ 34 | 35 | :param value: some value, x 36 | :param pre_tanh_value: arctanh(x) 37 | :return: 38 | """ 39 | if pre_tanh_value is None: 40 | pre_tanh_value = torch.log( 41 | (1+value) / (1-value) 42 | ) / 2 43 | return self.normal.log_prob(pre_tanh_value) - torch.log( 44 | 1 - value * value + self.epsilon 45 | ) 46 | 47 | def sample(self, return_pretanh_value=False): 48 | """ 49 | Gradients will and should *not* pass through this operation. 50 | 51 | See https://github.com/pytorch/pytorch/issues/4620 for discussion. 52 | """ 53 | z = self.normal.sample().detach() 54 | 55 | if return_pretanh_value: 56 | return torch.tanh(z), z 57 | else: 58 | return torch.tanh(z) 59 | 60 | def rsample(self, return_pretanh_value=False): 61 | """ 62 | Sampling in the reparameterization case. 63 | """ 64 | z = ( 65 | self.normal_mean + 66 | self.normal_std * 67 | Normal( 68 | torch.zeros(self.normal_mean.size()), 69 | torch.ones(self.normal_std.size()) 70 | ).sample().to(self.normal_mean.device) 71 | ) 72 | 73 | if return_pretanh_value: 74 | return torch.tanh(z), z 75 | else: 76 | return torch.tanh(z) 77 | 78 | def entropy(self): 79 | return self.normal.entropy() 80 | -------------------------------------------------------------------------------- /torchrl/replay_buffers/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseReplayBuffer 2 | from .base import BaseMTReplayBuffer -------------------------------------------------------------------------------- /torchrl/replay_buffers/base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class BaseReplayBuffer(): 4 | """ 5 | Basic Replay Buffer 6 | """ 7 | def __init__( 8 | self, max_replay_buffer_size, time_limit_filter=False, 9 | ): 10 | self.worker_nums = 1 11 | self._max_replay_buffer_size = max_replay_buffer_size 12 | self._top = 0 13 | self._size = 0 14 | self.time_limit_filter = time_limit_filter 15 | 16 | def add_sample(self, sample_dict, env_rank=0, **kwargs): 17 | for key in sample_dict: 18 | if not hasattr(self, "_" + key): 19 | self.__setattr__( 20 | "_" + key, 21 | np.zeros((self._max_replay_buffer_size, 1) + \ 22 | np.shape(sample_dict[key]))) 23 | self.__getattribute__("_" + key)[self._top, 0] = sample_dict[key] 24 | self._advance() 25 | 26 | def terminate_episode(self): 27 | pass 28 | 29 | def _advance(self): 30 | self._top = (self._top + 1) % self._max_replay_buffer_size 31 | if self._size < self._max_replay_buffer_size: 32 | self._size += 1 33 | 34 | def random_batch(self, batch_size, sample_key): 35 | indices = np.random.randint(0, self._size, batch_size) 36 | return_dict = {} 37 | for key in sample_key: 38 | return_dict[key] = np.squeeze(self.__getattribute__("_"+key)[indices], axis=1) 39 | return return_dict 40 | 41 | def num_steps_can_sample(self): 42 | return self._size 43 | 44 | class BaseMTReplayBuffer(BaseReplayBuffer): 45 | """ 46 | Just for imitation Learning 47 | """ 48 | def __init__(self, 49 | max_replay_buffer_size, 50 | task_nums, 51 | ): 52 | super(BaseMTReplayBuffer, self).__init__(max_replay_buffer_size) 53 | self.task_nums = task_nums 54 | 55 | # Not USED 56 | def add_sample(self, sample_dict, env_rank = 0, **kwargs): 57 | pass 58 | 59 | def terminate_episode(self): 60 | pass 61 | 62 | def _advance(self): 63 | pass 64 | 65 | def random_batch(self, batch_size, sample_key, reshape = True): 66 | assert batch_size % self.task_nums== 0, \ 67 | "batch size should be dividable by worker_nums" 68 | batch_size //= self.task_nums 69 | size = self.num_steps_can_sample() 70 | indices = np.random.randint(0, size, batch_size) 71 | return_dict = {} 72 | for key in sample_key: 73 | return_dict[key] = self.__getattribute__("_"+key)[indices] 74 | if reshape: 75 | return_dict[key] = return_dict[key].reshape( 76 | (batch_size * self.worker_nums, -1)) 77 | return return_dict 78 | 79 | def num_steps_can_sample(self): 80 | return self._size -------------------------------------------------------------------------------- /torchrl/replay_buffers/shared/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import SharedBaseReplayBuffer 2 | from .base import AsyncSharedReplayBuffer -------------------------------------------------------------------------------- /torchrl/replay_buffers/shared/base.py: -------------------------------------------------------------------------------- 1 | # Since we could ensure that multi-proces would write into the different parts 2 | # For efficiency, we use Multiprocess.RawArray 3 | 4 | from torch.multiprocessing import RawArray 5 | from multiprocessing.managers import BaseProxy 6 | import numpy as np 7 | 8 | from torchrl.replay_buffers.base import BaseReplayBuffer 9 | 10 | from .shmarray import NpShmemArray 11 | 12 | from .shmarray import get_random_tag 13 | 14 | class SharedBaseReplayBuffer(BaseReplayBuffer): 15 | """ 16 | Basic Replay Buffer 17 | """ 18 | def __init__(self, 19 | max_replay_buffer_size, 20 | worker_nums 21 | # example_dict, 22 | # tag 23 | ): 24 | super().__init__(max_replay_buffer_size) 25 | 26 | self.worker_nums = worker_nums 27 | assert self._max_replay_buffer_size % self.worker_nums == 0, \ 28 | "buffer size is not dividable by worker num" 29 | self._max_replay_buffer_size //= self.worker_nums 30 | 31 | if not hasattr(self, "tag"): 32 | self.tag = get_random_tag() 33 | 34 | def build_by_example(self, example_dict): 35 | self._size = NpShmemArray(self.worker_nums, np.int32, self.tag+"_size") 36 | self._top = NpShmemArray(self.worker_nums, np.int32, self.tag+"_top") 37 | 38 | self.tags = {} 39 | self.shapes = {} 40 | for key in example_dict: 41 | if not hasattr( self, "_" + key ): 42 | current_tag = "_"+key 43 | self.tags[current_tag] = self.tag+current_tag 44 | shape = (self._max_replay_buffer_size, self.worker_nums) + \ 45 | np.shape(example_dict[key]) 46 | self.shapes[current_tag] = shape 47 | 48 | np_array = NpShmemArray(shape, np.float32, self.tag+current_tag) 49 | self.__setattr__(current_tag, np_array ) 50 | 51 | def rebuild_from_tag(self): 52 | 53 | self._size = NpShmemArray(self.worker_nums, np.int32, 54 | self.tag+"_size", create=False) 55 | self._top = NpShmemArray(self.worker_nums, np.int32, 56 | self.tag+"_top", create=False) 57 | 58 | for key in self.tags: 59 | np_array = NpShmemArray(self.shapes[key], np.float32, 60 | self.tags[key], create=False) 61 | self.__setattr__(key, np_array ) 62 | 63 | def add_sample(self, sample_dict, worker_rank, **kwargs): 64 | for key in sample_dict: 65 | self.__getattribute__( "_" + key )[self._top[worker_rank], worker_rank] = sample_dict[key] 66 | self._advance(worker_rank) 67 | 68 | def terminate_episode(self): 69 | pass 70 | 71 | def _advance(self, worker_rank): 72 | self._top[worker_rank] = (self._top[worker_rank] + 1) % \ 73 | self._max_replay_buffer_size 74 | if self._size[worker_rank] < self._max_replay_buffer_size: 75 | self._size[worker_rank] = self._size[worker_rank] + 1 76 | 77 | def random_batch(self, batch_size, sample_key, reshape = True): 78 | assert batch_size % self.worker_nums == 0, \ 79 | "batch size should be dividable by worker_nums" 80 | batch_size //= self.worker_nums 81 | size = self.num_steps_can_sample() 82 | indices = np.random.randint(0, size, batch_size) 83 | return_dict = {} 84 | for key in sample_key: 85 | return_dict[key] = self.__getattribute__("_"+key)[indices] 86 | if reshape: 87 | return_dict[key] = return_dict[key].reshape( 88 | (batch_size * self.worker_nums, -1)) 89 | return return_dict 90 | 91 | def num_steps_can_sample(self): 92 | min_size = np.min(self._size) 93 | max_size = np.max(self._size) 94 | assert max_size == min_size, \ 95 | "all worker should gather the same amount of samples" 96 | return min_size 97 | 98 | class AsyncSharedReplayBuffer(SharedBaseReplayBuffer): 99 | def num_steps_can_sample(self): 100 | # Use asynchronized sampling could cause sample collected is 101 | # different across different workers but actually it's find 102 | min_size = np.min(self._size) 103 | return min_size -------------------------------------------------------------------------------- /torchrl/replay_buffers/shared/shmarray.py: -------------------------------------------------------------------------------- 1 | """ 2 | From rlpyt // Currently not used in RLPYT 3 | 4 | Slightly modified by Rchal to support MAC 5 | """ 6 | # 7 | # Based on multiprocessing.sharedctypes.RawArray 8 | # 9 | # Uses posix_ipc (http://semanchuk.com/philip/posix_ipc/) to allow shared ctypes arrays 10 | # among unrelated processors 11 | # 12 | # Usage Notes: 13 | # * The first two args (typecode_or_type and size_or_initializer) should work the same as with RawArray. 14 | # * The shared array is accessible by any process, as long as tag matches. 15 | # * The shared memory segment is unlinked when the origin array (that returned 16 | # by ShmemRawArray(..., create=True)) is deleted/gc'ed 17 | # * Creating an shared array using a tag that currently exists will raise an ExistentialError 18 | # * Accessing a shared array using a tag that doesn't exist (or one that has been unlinked) will also 19 | # raise an ExistentialError 20 | # 21 | # Author: Shawn Chin (http://shawnchin.github.com) 22 | # 23 | # Edited for python 3 by: Adam Stooke 24 | # 25 | 26 | import numpy as np 27 | # import os 28 | import time 29 | import sys 30 | import mmap 31 | import ctypes 32 | import posix_ipc 33 | # from _multiprocessing import address_of_buffer # (not in python 3) 34 | from string import ascii_letters, digits 35 | 36 | valid_chars = frozenset("/-_. %s%s" % (ascii_letters, digits)) 37 | 38 | typecode_to_type = { 39 | 'c': ctypes.c_char, 'u': ctypes.c_wchar, 40 | 'b': ctypes.c_byte, 'B': ctypes.c_ubyte, 41 | 'h': ctypes.c_short, 'H': ctypes.c_ushort, 42 | 'i': ctypes.c_int, 'I': ctypes.c_uint, 43 | 'l': ctypes.c_long, 'L': ctypes.c_ulong, 44 | 'f': ctypes.c_float, 'd': ctypes.c_double 45 | } 46 | 47 | 48 | def address_of_buffer(buf): # (python 3) 49 | return ctypes.addressof(ctypes.c_char.from_buffer(buf)) 50 | 51 | 52 | class ShmemBufferWrapper: 53 | 54 | def __init__(self, tag, size, create=True): 55 | # default vals so __del__ doesn't fail if __init__ fails to complete 56 | self._mem = None 57 | self._map = None 58 | self._owner = create 59 | self.size = size 60 | 61 | assert 0 <= size < sys.maxsize # sys.maxint (python 3) 62 | flag = (0, posix_ipc.O_CREX)[create] 63 | mem_size = (0, self.size)[create] 64 | 65 | self._mem = posix_ipc.SharedMemory(tag, flags=flag, size=mem_size) 66 | self._map = mmap.mmap(self._mem.fd, self._mem.size) 67 | self._mem.close_fd() 68 | 69 | def get_address(self): 70 | # assert self._map.size() == self.size # (changed for python 3) 71 | assert self._map.size() >= self.size # strictly equal might not meet in MAC 72 | addr = address_of_buffer(self._map) 73 | return addr 74 | 75 | def __del__(self): 76 | if self._map is not None: 77 | self._map.close() 78 | if self._mem is not None and self._owner: 79 | self._mem.unlink() 80 | 81 | 82 | def ShmemRawArray(typecode_or_type, size_or_initializer, tag, create=True): 83 | assert frozenset(tag).issubset(valid_chars) 84 | if tag[0] != "/": 85 | tag = "/%s" % (tag,) 86 | 87 | type_ = typecode_to_type.get(typecode_or_type, typecode_or_type) 88 | if isinstance(size_or_initializer, int): 89 | type_ = type_ * size_or_initializer 90 | else: 91 | type_ = type_ * len(size_or_initializer) 92 | 93 | buffer = ShmemBufferWrapper(tag, ctypes.sizeof(type_), create=create) 94 | obj = type_.from_address(buffer.get_address()) 95 | obj._buffer = buffer 96 | 97 | if not isinstance(size_or_initializer, int): 98 | obj.__init__(*size_or_initializer) 99 | 100 | return obj 101 | 102 | 103 | ############################################################################### 104 | # New Additions (by Adam) # 105 | 106 | 107 | def NpShmemArray(shape, dtype, tag, create=True): 108 | size = int(np.prod(shape)) 109 | nbytes = size * np.dtype(dtype).itemsize 110 | shmem = ShmemRawArray(ctypes.c_char, nbytes, tag, create) 111 | return np.frombuffer(shmem, dtype=dtype, count=size).reshape(shape) 112 | 113 | 114 | def get_random_tag(): 115 | return str(time.time()).replace(".", "")[-9:] 116 | -------------------------------------------------------------------------------- /torchrl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .args import get_args 2 | from .args import get_params 3 | from .logger import Logger 4 | -------------------------------------------------------------------------------- /torchrl/utils/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import torch 5 | 6 | def get_args(): 7 | parser = argparse.ArgumentParser(description='RL') 8 | 9 | parser.add_argument('--seed', type=int, default=0, 10 | help='random seed (default: 1)') 11 | 12 | parser.add_argument('--worker_nums', type=int, default=4, 13 | help='worker nums') 14 | 15 | parser.add_argument('--eval_worker_nums', type=int, default=2, 16 | help='eval worker nums') 17 | 18 | parser.add_argument("--config", type=str, default=None, 19 | help="config file", ) 20 | 21 | parser.add_argument('--save_dir', type=str, default='./snapshots', 22 | help='directory for snapshots (default: ./snapshots)') 23 | 24 | parser.add_argument('--data_dir', type=str, default='./data', 25 | help='directory for snapshots (default: ./snapshots)') 26 | 27 | parser.add_argument('--log_dir', type=str, default='./log', 28 | help='directory for tensorboard logs (default: ./log)') 29 | 30 | parser.add_argument('--no_cuda', action='store_true', default=False, 31 | help='disables CUDA training') 32 | 33 | parser.add_argument("--device", type=int, default=0, 34 | help="gpu secification", ) 35 | 36 | # tensorboard 37 | parser.add_argument("--id", type=str, default=None, 38 | help="id for tensorboard", ) 39 | 40 | # policy snapshot 41 | parser.add_argument("--pf_snap", type=str, default=None, 42 | help="policy snapshot path", ) 43 | # q function snapshot 44 | parser.add_argument("--qf1_snap", type=str, default=None, 45 | help="policy snapshot path", ) 46 | # q function snapshot 47 | parser.add_argument("--qf2_snap", type=str, default=None, 48 | help="policy snapshot path", ) 49 | 50 | args = parser.parse_args() 51 | 52 | args.cuda = not args.no_cuda and torch.cuda.is_available() 53 | 54 | return args 55 | 56 | def get_params(file_name): 57 | with open(file_name) as f: 58 | params = json.load(f) 59 | return params 60 | -------------------------------------------------------------------------------- /torchrl/utils/logger.py: -------------------------------------------------------------------------------- 1 | import tensorboardX 2 | import logging 3 | import shutil 4 | import os 5 | import numpy as np 6 | from tabulate import tabulate 7 | import sys 8 | import json 9 | import csv 10 | 11 | class Logger(): 12 | def __init__(self, experiment_id, env_name, seed, params, log_dir = "./log"): 13 | 14 | self.logger = logging.getLogger("{}_{}_{}".format(experiment_id,env_name,str(seed))) 15 | 16 | self.logger.handlers = [] 17 | self.logger.propagate = False 18 | sh = logging.StreamHandler(sys.stdout) 19 | format = "%(asctime)s %(threadName)s %(levelname)s: %(message)s" 20 | formatter = logging.Formatter(format) 21 | sh.setFormatter(formatter) 22 | sh.setLevel(logging.INFO) 23 | self.logger.addHandler( sh ) 24 | self.logger.setLevel(logging.INFO) 25 | 26 | work_dir = os.path.join( log_dir, experiment_id, env_name, str(seed) ) 27 | self.work_dir = work_dir 28 | if os.path.exists( work_dir ): 29 | shutil.rmtree(work_dir) 30 | self.tf_writer = tensorboardX.SummaryWriter(work_dir) 31 | 32 | self.csv_file_path = os.path.join(work_dir, 'log.csv') 33 | 34 | self.update_count = 0 35 | self.stored_infos = {} 36 | 37 | with open( os.path.join(work_dir, 'params.json'), 'w' ) as output_param: 38 | json.dump( params, output_param, indent = 2 ) 39 | 40 | self.logger.info("Experiment Name:{}".format(experiment_id)) 41 | self.logger.info( 42 | json.dumps(params, indent = 2 ) 43 | ) 44 | 45 | def log(self, info): 46 | self.logger.info(info) 47 | 48 | def add_update_info(self, infos): 49 | 50 | for info in infos: 51 | if info not in self.stored_infos : 52 | self.stored_infos[info] = [] 53 | self.stored_infos[info].append( infos[info] ) 54 | 55 | self.update_count += 1 56 | 57 | def add_epoch_info(self, epoch_num, total_frames, total_time, infos, csv_write=True): 58 | if csv_write: 59 | if epoch_num == 0: 60 | csv_titles = ["EPOCH", "Time Consumed", "Total Frames"] 61 | csv_values = [epoch_num, total_time, total_frames] 62 | 63 | self.logger.info("EPOCH:{}".format(epoch_num)) 64 | self.logger.info("Time Consumed:{}s".format(total_time)) 65 | self.logger.info("Total Frames:{}s".format(total_frames)) 66 | 67 | tabulate_list = [["Name", "Value"]] 68 | 69 | for info in infos: 70 | self.tf_writer.add_scalar( info, infos[info], total_frames ) 71 | tabulate_list.append([ info, "{:.5f}".format( infos[info] ) ]) 72 | if csv_write: 73 | if epoch_num == 0: 74 | csv_titles += [info] 75 | csv_values += ["{:.5f}".format(infos[info])] 76 | 77 | tabulate_list.append([]) 78 | 79 | method_list = [ np.mean, np.std, np.max, np.min ] 80 | name_list = [ "Mean", "Std", "Max", "Min" ] 81 | tabulate_list.append( ["Name"] + name_list ) 82 | 83 | for info in self.stored_infos: 84 | 85 | temp_list = [info] 86 | for name, method in zip( name_list, method_list ): 87 | processed_info = method(self.stored_infos[info]) 88 | self.tf_writer.add_scalar( "{}_{}".format( info, name ), 89 | processed_info, total_frames ) 90 | temp_list.append( "{:.5f}".format( processed_info ) ) 91 | if csv_write: 92 | if epoch_num == 0: 93 | csv_titles += ["{}_{}".format(info, name)] 94 | csv_values += ["{:.5f}".format(processed_info)] 95 | 96 | tabulate_list.append( temp_list ) 97 | #clear 98 | self.stored_infos = {} 99 | if csv_write: 100 | with open(self.csv_file_path, 'a') as f: 101 | self.csv_writer = csv.writer(f) 102 | if epoch_num == 0: 103 | self.csv_writer.writerow(csv_titles) 104 | self.csv_writer.writerow(csv_values) 105 | 106 | print( tabulate(tabulate_list) ) 107 | -------------------------------------------------------------------------------- /torchrl/utils/plot_csv.py: -------------------------------------------------------------------------------- 1 | import time 2 | import pickle 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import sys 6 | import os 7 | from collections import OrderedDict 8 | import argparse 9 | import seaborn as sns 10 | import csv 11 | 12 | 13 | def get_args(): 14 | parser = argparse.ArgumentParser(description='RL') 15 | parser.add_argument('--seed', type=int, nargs='+', default=(0,), 16 | help='random seed (default: (0,))') 17 | parser.add_argument('--max_m', type=int, default=None, 18 | help='maximum million') 19 | parser.add_argument('--smooth_coeff', type=int, default=25, 20 | help='smooth coeff') 21 | parser.add_argument('--env_name', type=str, default='mt10', 22 | help='environment trained on (default: mt10)') 23 | parser.add_argument('--log_dir', type=str, default='./log', 24 | help='directory for tensorboard logs (default: ./log)') 25 | parser.add_argument( "--id", type=str, nargs='+', default=('origin',), 26 | help="id for tensorboard") 27 | parser.add_argument( "--tags", type=str, nargs='+', default=None, 28 | help="id for tensorboard") 29 | parser.add_argument('--output_dir', type=str, default='./fig', 30 | help='directory for plot output (default: ./fig)') 31 | parser.add_argument('--entry', type=str, default='Running_Average_Rewards', 32 | help='Record Entry') 33 | parser.add_argument('--add_tag', type=str, default='', 34 | help='added tag') 35 | args = parser.parse_args() 36 | return args 37 | 38 | 39 | args = get_args() 40 | env_name = args.env_name 41 | env_id = args.id 42 | 43 | if args.tags is None: 44 | args.tags = args.id 45 | assert len(args.tags) == len(args.id) 46 | 47 | 48 | def post_process(array): 49 | smoth_para = args.smooth_coeff 50 | new_array = [] 51 | for i in range(len(array)): 52 | if i < len(array) - smoth_para: 53 | new_array.append(np.mean(array[i:i+smoth_para])) 54 | else: 55 | new_array.append(np.mean(array[i:None])) 56 | return new_array 57 | 58 | 59 | sns.set("paper") 60 | 61 | current_palette = sns.color_palette() 62 | sns.palplot(current_palette) 63 | 64 | fig = plt.figure(figsize=(10,7)) 65 | plt.subplots_adjust(left=0.07, bottom=0.15, right=1, top=0.90, 66 | wspace=0, hspace=0) 67 | 68 | ax1 = fig.add_subplot(111) 69 | colors = current_palette 70 | linestyles_choose = ['solid', 'solid', 'solid', 'solid', 'solid', 'solid', 'solid'] 71 | 72 | for eachcolor, eachlinestyle, exp_name, exp_tag in zip(colors, linestyles_choose, args.id, args.tags ): 73 | min_step_number = 1000000000000 74 | step_number = [] 75 | all_scores = {} 76 | 77 | for seed in args.seed: 78 | file_path = os.path.join(args.log_dir, exp_name, env_name, str(seed), 'log.csv') 79 | 80 | all_scores[seed] = [] 81 | temp_step_number = [] 82 | with open(file_path,'r') as f: 83 | csv_reader = csv.DictReader(f) 84 | for row in csv_reader: 85 | all_scores[seed].append(float(row[args.entry])) 86 | temp_step_number.append(int(row["Total Frames"])) 87 | 88 | if temp_step_number[-1] < min_step_number: 89 | min_step_number = temp_step_number[-1] 90 | step_number = temp_step_number 91 | 92 | all_mean = [] 93 | all_upper = [] 94 | all_lower = [] 95 | 96 | step_number = np.array(step_number) / 1e6 97 | final_step = [] 98 | for i in range(len(step_number)): 99 | if args.max_m is not None and step_number[i] >= args.max_m: 100 | continue 101 | final_step.append(step_number[i]) 102 | temp_list = [] 103 | for key, valueList in all_scores.items(): 104 | try: 105 | temp_list.append(valueList[i]) 106 | except Exception: 107 | print(i) 108 | # exit() 109 | all_mean.append(np.mean(temp_list)) 110 | all_upper.append(np.mean(temp_list) + np.std(temp_list)) 111 | all_lower.append(np.mean(temp_list) - np.std(temp_list)) 112 | # print(exp_tag, np.mean(all_mean[-10:])) 113 | all_mean = post_process(all_mean) 114 | all_lower = post_process(all_lower) 115 | all_upper = post_process(all_upper) 116 | 117 | ax1.plot(final_step, all_mean, label=exp_tag, color=eachcolor, linestyle=eachlinestyle, linewidth=2) 118 | ax1.plot(final_step, all_upper, color=eachcolor, linestyle=eachlinestyle, alpha = 0.23, linewidth=1) 119 | ax1.plot(final_step, all_lower, color=eachcolor, linestyle=eachlinestyle, alpha = 0.23, linewidth=1) 120 | ax1.fill_between(final_step, all_lower, all_upper, alpha=0.2, color=eachcolor) 121 | 122 | 123 | ax1.set_xlabel('Million Samples', fontsize=30) 124 | ax1.tick_params(labelsize=25) 125 | 126 | box = ax1.get_position() 127 | 128 | leg = ax1.legend( 129 | loc='best', 130 | ncol=1, 131 | fontsize=25) 132 | 133 | for legobj in leg.legendHandles: 134 | legobj.set_linewidth(10.0) 135 | 136 | plt.title("{} {}".format(env_name, args.entry), fontsize=40) 137 | if not os.path.exists( args.output_dir ): 138 | os.mkdir( args.output_dir ) 139 | plt.savefig( os.path.join( args.output_dir, '{}_{}{}.png'.format(env_name, args.entry, args.add_tag) ) ) 140 | plt.close() --------------------------------------------------------------------------------