├── pokemonai ├── __init__.py ├── smogon │ ├── __init__.py │ ├── battle.py │ └── controller.py └── __main__.py ├── requirements.txt ├── README.md ├── setup.py ├── LICENSE └── .gitignore /pokemonai/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pokemonai/smogon/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | websockets>=3.2 2 | aiodns>=1.1.1 3 | aiohttp>=1.3.3 4 | cchardet>=1.1.3 5 | -------------------------------------------------------------------------------- /pokemonai/__main__.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from pokemonai.smogon import controller 3 | 4 | asyncio.get_event_loop().run_until_complete(controller.run()) 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![](http://play.pokemonshowdown.com/sprites/xyani/charizard.gif) Pokémon AI ![](http://play.pokemonshowdown.com/sprites/xyani/pikachu.gif) 2 | ======================== 3 | --- 4 | ### Goal 5 | Reinforcement learning approach to playing competitive Pokémon. Battling is done on Smogon's [Pokémon Showdown! battle simulator](http://play.pokemonshowdown.com/). 6 | 7 | ### Ideas to Explore 8 | * Batch RL (experience replay/fitted Q iteration) 9 | * Online TD Learning 10 | * PG methods (DDPG) 11 | * Monte Carlo Tree Search (with simplfied assumptions) 12 | * Imitation learning with Smogon leaderboard 13 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | import sys 3 | 4 | with open('README.md') as f: 5 | readme = f.read() 6 | 7 | with open('LICENSE') as f: 8 | license = f.read() 9 | 10 | py_version = sys.version_info[:2] 11 | 12 | if py_version < (3, 5): 13 | raise Exception("websockets requires Python >= 3.5.") 14 | 15 | setup( 16 | name='pokemonai', 17 | version='0.0.1', 18 | description='Reinforcement Learning agent to play competitive Pokemon', 19 | long_description=readme, 20 | author='Rishi Shah', 21 | author_email='rishihahs@gmail.com', 22 | url='https://github.com/rishihahs/pokemonai', 23 | license=license, 24 | packages=find_packages(exclude=('bin')), 25 | package_data={ 26 | 'pokemonai': ['data/*.json'] 27 | }, 28 | entry_points={ 29 | 'console_scripts': [ 30 | 'pokemonai = pokemonai.__main__:main' 31 | ] 32 | }, 33 | ) 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Rishi Shah 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *,cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | 93 | # Rope project settings 94 | .ropeproject 95 | -------------------------------------------------------------------------------- /pokemonai/smogon/battle.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Optional, Sequence, Any 2 | import pkgutil 3 | import json 4 | import re 5 | import random 6 | 7 | # Load pokedex 8 | pokedex = json.loads(pkgutil.get_data('pokemonai', 'data/pokedex.json')) 9 | 10 | # Extract JSON from Smogon action request 11 | request_regex = re.compile(r"\|request\|(.+?)$") 12 | # Extract player id for a game 13 | player_regex = re.compile(r"\|player\|(.+?)\|(.+?)\|") 14 | # Extract opponent's pokemon 15 | poke_regex = lambda pid: "\\|poke\\|%s\\|(.+?)\\|" % pid 16 | # Extract active pokemon 17 | switch_regex = lambda pid: "\\|switch\\|%s.+?\\|(.+?)\\|" % pid 18 | 19 | class BattleHandler(object): 20 | 21 | def __init__(self, roomid: str, username: str): 22 | self.roomid = roomid 23 | self.username = username 24 | self.pid = None # Player id in game 25 | self.opponent_pid = None # Opponent player id in game 26 | self.battledata = BattleData() 27 | 28 | 29 | def parse(self, message: str) -> Tuple[Optional[str], bool]: 30 | done = False 31 | response = None 32 | 33 | # Extract player ids 34 | if '|player|' in message: 35 | m = player_regex.search(message) 36 | pid = m.group(1) 37 | uname = m.group(2) 38 | if uname == self.username: 39 | self.pid = pid 40 | else: 41 | self.opponent_pid = pid 42 | 43 | # Extract opponent's pokemon from team preview 44 | if '|poke|' in message: 45 | assert(self.opponent_pid) # We should already know their pid 46 | matches = re.finditer(poke_regex(self.opponent_pid), message) 47 | self.battledata.record_opponent_team([m.group(1) for m in matches]) 48 | print(self.battledata.opponent_team) 49 | 50 | # Extract active pokemon 51 | if '|switch|' in message: 52 | m = re.search(switch_regex(self.opponent_pid), message) 53 | self.battledata.record_opponent_active(m.group(1)) 54 | 55 | # Smogon is requesting an action or switch 56 | if '|request|' in message: 57 | # Data provided with request (i.e. team info) 58 | request_data = json.loads(request_regex.search(message).group(1)) 59 | 60 | if request_data.get('teamPreview'): 61 | choice = self.team_preview(request_data["side"]["pokemon"]) 62 | return ('/team %s' % ''.join(map(str, choice)), done) 63 | elif 'ZMove' in message: 64 | return ('/choose move 1 zmove', done) 65 | else: 66 | return ('/choose move 1', done) 67 | 68 | if '|win|' in message or '|lose|' in message: 69 | done = True 70 | 71 | return (response, done) 72 | 73 | 74 | """ 75 | Given a list of pokemon (team) 76 | returns the order in which team should play 77 | """ 78 | def team_preview(self, pokemon: Sequence[Any]) -> Sequence[int]: 79 | # For now random 80 | choices = list(range(1, len(pokemon))) 81 | random.shuffle(choices) 82 | return choices 83 | 84 | 85 | class BattleData(object): 86 | 87 | def __init__(self): 88 | self.player_team = [] # Our team 89 | self.player_active = None # Our active pokemon in battle 90 | self.opponent_team = [] 91 | self.opponent_active = None 92 | 93 | """ 94 | Given a Smogon pokemon description, 95 | sets that opponent's active pokemon 96 | """ 97 | def record_opponent_active(self, pokemon: str) -> None: 98 | attrs = pokemon.split(',') 99 | pokeid = attrs[0].lower().replace('-', '') 100 | self.opponent_active = any(p for p in self.opponent_team if p['id'] == pokeid) 101 | 102 | """ 103 | Given a list of Smogon pokemon descriptions (i.e. from team preview) 104 | cross references with pokedex to record opponent's team 105 | """ 106 | def record_opponent_team(self, pokemon: Sequence[str]) -> None: 107 | normalized = [] # Normalized names to index into pokedex 108 | genders = [] 109 | for p in pokemon: 110 | attrs = p.split(',') 111 | normalized.append(attrs[0].lower().replace('-', '')) 112 | genders.append(attrs[1] if 1 < len(attrs) else None) # Gender or None if no gender 113 | 114 | # Set opponent's team 115 | self.opponent_team = [] 116 | for idx, name in enumerate(normalized): 117 | poke = pokedex[name] 118 | self.opponent_team.append({ 119 | 'baseStats': poke['baseStats'], 120 | 'weightkg': poke['weightkg'], 121 | 'types': poke['types'], 122 | 'species': poke['species'], 123 | 'gender': genders[idx], 124 | 'id': name 125 | }) 126 | 127 | 128 | -------------------------------------------------------------------------------- /pokemonai/smogon/controller.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import aiohttp 3 | import websockets 4 | import json 5 | import re 6 | 7 | from pokemonai.smogon.battle import BattleHandler 8 | 9 | SMOGON_WEBSOCKET_URI = 'ws://sim.smogon.com:8000/showdown/websocket' 10 | SMOGON_ACTION_URL = 'http://play.pokemonshowdown.com/action.php' 11 | SMOGON_USERNAME = 'gdelta' 12 | SMOGON_PASSWORD = 'gebiet' 13 | 14 | SMOGON_TEAM = '|togekiss|choicescarf|1|airslash,ancientpower,trick,roost|Timid|,,,252,4,252||,0,,,,|||]|mimikyu|mentalherb||trickroom,destinybond,painsplit,shadowsneak||||,,,,,0||1|]|torkoal|firiumz|1|fireblast,|Modest|4,,,252,252,||,0,,,,|||' 15 | 16 | MAX_PARALLEL_GAMES = 1 17 | 18 | # Pool of open battles 19 | pool = set() 20 | 21 | """ 22 | Continuously searches for battles and runs MAX_PARALLEL_GAMES battles at a time 23 | """ 24 | async def run(): 25 | async with websockets.connect(SMOGON_WEBSOCKET_URI) as websocket: 26 | await _connect(websocket) 27 | 28 | while True: 29 | while len(pool) < MAX_PARALLEL_GAMES: 30 | # Preload connection for battle so that it receives messages 31 | conn = await _open_connection() 32 | 33 | # Start Smogon battle search 34 | await websocket.send('|/utm %s' % SMOGON_TEAM) 35 | await websocket.send('|/search gen71v1') 36 | 37 | # Wait for battle initialization 38 | msg = '' 39 | while not '|init|battle' in msg: 40 | msg = await websocket.recv() 41 | 42 | # Room id is in the first line of the message 43 | # of the form >roomid 44 | m = re.match('>(.+?)\n', msg) 45 | roomid = m.group(1) 46 | 47 | # Start battle handler with preloaded connection 48 | bh = BattleHandler(roomid, SMOGON_USERNAME) 49 | pool.add(asyncio.ensure_future(_battle(conn, bh))) 50 | 51 | # Wait for battle handlers to complete 52 | done, _ = await asyncio.wait(pool, return_when=asyncio.FIRST_COMPLETED) 53 | 54 | # Remove done from pool 55 | for d in done: 56 | pool.remove(d) 57 | 58 | 59 | """ 60 | Handles initial handshaking/auth with Smogon 61 | """ 62 | async def _connect(websocket): 63 | # Initial communication 64 | await websocket.send('|/cmd rooms') 65 | await websocket.send('|/autojoin') 66 | 67 | # Wait for 'challstr' response needed for authentication 68 | challstr = '' 69 | while not challstr.startswith('|challstr|'): 70 | challstr = await websocket.recv() 71 | challstr = challstr.replace('|challstr|', '') 72 | 73 | # Authenticate with Smogon 74 | assertion_token = await _authenticate(SMOGON_USERNAME, SMOGON_PASSWORD, challstr) 75 | await websocket.send('|/trn %s,0,%s' % (SMOGON_USERNAME, assertion_token)) 76 | 77 | # Wait for verification that we are logged in 78 | resp = '' 79 | while not resp.startswith('|updateuser|%s' % SMOGON_USERNAME): 80 | resp = await websocket.recv() 81 | 82 | 83 | """ 84 | Authenticates with Smogon server 85 | by taking a username and challstr 86 | to get an 'assertion' token 87 | 88 | Returns: assertion token 89 | """ 90 | async def _authenticate(username, password, challstr): 91 | async with aiohttp.ClientSession() as session: 92 | res = await session.post(SMOGON_ACTION_URL, data={ 93 | 'act': 'login', 94 | 'name': username, 95 | 'pass': password, 96 | 'challstr': challstr 97 | }) 98 | async with res: 99 | body = await res.text() 100 | 101 | # First character is ']' followed by JSON object 102 | return json.loads(body[1:])['assertion'] 103 | 104 | 105 | """ 106 | Opens new websocket and connects to Smogon 107 | NOTE: Does not handle closing of websocket 108 | """ 109 | async def _open_connection(): 110 | websocket = await websockets.connect(SMOGON_WEBSOCKET_URI) 111 | await _connect(websocket) 112 | return websocket 113 | 114 | 115 | """ 116 | Communicates with Smogon about battle 117 | NOTE: Closes given websocket on completion 118 | """ 119 | async def _battle(websocket, battlehandler): 120 | try: 121 | # Start the battle timer 122 | await websocket.send('%s|/timer on' % battlehandler.roomid) 123 | 124 | while True: 125 | msg = await websocket.recv() 126 | if msg.startswith('>' + battlehandler.roomid): 127 | #print(msg) 128 | response, done = battlehandler.parse(msg) 129 | 130 | if response: 131 | await websocket.send('%s|%s' % (battlehandler.roomid, response)) 132 | 133 | if done: 134 | break 135 | finally: 136 | await websocket.close() 137 | 138 | --------------------------------------------------------------------------------