├── .gitignore ├── snake.gif ├── history-2022-12-18-03-28-52 ├── model_1000.pkl ├── model_10000.pkl ├── model_11000.pkl ├── model_12000.pkl ├── model_13000.pkl ├── model_14000.pkl ├── model_15000.pkl ├── model_16000.pkl ├── model_17000.pkl ├── model_18000.pkl ├── model_19000.pkl ├── model_2000.pkl ├── model_20000.pkl ├── model_21000.pkl ├── model_22000.pkl ├── model_23000.pkl ├── model_24000.pkl ├── model_25000.pkl ├── model_26000.pkl ├── model_27000.pkl ├── model_28000.pkl ├── model_29000.pkl ├── model_3000.pkl ├── model_30000.pkl ├── model_31000.pkl ├── model_32000.pkl ├── model_33000.pkl ├── model_33300.pkl ├── model_4000.pkl ├── model_5000.pkl ├── model_6000.pkl ├── model_7000.pkl ├── model_8000.pkl └── model_9000.pkl ├── history-2022-12-18-09-38-36 ├── model_1000.pkl ├── model_10000.pkl ├── model_11000.pkl ├── model_12000.pkl ├── model_13000.pkl ├── model_14000.pkl ├── model_15000.pkl ├── model_16000.pkl ├── model_17000.pkl ├── model_18000.pkl ├── model_19000.pkl ├── model_2000.pkl ├── model_20000.pkl ├── model_20335.pkl ├── model_3000.pkl ├── model_4000.pkl ├── model_5000.pkl ├── model_6000.pkl ├── model_7000.pkl ├── model_8000.pkl └── model_9000.pkl ├── buffer.py ├── model.py ├── snake.py ├── README.md ├── snake_norender.py └── ddqn.py /.gitignore: -------------------------------------------------------------------------------- 1 | runs 2 | __* -------------------------------------------------------------------------------- /snake.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/snake.gif -------------------------------------------------------------------------------- /history-2022-12-18-03-28-52/model_1000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-03-28-52/model_1000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-03-28-52/model_10000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-03-28-52/model_10000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-03-28-52/model_11000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-03-28-52/model_11000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-03-28-52/model_12000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-03-28-52/model_12000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-03-28-52/model_13000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-03-28-52/model_13000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-03-28-52/model_14000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-03-28-52/model_14000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-03-28-52/model_15000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-03-28-52/model_15000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-03-28-52/model_16000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-03-28-52/model_16000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-03-28-52/model_17000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-03-28-52/model_17000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-03-28-52/model_18000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-03-28-52/model_18000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-03-28-52/model_19000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-03-28-52/model_19000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-03-28-52/model_2000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-03-28-52/model_2000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-03-28-52/model_20000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-03-28-52/model_20000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-03-28-52/model_21000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-03-28-52/model_21000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-03-28-52/model_22000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-03-28-52/model_22000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-03-28-52/model_23000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-03-28-52/model_23000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-03-28-52/model_24000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-03-28-52/model_24000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-03-28-52/model_25000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-03-28-52/model_25000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-03-28-52/model_26000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-03-28-52/model_26000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-03-28-52/model_27000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-03-28-52/model_27000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-03-28-52/model_28000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-03-28-52/model_28000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-03-28-52/model_29000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-03-28-52/model_29000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-03-28-52/model_3000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-03-28-52/model_3000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-03-28-52/model_30000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-03-28-52/model_30000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-03-28-52/model_31000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-03-28-52/model_31000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-03-28-52/model_32000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-03-28-52/model_32000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-03-28-52/model_33000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-03-28-52/model_33000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-03-28-52/model_33300.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-03-28-52/model_33300.pkl -------------------------------------------------------------------------------- /history-2022-12-18-03-28-52/model_4000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-03-28-52/model_4000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-03-28-52/model_5000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-03-28-52/model_5000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-03-28-52/model_6000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-03-28-52/model_6000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-03-28-52/model_7000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-03-28-52/model_7000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-03-28-52/model_8000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-03-28-52/model_8000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-03-28-52/model_9000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-03-28-52/model_9000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-09-38-36/model_1000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-09-38-36/model_1000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-09-38-36/model_10000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-09-38-36/model_10000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-09-38-36/model_11000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-09-38-36/model_11000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-09-38-36/model_12000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-09-38-36/model_12000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-09-38-36/model_13000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-09-38-36/model_13000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-09-38-36/model_14000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-09-38-36/model_14000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-09-38-36/model_15000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-09-38-36/model_15000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-09-38-36/model_16000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-09-38-36/model_16000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-09-38-36/model_17000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-09-38-36/model_17000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-09-38-36/model_18000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-09-38-36/model_18000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-09-38-36/model_19000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-09-38-36/model_19000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-09-38-36/model_2000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-09-38-36/model_2000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-09-38-36/model_20000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-09-38-36/model_20000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-09-38-36/model_20335.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-09-38-36/model_20335.pkl -------------------------------------------------------------------------------- /history-2022-12-18-09-38-36/model_3000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-09-38-36/model_3000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-09-38-36/model_4000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-09-38-36/model_4000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-09-38-36/model_5000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-09-38-36/model_5000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-09-38-36/model_6000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-09-38-36/model_6000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-09-38-36/model_7000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-09-38-36/model_7000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-09-38-36/model_8000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-09-38-36/model_8000.pkl -------------------------------------------------------------------------------- /history-2022-12-18-09-38-36/model_9000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panjd123/D3QN-Snake/HEAD/history-2022-12-18-09-38-36/model_9000.pkl -------------------------------------------------------------------------------- /buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Buffer: 5 | is_prioritized = False 6 | weight = [] 7 | capcity = 100000 8 | _size = 0 9 | obs,act,rew,done,obs_next = [[] for _ in range(5)] 10 | 11 | def __init__(self,capcity=100000) -> None: 12 | self.capcity = capcity 13 | 14 | def add(self, obs, act, rew, done, obs_next, info=None, wei=1): 15 | if self._size >= self.capcity: 16 | self.obs.pop(0) 17 | self.act.pop(0) 18 | self.rew.pop(0) 19 | self.done.pop(0) 20 | self.obs_next.pop(0) 21 | 22 | if self.is_prioritized: 23 | self.weight.pop(0) 24 | 25 | self._size -= 1 26 | 27 | self.obs.append(obs) 28 | self.act.append(act) 29 | self.rew.append(rew) 30 | self.done.append(done) 31 | self.obs_next.append(obs_next) 32 | 33 | if self.is_prioritized: 34 | self.weight.append(wei) 35 | 36 | self._size += 1 37 | 38 | def sample(self, batch_size): 39 | if self.is_prioritized: 40 | p = np.array(self.weights) 41 | p = p / np.sum(p) 42 | arg = np.random.choice(np.arange(self._size), batch_size, p=p) 43 | return np.array(self.obs)[arg], np.array(self.act)[arg], np.array(self.rew)[arg], np.array(self.done)[arg], np.array(self.obs_next)[arg] 44 | else: 45 | arg = np.random.choice(np.arange(self._size), batch_size) 46 | return np.array(self.obs)[arg], np.array(self.act)[arg], np.array(self.rew)[arg], np.array(self.done)[arg], np.array(self.obs_next)[arg] 47 | 48 | def __len__(self)->int: 49 | return self._size 50 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | 5 | 6 | class CNN(nn.Module): 7 | def __init__(self, input_shape) -> None: 8 | super().__init__() 9 | self.input_shape = input_shape 10 | if len(input_shape) == 2: 11 | input_shape = (1, *input_shape) 12 | if len(input_shape) == 3: 13 | self.nn = nn.Sequential( 14 | nn.Conv2d(input_shape[0], 64, kernel_size=5, stride=2), 15 | nn.ReLU(), 16 | nn.Conv2d(64, 128, kernel_size=3, stride=1), 17 | nn.ReLU() 18 | ) 19 | elif len(input_shape) == 1: 20 | self.nn = nn.Sequential( 21 | nn.Linear(input_shape[0], 64), 22 | nn.ReLU() 23 | ) 24 | 25 | x = torch.zeros(input_shape) 26 | x = x.unsqueeze(0) 27 | x = self.nn(x) 28 | self.num_feature = np.prod(x.shape[1:]) 29 | print('CNN features:', self.num_feature) 30 | 31 | def forward(self, x): 32 | return self.nn(x) 33 | 34 | 35 | class VDNet(nn.Module): 36 | def __init__(self, num_feature, num_act) -> None: 37 | super().__init__() 38 | self.vnet = nn.Sequential( 39 | nn.Linear(num_feature, 128), 40 | nn.ReLU(), 41 | nn.Linear(128, 32), 42 | nn.ReLU(), 43 | nn.Linear(32, 1) 44 | ) 45 | self.dnet = nn.Sequential( 46 | nn.Linear(num_feature, 256), 47 | nn.ReLU(), 48 | nn.Linear(256, 32), 49 | nn.ReLU(), 50 | nn.Linear(32, num_act) 51 | ) 52 | 53 | def forward(self, x): 54 | v = self.vnet(x) 55 | d = self.dnet(x) 56 | return d+v-d.mean() 57 | # return d 58 | 59 | 60 | class DuelingNetwork(nn.Module): 61 | def __init__(self, input_shape, num_act) -> None: 62 | super().__init__() 63 | self.features = CNN(input_shape) 64 | num_feature = self.features.num_feature 65 | self.vdnet = VDNet(num_feature, num_act) 66 | 67 | def forward(self, x): 68 | x = self.features(x) 69 | x = x.view(x.size(0), -1) 70 | x = self.vdnet(x) 71 | return x 72 | 73 | 74 | if __name__ == '__main__': 75 | dn = DuelingNetwork((1, 16, 16), 2) 76 | x = torch.rand(1, 1, 16, 16) 77 | print(dn.forward(x)) 78 | -------------------------------------------------------------------------------- /snake.py: -------------------------------------------------------------------------------- 1 | from pygame.locals import * 2 | import pygame 3 | import numpy as np 4 | from snake_norender import * 5 | 6 | class Snake(Snake_norender): 7 | key2id = {K_LEFT: 1, K_RIGHT: 2, K_UP: 0, 8 | K_DOWN: 3, K_a: 1, K_d: 2, K_w: 0, K_s: 3} 9 | 10 | def __init__(self, visual_dis=1) -> None: 11 | super().__init__(visual_dis) 12 | pygame.init() 13 | self.screen = pygame.display.set_mode(self.settings.screen_size) 14 | pygame.display.set_caption("Snake") 15 | 16 | def draw_cell(self): 17 | for r in range(self.settings.row): 18 | pygame.draw.line(self.screen, self.settings.cell_color, 19 | (0, r * self.settings.cell_height), (self.settings.screen_width, r * self.settings.cell_height)) 20 | for c in range(self.settings.col): 21 | pygame.draw.line(self.screen, self.settings.cell_color, 22 | (c * self.settings.cell_width, 0), (c * self.settings.cell_width, self.settings.screen_height)) 23 | 24 | def draw_rect(self, point, color=None): 25 | left = point.col * self.settings.cell_width 26 | top = point.row * self.settings.cell_height 27 | if not color: 28 | pygame.draw.rect(self.screen, point.color, (left, top, 29 | self.settings.cell_width, self.settings.cell_height)) 30 | else: 31 | pygame.draw.rect( 32 | self.screen, color, (left, top, self.settings.cell_width, self.settings.cell_height)) 33 | 34 | def runner(self): 35 | flag_get_key = False 36 | 37 | if self.is_render: 38 | clock = pygame.time.Clock() 39 | 40 | if not self.game_going: 41 | for event in pygame.event.get(): 42 | if event.type == KEYDOWN: 43 | if event.key == K_SPACE: 44 | self.reset() 45 | self.game_going = True 46 | elif event.type == QUIT: 47 | exit(0) 48 | return 49 | 50 | # get direction 51 | for event in pygame.event.get(): 52 | if event.type == QUIT: 53 | exit(0) 54 | elif event.type == KEYDOWN and not flag_get_key: 55 | flag_get_key = True 56 | self.step(self.key2id[event.key]) 57 | 58 | if not flag_get_key: 59 | self.step(self.direct_id) 60 | 61 | if self.is_render: 62 | self.render() 63 | clock.tick(10 * self.settings.snake_speed) 64 | 65 | # draw the screen 66 | def render(self,caption='Snake'): 67 | self.screen.fill(self.settings.bg_color) 68 | for body in self.bodys[1:]: 69 | self.draw_rect(body) 70 | self.draw_rect(self.bodys[0], self.settings.head_color) 71 | self.draw_rect(self.food) 72 | self.draw_cell() 73 | for event in pygame.event.get(): 74 | pass 75 | pygame.display.set_caption(caption) 76 | pygame.display.flip() 77 | 78 | if __name__ == '__main__': 79 | snake = Snake() 80 | snake.run_game() 81 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # D3QN Snake 2 | 3 | [中文版](#d3qn-贪吃蛇) 4 | 5 | A greedy snake AI using reinforcement learning, which can work after only a few minutes of training 6 | 7 | Used skills including experience playback, double Q learning, and dueling network [reinforcement learning DQN quick start](https://blog.csdn.net/qq_32461955/article/details/126040912) 8 | 9 | The project also uploaded 2 models trained for 5 hours on the CPU 10 | 11 | ![](snake.gif) 12 | 13 | ## Usages 14 | 15 | ### Requirement 16 | 17 | - `numpy=1.22.4` 18 | - `matplotlib=3.5.1` 19 | - `pygame=2.1.2` 20 | - `tqdm=4.64.0` 21 | - `torch=1.13.0` 22 | ### How to start 23 | 24 | #### Demo 25 | 26 | - Play only, no training: `python ddqn.py --visual 2 --play --model_load history-2022-12-18-03-28-52/model_33300.pkl` 27 | 28 | #### Training 29 | 30 | - From just being able to move to being able to eat about 8 foods (6 minutes): `python ddqn.py --visual 1 --step 100000 --model_load history-2022-12-18-09-38-36/model_10000.pkl --model_save tmp.pkl` 31 | 32 | - From 0 to convergence (50 minutes): `python ddqn.py --visual 1 --step 1000000 -norender --model_save tmp.pkl` 33 | 34 | ### Further usages 35 | 36 | For further usage, see `python ddqn.py -h` 37 | ```cmd 38 | usage: ddqn.py [-h] [--step STEP] [--history HISTORY] [--render RENDER] [--train] [--play] [--visual VISUAL] 39 | [--model_load MODEL_LOAD] [--model_save MODEL_SAVE] [--test TEST] [--epsilon EPSILON] [--log] 40 | 41 | optional arguments: 42 | -h, --help show this help message and exit 43 | --step STEP the number of step it will train 44 | --history HISTORY after HISTORY generations, save model every 1000 generations 45 | --render RENDER render level while training, 0: no render, 1: render once after die, 2[default]: render every step 46 | --train only train 47 | --play only play 48 | --visual VISUAL the manhattan distance that snakes can see 49 | --model_load MODEL_LOAD 50 | the path of the loading model 51 | --model_save MODEL_SAVE 52 | the model's output path 53 | --test TEST 54 | --epsilon EPSILON probability of using random movement during training 55 | --log output log while training 56 | ``` 57 | 58 | If you want to play Snake manually, run `python snake.py` 59 | 60 | ## About the model 61 | 62 | ### Input 63 | 64 | - The two groups of normalized coordinates represent the coordinates of snake head and the coordinates of snake head relative to food 65 | 66 | - Several Boolean values indicate whether there are obstacles, such as walls or bodies, in each grid within `visual` manhattan distance. 67 | 68 | - Four integers record the last four movements, which is equivalent to telling the snake the position of the first four segments 69 | 70 | - A floating point number, which represents the ratio of the interval between eating food last time and half of the map area, and is used for learning to avoid an endless loop 71 | 72 | >The specific model architecture can be seen in [Enhanced Learning DQN Quick Start](https://blog.csdn.net/qq_32461955/article/details/126040912) 73 | ### visual = 1 74 | 75 | Usually, on my laptop CPU, this model will be on track in 10 minutes and reach a better level in about half an hour. And it takes about 500000 training steps. 76 | 77 | ### visual = 2 78 | 79 | This is the default parameter. Under this parameter, the convergence speed of the model is slightly slower. It takes 15 minutes to see the effect, and it can reach a better level in a few hours. 80 | 81 | ### Reward 82 | 83 | - eat: +1.2 84 | - get close to the food: +0.1 85 | - get away from the food: -0.3 86 | - die: -2 87 | 88 | ## Preset Model Description 89 | 90 | - The training history of the preset model of `visual=2` is stored in `history-2022-12-18-03-28-52`, and the average score after 5 hours of training is $28.10$ food 91 | 92 | - The training history of the preset model of `visual=1` is stored in `history-2022-12-18-09-38-36`, and the average score after 1 hour of training is $20.58$ food 93 | 94 | > It seems that the above models have not really converged 95 | 96 | # D3QN-贪吃蛇 97 | 98 | 一个用强化学习实现的能只经过几分钟训练后work的贪吃蛇AI 99 | 100 | 使用了包括经验回放,双Q学习,对决网络等技巧 [强化学习 DQN 速成](https://blog.csdn.net/qq_32461955/article/details/126040912) 101 | 102 | 项目还上传了在CPU上训练了5个小时的模型 103 | 104 | ![](snake.gif) 105 | 106 | ## 运行方法 107 | 108 | ### 依赖库 109 | 110 | - `numpy=1.22.4` 111 | - `matplotlib=3.5.1` 112 | - `pygame=2.1.2` 113 | - `tqdm=4.64.0` 114 | - `torch=1.13.0` 115 | 116 | ### How to start 117 | 118 | #### 演示 119 | 120 | - 只玩,不训练:`python ddqn.py --visual 2 --play --model_load history-2022-12-18-03-28-52/model_33300.pkl` 121 | 122 | #### 训练 123 | 124 | - 从刚能移动开始训练到能吃8个左右的食物(耗时6分钟):`python ddqn.py --visual 1 --step 100000 --model_load history-2022-12-18-09-38-36/model_10000.pkl --model_save tmp.pkl` 125 | 126 | - 从0开始直到收敛(耗时50分钟):`python ddqn.py --visual 1 --step 1000000 -norender --model_save tmp.pkl` 127 | 128 | ### 运行方法 129 | 130 | 进一步的使用方法请见 `python ddqn.py -h` 131 | 132 | ```cmd 133 | usage: ddqn.py [-h] [--step STEP] [--history HISTORY] [--render RENDER] [--train] [--play] [--visual VISUAL] 134 | [--model_load MODEL_LOAD] [--model_save MODEL_SAVE] [--test TEST] [--epsilon EPSILON] [--log] 135 | 136 | optional arguments: 137 | -h, --help show this help message and exit 138 | --step STEP the number of step it will train 139 | --history HISTORY after HISTORY generations, save model every 1000 generations 140 | --render RENDER render level while training, 0: no render, 1: render once after die, 2[default]: render every step 141 | --train only train 142 | --play only play 143 | --visual VISUAL the manhattan distance that snakes can see 144 | --model_load MODEL_LOAD 145 | the path of the loading model 146 | --model_save MODEL_SAVE 147 | the model's output path 148 | --test TEST 149 | --epsilon EPSILON probability of using random movement during training 150 | --log output log while training 151 | ``` 152 | 153 | 如果你想手动玩贪吃蛇,那么直接运行 `python snake.py`,这只需要安装 `pygame` 154 | 155 | ## 关于模型 156 | 157 | ### 输入 158 | 159 | - 两组归一化后的坐标表示蛇头的坐标和蛇头相对食物的坐标 160 | - 若干个个布尔值表示曼哈顿距离 `visual` 内的每一格是否有障碍物,如墙或者身体 161 | - 四个整数记录最近四次移动,相当于告诉蛇的前四节身体的位置 162 | - 一个浮点数,表示上次吃食物的间隔与地图面积一半的比值,用于让它学习避免死循环 163 | 164 | > 具体模型的架构可以见 [强化学习 DQN 速成](https://blog.csdn.net/qq_32461955/article/details/126040912) 165 | ### visual = 1 166 | 167 | 通常,在我的笔记本CPU上,这个模型会在10分钟内步入正轨,在半小时左右达到比较好的水平。而这大概需要五十万个训练步数。 168 | 169 | ### visual = 2 170 | 171 | 这是默认的参数,这个参数下模型的收敛速度稍慢一点,需要15分钟才能看出效果,在几个小时后能达到比较好的水平。 172 | 173 | ### 奖励设置 174 | 175 | - eat: +1.2 176 | - get close to the food: +0.1 177 | - get away from the food: -0.3 178 | - die: -2 179 | 180 | ## 预置模型说明 181 | 182 | 程序的默认视距是 $2$,根目录下的模型也是视距为 $2$ 的 183 | 184 | 更多: 185 | 186 | - `history-2022-12-18-03-28-52` 中存放了 `visual=2` 的预置模型的训练历史,训练5个小时后的平均成绩是 $28.10$ 个食物 187 | 188 | - `history-2022-12-18-09-38-36` 中存放了 `visual=1` 的预置模型的训练历史,训练1个小时后的平均成绩是 $20.58$ 个食物 189 | 190 | > 以上模型似乎都还没真正收敛 -------------------------------------------------------------------------------- /snake_norender.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Settings: 5 | def __init__(self): 6 | # the settings of color, size, speed 7 | self.screen_width = 400 8 | self.screen_height = 400 9 | self.col = 13 10 | self.row = 13 11 | self.bg_color = (200, 200, 200) 12 | self.snake_color = (100, 100, 100) 13 | self.head_color = (0, 0, 0) 14 | self.food_color = (255, 245, 225) 15 | self.cell_color = (0, 180, 100) 16 | self.snake_speed = 1 17 | self.cell_width = self.screen_width / self.col 18 | self.cell_height = self.screen_height / self.row 19 | self.screen_size = (self.screen_width, self.screen_height) 20 | 21 | 22 | class Point: 23 | row = 0 24 | col = 0 25 | color = (0, 0, 0) 26 | 27 | def __init__(self, row=0, col=0, color=(0, 0, 0)): 28 | self.row = row 29 | self.col = col 30 | self.color = color 31 | 32 | def __eq__(self, rhs): 33 | return self.row == rhs.row and self.col == rhs.col 34 | 35 | def copy(self): 36 | return Point(self.row, self.col, self.color) 37 | 38 | 39 | def distance(p1, p2): 40 | return np.abs(p1.row - p2.row) + np.abs(p1.col - p2.col) 41 | 42 | 43 | class Snake_norender: 44 | settings = Settings() 45 | direct_col = [0, -1, 1, 0] 46 | direct_row = [-1, 0, 0, 1] 47 | # up left right down 48 | direct_id = 0 49 | bodys = [] 50 | score_point = 0 51 | game_going = False 52 | screen = None 53 | is_render = True 54 | silent = True 55 | history_act = [0, 0, 0, 0] 56 | vis = np.zeros((settings.row, settings.col)) 57 | mode = '1d' 58 | tic = 0 59 | lst_eat = 0 60 | visual_dis = 1 61 | 62 | def __init__(self, visual_dis=1) -> None: 63 | self.visual_dis = visual_dis 64 | self.reset() 65 | 66 | def setvis(self, p: Point, t=1): 67 | self.vis[p.row-1][p.col-1] = t 68 | 69 | def get_sn(self, p): 70 | dis = self.visual_dis 71 | x = p.row-1 72 | y = p.col-1 73 | sn = [] 74 | for j in range(-dis,dis+1): 75 | for i in range(-dis,dis+1): 76 | if (i==0 and j==0) or (abs(i)+abs(j)>dis): 77 | continue 78 | tx = x+i 79 | ty = y+j 80 | if 0<=tx and tx= 1 else 1.0 85 | # s2 = self.vis[x+1][y] if x < self.settings.row-1 else 1.0 86 | # s3 = self.vis[x][y-1] if y >= 1 else 1.0 87 | # s4 = self.vis[x][y+1] if y < self.settings.col-1 else 1.0 88 | # print('s:',sn,[s1, s2, s3, s4]) 89 | # return [s1, s2, s3, s4] 90 | return sn 91 | 92 | def get_body_food(self): 93 | return [self.bodys[0].row/self.settings.row, self.bodys[0].col/self.settings.col, 94 | self.food.row/self.settings.row-self.bodys[0].row/self.settings.row, self.food.col/self.settings.col-self.bodys[0].col/self.settings.col] 95 | 96 | def reset(self): 97 | self.lst_eat = 0 98 | self.tic = 0 99 | self.bodys = [Point(self.settings.row // 2, self.settings.col // 100 | 2, self.settings.snake_color)] 101 | self.vis = np.zeros((self.settings.row, self.settings.col)) 102 | self.setvis(self.bodys[0]) 103 | self.food = self.create_food() 104 | self.score_point = 0 105 | return self.get_s(mode=self.mode), 0, 0, 0, {'score': self.score_point} 106 | 107 | def draw_cell(self): 108 | pass 109 | 110 | def draw_rect(self, point, color=None): 111 | pass 112 | 113 | def create_food(self): 114 | if len(self.bodys) < 0: 115 | while True: 116 | new_food = Point(np.random.randint(0, self.settings.row - 1), 117 | np.random.randint(0, self.settings.col - 1), self.settings.food_color) 118 | is_coll = False 119 | for body in self.bodys: 120 | if body == new_food: 121 | is_coll = True 122 | break 123 | if not is_coll: 124 | return new_food 125 | else: 126 | points = np.ones(self.settings.row*self.settings.col) 127 | for body in self.bodys: 128 | points[body.row*self.settings.col+body.col] = 0 129 | points = np.nonzero(points)[0] 130 | id = np.random.randint(len(points)) 131 | ij = points[id] 132 | i = ij//self.settings.col 133 | j = ij % self.settings.col 134 | return Point(i, j, self.settings.food_color) 135 | 136 | def get_lazy(self): 137 | # print('lazy:', (self.tic - self.lst_eat)/(0.5*self.settings.col*self.settings.row)) 138 | return (self.tic - self.lst_eat)/(self.settings.col*self.settings.row) 139 | 140 | def step(self, action): 141 | rew = 0 142 | self.tic += 1 143 | 144 | if action == 3-self.direct_id: 145 | return rew, 1, self.get_s(mode=self.mode), {'score': self.score_point} 146 | 147 | self.history_act.pop(0) 148 | self.history_act.append(action) 149 | 150 | self.direct_id = action 151 | 152 | self.bodys.insert(0, self.bodys[0].copy()) 153 | 154 | dis_old = distance(self.bodys[0], self.food) 155 | 156 | coll = 0 157 | # snake move 158 | self.bodys[0].row += self.direct_row[self.direct_id] 159 | self.bodys[0].col += self.direct_col[self.direct_id] 160 | self.setvis(self.bodys[0]) 161 | 162 | if self.bodys[0].col < 0: 163 | # self.bodys[0].col = self.settings.col - 1 164 | coll = 1 165 | if self.bodys[0].col > self.settings.col - 1: 166 | # self.bodys[0].col = 0 167 | coll = 1 168 | if self.bodys[0].row < 0: 169 | # self.bodys[0].row = self.settings.row - 1 170 | coll = 1 171 | if self.bodys[0].row > self.settings.row - 1: 172 | # self.bodys[0].row = 0 173 | coll = 1 174 | 175 | eat = (self.bodys[0] == self.food) 176 | if eat: 177 | self.food = self.create_food() 178 | self.score_point += 1 179 | rew += 1.2 180 | self.lst_eat = self.tic 181 | if not eat: 182 | self.setvis(self.bodys[-1], 0) 183 | self.bodys.pop() 184 | 185 | dis_new = distance(self.bodys[0], self.food) 186 | if dis_new > dis_old: 187 | rew -= 0.3 188 | pass 189 | else: 190 | rew += 0.1 191 | 192 | # judge coll 193 | for body in self.bodys[1:]: 194 | if self.bodys[0] == body: 195 | coll = 1 196 | break 197 | 198 | # avoid loop 199 | if self.get_lazy() > 1: 200 | coll = 1 201 | 202 | if coll: 203 | self.game_going = False 204 | if not self.silent: 205 | print("You die, and please press space to restart") 206 | rew -= 2.0 207 | return rew, coll, self.get_s(mode=self.mode), {'score': self.score_point} 208 | 209 | def runner(self): 210 | self.step(self.direct_id) 211 | 212 | def get_s(self, mode='1d'): 213 | if mode == '1d': 214 | return [*self.get_body_food(), *self.get_sn(self.bodys[0]), *self.history_act, self.get_lazy()] 215 | elif mode == '2d': 216 | ret = np.zeros((self.settings.row, self.settings.col)) 217 | p = 1 218 | for body in self.bodys: 219 | ret[body.row-1, body.col-1] = p 220 | p *= 0.98 221 | ret[self.food.row-1, self.food.col-1] = -1 222 | return ret[np.newaxis, :, :] 223 | 224 | # draw the screen 225 | def render(self, caption='Snake'): 226 | pass 227 | 228 | def run_game(self): 229 | self.render() 230 | self.silent = False 231 | while True: 232 | self.runner() 233 | 234 | def random_action(self, act=-1): 235 | act_next = np.random.randint(3) 236 | if act_next >= 3-act: 237 | return act_next+1 238 | else: 239 | return act_next 240 | 241 | -------------------------------------------------------------------------------- /ddqn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import optim 4 | from torch.nn import functional as F 5 | from torch.utils.tensorboard import SummaryWriter 6 | from model import DuelingNetwork 7 | from buffer import Buffer 8 | from snake import Snake, Snake_norender 9 | from tqdm import tqdm 10 | import time 11 | import os 12 | import os.path as osp 13 | import matplotlib.pyplot as plt 14 | from argparse import ArgumentParser 15 | from collections import deque 16 | import pygame 17 | 18 | parser = ArgumentParser() 19 | parser.add_argument("--step", type=int, default=1000, 20 | help="the number of step it will train") 21 | parser.add_argument("--history", type=int, default=0, 22 | help="after HISTORY generations, save model every 1000 generations") 23 | parser.add_argument("--render", type=int, default=2, 24 | help="render level while training, 0: no render, 1: render once after die, 2[default]: render every step") 25 | parser.add_argument("--train", action="store_true", help="only train") 26 | parser.add_argument("--play", action="store_true", help="only play") 27 | parser.add_argument("--visual", type=int, default=2, 28 | help="the manhattan distance that snakes can see, note that this argument will affect the model's parameter size, if you plan to load a model, pay attention to the corresponding.") 29 | parser.add_argument("--model_load", type=str, 30 | default="model.pkl", help="the path of the loading model") 31 | parser.add_argument("--model_save", type=str, 32 | default="model.pkl", help="the model's output path") 33 | parser.add_argument("--test", type=str, default="") 34 | parser.add_argument("--epsilon", type=float, default=0.05, 35 | help="probability of using random movement during training") 36 | parser.add_argument("--log", action="store_true", 37 | help="output log while training") 38 | 39 | argument = parser.parse_args() 40 | history_dir = 'history-'+time.strftime('%Y-%m-%d-%H-%M-%S') 41 | 42 | if not argument.play and not argument.test: 43 | os.mkdir(history_dir) 44 | 45 | 46 | class DDQN: 47 | def __init__(self, input_shape, num_act, env: Snake, gamma=0.99, lamda=0.05, epsilon=0.05) -> None: 48 | self.gamma = gamma 49 | self.lamda = lamda 50 | self.model = DuelingNetwork(input_shape, num_act) 51 | self.target_model = DuelingNetwork(input_shape, num_act) 52 | 53 | # continue 54 | # self.model.load_state_dict(torch.load('model40w.pkl')) 55 | 56 | # self.target_model.load_state_dict(self.model.state_dict()) 57 | self.train_after = 0 58 | self.expl_before = 0 59 | self.buffer = Buffer(capcity=1) 60 | self.env = env 61 | self.epsilon = epsilon 62 | self.batch_size = 1 63 | self.log = [] 64 | # self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 65 | self.device = 'cpu' 66 | self.model.to(self.device) 67 | self.target_model.to(self.device) 68 | self.model_optim = optim.Adam(self.model.parameters(), lr=0.0001) 69 | 70 | def learning(self): 71 | obs, act, rew, done, obs_next = self.buffer.sample(self.batch_size) 72 | # 2d 73 | obs = torch.tensor(obs, dtype=torch.float).to(self.device) 74 | act = torch.tensor(act, dtype=torch.long).to(self.device) 75 | rew = torch.tensor(rew, dtype=torch.float).to(self.device) 76 | done = torch.tensor(done, dtype=torch.float).to(self.device) 77 | obs_next = torch.tensor(obs_next, dtype=torch.float).to(self.device) 78 | 79 | q_value = self.model(obs) 80 | q_value = q_value[:, act] 81 | 82 | with torch.no_grad(): 83 | q_next_value = self.model(obs_next) 84 | q_value_max_arg = torch.argmax(q_next_value, dim=1) 85 | q_next_value = self.target_model(obs_next) 86 | q_next_value = q_next_value[:, q_value_max_arg] 87 | 88 | expected_q_value = rew + (1 - done) * self.gamma * q_next_value 89 | 90 | loss = F.mse_loss(q_value, expected_q_value) 91 | self.model_optim.zero_grad() 92 | loss.backward() 93 | self.model_optim.step() 94 | 95 | for model_param, target_param in zip(self.model.parameters(), self.target_model.parameters()): 96 | target_param.data.copy_( 97 | (1-self.lamda)*target_param.data + self.lamda*model_param.data) 98 | 99 | return loss.item() 100 | 101 | def training(self, max_step=1000, detail_render=False, is_log=True, queue_maxlen=50): 102 | try: 103 | epoch = 0 104 | writer = SummaryWriter() 105 | total_reward = 0 106 | obs, act, rew, done, info = self.env.reset() 107 | scores = deque(maxlen=queue_maxlen) 108 | rewards = deque(maxlen=queue_maxlen) 109 | for step in tqdm(range(max_step)): 110 | if step > self.expl_before: 111 | act = self.select_action(obs, act) 112 | else: 113 | act = self.env.random_action(act) 114 | 115 | rew, done, obs_next, info = self.env.step(act) 116 | 117 | if detail_render: 118 | self.env.render() 119 | 120 | self.buffer.add(obs, act, rew, done, obs_next) 121 | 122 | if step > self.train_after: 123 | loss = self.learning() 124 | writer.add_scalar('loss', loss, step) 125 | 126 | total_reward += rew 127 | 128 | if done: 129 | writer.add_scalars( 130 | 'score', {'reward': total_reward, 'score': info['score']}, epoch) 131 | 132 | if is_log: 133 | rewards.append(total_reward) 134 | scores.append(info['score']) 135 | mean_reward = np.mean(list(rewards)) 136 | mean_score = np.mean(list(scores)) 137 | tqdm.write(f'{epoch}: '+str(mean_reward) + 138 | ", "+str(mean_score)) 139 | 140 | self.env.render(f"第{epoch}代小蛇") 141 | 142 | obs, act, rew, done, info = self.env.reset() 143 | epoch += 1 144 | if epoch > argument.history and epoch % 1000 == 0: 145 | torch.save(self.model.state_dict(), osp.join( 146 | history_dir, f"model_{epoch}.pkl")) 147 | total_reward = 0 148 | else: 149 | obs = obs_next 150 | writer.close() 151 | torch.save(self.model.state_dict(), osp.join( 152 | history_dir, f"model_{epoch}.pkl")) 153 | except KeyboardInterrupt: 154 | torch.save(self.model.state_dict(), 'model_interrupt.pkl') 155 | torch.save(self.target_model.state_dict(), 156 | 'target_model_interrupt.pkl') 157 | writer.close() 158 | raise KeyboardInterrupt 159 | 160 | def select_action(self, obs, act): 161 | if np.random.rand() > 1-self.epsilon: 162 | return self.env.random_action(act) 163 | q_value = self.model(torch.tensor( 164 | obs, dtype=torch.float, device=self.device).unsqueeze(0)) 165 | q_value_max_arg = torch.argmax(q_value, dim=1) 166 | # if q_value_max_arg.item() == 3-act: 167 | # return self.env.random_action(act) 168 | # else: 169 | # return q_value_max_arg.item() 170 | return q_value_max_arg.item() 171 | 172 | def play(self, max_epoch=100, delay=30, is_render=True): 173 | self.env.render('Snake') 174 | self.epsilon = 0 175 | rewards = [] 176 | scores = [] 177 | for epoch in range(max_epoch): 178 | total_reward = 0 179 | obs, act, rew, done, info = self.env.reset() 180 | while not done: 181 | act = self.select_action(obs, act) 182 | rew, done, obs, info = self.env.step(act) 183 | if is_render: 184 | self.env.render() 185 | pygame.time.delay(delay) 186 | total_reward += rew 187 | # print('total_reward:', total_reward, 'obs:', obs) 188 | rewards.append(total_reward) 189 | scores.append(info['score']) 190 | self.env.render() 191 | print('mean reward:', np.mean(rewards), 'mean score:', np.mean(scores)) 192 | return rewards, scores 193 | 194 | 195 | def test(ddqn: DDQN, dir='history-2022-12-18-03-28-52', epoch=100): 196 | print("testing", dir) 197 | gens = [] 198 | scores = [] 199 | for root, dirs, files in os.walk(dir): 200 | files = sorted(files, key=lambda x: int(x[6:-4])) 201 | for name in tqdm(files): 202 | ddqn.model.load_state_dict(torch.load(osp.join(root, name))) 203 | rews, scos = ddqn.play(max_epoch=epoch, delay=0) 204 | scores.append(np.mean(scos)) 205 | gens.append(name[6:-4]) 206 | tqdm.write(name+"\t"+str(scores[-1])) 207 | gens = np.array(gens) 208 | scores = np.array(scores) 209 | plt.plot(scores) 210 | plt.show() 211 | 212 | 213 | if __name__ == '__main__': 214 | if argument.render == 0: 215 | env_train = Snake_norender(visual_dis=argument.visual) 216 | elif argument.render >= 1: 217 | env_train = Snake(visual_dis=argument.visual) 218 | 219 | obs_length = len(env_train.reset()[0]) 220 | print('obs_length=', obs_length) 221 | 222 | ddqn = DDQN((obs_length,), 4, env_train, epsilon=argument.epsilon) 223 | 224 | if argument.test: 225 | test(ddqn, argument.test) 226 | exit(0) 227 | 228 | if not argument.play: # train 229 | # continue to train 230 | try: 231 | ddqn.model.load_state_dict(torch.load(argument.model_load)) 232 | except Exception as e: 233 | print(e) 234 | print('Warning: loading model fail, use initialization parameters') 235 | 236 | try: 237 | ddqn.model.load_state_dict(torch.load('target_model.pkl')) 238 | except Exception as e: 239 | print(e) 240 | print('Warning: loading target model fail, use initialization parameters') 241 | ddqn.training(max_step=argument.step, 242 | detail_render=argument.render == 2, is_log=argument.log) 243 | torch.save(ddqn.model.state_dict(), argument.model_save) 244 | torch.save(ddqn.target_model.state_dict(), 'target_model.pkl') 245 | 246 | if not argument.train: # play 247 | env_play = Snake(visual_dis=argument.visual) 248 | ddqn_play = DDQN((obs_length,), 4, env_train, epsilon=argument.epsilon) 249 | if not argument.play: # after train 250 | try: 251 | ddqn_play.model.load_state_dict( 252 | torch.load(argument.model_save)) 253 | except Exception as e: 254 | print(e) 255 | print('Warning: loading model fail, use initialization parameters') 256 | else: # without train 257 | try: 258 | ddqn_play.model.load_state_dict( 259 | torch.load(argument.model_load)) 260 | except Exception as e: 261 | print(e) 262 | print('Warning: loading model fail, use initialization parameters') 263 | ddqn_play.play(10) 264 | --------------------------------------------------------------------------------