├── .DS_Store ├── README.md ├── ac ├── CartPole-v0 │ ├── 10core_actor_loss_curve.png │ ├── 10core_critic_loss_curve.png │ ├── 10core_reward_curve.png │ ├── 1core_reward_curve.png │ ├── 2core_reward_curve.png │ ├── 3core_reward_curve.png │ ├── A2C.gif │ ├── A2C.pth │ ├── A3C.gif │ ├── A3C.pth │ ├── AC.gif │ ├── AC.pth │ ├── actor_loss_curve.png │ ├── critic_loss_curve.png │ ├── reward_curve(A2C).png │ ├── reward_curve(A3C).png │ └── reward_curve(AC).png ├── README.md ├── a2c.py ├── a2c_target.py ├── a3c.py ├── ac.py ├── ac_target.py ├── sacwalker.py └── test_env.py ├── custom_env ├── README.md └── snake_env.py ├── ddpg ├── ddpg.py └── td3.py ├── deep_learning ├── 作业1-傅上峰.ipynb ├── 作业1-廖子兴.ipynb ├── 作业1-邓成杰.ipynb ├── 作业1-邹庆翔.ipynb ├── 作业1——手写数字识别.ipynb └── 喻风-minst.ipynb ├── dqn ├── ddqn.py ├── ddrqn.py ├── dqn.py └── drqn.py ├── pg ├── reinforce.py └── reinforce_baseline.py ├── ppo ├── .idea │ ├── .gitignore │ ├── RL_algorithm.iml │ ├── inspectionProfiles │ │ └── profiles_settings.xml │ ├── misc.xml │ └── modules.xml ├── BipedalWalker-v3 │ ├── 2024_01_30_22_54_BipedalWalker-v3.pth │ ├── 2024_01_30_23_01_BipedalWalker-v3.pth │ ├── 2024_01_30_23_05_BipedalWalker-v3.pth │ ├── 2024_01_30_23_16_BipedalWalker-v3.pth │ ├── 2024_01_30_23_17_BipedalWalker-v3.pth │ ├── 2024_01_30_23_26_BipedalWalker-v3.pth │ ├── 2024_01_30_23_32_BipedalWalker-v3.pth │ ├── 2024_01_30_23_35_BipedalWalker-v3.pth │ ├── 2024_01_30_23_39_BipedalWalker-v3.pth │ ├── 2024_01_30_23_40_BipedalWalker-v3.pth │ ├── 2024_01_30_23_50_BipedalWalker-v3.pth │ ├── 2024_01_31_00_09_BipedalWalker-v3.pth │ ├── 2024_01_31_00_11_BipedalWalker-v3.pth │ ├── 2024_01_31_00_15_BipedalWalker-v3.pth │ ├── 2024_02_01_10_58_BipedalWalker-v3.pth │ ├── 2024_02_01_10_59_BipedalWalker-v3.pth │ ├── 2024_02_01_11_00_BipedalWalker-v3.pth │ ├── 2024_02_01_11_01_BipedalWalker-v3.pth │ ├── 2024_02_01_11_02_BipedalWalker-v3.pth │ ├── 2024_02_01_11_03_BipedalWalker-v3.pth │ ├── 2024_02_01_11_04_BipedalWalker-v3.pth │ ├── 2024_02_01_11_06_BipedalWalker-v3.pth │ ├── 2024_02_01_11_07_BipedalWalker-v3.pth │ ├── 2024_02_01_11_08_BipedalWalker-v3.pth │ ├── 2024_02_01_15_41_BipedalWalker-v3.pth │ ├── 2024_02_01_15_48_BipedalWalker-v3.pth │ ├── 2024_02_01_15_49_BipedalWalker-v3.pth │ ├── 2024_02_01_15_50_BipedalWalker-v3.pth │ ├── 2024_02_01_15_51_BipedalWalker-v3.pth │ ├── 2024_02_01_15_53_BipedalWalker-v3.pth │ ├── 2024_02_01_15_54_BipedalWalker-v3.pth │ ├── 2024_02_01_15_59_BipedalWalker-v3.pth │ ├── 2024_02_01_16_23_BipedalWalker-v3.pth │ ├── 2024_02_01_16_25_BipedalWalker-v3.pth │ ├── 2024_02_01_16_26_BipedalWalker-v3.pth │ ├── 2024_02_01_16_27_BipedalWalker-v3.pth │ ├── 2024_02_01_16_28_BipedalWalker-v3.pth │ ├── 2024_02_01_16_29_BipedalWalker-v3.pth │ ├── 2024_02_01_16_30_BipedalWalker-v3.pth │ ├── 2024_02_01_16_31_BipedalWalker-v3.pth │ ├── 2024_02_01_16_32_BipedalWalker-v3.pth │ ├── 2024_02_01_16_33_BipedalWalker-v3.pth │ ├── 2024_02_01_16_34_BipedalWalker-v3.pth │ ├── 2024_02_01_16_35_BipedalWalker-v3.pth │ ├── 2024_02_01_16_36_BipedalWalker-v3.pth │ ├── 2024_02_01_16_42_BipedalWalker-v3.pth │ ├── 2024_02_01_17_53_BipedalWalkerHardcore-v3.pth │ ├── 2024_02_01_17_55_BipedalWalkerHardcore-v3.pth │ ├── 2024_02_01_17_56_BipedalWalkerHardcore-v3.pth │ ├── 2024_02_01_17_58_BipedalWalkerHardcore-v3.pth │ ├── 2024_02_01_17_59_BipedalWalkerHardcore-v3.pth │ ├── 2024_02_01_18_00_BipedalWalkerHardcore-v3.pth │ ├── 2024_02_01_18_02_BipedalWalkerHardcore-v3.pth │ ├── 2024_02_01_18_05_BipedalWalkerHardcore-v3.pth │ ├── 2024_02_01_18_06_BipedalWalkerHardcore-v3.pth │ ├── 2024_02_01_18_07_BipedalWalkerHardcore-v3.pth │ ├── 2024_02_01_18_08_BipedalWalkerHardcore-v3.pth │ ├── 2024_02_01_18_09_BipedalWalkerHardcore-v3.pth │ ├── 2024_02_01_18_12_BipedalWalkerHardcore-v3.pth │ ├── 2024_02_01_18_13_BipedalWalkerHardcore-v3.pth │ ├── 2024_02_01_18_24_BipedalWalkerHardcore-v3.pth │ ├── 2024_02_01_18_41_BipedalWalkerHardcore-v3.pth │ ├── 2024_02_01_18_45_BipedalWalkerHardcore-v3.pth │ ├── 2024_02_01_18_51_BipedalWalkerHardcore-v3.pth │ ├── 2024_02_01_18_56_BipedalWalkerHardcore-v3.pth │ ├── 2024_02_01_19_11_BipedalWalkerHardcore-v3.pth │ ├── 2024_02_01_19_19_BipedalWalkerHardcore-v3.pth │ ├── 2024_02_01_19_26_BipedalWalkerHardcore-v3.pth │ ├── 2024_02_01_19_44_BipedalWalkerHardcore-v3.pth │ ├── 2024_02_01_19_52_BipedalWalkerHardcore-v3.pth │ ├── 2024_02_01_19_56_BipedalWalkerHardcore-v3.pth │ ├── 2024_02_01_20_12_BipedalWalkerHardcore-v3.pth │ ├── 2024_02_01_20_16_BipedalWalkerHardcore-v3.pth │ ├── 2024_02_01_20_18_BipedalWalkerHardcore-v3.pth │ ├── 2024_02_01_20_22_BipedalWalkerHardcore-v3.pth │ ├── 2024_02_01_21_10_BipedalWalkerHardcore-v3.pth │ ├── 2024_02_01_21_14_BipedalWalkerHardcore-v3.pth │ ├── 2024_02_01_21_25_BipedalWalkerHardcore-v3.pth │ ├── 2024_02_01_21_30_BipedalWalkerHardcore-v3.pth │ └── 2024_02_01_21_54_BipedalWalkerHardcore-v3.pth ├── BipedwalkerHardcoreTest.py ├── MultiFollow │ ├── 2024_02_24_20_31_Follow.pth │ └── 2024_02_24_20_32.gif ├── custom_env │ ├── MultiAgentFollowEnvV1.py │ ├── MultiAgentFollowEnvV2.py │ ├── Walker_Discreate.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── MultiAgentFollowEnvV2.cpython-38.pyc │ │ ├── MultiAgentFollowEnvV3.cpython-38.pyc │ │ ├── Walker_Discreate.cpython-38.pyc │ │ ├── __init__.cpython-38.pyc │ │ └── snake_env.cpython-38.pyc │ └── snake_env.py ├── ppo.py ├── ppo_continuous.py ├── ppo_discrete.py ├── ppo_discretemask.py ├── ppo_multiagentfollow.py ├── ppo_multidiscrete.py ├── ppo_multienv.py ├── ppo_sharing.py ├── ppo_snake.py ├── ppoconPendulum.py ├── ppodisPendulum.py ├── runs │ ├── BipedalWalker-v3__ppo_continuous__1__1706597051 │ │ └── events.out.tfevents.1706597051.SKY-20230422NIZ.12216.0 │ ├── BipedalWalker-v3__ppo_continuous__1__1706771659 │ │ └── events.out.tfevents.1706771659.DESKTOP-VAT73EM.12564.0 │ ├── BipedalWalker-v3__ppo_multienv__1__1706755654 │ │ └── events.out.tfevents.1706755654.DESKTOP-VAT73EM.4120.0 │ ├── BipedalWalker-v3__ppo_multienv__1__1706774718 │ │ └── events.out.tfevents.1706774718.DESKTOP-VAT73EM.11980.0 │ ├── BipedalWalkerHardcore-v3__ppo_multienv__1__1706780430 │ │ └── events.out.tfevents.1706780430.DESKTOP-VAT73EM.12616.0 │ ├── Follow__ppo_sharing__1__1708769402 │ │ └── events.out.tfevents.1708769402.LAPTOP-K5GRC2HU.20764.0 │ ├── Snake-v0__ppo_discretemask__1__1706672792 │ │ └── events.out.tfevents.1706672792.SKY-20230422NIZ.11492.0 │ ├── Snake-v0__ppo_discretemask__1__1707018799 │ │ └── events.out.tfevents.1707018799.DESKTOP-VAT73EM.4348.0 │ ├── Snake-v0__ppo_snake__1__1706623864 │ │ └── events.out.tfevents.1706623864.SKY-20230422NIZ.676.0 │ ├── Snake-v0__ppo_snake__1__1708424259 │ │ └── events.out.tfevents.1708424259.LAPTOP-K5GRC2HU.22668.0 │ ├── Walker__ppomultidiscrete__1__1706597666 │ │ └── events.out.tfevents.1706597666.SKY-20230422NIZ.10160.0 │ └── runs.rar ├── snake-v0 │ ├── 2024_01_30_22_14_Snake-v0.pth │ ├── 2024_01_30_22_19_Snake-v0.pth │ ├── 2024_01_30_22_22_Snake-v0.pth │ ├── 2024_01_30_22_36_Snake-v0.pth │ ├── 2024_01_30_22_49_Snake-v0.pth │ ├── 2024_01_31_19_26_Snake-v0.pth │ ├── 2024_01_31_19_27_Snake-v0.pth │ └── 2024_01_31_19_28_Snake-v0.pth ├── snake-v0mask │ ├── 2024_01_31_11_48_Snake-v0.pth │ ├── 2024_01_31_11_49_Snake-v0.pth │ ├── 2024_01_31_11_51_Snake-v0.pth │ ├── 2024_01_31_11_55_Snake-v0.pth │ ├── 2024_01_31_12_00_Snake-v0.pth │ ├── 2024_01_31_12_09_Snake-v0.pth │ ├── 2024_01_31_12_13_Snake-v0.pth │ ├── 2024_01_31_12_31_Snake-v0.pth │ ├── 2024_01_31_12_48_Snake-v0.pth │ ├── 2024_01_31_17_20_Snake-v0.pth │ └── 2024_01_31_17_21_Snake-v0.pth └── test_follow.py ├── requirements.txt ├── result └── gif │ ├── BipedalWalker-v3 │ ├── BipedalWalker-v3.gif │ ├── SAC.png │ ├── SAC.pth │ ├── SACtrick.png │ └── SACtrick.pth │ ├── BipedalWalkerHardcore-v3 │ ├── BipedalWalkerHardcore-v3.gif │ ├── SAC10000.png │ ├── SAC10000.pth │ ├── SAC3000.png │ └── SAC3000.pth │ ├── Snake-v0 │ ├── A2C.gif │ ├── A2C.png │ ├── A2C.pth │ ├── A2C.yaml │ ├── A2C_TARGET.png │ ├── A2C_TARGET.pth │ ├── AC.png │ ├── AC.pth │ ├── AC_TARGET.png │ ├── AC_TARGET.pth │ ├── DDQN.gif │ ├── DDQN.png │ ├── DDQN.pth │ ├── DDRQN.png │ ├── DDRQN.pth │ ├── DQN.png │ ├── DQN.pth │ ├── DRQN.png │ ├── DRQN.pth │ ├── PPO.gif │ ├── PPO.png │ ├── PPO.pth │ ├── REINFORCE.png │ ├── REINFORCE.pth │ ├── REINFORCE_baseline.gif │ ├── REINFORCE_baseline.png │ └── REINFORCE_baseline.pth │ └── Snake-v1 │ ├── A2C.gif │ ├── DDQN.gif │ ├── PPO.gif │ └── REINFORCE_baseline.gif ├── rl_utils.py ├── test_agent.py └── test_env.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Algorithm 3 | 4 | ## Implemented algorithm 5 | 6 | dqn, ddqn, drqn 7 | 8 | reinforce, reinforce with baseline 9 | 10 | ac, ac with target, a2c, a2c with target, a3c, sac 11 | 12 | ppo, ippo, multidiscrete action ppo 13 | 14 | ## To-be-implemented algorithm 15 | 16 | dueling dqn 17 | 18 | trpo 19 | 20 | ddpg 21 | 22 | 23 | # benchmark 24 | 25 | - custom env 26 | 27 | - Snake-0 28 | 29 | - Walker(BipedalWalker-v3 discrete version) 30 | 31 | 32 | - CartPole-v0 33 | 34 | - Pendulum-v1 35 | 36 | - BipedalWalker-v3 37 | 38 | - BipedalWalkerHardcore-v3 39 | 40 | 41 | -------------------------------------------------------------------------------- /ac/CartPole-v0/10core_actor_loss_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ac/CartPole-v0/10core_actor_loss_curve.png -------------------------------------------------------------------------------- /ac/CartPole-v0/10core_critic_loss_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ac/CartPole-v0/10core_critic_loss_curve.png -------------------------------------------------------------------------------- /ac/CartPole-v0/10core_reward_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ac/CartPole-v0/10core_reward_curve.png -------------------------------------------------------------------------------- /ac/CartPole-v0/1core_reward_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ac/CartPole-v0/1core_reward_curve.png -------------------------------------------------------------------------------- /ac/CartPole-v0/2core_reward_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ac/CartPole-v0/2core_reward_curve.png -------------------------------------------------------------------------------- /ac/CartPole-v0/3core_reward_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ac/CartPole-v0/3core_reward_curve.png -------------------------------------------------------------------------------- /ac/CartPole-v0/A2C.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ac/CartPole-v0/A2C.gif -------------------------------------------------------------------------------- /ac/CartPole-v0/A2C.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ac/CartPole-v0/A2C.pth -------------------------------------------------------------------------------- /ac/CartPole-v0/A3C.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ac/CartPole-v0/A3C.gif -------------------------------------------------------------------------------- /ac/CartPole-v0/A3C.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ac/CartPole-v0/A3C.pth -------------------------------------------------------------------------------- /ac/CartPole-v0/AC.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ac/CartPole-v0/AC.gif -------------------------------------------------------------------------------- /ac/CartPole-v0/AC.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ac/CartPole-v0/AC.pth -------------------------------------------------------------------------------- /ac/CartPole-v0/actor_loss_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ac/CartPole-v0/actor_loss_curve.png -------------------------------------------------------------------------------- /ac/CartPole-v0/critic_loss_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ac/CartPole-v0/critic_loss_curve.png -------------------------------------------------------------------------------- /ac/CartPole-v0/reward_curve(A2C).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ac/CartPole-v0/reward_curve(A2C).png -------------------------------------------------------------------------------- /ac/CartPole-v0/reward_curve(A3C).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ac/CartPole-v0/reward_curve(A3C).png -------------------------------------------------------------------------------- /ac/CartPole-v0/reward_curve(AC).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ac/CartPole-v0/reward_curve(AC).png -------------------------------------------------------------------------------- /ac/README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | 3 | Just implement the Actor-Critic(AC), the Advantage Actor-Critic(A2C) and Asynchronous Advantage Actor-Critic(A3C) in the simple environment CatPole-v0. The implemented code is not generic and now just test in this simple environment CartPole-v0 4 | 5 | # How to train 6 | 7 | Just run the program ac.py, a2c.py or a3c.py 8 | 9 | # How to test 10 | 11 | Just run the test.py 12 | 13 | # The results demonstrate 14 | 15 | The fellows are only the reward curve and you can find more in the "Actor-Critic-Algorithm-Learning\CartPole-v0" directory. 16 | 17 | ![reward_curve(AC)](https://github.com/sunwuzhou03/rl_learning/blob/master/Actor-Critic-Algorithm-Learning/CartPole-v0/reward_curve(AC).png) 18 | 19 | ![reward_curve(A2C)](https://github.com/sunwuzhou03/rl_learning/blob/master/Actor-Critic-Algorithm-Learning/CartPole-v0/reward_curve(A2C).png) 20 | 21 | ![reward_curve(A3C)](https://github.com/sunwuzhou03/rl_learning/blob/master/Actor-Critic-Algorithm-Learning/CartPole-v0/reward_curve(A3C).png) -------------------------------------------------------------------------------- /ac/a2c.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import gym 4 | import pygame 5 | import collections 6 | import random 7 | import torch.nn.functional as F 8 | from tqdm import tqdm 9 | import matplotlib.pyplot as plt 10 | import time 11 | import torch.nn as nn 12 | import yaml 13 | import os 14 | 15 | 16 | def plot_smooth_reward(rewards, 17 | directory="./", 18 | filename="smooth_reward_plot", 19 | y_label='Reward', 20 | x_label='Episode', 21 | window_size=100): 22 | 23 | # 创建目标目录(如果不存在) 24 | os.makedirs(directory, exist_ok=True) 25 | 26 | img_title = filename 27 | # 拼接文件名和扩展名 28 | filename = f"{filename}.png" 29 | 30 | # 构建完整的文件路径 31 | filepath = os.path.join(directory, filename) 32 | 33 | # 计算滑动窗口平均值 34 | smoothed_rewards = np.convolve(rewards, 35 | np.ones(window_size) / window_size, 36 | mode='valid') 37 | 38 | # 绘制原始奖励和平滑奖励曲线 39 | plt.plot(rewards, label='Raw Data Curve') 40 | plt.plot(smoothed_rewards, label='Smoothed Data Curve') 41 | 42 | # 设置图例、标题和轴标签 43 | plt.legend() 44 | plt.title(img_title) 45 | plt.xlabel(x_label) 46 | plt.ylabel(y_label) 47 | 48 | # 保存图像 49 | plt.savefig(filepath) 50 | 51 | # 关闭图像 52 | plt.close() 53 | 54 | 55 | class CriticNet(nn.Module): 56 | def __init__(self, state_dim, hidden_dim, action_dim) -> None: 57 | super().__init__() 58 | self.fc1 = torch.nn.Linear(state_dim, hidden_dim) 59 | self.fc2 = torch.nn.Linear(hidden_dim, 1) 60 | 61 | def forward(self, state): 62 | hidden_state = F.relu(self.fc1(state)) 63 | acion_value = self.fc2(hidden_state) 64 | return acion_value 65 | 66 | 67 | class ActorNet(nn.Module): 68 | def __init__(self, state_dim, hidden_dim, action_dim) -> None: 69 | super().__init__() 70 | self.fc1 = nn.Linear(state_dim, hidden_dim) 71 | self.fc2 = nn.Linear(hidden_dim, action_dim) 72 | 73 | def forward(self, state): 74 | 75 | hidden_state = F.relu(self.fc1(state)) 76 | # print(self.fc2(hidden_state)) 77 | probs = F.softmax(self.fc2(hidden_state), dim=1) 78 | return probs 79 | 80 | 81 | class A2C: 82 | def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr, 83 | gamma, device) -> None: 84 | self.actor = ActorNet(state_dim, hidden_dim, action_dim).to(device) 85 | self.critic = CriticNet(state_dim, hidden_dim, action_dim).to(device) 86 | self.gamma = gamma 87 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), 88 | lr=actor_lr) 89 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), 90 | lr=critic_lr) 91 | self.device = device 92 | 93 | def save_model(self, save_path='./', filename='model'): 94 | model = {'actor': self.actor, 'critic': self.critic} 95 | torch.save(model, f"{save_path}\\{filename}.pth") 96 | 97 | def load_model(self, load_path): 98 | model = torch.load(load_path) 99 | self.actor = model['actor'] 100 | self.critic = model['critic'] 101 | 102 | def take_action(self, state): 103 | state = torch.tensor([state], dtype=torch.float).to(self.device) 104 | probs = self.actor(state) 105 | action_dist = torch.distributions.Categorical(probs) 106 | action = action_dist.sample() 107 | return action.item() 108 | 109 | def evaluate(self, state): 110 | state = torch.tensor([state], dtype=torch.float).to(self.device) 111 | probs = self.actor(state) 112 | action = torch.argmax(probs, dim=1) 113 | return action.item() 114 | 115 | def update(self, transition_dict): 116 | 117 | rewards = torch.tensor(transition_dict['rewards'], 118 | dtype=torch.float).view(-1, 1).to(self.device) 119 | states = torch.tensor(transition_dict['states'], 120 | dtype=torch.float).to(self.device) 121 | actions = torch.tensor(transition_dict['actions']).view(-1, 1).to( 122 | self.device) 123 | next_states = torch.tensor(transition_dict['next_states'], 124 | dtype=torch.float).to(self.device) 125 | next_actions = torch.tensor(transition_dict['next_actions']).view( 126 | -1, 1).to(self.device) 127 | dones = torch.tensor(transition_dict['dones'], 128 | dtype=torch.float).to(self.device).view(-1, 1) 129 | 130 | v_now = self.critic(states).view(-1, 1) 131 | 132 | v_next = self.critic(next_states).view(-1, 1) 133 | 134 | y_now = (self.gamma * v_next * (1 - dones) + rewards).view(-1, 1) 135 | td_delta = y_now - v_now 136 | log_prob = torch.log(self.actor(states).gather(1, actions)) 137 | 138 | actor_loss = torch.mean(-log_prob * td_delta.detach()) 139 | critic_loss = torch.mean(F.mse_loss(y_now.detach(), v_now)) 140 | 141 | self.actor_optimizer.zero_grad() 142 | self.critic_optimizer.zero_grad() 143 | 144 | actor_loss.backward() 145 | critic_loss.backward() 146 | 147 | self.actor_optimizer.step() 148 | self.critic_optimizer.step() 149 | 150 | 151 | if __name__ == "__main__": 152 | 153 | gamma = 0.99 154 | algorithm_name = "A2C" 155 | num_episodes = 1000 156 | actor_lr = 1e-3 157 | critic_lr = 1e-2 158 | env_name = 'CartPole-v0' 159 | hidden_dim = 128 160 | 161 | #选择设备 162 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 163 | 164 | # 注册环境 165 | gym.register(id='Snake-v0', entry_point='snake_env:SnakeEnv') 166 | 167 | env = gym.make(env_name) 168 | 169 | random.seed(0) 170 | np.random.seed(0) 171 | env.seed(0) 172 | torch.manual_seed(0) 173 | 174 | state_dim = env.observation_space.shape[0] 175 | 176 | action_dim = env.action_space.n 177 | agent = A2C(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, gamma, 178 | device) 179 | 180 | return_list = [] 181 | max_reward = 0 182 | start_time = time.time() 183 | flag = 0 184 | for i in range(10): 185 | with tqdm(total=int(num_episodes / 10), 186 | desc='Iteration %d' % i) as pbar: 187 | for i_episodes in range(int(num_episodes / 10)): 188 | episode_return = 0 189 | state = env.reset() 190 | done = False 191 | transition_dict = { 192 | 'states': [], 193 | 'actions': [], 194 | 'next_states': [], 195 | 'next_actions': [], 196 | 'rewards': [], 197 | 'dones': [] 198 | } 199 | 200 | while not done: 201 | action = agent.take_action(state) 202 | next_state, reward, done, _ = env.step(action) 203 | next_action = agent.take_action(next_state) 204 | 205 | transition_dict['states'].append(state) 206 | transition_dict['actions'].append(action) 207 | transition_dict['next_states'].append(next_state) 208 | transition_dict['next_actions'].append(next_action) 209 | transition_dict['rewards'].append(reward) 210 | transition_dict['dones'].append(done) 211 | 212 | state = next_state 213 | episode_return += reward 214 | agent.update(transition_dict) 215 | 216 | return_list.append(episode_return) 217 | plot_smooth_reward(return_list, env_name, "reward_curve(A2C)") 218 | if episode_return >= max_reward: 219 | max_reward = episode_return 220 | agent.save_model(env_name, algorithm_name) 221 | 222 | if episode_return >= 200 and flag == 0: 223 | flag = 1 224 | end_time = time.time() 225 | run_time = end_time - start_time 226 | 227 | # 打印程序运行时间 228 | print(f"A2C到达200需要:{run_time}秒") 229 | 230 | if (i_episodes + 1 % 10 == 0): 231 | pbar.set_postfix({ 232 | 'episode': 233 | '%d' % (num_episodes / 10 * i + i_episodes + 1), 234 | 'return': 235 | '%.3f' % np.mean(return_list[-10:]) 236 | }) 237 | pbar.update(1) 238 | -------------------------------------------------------------------------------- /ac/a2c_target.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import gym 4 | import pygame 5 | import collections 6 | import random 7 | import torch.nn.functional as F 8 | from tqdm import tqdm 9 | import matplotlib.pyplot as plt 10 | import time 11 | import torch.nn as nn 12 | from rl_utils import plot_smooth_reward 13 | 14 | 15 | class CriticNet(nn.Module): 16 | def __init__(self, state_dim, hidden_dim, action_dim) -> None: 17 | super().__init__() 18 | self.fc1 = torch.nn.Linear(state_dim, hidden_dim) 19 | self.fc2 = torch.nn.Linear(hidden_dim, 1) 20 | 21 | def forward(self, state): 22 | hidden_state = F.relu(self.fc1(state)) 23 | acion_value = self.fc2(hidden_state) 24 | return acion_value 25 | 26 | 27 | class ActorNet(nn.Module): 28 | def __init__(self, state_dim, hidden_dim, action_dim) -> None: 29 | super().__init__() 30 | self.fc1 = nn.Linear(state_dim, hidden_dim) 31 | self.fc2 = nn.Linear(hidden_dim, action_dim) 32 | 33 | def forward(self, state): 34 | 35 | hidden_state = F.relu(self.fc1(state)) 36 | # print(self.fc2(hidden_state)) 37 | probs = F.softmax(self.fc2(hidden_state), dim=1) 38 | return probs 39 | 40 | 41 | class A2C_TARGET: 42 | def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr, 43 | gamma, tau, device) -> None: 44 | self.actor = ActorNet(state_dim, hidden_dim, action_dim).to(device) 45 | self.critic = CriticNet(state_dim, hidden_dim, action_dim).to(device) 46 | self.critic_target = CriticNet(state_dim, hidden_dim, 47 | action_dim).to(device) 48 | self.critic_target.load_state_dict(self.critic.state_dict()) 49 | 50 | self.gamma = gamma 51 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), 52 | lr=actor_lr) 53 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), 54 | lr=critic_lr) 55 | self.device = device 56 | self.tau = tau 57 | 58 | def save_model(self, save_path='./', filename='model'): 59 | torch.save(self.actor.state_dict(), f"{save_path}\\{filename}.pth") 60 | 61 | def load_model(self, load_path): 62 | self.actor.load_state_dict(torch.load(load_path)) 63 | 64 | def take_action(self, state): 65 | state = torch.tensor([state], dtype=torch.float).to(self.device) 66 | probs = self.actor(state) 67 | action_dist = torch.distributions.Categorical(probs) 68 | action = action_dist.sample() 69 | return action.item() 70 | 71 | def soft_update(self, net, target_net): 72 | for param_target, param in zip(target_net.parameters(), 73 | net.parameters()): 74 | param_target.data.copy_(param_target.data * (1 - self.tau) + 75 | param.data * self.tau) 76 | 77 | def update(self, transition_dict): 78 | rewards = torch.tensor(transition_dict['rewards'], 79 | dtype=torch.float).view(-1, 1).to(self.device) 80 | states = torch.tensor(transition_dict['states'], 81 | dtype=torch.float).to(self.device) 82 | actions = torch.tensor(transition_dict['actions']).view(-1, 1).to( 83 | self.device) 84 | next_states = torch.tensor(transition_dict['next_states'], 85 | dtype=torch.float).to(self.device) 86 | next_actions = torch.tensor(transition_dict['next_actions']).view( 87 | -1, 1).to(self.device) 88 | dones = torch.tensor(transition_dict['dones'], 89 | dtype=torch.float).to(self.device).view(-1, 1) 90 | 91 | v_now = self.critic(states).view(-1, 1) 92 | v_next = self.critic_target(next_states).view(-1, 1) 93 | y_now = (self.gamma * v_next * (1 - dones) + rewards).view(-1, 1) 94 | td_delta = y_now - v_now 95 | log_prob = torch.log(self.actor(states).gather(1, actions)) 96 | 97 | actor_loss = torch.mean(-log_prob * td_delta.detach()) 98 | critic_loss = torch.mean(F.mse_loss(y_now.detach(), v_now)) 99 | 100 | self.actor_optimizer.zero_grad() 101 | self.critic_optimizer.zero_grad() 102 | 103 | actor_loss.backward() 104 | critic_loss.backward() 105 | 106 | self.actor_optimizer.step() 107 | self.critic_optimizer.step() 108 | self.soft_update(self.critic, self.critic_target) 109 | 110 | 111 | if __name__ == "__main__": 112 | 113 | env_name = 'Snake-v0' #'CartPole-v0' 114 | 115 | # 注册环境 116 | gym.register(id='Snake-v0', entry_point='snake_env:SnakeEnv') 117 | 118 | env = gym.make(env_name) 119 | 120 | random.seed(0) 121 | np.random.seed(0) 122 | env.seed(0) 123 | torch.manual_seed(0) 124 | gamma = 0.98 125 | algorithm_name = "A2C_TARGET" 126 | num_episodes = 10000 127 | actor_lr = 1e-3 128 | critic_lr = 3e-3 129 | tau = 5e-3 130 | device = torch.device('cuda') 131 | state_dim = env.observation_space.shape[0] 132 | hidden_dim = 128 133 | action_dim = env.action_space.n 134 | 135 | agent = A2C_TARGET(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, 136 | gamma, tau, device) 137 | 138 | return_list = [] 139 | max_reward = 0 140 | for i in range(10): 141 | with tqdm(total=int(num_episodes / 10), 142 | desc='Iteration %d' % i) as pbar: 143 | for i_episodes in range(int(num_episodes / 10)): 144 | episode_return = 0 145 | state = env.reset() 146 | done = False 147 | 148 | transition_dict = { 149 | 'states': [], 150 | 'actions': [], 151 | 'next_states': [], 152 | 'next_actions': [], 153 | 'rewards': [], 154 | 'dones': [] 155 | } 156 | 157 | while not done: 158 | action = agent.take_action(state) 159 | next_state, reward, done, _ = env.step(action) 160 | next_action = agent.take_action(next_state) 161 | 162 | transition_dict['states'].append(state) 163 | transition_dict['actions'].append(action) 164 | transition_dict['next_states'].append(next_state) 165 | transition_dict['next_actions'].append(next_action) 166 | transition_dict['rewards'].append(reward) 167 | transition_dict['dones'].append(done) 168 | 169 | state = next_state 170 | episode_return += reward 171 | if i_episodes == int(num_episodes / 10) - 1: 172 | # time.sleep(0.1) 173 | env.render() 174 | 175 | agent.update(transition_dict) 176 | 177 | return_list.append(episode_return) 178 | plot_smooth_reward(return_list, 100, env_name, algorithm_name) 179 | 180 | if episode_return > max_reward: 181 | max_reward = episode_return 182 | agent.save_model(env_name, algorithm_name) 183 | if (i_episodes + 1 % 10 == 0): 184 | pbar.set_postfix({ 185 | 'episode': 186 | '%d' % (num_episodes / 10 * i + i_episodes + 1), 187 | 'return': 188 | '%.3f' % np.mean(return_list[-10:]) 189 | }) 190 | pbar.update(1) 191 | -------------------------------------------------------------------------------- /ac/ac.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import gym 4 | import pygame 5 | import collections 6 | import random 7 | import torch.nn.functional as F 8 | from tqdm import tqdm 9 | import matplotlib.pyplot as plt 10 | import time 11 | import torch.nn as nn 12 | import os 13 | 14 | 15 | def plot_smooth_reward(rewards, 16 | directory="./", 17 | filename="smooth_reward_plot", 18 | y_label='Reward', 19 | x_label='Episode', 20 | window_size=100): 21 | 22 | # 创建目标目录(如果不存在) 23 | os.makedirs(directory, exist_ok=True) 24 | 25 | img_title = filename 26 | # 拼接文件名和扩展名 27 | filename = f"{filename}.png" 28 | 29 | # 构建完整的文件路径 30 | filepath = os.path.join(directory, filename) 31 | 32 | # 计算滑动窗口平均值 33 | smoothed_rewards = np.convolve(rewards, 34 | np.ones(window_size) / window_size, 35 | mode='valid') 36 | 37 | # 绘制原始奖励和平滑奖励曲线 38 | plt.plot(rewards, label='Raw Data Curve') 39 | plt.plot(smoothed_rewards, label='Smoothed Data Curve') 40 | 41 | # 设置图例、标题和轴标签 42 | plt.legend() 43 | plt.title(img_title) 44 | plt.xlabel(x_label) 45 | plt.ylabel(y_label) 46 | 47 | # 保存图像 48 | plt.savefig(filepath) 49 | 50 | # 关闭图像 51 | plt.close() 52 | 53 | 54 | class CriticNet(nn.Module): 55 | def __init__(self, state_dim, hidden_dim, action_dim) -> None: 56 | super().__init__() 57 | self.fc1 = torch.nn.Linear(state_dim, hidden_dim) 58 | self.fc2 = torch.nn.Linear(hidden_dim, action_dim) 59 | 60 | def forward(self, state): 61 | hidden_state = F.relu(self.fc1(state)) 62 | acion_value = self.fc2(hidden_state) 63 | return acion_value 64 | 65 | 66 | class ActorNet(nn.Module): 67 | def __init__(self, state_dim, hidden_dim, action_dim) -> None: 68 | super().__init__() 69 | self.fc1 = nn.Linear(state_dim, hidden_dim) 70 | self.fc2 = nn.Linear(hidden_dim, action_dim) 71 | 72 | def forward(self, state): 73 | 74 | hidden_state = F.relu(self.fc1(state)) 75 | # print(self.fc2(hidden_state)) 76 | probs = F.softmax(self.fc2(hidden_state), dim=1) 77 | return probs 78 | 79 | 80 | class AC: 81 | def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr, 82 | gamma, device) -> None: 83 | self.actor = ActorNet(state_dim, hidden_dim, action_dim).to(device) 84 | self.critic = CriticNet(state_dim, hidden_dim, action_dim).to(device) 85 | self.gamma = gamma 86 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), 87 | lr=actor_lr) 88 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), 89 | lr=critic_lr) 90 | self.device = device 91 | 92 | def save_model(self, save_path='./', filename='model'): 93 | model = {'actor': self.actor, 'critic': self.critic} 94 | torch.save(model, f"{save_path}\\{filename}.pth") 95 | 96 | def load_model(self, load_path): 97 | model = torch.load(load_path) 98 | self.actor = model['actor'] 99 | self.critic = model['critic'] 100 | 101 | def take_action(self, state): 102 | state = torch.tensor([state], dtype=torch.float).to(self.device) 103 | probs = self.actor(state) 104 | action_dist = torch.distributions.Categorical(probs) 105 | action = action_dist.sample() 106 | return action.item() 107 | 108 | def evaluate(self, state): 109 | state = torch.tensor([state], dtype=torch.float).to(self.device) 110 | probs = self.actor(state) 111 | action = torch.argmax(probs, dim=1) 112 | return action.item() 113 | 114 | def update(self, transition_dict): 115 | rewards = torch.tensor(transition_dict['rewards'], 116 | dtype=torch.float).view(-1, 1).to(self.device) 117 | states = torch.tensor(transition_dict['states'], 118 | dtype=torch.float).to(self.device) 119 | actions = torch.tensor(transition_dict['actions']).view(-1, 1).to( 120 | self.device) 121 | next_states = torch.tensor(transition_dict['next_states'], 122 | dtype=torch.float).to(self.device) 123 | next_actions = torch.tensor(transition_dict['next_actions']).view( 124 | -1, 1).to(self.device) 125 | dones = torch.tensor(transition_dict['dones'], 126 | dtype=torch.float).to(self.device).view(-1, 1) 127 | 128 | q_now = self.critic(states).gather(1, actions).view(-1, 1) 129 | q_next = self.critic(next_states).gather(1, next_actions).view(-1, 1) 130 | 131 | y_now = (self.gamma * q_next * (1 - dones) + rewards).view(-1, 1) 132 | td_delta = y_now - q_now 133 | log_prob = torch.log(self.actor(states).gather(1, actions)) 134 | actor_loss = torch.mean(-log_prob * td_delta.detach()) 135 | critic_loss = torch.mean(F.mse_loss(y_now.detach(), q_now)) 136 | 137 | self.actor_optimizer.zero_grad() 138 | self.critic_optimizer.zero_grad() 139 | 140 | actor_loss.backward() 141 | critic_loss.backward() 142 | 143 | self.actor_optimizer.step() 144 | self.critic_optimizer.step() 145 | 146 | 147 | if __name__ == "__main__": 148 | 149 | gamma = 0.98 150 | algorithm_name = "AC" 151 | num_episodes = 1000 152 | 153 | actor_lr = 5e-3 154 | critic_lr = 3e-4 155 | print(algorithm_name, actor_lr, critic_lr) 156 | device = torch.device('cuda') 157 | 158 | env_name = 'CartPole-v0' 159 | 160 | env = gym.make(env_name) 161 | 162 | random.seed(0) 163 | np.random.seed(0) 164 | env.seed(0) 165 | torch.manual_seed(0) 166 | 167 | state_dim = env.observation_space.shape[0] 168 | hidden_dim = 128 169 | action_dim = env.action_space.n 170 | 171 | agent = AC(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, gamma, 172 | device) 173 | 174 | return_list = [] 175 | max_reward = 0 176 | for i in range(10): 177 | with tqdm(total=int(num_episodes / 10), 178 | desc='Iteration %d' % i) as pbar: 179 | for i_episodes in range(int(num_episodes / 10)): 180 | episode_return = 0 181 | state = env.reset() 182 | done = False 183 | 184 | transition_dict = { 185 | 'states': [], 186 | 'actions': [], 187 | 'next_states': [], 188 | 'next_actions': [], 189 | 'rewards': [], 190 | 'dones': [] 191 | } 192 | while not done: 193 | action = agent.take_action(state) 194 | next_state, reward, done, _ = env.step(action) 195 | next_action = agent.take_action(next_state) 196 | 197 | transition_dict['states'].append(state) 198 | transition_dict['actions'].append(action) 199 | transition_dict['next_states'].append(next_state) 200 | transition_dict['next_actions'].append(next_action) 201 | transition_dict['rewards'].append(reward) 202 | transition_dict['dones'].append(done) 203 | 204 | state = next_state 205 | episode_return += reward 206 | agent.update(transition_dict) 207 | 208 | return_list.append(episode_return) 209 | plot_smooth_reward(return_list, env_name, "reward_curve(AC)") 210 | 211 | if episode_return >= max_reward: 212 | max_reward = episode_return 213 | agent.save_model(env_name, algorithm_name) 214 | if (i_episodes + 1 % 10 == 0): 215 | pbar.set_postfix({ 216 | 'episode': 217 | '%d' % (num_episodes / 10 * i + i_episodes + 1), 218 | 'return': 219 | '%.3f' % np.mean(return_list[-10:]) 220 | }) 221 | pbar.update(1) 222 | -------------------------------------------------------------------------------- /ac/ac_target.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import gym 4 | import pygame 5 | import collections 6 | import random 7 | import torch.nn.functional as F 8 | from tqdm import tqdm 9 | import matplotlib.pyplot as plt 10 | import time 11 | import torch.nn as nn 12 | from rl_utils import plot_smooth_reward 13 | 14 | 15 | class CriticNet(nn.Module): 16 | def __init__(self, state_dim, hidden_dim, action_dim) -> None: 17 | super().__init__() 18 | self.fc1 = torch.nn.Linear(state_dim, hidden_dim) 19 | self.fc2 = torch.nn.Linear(hidden_dim, action_dim) 20 | 21 | def forward(self, state): 22 | hidden_state = F.relu(self.fc1(state)) 23 | acion_value = self.fc2(hidden_state) 24 | return acion_value 25 | 26 | 27 | class ActorNet(nn.Module): 28 | def __init__(self, state_dim, hidden_dim, action_dim) -> None: 29 | super().__init__() 30 | self.fc1 = nn.Linear(state_dim, hidden_dim) 31 | self.fc2 = nn.Linear(hidden_dim, action_dim) 32 | 33 | def forward(self, state): 34 | 35 | hidden_state = F.relu(self.fc1(state)) 36 | # print(self.fc2(hidden_state)) 37 | probs = F.softmax(self.fc2(hidden_state), dim=1) 38 | return probs 39 | 40 | 41 | class AC_TARGET: 42 | def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr, 43 | gamma, tau, device) -> None: 44 | self.actor = ActorNet(state_dim, hidden_dim, action_dim).to(device) 45 | self.critic = CriticNet(state_dim, hidden_dim, action_dim).to(device) 46 | self.critic_target = CriticNet(state_dim, hidden_dim, 47 | action_dim).to(device) 48 | self.critic_target.load_state_dict(self.critic.state_dict()) 49 | self.gamma = gamma 50 | self.tau = tau 51 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), 52 | lr=actor_lr) 53 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), 54 | lr=critic_lr) 55 | self.device = device 56 | 57 | def save_model(self, save_path='./', filename='model'): 58 | torch.save(self.actor.state_dict(), f"{save_path}\\{filename}.pth") 59 | 60 | def load_model(self, load_path): 61 | self.actor.load_state_dict(torch.load(load_path)) 62 | 63 | def take_action(self, state): 64 | state = torch.tensor([state], dtype=torch.float).to(self.device) 65 | probs = self.actor(state) 66 | action_dist = torch.distributions.Categorical(probs) 67 | action = action_dist.sample() 68 | return action.item() 69 | 70 | def soft_update(self, net, target_net): 71 | for param_target, param in zip(target_net.parameters(), 72 | net.parameters()): 73 | param_target.data.copy_(param_target.data * (1 - self.tau) + 74 | param.data * self.tau) 75 | 76 | def update(self, transition_dict): 77 | rewards = torch.tensor(transition_dict['rewards'], 78 | dtype=torch.float).view(-1, 1).to(self.device) 79 | states = torch.tensor(transition_dict['states'], 80 | dtype=torch.float).to(self.device) 81 | actions = torch.tensor(transition_dict['actions']).view(-1, 1).to( 82 | self.device) 83 | next_states = torch.tensor(transition_dict['next_states'], 84 | dtype=torch.float).to(self.device) 85 | next_actions = torch.tensor(transition_dict['next_actions']).view( 86 | -1, 1).to(self.device) 87 | dones = torch.tensor(transition_dict['dones'], 88 | dtype=torch.float).to(self.device).view(-1, 1) 89 | 90 | q_now = self.critic(states).gather(1, actions).view(-1, 1) 91 | q_next = self.critic_target(next_states).gather(1, next_actions).view( 92 | -1, 1) 93 | 94 | y_now = (self.gamma * q_next + rewards).view(-1, 1) 95 | # print(y_now) 96 | # print(q_next) 97 | td_delta = y_now - q_now 98 | log_prob = torch.log(self.actor(states).gather(1, actions)) 99 | # print(log_prob) 100 | # print(action) 101 | actor_loss = torch.mean(-log_prob * td_delta.detach()) 102 | critic_loss = torch.mean(F.mse_loss(y_now.detach(), q_now)) 103 | self.actor_optimizer.zero_grad() 104 | self.critic_optimizer.zero_grad() 105 | actor_loss.backward() 106 | critic_loss.backward() 107 | self.actor_optimizer.step() 108 | self.critic_optimizer.step() 109 | 110 | self.soft_update(self.critic, self.critic_target) 111 | 112 | 113 | if __name__ == "__main__": 114 | 115 | gamma = 0.99 116 | algorithm_name = "AC_TARGET" 117 | num_episodes = 5000 118 | 119 | actor_lr = 1e-3 120 | critic_lr = 3e-3 121 | print(algorithm_name, actor_lr, critic_lr) 122 | device = torch.device('cuda') 123 | 124 | env_name = 'Snake-v0' #'CartPole-v0' 125 | 126 | # 注册环境 127 | gym.register(id='Snake-v0', entry_point='snake_env:SnakeEnv') 128 | 129 | env = gym.make(env_name) 130 | 131 | random.seed(0) 132 | np.random.seed(0) 133 | env.seed(0) 134 | torch.manual_seed(0) 135 | 136 | state_dim = env.observation_space.shape[0] 137 | hidden_dim = 128 138 | action_dim = env.action_space.n 139 | tau = 5e-3 140 | agent = AC(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, gamma, 141 | tau, device) 142 | 143 | return_list = [] 144 | max_reward = 0 145 | for i in range(20): 146 | with tqdm(total=int(num_episodes / 10), 147 | desc='Iteration %d' % i) as pbar: 148 | for i_episodes in range(int(num_episodes / 10)): 149 | episode_return = 0 150 | state = env.reset() 151 | done = False 152 | 153 | transition_dict = { 154 | 'states': [], 155 | 'actions': [], 156 | 'next_states': [], 157 | 'next_actions': [], 158 | 'rewards': [], 159 | 'dones': [] 160 | } 161 | 162 | while not done: 163 | action = agent.take_action(state) 164 | next_state, reward, done, _ = env.step(action) 165 | next_action = agent.take_action(next_state) 166 | env.render() 167 | 168 | transition_dict['states'].append(state) 169 | transition_dict['actions'].append(action) 170 | transition_dict['next_states'].append(next_state) 171 | transition_dict['next_actions'].append(next_action) 172 | transition_dict['rewards'].append(reward) 173 | transition_dict['dones'].append(done) 174 | 175 | state = next_state 176 | episode_return += reward 177 | if i_episodes == int(num_episodes / 10) - 1: 178 | time.sleep(0.1) 179 | agent.update(transition_dict) 180 | return_list.append(episode_return) 181 | plot_smooth_reward(return_list, 100, env_name, algorithm_name) 182 | if episode_return > max_reward: 183 | max_reward = episode_return 184 | agent.save_model(env_name, algorithm_name) 185 | if (i_episodes + 1 % 10 == 0): 186 | pbar.set_postfix({ 187 | 'episode': 188 | '%d' % (num_episodes / 10 * i + i_episodes + 1), 189 | 'return': 190 | '%.3f' % np.mean(return_list[-10:]) 191 | }) 192 | pbar.update(1) 193 | -------------------------------------------------------------------------------- /ac/sacwalker.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import random 3 | import gym 4 | import numpy as np 5 | from tqdm import tqdm 6 | import torch 7 | import torch.nn.functional as F 8 | from torch.distributions import Normal 9 | import matplotlib.pyplot as plt 10 | import rl_utils 11 | 12 | 13 | class PolicyNetContinuous(torch.nn.Module): 14 | def __init__(self, state_dim, hidden_dim, action_dim, action_bound): 15 | super(PolicyNetContinuous, self).__init__() 16 | self.fc1 = torch.nn.Linear(state_dim, hidden_dim) 17 | self.fc_mu = torch.nn.Linear(hidden_dim, action_dim) 18 | self.fc_std = torch.nn.Linear(hidden_dim, action_dim) 19 | self.action_bound = action_bound 20 | 21 | def forward(self, x): 22 | x = F.relu(self.fc1(x)) 23 | mu = self.fc_mu(x) 24 | std = F.softplus(self.fc_std(x)) 25 | dist = Normal(mu, std) 26 | normal_sample = dist.rsample() # rsample()是重参数化采样 27 | 28 | action = torch.tanh(normal_sample) 29 | # 计算tanh_normal分布的对数概率密度 30 | log_pi = dist.log_prob(normal_sample).sum(dim=1, keepdim=True) 31 | log_pi -= ( 32 | 2 * 33 | (np.log(2) - normal_sample - F.softplus(-2 * normal_sample))).sum( 34 | dim=1, keepdim=True) 35 | 36 | action = action * self.action_bound 37 | 38 | return action, log_pi 39 | 40 | 41 | class QValueNetContinuous(torch.nn.Module): 42 | def __init__(self, state_dim, hidden_dim, action_dim): 43 | super(QValueNetContinuous, self).__init__() 44 | self.fc1 = torch.nn.Linear(state_dim + action_dim, hidden_dim) 45 | self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim) 46 | self.fc_out = torch.nn.Linear(hidden_dim, 1) 47 | 48 | def forward(self, x, a): 49 | cat = torch.cat([x, a], dim=1) 50 | x = F.relu(self.fc1(cat)) 51 | x = F.relu(self.fc2(x)) 52 | return self.fc_out(x) 53 | 54 | 55 | class SACContinuous: 56 | ''' 处理连续动作的SAC算法 ''' 57 | def __init__(self, state_dim, hidden_dim, action_dim, action_bound, 58 | actor_lr, critic_lr, alpha_lr, target_entropy, tau, gamma, 59 | device): 60 | self.actor = PolicyNetContinuous(state_dim, hidden_dim, action_dim, 61 | action_bound).to(device) # 策略网络 62 | self.critic_1 = QValueNetContinuous(state_dim, hidden_dim, 63 | action_dim).to(device) # 第一个Q网络 64 | self.critic_2 = QValueNetContinuous(state_dim, hidden_dim, 65 | action_dim).to(device) # 第二个Q网络 66 | self.target_critic_1 = QValueNetContinuous(state_dim, 67 | hidden_dim, action_dim).to( 68 | device) # 第一个目标Q网络 69 | self.target_critic_2 = QValueNetContinuous(state_dim, 70 | hidden_dim, action_dim).to( 71 | device) # 第二个目标Q网络 72 | # 令目标Q网络的初始参数和Q网络一样 73 | self.target_critic_1.load_state_dict(self.critic_1.state_dict()) 74 | self.target_critic_2.load_state_dict(self.critic_2.state_dict()) 75 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), 76 | lr=actor_lr) 77 | self.critic_1_optimizer = torch.optim.Adam(self.critic_1.parameters(), 78 | lr=critic_lr) 79 | self.critic_2_optimizer = torch.optim.Adam(self.critic_2.parameters(), 80 | lr=critic_lr) 81 | # 使用alpha的log值,可以使训练结果比较稳定 82 | self.log_alpha = torch.tensor(np.log(0.01), dtype=torch.float) 83 | self.log_alpha.requires_grad = True # 可以对alpha求梯度 84 | self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha], 85 | lr=alpha_lr) 86 | self.target_entropy = target_entropy # 目标熵的大小 87 | self.gamma = gamma 88 | self.tau = tau 89 | self.device = device 90 | 91 | def save_model(self, save_path='./', filename='model'): 92 | torch.save(self.actor.state_dict(), f"{save_path}\\{filename}.pth") 93 | 94 | def load_model(self, load_path): 95 | self.actor.load_state_dict(torch.load(load_path)) 96 | 97 | def take_action(self, state): 98 | state = torch.tensor([state], dtype=torch.float).to(self.device) 99 | action = self.actor(state)[0] 100 | # print(action.detach().cpu().numpy()) 101 | return action.detach().cpu().numpy().flatten() 102 | 103 | def calc_target(self, rewards, next_states, dones): # 计算目标Q值 104 | next_actions, log_prob = self.actor(next_states) 105 | entropy = -log_prob 106 | q1_value = self.target_critic_1(next_states, next_actions) 107 | q2_value = self.target_critic_2(next_states, next_actions) 108 | 109 | next_value = torch.min(q1_value, 110 | q2_value) + self.log_alpha.exp() * entropy 111 | td_target = rewards + self.gamma * next_value * (1 - dones) 112 | return td_target 113 | 114 | def soft_update(self, net, target_net): 115 | for param_target, param in zip(target_net.parameters(), 116 | net.parameters()): 117 | param_target.data.copy_(param_target.data * (1.0 - self.tau) + 118 | param.data * self.tau) 119 | 120 | def update(self, transition_dict): 121 | states = torch.tensor(transition_dict['states'], 122 | dtype=torch.float).to(self.device) 123 | actions = torch.tensor(transition_dict['actions'], 124 | dtype=torch.float).view(-1, 4).to(self.device) 125 | rewards = torch.tensor(transition_dict['rewards'], 126 | dtype=torch.float).view(-1, 1).to(self.device) 127 | next_states = torch.tensor(transition_dict['next_states'], 128 | dtype=torch.float).to(self.device) 129 | dones = torch.tensor(transition_dict['dones'], 130 | dtype=torch.float).view(-1, 1).to(self.device) 131 | 132 | # 更新两个Q网络 133 | td_target = self.calc_target(rewards, next_states, dones) 134 | # print(self.critic_1(states, actions)) 135 | # print(td_target.detach()) 136 | 137 | critic_1_loss = torch.mean( 138 | F.mse_loss(self.critic_1(states, actions), td_target.detach())) 139 | critic_2_loss = torch.mean( 140 | F.mse_loss(self.critic_2(states, actions), td_target.detach())) 141 | self.critic_1_optimizer.zero_grad() 142 | critic_1_loss.backward() 143 | self.critic_1_optimizer.step() 144 | self.critic_2_optimizer.zero_grad() 145 | critic_2_loss.backward() 146 | self.critic_2_optimizer.step() 147 | 148 | # 更新策略网络 149 | new_actions, log_prob = self.actor(states) 150 | entropy = -log_prob 151 | q1_value = self.critic_1(states, new_actions) 152 | q2_value = self.critic_2(states, new_actions) 153 | actor_loss = torch.mean(-self.log_alpha.exp() * entropy - 154 | torch.min(q1_value, q2_value)) 155 | self.actor_optimizer.zero_grad() 156 | actor_loss.backward() 157 | self.actor_optimizer.step() 158 | 159 | # 更新alpha值 160 | alpha_loss = torch.mean( 161 | (entropy - self.target_entropy).detach() * self.log_alpha.exp()) 162 | self.log_alpha_optimizer.zero_grad() 163 | alpha_loss.backward() 164 | self.log_alpha_optimizer.step() 165 | 166 | self.soft_update(self.critic_1, self.target_critic_1) 167 | self.soft_update(self.critic_2, self.target_critic_2) 168 | 169 | 170 | class ReplayBuffer: 171 | def __init__(self, capacity) -> None: 172 | self.buffer = collections.deque(maxlen=capacity) 173 | 174 | def add(self, state, action, reward, next_state, done): 175 | self.buffer.append([state, action, reward, next_state, done]) 176 | 177 | def sample(self, batch_size): 178 | transitions = random.sample(self.buffer, batch_size) 179 | states, actions, rewards, next_states, dones = zip(*transitions) 180 | return states, actions, rewards, next_states, dones 181 | 182 | def size(self): 183 | return len(self.buffer) 184 | 185 | 186 | if __name__ == "__main__": 187 | 188 | algorithm_name = 'SAC_correct' 189 | env_name = 'BipedalWalker-v3' 190 | env = gym.make(env_name) 191 | state_dim = env.observation_space.shape[0] 192 | action_dim = env.action_space.shape[0] 193 | action_bound = env.action_space.high[0] # 动作最大值 194 | 195 | random.seed(0) 196 | np.random.seed(0) 197 | env.seed(0) 198 | torch.manual_seed(0) 199 | 200 | actor_lr = 4e-4 201 | critic_lr = 4e-4 202 | alpha_lr = 4e-4 203 | num_episodes = 1000 204 | hidden_dim = 256 205 | gamma = 0.98 206 | tau = 0.01 # 软更新参数 207 | buffer_size = 100000 208 | minimal_size = 1000 209 | batch_size = 64 210 | max_step = 1000 211 | target_entropy = -env.action_space.shape[0] 212 | device = torch.device( 213 | "cuda") if torch.cuda.is_available() else torch.device("cpu") 214 | 215 | replay_buffer = ReplayBuffer(buffer_size) 216 | agent = SACContinuous(state_dim, hidden_dim, action_dim, action_bound, 217 | actor_lr, critic_lr, alpha_lr, target_entropy, tau, 218 | gamma, device) 219 | # agent.load_model('BipedalWalker-v3\SACx.pth') 220 | return_list = [] 221 | max_reward = -float('inf') 222 | for i in range(10): 223 | with tqdm(total=int(num_episodes / 10), 224 | desc='Iteration %d' % i) as pbar: 225 | for i_episodes in range(int(num_episodes / 10)): 226 | episode_return = 0 227 | state = env.reset() 228 | done = False 229 | 230 | for _ in range(max_step): 231 | 232 | action = agent.take_action(state) 233 | 234 | next_state, reward, done, _ = env.step(action) 235 | episode_return += reward 236 | done_ = done 237 | if reward <= -100: 238 | done = True 239 | else: 240 | done = False 241 | 242 | if reward == -100: 243 | reward = -1 244 | 245 | if i_episodes == int(num_episodes / 10) - 1: 246 | env.render() 247 | 248 | replay_buffer.add(state, action, reward, next_state, done) 249 | state = next_state 250 | 251 | if replay_buffer.size() > minimal_size: 252 | b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample( 253 | batch_size) 254 | transition_dict = { 255 | 'states': b_s, 256 | 'actions': b_a, 257 | 'next_states': b_ns, 258 | 'rewards': b_r, 259 | 'dones': b_d 260 | } 261 | 262 | agent.update(transition_dict) 263 | 264 | if done_: 265 | break 266 | 267 | return_list.append(episode_return) 268 | 269 | rl_utils.plot_smooth_reward(return_list, 100, env_name, 270 | algorithm_name) 271 | 272 | if episode_return >= max_reward: 273 | max_reward = episode_return 274 | agent.save_model(env_name, algorithm_name) 275 | rl_utils.plot_smooth_reward(return_list, 100, env_name, 276 | algorithm_name) 277 | 278 | if (i_episodes + 1 % 10 == 0): 279 | pbar.set_postfix({ 280 | 'episode': 281 | '%d' % (num_episodes / 10 * i + i_episodes + 1), 282 | 'max_reward': 283 | '%d' % (max_reward), 284 | 'return': 285 | '%.3f' % np.mean(return_list[-10:]) 286 | }) 287 | pbar.update(1) 288 | -------------------------------------------------------------------------------- /ac/test_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import gym 4 | import pygame 5 | import collections 6 | import random 7 | import torch.nn.functional as F 8 | from tqdm import tqdm 9 | import matplotlib.pyplot as plt 10 | import time 11 | import torch.nn as nn 12 | import yaml 13 | import multiprocessing 14 | import os 15 | import argparse 16 | from a3c import get_args, A3C, ActorNet, CriticNet 17 | from a2c import A2C 18 | from ac import AC 19 | from rl_utils import save_video 20 | 21 | opt = get_args() 22 | env = gym.make(opt.env_name) 23 | 24 | random.seed(0) 25 | np.random.seed(0) 26 | torch.manual_seed(0) 27 | env.seed(0) 28 | 29 | state_dim = env.observation_space.shape[0] 30 | hidden_dim = 128 31 | action_dim = env.action_space.n 32 | 33 | AC_agent = AC(state_dim, hidden_dim, action_dim, opt.actor_lr, opt.critic_lr, 34 | opt.gamma, opt.device) 35 | AC_agent.load_model("CartPole-v0\AC.pth") 36 | AC_agent.actor.cpu() 37 | AC_agent.actor.cpu() 38 | 39 | # use gpu to train A2C algorithm so that we need to cpu() 40 | A2C_agent = A2C(state_dim, hidden_dim, action_dim, opt.actor_lr, opt.critic_lr, 41 | opt.gamma, opt.device) 42 | A2C_agent.load_model("CartPole-v0\A2C.pth") 43 | A2C_agent.actor.cpu() 44 | A2C_agent.actor.cpu() 45 | 46 | A3C_agent = A3C(state_dim, hidden_dim, action_dim, opt.actor_lr, opt.critic_lr, 47 | opt.gamma, opt.device) 48 | A3C_agent.load_model("CartPole-v0\A3C.pth") 49 | 50 | print("test AC") 51 | print("-----------------------------------------") 52 | state = env.reset() 53 | done = False 54 | frame_list = [] 55 | total_reward = 0 56 | step = 0 57 | max_step = 200 58 | while step < max_step: 59 | action = AC_agent.evaluate(state) # 随机选择一个行动 60 | state, reward, done, info = env.step(action) 61 | rendered_image = env.render('rgb_array') 62 | frame_list.append(rendered_image) 63 | total_reward += reward 64 | step += 1 65 | 66 | save_video(frame_list, "CartPole-V0", "AC", "gif") 67 | print(total_reward) 68 | 69 | print("test A2C") 70 | print("-----------------------------------------") 71 | state = env.reset() 72 | done = False 73 | frame_list = [] 74 | total_reward = 0 75 | step = 0 76 | max_step = 200 77 | while step < max_step: 78 | action = A2C_agent.evaluate(state) # 随机选择一个行动 79 | state, reward, done, info = env.step(action) 80 | rendered_image = env.render('rgb_array') 81 | frame_list.append(rendered_image) 82 | total_reward += reward 83 | step += 1 84 | 85 | save_video(frame_list, "CartPole-V0", "A2C", "gif") 86 | print(total_reward) 87 | 88 | print("test A3C") 89 | print("-----------------------------------------") 90 | state = env.reset() 91 | done = False 92 | frame_list = [] 93 | total_reward = 0 94 | step = 0 95 | max_step = 200 96 | while step < max_step: 97 | action = A2C_agent.evaluate(state) # 随机选择一个行动 98 | state, reward, done, info = env.step(action) 99 | rendered_image = env.render('rgb_array') 100 | frame_list.append(rendered_image) 101 | total_reward += reward 102 | step += 1 103 | 104 | save_video(frame_list, "CartPole-v0", "A3C", "gif") 105 | print(total_reward) 106 | -------------------------------------------------------------------------------- /custom_env/README.md: -------------------------------------------------------------------------------- 1 | # rl_learning 2 | 3 | ## Snake-v0 4 | ### state space: 6 5 | (fx-x, fy-y, left, right, up ,down) 6 | 7 | (x, y) is the position of the snake head. 8 | 9 | (fx, fy) is the position of the food. 10 | 11 | left indicates whether the left side of the snake’s head is a boundary or part of the snake’s body. left can only be 0 or 1. 12 | 13 | ... 14 | 15 | ### action space: 4 16 | The snake can move in four directions: up, down, left, and right. 17 | 18 | ### Algorithm 19 | #### Implemented algorithm 20 | dqn, ddqn, drqn 21 | 22 | reinforce, reinforce with baseline 23 | 24 | ac, ac with target, a2c, a2c with target 25 | 26 | ppo 27 | 28 | #### To-be-implemented algorithm 29 | dueling dqn 30 | 31 | a3c 32 | 33 | trpo 34 | 35 | ddpg 36 | 37 | sac 38 | 39 | #### Successful results 40 | 41 | !['ppo 20*20'](https://github.com/sunwuzhou03/rl_learning/blob/master/gif/Snake-v1/PPO.gif) 42 | 43 | You can find more successful result in gif file. 44 | 45 | ## CartPole-v0 46 | 47 | You can just modify the env_name from 'Snake-v0' to 'CartPole-v0' and adjust the hyper-parameters so that you train. 48 | 49 | The program include: ac, a2c, ac_target, a2c_target, ddqn, ddrqn, dqn, drqn, ppo, reinforce, reinforce_baseline. 50 | 51 | ## Pendulum-v1 52 | 53 | You can run the ppoPendulum.py program and train agent in this environment. 54 | 55 | ## BipedalWalker-v3 and BipedalWalkerHardcore-v3 56 | 57 | You can just modify the env_name from 'BipedalWalker-v3' to 'BipedalWalkerHardcore-v3' in the program sacwalker.py so that you train the agent in two environment. 58 | 59 | ### The results 60 | 61 | #### BipedalWalker-v3 62 | 63 | ![BipedalWalker-v3](https://github.com/sunwuzhou03/rl_learning/blob/master/BipedalWalker-v3/BipedalWalker-v3.gif) 64 | 65 | #### BipedalWalkerHardcore-v3 66 | 67 | ![BipedalWalkerHardcore-v3](https://github.com/sunwuzhou03/rl_learning/blob/master/BipedalWalkerHardcore-v3/BipedalWalkerHardcore-v3.gif) 68 | 69 | -------------------------------------------------------------------------------- /custom_env/snake_env.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | import random 4 | import pygame 5 | import sys 6 | 7 | 8 | class SnakeEnv(gym.Env): 9 | def __init__(self, width=10, height=10): 10 | super(SnakeEnv, self).__init__() 11 | 12 | self.width = width 13 | self.height = height 14 | self.grid = np.zeros((height, width)) 15 | self.snake = [(0, 0)] 16 | self.food = self.generate_food() 17 | self.direction = "right" 18 | self.time = 1 19 | self.cell_size = 40 20 | self.window_size = (self.width * self.cell_size, 21 | self.height * self.cell_size) 22 | # 定义颜色 23 | self.color_bg = (255, 255, 255) 24 | self.color_head = (0, 120, 120) 25 | self.color_snake = (0, 255, 0) 26 | self.color_food = (255, 0, 0) 27 | 28 | # 定义贪吃蛇环境的观测空间和行动空间 29 | inf = self.width + self.height 30 | low = np.array([-inf, -1.0, 0, 0, 0, 0]) # 连续状态空间的最小值 31 | high = np.array([inf, 1.0, 1.0, 1.0, 1.0, 1.0]) # 连续状态空间的最大值 32 | continuous_space = gym.spaces.Box(low, high, shape=(6, ), dtype=float) 33 | self.observation_space = continuous_space 34 | 35 | self.action_space = gym.spaces.Discrete(4) 36 | 37 | def generate_food(self): 38 | while True: 39 | x = random.randint(0, self.width - 1) 40 | y = random.randint(0, self.height - 1) 41 | if (x, y) not in self.snake: 42 | return (x, y) 43 | 44 | def reset(self): 45 | self.grid = np.zeros((self.height, self.width)) 46 | self.snake = [(0, 0)] 47 | self.food = self.generate_food() 48 | self.time = 1 49 | self.direction = "right" 50 | self.update_grid() 51 | 52 | return self.get_state() 53 | 54 | def get_state(self): 55 | state = [] 56 | x, y = self.snake[0] 57 | fx, fy = self.food 58 | xbase = self.width 59 | ybase = self.height 60 | x_norm, y_norm = (fx - x) / xbase, (fy - y) / ybase 61 | state.append(fx - x) 62 | state.append(fy - y) 63 | for gx, gy in [(0, 1), (1, 0), (0, -1), (-1, 0)]: 64 | dx, dy = x + gx, y + gy 65 | if dx < 0 or dy < 0 or dx >= self.width or dy >= self.height or ( 66 | dx, dy) in self.snake: 67 | state.append(0) 68 | continue 69 | else: 70 | state.append(1) #四个方向可以走 71 | return np.array(state) 72 | 73 | def update_grid(self): 74 | self.grid = np.zeros((self.height, self.width)) 75 | x, y = self.snake[0] 76 | self.grid[y, x] = 1 77 | for x, y in self.snake[1:]: 78 | self.grid[y, x] = 2 79 | fx, fy = self.food 80 | self.grid[fy, fx] = 3 81 | 82 | def step(self, action): 83 | x, y = self.snake[0] 84 | if action == 0: # 左 85 | if self.direction != "right": 86 | self.direction = "left" 87 | elif action == 1: # 上 88 | if self.direction != "down": 89 | self.direction = "up" 90 | elif action == 2: # 右 91 | if self.direction != "left": 92 | self.direction = "right" 93 | elif action == 3: # 下 94 | if self.direction != "up": 95 | self.direction = "down" 96 | 97 | if self.direction == "left": 98 | x -= 1 99 | elif self.direction == "up": 100 | y -= 1 101 | elif self.direction == "right": 102 | x += 1 103 | elif self.direction == "down": 104 | y += 1 105 | 106 | if x < 0 or x >= self.width or y < 0 or y >= self.height or ( 107 | x, y) in self.snake: 108 | reward = -0.5 - 1 / (len(self.snake)) 109 | self.time = 1 110 | done = True 111 | elif (x, y) == self.food: 112 | reward = 4 + len(self.snake) * 0.1 113 | self.time = 1 114 | self.snake.insert(0, (x, y)) 115 | self.food = self.generate_food() 116 | self.update_grid() 117 | done = False 118 | else: 119 | fx, fy = self.food 120 | d = (abs(x - fx) + abs(y - fy)) 121 | reward = 0.1 * (2 - d) / (self.time) 122 | self.snake.insert(0, (x, y)) 123 | self.snake.pop() 124 | self.update_grid() 125 | done = False 126 | self.time += 1 127 | return self.get_state(), reward, done, {} 128 | 129 | def render(self, mode='human'): 130 | # 初始化pygame窗口 131 | pygame.init() 132 | pygame.font.init() 133 | self.font = pygame.font.Font(None, 30) 134 | self.window = pygame.display.set_mode(self.window_size) 135 | 136 | pygame.display.set_caption("Snake Game") 137 | 138 | if mode == 'rgb_array': 139 | surface = pygame.Surface( 140 | (self.width * self.cell_size, self.height * self.cell_size)) 141 | self.window = surface 142 | 143 | self.window.fill(self.color_bg) 144 | snake_length_text = self.font.render("Length: " + str(len(self.snake)), 145 | True, (0, 25, 25)) 146 | self.window.blit(snake_length_text, (0, 0)) 147 | 148 | for event in pygame.event.get(): 149 | if event.type == pygame.QUIT: 150 | pygame.quit() 151 | sys.exit() 152 | 153 | for y in range(self.height): 154 | for x in range(self.width): 155 | cell_value = self.grid[y, x] 156 | cell_rect = pygame.Rect(x * self.cell_size, y * self.cell_size, 157 | self.cell_size, self.cell_size) 158 | if cell_value == 1: # 贪吃蛇身体 159 | pygame.draw.rect(self.window, self.color_head, cell_rect) 160 | elif cell_value == 2: # 贪吃蛇身体 161 | pygame.draw.rect(self.window, self.color_snake, cell_rect) 162 | elif cell_value == 3: # 食物 163 | pygame.draw.rect(self.window, self.color_food, cell_rect) 164 | 165 | pygame.display.flip() 166 | 167 | if mode == 'rgb_array': 168 | image_array = pygame.surfarray.array3d(self.window) 169 | return image_array 170 | -------------------------------------------------------------------------------- /ddpg/ddpg.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ddpg/ddpg.py -------------------------------------------------------------------------------- /ddpg/td3.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ddpg/td3.py -------------------------------------------------------------------------------- /dqn/ddqn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import gym 4 | import pygame 5 | import collections 6 | import random 7 | import torch.nn.functional as F 8 | from tqdm import tqdm 9 | import matplotlib.pyplot as plt 10 | import time 11 | from rl_utils import plot_smooth_reward 12 | 13 | 14 | class Qnet(torch.nn.Module): 15 | def __init__(self, state_dim, hidden_dim, action_dim) -> None: 16 | super().__init__() 17 | self.fc1 = torch.nn.Linear(state_dim, hidden_dim) 18 | self.fc2 = torch.nn.Linear(hidden_dim, action_dim) 19 | 20 | def forward(self, state): 21 | hidden_state = F.relu(self.fc1(state)) 22 | acion_value = self.fc2(hidden_state) 23 | return acion_value 24 | 25 | 26 | class DDQN: 27 | def __init__(self, state_dim, hidden_dim, action_dim, learning_rate, gamma, 28 | epsilon, update_target, device) -> None: 29 | self.Q = Qnet(state_dim, hidden_dim, action_dim).to(device) 30 | self.Q_target = Qnet(state_dim, hidden_dim, action_dim).to(device) 31 | self.Q_target.load_state_dict(self.Q.state_dict()) 32 | self.optimizer = torch.optim.Adam(self.Q.parameters(), 33 | lr=learning_rate) 34 | self.action_dim = action_dim 35 | self.gamma = gamma 36 | self.epsilon = epsilon 37 | self.device = device 38 | self.count = 1 39 | self.update_target = update_target 40 | 41 | def save_model(self, save_path='./', filename='model'): 42 | torch.save(self.Q.state_dict(), f"{save_path}\\{filename}.pth") 43 | 44 | def load_model(self, load_path): 45 | self.Q.load_state_dict(torch.load(load_path)) 46 | self.Q_target.load_state_dict(torch.load(load_path)) 47 | 48 | def take_action(self, state): 49 | state = torch.tensor([state], dtype=torch.float).to(self.device) 50 | Q_value = self.Q(state) 51 | if np.random.random() < self.epsilon: 52 | action = np.random.randint(self.action_dim) 53 | else: 54 | action = torch.argmax(Q_value).item() 55 | return action 56 | 57 | def update(self, states, actions, rewards, next_states, dones): 58 | states_cuda = torch.tensor(states, dtype=torch.float).to(self.device) 59 | actions_cuda = torch.tensor(actions, 60 | dtype=torch.float).view(-1, 61 | 1).to(self.device) 62 | rewards_cuda = torch.tensor(rewards, 63 | dtype=torch.float).to(self.device).view( 64 | -1, 1) 65 | next_states_cuda = torch.tensor(next_states, 66 | dtype=torch.float).to(self.device) 67 | dones_cuda = torch.tensor(dones, 68 | dtype=torch.float).to(self.device).view( 69 | -1, 1) 70 | q_now = self.Q(states_cuda).gather(1, actions_cuda.long()).view(-1, 1) 71 | 72 | actions_star = self.Q(next_states_cuda).max(1)[1].view(-1, 1) 73 | 74 | q_next = self.Q_target(next_states_cuda).gather(1, actions_star).view( 75 | -1, 1) 76 | 77 | y_now = (rewards_cuda + self.gamma * q_next * (1 - dones_cuda)).view( 78 | -1, 1) 79 | td_error = q_now - y_now 80 | 81 | loss = torch.mean(F.mse_loss(q_now, y_now)) 82 | self.optimizer.zero_grad() 83 | loss.backward() 84 | self.optimizer.step() 85 | 86 | if self.count % self.update_target == 0: 87 | self.Q_target.load_state_dict(self.Q.state_dict()) 88 | 89 | self.count += 1 90 | 91 | 92 | class ReplayBuffer: 93 | def __init__(self, capacity) -> None: 94 | self.buffer = collections.deque(maxlen=capacity) 95 | 96 | def add(self, state, action, reward, next_state, done): 97 | self.buffer.append([state, action, reward, next_state, done]) 98 | 99 | def sample(self, batch_size): 100 | transitions = random.sample(self.buffer, batch_size) 101 | states, actions, rewards, next_states, dones = zip(*transitions) 102 | return states, actions, rewards, next_states, dones 103 | 104 | def size(self): 105 | return len(self.buffer) 106 | 107 | 108 | if __name__ == "__main__": 109 | algorithm_name = "demo" 110 | gamma = 0.99 111 | 112 | num_episodes = 5000 113 | buffersize = 10000 114 | minmal_size = 500 115 | batch_size = 64 116 | epsilon = 0.01 117 | learning_rate = 2e-3 118 | device = torch.device('cuda') 119 | 120 | env_name = 'CartPole-v0' 121 | # 注册环境 122 | gym.register(id='Snake-v0', entry_point='snake_env:SnakeEnv') 123 | 124 | env = gym.make(env_name) 125 | random.seed(0) 126 | np.random.seed(0) 127 | env.seed(0) 128 | torch.manual_seed(0) 129 | replay_buffer = ReplayBuffer(buffersize) 130 | 131 | state_dim = env.observation_space.shape[0] 132 | hidden_dim = 128 133 | action_dim = env.action_space.n 134 | update_target = 100 135 | agent = DDQN(state_dim, hidden_dim, action_dim, learning_rate, gamma, 136 | epsilon, update_target, device) 137 | 138 | return_list = [] 139 | max_reward = 0 140 | for i in range(20): 141 | with tqdm(total=int(num_episodes / 10), 142 | desc='Iteration %d' % i) as pbar: 143 | for i_episodes in range(int(num_episodes / 10)): 144 | episode_return = 0 145 | state = env.reset() 146 | done = False 147 | while not done: 148 | action = agent.take_action(state) 149 | next_state, reward, done, _ = env.step(action) 150 | 151 | if i_episodes == int(num_episodes / 10) - 1: 152 | env.render() 153 | time.sleep(0.1) 154 | replay_buffer.add(state, action, reward, next_state, done) 155 | state = next_state 156 | episode_return += reward 157 | if replay_buffer.size() > minmal_size: 158 | b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample( 159 | batch_size) 160 | agent.update(b_s, b_a, b_r, b_ns, b_d) 161 | return_list.append(episode_return) 162 | plot_smooth_reward(return_list, 100, env_name, algorithm_name) 163 | if episode_return > max_reward: 164 | max_reward = episode_return 165 | agent.save_model(env_name, algorithm_name) 166 | if (i_episodes + 1 % 10 == 0): 167 | pbar.set_postfix({ 168 | 'episode': 169 | '%d' % (num_episodes / 10 * i + i_episodes + 1), 170 | 'return': 171 | '%.3f' % np.mean(return_list[-10:]) 172 | }) 173 | pbar.update(1) 174 | -------------------------------------------------------------------------------- /dqn/ddrqn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import gym 4 | import pygame 5 | import collections 6 | import random 7 | import torch.nn.functional as F 8 | from tqdm import tqdm 9 | import matplotlib.pyplot as plt 10 | import time 11 | from rl_utils import plot_smooth_reward 12 | 13 | 14 | class Qnet(torch.nn.Module): 15 | def __init__(self, 16 | obs_dim, 17 | hidden_dim, 18 | action_dim, 19 | device, 20 | num_layers=1, 21 | batch_first=True) -> None: 22 | super().__init__() 23 | self.action_dim = action_dim 24 | self.hidden_dim = hidden_dim 25 | self.device = device 26 | self.num_layers = num_layers 27 | self.batch_first = batch_first 28 | self.fc1 = torch.nn.Linear(obs_dim, hidden_dim) 29 | self.fc2 = torch.nn.Linear(hidden_dim, action_dim) 30 | self.lstm = torch.nn.LSTM(hidden_dim, 31 | hidden_dim, 32 | self.num_layers, 33 | batch_first=True) 34 | 35 | def forward(self, state, max_seq=1, batch_size=1): 36 | h0 = torch.zeros(self.num_layers, state.size(0), 37 | self.hidden_dim).to(device) 38 | c0 = torch.zeros(self.num_layers, state.size(0), 39 | self.hidden_dim).to(device) 40 | 41 | state = F.relu(self.fc1(state)).reshape(-1, max_seq, self.hidden_dim) 42 | state, _ = self.lstm(state, (h0, c0)) 43 | action_value = self.fc2(state).view(-1, self.action_dim) 44 | return action_value 45 | 46 | 47 | class DDRQN: 48 | def __init__(self, state_dim, hidden_dim, action_dim, learning_rate, gamma, 49 | epsilon, update_target, device) -> None: 50 | self.Q = Qnet(state_dim, hidden_dim, action_dim, device).to(device) 51 | self.Q_target = Qnet(state_dim, hidden_dim, action_dim, 52 | device).to(device) 53 | self.Q_target.load_state_dict(self.Q.state_dict()) 54 | self.optimizer = torch.optim.Adam(self.Q.parameters(), 55 | lr=learning_rate) 56 | self.action_dim = action_dim 57 | print(action_dim) 58 | self.gamma = gamma 59 | self.epsilon = epsilon 60 | self.device = device 61 | self.count = 1 62 | self.update_target = update_target 63 | 64 | def save_model(self, save_path='./', filename='model'): 65 | torch.save(self.Q.state_dict(), f"{save_path}\\{filename}.pth") 66 | 67 | def load_model(self, load_path): 68 | self.Q.load_state_dict(torch.load(load_path)) 69 | self.Q_target.load_state_dict(torch.load(load_path)) 70 | 71 | def take_action(self, state): 72 | state = torch.tensor([state], dtype=torch.float).to(self.device) 73 | Q_value = self.Q(state) 74 | if np.random.random() < self.epsilon: 75 | action = np.random.randint(self.action_dim) 76 | else: 77 | action = torch.argmax(Q_value).item() 78 | return action 79 | 80 | def update(self, 81 | states, 82 | actions, 83 | rewards, 84 | next_states, 85 | dones, 86 | max_seq=1, 87 | batch_size=1): 88 | states = torch.tensor(states, dtype=torch.float).to(self.device) 89 | actions = torch.tensor(actions, 90 | dtype=torch.float).view(-1, 1).to(self.device) 91 | rewards = torch.tensor(rewards, 92 | dtype=torch.float).to(self.device).view(-1, 1) 93 | next_states = torch.tensor(next_states, 94 | dtype=torch.float).to(self.device) 95 | dones = torch.tensor(dones, 96 | dtype=torch.float).to(self.device).view(-1, 1) 97 | q_now = self.Q.forward(states, max_seq, 98 | batch_size).gather(1, 99 | actions.long()).view(-1, 1) 100 | 101 | actions_star = self.Q.forward(next_states, max_seq, 102 | batch_size).max(1)[1].view(-1, 1) 103 | 104 | q_next = self.Q_target.forward(next_states, 105 | max_seq, batch_size).gather( 106 | 1, actions_star).view(-1, 1) 107 | 108 | y_now = (rewards + self.gamma * q_next * (1 - dones)).view(-1, 1) 109 | td_error = q_now - y_now 110 | 111 | loss = torch.mean(F.mse_loss(q_now, y_now)) 112 | self.optimizer.zero_grad() 113 | loss.backward() 114 | self.optimizer.step() 115 | 116 | if self.count % self.update_target == 0: 117 | self.Q_target.load_state_dict(self.Q.state_dict()) 118 | 119 | self.count += 1 120 | 121 | 122 | class Recurrent_Memory_ReplayBuffer: 123 | def __init__(self, capacity, max_seq) -> None: 124 | self.max_seq = max_seq 125 | self.buffer = collections.deque(maxlen=capacity) 126 | 127 | def add(self, state, action, reward, next_state, done): 128 | self.buffer.append([state, action, reward, next_state, done]) 129 | 130 | # def sample(self, batch_size): 131 | # transitions = random.sample(self.buffer, batch_size) 132 | # states, actions, rewards, next_states, dones = zip(*transitions) 133 | # return states, actions, rewards, next_states, dones 134 | 135 | def sample(self, batch_size): 136 | # sample episodic memory 137 | states, actions, rewards, next_states, dones = [], [], [], [], [] 138 | for _ in range(batch_size): 139 | finish = random.randint(self.max_seq, self.size() - 1) 140 | begin = finish - self.max_seq 141 | data = [] 142 | for idx in range(begin, finish): 143 | data.append(self.buffer[idx]) 144 | state, action, reward, next_state, done = zip(*data) 145 | states.append(np.vstack(state)) 146 | actions.append(action) 147 | rewards.append(reward) 148 | next_states.append(np.vstack(next_state)) 149 | dones.append(done) 150 | 151 | states = np.array(states) 152 | actions = np.array(actions) 153 | rewards = np.array(rewards) 154 | next_states = np.array(next_states) 155 | dones = np.array(dones) 156 | 157 | return states, actions, rewards, next_states, dones 158 | 159 | def size(self): 160 | return len(self.buffer) 161 | 162 | 163 | if __name__ == "__main__": 164 | algorithm_name = "DDRQN1" 165 | gamma = 0.99 166 | 167 | num_episodes = 5000 168 | buffersize = 10000 169 | minmal_size = 500 170 | batch_size = 64 171 | epsilon = 0.01 172 | learning_rate = 2e-3 173 | max_seq = 1 174 | device = torch.device('cuda') 175 | 176 | env_name = 'Snake-v0' 177 | # 注册环境 178 | gym.register(id='Snake-v0', entry_point='snake_env:SnakeEnv') 179 | 180 | env = gym.make(env_name) 181 | random.seed(0) 182 | np.random.seed(0) 183 | env.seed(0) 184 | torch.manual_seed(0) 185 | replay_buffer = Recurrent_Memory_ReplayBuffer(buffersize, max_seq) 186 | 187 | state_dim = env.observation_space.shape[0] 188 | hidden_dim = 128 189 | action_dim = env.action_space.n 190 | update_target = 50 191 | agent = DDRQN(state_dim, hidden_dim, action_dim, learning_rate, gamma, 192 | epsilon, update_target, device) 193 | 194 | return_list = [] 195 | max_reward = 0 196 | for i in range(20): 197 | with tqdm(total=int(num_episodes / 10), 198 | desc='Iteration %d' % i) as pbar: 199 | for i_episodes in range(int(num_episodes / 10)): 200 | episode_return = 0 201 | state = env.reset() 202 | done = False 203 | while not done: 204 | action = agent.take_action(state) 205 | next_state, reward, done, _ = env.step(action) 206 | 207 | if i_episodes == int(num_episodes / 10) - 1: 208 | env.render() 209 | time.sleep(0.1) 210 | replay_buffer.add(state, action, reward, next_state, done) 211 | state = next_state 212 | episode_return += reward 213 | if replay_buffer.size() > minmal_size: 214 | b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample( 215 | batch_size) 216 | agent.update(b_s, b_a, b_r, b_ns, b_d, max_seq, 217 | batch_size) 218 | return_list.append(episode_return) 219 | plot_smooth_reward(return_list, 100, env_name, algorithm_name) 220 | if episode_return > max_reward: 221 | max_reward = episode_return 222 | agent.save_model(env_name, algorithm_name) 223 | if (i_episodes + 1 % 10 == 0): 224 | pbar.set_postfix({ 225 | 'episode': 226 | '%d' % (num_episodes / 10 * i + i_episodes + 1), 227 | 'return': 228 | '%.3f' % np.mean(return_list[-10:]) 229 | }) 230 | pbar.update(1) 231 | -------------------------------------------------------------------------------- /dqn/dqn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import gym 4 | import pygame 5 | import collections 6 | import random 7 | import torch.nn.functional as F 8 | from tqdm import tqdm 9 | import matplotlib.pyplot as plt 10 | import time 11 | from ddqn import plot_smooth_reward 12 | 13 | 14 | class Qnet(torch.nn.Module): 15 | def __init__(self, state_dim, hidden_dim, action_dim) -> None: 16 | super().__init__() 17 | self.fc1 = torch.nn.Linear(state_dim, hidden_dim) 18 | self.fc2 = torch.nn.Linear(hidden_dim, action_dim) 19 | 20 | def forward(self, state): 21 | hidden_state = F.relu(self.fc1(state)) 22 | acion_value = self.fc2(hidden_state) 23 | return acion_value 24 | 25 | 26 | class DQN: 27 | def __init__(self, state_dim, hidden_dim, action_dim, learning_rate, gamma, 28 | epsilon, device) -> None: 29 | self.Q = Qnet(state_dim, hidden_dim, action_dim).to(device) 30 | self.optimizer = torch.optim.Adam(self.Q.parameters(), 31 | lr=learning_rate) 32 | self.action_dim = action_dim 33 | self.gamma = gamma 34 | self.epsilon = epsilon 35 | self.device = device 36 | self.count = 1 37 | pass 38 | 39 | def save_model(self, save_path='./', filename='model'): 40 | torch.save(self.Q.state_dict(), f"{save_path}\\{filename}.pth") 41 | 42 | def load_model(self, load_path): 43 | self.Q.load_state_dict(torch.load(load_path)) 44 | 45 | def take_action(self, state): 46 | 47 | Q_value = self.Q(state) 48 | if np.random.random() < self.epsilon: 49 | action = np.random.randint(self.action_dim) 50 | else: 51 | action = torch.argmax(Q_value).item() 52 | return action 53 | 54 | def update(self, states, actions, rewards, next_states, dones): 55 | states_cuda = torch.tensor(states, dtype=torch.float).to(self.device) 56 | actions_cuda = torch.tensor(actions, 57 | dtype=torch.float).view(-1, 58 | 1).to(self.device) 59 | rewards_cuda = torch.tensor(rewards, 60 | dtype=torch.float).to(self.device).view( 61 | -1, 1) 62 | next_states_cuda = torch.tensor(next_states, 63 | dtype=torch.float).to(self.device) 64 | dones_cuda = torch.tensor(dones, 65 | dtype=torch.float).to(self.device).view( 66 | -1, 1) 67 | q_now = self.Q(states_cuda).gather(1, actions_cuda.long()).view(-1, 1) 68 | 69 | q_next = self.Q(next_states_cuda).max(1)[0].view(-1, 1) 70 | 71 | y_now = (rewards_cuda + self.gamma * q_next * (1 - dones_cuda)).view( 72 | -1, 1) 73 | td_error = q_now - y_now 74 | 75 | loss = torch.mean(F.mse_loss(q_now, y_now)) 76 | self.optimizer.zero_grad() 77 | loss.backward() 78 | self.optimizer.step() 79 | self.count += 1 80 | if self.count == 100: 81 | self.epsilon *= 0.9 82 | 83 | 84 | class ReplayBuffer: 85 | def __init__(self, capacity) -> None: 86 | self.buffer = collections.deque(maxlen=capacity) 87 | 88 | def add(self, state, action, reward, next_state, done): 89 | self.buffer.append([state, action, reward, next_state, done]) 90 | 91 | def sample(self, batch_size): 92 | transitions = random.sample(self.buffer, batch_size) 93 | states, actions, rewards, next_states, dones = zip(*transitions) 94 | return states, actions, rewards, next_states, dones 95 | 96 | def size(self): 97 | return len(self.buffer) 98 | 99 | 100 | gamma = 0.99 101 | 102 | num_episodes = 10000 103 | buffersize = 10000 104 | minmal_size = 500 105 | batch_size = 64 106 | epsilon = 0.01 107 | learning_rate = 2e-3 108 | device = torch.device('cuda') 109 | algorithm_name = 'DQN' 110 | env_name = 'Snake-v0' 111 | # 注册环境 112 | gym.register(id='Snake-v0', entry_point='snake_env:SnakeEnv') 113 | 114 | env = gym.make(env_name) 115 | random.seed(0) 116 | np.random.seed(0) 117 | env.seed(0) 118 | torch.manual_seed(0) 119 | replay_buffer = ReplayBuffer(buffersize) 120 | 121 | state_dim = env.observation_space.shape[0] 122 | hidden_dim = 128 123 | action_dim = env.action_space.n 124 | 125 | agent = DQN(state_dim, hidden_dim, action_dim, learning_rate, gamma, epsilon, 126 | device) 127 | 128 | return_list = [] 129 | max_reward = 0 130 | for i in range(10): 131 | with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar: 132 | for i_episodes in range(int(num_episodes / 10)): 133 | episode_return = 0 134 | state = env.reset() 135 | done = False 136 | while not done: 137 | action = agent.take_action( 138 | torch.tensor(state, dtype=torch.float).to(device)) 139 | next_state, reward, done, _ = env.step(action) 140 | 141 | if i_episodes == int(num_episodes / 10) - 1: 142 | env.render() 143 | time.sleep(0.1) 144 | replay_buffer.add(state, action, reward, next_state, done) 145 | state = next_state 146 | episode_return += reward 147 | if replay_buffer.size() > minmal_size: 148 | b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size) 149 | 150 | agent.update(b_s, b_a, b_r, b_ns, b_d) 151 | return_list.append(episode_return) 152 | plot_smooth_reward(return_list, 100, env_name, algorithm_name) 153 | if episode_return > max_reward: 154 | max_reward = episode_return 155 | agent.save_model(env_name, algorithm_name) 156 | if (i_episodes + 1 % 10 == 0): 157 | pbar.set_postfix({ 158 | 'episode': 159 | '%d' % (num_episodes / 10 * i + i_episodes + 1), 160 | 'return': 161 | '%.3f' % np.mean(return_list[-10:]) 162 | }) 163 | pbar.update(1) 164 | -------------------------------------------------------------------------------- /dqn/drqn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import gym 4 | import pygame 5 | import collections 6 | import random 7 | import torch.nn.functional as F 8 | from tqdm import tqdm 9 | import matplotlib.pyplot as plt 10 | import time 11 | from ddqn import plot_smooth_reward 12 | 13 | 14 | class Qnet(torch.nn.Module): 15 | def __init__(self, 16 | obs_dim, 17 | hidden_dim, 18 | action_dim, 19 | device, 20 | num_layers=1, 21 | batch_first=True) -> None: 22 | super().__init__() 23 | self.action_dim = action_dim 24 | self.hidden_dim = hidden_dim 25 | self.device = device 26 | self.num_layers = num_layers 27 | self.batch_first = batch_first 28 | self.fc1 = torch.nn.Linear(obs_dim, hidden_dim) 29 | self.fc2 = torch.nn.Linear(hidden_dim, action_dim) 30 | self.lstm = torch.nn.LSTM(hidden_dim, 31 | hidden_dim, 32 | self.num_layers, 33 | batch_first=True) 34 | 35 | def forward(self, state, max_seq=1, batch_size=1): 36 | h0 = torch.zeros(self.num_layers, state.size(0), 37 | self.hidden_dim).to(device) 38 | c0 = torch.zeros(self.num_layers, state.size(0), 39 | self.hidden_dim).to(device) 40 | 41 | state = F.relu(self.fc1(state)).reshape(-1, max_seq, self.hidden_dim) 42 | state, _ = self.lstm(state, (h0, c0)) 43 | action_value = self.fc2(state).view(-1, self.action_dim) 44 | return action_value 45 | 46 | 47 | class DRQN: 48 | def __init__(self, state_dim, hidden_dim, action_dim, learning_rate, gamma, 49 | epsilon, device) -> None: 50 | self.Q = Qnet(state_dim, hidden_dim, action_dim, device).to(device) 51 | self.optimizer = torch.optim.Adam(self.Q.parameters(), 52 | lr=learning_rate) 53 | self.action_dim = action_dim 54 | self.gamma = gamma 55 | self.epsilon = epsilon 56 | self.device = device 57 | self.count = 1 58 | pass 59 | 60 | def save_model(self, save_path='./', filename='model'): 61 | torch.save(self.Q.state_dict(), f"{save_path}\\{filename}.pth") 62 | 63 | def load_model(self, load_path): 64 | self.Q.load_state_dict(torch.load(load_path)) 65 | 66 | def take_action(self, state): 67 | state = torch.tensor([state], dtype=torch.float).to(self.device) 68 | Q_value = self.Q(state) 69 | if np.random.random() < self.epsilon: 70 | action = np.random.randint(self.action_dim) 71 | else: 72 | action = torch.argmax(Q_value).item() 73 | return action 74 | 75 | def update(self, 76 | states, 77 | actions, 78 | rewards, 79 | next_states, 80 | dones, 81 | max_seq=1, 82 | batch_size=1): 83 | states = torch.tensor(states, dtype=torch.float).to(self.device) 84 | actions = torch.tensor(actions, 85 | dtype=torch.float).view(-1, 1).to(self.device) 86 | rewards = torch.tensor(rewards, 87 | dtype=torch.float).to(self.device).view(-1, 1) 88 | next_states = torch.tensor(next_states, 89 | dtype=torch.float).to(self.device) 90 | dones = torch.tensor(dones, 91 | dtype=torch.float).to(self.device).view(-1, 1) 92 | q_now = self.Q.forward(states, max_seq, 93 | batch_size).gather(1, 94 | actions.long()).view(-1, 1) 95 | 96 | q_next = self.Q.forward(next_states, max_seq, 97 | batch_size).max(1)[0].view(-1, 1) 98 | 99 | y_now = (rewards + self.gamma * q_next * (1 - dones)).view(-1, 1) 100 | td_error = q_now - y_now 101 | 102 | loss = torch.mean(F.mse_loss(q_now, y_now)) 103 | self.optimizer.zero_grad() 104 | loss.backward() 105 | self.optimizer.step() 106 | self.count += 1 107 | if self.count == 100: 108 | self.epsilon *= 0.9 109 | 110 | 111 | class Recurrent_Memory_ReplayBuffer: 112 | def __init__(self, capacity) -> None: 113 | self.buffer = collections.deque(maxlen=capacity) 114 | 115 | def add(self, state, action, reward, next_state, done): 116 | self.buffer.append([state, action, reward, next_state, done]) 117 | 118 | # def sample(self, batch_size): 119 | # transitions = random.sample(self.buffer, batch_size) 120 | # states, actions, rewards, next_states, dones = zip(*transitions) 121 | # return states, actions, rewards, next_states, dones 122 | 123 | def sample(self, batch_size, max_seq): 124 | # sample episodic memory 125 | states, actions, rewards, next_states, dones = [], [], [], [], [] 126 | for _ in range(batch_size): 127 | finish = random.randint(max_seq, self.size() - 1) 128 | begin = finish - max_seq 129 | data = [] 130 | for idx in range(begin, finish): 131 | data.append(self.buffer[idx]) 132 | state, action, reward, next_state, done = zip(*data) 133 | states.append(np.vstack(state)) 134 | actions.append(action) 135 | rewards.append(reward) 136 | next_states.append(np.vstack(next_state)) 137 | dones.append(done) 138 | 139 | states = np.array(states) 140 | actions = np.array(actions) 141 | rewards = np.array(rewards) 142 | next_states = np.array(next_states) 143 | dones = np.array(dones) 144 | 145 | return states, actions, rewards, next_states, dones 146 | 147 | def size(self): 148 | return len(self.buffer) 149 | 150 | 151 | gamma = 0.99 152 | num_episodes = 10000 153 | buffersize = 10000 154 | minmal_size = 500 155 | batch_size = 64 156 | epsilon = 0.01 157 | learning_rate = 2e-3 158 | max_seq = 1 159 | device = torch.device('cuda') 160 | algorithm_name = 'DRQN' 161 | env_name = 'Snake-v0' 162 | # 注册环境 163 | gym.register(id='Snake-v0', entry_point='snake_env:SnakeEnv') 164 | 165 | env = gym.make(env_name) 166 | random.seed(0) 167 | np.random.seed(0) 168 | env.seed(0) 169 | torch.manual_seed(0) 170 | replay_buffer = Recurrent_Memory_ReplayBuffer(buffersize) 171 | 172 | state_dim = env.observation_space.shape[0] 173 | hidden_dim = 128 174 | action_dim = env.action_space.n 175 | 176 | agent = DRQN(state_dim, hidden_dim, action_dim, learning_rate, gamma, epsilon, 177 | device) 178 | 179 | return_list = [] 180 | max_reward = 0 181 | for i in range(10): 182 | with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar: 183 | for i_episodes in range(int(num_episodes / 10)): 184 | episode_return = 0 185 | state = env.reset() 186 | done = False 187 | while not done: 188 | action = agent.take_action(state) 189 | next_state, reward, done, _ = env.step(action) 190 | 191 | if i_episodes == int(num_episodes / 10) - 1: 192 | env.render() 193 | time.sleep(0.1) 194 | replay_buffer.add(state, action, reward, next_state, done) 195 | state = next_state 196 | episode_return += reward 197 | if replay_buffer.size() > minmal_size: 198 | b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample( 199 | batch_size, max_seq) 200 | 201 | agent.update(b_s, b_a, b_r, b_ns, b_d, max_seq, batch_size) 202 | return_list.append(episode_return) 203 | plot_smooth_reward(return_list, 100, env_name, algorithm_name) 204 | if episode_return > max_reward: 205 | max_reward = episode_return 206 | agent.save_model(env_name, algorithm_name) 207 | if (i_episodes + 1 % 10 == 0): 208 | pbar.set_postfix({ 209 | 'episode': 210 | '%d' % (num_episodes / 10 * i + i_episodes + 1), 211 | 'return': 212 | '%.3f' % np.mean(return_list[-10:]) 213 | }) 214 | pbar.update(1) 215 | -------------------------------------------------------------------------------- /pg/reinforce.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import gym 4 | import pygame 5 | import collections 6 | import random 7 | import torch.nn.functional as F 8 | from tqdm import tqdm 9 | import matplotlib.pyplot as plt 10 | import time 11 | import torch.nn as nn 12 | from rl_utils import plot_smooth_reward 13 | 14 | 15 | class ActorNet(nn.Module): 16 | def __init__(self, state_dim, hidden_dim, action_dim) -> None: 17 | super().__init__() 18 | self.fc1 = nn.Linear(state_dim, hidden_dim) 19 | self.fc2 = nn.Linear(hidden_dim, action_dim) 20 | 21 | def forward(self, state): 22 | 23 | hidden_state = F.relu(self.fc1(state)) 24 | # print(self.fc2(hidden_state)) 25 | probs = F.softmax(self.fc2(hidden_state), dim=1) 26 | return probs 27 | 28 | 29 | class Reinforce: 30 | def __init__(self, state_dim, hidden_dim, action_dim, learning_rate, gamma, 31 | device) -> None: 32 | self.actor = ActorNet(state_dim, hidden_dim, action_dim).to(device) 33 | self.gamma = gamma 34 | self.optimizer = torch.optim.Adam(self.actor.parameters(), 35 | lr=learning_rate) 36 | self.device = device 37 | 38 | def save_model(self, save_path='./', filename='model'): 39 | torch.save(self.actor.state_dict(), f"{save_path}\\{filename}.pth") 40 | 41 | def load_model(self, load_path): 42 | self.actor.load_state_dict(torch.load(load_path)) 43 | 44 | def take_action(self, state): 45 | state = torch.tensor([state], dtype=torch.float).to(self.device) 46 | probs = self.actor(state) 47 | action_dist = torch.distributions.Categorical(probs) 48 | action = action_dist.sample() 49 | return action.item() 50 | 51 | def update(self, transition_dict): 52 | reward_list = transition_dict['rewards'] 53 | state_list = transition_dict['states'] 54 | action_list = transition_dict['actions'] 55 | G = 0 56 | self.optimizer.zero_grad() 57 | for i in reversed(range(len(reward_list))): 58 | reward = reward_list[i] 59 | state = torch.tensor([state_list[i]], 60 | dtype=torch.float).to(self.device) 61 | action = torch.tensor([action_list[i]]).view(-1, 1).to(self.device) 62 | log_prob = torch.log(self.actor(state).gather(1, action)) 63 | G = self.gamma * G + reward 64 | loss = torch.mean(-log_prob * G) 65 | loss.backward() 66 | self.optimizer.step() 67 | 68 | 69 | if __name__ == "__main__": 70 | algorithm_name = "REINFORCE" 71 | gamma = 0.98 72 | 73 | num_episodes = 5000 74 | learning_rate = 2e-3 75 | device = torch.device('cuda') 76 | 77 | env_name = 'Snake-v0' #'CartPole-v0' 78 | # 注册环境 79 | gym.register(id='Snake-v0', entry_point='snake_env:SnakeEnv') 80 | 81 | env = gym.make(env_name) 82 | 83 | env.seed(0) 84 | torch.manual_seed(0) 85 | 86 | state_dim = env.observation_space.shape[0] 87 | hidden_dim = 128 88 | action_dim = env.action_space.n 89 | update_target = 100 90 | agent = Reinforce(state_dim, hidden_dim, action_dim, learning_rate, gamma, 91 | device) 92 | 93 | return_list = [] 94 | max_reward = 0 95 | for i in range(20): 96 | with tqdm(total=int(num_episodes / 10), 97 | desc='Iteration %d' % i) as pbar: 98 | for i_episodes in range(int(num_episodes / 10)): 99 | episode_return = 0 100 | state = env.reset() 101 | done = False 102 | transition_dict = { 103 | 'states': [], 104 | 'actions': [], 105 | 'next_states': [], 106 | 'rewards': [], 107 | 'dones': [] 108 | } 109 | while not done: 110 | action = agent.take_action(state) 111 | next_state, reward, done, _ = env.step(action) 112 | env.render() 113 | transition_dict['states'].append(state) 114 | transition_dict['actions'].append(action) 115 | transition_dict['next_states'].append(next_state) 116 | transition_dict['rewards'].append(reward) 117 | transition_dict['dones'].append(done) 118 | state = next_state 119 | episode_return += reward 120 | if i_episodes == int(num_episodes / 10) - 1: 121 | time.sleep(0.1) 122 | agent.update(transition_dict) 123 | return_list.append(episode_return) 124 | plot_smooth_reward(return_list, 100, env_name, algorithm_name) 125 | if episode_return > max_reward: 126 | max_reward = episode_return 127 | agent.save_model(env_name, algorithm_name) 128 | if (i_episodes + 1 % 10 == 0): 129 | pbar.set_postfix({ 130 | 'episode': 131 | '%d' % (num_episodes / 10 * i + i_episodes + 1), 132 | 'return': 133 | '%.3f' % np.mean(return_list[-10:]) 134 | }) 135 | pbar.update(1) 136 | -------------------------------------------------------------------------------- /pg/reinforce_baseline.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import gym 4 | import pygame 5 | import collections 6 | import random 7 | import torch.nn.functional as F 8 | from tqdm import tqdm 9 | import matplotlib.pyplot as plt 10 | import time 11 | import torch.nn as nn 12 | from rl_utils import plot_smooth_reward 13 | 14 | 15 | class ValueNet(torch.nn.Module): 16 | def __init__(self, state_dim, hidden_dim, action_dim) -> None: 17 | super().__init__() 18 | self.fc1 = torch.nn.Linear(state_dim, hidden_dim) 19 | self.fc2 = torch.nn.Linear(hidden_dim, 1) 20 | 21 | def forward(self, state): 22 | hidden_state = F.relu(self.fc1(state)) 23 | value = self.fc2(hidden_state) 24 | return value 25 | 26 | 27 | class ActorNet(nn.Module): 28 | def __init__(self, state_dim, hidden_dim, action_dim) -> None: 29 | super().__init__() 30 | self.fc1 = nn.Linear(state_dim, hidden_dim) 31 | self.fc2 = nn.Linear(hidden_dim, action_dim) 32 | 33 | def forward(self, state): 34 | 35 | hidden_state = F.relu(self.fc1(state)) 36 | # print(self.fc2(hidden_state)) 37 | probs = F.softmax(self.fc2(hidden_state), dim=1) 38 | return probs 39 | 40 | 41 | class REINFORCE_BASELINE: 42 | def __init__(self, state_dim, hidden_dim, action_dim, learning_rate, gamma, 43 | device) -> None: 44 | self.actor = ActorNet(state_dim, hidden_dim, action_dim).to(device) 45 | self.vnet = ValueNet(state_dim, hidden_dim, action_dim).to(device) 46 | self.gamma = gamma 47 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), 48 | lr=learning_rate) 49 | self.vnet_optimizer = torch.optim.Adam(self.vnet.parameters(), 50 | lr=learning_rate) 51 | self.device = device 52 | 53 | def save_model(self, save_path='./', filename='model'): 54 | torch.save(self.actor.state_dict(), f"{save_path}\\{filename}.pth") 55 | 56 | def load_model(self, load_path): 57 | self.actor.load_state_dict(torch.load(load_path)) 58 | 59 | def take_action(self, state): 60 | state = torch.tensor([state], dtype=torch.float).to(self.device) 61 | probs = self.actor(state) 62 | action_dist = torch.distributions.Categorical(probs) 63 | action = action_dist.sample() 64 | return action.item() 65 | 66 | def update(self, transition_dict): 67 | reward_list = transition_dict['rewards'] 68 | state_list = transition_dict['states'] 69 | action_list = transition_dict['actions'] 70 | G = 0 71 | SG = 0 72 | self.actor_optimizer.zero_grad() 73 | self.vnet_optimizer.zero_grad() 74 | for i in reversed(range(len(reward_list))): 75 | reward = torch.tensor([reward_list[i]], 76 | dtype=torch.float).view(-1, 77 | 1).to(self.device) 78 | state = torch.tensor([state_list[i]], 79 | dtype=torch.float).to(self.device) 80 | action = torch.tensor([action_list[i]]).view(-1, 1).to(self.device) 81 | value = self.vnet(state).view(-1, 1).to(self.device) 82 | log_prob = torch.log(self.actor(state).gather(1, action)) 83 | G = self.gamma * G + reward 84 | 85 | td_delta = G - value 86 | 87 | actor_loss = -log_prob * td_delta.detach() 88 | actor_loss.backward() 89 | 90 | vnet_loss = F.mse_loss(value, G) 91 | vnet_loss.backward() 92 | 93 | self.actor_optimizer.step() 94 | self.vnet_optimizer.step() 95 | 96 | 97 | if __name__ == "__main__": 98 | 99 | gamma = 0.99 100 | algorithm_name = "REINFORCE_baseline" 101 | num_episodes = 10000 102 | learning_rate = 2e-3 103 | device = torch.device('cuda') 104 | 105 | env_name = 'CartPole-v0' #'CartPole-v0' 106 | # 注册环境 107 | gym.register(id='Snake-v0', entry_point='snake_env:SnakeEnv') 108 | 109 | env = gym.make(env_name) 110 | 111 | env.seed(0) 112 | torch.manual_seed(0) 113 | 114 | state_dim = env.observation_space.shape[0] 115 | hidden_dim = 128 116 | action_dim = env.action_space.n 117 | update_target = 100 118 | agent = REINFORCE_BASELINE(state_dim, hidden_dim, action_dim, 119 | learning_rate, gamma, device) 120 | 121 | return_list = [] 122 | max_reward = 0 123 | for i in range(10): 124 | with tqdm(total=int(num_episodes / 10), 125 | desc='Iteration %d' % i) as pbar: 126 | for i_episodes in range(int(num_episodes / 10)): 127 | episode_return = 0 128 | state = env.reset() 129 | done = False 130 | transition_dict = { 131 | 'states': [], 132 | 'actions': [], 133 | 'next_states': [], 134 | 'rewards': [], 135 | 'dones': [] 136 | } 137 | while not done: 138 | action = agent.take_action(state) 139 | next_state, reward, done, _ = env.step(action) 140 | 141 | transition_dict['states'].append(state) 142 | transition_dict['actions'].append(action) 143 | transition_dict['next_states'].append(next_state) 144 | transition_dict['rewards'].append(reward) 145 | transition_dict['dones'].append(done) 146 | state = next_state 147 | episode_return += reward 148 | if i_episodes == int(num_episodes / 10) - 1: 149 | env.render() 150 | time.sleep(0.1) 151 | agent.update(transition_dict) 152 | return_list.append(episode_return) 153 | plot_smooth_reward(return_list, 100, env_name, algorithm_name) 154 | if episode_return > max_reward: 155 | max_reward = episode_return 156 | agent.save_model(env_name, algorithm_name) 157 | if (i_episodes + 1 % 10 == 0): 158 | pbar.set_postfix({ 159 | 'episode': 160 | '%d' % (num_episodes / 10 * i + i_episodes + 1), 161 | 'return': 162 | '%.3f' % np.mean(return_list[-10:]) 163 | }) 164 | pbar.update(1) 165 | -------------------------------------------------------------------------------- /ppo/.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /ppo/.idea/RL_algorithm.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /ppo/.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /ppo/.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /ppo/.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_01_30_22_54_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_01_30_22_54_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_01_30_23_01_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_01_30_23_01_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_01_30_23_05_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_01_30_23_05_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_01_30_23_16_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_01_30_23_16_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_01_30_23_17_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_01_30_23_17_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_01_30_23_26_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_01_30_23_26_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_01_30_23_32_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_01_30_23_32_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_01_30_23_35_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_01_30_23_35_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_01_30_23_39_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_01_30_23_39_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_01_30_23_40_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_01_30_23_40_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_01_30_23_50_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_01_30_23_50_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_01_31_00_09_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_01_31_00_09_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_01_31_00_11_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_01_31_00_11_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_01_31_00_15_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_01_31_00_15_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_10_58_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_10_58_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_10_59_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_10_59_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_11_00_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_11_00_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_11_01_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_11_01_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_11_02_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_11_02_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_11_03_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_11_03_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_11_04_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_11_04_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_11_06_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_11_06_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_11_07_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_11_07_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_11_08_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_11_08_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_15_41_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_15_41_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_15_48_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_15_48_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_15_49_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_15_49_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_15_50_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_15_50_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_15_51_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_15_51_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_15_53_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_15_53_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_15_54_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_15_54_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_15_59_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_15_59_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_16_23_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_16_23_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_16_25_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_16_25_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_16_26_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_16_26_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_16_27_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_16_27_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_16_28_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_16_28_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_16_29_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_16_29_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_16_30_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_16_30_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_16_31_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_16_31_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_16_32_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_16_32_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_16_33_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_16_33_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_16_34_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_16_34_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_16_35_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_16_35_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_16_36_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_16_36_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_16_42_BipedalWalker-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_16_42_BipedalWalker-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_17_53_BipedalWalkerHardcore-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_17_53_BipedalWalkerHardcore-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_17_55_BipedalWalkerHardcore-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_17_55_BipedalWalkerHardcore-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_17_56_BipedalWalkerHardcore-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_17_56_BipedalWalkerHardcore-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_17_58_BipedalWalkerHardcore-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_17_58_BipedalWalkerHardcore-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_17_59_BipedalWalkerHardcore-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_17_59_BipedalWalkerHardcore-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_18_00_BipedalWalkerHardcore-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_18_00_BipedalWalkerHardcore-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_18_02_BipedalWalkerHardcore-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_18_02_BipedalWalkerHardcore-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_18_05_BipedalWalkerHardcore-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_18_05_BipedalWalkerHardcore-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_18_06_BipedalWalkerHardcore-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_18_06_BipedalWalkerHardcore-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_18_07_BipedalWalkerHardcore-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_18_07_BipedalWalkerHardcore-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_18_08_BipedalWalkerHardcore-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_18_08_BipedalWalkerHardcore-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_18_09_BipedalWalkerHardcore-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_18_09_BipedalWalkerHardcore-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_18_12_BipedalWalkerHardcore-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_18_12_BipedalWalkerHardcore-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_18_13_BipedalWalkerHardcore-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_18_13_BipedalWalkerHardcore-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_18_24_BipedalWalkerHardcore-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_18_24_BipedalWalkerHardcore-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_18_41_BipedalWalkerHardcore-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_18_41_BipedalWalkerHardcore-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_18_45_BipedalWalkerHardcore-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_18_45_BipedalWalkerHardcore-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_18_51_BipedalWalkerHardcore-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_18_51_BipedalWalkerHardcore-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_18_56_BipedalWalkerHardcore-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_18_56_BipedalWalkerHardcore-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_19_11_BipedalWalkerHardcore-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_19_11_BipedalWalkerHardcore-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_19_19_BipedalWalkerHardcore-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_19_19_BipedalWalkerHardcore-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_19_26_BipedalWalkerHardcore-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_19_26_BipedalWalkerHardcore-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_19_44_BipedalWalkerHardcore-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_19_44_BipedalWalkerHardcore-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_19_52_BipedalWalkerHardcore-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_19_52_BipedalWalkerHardcore-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_19_56_BipedalWalkerHardcore-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_19_56_BipedalWalkerHardcore-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_20_12_BipedalWalkerHardcore-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_20_12_BipedalWalkerHardcore-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_20_16_BipedalWalkerHardcore-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_20_16_BipedalWalkerHardcore-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_20_18_BipedalWalkerHardcore-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_20_18_BipedalWalkerHardcore-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_20_22_BipedalWalkerHardcore-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_20_22_BipedalWalkerHardcore-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_21_10_BipedalWalkerHardcore-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_21_10_BipedalWalkerHardcore-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_21_14_BipedalWalkerHardcore-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_21_14_BipedalWalkerHardcore-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_21_25_BipedalWalkerHardcore-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_21_25_BipedalWalkerHardcore-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_21_30_BipedalWalkerHardcore-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_21_30_BipedalWalkerHardcore-v3.pth -------------------------------------------------------------------------------- /ppo/BipedalWalker-v3/2024_02_01_21_54_BipedalWalkerHardcore-v3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/BipedalWalker-v3/2024_02_01_21_54_BipedalWalkerHardcore-v3.pth -------------------------------------------------------------------------------- /ppo/BipedwalkerHardcoreTest.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | 4 | 5 | if __name__ == "__main__": 6 | env = gym.make("BipedalWalkerHardcore-v3") 7 | state = env.reset() # 初始化状态 8 | total_reward = 0 9 | step_count = 0 10 | done = False 11 | from ppo_multienv import PPOAgent 12 | from ppo_multienv import parse_args 13 | import torch 14 | import time 15 | 16 | args = parse_args() 17 | agent = PPOAgent(env.observation_space, env.action_space, args).to(args.device) 18 | agent = torch.load("BipedalWalker-v3/2024_02_01_21_54_BipedalWalkerHardcore-v3.pth") 19 | next_obs = env.reset() # 初始化状态 20 | while not done: 21 | next_obs = torch.tensor(next_obs).unsqueeze(0).to(args.device) 22 | with torch.no_grad(): 23 | action, logprob, entropy, value = agent.get_action_and_value(next_obs) 24 | next_obs, reward, done, info = env.step(action.flatten().cpu().numpy()) 25 | if done: 26 | next_obs = env.reset() 27 | env.render() 28 | # action = 1#env.action_space.sample() # 随机采样一个动作,此处用于演示 29 | # state, reward, done, info = env.step(action) # 执行动作 30 | total_reward += reward 31 | step_count += 1 32 | print('At step {}, reward = {}, done = {}'.format(step_count, reward, done)) 33 | print('Total reward: {}'.format(total_reward)) -------------------------------------------------------------------------------- /ppo/MultiFollow/2024_02_24_20_31_Follow.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/MultiFollow/2024_02_24_20_31_Follow.pth -------------------------------------------------------------------------------- /ppo/MultiFollow/2024_02_24_20_32.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/MultiFollow/2024_02_24_20_32.gif -------------------------------------------------------------------------------- /ppo/custom_env/MultiAgentFollowEnvV2.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import io 3 | import gym 4 | from gym import spaces 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | from matplotlib.animation import FuncAnimation, ArtistAnimation 8 | import matplotlib.animation as animation 9 | 10 | import imageio 11 | 12 | class MultiAgentFollowEnv(gym.Env): 13 | def __init__(self, target_info): 14 | np.random.seed(0) 15 | 16 | self.num_agents = target_info["num_agents"] 17 | 18 | self.episode_length = 0 19 | self.episode_rewards = [0]*self.num_agents 20 | self.episode = 0 21 | 22 | 23 | # 团队奖励和个人奖励的加权 24 | self.tau=0.9 25 | 26 | # 目标状态 27 | self.target_distances=[] 28 | self.target_angles=[] 29 | 30 | for i in range(self.num_agents): 31 | self.target_distances.append(target_info[str(i)][0]) 32 | self.target_angles.append(target_info[str(i)][1]) 33 | 34 | self.target_distances=np.array(self.target_distances) 35 | self.target_angles=np.array(self.target_angles) 36 | 37 | # 定义状态空间:每个智能体包括目标位置、智能体位置、速度和方向 38 | pi = np.pi 39 | inf=np.inf 40 | self.id=[0]*self.num_agents 41 | self.id_=[1]*self.num_agents 42 | self.low = np.array(self.id+[-100.0, -100.0, -100.0,-100.0, 0.0, 0]) 43 | self.high = np.array(self.id_+[100.0, 100.0,100.0,100.0,1.0, 2*pi]) 44 | self.observation_space = spaces.Box(self.low, self.high,(self.num_agents+2+2+2,), dtype=float) 45 | 46 | # 定义动作空间:每个智能体包括速度和方向两个维度,多维离散 47 | self.action_space=gym.spaces.MultiDiscrete([3,3]) 48 | # self.action_space = spaces.Tuple([ 49 | # spaces.Discrete(3), # 速度:加速、减速、保持 50 | # spaces.Discrete(3) # 方向:向左、向右、保持 51 | # ]) 52 | 53 | 54 | # 定义其他环境参数 55 | self.max_steps = 150 56 | self.current_step = 0 57 | 58 | # 到达次数记录以及时间记录 59 | self.arrive_count=0 60 | self.arrive_time=0 61 | 62 | self.fig, self.ax = plt.subplots() 63 | 64 | # 帧列表 65 | self.frames=[] 66 | 67 | # 设置矩形边界限制 68 | self.x_min, self.x_max = 0,3 69 | self.y_min, self.y_max = 0,3 70 | self.main_speed = 0.05 71 | self.move_directions = [np.array([1, 0]), np.array([0, 1]), np.array([-1, 0]), np.array([0, -1])] 72 | self.current_direction_index = 0 73 | 74 | # 重置环境状态,随机设置目标位置和每个智能体初始位置、速度、方向 75 | self.main_position = np.random.uniform(self.x_min, self.x_max, size=(2,)) 76 | 77 | # 随机设置每个智能体初始位置,确保距离目标一定距离 78 | self.agent_positions = np.random.uniform(self.x_min, self.x_max, size=(self.num_agents,2)) 79 | 80 | # 定义每个实体的速度和动作 81 | self.speeds = np.zeros(self.num_agents) 82 | self.directions = np.zeros(self.num_agents) 83 | 84 | 85 | def update_main_position(self): 86 | # 沿着当前方向移动 87 | displacement = self.main_speed * self.move_directions[self.current_direction_index] 88 | new_main_position = self.main_position + displacement 89 | 90 | # 运行轨迹矩形边框 91 | x_min, x_max = 0.5,2.5 92 | y_min, y_max = 0.5,2.5 93 | 94 | # 判断是否需要改变方向 95 | if not (x_min <= new_main_position[0] <= x_max) or \ 96 | not (y_min <= new_main_position[1] <= y_max): 97 | # 如果超出边界,改变方向 98 | self.current_direction_index = (self.current_direction_index + 1) % 4 99 | 100 | # 重新计算移动 101 | displacement = self.main_speed * self.move_directions[self.current_direction_index] 102 | 103 | new_main_position = self.main_position + displacement 104 | 105 | # 更新主目标位置 106 | self.main_position = new_main_position 107 | 108 | # 重置环境状态,随机设置目标位置和每个智能体初始位置、速度、方向 109 | self.main_position = np.clip(self.main_position,x_min, x_max) 110 | 111 | 112 | def calculate_target_coordinates(self): 113 | 114 | # Calculate x and y coordinates of the targets 115 | target_x = self.main_position[0] + self.target_distances * np.cos(self.target_angles) 116 | target_y = self.main_position[1] + self.target_distances * np.sin(self.target_angles) 117 | 118 | # Now, target_x and target_y contain the coordinates of each target 119 | target_coordinates = np.column_stack((target_x, target_y)) 120 | 121 | return target_coordinates 122 | 123 | def reset(self,seed=0): 124 | # np.random.seed(seed) 125 | 126 | pi = np.pi 127 | inf=np.inf 128 | 129 | self.frames=[] 130 | 131 | # 重置环境记录 132 | self.episode_length=0 133 | self.episode_rewards=[0]*self.num_agents 134 | 135 | # 重置环境状态,随机设置参考点位置和每个智能体初始位置、速度、方向 136 | self.main_position = np.random.uniform(self.x_min, self.x_max, size=(2,)) 137 | 138 | # 随机设置每个智能体初始位置,确保距离目标一定距离 139 | self.agent_positions = np.random.uniform(self.x_min, self.x_max, size=(self.num_agents,2)) 140 | 141 | self.speeds = np.zeros(self.num_agents) 142 | self.directions = np.zeros(self.num_agents) 143 | 144 | self.current_step = 0 145 | 146 | # 使用广播操作计算每个代理与其他所有代理位置的差异 147 | differences = self.agent_positions[:, np.newaxis, :] - self.agent_positions[np.newaxis, :, :] 148 | # 使用 np.linalg.norm 计算每一行的2范数,即每个代理与其他所有代理之间的欧式距离 149 | distances = np.linalg.norm(differences, axis=2) 150 | # 找到距离当前智能体最小的其他智能体位置 151 | try: 152 | second_min_indices = np.argsort(distances, axis=1)[:, 1] 153 | second_min_values = distances[np.arange(len(distances)), second_min_indices] 154 | done=done or np.any(second_min_values<1) 155 | print(second_min_values) 156 | if np.any(second_min_values<1): 157 | print(second_min_values) 158 | print("碰撞") 159 | except: 160 | second_min_indices=[0]*self.num_agents 161 | 162 | 163 | # 返回新的状态,包括目标位置、智能体位置、速度和方向 164 | states=[] 165 | for i in range(self.num_agents): 166 | self.id[i]=1 167 | states.append(np.concatenate([self.id,self.agent_positions[i]-self.main_position,self.agent_positions[i]-self.agent_positions[second_min_indices[i]],self.speeds[i:i+1],self.directions[i:i+1]])) 168 | self.id[i]=0 169 | 170 | return states 171 | 172 | def step(self, action): 173 | # 解析动作为速度和方向 174 | for i in range(self.num_agents): 175 | speed_action, direction_action = action[i] 176 | 177 | # 根据速度动作更新速度 178 | if speed_action == 0: # 减速 179 | self.speeds[i] -= 0.1 180 | self.speeds[i]=max(self.speeds[i],self.low[-2]) 181 | elif speed_action == 1: # 加速 182 | self.speeds[i] += 0.1 183 | self.speeds[i]=min(self.speeds[i],self.high[-2]) 184 | 185 | # 根据方向动作更新方向 186 | if direction_action == 0: # 向左 187 | self.directions[i] -= 0.1 188 | elif direction_action == 1: # 向右 189 | self.directions[i] += 0.1 190 | self.directions[i]=np.mod(self.directions,2*np.pi) 191 | 192 | 193 | # 计算每个智能体的位移 194 | displacements = self.speeds[:, np.newaxis] * np.column_stack([np.cos(self.directions), np.sin(self.directions)]) 195 | 196 | # 更新每个智能体位置 197 | self.agent_positions += displacements 198 | 199 | # 超出范围则进行取余数操作 200 | self.agent_positions = np.clip(self.agent_positions, self.x_min,self.x_max) 201 | 202 | self.update_main_position() 203 | 204 | # 计算每个代理相对于参考点的位置差异 205 | target_positions=self.calculate_target_coordinates() 206 | differences = self.agent_positions - target_positions 207 | 208 | # 计算奖励 209 | 210 | distances_to_target = np.linalg.norm(differences, axis=1) 211 | 212 | # main位置取余 213 | self.main_position = np.clip(self.main_position,0, 5) 214 | 215 | # 计算个人奖励 216 | distance_reward=-0.1*distances_to_target 217 | 218 | rewards = distance_reward # 距离倒数作为奖励,目标是最小化距离 219 | 220 | # 计算团队奖励 221 | team_reward=np.mean(rewards) 222 | 223 | # 判断是否达到终止条件 224 | done = self.current_step >= self.max_steps 225 | if all(distances < 0.1 for distances in distances_to_target): 226 | team_reward+=1 227 | 228 | # 最终奖励 229 | rewards=rewards*self.tau+team_reward*(1-self.tau) 230 | 231 | self.current_step += 1 232 | 233 | # 使用广播操作计算每个代理与其他所有代理位置的差异 234 | differences = self.agent_positions[:, np.newaxis, :] - self.agent_positions[np.newaxis, :, :] 235 | # 使用 np.linalg.norm 计算每一行的2范数,即每个代理与其他所有代理之间的欧式距离 236 | distances = np.linalg.norm(differences, axis=2) 237 | # 找到距离当前智能体最小的其他智能体位置 238 | try: 239 | second_min_indices = np.argsort(distances, axis=1)[:, 1] 240 | second_min_values = distances[np.arange(len(distances)), second_min_indices] 241 | done=done or np.any(second_min_values<1) 242 | if np.any(second_min_values<1): 243 | print(second_min_values) 244 | print("碰撞") 245 | except: 246 | second_min_indices=[0]*self.num_agents 247 | 248 | # 返回新的状态,包括目标位置、智能体位置、速度和方向 249 | states=[] 250 | 251 | for i in range(self.num_agents): 252 | self.id[i]=1 253 | states.append(np.concatenate([self.id,self.agent_positions[i]-target_positions[i],self.agent_positions[i]-self.agent_positions[second_min_indices[i]],self.speeds[i:i+1],self.directions[i:i+1]])) 254 | self.id[i]=0 255 | self.episode_rewards[i]=rewards[i] 256 | 257 | # 组装info 258 | self.episode_length+=1 259 | 260 | info={} 261 | if done: 262 | self.episode += 1 263 | details = {} 264 | details['r'] = self.episode_rewards 265 | details['l'] = self.episode_length-1 266 | details['e'] = self.episode 267 | info['episode'] = details 268 | 269 | if len(self.frames)!=0: 270 | timestamp = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M') 271 | save_path = f"Follow/{timestamp}.gif" 272 | imageio.mimsave(save_path, self.frames, duration=0.1) # Adjust duration as needed 273 | print("GIF saved successfully.") 274 | 275 | # print("*"*20) 276 | # print("距离reward",distance_reward) 277 | # print("方位reward",angle_reward,0.1*angle_reward) 278 | 279 | # print("*"*20) 280 | 281 | return states,rewards,done,info 282 | 283 | def add_arrow(self, position, direction, color='black'): 284 | # 添加箭头 285 | arrow_length = 0.1 286 | arrow_head_width = 0.05 287 | 288 | arrow_dx = arrow_length * np.cos(direction) 289 | arrow_dy = arrow_length * np.sin(direction) 290 | 291 | self.ax.arrow(position[0], position[1], arrow_dx, arrow_dy, color=color, width=arrow_head_width) 292 | 293 | def render(self, mode="human"): 294 | # 清空子图内容 295 | self.ax.clear() 296 | 297 | # 绘制目标 298 | self.ax.plot(self.main_position[0], self.main_position[1], 'rs', label='main') 299 | 300 | # 绘制每个智能体 301 | for i in range(self.num_agents): 302 | self.ax.plot(self.agent_positions[i, 0], self.agent_positions[i, 1], 'bo', label=f'agent {i + 1}') 303 | 304 | # 添加运动方向箭头 305 | self.add_arrow(self.agent_positions[i], self.directions[i]) 306 | 307 | # 设置图形范围 308 | self.ax.set_xlim(self.x_min,self.x_max) # 根据需要调整范围 309 | self.ax.set_ylim(self.y_min,self.y_max) # 根据需要调整范围 310 | 311 | # 添加图例 312 | self.ax.legend() 313 | 314 | if mode == "human": 315 | # 显示图形 316 | plt.pause(0.01) # 添加短暂的时间间隔,单位为秒 317 | else: 318 | # 将当前帧的图形添加到列表中 319 | self.frames.append(self.fig_to_array()) 320 | 321 | def fig_to_array(self): 322 | # Convert the current figure to an array of pixels 323 | buf = io.BytesIO() 324 | self.ax.figure.savefig(buf, format='png') 325 | buf.seek(0) 326 | img = imageio.imread(buf) 327 | return img 328 | 329 | 330 | if __name__ == "__main__": 331 | 332 | 333 | pi=np.pi 334 | target_info={"num_agents":1,"0":(0,0),"1":(10,pi/3),"2":(20,pi/6),"3":(20,pi/3)} 335 | 336 | # 循环测试代码 337 | env = MultiAgentFollowEnv(target_info) 338 | 339 | state = env.reset() # 重置环境 340 | done = False 341 | 342 | while not done: 343 | # 随机生成每个智能体的离散动作,其中第一维控制速度,有加速、减速、保持, 344 | # 第二维控制方向,有向左、向右、保持 345 | actions = np.random.randint(3, size=(env.num_agents, 2)) 346 | 347 | actions=np.array([[1,1]]) 348 | 349 | states, rewards, done, _ = env.step(actions) # 执行动作 350 | 351 | print(rewards) 352 | 353 | env.render(mode="human") # 渲染环境状态 354 | 355 | print("Testing complete.") 356 | -------------------------------------------------------------------------------- /ppo/custom_env/Walker_Discreate.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | 4 | 5 | class Walker(gym.Env): 6 | def __init__(self, bins=20): 7 | self.env = gym.make("BipedalWalker-v3") 8 | self.env = gym.wrappers.RecordEpisodeStatistics(self.env) 9 | self.observation_space=self.env.observation_space 10 | self.action_space = gym.spaces.MultiDiscrete([bins]*self.env.action_space.shape[0]) 11 | self.discrete_action = np.linspace(-1., 1., bins) 12 | 13 | def step(self, action): 14 | continuous_action = self.discrete_action[action] 15 | next_state, reward, done, info = self.env.step(continuous_action) 16 | return next_state, reward, done, info 17 | 18 | def reset(self): 19 | next_state = self.env.reset() 20 | return next_state 21 | 22 | def render(self, mode="human"): 23 | self.env.render(mode=mode) 24 | 25 | def seed(self, seed=None): 26 | self.env.seed(seed) 27 | 28 | 29 | if __name__ == "__main__": 30 | env = Walker() 31 | state = env.reset() # 初始化状态 32 | total_reward = 0 33 | step_count = 0 34 | done = False 35 | while not done: 36 | env.render() 37 | action = env.action_space.sample() # 随机采样一个动作,此处用于演示 38 | state, reward, done, info = env.step(action) # 执行动作 39 | total_reward += reward 40 | step_count += 1 41 | print('At step {}, reward = {}, done = {}'.format(step_count, reward, done)) 42 | print('Total reward: {}'.format(total_reward)) 43 | -------------------------------------------------------------------------------- /ppo/custom_env/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/custom_env/__init__.py -------------------------------------------------------------------------------- /ppo/custom_env/__pycache__/MultiAgentFollowEnvV2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/custom_env/__pycache__/MultiAgentFollowEnvV2.cpython-38.pyc -------------------------------------------------------------------------------- /ppo/custom_env/__pycache__/MultiAgentFollowEnvV3.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/custom_env/__pycache__/MultiAgentFollowEnvV3.cpython-38.pyc -------------------------------------------------------------------------------- /ppo/custom_env/__pycache__/Walker_Discreate.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/custom_env/__pycache__/Walker_Discreate.cpython-38.pyc -------------------------------------------------------------------------------- /ppo/custom_env/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/custom_env/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ppo/custom_env/__pycache__/snake_env.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/custom_env/__pycache__/snake_env.cpython-38.pyc -------------------------------------------------------------------------------- /ppo/custom_env/snake_env.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | import random 4 | import pygame 5 | import sys 6 | import time 7 | 8 | 9 | class SnakeEnv(gym.Env): 10 | def __init__(self, width=20, height=10): 11 | super(SnakeEnv, self).__init__() 12 | 13 | self.width = width 14 | self.height = height 15 | self.grid = np.zeros((height, width)) 16 | self.snake = None 17 | self.food = None 18 | self.action_mask = [True] * 4 19 | self.episode_length = 0 20 | self.episode_reward = 0 21 | self.episode = 0 22 | self.max_episode_length = 200 23 | self.cell_size = 20 24 | self.window_size = (self.width * self.cell_size, 25 | self.height * self.cell_size) 26 | # 定义颜色 27 | self.color_bg = (255, 255, 255) 28 | self.color_head = (0, 120, 120) 29 | self.color_body = (0, 255, 0) 30 | self.color_food = (255, 0, 0) 31 | 32 | # 定义贪吃蛇环境的观测空间和行动空间 33 | inf = self.width + self.height 34 | low = np.array([-10, -1.0, 0, 0, 0, 0]) # 连续状态空间的最小值 35 | high = np.array([10, 1.0, 1.0, 1.0, 1.0, 1.0]) # 连续状态空间的最大值 36 | self.observation_space = gym.spaces.Box(low, high, shape=(6,), dtype=float) 37 | 38 | # 0 左 1上 2右 3下 39 | self.action_space = gym.spaces.Discrete(4) 40 | 41 | def generate_food(self): 42 | while True: 43 | x = random.randint(0, self.width - 1) 44 | y = random.randint(0, self.height - 1) 45 | if (x, y) not in self.snake: 46 | return (x, y) 47 | 48 | def reset(self): 49 | self.grid = np.zeros((self.height, self.width)) 50 | self.snake = [(0, 0)] 51 | self.food = self.generate_food() 52 | self.time = 1 53 | self.action_mask = self.get_mask() 54 | self.episode_length = 0 55 | self.episode_reward = 0 56 | self.update_grid() 57 | return self.get_state() 58 | 59 | # 状态获取函数 60 | def get_state(self): 61 | state = [] 62 | x, y = self.snake[0] 63 | fx, fy = self.food 64 | state.append(fx - x) 65 | state.append(fy - y) 66 | for gx, gy in [(0, 1), (1, 0), (0, -1), (-1, 0)]: 67 | dx, dy = x + gx, y + gy 68 | if dx < 0 or dy < 0 or dx >= self.width or dy >= self.height or ( 69 | dx, dy) in self.snake: 70 | state.append(0) 71 | continue 72 | else: 73 | state.append(1) # 四个方向可以走 74 | return np.array(state, dtype=np.float32) 75 | 76 | # 更新当前动画状态 77 | def update_grid(self): 78 | self.grid = np.zeros((self.height, self.width)) 79 | x, y = self.snake[0] 80 | self.grid[y, x] = 1 81 | for x, y in self.snake[1:]: 82 | self.grid[y, x] = 2 83 | fx, fy = self.food 84 | self.grid[fy, fx] = 3 85 | 86 | # 获取动作mask 87 | def get_mask(self): 88 | action_mask = [True] * 4 89 | x, y = self.snake[0] 90 | for i, (gx, gy) in enumerate([(0, 1), (1, 0), (0, -1), (-1, 0)]): 91 | dx, dy = x + gx, y + gy 92 | if dx < 0 or dy < 0 or dx >= self.width or dy >= self.height or ( 93 | dx, dy) in self.snake: 94 | action_mask[i] = False 95 | else: 96 | action_mask[i] = True # True则表示动作可以执行 97 | return action_mask 98 | 99 | def step(self, action): 100 | x, y = self.snake[0] 101 | direction = [(0, 1), (1, 0), (0, -1), (-1, 0)] 102 | x = x + direction[action][0] 103 | y = y + direction[action][1] 104 | 105 | self.episode_length += 1 106 | if x < 0 or x >= self.width or y < 0 or y >= self.height or ( 107 | x, y) in self.snake: 108 | reward = -1 109 | done = True 110 | elif (x, y) == self.food: 111 | reward = 1 112 | self.snake.insert(0, (x, y)) 113 | self.food = self.generate_food() 114 | self.update_grid() 115 | done = False 116 | else: 117 | fx, fy = self.food 118 | d = (abs(x - fx) + abs(y - fy)) 119 | reward = 0 120 | self.snake.insert(0, (x, y)) 121 | self.snake.pop() 122 | self.update_grid() 123 | done = False 124 | 125 | info = {} 126 | self.episode_reward += reward 127 | # 更新action_mask 128 | self.action_mask = self.get_mask() 129 | if done: 130 | self.episode += 1 131 | details = {} 132 | details['r'] = self.episode_reward 133 | details['l'] = self.episode_length 134 | details['e'] = self.episode 135 | info['episode'] = details 136 | return self.get_state(), reward, done, info 137 | 138 | def render(self, mode='human'): 139 | # 初始化pygame窗口 140 | pygame.init() 141 | pygame.font.init() 142 | self.font = pygame.font.Font(None, 30) 143 | self.window = pygame.display.set_mode(self.window_size) 144 | 145 | pygame.display.set_caption("Snake Game") 146 | 147 | if mode == 'rgb_array': 148 | surface = pygame.Surface( 149 | (self.width * self.cell_size, self.height * self.cell_size)) 150 | self.window = surface 151 | 152 | self.window.fill(self.color_bg) 153 | 154 | for event in pygame.event.get(): 155 | if event.type == pygame.QUIT: 156 | pygame.quit() 157 | sys.exit() 158 | 159 | for y in range(self.height): 160 | for x in range(self.width): 161 | cell_value = self.grid[y, x] 162 | cell_rect = pygame.Rect(x * self.cell_size, y * self.cell_size, 163 | self.cell_size, self.cell_size) 164 | 165 | for y in range(self.height): 166 | for x in range(self.width): 167 | cell_value = self.grid[y, x] 168 | cell_rect = pygame.Rect(x * self.cell_size, y * self.cell_size, 169 | self.cell_size, self.cell_size) 170 | if cell_value == 0: # 白色的空白格子 171 | pygame.draw.rect(self.window, (255, 255, 0), cell_rect, 1) 172 | elif cell_value == 1: # 贪吃蛇身体 173 | pygame.draw.rect(self.window, self.color_head, cell_rect) 174 | elif cell_value == 2: # 贪吃蛇身体 175 | pygame.draw.rect(self.window, self.color_body, cell_rect) 176 | elif cell_value == 3: # 食物 177 | # pygame.draw.rect(self.window, self.color_food, cell_rect) 178 | pygame.draw.circle(self.window, self.color_food, 179 | (cell_rect.x + self.cell_size // 2, cell_rect.y + self.cell_size // 2), 180 | self.cell_size // 2) 181 | 182 | snake_length_text = self.font.render("Length: " + str(len(self.snake)), 183 | True, (0, 25, 25)) 184 | self.window.blit(snake_length_text, (0, 0)) 185 | 186 | pygame.display.flip() 187 | 188 | if mode == 'rgb_array': 189 | image_array = pygame.surfarray.array3d(self.window) 190 | return image_array 191 | 192 | def close(self): 193 | pygame.quit() # 释放 Pygame 占用的系统资源 194 | sys.exit() # 关闭程序 195 | 196 | 197 | if __name__ == "__main__": 198 | 199 | gym.register(id='Snake-v0', entry_point='snake_env:SnakeEnv') 200 | env = gym.make('Snake-v0') 201 | 202 | next_obs = env.reset() # 初始化状态 203 | total_reward = 0 204 | step_count = 0 205 | done = False 206 | from ppo_snake import PPOAgent 207 | from ppo_snake import parse_args 208 | import torch 209 | 210 | args = parse_args() 211 | agent = PPOAgent(env.observation_space, env.action_space, args).to(args.device) 212 | agent = torch.load("snake-v0/2024_01_30_22_49_Snake-v0.pth") 213 | while not done: 214 | next_obs = torch.tensor(next_obs).unsqueeze(0).to(args.device) 215 | with torch.no_grad(): 216 | action, logprob, entropy, value = agent.get_action_and_value(next_obs) 217 | next_obs, reward, done, info = env.step(action.flatten().cpu().numpy().item()) 218 | if done: 219 | next_obs = env.reset() 220 | env.render() 221 | time.sleep(0.1) 222 | # action = 1#env.action_space.sample() # 随机采样一个动作,此处用于演示 223 | # state, reward, done, info = env.step(action) # 执行动作 224 | total_reward += reward 225 | step_count += 1 226 | print('At step {}, reward = {}, done = {}'.format(step_count, reward, done)) 227 | print('Total reward: {}'.format(total_reward)) 228 | -------------------------------------------------------------------------------- /ppo/ppo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import gym 4 | import pygame 5 | import collections 6 | import random 7 | import torch.nn.functional as F 8 | from tqdm import tqdm 9 | import matplotlib.pyplot as plt 10 | import time 11 | import torch.nn as nn 12 | from rl_utils import plot_smooth_reward 13 | 14 | 15 | def compute_advantage(gamma, lmbda, td_delta): 16 | td_delta = td_delta.detach().numpy() 17 | advantage_list = [] 18 | advantage = 0.0 19 | for delta in td_delta[::-1]: 20 | advantage = gamma * lmbda * advantage + delta 21 | advantage_list.append(advantage) 22 | advantage_list.reverse() 23 | return torch.tensor(advantage_list, dtype=torch.float) 24 | 25 | 26 | class CriticNet(nn.Module): 27 | def __init__(self, state_dim, hidden_dim, action_dim) -> None: 28 | super().__init__() 29 | self.fc1 = torch.nn.Linear(state_dim, hidden_dim) 30 | self.fc2 = torch.nn.Linear(hidden_dim, 1) 31 | 32 | def forward(self, state): 33 | hidden_state = F.relu(self.fc1(state)) 34 | acion_value = self.fc2(hidden_state) 35 | return acion_value 36 | 37 | 38 | class ActorNet(nn.Module): 39 | def __init__(self, state_dim, hidden_dim, action_dim) -> None: 40 | super().__init__() 41 | self.fc1 = nn.Linear(state_dim, hidden_dim) 42 | self.fc2 = nn.Linear(hidden_dim, action_dim) 43 | 44 | def forward(self, state): 45 | 46 | hidden_state = F.relu(self.fc1(state)) 47 | # print(self.fc2(hidden_state)) 48 | probs = F.softmax(self.fc2(hidden_state), dim=1) 49 | return probs 50 | 51 | 52 | class PPO: 53 | def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr, 54 | gamma, lmbda, eps, epochs, device) -> None: 55 | self.actor = ActorNet(state_dim, hidden_dim, action_dim).to(device) 56 | self.critic = CriticNet(state_dim, hidden_dim, action_dim).to(device) 57 | self.gamma = gamma 58 | self.lmbda = lmbda 59 | self.eps = eps 60 | self.epochs = epochs 61 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), 62 | lr=actor_lr) 63 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), 64 | lr=critic_lr) 65 | self.device = device 66 | 67 | def save_model(self, save_path='./', filename='model'): 68 | torch.save(self.actor.state_dict(), f"{save_path}\\{filename}.pth") 69 | 70 | def load_model(self, load_path): 71 | self.actor.load_state_dict(torch.load(load_path)) 72 | 73 | def take_action(self, state): 74 | state = torch.tensor([state], dtype=torch.float).to(self.device) 75 | probs = self.actor(state) 76 | action_dist = torch.distributions.Categorical(probs) 77 | action = action_dist.sample() 78 | return action.item() 79 | 80 | def update(self, transition_dict): 81 | rewards = torch.tensor(transition_dict['rewards'], 82 | dtype=torch.float).view(-1, 1).to(self.device) 83 | states = torch.tensor(transition_dict['states'], 84 | dtype=torch.float).to(self.device) 85 | actions = torch.tensor(transition_dict['actions']).view(-1, 1).to( 86 | self.device) 87 | next_states = torch.tensor(transition_dict['next_states'], 88 | dtype=torch.float).to(self.device) 89 | next_actions = torch.tensor(transition_dict['next_actions']).view( 90 | -1, 1).to(self.device) 91 | dones = torch.tensor(transition_dict['dones'], 92 | dtype=torch.float).to(self.device).view(-1, 1) 93 | 94 | v_now = self.critic(states).view(-1, 1) 95 | v_next = self.critic(next_states).view(-1, 1) 96 | y_now = (self.gamma * v_next * (1 - dones) + rewards).view(-1, 1) 97 | td_delta = y_now - v_now 98 | advantage = compute_advantage(self.gamma, self.lmbda, 99 | td_delta.cpu()).to(self.device) 100 | old_log_probs = torch.log(self.actor(states).gather(1, 101 | actions)).detach() 102 | for _ in range(self.epochs): 103 | log_probs = torch.log(self.actor(states).gather(1, actions)) 104 | ratio = torch.exp(log_probs - old_log_probs) 105 | surr1 = ratio * advantage 106 | surr2 = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * advantage 107 | actor_loss = torch.mean(-torch.min(surr1, surr2)) 108 | critic_loss = torch.mean( 109 | F.mse_loss(self.critic(states), y_now.detach())) 110 | 111 | self.actor_optimizer.zero_grad() 112 | self.critic_optimizer.zero_grad() 113 | 114 | actor_loss.backward() 115 | critic_loss.backward() 116 | 117 | self.actor_optimizer.step() 118 | self.critic_optimizer.step() 119 | 120 | 121 | if __name__ == "__main__": 122 | 123 | gamma = 0.99 124 | algorithm_name = "PPOdemo" 125 | num_episodes = 10000 126 | actor_lr = 1e-3 127 | critic_lr = 1e-3 128 | lmbda = 0.95 129 | eps = 0.2 130 | epochs = 8 131 | device = torch.device('cuda') 132 | env_name = 'Snake-v0' #'CartPole-v0' 133 | 134 | # 注册环境 135 | gym.register(id='Snake-v0', entry_point='snake_env:SnakeEnv') 136 | 137 | env = gym.make(env_name) 138 | 139 | random.seed(0) 140 | np.random.seed(0) 141 | env.seed(0) 142 | torch.manual_seed(0) 143 | 144 | state_dim = env.observation_space.shape[0] 145 | hidden_dim = 128 146 | action_dim = env.action_space.n 147 | agent = PPO(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, gamma, 148 | lmbda, eps, epochs, device) 149 | 150 | return_list = [] 151 | max_reward = 0 152 | for i in range(10): 153 | with tqdm(total=int(num_episodes / 10), 154 | desc='Iteration %d' % i) as pbar: 155 | for i_episodes in range(int(num_episodes / 10)): 156 | episode_return = 0 157 | state = env.reset() 158 | done = False 159 | 160 | transition_dict = { 161 | 'states': [], 162 | 'actions': [], 163 | 'next_states': [], 164 | 'next_actions': [], 165 | 'rewards': [], 166 | 'dones': [] 167 | } 168 | 169 | while not done: 170 | action = agent.take_action(state) 171 | next_state, reward, done, _ = env.step(action) 172 | next_action = agent.take_action(next_state) 173 | env.render() 174 | 175 | transition_dict['states'].append(state) 176 | transition_dict['actions'].append(action) 177 | transition_dict['next_states'].append(next_state) 178 | transition_dict['next_actions'].append(next_action) 179 | transition_dict['rewards'].append(reward) 180 | transition_dict['dones'].append(done) 181 | 182 | state = next_state 183 | episode_return += reward 184 | if i_episodes == int(num_episodes / 10) - 1: 185 | env.render() 186 | time.sleep(0.1) 187 | agent.update(transition_dict) 188 | 189 | return_list.append(episode_return) 190 | plot_smooth_reward(return_list, 100, env_name, algorithm_name) 191 | if episode_return > max_reward: 192 | max_reward = episode_return 193 | agent.save_model(env_name, algorithm_name) 194 | if (i_episodes + 1 % 10 == 0): 195 | pbar.set_postfix({ 196 | 'episode': 197 | '%d' % (num_episodes / 10 * i + i_episodes + 1), 198 | 'return': 199 | '%.3f' % np.mean(return_list[-10:]) 200 | }) 201 | pbar.update(1) 202 | -------------------------------------------------------------------------------- /ppo/ppoconPendulum.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import gym 4 | import pygame 5 | import collections 6 | import random 7 | import torch.nn.functional as F 8 | from tqdm import tqdm 9 | import matplotlib.pyplot as plt 10 | import time 11 | import torch.nn as nn 12 | from rl_utils import plot_smooth_reward 13 | 14 | 15 | def compute_advantage(gamma, lmbda, td_delta): 16 | td_delta = td_delta.detach().numpy() 17 | advantage_list = [] 18 | advantage = 0.0 19 | for delta in td_delta[::-1]: 20 | advantage = gamma * lmbda * advantage + delta 21 | advantage_list.append(advantage) 22 | advantage_list.reverse() 23 | return torch.tensor(advantage_list, dtype=torch.float) 24 | 25 | 26 | class CriticNet(nn.Module): 27 | def __init__(self, state_dim, hidden_dim, action_dim) -> None: 28 | super().__init__() 29 | self.fc1 = torch.nn.Linear(state_dim, hidden_dim) 30 | self.fc2 = torch.nn.Linear(hidden_dim, 1) 31 | 32 | def forward(self, state): 33 | hidden_state = F.relu(self.fc1(state)) 34 | acion_value = self.fc2(hidden_state) 35 | return acion_value 36 | 37 | 38 | class ActorNet(nn.Module): 39 | def __init__(self, state_dim, hidden_dim, action_dim) -> None: 40 | super().__init__() 41 | self.fc1 = nn.Linear(state_dim, hidden_dim) 42 | self.fc_mu = nn.Linear(hidden_dim, action_dim) 43 | self.fc_std = nn.Linear(hidden_dim, action_dim) 44 | 45 | def forward(self, state): 46 | 47 | hidden_state = F.relu(self.fc1(state)) 48 | mu = 2.0 * F.tanh(self.fc_mu(hidden_state)) 49 | sigma = F.softplus(self.fc_std(hidden_state)) 50 | return mu, sigma 51 | 52 | 53 | class PPO: 54 | def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr, 55 | gamma, lmbda, eps, epochs, device) -> None: 56 | self.actor = ActorNet(state_dim, hidden_dim, action_dim).to(device) 57 | self.critic = CriticNet(state_dim, hidden_dim, action_dim).to(device) 58 | self.gamma = gamma 59 | self.lmbda = lmbda 60 | self.eps = eps 61 | self.epochs = epochs 62 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), 63 | lr=actor_lr) 64 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), 65 | lr=critic_lr) 66 | self.device = device 67 | 68 | def save_model(self, save_path='./', filename='model'): 69 | torch.save(self.actor.state_dict(), f"{save_path}\\{filename}.pth") 70 | 71 | def load_model(self, load_path): 72 | self.actor.load_state_dict(torch.load(load_path)) 73 | 74 | def take_action(self, state): 75 | print(state) 76 | state = torch.tensor([state], dtype=torch.float).to(self.device) 77 | mu, sigma = self.actor(state) 78 | action_dist = torch.distributions.Normal(mu, sigma) 79 | action = action_dist.sample() 80 | # print(action) 81 | return action.cpu().numpy().flatten() 82 | 83 | def update(self, transition_dict): 84 | rewards = torch.tensor(transition_dict['rewards'], 85 | dtype=torch.float).view(-1, 1).to(self.device) 86 | states = torch.tensor(transition_dict['states'], 87 | dtype=torch.float).to(self.device) 88 | actions = torch.tensor(transition_dict['actions']).view(-1, 1).to( 89 | self.device) 90 | next_states = torch.tensor(transition_dict['next_states'], 91 | dtype=torch.float).to(self.device) 92 | next_actions = torch.tensor(transition_dict['next_actions']).view( 93 | -1, 1).to(self.device) 94 | dones = torch.tensor(transition_dict['dones'], 95 | dtype=torch.float).to(self.device).view(-1, 1) 96 | 97 | v_now = self.critic(states).view(-1, 1) 98 | v_next = self.critic(next_states).view(-1, 1) 99 | y_now = (self.gamma * v_next * (1 - dones) + rewards).view(-1, 1) 100 | td_delta = y_now - v_now 101 | advantage = compute_advantage(self.gamma, self.lmbda, 102 | td_delta.cpu()).to(self.device) 103 | mu, sigma = self.actor(states) 104 | action_dists = torch.distributions.Normal(mu.detach(), sigma.detach()) 105 | old_log_probs = action_dists.log_prob(actions) 106 | for _ in range(self.epochs): 107 | 108 | mu, sigma = self.actor(states) 109 | action_dists = torch.distributions.Normal(mu, sigma) 110 | log_probs = action_dists.log_prob(actions) 111 | 112 | ratio = torch.exp(log_probs - old_log_probs) 113 | surr1 = ratio * advantage 114 | surr2 = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * advantage 115 | actor_loss = torch.mean(-torch.min(surr1, surr2)) 116 | critic_loss = torch.mean( 117 | F.mse_loss(self.critic(states), y_now.detach())) 118 | 119 | self.actor_optimizer.zero_grad() 120 | self.critic_optimizer.zero_grad() 121 | 122 | actor_loss.backward() 123 | critic_loss.backward() 124 | 125 | self.actor_optimizer.step() 126 | self.critic_optimizer.step() 127 | 128 | 129 | if __name__ == "__main__": 130 | 131 | gamma = 0.9 132 | algorithm_name = "demo" 133 | num_episodes = 2000 134 | actor_lr = 1e-4 135 | critic_lr = 5e-3 136 | lmbda = 0.9 137 | eps = 0.2 138 | epochs = 10 139 | device = torch.device('cuda') 140 | env_name = 'Pendulum-v1' #'CartPole-v0' 141 | max_step = 500 142 | # 注册环境 143 | gym.register(id='Snake-v0', entry_point='snake_env:SnakeEnv') 144 | 145 | env = gym.make(env_name) 146 | 147 | random.seed(0) 148 | np.random.seed(0) 149 | env.seed(0) 150 | torch.manual_seed(0) 151 | 152 | state_dim = env.observation_space.shape[0] 153 | hidden_dim = 128 154 | action_dim = env.action_space.shape[0] 155 | agent = PPO(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, gamma, 156 | lmbda, eps, epochs, device) 157 | 158 | return_list = [] 159 | inf = float('inf') 160 | max_reward = -inf 161 | for i in range(10): 162 | with tqdm(total=int(num_episodes / 10), 163 | desc='Iteration %d' % i) as pbar: 164 | for i_episodes in range(int(num_episodes / 10)): 165 | episode_return = 0 166 | state = env.reset() 167 | done = False 168 | 169 | transition_dict = { 170 | 'states': [], 171 | 'actions': [], 172 | 'next_states': [], 173 | 'next_actions': [], 174 | 'rewards': [], 175 | 'dones': [] 176 | } 177 | 178 | for _ in range(max_step): 179 | action = agent.take_action(state) 180 | next_state, reward, done, _ = env.step(action) 181 | 182 | next_action = agent.take_action(next_state) 183 | # print(next_action, next_state, action, state, reward) 184 | 185 | transition_dict['states'].append(state) 186 | transition_dict['actions'].append(action) 187 | transition_dict['next_states'].append(next_state) 188 | transition_dict['next_actions'].append(next_action) 189 | transition_dict['rewards'].append((reward + 8.0) / 8.0) 190 | transition_dict['dones'].append(done) 191 | 192 | state = next_state 193 | episode_return += reward 194 | if i_episodes == int(num_episodes / 10) - 1: 195 | env.render() 196 | time.sleep(0.1) 197 | if done: 198 | break 199 | agent.update(transition_dict) 200 | 201 | return_list.append(episode_return) 202 | plot_smooth_reward(return_list, 100, env_name, algorithm_name) 203 | if episode_return > max_reward: 204 | max_reward = episode_return 205 | agent.save_model(env_name, algorithm_name) 206 | if (i_episodes + 1 % 10 == 0): 207 | pbar.set_postfix({ 208 | 'episode': 209 | '%d' % (num_episodes / 10 * i + i_episodes + 1), 210 | 'return': 211 | '%.3f' % np.mean(return_list[-10:]) 212 | }) 213 | pbar.update(1) 214 | -------------------------------------------------------------------------------- /ppo/ppodisPendulum.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import gym 4 | import pygame 5 | import collections 6 | import random 7 | import torch.nn.functional as F 8 | from tqdm import tqdm 9 | import matplotlib.pyplot as plt 10 | import time 11 | import torch.nn as nn 12 | from rl_utils import plot_smooth_reward 13 | 14 | 15 | def compute_advantage(gamma, lmbda, td_delta): 16 | td_delta = td_delta.detach().numpy() 17 | advantage_list = [] 18 | advantage = 0.0 19 | for delta in td_delta[::-1]: 20 | advantage = gamma * lmbda * advantage + delta 21 | advantage_list.append(advantage) 22 | advantage_list.reverse() 23 | return torch.tensor(advantage_list, dtype=torch.float) 24 | 25 | 26 | class CriticNet(nn.Module): 27 | def __init__(self, state_dim, hidden_dim, action_dim) -> None: 28 | super().__init__() 29 | self.fc1 = torch.nn.Linear(state_dim, hidden_dim) 30 | self.fc2 = torch.nn.Linear(hidden_dim, 1) 31 | 32 | def forward(self, state): 33 | hidden_state = F.relu(self.fc1(state)) 34 | acion_value = self.fc2(hidden_state) 35 | return acion_value 36 | 37 | 38 | class ActorNet(nn.Module): 39 | def __init__(self, state_dim, hidden_dim, action_dim) -> None: 40 | super().__init__() 41 | self.fc1 = nn.Linear(state_dim, hidden_dim) 42 | self.fc2 = nn.Linear(hidden_dim, action_dim) 43 | 44 | def forward(self, state): 45 | 46 | hidden_state = F.relu(self.fc1(state)) 47 | # print(self.fc2(hidden_state)) 48 | probs = F.softmax(self.fc2(hidden_state), dim=1) 49 | return probs 50 | 51 | 52 | class PPO: 53 | def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr, 54 | gamma, lmbda, eps, epochs, device) -> None: 55 | self.actor = ActorNet(state_dim, hidden_dim, action_dim).to(device) 56 | self.critic = CriticNet(state_dim, hidden_dim, action_dim).to(device) 57 | self.gamma = gamma 58 | self.lmbda = lmbda 59 | self.eps = eps 60 | self.epochs = epochs 61 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), 62 | lr=actor_lr) 63 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), 64 | lr=critic_lr) 65 | self.device = device 66 | 67 | def save_model(self, save_path='./', filename='model'): 68 | torch.save(self.actor.state_dict(), f"{save_path}\\{filename}.pth") 69 | 70 | def load_model(self, load_path): 71 | self.actor.load_state_dict(torch.load(load_path)) 72 | 73 | def take_action(self, state): 74 | state = torch.tensor(np.array([state]), dtype=torch.float).to(self.device) 75 | probs = self.actor(state) 76 | action_dist = torch.distributions.Categorical(probs) 77 | action = action_dist.sample() 78 | return action.item() 79 | 80 | def update(self, transition_dict): 81 | rewards = torch.tensor(transition_dict['rewards'], 82 | dtype=torch.float).view(-1, 1).to(self.device) 83 | states = torch.tensor(np.array(transition_dict['states']), 84 | dtype=torch.float).to(self.device) 85 | actions = torch.tensor(transition_dict['actions']).view(-1, 1).to( 86 | self.device) 87 | next_states = torch.tensor(transition_dict['next_states'], 88 | dtype=torch.float).to(self.device) 89 | next_actions = torch.tensor(transition_dict['next_actions']).view( 90 | -1, 1).to(self.device) 91 | dones = torch.tensor(transition_dict['dones'], 92 | dtype=torch.float).to(self.device).view(-1, 1) 93 | 94 | v_now = self.critic(states).view(-1, 1) 95 | v_next = self.critic(next_states).view(-1, 1) 96 | y_now = (self.gamma * v_next * (1 - dones) + rewards).view(-1, 1) 97 | td_delta = y_now - v_now 98 | advantage = compute_advantage(self.gamma, self.lmbda, 99 | td_delta.cpu()).to(self.device) 100 | old_log_probs = torch.log(self.actor(states).gather(1, 101 | actions)).detach() 102 | for _ in range(self.epochs): 103 | log_probs = torch.log(self.actor(states).gather(1, actions)) 104 | ratio = torch.exp(log_probs - old_log_probs) 105 | surr1 = ratio * advantage 106 | surr2 = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * advantage 107 | actor_loss = torch.mean(-torch.min(surr1, surr2)) 108 | critic_loss = torch.mean( 109 | F.mse_loss(self.critic(states), y_now.detach())) 110 | 111 | self.actor_optimizer.zero_grad() 112 | self.critic_optimizer.zero_grad() 113 | 114 | actor_loss.backward() 115 | critic_loss.backward() 116 | 117 | self.actor_optimizer.step() 118 | self.critic_optimizer.step() 119 | 120 | 121 | if __name__ == "__main__": 122 | 123 | gamma = 0.9 124 | algorithm_name = "PPOdis" 125 | num_episodes = 2000 126 | actor_lr = 1e-3 127 | critic_lr = 2e-3 128 | lmbda = 0.9 129 | eps = 0.2 130 | epochs = 10 131 | device = torch.device('cuda') 132 | env_name = 'Pendulum-v1' #'CartPole-v0' 133 | max_step = 500 134 | 135 | env = gym.make(env_name) 136 | random_seed=3407 137 | random.seed(random_seed) 138 | np.random.seed(random_seed) 139 | torch.manual_seed(random_seed) 140 | 141 | state_dim = env.observation_space.shape[0] 142 | hidden_dim = 256 143 | try: 144 | action_dim = env.action_space.n 145 | except: 146 | action_dim=15 147 | agent = PPO(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, gamma, 148 | lmbda, eps, epochs, device) 149 | 150 | return_list = [] 151 | max_reward = 0 152 | 153 | lower_limit = -2 # 下限 154 | upper_limit = 2 # 上限 155 | # 计算每个等份的大小 156 | partition_size = (upper_limit - lower_limit) / (action_dim-1) 157 | # 使用lambda表达式生成等份的起始和结束位置 158 | partitions = [ [lower_limit + partition_size*i] for i in range(action_dim)] 159 | print(partitions) 160 | 161 | 162 | for i in range(10): 163 | with tqdm(total=int(num_episodes / 10), 164 | desc='Iteration %d' % i) as pbar: 165 | for i_episodes in range(int(num_episodes / 10)): 166 | episode_return = 0 167 | state = env.reset(seed=random_seed) 168 | done = False 169 | 170 | transition_dict = { 171 | 'states': [], 172 | 'actions': [], 173 | 'next_states': [], 174 | 'next_actions': [], 175 | 'rewards': [], 176 | 'dones': [] 177 | } 178 | 179 | for _ in range(max_step): 180 | if done: 181 | break 182 | action_index = agent.take_action(state) 183 | action=partitions[action_index] 184 | next_state, reward, done, _ = env.step(action) 185 | next_action_index = agent.take_action(next_state) 186 | transition_dict['states'].append(state) 187 | transition_dict['actions'].append(action_index) 188 | transition_dict['next_states'].append(next_state) 189 | transition_dict['next_actions'].append(next_action_index) 190 | transition_dict['rewards'].append(reward) 191 | transition_dict['dones'].append(done) 192 | 193 | state = next_state 194 | episode_return += reward 195 | agent.update(transition_dict) 196 | 197 | return_list.append(episode_return) 198 | plot_smooth_reward(return_list, 5, env_name, algorithm_name) 199 | if episode_return >= max_reward: 200 | max_reward = episode_return 201 | agent.save_model(env_name, algorithm_name) 202 | if (i_episodes + 1 % 10 == 0): 203 | pbar.set_postfix({ 204 | 'episode': 205 | '%d' % (num_episodes / 10 * i + i_episodes + 1), 206 | 'return': 207 | '%.3f' % np.mean(return_list[-10:]) 208 | }) 209 | pbar.update(1) 210 | agent.save_model(env_name, algorithm_name) 211 | -------------------------------------------------------------------------------- /ppo/runs/BipedalWalker-v3__ppo_continuous__1__1706597051/events.out.tfevents.1706597051.SKY-20230422NIZ.12216.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/runs/BipedalWalker-v3__ppo_continuous__1__1706597051/events.out.tfevents.1706597051.SKY-20230422NIZ.12216.0 -------------------------------------------------------------------------------- /ppo/runs/BipedalWalker-v3__ppo_continuous__1__1706771659/events.out.tfevents.1706771659.DESKTOP-VAT73EM.12564.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/runs/BipedalWalker-v3__ppo_continuous__1__1706771659/events.out.tfevents.1706771659.DESKTOP-VAT73EM.12564.0 -------------------------------------------------------------------------------- /ppo/runs/BipedalWalker-v3__ppo_multienv__1__1706755654/events.out.tfevents.1706755654.DESKTOP-VAT73EM.4120.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/runs/BipedalWalker-v3__ppo_multienv__1__1706755654/events.out.tfevents.1706755654.DESKTOP-VAT73EM.4120.0 -------------------------------------------------------------------------------- /ppo/runs/BipedalWalker-v3__ppo_multienv__1__1706774718/events.out.tfevents.1706774718.DESKTOP-VAT73EM.11980.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/runs/BipedalWalker-v3__ppo_multienv__1__1706774718/events.out.tfevents.1706774718.DESKTOP-VAT73EM.11980.0 -------------------------------------------------------------------------------- /ppo/runs/BipedalWalkerHardcore-v3__ppo_multienv__1__1706780430/events.out.tfevents.1706780430.DESKTOP-VAT73EM.12616.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/runs/BipedalWalkerHardcore-v3__ppo_multienv__1__1706780430/events.out.tfevents.1706780430.DESKTOP-VAT73EM.12616.0 -------------------------------------------------------------------------------- /ppo/runs/Follow__ppo_sharing__1__1708769402/events.out.tfevents.1708769402.LAPTOP-K5GRC2HU.20764.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/runs/Follow__ppo_sharing__1__1708769402/events.out.tfevents.1708769402.LAPTOP-K5GRC2HU.20764.0 -------------------------------------------------------------------------------- /ppo/runs/Snake-v0__ppo_discretemask__1__1706672792/events.out.tfevents.1706672792.SKY-20230422NIZ.11492.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/runs/Snake-v0__ppo_discretemask__1__1706672792/events.out.tfevents.1706672792.SKY-20230422NIZ.11492.0 -------------------------------------------------------------------------------- /ppo/runs/Snake-v0__ppo_discretemask__1__1707018799/events.out.tfevents.1707018799.DESKTOP-VAT73EM.4348.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/runs/Snake-v0__ppo_discretemask__1__1707018799/events.out.tfevents.1707018799.DESKTOP-VAT73EM.4348.0 -------------------------------------------------------------------------------- /ppo/runs/Snake-v0__ppo_snake__1__1706623864/events.out.tfevents.1706623864.SKY-20230422NIZ.676.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/runs/Snake-v0__ppo_snake__1__1706623864/events.out.tfevents.1706623864.SKY-20230422NIZ.676.0 -------------------------------------------------------------------------------- /ppo/runs/Snake-v0__ppo_snake__1__1708424259/events.out.tfevents.1708424259.LAPTOP-K5GRC2HU.22668.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/runs/Snake-v0__ppo_snake__1__1708424259/events.out.tfevents.1708424259.LAPTOP-K5GRC2HU.22668.0 -------------------------------------------------------------------------------- /ppo/runs/Walker__ppomultidiscrete__1__1706597666/events.out.tfevents.1706597666.SKY-20230422NIZ.10160.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/runs/Walker__ppomultidiscrete__1__1706597666/events.out.tfevents.1706597666.SKY-20230422NIZ.10160.0 -------------------------------------------------------------------------------- /ppo/runs/runs.rar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/runs/runs.rar -------------------------------------------------------------------------------- /ppo/snake-v0/2024_01_30_22_14_Snake-v0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/snake-v0/2024_01_30_22_14_Snake-v0.pth -------------------------------------------------------------------------------- /ppo/snake-v0/2024_01_30_22_19_Snake-v0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/snake-v0/2024_01_30_22_19_Snake-v0.pth -------------------------------------------------------------------------------- /ppo/snake-v0/2024_01_30_22_22_Snake-v0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/snake-v0/2024_01_30_22_22_Snake-v0.pth -------------------------------------------------------------------------------- /ppo/snake-v0/2024_01_30_22_36_Snake-v0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/snake-v0/2024_01_30_22_36_Snake-v0.pth -------------------------------------------------------------------------------- /ppo/snake-v0/2024_01_30_22_49_Snake-v0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/snake-v0/2024_01_30_22_49_Snake-v0.pth -------------------------------------------------------------------------------- /ppo/snake-v0/2024_01_31_19_26_Snake-v0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/snake-v0/2024_01_31_19_26_Snake-v0.pth -------------------------------------------------------------------------------- /ppo/snake-v0/2024_01_31_19_27_Snake-v0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/snake-v0/2024_01_31_19_27_Snake-v0.pth -------------------------------------------------------------------------------- /ppo/snake-v0/2024_01_31_19_28_Snake-v0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/snake-v0/2024_01_31_19_28_Snake-v0.pth -------------------------------------------------------------------------------- /ppo/snake-v0mask/2024_01_31_11_48_Snake-v0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/snake-v0mask/2024_01_31_11_48_Snake-v0.pth -------------------------------------------------------------------------------- /ppo/snake-v0mask/2024_01_31_11_49_Snake-v0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/snake-v0mask/2024_01_31_11_49_Snake-v0.pth -------------------------------------------------------------------------------- /ppo/snake-v0mask/2024_01_31_11_51_Snake-v0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/snake-v0mask/2024_01_31_11_51_Snake-v0.pth -------------------------------------------------------------------------------- /ppo/snake-v0mask/2024_01_31_11_55_Snake-v0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/snake-v0mask/2024_01_31_11_55_Snake-v0.pth -------------------------------------------------------------------------------- /ppo/snake-v0mask/2024_01_31_12_00_Snake-v0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/snake-v0mask/2024_01_31_12_00_Snake-v0.pth -------------------------------------------------------------------------------- /ppo/snake-v0mask/2024_01_31_12_09_Snake-v0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/snake-v0mask/2024_01_31_12_09_Snake-v0.pth -------------------------------------------------------------------------------- /ppo/snake-v0mask/2024_01_31_12_13_Snake-v0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/snake-v0mask/2024_01_31_12_13_Snake-v0.pth -------------------------------------------------------------------------------- /ppo/snake-v0mask/2024_01_31_12_31_Snake-v0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/snake-v0mask/2024_01_31_12_31_Snake-v0.pth -------------------------------------------------------------------------------- /ppo/snake-v0mask/2024_01_31_12_48_Snake-v0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/snake-v0mask/2024_01_31_12_48_Snake-v0.pth -------------------------------------------------------------------------------- /ppo/snake-v0mask/2024_01_31_17_20_Snake-v0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/snake-v0mask/2024_01_31_17_20_Snake-v0.pth -------------------------------------------------------------------------------- /ppo/snake-v0mask/2024_01_31_17_21_Snake-v0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/ppo/snake-v0mask/2024_01_31_17_21_Snake-v0.pth -------------------------------------------------------------------------------- /ppo/test_follow.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | def plot_agent_positions(target_info): 5 | num_agents = target_info["num_agents"] 6 | 7 | # 主位置 8 | main_position = np.array([0, 0]) 9 | 10 | # 初始化代理位置数组 11 | agent_positions = np.zeros((num_agents, 2)) 12 | 13 | # 逐个处理每个代理 14 | for i in range(num_agents): 15 | distance, angle = target_info[str(i)] 16 | # 计算代理相对于主位置的位置 17 | agent_positions[i, 0] = distance * np.cos(angle) 18 | agent_positions[i, 1] = distance * np.sin(angle) 19 | 20 | # 绘制代理位置和主位置 21 | plt.scatter(agent_positions[:, 0], agent_positions[:, 1], label="Agents", marker='o', s=100) 22 | plt.scatter(main_position[0], main_position[1], color='red', label="Main Position", marker='s', s=200) 23 | 24 | # 添加注释 25 | for i, values in target_info.items(): 26 | if i != "num_agents": 27 | distance, angle = values 28 | plt.annotate(f"Agent {i}\nDistance: {distance}\nAngle: {angle:.2f}", 29 | xy=(agent_positions[int(i), 0], agent_positions[int(i), 1]), 30 | xytext=(agent_positions[int(i), 0]+1, agent_positions[int(i), 1]+1), 31 | arrowprops=dict(facecolor='black', shrink=0.05), 32 | fontsize=8, ha='center', va='bottom') 33 | 34 | # 设置图形属性 35 | plt.title("Agent Positions Relative to Main") 36 | plt.xlabel("X Coordinate") 37 | plt.ylabel("Y Coordinate") 38 | plt.legend() 39 | plt.grid(True) 40 | plt.axis('equal') # 使坐标轴比例相等 41 | plt.show() 42 | 43 | # 示例数据 44 | pi=np.pi 45 | target_info={"num_agents":4,"0":(5,pi/6),"1":(5,pi/3),"2":(10,pi/6),"3":(10,pi/3)} 46 | 47 | 48 | # 调用函数进行绘图 49 | plot_agent_positions(target_info) 50 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gym==0.25.2 2 | matplotlib==3.5.3 3 | numpy==1.24.2 4 | pygame==2.1.2 5 | PyYAML==6.0 6 | PyYAML==6.0.1 7 | torch==1.10.0+cu113 8 | tqdm==4.65.0 9 | -------------------------------------------------------------------------------- /result/gif/BipedalWalker-v3/BipedalWalker-v3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/result/gif/BipedalWalker-v3/BipedalWalker-v3.gif -------------------------------------------------------------------------------- /result/gif/BipedalWalker-v3/SAC.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/result/gif/BipedalWalker-v3/SAC.png -------------------------------------------------------------------------------- /result/gif/BipedalWalker-v3/SAC.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/result/gif/BipedalWalker-v3/SAC.pth -------------------------------------------------------------------------------- /result/gif/BipedalWalker-v3/SACtrick.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/result/gif/BipedalWalker-v3/SACtrick.png -------------------------------------------------------------------------------- /result/gif/BipedalWalker-v3/SACtrick.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/result/gif/BipedalWalker-v3/SACtrick.pth -------------------------------------------------------------------------------- /result/gif/BipedalWalkerHardcore-v3/BipedalWalkerHardcore-v3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/result/gif/BipedalWalkerHardcore-v3/BipedalWalkerHardcore-v3.gif -------------------------------------------------------------------------------- /result/gif/BipedalWalkerHardcore-v3/SAC10000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/result/gif/BipedalWalkerHardcore-v3/SAC10000.png -------------------------------------------------------------------------------- /result/gif/BipedalWalkerHardcore-v3/SAC10000.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/result/gif/BipedalWalkerHardcore-v3/SAC10000.pth -------------------------------------------------------------------------------- /result/gif/BipedalWalkerHardcore-v3/SAC3000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/result/gif/BipedalWalkerHardcore-v3/SAC3000.png -------------------------------------------------------------------------------- /result/gif/BipedalWalkerHardcore-v3/SAC3000.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/result/gif/BipedalWalkerHardcore-v3/SAC3000.pth -------------------------------------------------------------------------------- /result/gif/Snake-v0/A2C.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/result/gif/Snake-v0/A2C.gif -------------------------------------------------------------------------------- /result/gif/Snake-v0/A2C.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/result/gif/Snake-v0/A2C.png -------------------------------------------------------------------------------- /result/gif/Snake-v0/A2C.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/result/gif/Snake-v0/A2C.pth -------------------------------------------------------------------------------- /result/gif/Snake-v0/A2C.yaml: -------------------------------------------------------------------------------- 1 | train: 2 | gamma: 0.99 3 | algorithm_name: "A2C" 4 | num_episodes: 5000 5 | actor_lr: 1.0e-3 6 | critic_lr: 3.0e-3 7 | env_name: 'Snake-v0' 8 | hidden_dim: 128 9 | 10 | test: 11 | gamma: 0.99 12 | algorithm_name: "A2C" 13 | num_episodes: 5000 14 | actor_lr: 1.0e-3 15 | critic_lr: 3.0e-3 16 | env_name: 'Snake-v0' 17 | hidden_dim: 128 -------------------------------------------------------------------------------- /result/gif/Snake-v0/A2C_TARGET.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/result/gif/Snake-v0/A2C_TARGET.png -------------------------------------------------------------------------------- /result/gif/Snake-v0/A2C_TARGET.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/result/gif/Snake-v0/A2C_TARGET.pth -------------------------------------------------------------------------------- /result/gif/Snake-v0/AC.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/result/gif/Snake-v0/AC.png -------------------------------------------------------------------------------- /result/gif/Snake-v0/AC.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/result/gif/Snake-v0/AC.pth -------------------------------------------------------------------------------- /result/gif/Snake-v0/AC_TARGET.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/result/gif/Snake-v0/AC_TARGET.png -------------------------------------------------------------------------------- /result/gif/Snake-v0/AC_TARGET.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/result/gif/Snake-v0/AC_TARGET.pth -------------------------------------------------------------------------------- /result/gif/Snake-v0/DDQN.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/result/gif/Snake-v0/DDQN.gif -------------------------------------------------------------------------------- /result/gif/Snake-v0/DDQN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/result/gif/Snake-v0/DDQN.png -------------------------------------------------------------------------------- /result/gif/Snake-v0/DDQN.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/result/gif/Snake-v0/DDQN.pth -------------------------------------------------------------------------------- /result/gif/Snake-v0/DDRQN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/result/gif/Snake-v0/DDRQN.png -------------------------------------------------------------------------------- /result/gif/Snake-v0/DDRQN.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/result/gif/Snake-v0/DDRQN.pth -------------------------------------------------------------------------------- /result/gif/Snake-v0/DQN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/result/gif/Snake-v0/DQN.png -------------------------------------------------------------------------------- /result/gif/Snake-v0/DQN.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/result/gif/Snake-v0/DQN.pth -------------------------------------------------------------------------------- /result/gif/Snake-v0/DRQN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/result/gif/Snake-v0/DRQN.png -------------------------------------------------------------------------------- /result/gif/Snake-v0/DRQN.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/result/gif/Snake-v0/DRQN.pth -------------------------------------------------------------------------------- /result/gif/Snake-v0/PPO.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/result/gif/Snake-v0/PPO.gif -------------------------------------------------------------------------------- /result/gif/Snake-v0/PPO.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/result/gif/Snake-v0/PPO.png -------------------------------------------------------------------------------- /result/gif/Snake-v0/PPO.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/result/gif/Snake-v0/PPO.pth -------------------------------------------------------------------------------- /result/gif/Snake-v0/REINFORCE.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/result/gif/Snake-v0/REINFORCE.png -------------------------------------------------------------------------------- /result/gif/Snake-v0/REINFORCE.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/result/gif/Snake-v0/REINFORCE.pth -------------------------------------------------------------------------------- /result/gif/Snake-v0/REINFORCE_baseline.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/result/gif/Snake-v0/REINFORCE_baseline.gif -------------------------------------------------------------------------------- /result/gif/Snake-v0/REINFORCE_baseline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/result/gif/Snake-v0/REINFORCE_baseline.png -------------------------------------------------------------------------------- /result/gif/Snake-v0/REINFORCE_baseline.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/result/gif/Snake-v0/REINFORCE_baseline.pth -------------------------------------------------------------------------------- /result/gif/Snake-v1/A2C.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/result/gif/Snake-v1/A2C.gif -------------------------------------------------------------------------------- /result/gif/Snake-v1/DDQN.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/result/gif/Snake-v1/DDQN.gif -------------------------------------------------------------------------------- /result/gif/Snake-v1/PPO.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/result/gif/Snake-v1/PPO.gif -------------------------------------------------------------------------------- /result/gif/Snake-v1/REINFORCE_baseline.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunwuzhou03/reinforcement-learning-lab/714b812fb704f11c9a34b3a8bb2b6234f15c9b5e/result/gif/Snake-v1/REINFORCE_baseline.gif -------------------------------------------------------------------------------- /rl_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import gym 4 | import pygame 5 | import collections 6 | import random 7 | import torch.nn.functional as F 8 | from tqdm import tqdm 9 | import matplotlib.pyplot as plt 10 | import time 11 | import torch.nn as nn 12 | import os 13 | import imageio 14 | 15 | 16 | def save_video(frames, directory="./", filename="video", mode="gif", fps=30): 17 | height, width, _ = frames[0].shape 18 | 19 | # 创建目标目录(如果不存在) 20 | os.makedirs(directory, exist_ok=True) 21 | 22 | # 拼接文件名和扩展名 23 | filename = f"{filename}.{mode}" 24 | # 或者使用下面的语句 25 | # filename = "{}.{}".format(filename, mode) 26 | 27 | # 构建完整的文件路径 28 | filepath = os.path.join(directory, filename) 29 | 30 | # 创建视频写入器 31 | writer = imageio.get_writer(filepath, fps=fps) 32 | 33 | # 将所有帧写入视频 34 | for frame in frames: 35 | writer.append_data(frame) 36 | 37 | # 关闭视频写入器 38 | writer.close() 39 | 40 | 41 | def plot_smooth_reward(rewards, 42 | window_size=100, 43 | directory="./", 44 | filename="smooth_reward_plot"): 45 | 46 | # 创建目标目录(如果不存在) 47 | os.makedirs(directory, exist_ok=True) 48 | 49 | # 拼接文件名和扩展名 50 | filename = f"{filename}.png" 51 | 52 | # 构建完整的文件路径 53 | filepath = os.path.join(directory, filename) 54 | 55 | # 计算滑动窗口平均值 56 | smoothed_rewards = np.convolve(rewards, 57 | np.ones(window_size) / window_size, 58 | mode='valid') 59 | 60 | # 绘制原始奖励和平滑奖励曲线 61 | plt.plot(rewards, label='Raw Reward') 62 | plt.plot(smoothed_rewards, label='Smoothed Reward') 63 | 64 | # 设置图例、标题和轴标签 65 | plt.legend() 66 | plt.title('Smoothed Reward') 67 | plt.xlabel('Episode') 68 | plt.ylabel('Reward') 69 | 70 | # 保存图像 71 | plt.savefig(filepath) 72 | 73 | # 关闭图像 74 | plt.close() 75 | -------------------------------------------------------------------------------- /test_agent.py: -------------------------------------------------------------------------------- 1 | from ddqn import DDQN, plot_smooth_reward, ReplayBuffer 2 | from a2c import A2C 3 | from ppo import PPO 4 | from reinforce_baseline import REINFORCE_BASELINE 5 | import numpy as np 6 | import torch 7 | import gym 8 | import pygame 9 | import collections 10 | import random 11 | import torch.nn.functional as F 12 | from tqdm import tqdm 13 | import matplotlib.pyplot as plt 14 | import time 15 | import pygame 16 | import datetime 17 | from rl_utils import save_video 18 | 19 | gamma = 0.99 20 | num_episodes = 5000 21 | hidden_dim = 128 22 | buffersize = 10000 23 | minmal_size = 500 24 | batch_size = 32 25 | epsilon = 0.01 26 | target_update = 50 27 | learning_rate = 2e-3 28 | max_step = 500 29 | device = torch.device('cuda') 30 | 31 | env_name = 'Snake-v0' 32 | # 注册环境 33 | gym.register(id='Snake-v0', entry_point='snake_env:SnakeEnv') 34 | 35 | env = gym.make(env_name, width=10, height=10) 36 | random.seed(0) 37 | np.random.seed(0) 38 | env.seed(0) 39 | torch.manual_seed(0) 40 | replay_buffer = ReplayBuffer(buffersize) 41 | 42 | state_dim = env.observation_space.shape[0] 43 | action_dim = env.action_space.n 44 | 45 | load_path = 'Snake-v0/DDQN.pth' 46 | agent = DDQN(state_dim, hidden_dim, action_dim, learning_rate, gamma, epsilon, 47 | target_update, device) 48 | 49 | load_path = 'Snake-v0/REINFORCE_baseline.pth' 50 | agent = REINFORCE_BASELINE(state_dim, hidden_dim, action_dim, learning_rate, 51 | gamma, device) 52 | 53 | load_path = 'Snake-v0/PPO.pth' #5best 54 | agent = PPO(state_dim, hidden_dim, action_dim, 1e-3, 1e-3, gamma, 0.95, 0.2, 5, 55 | device) 56 | 57 | load_path = 'Snake-v0/A2C.pth' #5best 58 | agent = A2C(state_dim, hidden_dim, action_dim, 1e-3, 1e-3, gamma, device) 59 | 60 | # 加载模型的状态字典 61 | state_dict = torch.load(load_path) 62 | 63 | # 加载状态字典到模型中 64 | 65 | agent.load_model(load_path) 66 | 67 | state = env.reset() 68 | done = False 69 | frame_list = [] 70 | while not done: 71 | action = agent.take_action(state) # 随机选择一个行动 72 | state, reward, done, info = env.step(action) 73 | rendered_image = env.render('rgb_array') 74 | frame_list.append(rendered_image) 75 | save_video(frame_list) 76 | -------------------------------------------------------------------------------- /test_env.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | env = gym.make('Pendulum-v1') 4 | 5 | observation = env.reset() 6 | done = False 7 | 8 | while not done: 9 | action = env.action_space.sample() # 随机选择一个动作 10 | print(action) 11 | observation, reward, done, info = env.step(action) # 执行动作并获取观察、奖励等信息 12 | print('Observation:', observation) 13 | print('Reward:', reward) 14 | --------------------------------------------------------------------------------