├── setup.cfg ├── .travis.yml ├── rebuild.sh ├── rebuild.ps1 ├── travis_test.py ├── test.py ├── LICENSE ├── setup.py ├── .gitignore ├── README.md └── retrowrapper.py /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "3.5" 4 | - "3.6" 5 | install: 6 | - pip install gym-retro 7 | script: 8 | - python travis_test.py 9 | -------------------------------------------------------------------------------- /rebuild.sh: -------------------------------------------------------------------------------- 1 | rm -rf build/ 2 | rm -rf dist/ 3 | rm -rf retrowrapper.egg-info/ 4 | pipreqs --force . 5 | python3 setup.py bdist_wheel 6 | twine upload dist/* 7 | echo "Done. Please remember to make a release on github via:" 8 | echo "git tag -a v -m \"\"" 9 | echo "git push origin v" 10 | -------------------------------------------------------------------------------- /rebuild.ps1: -------------------------------------------------------------------------------- 1 | Remove-Item .\build -Recurse -Force 2 | Remove-Item .\dist -Recurse -Force 3 | Remove-Item .\retrowrapper.egg-info -Recurse -Force 4 | 5 | pipreqs.exe --force . 6 | 7 | python setup.py bdist_wheel 8 | 9 | twine.exe upload dist\* 10 | 11 | Write-Host "Done. Please remember to make a release on github via:" 12 | Write-Host "git tag -a v -m " 13 | Write-Host "git push origin v" -------------------------------------------------------------------------------- /travis_test.py: -------------------------------------------------------------------------------- 1 | import retrowrapper 2 | 3 | if __name__ == "__main__": 4 | game = "Airstriker-Genesis" 5 | env1 = retrowrapper.RetroWrapper(game) 6 | env2 = retrowrapper.RetroWrapper(game) 7 | _obs = env1.reset() 8 | _obs = env2.reset() 9 | 10 | done = False 11 | while not done: 12 | action = env1.action_space.sample() 13 | _obs, _rew, done, _info = env1.step(action) 14 | 15 | action = env2.action_space.sample() 16 | _obs, _rew, done, _info = env2.step(action) 17 | 18 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import retrowrapper 2 | 3 | if __name__ == "__main__": 4 | game = "SonicTheHedgehog-Genesis" 5 | state = "GreenHillZone.Act1" 6 | env1 = retrowrapper.RetroWrapper(game, state=state) 7 | env2 = retrowrapper.RetroWrapper(game, state=state) 8 | _obs = env1.reset() 9 | _obs = env2.reset() 10 | 11 | done = False 12 | while not done: 13 | action = env1.action_space.sample() 14 | _obs, _rew, done, _info = env1.step(action) 15 | env1.render() 16 | 17 | action = env2.action_space.sample() 18 | _obs, _rew, done, _info = env2.step(action) 19 | env2.render() -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Max Strange 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | __doc__ = """Wrapper for OpenAI Retro Gym to allow multiple processes.""" 2 | 3 | from setuptools import setup 4 | 5 | setup( 6 | name="retrowrapper", 7 | version="0.3.0", 8 | author="Max Strange", 9 | author_email="maxfieldstrange@gmail.com", 10 | description="Wrapper for OpenAI Retro Gym environments to allow multiple processes.", 11 | install_requires=["gym-retro"], 12 | license="MIT", 13 | keywords="reinforcement-learning retro ai rl dl deep-learning gym openai", 14 | url="https://github.com/MaxStrange/retrowrapper", 15 | py_modules=["retrowrapper"], 16 | python_requires="~=3.4", 17 | long_description=__doc__, 18 | classifiers=[ 19 | "Development Status :: 3 - Alpha", 20 | "Intended Audience :: Developers", 21 | "License :: OSI Approved :: MIT License", 22 | "Operating System :: OS Independent", 23 | "Programming Language :: Python", 24 | "Programming Language :: Python :: 3", 25 | "Programming Language :: Python :: 3.4", 26 | "Programming Language :: Python :: 3.5", 27 | "Programming Language :: Python :: 3.6", 28 | "Programming Language :: Python :: 3.7", 29 | ] 30 | ) 31 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # retrowrapper 2 | 3 | [![Build Status](https://travis-ci.org/MaxStrange/retrowrapper.svg?branch=master)](https://travis-ci.org/MaxStrange/retrowrapper) 4 | 5 | Wrapper for OpenAI Retro envs for parallel execution 6 | 7 | OpenAI's [Retro](https://github.com/openai/retro) exposes an OpenAI [gym](https://gym.openai.com/) interface for Deep Reinforcement Learning, but 8 | unfortunately, their back-end only allows one emulator instance per process. To get around this, I wrote this class. 9 | 10 | ## To Use 11 | 12 | To use it, just instantiate it like you would a normal retro environment, and then treat it exactly the same, but now you can have multiples in a single python process. Magic! 13 | 14 | ```python 15 | import retrowrapper 16 | 17 | if __name__ == "__main__": 18 | game = "SonicTheHedgehog-Genesis" 19 | state = "GreenHillZone.Act1" 20 | env1 = retrowrapper.RetroWrapper(game, state=state) 21 | env2 = retrowrapper.RetroWrapper(game, state=state) 22 | _obs = env1.reset() 23 | _obs = env2.reset() 24 | 25 | done = False 26 | while not done: 27 | action = env1.action_space.sample() 28 | _obs, _rew, done, _info = env1.step(action) 29 | env1.render() 30 | 31 | action = env2.action_space.sample() 32 | _obs, _rew, done, _info = env2.step(action) 33 | env2.render() 34 | ``` 35 | 36 | ## Using a custom make function 37 | 38 | Sometimes you will need a custom make function, for example the `retro_contest` 39 | repository requires you to use their make function rather than `retro.make`. 40 | 41 | In these cases you can use the `retrowrapper.set_retro_make()` to set a new 42 | make function. 43 | 44 | Example usage: 45 | 46 | ```python 47 | import retrowrapper 48 | from retro_contest.local import make 49 | 50 | retrowrapper.set_retro_make( make ) 51 | 52 | env1 = retrowrapper.RetroWrapper( 53 | game='SonicTheHedgehog2-Genesis', 54 | state='MetropolisZone.Act1' 55 | ) 56 | env2 = retrowrapper.RetroWrapper( 57 | game='SonicTheHedgehog2-Genesis', 58 | state='MetropolisZone.Act2' 59 | ) 60 | ``` 61 | -------------------------------------------------------------------------------- /retrowrapper.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module exposes the RetroWrapper class. 3 | """ 4 | import multiprocessing 5 | import retro 6 | import gc 7 | 8 | MAKE_RETRIES = 5 9 | 10 | def set_retro_make( new_retro_make_func ): 11 | RetroWrapper.retro_make_func = new_retro_make_func 12 | 13 | def _retrocom(rx, tx, game, kwargs): 14 | """ 15 | This function is the target for RetroWrapper's internal 16 | process and does all the work of communicating with the 17 | environment. 18 | """ 19 | env = RetroWrapper.retro_make_func(game, **kwargs) 20 | 21 | # Sit around on the queue, waiting for calls from RetroWrapper 22 | while True: 23 | attr, args, kwargs = rx.get() 24 | 25 | # First, handle special case where the wrapper is asking if attr is callable. 26 | # In this case, we actually have RetroWrapper.symbol, attr, and {}. 27 | if attr == RetroWrapper.symbol: 28 | result = env.__getattribute__(args) 29 | tx.put(callable(result)) 30 | elif attr == "close": 31 | env.close() 32 | break 33 | else: 34 | # Otherwise, handle the request 35 | result = getattr(env, attr) 36 | if callable(result): 37 | result = result(*args, **kwargs) 38 | tx.put(result) 39 | 40 | 41 | class RetroWrapper(): 42 | """ 43 | This class is a thin wrapper around a retro environment. 44 | 45 | The purpose of this class is to protect us from the fact 46 | that each Python process can only have a single retro 47 | environment at a time, and we would like potentially 48 | several. 49 | 50 | This class gets around this limitation by spawning a process 51 | internally that sits around waiting for retro environment 52 | API calls, asking its own local copy of the environment, and 53 | then returning the answer. 54 | 55 | Call functions on this object exactly as if it were a retro env. 56 | """ 57 | symbol = "THIS IS A SPECIAL MESSAGE FOR YOU" 58 | retro_make_func = retro.make 59 | 60 | def __init__(self, game, **kwargs): 61 | tempenv = None 62 | retry_counter = MAKE_RETRIES 63 | while True: 64 | try: 65 | tempenv = RetroWrapper.retro_make_func(game, **kwargs) 66 | except RuntimeError: # Sometimes we need to gc.collect because previous tempenvs haven't been cleaned up. 67 | gc.collect() 68 | retry_counter -= 1 69 | if retry_counter > 0: 70 | continue 71 | break 72 | 73 | if tempenv == None: 74 | raise RuntimeError( 'Unable to create tempenv' ) 75 | 76 | tempenv.reset() 77 | 78 | if hasattr( tempenv, 'unwrapped' ): # Wrappers don't have gamename or initial_state 79 | tempenv_unwrapped = tempenv.unwrapped 80 | self.gamename = tempenv_unwrapped.gamename 81 | self.initial_state = tempenv_unwrapped.initial_state 82 | 83 | self.action_space = tempenv.action_space 84 | self.metadata = tempenv.metadata 85 | self.observation_space = tempenv.observation_space 86 | self.reward_range = tempenv.reward_range 87 | tempenv.close() 88 | 89 | self._rx = multiprocessing.Queue() 90 | self._tx = multiprocessing.Queue() 91 | self._proc = multiprocessing.Process(target=_retrocom, args=(self._tx, self._rx, game, kwargs), daemon=True) 92 | self._proc.start() 93 | 94 | def __del__(self): 95 | """ 96 | Make sure to clean up. 97 | """ 98 | self.close() 99 | 100 | def __getattr__(self, attr): 101 | """ 102 | Any time a client calls anything on our object, we want to check to 103 | see if we can answer without having to ask the retro process. Usually, 104 | we will have to ask it. If we do, we put a request into the queue for the 105 | result of whatever the client requested and block until it comes back. 106 | 107 | Otherwise we simply give the client whatever we have that they want. 108 | 109 | BTW: This doesn't work for magic methods. To get those working is a little more involved. TODO 110 | """ 111 | # E.g.: Client calls env.step(action) 112 | ignore_list = ['class', 'mro', 'new', 'init', 'setattr', 'getattr', 'getattribute'] 113 | if attr in self.__dict__ and attr not in ignore_list: 114 | # 1. Check if we have a step function. If so, return it. 115 | return attr 116 | else: 117 | # 2. If we don't, return a function that calls step with whatever args are passed in to it. 118 | is_callable = self._ask_if_attr_is_callable(attr) 119 | 120 | if is_callable: 121 | # The result of getattr(attr) is a callable, so return a wrapper 122 | # that pretends to be the function the user was trying to call 123 | def wrapper(*args, **kwargs): 124 | self._tx.put((attr, args, kwargs)) 125 | return self._rx.get() 126 | return wrapper 127 | else: 128 | # The result of getattr(attr) is not a callable, so we should just 129 | # execute the request for the user and return the result 130 | self._tx.put((attr, [], {})) 131 | return self._tx.get() 132 | 133 | def _ask_if_attr_is_callable(self, attr): 134 | """ 135 | Returns whether or not the attribute is a callable. 136 | """ 137 | self._tx.put((RetroWrapper.symbol, attr, {})) 138 | return self._rx.get() 139 | 140 | def close(self): 141 | """ 142 | Shutdown the environment. 143 | """ 144 | if "_tx" in self.__dict__ and "_proc" in self.__dict__: 145 | self._tx.put(("close", (), {})) 146 | self._proc.join() 147 | --------------------------------------------------------------------------------