├── imgs ├── bwh_sacc.png ├── result.jpg └── lldc_sacc.png ├── SAC-Continuous-OLD.zip ├── model ├── PV1_actor10.pth ├── BWv3_actor400.pth ├── PV1_q_critic10.pth ├── BWHv3_actor1400.pth ├── BWv3_q_critic400.pth ├── Humanv4_actor1800.pth ├── BWHv3_q_critic1400.pth └── Humanv4_q_critic1800.pth ├── runs ├── HCv4 2023-11-17 15_18 │ └── events.out.tfevents.1700205518.kim.5283.0 ├── BWv3 2023-11-16 20_28 │ └── events.out.tfevents.1700137718.3060SG.37436.0 ├── BWv3 2023-11-17 00_58 │ └── events.out.tfevents.1700153892.3060SG.38179.0 ├── BWv3 2023-11-17 05_27 │ └── events.out.tfevents.1700170052.3060SG.38842.0 ├── BWv3 2023-11-17 09_57 │ └── events.out.tfevents.1700186236.3060SG.47512.0 ├── HCv4 2023-11-17 14_50 │ └── events.out.tfevents.1700203837.3060SG.60122.0 ├── HCv4 2023-11-17 15_35 │ └── events.out.tfevents.1700206511.3060SG.63696.0 ├── Humanv4 2023-11-17 18_16 │ └── events.out.tfevents.1700216182.kim.5679.0 ├── PV1 2023-11-17 18_24 │ └── events.out.tfevents.1700216647.3060SG.67549.0 ├── BWHv3 2023-11-16 20_51 │ └── events.out.tfevents.1700139070.3060SG.37507.0 ├── BWHv3 2023-11-17 01_20 │ └── events.out.tfevents.1700155227.3060SG.38248.0 ├── BWHv3 2023-11-17 05_50 │ └── events.out.tfevents.1700171402.3060SG.38905.0 ├── BWHv3 2023-11-17 10_19 │ └── events.out.tfevents.1700187570.3060SG.47587.0 ├── LLdV2 2023-11-16 20_07 │ └── events.out.tfevents.1700136473.3060SG.37336.0 ├── LLdV2 2023-11-17 00_37 │ └── events.out.tfevents.1700152641.3060SG.38105.0 ├── LLdV2 2023-11-17 05_06 │ └── events.out.tfevents.1700168813.3060SG.38764.0 └── LLdV2 2023-11-17 09_36 │ └── events.out.tfevents.1700184997.3060SG.47386.0 ├── LICENSE ├── README.md ├── utils.py ├── SAC.py └── main.py /imgs/bwh_sacc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/imgs/bwh_sacc.png -------------------------------------------------------------------------------- /imgs/result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/imgs/result.jpg -------------------------------------------------------------------------------- /imgs/lldc_sacc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/imgs/lldc_sacc.png -------------------------------------------------------------------------------- /SAC-Continuous-OLD.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/SAC-Continuous-OLD.zip -------------------------------------------------------------------------------- /model/PV1_actor10.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/model/PV1_actor10.pth -------------------------------------------------------------------------------- /model/BWv3_actor400.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/model/BWv3_actor400.pth -------------------------------------------------------------------------------- /model/PV1_q_critic10.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/model/PV1_q_critic10.pth -------------------------------------------------------------------------------- /model/BWHv3_actor1400.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/model/BWHv3_actor1400.pth -------------------------------------------------------------------------------- /model/BWv3_q_critic400.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/model/BWv3_q_critic400.pth -------------------------------------------------------------------------------- /model/Humanv4_actor1800.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/model/Humanv4_actor1800.pth -------------------------------------------------------------------------------- /model/BWHv3_q_critic1400.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/model/BWHv3_q_critic1400.pth -------------------------------------------------------------------------------- /model/Humanv4_q_critic1800.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/model/Humanv4_q_critic1800.pth -------------------------------------------------------------------------------- /runs/HCv4 2023-11-17 15_18/events.out.tfevents.1700205518.kim.5283.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/runs/HCv4 2023-11-17 15_18/events.out.tfevents.1700205518.kim.5283.0 -------------------------------------------------------------------------------- /runs/BWv3 2023-11-16 20_28/events.out.tfevents.1700137718.3060SG.37436.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/runs/BWv3 2023-11-16 20_28/events.out.tfevents.1700137718.3060SG.37436.0 -------------------------------------------------------------------------------- /runs/BWv3 2023-11-17 00_58/events.out.tfevents.1700153892.3060SG.38179.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/runs/BWv3 2023-11-17 00_58/events.out.tfevents.1700153892.3060SG.38179.0 -------------------------------------------------------------------------------- /runs/BWv3 2023-11-17 05_27/events.out.tfevents.1700170052.3060SG.38842.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/runs/BWv3 2023-11-17 05_27/events.out.tfevents.1700170052.3060SG.38842.0 -------------------------------------------------------------------------------- /runs/BWv3 2023-11-17 09_57/events.out.tfevents.1700186236.3060SG.47512.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/runs/BWv3 2023-11-17 09_57/events.out.tfevents.1700186236.3060SG.47512.0 -------------------------------------------------------------------------------- /runs/HCv4 2023-11-17 14_50/events.out.tfevents.1700203837.3060SG.60122.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/runs/HCv4 2023-11-17 14_50/events.out.tfevents.1700203837.3060SG.60122.0 -------------------------------------------------------------------------------- /runs/HCv4 2023-11-17 15_35/events.out.tfevents.1700206511.3060SG.63696.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/runs/HCv4 2023-11-17 15_35/events.out.tfevents.1700206511.3060SG.63696.0 -------------------------------------------------------------------------------- /runs/Humanv4 2023-11-17 18_16/events.out.tfevents.1700216182.kim.5679.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/runs/Humanv4 2023-11-17 18_16/events.out.tfevents.1700216182.kim.5679.0 -------------------------------------------------------------------------------- /runs/PV1 2023-11-17 18_24/events.out.tfevents.1700216647.3060SG.67549.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/runs/PV1 2023-11-17 18_24/events.out.tfevents.1700216647.3060SG.67549.0 -------------------------------------------------------------------------------- /runs/BWHv3 2023-11-16 20_51/events.out.tfevents.1700139070.3060SG.37507.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/runs/BWHv3 2023-11-16 20_51/events.out.tfevents.1700139070.3060SG.37507.0 -------------------------------------------------------------------------------- /runs/BWHv3 2023-11-17 01_20/events.out.tfevents.1700155227.3060SG.38248.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/runs/BWHv3 2023-11-17 01_20/events.out.tfevents.1700155227.3060SG.38248.0 -------------------------------------------------------------------------------- /runs/BWHv3 2023-11-17 05_50/events.out.tfevents.1700171402.3060SG.38905.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/runs/BWHv3 2023-11-17 05_50/events.out.tfevents.1700171402.3060SG.38905.0 -------------------------------------------------------------------------------- /runs/BWHv3 2023-11-17 10_19/events.out.tfevents.1700187570.3060SG.47587.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/runs/BWHv3 2023-11-17 10_19/events.out.tfevents.1700187570.3060SG.47587.0 -------------------------------------------------------------------------------- /runs/LLdV2 2023-11-16 20_07/events.out.tfevents.1700136473.3060SG.37336.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/runs/LLdV2 2023-11-16 20_07/events.out.tfevents.1700136473.3060SG.37336.0 -------------------------------------------------------------------------------- /runs/LLdV2 2023-11-17 00_37/events.out.tfevents.1700152641.3060SG.38105.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/runs/LLdV2 2023-11-17 00_37/events.out.tfevents.1700152641.3060SG.38105.0 -------------------------------------------------------------------------------- /runs/LLdV2 2023-11-17 05_06/events.out.tfevents.1700168813.3060SG.38764.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/runs/LLdV2 2023-11-17 05_06/events.out.tfevents.1700168813.3060SG.38764.0 -------------------------------------------------------------------------------- /runs/LLdV2 2023-11-17 09_36/events.out.tfevents.1700184997.3060SG.47386.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/runs/LLdV2 2023-11-17 09_36/events.out.tfevents.1700184997.3060SG.47386.0 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 XinJingHao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SAC-Continuous-Pytorch 2 | **A clean and robust Pytorch implementation of Soft-Actor-Critic on continuous action space.** 3 | 4 | BipedalWalkerHardcore | LunarLanderContinuous 5 | :-----------------------:|:-----------------------:| 6 | | 7 | | 8 | 9 | 10 | **Other RL algorithms by Pytorch can be found [here](https://github.com/XinJingHao/RL-Algorithms-by-Pytorch).** 11 | 12 | ## Dependencies 13 | ```python 14 | gymnasium==0.29.1 15 | numpy==1.26.1 16 | pytorch==2.1.0 17 | 18 | python==3.11.5 19 | ``` 20 | 21 | ## How to use my code 22 | ### Train from scratch 23 | ```bash 24 | python main.py 25 | ``` 26 | where the default enviroment is 'Pendulum'. 27 | 28 | ### Play with trained model 29 | ```bash 30 | python main.py --EnvIdex 0 --render True --Loadmodel True --ModelIdex 10 31 | ``` 32 | which will render the 'Pendulum'. 33 | 34 | ### Change Enviroment 35 | If you want to train on different enviroments, just run 36 | ```bash 37 | python main.py --EnvIdex 1 38 | ``` 39 | The ```--EnvIdex``` can be set to be 0~5, where 40 | ```bash 41 | '--EnvIdex 0' for 'Pendulum-v1' 42 | '--EnvIdex 1' for 'LunarLanderContinuous-v2' 43 | '--EnvIdex 2' for 'Humanoid-v4' 44 | '--EnvIdex 3' for 'HalfCheetah-v4' 45 | '--EnvIdex 4' for 'BipedalWalker-v3' 46 | '--EnvIdex 5' for 'BipedalWalkerHardcore-v3' 47 | ``` 48 | 49 | Note: if you want train on **BipedalWalker, BipedalWalkerHardcore, or LunarLanderContinuous**, you need to install [box2d-py](https://gymnasium.farama.org/environments/box2d/) first. You can install box2d-py via: 50 | ```bash 51 | pip install gymnasium[box2d] 52 | ``` 53 | 54 | if you want train on **Humanoid or HalfCheetah**, you need to install [MuJoCo](https://gymnasium.farama.org/environments/mujoco/) first. You can install MuJoCo via: 55 | ```bash 56 | pip install mujoco 57 | pip install gymnasium[mujoco] 58 | ``` 59 | 60 | ### Visualize the training curve 61 | You can use the [tensorboard](https://pytorch.org/docs/stable/tensorboard.html) to record anv visualize the training curve. 62 | 63 | - Installation (please make sure PyTorch is installed already): 64 | ```bash 65 | pip install tensorboard 66 | pip install packaging 67 | ``` 68 | - Record (the training curves will be saved at '**\runs**'): 69 | ```bash 70 | python main.py --write True 71 | ``` 72 | 73 | - Visualization: 74 | ```bash 75 | tensorboard --logdir runs 76 | ``` 77 | 78 | ### Hyperparameter Setting 79 | For more details of Hyperparameter Setting, please check 'main.py' 80 | 81 | ### Reference 82 | [Soft Actor-Critic Algorithms and Applications](https://arxiv.org/pdf/1812.05905.pdf) 83 | 84 | ## All Training Curves 85 | 86 | All the experiments are trained with same hyperparameters (see main.py). 87 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.distributions import Normal 7 | 8 | def build_net(layer_shape, hidden_activation, output_activation): 9 | '''Build net with for loop''' 10 | layers = [] 11 | for j in range(len(layer_shape)-1): 12 | act = hidden_activation if j < len(layer_shape)-2 else output_activation 13 | layers += [nn.Linear(layer_shape[j], layer_shape[j+1]), act()] 14 | return nn.Sequential(*layers) 15 | 16 | 17 | class Actor(nn.Module): 18 | def __init__(self, state_dim, action_dim, hid_shape, hidden_activation=nn.ReLU, output_activation=nn.ReLU): 19 | super(Actor, self).__init__() 20 | layers = [state_dim] + list(hid_shape) 21 | 22 | self.a_net = build_net(layers, hidden_activation, output_activation) 23 | self.mu_layer = nn.Linear(layers[-1], action_dim) 24 | self.log_std_layer = nn.Linear(layers[-1], action_dim) 25 | 26 | self.LOG_STD_MAX = 2 27 | self.LOG_STD_MIN = -20 28 | 29 | def forward(self, state, deterministic, with_logprob): 30 | '''Network with Enforcing Action Bounds''' 31 | net_out = self.a_net(state) 32 | mu = self.mu_layer(net_out) 33 | log_std = self.log_std_layer(net_out) 34 | log_std = torch.clamp(log_std, self.LOG_STD_MIN, self.LOG_STD_MAX) #总感觉这里clamp不利于学习 35 | # we learn log_std rather than std, so that exp(log_std) is always > 0 36 | std = torch.exp(log_std) 37 | dist = Normal(mu, std) 38 | if deterministic: u = mu 39 | else: u = dist.rsample() 40 | 41 | '''↓↓↓ Enforcing Action Bounds, see Page 16 of https://arxiv.org/pdf/1812.05905.pdf ↓↓↓''' 42 | a = torch.tanh(u) 43 | if with_logprob: 44 | # Get probability density of logp_pi_a from probability density of u: 45 | # logp_pi_a = (dist.log_prob(u) - torch.log(1 - a.pow(2) + 1e-6)).sum(dim=1, keepdim=True) 46 | # Derive from the above equation. No a, thus no tanh(h), thus less gradient vanish and more stable. 47 | logp_pi_a = dist.log_prob(u).sum(axis=1, keepdim=True) - (2 * (np.log(2) - u - F.softplus(-2 * u))).sum(axis=1, keepdim=True) 48 | else: 49 | logp_pi_a = None 50 | 51 | return a, logp_pi_a 52 | 53 | class Double_Q_Critic(nn.Module): 54 | def __init__(self, state_dim, action_dim, hid_shape): 55 | super(Double_Q_Critic, self).__init__() 56 | layers = [state_dim + action_dim] + list(hid_shape) + [1] 57 | 58 | self.Q_1 = build_net(layers, nn.ReLU, nn.Identity) 59 | self.Q_2 = build_net(layers, nn.ReLU, nn.Identity) 60 | 61 | 62 | def forward(self, state, action): 63 | sa = torch.cat([state, action], 1) 64 | q1 = self.Q_1(sa) 65 | q2 = self.Q_2(sa) 66 | return q1, q2 67 | 68 | #reward engineering for better training 69 | def Reward_adapter(r, EnvIdex): 70 | # For Pendulum-v0 71 | if EnvIdex == 0: 72 | r = (r + 8) / 8 73 | 74 | # For LunarLander 75 | elif EnvIdex == 1: 76 | if r <= -100: r = -10 77 | 78 | # For BipedalWalker 79 | elif EnvIdex == 4 or EnvIdex == 5: 80 | if r <= -100: r = -1 81 | return r 82 | 83 | 84 | def Action_adapter(a,max_action): 85 | #from [-1,1] to [-max,max] 86 | return a*max_action 87 | 88 | def Action_adapter_reverse(act,max_action): 89 | #from [-max,max] to [-1,1] 90 | return act/max_action 91 | 92 | 93 | def evaluate_policy(env, max_action, agent, turns = 3): 94 | total_scores = 0 95 | for j in range(turns): 96 | s, info = env.reset() 97 | done = False 98 | while not done: 99 | # Take deterministic actions at test time 100 | a = agent.select_action(s, deterministic=True) 101 | act = Action_adapter(a, max_action) 102 | s_next, r, dw, tr, info = env.step(act) 103 | done = (dw or tr) 104 | 105 | total_scores += r 106 | s = s_next 107 | return int(total_scores/turns) 108 | 109 | 110 | def str2bool(v): 111 | '''transfer str to bool for argparse''' 112 | if isinstance(v, bool): 113 | return v 114 | if v.lower() in ('yes', 'True','true','TRUE', 't', 'y', '1'): 115 | return True 116 | elif v.lower() in ('no', 'False','false','FALSE', 'f', 'n', '0'): 117 | return False 118 | else: 119 | raise argparse.ArgumentTypeError('Boolean value expected.') 120 | -------------------------------------------------------------------------------- /SAC.py: -------------------------------------------------------------------------------- 1 | from utils import Actor, Double_Q_Critic 2 | import torch.nn.functional as F 3 | import numpy as np 4 | import torch 5 | import copy 6 | 7 | 8 | class SAC_countinuous(): 9 | def __init__(self, **kwargs): 10 | # Init hyperparameters for agent, just like "self.gamma = opt.gamma, self.lambd = opt.lambd, ..." 11 | self.__dict__.update(kwargs) 12 | self.tau = 0.005 13 | 14 | self.actor = Actor(self.state_dim, self.action_dim, (self.net_width,self.net_width)).to(self.dvc) 15 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=self.a_lr) 16 | 17 | self.q_critic = Double_Q_Critic(self.state_dim, self.action_dim, (self.net_width,self.net_width)).to(self.dvc) 18 | self.q_critic_optimizer = torch.optim.Adam(self.q_critic.parameters(), lr=self.c_lr) 19 | self.q_critic_target = copy.deepcopy(self.q_critic) 20 | # Freeze target networks with respect to optimizers (only update via polyak averaging) 21 | for p in self.q_critic_target.parameters(): 22 | p.requires_grad = False 23 | 24 | self.replay_buffer = ReplayBuffer(self.state_dim, self.action_dim, max_size=int(1e6), dvc=self.dvc) 25 | 26 | if self.adaptive_alpha: 27 | # Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper 28 | self.target_entropy = torch.tensor(-self.action_dim, dtype=float, requires_grad=True, device=self.dvc) 29 | # We learn log_alpha instead of alpha to ensure alpha>0 30 | self.log_alpha = torch.tensor(np.log(self.alpha), dtype=float, requires_grad=True, device=self.dvc) 31 | self.alpha_optim = torch.optim.Adam([self.log_alpha], lr=self.c_lr) 32 | 33 | def select_action(self, state, deterministic): 34 | # only used when interact with the env 35 | with torch.no_grad(): 36 | state = torch.FloatTensor(state[np.newaxis,:]).to(self.dvc) 37 | a, _ = self.actor(state, deterministic, with_logprob=False) 38 | return a.cpu().numpy()[0] 39 | 40 | def train(self,): 41 | s, a, r, s_next, dw = self.replay_buffer.sample(self.batch_size) 42 | 43 | #----------------------------- ↓↓↓↓↓ Update Q Net ↓↓↓↓↓ ------------------------------# 44 | with torch.no_grad(): 45 | a_next, log_pi_a_next = self.actor(s_next, deterministic=False, with_logprob=True) 46 | target_Q1, target_Q2 = self.q_critic_target(s_next, a_next) 47 | target_Q = torch.min(target_Q1, target_Q2) 48 | target_Q = r + (~dw) * self.gamma * (target_Q - self.alpha * log_pi_a_next) #Dead or Done is tackled by Randombuffer 49 | 50 | # Get current Q estimates 51 | current_Q1, current_Q2 = self.q_critic(s, a) 52 | 53 | q_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q) 54 | self.q_critic_optimizer.zero_grad() 55 | q_loss.backward() 56 | self.q_critic_optimizer.step() 57 | 58 | #----------------------------- ↓↓↓↓↓ Update Actor Net ↓↓↓↓↓ ------------------------------# 59 | # Freeze critic so you don't waste computational effort computing gradients for them when update actor 60 | for params in self.q_critic.parameters(): params.requires_grad = False 61 | 62 | a, log_pi_a = self.actor(s, deterministic=False, with_logprob=True) 63 | current_Q1, current_Q2 = self.q_critic(s, a) 64 | Q = torch.min(current_Q1, current_Q2) 65 | 66 | a_loss = (self.alpha * log_pi_a - Q).mean() 67 | self.actor_optimizer.zero_grad() 68 | a_loss.backward() 69 | self.actor_optimizer.step() 70 | 71 | for params in self.q_critic.parameters(): params.requires_grad = True 72 | 73 | #----------------------------- ↓↓↓↓↓ Update alpha ↓↓↓↓↓ ------------------------------# 74 | if self.adaptive_alpha: 75 | # We learn log_alpha instead of alpha to ensure alpha>0 76 | alpha_loss = -(self.log_alpha * (log_pi_a + self.target_entropy).detach()).mean() 77 | self.alpha_optim.zero_grad() 78 | alpha_loss.backward() 79 | self.alpha_optim.step() 80 | self.alpha = self.log_alpha.exp() 81 | 82 | #----------------------------- ↓↓↓↓↓ Update Target Net ↓↓↓↓↓ ------------------------------# 83 | for param, target_param in zip(self.q_critic.parameters(), self.q_critic_target.parameters()): 84 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 85 | 86 | def save(self,EnvName, timestep): 87 | torch.save(self.actor.state_dict(), "./model/{}_actor{}.pth".format(EnvName,timestep)) 88 | torch.save(self.q_critic.state_dict(), "./model/{}_q_critic{}.pth".format(EnvName,timestep)) 89 | 90 | def load(self,EnvName, timestep): 91 | self.actor.load_state_dict(torch.load("./model/{}_actor{}.pth".format(EnvName, timestep), map_location=self.dvc)) 92 | self.q_critic.load_state_dict(torch.load("./model/{}_q_critic{}.pth".format(EnvName, timestep), map_location=self.dvc)) 93 | 94 | 95 | class ReplayBuffer(): 96 | def __init__(self, state_dim, action_dim, max_size, dvc): 97 | self.max_size = max_size 98 | self.dvc = dvc 99 | self.ptr = 0 100 | self.size = 0 101 | 102 | self.s = torch.zeros((max_size, state_dim) ,dtype=torch.float,device=self.dvc) 103 | self.a = torch.zeros((max_size, action_dim) ,dtype=torch.float,device=self.dvc) 104 | self.r = torch.zeros((max_size, 1) ,dtype=torch.float,device=self.dvc) 105 | self.s_next = torch.zeros((max_size, state_dim) ,dtype=torch.float,device=self.dvc) 106 | self.dw = torch.zeros((max_size, 1) ,dtype=torch.bool,device=self.dvc) 107 | 108 | def add(self, s, a, r, s_next, dw): 109 | #每次只放入一个时刻的数据 110 | self.s[self.ptr] = torch.from_numpy(s).to(self.dvc) 111 | self.a[self.ptr] = torch.from_numpy(a).to(self.dvc) # Note that a is numpy.array 112 | self.r[self.ptr] = r 113 | self.s_next[self.ptr] = torch.from_numpy(s_next).to(self.dvc) 114 | self.dw[self.ptr] = dw 115 | 116 | self.ptr = (self.ptr + 1) % self.max_size #存满了又重头开始存 117 | self.size = min(self.size + 1, self.max_size) 118 | 119 | def sample(self, batch_size): 120 | ind = torch.randint(0, self.size, device=self.dvc, size=(batch_size,)) 121 | return self.s[ind], self.a[ind], self.r[ind], self.s_next[ind], self.dw[ind] 122 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from utils import str2bool, evaluate_policy, Action_adapter, Action_adapter_reverse, Reward_adapter 2 | from datetime import datetime 3 | from SAC import SAC_countinuous 4 | import gymnasium as gym 5 | import os, shutil 6 | import argparse 7 | import torch 8 | 9 | 10 | '''Hyperparameter Setting''' 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--dvc', type=str, default='cuda', help='running device: cuda or cpu') 13 | parser.add_argument('--EnvIdex', type=int, default=0, help='PV1, Lch_Cv2, Humanv4, HCv4, BWv3, BWHv3') 14 | parser.add_argument('--write', type=str2bool, default=False, help='Use SummaryWriter to record the training') 15 | parser.add_argument('--render', type=str2bool, default=False, help='Render or Not') 16 | parser.add_argument('--Loadmodel', type=str2bool, default=False, help='Load pretrained model or Not') 17 | parser.add_argument('--ModelIdex', type=int, default=100, help='which model to load') 18 | 19 | parser.add_argument('--seed', type=int, default=0, help='random seed') 20 | parser.add_argument('--Max_train_steps', type=int, default=int(5e6), help='Max training steps') 21 | parser.add_argument('--save_interval', type=int, default=int(100e3), help='Model saving interval, in steps.') 22 | parser.add_argument('--eval_interval', type=int, default=int(2.5e3), help='Model evaluating interval, in steps.') 23 | parser.add_argument('--update_every', type=int, default=50, help='Training Fraquency, in stpes') 24 | 25 | parser.add_argument('--gamma', type=float, default=0.99, help='Discounted Factor') 26 | parser.add_argument('--net_width', type=int, default=256, help='Hidden net width, s_dim-400-300-a_dim') 27 | parser.add_argument('--a_lr', type=float, default=3e-4, help='Learning rate of actor') 28 | parser.add_argument('--c_lr', type=float, default=3e-4, help='Learning rate of critic') 29 | parser.add_argument('--batch_size', type=int, default=256, help='batch_size of training') 30 | parser.add_argument('--alpha', type=float, default=0.12, help='Entropy coefficient') 31 | parser.add_argument('--adaptive_alpha', type=str2bool, default=True, help='Use adaptive_alpha or Not') 32 | opt = parser.parse_args() 33 | opt.dvc = torch.device(opt.dvc) # from str to torch.device 34 | print(opt) 35 | 36 | 37 | def main(): 38 | EnvName = ['Pendulum-v1','LunarLanderContinuous-v2','Humanoid-v4','HalfCheetah-v4','BipedalWalker-v3','BipedalWalkerHardcore-v3'] 39 | BrifEnvName = ['PV1', 'LLdV2', 'Humanv4', 'HCv4','BWv3', 'BWHv3'] 40 | 41 | # Build Env 42 | env = gym.make(EnvName[opt.EnvIdex], render_mode = "human" if opt.render else None) 43 | eval_env = gym.make(EnvName[opt.EnvIdex]) 44 | opt.state_dim = env.observation_space.shape[0] 45 | opt.action_dim = env.action_space.shape[0] 46 | opt.max_action = float(env.action_space.high[0]) #remark: action space【-max,max】 47 | opt.max_e_steps = env._max_episode_steps 48 | print(f'Env:{EnvName[opt.EnvIdex]} state_dim:{opt.state_dim} action_dim:{opt.action_dim} ' 49 | f'max_a:{opt.max_action} min_a:{env.action_space.low[0]} max_e_steps:{opt.max_e_steps}') 50 | 51 | # Seed Everything 52 | env_seed = opt.seed 53 | torch.manual_seed(opt.seed) 54 | torch.cuda.manual_seed(opt.seed) 55 | torch.backends.cudnn.deterministic = True 56 | torch.backends.cudnn.benchmark = False 57 | print("Random Seed: {}".format(opt.seed)) 58 | 59 | # Build SummaryWriter to record training curves 60 | if opt.write: 61 | from torch.utils.tensorboard import SummaryWriter 62 | timenow = str(datetime.now())[0:-10] 63 | timenow = ' ' + timenow[0:13] + '_' + timenow[-2::] 64 | writepath = 'runs/{}'.format(BrifEnvName[opt.EnvIdex]) + timenow 65 | if os.path.exists(writepath): shutil.rmtree(writepath) 66 | writer = SummaryWriter(log_dir=writepath) 67 | 68 | 69 | # Build DRL model 70 | if not os.path.exists('model'): os.mkdir('model') 71 | agent = SAC_countinuous(**vars(opt)) # var: transfer argparse to dictionary 72 | if opt.Loadmodel: agent.load(BrifEnvName[opt.EnvIdex], opt.ModelIdex) 73 | 74 | if opt.render: 75 | while True: 76 | score = evaluate_policy(env, opt.max_action, agent, turns=1) 77 | print('EnvName:', BrifEnvName[opt.EnvIdex], 'score:', score) 78 | else: 79 | total_steps = 0 80 | while total_steps < opt.Max_train_steps: 81 | s, info = env.reset(seed=env_seed) # Do not use opt.seed directly, or it can overfit to opt.seed 82 | env_seed += 1 83 | done = False 84 | 85 | '''Interact & trian''' 86 | while not done: 87 | if total_steps < (5*opt.max_e_steps): 88 | act = env.action_space.sample() # act∈[-max,max] 89 | a = Action_adapter_reverse(act, opt.max_action) # a∈[-1,1] 90 | else: 91 | a = agent.select_action(s, deterministic=False) # a∈[-1,1] 92 | act = Action_adapter(a, opt.max_action) # act∈[-max,max] 93 | s_next, r, dw, tr, info = env.step(act) # dw: dead&win; tr: truncated 94 | r = Reward_adapter(r, opt.EnvIdex) 95 | done = (dw or tr) 96 | 97 | agent.replay_buffer.add(s, a, r, s_next, dw) 98 | s = s_next 99 | total_steps += 1 100 | 101 | '''train if it's time''' 102 | # train 50 times every 50 steps rather than 1 training per step. Better! 103 | if (total_steps >= 2*opt.max_e_steps) and (total_steps % opt.update_every == 0): 104 | for j in range(opt.update_every): 105 | agent.train() 106 | 107 | '''record & log''' 108 | if total_steps % opt.eval_interval == 0: 109 | ep_r = evaluate_policy(eval_env, opt.max_action, agent, turns=3) 110 | if opt.write: writer.add_scalar('ep_r', ep_r, global_step=total_steps) 111 | print(f'EnvName:{BrifEnvName[opt.EnvIdex]}, Steps: {int(total_steps/1000)}k, Episode Reward:{ep_r}') 112 | 113 | '''save model''' 114 | if total_steps % opt.save_interval == 0: 115 | agent.save(BrifEnvName[opt.EnvIdex], int(total_steps/1000)) 116 | env.close() 117 | eval_env.close() 118 | 119 | 120 | if __name__ == '__main__': 121 | main() 122 | --------------------------------------------------------------------------------