├── rainbow_dqn_result.png ├── data_train ├── Rainbow_DQN_env_LunarLander-v2_number_1_seed_0.npy ├── Rainbow_DQN_env_LunarLander-v2_number_1_seed_10.npy ├── Rainbow_DQN_env_LunarLander-v2_number_1_seed_100.npy ├── DQN_Double_Dueling_Noisy_PER_env_LunarLander-v2_number_1_seed_0.npy ├── DQN_Double_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_0.npy ├── DQN_Double_Dueling_Noisy_PER_env_LunarLander-v2_number_1_seed_10.npy ├── DQN_Double_Dueling_Noisy_PER_env_LunarLander-v2_number_1_seed_100.npy ├── DQN_Double_Dueling_PER_N_steps_env_LunarLander-v2_number_1_seed_0.npy ├── DQN_Double_Dueling_PER_N_steps_env_LunarLander-v2_number_1_seed_10.npy ├── DQN_Double_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_10.npy ├── DQN_Double_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_100.npy ├── DQN_Dueling_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_0.npy ├── DQN_Dueling_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_10.npy ├── DQN_Dueling_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_100.npy ├── DQN_Double_Dueling_Noisy_N_steps_env_LunarLander-v2_number_1_seed_0.npy ├── DQN_Double_Dueling_Noisy_N_steps_env_LunarLander-v2_number_1_seed_10.npy ├── DQN_Double_Dueling_PER_N_steps_env_LunarLander-v2_number_1_seed_100.npy └── DQN_Double_Dueling_Noisy_N_steps_env_LunarLander-v2_number_1_seed_100.npy ├── runs └── DQN │ ├── Rainbow_DQN_env_LunarLander-v2_number_1_seed_0 │ └── events.out.tfevents.1658479209.DESKTOP-LMKC0MO.1228.0 │ ├── Rainbow_DQN_env_LunarLander-v2_number_1_seed_10 │ └── events.out.tfevents.1658479212.DESKTOP-LMKC0MO.10500.0 │ ├── Rainbow_DQN_env_LunarLander-v2_number_1_seed_100 │ └── events.out.tfevents.1658479214.DESKTOP-LMKC0MO.9512.0 │ ├── DQN_Double_Dueling_Noisy_PER_env_LunarLander-v2_number_1_seed_0 │ └── events.out.tfevents.1658494481.DESKTOP-LMKC0MO.9316.0 │ ├── DQN_Double_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_0 │ └── events.out.tfevents.1658494473.DESKTOP-LMKC0MO.2144.0 │ ├── DQN_Double_Dueling_Noisy_PER_env_LunarLander-v2_number_1_seed_10 │ └── events.out.tfevents.1658512436.DESKTOP-LMKC0MO.9316.1 │ ├── DQN_Double_Dueling_Noisy_PER_env_LunarLander-v2_number_1_seed_100 │ └── events.out.tfevents.1658531515.DESKTOP-LMKC0MO.9316.2 │ ├── DQN_Double_Dueling_PER_N_steps_env_LunarLander-v2_number_1_seed_0 │ └── events.out.tfevents.1658494475.DESKTOP-LMKC0MO.5976.0 │ ├── DQN_Double_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_10 │ └── events.out.tfevents.1658511489.DESKTOP-LMKC0MO.2144.1 │ ├── DQN_Double_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_100 │ └── events.out.tfevents.1658529336.DESKTOP-LMKC0MO.2144.2 │ ├── DQN_Dueling_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_0 │ └── events.out.tfevents.1658494471.DESKTOP-LMKC0MO.9964.0 │ ├── DQN_Dueling_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_10 │ └── events.out.tfevents.1658510515.DESKTOP-LMKC0MO.9964.1 │ ├── DQN_Double_Dueling_Noisy_N_steps_env_LunarLander-v2_number_1_seed_0 │ └── events.out.tfevents.1658494478.DESKTOP-LMKC0MO.1408.0 │ ├── DQN_Double_Dueling_Noisy_N_steps_env_LunarLander-v2_number_1_seed_10 │ └── events.out.tfevents.1658507126.DESKTOP-LMKC0MO.1408.1 │ ├── DQN_Double_Dueling_PER_N_steps_env_LunarLander-v2_number_1_seed_10 │ └── events.out.tfevents.1658511615.DESKTOP-LMKC0MO.5976.1 │ ├── DQN_Double_Dueling_PER_N_steps_env_LunarLander-v2_number_1_seed_100 │ └── events.out.tfevents.1658528978.DESKTOP-LMKC0MO.5976.2 │ ├── DQN_Dueling_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_100 │ └── events.out.tfevents.1658526626.DESKTOP-LMKC0MO.9964.2 │ └── DQN_Double_Dueling_Noisy_N_steps_env_LunarLander-v2_number_1_seed_100 │ └── events.out.tfevents.1658520541.DESKTOP-LMKC0MO.1408.2 ├── LICENSE ├── README.md ├── sum_tree.py ├── network.py ├── rainbow_dqn.py ├── Rainbow_DQN_main.py └── replay_buffer.py /rainbow_dqn_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/rainbow_dqn_result.png -------------------------------------------------------------------------------- /data_train/Rainbow_DQN_env_LunarLander-v2_number_1_seed_0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/data_train/Rainbow_DQN_env_LunarLander-v2_number_1_seed_0.npy -------------------------------------------------------------------------------- /data_train/Rainbow_DQN_env_LunarLander-v2_number_1_seed_10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/data_train/Rainbow_DQN_env_LunarLander-v2_number_1_seed_10.npy -------------------------------------------------------------------------------- /data_train/Rainbow_DQN_env_LunarLander-v2_number_1_seed_100.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/data_train/Rainbow_DQN_env_LunarLander-v2_number_1_seed_100.npy -------------------------------------------------------------------------------- /data_train/DQN_Double_Dueling_Noisy_PER_env_LunarLander-v2_number_1_seed_0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/data_train/DQN_Double_Dueling_Noisy_PER_env_LunarLander-v2_number_1_seed_0.npy -------------------------------------------------------------------------------- /data_train/DQN_Double_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/data_train/DQN_Double_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_0.npy -------------------------------------------------------------------------------- /data_train/DQN_Double_Dueling_Noisy_PER_env_LunarLander-v2_number_1_seed_10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/data_train/DQN_Double_Dueling_Noisy_PER_env_LunarLander-v2_number_1_seed_10.npy -------------------------------------------------------------------------------- /data_train/DQN_Double_Dueling_Noisy_PER_env_LunarLander-v2_number_1_seed_100.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/data_train/DQN_Double_Dueling_Noisy_PER_env_LunarLander-v2_number_1_seed_100.npy -------------------------------------------------------------------------------- /data_train/DQN_Double_Dueling_PER_N_steps_env_LunarLander-v2_number_1_seed_0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/data_train/DQN_Double_Dueling_PER_N_steps_env_LunarLander-v2_number_1_seed_0.npy -------------------------------------------------------------------------------- /data_train/DQN_Double_Dueling_PER_N_steps_env_LunarLander-v2_number_1_seed_10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/data_train/DQN_Double_Dueling_PER_N_steps_env_LunarLander-v2_number_1_seed_10.npy -------------------------------------------------------------------------------- /data_train/DQN_Double_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/data_train/DQN_Double_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_10.npy -------------------------------------------------------------------------------- /data_train/DQN_Double_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_100.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/data_train/DQN_Double_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_100.npy -------------------------------------------------------------------------------- /data_train/DQN_Dueling_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/data_train/DQN_Dueling_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_0.npy -------------------------------------------------------------------------------- /data_train/DQN_Dueling_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/data_train/DQN_Dueling_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_10.npy -------------------------------------------------------------------------------- /data_train/DQN_Dueling_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_100.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/data_train/DQN_Dueling_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_100.npy -------------------------------------------------------------------------------- /data_train/DQN_Double_Dueling_Noisy_N_steps_env_LunarLander-v2_number_1_seed_0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/data_train/DQN_Double_Dueling_Noisy_N_steps_env_LunarLander-v2_number_1_seed_0.npy -------------------------------------------------------------------------------- /data_train/DQN_Double_Dueling_Noisy_N_steps_env_LunarLander-v2_number_1_seed_10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/data_train/DQN_Double_Dueling_Noisy_N_steps_env_LunarLander-v2_number_1_seed_10.npy -------------------------------------------------------------------------------- /data_train/DQN_Double_Dueling_PER_N_steps_env_LunarLander-v2_number_1_seed_100.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/data_train/DQN_Double_Dueling_PER_N_steps_env_LunarLander-v2_number_1_seed_100.npy -------------------------------------------------------------------------------- /data_train/DQN_Double_Dueling_Noisy_N_steps_env_LunarLander-v2_number_1_seed_100.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/data_train/DQN_Double_Dueling_Noisy_N_steps_env_LunarLander-v2_number_1_seed_100.npy -------------------------------------------------------------------------------- /runs/DQN/Rainbow_DQN_env_LunarLander-v2_number_1_seed_0/events.out.tfevents.1658479209.DESKTOP-LMKC0MO.1228.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/runs/DQN/Rainbow_DQN_env_LunarLander-v2_number_1_seed_0/events.out.tfevents.1658479209.DESKTOP-LMKC0MO.1228.0 -------------------------------------------------------------------------------- /runs/DQN/Rainbow_DQN_env_LunarLander-v2_number_1_seed_10/events.out.tfevents.1658479212.DESKTOP-LMKC0MO.10500.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/runs/DQN/Rainbow_DQN_env_LunarLander-v2_number_1_seed_10/events.out.tfevents.1658479212.DESKTOP-LMKC0MO.10500.0 -------------------------------------------------------------------------------- /runs/DQN/Rainbow_DQN_env_LunarLander-v2_number_1_seed_100/events.out.tfevents.1658479214.DESKTOP-LMKC0MO.9512.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/runs/DQN/Rainbow_DQN_env_LunarLander-v2_number_1_seed_100/events.out.tfevents.1658479214.DESKTOP-LMKC0MO.9512.0 -------------------------------------------------------------------------------- /runs/DQN/DQN_Double_Dueling_Noisy_PER_env_LunarLander-v2_number_1_seed_0/events.out.tfevents.1658494481.DESKTOP-LMKC0MO.9316.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/runs/DQN/DQN_Double_Dueling_Noisy_PER_env_LunarLander-v2_number_1_seed_0/events.out.tfevents.1658494481.DESKTOP-LMKC0MO.9316.0 -------------------------------------------------------------------------------- /runs/DQN/DQN_Double_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_0/events.out.tfevents.1658494473.DESKTOP-LMKC0MO.2144.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/runs/DQN/DQN_Double_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_0/events.out.tfevents.1658494473.DESKTOP-LMKC0MO.2144.0 -------------------------------------------------------------------------------- /runs/DQN/DQN_Double_Dueling_Noisy_PER_env_LunarLander-v2_number_1_seed_10/events.out.tfevents.1658512436.DESKTOP-LMKC0MO.9316.1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/runs/DQN/DQN_Double_Dueling_Noisy_PER_env_LunarLander-v2_number_1_seed_10/events.out.tfevents.1658512436.DESKTOP-LMKC0MO.9316.1 -------------------------------------------------------------------------------- /runs/DQN/DQN_Double_Dueling_Noisy_PER_env_LunarLander-v2_number_1_seed_100/events.out.tfevents.1658531515.DESKTOP-LMKC0MO.9316.2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/runs/DQN/DQN_Double_Dueling_Noisy_PER_env_LunarLander-v2_number_1_seed_100/events.out.tfevents.1658531515.DESKTOP-LMKC0MO.9316.2 -------------------------------------------------------------------------------- /runs/DQN/DQN_Double_Dueling_PER_N_steps_env_LunarLander-v2_number_1_seed_0/events.out.tfevents.1658494475.DESKTOP-LMKC0MO.5976.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/runs/DQN/DQN_Double_Dueling_PER_N_steps_env_LunarLander-v2_number_1_seed_0/events.out.tfevents.1658494475.DESKTOP-LMKC0MO.5976.0 -------------------------------------------------------------------------------- /runs/DQN/DQN_Double_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_10/events.out.tfevents.1658511489.DESKTOP-LMKC0MO.2144.1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/runs/DQN/DQN_Double_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_10/events.out.tfevents.1658511489.DESKTOP-LMKC0MO.2144.1 -------------------------------------------------------------------------------- /runs/DQN/DQN_Double_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_100/events.out.tfevents.1658529336.DESKTOP-LMKC0MO.2144.2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/runs/DQN/DQN_Double_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_100/events.out.tfevents.1658529336.DESKTOP-LMKC0MO.2144.2 -------------------------------------------------------------------------------- /runs/DQN/DQN_Dueling_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_0/events.out.tfevents.1658494471.DESKTOP-LMKC0MO.9964.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/runs/DQN/DQN_Dueling_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_0/events.out.tfevents.1658494471.DESKTOP-LMKC0MO.9964.0 -------------------------------------------------------------------------------- /runs/DQN/DQN_Dueling_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_10/events.out.tfevents.1658510515.DESKTOP-LMKC0MO.9964.1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/runs/DQN/DQN_Dueling_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_10/events.out.tfevents.1658510515.DESKTOP-LMKC0MO.9964.1 -------------------------------------------------------------------------------- /runs/DQN/DQN_Double_Dueling_Noisy_N_steps_env_LunarLander-v2_number_1_seed_0/events.out.tfevents.1658494478.DESKTOP-LMKC0MO.1408.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/runs/DQN/DQN_Double_Dueling_Noisy_N_steps_env_LunarLander-v2_number_1_seed_0/events.out.tfevents.1658494478.DESKTOP-LMKC0MO.1408.0 -------------------------------------------------------------------------------- /runs/DQN/DQN_Double_Dueling_Noisy_N_steps_env_LunarLander-v2_number_1_seed_10/events.out.tfevents.1658507126.DESKTOP-LMKC0MO.1408.1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/runs/DQN/DQN_Double_Dueling_Noisy_N_steps_env_LunarLander-v2_number_1_seed_10/events.out.tfevents.1658507126.DESKTOP-LMKC0MO.1408.1 -------------------------------------------------------------------------------- /runs/DQN/DQN_Double_Dueling_PER_N_steps_env_LunarLander-v2_number_1_seed_10/events.out.tfevents.1658511615.DESKTOP-LMKC0MO.5976.1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/runs/DQN/DQN_Double_Dueling_PER_N_steps_env_LunarLander-v2_number_1_seed_10/events.out.tfevents.1658511615.DESKTOP-LMKC0MO.5976.1 -------------------------------------------------------------------------------- /runs/DQN/DQN_Double_Dueling_PER_N_steps_env_LunarLander-v2_number_1_seed_100/events.out.tfevents.1658528978.DESKTOP-LMKC0MO.5976.2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/runs/DQN/DQN_Double_Dueling_PER_N_steps_env_LunarLander-v2_number_1_seed_100/events.out.tfevents.1658528978.DESKTOP-LMKC0MO.5976.2 -------------------------------------------------------------------------------- /runs/DQN/DQN_Dueling_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_100/events.out.tfevents.1658526626.DESKTOP-LMKC0MO.9964.2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/runs/DQN/DQN_Dueling_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_100/events.out.tfevents.1658526626.DESKTOP-LMKC0MO.9964.2 -------------------------------------------------------------------------------- /runs/DQN/DQN_Double_Dueling_Noisy_N_steps_env_LunarLander-v2_number_1_seed_100/events.out.tfevents.1658520541.DESKTOP-LMKC0MO.1408.2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/runs/DQN/DQN_Double_Dueling_Noisy_N_steps_env_LunarLander-v2_number_1_seed_100/events.out.tfevents.1658520541.DESKTOP-LMKC0MO.1408.2 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Lizhi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Rainbow DQN 2 | This is a concise Pytorch implementation of Rainbow DQN, including Double Q-learning, Dueling network, Noisy network, PER and n-steps Q-learning.
3 | 4 | ## Dependencies 5 | python==3.7.9
6 | numpy==1.19.4
7 | pytorch==1.5.0
8 | tensorboard==0.6.0
9 | gym==0.21.0
10 | 11 | ## How to use my code? 12 | You can dircetly run Rainbow_DQN_main.py in your own IDE.
13 | 14 | ### Trainning environments 15 | You can set the 'env_index' in the code to change the environments.
16 | env_index=0 represent 'CartPole-v1'
17 | env_index=1 represent 'LunarLander-v2'
18 | 19 | ### How to see the training results? 20 | You can use the tensorboard to visualize the training curves, which are saved in the file 'runs'.
21 | The rewards data are saved as numpy in the file 'data_train'.
22 | The training curves are shown below.
23 | The right picture is smoothed by averaging over a window of 10 steps. The solid line and the shadow respectively represent the average and standard deviation over three different random seeds. (seed=0, 10, 100)
24 | ![image](https://github.com/Lizhi-sjtu/Rainbow-DQN-pytorch/blob/main/rainbow_dqn_result.png) 25 | 26 | ## Reference 27 | [1] Mnih V, Kavukcuoglu K, Silver D, et al. Human-level control through deep reinforcement learning[J]. nature, 2015, 518(7540): 529-533.
28 | [2] Van Hasselt H, Guez A, Silver D. Deep reinforcement learning with double q-learning[C]//Proceedings of the AAAI conference on artificial intelligence. 2016, 30(1).
29 | [3] Wang Z, Schaul T, Hessel M, et al. Dueling network architectures for deep reinforcement learning[C]//International conference on machine learning. PMLR, 2016: 1995-2003.
30 | [4] Fortunato M, Azar M G, Piot B, et al. Noisy networks for exploration[J]. arXiv preprint arXiv:1706.10295, 2017.
31 | [5] Schaul T, Quan J, Antonoglou I, et al. Prioritized experience replay[J]. arXiv preprint arXiv:1511.05952, 2015.
32 | [6] Hessel M, Modayil J, Van Hasselt H, et al. Rainbow: Combining improvements in deep reinforcement learning[C]//Thirty-second AAAI conference on artificial intelligence. 2018.
33 | 34 | -------------------------------------------------------------------------------- /sum_tree.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | class SumTree(object): 5 | """ 6 | Story data with its priority in the tree. 7 | Tree structure and array storage: 8 | 9 | Tree index: 10 | 0 -> storing priority sum 11 | / \ 12 | 1 2 13 | / \ / \ 14 | 3 4 5 6 -> storing priority for transitions 15 | 16 | Array type for storing: 17 | [0,1,2,3,4,5,6] 18 | """ 19 | 20 | def __init__(self, buffer_capacity): 21 | self.buffer_capacity = buffer_capacity # buffer的容量 22 | self.tree_capacity = 2 * buffer_capacity - 1 # sum_tree的容量 23 | self.tree = np.zeros(self.tree_capacity) 24 | 25 | def update(self, data_index, priority): 26 | # data_index表示当前数据在buffer中的index 27 | # tree_index表示当前数据在sum_tree中的index 28 | tree_index = data_index + self.buffer_capacity - 1 # 把当前数据在buffer中的index转换为在sum_tree中的index 29 | change = priority - self.tree[tree_index] # 当前数据的priority的改变量 30 | self.tree[tree_index] = priority # 更新树的最后一层叶子节点的优先级 31 | # then propagate the change through the tree 32 | while tree_index != 0: # 更新上层节点的优先级,一直传播到最顶端 33 | tree_index = (tree_index - 1) // 2 34 | self.tree[tree_index] += change 35 | 36 | def get_index(self, v): 37 | parent_idx = 0 # 从树的顶端开始 38 | while True: 39 | child_left_idx = 2 * parent_idx + 1 # 父节点下方的左右两个子节点的index 40 | child_right_idx = child_left_idx + 1 41 | if child_left_idx >= self.tree_capacity: # reach bottom, end search 42 | tree_index = parent_idx # tree_index表示采样到的数据在sum_tree中的index 43 | break 44 | else: # downward search, always search for a higher priority node 45 | if v <= self.tree[child_left_idx]: 46 | parent_idx = child_left_idx 47 | else: 48 | v -= self.tree[child_left_idx] 49 | parent_idx = child_right_idx 50 | 51 | data_index = tree_index - self.buffer_capacity + 1 # tree_index->data_index 52 | return data_index, self.tree[tree_index] # 返回采样到的data在buffer中的index,以及相对应的priority 53 | 54 | def get_batch_index(self, current_size, batch_size, beta): 55 | batch_index = np.zeros(batch_size, dtype=np.long) 56 | IS_weight = torch.zeros(batch_size, dtype=torch.float32) 57 | segment = self.priority_sum / batch_size # 把[0,priority_sum]等分成batch_size个区间,在每个区间均匀采样一个数 58 | for i in range(batch_size): 59 | a = segment * i 60 | b = segment * (i + 1) 61 | v = np.random.uniform(a, b) 62 | index, priority = self.get_index(v) 63 | batch_index[i] = index 64 | prob = priority / self.priority_sum # 当前数据被采样的概率 65 | IS_weight[i] = (current_size * prob) ** (-beta) 66 | IS_weight /= IS_weight.max() # normalization 67 | 68 | return batch_index, IS_weight 69 | 70 | @property 71 | def priority_sum(self): 72 | return self.tree[0] # 树的顶端保存了所有priority之和 73 | 74 | @property 75 | def priority_max(self): 76 | return self.tree[self.buffer_capacity - 1:].max() # 树的最后一层叶节点,保存的才是每个数据对应的priority 77 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | 7 | class Dueling_Net(nn.Module): 8 | def __init__(self, args): 9 | super(Dueling_Net, self).__init__() 10 | self.fc1 = nn.Linear(args.state_dim, args.hidden_dim) 11 | self.fc2 = nn.Linear(args.hidden_dim, args.hidden_dim) 12 | if args.use_noisy: 13 | self.V = NoisyLinear(args.hidden_dim, 1) 14 | self.A = NoisyLinear(args.hidden_dim, args.action_dim) 15 | else: 16 | self.V = nn.Linear(args.hidden_dim, 1) 17 | self.A = nn.Linear(args.hidden_dim, args.action_dim) 18 | 19 | def forward(self, s): 20 | s = torch.relu(self.fc1(s)) 21 | s = torch.relu(self.fc2(s)) 22 | V = self.V(s) # batch_size X 1 23 | A = self.A(s) # batch_size X action_dim 24 | Q = V + (A - torch.mean(A, dim=-1, keepdim=True)) # Q(s,a)=V(s)+A(s,a)-mean(A(s,a)) 25 | return Q 26 | 27 | 28 | class Net(nn.Module): 29 | def __init__(self, args): 30 | super(Net, self).__init__() 31 | self.fc1 = nn.Linear(args.state_dim, args.hidden_dim) 32 | self.fc2 = nn.Linear(args.hidden_dim, args.hidden_dim) 33 | if args.use_noisy: 34 | self.fc3 = NoisyLinear(args.hidden_dim, args.action_dim) 35 | else: 36 | self.fc3 = nn.Linear(args.hidden_dim, args.action_dim) 37 | 38 | def forward(self, s): 39 | s = torch.relu(self.fc1(s)) 40 | s = torch.relu(self.fc2(s)) 41 | Q = self.fc3(s) 42 | return Q 43 | 44 | 45 | class NoisyLinear(nn.Module): 46 | def __init__(self, in_features, out_features, sigma_init=0.5): 47 | super(NoisyLinear, self).__init__() 48 | self.in_features = in_features 49 | self.out_features = out_features 50 | self.sigma_init = sigma_init 51 | 52 | self.weight_mu = nn.Parameter(torch.FloatTensor(out_features, in_features)) 53 | self.weight_sigma = nn.Parameter(torch.FloatTensor(out_features, in_features)) 54 | self.register_buffer('weight_epsilon', torch.FloatTensor(out_features, in_features)) 55 | 56 | self.bias_mu = nn.Parameter(torch.FloatTensor(out_features)) 57 | self.bias_sigma = nn.Parameter(torch.FloatTensor(out_features)) 58 | self.register_buffer('bias_epsilon', torch.FloatTensor(out_features)) 59 | 60 | self.reset_parameters() 61 | self.reset_noise() 62 | 63 | def forward(self, x): 64 | if self.training: 65 | self.reset_noise() 66 | weight = self.weight_mu + self.weight_sigma.mul(self.weight_epsilon) # mul是对应元素相乘 67 | bias = self.bias_mu + self.bias_sigma.mul(self.bias_epsilon) 68 | 69 | else: 70 | weight = self.weight_mu 71 | bias = self.bias_mu 72 | 73 | return F.linear(x, weight, bias) 74 | 75 | def reset_parameters(self): 76 | mu_range = 1 / math.sqrt(self.in_features) 77 | self.weight_mu.data.uniform_(-mu_range, mu_range) 78 | self.bias_mu.data.uniform_(-mu_range, mu_range) 79 | 80 | self.weight_sigma.data.fill_(self.sigma_init / math.sqrt(self.in_features)) 81 | self.bias_sigma.data.fill_(self.sigma_init / math.sqrt(self.out_features)) # 这里要除以out_features 82 | 83 | def reset_noise(self): 84 | epsilon_i = self.scale_noise(self.in_features) 85 | epsilon_j = self.scale_noise(self.out_features) 86 | self.weight_epsilon.copy_(torch.ger(epsilon_j, epsilon_i)) 87 | self.bias_epsilon.copy_(epsilon_j) 88 | 89 | def scale_noise(self, size): 90 | x = torch.randn(size) # torch.randn产生标准高斯分布 91 | x = x.sign().mul(x.abs().sqrt()) 92 | return x 93 | -------------------------------------------------------------------------------- /rainbow_dqn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import copy 4 | from network import Dueling_Net, Net 5 | 6 | 7 | class DQN(object): 8 | def __init__(self, args): 9 | self.action_dim = args.action_dim 10 | self.batch_size = args.batch_size # batch size 11 | self.max_train_steps = args.max_train_steps 12 | self.lr = args.lr # learning rate 13 | self.gamma = args.gamma # discount factor 14 | self.tau = args.tau # Soft update 15 | self.use_soft_update = args.use_soft_update 16 | self.target_update_freq = args.target_update_freq # hard update 17 | self.update_count = 0 18 | 19 | self.grad_clip = args.grad_clip 20 | self.use_lr_decay = args.use_lr_decay 21 | self.use_double = args.use_double 22 | self.use_dueling = args.use_dueling 23 | self.use_per = args.use_per 24 | self.use_n_steps = args.use_n_steps 25 | if self.use_n_steps: 26 | self.gamma = self.gamma ** args.n_steps 27 | 28 | if self.use_dueling: # Whether to use the 'dueling network' 29 | self.net = Dueling_Net(args) 30 | else: 31 | self.net = Net(args) 32 | 33 | self.target_net = copy.deepcopy(self.net) # Copy the online_net to the target_net 34 | 35 | self.optimizer = torch.optim.Adam(self.net.parameters(), lr=self.lr) 36 | 37 | def choose_action(self, state, epsilon): 38 | with torch.no_grad(): 39 | state = torch.unsqueeze(torch.tensor(state, dtype=torch.float), 0) 40 | q = self.net(state) 41 | if np.random.uniform() > epsilon: 42 | action = q.argmax(dim=-1).item() 43 | else: 44 | action = np.random.randint(0, self.action_dim) 45 | return action 46 | 47 | def learn(self, replay_buffer, total_steps): 48 | batch, batch_index, IS_weight = replay_buffer.sample(total_steps) 49 | 50 | with torch.no_grad(): # q_target has no gradient 51 | if self.use_double: # Whether to use the 'double q-learning' 52 | # Use online_net to select the action 53 | a_argmax = self.net(batch['next_state']).argmax(dim=-1, keepdim=True) # shape:(batch_size,1) 54 | # Use target_net to estimate the q_target 55 | q_target = batch['reward'] + self.gamma * (1 - batch['terminal']) * self.target_net(batch['next_state']).gather(-1, a_argmax).squeeze(-1) # shape:(batch_size,) 56 | else: 57 | q_target = batch['reward'] + self.gamma * (1 - batch['terminal']) * self.target_net(batch['next_state']).max(dim=-1)[0] # shape:(batch_size,) 58 | 59 | q_current = self.net(batch['state']).gather(-1, batch['action']).squeeze(-1) # shape:(batch_size,) 60 | td_errors = q_current - q_target # shape:(batch_size,) 61 | 62 | if self.use_per: 63 | loss = (IS_weight * (td_errors ** 2)).mean() 64 | replay_buffer.update_batch_priorities(batch_index, td_errors.detach().numpy()) 65 | else: 66 | loss = (td_errors ** 2).mean() 67 | 68 | self.optimizer.zero_grad() 69 | loss.backward() 70 | if self.grad_clip: 71 | torch.nn.utils.clip_grad_norm_(self.net.parameters(), self.grad_clip) 72 | self.optimizer.step() 73 | 74 | if self.use_soft_update: # soft update 75 | for param, target_param in zip(self.net.parameters(), self.target_net.parameters()): 76 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 77 | else: # hard update 78 | self.update_count += 1 79 | if self.update_count % self.target_update_freq == 0: 80 | self.target_net.load_state_dict(self.net.state_dict()) 81 | 82 | if self.use_lr_decay: # learning rate Decay 83 | self.lr_decay(total_steps) 84 | 85 | def lr_decay(self, total_steps): 86 | lr_now = 0.9 * self.lr * (1 - total_steps / self.max_train_steps) + 0.1 * self.lr 87 | for p in self.optimizer.param_groups: 88 | p['lr'] = lr_now 89 | -------------------------------------------------------------------------------- /Rainbow_DQN_main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import gym 4 | from torch.utils.tensorboard import SummaryWriter 5 | from replay_buffer import * 6 | from rainbow_dqn import DQN 7 | import argparse 8 | 9 | 10 | class Runner: 11 | def __init__(self, args, env_name, number, seed): 12 | self.args = args 13 | self.env_name = env_name 14 | self.number = number 15 | self.seed = seed 16 | 17 | self.env = gym.make(env_name) 18 | self.env_evaluate = gym.make(env_name) # When evaluating the policy, we need to rebuild an environment 19 | self.env.seed(seed) 20 | self.env.action_space.seed(seed) 21 | self.env_evaluate.seed(seed) 22 | self.env_evaluate.action_space.seed(seed) 23 | np.random.seed(seed) 24 | torch.manual_seed(seed) 25 | 26 | self.args.state_dim = self.env.observation_space.shape[0] 27 | self.args.action_dim = self.env.action_space.n 28 | self.args.episode_limit = self.env._max_episode_steps # Maximum number of steps per episode 29 | print("env={}".format(self.env_name)) 30 | print("state_dim={}".format(self.args.state_dim)) 31 | print("action_dim={}".format(self.args.action_dim)) 32 | print("episode_limit={}".format(self.args.episode_limit)) 33 | 34 | if args.use_per and args.use_n_steps: 35 | self.replay_buffer = N_Steps_Prioritized_ReplayBuffer(args) 36 | elif args.use_per: 37 | self.replay_buffer = Prioritized_ReplayBuffer(args) 38 | elif args.use_n_steps: 39 | self.replay_buffer = N_Steps_ReplayBuffer(args) 40 | else: 41 | self.replay_buffer = ReplayBuffer(args) 42 | self.agent = DQN(args) 43 | 44 | self.algorithm = 'DQN' 45 | if args.use_double and args.use_dueling and args.use_noisy and args.use_per and args.use_n_steps: 46 | self.algorithm = 'Rainbow_' + self.algorithm 47 | else: 48 | if args.use_double: 49 | self.algorithm += '_Double' 50 | if args.use_dueling: 51 | self.algorithm += '_Dueling' 52 | if args.use_noisy: 53 | self.algorithm += '_Noisy' 54 | if args.use_per: 55 | self.algorithm += '_PER' 56 | if args.use_n_steps: 57 | self.algorithm += "_N_steps" 58 | 59 | self.writer = SummaryWriter(log_dir='runs/DQN/{}_env_{}_number_{}_seed_{}'.format(self.algorithm, env_name, number, seed)) 60 | 61 | self.evaluate_num = 0 # Record the number of evaluations 62 | self.evaluate_rewards = [] # Record the rewards during the evaluating 63 | self.total_steps = 0 # Record the total steps during the training 64 | if args.use_noisy: # 如果使用Noisy net,就不需要epsilon贪心策略了 65 | self.epsilon = 0 66 | else: 67 | self.epsilon = self.args.epsilon_init 68 | self.epsilon_min = self.args.epsilon_min 69 | self.epsilon_decay = (self.args.epsilon_init - self.args.epsilon_min) / self.args.epsilon_decay_steps 70 | 71 | def run(self, ): 72 | self.evaluate_policy() 73 | while self.total_steps < self.args.max_train_steps: 74 | state = self.env.reset() 75 | done = False 76 | episode_steps = 0 77 | while not done: 78 | action = self.agent.choose_action(state, epsilon=self.epsilon) 79 | next_state, reward, done, _ = self.env.step(action) 80 | episode_steps += 1 81 | self.total_steps += 1 82 | 83 | if not self.args.use_noisy: # Decay epsilon 84 | self.epsilon = self.epsilon - self.epsilon_decay if self.epsilon - self.epsilon_decay > self.epsilon_min else self.epsilon_min 85 | 86 | # When dead or win or reaching the max_episode_steps, done will be Ture, we need to distinguish them; 87 | # terminal means dead or win,there is no next state s'; 88 | # but when reaching the max_episode_steps,there is a next state s' actually. 89 | if done and episode_steps != self.args.episode_limit: 90 | if self.env_name == 'LunarLander-v2': 91 | if reward <= -100: reward = -1 # good for LunarLander 92 | terminal = True 93 | else: 94 | terminal = False 95 | 96 | self.replay_buffer.store_transition(state, action, reward, next_state, terminal, done) # Store the transition 97 | state = next_state 98 | 99 | if self.replay_buffer.current_size >= self.args.batch_size: 100 | self.agent.learn(self.replay_buffer, self.total_steps) 101 | 102 | if self.total_steps % self.args.evaluate_freq == 0: 103 | self.evaluate_policy() 104 | # Save reward 105 | np.save('./data_train/{}_env_{}_number_{}_seed_{}.npy'.format(self.algorithm, self.env_name, self.number, self.seed), np.array(self.evaluate_rewards)) 106 | 107 | def evaluate_policy(self, ): 108 | evaluate_reward = 0 109 | self.agent.net.eval() 110 | for _ in range(self.args.evaluate_times): 111 | state = self.env_evaluate.reset() 112 | done = False 113 | episode_reward = 0 114 | while not done: 115 | action = self.agent.choose_action(state, epsilon=0) 116 | next_state, reward, done, _ = self.env_evaluate.step(action) 117 | episode_reward += reward 118 | state = next_state 119 | evaluate_reward += episode_reward 120 | self.agent.net.train() 121 | evaluate_reward /= self.args.evaluate_times 122 | self.evaluate_rewards.append(evaluate_reward) 123 | print("total_steps:{} \t evaluate_reward:{} \t epsilon:{}".format(self.total_steps, evaluate_reward, self.epsilon)) 124 | self.writer.add_scalar('step_rewards_{}'.format(self.env_name), evaluate_reward, global_step=self.total_steps) 125 | 126 | 127 | if __name__ == '__main__': 128 | parser = argparse.ArgumentParser("Hyperparameter Setting for DQN") 129 | parser.add_argument("--max_train_steps", type=int, default=int(4e5), help=" Maximum number of training steps") 130 | parser.add_argument("--evaluate_freq", type=float, default=1e3, help="Evaluate the policy every 'evaluate_freq' steps") 131 | parser.add_argument("--evaluate_times", type=float, default=3, help="Evaluate times") 132 | 133 | parser.add_argument("--buffer_capacity", type=int, default=int(1e5), help="The maximum replay-buffer capacity ") 134 | parser.add_argument("--batch_size", type=int, default=256, help="batch size") 135 | parser.add_argument("--hidden_dim", type=int, default=256, help="The number of neurons in hidden layers of the neural network") 136 | parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate of actor") 137 | parser.add_argument("--gamma", type=float, default=0.99, help="Discount factor") 138 | parser.add_argument("--epsilon_init", type=float, default=0.5, help="Initial epsilon") 139 | parser.add_argument("--epsilon_min", type=float, default=0.1, help="Minimum epsilon") 140 | parser.add_argument("--epsilon_decay_steps", type=int, default=int(1e5), help="How many steps before the epsilon decays to the minimum") 141 | parser.add_argument("--tau", type=float, default=0.005, help="soft update the target network") 142 | parser.add_argument("--use_soft_update", type=bool, default=True, help="Whether to use soft update") 143 | parser.add_argument("--target_update_freq", type=int, default=200, help="Update frequency of the target network(hard update)") 144 | parser.add_argument("--n_steps", type=int, default=5, help="n_steps") 145 | parser.add_argument("--alpha", type=float, default=0.6, help="PER parameter") 146 | parser.add_argument("--beta_init", type=float, default=0.4, help="Important sampling parameter in PER") 147 | parser.add_argument("--use_lr_decay", type=bool, default=True, help="Learning rate Decay") 148 | parser.add_argument("--grad_clip", type=float, default=10.0, help="Gradient clip") 149 | 150 | parser.add_argument("--use_double", type=bool, default=True, help="Whether to use double Q-learning") 151 | parser.add_argument("--use_dueling", type=bool, default=True, help="Whether to use dueling network") 152 | parser.add_argument("--use_noisy", type=bool, default=True, help="Whether to use noisy network") 153 | parser.add_argument("--use_per", type=bool, default=True, help="Whether to use PER") 154 | parser.add_argument("--use_n_steps", type=bool, default=True, help="Whether to use n_steps Q-learning") 155 | 156 | args = parser.parse_args() 157 | 158 | env_names = ['CartPole-v1', 'LunarLander-v2'] 159 | env_index = 1 160 | for seed in [0, 10, 100]: 161 | runner = Runner(args=args, env_name=env_names[env_index], number=1, seed=seed) 162 | runner.run() 163 | -------------------------------------------------------------------------------- /replay_buffer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from collections import deque 4 | from sum_tree import SumTree 5 | 6 | 7 | class ReplayBuffer(object): 8 | def __init__(self, args): 9 | self.batch_size = args.batch_size 10 | self.buffer_capacity = args.buffer_capacity 11 | self.current_size = 0 12 | self.count = 0 13 | self.buffer = {'state': np.zeros((self.buffer_capacity, args.state_dim)), 14 | 'action': np.zeros((self.buffer_capacity, 1)), 15 | 'reward': np.zeros(self.buffer_capacity), 16 | 'next_state': np.zeros((self.buffer_capacity, args.state_dim)), 17 | 'terminal': np.zeros(self.buffer_capacity), 18 | } 19 | 20 | def store_transition(self, state, action, reward, next_state, terminal, done): 21 | self.buffer['state'][self.count] = state 22 | self.buffer['action'][self.count] = action 23 | self.buffer['reward'][self.count] = reward 24 | self.buffer['next_state'][self.count] = next_state 25 | self.buffer['terminal'][self.count] = terminal 26 | self.count = (self.count + 1) % self.buffer_capacity # When the 'count' reaches buffer_capacity, it will be reset to 0. 27 | self.current_size = min(self.current_size + 1, self.buffer_capacity) 28 | 29 | def sample(self, total_steps): 30 | index = np.random.randint(0, self.current_size, size=self.batch_size) 31 | batch = {} 32 | for key in self.buffer.keys(): # numpy->tensor 33 | if key == 'action': 34 | batch[key] = torch.tensor(self.buffer[key][index], dtype=torch.long) 35 | else: 36 | batch[key] = torch.tensor(self.buffer[key][index], dtype=torch.float32) 37 | 38 | return batch, None, None 39 | 40 | 41 | class N_Steps_ReplayBuffer(object): 42 | def __init__(self, args): 43 | self.gamma = args.gamma 44 | self.batch_size = args.batch_size 45 | self.buffer_capacity = args.buffer_capacity 46 | self.current_size = 0 47 | self.count = 0 48 | self.n_steps = args.n_steps 49 | self.n_steps_deque = deque(maxlen=self.n_steps) 50 | self.buffer = {'state': np.zeros((self.buffer_capacity, args.state_dim)), 51 | 'action': np.zeros((self.buffer_capacity, 1)), 52 | 'reward': np.zeros(self.buffer_capacity), 53 | 'next_state': np.zeros((self.buffer_capacity, args.state_dim)), 54 | 'terminal': np.zeros(self.buffer_capacity), 55 | } 56 | 57 | def store_transition(self, state, action, reward, next_state, terminal, done): 58 | transition = (state, action, reward, next_state, terminal, done) 59 | self.n_steps_deque.append(transition) 60 | if len(self.n_steps_deque) == self.n_steps: 61 | state, action, n_steps_reward, next_state, terminal = self.get_n_steps_transition() 62 | self.buffer['state'][self.count] = state 63 | self.buffer['action'][self.count] = action 64 | self.buffer['reward'][self.count] = n_steps_reward 65 | self.buffer['next_state'][self.count] = next_state 66 | self.buffer['terminal'][self.count] = terminal 67 | self.count = (self.count + 1) % self.buffer_capacity # When the 'count' reaches buffer_capacity, it will be reset to 0. 68 | self.current_size = min(self.current_size + 1, self.buffer_capacity) 69 | 70 | def get_n_steps_transition(self): 71 | state, action = self.n_steps_deque[0][:2] 72 | next_state, terminal = self.n_steps_deque[-1][3:5] 73 | n_steps_reward = 0 74 | for i in reversed(range(self.n_steps)): 75 | r, s_, ter, d = self.n_steps_deque[i][2:] 76 | n_steps_reward = r + self.gamma * (1 - d) * n_steps_reward 77 | if d: 78 | next_state, terminal = s_, ter 79 | 80 | return state, action, n_steps_reward, next_state, terminal 81 | 82 | def sample(self, total_steps): 83 | index = np.random.randint(0, self.current_size, size=self.batch_size) 84 | batch = {} 85 | for key in self.buffer.keys(): # numpy->tensor 86 | if key == 'action': 87 | batch[key] = torch.tensor(self.buffer[key][index], dtype=torch.long) 88 | else: 89 | batch[key] = torch.tensor(self.buffer[key][index], dtype=torch.float32) 90 | 91 | return batch, None, None 92 | 93 | 94 | class Prioritized_ReplayBuffer(object): 95 | def __init__(self, args): 96 | self.max_train_steps = args.max_train_steps 97 | self.alpha = args.alpha 98 | self.beta_init = args.beta_init 99 | self.beta = args.beta_init 100 | self.batch_size = args.batch_size 101 | self.buffer_capacity = args.buffer_capacity 102 | self.sum_tree = SumTree(self.buffer_capacity) 103 | self.current_size = 0 104 | self.count = 0 105 | self.buffer = {'state': np.zeros((self.buffer_capacity, args.state_dim)), 106 | 'action': np.zeros((self.buffer_capacity, 1)), 107 | 'reward': np.zeros(self.buffer_capacity), 108 | 'next_state': np.zeros((self.buffer_capacity, args.state_dim)), 109 | 'terminal': np.zeros(self.buffer_capacity), 110 | } 111 | 112 | def store_transition(self, state, action, reward, next_state, terminal, done): 113 | self.buffer['state'][self.count] = state 114 | self.buffer['action'][self.count] = action 115 | self.buffer['reward'][self.count] = reward 116 | self.buffer['next_state'][self.count] = next_state 117 | self.buffer['terminal'][self.count] = terminal 118 | # 如果是第一条经验,初始化优先级为1.0;否则,对于新存入的经验,指定为当前最大的优先级 119 | priority = 1.0 if self.current_size == 0 else self.sum_tree.priority_max 120 | self.sum_tree.update(data_index=self.count, priority=priority) # 更新当前经验在sum_tree中的优先级 121 | self.count = (self.count + 1) % self.buffer_capacity # When the 'count' reaches buffer_capacity, it will be reset to 0. 122 | self.current_size = min(self.current_size + 1, self.buffer_capacity) 123 | 124 | def sample(self, total_steps): 125 | batch_index, IS_weight = self.sum_tree.get_batch_index(current_size=self.current_size, batch_size=self.batch_size, beta=self.beta) 126 | self.beta = self.beta_init + (1 - self.beta_init) * (total_steps / self.max_train_steps) # beta:beta_init->1.0 127 | batch = {} 128 | for key in self.buffer.keys(): # numpy->tensor 129 | if key == 'action': 130 | batch[key] = torch.tensor(self.buffer[key][batch_index], dtype=torch.long) 131 | else: 132 | batch[key] = torch.tensor(self.buffer[key][batch_index], dtype=torch.float32) 133 | 134 | return batch, batch_index, IS_weight 135 | 136 | def update_batch_priorities(self, batch_index, td_errors): # 根据传入的td_error,更新batch_index所对应数据的priorities 137 | priorities = (np.abs(td_errors) + 0.01) ** self.alpha 138 | for index, priority in zip(batch_index, priorities): 139 | self.sum_tree.update(data_index=index, priority=priority) 140 | 141 | 142 | class N_Steps_Prioritized_ReplayBuffer(object): 143 | def __init__(self, args): 144 | self.max_train_steps = args.max_train_steps 145 | self.alpha = args.alpha 146 | self.beta_init = args.beta_init 147 | self.beta = args.beta_init 148 | self.gamma = args.gamma 149 | self.batch_size = args.batch_size 150 | self.buffer_capacity = args.buffer_capacity 151 | self.sum_tree = SumTree(self.buffer_capacity) 152 | self.n_steps = args.n_steps 153 | self.n_steps_deque = deque(maxlen=self.n_steps) 154 | self.buffer = {'state': np.zeros((self.buffer_capacity, args.state_dim)), 155 | 'action': np.zeros((self.buffer_capacity, 1)), 156 | 'reward': np.zeros(self.buffer_capacity), 157 | 'next_state': np.zeros((self.buffer_capacity, args.state_dim)), 158 | 'terminal': np.zeros(self.buffer_capacity), 159 | } 160 | self.current_size = 0 161 | self.count = 0 162 | 163 | def store_transition(self, state, action, reward, next_state, terminal, done): 164 | transition = (state, action, reward, next_state, terminal, done) 165 | self.n_steps_deque.append(transition) 166 | if len(self.n_steps_deque) == self.n_steps: 167 | state, action, n_steps_reward, next_state, terminal = self.get_n_steps_transition() 168 | self.buffer['state'][self.count] = state 169 | self.buffer['action'][self.count] = action 170 | self.buffer['reward'][self.count] = n_steps_reward 171 | self.buffer['next_state'][self.count] = next_state 172 | self.buffer['terminal'][self.count] = terminal 173 | # 如果是buffer中的第一条经验,那么指定priority为1.0;否则对于新存入的经验,指定为当前最大的priority 174 | priority = 1.0 if self.current_size == 0 else self.sum_tree.priority_max 175 | self.sum_tree.update(data_index=self.count, priority=priority) # 更新当前经验在sum_tree中的优先级 176 | self.count = (self.count + 1) % self.buffer_capacity # When 'count' reaches buffer_capacity, it will be reset to 0. 177 | self.current_size = min(self.current_size + 1, self.buffer_capacity) 178 | 179 | def sample(self, total_steps): 180 | batch_index, IS_weight = self.sum_tree.get_batch_index(current_size=self.current_size, batch_size=self.batch_size, beta=self.beta) 181 | self.beta = self.beta_init + (1 - self.beta_init) * (total_steps / self.max_train_steps) # beta:beta_init->1.0 182 | batch = {} 183 | for key in self.buffer.keys(): # numpy->tensor 184 | if key == 'action': 185 | batch[key] = torch.tensor(self.buffer[key][batch_index], dtype=torch.long) 186 | else: 187 | batch[key] = torch.tensor(self.buffer[key][batch_index], dtype=torch.float32) 188 | 189 | return batch, batch_index, IS_weight 190 | 191 | def get_n_steps_transition(self): 192 | state, action = self.n_steps_deque[0][:2] # 获取deque中第一个transition的s和a 193 | next_state, terminal = self.n_steps_deque[-1][3:5] # 获取deque中最后一个transition的s'和terminal 194 | n_steps_reward = 0 195 | for i in reversed(range(self.n_steps)): # 逆序计算n_steps_reward 196 | r, s_, ter, d = self.n_steps_deque[i][2:] 197 | n_steps_reward = r + self.gamma * (1 - d) * n_steps_reward 198 | if d: # 如果done=True,说明一个回合结束,保存deque中当前这个transition的s'和terminal作为这个n_steps_transition的next_state和terminal 199 | next_state, terminal = s_, ter 200 | 201 | return state, action, n_steps_reward, next_state, terminal 202 | 203 | def update_batch_priorities(self, batch_index, td_errors): # 根据传入的td_error,更新batch_index所对应数据的priorities 204 | priorities = (np.abs(td_errors) + 0.01) ** self.alpha 205 | for index, priority in zip(batch_index, priorities): 206 | self.sum_tree.update(data_index=index, priority=priority) 207 | --------------------------------------------------------------------------------