├── .github └── workflows │ └── publish.yml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── flashrl ├── __init__.py ├── envs │ ├── __init__.py │ ├── grid.pyx │ ├── multigrid.pyx │ └── pong.pyx ├── main.py ├── models.py └── utils.py ├── pyproject.toml ├── requirements.txt ├── setup.py ├── test ├── __init__.py └── test.py └── train.py /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Build and Publish to PyPI 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*.*.*' 7 | 8 | jobs: 9 | build-wheels: 10 | name: Build wheels on ${{ matrix.os }} for Python ${{ matrix.python-version }} 11 | runs-on: ${{ matrix.os }} 12 | strategy: 13 | matrix: 14 | os: [ubuntu-latest, windows-latest, macos-latest] 15 | python-version: ['3.10', '3.11', '3.12', '3.13'] 16 | exclude: 17 | - os: ubuntu-latest 18 | python-version: '3.10' 19 | - os: ubuntu-latest 20 | python-version: '3.11' 21 | - os: ubuntu-latest 22 | python-version: '3.13' 23 | fail-fast: false 24 | 25 | steps: 26 | - name: Checkout code 27 | uses: actions/checkout@v4 28 | 29 | - name: Set up Python ${{ matrix.python-version || '3.12' }} 30 | uses: actions/setup-python@v5 31 | with: 32 | python-version: ${{ matrix.python-version || '3.12' }} 33 | 34 | - name: Install dependencies 35 | run: | 36 | python -m pip install --upgrade pip 37 | pip install build setuptools wheel cython numpy 38 | 39 | - name: Install cibuildwheel (Linux only) 40 | if: matrix.os == 'ubuntu-latest' 41 | run: | 42 | pip install cibuildwheel 43 | 44 | - name: Build wheels (Linux with cibuildwheel) 45 | if: matrix.os == 'ubuntu-latest' 46 | run: | 47 | python -m cibuildwheel --output-dir dist 48 | env: 49 | CIBW_BUILD: "cp310-* cp311-* cp312-* cp313-*" 50 | CIBW_ARCHS: auto64 51 | CIBW_SKIP: "*-musllinux_*" 52 | CIBW_BUILD_VERBOSITY: 1 53 | 54 | - name: Build wheels (Windows and macOS) 55 | if: matrix.os != 'ubuntu-latest' 56 | run: | 57 | python -m build --wheel 58 | 59 | - name: Upload wheels as artifacts 60 | uses: actions/upload-artifact@v4 61 | with: 62 | name: wheels-${{ matrix.os }}-${{ matrix.python-version || 'all' }} 63 | path: dist/*.whl 64 | 65 | publish: 66 | needs: build-wheels 67 | runs-on: ubuntu-latest 68 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') 69 | steps: 70 | - name: Download all wheel artifacts 71 | uses: actions/download-artifact@v4 72 | with: 73 | path: dist 74 | pattern: wheels-* 75 | merge-multiple: true 76 | 77 | - name: Publish to PyPI 78 | env: 79 | TWINE_USERNAME: __token__ 80 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 81 | run: | 82 | pip install twine 83 | twine upload dist/*.whl 84 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # PyPI configuration file 171 | .pypirc 172 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Lukas Fisch 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | recursive-include flashrl/envs *.pyx *.c 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # flashrl 2 | `flashrl` does RL with **millions of steps/second 💨 while being tiny**: ~200 lines of code 3 | 4 | 🛠️ `pip install flashrl` or clone the repo & `pip install -r requirements.txt` 5 | - If cloned (or if envs changed), compile: `python setup.py build_ext --inplace` 6 | 7 | 💡 `flashrl` will always be **tiny**: **Read the code** (+paste into LLM) to understand it! 8 | ## Quick Start 🚀 9 | `flashrl` uses a `Learner` that holds an `env` and a `model` (default: `Policy` with LSTM) 10 | 11 | ```python 12 | import flashrl as frl 13 | 14 | learn = frl.Learner(frl.envs.Pong(n_agents=2**14)) 15 | curves = learn.fit(40, steps=16, desc='done') 16 | frl.print_curve(curves['loss'], label='loss') 17 | frl.play(learn.env, learn.model, fps=8) 18 | learn.env.close() 19 | ``` 20 | `.fit` does RL with ~**10 million steps**: `40` iterations × `16` steps × `2**14` agents! 21 | 22 |

23 | 24 |

25 | 26 | **Run it yourself via `python train.py` and play against the AI** 🪄 27 | 28 |
29 | Click here, to read a tiny doc 📑 30 | 31 | `Learner` takes the arguments 32 | - `env`: RL environment 33 | - `model`: A `Policy` model 34 | - `device`: Per default picks `mps` or `cuda` if available else `cpu` 35 | - `dtype`: Per default `torch.bfloat16` if device is `cuda` else `torch.float32` 36 | - `compile_no_lstm`: Speedup via `torch.compile` if `model` has no `lstm` 37 | - `**kwargs`: Passed to the `Policy`, e.g. `hidden_size` or `lstm` 38 | 39 | `Learner.fit` takes the arguments 40 | - `iters`: Number of iterations 41 | - `steps`: Number of steps in `rollout` 42 | - `desc`: Progress bar description (e.g. `'reward'`) 43 | - `log`: If `True`, `tensorboard` logging is enabled 44 | - run `tensorboard --logdir=runs`and visit `http://localhost:6006` in the browser! 45 | - `stop_func`: Function that stops training if it returns `True` e.g. 46 | 47 | ```python 48 | ... 49 | def stop(kl, **kwargs): 50 | return kl > .1 51 | 52 | curves = learn.fit(40, steps=16, stop_func=stop) 53 | ... 54 | ``` 55 | - `lr`, `anneal_lr` & args of `ppo` after `bs`: Hyperparameters 56 | 57 | The most important functions in `flashrl/utils.py` are 58 | - `print_curve`: Visualizes the loss across the `iters` 59 | - `play`: Plays the environment in the terminal and takes 60 | - `model`: A `Policy` model 61 | - `playable`: If `True`, allows you to act (or decide to let the model act) 62 | - `steps`: Number of steps 63 | - `fps`: Frames per second 64 | - `obs`: Argument of the env that should be rendered as observations 65 | - `dump`: If `True`, no frame refresh -> Frames accumulate in the terminal 66 | - `idx`: Agent index between `0` and `n_agents` (default: `0`) 67 |
68 | 69 | ## Environments 🕹️ 70 | **Each env is one Cython(=`.pyx`) file** in `flashrl/envs`. **That's it!** 71 | 72 | To **add custom envs**, use `grid.pyx`, `pong.pyx` or `multigrid.pyx` as a **template**: 73 | - `grid.pyx` for **single-agent** envs (~110 LOC) 74 | - `pong.pyx` for **1 vs 1 agent** envs (~150 LOC) 75 | - `multigrid.pyx` for **multi-agent** envs (~190 LOC) 76 | 77 | | `Grid` | `Pong` | `MultiGrid` | 78 | |-----------------------|-----------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------| 79 | | Agent must reach goal | Agent must score | Agent must reach goal first | 80 | |![grid](https://github.com/user-attachments/assets/f51c9fea-0ab9-45a1-a52e-446cee9fc593)| ![pong](https://github.com/user-attachments/assets/e77332d4-a3f4-432a-b338-98a078fb7dfb)| ![multigrid](https://github.com/user-attachments/assets/bc67c5e5-e820-4cfe-875c-1e545fbddff3)| 81 | 82 | ## Acknowledgements 🙌 83 | I want to thank 84 | - [Joseph Suarez](https://github.com/jsuarez5341) for open sourcing RL envs in C(ython)! Star [PufferLib](https://github.com/PufferAI/PufferLib) ⭐ 85 | - [Costa Huang](https://github.com/vwxyzjn) for open sourcing high-quality single-file RL code! Star [cleanrl](https://github.com/vwxyzjn/cleanrl) ⭐ 86 | 87 | and last but not least... 88 | 89 |

90 | 91 |

92 | -------------------------------------------------------------------------------- /flashrl/__init__.py: -------------------------------------------------------------------------------- 1 | from flashrl import envs 2 | from .main import Learner 3 | from .models import Policy 4 | from .utils import play, render, set_seed, print_curve 5 | -------------------------------------------------------------------------------- /flashrl/envs/__init__.py: -------------------------------------------------------------------------------- 1 | emoji_maps = {'grid': {0: ' ', 1: '🦠', 2: '🍪'}, 2 | 'pong': {0: ' ', 1: '🔲', 2: '🔴'}, 3 | 'multigrid': {0: ' ', 1: '🧱', 2: '🦠', 3: '🍪'}} 4 | key_maps = {'grid': {'a': 1, 'd': 2, 'w': 3, 's': 4}, 5 | 'pong': {'w': 1, 's': 2}, 6 | 'multigrid': {'a': 1, 'd': 2, 'w': 3, 's': 4}} 7 | from .grid import Grid 8 | from .pong import Pong 9 | from .multigrid import MultiGrid 10 | -------------------------------------------------------------------------------- /flashrl/envs/grid.pyx: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | import numpy as np 3 | cimport numpy as np 4 | 5 | from libc.stdlib cimport free, srand, calloc 6 | 7 | cdef extern from *: 8 | ''' 9 | #include 10 | #include 11 | 12 | const char AGENT = 1, GOAL = 2; 13 | const unsigned char LEFT = 1, RIGHT = 2, UP = 3, DOWN = 4; 14 | 15 | typedef struct { 16 | char *obs, *reward, *done; 17 | int size, t, x, y, goal_x, goal_y; 18 | } CGrid; 19 | 20 | void c_reset(CGrid* env) { 21 | env->t = 0; 22 | memset(env->obs, 0, env->size * env->size); 23 | env->x = env->y = env->size / 2; 24 | env->obs[env->x + env->y * env->size] = AGENT; 25 | env->goal_x = rand() % env->size; 26 | env->goal_y = rand() % env->size; 27 | if (env->goal_x == env->x && env->goal_y == env->y) env->goal_x++; 28 | env->obs[env->goal_x + env->goal_y * env->size] = GOAL; 29 | } 30 | 31 | void c_step(CGrid* env, unsigned char act) { 32 | env->reward[0] = 0; 33 | env->done[0] = 0; 34 | env->obs[env->x + env->y * env->size] = 0; 35 | if (act == LEFT) env->x--; 36 | else if (act == RIGHT) env->x++; 37 | else if (act == UP) env->y--; 38 | else if (act == DOWN) env->y++; 39 | if (env->t > 3 * env->size || env->x < 0 || env->y < 0 || env->x >= env->size || env->y >= env->size) { 40 | env->reward[0] = -1; 41 | env->done[0] = 1; 42 | c_reset(env); 43 | return; 44 | } 45 | int position = env->x + env->y * env->size; 46 | if (env->obs[position] == GOAL) { 47 | env->reward[0] = 1; 48 | env->done[0] = 1; 49 | c_reset(env); 50 | return; 51 | } 52 | env->obs[position] = AGENT; 53 | env->t++; 54 | } 55 | ''' 56 | 57 | ctypedef struct CGrid: 58 | char *obs 59 | char *reward 60 | char *done 61 | int size, t, x, y, goal_x, goal_y 62 | 63 | void c_reset(CGrid *env) 64 | void c_step(CGrid *env, unsigned char act) 65 | 66 | 67 | cdef class Grid: 68 | cdef: 69 | CGrid *envs 70 | int n_agents, _n_acts 71 | np.ndarray obs_arr, rewards_arr, dones_arr 72 | cdef char[:, :, :] obs_memview 73 | cdef char[:] rewards_memview 74 | cdef char[:] dones_memview 75 | int size 76 | 77 | def __init__(self, n_agents=2**14, n_acts=5, size=8): 78 | self.envs = calloc(n_agents, sizeof(CGrid)) 79 | self.n_agents = n_agents 80 | self._n_acts = n_acts 81 | self.obs_arr = np.zeros((n_agents, size, size), dtype=np.int8) 82 | self.rewards_arr = np.zeros(n_agents, dtype=np.int8) 83 | self.dones_arr = np.zeros(n_agents, dtype=np.int8) 84 | self.obs_memview = self.obs_arr 85 | self.rewards_memview = self.rewards_arr 86 | self.dones_memview = self.dones_arr 87 | cdef int i 88 | for i in range(n_agents): 89 | env = &self.envs[i] 90 | env.obs = &self.obs_memview[i, 0, 0] 91 | env.reward = &self.rewards_memview[i] 92 | env.done = &self.dones_memview[i] 93 | env.size = size 94 | 95 | def reset(self, seed=None): 96 | if seed is not None: 97 | srand(seed) 98 | cdef int i 99 | for i in range(self.n_agents): 100 | c_reset(&self.envs[i]) 101 | return self 102 | 103 | def step(self, np.ndarray acts): 104 | cdef unsigned char[:] acts_memview = acts 105 | cdef int i 106 | for i in range(self.n_agents): 107 | c_step(&self.envs[i], acts_memview[i]) 108 | 109 | def close(self): 110 | free(self.envs) 111 | 112 | @property 113 | def obs(self): return self.obs_arr 114 | 115 | @property 116 | def rewards(self): return self.rewards_arr 117 | 118 | @property 119 | def dones(self): return self.dones_arr 120 | 121 | @property 122 | def n_acts(self): return self._n_acts 123 | -------------------------------------------------------------------------------- /flashrl/envs/multigrid.pyx: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | import numpy as np 3 | cimport numpy as np 4 | 5 | from libc.stdlib cimport free, srand, calloc 6 | 7 | cdef extern from *: 8 | ''' 9 | #include 10 | #include 11 | #include 12 | 13 | const char WALL = 1, AGENT = 2, GOAL = 3; 14 | const unsigned char NOOP = 0, DOWN = 1, UP = 2, LEFT = 3, RIGHT = 4; 15 | 16 | typedef struct { 17 | char *obs, *rewards, *dones; 18 | unsigned char *x, *y; 19 | char *total_obs; 20 | int n_agents_per_env, vision, size, t, goal_x, goal_y; 21 | } CMultiGrid; 22 | 23 | void get_obs(CMultiGrid* env) { 24 | int ob_size = 2 * env->vision + 1; 25 | int ob_pixels = ob_size * ob_size; 26 | memset(env->obs, 0, env->n_agents_per_env * ob_pixels); 27 | int center = ob_pixels / 2; 28 | for (int i = 0; i < env->n_agents_per_env; i++) { 29 | for (int x = -env->vision; x <= env->vision; x++) { 30 | for (int y = -env->vision; y <= env->vision; y++) { 31 | char world_x = env->x[i] + x; 32 | char world_y = env->y[i] + y; 33 | if (world_x < 0 || world_x > env->size - 1 || world_y < 0 || world_y > env->size - 1) { 34 | env->obs[i * ob_pixels + center + x * ob_size + y] = WALL; 35 | } 36 | } 37 | } 38 | for (int j = 0; j < env->n_agents_per_env; j++) { 39 | char dx = env->x[j] - env->x[i]; 40 | char dy = env->y[j] - env->y[i]; 41 | if (abs(dx) <= env->vision && abs(dy) <= env->vision) { 42 | env->obs[i * ob_pixels + center + dx * ob_size + dy] = AGENT; 43 | } 44 | } 45 | char dx = env->goal_x - env->x[i]; 46 | char dy = env->goal_y - env->y[i]; 47 | if (abs(dx) <= env->vision && abs(dy) <= env->vision) { 48 | env->obs[i * ob_pixels + center + dx * ob_size + dy] = GOAL; 49 | } 50 | } 51 | } 52 | 53 | void get_total_obs(CMultiGrid* env) { 54 | memset(env->total_obs, 0, env->size * env->size); 55 | for (int i = 0; i < env->n_agents_per_env; i++) { 56 | env->total_obs[env->y[i] + env->x[i] * env->size] = AGENT; 57 | } 58 | env->total_obs[env->goal_y + env->goal_x * env->size] = GOAL; 59 | } 60 | 61 | void c_reset(CMultiGrid* env, char with_total_obs) { 62 | env->t = 0; 63 | env->goal_x = rand() % env->size; 64 | env->goal_y = rand() % env->size; 65 | for (int i = 0; i < env->n_agents_per_env; i++) { 66 | env->x[i] = rand() % env->size; 67 | env->y[i] = rand() % env->size; 68 | } 69 | get_obs(env); 70 | if (with_total_obs > 0) { 71 | get_total_obs(env); 72 | } 73 | } 74 | 75 | void agent_step(CMultiGrid* env, unsigned char *acts, int i, char with_total_obs) { 76 | env->rewards[i] = 0; 77 | env->dones[i] = 0; 78 | unsigned char act = acts[i]; 79 | if (act == LEFT) env->x[i]--; 80 | else if (act == RIGHT) env->x[i]++; 81 | else if (act == UP) env->y[i]++; 82 | else if (act == DOWN) env->y[i]--; 83 | if (env->t > 3 * env->size || env->x[i] < 0 || env->y[i] < 0 || env->x[i] >= env->size || env->y[i] >= env->size) { 84 | env->dones[i] = 1; 85 | env->rewards[i] = -1; 86 | c_reset(env, with_total_obs); 87 | return; 88 | } 89 | if (env->x[i] == env->goal_x && env->y[i] == env->goal_y) { 90 | env->dones[i] = 1; 91 | env->rewards[i] = 1; 92 | c_reset(env, with_total_obs); 93 | return; 94 | } 95 | } 96 | 97 | void c_step(CMultiGrid* env, unsigned char *acts, bool with_total_obs) { 98 | for (int i = 0; i < env->n_agents_per_env; i++){ 99 | agent_step(env, acts, i, with_total_obs); 100 | } 101 | get_obs(env); 102 | if (with_total_obs > 0) { 103 | get_total_obs(env); 104 | } 105 | env->t++; 106 | } 107 | ''' 108 | 109 | ctypedef struct CMultiGrid: 110 | char *obs 111 | char *rewards 112 | char *dones 113 | unsigned char *x 114 | unsigned char *y 115 | char *total_obs 116 | int n_agents_per_env, vision, size, t, goal_x, goal_y 117 | 118 | void c_reset(CMultiGrid* env, char with_total_obs) 119 | void c_step(CMultiGrid* env, unsigned char *acts, char with_total_obs) 120 | 121 | cdef class MultiGrid: 122 | cdef: 123 | CMultiGrid* envs 124 | int n_agents, _n_acts, _n_agents_per_env 125 | np.ndarray obs_arr, rewards_arr, dones_arr, x_arr, y_arr, total_obs_arr 126 | cdef char[:, :, :] obs_memview 127 | cdef char[:] rewards_memview 128 | cdef char[:] dones_memview 129 | cdef unsigned char[:] x_memview 130 | cdef unsigned char[:] y_memview 131 | cdef char[:, :, :] total_obs_memview 132 | int size 133 | bint with_total_obs 134 | 135 | def __init__(self, n_agents=2**14, n_acts=5, n_agents_per_env=2, vision=3, size=8): 136 | self.envs = calloc(n_agents // n_agents_per_env, sizeof(CMultiGrid)) 137 | self.n_agents = n_agents 138 | self._n_acts = n_acts 139 | self._n_agents_per_env = n_agents_per_env 140 | self.obs_arr = np.zeros((n_agents, 2*vision+1, 2*vision+1), dtype=np.int8) 141 | self.rewards_arr = np.zeros(n_agents, dtype=np.int8) 142 | self.dones_arr = np.zeros(n_agents, dtype=np.int8) 143 | self.x_arr = np.zeros(n_agents, dtype=np.uint8) 144 | self.y_arr = np.zeros(n_agents, dtype=np.uint8) 145 | self.total_obs_arr = np.zeros((n_agents // n_agents_per_env, size, size), dtype=np.int8) 146 | self.obs_memview = self.obs_arr 147 | self.rewards_memview = self.rewards_arr 148 | self.dones_memview = self.dones_arr 149 | self.x_memview = self.x_arr 150 | self.y_memview = self.y_arr 151 | self.total_obs_memview = self.total_obs_arr 152 | cdef int i, i_agent 153 | for i in range(n_agents // n_agents_per_env): 154 | env = &self.envs[i] 155 | i_agent = n_agents_per_env * i 156 | env.obs = &self.obs_memview[i_agent, 0, 0] 157 | env.rewards = &self.rewards_memview[i_agent] 158 | env.dones = &self.dones_memview[i_agent] 159 | env.x = &self.x_memview[i_agent] 160 | env.y = &self.y_memview[i_agent] 161 | env.total_obs = &self.total_obs_memview[i, 0, 0] 162 | env.n_agents_per_env = n_agents_per_env 163 | env.vision = vision 164 | env.size = size 165 | 166 | def reset(self, seed=None, with_total_obs=False): 167 | if seed is not None: 168 | srand(seed) 169 | cdef int i 170 | for i in range(self.n_agents // self.n_agents_per_env): 171 | c_reset(&self.envs[i], with_total_obs) 172 | return self 173 | 174 | def step(self, np.ndarray acts, with_total_obs=False): 175 | cdef unsigned char[:] acts_memview = acts 176 | cdef int i 177 | for i in range(self.n_agents // self.n_agents_per_env): 178 | c_step(&self.envs[i], &acts_memview[self.n_agents_per_env * i], with_total_obs) 179 | 180 | def close(self): 181 | free(self.envs) 182 | 183 | @property 184 | def obs(self): return self.obs_arr 185 | 186 | @property 187 | def rewards(self): return self.rewards_arr 188 | 189 | @property 190 | def dones(self): return self.dones_arr 191 | 192 | @property 193 | def n_acts(self): return self._n_acts 194 | 195 | @property 196 | def n_agents_per_env(self): return self._n_agents_per_env 197 | 198 | @property 199 | def total_obs(self): return self.total_obs_arr 200 | -------------------------------------------------------------------------------- /flashrl/envs/pong.pyx: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | import numpy as np 3 | cimport numpy as np 4 | 5 | from libc.stdlib cimport free, srand, calloc 6 | 7 | cdef extern from *: 8 | ''' 9 | #include 10 | #include 11 | #include 12 | 13 | const char PADDLE = 1, BALL = 2; 14 | const unsigned char UP = 1, DOWN = 2; 15 | 16 | typedef struct { 17 | char *obs0, *obs1, *reward0, *reward1, *done0, *done1; 18 | int size_x, size_y, t, paddle0_x, paddle0_y, paddle1_x, paddle1_y, x, dx; 19 | float y, dy, max_dy; 20 | } CPong; 21 | 22 | void set_obs(CPong* env, char paddle, char ball) { 23 | for (int i = -1; i < 2; i++) { 24 | if (env->paddle0_y + i >= 0 && env->paddle0_y + i <= env->size_y - 1) { 25 | env->obs0[(env->size_x - 1) - env->paddle0_x + (env->paddle0_y + i) * env->size_x] = paddle; 26 | env->obs1[env->paddle0_x + (env->paddle0_y + i) * env->size_x] = paddle; 27 | } 28 | if (env->paddle1_y + i >= 0 && env->paddle1_y + i <= env->size_y - 1) { 29 | env->obs0[(env->size_x - 1) - env->paddle1_x + (env->paddle1_y + i) * env->size_x] = paddle; 30 | env->obs1[env->paddle1_x + (env->paddle1_y + i) * env->size_x] = paddle; 31 | } 32 | } 33 | env->obs0[(env->size_x - 1) - env->x + (int)(roundf(env->y)) * env->size_x] = ball; 34 | env->obs1[env->x + (int)(roundf(env->y)) * env->size_x] = ball; 35 | } 36 | 37 | void c_reset(CPong* env) { 38 | env->t = 0; 39 | memset(env->obs0, 0, env->size_x * env->size_y); 40 | memset(env->obs1, 0, env->size_x * env->size_y); 41 | env->x = env->size_x / 2; 42 | env->y = rand() % (env->size_y - 1); 43 | env->dx = (rand() % 2) ? 1 : -1; 44 | env->dy = 2.0f * ((float)rand() / RAND_MAX) - 1.0f; 45 | env->paddle0_x = 0; 46 | env->paddle1_x = env->size_x - 1; 47 | env->paddle0_y = env->paddle1_y = env->size_y / 2; 48 | set_obs(env, PADDLE, BALL); 49 | } 50 | 51 | void c_step(CPong* env, unsigned char act0, unsigned char act1) { 52 | env->reward0[0] = env->reward1[0] = 0; 53 | env->done0[0] = env->done1[0] = 0; 54 | set_obs(env, 0, 0); 55 | if (act0 == UP && env->paddle0_y > 0) env->paddle0_y--; 56 | if (act0 == DOWN && env->paddle0_y < env->size_y - 2) env->paddle0_y++; 57 | if (act1 == UP && env->paddle1_y > 0) env->paddle1_y--; 58 | if (act1 == DOWN && env->paddle1_y < env->size_y - 2) env->paddle1_y++; 59 | env->dy = fminf(fmaxf(env->dy, -env->max_dy), env->max_dy); 60 | env->x += env->dx; 61 | env->y += env->dy; 62 | env->y = fminf(fmaxf(env->y, 0.f), env->size_y - 1.f); 63 | if (env->y <= 0 || env->y >= env->size_y - 1) env->dy = -env->dy; 64 | if (env->x == 1 && env->y >= env->paddle0_y - 1 && env->y <= env->paddle0_y + 1) { 65 | env->dx = -env->dx; 66 | env->dy += env->y - env->paddle0_y; 67 | } 68 | if (env->x == env->size_x - 2 && env->y >= env->paddle1_y - 1 && env->y <= env->paddle1_y + 1) { 69 | env->dx = -env->dx; 70 | env->dy += env->y - env->paddle1_y; 71 | } 72 | if (env->x == 0 || env->x == env->size_x - 1) { 73 | env->reward1[0] = 2 * (char)(env->x == 0) - 1; 74 | env->reward0[0] = -env->reward1[0]; 75 | env->done0[0] = env->done1[0] = 1; 76 | c_reset(env); 77 | } 78 | set_obs(env, PADDLE, BALL); 79 | env->t++; 80 | } 81 | ''' 82 | 83 | ctypedef struct CPong: 84 | char *obs0 85 | char *obs1 86 | char *reward0 87 | char *reward1 88 | char *done0 89 | char *done1 90 | int size_x, size_y, t, paddle0_x, paddle0_y, paddle1_x, paddle1_y, x, dx 91 | float y, dy, max_dy 92 | 93 | void c_reset(CPong* env) 94 | void c_step(CPong* env, unsigned char act0, unsigned char act1) 95 | 96 | cdef class Pong: 97 | cdef: 98 | CPong* envs 99 | int n_agents, _n_acts 100 | np.ndarray obs_arr, rewards_arr, dones_arr 101 | cdef char[:, :, :] obs_memview 102 | cdef char[:] rewards_memview 103 | cdef char[:] dones_memview 104 | int size_x, size_y 105 | float max_dy 106 | 107 | def __init__(self, n_agents=2**14, n_acts=3, size_x=16, size_y=8, max_dy=1.): 108 | self.envs = calloc(n_agents // 2, sizeof(CPong)) 109 | self.n_agents = n_agents 110 | self._n_acts = n_acts 111 | self.obs_arr = np.zeros((n_agents, size_y, size_x), dtype=np.int8) 112 | self.rewards_arr = np.zeros(n_agents, dtype=np.int8) 113 | self.dones_arr = np.zeros(n_agents, dtype=np.int8) 114 | self.obs_memview = self.obs_arr 115 | self.rewards_memview = self.rewards_arr 116 | self.dones_memview = self.dones_arr 117 | cdef int i 118 | for i in range(n_agents // 2): 119 | env = &self.envs[i] 120 | env.obs0, env.obs1 = &self.obs_memview[2 * i, 0, 0], &self.obs_memview[2 * i + 1, 0, 0] 121 | env.reward0, env.reward1 = &self.rewards_memview[2 * i], &self.rewards_memview[2 * i + 1] 122 | env.done0, env.done1 = &self.dones_memview[2 * i], &self.dones_memview[2 * i + 1] 123 | env.size_x = size_x 124 | env.size_y = size_y 125 | env.max_dy = max_dy 126 | 127 | def reset(self, seed=None): 128 | if seed is not None: 129 | srand(seed) 130 | cdef int i 131 | for i in range(self.n_agents // 2): 132 | c_reset(&self.envs[i]) 133 | return self 134 | 135 | def step(self, np.ndarray acts): 136 | cdef unsigned char[:] acts_memview = acts 137 | cdef int i 138 | for i in range(self.n_agents // 2): 139 | c_step(&self.envs[i], acts_memview[2 * i], acts_memview[2 * i + 1]) 140 | 141 | def close(self): 142 | free(self.envs) 143 | 144 | @property 145 | def obs(self): return self.obs_arr 146 | 147 | @property 148 | def rewards(self): return self.rewards_arr 149 | 150 | @property 151 | def dones(self): return self.dones_arr 152 | 153 | @property 154 | def n_acts(self): return self._n_acts 155 | -------------------------------------------------------------------------------- /flashrl/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | from torch.utils.tensorboard import SummaryWriter 4 | 5 | from .models import Policy 6 | DEVICE = 'mps' if torch.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu' 7 | 8 | 9 | class Learner: 10 | def __init__(self, env, model=None, device=None, dtype=None, compile_no_lstm=False, **kwargs): 11 | self.env = env 12 | self.device = DEVICE if device is None else device 13 | self.dtype = dtype if dtype is not None else torch.bfloat16 if self.device == 'cuda' else torch.float32 14 | self.model = Policy(self.env, **kwargs).to(self.device, self.dtype) if model is None else model 15 | if self.model.lstm is None and compile_no_lstm: # only no-lstm policy gets faster from torch.compile 16 | self.model = torch.compile(self.model, fullgraph=True, mode='reduce-overhead') 17 | self._data, self._np_data, self._rollout_state, self._ppo_state = None, None, None, None 18 | 19 | def fit(self, iters=40, steps=16, lr=.01, bs=None, anneal_lr=True, log=False, desc=None, stop_func=None, **hparams): 20 | bs = bs or len(self.env.obs) // 2 21 | self.setup_data(steps, bs) 22 | logger = SummaryWriter() if log else None 23 | opt = torch.optim.Adam(self.model.parameters(), lr=lr, eps=1e-5) 24 | pbar = tqdm(range(iters), total=iters) 25 | curves = [] 26 | for i in pbar: 27 | opt.param_groups[0]['lr'] = lr * (1 - i / iters) if anneal_lr else lr 28 | self.rollout(steps) 29 | losses = ppo(self.model, opt, bs=bs, state=self._ppo_state, **self._data, **hparams) 30 | if desc: pbar.set_description(f'{desc}: {losses[desc] if desc in losses else self._data[desc].mean():.3f}') 31 | if i: pbar.set_postfix_str(f'{1e-6 * self._data["act"].numel() * pbar.format_dict["rate"]:.1f}M steps/s') 32 | if log: 33 | for k, v in losses.items(): logger.add_scalar(k, v, global_step=i) 34 | for name, param in self.model.named_parameters(): logger.add_histogram(name, param, global_step=i) 35 | curves.append(losses) 36 | if stop_func is not None: 37 | if stop_func(**self._data, **losses): break 38 | return {k: [m[k].item() for m in curves] for k in curves[0]} 39 | 40 | def setup_data(self, steps, bs=None): 41 | x = torch.zeros((len(self.env.obs), steps), dtype=self.dtype, device=self.device) 42 | obs = torch.zeros((*x.shape, *self.env.obs.shape[1:]), dtype=self.dtype, device=self.device) 43 | self._data = {'obs': obs, 'act': x.clone().byte(), 'logprob': x.clone(), 'value': x} 44 | self._np_data = {'reward': x.char().cpu().numpy(), 'done': x.char().cpu().numpy()} 45 | if self.model.lstm is not None: 46 | zeros = torch.zeros((len(obs), self.model.encoder.out_features), dtype=self.dtype, device=self.device) 47 | self._rollout_state = (zeros, zeros.clone()) 48 | if bs is not None: 49 | zeros = torch.zeros((bs, self.model.encoder.out_features), dtype=self.dtype, device=self.device) 50 | self._ppo_state = (zeros, zeros.clone()) 51 | 52 | def rollout(self, steps, state=None, extra_args_list=None, **kwargs): 53 | state = self._rollout_state if state is None else state 54 | if steps != (0 if self._data is None else self._data['obs'].shape[1]): self.setup_data(steps) 55 | extra_data = {} if extra_args_list is None else {k: [] for k in extra_args_list} 56 | for i in range(steps): 57 | o = self.to_torch(self.env.obs) 58 | with torch.no_grad(): 59 | act, logp, _, value, state = self.model(o, state=state) 60 | self._data['obs'][:, i] = o 61 | self._data['act'][:, i] = act 62 | self._data['logprob'][:, i] = logp 63 | self._data['value'][:, i] = value 64 | self._np_data['reward'][:, i] = self.env.rewards 65 | self._np_data['done'][:, i] = self.env.dones 66 | for k in extra_data: extra_data[k].append(self.to_torch(getattr(self.env, k).copy())) 67 | self.env.step(act.cpu().numpy(), **kwargs) 68 | self._data.update({k: self.to_torch(v) for k, v in self._np_data.items()}) 69 | return {k: torch.stack(v, dim=1) for k, v in extra_data.items()} 70 | 71 | def to_torch(self, x, non_blocking=True): 72 | return torch.from_numpy(x).to(device=self.device, dtype=self.dtype, non_blocking=non_blocking) 73 | 74 | 75 | def ppo(model, opt, obs, value, act, logprob, reward, done, bs=2**13, gamma=.99, gae_lambda=.95, clip_coef=.1, 76 | value_coef=.5, value_clip_coef=.1, entropy_coef=.01, max_grad_norm=.5, norm_adv=True, state=None): 77 | advs = get_advantages(value, reward, done, gamma=gamma, gae_lambda=gae_lambda) 78 | obs, value, act, logprob, advs = [xs.view(-1, bs, *xs.shape[2:]) for xs in [obs, value, act, logprob, advs]] 79 | returns = advs + value 80 | metrics, metric_keys = [], ['loss', 'policy_loss', 'value_loss', 'entropy_loss', 'kl'] 81 | for o, old_value, a, old_logp, adv, ret in zip(obs, value, act, logprob, advs, returns): 82 | _, logp, entropy, val, state = model(o, state=state, act=a) 83 | state = state if model.lstm is None else (state[0].detach(), state[1].detach()) 84 | logratio = logp - old_logp 85 | ratio = logratio.exp() 86 | adv = (adv - adv.mean()) / (adv.std() + 1e-8) if norm_adv else adv 87 | policy_loss = torch.max(-adv * ratio, -adv * ratio.clip(1 - clip_coef, 1 + clip_coef)).mean() 88 | if value_clip_coef: 89 | v_clipped = old_value + (val - old_value).clip(-value_clip_coef, value_clip_coef) 90 | value_loss = .5 * torch.max((val - ret) ** 2, (v_clipped - ret) ** 2).mean() 91 | else: 92 | value_loss = .5 * ((val - ret) ** 2).mean() 93 | entropy = entropy.mean() 94 | loss = policy_loss + value_coef * value_loss - entropy_coef * entropy 95 | opt.zero_grad() 96 | loss.backward() 97 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) 98 | opt.step() 99 | kl = ((ratio - 1) - logratio).mean() 100 | metrics.append([loss, policy_loss, value_loss, entropy, kl]) 101 | return {k: torch.stack([values[i] for values in metrics]).mean() for i, k in enumerate(metric_keys)} 102 | 103 | 104 | def get_advantages(value, reward, done, gamma=.99, gae_lambda=.95): # see arxiv.org/abs/1506.02438 eq. (16)-(18) 105 | advs = torch.zeros_like(value) 106 | not_done = 1. - done 107 | for t in range(1, done.shape[1]): 108 | delta = reward[:, -t] + gamma * value[:, -t] * not_done[:, -t] - value[:, -t - 1] 109 | advs[:, -t-1] = delta + gamma * gae_lambda * not_done[:, -t] * advs[:, -t] 110 | return advs 111 | -------------------------------------------------------------------------------- /flashrl/models.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | 5 | class Policy(torch.nn.Module): 6 | def __init__(self, env, hidden_size=128, lstm=False): 7 | super().__init__() 8 | self.encoder = torch.nn.Linear(math.prod(env.obs.shape[1:]), hidden_size) 9 | self.actor = torch.nn.Linear(hidden_size, env.n_acts) 10 | self.value_head = torch.nn.Linear(hidden_size, 1) 11 | self.lstm = cleanrl_init(torch.nn.LSTMCell(hidden_size, hidden_size)) if lstm else None 12 | 13 | def forward(self, x, state, act=None, with_entropy=None): 14 | with_entropy = act is not None if with_entropy is None else with_entropy 15 | h = self.encoder(x.view(len(x), -1)).relu() 16 | h, c = (h, None) if self.lstm is None else self.lstm(h, state) 17 | value = self.value_head(h)[:, 0] 18 | x = self.actor(h) 19 | act = torch.multinomial(x.softmax(dim=-1), 1)[:, 0] if act is None else act 20 | x = x - x.logsumexp(dim=-1, keepdim=True) 21 | logprob = x.gather(-1, act[..., None].long())[..., 0] 22 | entropy = -(x * x.softmax(dim=-1)).sum(-1) if with_entropy else None 23 | return act.byte(), logprob, entropy, value, (h, c) 24 | 25 | 26 | def cleanrl_init(module): 27 | for name, param in module.named_parameters(): 28 | if 'bias' in name: torch.nn.init.constant_(param, 0) 29 | elif 'weight' in name: torch.nn.init.orthogonal_(param, 1) 30 | return module 31 | -------------------------------------------------------------------------------- /flashrl/utils.py: -------------------------------------------------------------------------------- 1 | import sys, time, random, platform 2 | import torch 3 | import plotille 4 | import numpy as np 5 | 6 | from .envs import key_maps, emoji_maps 7 | 8 | 9 | def play(env, model=None, playable=False, steps=None, fps=4, obs='obs', dump=False, with_data=True, idx=0, **kwargs): 10 | key_map = key_maps[env.__class__.__name__.lower()] 11 | emoji_map = emoji_maps[env.__class__.__name__.lower()] 12 | if playable: print(f'Press {"".join(key_map)} to act, m for model act and q to quit') 13 | data, state = {}, None 14 | for i in range((10000 if playable else 64) if steps is None else steps): 15 | data.update({'step': i}) 16 | render(getattr(env, obs)[idx], cursor_up=i and not dump, emoji_map=emoji_map, data=data if with_data else None) 17 | acts = np.zeros(len(env.obs), dtype=np.uint8) 18 | if model is not None: 19 | o = torch.from_numpy(env.obs).to(device=model.actor.weight.device, dtype=model.actor.weight.dtype) 20 | with torch.no_grad(): acts, logp, entropy, val, state = model(o, state=state, with_entropy=True) 21 | data.update({'model act': acts[idx], 'logp': logp[idx], 'entropy': entropy[idx], 'value': val[idx]}) 22 | acts = acts.cpu().numpy() 23 | key = get_pressed_key() if playable else f'm{time.sleep(1 / fps)}'[:1] 24 | if key == 'q': break 25 | acts[idx] = acts[idx] if key == 'm' else key_map[key] if key in key_map else 0 26 | env.step(acts, **kwargs) 27 | data.update({'act': acts[idx], 'reward': env.rewards[idx], 'done': env.dones[idx]}) 28 | 29 | 30 | def render(ob, cursor_up=True, emoji_map=None, data=None): 31 | if cursor_up: print(f'\033[A\033[{len(ob)}A') 32 | ob = 23 * (ob - ob.min()) / (ob.max() - ob.min()) + 232 if emoji_map is None else ob 33 | for i, row in enumerate(ob): 34 | for o in row.tolist(): 35 | print(f'\033[48;5;{f"{232 + o}m" if emoji_map is None else f"232m{emoji_map[o]}"}\033[0m', end='') 36 | if data is not None: 37 | if i < len(data): 38 | print(f'{list(data.keys())[i]}: {list(data.values())[i]:.3g}', end=' ') 39 | print() 40 | 41 | 42 | def set_seed(seed): 43 | np.random.seed(seed) 44 | random.seed(seed) 45 | if seed is not None: 46 | torch.manual_seed(seed) 47 | torch.backends.cudnn.deterministic = True 48 | 49 | 50 | def print_curve(array, label=None, height=8, width=65): 51 | fig = plotille.Figure() 52 | fig._height, fig._width = height, width 53 | fig.y_label = fig.y_label if label is None else label 54 | fig.scatter(list(range(len(array))), array) 55 | print('\n'.join(fig.show().split('\n')[:-2])) 56 | 57 | 58 | def get_pressed_key(): 59 | if platform.system() == 'Windows': 60 | import msvcrt 61 | key = msvcrt.getch() 62 | else: 63 | import tty, termios 64 | fd = sys.stdin.fileno() 65 | old_settings = termios.tcgetattr(fd) 66 | try: 67 | tty.setraw(sys.stdin.fileno()) 68 | key = sys.stdin.read(1) 69 | finally: 70 | termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) 71 | return key 72 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ['setuptools', 'wheel', 'Cython', 'numpy'] 3 | build-backend = 'setuptools.build_meta' 4 | 5 | [project] 6 | name = 'flashrl' 7 | version = '0.2.1' 8 | description = 'Fast reinforcement learning 💨' 9 | authors = [{name = 'codingfisch', email = 'l_fisc17@wwu.de'}] 10 | readme = 'README.md' 11 | dependencies = ['torch', 'Cython', 'tqdm', 'plotille', 'tensorboard'] 12 | 13 | [tool.setuptools] 14 | license-files = [] 15 | packages = ['flashrl', 'flashrl.envs'] 16 | include-package-data = true 17 | 18 | [tool.setuptools.package-data] 19 | flashrl = ['envs/*.pyx', 'envs/*.c'] 20 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | Cython 3 | tqdm 4 | plotille 5 | tensorboard 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | from pathlib import Path 3 | from setuptools import setup, Extension 4 | from Cython.Build import cythonize 5 | 6 | kwargs = {'extra_compile_args': ['-DNPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION', '-O1']} 7 | mods = [Extension(f'flashrl.envs.{Path(fp.stem)}', sources=[fp], **kwargs) for fp in Path('flashrl/envs').glob('*.pyx')] 8 | setup(ext_modules=cythonize(mods), packages=['flashrl', 'flashrl.envs'], include_dirs=[numpy.get_include()], 9 | include_package_data=True) 10 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codingfisch/flashrl/8769edd4383a178678b3146daf7a1023e5e5fb67/test/__init__.py -------------------------------------------------------------------------------- /test/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from flashrl.utils import play, print_curve 3 | from flashrl.models import Policy 4 | from flashrl.main import get_advantages, Learner 5 | from flashrl.envs.grid import Grid 6 | 7 | 8 | def test_print_ascii_curve(): 9 | array = [1, 2, 3, 4, 5] 10 | print_curve(array) 11 | 12 | def test_Policy_init(): 13 | env = Grid(n_agents=1, size=5) 14 | model = Policy(env, lstm=True) 15 | assert isinstance(model, Policy), 'Model initialization failed' 16 | assert model.actor.in_features == 128, 'Actor layer initialization failed' 17 | assert model.lstm.hidden_size == 128, 'LSTM layer hidden size is incorrect' 18 | 19 | def test_learner_init(): 20 | env = Grid(n_agents=1, size=5) 21 | learn = Learner(env) 22 | assert isinstance(learn.model, Policy), 'Model initialization failed' 23 | 24 | def test_learner_fit(): 25 | env = Grid(n_agents=2**13, size=5) 26 | learn = Learner(env) 27 | losses = learn.fit(iters=1) 28 | assert 'loss' in losses, 'Fit did not return metrics' 29 | 30 | def test_play(): 31 | env = Grid(n_agents=2**13, size=5) 32 | learn = Learner(env) 33 | learn.fit(iters=1) 34 | play(env, learn.model, steps=16) 35 | 36 | def test_get_advantages(): 37 | values = torch.rand(1, 10) 38 | rewards = torch.rand(1, 10) 39 | dones = torch.zeros(1, 10) 40 | advantages = get_advantages(values, rewards, dones) 41 | assert advantages is not None, 'Advantages calculation failed' 42 | assert advantages.shape == values.shape, 'Advantages shape mismatch' 43 | 44 | if __name__ == '__main__': 45 | test_print_ascii_curve() 46 | test_Policy_init() 47 | test_learner_init() 48 | test_learner_fit() 49 | test_play() 50 | test_get_advantages() 51 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import flashrl as frl 2 | frl.set_seed(SEED:=1) 3 | 4 | env = frl.envs.Pong(n_agents=2**14).reset(SEED) # try one of: Pong, Grid, MultiGrid 5 | learn = frl.Learner(env, hidden_size=128, lstm=True) # faster with lstm=False and smaller hidden_size 6 | curves = learn.fit(40, steps=16, desc='done') # ,lr=1e-2, gamma=.99) set hparams here 7 | frl.print_curve(curves['loss'], label='loss') 8 | frl.play(env, learn.model, fps=8, playable=False) # if env is MultiGrid, try obs='total_obs', with_total_obs=True 9 | env.close() 10 | --------------------------------------------------------------------------------