├── .gitignore ├── LICENSE ├── README.md ├── env ├── Dockerfile ├── conda_env.yaml └── gym-0.19.0-py3-none-any.whl ├── imgs ├── mmaze.jpg ├── r2i.png └── teaser.jpg └── recall2imagine ├── Dockerfile ├── __init__.py ├── agent.py ├── behaviors.py ├── configs.yaml ├── configs_backup.yaml ├── embodied ├── __init__.py ├── core │ ├── __init__.py │ ├── base.py │ ├── basics.py │ ├── batch.py │ ├── batcher.py │ ├── checkpoint.py │ ├── config.py │ ├── counter.py │ ├── distr.py │ ├── driver.py │ ├── flags.py │ ├── logger.py │ ├── metrics.py │ ├── parallel.py │ ├── path.py │ ├── random.py │ ├── space.py │ ├── timer.py │ ├── uuid.py │ ├── when.py │ ├── worker.py │ └── wrappers.py ├── envs │ ├── atari.py │ ├── crafter.py │ ├── dmc.py │ ├── dmlab.py │ ├── dummy.py │ ├── from_dm.py │ ├── from_gym.py │ ├── loconav.py │ ├── loconav_quadruped.py │ ├── loconav_quadruped.xml │ ├── minecraft.py │ ├── minecraft_base.py │ ├── minecraft_minerl.py │ ├── pinpad.py │ └── robodesk.py ├── replay │ ├── __init__.py │ ├── chunk.py │ ├── generic.py │ ├── generic_lfs.py │ ├── lfs_manager.py │ ├── limiters.py │ ├── naive_chunks.py │ ├── replays.py │ ├── reverb.py │ ├── saver.py │ └── selectors.py ├── run │ ├── __init__.py │ ├── eval_only.py │ ├── parallel.py │ ├── train.py │ ├── train_eval.py │ ├── train_holdout.py │ └── train_save.py └── scripts │ ├── install-atari.sh │ ├── install-dmlab.sh │ ├── install-minecraft.sh │ ├── plot.py │ └── xvfb_run.sh ├── expl.py ├── jaxagent.py ├── jaxutils.py ├── nets.py ├── ninjax.py ├── ssm ├── __init__.py ├── common.py ├── mimo.py └── siso.py ├── ssm_nets.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | .pytest_cache 2 | dist 3 | scripts/* 4 | __pycache__/ 5 | *.py[codi] 6 | *.egg-info 7 | MUJOCO_LOG.TXT 8 | ; 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2023 Danijar Hafner, Artem Zholus, Mohammad Reza Samsami 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mastering Memory Tasks with World Models 2 | 3 | --- 4 | 5 | > M. R. Samsami\*, A. Zholus\*, J. Rajendran, S. Chandar. Mastering Memory Tasks with World Models. ICLR 2024, top‑1.2% (oral)
6 | > __[[Project Website]](https://recall2imagine.github.io)__       __[[Paper]](https://arxiv.org/abs/2403.04253)__       __[[Openreview]](https://openreview.net/forum?id=1vDArHJ68h)__       __[[X Thread]](https://x.com/apsarathchandar/status/1772345328617824502)__ 7 | 8 | --- 9 | 10 | This repo contains the official implementation of the Recall to Imagine (R2I) algorithm, introduced in the paper 11 | **Mastering Memory Tasks with World Models**. R2I is a generalist, computationally efficient, model-based agent focusing on memory-intensive reinforcement learnin (RL) tasks. It stands out by demonstrating superhuman performance in complex memory domain, Memory Maze. The codebase, powered by JAX, offers an efficient framework for RL development, incorporating state space models. 12 | 13 | 14 |

15 | Teaser 16 |

17 | 18 | ## Table of Contents 19 | - [Prerequisites](#prerequisites) 20 | - [Installation](#installation) 21 | - [Conda](#conda) 22 | - [Docker & Singularity](#docker--singularity) 23 | - [Running Experiments with Containers](#running-experiments-with-containers) 24 | - [Reproducing Results from the Paper](#reproducing-results-from-the-paper) 25 | - [Memory Maze](#memory-maze) 26 | - [POPGym](#popgym) 27 | - [BSuite](#bsuite) 28 | - [BibTeX](#bibtex) 29 | - [Acknowledgements](#acknowledgements) 30 | 31 | ## Prerequisites 32 | 33 | Before diving into the installation and running experiments, ensure you meet the following prerequisites: 34 | 35 | | Requirement | Specification | 36 | |--------------------|----------------------------------------------------------------| 37 | | **Hardware** | GPU with CUDA support | 38 | | **Software** | Docker or Singularity for containerized environments, conda | 39 | | **Python Version** | 3.8 | 40 | | **Storage** | At least 130GB for the Memory Maze experiment's replay buffer | 41 | 42 | 43 | ## Installation 44 | 45 | ### Conda 46 | 47 | We provide a self-contained conda environment for running experiments. To setup the env do 48 | 49 | ```sh 50 | conda create -n recall2imagine python=3.8 51 | conda activate recall2imagine 52 | cd env && conda env update -f conda_env.yaml 53 | pip install \ 54 | "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \ 55 | dm-haiku \ 56 | flax \ 57 | optax==0.1.5 58 | ``` 59 | 60 | Though conda environment should work, the recommended way of setting up the environment is through 61 | docker or singularity containers. See below for details. 62 | 63 | ### Docker & Singularity 64 | 65 | The most reliable way to reproduce the results from the paper is through containers. 66 | To build a docker image do 67 | 68 | ```sh 69 | cd env && docker build -t r2i . 70 | ``` 71 | 72 | This steps takes several mins to complete. Alternatively, you can just pull our pre-built image: 73 | 74 | ```sh 75 | docker pull artemzholus/r2i && docker tag artemzholus/r2i r2i 76 | ``` 77 | 78 | If you are going to run the code inside of a singularity container, convert the docker image to singularity. 79 | Note that the code below creates a lot of temporary files in the home directory to avoid that, do 80 | 81 | ```sh 82 | export SINGULARITY_CACHEDIR= 83 | ``` 84 | 85 | If you're not sure whether you need this, just skip this step. 86 | 87 | To convert the image, do 88 | 89 | ```sh 90 | singularity pull docker://artemzholus/r2i 91 | ``` 92 | 93 | ## Running experiments with containers 94 | 95 | To run experiments docker, launch an interactive docker container: 96 | 97 | ```sh 98 | docker run --rm -it --network host -v $(pwd):/code -w /code r2i bash 99 | ``` 100 | 101 | This will lead to the interactive shell within a container. Run any python training script 102 | within that shell. 103 | 104 | For singularity: 105 | 106 | ```sh 107 | singularity shell --nv 108 | ``` 109 | 110 | where image path it the file name that was downloaded when doing `singularity pull` 111 | 112 | 113 | ## Reproducing results from the paper 114 | 115 | The code was designed with reproducibility in mind. 116 | 117 | ### Memory Maze 118 | 119 | **TL;DR:** use the following command: 120 | 121 | ```sh 122 | python recall2imagine/train.py \ 123 | --configs mmaze --task gym_memory_maze:MemoryMaze-9x9-v0 \ 124 | --wdb_name memory_maze_9x9 \ 125 | --logdir ./logs_memory_maze_9x9 126 | ``` 127 | 128 | By default, the script uses 1 GPU. 129 | To ensure that you match the training speed reported in our paper, consider using two GPUs 130 | (this does not affect the sample efficiency): 131 | 132 | ```sh 133 | python recall2imagine/train.py \ 134 | --configs mmaze --task gym_memory_maze:MemoryMaze-9x9-v0 \ 135 | --wdb_name memory_maze_9x9 \ 136 | --jax.train_devices 0 1 \ 137 | --jax.policy_devices 0 \ 138 | --logdir ./logs_memory_maze_9x9 139 | ``` 140 | 141 | **NOTE**: By default the training script uses disk-based replay buffer. In memory maze our 142 | default replay buffer size is 10M image RL steps which takes about 130GB of space. 143 | Therefore, be sure to have enough space before running, since the scipt will allocate 144 | the whole file in the beginning of the training. 145 | 146 | memory-maze 147 | 148 | Since this experiment may take a long time, if you are a large compute cluster user, 149 | you may have a limited maximal duration of experiments. Exactly for that scenario, we designed a new 150 | replay buffer for efficient restarting of very long RL experiments. The idea is to keep two 151 | versions of the replay buffer - one in the fast, but temporary storage, the other in the slower but 152 | lifelong storage. Use the following pattern to continue your experiment after interruption: 153 | 154 | ```sh 155 | python recall2imagine/train.py \ 156 | --configs mmaze --task gym_memory_maze:MemoryMaze-9x9-v0 \ 157 | --wdb_name memory_maze_9x9 \ 158 | --jax.train_devices 0 1 \ 159 | --jax.policy_devices 0 \ 160 | --use_lfs True --lfs_dir \ 161 | --logdir 162 | ``` 163 | 164 | Use this command for both first run and each next one. It knows how to deal with the state. 165 | This buffer is very reliable, it has a good fault tolerance and it is the way how we conducted 166 | our original experiments in the paper. 167 | 168 | 169 | ### POPGym 170 | 171 | Use the following command: 172 | 173 | ```sh 174 | python recall2imagine/train.py \ 175 | --configs popgym --task gym_popgym:popgym-RepeatPreviousEasy-v0 \ 176 | --wdb_name popgym_repeat_previous_easy \ 177 | --logdir ./logs_popgym_repeat_previous_easy 178 | ``` 179 | 180 | ### BSuite 181 | 182 | (Coming soon!) 183 | 184 | ## BibTeX 185 | 186 | If you find our work useful, please cite our paper: 187 | 188 | ``` 189 | @inproceedings{ 190 | samsami2024mastering, 191 | title={Mastering Memory Tasks with World Models}, 192 | author={Mohammad Reza Samsami and Artem Zholus and Janarthanan Rajendran and Sarath Chandar}, 193 | booktitle={The Twelfth International Conference on Learning Representations}, 194 | year={2024}, 195 | url={https://openreview.net/forum?id=1vDArHJ68h} 196 | } 197 | ``` 198 | 199 | ## Acknowledgements 200 | 201 | We thank Danijar Hafner for providing the DreamerV3 implementation, which this repo is based upon. 202 | 203 | [arxiv]: https://github.com/chandar-lab/recall2imagine 204 | [website]: https://recall2imagine.github.io/ 205 | [twitter]: https://github.com/chandar-lab/recall2imagine 206 | -------------------------------------------------------------------------------- /env/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 2 | 3 | ARG DEBIAN_FRONTEND=noninteractive 4 | RUN apt-get update && apt-get install -y \ 5 | git xvfb curl \ 6 | libglu1-mesa libglu1-mesa-dev libgl1-mesa-dev libosmesa6-dev mesa-utils freeglut3 freeglut3-dev \ 7 | libglew2.1 libglfw3 libglfw3-dev libegl-dev zlib1g zlib1g-dev libsdl2-dev libjpeg-dev lua5.1 liblua5.1-0-dev libffi-dev \ 8 | build-essential cmake g++ build-essential pkg-config software-properties-common gettext \ 9 | ffmpeg patchelf swig unrar unzip zip curl wget tmux \ 10 | && rm -rf /var/lib/apt/lists/* 11 | 12 | 13 | ENV CNDA=/conda 14 | RUN mkdir -p $CNDA && chmod 755 $CNDA 15 | RUN curl -Lo $CNDA/miniconda.sh https://repo.continuum.io/miniconda/Miniconda3-4.5.11-Linux-x86_64.sh \ 16 | && chmod +x $CNDA/miniconda.sh \ 17 | && $CNDA/miniconda.sh -b -p $CNDA/miniconda \ 18 | && rm $CNDA/miniconda.sh 19 | ENV PATH=$CNDA/miniconda/bin:$PATH 20 | ENV CONDA_AUTO_UPDATE_CONDA=false 21 | 22 | # Create a Python 3.8 environment 23 | RUN $CNDA/miniconda/bin/conda create -y --name py38 python=3.8 \ 24 | && $CNDA/miniconda/bin/conda clean -ya 25 | ENV CONDA_DEFAULT_ENV=py38 26 | ENV CONDA_PREFIX=$CNDA/miniconda/envs/$CONDA_DEFAULT_ENV 27 | ENV PATH=$CONDA_PREFIX/bin:$PATH 28 | RUN $CNDA/miniconda/bin/conda clean -ya 29 | # Create conda environment 30 | # Atari 31 | RUN pip3 install --upgrade setuptools pip 32 | ADD ./gym-0.19.0-py3-none-any.whl /root 33 | RUN pip3 install \ 34 | tensorflow_probability==0.20.1 \ 35 | yacs==0.1.8 \ 36 | tensorflow-cpu==2.12.0 \ 37 | ruamel.yaml==0.17.32 \ 38 | ruamel.yaml.clib==0.2.7 \ 39 | moviepy==1.0.3 \ 40 | imageio==2.31.1 \ 41 | crafter==1.8.1 \ 42 | dm-control==1.0.12 \ 43 | mujoco==2.3.6 \ 44 | robodesk==1.0.0 \ 45 | bsuite==0.3.5 \ 46 | numpy==1.22.1 \ 47 | opt_einsum==3.3.0 \ 48 | einops==0.6.1 \ 49 | wandb==0.15.5 \ 50 | memory-maze==1.0.2 \ 51 | popgym==1.0.2 \ 52 | gymnasium==0.29.0 \ 53 | mazelib==0.9.13 \ 54 | procgen==0.10.7 \ 55 | atari-py==0.2.9 \ 56 | protobuf==3.20.3 \ 57 | gast==0.4.0 \ 58 | zmq \ 59 | /root/gym-0.19.0-py3-none-any.whl \ 60 | rich==13.7.0 \ 61 | msgpack==1.0.7 \ 62 | cloudpickle==1.6.0 \ 63 | opencv-python==4.8.0.74 64 | 65 | RUN pip3 install --upgrade \ 66 | flax \ 67 | dm-haiku \ 68 | "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \ 69 | optax==0.1.5 70 | 71 | RUN wget -L -nv http://www.atarimania.com/roms/Roms.rar && \ 72 | unrar x -y Roms.rar && \ 73 | python3 -m atari_py.import_roms ROMS && \ 74 | rm -rf Roms.rar ROMS.zip ROMS 75 | RUN mkdir -p /root/.mujoco && \ 76 | cd /root/.mujoco && \ 77 | wget -nv https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz -O mujoco.tar.gz && \ 78 | tar -xf mujoco.tar.gz && \ 79 | rm mujoco.tar.gz 80 | 81 | ENV OMP_NUM_THREADS 1 82 | ENV PYTHONUNBUFFERED 1 83 | ENV LANG "C.UTF-8" 84 | ENV NUMBA_CACHE_DIR=/tmp 85 | ENV XLA_PYTHON_CLIENT_MEM_FRACTION 0.8 -------------------------------------------------------------------------------- /env/conda_env.yaml: -------------------------------------------------------------------------------- 1 | channels: 2 | - conda-forge 3 | - defaults 4 | dependencies: 5 | - python=3.8 6 | - pip: 7 | - tensorflow_probability==0.20.1 8 | - optax==0.1.5 9 | - yacs==0.1.8 10 | - tensorflow-cpu==2.12.0 11 | - ruamel.yaml==0.17.32 12 | - ruamel.yaml.clib==0.2.7 13 | - moviepy==1.0.3 14 | - imageio==2.31.1 15 | - crafter==1.8.1 16 | - dm-control==1.0.12 17 | - mujoco==2.3.6 18 | - robodesk==1.0.0 19 | - bsuite==0.3.5 20 | - numpy==1.22.1 21 | - opt_einsum==3.3.0 22 | - einops==0.6.1 23 | - wandb==0.15.5 24 | - memory-maze==1.0.2 25 | - popgym==1.0.2 26 | - gymnasium==0.29.0 27 | - mazelib==0.9.13 28 | - procgen==0.10.7 29 | - flax==0.7.0 30 | - dm-haiku==0.0.10 31 | - atari-py==0.2.9 32 | - protobuf==3.20.3 33 | - gast==0.4.0 34 | - zmq 35 | - rich==13.7.0 36 | - msgpack==1.0.7 37 | - cloudpickle==1.6.0 38 | - opencv-python==4.8.0.74 39 | -------------------------------------------------------------------------------- /env/gym-0.19.0-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chandar-lab/Recall2Imagine/6e317e751ffd8f381e09476f3633567dff9f9234/env/gym-0.19.0-py3-none-any.whl -------------------------------------------------------------------------------- /imgs/mmaze.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chandar-lab/Recall2Imagine/6e317e751ffd8f381e09476f3633567dff9f9234/imgs/mmaze.jpg -------------------------------------------------------------------------------- /imgs/r2i.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chandar-lab/Recall2Imagine/6e317e751ffd8f381e09476f3633567dff9f9234/imgs/r2i.png -------------------------------------------------------------------------------- /imgs/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chandar-lab/Recall2Imagine/6e317e751ffd8f381e09476f3633567dff9f9234/imgs/teaser.jpg -------------------------------------------------------------------------------- /recall2imagine/Dockerfile: -------------------------------------------------------------------------------- 1 | # 1. Test setup: 2 | # docker run -it --rm --gpus all nvidia/cuda:11.4.2-cudnn8-runtime-ubuntu20.04 nvidia-smi 3 | # 4 | # If the above does not work, try adding the --privileged flag 5 | # and changing the command to `sh -c 'ldconfig -v && nvidia-smi'`. 6 | # 7 | # 2. Start training: 8 | # docker build -f dreamerv3/Dockerfile -t img . && \ 9 | # docker run -it --rm --gpus all -v ~/logdir:/logdir img \ 10 | # sh scripts/xvfb_run.sh python3 dreamerv3/train.py \ 11 | # --logdir "/logdir/$(date +%Y%m%d-%H%M%S)" \ 12 | # --configs dmc_vision --task dmc_walker_walk 13 | # 14 | # 3. See results: 15 | # tensorboard --logdir ~/logdir 16 | 17 | # System 18 | FROM nvidia/cuda:11.4.2-cudnn8-devel-ubuntu20.04 19 | ARG DEBIAN_FRONTEND=noninteractive 20 | ENV TZ=America/San_Francisco 21 | ENV PYTHONUNBUFFERED 1 22 | ENV PIP_DISABLE_PIP_VERSION_CHECK 1 23 | ENV PIP_NO_CACHE_DIR 1 24 | RUN apt-get update && apt-get install -y \ 25 | ffmpeg git python3-pip vim libglew-dev \ 26 | x11-xserver-utils xvfb \ 27 | && apt-get clean 28 | RUN pip3 install --upgrade pip 29 | 30 | # Envs 31 | ENV MUJOCO_GL egl 32 | ENV DMLAB_DATASET_PATH /dmlab_data 33 | COPY scripts scripts 34 | RUN sh scripts/install-dmlab.sh 35 | RUN sh scripts/install-atari.sh 36 | RUN sh scripts/install-minecraft.sh 37 | ENV NUMBA_CACHE_DIR=/tmp 38 | RUN pip3 install crafter 39 | RUN pip3 install dm_control 40 | RUN pip3 install robodesk 41 | RUN pip3 install bsuite 42 | 43 | # Agent 44 | RUN pip3 install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 45 | RUN pip3 install jaxlib 46 | RUN pip3 install tensorflow_probability 47 | RUN pip3 install optax 48 | RUN pip3 install tensorflow-cpu 49 | ENV XLA_PYTHON_CLIENT_MEM_FRACTION 0.8 50 | 51 | # Google Cloud DNS cache (optional) 52 | ENV GCS_RESOLVE_REFRESH_SECS=60 53 | ENV GCS_REQUEST_CONNECTION_TIMEOUT_SECS=300 54 | ENV GCS_METADATA_REQUEST_TIMEOUT_SECS=300 55 | ENV GCS_READ_REQUEST_TIMEOUT_SECS=300 56 | ENV GCS_WRITE_REQUEST_TIMEOUT_SECS=600 57 | 58 | # Embodied 59 | RUN pip3 install numpy cloudpickle ruamel.yaml rich zmq msgpack 60 | COPY . /embodied 61 | RUN chown -R 1000:root /embodied && chmod -R 775 /embodied 62 | 63 | WORKDIR embodied 64 | -------------------------------------------------------------------------------- /recall2imagine/__init__.py: -------------------------------------------------------------------------------- 1 | import sys, pathlib 2 | sys.path.append(str(pathlib.Path(__file__).parent)) 3 | 4 | from .agent import Agent 5 | configs = Agent.configs 6 | from .train import wrap_env 7 | -------------------------------------------------------------------------------- /recall2imagine/behaviors.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from tensorflow_probability.substrates import jax as tfp 3 | tfd = tfp.distributions 4 | 5 | from . import agent 6 | from . import expl 7 | from . import ninjax as nj 8 | from . import jaxutils 9 | 10 | 11 | class Greedy(nj.Module): 12 | 13 | def __init__(self, wm, act_space, config): 14 | rewfn = lambda s: wm.heads['reward'](s).mean()[1:] 15 | if config.critic_type == 'vfunction': 16 | critics = {'extr': agent.VFunction(rewfn, config, name='critic')} 17 | else: 18 | raise NotImplementedError(config.critic_type) 19 | self.ac = agent.ImagActorCritic( 20 | critics, {'extr': 1.0}, act_space, config, name='ac') 21 | 22 | def initial(self, batch_size): 23 | return self.ac.initial(batch_size) 24 | 25 | def policy(self, latent, state): 26 | return self.ac.policy(latent, state) 27 | 28 | def train(self, imagine, start, data): 29 | return self.ac.train(imagine, start, data) 30 | 31 | def report(self, data): 32 | return {} 33 | 34 | 35 | class Random(nj.Module): 36 | 37 | def __init__(self, wm, act_space, config): 38 | self.config = config 39 | self.act_space = act_space 40 | 41 | def initial(self, batch_size): 42 | return jnp.zeros(batch_size) 43 | 44 | def policy(self, latent, state): 45 | batch_size = len(state) 46 | shape = (batch_size,) + self.act_space.shape 47 | if self.act_space.discrete: 48 | dist = jaxutils.OneHotDist(jnp.zeros(shape)) 49 | else: 50 | dist = tfd.Uniform(-jnp.ones(shape), jnp.ones(shape)) 51 | dist = tfd.Independent(dist, 1) 52 | return {'action': dist}, state 53 | 54 | def train(self, imagine, start, data): 55 | return None, {} 56 | 57 | def report(self, data): 58 | return {} 59 | 60 | 61 | class Explore(nj.Module): 62 | 63 | REWARDS = { 64 | 'disag': expl.Disag, 65 | } 66 | 67 | def __init__(self, wm, act_space, config): 68 | self.config = config 69 | self.rewards = {} 70 | critics = {} 71 | for key, scale in config.expl_rewards.items(): 72 | if not scale: 73 | continue 74 | if key == 'extr': 75 | rewfn = lambda s: wm.heads['reward'](s).mean()[1:] 76 | critics[key] = agent.VFunction(rewfn, config, name=key) 77 | else: 78 | rewfn = self.REWARDS[key]( 79 | wm, act_space, config, name=key + '_reward') 80 | critics[key] = agent.VFunction(rewfn, config, name=key) 81 | self.rewards[key] = rewfn 82 | scales = {k: v for k, v in config.expl_rewards.items() if v} 83 | self.ac = agent.ImagActorCritic( 84 | critics, scales, act_space, config, name='ac') 85 | 86 | def initial(self, batch_size): 87 | return self.ac.initial(batch_size) 88 | 89 | def policy(self, latent, state): 90 | return self.ac.policy(latent, state) 91 | 92 | def train(self, imagine, start, data): 93 | metrics = {} 94 | for key, rewfn in self.rewards.items(): 95 | mets = rewfn.train(data) 96 | metrics.update({f'{key}_k': v for k, v in mets.items()}) 97 | traj, mets = self.ac.train(imagine, start, data) 98 | metrics.update(mets) 99 | return traj, metrics 100 | 101 | def report(self, data): 102 | return {} 103 | -------------------------------------------------------------------------------- /recall2imagine/embodied/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | import rich.traceback 3 | rich.traceback.install() 4 | except ImportError: 5 | pass 6 | 7 | from .core import * 8 | 9 | from . import envs 10 | from . import replay 11 | from . import run 12 | -------------------------------------------------------------------------------- /recall2imagine/embodied/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Agent, Env, Wrapper, Replay 2 | 3 | from .basics import convert, treemap, pack, unpack 4 | from .basics import print_ as print 5 | from .basics import format_ as format 6 | 7 | from .space import Space 8 | from .path import Path 9 | from .checkpoint import Checkpoint 10 | from .config import Config 11 | from .counter import Counter 12 | from .driver import Driver 13 | from .flags import Flags 14 | from .logger import Logger 15 | from .parallel import Parallel 16 | from .timer import Timer 17 | from .worker import Worker 18 | from .batcher import Batcher, BatcherSM 19 | from .metrics import Metrics 20 | from .uuid import uuid 21 | 22 | from .batch import BatchEnv 23 | from .random import RandomAgent 24 | from .distr import Client, Server, BatchServer 25 | 26 | from . import logger 27 | from . import when 28 | from . import wrappers 29 | from . import distr 30 | -------------------------------------------------------------------------------- /recall2imagine/embodied/core/base.py: -------------------------------------------------------------------------------- 1 | class Agent: 2 | 3 | configs = {} # dict of dicts 4 | 5 | def __init__(self, obs_space, act_space, step, config): 6 | pass 7 | 8 | def dataset(self, generator_fn): 9 | raise NotImplementedError( 10 | 'dataset(generator_fn) -> generator_fn') 11 | 12 | def policy(self, obs, state=None, mode='train'): 13 | raise NotImplementedError( 14 | "policy(obs, state=None, mode='train') -> act, state") 15 | 16 | def train(self, data, state=None): 17 | raise NotImplementedError( 18 | 'train(data, state=None) -> outs, state, metrics') 19 | 20 | def report(self, data): 21 | raise NotImplementedError( 22 | 'report(data) -> metrics') 23 | 24 | def save(self): 25 | raise NotImplementedError('save() -> data') 26 | 27 | def load(self, data): 28 | raise NotImplementedError('load(data) -> None') 29 | 30 | def sync(self): 31 | # This method allows the agent to sync parameters from its training devices 32 | # to its policy devices in the case of a multi-device agent. 33 | pass 34 | 35 | 36 | class Env: 37 | 38 | def __len__(self): 39 | return 0 # Return positive integer for batched envs. 40 | 41 | def __bool__(self): 42 | return True # Env is always truthy, despite length zero. 43 | 44 | def __repr__(self): 45 | return ( 46 | f'{self.__class__.__name__}(' 47 | f'len={len(self)}, ' 48 | f'obs_space={self.obs_space}, ' 49 | f'act_space={self.act_space})') 50 | 51 | @property 52 | def obs_space(self): 53 | # The observation space must contain the keys is_first, is_last, and 54 | # is_terminal. Commonly, it also contains the keys reward and image. By 55 | # convention, keys starting with log_ are not consumed by the agent. 56 | raise NotImplementedError('Returns: dict of spaces') 57 | 58 | @property 59 | def act_space(self): 60 | # The observation space must contain the keys action and reset. This 61 | # restriction may be lifted in the future. 62 | raise NotImplementedError('Returns: dict of spaces') 63 | 64 | def step(self, action): 65 | raise NotImplementedError('Returns: dict') 66 | 67 | def render(self): 68 | raise NotImplementedError('Returns: array') 69 | 70 | def close(self): 71 | pass 72 | 73 | 74 | class Wrapper: 75 | 76 | def __init__(self, env): 77 | self.env = env 78 | 79 | def __len__(self): 80 | return len(self.env) 81 | 82 | def __bool__(self): 83 | return bool(self.env) 84 | 85 | def __getattr__(self, name): 86 | if name.startswith('__'): 87 | raise AttributeError(name) 88 | try: 89 | return getattr(self.env, name) 90 | except AttributeError: 91 | raise ValueError(name) 92 | 93 | 94 | class Replay: 95 | 96 | def __len__(self): 97 | raise NotImplementedError('Returns: total number of steps') 98 | 99 | @property 100 | def stats(self): 101 | raise NotImplementedError('Returns: metrics') 102 | 103 | def add(self, transition, worker=0): 104 | raise NotImplementedError('Returns: None') 105 | 106 | def add_traj(self, trajectory): 107 | raise NotImplementedError('Returns: None') 108 | 109 | def dataset(self): 110 | raise NotImplementedError('Yields: trajectory') 111 | 112 | def prioritize(self, keys, priorities): 113 | pass 114 | 115 | def save(self): 116 | pass 117 | 118 | def load(self, data): 119 | pass 120 | -------------------------------------------------------------------------------- /recall2imagine/embodied/core/basics.py: -------------------------------------------------------------------------------- 1 | import builtins 2 | import pickle 3 | 4 | import numpy as np 5 | 6 | from . import space as spacelib 7 | 8 | try: 9 | import rich.console 10 | console = rich.console.Console() 11 | except ImportError: 12 | console = None 13 | 14 | 15 | CONVERSION = { 16 | np.floating: np.float32, 17 | np.signedinteger: np.int64, 18 | np.uint8: np.uint8, 19 | bool: bool, 20 | } 21 | 22 | 23 | def convert(value): 24 | value = np.asarray(value) 25 | if value.dtype not in CONVERSION.values(): 26 | for src, dst in CONVERSION.items(): 27 | if np.issubdtype(value.dtype, src): 28 | if value.dtype != dst: 29 | value = value.astype(dst) 30 | break 31 | else: 32 | raise TypeError(f"Object '{value}' has unsupported dtype: {value.dtype}") 33 | return value 34 | 35 | 36 | def print_(value, color=None): 37 | global console 38 | value = format_(value) 39 | if console: 40 | if color: 41 | value = f'[{color}]{value}[/{color}]' 42 | console.print(value) 43 | else: 44 | builtins.print(value) 45 | 46 | 47 | def format_(value): 48 | if isinstance(value, dict): 49 | if value and all(isinstance(x, spacelib.Space) for x in value.values()): 50 | return '\n'.join(f' {k:<16} {v}' for k, v in value.items()) 51 | items = [f'{format_(k)}: {format_(v)}' for k, v in value.items()] 52 | return '{' + ', '.join(items) + '}' 53 | if isinstance(value, list): 54 | return '[' + ', '.join(f'{format_(x)}' for x in value) + ']' 55 | if isinstance(value, tuple): 56 | return '(' + ', '.join(f'{format_(x)}' for x in value) + ')' 57 | if hasattr(value, 'shape') and hasattr(value, 'dtype'): 58 | shape = ','.join(str(x) for x in value.shape) 59 | dtype = value.dtype.name 60 | for long, short in {'float': 'f', 'uint': 'u', 'int': 'i'}.items(): 61 | dtype = dtype.replace(long, short) 62 | return f'{dtype}[{shape}]' 63 | if isinstance(value, bytes): 64 | value = '0x' + value.hex() if r'\x' in str(value) else str(value) 65 | if len(value) > 32: 66 | value = value[:32 - 3] + '...' 67 | return str(value) 68 | 69 | 70 | def treemap(fn, *trees, isleaf=None): 71 | assert trees, 'Provide one or more nested Python structures' 72 | kw = dict(isleaf=isleaf) 73 | first = trees[0] 74 | assert all(isinstance(x, type(first)) for x in trees) 75 | if isleaf and isleaf(trees): 76 | return fn(*trees) 77 | if isinstance(first, list): 78 | assert all(len(x) == len(first) for x in trees), format_(trees) 79 | return [treemap( 80 | fn, *[t[i] for t in trees], **kw) for i in range(len(first))] 81 | if isinstance(first, tuple): 82 | assert all(len(x) == len(first) for x in trees), format_(trees) 83 | return tuple([treemap( 84 | fn, *[t[i] for t in trees], **kw) for i in range(len(first))]) 85 | if isinstance(first, dict): 86 | assert all(set(x.keys()) == set(first.keys()) for x in trees), ( 87 | format_(trees)) 88 | return {k: treemap(fn, *[t[k] for t in trees], **kw) for k in first} 89 | return fn(*trees) 90 | 91 | 92 | def pack(data): 93 | return pickle.dumps(data) 94 | # import msgpack 95 | # def fn(data): 96 | # if isinstance(data, np.ndarray): 97 | # return [b'type_numpy', list(data.shape), data.dtype.name, data.tobytes()] 98 | # if isinstance(data, bytes): 99 | # return [b'type_bytes', data] 100 | # if isinstance(data, tuple): 101 | # return [b'type_tuple', *[fn(x) for x in data]] 102 | # if isinstance(data, list): 103 | # return [fn(x) for x in data] 104 | # if isinstance(data, str): 105 | # return data.encode('utf-8') 106 | # if isinstance(data, dict): 107 | # return {k: fn(v) for k, v in data.items()} 108 | # if allow_pickle: 109 | # primitives = (type(None), bool, int, float, str, bytes) 110 | # if not isinstance(data, primitives): 111 | # return [b'type_pickle', pickle.dumps(data)] 112 | # return data 113 | # data = fn(data) 114 | # # print(format_(data)) 115 | # data = msgpack.packb( 116 | # data, use_single_float=True, use_bin_type=True, strict_types=True) 117 | # return data 118 | 119 | 120 | def unpack(buffer): 121 | return pickle.loads(buffer) 122 | # import msgpack 123 | # import pickle 124 | # def fn(data): 125 | # if isinstance(data, list) and data and data[0] == b'type_numpy': 126 | # return np.frombuffer(data[3], data[2].decode('utf-8')).reshape(data[1]) 127 | # if isinstance(data, list) and data and data[0] == b'type_bytes': 128 | # return data[1] 129 | # if isinstance(data, list) and data and data[0] == b'type_tuple': 130 | # return tuple([fn(x) for x in data[1:]]) 131 | # if isinstance(data, list) and data and data[0] == b'type_pickle': 132 | # assert allow_pickle, 'Buffer contains pickled Python objects.' 133 | # return pickle.loads(data[1]) 134 | # if isinstance(data, list): 135 | # return [fn(x) for x in data] 136 | # if isinstance(data, str): 137 | # return data.decode('utf-8') 138 | # if isinstance(data, dict): 139 | # return {k.decode('utf-8'): fn(v) for k, v in data.items()} 140 | # return data 141 | # data = msgpack.unpackb(buffer, raw=True, use_list=True) 142 | # data = fn(data) 143 | # # print(format_(data)) 144 | # return data 145 | -------------------------------------------------------------------------------- /recall2imagine/embodied/core/batch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from . import base 4 | 5 | 6 | class BatchEnv(base.Env): 7 | 8 | def __init__(self, envs, parallel): 9 | assert all(len(env) == 0 for env in envs) 10 | assert len(envs) > 0 11 | self._envs = envs 12 | self._parallel = parallel 13 | self._keys = list(self.obs_space.keys()) 14 | 15 | @property 16 | def obs_space(self): 17 | return self._envs[0].obs_space 18 | 19 | @property 20 | def act_space(self): 21 | return self._envs[0].act_space 22 | 23 | def __len__(self): 24 | return len(self._envs) 25 | 26 | def step(self, action): 27 | assert all(len(v) == len(self._envs) for v in action.values()), ( 28 | len(self._envs), {k: v.shape for k, v in action.items()}) 29 | obs = [] 30 | for i, env in enumerate(self._envs): 31 | act = {k: v[i] for k, v in action.items()} 32 | obs.append(env.step(act)) 33 | if self._parallel: 34 | obs = [ob() for ob in obs] 35 | return {k: np.array([ob[k] for ob in obs]) for k in obs[0]} 36 | 37 | def render(self): 38 | return np.stack([env.render() for env in self._envs]) 39 | 40 | def close(self): 41 | for env in self._envs: 42 | try: 43 | env.close() 44 | except Exception: 45 | pass 46 | -------------------------------------------------------------------------------- /recall2imagine/embodied/core/batcher.py: -------------------------------------------------------------------------------- 1 | import queue as queuelib 2 | import sys 3 | import threading 4 | import time 5 | import traceback 6 | 7 | import numpy as np 8 | 9 | 10 | class Batcher: 11 | 12 | def __init__( 13 | self, sources, workers=0, postprocess=None, 14 | prefetch_source=4, prefetch_batch=2): 15 | self._workers = workers 16 | self._postprocess = postprocess 17 | if workers: 18 | # Round-robin assign sources to workers. 19 | self._running = True 20 | self._threads = [] 21 | self._queues = [] 22 | assignments = [([], []) for _ in range(workers)] 23 | for index, source in enumerate(sources): 24 | queue = queuelib.Queue(prefetch_source) 25 | self._queues.append(queue) 26 | assignments[index % workers][0].append(source) 27 | assignments[index % workers][1].append(queue) 28 | for args in assignments: 29 | creator = threading.Thread( 30 | target=self._creator, args=args, daemon=True) 31 | creator.start() 32 | self._threads.append(creator) 33 | self._batches = queuelib.Queue(prefetch_batch) 34 | batcher = threading.Thread( 35 | target=self._batcher, args=(self._queues, self._batches), 36 | daemon=True) 37 | batcher.start() 38 | self._threads.append(batcher) 39 | else: 40 | self._iterators = [source() for source in sources] 41 | self._once = False 42 | 43 | def close(self): 44 | if self._workers: 45 | self._running = False 46 | for thread in self._threads: 47 | thread.close() 48 | 49 | def __iter__(self): 50 | if self._once: 51 | raise RuntimeError( 52 | 'You can only create one iterator per Batcher object to ensure that ' 53 | 'data is consumed in order. Create another Batcher object instead.') 54 | self._once = True 55 | return self 56 | 57 | def __call__(self): 58 | return self.__iter__() 59 | 60 | def __next__(self): 61 | if self._workers: 62 | batch = self._batches.get() 63 | else: 64 | elems = [next(x) for x in self._iterators] 65 | batch = {k: np.stack([x[k] for x in elems], 0) for k in elems[0]} 66 | if isinstance(batch, Exception): 67 | raise batch 68 | return batch 69 | 70 | def _creator(self, sources, outputs): 71 | try: 72 | iterators = [source() for source in sources] 73 | while self._running: 74 | waiting = True 75 | for iterator, queue in zip(iterators, outputs): 76 | if queue.full(): 77 | continue 78 | queue.put(next(iterator)) 79 | waiting = False 80 | if waiting: 81 | time.sleep(0.001) 82 | except Exception as e: 83 | e.stacktrace = ''.join(traceback.format_exception(*sys.exc_info())) 84 | outputs[0].put(e) 85 | raise 86 | 87 | def _batcher(self, sources, output): 88 | try: 89 | while self._running: 90 | elems = [x.get() for x in sources] 91 | for elem in elems: 92 | if isinstance(elem, Exception): 93 | raise elem 94 | batch = {k: np.stack([x[k] for x in elems], 0) for k in elems[0]} 95 | if self._postprocess: 96 | batch = self._postprocess(batch) 97 | output.put(batch) # Will wait here if the queue is full. 98 | except Exception as e: 99 | e.stacktrace = ''.join(traceback.format_exception(*sys.exc_info())) 100 | output.put(e) 101 | raise 102 | 103 | 104 | class BatcherSM: 105 | """ 106 | A class for the batching process that uses shared numpy memory. 107 | It is much faster than standard Batcher when we are working with 108 | long sequences in batches because the former sends a lot of data through 109 | threading.Queue, which is much slower than reading/writing 110 | to the shared numpy arrays. 111 | """ 112 | def __init__( 113 | self, replay, workers=0, batch_size=1, 114 | batch_sequence_len=1, postprocess=None, 115 | prefetch_source=4, prefetch_batch=2): 116 | self._workers = workers 117 | self._postprocess = postprocess 118 | # we assume replay was filled with at least one step so the serializer is 119 | # initialized already 120 | self._replay = replay 121 | self.prefetch_batch = prefetch_batch 122 | self.batch_size = batch_size 123 | self._serializer = self._replay.serializer 124 | self._batch_buffers = None 125 | self.batch_sequence_len = batch_sequence_len 126 | self.batch_size = batch_size 127 | 128 | if workers: 129 | # Round-robin assign sources to workers. 130 | self._running = True 131 | self._threads = [] 132 | self._queues = [] 133 | self.tasks = queuelib.Queue(prefetch_batch * batch_size) 134 | self.reports = queuelib.Queue(prefetch_batch * batch_size) 135 | for _ in range(workers): 136 | creator = threading.Thread( 137 | target=self._creator, args=(), daemon=True 138 | ) 139 | creator.start() 140 | self._threads.append(creator) 141 | self._outputs = queuelib.Queue(prefetch_batch) 142 | batcher = threading.Thread( 143 | target=self._batcher, args=(), 144 | daemon=True) 145 | batcher.start() 146 | self._threads.append(batcher) 147 | self._once = False 148 | 149 | def close(self): 150 | if self._workers: 151 | self._running = False 152 | for thread in self._threads: 153 | thread.close() 154 | 155 | def __iter__(self): 156 | if self._once: 157 | raise RuntimeError( 158 | 'You can only create one iterator per Batcher object to ensure that ' 159 | 'data is consumed in order. Create another Batcher object instead.') 160 | self._once = True 161 | return self 162 | 163 | def __call__(self): 164 | return self.__iter__() 165 | 166 | def __next__(self): 167 | if self._workers: 168 | batch = self._outputs.get() 169 | else: 170 | elems = [next(x) for x in self._iterators] 171 | batch = {k: np.stack([x[k] for x in elems], 0) for k in elems[0]} 172 | if isinstance(batch, Exception): 173 | raise batch 174 | return batch 175 | 176 | def _creator(self): 177 | try: 178 | while self._running: 179 | if not self.tasks.empty() and self._replay.ready: 180 | if self._batch_buffers is None: 181 | self._serializer = self._replay.serializer 182 | self._batch_buffers = self._serializer.batch_buffer( 183 | self.prefetch_batch, self.batch_size, self.batch_sequence_len) 184 | flip, task_id = self.tasks.get() 185 | success = self._replay.sample(flip, task_id, self._batch_buffers) 186 | assert success 187 | self.reports.put((flip, task_id)) 188 | else: 189 | time.sleep(0.001) 190 | except Exception as e: 191 | e.stacktrace = ''.join(traceback.format_exception(*sys.exc_info())) 192 | raise 193 | 194 | def read_buffer(self, flip): 195 | # copy here is necessary as right after this read we'll 196 | # assign another set of data loading tasks which will 197 | # overwrite the batch buffers 198 | return {k: v[flip].copy() for k, v in self._batch_buffers.items()} 199 | 200 | def _batcher(self, ): 201 | if self.tasks.empty(): 202 | for flip in range(self.prefetch_batch): 203 | for batch_id in range(self.batch_size): 204 | self.tasks.put((flip, batch_id)) 205 | batch_completion = [[False for _ in range(self.batch_size)] 206 | for _ in range(self.prefetch_batch)] 207 | try: 208 | while self._running: 209 | while not self.reports.empty(): 210 | flip, task_id = self.reports.get() 211 | batch_completion[flip][task_id] = True 212 | completed = [flip for flip in range(self.prefetch_batch) if all(batch_completion[flip])] 213 | for flip in completed: 214 | batch = self.read_buffer(flip) 215 | for k in range(self.batch_size): 216 | batch_completion[flip][k] = False 217 | self.tasks.put((flip, k)) 218 | if self._postprocess: 219 | batch = self._postprocess(batch) 220 | self._outputs.put(batch) 221 | if len(completed) == 0: 222 | time.sleep(0.001) 223 | except Exception as e: 224 | e.stacktrace = ''.join(traceback.format_exception(*sys.exc_info())) 225 | raise 226 | -------------------------------------------------------------------------------- /recall2imagine/embodied/core/checkpoint.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures 2 | import time 3 | 4 | from . import basics 5 | from . import path 6 | 7 | 8 | class Checkpoint: 9 | 10 | def __init__(self, filename=None, log=True, parallel=True): 11 | self._filename = filename and path.Path(filename) 12 | self._log = log 13 | self._values = {} 14 | self._parallel = parallel 15 | if self._parallel: 16 | self._worker = concurrent.futures.ThreadPoolExecutor(1) 17 | self._promise = None 18 | 19 | def __setattr__(self, name, value): 20 | if name in ('exists', 'save', 'load'): 21 | return super().__setattr__(name, value) 22 | if name.startswith('_'): 23 | return super().__setattr__(name, value) 24 | has_load = hasattr(value, 'load') and callable(value.load) 25 | has_save = hasattr(value, 'save') and callable(value.save) 26 | if not (has_load and has_save): 27 | message = f"Checkpoint entry '{name}' must implement save() and load()." 28 | raise ValueError(message) 29 | self._values[name] = value 30 | 31 | def __getattr__(self, name): 32 | if name.startswith('_'): 33 | raise AttributeError(name) 34 | try: 35 | return getattr(self._values, name) 36 | except AttributeError: 37 | raise ValueError(name) 38 | 39 | def exists(self, filename=None): 40 | assert self._filename or filename 41 | filename = path.Path(filename or self._filename) 42 | exists = self._filename.exists() 43 | self._log and exists and print('Found existing checkpoint.') 44 | self._log and not exists and print('Did not find any checkpoint.') 45 | return exists 46 | 47 | def save(self, filename=None, keys=None): 48 | assert self._filename or filename 49 | filename = path.Path(filename or self._filename) 50 | self._log and print(f'Writing checkpoint: {filename}') 51 | if self._parallel: 52 | self._promise and self._promise.result() 53 | self._promise = self._worker.submit(self._save, filename, keys) 54 | else: 55 | self._save(filename, keys) 56 | 57 | def _save(self, filename, keys): 58 | keys = tuple(self._values.keys() if keys is None else keys) 59 | assert all([not k.startswith('_') for k in keys]), keys 60 | data = {k: self._values[k].save() for k in keys} 61 | data['_timestamp'] = time.time() 62 | if filename.exists(): 63 | old = filename.parent / (filename.name + '.old') 64 | filename.copy(old) 65 | filename.write(basics.pack(data), mode='wb') 66 | old.remove() 67 | else: 68 | filename.write(basics.pack(data), mode='wb') 69 | self._log and print(f'Wrote checkpoint: {filename}') 70 | 71 | def load(self, filename=None, keys=None): 72 | assert self._filename or filename 73 | filename = path.Path(filename or self._filename) 74 | self._log and print(f'Loading checkpoint: {filename}') 75 | data = basics.unpack(filename.read('rb')) 76 | keys = tuple(data.keys() if keys is None else keys) 77 | all_loaded = True 78 | for key in keys: 79 | if key.startswith('_'): 80 | continue 81 | try: 82 | loaded = self._values[key].load(data[key]) 83 | if loaded is None: 84 | loaded = True 85 | all_loaded = loaded and all_loaded 86 | except Exception: 87 | print(f'Error loading {key} from checkpoint.') 88 | raise 89 | if self._log: 90 | age = time.time() - data['_timestamp'] 91 | print(f'Loaded checkpoint from {age:.0f} seconds ago.') 92 | return all_loaded 93 | 94 | def load_or_save(self, filename=None): 95 | if self.exists(filename=filename): 96 | loaded = self.load() 97 | if loaded is not None and not loaded: 98 | self.save() 99 | else: 100 | self.save() 101 | -------------------------------------------------------------------------------- /recall2imagine/embodied/core/config.py: -------------------------------------------------------------------------------- 1 | import io 2 | import json 3 | import re 4 | 5 | from . import path 6 | 7 | 8 | class Config(dict): 9 | 10 | SEP = '.' 11 | IS_PATTERN = re.compile(r'.*[^A-Za-z0-9_.-].*') 12 | 13 | def __init__(self, *args, **kwargs): 14 | mapping = dict(*args, **kwargs) 15 | mapping = self._flatten(mapping) 16 | mapping = self._ensure_keys(mapping) 17 | mapping = self._ensure_values(mapping) 18 | self._flat = mapping 19 | self._nested = self._nest(mapping) 20 | # Need to assign the values to the base class dictionary so that 21 | # conversion to dict does not lose the content. 22 | super().__init__(self._nested) 23 | 24 | @property 25 | def flat(self): 26 | return self._flat.copy() 27 | 28 | def save(self, filename): 29 | filename = path.Path(filename) 30 | if filename.suffix == '.json': 31 | filename.write(json.dumps(dict(self))) 32 | elif filename.suffix in ('.yml', '.yaml'): 33 | import ruamel.yaml as yaml 34 | with io.StringIO() as stream: 35 | yaml.safe_dump(dict(self), stream) 36 | filename.write(stream.getvalue()) 37 | else: 38 | raise NotImplementedError(filename.suffix) 39 | 40 | @classmethod 41 | def load(cls, filename): 42 | filename = path.Path(filename) 43 | if filename.suffix == '.json': 44 | return cls(json.loads(filename.read_text())) 45 | elif filename.suffix in ('.yml', '.yaml'): 46 | import ruamel.yaml as yaml 47 | return cls(yaml.safe_load(filename.read_text())) 48 | else: 49 | raise NotImplementedError(filename.suffix) 50 | 51 | def __contains__(self, name): 52 | try: 53 | self[name] 54 | return True 55 | except KeyError: 56 | return False 57 | 58 | def __getattr__(self, name): 59 | if name.startswith('_'): 60 | return super().__getattr__(name) 61 | try: 62 | return self[name] 63 | except KeyError: 64 | raise AttributeError(name) 65 | 66 | def __getitem__(self, name): 67 | result = self._nested 68 | for part in name.split(self.SEP): 69 | try: 70 | result = result[part] 71 | except TypeError: 72 | raise KeyError 73 | if isinstance(result, dict): 74 | result = type(self)(result) 75 | return result 76 | 77 | def __setattr__(self, key, value): 78 | if key.startswith('_'): 79 | return super().__setattr__(key, value) 80 | message = f"Tried to set key '{key}' on immutable config. Use update()." 81 | raise AttributeError(message) 82 | 83 | def __setitem__(self, key, value): 84 | if key.startswith('_'): 85 | return super().__setitem__(key, value) 86 | message = f"Tried to set key '{key}' on immutable config. Use update()." 87 | raise AttributeError(message) 88 | 89 | def __reduce__(self): 90 | return (type(self), (dict(self),)) 91 | 92 | def __str__(self): 93 | lines = ['\nConfig:'] 94 | keys, vals, typs = [], [], [] 95 | for key, val in self.flat.items(): 96 | keys.append(key + ':') 97 | vals.append(self._format_value(val)) 98 | typs.append(self._format_type(val)) 99 | max_key = max(len(k) for k in keys) if keys else 0 100 | max_val = max(len(v) for v in vals) if vals else 0 101 | for key, val, typ in zip(keys, vals, typs): 102 | key = key.ljust(max_key) 103 | val = val.ljust(max_val) 104 | lines.append(f'{key} {val} ({typ})') 105 | return '\n'.join(lines) 106 | 107 | def update(self, *args, **kwargs): 108 | result = self._flat.copy() 109 | inputs = self._flatten(dict(*args, **kwargs)) 110 | for key, new in inputs.items(): 111 | if self.IS_PATTERN.match(key): 112 | pattern = re.compile(key) 113 | keys = {k for k in result if pattern.match(k)} 114 | else: 115 | keys = [key] 116 | if not keys: 117 | raise KeyError(f'Unknown key or pattern {key}.') 118 | for key in keys: 119 | old = result[key] 120 | try: 121 | if isinstance(old, int) and isinstance(new, float): 122 | if float(int(new)) != new: 123 | message = f"Cannot convert fractional float {new} to int." 124 | raise ValueError(message) 125 | result[key] = type(old)(new) 126 | except (ValueError, TypeError): 127 | raise TypeError( 128 | f"Cannot convert '{new}' to type '{type(old).__name__}' " + 129 | f"for key '{key}' with previous value '{old}'.") 130 | return type(self)(result) 131 | 132 | def _flatten(self, mapping): 133 | result = {} 134 | for key, value in mapping.items(): 135 | if isinstance(value, dict): 136 | for k, v in self._flatten(value).items(): 137 | if self.IS_PATTERN.match(key) or self.IS_PATTERN.match(k): 138 | combined = f'{key}\\{self.SEP}{k}' 139 | else: 140 | combined = f'{key}{self.SEP}{k}' 141 | result[combined] = v 142 | else: 143 | result[key] = value 144 | return result 145 | 146 | def _nest(self, mapping): 147 | result = {} 148 | for key, value in mapping.items(): 149 | parts = key.split(self.SEP) 150 | node = result 151 | for part in parts[:-1]: 152 | if part not in node: 153 | node[part] = {} 154 | node = node[part] 155 | node[parts[-1]] = value 156 | return result 157 | 158 | def _ensure_keys(self, mapping): 159 | for key in mapping: 160 | assert not self.IS_PATTERN.match(key), key 161 | return mapping 162 | 163 | def _ensure_values(self, mapping): 164 | result = json.loads(json.dumps(mapping)) 165 | for key, value in result.items(): 166 | if isinstance(value, list): 167 | value = tuple(value) 168 | if isinstance(value, tuple): 169 | if len(value) == 0: 170 | message = 'Empty lists are disallowed because their type is unclear.' 171 | raise TypeError(message) 172 | if not isinstance(value[0], (str, float, int, bool)): 173 | message = 'Lists can only contain strings, floats, ints, bools' 174 | message += f' but not {type(value[0])}' 175 | raise TypeError(message) 176 | if not all(isinstance(x, type(value[0])) for x in value[1:]): 177 | message = 'Elements of a list must all be of the same type.' 178 | raise TypeError(message) 179 | result[key] = value 180 | return result 181 | 182 | def _format_value(self, value): 183 | if isinstance(value, (list, tuple)): 184 | return '[' + ', '.join(self._format_value(x) for x in value) + ']' 185 | return str(value) 186 | 187 | def _format_type(self, value): 188 | if isinstance(value, (list, tuple)): 189 | assert len(value) > 0, value 190 | return self._format_type(value[0]) + 's' 191 | return str(type(value).__name__) 192 | -------------------------------------------------------------------------------- /recall2imagine/embodied/core/counter.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | 4 | @functools.total_ordering 5 | class Counter: 6 | 7 | def __init__(self, initial=0): 8 | self.value = initial 9 | 10 | def __repr__(self): 11 | return f'Counter({self.value})' 12 | 13 | def __int__(self): 14 | return int(self.value) 15 | 16 | def __eq__(self, other): 17 | return int(self) == other 18 | 19 | def __ne__(self, other): 20 | return int(self) != other 21 | 22 | def __lt__(self, other): 23 | return int(self) < other 24 | 25 | def __add__(self, other): 26 | return int(self) + other 27 | 28 | def __radd__(self, other): 29 | return other - int(self) 30 | 31 | def __sub__(self, other): 32 | return int(self) - other 33 | 34 | def __rsub__(self, other): 35 | return other - int(self) 36 | 37 | def increment(self, amount=1): 38 | self.value += amount 39 | 40 | def save(self): 41 | return self.value 42 | 43 | def load(self, value): 44 | self.value = value 45 | -------------------------------------------------------------------------------- /recall2imagine/embodied/core/distr.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import sys 3 | import threading 4 | import time 5 | import traceback 6 | import uuid 7 | 8 | import numpy as np 9 | 10 | from . import basics 11 | 12 | 13 | class Client: 14 | 15 | def __init__(self, address, timeout_ms=-1, ipv6=False): 16 | import zmq 17 | addresses = [address] if isinstance(address, str) else address 18 | context = zmq.Context.instance() 19 | self.socket = context.socket(zmq.REQ) 20 | self.socket.setsockopt(zmq.IDENTITY, uuid.uuid4().bytes) 21 | self.socket.RCVTIMEO = timeout_ms 22 | for address in addresses: 23 | basics.print_(f'Client connecting to {address}', color='green') 24 | ipv6 and self.socket.setsockopt(zmq.IPV6, 1) 25 | self.socket.connect(address) 26 | self.result = True 27 | 28 | def __call__(self, data): 29 | assert isinstance(data, dict), type(data) 30 | if self.result is None: 31 | self._receive() 32 | self.result = None 33 | self.socket.send(basics.pack(data)) 34 | return self._receive 35 | 36 | def _receive(self): 37 | try: 38 | recieved = self.socket.recv() 39 | except Exception as e: 40 | raise RuntimeError(f'Failed to receive data from server: {e}') 41 | self.result = basics.unpack(recieved) 42 | if self.result.get('type', 'data') == 'error': 43 | msg = self.result.get('message', None) 44 | raise RuntimeError(f'Server responded with an error: {msg}') 45 | return self.result 46 | 47 | 48 | class Server: 49 | 50 | def __init__(self, address, function, ipv6=False): 51 | import zmq 52 | context = zmq.Context.instance() 53 | self.socket = context.socket(zmq.REP) 54 | basics.print_(f'Server listening at {address}', color='green') 55 | ipv6 and self.socket.setsockopt(zmq.IPV6, 1) 56 | self.socket.bind(address) 57 | self.function = function 58 | 59 | def run(self): 60 | while True: 61 | payload = self.socket.recv() 62 | inputs = basics.unpack(payload) 63 | assert isinstance(inputs, dict), type(inputs) 64 | try: 65 | result = self.function(inputs) 66 | assert isinstance(result, dict), type(result) 67 | except Exception as e: 68 | result = {'type': 'error', 'message': str(e)} 69 | self.socket.send(basics.pack(payload)) 70 | raise 71 | payload = basics.pack(result) 72 | self.socket.send(payload) 73 | 74 | 75 | class BatchServer: 76 | 77 | def __init__(self, address, batch, function, ipv6=False): 78 | import zmq 79 | context = zmq.Context.instance() 80 | self.socket = context.socket(zmq.ROUTER) 81 | basics.print_(f'BatchServer listening at {address}', color='green') 82 | ipv6 and self.socket.setsockopt(zmq.IPV6, 1) 83 | self.socket.bind(address) 84 | self.function = function 85 | self.batch = batch 86 | 87 | def run(self): 88 | inputs = None 89 | while True: 90 | addresses = [] 91 | for i in range(self.batch): 92 | address, empty, payload = self.socket.recv_multipart() 93 | data = basics.unpack(payload) 94 | assert isinstance(data, dict), type(data) 95 | if inputs is None: 96 | inputs = { 97 | k: np.empty((self.batch, *v.shape), v.dtype) 98 | for k, v in data.items() if not isinstance(v, str)} 99 | for key, value in data.items(): 100 | inputs[key][i] = value 101 | addresses.append(address) 102 | try: 103 | results = self.function(inputs, [x.hex() for x in addresses]) 104 | assert isinstance(results, dict), type(results) 105 | for key, value in results.items(): 106 | if not isinstance(value, str): 107 | assert len(value) == self.batch, (key, value.shape) 108 | except Exception as e: 109 | results = {'type': 'error', 'message': str(e)} 110 | self._respond(addresses, results) 111 | raise 112 | self._respond(addresses, results) 113 | 114 | def _respond(self, addresses, results): 115 | for i, address in enumerate(addresses): 116 | payload = basics.pack({ 117 | k: v if isinstance(v, str) else v[i] 118 | for k, v in results.items()}) 119 | self.socket.send_multipart([address, b'', payload]) 120 | 121 | 122 | class Thread(threading.Thread): 123 | 124 | lock = threading.Lock() 125 | 126 | def __init__(self, fn, *args, name=None): 127 | self.fn = fn 128 | self.exitcode = None 129 | name = name or fn.__name__ 130 | super().__init__(target=self._wrapper, args=args, name=name, daemon=True) 131 | 132 | def _wrapper(self, *args): 133 | try: 134 | self.fn(*args) 135 | except Exception: 136 | with self.lock: 137 | print('-' * 79) 138 | print(f'Exception in worker: {self.name}') 139 | print('-' * 79) 140 | print(''.join(traceback.format_exception(*sys.exc_info()))) 141 | self.exitcode = 1 142 | raise 143 | self.exitcode = 0 144 | 145 | def terminate(self): 146 | if not self.is_alive(): 147 | return 148 | if hasattr(self, '_thread_id'): 149 | thread_id = self._thread_id 150 | else: 151 | thread_id = [k for k, v in threading._active.items() if v is self][0] 152 | result = ctypes.pythonapi.PyThreadState_SetAsyncExc( 153 | ctypes.c_long(thread_id), ctypes.py_object(SystemExit)) 154 | if result > 1: 155 | ctypes.pythonapi.PyThreadState_SetAsyncExc( 156 | ctypes.c_long(thread_id), None) 157 | print('Shut down worker:', self.name) 158 | 159 | 160 | class Process: 161 | 162 | lock = None 163 | initializers = [] 164 | 165 | def __init__(self, fn, *args, name=None): 166 | import multiprocessing 167 | import cloudpickle 168 | mp = multiprocessing.get_context('spawn') 169 | if Process.lock is None: 170 | Process.lock = mp.Lock() 171 | name = name or fn.__name__ 172 | initializers = cloudpickle.dumps(self.initializers) 173 | args = (initializers,) + args 174 | self._process = mp.Process( 175 | target=self._wrapper, args=(Process.lock, fn, *args), 176 | name=name) 177 | 178 | def start(self): 179 | self._process.start() 180 | 181 | @property 182 | def name(self): 183 | return self._process.name 184 | 185 | @property 186 | def exitcode(self): 187 | return self._process.exitcode 188 | 189 | def terminate(self): 190 | self._process.terminate() 191 | print('Shut down worker:', self.name) 192 | 193 | def _wrapper(self, lock, fn, *args): 194 | try: 195 | import cloudpickle 196 | initializers, *args = args 197 | for initializer in cloudpickle.loads(initializers): 198 | initializer() 199 | fn(*args) 200 | except Exception: 201 | with lock: 202 | print('-' * 79) 203 | print(f'Exception in worker: {self.name}') 204 | print('-' * 79) 205 | print(''.join(traceback.format_exception(*sys.exc_info()))) 206 | raise 207 | 208 | 209 | def run(workers): 210 | [x.start() for x in workers] 211 | while True: 212 | if all(x.exitcode == 0 for x in workers): 213 | print('All workers terminated successfully.') 214 | return 215 | for worker in workers: 216 | if worker.exitcode not in (None, 0): 217 | # Wait for everybody who wants to print their error messages. 218 | time.sleep(1) 219 | [x.terminate() for x in workers if x is not worker] 220 | raise RuntimeError(f'Stopped workers due to crash in {worker.name}.') 221 | time.sleep(0.1) 222 | -------------------------------------------------------------------------------- /recall2imagine/embodied/core/driver.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | import numpy as np 4 | 5 | from .basics import convert 6 | 7 | 8 | class Driver: 9 | 10 | _CONVERSION = { 11 | np.floating: np.float32, 12 | np.signedinteger: np.int32, 13 | np.uint8: np.uint8, 14 | bool: bool, 15 | } 16 | 17 | def __init__(self, env, **kwargs): 18 | assert len(env) > 0 19 | self._env = env 20 | self._kwargs = kwargs 21 | self._on_steps = [] 22 | self._on_episodes = [] 23 | self.reset() 24 | 25 | def reset(self): 26 | self._acts = { 27 | k: convert(np.zeros((len(self._env),) + v.shape, v.dtype)) 28 | for k, v in self._env.act_space.items()} 29 | self._acts['reset'] = np.ones(len(self._env), bool) 30 | self._eps = [collections.defaultdict(list) for _ in range(len(self._env))] 31 | self._state = None 32 | 33 | def on_step(self, callback): 34 | self._on_steps.append(callback) 35 | 36 | def on_episode(self, callback): 37 | self._on_episodes.append(callback) 38 | 39 | def __call__(self, policy, steps=0, episodes=0): 40 | step, episode = 0, 0 41 | while step < steps or episode < episodes: 42 | step, episode = self._step(policy, step, episode) 43 | 44 | def _step(self, policy, step, episode): 45 | assert all(len(x) == len(self._env) for x in self._acts.values()) 46 | acts = {k: v for k, v in self._acts.items() if not k.startswith('log_')} 47 | obs = self._env.step(acts) 48 | obs = {k: convert(v) for k, v in obs.items()} 49 | assert all(len(x) == len(self._env) for x in obs.values()), obs 50 | acts, self._state = policy(obs, self._state, **self._kwargs) 51 | acts = {k: convert(v) for k, v in acts.items()} 52 | if obs['is_last'].any(): 53 | mask = 1 - obs['is_last'] 54 | acts = {k: v * self._expand(mask, len(v.shape)) for k, v in acts.items()} 55 | acts['reset'] = obs['is_last'].copy() 56 | self._acts = acts 57 | trns = {**obs, **acts} 58 | if obs['is_first'].any(): 59 | for i, first in enumerate(obs['is_first']): 60 | if first: 61 | self._eps[i].clear() 62 | for i in range(len(self._env)): 63 | trn = {k: v[i] for k, v in trns.items()} 64 | [self._eps[i][k].append(v) for k, v in trn.items()] 65 | [fn(trn, i, **self._kwargs) for fn in self._on_steps] 66 | step += 1 67 | if obs['is_last'].any(): 68 | for i, done in enumerate(obs['is_last']): 69 | if done: 70 | ep = {k: convert(v) for k, v in self._eps[i].items()} 71 | [fn(ep.copy(), i, **self._kwargs) for fn in self._on_episodes] 72 | episode += 1 73 | return step, episode 74 | 75 | def _expand(self, value, dims): 76 | while len(value.shape) < dims: 77 | value = value[..., None] 78 | return value 79 | -------------------------------------------------------------------------------- /recall2imagine/embodied/core/flags.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | 4 | from . import config 5 | 6 | 7 | class Flags: 8 | 9 | def __init__(self, *args, **kwargs): 10 | self._config = config.Config(*args, **kwargs) 11 | 12 | def parse(self, argv=None, help_exists=True): 13 | parsed, remaining = self.parse_known(argv) 14 | for flag in remaining: 15 | if flag.startswith('--'): 16 | raise ValueError(f"Flag '{flag}' did not match any config keys.") 17 | assert not remaining, remaining 18 | return parsed 19 | 20 | def parse_known(self, argv=None, help_exists=False): 21 | if argv is None: 22 | argv = sys.argv[1:] 23 | if '--help' in argv: 24 | print('\nHelp:') 25 | lines = str(self._config).split('\n')[2:] 26 | print('\n'.join('--' + re.sub(r'[:,\[\]]', '', x) for x in lines)) 27 | help_exists and sys.exit() 28 | parsed = {} 29 | remaining = [] 30 | key = None 31 | vals = None 32 | for arg in argv: 33 | if arg.startswith('--'): 34 | if key: 35 | self._submit_entry(key, vals, parsed, remaining) 36 | if '=' in arg: 37 | key, val = arg.split('=', 1) 38 | vals = [val] 39 | else: 40 | key, vals = arg, [] 41 | else: 42 | if key: 43 | vals.append(arg) 44 | else: 45 | remaining.append(arg) 46 | self._submit_entry(key, vals, parsed, remaining) 47 | parsed = self._config.update(parsed) 48 | return parsed, remaining 49 | 50 | def _submit_entry(self, key, vals, parsed, remaining): 51 | if not key and not vals: 52 | return 53 | if not key: 54 | vals = ', '.join(f"'{x}'" for x in vals) 55 | raise ValueError(f"Values {vals} were not preceded by any flag.") 56 | name = key[len('--'):] 57 | if '=' in name: 58 | remaining.extend([key] + vals) 59 | return 60 | if self._config.IS_PATTERN.fullmatch(name): 61 | pattern = re.compile(name) 62 | keys = {k for k in self._config.flat if pattern.fullmatch(k)} 63 | elif name in self._config: 64 | keys = [name] 65 | else: 66 | keys = [] 67 | if not keys: 68 | remaining.extend([key] + vals) 69 | return 70 | if not vals: 71 | raise ValueError(f"Flag '{key}' was not followed by any values.") 72 | for key in keys: 73 | parsed[key] = self._parse_flag_value(self._config[key], vals, key) 74 | 75 | def _parse_flag_value(self, default, value, key): 76 | value = value if isinstance(value, (tuple, list)) else (value,) 77 | if isinstance(default, (tuple, list)): 78 | if len(value) == 1 and ',' in value[0]: 79 | value = value[0].split(',') 80 | return tuple(self._parse_flag_value(default[0], [x], key) for x in value) 81 | assert len(value) == 1, value 82 | value = str(value[0]) 83 | if default is None: 84 | return value 85 | if isinstance(default, bool): 86 | try: 87 | return bool(['False', 'True'].index(value)) 88 | except ValueError: 89 | message = f"Expected bool but got '{value}' for key '{key}'." 90 | raise TypeError(message) 91 | if isinstance(default, int): 92 | try: 93 | value = float(value) # Allow scientific notation for integers. 94 | assert float(int(value)) == value 95 | except (TypeError, AssertionError): 96 | message = f"Expected int but got float '{value}' for key '{key}'." 97 | raise TypeError(message) 98 | return int(value) 99 | if isinstance(default, dict): 100 | raise TypeError( 101 | f"Key '{key}' refers to a whole dict. Please speicfy a subkey.") 102 | return type(default)(value) 103 | -------------------------------------------------------------------------------- /recall2imagine/embodied/core/metrics.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import warnings 3 | 4 | import numpy as np 5 | 6 | 7 | class Metrics: 8 | 9 | def __init__(self): 10 | self._scalars = collections.defaultdict(list) 11 | self._lasts = {} 12 | 13 | def scalar(self, key, value): 14 | self._scalars[key].append(value) 15 | 16 | def image(self, key, value): 17 | self._lasts[key].append(value) 18 | 19 | def video(self, key, value): 20 | self._lasts[key].append(value) 21 | 22 | def add(self, mapping, prefix=None): 23 | for key, value in mapping.items(): 24 | key = prefix + '/' + key if prefix else key 25 | if hasattr(value, 'shape') and len(value.shape) > 0: 26 | self._lasts[key] = value 27 | else: 28 | self._scalars[key].append(value) 29 | 30 | def result(self, reset=True): 31 | result = {} 32 | result.update(self._lasts) 33 | with warnings.catch_warnings(): # Ignore empty slice warnings. 34 | warnings.simplefilter('ignore', category=RuntimeWarning) 35 | for key, values in self._scalars.items(): 36 | result[key] = np.nanmean(values, dtype=np.float64) 37 | reset and self.reset() 38 | return result 39 | 40 | def reset(self): 41 | self._scalars.clear() 42 | self._lasts.clear() 43 | -------------------------------------------------------------------------------- /recall2imagine/embodied/core/parallel.py: -------------------------------------------------------------------------------- 1 | import enum 2 | from functools import partial as bind 3 | 4 | from . import worker 5 | 6 | 7 | class Parallel: 8 | 9 | def __init__(self, ctor, strategy): 10 | self.worker = worker.Worker( 11 | bind(self._respond, ctor), strategy, state=True) 12 | self.callables = {} 13 | 14 | def __getattr__(self, name): 15 | if name.startswith('_'): 16 | raise AttributeError(name) 17 | try: 18 | if name not in self.callables: 19 | self.callables[name] = self.worker(Message.CALLABLE, name)() 20 | if self.callables[name]: 21 | return bind(self.worker, Message.CALL, name) 22 | else: 23 | return self.worker(Message.READ, name)() 24 | except AttributeError: 25 | raise ValueError(name) 26 | 27 | def __len__(self): 28 | return self.worker(Message.CALL, '__len__')() 29 | 30 | def close(self): 31 | self.worker.close() 32 | 33 | @staticmethod 34 | def _respond(ctor, state, message, name, *args, **kwargs): 35 | state = state or ctor() 36 | if message == Message.CALLABLE: 37 | assert not args and not kwargs, (args, kwargs) 38 | result = callable(getattr(state, name)) 39 | elif message == Message.CALL: 40 | result = getattr(state, name)(*args, **kwargs) 41 | elif message == Message.READ: 42 | assert not args and not kwargs, (args, kwargs) 43 | result = getattr(state, name) 44 | return state, result 45 | 46 | 47 | class Message(enum.Enum): 48 | 49 | CALLABLE = 2 50 | CALL = 3 51 | READ = 4 52 | -------------------------------------------------------------------------------- /recall2imagine/embodied/core/path.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import glob 3 | import os 4 | import re 5 | import shutil 6 | 7 | 8 | class Path: 9 | 10 | filesystems = [] 11 | 12 | def __new__(cls, path): 13 | path = str(path) 14 | for impl, pred in cls.filesystems: 15 | if pred(path): 16 | obj = super().__new__(impl) 17 | obj.__init__(path) 18 | return obj 19 | raise NotImplementedError(f'No filesystem supports: {path}') 20 | 21 | def __getnewargs__(self): 22 | return (self._path,) 23 | 24 | def __init__(self, path): 25 | assert isinstance(path, str) 26 | path = re.sub(r'^\./*', '', path) # Remove leading dot or dot slashes. 27 | path = re.sub(r'(?<=[^/])/$', '', path) # Remove single trailing slash. 28 | path = path or '.' # Empty path is represented by a dot. 29 | self._path = path 30 | 31 | def __truediv__(self, part): 32 | sep = '' if self._path.endswith('/') else '/' 33 | return type(self)(f'{self._path}{sep}{str(part)}') 34 | 35 | def __repr__(self): 36 | return f'Path({str(self)})' 37 | 38 | def __fspath__(self): 39 | return str(self) 40 | 41 | def __eq__(self, other): 42 | return self._path == other._path 43 | 44 | def __lt__(self, other): 45 | return self._path < other._path 46 | 47 | def __str__(self): 48 | return self._path 49 | 50 | @property 51 | def parent(self): 52 | if '/' not in self._path: 53 | return type(self)('.') 54 | parent = self._path.rsplit('/', 1)[0] 55 | parent = parent or ('/' if self._path.startswith('/') else '.') 56 | return type(self)(parent) 57 | 58 | @property 59 | def name(self): 60 | if '/' not in self._path: 61 | return self._path 62 | return self._path.rsplit('/', 1)[1] 63 | 64 | @property 65 | def stem(self): 66 | return self.name.split('.', 1)[0] if '.' in self.name else self.name 67 | 68 | @property 69 | def suffix(self): 70 | return ('.' + self.name.split('.', 1)[1]) if '.' in self.name else '' 71 | 72 | def read(self, mode='r'): 73 | assert mode in 'r rb'.split(), mode 74 | with self.open(mode) as f: 75 | return f.read() 76 | 77 | def write(self, content, mode='w'): 78 | assert mode in 'w a wb ab'.split(), mode 79 | with self.open(mode) as f: 80 | f.write(content) 81 | 82 | @contextlib.contextmanager 83 | def open(self, mode='r'): 84 | raise NotImplementedError 85 | 86 | def absolute(self): 87 | raise NotImplementedError 88 | 89 | def glob(self, pattern): 90 | raise NotImplementedError 91 | 92 | def exists(self): 93 | raise NotImplementedError 94 | 95 | def isfile(self): 96 | raise NotImplementedError 97 | 98 | def isdir(self): 99 | raise NotImplementedError 100 | 101 | def mkdirs(self): 102 | raise NotImplementedError 103 | 104 | def remove(self): 105 | raise NotImplementedError 106 | 107 | def rmtree(self): 108 | raise NotImplementedError 109 | 110 | def copy(self, dest): 111 | raise NotImplementedError 112 | 113 | def move(self, dest): 114 | self.copy(dest) 115 | self.remove() 116 | 117 | 118 | class LocalPath(Path): 119 | 120 | def __init__(self, path): 121 | super().__init__(os.path.expanduser(str(path))) 122 | 123 | @contextlib.contextmanager 124 | def open(self, mode='r'): 125 | with open(str(self), mode=mode) as f: 126 | yield f 127 | 128 | def absolute(self): 129 | return type(self)(os.path.absolute(str(self))) 130 | 131 | def glob(self, pattern): 132 | for path in glob.glob(f'{str(self)}/{pattern}'): 133 | yield type(self)(path) 134 | 135 | def exists(self): 136 | return os.path.exists(str(self)) 137 | 138 | def isfile(self): 139 | return os.path.isfile(str(self)) 140 | 141 | def isdir(self): 142 | return os.path.isdir(str(self)) 143 | 144 | def mkdirs(self): 145 | os.makedirs(str(self), exist_ok=True) 146 | 147 | def remove(self): 148 | os.rmdir(str(self)) if self.isdir() else os.remove(str(self)) 149 | 150 | def rmtree(self): 151 | shutil.rmtree(self) 152 | 153 | def copy(self, dest): 154 | if self.isfile(): 155 | shutil.copy(self, type(self)(dest)) 156 | else: 157 | shutil.copytree(self, type(self)(dest), dirs_exist_ok=True) 158 | 159 | def move(self, dest): 160 | shutil.move(self, dest) 161 | 162 | 163 | class GFilePath(Path): 164 | 165 | def __init__(self, path): 166 | path = str(path) 167 | if not (path.startswith('/') or '://' in path): 168 | path = os.path.abspath(os.path.expanduser(path)) 169 | super().__init__(path) 170 | import tensorflow as tf 171 | self._gfile = tf.io.gfile 172 | 173 | @contextlib.contextmanager 174 | def open(self, mode='r'): 175 | path = str(self) 176 | if 'a' in mode and path.startswith('/cns/'): 177 | path += '%r=3.2' 178 | if mode.startswith('x') and self.exists(): 179 | raise FileExistsError(path) 180 | mode = mode.replace('x', 'w') 181 | with self._gfile.GFile(path, mode) as f: 182 | yield f 183 | 184 | def absolute(self): 185 | return self 186 | 187 | def glob(self, pattern): 188 | for path in self._gfile.glob(f'{str(self)}/{pattern}'): 189 | yield type(self)(path) 190 | 191 | def exists(self): 192 | return self._gfile.exists(str(self)) 193 | 194 | def isfile(self): 195 | return self.exists() and not self.isdir() 196 | 197 | def isdir(self): 198 | return self._gfile.isdir(str(self)) 199 | 200 | def mkdirs(self): 201 | self._gfile.makedirs(str(self)) 202 | 203 | def remove(self): 204 | self._gfile.remove(str(self)) 205 | 206 | def rmtree(self): 207 | self._gfile.rmtree(str(self)) 208 | 209 | def copy(self, dest): 210 | self._gfile.copy(str(self), str(dest), overwrite=True) 211 | 212 | def move(self, dest): 213 | dest = Path(dest) 214 | if dest.isdir(): 215 | dest.rmtree() 216 | self._gfile.rename(self, str(dest), overwrite=True) 217 | 218 | 219 | Path.filesystems = [ 220 | (GFilePath, lambda path: path.startswith('gs://')), 221 | (GFilePath, lambda path: path.startswith('/cns/')), 222 | (LocalPath, lambda path: True), 223 | ] 224 | -------------------------------------------------------------------------------- /recall2imagine/embodied/core/random.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class RandomAgent: 5 | 6 | def __init__(self, act_space): 7 | self.act_space = act_space 8 | 9 | def policy(self, obs, state=None, mode='train'): 10 | batch_size = len(next(iter(obs.values()))) 11 | act = { 12 | k: np.stack([v.sample() for _ in range(batch_size)]) 13 | for k, v in self.act_space.items() if k != 'reset'} 14 | return act, state 15 | -------------------------------------------------------------------------------- /recall2imagine/embodied/core/space.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Space: 5 | 6 | def __init__(self, dtype, shape=(), low=None, high=None): 7 | # For integer types, high is the excluside upper bound. 8 | shape = (shape,) if isinstance(shape, int) else shape 9 | self._dtype = np.dtype(dtype) 10 | assert self._dtype is not object, self._dtype 11 | assert isinstance(shape, tuple), shape 12 | self._low = self._infer_low(dtype, shape, low, high) 13 | self._high = self._infer_high(dtype, shape, low, high) 14 | self._shape = self._infer_shape(dtype, shape, low, high) 15 | self._discrete = ( 16 | np.issubdtype(self.dtype, np.integer) or self.dtype == bool) 17 | self._random = np.random.RandomState() 18 | 19 | @property 20 | def dtype(self): 21 | return self._dtype 22 | 23 | @property 24 | def shape(self): 25 | return self._shape 26 | 27 | @property 28 | def low(self): 29 | return self._low 30 | 31 | @property 32 | def high(self): 33 | return self._high 34 | 35 | @property 36 | def discrete(self): 37 | return self._discrete 38 | 39 | def __repr__(self): 40 | return ( 41 | f'Space(dtype={self.dtype.name}, ' 42 | f'shape={self.shape}, ' 43 | f'low={self.low.min()}, ' 44 | f'high={self.high.max()})') 45 | 46 | def __contains__(self, value): 47 | value = np.asarray(value) 48 | if value.shape != self.shape: 49 | return False 50 | if (value > self.high).any(): 51 | return False 52 | if (value < self.low).any(): 53 | return False 54 | if (value.astype(self.dtype).astype(value.dtype) != value).any(): 55 | return False 56 | return True 57 | 58 | def sample(self): 59 | low, high = self.low, self.high 60 | if np.issubdtype(self.dtype, np.floating): 61 | low = np.maximum(np.ones(self.shape) * np.finfo(self.dtype).min, low) 62 | high = np.minimum(np.ones(self.shape) * np.finfo(self.dtype).max, high) 63 | return self._random.uniform(low, high, self.shape).astype(self.dtype) 64 | 65 | def _infer_low(self, dtype, shape, low, high): 66 | if low is not None: 67 | try: 68 | return np.broadcast_to(low, shape) 69 | except ValueError: 70 | raise ValueError(f'Cannot broadcast {low} to shape {shape}') 71 | elif np.issubdtype(dtype, np.floating): 72 | return -np.inf * np.ones(shape) 73 | elif np.issubdtype(dtype, np.integer): 74 | return np.iinfo(dtype).min * np.ones(shape, dtype) 75 | elif np.issubdtype(dtype, bool): 76 | return np.zeros(shape, bool) 77 | else: 78 | raise ValueError('Cannot infer low bound from shape and dtype.') 79 | 80 | def _infer_high(self, dtype, shape, low, high): 81 | if high is not None: 82 | try: 83 | return np.broadcast_to(high, shape) 84 | except ValueError: 85 | raise ValueError(f'Cannot broadcast {high} to shape {shape}') 86 | elif np.issubdtype(dtype, np.floating): 87 | return np.inf * np.ones(shape) 88 | elif np.issubdtype(dtype, np.integer): 89 | return np.iinfo(dtype).max * np.ones(shape, dtype) 90 | elif np.issubdtype(dtype, bool): 91 | return np.ones(shape, bool) 92 | else: 93 | raise ValueError('Cannot infer high bound from shape and dtype.') 94 | 95 | def _infer_shape(self, dtype, shape, low, high): 96 | if shape is None and low is not None: 97 | shape = low.shape 98 | if shape is None and high is not None: 99 | shape = high.shape 100 | if not hasattr(shape, '__len__'): 101 | shape = (shape,) 102 | assert all(dim and dim > 0 for dim in shape), shape 103 | return tuple(shape) 104 | -------------------------------------------------------------------------------- /recall2imagine/embodied/core/timer.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import contextlib 3 | import time 4 | 5 | import numpy as np 6 | 7 | 8 | class Timer: 9 | 10 | def __init__(self, columns=('frac', 'min', 'avg', 'max', 'count', 'total')): 11 | available = ('frac', 'avg', 'min', 'max', 'count', 'total') 12 | assert all(x in available for x in columns), columns 13 | self._columns = columns 14 | self._durations = collections.defaultdict(list) 15 | self._start = time.time() 16 | 17 | def reset(self): 18 | for timings in self._durations.values(): 19 | timings.clear() 20 | self._start = time.time() 21 | 22 | @contextlib.contextmanager 23 | def scope(self, name): 24 | start = time.time() 25 | yield 26 | stop = time.time() 27 | self._durations[name].append(stop - start) 28 | 29 | def wrap(self, name, obj, methods): 30 | for method in methods: 31 | decorator = self.scope(f'{name}.{method}') 32 | setattr(obj, method, decorator(getattr(obj, method))) 33 | 34 | def stats(self, reset=True, log=False): 35 | metrics = {} 36 | metrics['duration'] = time.time() - self._start 37 | for name, durs in self._durations.items(): 38 | available = {} 39 | available['count'] = len(durs) 40 | available['total'] = np.sum(durs) 41 | available['frac'] = np.sum(durs) / metrics['duration'] 42 | if len(durs): 43 | available['avg'] = np.mean(durs) 44 | available['min'] = np.min(durs) 45 | available['max'] = np.max(durs) 46 | for key, value in available.items(): 47 | if key in self._columns: 48 | metrics[f'{name}_{key}'] = value 49 | if log: 50 | self._log(metrics) 51 | if reset: 52 | self.reset() 53 | return metrics 54 | 55 | def _log(self, metrics): 56 | names = self._durations.keys() 57 | names = sorted(names, key=lambda k: -metrics[f'{k}_frac']) 58 | print('Timer:'.ljust(20), ' '.join(x.rjust(8) for x in self._columns)) 59 | for name in names: 60 | values = [metrics[f'{name}_{col}'] for col in self._columns] 61 | print(f'{name.ljust(20)}', ' '.join((f'{x:8.4f}' for x in values))) 62 | -------------------------------------------------------------------------------- /recall2imagine/embodied/core/uuid.py: -------------------------------------------------------------------------------- 1 | import string 2 | import uuid as uuidlib 3 | 4 | import numpy as np 5 | 6 | 7 | class uuid: 8 | """UUID that is stored as 16 byte string and can be converted to and from 9 | int, string, and array types.""" 10 | 11 | DEBUG_ID = None 12 | BASE62 = string.digits + string.ascii_letters 13 | BASE62REV = {x: i for i, x in enumerate(BASE62)} 14 | 15 | @classmethod 16 | def reset(cls, *, debug): 17 | cls.DEBUG_ID = 0 if debug else None 18 | 19 | def __init__(self, value=None): 20 | if value is None: 21 | if self.DEBUG_ID is None: 22 | self.value = uuidlib.uuid4().bytes 23 | else: 24 | type(self).DEBUG_ID += 1 25 | self.value = self.DEBUG_ID.to_bytes(16, 'big') 26 | elif isinstance(value, uuid): 27 | self.value = value.value 28 | elif isinstance(value, int): 29 | self.value = value.to_bytes(16, 'big') 30 | elif isinstance(value, str): 31 | if self.DEBUG_ID is None: 32 | integer = 0 33 | for index, char in enumerate(value[::-1]): 34 | integer += (62 ** index) * self.BASE62REV[char] 35 | self.value = integer.to_bytes(16, 'big') 36 | else: 37 | self.value = int(value).to_bytes(16, 'big') 38 | elif isinstance(value, np.ndarray): 39 | self.value = value.tobytes() 40 | else: 41 | raise ValueError(value) 42 | assert type(self.value) == bytes, type(self.value) 43 | assert len(self.value) == 16, len(self.value) 44 | self._hash = hash(self.value) 45 | 46 | def __int__(self): 47 | return int.from_bytes(self.value, 'big') 48 | 49 | def __str__(self): 50 | if self.DEBUG_ID is not None: 51 | return str(int(self)) 52 | chars = [] 53 | integer = int(self) 54 | while integer != 0: 55 | chars.append(self.BASE62[integer % 62]) 56 | integer //= 62 57 | while len(chars) < 22: 58 | chars.append('0') 59 | return ''.join(chars[::-1]) 60 | 61 | def __array__(self): 62 | return np.frombuffer(self.value, np.uint8) 63 | 64 | def __getitem__(self, index): 65 | return self.__array__()[index] 66 | 67 | def __repr__(self): 68 | return str(self) 69 | 70 | def __eq__(self, other): 71 | return self.value == other.value 72 | 73 | def __hash__(self): 74 | return self._hash 75 | -------------------------------------------------------------------------------- /recall2imagine/embodied/core/when.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | 4 | class Every: 5 | 6 | def __init__(self, every, initial=True): 7 | self._every = every 8 | self._initial = initial 9 | self._prev = None 10 | 11 | def __call__(self, step): 12 | step = int(step) 13 | if self._every < 0: 14 | return True 15 | if self._every == 0: 16 | return False 17 | if self._prev is None: 18 | self._prev = (step // self._every) * self._every 19 | return self._initial 20 | if step >= self._prev + self._every: 21 | self._prev += self._every 22 | return True 23 | return False 24 | 25 | 26 | class Ratio: 27 | 28 | def __init__(self, ratio): 29 | assert ratio >= 0, ratio 30 | self._ratio = ratio 31 | self._prev = None 32 | 33 | def __call__(self, step): 34 | step = int(step) 35 | if self._ratio == 0: 36 | return 0 37 | if self._prev is None: 38 | self._prev = step 39 | return 1 40 | repeats = int((step - self._prev) * self._ratio) 41 | self._prev += repeats / self._ratio 42 | return repeats 43 | 44 | 45 | class Once: 46 | 47 | def __init__(self): 48 | self._once = True 49 | 50 | def __call__(self): 51 | if self._once: 52 | self._once = False 53 | return True 54 | return False 55 | 56 | 57 | class Until: 58 | 59 | def __init__(self, until): 60 | self._until = until 61 | 62 | def __call__(self, step): 63 | step = int(step) 64 | if not self._until: 65 | return True 66 | return step < self._until 67 | 68 | 69 | class Clock: 70 | 71 | def __init__(self, every): 72 | self._every = every 73 | self._prev = None 74 | 75 | def __call__(self, step=None): 76 | if self._every < 0: 77 | return True 78 | if self._every == 0: 79 | return False 80 | now = time.time() 81 | if self._prev is None: 82 | self._prev = now 83 | return True 84 | if now >= self._prev + self._every: 85 | # self._prev += self._every 86 | self._prev = now 87 | return True 88 | return False 89 | -------------------------------------------------------------------------------- /recall2imagine/embodied/core/worker.py: -------------------------------------------------------------------------------- 1 | import atexit 2 | import concurrent.futures 3 | import enum 4 | import os 5 | import sys 6 | import time 7 | import traceback 8 | from functools import partial as bind 9 | 10 | 11 | class Worker: 12 | 13 | initializers = [] 14 | 15 | def __init__(self, fn, strategy='thread', state=False): 16 | if not state: 17 | fn = lambda s, *args, fn=fn, **kwargs: (s, fn(*args, **kwargs)) 18 | inits = self.initializers 19 | self.impl = { 20 | 'blocking': BlockingWorker, 21 | 'thread': ThreadWorker, 22 | 'process': bind(ProcessPipeWorker, initializers=inits), 23 | 'daemon': bind(ProcessPipeWorker, initializers=inits, daemon=True), 24 | 'process_slow': bind(ProcessWorker, initializers=inits), 25 | }[strategy](fn) 26 | self.promise = None 27 | 28 | def __call__(self, *args, **kwargs): 29 | self.promise and self.promise() # Raise previous exception if any. 30 | self.promise = self.impl(*args, **kwargs) 31 | return self.promise 32 | 33 | def wait(self): 34 | return self.impl.wait() 35 | 36 | def close(self): 37 | self.impl.close() 38 | 39 | 40 | class BlockingWorker: 41 | 42 | def __init__(self, fn): 43 | self.fn = fn 44 | self.state = None 45 | 46 | def __call__(self, *args, **kwargs): 47 | self.state, result = self.fn(self.state, *args, **kwargs) 48 | # return lambda: result 49 | return lambda result=result: result 50 | 51 | def wait(self): 52 | pass 53 | 54 | def close(self): 55 | pass 56 | 57 | 58 | class ThreadWorker: 59 | 60 | def __init__(self, fn): 61 | self.fn = fn 62 | self.state = None 63 | self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) 64 | self.futures = [] 65 | 66 | def __call__(self, *args, **kwargs): 67 | future = self.executor.submit(self._worker, *args, **kwargs) 68 | self.futures.append(future) 69 | future.add_done_callback(lambda f: self.futures.remove(f)) 70 | return future.result 71 | 72 | def wait(self): 73 | concurrent.futures.wait(self.futures) 74 | 75 | def close(self): 76 | self.executor.shutdown(wait=False, cancel_futures=True) 77 | 78 | def _worker(self, *args, **kwargs): 79 | self.state, output = self.fn(self.state, *args, **kwargs) 80 | return output 81 | 82 | 83 | class ProcessWorker: 84 | 85 | def __init__(self, fn, initializers=()): 86 | import cloudpickle 87 | import multiprocessing 88 | fn = cloudpickle.dumps(fn) 89 | initializers = cloudpickle.dumps(initializers) 90 | self.executor = concurrent.futures.ProcessPoolExecutor( 91 | max_workers=1, mp_context=multiprocessing.get_context('spawn'), 92 | initializer=self._initializer, initargs=(fn, initializers)) 93 | self.futures = [] 94 | 95 | def __call__(self, *args, **kwargs): 96 | future = self.executor.submit(self._worker, *args, **kwargs) 97 | self.futures.append(future) 98 | future.add_done_callback(lambda f: self.futures.remove(f)) 99 | return future.result 100 | 101 | def wait(self): 102 | concurrent.futures.wait(self.futures) 103 | 104 | def close(self): 105 | self.executor.shutdown(wait=False, cancel_futures=True) 106 | 107 | @staticmethod 108 | def _initializer(fn, initializers): 109 | global _FN, _STATE 110 | import cloudpickle 111 | _FN = cloudpickle.loads(fn) 112 | _STATE = None 113 | for initializer in cloudpickle.loads(initializers): 114 | initializers() 115 | 116 | @staticmethod 117 | def _worker(*args, **kwargs): 118 | global _FN, _STATE 119 | _STATE, output = _FN(_STATE, *args, **kwargs) 120 | return output 121 | 122 | 123 | class ProcessPipeWorker: 124 | 125 | def __init__(self, fn, initializers=(), daemon=False): 126 | import multiprocessing 127 | import cloudpickle 128 | self._context = multiprocessing.get_context('spawn') 129 | self._pipe, pipe = self._context.Pipe() 130 | fn = cloudpickle.dumps(fn) 131 | initializers = cloudpickle.dumps(initializers) 132 | self._process = self._context.Process( 133 | target=self._loop, 134 | args=(pipe, fn, initializers), 135 | daemon=daemon) 136 | self._process.start() 137 | self._nextid = 0 138 | self._results = {} 139 | assert self._submit(Message.OK)() 140 | atexit.register(self.close) 141 | 142 | def __call__(self, *args, **kwargs): 143 | return self._submit(Message.RUN, (args, kwargs)) 144 | 145 | def wait(self): 146 | pass 147 | 148 | def close(self): 149 | try: 150 | self._pipe.send((Message.STOP, self._nextid, None)) 151 | self._pipe.close() 152 | except (AttributeError, IOError): 153 | pass # The connection was already closed. 154 | try: 155 | self._process.join(0.1) 156 | if self._process.exitcode is None: 157 | try: 158 | os.kill(self._process.pid, 9) 159 | time.sleep(0.1) 160 | except Exception: 161 | pass 162 | except (AttributeError, AssertionError): 163 | pass 164 | 165 | def _submit(self, message, payload=None): 166 | callid = self._nextid 167 | self._nextid += 1 168 | self._pipe.send((message, callid, payload)) 169 | return Future(self._receive, callid) 170 | 171 | def _receive(self, callid): 172 | while callid not in self._results: 173 | try: 174 | message, callid, payload = self._pipe.recv() 175 | except (OSError, EOFError): 176 | raise RuntimeError('Lost connection to worker.') 177 | if message == Message.ERROR: 178 | raise Exception(payload) 179 | assert message == Message.RESULT, message 180 | self._results[callid] = payload 181 | return self._results.pop(callid) 182 | 183 | @staticmethod 184 | def _loop(pipe, function, initializers): 185 | try: 186 | callid = None 187 | state = None 188 | import cloudpickle 189 | initializers = cloudpickle.loads(initializers) 190 | function = cloudpickle.loads(function) 191 | [fn() for fn in initializers] 192 | while True: 193 | if not pipe.poll(0.1): 194 | continue # Wake up for keyboard interrupts. 195 | message, callid, payload = pipe.recv() 196 | if message == Message.OK: 197 | pipe.send((Message.RESULT, callid, True)) 198 | elif message == Message.STOP: 199 | return 200 | elif message == Message.RUN: 201 | args, kwargs = payload 202 | state, result = function(state, *args, **kwargs) 203 | pipe.send((Message.RESULT, callid, result)) 204 | else: 205 | raise KeyError(f'Invalid message: {message}') 206 | except (EOFError, KeyboardInterrupt): 207 | return 208 | except Exception: 209 | stacktrace = ''.join(traceback.format_exception(*sys.exc_info())) 210 | print(f'Error inside process worker: {stacktrace}.', flush=True) 211 | pipe.send((Message.ERROR, callid, stacktrace)) 212 | return 213 | finally: 214 | try: 215 | pipe.close() 216 | except Exception: 217 | pass 218 | 219 | 220 | class Future: 221 | 222 | def __init__(self, receive, callid): 223 | self._receive = receive 224 | self._callid = callid 225 | self._result = None 226 | self._complete = False 227 | 228 | def __call__(self): 229 | if not self._complete: 230 | self._result = self._receive(self._callid) 231 | self._complete = True 232 | return self._result 233 | 234 | 235 | class Message(enum.Enum): 236 | 237 | OK = 1 238 | RUN = 2 239 | RESULT = 3 240 | STOP = 4 241 | ERROR = 5 242 | -------------------------------------------------------------------------------- /recall2imagine/embodied/envs/atari.py: -------------------------------------------------------------------------------- 1 | import embodied 2 | import numpy as np 3 | 4 | 5 | class Atari(embodied.Env): 6 | 7 | LOCK = None 8 | 9 | def __init__( 10 | self, name, repeat=4, size=(84, 84), gray=True, noops=0, lives='unused', 11 | sticky=True, actions='all', length=108000, resize='opencv', seed=None): 12 | assert size[0] == size[1] 13 | assert lives in ('unused', 'discount', 'reset'), lives 14 | assert actions in ('all', 'needed'), actions 15 | assert resize in ('opencv', 'pillow'), resize 16 | if self.LOCK is None: 17 | import multiprocessing as mp 18 | mp = mp.get_context('spawn') 19 | self.LOCK = mp.Lock() 20 | self._resize = resize 21 | if self._resize == 'opencv': 22 | import cv2 23 | self._cv2 = cv2 24 | if self._resize == 'pillow': 25 | from PIL import Image 26 | self._image = Image 27 | import gym.envs.atari 28 | if name == 'james_bond': 29 | name = 'jamesbond' 30 | self._repeat = repeat 31 | self._size = size 32 | self._gray = gray 33 | self._noops = noops 34 | self._lives = lives 35 | self._sticky = sticky 36 | self._length = length 37 | self._random = np.random.RandomState(seed) 38 | with self.LOCK: 39 | self._env = gym.envs.atari.AtariEnv( 40 | game=name, 41 | obs_type='image', 42 | frameskip=1, repeat_action_probability=0.25 if sticky else 0.0, 43 | full_action_space=(actions == 'all')) 44 | assert self._env.unwrapped.get_action_meanings()[0] == 'NOOP' 45 | shape = self._env.observation_space.shape 46 | self._buffer = [np.zeros(shape, np.uint8) for _ in range(2)] 47 | self._ale = self._env.unwrapped.ale 48 | self._last_lives = None 49 | self._done = True 50 | self._step = 0 51 | 52 | @property 53 | def obs_space(self): 54 | shape = self._size + (1 if self._gray else 3,) 55 | return { 56 | 'image': embodied.Space(np.uint8, shape), 57 | 'reward': embodied.Space(np.float32), 58 | 'is_first': embodied.Space(bool), 59 | 'is_last': embodied.Space(bool), 60 | 'is_terminal': embodied.Space(bool), 61 | } 62 | 63 | @property 64 | def act_space(self): 65 | return { 66 | 'action': embodied.Space(np.int32, (), 0, self._env.action_space.n), 67 | 'reset': embodied.Space(bool), 68 | } 69 | 70 | def step(self, action): 71 | if action['reset'] or self._done: 72 | with self.LOCK: 73 | self._reset() 74 | self._done = False 75 | self._step = 0 76 | return self._obs(0.0, is_first=True) 77 | total = 0.0 78 | dead = False 79 | for repeat in range(self._repeat): 80 | _, reward, over, info = self._env.step(action['action']) 81 | self._step += 1 82 | total += reward 83 | if repeat == self._repeat - 2: 84 | self._screen(self._buffer[1]) 85 | if over: 86 | break 87 | if self._lives != 'unused': 88 | current = self._ale.lives() 89 | if current < self._last_lives: 90 | dead = True 91 | self._last_lives = current 92 | break 93 | if not self._repeat: 94 | self._buffer[1][:] = self._buffer[0][:] 95 | self._screen(self._buffer[0]) 96 | self._done = over or (self._length and self._step >= self._length) 97 | return self._obs( 98 | total, 99 | is_last=self._done or (dead and self._lives == 'reset'), 100 | is_terminal=dead or over) 101 | 102 | def _reset(self): 103 | self._env.reset() 104 | if self._noops: 105 | for _ in range(self._random.randint(self._noops)): 106 | _, _, dead, _ = self._env.step(0) 107 | if dead: 108 | self._env.reset() 109 | self._last_lives = self._ale.lives() 110 | self._screen(self._buffer[0]) 111 | self._buffer[1].fill(0) 112 | 113 | def _obs(self, reward, is_first=False, is_last=False, is_terminal=False): 114 | np.maximum(self._buffer[0], self._buffer[1], out=self._buffer[0]) 115 | image = self._buffer[0] 116 | if image.shape[:2] != self._size: 117 | if self._resize == 'opencv': 118 | image = self._cv2.resize( 119 | image, self._size, interpolation=self._cv2.INTER_AREA) 120 | if self._resize == 'pillow': 121 | image = self._image.fromarray(image) 122 | image = image.resize(self._size, self._image.NEAREST) 123 | image = np.array(image) 124 | if self._gray: 125 | weights = [0.299, 0.587, 1 - (0.299 + 0.587)] 126 | image = np.tensordot(image, weights, (-1, 0)).astype(image.dtype) 127 | image = image[:, :, None] 128 | return dict( 129 | image=image, 130 | reward=reward, 131 | is_first=is_first, 132 | is_last=is_last, 133 | is_terminal=is_last, 134 | ) 135 | 136 | def _screen(self, array): 137 | self._ale.getScreenRGB2(array) 138 | 139 | def close(self): 140 | return self._env.close() 141 | -------------------------------------------------------------------------------- /recall2imagine/embodied/envs/crafter.py: -------------------------------------------------------------------------------- 1 | import embodied 2 | import numpy as np 3 | 4 | 5 | class Crafter(embodied.Env): 6 | 7 | def __init__(self, task, size=(64, 64), outdir=None, seed=None): 8 | assert task in ('reward', 'noreward') 9 | import crafter 10 | self._env = crafter.Env(size=size, reward=(task == 'reward'), seed=seed) 11 | if outdir: 12 | outdir = embodied.Path(outdir) 13 | self._env = crafter.Recorder( 14 | self._env, outdir, 15 | save_stats=True, 16 | save_video=False, 17 | save_episode=False, 18 | ) 19 | self._achievements = crafter.constants.achievements.copy() 20 | self._done = True 21 | 22 | @property 23 | def obs_space(self): 24 | spaces = { 25 | 'image': embodied.Space(np.uint8, self._env.observation_space.shape), 26 | 'reward': embodied.Space(np.float32), 27 | 'is_first': embodied.Space(bool), 28 | 'is_last': embodied.Space(bool), 29 | 'is_terminal': embodied.Space(bool), 30 | 'log_reward': embodied.Space(np.float32), 31 | } 32 | spaces.update({ 33 | f'log_achievement_{k}': embodied.Space(np.int32) 34 | for k in self._achievements}) 35 | return spaces 36 | 37 | @property 38 | def act_space(self): 39 | return { 40 | 'action': embodied.Space(np.int32, (), 0, self._env.action_space.n), 41 | 'reset': embodied.Space(bool), 42 | } 43 | 44 | def step(self, action): 45 | if action['reset'] or self._done: 46 | self._done = False 47 | image = self._env.reset() 48 | return self._obs(image, 0.0, {}, is_first=True) 49 | image, reward, self._done, info = self._env.step(action['action']) 50 | reward = np.float32(reward) 51 | return self._obs( 52 | image, reward, info, 53 | is_last=self._done, 54 | is_terminal=info['discount'] == 0) 55 | 56 | def _obs( 57 | self, image, reward, info, 58 | is_first=False, is_last=False, is_terminal=False): 59 | log_achievements = { 60 | f'log_achievement_{k}': info['achievements'][k] if info else 0 61 | for k in self._achievements} 62 | return dict( 63 | image=image, 64 | reward=reward, 65 | is_first=is_first, 66 | is_last=is_last, 67 | is_terminal=is_terminal, 68 | log_reward=np.float32(info['reward'] if info else 0.0), 69 | **log_achievements, 70 | ) 71 | 72 | def render(self): 73 | return self._env.render() 74 | -------------------------------------------------------------------------------- /recall2imagine/embodied/envs/dmc.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import os 3 | 4 | import embodied 5 | import numpy as np 6 | 7 | 8 | class DMC(embodied.Env): 9 | 10 | DEFAULT_CAMERAS = dict( 11 | locom_rodent=1, 12 | quadruped=2, 13 | ) 14 | 15 | def __init__(self, env, repeat=1, render=True, size=(64, 64), camera=-1): 16 | # TODO: This env variable is meant for headless GPU machines but may fail 17 | # on CPU-only machines. 18 | if 'MUJOCO_GL' not in os.environ: 19 | os.environ['MUJOCO_GL'] = 'egl' 20 | if isinstance(env, str): 21 | domain, task = env.split('_', 1) 22 | if camera == -1: 23 | camera = self.DEFAULT_CAMERAS.get(domain, 0) 24 | if domain == 'cup': # Only domain with multiple words. 25 | domain = 'ball_in_cup' 26 | if domain == 'manip': 27 | from dm_control import manipulation 28 | env = manipulation.load(task + '_vision') 29 | elif domain == 'locom': 30 | from dm_control.locomotion.examples import basic_rodent_2020 31 | env = getattr(basic_rodent_2020, task)() 32 | else: 33 | from dm_control import suite 34 | env = suite.load(domain, task) 35 | self._dmenv = env 36 | from . import from_dm 37 | self._env = from_dm.FromDM(self._dmenv) 38 | self._env = embodied.wrappers.ExpandScalars(self._env) 39 | self._env = embodied.wrappers.ActionRepeat(self._env, repeat) 40 | self._render = render 41 | self._size = size 42 | self._camera = camera 43 | 44 | @functools.cached_property 45 | def obs_space(self): 46 | spaces = self._env.obs_space.copy() 47 | if self._render: 48 | spaces['image'] = embodied.Space(np.uint8, self._size + (3,)) 49 | return spaces 50 | 51 | @functools.cached_property 52 | def act_space(self): 53 | return self._env.act_space 54 | 55 | def step(self, action): 56 | for key, space in self.act_space.items(): 57 | if not space.discrete: 58 | assert np.isfinite(action[key]).all(), (key, action[key]) 59 | obs = self._env.step(action) 60 | if self._render: 61 | obs['image'] = self.render() 62 | return obs 63 | 64 | def render(self): 65 | return self._dmenv.physics.render(*self._size, camera_id=self._camera) 66 | -------------------------------------------------------------------------------- /recall2imagine/embodied/envs/dmlab.py: -------------------------------------------------------------------------------- 1 | import embodied 2 | import numpy as np 3 | 4 | 5 | class DMLab(embodied.Env): 6 | 7 | # Small action set used by IMPALA. 8 | IMPALA_ACTION_SET = ( 9 | ( 0, 0, 0, 1, 0, 0, 0), # Forward 10 | ( 0, 0, 0, -1, 0, 0, 0), # Backward 11 | ( 0, 0, -1, 0, 0, 0, 0), # Strafe Left 12 | ( 0, 0, 1, 0, 0, 0, 0), # Strafe Right 13 | (-20, 0, 0, 0, 0, 0, 0), # Look Left 14 | ( 20, 0, 0, 0, 0, 0, 0), # Look Right 15 | (-20, 0, 0, 1, 0, 0, 0), # Look Left + Forward 16 | ( 20, 0, 0, 1, 0, 0, 0), # Look Right + Forward 17 | ( 0, 0, 0, 0, 1, 0, 0), # Fire 18 | ) 19 | 20 | # Large action set used by PopArt and R2D2. 21 | POPART_ACTION_SET = [ 22 | ( 0, 0, 0, 1, 0, 0, 0), # FW 23 | ( 0, 0, 0, -1, 0, 0, 0), # BW 24 | ( 0, 0, -1, 0, 0, 0, 0), # Strafe Left 25 | ( 0, 0, 1, 0, 0, 0, 0), # Strafe Right 26 | (-10, 0, 0, 0, 0, 0, 0), # Small LL 27 | ( 10, 0, 0, 0, 0, 0, 0), # Small LR 28 | (-60, 0, 0, 0, 0, 0, 0), # Large LL 29 | ( 60, 0, 0, 0, 0, 0, 0), # Large LR 30 | ( 0, 10, 0, 0, 0, 0, 0), # Look Down 31 | ( 0, -10, 0, 0, 0, 0, 0), # Look Up 32 | (-10, 0, 0, 1, 0, 0, 0), # FW + Small LL 33 | ( 10, 0, 0, 1, 0, 0, 0), # FW + Small LR 34 | (-60, 0, 0, 1, 0, 0, 0), # FW + Large LL 35 | ( 60, 0, 0, 1, 0, 0, 0), # FW + Large LR 36 | ( 0, 0, 0, 0, 1, 0, 0), # Fire 37 | ] 38 | 39 | def __init__( 40 | self, level, repeat=4, size=(64, 64), mode='train', 41 | action_set=IMPALA_ACTION_SET, episodic=True, seed=None): 42 | import deepmind_lab 43 | cache = None 44 | # path = os.environ.get('DMLAB_CACHE', None) 45 | # if path: 46 | # cache = Cache(path) 47 | self._size = size 48 | self._repeat = repeat 49 | self._action_set = action_set 50 | self._episodic = episodic 51 | self._random = np.random.RandomState(seed) 52 | config = dict(height=size[0], width=size[1], logLevel='WARN') 53 | if mode == 'train': 54 | if level.endswith('_test'): 55 | level = level.replace('_test', '_train') 56 | elif mode == 'eval': 57 | config.update(allowHoldOutLevels='true', mixerSeed=0x600D5EED) 58 | else: 59 | raise NotImplementedError(mode) 60 | config = {k: str(v) for k, v in config.items()} 61 | self._env = deepmind_lab.Lab( 62 | level='contributed/dmlab30/' + level, 63 | observations=['RGB_INTERLEAVED'], 64 | level_cache=cache, config=config) 65 | self._prev_image = None 66 | self._done = True 67 | 68 | @property 69 | def obs_space(self): 70 | return { 71 | 'image': embodied.Space(np.uint8, self._size + (3,)), 72 | 'reward': embodied.Space(np.float32), 73 | 'is_first': embodied.Space(bool), 74 | 'is_last': embodied.Space(bool), 75 | 'is_terminal': embodied.Space(bool), 76 | } 77 | 78 | @property 79 | def act_space(self): 80 | return { 81 | 'action': embodied.Space(np.int32, (), 0, len(self._action_set)), 82 | 'reset': embodied.Space(bool), 83 | } 84 | 85 | def step(self, action): 86 | if action['reset'] or self._done: 87 | self._env.reset(seed=self._random.randint(0, 2 ** 31 - 1)) 88 | self._done = False 89 | return self._obs(0.0, is_first=True) 90 | raw_action = np.array(self._action_set[action['action']], np.intc) 91 | reward = self._env.step(raw_action, num_steps=self._repeat) 92 | self._done = not self._env.is_running() 93 | return self._obs(reward, is_last=self._done) 94 | 95 | def _obs(self, reward, is_first=False, is_last=False): 96 | return dict( 97 | image=self.render(), 98 | reward=reward, 99 | is_first=is_first, 100 | is_last=is_last, 101 | is_terminal=is_last if self._episodic else False, 102 | ) 103 | 104 | def render(self): 105 | if not self._done: 106 | self._prev_image = self._env.observations()['RGB_INTERLEAVED'] 107 | return self._prev_image 108 | 109 | def close(self): 110 | self._env.close() 111 | 112 | 113 | class Cache: 114 | 115 | def __init__(self, cache_dir): 116 | self._cache_dir = cache_dir 117 | 118 | def get_path(self, key): 119 | import hashlib, os 120 | key = hashlib.md5(key.encode('utf-8')).hexdigest() 121 | dir_, filename = key[:3], key[3:] 122 | return os.path.join(self._cache_dir, dir_, filename) 123 | 124 | def fetch(self, key, pk3_path): 125 | import tensorflow as tf 126 | path = self.get_path(key) 127 | try: 128 | tf.io.gfile.copy(path, pk3_path, overwrite=True) 129 | return True 130 | except tf.errors.OpError: 131 | return False 132 | 133 | def write(self, key, pk3_path): 134 | import os 135 | import tensorflow as tf 136 | path = self.get_path(key) 137 | try: 138 | if not tf.io.gfile.exists(path): 139 | tf.io.gfile.makedirs(os.path.dirname(path)) 140 | tf.io.gfile.copy(pk3_path, path) 141 | except Exception as e: 142 | print(f'Could to store level: {e}') 143 | -------------------------------------------------------------------------------- /recall2imagine/embodied/envs/dummy.py: -------------------------------------------------------------------------------- 1 | import embodied 2 | import numpy as np 3 | 4 | 5 | class Dummy(embodied.Env): 6 | 7 | def __init__(self, task, size=(64, 64), length=100): 8 | assert task in ('cont', 'disc') 9 | self._task = task 10 | self._size = size 11 | self._length = length 12 | self._step = 0 13 | self._done = False 14 | 15 | @property 16 | def obs_space(self): 17 | return { 18 | 'image': embodied.Space(np.uint8, self._size + (3,)), 19 | 'vector': embodied.Space(np.float32, (7,)), 20 | 'step': embodied.Space(np.int32, (), 0, self._length), 21 | 'reward': embodied.Space(np.float32), 22 | 'is_first': embodied.Space(bool), 23 | 'is_last': embodied.Space(bool), 24 | 'is_terminal': embodied.Space(bool), 25 | } 26 | 27 | @property 28 | def act_space(self): 29 | if self._task == 'cont': 30 | space = embodied.Space(np.float32, (6,)) 31 | else: 32 | space = embodied.Space(np.int32, (), 0, 5) 33 | return {'action': space, 'reset': embodied.Space(bool)} 34 | 35 | def step(self, action): 36 | if action['reset'] or self._done: 37 | self._step = 0 38 | self._done = False 39 | return self._obs(0.0, is_first=True) 40 | action = action['action'] 41 | self._step += 1 42 | self._done = (self._step >= self._length) 43 | return self._obs(1.0, is_last=self._done, is_terminal=self._done) 44 | 45 | def _obs(self, reward, is_first=False, is_last=False, is_terminal=False): 46 | return dict( 47 | image=np.zeros(self._size + (3,), np.uint8), 48 | vector=np.zeros(7, np.float32), 49 | step=self._step, 50 | reward=reward, 51 | is_first=is_first, 52 | is_last=is_last, 53 | is_terminal=is_terminal, 54 | ) 55 | -------------------------------------------------------------------------------- /recall2imagine/embodied/envs/from_dm.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import os 3 | 4 | import embodied 5 | import numpy as np 6 | 7 | 8 | class FromDM(embodied.Env): 9 | 10 | def __init__(self, env, obs_key='observation', act_key='action'): 11 | if isinstance(env, str): 12 | if env.startswith('bsuite'): 13 | import bsuite 14 | _, task = env.split('_', 1) 15 | self._env = bsuite.load_from_id(task) 16 | else: 17 | self._env = env 18 | obs_spec = self._env.observation_spec() 19 | act_spec = self._env.action_spec() 20 | self._obs_dict = isinstance(obs_spec, dict) 21 | self._act_dict = isinstance(act_spec, dict) 22 | self._obs_key = not self._obs_dict and obs_key 23 | self._act_key = not self._act_dict and act_key 24 | self._obs_empty = [] 25 | self._done = True 26 | 27 | @functools.cached_property 28 | def obs_space(self): 29 | spec = self._env.observation_spec() 30 | spec = spec if self._obs_dict else {self._obs_key: spec} 31 | if 'reward' in spec: 32 | spec['obs_reward'] = spec.pop('reward') 33 | for key, value in spec.copy().items(): 34 | if int(np.prod(value.shape)) == 0: 35 | self._obs_empty.append(key) 36 | del spec[key] 37 | return { 38 | 'reward': embodied.Space(np.float32), 39 | 'is_first': embodied.Space(bool), 40 | 'is_last': embodied.Space(bool), 41 | 'is_terminal': embodied.Space(bool), 42 | 'step_no': embodied.Space(np.int32), 43 | **{k or self._obs_key: self._convert(v) for k, v in spec.items()}, 44 | } 45 | 46 | @functools.cached_property 47 | def act_space(self): 48 | spec = self._env.action_spec() 49 | spec = spec if self._act_dict else {self._act_key: spec} 50 | return { 51 | 'reset': embodied.Space(bool), 52 | **{k or self._act_key: self._convert(v) for k, v in spec.items()}, 53 | } 54 | 55 | def step(self, action): 56 | action = action.copy() 57 | reset = action.pop('reset') 58 | if reset or self._done: 59 | time_step = self._env.reset() 60 | else: 61 | action = action if self._act_dict else action[self._act_key] 62 | time_step = self._env.step(action) 63 | self._done = time_step.last() 64 | return self._obs(time_step) 65 | 66 | def _obs(self, time_step): 67 | if not time_step.first(): 68 | assert time_step.discount in (0, 1), time_step.discount 69 | obs = time_step.observation 70 | obs = dict(obs) if self._obs_dict else {self._obs_key: obs} 71 | if 'reward' in obs: 72 | obs['obs_reward'] = obs.pop('reward') 73 | for key in self._obs_empty: 74 | del obs[key] 75 | return dict( 76 | reward=np.float32(0.0 if time_step.first() else time_step.reward), 77 | is_first=time_step.first(), 78 | is_last=time_step.last(), 79 | is_terminal=False if time_step.first() else time_step.discount == 0, 80 | **{k: v if len(v.shape) != 2 else v.flatten() for k, v in obs.items()}, 81 | ) 82 | 83 | def _convert(self, space): 84 | shape = space.shape 85 | if len(shape) == 2: 86 | shape = (shape[0] * shape[1],) 87 | if hasattr(space, 'num_values'): 88 | return embodied.Space(space.dtype, (), 0, space.num_values) 89 | elif hasattr(space, 'minimum'): 90 | assert np.isfinite(space.minimum).all(), space.minimum 91 | assert np.isfinite(space.maximum).all(), space.maximum 92 | return embodied.Space( 93 | space.dtype, shape, space.minimum, space.maximum) 94 | else: 95 | return embodied.Space(space.dtype, shape, None, None) 96 | -------------------------------------------------------------------------------- /recall2imagine/embodied/envs/from_gym.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import embodied 4 | import gym 5 | import gymnasium as gym_ 6 | import numpy as np 7 | 8 | def flatten(space, onehot=False): 9 | if isinstance(space, gym_.spaces.Discrete): 10 | n = space.n 11 | p = np.eye(n) 12 | tr = lambda x: p[x] 13 | space_ = gym_.spaces.Box(0, 1, (n,)) 14 | elif isinstance(space, gym_.spaces.MultiDiscrete): 15 | p = [np.eye(t) for t in space.nvec] 16 | tr = lambda x: np.concatenate([eye[j] for eye, j in zip(p, x)], axis=0) 17 | size = space.nvec.sum() 18 | space_ = gym_.spaces.Box(0, 1, (size,)) 19 | elif isinstance(space, gym_.spaces.Tuple): 20 | p = [np.eye(t.n) for t in space.spaces] 21 | tr = lambda x: np.concatenate([eye[j] for eye, j in zip(p, x)], axis=0) 22 | space_ = gym_.spaces.Box(0, 1, (sum([len(pp) for pp in p]),)) 23 | return tr, space_ 24 | 25 | def unmap_action(a, nvec): 26 | newa = np.zeros(len(nvec), dtype=np.int32) 27 | for i in reversed(range(len(nvec))): 28 | newa[i] = a % nvec[i] 29 | a //= nvec[i] 30 | return newa 31 | 32 | class MyObsPopgym(gym_.Wrapper): 33 | def __init__(self, env): 34 | super().__init__(env) 35 | self.tr, space_ = flatten(self.observation_space) 36 | self.maction = self.action_space if isinstance(self.action_space, gym_.spaces.MultiDiscrete) else None 37 | if self.maction is not None: 38 | self.action_space = gym_.spaces.Discrete(np.prod(self.action_space.nvec)) 39 | self.observation_space = gym_.spaces.Dict({'observation': space_}) 40 | 41 | def step(self, action): 42 | if self.maction is not None: 43 | action = unmap_action(action, self.maction.nvec) 44 | obs, reward, ter, trun, info = super().step(action) 45 | obs = self.tr(obs) 46 | return {'observation': obs}, reward, ter | trun, info 47 | 48 | def reset(self): 49 | obs, info = super().reset() 50 | return {'observation': self.tr(obs)} 51 | 52 | class FromGym(embodied.Env): 53 | 54 | def __init__(self, env, obs_key='image', act_key='action', **kwargs): 55 | if isinstance(env, str): 56 | try: 57 | self._env = gym.make(env, **kwargs) 58 | except: 59 | self._env = gym_.make(env, **kwargs) 60 | if 'popgym' in env: 61 | self._env = MyObsPopgym(self._env) 62 | else: 63 | assert not kwargs, kwargs 64 | self._env = env 65 | self._obs_dict = hasattr(self._env.observation_space, 'spaces') 66 | self._act_dict = hasattr(self._env.action_space, 'spaces') 67 | self._obs_key = obs_key 68 | self._act_key = act_key 69 | self._done = True 70 | self._info = None 71 | 72 | @property 73 | def info(self): 74 | return self._info 75 | 76 | @functools.cached_property 77 | def obs_space(self): 78 | if self._obs_dict: 79 | spaces = self._flatten(self._env.observation_space.spaces) 80 | else: 81 | spaces = {self._obs_key: self._env.observation_space} 82 | spaces = {k: self._convert(v) for k, v in spaces.items()} 83 | return { 84 | **spaces, 85 | 'reward': embodied.Space(np.float32), 86 | 'is_first': embodied.Space(bool), 87 | 'is_last': embodied.Space(bool), 88 | 'is_terminal': embodied.Space(bool), 89 | 'step_no': embodied.Space(np.int32), 90 | # 'ep_no': embodied.Space(np.int32), 91 | } 92 | 93 | @functools.cached_property 94 | def act_space(self): 95 | if self._act_dict: 96 | spaces = self._flatten(self._env.action_space.spaces) 97 | else: 98 | spaces = {self._act_key: self._env.action_space} 99 | spaces = {k: self._convert(v) for k, v in spaces.items()} 100 | spaces['reset'] = embodied.Space(bool) 101 | return spaces 102 | 103 | def step(self, action): 104 | if action['reset'] or self._done: 105 | self._done = False 106 | obs = self._env.reset() 107 | return self._obs(obs, 0.0, is_first=True) 108 | if self._act_dict: 109 | action = self._unflatten(action) 110 | else: 111 | action = action[self._act_key] 112 | obs, reward, self._done, self._info = self._env.step(action) 113 | return self._obs( 114 | obs, reward, 115 | is_last=bool(self._done), 116 | is_terminal=bool(self._info.get('is_terminal', self._done))) 117 | 118 | def _obs( 119 | self, obs, reward, is_first=False, is_last=False, is_terminal=False): 120 | if not self._obs_dict: 121 | obs = {self._obs_key: obs} 122 | obs = self._flatten(obs) 123 | obs = {k: np.asarray(v) for k, v in obs.items()} 124 | obs.update( 125 | reward=np.float32(reward), 126 | is_first=is_first, 127 | is_last=is_last, 128 | is_terminal=is_terminal) 129 | return obs 130 | 131 | def render(self): 132 | image = self._env.render('rgb_array') 133 | assert image is not None 134 | return image 135 | 136 | def close(self): 137 | try: 138 | self._env.close() 139 | except Exception: 140 | pass 141 | 142 | def _flatten(self, nest, prefix=None): 143 | result = {} 144 | for key, value in nest.items(): 145 | key = prefix + '/' + key if prefix else key 146 | if isinstance(value, (gym.spaces.Dict, gym_.spaces.Dict)): 147 | value = value.spaces 148 | if isinstance(value, dict): 149 | result.update(self._flatten(value, key)) 150 | else: 151 | result[key] = value 152 | return result 153 | 154 | def _unflatten(self, flat): 155 | result = {} 156 | for key, value in flat.items(): 157 | parts = key.split('/') 158 | node = result 159 | for part in parts[:-1]: 160 | if part not in node: 161 | node[part] = {} 162 | node = node[part] 163 | node[parts[-1]] = value 164 | return result 165 | 166 | def _convert(self, space): 167 | if hasattr(space, 'n'): 168 | return embodied.Space(np.int32, (), 0, space.n) 169 | return embodied.Space(space.dtype, space.shape, space.low, space.high) 170 | -------------------------------------------------------------------------------- /recall2imagine/embodied/envs/loconav_quadruped.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dm_control import composer 4 | from dm_control import mjcf 5 | from dm_control.composer.observation import observable 6 | from dm_control.locomotion.walkers import base 7 | from dm_control.locomotion.walkers import legacy_base 8 | from dm_control.mujoco.wrapper import mjbindings 9 | import numpy as np 10 | 11 | enums = mjbindings.enums 12 | mjlib = mjbindings.mjlib 13 | 14 | 15 | class Quadruped(legacy_base.Walker): 16 | 17 | def _build(self, name='walker', initializer=None): 18 | super()._build(initializer=initializer) 19 | self._mjcf_root = mjcf.from_path( 20 | os.path.join(os.path.dirname(__file__), 'loconav_quadruped.xml')) 21 | if name: 22 | self._mjcf_root.model = name 23 | self._prev_action = np.zeros( 24 | self.action_spec.shape, self.action_spec.dtype) 25 | 26 | def initialize_episode(self, physics, random_state): 27 | self._prev_action = np.zeros_like(self._prev_action) 28 | 29 | def apply_action(self, physics, action, random_state): 30 | super().apply_action(physics, action, random_state) 31 | self._prev_action[:] = action 32 | 33 | def _build_observables(self): 34 | return QuadrupedObservables(self) 35 | 36 | @property 37 | def mjcf_model(self): 38 | return self._mjcf_root 39 | 40 | @property 41 | def upright_pose(self): 42 | return base.WalkerPose() 43 | 44 | @composer.cached_property 45 | def actuators(self): 46 | return self._mjcf_root.find_all('actuator') 47 | 48 | @composer.cached_property 49 | def root_body(self): 50 | return self._mjcf_root.find('body', 'torso') 51 | 52 | @composer.cached_property 53 | def bodies(self): 54 | return tuple(self.mjcf_model.find_all('body')) 55 | 56 | @composer.cached_property 57 | def mocap_tracking_bodies(self): 58 | return tuple(self.mjcf_model.find_all('body')) 59 | 60 | @property 61 | def mocap_joints(self): 62 | return self.mjcf_model.find_all('joint') 63 | 64 | @property 65 | def _foot_bodies(self): 66 | return ( 67 | self._mjcf_root.find('body', 'toe_front_left'), 68 | self._mjcf_root.find('body', 'toe_front_right'), 69 | self._mjcf_root.find('body', 'toe_back_right'), 70 | self._mjcf_root.find('body', 'toe_back_left'), 71 | ) 72 | 73 | @composer.cached_property 74 | def end_effectors(self): 75 | return self._foot_bodies 76 | 77 | @composer.cached_property 78 | def observable_joints(self): 79 | return self._mjcf_root.find_all('joint') 80 | 81 | @composer.cached_property 82 | def egocentric_camera(self): 83 | return self._mjcf_root.find('camera', 'egocentric') 84 | 85 | def aliveness(self, physics): 86 | return (physics.bind(self.root_body).xmat[-1] - 1.) / 2. 87 | 88 | @composer.cached_property 89 | def ground_contact_geoms(self): 90 | foot_geoms = [] 91 | for foot in self._foot_bodies: 92 | foot_geoms.extend(foot.find_all('geom')) 93 | return tuple(foot_geoms) 94 | 95 | @property 96 | def prev_action(self): 97 | return self._prev_action 98 | 99 | 100 | class QuadrupedObservables(legacy_base.WalkerObservables): 101 | 102 | @composer.observable 103 | def actuator_activations(self): 104 | def actuator_activations_in_egocentric_frame(physics): 105 | return physics.data.act 106 | return observable.Generic(actuator_activations_in_egocentric_frame) 107 | 108 | @composer.observable 109 | def root_global_pos(self): 110 | def root_pos(physics): 111 | root_xpos, _ = self._entity.get_pose(physics) 112 | return np.reshape(root_xpos, -1) 113 | return observable.Generic(root_pos) 114 | 115 | @composer.observable 116 | def torso_global_pos(self): 117 | def torso_pos(physics): 118 | root_body = self._entity.root_body 119 | root_body_xpos = physics.bind(root_body).xpos 120 | return np.reshape(root_body_xpos, -1) 121 | return observable.Generic(torso_pos) 122 | 123 | @property 124 | def proprioception(self): 125 | return ([ 126 | self.joints_pos, self.joints_vel, self.actuator_activations, 127 | self.sensors_accelerometer, self.sensors_gyro, 128 | self.sensors_velocimeter, 129 | self.sensors_force, self.sensors_torque, 130 | self.world_zaxis, 131 | self.root_global_pos, self.torso_global_pos, 132 | ] + self._collect_from_attachments('proprioception')) 133 | -------------------------------------------------------------------------------- /recall2imagine/embodied/envs/minecraft.py: -------------------------------------------------------------------------------- 1 | import embodied 2 | import numpy as np 3 | 4 | from . import minecraft_base 5 | 6 | 7 | class Minecraft(embodied.Wrapper): 8 | 9 | def __init__(self, task, *args, **kwargs): 10 | super().__init__({ 11 | 'wood': MinecraftWood, 12 | 'climb': MinecraftClimb, 13 | 'diamond': MinecraftDiamond, 14 | }[task](*args, **kwargs)) 15 | 16 | 17 | class MinecraftWood(embodied.Wrapper): 18 | 19 | def __init__(self, *args, **kwargs): 20 | actions = BASIC_ACTIONS 21 | self.rewards = [ 22 | CollectReward('log', repeated=1), 23 | HealthReward(), 24 | ] 25 | length = kwargs.pop('length', 36000) 26 | env = minecraft_base.MinecraftBase(actions, *args, **kwargs) 27 | env = embodied.wrappers.TimeLimit(env, length) 28 | super().__init__(env) 29 | 30 | def step(self, action): 31 | obs = self.env.step(action) 32 | obs['reward'] = sum([fn(obs, self.env.inventory) for fn in self.rewards]) 33 | return obs 34 | 35 | 36 | class MinecraftClimb(embodied.Wrapper): 37 | 38 | def __init__(self, *args, **kwargs): 39 | actions = BASIC_ACTIONS 40 | length = kwargs.pop('length', 36000) 41 | env = minecraft_base.MinecraftBase(actions, *args, **kwargs) 42 | env = embodied.wrappers.TimeLimit(env, length) 43 | super().__init__(env) 44 | self._previous = None 45 | self._health_reward = HealthReward() 46 | 47 | def step(self, action): 48 | obs = self.env.step(action) 49 | x, y, z = obs['log_player_pos'] 50 | height = np.float32(y) 51 | if obs['is_first']: 52 | self._previous = height 53 | obs['reward'] = height - self._previous 54 | obs['reward'] += self._health_reward(obs) 55 | self._previous = height 56 | return obs 57 | 58 | 59 | class MinecraftDiamond(embodied.Wrapper): 60 | 61 | def __init__(self, *args, **kwargs): 62 | actions = { 63 | **BASIC_ACTIONS, 64 | 'craft_planks': dict(craft='planks'), 65 | 'craft_stick': dict(craft='stick'), 66 | 'craft_crafting_table': dict(craft='crafting_table'), 67 | 'place_crafting_table': dict(place='crafting_table'), 68 | 'craft_wooden_pickaxe': dict(nearbyCraft='wooden_pickaxe'), 69 | 'craft_stone_pickaxe': dict(nearbyCraft='stone_pickaxe'), 70 | 'craft_iron_pickaxe': dict(nearbyCraft='iron_pickaxe'), 71 | 'equip_stone_pickaxe': dict(equip='stone_pickaxe'), 72 | 'equip_wooden_pickaxe': dict(equip='wooden_pickaxe'), 73 | 'equip_iron_pickaxe': dict(equip='iron_pickaxe'), 74 | 'craft_furnace': dict(nearbyCraft='furnace'), 75 | 'place_furnace': dict(place='furnace'), 76 | 'smelt_iron_ingot': dict(nearbySmelt='iron_ingot'), 77 | } 78 | self.rewards = [ 79 | CollectReward('log', once=1), 80 | CollectReward('planks', once=1), 81 | CollectReward('stick', once=1), 82 | CollectReward('crafting_table', once=1), 83 | CollectReward('wooden_pickaxe', once=1), 84 | CollectReward('cobblestone', once=1), 85 | CollectReward('stone_pickaxe', once=1), 86 | CollectReward('iron_ore', once=1), 87 | CollectReward('furnace', once=1), 88 | CollectReward('iron_ingot', once=1), 89 | CollectReward('iron_pickaxe', once=1), 90 | CollectReward('diamond', once=1), 91 | HealthReward(), 92 | ] 93 | length = kwargs.pop('length', 36000) 94 | env = minecraft_base.MinecraftBase(actions, *args, **kwargs) 95 | env = embodied.wrappers.TimeLimit(env, length) 96 | super().__init__(env) 97 | 98 | def step(self, action): 99 | obs = self.env.step(action) 100 | obs['reward'] = sum([fn(obs, self.env.inventory) for fn in self.rewards]) 101 | return obs 102 | 103 | 104 | class CollectReward: 105 | 106 | def __init__(self, item, once=0, repeated=0): 107 | self.item = item 108 | self.once = once 109 | self.repeated = repeated 110 | self.previous = 0 111 | self.maximum = 0 112 | 113 | def __call__(self, obs, inventory): 114 | current = inventory[self.item] 115 | if obs['is_first']: 116 | self.previous = current 117 | self.maximum = current 118 | return 0 119 | reward = self.repeated * max(0, current - self.previous) 120 | if self.maximum == 0 and current > 0: 121 | reward += self.once 122 | self.previous = current 123 | self.maximum = max(self.maximum, current) 124 | return reward 125 | 126 | 127 | class HealthReward: 128 | 129 | def __init__(self, scale=0.01): 130 | self.scale = scale 131 | self.previous = None 132 | 133 | def __call__(self, obs, inventory=None): 134 | health = obs['health'] 135 | if obs['is_first']: 136 | self.previous = health 137 | return 0 138 | reward = self.scale * (health - self.previous) 139 | self.previous = health 140 | return np.float32(reward) 141 | 142 | 143 | BASIC_ACTIONS = { 144 | 'noop': dict(), 145 | 'attack': dict(attack=1), 146 | 'turn_up': dict(camera=(-15, 0)), 147 | 'turn_down': dict(camera=(15, 0)), 148 | 'turn_left': dict(camera=(0, -15)), 149 | 'turn_right': dict(camera=(0, 15)), 150 | 'forward': dict(forward=1), 151 | 'back': dict(back=1), 152 | 'left': dict(left=1), 153 | 'right': dict(right=1), 154 | 'jump': dict(jump=1, forward=1), 155 | 'place_dirt': dict(place='dirt'), 156 | } 157 | -------------------------------------------------------------------------------- /recall2imagine/embodied/envs/minecraft_base.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import threading 3 | 4 | import embodied 5 | import numpy as np 6 | 7 | 8 | class MinecraftBase(embodied.Env): 9 | 10 | _LOCK = threading.Lock() 11 | 12 | def __init__( 13 | self, actions, 14 | repeat=1, 15 | size=(64, 64), 16 | break_speed=100.0, 17 | gamma=10.0, 18 | sticky_attack=30, 19 | sticky_jump=10, 20 | pitch_limit=(-60, 60), 21 | logs=True, # TODO 22 | ): 23 | if logs: 24 | logging.basicConfig(level=logging.DEBUG) 25 | self._repeat = repeat 26 | self._size = size 27 | if break_speed != 1.0: 28 | sticky_attack = 0 29 | 30 | # Make env 31 | with self._LOCK: 32 | from .import minecraft_minerl 33 | self._gymenv = minecraft_minerl.MineRLEnv(size, break_speed, gamma).make() 34 | from . import from_gym 35 | self._env = from_gym.FromGym(self._gymenv) 36 | self._inventory = {} 37 | 38 | # Observations 39 | self._inv_keys = [ 40 | k for k in self._env.obs_space if k.startswith('inventory/') 41 | if k != 'inventory/log2'] 42 | self._step = 0 43 | self._max_inventory = None 44 | self._equip_enum = self._gymenv.observation_space[ 45 | 'equipped_items']['mainhand']['type'].values.tolist() 46 | self._obs_space = self.obs_space 47 | 48 | # Actions 49 | self._noop_action = minecraft_minerl.NOOP_ACTION 50 | actions = self._insert_defaults(actions) 51 | self._action_names = tuple(actions.keys()) 52 | self._action_values = tuple(actions.values()) 53 | message = f'Minecraft action space ({len(self._action_values)}):' 54 | print(message, ', '.join(self._action_names)) 55 | self._sticky_attack_length = sticky_attack 56 | self._sticky_attack_counter = 0 57 | self._sticky_jump_length = sticky_jump 58 | self._sticky_jump_counter = 0 59 | self._pitch_limit = pitch_limit 60 | self._pitch = 0 61 | 62 | @property 63 | def obs_space(self): 64 | return { 65 | 'image': embodied.Space(np.uint8, self._size + (3,)), 66 | 'inventory': embodied.Space(np.float32, len(self._inv_keys), 0), 67 | 'inventory_max': embodied.Space(np.float32, len(self._inv_keys), 0), 68 | 'equipped': embodied.Space(np.float32, len(self._equip_enum), 0, 1), 69 | 'reward': embodied.Space(np.float32), 70 | 'health': embodied.Space(np.float32), 71 | 'hunger': embodied.Space(np.float32), 72 | 'breath': embodied.Space(np.float32), 73 | 'is_first': embodied.Space(bool), 74 | 'is_last': embodied.Space(bool), 75 | 'is_terminal': embodied.Space(bool), 76 | **{f'log_{k}': embodied.Space(np.int64) for k in self._inv_keys}, 77 | 'log_player_pos': embodied.Space(np.float32, 3), 78 | } 79 | 80 | @property 81 | def act_space(self): 82 | return { 83 | 'action': embodied.Space(np.int64, (), 0, len(self._action_values)), 84 | 'reset': embodied.Space(bool), 85 | } 86 | 87 | def step(self, action): 88 | action = action.copy() 89 | index = action.pop('action') 90 | action.update(self._action_values[index]) 91 | action = self._action(action) 92 | if action['reset']: 93 | obs = self._reset() 94 | else: 95 | following = self._noop_action.copy() 96 | for key in ('attack', 'forward', 'back', 'left', 'right'): 97 | following[key] = action[key] 98 | for act in [action] + ([following] * (self._repeat - 1)): 99 | obs = self._env.step(act) 100 | if 'error' in self._env.info: 101 | obs = self._reset() 102 | break 103 | obs = self._obs(obs) 104 | self._step += 1 105 | assert 'pov' not in obs, list(obs.keys()) 106 | return obs 107 | 108 | @property 109 | def inventory(self): 110 | return self._inventory 111 | 112 | def _reset(self): 113 | with self._LOCK: 114 | obs = self._env.step({'reset': True}) 115 | self._step = 0 116 | self._max_inventory = None 117 | self._sticky_attack_counter = 0 118 | self._sticky_jump_counter = 0 119 | self._pitch = 0 120 | self._inventory = {} 121 | return obs 122 | 123 | def _obs(self, obs): 124 | obs['inventory/log'] += obs.pop('inventory/log2') 125 | self._inventory = { 126 | k.split('/', 1)[1]: obs[k] for k in self._inv_keys 127 | if k != 'inventory/air'} 128 | inventory = np.array([obs[k] for k in self._inv_keys], np.float32) 129 | if self._max_inventory is None: 130 | self._max_inventory = inventory 131 | else: 132 | self._max_inventory = np.maximum(self._max_inventory, inventory) 133 | index = self._equip_enum.index(obs['equipped_items/mainhand/type']) 134 | equipped = np.zeros(len(self._equip_enum), np.float32) 135 | equipped[index] = 1.0 136 | player_x = obs['location_stats/xpos'] 137 | player_y = obs['location_stats/ypos'] 138 | player_z = obs['location_stats/zpos'] 139 | obs = { 140 | 'image': obs['pov'], 141 | 'inventory': inventory, 142 | 'inventory_max': self._max_inventory.copy(), 143 | 'equipped': equipped, 144 | 'health': np.float32(obs['life_stats/life'] / 20), 145 | 'hunger': np.float32(obs['life_stats/food'] / 20), 146 | 'breath': np.float32(obs['life_stats/air'] / 300), 147 | 'reward': 0.0, 148 | 'is_first': obs['is_first'], 149 | 'is_last': obs['is_last'], 150 | 'is_terminal': obs['is_terminal'], 151 | **{f'log_{k}': np.int64(obs[k]) for k in self._inv_keys}, 152 | 'log_player_pos': np.array([player_x, player_y, player_z], np.float32), 153 | } 154 | for key, value in obs.items(): 155 | space = self._obs_space[key] 156 | if not isinstance(value, np.ndarray): 157 | value = np.array(value) 158 | assert value in space, (key, value, value.dtype, value.shape, space) 159 | return obs 160 | 161 | def _action(self, action): 162 | if self._sticky_attack_length: 163 | if action['attack']: 164 | self._sticky_attack_counter = self._sticky_attack_length 165 | if self._sticky_attack_counter > 0: 166 | action['attack'] = 1 167 | action['jump'] = 0 168 | self._sticky_attack_counter -= 1 169 | if self._sticky_jump_length: 170 | if action['jump']: 171 | self._sticky_jump_counter = self._sticky_jump_length 172 | if self._sticky_jump_counter > 0: 173 | action['jump'] = 1 174 | action['forward'] = 1 175 | self._sticky_jump_counter -= 1 176 | if self._pitch_limit and action['camera'][0]: 177 | lo, hi = self._pitch_limit 178 | if not (lo <= self._pitch + action['camera'][0] <= hi): 179 | action['camera'] = (0, action['camera'][1]) 180 | self._pitch += action['camera'][0] 181 | return action 182 | 183 | def _insert_defaults(self, actions): 184 | actions = {name: action.copy() for name, action in actions.items()} 185 | for key, default in self._noop_action.items(): 186 | for action in actions.values(): 187 | if key not in action: 188 | action[key] = default 189 | return actions 190 | -------------------------------------------------------------------------------- /recall2imagine/embodied/envs/minecraft_minerl.py: -------------------------------------------------------------------------------- 1 | from minerl.herobraine.env_spec import EnvSpec 2 | from minerl.herobraine.hero import handler 3 | from minerl.herobraine.hero import handlers 4 | from minerl.herobraine.hero import mc 5 | from minerl.herobraine.hero.mc import INVERSE_KEYMAP 6 | 7 | 8 | def edit_options(**kwargs): 9 | import os, pathlib, re 10 | for word in os.popen('pip3 --version').read().split(' '): 11 | if '-packages/pip' in word: 12 | break 13 | else: 14 | raise RuntimeError('Could not found python package directory.') 15 | packages = pathlib.Path(word).parent 16 | filename = packages / 'minerl/Malmo/Minecraft/run/options.txt' 17 | options = filename.read_text() 18 | if 'fovEffectScale:' not in options: 19 | options += 'fovEffectScale:1.0\n' 20 | if 'simulationDistance:' not in options: 21 | options += 'simulationDistance:12\n' 22 | for key, value in kwargs.items(): 23 | assert f'{key}:' in options, key 24 | assert isinstance(value, str), (value, type(value)) 25 | options = re.sub(f'{key}:.*\n', f'{key}:{value}\n', options) 26 | filename.write_text(options) 27 | 28 | 29 | edit_options( 30 | difficulty='2', 31 | renderDistance='6', 32 | simulationDistance='6', 33 | fovEffectScale='0.0', 34 | ao='1', 35 | gamma='5.0', 36 | ) 37 | 38 | 39 | class MineRLEnv(EnvSpec): 40 | 41 | def __init__(self, resolution=(64, 64), break_speed=50, gamma=10.0): 42 | self.resolution = resolution 43 | self.break_speed = break_speed 44 | self.gamma = gamma 45 | super().__init__(name='MineRLEnv-v1') 46 | 47 | def create_agent_start(self): 48 | return [ 49 | BreakSpeedMultiplier(self.break_speed), 50 | ] 51 | 52 | def create_agent_handlers(self): 53 | return [] 54 | 55 | def create_server_world_generators(self): 56 | return [handlers.DefaultWorldGenerator(force_reset=True)] 57 | 58 | def create_server_quit_producers(self): 59 | return [handlers.ServerQuitWhenAnyAgentFinishes()] 60 | 61 | def create_server_initial_conditions(self): 62 | return [ 63 | handlers.TimeInitialCondition( 64 | allow_passage_of_time=True, 65 | start_time=0, 66 | ), 67 | handlers.SpawningInitialCondition( 68 | allow_spawning=True, 69 | ) 70 | ] 71 | 72 | def create_observables(self): 73 | return [ 74 | handlers.POVObservation(self.resolution), 75 | handlers.FlatInventoryObservation(mc.ALL_ITEMS), 76 | handlers.EquippedItemObservation( 77 | mc.ALL_ITEMS, _default='air', _other='other'), 78 | handlers.ObservationFromCurrentLocation(), 79 | handlers.ObservationFromLifeStats(), 80 | ] 81 | 82 | def create_actionables(self): 83 | kw = dict(_other='none', _default='none') 84 | return [ 85 | handlers.KeybasedCommandAction('forward', INVERSE_KEYMAP['forward']), 86 | handlers.KeybasedCommandAction('back', INVERSE_KEYMAP['back']), 87 | handlers.KeybasedCommandAction('left', INVERSE_KEYMAP['left']), 88 | handlers.KeybasedCommandAction('right', INVERSE_KEYMAP['right']), 89 | handlers.KeybasedCommandAction('jump', INVERSE_KEYMAP['jump']), 90 | handlers.KeybasedCommandAction('sneak', INVERSE_KEYMAP['sneak']), 91 | handlers.KeybasedCommandAction('attack', INVERSE_KEYMAP['attack']), 92 | handlers.CameraAction(), 93 | handlers.PlaceBlock(['none'] + mc.ALL_ITEMS, **kw), 94 | handlers.EquipAction(['none'] + mc.ALL_ITEMS, **kw), 95 | handlers.CraftAction(['none'] + mc.ALL_ITEMS, **kw), 96 | handlers.CraftNearbyAction(['none'] + mc.ALL_ITEMS, **kw), 97 | handlers.SmeltItemNearby(['none'] + mc.ALL_ITEMS, **kw), 98 | ] 99 | 100 | def is_from_folder(self, folder): 101 | return folder == 'none' 102 | 103 | def get_docstring(self): 104 | return '' 105 | 106 | def determine_success_from_rewards(self, rewards): 107 | return True 108 | 109 | def create_rewardables(self): 110 | return [] 111 | 112 | def create_server_decorators(self): 113 | return [] 114 | 115 | def create_mission_handlers(self): 116 | return [] 117 | 118 | def create_monitors(self): 119 | return [] 120 | 121 | 122 | class BreakSpeedMultiplier(handler.Handler): 123 | 124 | def __init__(self, multiplier=1.0): 125 | self.multiplier = multiplier 126 | 127 | def to_string(self): 128 | return f'break_speed({self.multiplier})' 129 | 130 | def xml_template(self): 131 | return '{{multiplier}}' 132 | 133 | 134 | class Gamma(handler.Handler): 135 | 136 | def __init__(self, gamma=2.0): 137 | self.gamma = gamma 138 | 139 | def to_string(self): 140 | return f'gamma({self.gamma})' 141 | 142 | def xml_template(self): 143 | return '{{gamma}}' 144 | 145 | 146 | NOOP_ACTION = dict( 147 | camera=(0, 0), forward=0, back=0, left=0, right=0, attack=0, sprint=0, 148 | jump=0, sneak=0, craft='none', nearbyCraft='none', nearbySmelt='none', 149 | place='none', equip='none', 150 | ) 151 | -------------------------------------------------------------------------------- /recall2imagine/embodied/envs/pinpad.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | import embodied 4 | import numpy as np 5 | 6 | 7 | class PinPad(embodied.Env): 8 | 9 | COLORS = { 10 | '1': (255, 0, 0), 11 | '2': ( 0, 255, 0), 12 | '3': ( 0, 0, 255), 13 | '4': (255, 255, 0), 14 | '5': (255, 0, 255), 15 | '6': ( 0, 255, 255), 16 | '7': (128, 0, 128), 17 | '8': ( 0, 128, 128), 18 | } 19 | 20 | def __init__(self, task, length=10000): 21 | assert length > 0 22 | layout = { 23 | 'three': LAYOUT_THREE, 24 | 'four': LAYOUT_FOUR, 25 | 'five': LAYOUT_FIVE, 26 | 'six': LAYOUT_SIX, 27 | 'seven': LAYOUT_SEVEN, 28 | 'eight': LAYOUT_EIGHT, 29 | }[task] 30 | self.layout = np.array([list(line) for line in layout.split('\n')]).T 31 | assert self.layout.shape == (16, 14), self.layout.shape 32 | self.length = length 33 | self.random = np.random.RandomState() 34 | self.pads = set(self.layout.flatten().tolist()) - set('* #\n') 35 | self.target = tuple(sorted(self.pads)) 36 | self.spawns = [] 37 | for (x, y), char in np.ndenumerate(self.layout): 38 | if char != '#': 39 | self.spawns.append((x, y)) 40 | print(f'Created PinPad env with sequence: {"->".join(self.target)}') 41 | self.sequence = collections.deque(maxlen=len(self.target)) 42 | self.player = None 43 | self.steps = None 44 | self.done = None 45 | self.countdown = None 46 | 47 | @property 48 | def act_space(self): 49 | return { 50 | 'action': embodied.Space(np.int64, (), 0, 5), 51 | 'reset': embodied.Space(bool), 52 | } 53 | 54 | @property 55 | def obs_space(self): 56 | return { 57 | 'image': embodied.Space(np.uint8, (64, 64, 3)), 58 | 'reward': embodied.Space(np.float32), 59 | 'is_first': embodied.Space(bool), 60 | 'is_last': embodied.Space(bool), 61 | 'is_terminal': embodied.Space(bool), 62 | } 63 | 64 | def step(self, action): 65 | if self.done or action['reset']: 66 | self.player = self.spawns[self.random.randint(len(self.spawns))] 67 | self.sequence.clear() 68 | self.steps = 0 69 | self.done = False 70 | self.countdown = 0 71 | return self._obs(reward=0.0, is_first=True) 72 | if self.countdown: 73 | self.countdown -= 1 74 | if self.countdown == 0: 75 | self.player = self.spawns[self.random.randint(len(self.spawns))] 76 | self.sequence.clear() 77 | reward = 0.0 78 | move = [(0, 0), (0, 1), (0, -1), (1, 0), (-1, 0)][action['action']] 79 | x = np.clip(self.player[0] + move[0], 0, 15) 80 | y = np.clip(self.player[1] + move[1], 0, 13) 81 | tile = self.layout[x][y] 82 | if tile != '#': 83 | self.player = (x, y) 84 | if tile in self.pads: 85 | if not self.sequence or self.sequence[-1] != tile: 86 | self.sequence.append(tile) 87 | if tuple(self.sequence) == self.target and not self.countdown: 88 | reward += 10.0 89 | self.countdown = 10 90 | self.steps += 1 91 | self.done = self.done or (self.steps >= self.length) 92 | return self._obs(reward=reward, is_last=self.done) 93 | 94 | def render(self): 95 | grid = np.zeros((16, 16, 3), np.uint8) + 255 96 | white = np.array([255, 255, 255]) 97 | if self.countdown: 98 | grid[:] = (223, 255, 223) 99 | current = self.layout[self.player[0]][self.player[1]] 100 | for (x, y), char in np.ndenumerate(self.layout): 101 | if char == '#': 102 | grid[x, y] = (192, 192, 192) 103 | elif char in self.pads: 104 | color = np.array(self.COLORS[char]) 105 | color = color if char == current else (10 * color + 90 * white) / 100 106 | grid[x, y] = color 107 | grid[self.player] = (0, 0, 0) 108 | grid[:, -2:] = (192, 192, 192) 109 | for i, char in enumerate(self.sequence): 110 | grid[2 * i + 1, -2] = self.COLORS[char] 111 | image = np.repeat(np.repeat(grid, 4, 0), 4, 1) 112 | return image.transpose((1, 0, 2)) 113 | 114 | def _obs(self, reward, is_first=False, is_last=False, is_terminal=False): 115 | return dict( 116 | image=self.render(), reward=reward, is_first=is_first, is_last=is_last, 117 | is_terminal=is_terminal) 118 | 119 | 120 | LAYOUT_THREE = """ 121 | ################ 122 | #1111 3333# 123 | #1111 3333# 124 | #1111 3333# 125 | #1111 3333# 126 | # # 127 | # # 128 | # # 129 | # # 130 | # 2222 # 131 | # 2222 # 132 | # 2222 # 133 | # 2222 # 134 | ################ 135 | """.strip('\n') 136 | 137 | LAYOUT_FOUR = """ 138 | ################ 139 | #1111 4444# 140 | #1111 4444# 141 | #1111 4444# 142 | #1111 4444# 143 | # # 144 | # # 145 | # # 146 | # # 147 | #3333 2222# 148 | #3333 2222# 149 | #3333 2222# 150 | #3333 2222# 151 | ################ 152 | """.strip('\n') 153 | 154 | LAYOUT_FIVE = """ 155 | ################ 156 | # 4444# 157 | #111 4444# 158 | #111 4444# 159 | #111 # 160 | #111 555# 161 | # 555# 162 | # 555# 163 | #333 555# 164 | #333 # 165 | #333 2222# 166 | #333 2222# 167 | # 2222# 168 | ################ 169 | """.strip('\n') 170 | 171 | LAYOUT_SIX = """ 172 | ################ 173 | #111 555# 174 | #111 555# 175 | #111 555# 176 | # # 177 | #33 66# 178 | #33 66# 179 | #33 66# 180 | #33 66# 181 | # # 182 | #444 222# 183 | #444 222# 184 | #444 222# 185 | ################ 186 | """.strip('\n') 187 | 188 | LAYOUT_SEVEN = """ 189 | ################ 190 | #111 444# 191 | #111 444# 192 | #11 44# 193 | # # 194 | #33 55# 195 | #33 55# 196 | #33 55# 197 | #33 55# 198 | # # 199 | #66 22# 200 | #666 7777 222# 201 | #666 7777 222# 202 | ################ 203 | """.strip('\n') 204 | 205 | LAYOUT_EIGHT = """ 206 | ################ 207 | #111 8888 444# 208 | #111 8888 444# 209 | #11 44# 210 | # # 211 | #33 55# 212 | #33 55# 213 | #33 55# 214 | #33 55# 215 | # # 216 | #66 22# 217 | #666 7777 222# 218 | #666 7777 222# 219 | ################ 220 | """.strip('\n') 221 | -------------------------------------------------------------------------------- /recall2imagine/embodied/envs/robodesk.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import embodied 4 | 5 | 6 | class RoboDesk(embodied.Env): 7 | 8 | def __init__(self, task, mode, repeat=1, length=500, resets=True): 9 | assert mode in ('train', 'eval') 10 | # TODO: This env variable is meant for headless GPU machines but may fail 11 | # on CPU-only machines. 12 | if 'MUJOCO_GL' not in os.environ: 13 | os.environ['MUJOCO_GL'] = 'egl' 14 | try: 15 | from robodesk import robodesk 16 | except ImportError: 17 | import robodesk 18 | task, reward = task.rsplit('_', 1) 19 | if mode == 'eval': 20 | reward = 'success' 21 | assert reward in ('dense', 'sparse', 'success'), reward 22 | self._gymenv = robodesk.RoboDesk(task, reward, repeat, length) 23 | from . import from_gym 24 | self._env = from_gym.FromGym(self._gymenv) 25 | 26 | @property 27 | def obs_space(self): 28 | return self._env.obs_space 29 | 30 | @property 31 | def act_space(self): 32 | return self._env.act_space 33 | 34 | def step(self, action): 35 | obs = self._env.step(action) 36 | obs['is_terminal'] = False 37 | return obs 38 | -------------------------------------------------------------------------------- /recall2imagine/embodied/replay/__init__.py: -------------------------------------------------------------------------------- 1 | from .generic import Generic 2 | from .generic_lfs import FIFO_LFS 3 | from .reverb import Reverb 4 | from .replays import Uniform 5 | from .naive_chunks import NaiveChunks 6 | from . import selectors 7 | from . import limiters 8 | -------------------------------------------------------------------------------- /recall2imagine/embodied/replay/chunk.py: -------------------------------------------------------------------------------- 1 | import io 2 | from datetime import datetime 3 | 4 | import embodied 5 | import numpy as np 6 | import uuid as uuidlib 7 | 8 | 9 | class Chunk: 10 | """ 11 | This class represents a contiguous chunk of RL experience 12 | and implements some useful methods to work with that - 13 | append experience step, save and load. 14 | """ 15 | def __init__(self, size, successor=None): 16 | now = datetime.now() 17 | self.time = now.strftime("%Y%m%dT%H%M%S") + f'F{now.microsecond:06d}' 18 | self.uuid_b = embodied.uuid() 19 | self.uuid = str(self.uuid_b) 20 | self.uuid_b = self.uuid_b.value 21 | self.successor = successor 22 | self.size = size 23 | self.data = None 24 | self.length = 0 25 | 26 | def __repr__(self): 27 | succ = self.successor or str(embodied.uuid(0)) 28 | succ = succ.uuid if isinstance(succ, type(self)) else succ 29 | return ( 30 | f'Chunk(uuid={self.uuid}, ' 31 | f'succ={succ}, ' 32 | f'len={self.length})') 33 | 34 | def __len__(self): 35 | return self.length 36 | 37 | def __bool__(self): 38 | return True 39 | 40 | def append(self, step): 41 | if not self.data: 42 | example = {k: embodied.convert(v) for k, v in step.items()} 43 | self.data = { 44 | k: np.empty((self.size,) + v.shape, v.dtype) 45 | for k, v in example.items()} 46 | for key, value in step.items(): 47 | self.data[key][self.length] = value 48 | self.length += 1 49 | 50 | def save(self, directory): 51 | succ = self.successor or str(embodied.uuid(0)) 52 | succ = succ.uuid if isinstance(succ, type(self)) else succ 53 | filename = f'{self.time}-{self.uuid}-{succ}-{self.length}.npz' 54 | filename = embodied.Path(directory) / filename 55 | data = {k: embodied.convert(v) for k, v in self.data.items()} 56 | with io.BytesIO() as stream: 57 | np.savez_compressed(stream, **data) 58 | stream.seek(0) 59 | filename.write(stream.read(), mode='wb') 60 | print(f'Saved chunk: {filename.name}') 61 | 62 | @classmethod 63 | def load(cls, filename, load_data=True): 64 | length = int(filename.stem.split('-')[3]) 65 | if load_data: 66 | with embodied.Path(filename).open('rb') as f: 67 | data = np.load(f) 68 | data = {k: data[k] for k in data.keys()} 69 | else: 70 | data = None 71 | chunk = cls(length) 72 | chunk.time = filename.stem.split('-')[0] 73 | chunk.uuid = filename.stem.split('-')[1] 74 | chunk.successor = filename.stem.split('-')[2] 75 | chunk.length = length 76 | chunk.data = data 77 | chunk.filename = filename 78 | return chunk 79 | 80 | @classmethod 81 | def scan(cls, directory, capacity=None, shorten=0): 82 | directory = embodied.Path(directory) 83 | filenames, total = [], 0 84 | for filename in reversed(sorted(directory.glob('*.npz'))): 85 | if capacity and total >= capacity: 86 | break 87 | filenames.append(filename) 88 | total += max(0, int(filename.stem.split('-')[3]) - shorten) 89 | return sorted(filenames) 90 | 91 | class ChunkSerializer: 92 | """ 93 | This class represents the serialization behaviour for the Chunk 94 | class definded above. The serialization format is 95 | (next chunk uuid, current chunk uuid, experience payload) 96 | where the latter is several numpy arrays serialized to bytes. 97 | The next chunk link is needed to keep the sequential order of data in 98 | the dataloader (see the buffer class defined in generic_lfs.py). 99 | """ 100 | def __init__(self, pattern_obj: Chunk, pattern=None): 101 | if pattern is None: 102 | self.pattern = [(k, v.shape, v.dtype) for k, v in pattern_obj.data.items()] 103 | else: 104 | self.pattern = pattern 105 | 106 | def dummy_chunk(self): 107 | return { 108 | k: np.empty(shape=v, dtype=d) for k,v,d in self.pattern 109 | } 110 | 111 | @property 112 | def chunk_size(self): 113 | return sum(np.dtype(dt).itemsize * np.prod(sh) for _, sh, dt in self.pattern) + 16 * 2 114 | 115 | def batch_buffer(self, num_buffers, batch_size, sequence_length): 116 | return { 117 | k: np.empty((num_buffers, batch_size, sequence_length, *shape[1:]), dtype=dtype) 118 | for k, shape, dtype in self.pattern 119 | } 120 | 121 | def serialize(self, chunk: Chunk, buffer: np.ndarray): 122 | offset = 0 123 | 124 | succ = chunk.successor or str(embodied.uuid(0)) 125 | succ = succ.uuid if isinstance(succ, type(chunk)) else succ 126 | succ = succ.value if isinstance(succ, embodied.uuid) else bytes(succ) 127 | buffer[offset:offset+16] = np.frombuffer(succ, dtype=np.uint8) 128 | offset += 16 129 | 130 | uuid = chunk.uuid_b 131 | buffer[offset:offset+16] = np.frombuffer(uuid, dtype=np.uint8) 132 | offset += 16 133 | 134 | for k, shape, dtype in self.pattern: 135 | assert shape == chunk.data[k].shape 136 | assert dtype == chunk.data[k].dtype 137 | array = chunk.data[k] 138 | buffer[offset:offset+array.nbytes] = array.view(np.uint8).flat 139 | offset += array.nbytes 140 | 141 | def deserialize(self, buffer): 142 | offset = 0 143 | succ = buffer[offset:offset+16].tobytes() 144 | offset += 16 145 | uuid = buffer[offset:offset+16].tobytes() 146 | offset += 16 147 | destination = {} 148 | for k, shape, dtype in self.pattern: 149 | bytes_size = np.dtype(dtype).itemsize * np.prod(shape) 150 | destination[k] = buffer[offset:offset+bytes_size].view(dtype).reshape(shape).copy() 151 | offset += bytes_size 152 | chunk = Chunk(len(destination[k])) 153 | chunk.data = destination 154 | chunk.uuid_b = uuid 155 | chunk.uuid = embodied.uuid(int.from_bytes(uuid, 'big')) 156 | chunk.successor = succ 157 | return chunk -------------------------------------------------------------------------------- /recall2imagine/embodied/replay/generic.py: -------------------------------------------------------------------------------- 1 | import time 2 | from collections import defaultdict, deque 3 | from functools import partial as bind 4 | 5 | import embodied 6 | import numpy as np 7 | 8 | from . import saver 9 | 10 | 11 | class Generic: 12 | 13 | def __init__( 14 | self, length, capacity, remover, sampler, limiter, directory, 15 | overlap=None, online=False, chunks=1024): 16 | assert capacity is None or 1 <= capacity 17 | self.length = length 18 | self.capacity = capacity 19 | self.remover = remover 20 | self.sampler = sampler 21 | self.limiter = limiter 22 | self.stride = 1 if overlap is None else length - overlap 23 | self.streams = defaultdict(bind(deque, maxlen=length)) 24 | self.counters = defaultdict(int) 25 | self.table = {} 26 | self.online = online 27 | if self.online: 28 | self.online_queue = deque() 29 | self.online_stride = length 30 | self.online_counters = defaultdict(int) 31 | self.saver = directory and saver.Saver(directory, chunks) 32 | self.metrics = { 33 | 'samples': 0, 34 | 'sample_wait_dur': 0, 35 | 'sample_wait_count': 0, 36 | 'inserts': 0, 37 | 'insert_wait_dur': 0, 38 | 'insert_wait_count': 0, 39 | } 40 | self.load() 41 | 42 | def __len__(self): 43 | return len(self.table) 44 | 45 | def set_agent(self, agent): 46 | """ 47 | Placeholder method for API compatibility with the other buffer 48 | """ 49 | pass 50 | 51 | @property 52 | def stats(self): 53 | ratio = lambda x, y: x / y if y else np.nan 54 | m = self.metrics 55 | stats = { 56 | 'size': len(self), 57 | 'inserts': m['inserts'], 58 | 'samples': m['samples'], 59 | 'insert_wait_avg': ratio(m['insert_wait_dur'], m['inserts']), 60 | 'insert_wait_frac': ratio(m['insert_wait_count'], m['inserts']), 61 | 'sample_wait_avg': ratio(m['sample_wait_dur'], m['samples']), 62 | 'sample_wait_frac': ratio(m['sample_wait_count'], m['samples']), 63 | } 64 | for key in self.metrics: 65 | self.metrics[key] = 0 66 | return stats 67 | 68 | def add(self, step, worker=0, load=False): 69 | step = {k: v for k, v in step.items() if not k.startswith('log_')} 70 | step['id'] = np.asarray(embodied.uuid(step.get('id'))) 71 | stream = self.streams[worker] 72 | stream.append(step) 73 | self.saver and self.saver.add(step, worker) 74 | self.counters[worker] += 1 75 | if self.online: 76 | self.online_counters[worker] += 1 77 | if len(stream) >= self.length and ( 78 | self.online_counters[worker] >= self.online_stride): 79 | self.online_queue.append(tuple(stream)) 80 | self.online_counters[worker] = 0 81 | if len(stream) < self.length or self.counters[worker] < self.stride: 82 | return 83 | self.counters[worker] = 0 84 | key = embodied.uuid() 85 | seq = tuple(stream) 86 | if load: 87 | assert self.limiter.want_load()[0] 88 | else: 89 | dur = wait(self.limiter.want_insert, 'Replay insert is waiting') 90 | self.metrics['inserts'] += 1 91 | self.metrics['insert_wait_dur'] += dur 92 | self.metrics['insert_wait_count'] += int(dur > 0) 93 | self.table[key] = seq 94 | self.remover[key] = seq 95 | self.sampler[key] = seq 96 | while self.capacity and len(self) > self.capacity: 97 | self._remove(self.remover()) 98 | 99 | def _sample(self): 100 | dur = wait(self.limiter.want_sample, 'Replay sample is waiting') 101 | self.metrics['samples'] += 1 102 | self.metrics['sample_wait_dur'] += dur 103 | self.metrics['sample_wait_count'] += int(dur > 0) 104 | if self.online: 105 | try: 106 | seq = self.online_queue.popleft() 107 | except IndexError: 108 | seq = self.table[self.sampler()] 109 | else: 110 | seq = self.table[self.sampler()] 111 | seq = {k: [step[k] for step in seq] for k in seq[0]} 112 | seq = {k: embodied.convert(v) for k, v in seq.items()} 113 | if 'is_first' in seq: 114 | seq['is_first'][0] = True 115 | return seq 116 | 117 | def _remove(self, key): 118 | wait(self.limiter.want_remove, 'Replay remove is waiting') 119 | del self.table[key] 120 | del self.remover[key] 121 | del self.sampler[key] 122 | 123 | def dataset(self): 124 | while True: 125 | yield self._sample() 126 | 127 | def prioritize(self, ids, prios): 128 | if hasattr(self.sampler, 'prioritize'): 129 | self.sampler.prioritize(ids, prios) 130 | 131 | def save(self, wait=False): 132 | if not self.saver: 133 | return 134 | self.saver.save(wait) 135 | # return { 136 | # 'saver': self.saver.save(wait), 137 | # # 'remover': self.remover.save(wait), 138 | # # 'sampler': self.sampler.save(wait), 139 | # # 'limiter': self.limiter.save(wait), 140 | # } 141 | 142 | def maybe_restore(self): 143 | self.load() 144 | 145 | def load(self, data=None): 146 | if not self.saver: 147 | return 148 | workers = set() 149 | for step, worker in self.saver.load(self.capacity, self.length): 150 | workers.add(worker) 151 | self.add(step, worker, load=True) 152 | for worker in workers: 153 | del self.streams[worker] 154 | del self.counters[worker] 155 | # self.remover.load(data['remover']) 156 | # self.sampler.load(data['sampler']) 157 | # self.limiter.load(data['limiter']) 158 | 159 | 160 | def wait(predicate, message, sleep=0.001, notify=1.0): 161 | start = time.time() 162 | notified = False 163 | while True: 164 | allowed, detail = predicate() 165 | duration = time.time() - start 166 | if allowed: 167 | return duration 168 | if not notified and duration >= notify: 169 | print(f'{message} ({detail})') 170 | notified = True 171 | time.sleep(sleep) 172 | -------------------------------------------------------------------------------- /recall2imagine/embodied/replay/limiters.py: -------------------------------------------------------------------------------- 1 | import threading 2 | 3 | 4 | class MinSize: 5 | 6 | def __init__(self, minimum): 7 | assert 1 <= minimum, minimum 8 | self.minimum = minimum 9 | self.size = 0 10 | self.lock = threading.Lock() 11 | 12 | def want_load(self): 13 | with self.lock: 14 | self.size += 1 15 | return True, 'ok' 16 | 17 | def want_insert(self): 18 | with self.lock: 19 | self.size += 1 20 | return True, 'ok' 21 | 22 | def want_remove(self): 23 | with self.lock: 24 | if self.size < 1: 25 | return False, 'is empty' 26 | self.size -= 1 27 | return True, 'ok' 28 | 29 | def want_sample(self): 30 | if self.size < self.minimum: 31 | return False, f'too empty: {self.size} < {self.minimum}' 32 | return True, 'ok' 33 | 34 | 35 | class SamplesPerInsert: 36 | 37 | def __init__(self, samples_per_insert, tolerance, minimum=1, unlocked=False): 38 | assert 1 <= minimum 39 | self.samples_per_insert = samples_per_insert 40 | self.unlocked = unlocked 41 | self.minimum = minimum 42 | self.avail = -minimum 43 | self.min_avail = -tolerance 44 | self.max_avail = tolerance * samples_per_insert 45 | self.size = 0 46 | self.lock = threading.Lock() 47 | 48 | def want_load(self): 49 | with self.lock: 50 | self.size += 1 51 | return True, 'ok' 52 | 53 | def want_insert(self): 54 | with self.lock: 55 | if self.unlocked: 56 | self.avail += self.samples_per_insert 57 | self.size += 1 58 | return True, 'ok' 59 | else: 60 | if self.avail >= self.max_avail: 61 | return False, f'rate limited: {self.avail:.3f} >= {self.max_avail:.3f}' 62 | self.avail += self.samples_per_insert 63 | self.size += 1 64 | return True, 'ok' 65 | 66 | def want_remove(self): 67 | with self.lock: 68 | if self.size < 1: 69 | return False, 'is empty' 70 | self.size -= 1 71 | return True, 'ok' 72 | 73 | def want_sample(self): 74 | with self.lock: 75 | if self.size < self.minimum: 76 | return False, f'too empty: {self.size} < {self.minimum}' 77 | if self.unlocked: 78 | self.avail -= 1 79 | return True, 'ok' 80 | else: 81 | if self.avail <= self.min_avail: 82 | return False, f'rate limited: {self.avail:.3f} <= {self.min_avail:.3f}' 83 | self.avail -= 1 84 | return True, 'ok' 85 | 86 | 87 | class Queue: 88 | 89 | def __init__(self, capacity): 90 | assert 1 <= capacity 91 | self.capacity = capacity 92 | self.size = 0 93 | self.lock = threading.Lock() 94 | 95 | def want_load(self): 96 | with self.lock: 97 | self.size += 1 98 | return True, 'ok' 99 | 100 | def want_insert(self): 101 | with self.lock: 102 | if self.size >= self.capacity: 103 | return False, f'is full: {self.size} >= {self.capacity}' 104 | self.size += 1 105 | return True, 'ok' 106 | 107 | def want_remove(self): 108 | with self.lock: 109 | if self.size < 1: 110 | return False, 'is empty' 111 | self.size -= 1 112 | return True, 'ok' 113 | 114 | def want_sample(self): 115 | if self.size < 1: 116 | return False, 'is empty' 117 | else: 118 | return True, 'ok' 119 | -------------------------------------------------------------------------------- /recall2imagine/embodied/replay/naive_chunks.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures 2 | import threading 3 | import time 4 | import uuid 5 | from collections import deque, defaultdict 6 | from functools import partial as bind 7 | 8 | import numpy as np 9 | import embodied 10 | 11 | from . import chunk as chunklib 12 | 13 | 14 | class NaiveChunks(embodied.Replay): 15 | 16 | def __init__(self, length, capacity=None, directory=None, chunks=1024, seed=0): 17 | assert 1 <= length <= chunks 18 | self.length = length 19 | self.capacity = capacity 20 | self.directory = directory and embodied.Path(directory) 21 | self.chunks = chunks 22 | self.buffers = {} 23 | self.rng = np.random.default_rng(seed) 24 | self.ongoing = defaultdict(bind(chunklib.Chunk, chunks)) 25 | if directory: 26 | self.directory.mkdirs() 27 | self.workers = concurrent.futures.ThreadPoolExecutor(16) 28 | self.promises = deque() 29 | 30 | def __len__(self): 31 | return len(self.buffers) * self.length 32 | 33 | @property 34 | def stats(self): 35 | return {'size': len(self), 'chunks': len(self.buffers)} 36 | 37 | def add(self, step, worker=0): 38 | step = {k: v for k, v in step.items() if not k.startswith('log_')} 39 | chunk = self.ongoing[worker] 40 | chunk.append(step) 41 | if len(chunk) >= self.chunks: 42 | self.buffers[chunk.uuid] = self.ongoing.pop(worker) 43 | self.promises.append(self.workers.submit(chunk.save, self.directory)) 44 | for promise in [x for x in self.promises if x.done()]: 45 | promise.result() 46 | self.promises.remove(promise) 47 | while len(self) > self.capacity: 48 | del self.buffers[next(iter(self.buffers.keys()))] 49 | 50 | def _sample(self): 51 | counter = 0 52 | while not self.buffers: 53 | if counter % 100 == 0: 54 | print('Replay sample is waiting') 55 | time.sleep(0.1) 56 | counter += 1 57 | keys = tuple(self.buffers.keys()) 58 | chunk = self.buffers[keys[self.rng.integers(0, len(keys))]] 59 | idx = self.rng.integers(0, len(chunk) - self.length + 1) 60 | seq = {k: chunk.data[k][idx: idx + self.length] for k in chunk.data.keys()} 61 | seq['is_first'][0] = True 62 | return seq 63 | 64 | def dataset(self): 65 | while True: 66 | yield self._sample() 67 | 68 | def save(self, wait=False): 69 | for chunk in self.ongoing.values(): 70 | if chunk.length: 71 | self.promises.append(self.workers.submit(chunk.save, self.directory)) 72 | if wait: 73 | [x.result() for x in self.promises] 74 | self.promises.clear() 75 | 76 | def load(self, data=None): 77 | filenames = chunklib.Chunk.scan(self.directory, capacity) 78 | if not filenames: 79 | return 80 | threads = min(len(filenames), 32) 81 | with concurrent.futures.ThreadPoolExecutor(threads) as executor: 82 | chunks = list(executor.map(chunklib.Chunk.load, filenames)) 83 | self.buffers = {chunk.uuid: chunk for chunk in chunks} 84 | -------------------------------------------------------------------------------- /recall2imagine/embodied/replay/replays.py: -------------------------------------------------------------------------------- 1 | from . import generic 2 | from . import selectors 3 | from . import limiters 4 | 5 | 6 | class Uniform(generic.Generic): 7 | 8 | def __init__( 9 | self, length, capacity=None, directory=None, online=False, chunks=1024, 10 | min_size=1, samples_per_insert=None, tolerance=1e4, seed=0): 11 | if samples_per_insert: 12 | limiter = limiters.SamplesPerInsert( 13 | samples_per_insert, tolerance, min_size) 14 | else: 15 | limiter = limiters.MinSize(min_size) 16 | assert not capacity or min_size <= capacity 17 | super().__init__( 18 | length=length, 19 | capacity=capacity, 20 | remover=selectors.Fifo(), 21 | sampler=selectors.Uniform(seed), 22 | limiter=limiter, 23 | directory=directory, 24 | online=online, 25 | chunks=chunks, 26 | ) 27 | -------------------------------------------------------------------------------- /recall2imagine/embodied/replay/reverb.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from collections import defaultdict 3 | from functools import partial as bind 4 | 5 | import embodied 6 | import numpy as np 7 | 8 | 9 | class Reverb: 10 | 11 | def __init__( 12 | self, length, capacity=None, directory=None, chunks=None, flush=100): 13 | del chunks 14 | import reverb 15 | self.length = length 16 | self.capacity = capacity 17 | self.directory = directory and embodied.Path(directory) 18 | self.checkpointer = None 19 | self.server = None 20 | self.client = None 21 | self.writers = None 22 | self.counters = None 23 | self.signature = None 24 | self.flush = flush 25 | if self.directory: 26 | self.directory.mkdirs() 27 | path = str(self.directory) 28 | try: 29 | self.checkpointer = reverb.checkpointers.DefaultCheckpointer(path) 30 | except AttributeError: 31 | self.checkpointer = reverb.checkpointers.RecordIOCheckpointer(path) 32 | self.sigpath = self.directory.parent / (self.directory.name + '_sig.pkl') 33 | if self.directory and self.sigpath.exists(): 34 | with self.sigpath.open('rb') as file: 35 | self.signature = pickle.load(file) 36 | self._create_server() 37 | 38 | def _create_server(self): 39 | import reverb 40 | import tensorflow as tf 41 | self.server = reverb.Server(tables=[reverb.Table( 42 | name='table', 43 | sampler=reverb.selectors.Uniform(), 44 | remover=reverb.selectors.Fifo(), 45 | max_size=int(self.capacity), 46 | rate_limiter=reverb.rate_limiters.MinSize(1), 47 | signature={ 48 | key: tf.TensorSpec(shape, dtype) 49 | for key, (shape, dtype) in self.signature.items()}, 50 | )], port=None, checkpointer=self.checkpointer) 51 | self.client = reverb.Client(f'localhost:{self.server.port}') 52 | self.writers = defaultdict(bind( 53 | self.client.trajectory_writer, self.length)) 54 | self.counters = defaultdict(int) 55 | 56 | def __len__(self): 57 | if not self.client: 58 | return 0 59 | return self.client.server_info()['table'].current_size 60 | 61 | @property 62 | def stats(self): 63 | return {'size': len(self)} 64 | 65 | def add(self, step, worker=0): 66 | step = {k: v for k, v in step.items() if not k.startswith('log_')} 67 | step = {k: embodied.convert(v) for k, v in step.items()} 68 | step['id'] = np.asarray(embodied.uuid(step.get('id'))) 69 | if not self.server: 70 | self.signature = { 71 | k: ((self.length, *v.shape), v.dtype) for k, v in step.items()} 72 | self._create_server() 73 | step = {k: v for k, v in step.items() if not k.startswith('log_')} 74 | writer = self.writers[worker] 75 | writer.append(step) 76 | if len(next(iter(writer.history.values()))) >= self.length: 77 | seq = {k: v[-self.length:] for k, v in writer.history.items()} 78 | writer.create_item('table', priority=1.0, trajectory=seq) 79 | self.counters[worker] += 1 80 | if self.counters[worker] > self.flush: 81 | self.counters[worker] = 0 82 | writer.flush() 83 | 84 | def dataset(self): 85 | import reverb 86 | dataset = reverb.TrajectoryDataset.from_table_signature( 87 | server_address=f'localhost:{self.server.port}', 88 | table='table', 89 | max_in_flight_samples_per_worker=10) 90 | for sample in dataset: 91 | seq = sample.data 92 | seq = {k: embodied.convert(v) for k, v in seq.items()} 93 | # seq['key'] = sample.info.key # uint64 94 | # seq['prob'] = sample.info.probability 95 | if 'is_first' in seq: 96 | seq['is_first'] = np.array(seq['is_first']) 97 | seq['is_first'][0] = True 98 | yield seq 99 | 100 | def prioritize(self, ids, prios): 101 | raise NotImplementedError 102 | 103 | def save(self, wait=False): 104 | for writer in self.writers.values(): 105 | writer.flush() 106 | with self.sigpath.open('wb') as file: 107 | file.write(pickle.dumps(self.signature)) 108 | if self.directory: 109 | self.client.checkpoint() 110 | 111 | def load(self, data=None): 112 | pass 113 | -------------------------------------------------------------------------------- /recall2imagine/embodied/replay/saver.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures 2 | from multiprocessing.pool import ThreadPool 3 | from collections import defaultdict, deque 4 | from functools import partial as bind 5 | import numpy as np 6 | from tqdm import tqdm 7 | import embodied 8 | 9 | from . import chunk as chunklib 10 | 11 | 12 | class Saver: 13 | 14 | def __init__(self, directory, chunks=1024): 15 | self.directory = embodied.Path(directory) 16 | self.directory.mkdirs() 17 | self.chunks = chunks 18 | self.buffers = defaultdict(bind(chunklib.Chunk, chunks)) 19 | self.workers = concurrent.futures.ThreadPoolExecutor(16) 20 | self.promises = deque() 21 | self.loading = False 22 | 23 | def add(self, step, worker): 24 | if self.loading: 25 | return 26 | buffer = self.buffers[worker] 27 | buffer.append(step) 28 | if buffer.length >= self.chunks: 29 | self.buffers[worker] = buffer.successor = chunklib.Chunk(self.chunks) 30 | self.promises.append(self.workers.submit(buffer.save, self.directory)) 31 | for promise in [x for x in self.promises if x.done()]: 32 | promise.result() 33 | self.promises.remove(promise) 34 | 35 | def save(self, wait=False): 36 | for buffer in self.buffers.values(): 37 | if buffer.length: 38 | self.promises.append(self.workers.submit(buffer.save, self.directory)) 39 | if wait: 40 | [x.result() for x in self.promises] 41 | self.promises.clear() 42 | 43 | def load(self, capacity, length): 44 | filenames = chunklib.Chunk.scan(self.directory, capacity, 1) 45 | if not filenames: 46 | return 47 | lazychunks = [chunklib.Chunk.load(f, load_data=False) for f in filenames] 48 | lazychunks1 = {c.uuid: c for c in lazychunks} 49 | was = {k: False for k in lazychunks1} 50 | streamids = {} 51 | streams = set() 52 | while any(not x for x in was.values()): 53 | for chunk in (sorted(lazychunks, key=lambda x: x.time)): 54 | if not was[chunk.uuid]: 55 | stream_id = chunk.uuid 56 | streamids[chunk.uuid] = stream_id 57 | streams.add(stream_id) 58 | was[stream_id] = True 59 | nxt = lazychunks1[chunk.successor] if chunk.successor in lazychunks1 else None 60 | steps = 0 61 | while nxt is not None and nxt.uuid in lazychunks1: 62 | streamids[nxt.uuid] = stream_id 63 | was[nxt.uuid] = True 64 | nxt = lazychunks1[nxt.successor] if nxt.successor in lazychunks1 else None 65 | steps += 1 66 | print(f'Stream: {stream_id} ({len(streams)}th) - {steps} steps') 67 | threads = min(len(filenames), 32) 68 | self.loading = True 69 | with ThreadPool(threads) as executor: 70 | for chunk in tqdm(executor.imap_unordered(chunklib.Chunk.load, filenames), total=len(filenames)): 71 | stream = streamids[chunk.uuid] 72 | for index in range(chunk.length): 73 | step = {k: v[index] for k, v in chunk.data.items()} 74 | yield step, stream 75 | del chunk 76 | self.loading = False 77 | -------------------------------------------------------------------------------- /recall2imagine/embodied/replay/selectors.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | 3 | import numpy as np 4 | 5 | 6 | class Fifo: 7 | 8 | def __init__(self): 9 | self.queue = deque() 10 | 11 | def __call__(self): 12 | return self.queue[0] 13 | 14 | def __setitem__(self, key, steps): 15 | self.queue.append(key) 16 | 17 | def __delitem__(self, key): 18 | if self.queue[0] == key: 19 | self.queue.popleft() 20 | else: 21 | # TODO: This branch is unused but very slow. 22 | self.queue.remove(key) 23 | 24 | 25 | class Uniform: 26 | 27 | def __init__(self, seed=0): 28 | self.indices = {} 29 | self.keys = [] 30 | self.rng = np.random.default_rng(seed) 31 | 32 | def __call__(self): 33 | index = self.rng.integers(0, len(self.keys)).item() 34 | return self.keys[index] 35 | 36 | def __setitem__(self, key, steps): 37 | self.indices[key] = len(self.keys) 38 | self.keys.append(key) 39 | 40 | def __delitem__(self, key): 41 | index = self.indices.pop(key) 42 | last = self.keys.pop() 43 | if index != len(self.keys): 44 | self.keys[index] = last 45 | self.indices[last] = index 46 | -------------------------------------------------------------------------------- /recall2imagine/embodied/run/__init__.py: -------------------------------------------------------------------------------- 1 | from .eval_only import eval_only 2 | from .parallel import parallel 3 | from .train import train 4 | from .train_eval import train_eval 5 | from .train_holdout import train_holdout 6 | from .train_save import train_save 7 | -------------------------------------------------------------------------------- /recall2imagine/embodied/run/eval_only.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import embodied 4 | import numpy as np 5 | 6 | 7 | def eval_only(agent, env, logger, args): 8 | 9 | logdir = embodied.Path(args.logdir) 10 | logdir.mkdirs() 11 | print('Logdir', logdir) 12 | should_log = embodied.when.Clock(args.log_every) 13 | step = logger.step 14 | metrics = embodied.Metrics() 15 | print('Observation space:', env.obs_space) 16 | print('Action space:', env.act_space) 17 | 18 | timer = embodied.Timer() 19 | timer.wrap('agent', agent, ['policy']) 20 | timer.wrap('env', env, ['step']) 21 | timer.wrap('logger', logger, ['write']) 22 | 23 | nonzeros = set() 24 | def per_episode(ep): 25 | length = len(ep['reward']) - 1 26 | score = float(ep['reward'].astype(np.float64).sum()) 27 | logger.add({'length': length, 'score': score}, prefix='episode') 28 | print(f'Episode has {length} steps and return {score:.1f}.') 29 | stats = {} 30 | for key in args.log_keys_video: 31 | if key in ep: 32 | stats[f'policy_{key}'] = ep[key] 33 | for key, value in ep.items(): 34 | if not args.log_zeros and key not in nonzeros and (value == 0).all(): 35 | continue 36 | nonzeros.add(key) 37 | if re.match(args.log_keys_sum, key): 38 | stats[f'sum_{key}'] = ep[key].sum() 39 | if re.match(args.log_keys_mean, key): 40 | stats[f'mean_{key}'] = ep[key].mean() 41 | if re.match(args.log_keys_max, key): 42 | stats[f'max_{key}'] = ep[key].max(0).mean() 43 | metrics.add(stats, prefix='stats') 44 | 45 | driver = embodied.Driver(env) 46 | driver.on_episode(lambda ep, worker: per_episode(ep)) 47 | driver.on_step(lambda tran, _: step.increment()) 48 | 49 | checkpoint = embodied.Checkpoint() 50 | checkpoint.agent = agent 51 | checkpoint.load(args.from_checkpoint, keys=['agent']) 52 | 53 | print('Start evaluation loop.') 54 | policy = lambda *args: agent.policy(*args, mode='eval') 55 | while step < args.steps: 56 | driver(policy, steps=100) 57 | if should_log(step): 58 | logger.add(metrics.result()) 59 | logger.add(timer.stats(), prefix='timer') 60 | logger.write(fps=True) 61 | logger.write() 62 | -------------------------------------------------------------------------------- /recall2imagine/embodied/run/parallel.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | import logging 4 | import threading 5 | from collections import defaultdict 6 | 7 | import embodied 8 | import numpy as np 9 | 10 | 11 | def parallel(agent, replay, logger, make_env, num_envs, args): 12 | step = logger.step 13 | timer = embodied.Timer() 14 | timer.wrap('agent', agent, ['policy', 'train', 'report', 'save']) 15 | timer.wrap('replay', replay, ['add', 'save']) 16 | timer.wrap('logger', logger, ['write']) 17 | workers = [] 18 | workers.append(embodied.distr.Thread( 19 | actor, step, agent, replay, logger, args.actor_addr, args)) 20 | workers.append(embodied.distr.Thread( 21 | learner, step, agent, replay, logger, timer, args)) 22 | if num_envs == 1: 23 | workers.append(embodied.distr.Thread( 24 | env, make_env, args.actor_addr, 0, args, timer)) 25 | else: 26 | for i in range(num_envs): 27 | workers.append(embodied.distr.Process( 28 | env, make_env, args.actor_addr, i, args)) 29 | embodied.distr.run(workers) 30 | 31 | 32 | def actor(step, agent, replay, logger, actor_addr, args): 33 | metrics = embodied.Metrics() 34 | scalars = defaultdict(lambda: defaultdict(list)) 35 | videos = defaultdict(lambda: defaultdict(list)) 36 | should_log = embodied.when.Clock(args.log_every) 37 | 38 | _, initial = agent.policy(dummy_data( 39 | agent.agent.obs_space, (args.actor_batch,))) 40 | initial = embodied.treemap(lambda x: x[0], initial) 41 | allstates = defaultdict(lambda: initial) 42 | agent.sync() 43 | # step.t = 0 44 | 45 | def callback(obs, env_addrs): 46 | states = [allstates[a] for a in env_addrs] 47 | states = embodied.treemap(lambda *xs: list(xs), *states) 48 | act, states = agent.policy(obs, states) 49 | act['reset'] = obs['is_last'].copy() 50 | for i, a in enumerate(env_addrs): 51 | allstates[a] = embodied.treemap(lambda x: x[i], states) 52 | 53 | trans = {**obs, **act} 54 | for i, a in enumerate(env_addrs): 55 | tran = {k: v[i].copy() for k, v in trans.items()} 56 | replay.add(tran.copy(), worker=a) 57 | [scalars[a][k].append(v) for k, v in tran.items() if v.size == 1] 58 | [videos[a][k].append(tran[k]) for k in args.log_keys_video if k in tran] 59 | step.increment(args.actor_batch) 60 | # print(f'fps: {args.actor_batch / (time.time() - step.t):.3f}') 61 | # step.t = time.time() 62 | 63 | 64 | 65 | for i, a in enumerate(env_addrs): 66 | if not trans['is_last'][i]: 67 | continue 68 | vids = videos.pop(a) if a in videos else {} 69 | ep = {**scalars.pop(a), **vids} 70 | ep = {k: embodied.convert(v) for k, v in ep.items()} 71 | logger.add({ 72 | 'length': len(ep['reward']) - 1, 73 | 'score': sum(ep['reward']), 74 | }, prefix='episode') 75 | stats = {} 76 | for key in args.log_keys_video: 77 | if key != 'none': 78 | stats[f'policy_{key}'] = ep[key] 79 | metrics.add(stats, prefix='stats') 80 | 81 | if should_log(): 82 | logger.add(metrics.result()) 83 | 84 | return act 85 | 86 | print('[actor] Start server') 87 | embodied.BatchServer(actor_addr, args.actor_batch, callback).run() 88 | 89 | 90 | def learner(step, agent, replay, logger, timer, args): 91 | logdir = embodied.Path(args.logdir) 92 | ckpt_dir = embodied.Path(args.checkpoint_dir) 93 | metrics = embodied.Metrics() 94 | should_log = embodied.when.Clock(args.log_every) 95 | should_save = embodied.when.Clock(args.save_every) 96 | should_sync = embodied.when.Every(args.sync_every) 97 | updates = embodied.Counter() 98 | 99 | checkpoint = embodied.Checkpoint(ckpt_dir / 'checkpoint.ckpt') 100 | checkpoint.step = step 101 | checkpoint.agent = agent 102 | checkpoint.replay = replay 103 | is_set = False 104 | while not is_set: 105 | with replay.manager.event_lock: 106 | is_set = bool(replay.manager.init_event.is_set()) 107 | if not is_set: 108 | print('Learner is waiting for the actor to initialize the buffer..') 109 | time.sleep(30) 110 | 111 | if args.from_checkpoint: 112 | checkpoint.load(args.from_checkpoint) 113 | checkpoint.load_or_save(filename=replay.manager.tmp_file_path) 114 | # a workaround for the logger step issue when restoring 115 | logger.step.value = embodied.Counter(initial=agent.step.value) 116 | 117 | dataset = agent.dataset(replay) 118 | state = None 119 | stats = dict(last_time=time.time(), last_step=int(step), batch_entries=0) 120 | while True: 121 | batch = next(dataset) 122 | outs, state, mets = agent.train(batch, state) 123 | metrics.add(mets) 124 | updates.increment() 125 | stats['batch_entries'] += batch['is_first'].size 126 | 127 | if should_sync(updates): 128 | agent.sync() 129 | 130 | if should_log(): 131 | train = metrics.result() 132 | report = agent.report(batch) 133 | report = {k: v for k, v in report.items() if 'train/' + k not in train} 134 | logger.add(train, prefix='train') 135 | logger.add(report, prefix='report') 136 | logger.add(timer.stats(), prefix='timer') 137 | logger.add(replay.stats, prefix='replay') 138 | 139 | duration = time.time() - stats['last_time'] 140 | actor_fps = (int(step) - stats['last_step']) / duration 141 | learner_fps = stats['batch_entries'] / duration 142 | logger.add({ 143 | 'actor_fps': actor_fps, 144 | 'learner_fps': learner_fps, 145 | 'train_ratio': learner_fps / actor_fps if actor_fps else np.inf, 146 | }, prefix='parallel') 147 | stats = dict(last_time=time.time(), last_step=int(step), batch_entries=0) 148 | try: 149 | logger.write(fps=True) 150 | except: 151 | print('logging failed') 152 | 153 | if should_save(): 154 | try: 155 | checkpoint.save() 156 | except: 157 | print('saving failed') 158 | pass 159 | 160 | 161 | def env(make_env, actor_addr, i, args, timer=None): 162 | # TODO: Optionally write NPZ episodes. 163 | print(f'[env{i}] Make env') 164 | env = make_env() 165 | if timer: 166 | timer.wrap('env', env, ['step']) 167 | actor = embodied.Client(actor_addr) 168 | act = {k: v.sample() for k, v in env.act_space.items()} 169 | done = False 170 | while True: 171 | act['reset'] = done 172 | obs = env.step(act) 173 | obs = {k: np.asarray(v) for k, v in obs.items()} 174 | done = obs['is_last'] 175 | promise = actor(obs) 176 | try: 177 | act = promise() 178 | except RuntimeError: 179 | sys.exit(0) 180 | act = {k: v for k, v in act.items() if not k.startswith('log_')} 181 | 182 | 183 | def dummy_data(spaces, batch_dims): 184 | # TODO: Get rid of this function by adding initial_policy_state() and 185 | # initial_train_state() to the agent API. 186 | spaces = list(spaces.items()) 187 | data = {k: np.zeros(v.shape, v.dtype) for k, v in spaces} 188 | for dim in reversed(batch_dims): 189 | data = {k: np.repeat(v[None], dim, axis=0) for k, v in data.items()} 190 | return data 191 | -------------------------------------------------------------------------------- /recall2imagine/embodied/run/train.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import embodied 4 | import numpy as np 5 | import jax 6 | 7 | 8 | def train(agent, env, replay, logger, args, config): 9 | 10 | logdir = embodied.Path(args.logdir) 11 | logdir.mkdirs() 12 | print('Logdir', logdir) 13 | should_expl = embodied.when.Until(args.expl_until) 14 | should_report = embodied.when.Every(1000000) 15 | should_train = embodied.when.Ratio(args.train_ratio / args.batch_steps) 16 | should_log = embodied.when.Clock(args.log_every) 17 | should_save = embodied.when.Clock(args.save_every) 18 | should_sync = embodied.when.Every(args.sync_every) 19 | should_profile = args.profile_path != 'none' 20 | step = logger.step 21 | updates = embodied.Counter() 22 | metrics = embodied.Metrics() 23 | print('Observation space:', embodied.format(env.obs_space), sep='\n') 24 | print('Action space:', embodied.format(env.act_space), sep='\n') 25 | 26 | timer = embodied.Timer() 27 | timer.wrap('agent', agent, ['policy', 'train', 'report', 'save']) 28 | timer.wrap('env', env, ['step']) 29 | timer.wrap('replay', replay, ['add', 'save']) 30 | timer.wrap('logger', logger, ['write']) 31 | 32 | nonzeros = set() 33 | def per_episode(ep): 34 | length = len(ep['reward']) - 1 35 | score = float(ep['reward'].astype(np.float64).sum()) 36 | sum_abs_reward = float(np.abs(ep['reward']).astype(np.float64).sum()) 37 | logger.add({ 38 | 'length': length, 39 | 'score': score, 40 | 'sum_abs_reward': sum_abs_reward, 41 | 'reward_rate': (np.abs(ep['reward']) >= 0.5).mean(), 42 | }, prefix='episode') 43 | print(f'Episode has {length} steps and return {score:.1f}.') 44 | stats = {} 45 | for key in args.log_keys_video: 46 | if key in ep: 47 | stats[f'policy_{key}'] = ep[key] 48 | for key, value in ep.items(): 49 | if not args.log_zeros and key not in nonzeros and (value == 0).all(): 50 | continue 51 | nonzeros.add(key) 52 | if re.match(args.log_keys_sum, key): 53 | stats[f'sum_{key}'] = ep[key].sum() 54 | if re.match(args.log_keys_mean, key): 55 | stats[f'mean_{key}'] = ep[key].mean() 56 | if re.match(args.log_keys_max, key): 57 | stats[f'max_{key}'] = ep[key].max(0).mean() 58 | metrics.add(stats, prefix='stats') 59 | 60 | driver = embodied.Driver(env) 61 | driver.on_episode(lambda ep, worker: per_episode(ep)) 62 | driver.on_step(lambda tran, _: step.increment()) 63 | driver.on_step(replay.add) 64 | 65 | replay.maybe_restore() 66 | 67 | print('Prefill train dataset.') 68 | random_agent = embodied.RandomAgent(env.act_space) 69 | while len(replay) < max(args.batch_steps * config.envs.amount, args.train_fill): 70 | driver(random_agent.policy, steps=100) 71 | logger.add(metrics.result()) 72 | logger.write() 73 | 74 | if config.replay == 'lfs': 75 | dataset = agent.dataset(replay, shared_memory=True) 76 | elif config.replay == 'uniform': 77 | dataset = agent.dataset(replay.dataset, shared_memory=False) 78 | state = [None] # To be writable from train step function below. 79 | batch = [None] 80 | def train_step(tran, worker): 81 | for _ in range(should_train(step)): 82 | with timer.scope('dataset'): 83 | batch[0] = next(dataset) 84 | if should_profile: 85 | jax.profiler.start_trace(f"{args.profile_path}/{step.value}") 86 | print(f'profiling step {step}') 87 | outs, state[0], mets = agent.train(batch[0], state[0]) 88 | if should_profile: 89 | jax.profiler.stop_trace() 90 | metrics.add(mets, prefix='train') 91 | if 'priority' in outs: 92 | replay.prioritize(outs['key'], outs['priority']) 93 | updates.increment() 94 | if should_sync(updates): 95 | agent.sync() 96 | if should_log(step): 97 | agg = metrics.result() 98 | report = agent.report(batch[0]) 99 | report = {k: v for k, v in report.items() if 'train/' + k not in agg} 100 | logger.add(agg) 101 | # TODO: do this rarely 102 | if should_report(step): 103 | logger.add(report, prefix='report') 104 | logger.add(replay.stats, prefix='replay') 105 | logger.add(timer.stats(), prefix='timer') 106 | logger.write(fps=True) 107 | driver.on_step(train_step) 108 | 109 | checkpoint = embodied.Checkpoint(logdir / 'checkpoint.ckpt', parallel=False) 110 | timer.wrap('checkpoint', checkpoint, ['save', 'load']) 111 | checkpoint.step = step 112 | checkpoint.agent = agent 113 | checkpoint.replay = replay 114 | if args.from_checkpoint: 115 | checkpoint.load(args.from_checkpoint) 116 | checkpoint.load_or_save() 117 | should_save(step) # Register that we jused saved. 118 | 119 | print('Start training loop.') 120 | policy = lambda *args: agent.policy( 121 | *args, mode='explore' if should_expl(step) else 'train') 122 | while step < args.steps: 123 | driver(policy, steps=100) 124 | if should_save(step): 125 | try: 126 | checkpoint.save() 127 | except: 128 | print('saving failed') 129 | pass 130 | logger.write() 131 | -------------------------------------------------------------------------------- /recall2imagine/embodied/run/train_eval.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import embodied 4 | import numpy as np 5 | 6 | 7 | def train_eval( 8 | agent, train_env, eval_env, train_replay, eval_replay, logger, args): 9 | 10 | logdir = embodied.Path(args.logdir) 11 | logdir.mkdirs() 12 | print('Logdir', logdir) 13 | should_expl = embodied.when.Until(args.expl_until) 14 | should_train = embodied.when.Ratio(args.train_ratio / args.batch_steps) 15 | should_log = embodied.when.Clock(args.log_every) 16 | should_save = embodied.when.Clock(args.save_every) 17 | should_eval = embodied.when.Every(args.eval_every, args.eval_initial) 18 | should_sync = embodied.when.Every(args.sync_every) 19 | step = logger.step 20 | updates = embodied.Counter() 21 | metrics = embodied.Metrics() 22 | print('Observation space:', embodied.format(train_env.obs_space), sep='\n') 23 | print('Action space:', embodied.format(train_env.act_space), sep='\n') 24 | 25 | timer = embodied.Timer() 26 | timer.wrap('agent', agent, ['policy', 'train', 'report', 'save']) 27 | timer.wrap('env', train_env, ['step']) 28 | if hasattr(train_replay, '_sample'): 29 | timer.wrap('replay', train_replay, ['_sample']) 30 | 31 | nonzeros = set() 32 | def per_episode(ep, mode): 33 | length = len(ep['reward']) - 1 34 | score = float(ep['reward'].astype(np.float64).sum()) 35 | logger.add({ 36 | 'length': length, 'score': score, 37 | 'reward_rate': (ep['reward'] - ep['reward'].min() >= 0.1).mean(), 38 | }, prefix=('episode' if mode == 'train' else f'{mode}_episode')) 39 | print(f'Episode has {length} steps and return {score:.1f}.') 40 | stats = {} 41 | for key in args.log_keys_video: 42 | if key in ep: 43 | stats[f'policy_{key}'] = ep[key] 44 | for key, value in ep.items(): 45 | if not args.log_zeros and key not in nonzeros and (value == 0).all(): 46 | continue 47 | nonzeros.add(key) 48 | if re.match(args.log_keys_sum, key): 49 | stats[f'sum_{key}'] = ep[key].sum() 50 | if re.match(args.log_keys_mean, key): 51 | stats[f'mean_{key}'] = ep[key].mean() 52 | if re.match(args.log_keys_max, key): 53 | stats[f'max_{key}'] = ep[key].max(0).mean() 54 | metrics.add(stats, prefix=f'{mode}_stats') 55 | 56 | driver_train = embodied.Driver(train_env) 57 | driver_train.on_episode(lambda ep, worker: per_episode(ep, mode='train')) 58 | driver_train.on_step(lambda tran, _: step.increment()) 59 | driver_train.on_step(train_replay.add) 60 | driver_eval = embodied.Driver(eval_env) 61 | driver_eval.on_step(eval_replay.add) 62 | driver_eval.on_episode(lambda ep, worker: per_episode(ep, mode='eval')) 63 | 64 | random_agent = embodied.RandomAgent(train_env.act_space) 65 | print('Prefill train dataset.') 66 | while len(train_replay) < max(args.batch_steps, args.train_fill): 67 | driver_train(random_agent.policy, steps=100) 68 | print('Prefill eval dataset.') 69 | while len(eval_replay) < max(args.batch_steps, args.eval_fill): 70 | driver_eval(random_agent.policy, steps=100) 71 | logger.add(metrics.result()) 72 | logger.write() 73 | 74 | dataset_train = agent.dataset(train_replay.dataset) 75 | dataset_eval = agent.dataset(eval_replay.dataset) 76 | state = [None] # To be writable from train step function below. 77 | batch = [None] 78 | def train_step(tran, worker): 79 | for _ in range(should_train(step)): 80 | with timer.scope('dataset_train'): 81 | batch[0] = next(dataset_train) 82 | outs, state[0], mets = agent.train(batch[0], state[0]) 83 | metrics.add(mets, prefix='train') 84 | if 'priority' in outs: 85 | train_replay.prioritize(outs['key'], outs['priority']) 86 | updates.increment() 87 | if should_sync(updates): 88 | agent.sync() 89 | if should_log(step): 90 | logger.add(metrics.result()) 91 | logger.add(agent.report(batch[0]), prefix='report') 92 | with timer.scope('dataset_eval'): 93 | eval_batch = next(dataset_eval) 94 | logger.add(agent.report(eval_batch), prefix='eval') 95 | logger.add(train_replay.stats, prefix='replay') 96 | logger.add(eval_replay.stats, prefix='eval_replay') 97 | logger.add(timer.stats(), prefix='timer') 98 | logger.write(fps=True) 99 | driver_train.on_step(train_step) 100 | 101 | checkpoint = embodied.Checkpoint(logdir / 'checkpoint.ckpt') 102 | checkpoint.step = step 103 | checkpoint.agent = agent 104 | checkpoint.train_replay = train_replay 105 | checkpoint.eval_replay = eval_replay 106 | if args.from_checkpoint: 107 | checkpoint.load(args.from_checkpoint) 108 | checkpoint.load_or_save() 109 | should_save(step) # Register that we jused saved. 110 | 111 | print('Start training loop.') 112 | policy_train = lambda *args: agent.policy( 113 | *args, mode='explore' if should_expl(step) else 'train') 114 | policy_eval = lambda *args: agent.policy(*args, mode='eval') 115 | while step < args.steps: 116 | if should_eval(step): 117 | print('Starting evaluation at step', int(step)) 118 | driver_eval.reset() 119 | driver_eval(policy_eval, episodes=max(len(eval_env), args.eval_eps)) 120 | driver_train(policy_train, steps=100) 121 | if should_save(step): 122 | checkpoint.save() 123 | logger.write() 124 | logger.write() 125 | -------------------------------------------------------------------------------- /recall2imagine/embodied/run/train_holdout.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import embodied 4 | import numpy as np 5 | 6 | 7 | def train_holdout(agent, env, train_replay, eval_replay, logger, args): 8 | 9 | logdir = embodied.Path(args.logdir) 10 | logdir.mkdirs() 11 | print('Logdir', logdir) 12 | should_expl = embodied.when.Until(args.expl_until) 13 | should_train = embodied.when.Ratio(args.train_ratio / args.batch_steps) 14 | should_log = embodied.when.Clock(args.log_every) 15 | should_save = embodied.when.Clock(args.save_every) 16 | should_sync = embodied.when.Every(args.sync_every) 17 | step = logger.step 18 | updates = embodied.Counter() 19 | metrics = embodied.Metrics() 20 | print('Observation space:', embodied.format(env.obs_space), sep='\n') 21 | print('Action space:', embodied.format(env.act_space), sep='\n') 22 | 23 | timer = embodied.Timer() 24 | timer.wrap('agent', agent, ['policy', 'train', 'report', 'save']) 25 | timer.wrap('env', env, ['step']) 26 | if hasattr(train_replay, '_sample'): 27 | timer.wrap('replay', train_replay, ['_sample']) 28 | 29 | nonzeros = set() 30 | def per_episode(ep): 31 | length = len(ep['reward']) - 1 32 | score = float(ep['reward'].astype(np.float64).sum()) 33 | logger.add({ 34 | 'length': length, 'score': score, 35 | 'reward_rate': (ep['reward'] - ep['reward'].min() >= 0.1).mean(), 36 | }, prefix='episode') 37 | print(f'Episode has {length} steps and return {score:.1f}.') 38 | stats = {} 39 | for key in args.log_keys_video: 40 | if key in ep: 41 | stats[f'policy_{key}'] = ep[key] 42 | for key, value in ep.items(): 43 | if not args.log_zeros and key not in nonzeros and (value == 0).all(): 44 | continue 45 | nonzeros.add(key) 46 | if re.match(args.log_keys_sum, key): 47 | stats[f'sum_{key}'] = ep[key].sum() 48 | if re.match(args.log_keys_mean, key): 49 | stats[f'mean_{key}'] = ep[key].mean() 50 | if re.match(args.log_keys_max, key): 51 | stats[f'max_{key}'] = ep[key].max(0).mean() 52 | metrics.add(stats, prefix='stats') 53 | 54 | driver = embodied.Driver(env) 55 | driver.on_episode(lambda ep, worker: per_episode(ep)) 56 | driver.on_step(lambda tran, _: step.increment()) 57 | driver.on_step(train_replay.add) 58 | 59 | print('Fill eval dataset.') 60 | driver_eval = embodied.Driver(env) 61 | driver_eval.on_step(eval_replay.add) 62 | random_agent = embodied.RandomAgent(env.act_space) 63 | while len(eval_replay) < max(args.batch_steps, args.eval_fill): 64 | print(len(eval_replay), max(args.batch_steps, args.eval_fill)) 65 | driver_eval(random_agent.policy, steps=100) 66 | del driver_eval 67 | print('Prefill train dataset.') 68 | while len(train_replay) < max(args.batch_steps, args.train_fill): 69 | print(len(train_replay), max(args.batch_steps, args.train_fill)) 70 | driver(random_agent.policy, steps=100) 71 | logger.add(metrics.result()) 72 | logger.write() 73 | 74 | dataset_train = agent.dataset(train_replay.dataset) 75 | dataset_eval = agent.dataset(eval_replay.dataset) 76 | state = [None] # To be writable from train step function below. 77 | batch = [None] 78 | def train_step(tran, worker): 79 | for _ in range(should_train(step)): 80 | with timer.scope('dataset_train'): 81 | batch[0] = next(dataset_train) 82 | outs, state[0], mets = agent.train(batch[0], state[0]) 83 | metrics.add(mets, prefix='train') 84 | if 'priority' in outs: 85 | train_replay.prioritize(outs['key'], outs['priority']) 86 | updates.increment() 87 | if should_sync(updates): 88 | agent.sync() 89 | if should_log(step): 90 | logger.add(metrics.result()) 91 | logger.add(agent.report(batch[0]), prefix='report') 92 | with timer.scope('dataset_eval'): 93 | eval_batch = next(dataset_eval) 94 | logger.add(agent.report(eval_batch), prefix='eval') 95 | logger.add(train_replay.stats, prefix='replay') 96 | logger.add(eval_replay.stats, prefix='eval_replay') 97 | logger.add(timer.stats(), prefix='timer') 98 | logger.write(fps=True) 99 | driver.on_step(train_step) 100 | 101 | checkpoint = embodied.Checkpoint(logdir / 'checkpoint.ckpt') 102 | checkpoint.step = step 103 | checkpoint.agent = agent 104 | checkpoint.train_replay = train_replay 105 | checkpoint.eval_replay = eval_replay 106 | if args.from_checkpoint: 107 | checkpoint.load(args.from_checkpoint) 108 | checkpoint.load_or_save() 109 | should_save(step) # Register that we jused saved. 110 | 111 | print('Start training loop.') 112 | policy = lambda *args: agent.policy( 113 | *args, mode='explore' if should_expl(step) else 'train') 114 | while step < args.steps: 115 | # scalars = collections.defaultdict(list) 116 | # for _ in range(args.eval_samples): 117 | # for key, value in agent.report(next(dataset_eval)).items(): 118 | # if value.shape == (): 119 | # scalars[key].append(value) 120 | # for name, values in scalars.items(): 121 | # logger.scalar(f'eval/{name}', np.array(values, np.float64).mean()) 122 | # logger.write() 123 | driver(policy, steps=100) 124 | if should_save(step): 125 | checkpoint.save() 126 | logger.write() 127 | logger.write() 128 | -------------------------------------------------------------------------------- /recall2imagine/embodied/run/train_save.py: -------------------------------------------------------------------------------- 1 | import io 2 | import re 3 | from datetime import datetime 4 | 5 | import embodied 6 | import numpy as np 7 | 8 | 9 | def train_save(agent, env, replay, logger, args): 10 | 11 | logdir = embodied.Path(args.logdir) 12 | logdir.mkdirs() 13 | print('Logdir:', logdir) 14 | should_expl = embodied.when.Until(args.expl_until) 15 | should_train = embodied.when.Ratio(args.train_ratio / args.batch_steps) 16 | should_log = embodied.when.Clock(args.log_every) 17 | should_save = embodied.when.Clock(args.save_every) 18 | should_sync = embodied.when.Every(args.sync_every) 19 | step = logger.step 20 | updates = embodied.Counter() 21 | metrics = embodied.Metrics() 22 | print('Observation space:', embodied.format(env.obs_space), sep='\n') 23 | print('Action space:', embodied.format(env.act_space), sep='\n') 24 | 25 | timer = embodied.Timer() 26 | timer.wrap('agent', agent, ['policy', 'train', 'report', 'save']) 27 | timer.wrap('env', env, ['step']) 28 | timer.wrap('replay', replay, ['add', 'save']) 29 | timer.wrap('logger', logger, ['write']) 30 | 31 | nonzeros = set() 32 | def per_episode(ep): 33 | length = len(ep['reward']) - 1 34 | score = float(ep['reward'].astype(np.float64).sum()) 35 | sum_abs_reward = float(np.abs(ep['reward']).astype(np.float64).sum()) 36 | logger.add({ 37 | 'length': length, 38 | 'score': score, 39 | 'sum_abs_reward': sum_abs_reward, 40 | 'reward_rate': (np.abs(ep['reward']) >= 0.5).mean(), 41 | }, prefix='episode') 42 | print(f'Episode has {length} steps and return {score:.1f}.') 43 | stats = {} 44 | for key in args.log_keys_video: 45 | if key in ep: 46 | stats[f'policy_{key}'] = ep[key] 47 | for key, value in ep.items(): 48 | if not args.log_zeros and key not in nonzeros and (value == 0).all(): 49 | continue 50 | nonzeros.add(key) 51 | if re.match(args.log_keys_sum, key): 52 | stats[f'sum_{key}'] = ep[key].sum() 53 | if re.match(args.log_keys_mean, key): 54 | stats[f'mean_{key}'] = ep[key].mean() 55 | if re.match(args.log_keys_max, key): 56 | stats[f'max_{key}'] = ep[key].max(0).mean() 57 | metrics.add(stats, prefix='stats') 58 | 59 | epsdir = embodied.Path(args.logdir) / 'saved_episodes' 60 | epsdir.mkdirs() 61 | print('Saving episodes:', epsdir) 62 | def save(ep): 63 | time = datetime.now().strftime("%Y%m%dT%H%M%S") 64 | uuid = str(embodied.uuid()) 65 | score = str(np.round(ep['reward'].sum(), 1)).replace('-', 'm') 66 | length = len(ep['reward']) 67 | filename = epsdir / f'{time}-{uuid}-len{length}-rew{score}.npz' 68 | with io.BytesIO() as stream: 69 | np.savez_compressed(stream, **ep) 70 | stream.seek(0) 71 | filename.write(stream.read(), mode='wb') 72 | print('Saved episode:', filename) 73 | saver = embodied.Worker(save, 'thread') 74 | 75 | driver = embodied.Driver(env) 76 | driver.on_episode(lambda ep, worker: per_episode(ep)) 77 | driver.on_episode(lambda ep, worker: saver(ep)) 78 | driver.on_step(lambda tran, _: step.increment()) 79 | driver.on_step(replay.add) 80 | 81 | print('Prefill train dataset.') 82 | random_agent = embodied.RandomAgent(env.act_space) 83 | while len(replay) < max(args.batch_steps, args.train_fill): 84 | driver(random_agent.policy, steps=100) 85 | logger.add(metrics.result()) 86 | logger.write() 87 | 88 | dataset = agent.dataset(replay.dataset) 89 | state = [None] # To be writable from train step function below. 90 | batch = [None] 91 | def train_step(tran, worker): 92 | for _ in range(should_train(step)): 93 | with timer.scope('dataset'): 94 | batch[0] = next(dataset) 95 | outs, state[0], mets = agent.train(batch[0], state[0]) 96 | metrics.add(mets, prefix='train') 97 | if 'priority' in outs: 98 | replay.prioritize(outs['key'], outs['priority']) 99 | updates.increment() 100 | if should_sync(updates): 101 | agent.sync() 102 | if should_log(step): 103 | agg = metrics.result() 104 | report = agent.report(batch[0]) 105 | report = {k: v for k, v in report.items() if 'train/' + k not in agg} 106 | logger.add(agg) 107 | logger.add(report, prefix='report') 108 | logger.add(replay.stats, prefix='replay') 109 | logger.add(timer.stats(), prefix='timer') 110 | logger.write(fps=True) 111 | driver.on_step(train_step) 112 | 113 | checkpoint = embodied.Checkpoint(logdir / 'checkpoint.ckpt') 114 | timer.wrap('checkpoint', checkpoint, ['save', 'load']) 115 | checkpoint.step = step 116 | checkpoint.agent = agent 117 | checkpoint.replay = replay 118 | if args.from_checkpoint: 119 | checkpoint.load(args.from_checkpoint) 120 | checkpoint.load_or_save() 121 | should_save(step) # Register that we jused saved. 122 | 123 | print('Start training loop.') 124 | policy = lambda *args: agent.policy( 125 | *args, mode='explore' if should_expl(step) else 'train') 126 | while step < args.steps: 127 | driver(policy, steps=100) 128 | if should_save(step): 129 | checkpoint.save() 130 | logger.write() 131 | -------------------------------------------------------------------------------- /recall2imagine/embodied/scripts/install-atari.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | set -eu 3 | 4 | apt-get update 5 | apt-get install -y wget 6 | apt-get install -y unrar 7 | apt-get clean 8 | 9 | pip3 install gym==0.19.0 10 | pip3 install atari-py==0.2.9 11 | pip3 install opencv-python 12 | 13 | mkdir roms && cd roms 14 | wget -L -nv http://www.atarimania.com/roms/Roms.rar 15 | unrar x -o+ Roms.rar 16 | python3 -m atari_py.import_roms ROMS 17 | cd .. && rm -rf roms 18 | -------------------------------------------------------------------------------- /recall2imagine/embodied/scripts/install-dmlab.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | set -eu 3 | 4 | # Dependencies 5 | apt-get update && apt-get install -y \ 6 | build-essential curl freeglut3 gettext git libffi-dev libglu1-mesa \ 7 | libglu1-mesa-dev libjpeg-dev liblua5.1-0-dev libosmesa6-dev \ 8 | libsdl2-dev lua5.1 pkg-config python-setuptools python3-dev \ 9 | software-properties-common unzip zip zlib1g-dev g++ 10 | pip3 install numpy 11 | 12 | # Bazel 13 | apt-get install -y apt-transport-https curl gnupg 14 | curl -fsSL https://bazel.build/bazel-release.pub.gpg | gpg --dearmor > bazel.gpg 15 | mv bazel.gpg /etc/apt/trusted.gpg.d/ 16 | echo "deb [arch=amd64] https://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list 17 | apt-get update && apt-get install -y bazel 18 | 19 | # Build 20 | git clone https://github.com/deepmind/lab.git 21 | cd lab 22 | echo 'build --cxxopt=-std=c++17' > .bazelrc 23 | bazel build -c opt //python/pip_package:build_pip_package 24 | ./bazel-bin/python/pip_package/build_pip_package /tmp/dmlab_pkg 25 | pip3 install --force-reinstall /tmp/dmlab_pkg/deepmind_lab-*.whl 26 | cd .. 27 | rm -rf lab 28 | 29 | # Dataset 30 | mkdir dmlab_data 31 | cd dmlab_data 32 | pip3 install Pillow 33 | curl https://bradylab.ucsd.edu/stimuli/ObjectsAll.zip -o ObjectsAll.zip 34 | unzip ObjectsAll.zip 35 | cd OBJECTSALL 36 | python3 << EOM 37 | import os 38 | from PIL import Image 39 | files = [f for f in os.listdir('.') if f.lower().endswith('jpg')] 40 | for i, file in enumerate(sorted(files)): 41 | print(file) 42 | im = Image.open(file) 43 | im.save('../%04d.png' % (i+1)) 44 | EOM 45 | cd .. 46 | rm -rf __MACOSX OBJECTSALL ObjectsAll.zip 47 | 48 | apt-get clean 49 | -------------------------------------------------------------------------------- /recall2imagine/embodied/scripts/install-minecraft.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | set -eu 3 | 4 | apt-get update 5 | apt-get install -y libgl1-mesa-dev 6 | apt-get install -y libx11-6 7 | apt-get install -y openjdk-8-jdk 8 | apt-get install -y x11-xserver-utils 9 | apt-get install -y xvfb 10 | apt-get clean 11 | 12 | pip3 install minerl==0.4.4 13 | -------------------------------------------------------------------------------- /recall2imagine/embodied/scripts/xvfb_run.sh: -------------------------------------------------------------------------------- 1 | xvfb-run -a -s "-screen 0 1024x768x24 -ac +extension GLX +render -noreset" "$@" 2 | # xvfb-run "$@" 3 | -------------------------------------------------------------------------------- /recall2imagine/expl.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | tree_map = jax.tree_util.tree_map 4 | sg = lambda x: tree_map(jax.lax.stop_gradient, x) 5 | 6 | from . import nets 7 | from . import jaxutils 8 | from . import ninjax as nj 9 | 10 | 11 | class Disag(nj.Module): 12 | 13 | def __init__(self, wm, act_space, config): 14 | self.config = config.update({'disag_head.inputs': ['tensor']}) 15 | self.opt = jaxutils.Optimizer(name='disag_opt', **config.expl_opt) 16 | self.inputs = nets.Input(config.disag_head.inputs, dims='deter') 17 | self.target = nets.Input(self.config.disag_target, dims='deter') 18 | self.nets = [ 19 | nets.MLP(shape=None, **self.config.disag_head, name=f'disag{i}') 20 | for i in range(self.config.disag_models)] 21 | 22 | def __call__(self, traj): 23 | inp = self.inputs(traj) 24 | preds = jnp.array([net(inp).mode() for net in self.nets]) 25 | return preds.std(0).mean(-1)[1:] 26 | 27 | def train(self, data): 28 | return self.opt(self.nets, self.loss, data) 29 | 30 | def loss(self, data): 31 | inp = sg(self.inputs(data)[:, :-1]) 32 | tar = sg(self.target(data)[:, 1:]) 33 | losses = [] 34 | for net in self.nets: 35 | net._shape = tar.shape[2:] 36 | losses.append(-net(inp).log_prob(tar).mean()) 37 | return jnp.array(losses).sum() 38 | -------------------------------------------------------------------------------- /recall2imagine/ssm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chandar-lab/Recall2Imagine/6e317e751ffd8f381e09476f3633567dff9f9234/recall2imagine/ssm/__init__.py -------------------------------------------------------------------------------- /recall2imagine/ssm/siso.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | from flax import linen as nn 4 | import jax 5 | import jax.numpy as np 6 | from jax.numpy.linalg import eigh, inv, matrix_power 7 | from jax.nn.initializers import normal 8 | 9 | from .common import log_step_initializer, \ 10 | slow_scan, \ 11 | fast_scan, \ 12 | depthwise, \ 13 | SequenceBlock, \ 14 | batchwise 15 | 16 | 17 | def make_HiPPO(N): 18 | """ 19 | A standard HiPPO initialization. 20 | """ 21 | P = np.sqrt(1 + 2 * np.arange(N)) 22 | A = P[:, np.newaxis] * P[np.newaxis, :] 23 | A = np.tril(A) - np.diag(np.arange(N)) 24 | return -A 25 | 26 | def discrete_DPLR(Lambda, P, Q, B, C, step, L): 27 | """ 28 | DLPR matrix discretization function. 29 | """ 30 | # Convert parameters to matrices 31 | B = B[:, np.newaxis] 32 | Ct = C[np.newaxis, :] 33 | 34 | N = Lambda.shape[0] 35 | A = np.diag(Lambda) - P[:, np.newaxis] @ Q[:, np.newaxis].conj().T 36 | I = np.eye(N) 37 | 38 | # Forward Euler 39 | A0 = (2.0 / step) * I + A 40 | 41 | # Backward Euler 42 | D = np.diag(1.0 / ((2.0 / step) - Lambda)) 43 | Qc = Q.conj().T.reshape(1, -1) 44 | P2 = P.reshape(-1, 1) 45 | A1 = D - (D @ P2 * (1.0 / (1 + (Qc @ D @ P2))) * Qc @ D) 46 | 47 | # A bar and B bar 48 | Ab = A1 @ A0 49 | Bb = 2 * A1 @ B 50 | 51 | # Recover Cbar from Ct 52 | Cb = Ct @ inv(I - matrix_power(Ab, L)).conj() 53 | return Ab, Bb, Cb.conj() 54 | 55 | def make_NPLR_HiPPO(N): 56 | """ 57 | Creates a HiPPO matrix and discretizes it. 58 | """ 59 | # Make -HiPPO 60 | nhippo = make_HiPPO(N) 61 | 62 | # Add in a rank 1 term. Makes it Normal. 63 | P = np.sqrt(np.arange(N) + 0.5) 64 | 65 | 66 | # HiPPO also specifies the B matrix 67 | B = np.sqrt(2 * np.arange(N) + 1.0) 68 | return nhippo, P, B 69 | 70 | 71 | def make_DPLR_HiPPO(N): 72 | """ 73 | Diagonalize NPLR representation 74 | """ 75 | A, P, B = make_NPLR_HiPPO(N) 76 | 77 | S = A + P[:, np.newaxis] * P[np.newaxis, :] 78 | 79 | # Check skew symmetry 80 | S_diag = np.diagonal(S) 81 | Lambda_real = np.mean(S_diag) * np.ones_like(S_diag) 82 | # assert np.allclose(Lambda_real, S_diag, atol=1e-3) 83 | 84 | # Diagonalize S to V \Lambda V^* 85 | Lambda_imag, V = eigh(S * -1j) 86 | 87 | P = V.conj().T @ P 88 | B = V.conj().T @ B 89 | return Lambda_real + 1j * Lambda_imag, P, B, V 90 | 91 | def init(x): 92 | """ 93 | Factory for constant initializer in Flax 94 | """ 95 | def _init(key, shape): 96 | assert shape == x.shape 97 | return x 98 | 99 | return _init 100 | 101 | def hippo_initializer(N): 102 | """ 103 | The initializer function for the DPLR matrices. 104 | """ 105 | Lambda, P, B, _ = make_DPLR_HiPPO(N) 106 | return init(Lambda.real), init(Lambda.imag), init(P), init(B) 107 | 108 | 109 | @depthwise 110 | class SISOLayer(nn.Module): 111 | """ 112 | The SISO SSM (S4). 113 | """ 114 | N: int 115 | l_max: int # apparently in rec mode this has no influence on anything 116 | parallel: bool = False 117 | conv: bool = False 118 | 119 | # Special parameters with multiplicative factor on lr and no weight decay (handled by main train script) 120 | lr = { 121 | "Lambda_re": 0.1, 122 | "Lambda_im": 0.1, 123 | "P": 0.1, 124 | "B": 0.1, 125 | "log_step": 0.1, 126 | } 127 | 128 | def setup(self): 129 | """ 130 | initializes the SSM parameters. 131 | """ 132 | # Learned Parameters (C is complex!) 133 | init_A_re, init_A_im, init_P, init_B = hippo_initializer(self.N) 134 | self.Lambda_re = self.param("Lambda_re", init_A_re, (self.N,)) 135 | self.Lambda_im = self.param("Lambda_im", init_A_im, (self.N,)) 136 | # Ensure the real part of Lambda is negative 137 | # (described in the SaShiMi follow-up to S4) 138 | self.Lambda = np.clip(self.Lambda_re, None, -1e-4) + 1j * self.Lambda_im 139 | self.P = self.param("P", init_P, (self.N,)) 140 | self.B = self.param("B", init_B, (self.N,)) 141 | # C should be init as standard normal 142 | # This doesn't work due to how JAX handles complex optimizers https://github.com/deepmind/optax/issues/196 143 | # self.C = self.param("C", normal(stddev=1.0, dtype=np.complex64), (self.N,)) 144 | self.C = self.param("C", normal(stddev=0.5**0.5), (self.N, 2)) 145 | self.C = self.C[..., 0] + 1j * self.C[..., 1] 146 | self.D = self.param("D", nn.initializers.ones, (1,)) 147 | self.step = np.exp(self.param("log_step", log_step_initializer(), (1,))) 148 | # Flax trick to cache discrete form during decoding. 149 | def init_discrete(): 150 | return discrete_DPLR( 151 | self.Lambda, 152 | self.P, 153 | self.P, 154 | self.B, 155 | self.C, 156 | self.step, 157 | self.l_max, 158 | ) 159 | self.ssm = init_discrete() 160 | 161 | def __call__(self, u, x0, dones=None): 162 | """ 163 | Forward pass for SSM. 164 | 165 | when parallel=True: 166 | Shape: 167 | u.shape: (len, dims) 168 | u.dtype: float32 169 | x_0.shape: (N, dims) 170 | x_0.dtyle: complex64 171 | Output Shape: 172 | y.shape: (len, dims) 173 | y.dtype: float32 174 | x.shape: (N, dims) 175 | x.dtype: complex64 176 | """ 177 | if not self.parallel or u.shape[0] == 1: 178 | x, y = slow_scan(*self.ssm, u, x0) 179 | Du = jax.vmap(lambda u: self.D * u)(u) 180 | return (y.real + Du).reshape(-1), x 181 | y, x = fast_scan(*self.ssm, u, x0, dones) # .real is happening inside of fast_scan 182 | Du = jax.vmap(lambda u: self.D * u)(u) 183 | return (y + Du)[..., 0], x 184 | 185 | 186 | SISOBlock = partial(SequenceBlock, layer_cls=SISOLayer) 187 | SISOBlock = batchwise(SISOBlock) 188 | SISOLayerB = batchwise(SISOLayer) -------------------------------------------------------------------------------- /recall2imagine/ssm_nets.py: -------------------------------------------------------------------------------- 1 | from . import ninjax as nj 2 | from . import jaxutils 3 | 4 | from .ssm.siso import SISOBlock 5 | from .ssm.mimo import MIMOBlock 6 | from .nets import RSSM, Linear 7 | 8 | import jax 9 | from jax import numpy as jnp 10 | from tensorflow_probability.substrates import jax as tfp 11 | f32 = jnp.float32 12 | tfd = tfp.distributions 13 | tree_map = jax.tree_util.tree_map 14 | sg = lambda x: tree_map(jax.lax.stop_gradient, x) 15 | 16 | cast = jaxutils.cast_to_compute 17 | 18 | def init_siso_ssm(N, d_model, n_layers, parallel=True, conv=False, use_norm=False, **kw): 19 | """ 20 | The wrapper function that wraps SISO SSM into FlaxModule for compatibility with the main code. 21 | """ 22 | if 'l_max' not in kw: 23 | kw['l_max'] = 255 24 | 25 | return nj.FlaxModule( 26 | SISOBlock, 27 | layer={ 28 | 'N': N, 29 | 'l_max': kw['l_max'], 30 | 'parallel': parallel, 31 | 'conv': conv 32 | }, 33 | name='siso_ssm', 34 | d_model=d_model, 35 | n_layers=n_layers, 36 | dropout=kw.get('dropout', 0.0), 37 | prenorm=kw.get('prenorm', False), 38 | mlp=kw.get('mlp', False), 39 | glu=kw.get('glu', False), 40 | use_norm=use_norm, 41 | ) 42 | 43 | 44 | def init_mimo_ssm( 45 | P, 46 | H, 47 | n_blocks=1, 48 | n_layers=1, 49 | conj_sym=False, 50 | C_init='', 51 | discretization='zoh', 52 | dt_min=0.001, 53 | dt_max=0.1, 54 | clip_eigs=False, 55 | parallel=True, 56 | use_norm=False, 57 | reset_mode=False, 58 | **kw 59 | ): 60 | """ 61 | The wrapper function that wraps MIMO SSM into FlaxModule for compatibility with the main code. 62 | 63 | kw: 64 | n_blocks: int 65 | conj_sym: bool 66 | C_init: 67 | """ 68 | 69 | if conj_sym: 70 | P = P // 2 71 | 72 | return nj.FlaxModule( 73 | MIMOBlock, 74 | layer={ 75 | # **{ 76 | 'H': H, 77 | 'P': P, 78 | 'n_blocks': n_blocks, 79 | 'C_init': C_init, 80 | 'discretization': discretization, 81 | 'dt_min': dt_min, 82 | 'dt_max': dt_max, 83 | 'conj_sym': conj_sym, 84 | 'clip_eigs': clip_eigs, 85 | 'parallel': parallel, 86 | 'reset_mode': reset_mode, 87 | }, 88 | # }, 89 | name='mimo_ssm', 90 | d_model=H, 91 | n_layers=n_layers, 92 | dropout=kw.get('dropout', 0.0), 93 | prenorm=kw.get('prenorm', False), 94 | mlp=kw.get('mlp', False), 95 | glu=kw.get('glu', False), 96 | use_norm=use_norm, 97 | ) 98 | 99 | 100 | class S3M(RSSM): 101 | """ 102 | This class implements RSSM with SISO/MIMO SSM cell. 103 | """ 104 | def __init__(self, deter=1024, stoch=32, classes=32, units=1024, hidden=128, unroll=False, initial='learned', 105 | unimix=0.01, action_clip=1.0, nonrecurrent_enc=False, ssm='mimo', ssm_kwargs=None, **kw): 106 | if ssm == 'siso': 107 | self.core = init_siso_ssm(hidden, units, **ssm_kwargs) 108 | elif ssm == 'mimo': 109 | self.core = init_mimo_ssm(hidden, units, **ssm_kwargs) 110 | else: 111 | raise NotImplementedError("SSM is not implemented") 112 | self._ssm = ssm 113 | self._deter = deter 114 | self._units = units 115 | self._hidden = hidden 116 | self._stoch = stoch 117 | self._classes = classes 118 | self._n_layers = ssm_kwargs['n_layers'] 119 | self._parallel = ssm_kwargs['parallel'] 120 | self._conv = ssm_kwargs['conv'] 121 | self._unroll = unroll 122 | self._initial = initial 123 | self._unimix = unimix 124 | self._nonrecurrent_enc = nonrecurrent_enc 125 | self._action_clip = action_clip 126 | self._kw = kw 127 | self._kw['units'] = units 128 | 129 | def initial(self, bs): 130 | """ 131 | Returns the initial vector for RSSM. 132 | """ 133 | if self._classes: 134 | state = dict( 135 | deter=jnp.zeros([bs, self._deter], f32), 136 | logit=jnp.zeros([bs, self._stoch, self._classes], f32), 137 | stoch=jnp.zeros([bs, self._stoch, self._classes], f32)) 138 | else: 139 | state = dict( 140 | deter=jnp.zeros([bs, self._deter], f32), 141 | mean=jnp.zeros([bs, self._stoch], f32), 142 | std=jnp.ones([bs, self._stoch], f32), 143 | stoch=jnp.zeros([bs, self._stoch], f32)) 144 | if self._ssm == 'siso': 145 | state['hidden'] = jnp.zeros([bs, self._n_layers, self._hidden, self._units], jnp.complex64) 146 | if self._ssm == 'mimo': 147 | state['hidden'] = jnp.zeros([bs, self._n_layers, self._hidden], jnp.complex64) 148 | if self._initial == 'zeros': 149 | state = cast(state) 150 | state['hidden'] = state['hidden'].astype(jnp.complex64) 151 | return state 152 | elif self._initial == 'learned': 153 | hidden = self.get('initial_hidden', jnp.zeros, (2,) + state['hidden'][0].shape, f32) 154 | deter = self.get('initial_deter', jnp.zeros, state['deter'][0].shape, f32) 155 | hidden = jnp.expand_dims(hidden[0] + 1j * hidden[1], 0) 156 | state['hidden'] = jnp.repeat(hidden, bs, 0) 157 | state['deter'] = jnp.repeat(jnp.tanh(deter)[None], bs, 0) 158 | state['stoch'] = self.get_stoch(cast(state['deter'])) 159 | return cast(state) 160 | else: 161 | raise NotImplementedError(self._initial) 162 | 163 | def _cell(self, x, prev_state): 164 | """ 165 | Implements one step forward pass of RSSM. 166 | 167 | x.shape == (batch, units) 168 | prev_state { 169 | 'hidden' : shape (batch, self._hidden, self._units) 170 | } 171 | """ 172 | hidden = prev_state['hidden'] # this is x_t-1 173 | x = x[:, jnp.newaxis] 174 | # cell expected shape 175 | # u: (batch, 1, self._units) 176 | # xk: (batch, self._hidden, self._units) 177 | output, hidden = self.core(x, hidden) # y_t, x_t = S*(u_t-1, x_t-1) 178 | # y: (batch, 1, self._units) 179 | # xk1: (batch, self._hidden, self._units) 180 | if isinstance(output, tuple): 181 | output, _ = output 182 | output = output[:, 0] 183 | if self._ssm == 'siso': 184 | # (batch, hidden, layers, units) -> (batch, layers, hidden, units) 185 | hidden = jnp.transpose(hidden, (0, 2, 1, 3)) 186 | deter = output 187 | kw = {'winit': 'normal', 'fan': 'avg', 'act': self._kw['act'], 'units': self._deter} 188 | deter = self.get('out_proj', Linear, **kw)(deter) 189 | # kw = {**self._kw, 'units': self._deter} 190 | # deter = self.get('deter_proj', Linear, **kw)(output) 191 | return deter, {'deter': deter, 'hidden': hidden} 192 | 193 | 194 | def _cell_scan(self, x, state, first, zero_state): 195 | """ 196 | Implements sequential forward pass for RSSM. 197 | """ 198 | swap = lambda x: x.transpose([1, 0] + list(range(2, len(x.shape)))) 199 | if not self._parallel: 200 | state, (outs, x) = super()._cell_scan(x, state, first, zero_state) 201 | return state, (tree_map(swap, outs), swap(x)) 202 | 203 | if self._ssm == 'siso': 204 | first = jnp.tile(first[..., None], (1, 1, x.shape[-1])) 205 | x, outstate = self.core((swap(x), swap(first)), state['hidden'], zero_state['hidden']) # hidden.shape = (batch, layers, hidden, units) 206 | if isinstance(x, tuple): 207 | x, _ = x 208 | if self._ssm == 'siso': 209 | # (batch, seq, units, layers, hidden) -> (batch, seq, layers, hidden, units) 210 | outstate = jnp.transpose(outstate, (0, 1, 3, 4, 2)) 211 | kw = {'winit': 'normal', 'fan': 'avg', 'act': self._kw['act'], 'units': self._deter} 212 | x = self.get('out_proj', Linear, **kw)(x) 213 | return {'deter': x[:, -1], 'hidden': outstate[:, -1]}, ({'deter': x, 'hidden': outstate}, x) 214 | --------------------------------------------------------------------------------