├── rainbow_dqn_result.png
├── data_train
├── Rainbow_DQN_env_LunarLander-v2_number_1_seed_0.npy
├── Rainbow_DQN_env_LunarLander-v2_number_1_seed_10.npy
├── Rainbow_DQN_env_LunarLander-v2_number_1_seed_100.npy
├── DQN_Double_Dueling_Noisy_PER_env_LunarLander-v2_number_1_seed_0.npy
├── DQN_Double_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_0.npy
├── DQN_Double_Dueling_Noisy_PER_env_LunarLander-v2_number_1_seed_10.npy
├── DQN_Double_Dueling_Noisy_PER_env_LunarLander-v2_number_1_seed_100.npy
├── DQN_Double_Dueling_PER_N_steps_env_LunarLander-v2_number_1_seed_0.npy
├── DQN_Double_Dueling_PER_N_steps_env_LunarLander-v2_number_1_seed_10.npy
├── DQN_Double_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_10.npy
├── DQN_Double_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_100.npy
├── DQN_Dueling_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_0.npy
├── DQN_Dueling_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_10.npy
├── DQN_Dueling_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_100.npy
├── DQN_Double_Dueling_Noisy_N_steps_env_LunarLander-v2_number_1_seed_0.npy
├── DQN_Double_Dueling_Noisy_N_steps_env_LunarLander-v2_number_1_seed_10.npy
├── DQN_Double_Dueling_PER_N_steps_env_LunarLander-v2_number_1_seed_100.npy
└── DQN_Double_Dueling_Noisy_N_steps_env_LunarLander-v2_number_1_seed_100.npy
├── runs
└── DQN
│ ├── Rainbow_DQN_env_LunarLander-v2_number_1_seed_0
│ └── events.out.tfevents.1658479209.DESKTOP-LMKC0MO.1228.0
│ ├── Rainbow_DQN_env_LunarLander-v2_number_1_seed_10
│ └── events.out.tfevents.1658479212.DESKTOP-LMKC0MO.10500.0
│ ├── Rainbow_DQN_env_LunarLander-v2_number_1_seed_100
│ └── events.out.tfevents.1658479214.DESKTOP-LMKC0MO.9512.0
│ ├── DQN_Double_Dueling_Noisy_PER_env_LunarLander-v2_number_1_seed_0
│ └── events.out.tfevents.1658494481.DESKTOP-LMKC0MO.9316.0
│ ├── DQN_Double_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_0
│ └── events.out.tfevents.1658494473.DESKTOP-LMKC0MO.2144.0
│ ├── DQN_Double_Dueling_Noisy_PER_env_LunarLander-v2_number_1_seed_10
│ └── events.out.tfevents.1658512436.DESKTOP-LMKC0MO.9316.1
│ ├── DQN_Double_Dueling_Noisy_PER_env_LunarLander-v2_number_1_seed_100
│ └── events.out.tfevents.1658531515.DESKTOP-LMKC0MO.9316.2
│ ├── DQN_Double_Dueling_PER_N_steps_env_LunarLander-v2_number_1_seed_0
│ └── events.out.tfevents.1658494475.DESKTOP-LMKC0MO.5976.0
│ ├── DQN_Double_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_10
│ └── events.out.tfevents.1658511489.DESKTOP-LMKC0MO.2144.1
│ ├── DQN_Double_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_100
│ └── events.out.tfevents.1658529336.DESKTOP-LMKC0MO.2144.2
│ ├── DQN_Dueling_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_0
│ └── events.out.tfevents.1658494471.DESKTOP-LMKC0MO.9964.0
│ ├── DQN_Dueling_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_10
│ └── events.out.tfevents.1658510515.DESKTOP-LMKC0MO.9964.1
│ ├── DQN_Double_Dueling_Noisy_N_steps_env_LunarLander-v2_number_1_seed_0
│ └── events.out.tfevents.1658494478.DESKTOP-LMKC0MO.1408.0
│ ├── DQN_Double_Dueling_Noisy_N_steps_env_LunarLander-v2_number_1_seed_10
│ └── events.out.tfevents.1658507126.DESKTOP-LMKC0MO.1408.1
│ ├── DQN_Double_Dueling_PER_N_steps_env_LunarLander-v2_number_1_seed_10
│ └── events.out.tfevents.1658511615.DESKTOP-LMKC0MO.5976.1
│ ├── DQN_Double_Dueling_PER_N_steps_env_LunarLander-v2_number_1_seed_100
│ └── events.out.tfevents.1658528978.DESKTOP-LMKC0MO.5976.2
│ ├── DQN_Dueling_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_100
│ └── events.out.tfevents.1658526626.DESKTOP-LMKC0MO.9964.2
│ └── DQN_Double_Dueling_Noisy_N_steps_env_LunarLander-v2_number_1_seed_100
│ └── events.out.tfevents.1658520541.DESKTOP-LMKC0MO.1408.2
├── LICENSE
├── README.md
├── sum_tree.py
├── network.py
├── rainbow_dqn.py
├── Rainbow_DQN_main.py
└── replay_buffer.py
/rainbow_dqn_result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/rainbow_dqn_result.png
--------------------------------------------------------------------------------
/data_train/Rainbow_DQN_env_LunarLander-v2_number_1_seed_0.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/data_train/Rainbow_DQN_env_LunarLander-v2_number_1_seed_0.npy
--------------------------------------------------------------------------------
/data_train/Rainbow_DQN_env_LunarLander-v2_number_1_seed_10.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/data_train/Rainbow_DQN_env_LunarLander-v2_number_1_seed_10.npy
--------------------------------------------------------------------------------
/data_train/Rainbow_DQN_env_LunarLander-v2_number_1_seed_100.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/data_train/Rainbow_DQN_env_LunarLander-v2_number_1_seed_100.npy
--------------------------------------------------------------------------------
/data_train/DQN_Double_Dueling_Noisy_PER_env_LunarLander-v2_number_1_seed_0.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/data_train/DQN_Double_Dueling_Noisy_PER_env_LunarLander-v2_number_1_seed_0.npy
--------------------------------------------------------------------------------
/data_train/DQN_Double_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_0.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/data_train/DQN_Double_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_0.npy
--------------------------------------------------------------------------------
/data_train/DQN_Double_Dueling_Noisy_PER_env_LunarLander-v2_number_1_seed_10.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/data_train/DQN_Double_Dueling_Noisy_PER_env_LunarLander-v2_number_1_seed_10.npy
--------------------------------------------------------------------------------
/data_train/DQN_Double_Dueling_Noisy_PER_env_LunarLander-v2_number_1_seed_100.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/data_train/DQN_Double_Dueling_Noisy_PER_env_LunarLander-v2_number_1_seed_100.npy
--------------------------------------------------------------------------------
/data_train/DQN_Double_Dueling_PER_N_steps_env_LunarLander-v2_number_1_seed_0.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/data_train/DQN_Double_Dueling_PER_N_steps_env_LunarLander-v2_number_1_seed_0.npy
--------------------------------------------------------------------------------
/data_train/DQN_Double_Dueling_PER_N_steps_env_LunarLander-v2_number_1_seed_10.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/data_train/DQN_Double_Dueling_PER_N_steps_env_LunarLander-v2_number_1_seed_10.npy
--------------------------------------------------------------------------------
/data_train/DQN_Double_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_10.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/data_train/DQN_Double_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_10.npy
--------------------------------------------------------------------------------
/data_train/DQN_Double_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_100.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/data_train/DQN_Double_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_100.npy
--------------------------------------------------------------------------------
/data_train/DQN_Dueling_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_0.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/data_train/DQN_Dueling_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_0.npy
--------------------------------------------------------------------------------
/data_train/DQN_Dueling_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_10.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/data_train/DQN_Dueling_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_10.npy
--------------------------------------------------------------------------------
/data_train/DQN_Dueling_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_100.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/data_train/DQN_Dueling_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_100.npy
--------------------------------------------------------------------------------
/data_train/DQN_Double_Dueling_Noisy_N_steps_env_LunarLander-v2_number_1_seed_0.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/data_train/DQN_Double_Dueling_Noisy_N_steps_env_LunarLander-v2_number_1_seed_0.npy
--------------------------------------------------------------------------------
/data_train/DQN_Double_Dueling_Noisy_N_steps_env_LunarLander-v2_number_1_seed_10.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/data_train/DQN_Double_Dueling_Noisy_N_steps_env_LunarLander-v2_number_1_seed_10.npy
--------------------------------------------------------------------------------
/data_train/DQN_Double_Dueling_PER_N_steps_env_LunarLander-v2_number_1_seed_100.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/data_train/DQN_Double_Dueling_PER_N_steps_env_LunarLander-v2_number_1_seed_100.npy
--------------------------------------------------------------------------------
/data_train/DQN_Double_Dueling_Noisy_N_steps_env_LunarLander-v2_number_1_seed_100.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/data_train/DQN_Double_Dueling_Noisy_N_steps_env_LunarLander-v2_number_1_seed_100.npy
--------------------------------------------------------------------------------
/runs/DQN/Rainbow_DQN_env_LunarLander-v2_number_1_seed_0/events.out.tfevents.1658479209.DESKTOP-LMKC0MO.1228.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/runs/DQN/Rainbow_DQN_env_LunarLander-v2_number_1_seed_0/events.out.tfevents.1658479209.DESKTOP-LMKC0MO.1228.0
--------------------------------------------------------------------------------
/runs/DQN/Rainbow_DQN_env_LunarLander-v2_number_1_seed_10/events.out.tfevents.1658479212.DESKTOP-LMKC0MO.10500.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/runs/DQN/Rainbow_DQN_env_LunarLander-v2_number_1_seed_10/events.out.tfevents.1658479212.DESKTOP-LMKC0MO.10500.0
--------------------------------------------------------------------------------
/runs/DQN/Rainbow_DQN_env_LunarLander-v2_number_1_seed_100/events.out.tfevents.1658479214.DESKTOP-LMKC0MO.9512.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/runs/DQN/Rainbow_DQN_env_LunarLander-v2_number_1_seed_100/events.out.tfevents.1658479214.DESKTOP-LMKC0MO.9512.0
--------------------------------------------------------------------------------
/runs/DQN/DQN_Double_Dueling_Noisy_PER_env_LunarLander-v2_number_1_seed_0/events.out.tfevents.1658494481.DESKTOP-LMKC0MO.9316.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/runs/DQN/DQN_Double_Dueling_Noisy_PER_env_LunarLander-v2_number_1_seed_0/events.out.tfevents.1658494481.DESKTOP-LMKC0MO.9316.0
--------------------------------------------------------------------------------
/runs/DQN/DQN_Double_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_0/events.out.tfevents.1658494473.DESKTOP-LMKC0MO.2144.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/runs/DQN/DQN_Double_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_0/events.out.tfevents.1658494473.DESKTOP-LMKC0MO.2144.0
--------------------------------------------------------------------------------
/runs/DQN/DQN_Double_Dueling_Noisy_PER_env_LunarLander-v2_number_1_seed_10/events.out.tfevents.1658512436.DESKTOP-LMKC0MO.9316.1:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/runs/DQN/DQN_Double_Dueling_Noisy_PER_env_LunarLander-v2_number_1_seed_10/events.out.tfevents.1658512436.DESKTOP-LMKC0MO.9316.1
--------------------------------------------------------------------------------
/runs/DQN/DQN_Double_Dueling_Noisy_PER_env_LunarLander-v2_number_1_seed_100/events.out.tfevents.1658531515.DESKTOP-LMKC0MO.9316.2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/runs/DQN/DQN_Double_Dueling_Noisy_PER_env_LunarLander-v2_number_1_seed_100/events.out.tfevents.1658531515.DESKTOP-LMKC0MO.9316.2
--------------------------------------------------------------------------------
/runs/DQN/DQN_Double_Dueling_PER_N_steps_env_LunarLander-v2_number_1_seed_0/events.out.tfevents.1658494475.DESKTOP-LMKC0MO.5976.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/runs/DQN/DQN_Double_Dueling_PER_N_steps_env_LunarLander-v2_number_1_seed_0/events.out.tfevents.1658494475.DESKTOP-LMKC0MO.5976.0
--------------------------------------------------------------------------------
/runs/DQN/DQN_Double_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_10/events.out.tfevents.1658511489.DESKTOP-LMKC0MO.2144.1:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/runs/DQN/DQN_Double_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_10/events.out.tfevents.1658511489.DESKTOP-LMKC0MO.2144.1
--------------------------------------------------------------------------------
/runs/DQN/DQN_Double_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_100/events.out.tfevents.1658529336.DESKTOP-LMKC0MO.2144.2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/runs/DQN/DQN_Double_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_100/events.out.tfevents.1658529336.DESKTOP-LMKC0MO.2144.2
--------------------------------------------------------------------------------
/runs/DQN/DQN_Dueling_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_0/events.out.tfevents.1658494471.DESKTOP-LMKC0MO.9964.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/runs/DQN/DQN_Dueling_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_0/events.out.tfevents.1658494471.DESKTOP-LMKC0MO.9964.0
--------------------------------------------------------------------------------
/runs/DQN/DQN_Dueling_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_10/events.out.tfevents.1658510515.DESKTOP-LMKC0MO.9964.1:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/runs/DQN/DQN_Dueling_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_10/events.out.tfevents.1658510515.DESKTOP-LMKC0MO.9964.1
--------------------------------------------------------------------------------
/runs/DQN/DQN_Double_Dueling_Noisy_N_steps_env_LunarLander-v2_number_1_seed_0/events.out.tfevents.1658494478.DESKTOP-LMKC0MO.1408.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/runs/DQN/DQN_Double_Dueling_Noisy_N_steps_env_LunarLander-v2_number_1_seed_0/events.out.tfevents.1658494478.DESKTOP-LMKC0MO.1408.0
--------------------------------------------------------------------------------
/runs/DQN/DQN_Double_Dueling_Noisy_N_steps_env_LunarLander-v2_number_1_seed_10/events.out.tfevents.1658507126.DESKTOP-LMKC0MO.1408.1:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/runs/DQN/DQN_Double_Dueling_Noisy_N_steps_env_LunarLander-v2_number_1_seed_10/events.out.tfevents.1658507126.DESKTOP-LMKC0MO.1408.1
--------------------------------------------------------------------------------
/runs/DQN/DQN_Double_Dueling_PER_N_steps_env_LunarLander-v2_number_1_seed_10/events.out.tfevents.1658511615.DESKTOP-LMKC0MO.5976.1:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/runs/DQN/DQN_Double_Dueling_PER_N_steps_env_LunarLander-v2_number_1_seed_10/events.out.tfevents.1658511615.DESKTOP-LMKC0MO.5976.1
--------------------------------------------------------------------------------
/runs/DQN/DQN_Double_Dueling_PER_N_steps_env_LunarLander-v2_number_1_seed_100/events.out.tfevents.1658528978.DESKTOP-LMKC0MO.5976.2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/runs/DQN/DQN_Double_Dueling_PER_N_steps_env_LunarLander-v2_number_1_seed_100/events.out.tfevents.1658528978.DESKTOP-LMKC0MO.5976.2
--------------------------------------------------------------------------------
/runs/DQN/DQN_Dueling_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_100/events.out.tfevents.1658526626.DESKTOP-LMKC0MO.9964.2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/runs/DQN/DQN_Dueling_Noisy_PER_N_steps_env_LunarLander-v2_number_1_seed_100/events.out.tfevents.1658526626.DESKTOP-LMKC0MO.9964.2
--------------------------------------------------------------------------------
/runs/DQN/DQN_Double_Dueling_Noisy_N_steps_env_LunarLander-v2_number_1_seed_100/events.out.tfevents.1658520541.DESKTOP-LMKC0MO.1408.2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lizhi-sjtu/Rainbow-DQN-pytorch/HEAD/runs/DQN/DQN_Double_Dueling_Noisy_N_steps_env_LunarLander-v2_number_1_seed_100/events.out.tfevents.1658520541.DESKTOP-LMKC0MO.1408.2
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Lizhi
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Rainbow DQN
2 | This is a concise Pytorch implementation of Rainbow DQN, including Double Q-learning, Dueling network, Noisy network, PER and n-steps Q-learning.
3 |
4 | ## Dependencies
5 | python==3.7.9
6 | numpy==1.19.4
7 | pytorch==1.5.0
8 | tensorboard==0.6.0
9 | gym==0.21.0
10 |
11 | ## How to use my code?
12 | You can dircetly run Rainbow_DQN_main.py in your own IDE.
13 |
14 | ### Trainning environments
15 | You can set the 'env_index' in the code to change the environments.
16 | env_index=0 represent 'CartPole-v1'
17 | env_index=1 represent 'LunarLander-v2'
18 |
19 | ### How to see the training results?
20 | You can use the tensorboard to visualize the training curves, which are saved in the file 'runs'.
21 | The rewards data are saved as numpy in the file 'data_train'.
22 | The training curves are shown below.
23 | The right picture is smoothed by averaging over a window of 10 steps. The solid line and the shadow respectively represent the average and standard deviation over three different random seeds. (seed=0, 10, 100)
24 | 
25 |
26 | ## Reference
27 | [1] Mnih V, Kavukcuoglu K, Silver D, et al. Human-level control through deep reinforcement learning[J]. nature, 2015, 518(7540): 529-533.
28 | [2] Van Hasselt H, Guez A, Silver D. Deep reinforcement learning with double q-learning[C]//Proceedings of the AAAI conference on artificial intelligence. 2016, 30(1).
29 | [3] Wang Z, Schaul T, Hessel M, et al. Dueling network architectures for deep reinforcement learning[C]//International conference on machine learning. PMLR, 2016: 1995-2003.
30 | [4] Fortunato M, Azar M G, Piot B, et al. Noisy networks for exploration[J]. arXiv preprint arXiv:1706.10295, 2017.
31 | [5] Schaul T, Quan J, Antonoglou I, et al. Prioritized experience replay[J]. arXiv preprint arXiv:1511.05952, 2015.
32 | [6] Hessel M, Modayil J, Van Hasselt H, et al. Rainbow: Combining improvements in deep reinforcement learning[C]//Thirty-second AAAI conference on artificial intelligence. 2018.
33 |
34 |
--------------------------------------------------------------------------------
/sum_tree.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | class SumTree(object):
5 | """
6 | Story data with its priority in the tree.
7 | Tree structure and array storage:
8 |
9 | Tree index:
10 | 0 -> storing priority sum
11 | / \
12 | 1 2
13 | / \ / \
14 | 3 4 5 6 -> storing priority for transitions
15 |
16 | Array type for storing:
17 | [0,1,2,3,4,5,6]
18 | """
19 |
20 | def __init__(self, buffer_capacity):
21 | self.buffer_capacity = buffer_capacity # buffer的容量
22 | self.tree_capacity = 2 * buffer_capacity - 1 # sum_tree的容量
23 | self.tree = np.zeros(self.tree_capacity)
24 |
25 | def update(self, data_index, priority):
26 | # data_index表示当前数据在buffer中的index
27 | # tree_index表示当前数据在sum_tree中的index
28 | tree_index = data_index + self.buffer_capacity - 1 # 把当前数据在buffer中的index转换为在sum_tree中的index
29 | change = priority - self.tree[tree_index] # 当前数据的priority的改变量
30 | self.tree[tree_index] = priority # 更新树的最后一层叶子节点的优先级
31 | # then propagate the change through the tree
32 | while tree_index != 0: # 更新上层节点的优先级,一直传播到最顶端
33 | tree_index = (tree_index - 1) // 2
34 | self.tree[tree_index] += change
35 |
36 | def get_index(self, v):
37 | parent_idx = 0 # 从树的顶端开始
38 | while True:
39 | child_left_idx = 2 * parent_idx + 1 # 父节点下方的左右两个子节点的index
40 | child_right_idx = child_left_idx + 1
41 | if child_left_idx >= self.tree_capacity: # reach bottom, end search
42 | tree_index = parent_idx # tree_index表示采样到的数据在sum_tree中的index
43 | break
44 | else: # downward search, always search for a higher priority node
45 | if v <= self.tree[child_left_idx]:
46 | parent_idx = child_left_idx
47 | else:
48 | v -= self.tree[child_left_idx]
49 | parent_idx = child_right_idx
50 |
51 | data_index = tree_index - self.buffer_capacity + 1 # tree_index->data_index
52 | return data_index, self.tree[tree_index] # 返回采样到的data在buffer中的index,以及相对应的priority
53 |
54 | def get_batch_index(self, current_size, batch_size, beta):
55 | batch_index = np.zeros(batch_size, dtype=np.long)
56 | IS_weight = torch.zeros(batch_size, dtype=torch.float32)
57 | segment = self.priority_sum / batch_size # 把[0,priority_sum]等分成batch_size个区间,在每个区间均匀采样一个数
58 | for i in range(batch_size):
59 | a = segment * i
60 | b = segment * (i + 1)
61 | v = np.random.uniform(a, b)
62 | index, priority = self.get_index(v)
63 | batch_index[i] = index
64 | prob = priority / self.priority_sum # 当前数据被采样的概率
65 | IS_weight[i] = (current_size * prob) ** (-beta)
66 | IS_weight /= IS_weight.max() # normalization
67 |
68 | return batch_index, IS_weight
69 |
70 | @property
71 | def priority_sum(self):
72 | return self.tree[0] # 树的顶端保存了所有priority之和
73 |
74 | @property
75 | def priority_max(self):
76 | return self.tree[self.buffer_capacity - 1:].max() # 树的最后一层叶节点,保存的才是每个数据对应的priority
77 |
--------------------------------------------------------------------------------
/network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 |
6 |
7 | class Dueling_Net(nn.Module):
8 | def __init__(self, args):
9 | super(Dueling_Net, self).__init__()
10 | self.fc1 = nn.Linear(args.state_dim, args.hidden_dim)
11 | self.fc2 = nn.Linear(args.hidden_dim, args.hidden_dim)
12 | if args.use_noisy:
13 | self.V = NoisyLinear(args.hidden_dim, 1)
14 | self.A = NoisyLinear(args.hidden_dim, args.action_dim)
15 | else:
16 | self.V = nn.Linear(args.hidden_dim, 1)
17 | self.A = nn.Linear(args.hidden_dim, args.action_dim)
18 |
19 | def forward(self, s):
20 | s = torch.relu(self.fc1(s))
21 | s = torch.relu(self.fc2(s))
22 | V = self.V(s) # batch_size X 1
23 | A = self.A(s) # batch_size X action_dim
24 | Q = V + (A - torch.mean(A, dim=-1, keepdim=True)) # Q(s,a)=V(s)+A(s,a)-mean(A(s,a))
25 | return Q
26 |
27 |
28 | class Net(nn.Module):
29 | def __init__(self, args):
30 | super(Net, self).__init__()
31 | self.fc1 = nn.Linear(args.state_dim, args.hidden_dim)
32 | self.fc2 = nn.Linear(args.hidden_dim, args.hidden_dim)
33 | if args.use_noisy:
34 | self.fc3 = NoisyLinear(args.hidden_dim, args.action_dim)
35 | else:
36 | self.fc3 = nn.Linear(args.hidden_dim, args.action_dim)
37 |
38 | def forward(self, s):
39 | s = torch.relu(self.fc1(s))
40 | s = torch.relu(self.fc2(s))
41 | Q = self.fc3(s)
42 | return Q
43 |
44 |
45 | class NoisyLinear(nn.Module):
46 | def __init__(self, in_features, out_features, sigma_init=0.5):
47 | super(NoisyLinear, self).__init__()
48 | self.in_features = in_features
49 | self.out_features = out_features
50 | self.sigma_init = sigma_init
51 |
52 | self.weight_mu = nn.Parameter(torch.FloatTensor(out_features, in_features))
53 | self.weight_sigma = nn.Parameter(torch.FloatTensor(out_features, in_features))
54 | self.register_buffer('weight_epsilon', torch.FloatTensor(out_features, in_features))
55 |
56 | self.bias_mu = nn.Parameter(torch.FloatTensor(out_features))
57 | self.bias_sigma = nn.Parameter(torch.FloatTensor(out_features))
58 | self.register_buffer('bias_epsilon', torch.FloatTensor(out_features))
59 |
60 | self.reset_parameters()
61 | self.reset_noise()
62 |
63 | def forward(self, x):
64 | if self.training:
65 | self.reset_noise()
66 | weight = self.weight_mu + self.weight_sigma.mul(self.weight_epsilon) # mul是对应元素相乘
67 | bias = self.bias_mu + self.bias_sigma.mul(self.bias_epsilon)
68 |
69 | else:
70 | weight = self.weight_mu
71 | bias = self.bias_mu
72 |
73 | return F.linear(x, weight, bias)
74 |
75 | def reset_parameters(self):
76 | mu_range = 1 / math.sqrt(self.in_features)
77 | self.weight_mu.data.uniform_(-mu_range, mu_range)
78 | self.bias_mu.data.uniform_(-mu_range, mu_range)
79 |
80 | self.weight_sigma.data.fill_(self.sigma_init / math.sqrt(self.in_features))
81 | self.bias_sigma.data.fill_(self.sigma_init / math.sqrt(self.out_features)) # 这里要除以out_features
82 |
83 | def reset_noise(self):
84 | epsilon_i = self.scale_noise(self.in_features)
85 | epsilon_j = self.scale_noise(self.out_features)
86 | self.weight_epsilon.copy_(torch.ger(epsilon_j, epsilon_i))
87 | self.bias_epsilon.copy_(epsilon_j)
88 |
89 | def scale_noise(self, size):
90 | x = torch.randn(size) # torch.randn产生标准高斯分布
91 | x = x.sign().mul(x.abs().sqrt())
92 | return x
93 |
--------------------------------------------------------------------------------
/rainbow_dqn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import copy
4 | from network import Dueling_Net, Net
5 |
6 |
7 | class DQN(object):
8 | def __init__(self, args):
9 | self.action_dim = args.action_dim
10 | self.batch_size = args.batch_size # batch size
11 | self.max_train_steps = args.max_train_steps
12 | self.lr = args.lr # learning rate
13 | self.gamma = args.gamma # discount factor
14 | self.tau = args.tau # Soft update
15 | self.use_soft_update = args.use_soft_update
16 | self.target_update_freq = args.target_update_freq # hard update
17 | self.update_count = 0
18 |
19 | self.grad_clip = args.grad_clip
20 | self.use_lr_decay = args.use_lr_decay
21 | self.use_double = args.use_double
22 | self.use_dueling = args.use_dueling
23 | self.use_per = args.use_per
24 | self.use_n_steps = args.use_n_steps
25 | if self.use_n_steps:
26 | self.gamma = self.gamma ** args.n_steps
27 |
28 | if self.use_dueling: # Whether to use the 'dueling network'
29 | self.net = Dueling_Net(args)
30 | else:
31 | self.net = Net(args)
32 |
33 | self.target_net = copy.deepcopy(self.net) # Copy the online_net to the target_net
34 |
35 | self.optimizer = torch.optim.Adam(self.net.parameters(), lr=self.lr)
36 |
37 | def choose_action(self, state, epsilon):
38 | with torch.no_grad():
39 | state = torch.unsqueeze(torch.tensor(state, dtype=torch.float), 0)
40 | q = self.net(state)
41 | if np.random.uniform() > epsilon:
42 | action = q.argmax(dim=-1).item()
43 | else:
44 | action = np.random.randint(0, self.action_dim)
45 | return action
46 |
47 | def learn(self, replay_buffer, total_steps):
48 | batch, batch_index, IS_weight = replay_buffer.sample(total_steps)
49 |
50 | with torch.no_grad(): # q_target has no gradient
51 | if self.use_double: # Whether to use the 'double q-learning'
52 | # Use online_net to select the action
53 | a_argmax = self.net(batch['next_state']).argmax(dim=-1, keepdim=True) # shape:(batch_size,1)
54 | # Use target_net to estimate the q_target
55 | q_target = batch['reward'] + self.gamma * (1 - batch['terminal']) * self.target_net(batch['next_state']).gather(-1, a_argmax).squeeze(-1) # shape:(batch_size,)
56 | else:
57 | q_target = batch['reward'] + self.gamma * (1 - batch['terminal']) * self.target_net(batch['next_state']).max(dim=-1)[0] # shape:(batch_size,)
58 |
59 | q_current = self.net(batch['state']).gather(-1, batch['action']).squeeze(-1) # shape:(batch_size,)
60 | td_errors = q_current - q_target # shape:(batch_size,)
61 |
62 | if self.use_per:
63 | loss = (IS_weight * (td_errors ** 2)).mean()
64 | replay_buffer.update_batch_priorities(batch_index, td_errors.detach().numpy())
65 | else:
66 | loss = (td_errors ** 2).mean()
67 |
68 | self.optimizer.zero_grad()
69 | loss.backward()
70 | if self.grad_clip:
71 | torch.nn.utils.clip_grad_norm_(self.net.parameters(), self.grad_clip)
72 | self.optimizer.step()
73 |
74 | if self.use_soft_update: # soft update
75 | for param, target_param in zip(self.net.parameters(), self.target_net.parameters()):
76 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
77 | else: # hard update
78 | self.update_count += 1
79 | if self.update_count % self.target_update_freq == 0:
80 | self.target_net.load_state_dict(self.net.state_dict())
81 |
82 | if self.use_lr_decay: # learning rate Decay
83 | self.lr_decay(total_steps)
84 |
85 | def lr_decay(self, total_steps):
86 | lr_now = 0.9 * self.lr * (1 - total_steps / self.max_train_steps) + 0.1 * self.lr
87 | for p in self.optimizer.param_groups:
88 | p['lr'] = lr_now
89 |
--------------------------------------------------------------------------------
/Rainbow_DQN_main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import gym
4 | from torch.utils.tensorboard import SummaryWriter
5 | from replay_buffer import *
6 | from rainbow_dqn import DQN
7 | import argparse
8 |
9 |
10 | class Runner:
11 | def __init__(self, args, env_name, number, seed):
12 | self.args = args
13 | self.env_name = env_name
14 | self.number = number
15 | self.seed = seed
16 |
17 | self.env = gym.make(env_name)
18 | self.env_evaluate = gym.make(env_name) # When evaluating the policy, we need to rebuild an environment
19 | self.env.seed(seed)
20 | self.env.action_space.seed(seed)
21 | self.env_evaluate.seed(seed)
22 | self.env_evaluate.action_space.seed(seed)
23 | np.random.seed(seed)
24 | torch.manual_seed(seed)
25 |
26 | self.args.state_dim = self.env.observation_space.shape[0]
27 | self.args.action_dim = self.env.action_space.n
28 | self.args.episode_limit = self.env._max_episode_steps # Maximum number of steps per episode
29 | print("env={}".format(self.env_name))
30 | print("state_dim={}".format(self.args.state_dim))
31 | print("action_dim={}".format(self.args.action_dim))
32 | print("episode_limit={}".format(self.args.episode_limit))
33 |
34 | if args.use_per and args.use_n_steps:
35 | self.replay_buffer = N_Steps_Prioritized_ReplayBuffer(args)
36 | elif args.use_per:
37 | self.replay_buffer = Prioritized_ReplayBuffer(args)
38 | elif args.use_n_steps:
39 | self.replay_buffer = N_Steps_ReplayBuffer(args)
40 | else:
41 | self.replay_buffer = ReplayBuffer(args)
42 | self.agent = DQN(args)
43 |
44 | self.algorithm = 'DQN'
45 | if args.use_double and args.use_dueling and args.use_noisy and args.use_per and args.use_n_steps:
46 | self.algorithm = 'Rainbow_' + self.algorithm
47 | else:
48 | if args.use_double:
49 | self.algorithm += '_Double'
50 | if args.use_dueling:
51 | self.algorithm += '_Dueling'
52 | if args.use_noisy:
53 | self.algorithm += '_Noisy'
54 | if args.use_per:
55 | self.algorithm += '_PER'
56 | if args.use_n_steps:
57 | self.algorithm += "_N_steps"
58 |
59 | self.writer = SummaryWriter(log_dir='runs/DQN/{}_env_{}_number_{}_seed_{}'.format(self.algorithm, env_name, number, seed))
60 |
61 | self.evaluate_num = 0 # Record the number of evaluations
62 | self.evaluate_rewards = [] # Record the rewards during the evaluating
63 | self.total_steps = 0 # Record the total steps during the training
64 | if args.use_noisy: # 如果使用Noisy net,就不需要epsilon贪心策略了
65 | self.epsilon = 0
66 | else:
67 | self.epsilon = self.args.epsilon_init
68 | self.epsilon_min = self.args.epsilon_min
69 | self.epsilon_decay = (self.args.epsilon_init - self.args.epsilon_min) / self.args.epsilon_decay_steps
70 |
71 | def run(self, ):
72 | self.evaluate_policy()
73 | while self.total_steps < self.args.max_train_steps:
74 | state = self.env.reset()
75 | done = False
76 | episode_steps = 0
77 | while not done:
78 | action = self.agent.choose_action(state, epsilon=self.epsilon)
79 | next_state, reward, done, _ = self.env.step(action)
80 | episode_steps += 1
81 | self.total_steps += 1
82 |
83 | if not self.args.use_noisy: # Decay epsilon
84 | self.epsilon = self.epsilon - self.epsilon_decay if self.epsilon - self.epsilon_decay > self.epsilon_min else self.epsilon_min
85 |
86 | # When dead or win or reaching the max_episode_steps, done will be Ture, we need to distinguish them;
87 | # terminal means dead or win,there is no next state s';
88 | # but when reaching the max_episode_steps,there is a next state s' actually.
89 | if done and episode_steps != self.args.episode_limit:
90 | if self.env_name == 'LunarLander-v2':
91 | if reward <= -100: reward = -1 # good for LunarLander
92 | terminal = True
93 | else:
94 | terminal = False
95 |
96 | self.replay_buffer.store_transition(state, action, reward, next_state, terminal, done) # Store the transition
97 | state = next_state
98 |
99 | if self.replay_buffer.current_size >= self.args.batch_size:
100 | self.agent.learn(self.replay_buffer, self.total_steps)
101 |
102 | if self.total_steps % self.args.evaluate_freq == 0:
103 | self.evaluate_policy()
104 | # Save reward
105 | np.save('./data_train/{}_env_{}_number_{}_seed_{}.npy'.format(self.algorithm, self.env_name, self.number, self.seed), np.array(self.evaluate_rewards))
106 |
107 | def evaluate_policy(self, ):
108 | evaluate_reward = 0
109 | self.agent.net.eval()
110 | for _ in range(self.args.evaluate_times):
111 | state = self.env_evaluate.reset()
112 | done = False
113 | episode_reward = 0
114 | while not done:
115 | action = self.agent.choose_action(state, epsilon=0)
116 | next_state, reward, done, _ = self.env_evaluate.step(action)
117 | episode_reward += reward
118 | state = next_state
119 | evaluate_reward += episode_reward
120 | self.agent.net.train()
121 | evaluate_reward /= self.args.evaluate_times
122 | self.evaluate_rewards.append(evaluate_reward)
123 | print("total_steps:{} \t evaluate_reward:{} \t epsilon:{}".format(self.total_steps, evaluate_reward, self.epsilon))
124 | self.writer.add_scalar('step_rewards_{}'.format(self.env_name), evaluate_reward, global_step=self.total_steps)
125 |
126 |
127 | if __name__ == '__main__':
128 | parser = argparse.ArgumentParser("Hyperparameter Setting for DQN")
129 | parser.add_argument("--max_train_steps", type=int, default=int(4e5), help=" Maximum number of training steps")
130 | parser.add_argument("--evaluate_freq", type=float, default=1e3, help="Evaluate the policy every 'evaluate_freq' steps")
131 | parser.add_argument("--evaluate_times", type=float, default=3, help="Evaluate times")
132 |
133 | parser.add_argument("--buffer_capacity", type=int, default=int(1e5), help="The maximum replay-buffer capacity ")
134 | parser.add_argument("--batch_size", type=int, default=256, help="batch size")
135 | parser.add_argument("--hidden_dim", type=int, default=256, help="The number of neurons in hidden layers of the neural network")
136 | parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate of actor")
137 | parser.add_argument("--gamma", type=float, default=0.99, help="Discount factor")
138 | parser.add_argument("--epsilon_init", type=float, default=0.5, help="Initial epsilon")
139 | parser.add_argument("--epsilon_min", type=float, default=0.1, help="Minimum epsilon")
140 | parser.add_argument("--epsilon_decay_steps", type=int, default=int(1e5), help="How many steps before the epsilon decays to the minimum")
141 | parser.add_argument("--tau", type=float, default=0.005, help="soft update the target network")
142 | parser.add_argument("--use_soft_update", type=bool, default=True, help="Whether to use soft update")
143 | parser.add_argument("--target_update_freq", type=int, default=200, help="Update frequency of the target network(hard update)")
144 | parser.add_argument("--n_steps", type=int, default=5, help="n_steps")
145 | parser.add_argument("--alpha", type=float, default=0.6, help="PER parameter")
146 | parser.add_argument("--beta_init", type=float, default=0.4, help="Important sampling parameter in PER")
147 | parser.add_argument("--use_lr_decay", type=bool, default=True, help="Learning rate Decay")
148 | parser.add_argument("--grad_clip", type=float, default=10.0, help="Gradient clip")
149 |
150 | parser.add_argument("--use_double", type=bool, default=True, help="Whether to use double Q-learning")
151 | parser.add_argument("--use_dueling", type=bool, default=True, help="Whether to use dueling network")
152 | parser.add_argument("--use_noisy", type=bool, default=True, help="Whether to use noisy network")
153 | parser.add_argument("--use_per", type=bool, default=True, help="Whether to use PER")
154 | parser.add_argument("--use_n_steps", type=bool, default=True, help="Whether to use n_steps Q-learning")
155 |
156 | args = parser.parse_args()
157 |
158 | env_names = ['CartPole-v1', 'LunarLander-v2']
159 | env_index = 1
160 | for seed in [0, 10, 100]:
161 | runner = Runner(args=args, env_name=env_names[env_index], number=1, seed=seed)
162 | runner.run()
163 |
--------------------------------------------------------------------------------
/replay_buffer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from collections import deque
4 | from sum_tree import SumTree
5 |
6 |
7 | class ReplayBuffer(object):
8 | def __init__(self, args):
9 | self.batch_size = args.batch_size
10 | self.buffer_capacity = args.buffer_capacity
11 | self.current_size = 0
12 | self.count = 0
13 | self.buffer = {'state': np.zeros((self.buffer_capacity, args.state_dim)),
14 | 'action': np.zeros((self.buffer_capacity, 1)),
15 | 'reward': np.zeros(self.buffer_capacity),
16 | 'next_state': np.zeros((self.buffer_capacity, args.state_dim)),
17 | 'terminal': np.zeros(self.buffer_capacity),
18 | }
19 |
20 | def store_transition(self, state, action, reward, next_state, terminal, done):
21 | self.buffer['state'][self.count] = state
22 | self.buffer['action'][self.count] = action
23 | self.buffer['reward'][self.count] = reward
24 | self.buffer['next_state'][self.count] = next_state
25 | self.buffer['terminal'][self.count] = terminal
26 | self.count = (self.count + 1) % self.buffer_capacity # When the 'count' reaches buffer_capacity, it will be reset to 0.
27 | self.current_size = min(self.current_size + 1, self.buffer_capacity)
28 |
29 | def sample(self, total_steps):
30 | index = np.random.randint(0, self.current_size, size=self.batch_size)
31 | batch = {}
32 | for key in self.buffer.keys(): # numpy->tensor
33 | if key == 'action':
34 | batch[key] = torch.tensor(self.buffer[key][index], dtype=torch.long)
35 | else:
36 | batch[key] = torch.tensor(self.buffer[key][index], dtype=torch.float32)
37 |
38 | return batch, None, None
39 |
40 |
41 | class N_Steps_ReplayBuffer(object):
42 | def __init__(self, args):
43 | self.gamma = args.gamma
44 | self.batch_size = args.batch_size
45 | self.buffer_capacity = args.buffer_capacity
46 | self.current_size = 0
47 | self.count = 0
48 | self.n_steps = args.n_steps
49 | self.n_steps_deque = deque(maxlen=self.n_steps)
50 | self.buffer = {'state': np.zeros((self.buffer_capacity, args.state_dim)),
51 | 'action': np.zeros((self.buffer_capacity, 1)),
52 | 'reward': np.zeros(self.buffer_capacity),
53 | 'next_state': np.zeros((self.buffer_capacity, args.state_dim)),
54 | 'terminal': np.zeros(self.buffer_capacity),
55 | }
56 |
57 | def store_transition(self, state, action, reward, next_state, terminal, done):
58 | transition = (state, action, reward, next_state, terminal, done)
59 | self.n_steps_deque.append(transition)
60 | if len(self.n_steps_deque) == self.n_steps:
61 | state, action, n_steps_reward, next_state, terminal = self.get_n_steps_transition()
62 | self.buffer['state'][self.count] = state
63 | self.buffer['action'][self.count] = action
64 | self.buffer['reward'][self.count] = n_steps_reward
65 | self.buffer['next_state'][self.count] = next_state
66 | self.buffer['terminal'][self.count] = terminal
67 | self.count = (self.count + 1) % self.buffer_capacity # When the 'count' reaches buffer_capacity, it will be reset to 0.
68 | self.current_size = min(self.current_size + 1, self.buffer_capacity)
69 |
70 | def get_n_steps_transition(self):
71 | state, action = self.n_steps_deque[0][:2]
72 | next_state, terminal = self.n_steps_deque[-1][3:5]
73 | n_steps_reward = 0
74 | for i in reversed(range(self.n_steps)):
75 | r, s_, ter, d = self.n_steps_deque[i][2:]
76 | n_steps_reward = r + self.gamma * (1 - d) * n_steps_reward
77 | if d:
78 | next_state, terminal = s_, ter
79 |
80 | return state, action, n_steps_reward, next_state, terminal
81 |
82 | def sample(self, total_steps):
83 | index = np.random.randint(0, self.current_size, size=self.batch_size)
84 | batch = {}
85 | for key in self.buffer.keys(): # numpy->tensor
86 | if key == 'action':
87 | batch[key] = torch.tensor(self.buffer[key][index], dtype=torch.long)
88 | else:
89 | batch[key] = torch.tensor(self.buffer[key][index], dtype=torch.float32)
90 |
91 | return batch, None, None
92 |
93 |
94 | class Prioritized_ReplayBuffer(object):
95 | def __init__(self, args):
96 | self.max_train_steps = args.max_train_steps
97 | self.alpha = args.alpha
98 | self.beta_init = args.beta_init
99 | self.beta = args.beta_init
100 | self.batch_size = args.batch_size
101 | self.buffer_capacity = args.buffer_capacity
102 | self.sum_tree = SumTree(self.buffer_capacity)
103 | self.current_size = 0
104 | self.count = 0
105 | self.buffer = {'state': np.zeros((self.buffer_capacity, args.state_dim)),
106 | 'action': np.zeros((self.buffer_capacity, 1)),
107 | 'reward': np.zeros(self.buffer_capacity),
108 | 'next_state': np.zeros((self.buffer_capacity, args.state_dim)),
109 | 'terminal': np.zeros(self.buffer_capacity),
110 | }
111 |
112 | def store_transition(self, state, action, reward, next_state, terminal, done):
113 | self.buffer['state'][self.count] = state
114 | self.buffer['action'][self.count] = action
115 | self.buffer['reward'][self.count] = reward
116 | self.buffer['next_state'][self.count] = next_state
117 | self.buffer['terminal'][self.count] = terminal
118 | # 如果是第一条经验,初始化优先级为1.0;否则,对于新存入的经验,指定为当前最大的优先级
119 | priority = 1.0 if self.current_size == 0 else self.sum_tree.priority_max
120 | self.sum_tree.update(data_index=self.count, priority=priority) # 更新当前经验在sum_tree中的优先级
121 | self.count = (self.count + 1) % self.buffer_capacity # When the 'count' reaches buffer_capacity, it will be reset to 0.
122 | self.current_size = min(self.current_size + 1, self.buffer_capacity)
123 |
124 | def sample(self, total_steps):
125 | batch_index, IS_weight = self.sum_tree.get_batch_index(current_size=self.current_size, batch_size=self.batch_size, beta=self.beta)
126 | self.beta = self.beta_init + (1 - self.beta_init) * (total_steps / self.max_train_steps) # beta:beta_init->1.0
127 | batch = {}
128 | for key in self.buffer.keys(): # numpy->tensor
129 | if key == 'action':
130 | batch[key] = torch.tensor(self.buffer[key][batch_index], dtype=torch.long)
131 | else:
132 | batch[key] = torch.tensor(self.buffer[key][batch_index], dtype=torch.float32)
133 |
134 | return batch, batch_index, IS_weight
135 |
136 | def update_batch_priorities(self, batch_index, td_errors): # 根据传入的td_error,更新batch_index所对应数据的priorities
137 | priorities = (np.abs(td_errors) + 0.01) ** self.alpha
138 | for index, priority in zip(batch_index, priorities):
139 | self.sum_tree.update(data_index=index, priority=priority)
140 |
141 |
142 | class N_Steps_Prioritized_ReplayBuffer(object):
143 | def __init__(self, args):
144 | self.max_train_steps = args.max_train_steps
145 | self.alpha = args.alpha
146 | self.beta_init = args.beta_init
147 | self.beta = args.beta_init
148 | self.gamma = args.gamma
149 | self.batch_size = args.batch_size
150 | self.buffer_capacity = args.buffer_capacity
151 | self.sum_tree = SumTree(self.buffer_capacity)
152 | self.n_steps = args.n_steps
153 | self.n_steps_deque = deque(maxlen=self.n_steps)
154 | self.buffer = {'state': np.zeros((self.buffer_capacity, args.state_dim)),
155 | 'action': np.zeros((self.buffer_capacity, 1)),
156 | 'reward': np.zeros(self.buffer_capacity),
157 | 'next_state': np.zeros((self.buffer_capacity, args.state_dim)),
158 | 'terminal': np.zeros(self.buffer_capacity),
159 | }
160 | self.current_size = 0
161 | self.count = 0
162 |
163 | def store_transition(self, state, action, reward, next_state, terminal, done):
164 | transition = (state, action, reward, next_state, terminal, done)
165 | self.n_steps_deque.append(transition)
166 | if len(self.n_steps_deque) == self.n_steps:
167 | state, action, n_steps_reward, next_state, terminal = self.get_n_steps_transition()
168 | self.buffer['state'][self.count] = state
169 | self.buffer['action'][self.count] = action
170 | self.buffer['reward'][self.count] = n_steps_reward
171 | self.buffer['next_state'][self.count] = next_state
172 | self.buffer['terminal'][self.count] = terminal
173 | # 如果是buffer中的第一条经验,那么指定priority为1.0;否则对于新存入的经验,指定为当前最大的priority
174 | priority = 1.0 if self.current_size == 0 else self.sum_tree.priority_max
175 | self.sum_tree.update(data_index=self.count, priority=priority) # 更新当前经验在sum_tree中的优先级
176 | self.count = (self.count + 1) % self.buffer_capacity # When 'count' reaches buffer_capacity, it will be reset to 0.
177 | self.current_size = min(self.current_size + 1, self.buffer_capacity)
178 |
179 | def sample(self, total_steps):
180 | batch_index, IS_weight = self.sum_tree.get_batch_index(current_size=self.current_size, batch_size=self.batch_size, beta=self.beta)
181 | self.beta = self.beta_init + (1 - self.beta_init) * (total_steps / self.max_train_steps) # beta:beta_init->1.0
182 | batch = {}
183 | for key in self.buffer.keys(): # numpy->tensor
184 | if key == 'action':
185 | batch[key] = torch.tensor(self.buffer[key][batch_index], dtype=torch.long)
186 | else:
187 | batch[key] = torch.tensor(self.buffer[key][batch_index], dtype=torch.float32)
188 |
189 | return batch, batch_index, IS_weight
190 |
191 | def get_n_steps_transition(self):
192 | state, action = self.n_steps_deque[0][:2] # 获取deque中第一个transition的s和a
193 | next_state, terminal = self.n_steps_deque[-1][3:5] # 获取deque中最后一个transition的s'和terminal
194 | n_steps_reward = 0
195 | for i in reversed(range(self.n_steps)): # 逆序计算n_steps_reward
196 | r, s_, ter, d = self.n_steps_deque[i][2:]
197 | n_steps_reward = r + self.gamma * (1 - d) * n_steps_reward
198 | if d: # 如果done=True,说明一个回合结束,保存deque中当前这个transition的s'和terminal作为这个n_steps_transition的next_state和terminal
199 | next_state, terminal = s_, ter
200 |
201 | return state, action, n_steps_reward, next_state, terminal
202 |
203 | def update_batch_priorities(self, batch_index, td_errors): # 根据传入的td_error,更新batch_index所对应数据的priorities
204 | priorities = (np.abs(td_errors) + 0.01) ** self.alpha
205 | for index, priority in zip(batch_index, priorities):
206 | self.sum_tree.update(data_index=index, priority=priority)
207 |
--------------------------------------------------------------------------------