├── .github
└── workflows
│ └── publish.yml
├── .gitignore
├── LICENSE
├── MANIFEST.in
├── README.md
├── flashrl
├── __init__.py
├── envs
│ ├── __init__.py
│ ├── grid.pyx
│ ├── multigrid.pyx
│ └── pong.pyx
├── main.py
├── models.py
└── utils.py
├── pyproject.toml
├── requirements.txt
├── setup.py
├── test
├── __init__.py
└── test.py
└── train.py
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: Build and Publish to PyPI
2 |
3 | on:
4 | push:
5 | tags:
6 | - 'v*.*.*'
7 |
8 | jobs:
9 | build-wheels:
10 | name: Build wheels on ${{ matrix.os }} for Python ${{ matrix.python-version }}
11 | runs-on: ${{ matrix.os }}
12 | strategy:
13 | matrix:
14 | os: [ubuntu-latest, windows-latest, macos-latest]
15 | python-version: ['3.10', '3.11', '3.12', '3.13']
16 | exclude:
17 | - os: ubuntu-latest
18 | python-version: '3.10'
19 | - os: ubuntu-latest
20 | python-version: '3.11'
21 | - os: ubuntu-latest
22 | python-version: '3.13'
23 | fail-fast: false
24 |
25 | steps:
26 | - name: Checkout code
27 | uses: actions/checkout@v4
28 |
29 | - name: Set up Python ${{ matrix.python-version || '3.12' }}
30 | uses: actions/setup-python@v5
31 | with:
32 | python-version: ${{ matrix.python-version || '3.12' }}
33 |
34 | - name: Install dependencies
35 | run: |
36 | python -m pip install --upgrade pip
37 | pip install build setuptools wheel cython numpy
38 |
39 | - name: Install cibuildwheel (Linux only)
40 | if: matrix.os == 'ubuntu-latest'
41 | run: |
42 | pip install cibuildwheel
43 |
44 | - name: Build wheels (Linux with cibuildwheel)
45 | if: matrix.os == 'ubuntu-latest'
46 | run: |
47 | python -m cibuildwheel --output-dir dist
48 | env:
49 | CIBW_BUILD: "cp310-* cp311-* cp312-* cp313-*"
50 | CIBW_ARCHS: auto64
51 | CIBW_SKIP: "*-musllinux_*"
52 | CIBW_BUILD_VERBOSITY: 1
53 |
54 | - name: Build wheels (Windows and macOS)
55 | if: matrix.os != 'ubuntu-latest'
56 | run: |
57 | python -m build --wheel
58 |
59 | - name: Upload wheels as artifacts
60 | uses: actions/upload-artifact@v4
61 | with:
62 | name: wheels-${{ matrix.os }}-${{ matrix.python-version || 'all' }}
63 | path: dist/*.whl
64 |
65 | publish:
66 | needs: build-wheels
67 | runs-on: ubuntu-latest
68 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v')
69 | steps:
70 | - name: Download all wheel artifacts
71 | uses: actions/download-artifact@v4
72 | with:
73 | path: dist
74 | pattern: wheels-*
75 | merge-multiple: true
76 |
77 | - name: Publish to PyPI
78 | env:
79 | TWINE_USERNAME: __token__
80 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
81 | run: |
82 | pip install twine
83 | twine upload dist/*.whl
84 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # UV
98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | #uv.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116 | .pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 |
170 | # PyPI configuration file
171 | .pypirc
172 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Lukas Fisch
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include README.md
2 | recursive-include flashrl/envs *.pyx *.c
3 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # flashrl
2 | `flashrl` does RL with **millions of steps/second 💨 while being tiny**: ~200 lines of code
3 |
4 | 🛠️ `pip install flashrl` or clone the repo & `pip install -r requirements.txt`
5 | - If cloned (or if envs changed), compile: `python setup.py build_ext --inplace`
6 |
7 | 💡 `flashrl` will always be **tiny**: **Read the code** (+paste into LLM) to understand it!
8 | ## Quick Start 🚀
9 | `flashrl` uses a `Learner` that holds an `env` and a `model` (default: `Policy` with LSTM)
10 |
11 | ```python
12 | import flashrl as frl
13 |
14 | learn = frl.Learner(frl.envs.Pong(n_agents=2**14))
15 | curves = learn.fit(40, steps=16, desc='done')
16 | frl.print_curve(curves['loss'], label='loss')
17 | frl.play(learn.env, learn.model, fps=8)
18 | learn.env.close()
19 | ```
20 | `.fit` does RL with ~**10 million steps**: `40` iterations × `16` steps × `2**14` agents!
21 |
22 |
23 |
24 |
25 |
26 | **Run it yourself via `python train.py` and play against the AI** 🪄
27 |
28 |
29 | Click here, to read a tiny doc 📑
30 |
31 | `Learner` takes the arguments
32 | - `env`: RL environment
33 | - `model`: A `Policy` model
34 | - `device`: Per default picks `mps` or `cuda` if available else `cpu`
35 | - `dtype`: Per default `torch.bfloat16` if device is `cuda` else `torch.float32`
36 | - `compile_no_lstm`: Speedup via `torch.compile` if `model` has no `lstm`
37 | - `**kwargs`: Passed to the `Policy`, e.g. `hidden_size` or `lstm`
38 |
39 | `Learner.fit` takes the arguments
40 | - `iters`: Number of iterations
41 | - `steps`: Number of steps in `rollout`
42 | - `desc`: Progress bar description (e.g. `'reward'`)
43 | - `log`: If `True`, `tensorboard` logging is enabled
44 | - run `tensorboard --logdir=runs`and visit `http://localhost:6006` in the browser!
45 | - `stop_func`: Function that stops training if it returns `True` e.g.
46 |
47 | ```python
48 | ...
49 | def stop(kl, **kwargs):
50 | return kl > .1
51 |
52 | curves = learn.fit(40, steps=16, stop_func=stop)
53 | ...
54 | ```
55 | - `lr`, `anneal_lr` & args of `ppo` after `bs`: Hyperparameters
56 |
57 | The most important functions in `flashrl/utils.py` are
58 | - `print_curve`: Visualizes the loss across the `iters`
59 | - `play`: Plays the environment in the terminal and takes
60 | - `model`: A `Policy` model
61 | - `playable`: If `True`, allows you to act (or decide to let the model act)
62 | - `steps`: Number of steps
63 | - `fps`: Frames per second
64 | - `obs`: Argument of the env that should be rendered as observations
65 | - `dump`: If `True`, no frame refresh -> Frames accumulate in the terminal
66 | - `idx`: Agent index between `0` and `n_agents` (default: `0`)
67 |
68 |
69 | ## Environments 🕹️
70 | **Each env is one Cython(=`.pyx`) file** in `flashrl/envs`. **That's it!**
71 |
72 | To **add custom envs**, use `grid.pyx`, `pong.pyx` or `multigrid.pyx` as a **template**:
73 | - `grid.pyx` for **single-agent** envs (~110 LOC)
74 | - `pong.pyx` for **1 vs 1 agent** envs (~150 LOC)
75 | - `multigrid.pyx` for **multi-agent** envs (~190 LOC)
76 |
77 | | `Grid` | `Pong` | `MultiGrid` |
78 | |-----------------------|-----------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------|
79 | | Agent must reach goal | Agent must score | Agent must reach goal first |
80 | || | |
81 |
82 | ## Acknowledgements 🙌
83 | I want to thank
84 | - [Joseph Suarez](https://github.com/jsuarez5341) for open sourcing RL envs in C(ython)! Star [PufferLib](https://github.com/PufferAI/PufferLib) ⭐
85 | - [Costa Huang](https://github.com/vwxyzjn) for open sourcing high-quality single-file RL code! Star [cleanrl](https://github.com/vwxyzjn/cleanrl) ⭐
86 |
87 | and last but not least...
88 |
89 |
90 |
91 |
92 |
--------------------------------------------------------------------------------
/flashrl/__init__.py:
--------------------------------------------------------------------------------
1 | from flashrl import envs
2 | from .main import Learner
3 | from .models import Policy
4 | from .utils import play, render, set_seed, print_curve
5 |
--------------------------------------------------------------------------------
/flashrl/envs/__init__.py:
--------------------------------------------------------------------------------
1 | emoji_maps = {'grid': {0: ' ', 1: '🦠', 2: '🍪'},
2 | 'pong': {0: ' ', 1: '🔲', 2: '🔴'},
3 | 'multigrid': {0: ' ', 1: '🧱', 2: '🦠', 3: '🍪'}}
4 | key_maps = {'grid': {'a': 1, 'd': 2, 'w': 3, 's': 4},
5 | 'pong': {'w': 1, 's': 2},
6 | 'multigrid': {'a': 1, 'd': 2, 'w': 3, 's': 4}}
7 | from .grid import Grid
8 | from .pong import Pong
9 | from .multigrid import MultiGrid
10 |
--------------------------------------------------------------------------------
/flashrl/envs/grid.pyx:
--------------------------------------------------------------------------------
1 | # cython: language_level=3
2 | import numpy as np
3 | cimport numpy as np
4 |
5 | from libc.stdlib cimport free, srand, calloc
6 |
7 | cdef extern from *:
8 | '''
9 | #include
10 | #include
11 |
12 | const char AGENT = 1, GOAL = 2;
13 | const unsigned char LEFT = 1, RIGHT = 2, UP = 3, DOWN = 4;
14 |
15 | typedef struct {
16 | char *obs, *reward, *done;
17 | int size, t, x, y, goal_x, goal_y;
18 | } CGrid;
19 |
20 | void c_reset(CGrid* env) {
21 | env->t = 0;
22 | memset(env->obs, 0, env->size * env->size);
23 | env->x = env->y = env->size / 2;
24 | env->obs[env->x + env->y * env->size] = AGENT;
25 | env->goal_x = rand() % env->size;
26 | env->goal_y = rand() % env->size;
27 | if (env->goal_x == env->x && env->goal_y == env->y) env->goal_x++;
28 | env->obs[env->goal_x + env->goal_y * env->size] = GOAL;
29 | }
30 |
31 | void c_step(CGrid* env, unsigned char act) {
32 | env->reward[0] = 0;
33 | env->done[0] = 0;
34 | env->obs[env->x + env->y * env->size] = 0;
35 | if (act == LEFT) env->x--;
36 | else if (act == RIGHT) env->x++;
37 | else if (act == UP) env->y--;
38 | else if (act == DOWN) env->y++;
39 | if (env->t > 3 * env->size || env->x < 0 || env->y < 0 || env->x >= env->size || env->y >= env->size) {
40 | env->reward[0] = -1;
41 | env->done[0] = 1;
42 | c_reset(env);
43 | return;
44 | }
45 | int position = env->x + env->y * env->size;
46 | if (env->obs[position] == GOAL) {
47 | env->reward[0] = 1;
48 | env->done[0] = 1;
49 | c_reset(env);
50 | return;
51 | }
52 | env->obs[position] = AGENT;
53 | env->t++;
54 | }
55 | '''
56 |
57 | ctypedef struct CGrid:
58 | char *obs
59 | char *reward
60 | char *done
61 | int size, t, x, y, goal_x, goal_y
62 |
63 | void c_reset(CGrid *env)
64 | void c_step(CGrid *env, unsigned char act)
65 |
66 |
67 | cdef class Grid:
68 | cdef:
69 | CGrid *envs
70 | int n_agents, _n_acts
71 | np.ndarray obs_arr, rewards_arr, dones_arr
72 | cdef char[:, :, :] obs_memview
73 | cdef char[:] rewards_memview
74 | cdef char[:] dones_memview
75 | int size
76 |
77 | def __init__(self, n_agents=2**14, n_acts=5, size=8):
78 | self.envs = calloc(n_agents, sizeof(CGrid))
79 | self.n_agents = n_agents
80 | self._n_acts = n_acts
81 | self.obs_arr = np.zeros((n_agents, size, size), dtype=np.int8)
82 | self.rewards_arr = np.zeros(n_agents, dtype=np.int8)
83 | self.dones_arr = np.zeros(n_agents, dtype=np.int8)
84 | self.obs_memview = self.obs_arr
85 | self.rewards_memview = self.rewards_arr
86 | self.dones_memview = self.dones_arr
87 | cdef int i
88 | for i in range(n_agents):
89 | env = &self.envs[i]
90 | env.obs = &self.obs_memview[i, 0, 0]
91 | env.reward = &self.rewards_memview[i]
92 | env.done = &self.dones_memview[i]
93 | env.size = size
94 |
95 | def reset(self, seed=None):
96 | if seed is not None:
97 | srand(seed)
98 | cdef int i
99 | for i in range(self.n_agents):
100 | c_reset(&self.envs[i])
101 | return self
102 |
103 | def step(self, np.ndarray acts):
104 | cdef unsigned char[:] acts_memview = acts
105 | cdef int i
106 | for i in range(self.n_agents):
107 | c_step(&self.envs[i], acts_memview[i])
108 |
109 | def close(self):
110 | free(self.envs)
111 |
112 | @property
113 | def obs(self): return self.obs_arr
114 |
115 | @property
116 | def rewards(self): return self.rewards_arr
117 |
118 | @property
119 | def dones(self): return self.dones_arr
120 |
121 | @property
122 | def n_acts(self): return self._n_acts
123 |
--------------------------------------------------------------------------------
/flashrl/envs/multigrid.pyx:
--------------------------------------------------------------------------------
1 | # cython: language_level=3
2 | import numpy as np
3 | cimport numpy as np
4 |
5 | from libc.stdlib cimport free, srand, calloc
6 |
7 | cdef extern from *:
8 | '''
9 | #include
10 | #include
11 | #include
12 |
13 | const char WALL = 1, AGENT = 2, GOAL = 3;
14 | const unsigned char NOOP = 0, DOWN = 1, UP = 2, LEFT = 3, RIGHT = 4;
15 |
16 | typedef struct {
17 | char *obs, *rewards, *dones;
18 | unsigned char *x, *y;
19 | char *total_obs;
20 | int n_agents_per_env, vision, size, t, goal_x, goal_y;
21 | } CMultiGrid;
22 |
23 | void get_obs(CMultiGrid* env) {
24 | int ob_size = 2 * env->vision + 1;
25 | int ob_pixels = ob_size * ob_size;
26 | memset(env->obs, 0, env->n_agents_per_env * ob_pixels);
27 | int center = ob_pixels / 2;
28 | for (int i = 0; i < env->n_agents_per_env; i++) {
29 | for (int x = -env->vision; x <= env->vision; x++) {
30 | for (int y = -env->vision; y <= env->vision; y++) {
31 | char world_x = env->x[i] + x;
32 | char world_y = env->y[i] + y;
33 | if (world_x < 0 || world_x > env->size - 1 || world_y < 0 || world_y > env->size - 1) {
34 | env->obs[i * ob_pixels + center + x * ob_size + y] = WALL;
35 | }
36 | }
37 | }
38 | for (int j = 0; j < env->n_agents_per_env; j++) {
39 | char dx = env->x[j] - env->x[i];
40 | char dy = env->y[j] - env->y[i];
41 | if (abs(dx) <= env->vision && abs(dy) <= env->vision) {
42 | env->obs[i * ob_pixels + center + dx * ob_size + dy] = AGENT;
43 | }
44 | }
45 | char dx = env->goal_x - env->x[i];
46 | char dy = env->goal_y - env->y[i];
47 | if (abs(dx) <= env->vision && abs(dy) <= env->vision) {
48 | env->obs[i * ob_pixels + center + dx * ob_size + dy] = GOAL;
49 | }
50 | }
51 | }
52 |
53 | void get_total_obs(CMultiGrid* env) {
54 | memset(env->total_obs, 0, env->size * env->size);
55 | for (int i = 0; i < env->n_agents_per_env; i++) {
56 | env->total_obs[env->y[i] + env->x[i] * env->size] = AGENT;
57 | }
58 | env->total_obs[env->goal_y + env->goal_x * env->size] = GOAL;
59 | }
60 |
61 | void c_reset(CMultiGrid* env, char with_total_obs) {
62 | env->t = 0;
63 | env->goal_x = rand() % env->size;
64 | env->goal_y = rand() % env->size;
65 | for (int i = 0; i < env->n_agents_per_env; i++) {
66 | env->x[i] = rand() % env->size;
67 | env->y[i] = rand() % env->size;
68 | }
69 | get_obs(env);
70 | if (with_total_obs > 0) {
71 | get_total_obs(env);
72 | }
73 | }
74 |
75 | void agent_step(CMultiGrid* env, unsigned char *acts, int i, char with_total_obs) {
76 | env->rewards[i] = 0;
77 | env->dones[i] = 0;
78 | unsigned char act = acts[i];
79 | if (act == LEFT) env->x[i]--;
80 | else if (act == RIGHT) env->x[i]++;
81 | else if (act == UP) env->y[i]++;
82 | else if (act == DOWN) env->y[i]--;
83 | if (env->t > 3 * env->size || env->x[i] < 0 || env->y[i] < 0 || env->x[i] >= env->size || env->y[i] >= env->size) {
84 | env->dones[i] = 1;
85 | env->rewards[i] = -1;
86 | c_reset(env, with_total_obs);
87 | return;
88 | }
89 | if (env->x[i] == env->goal_x && env->y[i] == env->goal_y) {
90 | env->dones[i] = 1;
91 | env->rewards[i] = 1;
92 | c_reset(env, with_total_obs);
93 | return;
94 | }
95 | }
96 |
97 | void c_step(CMultiGrid* env, unsigned char *acts, bool with_total_obs) {
98 | for (int i = 0; i < env->n_agents_per_env; i++){
99 | agent_step(env, acts, i, with_total_obs);
100 | }
101 | get_obs(env);
102 | if (with_total_obs > 0) {
103 | get_total_obs(env);
104 | }
105 | env->t++;
106 | }
107 | '''
108 |
109 | ctypedef struct CMultiGrid:
110 | char *obs
111 | char *rewards
112 | char *dones
113 | unsigned char *x
114 | unsigned char *y
115 | char *total_obs
116 | int n_agents_per_env, vision, size, t, goal_x, goal_y
117 |
118 | void c_reset(CMultiGrid* env, char with_total_obs)
119 | void c_step(CMultiGrid* env, unsigned char *acts, char with_total_obs)
120 |
121 | cdef class MultiGrid:
122 | cdef:
123 | CMultiGrid* envs
124 | int n_agents, _n_acts, _n_agents_per_env
125 | np.ndarray obs_arr, rewards_arr, dones_arr, x_arr, y_arr, total_obs_arr
126 | cdef char[:, :, :] obs_memview
127 | cdef char[:] rewards_memview
128 | cdef char[:] dones_memview
129 | cdef unsigned char[:] x_memview
130 | cdef unsigned char[:] y_memview
131 | cdef char[:, :, :] total_obs_memview
132 | int size
133 | bint with_total_obs
134 |
135 | def __init__(self, n_agents=2**14, n_acts=5, n_agents_per_env=2, vision=3, size=8):
136 | self.envs = calloc(n_agents // n_agents_per_env, sizeof(CMultiGrid))
137 | self.n_agents = n_agents
138 | self._n_acts = n_acts
139 | self._n_agents_per_env = n_agents_per_env
140 | self.obs_arr = np.zeros((n_agents, 2*vision+1, 2*vision+1), dtype=np.int8)
141 | self.rewards_arr = np.zeros(n_agents, dtype=np.int8)
142 | self.dones_arr = np.zeros(n_agents, dtype=np.int8)
143 | self.x_arr = np.zeros(n_agents, dtype=np.uint8)
144 | self.y_arr = np.zeros(n_agents, dtype=np.uint8)
145 | self.total_obs_arr = np.zeros((n_agents // n_agents_per_env, size, size), dtype=np.int8)
146 | self.obs_memview = self.obs_arr
147 | self.rewards_memview = self.rewards_arr
148 | self.dones_memview = self.dones_arr
149 | self.x_memview = self.x_arr
150 | self.y_memview = self.y_arr
151 | self.total_obs_memview = self.total_obs_arr
152 | cdef int i, i_agent
153 | for i in range(n_agents // n_agents_per_env):
154 | env = &self.envs[i]
155 | i_agent = n_agents_per_env * i
156 | env.obs = &self.obs_memview[i_agent, 0, 0]
157 | env.rewards = &self.rewards_memview[i_agent]
158 | env.dones = &self.dones_memview[i_agent]
159 | env.x = &self.x_memview[i_agent]
160 | env.y = &self.y_memview[i_agent]
161 | env.total_obs = &self.total_obs_memview[i, 0, 0]
162 | env.n_agents_per_env = n_agents_per_env
163 | env.vision = vision
164 | env.size = size
165 |
166 | def reset(self, seed=None, with_total_obs=False):
167 | if seed is not None:
168 | srand(seed)
169 | cdef int i
170 | for i in range(self.n_agents // self.n_agents_per_env):
171 | c_reset(&self.envs[i], with_total_obs)
172 | return self
173 |
174 | def step(self, np.ndarray acts, with_total_obs=False):
175 | cdef unsigned char[:] acts_memview = acts
176 | cdef int i
177 | for i in range(self.n_agents // self.n_agents_per_env):
178 | c_step(&self.envs[i], &acts_memview[self.n_agents_per_env * i], with_total_obs)
179 |
180 | def close(self):
181 | free(self.envs)
182 |
183 | @property
184 | def obs(self): return self.obs_arr
185 |
186 | @property
187 | def rewards(self): return self.rewards_arr
188 |
189 | @property
190 | def dones(self): return self.dones_arr
191 |
192 | @property
193 | def n_acts(self): return self._n_acts
194 |
195 | @property
196 | def n_agents_per_env(self): return self._n_agents_per_env
197 |
198 | @property
199 | def total_obs(self): return self.total_obs_arr
200 |
--------------------------------------------------------------------------------
/flashrl/envs/pong.pyx:
--------------------------------------------------------------------------------
1 | # cython: language_level=3
2 | import numpy as np
3 | cimport numpy as np
4 |
5 | from libc.stdlib cimport free, srand, calloc
6 |
7 | cdef extern from *:
8 | '''
9 | #include
10 | #include
11 | #include
12 |
13 | const char PADDLE = 1, BALL = 2;
14 | const unsigned char UP = 1, DOWN = 2;
15 |
16 | typedef struct {
17 | char *obs0, *obs1, *reward0, *reward1, *done0, *done1;
18 | int size_x, size_y, t, paddle0_x, paddle0_y, paddle1_x, paddle1_y, x, dx;
19 | float y, dy, max_dy;
20 | } CPong;
21 |
22 | void set_obs(CPong* env, char paddle, char ball) {
23 | for (int i = -1; i < 2; i++) {
24 | if (env->paddle0_y + i >= 0 && env->paddle0_y + i <= env->size_y - 1) {
25 | env->obs0[(env->size_x - 1) - env->paddle0_x + (env->paddle0_y + i) * env->size_x] = paddle;
26 | env->obs1[env->paddle0_x + (env->paddle0_y + i) * env->size_x] = paddle;
27 | }
28 | if (env->paddle1_y + i >= 0 && env->paddle1_y + i <= env->size_y - 1) {
29 | env->obs0[(env->size_x - 1) - env->paddle1_x + (env->paddle1_y + i) * env->size_x] = paddle;
30 | env->obs1[env->paddle1_x + (env->paddle1_y + i) * env->size_x] = paddle;
31 | }
32 | }
33 | env->obs0[(env->size_x - 1) - env->x + (int)(roundf(env->y)) * env->size_x] = ball;
34 | env->obs1[env->x + (int)(roundf(env->y)) * env->size_x] = ball;
35 | }
36 |
37 | void c_reset(CPong* env) {
38 | env->t = 0;
39 | memset(env->obs0, 0, env->size_x * env->size_y);
40 | memset(env->obs1, 0, env->size_x * env->size_y);
41 | env->x = env->size_x / 2;
42 | env->y = rand() % (env->size_y - 1);
43 | env->dx = (rand() % 2) ? 1 : -1;
44 | env->dy = 2.0f * ((float)rand() / RAND_MAX) - 1.0f;
45 | env->paddle0_x = 0;
46 | env->paddle1_x = env->size_x - 1;
47 | env->paddle0_y = env->paddle1_y = env->size_y / 2;
48 | set_obs(env, PADDLE, BALL);
49 | }
50 |
51 | void c_step(CPong* env, unsigned char act0, unsigned char act1) {
52 | env->reward0[0] = env->reward1[0] = 0;
53 | env->done0[0] = env->done1[0] = 0;
54 | set_obs(env, 0, 0);
55 | if (act0 == UP && env->paddle0_y > 0) env->paddle0_y--;
56 | if (act0 == DOWN && env->paddle0_y < env->size_y - 2) env->paddle0_y++;
57 | if (act1 == UP && env->paddle1_y > 0) env->paddle1_y--;
58 | if (act1 == DOWN && env->paddle1_y < env->size_y - 2) env->paddle1_y++;
59 | env->dy = fminf(fmaxf(env->dy, -env->max_dy), env->max_dy);
60 | env->x += env->dx;
61 | env->y += env->dy;
62 | env->y = fminf(fmaxf(env->y, 0.f), env->size_y - 1.f);
63 | if (env->y <= 0 || env->y >= env->size_y - 1) env->dy = -env->dy;
64 | if (env->x == 1 && env->y >= env->paddle0_y - 1 && env->y <= env->paddle0_y + 1) {
65 | env->dx = -env->dx;
66 | env->dy += env->y - env->paddle0_y;
67 | }
68 | if (env->x == env->size_x - 2 && env->y >= env->paddle1_y - 1 && env->y <= env->paddle1_y + 1) {
69 | env->dx = -env->dx;
70 | env->dy += env->y - env->paddle1_y;
71 | }
72 | if (env->x == 0 || env->x == env->size_x - 1) {
73 | env->reward1[0] = 2 * (char)(env->x == 0) - 1;
74 | env->reward0[0] = -env->reward1[0];
75 | env->done0[0] = env->done1[0] = 1;
76 | c_reset(env);
77 | }
78 | set_obs(env, PADDLE, BALL);
79 | env->t++;
80 | }
81 | '''
82 |
83 | ctypedef struct CPong:
84 | char *obs0
85 | char *obs1
86 | char *reward0
87 | char *reward1
88 | char *done0
89 | char *done1
90 | int size_x, size_y, t, paddle0_x, paddle0_y, paddle1_x, paddle1_y, x, dx
91 | float y, dy, max_dy
92 |
93 | void c_reset(CPong* env)
94 | void c_step(CPong* env, unsigned char act0, unsigned char act1)
95 |
96 | cdef class Pong:
97 | cdef:
98 | CPong* envs
99 | int n_agents, _n_acts
100 | np.ndarray obs_arr, rewards_arr, dones_arr
101 | cdef char[:, :, :] obs_memview
102 | cdef char[:] rewards_memview
103 | cdef char[:] dones_memview
104 | int size_x, size_y
105 | float max_dy
106 |
107 | def __init__(self, n_agents=2**14, n_acts=3, size_x=16, size_y=8, max_dy=1.):
108 | self.envs = calloc(n_agents // 2, sizeof(CPong))
109 | self.n_agents = n_agents
110 | self._n_acts = n_acts
111 | self.obs_arr = np.zeros((n_agents, size_y, size_x), dtype=np.int8)
112 | self.rewards_arr = np.zeros(n_agents, dtype=np.int8)
113 | self.dones_arr = np.zeros(n_agents, dtype=np.int8)
114 | self.obs_memview = self.obs_arr
115 | self.rewards_memview = self.rewards_arr
116 | self.dones_memview = self.dones_arr
117 | cdef int i
118 | for i in range(n_agents // 2):
119 | env = &self.envs[i]
120 | env.obs0, env.obs1 = &self.obs_memview[2 * i, 0, 0], &self.obs_memview[2 * i + 1, 0, 0]
121 | env.reward0, env.reward1 = &self.rewards_memview[2 * i], &self.rewards_memview[2 * i + 1]
122 | env.done0, env.done1 = &self.dones_memview[2 * i], &self.dones_memview[2 * i + 1]
123 | env.size_x = size_x
124 | env.size_y = size_y
125 | env.max_dy = max_dy
126 |
127 | def reset(self, seed=None):
128 | if seed is not None:
129 | srand(seed)
130 | cdef int i
131 | for i in range(self.n_agents // 2):
132 | c_reset(&self.envs[i])
133 | return self
134 |
135 | def step(self, np.ndarray acts):
136 | cdef unsigned char[:] acts_memview = acts
137 | cdef int i
138 | for i in range(self.n_agents // 2):
139 | c_step(&self.envs[i], acts_memview[2 * i], acts_memview[2 * i + 1])
140 |
141 | def close(self):
142 | free(self.envs)
143 |
144 | @property
145 | def obs(self): return self.obs_arr
146 |
147 | @property
148 | def rewards(self): return self.rewards_arr
149 |
150 | @property
151 | def dones(self): return self.dones_arr
152 |
153 | @property
154 | def n_acts(self): return self._n_acts
155 |
--------------------------------------------------------------------------------
/flashrl/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from tqdm import tqdm
3 | from torch.utils.tensorboard import SummaryWriter
4 |
5 | from .models import Policy
6 | DEVICE = 'mps' if torch.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'
7 |
8 |
9 | class Learner:
10 | def __init__(self, env, model=None, device=None, dtype=None, compile_no_lstm=False, **kwargs):
11 | self.env = env
12 | self.device = DEVICE if device is None else device
13 | self.dtype = dtype if dtype is not None else torch.bfloat16 if self.device == 'cuda' else torch.float32
14 | self.model = Policy(self.env, **kwargs).to(self.device, self.dtype) if model is None else model
15 | if self.model.lstm is None and compile_no_lstm: # only no-lstm policy gets faster from torch.compile
16 | self.model = torch.compile(self.model, fullgraph=True, mode='reduce-overhead')
17 | self._data, self._np_data, self._rollout_state, self._ppo_state = None, None, None, None
18 |
19 | def fit(self, iters=40, steps=16, lr=.01, bs=None, anneal_lr=True, log=False, desc=None, stop_func=None, **hparams):
20 | bs = bs or len(self.env.obs) // 2
21 | self.setup_data(steps, bs)
22 | logger = SummaryWriter() if log else None
23 | opt = torch.optim.Adam(self.model.parameters(), lr=lr, eps=1e-5)
24 | pbar = tqdm(range(iters), total=iters)
25 | curves = []
26 | for i in pbar:
27 | opt.param_groups[0]['lr'] = lr * (1 - i / iters) if anneal_lr else lr
28 | self.rollout(steps)
29 | losses = ppo(self.model, opt, bs=bs, state=self._ppo_state, **self._data, **hparams)
30 | if desc: pbar.set_description(f'{desc}: {losses[desc] if desc in losses else self._data[desc].mean():.3f}')
31 | if i: pbar.set_postfix_str(f'{1e-6 * self._data["act"].numel() * pbar.format_dict["rate"]:.1f}M steps/s')
32 | if log:
33 | for k, v in losses.items(): logger.add_scalar(k, v, global_step=i)
34 | for name, param in self.model.named_parameters(): logger.add_histogram(name, param, global_step=i)
35 | curves.append(losses)
36 | if stop_func is not None:
37 | if stop_func(**self._data, **losses): break
38 | return {k: [m[k].item() for m in curves] for k in curves[0]}
39 |
40 | def setup_data(self, steps, bs=None):
41 | x = torch.zeros((len(self.env.obs), steps), dtype=self.dtype, device=self.device)
42 | obs = torch.zeros((*x.shape, *self.env.obs.shape[1:]), dtype=self.dtype, device=self.device)
43 | self._data = {'obs': obs, 'act': x.clone().byte(), 'logprob': x.clone(), 'value': x}
44 | self._np_data = {'reward': x.char().cpu().numpy(), 'done': x.char().cpu().numpy()}
45 | if self.model.lstm is not None:
46 | zeros = torch.zeros((len(obs), self.model.encoder.out_features), dtype=self.dtype, device=self.device)
47 | self._rollout_state = (zeros, zeros.clone())
48 | if bs is not None:
49 | zeros = torch.zeros((bs, self.model.encoder.out_features), dtype=self.dtype, device=self.device)
50 | self._ppo_state = (zeros, zeros.clone())
51 |
52 | def rollout(self, steps, state=None, extra_args_list=None, **kwargs):
53 | state = self._rollout_state if state is None else state
54 | if steps != (0 if self._data is None else self._data['obs'].shape[1]): self.setup_data(steps)
55 | extra_data = {} if extra_args_list is None else {k: [] for k in extra_args_list}
56 | for i in range(steps):
57 | o = self.to_torch(self.env.obs)
58 | with torch.no_grad():
59 | act, logp, _, value, state = self.model(o, state=state)
60 | self._data['obs'][:, i] = o
61 | self._data['act'][:, i] = act
62 | self._data['logprob'][:, i] = logp
63 | self._data['value'][:, i] = value
64 | self._np_data['reward'][:, i] = self.env.rewards
65 | self._np_data['done'][:, i] = self.env.dones
66 | for k in extra_data: extra_data[k].append(self.to_torch(getattr(self.env, k).copy()))
67 | self.env.step(act.cpu().numpy(), **kwargs)
68 | self._data.update({k: self.to_torch(v) for k, v in self._np_data.items()})
69 | return {k: torch.stack(v, dim=1) for k, v in extra_data.items()}
70 |
71 | def to_torch(self, x, non_blocking=True):
72 | return torch.from_numpy(x).to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
73 |
74 |
75 | def ppo(model, opt, obs, value, act, logprob, reward, done, bs=2**13, gamma=.99, gae_lambda=.95, clip_coef=.1,
76 | value_coef=.5, value_clip_coef=.1, entropy_coef=.01, max_grad_norm=.5, norm_adv=True, state=None):
77 | advs = get_advantages(value, reward, done, gamma=gamma, gae_lambda=gae_lambda)
78 | obs, value, act, logprob, advs = [xs.view(-1, bs, *xs.shape[2:]) for xs in [obs, value, act, logprob, advs]]
79 | returns = advs + value
80 | metrics, metric_keys = [], ['loss', 'policy_loss', 'value_loss', 'entropy_loss', 'kl']
81 | for o, old_value, a, old_logp, adv, ret in zip(obs, value, act, logprob, advs, returns):
82 | _, logp, entropy, val, state = model(o, state=state, act=a)
83 | state = state if model.lstm is None else (state[0].detach(), state[1].detach())
84 | logratio = logp - old_logp
85 | ratio = logratio.exp()
86 | adv = (adv - adv.mean()) / (adv.std() + 1e-8) if norm_adv else adv
87 | policy_loss = torch.max(-adv * ratio, -adv * ratio.clip(1 - clip_coef, 1 + clip_coef)).mean()
88 | if value_clip_coef:
89 | v_clipped = old_value + (val - old_value).clip(-value_clip_coef, value_clip_coef)
90 | value_loss = .5 * torch.max((val - ret) ** 2, (v_clipped - ret) ** 2).mean()
91 | else:
92 | value_loss = .5 * ((val - ret) ** 2).mean()
93 | entropy = entropy.mean()
94 | loss = policy_loss + value_coef * value_loss - entropy_coef * entropy
95 | opt.zero_grad()
96 | loss.backward()
97 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
98 | opt.step()
99 | kl = ((ratio - 1) - logratio).mean()
100 | metrics.append([loss, policy_loss, value_loss, entropy, kl])
101 | return {k: torch.stack([values[i] for values in metrics]).mean() for i, k in enumerate(metric_keys)}
102 |
103 |
104 | def get_advantages(value, reward, done, gamma=.99, gae_lambda=.95): # see arxiv.org/abs/1506.02438 eq. (16)-(18)
105 | advs = torch.zeros_like(value)
106 | not_done = 1. - done
107 | for t in range(1, done.shape[1]):
108 | delta = reward[:, -t] + gamma * value[:, -t] * not_done[:, -t] - value[:, -t - 1]
109 | advs[:, -t-1] = delta + gamma * gae_lambda * not_done[:, -t] * advs[:, -t]
110 | return advs
111 |
--------------------------------------------------------------------------------
/flashrl/models.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 |
4 |
5 | class Policy(torch.nn.Module):
6 | def __init__(self, env, hidden_size=128, lstm=False):
7 | super().__init__()
8 | self.encoder = torch.nn.Linear(math.prod(env.obs.shape[1:]), hidden_size)
9 | self.actor = torch.nn.Linear(hidden_size, env.n_acts)
10 | self.value_head = torch.nn.Linear(hidden_size, 1)
11 | self.lstm = cleanrl_init(torch.nn.LSTMCell(hidden_size, hidden_size)) if lstm else None
12 |
13 | def forward(self, x, state, act=None, with_entropy=None):
14 | with_entropy = act is not None if with_entropy is None else with_entropy
15 | h = self.encoder(x.view(len(x), -1)).relu()
16 | h, c = (h, None) if self.lstm is None else self.lstm(h, state)
17 | value = self.value_head(h)[:, 0]
18 | x = self.actor(h)
19 | act = torch.multinomial(x.softmax(dim=-1), 1)[:, 0] if act is None else act
20 | x = x - x.logsumexp(dim=-1, keepdim=True)
21 | logprob = x.gather(-1, act[..., None].long())[..., 0]
22 | entropy = -(x * x.softmax(dim=-1)).sum(-1) if with_entropy else None
23 | return act.byte(), logprob, entropy, value, (h, c)
24 |
25 |
26 | def cleanrl_init(module):
27 | for name, param in module.named_parameters():
28 | if 'bias' in name: torch.nn.init.constant_(param, 0)
29 | elif 'weight' in name: torch.nn.init.orthogonal_(param, 1)
30 | return module
31 |
--------------------------------------------------------------------------------
/flashrl/utils.py:
--------------------------------------------------------------------------------
1 | import sys, time, random, platform
2 | import torch
3 | import plotille
4 | import numpy as np
5 |
6 | from .envs import key_maps, emoji_maps
7 |
8 |
9 | def play(env, model=None, playable=False, steps=None, fps=4, obs='obs', dump=False, with_data=True, idx=0, **kwargs):
10 | key_map = key_maps[env.__class__.__name__.lower()]
11 | emoji_map = emoji_maps[env.__class__.__name__.lower()]
12 | if playable: print(f'Press {"".join(key_map)} to act, m for model act and q to quit')
13 | data, state = {}, None
14 | for i in range((10000 if playable else 64) if steps is None else steps):
15 | data.update({'step': i})
16 | render(getattr(env, obs)[idx], cursor_up=i and not dump, emoji_map=emoji_map, data=data if with_data else None)
17 | acts = np.zeros(len(env.obs), dtype=np.uint8)
18 | if model is not None:
19 | o = torch.from_numpy(env.obs).to(device=model.actor.weight.device, dtype=model.actor.weight.dtype)
20 | with torch.no_grad(): acts, logp, entropy, val, state = model(o, state=state, with_entropy=True)
21 | data.update({'model act': acts[idx], 'logp': logp[idx], 'entropy': entropy[idx], 'value': val[idx]})
22 | acts = acts.cpu().numpy()
23 | key = get_pressed_key() if playable else f'm{time.sleep(1 / fps)}'[:1]
24 | if key == 'q': break
25 | acts[idx] = acts[idx] if key == 'm' else key_map[key] if key in key_map else 0
26 | env.step(acts, **kwargs)
27 | data.update({'act': acts[idx], 'reward': env.rewards[idx], 'done': env.dones[idx]})
28 |
29 |
30 | def render(ob, cursor_up=True, emoji_map=None, data=None):
31 | if cursor_up: print(f'\033[A\033[{len(ob)}A')
32 | ob = 23 * (ob - ob.min()) / (ob.max() - ob.min()) + 232 if emoji_map is None else ob
33 | for i, row in enumerate(ob):
34 | for o in row.tolist():
35 | print(f'\033[48;5;{f"{232 + o}m" if emoji_map is None else f"232m{emoji_map[o]}"}\033[0m', end='')
36 | if data is not None:
37 | if i < len(data):
38 | print(f'{list(data.keys())[i]}: {list(data.values())[i]:.3g}', end=' ')
39 | print()
40 |
41 |
42 | def set_seed(seed):
43 | np.random.seed(seed)
44 | random.seed(seed)
45 | if seed is not None:
46 | torch.manual_seed(seed)
47 | torch.backends.cudnn.deterministic = True
48 |
49 |
50 | def print_curve(array, label=None, height=8, width=65):
51 | fig = plotille.Figure()
52 | fig._height, fig._width = height, width
53 | fig.y_label = fig.y_label if label is None else label
54 | fig.scatter(list(range(len(array))), array)
55 | print('\n'.join(fig.show().split('\n')[:-2]))
56 |
57 |
58 | def get_pressed_key():
59 | if platform.system() == 'Windows':
60 | import msvcrt
61 | key = msvcrt.getch()
62 | else:
63 | import tty, termios
64 | fd = sys.stdin.fileno()
65 | old_settings = termios.tcgetattr(fd)
66 | try:
67 | tty.setraw(sys.stdin.fileno())
68 | key = sys.stdin.read(1)
69 | finally:
70 | termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
71 | return key
72 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ['setuptools', 'wheel', 'Cython', 'numpy']
3 | build-backend = 'setuptools.build_meta'
4 |
5 | [project]
6 | name = 'flashrl'
7 | version = '0.2.1'
8 | description = 'Fast reinforcement learning 💨'
9 | authors = [{name = 'codingfisch', email = 'l_fisc17@wwu.de'}]
10 | readme = 'README.md'
11 | dependencies = ['torch', 'Cython', 'tqdm', 'plotille', 'tensorboard']
12 |
13 | [tool.setuptools]
14 | license-files = []
15 | packages = ['flashrl', 'flashrl.envs']
16 | include-package-data = true
17 |
18 | [tool.setuptools.package-data]
19 | flashrl = ['envs/*.pyx', 'envs/*.c']
20 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | Cython
3 | tqdm
4 | plotille
5 | tensorboard
6 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import numpy
2 | from pathlib import Path
3 | from setuptools import setup, Extension
4 | from Cython.Build import cythonize
5 |
6 | kwargs = {'extra_compile_args': ['-DNPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION', '-O1']}
7 | mods = [Extension(f'flashrl.envs.{Path(fp.stem)}', sources=[fp], **kwargs) for fp in Path('flashrl/envs').glob('*.pyx')]
8 | setup(ext_modules=cythonize(mods), packages=['flashrl', 'flashrl.envs'], include_dirs=[numpy.get_include()],
9 | include_package_data=True)
10 |
--------------------------------------------------------------------------------
/test/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codingfisch/flashrl/8769edd4383a178678b3146daf7a1023e5e5fb67/test/__init__.py
--------------------------------------------------------------------------------
/test/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from flashrl.utils import play, print_curve
3 | from flashrl.models import Policy
4 | from flashrl.main import get_advantages, Learner
5 | from flashrl.envs.grid import Grid
6 |
7 |
8 | def test_print_ascii_curve():
9 | array = [1, 2, 3, 4, 5]
10 | print_curve(array)
11 |
12 | def test_Policy_init():
13 | env = Grid(n_agents=1, size=5)
14 | model = Policy(env, lstm=True)
15 | assert isinstance(model, Policy), 'Model initialization failed'
16 | assert model.actor.in_features == 128, 'Actor layer initialization failed'
17 | assert model.lstm.hidden_size == 128, 'LSTM layer hidden size is incorrect'
18 |
19 | def test_learner_init():
20 | env = Grid(n_agents=1, size=5)
21 | learn = Learner(env)
22 | assert isinstance(learn.model, Policy), 'Model initialization failed'
23 |
24 | def test_learner_fit():
25 | env = Grid(n_agents=2**13, size=5)
26 | learn = Learner(env)
27 | losses = learn.fit(iters=1)
28 | assert 'loss' in losses, 'Fit did not return metrics'
29 |
30 | def test_play():
31 | env = Grid(n_agents=2**13, size=5)
32 | learn = Learner(env)
33 | learn.fit(iters=1)
34 | play(env, learn.model, steps=16)
35 |
36 | def test_get_advantages():
37 | values = torch.rand(1, 10)
38 | rewards = torch.rand(1, 10)
39 | dones = torch.zeros(1, 10)
40 | advantages = get_advantages(values, rewards, dones)
41 | assert advantages is not None, 'Advantages calculation failed'
42 | assert advantages.shape == values.shape, 'Advantages shape mismatch'
43 |
44 | if __name__ == '__main__':
45 | test_print_ascii_curve()
46 | test_Policy_init()
47 | test_learner_init()
48 | test_learner_fit()
49 | test_play()
50 | test_get_advantages()
51 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import flashrl as frl
2 | frl.set_seed(SEED:=1)
3 |
4 | env = frl.envs.Pong(n_agents=2**14).reset(SEED) # try one of: Pong, Grid, MultiGrid
5 | learn = frl.Learner(env, hidden_size=128, lstm=True) # faster with lstm=False and smaller hidden_size
6 | curves = learn.fit(40, steps=16, desc='done') # ,lr=1e-2, gamma=.99) set hparams here
7 | frl.print_curve(curves['loss'], label='loss')
8 | frl.play(env, learn.model, fps=8, playable=False) # if env is MultiGrid, try obs='total_obs', with_total_obs=True
9 | env.close()
10 |
--------------------------------------------------------------------------------