├── README.md
├── images
├── ablat.png
├── dtr.png
├── interp.png
├── result.png
└── scalability.png
└── requirements.txt
/README.md:
--------------------------------------------------------------------------------
1 | # DART (Full code to be released soon)
2 | Official implementation of ICML 2024 paper "Learning to Play Atari in a World of Tokens" by
3 |
4 | [Pranav Agarwal](https://pranaval.github.io/), [Sheldon Andrews](https://profs.etsmtl.ca/sandrews/) and [Samira Ebrahimi Kahou](https://saebrahimi.github.io/).
5 |
6 | In this work, we introduce discrete abstract representations for transformer-based learning (DART), a sample-efficient method utilizing discrete representations for modeling both the world and learning behavior. DART outperforms previous state-of-the-art methods that do not use look-ahead search on the Atari 100k sample efficiency benchmark with a median human-normalized score of 0.790 and beats humans in 9 out of 26 games.
7 |
8 | [Paper]() [Webpage](https://pranaval.github.io/DART/)
9 |
10 |
11 |
12 |
13 |
14 |
15 | ## Dependencies
16 | Create a virtual environment and install all the required files
17 | ```
18 | python3.8 -m venv myenv
19 | source myenv/bin/activate
20 | pip install -r requirements.txt
21 | ```
22 |
23 | ## Overview
24 | The main contribution of our work includes a novel approach that utilizes transformers for both world and policy modeling. Specifically, we utilize a transformer-decoder (GPT) for world modeling and a transformer-encoder (ViT) for policy learning. This represents an improvement compared to IRIS, which relies on CNNs and LSTMs for policy learning, potentially limiting its performance. We use discrete representations for policy and world modeling. These discrete representations capture abstract features, enabling our transformer-based model to focus on task-specific fine grained details. Attending to these details improves decision-making, as demonstrated by our results. To address the problem of partial observability, we introduce a novel mechanism for modeling the memory that aggregates task-relevant information from the previous time step to the next using a self-attention mechanism. Our model showcases enhanced interpretability and sample efficiency. It achieves state-of-the-art results (no-look-ahead search methods) on the Atari 100k benchmark with a median score of 0.790 and superhuman performance in 9 out of 26 games.
25 |
26 | ## Paper Associated
27 | If you use it, please cite
28 | ```
29 | @article{agarwal2024learning,
30 | title={Learning to Play Atari in a World of Tokens},
31 | author={Agarwal, Pranav and Andrews, Sheldon and Kahou, Samira Ebrahimi},
32 | booktitle={International Conference on Machine Learning (ICML)},
33 | year={2024}
34 | }
35 | ```
36 |
37 | ## Credits
38 | * Modeling of tokens and the world is inspired from [https://github.com/eloialonso/iris](https://github.com/eloialonso/iris).
39 | * [https://github.com/google-research/vision_transformer](https://github.com/google-research/vision_transformer)
40 |
41 |
42 |
43 |
--------------------------------------------------------------------------------
/images/ablat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pranavAL/DART/c5f4331181165c1d7fb7ee5545952f7c1c1c33f9/images/ablat.png
--------------------------------------------------------------------------------
/images/dtr.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pranavAL/DART/c5f4331181165c1d7fb7ee5545952f7c1c1c33f9/images/dtr.png
--------------------------------------------------------------------------------
/images/interp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pranavAL/DART/c5f4331181165c1d7fb7ee5545952f7c1c1c33f9/images/interp.png
--------------------------------------------------------------------------------
/images/result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pranavAL/DART/c5f4331181165c1d7fb7ee5545952f7c1c1c33f9/images/result.png
--------------------------------------------------------------------------------
/images/scalability.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pranavAL/DART/c5f4331181165c1d7fb7ee5545952f7c1c1c33f9/images/scalability.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.1.0+computecanada
2 | ale-py==0.7.5+computecanada
3 | antlr4-python3-runtime==4.9.3+computecanada
4 | appdirs==1.4.3+computecanada
5 | arch==5.0.1
6 | argon2-cffi==21.3.0+computecanada
7 | argon2-cffi-bindings==21.2.0+computecanada
8 | asciitree==0.3.3+computecanada
9 | asttokens==2.0.5+computecanada
10 | astunparse==1.6.3+computecanada
11 | async-timeout==4.0.2+computecanada
12 | atari-py==0.2.6
13 | attrs==21.4.0+computecanada
14 | Automat==22.10.0
15 | AutoROM==0.4.2
16 | AutoROM.accept-rom-license==0.4.2
17 | backcall==0.2.0+computecanada
18 | beautifulsoup4==4.11.1+computecanada
19 | bitmath==1.3.3.1
20 | bleach==5.0.0+computecanada
21 | bokeh==2.4.3+computecanada
22 | cached-property==1.5.2+computecanada
23 | cachetools==4.2.4+computecanada
24 | castepxbin==0.3.0
25 | certifi==2022.5.18.1+computecanada
26 | cffi==1.15.0+computecanada
27 | chardet==4.0.0+computecanada
28 | charset-normalizer==2.0.12+computecanada
29 | chex==0.1.7+computecanada
30 | click==7.1.2+computecanada
31 | cloudpickle==1.6.0+computecanada
32 | colorama==0.4.6+computecanada
33 | colormath==3.0.0+computecanada
34 | commonroad-io==2023.1
35 | commonroad-vehicle-models==3.0.2
36 | constantly==15.1.0
37 | contourpy==1.0.7+computecanada
38 | coverage==7.2.0+computecanada
39 | crafter==1.8.1
40 | cycler==0.11.0+computecanada
41 | Cython==0.29.14+computecanada
42 | debugpy==1.6.0+computecanada
43 | decorator==4.4.2+computecanada
44 | defusedxml==0.7.1+computecanada
45 | deprecation==2.1.0+computecanada
46 | DI-toolkit==0.2.0
47 | DI-treetensor==0.4.1
48 | dill==0.3.7+computecanada
49 | distlib==0.3.0
50 | dm-tree==0.1.8+computecanada
51 | docker-pycreds==0.4.0+computecanada
52 | easydict==1.9+computecanada
53 | einops==0.3.0
54 | entrypoints==0.4+computecanada
55 | enum-tools==0.10.0
56 | execnet==1.9.0+computecanada
57 | executing==0.8.3+computecanada
58 | fasteners==0.17.3+computecanada
59 | fastjsonschema==2.15.3+computecanada
60 | ffmpeg==1.4
61 | filelock==3.0.12+computecanada
62 | Flask==1.1.4
63 | flatbuffers==20190709135844+computecanada
64 | folium==0.2.1
65 | fonttools==4.39.3+computecanada
66 | future==0.18.2+computecanada
67 | gast==0.4.0+computecanada
68 | gitdb==4.0.9+computecanada
69 | gitpython==3.1.27+computecanada
70 | glfw==1.12.0+computecanada
71 | google-auth==1.35.0+computecanada
72 | google-auth-oauthlib==0.4.6+computecanada
73 | google-pasta==0.2.0+computecanada
74 | graphviz==0.20.1+computecanada
75 | grpcio==1.32.0+computecanada
76 | gtrxl-torch==0.1.7
77 | gym==0.19.0
78 | gym-notices==0.0.8
79 | gymnasium==0.27.0
80 | gymnasium-notices==0.0.1+computecanada
81 | h5py==3.8.0+computecanada
82 | hbutils==0.9.1
83 | hickle==5.0.2
84 | hydra-core==1.2.0+computecanada
85 | hyperlink==21.0.0+computecanada
86 | idna==3.3+computecanada
87 | ijson==3.2.0.post0
88 | imageio==2.9.0+computecanada
89 | imageio-ffmpeg==0.4.2+computecanada
90 | importlib-metadata==4.11.4+computecanada
91 | importlib-resources==5.7.1
92 | incremental==22.10.0
93 | iniconfig==2.0.0
94 | ipykernel==6.14.0
95 | ipython==8.4.0+computecanada
96 | ipython-genutils==0.2.0+computecanada
97 | ipywidgets==7.7.0
98 | iso3166==2.1.1
99 | itsdangerous==1.1.0+computecanada
100 | jax==0.4.13+computecanada
101 | jax-jumpy==1.0.0+computecanada
102 | jaxlib==0.4.13+cuda11.cudnn86.computecanada
103 | jedi==0.18.1+computecanada
104 | Jinja2==2.11.3+computecanada
105 | joblib==1.2.0+computecanada
106 | jsonpatch==1.32
107 | jsonpointer==2.3+computecanada
108 | jsonschema==4.6.0+computecanada
109 | jupyter-client==7.3.4+computecanada
110 | jupyter-core==4.10.0+computecanada
111 | jupyterlab-pygments==0.2.2+computecanada
112 | jupyterlab-widgets==1.1.0
113 | kaggle==1.5.13
114 | keras==2.11.0+computecanada
115 | kiwisolver==1.3.2+computecanada
116 | libclang==14.0.1+computecanada
117 | lxml==4.9.1+computecanada
118 | lz4==4.3.2+computecanada
119 | Markdown==3.3.7+computecanada
120 | markdown-it-py==2.2.0+computecanada
121 | markupsafe==2.0.1+computecanada
122 | matplotlib==3.7.0+computecanada
123 | matplotlib-inline==0.1.6
124 | mdurl==0.1.2+computecanada
125 | mistune==0.8.4+computecanada
126 | ml-dtypes==0.1.0+computecanada
127 | monty==2023.5.7
128 | more-itertools==8.2.0+computecanada
129 | moviepy==1.0.3+computecanada
130 | mpire==2.7.1
131 | mpmath==1.3.0+computecanada
132 | mujoco-py==2.1.2.14
133 | nbclient==0.6.4
134 | nbconvert==6.5.0+computecanada
135 | nbformat==5.4.0+computecanada
136 | nest-asyncio==1.5.5+computecanada
137 | networkx==3.1+computecanada
138 | notebook==6.4.12+computecanada
139 | numcodecs==0.8.0+computecanada
140 | numpy==1.24.2+computecanada
141 | oauthlib==3.2.0+computecanada
142 | omegaconf==2.2.3+computecanada
143 | opencv-contrib-python-headless==4.6.0.66
144 | opencv-python==4.1.2.30+computecanada
145 | opendrive2lanelet==1.2.1
146 | opensimplex==0.4.4
147 | opt-einsum==3.3.0+computecanada
148 | optax==0.1.7
149 | packaging==21.3+computecanada
150 | palettable==3.3.3
151 | panda3d==1.10.8+computecanada
152 | panda3d-gltf==0.13
153 | panda3d-simplepbr==0.10
154 | pandas==1.5.3+computecanada
155 | pandocfilters==1.5.0+computecanada
156 | parso==0.8.3+computecanada
157 | pathtools==0.1.2+computecanada
158 | patsy==0.5.3+computecanada
159 | PettingZoo==1.22.3
160 | pexpect==4.8.0+computecanada
161 | phonopy==2.17.1+computecanada
162 | pickleshare==0.7.5+computecanada
163 | Pillow==8.1.2+computecanada
164 | Pillow-SIMD==7.0.0.post3+computecanada
165 | pkgutil-resolve-name==1.3.10+computecanada
166 | plotly==5.14.1
167 | pluggy==1.0.0+computecanada
168 | proglog==0.1.10
169 | prometheus-client==0.14.1+computecanada
170 | promise==2.3+computecanada
171 | prompt-toolkit==3.0.29+computecanada
172 | property-cached==1.6.4+computecanada
173 | protobuf==3.20.1+computecanada
174 | psutil==5.9.0+computecanada
175 | PTable==0.9.2
176 | ptyprocess==0.7.0+computecanada
177 | pure-eval==0.2.2+computecanada
178 | py==1.11.0+computecanada
179 | py-cpuinfo==9.0.0+computecanada
180 | pyasn1==0.4.8+computecanada
181 | pyasn1-modules==0.2.8+computecanada
182 | pybullet==3.0.6
183 | pycparser==2.21+computecanada
184 | pygame==2.1.0+computecanada
185 | pyglet==1.5.0+computecanada
186 | Pygments==2.15.1+computecanada
187 | pymap3d==2.7.0
188 | pymatgen==2022.0.10+computecanada
189 | pynng==0.7.1+computecanada
190 | PyOpenGL==3.1.6
191 | pyparsing==2.4.7+computecanada
192 | pyproj==3.3.0+computecanada
193 | pyrsistent==0.18.1+computecanada
194 | pytest==7.0.1+computecanada
195 | pytest-benchmark==4.0.0
196 | pytest-cov==4.0.0+computecanada
197 | pytest-forked==1.6.0
198 | pytest-xdist==3.2.1
199 | python-dateutil==2.8.2+computecanada
200 | python-slugify==8.0.1
201 | pytimeparse==1.1.8+computecanada
202 | pytz==2022.1+computecanada
203 | PyWavelets==1.3.0+computecanada
204 | PyYAML==6.0+computecanada
205 | pyzmq==23.1.0+computecanada
206 | readerwriterlock==1.0.9
207 | redis==4.6.0
208 | requests==2.28.0+computecanada
209 | requests-oauthlib==1.3.1+computecanada
210 | responses==0.12.1+computecanada
211 | rich==13.3.5
212 | rliable==1.0.7
213 | rsa==4.8+computecanada
214 | Rtree==1.0.1+computecanada
215 | ruamel-yaml-clib==0.2.6+computecanada
216 | ruamel.yaml==0.17.4
217 | scikit-image==0.19.3+computecanada
218 | scikit-learn==1.1.2+computecanada
219 | scipy==1.8.0+computecanada
220 | seaborn==0.12.2+computecanada
221 | seekpath==2.0.1
222 | send2trash==1.8.0+computecanada
223 | sentry-sdk==1.9.9
224 | setproctitle==1.3.2
225 | setuptools-scm==3.5.0+computecanada
226 | sh==2.0.3
227 | shapely==2.0.1+computecanada
228 | Shimmy==0.2.1
229 | shortuuid==1.0.9
230 | six==1.14.0+computecanada
231 | smmap==5.0.0+computecanada
232 | sniffio==1.3.0+computecanada
233 | soupsieve==2.3.2.post1+computecanada
234 | spglib==2.0.2+computecanada
235 | stable-baselines3==1.5.0
236 | stack-data==0.2.0+computecanada
237 | statsmodels==0.13.5+computecanada
238 | sympy==1.11.1+computecanada
239 | tableprint==0.9.1
240 | tabulate==0.9.0+computecanada
241 | tenacity==8.2.2
242 | tensorboard==2.11.2+computecanada
243 | tensorboard-data-server==0.6.1+computecanada
244 | tensorboard-plugin-wit==1.8.1+computecanada
245 | tensorboardX==2.6.2
246 | tensorflow==2.11.0+computecanada
247 | tensorflow-estimator==2.11.0+computecanada
248 | tensorflow-io-gcs-filesystem==0.26.0+computecanada
249 | tensorflow-probability==0.21.0+computecanada
250 | termcolor==2.3.0
251 | terminado==0.15.0+computecanada
252 | text-unidecode==1.3+computecanada
253 | tf-estimator-nightly==2.8.0.dev2021122109+computecanada
254 | threadpoolctl==3.1.0+computecanada
255 | tifffile==2023.4.12+computecanada
256 | tinycss2==1.1.1+computecanada
257 | tomli==2.0.1+computecanada
258 | toolz==0.12.0+computecanada
259 | torch==1.8.1+computecanada
260 | torchfile==0.1.0+computecanada
261 | torchvision==0.9.1+computecanada
262 | tornado==6.1+computecanada
263 | tqdm==4.64.0+computecanada
264 | traitlets==5.2.2.post1
265 | transforms3d==0.3.1+computecanada
266 | treevalue==1.4.11
267 | trimesh==3.9.29
268 | trueskill==0.4.5
269 | Twisted==22.10.0
270 | typing-extensions==4.5.0+computecanada
271 | uncertainties==3.1.7
272 | urllib3==1.26.12+computecanada
273 | URLObject==2.4.3
274 | virtualenv==20.0.18
275 | visdom==0.1.8.9+computecanada
276 | wandb==0.13.3
277 | wcwidth==0.2.5+computecanada
278 | webencodings==0.5.1+computecanada
279 | websocket-client==1.5.1+computecanada
280 | Werkzeug==1.0.1+computecanada
281 | widgetsnbextension==3.6.0
282 | wrapt==1.15.0+computecanada
283 | yacs==0.1.8+computecanada
284 | yapf==0.29.0
285 | yattag==1.15.1
286 | zarr==2.11.3
287 | zipp==3.8.0+computecanada
288 | zope.interface==5.4.0+computecanada
289 |
--------------------------------------------------------------------------------