├── DRQN.py └── env_Tmaze.py /DRQN.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | import random 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import random 7 | from torch.autograd import Variable 8 | from env_Tmaze import EnvTMaze 9 | import numpy as np 10 | import math 11 | 12 | class ReplayMemory(object): 13 | def __init__(self, max_epi_num=50, max_epi_len=300): 14 | # capacity is the maximum number of episodes 15 | self.max_epi_num = max_epi_num 16 | self.max_epi_len = max_epi_len 17 | self.memory = deque(maxlen=self.max_epi_num) 18 | self.is_av = False 19 | self.current_epi = 0 20 | self.memory.append([]) 21 | 22 | def reset(self): 23 | self.current_epi = 0 24 | self.memory.clear() 25 | self.memory.append([]) 26 | 27 | def create_new_epi(self): 28 | self.memory.append([]) 29 | self.current_epi = self.current_epi + 1 30 | if self.current_epi > self.max_epi_num - 1: 31 | self.current_epi = self.max_epi_num - 1 32 | 33 | def remember(self, state, action, reward): 34 | if len(self.memory[self.current_epi]) < self.max_epi_len: 35 | self.memory[self.current_epi].append([state, action, reward]) 36 | 37 | def sample(self): 38 | epi_index = random.randint(0, len(self.memory)-2) 39 | if self.is_available(): 40 | return self.memory[epi_index] 41 | else: 42 | return [] 43 | 44 | def size(self): 45 | return len(self.memory) 46 | 47 | def is_available(self): 48 | self.is_av = True 49 | if len(self.memory) <= 1: 50 | self.is_av = False 51 | return self.is_av 52 | 53 | def print_info(self): 54 | for i in range(len(self.memory)): 55 | print('epi', i, 'length', len(self.memory[i])) 56 | 57 | class Flatten(nn.Module): 58 | def forward(self, input): 59 | return input.view(input.size(0), -1) 60 | 61 | class DRQN(nn.Module): 62 | def __init__(self, N_action): 63 | super(DRQN, self).__init__() 64 | self.lstm_i_dim = 16 # input dimension of LSTM 65 | self.lstm_h_dim = 16 # output dimension of LSTM 66 | self.lstm_N_layer = 1 # number of layers of LSTM 67 | self.N_action = N_action 68 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1) 69 | self.flat1 = Flatten() 70 | self.lstm = nn.LSTM(input_size=self.lstm_i_dim, hidden_size=self.lstm_h_dim, num_layers=self.lstm_N_layer) 71 | self.fc1 = nn.Linear(self.lstm_h_dim, 16) 72 | self.fc2 = nn.Linear(16, self.N_action) 73 | 74 | def forward(self, x, hidden): 75 | h1 = F.relu(self.conv1(x)) 76 | h2 = self.flat1(h1) 77 | h2 = h2.unsqueeze(1) 78 | h3, new_hidden = self.lstm(h2, hidden) 79 | h4 = F.relu(self.fc1(h3)) 80 | h5 = self.fc2(h4) 81 | return h5, new_hidden 82 | 83 | class Agent(object): 84 | def __init__(self, N_action, max_epi_num=50, max_epi_len=300): 85 | self.N_action = N_action 86 | self.max_epi_num = max_epi_num 87 | self.max_epi_len = max_epi_len 88 | self.drqn = DRQN(self.N_action) 89 | self.buffer = ReplayMemory(max_epi_num=self.max_epi_num, max_epi_len=self.max_epi_len) 90 | self.gamma = 0.9 91 | self.loss_fn = torch.nn.MSELoss() 92 | self.optimizer = torch.optim.Adam(self.drqn.parameters(), lr=1e-3) 93 | 94 | def remember(self, state, action, reward): 95 | self.buffer.remember(state, action, reward) 96 | 97 | def img_to_tensor(self, img): 98 | img_tensor = torch.FloatTensor(img) 99 | img_tensor = img_tensor.permute(2, 0, 1) 100 | return img_tensor 101 | 102 | def img_list_to_batch(self, x): 103 | # transform a list of image to a batch of tensor [batch size, input channel, width, height] 104 | temp_batch = self.img_to_tensor(x[0]) 105 | temp_batch = temp_batch.unsqueeze(0) 106 | for i in range(1, len(x)): 107 | img = self.img_to_tensor(x[i]) 108 | img = img.unsqueeze(0) 109 | temp_batch = torch.cat([temp_batch, img], dim=0) 110 | return temp_batch 111 | 112 | def train(self): 113 | if self.buffer.is_available(): 114 | memo = self.buffer.sample() 115 | obs_list = [] 116 | action_list = [] 117 | reward_list = [] 118 | for i in range(len(memo)): 119 | obs_list.append(memo[i][0]) 120 | action_list.append(memo[i][1]) 121 | reward_list.append(memo[i][2]) 122 | obs_list = self.img_list_to_batch(obs_list) 123 | hidden = (Variable(torch.zeros(1, 1, 16).float()), Variable(torch.zeros(1, 1, 16).float())) 124 | Q, hidden = self.drqn.forward(obs_list, hidden) 125 | Q_est = Q.clone() 126 | for t in range(len(memo) - 1): 127 | max_next_q = torch.max(Q_est[t+1, 0, :]).clone().detach() 128 | Q_est[t, 0, action_list[t]] = reward_list[t] + self.gamma * max_next_q 129 | T = len(memo) - 1 130 | Q_est[T, 0, action_list[T]] = reward_list[T] 131 | 132 | loss = self.loss_fn(Q, Q_est) 133 | self.optimizer.zero_grad() 134 | loss.backward() 135 | self.optimizer.step() 136 | 137 | def get_action(self, obs, hidden, epsilon): 138 | if random.random() > epsilon: 139 | q, new_hidden = self.drqn.forward(self.img_to_tensor(obs).unsqueeze(0), hidden) 140 | action = q[0].max(1)[1].data[0].item() 141 | else: 142 | q, new_hidden = self.drqn.forward(self.img_to_tensor(obs).unsqueeze(0), hidden) 143 | action = random.randint(0, self.N_action-1) 144 | return action, new_hidden 145 | 146 | def get_decay(epi_iter): 147 | decay = math.pow(0.999, epi_iter) 148 | if decay < 0.05: 149 | decay = 0.05 150 | return decay 151 | 152 | if __name__ == '__main__': 153 | random.seed() 154 | env = EnvTMaze(4, random.randint(0, 1)) 155 | max_epi_iter = 30000 156 | max_MC_iter = 100 157 | agent = Agent(N_action=4, max_epi_num=5000, max_epi_len=max_MC_iter) 158 | train_curve = [] 159 | for epi_iter in range(max_epi_iter): 160 | random.seed() 161 | env.reset(random.randint(0, 1)) 162 | hidden = (Variable(torch.zeros(1, 1, 16).float()), Variable(torch.zeros(1, 1, 16).float())) 163 | for MC_iter in range(max_MC_iter): 164 | # env.render() 165 | obs = env.get_obs() 166 | action, hidden = agent.get_action(obs, hidden, get_decay(epi_iter)) 167 | reward = env.step(action) 168 | agent.remember(obs, action, reward) 169 | if reward != 0 or MC_iter == max_MC_iter-1: 170 | agent.buffer.create_new_epi() 171 | break 172 | print('Episode', epi_iter, 'reward', reward, 'where', env.if_up) 173 | if epi_iter % 100 == 0: 174 | train_curve.append(reward) 175 | if agent.buffer.is_available(): 176 | agent.train() 177 | np.save("len4_DRQN16_1e3_4.npy", np.array(train_curve)) -------------------------------------------------------------------------------- /env_Tmaze.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from matplotlib.gridspec import GridSpec 4 | import random 5 | import cv2 6 | 7 | class EnvTMaze(object): 8 | def __init__(self, len, if_up): 9 | self.len = len 10 | self.occupancy = np.zeros((5, len+2)) 11 | for i in range(5): 12 | self.occupancy[i, 0] = 1 13 | self.occupancy[i, len+1] = 1 14 | for i in range(len+2): 15 | self.occupancy[0, i] = 1 16 | self.occupancy[4, i] = 1 17 | for i in range(len): 18 | self.occupancy[1, i] = 1 19 | self.occupancy[3, i] = 1 20 | self.if_up = if_up 21 | self.agt_pos = [2, 1] 22 | if if_up == 1: 23 | self.dest_pos = [1, len] 24 | self.wrong_pos = [3, len] 25 | else: 26 | self.dest_pos = [3, len] 27 | self.wrong_pos = [1, len] 28 | 29 | def reset(self, if_up): 30 | self.if_up = if_up 31 | self.agt_pos = [2, 1] 32 | if if_up == 1: 33 | self.dest_pos = [1, self.len] 34 | self.wrong_pos = [3, self.len] 35 | else: 36 | self.dest_pos = [3, self.len] 37 | self.wrong_pos = [1, self.len] 38 | 39 | def step(self, action): 40 | reward = 0 41 | if action == 0: # up 42 | if self.occupancy[self.agt_pos[0] - 1][self.agt_pos[1]] != 1: # if can move 43 | self.agt_pos[0] = self.agt_pos[0] - 1 44 | if action == 1: # down 45 | if self.occupancy[self.agt_pos[0] + 1][self.agt_pos[1]] != 1: # if can move 46 | self.agt_pos[0] = self.agt_pos[0] + 1 47 | if action == 2: # left 48 | if self.occupancy[self.agt_pos[0]][self.agt_pos[1] - 1] != 1: # if can move 49 | self.agt_pos[1] = self.agt_pos[1] - 1 50 | if action == 3: # right 51 | if self.occupancy[self.agt_pos[0]][self.agt_pos[1] + 1] != 1: # if can move 52 | self.agt_pos[1] = self.agt_pos[1] + 1 53 | if self.agt_pos == self.dest_pos: 54 | reward = 10 55 | if self.agt_pos == self.wrong_pos: 56 | reward = -10 57 | return reward 58 | 59 | def get_obs(self): 60 | obs = np.zeros((3, 3, 3)) 61 | for i in range(3): 62 | for j in range(3): 63 | img_x = self.agt_pos[0] - 1 + i 64 | img_y = self.agt_pos[1] - 1 + j 65 | if img_x >= 0 and img_x < 5 and img_y >= 0 and img_y < 8: 66 | if self.occupancy[img_x, img_y] == 0: 67 | obs[i, j, 0] = 1.0 68 | obs[i, j, 1] = 1.0 69 | obs[i, j, 2] = 1.0 70 | if self.agt_pos == [2, 1]: 71 | if self.if_up==1: 72 | obs[1, 1, 0] = 0.0 73 | obs[1, 1, 1] = 1.0 74 | obs[1, 1, 2] = 0.0 75 | else: 76 | obs[1, 1, 0] = 0.0 77 | obs[1, 1, 1] = 0.0 78 | obs[1, 1, 2] = 1.0 79 | return obs 80 | 81 | def get_global_obs(self): 82 | obs = np.zeros((5, 6, 3)) 83 | for i in range(5): 84 | for j in range(6): 85 | if self.occupancy[i, j] == 0: 86 | obs[i, j, 0] = 1.0 87 | obs[i, j, 1] = 1.0 88 | obs[i, j, 2] = 1.0 89 | if self.if_up == 1: 90 | obs[2, 1, 0] = 0.0 91 | obs[2, 1, 1] = 1.0 92 | obs[2, 1, 2] = 0.0 93 | else: 94 | obs[2, 1, 0] = 0.0 95 | obs[2, 1, 1] = 0.0 96 | obs[2, 1, 2] = 1.0 97 | obs[self.agt_pos[0], self.agt_pos[1], 0] = 1.0 98 | obs[self.agt_pos[0], self.agt_pos[1], 1] = 0.0 99 | obs[self.agt_pos[0], self.agt_pos[1], 2] = 0.0 100 | return obs 101 | 102 | def render(self): 103 | obs = self.get_global_obs() 104 | enlarge = 10 105 | new_obs = np.ones((5*enlarge, (self.len + 2)*enlarge, 3)) 106 | for i in range(5): 107 | for j in range(self.len + 2): 108 | if obs[i][j][0] == 0.0 and obs[i][j][1] == 0.0 and obs[i][j][2] == 0.0: 109 | cv2.rectangle(new_obs, (j * enlarge, i * enlarge), (j * enlarge + enlarge, i * enlarge + enlarge), (0, 0, 0), -1) 110 | if obs[i][j][0] == 1.0 and obs[i][j][1] == 0.0 and obs[i][j][2] == 0.0: 111 | cv2.rectangle(new_obs, (j * enlarge, i * enlarge), (j * enlarge + enlarge, i * enlarge + enlarge), 112 | (0, 0, 255), -1) 113 | if obs[i][j][0] == 0.0 and obs[i][j][1] == 1.0 and obs[i][j][2] == 0.0: 114 | cv2.rectangle(new_obs, (j * enlarge, i * enlarge), (j * enlarge + enlarge, i * enlarge + enlarge), 115 | (0, 255, 0), -1) 116 | if obs[i][j][0] == 0.0 and obs[i][j][1] == 0.0 and obs[i][j][2] == 1.0: 117 | cv2.rectangle(new_obs, (j * enlarge, i * enlarge), (j * enlarge + enlarge, i * enlarge + enlarge), 118 | (255, 0, 0), -1) 119 | cv2.imshow('image', new_obs) 120 | cv2.waitKey(10) 121 | 122 | def plot_scene(self): 123 | fig = plt.figure(figsize=(5, 5)) 124 | gs = GridSpec(3, 3, figure=fig) 125 | ax1 = fig.add_subplot(gs[0:2, 0:2]) 126 | ax2 = fig.add_subplot(gs[2, 2]) 127 | ax1.imshow(self.get_global_obs()) 128 | ax2.imshow(self.get_obs()) 129 | plt.show() 130 | 131 | if __name__ == '__main__': 132 | env = EnvTMaze(4, random.randint(0, 1)) 133 | max_iter = 100000 134 | for i in range(max_iter): 135 | print("iter= ", i) 136 | env.plot_scene() 137 | # env.render() 138 | print('agent at', env.agt_pos, 'dest', env.dest_pos, 'wrong', env.wrong_pos) 139 | action = random.randint(0, 3) 140 | reward = env.step(action) 141 | print('reward', reward) 142 | if reward != 0: 143 | print('reset') 144 | env.reset(random.randint(0, 1)) 145 | --------------------------------------------------------------------------------