├── .gitignore
├── LICENSE
├── README.md
├── env
├── Dockerfile
├── conda_env.yaml
└── gym-0.19.0-py3-none-any.whl
├── imgs
├── mmaze.jpg
├── r2i.png
└── teaser.jpg
└── recall2imagine
├── Dockerfile
├── __init__.py
├── agent.py
├── behaviors.py
├── configs.yaml
├── configs_backup.yaml
├── embodied
├── __init__.py
├── core
│ ├── __init__.py
│ ├── base.py
│ ├── basics.py
│ ├── batch.py
│ ├── batcher.py
│ ├── checkpoint.py
│ ├── config.py
│ ├── counter.py
│ ├── distr.py
│ ├── driver.py
│ ├── flags.py
│ ├── logger.py
│ ├── metrics.py
│ ├── parallel.py
│ ├── path.py
│ ├── random.py
│ ├── space.py
│ ├── timer.py
│ ├── uuid.py
│ ├── when.py
│ ├── worker.py
│ └── wrappers.py
├── envs
│ ├── atari.py
│ ├── crafter.py
│ ├── dmc.py
│ ├── dmlab.py
│ ├── dummy.py
│ ├── from_dm.py
│ ├── from_gym.py
│ ├── loconav.py
│ ├── loconav_quadruped.py
│ ├── loconav_quadruped.xml
│ ├── minecraft.py
│ ├── minecraft_base.py
│ ├── minecraft_minerl.py
│ ├── pinpad.py
│ └── robodesk.py
├── replay
│ ├── __init__.py
│ ├── chunk.py
│ ├── generic.py
│ ├── generic_lfs.py
│ ├── lfs_manager.py
│ ├── limiters.py
│ ├── naive_chunks.py
│ ├── replays.py
│ ├── reverb.py
│ ├── saver.py
│ └── selectors.py
├── run
│ ├── __init__.py
│ ├── eval_only.py
│ ├── parallel.py
│ ├── train.py
│ ├── train_eval.py
│ ├── train_holdout.py
│ └── train_save.py
└── scripts
│ ├── install-atari.sh
│ ├── install-dmlab.sh
│ ├── install-minecraft.sh
│ ├── plot.py
│ └── xvfb_run.sh
├── expl.py
├── jaxagent.py
├── jaxutils.py
├── nets.py
├── ninjax.py
├── ssm
├── __init__.py
├── common.py
├── mimo.py
└── siso.py
├── ssm_nets.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .pytest_cache
2 | dist
3 | scripts/*
4 | __pycache__/
5 | *.py[codi]
6 | *.egg-info
7 | MUJOCO_LOG.TXT
8 | ;
9 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2023 Danijar Hafner, Artem Zholus, Mohammad Reza Samsami
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in all
11 | copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 | SOFTWARE.
20 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Mastering Memory Tasks with World Models
2 |
3 | ---
4 |
5 | > M. R. Samsami\*, A. Zholus\*, J. Rajendran, S. Chandar. Mastering Memory Tasks with World Models. ICLR 2024, top‑1.2% (oral)
6 | > __[[Project Website]](https://recall2imagine.github.io)__ __[[Paper]](https://arxiv.org/abs/2403.04253)__ __[[Openreview]](https://openreview.net/forum?id=1vDArHJ68h)__ __[[X Thread]](https://x.com/apsarathchandar/status/1772345328617824502)__
7 |
8 | ---
9 |
10 | This repo contains the official implementation of the Recall to Imagine (R2I) algorithm, introduced in the paper
11 | **Mastering Memory Tasks with World Models**. R2I is a generalist, computationally efficient, model-based agent focusing on memory-intensive reinforcement learnin (RL) tasks. It stands out by demonstrating superhuman performance in complex memory domain, Memory Maze. The codebase, powered by JAX, offers an efficient framework for RL development, incorporating state space models.
12 |
13 |
14 |
15 |
16 |
17 |
18 | ## Table of Contents
19 | - [Prerequisites](#prerequisites)
20 | - [Installation](#installation)
21 | - [Conda](#conda)
22 | - [Docker & Singularity](#docker--singularity)
23 | - [Running Experiments with Containers](#running-experiments-with-containers)
24 | - [Reproducing Results from the Paper](#reproducing-results-from-the-paper)
25 | - [Memory Maze](#memory-maze)
26 | - [POPGym](#popgym)
27 | - [BSuite](#bsuite)
28 | - [BibTeX](#bibtex)
29 | - [Acknowledgements](#acknowledgements)
30 |
31 | ## Prerequisites
32 |
33 | Before diving into the installation and running experiments, ensure you meet the following prerequisites:
34 |
35 | | Requirement | Specification |
36 | |--------------------|----------------------------------------------------------------|
37 | | **Hardware** | GPU with CUDA support |
38 | | **Software** | Docker or Singularity for containerized environments, conda |
39 | | **Python Version** | 3.8 |
40 | | **Storage** | At least 130GB for the Memory Maze experiment's replay buffer |
41 |
42 |
43 | ## Installation
44 |
45 | ### Conda
46 |
47 | We provide a self-contained conda environment for running experiments. To setup the env do
48 |
49 | ```sh
50 | conda create -n recall2imagine python=3.8
51 | conda activate recall2imagine
52 | cd env && conda env update -f conda_env.yaml
53 | pip install \
54 | "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \
55 | dm-haiku \
56 | flax \
57 | optax==0.1.5
58 | ```
59 |
60 | Though conda environment should work, the recommended way of setting up the environment is through
61 | docker or singularity containers. See below for details.
62 |
63 | ### Docker & Singularity
64 |
65 | The most reliable way to reproduce the results from the paper is through containers.
66 | To build a docker image do
67 |
68 | ```sh
69 | cd env && docker build -t r2i .
70 | ```
71 |
72 | This steps takes several mins to complete. Alternatively, you can just pull our pre-built image:
73 |
74 | ```sh
75 | docker pull artemzholus/r2i && docker tag artemzholus/r2i r2i
76 | ```
77 |
78 | If you are going to run the code inside of a singularity container, convert the docker image to singularity.
79 | Note that the code below creates a lot of temporary files in the home directory to avoid that, do
80 |
81 | ```sh
82 | export SINGULARITY_CACHEDIR=
83 | ```
84 |
85 | If you're not sure whether you need this, just skip this step.
86 |
87 | To convert the image, do
88 |
89 | ```sh
90 | singularity pull docker://artemzholus/r2i
91 | ```
92 |
93 | ## Running experiments with containers
94 |
95 | To run experiments docker, launch an interactive docker container:
96 |
97 | ```sh
98 | docker run --rm -it --network host -v $(pwd):/code -w /code r2i bash
99 | ```
100 |
101 | This will lead to the interactive shell within a container. Run any python training script
102 | within that shell.
103 |
104 | For singularity:
105 |
106 | ```sh
107 | singularity shell --nv
108 | ```
109 |
110 | where image path it the file name that was downloaded when doing `singularity pull`
111 |
112 |
113 | ## Reproducing results from the paper
114 |
115 | The code was designed with reproducibility in mind.
116 |
117 | ### Memory Maze
118 |
119 | **TL;DR:** use the following command:
120 |
121 | ```sh
122 | python recall2imagine/train.py \
123 | --configs mmaze --task gym_memory_maze:MemoryMaze-9x9-v0 \
124 | --wdb_name memory_maze_9x9 \
125 | --logdir ./logs_memory_maze_9x9
126 | ```
127 |
128 | By default, the script uses 1 GPU.
129 | To ensure that you match the training speed reported in our paper, consider using two GPUs
130 | (this does not affect the sample efficiency):
131 |
132 | ```sh
133 | python recall2imagine/train.py \
134 | --configs mmaze --task gym_memory_maze:MemoryMaze-9x9-v0 \
135 | --wdb_name memory_maze_9x9 \
136 | --jax.train_devices 0 1 \
137 | --jax.policy_devices 0 \
138 | --logdir ./logs_memory_maze_9x9
139 | ```
140 |
141 | **NOTE**: By default the training script uses disk-based replay buffer. In memory maze our
142 | default replay buffer size is 10M image RL steps which takes about 130GB of space.
143 | Therefore, be sure to have enough space before running, since the scipt will allocate
144 | the whole file in the beginning of the training.
145 |
146 |
147 |
148 | Since this experiment may take a long time, if you are a large compute cluster user,
149 | you may have a limited maximal duration of experiments. Exactly for that scenario, we designed a new
150 | replay buffer for efficient restarting of very long RL experiments. The idea is to keep two
151 | versions of the replay buffer - one in the fast, but temporary storage, the other in the slower but
152 | lifelong storage. Use the following pattern to continue your experiment after interruption:
153 |
154 | ```sh
155 | python recall2imagine/train.py \
156 | --configs mmaze --task gym_memory_maze:MemoryMaze-9x9-v0 \
157 | --wdb_name memory_maze_9x9 \
158 | --jax.train_devices 0 1 \
159 | --jax.policy_devices 0 \
160 | --use_lfs True --lfs_dir \
161 | --logdir
162 | ```
163 |
164 | Use this command for both first run and each next one. It knows how to deal with the state.
165 | This buffer is very reliable, it has a good fault tolerance and it is the way how we conducted
166 | our original experiments in the paper.
167 |
168 |
169 | ### POPGym
170 |
171 | Use the following command:
172 |
173 | ```sh
174 | python recall2imagine/train.py \
175 | --configs popgym --task gym_popgym:popgym-RepeatPreviousEasy-v0 \
176 | --wdb_name popgym_repeat_previous_easy \
177 | --logdir ./logs_popgym_repeat_previous_easy
178 | ```
179 |
180 | ### BSuite
181 |
182 | (Coming soon!)
183 |
184 | ## BibTeX
185 |
186 | If you find our work useful, please cite our paper:
187 |
188 | ```
189 | @inproceedings{
190 | samsami2024mastering,
191 | title={Mastering Memory Tasks with World Models},
192 | author={Mohammad Reza Samsami and Artem Zholus and Janarthanan Rajendran and Sarath Chandar},
193 | booktitle={The Twelfth International Conference on Learning Representations},
194 | year={2024},
195 | url={https://openreview.net/forum?id=1vDArHJ68h}
196 | }
197 | ```
198 |
199 | ## Acknowledgements
200 |
201 | We thank Danijar Hafner for providing the DreamerV3 implementation, which this repo is based upon.
202 |
203 | [arxiv]: https://github.com/chandar-lab/recall2imagine
204 | [website]: https://recall2imagine.github.io/
205 | [twitter]: https://github.com/chandar-lab/recall2imagine
206 |
--------------------------------------------------------------------------------
/env/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04
2 |
3 | ARG DEBIAN_FRONTEND=noninteractive
4 | RUN apt-get update && apt-get install -y \
5 | git xvfb curl \
6 | libglu1-mesa libglu1-mesa-dev libgl1-mesa-dev libosmesa6-dev mesa-utils freeglut3 freeglut3-dev \
7 | libglew2.1 libglfw3 libglfw3-dev libegl-dev zlib1g zlib1g-dev libsdl2-dev libjpeg-dev lua5.1 liblua5.1-0-dev libffi-dev \
8 | build-essential cmake g++ build-essential pkg-config software-properties-common gettext \
9 | ffmpeg patchelf swig unrar unzip zip curl wget tmux \
10 | && rm -rf /var/lib/apt/lists/*
11 |
12 |
13 | ENV CNDA=/conda
14 | RUN mkdir -p $CNDA && chmod 755 $CNDA
15 | RUN curl -Lo $CNDA/miniconda.sh https://repo.continuum.io/miniconda/Miniconda3-4.5.11-Linux-x86_64.sh \
16 | && chmod +x $CNDA/miniconda.sh \
17 | && $CNDA/miniconda.sh -b -p $CNDA/miniconda \
18 | && rm $CNDA/miniconda.sh
19 | ENV PATH=$CNDA/miniconda/bin:$PATH
20 | ENV CONDA_AUTO_UPDATE_CONDA=false
21 |
22 | # Create a Python 3.8 environment
23 | RUN $CNDA/miniconda/bin/conda create -y --name py38 python=3.8 \
24 | && $CNDA/miniconda/bin/conda clean -ya
25 | ENV CONDA_DEFAULT_ENV=py38
26 | ENV CONDA_PREFIX=$CNDA/miniconda/envs/$CONDA_DEFAULT_ENV
27 | ENV PATH=$CONDA_PREFIX/bin:$PATH
28 | RUN $CNDA/miniconda/bin/conda clean -ya
29 | # Create conda environment
30 | # Atari
31 | RUN pip3 install --upgrade setuptools pip
32 | ADD ./gym-0.19.0-py3-none-any.whl /root
33 | RUN pip3 install \
34 | tensorflow_probability==0.20.1 \
35 | yacs==0.1.8 \
36 | tensorflow-cpu==2.12.0 \
37 | ruamel.yaml==0.17.32 \
38 | ruamel.yaml.clib==0.2.7 \
39 | moviepy==1.0.3 \
40 | imageio==2.31.1 \
41 | crafter==1.8.1 \
42 | dm-control==1.0.12 \
43 | mujoco==2.3.6 \
44 | robodesk==1.0.0 \
45 | bsuite==0.3.5 \
46 | numpy==1.22.1 \
47 | opt_einsum==3.3.0 \
48 | einops==0.6.1 \
49 | wandb==0.15.5 \
50 | memory-maze==1.0.2 \
51 | popgym==1.0.2 \
52 | gymnasium==0.29.0 \
53 | mazelib==0.9.13 \
54 | procgen==0.10.7 \
55 | atari-py==0.2.9 \
56 | protobuf==3.20.3 \
57 | gast==0.4.0 \
58 | zmq \
59 | /root/gym-0.19.0-py3-none-any.whl \
60 | rich==13.7.0 \
61 | msgpack==1.0.7 \
62 | cloudpickle==1.6.0 \
63 | opencv-python==4.8.0.74
64 |
65 | RUN pip3 install --upgrade \
66 | flax \
67 | dm-haiku \
68 | "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \
69 | optax==0.1.5
70 |
71 | RUN wget -L -nv http://www.atarimania.com/roms/Roms.rar && \
72 | unrar x -y Roms.rar && \
73 | python3 -m atari_py.import_roms ROMS && \
74 | rm -rf Roms.rar ROMS.zip ROMS
75 | RUN mkdir -p /root/.mujoco && \
76 | cd /root/.mujoco && \
77 | wget -nv https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz -O mujoco.tar.gz && \
78 | tar -xf mujoco.tar.gz && \
79 | rm mujoco.tar.gz
80 |
81 | ENV OMP_NUM_THREADS 1
82 | ENV PYTHONUNBUFFERED 1
83 | ENV LANG "C.UTF-8"
84 | ENV NUMBA_CACHE_DIR=/tmp
85 | ENV XLA_PYTHON_CLIENT_MEM_FRACTION 0.8
--------------------------------------------------------------------------------
/env/conda_env.yaml:
--------------------------------------------------------------------------------
1 | channels:
2 | - conda-forge
3 | - defaults
4 | dependencies:
5 | - python=3.8
6 | - pip:
7 | - tensorflow_probability==0.20.1
8 | - optax==0.1.5
9 | - yacs==0.1.8
10 | - tensorflow-cpu==2.12.0
11 | - ruamel.yaml==0.17.32
12 | - ruamel.yaml.clib==0.2.7
13 | - moviepy==1.0.3
14 | - imageio==2.31.1
15 | - crafter==1.8.1
16 | - dm-control==1.0.12
17 | - mujoco==2.3.6
18 | - robodesk==1.0.0
19 | - bsuite==0.3.5
20 | - numpy==1.22.1
21 | - opt_einsum==3.3.0
22 | - einops==0.6.1
23 | - wandb==0.15.5
24 | - memory-maze==1.0.2
25 | - popgym==1.0.2
26 | - gymnasium==0.29.0
27 | - mazelib==0.9.13
28 | - procgen==0.10.7
29 | - flax==0.7.0
30 | - dm-haiku==0.0.10
31 | - atari-py==0.2.9
32 | - protobuf==3.20.3
33 | - gast==0.4.0
34 | - zmq
35 | - rich==13.7.0
36 | - msgpack==1.0.7
37 | - cloudpickle==1.6.0
38 | - opencv-python==4.8.0.74
39 |
--------------------------------------------------------------------------------
/env/gym-0.19.0-py3-none-any.whl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chandar-lab/Recall2Imagine/6e317e751ffd8f381e09476f3633567dff9f9234/env/gym-0.19.0-py3-none-any.whl
--------------------------------------------------------------------------------
/imgs/mmaze.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chandar-lab/Recall2Imagine/6e317e751ffd8f381e09476f3633567dff9f9234/imgs/mmaze.jpg
--------------------------------------------------------------------------------
/imgs/r2i.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chandar-lab/Recall2Imagine/6e317e751ffd8f381e09476f3633567dff9f9234/imgs/r2i.png
--------------------------------------------------------------------------------
/imgs/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chandar-lab/Recall2Imagine/6e317e751ffd8f381e09476f3633567dff9f9234/imgs/teaser.jpg
--------------------------------------------------------------------------------
/recall2imagine/Dockerfile:
--------------------------------------------------------------------------------
1 | # 1. Test setup:
2 | # docker run -it --rm --gpus all nvidia/cuda:11.4.2-cudnn8-runtime-ubuntu20.04 nvidia-smi
3 | #
4 | # If the above does not work, try adding the --privileged flag
5 | # and changing the command to `sh -c 'ldconfig -v && nvidia-smi'`.
6 | #
7 | # 2. Start training:
8 | # docker build -f dreamerv3/Dockerfile -t img . && \
9 | # docker run -it --rm --gpus all -v ~/logdir:/logdir img \
10 | # sh scripts/xvfb_run.sh python3 dreamerv3/train.py \
11 | # --logdir "/logdir/$(date +%Y%m%d-%H%M%S)" \
12 | # --configs dmc_vision --task dmc_walker_walk
13 | #
14 | # 3. See results:
15 | # tensorboard --logdir ~/logdir
16 |
17 | # System
18 | FROM nvidia/cuda:11.4.2-cudnn8-devel-ubuntu20.04
19 | ARG DEBIAN_FRONTEND=noninteractive
20 | ENV TZ=America/San_Francisco
21 | ENV PYTHONUNBUFFERED 1
22 | ENV PIP_DISABLE_PIP_VERSION_CHECK 1
23 | ENV PIP_NO_CACHE_DIR 1
24 | RUN apt-get update && apt-get install -y \
25 | ffmpeg git python3-pip vim libglew-dev \
26 | x11-xserver-utils xvfb \
27 | && apt-get clean
28 | RUN pip3 install --upgrade pip
29 |
30 | # Envs
31 | ENV MUJOCO_GL egl
32 | ENV DMLAB_DATASET_PATH /dmlab_data
33 | COPY scripts scripts
34 | RUN sh scripts/install-dmlab.sh
35 | RUN sh scripts/install-atari.sh
36 | RUN sh scripts/install-minecraft.sh
37 | ENV NUMBA_CACHE_DIR=/tmp
38 | RUN pip3 install crafter
39 | RUN pip3 install dm_control
40 | RUN pip3 install robodesk
41 | RUN pip3 install bsuite
42 |
43 | # Agent
44 | RUN pip3 install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
45 | RUN pip3 install jaxlib
46 | RUN pip3 install tensorflow_probability
47 | RUN pip3 install optax
48 | RUN pip3 install tensorflow-cpu
49 | ENV XLA_PYTHON_CLIENT_MEM_FRACTION 0.8
50 |
51 | # Google Cloud DNS cache (optional)
52 | ENV GCS_RESOLVE_REFRESH_SECS=60
53 | ENV GCS_REQUEST_CONNECTION_TIMEOUT_SECS=300
54 | ENV GCS_METADATA_REQUEST_TIMEOUT_SECS=300
55 | ENV GCS_READ_REQUEST_TIMEOUT_SECS=300
56 | ENV GCS_WRITE_REQUEST_TIMEOUT_SECS=600
57 |
58 | # Embodied
59 | RUN pip3 install numpy cloudpickle ruamel.yaml rich zmq msgpack
60 | COPY . /embodied
61 | RUN chown -R 1000:root /embodied && chmod -R 775 /embodied
62 |
63 | WORKDIR embodied
64 |
--------------------------------------------------------------------------------
/recall2imagine/__init__.py:
--------------------------------------------------------------------------------
1 | import sys, pathlib
2 | sys.path.append(str(pathlib.Path(__file__).parent))
3 |
4 | from .agent import Agent
5 | configs = Agent.configs
6 | from .train import wrap_env
7 |
--------------------------------------------------------------------------------
/recall2imagine/behaviors.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 | from tensorflow_probability.substrates import jax as tfp
3 | tfd = tfp.distributions
4 |
5 | from . import agent
6 | from . import expl
7 | from . import ninjax as nj
8 | from . import jaxutils
9 |
10 |
11 | class Greedy(nj.Module):
12 |
13 | def __init__(self, wm, act_space, config):
14 | rewfn = lambda s: wm.heads['reward'](s).mean()[1:]
15 | if config.critic_type == 'vfunction':
16 | critics = {'extr': agent.VFunction(rewfn, config, name='critic')}
17 | else:
18 | raise NotImplementedError(config.critic_type)
19 | self.ac = agent.ImagActorCritic(
20 | critics, {'extr': 1.0}, act_space, config, name='ac')
21 |
22 | def initial(self, batch_size):
23 | return self.ac.initial(batch_size)
24 |
25 | def policy(self, latent, state):
26 | return self.ac.policy(latent, state)
27 |
28 | def train(self, imagine, start, data):
29 | return self.ac.train(imagine, start, data)
30 |
31 | def report(self, data):
32 | return {}
33 |
34 |
35 | class Random(nj.Module):
36 |
37 | def __init__(self, wm, act_space, config):
38 | self.config = config
39 | self.act_space = act_space
40 |
41 | def initial(self, batch_size):
42 | return jnp.zeros(batch_size)
43 |
44 | def policy(self, latent, state):
45 | batch_size = len(state)
46 | shape = (batch_size,) + self.act_space.shape
47 | if self.act_space.discrete:
48 | dist = jaxutils.OneHotDist(jnp.zeros(shape))
49 | else:
50 | dist = tfd.Uniform(-jnp.ones(shape), jnp.ones(shape))
51 | dist = tfd.Independent(dist, 1)
52 | return {'action': dist}, state
53 |
54 | def train(self, imagine, start, data):
55 | return None, {}
56 |
57 | def report(self, data):
58 | return {}
59 |
60 |
61 | class Explore(nj.Module):
62 |
63 | REWARDS = {
64 | 'disag': expl.Disag,
65 | }
66 |
67 | def __init__(self, wm, act_space, config):
68 | self.config = config
69 | self.rewards = {}
70 | critics = {}
71 | for key, scale in config.expl_rewards.items():
72 | if not scale:
73 | continue
74 | if key == 'extr':
75 | rewfn = lambda s: wm.heads['reward'](s).mean()[1:]
76 | critics[key] = agent.VFunction(rewfn, config, name=key)
77 | else:
78 | rewfn = self.REWARDS[key](
79 | wm, act_space, config, name=key + '_reward')
80 | critics[key] = agent.VFunction(rewfn, config, name=key)
81 | self.rewards[key] = rewfn
82 | scales = {k: v for k, v in config.expl_rewards.items() if v}
83 | self.ac = agent.ImagActorCritic(
84 | critics, scales, act_space, config, name='ac')
85 |
86 | def initial(self, batch_size):
87 | return self.ac.initial(batch_size)
88 |
89 | def policy(self, latent, state):
90 | return self.ac.policy(latent, state)
91 |
92 | def train(self, imagine, start, data):
93 | metrics = {}
94 | for key, rewfn in self.rewards.items():
95 | mets = rewfn.train(data)
96 | metrics.update({f'{key}_k': v for k, v in mets.items()})
97 | traj, mets = self.ac.train(imagine, start, data)
98 | metrics.update(mets)
99 | return traj, metrics
100 |
101 | def report(self, data):
102 | return {}
103 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/__init__.py:
--------------------------------------------------------------------------------
1 | try:
2 | import rich.traceback
3 | rich.traceback.install()
4 | except ImportError:
5 | pass
6 |
7 | from .core import *
8 |
9 | from . import envs
10 | from . import replay
11 | from . import run
12 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/core/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import Agent, Env, Wrapper, Replay
2 |
3 | from .basics import convert, treemap, pack, unpack
4 | from .basics import print_ as print
5 | from .basics import format_ as format
6 |
7 | from .space import Space
8 | from .path import Path
9 | from .checkpoint import Checkpoint
10 | from .config import Config
11 | from .counter import Counter
12 | from .driver import Driver
13 | from .flags import Flags
14 | from .logger import Logger
15 | from .parallel import Parallel
16 | from .timer import Timer
17 | from .worker import Worker
18 | from .batcher import Batcher, BatcherSM
19 | from .metrics import Metrics
20 | from .uuid import uuid
21 |
22 | from .batch import BatchEnv
23 | from .random import RandomAgent
24 | from .distr import Client, Server, BatchServer
25 |
26 | from . import logger
27 | from . import when
28 | from . import wrappers
29 | from . import distr
30 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/core/base.py:
--------------------------------------------------------------------------------
1 | class Agent:
2 |
3 | configs = {} # dict of dicts
4 |
5 | def __init__(self, obs_space, act_space, step, config):
6 | pass
7 |
8 | def dataset(self, generator_fn):
9 | raise NotImplementedError(
10 | 'dataset(generator_fn) -> generator_fn')
11 |
12 | def policy(self, obs, state=None, mode='train'):
13 | raise NotImplementedError(
14 | "policy(obs, state=None, mode='train') -> act, state")
15 |
16 | def train(self, data, state=None):
17 | raise NotImplementedError(
18 | 'train(data, state=None) -> outs, state, metrics')
19 |
20 | def report(self, data):
21 | raise NotImplementedError(
22 | 'report(data) -> metrics')
23 |
24 | def save(self):
25 | raise NotImplementedError('save() -> data')
26 |
27 | def load(self, data):
28 | raise NotImplementedError('load(data) -> None')
29 |
30 | def sync(self):
31 | # This method allows the agent to sync parameters from its training devices
32 | # to its policy devices in the case of a multi-device agent.
33 | pass
34 |
35 |
36 | class Env:
37 |
38 | def __len__(self):
39 | return 0 # Return positive integer for batched envs.
40 |
41 | def __bool__(self):
42 | return True # Env is always truthy, despite length zero.
43 |
44 | def __repr__(self):
45 | return (
46 | f'{self.__class__.__name__}('
47 | f'len={len(self)}, '
48 | f'obs_space={self.obs_space}, '
49 | f'act_space={self.act_space})')
50 |
51 | @property
52 | def obs_space(self):
53 | # The observation space must contain the keys is_first, is_last, and
54 | # is_terminal. Commonly, it also contains the keys reward and image. By
55 | # convention, keys starting with log_ are not consumed by the agent.
56 | raise NotImplementedError('Returns: dict of spaces')
57 |
58 | @property
59 | def act_space(self):
60 | # The observation space must contain the keys action and reset. This
61 | # restriction may be lifted in the future.
62 | raise NotImplementedError('Returns: dict of spaces')
63 |
64 | def step(self, action):
65 | raise NotImplementedError('Returns: dict')
66 |
67 | def render(self):
68 | raise NotImplementedError('Returns: array')
69 |
70 | def close(self):
71 | pass
72 |
73 |
74 | class Wrapper:
75 |
76 | def __init__(self, env):
77 | self.env = env
78 |
79 | def __len__(self):
80 | return len(self.env)
81 |
82 | def __bool__(self):
83 | return bool(self.env)
84 |
85 | def __getattr__(self, name):
86 | if name.startswith('__'):
87 | raise AttributeError(name)
88 | try:
89 | return getattr(self.env, name)
90 | except AttributeError:
91 | raise ValueError(name)
92 |
93 |
94 | class Replay:
95 |
96 | def __len__(self):
97 | raise NotImplementedError('Returns: total number of steps')
98 |
99 | @property
100 | def stats(self):
101 | raise NotImplementedError('Returns: metrics')
102 |
103 | def add(self, transition, worker=0):
104 | raise NotImplementedError('Returns: None')
105 |
106 | def add_traj(self, trajectory):
107 | raise NotImplementedError('Returns: None')
108 |
109 | def dataset(self):
110 | raise NotImplementedError('Yields: trajectory')
111 |
112 | def prioritize(self, keys, priorities):
113 | pass
114 |
115 | def save(self):
116 | pass
117 |
118 | def load(self, data):
119 | pass
120 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/core/basics.py:
--------------------------------------------------------------------------------
1 | import builtins
2 | import pickle
3 |
4 | import numpy as np
5 |
6 | from . import space as spacelib
7 |
8 | try:
9 | import rich.console
10 | console = rich.console.Console()
11 | except ImportError:
12 | console = None
13 |
14 |
15 | CONVERSION = {
16 | np.floating: np.float32,
17 | np.signedinteger: np.int64,
18 | np.uint8: np.uint8,
19 | bool: bool,
20 | }
21 |
22 |
23 | def convert(value):
24 | value = np.asarray(value)
25 | if value.dtype not in CONVERSION.values():
26 | for src, dst in CONVERSION.items():
27 | if np.issubdtype(value.dtype, src):
28 | if value.dtype != dst:
29 | value = value.astype(dst)
30 | break
31 | else:
32 | raise TypeError(f"Object '{value}' has unsupported dtype: {value.dtype}")
33 | return value
34 |
35 |
36 | def print_(value, color=None):
37 | global console
38 | value = format_(value)
39 | if console:
40 | if color:
41 | value = f'[{color}]{value}[/{color}]'
42 | console.print(value)
43 | else:
44 | builtins.print(value)
45 |
46 |
47 | def format_(value):
48 | if isinstance(value, dict):
49 | if value and all(isinstance(x, spacelib.Space) for x in value.values()):
50 | return '\n'.join(f' {k:<16} {v}' for k, v in value.items())
51 | items = [f'{format_(k)}: {format_(v)}' for k, v in value.items()]
52 | return '{' + ', '.join(items) + '}'
53 | if isinstance(value, list):
54 | return '[' + ', '.join(f'{format_(x)}' for x in value) + ']'
55 | if isinstance(value, tuple):
56 | return '(' + ', '.join(f'{format_(x)}' for x in value) + ')'
57 | if hasattr(value, 'shape') and hasattr(value, 'dtype'):
58 | shape = ','.join(str(x) for x in value.shape)
59 | dtype = value.dtype.name
60 | for long, short in {'float': 'f', 'uint': 'u', 'int': 'i'}.items():
61 | dtype = dtype.replace(long, short)
62 | return f'{dtype}[{shape}]'
63 | if isinstance(value, bytes):
64 | value = '0x' + value.hex() if r'\x' in str(value) else str(value)
65 | if len(value) > 32:
66 | value = value[:32 - 3] + '...'
67 | return str(value)
68 |
69 |
70 | def treemap(fn, *trees, isleaf=None):
71 | assert trees, 'Provide one or more nested Python structures'
72 | kw = dict(isleaf=isleaf)
73 | first = trees[0]
74 | assert all(isinstance(x, type(first)) for x in trees)
75 | if isleaf and isleaf(trees):
76 | return fn(*trees)
77 | if isinstance(first, list):
78 | assert all(len(x) == len(first) for x in trees), format_(trees)
79 | return [treemap(
80 | fn, *[t[i] for t in trees], **kw) for i in range(len(first))]
81 | if isinstance(first, tuple):
82 | assert all(len(x) == len(first) for x in trees), format_(trees)
83 | return tuple([treemap(
84 | fn, *[t[i] for t in trees], **kw) for i in range(len(first))])
85 | if isinstance(first, dict):
86 | assert all(set(x.keys()) == set(first.keys()) for x in trees), (
87 | format_(trees))
88 | return {k: treemap(fn, *[t[k] for t in trees], **kw) for k in first}
89 | return fn(*trees)
90 |
91 |
92 | def pack(data):
93 | return pickle.dumps(data)
94 | # import msgpack
95 | # def fn(data):
96 | # if isinstance(data, np.ndarray):
97 | # return [b'type_numpy', list(data.shape), data.dtype.name, data.tobytes()]
98 | # if isinstance(data, bytes):
99 | # return [b'type_bytes', data]
100 | # if isinstance(data, tuple):
101 | # return [b'type_tuple', *[fn(x) for x in data]]
102 | # if isinstance(data, list):
103 | # return [fn(x) for x in data]
104 | # if isinstance(data, str):
105 | # return data.encode('utf-8')
106 | # if isinstance(data, dict):
107 | # return {k: fn(v) for k, v in data.items()}
108 | # if allow_pickle:
109 | # primitives = (type(None), bool, int, float, str, bytes)
110 | # if not isinstance(data, primitives):
111 | # return [b'type_pickle', pickle.dumps(data)]
112 | # return data
113 | # data = fn(data)
114 | # # print(format_(data))
115 | # data = msgpack.packb(
116 | # data, use_single_float=True, use_bin_type=True, strict_types=True)
117 | # return data
118 |
119 |
120 | def unpack(buffer):
121 | return pickle.loads(buffer)
122 | # import msgpack
123 | # import pickle
124 | # def fn(data):
125 | # if isinstance(data, list) and data and data[0] == b'type_numpy':
126 | # return np.frombuffer(data[3], data[2].decode('utf-8')).reshape(data[1])
127 | # if isinstance(data, list) and data and data[0] == b'type_bytes':
128 | # return data[1]
129 | # if isinstance(data, list) and data and data[0] == b'type_tuple':
130 | # return tuple([fn(x) for x in data[1:]])
131 | # if isinstance(data, list) and data and data[0] == b'type_pickle':
132 | # assert allow_pickle, 'Buffer contains pickled Python objects.'
133 | # return pickle.loads(data[1])
134 | # if isinstance(data, list):
135 | # return [fn(x) for x in data]
136 | # if isinstance(data, str):
137 | # return data.decode('utf-8')
138 | # if isinstance(data, dict):
139 | # return {k.decode('utf-8'): fn(v) for k, v in data.items()}
140 | # return data
141 | # data = msgpack.unpackb(buffer, raw=True, use_list=True)
142 | # data = fn(data)
143 | # # print(format_(data))
144 | # return data
145 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/core/batch.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from . import base
4 |
5 |
6 | class BatchEnv(base.Env):
7 |
8 | def __init__(self, envs, parallel):
9 | assert all(len(env) == 0 for env in envs)
10 | assert len(envs) > 0
11 | self._envs = envs
12 | self._parallel = parallel
13 | self._keys = list(self.obs_space.keys())
14 |
15 | @property
16 | def obs_space(self):
17 | return self._envs[0].obs_space
18 |
19 | @property
20 | def act_space(self):
21 | return self._envs[0].act_space
22 |
23 | def __len__(self):
24 | return len(self._envs)
25 |
26 | def step(self, action):
27 | assert all(len(v) == len(self._envs) for v in action.values()), (
28 | len(self._envs), {k: v.shape for k, v in action.items()})
29 | obs = []
30 | for i, env in enumerate(self._envs):
31 | act = {k: v[i] for k, v in action.items()}
32 | obs.append(env.step(act))
33 | if self._parallel:
34 | obs = [ob() for ob in obs]
35 | return {k: np.array([ob[k] for ob in obs]) for k in obs[0]}
36 |
37 | def render(self):
38 | return np.stack([env.render() for env in self._envs])
39 |
40 | def close(self):
41 | for env in self._envs:
42 | try:
43 | env.close()
44 | except Exception:
45 | pass
46 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/core/batcher.py:
--------------------------------------------------------------------------------
1 | import queue as queuelib
2 | import sys
3 | import threading
4 | import time
5 | import traceback
6 |
7 | import numpy as np
8 |
9 |
10 | class Batcher:
11 |
12 | def __init__(
13 | self, sources, workers=0, postprocess=None,
14 | prefetch_source=4, prefetch_batch=2):
15 | self._workers = workers
16 | self._postprocess = postprocess
17 | if workers:
18 | # Round-robin assign sources to workers.
19 | self._running = True
20 | self._threads = []
21 | self._queues = []
22 | assignments = [([], []) for _ in range(workers)]
23 | for index, source in enumerate(sources):
24 | queue = queuelib.Queue(prefetch_source)
25 | self._queues.append(queue)
26 | assignments[index % workers][0].append(source)
27 | assignments[index % workers][1].append(queue)
28 | for args in assignments:
29 | creator = threading.Thread(
30 | target=self._creator, args=args, daemon=True)
31 | creator.start()
32 | self._threads.append(creator)
33 | self._batches = queuelib.Queue(prefetch_batch)
34 | batcher = threading.Thread(
35 | target=self._batcher, args=(self._queues, self._batches),
36 | daemon=True)
37 | batcher.start()
38 | self._threads.append(batcher)
39 | else:
40 | self._iterators = [source() for source in sources]
41 | self._once = False
42 |
43 | def close(self):
44 | if self._workers:
45 | self._running = False
46 | for thread in self._threads:
47 | thread.close()
48 |
49 | def __iter__(self):
50 | if self._once:
51 | raise RuntimeError(
52 | 'You can only create one iterator per Batcher object to ensure that '
53 | 'data is consumed in order. Create another Batcher object instead.')
54 | self._once = True
55 | return self
56 |
57 | def __call__(self):
58 | return self.__iter__()
59 |
60 | def __next__(self):
61 | if self._workers:
62 | batch = self._batches.get()
63 | else:
64 | elems = [next(x) for x in self._iterators]
65 | batch = {k: np.stack([x[k] for x in elems], 0) for k in elems[0]}
66 | if isinstance(batch, Exception):
67 | raise batch
68 | return batch
69 |
70 | def _creator(self, sources, outputs):
71 | try:
72 | iterators = [source() for source in sources]
73 | while self._running:
74 | waiting = True
75 | for iterator, queue in zip(iterators, outputs):
76 | if queue.full():
77 | continue
78 | queue.put(next(iterator))
79 | waiting = False
80 | if waiting:
81 | time.sleep(0.001)
82 | except Exception as e:
83 | e.stacktrace = ''.join(traceback.format_exception(*sys.exc_info()))
84 | outputs[0].put(e)
85 | raise
86 |
87 | def _batcher(self, sources, output):
88 | try:
89 | while self._running:
90 | elems = [x.get() for x in sources]
91 | for elem in elems:
92 | if isinstance(elem, Exception):
93 | raise elem
94 | batch = {k: np.stack([x[k] for x in elems], 0) for k in elems[0]}
95 | if self._postprocess:
96 | batch = self._postprocess(batch)
97 | output.put(batch) # Will wait here if the queue is full.
98 | except Exception as e:
99 | e.stacktrace = ''.join(traceback.format_exception(*sys.exc_info()))
100 | output.put(e)
101 | raise
102 |
103 |
104 | class BatcherSM:
105 | """
106 | A class for the batching process that uses shared numpy memory.
107 | It is much faster than standard Batcher when we are working with
108 | long sequences in batches because the former sends a lot of data through
109 | threading.Queue, which is much slower than reading/writing
110 | to the shared numpy arrays.
111 | """
112 | def __init__(
113 | self, replay, workers=0, batch_size=1,
114 | batch_sequence_len=1, postprocess=None,
115 | prefetch_source=4, prefetch_batch=2):
116 | self._workers = workers
117 | self._postprocess = postprocess
118 | # we assume replay was filled with at least one step so the serializer is
119 | # initialized already
120 | self._replay = replay
121 | self.prefetch_batch = prefetch_batch
122 | self.batch_size = batch_size
123 | self._serializer = self._replay.serializer
124 | self._batch_buffers = None
125 | self.batch_sequence_len = batch_sequence_len
126 | self.batch_size = batch_size
127 |
128 | if workers:
129 | # Round-robin assign sources to workers.
130 | self._running = True
131 | self._threads = []
132 | self._queues = []
133 | self.tasks = queuelib.Queue(prefetch_batch * batch_size)
134 | self.reports = queuelib.Queue(prefetch_batch * batch_size)
135 | for _ in range(workers):
136 | creator = threading.Thread(
137 | target=self._creator, args=(), daemon=True
138 | )
139 | creator.start()
140 | self._threads.append(creator)
141 | self._outputs = queuelib.Queue(prefetch_batch)
142 | batcher = threading.Thread(
143 | target=self._batcher, args=(),
144 | daemon=True)
145 | batcher.start()
146 | self._threads.append(batcher)
147 | self._once = False
148 |
149 | def close(self):
150 | if self._workers:
151 | self._running = False
152 | for thread in self._threads:
153 | thread.close()
154 |
155 | def __iter__(self):
156 | if self._once:
157 | raise RuntimeError(
158 | 'You can only create one iterator per Batcher object to ensure that '
159 | 'data is consumed in order. Create another Batcher object instead.')
160 | self._once = True
161 | return self
162 |
163 | def __call__(self):
164 | return self.__iter__()
165 |
166 | def __next__(self):
167 | if self._workers:
168 | batch = self._outputs.get()
169 | else:
170 | elems = [next(x) for x in self._iterators]
171 | batch = {k: np.stack([x[k] for x in elems], 0) for k in elems[0]}
172 | if isinstance(batch, Exception):
173 | raise batch
174 | return batch
175 |
176 | def _creator(self):
177 | try:
178 | while self._running:
179 | if not self.tasks.empty() and self._replay.ready:
180 | if self._batch_buffers is None:
181 | self._serializer = self._replay.serializer
182 | self._batch_buffers = self._serializer.batch_buffer(
183 | self.prefetch_batch, self.batch_size, self.batch_sequence_len)
184 | flip, task_id = self.tasks.get()
185 | success = self._replay.sample(flip, task_id, self._batch_buffers)
186 | assert success
187 | self.reports.put((flip, task_id))
188 | else:
189 | time.sleep(0.001)
190 | except Exception as e:
191 | e.stacktrace = ''.join(traceback.format_exception(*sys.exc_info()))
192 | raise
193 |
194 | def read_buffer(self, flip):
195 | # copy here is necessary as right after this read we'll
196 | # assign another set of data loading tasks which will
197 | # overwrite the batch buffers
198 | return {k: v[flip].copy() for k, v in self._batch_buffers.items()}
199 |
200 | def _batcher(self, ):
201 | if self.tasks.empty():
202 | for flip in range(self.prefetch_batch):
203 | for batch_id in range(self.batch_size):
204 | self.tasks.put((flip, batch_id))
205 | batch_completion = [[False for _ in range(self.batch_size)]
206 | for _ in range(self.prefetch_batch)]
207 | try:
208 | while self._running:
209 | while not self.reports.empty():
210 | flip, task_id = self.reports.get()
211 | batch_completion[flip][task_id] = True
212 | completed = [flip for flip in range(self.prefetch_batch) if all(batch_completion[flip])]
213 | for flip in completed:
214 | batch = self.read_buffer(flip)
215 | for k in range(self.batch_size):
216 | batch_completion[flip][k] = False
217 | self.tasks.put((flip, k))
218 | if self._postprocess:
219 | batch = self._postprocess(batch)
220 | self._outputs.put(batch)
221 | if len(completed) == 0:
222 | time.sleep(0.001)
223 | except Exception as e:
224 | e.stacktrace = ''.join(traceback.format_exception(*sys.exc_info()))
225 | raise
226 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/core/checkpoint.py:
--------------------------------------------------------------------------------
1 | import concurrent.futures
2 | import time
3 |
4 | from . import basics
5 | from . import path
6 |
7 |
8 | class Checkpoint:
9 |
10 | def __init__(self, filename=None, log=True, parallel=True):
11 | self._filename = filename and path.Path(filename)
12 | self._log = log
13 | self._values = {}
14 | self._parallel = parallel
15 | if self._parallel:
16 | self._worker = concurrent.futures.ThreadPoolExecutor(1)
17 | self._promise = None
18 |
19 | def __setattr__(self, name, value):
20 | if name in ('exists', 'save', 'load'):
21 | return super().__setattr__(name, value)
22 | if name.startswith('_'):
23 | return super().__setattr__(name, value)
24 | has_load = hasattr(value, 'load') and callable(value.load)
25 | has_save = hasattr(value, 'save') and callable(value.save)
26 | if not (has_load and has_save):
27 | message = f"Checkpoint entry '{name}' must implement save() and load()."
28 | raise ValueError(message)
29 | self._values[name] = value
30 |
31 | def __getattr__(self, name):
32 | if name.startswith('_'):
33 | raise AttributeError(name)
34 | try:
35 | return getattr(self._values, name)
36 | except AttributeError:
37 | raise ValueError(name)
38 |
39 | def exists(self, filename=None):
40 | assert self._filename or filename
41 | filename = path.Path(filename or self._filename)
42 | exists = self._filename.exists()
43 | self._log and exists and print('Found existing checkpoint.')
44 | self._log and not exists and print('Did not find any checkpoint.')
45 | return exists
46 |
47 | def save(self, filename=None, keys=None):
48 | assert self._filename or filename
49 | filename = path.Path(filename or self._filename)
50 | self._log and print(f'Writing checkpoint: {filename}')
51 | if self._parallel:
52 | self._promise and self._promise.result()
53 | self._promise = self._worker.submit(self._save, filename, keys)
54 | else:
55 | self._save(filename, keys)
56 |
57 | def _save(self, filename, keys):
58 | keys = tuple(self._values.keys() if keys is None else keys)
59 | assert all([not k.startswith('_') for k in keys]), keys
60 | data = {k: self._values[k].save() for k in keys}
61 | data['_timestamp'] = time.time()
62 | if filename.exists():
63 | old = filename.parent / (filename.name + '.old')
64 | filename.copy(old)
65 | filename.write(basics.pack(data), mode='wb')
66 | old.remove()
67 | else:
68 | filename.write(basics.pack(data), mode='wb')
69 | self._log and print(f'Wrote checkpoint: {filename}')
70 |
71 | def load(self, filename=None, keys=None):
72 | assert self._filename or filename
73 | filename = path.Path(filename or self._filename)
74 | self._log and print(f'Loading checkpoint: {filename}')
75 | data = basics.unpack(filename.read('rb'))
76 | keys = tuple(data.keys() if keys is None else keys)
77 | all_loaded = True
78 | for key in keys:
79 | if key.startswith('_'):
80 | continue
81 | try:
82 | loaded = self._values[key].load(data[key])
83 | if loaded is None:
84 | loaded = True
85 | all_loaded = loaded and all_loaded
86 | except Exception:
87 | print(f'Error loading {key} from checkpoint.')
88 | raise
89 | if self._log:
90 | age = time.time() - data['_timestamp']
91 | print(f'Loaded checkpoint from {age:.0f} seconds ago.')
92 | return all_loaded
93 |
94 | def load_or_save(self, filename=None):
95 | if self.exists(filename=filename):
96 | loaded = self.load()
97 | if loaded is not None and not loaded:
98 | self.save()
99 | else:
100 | self.save()
101 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/core/config.py:
--------------------------------------------------------------------------------
1 | import io
2 | import json
3 | import re
4 |
5 | from . import path
6 |
7 |
8 | class Config(dict):
9 |
10 | SEP = '.'
11 | IS_PATTERN = re.compile(r'.*[^A-Za-z0-9_.-].*')
12 |
13 | def __init__(self, *args, **kwargs):
14 | mapping = dict(*args, **kwargs)
15 | mapping = self._flatten(mapping)
16 | mapping = self._ensure_keys(mapping)
17 | mapping = self._ensure_values(mapping)
18 | self._flat = mapping
19 | self._nested = self._nest(mapping)
20 | # Need to assign the values to the base class dictionary so that
21 | # conversion to dict does not lose the content.
22 | super().__init__(self._nested)
23 |
24 | @property
25 | def flat(self):
26 | return self._flat.copy()
27 |
28 | def save(self, filename):
29 | filename = path.Path(filename)
30 | if filename.suffix == '.json':
31 | filename.write(json.dumps(dict(self)))
32 | elif filename.suffix in ('.yml', '.yaml'):
33 | import ruamel.yaml as yaml
34 | with io.StringIO() as stream:
35 | yaml.safe_dump(dict(self), stream)
36 | filename.write(stream.getvalue())
37 | else:
38 | raise NotImplementedError(filename.suffix)
39 |
40 | @classmethod
41 | def load(cls, filename):
42 | filename = path.Path(filename)
43 | if filename.suffix == '.json':
44 | return cls(json.loads(filename.read_text()))
45 | elif filename.suffix in ('.yml', '.yaml'):
46 | import ruamel.yaml as yaml
47 | return cls(yaml.safe_load(filename.read_text()))
48 | else:
49 | raise NotImplementedError(filename.suffix)
50 |
51 | def __contains__(self, name):
52 | try:
53 | self[name]
54 | return True
55 | except KeyError:
56 | return False
57 |
58 | def __getattr__(self, name):
59 | if name.startswith('_'):
60 | return super().__getattr__(name)
61 | try:
62 | return self[name]
63 | except KeyError:
64 | raise AttributeError(name)
65 |
66 | def __getitem__(self, name):
67 | result = self._nested
68 | for part in name.split(self.SEP):
69 | try:
70 | result = result[part]
71 | except TypeError:
72 | raise KeyError
73 | if isinstance(result, dict):
74 | result = type(self)(result)
75 | return result
76 |
77 | def __setattr__(self, key, value):
78 | if key.startswith('_'):
79 | return super().__setattr__(key, value)
80 | message = f"Tried to set key '{key}' on immutable config. Use update()."
81 | raise AttributeError(message)
82 |
83 | def __setitem__(self, key, value):
84 | if key.startswith('_'):
85 | return super().__setitem__(key, value)
86 | message = f"Tried to set key '{key}' on immutable config. Use update()."
87 | raise AttributeError(message)
88 |
89 | def __reduce__(self):
90 | return (type(self), (dict(self),))
91 |
92 | def __str__(self):
93 | lines = ['\nConfig:']
94 | keys, vals, typs = [], [], []
95 | for key, val in self.flat.items():
96 | keys.append(key + ':')
97 | vals.append(self._format_value(val))
98 | typs.append(self._format_type(val))
99 | max_key = max(len(k) for k in keys) if keys else 0
100 | max_val = max(len(v) for v in vals) if vals else 0
101 | for key, val, typ in zip(keys, vals, typs):
102 | key = key.ljust(max_key)
103 | val = val.ljust(max_val)
104 | lines.append(f'{key} {val} ({typ})')
105 | return '\n'.join(lines)
106 |
107 | def update(self, *args, **kwargs):
108 | result = self._flat.copy()
109 | inputs = self._flatten(dict(*args, **kwargs))
110 | for key, new in inputs.items():
111 | if self.IS_PATTERN.match(key):
112 | pattern = re.compile(key)
113 | keys = {k for k in result if pattern.match(k)}
114 | else:
115 | keys = [key]
116 | if not keys:
117 | raise KeyError(f'Unknown key or pattern {key}.')
118 | for key in keys:
119 | old = result[key]
120 | try:
121 | if isinstance(old, int) and isinstance(new, float):
122 | if float(int(new)) != new:
123 | message = f"Cannot convert fractional float {new} to int."
124 | raise ValueError(message)
125 | result[key] = type(old)(new)
126 | except (ValueError, TypeError):
127 | raise TypeError(
128 | f"Cannot convert '{new}' to type '{type(old).__name__}' " +
129 | f"for key '{key}' with previous value '{old}'.")
130 | return type(self)(result)
131 |
132 | def _flatten(self, mapping):
133 | result = {}
134 | for key, value in mapping.items():
135 | if isinstance(value, dict):
136 | for k, v in self._flatten(value).items():
137 | if self.IS_PATTERN.match(key) or self.IS_PATTERN.match(k):
138 | combined = f'{key}\\{self.SEP}{k}'
139 | else:
140 | combined = f'{key}{self.SEP}{k}'
141 | result[combined] = v
142 | else:
143 | result[key] = value
144 | return result
145 |
146 | def _nest(self, mapping):
147 | result = {}
148 | for key, value in mapping.items():
149 | parts = key.split(self.SEP)
150 | node = result
151 | for part in parts[:-1]:
152 | if part not in node:
153 | node[part] = {}
154 | node = node[part]
155 | node[parts[-1]] = value
156 | return result
157 |
158 | def _ensure_keys(self, mapping):
159 | for key in mapping:
160 | assert not self.IS_PATTERN.match(key), key
161 | return mapping
162 |
163 | def _ensure_values(self, mapping):
164 | result = json.loads(json.dumps(mapping))
165 | for key, value in result.items():
166 | if isinstance(value, list):
167 | value = tuple(value)
168 | if isinstance(value, tuple):
169 | if len(value) == 0:
170 | message = 'Empty lists are disallowed because their type is unclear.'
171 | raise TypeError(message)
172 | if not isinstance(value[0], (str, float, int, bool)):
173 | message = 'Lists can only contain strings, floats, ints, bools'
174 | message += f' but not {type(value[0])}'
175 | raise TypeError(message)
176 | if not all(isinstance(x, type(value[0])) for x in value[1:]):
177 | message = 'Elements of a list must all be of the same type.'
178 | raise TypeError(message)
179 | result[key] = value
180 | return result
181 |
182 | def _format_value(self, value):
183 | if isinstance(value, (list, tuple)):
184 | return '[' + ', '.join(self._format_value(x) for x in value) + ']'
185 | return str(value)
186 |
187 | def _format_type(self, value):
188 | if isinstance(value, (list, tuple)):
189 | assert len(value) > 0, value
190 | return self._format_type(value[0]) + 's'
191 | return str(type(value).__name__)
192 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/core/counter.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 |
4 | @functools.total_ordering
5 | class Counter:
6 |
7 | def __init__(self, initial=0):
8 | self.value = initial
9 |
10 | def __repr__(self):
11 | return f'Counter({self.value})'
12 |
13 | def __int__(self):
14 | return int(self.value)
15 |
16 | def __eq__(self, other):
17 | return int(self) == other
18 |
19 | def __ne__(self, other):
20 | return int(self) != other
21 |
22 | def __lt__(self, other):
23 | return int(self) < other
24 |
25 | def __add__(self, other):
26 | return int(self) + other
27 |
28 | def __radd__(self, other):
29 | return other - int(self)
30 |
31 | def __sub__(self, other):
32 | return int(self) - other
33 |
34 | def __rsub__(self, other):
35 | return other - int(self)
36 |
37 | def increment(self, amount=1):
38 | self.value += amount
39 |
40 | def save(self):
41 | return self.value
42 |
43 | def load(self, value):
44 | self.value = value
45 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/core/distr.py:
--------------------------------------------------------------------------------
1 | import ctypes
2 | import sys
3 | import threading
4 | import time
5 | import traceback
6 | import uuid
7 |
8 | import numpy as np
9 |
10 | from . import basics
11 |
12 |
13 | class Client:
14 |
15 | def __init__(self, address, timeout_ms=-1, ipv6=False):
16 | import zmq
17 | addresses = [address] if isinstance(address, str) else address
18 | context = zmq.Context.instance()
19 | self.socket = context.socket(zmq.REQ)
20 | self.socket.setsockopt(zmq.IDENTITY, uuid.uuid4().bytes)
21 | self.socket.RCVTIMEO = timeout_ms
22 | for address in addresses:
23 | basics.print_(f'Client connecting to {address}', color='green')
24 | ipv6 and self.socket.setsockopt(zmq.IPV6, 1)
25 | self.socket.connect(address)
26 | self.result = True
27 |
28 | def __call__(self, data):
29 | assert isinstance(data, dict), type(data)
30 | if self.result is None:
31 | self._receive()
32 | self.result = None
33 | self.socket.send(basics.pack(data))
34 | return self._receive
35 |
36 | def _receive(self):
37 | try:
38 | recieved = self.socket.recv()
39 | except Exception as e:
40 | raise RuntimeError(f'Failed to receive data from server: {e}')
41 | self.result = basics.unpack(recieved)
42 | if self.result.get('type', 'data') == 'error':
43 | msg = self.result.get('message', None)
44 | raise RuntimeError(f'Server responded with an error: {msg}')
45 | return self.result
46 |
47 |
48 | class Server:
49 |
50 | def __init__(self, address, function, ipv6=False):
51 | import zmq
52 | context = zmq.Context.instance()
53 | self.socket = context.socket(zmq.REP)
54 | basics.print_(f'Server listening at {address}', color='green')
55 | ipv6 and self.socket.setsockopt(zmq.IPV6, 1)
56 | self.socket.bind(address)
57 | self.function = function
58 |
59 | def run(self):
60 | while True:
61 | payload = self.socket.recv()
62 | inputs = basics.unpack(payload)
63 | assert isinstance(inputs, dict), type(inputs)
64 | try:
65 | result = self.function(inputs)
66 | assert isinstance(result, dict), type(result)
67 | except Exception as e:
68 | result = {'type': 'error', 'message': str(e)}
69 | self.socket.send(basics.pack(payload))
70 | raise
71 | payload = basics.pack(result)
72 | self.socket.send(payload)
73 |
74 |
75 | class BatchServer:
76 |
77 | def __init__(self, address, batch, function, ipv6=False):
78 | import zmq
79 | context = zmq.Context.instance()
80 | self.socket = context.socket(zmq.ROUTER)
81 | basics.print_(f'BatchServer listening at {address}', color='green')
82 | ipv6 and self.socket.setsockopt(zmq.IPV6, 1)
83 | self.socket.bind(address)
84 | self.function = function
85 | self.batch = batch
86 |
87 | def run(self):
88 | inputs = None
89 | while True:
90 | addresses = []
91 | for i in range(self.batch):
92 | address, empty, payload = self.socket.recv_multipart()
93 | data = basics.unpack(payload)
94 | assert isinstance(data, dict), type(data)
95 | if inputs is None:
96 | inputs = {
97 | k: np.empty((self.batch, *v.shape), v.dtype)
98 | for k, v in data.items() if not isinstance(v, str)}
99 | for key, value in data.items():
100 | inputs[key][i] = value
101 | addresses.append(address)
102 | try:
103 | results = self.function(inputs, [x.hex() for x in addresses])
104 | assert isinstance(results, dict), type(results)
105 | for key, value in results.items():
106 | if not isinstance(value, str):
107 | assert len(value) == self.batch, (key, value.shape)
108 | except Exception as e:
109 | results = {'type': 'error', 'message': str(e)}
110 | self._respond(addresses, results)
111 | raise
112 | self._respond(addresses, results)
113 |
114 | def _respond(self, addresses, results):
115 | for i, address in enumerate(addresses):
116 | payload = basics.pack({
117 | k: v if isinstance(v, str) else v[i]
118 | for k, v in results.items()})
119 | self.socket.send_multipart([address, b'', payload])
120 |
121 |
122 | class Thread(threading.Thread):
123 |
124 | lock = threading.Lock()
125 |
126 | def __init__(self, fn, *args, name=None):
127 | self.fn = fn
128 | self.exitcode = None
129 | name = name or fn.__name__
130 | super().__init__(target=self._wrapper, args=args, name=name, daemon=True)
131 |
132 | def _wrapper(self, *args):
133 | try:
134 | self.fn(*args)
135 | except Exception:
136 | with self.lock:
137 | print('-' * 79)
138 | print(f'Exception in worker: {self.name}')
139 | print('-' * 79)
140 | print(''.join(traceback.format_exception(*sys.exc_info())))
141 | self.exitcode = 1
142 | raise
143 | self.exitcode = 0
144 |
145 | def terminate(self):
146 | if not self.is_alive():
147 | return
148 | if hasattr(self, '_thread_id'):
149 | thread_id = self._thread_id
150 | else:
151 | thread_id = [k for k, v in threading._active.items() if v is self][0]
152 | result = ctypes.pythonapi.PyThreadState_SetAsyncExc(
153 | ctypes.c_long(thread_id), ctypes.py_object(SystemExit))
154 | if result > 1:
155 | ctypes.pythonapi.PyThreadState_SetAsyncExc(
156 | ctypes.c_long(thread_id), None)
157 | print('Shut down worker:', self.name)
158 |
159 |
160 | class Process:
161 |
162 | lock = None
163 | initializers = []
164 |
165 | def __init__(self, fn, *args, name=None):
166 | import multiprocessing
167 | import cloudpickle
168 | mp = multiprocessing.get_context('spawn')
169 | if Process.lock is None:
170 | Process.lock = mp.Lock()
171 | name = name or fn.__name__
172 | initializers = cloudpickle.dumps(self.initializers)
173 | args = (initializers,) + args
174 | self._process = mp.Process(
175 | target=self._wrapper, args=(Process.lock, fn, *args),
176 | name=name)
177 |
178 | def start(self):
179 | self._process.start()
180 |
181 | @property
182 | def name(self):
183 | return self._process.name
184 |
185 | @property
186 | def exitcode(self):
187 | return self._process.exitcode
188 |
189 | def terminate(self):
190 | self._process.terminate()
191 | print('Shut down worker:', self.name)
192 |
193 | def _wrapper(self, lock, fn, *args):
194 | try:
195 | import cloudpickle
196 | initializers, *args = args
197 | for initializer in cloudpickle.loads(initializers):
198 | initializer()
199 | fn(*args)
200 | except Exception:
201 | with lock:
202 | print('-' * 79)
203 | print(f'Exception in worker: {self.name}')
204 | print('-' * 79)
205 | print(''.join(traceback.format_exception(*sys.exc_info())))
206 | raise
207 |
208 |
209 | def run(workers):
210 | [x.start() for x in workers]
211 | while True:
212 | if all(x.exitcode == 0 for x in workers):
213 | print('All workers terminated successfully.')
214 | return
215 | for worker in workers:
216 | if worker.exitcode not in (None, 0):
217 | # Wait for everybody who wants to print their error messages.
218 | time.sleep(1)
219 | [x.terminate() for x in workers if x is not worker]
220 | raise RuntimeError(f'Stopped workers due to crash in {worker.name}.')
221 | time.sleep(0.1)
222 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/core/driver.py:
--------------------------------------------------------------------------------
1 | import collections
2 |
3 | import numpy as np
4 |
5 | from .basics import convert
6 |
7 |
8 | class Driver:
9 |
10 | _CONVERSION = {
11 | np.floating: np.float32,
12 | np.signedinteger: np.int32,
13 | np.uint8: np.uint8,
14 | bool: bool,
15 | }
16 |
17 | def __init__(self, env, **kwargs):
18 | assert len(env) > 0
19 | self._env = env
20 | self._kwargs = kwargs
21 | self._on_steps = []
22 | self._on_episodes = []
23 | self.reset()
24 |
25 | def reset(self):
26 | self._acts = {
27 | k: convert(np.zeros((len(self._env),) + v.shape, v.dtype))
28 | for k, v in self._env.act_space.items()}
29 | self._acts['reset'] = np.ones(len(self._env), bool)
30 | self._eps = [collections.defaultdict(list) for _ in range(len(self._env))]
31 | self._state = None
32 |
33 | def on_step(self, callback):
34 | self._on_steps.append(callback)
35 |
36 | def on_episode(self, callback):
37 | self._on_episodes.append(callback)
38 |
39 | def __call__(self, policy, steps=0, episodes=0):
40 | step, episode = 0, 0
41 | while step < steps or episode < episodes:
42 | step, episode = self._step(policy, step, episode)
43 |
44 | def _step(self, policy, step, episode):
45 | assert all(len(x) == len(self._env) for x in self._acts.values())
46 | acts = {k: v for k, v in self._acts.items() if not k.startswith('log_')}
47 | obs = self._env.step(acts)
48 | obs = {k: convert(v) for k, v in obs.items()}
49 | assert all(len(x) == len(self._env) for x in obs.values()), obs
50 | acts, self._state = policy(obs, self._state, **self._kwargs)
51 | acts = {k: convert(v) for k, v in acts.items()}
52 | if obs['is_last'].any():
53 | mask = 1 - obs['is_last']
54 | acts = {k: v * self._expand(mask, len(v.shape)) for k, v in acts.items()}
55 | acts['reset'] = obs['is_last'].copy()
56 | self._acts = acts
57 | trns = {**obs, **acts}
58 | if obs['is_first'].any():
59 | for i, first in enumerate(obs['is_first']):
60 | if first:
61 | self._eps[i].clear()
62 | for i in range(len(self._env)):
63 | trn = {k: v[i] for k, v in trns.items()}
64 | [self._eps[i][k].append(v) for k, v in trn.items()]
65 | [fn(trn, i, **self._kwargs) for fn in self._on_steps]
66 | step += 1
67 | if obs['is_last'].any():
68 | for i, done in enumerate(obs['is_last']):
69 | if done:
70 | ep = {k: convert(v) for k, v in self._eps[i].items()}
71 | [fn(ep.copy(), i, **self._kwargs) for fn in self._on_episodes]
72 | episode += 1
73 | return step, episode
74 |
75 | def _expand(self, value, dims):
76 | while len(value.shape) < dims:
77 | value = value[..., None]
78 | return value
79 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/core/flags.py:
--------------------------------------------------------------------------------
1 | import re
2 | import sys
3 |
4 | from . import config
5 |
6 |
7 | class Flags:
8 |
9 | def __init__(self, *args, **kwargs):
10 | self._config = config.Config(*args, **kwargs)
11 |
12 | def parse(self, argv=None, help_exists=True):
13 | parsed, remaining = self.parse_known(argv)
14 | for flag in remaining:
15 | if flag.startswith('--'):
16 | raise ValueError(f"Flag '{flag}' did not match any config keys.")
17 | assert not remaining, remaining
18 | return parsed
19 |
20 | def parse_known(self, argv=None, help_exists=False):
21 | if argv is None:
22 | argv = sys.argv[1:]
23 | if '--help' in argv:
24 | print('\nHelp:')
25 | lines = str(self._config).split('\n')[2:]
26 | print('\n'.join('--' + re.sub(r'[:,\[\]]', '', x) for x in lines))
27 | help_exists and sys.exit()
28 | parsed = {}
29 | remaining = []
30 | key = None
31 | vals = None
32 | for arg in argv:
33 | if arg.startswith('--'):
34 | if key:
35 | self._submit_entry(key, vals, parsed, remaining)
36 | if '=' in arg:
37 | key, val = arg.split('=', 1)
38 | vals = [val]
39 | else:
40 | key, vals = arg, []
41 | else:
42 | if key:
43 | vals.append(arg)
44 | else:
45 | remaining.append(arg)
46 | self._submit_entry(key, vals, parsed, remaining)
47 | parsed = self._config.update(parsed)
48 | return parsed, remaining
49 |
50 | def _submit_entry(self, key, vals, parsed, remaining):
51 | if not key and not vals:
52 | return
53 | if not key:
54 | vals = ', '.join(f"'{x}'" for x in vals)
55 | raise ValueError(f"Values {vals} were not preceded by any flag.")
56 | name = key[len('--'):]
57 | if '=' in name:
58 | remaining.extend([key] + vals)
59 | return
60 | if self._config.IS_PATTERN.fullmatch(name):
61 | pattern = re.compile(name)
62 | keys = {k for k in self._config.flat if pattern.fullmatch(k)}
63 | elif name in self._config:
64 | keys = [name]
65 | else:
66 | keys = []
67 | if not keys:
68 | remaining.extend([key] + vals)
69 | return
70 | if not vals:
71 | raise ValueError(f"Flag '{key}' was not followed by any values.")
72 | for key in keys:
73 | parsed[key] = self._parse_flag_value(self._config[key], vals, key)
74 |
75 | def _parse_flag_value(self, default, value, key):
76 | value = value if isinstance(value, (tuple, list)) else (value,)
77 | if isinstance(default, (tuple, list)):
78 | if len(value) == 1 and ',' in value[0]:
79 | value = value[0].split(',')
80 | return tuple(self._parse_flag_value(default[0], [x], key) for x in value)
81 | assert len(value) == 1, value
82 | value = str(value[0])
83 | if default is None:
84 | return value
85 | if isinstance(default, bool):
86 | try:
87 | return bool(['False', 'True'].index(value))
88 | except ValueError:
89 | message = f"Expected bool but got '{value}' for key '{key}'."
90 | raise TypeError(message)
91 | if isinstance(default, int):
92 | try:
93 | value = float(value) # Allow scientific notation for integers.
94 | assert float(int(value)) == value
95 | except (TypeError, AssertionError):
96 | message = f"Expected int but got float '{value}' for key '{key}'."
97 | raise TypeError(message)
98 | return int(value)
99 | if isinstance(default, dict):
100 | raise TypeError(
101 | f"Key '{key}' refers to a whole dict. Please speicfy a subkey.")
102 | return type(default)(value)
103 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/core/metrics.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import warnings
3 |
4 | import numpy as np
5 |
6 |
7 | class Metrics:
8 |
9 | def __init__(self):
10 | self._scalars = collections.defaultdict(list)
11 | self._lasts = {}
12 |
13 | def scalar(self, key, value):
14 | self._scalars[key].append(value)
15 |
16 | def image(self, key, value):
17 | self._lasts[key].append(value)
18 |
19 | def video(self, key, value):
20 | self._lasts[key].append(value)
21 |
22 | def add(self, mapping, prefix=None):
23 | for key, value in mapping.items():
24 | key = prefix + '/' + key if prefix else key
25 | if hasattr(value, 'shape') and len(value.shape) > 0:
26 | self._lasts[key] = value
27 | else:
28 | self._scalars[key].append(value)
29 |
30 | def result(self, reset=True):
31 | result = {}
32 | result.update(self._lasts)
33 | with warnings.catch_warnings(): # Ignore empty slice warnings.
34 | warnings.simplefilter('ignore', category=RuntimeWarning)
35 | for key, values in self._scalars.items():
36 | result[key] = np.nanmean(values, dtype=np.float64)
37 | reset and self.reset()
38 | return result
39 |
40 | def reset(self):
41 | self._scalars.clear()
42 | self._lasts.clear()
43 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/core/parallel.py:
--------------------------------------------------------------------------------
1 | import enum
2 | from functools import partial as bind
3 |
4 | from . import worker
5 |
6 |
7 | class Parallel:
8 |
9 | def __init__(self, ctor, strategy):
10 | self.worker = worker.Worker(
11 | bind(self._respond, ctor), strategy, state=True)
12 | self.callables = {}
13 |
14 | def __getattr__(self, name):
15 | if name.startswith('_'):
16 | raise AttributeError(name)
17 | try:
18 | if name not in self.callables:
19 | self.callables[name] = self.worker(Message.CALLABLE, name)()
20 | if self.callables[name]:
21 | return bind(self.worker, Message.CALL, name)
22 | else:
23 | return self.worker(Message.READ, name)()
24 | except AttributeError:
25 | raise ValueError(name)
26 |
27 | def __len__(self):
28 | return self.worker(Message.CALL, '__len__')()
29 |
30 | def close(self):
31 | self.worker.close()
32 |
33 | @staticmethod
34 | def _respond(ctor, state, message, name, *args, **kwargs):
35 | state = state or ctor()
36 | if message == Message.CALLABLE:
37 | assert not args and not kwargs, (args, kwargs)
38 | result = callable(getattr(state, name))
39 | elif message == Message.CALL:
40 | result = getattr(state, name)(*args, **kwargs)
41 | elif message == Message.READ:
42 | assert not args and not kwargs, (args, kwargs)
43 | result = getattr(state, name)
44 | return state, result
45 |
46 |
47 | class Message(enum.Enum):
48 |
49 | CALLABLE = 2
50 | CALL = 3
51 | READ = 4
52 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/core/path.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import glob
3 | import os
4 | import re
5 | import shutil
6 |
7 |
8 | class Path:
9 |
10 | filesystems = []
11 |
12 | def __new__(cls, path):
13 | path = str(path)
14 | for impl, pred in cls.filesystems:
15 | if pred(path):
16 | obj = super().__new__(impl)
17 | obj.__init__(path)
18 | return obj
19 | raise NotImplementedError(f'No filesystem supports: {path}')
20 |
21 | def __getnewargs__(self):
22 | return (self._path,)
23 |
24 | def __init__(self, path):
25 | assert isinstance(path, str)
26 | path = re.sub(r'^\./*', '', path) # Remove leading dot or dot slashes.
27 | path = re.sub(r'(?<=[^/])/$', '', path) # Remove single trailing slash.
28 | path = path or '.' # Empty path is represented by a dot.
29 | self._path = path
30 |
31 | def __truediv__(self, part):
32 | sep = '' if self._path.endswith('/') else '/'
33 | return type(self)(f'{self._path}{sep}{str(part)}')
34 |
35 | def __repr__(self):
36 | return f'Path({str(self)})'
37 |
38 | def __fspath__(self):
39 | return str(self)
40 |
41 | def __eq__(self, other):
42 | return self._path == other._path
43 |
44 | def __lt__(self, other):
45 | return self._path < other._path
46 |
47 | def __str__(self):
48 | return self._path
49 |
50 | @property
51 | def parent(self):
52 | if '/' not in self._path:
53 | return type(self)('.')
54 | parent = self._path.rsplit('/', 1)[0]
55 | parent = parent or ('/' if self._path.startswith('/') else '.')
56 | return type(self)(parent)
57 |
58 | @property
59 | def name(self):
60 | if '/' not in self._path:
61 | return self._path
62 | return self._path.rsplit('/', 1)[1]
63 |
64 | @property
65 | def stem(self):
66 | return self.name.split('.', 1)[0] if '.' in self.name else self.name
67 |
68 | @property
69 | def suffix(self):
70 | return ('.' + self.name.split('.', 1)[1]) if '.' in self.name else ''
71 |
72 | def read(self, mode='r'):
73 | assert mode in 'r rb'.split(), mode
74 | with self.open(mode) as f:
75 | return f.read()
76 |
77 | def write(self, content, mode='w'):
78 | assert mode in 'w a wb ab'.split(), mode
79 | with self.open(mode) as f:
80 | f.write(content)
81 |
82 | @contextlib.contextmanager
83 | def open(self, mode='r'):
84 | raise NotImplementedError
85 |
86 | def absolute(self):
87 | raise NotImplementedError
88 |
89 | def glob(self, pattern):
90 | raise NotImplementedError
91 |
92 | def exists(self):
93 | raise NotImplementedError
94 |
95 | def isfile(self):
96 | raise NotImplementedError
97 |
98 | def isdir(self):
99 | raise NotImplementedError
100 |
101 | def mkdirs(self):
102 | raise NotImplementedError
103 |
104 | def remove(self):
105 | raise NotImplementedError
106 |
107 | def rmtree(self):
108 | raise NotImplementedError
109 |
110 | def copy(self, dest):
111 | raise NotImplementedError
112 |
113 | def move(self, dest):
114 | self.copy(dest)
115 | self.remove()
116 |
117 |
118 | class LocalPath(Path):
119 |
120 | def __init__(self, path):
121 | super().__init__(os.path.expanduser(str(path)))
122 |
123 | @contextlib.contextmanager
124 | def open(self, mode='r'):
125 | with open(str(self), mode=mode) as f:
126 | yield f
127 |
128 | def absolute(self):
129 | return type(self)(os.path.absolute(str(self)))
130 |
131 | def glob(self, pattern):
132 | for path in glob.glob(f'{str(self)}/{pattern}'):
133 | yield type(self)(path)
134 |
135 | def exists(self):
136 | return os.path.exists(str(self))
137 |
138 | def isfile(self):
139 | return os.path.isfile(str(self))
140 |
141 | def isdir(self):
142 | return os.path.isdir(str(self))
143 |
144 | def mkdirs(self):
145 | os.makedirs(str(self), exist_ok=True)
146 |
147 | def remove(self):
148 | os.rmdir(str(self)) if self.isdir() else os.remove(str(self))
149 |
150 | def rmtree(self):
151 | shutil.rmtree(self)
152 |
153 | def copy(self, dest):
154 | if self.isfile():
155 | shutil.copy(self, type(self)(dest))
156 | else:
157 | shutil.copytree(self, type(self)(dest), dirs_exist_ok=True)
158 |
159 | def move(self, dest):
160 | shutil.move(self, dest)
161 |
162 |
163 | class GFilePath(Path):
164 |
165 | def __init__(self, path):
166 | path = str(path)
167 | if not (path.startswith('/') or '://' in path):
168 | path = os.path.abspath(os.path.expanduser(path))
169 | super().__init__(path)
170 | import tensorflow as tf
171 | self._gfile = tf.io.gfile
172 |
173 | @contextlib.contextmanager
174 | def open(self, mode='r'):
175 | path = str(self)
176 | if 'a' in mode and path.startswith('/cns/'):
177 | path += '%r=3.2'
178 | if mode.startswith('x') and self.exists():
179 | raise FileExistsError(path)
180 | mode = mode.replace('x', 'w')
181 | with self._gfile.GFile(path, mode) as f:
182 | yield f
183 |
184 | def absolute(self):
185 | return self
186 |
187 | def glob(self, pattern):
188 | for path in self._gfile.glob(f'{str(self)}/{pattern}'):
189 | yield type(self)(path)
190 |
191 | def exists(self):
192 | return self._gfile.exists(str(self))
193 |
194 | def isfile(self):
195 | return self.exists() and not self.isdir()
196 |
197 | def isdir(self):
198 | return self._gfile.isdir(str(self))
199 |
200 | def mkdirs(self):
201 | self._gfile.makedirs(str(self))
202 |
203 | def remove(self):
204 | self._gfile.remove(str(self))
205 |
206 | def rmtree(self):
207 | self._gfile.rmtree(str(self))
208 |
209 | def copy(self, dest):
210 | self._gfile.copy(str(self), str(dest), overwrite=True)
211 |
212 | def move(self, dest):
213 | dest = Path(dest)
214 | if dest.isdir():
215 | dest.rmtree()
216 | self._gfile.rename(self, str(dest), overwrite=True)
217 |
218 |
219 | Path.filesystems = [
220 | (GFilePath, lambda path: path.startswith('gs://')),
221 | (GFilePath, lambda path: path.startswith('/cns/')),
222 | (LocalPath, lambda path: True),
223 | ]
224 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/core/random.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class RandomAgent:
5 |
6 | def __init__(self, act_space):
7 | self.act_space = act_space
8 |
9 | def policy(self, obs, state=None, mode='train'):
10 | batch_size = len(next(iter(obs.values())))
11 | act = {
12 | k: np.stack([v.sample() for _ in range(batch_size)])
13 | for k, v in self.act_space.items() if k != 'reset'}
14 | return act, state
15 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/core/space.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class Space:
5 |
6 | def __init__(self, dtype, shape=(), low=None, high=None):
7 | # For integer types, high is the excluside upper bound.
8 | shape = (shape,) if isinstance(shape, int) else shape
9 | self._dtype = np.dtype(dtype)
10 | assert self._dtype is not object, self._dtype
11 | assert isinstance(shape, tuple), shape
12 | self._low = self._infer_low(dtype, shape, low, high)
13 | self._high = self._infer_high(dtype, shape, low, high)
14 | self._shape = self._infer_shape(dtype, shape, low, high)
15 | self._discrete = (
16 | np.issubdtype(self.dtype, np.integer) or self.dtype == bool)
17 | self._random = np.random.RandomState()
18 |
19 | @property
20 | def dtype(self):
21 | return self._dtype
22 |
23 | @property
24 | def shape(self):
25 | return self._shape
26 |
27 | @property
28 | def low(self):
29 | return self._low
30 |
31 | @property
32 | def high(self):
33 | return self._high
34 |
35 | @property
36 | def discrete(self):
37 | return self._discrete
38 |
39 | def __repr__(self):
40 | return (
41 | f'Space(dtype={self.dtype.name}, '
42 | f'shape={self.shape}, '
43 | f'low={self.low.min()}, '
44 | f'high={self.high.max()})')
45 |
46 | def __contains__(self, value):
47 | value = np.asarray(value)
48 | if value.shape != self.shape:
49 | return False
50 | if (value > self.high).any():
51 | return False
52 | if (value < self.low).any():
53 | return False
54 | if (value.astype(self.dtype).astype(value.dtype) != value).any():
55 | return False
56 | return True
57 |
58 | def sample(self):
59 | low, high = self.low, self.high
60 | if np.issubdtype(self.dtype, np.floating):
61 | low = np.maximum(np.ones(self.shape) * np.finfo(self.dtype).min, low)
62 | high = np.minimum(np.ones(self.shape) * np.finfo(self.dtype).max, high)
63 | return self._random.uniform(low, high, self.shape).astype(self.dtype)
64 |
65 | def _infer_low(self, dtype, shape, low, high):
66 | if low is not None:
67 | try:
68 | return np.broadcast_to(low, shape)
69 | except ValueError:
70 | raise ValueError(f'Cannot broadcast {low} to shape {shape}')
71 | elif np.issubdtype(dtype, np.floating):
72 | return -np.inf * np.ones(shape)
73 | elif np.issubdtype(dtype, np.integer):
74 | return np.iinfo(dtype).min * np.ones(shape, dtype)
75 | elif np.issubdtype(dtype, bool):
76 | return np.zeros(shape, bool)
77 | else:
78 | raise ValueError('Cannot infer low bound from shape and dtype.')
79 |
80 | def _infer_high(self, dtype, shape, low, high):
81 | if high is not None:
82 | try:
83 | return np.broadcast_to(high, shape)
84 | except ValueError:
85 | raise ValueError(f'Cannot broadcast {high} to shape {shape}')
86 | elif np.issubdtype(dtype, np.floating):
87 | return np.inf * np.ones(shape)
88 | elif np.issubdtype(dtype, np.integer):
89 | return np.iinfo(dtype).max * np.ones(shape, dtype)
90 | elif np.issubdtype(dtype, bool):
91 | return np.ones(shape, bool)
92 | else:
93 | raise ValueError('Cannot infer high bound from shape and dtype.')
94 |
95 | def _infer_shape(self, dtype, shape, low, high):
96 | if shape is None and low is not None:
97 | shape = low.shape
98 | if shape is None and high is not None:
99 | shape = high.shape
100 | if not hasattr(shape, '__len__'):
101 | shape = (shape,)
102 | assert all(dim and dim > 0 for dim in shape), shape
103 | return tuple(shape)
104 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/core/timer.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import contextlib
3 | import time
4 |
5 | import numpy as np
6 |
7 |
8 | class Timer:
9 |
10 | def __init__(self, columns=('frac', 'min', 'avg', 'max', 'count', 'total')):
11 | available = ('frac', 'avg', 'min', 'max', 'count', 'total')
12 | assert all(x in available for x in columns), columns
13 | self._columns = columns
14 | self._durations = collections.defaultdict(list)
15 | self._start = time.time()
16 |
17 | def reset(self):
18 | for timings in self._durations.values():
19 | timings.clear()
20 | self._start = time.time()
21 |
22 | @contextlib.contextmanager
23 | def scope(self, name):
24 | start = time.time()
25 | yield
26 | stop = time.time()
27 | self._durations[name].append(stop - start)
28 |
29 | def wrap(self, name, obj, methods):
30 | for method in methods:
31 | decorator = self.scope(f'{name}.{method}')
32 | setattr(obj, method, decorator(getattr(obj, method)))
33 |
34 | def stats(self, reset=True, log=False):
35 | metrics = {}
36 | metrics['duration'] = time.time() - self._start
37 | for name, durs in self._durations.items():
38 | available = {}
39 | available['count'] = len(durs)
40 | available['total'] = np.sum(durs)
41 | available['frac'] = np.sum(durs) / metrics['duration']
42 | if len(durs):
43 | available['avg'] = np.mean(durs)
44 | available['min'] = np.min(durs)
45 | available['max'] = np.max(durs)
46 | for key, value in available.items():
47 | if key in self._columns:
48 | metrics[f'{name}_{key}'] = value
49 | if log:
50 | self._log(metrics)
51 | if reset:
52 | self.reset()
53 | return metrics
54 |
55 | def _log(self, metrics):
56 | names = self._durations.keys()
57 | names = sorted(names, key=lambda k: -metrics[f'{k}_frac'])
58 | print('Timer:'.ljust(20), ' '.join(x.rjust(8) for x in self._columns))
59 | for name in names:
60 | values = [metrics[f'{name}_{col}'] for col in self._columns]
61 | print(f'{name.ljust(20)}', ' '.join((f'{x:8.4f}' for x in values)))
62 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/core/uuid.py:
--------------------------------------------------------------------------------
1 | import string
2 | import uuid as uuidlib
3 |
4 | import numpy as np
5 |
6 |
7 | class uuid:
8 | """UUID that is stored as 16 byte string and can be converted to and from
9 | int, string, and array types."""
10 |
11 | DEBUG_ID = None
12 | BASE62 = string.digits + string.ascii_letters
13 | BASE62REV = {x: i for i, x in enumerate(BASE62)}
14 |
15 | @classmethod
16 | def reset(cls, *, debug):
17 | cls.DEBUG_ID = 0 if debug else None
18 |
19 | def __init__(self, value=None):
20 | if value is None:
21 | if self.DEBUG_ID is None:
22 | self.value = uuidlib.uuid4().bytes
23 | else:
24 | type(self).DEBUG_ID += 1
25 | self.value = self.DEBUG_ID.to_bytes(16, 'big')
26 | elif isinstance(value, uuid):
27 | self.value = value.value
28 | elif isinstance(value, int):
29 | self.value = value.to_bytes(16, 'big')
30 | elif isinstance(value, str):
31 | if self.DEBUG_ID is None:
32 | integer = 0
33 | for index, char in enumerate(value[::-1]):
34 | integer += (62 ** index) * self.BASE62REV[char]
35 | self.value = integer.to_bytes(16, 'big')
36 | else:
37 | self.value = int(value).to_bytes(16, 'big')
38 | elif isinstance(value, np.ndarray):
39 | self.value = value.tobytes()
40 | else:
41 | raise ValueError(value)
42 | assert type(self.value) == bytes, type(self.value)
43 | assert len(self.value) == 16, len(self.value)
44 | self._hash = hash(self.value)
45 |
46 | def __int__(self):
47 | return int.from_bytes(self.value, 'big')
48 |
49 | def __str__(self):
50 | if self.DEBUG_ID is not None:
51 | return str(int(self))
52 | chars = []
53 | integer = int(self)
54 | while integer != 0:
55 | chars.append(self.BASE62[integer % 62])
56 | integer //= 62
57 | while len(chars) < 22:
58 | chars.append('0')
59 | return ''.join(chars[::-1])
60 |
61 | def __array__(self):
62 | return np.frombuffer(self.value, np.uint8)
63 |
64 | def __getitem__(self, index):
65 | return self.__array__()[index]
66 |
67 | def __repr__(self):
68 | return str(self)
69 |
70 | def __eq__(self, other):
71 | return self.value == other.value
72 |
73 | def __hash__(self):
74 | return self._hash
75 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/core/when.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 |
4 | class Every:
5 |
6 | def __init__(self, every, initial=True):
7 | self._every = every
8 | self._initial = initial
9 | self._prev = None
10 |
11 | def __call__(self, step):
12 | step = int(step)
13 | if self._every < 0:
14 | return True
15 | if self._every == 0:
16 | return False
17 | if self._prev is None:
18 | self._prev = (step // self._every) * self._every
19 | return self._initial
20 | if step >= self._prev + self._every:
21 | self._prev += self._every
22 | return True
23 | return False
24 |
25 |
26 | class Ratio:
27 |
28 | def __init__(self, ratio):
29 | assert ratio >= 0, ratio
30 | self._ratio = ratio
31 | self._prev = None
32 |
33 | def __call__(self, step):
34 | step = int(step)
35 | if self._ratio == 0:
36 | return 0
37 | if self._prev is None:
38 | self._prev = step
39 | return 1
40 | repeats = int((step - self._prev) * self._ratio)
41 | self._prev += repeats / self._ratio
42 | return repeats
43 |
44 |
45 | class Once:
46 |
47 | def __init__(self):
48 | self._once = True
49 |
50 | def __call__(self):
51 | if self._once:
52 | self._once = False
53 | return True
54 | return False
55 |
56 |
57 | class Until:
58 |
59 | def __init__(self, until):
60 | self._until = until
61 |
62 | def __call__(self, step):
63 | step = int(step)
64 | if not self._until:
65 | return True
66 | return step < self._until
67 |
68 |
69 | class Clock:
70 |
71 | def __init__(self, every):
72 | self._every = every
73 | self._prev = None
74 |
75 | def __call__(self, step=None):
76 | if self._every < 0:
77 | return True
78 | if self._every == 0:
79 | return False
80 | now = time.time()
81 | if self._prev is None:
82 | self._prev = now
83 | return True
84 | if now >= self._prev + self._every:
85 | # self._prev += self._every
86 | self._prev = now
87 | return True
88 | return False
89 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/core/worker.py:
--------------------------------------------------------------------------------
1 | import atexit
2 | import concurrent.futures
3 | import enum
4 | import os
5 | import sys
6 | import time
7 | import traceback
8 | from functools import partial as bind
9 |
10 |
11 | class Worker:
12 |
13 | initializers = []
14 |
15 | def __init__(self, fn, strategy='thread', state=False):
16 | if not state:
17 | fn = lambda s, *args, fn=fn, **kwargs: (s, fn(*args, **kwargs))
18 | inits = self.initializers
19 | self.impl = {
20 | 'blocking': BlockingWorker,
21 | 'thread': ThreadWorker,
22 | 'process': bind(ProcessPipeWorker, initializers=inits),
23 | 'daemon': bind(ProcessPipeWorker, initializers=inits, daemon=True),
24 | 'process_slow': bind(ProcessWorker, initializers=inits),
25 | }[strategy](fn)
26 | self.promise = None
27 |
28 | def __call__(self, *args, **kwargs):
29 | self.promise and self.promise() # Raise previous exception if any.
30 | self.promise = self.impl(*args, **kwargs)
31 | return self.promise
32 |
33 | def wait(self):
34 | return self.impl.wait()
35 |
36 | def close(self):
37 | self.impl.close()
38 |
39 |
40 | class BlockingWorker:
41 |
42 | def __init__(self, fn):
43 | self.fn = fn
44 | self.state = None
45 |
46 | def __call__(self, *args, **kwargs):
47 | self.state, result = self.fn(self.state, *args, **kwargs)
48 | # return lambda: result
49 | return lambda result=result: result
50 |
51 | def wait(self):
52 | pass
53 |
54 | def close(self):
55 | pass
56 |
57 |
58 | class ThreadWorker:
59 |
60 | def __init__(self, fn):
61 | self.fn = fn
62 | self.state = None
63 | self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
64 | self.futures = []
65 |
66 | def __call__(self, *args, **kwargs):
67 | future = self.executor.submit(self._worker, *args, **kwargs)
68 | self.futures.append(future)
69 | future.add_done_callback(lambda f: self.futures.remove(f))
70 | return future.result
71 |
72 | def wait(self):
73 | concurrent.futures.wait(self.futures)
74 |
75 | def close(self):
76 | self.executor.shutdown(wait=False, cancel_futures=True)
77 |
78 | def _worker(self, *args, **kwargs):
79 | self.state, output = self.fn(self.state, *args, **kwargs)
80 | return output
81 |
82 |
83 | class ProcessWorker:
84 |
85 | def __init__(self, fn, initializers=()):
86 | import cloudpickle
87 | import multiprocessing
88 | fn = cloudpickle.dumps(fn)
89 | initializers = cloudpickle.dumps(initializers)
90 | self.executor = concurrent.futures.ProcessPoolExecutor(
91 | max_workers=1, mp_context=multiprocessing.get_context('spawn'),
92 | initializer=self._initializer, initargs=(fn, initializers))
93 | self.futures = []
94 |
95 | def __call__(self, *args, **kwargs):
96 | future = self.executor.submit(self._worker, *args, **kwargs)
97 | self.futures.append(future)
98 | future.add_done_callback(lambda f: self.futures.remove(f))
99 | return future.result
100 |
101 | def wait(self):
102 | concurrent.futures.wait(self.futures)
103 |
104 | def close(self):
105 | self.executor.shutdown(wait=False, cancel_futures=True)
106 |
107 | @staticmethod
108 | def _initializer(fn, initializers):
109 | global _FN, _STATE
110 | import cloudpickle
111 | _FN = cloudpickle.loads(fn)
112 | _STATE = None
113 | for initializer in cloudpickle.loads(initializers):
114 | initializers()
115 |
116 | @staticmethod
117 | def _worker(*args, **kwargs):
118 | global _FN, _STATE
119 | _STATE, output = _FN(_STATE, *args, **kwargs)
120 | return output
121 |
122 |
123 | class ProcessPipeWorker:
124 |
125 | def __init__(self, fn, initializers=(), daemon=False):
126 | import multiprocessing
127 | import cloudpickle
128 | self._context = multiprocessing.get_context('spawn')
129 | self._pipe, pipe = self._context.Pipe()
130 | fn = cloudpickle.dumps(fn)
131 | initializers = cloudpickle.dumps(initializers)
132 | self._process = self._context.Process(
133 | target=self._loop,
134 | args=(pipe, fn, initializers),
135 | daemon=daemon)
136 | self._process.start()
137 | self._nextid = 0
138 | self._results = {}
139 | assert self._submit(Message.OK)()
140 | atexit.register(self.close)
141 |
142 | def __call__(self, *args, **kwargs):
143 | return self._submit(Message.RUN, (args, kwargs))
144 |
145 | def wait(self):
146 | pass
147 |
148 | def close(self):
149 | try:
150 | self._pipe.send((Message.STOP, self._nextid, None))
151 | self._pipe.close()
152 | except (AttributeError, IOError):
153 | pass # The connection was already closed.
154 | try:
155 | self._process.join(0.1)
156 | if self._process.exitcode is None:
157 | try:
158 | os.kill(self._process.pid, 9)
159 | time.sleep(0.1)
160 | except Exception:
161 | pass
162 | except (AttributeError, AssertionError):
163 | pass
164 |
165 | def _submit(self, message, payload=None):
166 | callid = self._nextid
167 | self._nextid += 1
168 | self._pipe.send((message, callid, payload))
169 | return Future(self._receive, callid)
170 |
171 | def _receive(self, callid):
172 | while callid not in self._results:
173 | try:
174 | message, callid, payload = self._pipe.recv()
175 | except (OSError, EOFError):
176 | raise RuntimeError('Lost connection to worker.')
177 | if message == Message.ERROR:
178 | raise Exception(payload)
179 | assert message == Message.RESULT, message
180 | self._results[callid] = payload
181 | return self._results.pop(callid)
182 |
183 | @staticmethod
184 | def _loop(pipe, function, initializers):
185 | try:
186 | callid = None
187 | state = None
188 | import cloudpickle
189 | initializers = cloudpickle.loads(initializers)
190 | function = cloudpickle.loads(function)
191 | [fn() for fn in initializers]
192 | while True:
193 | if not pipe.poll(0.1):
194 | continue # Wake up for keyboard interrupts.
195 | message, callid, payload = pipe.recv()
196 | if message == Message.OK:
197 | pipe.send((Message.RESULT, callid, True))
198 | elif message == Message.STOP:
199 | return
200 | elif message == Message.RUN:
201 | args, kwargs = payload
202 | state, result = function(state, *args, **kwargs)
203 | pipe.send((Message.RESULT, callid, result))
204 | else:
205 | raise KeyError(f'Invalid message: {message}')
206 | except (EOFError, KeyboardInterrupt):
207 | return
208 | except Exception:
209 | stacktrace = ''.join(traceback.format_exception(*sys.exc_info()))
210 | print(f'Error inside process worker: {stacktrace}.', flush=True)
211 | pipe.send((Message.ERROR, callid, stacktrace))
212 | return
213 | finally:
214 | try:
215 | pipe.close()
216 | except Exception:
217 | pass
218 |
219 |
220 | class Future:
221 |
222 | def __init__(self, receive, callid):
223 | self._receive = receive
224 | self._callid = callid
225 | self._result = None
226 | self._complete = False
227 |
228 | def __call__(self):
229 | if not self._complete:
230 | self._result = self._receive(self._callid)
231 | self._complete = True
232 | return self._result
233 |
234 |
235 | class Message(enum.Enum):
236 |
237 | OK = 1
238 | RUN = 2
239 | RESULT = 3
240 | STOP = 4
241 | ERROR = 5
242 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/envs/atari.py:
--------------------------------------------------------------------------------
1 | import embodied
2 | import numpy as np
3 |
4 |
5 | class Atari(embodied.Env):
6 |
7 | LOCK = None
8 |
9 | def __init__(
10 | self, name, repeat=4, size=(84, 84), gray=True, noops=0, lives='unused',
11 | sticky=True, actions='all', length=108000, resize='opencv', seed=None):
12 | assert size[0] == size[1]
13 | assert lives in ('unused', 'discount', 'reset'), lives
14 | assert actions in ('all', 'needed'), actions
15 | assert resize in ('opencv', 'pillow'), resize
16 | if self.LOCK is None:
17 | import multiprocessing as mp
18 | mp = mp.get_context('spawn')
19 | self.LOCK = mp.Lock()
20 | self._resize = resize
21 | if self._resize == 'opencv':
22 | import cv2
23 | self._cv2 = cv2
24 | if self._resize == 'pillow':
25 | from PIL import Image
26 | self._image = Image
27 | import gym.envs.atari
28 | if name == 'james_bond':
29 | name = 'jamesbond'
30 | self._repeat = repeat
31 | self._size = size
32 | self._gray = gray
33 | self._noops = noops
34 | self._lives = lives
35 | self._sticky = sticky
36 | self._length = length
37 | self._random = np.random.RandomState(seed)
38 | with self.LOCK:
39 | self._env = gym.envs.atari.AtariEnv(
40 | game=name,
41 | obs_type='image',
42 | frameskip=1, repeat_action_probability=0.25 if sticky else 0.0,
43 | full_action_space=(actions == 'all'))
44 | assert self._env.unwrapped.get_action_meanings()[0] == 'NOOP'
45 | shape = self._env.observation_space.shape
46 | self._buffer = [np.zeros(shape, np.uint8) for _ in range(2)]
47 | self._ale = self._env.unwrapped.ale
48 | self._last_lives = None
49 | self._done = True
50 | self._step = 0
51 |
52 | @property
53 | def obs_space(self):
54 | shape = self._size + (1 if self._gray else 3,)
55 | return {
56 | 'image': embodied.Space(np.uint8, shape),
57 | 'reward': embodied.Space(np.float32),
58 | 'is_first': embodied.Space(bool),
59 | 'is_last': embodied.Space(bool),
60 | 'is_terminal': embodied.Space(bool),
61 | }
62 |
63 | @property
64 | def act_space(self):
65 | return {
66 | 'action': embodied.Space(np.int32, (), 0, self._env.action_space.n),
67 | 'reset': embodied.Space(bool),
68 | }
69 |
70 | def step(self, action):
71 | if action['reset'] or self._done:
72 | with self.LOCK:
73 | self._reset()
74 | self._done = False
75 | self._step = 0
76 | return self._obs(0.0, is_first=True)
77 | total = 0.0
78 | dead = False
79 | for repeat in range(self._repeat):
80 | _, reward, over, info = self._env.step(action['action'])
81 | self._step += 1
82 | total += reward
83 | if repeat == self._repeat - 2:
84 | self._screen(self._buffer[1])
85 | if over:
86 | break
87 | if self._lives != 'unused':
88 | current = self._ale.lives()
89 | if current < self._last_lives:
90 | dead = True
91 | self._last_lives = current
92 | break
93 | if not self._repeat:
94 | self._buffer[1][:] = self._buffer[0][:]
95 | self._screen(self._buffer[0])
96 | self._done = over or (self._length and self._step >= self._length)
97 | return self._obs(
98 | total,
99 | is_last=self._done or (dead and self._lives == 'reset'),
100 | is_terminal=dead or over)
101 |
102 | def _reset(self):
103 | self._env.reset()
104 | if self._noops:
105 | for _ in range(self._random.randint(self._noops)):
106 | _, _, dead, _ = self._env.step(0)
107 | if dead:
108 | self._env.reset()
109 | self._last_lives = self._ale.lives()
110 | self._screen(self._buffer[0])
111 | self._buffer[1].fill(0)
112 |
113 | def _obs(self, reward, is_first=False, is_last=False, is_terminal=False):
114 | np.maximum(self._buffer[0], self._buffer[1], out=self._buffer[0])
115 | image = self._buffer[0]
116 | if image.shape[:2] != self._size:
117 | if self._resize == 'opencv':
118 | image = self._cv2.resize(
119 | image, self._size, interpolation=self._cv2.INTER_AREA)
120 | if self._resize == 'pillow':
121 | image = self._image.fromarray(image)
122 | image = image.resize(self._size, self._image.NEAREST)
123 | image = np.array(image)
124 | if self._gray:
125 | weights = [0.299, 0.587, 1 - (0.299 + 0.587)]
126 | image = np.tensordot(image, weights, (-1, 0)).astype(image.dtype)
127 | image = image[:, :, None]
128 | return dict(
129 | image=image,
130 | reward=reward,
131 | is_first=is_first,
132 | is_last=is_last,
133 | is_terminal=is_last,
134 | )
135 |
136 | def _screen(self, array):
137 | self._ale.getScreenRGB2(array)
138 |
139 | def close(self):
140 | return self._env.close()
141 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/envs/crafter.py:
--------------------------------------------------------------------------------
1 | import embodied
2 | import numpy as np
3 |
4 |
5 | class Crafter(embodied.Env):
6 |
7 | def __init__(self, task, size=(64, 64), outdir=None, seed=None):
8 | assert task in ('reward', 'noreward')
9 | import crafter
10 | self._env = crafter.Env(size=size, reward=(task == 'reward'), seed=seed)
11 | if outdir:
12 | outdir = embodied.Path(outdir)
13 | self._env = crafter.Recorder(
14 | self._env, outdir,
15 | save_stats=True,
16 | save_video=False,
17 | save_episode=False,
18 | )
19 | self._achievements = crafter.constants.achievements.copy()
20 | self._done = True
21 |
22 | @property
23 | def obs_space(self):
24 | spaces = {
25 | 'image': embodied.Space(np.uint8, self._env.observation_space.shape),
26 | 'reward': embodied.Space(np.float32),
27 | 'is_first': embodied.Space(bool),
28 | 'is_last': embodied.Space(bool),
29 | 'is_terminal': embodied.Space(bool),
30 | 'log_reward': embodied.Space(np.float32),
31 | }
32 | spaces.update({
33 | f'log_achievement_{k}': embodied.Space(np.int32)
34 | for k in self._achievements})
35 | return spaces
36 |
37 | @property
38 | def act_space(self):
39 | return {
40 | 'action': embodied.Space(np.int32, (), 0, self._env.action_space.n),
41 | 'reset': embodied.Space(bool),
42 | }
43 |
44 | def step(self, action):
45 | if action['reset'] or self._done:
46 | self._done = False
47 | image = self._env.reset()
48 | return self._obs(image, 0.0, {}, is_first=True)
49 | image, reward, self._done, info = self._env.step(action['action'])
50 | reward = np.float32(reward)
51 | return self._obs(
52 | image, reward, info,
53 | is_last=self._done,
54 | is_terminal=info['discount'] == 0)
55 |
56 | def _obs(
57 | self, image, reward, info,
58 | is_first=False, is_last=False, is_terminal=False):
59 | log_achievements = {
60 | f'log_achievement_{k}': info['achievements'][k] if info else 0
61 | for k in self._achievements}
62 | return dict(
63 | image=image,
64 | reward=reward,
65 | is_first=is_first,
66 | is_last=is_last,
67 | is_terminal=is_terminal,
68 | log_reward=np.float32(info['reward'] if info else 0.0),
69 | **log_achievements,
70 | )
71 |
72 | def render(self):
73 | return self._env.render()
74 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/envs/dmc.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import os
3 |
4 | import embodied
5 | import numpy as np
6 |
7 |
8 | class DMC(embodied.Env):
9 |
10 | DEFAULT_CAMERAS = dict(
11 | locom_rodent=1,
12 | quadruped=2,
13 | )
14 |
15 | def __init__(self, env, repeat=1, render=True, size=(64, 64), camera=-1):
16 | # TODO: This env variable is meant for headless GPU machines but may fail
17 | # on CPU-only machines.
18 | if 'MUJOCO_GL' not in os.environ:
19 | os.environ['MUJOCO_GL'] = 'egl'
20 | if isinstance(env, str):
21 | domain, task = env.split('_', 1)
22 | if camera == -1:
23 | camera = self.DEFAULT_CAMERAS.get(domain, 0)
24 | if domain == 'cup': # Only domain with multiple words.
25 | domain = 'ball_in_cup'
26 | if domain == 'manip':
27 | from dm_control import manipulation
28 | env = manipulation.load(task + '_vision')
29 | elif domain == 'locom':
30 | from dm_control.locomotion.examples import basic_rodent_2020
31 | env = getattr(basic_rodent_2020, task)()
32 | else:
33 | from dm_control import suite
34 | env = suite.load(domain, task)
35 | self._dmenv = env
36 | from . import from_dm
37 | self._env = from_dm.FromDM(self._dmenv)
38 | self._env = embodied.wrappers.ExpandScalars(self._env)
39 | self._env = embodied.wrappers.ActionRepeat(self._env, repeat)
40 | self._render = render
41 | self._size = size
42 | self._camera = camera
43 |
44 | @functools.cached_property
45 | def obs_space(self):
46 | spaces = self._env.obs_space.copy()
47 | if self._render:
48 | spaces['image'] = embodied.Space(np.uint8, self._size + (3,))
49 | return spaces
50 |
51 | @functools.cached_property
52 | def act_space(self):
53 | return self._env.act_space
54 |
55 | def step(self, action):
56 | for key, space in self.act_space.items():
57 | if not space.discrete:
58 | assert np.isfinite(action[key]).all(), (key, action[key])
59 | obs = self._env.step(action)
60 | if self._render:
61 | obs['image'] = self.render()
62 | return obs
63 |
64 | def render(self):
65 | return self._dmenv.physics.render(*self._size, camera_id=self._camera)
66 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/envs/dmlab.py:
--------------------------------------------------------------------------------
1 | import embodied
2 | import numpy as np
3 |
4 |
5 | class DMLab(embodied.Env):
6 |
7 | # Small action set used by IMPALA.
8 | IMPALA_ACTION_SET = (
9 | ( 0, 0, 0, 1, 0, 0, 0), # Forward
10 | ( 0, 0, 0, -1, 0, 0, 0), # Backward
11 | ( 0, 0, -1, 0, 0, 0, 0), # Strafe Left
12 | ( 0, 0, 1, 0, 0, 0, 0), # Strafe Right
13 | (-20, 0, 0, 0, 0, 0, 0), # Look Left
14 | ( 20, 0, 0, 0, 0, 0, 0), # Look Right
15 | (-20, 0, 0, 1, 0, 0, 0), # Look Left + Forward
16 | ( 20, 0, 0, 1, 0, 0, 0), # Look Right + Forward
17 | ( 0, 0, 0, 0, 1, 0, 0), # Fire
18 | )
19 |
20 | # Large action set used by PopArt and R2D2.
21 | POPART_ACTION_SET = [
22 | ( 0, 0, 0, 1, 0, 0, 0), # FW
23 | ( 0, 0, 0, -1, 0, 0, 0), # BW
24 | ( 0, 0, -1, 0, 0, 0, 0), # Strafe Left
25 | ( 0, 0, 1, 0, 0, 0, 0), # Strafe Right
26 | (-10, 0, 0, 0, 0, 0, 0), # Small LL
27 | ( 10, 0, 0, 0, 0, 0, 0), # Small LR
28 | (-60, 0, 0, 0, 0, 0, 0), # Large LL
29 | ( 60, 0, 0, 0, 0, 0, 0), # Large LR
30 | ( 0, 10, 0, 0, 0, 0, 0), # Look Down
31 | ( 0, -10, 0, 0, 0, 0, 0), # Look Up
32 | (-10, 0, 0, 1, 0, 0, 0), # FW + Small LL
33 | ( 10, 0, 0, 1, 0, 0, 0), # FW + Small LR
34 | (-60, 0, 0, 1, 0, 0, 0), # FW + Large LL
35 | ( 60, 0, 0, 1, 0, 0, 0), # FW + Large LR
36 | ( 0, 0, 0, 0, 1, 0, 0), # Fire
37 | ]
38 |
39 | def __init__(
40 | self, level, repeat=4, size=(64, 64), mode='train',
41 | action_set=IMPALA_ACTION_SET, episodic=True, seed=None):
42 | import deepmind_lab
43 | cache = None
44 | # path = os.environ.get('DMLAB_CACHE', None)
45 | # if path:
46 | # cache = Cache(path)
47 | self._size = size
48 | self._repeat = repeat
49 | self._action_set = action_set
50 | self._episodic = episodic
51 | self._random = np.random.RandomState(seed)
52 | config = dict(height=size[0], width=size[1], logLevel='WARN')
53 | if mode == 'train':
54 | if level.endswith('_test'):
55 | level = level.replace('_test', '_train')
56 | elif mode == 'eval':
57 | config.update(allowHoldOutLevels='true', mixerSeed=0x600D5EED)
58 | else:
59 | raise NotImplementedError(mode)
60 | config = {k: str(v) for k, v in config.items()}
61 | self._env = deepmind_lab.Lab(
62 | level='contributed/dmlab30/' + level,
63 | observations=['RGB_INTERLEAVED'],
64 | level_cache=cache, config=config)
65 | self._prev_image = None
66 | self._done = True
67 |
68 | @property
69 | def obs_space(self):
70 | return {
71 | 'image': embodied.Space(np.uint8, self._size + (3,)),
72 | 'reward': embodied.Space(np.float32),
73 | 'is_first': embodied.Space(bool),
74 | 'is_last': embodied.Space(bool),
75 | 'is_terminal': embodied.Space(bool),
76 | }
77 |
78 | @property
79 | def act_space(self):
80 | return {
81 | 'action': embodied.Space(np.int32, (), 0, len(self._action_set)),
82 | 'reset': embodied.Space(bool),
83 | }
84 |
85 | def step(self, action):
86 | if action['reset'] or self._done:
87 | self._env.reset(seed=self._random.randint(0, 2 ** 31 - 1))
88 | self._done = False
89 | return self._obs(0.0, is_first=True)
90 | raw_action = np.array(self._action_set[action['action']], np.intc)
91 | reward = self._env.step(raw_action, num_steps=self._repeat)
92 | self._done = not self._env.is_running()
93 | return self._obs(reward, is_last=self._done)
94 |
95 | def _obs(self, reward, is_first=False, is_last=False):
96 | return dict(
97 | image=self.render(),
98 | reward=reward,
99 | is_first=is_first,
100 | is_last=is_last,
101 | is_terminal=is_last if self._episodic else False,
102 | )
103 |
104 | def render(self):
105 | if not self._done:
106 | self._prev_image = self._env.observations()['RGB_INTERLEAVED']
107 | return self._prev_image
108 |
109 | def close(self):
110 | self._env.close()
111 |
112 |
113 | class Cache:
114 |
115 | def __init__(self, cache_dir):
116 | self._cache_dir = cache_dir
117 |
118 | def get_path(self, key):
119 | import hashlib, os
120 | key = hashlib.md5(key.encode('utf-8')).hexdigest()
121 | dir_, filename = key[:3], key[3:]
122 | return os.path.join(self._cache_dir, dir_, filename)
123 |
124 | def fetch(self, key, pk3_path):
125 | import tensorflow as tf
126 | path = self.get_path(key)
127 | try:
128 | tf.io.gfile.copy(path, pk3_path, overwrite=True)
129 | return True
130 | except tf.errors.OpError:
131 | return False
132 |
133 | def write(self, key, pk3_path):
134 | import os
135 | import tensorflow as tf
136 | path = self.get_path(key)
137 | try:
138 | if not tf.io.gfile.exists(path):
139 | tf.io.gfile.makedirs(os.path.dirname(path))
140 | tf.io.gfile.copy(pk3_path, path)
141 | except Exception as e:
142 | print(f'Could to store level: {e}')
143 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/envs/dummy.py:
--------------------------------------------------------------------------------
1 | import embodied
2 | import numpy as np
3 |
4 |
5 | class Dummy(embodied.Env):
6 |
7 | def __init__(self, task, size=(64, 64), length=100):
8 | assert task in ('cont', 'disc')
9 | self._task = task
10 | self._size = size
11 | self._length = length
12 | self._step = 0
13 | self._done = False
14 |
15 | @property
16 | def obs_space(self):
17 | return {
18 | 'image': embodied.Space(np.uint8, self._size + (3,)),
19 | 'vector': embodied.Space(np.float32, (7,)),
20 | 'step': embodied.Space(np.int32, (), 0, self._length),
21 | 'reward': embodied.Space(np.float32),
22 | 'is_first': embodied.Space(bool),
23 | 'is_last': embodied.Space(bool),
24 | 'is_terminal': embodied.Space(bool),
25 | }
26 |
27 | @property
28 | def act_space(self):
29 | if self._task == 'cont':
30 | space = embodied.Space(np.float32, (6,))
31 | else:
32 | space = embodied.Space(np.int32, (), 0, 5)
33 | return {'action': space, 'reset': embodied.Space(bool)}
34 |
35 | def step(self, action):
36 | if action['reset'] or self._done:
37 | self._step = 0
38 | self._done = False
39 | return self._obs(0.0, is_first=True)
40 | action = action['action']
41 | self._step += 1
42 | self._done = (self._step >= self._length)
43 | return self._obs(1.0, is_last=self._done, is_terminal=self._done)
44 |
45 | def _obs(self, reward, is_first=False, is_last=False, is_terminal=False):
46 | return dict(
47 | image=np.zeros(self._size + (3,), np.uint8),
48 | vector=np.zeros(7, np.float32),
49 | step=self._step,
50 | reward=reward,
51 | is_first=is_first,
52 | is_last=is_last,
53 | is_terminal=is_terminal,
54 | )
55 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/envs/from_dm.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import os
3 |
4 | import embodied
5 | import numpy as np
6 |
7 |
8 | class FromDM(embodied.Env):
9 |
10 | def __init__(self, env, obs_key='observation', act_key='action'):
11 | if isinstance(env, str):
12 | if env.startswith('bsuite'):
13 | import bsuite
14 | _, task = env.split('_', 1)
15 | self._env = bsuite.load_from_id(task)
16 | else:
17 | self._env = env
18 | obs_spec = self._env.observation_spec()
19 | act_spec = self._env.action_spec()
20 | self._obs_dict = isinstance(obs_spec, dict)
21 | self._act_dict = isinstance(act_spec, dict)
22 | self._obs_key = not self._obs_dict and obs_key
23 | self._act_key = not self._act_dict and act_key
24 | self._obs_empty = []
25 | self._done = True
26 |
27 | @functools.cached_property
28 | def obs_space(self):
29 | spec = self._env.observation_spec()
30 | spec = spec if self._obs_dict else {self._obs_key: spec}
31 | if 'reward' in spec:
32 | spec['obs_reward'] = spec.pop('reward')
33 | for key, value in spec.copy().items():
34 | if int(np.prod(value.shape)) == 0:
35 | self._obs_empty.append(key)
36 | del spec[key]
37 | return {
38 | 'reward': embodied.Space(np.float32),
39 | 'is_first': embodied.Space(bool),
40 | 'is_last': embodied.Space(bool),
41 | 'is_terminal': embodied.Space(bool),
42 | 'step_no': embodied.Space(np.int32),
43 | **{k or self._obs_key: self._convert(v) for k, v in spec.items()},
44 | }
45 |
46 | @functools.cached_property
47 | def act_space(self):
48 | spec = self._env.action_spec()
49 | spec = spec if self._act_dict else {self._act_key: spec}
50 | return {
51 | 'reset': embodied.Space(bool),
52 | **{k or self._act_key: self._convert(v) for k, v in spec.items()},
53 | }
54 |
55 | def step(self, action):
56 | action = action.copy()
57 | reset = action.pop('reset')
58 | if reset or self._done:
59 | time_step = self._env.reset()
60 | else:
61 | action = action if self._act_dict else action[self._act_key]
62 | time_step = self._env.step(action)
63 | self._done = time_step.last()
64 | return self._obs(time_step)
65 |
66 | def _obs(self, time_step):
67 | if not time_step.first():
68 | assert time_step.discount in (0, 1), time_step.discount
69 | obs = time_step.observation
70 | obs = dict(obs) if self._obs_dict else {self._obs_key: obs}
71 | if 'reward' in obs:
72 | obs['obs_reward'] = obs.pop('reward')
73 | for key in self._obs_empty:
74 | del obs[key]
75 | return dict(
76 | reward=np.float32(0.0 if time_step.first() else time_step.reward),
77 | is_first=time_step.first(),
78 | is_last=time_step.last(),
79 | is_terminal=False if time_step.first() else time_step.discount == 0,
80 | **{k: v if len(v.shape) != 2 else v.flatten() for k, v in obs.items()},
81 | )
82 |
83 | def _convert(self, space):
84 | shape = space.shape
85 | if len(shape) == 2:
86 | shape = (shape[0] * shape[1],)
87 | if hasattr(space, 'num_values'):
88 | return embodied.Space(space.dtype, (), 0, space.num_values)
89 | elif hasattr(space, 'minimum'):
90 | assert np.isfinite(space.minimum).all(), space.minimum
91 | assert np.isfinite(space.maximum).all(), space.maximum
92 | return embodied.Space(
93 | space.dtype, shape, space.minimum, space.maximum)
94 | else:
95 | return embodied.Space(space.dtype, shape, None, None)
96 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/envs/from_gym.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import embodied
4 | import gym
5 | import gymnasium as gym_
6 | import numpy as np
7 |
8 | def flatten(space, onehot=False):
9 | if isinstance(space, gym_.spaces.Discrete):
10 | n = space.n
11 | p = np.eye(n)
12 | tr = lambda x: p[x]
13 | space_ = gym_.spaces.Box(0, 1, (n,))
14 | elif isinstance(space, gym_.spaces.MultiDiscrete):
15 | p = [np.eye(t) for t in space.nvec]
16 | tr = lambda x: np.concatenate([eye[j] for eye, j in zip(p, x)], axis=0)
17 | size = space.nvec.sum()
18 | space_ = gym_.spaces.Box(0, 1, (size,))
19 | elif isinstance(space, gym_.spaces.Tuple):
20 | p = [np.eye(t.n) for t in space.spaces]
21 | tr = lambda x: np.concatenate([eye[j] for eye, j in zip(p, x)], axis=0)
22 | space_ = gym_.spaces.Box(0, 1, (sum([len(pp) for pp in p]),))
23 | return tr, space_
24 |
25 | def unmap_action(a, nvec):
26 | newa = np.zeros(len(nvec), dtype=np.int32)
27 | for i in reversed(range(len(nvec))):
28 | newa[i] = a % nvec[i]
29 | a //= nvec[i]
30 | return newa
31 |
32 | class MyObsPopgym(gym_.Wrapper):
33 | def __init__(self, env):
34 | super().__init__(env)
35 | self.tr, space_ = flatten(self.observation_space)
36 | self.maction = self.action_space if isinstance(self.action_space, gym_.spaces.MultiDiscrete) else None
37 | if self.maction is not None:
38 | self.action_space = gym_.spaces.Discrete(np.prod(self.action_space.nvec))
39 | self.observation_space = gym_.spaces.Dict({'observation': space_})
40 |
41 | def step(self, action):
42 | if self.maction is not None:
43 | action = unmap_action(action, self.maction.nvec)
44 | obs, reward, ter, trun, info = super().step(action)
45 | obs = self.tr(obs)
46 | return {'observation': obs}, reward, ter | trun, info
47 |
48 | def reset(self):
49 | obs, info = super().reset()
50 | return {'observation': self.tr(obs)}
51 |
52 | class FromGym(embodied.Env):
53 |
54 | def __init__(self, env, obs_key='image', act_key='action', **kwargs):
55 | if isinstance(env, str):
56 | try:
57 | self._env = gym.make(env, **kwargs)
58 | except:
59 | self._env = gym_.make(env, **kwargs)
60 | if 'popgym' in env:
61 | self._env = MyObsPopgym(self._env)
62 | else:
63 | assert not kwargs, kwargs
64 | self._env = env
65 | self._obs_dict = hasattr(self._env.observation_space, 'spaces')
66 | self._act_dict = hasattr(self._env.action_space, 'spaces')
67 | self._obs_key = obs_key
68 | self._act_key = act_key
69 | self._done = True
70 | self._info = None
71 |
72 | @property
73 | def info(self):
74 | return self._info
75 |
76 | @functools.cached_property
77 | def obs_space(self):
78 | if self._obs_dict:
79 | spaces = self._flatten(self._env.observation_space.spaces)
80 | else:
81 | spaces = {self._obs_key: self._env.observation_space}
82 | spaces = {k: self._convert(v) for k, v in spaces.items()}
83 | return {
84 | **spaces,
85 | 'reward': embodied.Space(np.float32),
86 | 'is_first': embodied.Space(bool),
87 | 'is_last': embodied.Space(bool),
88 | 'is_terminal': embodied.Space(bool),
89 | 'step_no': embodied.Space(np.int32),
90 | # 'ep_no': embodied.Space(np.int32),
91 | }
92 |
93 | @functools.cached_property
94 | def act_space(self):
95 | if self._act_dict:
96 | spaces = self._flatten(self._env.action_space.spaces)
97 | else:
98 | spaces = {self._act_key: self._env.action_space}
99 | spaces = {k: self._convert(v) for k, v in spaces.items()}
100 | spaces['reset'] = embodied.Space(bool)
101 | return spaces
102 |
103 | def step(self, action):
104 | if action['reset'] or self._done:
105 | self._done = False
106 | obs = self._env.reset()
107 | return self._obs(obs, 0.0, is_first=True)
108 | if self._act_dict:
109 | action = self._unflatten(action)
110 | else:
111 | action = action[self._act_key]
112 | obs, reward, self._done, self._info = self._env.step(action)
113 | return self._obs(
114 | obs, reward,
115 | is_last=bool(self._done),
116 | is_terminal=bool(self._info.get('is_terminal', self._done)))
117 |
118 | def _obs(
119 | self, obs, reward, is_first=False, is_last=False, is_terminal=False):
120 | if not self._obs_dict:
121 | obs = {self._obs_key: obs}
122 | obs = self._flatten(obs)
123 | obs = {k: np.asarray(v) for k, v in obs.items()}
124 | obs.update(
125 | reward=np.float32(reward),
126 | is_first=is_first,
127 | is_last=is_last,
128 | is_terminal=is_terminal)
129 | return obs
130 |
131 | def render(self):
132 | image = self._env.render('rgb_array')
133 | assert image is not None
134 | return image
135 |
136 | def close(self):
137 | try:
138 | self._env.close()
139 | except Exception:
140 | pass
141 |
142 | def _flatten(self, nest, prefix=None):
143 | result = {}
144 | for key, value in nest.items():
145 | key = prefix + '/' + key if prefix else key
146 | if isinstance(value, (gym.spaces.Dict, gym_.spaces.Dict)):
147 | value = value.spaces
148 | if isinstance(value, dict):
149 | result.update(self._flatten(value, key))
150 | else:
151 | result[key] = value
152 | return result
153 |
154 | def _unflatten(self, flat):
155 | result = {}
156 | for key, value in flat.items():
157 | parts = key.split('/')
158 | node = result
159 | for part in parts[:-1]:
160 | if part not in node:
161 | node[part] = {}
162 | node = node[part]
163 | node[parts[-1]] = value
164 | return result
165 |
166 | def _convert(self, space):
167 | if hasattr(space, 'n'):
168 | return embodied.Space(np.int32, (), 0, space.n)
169 | return embodied.Space(space.dtype, space.shape, space.low, space.high)
170 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/envs/loconav_quadruped.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from dm_control import composer
4 | from dm_control import mjcf
5 | from dm_control.composer.observation import observable
6 | from dm_control.locomotion.walkers import base
7 | from dm_control.locomotion.walkers import legacy_base
8 | from dm_control.mujoco.wrapper import mjbindings
9 | import numpy as np
10 |
11 | enums = mjbindings.enums
12 | mjlib = mjbindings.mjlib
13 |
14 |
15 | class Quadruped(legacy_base.Walker):
16 |
17 | def _build(self, name='walker', initializer=None):
18 | super()._build(initializer=initializer)
19 | self._mjcf_root = mjcf.from_path(
20 | os.path.join(os.path.dirname(__file__), 'loconav_quadruped.xml'))
21 | if name:
22 | self._mjcf_root.model = name
23 | self._prev_action = np.zeros(
24 | self.action_spec.shape, self.action_spec.dtype)
25 |
26 | def initialize_episode(self, physics, random_state):
27 | self._prev_action = np.zeros_like(self._prev_action)
28 |
29 | def apply_action(self, physics, action, random_state):
30 | super().apply_action(physics, action, random_state)
31 | self._prev_action[:] = action
32 |
33 | def _build_observables(self):
34 | return QuadrupedObservables(self)
35 |
36 | @property
37 | def mjcf_model(self):
38 | return self._mjcf_root
39 |
40 | @property
41 | def upright_pose(self):
42 | return base.WalkerPose()
43 |
44 | @composer.cached_property
45 | def actuators(self):
46 | return self._mjcf_root.find_all('actuator')
47 |
48 | @composer.cached_property
49 | def root_body(self):
50 | return self._mjcf_root.find('body', 'torso')
51 |
52 | @composer.cached_property
53 | def bodies(self):
54 | return tuple(self.mjcf_model.find_all('body'))
55 |
56 | @composer.cached_property
57 | def mocap_tracking_bodies(self):
58 | return tuple(self.mjcf_model.find_all('body'))
59 |
60 | @property
61 | def mocap_joints(self):
62 | return self.mjcf_model.find_all('joint')
63 |
64 | @property
65 | def _foot_bodies(self):
66 | return (
67 | self._mjcf_root.find('body', 'toe_front_left'),
68 | self._mjcf_root.find('body', 'toe_front_right'),
69 | self._mjcf_root.find('body', 'toe_back_right'),
70 | self._mjcf_root.find('body', 'toe_back_left'),
71 | )
72 |
73 | @composer.cached_property
74 | def end_effectors(self):
75 | return self._foot_bodies
76 |
77 | @composer.cached_property
78 | def observable_joints(self):
79 | return self._mjcf_root.find_all('joint')
80 |
81 | @composer.cached_property
82 | def egocentric_camera(self):
83 | return self._mjcf_root.find('camera', 'egocentric')
84 |
85 | def aliveness(self, physics):
86 | return (physics.bind(self.root_body).xmat[-1] - 1.) / 2.
87 |
88 | @composer.cached_property
89 | def ground_contact_geoms(self):
90 | foot_geoms = []
91 | for foot in self._foot_bodies:
92 | foot_geoms.extend(foot.find_all('geom'))
93 | return tuple(foot_geoms)
94 |
95 | @property
96 | def prev_action(self):
97 | return self._prev_action
98 |
99 |
100 | class QuadrupedObservables(legacy_base.WalkerObservables):
101 |
102 | @composer.observable
103 | def actuator_activations(self):
104 | def actuator_activations_in_egocentric_frame(physics):
105 | return physics.data.act
106 | return observable.Generic(actuator_activations_in_egocentric_frame)
107 |
108 | @composer.observable
109 | def root_global_pos(self):
110 | def root_pos(physics):
111 | root_xpos, _ = self._entity.get_pose(physics)
112 | return np.reshape(root_xpos, -1)
113 | return observable.Generic(root_pos)
114 |
115 | @composer.observable
116 | def torso_global_pos(self):
117 | def torso_pos(physics):
118 | root_body = self._entity.root_body
119 | root_body_xpos = physics.bind(root_body).xpos
120 | return np.reshape(root_body_xpos, -1)
121 | return observable.Generic(torso_pos)
122 |
123 | @property
124 | def proprioception(self):
125 | return ([
126 | self.joints_pos, self.joints_vel, self.actuator_activations,
127 | self.sensors_accelerometer, self.sensors_gyro,
128 | self.sensors_velocimeter,
129 | self.sensors_force, self.sensors_torque,
130 | self.world_zaxis,
131 | self.root_global_pos, self.torso_global_pos,
132 | ] + self._collect_from_attachments('proprioception'))
133 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/envs/minecraft.py:
--------------------------------------------------------------------------------
1 | import embodied
2 | import numpy as np
3 |
4 | from . import minecraft_base
5 |
6 |
7 | class Minecraft(embodied.Wrapper):
8 |
9 | def __init__(self, task, *args, **kwargs):
10 | super().__init__({
11 | 'wood': MinecraftWood,
12 | 'climb': MinecraftClimb,
13 | 'diamond': MinecraftDiamond,
14 | }[task](*args, **kwargs))
15 |
16 |
17 | class MinecraftWood(embodied.Wrapper):
18 |
19 | def __init__(self, *args, **kwargs):
20 | actions = BASIC_ACTIONS
21 | self.rewards = [
22 | CollectReward('log', repeated=1),
23 | HealthReward(),
24 | ]
25 | length = kwargs.pop('length', 36000)
26 | env = minecraft_base.MinecraftBase(actions, *args, **kwargs)
27 | env = embodied.wrappers.TimeLimit(env, length)
28 | super().__init__(env)
29 |
30 | def step(self, action):
31 | obs = self.env.step(action)
32 | obs['reward'] = sum([fn(obs, self.env.inventory) for fn in self.rewards])
33 | return obs
34 |
35 |
36 | class MinecraftClimb(embodied.Wrapper):
37 |
38 | def __init__(self, *args, **kwargs):
39 | actions = BASIC_ACTIONS
40 | length = kwargs.pop('length', 36000)
41 | env = minecraft_base.MinecraftBase(actions, *args, **kwargs)
42 | env = embodied.wrappers.TimeLimit(env, length)
43 | super().__init__(env)
44 | self._previous = None
45 | self._health_reward = HealthReward()
46 |
47 | def step(self, action):
48 | obs = self.env.step(action)
49 | x, y, z = obs['log_player_pos']
50 | height = np.float32(y)
51 | if obs['is_first']:
52 | self._previous = height
53 | obs['reward'] = height - self._previous
54 | obs['reward'] += self._health_reward(obs)
55 | self._previous = height
56 | return obs
57 |
58 |
59 | class MinecraftDiamond(embodied.Wrapper):
60 |
61 | def __init__(self, *args, **kwargs):
62 | actions = {
63 | **BASIC_ACTIONS,
64 | 'craft_planks': dict(craft='planks'),
65 | 'craft_stick': dict(craft='stick'),
66 | 'craft_crafting_table': dict(craft='crafting_table'),
67 | 'place_crafting_table': dict(place='crafting_table'),
68 | 'craft_wooden_pickaxe': dict(nearbyCraft='wooden_pickaxe'),
69 | 'craft_stone_pickaxe': dict(nearbyCraft='stone_pickaxe'),
70 | 'craft_iron_pickaxe': dict(nearbyCraft='iron_pickaxe'),
71 | 'equip_stone_pickaxe': dict(equip='stone_pickaxe'),
72 | 'equip_wooden_pickaxe': dict(equip='wooden_pickaxe'),
73 | 'equip_iron_pickaxe': dict(equip='iron_pickaxe'),
74 | 'craft_furnace': dict(nearbyCraft='furnace'),
75 | 'place_furnace': dict(place='furnace'),
76 | 'smelt_iron_ingot': dict(nearbySmelt='iron_ingot'),
77 | }
78 | self.rewards = [
79 | CollectReward('log', once=1),
80 | CollectReward('planks', once=1),
81 | CollectReward('stick', once=1),
82 | CollectReward('crafting_table', once=1),
83 | CollectReward('wooden_pickaxe', once=1),
84 | CollectReward('cobblestone', once=1),
85 | CollectReward('stone_pickaxe', once=1),
86 | CollectReward('iron_ore', once=1),
87 | CollectReward('furnace', once=1),
88 | CollectReward('iron_ingot', once=1),
89 | CollectReward('iron_pickaxe', once=1),
90 | CollectReward('diamond', once=1),
91 | HealthReward(),
92 | ]
93 | length = kwargs.pop('length', 36000)
94 | env = minecraft_base.MinecraftBase(actions, *args, **kwargs)
95 | env = embodied.wrappers.TimeLimit(env, length)
96 | super().__init__(env)
97 |
98 | def step(self, action):
99 | obs = self.env.step(action)
100 | obs['reward'] = sum([fn(obs, self.env.inventory) for fn in self.rewards])
101 | return obs
102 |
103 |
104 | class CollectReward:
105 |
106 | def __init__(self, item, once=0, repeated=0):
107 | self.item = item
108 | self.once = once
109 | self.repeated = repeated
110 | self.previous = 0
111 | self.maximum = 0
112 |
113 | def __call__(self, obs, inventory):
114 | current = inventory[self.item]
115 | if obs['is_first']:
116 | self.previous = current
117 | self.maximum = current
118 | return 0
119 | reward = self.repeated * max(0, current - self.previous)
120 | if self.maximum == 0 and current > 0:
121 | reward += self.once
122 | self.previous = current
123 | self.maximum = max(self.maximum, current)
124 | return reward
125 |
126 |
127 | class HealthReward:
128 |
129 | def __init__(self, scale=0.01):
130 | self.scale = scale
131 | self.previous = None
132 |
133 | def __call__(self, obs, inventory=None):
134 | health = obs['health']
135 | if obs['is_first']:
136 | self.previous = health
137 | return 0
138 | reward = self.scale * (health - self.previous)
139 | self.previous = health
140 | return np.float32(reward)
141 |
142 |
143 | BASIC_ACTIONS = {
144 | 'noop': dict(),
145 | 'attack': dict(attack=1),
146 | 'turn_up': dict(camera=(-15, 0)),
147 | 'turn_down': dict(camera=(15, 0)),
148 | 'turn_left': dict(camera=(0, -15)),
149 | 'turn_right': dict(camera=(0, 15)),
150 | 'forward': dict(forward=1),
151 | 'back': dict(back=1),
152 | 'left': dict(left=1),
153 | 'right': dict(right=1),
154 | 'jump': dict(jump=1, forward=1),
155 | 'place_dirt': dict(place='dirt'),
156 | }
157 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/envs/minecraft_base.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import threading
3 |
4 | import embodied
5 | import numpy as np
6 |
7 |
8 | class MinecraftBase(embodied.Env):
9 |
10 | _LOCK = threading.Lock()
11 |
12 | def __init__(
13 | self, actions,
14 | repeat=1,
15 | size=(64, 64),
16 | break_speed=100.0,
17 | gamma=10.0,
18 | sticky_attack=30,
19 | sticky_jump=10,
20 | pitch_limit=(-60, 60),
21 | logs=True, # TODO
22 | ):
23 | if logs:
24 | logging.basicConfig(level=logging.DEBUG)
25 | self._repeat = repeat
26 | self._size = size
27 | if break_speed != 1.0:
28 | sticky_attack = 0
29 |
30 | # Make env
31 | with self._LOCK:
32 | from .import minecraft_minerl
33 | self._gymenv = minecraft_minerl.MineRLEnv(size, break_speed, gamma).make()
34 | from . import from_gym
35 | self._env = from_gym.FromGym(self._gymenv)
36 | self._inventory = {}
37 |
38 | # Observations
39 | self._inv_keys = [
40 | k for k in self._env.obs_space if k.startswith('inventory/')
41 | if k != 'inventory/log2']
42 | self._step = 0
43 | self._max_inventory = None
44 | self._equip_enum = self._gymenv.observation_space[
45 | 'equipped_items']['mainhand']['type'].values.tolist()
46 | self._obs_space = self.obs_space
47 |
48 | # Actions
49 | self._noop_action = minecraft_minerl.NOOP_ACTION
50 | actions = self._insert_defaults(actions)
51 | self._action_names = tuple(actions.keys())
52 | self._action_values = tuple(actions.values())
53 | message = f'Minecraft action space ({len(self._action_values)}):'
54 | print(message, ', '.join(self._action_names))
55 | self._sticky_attack_length = sticky_attack
56 | self._sticky_attack_counter = 0
57 | self._sticky_jump_length = sticky_jump
58 | self._sticky_jump_counter = 0
59 | self._pitch_limit = pitch_limit
60 | self._pitch = 0
61 |
62 | @property
63 | def obs_space(self):
64 | return {
65 | 'image': embodied.Space(np.uint8, self._size + (3,)),
66 | 'inventory': embodied.Space(np.float32, len(self._inv_keys), 0),
67 | 'inventory_max': embodied.Space(np.float32, len(self._inv_keys), 0),
68 | 'equipped': embodied.Space(np.float32, len(self._equip_enum), 0, 1),
69 | 'reward': embodied.Space(np.float32),
70 | 'health': embodied.Space(np.float32),
71 | 'hunger': embodied.Space(np.float32),
72 | 'breath': embodied.Space(np.float32),
73 | 'is_first': embodied.Space(bool),
74 | 'is_last': embodied.Space(bool),
75 | 'is_terminal': embodied.Space(bool),
76 | **{f'log_{k}': embodied.Space(np.int64) for k in self._inv_keys},
77 | 'log_player_pos': embodied.Space(np.float32, 3),
78 | }
79 |
80 | @property
81 | def act_space(self):
82 | return {
83 | 'action': embodied.Space(np.int64, (), 0, len(self._action_values)),
84 | 'reset': embodied.Space(bool),
85 | }
86 |
87 | def step(self, action):
88 | action = action.copy()
89 | index = action.pop('action')
90 | action.update(self._action_values[index])
91 | action = self._action(action)
92 | if action['reset']:
93 | obs = self._reset()
94 | else:
95 | following = self._noop_action.copy()
96 | for key in ('attack', 'forward', 'back', 'left', 'right'):
97 | following[key] = action[key]
98 | for act in [action] + ([following] * (self._repeat - 1)):
99 | obs = self._env.step(act)
100 | if 'error' in self._env.info:
101 | obs = self._reset()
102 | break
103 | obs = self._obs(obs)
104 | self._step += 1
105 | assert 'pov' not in obs, list(obs.keys())
106 | return obs
107 |
108 | @property
109 | def inventory(self):
110 | return self._inventory
111 |
112 | def _reset(self):
113 | with self._LOCK:
114 | obs = self._env.step({'reset': True})
115 | self._step = 0
116 | self._max_inventory = None
117 | self._sticky_attack_counter = 0
118 | self._sticky_jump_counter = 0
119 | self._pitch = 0
120 | self._inventory = {}
121 | return obs
122 |
123 | def _obs(self, obs):
124 | obs['inventory/log'] += obs.pop('inventory/log2')
125 | self._inventory = {
126 | k.split('/', 1)[1]: obs[k] for k in self._inv_keys
127 | if k != 'inventory/air'}
128 | inventory = np.array([obs[k] for k in self._inv_keys], np.float32)
129 | if self._max_inventory is None:
130 | self._max_inventory = inventory
131 | else:
132 | self._max_inventory = np.maximum(self._max_inventory, inventory)
133 | index = self._equip_enum.index(obs['equipped_items/mainhand/type'])
134 | equipped = np.zeros(len(self._equip_enum), np.float32)
135 | equipped[index] = 1.0
136 | player_x = obs['location_stats/xpos']
137 | player_y = obs['location_stats/ypos']
138 | player_z = obs['location_stats/zpos']
139 | obs = {
140 | 'image': obs['pov'],
141 | 'inventory': inventory,
142 | 'inventory_max': self._max_inventory.copy(),
143 | 'equipped': equipped,
144 | 'health': np.float32(obs['life_stats/life'] / 20),
145 | 'hunger': np.float32(obs['life_stats/food'] / 20),
146 | 'breath': np.float32(obs['life_stats/air'] / 300),
147 | 'reward': 0.0,
148 | 'is_first': obs['is_first'],
149 | 'is_last': obs['is_last'],
150 | 'is_terminal': obs['is_terminal'],
151 | **{f'log_{k}': np.int64(obs[k]) for k in self._inv_keys},
152 | 'log_player_pos': np.array([player_x, player_y, player_z], np.float32),
153 | }
154 | for key, value in obs.items():
155 | space = self._obs_space[key]
156 | if not isinstance(value, np.ndarray):
157 | value = np.array(value)
158 | assert value in space, (key, value, value.dtype, value.shape, space)
159 | return obs
160 |
161 | def _action(self, action):
162 | if self._sticky_attack_length:
163 | if action['attack']:
164 | self._sticky_attack_counter = self._sticky_attack_length
165 | if self._sticky_attack_counter > 0:
166 | action['attack'] = 1
167 | action['jump'] = 0
168 | self._sticky_attack_counter -= 1
169 | if self._sticky_jump_length:
170 | if action['jump']:
171 | self._sticky_jump_counter = self._sticky_jump_length
172 | if self._sticky_jump_counter > 0:
173 | action['jump'] = 1
174 | action['forward'] = 1
175 | self._sticky_jump_counter -= 1
176 | if self._pitch_limit and action['camera'][0]:
177 | lo, hi = self._pitch_limit
178 | if not (lo <= self._pitch + action['camera'][0] <= hi):
179 | action['camera'] = (0, action['camera'][1])
180 | self._pitch += action['camera'][0]
181 | return action
182 |
183 | def _insert_defaults(self, actions):
184 | actions = {name: action.copy() for name, action in actions.items()}
185 | for key, default in self._noop_action.items():
186 | for action in actions.values():
187 | if key not in action:
188 | action[key] = default
189 | return actions
190 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/envs/minecraft_minerl.py:
--------------------------------------------------------------------------------
1 | from minerl.herobraine.env_spec import EnvSpec
2 | from minerl.herobraine.hero import handler
3 | from minerl.herobraine.hero import handlers
4 | from minerl.herobraine.hero import mc
5 | from minerl.herobraine.hero.mc import INVERSE_KEYMAP
6 |
7 |
8 | def edit_options(**kwargs):
9 | import os, pathlib, re
10 | for word in os.popen('pip3 --version').read().split(' '):
11 | if '-packages/pip' in word:
12 | break
13 | else:
14 | raise RuntimeError('Could not found python package directory.')
15 | packages = pathlib.Path(word).parent
16 | filename = packages / 'minerl/Malmo/Minecraft/run/options.txt'
17 | options = filename.read_text()
18 | if 'fovEffectScale:' not in options:
19 | options += 'fovEffectScale:1.0\n'
20 | if 'simulationDistance:' not in options:
21 | options += 'simulationDistance:12\n'
22 | for key, value in kwargs.items():
23 | assert f'{key}:' in options, key
24 | assert isinstance(value, str), (value, type(value))
25 | options = re.sub(f'{key}:.*\n', f'{key}:{value}\n', options)
26 | filename.write_text(options)
27 |
28 |
29 | edit_options(
30 | difficulty='2',
31 | renderDistance='6',
32 | simulationDistance='6',
33 | fovEffectScale='0.0',
34 | ao='1',
35 | gamma='5.0',
36 | )
37 |
38 |
39 | class MineRLEnv(EnvSpec):
40 |
41 | def __init__(self, resolution=(64, 64), break_speed=50, gamma=10.0):
42 | self.resolution = resolution
43 | self.break_speed = break_speed
44 | self.gamma = gamma
45 | super().__init__(name='MineRLEnv-v1')
46 |
47 | def create_agent_start(self):
48 | return [
49 | BreakSpeedMultiplier(self.break_speed),
50 | ]
51 |
52 | def create_agent_handlers(self):
53 | return []
54 |
55 | def create_server_world_generators(self):
56 | return [handlers.DefaultWorldGenerator(force_reset=True)]
57 |
58 | def create_server_quit_producers(self):
59 | return [handlers.ServerQuitWhenAnyAgentFinishes()]
60 |
61 | def create_server_initial_conditions(self):
62 | return [
63 | handlers.TimeInitialCondition(
64 | allow_passage_of_time=True,
65 | start_time=0,
66 | ),
67 | handlers.SpawningInitialCondition(
68 | allow_spawning=True,
69 | )
70 | ]
71 |
72 | def create_observables(self):
73 | return [
74 | handlers.POVObservation(self.resolution),
75 | handlers.FlatInventoryObservation(mc.ALL_ITEMS),
76 | handlers.EquippedItemObservation(
77 | mc.ALL_ITEMS, _default='air', _other='other'),
78 | handlers.ObservationFromCurrentLocation(),
79 | handlers.ObservationFromLifeStats(),
80 | ]
81 |
82 | def create_actionables(self):
83 | kw = dict(_other='none', _default='none')
84 | return [
85 | handlers.KeybasedCommandAction('forward', INVERSE_KEYMAP['forward']),
86 | handlers.KeybasedCommandAction('back', INVERSE_KEYMAP['back']),
87 | handlers.KeybasedCommandAction('left', INVERSE_KEYMAP['left']),
88 | handlers.KeybasedCommandAction('right', INVERSE_KEYMAP['right']),
89 | handlers.KeybasedCommandAction('jump', INVERSE_KEYMAP['jump']),
90 | handlers.KeybasedCommandAction('sneak', INVERSE_KEYMAP['sneak']),
91 | handlers.KeybasedCommandAction('attack', INVERSE_KEYMAP['attack']),
92 | handlers.CameraAction(),
93 | handlers.PlaceBlock(['none'] + mc.ALL_ITEMS, **kw),
94 | handlers.EquipAction(['none'] + mc.ALL_ITEMS, **kw),
95 | handlers.CraftAction(['none'] + mc.ALL_ITEMS, **kw),
96 | handlers.CraftNearbyAction(['none'] + mc.ALL_ITEMS, **kw),
97 | handlers.SmeltItemNearby(['none'] + mc.ALL_ITEMS, **kw),
98 | ]
99 |
100 | def is_from_folder(self, folder):
101 | return folder == 'none'
102 |
103 | def get_docstring(self):
104 | return ''
105 |
106 | def determine_success_from_rewards(self, rewards):
107 | return True
108 |
109 | def create_rewardables(self):
110 | return []
111 |
112 | def create_server_decorators(self):
113 | return []
114 |
115 | def create_mission_handlers(self):
116 | return []
117 |
118 | def create_monitors(self):
119 | return []
120 |
121 |
122 | class BreakSpeedMultiplier(handler.Handler):
123 |
124 | def __init__(self, multiplier=1.0):
125 | self.multiplier = multiplier
126 |
127 | def to_string(self):
128 | return f'break_speed({self.multiplier})'
129 |
130 | def xml_template(self):
131 | return '{{multiplier}}'
132 |
133 |
134 | class Gamma(handler.Handler):
135 |
136 | def __init__(self, gamma=2.0):
137 | self.gamma = gamma
138 |
139 | def to_string(self):
140 | return f'gamma({self.gamma})'
141 |
142 | def xml_template(self):
143 | return '{{gamma}}'
144 |
145 |
146 | NOOP_ACTION = dict(
147 | camera=(0, 0), forward=0, back=0, left=0, right=0, attack=0, sprint=0,
148 | jump=0, sneak=0, craft='none', nearbyCraft='none', nearbySmelt='none',
149 | place='none', equip='none',
150 | )
151 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/envs/pinpad.py:
--------------------------------------------------------------------------------
1 | import collections
2 |
3 | import embodied
4 | import numpy as np
5 |
6 |
7 | class PinPad(embodied.Env):
8 |
9 | COLORS = {
10 | '1': (255, 0, 0),
11 | '2': ( 0, 255, 0),
12 | '3': ( 0, 0, 255),
13 | '4': (255, 255, 0),
14 | '5': (255, 0, 255),
15 | '6': ( 0, 255, 255),
16 | '7': (128, 0, 128),
17 | '8': ( 0, 128, 128),
18 | }
19 |
20 | def __init__(self, task, length=10000):
21 | assert length > 0
22 | layout = {
23 | 'three': LAYOUT_THREE,
24 | 'four': LAYOUT_FOUR,
25 | 'five': LAYOUT_FIVE,
26 | 'six': LAYOUT_SIX,
27 | 'seven': LAYOUT_SEVEN,
28 | 'eight': LAYOUT_EIGHT,
29 | }[task]
30 | self.layout = np.array([list(line) for line in layout.split('\n')]).T
31 | assert self.layout.shape == (16, 14), self.layout.shape
32 | self.length = length
33 | self.random = np.random.RandomState()
34 | self.pads = set(self.layout.flatten().tolist()) - set('* #\n')
35 | self.target = tuple(sorted(self.pads))
36 | self.spawns = []
37 | for (x, y), char in np.ndenumerate(self.layout):
38 | if char != '#':
39 | self.spawns.append((x, y))
40 | print(f'Created PinPad env with sequence: {"->".join(self.target)}')
41 | self.sequence = collections.deque(maxlen=len(self.target))
42 | self.player = None
43 | self.steps = None
44 | self.done = None
45 | self.countdown = None
46 |
47 | @property
48 | def act_space(self):
49 | return {
50 | 'action': embodied.Space(np.int64, (), 0, 5),
51 | 'reset': embodied.Space(bool),
52 | }
53 |
54 | @property
55 | def obs_space(self):
56 | return {
57 | 'image': embodied.Space(np.uint8, (64, 64, 3)),
58 | 'reward': embodied.Space(np.float32),
59 | 'is_first': embodied.Space(bool),
60 | 'is_last': embodied.Space(bool),
61 | 'is_terminal': embodied.Space(bool),
62 | }
63 |
64 | def step(self, action):
65 | if self.done or action['reset']:
66 | self.player = self.spawns[self.random.randint(len(self.spawns))]
67 | self.sequence.clear()
68 | self.steps = 0
69 | self.done = False
70 | self.countdown = 0
71 | return self._obs(reward=0.0, is_first=True)
72 | if self.countdown:
73 | self.countdown -= 1
74 | if self.countdown == 0:
75 | self.player = self.spawns[self.random.randint(len(self.spawns))]
76 | self.sequence.clear()
77 | reward = 0.0
78 | move = [(0, 0), (0, 1), (0, -1), (1, 0), (-1, 0)][action['action']]
79 | x = np.clip(self.player[0] + move[0], 0, 15)
80 | y = np.clip(self.player[1] + move[1], 0, 13)
81 | tile = self.layout[x][y]
82 | if tile != '#':
83 | self.player = (x, y)
84 | if tile in self.pads:
85 | if not self.sequence or self.sequence[-1] != tile:
86 | self.sequence.append(tile)
87 | if tuple(self.sequence) == self.target and not self.countdown:
88 | reward += 10.0
89 | self.countdown = 10
90 | self.steps += 1
91 | self.done = self.done or (self.steps >= self.length)
92 | return self._obs(reward=reward, is_last=self.done)
93 |
94 | def render(self):
95 | grid = np.zeros((16, 16, 3), np.uint8) + 255
96 | white = np.array([255, 255, 255])
97 | if self.countdown:
98 | grid[:] = (223, 255, 223)
99 | current = self.layout[self.player[0]][self.player[1]]
100 | for (x, y), char in np.ndenumerate(self.layout):
101 | if char == '#':
102 | grid[x, y] = (192, 192, 192)
103 | elif char in self.pads:
104 | color = np.array(self.COLORS[char])
105 | color = color if char == current else (10 * color + 90 * white) / 100
106 | grid[x, y] = color
107 | grid[self.player] = (0, 0, 0)
108 | grid[:, -2:] = (192, 192, 192)
109 | for i, char in enumerate(self.sequence):
110 | grid[2 * i + 1, -2] = self.COLORS[char]
111 | image = np.repeat(np.repeat(grid, 4, 0), 4, 1)
112 | return image.transpose((1, 0, 2))
113 |
114 | def _obs(self, reward, is_first=False, is_last=False, is_terminal=False):
115 | return dict(
116 | image=self.render(), reward=reward, is_first=is_first, is_last=is_last,
117 | is_terminal=is_terminal)
118 |
119 |
120 | LAYOUT_THREE = """
121 | ################
122 | #1111 3333#
123 | #1111 3333#
124 | #1111 3333#
125 | #1111 3333#
126 | # #
127 | # #
128 | # #
129 | # #
130 | # 2222 #
131 | # 2222 #
132 | # 2222 #
133 | # 2222 #
134 | ################
135 | """.strip('\n')
136 |
137 | LAYOUT_FOUR = """
138 | ################
139 | #1111 4444#
140 | #1111 4444#
141 | #1111 4444#
142 | #1111 4444#
143 | # #
144 | # #
145 | # #
146 | # #
147 | #3333 2222#
148 | #3333 2222#
149 | #3333 2222#
150 | #3333 2222#
151 | ################
152 | """.strip('\n')
153 |
154 | LAYOUT_FIVE = """
155 | ################
156 | # 4444#
157 | #111 4444#
158 | #111 4444#
159 | #111 #
160 | #111 555#
161 | # 555#
162 | # 555#
163 | #333 555#
164 | #333 #
165 | #333 2222#
166 | #333 2222#
167 | # 2222#
168 | ################
169 | """.strip('\n')
170 |
171 | LAYOUT_SIX = """
172 | ################
173 | #111 555#
174 | #111 555#
175 | #111 555#
176 | # #
177 | #33 66#
178 | #33 66#
179 | #33 66#
180 | #33 66#
181 | # #
182 | #444 222#
183 | #444 222#
184 | #444 222#
185 | ################
186 | """.strip('\n')
187 |
188 | LAYOUT_SEVEN = """
189 | ################
190 | #111 444#
191 | #111 444#
192 | #11 44#
193 | # #
194 | #33 55#
195 | #33 55#
196 | #33 55#
197 | #33 55#
198 | # #
199 | #66 22#
200 | #666 7777 222#
201 | #666 7777 222#
202 | ################
203 | """.strip('\n')
204 |
205 | LAYOUT_EIGHT = """
206 | ################
207 | #111 8888 444#
208 | #111 8888 444#
209 | #11 44#
210 | # #
211 | #33 55#
212 | #33 55#
213 | #33 55#
214 | #33 55#
215 | # #
216 | #66 22#
217 | #666 7777 222#
218 | #666 7777 222#
219 | ################
220 | """.strip('\n')
221 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/envs/robodesk.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import embodied
4 |
5 |
6 | class RoboDesk(embodied.Env):
7 |
8 | def __init__(self, task, mode, repeat=1, length=500, resets=True):
9 | assert mode in ('train', 'eval')
10 | # TODO: This env variable is meant for headless GPU machines but may fail
11 | # on CPU-only machines.
12 | if 'MUJOCO_GL' not in os.environ:
13 | os.environ['MUJOCO_GL'] = 'egl'
14 | try:
15 | from robodesk import robodesk
16 | except ImportError:
17 | import robodesk
18 | task, reward = task.rsplit('_', 1)
19 | if mode == 'eval':
20 | reward = 'success'
21 | assert reward in ('dense', 'sparse', 'success'), reward
22 | self._gymenv = robodesk.RoboDesk(task, reward, repeat, length)
23 | from . import from_gym
24 | self._env = from_gym.FromGym(self._gymenv)
25 |
26 | @property
27 | def obs_space(self):
28 | return self._env.obs_space
29 |
30 | @property
31 | def act_space(self):
32 | return self._env.act_space
33 |
34 | def step(self, action):
35 | obs = self._env.step(action)
36 | obs['is_terminal'] = False
37 | return obs
38 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/replay/__init__.py:
--------------------------------------------------------------------------------
1 | from .generic import Generic
2 | from .generic_lfs import FIFO_LFS
3 | from .reverb import Reverb
4 | from .replays import Uniform
5 | from .naive_chunks import NaiveChunks
6 | from . import selectors
7 | from . import limiters
8 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/replay/chunk.py:
--------------------------------------------------------------------------------
1 | import io
2 | from datetime import datetime
3 |
4 | import embodied
5 | import numpy as np
6 | import uuid as uuidlib
7 |
8 |
9 | class Chunk:
10 | """
11 | This class represents a contiguous chunk of RL experience
12 | and implements some useful methods to work with that -
13 | append experience step, save and load.
14 | """
15 | def __init__(self, size, successor=None):
16 | now = datetime.now()
17 | self.time = now.strftime("%Y%m%dT%H%M%S") + f'F{now.microsecond:06d}'
18 | self.uuid_b = embodied.uuid()
19 | self.uuid = str(self.uuid_b)
20 | self.uuid_b = self.uuid_b.value
21 | self.successor = successor
22 | self.size = size
23 | self.data = None
24 | self.length = 0
25 |
26 | def __repr__(self):
27 | succ = self.successor or str(embodied.uuid(0))
28 | succ = succ.uuid if isinstance(succ, type(self)) else succ
29 | return (
30 | f'Chunk(uuid={self.uuid}, '
31 | f'succ={succ}, '
32 | f'len={self.length})')
33 |
34 | def __len__(self):
35 | return self.length
36 |
37 | def __bool__(self):
38 | return True
39 |
40 | def append(self, step):
41 | if not self.data:
42 | example = {k: embodied.convert(v) for k, v in step.items()}
43 | self.data = {
44 | k: np.empty((self.size,) + v.shape, v.dtype)
45 | for k, v in example.items()}
46 | for key, value in step.items():
47 | self.data[key][self.length] = value
48 | self.length += 1
49 |
50 | def save(self, directory):
51 | succ = self.successor or str(embodied.uuid(0))
52 | succ = succ.uuid if isinstance(succ, type(self)) else succ
53 | filename = f'{self.time}-{self.uuid}-{succ}-{self.length}.npz'
54 | filename = embodied.Path(directory) / filename
55 | data = {k: embodied.convert(v) for k, v in self.data.items()}
56 | with io.BytesIO() as stream:
57 | np.savez_compressed(stream, **data)
58 | stream.seek(0)
59 | filename.write(stream.read(), mode='wb')
60 | print(f'Saved chunk: {filename.name}')
61 |
62 | @classmethod
63 | def load(cls, filename, load_data=True):
64 | length = int(filename.stem.split('-')[3])
65 | if load_data:
66 | with embodied.Path(filename).open('rb') as f:
67 | data = np.load(f)
68 | data = {k: data[k] for k in data.keys()}
69 | else:
70 | data = None
71 | chunk = cls(length)
72 | chunk.time = filename.stem.split('-')[0]
73 | chunk.uuid = filename.stem.split('-')[1]
74 | chunk.successor = filename.stem.split('-')[2]
75 | chunk.length = length
76 | chunk.data = data
77 | chunk.filename = filename
78 | return chunk
79 |
80 | @classmethod
81 | def scan(cls, directory, capacity=None, shorten=0):
82 | directory = embodied.Path(directory)
83 | filenames, total = [], 0
84 | for filename in reversed(sorted(directory.glob('*.npz'))):
85 | if capacity and total >= capacity:
86 | break
87 | filenames.append(filename)
88 | total += max(0, int(filename.stem.split('-')[3]) - shorten)
89 | return sorted(filenames)
90 |
91 | class ChunkSerializer:
92 | """
93 | This class represents the serialization behaviour for the Chunk
94 | class definded above. The serialization format is
95 | (next chunk uuid, current chunk uuid, experience payload)
96 | where the latter is several numpy arrays serialized to bytes.
97 | The next chunk link is needed to keep the sequential order of data in
98 | the dataloader (see the buffer class defined in generic_lfs.py).
99 | """
100 | def __init__(self, pattern_obj: Chunk, pattern=None):
101 | if pattern is None:
102 | self.pattern = [(k, v.shape, v.dtype) for k, v in pattern_obj.data.items()]
103 | else:
104 | self.pattern = pattern
105 |
106 | def dummy_chunk(self):
107 | return {
108 | k: np.empty(shape=v, dtype=d) for k,v,d in self.pattern
109 | }
110 |
111 | @property
112 | def chunk_size(self):
113 | return sum(np.dtype(dt).itemsize * np.prod(sh) for _, sh, dt in self.pattern) + 16 * 2
114 |
115 | def batch_buffer(self, num_buffers, batch_size, sequence_length):
116 | return {
117 | k: np.empty((num_buffers, batch_size, sequence_length, *shape[1:]), dtype=dtype)
118 | for k, shape, dtype in self.pattern
119 | }
120 |
121 | def serialize(self, chunk: Chunk, buffer: np.ndarray):
122 | offset = 0
123 |
124 | succ = chunk.successor or str(embodied.uuid(0))
125 | succ = succ.uuid if isinstance(succ, type(chunk)) else succ
126 | succ = succ.value if isinstance(succ, embodied.uuid) else bytes(succ)
127 | buffer[offset:offset+16] = np.frombuffer(succ, dtype=np.uint8)
128 | offset += 16
129 |
130 | uuid = chunk.uuid_b
131 | buffer[offset:offset+16] = np.frombuffer(uuid, dtype=np.uint8)
132 | offset += 16
133 |
134 | for k, shape, dtype in self.pattern:
135 | assert shape == chunk.data[k].shape
136 | assert dtype == chunk.data[k].dtype
137 | array = chunk.data[k]
138 | buffer[offset:offset+array.nbytes] = array.view(np.uint8).flat
139 | offset += array.nbytes
140 |
141 | def deserialize(self, buffer):
142 | offset = 0
143 | succ = buffer[offset:offset+16].tobytes()
144 | offset += 16
145 | uuid = buffer[offset:offset+16].tobytes()
146 | offset += 16
147 | destination = {}
148 | for k, shape, dtype in self.pattern:
149 | bytes_size = np.dtype(dtype).itemsize * np.prod(shape)
150 | destination[k] = buffer[offset:offset+bytes_size].view(dtype).reshape(shape).copy()
151 | offset += bytes_size
152 | chunk = Chunk(len(destination[k]))
153 | chunk.data = destination
154 | chunk.uuid_b = uuid
155 | chunk.uuid = embodied.uuid(int.from_bytes(uuid, 'big'))
156 | chunk.successor = succ
157 | return chunk
--------------------------------------------------------------------------------
/recall2imagine/embodied/replay/generic.py:
--------------------------------------------------------------------------------
1 | import time
2 | from collections import defaultdict, deque
3 | from functools import partial as bind
4 |
5 | import embodied
6 | import numpy as np
7 |
8 | from . import saver
9 |
10 |
11 | class Generic:
12 |
13 | def __init__(
14 | self, length, capacity, remover, sampler, limiter, directory,
15 | overlap=None, online=False, chunks=1024):
16 | assert capacity is None or 1 <= capacity
17 | self.length = length
18 | self.capacity = capacity
19 | self.remover = remover
20 | self.sampler = sampler
21 | self.limiter = limiter
22 | self.stride = 1 if overlap is None else length - overlap
23 | self.streams = defaultdict(bind(deque, maxlen=length))
24 | self.counters = defaultdict(int)
25 | self.table = {}
26 | self.online = online
27 | if self.online:
28 | self.online_queue = deque()
29 | self.online_stride = length
30 | self.online_counters = defaultdict(int)
31 | self.saver = directory and saver.Saver(directory, chunks)
32 | self.metrics = {
33 | 'samples': 0,
34 | 'sample_wait_dur': 0,
35 | 'sample_wait_count': 0,
36 | 'inserts': 0,
37 | 'insert_wait_dur': 0,
38 | 'insert_wait_count': 0,
39 | }
40 | self.load()
41 |
42 | def __len__(self):
43 | return len(self.table)
44 |
45 | def set_agent(self, agent):
46 | """
47 | Placeholder method for API compatibility with the other buffer
48 | """
49 | pass
50 |
51 | @property
52 | def stats(self):
53 | ratio = lambda x, y: x / y if y else np.nan
54 | m = self.metrics
55 | stats = {
56 | 'size': len(self),
57 | 'inserts': m['inserts'],
58 | 'samples': m['samples'],
59 | 'insert_wait_avg': ratio(m['insert_wait_dur'], m['inserts']),
60 | 'insert_wait_frac': ratio(m['insert_wait_count'], m['inserts']),
61 | 'sample_wait_avg': ratio(m['sample_wait_dur'], m['samples']),
62 | 'sample_wait_frac': ratio(m['sample_wait_count'], m['samples']),
63 | }
64 | for key in self.metrics:
65 | self.metrics[key] = 0
66 | return stats
67 |
68 | def add(self, step, worker=0, load=False):
69 | step = {k: v for k, v in step.items() if not k.startswith('log_')}
70 | step['id'] = np.asarray(embodied.uuid(step.get('id')))
71 | stream = self.streams[worker]
72 | stream.append(step)
73 | self.saver and self.saver.add(step, worker)
74 | self.counters[worker] += 1
75 | if self.online:
76 | self.online_counters[worker] += 1
77 | if len(stream) >= self.length and (
78 | self.online_counters[worker] >= self.online_stride):
79 | self.online_queue.append(tuple(stream))
80 | self.online_counters[worker] = 0
81 | if len(stream) < self.length or self.counters[worker] < self.stride:
82 | return
83 | self.counters[worker] = 0
84 | key = embodied.uuid()
85 | seq = tuple(stream)
86 | if load:
87 | assert self.limiter.want_load()[0]
88 | else:
89 | dur = wait(self.limiter.want_insert, 'Replay insert is waiting')
90 | self.metrics['inserts'] += 1
91 | self.metrics['insert_wait_dur'] += dur
92 | self.metrics['insert_wait_count'] += int(dur > 0)
93 | self.table[key] = seq
94 | self.remover[key] = seq
95 | self.sampler[key] = seq
96 | while self.capacity and len(self) > self.capacity:
97 | self._remove(self.remover())
98 |
99 | def _sample(self):
100 | dur = wait(self.limiter.want_sample, 'Replay sample is waiting')
101 | self.metrics['samples'] += 1
102 | self.metrics['sample_wait_dur'] += dur
103 | self.metrics['sample_wait_count'] += int(dur > 0)
104 | if self.online:
105 | try:
106 | seq = self.online_queue.popleft()
107 | except IndexError:
108 | seq = self.table[self.sampler()]
109 | else:
110 | seq = self.table[self.sampler()]
111 | seq = {k: [step[k] for step in seq] for k in seq[0]}
112 | seq = {k: embodied.convert(v) for k, v in seq.items()}
113 | if 'is_first' in seq:
114 | seq['is_first'][0] = True
115 | return seq
116 |
117 | def _remove(self, key):
118 | wait(self.limiter.want_remove, 'Replay remove is waiting')
119 | del self.table[key]
120 | del self.remover[key]
121 | del self.sampler[key]
122 |
123 | def dataset(self):
124 | while True:
125 | yield self._sample()
126 |
127 | def prioritize(self, ids, prios):
128 | if hasattr(self.sampler, 'prioritize'):
129 | self.sampler.prioritize(ids, prios)
130 |
131 | def save(self, wait=False):
132 | if not self.saver:
133 | return
134 | self.saver.save(wait)
135 | # return {
136 | # 'saver': self.saver.save(wait),
137 | # # 'remover': self.remover.save(wait),
138 | # # 'sampler': self.sampler.save(wait),
139 | # # 'limiter': self.limiter.save(wait),
140 | # }
141 |
142 | def maybe_restore(self):
143 | self.load()
144 |
145 | def load(self, data=None):
146 | if not self.saver:
147 | return
148 | workers = set()
149 | for step, worker in self.saver.load(self.capacity, self.length):
150 | workers.add(worker)
151 | self.add(step, worker, load=True)
152 | for worker in workers:
153 | del self.streams[worker]
154 | del self.counters[worker]
155 | # self.remover.load(data['remover'])
156 | # self.sampler.load(data['sampler'])
157 | # self.limiter.load(data['limiter'])
158 |
159 |
160 | def wait(predicate, message, sleep=0.001, notify=1.0):
161 | start = time.time()
162 | notified = False
163 | while True:
164 | allowed, detail = predicate()
165 | duration = time.time() - start
166 | if allowed:
167 | return duration
168 | if not notified and duration >= notify:
169 | print(f'{message} ({detail})')
170 | notified = True
171 | time.sleep(sleep)
172 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/replay/limiters.py:
--------------------------------------------------------------------------------
1 | import threading
2 |
3 |
4 | class MinSize:
5 |
6 | def __init__(self, minimum):
7 | assert 1 <= minimum, minimum
8 | self.minimum = minimum
9 | self.size = 0
10 | self.lock = threading.Lock()
11 |
12 | def want_load(self):
13 | with self.lock:
14 | self.size += 1
15 | return True, 'ok'
16 |
17 | def want_insert(self):
18 | with self.lock:
19 | self.size += 1
20 | return True, 'ok'
21 |
22 | def want_remove(self):
23 | with self.lock:
24 | if self.size < 1:
25 | return False, 'is empty'
26 | self.size -= 1
27 | return True, 'ok'
28 |
29 | def want_sample(self):
30 | if self.size < self.minimum:
31 | return False, f'too empty: {self.size} < {self.minimum}'
32 | return True, 'ok'
33 |
34 |
35 | class SamplesPerInsert:
36 |
37 | def __init__(self, samples_per_insert, tolerance, minimum=1, unlocked=False):
38 | assert 1 <= minimum
39 | self.samples_per_insert = samples_per_insert
40 | self.unlocked = unlocked
41 | self.minimum = minimum
42 | self.avail = -minimum
43 | self.min_avail = -tolerance
44 | self.max_avail = tolerance * samples_per_insert
45 | self.size = 0
46 | self.lock = threading.Lock()
47 |
48 | def want_load(self):
49 | with self.lock:
50 | self.size += 1
51 | return True, 'ok'
52 |
53 | def want_insert(self):
54 | with self.lock:
55 | if self.unlocked:
56 | self.avail += self.samples_per_insert
57 | self.size += 1
58 | return True, 'ok'
59 | else:
60 | if self.avail >= self.max_avail:
61 | return False, f'rate limited: {self.avail:.3f} >= {self.max_avail:.3f}'
62 | self.avail += self.samples_per_insert
63 | self.size += 1
64 | return True, 'ok'
65 |
66 | def want_remove(self):
67 | with self.lock:
68 | if self.size < 1:
69 | return False, 'is empty'
70 | self.size -= 1
71 | return True, 'ok'
72 |
73 | def want_sample(self):
74 | with self.lock:
75 | if self.size < self.minimum:
76 | return False, f'too empty: {self.size} < {self.minimum}'
77 | if self.unlocked:
78 | self.avail -= 1
79 | return True, 'ok'
80 | else:
81 | if self.avail <= self.min_avail:
82 | return False, f'rate limited: {self.avail:.3f} <= {self.min_avail:.3f}'
83 | self.avail -= 1
84 | return True, 'ok'
85 |
86 |
87 | class Queue:
88 |
89 | def __init__(self, capacity):
90 | assert 1 <= capacity
91 | self.capacity = capacity
92 | self.size = 0
93 | self.lock = threading.Lock()
94 |
95 | def want_load(self):
96 | with self.lock:
97 | self.size += 1
98 | return True, 'ok'
99 |
100 | def want_insert(self):
101 | with self.lock:
102 | if self.size >= self.capacity:
103 | return False, f'is full: {self.size} >= {self.capacity}'
104 | self.size += 1
105 | return True, 'ok'
106 |
107 | def want_remove(self):
108 | with self.lock:
109 | if self.size < 1:
110 | return False, 'is empty'
111 | self.size -= 1
112 | return True, 'ok'
113 |
114 | def want_sample(self):
115 | if self.size < 1:
116 | return False, 'is empty'
117 | else:
118 | return True, 'ok'
119 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/replay/naive_chunks.py:
--------------------------------------------------------------------------------
1 | import concurrent.futures
2 | import threading
3 | import time
4 | import uuid
5 | from collections import deque, defaultdict
6 | from functools import partial as bind
7 |
8 | import numpy as np
9 | import embodied
10 |
11 | from . import chunk as chunklib
12 |
13 |
14 | class NaiveChunks(embodied.Replay):
15 |
16 | def __init__(self, length, capacity=None, directory=None, chunks=1024, seed=0):
17 | assert 1 <= length <= chunks
18 | self.length = length
19 | self.capacity = capacity
20 | self.directory = directory and embodied.Path(directory)
21 | self.chunks = chunks
22 | self.buffers = {}
23 | self.rng = np.random.default_rng(seed)
24 | self.ongoing = defaultdict(bind(chunklib.Chunk, chunks))
25 | if directory:
26 | self.directory.mkdirs()
27 | self.workers = concurrent.futures.ThreadPoolExecutor(16)
28 | self.promises = deque()
29 |
30 | def __len__(self):
31 | return len(self.buffers) * self.length
32 |
33 | @property
34 | def stats(self):
35 | return {'size': len(self), 'chunks': len(self.buffers)}
36 |
37 | def add(self, step, worker=0):
38 | step = {k: v for k, v in step.items() if not k.startswith('log_')}
39 | chunk = self.ongoing[worker]
40 | chunk.append(step)
41 | if len(chunk) >= self.chunks:
42 | self.buffers[chunk.uuid] = self.ongoing.pop(worker)
43 | self.promises.append(self.workers.submit(chunk.save, self.directory))
44 | for promise in [x for x in self.promises if x.done()]:
45 | promise.result()
46 | self.promises.remove(promise)
47 | while len(self) > self.capacity:
48 | del self.buffers[next(iter(self.buffers.keys()))]
49 |
50 | def _sample(self):
51 | counter = 0
52 | while not self.buffers:
53 | if counter % 100 == 0:
54 | print('Replay sample is waiting')
55 | time.sleep(0.1)
56 | counter += 1
57 | keys = tuple(self.buffers.keys())
58 | chunk = self.buffers[keys[self.rng.integers(0, len(keys))]]
59 | idx = self.rng.integers(0, len(chunk) - self.length + 1)
60 | seq = {k: chunk.data[k][idx: idx + self.length] for k in chunk.data.keys()}
61 | seq['is_first'][0] = True
62 | return seq
63 |
64 | def dataset(self):
65 | while True:
66 | yield self._sample()
67 |
68 | def save(self, wait=False):
69 | for chunk in self.ongoing.values():
70 | if chunk.length:
71 | self.promises.append(self.workers.submit(chunk.save, self.directory))
72 | if wait:
73 | [x.result() for x in self.promises]
74 | self.promises.clear()
75 |
76 | def load(self, data=None):
77 | filenames = chunklib.Chunk.scan(self.directory, capacity)
78 | if not filenames:
79 | return
80 | threads = min(len(filenames), 32)
81 | with concurrent.futures.ThreadPoolExecutor(threads) as executor:
82 | chunks = list(executor.map(chunklib.Chunk.load, filenames))
83 | self.buffers = {chunk.uuid: chunk for chunk in chunks}
84 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/replay/replays.py:
--------------------------------------------------------------------------------
1 | from . import generic
2 | from . import selectors
3 | from . import limiters
4 |
5 |
6 | class Uniform(generic.Generic):
7 |
8 | def __init__(
9 | self, length, capacity=None, directory=None, online=False, chunks=1024,
10 | min_size=1, samples_per_insert=None, tolerance=1e4, seed=0):
11 | if samples_per_insert:
12 | limiter = limiters.SamplesPerInsert(
13 | samples_per_insert, tolerance, min_size)
14 | else:
15 | limiter = limiters.MinSize(min_size)
16 | assert not capacity or min_size <= capacity
17 | super().__init__(
18 | length=length,
19 | capacity=capacity,
20 | remover=selectors.Fifo(),
21 | sampler=selectors.Uniform(seed),
22 | limiter=limiter,
23 | directory=directory,
24 | online=online,
25 | chunks=chunks,
26 | )
27 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/replay/reverb.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | from collections import defaultdict
3 | from functools import partial as bind
4 |
5 | import embodied
6 | import numpy as np
7 |
8 |
9 | class Reverb:
10 |
11 | def __init__(
12 | self, length, capacity=None, directory=None, chunks=None, flush=100):
13 | del chunks
14 | import reverb
15 | self.length = length
16 | self.capacity = capacity
17 | self.directory = directory and embodied.Path(directory)
18 | self.checkpointer = None
19 | self.server = None
20 | self.client = None
21 | self.writers = None
22 | self.counters = None
23 | self.signature = None
24 | self.flush = flush
25 | if self.directory:
26 | self.directory.mkdirs()
27 | path = str(self.directory)
28 | try:
29 | self.checkpointer = reverb.checkpointers.DefaultCheckpointer(path)
30 | except AttributeError:
31 | self.checkpointer = reverb.checkpointers.RecordIOCheckpointer(path)
32 | self.sigpath = self.directory.parent / (self.directory.name + '_sig.pkl')
33 | if self.directory and self.sigpath.exists():
34 | with self.sigpath.open('rb') as file:
35 | self.signature = pickle.load(file)
36 | self._create_server()
37 |
38 | def _create_server(self):
39 | import reverb
40 | import tensorflow as tf
41 | self.server = reverb.Server(tables=[reverb.Table(
42 | name='table',
43 | sampler=reverb.selectors.Uniform(),
44 | remover=reverb.selectors.Fifo(),
45 | max_size=int(self.capacity),
46 | rate_limiter=reverb.rate_limiters.MinSize(1),
47 | signature={
48 | key: tf.TensorSpec(shape, dtype)
49 | for key, (shape, dtype) in self.signature.items()},
50 | )], port=None, checkpointer=self.checkpointer)
51 | self.client = reverb.Client(f'localhost:{self.server.port}')
52 | self.writers = defaultdict(bind(
53 | self.client.trajectory_writer, self.length))
54 | self.counters = defaultdict(int)
55 |
56 | def __len__(self):
57 | if not self.client:
58 | return 0
59 | return self.client.server_info()['table'].current_size
60 |
61 | @property
62 | def stats(self):
63 | return {'size': len(self)}
64 |
65 | def add(self, step, worker=0):
66 | step = {k: v for k, v in step.items() if not k.startswith('log_')}
67 | step = {k: embodied.convert(v) for k, v in step.items()}
68 | step['id'] = np.asarray(embodied.uuid(step.get('id')))
69 | if not self.server:
70 | self.signature = {
71 | k: ((self.length, *v.shape), v.dtype) for k, v in step.items()}
72 | self._create_server()
73 | step = {k: v for k, v in step.items() if not k.startswith('log_')}
74 | writer = self.writers[worker]
75 | writer.append(step)
76 | if len(next(iter(writer.history.values()))) >= self.length:
77 | seq = {k: v[-self.length:] for k, v in writer.history.items()}
78 | writer.create_item('table', priority=1.0, trajectory=seq)
79 | self.counters[worker] += 1
80 | if self.counters[worker] > self.flush:
81 | self.counters[worker] = 0
82 | writer.flush()
83 |
84 | def dataset(self):
85 | import reverb
86 | dataset = reverb.TrajectoryDataset.from_table_signature(
87 | server_address=f'localhost:{self.server.port}',
88 | table='table',
89 | max_in_flight_samples_per_worker=10)
90 | for sample in dataset:
91 | seq = sample.data
92 | seq = {k: embodied.convert(v) for k, v in seq.items()}
93 | # seq['key'] = sample.info.key # uint64
94 | # seq['prob'] = sample.info.probability
95 | if 'is_first' in seq:
96 | seq['is_first'] = np.array(seq['is_first'])
97 | seq['is_first'][0] = True
98 | yield seq
99 |
100 | def prioritize(self, ids, prios):
101 | raise NotImplementedError
102 |
103 | def save(self, wait=False):
104 | for writer in self.writers.values():
105 | writer.flush()
106 | with self.sigpath.open('wb') as file:
107 | file.write(pickle.dumps(self.signature))
108 | if self.directory:
109 | self.client.checkpoint()
110 |
111 | def load(self, data=None):
112 | pass
113 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/replay/saver.py:
--------------------------------------------------------------------------------
1 | import concurrent.futures
2 | from multiprocessing.pool import ThreadPool
3 | from collections import defaultdict, deque
4 | from functools import partial as bind
5 | import numpy as np
6 | from tqdm import tqdm
7 | import embodied
8 |
9 | from . import chunk as chunklib
10 |
11 |
12 | class Saver:
13 |
14 | def __init__(self, directory, chunks=1024):
15 | self.directory = embodied.Path(directory)
16 | self.directory.mkdirs()
17 | self.chunks = chunks
18 | self.buffers = defaultdict(bind(chunklib.Chunk, chunks))
19 | self.workers = concurrent.futures.ThreadPoolExecutor(16)
20 | self.promises = deque()
21 | self.loading = False
22 |
23 | def add(self, step, worker):
24 | if self.loading:
25 | return
26 | buffer = self.buffers[worker]
27 | buffer.append(step)
28 | if buffer.length >= self.chunks:
29 | self.buffers[worker] = buffer.successor = chunklib.Chunk(self.chunks)
30 | self.promises.append(self.workers.submit(buffer.save, self.directory))
31 | for promise in [x for x in self.promises if x.done()]:
32 | promise.result()
33 | self.promises.remove(promise)
34 |
35 | def save(self, wait=False):
36 | for buffer in self.buffers.values():
37 | if buffer.length:
38 | self.promises.append(self.workers.submit(buffer.save, self.directory))
39 | if wait:
40 | [x.result() for x in self.promises]
41 | self.promises.clear()
42 |
43 | def load(self, capacity, length):
44 | filenames = chunklib.Chunk.scan(self.directory, capacity, 1)
45 | if not filenames:
46 | return
47 | lazychunks = [chunklib.Chunk.load(f, load_data=False) for f in filenames]
48 | lazychunks1 = {c.uuid: c for c in lazychunks}
49 | was = {k: False for k in lazychunks1}
50 | streamids = {}
51 | streams = set()
52 | while any(not x for x in was.values()):
53 | for chunk in (sorted(lazychunks, key=lambda x: x.time)):
54 | if not was[chunk.uuid]:
55 | stream_id = chunk.uuid
56 | streamids[chunk.uuid] = stream_id
57 | streams.add(stream_id)
58 | was[stream_id] = True
59 | nxt = lazychunks1[chunk.successor] if chunk.successor in lazychunks1 else None
60 | steps = 0
61 | while nxt is not None and nxt.uuid in lazychunks1:
62 | streamids[nxt.uuid] = stream_id
63 | was[nxt.uuid] = True
64 | nxt = lazychunks1[nxt.successor] if nxt.successor in lazychunks1 else None
65 | steps += 1
66 | print(f'Stream: {stream_id} ({len(streams)}th) - {steps} steps')
67 | threads = min(len(filenames), 32)
68 | self.loading = True
69 | with ThreadPool(threads) as executor:
70 | for chunk in tqdm(executor.imap_unordered(chunklib.Chunk.load, filenames), total=len(filenames)):
71 | stream = streamids[chunk.uuid]
72 | for index in range(chunk.length):
73 | step = {k: v[index] for k, v in chunk.data.items()}
74 | yield step, stream
75 | del chunk
76 | self.loading = False
77 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/replay/selectors.py:
--------------------------------------------------------------------------------
1 | from collections import deque
2 |
3 | import numpy as np
4 |
5 |
6 | class Fifo:
7 |
8 | def __init__(self):
9 | self.queue = deque()
10 |
11 | def __call__(self):
12 | return self.queue[0]
13 |
14 | def __setitem__(self, key, steps):
15 | self.queue.append(key)
16 |
17 | def __delitem__(self, key):
18 | if self.queue[0] == key:
19 | self.queue.popleft()
20 | else:
21 | # TODO: This branch is unused but very slow.
22 | self.queue.remove(key)
23 |
24 |
25 | class Uniform:
26 |
27 | def __init__(self, seed=0):
28 | self.indices = {}
29 | self.keys = []
30 | self.rng = np.random.default_rng(seed)
31 |
32 | def __call__(self):
33 | index = self.rng.integers(0, len(self.keys)).item()
34 | return self.keys[index]
35 |
36 | def __setitem__(self, key, steps):
37 | self.indices[key] = len(self.keys)
38 | self.keys.append(key)
39 |
40 | def __delitem__(self, key):
41 | index = self.indices.pop(key)
42 | last = self.keys.pop()
43 | if index != len(self.keys):
44 | self.keys[index] = last
45 | self.indices[last] = index
46 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/run/__init__.py:
--------------------------------------------------------------------------------
1 | from .eval_only import eval_only
2 | from .parallel import parallel
3 | from .train import train
4 | from .train_eval import train_eval
5 | from .train_holdout import train_holdout
6 | from .train_save import train_save
7 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/run/eval_only.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | import embodied
4 | import numpy as np
5 |
6 |
7 | def eval_only(agent, env, logger, args):
8 |
9 | logdir = embodied.Path(args.logdir)
10 | logdir.mkdirs()
11 | print('Logdir', logdir)
12 | should_log = embodied.when.Clock(args.log_every)
13 | step = logger.step
14 | metrics = embodied.Metrics()
15 | print('Observation space:', env.obs_space)
16 | print('Action space:', env.act_space)
17 |
18 | timer = embodied.Timer()
19 | timer.wrap('agent', agent, ['policy'])
20 | timer.wrap('env', env, ['step'])
21 | timer.wrap('logger', logger, ['write'])
22 |
23 | nonzeros = set()
24 | def per_episode(ep):
25 | length = len(ep['reward']) - 1
26 | score = float(ep['reward'].astype(np.float64).sum())
27 | logger.add({'length': length, 'score': score}, prefix='episode')
28 | print(f'Episode has {length} steps and return {score:.1f}.')
29 | stats = {}
30 | for key in args.log_keys_video:
31 | if key in ep:
32 | stats[f'policy_{key}'] = ep[key]
33 | for key, value in ep.items():
34 | if not args.log_zeros and key not in nonzeros and (value == 0).all():
35 | continue
36 | nonzeros.add(key)
37 | if re.match(args.log_keys_sum, key):
38 | stats[f'sum_{key}'] = ep[key].sum()
39 | if re.match(args.log_keys_mean, key):
40 | stats[f'mean_{key}'] = ep[key].mean()
41 | if re.match(args.log_keys_max, key):
42 | stats[f'max_{key}'] = ep[key].max(0).mean()
43 | metrics.add(stats, prefix='stats')
44 |
45 | driver = embodied.Driver(env)
46 | driver.on_episode(lambda ep, worker: per_episode(ep))
47 | driver.on_step(lambda tran, _: step.increment())
48 |
49 | checkpoint = embodied.Checkpoint()
50 | checkpoint.agent = agent
51 | checkpoint.load(args.from_checkpoint, keys=['agent'])
52 |
53 | print('Start evaluation loop.')
54 | policy = lambda *args: agent.policy(*args, mode='eval')
55 | while step < args.steps:
56 | driver(policy, steps=100)
57 | if should_log(step):
58 | logger.add(metrics.result())
59 | logger.add(timer.stats(), prefix='timer')
60 | logger.write(fps=True)
61 | logger.write()
62 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/run/parallel.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import time
3 | import logging
4 | import threading
5 | from collections import defaultdict
6 |
7 | import embodied
8 | import numpy as np
9 |
10 |
11 | def parallel(agent, replay, logger, make_env, num_envs, args):
12 | step = logger.step
13 | timer = embodied.Timer()
14 | timer.wrap('agent', agent, ['policy', 'train', 'report', 'save'])
15 | timer.wrap('replay', replay, ['add', 'save'])
16 | timer.wrap('logger', logger, ['write'])
17 | workers = []
18 | workers.append(embodied.distr.Thread(
19 | actor, step, agent, replay, logger, args.actor_addr, args))
20 | workers.append(embodied.distr.Thread(
21 | learner, step, agent, replay, logger, timer, args))
22 | if num_envs == 1:
23 | workers.append(embodied.distr.Thread(
24 | env, make_env, args.actor_addr, 0, args, timer))
25 | else:
26 | for i in range(num_envs):
27 | workers.append(embodied.distr.Process(
28 | env, make_env, args.actor_addr, i, args))
29 | embodied.distr.run(workers)
30 |
31 |
32 | def actor(step, agent, replay, logger, actor_addr, args):
33 | metrics = embodied.Metrics()
34 | scalars = defaultdict(lambda: defaultdict(list))
35 | videos = defaultdict(lambda: defaultdict(list))
36 | should_log = embodied.when.Clock(args.log_every)
37 |
38 | _, initial = agent.policy(dummy_data(
39 | agent.agent.obs_space, (args.actor_batch,)))
40 | initial = embodied.treemap(lambda x: x[0], initial)
41 | allstates = defaultdict(lambda: initial)
42 | agent.sync()
43 | # step.t = 0
44 |
45 | def callback(obs, env_addrs):
46 | states = [allstates[a] for a in env_addrs]
47 | states = embodied.treemap(lambda *xs: list(xs), *states)
48 | act, states = agent.policy(obs, states)
49 | act['reset'] = obs['is_last'].copy()
50 | for i, a in enumerate(env_addrs):
51 | allstates[a] = embodied.treemap(lambda x: x[i], states)
52 |
53 | trans = {**obs, **act}
54 | for i, a in enumerate(env_addrs):
55 | tran = {k: v[i].copy() for k, v in trans.items()}
56 | replay.add(tran.copy(), worker=a)
57 | [scalars[a][k].append(v) for k, v in tran.items() if v.size == 1]
58 | [videos[a][k].append(tran[k]) for k in args.log_keys_video if k in tran]
59 | step.increment(args.actor_batch)
60 | # print(f'fps: {args.actor_batch / (time.time() - step.t):.3f}')
61 | # step.t = time.time()
62 |
63 |
64 |
65 | for i, a in enumerate(env_addrs):
66 | if not trans['is_last'][i]:
67 | continue
68 | vids = videos.pop(a) if a in videos else {}
69 | ep = {**scalars.pop(a), **vids}
70 | ep = {k: embodied.convert(v) for k, v in ep.items()}
71 | logger.add({
72 | 'length': len(ep['reward']) - 1,
73 | 'score': sum(ep['reward']),
74 | }, prefix='episode')
75 | stats = {}
76 | for key in args.log_keys_video:
77 | if key != 'none':
78 | stats[f'policy_{key}'] = ep[key]
79 | metrics.add(stats, prefix='stats')
80 |
81 | if should_log():
82 | logger.add(metrics.result())
83 |
84 | return act
85 |
86 | print('[actor] Start server')
87 | embodied.BatchServer(actor_addr, args.actor_batch, callback).run()
88 |
89 |
90 | def learner(step, agent, replay, logger, timer, args):
91 | logdir = embodied.Path(args.logdir)
92 | ckpt_dir = embodied.Path(args.checkpoint_dir)
93 | metrics = embodied.Metrics()
94 | should_log = embodied.when.Clock(args.log_every)
95 | should_save = embodied.when.Clock(args.save_every)
96 | should_sync = embodied.when.Every(args.sync_every)
97 | updates = embodied.Counter()
98 |
99 | checkpoint = embodied.Checkpoint(ckpt_dir / 'checkpoint.ckpt')
100 | checkpoint.step = step
101 | checkpoint.agent = agent
102 | checkpoint.replay = replay
103 | is_set = False
104 | while not is_set:
105 | with replay.manager.event_lock:
106 | is_set = bool(replay.manager.init_event.is_set())
107 | if not is_set:
108 | print('Learner is waiting for the actor to initialize the buffer..')
109 | time.sleep(30)
110 |
111 | if args.from_checkpoint:
112 | checkpoint.load(args.from_checkpoint)
113 | checkpoint.load_or_save(filename=replay.manager.tmp_file_path)
114 | # a workaround for the logger step issue when restoring
115 | logger.step.value = embodied.Counter(initial=agent.step.value)
116 |
117 | dataset = agent.dataset(replay)
118 | state = None
119 | stats = dict(last_time=time.time(), last_step=int(step), batch_entries=0)
120 | while True:
121 | batch = next(dataset)
122 | outs, state, mets = agent.train(batch, state)
123 | metrics.add(mets)
124 | updates.increment()
125 | stats['batch_entries'] += batch['is_first'].size
126 |
127 | if should_sync(updates):
128 | agent.sync()
129 |
130 | if should_log():
131 | train = metrics.result()
132 | report = agent.report(batch)
133 | report = {k: v for k, v in report.items() if 'train/' + k not in train}
134 | logger.add(train, prefix='train')
135 | logger.add(report, prefix='report')
136 | logger.add(timer.stats(), prefix='timer')
137 | logger.add(replay.stats, prefix='replay')
138 |
139 | duration = time.time() - stats['last_time']
140 | actor_fps = (int(step) - stats['last_step']) / duration
141 | learner_fps = stats['batch_entries'] / duration
142 | logger.add({
143 | 'actor_fps': actor_fps,
144 | 'learner_fps': learner_fps,
145 | 'train_ratio': learner_fps / actor_fps if actor_fps else np.inf,
146 | }, prefix='parallel')
147 | stats = dict(last_time=time.time(), last_step=int(step), batch_entries=0)
148 | try:
149 | logger.write(fps=True)
150 | except:
151 | print('logging failed')
152 |
153 | if should_save():
154 | try:
155 | checkpoint.save()
156 | except:
157 | print('saving failed')
158 | pass
159 |
160 |
161 | def env(make_env, actor_addr, i, args, timer=None):
162 | # TODO: Optionally write NPZ episodes.
163 | print(f'[env{i}] Make env')
164 | env = make_env()
165 | if timer:
166 | timer.wrap('env', env, ['step'])
167 | actor = embodied.Client(actor_addr)
168 | act = {k: v.sample() for k, v in env.act_space.items()}
169 | done = False
170 | while True:
171 | act['reset'] = done
172 | obs = env.step(act)
173 | obs = {k: np.asarray(v) for k, v in obs.items()}
174 | done = obs['is_last']
175 | promise = actor(obs)
176 | try:
177 | act = promise()
178 | except RuntimeError:
179 | sys.exit(0)
180 | act = {k: v for k, v in act.items() if not k.startswith('log_')}
181 |
182 |
183 | def dummy_data(spaces, batch_dims):
184 | # TODO: Get rid of this function by adding initial_policy_state() and
185 | # initial_train_state() to the agent API.
186 | spaces = list(spaces.items())
187 | data = {k: np.zeros(v.shape, v.dtype) for k, v in spaces}
188 | for dim in reversed(batch_dims):
189 | data = {k: np.repeat(v[None], dim, axis=0) for k, v in data.items()}
190 | return data
191 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/run/train.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | import embodied
4 | import numpy as np
5 | import jax
6 |
7 |
8 | def train(agent, env, replay, logger, args, config):
9 |
10 | logdir = embodied.Path(args.logdir)
11 | logdir.mkdirs()
12 | print('Logdir', logdir)
13 | should_expl = embodied.when.Until(args.expl_until)
14 | should_report = embodied.when.Every(1000000)
15 | should_train = embodied.when.Ratio(args.train_ratio / args.batch_steps)
16 | should_log = embodied.when.Clock(args.log_every)
17 | should_save = embodied.when.Clock(args.save_every)
18 | should_sync = embodied.when.Every(args.sync_every)
19 | should_profile = args.profile_path != 'none'
20 | step = logger.step
21 | updates = embodied.Counter()
22 | metrics = embodied.Metrics()
23 | print('Observation space:', embodied.format(env.obs_space), sep='\n')
24 | print('Action space:', embodied.format(env.act_space), sep='\n')
25 |
26 | timer = embodied.Timer()
27 | timer.wrap('agent', agent, ['policy', 'train', 'report', 'save'])
28 | timer.wrap('env', env, ['step'])
29 | timer.wrap('replay', replay, ['add', 'save'])
30 | timer.wrap('logger', logger, ['write'])
31 |
32 | nonzeros = set()
33 | def per_episode(ep):
34 | length = len(ep['reward']) - 1
35 | score = float(ep['reward'].astype(np.float64).sum())
36 | sum_abs_reward = float(np.abs(ep['reward']).astype(np.float64).sum())
37 | logger.add({
38 | 'length': length,
39 | 'score': score,
40 | 'sum_abs_reward': sum_abs_reward,
41 | 'reward_rate': (np.abs(ep['reward']) >= 0.5).mean(),
42 | }, prefix='episode')
43 | print(f'Episode has {length} steps and return {score:.1f}.')
44 | stats = {}
45 | for key in args.log_keys_video:
46 | if key in ep:
47 | stats[f'policy_{key}'] = ep[key]
48 | for key, value in ep.items():
49 | if not args.log_zeros and key not in nonzeros and (value == 0).all():
50 | continue
51 | nonzeros.add(key)
52 | if re.match(args.log_keys_sum, key):
53 | stats[f'sum_{key}'] = ep[key].sum()
54 | if re.match(args.log_keys_mean, key):
55 | stats[f'mean_{key}'] = ep[key].mean()
56 | if re.match(args.log_keys_max, key):
57 | stats[f'max_{key}'] = ep[key].max(0).mean()
58 | metrics.add(stats, prefix='stats')
59 |
60 | driver = embodied.Driver(env)
61 | driver.on_episode(lambda ep, worker: per_episode(ep))
62 | driver.on_step(lambda tran, _: step.increment())
63 | driver.on_step(replay.add)
64 |
65 | replay.maybe_restore()
66 |
67 | print('Prefill train dataset.')
68 | random_agent = embodied.RandomAgent(env.act_space)
69 | while len(replay) < max(args.batch_steps * config.envs.amount, args.train_fill):
70 | driver(random_agent.policy, steps=100)
71 | logger.add(metrics.result())
72 | logger.write()
73 |
74 | if config.replay == 'lfs':
75 | dataset = agent.dataset(replay, shared_memory=True)
76 | elif config.replay == 'uniform':
77 | dataset = agent.dataset(replay.dataset, shared_memory=False)
78 | state = [None] # To be writable from train step function below.
79 | batch = [None]
80 | def train_step(tran, worker):
81 | for _ in range(should_train(step)):
82 | with timer.scope('dataset'):
83 | batch[0] = next(dataset)
84 | if should_profile:
85 | jax.profiler.start_trace(f"{args.profile_path}/{step.value}")
86 | print(f'profiling step {step}')
87 | outs, state[0], mets = agent.train(batch[0], state[0])
88 | if should_profile:
89 | jax.profiler.stop_trace()
90 | metrics.add(mets, prefix='train')
91 | if 'priority' in outs:
92 | replay.prioritize(outs['key'], outs['priority'])
93 | updates.increment()
94 | if should_sync(updates):
95 | agent.sync()
96 | if should_log(step):
97 | agg = metrics.result()
98 | report = agent.report(batch[0])
99 | report = {k: v for k, v in report.items() if 'train/' + k not in agg}
100 | logger.add(agg)
101 | # TODO: do this rarely
102 | if should_report(step):
103 | logger.add(report, prefix='report')
104 | logger.add(replay.stats, prefix='replay')
105 | logger.add(timer.stats(), prefix='timer')
106 | logger.write(fps=True)
107 | driver.on_step(train_step)
108 |
109 | checkpoint = embodied.Checkpoint(logdir / 'checkpoint.ckpt', parallel=False)
110 | timer.wrap('checkpoint', checkpoint, ['save', 'load'])
111 | checkpoint.step = step
112 | checkpoint.agent = agent
113 | checkpoint.replay = replay
114 | if args.from_checkpoint:
115 | checkpoint.load(args.from_checkpoint)
116 | checkpoint.load_or_save()
117 | should_save(step) # Register that we jused saved.
118 |
119 | print('Start training loop.')
120 | policy = lambda *args: agent.policy(
121 | *args, mode='explore' if should_expl(step) else 'train')
122 | while step < args.steps:
123 | driver(policy, steps=100)
124 | if should_save(step):
125 | try:
126 | checkpoint.save()
127 | except:
128 | print('saving failed')
129 | pass
130 | logger.write()
131 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/run/train_eval.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | import embodied
4 | import numpy as np
5 |
6 |
7 | def train_eval(
8 | agent, train_env, eval_env, train_replay, eval_replay, logger, args):
9 |
10 | logdir = embodied.Path(args.logdir)
11 | logdir.mkdirs()
12 | print('Logdir', logdir)
13 | should_expl = embodied.when.Until(args.expl_until)
14 | should_train = embodied.when.Ratio(args.train_ratio / args.batch_steps)
15 | should_log = embodied.when.Clock(args.log_every)
16 | should_save = embodied.when.Clock(args.save_every)
17 | should_eval = embodied.when.Every(args.eval_every, args.eval_initial)
18 | should_sync = embodied.when.Every(args.sync_every)
19 | step = logger.step
20 | updates = embodied.Counter()
21 | metrics = embodied.Metrics()
22 | print('Observation space:', embodied.format(train_env.obs_space), sep='\n')
23 | print('Action space:', embodied.format(train_env.act_space), sep='\n')
24 |
25 | timer = embodied.Timer()
26 | timer.wrap('agent', agent, ['policy', 'train', 'report', 'save'])
27 | timer.wrap('env', train_env, ['step'])
28 | if hasattr(train_replay, '_sample'):
29 | timer.wrap('replay', train_replay, ['_sample'])
30 |
31 | nonzeros = set()
32 | def per_episode(ep, mode):
33 | length = len(ep['reward']) - 1
34 | score = float(ep['reward'].astype(np.float64).sum())
35 | logger.add({
36 | 'length': length, 'score': score,
37 | 'reward_rate': (ep['reward'] - ep['reward'].min() >= 0.1).mean(),
38 | }, prefix=('episode' if mode == 'train' else f'{mode}_episode'))
39 | print(f'Episode has {length} steps and return {score:.1f}.')
40 | stats = {}
41 | for key in args.log_keys_video:
42 | if key in ep:
43 | stats[f'policy_{key}'] = ep[key]
44 | for key, value in ep.items():
45 | if not args.log_zeros and key not in nonzeros and (value == 0).all():
46 | continue
47 | nonzeros.add(key)
48 | if re.match(args.log_keys_sum, key):
49 | stats[f'sum_{key}'] = ep[key].sum()
50 | if re.match(args.log_keys_mean, key):
51 | stats[f'mean_{key}'] = ep[key].mean()
52 | if re.match(args.log_keys_max, key):
53 | stats[f'max_{key}'] = ep[key].max(0).mean()
54 | metrics.add(stats, prefix=f'{mode}_stats')
55 |
56 | driver_train = embodied.Driver(train_env)
57 | driver_train.on_episode(lambda ep, worker: per_episode(ep, mode='train'))
58 | driver_train.on_step(lambda tran, _: step.increment())
59 | driver_train.on_step(train_replay.add)
60 | driver_eval = embodied.Driver(eval_env)
61 | driver_eval.on_step(eval_replay.add)
62 | driver_eval.on_episode(lambda ep, worker: per_episode(ep, mode='eval'))
63 |
64 | random_agent = embodied.RandomAgent(train_env.act_space)
65 | print('Prefill train dataset.')
66 | while len(train_replay) < max(args.batch_steps, args.train_fill):
67 | driver_train(random_agent.policy, steps=100)
68 | print('Prefill eval dataset.')
69 | while len(eval_replay) < max(args.batch_steps, args.eval_fill):
70 | driver_eval(random_agent.policy, steps=100)
71 | logger.add(metrics.result())
72 | logger.write()
73 |
74 | dataset_train = agent.dataset(train_replay.dataset)
75 | dataset_eval = agent.dataset(eval_replay.dataset)
76 | state = [None] # To be writable from train step function below.
77 | batch = [None]
78 | def train_step(tran, worker):
79 | for _ in range(should_train(step)):
80 | with timer.scope('dataset_train'):
81 | batch[0] = next(dataset_train)
82 | outs, state[0], mets = agent.train(batch[0], state[0])
83 | metrics.add(mets, prefix='train')
84 | if 'priority' in outs:
85 | train_replay.prioritize(outs['key'], outs['priority'])
86 | updates.increment()
87 | if should_sync(updates):
88 | agent.sync()
89 | if should_log(step):
90 | logger.add(metrics.result())
91 | logger.add(agent.report(batch[0]), prefix='report')
92 | with timer.scope('dataset_eval'):
93 | eval_batch = next(dataset_eval)
94 | logger.add(agent.report(eval_batch), prefix='eval')
95 | logger.add(train_replay.stats, prefix='replay')
96 | logger.add(eval_replay.stats, prefix='eval_replay')
97 | logger.add(timer.stats(), prefix='timer')
98 | logger.write(fps=True)
99 | driver_train.on_step(train_step)
100 |
101 | checkpoint = embodied.Checkpoint(logdir / 'checkpoint.ckpt')
102 | checkpoint.step = step
103 | checkpoint.agent = agent
104 | checkpoint.train_replay = train_replay
105 | checkpoint.eval_replay = eval_replay
106 | if args.from_checkpoint:
107 | checkpoint.load(args.from_checkpoint)
108 | checkpoint.load_or_save()
109 | should_save(step) # Register that we jused saved.
110 |
111 | print('Start training loop.')
112 | policy_train = lambda *args: agent.policy(
113 | *args, mode='explore' if should_expl(step) else 'train')
114 | policy_eval = lambda *args: agent.policy(*args, mode='eval')
115 | while step < args.steps:
116 | if should_eval(step):
117 | print('Starting evaluation at step', int(step))
118 | driver_eval.reset()
119 | driver_eval(policy_eval, episodes=max(len(eval_env), args.eval_eps))
120 | driver_train(policy_train, steps=100)
121 | if should_save(step):
122 | checkpoint.save()
123 | logger.write()
124 | logger.write()
125 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/run/train_holdout.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | import embodied
4 | import numpy as np
5 |
6 |
7 | def train_holdout(agent, env, train_replay, eval_replay, logger, args):
8 |
9 | logdir = embodied.Path(args.logdir)
10 | logdir.mkdirs()
11 | print('Logdir', logdir)
12 | should_expl = embodied.when.Until(args.expl_until)
13 | should_train = embodied.when.Ratio(args.train_ratio / args.batch_steps)
14 | should_log = embodied.when.Clock(args.log_every)
15 | should_save = embodied.when.Clock(args.save_every)
16 | should_sync = embodied.when.Every(args.sync_every)
17 | step = logger.step
18 | updates = embodied.Counter()
19 | metrics = embodied.Metrics()
20 | print('Observation space:', embodied.format(env.obs_space), sep='\n')
21 | print('Action space:', embodied.format(env.act_space), sep='\n')
22 |
23 | timer = embodied.Timer()
24 | timer.wrap('agent', agent, ['policy', 'train', 'report', 'save'])
25 | timer.wrap('env', env, ['step'])
26 | if hasattr(train_replay, '_sample'):
27 | timer.wrap('replay', train_replay, ['_sample'])
28 |
29 | nonzeros = set()
30 | def per_episode(ep):
31 | length = len(ep['reward']) - 1
32 | score = float(ep['reward'].astype(np.float64).sum())
33 | logger.add({
34 | 'length': length, 'score': score,
35 | 'reward_rate': (ep['reward'] - ep['reward'].min() >= 0.1).mean(),
36 | }, prefix='episode')
37 | print(f'Episode has {length} steps and return {score:.1f}.')
38 | stats = {}
39 | for key in args.log_keys_video:
40 | if key in ep:
41 | stats[f'policy_{key}'] = ep[key]
42 | for key, value in ep.items():
43 | if not args.log_zeros and key not in nonzeros and (value == 0).all():
44 | continue
45 | nonzeros.add(key)
46 | if re.match(args.log_keys_sum, key):
47 | stats[f'sum_{key}'] = ep[key].sum()
48 | if re.match(args.log_keys_mean, key):
49 | stats[f'mean_{key}'] = ep[key].mean()
50 | if re.match(args.log_keys_max, key):
51 | stats[f'max_{key}'] = ep[key].max(0).mean()
52 | metrics.add(stats, prefix='stats')
53 |
54 | driver = embodied.Driver(env)
55 | driver.on_episode(lambda ep, worker: per_episode(ep))
56 | driver.on_step(lambda tran, _: step.increment())
57 | driver.on_step(train_replay.add)
58 |
59 | print('Fill eval dataset.')
60 | driver_eval = embodied.Driver(env)
61 | driver_eval.on_step(eval_replay.add)
62 | random_agent = embodied.RandomAgent(env.act_space)
63 | while len(eval_replay) < max(args.batch_steps, args.eval_fill):
64 | print(len(eval_replay), max(args.batch_steps, args.eval_fill))
65 | driver_eval(random_agent.policy, steps=100)
66 | del driver_eval
67 | print('Prefill train dataset.')
68 | while len(train_replay) < max(args.batch_steps, args.train_fill):
69 | print(len(train_replay), max(args.batch_steps, args.train_fill))
70 | driver(random_agent.policy, steps=100)
71 | logger.add(metrics.result())
72 | logger.write()
73 |
74 | dataset_train = agent.dataset(train_replay.dataset)
75 | dataset_eval = agent.dataset(eval_replay.dataset)
76 | state = [None] # To be writable from train step function below.
77 | batch = [None]
78 | def train_step(tran, worker):
79 | for _ in range(should_train(step)):
80 | with timer.scope('dataset_train'):
81 | batch[0] = next(dataset_train)
82 | outs, state[0], mets = agent.train(batch[0], state[0])
83 | metrics.add(mets, prefix='train')
84 | if 'priority' in outs:
85 | train_replay.prioritize(outs['key'], outs['priority'])
86 | updates.increment()
87 | if should_sync(updates):
88 | agent.sync()
89 | if should_log(step):
90 | logger.add(metrics.result())
91 | logger.add(agent.report(batch[0]), prefix='report')
92 | with timer.scope('dataset_eval'):
93 | eval_batch = next(dataset_eval)
94 | logger.add(agent.report(eval_batch), prefix='eval')
95 | logger.add(train_replay.stats, prefix='replay')
96 | logger.add(eval_replay.stats, prefix='eval_replay')
97 | logger.add(timer.stats(), prefix='timer')
98 | logger.write(fps=True)
99 | driver.on_step(train_step)
100 |
101 | checkpoint = embodied.Checkpoint(logdir / 'checkpoint.ckpt')
102 | checkpoint.step = step
103 | checkpoint.agent = agent
104 | checkpoint.train_replay = train_replay
105 | checkpoint.eval_replay = eval_replay
106 | if args.from_checkpoint:
107 | checkpoint.load(args.from_checkpoint)
108 | checkpoint.load_or_save()
109 | should_save(step) # Register that we jused saved.
110 |
111 | print('Start training loop.')
112 | policy = lambda *args: agent.policy(
113 | *args, mode='explore' if should_expl(step) else 'train')
114 | while step < args.steps:
115 | # scalars = collections.defaultdict(list)
116 | # for _ in range(args.eval_samples):
117 | # for key, value in agent.report(next(dataset_eval)).items():
118 | # if value.shape == ():
119 | # scalars[key].append(value)
120 | # for name, values in scalars.items():
121 | # logger.scalar(f'eval/{name}', np.array(values, np.float64).mean())
122 | # logger.write()
123 | driver(policy, steps=100)
124 | if should_save(step):
125 | checkpoint.save()
126 | logger.write()
127 | logger.write()
128 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/run/train_save.py:
--------------------------------------------------------------------------------
1 | import io
2 | import re
3 | from datetime import datetime
4 |
5 | import embodied
6 | import numpy as np
7 |
8 |
9 | def train_save(agent, env, replay, logger, args):
10 |
11 | logdir = embodied.Path(args.logdir)
12 | logdir.mkdirs()
13 | print('Logdir:', logdir)
14 | should_expl = embodied.when.Until(args.expl_until)
15 | should_train = embodied.when.Ratio(args.train_ratio / args.batch_steps)
16 | should_log = embodied.when.Clock(args.log_every)
17 | should_save = embodied.when.Clock(args.save_every)
18 | should_sync = embodied.when.Every(args.sync_every)
19 | step = logger.step
20 | updates = embodied.Counter()
21 | metrics = embodied.Metrics()
22 | print('Observation space:', embodied.format(env.obs_space), sep='\n')
23 | print('Action space:', embodied.format(env.act_space), sep='\n')
24 |
25 | timer = embodied.Timer()
26 | timer.wrap('agent', agent, ['policy', 'train', 'report', 'save'])
27 | timer.wrap('env', env, ['step'])
28 | timer.wrap('replay', replay, ['add', 'save'])
29 | timer.wrap('logger', logger, ['write'])
30 |
31 | nonzeros = set()
32 | def per_episode(ep):
33 | length = len(ep['reward']) - 1
34 | score = float(ep['reward'].astype(np.float64).sum())
35 | sum_abs_reward = float(np.abs(ep['reward']).astype(np.float64).sum())
36 | logger.add({
37 | 'length': length,
38 | 'score': score,
39 | 'sum_abs_reward': sum_abs_reward,
40 | 'reward_rate': (np.abs(ep['reward']) >= 0.5).mean(),
41 | }, prefix='episode')
42 | print(f'Episode has {length} steps and return {score:.1f}.')
43 | stats = {}
44 | for key in args.log_keys_video:
45 | if key in ep:
46 | stats[f'policy_{key}'] = ep[key]
47 | for key, value in ep.items():
48 | if not args.log_zeros and key not in nonzeros and (value == 0).all():
49 | continue
50 | nonzeros.add(key)
51 | if re.match(args.log_keys_sum, key):
52 | stats[f'sum_{key}'] = ep[key].sum()
53 | if re.match(args.log_keys_mean, key):
54 | stats[f'mean_{key}'] = ep[key].mean()
55 | if re.match(args.log_keys_max, key):
56 | stats[f'max_{key}'] = ep[key].max(0).mean()
57 | metrics.add(stats, prefix='stats')
58 |
59 | epsdir = embodied.Path(args.logdir) / 'saved_episodes'
60 | epsdir.mkdirs()
61 | print('Saving episodes:', epsdir)
62 | def save(ep):
63 | time = datetime.now().strftime("%Y%m%dT%H%M%S")
64 | uuid = str(embodied.uuid())
65 | score = str(np.round(ep['reward'].sum(), 1)).replace('-', 'm')
66 | length = len(ep['reward'])
67 | filename = epsdir / f'{time}-{uuid}-len{length}-rew{score}.npz'
68 | with io.BytesIO() as stream:
69 | np.savez_compressed(stream, **ep)
70 | stream.seek(0)
71 | filename.write(stream.read(), mode='wb')
72 | print('Saved episode:', filename)
73 | saver = embodied.Worker(save, 'thread')
74 |
75 | driver = embodied.Driver(env)
76 | driver.on_episode(lambda ep, worker: per_episode(ep))
77 | driver.on_episode(lambda ep, worker: saver(ep))
78 | driver.on_step(lambda tran, _: step.increment())
79 | driver.on_step(replay.add)
80 |
81 | print('Prefill train dataset.')
82 | random_agent = embodied.RandomAgent(env.act_space)
83 | while len(replay) < max(args.batch_steps, args.train_fill):
84 | driver(random_agent.policy, steps=100)
85 | logger.add(metrics.result())
86 | logger.write()
87 |
88 | dataset = agent.dataset(replay.dataset)
89 | state = [None] # To be writable from train step function below.
90 | batch = [None]
91 | def train_step(tran, worker):
92 | for _ in range(should_train(step)):
93 | with timer.scope('dataset'):
94 | batch[0] = next(dataset)
95 | outs, state[0], mets = agent.train(batch[0], state[0])
96 | metrics.add(mets, prefix='train')
97 | if 'priority' in outs:
98 | replay.prioritize(outs['key'], outs['priority'])
99 | updates.increment()
100 | if should_sync(updates):
101 | agent.sync()
102 | if should_log(step):
103 | agg = metrics.result()
104 | report = agent.report(batch[0])
105 | report = {k: v for k, v in report.items() if 'train/' + k not in agg}
106 | logger.add(agg)
107 | logger.add(report, prefix='report')
108 | logger.add(replay.stats, prefix='replay')
109 | logger.add(timer.stats(), prefix='timer')
110 | logger.write(fps=True)
111 | driver.on_step(train_step)
112 |
113 | checkpoint = embodied.Checkpoint(logdir / 'checkpoint.ckpt')
114 | timer.wrap('checkpoint', checkpoint, ['save', 'load'])
115 | checkpoint.step = step
116 | checkpoint.agent = agent
117 | checkpoint.replay = replay
118 | if args.from_checkpoint:
119 | checkpoint.load(args.from_checkpoint)
120 | checkpoint.load_or_save()
121 | should_save(step) # Register that we jused saved.
122 |
123 | print('Start training loop.')
124 | policy = lambda *args: agent.policy(
125 | *args, mode='explore' if should_expl(step) else 'train')
126 | while step < args.steps:
127 | driver(policy, steps=100)
128 | if should_save(step):
129 | checkpoint.save()
130 | logger.write()
131 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/scripts/install-atari.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | set -eu
3 |
4 | apt-get update
5 | apt-get install -y wget
6 | apt-get install -y unrar
7 | apt-get clean
8 |
9 | pip3 install gym==0.19.0
10 | pip3 install atari-py==0.2.9
11 | pip3 install opencv-python
12 |
13 | mkdir roms && cd roms
14 | wget -L -nv http://www.atarimania.com/roms/Roms.rar
15 | unrar x -o+ Roms.rar
16 | python3 -m atari_py.import_roms ROMS
17 | cd .. && rm -rf roms
18 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/scripts/install-dmlab.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | set -eu
3 |
4 | # Dependencies
5 | apt-get update && apt-get install -y \
6 | build-essential curl freeglut3 gettext git libffi-dev libglu1-mesa \
7 | libglu1-mesa-dev libjpeg-dev liblua5.1-0-dev libosmesa6-dev \
8 | libsdl2-dev lua5.1 pkg-config python-setuptools python3-dev \
9 | software-properties-common unzip zip zlib1g-dev g++
10 | pip3 install numpy
11 |
12 | # Bazel
13 | apt-get install -y apt-transport-https curl gnupg
14 | curl -fsSL https://bazel.build/bazel-release.pub.gpg | gpg --dearmor > bazel.gpg
15 | mv bazel.gpg /etc/apt/trusted.gpg.d/
16 | echo "deb [arch=amd64] https://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list
17 | apt-get update && apt-get install -y bazel
18 |
19 | # Build
20 | git clone https://github.com/deepmind/lab.git
21 | cd lab
22 | echo 'build --cxxopt=-std=c++17' > .bazelrc
23 | bazel build -c opt //python/pip_package:build_pip_package
24 | ./bazel-bin/python/pip_package/build_pip_package /tmp/dmlab_pkg
25 | pip3 install --force-reinstall /tmp/dmlab_pkg/deepmind_lab-*.whl
26 | cd ..
27 | rm -rf lab
28 |
29 | # Dataset
30 | mkdir dmlab_data
31 | cd dmlab_data
32 | pip3 install Pillow
33 | curl https://bradylab.ucsd.edu/stimuli/ObjectsAll.zip -o ObjectsAll.zip
34 | unzip ObjectsAll.zip
35 | cd OBJECTSALL
36 | python3 << EOM
37 | import os
38 | from PIL import Image
39 | files = [f for f in os.listdir('.') if f.lower().endswith('jpg')]
40 | for i, file in enumerate(sorted(files)):
41 | print(file)
42 | im = Image.open(file)
43 | im.save('../%04d.png' % (i+1))
44 | EOM
45 | cd ..
46 | rm -rf __MACOSX OBJECTSALL ObjectsAll.zip
47 |
48 | apt-get clean
49 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/scripts/install-minecraft.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | set -eu
3 |
4 | apt-get update
5 | apt-get install -y libgl1-mesa-dev
6 | apt-get install -y libx11-6
7 | apt-get install -y openjdk-8-jdk
8 | apt-get install -y x11-xserver-utils
9 | apt-get install -y xvfb
10 | apt-get clean
11 |
12 | pip3 install minerl==0.4.4
13 |
--------------------------------------------------------------------------------
/recall2imagine/embodied/scripts/xvfb_run.sh:
--------------------------------------------------------------------------------
1 | xvfb-run -a -s "-screen 0 1024x768x24 -ac +extension GLX +render -noreset" "$@"
2 | # xvfb-run "$@"
3 |
--------------------------------------------------------------------------------
/recall2imagine/expl.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | tree_map = jax.tree_util.tree_map
4 | sg = lambda x: tree_map(jax.lax.stop_gradient, x)
5 |
6 | from . import nets
7 | from . import jaxutils
8 | from . import ninjax as nj
9 |
10 |
11 | class Disag(nj.Module):
12 |
13 | def __init__(self, wm, act_space, config):
14 | self.config = config.update({'disag_head.inputs': ['tensor']})
15 | self.opt = jaxutils.Optimizer(name='disag_opt', **config.expl_opt)
16 | self.inputs = nets.Input(config.disag_head.inputs, dims='deter')
17 | self.target = nets.Input(self.config.disag_target, dims='deter')
18 | self.nets = [
19 | nets.MLP(shape=None, **self.config.disag_head, name=f'disag{i}')
20 | for i in range(self.config.disag_models)]
21 |
22 | def __call__(self, traj):
23 | inp = self.inputs(traj)
24 | preds = jnp.array([net(inp).mode() for net in self.nets])
25 | return preds.std(0).mean(-1)[1:]
26 |
27 | def train(self, data):
28 | return self.opt(self.nets, self.loss, data)
29 |
30 | def loss(self, data):
31 | inp = sg(self.inputs(data)[:, :-1])
32 | tar = sg(self.target(data)[:, 1:])
33 | losses = []
34 | for net in self.nets:
35 | net._shape = tar.shape[2:]
36 | losses.append(-net(inp).log_prob(tar).mean())
37 | return jnp.array(losses).sum()
38 |
--------------------------------------------------------------------------------
/recall2imagine/ssm/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chandar-lab/Recall2Imagine/6e317e751ffd8f381e09476f3633567dff9f9234/recall2imagine/ssm/__init__.py
--------------------------------------------------------------------------------
/recall2imagine/ssm/siso.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | from flax import linen as nn
4 | import jax
5 | import jax.numpy as np
6 | from jax.numpy.linalg import eigh, inv, matrix_power
7 | from jax.nn.initializers import normal
8 |
9 | from .common import log_step_initializer, \
10 | slow_scan, \
11 | fast_scan, \
12 | depthwise, \
13 | SequenceBlock, \
14 | batchwise
15 |
16 |
17 | def make_HiPPO(N):
18 | """
19 | A standard HiPPO initialization.
20 | """
21 | P = np.sqrt(1 + 2 * np.arange(N))
22 | A = P[:, np.newaxis] * P[np.newaxis, :]
23 | A = np.tril(A) - np.diag(np.arange(N))
24 | return -A
25 |
26 | def discrete_DPLR(Lambda, P, Q, B, C, step, L):
27 | """
28 | DLPR matrix discretization function.
29 | """
30 | # Convert parameters to matrices
31 | B = B[:, np.newaxis]
32 | Ct = C[np.newaxis, :]
33 |
34 | N = Lambda.shape[0]
35 | A = np.diag(Lambda) - P[:, np.newaxis] @ Q[:, np.newaxis].conj().T
36 | I = np.eye(N)
37 |
38 | # Forward Euler
39 | A0 = (2.0 / step) * I + A
40 |
41 | # Backward Euler
42 | D = np.diag(1.0 / ((2.0 / step) - Lambda))
43 | Qc = Q.conj().T.reshape(1, -1)
44 | P2 = P.reshape(-1, 1)
45 | A1 = D - (D @ P2 * (1.0 / (1 + (Qc @ D @ P2))) * Qc @ D)
46 |
47 | # A bar and B bar
48 | Ab = A1 @ A0
49 | Bb = 2 * A1 @ B
50 |
51 | # Recover Cbar from Ct
52 | Cb = Ct @ inv(I - matrix_power(Ab, L)).conj()
53 | return Ab, Bb, Cb.conj()
54 |
55 | def make_NPLR_HiPPO(N):
56 | """
57 | Creates a HiPPO matrix and discretizes it.
58 | """
59 | # Make -HiPPO
60 | nhippo = make_HiPPO(N)
61 |
62 | # Add in a rank 1 term. Makes it Normal.
63 | P = np.sqrt(np.arange(N) + 0.5)
64 |
65 |
66 | # HiPPO also specifies the B matrix
67 | B = np.sqrt(2 * np.arange(N) + 1.0)
68 | return nhippo, P, B
69 |
70 |
71 | def make_DPLR_HiPPO(N):
72 | """
73 | Diagonalize NPLR representation
74 | """
75 | A, P, B = make_NPLR_HiPPO(N)
76 |
77 | S = A + P[:, np.newaxis] * P[np.newaxis, :]
78 |
79 | # Check skew symmetry
80 | S_diag = np.diagonal(S)
81 | Lambda_real = np.mean(S_diag) * np.ones_like(S_diag)
82 | # assert np.allclose(Lambda_real, S_diag, atol=1e-3)
83 |
84 | # Diagonalize S to V \Lambda V^*
85 | Lambda_imag, V = eigh(S * -1j)
86 |
87 | P = V.conj().T @ P
88 | B = V.conj().T @ B
89 | return Lambda_real + 1j * Lambda_imag, P, B, V
90 |
91 | def init(x):
92 | """
93 | Factory for constant initializer in Flax
94 | """
95 | def _init(key, shape):
96 | assert shape == x.shape
97 | return x
98 |
99 | return _init
100 |
101 | def hippo_initializer(N):
102 | """
103 | The initializer function for the DPLR matrices.
104 | """
105 | Lambda, P, B, _ = make_DPLR_HiPPO(N)
106 | return init(Lambda.real), init(Lambda.imag), init(P), init(B)
107 |
108 |
109 | @depthwise
110 | class SISOLayer(nn.Module):
111 | """
112 | The SISO SSM (S4).
113 | """
114 | N: int
115 | l_max: int # apparently in rec mode this has no influence on anything
116 | parallel: bool = False
117 | conv: bool = False
118 |
119 | # Special parameters with multiplicative factor on lr and no weight decay (handled by main train script)
120 | lr = {
121 | "Lambda_re": 0.1,
122 | "Lambda_im": 0.1,
123 | "P": 0.1,
124 | "B": 0.1,
125 | "log_step": 0.1,
126 | }
127 |
128 | def setup(self):
129 | """
130 | initializes the SSM parameters.
131 | """
132 | # Learned Parameters (C is complex!)
133 | init_A_re, init_A_im, init_P, init_B = hippo_initializer(self.N)
134 | self.Lambda_re = self.param("Lambda_re", init_A_re, (self.N,))
135 | self.Lambda_im = self.param("Lambda_im", init_A_im, (self.N,))
136 | # Ensure the real part of Lambda is negative
137 | # (described in the SaShiMi follow-up to S4)
138 | self.Lambda = np.clip(self.Lambda_re, None, -1e-4) + 1j * self.Lambda_im
139 | self.P = self.param("P", init_P, (self.N,))
140 | self.B = self.param("B", init_B, (self.N,))
141 | # C should be init as standard normal
142 | # This doesn't work due to how JAX handles complex optimizers https://github.com/deepmind/optax/issues/196
143 | # self.C = self.param("C", normal(stddev=1.0, dtype=np.complex64), (self.N,))
144 | self.C = self.param("C", normal(stddev=0.5**0.5), (self.N, 2))
145 | self.C = self.C[..., 0] + 1j * self.C[..., 1]
146 | self.D = self.param("D", nn.initializers.ones, (1,))
147 | self.step = np.exp(self.param("log_step", log_step_initializer(), (1,)))
148 | # Flax trick to cache discrete form during decoding.
149 | def init_discrete():
150 | return discrete_DPLR(
151 | self.Lambda,
152 | self.P,
153 | self.P,
154 | self.B,
155 | self.C,
156 | self.step,
157 | self.l_max,
158 | )
159 | self.ssm = init_discrete()
160 |
161 | def __call__(self, u, x0, dones=None):
162 | """
163 | Forward pass for SSM.
164 |
165 | when parallel=True:
166 | Shape:
167 | u.shape: (len, dims)
168 | u.dtype: float32
169 | x_0.shape: (N, dims)
170 | x_0.dtyle: complex64
171 | Output Shape:
172 | y.shape: (len, dims)
173 | y.dtype: float32
174 | x.shape: (N, dims)
175 | x.dtype: complex64
176 | """
177 | if not self.parallel or u.shape[0] == 1:
178 | x, y = slow_scan(*self.ssm, u, x0)
179 | Du = jax.vmap(lambda u: self.D * u)(u)
180 | return (y.real + Du).reshape(-1), x
181 | y, x = fast_scan(*self.ssm, u, x0, dones) # .real is happening inside of fast_scan
182 | Du = jax.vmap(lambda u: self.D * u)(u)
183 | return (y + Du)[..., 0], x
184 |
185 |
186 | SISOBlock = partial(SequenceBlock, layer_cls=SISOLayer)
187 | SISOBlock = batchwise(SISOBlock)
188 | SISOLayerB = batchwise(SISOLayer)
--------------------------------------------------------------------------------
/recall2imagine/ssm_nets.py:
--------------------------------------------------------------------------------
1 | from . import ninjax as nj
2 | from . import jaxutils
3 |
4 | from .ssm.siso import SISOBlock
5 | from .ssm.mimo import MIMOBlock
6 | from .nets import RSSM, Linear
7 |
8 | import jax
9 | from jax import numpy as jnp
10 | from tensorflow_probability.substrates import jax as tfp
11 | f32 = jnp.float32
12 | tfd = tfp.distributions
13 | tree_map = jax.tree_util.tree_map
14 | sg = lambda x: tree_map(jax.lax.stop_gradient, x)
15 |
16 | cast = jaxutils.cast_to_compute
17 |
18 | def init_siso_ssm(N, d_model, n_layers, parallel=True, conv=False, use_norm=False, **kw):
19 | """
20 | The wrapper function that wraps SISO SSM into FlaxModule for compatibility with the main code.
21 | """
22 | if 'l_max' not in kw:
23 | kw['l_max'] = 255
24 |
25 | return nj.FlaxModule(
26 | SISOBlock,
27 | layer={
28 | 'N': N,
29 | 'l_max': kw['l_max'],
30 | 'parallel': parallel,
31 | 'conv': conv
32 | },
33 | name='siso_ssm',
34 | d_model=d_model,
35 | n_layers=n_layers,
36 | dropout=kw.get('dropout', 0.0),
37 | prenorm=kw.get('prenorm', False),
38 | mlp=kw.get('mlp', False),
39 | glu=kw.get('glu', False),
40 | use_norm=use_norm,
41 | )
42 |
43 |
44 | def init_mimo_ssm(
45 | P,
46 | H,
47 | n_blocks=1,
48 | n_layers=1,
49 | conj_sym=False,
50 | C_init='',
51 | discretization='zoh',
52 | dt_min=0.001,
53 | dt_max=0.1,
54 | clip_eigs=False,
55 | parallel=True,
56 | use_norm=False,
57 | reset_mode=False,
58 | **kw
59 | ):
60 | """
61 | The wrapper function that wraps MIMO SSM into FlaxModule for compatibility with the main code.
62 |
63 | kw:
64 | n_blocks: int
65 | conj_sym: bool
66 | C_init:
67 | """
68 |
69 | if conj_sym:
70 | P = P // 2
71 |
72 | return nj.FlaxModule(
73 | MIMOBlock,
74 | layer={
75 | # **{
76 | 'H': H,
77 | 'P': P,
78 | 'n_blocks': n_blocks,
79 | 'C_init': C_init,
80 | 'discretization': discretization,
81 | 'dt_min': dt_min,
82 | 'dt_max': dt_max,
83 | 'conj_sym': conj_sym,
84 | 'clip_eigs': clip_eigs,
85 | 'parallel': parallel,
86 | 'reset_mode': reset_mode,
87 | },
88 | # },
89 | name='mimo_ssm',
90 | d_model=H,
91 | n_layers=n_layers,
92 | dropout=kw.get('dropout', 0.0),
93 | prenorm=kw.get('prenorm', False),
94 | mlp=kw.get('mlp', False),
95 | glu=kw.get('glu', False),
96 | use_norm=use_norm,
97 | )
98 |
99 |
100 | class S3M(RSSM):
101 | """
102 | This class implements RSSM with SISO/MIMO SSM cell.
103 | """
104 | def __init__(self, deter=1024, stoch=32, classes=32, units=1024, hidden=128, unroll=False, initial='learned',
105 | unimix=0.01, action_clip=1.0, nonrecurrent_enc=False, ssm='mimo', ssm_kwargs=None, **kw):
106 | if ssm == 'siso':
107 | self.core = init_siso_ssm(hidden, units, **ssm_kwargs)
108 | elif ssm == 'mimo':
109 | self.core = init_mimo_ssm(hidden, units, **ssm_kwargs)
110 | else:
111 | raise NotImplementedError("SSM is not implemented")
112 | self._ssm = ssm
113 | self._deter = deter
114 | self._units = units
115 | self._hidden = hidden
116 | self._stoch = stoch
117 | self._classes = classes
118 | self._n_layers = ssm_kwargs['n_layers']
119 | self._parallel = ssm_kwargs['parallel']
120 | self._conv = ssm_kwargs['conv']
121 | self._unroll = unroll
122 | self._initial = initial
123 | self._unimix = unimix
124 | self._nonrecurrent_enc = nonrecurrent_enc
125 | self._action_clip = action_clip
126 | self._kw = kw
127 | self._kw['units'] = units
128 |
129 | def initial(self, bs):
130 | """
131 | Returns the initial vector for RSSM.
132 | """
133 | if self._classes:
134 | state = dict(
135 | deter=jnp.zeros([bs, self._deter], f32),
136 | logit=jnp.zeros([bs, self._stoch, self._classes], f32),
137 | stoch=jnp.zeros([bs, self._stoch, self._classes], f32))
138 | else:
139 | state = dict(
140 | deter=jnp.zeros([bs, self._deter], f32),
141 | mean=jnp.zeros([bs, self._stoch], f32),
142 | std=jnp.ones([bs, self._stoch], f32),
143 | stoch=jnp.zeros([bs, self._stoch], f32))
144 | if self._ssm == 'siso':
145 | state['hidden'] = jnp.zeros([bs, self._n_layers, self._hidden, self._units], jnp.complex64)
146 | if self._ssm == 'mimo':
147 | state['hidden'] = jnp.zeros([bs, self._n_layers, self._hidden], jnp.complex64)
148 | if self._initial == 'zeros':
149 | state = cast(state)
150 | state['hidden'] = state['hidden'].astype(jnp.complex64)
151 | return state
152 | elif self._initial == 'learned':
153 | hidden = self.get('initial_hidden', jnp.zeros, (2,) + state['hidden'][0].shape, f32)
154 | deter = self.get('initial_deter', jnp.zeros, state['deter'][0].shape, f32)
155 | hidden = jnp.expand_dims(hidden[0] + 1j * hidden[1], 0)
156 | state['hidden'] = jnp.repeat(hidden, bs, 0)
157 | state['deter'] = jnp.repeat(jnp.tanh(deter)[None], bs, 0)
158 | state['stoch'] = self.get_stoch(cast(state['deter']))
159 | return cast(state)
160 | else:
161 | raise NotImplementedError(self._initial)
162 |
163 | def _cell(self, x, prev_state):
164 | """
165 | Implements one step forward pass of RSSM.
166 |
167 | x.shape == (batch, units)
168 | prev_state {
169 | 'hidden' : shape (batch, self._hidden, self._units)
170 | }
171 | """
172 | hidden = prev_state['hidden'] # this is x_t-1
173 | x = x[:, jnp.newaxis]
174 | # cell expected shape
175 | # u: (batch, 1, self._units)
176 | # xk: (batch, self._hidden, self._units)
177 | output, hidden = self.core(x, hidden) # y_t, x_t = S*(u_t-1, x_t-1)
178 | # y: (batch, 1, self._units)
179 | # xk1: (batch, self._hidden, self._units)
180 | if isinstance(output, tuple):
181 | output, _ = output
182 | output = output[:, 0]
183 | if self._ssm == 'siso':
184 | # (batch, hidden, layers, units) -> (batch, layers, hidden, units)
185 | hidden = jnp.transpose(hidden, (0, 2, 1, 3))
186 | deter = output
187 | kw = {'winit': 'normal', 'fan': 'avg', 'act': self._kw['act'], 'units': self._deter}
188 | deter = self.get('out_proj', Linear, **kw)(deter)
189 | # kw = {**self._kw, 'units': self._deter}
190 | # deter = self.get('deter_proj', Linear, **kw)(output)
191 | return deter, {'deter': deter, 'hidden': hidden}
192 |
193 |
194 | def _cell_scan(self, x, state, first, zero_state):
195 | """
196 | Implements sequential forward pass for RSSM.
197 | """
198 | swap = lambda x: x.transpose([1, 0] + list(range(2, len(x.shape))))
199 | if not self._parallel:
200 | state, (outs, x) = super()._cell_scan(x, state, first, zero_state)
201 | return state, (tree_map(swap, outs), swap(x))
202 |
203 | if self._ssm == 'siso':
204 | first = jnp.tile(first[..., None], (1, 1, x.shape[-1]))
205 | x, outstate = self.core((swap(x), swap(first)), state['hidden'], zero_state['hidden']) # hidden.shape = (batch, layers, hidden, units)
206 | if isinstance(x, tuple):
207 | x, _ = x
208 | if self._ssm == 'siso':
209 | # (batch, seq, units, layers, hidden) -> (batch, seq, layers, hidden, units)
210 | outstate = jnp.transpose(outstate, (0, 1, 3, 4, 2))
211 | kw = {'winit': 'normal', 'fan': 'avg', 'act': self._kw['act'], 'units': self._deter}
212 | x = self.get('out_proj', Linear, **kw)(x)
213 | return {'deter': x[:, -1], 'hidden': outstate[:, -1]}, ({'deter': x, 'hidden': outstate}, x)
214 |
--------------------------------------------------------------------------------