├── tianshou ├── data │ ├── buffer │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── base.cpython-37.pyc │ │ │ ├── prio.cpython-37.pyc │ │ │ ├── cached.cpython-37.pyc │ │ │ ├── manager.cpython-37.pyc │ │ │ ├── vecbuf.cpython-37.pyc │ │ │ └── __init__.cpython-37.pyc │ │ ├── vecbuf.py │ │ ├── cached.py │ │ └── prio.py │ ├── utils │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── segtree.cpython-37.pyc │ │ │ └── converter.cpython-37.pyc │ │ └── segtree.py │ ├── __pycache__ │ │ ├── batch.cpython-37.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── collector.cpython-37.pyc │ │ ├── abstracter.cpython-37.pyc │ │ ├── interfaces.cpython-37.pyc │ │ ├── abstracter_adv.cpython-37.pyc │ │ ├── necsa_collector.cpython-37.pyc │ │ ├── necsa_adv_collector.cpython-37.pyc │ │ └── necsa_atari_collector.cpython-37.pyc │ └── __init__.py ├── utils │ ├── logger │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── base.cpython-37.pyc │ │ │ ├── wandb.cpython-37.pyc │ │ │ ├── __init__.cpython-37.pyc │ │ │ └── tensorboard.cpython-37.pyc │ │ ├── tensorboard.py │ │ └── base.py │ ├── net │ │ ├── __init__.py │ │ └── __pycache__ │ │ │ ├── common.cpython-37.pyc │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── discrete.cpython-37.pyc │ │ │ └── continuous.cpython-37.pyc │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── warning.cpython-37.pyc │ │ ├── statistics.cpython-37.pyc │ │ ├── lr_scheduler.cpython-37.pyc │ │ └── progress_bar.cpython-37.pyc │ ├── warning.py │ ├── __init__.py │ ├── progress_bar.py │ ├── lr_scheduler.py │ └── statistics.py ├── policy │ ├── imitation │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── base.cpython-37.pyc │ │ │ ├── bcq.cpython-37.pyc │ │ │ ├── cql.cpython-37.pyc │ │ │ ├── gail.cpython-37.pyc │ │ │ ├── td3_bc.cpython-37.pyc │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── discrete_bcq.cpython-37.pyc │ │ │ ├── discrete_cql.cpython-37.pyc │ │ │ └── discrete_crr.cpython-37.pyc │ │ ├── base.py │ │ ├── discrete_cql.py │ │ ├── td3_bc.py │ │ ├── discrete_bcq.py │ │ └── discrete_crr.py │ ├── modelbased │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── icm.cpython-37.pyc │ │ │ ├── psrl.cpython-37.pyc │ │ │ └── __init__.cpython-37.pyc │ │ └── icm.py │ ├── modelfree │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── pg.cpython-37.pyc │ │ │ ├── a2c.cpython-37.pyc │ │ │ ├── bdq.cpython-37.pyc │ │ │ ├── c51.cpython-37.pyc │ │ │ ├── ddpg.cpython-37.pyc │ │ │ ├── dqn.cpython-37.pyc │ │ │ ├── fqf.cpython-37.pyc │ │ │ ├── iqn.cpython-37.pyc │ │ │ ├── npg.cpython-37.pyc │ │ │ ├── ppo.cpython-37.pyc │ │ │ ├── redq.cpython-37.pyc │ │ │ ├── sac.cpython-37.pyc │ │ │ ├── td3.cpython-37.pyc │ │ │ ├── trpo.cpython-37.pyc │ │ │ ├── qrdqn.cpython-37.pyc │ │ │ ├── rainbow.cpython-37.pyc │ │ │ ├── __init__.cpython-37.pyc │ │ │ └── discrete_sac.cpython-37.pyc │ │ ├── rainbow.py │ │ ├── qrdqn.py │ │ ├── c51.py │ │ ├── iqn.py │ │ ├── pg.py │ │ └── td3.py │ ├── multiagent │ │ ├── __init__.py │ │ └── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ └── mapolicy.cpython-37.pyc │ ├── __pycache__ │ │ ├── base.cpython-37.pyc │ │ ├── random.cpython-37.pyc │ │ └── __init__.cpython-37.pyc │ ├── random.py │ └── __init__.py ├── env │ ├── __pycache__ │ │ ├── utils.cpython-37.pyc │ │ ├── venvs.cpython-37.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── gym_wrappers.cpython-37.pyc │ │ ├── pettingzoo_env.cpython-37.pyc │ │ └── venv_wrappers.cpython-37.pyc │ ├── worker │ │ ├── __pycache__ │ │ │ ├── ray.cpython-37.pyc │ │ │ ├── base.cpython-37.pyc │ │ │ ├── dummy.cpython-37.pyc │ │ │ ├── __init__.cpython-37.pyc │ │ │ └── subproc.cpython-37.pyc │ │ ├── __init__.py │ │ ├── dummy.py │ │ ├── ray.py │ │ └── base.py │ ├── utils.py │ ├── __init__.py │ ├── gym_wrappers.py │ ├── venv_wrappers.py │ └── pettingzoo_env.py ├── trainer │ ├── __pycache__ │ │ ├── base.cpython-37.pyc │ │ ├── utils.cpython-37.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── offline.cpython-37.pyc │ │ ├── onpolicy.cpython-37.pyc │ │ └── offpolicy.cpython-37.pyc │ ├── __init__.py │ ├── utils.py │ └── offline.py ├── exploration │ ├── __pycache__ │ │ ├── random.cpython-37.pyc │ │ └── __init__.cpython-37.pyc │ ├── __init__.py │ └── random.py └── __init__.py ├── .gitignore ├── rebuttal ├── Alien.png ├── Enduro.png ├── Qbert.png ├── Breakout.png ├── MsPacman.png ├── Alien-grid.png ├── Alien-scores.png ├── Alien-state.png ├── Alien-steps.png ├── multi-steps.png ├── Alien-backbone.png ├── Alien-epsilon.png ├── SpaceInvaders.png └── score_advantage.png ├── scripts ├── Ant-v3 │ ├── train_baseline.sh │ └── train_NECSA_TD3.sh ├── Reacher-v2 │ ├── train_baseline.sh │ └── train_NECSA_TD3.sh ├── Swimmer-v3 │ ├── train_baseline.sh │ └── train_NECSA_TD3.sh ├── Hopper-v3 │ ├── train_baseline.sh │ ├── train_NECSA_ADV_TD3.sh │ └── train_NECSA_TD3.sh ├── Walker2d-v3 │ ├── train_baseline.sh │ ├── train_NECSA_4_step_TD3.sh │ ├── train_NECSA_ADV_TD3.sh │ └── train_NECSA_TD3.sh ├── Humanoid-v3 │ ├── train_baseline.sh │ └── train_NECSA_TD3.sh ├── HalfCheetah-v3 │ ├── train_baseline.sh │ ├── train_NECSA_TD3.sh │ └── train_NECSA_DDPG.sh ├── InvertedPendulum-v2 │ ├── train_baseline.sh │ └── train_NECSA_TD3.sh ├── Alien │ ├── train_baseline.sh │ └── train_NECSA.sh ├── Qbert │ ├── train_baseline.sh │ └── train_NECSA.sh ├── Enduro │ ├── train_baseline.sh │ └── train_NECSA.sh ├── InvertedDoublePendulum-v2 │ ├── train_baseline.sh │ └── train_NECSA_TD3.sh ├── Breakout │ ├── train_baseline.sh │ └── train_NECSA.sh ├── MsPacman │ ├── train_baseline.sh │ └── train_NECSA.sh └── SpaceInvaders │ ├── train_baseline.sh │ └── train_NECSA.sh ├── mujoco_env.py ├── requirements.txt └── README.md /tianshou/data/buffer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tianshou/data/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tianshou/utils/logger/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tianshou/utils/net/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tianshou/policy/imitation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tianshou/policy/modelbased/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tianshou/policy/modelfree/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tianshou/policy/multiagent/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | log/ 2 | __pycache__ 3 | nohup* -------------------------------------------------------------------------------- /rebuttal/Alien.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/rebuttal/Alien.png -------------------------------------------------------------------------------- /rebuttal/Enduro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/rebuttal/Enduro.png -------------------------------------------------------------------------------- /rebuttal/Qbert.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/rebuttal/Qbert.png -------------------------------------------------------------------------------- /rebuttal/Breakout.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/rebuttal/Breakout.png -------------------------------------------------------------------------------- /rebuttal/MsPacman.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/rebuttal/MsPacman.png -------------------------------------------------------------------------------- /rebuttal/Alien-grid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/rebuttal/Alien-grid.png -------------------------------------------------------------------------------- /rebuttal/Alien-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/rebuttal/Alien-scores.png -------------------------------------------------------------------------------- /rebuttal/Alien-state.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/rebuttal/Alien-state.png -------------------------------------------------------------------------------- /rebuttal/Alien-steps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/rebuttal/Alien-steps.png -------------------------------------------------------------------------------- /rebuttal/multi-steps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/rebuttal/multi-steps.png -------------------------------------------------------------------------------- /rebuttal/Alien-backbone.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/rebuttal/Alien-backbone.png -------------------------------------------------------------------------------- /rebuttal/Alien-epsilon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/rebuttal/Alien-epsilon.png -------------------------------------------------------------------------------- /rebuttal/SpaceInvaders.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/rebuttal/SpaceInvaders.png -------------------------------------------------------------------------------- /rebuttal/score_advantage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/rebuttal/score_advantage.png -------------------------------------------------------------------------------- /tianshou/env/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/env/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/env/__pycache__/venvs.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/env/__pycache__/venvs.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/data/__pycache__/batch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/data/__pycache__/batch.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/policy/__pycache__/base.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/policy/__pycache__/base.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/data/__pycache__/collector.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/data/__pycache__/collector.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/env/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/env/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/env/worker/__pycache__/ray.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/env/worker/__pycache__/ray.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/policy/__pycache__/random.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/policy/__pycache__/random.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/trainer/__pycache__/base.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/trainer/__pycache__/base.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/trainer/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/trainer/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/utils/__pycache__/warning.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/utils/__pycache__/warning.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/data/__pycache__/abstracter.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/data/__pycache__/abstracter.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/data/__pycache__/interfaces.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/data/__pycache__/interfaces.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/data/buffer/__pycache__/base.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/data/buffer/__pycache__/base.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/data/buffer/__pycache__/prio.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/data/buffer/__pycache__/prio.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/env/__pycache__/gym_wrappers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/env/__pycache__/gym_wrappers.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/env/worker/__pycache__/base.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/env/worker/__pycache__/base.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/env/worker/__pycache__/dummy.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/env/worker/__pycache__/dummy.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/policy/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/policy/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/trainer/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/trainer/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/trainer/__pycache__/offline.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/trainer/__pycache__/offline.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/trainer/__pycache__/onpolicy.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/trainer/__pycache__/onpolicy.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/utils/__pycache__/statistics.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/utils/__pycache__/statistics.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/utils/net/__pycache__/common.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/utils/net/__pycache__/common.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/data/__pycache__/abstracter_adv.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/data/__pycache__/abstracter_adv.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/data/buffer/__pycache__/cached.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/data/buffer/__pycache__/cached.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/data/buffer/__pycache__/manager.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/data/buffer/__pycache__/manager.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/data/buffer/__pycache__/vecbuf.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/data/buffer/__pycache__/vecbuf.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/data/utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/data/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/data/utils/__pycache__/segtree.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/data/utils/__pycache__/segtree.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/env/__pycache__/pettingzoo_env.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/env/__pycache__/pettingzoo_env.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/env/__pycache__/venv_wrappers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/env/__pycache__/venv_wrappers.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/env/worker/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/env/worker/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/env/worker/__pycache__/subproc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/env/worker/__pycache__/subproc.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/exploration/__pycache__/random.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/exploration/__pycache__/random.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/policy/modelfree/__pycache__/pg.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/policy/modelfree/__pycache__/pg.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/trainer/__pycache__/offpolicy.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/trainer/__pycache__/offpolicy.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/utils/__pycache__/lr_scheduler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/utils/__pycache__/lr_scheduler.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/utils/__pycache__/progress_bar.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/utils/__pycache__/progress_bar.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/utils/logger/__pycache__/base.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/utils/logger/__pycache__/base.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/utils/logger/__pycache__/wandb.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/utils/logger/__pycache__/wandb.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/utils/net/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/utils/net/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/utils/net/__pycache__/discrete.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/utils/net/__pycache__/discrete.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/data/__pycache__/necsa_collector.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/data/__pycache__/necsa_collector.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/data/buffer/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/data/buffer/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/data/utils/__pycache__/converter.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/data/utils/__pycache__/converter.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/exploration/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/exploration/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/policy/imitation/__pycache__/base.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/policy/imitation/__pycache__/base.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/policy/imitation/__pycache__/bcq.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/policy/imitation/__pycache__/bcq.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/policy/imitation/__pycache__/cql.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/policy/imitation/__pycache__/cql.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/policy/imitation/__pycache__/gail.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/policy/imitation/__pycache__/gail.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/policy/modelbased/__pycache__/icm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/policy/modelbased/__pycache__/icm.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/policy/modelfree/__pycache__/a2c.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/policy/modelfree/__pycache__/a2c.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/policy/modelfree/__pycache__/bdq.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/policy/modelfree/__pycache__/bdq.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/policy/modelfree/__pycache__/c51.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/policy/modelfree/__pycache__/c51.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/policy/modelfree/__pycache__/ddpg.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/policy/modelfree/__pycache__/ddpg.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/policy/modelfree/__pycache__/dqn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/policy/modelfree/__pycache__/dqn.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/policy/modelfree/__pycache__/fqf.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/policy/modelfree/__pycache__/fqf.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/policy/modelfree/__pycache__/iqn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/policy/modelfree/__pycache__/iqn.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/policy/modelfree/__pycache__/npg.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/policy/modelfree/__pycache__/npg.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/policy/modelfree/__pycache__/ppo.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/policy/modelfree/__pycache__/ppo.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/policy/modelfree/__pycache__/redq.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/policy/modelfree/__pycache__/redq.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/policy/modelfree/__pycache__/sac.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/policy/modelfree/__pycache__/sac.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/policy/modelfree/__pycache__/td3.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/policy/modelfree/__pycache__/td3.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/policy/modelfree/__pycache__/trpo.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/policy/modelfree/__pycache__/trpo.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/utils/logger/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/utils/logger/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/utils/net/__pycache__/continuous.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/utils/net/__pycache__/continuous.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/data/__pycache__/necsa_adv_collector.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/data/__pycache__/necsa_adv_collector.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/policy/imitation/__pycache__/td3_bc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/policy/imitation/__pycache__/td3_bc.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/policy/modelbased/__pycache__/psrl.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/policy/modelbased/__pycache__/psrl.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/policy/modelfree/__pycache__/qrdqn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/policy/modelfree/__pycache__/qrdqn.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/policy/modelfree/__pycache__/rainbow.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/policy/modelfree/__pycache__/rainbow.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/utils/logger/__pycache__/tensorboard.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/utils/logger/__pycache__/tensorboard.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/data/__pycache__/necsa_atari_collector.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/data/__pycache__/necsa_atari_collector.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/policy/imitation/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/policy/imitation/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/policy/modelbased/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/policy/modelbased/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/policy/modelfree/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/policy/modelfree/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/policy/multiagent/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/policy/multiagent/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/policy/multiagent/__pycache__/mapolicy.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/policy/multiagent/__pycache__/mapolicy.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/policy/imitation/__pycache__/discrete_bcq.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/policy/imitation/__pycache__/discrete_bcq.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/policy/imitation/__pycache__/discrete_cql.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/policy/imitation/__pycache__/discrete_cql.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/policy/imitation/__pycache__/discrete_crr.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/policy/imitation/__pycache__/discrete_crr.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/policy/modelfree/__pycache__/discrete_sac.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhuo-1994/NECSA/HEAD/tianshou/policy/modelfree/__pycache__/discrete_sac.cpython-37.pyc -------------------------------------------------------------------------------- /tianshou/exploration/__init__.py: -------------------------------------------------------------------------------- 1 | from tianshou.exploration.random import BaseNoise, GaussianNoise, OUNoise 2 | 3 | __all__ = [ 4 | "BaseNoise", 5 | "GaussianNoise", 6 | "OUNoise", 7 | ] 8 | -------------------------------------------------------------------------------- /tianshou/utils/warning.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.simplefilter("once", DeprecationWarning) 4 | 5 | 6 | def deprecation(msg: str) -> None: 7 | """Deprecation warning wrapper.""" 8 | warnings.warn(msg, category=DeprecationWarning, stacklevel=2) 9 | -------------------------------------------------------------------------------- /tianshou/__init__.py: -------------------------------------------------------------------------------- 1 | from tianshou import data, env, exploration, policy, trainer, utils 2 | 3 | __version__ = "0.4.9" 4 | 5 | __all__ = [ 6 | "env", 7 | "data", 8 | "utils", 9 | "policy", 10 | "trainer", 11 | "exploration", 12 | ] 13 | -------------------------------------------------------------------------------- /tianshou/env/worker/__init__.py: -------------------------------------------------------------------------------- 1 | from tianshou.env.worker.base import EnvWorker 2 | from tianshou.env.worker.dummy import DummyEnvWorker 3 | from tianshou.env.worker.ray import RayEnvWorker 4 | from tianshou.env.worker.subproc import SubprocEnvWorker 5 | 6 | __all__ = [ 7 | "EnvWorker", 8 | "DummyEnvWorker", 9 | "SubprocEnvWorker", 10 | "RayEnvWorker", 11 | ] 12 | -------------------------------------------------------------------------------- /tianshou/env/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import cloudpickle 4 | 5 | 6 | class CloudpickleWrapper(object): 7 | """A cloudpickle wrapper used in SubprocVectorEnv.""" 8 | 9 | def __init__(self, data: Any) -> None: 10 | self.data = data 11 | 12 | def __getstate__(self) -> str: 13 | return cloudpickle.dumps(self.data) 14 | 15 | def __setstate__(self, data: str) -> None: 16 | self.data = cloudpickle.loads(data) 17 | -------------------------------------------------------------------------------- /scripts/Ant-v3/train_baseline.sh: -------------------------------------------------------------------------------- 1 | # python ddpg.py --task Ant-v3 --epoch 1000 2 | # pkill -f ddpg 3 | # python ddpg.py --task Ant-v3 --epoch 1000 4 | # pkill -f ddpg 5 | # python ddpg.py --task Ant-v3 --epoch 1000 6 | # pkill -f ddpg 7 | # python ddpg.py --task Ant-v3 --epoch 1000 8 | # pkill -f ddpg 9 | # python ddpg.py --task Ant-v3 --epoch 1000 10 | # pkill -f ddpg 11 | 12 | 13 | 14 | python td3.py --task Ant-v3 --epoch 1000 15 | pkill -f td3 16 | python td3.py --task Ant-v3 --epoch 1000 17 | pkill -f td3 18 | python td3.py --task Ant-v3 --epoch 1000 19 | pkill -f td3 20 | python td3.py --task Ant-v3 --epoch 1000 21 | pkill -f td3 22 | python td3.py --task Ant-v3 --epoch 1000 23 | pkill -f td3 -------------------------------------------------------------------------------- /scripts/Reacher-v2/train_baseline.sh: -------------------------------------------------------------------------------- 1 | python ddpg.py --task Reacher-v2 --epoch 100 2 | pkill -f ddpg 3 | python ddpg.py --task Reacher-v2 --epoch 100 4 | pkill -f ddpg 5 | python ddpg.py --task Reacher-v2 --epoch 100 6 | pkill -f ddpg 7 | python ddpg.py --task Reacher-v2 --epoch 100 8 | pkill -f ddpg 9 | python ddpg.py --task Reacher-v2 --epoch 100 10 | pkill -f ddpg 11 | 12 | 13 | 14 | python td3.py --task Reacher-v2 --epoch 100 15 | pkill -f td3 16 | python td3.py --task Reacher-v2 --epoch 100 17 | pkill -f td3 18 | python td3.py --task Reacher-v2 --epoch 100 19 | pkill -f td3 20 | python td3.py --task Reacher-v2 --epoch 100 21 | pkill -f td3 22 | python td3.py --task Reacher-v2 --epoch 100 23 | pkill -f td3 -------------------------------------------------------------------------------- /scripts/Swimmer-v3/train_baseline.sh: -------------------------------------------------------------------------------- 1 | python ddpg.py --task Swimmer-v3 --epoch 400 2 | pkill -f ddpg 3 | python ddpg.py --task Swimmer-v3 --epoch 400 4 | pkill -f ddpg 5 | python ddpg.py --task Swimmer-v3 --epoch 400 6 | pkill -f ddpg 7 | python ddpg.py --task Swimmer-v3 --epoch 400 8 | pkill -f ddpg 9 | python ddpg.py --task Swimmer-v3 --epoch 400 10 | pkill -f ddpg 11 | 12 | 13 | python td3.py --task Swimmer-v3 --epoch 400 14 | pkill -f td3 15 | python td3.py --task Swimmer-v3 --epoch 400 16 | pkill -f td3 17 | python td3.py --task Swimmer-v3 --epoch 400 18 | pkill -f td3 19 | python td3.py --task Swimmer-v3 --epoch 400 20 | pkill -f td3 21 | python td3.py --task Swimmer-v3 --epoch 400 22 | pkill -f td3 -------------------------------------------------------------------------------- /scripts/Hopper-v3/train_baseline.sh: -------------------------------------------------------------------------------- 1 | #python ddpg.py --task Hopper-v3 --epoch 400 2 | #pkill -f ddpg 3 | #python ddpg.py --task Hopper-v3 --epoch 400 4 | #pkill -f ddpg 5 | #python ddpg.py --task Hopper-v3 --epoch 400 6 | #pkill -f ddpg 7 | #python ddpg.py --task Hopper-v3 --epoch 400 8 | #pkill -f ddpg 9 | #python ddpg.py --task Hopper-v3 --epoch 400 10 | #pkill -f ddpg 11 | 12 | 13 | 14 | python td3.py --task Hopper-v3 --epoch 400 15 | pkill -f td3 16 | python td3.py --task Hopper-v3 --epoch 400 17 | pkill -f td3 18 | python td3.py --task Hopper-v3 --epoch 400 19 | pkill -f td3 20 | python td3.py --task Hopper-v3 --epoch 400 21 | pkill -f td3 22 | python td3.py --task Hopper-v3 --epoch 400 23 | pkill -f td3 -------------------------------------------------------------------------------- /scripts/Walker2d-v3/train_baseline.sh: -------------------------------------------------------------------------------- 1 | python ddpg.py --task Walker2d-v3 --epoch 400 2 | pkill -f ddpg 3 | python ddpg.py --task Walker2d-v3 --epoch 400 4 | pkill -f ddpg 5 | python ddpg.py --task Walker2d-v3 --epoch 400 6 | pkill -f ddpg 7 | python ddpg.py --task Walker2d-v3 --epoch 400 8 | pkill -f ddpg 9 | python ddpg.py --task Walker2d-v3 --epoch 400 10 | pkill -f ddpg 11 | 12 | 13 | 14 | python td3.py --task Walker2d-v3 --epoch 400 15 | pkill -f td3 16 | python td3.py --task Walker2d-v3 --epoch 400 17 | pkill -f td3 18 | python td3.py --task Walker2d-v3 --epoch 400 19 | pkill -f td3 20 | python td3.py --task Walker2d-v3 --epoch 400 21 | pkill -f td3 22 | python td3.py --task Walker2d-v3 --epoch 400 23 | pkill -f td3 -------------------------------------------------------------------------------- /scripts/Humanoid-v3/train_baseline.sh: -------------------------------------------------------------------------------- 1 | python ddpg.py --task Humanoid-v3 --epoch 1000 2 | pkill -f ddpg 3 | python ddpg.py --task Humanoid-v3 --epoch 1000 4 | pkill -f ddpg 5 | python ddpg.py --task Humanoid-v3 --epoch 1000 6 | pkill -f ddpg 7 | python ddpg.py --task Humanoid-v3 --epoch 1000 8 | pkill -f ddpg 9 | python ddpg.py --task Humanoid-v3 --epoch 1000 10 | pkill -f ddpg 11 | 12 | 13 | 14 | python td3.py --task Humanoid-v3 --epoch 1000 15 | pkill -f td3 16 | python td3.py --task Humanoid-v3 --epoch 1000 17 | pkill -f td3 18 | python td3.py --task Humanoid-v3 --epoch 1000 19 | pkill -f td3 20 | python td3.py --task Humanoid-v3 --epoch 1000 21 | pkill -f td3 22 | python td3.py --task Humanoid-v3 --epoch 1000 23 | pkill -f td3 -------------------------------------------------------------------------------- /scripts/HalfCheetah-v3/train_baseline.sh: -------------------------------------------------------------------------------- 1 | python ddpg.py --task HalfCheetah-v3 --epoch 1000 2 | pkill -f ddpg 3 | python ddpg.py --task HalfCheetah-v3 --epoch 1000 4 | pkill -f ddpg 5 | python ddpg.py --task HalfCheetah-v3 --epoch 1000 6 | pkill -f ddpg 7 | python ddpg.py --task HalfCheetah-v3 --epoch 1000 8 | pkill -f ddpg 9 | python ddpg.py --task HalfCheetah-v3 --epoch 1000 10 | pkill -f ddpg 11 | 12 | 13 | python td3.py --task HalfCheetah-v3 --epoch 1000 14 | pkill -f td3 15 | python td3.py --task HalfCheetah-v3 --epoch 1000 16 | pkill -f td3 17 | python td3.py --task HalfCheetah-v3 --epoch 1000 18 | pkill -f td3 19 | python td3.py --task HalfCheetah-v3 --epoch 1000 20 | pkill -f td3 21 | python td3.py --task HalfCheetah-v3 --epoch 1000 22 | pkill -f td3 -------------------------------------------------------------------------------- /tianshou/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Utils package.""" 2 | 3 | from tianshou.utils.logger.base import BaseLogger, LazyLogger 4 | from tianshou.utils.logger.tensorboard import BasicLogger, TensorboardLogger 5 | from tianshou.utils.logger.wandb import WandbLogger 6 | from tianshou.utils.lr_scheduler import MultipleLRSchedulers 7 | from tianshou.utils.progress_bar import DummyTqdm, tqdm_config 8 | from tianshou.utils.statistics import MovAvg, RunningMeanStd 9 | from tianshou.utils.warning import deprecation 10 | 11 | __all__ = [ 12 | "MovAvg", 13 | "RunningMeanStd", 14 | "tqdm_config", 15 | "DummyTqdm", 16 | "BaseLogger", 17 | "TensorboardLogger", 18 | "BasicLogger", 19 | "LazyLogger", 20 | "WandbLogger", 21 | "deprecation", 22 | "MultipleLRSchedulers", 23 | ] 24 | -------------------------------------------------------------------------------- /scripts/InvertedPendulum-v2/train_baseline.sh: -------------------------------------------------------------------------------- 1 | python ddpg.py --task InvertedPendulum-v2 --epoch 100 2 | pkill -f ddpg 3 | python ddpg.py --task InvertedPendulum-v2 --epoch 100 4 | pkill -f ddpg 5 | python ddpg.py --task InvertedPendulum-v2 --epoch 100 6 | pkill -f ddpg 7 | python ddpg.py --task InvertedPendulum-v2 --epoch 100 8 | pkill -f ddpg 9 | python ddpg.py --task InvertedPendulum-v2 --epoch 100 10 | pkill -f ddpg 11 | 12 | 13 | 14 | python td3.py --task InvertedPendulum-v2 --epoch 100 15 | pkill -f td3 16 | python td3.py --task InvertedPendulum-v2 --epoch 100 17 | pkill -f td3 18 | python td3.py --task InvertedPendulum-v2 --epoch 100 19 | pkill -f td3 20 | python td3.py --task InvertedPendulum-v2 --epoch 100 21 | pkill -f td3 22 | python td3.py --task InvertedPendulum-v2 --epoch 100 23 | pkill -f td3 -------------------------------------------------------------------------------- /tianshou/env/__init__.py: -------------------------------------------------------------------------------- 1 | """Env package.""" 2 | 3 | from tianshou.env.gym_wrappers import ContinuousToDiscrete, MultiDiscreteToDiscrete 4 | from tianshou.env.venv_wrappers import VectorEnvNormObs, VectorEnvWrapper 5 | from tianshou.env.venvs import ( 6 | BaseVectorEnv, 7 | DummyVectorEnv, 8 | RayVectorEnv, 9 | ShmemVectorEnv, 10 | SubprocVectorEnv, 11 | ) 12 | 13 | try: 14 | from tianshou.env.pettingzoo_env import PettingZooEnv 15 | except ImportError: 16 | pass 17 | 18 | __all__ = [ 19 | "BaseVectorEnv", 20 | "DummyVectorEnv", 21 | "SubprocVectorEnv", 22 | "ShmemVectorEnv", 23 | "RayVectorEnv", 24 | "VectorEnvWrapper", 25 | "VectorEnvNormObs", 26 | "PettingZooEnv", 27 | "ContinuousToDiscrete", 28 | "MultiDiscreteToDiscrete", 29 | ] 30 | -------------------------------------------------------------------------------- /scripts/Alien/train_baseline.sh: -------------------------------------------------------------------------------- 1 | python dqn.py --task AlienNoFrameskip-v4 --epoch 500 2 | killall -9 python 3 | python dqn.py --task AlienNoFrameskip-v4 --epoch 500 4 | killall -9 python 5 | python dqn.py --task AlienNoFrameskip-v4 --epoch 500 6 | killall -9 python 7 | python dqn.py --task AlienNoFrameskip-v4 --epoch 500 8 | killall -9 python 9 | python dqn.py --task AlienNoFrameskip-v4 --epoch 500 10 | killall -9 python 11 | 12 | 13 | python rainbow.py --task AlienNoFrameskip-v4 --epoch 500 14 | killall -9 python 15 | python rainbow.py --task AlienNoFrameskip-v4 --epoch 500 16 | killall -9 python 17 | python rainbow.py --task AlienNoFrameskip-v4 --epoch 500 18 | killall -9 python 19 | python rainbow.py --task AlienNoFrameskip-v4 --epoch 500 20 | killall -9 python 21 | python rainbow.py --task AlienNoFrameskip-v4 --epoch 500 22 | killall -9 python 23 | -------------------------------------------------------------------------------- /scripts/Qbert/train_baseline.sh: -------------------------------------------------------------------------------- 1 | python dqn.py --task QbertNoFrameskip-v4 --epoch 500 2 | killall -9 python 3 | python dqn.py --task QbertNoFrameskip-v4 --epoch 500 4 | killall -9 python 5 | python dqn.py --task QbertNoFrameskip-v4 --epoch 500 6 | killall -9 python 7 | python dqn.py --task QbertNoFrameskip-v4 --epoch 500 8 | killall -9 python 9 | python dqn.py --task QbertNoFrameskip-v4 --epoch 500 10 | killall -9 python 11 | 12 | 13 | python rainbow.py --task QbertNoFrameskip-v4 --epoch 500 14 | killall -9 python 15 | python rainbow.py --task QbertNoFrameskip-v4 --epoch 500 16 | killall -9 python 17 | python rainbow.py --task QbertNoFrameskip-v4 --epoch 500 18 | killall -9 python 19 | python rainbow.py --task QbertNoFrameskip-v4 --epoch 500 20 | killall -9 python 21 | python rainbow.py --task QbertNoFrameskip-v4 --epoch 500 22 | killall -9 python 23 | -------------------------------------------------------------------------------- /scripts/Enduro/train_baseline.sh: -------------------------------------------------------------------------------- 1 | python dqn.py --task EnduroNoFrameskip-v4 --epoch 500 2 | killall -9 python 3 | python dqn.py --task EnduroNoFrameskip-v4 --epoch 500 4 | killall -9 python 5 | python dqn.py --task EnduroNoFrameskip-v4 --epoch 500 6 | killall -9 python 7 | python dqn.py --task EnduroNoFrameskip-v4 --epoch 500 8 | killall -9 python 9 | python dqn.py --task EnduroNoFrameskip-v4 --epoch 500 10 | killall -9 python 11 | 12 | python rainbow.py --task EnduroNoFrameskip-v4 --epoch 500 13 | killall -9 python 14 | python rainbow.py --task EnduroNoFrameskip-v4 --epoch 500 15 | killall -9 python 16 | python rainbow.py --task EnduroNoFrameskip-v4 --epoch 500 17 | killall -9 python 18 | python rainbow.py --task EnduroNoFrameskip-v4 --epoch 500 19 | killall -9 python 20 | python rainbow.py --task EnduroNoFrameskip-v4 --epoch 500 21 | killall -9 python 22 | 23 | -------------------------------------------------------------------------------- /scripts/InvertedDoublePendulum-v2/train_baseline.sh: -------------------------------------------------------------------------------- 1 | python ddpg.py --task InvertedDoublePendulum-v2 --epoch 100 2 | pkill -f ddpg 3 | python ddpg.py --task InvertedDoublePendulum-v2 --epoch 100 4 | pkill -f ddpg 5 | python ddpg.py --task InvertedDoublePendulum-v2 --epoch 100 6 | pkill -f ddpg 7 | python ddpg.py --task InvertedDoublePendulum-v2 --epoch 100 8 | pkill -f ddpg 9 | python ddpg.py --task InvertedDoublePendulum-v2 --epoch 100 10 | pkill -f ddpg 11 | 12 | 13 | 14 | 15 | python td3.py --task InvertedDoublePendulum-v2 --epoch 100 16 | pkill -f td3 17 | python td3.py --task InvertedDoublePendulum-v2 --epoch 100 18 | pkill -f td3 19 | python td3.py --task InvertedDoublePendulum-v2 --epoch 100 20 | pkill -f td3 21 | python td3.py --task InvertedDoublePendulum-v2 --epoch 100 22 | pkill -f td3 23 | python td3.py --task InvertedDoublePendulum-v2 --epoch 100 24 | pkill -f td3 -------------------------------------------------------------------------------- /scripts/Breakout/train_baseline.sh: -------------------------------------------------------------------------------- 1 | python dqn.py --task BreakoutNoFrameskip-v4 --epoch 500 2 | killall -9 python 3 | python dqn.py --task BreakoutNoFrameskip-v4 --epoch 500 4 | killall -9 python 5 | python dqn.py --task BreakoutNoFrameskip-v4 --epoch 500 6 | killall -9 python 7 | python dqn.py --task BreakoutNoFrameskip-v4 --epoch 500 8 | killall -9 python 9 | python dqn.py --task BreakoutNoFrameskip-v4 --epoch 500 10 | killall -9 python 11 | 12 | 13 | python rainbow.py --task BreakoutNoFrameskip-v4 --epoch 500 14 | killall -9 python 15 | python rainbow.py --task BreakoutNoFrameskip-v4 --epoch 500 16 | killall -9 python 17 | python rainbow.py --task BreakoutNoFrameskip-v4 --epoch 500 18 | killall -9 python 19 | python rainbow.py --task BreakoutNoFrameskip-v4 --epoch 500 20 | killall -9 python 21 | python rainbow.py --task BreakoutNoFrameskip-v4 --epoch 500 22 | killall -9 python 23 | -------------------------------------------------------------------------------- /scripts/MsPacman/train_baseline.sh: -------------------------------------------------------------------------------- 1 | python dqn.py --task MsPacmanNoFrameskip-v4 --epoch 500 2 | killall -9 python 3 | python dqn.py --task MsPacmanNoFrameskip-v4 --epoch 500 4 | killall -9 python 5 | python dqn.py --task MsPacmanNoFrameskip-v4 --epoch 500 6 | killall -9 python 7 | python dqn.py --task MsPacmanNoFrameskip-v4 --epoch 500 8 | killall -9 python 9 | python dqn.py --task MsPacmanNoFrameskip-v4 --epoch 500 10 | killall -9 python 11 | 12 | 13 | python rainbow.py --task MsPacmanNoFrameskip-v4 --epoch 500 14 | killall -9 python 15 | python rainbow.py --task MsPacmanNoFrameskip-v4 --epoch 500 16 | killall -9 python 17 | python rainbow.py --task MsPacmanNoFrameskip-v4 --epoch 500 18 | killall -9 python 19 | python rainbow.py --task MsPacmanNoFrameskip-v4 --epoch 500 20 | killall -9 python 21 | python rainbow.py --task MsPacmanNoFrameskip-v4 --epoch 500 22 | killall -9 python 23 | -------------------------------------------------------------------------------- /scripts/SpaceInvaders/train_baseline.sh: -------------------------------------------------------------------------------- 1 | python dqn.py --task SpaceInvadersNoFrameskip-v4 --epoch 500 2 | killall -9 python 3 | python dqn.py --task SpaceInvadersNoFrameskip-v4 --epoch 500 4 | killall -9 python 5 | python dqn.py --task SpaceInvadersNoFrameskip-v4 --epoch 500 6 | killall -9 python 7 | python dqn.py --task SpaceInvadersNoFrameskip-v4 --epoch 500 8 | killall -9 python 9 | python dqn.py --task SpaceInvadersNoFrameskip-v4 --epoch 500 10 | killall -9 python 11 | 12 | python rainbow.py --task SpaceInvadersNoFrameskip-v4 --epoch 500 13 | killall -9 python 14 | python rainbow.py --task SpaceInvadersNoFrameskip-v4 --epoch 500 15 | killall -9 python 16 | python rainbow.py --task SpaceInvadersNoFrameskip-v4 --epoch 500 17 | killall -9 python 18 | python rainbow.py --task SpaceInvadersNoFrameskip-v4 --epoch 500 19 | killall -9 python 20 | python rainbow.py --task SpaceInvadersNoFrameskip-v4 --epoch 500 21 | killall -9 python 22 | 23 | -------------------------------------------------------------------------------- /tianshou/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | """Trainer package.""" 2 | 3 | from tianshou.trainer.base import BaseTrainer 4 | from tianshou.trainer.offline import ( 5 | OfflineTrainer, 6 | offline_trainer, 7 | offline_trainer_iter, 8 | ) 9 | from tianshou.trainer.offpolicy import ( 10 | OffpolicyTrainer, 11 | offpolicy_trainer, 12 | offpolicy_trainer_iter, 13 | ) 14 | from tianshou.trainer.onpolicy import ( 15 | OnpolicyTrainer, 16 | onpolicy_trainer, 17 | onpolicy_trainer_iter, 18 | ) 19 | from tianshou.trainer.utils import gather_info, test_episode 20 | 21 | __all__ = [ 22 | "BaseTrainer", 23 | "offpolicy_trainer", 24 | "offpolicy_trainer_iter", 25 | "OffpolicyTrainer", 26 | "onpolicy_trainer", 27 | "onpolicy_trainer_iter", 28 | "OnpolicyTrainer", 29 | "offline_trainer", 30 | "offline_trainer_iter", 31 | "OfflineTrainer", 32 | "test_episode", 33 | "gather_info", 34 | ] 35 | -------------------------------------------------------------------------------- /tianshou/utils/progress_bar.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | tqdm_config = { 4 | "dynamic_ncols": True, 5 | "ascii": True, 6 | } 7 | 8 | 9 | class DummyTqdm: 10 | """A dummy tqdm class that keeps stats but without progress bar. 11 | 12 | It supports ``__enter__`` and ``__exit__``, update and a dummy 13 | ``set_postfix``, which is the interface that trainers use. 14 | 15 | .. note:: 16 | 17 | Using ``disable=True`` in tqdm config results in infinite loop, thus 18 | this class is created. See the discussion at #641 for details. 19 | """ 20 | 21 | def __init__(self, total: int, **kwargs: Any): 22 | self.total = total 23 | self.n = 0 24 | 25 | def set_postfix(self, **kwargs: Any) -> None: 26 | pass 27 | 28 | def update(self, n: int = 1) -> None: 29 | self.n += n 30 | 31 | def __enter__(self) -> "DummyTqdm": 32 | return self 33 | 34 | def __exit__(self, *args: Any, **kwargs: Any) -> None: 35 | pass 36 | -------------------------------------------------------------------------------- /scripts/Walker2d-v3/train_NECSA_4_step_TD3.sh: -------------------------------------------------------------------------------- 1 | 2 | python necsa_td3.py --task Walker2d-v3 --epoch 400 --step 4 --grid_num 5 --epsilon 0.1 3 | pkill -f necsa 4 | python necsa_td3.py --task Walker2d-v3 --epoch 400 --step 4 --grid_num 5 --epsilon 0.1 5 | pkill -f necsa 6 | python necsa_td3.py --task Walker2d-v3 --epoch 400 --step 4 --grid_num 5 --epsilon 0.1 7 | pkill -f necsa 8 | 9 | python necsa_td3.py --task Walker2d-v3 --epoch 400 --step 5 --grid_num 5 --epsilon 0.1 10 | pkill -f necsa 11 | python necsa_td3.py --task Walker2d-v3 --epoch 400 --step 5 --grid_num 5 --epsilon 0.1 12 | pkill -f necsa 13 | python necsa_td3.py --task Walker2d-v3 --epoch 400 --step 5 --grid_num 5 --epsilon 0.1 14 | pkill -f necsa 15 | 16 | python necsa_td3.py --task Walker2d-v3 --epoch 400 --step 6 --grid_num 5 --epsilon 0.1 17 | pkill -f necsa 18 | python necsa_td3.py --task Walker2d-v3 --epoch 400 --step 6 --grid_num 5 --epsilon 0.1 19 | pkill -f necsa 20 | python necsa_td3.py --task Walker2d-v3 --epoch 400 --step 6 --grid_num 5 --epsilon 0.1 21 | pkill -f necsa -------------------------------------------------------------------------------- /tianshou/data/__init__.py: -------------------------------------------------------------------------------- 1 | """Data package.""" 2 | # isort:skip_file 3 | 4 | from tianshou.data.batch import Batch 5 | from tianshou.data.utils.converter import to_numpy, to_torch, to_torch_as 6 | from tianshou.data.utils.segtree import SegmentTree 7 | from tianshou.data.buffer.base import ReplayBuffer 8 | from tianshou.data.buffer.prio import PrioritizedReplayBuffer 9 | from tianshou.data.buffer.manager import ( 10 | ReplayBufferManager, 11 | PrioritizedReplayBufferManager, 12 | ) 13 | from tianshou.data.buffer.vecbuf import ( 14 | VectorReplayBuffer, 15 | PrioritizedVectorReplayBuffer, 16 | ) 17 | from tianshou.data.buffer.cached import CachedReplayBuffer 18 | from tianshou.data.collector import Collector, AsyncCollector 19 | from tianshou.data.necsa_collector import NECSA_Collector 20 | from tianshou.data.necsa_atari_collector import NECSA_Atari_Collector 21 | from tianshou.data.necsa_adv_collector import NECSA_Adv_Collector 22 | 23 | __all__ = [ 24 | "Batch", 25 | "to_numpy", 26 | "to_torch", 27 | "to_torch_as", 28 | "SegmentTree", 29 | "ReplayBuffer", 30 | "PrioritizedReplayBuffer", 31 | "ReplayBufferManager", 32 | "PrioritizedReplayBufferManager", 33 | "VectorReplayBuffer", 34 | "PrioritizedVectorReplayBuffer", 35 | "CachedReplayBuffer", 36 | "Collector", 37 | "AsyncCollector", 38 | ] 39 | -------------------------------------------------------------------------------- /mujoco_env.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import gym 4 | 5 | from tianshou.env import ShmemVectorEnv, VectorEnvNormObs 6 | 7 | try: 8 | import envpool 9 | except ImportError: 10 | envpool = None 11 | 12 | 13 | def make_mujoco_env(task, seed, training_num, test_num, obs_norm): 14 | """Wrapper function for Mujoco env. 15 | 16 | If EnvPool is installed, it will automatically switch to EnvPool's Mujoco env. 17 | 18 | :return: a tuple of (single env, training envs, test envs). 19 | """ 20 | if envpool is not None: 21 | train_envs = env = envpool.make_gym(task, num_envs=training_num, seed=seed) 22 | test_envs = envpool.make_gym(task, num_envs=test_num, seed=seed) 23 | else: 24 | warnings.warn( 25 | "Recommend using envpool (pip install envpool) " 26 | "to run Mujoco environments more efficiently." 27 | ) 28 | env = gym.make(task) 29 | train_envs = ShmemVectorEnv( 30 | [lambda: gym.make(task) for _ in range(training_num)] 31 | ) 32 | test_envs = ShmemVectorEnv([lambda: gym.make(task) for _ in range(test_num)]) 33 | env.seed(seed) 34 | train_envs.seed(seed) 35 | test_envs.seed(seed) 36 | if obs_norm: 37 | # obs norm wrapper 38 | train_envs = VectorEnvNormObs(train_envs) 39 | test_envs = VectorEnvNormObs(test_envs, update_obs_rms=False) 40 | test_envs.set_obs_rms(train_envs.get_obs_rms()) 41 | return env, train_envs, test_envs 42 | -------------------------------------------------------------------------------- /tianshou/policy/random.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional, Union 2 | 3 | import numpy as np 4 | 5 | from tianshou.data import Batch 6 | from tianshou.policy import BasePolicy 7 | 8 | 9 | class RandomPolicy(BasePolicy): 10 | """A random agent used in multi-agent learning. 11 | 12 | It randomly chooses an action from the legal action. 13 | """ 14 | 15 | def forward( 16 | self, 17 | batch: Batch, 18 | state: Optional[Union[dict, Batch, np.ndarray]] = None, 19 | **kwargs: Any, 20 | ) -> Batch: 21 | """Compute the random action over the given batch data. 22 | 23 | The input should contain a mask in batch.obs, with "True" to be 24 | available and "False" to be unavailable. For example, 25 | ``batch.obs.mask == np.array([[False, True, False]])`` means with batch 26 | size 1, action "1" is available but action "0" and "2" are unavailable. 27 | 28 | :return: A :class:`~tianshou.data.Batch` with "act" key, containing 29 | the random action. 30 | 31 | .. seealso:: 32 | 33 | Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for 34 | more detailed explanation. 35 | """ 36 | mask = batch.obs.mask 37 | logits = np.random.rand(*mask.shape) 38 | logits[~mask] = -np.inf 39 | return Batch(act=logits.argmax(axis=-1)) 40 | 41 | def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: 42 | """Since a random agent learns nothing, it returns an empty dict.""" 43 | return {} 44 | -------------------------------------------------------------------------------- /tianshou/utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | import torch 4 | 5 | 6 | class MultipleLRSchedulers: 7 | """A wrapper for multiple learning rate schedulers. 8 | 9 | Every time :meth:`~tianshou.utils.MultipleLRSchedulers.step` is called, 10 | it calls the step() method of each of the schedulers that it contains. 11 | Example usage: 12 | :: 13 | 14 | scheduler1 = ConstantLR(opt1, factor=0.1, total_iters=2) 15 | scheduler2 = ExponentialLR(opt2, gamma=0.9) 16 | scheduler = MultipleLRSchedulers(scheduler1, scheduler2) 17 | policy = PPOPolicy(..., lr_scheduler=scheduler) 18 | """ 19 | 20 | def __init__(self, *args: torch.optim.lr_scheduler.LambdaLR): 21 | self.schedulers = args 22 | 23 | def step(self) -> None: 24 | """Take a step in each of the learning rate schedulers.""" 25 | for scheduler in self.schedulers: 26 | scheduler.step() 27 | 28 | def state_dict(self) -> List[Dict]: 29 | """Get state_dict for each of the learning rate schedulers. 30 | 31 | :return: A list of state_dict of learning rate schedulers. 32 | """ 33 | return [s.state_dict() for s in self.schedulers] 34 | 35 | def load_state_dict(self, state_dict: List[Dict]) -> None: 36 | """Load states from state_dict. 37 | 38 | :param List[Dict] state_dict: A list of learning rate scheduler 39 | state_dict, in the same order as the schedulers. 40 | """ 41 | for (s, sd) in zip(self.schedulers, state_dict): 42 | s.__dict__.update(sd) 43 | -------------------------------------------------------------------------------- /scripts/Hopper-v3/train_NECSA_ADV_TD3.sh: -------------------------------------------------------------------------------- 1 | python necsa_adv_td3.py --task Hopper-v3 --epoch 400 --step 1 --grid_num 5 --epsilon 0.1 2 | pkill -f necsa 3 | python necsa_adv_td3.py --task Hopper-v3 --epoch 400 --step 1 --grid_num 5 --epsilon 0.1 4 | pkill -f necsa 5 | python necsa_adv_td3.py --task Hopper-v3 --epoch 400 --step 1 --grid_num 5 --epsilon 0.1 6 | pkill -f necsa 7 | python necsa_adv_td3.py --task Hopper-v3 --epoch 400 --step 1 --grid_num 5 --epsilon 0.1 8 | pkill -f necsa 9 | python necsa_adv_td3.py --task Hopper-v3 --epoch 400 --step 1 --grid_num 5 --epsilon 0.1 10 | pkill -f necsa 11 | 12 | 13 | python necsa_adv_td3.py --task Hopper-v3 --epoch 400 --step 2 --grid_num 5 --epsilon 0.1 14 | pkill -f necsa 15 | python necsa_adv_td3.py --task Hopper-v3 --epoch 400 --step 2 --grid_num 5 --epsilon 0.1 16 | pkill -f necsa 17 | python necsa_adv_td3.py --task Hopper-v3 --epoch 400 --step 2 --grid_num 5 --epsilon 0.1 18 | pkill -f necsa 19 | python necsa_adv_td3.py --task Hopper-v3 --epoch 400 --step 2 --grid_num 5 --epsilon 0.1 20 | pkill -f necsa 21 | python necsa_adv_td3.py --task Hopper-v3 --epoch 400 --step 2 --grid_num 5 --epsilon 0.1 22 | pkill -f necsa 23 | 24 | python necsa_adv_td3.py --task Hopper-v3 --epoch 400 --step 3 --grid_num 5 --epsilon 0.1 25 | pkill -f necsa 26 | python necsa_adv_td3.py --task Hopper-v3 --epoch 400 --step 3 --grid_num 5 --epsilon 0.1 27 | pkill -f necsa 28 | python necsa_adv_td3.py --task Hopper-v3 --epoch 400 --step 3 --grid_num 5 --epsilon 0.1 29 | pkill -f necsa 30 | python necsa_adv_td3.py --task Hopper-v3 --epoch 400 --step 3 --grid_num 5 --epsilon 0.1 31 | pkill -f necsa 32 | python necsa_adv_td3.py --task Hopper-v3 --epoch 400 --step 3 --grid_num 5 --epsilon 0.1 33 | pkill -f necsa -------------------------------------------------------------------------------- /scripts/Walker2d-v3/train_NECSA_ADV_TD3.sh: -------------------------------------------------------------------------------- 1 | python necsa_adv_td3.py --task Walker2d-v3 --epoch 400 --step 1 --grid_num 5 --epsilon 0.2 2 | pkill -f necsa 3 | python necsa_adv_td3.py --task Walker2d-v3 --epoch 400 --step 1 --grid_num 5 --epsilon 0.2 4 | pkill -f necsa 5 | python necsa_adv_td3.py --task Walker2d-v3 --epoch 400 --step 1 --grid_num 5 --epsilon 0.2 6 | pkill -f necsa 7 | python necsa_adv_td3.py --task Walker2d-v3 --epoch 400 --step 1 --grid_num 5 --epsilon 0.2 8 | pkill -f necsa 9 | python necsa_adv_td3.py --task Walker2d-v3 --epoch 400 --step 1 --grid_num 5 --epsilon 0.2 10 | pkill -f necsa 11 | 12 | 13 | python necsa_adv_td3.py --task Walker2d-v3 --epoch 400 --step 2 --grid_num 5 --epsilon 0.2 14 | pkill -f necsa 15 | python necsa_adv_td3.py --task Walker2d-v3 --epoch 400 --step 2 --grid_num 5 --epsilon 0.2 16 | pkill -f necsa 17 | python necsa_adv_td3.py --task Walker2d-v3 --epoch 400 --step 2 --grid_num 5 --epsilon 0.2 18 | pkill -f necsa 19 | python necsa_adv_td3.py --task Walker2d-v3 --epoch 400 --step 2 --grid_num 5 --epsilon 0.2 20 | pkill -f necsa 21 | python necsa_adv_td3.py --task Walker2d-v3 --epoch 400 --step 2 --grid_num 5 --epsilon 0.2 22 | pkill -f necsa 23 | 24 | python necsa_adv_td3.py --task Walker2d-v3 --epoch 400 --step 3 --grid_num 5 --epsilon 0.2 25 | pkill -f necsa 26 | python necsa_adv_td3.py --task Walker2d-v3 --epoch 400 --step 3 --grid_num 5 --epsilon 0.2 27 | pkill -f necsa 28 | python necsa_adv_td3.py --task Walker2d-v3 --epoch 400 --step 3 --grid_num 5 --epsilon 0.2 29 | pkill -f necsa 30 | python necsa_adv_td3.py --task Walker2d-v3 --epoch 400 --step 3 --grid_num 5 --epsilon 0.2 31 | pkill -f necsa 32 | python necsa_adv_td3.py --task Walker2d-v3 --epoch 400 --step 3 --grid_num 5 --epsilon 0.2 33 | pkill -f necsa -------------------------------------------------------------------------------- /tianshou/env/worker/dummy.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, List, Optional, Tuple, Union 2 | 3 | import gym 4 | import numpy as np 5 | 6 | from tianshou.env.worker import EnvWorker 7 | 8 | 9 | class DummyEnvWorker(EnvWorker): 10 | """Dummy worker used in sequential vector environments.""" 11 | 12 | def __init__(self, env_fn: Callable[[], gym.Env]) -> None: 13 | self.env = env_fn() 14 | super().__init__(env_fn) 15 | 16 | def get_env_attr(self, key: str) -> Any: 17 | return getattr(self.env, key) 18 | 19 | def set_env_attr(self, key: str, value: Any) -> None: 20 | setattr(self.env.unwrapped, key, value) 21 | 22 | def reset(self, **kwargs: Any) -> Union[np.ndarray, Tuple[np.ndarray, dict]]: 23 | if "seed" in kwargs: 24 | super().seed(kwargs["seed"]) 25 | return self.env.reset(**kwargs) 26 | 27 | @staticmethod 28 | def wait( # type: ignore 29 | workers: List["DummyEnvWorker"], wait_num: int, timeout: Optional[float] = None 30 | ) -> List["DummyEnvWorker"]: 31 | # Sequential EnvWorker objects are always ready 32 | return workers 33 | 34 | def send(self, action: Optional[np.ndarray], **kwargs: Any) -> None: 35 | if action is None: 36 | self.result = self.env.reset(**kwargs) 37 | else: 38 | self.result = self.env.step(action) # type: ignore 39 | 40 | def seed(self, seed: Optional[int] = None) -> Optional[List[int]]: 41 | super().seed(seed) 42 | try: 43 | return self.env.seed(seed) 44 | except NotImplementedError: 45 | self.env.reset(seed=seed) 46 | return [seed] # type: ignore 47 | 48 | def render(self, **kwargs: Any) -> Any: 49 | return self.env.render(**kwargs) 50 | 51 | def close_env(self) -> None: 52 | self.env.close() 53 | -------------------------------------------------------------------------------- /tianshou/policy/modelfree/rainbow.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | from tianshou.data import Batch 4 | from tianshou.policy import C51Policy 5 | from tianshou.utils.net.discrete import sample_noise 6 | 7 | 8 | class RainbowPolicy(C51Policy): 9 | """Implementation of Rainbow DQN. arXiv:1710.02298. 10 | 11 | :param torch.nn.Module model: a model following the rules in 12 | :class:`~tianshou.policy.BasePolicy`. (s -> logits) 13 | :param torch.optim.Optimizer optim: a torch.optim for optimizing the model. 14 | :param float discount_factor: in [0, 1]. 15 | :param int num_atoms: the number of atoms in the support set of the 16 | value distribution. Default to 51. 17 | :param float v_min: the value of the smallest atom in the support set. 18 | Default to -10.0. 19 | :param float v_max: the value of the largest atom in the support set. 20 | Default to 10.0. 21 | :param int estimation_step: the number of steps to look ahead. Default to 1. 22 | :param int target_update_freq: the target network update frequency (0 if 23 | you do not use the target network). Default to 0. 24 | :param bool reward_normalization: normalize the reward to Normal(0, 1). 25 | Default to False. 26 | :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in 27 | optimizer in each policy.update(). Default to None (no lr_scheduler). 28 | 29 | .. seealso:: 30 | 31 | Please refer to :class:`~tianshou.policy.C51Policy` for more detailed 32 | explanation. 33 | """ 34 | 35 | def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: 36 | sample_noise(self.model) 37 | if self._target and sample_noise(self.model_old): 38 | self.model_old.train() # so that NoisyLinear takes effect 39 | return super().learn(batch, **kwargs) 40 | -------------------------------------------------------------------------------- /scripts/Hopper-v3/train_NECSA_TD3.sh: -------------------------------------------------------------------------------- 1 | python necsa_td3.py --task Hopper-v3 --epoch 400 --step 1 --grid_num 5 --epsilon 0.1 --mode state_action 2 | pkill -f necsa 3 | python necsa_td3.py --task Hopper-v3 --epoch 400 --step 1 --grid_num 5 --epsilon 0.1 --mode state_action 4 | pkill -f necsa 5 | python necsa_td3.py --task Hopper-v3 --epoch 400 --step 1 --grid_num 5 --epsilon 0.1 --mode state_action 6 | pkill -f necsa 7 | python necsa_td3.py --task Hopper-v3 --epoch 400 --step 1 --grid_num 5 --epsilon 0.1 --mode state_action 8 | pkill -f necsa 9 | python necsa_td3.py --task Hopper-v3 --epoch 400 --step 1 --grid_num 5 --epsilon 0.1 --mode state_action 10 | pkill -f necsa 11 | 12 | 13 | python necsa_td3.py --task Hopper-v3 --epoch 400 --step 2 --grid_num 5 --epsilon 0.1 --mode state_action 14 | pkill -f necsa 15 | python necsa_td3.py --task Hopper-v3 --epoch 400 --step 2 --grid_num 5 --epsilon 0.1 --mode state_action 16 | pkill -f necsa 17 | python necsa_td3.py --task Hopper-v3 --epoch 400 --step 2 --grid_num 5 --epsilon 0.1 --mode state_action 18 | pkill -f necsa 19 | python necsa_td3.py --task Hopper-v3 --epoch 400 --step 2 --grid_num 5 --epsilon 0.1 --mode state_action 20 | pkill -f necsa 21 | python necsa_td3.py --task Hopper-v3 --epoch 400 --step 2 --grid_num 5 --epsilon 0.1 --mode state_action 22 | pkill -f necsa 23 | 24 | python necsa_td3.py --task Hopper-v3 --epoch 400 --step 3 --grid_num 5 --epsilon 0.1 --mode state_action 25 | pkill -f necsa 26 | python necsa_td3.py --task Hopper-v3 --epoch 400 --step 3 --grid_num 5 --epsilon 0.1 --mode state_action 27 | pkill -f necsa 28 | python necsa_td3.py --task Hopper-v3 --epoch 400 --step 3 --grid_num 5 --epsilon 0.1 --mode state_action 29 | pkill -f necsa 30 | python necsa_td3.py --task Hopper-v3 --epoch 400 --step 3 --grid_num 5 --epsilon 0.1 --mode state_action 31 | pkill -f necsa 32 | python necsa_td3.py --task Hopper-v3 --epoch 400 --step 3 --grid_num 5 --epsilon 0.1 --mode state_action 33 | pkill -f necsa -------------------------------------------------------------------------------- /scripts/Walker2d-v3/train_NECSA_TD3.sh: -------------------------------------------------------------------------------- 1 | python necsa_td3.py --task Walker2d-v3 --epoch 400 --step 1 --grid_num 5 --epsilon 0.2 --mode state_action 2 | pkill -f necsa 3 | python necsa_td3.py --task Walker2d-v3 --epoch 400 --step 1 --grid_num 5 --epsilon 0.2 --mode state_action 4 | pkill -f necsa 5 | python necsa_td3.py --task Walker2d-v3 --epoch 400 --step 1 --grid_num 5 --epsilon 0.2 --mode state_action 6 | pkill -f necsa 7 | python necsa_td3.py --task Walker2d-v3 --epoch 400 --step 1 --grid_num 5 --epsilon 0.2 --mode state_action 8 | pkill -f necsa 9 | python necsa_td3.py --task Walker2d-v3 --epoch 400 --step 1 --grid_num 5 --epsilon 0.2 --mode state_action 10 | pkill -f necsa 11 | 12 | 13 | python necsa_td3.py --task Walker2d-v3 --epoch 400 --step 2 --grid_num 5 --epsilon 0.2 --mode state_action 14 | pkill -f necsa 15 | python necsa_td3.py --task Walker2d-v3 --epoch 400 --step 2 --grid_num 5 --epsilon 0.2 --mode state_action 16 | pkill -f necsa 17 | python necsa_td3.py --task Walker2d-v3 --epoch 400 --step 2 --grid_num 5 --epsilon 0.2 --mode state_action 18 | pkill -f necsa 19 | python necsa_td3.py --task Walker2d-v3 --epoch 400 --step 2 --grid_num 5 --epsilon 0.2 --mode state_action 20 | pkill -f necsa 21 | python necsa_td3.py --task Walker2d-v3 --epoch 400 --step 2 --grid_num 5 --epsilon 0.2 --mode state_action 22 | pkill -f necsa 23 | 24 | python necsa_td3.py --task Walker2d-v3 --epoch 400 --step 3 --grid_num 5 --epsilon 0.2 --mode state_action 25 | pkill -f necsa 26 | python necsa_td3.py --task Walker2d-v3 --epoch 400 --step 3 --grid_num 5 --epsilon 0.2 --mode state_action 27 | pkill -f necsa 28 | python necsa_td3.py --task Walker2d-v3 --epoch 400 --step 3 --grid_num 5 --epsilon 0.2 --mode state_action 29 | pkill -f necsa 30 | python necsa_td3.py --task Walker2d-v3 --epoch 400 --step 3 --grid_num 5 --epsilon 0.2 --mode state_action 31 | pkill -f necsa 32 | python necsa_td3.py --task Walker2d-v3 --epoch 400 --step 3 --grid_num 5 --epsilon 0.2 --mode state_action 33 | pkill -f necsa -------------------------------------------------------------------------------- /scripts/Alien/train_NECSA.sh: -------------------------------------------------------------------------------- 1 | python necsa_dqn.py --task AlienNoFrameskip-v4 --epoch 500 --step 1 --epsilon 0.1 --mode hidden --reduction 2 | killall -9 python 3 | python necsa_dqn.py --task AlienNoFrameskip-v4 --epoch 500 --step 1 --epsilon 0.1 --mode hidden --reduction 4 | killall -9 python 5 | python necsa_dqn.py --task AlienNoFrameskip-v4 --epoch 500 --step 1 --epsilon 0.1 --mode hidden --reduction 6 | killall -9 python 7 | python necsa_dqn.py --task AlienNoFrameskip-v4 --epoch 500 --step 1 --epsilon 0.1 --mode hidden --reduction 8 | killall -9 python 9 | python necsa_dqn.py --task AlienNoFrameskip-v4 --epoch 500 --step 1 --epsilon 0.1 --mode hidden --reduction 10 | killall -9 python 11 | 12 | 13 | 14 | python necsa_dqn.py --task AlienNoFrameskip-v4 --epoch 500 --step 2 --epsilon 0.1 --mode hidden --reduction 15 | killall -9 python 16 | python necsa_dqn.py --task AlienNoFrameskip-v4 --epoch 500 --step 2 --epsilon 0.1 --mode hidden --reduction 17 | killall -9 python 18 | python necsa_dqn.py --task AlienNoFrameskip-v4 --epoch 500 --step 2 --epsilon 0.1 --mode hidden --reduction 19 | killall -9 python 20 | python necsa_dqn.py --task AlienNoFrameskip-v4 --epoch 500 --step 2 --epsilon 0.1 --mode hidden --reduction 21 | killall -9 python 22 | python necsa_dqn.py --task AlienNoFrameskip-v4 --epoch 500 --step 2 --epsilon 0.1 --mode hidden --reduction 23 | killall -9 python 24 | 25 | 26 | python necsa_dqn.py --task AlienNoFrameskip-v4 --epoch 500 --step 3 --epsilon 0.1 --mode hidden --reduction 27 | killall -9 python 28 | python necsa_dqn.py --task AlienNoFrameskip-v4 --epoch 500 --step 3 --epsilon 0.1 --mode hidden --reduction 29 | killall -9 python 30 | python necsa_dqn.py --task AlienNoFrameskip-v4 --epoch 500 --step 3 --epsilon 0.1 --mode hidden --reduction 31 | killall -9 python 32 | python necsa_dqn.py --task AlienNoFrameskip-v4 --epoch 500 --step 3 --epsilon 0.1 --mode hidden --reduction 33 | killall -9 python 34 | python necsa_dqn.py --task AlienNoFrameskip-v4 --epoch 500 --step 3 --epsilon 0.1 --mode hidden --reduction 35 | killall -9 python 36 | 37 | -------------------------------------------------------------------------------- /scripts/Qbert/train_NECSA.sh: -------------------------------------------------------------------------------- 1 | python necsa_rainbow.py --task QbertNoFrameskip-v4 --epoch 500 --step 1 --epsilon 0.1 --mode hidden --reduction 2 | killall -9 python 3 | python necsa_rainbow.py --task QbertNoFrameskip-v4 --epoch 500 --step 1 --epsilon 0.1 --mode hidden --reduction 4 | killall -9 python 5 | python necsa_rainbow.py --task QbertNoFrameskip-v4 --epoch 500 --step 1 --epsilon 0.1 --mode hidden --reduction 6 | killall -9 python 7 | python necsa_rainbow.py --task QbertNoFrameskip-v4 --epoch 500 --step 1 --epsilon 0.1 --mode hidden --reduction 8 | killall -9 python 9 | python necsa_rainbow.py --task QbertNoFrameskip-v4 --epoch 500 --step 1 --epsilon 0.1 --mode hidden --reduction 10 | killall -9 python 11 | 12 | 13 | 14 | python necsa_rainbow.py --task QbertNoFrameskip-v4 --epoch 500 --step 2 --epsilon 0.1 --mode hidden --reduction 15 | killall -9 python 16 | python necsa_rainbow.py --task QbertNoFrameskip-v4 --epoch 500 --step 2 --epsilon 0.1 --mode hidden --reduction 17 | killall -9 python 18 | python necsa_rainbow.py --task QbertNoFrameskip-v4 --epoch 500 --step 2 --epsilon 0.1 --mode hidden --reduction 19 | killall -9 python 20 | python necsa_rainbow.py --task QbertNoFrameskip-v4 --epoch 500 --step 2 --epsilon 0.1 --mode hidden --reduction 21 | killall -9 python 22 | python necsa_rainbow.py --task QbertNoFrameskip-v4 --epoch 500 --step 2 --epsilon 0.1 --mode hidden --reduction 23 | killall -9 python 24 | 25 | python necsa_rainbow.py --task QbertNoFrameskip-v4 --epoch 500 --step 3 --epsilon 0.1 --mode hidden --reduction 26 | killall -9 python 27 | python necsa_rainbow.py --task QbertNoFrameskip-v4 --epoch 500 --step 3 --epsilon 0.1 --mode hidden --reduction 28 | killall -9 python 29 | python necsa_rainbow.py --task QbertNoFrameskip-v4 --epoch 500 --step 3 --epsilon 0.1 --mode hidden --reduction 30 | killall -9 python 31 | python necsa_rainbow.py --task QbertNoFrameskip-v4 --epoch 500 --step 3 --epsilon 0.1 --mode hidden --reduction 32 | killall -9 python 33 | python necsa_rainbow.py --task QbertNoFrameskip-v4 --epoch 500 --step 3 --epsilon 0.1 --mode hidden --reduction 34 | killall -9 python 35 | -------------------------------------------------------------------------------- /scripts/InvertedPendulum-v2/train_NECSA_TD3.sh: -------------------------------------------------------------------------------- 1 | python necsa_td3.py --task InvertedPendulum-v2 --epoch 100 --step 1 --grid_num 10 --epsilon 0.1 --mode state_action 2 | pkill -f necsa 3 | python necsa_td3.py --task InvertedPendulum-v2 --epoch 100 --step 1 --grid_num 10 --epsilon 0.1 --mode state_action 4 | pkill -f necsa 5 | python necsa_td3.py --task InvertedPendulum-v2 --epoch 100 --step 1 --grid_num 10 --epsilon 0.1 --mode state_action 6 | pkill -f necsa 7 | python necsa_td3.py --task InvertedPendulum-v2 --epoch 100 --step 1 --grid_num 10 --epsilon 0.1 --mode state_action 8 | pkill -f necsa 9 | python necsa_td3.py --task InvertedPendulum-v2 --epoch 100 --step 1 --grid_num 10 --epsilon 0.1 --mode state_action 10 | pkill -f necsa 11 | 12 | 13 | python necsa_td3.py --task InvertedPendulum-v2 --epoch 100 --step 2 --grid_num 10 --epsilon 0.1 --mode state_action 14 | pkill -f necsa 15 | python necsa_td3.py --task InvertedPendulum-v2 --epoch 100 --step 2 --grid_num 10 --epsilon 0.1 --mode state_action 16 | pkill -f necsa 17 | python necsa_td3.py --task InvertedPendulum-v2 --epoch 100 --step 2 --grid_num 10 --epsilon 0.1 --mode state_action 18 | pkill -f necsa 19 | python necsa_td3.py --task InvertedPendulum-v2 --epoch 100 --step 2 --grid_num 10 --epsilon 0.1 --mode state_action 20 | pkill -f necsa 21 | python necsa_td3.py --task InvertedPendulum-v2 --epoch 100 --step 2 --grid_num 10 --epsilon 0.1 --mode state_action 22 | pkill -f necsa 23 | 24 | python necsa_td3.py --task InvertedPendulum-v2 --epoch 100 --step 3 --grid_num 10 --epsilon 0.1 --mode state_action 25 | pkill -f necsa 26 | python necsa_td3.py --task InvertedPendulum-v2 --epoch 100 --step 3 --grid_num 10 --epsilon 0.1 --mode state_action 27 | pkill -f necsa 28 | python necsa_td3.py --task InvertedPendulum-v2 --epoch 100 --step 3 --grid_num 10 --epsilon 0.1 --mode state_action 29 | pkill -f necsa 30 | python necsa_td3.py --task InvertedPendulum-v2 --epoch 100 --step 3 --grid_num 10 --epsilon 0.1 --mode state_action 31 | pkill -f necsa 32 | python necsa_td3.py --task InvertedPendulum-v2 --epoch 100 --step 3 --grid_num 10 --epsilon 0.1 --mode state_action 33 | pkill -f necsa -------------------------------------------------------------------------------- /scripts/Enduro/train_NECSA.sh: -------------------------------------------------------------------------------- 1 | python necsa_rainbow.py --task EnduroNoFrameskip-v4 --epoch 500 --step 1 --epsilon 0.1 --mode hidden --reduction 2 | killall -9 python 3 | python necsa_rainbow.py --task EnduroNoFrameskip-v4 --epoch 500 --step 1 --epsilon 0.1 --mode hidden --reduction 4 | killall -9 python 5 | python necsa_rainbow.py --task EnduroNoFrameskip-v4 --epoch 500 --step 1 --epsilon 0.1 --mode hidden --reduction 6 | killall -9 python 7 | python necsa_rainbow.py --task EnduroNoFrameskip-v4 --epoch 500 --step 1 --epsilon 0.1 --mode hidden --reduction 8 | killall -9 python 9 | python necsa_rainbow.py --task EnduroNoFrameskip-v4 --epoch 500 --step 1 --epsilon 0.1 --mode hidden --reduction 10 | killall -9 python 11 | 12 | 13 | 14 | python necsa_rainbow.py --task EnduroNoFrameskip-v4 --epoch 500 --step 2 --epsilon 0.1 --mode hidden --reduction 15 | killall -9 python 16 | python necsa_rainbow.py --task EnduroNoFrameskip-v4 --epoch 500 --step 2 --epsilon 0.1 --mode hidden --reduction 17 | killall -9 python 18 | python necsa_rainbow.py --task EnduroNoFrameskip-v4 --epoch 500 --step 2 --epsilon 0.1 --mode hidden --reduction 19 | killall -9 python 20 | python necsa_rainbow.py --task EnduroNoFrameskip-v4 --epoch 500 --step 2 --epsilon 0.1 --mode hidden --reduction 21 | killall -9 python 22 | python necsa_rainbow.py --task EnduroNoFrameskip-v4 --epoch 500 --step 2 --epsilon 0.1 --mode hidden --reduction 23 | killall -9 python 24 | 25 | 26 | python necsa_rainbow.py --task EnduroNoFrameskip-v4 --epoch 500 --step 3 --epsilon 0.1 --mode hidden --reduction 27 | killall -9 python 28 | python necsa_rainbow.py --task EnduroNoFrameskip-v4 --epoch 500 --step 3 --epsilon 0.1 --mode hidden --reduction 29 | killall -9 python 30 | python necsa_rainbow.py --task EnduroNoFrameskip-v4 --epoch 500 --step 3 --epsilon 0.1 --mode hidden --reduction 31 | killall -9 python 32 | python necsa_rainbow.py --task EnduroNoFrameskip-v4 --epoch 500 --step 3 --epsilon 0.1 --mode hidden --reduction 33 | killall -9 python 34 | python necsa_rainbow.py --task EnduroNoFrameskip-v4 --epoch 500 --step 3 --epsilon 0.1 --mode hidden --reduction 35 | killall -9 python 36 | 37 | -------------------------------------------------------------------------------- /scripts/MsPacman/train_NECSA.sh: -------------------------------------------------------------------------------- 1 | python necsa_rainbow.py --task MsPacmanNoFrameskip-v4 --epoch 500 --step 1 --epsilon 0.1 --mode hidden --reduction 2 | killall -9 python 3 | python necsa_rainbow.py --task MsPacmanNoFrameskip-v4 --epoch 500 --step 1 --epsilon 0.1 --mode hidden --reduction 4 | killall -9 python 5 | python necsa_rainbow.py --task MsPacmanNoFrameskip-v4 --epoch 500 --step 1 --epsilon 0.1 --mode hidden --reduction 6 | killall -9 python 7 | python necsa_rainbow.py --task MsPacmanNoFrameskip-v4 --epoch 500 --step 1 --epsilon 0.1 --mode hidden --reduction 8 | killall -9 python 9 | python necsa_rainbow.py --task MsPacmanNoFrameskip-v4 --epoch 500 --step 1 --epsilon 0.1 --mode hidden --reduction 10 | killall -9 python 11 | 12 | 13 | 14 | python necsa_rainbow.py --task MsPacmanNoFrameskip-v4 --epoch 500 --step 2 --epsilon 0.1 --mode hidden --reduction 15 | killall -9 python 16 | python necsa_rainbow.py --task MsPacmanNoFrameskip-v4 --epoch 500 --step 2 --epsilon 0.1 --mode hidden --reduction 17 | killall -9 python 18 | python necsa_rainbow.py --task MsPacmanNoFrameskip-v4 --epoch 500 --step 2 --epsilon 0.1 --mode hidden --reduction 19 | killall -9 python 20 | python necsa_rainbow.py --task MsPacmanNoFrameskip-v4 --epoch 500 --step 2 --epsilon 0.1 --mode hidden --reduction 21 | killall -9 python 22 | python necsa_rainbow.py --task MsPacmanNoFrameskip-v4 --epoch 500 --step 2 --epsilon 0.1 --mode hidden --reduction 23 | killall -9 python 24 | 25 | python necsa_rainbow.py --task MsPacmanNoFrameskip-v4 --epoch 500 --step 3 --epsilon 0.1 --mode hidden --reduction 26 | killall -9 python 27 | python necsa_rainbow.py --task MsPacmanNoFrameskip-v4 --epoch 500 --step 3 --epsilon 0.1 --mode hidden --reduction 28 | killall -9 python 29 | python necsa_rainbow.py --task MsPacmanNoFrameskip-v4 --epoch 500 --step 3 --epsilon 0.1 --mode hidden --reduction 30 | killall -9 python 31 | python necsa_rainbow.py --task MsPacmanNoFrameskip-v4 --epoch 500 --step 3 --epsilon 0.1 --mode hidden --reduction 32 | killall -9 python 33 | python necsa_rainbow.py --task MsPacmanNoFrameskip-v4 --epoch 500 --step 3 --epsilon 0.1 --mode hidden --reduction 34 | killall -9 python 35 | -------------------------------------------------------------------------------- /scripts/Breakout/train_NECSA.sh: -------------------------------------------------------------------------------- 1 | python necsa_rainbow.py --task BreakoutNoFrameskip-v4 --epoch 500 --step 1 --epsilon 0.1 --mode hidden --reduction 2 | killall -9 python 3 | python necsa_rainbow.py --task BreakoutNoFrameskip-v4 --epoch 500 --step 1 --epsilon 0.1 --mode hidden --reduction 4 | killall -9 python 5 | python necsa_rainbow.py --task BreakoutNoFrameskip-v4 --epoch 500 --step 1 --epsilon 0.1 --mode hidden --reduction 6 | killall -9 python 7 | python necsa_rainbow.py --task BreakoutNoFrameskip-v4 --epoch 500 --step 1 --epsilon 0.1 --mode hidden --reduction 8 | killall -9 python 9 | python necsa_rainbow.py --task BreakoutNoFrameskip-v4 --epoch 500 --step 1 --epsilon 0.1 --mode hidden --reduction 10 | killall -9 python 11 | 12 | 13 | python necsa_rainbow.py --task BreakoutNoFrameskip-v4 --epoch 500 --step 2 --epsilon 0.1 --mode hidden --reduction 14 | killall -9 python 15 | python necsa_rainbow.py --task BreakoutNoFrameskip-v4 --epoch 500 --step 2 --epsilon 0.1 --mode hidden --reduction 16 | killall -9 python 17 | python necsa_rainbow.py --task BreakoutNoFrameskip-v4 --epoch 500 --step 2 --epsilon 0.1 --mode hidden --reduction 18 | killall -9 python 19 | python necsa_rainbow.py --task BreakoutNoFrameskip-v4 --epoch 500 --step 2 --epsilon 0.1 --mode hidden --reduction 20 | killall -9 python 21 | python necsa_rainbow.py --task BreakoutNoFrameskip-v4 --epoch 500 --step 2 --epsilon 0.1 --mode hidden --reduction 22 | killall -9 python 23 | 24 | 25 | python necsa_rainbow.py --task BreakoutNoFrameskip-v4 --epoch 500 --step 3 --epsilon 0.1 --mode hidden --reduction 26 | killall -9 python 27 | python necsa_rainbow.py --task BreakoutNoFrameskip-v4 --epoch 500 --step 3 --epsilon 0.1 --mode hidden --reduction 28 | killall -9 python 29 | python necsa_rainbow.py --task BreakoutNoFrameskip-v4 --epoch 500 --step 3 --epsilon 0.1 --mode hidden --reduction 30 | killall -9 python 31 | python necsa_rainbow.py --task BreakoutNoFrameskip-v4 --epoch 500 --step 3 --epsilon 0.1 --mode hidden --reduction 32 | killall -9 python 33 | python necsa_rainbow.py --task BreakoutNoFrameskip-v4 --epoch 500 --step 3 --epsilon 0.1 --mode hidden --reduction 34 | killall -9 python 35 | 36 | -------------------------------------------------------------------------------- /scripts/InvertedDoublePendulum-v2/train_NECSA_TD3.sh: -------------------------------------------------------------------------------- 1 | python necsa_td3.py --task InvertedDoublePendulum-v2 --epoch 100 --step 1 --grid_num 10 --epsilon 0.1 --mode state_action 2 | pkill -f necsa 3 | python necsa_td3.py --task InvertedDoublePendulum-v2 --epoch 100 --step 1 --grid_num 10 --epsilon 0.1 --mode state_action 4 | pkill -f necsa 5 | python necsa_td3.py --task InvertedDoublePendulum-v2 --epoch 100 --step 1 --grid_num 10 --epsilon 0.1 --mode state_action 6 | pkill -f necsa 7 | python necsa_td3.py --task InvertedDoublePendulum-v2 --epoch 100 --step 1 --grid_num 10 --epsilon 0.1 --mode state_action 8 | pkill -f necsa 9 | python necsa_td3.py --task InvertedDoublePendulum-v2 --epoch 100 --step 1 --grid_num 10 --epsilon 0.1 --mode state_action 10 | pkill -f necsa 11 | 12 | 13 | python necsa_td3.py --task InvertedDoublePendulum-v2 --epoch 100 --step 2 --grid_num 10 --epsilon 0.1 --mode state_action 14 | pkill -f necsa 15 | python necsa_td3.py --task InvertedDoublePendulum-v2 --epoch 100 --step 2 --grid_num 10 --epsilon 0.1 --mode state_action 16 | pkill -f necsa 17 | python necsa_td3.py --task InvertedDoublePendulum-v2 --epoch 100 --step 2 --grid_num 10 --epsilon 0.1 --mode state_action 18 | pkill -f necsa 19 | python necsa_td3.py --task InvertedDoublePendulum-v2 --epoch 100 --step 2 --grid_num 10 --epsilon 0.1 --mode state_action 20 | pkill -f necsa 21 | python necsa_td3.py --task InvertedDoublePendulum-v2 --epoch 100 --step 2 --grid_num 10 --epsilon 0.1 --mode state_action 22 | pkill -f necsa 23 | 24 | python necsa_td3.py --task InvertedDoublePendulum-v2 --epoch 100 --step 3 --grid_num 10 --epsilon 0.1 --mode state_action 25 | pkill -f necsa 26 | python necsa_td3.py --task InvertedDoublePendulum-v2 --epoch 100 --step 3 --grid_num 10 --epsilon 0.1 --mode state_action 27 | pkill -f necsa 28 | python necsa_td3.py --task InvertedDoublePendulum-v2 --epoch 100 --step 3 --grid_num 10 --epsilon 0.1 --mode state_action 29 | pkill -f necsa 30 | python necsa_td3.py --task InvertedDoublePendulum-v2 --epoch 100 --step 3 --grid_num 10 --epsilon 0.1 --mode state_action 31 | pkill -f necsa 32 | python necsa_td3.py --task InvertedDoublePendulum-v2 --epoch 100 --step 3 --grid_num 10 --epsilon 0.1 --mode state_action 33 | pkill -f necsa -------------------------------------------------------------------------------- /scripts/SpaceInvaders/train_NECSA.sh: -------------------------------------------------------------------------------- 1 | python necsa_rainbow.py --task SpaceInvadersNoFrameskip-v4 --epoch 500 --step 1 --epsilon 0.1 --mode hidden --reduction 2 | killall -9 python 3 | python necsa_rainbow.py --task SpaceInvadersNoFrameskip-v4 --epoch 500 --step 1 --epsilon 0.1 --mode hidden --reduction 4 | killall -9 python 5 | python necsa_rainbow.py --task SpaceInvadersNoFrameskip-v4 --epoch 500 --step 1 --epsilon 0.1 --mode hidden --reduction 6 | killall -9 python 7 | python necsa_rainbow.py --task SpaceInvadersNoFrameskip-v4 --epoch 500 --step 1 --epsilon 0.1 --mode hidden --reduction 8 | killall -9 python 9 | python necsa_rainbow.py --task SpaceInvadersNoFrameskip-v4 --epoch 500 --step 1 --epsilon 0.1 --mode hidden --reduction 10 | killall -9 python 11 | 12 | 13 | 14 | python necsa_rainbow.py --task SpaceInvadersNoFrameskip-v4 --epoch 500 --step 2 --epsilon 0.1 --mode hidden --reduction 15 | killall -9 python 16 | python necsa_rainbow.py --task SpaceInvadersNoFrameskip-v4 --epoch 500 --step 2 --epsilon 0.1 --mode hidden --reduction 17 | killall -9 python 18 | python necsa_rainbow.py --task SpaceInvadersNoFrameskip-v4 --epoch 500 --step 2 --epsilon 0.1 --mode hidden --reduction 19 | killall -9 python 20 | python necsa_rainbow.py --task SpaceInvadersNoFrameskip-v4 --epoch 500 --step 2 --epsilon 0.1 --mode hidden --reduction 21 | killall -9 python 22 | python necsa_rainbow.py --task SpaceInvadersNoFrameskip-v4 --epoch 500 --step 2 --epsilon 0.1 --mode hidden --reduction 23 | killall -9 python 24 | 25 | 26 | python necsa_rainbow.py --task SpaceInvadersNoFrameskip-v4 --epoch 500 --step 3 --epsilon 0.1 --mode hidden --reduction 27 | killall -9 python 28 | python necsa_rainbow.py --task SpaceInvadersNoFrameskip-v4 --epoch 500 --step 3 --epsilon 0.1 --mode hidden --reduction 29 | killall -9 python 30 | python necsa_rainbow.py --task SpaceInvadersNoFrameskip-v4 --epoch 500 --step 3 --epsilon 0.1 --mode hidden --reduction 31 | killall -9 python 32 | python necsa_rainbow.py --task SpaceInvadersNoFrameskip-v4 --epoch 500 --step 3 --epsilon 0.1 --mode hidden --reduction 33 | killall -9 python 34 | python necsa_rainbow.py --task SpaceInvadersNoFrameskip-v4 --epoch 500 --step 3 --epsilon 0.1 --mode hidden --reduction 35 | killall -9 python 36 | 37 | -------------------------------------------------------------------------------- /scripts/Reacher-v2/train_NECSA_TD3.sh: -------------------------------------------------------------------------------- 1 | python necsa_td3.py --task Reacher-v2 --epoch 100 --step 1 --grid_num 10 --state_min -6 --state_max 6 --epsilon 0.1 --mode state_action 2 | pkill -f necsa 3 | python necsa_td3.py --task Reacher-v2 --epoch 100 --step 1 --grid_num 10 --state_min -6 --state_max 6 --epsilon 0.1 --mode state_action 4 | pkill -f necsa 5 | python necsa_td3.py --task Reacher-v2 --epoch 100 --step 1 --grid_num 10 --state_min -6 --state_max 6 --epsilon 0.1 --mode state_action 6 | pkill -f necsa 7 | python necsa_td3.py --task Reacher-v2 --epoch 100 --step 1 --grid_num 10 --state_min -6 --state_max 6 --epsilon 0.1 --mode state_action 8 | pkill -f necsa 9 | python necsa_td3.py --task Reacher-v2 --epoch 100 --step 1 --grid_num 10 --state_min -6 --state_max 6 --epsilon 0.1 --mode state_action 10 | pkill -f necsa 11 | 12 | 13 | python necsa_td3.py --task Reacher-v2 --epoch 100 --step 2 --grid_num 10 --state_min -6 --state_max 6 --epsilon 0.1 --mode state_action 14 | pkill -f necsa 15 | python necsa_td3.py --task Reacher-v2 --epoch 100 --step 2 --grid_num 10 --state_min -6 --state_max 6 --epsilon 0.1 --mode state_action 16 | pkill -f necsa 17 | python necsa_td3.py --task Reacher-v2 --epoch 100 --step 2 --grid_num 10 --state_min -6 --state_max 6 --epsilon 0.1 --mode state_action 18 | pkill -f necsa 19 | python necsa_td3.py --task Reacher-v2 --epoch 100 --step 2 --grid_num 10 --state_min -6 --state_max 6 --epsilon 0.1 --mode state_action 20 | pkill -f necsa 21 | python necsa_td3.py --task Reacher-v2 --epoch 100 --step 2 --grid_num 10 --state_min -6 --state_max 6 --epsilon 0.1 --mode state_action 22 | pkill -f necsa 23 | 24 | python necsa_td3.py --task Reacher-v2 --epoch 100 --step 3 --grid_num 10 --state_min -6 --state_max 6 --epsilon 0.1 --mode state_action 25 | pkill -f necsa 26 | python necsa_td3.py --task Reacher-v2 --epoch 100 --step 3 --grid_num 10 --state_min -6 --state_max 6 --epsilon 0.1 --mode state_action 27 | pkill -f necsa 28 | python necsa_td3.py --task Reacher-v2 --epoch 100 --step 3 --grid_num 10 --state_min -6 --state_max 6 --epsilon 0.1 --mode state_action 29 | pkill -f necsa 30 | python necsa_td3.py --task Reacher-v2 --epoch 100 --step 3 --grid_num 10 --state_min -6 --state_max 6 --epsilon 0.1 --mode state_action 31 | pkill -f necsa 32 | python necsa_td3.py --task Reacher-v2 --epoch 100 --step 3 --grid_num 10 --state_min -6 --state_max 6 --epsilon 0.1 --mode state_action 33 | pkill -f necsa -------------------------------------------------------------------------------- /scripts/Swimmer-v3/train_NECSA_TD3.sh: -------------------------------------------------------------------------------- 1 | python necsa_td3.py --task Swimmer-v3 --epoch 400 --step 1 --grid_num 10 --epsilon 0.1 --state_min -6 --state_max 6 --mode state_action 2 | pkill -f necsa 3 | python necsa_td3.py --task Swimmer-v3 --epoch 400 --step 1 --grid_num 10 --epsilon 0.1 --state_min -6 --state_max 6 --mode state_action 4 | pkill -f necsa 5 | python necsa_td3.py --task Swimmer-v3 --epoch 400 --step 1 --grid_num 10 --epsilon 0.1 --state_min -6 --state_max 6 --mode state_action 6 | pkill -f necsa 7 | python necsa_td3.py --task Swimmer-v3 --epoch 400 --step 1 --grid_num 10 --epsilon 0.1 --state_min -6 --state_max 6 --mode state_action 8 | pkill -f necsa 9 | python necsa_td3.py --task Swimmer-v3 --epoch 400 --step 1 --grid_num 10 --epsilon 0.1 --state_min -6 --state_max 6 --mode state_action 10 | pkill -f necsa 11 | 12 | 13 | python necsa_td3.py --task Swimmer-v3 --epoch 400 --step 2 --grid_num 10 --epsilon 0.1 --state_min -6 --state_max 6 --mode state_action 14 | pkill -f necsa 15 | python necsa_td3.py --task Swimmer-v3 --epoch 400 --step 2 --grid_num 10 --epsilon 0.1 --state_min -6 --state_max 6 --mode state_action 16 | pkill -f necsa 17 | python necsa_td3.py --task Swimmer-v3 --epoch 400 --step 2 --grid_num 10 --epsilon 0.1 --state_min -6 --state_max 6 --mode state_action 18 | pkill -f necsa 19 | python necsa_td3.py --task Swimmer-v3 --epoch 400 --step 2 --grid_num 10 --epsilon 0.1 --state_min -6 --state_max 6 --mode state_action 20 | pkill -f necsa 21 | python necsa_td3.py --task Swimmer-v3 --epoch 400 --step 2 --grid_num 10 --epsilon 0.1 --state_min -6 --state_max 6 --mode state_action 22 | pkill -f necsa 23 | 24 | python necsa_td3.py --task Swimmer-v3 --epoch 400 --step 3 --grid_num 10 --epsilon 0.1 --state_min -6 --state_max 6 --mode state_action 25 | pkill -f necsa 26 | python necsa_td3.py --task Swimmer-v3 --epoch 400 --step 3 --grid_num 10 --epsilon 0.1 --state_min -6 --state_max 6 --mode state_action 27 | pkill -f necsa 28 | python necsa_td3.py --task Swimmer-v3 --epoch 400 --step 3 --grid_num 10 --epsilon 0.1 --state_min -6 --state_max 6 --mode state_action 29 | pkill -f necsa 30 | python necsa_td3.py --task Swimmer-v3 --epoch 400 --step 3 --grid_num 10 --epsilon 0.1 --state_min -6 --state_max 6 --mode state_action 31 | pkill -f necsa 32 | python necsa_td3.py --task Swimmer-v3 --epoch 400 --step 3 --grid_num 10 --epsilon 0.1 --state_min -6 --state_max 6 --mode state_action 33 | pkill -f necsa -------------------------------------------------------------------------------- /tianshou/env/gym_wrappers.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | import gym 4 | import numpy as np 5 | 6 | 7 | class ContinuousToDiscrete(gym.ActionWrapper): 8 | """Gym environment wrapper to take discrete action in a continuous environment. 9 | 10 | :param gym.Env env: gym environment with continuous action space. 11 | :param int action_per_dim: number of discrete actions in each dimension 12 | of the action space. 13 | """ 14 | 15 | def __init__(self, env: gym.Env, action_per_dim: Union[int, List[int]]) -> None: 16 | super().__init__(env) 17 | assert isinstance(env.action_space, gym.spaces.Box) 18 | low, high = env.action_space.low, env.action_space.high 19 | if isinstance(action_per_dim, int): 20 | action_per_dim = [action_per_dim] * env.action_space.shape[0] 21 | assert len(action_per_dim) == env.action_space.shape[0] 22 | self.action_space = gym.spaces.MultiDiscrete(action_per_dim) 23 | self.mesh = np.array( 24 | [np.linspace(lo, hi, a) for lo, hi, a in zip(low, high, action_per_dim)], 25 | dtype=object 26 | ) 27 | 28 | def action(self, act: np.ndarray) -> np.ndarray: 29 | # modify act 30 | assert len(act.shape) <= 2, f"Unknown action format with shape {act.shape}." 31 | if len(act.shape) == 1: 32 | return np.array([self.mesh[i][a] for i, a in enumerate(act)]) 33 | return np.array([[self.mesh[i][a] for i, a in enumerate(a_)] for a_ in act]) 34 | 35 | 36 | class MultiDiscreteToDiscrete(gym.ActionWrapper): 37 | """Gym environment wrapper to take discrete action in multidiscrete environment. 38 | 39 | :param gym.Env env: gym environment with multidiscrete action space. 40 | """ 41 | 42 | def __init__(self, env: gym.Env) -> None: 43 | super().__init__(env) 44 | assert isinstance(env.action_space, gym.spaces.MultiDiscrete) 45 | nvec = env.action_space.nvec 46 | assert nvec.ndim == 1 47 | self.bases = np.ones_like(nvec) 48 | for i in range(1, len(self.bases)): 49 | self.bases[i] = self.bases[i - 1] * nvec[-i] 50 | self.action_space = gym.spaces.Discrete(np.prod(nvec)) 51 | 52 | def action(self, act: np.ndarray) -> np.ndarray: 53 | converted_act = [] 54 | for b in np.flip(self.bases): 55 | converted_act.append(act // b) 56 | act = act % b 57 | return np.array(converted_act).transpose() 58 | -------------------------------------------------------------------------------- /scripts/HalfCheetah-v3/train_NECSA_TD3.sh: -------------------------------------------------------------------------------- 1 | # python necsa_td3.py --task HalfCheetah-v3 --epoch 1000 --step 1 --grid_num 6 --state_min -6 --state_max 6 --epsilon 0.2 --mode state_action 2 | # pkill -f necsa 3 | # python necsa_td3.py --task HalfCheetah-v3 --epoch 1000 --step 1 --grid_num 6 --state_min -6 --state_max 6 --epsilon 0.2 --mode state_action 4 | # pkill -f necsa 5 | # python necsa_td3.py --task HalfCheetah-v3 --epoch 1000 --step 1 --grid_num 6 --state_min -6 --state_max 6 --epsilon 0.2 --mode state_action 6 | # pkill -f necsa 7 | # python necsa_td3.py --task HalfCheetah-v3 --epoch 1000 --step 1 --grid_num 6 --state_min -6 --state_max 6 --epsilon 0.2 --mode state_action 8 | # pkill -f necsa 9 | # python necsa_td3.py --task HalfCheetah-v3 --epoch 1000 --step 1 --grid_num 6 --state_min -6 --state_max 6 --epsilon 0.2 --mode state_action 10 | # pkill -f necsa 11 | 12 | 13 | # python necsa_td3.py --task HalfCheetah-v3 --epoch 1000 --step 2 --grid_num 6 --state_min -6 --state_max 6 --epsilon 0.2 --mode state_action 14 | # pkill -f necsa 15 | # python necsa_td3.py --task HalfCheetah-v3 --epoch 1000 --step 2 --grid_num 6 --state_min -6 --state_max 6 --epsilon 0.2 --mode state_action 16 | # pkill -f necsa 17 | # python necsa_td3.py --task HalfCheetah-v3 --epoch 1000 --step 2 --grid_num 6 --state_min -6 --state_max 6 --epsilon 0.2 --mode state_action 18 | # pkill -f necsa 19 | # python necsa_td3.py --task HalfCheetah-v3 --epoch 1000 --step 2 --grid_num 6 --state_min -6 --state_max 6 --epsilon 0.2 --mode state_action 20 | # pkill -f necsa 21 | # python necsa_td3.py --task HalfCheetah-v3 --epoch 1000 --step 2 --grid_num 6 --state_min -6 --state_max 6 --epsilon 0.2 --mode state_action 22 | # pkill -f necsa 23 | 24 | python necsa_td3.py --task HalfCheetah-v3 --epoch 1000 --step 3 --grid_num 6 --state_min -6 --state_max 6 --epsilon 0.4 --mode state_action 25 | pkill -f necsa 26 | python necsa_td3.py --task HalfCheetah-v3 --epoch 1000 --step 3 --grid_num 6 --state_min -6 --state_max 6 --epsilon 0.4 --mode state_action 27 | pkill -f necsa 28 | python necsa_td3.py --task HalfCheetah-v3 --epoch 1000 --step 3 --grid_num 6 --state_min -6 --state_max 6 --epsilon 0.4 --mode state_action 29 | pkill -f necsa 30 | python necsa_td3.py --task HalfCheetah-v3 --epoch 1000 --step 3 --grid_num 6 --state_min -6 --state_max 6 --epsilon 0.4 --mode state_action 31 | pkill -f necsa 32 | python necsa_td3.py --task HalfCheetah-v3 --epoch 1000 --step 3 --grid_num 6 --state_min -6 --state_max 6 --epsilon 0.4 --mode state_action 33 | pkill -f necsa 34 | 35 | -------------------------------------------------------------------------------- /tianshou/policy/imitation/base.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional, Union 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | from tianshou.data import Batch, to_torch 8 | from tianshou.policy import BasePolicy 9 | 10 | 11 | class ImitationPolicy(BasePolicy): 12 | """Implementation of vanilla imitation learning. 13 | 14 | :param torch.nn.Module model: a model following the rules in 15 | :class:`~tianshou.policy.BasePolicy`. (s -> a) 16 | :param torch.optim.Optimizer optim: for optimizing the model. 17 | :param gym.Space action_space: env's action space. 18 | :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in 19 | optimizer in each policy.update(). Default to None (no lr_scheduler). 20 | 21 | .. seealso:: 22 | 23 | Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed 24 | explanation. 25 | """ 26 | 27 | def __init__( 28 | self, 29 | model: torch.nn.Module, 30 | optim: torch.optim.Optimizer, 31 | **kwargs: Any, 32 | ) -> None: 33 | super().__init__(**kwargs) 34 | self.model = model 35 | self.optim = optim 36 | assert self.action_type in ["continuous", "discrete"], \ 37 | "Please specify action_space." 38 | 39 | def forward( 40 | self, 41 | batch: Batch, 42 | state: Optional[Union[dict, Batch, np.ndarray]] = None, 43 | **kwargs: Any, 44 | ) -> Batch: 45 | logits, hidden = self.model(batch.obs, state=state, info=batch.info) 46 | if self.action_type == "discrete": 47 | act = logits.max(dim=1)[1] 48 | else: 49 | act = logits 50 | return Batch(logits=logits, act=act, state=hidden) 51 | 52 | def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: 53 | self.optim.zero_grad() 54 | if self.action_type == "continuous": # regression 55 | act = self(batch).act 56 | act_target = to_torch(batch.act, dtype=torch.float32, device=act.device) 57 | loss = F.mse_loss(act, act_target) 58 | elif self.action_type == "discrete": # classification 59 | act = F.log_softmax(self(batch).logits, dim=-1) 60 | act_target = to_torch(batch.act, dtype=torch.long, device=act.device) 61 | loss = F.nll_loss(act, act_target) 62 | loss.backward() 63 | self.optim.step() 64 | return {"loss": loss.item()} 65 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | wheel==0.38.4 2 | setuptools==63.2.0 3 | absl-py==1.0.0 4 | ale-py==0.8.1 5 | astor==0.8.1 6 | astunparse==1.6.3 7 | atari-py==0.2.9 8 | cachetools==5.0.0 9 | certifi==2023.5.7 10 | cffi==1.15.0 11 | charset-normalizer==2.0.12 12 | cloudpickle==1.6.0 13 | conda-pack==0.6.0 14 | cycler==0.11.0 15 | Cython==0.29.28 16 | dm-env==1.6 17 | dm-tree==0.1.8 18 | envpool==0.8.2 19 | Farama-Notifications==0.0.4 20 | fasteners==0.17.3 21 | fonttools==4.32.0 22 | future==0.18.2 23 | gast==0.3.3 24 | glfw==2.5.3 25 | google-auth==2.19.0 26 | google-auth-oauthlib==1.0.0 27 | google-pasta==0.2.0 28 | grpcio==1.54.2 29 | gym==0.19 30 | gym-notices==0.0.6 31 | gymnasium==0.28.1 32 | h5py==3.7.0 33 | idna==3.3 34 | imageio==2.1.2 35 | importlib-metadata==4.11.3 36 | importlib-resources==5.12.0 37 | jax-jumpy==1.0.0 38 | joblib==1.1.0 39 | Keras-Applications==1.0.8 40 | Keras-Preprocessing==1.1.2 41 | kiwisolver==1.4.2 42 | llvmlite==0.39.0 43 | Markdown==3.3.6 44 | matplotlib==3.5.1 45 | mujoco-py==2.1.2.14 46 | numba==0.56.0 47 | numpy==1.22.0 48 | nvidia-cublas-cu12==12.1.0.26 49 | nvidia-cuda-cupti-cu12==12.1.62 50 | nvidia-cuda-nvcc-cu12==12.1.66 51 | nvidia-cuda-runtime-cu12==12.1.55 52 | nvidia-cudnn-cu12==8.9.0.131 53 | nvidia-cufft-cu12==11.0.2.4 54 | nvidia-curand-cu12==10.3.2.56 55 | nvidia-cusolver-cu12==11.4.4.55 56 | nvidia-cusparse-cu12==12.0.2.55 57 | nvidia-dali-cuda110==1.23.0 58 | nvidia-dali-nvtf-plugin==1.23.0+nv23.3 59 | nvidia-horovod==0.27.0+nv23.3 60 | nvidia-nccl-cu12==2.17.1 61 | nvidia-nvjitlink-cu12==12.1.55 62 | nvidia-pyindex==1.0.9 63 | oauthlib==3.2.0 64 | opencv-python==4.7.0.72 65 | opt-einsum==3.3.0 66 | optree==0.9.1 67 | packaging==21.3 68 | Pillow==7.2.0 69 | protobuf==3.20.0 70 | psutil==5.7.0 71 | pyasn1==0.4.8 72 | pyasn1-modules==0.2.8 73 | pycparser==2.21 74 | pyglet==1.5.0 75 | pyparsing==3.0.8 76 | python-dateutil==2.8.2 77 | PyYAML==6.0 78 | requests==2.27.1 79 | requests-oauthlib==1.3.1 80 | rsa==4.8 81 | scikit-learn==1.0.2 82 | scipy==1.7.3 83 | six==1.16.0 84 | sklearn==0.0 85 | tensorboard==2.13.0 86 | tensorboard-data-server==0.7.0 87 | tensorboard-plugin-wit==1.8.1 88 | tensorflow-estimator==1.15.1 89 | tensorrt==8.6.0 90 | termcolor==2.2.0 91 | threadpoolctl==3.1.0 92 | torch==1.11.0 93 | torchaudio==0.11.0 94 | torchvision==0.12.0 95 | tqdm==4.64.0 96 | types-protobuf==4.23.0.1 97 | typing_extensions==4.6.2 98 | urllib3==1.26.9 99 | Werkzeug==2.1.1 100 | wrapt==1.15.0 101 | zipp==3.8.0 102 | -------------------------------------------------------------------------------- /scripts/HalfCheetah-v3/train_NECSA_DDPG.sh: -------------------------------------------------------------------------------- 1 | # python necsa_ddpg.py --task HalfCheetah-v3 --epoch 1000 --step 1 --grid_num 6 --state_min -6 --state_max 6 --epsilon 0.2 --mode state_action 2 | # pkill -f necsa 3 | # python necsa_ddpg.py --task HalfCheetah-v3 --epoch 1000 --step 1 --grid_num 6 --state_min -6 --state_max 6 --epsilon 0.2 --mode state_action 4 | # pkill -f necsa 5 | # python necsa_ddpg.py --task HalfCheetah-v3 --epoch 1000 --step 1 --grid_num 6 --state_min -6 --state_max 6 --epsilon 0.2 --mode state_action 6 | # pkill -f necsa 7 | # python necsa_ddpg.py --task HalfCheetah-v3 --epoch 1000 --step 1 --grid_num 6 --state_min -6 --state_max 6 --epsilon 0.2 --mode state_action 8 | # pkill -f necsa 9 | # python necsa_ddpg.py --task HalfCheetah-v3 --epoch 1000 --step 1 --grid_num 6 --state_min -6 --state_max 6 --epsilon 0.2 --mode state_action 10 | # pkill -f necsa 11 | 12 | 13 | # python necsa_ddpg.py --task HalfCheetah-v3 --epoch 1000 --step 2 --grid_num 6 --state_min -6 --state_max 6 --epsilon 0.2 --mode state_action 14 | # pkill -f necsa 15 | # python necsa_ddpg.py --task HalfCheetah-v3 --epoch 1000 --step 2 --grid_num 6 --state_min -6 --state_max 6 --epsilon 0.2 --mode state_action 16 | # pkill -f necsa 17 | # python necsa_ddpg.py --task HalfCheetah-v3 --epoch 1000 --step 2 --grid_num 6 --state_min -6 --state_max 6 --epsilon 0.2 --mode state_action 18 | # pkill -f necsa 19 | # python necsa_ddpg.py --task HalfCheetah-v3 --epoch 1000 --step 2 --grid_num 6 --state_min -6 --state_max 6 --epsilon 0.2 --mode state_action 20 | # pkill -f necsa 21 | # python necsa_ddpg.py --task HalfCheetah-v3 --epoch 1000 --step 2 --grid_num 6 --state_min -6 --state_max 6 --epsilon 0.2 --mode state_action 22 | # pkill -f necsa 23 | 24 | python necsa_ddpg.py --task HalfCheetah-v3 --epoch 1000 --step 3 --grid_num 6 --state_min -6 --state_max 6 --epsilon 0.4 --mode state_action 25 | pkill -f necsa 26 | python necsa_ddpg.py --task HalfCheetah-v3 --epoch 1000 --step 3 --grid_num 6 --state_min -6 --state_max 6 --epsilon 0.4 --mode state_action 27 | pkill -f necsa 28 | python necsa_ddpg.py --task HalfCheetah-v3 --epoch 1000 --step 3 --grid_num 6 --state_min -6 --state_max 6 --epsilon 0.4 --mode state_action 29 | pkill -f necsa 30 | python necsa_ddpg.py --task HalfCheetah-v3 --epoch 1000 --step 3 --grid_num 6 --state_min -6 --state_max 6 --epsilon 0.4 --mode state_action 31 | pkill -f necsa 32 | python necsa_ddpg.py --task HalfCheetah-v3 --epoch 1000 --step 3 --grid_num 6 --state_min -6 --state_max 6 --epsilon 0.4 --mode state_action 33 | pkill -f necsa 34 | 35 | -------------------------------------------------------------------------------- /tianshou/data/buffer/vecbuf.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import numpy as np 4 | 5 | from tianshou.data import ( 6 | PrioritizedReplayBuffer, 7 | PrioritizedReplayBufferManager, 8 | ReplayBuffer, 9 | ReplayBufferManager, 10 | ) 11 | 12 | 13 | class VectorReplayBuffer(ReplayBufferManager): 14 | """VectorReplayBuffer contains n ReplayBuffer with the same size. 15 | 16 | It is used for storing transition from different environments yet keeping the order 17 | of time. 18 | 19 | :param int total_size: the total size of VectorReplayBuffer. 20 | :param int buffer_num: the number of ReplayBuffer it uses, which are under the same 21 | configuration. 22 | 23 | Other input arguments (stack_num/ignore_obs_next/save_only_last_obs/sample_avail) 24 | are the same as :class:`~tianshou.data.ReplayBuffer`. 25 | 26 | .. seealso:: 27 | 28 | Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. 29 | """ 30 | 31 | def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None: 32 | assert buffer_num > 0 33 | size = int(np.ceil(total_size / buffer_num)) 34 | buffer_list = [ReplayBuffer(size, **kwargs) for _ in range(buffer_num)] 35 | super().__init__(buffer_list) 36 | 37 | 38 | class PrioritizedVectorReplayBuffer(PrioritizedReplayBufferManager): 39 | """PrioritizedVectorReplayBuffer contains n PrioritizedReplayBuffer with same size. 40 | 41 | It is used for storing transition from different environments yet keeping the order 42 | of time. 43 | 44 | :param int total_size: the total size of PrioritizedVectorReplayBuffer. 45 | :param int buffer_num: the number of PrioritizedReplayBuffer it uses, which are 46 | under the same configuration. 47 | 48 | Other input arguments (alpha/beta/stack_num/ignore_obs_next/save_only_last_obs/ 49 | sample_avail) are the same as :class:`~tianshou.data.PrioritizedReplayBuffer`. 50 | 51 | .. seealso:: 52 | 53 | Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. 54 | """ 55 | 56 | def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None: 57 | assert buffer_num > 0 58 | size = int(np.ceil(total_size / buffer_num)) 59 | buffer_list = [ 60 | PrioritizedReplayBuffer(size, **kwargs) for _ in range(buffer_num) 61 | ] 62 | super().__init__(buffer_list) 63 | 64 | def set_beta(self, beta: float) -> None: 65 | for buffer in self.buffers: 66 | buffer.set_beta(beta) 67 | -------------------------------------------------------------------------------- /tianshou/policy/__init__.py: -------------------------------------------------------------------------------- 1 | """Policy package.""" 2 | # isort:skip_file 3 | 4 | from tianshou.policy.base import BasePolicy 5 | from tianshou.policy.random import RandomPolicy 6 | from tianshou.policy.modelfree.dqn import DQNPolicy 7 | from tianshou.policy.modelfree.bdq import BranchingDQNPolicy 8 | from tianshou.policy.modelfree.c51 import C51Policy 9 | from tianshou.policy.modelfree.rainbow import RainbowPolicy 10 | from tianshou.policy.modelfree.qrdqn import QRDQNPolicy 11 | from tianshou.policy.modelfree.iqn import IQNPolicy 12 | from tianshou.policy.modelfree.fqf import FQFPolicy 13 | from tianshou.policy.modelfree.pg import PGPolicy 14 | from tianshou.policy.modelfree.a2c import A2CPolicy 15 | from tianshou.policy.modelfree.npg import NPGPolicy 16 | from tianshou.policy.modelfree.ddpg import DDPGPolicy 17 | from tianshou.policy.modelfree.ppo import PPOPolicy 18 | from tianshou.policy.modelfree.trpo import TRPOPolicy 19 | from tianshou.policy.modelfree.td3 import TD3Policy 20 | from tianshou.policy.modelfree.sac import SACPolicy 21 | from tianshou.policy.modelfree.redq import REDQPolicy 22 | from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy 23 | from tianshou.policy.imitation.base import ImitationPolicy 24 | from tianshou.policy.imitation.bcq import BCQPolicy 25 | from tianshou.policy.imitation.cql import CQLPolicy 26 | from tianshou.policy.imitation.td3_bc import TD3BCPolicy 27 | from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy 28 | from tianshou.policy.imitation.discrete_cql import DiscreteCQLPolicy 29 | from tianshou.policy.imitation.discrete_crr import DiscreteCRRPolicy 30 | from tianshou.policy.imitation.gail import GAILPolicy 31 | from tianshou.policy.modelbased.psrl import PSRLPolicy 32 | from tianshou.policy.modelbased.icm import ICMPolicy 33 | from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager 34 | 35 | __all__ = [ 36 | "BasePolicy", 37 | "RandomPolicy", 38 | "DQNPolicy", 39 | "BranchingDQNPolicy", 40 | "C51Policy", 41 | "RainbowPolicy", 42 | "QRDQNPolicy", 43 | "IQNPolicy", 44 | "FQFPolicy", 45 | "PGPolicy", 46 | "A2CPolicy", 47 | "NPGPolicy", 48 | "DDPGPolicy", 49 | "PPOPolicy", 50 | "TRPOPolicy", 51 | "TD3Policy", 52 | "SACPolicy", 53 | "REDQPolicy", 54 | "DiscreteSACPolicy", 55 | "ImitationPolicy", 56 | "BCQPolicy", 57 | "CQLPolicy", 58 | "TD3BCPolicy", 59 | "DiscreteBCQPolicy", 60 | "DiscreteCQLPolicy", 61 | "DiscreteCRRPolicy", 62 | "GAILPolicy", 63 | "PSRLPolicy", 64 | "ICMPolicy", 65 | "MultiAgentPolicyManager", 66 | ] 67 | -------------------------------------------------------------------------------- /tianshou/env/worker/ray.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, List, Optional, Tuple, Union 2 | 3 | import gym 4 | import numpy as np 5 | 6 | from tianshou.env.worker import EnvWorker 7 | 8 | try: 9 | import ray 10 | except ImportError: 11 | pass 12 | 13 | 14 | class _SetAttrWrapper(gym.Wrapper): 15 | 16 | def set_env_attr(self, key: str, value: Any) -> None: 17 | setattr(self.env.unwrapped, key, value) 18 | 19 | def get_env_attr(self, key: str) -> Any: 20 | return getattr(self.env, key) 21 | 22 | 23 | class RayEnvWorker(EnvWorker): 24 | """Ray worker used in RayVectorEnv.""" 25 | 26 | def __init__(self, env_fn: Callable[[], gym.Env]) -> None: 27 | self.env = ray.remote(_SetAttrWrapper).options( # type: ignore 28 | num_cpus=0 29 | ).remote(env_fn()) 30 | super().__init__(env_fn) 31 | 32 | def get_env_attr(self, key: str) -> Any: 33 | return ray.get(self.env.get_env_attr.remote(key)) 34 | 35 | def set_env_attr(self, key: str, value: Any) -> None: 36 | ray.get(self.env.set_env_attr.remote(key, value)) 37 | 38 | def reset(self, **kwargs: Any) -> Any: 39 | if "seed" in kwargs: 40 | super().seed(kwargs["seed"]) 41 | return ray.get(self.env.reset.remote(**kwargs)) 42 | 43 | @staticmethod 44 | def wait( # type: ignore 45 | workers: List["RayEnvWorker"], wait_num: int, timeout: Optional[float] = None 46 | ) -> List["RayEnvWorker"]: 47 | results = [x.result for x in workers] 48 | ready_results, _ = ray.wait(results, num_returns=wait_num, timeout=timeout) 49 | return [workers[results.index(result)] for result in ready_results] 50 | 51 | def send(self, action: Optional[np.ndarray], **kwargs: Any) -> None: 52 | # self.result is actually a handle 53 | if action is None: 54 | self.result = self.env.reset.remote(**kwargs) 55 | else: 56 | self.result = self.env.step.remote(action) 57 | 58 | def recv( 59 | self 60 | ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray]: 61 | return ray.get(self.result) # type: ignore 62 | 63 | def seed(self, seed: Optional[int] = None) -> Optional[List[int]]: 64 | super().seed(seed) 65 | try: 66 | return ray.get(self.env.seed.remote(seed)) 67 | except NotImplementedError: 68 | self.env.reset.remote(seed=seed) 69 | return None 70 | 71 | def render(self, **kwargs: Any) -> Any: 72 | return ray.get(self.env.render.remote(**kwargs)) 73 | 74 | def close_env(self) -> None: 75 | ray.get(self.env.close.remote()) 76 | -------------------------------------------------------------------------------- /scripts/Ant-v3/train_NECSA_TD3.sh: -------------------------------------------------------------------------------- 1 | python necsa_td3.py --task Ant-v3 --epoch 1000 --step 1 --grid_num 6 --state_dim 24 --state_min -6 --state_max 6 --epsilon 0.05 --mode state_action --reduction 2 | killall -9 python 3 | python necsa_td3.py --task Ant-v3 --epoch 1000 --step 1 --grid_num 6 --state_dim 24 --state_min -6 --state_max 6 --epsilon 0.05 --mode state_action --reduction 4 | killall -9 python 5 | python necsa_td3.py --task Ant-v3 --epoch 1000 --step 1 --grid_num 6 --state_dim 24 --state_min -6 --state_max 6 --epsilon 0.05 --mode state_action --reduction 6 | killall -9 python 7 | python necsa_td3.py --task Ant-v3 --epoch 1000 --step 1 --grid_num 6 --state_dim 24 --state_min -6 --state_max 6 --epsilon 0.05 --mode state_action --reduction 8 | killall -9 python 9 | python necsa_td3.py --task Ant-v3 --epoch 1000 --step 1 --grid_num 6 --state_dim 24 --state_min -6 --state_max 6 --epsilon 0.05 --mode state_action --reduction 10 | killall -9 python 11 | 12 | 13 | python necsa_td3.py --task Ant-v3 --epoch 1000 --step 2 --grid_num 6 --state_dim 24 --state_min -6 --state_max 6 --epsilon 0.1 --mode state_action --reduction 14 | killall -9 python 15 | python necsa_td3.py --task Ant-v3 --epoch 1000 --step 2 --grid_num 6 --state_dim 24 --state_min -6 --state_max 6 --epsilon 0.1 --mode state_action --reduction 16 | killall -9 python 17 | python necsa_td3.py --task Ant-v3 --epoch 1000 --step 2 --grid_num 6 --state_dim 24 --state_min -6 --state_max 6 --epsilon 0.1 --mode state_action --reduction 18 | killall -9 python 19 | python necsa_td3.py --task Ant-v3 --epoch 1000 --step 2 --grid_num 6 --state_dim 24 --state_min -6 --state_max 6 --epsilon 0.1 --mode state_action --reduction 20 | killall -9 python 21 | python necsa_td3.py --task Ant-v3 --epoch 1000 --step 2 --grid_num 6 --state_dim 24 --state_min -6 --state_max 6 --epsilon 0.1 --mode state_action --reduction 22 | killall -9 python 23 | 24 | python necsa_td3.py --task Ant-v3 --epoch 1000 --step 3 --grid_num 6 --state_dim 24 --state_min -6 --state_max 6 --epsilon 0.1 --mode state_action --reduction 25 | killall -9 python 26 | python necsa_td3.py --task Ant-v3 --epoch 1000 --step 3 --grid_num 6 --state_dim 24 --state_min -6 --state_max 6 --epsilon 0.1 --mode state_action --reduction 27 | killall -9 python 28 | python necsa_td3.py --task Ant-v3 --epoch 1000 --step 3 --grid_num 6 --state_dim 24 --state_min -6 --state_max 6 --epsilon 0.1 --mode state_action --reduction 29 | killall -9 python 30 | python necsa_td3.py --task Ant-v3 --epoch 1000 --step 3 --grid_num 6 --state_dim 24 --state_min -6 --state_max 6 --epsilon 0.1 --mode state_action --reduction 31 | killall -9 python 32 | python necsa_td3.py --task Ant-v3 --epoch 1000 --step 3 --grid_num 6 --state_dim 24 --state_min -6 --state_max 6 --epsilon 0.1 --mode state_action --reduction 33 | killall -9 python -------------------------------------------------------------------------------- /tianshou/exploration/random.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Optional, Sequence, Union 3 | 4 | import numpy as np 5 | 6 | 7 | class BaseNoise(ABC, object): 8 | """The action noise base class.""" 9 | 10 | def __init__(self) -> None: 11 | super().__init__() 12 | 13 | def reset(self) -> None: 14 | """Reset to the initial state.""" 15 | pass 16 | 17 | @abstractmethod 18 | def __call__(self, size: Sequence[int]) -> np.ndarray: 19 | """Generate new noise.""" 20 | raise NotImplementedError 21 | 22 | 23 | class GaussianNoise(BaseNoise): 24 | """The vanilla Gaussian process, for exploration in DDPG by default.""" 25 | 26 | def __init__(self, mu: float = 0.0, sigma: float = 1.0) -> None: 27 | super().__init__() 28 | self._mu = mu 29 | assert 0 <= sigma, "Noise std should not be negative." 30 | self._sigma = sigma 31 | 32 | def __call__(self, size: Sequence[int]) -> np.ndarray: 33 | return np.random.normal(self._mu, self._sigma, size) 34 | 35 | 36 | class OUNoise(BaseNoise): 37 | """Class for Ornstein-Uhlenbeck process, as used for exploration in DDPG. 38 | 39 | Usage: 40 | :: 41 | 42 | # init 43 | self.noise = OUNoise() 44 | # generate noise 45 | noise = self.noise(logits.shape, eps) 46 | 47 | For required parameters, you can refer to the stackoverflow page. However, 48 | our experiment result shows that (similar to OpenAI SpinningUp) using 49 | vanilla Gaussian process has little difference from using the 50 | Ornstein-Uhlenbeck process. 51 | """ 52 | 53 | def __init__( 54 | self, 55 | mu: float = 0.0, 56 | sigma: float = 0.3, 57 | theta: float = 0.15, 58 | dt: float = 1e-2, 59 | x0: Optional[Union[float, np.ndarray]] = None, 60 | ) -> None: 61 | super().__init__() 62 | self._mu = mu 63 | self._alpha = theta * dt 64 | self._beta = sigma * np.sqrt(dt) 65 | self._x0 = x0 66 | self.reset() 67 | 68 | def reset(self) -> None: 69 | """Reset to the initial state.""" 70 | self._x = self._x0 71 | 72 | def __call__(self, size: Sequence[int], mu: Optional[float] = None) -> np.ndarray: 73 | """Generate new noise. 74 | 75 | Return an numpy array which size is equal to ``size``. 76 | """ 77 | if self._x is None or isinstance( 78 | self._x, np.ndarray 79 | ) and self._x.shape != size: 80 | self._x = 0.0 81 | if mu is None: 82 | mu = self._mu 83 | r = self._beta * np.random.normal(size=size) 84 | self._x = self._x + self._alpha * (mu - self._x) + r 85 | return self._x # type: ignore 86 | -------------------------------------------------------------------------------- /scripts/Humanoid-v3/train_NECSA_TD3.sh: -------------------------------------------------------------------------------- 1 | python necsa_td3.py --task Humanoid-v3 --epoch 1000 --step 1 --grid_num 6 --state_dim 24 --state_min -6 --state_max 6 --epsilon 0.2 --mode state_action --reduction 2 | killall -9 python 3 | python necsa_td3.py --task Humanoid-v3 --epoch 1000 --step 1 --grid_num 6 --state_dim 24 --state_min -6 --state_max 6 --epsilon 0.2 --mode state_action --reduction 4 | killall -9 python 5 | python necsa_td3.py --task Humanoid-v3 --epoch 1000 --step 1 --grid_num 6 --state_dim 24 --state_min -6 --state_max 6 --epsilon 0.2 --mode state_action --reduction 6 | killall -9 python 7 | python necsa_td3.py --task Humanoid-v3 --epoch 1000 --step 1 --grid_num 6 --state_dim 24 --state_min -6 --state_max 6 --epsilon 0.2 --mode state_action --reduction 8 | killall -9 python 9 | python necsa_td3.py --task Humanoid-v3 --epoch 1000 --step 1 --grid_num 6 --state_dim 24 --state_min -6 --state_max 6 --epsilon 0.2 --mode state_action --reduction 10 | killall -9 python 11 | 12 | 13 | python necsa_td3.py --task Humanoid-v3 --epoch 1000 --step 2 --grid_num 6 --state_dim 24 --state_min -6 --state_max 6 --epsilon 0.2 --mode state_action --reduction 14 | killall -9 python 15 | python necsa_td3.py --task Humanoid-v3 --epoch 1000 --step 2 --grid_num 6 --state_dim 24 --state_min -6 --state_max 6 --epsilon 0.2 --mode state_action --reduction 16 | killall -9 python 17 | python necsa_td3.py --task Humanoid-v3 --epoch 1000 --step 2 --grid_num 6 --state_dim 24 --state_min -6 --state_max 6 --epsilon 0.2 --mode state_action --reduction 18 | killall -9 python 19 | python necsa_td3.py --task Humanoid-v3 --epoch 1000 --step 2 --grid_num 6 --state_dim 24 --state_min -6 --state_max 6 --epsilon 0.2 --mode state_action --reduction 20 | killall -9 python 21 | python necsa_td3.py --task Humanoid-v3 --epoch 1000 --step 2 --grid_num 6 --state_dim 24 --state_min -6 --state_max 6 --epsilon 0.2 --mode state_action --reduction 22 | killall -9 python 23 | 24 | python necsa_td3.py --task Humanoid-v3 --epoch 1000 --step 3 --grid_num 6 --state_dim 24 --state_min -6 --state_max 6 --epsilon 0.3 --mode state_action --reduction 25 | killall -9 python 26 | python necsa_td3.py --task Humanoid-v3 --epoch 1000 --step 3 --grid_num 6 --state_dim 24 --state_min -6 --state_max 6 --epsilon 0.3 --mode state_action --reduction 27 | killall -9 python 28 | python necsa_td3.py --task Humanoid-v3 --epoch 1000 --step 3 --grid_num 6 --state_dim 24 --state_min -6 --state_max 6 --epsilon 0.3 --mode state_action --reduction 29 | killall -9 python 30 | python necsa_td3.py --task Humanoid-v3 --epoch 1000 --step 3 --grid_num 6 --state_dim 24 --state_min -6 --state_max 6 --epsilon 0.3 --mode state_action --reduction 31 | killall -9 python 32 | python necsa_td3.py --task Humanoid-v3 --epoch 1000 --step 3 --grid_num 6 --state_dim 24 --state_min -6 --state_max 6 --epsilon 0.3 --mode state_action --reduction 33 | killall -9 python -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Episodic Control with State Abstraction 2 | * [NECSA](https://openreview.net/pdf?id=C2fsSj3ZGiU) is based on [tianshou](https://tianshou.readthedocs.io/en/master/index.html) platform. Please refer the original repo for installation. 3 | 4 | ## 0 Introduction 5 | 6 | * NECSA is implemented in a highly supplementary way. Please refer to tianshou/data/necsa_collector.py and necsa_atari_collector.py for details. 7 | 8 | ## 1 requirements 9 | 10 | * refer to requirements.txt 11 | 12 | ## 2 Anaconda and Python 13 | 14 | * wget https://repo.anaconda.com/archive/Anaconda3-2020.11-Linux-x86_64.sh 15 | * bash ./Anaconda3-2020.11-Linux-x86_64.sh 16 | * (should be changed)echo 'export PATH="$pathToAnaconda/anaconda3/bin:$PATH"' >> ~/.bashrc 17 | * (optional) conda config --set auto_activate_base false 18 | * conda create -n necsa python=3.8.5 19 | * conda activate necsa 20 | * pip3 install -r requirements.txt 21 | 22 | ## 3 Install Atari and MuJoCo 23 | 24 | * Download the [ROM files](http://www.atarimania.com/rom_collection_archive_atari_2600_roms.html) for Atari, unzip and execute: 25 | * python -m atari_py.import_roms 26 | * wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz 27 | * tar xvf mujoco210-linux-x86_64.tar.gz && mkdir -p ~/.mujoco && mv mujoco210 ~/.mujoco/mujoco210 28 | * wget https://www.roboti.us/file/mjkey.txt -O ~/.mujoco/mjkey.txt 29 | * echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:~/.mujoco/mujoco210/bin" >> ~/.bashrc 30 | 31 | 32 | ## 4 Execution: 33 | 34 | * Example: 35 | 36 | python necsa_td3.py --task Walker2d-v3 --epoch 1000 --step 3 --grid_num 5 --epsilon 0.2 --mode state_action 37 | 38 | * Execute the scripts: 39 | 40 | bash scripts/HalfCheetah-v3/train_NECSA_TD3.sh 41 | 42 | ## 5 Experiment results: 43 | 44 | * Data will be automatically saved into ./results 45 | 46 | ## 6 Citing and Thanks 47 | 48 | * Our program is highly depending on tianshou, thanks to the efforts by the developers. Please kindly cite the paper if you referenced our repo. 49 | 50 | ```latex 51 | @article{tianshou, 52 | title={Tianshou: A Highly Modularized Deep Reinforcement Learning Library}, 53 | author={Weng, Jiayi and Chen, Huayu and Yan, Dong and You, Kaichao and Duburcq, Alexis and Zhang, Minghao and Su, Yi and Su, Hang and Zhu, Jun}, 54 | journal={arXiv preprint arXiv:2107.14171}, 55 | year={2021} 56 | } 57 | ``` 58 | 59 | * Our work NECSA is also inspired by 3 state-of-the-art episodic control algorithms: [EMAC](https://github.com/schatty/EMAC), [EVA](https://github.com/AnnaNikitaRL/EVA) and [GEM](https://github.com/MouseHu/GEM). Please refer to the corresponding repo for details. 60 | 61 | ``` 62 | @article{kuznetsov2021solving, 63 | title={Solving Continuous Control with Episodic Memory}, 64 | author={Kuznetsov, Igor and Filchenkov, Andrey}, 65 | journal={arXiv preprint arXiv:2106.08832}, 66 | year={2021} 67 | } 68 | ``` 69 | 70 | ``` 71 | @article{hansen2018fast, 72 | title={Fast deep reinforcement learning using online adjustments from the past}, 73 | author={Hansen, Steven and Pritzel, Alexander and Sprechmann, Pablo and Barreto, Andr{\'e} and Blundell, Charles}, 74 | journal={Advances in Neural Information Processing Systems}, 75 | volume={31}, 76 | year={2018} 77 | } 78 | ``` 79 | 80 | ``` 81 | @article{hu2021generalizable, 82 | title={Generalizable episodic memory for deep reinforcement learning}, 83 | author={Hu, Hao and Ye, Jianing and Zhu, Guangxiang and Ren, Zhizhou and Zhang, Chongjie}, 84 | journal={arXiv preprint arXiv:2103.06469}, 85 | year={2021} 86 | } 87 | ``` 88 | 89 | -------------------------------------------------------------------------------- /tianshou/policy/imitation/discrete_cql.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | from tianshou.data import Batch, to_torch 8 | from tianshou.policy import QRDQNPolicy 9 | 10 | 11 | class DiscreteCQLPolicy(QRDQNPolicy): 12 | """Implementation of discrete Conservative Q-Learning algorithm. arXiv:2006.04779. 13 | 14 | :param torch.nn.Module model: a model following the rules in 15 | :class:`~tianshou.policy.BasePolicy`. (s -> logits) 16 | :param torch.optim.Optimizer optim: a torch.optim for optimizing the model. 17 | :param float discount_factor: in [0, 1]. 18 | :param int num_quantiles: the number of quantile midpoints in the inverse 19 | cumulative distribution function of the value. Default to 200. 20 | :param int estimation_step: the number of steps to look ahead. Default to 1. 21 | :param int target_update_freq: the target network update frequency (0 if 22 | you do not use the target network). 23 | :param bool reward_normalization: normalize the reward to Normal(0, 1). 24 | Default to False. 25 | :param float min_q_weight: the weight for the cql loss. 26 | :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in 27 | optimizer in each policy.update(). Default to None (no lr_scheduler). 28 | 29 | .. seealso:: 30 | Please refer to :class:`~tianshou.policy.QRDQNPolicy` for more detailed 31 | explanation. 32 | """ 33 | 34 | def __init__( 35 | self, 36 | model: torch.nn.Module, 37 | optim: torch.optim.Optimizer, 38 | discount_factor: float = 0.99, 39 | num_quantiles: int = 200, 40 | estimation_step: int = 1, 41 | target_update_freq: int = 0, 42 | reward_normalization: bool = False, 43 | min_q_weight: float = 10.0, 44 | **kwargs: Any, 45 | ) -> None: 46 | super().__init__( 47 | model, optim, discount_factor, num_quantiles, estimation_step, 48 | target_update_freq, reward_normalization, **kwargs 49 | ) 50 | self._min_q_weight = min_q_weight 51 | 52 | def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: 53 | if self._target and self._iter % self._freq == 0: 54 | self.sync_weight() 55 | self.optim.zero_grad() 56 | weight = batch.pop("weight", 1.0) 57 | all_dist = self(batch).logits 58 | act = to_torch(batch.act, dtype=torch.long, device=all_dist.device) 59 | curr_dist = all_dist[np.arange(len(act)), act, :].unsqueeze(2) 60 | target_dist = batch.returns.unsqueeze(1) 61 | # calculate each element's difference between curr_dist and target_dist 62 | dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") 63 | huber_loss = ( 64 | dist_diff * 65 | (self.tau_hat - (target_dist - curr_dist).detach().le(0.).float()).abs() 66 | ).sum(-1).mean(1) 67 | qr_loss = (huber_loss * weight).mean() 68 | # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ 69 | # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 70 | batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer 71 | # add CQL loss 72 | q = self.compute_q_value(all_dist, None) 73 | dataset_expec = q.gather(1, act.unsqueeze(1)).mean() 74 | negative_sampling = q.logsumexp(1).mean() 75 | min_q_loss = negative_sampling - dataset_expec 76 | loss = qr_loss + min_q_loss * self._min_q_weight 77 | loss.backward() 78 | self.optim.step() 79 | self._iter += 1 80 | return { 81 | "loss": loss.item(), 82 | "loss/qr": qr_loss.item(), 83 | "loss/cql": min_q_loss.item(), 84 | } 85 | -------------------------------------------------------------------------------- /tianshou/data/buffer/cached.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Union 2 | 3 | import numpy as np 4 | 5 | from tianshou.data import Batch, ReplayBuffer, ReplayBufferManager 6 | 7 | 8 | class CachedReplayBuffer(ReplayBufferManager): 9 | """CachedReplayBuffer contains a given main buffer and n cached buffers, \ 10 | ``cached_buffer_num * ReplayBuffer(size=max_episode_length)``. 11 | 12 | The memory layout is: ``| main_buffer | cached_buffers[0] | cached_buffers[1] | ... 13 | | cached_buffers[cached_buffer_num - 1] |``. 14 | 15 | The data is first stored in cached buffers. When an episode is terminated, the data 16 | will move to the main buffer and the corresponding cached buffer will be reset. 17 | 18 | :param ReplayBuffer main_buffer: the main buffer whose ``.update()`` function 19 | behaves normally. 20 | :param int cached_buffer_num: number of ReplayBuffer needs to be created for cached 21 | buffer. 22 | :param int max_episode_length: the maximum length of one episode, used in each 23 | cached buffer's maxsize. 24 | 25 | .. seealso:: 26 | 27 | Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. 28 | """ 29 | 30 | def __init__( 31 | self, 32 | main_buffer: ReplayBuffer, 33 | cached_buffer_num: int, 34 | max_episode_length: int, 35 | ) -> None: 36 | assert cached_buffer_num > 0 and max_episode_length > 0 37 | assert type(main_buffer) == ReplayBuffer 38 | kwargs = main_buffer.options 39 | buffers = [main_buffer] + [ 40 | ReplayBuffer(max_episode_length, **kwargs) 41 | for _ in range(cached_buffer_num) 42 | ] 43 | super().__init__(buffer_list=buffers) 44 | self.main_buffer = self.buffers[0] 45 | self.cached_buffers = self.buffers[1:] 46 | self.cached_buffer_num = cached_buffer_num 47 | 48 | def add( 49 | self, 50 | batch: Batch, 51 | buffer_ids: Optional[Union[np.ndarray, List[int]]] = None 52 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: 53 | """Add a batch of data into CachedReplayBuffer. 54 | 55 | Each of the data's length (first dimension) must equal to the length of 56 | buffer_ids. By default the buffer_ids is [0, 1, ..., cached_buffer_num - 1]. 57 | 58 | Return (current_index, episode_reward, episode_length, episode_start_index) 59 | with each of the shape (len(buffer_ids), ...), where (current_index[i], 60 | episode_reward[i], episode_length[i], episode_start_index[i]) refers to the 61 | cached_buffer_ids[i]th cached buffer's corresponding episode result. 62 | """ 63 | if buffer_ids is None: 64 | buf_arr = np.arange(1, 1 + self.cached_buffer_num) 65 | else: # make sure it is np.ndarray 66 | buf_arr = np.asarray(buffer_ids) + 1 67 | ptr, ep_rew, ep_len, ep_idx = super().add(batch, buffer_ids=buf_arr) 68 | # find the terminated episode, move data from cached buf to main buf 69 | updated_ptr, updated_ep_idx = [], [] 70 | done = batch.done.astype(bool) 71 | for buffer_idx in buf_arr[done]: 72 | index = self.main_buffer.update(self.buffers[buffer_idx]) 73 | if len(index) == 0: # unsuccessful move, replace with -1 74 | index = [-1] 75 | updated_ep_idx.append(index[0]) 76 | updated_ptr.append(index[-1]) 77 | self.buffers[buffer_idx].reset() 78 | self._lengths[0] = len(self.main_buffer) 79 | self._lengths[buffer_idx] = 0 80 | self.last_index[0] = index[-1] 81 | self.last_index[buffer_idx] = self._offset[buffer_idx] 82 | ptr[done] = updated_ptr 83 | ep_idx[done] = updated_ep_idx 84 | return ptr, ep_rew, ep_len, ep_idx 85 | -------------------------------------------------------------------------------- /tianshou/utils/logger/tensorboard.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Optional, Tuple 2 | 3 | from tensorboard.backend.event_processing import event_accumulator 4 | from torch.utils.tensorboard import SummaryWriter 5 | 6 | from tianshou.utils.logger.base import LOG_DATA_TYPE, BaseLogger 7 | from tianshou.utils.warning import deprecation 8 | 9 | 10 | class TensorboardLogger(BaseLogger): 11 | """A logger that relies on tensorboard SummaryWriter by default to visualize \ 12 | and log statistics. 13 | 14 | :param SummaryWriter writer: the writer to log data. 15 | :param int train_interval: the log interval in log_train_data(). Default to 1000. 16 | :param int test_interval: the log interval in log_test_data(). Default to 1. 17 | :param int update_interval: the log interval in log_update_data(). Default to 1000. 18 | :param int save_interval: the save interval in save_data(). Default to 1 (save at 19 | the end of each epoch). 20 | :param bool write_flush: whether to flush tensorboard result after each 21 | add_scalar operation. Default to True. 22 | """ 23 | 24 | def __init__( 25 | self, 26 | writer: SummaryWriter, 27 | train_interval: int = 1000, 28 | test_interval: int = 1, 29 | update_interval: int = 1000, 30 | save_interval: int = 1, 31 | write_flush: bool = True, 32 | ) -> None: 33 | super().__init__(train_interval, test_interval, update_interval) 34 | self.save_interval = save_interval 35 | self.write_flush = write_flush 36 | self.last_save_step = -1 37 | self.writer = writer 38 | 39 | def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: 40 | for k, v in data.items(): 41 | self.writer.add_scalar(k, v, global_step=step) 42 | if self.write_flush: # issue 580 43 | self.writer.flush() # issue #482 44 | 45 | def save_data( 46 | self, 47 | epoch: int, 48 | env_step: int, 49 | gradient_step: int, 50 | save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None, 51 | ) -> None: 52 | if save_checkpoint_fn and epoch - self.last_save_step >= self.save_interval: 53 | self.last_save_step = epoch 54 | save_checkpoint_fn(epoch, env_step, gradient_step) 55 | self.write("save/epoch", epoch, {"save/epoch": epoch}) 56 | self.write("save/env_step", env_step, {"save/env_step": env_step}) 57 | self.write( 58 | "save/gradient_step", gradient_step, 59 | {"save/gradient_step": gradient_step} 60 | ) 61 | 62 | def restore_data(self) -> Tuple[int, int, int]: 63 | ea = event_accumulator.EventAccumulator(self.writer.log_dir) 64 | ea.Reload() 65 | 66 | try: # epoch / gradient_step 67 | epoch = ea.scalars.Items("save/epoch")[-1].step 68 | self.last_save_step = self.last_log_test_step = epoch 69 | gradient_step = ea.scalars.Items("save/gradient_step")[-1].step 70 | self.last_log_update_step = gradient_step 71 | except KeyError: 72 | epoch, gradient_step = 0, 0 73 | try: # offline trainer doesn't have env_step 74 | env_step = ea.scalars.Items("save/env_step")[-1].step 75 | self.last_log_train_step = env_step 76 | except KeyError: 77 | env_step = 0 78 | 79 | return epoch, env_step, gradient_step 80 | 81 | 82 | class BasicLogger(TensorboardLogger): 83 | """BasicLogger has changed its name to TensorboardLogger in #427. 84 | 85 | This class is for compatibility. 86 | """ 87 | 88 | def __init__(self, *args: Any, **kwargs: Any) -> None: 89 | deprecation( 90 | "Class BasicLogger is marked as deprecated and will be removed soon. " 91 | "Please use TensorboardLogger instead." 92 | ) 93 | super().__init__(*args, **kwargs) 94 | -------------------------------------------------------------------------------- /tianshou/env/worker/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any, Callable, List, Optional, Tuple, Union 3 | 4 | import gym 5 | import numpy as np 6 | 7 | from tianshou.utils import deprecation 8 | 9 | 10 | class EnvWorker(ABC): 11 | """An abstract worker for an environment.""" 12 | 13 | def __init__(self, env_fn: Callable[[], gym.Env]) -> None: 14 | self._env_fn = env_fn 15 | self.is_closed = False 16 | self.result: Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], 17 | Tuple[np.ndarray, dict], np.ndarray] 18 | self.action_space = self.get_env_attr("action_space") # noqa: B009 19 | self.is_reset = False 20 | 21 | @abstractmethod 22 | def get_env_attr(self, key: str) -> Any: 23 | pass 24 | 25 | @abstractmethod 26 | def set_env_attr(self, key: str, value: Any) -> None: 27 | pass 28 | 29 | def send(self, action: Optional[np.ndarray]) -> None: 30 | """Send action signal to low-level worker. 31 | 32 | When action is None, it indicates sending "reset" signal; otherwise 33 | it indicates "step" signal. The paired return value from "recv" 34 | function is determined by such kind of different signal. 35 | """ 36 | if hasattr(self, "send_action"): 37 | deprecation( 38 | "send_action will soon be deprecated. " 39 | "Please use send and recv for your own EnvWorker." 40 | ) 41 | if action is None: 42 | self.is_reset = True 43 | self.result = self.reset() 44 | else: 45 | self.is_reset = False 46 | self.send_action(action) # type: ignore 47 | 48 | def recv( 49 | self 50 | ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], Tuple[ 51 | np.ndarray, dict], np.ndarray]: # noqa:E125 52 | """Receive result from low-level worker. 53 | 54 | If the last "send" function sends a NULL action, it only returns a 55 | single observation; otherwise it returns a tuple of (obs, rew, done, 56 | info). 57 | """ 58 | if hasattr(self, "get_result"): 59 | deprecation( 60 | "get_result will soon be deprecated. " 61 | "Please use send and recv for your own EnvWorker." 62 | ) 63 | if not self.is_reset: 64 | self.result = self.get_result() # type: ignore 65 | return self.result 66 | 67 | @abstractmethod 68 | def reset(self, **kwargs: Any) -> Union[np.ndarray, Tuple[np.ndarray, dict]]: 69 | pass 70 | 71 | def step( 72 | self, action: np.ndarray 73 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: 74 | """Perform one timestep of the environment's dynamic. 75 | 76 | "send" and "recv" are coupled in sync simulation, so users only call 77 | "step" function. But they can be called separately in async 78 | simulation, i.e. someone calls "send" first, and calls "recv" later. 79 | """ 80 | self.send(action) 81 | return self.recv() # type: ignore 82 | 83 | @staticmethod 84 | def wait( 85 | workers: List["EnvWorker"], 86 | wait_num: int, 87 | timeout: Optional[float] = None 88 | ) -> List["EnvWorker"]: 89 | """Given a list of workers, return those ready ones.""" 90 | raise NotImplementedError 91 | 92 | def seed(self, seed: Optional[int] = None) -> Optional[List[int]]: 93 | return self.action_space.seed(seed) # issue 299 94 | 95 | @abstractmethod 96 | def render(self, **kwargs: Any) -> Any: 97 | """Render the environment.""" 98 | pass 99 | 100 | @abstractmethod 101 | def close_env(self) -> None: 102 | pass 103 | 104 | def close(self) -> None: 105 | if self.is_closed: 106 | return None 107 | self.is_closed = True 108 | self.close_env() 109 | -------------------------------------------------------------------------------- /tianshou/trainer/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Any, Callable, Dict, Optional, Union 3 | 4 | import numpy as np 5 | 6 | from tianshou.data import Collector 7 | from tianshou.policy import BasePolicy 8 | from tianshou.utils import BaseLogger 9 | 10 | 11 | def test_episode( 12 | policy: BasePolicy, 13 | collector: Collector, 14 | test_fn: Optional[Callable[[int, Optional[int]], None]], 15 | epoch: int, 16 | n_episode: int, 17 | logger: Optional[BaseLogger] = None, 18 | global_step: Optional[int] = None, 19 | reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, 20 | ) -> Dict[str, Any]: 21 | """A simple wrapper of testing policy in collector.""" 22 | collector.reset_env() 23 | collector.reset_buffer() 24 | policy.eval() 25 | if test_fn: 26 | test_fn(epoch, global_step) 27 | result = collector.collect(n_episode=n_episode) 28 | if reward_metric: 29 | rew = reward_metric(result["rews"]) 30 | result.update(rews=rew, rew=rew.mean(), rew_std=rew.std()) 31 | if logger and global_step is not None: 32 | logger.log_test_data(result, global_step) 33 | return result 34 | 35 | 36 | def gather_info( 37 | start_time: float, 38 | train_collector: Optional[Collector], 39 | test_collector: Optional[Collector], 40 | best_reward: float, 41 | best_reward_std: float, 42 | ) -> Dict[str, Union[float, str]]: 43 | """A simple wrapper of gathering information from collectors. 44 | 45 | :return: A dictionary with the following keys: 46 | 47 | * ``train_step`` the total collected step of training collector; 48 | * ``train_episode`` the total collected episode of training collector; 49 | * ``train_time/collector`` the time for collecting transitions in the \ 50 | training collector; 51 | * ``train_time/model`` the time for training models; 52 | * ``train_speed`` the speed of training (env_step per second); 53 | * ``test_step`` the total collected step of test collector; 54 | * ``test_episode`` the total collected episode of test collector; 55 | * ``test_time`` the time for testing; 56 | * ``test_speed`` the speed of testing (env_step per second); 57 | * ``best_reward`` the best reward over the test results; 58 | * ``duration`` the total elapsed time. 59 | """ 60 | duration = max(0, time.time() - start_time) 61 | model_time = duration 62 | result: Dict[str, Union[float, str]] = { 63 | "duration": f"{duration:.2f}s", 64 | "train_time/model": f"{model_time:.2f}s", 65 | } 66 | if test_collector is not None: 67 | model_time = max(0, duration - test_collector.collect_time) 68 | test_speed = test_collector.collect_step / test_collector.collect_time 69 | result.update( 70 | { 71 | "test_step": test_collector.collect_step, 72 | "test_episode": test_collector.collect_episode, 73 | "test_time": f"{test_collector.collect_time:.2f}s", 74 | "test_speed": f"{test_speed:.2f} step/s", 75 | "best_reward": best_reward, 76 | "best_result": f"{best_reward:.2f} ± {best_reward_std:.2f}", 77 | "duration": f"{duration:.2f}s", 78 | "train_time/model": f"{model_time:.2f}s", 79 | } 80 | ) 81 | if train_collector is not None: 82 | model_time = max(0, model_time - train_collector.collect_time) 83 | if test_collector is not None: 84 | train_speed = train_collector.collect_step / ( 85 | duration - test_collector.collect_time 86 | ) 87 | else: 88 | train_speed = train_collector.collect_step / duration 89 | result.update( 90 | { 91 | "train_step": train_collector.collect_step, 92 | "train_episode": train_collector.collect_episode, 93 | "train_time/collector": f"{train_collector.collect_time:.2f}s", 94 | "train_time/model": f"{model_time:.2f}s", 95 | "train_speed": f"{train_speed:.2f} step/s", 96 | } 97 | ) 98 | return result 99 | -------------------------------------------------------------------------------- /tianshou/utils/statistics.py: -------------------------------------------------------------------------------- 1 | from numbers import Number 2 | from typing import List, Optional, Union 3 | 4 | import numpy as np 5 | import torch 6 | 7 | 8 | class MovAvg(object): 9 | """Class for moving average. 10 | 11 | It will automatically exclude the infinity and NaN. Usage: 12 | :: 13 | 14 | >>> stat = MovAvg(size=66) 15 | >>> stat.add(torch.tensor(5)) 16 | 5.0 17 | >>> stat.add(float('inf')) # which will not add to stat 18 | 5.0 19 | >>> stat.add([6, 7, 8]) 20 | 6.5 21 | >>> stat.get() 22 | 6.5 23 | >>> print(f'{stat.mean():.2f}±{stat.std():.2f}') 24 | 6.50±1.12 25 | """ 26 | 27 | def __init__(self, size: int = 100) -> None: 28 | super().__init__() 29 | self.size = size 30 | self.cache: List[np.number] = [] 31 | self.banned = [np.inf, np.nan, -np.inf] 32 | 33 | def add( 34 | self, data_array: Union[Number, np.number, list, np.ndarray, torch.Tensor] 35 | ) -> float: 36 | """Add a scalar into :class:`MovAvg`. 37 | 38 | You can add ``torch.Tensor`` with only one element, a python scalar, or 39 | a list of python scalar. 40 | """ 41 | if isinstance(data_array, torch.Tensor): 42 | data_array = data_array.flatten().cpu().numpy() 43 | if np.isscalar(data_array): 44 | data_array = [data_array] 45 | for number in data_array: # type: ignore 46 | if number not in self.banned: 47 | self.cache.append(number) 48 | if self.size > 0 and len(self.cache) > self.size: 49 | self.cache = self.cache[-self.size:] 50 | return self.get() 51 | 52 | def get(self) -> float: 53 | """Get the average.""" 54 | if len(self.cache) == 0: 55 | return 0.0 56 | return float(np.mean(self.cache)) 57 | 58 | def mean(self) -> float: 59 | """Get the average. Same as :meth:`get`.""" 60 | return self.get() 61 | 62 | def std(self) -> float: 63 | """Get the standard deviation.""" 64 | if len(self.cache) == 0: 65 | return 0.0 66 | return float(np.std(self.cache)) 67 | 68 | 69 | class RunningMeanStd(object): 70 | """Calculates the running mean and std of a data stream. 71 | 72 | https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm 73 | 74 | :param mean: the initial mean estimation for data array. Default to 0. 75 | :param std: the initial standard error estimation for data array. Default to 1. 76 | :param float clip_max: the maximum absolute value for data array. Default to 77 | 10.0. 78 | :param float epsilon: To avoid division by zero. 79 | """ 80 | 81 | def __init__( 82 | self, 83 | mean: Union[float, np.ndarray] = 0.0, 84 | std: Union[float, np.ndarray] = 1.0, 85 | clip_max: Optional[float] = 10.0, 86 | epsilon: float = np.finfo(np.float32).eps.item(), 87 | ) -> None: 88 | self.mean, self.var = mean, std 89 | self.clip_max = clip_max 90 | self.count = 0 91 | self.eps = epsilon 92 | 93 | def norm(self, data_array: Union[float, np.ndarray]) -> Union[float, np.ndarray]: 94 | data_array = (data_array - self.mean) / np.sqrt(self.var + self.eps) 95 | if self.clip_max: 96 | data_array = np.clip(data_array, -self.clip_max, self.clip_max) 97 | return data_array 98 | 99 | def update(self, data_array: np.ndarray) -> None: 100 | """Add a batch of item into RMS with the same shape, modify mean/var/count.""" 101 | batch_mean, batch_var = np.mean(data_array, axis=0), np.var(data_array, axis=0) 102 | batch_count = len(data_array) 103 | 104 | delta = batch_mean - self.mean 105 | total_count = self.count + batch_count 106 | 107 | new_mean = self.mean + delta * batch_count / total_count 108 | m_a = self.var * self.count 109 | m_b = batch_var * batch_count 110 | m_2 = m_a + m_b + delta**2 * self.count * batch_count / total_count 111 | new_var = m_2 / total_count 112 | 113 | self.mean, self.var = new_mean, new_var 114 | self.count = total_count 115 | -------------------------------------------------------------------------------- /tianshou/policy/modelfree/qrdqn.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Any, Dict, Optional 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | from tianshou.data import Batch, ReplayBuffer 9 | from tianshou.policy import DQNPolicy 10 | 11 | 12 | class QRDQNPolicy(DQNPolicy): 13 | """Implementation of Quantile Regression Deep Q-Network. arXiv:1710.10044. 14 | 15 | :param torch.nn.Module model: a model following the rules in 16 | :class:`~tianshou.policy.BasePolicy`. (s -> logits) 17 | :param torch.optim.Optimizer optim: a torch.optim for optimizing the model. 18 | :param float discount_factor: in [0, 1]. 19 | :param int num_quantiles: the number of quantile midpoints in the inverse 20 | cumulative distribution function of the value. Default to 200. 21 | :param int estimation_step: the number of steps to look ahead. Default to 1. 22 | :param int target_update_freq: the target network update frequency (0 if 23 | you do not use the target network). 24 | :param bool reward_normalization: normalize the reward to Normal(0, 1). 25 | Default to False. 26 | :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in 27 | optimizer in each policy.update(). Default to None (no lr_scheduler). 28 | 29 | .. seealso:: 30 | 31 | Please refer to :class:`~tianshou.policy.DQNPolicy` for more detailed 32 | explanation. 33 | """ 34 | 35 | def __init__( 36 | self, 37 | model: torch.nn.Module, 38 | optim: torch.optim.Optimizer, 39 | discount_factor: float = 0.99, 40 | num_quantiles: int = 200, 41 | estimation_step: int = 1, 42 | target_update_freq: int = 0, 43 | reward_normalization: bool = False, 44 | **kwargs: Any, 45 | ) -> None: 46 | super().__init__( 47 | model, optim, discount_factor, estimation_step, target_update_freq, 48 | reward_normalization, **kwargs 49 | ) 50 | assert num_quantiles > 1, "num_quantiles should be greater than 1" 51 | self._num_quantiles = num_quantiles 52 | tau = torch.linspace(0, 1, self._num_quantiles + 1) 53 | self.tau_hat = torch.nn.Parameter( 54 | ((tau[:-1] + tau[1:]) / 2).view(1, -1, 1), requires_grad=False 55 | ) 56 | warnings.filterwarnings("ignore", message="Using a target size") 57 | 58 | def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: 59 | batch = buffer[indices] # batch.obs_next: s_{t+n} 60 | if self._target: 61 | act = self(batch, input="obs_next").act 62 | next_dist = self(batch, model="model_old", input="obs_next").logits 63 | else: 64 | next_batch = self(batch, input="obs_next") 65 | act = next_batch.act 66 | next_dist = next_batch.logits 67 | next_dist = next_dist[np.arange(len(act)), act, :] 68 | return next_dist # shape: [bsz, num_quantiles] 69 | 70 | def compute_q_value( 71 | self, logits: torch.Tensor, mask: Optional[np.ndarray] 72 | ) -> torch.Tensor: 73 | return super().compute_q_value(logits.mean(2), mask) 74 | 75 | def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: 76 | if self._target and self._iter % self._freq == 0: 77 | self.sync_weight() 78 | self.optim.zero_grad() 79 | weight = batch.pop("weight", 1.0) 80 | curr_dist = self(batch).logits 81 | act = batch.act 82 | curr_dist = curr_dist[np.arange(len(act)), act, :].unsqueeze(2) 83 | target_dist = batch.returns.unsqueeze(1) 84 | # calculate each element's difference between curr_dist and target_dist 85 | dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") 86 | huber_loss = ( 87 | dist_diff * 88 | (self.tau_hat - (target_dist - curr_dist).detach().le(0.).float()).abs() 89 | ).sum(-1).mean(1) 90 | loss = (huber_loss * weight).mean() 91 | # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ 92 | # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 93 | batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer 94 | loss.backward() 95 | self.optim.step() 96 | self._iter += 1 97 | return {"loss": loss.item()} 98 | -------------------------------------------------------------------------------- /tianshou/data/buffer/prio.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Optional, Tuple, Union 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from tianshou.data import Batch, ReplayBuffer, SegmentTree, to_numpy 7 | 8 | 9 | class PrioritizedReplayBuffer(ReplayBuffer): 10 | """Implementation of Prioritized Experience Replay. arXiv:1511.05952. 11 | 12 | :param float alpha: the prioritization exponent. 13 | :param float beta: the importance sample soft coefficient. 14 | :param bool weight_norm: whether to normalize returned weights with the maximum 15 | weight value within the batch. Default to True. 16 | 17 | .. seealso:: 18 | 19 | Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | size: int, 25 | alpha: float, 26 | beta: float, 27 | weight_norm: bool = True, 28 | **kwargs: Any 29 | ) -> None: 30 | # will raise KeyError in PrioritizedVectorReplayBuffer 31 | # super().__init__(size, **kwargs) 32 | ReplayBuffer.__init__(self, size, **kwargs) 33 | assert alpha > 0.0 and beta >= 0.0 34 | self._alpha, self._beta = alpha, beta 35 | self._max_prio = self._min_prio = 1.0 36 | # save weight directly in this class instead of self._meta 37 | self.weight = SegmentTree(size) 38 | self.__eps = np.finfo(np.float32).eps.item() 39 | self.options.update(alpha=alpha, beta=beta) 40 | self._weight_norm = weight_norm 41 | 42 | def init_weight(self, index: Union[int, np.ndarray]) -> None: 43 | self.weight[index] = self._max_prio**self._alpha 44 | 45 | def update(self, buffer: ReplayBuffer) -> np.ndarray: 46 | indices = super().update(buffer) 47 | self.init_weight(indices) 48 | return indices 49 | 50 | def add( 51 | self, 52 | batch: Batch, 53 | buffer_ids: Optional[Union[np.ndarray, List[int]]] = None 54 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: 55 | ptr, ep_rew, ep_len, ep_idx = super().add(batch, buffer_ids) 56 | self.init_weight(ptr) 57 | return ptr, ep_rew, ep_len, ep_idx 58 | 59 | def sample_indices(self, batch_size: int) -> np.ndarray: 60 | if batch_size > 0 and len(self) > 0: 61 | scalar = np.random.rand(batch_size) * self.weight.reduce() 62 | return self.weight.get_prefix_sum_idx(scalar) # type: ignore 63 | else: 64 | return super().sample_indices(batch_size) 65 | 66 | def get_weight(self, index: Union[int, np.ndarray]) -> Union[float, np.ndarray]: 67 | """Get the importance sampling weight. 68 | 69 | The "weight" in the returned Batch is the weight on loss function to debias 70 | the sampling process (some transition tuples are sampled more often so their 71 | losses are weighted less). 72 | """ 73 | # important sampling weight calculation 74 | # original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta)) 75 | # simplified formula: (p_j/p_min)**(-beta) 76 | return (self.weight[index] / self._min_prio)**(-self._beta) 77 | 78 | def update_weight( 79 | self, index: np.ndarray, new_weight: Union[np.ndarray, torch.Tensor] 80 | ) -> None: 81 | """Update priority weight by index in this buffer. 82 | 83 | :param np.ndarray index: index you want to update weight. 84 | :param np.ndarray new_weight: new priority weight you want to update. 85 | """ 86 | weight = np.abs(to_numpy(new_weight)) + self.__eps 87 | self.weight[index] = weight**self._alpha 88 | self._max_prio = max(self._max_prio, weight.max()) 89 | self._min_prio = min(self._min_prio, weight.min()) 90 | 91 | def __getitem__(self, index: Union[slice, int, List[int], np.ndarray]) -> Batch: 92 | if isinstance(index, slice): # change slice to np array 93 | # buffer[:] will get all available data 94 | indices = self.sample_indices(0) if index == slice(None) \ 95 | else self._indices[:len(self)][index] 96 | else: 97 | indices = index # type: ignore 98 | batch = super().__getitem__(indices) 99 | weight = self.get_weight(indices) 100 | # ref: https://github.com/Kaixhin/Rainbow/blob/master/memory.py L154 101 | batch.weight = weight / np.max(weight) if self._weight_norm else weight 102 | return batch 103 | 104 | def set_beta(self, beta: float) -> None: 105 | self._beta = beta 106 | -------------------------------------------------------------------------------- /tianshou/env/venv_wrappers.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Optional, Tuple, Union 2 | 3 | import numpy as np 4 | 5 | from tianshou.env.venvs import GYM_RESERVED_KEYS, BaseVectorEnv 6 | from tianshou.utils import RunningMeanStd 7 | 8 | 9 | class VectorEnvWrapper(BaseVectorEnv): 10 | """Base class for vectorized environments wrapper.""" 11 | 12 | def __init__(self, venv: BaseVectorEnv) -> None: 13 | self.venv = venv 14 | self.is_async = venv.is_async 15 | 16 | def __len__(self) -> int: 17 | return len(self.venv) 18 | 19 | def __getattribute__(self, key: str) -> Any: 20 | if key in GYM_RESERVED_KEYS: # reserved keys in gym.Env 21 | return getattr(self.venv, key) 22 | else: 23 | return super().__getattribute__(key) 24 | 25 | def get_env_attr( 26 | self, 27 | key: str, 28 | id: Optional[Union[int, List[int], np.ndarray]] = None, 29 | ) -> List[Any]: 30 | return self.venv.get_env_attr(key, id) 31 | 32 | def set_env_attr( 33 | self, 34 | key: str, 35 | value: Any, 36 | id: Optional[Union[int, List[int], np.ndarray]] = None, 37 | ) -> None: 38 | return self.venv.set_env_attr(key, value, id) 39 | 40 | def reset( 41 | self, 42 | id: Optional[Union[int, List[int], np.ndarray]] = None, 43 | **kwargs: Any, 44 | ) -> Union[np.ndarray, Tuple[np.ndarray, Union[dict, List[dict]]]]: 45 | return self.venv.reset(id, **kwargs) 46 | 47 | def step( 48 | self, 49 | action: np.ndarray, 50 | id: Optional[Union[int, List[int], np.ndarray]] = None, 51 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: 52 | return self.venv.step(action, id) 53 | 54 | def seed( 55 | self, 56 | seed: Optional[Union[int, List[int]]] = None, 57 | ) -> List[Optional[List[int]]]: 58 | return self.venv.seed(seed) 59 | 60 | def render(self, **kwargs: Any) -> List[Any]: 61 | return self.venv.render(**kwargs) 62 | 63 | def close(self) -> None: 64 | self.venv.close() 65 | 66 | 67 | class VectorEnvNormObs(VectorEnvWrapper): 68 | """An observation normalization wrapper for vectorized environments. 69 | 70 | :param bool update_obs_rms: whether to update obs_rms. Default to True. 71 | """ 72 | 73 | def __init__( 74 | self, 75 | venv: BaseVectorEnv, 76 | update_obs_rms: bool = True, 77 | ) -> None: 78 | super().__init__(venv) 79 | # initialize observation running mean/std 80 | self.update_obs_rms = update_obs_rms 81 | self.obs_rms = RunningMeanStd() 82 | 83 | def reset( 84 | self, 85 | id: Optional[Union[int, List[int], np.ndarray]] = None, 86 | **kwargs: Any, 87 | ) -> Union[np.ndarray, Tuple[np.ndarray, Union[dict, List[dict]]]]: 88 | rval = self.venv.reset(id, **kwargs) 89 | returns_info = isinstance(rval, (tuple, list)) and (len(rval) == 2) and ( 90 | isinstance(rval[1], dict) or isinstance(rval[1][0], dict) 91 | ) 92 | if returns_info: 93 | obs, info = rval 94 | else: 95 | obs = rval 96 | 97 | if isinstance(obs, tuple): 98 | raise TypeError( 99 | "Tuple observation space is not supported. ", 100 | "Please change it to array or dict space", 101 | ) 102 | 103 | if self.obs_rms and self.update_obs_rms: 104 | self.obs_rms.update(obs) 105 | obs = self._norm_obs(obs) 106 | if returns_info: 107 | return obs, info 108 | else: 109 | return obs 110 | 111 | def step( 112 | self, 113 | action: np.ndarray, 114 | id: Optional[Union[int, List[int], np.ndarray]] = None, 115 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: 116 | obs, rew, done, info = self.venv.step(action, id) 117 | if self.obs_rms and self.update_obs_rms: 118 | self.obs_rms.update(obs) 119 | return self._norm_obs(obs), rew, done, info 120 | 121 | def _norm_obs(self, obs: np.ndarray) -> np.ndarray: 122 | if self.obs_rms: 123 | return self.obs_rms.norm(obs) # type: ignore 124 | return obs 125 | 126 | def set_obs_rms(self, obs_rms: RunningMeanStd) -> None: 127 | """Set with given observation running mean/std.""" 128 | self.obs_rms = obs_rms 129 | 130 | def get_obs_rms(self) -> RunningMeanStd: 131 | """Return observation running mean/std.""" 132 | return self.obs_rms 133 | -------------------------------------------------------------------------------- /tianshou/policy/modelfree/c51.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from tianshou.data import Batch, ReplayBuffer 7 | from tianshou.policy import DQNPolicy 8 | 9 | 10 | class C51Policy(DQNPolicy): 11 | """Implementation of Categorical Deep Q-Network. arXiv:1707.06887. 12 | 13 | :param torch.nn.Module model: a model following the rules in 14 | :class:`~tianshou.policy.BasePolicy`. (s -> logits) 15 | :param torch.optim.Optimizer optim: a torch.optim for optimizing the model. 16 | :param float discount_factor: in [0, 1]. 17 | :param int num_atoms: the number of atoms in the support set of the 18 | value distribution. Default to 51. 19 | :param float v_min: the value of the smallest atom in the support set. 20 | Default to -10.0. 21 | :param float v_max: the value of the largest atom in the support set. 22 | Default to 10.0. 23 | :param int estimation_step: the number of steps to look ahead. Default to 1. 24 | :param int target_update_freq: the target network update frequency (0 if 25 | you do not use the target network). Default to 0. 26 | :param bool reward_normalization: normalize the reward to Normal(0, 1). 27 | Default to False. 28 | :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in 29 | optimizer in each policy.update(). Default to None (no lr_scheduler). 30 | 31 | .. seealso:: 32 | 33 | Please refer to :class:`~tianshou.policy.DQNPolicy` for more detailed 34 | explanation. 35 | """ 36 | 37 | def __init__( 38 | self, 39 | model: torch.nn.Module, 40 | optim: torch.optim.Optimizer, 41 | discount_factor: float = 0.99, 42 | num_atoms: int = 51, 43 | v_min: float = -10.0, 44 | v_max: float = 10.0, 45 | estimation_step: int = 1, 46 | target_update_freq: int = 0, 47 | reward_normalization: bool = False, 48 | **kwargs: Any, 49 | ) -> None: 50 | super().__init__( 51 | model, optim, discount_factor, estimation_step, target_update_freq, 52 | reward_normalization, **kwargs 53 | ) 54 | assert num_atoms > 1, "num_atoms should be greater than 1" 55 | assert v_min < v_max, "v_max should be larger than v_min" 56 | self._num_atoms = num_atoms 57 | self._v_min = v_min 58 | self._v_max = v_max 59 | self.support = torch.nn.Parameter( 60 | torch.linspace(self._v_min, self._v_max, self._num_atoms), 61 | requires_grad=False, 62 | ) 63 | self.delta_z = (v_max - v_min) / (num_atoms - 1) 64 | 65 | def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: 66 | return self.support.repeat(len(indices), 1) # shape: [bsz, num_atoms] 67 | 68 | def compute_q_value( 69 | self, logits: torch.Tensor, mask: Optional[np.ndarray] 70 | ) -> torch.Tensor: 71 | return super().compute_q_value((logits * self.support).sum(2), mask) 72 | 73 | def _target_dist(self, batch: Batch) -> torch.Tensor: 74 | if self._target: 75 | act = self(batch, input="obs_next").act 76 | next_dist = self(batch, model="model_old", input="obs_next").logits 77 | else: 78 | next_batch = self(batch, input="obs_next") 79 | act = next_batch.act 80 | next_dist = next_batch.logits 81 | next_dist = next_dist[np.arange(len(act)), act, :] 82 | target_support = batch.returns.clamp(self._v_min, self._v_max) 83 | # An amazing trick for calculating the projection gracefully. 84 | # ref: https://github.com/ShangtongZhang/DeepRL 85 | target_dist = ( 86 | 1 - (target_support.unsqueeze(1) - self.support.view(1, -1, 1)).abs() / 87 | self.delta_z 88 | ).clamp(0, 1) * next_dist.unsqueeze(1) 89 | return target_dist.sum(-1) 90 | 91 | def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: 92 | if self._target and self._iter % self._freq == 0: 93 | self.sync_weight() 94 | self.optim.zero_grad() 95 | with torch.no_grad(): 96 | target_dist = self._target_dist(batch) 97 | weight = batch.pop("weight", 1.0) 98 | curr_dist = self(batch).logits 99 | act = batch.act 100 | curr_dist = curr_dist[np.arange(len(act)), act, :] 101 | cross_entropy = -(target_dist * torch.log(curr_dist + 1e-8)).sum(1) 102 | loss = (cross_entropy * weight).mean() 103 | # ref: https://github.com/Kaixhin/Rainbow/blob/master/agent.py L94-100 104 | batch.weight = cross_entropy.detach() # prio-buffer 105 | loss.backward() 106 | self.optim.step() 107 | self._iter += 1 108 | return {"loss": loss.item()} 109 | -------------------------------------------------------------------------------- /tianshou/policy/modelbased/icm.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional, Union 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch 8 | from tianshou.policy import BasePolicy 9 | from tianshou.utils.net.discrete import IntrinsicCuriosityModule 10 | 11 | 12 | class ICMPolicy(BasePolicy): 13 | """Implementation of Intrinsic Curiosity Module. arXiv:1705.05363. 14 | 15 | :param BasePolicy policy: a base policy to add ICM to. 16 | :param IntrinsicCuriosityModule model: the ICM model. 17 | :param torch.optim.Optimizer optim: a torch.optim for optimizing the model. 18 | :param float lr_scale: the scaling factor for ICM learning. 19 | :param float forward_loss_weight: the weight for forward model loss. 20 | :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in 21 | optimizer in each policy.update(). Default to None (no lr_scheduler). 22 | 23 | .. seealso:: 24 | 25 | Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed 26 | explanation. 27 | """ 28 | 29 | def __init__( 30 | self, 31 | policy: BasePolicy, 32 | model: IntrinsicCuriosityModule, 33 | optim: torch.optim.Optimizer, 34 | lr_scale: float, 35 | reward_scale: float, 36 | forward_loss_weight: float, 37 | **kwargs: Any, 38 | ) -> None: 39 | super().__init__(**kwargs) 40 | self.policy = policy 41 | self.model = model 42 | self.optim = optim 43 | self.lr_scale = lr_scale 44 | self.reward_scale = reward_scale 45 | self.forward_loss_weight = forward_loss_weight 46 | 47 | def train(self, mode: bool = True) -> "ICMPolicy": 48 | """Set the module in training mode.""" 49 | self.policy.train(mode) 50 | self.training = mode 51 | self.model.train(mode) 52 | return self 53 | 54 | def forward( 55 | self, 56 | batch: Batch, 57 | state: Optional[Union[dict, Batch, np.ndarray]] = None, 58 | **kwargs: Any, 59 | ) -> Batch: 60 | """Compute action over the given batch data by inner policy. 61 | 62 | .. seealso:: 63 | 64 | Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for 65 | more detailed explanation. 66 | """ 67 | return self.policy.forward(batch, state, **kwargs) 68 | 69 | def exploration_noise(self, act: Union[np.ndarray, Batch], 70 | batch: Batch) -> Union[np.ndarray, Batch]: 71 | return self.policy.exploration_noise(act, batch) 72 | 73 | def set_eps(self, eps: float) -> None: 74 | """Set the eps for epsilon-greedy exploration.""" 75 | if hasattr(self.policy, "set_eps"): 76 | self.policy.set_eps(eps) # type: ignore 77 | else: 78 | raise NotImplementedError() 79 | 80 | def process_fn( 81 | self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray 82 | ) -> Batch: 83 | """Pre-process the data from the provided replay buffer. 84 | 85 | Used in :meth:`update`. Check out :ref:`process_fn` for more information. 86 | """ 87 | mse_loss, act_hat = self.model(batch.obs, batch.act, batch.obs_next) 88 | batch.policy = Batch(orig_rew=batch.rew, act_hat=act_hat, mse_loss=mse_loss) 89 | batch.rew += to_numpy(mse_loss * self.reward_scale) 90 | return self.policy.process_fn(batch, buffer, indices) 91 | 92 | def post_process_fn( 93 | self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray 94 | ) -> None: 95 | """Post-process the data from the provided replay buffer. 96 | 97 | Typical usage is to update the sampling weight in prioritized 98 | experience replay. Used in :meth:`update`. 99 | """ 100 | self.policy.post_process_fn(batch, buffer, indices) 101 | batch.rew = batch.policy.orig_rew # restore original reward 102 | 103 | def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: 104 | res = self.policy.learn(batch, **kwargs) 105 | self.optim.zero_grad() 106 | act_hat = batch.policy.act_hat 107 | act = to_torch(batch.act, dtype=torch.long, device=act_hat.device) 108 | inverse_loss = F.cross_entropy(act_hat, act).mean() 109 | forward_loss = batch.policy.mse_loss.mean() 110 | loss = ( 111 | (1 - self.forward_loss_weight) * inverse_loss + 112 | self.forward_loss_weight * forward_loss 113 | ) * self.lr_scale 114 | loss.backward() 115 | self.optim.step() 116 | res.update( 117 | { 118 | "loss/icm": loss.item(), 119 | "loss/icm/forward": forward_loss.item(), 120 | "loss/icm/inverse": inverse_loss.item() 121 | } 122 | ) 123 | return res 124 | -------------------------------------------------------------------------------- /tianshou/policy/imitation/td3_bc.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from tianshou.data import Batch, to_torch_as 7 | from tianshou.exploration import BaseNoise, GaussianNoise 8 | from tianshou.policy import TD3Policy 9 | 10 | 11 | class TD3BCPolicy(TD3Policy): 12 | """Implementation of TD3+BC. arXiv:2106.06860. 13 | 14 | :param torch.nn.Module actor: the actor network following the rules in 15 | :class:`~tianshou.policy.BasePolicy`. (s -> logits) 16 | :param torch.optim.Optimizer actor_optim: the optimizer for actor network. 17 | :param torch.nn.Module critic1: the first critic network. (s, a -> Q(s, a)) 18 | :param torch.optim.Optimizer critic1_optim: the optimizer for the first 19 | critic network. 20 | :param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, a)) 21 | :param torch.optim.Optimizer critic2_optim: the optimizer for the second 22 | critic network. 23 | :param float tau: param for soft update of the target network. Default to 0.005. 24 | :param float gamma: discount factor, in [0, 1]. Default to 0.99. 25 | :param float exploration_noise: the exploration noise, add to the action. 26 | Default to ``GaussianNoise(sigma=0.1)`` 27 | :param float policy_noise: the noise used in updating policy network. 28 | Default to 0.2. 29 | :param int update_actor_freq: the update frequency of actor network. 30 | Default to 2. 31 | :param float noise_clip: the clipping range used in updating policy network. 32 | Default to 0.5. 33 | :param float alpha: the value of alpha, which controls the weight for TD3 learning 34 | relative to behavior cloning. 35 | :param bool reward_normalization: normalize the reward to Normal(0, 1). 36 | Default to False. 37 | :param bool action_scaling: whether to map actions from range [-1, 1] to range 38 | [action_spaces.low, action_spaces.high]. Default to True. 39 | :param str action_bound_method: method to bound action to range [-1, 1], can be 40 | either "clip" (for simply clipping the action) or empty string for no bounding. 41 | Default to "clip". 42 | :param Optional[gym.Space] action_space: env's action space, mandatory if you want 43 | to use option "action_scaling" or "action_bound_method". Default to None. 44 | :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in 45 | optimizer in each policy.update(). Default to None (no lr_scheduler). 46 | 47 | .. seealso:: 48 | 49 | Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed 50 | explanation. 51 | """ 52 | 53 | def __init__( 54 | self, 55 | actor: torch.nn.Module, 56 | actor_optim: torch.optim.Optimizer, 57 | critic1: torch.nn.Module, 58 | critic1_optim: torch.optim.Optimizer, 59 | critic2: torch.nn.Module, 60 | critic2_optim: torch.optim.Optimizer, 61 | tau: float = 0.005, 62 | gamma: float = 0.99, 63 | exploration_noise: Optional[BaseNoise] = GaussianNoise(sigma=0.1), 64 | policy_noise: float = 0.2, 65 | update_actor_freq: int = 2, 66 | noise_clip: float = 0.5, 67 | alpha: float = 2.5, 68 | reward_normalization: bool = False, 69 | estimation_step: int = 1, 70 | **kwargs: Any, 71 | ) -> None: 72 | super().__init__( 73 | actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, tau, 74 | gamma, exploration_noise, policy_noise, update_actor_freq, noise_clip, 75 | reward_normalization, estimation_step, **kwargs 76 | ) 77 | self._alpha = alpha 78 | 79 | def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: 80 | # critic 1&2 81 | td1, critic1_loss = self._mse_optimizer( 82 | batch, self.critic1, self.critic1_optim 83 | ) 84 | td2, critic2_loss = self._mse_optimizer( 85 | batch, self.critic2, self.critic2_optim 86 | ) 87 | batch.weight = (td1 + td2) / 2.0 # prio-buffer 88 | 89 | # actor 90 | if self._cnt % self._freq == 0: 91 | act = self(batch, eps=0.0).act 92 | q_value = self.critic1(batch.obs, act) 93 | lmbda = self._alpha / q_value.abs().mean().detach() 94 | actor_loss = -lmbda * q_value.mean() + F.mse_loss( 95 | act, to_torch_as(batch.act, act) 96 | ) 97 | self.actor_optim.zero_grad() 98 | actor_loss.backward() 99 | self._last = actor_loss.item() 100 | self.actor_optim.step() 101 | self.sync_weight() 102 | self._cnt += 1 103 | return { 104 | "loss/actor": self._last, 105 | "loss/critic1": critic1_loss.item(), 106 | "loss/critic2": critic2_loss.item(), 107 | } 108 | -------------------------------------------------------------------------------- /tianshou/policy/modelfree/iqn.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional, Union 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | from tianshou.data import Batch, to_numpy 8 | from tianshou.policy import QRDQNPolicy 9 | 10 | 11 | class IQNPolicy(QRDQNPolicy): 12 | """Implementation of Implicit Quantile Network. arXiv:1806.06923. 13 | 14 | :param torch.nn.Module model: a model following the rules in 15 | :class:`~tianshou.policy.BasePolicy`. (s -> logits) 16 | :param torch.optim.Optimizer optim: a torch.optim for optimizing the model. 17 | :param float discount_factor: in [0, 1]. 18 | :param int sample_size: the number of samples for policy evaluation. 19 | Default to 32. 20 | :param int online_sample_size: the number of samples for online model 21 | in training. Default to 8. 22 | :param int target_sample_size: the number of samples for target model 23 | in training. Default to 8. 24 | :param int estimation_step: the number of steps to look ahead. Default to 1. 25 | :param int target_update_freq: the target network update frequency (0 if 26 | you do not use the target network). 27 | :param bool reward_normalization: normalize the reward to Normal(0, 1). 28 | Default to False. 29 | :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in 30 | optimizer in each policy.update(). Default to None (no lr_scheduler). 31 | 32 | .. seealso:: 33 | 34 | Please refer to :class:`~tianshou.policy.QRDQNPolicy` for more detailed 35 | explanation. 36 | """ 37 | 38 | def __init__( 39 | self, 40 | model: torch.nn.Module, 41 | optim: torch.optim.Optimizer, 42 | discount_factor: float = 0.99, 43 | sample_size: int = 32, 44 | online_sample_size: int = 8, 45 | target_sample_size: int = 8, 46 | estimation_step: int = 1, 47 | target_update_freq: int = 0, 48 | reward_normalization: bool = False, 49 | **kwargs: Any, 50 | ) -> None: 51 | super().__init__( 52 | model, optim, discount_factor, sample_size, estimation_step, 53 | target_update_freq, reward_normalization, **kwargs 54 | ) 55 | assert sample_size > 1, "sample_size should be greater than 1" 56 | assert online_sample_size > 1, "online_sample_size should be greater than 1" 57 | assert target_sample_size > 1, "target_sample_size should be greater than 1" 58 | self._sample_size = sample_size # for policy eval 59 | self._online_sample_size = online_sample_size 60 | self._target_sample_size = target_sample_size 61 | 62 | def forward( 63 | self, 64 | batch: Batch, 65 | state: Optional[Union[dict, Batch, np.ndarray]] = None, 66 | model: str = "model", 67 | input: str = "obs", 68 | **kwargs: Any, 69 | ) -> Batch: 70 | if model == "model_old": 71 | sample_size = self._target_sample_size 72 | elif self.training: 73 | sample_size = self._online_sample_size 74 | else: 75 | sample_size = self._sample_size 76 | model = getattr(self, model) 77 | obs = batch[input] 78 | obs_next = obs.obs if hasattr(obs, "obs") else obs 79 | (logits, taus), hidden = model( 80 | obs_next, sample_size=sample_size, state=state, info=batch.info 81 | ) 82 | q = self.compute_q_value(logits, getattr(obs, "mask", None)) 83 | if not hasattr(self, "max_action_num"): 84 | self.max_action_num = q.shape[1] 85 | act = to_numpy(q.max(dim=1)[1]) 86 | return Batch(logits=logits, act=act, state=hidden, taus=taus) 87 | 88 | def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: 89 | if self._target and self._iter % self._freq == 0: 90 | self.sync_weight() 91 | self.optim.zero_grad() 92 | weight = batch.pop("weight", 1.0) 93 | action_batch = self(batch) 94 | curr_dist, taus = action_batch.logits, action_batch.taus 95 | act = batch.act 96 | curr_dist = curr_dist[np.arange(len(act)), act, :].unsqueeze(2) 97 | target_dist = batch.returns.unsqueeze(1) 98 | # calculate each element's difference between curr_dist and target_dist 99 | dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") 100 | huber_loss = ( 101 | dist_diff * 102 | (taus.unsqueeze(2) - 103 | (target_dist - curr_dist).detach().le(0.).float()).abs() 104 | ).sum(-1).mean(1) 105 | loss = (huber_loss * weight).mean() 106 | # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ 107 | # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 108 | batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer 109 | loss.backward() 110 | self.optim.step() 111 | self._iter += 1 112 | return {"loss": loss.item()} 113 | -------------------------------------------------------------------------------- /tianshou/env/pettingzoo_env.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from typing import Any, Dict, List, Tuple, Union 3 | 4 | import gym.spaces 5 | from pettingzoo.utils.env import AECEnv 6 | from pettingzoo.utils.wrappers import BaseWrapper 7 | 8 | 9 | class PettingZooEnv(AECEnv, ABC): 10 | """The interface for petting zoo environments. 11 | 12 | Multi-agent environments must be wrapped as 13 | :class:`~tianshou.env.PettingZooEnv`. Here is the usage: 14 | :: 15 | 16 | env = PettingZooEnv(...) 17 | # obs is a dict containing obs, agent_id, and mask 18 | obs = env.reset() 19 | action = policy(obs) 20 | obs, rew, done, info = env.step(action) 21 | env.close() 22 | 23 | The available action's mask is set to True, otherwise it is set to False. 24 | Further usage can be found at :ref:`marl_example`. 25 | """ 26 | 27 | def __init__(self, env: BaseWrapper): 28 | super().__init__() 29 | self.env = env 30 | # agent idx list 31 | self.agents = self.env.possible_agents 32 | self.agent_idx = {} 33 | for i, agent_id in enumerate(self.agents): 34 | self.agent_idx[agent_id] = i 35 | 36 | self.rewards = [0] * len(self.agents) 37 | 38 | # Get first observation space, assuming all agents have equal space 39 | self.observation_space: Any = self.env.observation_space(self.agents[0]) 40 | 41 | # Get first action space, assuming all agents have equal space 42 | self.action_space: Any = self.env.action_space(self.agents[0]) 43 | 44 | assert all(self.env.observation_space(agent) == self.observation_space 45 | for agent in self.agents), \ 46 | "Observation spaces for all agents must be identical. Perhaps " \ 47 | "SuperSuit's pad_observations wrapper can help (useage: " \ 48 | "`supersuit.aec_wrappers.pad_observations(env)`" 49 | 50 | assert all(self.env.action_space(agent) == self.action_space 51 | for agent in self.agents), \ 52 | "Action spaces for all agents must be identical. Perhaps " \ 53 | "SuperSuit's pad_action_space wrapper can help (useage: " \ 54 | "`supersuit.aec_wrappers.pad_action_space(env)`" 55 | 56 | self.reset() 57 | 58 | def reset(self, *args: Any, **kwargs: Any) -> Union[dict, Tuple[dict, dict]]: 59 | self.env.reset(*args, **kwargs) 60 | observation, _, _, info = self.env.last(self) 61 | if isinstance(observation, dict) and 'action_mask' in observation: 62 | observation_dict = { 63 | 'agent_id': self.env.agent_selection, 64 | 'obs': observation['observation'], 65 | 'mask': 66 | [True if obm == 1 else False for obm in observation['action_mask']] 67 | } 68 | else: 69 | if isinstance(self.action_space, gym.spaces.Discrete): 70 | observation_dict = { 71 | 'agent_id': self.env.agent_selection, 72 | 'obs': observation, 73 | 'mask': [True] * self.env.action_space(self.env.agent_selection).n 74 | } 75 | else: 76 | observation_dict = { 77 | 'agent_id': self.env.agent_selection, 78 | 'obs': observation, 79 | } 80 | 81 | if "return_info" in kwargs and kwargs["return_info"]: 82 | return observation_dict, info 83 | else: 84 | return observation_dict 85 | 86 | def step(self, action: Any) -> Tuple[Dict, List[int], bool, Dict]: 87 | self.env.step(action) 88 | observation, rew, done, info = self.env.last() 89 | if isinstance(observation, dict) and 'action_mask' in observation: 90 | obs = { 91 | 'agent_id': self.env.agent_selection, 92 | 'obs': observation['observation'], 93 | 'mask': 94 | [True if obm == 1 else False for obm in observation['action_mask']] 95 | } 96 | else: 97 | if isinstance(self.action_space, gym.spaces.Discrete): 98 | obs = { 99 | 'agent_id': self.env.agent_selection, 100 | 'obs': observation, 101 | 'mask': [True] * self.env.action_space(self.env.agent_selection).n 102 | } 103 | else: 104 | obs = {'agent_id': self.env.agent_selection, 'obs': observation} 105 | 106 | for agent_id, reward in self.env.rewards.items(): 107 | self.rewards[self.agent_idx[agent_id]] = reward 108 | return obs, self.rewards, done, info 109 | 110 | def close(self) -> None: 111 | self.env.close() 112 | 113 | def seed(self, seed: Any = None) -> None: 114 | try: 115 | self.env.seed(seed) 116 | except NotImplementedError: 117 | self.env.reset(seed=seed) 118 | 119 | def render(self, mode: str = "human") -> Any: 120 | return self.env.render(mode) 121 | -------------------------------------------------------------------------------- /tianshou/data/utils/segtree.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | import numpy as np 4 | from numba import njit 5 | 6 | 7 | class SegmentTree: 8 | """Implementation of Segment Tree. 9 | 10 | The segment tree stores an array ``arr`` with size ``n``. It supports value 11 | update and fast query of the sum for the interval ``[left, right)`` in 12 | O(log n) time. The detailed procedure is as follows: 13 | 14 | 1. Pad the array to have length of power of 2, so that leaf nodes in the \ 15 | segment tree have the same depth. 16 | 2. Store the segment tree in a binary heap. 17 | 18 | :param int size: the size of segment tree. 19 | """ 20 | 21 | def __init__(self, size: int) -> None: 22 | bound = 1 23 | while bound < size: 24 | bound *= 2 25 | self._size = size 26 | self._bound = bound 27 | self._value = np.zeros([bound * 2]) 28 | self._compile() 29 | 30 | def __len__(self) -> int: 31 | return self._size 32 | 33 | def __getitem__(self, index: Union[int, np.ndarray]) -> Union[float, np.ndarray]: 34 | """Return self[index].""" 35 | return self._value[index + self._bound] 36 | 37 | def __setitem__( 38 | self, index: Union[int, np.ndarray], value: Union[float, np.ndarray] 39 | ) -> None: 40 | """Update values in segment tree. 41 | 42 | Duplicate values in ``index`` are handled by numpy: later index 43 | overwrites previous ones. 44 | :: 45 | 46 | >>> a = np.array([1, 2, 3, 4]) 47 | >>> a[[0, 1, 0, 1]] = [4, 5, 6, 7] 48 | >>> print(a) 49 | [6 7 3 4] 50 | """ 51 | if isinstance(index, int): 52 | index, value = np.array([index]), np.array([value]) 53 | assert np.all(0 <= index) and np.all(index < self._size) 54 | _setitem(self._value, index + self._bound, value) 55 | 56 | def reduce(self, start: int = 0, end: Optional[int] = None) -> float: 57 | """Return operation(value[start:end]).""" 58 | if start == 0 and end is None: 59 | return self._value[1] 60 | if end is None: 61 | end = self._size 62 | if end < 0: 63 | end += self._size 64 | return _reduce(self._value, start + self._bound - 1, end + self._bound) 65 | 66 | def get_prefix_sum_idx(self, value: Union[float, 67 | np.ndarray]) -> Union[int, np.ndarray]: 68 | r"""Find the index with given value. 69 | 70 | Return the minimum index for each ``v`` in ``value`` so that 71 | :math:`v \le \mathrm{sums}_i`, where 72 | :math:`\mathrm{sums}_i = \sum_{j = 0}^{i} \mathrm{arr}_j`. 73 | 74 | .. warning:: 75 | 76 | Please make sure all of the values inside the segment tree are 77 | non-negative when using this function. 78 | """ 79 | assert np.all(value >= 0.0) and np.all(value < self._value[1]) 80 | single = False 81 | if not isinstance(value, np.ndarray): 82 | value = np.array([value]) 83 | single = True 84 | index = _get_prefix_sum_idx(value, self._bound, self._value) 85 | return index.item() if single else index 86 | 87 | def _compile(self) -> None: 88 | f64 = np.array([0, 1], dtype=np.float64) 89 | f32 = np.array([0, 1], dtype=np.float32) 90 | i64 = np.array([0, 1], dtype=np.int64) 91 | _setitem(f64, i64, f64) 92 | _setitem(f64, i64, f32) 93 | _reduce(f64, 0, 1) 94 | _get_prefix_sum_idx(f64, 1, f64) 95 | _get_prefix_sum_idx(f32, 1, f64) 96 | 97 | 98 | @njit 99 | def _setitem(tree: np.ndarray, index: np.ndarray, value: np.ndarray) -> None: 100 | """Numba version, 4x faster: 0.1 -> 0.024.""" 101 | tree[index] = value 102 | while index[0] > 1: 103 | index //= 2 104 | tree[index] = tree[index * 2] + tree[index * 2 + 1] 105 | 106 | 107 | @njit 108 | def _reduce(tree: np.ndarray, start: int, end: int) -> float: 109 | """Numba version, 2x faster: 0.009 -> 0.005.""" 110 | # nodes in (start, end) should be aggregated 111 | result = 0.0 112 | while end - start > 1: # (start, end) interval is not empty 113 | if start % 2 == 0: 114 | result += tree[start + 1] 115 | start //= 2 116 | if end % 2 == 1: 117 | result += tree[end - 1] 118 | end //= 2 119 | return result 120 | 121 | 122 | @njit 123 | def _get_prefix_sum_idx(value: np.ndarray, bound: int, sums: np.ndarray) -> np.ndarray: 124 | """Numba version (v0.51), 5x speed up with size=100000 and bsz=64. 125 | 126 | vectorized np: 0.0923 (numpy best) -> 0.024 (now) 127 | for-loop: 0.2914 -> 0.019 (but not so stable) 128 | """ 129 | index = np.ones(value.shape, dtype=np.int64) 130 | while index[0] < bound: 131 | index *= 2 132 | lsons = sums[index] 133 | direct = lsons < value 134 | value -= lsons * direct 135 | index += direct 136 | index -= bound 137 | return index 138 | -------------------------------------------------------------------------------- /tianshou/utils/logger/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from numbers import Number 3 | from typing import Callable, Dict, Optional, Tuple, Union 4 | 5 | import numpy as np 6 | 7 | LOG_DATA_TYPE = Dict[str, Union[int, Number, np.number, np.ndarray]] 8 | 9 | 10 | class BaseLogger(ABC): 11 | """The base class for any logger which is compatible with trainer. 12 | 13 | Try to overwrite write() method to use your own writer. 14 | 15 | :param int train_interval: the log interval in log_train_data(). Default to 1000. 16 | :param int test_interval: the log interval in log_test_data(). Default to 1. 17 | :param int update_interval: the log interval in log_update_data(). Default to 1000. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | train_interval: int = 1000, 23 | test_interval: int = 1, 24 | update_interval: int = 1000, 25 | ) -> None: 26 | super().__init__() 27 | self.train_interval = train_interval 28 | self.test_interval = test_interval 29 | self.update_interval = update_interval 30 | self.last_log_train_step = -1 31 | self.last_log_test_step = -1 32 | self.last_log_update_step = -1 33 | 34 | @abstractmethod 35 | def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: 36 | """Specify how the writer is used to log data. 37 | 38 | :param str step_type: namespace which the data dict belongs to. 39 | :param int step: stands for the ordinate of the data dict. 40 | :param dict data: the data to write with format ``{key: value}``. 41 | """ 42 | pass 43 | 44 | def log_train_data(self, collect_result: dict, step: int) -> None: 45 | """Use writer to log statistics generated during training. 46 | 47 | :param collect_result: a dict containing information of data collected in 48 | training stage, i.e., returns of collector.collect(). 49 | :param int step: stands for the timestep the collect_result being logged. 50 | """ 51 | if collect_result["n/ep"] > 0: 52 | if step - self.last_log_train_step >= self.train_interval: 53 | log_data = { 54 | "train/episode": collect_result["n/ep"], 55 | "train/reward": collect_result["rew"], 56 | "train/length": collect_result["len"], 57 | } 58 | self.write("train/env_step", step, log_data) 59 | self.last_log_train_step = step 60 | 61 | def log_test_data(self, collect_result: dict, step: int) -> None: 62 | """Use writer to log statistics generated during evaluating. 63 | 64 | :param collect_result: a dict containing information of data collected in 65 | evaluating stage, i.e., returns of collector.collect(). 66 | :param int step: stands for the timestep the collect_result being logged. 67 | """ 68 | assert collect_result["n/ep"] > 0 69 | if step - self.last_log_test_step >= self.test_interval: 70 | log_data = { 71 | "test/env_step": step, 72 | "test/reward": collect_result["rew"], 73 | "test/length": collect_result["len"], 74 | "test/reward_std": collect_result["rew_std"], 75 | "test/length_std": collect_result["len_std"], 76 | } 77 | self.write("test/env_step", step, log_data) 78 | self.last_log_test_step = step 79 | 80 | def log_update_data(self, update_result: dict, step: int) -> None: 81 | """Use writer to log statistics generated during updating. 82 | 83 | :param update_result: a dict containing information of data collected in 84 | updating stage, i.e., returns of policy.update(). 85 | :param int step: stands for the timestep the collect_result being logged. 86 | """ 87 | if step - self.last_log_update_step >= self.update_interval: 88 | log_data = {f"update/{k}": v for k, v in update_result.items()} 89 | self.write("update/gradient_step", step, log_data) 90 | self.last_log_update_step = step 91 | 92 | def save_data( 93 | self, 94 | epoch: int, 95 | env_step: int, 96 | gradient_step: int, 97 | save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None, 98 | ) -> None: 99 | """Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer. 100 | 101 | :param int epoch: the epoch in trainer. 102 | :param int env_step: the env_step in trainer. 103 | :param int gradient_step: the gradient_step in trainer. 104 | :param function save_checkpoint_fn: a hook defined by user, see trainer 105 | documentation for detail. 106 | """ 107 | pass 108 | 109 | def restore_data(self) -> Tuple[int, int, int]: 110 | """Return the metadata from existing log. 111 | 112 | If it finds nothing or an error occurs during the recover process, it will 113 | return the default parameters. 114 | 115 | :return: epoch, env_step, gradient_step. 116 | """ 117 | pass 118 | 119 | 120 | class LazyLogger(BaseLogger): 121 | """A logger that does nothing. Used as the placeholder in trainer.""" 122 | 123 | def __init__(self) -> None: 124 | super().__init__() 125 | 126 | def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: 127 | """The LazyLogger writes nothing.""" 128 | pass 129 | -------------------------------------------------------------------------------- /tianshou/policy/imitation/discrete_bcq.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Any, Dict, Optional, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | from tianshou.data import Batch, ReplayBuffer, to_torch 9 | from tianshou.policy import DQNPolicy 10 | 11 | 12 | class DiscreteBCQPolicy(DQNPolicy): 13 | """Implementation of discrete BCQ algorithm. arXiv:1910.01708. 14 | 15 | :param torch.nn.Module model: a model following the rules in 16 | :class:`~tianshou.policy.BasePolicy`. (s -> q_value) 17 | :param torch.nn.Module imitator: a model following the rules in 18 | :class:`~tianshou.policy.BasePolicy`. (s -> imitation_logits) 19 | :param torch.optim.Optimizer optim: a torch.optim for optimizing the model. 20 | :param float discount_factor: in [0, 1]. 21 | :param int estimation_step: the number of steps to look ahead. Default to 1. 22 | :param int target_update_freq: the target network update frequency. 23 | :param float eval_eps: the epsilon-greedy noise added in evaluation. 24 | :param float unlikely_action_threshold: the threshold (tau) for unlikely 25 | actions, as shown in Equ. (17) in the paper. Default to 0.3. 26 | :param float imitation_logits_penalty: regularization weight for imitation 27 | logits. Default to 1e-2. 28 | :param bool reward_normalization: normalize the reward to Normal(0, 1). 29 | Default to False. 30 | :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in 31 | optimizer in each policy.update(). Default to None (no lr_scheduler). 32 | 33 | .. seealso:: 34 | 35 | Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed 36 | explanation. 37 | """ 38 | 39 | def __init__( 40 | self, 41 | model: torch.nn.Module, 42 | imitator: torch.nn.Module, 43 | optim: torch.optim.Optimizer, 44 | discount_factor: float = 0.99, 45 | estimation_step: int = 1, 46 | target_update_freq: int = 8000, 47 | eval_eps: float = 1e-3, 48 | unlikely_action_threshold: float = 0.3, 49 | imitation_logits_penalty: float = 1e-2, 50 | reward_normalization: bool = False, 51 | **kwargs: Any, 52 | ) -> None: 53 | super().__init__( 54 | model, optim, discount_factor, estimation_step, target_update_freq, 55 | reward_normalization, **kwargs 56 | ) 57 | assert target_update_freq > 0, "BCQ needs target network setting." 58 | self.imitator = imitator 59 | assert 0.0 <= unlikely_action_threshold < 1.0, \ 60 | "unlikely_action_threshold should be in [0, 1)" 61 | if unlikely_action_threshold > 0: 62 | self._log_tau = math.log(unlikely_action_threshold) 63 | else: 64 | self._log_tau = -np.inf 65 | assert 0.0 <= eval_eps < 1.0 66 | self.eps = eval_eps 67 | self._weight_reg = imitation_logits_penalty 68 | 69 | def train(self, mode: bool = True) -> "DiscreteBCQPolicy": 70 | self.training = mode 71 | self.model.train(mode) 72 | self.imitator.train(mode) 73 | return self 74 | 75 | def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: 76 | batch = buffer[indices] # batch.obs_next: s_{t+n} 77 | # target_Q = Q_old(s_, argmax(Q_new(s_, *))) 78 | act = self(batch, input="obs_next").act 79 | target_q, _ = self.model_old(batch.obs_next) 80 | target_q = target_q[np.arange(len(act)), act] 81 | return target_q 82 | 83 | def forward( # type: ignore 84 | self, 85 | batch: Batch, 86 | state: Optional[Union[dict, Batch, np.ndarray]] = None, 87 | input: str = "obs", 88 | **kwargs: Any, 89 | ) -> Batch: 90 | obs = batch[input] 91 | q_value, state = self.model(obs, state=state, info=batch.info) 92 | if not hasattr(self, "max_action_num"): 93 | self.max_action_num = q_value.shape[1] 94 | imitation_logits, _ = self.imitator(obs, state=state, info=batch.info) 95 | 96 | # mask actions for argmax 97 | ratio = imitation_logits - imitation_logits.max(dim=-1, keepdim=True).values 98 | mask = (ratio < self._log_tau).float() 99 | act = (q_value - np.inf * mask).argmax(dim=-1) 100 | 101 | return Batch( 102 | act=act, state=state, q_value=q_value, imitation_logits=imitation_logits 103 | ) 104 | 105 | def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: 106 | if self._iter % self._freq == 0: 107 | self.sync_weight() 108 | self._iter += 1 109 | 110 | target_q = batch.returns.flatten() 111 | result = self(batch) 112 | imitation_logits = result.imitation_logits 113 | current_q = result.q_value[np.arange(len(target_q)), batch.act] 114 | act = to_torch(batch.act, dtype=torch.long, device=target_q.device) 115 | q_loss = F.smooth_l1_loss(current_q, target_q) 116 | i_loss = F.nll_loss(F.log_softmax(imitation_logits, dim=-1), act) 117 | reg_loss = imitation_logits.pow(2).mean() 118 | loss = q_loss + i_loss + self._weight_reg * reg_loss 119 | 120 | self.optim.zero_grad() 121 | loss.backward() 122 | self.optim.step() 123 | 124 | return { 125 | "loss": loss.item(), 126 | "loss/q": q_loss.item(), 127 | "loss/i": i_loss.item(), 128 | "loss/reg": reg_loss.item(), 129 | } 130 | -------------------------------------------------------------------------------- /tianshou/policy/imitation/discrete_crr.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Any, Dict 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.distributions import Categorical 7 | 8 | from tianshou.data import Batch, to_torch, to_torch_as 9 | from tianshou.policy.modelfree.pg import PGPolicy 10 | 11 | 12 | class DiscreteCRRPolicy(PGPolicy): 13 | r"""Implementation of discrete Critic Regularized Regression. arXiv:2006.15134. 14 | 15 | :param torch.nn.Module actor: the actor network following the rules in 16 | :class:`~tianshou.policy.BasePolicy`. (s -> logits) 17 | :param torch.nn.Module critic: the action-value critic (i.e., Q function) 18 | network. (s -> Q(s, \*)) 19 | :param torch.optim.Optimizer optim: a torch.optim for optimizing the model. 20 | :param float discount_factor: in [0, 1]. Default to 0.99. 21 | :param str policy_improvement_mode: type of the weight function f. Possible 22 | values: "binary"/"exp"/"all". Default to "exp". 23 | :param float ratio_upper_bound: when policy_improvement_mode is "exp", the value 24 | of the exp function is upper-bounded by this parameter. Default to 20. 25 | :param float beta: when policy_improvement_mode is "exp", this is the denominator 26 | of the exp function. Default to 1. 27 | :param float min_q_weight: weight for CQL loss/regularizer. Default to 10. 28 | :param int target_update_freq: the target network update frequency (0 if 29 | you do not use the target network). Default to 0. 30 | :param bool reward_normalization: normalize the reward to Normal(0, 1). 31 | Default to False. 32 | :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in 33 | optimizer in each policy.update(). Default to None (no lr_scheduler). 34 | 35 | .. seealso:: 36 | Please refer to :class:`~tianshou.policy.PGPolicy` for more detailed 37 | explanation. 38 | """ 39 | 40 | def __init__( 41 | self, 42 | actor: torch.nn.Module, 43 | critic: torch.nn.Module, 44 | optim: torch.optim.Optimizer, 45 | discount_factor: float = 0.99, 46 | policy_improvement_mode: str = "exp", 47 | ratio_upper_bound: float = 20.0, 48 | beta: float = 1.0, 49 | min_q_weight: float = 10.0, 50 | target_update_freq: int = 0, 51 | reward_normalization: bool = False, 52 | **kwargs: Any, 53 | ) -> None: 54 | super().__init__( 55 | actor, 56 | optim, 57 | lambda x: Categorical(logits=x), # type: ignore 58 | discount_factor, 59 | reward_normalization, 60 | **kwargs, 61 | ) 62 | self.critic = critic 63 | self._target = target_update_freq > 0 64 | self._freq = target_update_freq 65 | self._iter = 0 66 | if self._target: 67 | self.actor_old = deepcopy(self.actor) 68 | self.actor_old.eval() 69 | self.critic_old = deepcopy(self.critic) 70 | self.critic_old.eval() 71 | else: 72 | self.actor_old = self.actor 73 | self.critic_old = self.critic 74 | assert policy_improvement_mode in ["exp", "binary", "all"] 75 | self._policy_improvement_mode = policy_improvement_mode 76 | self._ratio_upper_bound = ratio_upper_bound 77 | self._beta = beta 78 | self._min_q_weight = min_q_weight 79 | 80 | def sync_weight(self) -> None: 81 | self.actor_old.load_state_dict(self.actor.state_dict()) 82 | self.critic_old.load_state_dict(self.critic.state_dict()) 83 | 84 | def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: # type: ignore 85 | if self._target and self._iter % self._freq == 0: 86 | self.sync_weight() 87 | self.optim.zero_grad() 88 | q_t = self.critic(batch.obs) 89 | act = to_torch(batch.act, dtype=torch.long, device=q_t.device) 90 | qa_t = q_t.gather(1, act.unsqueeze(1)) 91 | # Critic loss 92 | with torch.no_grad(): 93 | target_a_t, _ = self.actor_old(batch.obs_next) 94 | target_m = Categorical(logits=target_a_t) 95 | q_t_target = self.critic_old(batch.obs_next) 96 | rew = to_torch_as(batch.rew, q_t_target) 97 | expected_target_q = (q_t_target * target_m.probs).sum(-1, keepdim=True) 98 | expected_target_q[batch.done > 0] = 0.0 99 | target = rew.unsqueeze(1) + self._gamma * expected_target_q 100 | critic_loss = 0.5 * F.mse_loss(qa_t, target) 101 | # Actor loss 102 | act_target, _ = self.actor(batch.obs) 103 | dist = Categorical(logits=act_target) 104 | expected_policy_q = (q_t * dist.probs).sum(-1, keepdim=True) 105 | advantage = qa_t - expected_policy_q 106 | if self._policy_improvement_mode == "binary": 107 | actor_loss_coef = (advantage > 0).float() 108 | elif self._policy_improvement_mode == "exp": 109 | actor_loss_coef = ( 110 | (advantage / self._beta).exp().clamp(0, self._ratio_upper_bound) 111 | ) 112 | else: 113 | actor_loss_coef = 1.0 # effectively behavior cloning 114 | actor_loss = (-dist.log_prob(act) * actor_loss_coef).mean() 115 | # CQL loss/regularizer 116 | min_q_loss = (q_t.logsumexp(1) - qa_t).mean() 117 | loss = actor_loss + critic_loss + self._min_q_weight * min_q_loss 118 | loss.backward() 119 | self.optim.step() 120 | self._iter += 1 121 | return { 122 | "loss": loss.item(), 123 | "loss/actor": actor_loss.item(), 124 | "loss/critic": critic_loss.item(), 125 | "loss/cql": min_q_loss.item(), 126 | } 127 | -------------------------------------------------------------------------------- /tianshou/trainer/offline.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, Optional, Union 2 | 3 | import numpy as np 4 | 5 | from tianshou.data import Collector, ReplayBuffer 6 | from tianshou.policy import BasePolicy 7 | from tianshou.trainer.base import BaseTrainer 8 | from tianshou.utils import BaseLogger, LazyLogger 9 | 10 | 11 | class OfflineTrainer(BaseTrainer): 12 | """Create an iterator class for offline training procedure. 13 | 14 | :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. 15 | :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. 16 | This buffer must be populated with experiences for offline RL. 17 | :param Collector test_collector: the collector used for testing. If it's None, 18 | then no testing will be performed. 19 | :param int max_epoch: the maximum number of epochs for training. The training 20 | process might be finished before reaching ``max_epoch`` if ``stop_fn`` is 21 | set. 22 | :param int update_per_epoch: the number of policy network updates, so-called 23 | gradient steps, per epoch. 24 | :param episode_per_test: the number of episodes for one policy evaluation. 25 | :param int batch_size: the batch size of sample data, which is going to feed in 26 | the policy network. 27 | :param function test_fn: a hook called at the beginning of testing in each 28 | epoch. 29 | It can be used to perform custom additional operations, with the signature 30 | ``f(num_epoch: int, step_idx: int) -> None``. 31 | :param function save_best_fn: a hook called when the undiscounted average mean 32 | reward in evaluation phase gets better, with the signature 33 | ``f(policy: BasePolicy) -> None``. It was ``save_fn`` previously. 34 | :param function save_checkpoint_fn: a function to save training process and 35 | return the saved checkpoint path, with the signature ``f(epoch: int, 36 | env_step: int, gradient_step: int) -> str``; you can save whatever you want. 37 | Because offline-RL doesn't have env_step, the env_step is always 0 here. 38 | :param bool resume_from_log: resume gradient_step and other metadata from 39 | existing tensorboard log. Default to False. 40 | :param function stop_fn: a function with signature ``f(mean_rewards: float) -> 41 | bool``, receives the average undiscounted returns of the testing result, 42 | returns a boolean which indicates whether reaching the goal. 43 | :param function reward_metric: a function with signature ``f(rewards: 44 | np.ndarray with shape (num_episode, agent_num)) -> np.ndarray with shape 45 | (num_episode,)``, used in multi-agent RL. We need to return a single scalar 46 | for each episode's result to monitor training in the multi-agent RL 47 | setting. This function specifies what is the desired metric, e.g., the 48 | reward of agent 1 or the average reward over all agents. 49 | :param BaseLogger logger: A logger that logs statistics during 50 | updating/testing. Default to a logger that doesn't log anything. 51 | :param bool verbose: whether to print the information. Default to True. 52 | :param bool show_progress: whether to display a progress bar when training. 53 | Default to True. 54 | """ 55 | 56 | __doc__ = BaseTrainer.gen_doc("offline") + "\n".join(__doc__.split("\n")[1:]) 57 | 58 | def __init__( 59 | self, 60 | policy: BasePolicy, 61 | buffer: ReplayBuffer, 62 | test_collector: Optional[Collector], 63 | max_epoch: int, 64 | update_per_epoch: int, 65 | episode_per_test: int, 66 | batch_size: int, 67 | test_fn: Optional[Callable[[int, Optional[int]], None]] = None, 68 | stop_fn: Optional[Callable[[float], bool]] = None, 69 | save_best_fn: Optional[Callable[[BasePolicy], None]] = None, 70 | save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None, 71 | resume_from_log: bool = False, 72 | reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, 73 | logger: BaseLogger = LazyLogger(), 74 | verbose: bool = True, 75 | show_progress: bool = True, 76 | **kwargs: Any, 77 | ): 78 | super().__init__( 79 | learning_type="offline", 80 | policy=policy, 81 | buffer=buffer, 82 | test_collector=test_collector, 83 | max_epoch=max_epoch, 84 | update_per_epoch=update_per_epoch, 85 | step_per_epoch=update_per_epoch, 86 | episode_per_test=episode_per_test, 87 | batch_size=batch_size, 88 | test_fn=test_fn, 89 | stop_fn=stop_fn, 90 | save_best_fn=save_best_fn, 91 | save_checkpoint_fn=save_checkpoint_fn, 92 | resume_from_log=resume_from_log, 93 | reward_metric=reward_metric, 94 | logger=logger, 95 | verbose=verbose, 96 | show_progress=show_progress, 97 | **kwargs, 98 | ) 99 | 100 | def policy_update_fn( 101 | self, data: Dict[str, Any], result: Optional[Dict[str, Any]] = None 102 | ) -> None: 103 | """Perform one off-line policy update.""" 104 | assert self.buffer 105 | self.gradient_step += 1 106 | losses = self.policy.update(self.batch_size, self.buffer) 107 | data.update({"gradient_step": str(self.gradient_step)}) 108 | self.log_update_data(data, losses) 109 | 110 | 111 | def offline_trainer(*args, **kwargs) -> Dict[str, Union[float, str]]: # type: ignore 112 | """Wrapper for offline_trainer run method. 113 | 114 | It is identical to ``OfflineTrainer(...).run()``. 115 | 116 | :return: See :func:`~tianshou.trainer.gather_info`. 117 | """ 118 | return OfflineTrainer(*args, **kwargs).run() 119 | 120 | 121 | offline_trainer_iter = OfflineTrainer 122 | -------------------------------------------------------------------------------- /tianshou/policy/modelfree/pg.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Type, Union 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from tianshou.data import Batch, ReplayBuffer, to_torch, to_torch_as 7 | from tianshou.policy import BasePolicy 8 | from tianshou.utils import RunningMeanStd 9 | 10 | 11 | class PGPolicy(BasePolicy): 12 | """Implementation of REINFORCE algorithm. 13 | 14 | :param torch.nn.Module model: a model following the rules in 15 | :class:`~tianshou.policy.BasePolicy`. (s -> logits) 16 | :param torch.optim.Optimizer optim: a torch.optim for optimizing the model. 17 | :param dist_fn: distribution class for computing the action. 18 | :type dist_fn: Type[torch.distributions.Distribution] 19 | :param float discount_factor: in [0, 1]. Default to 0.99. 20 | :param bool action_scaling: whether to map actions from range [-1, 1] to range 21 | [action_spaces.low, action_spaces.high]. Default to True. 22 | :param str action_bound_method: method to bound action to range [-1, 1], can be 23 | either "clip" (for simply clipping the action), "tanh" (for applying tanh 24 | squashing) for now, or empty string for no bounding. Default to "clip". 25 | :param Optional[gym.Space] action_space: env's action space, mandatory if you want 26 | to use option "action_scaling" or "action_bound_method". Default to None. 27 | :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in 28 | optimizer in each policy.update(). Default to None (no lr_scheduler). 29 | :param bool deterministic_eval: whether to use deterministic action instead of 30 | stochastic action sampled by the policy. Default to False. 31 | 32 | .. seealso:: 33 | 34 | Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed 35 | explanation. 36 | """ 37 | 38 | def __init__( 39 | self, 40 | model: torch.nn.Module, 41 | optim: torch.optim.Optimizer, 42 | dist_fn: Type[torch.distributions.Distribution], 43 | discount_factor: float = 0.99, 44 | reward_normalization: bool = False, 45 | action_scaling: bool = True, 46 | action_bound_method: str = "clip", 47 | deterministic_eval: bool = False, 48 | **kwargs: Any, 49 | ) -> None: 50 | super().__init__( 51 | action_scaling=action_scaling, 52 | action_bound_method=action_bound_method, 53 | **kwargs 54 | ) 55 | self.actor = model 56 | self.optim = optim 57 | self.dist_fn = dist_fn 58 | assert 0.0 <= discount_factor <= 1.0, "discount factor should be in [0, 1]" 59 | self._gamma = discount_factor 60 | self._rew_norm = reward_normalization 61 | self.ret_rms = RunningMeanStd() 62 | self._eps = 1e-8 63 | self._deterministic_eval = deterministic_eval 64 | 65 | def process_fn( 66 | self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray 67 | ) -> Batch: 68 | r"""Compute the discounted returns for each transition. 69 | 70 | .. math:: 71 | G_t = \sum_{i=t}^T \gamma^{i-t}r_i 72 | 73 | where :math:`T` is the terminal time step, :math:`\gamma` is the 74 | discount factor, :math:`\gamma \in [0, 1]`. 75 | """ 76 | v_s_ = np.full(indices.shape, self.ret_rms.mean) 77 | unnormalized_returns, _ = self.compute_episodic_return( 78 | batch, buffer, indices, v_s_=v_s_, gamma=self._gamma, gae_lambda=1.0 79 | ) 80 | if self._rew_norm: 81 | batch.returns = (unnormalized_returns - self.ret_rms.mean) / \ 82 | np.sqrt(self.ret_rms.var + self._eps) 83 | self.ret_rms.update(unnormalized_returns) 84 | else: 85 | batch.returns = unnormalized_returns 86 | return batch 87 | 88 | def forward( 89 | self, 90 | batch: Batch, 91 | state: Optional[Union[dict, Batch, np.ndarray]] = None, 92 | **kwargs: Any, 93 | ) -> Batch: 94 | """Compute action over the given batch data. 95 | 96 | :return: A :class:`~tianshou.data.Batch` which has 4 keys: 97 | 98 | * ``act`` the action. 99 | * ``logits`` the network's raw output. 100 | * ``dist`` the action distribution. 101 | * ``state`` the hidden state. 102 | 103 | .. seealso:: 104 | 105 | Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for 106 | more detailed explanation. 107 | """ 108 | logits, hidden, _ = self.actor(batch.obs, state=state) 109 | if isinstance(logits, tuple): 110 | dist = self.dist_fn(*logits) 111 | else: 112 | dist = self.dist_fn(logits) 113 | if self._deterministic_eval and not self.training: 114 | if self.action_type == "discrete": 115 | act = logits.argmax(-1) 116 | elif self.action_type == "continuous": 117 | act = logits[0] 118 | else: 119 | act = dist.sample() 120 | return Batch(logits=logits, act=act, state=hidden, dist=dist) 121 | 122 | def learn( # type: ignore 123 | self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any 124 | ) -> Dict[str, List[float]]: 125 | losses = [] 126 | for _ in range(repeat): 127 | for minibatch in batch.split(batch_size, merge_last=True): 128 | self.optim.zero_grad() 129 | result = self(minibatch) 130 | dist = result.dist 131 | act = to_torch_as(minibatch.act, result.act) 132 | ret = to_torch(minibatch.returns, torch.float, result.act.device) 133 | log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1) 134 | loss = -(log_prob * ret).mean() 135 | loss.backward() 136 | self.optim.step() 137 | losses.append(loss.item()) 138 | 139 | return {"loss": losses} 140 | -------------------------------------------------------------------------------- /tianshou/policy/modelfree/td3.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Any, Dict, Optional 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from tianshou.data import Batch, ReplayBuffer 8 | from tianshou.exploration import BaseNoise, GaussianNoise 9 | from tianshou.policy import DDPGPolicy 10 | 11 | 12 | class TD3Policy(DDPGPolicy): 13 | """Implementation of TD3, arXiv:1802.09477. 14 | 15 | :param torch.nn.Module actor: the actor network following the rules in 16 | :class:`~tianshou.policy.BasePolicy`. (s -> logits) 17 | :param torch.optim.Optimizer actor_optim: the optimizer for actor network. 18 | :param torch.nn.Module critic1: the first critic network. (s, a -> Q(s, a)) 19 | :param torch.optim.Optimizer critic1_optim: the optimizer for the first 20 | critic network. 21 | :param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, a)) 22 | :param torch.optim.Optimizer critic2_optim: the optimizer for the second 23 | critic network. 24 | :param float tau: param for soft update of the target network. Default to 0.005. 25 | :param float gamma: discount factor, in [0, 1]. Default to 0.99. 26 | :param float exploration_noise: the exploration noise, add to the action. 27 | Default to ``GaussianNoise(sigma=0.1)`` 28 | :param float policy_noise: the noise used in updating policy network. 29 | Default to 0.2. 30 | :param int update_actor_freq: the update frequency of actor network. 31 | Default to 2. 32 | :param float noise_clip: the clipping range used in updating policy network. 33 | Default to 0.5. 34 | :param bool reward_normalization: normalize the reward to Normal(0, 1). 35 | Default to False. 36 | :param bool action_scaling: whether to map actions from range [-1, 1] to range 37 | [action_spaces.low, action_spaces.high]. Default to True. 38 | :param str action_bound_method: method to bound action to range [-1, 1], can be 39 | either "clip" (for simply clipping the action) or empty string for no bounding. 40 | Default to "clip". 41 | :param Optional[gym.Space] action_space: env's action space, mandatory if you want 42 | to use option "action_scaling" or "action_bound_method". Default to None. 43 | :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in 44 | optimizer in each policy.update(). Default to None (no lr_scheduler). 45 | 46 | .. seealso:: 47 | 48 | Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed 49 | explanation. 50 | """ 51 | 52 | def __init__( 53 | self, 54 | actor: torch.nn.Module, 55 | actor_optim: torch.optim.Optimizer, 56 | critic1: torch.nn.Module, 57 | critic1_optim: torch.optim.Optimizer, 58 | critic2: torch.nn.Module, 59 | critic2_optim: torch.optim.Optimizer, 60 | tau: float = 0.005, 61 | gamma: float = 0.99, 62 | exploration_noise: Optional[BaseNoise] = GaussianNoise(sigma=0.1), 63 | policy_noise: float = 0.2, 64 | update_actor_freq: int = 2, 65 | noise_clip: float = 0.5, 66 | reward_normalization: bool = False, 67 | estimation_step: int = 1, 68 | **kwargs: Any, 69 | ) -> None: 70 | super().__init__( 71 | actor, actor_optim, None, None, tau, gamma, exploration_noise, 72 | reward_normalization, estimation_step, **kwargs 73 | ) 74 | self.critic1, self.critic1_old = critic1, deepcopy(critic1) 75 | self.critic1_old.eval() 76 | self.critic1_optim = critic1_optim 77 | self.critic2, self.critic2_old = critic2, deepcopy(critic2) 78 | self.critic2_old.eval() 79 | self.critic2_optim = critic2_optim 80 | self._policy_noise = policy_noise 81 | self._freq = update_actor_freq 82 | self._noise_clip = noise_clip 83 | self._cnt = 0 84 | self._last = 0 85 | 86 | def train(self, mode: bool = True) -> "TD3Policy": 87 | self.training = mode 88 | self.actor.train(mode) 89 | self.critic1.train(mode) 90 | self.critic2.train(mode) 91 | return self 92 | 93 | def sync_weight(self) -> None: 94 | self.soft_update(self.critic1_old, self.critic1, self.tau) 95 | self.soft_update(self.critic2_old, self.critic2, self.tau) 96 | self.soft_update(self.actor_old, self.actor, self.tau) 97 | 98 | def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: 99 | batch = buffer[indices] # batch.obs: s_{t+n} 100 | act_ = self(batch, model="actor_old", input="obs_next").act 101 | noise = torch.randn(size=act_.shape, device=act_.device) * self._policy_noise 102 | if self._noise_clip > 0.0: 103 | noise = noise.clamp(-self._noise_clip, self._noise_clip) 104 | act_ += noise 105 | target_q = torch.min( 106 | self.critic1_old(batch.obs_next, act_), 107 | self.critic2_old(batch.obs_next, act_), 108 | ) 109 | return target_q 110 | 111 | def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: 112 | # critic 1&2 113 | td1, critic1_loss = self._mse_optimizer( 114 | batch, self.critic1, self.critic1_optim 115 | ) 116 | td2, critic2_loss = self._mse_optimizer( 117 | batch, self.critic2, self.critic2_optim 118 | ) 119 | batch.weight = (td1 + td2) / 2.0 # prio-buffer 120 | 121 | # actor 122 | if self._cnt % self._freq == 0: 123 | actor_loss = -self.critic1(batch.obs, self(batch, eps=0.0).act).mean() 124 | self.actor_optim.zero_grad() 125 | actor_loss.backward() 126 | self._last = actor_loss.item() 127 | self.actor_optim.step() 128 | self.sync_weight() 129 | self._cnt += 1 130 | return { 131 | "loss/actor": self._last, 132 | "loss/critic1": critic1_loss.item(), 133 | "loss/critic2": critic2_loss.item(), 134 | } 135 | --------------------------------------------------------------------------------