├── README.md ├── images ├── ablat.png ├── dtr.png ├── interp.png ├── result.png └── scalability.png └── requirements.txt /README.md: -------------------------------------------------------------------------------- 1 | # DART (Full code to be released soon) 2 | Official implementation of ICML 2024 paper "Learning to Play Atari in a World of Tokens" by 3 | 4 | [Pranav Agarwal](https://pranaval.github.io/), [Sheldon Andrews](https://profs.etsmtl.ca/sandrews/) and [Samira Ebrahimi Kahou](https://saebrahimi.github.io/). 5 | 6 | In this work, we introduce discrete abstract representations for transformer-based learning (DART), a sample-efficient method utilizing discrete representations for modeling both the world and learning behavior. DART outperforms previous state-of-the-art methods that do not use look-ahead search on the Atari 100k sample efficiency benchmark with a median human-normalized score of 0.790 and beats humans in 9 out of 26 games. 7 | 8 | [Paper]() [Webpage](https://pranaval.github.io/DART/) 9 | 10 |

11 | 12 |

13 | 14 | 15 | ## Dependencies 16 | Create a virtual environment and install all the required files 17 | ``` 18 | python3.8 -m venv myenv 19 | source myenv/bin/activate 20 | pip install -r requirements.txt 21 | ``` 22 | 23 | ## Overview 24 | The main contribution of our work includes a novel approach that utilizes transformers for both world and policy modeling. Specifically, we utilize a transformer-decoder (GPT) for world modeling and a transformer-encoder (ViT) for policy learning. This represents an improvement compared to IRIS, which relies on CNNs and LSTMs for policy learning, potentially limiting its performance. We use discrete representations for policy and world modeling. These discrete representations capture abstract features, enabling our transformer-based model to focus on task-specific fine grained details. Attending to these details improves decision-making, as demonstrated by our results. To address the problem of partial observability, we introduce a novel mechanism for modeling the memory that aggregates task-relevant information from the previous time step to the next using a self-attention mechanism. Our model showcases enhanced interpretability and sample efficiency. It achieves state-of-the-art results (no-look-ahead search methods) on the Atari 100k benchmark with a median score of 0.790 and superhuman performance in 9 out of 26 games. 25 | 26 | ## Paper Associated 27 | If you use it, please cite 28 | ``` 29 | @article{agarwal2024learning, 30 | title={Learning to Play Atari in a World of Tokens}, 31 | author={Agarwal, Pranav and Andrews, Sheldon and Kahou, Samira Ebrahimi}, 32 | booktitle={International Conference on Machine Learning (ICML)}, 33 | year={2024} 34 | } 35 | ``` 36 | 37 | ## Credits 38 | * Modeling of tokens and the world is inspired from [https://github.com/eloialonso/iris](https://github.com/eloialonso/iris). 39 | * [https://github.com/google-research/vision_transformer](https://github.com/google-research/vision_transformer) 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /images/ablat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pranavAL/DART/c5f4331181165c1d7fb7ee5545952f7c1c1c33f9/images/ablat.png -------------------------------------------------------------------------------- /images/dtr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pranavAL/DART/c5f4331181165c1d7fb7ee5545952f7c1c1c33f9/images/dtr.png -------------------------------------------------------------------------------- /images/interp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pranavAL/DART/c5f4331181165c1d7fb7ee5545952f7c1c1c33f9/images/interp.png -------------------------------------------------------------------------------- /images/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pranavAL/DART/c5f4331181165c1d7fb7ee5545952f7c1c1c33f9/images/result.png -------------------------------------------------------------------------------- /images/scalability.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pranavAL/DART/c5f4331181165c1d7fb7ee5545952f7c1c1c33f9/images/scalability.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.1.0+computecanada 2 | ale-py==0.7.5+computecanada 3 | antlr4-python3-runtime==4.9.3+computecanada 4 | appdirs==1.4.3+computecanada 5 | arch==5.0.1 6 | argon2-cffi==21.3.0+computecanada 7 | argon2-cffi-bindings==21.2.0+computecanada 8 | asciitree==0.3.3+computecanada 9 | asttokens==2.0.5+computecanada 10 | astunparse==1.6.3+computecanada 11 | async-timeout==4.0.2+computecanada 12 | atari-py==0.2.6 13 | attrs==21.4.0+computecanada 14 | Automat==22.10.0 15 | AutoROM==0.4.2 16 | AutoROM.accept-rom-license==0.4.2 17 | backcall==0.2.0+computecanada 18 | beautifulsoup4==4.11.1+computecanada 19 | bitmath==1.3.3.1 20 | bleach==5.0.0+computecanada 21 | bokeh==2.4.3+computecanada 22 | cached-property==1.5.2+computecanada 23 | cachetools==4.2.4+computecanada 24 | castepxbin==0.3.0 25 | certifi==2022.5.18.1+computecanada 26 | cffi==1.15.0+computecanada 27 | chardet==4.0.0+computecanada 28 | charset-normalizer==2.0.12+computecanada 29 | chex==0.1.7+computecanada 30 | click==7.1.2+computecanada 31 | cloudpickle==1.6.0+computecanada 32 | colorama==0.4.6+computecanada 33 | colormath==3.0.0+computecanada 34 | commonroad-io==2023.1 35 | commonroad-vehicle-models==3.0.2 36 | constantly==15.1.0 37 | contourpy==1.0.7+computecanada 38 | coverage==7.2.0+computecanada 39 | crafter==1.8.1 40 | cycler==0.11.0+computecanada 41 | Cython==0.29.14+computecanada 42 | debugpy==1.6.0+computecanada 43 | decorator==4.4.2+computecanada 44 | defusedxml==0.7.1+computecanada 45 | deprecation==2.1.0+computecanada 46 | DI-toolkit==0.2.0 47 | DI-treetensor==0.4.1 48 | dill==0.3.7+computecanada 49 | distlib==0.3.0 50 | dm-tree==0.1.8+computecanada 51 | docker-pycreds==0.4.0+computecanada 52 | easydict==1.9+computecanada 53 | einops==0.3.0 54 | entrypoints==0.4+computecanada 55 | enum-tools==0.10.0 56 | execnet==1.9.0+computecanada 57 | executing==0.8.3+computecanada 58 | fasteners==0.17.3+computecanada 59 | fastjsonschema==2.15.3+computecanada 60 | ffmpeg==1.4 61 | filelock==3.0.12+computecanada 62 | Flask==1.1.4 63 | flatbuffers==20190709135844+computecanada 64 | folium==0.2.1 65 | fonttools==4.39.3+computecanada 66 | future==0.18.2+computecanada 67 | gast==0.4.0+computecanada 68 | gitdb==4.0.9+computecanada 69 | gitpython==3.1.27+computecanada 70 | glfw==1.12.0+computecanada 71 | google-auth==1.35.0+computecanada 72 | google-auth-oauthlib==0.4.6+computecanada 73 | google-pasta==0.2.0+computecanada 74 | graphviz==0.20.1+computecanada 75 | grpcio==1.32.0+computecanada 76 | gtrxl-torch==0.1.7 77 | gym==0.19.0 78 | gym-notices==0.0.8 79 | gymnasium==0.27.0 80 | gymnasium-notices==0.0.1+computecanada 81 | h5py==3.8.0+computecanada 82 | hbutils==0.9.1 83 | hickle==5.0.2 84 | hydra-core==1.2.0+computecanada 85 | hyperlink==21.0.0+computecanada 86 | idna==3.3+computecanada 87 | ijson==3.2.0.post0 88 | imageio==2.9.0+computecanada 89 | imageio-ffmpeg==0.4.2+computecanada 90 | importlib-metadata==4.11.4+computecanada 91 | importlib-resources==5.7.1 92 | incremental==22.10.0 93 | iniconfig==2.0.0 94 | ipykernel==6.14.0 95 | ipython==8.4.0+computecanada 96 | ipython-genutils==0.2.0+computecanada 97 | ipywidgets==7.7.0 98 | iso3166==2.1.1 99 | itsdangerous==1.1.0+computecanada 100 | jax==0.4.13+computecanada 101 | jax-jumpy==1.0.0+computecanada 102 | jaxlib==0.4.13+cuda11.cudnn86.computecanada 103 | jedi==0.18.1+computecanada 104 | Jinja2==2.11.3+computecanada 105 | joblib==1.2.0+computecanada 106 | jsonpatch==1.32 107 | jsonpointer==2.3+computecanada 108 | jsonschema==4.6.0+computecanada 109 | jupyter-client==7.3.4+computecanada 110 | jupyter-core==4.10.0+computecanada 111 | jupyterlab-pygments==0.2.2+computecanada 112 | jupyterlab-widgets==1.1.0 113 | kaggle==1.5.13 114 | keras==2.11.0+computecanada 115 | kiwisolver==1.3.2+computecanada 116 | libclang==14.0.1+computecanada 117 | lxml==4.9.1+computecanada 118 | lz4==4.3.2+computecanada 119 | Markdown==3.3.7+computecanada 120 | markdown-it-py==2.2.0+computecanada 121 | markupsafe==2.0.1+computecanada 122 | matplotlib==3.7.0+computecanada 123 | matplotlib-inline==0.1.6 124 | mdurl==0.1.2+computecanada 125 | mistune==0.8.4+computecanada 126 | ml-dtypes==0.1.0+computecanada 127 | monty==2023.5.7 128 | more-itertools==8.2.0+computecanada 129 | moviepy==1.0.3+computecanada 130 | mpire==2.7.1 131 | mpmath==1.3.0+computecanada 132 | mujoco-py==2.1.2.14 133 | nbclient==0.6.4 134 | nbconvert==6.5.0+computecanada 135 | nbformat==5.4.0+computecanada 136 | nest-asyncio==1.5.5+computecanada 137 | networkx==3.1+computecanada 138 | notebook==6.4.12+computecanada 139 | numcodecs==0.8.0+computecanada 140 | numpy==1.24.2+computecanada 141 | oauthlib==3.2.0+computecanada 142 | omegaconf==2.2.3+computecanada 143 | opencv-contrib-python-headless==4.6.0.66 144 | opencv-python==4.1.2.30+computecanada 145 | opendrive2lanelet==1.2.1 146 | opensimplex==0.4.4 147 | opt-einsum==3.3.0+computecanada 148 | optax==0.1.7 149 | packaging==21.3+computecanada 150 | palettable==3.3.3 151 | panda3d==1.10.8+computecanada 152 | panda3d-gltf==0.13 153 | panda3d-simplepbr==0.10 154 | pandas==1.5.3+computecanada 155 | pandocfilters==1.5.0+computecanada 156 | parso==0.8.3+computecanada 157 | pathtools==0.1.2+computecanada 158 | patsy==0.5.3+computecanada 159 | PettingZoo==1.22.3 160 | pexpect==4.8.0+computecanada 161 | phonopy==2.17.1+computecanada 162 | pickleshare==0.7.5+computecanada 163 | Pillow==8.1.2+computecanada 164 | Pillow-SIMD==7.0.0.post3+computecanada 165 | pkgutil-resolve-name==1.3.10+computecanada 166 | plotly==5.14.1 167 | pluggy==1.0.0+computecanada 168 | proglog==0.1.10 169 | prometheus-client==0.14.1+computecanada 170 | promise==2.3+computecanada 171 | prompt-toolkit==3.0.29+computecanada 172 | property-cached==1.6.4+computecanada 173 | protobuf==3.20.1+computecanada 174 | psutil==5.9.0+computecanada 175 | PTable==0.9.2 176 | ptyprocess==0.7.0+computecanada 177 | pure-eval==0.2.2+computecanada 178 | py==1.11.0+computecanada 179 | py-cpuinfo==9.0.0+computecanada 180 | pyasn1==0.4.8+computecanada 181 | pyasn1-modules==0.2.8+computecanada 182 | pybullet==3.0.6 183 | pycparser==2.21+computecanada 184 | pygame==2.1.0+computecanada 185 | pyglet==1.5.0+computecanada 186 | Pygments==2.15.1+computecanada 187 | pymap3d==2.7.0 188 | pymatgen==2022.0.10+computecanada 189 | pynng==0.7.1+computecanada 190 | PyOpenGL==3.1.6 191 | pyparsing==2.4.7+computecanada 192 | pyproj==3.3.0+computecanada 193 | pyrsistent==0.18.1+computecanada 194 | pytest==7.0.1+computecanada 195 | pytest-benchmark==4.0.0 196 | pytest-cov==4.0.0+computecanada 197 | pytest-forked==1.6.0 198 | pytest-xdist==3.2.1 199 | python-dateutil==2.8.2+computecanada 200 | python-slugify==8.0.1 201 | pytimeparse==1.1.8+computecanada 202 | pytz==2022.1+computecanada 203 | PyWavelets==1.3.0+computecanada 204 | PyYAML==6.0+computecanada 205 | pyzmq==23.1.0+computecanada 206 | readerwriterlock==1.0.9 207 | redis==4.6.0 208 | requests==2.28.0+computecanada 209 | requests-oauthlib==1.3.1+computecanada 210 | responses==0.12.1+computecanada 211 | rich==13.3.5 212 | rliable==1.0.7 213 | rsa==4.8+computecanada 214 | Rtree==1.0.1+computecanada 215 | ruamel-yaml-clib==0.2.6+computecanada 216 | ruamel.yaml==0.17.4 217 | scikit-image==0.19.3+computecanada 218 | scikit-learn==1.1.2+computecanada 219 | scipy==1.8.0+computecanada 220 | seaborn==0.12.2+computecanada 221 | seekpath==2.0.1 222 | send2trash==1.8.0+computecanada 223 | sentry-sdk==1.9.9 224 | setproctitle==1.3.2 225 | setuptools-scm==3.5.0+computecanada 226 | sh==2.0.3 227 | shapely==2.0.1+computecanada 228 | Shimmy==0.2.1 229 | shortuuid==1.0.9 230 | six==1.14.0+computecanada 231 | smmap==5.0.0+computecanada 232 | sniffio==1.3.0+computecanada 233 | soupsieve==2.3.2.post1+computecanada 234 | spglib==2.0.2+computecanada 235 | stable-baselines3==1.5.0 236 | stack-data==0.2.0+computecanada 237 | statsmodels==0.13.5+computecanada 238 | sympy==1.11.1+computecanada 239 | tableprint==0.9.1 240 | tabulate==0.9.0+computecanada 241 | tenacity==8.2.2 242 | tensorboard==2.11.2+computecanada 243 | tensorboard-data-server==0.6.1+computecanada 244 | tensorboard-plugin-wit==1.8.1+computecanada 245 | tensorboardX==2.6.2 246 | tensorflow==2.11.0+computecanada 247 | tensorflow-estimator==2.11.0+computecanada 248 | tensorflow-io-gcs-filesystem==0.26.0+computecanada 249 | tensorflow-probability==0.21.0+computecanada 250 | termcolor==2.3.0 251 | terminado==0.15.0+computecanada 252 | text-unidecode==1.3+computecanada 253 | tf-estimator-nightly==2.8.0.dev2021122109+computecanada 254 | threadpoolctl==3.1.0+computecanada 255 | tifffile==2023.4.12+computecanada 256 | tinycss2==1.1.1+computecanada 257 | tomli==2.0.1+computecanada 258 | toolz==0.12.0+computecanada 259 | torch==1.8.1+computecanada 260 | torchfile==0.1.0+computecanada 261 | torchvision==0.9.1+computecanada 262 | tornado==6.1+computecanada 263 | tqdm==4.64.0+computecanada 264 | traitlets==5.2.2.post1 265 | transforms3d==0.3.1+computecanada 266 | treevalue==1.4.11 267 | trimesh==3.9.29 268 | trueskill==0.4.5 269 | Twisted==22.10.0 270 | typing-extensions==4.5.0+computecanada 271 | uncertainties==3.1.7 272 | urllib3==1.26.12+computecanada 273 | URLObject==2.4.3 274 | virtualenv==20.0.18 275 | visdom==0.1.8.9+computecanada 276 | wandb==0.13.3 277 | wcwidth==0.2.5+computecanada 278 | webencodings==0.5.1+computecanada 279 | websocket-client==1.5.1+computecanada 280 | Werkzeug==1.0.1+computecanada 281 | widgetsnbextension==3.6.0 282 | wrapt==1.15.0+computecanada 283 | yacs==0.1.8+computecanada 284 | yapf==0.29.0 285 | yattag==1.15.1 286 | zarr==2.11.3 287 | zipp==3.8.0+computecanada 288 | zope.interface==5.4.0+computecanada 289 | --------------------------------------------------------------------------------