├── README.md ├── agent.py ├── buffer.py ├── imgs ├── SAC_discrete_CP.png └── SAC_discrete_LL.png ├── networks.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # SAC Discrete 2 | 3 | PyTorch implementation of the discrete Soft-Actor-Critic (SAC) algorithm. 4 | 5 | # Run 6 | Execute `python train.py` 7 | 8 | # Results 9 | ### CartPole-v0 10 | ![alt_text](imgs/SAC_discrete_CP.png) 11 | 12 | ### LunarLander-v2 13 | ![alt_text](imgs/SAC_discrete_LL.png) 14 | -------------------------------------------------------------------------------- /agent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | from torch.nn.utils import clip_grad_norm_ 6 | from networks import Critic, Actor 7 | import copy 8 | 9 | 10 | class SAC(nn.Module): 11 | """Interacts with and learns from the environment.""" 12 | 13 | def __init__(self, 14 | state_size, 15 | action_size, 16 | device 17 | ): 18 | """Initialize an Agent object. 19 | 20 | Params 21 | ====== 22 | state_size (int): dimension of each state 23 | action_size (int): dimension of each action 24 | random_seed (int): random seed 25 | """ 26 | super(SAC, self).__init__() 27 | self.state_size = state_size 28 | self.action_size = action_size 29 | 30 | self.device = device 31 | 32 | self.gamma = 0.99 33 | self.tau = 1e-2 34 | hidden_size = 256 35 | learning_rate = 5e-4 36 | self.clip_grad_param = 1 37 | 38 | self.target_entropy = -action_size # -dim(A) 39 | 40 | self.log_alpha = torch.tensor([0.0], requires_grad=True) 41 | self.alpha = self.log_alpha.exp().detach() 42 | self.alpha_optimizer = optim.Adam(params=[self.log_alpha], lr=learning_rate) 43 | 44 | # Actor Network 45 | 46 | self.actor_local = Actor(state_size, action_size, hidden_size).to(device) 47 | self.actor_optimizer = optim.Adam(self.actor_local.parameters(), lr=learning_rate) 48 | 49 | # Critic Network (w/ Target Network) 50 | 51 | self.critic1 = Critic(state_size, action_size, hidden_size, 2).to(device) 52 | self.critic2 = Critic(state_size, action_size, hidden_size, 1).to(device) 53 | 54 | assert self.critic1.parameters() != self.critic2.parameters() 55 | 56 | self.critic1_target = Critic(state_size, action_size, hidden_size).to(device) 57 | self.critic1_target.load_state_dict(self.critic1.state_dict()) 58 | 59 | self.critic2_target = Critic(state_size, action_size, hidden_size).to(device) 60 | self.critic2_target.load_state_dict(self.critic2.state_dict()) 61 | 62 | self.critic1_optimizer = optim.Adam(self.critic1.parameters(), lr=learning_rate) 63 | self.critic2_optimizer = optim.Adam(self.critic2.parameters(), lr=learning_rate) 64 | 65 | 66 | def get_action(self, state): 67 | """Returns actions for given state as per current policy.""" 68 | state = torch.from_numpy(state).float().to(self.device) 69 | 70 | with torch.no_grad(): 71 | action = self.actor_local.get_det_action(state) 72 | return action.numpy() 73 | 74 | def calc_policy_loss(self, states, alpha): 75 | _, action_probs, log_pis = self.actor_local.evaluate(states) 76 | 77 | q1 = self.critic1(states) 78 | q2 = self.critic2(states) 79 | min_Q = torch.min(q1,q2) 80 | actor_loss = (action_probs * (alpha * log_pis - min_Q )).sum(1).mean() 81 | log_action_pi = torch.sum(log_pis * action_probs, dim=1) 82 | return actor_loss, log_action_pi 83 | 84 | def learn(self, step, experiences, gamma, d=1): 85 | """Updates actor, critics and entropy_alpha parameters using given batch of experience tuples. 86 | Q_targets = r + γ * (min_critic_target(next_state, actor_target(next_state)) - α *log_pi(next_action|next_state)) 87 | Critic_loss = MSE(Q, Q_target) 88 | Actor_loss = α * log_pi(a|s) - Q(s,a) 89 | where: 90 | actor_target(state) -> action 91 | critic_target(state, action) -> Q-value 92 | Params 93 | ====== 94 | experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples 95 | gamma (float): discount factor 96 | """ 97 | states, actions, rewards, next_states, dones = experiences 98 | 99 | # ---------------------------- update actor ---------------------------- # 100 | current_alpha = copy.deepcopy(self.alpha) 101 | actor_loss, log_pis = self.calc_policy_loss(states, current_alpha.to(self.device)) 102 | self.actor_optimizer.zero_grad() 103 | actor_loss.backward() 104 | self.actor_optimizer.step() 105 | 106 | # Compute alpha loss 107 | alpha_loss = - (self.log_alpha.exp() * (log_pis.cpu() + self.target_entropy).detach().cpu()).mean() 108 | self.alpha_optimizer.zero_grad() 109 | alpha_loss.backward() 110 | self.alpha_optimizer.step() 111 | self.alpha = self.log_alpha.exp().detach() 112 | 113 | # ---------------------------- update critic ---------------------------- # 114 | # Get predicted next-state actions and Q values from target models 115 | with torch.no_grad(): 116 | _, action_probs, log_pis = self.actor_local.evaluate(next_states) 117 | Q_target1_next = self.critic1_target(next_states) 118 | Q_target2_next = self.critic2_target(next_states) 119 | Q_target_next = action_probs * (torch.min(Q_target1_next, Q_target2_next) - self.alpha.to(self.device) * log_pis) 120 | 121 | # Compute Q targets for current states (y_i) 122 | Q_targets = rewards + (gamma * (1 - dones) * Q_target_next.sum(dim=1).unsqueeze(-1)) 123 | 124 | # Compute critic loss 125 | q1 = self.critic1(states).gather(1, actions.long()) 126 | q2 = self.critic2(states).gather(1, actions.long()) 127 | 128 | critic1_loss = 0.5 * F.mse_loss(q1, Q_targets) 129 | critic2_loss = 0.5 * F.mse_loss(q2, Q_targets) 130 | 131 | # Update critics 132 | # critic 1 133 | self.critic1_optimizer.zero_grad() 134 | critic1_loss.backward(retain_graph=True) 135 | clip_grad_norm_(self.critic1.parameters(), self.clip_grad_param) 136 | self.critic1_optimizer.step() 137 | # critic 2 138 | self.critic2_optimizer.zero_grad() 139 | critic2_loss.backward() 140 | clip_grad_norm_(self.critic2.parameters(), self.clip_grad_param) 141 | self.critic2_optimizer.step() 142 | 143 | # ----------------------- update target networks ----------------------- # 144 | self.soft_update(self.critic1, self.critic1_target) 145 | self.soft_update(self.critic2, self.critic2_target) 146 | 147 | return actor_loss.item(), alpha_loss.item(), critic1_loss.item(), critic2_loss.item(), current_alpha 148 | 149 | def soft_update(self, local_model , target_model): 150 | """Soft update model parameters. 151 | θ_target = τ*θ_local + (1 - τ)*θ_target 152 | Params 153 | ====== 154 | local_model: PyTorch model (weights will be copied from) 155 | target_model: PyTorch model (weights will be copied to) 156 | tau (float): interpolation parameter 157 | """ 158 | for target_param, local_param in zip(target_model.parameters(), local_model.parameters()): 159 | target_param.data.copy_(self.tau*local_param.data + (1.0-self.tau)*target_param.data) 160 | -------------------------------------------------------------------------------- /buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import torch 4 | from collections import deque, namedtuple 5 | 6 | class ReplayBuffer: 7 | """Fixed-size buffer to store experience tuples.""" 8 | 9 | def __init__(self, buffer_size, batch_size, device): 10 | """Initialize a ReplayBuffer object. 11 | Params 12 | ====== 13 | buffer_size (int): maximum size of buffer 14 | batch_size (int): size of each training batch 15 | seed (int): random seed 16 | """ 17 | self.device = device 18 | self.memory = deque(maxlen=buffer_size) 19 | self.batch_size = batch_size 20 | self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"]) 21 | 22 | def add(self, state, action, reward, next_state, done): 23 | """Add a new experience to memory.""" 24 | e = self.experience(state, action, reward, next_state, done) 25 | self.memory.append(e) 26 | 27 | def sample(self): 28 | """Randomly sample a batch of experiences from memory.""" 29 | experiences = random.sample(self.memory, k=self.batch_size) 30 | 31 | states = torch.from_numpy(np.stack([e.state for e in experiences if e is not None])).float().to(self.device) 32 | actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).float().to(self.device) 33 | rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(self.device) 34 | next_states = torch.from_numpy(np.stack([e.next_state for e in experiences if e is not None])).float().to(self.device) 35 | dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(self.device) 36 | 37 | return (states, actions, rewards, next_states, dones) 38 | 39 | def __len__(self): 40 | """Return the current size of internal memory.""" 41 | return len(self.memory) 42 | -------------------------------------------------------------------------------- /imgs/SAC_discrete_CP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BY571/SAC_discrete/d3403e784df073dd0af94b5c01513b95d4c67e20/imgs/SAC_discrete_CP.png -------------------------------------------------------------------------------- /imgs/SAC_discrete_LL.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BY571/SAC_discrete/d3403e784df073dd0af94b5c01513b95d4c67e20/imgs/SAC_discrete_LL.png -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.distributions import Categorical 4 | import numpy as np 5 | import torch.nn.functional as F 6 | 7 | 8 | def hidden_init(layer): 9 | fan_in = layer.weight.data.size()[0] 10 | lim = 1. / np.sqrt(fan_in) 11 | return (-lim, lim) 12 | 13 | class Actor(nn.Module): 14 | """Actor (Policy) Model.""" 15 | 16 | def __init__(self, state_size, action_size, hidden_size=32): 17 | """Initialize parameters and build model. 18 | Params 19 | ====== 20 | state_size (int): Dimension of each state 21 | action_size (int): Dimension of each action 22 | seed (int): Random seed 23 | fc1_units (int): Number of nodes in first hidden layer 24 | fc2_units (int): Number of nodes in second hidden layer 25 | """ 26 | super(Actor, self).__init__() 27 | 28 | self.fc1 = nn.Linear(state_size, hidden_size) 29 | self.fc2 = nn.Linear(hidden_size, hidden_size) 30 | self.fc3 = nn.Linear(hidden_size, action_size) 31 | self.softmax = nn.Softmax(dim=-1) 32 | 33 | def forward(self, state): 34 | 35 | x = F.relu(self.fc1(state)) 36 | x = F.relu(self.fc2(x)) 37 | action_probs = self.softmax(self.fc3(x)) 38 | return action_probs 39 | 40 | def evaluate(self, state, epsilon=1e-6): 41 | action_probs = self.forward(state) 42 | 43 | dist = Categorical(action_probs) 44 | action = dist.sample().to(state.device) 45 | # Have to deal with situation of 0.0 probabilities because we can't do log 0 46 | z = action_probs == 0.0 47 | z = z.float() * 1e-8 48 | log_action_probabilities = torch.log(action_probs + z) 49 | return action.detach().cpu(), action_probs, log_action_probabilities 50 | 51 | def get_action(self, state): 52 | """ 53 | returns the action based on a squashed gaussian policy. That means the samples are obtained according to: 54 | a(s,e)= tanh(mu(s)+sigma(s)+e) 55 | """ 56 | action_probs = self.forward(state) 57 | 58 | dist = Categorical(action_probs) 59 | action = dist.sample().to(state.device) 60 | # Have to deal with situation of 0.0 probabilities because we can't do log 0 61 | z = action_probs == 0.0 62 | z = z.float() * 1e-8 63 | log_action_probabilities = torch.log(action_probs + z) 64 | return action.detach().cpu(), action_probs, log_action_probabilities 65 | 66 | def get_det_action(self, state): 67 | action_probs = self.forward(state) 68 | dist = Categorical(action_probs) 69 | action = dist.sample().to(state.device) 70 | return action.detach().cpu() 71 | 72 | 73 | class Critic(nn.Module): 74 | """Critic (Value) Model.""" 75 | 76 | def __init__(self, state_size, action_size, hidden_size=32, seed=1): 77 | """Initialize parameters and build model. 78 | Params 79 | ====== 80 | state_size (int): Dimension of each state 81 | action_size (int): Dimension of each action 82 | seed (int): Random seed 83 | hidden_size (int): Number of nodes in the network layers 84 | """ 85 | super(Critic, self).__init__() 86 | self.seed = torch.manual_seed(seed) 87 | self.fc1 = nn.Linear(state_size, hidden_size) 88 | self.fc2 = nn.Linear(hidden_size, hidden_size) 89 | self.fc3 = nn.Linear(hidden_size, action_size) 90 | self.reset_parameters() 91 | 92 | def reset_parameters(self): 93 | self.fc1.weight.data.uniform_(*hidden_init(self.fc1)) 94 | self.fc2.weight.data.uniform_(*hidden_init(self.fc2)) 95 | self.fc3.weight.data.uniform_(-3e-3, 3e-3) 96 | 97 | def forward(self, state): 98 | """Build a critic (value) network that maps (state, action) pairs -> Q-values.""" 99 | x = F.relu(self.fc1(state)) 100 | x = F.relu(self.fc2(x)) 101 | return self.fc3(x) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import gym 4 | import pybullet_envs 5 | import numpy as np 6 | from collections import deque 7 | import torch 8 | import wandb 9 | import argparse 10 | from buffer import ReplayBuffer 11 | import glob 12 | from utils import save, collect_random 13 | import random 14 | from agent import SAC 15 | 16 | def get_config(): 17 | parser = argparse.ArgumentParser(description='RL') 18 | parser.add_argument("--run_name", type=str, default="SAC", help="Run name, default: SAC") 19 | parser.add_argument("--env", type=str, default="CartPole-v0", help="Gym environment name, default: CartPole-v0") 20 | parser.add_argument("--episodes", type=int, default=100, help="Number of episodes, default: 100") 21 | parser.add_argument("--buffer_size", type=int, default=100_000, help="Maximal training dataset size, default: 100_000") 22 | parser.add_argument("--seed", type=int, default=1, help="Seed, default: 1") 23 | parser.add_argument("--log_video", type=int, default=0, help="Log agent behaviour to wanbd when set to 1, default: 0") 24 | parser.add_argument("--save_every", type=int, default=100, help="Saves the network every x epochs, default: 25") 25 | parser.add_argument("--batch_size", type=int, default=256, help="Batch size, default: 256") 26 | 27 | args = parser.parse_args() 28 | return args 29 | 30 | def train(config): 31 | np.random.seed(config.seed) 32 | random.seed(config.seed) 33 | torch.manual_seed(config.seed) 34 | env = gym.make(config.env) 35 | 36 | env.seed(config.seed) 37 | env.action_space.seed(config.seed) 38 | 39 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 40 | 41 | steps = 0 42 | average10 = deque(maxlen=10) 43 | total_steps = 0 44 | 45 | with wandb.init(project="SAC_Discrete", name=config.run_name, config=config): 46 | 47 | agent = SAC(state_size=env.observation_space.shape[0], 48 | action_size=env.action_space.n, 49 | device=device) 50 | 51 | wandb.watch(agent, log="gradients", log_freq=10) 52 | 53 | buffer = ReplayBuffer(buffer_size=config.buffer_size, batch_size=config.batch_size, device=device) 54 | 55 | collect_random(env=env, dataset=buffer, num_samples=10000) 56 | 57 | if config.log_video: 58 | env = gym.wrappers.Monitor(env, './video', video_callable=lambda x: x%10==0, force=True) 59 | 60 | for i in range(1, config.episodes+1): 61 | state = env.reset() 62 | episode_steps = 0 63 | rewards = 0 64 | while True: 65 | action = agent.get_action(state) 66 | steps += 1 67 | next_state, reward, done, _ = env.step(action) 68 | buffer.add(state, action, reward, next_state, done) 69 | policy_loss, alpha_loss, bellmann_error1, bellmann_error2, current_alpha = agent.learn(steps, buffer.sample(), gamma=0.99) 70 | state = next_state 71 | rewards += reward 72 | episode_steps += 1 73 | if done: 74 | break 75 | 76 | 77 | 78 | average10.append(rewards) 79 | total_steps += episode_steps 80 | print("Episode: {} | Reward: {} | Polciy Loss: {} | Steps: {}".format(i, rewards, policy_loss, steps,)) 81 | 82 | wandb.log({"Reward": rewards, 83 | "Average10": np.mean(average10), 84 | "Steps": total_steps, 85 | "Policy Loss": policy_loss, 86 | "Alpha Loss": alpha_loss, 87 | "Bellmann error 1": bellmann_error1, 88 | "Bellmann error 2": bellmann_error2, 89 | "Alpha": current_alpha, 90 | "Steps": steps, 91 | "Episode": i, 92 | "Buffer size": buffer.__len__()}) 93 | 94 | if (i %10 == 0) and config.log_video: 95 | mp4list = glob.glob('video/*.mp4') 96 | if len(mp4list) > 1: 97 | mp4 = mp4list[-2] 98 | wandb.log({"gameplays": wandb.Video(mp4, caption='episode: '+str(i-10), fps=4, format="gif"), "Episode": i}) 99 | 100 | if i % config.save_every == 0: 101 | save(config, save_name="SAC_discrete", model=agent.actor_local, wandb=wandb, ep=0) 102 | 103 | if __name__ == "__main__": 104 | config = get_config() 105 | train(config) 106 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def save(args, save_name, model, wandb, ep=None): 4 | import os 5 | save_dir = './trained_models/' 6 | if not os.path.exists(save_dir): 7 | os.makedirs(save_dir) 8 | if not ep == None: 9 | torch.save(model.state_dict(), save_dir + args.run_name + save_name + str(ep) + ".pth") 10 | wandb.save(save_dir + args.run_name + save_name + str(ep) + ".pth") 11 | else: 12 | torch.save(model.state_dict(), save_dir + args.run_name + save_name + ".pth") 13 | wandb.save(save_dir + args.run_name + save_name + ".pth") 14 | 15 | def collect_random(env, dataset, num_samples=200): 16 | state = env.reset() 17 | for _ in range(num_samples): 18 | action = env.action_space.sample() 19 | next_state, reward, done, _ = env.step(action) 20 | dataset.add(state, action, reward, next_state, done) 21 | state = next_state 22 | if done: 23 | state = env.reset() 24 | --------------------------------------------------------------------------------