├── .dockerignore ├── .gitignore ├── LICENSE ├── README.md ├── attach.sh ├── roms └── .keep ├── setup ├── agent-pytorch.docker ├── agent-tf.docker ├── agent.docker ├── bash_profile ├── import_all.sh ├── remote-env-roms.docker ├── remote-env.docker └── setup.sh └── support ├── VERSION ├── gym_remote ├── __init__.py ├── bridge.py ├── client.py ├── exceptions.py └── server.py ├── retro_contest ├── __init__.py ├── __main__.py ├── agent.py ├── docker.py ├── local.py ├── remote.py └── rest.py ├── setup.cfg ├── setup.py └── tests ├── __init__.py ├── test_bridge.py └── test_env.py /.dockerignore: -------------------------------------------------------------------------------- 1 | gym-retro/.git 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __cache__ 2 | .cache 3 | *.pyc 4 | *.egg-info 5 | *.eggs 6 | VERSION.txt 7 | roms/*.smd 8 | roms/*.SGD 9 | roms/*.bin 10 | roms/*.68K 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2017 OpenAI (http://openai.com) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **Status:** Archive (code is provided as-is, no updates expected) 2 | 3 | # OpenAI Retro Contest 4 | -------------------------------------------------------------------------------- /attach.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export NVIDIA_DOCKER=${NVIDIA_DOCKER:-"nvidia-docker"} 4 | export COMPO_RESULTS=$(realpath $(dirname -- $0))/results 5 | 6 | mkdir -p $COMPO_RESULTS 7 | 8 | alias \ 9 | docker-retro-contest-agent="\$NVIDIA_DOCKER run --rm -v compo-tmp-vol:/root/compo/tmp agent retro-contest-agent" \ 10 | docker-retro-contest-remote="docker run --rm -v compo-tmp-vol:/root/compo/tmp -v \$COMPO_RESULTS:/root/compo/results remote-env retro-contest-remote" 11 | -------------------------------------------------------------------------------- /roms/.keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/retro-contest/0512db0baa07b6878b7d0f34a23a7d34838b39e6/roms/.keep -------------------------------------------------------------------------------- /setup/agent-pytorch.docker: -------------------------------------------------------------------------------- 1 | FROM openai/retro-agent:bare-cuda8 2 | 3 | ARG PYTORCH 4 | RUN . ~/venv/bin/activate && \ 5 | pip install http://download.pytorch.org/whl/cu80/torch-${PYTORCH}-cp35-cp35m-linux_x86_64.whl && \ 6 | pip install torchvision && \ 7 | rm -r ~/.cache 8 | -------------------------------------------------------------------------------- /setup/agent-tf.docker: -------------------------------------------------------------------------------- 1 | ARG CUDA 2 | FROM openai/retro-agent:bare-cuda${CUDA} 3 | 4 | ARG TF 5 | RUN . ~/venv/bin/activate && \ 6 | pip install tensorflow-gpu==$TF && \ 7 | rm -r ~/.cache 8 | -------------------------------------------------------------------------------- /setup/agent.docker: -------------------------------------------------------------------------------- 1 | ARG CUDA 2 | ARG BASE=nvidia/cuda 3 | ARG TAG=${CUDA}-runtime-ubuntu 4 | FROM ${BASE}:${TAG}16.04 5 | 6 | ARG CUDA 7 | ARG CUDNN 8 | 9 | # Set up dependency layers 10 | ENV DEBIAN_FRONTEND=noninteractive 11 | SHELL ["/bin/bash", "-c"] 12 | RUN ([ -z "$CUDA" ] || echo "deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1604/x86_64 /" > /etc/apt/sources.list.d/nvidia-ml.list) && \ 13 | apt-get update && \ 14 | apt-get install -y --no-install-recommends python3-venv build-essential git && \ 15 | ([ -z "$CUDA" ] || apt-get install -y --no-install-recommends cuda-driver-dev-${CUDA} libcudnn${CUDNN}=${CUDNN}.0.*+cuda${CUDA}) && \ 16 | apt-get clean && \ 17 | python3 -m venv ~/venv && \ 18 | . ~/venv/bin/activate && \ 19 | pip install wheel && \ 20 | rm -r ~/.cache 21 | 22 | # Set up competition-specific layers 23 | COPY support /tmp/support 24 | RUN . ~/venv/bin/activate && \ 25 | pip install gym>=0.9.6 && \ 26 | pip install /tmp/support && \ 27 | rm -r ~/.cache 28 | 29 | RUN echo agent > /root/hostname && \ 30 | mkdir -p /root/compo/tmp /root/compo/out 31 | COPY setup/bash_profile /root/.bash_profile 32 | VOLUME /root/compo/tmp 33 | VOLUME /root/compo/out 34 | WORKDIR /root/compo 35 | ENTRYPOINT ["bash", "-lc", "exec $0 $@"] 36 | CMD retro-contest-agent 37 | -------------------------------------------------------------------------------- /setup/bash_profile: -------------------------------------------------------------------------------- 1 | if [ ! -e /usr/lib/x86_64-linux-gnu/libcuda.so.1 ]; then 2 | ln -s /usr/local/cuda/lib64/stubs/libcuda.so /usr/lib/x86_64-linux-gnu/libcuda.so.1 3 | fi 4 | mkdir -p $HOME/compo/tmp/sock 5 | HOSTNAME=$(cat $HOME/hostname) 6 | source $HOME/venv/bin/activate 7 | PS1="\u@$HOSTNAME:\W# " 8 | -------------------------------------------------------------------------------- /setup/import_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATA_PATH=$(python -c 'import retro; print(retro.data.path(retro.data.DATA_PATH))') 4 | echo "Importing games from $DATA_PATH..." 5 | CONTAINER_ID=$(docker run -v "$DATA_PATH":/root/roms:ro -d remote-env 'python /tmp/gym-retro/scripts/import.py /root/roms') 6 | docker attach $CONTAINER_ID 7 | docker commit $CONTAINER_ID remote-env:full 8 | docker rm $CONTAINER_ID 9 | -------------------------------------------------------------------------------- /setup/remote-env-roms.docker: -------------------------------------------------------------------------------- 1 | FROM openai/retro-env 2 | 3 | # Set up ROMs 4 | COPY roms /root/roms 5 | RUN . ~/venv/bin/activate && \ 6 | python -m retro.import /root/roms && \ 7 | rm -r /root/roms 8 | -------------------------------------------------------------------------------- /setup/remote-env.docker: -------------------------------------------------------------------------------- 1 | FROM ubuntu:xenial 2 | 3 | # Set up dependency layers 4 | SHELL ["/bin/bash", "-c"] 5 | RUN apt-get update && \ 6 | DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ 7 | build-essential cmake libpython3-dev libzip-dev pkg-config python3-venv git && \ 8 | apt-get clean && \ 9 | python3 -m venv ~/venv && \ 10 | . ~/venv/bin/activate && \ 11 | pip install wheel && \ 12 | rm -r ~/.cache 13 | 14 | # Set up competition-specific layers 15 | COPY support /tmp/support 16 | RUN . ~/venv/bin/activate && \ 17 | pip install gym>=0.9.6 && \ 18 | pip install git+https://github.com/openai/retro.git@fbb97475859378c2cd7a30670d659744cc2692ea && \ 19 | pip install /tmp/support && \ 20 | rm -r ~/.cache && \ 21 | echo remote-env > /root/hostname && \ 22 | mkdir -p /root/compo/tmp /root/compo/results /root/roms 23 | COPY setup/bash_profile /root/.bash_profile 24 | 25 | VOLUME /root/compo/tmp 26 | VOLUME /root/compo/results 27 | WORKDIR /root/compo 28 | ENTRYPOINT ["bash", "-lc", "exec $0 $@"] 29 | -------------------------------------------------------------------------------- /setup/setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | cd `dirname $0` 4 | 5 | try_build () { 6 | docker build -t $@ || exit $? 7 | } 8 | 9 | echo - Pulling base images 10 | docker pull -a openai/retro-agent 11 | docker pull -a openai/retro-env 12 | if [ "$1" = "rebuild" ]; then 13 | echo - Building base CPU image 14 | try_build openai/retro-agent:bare --build-arg BASE=ubuntu --build-arg TAG= --pull -f agent.docker .. 15 | echo - Building base CUDA images 16 | try_build openai/retro-agent:bare-cuda8 --build-arg CUDA=8.0 --build-arg CUDNN=6 --pull -f agent.docker .. 17 | try_build openai/retro-agent:bare-cuda9 --build-arg CUDA=9.0 --build-arg CUDNN=7 --pull -f agent.docker .. 18 | echo - Building base TensorFlow images 19 | try_build openai/retro-agent:tensorflow-1.4 --build-arg CUDA=8 --build-arg TF=1.4.1 - < agent-tf.docker 20 | try_build openai/retro-agent:tensorflow-1.5 --build-arg CUDA=9 --build-arg TF=1.5.1 - < agent-tf.docker 21 | try_build openai/retro-agent:tensorflow-1.7 --build-arg CUDA=9 --build-arg TF=1.7.1 - < agent-tf.docker 22 | try_build openai/retro-agent:tensorflow-1.8 --build-arg CUDA=9 --build-arg TF=1.8.0 - < agent-tf.docker 23 | echo - Building base PyTorch images 24 | try_build openai/retro-agent:pytorch-0.3 --build-arg PYTORCH=0.3.1 - < agent-pytorch.docker 25 | try_build openai/retro-agent:pytorch-0.4 --build-arg PYTORCH=0.4.0 - < agent-pytorch.docker 26 | echo - Building remote image 27 | try_build openai/retro-env -f remote-env.docker .. 28 | fi 29 | if [ -n "$(ls ../roms)" ]; then 30 | echo - Building remote image with ROMs 31 | docker tag openai/retro-env openai/retro-env:bare 32 | try_build openai/retro-env -f remote-env-roms.docker .. 33 | fi 34 | echo - Tagging images 35 | docker tag openai/retro-agent:tensorflow-1.8 openai/retro-agent:tensorflow-latest 36 | docker tag openai/retro-agent:tensorflow-latest openai/retro-agent:tensorflow 37 | docker tag openai/retro-agent:pytorch-0.4 openai/retro-agent:pytorch 38 | docker tag openai/retro-agent:tensorflow openai/retro-agent:latest 39 | docker tag openai/retro-agent agent 40 | 41 | echo - Installing Python library 42 | pip3 install -e '../support[docker,rest]' 43 | -------------------------------------------------------------------------------- /support/VERSION: -------------------------------------------------------------------------------- 1 | 0.1.1 -------------------------------------------------------------------------------- /support/gym_remote/__init__.py: -------------------------------------------------------------------------------- 1 | from gym_remote.bridge import * 2 | import gym_remote.exceptions as exceptions 3 | -------------------------------------------------------------------------------- /support/gym_remote/bridge.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import gym.spaces 3 | import json 4 | import numpy as np 5 | import os 6 | import socket 7 | 8 | gym_version = tuple(int(x) for x in gym.__version__.split('.')) 9 | 10 | 11 | class Channel: 12 | def __init__(self): 13 | self.sock = None 14 | self.dirty = False 15 | self._value = None 16 | self.annotations = {} 17 | 18 | def set_socket(self, sock): 19 | self.sock = sock 20 | 21 | def set_base(self, base): 22 | pass 23 | 24 | def parse(self, value): 25 | return value 26 | 27 | def unparse(self, value): 28 | return value 29 | 30 | @property 31 | def value(self): 32 | return self.unparse(self._value) 33 | 34 | @value.setter 35 | def value(self, value): 36 | self._value = self.parse(value) 37 | self.dirty = True 38 | 39 | def serialize(self): 40 | return self._value 41 | 42 | def deserialize(self, value): 43 | self._value = self.parse(value) 44 | self.dirty = False 45 | 46 | @staticmethod 47 | def make(type, shape, annotations): 48 | types = { 49 | 'int': IntChannel, 50 | 'float': FloatChannel, 51 | 'bool': BoolChannel, 52 | 'int_fold': IntFoldChannel, 53 | 'np': NpChannel, 54 | } 55 | cls = types[type] 56 | if shape: 57 | ob = cls(*eval(shape, {}, {'dtype': np.dtype})) 58 | else: 59 | ob = cls() 60 | if annotations: 61 | for key, value in annotations.items(): 62 | ob.annotate(key, value) 63 | return ob 64 | 65 | def annotate(self, name, value): 66 | self.annotations[name] = str(value) 67 | 68 | 69 | class IntChannel(Channel): 70 | TYPE = 'int' 71 | SHAPE = None 72 | 73 | def parse(self, value): 74 | return int(value) 75 | 76 | 77 | class FloatChannel(Channel): 78 | TYPE = 'float' 79 | SHAPE = None 80 | 81 | def parse(self, value): 82 | return float(value) 83 | 84 | 85 | class BoolChannel(Channel): 86 | TYPE = 'bool' 87 | SHAPE = None 88 | 89 | def parse(self, value): 90 | return bool(value) 91 | 92 | 93 | class IntFoldChannel(Channel): 94 | TYPE = 'int_fold' 95 | 96 | def __init__(self, folds, dtype=np.int8): 97 | super(IntFoldChannel, self).__init__() 98 | self.folds = np.multiply.accumulate([1] + list(folds)[:-1], dtype=int) 99 | self.ranges = np.array(folds, dtype=int) 100 | self.dtype = dtype 101 | self.SHAPE = str(folds) + ',' 102 | 103 | def parse(self, value): 104 | folded = np.dot(self.folds, value % self.ranges) 105 | return int(folded) 106 | 107 | def unparse(self, value): 108 | if value is None: 109 | return None 110 | unfolded = np.full(self.ranges.shape, value) // self.folds % self.ranges 111 | return unfolded.astype(self.dtype) 112 | 113 | def deserialize(self, value): 114 | self._value = int(value) 115 | self.dirty = False 116 | 117 | 118 | class NpChannel(Channel): 119 | TYPE = 'np' 120 | 121 | def __init__(self, shape, dtype): 122 | super(NpChannel, self).__init__() 123 | self.SHAPE = '%s, %s' % (shape, 'dtype("%s")' % np.dtype(dtype).str) 124 | self.shape = shape 125 | self.dtype = dtype 126 | 127 | def set_base(self, base): 128 | self._value = np.memmap(base, mode='w+', dtype=self.dtype, shape=self.shape) 129 | 130 | @property 131 | def value(self): 132 | return self._value 133 | 134 | @value.setter 135 | def value(self, value): 136 | np.copyto(self._value, value) 137 | self.dirty = True 138 | 139 | def serialize(self): 140 | return True 141 | 142 | def deserialize(self, value): 143 | self.dirty = False 144 | 145 | 146 | class Bridge: 147 | Timeout = socket.timeout 148 | Closed = BrokenPipeError 149 | 150 | def __init__(self, base): 151 | self.base = base 152 | self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) 153 | 154 | def close(message): 155 | self.close() 156 | if 'exception' in message: 157 | import gym_remote.exceptions as gre 158 | exception = gre.make(message['exception'], message['reason']) 159 | else: 160 | exception = self.Closed(message['reason']) 161 | raise exception 162 | 163 | def exception(message): 164 | import gym_remote.exceptions as gre 165 | raise gre.make(message['exception'], message['reason']) 166 | 167 | self._channels = {} 168 | self.connection = None 169 | self._buffer = [] 170 | self._message_handlers = { 171 | 'update': self.update_vars, 172 | 'close': close, 173 | 'exception': exception 174 | } 175 | 176 | def __del__(self): 177 | self.close() 178 | 179 | def add_channel(self, name, channel): 180 | if name in self._channels: 181 | raise KeyError(name) 182 | self._channels[name] = channel 183 | channel.set_base(os.path.join(self.base, name)) 184 | return channel 185 | 186 | def wrap(self, name, space): 187 | channel = None 188 | if isinstance(space, gym.spaces.MultiBinary): 189 | if space.n < 64: 190 | channel = IntFoldChannel([2] * space.n, np.uint8) 191 | else: 192 | channel = NpChannel((space.n,), np.uint8) 193 | channel.annotate('n', space.n) 194 | channel.annotate('type', 'MultiBinary') 195 | elif isinstance(space, gym.spaces.Discrete): 196 | channel = IntChannel() 197 | channel.annotate('n', space.n) 198 | channel.annotate('type', 'Discrete') 199 | elif isinstance(space, gym.spaces.MultiDiscrete): 200 | if gym_version >= (0, 9, 6): 201 | channel = NpChannel(space.shape, np.int64) 202 | channel.annotate('shape', space.shape[0]) 203 | else: 204 | channel = NpChannel((space.shape,), np.int64) 205 | channel.annotate('shape', space.shape) 206 | channel.annotate('type', 'MultiDiscrete') 207 | elif isinstance(space, gym.spaces.Box): 208 | channel = NpChannel(space.shape, space.high.dtype) 209 | channel.annotate('type', 'Box') 210 | channel.annotate('shape', space.shape) 211 | 212 | if not channel: 213 | raise NotImplementedError('Unsupported space') 214 | 215 | return self.add_channel(name, channel) 216 | 217 | @staticmethod 218 | def unwrap(space): 219 | if space.annotations['type'] == 'MultiBinary': 220 | return gym.spaces.MultiBinary(int(space.annotations['n'])) 221 | if space.annotations['type'] == 'Discrete': 222 | return gym.spaces.Discrete(int(space.annotations['n'])) 223 | if space.annotations['type'] == 'MultiDiscrete': 224 | if gym_version >= (0, 9, 6): 225 | return gym.spaces.MultiDiscrete(space.shape[0]) 226 | else: 227 | return gym.spaces.MultiDiscrete(space.shape) 228 | if space.annotations['type'] == 'Box': 229 | kwargs = {} 230 | if gym_version >= (0, 9, 6): 231 | kwargs['dtype'] = space.dtype 232 | return gym.spaces.Box(low=0, high=255, shape=space.shape, **kwargs) 233 | 234 | def configure_channels(self, channel_info): 235 | for name, info in channel_info.items(): 236 | self._channels[name] = Channel.make(*info) 237 | 238 | def describe_channels(self): 239 | description = {} 240 | for name, channel in self._channels.items(): 241 | description[name] = (channel.TYPE, channel.SHAPE, channel.annotations) 242 | return description 243 | 244 | def listen(self): 245 | sock_path = os.path.join(self.base, 'sock') 246 | self.sock.bind(sock_path) 247 | self.sock.listen(1) 248 | 249 | def connect(self): 250 | sock_path = os.path.join(self.base, 'sock') 251 | self.sock.connect(sock_path) 252 | self.connection = self.sock 253 | 254 | def server_accept(self): 255 | self.connection, _ = self.sock.accept() 256 | for name, channel in self._channels.items(): 257 | channel.set_socket(self.connection) 258 | description = self.describe_channels() 259 | self._send_message('description', description) 260 | 261 | def configure_client(self): 262 | description = self._recv_message() 263 | assert description['type'] == 'description' 264 | self.configure_channels(description['content']) 265 | for name, channel in self._channels.items(): 266 | channel.set_socket(self.connection) 267 | channel.set_base(os.path.join(self.base, name)) 268 | return dict(self._channels) 269 | 270 | def _try_send(self, type, content): 271 | try: 272 | self._send_message(type, content) 273 | except self.Closed as e: 274 | try: 275 | while True: 276 | self.recv() 277 | except self.Closed as f: 278 | e = f 279 | self.close() 280 | raise e 281 | 282 | def _send_message(self, type, content): 283 | if not self.connection: 284 | raise self.Closed 285 | message = { 286 | 'type': type, 287 | 'content': content 288 | } 289 | # All messages end in a form feed 290 | message = json.dumps(message) + '\f' 291 | self.connection.sendall(message.encode('utf8')) 292 | 293 | def _recv_message(self): 294 | if not self.connection: 295 | raise self.Closed 296 | while len(self._buffer) < 2: 297 | # There are no fully buffered messages 298 | message = self.connection.recv(4096) 299 | if not message: 300 | raise self.Closed 301 | message = message.split(b'\f') 302 | if self._buffer: 303 | self._buffer[-1] += message.pop(0) 304 | self._buffer.extend(message) 305 | message = self._buffer.pop(0) 306 | return json.loads(message.decode('utf8')) 307 | 308 | def update_vars(self, vars): 309 | for name, value in vars.items(): 310 | self._channels[name].deserialize(value) 311 | 312 | def send(self): 313 | content = {} 314 | for name, channel in self._channels.items(): 315 | if channel.dirty: 316 | content[name] = channel.serialize() 317 | self._try_send('update', content) 318 | 319 | def recv(self): 320 | message = self._recv_message() 321 | if not message: 322 | raise self.Closed 323 | self._message_handlers[message['type']](message['content']) 324 | return True 325 | 326 | def close(self, reason=None, exception=None): 327 | if self.sock: 328 | try: 329 | kwargs = {'reason': reason} 330 | if exception: 331 | kwargs['exception'] = exception.ID 332 | self._send_message('close', kwargs) 333 | except self.Closed: 334 | pass 335 | self.sock.close() 336 | if self.sock and self.connection != self.sock: 337 | if self.connection: 338 | self.connection.close() 339 | try: 340 | os.unlink(os.path.join(self.base, 'sock')) 341 | except OSError: 342 | pass 343 | for name, channel in self._channels.items(): 344 | try: 345 | os.unlink(os.path.join(self.base, name)) 346 | except OSError: 347 | pass 348 | self.connection = None 349 | self.sock = None 350 | 351 | def exception(self, exception, reason=None): 352 | content = {'reason': reason, 'exception': exception.ID} 353 | self._try_send('exception', content) 354 | 355 | def settimeout(self, timeout): 356 | self.sock.settimeout(timeout) 357 | if self.connection: 358 | self.connection.settimeout(timeout) 359 | 360 | def __del__(self): 361 | self.close() 362 | -------------------------------------------------------------------------------- /support/gym_remote/client.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import time 3 | 4 | from gym_remote import Bridge 5 | 6 | 7 | class RemoteEnv(gym.Env): 8 | def __init__(self, directory, tries=8): 9 | self.bridge = Bridge(directory) 10 | 11 | # Try a few times to connect 12 | backoff = 2 13 | for x in range(tries): 14 | try: 15 | self.bridge.connect() 16 | break 17 | except FileNotFoundError: 18 | if x + 1 == tries: 19 | raise 20 | time.sleep(backoff) 21 | backoff *= 2 22 | 23 | self.bridge.configure_client() 24 | self.ch_ac = self.bridge._channels['ac'] 25 | self.ch_ob = self.bridge._channels['ob'] 26 | self.ch_reward = self.bridge._channels['reward'] 27 | self.ch_done = self.bridge._channels['done'] 28 | self.ch_reset = self.bridge._channels['reset'] 29 | self.action_space = self.bridge.unwrap(self.ch_ac) 30 | self.observation_space = self.bridge.unwrap(self.ch_ob) 31 | 32 | def step(self, action): 33 | self.ch_ac.value = action 34 | self.bridge.send() 35 | self.bridge.recv() 36 | 37 | return self.ch_ob.value, self.ch_reward.value, self.ch_done.value, {} 38 | 39 | def reset(self): 40 | self.ch_reset.value = True 41 | self.bridge.send() 42 | self.bridge.recv() 43 | return self.ch_ob.value 44 | 45 | def close(self): 46 | self.bridge.close() 47 | -------------------------------------------------------------------------------- /support/gym_remote/exceptions.py: -------------------------------------------------------------------------------- 1 | import gym_remote as gr 2 | 3 | 4 | class GymRemoteErrorMeta(type): 5 | ID_MAX = 0 6 | ID_LIST = [] 7 | 8 | def __new__(cls, name, bases, dictionary): 9 | dictionary['ID'] = cls.ID_MAX 10 | cls.ID_MAX += 1 11 | try: 12 | bases = (*bases, GymRemoteError) 13 | except NameError: 14 | assert name == 'GymRemoteError' 15 | newcls = super(GymRemoteErrorMeta, cls).__new__(cls, name, bases, dictionary) 16 | cls.ID_LIST.append(newcls) 17 | return newcls 18 | 19 | @classmethod 20 | def make(cls, id, *args, **kwargs): 21 | return cls.ID_LIST[id](*args, **kwargs) 22 | 23 | 24 | class GymRemoteError(RuntimeError, metaclass=GymRemoteErrorMeta): 25 | pass 26 | 27 | 28 | class TimestepTimeoutError(TimeoutError, metaclass=GymRemoteErrorMeta): 29 | pass 30 | 31 | 32 | class WallClockTimeoutError(TimeoutError, metaclass=GymRemoteErrorMeta): 33 | pass 34 | 35 | 36 | class ClientDisconnectError(gr.Bridge.Closed, metaclass=GymRemoteErrorMeta): 37 | pass 38 | 39 | 40 | class ServerDisconnectError(gr.Bridge.Closed, metaclass=GymRemoteErrorMeta): 41 | pass 42 | 43 | 44 | class ResetError(metaclass=GymRemoteErrorMeta): 45 | pass 46 | 47 | 48 | def make(id, *args, **kwargs): 49 | return GymRemoteErrorMeta.make(id, *args, **kwargs) 50 | -------------------------------------------------------------------------------- /support/gym_remote/server.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import time 3 | 4 | from gym_remote import Bridge, FloatChannel, BoolChannel 5 | import gym_remote.exceptions as gre 6 | 7 | 8 | class RemoteEnvWrapper(gym.Wrapper): 9 | def __init__(self, env, directory): 10 | gym.Wrapper.__init__(self, env) 11 | self.bridge = Bridge(directory) 12 | self.ch_ac = self.bridge.wrap('ac', env.action_space) 13 | self.ch_ob = self.bridge.wrap('ob', env.observation_space) 14 | self.ch_reward = self.bridge.add_channel('reward', FloatChannel()) 15 | self.ch_done = self.bridge.add_channel('done', BoolChannel()) 16 | self.ch_reset = self.bridge.add_channel('reset', BoolChannel()) 17 | self.bridge.listen() 18 | 19 | def serve(self, timestep_limit=None, wallclock_limit=None, ignore_reset=False): 20 | if wallclock_limit is not None: 21 | end = time.time() + wallclock_limit 22 | self.bridge.settimeout(wallclock_limit) 23 | else: 24 | end = None 25 | ts = 0 26 | 27 | try: 28 | self.bridge.server_accept() 29 | except Bridge.Timeout: 30 | return ts 31 | 32 | done = True 33 | 34 | while timestep_limit is None or ts < timestep_limit: 35 | if wallclock_limit: 36 | t = time.time() 37 | if t >= end: 38 | self.bridge.close(exception=gre.WallClockTimeoutError) 39 | break 40 | self.bridge.settimeout(end - t) 41 | try: 42 | self.bridge.recv() 43 | except Bridge.Timeout: 44 | self.bridge.close(exception=gre.WallClockTimeoutError) 45 | break 46 | except Bridge.Closed: 47 | self.bridge.close(exception=gre.ClientDisconnectError) 48 | break 49 | 50 | if self.ch_reset.value: 51 | if ignore_reset and not done: 52 | self.bridge.exception(gre.ResetError) 53 | self.bridge.send() 54 | continue 55 | self.ch_ob.value = self.env.reset() 56 | self.ch_reset.value = False 57 | self.ch_reward.value = 0 58 | self.ch_done.value = False 59 | done = False 60 | else: 61 | if ignore_reset and done: 62 | self.bridge.exception(gre.ResetError) 63 | self.bridge.send() 64 | continue 65 | ob, rew, done, _ = self.env.step(self.ch_ac.value) 66 | self.ch_ob.value = ob 67 | self.ch_reward.value = rew 68 | self.ch_done.value = done 69 | self.bridge.send() 70 | ts += 1 71 | 72 | if timestep_limit and ts >= timestep_limit: 73 | self.bridge.close(exception=gre.TimestepTimeoutError) 74 | return ts 75 | 76 | def close(self): 77 | self.bridge.close() 78 | self.env.close() 79 | -------------------------------------------------------------------------------- /support/retro_contest/__init__.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import gym 3 | import numpy as np 4 | import time 5 | 6 | 7 | class StochasticFrameSkip(gym.Wrapper): 8 | def __init__(self, env, n, stickprob): 9 | gym.Wrapper.__init__(self, env) 10 | self.n = n 11 | self.stickprob = stickprob 12 | self.curac = None 13 | self.rng = np.random.RandomState() 14 | 15 | def reset(self, **kwargs): 16 | self.curac = None 17 | return self.env.reset(**kwargs) 18 | 19 | def step(self, ac): 20 | done = False 21 | totrew = 0 22 | for i in range(self.n): 23 | # First step after reset, use action 24 | if self.curac is None: 25 | self.curac = ac 26 | # First substep, delay with probability=stickprob 27 | elif i == 0: 28 | if self.rng.rand() > self.stickprob: 29 | self.curac = ac 30 | # Second substep, new action definitely kicks in 31 | elif i == 1: 32 | self.curac = ac 33 | ob, rew, done, info = self.env.step(self.curac) 34 | totrew += rew 35 | if done: 36 | break 37 | return ob, totrew, done, info 38 | 39 | 40 | class Monitor(gym.Wrapper): 41 | def __init__(self, env, monitorfile, logfile=None): 42 | gym.Wrapper.__init__(self, env) 43 | self.file = open(monitorfile, 'w') 44 | self.csv = csv.DictWriter(self.file, ['r', 'l', 't']) 45 | self.log = open(logfile, 'w') 46 | self.logcsv = csv.DictWriter(self.log, ['l', 't']) 47 | self.episode_reward = 0 48 | self.episode_length = 0 49 | self.total_length = 0 50 | self.start = None 51 | self.csv.writeheader() 52 | self.file.flush() 53 | self.logcsv.writeheader() 54 | self.log.flush() 55 | 56 | def reset(self, **kwargs): 57 | if not self.start: 58 | self.start = time.time() 59 | else: 60 | self.csv.writerow({ 61 | 'r': self.episode_reward, 62 | 'l': self.episode_length, 63 | 't': time.time() - self.start 64 | }) 65 | self.file.flush() 66 | self.episode_length = 0 67 | self.episode_reward = 0 68 | return self.env.reset(**kwargs) 69 | 70 | def step(self, ac): 71 | ob, rew, done, info = self.env.step(ac) 72 | self.episode_length += 1 73 | self.total_length += 1 74 | self.episode_reward += rew 75 | if self.total_length % 1000 == 0: 76 | self.logcsv.writerow({ 77 | 'l': self.total_length, 78 | 't': time.time() - self.start 79 | }) 80 | self.log.flush() 81 | return ob, rew, done, info 82 | 83 | def __del__(self): 84 | self.file.close() 85 | -------------------------------------------------------------------------------- /support/retro_contest/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | 4 | try: 5 | import retro_contest.docker as docker_cmd 6 | except ImportError: 7 | docker_cmd = None 8 | 9 | try: 10 | import retro_contest.rest as rest_cmd 11 | except ImportError: 12 | rest_cmd = None 13 | 14 | 15 | def main(argv=sys.argv[1:]): 16 | parser = argparse.ArgumentParser(description='Run OpenAI Retro Contest support code') 17 | parser.set_defaults(func=lambda args: parser.print_help()) 18 | subparsers = parser.add_subparsers() 19 | if docker_cmd: 20 | docker_cmd.init_parser(subparsers) 21 | if rest_cmd: 22 | rest_cmd.init_parsers(subparsers) 23 | 24 | args = parser.parse_args(argv) 25 | if not args.func(args): 26 | sys.exit(1) 27 | 28 | 29 | if __name__ == '__main__': 30 | main() 31 | -------------------------------------------------------------------------------- /support/retro_contest/agent.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gym_remote.exceptions as gre 3 | import gym_remote.client as grc 4 | import os 5 | import sys 6 | import traceback 7 | from pkg_resources import EntryPoint 8 | 9 | 10 | def make(socketdir='tmp/sock'): 11 | env = grc.RemoteEnv(socketdir) 12 | return env 13 | 14 | 15 | def run(agent=None, socketdir='tmp/sock', daemonize=False, args=[]): 16 | if daemonize: 17 | pid = os.fork() 18 | if pid > 0: 19 | return 20 | 21 | if agent is None: 22 | print('Running agent: random_agent') 23 | agent = random_agent 24 | elif not callable(agent): 25 | print('Running agent: %s' % agent) 26 | entrypoint = EntryPoint.parse('entry=' + agent) 27 | agent = entrypoint.load(False) 28 | else: 29 | print('Running agent: %r' % agent) 30 | env = make(socketdir) 31 | try: 32 | agent(env, *args) 33 | except gre.GymRemoteError: 34 | pass 35 | 36 | 37 | def random_agent(env, *args): 38 | env.reset() 39 | while True: 40 | action = env.action_space.sample() 41 | try: 42 | ob, reward, done, _ = env.step(action) 43 | except gre.ResetError: 44 | done = True 45 | if done: 46 | env.reset() 47 | 48 | 49 | def main(argv=sys.argv[1:]): 50 | parser = argparse.ArgumentParser(description='Run support code for OpenAI Retro Contest remote environment') 51 | parser.add_argument('--daemonize', '-d', action='store_true', default=False, help='Daemonize (background) the process') 52 | parser.add_argument('entry', type=str, nargs='?', help='Entry point to create an agent') 53 | parser.add_argument('args', nargs='*', help='Optional arguments to the agent') 54 | 55 | args = parser.parse_args(argv) 56 | run(agent=args.entry, daemonize=args.daemonize, args=args.args) 57 | 58 | 59 | if __name__ == '__main__': 60 | main() 61 | -------------------------------------------------------------------------------- /support/retro_contest/docker.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import docker 3 | import io 4 | import os 5 | import random 6 | import requests.exceptions 7 | import sys 8 | import tempfile 9 | import threading 10 | try: 11 | from retro import data_path 12 | except ImportError: 13 | def data_path(): 14 | raise RuntimeError('Could not find Gym Retro data directory') 15 | 16 | 17 | class LogThread: 18 | def __init__(self, container): 19 | self._log = container.logs(stdout=True, stderr=True, stream=True) 20 | self._thread = threading.Thread(target=self._run) 21 | self._active = False 22 | 23 | def start(self): 24 | if self._active: 25 | return 26 | self._active = True 27 | self._thread.start() 28 | 29 | def exit(self): 30 | self._active = False 31 | 32 | def _run(self): 33 | while self._active: 34 | try: 35 | print(next(self._log).decode('utf-8'), end='') 36 | except StopIteration: 37 | break 38 | 39 | 40 | def convert_path(path): 41 | if sys.platform.startswith('win') and path[1] == ':': 42 | path = '/%s%s' % (path[0].lower(), path[2:].replace('\\', '/')) 43 | return path 44 | 45 | 46 | def run(game, state=None, entry=None, **kwargs): 47 | client = docker.from_env() 48 | remote_command = ['retro-contest-remote', 'run', game, *([state] if state else []), '-b', 'results/bk2', '-m', 'results'] 49 | remote_name = kwargs.get('remote_env', 'openai/retro-env') 50 | agent_command = [] 51 | agent_name = kwargs.get('agent', 'agent') 52 | datamount = {} 53 | 54 | if kwargs.get('wallclock_limit') is not None: 55 | remote_command.extend(['-W', str(kwargs['wallclock_limit'])]) 56 | if kwargs.get('timestep_limit') is not None: 57 | remote_command.extend(['-T', str(kwargs['timestep_limit'])]) 58 | if kwargs.get('discrete_actions'): 59 | remote_command.extend(['-D']) 60 | 61 | if entry: 62 | agent_command.append(entry) 63 | if kwargs.get('entry_args'): 64 | agent_command.extend(kwargs['entry_args']) 65 | 66 | rand = ''.join(random.sample('abcdefghijklmnopqrstuvwxyz0123456789', 8)) 67 | volname = 'retro-contest-tmp%s' % rand 68 | datamount = {} 69 | agentmount = {} 70 | if kwargs.get('resultsdir'): 71 | results = os.path.realpath(kwargs['resultsdir']) 72 | datamount[convert_path(results)] = {'bind': '/root/compo/results'} 73 | os.makedirs(results, exist_ok=True) 74 | else: 75 | results = None 76 | 77 | if kwargs.get('agentdir'): 78 | agentdir = os.path.realpath(kwargs['agentdir']) 79 | agentmount[convert_path(agentdir)] = {'bind': '/root/compo/out'} 80 | os.makedirs(agentdir, exist_ok=True) 81 | 82 | container_kwargs = {'detach': True, 'network_disabled': True} 83 | remote_kwargs = dict(container_kwargs) 84 | agent_kwargs = dict(container_kwargs) 85 | 86 | if kwargs.get('agent_shm'): 87 | agent_kwargs['shm_size'] = kwargs['agent_shm'] 88 | 89 | bridge = client.volumes.create(volname, driver='local', driver_opts={'type': 'tmpfs', 'device': 'tmpfs'}) 90 | if kwargs.get('use_host_data'): 91 | remote_command = [remote_command[0], '--data-dir', '/root/data', *remote_command[1:]] 92 | datamount[convert_path(data_path())] = {'bind': '/root/data', 'mode': 'ro'} 93 | 94 | try: 95 | remote = client.containers.run(remote_name, remote_command, 96 | volumes={volname: {'bind': '/root/compo/tmp'}, 97 | **datamount}, 98 | **remote_kwargs) 99 | except: 100 | bridge.remove() 101 | raise 102 | 103 | try: 104 | agent = client.containers.run(agent_name, agent_command, 105 | volumes={volname: {'bind': '/root/compo/tmp'}, 106 | **agentmount}, 107 | runtime=kwargs.get('runtime', 'nvidia'), 108 | **agent_kwargs) 109 | except: 110 | remote.kill() 111 | remote.remove() 112 | bridge.remove() 113 | raise 114 | 115 | a_exit = None 116 | r_exit = None 117 | 118 | if not kwargs.get('quiet'): 119 | log_thread = LogThread(agent) 120 | log_thread.start() 121 | 122 | try: 123 | while True: 124 | try: 125 | a_exit = agent.wait(timeout=5) 126 | break 127 | except requests.exceptions.RequestException: 128 | pass 129 | try: 130 | r_exit = remote.wait(timeout=5) 131 | break 132 | except requests.exceptions.RequestException: 133 | pass 134 | 135 | if a_exit is None: 136 | try: 137 | a_exit = agent.wait(timeout=10) 138 | except requests.exceptions.RequestException: 139 | agent.kill() 140 | if r_exit is None: 141 | try: 142 | r_exit = remote.wait(timeout=10) 143 | except requests.exceptions.RequestException: 144 | remote.kill() 145 | except: 146 | if a_exit is None: 147 | try: 148 | a_exit = agent.wait(timeout=1) 149 | except: 150 | try: 151 | agent.kill() 152 | except docker.errors.APIError: 153 | pass 154 | if r_exit is None: 155 | try: 156 | r_exit = remote.wait(timeout=1) 157 | except: 158 | try: 159 | remote.kill() 160 | except docker.errors.APIError: 161 | pass 162 | raise 163 | finally: 164 | if isinstance(a_exit, dict): 165 | a_exit = a_exit.get('StatusCode') 166 | if isinstance(r_exit, dict): 167 | r_exit = r_exit.get('StatusCode') 168 | 169 | if not kwargs.get('quiet'): 170 | log_thread.exit() 171 | 172 | logs = { 173 | 'remote': (r_exit, remote.logs(stdout=True, stderr=False), remote.logs(stdout=False, stderr=True)), 174 | 'agent': (a_exit, agent.logs(stdout=True, stderr=False), agent.logs(stdout=False, stderr=True)) 175 | } 176 | 177 | if results: 178 | with open(os.path.join(results, 'remote-stdout.txt'), 'w') as f: 179 | f.write(logs['remote'][1].decode('utf-8')) 180 | with open(os.path.join(results, 'remote-stderr.txt'), 'w') as f: 181 | f.write(logs['remote'][2].decode('utf-8')) 182 | with open(os.path.join(results, 'agent-stdout.txt'), 'w') as f: 183 | f.write(logs['agent'][1].decode('utf-8')) 184 | with open(os.path.join(results, 'agent-stderr.txt'), 'w') as f: 185 | f.write(logs['agent'][2].decode('utf-8')) 186 | 187 | remote.remove() 188 | agent.remove() 189 | bridge.remove() 190 | 191 | return logs 192 | 193 | 194 | def run_args(args): 195 | kwargs = { 196 | 'entry_args': args.args, 197 | 'wallclock_limit': args.wallclock_limit, 198 | 'timestep_limit': args.timestep_limit, 199 | 'discrete_actions': args.discrete_actions, 200 | 'resultsdir': args.results_dir, 201 | 'agentdir': args.agent_dir, 202 | 'quiet': args.quiet, 203 | 'use_host_data': args.use_host_data, 204 | 'agent_shm': args.agent_shm, 205 | } 206 | 207 | if args.no_nv: 208 | kwargs['runtime'] = None 209 | 210 | if args.agent: 211 | kwargs['agent'] = args.agent 212 | 213 | if args.remote_env: 214 | kwargs['remote_env'] = args.remote_env 215 | 216 | results = run(args.game, args.state, args.entry, **kwargs) 217 | if results['remote'][0] or results['agent'][0]: 218 | if results['remote'][0]: 219 | print('Remote exited uncleanly:', results['remote'][0]) 220 | if results['agent'][0]: 221 | print('Agent exited uncleanly', results['agent'][0]) 222 | return False 223 | return True 224 | 225 | 226 | def build(path, tag, install=None, pass_env=False): 227 | from pkg_resources import EntryPoint 228 | import tarfile 229 | if install: 230 | destination = 'module' 231 | else: 232 | destination = 'agent.py' 233 | docker_file = ['FROM openai/retro-agent', 234 | 'COPY context %s' % destination] 235 | 236 | if not install: 237 | docker_file.append('CMD ["python", "-u", "/root/compo/agent.py"]') 238 | else: 239 | docker_file.append('RUN . ~/venv/bin/activate && pip install -e module') 240 | valid = not any(c in install for c in ' "\\') 241 | if pass_env: 242 | try: 243 | EntryPoint.parse('entry=' + install) 244 | except ValueError: 245 | valid = False 246 | if not valid: 247 | raise ValueError('Invalid entry point') 248 | docker_file.append('CMD ["retro-contest-agent", "%s"]' % install) 249 | else: 250 | if not valid: 251 | raise ValueError('Invalid module name') 252 | docker_file.append('CMD ["python", "-u", "-m", "%s"]' % install) 253 | 254 | print('Creating Docker image...') 255 | docker_file_full = io.BytesIO('\n'.join(docker_file).encode('utf-8')) 256 | client = docker.from_env() 257 | with tempfile.NamedTemporaryFile() as f: 258 | tf = tarfile.open(mode='w:gz', fileobj=f) 259 | docker_file_info = tarfile.TarInfo('Dockerfile') 260 | docker_file_info.size = len(docker_file_full.getvalue()) 261 | tf.addfile(docker_file_info, docker_file_full) 262 | tf.add(path, arcname='context', exclude=lambda fname: fname.endswith('/.git')) 263 | tf.close() 264 | f.seek(0) 265 | client.images.build(fileobj=f, custom_context=True, tag=tag, gzip=True) 266 | print('Done!') 267 | 268 | 269 | def build_args(args): 270 | kwargs = { 271 | 'install': args.install, 272 | 'pass_env': args.pass_env, 273 | } 274 | 275 | try: 276 | build(args.path, args.tag, **kwargs) 277 | except docker.errors.BuildError as be: 278 | print(*[log['stream'] for log in be.build_log if 'stream' in log]) 279 | raise 280 | return True 281 | 282 | 283 | def init_parser(subparsers): 284 | parser_run = subparsers.add_parser('run', description='Run Docker containers locally') 285 | parser_run.set_defaults(func=run_args) 286 | parser_run.add_argument('game', type=str, help='Name of the game to run') 287 | parser_run.add_argument('state', type=str, default=None, nargs='?', help='Name of initial state') 288 | parser_run.add_argument('--entry', '-e', type=str, help='Name of agent entry point') 289 | parser_run.add_argument('--args', '-A', type=str, nargs='+', help='Extra agent entry arguments') 290 | parser_run.add_argument('--agent', '-a', type=str, help='Extra agent Docker image') 291 | parser_run.add_argument('--wallclock-limit', '-W', type=float, default=None, help='Maximum time to run in seconds') 292 | parser_run.add_argument('--timestep-limit', '-T', type=int, default=None, help='Maximum time to run in timesteps') 293 | parser_run.add_argument('--no-nv', '-N', action='store_true', help='Disable Nvidia runtime') 294 | parser_run.add_argument('--remote-env', '-R', type=str, help='Remote Docker image') 295 | parser_run.add_argument('--results-dir', '-r', type=str, help='Path to output results') 296 | parser_run.add_argument('--agent-dir', '-o', type=str, help='Path to mount into agent (mounted at /root/compo/out)') 297 | parser_run.add_argument('--discrete-actions', '-D', action='store_true', help='Use a discrete action space') 298 | parser_run.add_argument('--use-host-data', '-d', action='store_true', help='Use the host Gym Retro data directory') 299 | parser_run.add_argument('--quiet', '-q', action='store_true', help='Disable printing agent logs') 300 | parser_run.add_argument('--agent-shm', type=str, help='Agent /dev/shm size') 301 | 302 | parser_build = subparsers.add_parser('build', description='Build agent Docker containers') 303 | parser_build.set_defaults(func=build_args) 304 | parser_build.add_argument('path', type=str, help='Path to a file or package') 305 | parser_build.add_argument('--tag', '-t', required=True, type=str, help='Tag name for the built image') 306 | parser_build.add_argument('--install', '-i', type=str, help='Install as a package and run specified module or entry point (if -e is specified)') 307 | parser_build.add_argument('--pass-env', '-e', action='store_true', help='Pass preconfigured environment to entry point specified by -i') 308 | 309 | 310 | def main(argv=sys.argv[1:]): 311 | parser = argparse.ArgumentParser(description='Run OpenAI Retro Contest support code') 312 | parser.set_defaults(func=lambda args: parser.print_help()) 313 | init_parser(parser.add_subparsers()) 314 | args = parser.parse_args(argv) 315 | if not args.func(args): 316 | sys.exit(1) 317 | 318 | 319 | if __name__ == '__main__': 320 | main() 321 | -------------------------------------------------------------------------------- /support/retro_contest/local.py: -------------------------------------------------------------------------------- 1 | import retro 2 | import retro_contest 3 | import gym 4 | import gym.wrappers 5 | 6 | 7 | def make(game, state=retro.State.DEFAULT, discrete_actions=False, bk2dir=None): 8 | use_restricted_actions = retro.Actions.FILTERED 9 | if discrete_actions: 10 | use_restricted_actions = retro.Actions.DISCRETE 11 | try: 12 | env = retro.make(game, state, scenario='contest', use_restricted_actions=use_restricted_actions) 13 | except Exception: 14 | env = retro.make(game, state, use_restricted_actions=use_restricted_actions) 15 | if bk2dir: 16 | env.auto_record(bk2dir) 17 | env = retro_contest.StochasticFrameSkip(env, n=4, stickprob=0.25) 18 | env = gym.wrappers.TimeLimit(env, max_episode_steps=4500) 19 | return env 20 | -------------------------------------------------------------------------------- /support/retro_contest/remote.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gym 3 | import gym_remote.server as grs 4 | import os 5 | import retro 6 | import retro_contest 7 | import retro_contest.local 8 | import sys 9 | 10 | 11 | def make(game, state=retro.STATE_DEFAULT, bk2dir=None, monitordir=None, discrete_actions=False, socketdir=None): 12 | if bk2dir: 13 | os.makedirs(bk2dir, exist_ok=True) 14 | env = retro_contest.local.make(game, state, discrete_actions=discrete_actions, bk2dir=bk2dir) 15 | if monitordir: 16 | env = retro_contest.Monitor(env, os.path.join(monitordir, 'monitor.csv'), os.path.join(monitordir, 'log.csv')) 17 | env = grs.RemoteEnvWrapper(env, socketdir) 18 | return env 19 | 20 | 21 | def run(game, state, 22 | wallclock_limit=None, timestep_limit=None, 23 | monitordir=None, bk2dir=None, socketdir=None, 24 | discrete_actions=False, daemonize=False): 25 | if daemonize: 26 | pid = os.fork() 27 | if pid > 0: 28 | return 29 | 30 | env = make(game, state, bk2dir, monitordir, discrete_actions, socketdir) 31 | env.serve(timestep_limit=timestep_limit, wallclock_limit=wallclock_limit, ignore_reset=True) 32 | 33 | 34 | def run_args(args): 35 | run(args.game, args.state, 36 | wallclock_limit=args.wallclock_limit, 37 | timestep_limit=args.timestep_limit, 38 | bk2dir=args.bk2dir, 39 | monitordir=args.monitordir, 40 | socketdir=args.socketdir, 41 | discrete_actions=args.discrete_actions, 42 | daemonize=args.daemonize) 43 | 44 | 45 | def list_games(args): 46 | games = retro.data.list_games() 47 | if args.system: 48 | games = [game for game in games if game.endswith('-' + args.system)] 49 | games.sort() 50 | print(*games, sep='\n') 51 | 52 | 53 | def list_states(args): 54 | if args.game: 55 | games = args.game 56 | else: 57 | games = retro.data.list_games() 58 | games.sort() 59 | for game in games: 60 | states = retro.data.list_states(game) 61 | print(game + ':') 62 | states.sort() 63 | for state in states: 64 | print(' ' + state) 65 | 66 | 67 | def main(argv=sys.argv[1:]): 68 | parser = argparse.ArgumentParser(description='Run support code for OpenAI Retro Contest remote environment') 69 | parser.set_defaults(func=lambda args: parser.print_help()) 70 | parser.add_argument('--data-dir', type=str, help='Use a custom data directory (must be named `data`)') 71 | 72 | subparsers = parser.add_subparsers() 73 | parser_run = subparsers.add_parser('run', description='Run Remote environment') 74 | parser_list = subparsers.add_parser('list', description='List information about environments') 75 | 76 | parser_run.set_defaults(func=run_args) 77 | parser_run.add_argument('game', type=str, help='Name of the game to run') 78 | parser_run.add_argument('state', type=str, default=retro.State.DEFAULT, nargs='?', help='Name of initial state') 79 | parser_run.add_argument('--monitordir', '-m', type=str, help='Directory to hold monitor files') 80 | parser_run.add_argument('--bk2dir', '-b', type=str, help='Directory to hold BK2 movies') 81 | parser_run.add_argument('--socketdir', '-s', type=str, default='tmp/sock', help='Directory to hold sockets') 82 | parser_run.add_argument('--daemonize', '-d', action='store_true', default=False, help='Daemonize (background) the process') 83 | parser_run.add_argument('--wallclock-limit', '-W', type=float, default=None, help='Maximum time to run in seconds') 84 | parser_run.add_argument('--timestep-limit', '-T', type=int, default=None, help='Maximum time to run in timesteps') 85 | parser_run.add_argument('--discrete-actions', '-D', action='store_true', help='Use a discrete action space') 86 | 87 | parser_list.set_defaults(func=lambda args: parser_list.print_help()) 88 | subparsers_list = parser_list.add_subparsers() 89 | parser_list_games = subparsers_list.add_parser('games', description='List games') 90 | parser_list_games.set_defaults(func=list_games) 91 | parser_list_games.add_argument('--system', '-s', type=str, help='List for a specific system only') 92 | 93 | parser_list_states = subparsers_list.add_parser('states', description='List') 94 | parser_list_states.set_defaults(func=list_states) 95 | parser_list_states.add_argument('game', type=str, default=None, nargs='*', help='List for specified games only') 96 | 97 | args = parser.parse_args(argv) 98 | if args.data_dir: 99 | retro.data.path(args.data_dir) 100 | args.func(args) 101 | 102 | 103 | if __name__ == '__main__': 104 | main() 105 | -------------------------------------------------------------------------------- /support/retro_contest/rest.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import docker 3 | import getpass 4 | import itertools 5 | import json 6 | import os 7 | import requests 8 | import sys 9 | import yaml 10 | from functools import wraps 11 | from requests.auth import HTTPBasicAuth 12 | 13 | config = {} 14 | 15 | 16 | def update_config(key, value): 17 | config[key] = value 18 | write_config() 19 | 20 | 21 | def clear_config(key): 22 | del config[key] 23 | write_config() 24 | 25 | 26 | def write_config(): 27 | os.makedirs(os.path.join(os.path.expanduser('~'), '.config'), exist_ok=True) 28 | c = yaml.dump(config, default_flow_style=False) 29 | with open(os.path.join(os.path.expanduser('~'), '.config/retro-contest.yml'), 'w') as f: 30 | f.write(c) 31 | 32 | 33 | def load_config(): 34 | global config 35 | try: 36 | with open(os.path.join(os.path.expanduser('~'), '.config/retro-contest.yml')) as f: 37 | config = yaml.safe_load(f) 38 | except FileNotFoundError: 39 | pass 40 | 41 | 42 | def login(email, password, server=None): 43 | if not server: 44 | server = config.get('server') 45 | if server and not (server.startswith('http://') or server.startswith('https://')): 46 | server = 'http://' + server 47 | r = requests.post(server + '/rest/login', json={'email': email, 'password': password}) 48 | if r.status_code // 100 == 2: 49 | update_config('cookies', dict(r.cookies)) 50 | update_config('server', server) 51 | return True 52 | return False 53 | 54 | 55 | def login_args(args): 56 | email = args.email 57 | if not email: 58 | email = input('Email address: ') 59 | password = args.password 60 | if args.password_stdin or not password: 61 | password = getpass.getpass() 62 | server = args.server 63 | if login(email, password, server): 64 | print('Login succeeded') 65 | else: 66 | print('Login failed') 67 | return False 68 | return True 69 | 70 | 71 | def leaderboard_args(args): 72 | server = config.get('server') 73 | r = requests.get(server + '/rest/leaderboard') 74 | if r.status_code // 100 == 2: 75 | try: 76 | info = r.json() 77 | board = info.get('leaderboard') 78 | except: 79 | return 80 | for place, score in zip(itertools.count(info.get('start', 1)), board): 81 | print('#%i:' % place) 82 | print('- User:', score['name']) 83 | print('- Score:', score['score']) 84 | else: 85 | return False 86 | return True 87 | 88 | 89 | def logout_args(args): 90 | clear_config('cookies') 91 | print('Logged out') 92 | return True 93 | 94 | 95 | def needs_login(f): 96 | @wraps(f) 97 | def wrapped(*args, **kwargs): 98 | server = config.get('server') 99 | cookies = config.get('cookies') 100 | if not server or not cookies: 101 | print('You are not logged in') 102 | return None 103 | return f(server=server, cookies=cookies, *args, **kwargs) 104 | return wrapped 105 | 106 | 107 | @needs_login 108 | def docker_login_args(args, server, cookies): 109 | r = requests.get(server + '/rest/user', cookies=cookies) 110 | if r.status_code != 200 or 'cr' not in r.json(): 111 | print('Failed to obtain container registry') 112 | return False 113 | cr = r.json()['cr'] 114 | client = docker.from_env() 115 | client.login(cr['username'], cr['password'], registry=cr['url']) 116 | print('Logged in') 117 | return True 118 | 119 | 120 | @needs_login 121 | def docker_show_args(args, server, cookies): 122 | r = requests.get(server + '/rest/user', cookies=cookies) 123 | if r.status_code != 200 or 'cr' not in r.json(): 124 | print('Failed to obtain container registry') 125 | return False 126 | cr = r.json()['cr'] 127 | print('Registry URL:', cr['url']) 128 | print('Username:', cr['username']) 129 | if args.show_password: 130 | print('Password:', cr['password']) 131 | return True 132 | 133 | 134 | @needs_login 135 | def docker_list_args(args, server, cookies): 136 | r = requests.get(server + '/rest/user', cookies=cookies) 137 | if r.status_code != 200 or 'cr' not in r.json(): 138 | print('Failed to obtain container registry') 139 | return False 140 | cr = r.json()['cr'] 141 | auth = HTTPBasicAuth(cr['username'], cr['password']) 142 | r = requests.get('https://%s/v2/_catalog' % cr['url'], auth=auth) 143 | repos = r.json()['repositories'] 144 | everything = {} 145 | for repo in repos: 146 | r = requests.get('https://%s/v2/%s/tags/list' % (cr['url'], repo), auth=auth) 147 | if r.status_code != 200: 148 | continue 149 | try: 150 | info = r.json() 151 | everything[repo] = info.get('tags') 152 | except: 153 | pass 154 | for k, v in everything.items(): 155 | print(k + ':') 156 | for tag in v: 157 | print(' ' + tag) 158 | return True 159 | 160 | 161 | @needs_login 162 | def show_args(args, server, cookies): 163 | endpoint = server + '/rest/job/status' 164 | if args.all: 165 | endpoint += '/all' 166 | elif args.id: 167 | endpoint += '/%d' % args.id 168 | r = requests.get(endpoint, cookies=cookies) 169 | if r.status_code == 404: 170 | print('No job found') 171 | return False 172 | elif r.status_code == 200: 173 | jobs = r.json() 174 | if not args.all: 175 | jobs = [jobs] 176 | for job in jobs: 177 | if args.verbose: 178 | print('ID:', job['id']) 179 | print('Status:', job['status']) 180 | if 'score' in job: 181 | print('Score:', job['score']) 182 | print('Workers:') 183 | for worker in job['workers']: 184 | print('- Task:', worker['task']) 185 | print(' Status:', worker['state']) 186 | if 'eta' in worker: 187 | print(' ETA (seconds):', worker['eta']) 188 | if 'progress' in worker: 189 | print(' Progress (percent):', worker['progress'] * 100) 190 | if 'score' in worker: 191 | print(' Score:', worker['score']) 192 | if 'error' in worker: 193 | print(' Error:', worker['error']) 194 | else: 195 | print('%i: %s' % (job['id'], job['status'])) 196 | else: 197 | print('Error %i occurred' % r.status_code) 198 | return False 199 | return True 200 | 201 | @needs_login 202 | def kill_args(args, server, cookies): 203 | if not args.yes: 204 | yn = input('Are you sure? [y/N] ') 205 | if yn.lower() not in ('y', 'yes'): 206 | print('Not canceled') 207 | return True 208 | r = requests.post(server + '/rest/job/kill', cookies=cookies) 209 | if r.status_code == 404: 210 | print('No job found') 211 | return False 212 | elif r.status_code // 100 == 2: 213 | print('Canceled') 214 | else: 215 | print('Error %i occurred' % r.status_code) 216 | return False 217 | return True 218 | 219 | 220 | @needs_login 221 | def restart_args(args, server, cookies): 222 | if not args.yes: 223 | yn = input('Are you sure? [y/N] ') 224 | if yn.lower() not in ('y', 'yes'): 225 | print('Not restarted') 226 | return True 227 | if args.id: 228 | suffix = '/%d' % args.id 229 | else: 230 | suffix = '' 231 | r = requests.post(server + '/rest/job/restart' + suffix, cookies=cookies) 232 | if r.status_code == 404: 233 | print('No job found') 234 | return False 235 | elif r.status_code // 100 == 2: 236 | print('Restarted') 237 | else: 238 | print('Error %i occurred' % r.status_code) 239 | return False 240 | return True 241 | 242 | 243 | @needs_login 244 | def submit_args(args, server, cookies): 245 | r = requests.get(server + '/rest/user', cookies=cookies) 246 | if r.status_code != 200 or 'cr' not in r.json(): 247 | print('Failed to obtain container registry') 248 | cr = r.json()['cr'] 249 | client = docker.APIClient() 250 | cr['registry'] = cr['url'] 251 | del cr['url'] 252 | tag = args.tag or 'agent:latest' 253 | try: 254 | client.tag(tag, cr['registry'] + '/' + tag) 255 | except requests.exceptions.HTTPError: 256 | print('Could not find local tag') 257 | return False 258 | print('Pushing container...') 259 | size = {} 260 | for line in client.push(cr['registry'] + '/' + tag, stream=True, auth_config=cr): 261 | for line in line.split(b'\r\n'): 262 | if not line: 263 | continue 264 | line = json.loads(line) 265 | if 'status' not in line: 266 | continue 267 | if line['status'] == 'Pushing': 268 | size[line['id']] = int(line['progressDetail']['current']), int(line['progressDetail']['total']) 269 | elif line['status'] == 'Pushed': 270 | size[line['id']] = size[line['id']][1], size[line['id']][1] 271 | 272 | current, total = 0, 0 273 | for c, t in size.values(): 274 | current += c 275 | total += t 276 | if total > 0: 277 | print('\u001B[2K\r%i%%' % (100 * current / total), end='', flush=True) 278 | print('\u001B[2K\rPushed, submitting job') 279 | r = requests.post(server + '/rest/job/start', json={'tag': tag}, cookies=cookies) 280 | if r.status_code // 100 == 2: 281 | print('Done') 282 | else: 283 | print('Error %i occurred' % r.status_code) 284 | return False 285 | return True 286 | 287 | 288 | def init_parsers(subparsers): 289 | load_config() 290 | 291 | parser_login = subparsers.add_parser('login', description='Log into server') 292 | parser_login.set_defaults(func=login_args) 293 | parser_login.add_argument('--email', type=str, help='Your email address') 294 | parser_login.add_argument('--password', type=str, help='Your password (you should use --password-stdin instead)') 295 | parser_login.add_argument('--password-stdin', action='store_true', help='Read password from stdin') 296 | parser_login.add_argument('--server', type=str, help='Server to log into') 297 | 298 | parser_logout = subparsers.add_parser('logout', description='Log out of server') 299 | parser_logout.set_defaults(func=logout_args) 300 | 301 | parser_leaderboard = subparsers.add_parser('leaderboard', description='Get leaderboard') 302 | parser_leaderboard.set_defaults(func=leaderboard_args) 303 | 304 | parser_docker = subparsers.add_parser('docker', description='Docker support commands') 305 | parser_docker.set_defaults(func=lambda args: parser_docker.print_help()) 306 | subparsers_docker = parser_docker.add_subparsers() 307 | 308 | parser_docker_login = subparsers_docker.add_parser('login', description='Log into user Docker registry') 309 | parser_docker_login.set_defaults(func=docker_login_args) 310 | 311 | parser_docker_show = subparsers_docker.add_parser('show', description='Show information about user Docker registry') 312 | parser_docker_show.set_defaults(func=docker_show_args) 313 | parser_docker_show.add_argument('-p', '--show-password', action='store_true', help='Show login password') 314 | 315 | parser_docker_list = subparsers_docker.add_parser('list', description='List contents of Docker registry') 316 | parser_docker_list.set_defaults(func=docker_list_args) 317 | 318 | parser_job = subparsers.add_parser('job', description='Operations on jobs') 319 | parser_job.set_defaults(func=lambda args: parser_job.print_help()) 320 | subparsers_job = parser_job.add_subparsers() 321 | 322 | parser_job_show = subparsers_job.add_parser('show', description='Show current job, if it exists') 323 | parser_job_show.set_defaults(func=show_args) 324 | parser_job_show.add_argument('id', nargs='?', type=int, help='List a specific job ID') 325 | parser_job_show.add_argument('-v', '--verbose', action='store_true', help='Be more verbose') 326 | parser_job_show.add_argument('-a', '--all', action='store_true', help='Show all jobs') 327 | 328 | parser_job_kill = subparsers_job.add_parser('cancel', description='Cancel current job') 329 | parser_job_kill.set_defaults(func=kill_args) 330 | parser_job_kill.add_argument('-y', '--yes', action='store_true', help='Do not display confirmation') 331 | 332 | parser_job_restart = subparsers_job.add_parser('restart', description='Restart job') 333 | parser_job_restart.set_defaults(func=restart_args) 334 | parser_job_restart.add_argument('-y', '--yes', action='store_true', help='Do not display confirmation') 335 | parser_job_restart.add_argument('id', nargs='?', type=int, help='Job ID to restart (default: latest)') 336 | 337 | parser_job_submit = subparsers_job.add_parser('submit', description='Submit new job') 338 | parser_job_submit.set_defaults(func=submit_args) 339 | parser_job_submit.add_argument('-t', '--tag', type=str, help='Local tag to push') 340 | 341 | 342 | def main(argv=sys.argv[1:]): 343 | parser = argparse.ArgumentParser(description='Run OpenAI Retro Contest support code') 344 | parser.set_defaults(func=lambda args: parser.print_help()) 345 | init_parsers(parser.add_subparsers()) 346 | args = parser.parse_args(argv) 347 | if not args.func(args): 348 | sys.exit(1) 349 | 350 | 351 | if __name__ == '__main__': 352 | main() 353 | -------------------------------------------------------------------------------- /support/setup.cfg: -------------------------------------------------------------------------------- 1 | [pycodestyle] 2 | ignore = E501 3 | -------------------------------------------------------------------------------- /support/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | import sys 3 | import os 4 | import shutil 5 | 6 | VERSION_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'VERSION') 7 | 8 | if not os.path.exists(os.path.join(os.path.dirname(__file__), '.git')): 9 | use_scm_version = False 10 | shutil.copy('VERSION', 'gym_remote/VERSION.txt') 11 | else: 12 | def version_scheme(version): 13 | with open(VERSION_PATH) as v: 14 | version_file = v.read() 15 | if version.distance: 16 | version_file += '.dev%d' % version.distance 17 | return version_file 18 | 19 | def local_scheme(version): 20 | v = '' 21 | if version.distance: 22 | v = '+' + version.node 23 | return v 24 | use_scm_version = {'write_to': 'gym_remote/VERSION.txt', 25 | 'version_scheme': version_scheme, 26 | 'local_scheme': local_scheme} 27 | 28 | 29 | setup( 30 | name='retro-contest-support', 31 | version=open(VERSION_PATH, 'r').read(), 32 | license='MIT', 33 | install_requires=[ 34 | 'gym', 35 | ], 36 | extras_require={ 37 | 'retro': 'gym-retro>=0.6.0', 38 | 'docker': 'docker', 39 | 'rest': ['docker', 'pyyaml', 'requests'], 40 | }, 41 | entry_points={ 42 | 'console_scripts': [ 43 | 'retro-contest-remote=retro_contest.remote:main [retro]', 44 | 'retro-contest-agent=retro_contest.agent:main', 45 | 'retro-contest=retro_contest.__main__:main' 46 | ] 47 | }, 48 | packages=['gym_remote', 'retro_contest'], 49 | setup_requires=['pytest-runner'], 50 | use_scm_version=use_scm_version, 51 | zip_safe=True 52 | ) 53 | -------------------------------------------------------------------------------- /support/tests/__init__.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import pytest 3 | import tempfile 4 | from gym_remote.client import RemoteEnv 5 | from gym_remote.server import RemoteEnvWrapper 6 | 7 | 8 | @pytest.fixture(scope='function') 9 | def tempdir(): 10 | with tempfile.TemporaryDirectory() as dir: 11 | yield dir 12 | 13 | 14 | @pytest.fixture(scope='function') 15 | def process_wrapper(): 16 | with tempfile.TemporaryDirectory() as dir: 17 | def serve(pipe): 18 | make_env = pipe.recv() 19 | env = RemoteEnvWrapper(make_env(), dir) 20 | pipe.send('ok') 21 | 22 | args = pipe.recv() 23 | kwargs = pipe.recv() 24 | env.serve(*args, **kwargs) 25 | 26 | parent_pipe, child_pipe = multiprocessing.Pipe() 27 | proc = multiprocessing.Process(target=serve, args=(child_pipe,)) 28 | proc.start() 29 | 30 | def call(env, *args, **kwargs): 31 | parent_pipe.send(env) 32 | assert parent_pipe.recv() == 'ok' 33 | parent_pipe.send(args) 34 | parent_pipe.send(kwargs) 35 | return RemoteEnv(dir) 36 | 37 | yield call 38 | proc.terminate() 39 | -------------------------------------------------------------------------------- /support/tests/test_bridge.py: -------------------------------------------------------------------------------- 1 | import gym_remote as gr 2 | import os 3 | 4 | from . import tempdir 5 | 6 | 7 | def setup_client_server(base): 8 | server = gr.Bridge(base) 9 | server.listen() 10 | 11 | client = gr.Bridge(base) 12 | client.connect() 13 | return client, server 14 | 15 | 16 | def start_bridge(client, server): 17 | server.server_accept() 18 | client.configure_client() 19 | 20 | 21 | def test_bridge_setup_connection(tempdir): 22 | client, server = setup_client_server(tempdir) 23 | start_bridge(client, server) 24 | 25 | 26 | def test_bridge_int(tempdir): 27 | client, server = setup_client_server(tempdir) 28 | server.add_channel('int', gr.IntChannel()) 29 | 30 | start_bridge(client, server) 31 | 32 | assert list(server._channels.keys()) == ['int'] 33 | assert list(client._channels.keys()) == ['int'] 34 | 35 | assert server._channels['int'].value is None 36 | assert client._channels['int'].value is None 37 | 38 | server._channels['int'].value = 1 39 | server.send() 40 | client.recv() 41 | 42 | assert server._channels['int'].value == 1 43 | assert client._channels['int'].value == 1 44 | 45 | server._channels['int'].value = 2 46 | server.send() 47 | client.recv() 48 | 49 | assert server._channels['int'].value == 2 50 | assert client._channels['int'].value == 2 51 | 52 | 53 | def test_bridge_float(tempdir): 54 | client, server = setup_client_server(tempdir) 55 | server.add_channel('float', gr.FloatChannel()) 56 | 57 | start_bridge(client, server) 58 | 59 | assert list(server._channels.keys()) == ['float'] 60 | assert list(client._channels.keys()) == ['float'] 61 | 62 | assert server._channels['float'].value is None 63 | assert client._channels['float'].value is None 64 | 65 | server._channels['float'].value = 1 66 | server.send() 67 | client.recv() 68 | 69 | assert server._channels['float'].value == 1 70 | assert client._channels['float'].value == 1 71 | 72 | server._channels['float'].value = 0.5 73 | server.send() 74 | client.recv() 75 | 76 | assert server._channels['float'].value == 0.5 77 | assert client._channels['float'].value == 0.5 78 | 79 | 80 | def test_bridge_bool(tempdir): 81 | client, server = setup_client_server(tempdir) 82 | server.add_channel('bool', gr.BoolChannel()) 83 | 84 | start_bridge(client, server) 85 | 86 | assert list(server._channels.keys()) == ['bool'] 87 | assert list(client._channels.keys()) == ['bool'] 88 | 89 | assert server._channels['bool'].value is None 90 | assert client._channels['bool'].value is None 91 | 92 | server._channels['bool'].value = True 93 | server.send() 94 | client.recv() 95 | 96 | assert server._channels['bool'].value is True 97 | assert client._channels['bool'].value is True 98 | 99 | server._channels['bool'].value = False 100 | server.send() 101 | client.recv() 102 | 103 | assert server._channels['bool'].value is False 104 | assert client._channels['bool'].value is False 105 | 106 | 107 | def test_bridge_int_fold(tempdir): 108 | client, server = setup_client_server(tempdir) 109 | server.add_channel('int_fold', gr.IntFoldChannel((2, 3))) 110 | 111 | start_bridge(client, server) 112 | 113 | assert list(server._channels.keys()) == ['int_fold'] 114 | assert list(client._channels.keys()) == ['int_fold'] 115 | 116 | assert server._channels['int_fold'].value is None 117 | assert client._channels['int_fold'].value is None 118 | 119 | server._channels['int_fold'].value = [1, 2] 120 | server.send() 121 | client.recv() 122 | 123 | assert (server._channels['int_fold'].value == [1, 2]).all() 124 | assert (client._channels['int_fold'].value == [1, 2]).all() 125 | 126 | server._channels['int_fold'].value = [0, 1] 127 | server.send() 128 | client.recv() 129 | 130 | assert (server._channels['int_fold'].value == [0, 1]).all() 131 | assert (client._channels['int_fold'].value == [0, 1]).all() 132 | 133 | 134 | def test_bridge_np(tempdir): 135 | import numpy as np 136 | client, server = setup_client_server(tempdir) 137 | server.add_channel('np', gr.NpChannel((2, 2), int)) 138 | 139 | start_bridge(client, server) 140 | 141 | assert list(server._channels.keys()) == ['np'] 142 | assert list(client._channels.keys()) == ['np'] 143 | 144 | server._channels['np'].value = np.zeros((2, 2), int) 145 | server.send() 146 | client.recv() 147 | 148 | assert (server._channels['np'].value == np.zeros((2, 2))).all() 149 | assert (client._channels['np'].value == np.zeros((2, 2))).all() 150 | 151 | server._channels['np'].value = np.ones((2, 2), int) 152 | server.send() 153 | client.recv() 154 | 155 | assert (server._channels['np'].value == np.ones((2, 2))).all() 156 | assert (client._channels['np'].value == np.ones((2, 2))).all() 157 | 158 | 159 | def test_bridge_np_int16(tempdir): 160 | import numpy as np 161 | client, server = setup_client_server(tempdir) 162 | ldtype = np.dtype('u2') 164 | server.add_channel('npl', gr.NpChannel((2, 2), ldtype)) 165 | server.add_channel('npb', gr.NpChannel((2, 2), bdtype)) 166 | 167 | start_bridge(client, server) 168 | 169 | assert set(server._channels.keys()) == {'npl', 'npb'} 170 | assert set(client._channels.keys()) == {'npl', 'npb'} 171 | 172 | assert server._channels['npl'].dtype == client._channels['npl'].dtype 173 | assert server._channels['npb'].dtype == client._channels['npb'].dtype 174 | 175 | server._channels['npl'].value = np.array((1, 0x100), ldtype) 176 | server.send() 177 | client.recv() 178 | 179 | assert (server._channels['npl'].value == [1, 0x100]).all() 180 | assert (client._channels['npl'].value == [1, 0x100]).all() 181 | 182 | server._channels['npl'].value = np.array((0x100, 1), ldtype) 183 | server.send() 184 | client.recv() 185 | 186 | assert (server._channels['npl'].value == [0x100, 1]).all() 187 | assert (client._channels['npl'].value == [0x100, 1]).all() 188 | 189 | server._channels['npb'].value = np.array((1, 0x100), bdtype) 190 | server.send() 191 | client.recv() 192 | 193 | assert (server._channels['npb'].value == [1, 0x100]).all() 194 | assert (client._channels['npb'].value == [1, 0x100]).all() 195 | 196 | server._channels['npb'].value = np.array((0x100, 1), bdtype) 197 | server.send() 198 | client.recv() 199 | 200 | assert (server._channels['npb'].value == [0x100, 1]).all() 201 | assert (client._channels['npb'].value == [0x100, 1]).all() 202 | 203 | 204 | def test_bridge_multi(tempdir): 205 | client, server = setup_client_server(tempdir) 206 | server.add_channel('int', gr.IntChannel()) 207 | server.add_channel('bool', gr.BoolChannel()) 208 | 209 | start_bridge(client, server) 210 | 211 | assert set(server._channels.keys()) == {'int', 'bool'} 212 | assert set(client._channels.keys()) == {'int', 'bool'} 213 | 214 | assert server._channels['int'].value is None 215 | assert client._channels['int'].value is None 216 | assert server._channels['bool'].value is None 217 | assert client._channels['bool'].value is None 218 | 219 | server._channels['bool'].value = True 220 | server.send() 221 | client.recv() 222 | 223 | assert server._channels['int'].value is None 224 | assert client._channels['int'].value is None 225 | assert server._channels['bool'].value is True 226 | assert client._channels['bool'].value is True 227 | 228 | server._channels['int'].value = 1 229 | server.send() 230 | client.recv() 231 | 232 | assert server._channels['int'].value is 1 233 | assert client._channels['int'].value is 1 234 | assert server._channels['bool'].value is True 235 | assert client._channels['bool'].value is True 236 | 237 | server._channels['bool'].value = False 238 | server._channels['int'].value = 2 239 | server.send() 240 | client.recv() 241 | 242 | assert server._channels['int'].value is 2 243 | assert client._channels['int'].value is 2 244 | assert server._channels['bool'].value is False 245 | assert client._channels['bool'].value is False 246 | 247 | 248 | def test_bridge_clean(tempdir): 249 | import numpy as np 250 | client, server = setup_client_server(tempdir) 251 | server.add_channel('np', gr.NpChannel((2, 2), int)) 252 | 253 | start_bridge(client, server) 254 | 255 | assert list(server._channels.keys()) == ['np'] 256 | assert list(client._channels.keys()) == ['np'] 257 | 258 | assert os.path.exists(os.path.join(tempdir, 'sock')) 259 | assert os.path.exists(os.path.join(tempdir, 'np')) 260 | 261 | client.close() 262 | server.close() 263 | 264 | assert not os.path.exists(os.path.join(tempdir, 'sock')) 265 | assert not os.path.exists(os.path.join(tempdir, 'np')) 266 | 267 | 268 | def test_bridge_client_clean(tempdir): 269 | import numpy as np 270 | client, server = setup_client_server(tempdir) 271 | server.add_channel('np', gr.NpChannel((2, 2), int)) 272 | 273 | start_bridge(client, server) 274 | 275 | assert list(server._channels.keys()) == ['np'] 276 | assert list(client._channels.keys()) == ['np'] 277 | 278 | assert os.path.exists(os.path.join(tempdir, 'sock')) 279 | assert os.path.exists(os.path.join(tempdir, 'np')) 280 | 281 | client.close('disconnect') 282 | try: 283 | server.recv() 284 | except gr.Bridge.Closed as e: 285 | assert str(e) == 'disconnect' 286 | 287 | assert not os.path.exists(os.path.join(tempdir, 'sock')) 288 | assert not os.path.exists(os.path.join(tempdir, 'np')) 289 | 290 | 291 | def test_bridge_client_buffered_clean(tempdir): 292 | import numpy as np 293 | client, server = setup_client_server(tempdir) 294 | server.add_channel('np', gr.NpChannel((2, 2), int)) 295 | 296 | start_bridge(client, server) 297 | 298 | assert list(server._channels.keys()) == ['np'] 299 | assert list(client._channels.keys()) == ['np'] 300 | 301 | assert os.path.exists(os.path.join(tempdir, 'sock')) 302 | assert os.path.exists(os.path.join(tempdir, 'np')) 303 | 304 | client.close('disconnect') 305 | try: 306 | server.send() 307 | server.recv() 308 | except gr.Bridge.Closed as e: 309 | pass 310 | 311 | assert not os.path.exists(os.path.join(tempdir, 'sock')) 312 | assert not os.path.exists(os.path.join(tempdir, 'np')) 313 | 314 | 315 | def test_bridge_server_clean(tempdir): 316 | import numpy as np 317 | client, server = setup_client_server(tempdir) 318 | server.add_channel('np', gr.NpChannel((2, 2), int)) 319 | 320 | start_bridge(client, server) 321 | 322 | assert list(server._channels.keys()) == ['np'] 323 | assert list(client._channels.keys()) == ['np'] 324 | 325 | assert os.path.exists(os.path.join(tempdir, 'sock')) 326 | assert os.path.exists(os.path.join(tempdir, 'np')) 327 | 328 | server.close('disconnect') 329 | try: 330 | client.recv() 331 | except gr.Bridge.Closed as e: 332 | assert str(e) == 'disconnect' 333 | 334 | assert not os.path.exists(os.path.join(tempdir, 'sock')) 335 | assert not os.path.exists(os.path.join(tempdir, 'np')) 336 | 337 | 338 | def test_bridge_server_noclient_clean(tempdir): 339 | import numpy as np 340 | server = gr.Bridge(tempdir) 341 | server.listen() 342 | server.add_channel('np', gr.NpChannel((2, 2), int)) 343 | 344 | assert list(server._channels.keys()) == ['np'] 345 | 346 | assert os.path.exists(os.path.join(tempdir, 'sock')) 347 | assert os.path.exists(os.path.join(tempdir, 'np')) 348 | 349 | server.close('disconnect') 350 | 351 | assert not os.path.exists(os.path.join(tempdir, 'sock')) 352 | assert not os.path.exists(os.path.join(tempdir, 'np')) 353 | 354 | 355 | def test_bridge_server_buffered_clean(tempdir): 356 | import numpy as np 357 | client, server = setup_client_server(tempdir) 358 | server.add_channel('np', gr.NpChannel((2, 2), int)) 359 | 360 | start_bridge(client, server) 361 | 362 | assert list(server._channels.keys()) == ['np'] 363 | assert list(client._channels.keys()) == ['np'] 364 | 365 | assert os.path.exists(os.path.join(tempdir, 'sock')) 366 | assert os.path.exists(os.path.join(tempdir, 'np')) 367 | 368 | server.close('disconnect') 369 | try: 370 | client.send() 371 | client.recv() 372 | except gr.Bridge.Closed as e: 373 | pass 374 | 375 | assert not os.path.exists(os.path.join(tempdir, 'sock')) 376 | assert not os.path.exists(os.path.join(tempdir, 'np')) 377 | 378 | 379 | def test_bridge_exception_server(tempdir): 380 | client, server = setup_client_server(tempdir) 381 | 382 | start_bridge(client, server) 383 | 384 | server.send() 385 | client.recv() 386 | 387 | server.exception(gr.exceptions.GymRemoteError) 388 | try: 389 | client.recv() 390 | assert False, 'No exception' 391 | except gr.exceptions.GymRemoteError: 392 | pass 393 | 394 | 395 | def test_bridge_exception_client(tempdir): 396 | client, server = setup_client_server(tempdir) 397 | 398 | start_bridge(client, server) 399 | 400 | server.send() 401 | client.recv() 402 | 403 | client.exception(gr.exceptions.GymRemoteError) 404 | try: 405 | client.send() 406 | server.recv() 407 | assert False, 'No exception' 408 | except gr.exceptions.GymRemoteError: 409 | pass 410 | -------------------------------------------------------------------------------- /support/tests/test_env.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import gym.spaces 3 | import gym_remote as gr 4 | import gym_remote.exceptions as gre 5 | import numpy as np 6 | import os 7 | import time 8 | 9 | from . import process_wrapper 10 | 11 | 12 | class BitEnv(gym.Env): 13 | def __init__(self): 14 | self.action_space = gym.spaces.Discrete(8) 15 | self.observation_space = gym.spaces.Discrete(2) 16 | 17 | def step(self, action): 18 | assert self.action_space.contains(action) 19 | observation = action & 1 20 | reward = float(action & 2) 21 | done = bool(action & 4) 22 | return observation, reward, done, {} 23 | 24 | def reset(self): 25 | return 0 26 | 27 | 28 | class MultiBitEnv(gym.Env): 29 | def __init__(self): 30 | self.action_space = gym.spaces.MultiBinary(3) 31 | self.observation_space = gym.spaces.Discrete(2) 32 | 33 | def step(self, action): 34 | assert self.action_space.contains(action) 35 | observation = action[0] 36 | reward = float(action[1]) 37 | done = bool(action[2]) 38 | return observation, reward, done, {} 39 | 40 | def reset(self): 41 | return 0 42 | 43 | 44 | class StepEnv(gym.Env): 45 | def __init__(self): 46 | self.action_space = gym.spaces.Discrete(2) 47 | self.observation_space = gym.spaces.Discrete(1) 48 | self.reward = 0 49 | self.done = False 50 | 51 | def step(self, action): 52 | if not self.done: 53 | self.reward += 1 54 | if action: 55 | self.done = True 56 | return 0, self.reward, self.done, {} 57 | 58 | def reset(self): 59 | self.reward = 0 60 | self.done = False 61 | return 0 62 | 63 | 64 | def test_split(process_wrapper): 65 | env = process_wrapper(BitEnv) 66 | 67 | assert env.step(0) == (0, 0, False, {}) 68 | assert env.step(1) == (1, 0, False, {}) 69 | assert env.step(2) == (0, 2, False, {}) 70 | assert env.step(3) == (1, 2, False, {}) 71 | assert env.step(4) == (0, 0, True, {}) 72 | 73 | 74 | def test_multibinary_split(process_wrapper): 75 | env = process_wrapper(MultiBitEnv) 76 | 77 | assert env.step(np.array([0, 0, 0], np.int8)) == (0, 0, False, {}) 78 | assert env.step(np.array([1, 0, 0], np.int8)) == (1, 0, False, {}) 79 | assert env.step(np.array([0, 1, 0], np.int8)) == (0, 1, False, {}) 80 | assert env.step(np.array([1, 1, 0], np.int8)) == (1, 1, False, {}) 81 | assert env.step(np.array([0, 0, 1], np.int8)) == (0, 0, True, {}) 82 | 83 | 84 | def test_reset(process_wrapper): 85 | env = process_wrapper(StepEnv) 86 | 87 | assert env.reset() == 0 88 | assert env.step(0) == (0, 1, False, {}) 89 | assert env.step(0) == (0, 2, False, {}) 90 | assert env.step(1) == (0, 3, True, {}) 91 | assert env.step(0) == (0, 3, True, {}) 92 | assert env.reset() == 0 93 | assert env.step(0) == (0, 1, False, {}) 94 | assert env.step(0) == (0, 2, False, {}) 95 | assert env.step(1) == (0, 3, True, {}) 96 | assert env.step(0) == (0, 3, True, {}) 97 | 98 | 99 | def test_reset_exception(process_wrapper): 100 | env = process_wrapper(StepEnv, ignore_reset=True) 101 | 102 | assert env.reset() == 0 103 | assert env.step(0) == (0, 1, False, {}) 104 | assert env.step(0) == (0, 2, False, {}) 105 | assert env.step(1) == (0, 3, True, {}) 106 | assert env.reset() == 0 107 | assert env.step(0) == (0, 1, False, {}) 108 | assert env.step(0) == (0, 2, False, {}) 109 | assert env.step(1) == (0, 3, True, {}) 110 | try: 111 | assert env.step(0) == (0, 3, True, {}) 112 | except gre.ResetError: 113 | return 114 | except: 115 | assert False, 'Incorrect exception' 116 | assert False, 'No exception' 117 | 118 | 119 | def test_ts_limit(process_wrapper): 120 | env = process_wrapper(StepEnv, timestep_limit=5) 121 | 122 | assert env.step(0) == (0, 1, False, {}) 123 | assert env.step(0) == (0, 2, False, {}) 124 | assert env.step(0) == (0, 3, False, {}) 125 | assert env.step(0) == (0, 4, False, {}) 126 | assert env.step(0) == (0, 5, False, {}) 127 | try: 128 | env.step(0) 129 | except gre.TimestepTimeoutError as e: 130 | return 131 | except: 132 | assert False, 'Incorrect exception' 133 | assert False, 'Remote did not shut down' 134 | 135 | 136 | def test_wc_limit(process_wrapper): 137 | env = process_wrapper(StepEnv, wallclock_limit=0.1) 138 | 139 | assert env.step(0) == (0, 1, False, {}) 140 | time.sleep(0.2) 141 | try: 142 | env.step(0) 143 | except gre.WallClockTimeoutError as e: 144 | return 145 | except: 146 | assert False, 'Incorrect exception' 147 | assert False, 'Remote did not shut down' 148 | 149 | 150 | def test_cleanup(process_wrapper): 151 | env = process_wrapper(BitEnv) 152 | 153 | assert os.path.exists(os.path.join(env.bridge.base, 'sock')) 154 | 155 | env.close() 156 | time.sleep(0.1) 157 | 158 | assert not os.path.exists(os.path.join(env.bridge.base, 'sock')) 159 | --------------------------------------------------------------------------------