├── requirements.txt ├── gym_snake ├── envs │ ├── __init__.py │ ├── snake_env.py │ └── snake.py └── __init__.py ├── rl ├── default_model.pth ├── model.py └── agent.py ├── LICENSE.md ├── README.md ├── .gitignore ├── test.py └── train.py /requirements.txt: -------------------------------------------------------------------------------- 1 | gym==0.14.0 2 | numpy==1.16.1 3 | torch==1.1.0 4 | -------------------------------------------------------------------------------- /gym_snake/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from gym_snake.envs.snake_env import SnakeEnv 2 | -------------------------------------------------------------------------------- /rl/default_model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Platun0v/snake-gym/HEAD/rl/default_model.pth -------------------------------------------------------------------------------- /gym_snake/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | 3 | register( 4 | id='Snake-v0', 5 | entry_point='gym_snake.envs:SnakeEnv', 6 | ) 7 | -------------------------------------------------------------------------------- /rl/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class DQN(nn.Module): 7 | def __init__(self, state_size, action_size, seed): 8 | super(DQN, self).__init__() 9 | self.seed = torch.manual_seed(seed) 10 | self.fc1 = nn.Linear(state_size, 64) 11 | self.fc2 = nn.Linear(64, 64) 12 | self.fc3 = nn.Linear(64, action_size) 13 | 14 | def forward(self, h): 15 | h = F.relu(self.fc1(h)) 16 | h = F.relu(self.fc2(h)) 17 | return self.fc3(h) 18 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Platonov Nickolay 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # snake-gym 2 | 3 | ## Description 4 | snake-gym is implementation of the classic game snake that is made as an OpenAI gym environment 5 | 6 | ## Dependencies 7 | + gym 8 | + numpy 9 | + torch 10 | 11 | ## Installation 12 | 1. Clone repository: `$ git clone https://github.com/Platun0v/snake-gym.git` 13 | 2. `cd` into snake-gym and run: `pip install -r requirements.txt` 14 | 15 | ## Using ready-made programs 16 | 17 | ### Test network 18 | To see pretrained snake run: `python test.py` 19 | 20 | #### Parameters 21 | + _load_path_ - path to neural network 22 | + _render_ - render process 23 | + _times_ - how many times to run 24 | + _seed_ - seed 25 | + _blocks_ - number of blocks on square grid 26 | + _block_size_ - the size of block in pixels 27 | 28 | ### Train network 29 | To train snake run: `python train.py` 30 | 31 | #### Parameters 32 | + _save_path_ - path to save neural network 33 | + _render_ - render process 34 | + _episodes_ - how many times to run 35 | + _seed_ - seed 36 | + _blocks_ - number of blocks on square grid 37 | + _block_size_ - the size of block in pixels 38 | 39 | ## Using enviroment 40 | ```python 41 | import gym 42 | import gym_snake 43 | 44 | env = gym.make('Snake-v0') 45 | 46 | for i in range(100): 47 | env.reset() 48 | for t in range(1000): 49 | env.render() 50 | state, reward, done, info = env.step(env.action_space.sample()) 51 | if done: 52 | print('episode {} finished after {} timesteps'.format(i, t)) 53 | break 54 | ``` 55 | # 56 | https://habr.com/ru/post/465477/ 57 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/.idea 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # Environments 87 | .env 88 | .venv 89 | env/ 90 | venv/ 91 | ENV/ 92 | env.bak/ 93 | venv.bak/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from time import sleep 3 | import gym 4 | import gym_snake 5 | from rl.model import DQN 6 | import torch 7 | import numpy as np 8 | 9 | 10 | class Agent: 11 | def __init__(self, state_size, action_size, pth_path, seed): 12 | self.model = DQN(state_size, action_size, seed) 13 | self.model.load_state_dict(torch.load(pth_path)) 14 | 15 | def act(self, state): 16 | state = torch.from_numpy(state).float() 17 | self.model.eval() 18 | with torch.no_grad(): 19 | action_values = self.model(state) 20 | 21 | return np.argmax(action_values.data.numpy()) 22 | 23 | 24 | def main(load_path, render, times, seed, block_size, blocks): 25 | env = get_env(seed, block_size, blocks) 26 | agent = Agent(env.observation_space.shape[0], env.action_space.n, load_path, seed) 27 | watch_agent(agent, env, times, render) 28 | 29 | 30 | def get_env(seed, block_size, blocks): 31 | env = gym.make('Snake-v0', block_size=block_size, blocks=blocks) 32 | env.seed(seed) 33 | return env 34 | 35 | 36 | def watch_agent(agent, env, times, render): 37 | scores = [] 38 | apples = [] 39 | 40 | for i in range(1, times + 1): 41 | state = env.reset() 42 | score = 0 43 | steps_after_last_apple = 0 44 | 45 | while True: 46 | if render: 47 | env.render() 48 | sleep(0.05) 49 | action = agent.act(state) 50 | state, reward, done, info = env.step(action) 51 | score += reward 52 | if done: 53 | break 54 | 55 | steps_after_last_apple += 1 56 | if steps_after_last_apple > 200: 57 | break 58 | if info['apple_ate']: 59 | steps_after_last_apple = 0 60 | 61 | scores.append(score) 62 | apples.append(info['apples']) 63 | print(f'\rEpisode {i}\t' 64 | f'Average apples: {np.mean(apples):.2f}\t' 65 | f'Average score: {np.mean(scores):.2f}', end='') 66 | env.close() 67 | 68 | 69 | if __name__ == '__main__': 70 | parser = argparse.ArgumentParser('Test trained agent') 71 | parser.add_argument('--load_path', default='rl/default_model.pth', type=str) 72 | parser.add_argument('--render', action='store_true') 73 | parser.add_argument('--times', default=3, type=int) 74 | parser.add_argument('--seed', default=0, type=int) 75 | parser.add_argument('--blocks', default=10, type=int) 76 | parser.add_argument('--block_size', default=50, type=int) 77 | 78 | args = parser.parse_args() 79 | main( 80 | load_path=args.load_path, 81 | render=args.render, 82 | times=args.times, 83 | seed=args.seed, 84 | block_size=args.block_size, 85 | blocks=args.blocks, 86 | ) 87 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from time import sleep 3 | import gym 4 | import gym_snake 5 | import torch 6 | import numpy as np 7 | from rl.agent import Agent 8 | from collections import deque 9 | 10 | 11 | def main(save_path, render, seed, block_size, blocks, episodes, max_t, eps_start, eps_end, eps_decay): 12 | env = get_env(seed, block_size, blocks) 13 | agent = Agent(env.observation_space.shape[0], env.action_space.n, seed) 14 | agent = train_dqn(agent, env, episodes, max_t, eps_start, eps_end, eps_decay, render) 15 | torch.save(agent.qnetwork_local.state_dict(), save_path) 16 | 17 | 18 | def get_env(seed, block_size, blocks): 19 | env = gym.make('Snake-v0', block_size=block_size, blocks=blocks) 20 | env.seed(seed) 21 | return env 22 | 23 | 24 | def train_dqn(agent, env, episodes, max_t, eps_start, eps_end, eps_decay, render): 25 | scores = [] 26 | apples = [] 27 | scores_window = deque(maxlen=100) 28 | apples_window = deque(maxlen=100) 29 | eps = eps_start 30 | for i in range(1, episodes + 1): 31 | state = env.reset() 32 | score = 0 33 | for t in range(max_t): 34 | action = agent.act(state, eps) 35 | next_state, reward, done, info = env.step(action) 36 | if render: 37 | env.render() 38 | agent.step(state, action, reward, next_state, done) 39 | state = next_state 40 | score += reward 41 | if done: 42 | break 43 | scores_window.append(score) # save most recent score 44 | apples_window.append(info['apples']) 45 | scores.append(score) # save most recent score 46 | apples.append(info['apples']) 47 | eps = max(eps_end, eps_decay*eps) # decrease epsilon 48 | print(f'\rEpisode {i}\t' 49 | f'Average apples: {np.mean(apples):.2f}\t' 50 | f'Average score: {np.mean(scores):.2f}', end='') 51 | if i % 100 == 0: 52 | print(f'\rEpisode {i}\t' 53 | f'Average apples: {np.mean(apples):.2f}\t' 54 | f'Average score: {np.mean(scores):.2f}') 55 | return agent 56 | 57 | 58 | if __name__ == '__main__': 59 | parser = argparse.ArgumentParser('Train agent') 60 | parser.add_argument('--save_path', default='checkpoint.pth', type=str) 61 | parser.add_argument('--render', action='store_true') 62 | parser.add_argument('--seed', default=0, type=int) 63 | parser.add_argument('--blocks', default=10, type=int) 64 | parser.add_argument('--block_size', default=50, type=int) 65 | parser.add_argument('--episodes', default=2000, type=int) 66 | parser.add_argument('--max_t', default=1500, type=int) 67 | parser.add_argument('--eps_start', default=1.0, type=float) 68 | parser.add_argument('--eps_end', default=0.01, type=float) 69 | parser.add_argument('--eps_decay', default=0.995, type=float) 70 | 71 | args = parser.parse_args() 72 | main( 73 | save_path=args.save_path, 74 | render=args.render, 75 | seed=args.seed, 76 | block_size=args.block_size, 77 | blocks=args.blocks, 78 | episodes=args.episodes, 79 | max_t=args.max_t, 80 | eps_start=args.eps_start, 81 | eps_end=args.eps_end, 82 | eps_decay=args.eps_decay, 83 | ) 84 | -------------------------------------------------------------------------------- /gym_snake/envs/snake_env.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from gym import error, spaces, utils 3 | from gym.utils import seeding 4 | import numpy as np 5 | 6 | from gym_snake.envs.snake import Snake 7 | 8 | import logging 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class SnakeEnv(gym.Env): 14 | metadata = {'render.modes': ['human']} 15 | 16 | def __init__(self, blocks=10, block_size=50): 17 | self.blocks = blocks 18 | self.width = block_size * blocks 19 | self.snake = None 20 | 21 | self.action_space = spaces.Discrete(3) 22 | self.observation_space = spaces.Box( 23 | dtype=np.float32, 24 | low=np.array([0, 0, 0, -1, -1]), 25 | high=np.array([1, 1, 1, 1, 1]), 26 | ) 27 | 28 | self.seed() 29 | self.viewer = None 30 | self.rewards = None 31 | 32 | def set_rewards(self, rew_step, rew_apple, rew_death, rew_death2, rew_apple_func): 33 | self.rewards = [rew_step, rew_apple, rew_death, rew_death2, rew_apple_func] 34 | 35 | def seed(self, seed=None): 36 | self.np_random, seed = seeding.np_random(seed) 37 | return [seed] 38 | 39 | def step(self, action): 40 | if action != 0: 41 | self.snake.direction = self.snake.DIRECTIONS[self.snake.direction[action]] 42 | 43 | info = {} 44 | 45 | self.snake.update() 46 | info['apple_ate'] = self.snake.apple_ate 47 | 48 | raw_state, reward, done = self.snake.get_raw_state() 49 | info['apples'] = self.snake.cnt_apples 50 | 51 | state = np.array(raw_state, dtype=np.float32) 52 | state /= self.blocks 53 | 54 | return state, reward, done, info 55 | 56 | def reset(self): 57 | if self.rewards: 58 | self.snake = Snake(self.blocks, self.width // self.blocks, self.np_random, 59 | rew_step=self.rewards[0], rew_apple=self.rewards[1], 60 | rew_death=self.rewards[2], rew_death2=self.rewards[3], 61 | rew_apple_func=self.rewards[4],) 62 | else: 63 | self.snake = Snake(self.blocks, self.width // self.blocks, self.np_random) 64 | raw_state = self.snake.get_raw_state() 65 | 66 | state = np.array(raw_state[0], dtype=np.float32) 67 | state /= self.blocks 68 | 69 | return state 70 | 71 | def render(self, mode='human'): 72 | from gym.envs.classic_control import rendering 73 | w = self.snake.blockw 74 | 75 | if self.viewer is None: 76 | self.viewer = rendering.Viewer(self.width, self.width) 77 | apple = self._create_block(w) 78 | self.apple_trans = rendering.Transform() 79 | apple.add_attr(self.apple_trans) 80 | apple.set_color(*self.snake.apple.color) 81 | self.viewer.add_geom(apple) 82 | 83 | head = self._create_block(w) 84 | self.head_trans = rendering.Transform() 85 | head.add_attr(self.head_trans) 86 | head.set_color(*self.snake.head.color) 87 | self.viewer.add_geom(head) 88 | 89 | self.body = [] 90 | for i in range(len(self.snake.body)): 91 | body = self._create_block(w) 92 | body_trans = rendering.Transform() 93 | body.add_attr(body_trans) 94 | body.set_color(*self.snake.body[0].color) 95 | 96 | self.body.append(body_trans) 97 | self.viewer.add_geom(body) 98 | 99 | self.apple_trans.set_translation(self.snake.apple.x, self.snake.apple.y) 100 | self.head_trans.set_translation(self.snake.head.x, self.snake.head.y) 101 | 102 | if len(self.snake.body) > len(self.body): 103 | body = self._create_block(w) 104 | body_trans = rendering.Transform() 105 | body.add_attr(body_trans) 106 | body.set_color(*self.snake.body[0].color) 107 | 108 | self.body.append(body_trans) 109 | self.viewer.add_geom(body) 110 | elif len(self.snake.body) < len(self.body): 111 | self.body, trash = self.body[len(self.body) - len(self.snake.body):], \ 112 | self.body[:len(self.body) - len(self.snake.body)] 113 | for i in range(len(trash)): 114 | trash[i].set_translation(-w, -w) 115 | 116 | for i in range(len(self.body)): 117 | self.body[i].set_translation(self.snake.body[i].x, self.snake.body[i].y) 118 | 119 | self.viewer.render() 120 | 121 | def _create_block(self, w): 122 | from gym.envs.classic_control import rendering 123 | return rendering.FilledPolygon([(0, 0), (0, w), (w, w), (w, 0)]) 124 | 125 | def close(self): 126 | if self.viewer: 127 | self.viewer.close() 128 | self.viewer = None 129 | -------------------------------------------------------------------------------- /rl/agent.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from collections import namedtuple, deque 8 | 9 | from rl.model import DQN 10 | 11 | BUFFER_SIZE = int(1e5) # replay buffer size 12 | BATCH_SIZE = 64 # minibatch size 13 | GAMMA = 0.99 # discount factor 14 | TAU = 1e-3 # for soft update of target parameters 15 | LR = 5e-4 # learning rate 16 | 17 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 18 | 19 | 20 | class Agent: 21 | def __init__(self, state_size, action_size, seed): 22 | self.state_size = state_size 23 | self.action_size = action_size 24 | self.seed = random.seed(seed) 25 | 26 | # Q-Network 27 | self.qnetwork_local = DQN(state_size, action_size, seed).to(device) 28 | self.qnetwork_target = DQN(state_size, action_size, seed).to(device) 29 | self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR) 30 | 31 | # Replay memory 32 | self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed) 33 | 34 | def step(self, state, action, reward, next_state, done): 35 | # Save experience in replay memory 36 | self.memory.add(state, action, reward, next_state, done) 37 | 38 | # If enough samples are available in memory, get random subset and learn 39 | if len(self.memory) > BATCH_SIZE: 40 | experiences = self.memory.sample() 41 | self.learn(experiences, GAMMA) 42 | 43 | def act(self, state, eps=0.): 44 | """Returns actions for given state as per current policy. 45 | 46 | Params 47 | ====== 48 | state (array_like): current state 49 | eps (float): epsilon, for epsilon-greedy action selection 50 | """ 51 | state = torch.from_numpy(state).float().to(device) 52 | self.qnetwork_local.eval() 53 | with torch.no_grad(): 54 | action_values = self.qnetwork_local(state) 55 | self.qnetwork_local.train() 56 | 57 | # Epsilon-greedy action selection 58 | if random.random() > eps: 59 | return np.argmax(action_values.cpu().data.numpy()) 60 | else: 61 | return random.choice(np.arange(self.action_size)) 62 | 63 | def learn(self, experiences, gamma): 64 | """Update value parameters using given batch of experience tuples. 65 | Params 66 | ====== 67 | experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples 68 | gamma (float): discount factor 69 | """ 70 | states, actions, rewards, next_states, dones = experiences 71 | 72 | # get max predicted q values (for next states) from target model 73 | Q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(1) 74 | 75 | # Compute Q targets for current states 76 | Q_targets = rewards + (gamma * Q_targets_next * (1 - dones)) 77 | 78 | # Get expected Q values from local model 79 | Q_expected = self.qnetwork_local(states).gather(1, actions) 80 | 81 | # Compute loss 82 | loss = F.mse_loss(Q_expected, Q_targets) 83 | # Minimize the loss 84 | self.optimizer.zero_grad() 85 | loss.backward() 86 | self.optimizer.step() 87 | 88 | # ------------------- update target network ------------------- # 89 | self.soft_update(self.qnetwork_local, self.qnetwork_target, TAU) 90 | 91 | @staticmethod 92 | def soft_update(local_model, target_model, tau): 93 | """Soft update model parameters. 94 | θ_target = τ*θ_local + (1 - τ)*θ_target 95 | Params 96 | ====== 97 | local_model (PyTorch model): weights will be copied from 98 | target_model (PyTorch model): weights will be copied to 99 | tau (float): interpolation parameter 100 | """ 101 | for target_param, local_param in zip(target_model.parameters(), local_model.parameters()): 102 | target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data) 103 | 104 | 105 | class ReplayBuffer: 106 | """Fixed-size buffer to store experience tuples.""" 107 | 108 | def __init__(self, action_size, buffer_size, batch_size, seed): 109 | """Initialize a ReplayBuffer object. 110 | Params 111 | ====== 112 | action_size (int): dimension of each action 113 | buffer_size (int): maximum size of buffer 114 | batch_size (int): size of each training batch 115 | seed (int): random seed 116 | """ 117 | self.action_size = action_size 118 | self.memory = deque(maxlen=buffer_size) 119 | self.batch_size = batch_size 120 | self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"]) 121 | self.seed = random.seed(seed) 122 | 123 | def add(self, state, action, reward, next_state, done): 124 | """Add a new experience to memory.""" 125 | e = self.experience(state, action, reward, next_state, done) 126 | self.memory.append(e) 127 | 128 | def sample(self): 129 | """Randomly sample a batch of experiences from memory.""" 130 | experiences = random.sample(self.memory, k=self.batch_size) 131 | 132 | states = torch.from_numpy(np.vstack([e.state for e in experiences if e is not None])).float().to(device) 133 | actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(device) 134 | rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(device) 135 | next_states = torch.from_numpy(np.vstack([e.next_state for e in experiences if e is not None])).float().to(device) 136 | dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(device) 137 | 138 | return states, actions, rewards, next_states, dones 139 | 140 | def __len__(self): 141 | """Return the current size of internal memory.""" 142 | return len(self.memory) 143 | -------------------------------------------------------------------------------- /gym_snake/envs/snake.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | 3 | 4 | class Block: 5 | def __init__(self, x, y, w, color): 6 | self.x = x 7 | self.y = y 8 | self.w = w 9 | self.color = color 10 | 11 | 12 | class Snake: 13 | DIRECTIONS = { 14 | 'UP': ((0, 1), 'LEFT', 'RIGHT'), 15 | 'DOWN': ((0, -1), 'RIGHT', 'LEFT'), 16 | 'LEFT': ((-1, 0), 'DOWN', 'UP'), 17 | 'RIGHT': ((1, 0), 'UP', 'DOWN'), 18 | } 19 | 20 | def __init__(self, blocks, block_len, random, 21 | rew_step=-0.25, rew_apple=3.5, rew_death=-10.0, rew_death2=-100.0, 22 | rew_apple_func=lambda cnt, rew: sqrt(cnt) * rew): 23 | self.blockw = block_len 24 | self.blocks = blocks 25 | self.random = random 26 | 27 | x, y = self.blocks // 2, self.blocks // 2 28 | self.direction = self.DIRECTIONS['UP'] 29 | self.game_over = False 30 | 31 | self.rew_step = rew_step 32 | self.rew_apple = rew_apple 33 | self.rew_death = rew_death 34 | self.rew_death2 = rew_death2 35 | self.rew_apple_func = rew_apple_func 36 | 37 | self.body = [] 38 | for i in range(4): 39 | self.body.append( 40 | Block(x * block_len, 41 | (y - i - 1) * block_len, 42 | block_len, 43 | (0, 255, 0))) 44 | self.head = Block(x * block_len, 45 | y * block_len, 46 | block_len, 47 | (0, 255, 255), 48 | ) 49 | self.apple = None 50 | self.generate_apple() 51 | self.apple_ate = False 52 | self.cnt_apples = 0 53 | self.cnt_steps = 0 54 | 55 | def generate_apple(self): 56 | while True: 57 | x, y = self.random.randint(0, self.blocks - 1), self.random.randint(0, self.blocks - 1) 58 | if self.head.x == x * self.blockw and self.head.y == y * self.blockw: 59 | continue 60 | 61 | flag = False 62 | for e in self.body: 63 | if e.x == x * self.blockw and e.y == y * self.blockw: 64 | flag = True 65 | continue 66 | if flag: 67 | continue 68 | 69 | self.apple = Block( 70 | x * self.blockw, 71 | y * self.blockw, 72 | self.blockw, 73 | (255, 0, 0) 74 | ) 75 | break 76 | 77 | def update(self): 78 | if self.head.x < 0 or \ 79 | self.head.x > (self.blocks - 1) * self.blockw or \ 80 | self.head.y < 0 or \ 81 | self.head.y > (self.blocks - 1) * self.blockw: 82 | self.game_over = True 83 | return 84 | 85 | self.head.color = (0, 255, 0) 86 | self.body = [self.head] + self.body[:] 87 | self.head = Block(self.head.x + self.direction[0][0] * self.blockw, 88 | self.head.y + self.direction[0][1] * self.blockw, 89 | self.blockw, 90 | (0, 255, 255)) 91 | 92 | if self.apple is None: 93 | self.generate_apple() 94 | 95 | if self.apple.x == self.head.x and self.apple.y == self.head.y: 96 | self.generate_apple() 97 | self.apple_ate = True 98 | else: 99 | self.body = self.body[:-1] 100 | 101 | for e in self.body: 102 | if e.x == self.head.x and e.y == self.head.y: 103 | self.game_over = True 104 | 105 | def get_raw_state(self): 106 | reward = self.rew_step 107 | self.cnt_steps += 1 108 | if self.apple_ate: 109 | self.cnt_apples += 1 110 | self.apple_ate = False 111 | reward = self.rew_apple_func(self.cnt_apples, self.rew_apple) 112 | elif self.game_over: 113 | if self.cnt_steps < 15: 114 | reward = self.rew_death2 115 | else: 116 | reward = self.rew_death 117 | 118 | state = [ 119 | self.head.x // self.blockw, # from head to left side 120 | self.blocks - (self.head.x // self.blockw) - 1, # from head to right side 121 | self.head.y // self.blockw, # from head to down side 122 | self.blocks - (self.head.y // self.blockw) - 1, # from head to up side 123 | ] 124 | 125 | for e in self.body: 126 | if e.x == self.head.x: 127 | if e.y < self.head.y: 128 | state[2] = min(state[2], (self.head.y - e.y) // self.blockw - 1) 129 | else: 130 | state[3] = min(state[3], (- self.head.y + e.y) // self.blockw - 1) 131 | elif e.y == self.head.y: 132 | if e.x < self.head.x: 133 | state[0] = min(state[0], (self.head.x - e.x) // self.blockw - 1) 134 | else: 135 | state[1] = min(state[1], (- self.head.x + e.x) // self.blockw - 1) 136 | 137 | apple_crd = [ 138 | (-self.head.x + self.apple.x) // self.blockw, 139 | (-self.head.y + self.apple.y) // self.blockw, 140 | ] 141 | 142 | if self.direction == self.DIRECTIONS['UP']: 143 | state = [state[3], state[0], state[1]] 144 | if self.direction == self.DIRECTIONS['LEFT']: 145 | state = [state[0], state[2], state[3]] 146 | if apple_crd[0] * apple_crd[1] > 0: 147 | apple_crd[1] *= -1 148 | else: 149 | apple_crd[0] *= -1 150 | apple_crd[0], apple_crd[1] = apple_crd[1], apple_crd[0] 151 | if self.direction == self.DIRECTIONS['DOWN']: 152 | state = [state[2], state[1], state[0]] 153 | apple_crd[0] *= -1 154 | apple_crd[1] *= -1 155 | if self.direction == self.DIRECTIONS['RIGHT']: 156 | state = [state[1], state[3], state[2]] 157 | if apple_crd[0] * apple_crd[1] > 0: 158 | apple_crd[0] *= -1 159 | else: 160 | apple_crd[1] *= -1 161 | apple_crd[0], apple_crd[1] = apple_crd[1], apple_crd[0] 162 | 163 | state.extend(apple_crd) 164 | 165 | return state, reward, self.game_over 166 | --------------------------------------------------------------------------------