├── .gitignore ├── LISCENSE.txt ├── README.md ├── gym_snake ├── .gitignore ├── __init__.py └── envs │ ├── __init__.py │ ├── snake │ ├── __init__.py │ ├── controller.py │ ├── discrete.py │ ├── grid.py │ ├── grid_unittests.py │ ├── snake.py │ └── snake_unittests.py │ ├── snake_env.py │ └── snake_extrahard_env.py ├── requirements.txt ├── setup.py └── tests ├── imgs ├── biggrid.png ├── default.png ├── default_plural.png ├── nogap.png └── widegap.png └── manual_test.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info/ 2 | -------------------------------------------------------------------------------- /LISCENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [year] [fullname] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # gym-snake 2 | 3 | #### Created in response to OpenAI's [Requests for Research 2.0](https://blog.openai.com/requests-for-research-2/) 4 | 5 | ## Description 6 | gym-snake is a multi-agent implementation of the classic game [snake](https://www.youtube.com/watch?v=wDbTP0B94AM) that is made as an OpenAI gym environment. 7 | 8 | The two environments this repo offers are snake-v0 and snake-plural-v0. snake-v0 is the classic snake game. See the section on SnakeEnv for more details. snake-plural-v0 is a version of snake with multiple snakes and multiple snake foods on the map. See the section on SnakeExtraHardEnv for more details. 9 | 10 | Many of the aspects of the game can be changed for both environments. See the Game Details section for specifics. 11 | 12 | ## Dependencies 13 | - pip 14 | - gym 15 | - numpy 16 | - matplotlib 17 | 18 | ## Installation 19 | 1. Clone this repository 20 | 2. Navigate to the cloned repository 21 | 3. Run command `$ pip install -e ./` 22 | 23 | ## Rendering 24 | If you are experiencing trouble using the `render()` function while using jupyter notebook, insert: 25 | 26 | %matplotlib notebook 27 | 28 | before calling `render()`. 29 | 30 | ## Using gym-snake 31 | After installation, you can use gym-snake by making one of two gym environments. 32 | 33 | #### SnakeEnv 34 | Use `gym.make('snake-v0')` to make a new snake environment with the following default options (see Game Details to understand what each variable does): 35 | 36 | grid_size = [15,15] 37 | unit_size = 10 38 | unit_gap = 1 39 | snake_size = 3 40 | n_snakes = 1 41 | n_foods = 1 42 | 43 | #### SnakeExtraHardEnv 44 | Use `gym.make('snake-plural-v0')` to make a new snake environment with the following default parameters (see Game Details to understand what each variable does): 45 | 46 | grid_size = [25,25] 47 | unit_size = 10 48 | unit_gap = 1 49 | snake_size = 5 50 | n_snakes = 3 51 | n_foods = 2 52 | 53 | 54 | ## Game Details 55 | You're probably familiar with the game of snake. This is an OpenAI gym implementation of the game with multi snake and multi food options. 56 | 57 | #### Rewards 58 | A +1 reward is returned when a snake eats a food. 59 | 60 | A -1 reward is returned when a snake dies. 61 | 62 | No extra reward is given for victory snakes in plural play. 63 | 64 | #### Game Options 65 | 66 | - _grid_size_ - An x,y coordinate denoting the number of units on the snake grid (width, height). 67 | - _unit_size_ - Number of numpy pixels within a single grid unit. 68 | - _unit_gap_ - Number of pixels separating each unit of the grid. Space between the units can be useful to understand the direction of the snake's body. 69 | - _snake_size_ - Number of body units for each snake at start of game 70 | - _n_snakes_ - Number of individual snakes on grid 71 | - _n_foods_ - Number of food units (the stuff that makes the snakes grow) on the grid at any given time. 72 | - _random_init_ - If set to false, the food units initialize to the same location at each reset. 73 | 74 | Each of these options are member variables of the environment and will come into effect after the environment is reset. For example, if you wanted to use 5 food tokens in the regular version, you can be set the number of food tokens using the following code: 75 | 76 | env = gym.snake('snake-v0') 77 | env.n_foods = 5 78 | observation = env.reset() 79 | 80 | This will create a vanilla snake environment with 5 food tokens on the map. 81 | 82 | 83 | #### Environment Parameter Examples 84 | Below is the default setting for `snake-v0` play, 15x15 unit grid. 85 | 86 | ![default](./tests/imgs/default.png) 87 | 88 | 89 | Below is the default setting for `snake-plural-v0` play, 25x25 unit grid. 90 | 91 | ![default](./tests/imgs/default_plural.png) 92 | 93 | Below is `env.unit_gap` == 0 and a 30x30 grid. 94 | 95 | ![default](./tests/imgs/nogap.png) 96 | 97 | Below set `env.unit_gap` half the unit size with a 15x15 sized grid. 98 | 99 | ![default](./tests/imgs/widegap.png) 100 | 101 | Below is a big grid with lots of food and small snakes. 102 | 103 | ![default](./tests/imgs/biggrid.png) 104 | 105 | #### General Info 106 | The snake environment has three main interacting classes to construct the environment. The three are a Snake class, a Grid class, and a Controller class. Each holds information about the environment, and each can be accessed through the gym environment. 107 | 108 | import gym 109 | import gym_snake 110 | 111 | # Construct Environment 112 | env = gym.make('snake-v0') 113 | observation = env.reset() # Constructs an instance of the game 114 | 115 | # Controller 116 | game_controller = env.controller 117 | 118 | # Grid 119 | grid_object = game_controller.grid 120 | grid_pixels = grid_object.grid 121 | 122 | # Snake(s) 123 | snakes_array = game_controller.snakes 124 | snake_object1 = snakes_array[0] 125 | 126 | #### Using Multiple Snakes 127 | Snakes can be distinguished by the Green value of their `head_color` attribute. Each head color consists of [Red=255, Green=uniqueNonZeroValue, Blue=0]. For each snake instantiated, the head color will corespond to its index within the controller's snake array. The head value will take on [255, (i+1)*10, 0] where i is the index of the snake. 128 | 129 | When using multiple snakes, at each step, you pass an array of actions corresponding to the action of each snake. The return is an array of rewards corresponding each snake. The reward returned upon a snakes' death is -1, each subsequent step after this, however, is a reward of 0. The contents of the action array are ignored at a dead snake's index, the action array must, however, continue to have an index for each snake that originally started out the game. 130 | 131 | #### Coordinates 132 | The units of the game are made to take up multiple pixels within the grid. Each unit has an x,y coordinate associated with it where (0,0) represents the uppermost left unit of the grid and (`grid_object.grid_size[0]`, `grid_object.grid_size[1]`) denotes the lowermost right unit of the grid. Positional information about snake food and snakes' bodies is encoded using this coordinate system. 133 | 134 | #### Snake Class 135 | This class holds all pertinent information about an individual snake. Useful information includes: 136 | 137 | # Action constants denote the action space. 138 | snake_object1.UP # Equal to integer 0 139 | snake_object1.RIGHT # Equal to integer 1 140 | snake_object1.DOWN # Equal to integer 2 141 | snake_object1.LEFT # Equal to integer 3 142 | 143 | # Member Variables 144 | snake_object1.direction # Indicates which direction the snake's head is pointing; initially points DOWN 145 | snake_object1.head # x,y Coordinate of the snake's head 146 | snake_object1.head_color # A pixel ([R,G,B]) of type uint8 with an R value of 255 147 | snake_object1.body # deque containing the coordinates of the snake's body ordered from tail to neck [furthest from head, ..., closest to head] 148 | 149 | #### Grid Class 150 | This class holds all pertinent information about the grid that the snakes move on. Useful information includes: 151 | 152 | # Color constants give information about the colors of the grid 153 | # Each are ndarrays with dtype uint8 154 | grid_object.BODY_COLOR # [1,0,0] Color of snake body units 155 | grid_object.HEAD_COLOR # [255, (i+1)*10, 0] Color of snake head units. i is the index of the snake. 156 | grid_object.FOOD_COLOR # [0,0,255] Color of food units 157 | grid_object.SPACE_COLOR # [0,255,0] Color of blank space 158 | 159 | # Member Variables 160 | grid_object.unit_size # See Game Options 161 | grid_object.unit_gap # See Game Options 162 | grid_object.grid_size # See Game Options 163 | grid_object.grid # Numpy [R,G,B] pixel array of game 164 | 165 | #### Controller Class 166 | The Controller holds a grid object and an array of snakes that move on the grid. The Controller class handles the game logic between the snakes and the grid. Actions are taken through this class and initialization parameters within this class dictate the initial parameters of the grid and snake objects in the game. Useful information includes: 167 | 168 | # Member variables 169 | game_controller.grid # An instance of the grid class for the game 170 | self.snakes # An array of snake objects that are on the board. If a snake dies, it is erased and it becomes None. 171 | 172 | 173 | -------------------------------------------------------------------------------- /gym_snake/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | -------------------------------------------------------------------------------- /gym_snake/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | 3 | register( 4 | id='snake-v0', 5 | entry_point='gym_snake.envs:SnakeEnv', 6 | ) 7 | register( 8 | id='snake-plural-v0', 9 | entry_point='gym_snake.envs:SnakeExtraHardEnv', 10 | ) 11 | -------------------------------------------------------------------------------- /gym_snake/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from gym_snake.envs.snake_env import SnakeEnv 2 | from gym_snake.envs.snake_extrahard_env import SnakeExtraHardEnv 3 | -------------------------------------------------------------------------------- /gym_snake/envs/snake/__init__.py: -------------------------------------------------------------------------------- 1 | from gym_snake.envs.snake.snake import Snake 2 | from gym_snake.envs.snake.grid import Grid 3 | from gym_snake.envs.snake.controller import Controller 4 | from gym_snake.envs.snake.discrete import Discrete 5 | -------------------------------------------------------------------------------- /gym_snake/envs/snake/controller.py: -------------------------------------------------------------------------------- 1 | from gym_snake.envs.snake import Snake 2 | from gym_snake.envs.snake import Grid 3 | import numpy as np 4 | 5 | class Controller(): 6 | """ 7 | This class combines the Snake, Food, and Grid classes to handle the game logic. 8 | """ 9 | 10 | def __init__(self, grid_size=[30,30], unit_size=10, unit_gap=1, snake_size=3, n_snakes=1, n_foods=1, random_init=True): 11 | 12 | assert n_snakes < grid_size[0]//3 13 | assert n_snakes < 25 14 | assert snake_size < grid_size[1]//2 15 | assert unit_gap >= 0 and unit_gap < unit_size 16 | 17 | self.snakes_remaining = n_snakes 18 | self.grid = Grid(grid_size, unit_size, unit_gap) 19 | 20 | self.snakes = [] 21 | self.dead_snakes = [] 22 | for i in range(1,n_snakes+1): 23 | start_coord = [i*grid_size[0]//(n_snakes+1), snake_size+1] 24 | self.snakes.append(Snake(start_coord, snake_size)) 25 | color = [self.grid.HEAD_COLOR[0], i*10, 0] 26 | self.snakes[-1].head_color = color 27 | self.grid.draw_snake(self.snakes[-1], color) 28 | self.dead_snakes.append(None) 29 | 30 | if not random_init: 31 | for i in range(2,n_foods+2): 32 | start_coord = [i*grid_size[0]//(n_foods+3), grid_size[1]-5] 33 | self.grid.place_food(start_coord) 34 | else: 35 | for i in range(n_foods): 36 | self.grid.new_food() 37 | 38 | def move_snake(self, direction, snake_idx): 39 | """ 40 | Moves the specified snake according to the game's rules dependent on the direction. 41 | Does not draw head and does not check for reward scenarios. See move_result for these 42 | functionalities. 43 | """ 44 | 45 | snake = self.snakes[snake_idx] 46 | if type(snake) == type(None): 47 | return 48 | 49 | # Cover old head position with body 50 | self.grid.cover(snake.head, self.grid.BODY_COLOR) 51 | # Erase tail without popping so as to redraw if food eaten 52 | self.grid.erase(snake.body[0]) 53 | # Find and set next head position conditioned on direction 54 | snake.action(direction) 55 | 56 | def move_result(self, direction, snake_idx=0): 57 | """ 58 | Checks for food and death collisions after moving snake. Draws head of snake if 59 | no death scenarios. 60 | """ 61 | 62 | snake = self.snakes[snake_idx] 63 | if type(snake) == type(None): 64 | return 0 65 | 66 | # Check for death of snake 67 | if self.grid.check_death(snake.head): 68 | self.dead_snakes[snake_idx] = self.snakes[snake_idx] 69 | self.snakes[snake_idx] = None 70 | self.grid.cover(snake.head, snake.head_color) # Avoid miscount of grid.open_space 71 | self.grid.connect(snake.body.popleft(), snake.body[0], self.grid.SPACE_COLOR) 72 | reward = -1 73 | # Check for reward 74 | elif self.grid.food_space(snake.head): 75 | self.grid.draw(snake.body[0], self.grid.BODY_COLOR) # Redraw tail 76 | self.grid.connect(snake.body[0], snake.body[1], self.grid.BODY_COLOR) 77 | self.grid.cover(snake.head, snake.head_color) # Avoid miscount of grid.open_space 78 | reward = 1 79 | self.grid.new_food() 80 | else: 81 | reward = 0 82 | empty_coord = snake.body.popleft() 83 | self.grid.connect(empty_coord, snake.body[0], self.grid.SPACE_COLOR) 84 | self.grid.draw(snake.head, snake.head_color) 85 | 86 | self.grid.connect(snake.body[-1], snake.head, self.grid.BODY_COLOR) 87 | 88 | return reward 89 | 90 | def kill_snake(self, snake_idx): 91 | """ 92 | Deletes snake from game and subtracts from the snake_count 93 | """ 94 | 95 | assert self.dead_snakes[snake_idx] is not None 96 | self.grid.erase(self.dead_snakes[snake_idx].head) 97 | self.grid.erase_snake_body(self.dead_snakes[snake_idx]) 98 | self.dead_snakes[snake_idx] = None 99 | self.snakes_remaining -= 1 100 | 101 | def step(self, directions): 102 | """ 103 | Takes an action for each snake in the specified direction and collects their rewards 104 | and dones. 105 | 106 | directions - tuple, list, or ndarray of directions corresponding to each snake. 107 | """ 108 | 109 | # Ensure no more play until reset 110 | if self.snakes_remaining < 1 or self.grid.open_space < 1: 111 | if type(directions) == type(int()) or len(directions) == 1: 112 | return self.grid.grid.copy(), 0, True, {"snakes_remaining":self.snakes_remaining} 113 | else: 114 | return self.grid.grid.copy(), [0]*len(directions), True, {"snakes_remaining":self.snakes_remaining} 115 | 116 | rewards = [] 117 | 118 | if type(directions) == type(int()): 119 | directions = [directions] 120 | 121 | for i, direction in enumerate(directions): 122 | if self.snakes[i] is None and self.dead_snakes[i] is not None: 123 | self.kill_snake(i) 124 | self.move_snake(direction,i) 125 | rewards.append(self.move_result(direction, i)) 126 | 127 | done = self.snakes_remaining < 1 or self.grid.open_space < 1 128 | if len(rewards) == 1: 129 | return self.grid.grid.copy(), rewards[0], done, {"snakes_remaining":self.snakes_remaining} 130 | else: 131 | return self.grid.grid.copy(), rewards, done, {"snakes_remaining":self.snakes_remaining} 132 | -------------------------------------------------------------------------------- /gym_snake/envs/snake/discrete.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class Discrete(): 4 | def __init__(self, n_actions): 5 | self.dtype = np.int32 6 | self.n = n_actions 7 | self.actions = np.arange(self.n, dtype=self.dtype) 8 | self.shape = self.actions.shape 9 | 10 | def contains(self, argument): 11 | for action in self.actions: 12 | if action == argument: 13 | return True 14 | return False 15 | 16 | def sample(self): 17 | return np.random.choice(self.n) 18 | -------------------------------------------------------------------------------- /gym_snake/envs/snake/grid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class Grid(): 4 | 5 | """ 6 | This class contains all data related to the grid in which the game is contained. 7 | The information is stored as a numpy array of pixels. 8 | The grid is treated as a cartesian [x,y] plane in which [0,0] is located at 9 | the upper left most pixel and [max_x, max_y] is located at the lower right most pixel. 10 | 11 | Note that it is assumed spaces that can kill a snake have a non-zero value as their 0 channel. 12 | It is also assumed that HEAD_COLOR has a 255 value as its 0 channel. 13 | """ 14 | 15 | BODY_COLOR = np.array([1,0,0], dtype=np.uint8) 16 | HEAD_COLOR = np.array([255, 0, 0], dtype=np.uint8) 17 | FOOD_COLOR = np.array([0,0,255], dtype=np.uint8) 18 | SPACE_COLOR = np.array([0,255,0], dtype=np.uint8) 19 | COLORS = np.asarray([BODY_COLOR, HEAD_COLOR, FOOD_COLOR, SPACE_COLOR]) 20 | 21 | def __init__(self, grid_size=[30,30], unit_size=10, unit_gap=1): 22 | """ 23 | grid_size - tuple, list, or ndarray specifying number of atomic units in 24 | both the x and y direction 25 | unit_size - integer denoting the atomic size of grid units in pixels 26 | """ 27 | 28 | self.unit_size = int(unit_size) 29 | self.unit_gap = unit_gap 30 | self.grid_size = np.asarray(grid_size, dtype=np.int) # size in terms of units 31 | height = self.grid_size[1]*self.unit_size 32 | width = self.grid_size[0]*self.unit_size 33 | channels = 3 34 | self.grid = np.zeros((height, width, channels), dtype=np.uint8) 35 | self.grid[:,:,:] = self.SPACE_COLOR 36 | self.open_space = grid_size[0]*grid_size[1] 37 | 38 | def check_death(self, head_coord): 39 | """ 40 | Checks the grid to see if argued head_coord has collided with a death space (i.e. snake or wall) 41 | 42 | head_coord - x,y integer coordinates as a tuple, list, or ndarray 43 | """ 44 | return self.off_grid(head_coord) or self.snake_space(head_coord) 45 | 46 | def color_of(self, coord): 47 | """ 48 | Returns the color of the specified coordinate 49 | 50 | coord - x,y integer coordinates as a tuple, list, or ndarray 51 | """ 52 | 53 | return self.grid[int(coord[1]*self.unit_size), int(coord[0]*self.unit_size), :] 54 | 55 | def connect(self, coord1, coord2, color=BODY_COLOR): 56 | """ 57 | Draws connection between two adjacent pieces using the specified color. 58 | Created to indicate the relative ordering of the snake's body. 59 | coord1 and coord2 must be adjacent. 60 | 61 | coord1 - x,y integer coordinates as a tuple, list, or ndarray 62 | coord2 - x,y integer coordinates as a tuple, list, or ndarray 63 | color - [R,G,B] values as a tuple, list, or ndarray 64 | """ 65 | 66 | # Check for adjacency 67 | # Next to one another: 68 | adjacency1 = (np.abs(coord1[0]-coord2[0]) == 1 and np.abs(coord1[1]-coord2[1]) == 0) 69 | # Stacked on one another: 70 | adjacency2 = (np.abs(coord1[0]-coord2[0]) == 0 and np.abs(coord1[1]-coord2[1]) == 1) 71 | assert adjacency1 or adjacency2 72 | 73 | if adjacency1: # x values differ 74 | min_x, max_x = sorted([coord1[0], coord2[0]]) 75 | min_x = min_x*self.unit_size+self.unit_size-self.unit_gap 76 | max_x = max_x*self.unit_size 77 | self.grid[coord1[1]*self.unit_size, min_x:max_x, :] = color 78 | self.grid[coord1[1]*self.unit_size+self.unit_size-self.unit_gap-1, min_x:max_x, :] = color 79 | else: # y values differ 80 | min_y, max_y = sorted([coord1[1], coord2[1]]) 81 | min_y = min_y*self.unit_size+self.unit_size-self.unit_gap 82 | max_y = max_y*self.unit_size 83 | self.grid[min_y:max_y, coord1[0]*self.unit_size, :] = color 84 | self.grid[min_y:max_y, coord1[0]*self.unit_size+self.unit_size-self.unit_gap-1, :] = color 85 | 86 | def cover(self, coord, color): 87 | """ 88 | Colors a single space on the grid. Use erase if creating an empty space on the grid. 89 | This function is used like draw but without affecting the open_space count. 90 | 91 | coord - x,y integer coordinates as a tuple, list, or ndarray 92 | color - [R,G,B] values as a tuple, list, or ndarray 93 | """ 94 | 95 | if self.off_grid(coord): 96 | return False 97 | x = int(coord[0]*self.unit_size) 98 | end_x = x+self.unit_size-self.unit_gap 99 | y = int(coord[1]*self.unit_size) 100 | end_y = y+self.unit_size-self.unit_gap 101 | self.grid[y:end_y, x:end_x, :] = np.asarray(color, dtype=np.uint8) 102 | return True 103 | 104 | def draw(self, coord, color): 105 | """ 106 | Colors a single space on the grid. Use erase if creating an empty space on the grid. 107 | Affects the open_space count. 108 | 109 | coord - x,y integer coordinates as a tuple, list, or ndarray 110 | color - [R,G,B] values as a tuple, list, or ndarray 111 | """ 112 | 113 | if self.cover(coord, color): 114 | self.open_space -= 1 115 | return True 116 | else: 117 | return False 118 | 119 | 120 | def draw_snake(self, snake, head_color=HEAD_COLOR): 121 | """ 122 | Draws a snake with the given head color. 123 | 124 | snake - Snake object 125 | head_color - [R,G,B] values as a tuple, list, or ndarray 126 | """ 127 | 128 | self.draw(snake.head, head_color) 129 | prev_coord = None 130 | for i in range(len(snake.body)): 131 | coord = snake.body.popleft() 132 | self.draw(coord, self.BODY_COLOR) 133 | if prev_coord is not None: 134 | self.connect(prev_coord, coord, self.BODY_COLOR) 135 | snake.body.append(coord) 136 | prev_coord = coord 137 | self.connect(prev_coord, snake.head, self.BODY_COLOR) 138 | 139 | def erase(self, coord): 140 | """ 141 | Colors the entire coordinate with SPACE_COLOR to erase potential 142 | connection lines. 143 | 144 | coord - (x,y) as tuple, list, or ndarray 145 | """ 146 | if self.off_grid(coord): 147 | return False 148 | self.open_space += 1 149 | x = int(coord[0]*self.unit_size) 150 | end_x = x+self.unit_size 151 | y = int(coord[1]*self.unit_size) 152 | end_y = y+self.unit_size 153 | self.grid[y:end_y, x:end_x, :] = self.SPACE_COLOR 154 | return True 155 | 156 | def erase_connections(self, coord): 157 | """ 158 | Colors the dead space of the given coordinate with SPACE_COLOR to erase potential 159 | connection lines 160 | 161 | coord - (x,y) as tuple, list, or ndarray 162 | """ 163 | 164 | if self.off_grid(coord): 165 | return False 166 | # Erase Horizontal Row Below Coord 167 | x = int(coord[0]*self.unit_size) 168 | end_x = x+self.unit_size 169 | y = int(coord[1]*self.unit_size)+self.unit_size-self.unit_gap 170 | end_y = y+self.unit_gap 171 | self.grid[y:end_y, x:end_x, :] = self.SPACE_COLOR 172 | 173 | # Erase the Vertical Column to Right of Coord 174 | x = int(coord[0]*self.unit_size)+self.unit_size-self.unit_gap 175 | end_x = x+self.unit_gap 176 | y = int(coord[1]*self.unit_size) 177 | end_y = y+self.unit_size 178 | self.grid[y:end_y, x:end_x, :] = self.SPACE_COLOR 179 | 180 | return True 181 | 182 | def erase_snake_body(self, snake): 183 | """ 184 | Removes the argued snake's body and head from the grid. 185 | 186 | snake - Snake object 187 | """ 188 | 189 | for i in range(len(snake.body)): 190 | self.erase(snake.body.popleft()) 191 | 192 | def food_space(self, coord): 193 | """ 194 | Checks if argued coord is snake food 195 | 196 | coord - x,y integer coordinates as a tuple, list, or ndarray 197 | """ 198 | 199 | return np.array_equal(self.color_of(coord), self.FOOD_COLOR) 200 | 201 | def place_food(self, coord): 202 | """ 203 | Draws a food at the coord. Ensures the same placement for 204 | each food at the beginning of a new episode. This is useful for 205 | experimentation with curiosity driven behaviors. 206 | 207 | num - the integer denoting the 208 | """ 209 | if self.open_space < 1 or not np.array_equal(self.color_of(coord), self.SPACE_COLOR): 210 | return False 211 | self.draw(coord, self.FOOD_COLOR) 212 | return True 213 | 214 | def new_food(self): 215 | """ 216 | Draws a food on a random, open unit of the grid. 217 | Returns true if space left. Otherwise returns false. 218 | """ 219 | 220 | if self.open_space < 1: 221 | return False 222 | coord_not_found = True 223 | while(coord_not_found): 224 | coord = (np.random.randint(0,self.grid_size[0]), np.random.randint(0,self.grid_size[1])) 225 | if np.array_equal(self.color_of(coord), self.SPACE_COLOR): 226 | coord_not_found = False 227 | self.draw(coord, self.FOOD_COLOR) 228 | return True 229 | 230 | def off_grid(self, coord): 231 | """ 232 | Checks if argued coord is off of the grid 233 | 234 | coord - x,y integer coordinates as a tuple, list, or ndarray 235 | """ 236 | 237 | return coord[0]<0 or coord[0]>=self.grid_size[0] or coord[1]<0 or coord[1]>=self.grid_size[1] 238 | 239 | def snake_space(self, coord): 240 | """ 241 | Checks if argued coord is occupied by a snake 242 | 243 | coord - x,y integer coordinates as a tuple, list, or ndarray 244 | """ 245 | 246 | color = self.color_of(coord) 247 | return np.array_equal(color, self.BODY_COLOR) or color[0] == self.HEAD_COLOR[0] 248 | -------------------------------------------------------------------------------- /gym_snake/envs/snake/grid_unittests.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from grid import Grid 3 | from snake import Snake 4 | import numpy as np 5 | 6 | class GridTests(unittest.TestCase): 7 | 8 | grid_size = [30,30] 9 | unit_size = 10 10 | 11 | def test_grid_Initialization(self): 12 | grid = Grid(self.grid_size, self.unit_size) 13 | expected_size = [300,300,3] 14 | expected_grid = np.zeros(expected_size, dtype=np.uint8) 15 | expected_grid[:,:,1] = 255 16 | self.assertTrue(np.array_equal(grid.grid, expected_grid)) 17 | 18 | def test_constant_Initialization(self): 19 | grid = Grid(self.grid_size, self.unit_size) 20 | self.assertTrue(grid.unit_size == self.unit_size) 21 | self.assertTrue(np.array_equal(grid.grid_size, self.grid_size)) 22 | 23 | def test_color_Initialization(self): 24 | grid = Grid(self.grid_size, self.unit_size) 25 | expected_color = np.array([0,255,0], dtype=np.uint8) 26 | for i in range(grid.grid.shape[0]): 27 | for j in range(grid.grid.shape[1]): 28 | self.assertTrue(np.array_equal(grid.grid[i,j,:],expected_color)) 29 | 30 | def test_color_of_Color(self): 31 | grid = Grid(self.grid_size, self.unit_size) 32 | expected_color = np.array([0,255,0], dtype=np.uint8) 33 | self.assertTrue(np.array_equal(grid.color_of([0,0]),expected_color)) 34 | 35 | def test_color_of_Coordinate(self): 36 | grid = Grid(self.grid_size, self.unit_size) 37 | coord = [3,2] 38 | expected_color = np.array(grid.BODY_COLOR, dtype=np.uint8) 39 | grid.grid[coord[1]*self.unit_size,coord[0]*self.unit_size,:] = expected_color 40 | self.assertTrue(np.array_equal(grid.color_of(coord),expected_color)) 41 | 42 | def test_draw_Positive(self): 43 | grid = Grid(self.grid_size, self.unit_size) 44 | expected_color = np.array(grid.BODY_COLOR, dtype=np.uint8) 45 | coord = [3,2] 46 | grid.draw(coord, expected_color) 47 | for y in range(grid.grid.shape[0]): 48 | for x in range(grid.grid.shape[1]): 49 | if y >= coord[1]*self.unit_size and y < coord[1]*self.unit_size+grid.unit_size-grid.unit_gap and x >= coord[0]*self.unit_size and x < coord[0]*self.unit_size+grid.unit_size-grid.unit_gap: 50 | self.assertTrue(np.array_equal(grid.grid[y,x,:],expected_color)) 51 | else: 52 | self.assertFalse(np.array_equal(grid.grid[y,x,:],expected_color)) 53 | 54 | def test_draw_Negative(self): 55 | grid = Grid(self.grid_size, self.unit_size) 56 | expected_color = grid.SPACE_COLOR 57 | coord = [3,2] 58 | grid.draw(coord, grid.BODY_COLOR) 59 | for y in range(grid.grid.shape[0]): 60 | for x in range(grid.grid.shape[1]): 61 | if y >= coord[1]*self.unit_size and y < coord[1]*self.unit_size+grid.unit_size-grid.unit_gap and x >= coord[0]*self.unit_size and x < coord[0]*self.unit_size+grid.unit_size-grid.unit_gap: 62 | self.assertFalse(np.array_equal(grid.grid[y,x,:],expected_color)) 63 | else: 64 | self.assertTrue(np.array_equal(grid.grid[y,x,:],expected_color)) 65 | 66 | def test_draw_snake_Positive(self): 67 | grid = Grid(self.grid_size, self.unit_size) 68 | snake_size = 3 69 | head_coord = [10,10] 70 | snake = Snake(head_coord, snake_size) 71 | grid.draw_snake(snake, head_color=grid.HEAD_COLOR) 72 | 73 | expected_colors = np.array([grid.HEAD_COLOR, grid.BODY_COLOR, grid.BODY_COLOR], dtype=np.uint8) 74 | expected_coords = np.array([[10,10], [10,9], [10,8]]) 75 | for coord,color in zip(expected_coords, expected_colors): 76 | self.assertTrue(np.array_equal(grid.color_of(coord), color)) 77 | 78 | def test_draw_snake_Negative(self): 79 | grid = Grid(self.grid_size, self.unit_size) 80 | snake_size = 3 81 | head_coord = [10,10] 82 | snake = Snake(head_coord, snake_size) 83 | grid.draw_snake(snake, grid.HEAD_COLOR) 84 | 85 | expected_color = grid.SPACE_COLOR 86 | expected_coords = [(10,10), (10,9), (10,8)] 87 | for i,j in zip(range(grid.grid_size[0]),range(grid.grid_size[1])): 88 | coord = (i,j) 89 | if coord == expected_coords[0] or coord == expected_coords[1] or coord == expected_coords[2]: 90 | self.assertFalse(np.array_equal(grid.color_of(coord), expected_color)) 91 | else: 92 | self.assertTrue(np.array_equal(grid.color_of(coord), expected_color)) 93 | 94 | def test_draw_snake_Snake_Data(self): 95 | grid = Grid(self.grid_size, self.unit_size) 96 | snake_size = 3 97 | head_coord = [10,10] 98 | snake = Snake(head_coord, snake_size) 99 | grid.draw_snake(snake, grid.HEAD_COLOR) 100 | 101 | expected_coords = [[10,8],[10,9]] 102 | for i in range(len(snake.body)): 103 | self.assertTrue(np.array_equal(snake.body.popleft(), expected_coords[i])) 104 | 105 | def test_erase_snake_body(self): 106 | grid = Grid(self.grid_size, self.unit_size) 107 | snake_size = 3 108 | head_coord = [10,10] 109 | snake = Snake(head_coord, snake_size) 110 | grid.draw_snake(snake, grid.HEAD_COLOR) 111 | snake.action(1) 112 | grid.erase_snake_body(snake) 113 | 114 | expected_color = grid.SPACE_COLOR 115 | for i,j in zip(range(grid.grid_size[0]),range(grid.grid_size[1])): 116 | coord = (i,j) 117 | self.assertTrue(np.array_equal(grid.color_of(coord), expected_color)) 118 | 119 | def test_new_food(self): 120 | grid = Grid(self.grid_size, self.unit_size) 121 | expected_coord = (10,11) 122 | for x in range(grid.grid_size[0]): 123 | for y in range(grid.grid_size[1]): 124 | coord = (x,y) 125 | if coord != expected_coord: 126 | grid.draw(coord, grid.BODY_COLOR) 127 | 128 | self.assertTrue(grid.new_food()) 129 | self.assertTrue(np.array_equal(grid.color_of(expected_coord), grid.FOOD_COLOR)) 130 | 131 | def test_new_food_nospace(self): 132 | grid = Grid(self.grid_size, self.unit_size) 133 | for x in range(grid.grid_size[0]): 134 | for y in range(grid.grid_size[1]): 135 | coord = (x,y) 136 | grid.draw(coord, grid.BODY_COLOR) 137 | self.assertFalse(grid.new_food()) 138 | 139 | def test_snake_space_BODY(self): 140 | grid = Grid(self.grid_size, self.unit_size) 141 | coord = (10,11) 142 | grid.draw(coord, grid.BODY_COLOR) 143 | self.assertTrue(grid.snake_space(coord)) 144 | 145 | def test_snake_space_HEAD(self): 146 | grid = Grid(self.grid_size, self.unit_size) 147 | coord = (10,11) 148 | grid.draw(coord, grid.HEAD_COLOR) 149 | self.assertTrue(grid.snake_space(coord)) 150 | 151 | def test_snake_space_FOOD(self): 152 | grid = Grid(self.grid_size, self.unit_size) 153 | coord = (10,11) 154 | grid.draw(coord, grid.FOOD_COLOR) 155 | self.assertFalse(grid.snake_space(coord)) 156 | 157 | def test_snake_space_SPACE(self): 158 | grid = Grid(self.grid_size, self.unit_size) 159 | coord = (10,11) 160 | grid.draw(coord, grid.SPACE_COLOR) 161 | self.assertFalse(grid.snake_space(coord)) 162 | 163 | def test_off_grid_UP(self): 164 | grid = Grid(self.grid_size, self.unit_size) 165 | coord = (0,-1) 166 | self.assertTrue(grid.off_grid(coord)) 167 | 168 | def test_off_grid_RIGHT(self): 169 | grid = Grid(self.grid_size, self.unit_size) 170 | coord = (self.grid_size[0],0) 171 | self.assertTrue(grid.off_grid(coord)) 172 | 173 | def test_off_grid_DOWN(self): 174 | grid = Grid(self.grid_size, self.unit_size) 175 | coord = (0,self.grid_size[1]) 176 | self.assertTrue(grid.off_grid(coord)) 177 | 178 | def test_off_grid_LEFT(self): 179 | grid = Grid(self.grid_size, self.unit_size) 180 | coord = (-1,0) 181 | self.assertTrue(grid.off_grid(coord)) 182 | 183 | def test_food_space_FOOD(self): 184 | grid = Grid(self.grid_size, self.unit_size) 185 | coord = (10,11) 186 | grid.draw(coord, grid.FOOD_COLOR) 187 | self.assertTrue(grid.food_space(coord)) 188 | 189 | def test_food_space_BODY(self): 190 | grid = Grid(self.grid_size, self.unit_size) 191 | coord = (10,11) 192 | grid.draw(coord, grid.BODY_COLOR) 193 | self.assertFalse(grid.food_space(coord)) 194 | 195 | def test_food_space_HEAD(self): 196 | grid = Grid(self.grid_size, self.unit_size) 197 | coord = (10,11) 198 | grid.draw(coord, grid.HEAD_COLOR) 199 | self.assertFalse(grid.food_space(coord)) 200 | 201 | def test_food_space_SPACE(self): 202 | grid = Grid(self.grid_size, self.unit_size) 203 | coord = (10,11) 204 | grid.draw(coord, grid.SPACE_COLOR) 205 | self.assertFalse(grid.food_space(coord)) 206 | 207 | def test_connect_x(self): 208 | grid = Grid(self.grid_size, self.unit_size) 209 | expected_color = grid.BODY_COLOR 210 | coord1 = [3,2] 211 | coord2 = [4,2] 212 | grid.connect(coord1, coord2, expected_color) 213 | for y in range(grid.grid.shape[0]): 214 | for x in range(grid.grid.shape[1]): 215 | if (y == coord1[1]*self.unit_size or y == coord1[1]*self.unit_size+grid.unit_size-grid.unit_gap-1) and (x < coord2[0]*self.unit_size and x >= coord1[0]*self.unit_size+grid.unit_size-grid.unit_gap): 216 | self.assertTrue(np.array_equal(grid.grid[y,x,:],expected_color)) 217 | else: 218 | self.assertFalse(np.array_equal(grid.grid[y,x,:],expected_color)) 219 | 220 | def test_connect_y(self): 221 | grid = Grid(self.grid_size, self.unit_size) 222 | expected_color = grid.BODY_COLOR 223 | coord1 = [2,3] 224 | coord2 = [2,4] 225 | grid.connect(coord1, coord2, expected_color) 226 | for y in range(grid.grid.shape[0]): 227 | for x in range(grid.grid.shape[1]): 228 | if (x == coord1[0]*self.unit_size or x == coord1[0]*self.unit_size+grid.unit_size-grid.unit_gap-1) and (y < coord2[1]*self.unit_size and y >= coord1[1]*self.unit_size+grid.unit_size-grid.unit_gap): 229 | self.assertTrue(np.array_equal(grid.grid[y,x,:],expected_color)) 230 | else: 231 | self.assertFalse(np.array_equal(grid.grid[y,x,:],expected_color)) 232 | 233 | def test_erase(self): 234 | grid = Grid(self.grid_size, self.unit_size) 235 | coord1 = [2,3] 236 | coord2 = [2,4] 237 | grid.draw(coord1, grid.BODY_COLOR) 238 | grid.draw(coord2, grid.BODY_COLOR) 239 | grid.connect(coord1,coord2) 240 | expected_color = grid.SPACE_COLOR 241 | grid.erase(coord1) 242 | grid.erase(coord2) 243 | for y in range(grid.grid.shape[0]): 244 | for x in range(grid.grid.shape[1]): 245 | self.assertTrue(np.array_equal(grid.grid[y,x,:],expected_color)) 246 | 247 | def test_erase_connections(self): 248 | grid = Grid(self.grid_size, self.unit_size) 249 | coord1 = [2,3] 250 | coord2 = [2,4] 251 | grid.draw(coord1, grid.BODY_COLOR) 252 | grid.connect(coord1,coord2) 253 | grid.erase_connections(coord1) 254 | for y in range(grid.grid.shape[0]): 255 | for x in range(grid.grid.shape[1]): 256 | if y >= coord1[1]*self.unit_size and y < coord1[1]*self.unit_size+grid.unit_size-grid.unit_gap and x >= coord1[0]*self.unit_size and x < coord1[0]*self.unit_size+grid.unit_size-grid.unit_gap: 257 | self.assertTrue(np.array_equal(grid.grid[y,x,:],grid.BODY_COLOR)) 258 | else: 259 | self.assertFalse(np.array_equal(grid.grid[y,x,:],grid.BODY_COLOR)) 260 | 261 | def test_open_space(self): 262 | grid = Grid([10,10], self.unit_size) 263 | self.assertTrue(grid.open_space == 100) 264 | for i in range(1,10): 265 | grid.draw([i,i], grid.BODY_COLOR) 266 | self.assertTrue(grid.open_space == 100-i) 267 | for i in range(1,10): 268 | grid.erase([i,i]) 269 | self.assertTrue(grid.open_space == 91+i) 270 | snake_len = 3 271 | snake = Snake((5,5), snake_len) 272 | grid.draw_snake(snake) 273 | self.assertTrue(grid.open_space == 100-snake_len) 274 | 275 | def test_open_space_draw(self): 276 | grid = Grid([10,10], self.unit_size) 277 | for i in range(1,10): 278 | grid.draw([i,i], grid.BODY_COLOR) 279 | self.assertTrue(grid.open_space == 100-i) 280 | 281 | def test_open_space_erase(self): 282 | grid = Grid([10,10], self.unit_size) 283 | for i in range(1,10): 284 | grid.erase([i,i]) 285 | self.assertTrue(grid.open_space == 100+i) 286 | 287 | def test_open_space_draw_snake(self): 288 | grid = Grid([10,10], self.unit_size) 289 | snake_len = 3 290 | snake = Snake((5,5), snake_len) 291 | grid.draw_snake(snake) 292 | self.assertTrue(grid.open_space == 100-snake_len) 293 | 294 | def test_open_space_erase_snake_body(self): 295 | grid = Grid([10,10], self.unit_size) 296 | snake_len = 3 297 | snake = Snake((5,5), snake_len) 298 | grid.erase_snake_body(snake) 299 | self.assertTrue(grid.open_space == 100+snake_len-1) 300 | 301 | 302 | if __name__ == "__main__": 303 | unittest.main() 304 | -------------------------------------------------------------------------------- /gym_snake/envs/snake/snake.py: -------------------------------------------------------------------------------- 1 | from queue import deque 2 | import numpy as np 3 | 4 | class Snake(): 5 | 6 | """ 7 | The Snake class holds all pertinent information regarding the Snake's movement and boday. 8 | The position of the snake is tracked using a queue that stores the positions of the body. 9 | 10 | Note: 11 | A potentially more space efficient implementation could track directional changes rather 12 | than tracking each location of the snake's body. 13 | """ 14 | 15 | UP = 0 16 | RIGHT = 1 17 | DOWN = 2 18 | LEFT = 3 19 | 20 | def __init__(self, head_coord_start, length=3): 21 | """ 22 | head_coord_start - tuple, list, or ndarray denoting the starting coordinates for the snake's head 23 | length - starting number of units in snake's body 24 | """ 25 | 26 | self.direction = self.DOWN 27 | self.head = np.asarray(head_coord_start).astype(np.int) 28 | self.head_color = np.array([255,0,0], np.uint8) 29 | self.body = deque() 30 | for i in range(length-1, 0, -1): 31 | self.body.append(self.head-np.asarray([0,i]).astype(np.int)) 32 | 33 | def step(self, coord, direction): 34 | """ 35 | Takes a step in the specified direction from the specified coordinate. 36 | 37 | coord - list, tuple, or numpy array 38 | direction - integer from 1-4 inclusive. 39 | 0: up 40 | 1: right 41 | 2: down 42 | 3: left 43 | """ 44 | 45 | assert direction < 4 and direction >= 0 46 | 47 | if direction == self.UP: 48 | return np.asarray([coord[0], coord[1]-1]).astype(np.int) 49 | elif direction == self.RIGHT: 50 | return np.asarray([coord[0]+1, coord[1]]).astype(np.int) 51 | elif direction == self.DOWN: 52 | return np.asarray([coord[0], coord[1]+1]).astype(np.int) 53 | else: 54 | return np.asarray([coord[0]-1, coord[1]]).astype(np.int) 55 | 56 | def action(self, direction): 57 | """ 58 | This method sets a new head coordinate and appends the old head 59 | into the body queue. The Controller class handles popping the 60 | last piece of the body if no food is eaten on this step. 61 | 62 | The direction can be any integer value, but will be collapsed 63 | to 0, 1, 2, or 3 corresponding to up, right, down, left respectively. 64 | 65 | direction - integer from 0-3 inclusive. 66 | 0: up 67 | 1: right 68 | 2: down 69 | 3: left 70 | """ 71 | 72 | # Ensure direction is either 0, 1, 2, or 3 73 | direction = (int(direction) % 4) 74 | 75 | if np.abs(self.direction-direction) != 2: 76 | self.direction = direction 77 | 78 | self.body.append(self.head) 79 | self.head = self.step(self.head, self.direction) 80 | 81 | return self.head 82 | -------------------------------------------------------------------------------- /gym_snake/envs/snake/snake_unittests.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | from gym_snake.envs.snake import Snake 4 | 5 | class SnakeTests(unittest.TestCase): 6 | 7 | head_xy = [0,0] 8 | bod_len = 3 9 | 10 | def test_head_Initialization(self): 11 | kaa = Snake(self.head_xy, self.bod_len) 12 | self.assertTrue(np.array_equal(self.head_xy, kaa.head)) 13 | 14 | def test_body_Initialization(self): 15 | kaa = Snake(self.head_xy, self.bod_len) 16 | expected_body_coords = [[0,-2], [0,-1]] 17 | for i in range(len(kaa.body)): 18 | self.assertTrue(np.array_equal(kaa.body.popleft(), expected_body_coords[i])) 19 | 20 | def test_step_UP(self): 21 | kaa = Snake(self.head_xy, self.bod_len) 22 | expected_coord = [0,-1] 23 | actual_coord = kaa.step(kaa.head, kaa.UP) 24 | self.assertTrue(np.array_equal(expected_coord,actual_coord)) 25 | 26 | def test_step_RIGHT(self): 27 | kaa = Snake(self.head_xy, self.bod_len) 28 | expected_coord = [1,0] 29 | actual_coord = kaa.step(kaa.head,kaa.RIGHT) 30 | self.assertTrue(np.array_equal(expected_coord,actual_coord)) 31 | 32 | def test_step_DOWN(self): 33 | kaa = Snake(self.head_xy, self.bod_len) 34 | expected_coord = [0,1] 35 | actual_coord = kaa.step(kaa.head,kaa.DOWN) 36 | self.assertTrue(np.array_equal(expected_coord,actual_coord)) 37 | 38 | def test_step_LEFT(self): 39 | kaa = Snake(self.head_xy, self.bod_len) 40 | expected_coord = [-1,0] 41 | actual_coord = kaa.step(kaa.head,kaa.LEFT) 42 | self.assertTrue(np.array_equal(expected_coord,actual_coord)) 43 | 44 | def test_action_UP(self): 45 | kaa = Snake(self.head_xy, self.bod_len) 46 | kaa.direction = kaa.UP 47 | expected_coord = [0,-1] 48 | actual_coord = kaa.action(kaa.UP) 49 | self.assertTrue(np.array_equal(expected_coord,actual_coord)) 50 | 51 | def test_action_RIGHT(self): 52 | kaa = Snake(self.head_xy, self.bod_len) 53 | expected_coord = [1,0] 54 | actual_coord = kaa.action(kaa.RIGHT) 55 | self.assertTrue(np.array_equal(expected_coord,actual_coord)) 56 | 57 | def test_action_DOWN(self): 58 | kaa = Snake(self.head_xy, self.bod_len) 59 | kaa.direction = kaa.DOWN 60 | expected_coord = [0,1] 61 | actual_coord = kaa.action(kaa.DOWN) 62 | self.assertTrue(np.array_equal(expected_coord,actual_coord)) 63 | 64 | def test_action_LEFT(self): 65 | kaa = Snake(self.head_xy, self.bod_len) 66 | expected_coord = [-1,0] 67 | actual_coord = kaa.action(kaa.LEFT) 68 | self.assertTrue(np.array_equal(expected_coord,actual_coord)) 69 | 70 | def test_action_UP_outofrange(self): 71 | kaa = Snake(self.head_xy, self.bod_len) 72 | kaa.direction = kaa.UP 73 | expected_coord = [0,-1] 74 | actual_coord = kaa.action(kaa.UP+4) 75 | self.assertTrue(np.array_equal(expected_coord,actual_coord)) 76 | 77 | def test_action_RIGHT_outofrange(self): 78 | kaa = Snake(self.head_xy, self.bod_len) 79 | expected_coord = [1,0] 80 | actual_coord = kaa.action(kaa.RIGHT+4) 81 | self.assertTrue(np.array_equal(expected_coord,actual_coord)) 82 | 83 | def test_action_DOWN_outofrange(self): 84 | kaa = Snake(self.head_xy, self.bod_len) 85 | kaa.direction = kaa.DOWN 86 | expected_coord = [0,1] 87 | actual_coord = kaa.action(kaa.DOWN+4) 88 | self.assertTrue(np.array_equal(expected_coord,actual_coord)) 89 | 90 | def test_action_LEFT_outofrange(self): 91 | kaa = Snake(self.head_xy, self.bod_len) 92 | expected_coord = [-1,0] 93 | actual_coord = kaa.action(kaa.LEFT+4) 94 | self.assertTrue(np.array_equal(expected_coord,actual_coord)) 95 | 96 | def test_action_UP_backwards(self): 97 | kaa = Snake(self.head_xy, self.bod_len) 98 | kaa.direction = kaa.UP 99 | head = kaa.action(kaa.DOWN) 100 | self.assertTrue(np.array_equal(head, [0,-1])) 101 | 102 | def test_action_RIGHT_backwards(self): 103 | kaa = Snake(self.head_xy, self.bod_len) 104 | kaa.direction = kaa.RIGHT 105 | head = kaa.action(kaa.LEFT) 106 | self.assertTrue(np.array_equal(head, [1,0])) 107 | 108 | def test_action_DOWN_backwards(self): 109 | kaa = Snake(self.head_xy, self.bod_len) 110 | kaa.direction = kaa.DOWN 111 | head = kaa.action(kaa.UP) 112 | self.assertTrue(np.array_equal(head, [0,1])) 113 | 114 | def test_action_LEFT_backwards(self): 115 | kaa = Snake(self.head_xy, self.bod_len) 116 | kaa.direction = kaa.LEFT 117 | head = kaa.action(kaa.RIGHT) 118 | self.assertTrue(np.array_equal(head, [-1,0])) 119 | 120 | 121 | 122 | if __name__ == "__main__": 123 | unittest.main() 124 | -------------------------------------------------------------------------------- /gym_snake/envs/snake_env.py: -------------------------------------------------------------------------------- 1 | import os, subprocess, time, signal 2 | import numpy as np 3 | import gym 4 | from gym import error, spaces, utils 5 | from gym.utils import seeding 6 | from gym_snake.envs.snake import Controller, Discrete 7 | 8 | try: 9 | import matplotlib.pyplot as plt 10 | import matplotlib 11 | except ImportError as e: 12 | raise error.DependencyNotInstalled("{}. (HINT: see matplotlib documentation for installation https://matplotlib.org/faq/installing_faq.html#installation".format(e)) 13 | 14 | class SnakeEnv(gym.Env): 15 | metadata = {'render.modes': ['human']} 16 | 17 | def __init__(self, grid_size=[15,15], unit_size=10, unit_gap=1, snake_size=3, n_snakes=1, n_foods=1, random_init=True): 18 | self.grid_size = grid_size 19 | self.unit_size = unit_size 20 | self.unit_gap = unit_gap 21 | self.snake_size = snake_size 22 | self.n_snakes = n_snakes 23 | self.n_foods = n_foods 24 | self.viewer = None 25 | self.random_init = random_init 26 | 27 | self.action_space = spaces.Discrete(4) 28 | 29 | controller = Controller( 30 | self.grid_size, self.unit_size, self.unit_gap, 31 | self.snake_size, self.n_snakes, self.n_foods, 32 | random_init=self.random_init) 33 | grid = controller.grid 34 | self.observation_space = spaces.Box( 35 | low=np.min(grid.COLORS), 36 | high=np.max(grid.COLORS), 37 | ) 38 | 39 | def step(self, action): 40 | self.last_obs, rewards, done, info = self.controller.step(action) 41 | return self.last_obs, rewards, done, info 42 | 43 | def reset(self): 44 | self.controller = Controller(self.grid_size, self.unit_size, self.unit_gap, self.snake_size, self.n_snakes, self.n_foods, random_init=self.random_init) 45 | self.last_obs = self.controller.grid.grid.copy() 46 | return self.last_obs 47 | 48 | def render(self, mode='human', close=False, frame_speed=.1): 49 | if self.viewer is None: 50 | self.fig = plt.figure() 51 | self.viewer = self.fig.add_subplot(111) 52 | plt.ion() 53 | self.fig.show() 54 | self.viewer.clear() 55 | self.viewer.imshow(self.last_obs) 56 | plt.pause(frame_speed) 57 | self.fig.canvas.draw() 58 | 59 | def seed(self, x): 60 | pass 61 | -------------------------------------------------------------------------------- /gym_snake/envs/snake_extrahard_env.py: -------------------------------------------------------------------------------- 1 | import os, subprocess, time, signal 2 | import gym 3 | from gym import error, spaces, utils 4 | from gym.utils import seeding 5 | from gym_snake.envs.snake import Controller, Discrete 6 | from gym_snake.envs.snake_env import SnakeEnv 7 | 8 | try: 9 | import matplotlib.pyplot as plt 10 | except ImportError as e: 11 | raise error.DependencyNotInstalled("{}. (HINT: see matplotlib documentation for installation https://matplotlib.org/faq/installing_faq.html#installation".format(e)) 12 | 13 | class SnakeExtraHardEnv(SnakeEnv): 14 | metadata = {'render.modes': ['human']} 15 | 16 | def __init__(self, grid_size=[25,25], unit_size=10, unit_gap=1, snake_size=5, n_snakes=3, n_foods=2, random_init=True): 17 | super().__init__( 18 | grid_size=grid_size, 19 | unit_size=unit_size, 20 | unit_gap=unit_gap, 21 | snake_size=snake_size, 22 | n_snakes=n_snakes, 23 | n_foods=n_foods, 24 | random_init=random_init) 25 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | matplotlib 3 | gym 4 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup(name='gym_snake', 4 | version='0.0.1', 5 | author="Satchel Grant", 6 | install_requires=['gym', 'numpy', 'matplotlib'], 7 | python_requires='>=3', 8 | ) 9 | -------------------------------------------------------------------------------- /tests/imgs/biggrid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grantsrb/Gym-Snake/8d1dcb1f5558f6cb6c7fdb700e03c0ec7a798ab5/tests/imgs/biggrid.png -------------------------------------------------------------------------------- /tests/imgs/default.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grantsrb/Gym-Snake/8d1dcb1f5558f6cb6c7fdb700e03c0ec7a798ab5/tests/imgs/default.png -------------------------------------------------------------------------------- /tests/imgs/default_plural.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grantsrb/Gym-Snake/8d1dcb1f5558f6cb6c7fdb700e03c0ec7a798ab5/tests/imgs/default_plural.png -------------------------------------------------------------------------------- /tests/imgs/nogap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grantsrb/Gym-Snake/8d1dcb1f5558f6cb6c7fdb700e03c0ec7a798ab5/tests/imgs/nogap.png -------------------------------------------------------------------------------- /tests/imgs/widegap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grantsrb/Gym-Snake/8d1dcb1f5558f6cb6c7fdb700e03c0ec7a798ab5/tests/imgs/widegap.png -------------------------------------------------------------------------------- /tests/manual_test.py: -------------------------------------------------------------------------------- 1 | from gordongames.envs.ggames.constants import * 2 | import gym 3 | import gym_snake 4 | import time 5 | 6 | if __name__=="__main__": 7 | print("PRESS q to quit") 8 | print("wasd to move, f to press") 9 | env = gym.make('snake-v0') 10 | 11 | done = False 12 | rew = 0 13 | obs = env.reset() 14 | key = "" 15 | action = "w" 16 | while key != "q": 17 | env.render() 18 | key = input("action: ") 19 | if key == "w": action = 0 20 | elif key == "d": action = 1 21 | elif key == "s": action = 2 22 | elif key == "a": action = 3 23 | else: pass 24 | obs, rew, done, info = env.step(action) 25 | print("rew:", rew) 26 | print("done:", done) 27 | print("info") 28 | for k in info.keys(): 29 | print(" ", k, ":", info[k]) 30 | if done: 31 | obs = env.reset() 32 | --------------------------------------------------------------------------------