├── run_experiments.sh ├── LICENSE ├── README.md ├── utils.py ├── main.py └── TD3_BC.py /run_experiments.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Script to reproduce results 4 | 5 | envs=( 6 | "halfcheetah-random-v0" 7 | "hopper-random-v0" 8 | "walker2d-random-v0" 9 | "halfcheetah-medium-v0" 10 | "hopper-medium-v0" 11 | "walker2d-medium-v0" 12 | "halfcheetah-expert-v0" 13 | "hopper-expert-v0" 14 | "walker2d-expert-v0" 15 | "halfcheetah-medium-expert-v0" 16 | "hopper-medium-expert-v0" 17 | "walker2d-medium-expert-v0" 18 | "halfcheetah-medium-replay-v0" 19 | "hopper-medium-replay-v0" 20 | "walker2d-medium-replay-v0" 21 | ) 22 | 23 | for ((i=0;i<5;i+=1)) 24 | do 25 | for env in ${envs[*]} 26 | do 27 | python main.py \ 28 | --env $env \ 29 | --seed $i 30 | done 31 | done -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Scott Fujimoto 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A Minimalist Approach to Offline Reinforcement Learning 2 | 3 | TD3+BC is a simple approach to offline RL where only two changes are made to TD3: (1) a weighted behavior cloning loss is added to the policy update and (2) the states are normalized. Unlike competing methods there are no changes to architecture or underlying hyperparameters. The paper can be found [here](https://arxiv.org/abs/2106.06860). 4 | 5 | ### Usage 6 | Paper results were collected with [MuJoCo 1.50](http://www.mujoco.org/) (and [mujoco-py 1.50.1.1](https://github.com/openai/mujoco-py)) in [OpenAI gym 0.17.0](https://github.com/openai/gym) with the [D4RL datasets](https://github.com/rail-berkeley/d4rl). Networks are trained using [PyTorch 1.4.0](https://github.com/pytorch/pytorch) and Python 3.6. 7 | 8 | The paper results can be reproduced by running: 9 | ``` 10 | ./run_experiments.sh 11 | ``` 12 | 13 | 14 | ### Bibtex 15 | ``` 16 | @inproceedings{fujimoto2021minimalist, 17 | title={A Minimalist Approach to Offline Reinforcement Learning}, 18 | author={Scott Fujimoto and Shixiang Shane Gu}, 19 | booktitle={Thirty-Fifth Conference on Neural Information Processing Systems}, 20 | year={2021}, 21 | } 22 | ``` 23 | 24 | --- 25 | *This is not an official Google product. 26 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class ReplayBuffer(object): 6 | def __init__(self, state_dim, action_dim, max_size=int(1e6)): 7 | self.max_size = max_size 8 | self.ptr = 0 9 | self.size = 0 10 | 11 | self.state = np.zeros((max_size, state_dim)) 12 | self.action = np.zeros((max_size, action_dim)) 13 | self.next_state = np.zeros((max_size, state_dim)) 14 | self.reward = np.zeros((max_size, 1)) 15 | self.not_done = np.zeros((max_size, 1)) 16 | 17 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 18 | 19 | 20 | def add(self, state, action, next_state, reward, done): 21 | self.state[self.ptr] = state 22 | self.action[self.ptr] = action 23 | self.next_state[self.ptr] = next_state 24 | self.reward[self.ptr] = reward 25 | self.not_done[self.ptr] = 1. - done 26 | 27 | self.ptr = (self.ptr + 1) % self.max_size 28 | self.size = min(self.size + 1, self.max_size) 29 | 30 | 31 | def sample(self, batch_size): 32 | ind = np.random.randint(0, self.size, size=batch_size) 33 | 34 | return ( 35 | torch.FloatTensor(self.state[ind]).to(self.device), 36 | torch.FloatTensor(self.action[ind]).to(self.device), 37 | torch.FloatTensor(self.next_state[ind]).to(self.device), 38 | torch.FloatTensor(self.reward[ind]).to(self.device), 39 | torch.FloatTensor(self.not_done[ind]).to(self.device) 40 | ) 41 | 42 | 43 | def convert_D4RL(self, dataset): 44 | self.state = dataset['observations'] 45 | self.action = dataset['actions'] 46 | self.next_state = dataset['next_observations'] 47 | self.reward = dataset['rewards'].reshape(-1,1) 48 | self.not_done = 1. - dataset['terminals'].reshape(-1,1) 49 | self.size = self.state.shape[0] 50 | 51 | 52 | def normalize_states(self, eps = 1e-3): 53 | mean = self.state.mean(0,keepdims=True) 54 | std = self.state.std(0,keepdims=True) + eps 55 | self.state = (self.state - mean)/std 56 | self.next_state = (self.next_state - mean)/std 57 | return mean, std -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import gym 4 | import argparse 5 | import os 6 | import d4rl 7 | 8 | import utils 9 | import TD3_BC 10 | 11 | 12 | # Runs policy for X episodes and returns D4RL score 13 | # A fixed seed is used for the eval environment 14 | def eval_policy(policy, env_name, seed, mean, std, seed_offset=100, eval_episodes=10): 15 | eval_env = gym.make(env_name) 16 | eval_env.seed(seed + seed_offset) 17 | 18 | avg_reward = 0. 19 | for _ in range(eval_episodes): 20 | state, done = eval_env.reset(), False 21 | while not done: 22 | state = (np.array(state).reshape(1,-1) - mean)/std 23 | action = policy.select_action(state) 24 | state, reward, done, _ = eval_env.step(action) 25 | avg_reward += reward 26 | 27 | avg_reward /= eval_episodes 28 | d4rl_score = eval_env.get_normalized_score(avg_reward) * 100 29 | 30 | print("---------------------------------------") 31 | print(f"Evaluation over {eval_episodes} episodes: {avg_reward:.3f}, D4RL score: {d4rl_score:.3f}") 32 | print("---------------------------------------") 33 | return d4rl_score 34 | 35 | 36 | if __name__ == "__main__": 37 | 38 | parser = argparse.ArgumentParser() 39 | # Experiment 40 | parser.add_argument("--policy", default="TD3_BC") # Policy name 41 | parser.add_argument("--env", default="hopper-medium-v0") # OpenAI gym environment name 42 | parser.add_argument("--seed", default=0, type=int) # Sets Gym, PyTorch and Numpy seeds 43 | parser.add_argument("--eval_freq", default=5e3, type=int) # How often (time steps) we evaluate 44 | parser.add_argument("--max_timesteps", default=1e6, type=int) # Max time steps to run environment 45 | parser.add_argument("--save_model", action="store_true") # Save model and optimizer parameters 46 | parser.add_argument("--load_model", default="") # Model load file name, "" doesn't load, "default" uses file_name 47 | # TD3 48 | parser.add_argument("--expl_noise", default=0.1) # Std of Gaussian exploration noise 49 | parser.add_argument("--batch_size", default=256, type=int) # Batch size for both actor and critic 50 | parser.add_argument("--discount", default=0.99) # Discount factor 51 | parser.add_argument("--tau", default=0.005) # Target network update rate 52 | parser.add_argument("--policy_noise", default=0.2) # Noise added to target policy during critic update 53 | parser.add_argument("--noise_clip", default=0.5) # Range to clip target policy noise 54 | parser.add_argument("--policy_freq", default=2, type=int) # Frequency of delayed policy updates 55 | # TD3 + BC 56 | parser.add_argument("--alpha", default=2.5) 57 | parser.add_argument("--normalize", default=True) 58 | args = parser.parse_args() 59 | 60 | file_name = f"{args.policy}_{args.env}_{args.seed}" 61 | print("---------------------------------------") 62 | print(f"Policy: {args.policy}, Env: {args.env}, Seed: {args.seed}") 63 | print("---------------------------------------") 64 | 65 | if not os.path.exists("./results"): 66 | os.makedirs("./results") 67 | 68 | if args.save_model and not os.path.exists("./models"): 69 | os.makedirs("./models") 70 | 71 | env = gym.make(args.env) 72 | 73 | # Set seeds 74 | env.seed(args.seed) 75 | env.action_space.seed(args.seed) 76 | torch.manual_seed(args.seed) 77 | np.random.seed(args.seed) 78 | 79 | state_dim = env.observation_space.shape[0] 80 | action_dim = env.action_space.shape[0] 81 | max_action = float(env.action_space.high[0]) 82 | 83 | kwargs = { 84 | "state_dim": state_dim, 85 | "action_dim": action_dim, 86 | "max_action": max_action, 87 | "discount": args.discount, 88 | "tau": args.tau, 89 | # TD3 90 | "policy_noise": args.policy_noise * max_action, 91 | "noise_clip": args.noise_clip * max_action, 92 | "policy_freq": args.policy_freq, 93 | # TD3 + BC 94 | "alpha": args.alpha 95 | } 96 | 97 | # Initialize policy 98 | policy = TD3_BC.TD3_BC(**kwargs) 99 | 100 | if args.load_model != "": 101 | policy_file = file_name if args.load_model == "default" else args.load_model 102 | policy.load(f"./models/{policy_file}") 103 | 104 | replay_buffer = utils.ReplayBuffer(state_dim, action_dim) 105 | replay_buffer.convert_D4RL(d4rl.qlearning_dataset(env)) 106 | if args.normalize: 107 | mean,std = replay_buffer.normalize_states() 108 | else: 109 | mean,std = 0,1 110 | 111 | evaluations = [] 112 | for t in range(int(args.max_timesteps)): 113 | policy.train(replay_buffer, args.batch_size) 114 | # Evaluate episode 115 | if (t + 1) % args.eval_freq == 0: 116 | print(f"Time steps: {t+1}") 117 | evaluations.append(eval_policy(policy, args.env, args.seed, mean, std)) 118 | np.save(f"./results/{file_name}", evaluations) 119 | if args.save_model: policy.save(f"./models/{file_name}") 120 | -------------------------------------------------------------------------------- /TD3_BC.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | 11 | class Actor(nn.Module): 12 | def __init__(self, state_dim, action_dim, max_action): 13 | super(Actor, self).__init__() 14 | 15 | self.l1 = nn.Linear(state_dim, 256) 16 | self.l2 = nn.Linear(256, 256) 17 | self.l3 = nn.Linear(256, action_dim) 18 | 19 | self.max_action = max_action 20 | 21 | 22 | def forward(self, state): 23 | a = F.relu(self.l1(state)) 24 | a = F.relu(self.l2(a)) 25 | return self.max_action * torch.tanh(self.l3(a)) 26 | 27 | 28 | class Critic(nn.Module): 29 | def __init__(self, state_dim, action_dim): 30 | super(Critic, self).__init__() 31 | 32 | # Q1 architecture 33 | self.l1 = nn.Linear(state_dim + action_dim, 256) 34 | self.l2 = nn.Linear(256, 256) 35 | self.l3 = nn.Linear(256, 1) 36 | 37 | # Q2 architecture 38 | self.l4 = nn.Linear(state_dim + action_dim, 256) 39 | self.l5 = nn.Linear(256, 256) 40 | self.l6 = nn.Linear(256, 1) 41 | 42 | 43 | def forward(self, state, action): 44 | sa = torch.cat([state, action], 1) 45 | 46 | q1 = F.relu(self.l1(sa)) 47 | q1 = F.relu(self.l2(q1)) 48 | q1 = self.l3(q1) 49 | 50 | q2 = F.relu(self.l4(sa)) 51 | q2 = F.relu(self.l5(q2)) 52 | q2 = self.l6(q2) 53 | return q1, q2 54 | 55 | 56 | def Q1(self, state, action): 57 | sa = torch.cat([state, action], 1) 58 | 59 | q1 = F.relu(self.l1(sa)) 60 | q1 = F.relu(self.l2(q1)) 61 | q1 = self.l3(q1) 62 | return q1 63 | 64 | 65 | class TD3_BC(object): 66 | def __init__( 67 | self, 68 | state_dim, 69 | action_dim, 70 | max_action, 71 | discount=0.99, 72 | tau=0.005, 73 | policy_noise=0.2, 74 | noise_clip=0.5, 75 | policy_freq=2, 76 | alpha=2.5, 77 | ): 78 | 79 | self.actor = Actor(state_dim, action_dim, max_action).to(device) 80 | self.actor_target = copy.deepcopy(self.actor) 81 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4) 82 | 83 | self.critic = Critic(state_dim, action_dim).to(device) 84 | self.critic_target = copy.deepcopy(self.critic) 85 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4) 86 | 87 | self.max_action = max_action 88 | self.discount = discount 89 | self.tau = tau 90 | self.policy_noise = policy_noise 91 | self.noise_clip = noise_clip 92 | self.policy_freq = policy_freq 93 | self.alpha = alpha 94 | 95 | self.total_it = 0 96 | 97 | 98 | def select_action(self, state): 99 | state = torch.FloatTensor(state.reshape(1, -1)).to(device) 100 | return self.actor(state).cpu().data.numpy().flatten() 101 | 102 | 103 | def train(self, replay_buffer, batch_size=256): 104 | self.total_it += 1 105 | 106 | # Sample replay buffer 107 | state, action, next_state, reward, not_done = replay_buffer.sample(batch_size) 108 | 109 | with torch.no_grad(): 110 | # Select action according to policy and add clipped noise 111 | noise = ( 112 | torch.randn_like(action) * self.policy_noise 113 | ).clamp(-self.noise_clip, self.noise_clip) 114 | 115 | next_action = ( 116 | self.actor_target(next_state) + noise 117 | ).clamp(-self.max_action, self.max_action) 118 | 119 | # Compute the target Q value 120 | target_Q1, target_Q2 = self.critic_target(next_state, next_action) 121 | target_Q = torch.min(target_Q1, target_Q2) 122 | target_Q = reward + not_done * self.discount * target_Q 123 | 124 | # Get current Q estimates 125 | current_Q1, current_Q2 = self.critic(state, action) 126 | 127 | # Compute critic loss 128 | critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q) 129 | 130 | # Optimize the critic 131 | self.critic_optimizer.zero_grad() 132 | critic_loss.backward() 133 | self.critic_optimizer.step() 134 | 135 | # Delayed policy updates 136 | if self.total_it % self.policy_freq == 0: 137 | 138 | # Compute actor loss 139 | pi = self.actor(state) 140 | Q = self.critic.Q1(state, pi) 141 | lmbda = self.alpha/Q.abs().mean().detach() 142 | 143 | actor_loss = -lmbda * Q.mean() + F.mse_loss(pi, action) 144 | 145 | # Optimize the actor 146 | self.actor_optimizer.zero_grad() 147 | actor_loss.backward() 148 | self.actor_optimizer.step() 149 | 150 | # Update the frozen target models 151 | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): 152 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 153 | 154 | for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): 155 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 156 | 157 | 158 | def save(self, filename): 159 | torch.save(self.critic.state_dict(), filename + "_critic") 160 | torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer") 161 | 162 | torch.save(self.actor.state_dict(), filename + "_actor") 163 | torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer") 164 | 165 | 166 | def load(self, filename): 167 | self.critic.load_state_dict(torch.load(filename + "_critic")) 168 | self.critic_optimizer.load_state_dict(torch.load(filename + "_critic_optimizer")) 169 | self.critic_target = copy.deepcopy(self.critic) 170 | 171 | self.actor.load_state_dict(torch.load(filename + "_actor")) 172 | self.actor_optimizer.load_state_dict(torch.load(filename + "_actor_optimizer")) 173 | self.actor_target = copy.deepcopy(self.actor) --------------------------------------------------------------------------------