├── README.md ├── img ├── car_racing_demo_ppo.gif ├── car_racing_ppo.png ├── car_racing_ppo.svg └── network.png ├── param └── ppo_net_params.pkl ├── test.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Car Racing with PyTorch 2 | Solving the car racing problem in OpenAI Gym using Proximal Policy Optimization (PPO). This problem has a real physical engine in the back end. You can achieve real racing actions in the environment, like drifting. 3 | 4 | ## Requirement 5 | To run the code, you need 6 | - [pytorch 0.4](https://pytorch.org/) 7 | - [gym 0.10](https://github.com/openai/gym) 8 | - [visdom 0.1.8](https://github.com/facebookresearch/visdom) 9 | 10 | ## Method 11 | Every action will be repeated for 8 frames. To get velocity information, state is defined as adjacent 4 frames in shape (4, 96, 96). Use a two heads FCN to represent the actor and critic respectively. The actor outputs α, β for each actin as the parameters of Beta distribution. 12 |
13 | 14 | ## Training 15 | Start a Visdom server with ```python -m visdom.server```, it will serve http://localhost:8097/ by default. 16 | 17 | To train the agent, run```python train.py --render --vis``` or ```python train.py --render``` without visdom. 18 | To test, run ```python test.py --render```. 19 | 20 | ## Performance 21 |
22 |
23 | 24 | -------------------------------------------------------------------------------- /img/car_racing_demo_ppo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xtma/pytorch_car_caring/1631ad17b6c7222644988d30bedd4c44bc5453ac/img/car_racing_demo_ppo.gif -------------------------------------------------------------------------------- /img/car_racing_ppo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xtma/pytorch_car_caring/1631ad17b6c7222644988d30bedd4c44bc5453ac/img/car_racing_ppo.png -------------------------------------------------------------------------------- /img/car_racing_ppo.svg: -------------------------------------------------------------------------------- 1 | 0500100015000200400600800PPOEpisodeMoving averaged episode reward -------------------------------------------------------------------------------- /img/network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xtma/pytorch_car_caring/1631ad17b6c7222644988d30bedd4c44bc5453ac/img/network.png -------------------------------------------------------------------------------- /param/ppo_net_params.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xtma/pytorch_car_caring/1631ad17b6c7222644988d30bedd4c44bc5453ac/param/ppo_net_params.pkl -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | 5 | import gym 6 | import torch 7 | import torch.nn as nn 8 | 9 | parser = argparse.ArgumentParser(description='Test the PPO agent for the CarRacing-v0') 10 | parser.add_argument('--action-repeat', type=int, default=8, metavar='N', help='repeat action in N frames (default: 12)') 11 | parser.add_argument('--img-stack', type=int, default=4, metavar='N', help='stack N image in a state (default: 4)') 12 | parser.add_argument('--seed', type=int, default=0, metavar='N', help='random seed (default: 0)') 13 | parser.add_argument('--render', action='store_true', help='render the environment') 14 | args = parser.parse_args() 15 | 16 | use_cuda = torch.cuda.is_available() 17 | device = torch.device("cuda" if use_cuda else "cpu") 18 | torch.manual_seed(args.seed) 19 | if use_cuda: 20 | torch.cuda.manual_seed(args.seed) 21 | 22 | 23 | class Env(): 24 | """ 25 | Test environment wrapper for CarRacing 26 | """ 27 | 28 | def __init__(self): 29 | self.env = gym.make('CarRacing-v0') 30 | self.env.seed(args.seed) 31 | self.reward_threshold = self.env.spec.reward_threshold 32 | 33 | def reset(self): 34 | self.counter = 0 35 | self.av_r = self.reward_memory() 36 | 37 | self.die = False 38 | img_rgb = self.env.reset() 39 | img_gray = self.rgb2gray(img_rgb) 40 | self.stack = [img_gray] * args.img_stack 41 | return np.array(self.stack) 42 | 43 | def step(self, action): 44 | total_reward = 0 45 | for i in range(args.action_repeat): 46 | img_rgb, reward, die, _ = self.env.step(action) 47 | # don't penalize "die state" 48 | if die: 49 | reward += 100 50 | # green penalty 51 | if np.mean(img_rgb[:, :, 1]) > 185.0: 52 | reward -= 0.05 53 | total_reward += reward 54 | # if no reward recently, end the episode 55 | done = True if self.av_r(reward) <= -0.1 else False 56 | if done or die: 57 | break 58 | img_gray = self.rgb2gray(img_rgb) 59 | self.stack.pop(0) 60 | self.stack.append(img_gray) 61 | assert len(self.stack) == args.img_stack 62 | return np.array(self.stack), total_reward, done, die 63 | 64 | def render(self, *arg): 65 | self.env.render(*arg) 66 | 67 | @staticmethod 68 | def rgb2gray(rgb, norm=True): 69 | gray = np.dot(rgb[..., :], [0.299, 0.587, 0.114]) 70 | if norm: 71 | # normalize 72 | gray = gray / 128. - 1. 73 | return gray 74 | 75 | @staticmethod 76 | def reward_memory(): 77 | count = 0 78 | length = 100 79 | history = np.zeros(length) 80 | 81 | def memory(reward): 82 | nonlocal count 83 | history[count] = reward 84 | count = (count + 1) % length 85 | return np.mean(history) 86 | 87 | return memory 88 | 89 | 90 | class Net(nn.Module): 91 | """ 92 | Actor-Critic Network for PPO 93 | """ 94 | 95 | def __init__(self): 96 | super(Net, self).__init__() 97 | self.cnn_base = nn.Sequential( # input shape (4, 96, 96) 98 | nn.Conv2d(args.img_stack, 8, kernel_size=4, stride=2), 99 | nn.ReLU(), # activation 100 | nn.Conv2d(8, 16, kernel_size=3, stride=2), # (8, 47, 47) 101 | nn.ReLU(), # activation 102 | nn.Conv2d(16, 32, kernel_size=3, stride=2), # (16, 23, 23) 103 | nn.ReLU(), # activation 104 | nn.Conv2d(32, 64, kernel_size=3, stride=2), # (32, 11, 11) 105 | nn.ReLU(), # activation 106 | nn.Conv2d(64, 128, kernel_size=3, stride=1), # (64, 5, 5) 107 | nn.ReLU(), # activation 108 | nn.Conv2d(128, 256, kernel_size=3, stride=1), # (128, 3, 3) 109 | nn.ReLU(), # activation 110 | ) # output shape (256, 1, 1) 111 | self.v = nn.Sequential(nn.Linear(256, 100), nn.ReLU(), nn.Linear(100, 1)) 112 | self.fc = nn.Sequential(nn.Linear(256, 100), nn.ReLU()) 113 | self.alpha_head = nn.Sequential(nn.Linear(100, 3), nn.Softplus()) 114 | self.beta_head = nn.Sequential(nn.Linear(100, 3), nn.Softplus()) 115 | self.apply(self._weights_init) 116 | 117 | @staticmethod 118 | def _weights_init(m): 119 | if isinstance(m, nn.Conv2d): 120 | nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu')) 121 | nn.init.constant_(m.bias, 0.1) 122 | 123 | def forward(self, x): 124 | x = self.cnn_base(x) 125 | x = x.view(-1, 256) 126 | v = self.v(x) 127 | x = self.fc(x) 128 | alpha = self.alpha_head(x) + 1 129 | beta = self.beta_head(x) + 1 130 | 131 | return (alpha, beta), v 132 | 133 | 134 | class Agent(): 135 | """ 136 | Agent for testing 137 | """ 138 | 139 | def __init__(self): 140 | self.net = Net().float().to(device) 141 | 142 | def select_action(self, state): 143 | state = torch.from_numpy(state).float().to(device).unsqueeze(0) 144 | with torch.no_grad(): 145 | alpha, beta = self.net(state)[0] 146 | action = alpha / (alpha + beta) 147 | 148 | action = action.squeeze().cpu().numpy() 149 | return action 150 | 151 | def load_param(self): 152 | self.net.load_state_dict(torch.load('param/ppo_net_params.pkl')) 153 | 154 | 155 | if __name__ == "__main__": 156 | agent = Agent() 157 | agent.load_param() 158 | env = Env() 159 | 160 | training_records = [] 161 | running_score = 0 162 | state = env.reset() 163 | for i_ep in range(10): 164 | score = 0 165 | state = env.reset() 166 | 167 | for t in range(1000): 168 | action = agent.select_action(state) 169 | state_, reward, done, die = env.step(action * np.array([2., 1., 1.]) + np.array([-1., 0., 0.])) 170 | if args.render: 171 | env.render() 172 | score += reward 173 | state = state_ 174 | if done or die: 175 | break 176 | 177 | print('Ep {}\tScore: {:.2f}\t'.format(i_ep, score)) 178 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | 5 | import gym 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | from torch.distributions import Beta 11 | from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler 12 | from utils import DrawLine 13 | 14 | parser = argparse.ArgumentParser(description='Train a PPO agent for the CarRacing-v0') 15 | parser.add_argument('--gamma', type=float, default=0.99, metavar='G', help='discount factor (default: 0.99)') 16 | parser.add_argument('--action-repeat', type=int, default=8, metavar='N', help='repeat action in N frames (default: 8)') 17 | parser.add_argument('--img-stack', type=int, default=4, metavar='N', help='stack N image in a state (default: 4)') 18 | parser.add_argument('--seed', type=int, default=0, metavar='N', help='random seed (default: 0)') 19 | parser.add_argument('--render', action='store_true', help='render the environment') 20 | parser.add_argument('--vis', action='store_true', help='use visdom') 21 | parser.add_argument( 22 | '--log-interval', type=int, default=10, metavar='N', help='interval between training status logs (default: 10)') 23 | args = parser.parse_args() 24 | 25 | use_cuda = torch.cuda.is_available() 26 | device = torch.device("cuda" if use_cuda else "cpu") 27 | torch.manual_seed(args.seed) 28 | if use_cuda: 29 | torch.cuda.manual_seed(args.seed) 30 | 31 | transition = np.dtype([('s', np.float64, (args.img_stack, 96, 96)), ('a', np.float64, (3,)), ('a_logp', np.float64), 32 | ('r', np.float64), ('s_', np.float64, (args.img_stack, 96, 96))]) 33 | 34 | 35 | class Env(): 36 | """ 37 | Environment wrapper for CarRacing 38 | """ 39 | 40 | def __init__(self): 41 | self.env = gym.make('CarRacing-v0') 42 | self.env.seed(args.seed) 43 | self.reward_threshold = self.env.spec.reward_threshold 44 | 45 | def reset(self): 46 | self.counter = 0 47 | self.av_r = self.reward_memory() 48 | 49 | self.die = False 50 | img_rgb = self.env.reset() 51 | img_gray = self.rgb2gray(img_rgb) 52 | self.stack = [img_gray] * args.img_stack # four frames for decision 53 | return np.array(self.stack) 54 | 55 | def step(self, action): 56 | total_reward = 0 57 | for i in range(args.action_repeat): 58 | img_rgb, reward, die, _ = self.env.step(action) 59 | # don't penalize "die state" 60 | if die: 61 | reward += 100 62 | # green penalty 63 | if np.mean(img_rgb[:, :, 1]) > 185.0: 64 | reward -= 0.05 65 | total_reward += reward 66 | # if no reward recently, end the episode 67 | done = True if self.av_r(reward) <= -0.1 else False 68 | if done or die: 69 | break 70 | img_gray = self.rgb2gray(img_rgb) 71 | self.stack.pop(0) 72 | self.stack.append(img_gray) 73 | assert len(self.stack) == args.img_stack 74 | return np.array(self.stack), total_reward, done, die 75 | 76 | def render(self, *arg): 77 | self.env.render(*arg) 78 | 79 | @staticmethod 80 | def rgb2gray(rgb, norm=True): 81 | # rgb image -> gray [0, 1] 82 | gray = np.dot(rgb[..., :], [0.299, 0.587, 0.114]) 83 | if norm: 84 | # normalize 85 | gray = gray / 128. - 1. 86 | return gray 87 | 88 | @staticmethod 89 | def reward_memory(): 90 | # record reward for last 100 steps 91 | count = 0 92 | length = 100 93 | history = np.zeros(length) 94 | 95 | def memory(reward): 96 | nonlocal count 97 | history[count] = reward 98 | count = (count + 1) % length 99 | return np.mean(history) 100 | 101 | return memory 102 | 103 | 104 | class Net(nn.Module): 105 | """ 106 | Actor-Critic Network for PPO 107 | """ 108 | 109 | def __init__(self): 110 | super(Net, self).__init__() 111 | self.cnn_base = nn.Sequential( # input shape (4, 96, 96) 112 | nn.Conv2d(args.img_stack, 8, kernel_size=4, stride=2), 113 | nn.ReLU(), # activation 114 | nn.Conv2d(8, 16, kernel_size=3, stride=2), # (8, 47, 47) 115 | nn.ReLU(), # activation 116 | nn.Conv2d(16, 32, kernel_size=3, stride=2), # (16, 23, 23) 117 | nn.ReLU(), # activation 118 | nn.Conv2d(32, 64, kernel_size=3, stride=2), # (32, 11, 11) 119 | nn.ReLU(), # activation 120 | nn.Conv2d(64, 128, kernel_size=3, stride=1), # (64, 5, 5) 121 | nn.ReLU(), # activation 122 | nn.Conv2d(128, 256, kernel_size=3, stride=1), # (128, 3, 3) 123 | nn.ReLU(), # activation 124 | ) # output shape (256, 1, 1) 125 | self.v = nn.Sequential(nn.Linear(256, 100), nn.ReLU(), nn.Linear(100, 1)) 126 | self.fc = nn.Sequential(nn.Linear(256, 100), nn.ReLU()) 127 | self.alpha_head = nn.Sequential(nn.Linear(100, 3), nn.Softplus()) 128 | self.beta_head = nn.Sequential(nn.Linear(100, 3), nn.Softplus()) 129 | self.apply(self._weights_init) 130 | 131 | @staticmethod 132 | def _weights_init(m): 133 | if isinstance(m, nn.Conv2d): 134 | nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu')) 135 | nn.init.constant_(m.bias, 0.1) 136 | 137 | def forward(self, x): 138 | x = self.cnn_base(x) 139 | x = x.view(-1, 256) 140 | v = self.v(x) 141 | x = self.fc(x) 142 | alpha = self.alpha_head(x) + 1 143 | beta = self.beta_head(x) + 1 144 | 145 | return (alpha, beta), v 146 | 147 | 148 | class Agent(): 149 | """ 150 | Agent for training 151 | """ 152 | max_grad_norm = 0.5 153 | clip_param = 0.1 # epsilon in clipped loss 154 | ppo_epoch = 10 155 | buffer_capacity, batch_size = 2000, 128 156 | 157 | def __init__(self): 158 | self.training_step = 0 159 | self.net = Net().double().to(device) 160 | self.buffer = np.empty(self.buffer_capacity, dtype=transition) 161 | self.counter = 0 162 | 163 | self.optimizer = optim.Adam(self.net.parameters(), lr=1e-3) 164 | 165 | def select_action(self, state): 166 | state = torch.from_numpy(state).double().to(device).unsqueeze(0) 167 | with torch.no_grad(): 168 | alpha, beta = self.net(state)[0] 169 | dist = Beta(alpha, beta) 170 | action = dist.sample() 171 | a_logp = dist.log_prob(action).sum(dim=1) 172 | 173 | action = action.squeeze().cpu().numpy() 174 | a_logp = a_logp.item() 175 | return action, a_logp 176 | 177 | def save_param(self): 178 | torch.save(self.net.state_dict(), 'param/ppo_net_params.pkl') 179 | 180 | def store(self, transition): 181 | self.buffer[self.counter] = transition 182 | self.counter += 1 183 | if self.counter == self.buffer_capacity: 184 | self.counter = 0 185 | return True 186 | else: 187 | return False 188 | 189 | def update(self): 190 | self.training_step += 1 191 | 192 | s = torch.tensor(self.buffer['s'], dtype=torch.double).to(device) 193 | a = torch.tensor(self.buffer['a'], dtype=torch.double).to(device) 194 | r = torch.tensor(self.buffer['r'], dtype=torch.double).to(device).view(-1, 1) 195 | s_ = torch.tensor(self.buffer['s_'], dtype=torch.double).to(device) 196 | 197 | old_a_logp = torch.tensor(self.buffer['a_logp'], dtype=torch.double).to(device).view(-1, 1) 198 | 199 | with torch.no_grad(): 200 | target_v = r + args.gamma * self.net(s_)[1] 201 | adv = target_v - self.net(s)[1] 202 | # adv = (adv - adv.mean()) / (adv.std() + 1e-8) 203 | 204 | for _ in range(self.ppo_epoch): 205 | for index in BatchSampler(SubsetRandomSampler(range(self.buffer_capacity)), self.batch_size, False): 206 | 207 | alpha, beta = self.net(s[index])[0] 208 | dist = Beta(alpha, beta) 209 | a_logp = dist.log_prob(a[index]).sum(dim=1, keepdim=True) 210 | ratio = torch.exp(a_logp - old_a_logp[index]) 211 | 212 | surr1 = ratio * adv[index] 213 | surr2 = torch.clamp(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param) * adv[index] 214 | action_loss = -torch.min(surr1, surr2).mean() 215 | value_loss = F.smooth_l1_loss(self.net(s[index])[1], target_v[index]) 216 | loss = action_loss + 2. * value_loss 217 | 218 | self.optimizer.zero_grad() 219 | loss.backward() 220 | # nn.utils.clip_grad_norm_(self.net.parameters(), self.max_grad_norm) 221 | self.optimizer.step() 222 | 223 | 224 | if __name__ == "__main__": 225 | agent = Agent() 226 | env = Env() 227 | if args.vis: 228 | draw_reward = DrawLine(env="car", title="PPO", xlabel="Episode", ylabel="Moving averaged episode reward") 229 | 230 | training_records = [] 231 | running_score = 0 232 | state = env.reset() 233 | for i_ep in range(100000): 234 | score = 0 235 | state = env.reset() 236 | 237 | for t in range(1000): 238 | action, a_logp = agent.select_action(state) 239 | state_, reward, done, die = env.step(action * np.array([2., 1., 1.]) + np.array([-1., 0., 0.])) 240 | if args.render: 241 | env.render() 242 | if agent.store((state, action, a_logp, reward, state_)): 243 | print('updating') 244 | agent.update() 245 | score += reward 246 | state = state_ 247 | if done or die: 248 | break 249 | running_score = running_score * 0.99 + score * 0.01 250 | 251 | if i_ep % args.log_interval == 0: 252 | if args.vis: 253 | draw_reward(xdata=i_ep, ydata=running_score) 254 | print('Ep {}\tLast score: {:.2f}\tMoving average score: {:.2f}'.format(i_ep, score, running_score)) 255 | agent.save_param() 256 | if running_score > env.reward_threshold: 257 | print("Solved! Running reward is now {} and the last episode runs to {}!".format(running_score, score)) 258 | break 259 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import visdom 2 | import numpy as np 3 | 4 | 5 | class DrawLine(): 6 | 7 | def __init__(self, env, title, xlabel=None, ylabel=None): 8 | self.vis = visdom.Visdom() 9 | self.update_flag = False 10 | self.env = env 11 | self.xlabel = xlabel 12 | self.ylabel = ylabel 13 | self.title = title 14 | 15 | def __call__( 16 | self, 17 | xdata, 18 | ydata, 19 | ): 20 | if not self.update_flag: 21 | self.win = self.vis.line( 22 | X=np.array([xdata]), 23 | Y=np.array([ydata]), 24 | opts=dict( 25 | xlabel=self.xlabel, 26 | ylabel=self.ylabel, 27 | title=self.title, 28 | ), 29 | env=self.env, 30 | ) 31 | self.update_flag = True 32 | else: 33 | self.vis.line( 34 | X=np.array([xdata]), 35 | Y=np.array([ydata]), 36 | win=self.win, 37 | env=self.env, 38 | update='append', 39 | ) 40 | --------------------------------------------------------------------------------