├── imgs
├── bwh_sacc.png
├── result.jpg
└── lldc_sacc.png
├── SAC-Continuous-OLD.zip
├── model
├── PV1_actor10.pth
├── BWv3_actor400.pth
├── PV1_q_critic10.pth
├── BWHv3_actor1400.pth
├── BWv3_q_critic400.pth
├── Humanv4_actor1800.pth
├── BWHv3_q_critic1400.pth
└── Humanv4_q_critic1800.pth
├── runs
├── HCv4 2023-11-17 15_18
│ └── events.out.tfevents.1700205518.kim.5283.0
├── BWv3 2023-11-16 20_28
│ └── events.out.tfevents.1700137718.3060SG.37436.0
├── BWv3 2023-11-17 00_58
│ └── events.out.tfevents.1700153892.3060SG.38179.0
├── BWv3 2023-11-17 05_27
│ └── events.out.tfevents.1700170052.3060SG.38842.0
├── BWv3 2023-11-17 09_57
│ └── events.out.tfevents.1700186236.3060SG.47512.0
├── HCv4 2023-11-17 14_50
│ └── events.out.tfevents.1700203837.3060SG.60122.0
├── HCv4 2023-11-17 15_35
│ └── events.out.tfevents.1700206511.3060SG.63696.0
├── Humanv4 2023-11-17 18_16
│ └── events.out.tfevents.1700216182.kim.5679.0
├── PV1 2023-11-17 18_24
│ └── events.out.tfevents.1700216647.3060SG.67549.0
├── BWHv3 2023-11-16 20_51
│ └── events.out.tfevents.1700139070.3060SG.37507.0
├── BWHv3 2023-11-17 01_20
│ └── events.out.tfevents.1700155227.3060SG.38248.0
├── BWHv3 2023-11-17 05_50
│ └── events.out.tfevents.1700171402.3060SG.38905.0
├── BWHv3 2023-11-17 10_19
│ └── events.out.tfevents.1700187570.3060SG.47587.0
├── LLdV2 2023-11-16 20_07
│ └── events.out.tfevents.1700136473.3060SG.37336.0
├── LLdV2 2023-11-17 00_37
│ └── events.out.tfevents.1700152641.3060SG.38105.0
├── LLdV2 2023-11-17 05_06
│ └── events.out.tfevents.1700168813.3060SG.38764.0
└── LLdV2 2023-11-17 09_36
│ └── events.out.tfevents.1700184997.3060SG.47386.0
├── LICENSE
├── README.md
├── utils.py
├── SAC.py
└── main.py
/imgs/bwh_sacc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/imgs/bwh_sacc.png
--------------------------------------------------------------------------------
/imgs/result.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/imgs/result.jpg
--------------------------------------------------------------------------------
/imgs/lldc_sacc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/imgs/lldc_sacc.png
--------------------------------------------------------------------------------
/SAC-Continuous-OLD.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/SAC-Continuous-OLD.zip
--------------------------------------------------------------------------------
/model/PV1_actor10.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/model/PV1_actor10.pth
--------------------------------------------------------------------------------
/model/BWv3_actor400.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/model/BWv3_actor400.pth
--------------------------------------------------------------------------------
/model/PV1_q_critic10.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/model/PV1_q_critic10.pth
--------------------------------------------------------------------------------
/model/BWHv3_actor1400.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/model/BWHv3_actor1400.pth
--------------------------------------------------------------------------------
/model/BWv3_q_critic400.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/model/BWv3_q_critic400.pth
--------------------------------------------------------------------------------
/model/Humanv4_actor1800.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/model/Humanv4_actor1800.pth
--------------------------------------------------------------------------------
/model/BWHv3_q_critic1400.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/model/BWHv3_q_critic1400.pth
--------------------------------------------------------------------------------
/model/Humanv4_q_critic1800.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/model/Humanv4_q_critic1800.pth
--------------------------------------------------------------------------------
/runs/HCv4 2023-11-17 15_18/events.out.tfevents.1700205518.kim.5283.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/runs/HCv4 2023-11-17 15_18/events.out.tfevents.1700205518.kim.5283.0
--------------------------------------------------------------------------------
/runs/BWv3 2023-11-16 20_28/events.out.tfevents.1700137718.3060SG.37436.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/runs/BWv3 2023-11-16 20_28/events.out.tfevents.1700137718.3060SG.37436.0
--------------------------------------------------------------------------------
/runs/BWv3 2023-11-17 00_58/events.out.tfevents.1700153892.3060SG.38179.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/runs/BWv3 2023-11-17 00_58/events.out.tfevents.1700153892.3060SG.38179.0
--------------------------------------------------------------------------------
/runs/BWv3 2023-11-17 05_27/events.out.tfevents.1700170052.3060SG.38842.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/runs/BWv3 2023-11-17 05_27/events.out.tfevents.1700170052.3060SG.38842.0
--------------------------------------------------------------------------------
/runs/BWv3 2023-11-17 09_57/events.out.tfevents.1700186236.3060SG.47512.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/runs/BWv3 2023-11-17 09_57/events.out.tfevents.1700186236.3060SG.47512.0
--------------------------------------------------------------------------------
/runs/HCv4 2023-11-17 14_50/events.out.tfevents.1700203837.3060SG.60122.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/runs/HCv4 2023-11-17 14_50/events.out.tfevents.1700203837.3060SG.60122.0
--------------------------------------------------------------------------------
/runs/HCv4 2023-11-17 15_35/events.out.tfevents.1700206511.3060SG.63696.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/runs/HCv4 2023-11-17 15_35/events.out.tfevents.1700206511.3060SG.63696.0
--------------------------------------------------------------------------------
/runs/Humanv4 2023-11-17 18_16/events.out.tfevents.1700216182.kim.5679.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/runs/Humanv4 2023-11-17 18_16/events.out.tfevents.1700216182.kim.5679.0
--------------------------------------------------------------------------------
/runs/PV1 2023-11-17 18_24/events.out.tfevents.1700216647.3060SG.67549.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/runs/PV1 2023-11-17 18_24/events.out.tfevents.1700216647.3060SG.67549.0
--------------------------------------------------------------------------------
/runs/BWHv3 2023-11-16 20_51/events.out.tfevents.1700139070.3060SG.37507.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/runs/BWHv3 2023-11-16 20_51/events.out.tfevents.1700139070.3060SG.37507.0
--------------------------------------------------------------------------------
/runs/BWHv3 2023-11-17 01_20/events.out.tfevents.1700155227.3060SG.38248.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/runs/BWHv3 2023-11-17 01_20/events.out.tfevents.1700155227.3060SG.38248.0
--------------------------------------------------------------------------------
/runs/BWHv3 2023-11-17 05_50/events.out.tfevents.1700171402.3060SG.38905.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/runs/BWHv3 2023-11-17 05_50/events.out.tfevents.1700171402.3060SG.38905.0
--------------------------------------------------------------------------------
/runs/BWHv3 2023-11-17 10_19/events.out.tfevents.1700187570.3060SG.47587.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/runs/BWHv3 2023-11-17 10_19/events.out.tfevents.1700187570.3060SG.47587.0
--------------------------------------------------------------------------------
/runs/LLdV2 2023-11-16 20_07/events.out.tfevents.1700136473.3060SG.37336.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/runs/LLdV2 2023-11-16 20_07/events.out.tfevents.1700136473.3060SG.37336.0
--------------------------------------------------------------------------------
/runs/LLdV2 2023-11-17 00_37/events.out.tfevents.1700152641.3060SG.38105.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/runs/LLdV2 2023-11-17 00_37/events.out.tfevents.1700152641.3060SG.38105.0
--------------------------------------------------------------------------------
/runs/LLdV2 2023-11-17 05_06/events.out.tfevents.1700168813.3060SG.38764.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/runs/LLdV2 2023-11-17 05_06/events.out.tfevents.1700168813.3060SG.38764.0
--------------------------------------------------------------------------------
/runs/LLdV2 2023-11-17 09_36/events.out.tfevents.1700184997.3060SG.47386.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XinJingHao/SAC-Continuous-Pytorch/HEAD/runs/LLdV2 2023-11-17 09_36/events.out.tfevents.1700184997.3060SG.47386.0
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 XinJingHao
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SAC-Continuous-Pytorch
2 | **A clean and robust Pytorch implementation of Soft-Actor-Critic on continuous action space.**
3 |
4 | BipedalWalkerHardcore | LunarLanderContinuous
5 | :-----------------------:|:-----------------------:|
6 |
|
7 |
|
8 |
9 |
10 | **Other RL algorithms by Pytorch can be found [here](https://github.com/XinJingHao/RL-Algorithms-by-Pytorch).**
11 |
12 | ## Dependencies
13 | ```python
14 | gymnasium==0.29.1
15 | numpy==1.26.1
16 | pytorch==2.1.0
17 |
18 | python==3.11.5
19 | ```
20 |
21 | ## How to use my code
22 | ### Train from scratch
23 | ```bash
24 | python main.py
25 | ```
26 | where the default enviroment is 'Pendulum'.
27 |
28 | ### Play with trained model
29 | ```bash
30 | python main.py --EnvIdex 0 --render True --Loadmodel True --ModelIdex 10
31 | ```
32 | which will render the 'Pendulum'.
33 |
34 | ### Change Enviroment
35 | If you want to train on different enviroments, just run
36 | ```bash
37 | python main.py --EnvIdex 1
38 | ```
39 | The ```--EnvIdex``` can be set to be 0~5, where
40 | ```bash
41 | '--EnvIdex 0' for 'Pendulum-v1'
42 | '--EnvIdex 1' for 'LunarLanderContinuous-v2'
43 | '--EnvIdex 2' for 'Humanoid-v4'
44 | '--EnvIdex 3' for 'HalfCheetah-v4'
45 | '--EnvIdex 4' for 'BipedalWalker-v3'
46 | '--EnvIdex 5' for 'BipedalWalkerHardcore-v3'
47 | ```
48 |
49 | Note: if you want train on **BipedalWalker, BipedalWalkerHardcore, or LunarLanderContinuous**, you need to install [box2d-py](https://gymnasium.farama.org/environments/box2d/) first. You can install box2d-py via:
50 | ```bash
51 | pip install gymnasium[box2d]
52 | ```
53 |
54 | if you want train on **Humanoid or HalfCheetah**, you need to install [MuJoCo](https://gymnasium.farama.org/environments/mujoco/) first. You can install MuJoCo via:
55 | ```bash
56 | pip install mujoco
57 | pip install gymnasium[mujoco]
58 | ```
59 |
60 | ### Visualize the training curve
61 | You can use the [tensorboard](https://pytorch.org/docs/stable/tensorboard.html) to record anv visualize the training curve.
62 |
63 | - Installation (please make sure PyTorch is installed already):
64 | ```bash
65 | pip install tensorboard
66 | pip install packaging
67 | ```
68 | - Record (the training curves will be saved at '**\runs**'):
69 | ```bash
70 | python main.py --write True
71 | ```
72 |
73 | - Visualization:
74 | ```bash
75 | tensorboard --logdir runs
76 | ```
77 |
78 | ### Hyperparameter Setting
79 | For more details of Hyperparameter Setting, please check 'main.py'
80 |
81 | ### Reference
82 | [Soft Actor-Critic Algorithms and Applications](https://arxiv.org/pdf/1812.05905.pdf)
83 |
84 | ## All Training Curves
85 |
86 | All the experiments are trained with same hyperparameters (see main.py).
87 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | import numpy as np
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.distributions import Normal
7 |
8 | def build_net(layer_shape, hidden_activation, output_activation):
9 | '''Build net with for loop'''
10 | layers = []
11 | for j in range(len(layer_shape)-1):
12 | act = hidden_activation if j < len(layer_shape)-2 else output_activation
13 | layers += [nn.Linear(layer_shape[j], layer_shape[j+1]), act()]
14 | return nn.Sequential(*layers)
15 |
16 |
17 | class Actor(nn.Module):
18 | def __init__(self, state_dim, action_dim, hid_shape, hidden_activation=nn.ReLU, output_activation=nn.ReLU):
19 | super(Actor, self).__init__()
20 | layers = [state_dim] + list(hid_shape)
21 |
22 | self.a_net = build_net(layers, hidden_activation, output_activation)
23 | self.mu_layer = nn.Linear(layers[-1], action_dim)
24 | self.log_std_layer = nn.Linear(layers[-1], action_dim)
25 |
26 | self.LOG_STD_MAX = 2
27 | self.LOG_STD_MIN = -20
28 |
29 | def forward(self, state, deterministic, with_logprob):
30 | '''Network with Enforcing Action Bounds'''
31 | net_out = self.a_net(state)
32 | mu = self.mu_layer(net_out)
33 | log_std = self.log_std_layer(net_out)
34 | log_std = torch.clamp(log_std, self.LOG_STD_MIN, self.LOG_STD_MAX) #总感觉这里clamp不利于学习
35 | # we learn log_std rather than std, so that exp(log_std) is always > 0
36 | std = torch.exp(log_std)
37 | dist = Normal(mu, std)
38 | if deterministic: u = mu
39 | else: u = dist.rsample()
40 |
41 | '''↓↓↓ Enforcing Action Bounds, see Page 16 of https://arxiv.org/pdf/1812.05905.pdf ↓↓↓'''
42 | a = torch.tanh(u)
43 | if with_logprob:
44 | # Get probability density of logp_pi_a from probability density of u:
45 | # logp_pi_a = (dist.log_prob(u) - torch.log(1 - a.pow(2) + 1e-6)).sum(dim=1, keepdim=True)
46 | # Derive from the above equation. No a, thus no tanh(h), thus less gradient vanish and more stable.
47 | logp_pi_a = dist.log_prob(u).sum(axis=1, keepdim=True) - (2 * (np.log(2) - u - F.softplus(-2 * u))).sum(axis=1, keepdim=True)
48 | else:
49 | logp_pi_a = None
50 |
51 | return a, logp_pi_a
52 |
53 | class Double_Q_Critic(nn.Module):
54 | def __init__(self, state_dim, action_dim, hid_shape):
55 | super(Double_Q_Critic, self).__init__()
56 | layers = [state_dim + action_dim] + list(hid_shape) + [1]
57 |
58 | self.Q_1 = build_net(layers, nn.ReLU, nn.Identity)
59 | self.Q_2 = build_net(layers, nn.ReLU, nn.Identity)
60 |
61 |
62 | def forward(self, state, action):
63 | sa = torch.cat([state, action], 1)
64 | q1 = self.Q_1(sa)
65 | q2 = self.Q_2(sa)
66 | return q1, q2
67 |
68 | #reward engineering for better training
69 | def Reward_adapter(r, EnvIdex):
70 | # For Pendulum-v0
71 | if EnvIdex == 0:
72 | r = (r + 8) / 8
73 |
74 | # For LunarLander
75 | elif EnvIdex == 1:
76 | if r <= -100: r = -10
77 |
78 | # For BipedalWalker
79 | elif EnvIdex == 4 or EnvIdex == 5:
80 | if r <= -100: r = -1
81 | return r
82 |
83 |
84 | def Action_adapter(a,max_action):
85 | #from [-1,1] to [-max,max]
86 | return a*max_action
87 |
88 | def Action_adapter_reverse(act,max_action):
89 | #from [-max,max] to [-1,1]
90 | return act/max_action
91 |
92 |
93 | def evaluate_policy(env, max_action, agent, turns = 3):
94 | total_scores = 0
95 | for j in range(turns):
96 | s, info = env.reset()
97 | done = False
98 | while not done:
99 | # Take deterministic actions at test time
100 | a = agent.select_action(s, deterministic=True)
101 | act = Action_adapter(a, max_action)
102 | s_next, r, dw, tr, info = env.step(act)
103 | done = (dw or tr)
104 |
105 | total_scores += r
106 | s = s_next
107 | return int(total_scores/turns)
108 |
109 |
110 | def str2bool(v):
111 | '''transfer str to bool for argparse'''
112 | if isinstance(v, bool):
113 | return v
114 | if v.lower() in ('yes', 'True','true','TRUE', 't', 'y', '1'):
115 | return True
116 | elif v.lower() in ('no', 'False','false','FALSE', 'f', 'n', '0'):
117 | return False
118 | else:
119 | raise argparse.ArgumentTypeError('Boolean value expected.')
120 |
--------------------------------------------------------------------------------
/SAC.py:
--------------------------------------------------------------------------------
1 | from utils import Actor, Double_Q_Critic
2 | import torch.nn.functional as F
3 | import numpy as np
4 | import torch
5 | import copy
6 |
7 |
8 | class SAC_countinuous():
9 | def __init__(self, **kwargs):
10 | # Init hyperparameters for agent, just like "self.gamma = opt.gamma, self.lambd = opt.lambd, ..."
11 | self.__dict__.update(kwargs)
12 | self.tau = 0.005
13 |
14 | self.actor = Actor(self.state_dim, self.action_dim, (self.net_width,self.net_width)).to(self.dvc)
15 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=self.a_lr)
16 |
17 | self.q_critic = Double_Q_Critic(self.state_dim, self.action_dim, (self.net_width,self.net_width)).to(self.dvc)
18 | self.q_critic_optimizer = torch.optim.Adam(self.q_critic.parameters(), lr=self.c_lr)
19 | self.q_critic_target = copy.deepcopy(self.q_critic)
20 | # Freeze target networks with respect to optimizers (only update via polyak averaging)
21 | for p in self.q_critic_target.parameters():
22 | p.requires_grad = False
23 |
24 | self.replay_buffer = ReplayBuffer(self.state_dim, self.action_dim, max_size=int(1e6), dvc=self.dvc)
25 |
26 | if self.adaptive_alpha:
27 | # Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper
28 | self.target_entropy = torch.tensor(-self.action_dim, dtype=float, requires_grad=True, device=self.dvc)
29 | # We learn log_alpha instead of alpha to ensure alpha>0
30 | self.log_alpha = torch.tensor(np.log(self.alpha), dtype=float, requires_grad=True, device=self.dvc)
31 | self.alpha_optim = torch.optim.Adam([self.log_alpha], lr=self.c_lr)
32 |
33 | def select_action(self, state, deterministic):
34 | # only used when interact with the env
35 | with torch.no_grad():
36 | state = torch.FloatTensor(state[np.newaxis,:]).to(self.dvc)
37 | a, _ = self.actor(state, deterministic, with_logprob=False)
38 | return a.cpu().numpy()[0]
39 |
40 | def train(self,):
41 | s, a, r, s_next, dw = self.replay_buffer.sample(self.batch_size)
42 |
43 | #----------------------------- ↓↓↓↓↓ Update Q Net ↓↓↓↓↓ ------------------------------#
44 | with torch.no_grad():
45 | a_next, log_pi_a_next = self.actor(s_next, deterministic=False, with_logprob=True)
46 | target_Q1, target_Q2 = self.q_critic_target(s_next, a_next)
47 | target_Q = torch.min(target_Q1, target_Q2)
48 | target_Q = r + (~dw) * self.gamma * (target_Q - self.alpha * log_pi_a_next) #Dead or Done is tackled by Randombuffer
49 |
50 | # Get current Q estimates
51 | current_Q1, current_Q2 = self.q_critic(s, a)
52 |
53 | q_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
54 | self.q_critic_optimizer.zero_grad()
55 | q_loss.backward()
56 | self.q_critic_optimizer.step()
57 |
58 | #----------------------------- ↓↓↓↓↓ Update Actor Net ↓↓↓↓↓ ------------------------------#
59 | # Freeze critic so you don't waste computational effort computing gradients for them when update actor
60 | for params in self.q_critic.parameters(): params.requires_grad = False
61 |
62 | a, log_pi_a = self.actor(s, deterministic=False, with_logprob=True)
63 | current_Q1, current_Q2 = self.q_critic(s, a)
64 | Q = torch.min(current_Q1, current_Q2)
65 |
66 | a_loss = (self.alpha * log_pi_a - Q).mean()
67 | self.actor_optimizer.zero_grad()
68 | a_loss.backward()
69 | self.actor_optimizer.step()
70 |
71 | for params in self.q_critic.parameters(): params.requires_grad = True
72 |
73 | #----------------------------- ↓↓↓↓↓ Update alpha ↓↓↓↓↓ ------------------------------#
74 | if self.adaptive_alpha:
75 | # We learn log_alpha instead of alpha to ensure alpha>0
76 | alpha_loss = -(self.log_alpha * (log_pi_a + self.target_entropy).detach()).mean()
77 | self.alpha_optim.zero_grad()
78 | alpha_loss.backward()
79 | self.alpha_optim.step()
80 | self.alpha = self.log_alpha.exp()
81 |
82 | #----------------------------- ↓↓↓↓↓ Update Target Net ↓↓↓↓↓ ------------------------------#
83 | for param, target_param in zip(self.q_critic.parameters(), self.q_critic_target.parameters()):
84 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
85 |
86 | def save(self,EnvName, timestep):
87 | torch.save(self.actor.state_dict(), "./model/{}_actor{}.pth".format(EnvName,timestep))
88 | torch.save(self.q_critic.state_dict(), "./model/{}_q_critic{}.pth".format(EnvName,timestep))
89 |
90 | def load(self,EnvName, timestep):
91 | self.actor.load_state_dict(torch.load("./model/{}_actor{}.pth".format(EnvName, timestep), map_location=self.dvc))
92 | self.q_critic.load_state_dict(torch.load("./model/{}_q_critic{}.pth".format(EnvName, timestep), map_location=self.dvc))
93 |
94 |
95 | class ReplayBuffer():
96 | def __init__(self, state_dim, action_dim, max_size, dvc):
97 | self.max_size = max_size
98 | self.dvc = dvc
99 | self.ptr = 0
100 | self.size = 0
101 |
102 | self.s = torch.zeros((max_size, state_dim) ,dtype=torch.float,device=self.dvc)
103 | self.a = torch.zeros((max_size, action_dim) ,dtype=torch.float,device=self.dvc)
104 | self.r = torch.zeros((max_size, 1) ,dtype=torch.float,device=self.dvc)
105 | self.s_next = torch.zeros((max_size, state_dim) ,dtype=torch.float,device=self.dvc)
106 | self.dw = torch.zeros((max_size, 1) ,dtype=torch.bool,device=self.dvc)
107 |
108 | def add(self, s, a, r, s_next, dw):
109 | #每次只放入一个时刻的数据
110 | self.s[self.ptr] = torch.from_numpy(s).to(self.dvc)
111 | self.a[self.ptr] = torch.from_numpy(a).to(self.dvc) # Note that a is numpy.array
112 | self.r[self.ptr] = r
113 | self.s_next[self.ptr] = torch.from_numpy(s_next).to(self.dvc)
114 | self.dw[self.ptr] = dw
115 |
116 | self.ptr = (self.ptr + 1) % self.max_size #存满了又重头开始存
117 | self.size = min(self.size + 1, self.max_size)
118 |
119 | def sample(self, batch_size):
120 | ind = torch.randint(0, self.size, device=self.dvc, size=(batch_size,))
121 | return self.s[ind], self.a[ind], self.r[ind], self.s_next[ind], self.dw[ind]
122 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from utils import str2bool, evaluate_policy, Action_adapter, Action_adapter_reverse, Reward_adapter
2 | from datetime import datetime
3 | from SAC import SAC_countinuous
4 | import gymnasium as gym
5 | import os, shutil
6 | import argparse
7 | import torch
8 |
9 |
10 | '''Hyperparameter Setting'''
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument('--dvc', type=str, default='cuda', help='running device: cuda or cpu')
13 | parser.add_argument('--EnvIdex', type=int, default=0, help='PV1, Lch_Cv2, Humanv4, HCv4, BWv3, BWHv3')
14 | parser.add_argument('--write', type=str2bool, default=False, help='Use SummaryWriter to record the training')
15 | parser.add_argument('--render', type=str2bool, default=False, help='Render or Not')
16 | parser.add_argument('--Loadmodel', type=str2bool, default=False, help='Load pretrained model or Not')
17 | parser.add_argument('--ModelIdex', type=int, default=100, help='which model to load')
18 |
19 | parser.add_argument('--seed', type=int, default=0, help='random seed')
20 | parser.add_argument('--Max_train_steps', type=int, default=int(5e6), help='Max training steps')
21 | parser.add_argument('--save_interval', type=int, default=int(100e3), help='Model saving interval, in steps.')
22 | parser.add_argument('--eval_interval', type=int, default=int(2.5e3), help='Model evaluating interval, in steps.')
23 | parser.add_argument('--update_every', type=int, default=50, help='Training Fraquency, in stpes')
24 |
25 | parser.add_argument('--gamma', type=float, default=0.99, help='Discounted Factor')
26 | parser.add_argument('--net_width', type=int, default=256, help='Hidden net width, s_dim-400-300-a_dim')
27 | parser.add_argument('--a_lr', type=float, default=3e-4, help='Learning rate of actor')
28 | parser.add_argument('--c_lr', type=float, default=3e-4, help='Learning rate of critic')
29 | parser.add_argument('--batch_size', type=int, default=256, help='batch_size of training')
30 | parser.add_argument('--alpha', type=float, default=0.12, help='Entropy coefficient')
31 | parser.add_argument('--adaptive_alpha', type=str2bool, default=True, help='Use adaptive_alpha or Not')
32 | opt = parser.parse_args()
33 | opt.dvc = torch.device(opt.dvc) # from str to torch.device
34 | print(opt)
35 |
36 |
37 | def main():
38 | EnvName = ['Pendulum-v1','LunarLanderContinuous-v2','Humanoid-v4','HalfCheetah-v4','BipedalWalker-v3','BipedalWalkerHardcore-v3']
39 | BrifEnvName = ['PV1', 'LLdV2', 'Humanv4', 'HCv4','BWv3', 'BWHv3']
40 |
41 | # Build Env
42 | env = gym.make(EnvName[opt.EnvIdex], render_mode = "human" if opt.render else None)
43 | eval_env = gym.make(EnvName[opt.EnvIdex])
44 | opt.state_dim = env.observation_space.shape[0]
45 | opt.action_dim = env.action_space.shape[0]
46 | opt.max_action = float(env.action_space.high[0]) #remark: action space【-max,max】
47 | opt.max_e_steps = env._max_episode_steps
48 | print(f'Env:{EnvName[opt.EnvIdex]} state_dim:{opt.state_dim} action_dim:{opt.action_dim} '
49 | f'max_a:{opt.max_action} min_a:{env.action_space.low[0]} max_e_steps:{opt.max_e_steps}')
50 |
51 | # Seed Everything
52 | env_seed = opt.seed
53 | torch.manual_seed(opt.seed)
54 | torch.cuda.manual_seed(opt.seed)
55 | torch.backends.cudnn.deterministic = True
56 | torch.backends.cudnn.benchmark = False
57 | print("Random Seed: {}".format(opt.seed))
58 |
59 | # Build SummaryWriter to record training curves
60 | if opt.write:
61 | from torch.utils.tensorboard import SummaryWriter
62 | timenow = str(datetime.now())[0:-10]
63 | timenow = ' ' + timenow[0:13] + '_' + timenow[-2::]
64 | writepath = 'runs/{}'.format(BrifEnvName[opt.EnvIdex]) + timenow
65 | if os.path.exists(writepath): shutil.rmtree(writepath)
66 | writer = SummaryWriter(log_dir=writepath)
67 |
68 |
69 | # Build DRL model
70 | if not os.path.exists('model'): os.mkdir('model')
71 | agent = SAC_countinuous(**vars(opt)) # var: transfer argparse to dictionary
72 | if opt.Loadmodel: agent.load(BrifEnvName[opt.EnvIdex], opt.ModelIdex)
73 |
74 | if opt.render:
75 | while True:
76 | score = evaluate_policy(env, opt.max_action, agent, turns=1)
77 | print('EnvName:', BrifEnvName[opt.EnvIdex], 'score:', score)
78 | else:
79 | total_steps = 0
80 | while total_steps < opt.Max_train_steps:
81 | s, info = env.reset(seed=env_seed) # Do not use opt.seed directly, or it can overfit to opt.seed
82 | env_seed += 1
83 | done = False
84 |
85 | '''Interact & trian'''
86 | while not done:
87 | if total_steps < (5*opt.max_e_steps):
88 | act = env.action_space.sample() # act∈[-max,max]
89 | a = Action_adapter_reverse(act, opt.max_action) # a∈[-1,1]
90 | else:
91 | a = agent.select_action(s, deterministic=False) # a∈[-1,1]
92 | act = Action_adapter(a, opt.max_action) # act∈[-max,max]
93 | s_next, r, dw, tr, info = env.step(act) # dw: dead&win; tr: truncated
94 | r = Reward_adapter(r, opt.EnvIdex)
95 | done = (dw or tr)
96 |
97 | agent.replay_buffer.add(s, a, r, s_next, dw)
98 | s = s_next
99 | total_steps += 1
100 |
101 | '''train if it's time'''
102 | # train 50 times every 50 steps rather than 1 training per step. Better!
103 | if (total_steps >= 2*opt.max_e_steps) and (total_steps % opt.update_every == 0):
104 | for j in range(opt.update_every):
105 | agent.train()
106 |
107 | '''record & log'''
108 | if total_steps % opt.eval_interval == 0:
109 | ep_r = evaluate_policy(eval_env, opt.max_action, agent, turns=3)
110 | if opt.write: writer.add_scalar('ep_r', ep_r, global_step=total_steps)
111 | print(f'EnvName:{BrifEnvName[opt.EnvIdex]}, Steps: {int(total_steps/1000)}k, Episode Reward:{ep_r}')
112 |
113 | '''save model'''
114 | if total_steps % opt.save_interval == 0:
115 | agent.save(BrifEnvName[opt.EnvIdex], int(total_steps/1000))
116 | env.close()
117 | eval_env.close()
118 |
119 |
120 | if __name__ == '__main__':
121 | main()
122 |
--------------------------------------------------------------------------------