├── .github └── workflows │ └── install.yml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── c_gae.pyx ├── clean_pufferl.py ├── cleanrl_ppo_atari.py ├── config ├── atari │ ├── beam_rider.ini │ ├── breakout.ini │ ├── default.ini │ ├── enduro.ini │ ├── pong.ini │ ├── qbert.ini │ ├── seaquest.ini │ └── space_invaders.ini ├── box2d.ini ├── bsuite.ini ├── butterfly.ini ├── classic_control.ini ├── classic_control_continuous.ini ├── crafter.ini ├── default.ini ├── dm_control.ini ├── dm_lab.ini ├── doom.ini ├── gpudrive.ini ├── griddly.ini ├── gvgai.ini ├── magent.ini ├── microrts.ini ├── minerl.ini ├── minigrid.ini ├── minihack.ini ├── mujoco.ini ├── nethack.ini ├── nmmo.ini ├── ocean │ ├── breakout.ini │ ├── connect4.ini │ ├── continuous.ini │ ├── enduro.ini │ ├── go.ini │ ├── grid.ini │ ├── moba.ini │ ├── nmmo3.ini │ ├── pong.ini │ ├── pysquared.ini │ ├── rware.ini │ ├── sanity.ini │ ├── snake.ini │ ├── squared.ini │ ├── tactical.ini │ ├── trash_pickup.ini │ └── tripletriad.ini ├── open_spiel.ini ├── pokemon_red.ini ├── procgen.ini ├── slimevolley.ini ├── stable_retro.ini ├── starcraft.ini └── trade_sim.ini ├── demo.py ├── evaluate_elos.py ├── pufferlib ├── __init__.py ├── cleanrl.py ├── emulation.py ├── environment.py ├── environments │ ├── __init__.py │ ├── atari │ │ ├── __init__.py │ │ ├── environment.py │ │ └── torch.py │ ├── box2d │ │ ├── __init__.py │ │ ├── environment.py │ │ └── torch.py │ ├── bsuite │ │ ├── __init__.py │ │ ├── environment.py │ │ └── torch.py │ ├── butterfly │ │ ├── __init__.py │ │ ├── environment.py │ │ └── torch.py │ ├── classic_control │ │ ├── __init__.py │ │ ├── environment.py │ │ └── torch.py │ ├── classic_control_continuous │ │ ├── __init__.py │ │ ├── environment.py │ │ └── torch.py │ ├── crafter │ │ ├── __init__.py │ │ ├── environment.py │ │ └── torch.py │ ├── dm_control │ │ ├── __init__.py │ │ ├── environment.py │ │ └── torch.py │ ├── dm_lab │ │ ├── __init__.py │ │ ├── environment.py │ │ └── torch.py │ ├── gpudrive │ │ ├── __init__.py │ │ ├── environment.py │ │ └── torch.py │ ├── griddly │ │ ├── __init__.py │ │ ├── environment.py │ │ └── torch.py │ ├── gvgai │ │ └── environment.py │ ├── links_awaken │ │ ├── __init__.py │ │ ├── environment.py │ │ └── torch.py │ ├── magent │ │ ├── __init__.py │ │ ├── environment.py │ │ └── torch.py │ ├── microrts │ │ ├── __init__.py │ │ ├── environment.py │ │ └── torch.py │ ├── minerl │ │ ├── __init__.py │ │ ├── environment.py │ │ └── torch.py │ ├── minigrid │ │ ├── __init__.py │ │ ├── environment.py │ │ └── torch.py │ ├── minihack │ │ ├── __init__.py │ │ ├── environment.py │ │ └── torch.py │ ├── mujoco │ │ ├── __init__.py │ │ ├── cleanrl.py │ │ ├── environment.py │ │ └── policy.py │ ├── nethack │ │ ├── Hack-Regular.ttf │ │ ├── __init__.py │ │ ├── environment.py │ │ ├── torch.py │ │ └── wrapper.py │ ├── nmmo │ │ ├── __init__.py │ │ ├── environment.py │ │ └── torch.py │ ├── open_spiel │ │ ├── __init__.py │ │ ├── environment.py │ │ ├── gymnasium_environment.py │ │ ├── pettingzoo_environment.py │ │ ├── torch.py │ │ └── utils.py │ ├── pokemon_red │ │ ├── __init__.py │ │ ├── environment.py │ │ └── torch.py │ ├── procgen │ │ ├── __init__.py │ │ ├── environment.py │ │ └── torch.py │ ├── slimevolley │ │ ├── __init__.py │ │ ├── environment.py │ │ └── torch.py │ ├── smac │ │ ├── __init__.py │ │ ├── environment.py │ │ └── torch.py │ ├── stable_retro │ │ ├── __init__.py │ │ ├── environment.py │ │ └── torch.py │ ├── test │ │ ├── __init__.py │ │ ├── environment.py │ │ ├── mock_environments.py │ │ └── torch.py │ ├── trade_sim │ │ ├── __init__.py │ │ ├── environment.py │ │ └── torch.py │ └── vizdoom │ │ ├── __init__.py │ │ ├── environment.py │ │ └── torch.py ├── exceptions.py ├── extensions.pyx ├── models.py ├── namespace.py ├── ocean │ ├── __init__.py │ ├── breakout │ │ ├── breakout.c │ │ ├── breakout.h │ │ ├── breakout.py │ │ └── cy_breakout.pyx │ ├── connect4 │ │ ├── connect4.c │ │ ├── connect4.h │ │ ├── connect4.py │ │ ├── connect4game │ │ └── cy_connect4.pyx │ ├── enduro │ │ ├── cy_enduro.pyx │ │ ├── enduro.c │ │ ├── enduro.h │ │ └── enduro.py │ ├── environment.py │ ├── go │ │ ├── cy_go.pyx │ │ ├── go.c │ │ ├── go.h │ │ └── go.py │ ├── grid │ │ ├── __init__.py │ │ ├── c_grid.pyx │ │ ├── cy_grid.pyx │ │ ├── grid.c │ │ ├── grid.h │ │ └── grid.py │ ├── moba │ │ ├── __init__.py │ │ ├── cy_moba.pyx │ │ ├── game_map.h │ │ ├── moba.c │ │ ├── moba.h │ │ └── moba.py │ ├── nmmo3 │ │ ├── cy_nmmo3.pyx │ │ ├── make_sprite_sheets.py │ │ ├── nmmo3.c │ │ ├── nmmo3.h │ │ ├── nmmo3.py │ │ ├── simplex.h │ │ └── tile_atlas.h │ ├── pong │ │ ├── cy_pong.pyx │ │ ├── pong.c │ │ ├── pong.h │ │ └── pong.py │ ├── render.py │ ├── robocode │ │ ├── build_local.sh │ │ ├── robocode │ │ ├── robocode.c │ │ └── robocode.h │ ├── rocket_lander │ │ ├── cy_rocket_lander.pyx │ │ ├── rocket_lander.c │ │ ├── rocket_lander.h │ │ └── rocket_lander.py │ ├── rware │ │ ├── cy_rware.pyx │ │ ├── rware.c │ │ ├── rware.h │ │ └── rware.py │ ├── sanity.py │ ├── snake │ │ ├── README.md │ │ ├── __init__.py │ │ ├── cy_snake.pyx │ │ ├── snake.c │ │ ├── snake.h │ │ └── snake.py │ ├── squared │ │ ├── cy_squared.pyx │ │ ├── pysquared.py │ │ ├── squared.c │ │ ├── squared.h │ │ └── squared.py │ ├── tactical │ │ ├── __init__.py │ │ ├── build_local.sh │ │ ├── c_tactical.pyx │ │ ├── maps.h │ │ ├── tactical.c │ │ ├── tactical.h │ │ └── tactical.py │ ├── tcg │ │ ├── build_local.sh │ │ ├── build_web.sh │ │ ├── tcg.c │ │ └── tcg.h │ ├── torch.py │ ├── trash_pickup │ │ ├── README.md │ │ ├── cy_trash_pickup.pyx │ │ ├── trash_pickup.c │ │ ├── trash_pickup.h │ │ └── trash_pickup.py │ └── tripletriad │ │ ├── cy_tripletriad.pyx │ │ ├── tripletriad.c │ │ ├── tripletriad.h │ │ └── tripletriad.py ├── policy_ranker.py ├── policy_store.py ├── postprocess.py ├── puffernet.h ├── puffernet.pyx ├── pytorch.py ├── resources │ ├── breakout_weights.bin │ ├── connect4.pt │ ├── connect4_weights.bin │ ├── enduro │ │ ├── enduro_spritesheet.png │ │ └── enduro_weights.bin │ ├── go_weights.bin │ ├── moba │ │ ├── bloom_shader_100.fs │ │ ├── bloom_shader_330.fs │ │ ├── dota_map.png │ │ ├── game_map.npy │ │ ├── map_shader_100.fs │ │ ├── map_shader_330.fs │ │ ├── moba_assets.png │ │ └── moba_weights.bin │ ├── nmmo3 │ │ ├── ASSETS_LICENSE.md │ │ ├── air_0.png │ │ ├── air_1.png │ │ ├── air_2.png │ │ ├── air_3.png │ │ ├── air_4.png │ │ ├── air_5.png │ │ ├── air_6.png │ │ ├── air_7.png │ │ ├── air_8.png │ │ ├── air_9.png │ │ ├── earth_0.png │ │ ├── earth_1.png │ │ ├── earth_2.png │ │ ├── earth_3.png │ │ ├── earth_4.png │ │ ├── earth_5.png │ │ ├── earth_6.png │ │ ├── earth_7.png │ │ ├── earth_8.png │ │ ├── earth_9.png │ │ ├── fire_0.png │ │ ├── fire_1.png │ │ ├── fire_2.png │ │ ├── fire_3.png │ │ ├── fire_4.png │ │ ├── fire_5.png │ │ ├── fire_6.png │ │ ├── fire_7.png │ │ ├── fire_8.png │ │ ├── fire_9.png │ │ ├── inventory_64.png │ │ ├── inventory_64_press.png │ │ ├── inventory_64_selected.png │ │ ├── items_condensed.png │ │ ├── map_shader_100.fs │ │ ├── map_shader_330.fs │ │ ├── merged_sheet.png │ │ ├── neutral_0.png │ │ ├── neutral_1.png │ │ ├── neutral_2.png │ │ ├── neutral_3.png │ │ ├── neutral_4.png │ │ ├── neutral_5.png │ │ ├── neutral_6.png │ │ ├── neutral_7.png │ │ ├── neutral_8.png │ │ ├── neutral_9.png │ │ ├── nmmo3_help.png │ │ ├── nmmo_1500.bin │ │ ├── nmmo_2025.bin │ │ ├── water_0.png │ │ ├── water_1.png │ │ ├── water_2.png │ │ ├── water_3.png │ │ ├── water_4.png │ │ ├── water_5.png │ │ ├── water_6.png │ │ ├── water_7.png │ │ ├── water_8.png │ │ └── water_9.png │ ├── pong_weights.bin │ ├── puffers_128.png │ ├── robocode │ │ └── robocode.png │ ├── rware_weights.bin │ ├── snake_weights.bin │ └── tripletriad_weights.bin ├── spaces.py ├── utils.py ├── vector.py ├── version.py └── wrappers.py ├── pyproject.toml ├── resources ├── save_net_flat.py ├── sb3_demo.py ├── scripts ├── build_ocean.sh ├── minshell.html ├── sweep_atari.sh ├── train_atari.sh ├── train_ocean.sh ├── train_procgen.sh └── train_sanity.sh ├── setup.py └── tests ├── __init__.py ├── mem_test.py ├── pool ├── envpool_results.npy ├── plot_packing.py ├── test_basic_multprocessing.py ├── test_envpool.py └── test_multiprocessing.py ├── test.py ├── test_api.py ├── test_atari_reset.py ├── test_carbs.py ├── test_cleanrl_utils.py ├── test_extensions.py ├── test_flatten.py ├── test_import_performance.py ├── test_namespace.py ├── test_nested.py ├── test_nmmo3_compile.py ├── test_performance.py ├── test_pokemon_red.py ├── test_policy_pool.py ├── test_puffernet.py ├── test_pytorch.py ├── test_record_array.py ├── test_record_emulation.py ├── test_registry.sh ├── test_render.py ├── test_rich.py ├── test_utils.py └── time_alloc.py /.github/workflows/install.yml: -------------------------------------------------------------------------------- 1 | name: install 2 | 3 | on: 4 | push: 5 | pull_request: 6 | 7 | jobs: 8 | test: 9 | name: test ${{ matrix.py }} - ${{ matrix.os }} - ${{ matrix.env }} 10 | runs-on: ${{ matrix.os }} 11 | strategy: 12 | fail-fast: false 13 | matrix: 14 | os: 15 | - ubuntu-latest 16 | - macos-latest 17 | py: 18 | - "3.11" 19 | - "3.10" 20 | - "3.9" 21 | env: 22 | - pip 23 | - conda 24 | steps: 25 | - name: Checkout code 26 | uses: actions/checkout@v3 27 | 28 | - name: Setup Conda 29 | if: matrix.env == 'conda' 30 | uses: conda-incubator/setup-miniconda@v2 31 | with: 32 | python-version: ${{ matrix.py }} 33 | miniconda-version: "latest" 34 | activate-environment: test-env 35 | auto-update-conda: true 36 | 37 | - name: Setup Python for pip 38 | if: matrix.env == 'pip' 39 | uses: actions/setup-python@v4 40 | with: 41 | python-version: ${{ matrix.py }} 42 | 43 | - name: Upgrade pip 44 | run: python -m pip install -U pip 45 | 46 | - name: Install pufferlib 47 | run: pip install -e . 48 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 PufferAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | global-include *.pyx 2 | global-include *.pxd 3 | global-include *.h 4 | global-include *.py 5 | recursive-include pufferlib/resources * 6 | recursive-exclude experiments * 7 | recursive-exclude wandb * 8 | recursive-exclude tests * 9 | include raylib-5.0_linux_amd64/lib/libraylib.a 10 | include raylib-5.0_macos/lib/libraylib.a 11 | recursive-exclude raylib-5.0_webassembly * 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![figure](https://pufferai.github.io/source/resource/header.png) 2 | 3 | [![PyPI version](https://badge.fury.io/py/pufferlib.svg)](https://badge.fury.io/py/pufferlib) 4 | ![PyPI - Python Version](https://img.shields.io/pypi/pyversions/pufferlib) 5 | ![Github Actions](https://github.com/PufferAI/PufferLib/actions/workflows/install.yml/badge.svg) 6 | [![](https://dcbadge.vercel.app/api/server/spT4huaGYV?style=plastic)](https://discord.gg/spT4huaGYV) 7 | [![Twitter](https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Follow%20%40jsuarez5341)](https://twitter.com/jsuarez5341) 8 | 9 | PufferLib is the reinforcement learning library I wish existed during my PhD. It started as a compatibility layer to make working with complex environments a breeze. Now, it's a high-performance toolkit for research and industry with optimized parallel simulation, environments that run and train at 1M+ steps/second, and tons of quality of life improvements for practitioners. All our tools are free and open source. We also offer priority service for companies, startups, and labs! 10 | 11 | ![Trailer](https://github.com/PufferAI/puffer.ai/blob/main/docs/assets/puffer_2.gif?raw=true) 12 | 13 | All of our documentation is hosted at [puffer.ai](https://puffer.ai "PufferLib Documentation"). @jsuarez5341 on [Discord](https://discord.gg/puffer) for support -- post here before opening issues. We're always looking for new contributors, too! 14 | 15 | ## Star to puff up the project! 16 | 17 | 18 | 19 | 20 | 21 | Star History Chart 22 | 23 | 24 | -------------------------------------------------------------------------------- /c_gae.pyx: -------------------------------------------------------------------------------- 1 | # distutils: define_macros=NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION 2 | # cython: language_level=3 3 | # cython: boundscheck=False 4 | # cython: initializedcheck=False 5 | # cython: wraparound=False 6 | # cython: nonecheck=False 7 | 8 | import numpy as np 9 | cimport numpy as cnp 10 | 11 | def compute_gae(cnp.ndarray dones, cnp.ndarray values, 12 | cnp.ndarray rewards, float gamma, float gae_lambda): 13 | '''Fast Cython implementation of Generalized Advantage Estimation (GAE)''' 14 | cdef int num_steps = len(rewards) 15 | cdef cnp.ndarray advantages = np.zeros(num_steps, dtype=np.float32) 16 | cdef float[:] c_advantages = advantages 17 | cdef float[:] c_dones = dones 18 | cdef float[:] c_values = values 19 | cdef float[:] c_rewards = rewards 20 | 21 | cdef float lastgaelam = 0 22 | cdef float nextnonterminal, delta 23 | cdef int t, t_cur, t_next 24 | for t in range(num_steps-1): 25 | t_cur = num_steps - 2 - t 26 | t_next = num_steps - 1 - t 27 | nextnonterminal = 1.0 - c_dones[t_next] 28 | delta = c_rewards[t_next] + gamma * c_values[t_next] * nextnonterminal - c_values[t_cur] 29 | lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam 30 | c_advantages[t_cur] = lastgaelam 31 | 32 | return advantages 33 | 34 | 35 | -------------------------------------------------------------------------------- /config/atari/beam_rider.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = atari 3 | env_name = beam_rider 4 | 5 | [train] 6 | total_timesteps = 9_762_779 7 | batch_size = 65_536 8 | minibatch_size = 1024 9 | update_epochs = 3 10 | bptt_horizon = 2 11 | learning_rate = 0.00041171401568673385 12 | gae_lambda = 0.14527976163861273 13 | gamma = 0.990622479610104 14 | ent_coef = 0.010996558409985507 15 | clip_coef = 0.4966414480536032 16 | vf_clip_coef = 0.13282356641582535 17 | vf_coef = 0.985913502481555 18 | max_grad_norm = 0.9385297894477844 19 | 20 | [env] 21 | frameskip = 4 22 | repeat_action_probability = 0.0 23 | -------------------------------------------------------------------------------- /config/atari/breakout.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = atari 3 | env_name = breakout 4 | 5 | [train] 6 | total_timesteps = 25_000_000 7 | batch_size = 65_536 8 | minibatch_size = 512 9 | update_epochs = 3 10 | bptt_horizon = 16 11 | learning_rate = 0.0005426281444434721 12 | gae_lambda = 0.8538481400576657 13 | gamma = 0.9955183835557186 14 | ent_coef = 0.003385776743651383 15 | clip_coef = 0.07485999166174963 16 | vf_clip_coef = 0.20305217614276536 17 | vf_coef = 0.18278144162218027 18 | max_grad_norm = 1.0495498180389404 19 | 20 | [env] 21 | frameskip = 4 22 | repeat_action_probability = 0.0 23 | -------------------------------------------------------------------------------- /config/atari/default.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = atari 3 | 4 | env_name = adventure air_raid alien amidar assault asterix asteroids atlantis2 atlantis backgammon bank_heist basic_math battle_zone berzerk blackjack bowling boxing carnival casino centipede chopper_command combat crazy_climber crossbow darkchambers defender demon_attack donkey_kong double_dunk earthworld elevator_action entombed et fishing_derby flag_capture freeway frogger frostbite galaxian gopher gravitar hangman haunted_house hero human_cannonball ice_hockey jamesbond journey_escape joust kaboom kangaroo keystone_kapers king_kong klax koolaid krull kung_fu_master laser_gates lost_luggage mario_bros maze_craze miniature_golf montezuma_revenge mr_do ms_pacman name_this_game othello pacman phoenix pitfall2 pitfall pooyan private_eye riverraid road_runner robotank sir_lancelot skiing solaris space_war star_gunner superman surround tennis tetris tic_tac_toe_3d time_pilot trondead turmoil tutankham up_n_down venture video_checkers video_chess video_cube video_pinball warlords wizard_of_wor word_zapper yars_revenge zaxxon 5 | 6 | policy_name = Policy 7 | rnn_name = Recurrent 8 | 9 | [train] 10 | num_envs = 144 11 | num_workers = 24 12 | env_batch_size = 48 13 | zero_copy = False 14 | batch_size = 32768 15 | minibatch_size = 1024 16 | update_epochs = 2 17 | bptt_horizon = 8 18 | total_timesteps = 10_000_000 19 | anneal_lr = False 20 | 21 | [env] 22 | frameskip = 4 23 | repeat_action_probability = 0.0 24 | 25 | [sweep.parameters.env.parameters.frameskip] 26 | distribution = uniform 27 | min = 1 28 | max = 10 29 | 30 | #[sweep.parameters.env.parameters.repeat_action_probability] 31 | #distribution = uniform 32 | #min = 0 33 | #max = 1 34 | 35 | [sweep.parameters.train.parameters.total_timesteps] 36 | distribution = uniform 37 | min = 5_000_000 38 | max = 200_000_000 39 | 40 | [sweep.parameters.train.parameters.batch_size] 41 | distribution = uniform 42 | min = 16384 43 | max = 65536 44 | 45 | [sweep.parameters.train.parameters.minibatch_size] 46 | distribution = uniform 47 | min = 512 48 | max = 8192 49 | -------------------------------------------------------------------------------- /config/atari/enduro.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = atari 3 | env_name = enduro 4 | 5 | [train] 6 | total_timesteps = 16_657_125 7 | batch_size = 16384 8 | minibatch_size = 4096 9 | update_epochs = 6 10 | bptt_horizon = 8 11 | learning_rate = 0.0003495115734491776 12 | gae_lambda = 0.5996818325556474 13 | gamma = 0.9895491287086732 14 | ent_coef = 0.0021720638001863288 15 | clip_coef = 0.7140062837950062 16 | vf_clip_coef = 0.02629607191897852 17 | vf_coef = 0.9842251826587504 18 | max_grad_norm = 0.8422542810440063 19 | 20 | [env] 21 | frameskip = 6 22 | repeat_action_probability = 0.0 23 | -------------------------------------------------------------------------------- /config/atari/pong.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = atari 3 | env_name = pong 4 | rnn_name = Recurrent 5 | 6 | [train] 7 | total_timesteps = 5_000_000 8 | batch_size = 32768 9 | minibatch_size = 1024 10 | update_epochs = 2 11 | bptt_horizon = 8 12 | learning_rate = 0.0006112614226003401 13 | gae_lambda = 0.9590507508564148 14 | gamma = 0.9671759718055382 15 | ent_coef = 0.01557519441744131 16 | clip_coef = 0.3031963355045393 17 | vf_clip_coef = 0.13369578727174328 18 | vf_coef = 0.9274225135298954 19 | max_grad_norm = 1.392141580581665 20 | 21 | [env] 22 | frameskip = 4 23 | repeat_action_probability = 0.0 24 | -------------------------------------------------------------------------------- /config/atari/qbert.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = atari 3 | env_name = qbert 4 | 5 | [train] 6 | total_timesteps = 15_000_000 7 | batch_size = 32_768 8 | minibatch_size = 1024 9 | update_epochs = 3 10 | bptt_horizon = 8 11 | learning_rate = 0.00104284086325656 12 | gae_lambda = 0.8573007456819492 13 | gamma = 0.9426362777287904 14 | ent_coef = 0.025180053429464784 15 | clip_coef = 0.23123278532103236 16 | vf_clip_coef = 0.12751979973690886 17 | vf_coef = 0.5903166418793799 18 | max_grad_norm = 0.1610541045665741 19 | 20 | [env] 21 | frameskip = 4 22 | repeat_action_probability = 0.0 23 | -------------------------------------------------------------------------------- /config/atari/seaquest.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = atari 3 | env_name = seaquest 4 | 5 | [train] 6 | total_timesteps = 6_683_099 7 | batch_size = 65_536 8 | minibatch_size = 4096 9 | update_epochs = 4 10 | bptt_horizon = 8 11 | learning_rate = 0.0012161186756094724 12 | gae_lambda = 0.6957035398791592 13 | gamma = 0.9925043678586688 14 | ent_coef = 0.032082891906869346 15 | clip_coef = 0.24504077831570073 16 | vf_clip_coef = 0.18204547640437296 17 | vf_coef = 0.5850012910005633 18 | max_grad_norm = 0.6649078130722046 19 | 20 | [env] 21 | frameskip = 4 22 | repeat_action_probability = 0.0 23 | -------------------------------------------------------------------------------- /config/atari/space_invaders.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = atari 3 | env_name = space_invaders 4 | 5 | [train] 6 | total_timesteps = 7_500_000 7 | batch_size = 65_536 8 | minibatch_size = 4_096 9 | update_epochs = 4 10 | bptt_horizon = 8 11 | learning_rate = 0.0012161186756094724 12 | gae_lambda = 0.6957035398791592 13 | gamma = 0.9925043678586688 14 | ent_coef = 0.032082891906869346 15 | clip_coef = 0.24504077831570073 16 | vf_clip_coef = 0.18204547640437296 17 | vf_coef = 0.5850012910005633 18 | max_grad_norm = 0.6649078130722046 19 | 20 | [env] 21 | frameskip = 4 22 | repeat_action_probability = 0.0 23 | -------------------------------------------------------------------------------- /config/box2d.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = box2d 3 | env_name = car-racing 4 | 5 | [train] 6 | num_envs = 48 7 | num_workers = 24 8 | env_batch_size = 48 9 | zero_copy = False 10 | batch_size = 16384 11 | minibatch_size = 2048 12 | -------------------------------------------------------------------------------- /config/bsuite.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = bsuite 3 | env_name = bandit/0 4 | 5 | [train] 6 | total_timesteps = 1_000_000 7 | num_envs = 1 8 | -------------------------------------------------------------------------------- /config/butterfly.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = butterfly 3 | env_name = cooperative_pong_v5 4 | -------------------------------------------------------------------------------- /config/classic_control.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = classic_control 3 | env_name = cartpole mountaincar 4 | 5 | [train] 6 | total_timesteps = 500_000 7 | num_envs = 64 8 | env_batch_size = 64 9 | -------------------------------------------------------------------------------- /config/classic_control_continuous.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = classic_control_continuous 3 | env_name = mountaincar-continuous 4 | 5 | [train] 6 | total_timesteps = 500_000 7 | num_envs = 64 8 | env_batch_size = 64 9 | -------------------------------------------------------------------------------- /config/crafter.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = crafter 3 | env_name = crafter 4 | 5 | [train] 6 | num_envs = 96 7 | num_workers = 24 8 | env_batch_size = 48 9 | zero_copy = False 10 | batch_size = 6144 11 | -------------------------------------------------------------------------------- /config/default.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = None 3 | env_name = None 4 | vec = native 5 | policy_name = Policy 6 | rnn_name = None 7 | max_suggestion_cost = 3600 8 | 9 | [env] 10 | 11 | [policy] 12 | 13 | [rnn] 14 | 15 | [train] 16 | seed = 1 17 | torch_deterministic = True 18 | cpu_offload = False 19 | device = cuda 20 | total_timesteps = 10_000_000 21 | learning_rate = 2.5e-4 22 | anneal_lr = True 23 | gamma = 0.99 24 | gae_lambda = 0.95 25 | update_epochs = 4 26 | norm_adv = True 27 | clip_coef = 0.1 28 | clip_vloss = True 29 | vf_coef = 0.5 30 | vf_clip_coef = 0.1 31 | max_grad_norm = 0.5 32 | ent_coef = 0.01 33 | target_kl = None 34 | 35 | num_envs = 8 36 | num_workers = 8 37 | env_batch_size = None 38 | zero_copy = True 39 | data_dir = experiments 40 | checkpoint_interval = 200 41 | batch_size = 1024 42 | minibatch_size = 512 43 | bptt_horizon = 16 44 | compile = False 45 | compile_mode = reduce-overhead 46 | 47 | [sweep] 48 | method = bayes 49 | name = sweep 50 | 51 | [sweep.metric] 52 | goal = maximize 53 | name = environment/episode_return 54 | 55 | [sweep.parameters.train.parameters.learning_rate] 56 | distribution = log_uniform_values 57 | min = 1e-5 58 | max = 1e-1 59 | 60 | [sweep.parameters.train.parameters.gamma] 61 | distribution = uniform 62 | min = 0.0 63 | max = 1.0 64 | 65 | [sweep.parameters.train.parameters.gae_lambda] 66 | distribution = uniform 67 | min = 0.0 68 | max = 1.0 69 | 70 | [sweep.parameters.train.parameters.update_epochs] 71 | distribution = int_uniform 72 | min = 1 73 | max = 4 74 | 75 | [sweep.parameters.train.parameters.vf_coef] 76 | distribution = uniform 77 | min = 0.0 78 | max = 1.0 79 | 80 | [sweep.parameters.train.parameters.max_grad_norm] 81 | distribution = uniform 82 | min = 0.0 83 | max = 10.0 84 | 85 | [sweep.parameters.train.parameters.ent_coef] 86 | distribution = log_uniform_values 87 | min = 1e-5 88 | max = 1e-1 89 | 90 | [sweep.parameters.train.parameters.bptt_horizon] 91 | values = [1, 2, 4, 8, 16] 92 | -------------------------------------------------------------------------------- /config/dm_control.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = dm_control 3 | env_name = dmc 4 | -------------------------------------------------------------------------------- /config/dm_lab.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = dm_lab 3 | env_name = dml 4 | -------------------------------------------------------------------------------- /config/doom.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = vizdoom 3 | env_name = doom 4 | 5 | [train] 6 | num_envs = 144 7 | num_workers = 24 8 | env_batch_size = 48 9 | zero_copy = False 10 | batch_size = 8192 11 | minibatch_size = 2048 12 | update_epochs = 1 13 | -------------------------------------------------------------------------------- /config/gpudrive.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = gpudrive 3 | env_name = gpudrive 4 | policy_name = Policy 5 | rnn_name = Recurrent 6 | 7 | [env] 8 | num_worlds = 512 9 | 10 | [train] 11 | total_timesteps = 10_000_000 12 | num_envs = 1 13 | num_workers = 1 14 | env_batch_size = 1 15 | zero_copy = False 16 | batch_size = 262_144 17 | update_epochs = 5 18 | minibatch_size = 32768 19 | bptt_horizon = 4 20 | anneal_lr = False 21 | gae_lambda = 0.95 22 | gamma = 0.99 23 | clip_coef = 0.2 24 | vf_coef = 0.5 25 | vf_clip_coef = 0.2 26 | max_grad_norm = 0.5 27 | ent_coef = 0.00 28 | learning_rate = 0.0003 29 | checkpoint_interval = 1000 30 | device = cuda 31 | 32 | [sweep.metric] 33 | goal = maximize 34 | name = environment/goal_achieved 35 | 36 | [sweep.parameters.train.parameters.batch_size] 37 | distribution = uniform 38 | min = 32768 39 | max = 524288 40 | 41 | [sweep.parameters.train.parameters.minibatch_size] 42 | distribution = uniform 43 | min = 2048 44 | max = 32768 45 | 46 | 47 | -------------------------------------------------------------------------------- /config/griddly.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = griddly 3 | env_name = spiders 4 | -------------------------------------------------------------------------------- /config/gvgai.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = gvgai 3 | env_name = zelda 4 | policy_name = Policy 5 | rnn_name = Recurrent 6 | 7 | [train] 8 | total_timesteps = 1_000_000 9 | checkpoint_interval = 1000 10 | learning_rate = 0.00024290984560207393 11 | num_envs = 96 12 | num_workers = 24 13 | env_batch_size = 32 14 | update_epochs = 1 15 | zero_copy = False 16 | bptt_horizon = 16 17 | batch_size = 4096 18 | minibatch_size = 1024 19 | compile = False 20 | anneal_lr = False 21 | device = cuda 22 | 23 | [sweep.metric] 24 | goal = maximize 25 | name = environment/reward 26 | 27 | [sweep.parameters.train.parameters.total_timesteps] 28 | distribution = log_uniform_values 29 | min = 500_000_000 30 | max = 10_000_000_000 31 | 32 | [sweep.parameters.train.parameters.batch_size] 33 | values = [4096, 8192, 16384] 34 | 35 | [sweep.parameters.train.parameters.minibatch_size] 36 | values = [512, 1024, 2048, 4096] 37 | -------------------------------------------------------------------------------- /config/magent.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = magent 3 | env_name = battle_v4 4 | -------------------------------------------------------------------------------- /config/microrts.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = microrts 3 | env_name = GlobalAgentCombinedRewardEnv 4 | -------------------------------------------------------------------------------- /config/minerl.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = minerl 3 | env_name = MineRLNavigateDense-v0 4 | -------------------------------------------------------------------------------- /config/minigrid.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = minigrid 3 | env_name = minigrid 4 | 5 | [train] 6 | total_timesteps = 1_000_000 7 | num_envs = 48 8 | num_workers = 6 9 | env_batch_size = 48 10 | batch_size = 6144 11 | minibatch_size = 768 12 | update_epochs = 4 13 | anneal_lr = False 14 | gae_lambda = 0.95 15 | gamma = 0.95 16 | ent_coef = 0.025 17 | learning_rate = 2.5e-4 18 | -------------------------------------------------------------------------------- /config/minihack.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = minihack 3 | env_name = minihack 4 | 5 | [train] 6 | num_envs = 48 7 | num_workers = 24 8 | env_batch_size = 48 9 | zero_copy = False 10 | batch_size = 6144 11 | minibatch_size = 1536 12 | -------------------------------------------------------------------------------- /config/mujoco.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = mujoco 3 | env_name = HalfCheetah-v4 Hopper-v4 Swimmer-v4 Walker2d-v4 Ant-v4 Humanoid-v4 Reacher-v4 InvertedPendulum-v4 InvertedDoublePendulum-v4 Pusher-v4 HumanoidStandup-v4 4 | policy_name = CleanRLPolicy 5 | # rnn_name = Recurrent 6 | 7 | # The following is from https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_continuous_action.py 8 | [train] 9 | seed = 1 10 | torch_deterministic = True 11 | cpu_offload = False 12 | device = cuda 13 | total_timesteps = 1_000_000 14 | learning_rate = 3e-4 15 | anneal_lr = True 16 | gamma = 0.99 17 | gae_lambda = 0.95 18 | update_epochs = 10 19 | norm_adv = True 20 | clip_coef = 0.2 21 | clip_vloss = True 22 | vf_coef = 0.5 23 | vf_clip_coef = 0.2 24 | max_grad_norm = 0.5 25 | ent_coef = 0.0 26 | target_kl = None 27 | 28 | num_envs = 1 29 | num_workers = 1 30 | env_batch_size = None 31 | zero_copy = False 32 | data_dir = experiments 33 | checkpoint_interval = 200 34 | batch_size = 2048 35 | minibatch_size = 32 36 | bptt_horizon = 1 37 | compile = False 38 | compile_mode = reduce-overhead -------------------------------------------------------------------------------- /config/nethack.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = nethack 3 | env_name = nethack 4 | 5 | [train] 6 | num_envs = 72 7 | num_workers = 24 8 | env_batch_size = 48 9 | zero_copy = False 10 | batch_size = 6144 11 | minibatch_size = 1536 12 | update_epochs = 1 13 | -------------------------------------------------------------------------------- /config/nmmo.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = nmmo 3 | env_name = nmmo 4 | 5 | [train] 6 | num_envs = 4 7 | env_batch_size = 4 8 | num_workers = 4 9 | batch_size = 4096 10 | minibatch_size = 2048 11 | -------------------------------------------------------------------------------- /config/ocean/breakout.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = ocean 3 | env_name = puffer_breakout 4 | policy_name = Policy 5 | rnn_name = Recurrent 6 | 7 | [env] 8 | num_envs = 1024 9 | 10 | [train] 11 | total_timesteps =300000000 12 | checkpoint_interval = 50 13 | num_envs = 1 14 | num_workers = 1 15 | env_batch_size = 1 16 | batch_size = 262144 17 | update_epochs = 4 18 | learning_rate = 0.0005978428084749276 19 | minibatch_size = 4096 20 | bptt_horizon = 16 21 | anneal_lr = False 22 | device = cuda 23 | gamma = 0.9257755108746066 24 | gae_lambda = 0.8783667470139129 25 | ent_coef = 0.0027080029654114927 26 | max_grad_norm = 0.3808319568634033 27 | vf_coef = 0.17343129599886223 28 | -------------------------------------------------------------------------------- /config/ocean/connect4.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = ocean 3 | env_name = puffer_connect4 4 | vec = multiprocessing 5 | policy_name = Policy 6 | rnn_name = Recurrent 7 | 8 | [env] 9 | num_envs = 512 10 | 11 | [train] 12 | total_timesteps = 10_000_000 13 | checkpoint_interval = 50 14 | num_envs = 8 15 | num_workers = 8 16 | env_batch_size = 1 17 | batch_size = 32768 18 | update_epochs = 3 19 | minibatch_size = 8192 20 | bptt_horizon = 8 21 | max_grad_norm = 0.05481921136379242 22 | learning_rate = 0.00859505079095484 23 | env_coef = 0.02805873082160289 24 | gae_lambda = 0.2930961059311335 25 | gamma = 0.978843792530436 26 | vf_coef = 0.960235238467549 27 | anneal_lr = False 28 | device = cuda 29 | 30 | [sweep.metric] 31 | goal = maximize 32 | name = environment/score 33 | 34 | [sweep.parameters.train.parameters.total_timesteps] 35 | distribution = uniform 36 | min = 10_000_000 37 | max = 100_000_000 38 | -------------------------------------------------------------------------------- /config/ocean/continuous.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = ocean 3 | env_name = puffer_continuous 4 | 5 | [train] 6 | total_timesteps = 1_000_000 7 | anneal_lr = False 8 | num_envs = 64 9 | batch_size = 16384 10 | minibatch_size = 4096 11 | update_epochs = 1 12 | gamma = 0.8 13 | ent_coef = 0.05 14 | -------------------------------------------------------------------------------- /config/ocean/enduro.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = ocean 3 | env_name = puffer_enduro 4 | policy_name = Policy 5 | rnn_name = Recurrent 6 | 7 | [env] 8 | num_envs = 4096 9 | 10 | [train] 11 | total_timesteps = 500_000_000 12 | checkpoint_interval = 200 13 | num_envs = 1 14 | num_workers = 1 15 | env_batch_size = 1 16 | batch_size = 131072 17 | update_epochs = 1 18 | minibatch_size = 16384 19 | bptt_horizon = 16 20 | clip_coef = 0.2 21 | vf_clip_coef = 0.2 22 | vf_coef = 0.5 23 | ent_coef = 0.005 24 | gae_lambda = 0.95 25 | gamma = 0.97 26 | learning_rate = 0.001 27 | max_grad_norm = 0.5 28 | anneal_lr = False 29 | device = cuda 30 | -------------------------------------------------------------------------------- /config/ocean/go.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = ocean 3 | env_name = puffer_go 4 | vec = multiprocessing 5 | policy_name = Go 6 | rnn_name = Recurrent 7 | 8 | [env] 9 | num_envs = 2048 10 | reward_move_pass = -0.47713279724121094 11 | reward_move_valid = 0 12 | reward_move_invalid = -0.47179355621337893 13 | reward_opponent_capture = -0.5240603446960449 14 | reward_player_capture = 0.22175729274749756 15 | grid_size = 7 16 | 17 | [train] 18 | total_timesteps = 2_000_000_000 19 | checkpoint_interval = 50 20 | num_envs = 2 21 | num_workers = 2 22 | env_batch_size =1 23 | batch_size = 524288 24 | update_epochs = 1 25 | minibatch_size = 131072 26 | bptt_horizon = 16 27 | learning_rate = 0.0015 28 | ent_coef = 0.013460194258584548 29 | gae_lambda = 0.90 30 | gamma = 0.95 31 | max_grad_norm = 0.8140400052070618 32 | vf_coef = 0.48416485817685223 33 | anneal_lr = False 34 | device = cpu 35 | 36 | [sweep.parameters.env.parameters.reward_move_invalid] 37 | distribution = uniform 38 | min = -1.0 39 | max = 0.0 40 | 41 | [sweep.parameters.env.parameters.reward_move_pass] 42 | distribution = uniform 43 | min = -1.0 44 | max = 0.0 45 | 46 | [sweep.parameters.env.parameters.reward_player_capture] 47 | distribution = uniform 48 | min = 0.0 49 | max = 1.0 50 | 51 | [sweep.parameters.env.parameters.reward_opponent_capture] 52 | distribution = uniform 53 | min = -1.0 54 | max = 0.0 55 | -------------------------------------------------------------------------------- /config/ocean/grid.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = ocean 3 | env_name = puffer_grid 4 | vec = multiprocessing 5 | policy_name = Policy 6 | rnn_name = Recurrent 7 | 8 | [train] 9 | total_timesteps = 100_000_000 10 | checkpoint_interval = 1000 11 | learning_rate = 0.001 12 | num_envs = 2 13 | num_workers = 2 14 | env_batch_size = 1 15 | update_epochs = 1 16 | bptt_horizon = 16 17 | batch_size = 131072 18 | minibatch_size = 32768 19 | compile = False 20 | anneal_lr = False 21 | device = cuda 22 | 23 | [sweep.metric] 24 | goal = maximize 25 | name = environment/episode_return 26 | 27 | [sweep.parameters.train.parameters.batch_size] 28 | distribution = uniform 29 | min = 65536 30 | max = 524288 31 | 32 | [sweep.parameters.train.parameters.minibatch_size] 33 | distribution = uniform 34 | min = 8192 35 | max = 65536 36 | -------------------------------------------------------------------------------- /config/ocean/moba.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = ocean 3 | env_name = puffer_moba 4 | vec = multiprocessing 5 | policy_name = MOBA 6 | rnn_name = Recurrent 7 | 8 | [train] 9 | total_timesteps = 250_000_000 10 | checkpoint_interval = 50 11 | learning_rate = 0.0033394116514234556 12 | num_envs = 8 13 | num_workers = 8 14 | env_batch_size = 4 15 | update_epochs = 3 16 | gamma = 0.9885385317249888 17 | gae_lambda = 0.8723856970733372 18 | clip_coef = 0.1 19 | vf_clip_coef = 0.1 20 | vf_coef = 0.0957932474946704 21 | ent_coef = 0.00006591576198600687 22 | max_grad_norm = 1.8240838050842283 23 | bptt_horizon = 16 24 | batch_size = 1024000 25 | minibatch_size = 256000 26 | compile = False 27 | anneal_lr = False 28 | device = cuda 29 | 30 | [env] 31 | reward_death = 0.0 32 | reward_xp = 0.0016926873475313188 33 | reward_distance = 0.0 34 | reward_tower = 4.525112152099609 35 | num_envs = 100 36 | 37 | [sweep.metric] 38 | goal = maximize 39 | name = environment/radiant_towers_alive 40 | 41 | [sweep.parameters.env.parameters.reward_death] 42 | distribution = uniform 43 | min = -5.0 44 | max = 0 45 | 46 | [sweep.parameters.env.parameters.reward_xp] 47 | distribution = uniform 48 | min = 0.0 49 | max = 0.05 50 | 51 | [sweep.parameters.env.parameters.reward_tower] 52 | distribution = uniform 53 | min = 0.0 54 | max = 5.0 55 | 56 | [sweep.parameters.train.parameters.total_timesteps] 57 | distribution = uniform 58 | min = 200_000_000 59 | max = 2_000_000_000 60 | -------------------------------------------------------------------------------- /config/ocean/nmmo3.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = ocean 3 | env_name = puffer_nmmo3 4 | vec = multiprocessing 5 | policy_name = NMMO3 6 | rnn_name = NMMO3LSTM 7 | 8 | [train] 9 | total_timesteps = 107000000000 10 | checkpoint_interval = 1000 11 | learning_rate = 0.0004573146765703167 12 | num_envs = 2 13 | num_workers = 2 14 | env_batch_size = 1 15 | update_epochs = 1 16 | gamma = 0.7647543366891623 17 | gae_lambda = 0.996005622445478 18 | ent_coef = 0.01210084358004069 19 | max_grad_norm = 0.6075578331947327 20 | vf_coef = 0.3979089612467003 21 | bptt_horizon = 16 22 | batch_size = 262144 23 | minibatch_size = 32768 24 | compile = False 25 | anneal_lr = False 26 | 27 | [env] 28 | reward_combat_level = 2.9437930583953857 29 | reward_prof_level = 1.445250153541565 30 | reward_item_level = 1.3669428825378418 31 | reward_market = 0 32 | reward_death = -2.46451187133789 33 | 34 | [sweep.metric] 35 | goal = maximize 36 | name = environment/min_comb_prof 37 | 38 | [sweep.parameters.env.parameters.reward_combat_level] 39 | distribution = uniform 40 | min = 0.0 41 | max = 5.0 42 | 43 | [sweep.parameters.env.parameters.reward_prof_level] 44 | distribution = uniform 45 | min = 0.0 46 | max = 5.0 47 | 48 | [sweep.parameters.env.parameters.reward_item_level] 49 | distribution = uniform 50 | min = 0.0 51 | max = 5.0 52 | 53 | [sweep.parameters.env.parameters.reward_death_mmo] 54 | distribution = uniform 55 | min = -5.0 56 | max = 0.0 57 | 58 | [sweep.parameters.train.parameters.total_timesteps] 59 | distribution = uniform 60 | min = 1_000_000_000 61 | max = 10_000_000_000 62 | 63 | -------------------------------------------------------------------------------- /config/ocean/pong.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = ocean 3 | env_name = puffer_pong 4 | policy_name = Policy 5 | rnn_name = Recurrent 6 | 7 | [env] 8 | num_envs = 1024 9 | 10 | [train] 11 | total_timesteps = 20_000_000 12 | checkpoint_interval = 25 13 | num_envs = 1 14 | num_workers = 1 15 | env_batch_size = 1 16 | batch_size = 131072 17 | update_epochs = 3 18 | minibatch_size = 8192 19 | bptt_horizon = 16 20 | ent_coef = 0.004602 21 | gae_lambda = 0.979 22 | gamma = 0.9879 23 | learning_rate = 0.001494 24 | anneal_lr = False 25 | device = cuda 26 | max_grad_norm = 3.592 27 | vf_coef = 0.4122 28 | 29 | [sweep.metric] 30 | goal = maximize 31 | name = environment/score 32 | 33 | [sweep.parameters.train.parameters.total_timesteps] 34 | distribution = uniform 35 | min = 10_000_000 36 | max = 30_000_000 37 | -------------------------------------------------------------------------------- /config/ocean/pysquared.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = ocean 3 | env_name = puffer_pysquared 4 | policy_name = Policy 5 | rnn_name = Recurrent 6 | 7 | [env] 8 | num_envs = 1 9 | 10 | [train] 11 | total_timesteps = 40_000_000 12 | checkpoint_interval = 50 13 | num_envs = 12288 14 | num_workers = 12 15 | env_batch_size = 4096 16 | batch_size = 131072 17 | update_epochs = 1 18 | minibatch_size = 8192 19 | learning_rate = 0.0017 20 | anneal_lr = False 21 | device = cuda 22 | -------------------------------------------------------------------------------- /config/ocean/rware.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = ocean 3 | env_name = puffer_rware 4 | policy_name = Policy 5 | rnn_name = Recurrent 6 | 7 | [env] 8 | num_envs = 1024 9 | 10 | [train] 11 | total_timesteps = 175000000 12 | checkpoint_interval = 25 13 | num_envs = 1 14 | num_workers = 1 15 | env_batch_size = 1 16 | batch_size = 131072 17 | update_epochs = 1 18 | minibatch_size = 32768 19 | bptt_horizon = 8 20 | anneal_lr = False 21 | ent_coef = 0.019885424670094166 22 | device = cuda 23 | learning_rate=0.0018129721882644975 24 | gamma = 0.9543211781474217 25 | gae_lambda = 0.8297991396183212 26 | vf_coef = 0.3974834958825928 27 | clip_coef = 0.1 28 | vf_clip_coef = 0.1 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /config/ocean/sanity.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = ocean 3 | env_name = puffer_bandit puffer_memory puffer_multiagent puffer_password puffer_spaces puffer_stochastic 4 | policy_name = Policy 5 | rnn_name = Recurrent 6 | 7 | [train] 8 | total_timesteps = 50_000 9 | learning_rate = 0.017 10 | num_envs = 8 11 | num_workers = 2 12 | env_batch_size = 8 13 | batch_size = 1024 14 | minibatch_size = 128 15 | bptt_horizon = 4 16 | device = cpu 17 | 18 | [sweep.parameters.train.parameters.batch_size] 19 | distribution = uniform 20 | min = 512 21 | max = 2048 22 | 23 | [sweep.parameters.train.parameters.minibatch_size] 24 | distribution = uniform 25 | min = 64 26 | max = 512 27 | -------------------------------------------------------------------------------- /config/ocean/snake.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = ocean 3 | env_name = puffer_snake 4 | vec = multiprocessing 5 | rnn_name = Recurrent 6 | 7 | [env] 8 | vision = 5 9 | 10 | [train] 11 | total_timesteps = 200_000_000 12 | num_envs = 2 13 | num_workers = 2 14 | env_batch_size = 1 15 | batch_size = 131072 16 | update_epochs = 1 17 | minibatch_size = 32768 18 | bptt_horizon = 16 19 | anneal_lr = False 20 | gae_lambda = 0.9776227170639571 21 | gamma = 0.8567482546637853 22 | clip_coef = 0.011102333784435113 23 | vf_coef = 0.3403069830175013 24 | vf_clip_coef = 0.26475190539131727 25 | max_grad_norm = 0.8660179376602173 26 | ent_coef = 0.01376980586465873 27 | learning_rate = 0.002064722899262613 28 | checkpoint_interval = 1000 29 | device = cuda 30 | 31 | [sweep.metric] 32 | goal = maximize 33 | name = environment/reward 34 | -------------------------------------------------------------------------------- /config/ocean/squared.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = ocean 3 | env_name = puffer_squared 4 | policy_name = Policy 5 | rnn_name = Recurrent 6 | 7 | [env] 8 | num_envs = 4096 9 | 10 | [train] 11 | total_timesteps = 20_000_000 12 | checkpoint_interval = 50 13 | num_envs = 1 14 | num_workers = 1 15 | env_batch_size = 1 16 | batch_size = 131072 17 | update_epochs = 1 18 | minibatch_size = 8192 19 | learning_rate = 0.017 20 | anneal_lr = False 21 | device = cuda 22 | -------------------------------------------------------------------------------- /config/ocean/tactical.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = ocean 3 | env_name = puffer_tactical 4 | -------------------------------------------------------------------------------- /config/ocean/trash_pickup.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = ocean 3 | env_name = trash_pickup puffer_trash_pickup 4 | vec = multiprocessing 5 | policy_name = TrashPickup 6 | rnn_name = Recurrent 7 | 8 | [env] 9 | num_envs = 1024 # Recommended: 4096 (recommended start value) / num_agents 10 | grid_size = 10 11 | num_agents = 4 12 | num_trash = 20 13 | num_bins = 1 14 | max_steps = 150 15 | report_interval = 32 16 | agent_sight_range = 5 # only used with 2D local crop obs space 17 | 18 | [train] 19 | total_timesteps = 100_000_000 20 | checkpoint_interval = 200 21 | num_envs = 2 22 | num_workers = 2 23 | env_batch_size = 1 24 | batch_size = 131072 25 | update_epochs = 1 26 | minibatch_size = 16384 27 | bptt_horizon = 8 28 | anneal_lr = False 29 | device = cuda 30 | learning_rate=0.001 31 | gamma = 0.95 32 | gae_lambda = 0.85 33 | vf_ceof = 0.4 34 | clip_coef = 0.1 35 | vf_clip_coef = 0.1 36 | ent_coef = 0.01 37 | 38 | [sweep.metric] 39 | goal = maximize 40 | name = environment/episode_return 41 | 42 | [sweep.parameters.train.parameters.learning_rate] 43 | distribution = log_uniform_values 44 | min = 0.000001 45 | max = 0.01 46 | 47 | [sweep.parameters.train.parameters.gamma] 48 | distribution = uniform 49 | min = 0 50 | max = 1 51 | 52 | [sweep.parameters.train.parameters.gae_lambda] 53 | distribution = uniform 54 | min = 0 55 | max = 1 56 | 57 | [sweep.parameters.train.parameters.update_epochs] 58 | distribution = int_uniform 59 | min = 1 60 | max = 4 61 | 62 | [sweep.parameters.train.parameters.ent_coef] 63 | distribution = log_uniform_values 64 | min = 1e-5 65 | max = 1e-1 66 | -------------------------------------------------------------------------------- /config/ocean/tripletriad.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = ocean 3 | env_name = puffer_tripletriad 4 | policy_name = Policy 5 | rnn_name = Recurrent 6 | 7 | [env] 8 | num_envs = 4096 9 | 10 | [train] 11 | total_timesteps = 200_000_000 12 | checkpoint_interval = 50 13 | num_envs = 1 14 | num_workers = 1 15 | env_batch_size = 1 16 | batch_size = 131072 17 | update_epochs = 2 18 | minibatch_size = 16384 19 | bptt_horizon = 16 20 | ent_coef = 0.0050619133743733105 21 | gae_lambda = 0.9440403722133228 22 | gamma = 0.9685297452478734 23 | learning_rate = 0.001092406907391121 24 | anneal_lr = False 25 | device = cuda 26 | 27 | [sweep.metric] 28 | goal = maximize 29 | name = environment/score 30 | 31 | [sweep.parameters.train.parameters.total_timesteps] 32 | distribution = uniform 33 | min = 100_000_000 34 | max = 500_000_000 35 | -------------------------------------------------------------------------------- /config/open_spiel.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = open_spiel 3 | env_name = connect_four 4 | 5 | [train] 6 | num_envs = 32 7 | batch_size = 4096 8 | -------------------------------------------------------------------------------- /config/pokemon_red.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = pokemon_red 3 | env_name = pokemon_red 4 | 5 | [train] 6 | total_timesteps = 1_000_000 7 | num_envs = 96 8 | num_workers = 24 9 | env_batch_size = 32 10 | zero_copy = False 11 | update_epochs = 3 12 | gamma = 0.998 13 | batch_size = 65536 14 | minibatch_size = 2048 15 | compile = True 16 | learning_rate = 2.0e-4 17 | anneal_lr = False 18 | -------------------------------------------------------------------------------- /config/procgen.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = procgen 3 | env_name = bigfish bossfight caveflyer chaser climber coinrun dodgeball fruitbot heist jumper leaper maze miner ninja plunder starpilot 4 | 5 | [train] 6 | total_timesteps = 25_000_000 7 | learning_rate = 0.0005 8 | num_workers = 24 9 | num_envs = 96 10 | env_batch_size = 48 11 | zero_copy = False 12 | batch_size = 16384 13 | minibatch_size = 2048 14 | gamma = 0.99 15 | gae_lambda = 0.95 16 | anneal_lr = False 17 | -------------------------------------------------------------------------------- /config/slimevolley.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = slimevolley 3 | env_name = slimevolley 4 | 5 | [train] 6 | num_envs=1536 7 | num_workers=24 8 | env_batch_size=512 9 | zero_copy=False 10 | batch_size=65536 11 | minibatch_size=8192 12 | update_epochs=1 13 | -------------------------------------------------------------------------------- /config/stable_retro.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = stable_retro 3 | env_name = Airstriker-Genesis 4 | -------------------------------------------------------------------------------- /config/starcraft.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = smac 3 | -------------------------------------------------------------------------------- /config/trade_sim.ini: -------------------------------------------------------------------------------- 1 | [base] 2 | package = trade_sim 3 | env_name = trade_sim 4 | policy_name = Policy 5 | rnn_name = Recurrent 6 | vec = multiprocessing 7 | #vec = serial 8 | 9 | #[env] 10 | #num_envs = 128 11 | 12 | [train] 13 | total_timesteps = 100_000_000 14 | num_envs = 2048 15 | num_workers = 16 16 | env_batch_size = 1024 17 | batch_size = 262144 18 | gamma = 0.95 19 | 20 | [sweep] 21 | method = protein 22 | name = sweep 23 | 24 | [sweep.metric] 25 | goal = maximize 26 | name = score 27 | min = 0 28 | max = 864 29 | 30 | [sweep.train.total_timesteps] 31 | distribution = log_normal 32 | min = 2e7 33 | max = 1e8 34 | mean = 5e7 35 | scale = auto 36 | -------------------------------------------------------------------------------- /pufferlib/__init__.py: -------------------------------------------------------------------------------- 1 | from pufferlib import version 2 | __version__ = version.__version__ 3 | 4 | import os 5 | import sys 6 | 7 | # Silence noisy dependencies 8 | import warnings 9 | warnings.filterwarnings("ignore", category=DeprecationWarning) 10 | 11 | # Silence noisy packages 12 | original_stdout = sys.stdout 13 | original_stderr = sys.stderr 14 | sys.stdout = open(os.devnull, 'w') 15 | sys.stderr = open(os.devnull, 'w') 16 | try: 17 | import gymnasium 18 | import pygame 19 | except ImportError: 20 | pass 21 | sys.stdout.close() 22 | sys.stderr.close() 23 | sys.stdout = original_stdout 24 | sys.stderr = original_stderr 25 | 26 | from pufferlib.namespace import namespace, dataclass, Namespace 27 | from pufferlib import environments 28 | from pufferlib.environment import PufferEnv 29 | -------------------------------------------------------------------------------- /pufferlib/environments/__init__.py: -------------------------------------------------------------------------------- 1 | from pdb import set_trace as T 2 | import pufferlib 3 | 4 | def try_import(module_path, package_name=None): 5 | if package_name is None: 6 | package_name = module_path 7 | try: 8 | package = __import__(module_path) 9 | except ImportError as e: 10 | raise ImportError( 11 | f'{e.args[0]}\n\n' 12 | 'This is probably an installation error. Try: ' 13 | f'pip install pufferlib[{package_name}]. ' 14 | 15 | 'Note that some environments have non-python dependencies. ' 16 | 'These are included in PufferTank. Or, you can install ' 17 | 'manually by following the instructions provided by the ' 18 | 'environment meaintainers. But some are finicky, so we ' 19 | 'recommend using PufferTank.' 20 | ) from e 21 | return package 22 | -------------------------------------------------------------------------------- /pufferlib/environments/atari/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment import env_creator 2 | 3 | try: 4 | import torch 5 | except ImportError: 6 | pass 7 | else: 8 | from .torch import Policy 9 | try: 10 | from .torch import Recurrent 11 | except: 12 | Recurrent = None 13 | -------------------------------------------------------------------------------- /pufferlib/environments/atari/torch.py: -------------------------------------------------------------------------------- 1 | import pufferlib.models 2 | 3 | 4 | class Recurrent(pufferlib.models.LSTMWrapper): 5 | def __init__(self, env, policy, input_size=512, hidden_size=512, num_layers=1): 6 | super().__init__(env, policy, input_size, hidden_size, num_layers) 7 | 8 | class Policy(pufferlib.models.Convolutional): 9 | def __init__(self, env, input_size=512, hidden_size=512, output_size=512, 10 | framestack=1, flat_size=64*6*9): 11 | super().__init__( 12 | env=env, 13 | input_size=input_size, 14 | hidden_size=hidden_size, 15 | output_size=output_size, 16 | framestack=framestack, 17 | flat_size=flat_size, 18 | ) 19 | -------------------------------------------------------------------------------- /pufferlib/environments/box2d/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment import env_creator, make 2 | 3 | try: 4 | import torch 5 | except ImportError: 6 | pass 7 | else: 8 | from .torch import Policy 9 | try: 10 | from .torch import Recurrent 11 | except: 12 | Recurrent = None 13 | -------------------------------------------------------------------------------- /pufferlib/environments/box2d/environment.py: -------------------------------------------------------------------------------- 1 | from pdb import set_trace as T 2 | 3 | import gymnasium 4 | import functools 5 | 6 | import pufferlib.emulation 7 | import pufferlib.environments 8 | import pufferlib.postprocess 9 | 10 | 11 | def env_creator(name='car-racing'): 12 | return functools.partial(make, name=name) 13 | 14 | def make(name, domain_randomize=True, continuous=False, render_mode='rgb_array', buf=None): 15 | if name == 'car-racing': 16 | name = 'CarRacing-v2' 17 | 18 | env = gymnasium.make(name, render_mode=render_mode, 19 | domain_randomize=domain_randomize, continuous=continuous) 20 | env = pufferlib.postprocess.EpisodeStats(env) 21 | return pufferlib.emulation.GymnasiumPufferEnv(env=env, buf=buf) 22 | -------------------------------------------------------------------------------- /pufferlib/environments/box2d/torch.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import torch 3 | 4 | import pufferlib.models 5 | 6 | class Recurrent(pufferlib.models.LSTMWrapper): 7 | def __init__(self, env, policy, 8 | input_size=128, hidden_size=128, num_layers=1): 9 | super().__init__(env, policy, 10 | input_size, hidden_size, num_layers) 11 | 12 | class Policy(pufferlib.models.Convolutional): 13 | def __init__(self, env, 14 | input_size=128, hidden_size=128, output_size=128, 15 | framestack=3, flat_size=64*8*8): 16 | super().__init__( 17 | env=env, 18 | input_size=input_size, 19 | hidden_size=hidden_size, 20 | output_size=output_size, 21 | framestack=framestack, 22 | flat_size=flat_size, 23 | channels_last=True, 24 | ) 25 | -------------------------------------------------------------------------------- /pufferlib/environments/bsuite/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment import env_creator, make 2 | 3 | try: 4 | import torch 5 | except ImportError: 6 | pass 7 | else: 8 | from .torch import Policy 9 | try: 10 | from .torch import Recurrent 11 | except: 12 | Recurrent = None 13 | -------------------------------------------------------------------------------- /pufferlib/environments/bsuite/environment.py: -------------------------------------------------------------------------------- 1 | from pdb import set_trace as T 2 | import gym 3 | import functools 4 | 5 | import pufferlib.emulation 6 | import pufferlib.wrappers 7 | 8 | import bsuite 9 | from bsuite.utils import gym_wrapper 10 | 11 | def env_creator(name='bandit/0'): 12 | return functools.partial(make, name) 13 | 14 | def make(name='bandit/0', results_dir='experiments/bsuite', overwrite=True, buf=None): 15 | '''BSuite environments''' 16 | bsuite = pufferlib.environments.try_import('bsuite') 17 | from bsuite.utils import gym_wrapper 18 | env = bsuite.load_and_record_to_csv(name, results_dir, overwrite=overwrite) 19 | env = gym_wrapper.GymFromDMEnv(env) 20 | env = BSuiteStopper(env) 21 | env = pufferlib.wrappers.GymToGymnasium(env) 22 | env = pufferlib.emulation.GymnasiumPufferEnv(env, buf=buf) 23 | return env 24 | 25 | class BSuiteStopper: 26 | def __init__(self, env): 27 | self.env = env 28 | self.num_episodes = 0 29 | 30 | self.step = self.env.step 31 | self.render = self.env.render 32 | self.close = self.env.close 33 | self.observation_space = self.env.observation_space 34 | self.action_space = self.env.action_space 35 | 36 | def reset(self): 37 | '''Forces the environment to stop after the 38 | number of episodes required by bsuite''' 39 | self.num_episodes += 1 40 | 41 | if self.num_episodes >= self.env.bsuite_num_episodes: 42 | exit(0) 43 | 44 | return self.env.reset() 45 | -------------------------------------------------------------------------------- /pufferlib/environments/bsuite/torch.py: -------------------------------------------------------------------------------- 1 | from pufferlib.models import Default as Policy 2 | -------------------------------------------------------------------------------- /pufferlib/environments/butterfly/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment import env_creator, make 2 | 3 | try: 4 | import torch 5 | except ImportError: 6 | pass 7 | else: 8 | from .torch import Policy 9 | try: 10 | from .torch import Recurrent 11 | except: 12 | Recurrent = None 13 | -------------------------------------------------------------------------------- /pufferlib/environments/butterfly/environment.py: -------------------------------------------------------------------------------- 1 | from pdb import set_trace as T 2 | from pettingzoo.utils.conversions import aec_to_parallel_wrapper 3 | import functools 4 | 5 | import pufferlib.emulation 6 | import pufferlib.environments 7 | 8 | 9 | def env_creator(name='cooperative_pong_v5'): 10 | return functools.partial(make, name) 11 | 12 | def make(name, buf=None): 13 | pufferlib.environments.try_import('pettingzoo.butterfly', 'butterfly') 14 | if name == 'cooperative_pong_v5': 15 | from pettingzoo.butterfly import cooperative_pong_v5 as pong 16 | env_cls = pong.raw_env 17 | elif name == 'knights_archers_zombies_v10': 18 | from pettingzoo.butterfly import knights_archers_zombies_v10 as kaz 19 | env_cls = kaz.raw_env 20 | else: 21 | raise ValueError(f'Unknown environment: {name}') 22 | 23 | env = env_cls() 24 | env = aec_to_parallel_wrapper(env) 25 | return pufferlib.emulation.PettingZooPufferEnv(env=env, buf=buf) 26 | -------------------------------------------------------------------------------- /pufferlib/environments/butterfly/torch.py: -------------------------------------------------------------------------------- 1 | import pufferlib.models 2 | 3 | 4 | class Policy(pufferlib.models.Convolutional): 5 | def __init__( 6 | self, 7 | env, 8 | flat_size=3520, 9 | channels_last=True, 10 | downsample=4, 11 | input_size=512, 12 | hidden_size=128, 13 | output_size=128, 14 | **kwargs 15 | ): 16 | super().__init__( 17 | env, 18 | framestack=3, 19 | flat_size=flat_size, 20 | channels_last=channels_last, 21 | downsample=downsample, 22 | input_size=input_size, 23 | hidden_size=hidden_size, 24 | output_size=output_size, 25 | **kwargs 26 | ) 27 | -------------------------------------------------------------------------------- /pufferlib/environments/classic_control/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment import env_creator, make 2 | 3 | try: 4 | import torch 5 | except ImportError: 6 | pass 7 | else: 8 | from .torch import Policy 9 | try: 10 | from .torch import Recurrent 11 | except: 12 | Recurrent = None 13 | -------------------------------------------------------------------------------- /pufferlib/environments/classic_control/environment.py: -------------------------------------------------------------------------------- 1 | import gymnasium 2 | from gymnasium.envs import classic_control 3 | import functools 4 | import numpy as np 5 | 6 | import pufferlib 7 | import pufferlib.emulation 8 | import pufferlib.postprocess 9 | 10 | ALIASES = { 11 | 'cartpole': 'CartPole-v0', 12 | 'mountaincar': 'MountainCar-v0', 13 | } 14 | 15 | def env_creator(name='cartpole'): 16 | return functools.partial(make, name) 17 | 18 | def make(name, render_mode='rgb_array', buf=None): 19 | '''Create an environment by name''' 20 | 21 | if name in ALIASES: 22 | name = ALIASES[name] 23 | 24 | env = gymnasium.make(name, render_mode=render_mode) 25 | if name == 'MountainCar-v0': 26 | env = MountainCarWrapper(env) 27 | 28 | #env = gymnasium.wrappers.NormalizeObservation(env) 29 | env = gymnasium.wrappers.TransformObservation(env, lambda obs: np.clip(obs, -1, 1)) 30 | #env = gymnasium.wrappers.NormalizeReward(env, gamma=gamma) 31 | env = gymnasium.wrappers.TransformReward(env, lambda reward: np.clip(reward, -1, 1)) 32 | env = pufferlib.postprocess.EpisodeStats(env) 33 | return pufferlib.emulation.GymnasiumPufferEnv(env=env, buf=buf) 34 | 35 | class MountainCarWrapper(gymnasium.Wrapper): 36 | def step(self, action): 37 | obs, reward, terminated, truncated, info = self.env.step(action) 38 | reward = abs(obs[0]+0.5) 39 | return obs, reward, terminated, truncated, info 40 | 41 | -------------------------------------------------------------------------------- /pufferlib/environments/classic_control/torch.py: -------------------------------------------------------------------------------- 1 | import pufferlib.models 2 | 3 | 4 | class Policy(pufferlib.models.Default): 5 | def __init__(self, env, hidden_size=64): 6 | super().__init__(env, hidden_size) 7 | -------------------------------------------------------------------------------- /pufferlib/environments/classic_control_continuous/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment import env_creator, make 2 | 3 | try: 4 | import torch 5 | except ImportError: 6 | pass 7 | else: 8 | from .torch import Policy 9 | try: 10 | from .torch import Recurrent 11 | except: 12 | Recurrent = None 13 | -------------------------------------------------------------------------------- /pufferlib/environments/classic_control_continuous/environment.py: -------------------------------------------------------------------------------- 1 | import gymnasium 2 | import functools 3 | 4 | import pufferlib 5 | import pufferlib.emulation 6 | import pufferlib.postprocess 7 | 8 | 9 | def env_creator(name='MountainCarContinuous-v0'): 10 | return functools.partial(make, name) 11 | 12 | def make(name, render_mode='rgb_array', buf=None): 13 | '''Create an environment by name''' 14 | env = gymnasium.make(name, render_mode=render_mode) 15 | if name == 'MountainCarContinuous-v0': 16 | env = MountainCarWrapper(env) 17 | 18 | env = pufferlib.postprocess.ClipAction(env) 19 | env = pufferlib.postprocess.EpisodeStats(env) 20 | return pufferlib.emulation.GymnasiumPufferEnv(env=env, buf=buf) 21 | 22 | class MountainCarWrapper(gymnasium.Wrapper): 23 | def step(self, action): 24 | obs, reward, terminated, truncated, info = self.env.step(action) 25 | reward = abs(obs[0]+0.5) 26 | return obs, reward, terminated, truncated, info 27 | 28 | -------------------------------------------------------------------------------- /pufferlib/environments/classic_control_continuous/torch.py: -------------------------------------------------------------------------------- 1 | import pufferlib.models 2 | 3 | 4 | class Policy(pufferlib.models.Default): 5 | def __init__(self, env, hidden_size=64): 6 | super().__init__(env, hidden_size) 7 | -------------------------------------------------------------------------------- /pufferlib/environments/crafter/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment import env_creator, make 2 | 3 | try: 4 | import torch 5 | except ImportError: 6 | pass 7 | else: 8 | from .torch import Policy 9 | try: 10 | from .torch import Recurrent 11 | except: 12 | Recurrent = None 13 | -------------------------------------------------------------------------------- /pufferlib/environments/crafter/environment.py: -------------------------------------------------------------------------------- 1 | from pdb import set_trace as T 2 | 3 | import gym 4 | import gymnasium 5 | import shimmy 6 | import functools 7 | 8 | import pufferlib 9 | import pufferlib.emulation 10 | import pufferlib.environments 11 | import pufferlib.postprocess 12 | import pufferlib.utils 13 | 14 | 15 | class TransposeObs(gym.Wrapper): 16 | def observation(self, observation): 17 | return observation.transpose(2, 0, 1) 18 | 19 | def env_creator(name='crafter'): 20 | return functools.partial(make, name) 21 | 22 | def make(name, buf=None): 23 | '''Crafter creation function''' 24 | if name == 'crafter': 25 | name = 'CrafterReward-v1' 26 | 27 | pufferlib.environments.try_import('crafter') 28 | env = gym.make(name) 29 | env.reset = pufferlib.utils.silence_warnings(env.reset) 30 | env = shimmy.GymV21CompatibilityV0(env=env) 31 | env = RenderWrapper(env) 32 | env = TransposeObs(env) 33 | env = pufferlib.postprocess.EpisodeStats(env) 34 | return pufferlib.emulation.GymnasiumPufferEnv(env=env, buf=buf) 35 | 36 | class RenderWrapper(gym.Wrapper): 37 | def __init__(self, env): 38 | super().__init__(env) 39 | self.env = env 40 | 41 | @property 42 | def render_mode(self): 43 | return 'rgb_array' 44 | 45 | def render(self, *args, **kwargs): 46 | return self.env.unwrapped.env.unwrapped.render((256,256)) 47 | -------------------------------------------------------------------------------- /pufferlib/environments/crafter/torch.py: -------------------------------------------------------------------------------- 1 | import pufferlib.models 2 | 3 | 4 | class Policy(pufferlib.models.Convolutional): 5 | def __init__( 6 | self, 7 | env, 8 | flat_size=1024, 9 | channels_last=True, 10 | downsample=1, 11 | input_size=512, 12 | hidden_size=128, 13 | output_size=128, 14 | **kwargs 15 | ): 16 | super().__init__( 17 | env, 18 | framestack=3, 19 | flat_size=flat_size, 20 | channels_last=channels_last, 21 | downsample=downsample, 22 | input_size=input_size, 23 | hidden_size=hidden_size, 24 | output_size=output_size, 25 | **kwargs 26 | ) 27 | -------------------------------------------------------------------------------- /pufferlib/environments/dm_control/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment import env_creator, make 2 | 3 | try: 4 | import torch 5 | except ImportError: 6 | pass 7 | else: 8 | from .torch import Policy 9 | try: 10 | from .torch import Recurrent 11 | except: 12 | Recurrent = None 13 | -------------------------------------------------------------------------------- /pufferlib/environments/dm_control/environment.py: -------------------------------------------------------------------------------- 1 | from pdb import set_trace as T 2 | 3 | import gym 4 | import shimmy 5 | import functools 6 | 7 | import pufferlib 8 | import pufferlib.emulation 9 | import pufferlib.environments 10 | 11 | 12 | def env_creator(name='walker'): 13 | '''Deepmind Control environment creation function 14 | 15 | No support for bindings yet because PufferLib does 16 | not support continuous action spaces.''' 17 | return functools.partial(make, name) 18 | 19 | def make(name, task_name='walk', buf=None): 20 | '''Untested. Let us know in Discord if you want to use dmc in PufferLib.''' 21 | dm_control = pufferlib.environments.try_import('dm_control.suite', 'dmc') 22 | env = dm_control.suite.load(name, task_name) 23 | env = shimmy.DmControlCompatibilityV0(env=env) 24 | return pufferlib.emulation.GymnasiumPufferEnv(env, buf=buf) 25 | -------------------------------------------------------------------------------- /pufferlib/environments/dm_control/torch.py: -------------------------------------------------------------------------------- 1 | from pufferlib.models import Default as Policy 2 | -------------------------------------------------------------------------------- /pufferlib/environments/dm_lab/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment import env_creator, make 2 | 3 | try: 4 | import torch 5 | except ImportError: 6 | pass 7 | else: 8 | from .torch import Policy 9 | try: 10 | from .torch import Recurrent 11 | except: 12 | Recurrent = None 13 | -------------------------------------------------------------------------------- /pufferlib/environments/dm_lab/environment.py: -------------------------------------------------------------------------------- 1 | from pdb import set_trace as T 2 | 3 | import gym 4 | import shimmy 5 | import functools 6 | 7 | import pufferlib 8 | import pufferlib.emulation 9 | import pufferlib.environments 10 | 11 | 12 | def env_creator(name='seekavoid_arena_01'): 13 | '''Deepmind Lab binding creation function 14 | dm-lab requires extensive setup. Use PufferTank.''' 15 | return functools.partial(make, name=name) 16 | 17 | def make(name, buf=None): 18 | '''Deepmind Lab binding creation function 19 | dm-lab requires extensive setup. Currently dropped frop PufferTank. 20 | Let us know if you need this for your work.''' 21 | dm_lab = pufferlib.environments.try_import('deepmind_lab', 'dm-lab') 22 | env = dm_lab.Lab(name, ['RGB_INTERLEAVED']) 23 | env = shimmy.DmLabCompatibilityV0(env=env) 24 | return pufferlib.emulation.GymnasiumPufferEnv(env=env, buf=buf) 25 | -------------------------------------------------------------------------------- /pufferlib/environments/dm_lab/torch.py: -------------------------------------------------------------------------------- 1 | import pufferlib.models 2 | 3 | 4 | class Policy(pufferlib.models.Convolutional): 5 | def __init__( 6 | self, 7 | env, 8 | flat_size=3136, 9 | channels_last=True, 10 | downsample=1, 11 | input_size=512, 12 | hidden_size=128, 13 | output_size=128, 14 | **kwargs 15 | ): 16 | super().__init__( 17 | env, 18 | framestack=3, 19 | flat_size=flat_size, 20 | channels_last=channels_last, 21 | downsample=downsample, 22 | input_size=input_size, 23 | hidden_size=hidden_size, 24 | output_size=output_size, 25 | **kwargs 26 | ) 27 | -------------------------------------------------------------------------------- /pufferlib/environments/gpudrive/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment import env_creator 2 | 3 | try: 4 | import torch 5 | except ImportError: 6 | pass 7 | else: 8 | from .torch import Policy 9 | try: 10 | from .torch import Recurrent 11 | except: 12 | Recurrent = None 13 | -------------------------------------------------------------------------------- /pufferlib/environments/griddly/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment import env_creator, make 2 | 3 | try: 4 | import torch 5 | except ImportError: 6 | pass 7 | else: 8 | from .torch import Policy 9 | try: 10 | from .torch import Recurrent 11 | except: 12 | Recurrent = None 13 | -------------------------------------------------------------------------------- /pufferlib/environments/griddly/environment.py: -------------------------------------------------------------------------------- 1 | from pdb import set_trace as T 2 | 3 | import gym 4 | import shimmy 5 | import functools 6 | 7 | import pufferlib 8 | import pufferlib.emulation 9 | import pufferlib.environments 10 | import pufferlib.postprocess 11 | 12 | ALIASES = { 13 | 'spiders': 'GDY-Spiders-v0', 14 | } 15 | 16 | def env_creator(name='spiders'): 17 | return functools.partial(make, name) 18 | 19 | # TODO: fix griddly 20 | def make(name, buf=None): 21 | '''Griddly creation function 22 | 23 | Note that Griddly environments do not have observation spaces until 24 | they are created and reset''' 25 | if name in ALIASES: 26 | name = ALIASES[name] 27 | 28 | import warnings 29 | warnings.warn('Griddly has been segfaulting in the latest build and we do not know why. Submit a PR if you find a fix!') 30 | pufferlib.environments.try_import('griddly') 31 | with pufferlib.utils.Suppress(): 32 | env = gym.make(name) 33 | env.reset() # Populate observation space 34 | 35 | env = shimmy.GymV21CompatibilityV0(env=env) 36 | env = pufferlib.postprocess.EpisodeStats(env) 37 | return pufferlib.emulation.GymnasiumPufferEnv(env, buf=buf) 38 | -------------------------------------------------------------------------------- /pufferlib/environments/griddly/torch.py: -------------------------------------------------------------------------------- 1 | from pufferlib.models import Default as Policy 2 | -------------------------------------------------------------------------------- /pufferlib/environments/gvgai/environment.py: -------------------------------------------------------------------------------- 1 | from pdb import set_trace as T 2 | import numpy as np 3 | import functools 4 | 5 | import gym 6 | 7 | import pufferlib 8 | import pufferlib.emulation 9 | import pufferlib.environments 10 | import pufferlib.utils 11 | import pufferlib.postprocess 12 | import pufferlib.wrappers 13 | 14 | def env_creator(name='zelda'): 15 | if name == 'zelda': 16 | name = 'gvgai-zelda-lvl0-v0' 17 | return functools.partial(make, name) 18 | 19 | def make(name, obs_type='grayscale', frameskip=4, full_action_space=False, 20 | repeat_action_probability=0.0, render_mode='rgb_array', buf=None): 21 | '''Atari creation function''' 22 | pufferlib.environments.try_import('gym_gvgai') 23 | env = gym.make(name) 24 | env = pufferlib.wrappers.GymToGymnasium(env) 25 | env = pufferlib.postprocess.EpisodeStats(env) 26 | env = pufferlib.emulation.GymnasiumPufferEnv(env=env, buf=buf) 27 | return env 28 | 29 | -------------------------------------------------------------------------------- /pufferlib/environments/links_awaken/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment import env_creator, make_env 2 | 3 | try: 4 | import torch 5 | except ImportError: 6 | pass 7 | else: 8 | from .torch import Policy 9 | try: 10 | from .torch import Recurrent 11 | except: 12 | Recurrent = None 13 | -------------------------------------------------------------------------------- /pufferlib/environments/links_awaken/environment.py: -------------------------------------------------------------------------------- 1 | from pdb import set_trace as T 2 | 3 | import gymnasium 4 | 5 | from links_awaken import LinksAwakenV1 as env_creator 6 | 7 | import pufferlib.emulation 8 | 9 | 10 | def make_env(headless: bool = True, state_path=None, buf=None): 11 | '''Links Awakening''' 12 | env = env_creator(headless=headless, state_path=state_path) 13 | env = gymnasium.wrappers.ResizeObservation(env, shape=(72, 80)) 14 | return pufferlib.emulation.GymnasiumPufferEnv(env=env, 15 | postprocessor_cls=pufferlib.emulation.BasicPostprocessor, buf=buf) 16 | -------------------------------------------------------------------------------- /pufferlib/environments/links_awaken/torch.py: -------------------------------------------------------------------------------- 1 | import pufferlib.models 2 | from pufferlib.pytorch import LSTM 3 | 4 | 5 | class Recurrent: 6 | input_size = 512 7 | hidden_size = 512 8 | num_layers = 1 9 | 10 | class Policy(pufferlib.models.Convolutional): 11 | def __init__(self, env, input_size=512, hidden_size=512, output_size=512, 12 | framestack=3, flat_size=64*5*6): 13 | super().__init__( 14 | env=env, 15 | input_size=input_size, 16 | hidden_size=hidden_size, 17 | output_size=output_size, 18 | framestack=framestack, 19 | flat_size=flat_size, 20 | channels_last=True, 21 | ) 22 | -------------------------------------------------------------------------------- /pufferlib/environments/magent/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment import env_creator, make 2 | 3 | try: 4 | import torch 5 | except ImportError: 6 | pass 7 | else: 8 | from .torch import Policy 9 | try: 10 | from .torch import Recurrent 11 | except: 12 | Recurrent = None 13 | -------------------------------------------------------------------------------- /pufferlib/environments/magent/environment.py: -------------------------------------------------------------------------------- 1 | from pdb import set_trace as T 2 | from pettingzoo.utils.conversions import aec_to_parallel_wrapper 3 | import functools 4 | 5 | import pufferlib.emulation 6 | import pufferlib.environments 7 | import pufferlib.wrappers 8 | 9 | 10 | def env_creator(name='battle_v4'): 11 | return functools.partial(make, name) 12 | pufferlib.environments.try_import('pettingzoo.magent', 'magent') 13 | 14 | def make(name, buf=None): 15 | '''MAgent Battle V4 creation function''' 16 | if name == 'battle_v4': 17 | from pettingzoo.magent import battle_v4 18 | env_cls = battle_v4.env 19 | else: 20 | raise ValueError(f'Unknown environment name {name}') 21 | 22 | env = env_cls() 23 | env = aec_to_parallel_wrapper(env) 24 | env = pufferlib.wrappers.PettingZooTruncatedWrapper(env) 25 | return pufferlib.emulation.PettingZooPufferEnv(env=env, buf=buf) 26 | -------------------------------------------------------------------------------- /pufferlib/environments/magent/torch.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | import pufferlib.models 4 | 5 | 6 | class Policy(pufferlib.models.Policy): 7 | '''Based off of the DQN policy in MAgent''' 8 | def __init__(self, env, hidden_size=256, output_size=256, kernel_num=32): 9 | '''The CleanRL default Atari policy: a stack of three convolutions followed by a linear layer 10 | 11 | Takes framestack as a mandatory keyword arguments. Suggested default is 1 frame 12 | with LSTM or 4 frames without.''' 13 | super().__init__(env) 14 | self.num_actions = self.action_space.n 15 | 16 | self.network = nn.Sequential( 17 | pufferlib.pytorch.layer_init(nn.Conv2d(5, kernel_num, 3)), 18 | nn.ReLU(), 19 | pufferlib.pytorch.layer_init(nn.Conv2d(kernel_num, kernel_num, 3)), 20 | nn.ReLU(), 21 | nn.Flatten(), 22 | pufferlib.pytorch.layer_init(nn.Linear(kernel_num*9*9, hidden_size)), 23 | nn.ReLU(), 24 | pufferlib.pytorch.layer_init(nn.Linear(hidden_size, hidden_size)), 25 | nn.ReLU(), 26 | ) 27 | 28 | self.actor = pufferlib.pytorch.layer_init(nn.Linear(output_size, self.num_actions), std=0.01) 29 | self.value_function = pufferlib.pytorch.layer_init(nn.Linear(output_size, 1), std=1) 30 | 31 | def critic(self, hidden): 32 | return self.value_function(hidden) 33 | 34 | def encode_observations(self, observations): 35 | observations = observations.permute(0, 3, 1, 2) 36 | return self.network(observations), None 37 | 38 | def decode_actions(self, hidden, lookup): 39 | action = self.actor(hidden) 40 | value = self.value_function(hidden) 41 | return action, value 42 | -------------------------------------------------------------------------------- /pufferlib/environments/microrts/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment import env_creator, make 2 | 3 | try: 4 | import torch 5 | except ImportError: 6 | pass 7 | else: 8 | from .torch import Policy 9 | try: 10 | from .torch import Recurrent 11 | except: 12 | Recurrent = None 13 | -------------------------------------------------------------------------------- /pufferlib/environments/microrts/environment.py: -------------------------------------------------------------------------------- 1 | from pdb import set_trace as T 2 | import numpy as np 3 | 4 | import warnings 5 | import shimmy 6 | import functools 7 | 8 | import pufferlib.emulation 9 | import pufferlib.environments 10 | 11 | 12 | def env_creator(name='GlobalAgentCombinedRewardEnv'): 13 | return functools.partial(make, name) 14 | 15 | def make(name, buf=None): 16 | '''Gym MicroRTS creation function 17 | 18 | This library appears broken. Step crashes in Java. 19 | ''' 20 | pufferlib.environments.try_import('gym_microrts') 21 | if name == 'GlobalAgentCombinedRewardEnv': 22 | from gym_microrts.envs import GlobalAgentCombinedRewardEnv 23 | else: 24 | raise ValueError(f'Unknown environment: {name}') 25 | 26 | with pufferlib.utils.Suppress(): 27 | return GlobalAgentCombinedRewardEnv() 28 | 29 | env.reset = pufferlib.utils.silence_warnings(env.reset) 30 | env.step = pufferlib.utils.silence_warnings(env.step) 31 | 32 | env = MicroRTS(env) 33 | env = shimmy.GymV21CompatibilityV0(env=env) 34 | return pufferlib.emulation.GymnasiumPufferEnv(env=env, buf=buf) 35 | 36 | class MicroRTS: 37 | def __init__(self, env): 38 | self.env = env 39 | self.observation_space = self.env.observation_space 40 | self.action_space = self.env.action_space 41 | self.render = self.env.render 42 | self.close = self.env.close 43 | self.seed = self.env.seed 44 | 45 | def reset(self): 46 | return self.env.reset().astype(np.int32) 47 | 48 | def step(self, action): 49 | o, r, d, i = self.env.step(action) 50 | return o.astype(np.int32), r, d, i 51 | -------------------------------------------------------------------------------- /pufferlib/environments/microrts/torch.py: -------------------------------------------------------------------------------- 1 | from pufferlib.models import Default as Policy 2 | -------------------------------------------------------------------------------- /pufferlib/environments/minerl/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment import env_creator, make 2 | 3 | try: 4 | import torch 5 | except ImportError: 6 | pass 7 | else: 8 | from .torch import Policy 9 | try: 10 | from .torch import Recurrent 11 | except: 12 | Recurrent = None 13 | -------------------------------------------------------------------------------- /pufferlib/environments/minerl/environment.py: -------------------------------------------------------------------------------- 1 | from pdb import set_trace as T 2 | 3 | import gym 4 | import shimmy 5 | import functools 6 | 7 | import pufferlib 8 | import pufferlib.emulation 9 | import pufferlib.environments 10 | import pufferlib.utils 11 | 12 | 13 | def env_creator(name='MineRLBasaltFindCave-v0'): 14 | return functools.partial(make, name=name) 15 | 16 | def make(name, buf=None): 17 | '''Minecraft environment creation function''' 18 | 19 | pufferlib.environments.try_import('minerl') 20 | 21 | # Monkey patch to add .itmes to old gym.spaces.Dict 22 | #gym.spaces.Dict.items = lambda self: self.spaces.items() 23 | 24 | #with pufferlib.utils.Suppress(): 25 | env = gym.make(name) 26 | 27 | env = shimmy.GymV21CompatibilityV0(env=env) 28 | return pufferlib.emulation.GymnasiumPufferEnv(env, buf=buf) 29 | -------------------------------------------------------------------------------- /pufferlib/environments/minerl/torch.py: -------------------------------------------------------------------------------- 1 | from pufferlib.models import Default as Policy 2 | -------------------------------------------------------------------------------- /pufferlib/environments/minigrid/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment import env_creator, make 2 | 3 | try: 4 | import torch 5 | except ImportError: 6 | pass 7 | else: 8 | from .torch import Policy 9 | try: 10 | from .torch import Recurrent 11 | except: 12 | Recurrent = None 13 | -------------------------------------------------------------------------------- /pufferlib/environments/minigrid/environment.py: -------------------------------------------------------------------------------- 1 | from pdb import set_trace as T 2 | 3 | import gymnasium 4 | import functools 5 | 6 | import pufferlib.emulation 7 | import pufferlib.environments 8 | import pufferlib.postprocess 9 | 10 | ALIASES = { 11 | 'minigrid': 'MiniGrid-LavaGapS7-v0', 12 | } 13 | 14 | 15 | def env_creator(name='minigrid'): 16 | return functools.partial(make, name=name) 17 | 18 | def make(name, render_mode='rgb_array', buf=None): 19 | if name in ALIASES: 20 | name = ALIASES[name] 21 | 22 | minigrid = pufferlib.environments.try_import('minigrid') 23 | env = gymnasium.make(name, render_mode=render_mode) 24 | env = MiniGridWrapper(env) 25 | env = pufferlib.postprocess.EpisodeStats(env) 26 | return pufferlib.emulation.GymnasiumPufferEnv(env=env, buf=buf) 27 | 28 | class MiniGridWrapper: 29 | def __init__(self, env): 30 | self.env = env 31 | self.observation_space = gymnasium.spaces.Dict({ 32 | k: v for k, v in self.env.observation_space.items() if 33 | k != 'mission' 34 | }) 35 | self.action_space = self.env.action_space 36 | self.close = self.env.close 37 | self.render = self.env.render 38 | self.close = self.env.close 39 | self.render_mode = 'rgb_array' 40 | 41 | def reset(self, seed=None, options=None): 42 | self.tick = 0 43 | obs, info = self.env.reset(seed=seed) 44 | del obs['mission'] 45 | return obs, info 46 | 47 | def step(self, action): 48 | obs, reward, done, truncated, info = self.env.step(action) 49 | del obs['mission'] 50 | 51 | self.tick += 1 52 | if self.tick == 100: 53 | done = True 54 | 55 | return obs, reward, done, truncated, info 56 | -------------------------------------------------------------------------------- /pufferlib/environments/minigrid/torch.py: -------------------------------------------------------------------------------- 1 | from pufferlib.models import Default as Policy 2 | -------------------------------------------------------------------------------- /pufferlib/environments/minihack/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment import env_creator, make 2 | 3 | try: 4 | import torch 5 | except ImportError: 6 | pass 7 | else: 8 | from .torch import Policy 9 | try: 10 | from .torch import Recurrent 11 | except: 12 | Recurrent = None 13 | -------------------------------------------------------------------------------- /pufferlib/environments/minihack/environment.py: -------------------------------------------------------------------------------- 1 | from pdb import set_trace as T 2 | 3 | import gym 4 | import shimmy 5 | import functools 6 | 7 | import pufferlib 8 | import pufferlib.emulation 9 | import pufferlib.environments 10 | 11 | 12 | EXTRA_OBS_KEYS = [ 13 | 'tty_chars', 14 | 'tty_colors', 15 | 'tty_cursor', 16 | ] 17 | 18 | ALIASES = { 19 | 'minihack': 'MiniHack-River-v0', 20 | } 21 | 22 | def env_creator(name='minihack'): 23 | return functools.partial(make, name) 24 | 25 | def make(name, buf=None): 26 | '''NetHack binding creation function''' 27 | if name in ALIASES: 28 | name = ALIASES[name] 29 | 30 | import minihack 31 | pufferlib.environments.try_import('minihack') 32 | obs_key = minihack.base.MH_DEFAULT_OBS_KEYS + EXTRA_OBS_KEYS 33 | env = gym.make(name, observation_keys=obs_key) 34 | env = shimmy.GymV21CompatibilityV0(env=env) 35 | env = MinihackWrapper(env) 36 | return pufferlib.emulation.GymnasiumPufferEnv(env=env, buf=buf) 37 | 38 | class MinihackWrapper: 39 | def __init__(self, env): 40 | self.env = env 41 | self.observation_space = self.env.observation_space 42 | self.action_space = self.env.action_space 43 | self.close = self.env.close 44 | self.close = self.env.close 45 | self.render_mode = 'ansi' 46 | 47 | def reset(self, seed=None): 48 | obs, info = self.env.reset(seed=seed) 49 | self.obs = obs 50 | return obs, info 51 | 52 | def step(self, action): 53 | obs, reward, done, truncated, info = self.env.step(action) 54 | self.obs = obs 55 | return obs, reward, done, truncated, info 56 | 57 | def render(self): 58 | import nle 59 | chars = nle.nethack.tty_render( 60 | self.obs['tty_chars'], self.obs['tty_colors'], self.obs['tty_cursor']) 61 | return chars 62 | 63 | -------------------------------------------------------------------------------- /pufferlib/environments/minihack/torch.py: -------------------------------------------------------------------------------- 1 | from pdb import set_trace as T 2 | 3 | import pufferlib.pytorch 4 | from pufferlib.environments.nethack import Policy 5 | 6 | class Recurrent(pufferlib.models.LSTMWrapper): 7 | def __init__(self, env, policy, input_size=512, hidden_size=512, num_layers=1): 8 | super().__init__(env, policy, input_size, hidden_size, num_layers) 9 | -------------------------------------------------------------------------------- /pufferlib/environments/mujoco/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment import env_creator 2 | 3 | try: 4 | # NOTE: demo.py looks the policy class from the torch module 5 | import pufferlib.environments.mujoco.policy as torch 6 | except ImportError: 7 | pass 8 | else: 9 | from .policy import Policy 10 | try: 11 | from .policy import Recurrent 12 | except: 13 | Recurrent = None -------------------------------------------------------------------------------- /pufferlib/environments/mujoco/environment.py: -------------------------------------------------------------------------------- 1 | 2 | from pdb import set_trace as T 3 | 4 | import functools 5 | 6 | import numpy as np 7 | import gymnasium 8 | 9 | import pufferlib 10 | import pufferlib.emulation 11 | import pufferlib.environments 12 | import pufferlib.postprocess 13 | 14 | 15 | def single_env_creator(env_name, capture_video, gamma, run_name=None, idx=None, obs_norm=True, pufferl=False, buf=None): 16 | if capture_video and idx == 0: 17 | assert run_name is not None, "run_name must be specified when capturing videos" 18 | env = gymnasium.make(env_name, render_mode="rgb_array") 19 | env = gymnasium.wrappers.RecordVideo(env, f"videos/{run_name}") 20 | else: 21 | env = gymnasium.make(env_name) 22 | 23 | env = pufferlib.postprocess.ClipAction(env) # NOTE: this changed actions space 24 | env = pufferlib.postprocess.EpisodeStats(env) 25 | 26 | if obs_norm: 27 | env = gymnasium.wrappers.NormalizeObservation(env) 28 | env = gymnasium.wrappers.TransformObservation(env, lambda obs: np.clip(obs, -10, 10)) 29 | 30 | env = gymnasium.wrappers.NormalizeReward(env, gamma=gamma) 31 | env = gymnasium.wrappers.TransformReward(env, lambda reward: np.clip(reward, -10, 10)) 32 | 33 | if pufferl is True: 34 | env = pufferlib.emulation.GymnasiumPufferEnv(env=env, buf=buf) 35 | 36 | return env 37 | 38 | 39 | def cleanrl_env_creator(env_name, run_name, capture_video, gamma, idx): 40 | kwargs = { 41 | "env_name": env_name, 42 | "run_name": run_name, 43 | "capture_video": capture_video, 44 | "gamma": gamma, 45 | "idx": idx, 46 | "pufferl": False, 47 | } 48 | return functools.partial(single_env_creator, **kwargs) 49 | 50 | 51 | # Keep it simple for pufferl demo, for now 52 | def env_creator(env_name="HalfCheetah-v4", gamma=0.99): 53 | default_kwargs = { 54 | "env_name": env_name, 55 | "capture_video": False, 56 | "gamma": gamma, 57 | "pufferl": True, 58 | } 59 | return functools.partial(single_env_creator, **default_kwargs) 60 | -------------------------------------------------------------------------------- /pufferlib/environments/nethack/Hack-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/environments/nethack/Hack-Regular.ttf -------------------------------------------------------------------------------- /pufferlib/environments/nethack/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment import env_creator, make 2 | 3 | try: 4 | import torch 5 | except ImportError: 6 | pass 7 | else: 8 | from .torch import Policy 9 | try: 10 | from .torch import Recurrent 11 | except: 12 | Recurrent = None 13 | -------------------------------------------------------------------------------- /pufferlib/environments/nethack/environment.py: -------------------------------------------------------------------------------- 1 | from pdb import set_trace as T 2 | 3 | import shimmy 4 | import gym 5 | import functools 6 | 7 | import pufferlib 8 | import pufferlib.emulation 9 | import pufferlib.environments 10 | import pufferlib.postprocess 11 | #from .wrapper import RenderCharImagesWithNumpyWrapper 12 | 13 | def env_creator(name='nethack'): 14 | return functools.partial(make, name) 15 | 16 | def make(name, buf=None): 17 | '''NetHack binding creation function''' 18 | if name == 'nethack': 19 | name = 'NetHackScore-v0' 20 | 21 | nle = pufferlib.environments.try_import('nle') 22 | env = gym.make(name) 23 | #env = RenderCharImagesWithNumpyWrapper(env) 24 | env = shimmy.GymV21CompatibilityV0(env=env) 25 | env = NethackWrapper(env) 26 | env = pufferlib.postprocess.EpisodeStats(env) 27 | return pufferlib.emulation.GymnasiumPufferEnv(env=env, buf=buf) 28 | 29 | class NethackWrapper: 30 | def __init__(self, env): 31 | self.env = env 32 | self.observation_space = self.env.observation_space 33 | self.action_space = self.env.action_space 34 | self.close = self.env.close 35 | self.close = self.env.close 36 | self.render_mode = 'ansi' 37 | 38 | def reset(self, seed=None): 39 | obs, info = self.env.reset(seed=seed) 40 | self.obs = obs 41 | return obs, info 42 | 43 | def step(self, action): 44 | obs, reward, done, truncated, info = self.env.step(action) 45 | self.obs = obs 46 | return obs, reward, done, truncated, info 47 | 48 | def render(self): 49 | import nle 50 | chars = nle.nethack.tty_render( 51 | self.obs['tty_chars'], self.obs['tty_colors'], self.obs['tty_cursor']) 52 | return chars 53 | -------------------------------------------------------------------------------- /pufferlib/environments/nethack/torch.py: -------------------------------------------------------------------------------- 1 | from pdb import set_trace as T 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | import pufferlib.models 8 | import pufferlib.pytorch 9 | from pufferlib.pytorch import layer_init 10 | 11 | 12 | class Recurrent(pufferlib.models.LSTMWrapper): 13 | def __init__(self, env, policy, input_size=256, hidden_size=256, num_layers=1): 14 | super().__init__(env, policy, input_size, hidden_size, num_layers) 15 | 16 | class Policy(nn.Module): 17 | def __init__(self, env): 18 | super().__init__() 19 | self.dtype = pufferlib.pytorch.nativize_dtype(env.emulated) 20 | 21 | self.blstats_net = nn.Sequential( 22 | nn.Embedding(256, 32), 23 | nn.Flatten(), 24 | ) 25 | 26 | self.char_embed = nn.Embedding(256, 32) 27 | self.chars_net = nn.Sequential( 28 | layer_init(nn.Conv2d(32, 32, 5, stride=(2, 3))), 29 | nn.ReLU(), 30 | layer_init(nn.Conv2d(32, 64, 5, stride=(1, 3))), 31 | nn.ReLU(), 32 | layer_init(nn.Conv2d(64, 64, 3, stride=1)), 33 | nn.ReLU(), 34 | nn.Flatten(), 35 | ) 36 | 37 | self.proj = nn.Linear(864+960, 256) 38 | self.actor = layer_init(nn.Linear(256, 8), std=0.01) 39 | self.critic = layer_init(nn.Linear(256, 1), std=1) 40 | 41 | def forward(self, x): 42 | hidden = self.encode_observations(x) 43 | actions, value = self.decode_actions(hidden, None) 44 | return actions, value 45 | 46 | def encode_observations(self, x): 47 | x = x.type(torch.uint8) # Undo bad cleanrl cast 48 | x = pufferlib.pytorch.nativize_tensor(x, self.dtype) 49 | 50 | blstats = torch.clip(x['blstats'] + 1, 0, 255).int() 51 | blstats = self.blstats_net(blstats) 52 | 53 | chars = self.char_embed(x['chars'].int()) 54 | chars = torch.permute(chars, (0, 3, 1, 2)) 55 | chars = self.chars_net(chars) 56 | 57 | concat = torch.cat([blstats, chars], dim=1) 58 | return self.proj(concat) 59 | 60 | def decode_actions(self, hidden, lookup, concat=None): 61 | value = self.critic(hidden) 62 | action = self.actor(hidden) 63 | return action, value 64 | -------------------------------------------------------------------------------- /pufferlib/environments/nmmo/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment import env_creator, make 2 | 3 | try: 4 | import torch 5 | except ImportError: 6 | pass 7 | else: 8 | from .torch import Policy 9 | try: 10 | from .torch import Recurrent 11 | except: 12 | Recurrent = None 13 | -------------------------------------------------------------------------------- /pufferlib/environments/nmmo/environment.py: -------------------------------------------------------------------------------- 1 | from pdb import set_trace as T 2 | import numpy as np 3 | import functools 4 | 5 | import pufferlib 6 | import pufferlib.emulation 7 | import pufferlib.environments 8 | import pufferlib.wrappers 9 | import pufferlib.postprocess 10 | 11 | 12 | def env_creator(name='nmmo'): 13 | return functools.partial(make, name) 14 | 15 | def make(name, *args, buf=None, **kwargs): 16 | '''Neural MMO creation function''' 17 | nmmo = pufferlib.environments.try_import('nmmo') 18 | env = nmmo.Env(*args, **kwargs) 19 | env = NMMOWrapper(env) 20 | env = pufferlib.postprocess.MultiagentEpisodeStats(env) 21 | env = pufferlib.postprocess.MeanOverAgents(env) 22 | return pufferlib.emulation.PettingZooPufferEnv(env=env, buf=buf) 23 | 24 | class NMMOWrapper(pufferlib.postprocess.PettingZooWrapper): 25 | '''Remove task spam''' 26 | @property 27 | def render_mode(self): 28 | return 'rgb_array' 29 | 30 | def render(self): 31 | '''Quick little renderer for NMMO''' 32 | tiles = self.env.tile_map[:, :, 2].astype(np.uint8) 33 | render = np.zeros((tiles.shape[0], tiles.shape[1], 3), dtype=np.uint8) 34 | BROWN = (136, 69, 19) 35 | render[tiles == 1] = (0, 0, 255) 36 | render[tiles == 2] = (0, 255, 0) 37 | render[tiles == 3] = BROWN 38 | render[tiles == 4] = (64, 255, 64) 39 | render[tiles == 5] = (128, 128, 128) 40 | render[tiles == 6] = BROWN 41 | render[tiles == 7] = (255, 128, 128) 42 | render[tiles == 8] = BROWN 43 | render[tiles == 9] = (128, 255, 128) 44 | render[tiles == 10] = BROWN 45 | render[tiles == 11] = (128, 128, 255) 46 | render[tiles == 12] = BROWN 47 | render[tiles == 13] = (192, 255, 192) 48 | render[tiles == 14] = (0, 0, 255) 49 | render[tiles == 15] = (64, 64, 255) 50 | 51 | for agent in self.env.realm.players.values(): 52 | agent_r = agent.row.val 53 | agent_c = agent.col.val 54 | render[agent_r, agent_c, :] = (255, 255, 0) 55 | 56 | for npc in self.env.realm.npcs.values(): 57 | agent_r = npc.row.val 58 | agent_c = npc.col.val 59 | render[agent_r, agent_c, :] = (255, 0, 0) 60 | 61 | return render 62 | 63 | def reset(self, seed=None): 64 | obs, infos = self.env.reset(seed=seed) 65 | self.obs = obs 66 | return obs, infos 67 | 68 | def step(self, actions): 69 | obs, rewards, dones, truncateds, infos = self.env.step(actions) 70 | infos = {k: list(v['task'].values())[0] for k, v in infos.items()} 71 | self.obs = obs 72 | return obs, rewards, dones, truncateds, infos 73 | 74 | def close(self): 75 | return self.env.close() 76 | 77 | 78 | -------------------------------------------------------------------------------- /pufferlib/environments/open_spiel/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment import env_creator, make 2 | 3 | try: 4 | import torch 5 | except ImportError: 6 | pass 7 | else: 8 | from .torch import Policy 9 | try: 10 | from .torch import Recurrent 11 | except: 12 | Recurrent = None 13 | -------------------------------------------------------------------------------- /pufferlib/environments/open_spiel/environment.py: -------------------------------------------------------------------------------- 1 | from pdb import set_trace as T 2 | import numpy as np 3 | import functools 4 | 5 | import pufferlib 6 | from pufferlib import namespace 7 | import pufferlib.emulation 8 | import pufferlib.environments 9 | 10 | 11 | def env_creator(name='connect_four'): 12 | '''OpenSpiel creation function''' 13 | return functools.partial(make, name) 14 | 15 | def make( 16 | name, 17 | multiplayer=False, 18 | n_rollouts=5, 19 | max_simulations=10, 20 | min_simulations=None, 21 | buf=None 22 | ): 23 | '''OpenSpiel creation function''' 24 | pyspiel = pufferlib.environments.try_import('pyspiel', 'open_spiel') 25 | env = pyspiel.load_game(name) 26 | 27 | if min_simulations is None: 28 | min_simulations = max_simulations 29 | 30 | from pufferlib.environments.open_spiel.gymnasium_environment import ( 31 | OpenSpielGymnasiumEnvironment 32 | ) 33 | from pufferlib.environments.open_spiel.pettingzoo_environment import ( 34 | OpenSpielPettingZooEnvironment 35 | ) 36 | 37 | kwargs = dict( 38 | env=env, 39 | n_rollouts=int(n_rollouts), 40 | min_simulations=int(min_simulations), 41 | max_simulations=int(max_simulations), 42 | ) 43 | 44 | if multiplayer: 45 | env = OpenSpielPettingZooEnvironment(**kwargs) 46 | wrapper_cls = pufferlib.emulation.PettingZooPufferEnv 47 | else: 48 | env = OpenSpielGymnasiumEnvironment(**kwargs) 49 | wrapper_cls = pufferlib.emulation.GymnasiumPufferEnv 50 | 51 | return wrapper_cls( 52 | env=env, 53 | postprocessor_cls=pufferlib.emulation.BasicPostprocessor, 54 | buf=buf, 55 | ) 56 | 57 | -------------------------------------------------------------------------------- /pufferlib/environments/open_spiel/gymnasium_environment.py: -------------------------------------------------------------------------------- 1 | from pdb import set_trace as T 2 | import numpy as np 3 | 4 | from open_spiel.python.algorithms import mcts 5 | 6 | import pufferlib 7 | from pufferlib import namespace 8 | from pufferlib.environments.open_spiel.utils import ( 9 | solve_chance_nodes, 10 | get_obs_and_infos, 11 | observation_space, 12 | action_space, 13 | init, 14 | render, 15 | close, 16 | ) 17 | 18 | 19 | def create_bots(state, seed): 20 | assert seed is not None, 'seed must be set' 21 | rnd_state = np.random.RandomState(seed) 22 | 23 | evaluator = mcts.RandomRolloutEvaluator( 24 | n_rollouts=state.n_rollouts, 25 | random_state=rnd_state 26 | ) 27 | 28 | return [mcts.MCTSBot( 29 | game=state.env, 30 | uct_c=2, 31 | max_simulations=a, 32 | evaluator=evaluator, 33 | random_state=rnd_state, 34 | child_selection_fn=mcts.SearchNode.puct_value, 35 | solve=True, 36 | ) for a in range(state.min_simulations, state.max_simulations + 1)] 37 | 38 | def reset(state, seed = None, options = None): 39 | state.state = state.env.new_initial_state() 40 | 41 | if not state.has_reset: 42 | state.has_reset = True 43 | state.seed_value = seed 44 | np.random.seed(seed) 45 | state.all_bots = create_bots(state, seed) 46 | 47 | state.bot = np.random.choice(state.all_bots) 48 | 49 | if np.random.rand() < 0.5: 50 | bot_atn = state.bot.step(state.state) 51 | state.state.apply_action(bot_atn) 52 | 53 | obs, infos = get_obs_and_infos(state) 54 | player = state.state.current_player() 55 | return obs[player], infos[player] 56 | 57 | def step(state, action): 58 | player = state.state.current_player() 59 | solve_chance_nodes(state) 60 | state.state.apply_action(action) 61 | 62 | # Take other move with a bot 63 | if not state.state.is_terminal(): 64 | bot_atn = state.bot.step(state.state) 65 | solve_chance_nodes(state) 66 | state.state.apply_action(bot_atn) 67 | 68 | # Now that we have applied all actions, get the next obs. 69 | obs, all_infos = get_obs_and_infos(state) 70 | reward = state.state.returns()[player] 71 | info = all_infos[player] 72 | 73 | # Are we done? 74 | terminated = state.state.is_terminal() 75 | if terminated: 76 | key = f'win_mcts_{state.bot.max_simulations}' 77 | info[key] = int(reward==1) 78 | 79 | return obs[player], reward, terminated, False, info 80 | 81 | class OpenSpielGymnasiumEnvironment: 82 | __init__ = init 83 | step = step 84 | reset = reset 85 | observation_space = property(observation_space) 86 | action_space = property(action_space) 87 | render = render 88 | close = close 89 | -------------------------------------------------------------------------------- /pufferlib/environments/open_spiel/pettingzoo_environment.py: -------------------------------------------------------------------------------- 1 | from pdb import set_trace as T 2 | import numpy as np 3 | 4 | import pufferlib 5 | from pufferlib import namespace 6 | 7 | from pufferlib.environments.open_spiel.utils import ( 8 | solve_chance_nodes, 9 | get_obs_and_infos, 10 | observation_space, 11 | action_space, 12 | init, 13 | render, 14 | close, 15 | ) 16 | 17 | def agents(state): 18 | return state.agents 19 | 20 | def possible_agents(state): 21 | return list(range(state.env.num_players())) 22 | 23 | def pz_observation_space(state, agent): 24 | return observation_space(state) 25 | 26 | def pz_action_space(state, agent): 27 | return action_space(state) 28 | 29 | def reset(state, seed = None, options = None): 30 | state.state = state.env.new_initial_state() 31 | obs, infos = get_obs_and_infos(state) 32 | state.agents = state.possible_agents 33 | 34 | if not state.has_reset: 35 | state.has_reset = True 36 | state.seed_value = seed 37 | np.random.seed(seed) 38 | 39 | return obs, infos 40 | 41 | def step(state, actions): 42 | curr_player = state.state.current_player() 43 | solve_chance_nodes(state) 44 | state.state.apply_action(actions[curr_player]) 45 | obs, infos = get_obs_and_infos(state) 46 | rewards = {ag: r for ag, r in enumerate(state.state.returns())} 47 | 48 | # Are we done? 49 | is_terminated = state.state.is_terminal() 50 | terminateds = {a: False for a in obs} 51 | truncateds = {a: False for a in obs} 52 | 53 | if is_terminated: 54 | terminateds = {a: True for a in state.possible_agents} 55 | state.agents = [] 56 | 57 | return obs, rewards, terminateds, truncateds, infos 58 | 59 | class OpenSpielPettingZooEnvironment: 60 | __init__ = init 61 | step = step 62 | reset = reset 63 | agents = lambda state: state.agents 64 | possible_agents = property(possible_agents) 65 | observation_space = pz_observation_space 66 | action_space = pz_action_space 67 | render = render 68 | close = close 69 | -------------------------------------------------------------------------------- /pufferlib/environments/open_spiel/torch.py: -------------------------------------------------------------------------------- 1 | from pdb import set_trace as T 2 | import numpy as np 3 | 4 | import torch 5 | from torch import nn 6 | 7 | import pufferlib.emulation 8 | from pufferlib.models import Policy as Base 9 | 10 | class Policy(Base): 11 | def __init__(self, env, input_size=128, hidden_size=128): 12 | '''Default PyTorch policy, meant for debugging. 13 | This should run with any environment but is unlikely to learn anything. 14 | 15 | Uses a single linear layer + relu to encode observations and a list of 16 | linear layers to decode actions. The value function is a single linear layer. 17 | ''' 18 | super().__init__(env) 19 | 20 | self.flat_observation_space = env.flat_observation_space 21 | self.flat_observation_structure = env.flat_observation_structure 22 | 23 | self.encoder = nn.Linear(np.prod( 24 | env.structured_observation_space['obs'].shape), hidden_size) 25 | self.decoder = nn.Linear(hidden_size, self.action_space.n) 26 | 27 | self.value_head = nn.Linear(hidden_size, 1) 28 | 29 | def encode_observations(self, observations): 30 | '''Linear encoder function''' 31 | observations = pufferlib.emulation.unpack_batched_obs(observations, 32 | self.flat_observation_space, self.flat_observation_structure) 33 | obs = observations['obs'].view(observations['obs'].shape[0], -1) 34 | self.action_mask = observations['action_mask'] 35 | 36 | hidden = torch.relu(self.encoder(obs)) 37 | return hidden, None 38 | 39 | def decode_actions(self, hidden, lookup, concat=True): 40 | '''Concatenated linear decoder function''' 41 | value = self.value_head(hidden) 42 | action = self.decoder(hidden) 43 | action = action.masked_fill(self.action_mask == 0, -1e9) 44 | return action, value -------------------------------------------------------------------------------- /pufferlib/environments/pokemon_red/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment import env_creator, make 2 | 3 | try: 4 | import torch 5 | except ImportError: 6 | pass 7 | else: 8 | from .torch import Policy 9 | try: 10 | from .torch import Recurrent 11 | except: 12 | Recurrent = None 13 | -------------------------------------------------------------------------------- /pufferlib/environments/pokemon_red/environment.py: -------------------------------------------------------------------------------- 1 | from pdb import set_trace as T 2 | 3 | import gymnasium 4 | import functools 5 | 6 | from pokegym import Environment 7 | 8 | import pufferlib.emulation 9 | import pufferlib.postprocess 10 | 11 | 12 | def env_creator(name='pokemon_red'): 13 | return functools.partial(make, name) 14 | 15 | def make(name, headless: bool = True, state_path=None, buf=None): 16 | '''Pokemon Red''' 17 | env = Environment(headless=headless, state_path=state_path) 18 | env = RenderWrapper(env) 19 | env = pufferlib.postprocess.EpisodeStats(env) 20 | return pufferlib.emulation.GymnasiumPufferEnv(env=env, buf=buf) 21 | 22 | class RenderWrapper(gymnasium.Wrapper): 23 | def __init__(self, env): 24 | self.env = env 25 | 26 | @property 27 | def render_mode(self): 28 | return 'rgb_array' 29 | 30 | def render(self): 31 | return self.env.screen.screen_ndarray() 32 | -------------------------------------------------------------------------------- /pufferlib/environments/pokemon_red/torch.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import torch 3 | 4 | import pufferlib.models 5 | 6 | 7 | class Recurrent(pufferlib.models.LSTMWrapper): 8 | def __init__(self, env, policy, 9 | input_size=512, hidden_size=512, num_layers=1): 10 | super().__init__(env, policy, 11 | input_size, hidden_size, num_layers) 12 | 13 | class Policy(pufferlib.models.Convolutional): 14 | def __init__(self, env, 15 | input_size=512, hidden_size=512, output_size=512, 16 | framestack=4, flat_size=64*5*6): 17 | super().__init__( 18 | env=env, 19 | input_size=input_size, 20 | hidden_size=hidden_size, 21 | output_size=output_size, 22 | framestack=framestack, 23 | flat_size=flat_size, 24 | channels_last=True, 25 | ) 26 | 27 | 28 | ''' 29 | class Policy(pufferlib.models.ProcgenResnet): 30 | def __init__(self, env, cnn_width=16, mlp_width=512): 31 | super().__init__( 32 | env=env, 33 | cnn_width=cnn_width, 34 | mlp_width=mlp_width, 35 | ) 36 | ''' 37 | -------------------------------------------------------------------------------- /pufferlib/environments/procgen/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment import env_creator, make 2 | 3 | try: 4 | import torch 5 | except ImportError: 6 | pass 7 | else: 8 | from .torch import Policy 9 | try: 10 | from .torch import Recurrent 11 | except: 12 | Recurrent = None 13 | -------------------------------------------------------------------------------- /pufferlib/environments/procgen/torch.py: -------------------------------------------------------------------------------- 1 | from pdb import set_trace as T 2 | from torch import nn 3 | import pufferlib.models 4 | 5 | # This policy ended up being useful broadly 6 | # so I included it in the defaults 7 | 8 | class Recurrent(pufferlib.models.LSTMWrapper): 9 | def __init__(self, env, policy, input_size=256, hidden_size=256, num_layers=1): 10 | super().__init__(env, policy, input_size, hidden_size, num_layers) 11 | 12 | class Policy (nn.Module): 13 | def __init__(self, env, *args, input_size=256, hidden_size=256, 14 | output_size=256, **kwargs): 15 | '''The CleanRL default NatureCNN policy used for Atari. 16 | It's just a stack of three convolutions followed by a linear layer 17 | 18 | Takes framestack as a mandatory keyword argument. Suggested default is 1 frame 19 | with LSTM or 4 frames without.''' 20 | super().__init__() 21 | 22 | self.network= nn.Sequential( 23 | pufferlib.pytorch.layer_init(nn.Conv2d(3, 16, 8, stride=4)), 24 | nn.ReLU(), 25 | pufferlib.pytorch.layer_init(nn.Conv2d(16, 32, 4, stride=2)), 26 | nn.ReLU(), 27 | nn.Flatten(), 28 | pufferlib.pytorch.layer_init(nn.Linear(1152, hidden_size)), 29 | nn.ReLU(), 30 | ) 31 | self.actor = pufferlib.pytorch.layer_init( 32 | nn.Linear(hidden_size, env.single_action_space.n), std=0.01) 33 | self.value_fn = pufferlib.pytorch.layer_init( 34 | nn.Linear(output_size, 1), std=1) 35 | 36 | def forward(self, observations): 37 | hidden, lookup = self.encode_observations(observations) 38 | actions, value = self.decode_actions(hidden, lookup) 39 | return actions, value 40 | 41 | def encode_observations(self, observations): 42 | observations = observations.permute(0, 3, 1, 2) 43 | return self.network(observations.float() / 255.0), None 44 | 45 | def decode_actions(self, flat_hidden, lookup, concat=None): 46 | action = self.actor(flat_hidden) 47 | value = self.value_fn(flat_hidden) 48 | return action, value 49 | 50 | Policy = pufferlib.models.ProcgenResnet 51 | -------------------------------------------------------------------------------- /pufferlib/environments/slimevolley/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment import env_creator 2 | 3 | try: 4 | import torch 5 | except ImportError: 6 | pass 7 | else: 8 | from .torch import Policy 9 | try: 10 | from .torch import Recurrent 11 | except: 12 | Recurrent = None 13 | -------------------------------------------------------------------------------- /pufferlib/environments/slimevolley/environment.py: -------------------------------------------------------------------------------- 1 | from pdb import set_trace as T 2 | import numpy as np 3 | import functools 4 | 5 | import gym 6 | import shimmy 7 | 8 | import pufferlib 9 | import pufferlib.emulation 10 | import pufferlib.environments 11 | import pufferlib.utils 12 | import pufferlib.postprocess 13 | 14 | 15 | def env_creator(name='SlimeVolley-v0'): 16 | return functools.partial(make, name) 17 | 18 | def make(name, render_mode='rgb_array', buf=None): 19 | if name == 'slimevolley': 20 | name = 'SlimeVolley-v0' 21 | 22 | from slimevolleygym import SlimeVolleyEnv 23 | SlimeVolleyEnv.atari_mode = True 24 | env = SlimeVolleyEnv() 25 | env.policy.predict = lambda obs: np.random.randint(0, 2, 3) 26 | env = SlimeVolleyMultiDiscrete(env) 27 | env = SkipWrapper(env, repeat_count=4) 28 | env = shimmy.GymV21CompatibilityV0(env=env) 29 | env = pufferlib.postprocess.EpisodeStats(env) 30 | return pufferlib.emulation.GymnasiumPufferEnv(env=env, buf=buf) 31 | 32 | class SlimeVolleyMultiDiscrete(gym.Wrapper): 33 | def __init__(self, env): 34 | super().__init__(env) 35 | #self.action_space = gym.spaces.MultiDiscrete( 36 | # [2 for _ in range(env.action_space.n)]) 37 | 38 | def reset(self, seed=None): 39 | return self.env.reset().astype(np.float32) 40 | 41 | def step(self, action): 42 | obs, reward, done, info = self.env.step(action) 43 | return obs.astype(np.float32), reward, done, info 44 | 45 | class SkipWrapper(gym.Wrapper): 46 | """ 47 | Generic common frame skipping wrapper 48 | Will perform action for `x` additional steps 49 | """ 50 | def __init__(self, env, repeat_count): 51 | super(SkipWrapper, self).__init__(env) 52 | self.repeat_count = repeat_count 53 | self.stepcount = 0 54 | 55 | def step(self, action): 56 | done = False 57 | total_reward = 0 58 | current_step = 0 59 | while current_step < (self.repeat_count + 1) and not done: 60 | self.stepcount += 1 61 | obs, reward, done, info = self.env.step(action) 62 | total_reward += reward 63 | current_step += 1 64 | 65 | return obs, total_reward, done, info 66 | 67 | def reset(self): 68 | self.stepcount = 0 69 | return self.env.reset() 70 | 71 | -------------------------------------------------------------------------------- /pufferlib/environments/slimevolley/torch.py: -------------------------------------------------------------------------------- 1 | import pufferlib.models 2 | 3 | Recurrent = pufferlib.models.LSTMWrapper 4 | Policy = pufferlib.models.Default 5 | -------------------------------------------------------------------------------- /pufferlib/environments/smac/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment import env_creator, make 2 | 3 | try: 4 | import torch 5 | except ImportError: 6 | pass 7 | else: 8 | from .torch import Policy 9 | try: 10 | from .torch import Recurrent 11 | except: 12 | Recurrent = None 13 | -------------------------------------------------------------------------------- /pufferlib/environments/smac/environment.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import pufferlib 4 | import pufferlib.emulation 5 | import pufferlib.environments 6 | import pufferlib.wrappers 7 | 8 | 9 | def env_creator(name='smac'): 10 | return functools.partial(make, name) 11 | 12 | def make(name, buf=None): 13 | '''Starcraft Multiagent Challenge creation function 14 | 15 | Support for SMAC is WIP because environments do not function without 16 | an action-masked baseline policy.''' 17 | pufferlib.environments.try_import('smac') 18 | from smac.env.pettingzoo.StarCraft2PZEnv import _parallel_env as smac_env 19 | 20 | env = smac_env(1000) 21 | env = pufferlib.wrappers.PettingZooTruncatedWrapper(env) 22 | env = pufferlib.emulation.PettingZooPufferEnv(env, buf=buf) 23 | return env 24 | -------------------------------------------------------------------------------- /pufferlib/environments/smac/torch.py: -------------------------------------------------------------------------------- 1 | from pufferlib.models import Default as Policy 2 | -------------------------------------------------------------------------------- /pufferlib/environments/stable_retro/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment import env_creator, make 2 | 3 | try: 4 | import torch 5 | except ImportError: 6 | pass 7 | else: 8 | from .torch import Policy 9 | try: 10 | from .torch import Recurrent 11 | except: 12 | Recurrent = None 13 | -------------------------------------------------------------------------------- /pufferlib/environments/stable_retro/environment.py: -------------------------------------------------------------------------------- 1 | from pdb import set_trace as T 2 | import numpy as np 3 | 4 | import gymnasium as gym 5 | import functools 6 | 7 | import pufferlib 8 | import pufferlib.emulation 9 | import pufferlib.environments 10 | 11 | 12 | def env_creator(name='Airstriker-Genesis'): 13 | return functools.partial(make, name) 14 | 15 | def make(name='Airstriker-Genesis', framestack=4, buf=None): 16 | '''Atari creation function with default CleanRL preprocessing based on Stable Baselines3 wrappers''' 17 | retro = pufferlib.environments.try_import('retro', 'stable-retro') 18 | 19 | from stable_baselines3.common.atari_wrappers import ( 20 | ClipRewardEnv, 21 | EpisodicLifeEnv, 22 | FireResetEnv, 23 | MaxAndSkipEnv, 24 | ) 25 | with pufferlib.utils.Suppress(): 26 | env = retro.make(name) 27 | 28 | env = gym.wrappers.RecordEpisodeStatistics(env) 29 | env = MaxAndSkipEnv(env, skip=4) 30 | env = ClipRewardEnv(env) 31 | env = gym.wrappers.ResizeObservation(env, (84, 84)) 32 | env = gym.wrappers.GrayScaleObservation(env) 33 | env = gym.wrappers.FrameStack(env, framestack) 34 | return pufferlib.emulation.GymnasiumPufferEnv( 35 | env=env, postprocessor_cls=AtariFeaturizer, buf=buf) 36 | 37 | class AtariFeaturizer(pufferlib.emulation.Postprocessor): 38 | def reset(self, obs): 39 | self.epoch_return = 0 40 | self.epoch_length = 0 41 | self.done = False 42 | 43 | #@property 44 | #def observation_space(self): 45 | # return gym.spaces.Box(0, 255, (1, 84, 84), dtype=np.uint8) 46 | 47 | def observation(self, obs): 48 | return np.array(obs) 49 | return np.array(obs[1], dtype=np.float32) 50 | 51 | def reward_done_truncated_info(self, reward, done, truncated, info): 52 | return reward, done, truncated, info 53 | if 'lives' in info: 54 | if info['lives'] == 0 and done: 55 | info['return'] = info['episode']['r'] 56 | info['length'] = info['episode']['l'] 57 | info['time'] = info['episode']['t'] 58 | return reward, True, info 59 | return reward, False, info 60 | 61 | if self.done: 62 | return reward, done, info 63 | 64 | if done: 65 | info['return'] = self.epoch_return 66 | info['length'] = self.epoch_length 67 | self.done = True 68 | else: 69 | self.epoch_length += 1 70 | self.epoch_return += reward 71 | 72 | return reward, done, info 73 | -------------------------------------------------------------------------------- /pufferlib/environments/stable_retro/torch.py: -------------------------------------------------------------------------------- 1 | import pufferlib.models 2 | 3 | 4 | class Recurrent: 5 | input_size = 512 6 | hidden_size = 512 7 | num_layers = 1 8 | 9 | class Policy(pufferlib.models.Convolutional): 10 | def __init__(self, env, input_size=512, hidden_size=512, output_size=512, 11 | framestack=4, flat_size=64*7*7): 12 | super().__init__( 13 | env=env, 14 | input_size=input_size, 15 | hidden_size=hidden_size, 16 | output_size=output_size, 17 | framestack=framestack, 18 | flat_size=flat_size, 19 | ) 20 | -------------------------------------------------------------------------------- /pufferlib/environments/test/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment import ( 2 | GymnasiumPerformanceEnv, 3 | PettingZooPerformanceEnv, 4 | GymnasiumTestEnv, 5 | PettingZooTestEnv, 6 | make_all_mock_environments, 7 | MOCK_OBSERVATION_SPACES, 8 | MOCK_ACTION_SPACES, 9 | ) 10 | 11 | from .mock_environments import MOCK_SINGLE_AGENT_ENVIRONMENTS 12 | from .mock_environments import MOCK_MULTI_AGENT_ENVIRONMENTS 13 | 14 | try: 15 | import torch 16 | except ImportError: 17 | pass 18 | else: 19 | from .torch import Policy 20 | try: 21 | from .torch import Recurrent 22 | except: 23 | Recurrent = None 24 | -------------------------------------------------------------------------------- /pufferlib/environments/test/torch.py: -------------------------------------------------------------------------------- 1 | from pufferlib.models import Default as Policy 2 | -------------------------------------------------------------------------------- /pufferlib/environments/trade_sim/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment import env_creator, make 2 | 3 | try: 4 | import torch 5 | except ImportError: 6 | pass 7 | else: 8 | from .torch import Policy 9 | try: 10 | from .torch import Recurrent 11 | except: 12 | Recurrent = None 13 | -------------------------------------------------------------------------------- /pufferlib/environments/trade_sim/environment.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import numpy as np 3 | 4 | import pufferlib 5 | 6 | from src.simulation.env import TradingEnvironment 7 | 8 | def env_creator(name='metta'): 9 | return functools.partial(make, name) 10 | 11 | def make(name, config_path='trade_sim/config/experiment_config.yaml', render_mode='human', buf=None, seed=1): 12 | '''Crafter creation function''' 13 | from src.utils.config_manager import ConfigManager 14 | from src.data_ingestion.historical_data_reader import HistoricalDataReader 15 | 16 | config_manager = ConfigManager(config_path) 17 | config = config_manager.config 18 | data_reader = HistoricalDataReader(config_manager) 19 | data, _ = data_reader.preprocess_data() 20 | 21 | # Create environment 22 | env = TradingEnvironmentPuff(config_manager.config, data) 23 | return pufferlib.emulation.GymnasiumPufferEnv(env, buf=buf) 24 | 25 | class TradingEnvironmentPuff(TradingEnvironment): 26 | def __init__(self, config, data): 27 | super().__init__(config, data) 28 | 29 | def reset(self): 30 | obs, info = super().reset() 31 | return obs.astype(np.float32), info 32 | 33 | def step(self, action): 34 | obs, reward, terminated, truncated, info = super().step(action) 35 | 36 | if not terminated and not truncated: 37 | info = {} 38 | 39 | return obs.astype(np.float32), reward, terminated, truncated, info 40 | 41 | -------------------------------------------------------------------------------- /pufferlib/environments/trade_sim/torch.py: -------------------------------------------------------------------------------- 1 | import pufferlib.models 2 | 3 | Policy = pufferlib.models.Default 4 | Recurrent = pufferlib.models.LSTMWrapper 5 | -------------------------------------------------------------------------------- /pufferlib/environments/vizdoom/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment import env_creator 2 | 3 | try: 4 | import torch 5 | except ImportError: 6 | pass 7 | else: 8 | from .torch import Policy 9 | try: 10 | from .torch import Recurrent 11 | except: 12 | Recurrent = None 13 | -------------------------------------------------------------------------------- /pufferlib/environments/vizdoom/torch.py: -------------------------------------------------------------------------------- 1 | import pufferlib.models 2 | 3 | 4 | class Recurrent(pufferlib.models.LSTMWrapper): 5 | def __init__(self, env, policy, input_size=512, hidden_size=512, num_layers=1): 6 | super().__init__(env, policy, input_size, hidden_size, num_layers) 7 | 8 | class Policy(pufferlib.models.Convolutional): 9 | def __init__(self, env, input_size=512, hidden_size=512, output_size=512, 10 | framestack=1, flat_size=64*4*6): 11 | super().__init__( 12 | env=env, 13 | input_size=input_size, 14 | hidden_size=hidden_size, 15 | output_size=output_size, 16 | framestack=framestack, 17 | flat_size=flat_size, 18 | channels_last=True 19 | ) 20 | -------------------------------------------------------------------------------- /pufferlib/exceptions.py: -------------------------------------------------------------------------------- 1 | class EnvironmentSetupError(RuntimeError): 2 | def __init__(self, e, package): 3 | super().__init__(self.message) 4 | 5 | class APIUsageError(RuntimeError): 6 | """Exception raised when the API is used incorrectly.""" 7 | 8 | def __init__(self, message="API usage error."): 9 | self.message = message 10 | super().__init__(self.message) 11 | 12 | class InvalidAgentError(ValueError): 13 | """Exception raised when an invalid agent key is used.""" 14 | 15 | def __init__(self, agent_id, agents): 16 | message = ( 17 | f'Invalid agent/team ({agent_id}) specified. ' 18 | f'Valid values:\n{agents}' 19 | ) 20 | super().__init__(message) 21 | -------------------------------------------------------------------------------- /pufferlib/extensions.pyx: -------------------------------------------------------------------------------- 1 | # distutils: define_macros=NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION 2 | # cython: language_level=3 3 | # cython: boundscheck=False 4 | # cython: initializedcheck=False 5 | # cython: wraparound=False 6 | # cython: nonecheck=False 7 | 8 | '''Cythonized implementations of PufferLib's emulation functions 9 | 10 | emulate is about 2x faster than Python. Nativize is only slightly faster. 11 | ''' 12 | 13 | import numpy as np 14 | cimport numpy as cnp 15 | 16 | from pufferlib.spaces import Tuple, Dict, Discrete 17 | 18 | 19 | def emulate(cnp.ndarray np_struct, object sample): 20 | cdef str k 21 | cdef int i 22 | 23 | if isinstance(sample, dict): 24 | for k, v in sample.items(): 25 | emulate(np_struct[k], v) 26 | elif isinstance(sample, tuple): 27 | for i, v in enumerate(sample): 28 | emulate(np_struct[f'f{i}'], v) 29 | else: 30 | np_struct[()] = sample 31 | 32 | cdef object _nativize(np_struct, object space): 33 | cdef str k 34 | cdef int i 35 | 36 | if isinstance(space, Discrete): 37 | return np_struct.item() 38 | elif isinstance(space, Tuple): 39 | return tuple(_nativize(np_struct[f'f{i}'], elem) 40 | for i, elem in enumerate(space)) 41 | elif isinstance(space, Dict): 42 | return {k: _nativize(np_struct[k], value) 43 | for k, value in space.items()} 44 | else: 45 | return np_struct 46 | 47 | def nativize(arr, object space, cnp.dtype struct_dtype): 48 | np_struct = np.asarray(arr).view(struct_dtype)[0] 49 | return _nativize(np_struct, space) 50 | -------------------------------------------------------------------------------- /pufferlib/namespace.py: -------------------------------------------------------------------------------- 1 | from pdb import set_trace as T 2 | from types import SimpleNamespace 3 | from collections.abc import Mapping 4 | 5 | def __getitem__(self, key): 6 | return self.__dict__[key] 7 | 8 | def keys(self): 9 | return self.__dict__.keys() 10 | 11 | def values(self): 12 | return self.__dict__.values() 13 | 14 | def items(self): 15 | return self.__dict__.items() 16 | 17 | def __iter__(self): 18 | return iter(self.__dict__) 19 | 20 | def __len__(self): 21 | return len(self.__dict__) 22 | 23 | class Namespace(SimpleNamespace, Mapping): 24 | __getitem__ = __getitem__ 25 | __iter__ = __iter__ 26 | __len__ = __len__ 27 | keys = keys 28 | values = values 29 | items = items 30 | 31 | def dataclass(cls): 32 | # Safely get annotations 33 | annotations = getattr(cls, '__annotations__', {}) 34 | 35 | # Combine both annotated and non-annotated fields 36 | all_fields = {**{k: None for k in annotations.keys()}, **cls.__dict__} 37 | all_fields = {k: v for k, v in all_fields.items() if not callable(v) and not k.startswith('__')} 38 | 39 | def __init__(self, **kwargs): 40 | for field, default_value in all_fields.items(): 41 | setattr(self, field, kwargs.get(field, default_value)) 42 | 43 | cls.__init__ = __init__ 44 | setattr(cls, "__getitem__", __getitem__) 45 | setattr(cls, "__iter__", __iter__) 46 | setattr(cls, "__len__", __len__) 47 | setattr(cls, "keys", keys) 48 | setattr(cls, "values", values) 49 | setattr(cls, "items", items) 50 | return cls 51 | 52 | def namespace(self=None, **kwargs): 53 | if self is None: 54 | return Namespace(**kwargs) 55 | self.__dict__.update(kwargs) 56 | -------------------------------------------------------------------------------- /pufferlib/ocean/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment import * 2 | 3 | try: 4 | import torch 5 | except ImportError: 6 | pass 7 | else: 8 | from .torch import Policy 9 | try: 10 | from .torch import Recurrent 11 | except: 12 | Recurrent = None 13 | -------------------------------------------------------------------------------- /pufferlib/ocean/breakout/breakout.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include "breakout.h" 3 | #include "puffernet.h" 4 | 5 | void demo() { 6 | Weights* weights = load_weights("resources/breakout_weights.bin", 148101); 7 | LinearLSTM* net = make_linearlstm(weights, 1, 119, 4); 8 | 9 | Breakout env = { 10 | .frameskip = 1, 11 | .width = 576, 12 | .height = 330, 13 | .paddle_width = 62, 14 | .paddle_height = 8, 15 | .ball_width = 32, 16 | .ball_height = 32, 17 | .brick_width = 32, 18 | .brick_height = 12, 19 | .brick_rows = 6, 20 | .brick_cols = 18, 21 | }; 22 | allocate(&env); 23 | c_reset(&env); 24 | 25 | Client* client = make_client(&env); 26 | 27 | while (!WindowShouldClose()) { 28 | // User can take control of the paddle 29 | if (IsKeyDown(KEY_LEFT_SHIFT)) { 30 | env.actions[0] = 0; 31 | if (IsKeyDown(KEY_UP) || IsKeyDown(KEY_W)) env.actions[0] = 1; 32 | if (IsKeyDown(KEY_LEFT) || IsKeyDown(KEY_A)) env.actions[0] = 2; 33 | if (IsKeyDown(KEY_RIGHT) || IsKeyDown(KEY_D)) env.actions[0] = 3; 34 | } else { 35 | forward_linearlstm(net, env.observations, env.actions); 36 | } 37 | 38 | c_step(&env); 39 | c_render(client, &env); 40 | } 41 | free_linearlstm(net); 42 | free(weights); 43 | free_allocated(&env); 44 | close_client(client); 45 | } 46 | 47 | void performance_test() { 48 | long test_time = 10; 49 | Breakout env = { 50 | .frameskip = 1, 51 | .width = 576, 52 | .height = 330, 53 | .paddle_width = 62, 54 | .paddle_height = 8, 55 | .ball_width = 32, 56 | .ball_height = 32, 57 | .brick_width = 32, 58 | .brick_height = 12, 59 | .brick_rows = 6, 60 | .brick_cols = 18, 61 | }; 62 | allocate(&env); 63 | c_reset(&env); 64 | 65 | long start = time(NULL); 66 | int i = 0; 67 | while (time(NULL) - start < test_time) { 68 | env.actions[0] = rand() % 4; 69 | c_step(&env); 70 | i++; 71 | } 72 | long end = time(NULL); 73 | printf("SPS: %ld\n", i / (end - start)); 74 | free_initialized(&env); 75 | } 76 | 77 | int main() { 78 | //performance_test(); 79 | demo(); 80 | return 0; 81 | } 82 | -------------------------------------------------------------------------------- /pufferlib/ocean/breakout/breakout.py: -------------------------------------------------------------------------------- 1 | '''High-perf Pong 2 | 3 | Inspired from https://gist.github.com/Yttrmin/18ecc3d2d68b407b4be1 4 | & https://jair.org/index.php/jair/article/view/10819/25823 5 | & https://www.youtube.com/watch?v=PSQt5KGv7Vk 6 | ''' 7 | 8 | import numpy as np 9 | import gymnasium 10 | 11 | import pufferlib 12 | from pufferlib.ocean.breakout.cy_breakout import CyBreakout 13 | 14 | class Breakout(pufferlib.PufferEnv): 15 | def __init__(self, num_envs=1, render_mode=None, report_interval=128, 16 | frameskip=4, width=576, height=330, 17 | paddle_width=62, paddle_height=8, 18 | ball_width=32, ball_height=32, 19 | brick_width=32, brick_height=12, 20 | brick_rows=6, brick_cols=18, buf=None): 21 | self.single_observation_space = gymnasium.spaces.Box(low=0, high=1, 22 | shape=(11 + brick_rows*brick_cols,), dtype=np.float32) 23 | self.single_action_space = gymnasium.spaces.Discrete(4) 24 | self.report_interval = report_interval 25 | self.render_mode = render_mode 26 | self.num_agents = num_envs 27 | 28 | super().__init__(buf) 29 | self.c_envs = CyBreakout(self.observations, self.actions, self.rewards, 30 | self.terminals, num_envs, frameskip, width, height, 31 | paddle_width, paddle_height, ball_width, ball_height, 32 | brick_width, brick_height, brick_rows, brick_cols) 33 | 34 | def reset(self, seed=None): 35 | self.c_envs.reset() 36 | self.tick = 0 37 | return self.observations, [] 38 | 39 | def step(self, actions): 40 | self.actions[:] = actions 41 | self.c_envs.step() 42 | 43 | info = [] 44 | if self.tick % self.report_interval == 0: 45 | log = self.c_envs.log() 46 | if log['episode_length'] > 0: 47 | info.append(log) 48 | 49 | self.tick += 1 50 | return (self.observations, self.rewards, 51 | self.terminals, self.truncations, info) 52 | 53 | def render(self): 54 | self.c_envs.render() 55 | 56 | def close(self): 57 | self.c_envs.close() 58 | 59 | def test_performance(timeout=10, atn_cache=1024): 60 | env = CyBreakout(num_envs=1000) 61 | env.reset() 62 | tick = 0 63 | 64 | actions = np.random.randint(0, 2, (atn_cache, env.num_envs)) 65 | 66 | import time 67 | start = time.time() 68 | while time.time() - start < timeout: 69 | atn = actions[tick % atn_cache] 70 | env.step(atn) 71 | tick += 1 72 | 73 | print(f'SPS: %f', env.num_envs * tick / (time.time() - start)) 74 | 75 | if __name__ == '__main__': 76 | test_performance() 77 | -------------------------------------------------------------------------------- /pufferlib/ocean/connect4/connect4.c: -------------------------------------------------------------------------------- 1 | #include "connect4.h" 2 | #include "puffernet.h" 3 | #include "time.h" 4 | 5 | const unsigned char NOOP = 8; 6 | 7 | void interactive() { 8 | Weights* weights = load_weights("resources/connect4_weights.bin", 138632); 9 | LinearLSTM* net = make_linearlstm(weights, 1, 42, 7); 10 | 11 | CConnect4 env = { 12 | .width = 672, 13 | .height = 576, 14 | .piece_width = 96, 15 | .piece_height = 96, 16 | }; 17 | allocate_cconnect4(&env); 18 | c_reset(&env); 19 | 20 | Client* client = make_client(env.width, env.height); 21 | float observations[42] = {0}; 22 | int actions[1] = {0}; 23 | 24 | int tick = 0; 25 | while (!WindowShouldClose()) { 26 | env.actions[0] = NOOP; 27 | // user inputs 1 - 7 key pressed 28 | if (IsKeyDown(KEY_LEFT_SHIFT)) { 29 | if(IsKeyPressed(KEY_ONE)) env.actions[0] = 0; 30 | if(IsKeyPressed(KEY_TWO)) env.actions[0] = 1; 31 | if(IsKeyPressed(KEY_THREE)) env.actions[0] = 2; 32 | if(IsKeyPressed(KEY_FOUR)) env.actions[0] = 3; 33 | if(IsKeyPressed(KEY_FIVE)) env.actions[0] = 4; 34 | if(IsKeyPressed(KEY_SIX)) env.actions[0] = 5; 35 | if(IsKeyPressed(KEY_SEVEN)) env.actions[0] = 6; 36 | } else if (tick % 30 == 0) { 37 | for (int i = 0; i < 42; i++) { 38 | observations[i] = env.observations[i]; 39 | } 40 | forward_linearlstm(net, (float*)&observations, (int*)&actions); 41 | env.actions[0] = actions[0]; 42 | } 43 | 44 | tick = (tick + 1) % 60; 45 | if (env.actions[0] >= 0 && env.actions[0] <= 6) { 46 | c_step(&env); 47 | } 48 | 49 | c_render(client, &env); 50 | } 51 | free_linearlstm(net); 52 | free(weights); 53 | close_client(client); 54 | free_allocated_cconnect4(&env); 55 | } 56 | 57 | void performance_test() { 58 | long test_time = 10; 59 | CConnect4 env = { 60 | .width = 672, 61 | .height = 576, 62 | .piece_width = 96, 63 | .piece_height = 96, 64 | }; 65 | allocate_cconnect4(&env); 66 | c_reset(&env); 67 | 68 | long start = time(NULL); 69 | int i = 0; 70 | while (time(NULL) - start < test_time) { 71 | env.actions[0] = rand() % 7; 72 | c_step(&env); 73 | i++; 74 | } 75 | long end = time(NULL); 76 | printf("SPS: %ld\n", i / (end - start)); 77 | free_allocated_cconnect4(&env); 78 | } 79 | 80 | int main() { 81 | //performance_test(); 82 | interactive(); 83 | return 0; 84 | } 85 | -------------------------------------------------------------------------------- /pufferlib/ocean/connect4/connect4.py: -------------------------------------------------------------------------------- 1 | '''High-perf Pong 2 | 3 | Inspired from https://gist.github.com/Yttrmin/18ecc3d2d68b407b4be1 4 | & https://jair.org/index.php/jair/article/view/10819/25823 5 | & https://www.youtube.com/watch?v=PSQt5KGv7Vk 6 | ''' 7 | 8 | import numpy as np 9 | import gymnasium 10 | 11 | import pufferlib 12 | from pufferlib.ocean.connect4.cy_connect4 import CyConnect4 13 | 14 | 15 | class Connect4(pufferlib.PufferEnv): 16 | def __init__(self, num_envs=1, render_mode=None, report_interval=128, 17 | width=672, height=576, piece_width=96, piece_height=96, buf=None): 18 | 19 | self.single_observation_space = gymnasium.spaces.Box(low=0, high=1, 20 | shape=(42,), dtype=np.float32) 21 | self.single_action_space = gymnasium.spaces.Discrete(7) 22 | self.report_interval = report_interval 23 | self.render_mode = render_mode 24 | self.num_agents = num_envs 25 | 26 | super().__init__(buf=buf) 27 | self.c_envs = CyConnect4(self.observations, self.actions, self.rewards, 28 | self.terminals, num_envs, width, height, piece_width, piece_height) 29 | 30 | def reset(self, seed=None): 31 | self.c_envs.reset() 32 | self.tick = 0 33 | return self.observations, [] 34 | 35 | def step(self, actions): 36 | self.actions[:] = actions 37 | self.c_envs.step() 38 | self.tick += 1 39 | 40 | info = [] 41 | if self.tick % self.report_interval == 0: 42 | log = self.c_envs.log() 43 | if log['episode_length'] > 0: 44 | info.append(log) 45 | 46 | return (self.observations, self.rewards, 47 | self.terminals, self.truncations, info) 48 | 49 | def render(self): 50 | self.c_envs.render() 51 | 52 | def close(self): 53 | self.c_envs.close() 54 | 55 | 56 | def test_performance(timeout=10, atn_cache=1024, num_envs=1024): 57 | import time 58 | 59 | env = Connect4(num_envs=num_envs) 60 | env.reset() 61 | tick = 0 62 | 63 | actions = np.random.randint( 64 | 0, 65 | env.single_action_space.n + 1, 66 | (atn_cache, num_envs), 67 | ) 68 | 69 | start = time.time() 70 | while time.time() - start < timeout: 71 | atn = actions[tick % atn_cache] 72 | env.step(atn) 73 | tick += 1 74 | 75 | print(f'SPS: {num_envs * tick / (time.time() - start)}') 76 | 77 | 78 | if __name__ == '__main__': 79 | test_performance() 80 | -------------------------------------------------------------------------------- /pufferlib/ocean/connect4/connect4game: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/ocean/connect4/connect4game -------------------------------------------------------------------------------- /pufferlib/ocean/grid/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/ocean/grid/__init__.py -------------------------------------------------------------------------------- /pufferlib/ocean/grid/grid.c: -------------------------------------------------------------------------------- 1 | #include "grid.h" 2 | 3 | unsigned int actions[41] = {NORTH, NORTH, NORTH, NORTH, NORTH, NORTH, 4 | EAST, EAST, EAST, EAST, EAST, EAST, SOUTH, WEST, WEST, WEST, NORTH, WEST, 5 | WEST, WEST, SOUTH, SOUTH, SOUTH, SOUTH, SOUTH, SOUTH, SOUTH, SOUTH, SOUTH, 6 | SOUTH, SOUTH, SOUTH, EAST, EAST, EAST, EAST, EAST, EAST, EAST, EAST, SOUTH 7 | }; 8 | 9 | void test_multiple_envs() { 10 | Env** envs = (Env**)calloc(10, sizeof(Env*)); 11 | for (int i = 0; i < 10; i++) { 12 | envs[i] = alloc_locked_room_env(); 13 | reset_locked_room(envs[i]); 14 | } 15 | 16 | for (int i = 0; i < 41; i++) { 17 | for (int j = 0; j < 10; j++) { 18 | envs[j]->actions[0] = actions[i]; 19 | step(envs[j]); 20 | } 21 | } 22 | for (int i = 0; i < 10; i++) { 23 | free_allocated_grid(envs[i]); 24 | } 25 | free(envs); 26 | printf("Done\n"); 27 | } 28 | 29 | int main() { 30 | int width = 32; 31 | int height = 32; 32 | int num_agents = 1; 33 | int horizon = 128; 34 | float agent_speed = 1; 35 | int vision = 5; 36 | bool discretize = true; 37 | 38 | int render_cell_size = 32; 39 | int seed = 42; 40 | 41 | //test_multiple_envs(); 42 | //exit(0); 43 | 44 | Env* env = alloc_locked_room_env(); 45 | reset_locked_room(env); 46 | /* 47 | Env* env = allocate_grid(width, height, num_agents, horizon, 48 | vision, agent_speed, discretize); 49 | env->agents[0].spawn_y = 16; 50 | env->agents[0].spawn_x = 16; 51 | env->agents[0].color = AGENT_2; 52 | Env* env = alloc_locked_room_env(); 53 | reset_locked_room(env); 54 | */ 55 | 56 | Renderer* renderer = init_renderer(render_cell_size, width, height); 57 | 58 | int t = 0; 59 | while (!WindowShouldClose()) { 60 | // User can take control of the first agent 61 | env->actions[0] = PASS; 62 | if (IsKeyDown(KEY_UP) || IsKeyDown(KEY_W)) env->actions[0] = NORTH; 63 | if (IsKeyDown(KEY_DOWN) || IsKeyDown(KEY_S)) env->actions[0] = SOUTH; 64 | if (IsKeyDown(KEY_LEFT) || IsKeyDown(KEY_A)) env->actions[0] = WEST; 65 | if (IsKeyDown(KEY_RIGHT) || IsKeyDown(KEY_D)) env->actions[0] = EAST; 66 | 67 | //for (int i = 0; i < num_agents; i++) { 68 | // env->actions[i] = rand() % 4; 69 | //} 70 | //env->actions[0] = actions[t]; 71 | bool done = step(env); 72 | if (done) { 73 | printf("Done\n"); 74 | reset_locked_room(env); 75 | } 76 | render_global(renderer, env); 77 | 78 | /* 79 | t++; 80 | if (t == 41) { 81 | exit(0); 82 | } 83 | */ 84 | } 85 | close_renderer(renderer); 86 | free_allocated_grid(env); 87 | return 0; 88 | } 89 | 90 | -------------------------------------------------------------------------------- /pufferlib/ocean/moba/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/ocean/moba/__init__.py -------------------------------------------------------------------------------- /pufferlib/ocean/pong/pong.c: -------------------------------------------------------------------------------- 1 | #include "pong.h" 2 | #include "puffernet.h" 3 | 4 | int main() { 5 | Weights* weights = load_weights("resources/pong_weights.bin", 133764); 6 | LinearLSTM* net = make_linearlstm(weights, 1, 8, 3); 7 | 8 | Pong env = { 9 | .width = 500, 10 | .height = 640, 11 | .paddle_width = 20, 12 | .paddle_height = 70, 13 | //.ball_width = 10, 14 | //.ball_height = 15, 15 | .ball_width = 32, 16 | .ball_height = 32, 17 | .paddle_speed = 8, 18 | .ball_initial_speed_x = 10, 19 | .ball_initial_speed_y = 1, 20 | .ball_speed_y_increment = 3, 21 | .ball_max_speed_y = 13, 22 | .max_score = 21, 23 | .frameskip = 1, 24 | }; 25 | allocate(&env); 26 | 27 | Client* client = make_client(&env); 28 | 29 | c_reset(&env); 30 | while (!WindowShouldClose()) { 31 | // User can take control of the paddle 32 | if (IsKeyDown(KEY_LEFT_SHIFT)) { 33 | env.actions[0] = 0; 34 | if (IsKeyDown(KEY_UP) || IsKeyDown(KEY_W)) env.actions[0] = 1; 35 | if (IsKeyDown(KEY_DOWN) || IsKeyDown(KEY_S)) env.actions[0] = 2; 36 | } else { 37 | forward_linearlstm(net, env.observations, env.actions); 38 | } 39 | 40 | c_step(&env); 41 | c_render(client, &env); 42 | } 43 | free_linearlstm(net); 44 | free(weights); 45 | free_allocated(&env); 46 | close_client(client); 47 | } 48 | 49 | -------------------------------------------------------------------------------- /pufferlib/ocean/robocode/build_local.sh: -------------------------------------------------------------------------------- 1 | clang -Wall -Wuninitialized -Wmisleading-indentation -fsanitize=address,undefined,bounds,pointer-overflow,leak -ferror-limit=3 -g -o robocode robocode.c -I./raylib-5.0_linux_amd64/include/ -L./raylib-5.0_linux_amd64/lib/ -lraylib -lGL -lm -lpthread -ldl -lrt -lX11 -DPLATFORM_DESKTOP 2 | 3 | 4 | -------------------------------------------------------------------------------- /pufferlib/ocean/robocode/robocode: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/ocean/robocode/robocode -------------------------------------------------------------------------------- /pufferlib/ocean/robocode/robocode.c: -------------------------------------------------------------------------------- 1 | #include "robocode.h" 2 | 3 | int main() { 4 | Env env = {0}; 5 | env.num_agents = 2; 6 | env.width = 768; 7 | env.height = 576; 8 | allocate_env(&env); 9 | reset(&env); 10 | 11 | Client* client = make_client(&env); 12 | 13 | while (!WindowShouldClose()) { 14 | for (int i = 0; i < NUM_ACTIONS; i++) { 15 | env.actions[i] = 0; 16 | } 17 | 18 | env.actions[0] = 16.0f; 19 | float x = env.robots[0].x; 20 | float y = env.robots[0].y; 21 | float op_x = env.robots[1].x; 22 | float op_y = env.robots[1].y; 23 | float gun_heading = env.robots[0].gun_heading; 24 | float angle_to_op = 180*atan2(op_y - y, op_x - x)/M_PI; 25 | float gun_delta = angle_to_op - gun_heading; 26 | if (gun_delta < -180) gun_delta += 360; 27 | env.actions[2] = (gun_delta > 0) ? 1.0f : -1.0f; 28 | if (gun_delta < 5 && gun_delta > -5) env.actions[4] = 1.0; 29 | 30 | env.actions[5] = 16.0f; 31 | x = env.robots[1].x; 32 | y = env.robots[1].y; 33 | op_x = env.robots[0].x; 34 | op_y = env.robots[0].y; 35 | gun_heading = env.robots[1].gun_heading; 36 | angle_to_op = 180*atan2(op_y - y, op_x - x)/M_PI; 37 | gun_delta = angle_to_op - gun_heading; 38 | if (gun_delta < -180) gun_delta += 360; 39 | env.actions[7] = (gun_delta > 0) ? 1.0f : -1.0f; 40 | if (gun_delta < 5 && gun_delta > -5) env.actions[9] = 1.0; 41 | 42 | 43 | //if (IsKeyPressed(KEY_ESCAPE)) break; 44 | if (IsKeyDown(KEY_W)) env.actions[0] = 16.0f; 45 | if (IsKeyDown(KEY_S)) env.actions[0] = -16.0f; 46 | if (IsKeyDown(KEY_A)) env.actions[1] = -2.0f; 47 | if (IsKeyDown(KEY_D)) env.actions[1] = 2.0f; 48 | if (IsKeyDown(KEY_Q)) env.actions[2] = -1.0f; 49 | if (IsKeyDown(KEY_E)) env.actions[2] = 1.0f; 50 | if (IsKeyDown(KEY_SPACE)) env.actions[4] = 1.0f; 51 | 52 | step(&env); 53 | render(client, &env); 54 | } 55 | CloseWindow(); 56 | return 0; 57 | } 58 | -------------------------------------------------------------------------------- /pufferlib/ocean/rocket_lander/rocket_lander.c: -------------------------------------------------------------------------------- 1 | #include "rocket_lander.h" 2 | 3 | int main(void) { 4 | demo(); 5 | return 0; 6 | } 7 | 8 | 9 | 10 | /* 11 | Entity legs[2] = {0}; 12 | for (int i = 0; i < 2; i++) { 13 | float leg_i = (i == 0) ? -1 : 1; 14 | b2Vec2 leg_extent = (b2Vec2){LEG_W / SCALE, LEG_H / SCALE}; 15 | 16 | b2BodyDef leg = b2DefaultBodyDef(); 17 | leg.type = b2_dynamicBody; 18 | leg.position = (b2Vec2){-leg_i * LEG_AWAY, INITIAL_Y - LANDER_HEIGHT/2 - leg_extent.y/2}; 19 | //leg.position = (b2Vec2){0, 0}; 20 | leg.rotation = b2MakeRot(leg_i * 1.05); 21 | b2BodyId leg_id = b2CreateBody(world_id, &leg); 22 | 23 | b2Polygon leg_box = b2MakeBox(leg_extent.x, leg_extent.y); 24 | b2ShapeDef leg_shape = b2DefaultShapeDef(); 25 | b2CreatePolygonShape(leg_id, &leg_shape, &leg_box); 26 | 27 | float joint_x = leg_i*LANDER_WIDTH/2; 28 | float joint_y = INITIAL_Y - LANDER_HEIGHT/2 - leg_extent.y/2; 29 | b2Vec2 joint_p = (b2Vec2){joint_x, joint_y}; 30 | 31 | b2RevoluteJointDef joint = b2DefaultRevoluteJointDef(); 32 | joint.bodyIdA = lander_id; 33 | joint.bodyIdB = leg_id; 34 | joint.localAnchorA = b2Body_GetLocalPoint(lander_id, joint_p); 35 | joint.localAnchorB = b2Body_GetLocalPoint(leg_id, joint_p); 36 | joint.localAnchorB = (b2Vec2){i * 0.5, LEG_DOWN}; 37 | joint.enableMotor = true; 38 | joint.enableLimit = true; 39 | joint.maxMotorTorque = LEG_SPRING_TORQUE; 40 | joint.motorSpeed = 0.3*i; 41 | 42 | if (i == 0) { 43 | joint.lowerAngle = 40 * DEGTORAD; 44 | joint.upperAngle = 45 * DEGTORAD; 45 | } else { 46 | joint.lowerAngle = -45 * DEGTORAD; 47 | joint.upperAngle = -40 * DEGTORAD; 48 | } 49 | 50 | b2JointId joint_id = b2CreateRevoluteJoint(world_id, &joint); 51 | 52 | legs[i] = (Entity){ 53 | .extent = leg_extent, 54 | .bodyId = leg_id, 55 | }; 56 | } 57 | */ 58 | 59 | 60 | -------------------------------------------------------------------------------- /pufferlib/ocean/rocket_lander/rocket_lander.py: -------------------------------------------------------------------------------- 1 | '''High-perf Pong 2 | 3 | Inspired from https://gist.github.com/Yttrmin/18ecc3d2d68b407b4be1 4 | & https://jair.org/index.php/jair/article/view/10819/25823 5 | & https://www.youtube.com/watch?v=PSQt5KGv7Vk 6 | ''' 7 | 8 | import numpy as np 9 | import gymnasium 10 | 11 | import pufferlib 12 | from pufferlib.ocean.rocket_lander.cy_rocket_lander import CyRocketLander 13 | 14 | class RocketLander(pufferlib.PufferEnv): 15 | def __init__(self, num_envs=1, render_mode=None, report_interval=32, buf=None): 16 | self.single_observation_space = gymnasium.spaces.Box(low=0, high=1, 17 | shape=(6,), dtype=np.float32) 18 | self.single_action_space = gymnasium.spaces.Discrete(4) 19 | self.render_mode = render_mode 20 | self.num_agents = num_envs 21 | self.report_interval = report_interval 22 | 23 | super().__init__(buf) 24 | self.float_actions = np.zeros((num_envs, 3), dtype=np.float32) 25 | self.c_envs = CyRocketLander(self.observations, self.float_actions, self.rewards, 26 | self.terminals, self.truncations, num_envs) 27 | 28 | def reset(self, seed=None): 29 | self.tick = 0 30 | self.c_envs.reset() 31 | return self.observations, [] 32 | 33 | def step(self, actions): 34 | self.float_actions[:, :] = 0 35 | self.float_actions[:, 0] = actions == 1 36 | self.float_actions[:, 1] = actions == 2 37 | self.float_actions[:, 2] = actions == 3 38 | self.c_envs.step() 39 | 40 | info = [] 41 | if self.tick % self.report_interval == 0: 42 | log = self.c_envs.log() 43 | if log['episode_length'] > 0: 44 | info.append(log) 45 | 46 | self.tick += 1 47 | return (self.observations, self.rewards, 48 | self.terminals, self.truncations, info) 49 | 50 | def render(self): 51 | self.c_envs.render() 52 | 53 | def close(self): 54 | self.c_envs.close() 55 | 56 | def test_performance(timeout=10, atn_cache=1024): 57 | env = RocketLander(num_envs=1000) 58 | env.reset() 59 | tick = 0 60 | 61 | actions = np.random.randint(0, 2, (atn_cache, env.num_envs)) 62 | 63 | import time 64 | start = time.time() 65 | while time.time() - start < timeout: 66 | atn = actions[tick % atn_cache] 67 | env.step(atn) 68 | tick += 1 69 | 70 | print(f'SPS: %f', env.num_envs * tick / (time.time() - start)) 71 | 72 | if __name__ == '__main__': 73 | test_performance() 74 | -------------------------------------------------------------------------------- /pufferlib/ocean/rware/rware.py: -------------------------------------------------------------------------------- 1 | '''High-perf Pong 2 | 3 | Inspired from https://gist.github.com/Yttrmin/18ecc3d2d68b407b4be1 4 | & https://jair.org/index.php/jair/article/view/10819/25823 5 | & https://www.youtube.com/watch?v=PSQt5KGv7Vk 6 | ''' 7 | 8 | import numpy as np 9 | import gymnasium 10 | 11 | import pufferlib 12 | from pufferlib.ocean.rware.cy_rware import CyRware 13 | 14 | PLAYER_OBS_N = 27 15 | 16 | class Rware(pufferlib.PufferEnv): 17 | def __init__(self, num_envs=1, render_mode=None, report_interval=1, 18 | width=1280, height=1024, 19 | num_agents=4, 20 | map_choice=1, 21 | num_requested_shelves=4, 22 | grid_square_size=64, 23 | human_agent_idx=0, 24 | reward_type=1, 25 | buf = None): 26 | 27 | # env 28 | self.num_agents = num_envs*num_agents 29 | self.render_mode = render_mode 30 | self.report_interval = report_interval 31 | 32 | self.num_obs = 27 33 | self.single_observation_space = gymnasium.spaces.Box(low=0, high=1, 34 | shape=(self.num_obs,), dtype=np.float32) 35 | self.single_action_space = gymnasium.spaces.Discrete(5) 36 | 37 | super().__init__(buf=buf) 38 | self.c_envs = CyRware(self.observations, self.actions, self.rewards, 39 | self.terminals, num_envs, width, height, map_choice, num_agents, num_requested_shelves, grid_square_size, human_agent_idx) 40 | 41 | 42 | def reset(self, seed=None): 43 | self.c_envs.reset() 44 | self.tick = 0 45 | return self.observations, [] 46 | 47 | def step(self, actions): 48 | self.actions[:] = actions 49 | self.c_envs.step() 50 | self.tick += 1 51 | 52 | info = [] 53 | if self.tick % self.report_interval == 0: 54 | log = self.c_envs.log() 55 | if log['episode_length'] > 0: 56 | info.append(log) 57 | return (self.observations, self.rewards, 58 | self.terminals, self.truncations, info) 59 | 60 | def render(self): 61 | self.c_envs.render() 62 | 63 | def close(self): 64 | self.c_envs.close() 65 | 66 | def test_performance(timeout=10, atn_cache=1024): 67 | num_envs=1000; 68 | env = MyRware(num_envs=num_envs) 69 | env.reset() 70 | tick = 0 71 | 72 | actions = np.random.randint(0, env.single_action_space.n, (atn_cache, 5*num_envs)) 73 | 74 | import time 75 | start = time.time() 76 | while time.time() - start < timeout: 77 | atn = actions[tick % atn_cache] 78 | env.step(atn) 79 | tick += 1 80 | 81 | sps = num_envs * tick / (time.time() - start) 82 | print(f'SPS: {sps:,}') 83 | if __name__ == '__main__': 84 | test_performance() 85 | -------------------------------------------------------------------------------- /pufferlib/ocean/snake/README.md: -------------------------------------------------------------------------------- 1 | # PufferLib Multi-Snake 2 | 3 | This is a simple multi-agent snake environment runnable with any number of snakes, board size, food, etc. I originally implemented this to demonstrate how simple it is to implement ultra high performance environments that run at millions of steps per second. The exact same approaches you see here are used in all of my more complex simulators. 4 | 5 | # Cython version 6 | 7 | The Cython version is the original. It runs over 10M steps/second/core on a high-end CPU. This is the version that we currently have bound to training. You can use it with the PufferLib demo script (--env snake) or import it from pufferlib/environments/ocean. There are a number of default board sizes and settings. If you would like to contribute games to PufferLib, you can use this project as a template. There is a bit of bloat in the .py file because we have to trick PufferLib's vectorization into thinking this is a vecenv. In the future, there will be a more standard advanced API. 8 | 9 | Key concepts: 10 | - Memory views: Cython provides a way to access numpy arrays as C arrays or structs. This gives you C-speed numpy indexing and prevents you from having to copy data around. When running with multiprocessing, the observation buffers are stored in shared memory, so you are literally simulating into the experience buffer. 11 | - No memory management: All data is allocated by Numpy and passed to C. This is fast and also prevents any chance of leaks 12 | - No python callbacks: Compile and optimize with annotations enabled (see setup.py) to ensure that the Cython code never calls back to Python. You should be able to get >>1M agent steps/second for almost any sim 13 | 14 | # C version 15 | 16 | The C version is a direct port of the Cython version, plus a few minor tweaks. It includes a pure C raylib client and a pure C MLP forward pass for running local inference. I made this so that we could run a cool demo in the browser 100% client side. I may port additional simulators in the future, and you are welcome to contribute C code to PufferLib, but this is not required. You can make things plenty fast in Cython. To build this locally, all you need is the raylib source. If you want to build for web, follow RayLib's emscripten setup. 17 | 18 | -------------------------------------------------------------------------------- /pufferlib/ocean/snake/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/ocean/snake/__init__.py -------------------------------------------------------------------------------- /pufferlib/ocean/snake/snake.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include "snake.h" 3 | #include "puffernet.h" 4 | 5 | int demo() { 6 | CSnake env = { 7 | .num_snakes = 256, 8 | .width = 640, 9 | .height = 360, 10 | .max_snake_length = 200, 11 | .food = 4096, 12 | .vision = 5, 13 | .leave_corpse_on_death = true, 14 | .reward_food = 1.0f, 15 | .reward_corpse = 0.5f, 16 | .reward_death = -1.0f, 17 | }; 18 | allocate_csnake(&env); 19 | c_reset(&env); 20 | 21 | Weights* weights = load_weights("resources/snake_weights.bin", 148357); 22 | LinearLSTM* net = make_linearlstm(weights, env.num_snakes, env.obs_size, 4); 23 | Client* client = make_client(2, env.width, env.height); 24 | 25 | while (!WindowShouldClose()) { 26 | // User can take control of the first snake 27 | if (IsKeyDown(KEY_LEFT_SHIFT)) { 28 | if (IsKeyDown(KEY_UP) || IsKeyDown(KEY_W)) env.actions[0] = 0; 29 | if (IsKeyDown(KEY_DOWN) || IsKeyDown(KEY_S)) env.actions[0] = 1; 30 | if (IsKeyDown(KEY_LEFT) || IsKeyDown(KEY_A)) env.actions[0] = 2; 31 | if (IsKeyDown(KEY_RIGHT) || IsKeyDown(KEY_D)) env.actions[0] = 3; 32 | } else { 33 | for (int i = 0; i < env.num_snakes*env.obs_size; i++) { 34 | net->obs[i] = (float)env.observations[i]; 35 | } 36 | forward_linearlstm(net, net->obs, env.actions); 37 | } 38 | c_step(&env); 39 | c_render(client, &env); 40 | } 41 | free_linearlstm(net); 42 | free(weights); 43 | close_client(client); 44 | free_csnake(&env); 45 | return 0; 46 | } 47 | 48 | void test_performance(float test_time) { 49 | CSnake env = { 50 | .num_snakes = 1024, 51 | .width = 1280, 52 | .height = 720, 53 | .max_snake_length = 200, 54 | .food = 16384, 55 | .vision = 5, 56 | .leave_corpse_on_death = true, 57 | .reward_food = 1.0f, 58 | .reward_corpse = 0.5f, 59 | .reward_death = -1.0f, 60 | }; 61 | allocate_csnake(&env); 62 | c_reset(&env); 63 | 64 | int start = time(NULL); 65 | int i = 0; 66 | while (time(NULL) - start < test_time) { 67 | for (int j = 0; j < env.num_snakes; j++) { 68 | env.actions[j] = rand()%4; 69 | } 70 | c_step(&env); 71 | i++; 72 | } 73 | int end = time(NULL); 74 | printf("SPS: %f\n", (float)env.num_snakes*i / (end - start)); 75 | } 76 | 77 | int main() { 78 | demo(); 79 | //test_performance(30); 80 | return 0; 81 | } 82 | -------------------------------------------------------------------------------- /pufferlib/ocean/squared/cy_squared.pyx: -------------------------------------------------------------------------------- 1 | cimport numpy as cnp 2 | from libc.stdlib cimport calloc, free 3 | 4 | cdef extern from "squared.h": 5 | ctypedef struct Squared: 6 | unsigned char* observations 7 | int* actions 8 | float* rewards 9 | unsigned char* terminals 10 | int size 11 | int tick 12 | int r 13 | int c 14 | 15 | ctypedef struct Client 16 | 17 | void c_reset(Squared* env) 18 | void c_step(Squared* env) 19 | Client* make_client(Squared* env) 20 | void close_client(Client* client) 21 | void c_render(Client* client, Squared* env) 22 | 23 | cdef class CySquared: 24 | cdef: 25 | Squared* envs 26 | Client* client 27 | int num_envs 28 | int size 29 | 30 | def __init__(self, unsigned char[:, :] observations, int[:] actions, 31 | float[:] rewards, unsigned char[:] terminals, int num_envs, int size): 32 | 33 | self.envs = calloc(num_envs, sizeof(Squared)) 34 | self.num_envs = num_envs 35 | self.client = NULL 36 | 37 | cdef int i 38 | for i in range(num_envs): 39 | self.envs[i] = Squared( 40 | observations = &observations[i, 0], 41 | actions = &actions[i], 42 | rewards = &rewards[i], 43 | terminals = &terminals[i], 44 | size=size, 45 | ) 46 | 47 | def reset(self): 48 | cdef int i 49 | for i in range(self.num_envs): 50 | c_reset(&self.envs[i]) 51 | 52 | def step(self): 53 | cdef int i 54 | for i in range(self.num_envs): 55 | c_step(&self.envs[i]) 56 | 57 | def render(self): 58 | cdef Squared* env = &self.envs[0] 59 | if self.client == NULL: 60 | self.client = make_client(env) 61 | 62 | c_render(self.client, env) 63 | 64 | def close(self): 65 | if self.client != NULL: 66 | close_client(self.client) 67 | self.client = NULL 68 | 69 | free(self.envs) 70 | -------------------------------------------------------------------------------- /pufferlib/ocean/squared/squared.c: -------------------------------------------------------------------------------- 1 | #include "squared.h" 2 | #include "puffernet.h" 3 | 4 | int main() { 5 | //Weights* weights = load_weights("resources/pong_weights.bin", 133764); 6 | //LinearLSTM* net = make_linearlstm(weights, 1, 8, 3); 7 | 8 | Squared env = {.size = 11}; 9 | allocate(&env); 10 | 11 | Client* client = make_client(&env); 12 | 13 | c_reset(&env); 14 | while (!WindowShouldClose()) { 15 | if (IsKeyDown(KEY_LEFT_SHIFT)) { 16 | env.actions[0] = 0; 17 | if (IsKeyDown(KEY_UP) || IsKeyDown(KEY_W)) env.actions[0] = UP; 18 | if (IsKeyDown(KEY_DOWN) || IsKeyDown(KEY_S)) env.actions[0] = DOWN; 19 | if (IsKeyDown(KEY_LEFT) || IsKeyDown(KEY_A)) env.actions[0] = LEFT; 20 | if (IsKeyDown(KEY_RIGHT) || IsKeyDown(KEY_D)) env.actions[0] = RIGHT; 21 | } else { 22 | env.actions[0] = NOOP; 23 | //forward_linearlstm(net, env.observations, env.actions); 24 | } 25 | c_step(&env); 26 | c_render(client, &env); 27 | } 28 | //free_linearlstm(net); 29 | //free(weights); 30 | free_allocated(&env); 31 | close_client(client); 32 | } 33 | 34 | -------------------------------------------------------------------------------- /pufferlib/ocean/squared/squared.py: -------------------------------------------------------------------------------- 1 | '''A simple sample environment. Use this as a template for your own envs.''' 2 | 3 | import gymnasium 4 | import numpy as np 5 | 6 | import pufferlib 7 | from pufferlib.ocean.squared.cy_squared import CySquared 8 | 9 | 10 | class Squared(pufferlib.PufferEnv): 11 | def __init__(self, num_envs=1, render_mode=None, size=11, buf=None): 12 | self.single_observation_space = gymnasium.spaces.Box(low=0, high=1, 13 | shape=(size*size,), dtype=np.uint8) 14 | self.single_action_space = gymnasium.spaces.Discrete(5) 15 | self.render_mode = render_mode 16 | self.num_agents = num_envs 17 | 18 | super().__init__(buf) 19 | self.c_envs = CySquared(self.observations, self.actions, 20 | self.rewards, self.terminals, num_envs, size) 21 | 22 | def reset(self, seed=None): 23 | self.c_envs.reset() 24 | return self.observations, [] 25 | 26 | def step(self, actions): 27 | self.actions[:] = actions 28 | self.c_envs.step() 29 | 30 | episode_returns = self.rewards[self.terminals] 31 | 32 | info = [] 33 | if len(episode_returns) > 0: 34 | info = [{ 35 | 'reward': np.mean(episode_returns), 36 | }] 37 | 38 | return (self.observations, self.rewards, 39 | self.terminals, self.truncations, info) 40 | 41 | def render(self): 42 | self.c_envs.render() 43 | 44 | def close(self): 45 | self.c_envs.close() 46 | -------------------------------------------------------------------------------- /pufferlib/ocean/tactical/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/ocean/tactical/__init__.py -------------------------------------------------------------------------------- /pufferlib/ocean/tactical/build_local.sh: -------------------------------------------------------------------------------- 1 | clang -Wall -Wuninitialized -Wmisleading-indentation -fsanitize=address -ferror-limit=3 -g -o tacticalgame tactical.h -I/opt/homebrew/opt/raylib/include -L/opt/homebrew/opt/raylib/lib/ -lraylib -lm -lpthread -ldl -DPLATFORM_DESKTOP 2 | -------------------------------------------------------------------------------- /pufferlib/ocean/tactical/c_tactical.pyx: -------------------------------------------------------------------------------- 1 | cimport numpy as cnp 2 | 3 | cdef extern from "tactical.h": 4 | ctypedef struct Tactical: 5 | int num_agents 6 | unsigned char* observations 7 | int* actions 8 | float* rewards 9 | 10 | Tactical* init_tactical() 11 | void reset(Tactical* env) 12 | void step(Tactical* env) 13 | 14 | void free_tactical(Tactical* env) 15 | 16 | ctypedef struct GameRenderer 17 | 18 | GameRenderer* init_game_renderer(Tactical* env) 19 | int render_game(GameRenderer* renderer, Tactical* env) 20 | void close_game_renderer(GameRenderer* renderer) 21 | 22 | 23 | cdef class CTactical: 24 | cdef Tactical* env 25 | cdef GameRenderer* renderer 26 | 27 | def __init__(self, 28 | cnp.ndarray observations, 29 | cnp.ndarray rewards, 30 | cnp.ndarray actions,): 31 | env = init_tactical() 32 | self.env = env 33 | 34 | env.observations = observations.data 35 | env.actions = actions.data 36 | env.rewards = rewards.data 37 | 38 | self.renderer = NULL 39 | 40 | def reset(self): 41 | reset(self.env) 42 | 43 | def step(self): 44 | step(self.env) 45 | 46 | def render(self): 47 | if self.renderer == NULL: 48 | import os 49 | cwd = os.getcwd() 50 | os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) 51 | self.renderer = init_game_renderer(self.env) 52 | os.chdir(cwd) 53 | 54 | return render_game(self.renderer, self.env) 55 | 56 | def close(self): 57 | if self.renderer != NULL: 58 | close_game_renderer(self.renderer) 59 | self.renderer = NULL 60 | 61 | free_tactical(self.env) 62 | -------------------------------------------------------------------------------- /pufferlib/ocean/tactical/tactical.c: -------------------------------------------------------------------------------- 1 | #include "tactical.h" 2 | 3 | 4 | int main() { 5 | Tactical* env = init_tactical(); 6 | // allocate(&env); 7 | 8 | GameRenderer* client = init_game_renderer(env); 9 | 10 | reset(env); 11 | while (!WindowShouldClose()) { 12 | if (IsKeyPressed(KEY_Q) || IsKeyPressed(KEY_BACKSPACE)) break; 13 | step(env); 14 | render_game(client, env); 15 | } 16 | // free_linearlstm(net); 17 | // free(weights); 18 | // free_allocated(&env); 19 | // close_client(client); 20 | } 21 | 22 | -------------------------------------------------------------------------------- /pufferlib/ocean/tcg/build_local.sh: -------------------------------------------------------------------------------- 1 | clang -Wall -Wuninitialized -Wmisleading-indentation -fsanitize=address,undefined,bounds,pointer-overflow,leak -ferror-limit=3 -g -o tcg tcg.c -I./raylib-5.0_linux_amd64/include/ -L./raylib-5.0_linux_amd64/lib/ -lraylib -lGL -lm -lpthread -ldl -lrt -lX11 -DPLATFORM_DESKTOP 2 | 3 | 4 | -------------------------------------------------------------------------------- /pufferlib/ocean/tcg/build_web.sh: -------------------------------------------------------------------------------- 1 | emcc -o build/game.html tcg.c -Os -Wall ./raylib/src/libraylib.a -I./raylib/src -L. -L./raylib/src/libraylib.a -sASSERTIONS=2 -gsource-map -s USE_GLFW=3 -sUSE_WEBGL2=1 -s ASYNCIFY -sFILESYSTEM -s FORCE_FILESYSTEM=1 --shell-file ./raylib/src/minshell.html -DPLATFORM_WEB -DGRAPHICS_API_OPENGL_ES3 2 | -------------------------------------------------------------------------------- /pufferlib/ocean/tcg/tcg.c: -------------------------------------------------------------------------------- 1 | #include "tcg.h" 2 | 3 | int main() { 4 | TCG env = {0}; // MUST ZERO 5 | allocate_tcg(&env); 6 | reset(&env); 7 | 8 | init_client(&env); 9 | 10 | int atn = -1; 11 | while (!WindowShouldClose()) { 12 | if (atn != -1) { 13 | step(&env, atn); 14 | atn = -1; 15 | } 16 | 17 | if (IsKeyPressed(KEY_ONE)) atn = 0; 18 | if (IsKeyPressed(KEY_TWO)) atn = 1; 19 | if (IsKeyPressed(KEY_THREE)) atn = 2; 20 | if (IsKeyPressed(KEY_FOUR)) atn = 3; 21 | if (IsKeyPressed(KEY_FIVE)) atn = 4; 22 | if (IsKeyPressed(KEY_SIX)) atn = 5; 23 | if (IsKeyPressed(KEY_SEVEN)) atn = 6; 24 | if (IsKeyPressed(KEY_EIGHT)) atn = 7; 25 | if (IsKeyPressed(KEY_NINE)) atn = 8; 26 | if (IsKeyPressed(KEY_ZERO)) atn = 9; 27 | if (IsKeyPressed(KEY_ENTER)) atn = 10; 28 | 29 | if (env.turn == 1) { 30 | atn = rand() % 11; 31 | } 32 | 33 | render(&env); 34 | } 35 | free_tcg(&env); 36 | return 0; 37 | } 38 | -------------------------------------------------------------------------------- /pufferlib/ocean/trash_pickup/README.md: -------------------------------------------------------------------------------- 1 | # TrashPickup Environment 2 | 3 | A lightweight multi-agent reinforcement learning (RL) environment designed for coordination and cooperation research. Agents pick up trash and deposit it in bins for rewards. 4 | 5 | ## Key Features 6 | - **Multi-Agent Coordination:** Encourages teamwork, efficient planning, and resource allocation. 7 | - **Configurable Setup:** Adjustable grid size, number of agents, trash, bins, and episode length. 8 | - **Discrete Action Space:** Actions include `UP`, `DOWN`, `LEFT`, `RIGHT`. 9 | - **Fast and Lightweight:** Optimized for rapid training and testing. 10 | 11 | ## Example Research Goals 12 | - Investigate emergent behaviors like task allocation and coordination. 13 | - Study efficient resource collection and bin-pushing strategies. 14 | 15 | ## Ideal For 16 | - RL researchers exploring multi-agent cooperation. 17 | - Students learning about multi-agent systems. 18 | - Developers testing scalable RL algorithms. 19 | -------------------------------------------------------------------------------- /pufferlib/ocean/tripletriad/tripletriad.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import gymnasium 3 | 4 | import pufferlib 5 | from pufferlib.ocean.tripletriad.cy_tripletriad import CyTripleTriad 6 | 7 | class TripleTriad(pufferlib.PufferEnv): 8 | def __init__(self, num_envs=1, render_mode=None, report_interval=1, 9 | width=990, height=690, piece_width=192, piece_height=224, buf=None): 10 | self.single_observation_space = gymnasium.spaces.Box(low=0, high=1, 11 | shape=(114,), dtype=np.float32) 12 | self.single_action_space = gymnasium.spaces.Discrete(15) 13 | self.report_interval = report_interval 14 | self.render_mode = render_mode 15 | self.num_agents = num_envs 16 | 17 | super().__init__(buf=buf) 18 | self.c_envs = CyTripleTriad(self.observations, self.actions, 19 | self.rewards, self.terminals, num_envs, width, height, 20 | piece_width, piece_height) 21 | 22 | def reset(self, seed=None): 23 | self.c_envs.reset() 24 | self.tick = 0 25 | return self.observations, [] 26 | 27 | def step(self, actions): 28 | self.actions[:] = actions 29 | self.c_envs.step() 30 | self.tick += 1 31 | 32 | info = [] 33 | if self.tick % self.report_interval == 0: 34 | log = self.c_envs.log() 35 | if log['episode_length'] > 0: 36 | info.append(log) 37 | 38 | return (self.observations, self.rewards, 39 | self.terminals, self.truncations, info) 40 | 41 | def render(self): 42 | self.c_envs.render() 43 | 44 | def close(self): 45 | self.c_envs.close() 46 | 47 | def test_performance(timeout=10, atn_cache=1024): 48 | env = TripleTriad(num_envs=1000) 49 | env.reset() 50 | tick = 0 51 | 52 | actions = np.random.randint(0, 2, (atn_cache, env.num_envs)) 53 | 54 | import time 55 | start = time.time() 56 | while time.time() - start < timeout: 57 | atn = actions[tick % atn_cache] 58 | env.step(atn) 59 | tick += 1 60 | 61 | print(f'SPS: %f', env.num_envs * tick / (time.time() - start)) 62 | 63 | if __name__ == '__main__': 64 | test_performance() 65 | -------------------------------------------------------------------------------- /pufferlib/policy_store.py: -------------------------------------------------------------------------------- 1 | from pdb import set_trace as T 2 | import os 3 | import torch 4 | 5 | 6 | def get_policy_names(path: str) -> list: 7 | # Assumeing that all pt files other than trainer_state.pt in the path are policy files 8 | names = [] 9 | for file in os.listdir(path): 10 | if file.endswith(".pt") and file != 'trainer_state.pt': 11 | names.append(file[:-3]) 12 | return sorted(names) 13 | 14 | class PolicyStore: 15 | def __init__(self, path: str): 16 | self.path = path 17 | 18 | def policy_names(self) -> list: 19 | return get_policy_names(self.path) 20 | 21 | def get_policy(self, name: str) -> torch.nn.Module: 22 | path = os.path.join(self.path, name + '.pt') 23 | try: 24 | return torch.load(path) 25 | except: 26 | return torch.load(path, map_location=torch.device('cpu')) 27 | -------------------------------------------------------------------------------- /pufferlib/resources/breakout_weights.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/breakout_weights.bin -------------------------------------------------------------------------------- /pufferlib/resources/connect4.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/connect4.pt -------------------------------------------------------------------------------- /pufferlib/resources/connect4_weights.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/connect4_weights.bin -------------------------------------------------------------------------------- /pufferlib/resources/enduro/enduro_spritesheet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/enduro/enduro_spritesheet.png -------------------------------------------------------------------------------- /pufferlib/resources/enduro/enduro_weights.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/enduro/enduro_weights.bin -------------------------------------------------------------------------------- /pufferlib/resources/go_weights.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/go_weights.bin -------------------------------------------------------------------------------- /pufferlib/resources/moba/bloom_shader_100.fs: -------------------------------------------------------------------------------- 1 | #version 330 2 | 3 | precision mediump float; 4 | 5 | // Input vertex attributes (from vertex shader) 6 | varying vec2 fragTexCoord; 7 | varying vec4 fragColor; 8 | 9 | // Input uniform values 10 | uniform sampler2D texture0; 11 | uniform vec4 colDiffuse; 12 | 13 | // NOTE: Add here your custom variables 14 | 15 | const vec2 size = vec2(800, 450); // Framebuffer size 16 | const float samples = 5.0; // Pixels per axis; higher = bigger glow, worse performance 17 | const float quality = 2.5; // Defines size factor: Lower = smaller glow, better quality 18 | 19 | void main() 20 | { 21 | vec4 sum = vec4(0); 22 | vec2 sizeFactor = vec2(1)/size*quality; 23 | 24 | // Texel color fetching from texture sampler 25 | vec4 source = texture2D(texture0, fragTexCoord); 26 | 27 | const int range = 2; // should be = (samples - 1)/2; 28 | 29 | for (int x = -range; x <= range; x++) 30 | { 31 | for (int y = -range; y <= range; y++) 32 | { 33 | sum += texture2D(texture0, fragTexCoord + vec2(x, y)*sizeFactor); 34 | } 35 | } 36 | 37 | // Calculate final fragment color 38 | gl_FragColor = ((sum/(samples*samples)) + source)*colDiffuse; 39 | } 40 | -------------------------------------------------------------------------------- /pufferlib/resources/moba/bloom_shader_330.fs: -------------------------------------------------------------------------------- 1 | #version 330 2 | 3 | //precision mediump float; 4 | 5 | // Input vertex attributes (from vertex shader) 6 | varying vec2 fragTexCoord; 7 | varying vec4 fragColor; 8 | 9 | // Input uniform values 10 | uniform sampler2D texture0; 11 | uniform vec4 colDiffuse; 12 | 13 | // NOTE: Add here your custom variables 14 | 15 | const vec2 size = vec2(800, 450); // Framebuffer size 16 | const float samples = 5.0; // Pixels per axis; higher = bigger glow, worse performance 17 | const float quality = 2.5; // Defines size factor: Lower = smaller glow, better quality 18 | 19 | void main() 20 | { 21 | vec4 sum = vec4(0); 22 | vec2 sizeFactor = vec2(1)/size*quality; 23 | 24 | // Texel color fetching from texture sampler 25 | vec4 source = texture2D(texture0, fragTexCoord); 26 | 27 | const int range = 2; // should be = (samples - 1)/2; 28 | 29 | for (int x = -range; x <= range; x++) 30 | { 31 | for (int y = -range; y <= range; y++) 32 | { 33 | sum += texture2D(texture0, fragTexCoord + vec2(x, y)*sizeFactor); 34 | } 35 | } 36 | 37 | // Calculate final fragment color 38 | gl_FragColor = ((sum/(samples*samples)) + source)*colDiffuse; 39 | } 40 | -------------------------------------------------------------------------------- /pufferlib/resources/moba/dota_map.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/moba/dota_map.png -------------------------------------------------------------------------------- /pufferlib/resources/moba/moba_assets.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/moba/moba_assets.png -------------------------------------------------------------------------------- /pufferlib/resources/moba/moba_weights.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/moba/moba_weights.bin -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/ASSETS_LICENSE.md: -------------------------------------------------------------------------------- 1 | Characters and assets subject to the license of the original artists. In particular, we use Mana Seed assets by Seliel the Shaper under a valid license purchased from itch.io. You may not repurpose these assets for other projects without purchasing your own license. To mitigate abuse, we release only collated spritesheets as exported by our postprocessor. 2 | -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/air_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/air_0.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/air_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/air_1.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/air_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/air_2.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/air_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/air_3.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/air_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/air_4.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/air_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/air_5.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/air_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/air_6.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/air_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/air_7.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/air_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/air_8.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/air_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/air_9.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/earth_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/earth_0.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/earth_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/earth_1.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/earth_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/earth_2.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/earth_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/earth_3.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/earth_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/earth_4.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/earth_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/earth_5.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/earth_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/earth_6.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/earth_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/earth_7.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/earth_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/earth_8.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/earth_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/earth_9.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/fire_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/fire_0.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/fire_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/fire_1.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/fire_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/fire_2.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/fire_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/fire_3.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/fire_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/fire_4.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/fire_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/fire_5.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/fire_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/fire_6.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/fire_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/fire_7.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/fire_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/fire_8.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/fire_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/fire_9.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/inventory_64.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/inventory_64.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/inventory_64_press.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/inventory_64_press.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/inventory_64_selected.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/inventory_64_selected.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/items_condensed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/items_condensed.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/map_shader_100.fs: -------------------------------------------------------------------------------- 1 | precision mediump float; 2 | 3 | // Input uniforms (unchanged from original) 4 | uniform sampler2D terrain; 5 | uniform sampler2D texture_tiles; 6 | uniform vec4 colDiffuse; 7 | uniform vec3 resolution; 8 | uniform vec4 mouse; 9 | uniform float time; 10 | uniform float camera_x; 11 | uniform float camera_y; 12 | uniform float map_width; 13 | uniform float map_height; 14 | 15 | // Constants 16 | const float TILE_SIZE = 64.0; 17 | const float TILES_PER_ROW = 64.0; 18 | 19 | void main() 20 | { 21 | float ts = TILE_SIZE * resolution.z; 22 | // Get the screen pixel coordinates 23 | vec2 pixelPos = gl_FragCoord.xy; 24 | 25 | float x_offset = camera_x/64.0 + pixelPos.x/ts - resolution.x/ts/2.0; 26 | float y_offset = camera_y/64.0 - pixelPos.y/ts + resolution.y/ts/2.0; 27 | float x_floor = floor(x_offset); 28 | float y_floor = floor(y_offset); 29 | float x_frac = x_offset - x_floor; 30 | float y_frac = y_offset - y_floor; 31 | 32 | // Environment size calculation 33 | vec2 uv = vec2( 34 | x_floor/map_width, 35 | y_floor/map_height 36 | ); 37 | 38 | vec2 tile_rg = texture2D(terrain, uv).rg; 39 | float tile_high_byte = floor(tile_rg.r * 255.0); 40 | float tile_low_byte = floor(tile_rg.g * 255.0); 41 | float tile = tile_high_byte * 64.0 + tile_low_byte; 42 | 43 | // Handle animated tiles 44 | if (tile >= 240.0 && tile < (240.0 + 4.0*4.0*4.0*4.0)) { 45 | tile += floor(3.9 * time); 46 | } 47 | 48 | tile_high_byte = floor(tile/64.0); 49 | tile_low_byte = floor(mod(tile, 64.0)); 50 | 51 | vec2 tile_uv = vec2( 52 | tile_low_byte/64.0 + x_frac/64.0, 53 | tile_high_byte/64.0 + y_frac/64.0 54 | ); 55 | 56 | gl_FragColor = texture2D(texture_tiles, tile_uv); 57 | } 58 | -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/map_shader_330.fs: -------------------------------------------------------------------------------- 1 | #version 330 2 | 3 | // Input vertex attributes (from vertex shader) 4 | in vec2 fragTexCoord; 5 | in vec4 fragColor; 6 | 7 | // Input uniform values 8 | uniform sampler2D terrain; 9 | uniform sampler2D texture_tiles; // Tile sprite sheet texture 10 | uniform vec4 colDiffuse; 11 | uniform vec3 resolution; 12 | uniform vec4 mouse; 13 | uniform float time; 14 | uniform float camera_x; 15 | uniform float camera_y; 16 | uniform float map_width; 17 | uniform float map_height; 18 | 19 | // Output fragment color 20 | out vec4 outputColor; 21 | 22 | float TILE_SIZE = 64.0; 23 | 24 | // Number of tiles per row in the sprite sheet 25 | const int TILES_PER_ROW = 64; // Adjust this based on your sprite sheet layout 26 | 27 | void main() 28 | { 29 | float ts = TILE_SIZE * resolution.z; 30 | 31 | // Get the screen pixel coordinates 32 | vec2 pixelPos = gl_FragCoord.xy; 33 | 34 | float x_offset = camera_x/64.0 + pixelPos.x/ts - resolution.x/ts/2.0; 35 | float y_offset = camera_y/64.0 - pixelPos.y/ts + resolution.y/ts/2.0; 36 | 37 | float x_floor = floor(x_offset); 38 | float y_floor = floor(y_offset); 39 | 40 | float x_frac = x_offset - x_floor; 41 | float y_frac = y_offset - y_floor; 42 | 43 | // TODO: This is the env size 44 | vec2 uv = vec2( 45 | x_floor/map_width, 46 | y_floor/map_height 47 | ); 48 | vec2 tile_rg = texture(terrain, uv).rg; 49 | 50 | int tile_high_byte = int(tile_rg.r*255.0); 51 | int tile_low_byte = int(tile_rg.g*255.0); 52 | 53 | int tile = tile_high_byte*64 + tile_low_byte; 54 | if (tile >= 240 && tile < 240+4*4*4*4) { 55 | tile += int(3.9*time); 56 | } 57 | 58 | tile_high_byte = int(tile/64.0); 59 | tile_low_byte = int(tile%64); 60 | 61 | vec2 tile_uv = vec2( 62 | tile_low_byte/64.0 + x_frac/64.0, 63 | tile_high_byte/64.0 + y_frac/64.0 64 | ); 65 | 66 | outputColor = texture(texture_tiles, tile_uv); 67 | } 68 | 69 | -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/merged_sheet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/merged_sheet.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/neutral_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/neutral_0.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/neutral_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/neutral_1.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/neutral_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/neutral_2.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/neutral_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/neutral_3.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/neutral_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/neutral_4.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/neutral_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/neutral_5.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/neutral_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/neutral_6.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/neutral_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/neutral_7.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/neutral_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/neutral_8.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/neutral_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/neutral_9.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/nmmo3_help.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/nmmo3_help.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/nmmo_1500.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/nmmo_1500.bin -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/nmmo_2025.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/nmmo_2025.bin -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/water_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/water_0.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/water_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/water_1.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/water_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/water_2.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/water_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/water_3.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/water_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/water_4.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/water_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/water_5.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/water_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/water_6.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/water_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/water_7.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/water_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/water_8.png -------------------------------------------------------------------------------- /pufferlib/resources/nmmo3/water_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/nmmo3/water_9.png -------------------------------------------------------------------------------- /pufferlib/resources/pong_weights.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/pong_weights.bin -------------------------------------------------------------------------------- /pufferlib/resources/puffers_128.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/puffers_128.png -------------------------------------------------------------------------------- /pufferlib/resources/robocode/robocode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/robocode/robocode.png -------------------------------------------------------------------------------- /pufferlib/resources/rware_weights.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/rware_weights.bin -------------------------------------------------------------------------------- /pufferlib/resources/snake_weights.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/snake_weights.bin -------------------------------------------------------------------------------- /pufferlib/resources/tripletriad_weights.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/pufferlib/resources/tripletriad_weights.bin -------------------------------------------------------------------------------- /pufferlib/spaces.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import gym 3 | import gymnasium 4 | 5 | Box = (gym.spaces.Box, gymnasium.spaces.Box) 6 | Dict = (gym.spaces.Dict, gymnasium.spaces.Dict) 7 | Discrete = (gym.spaces.Discrete, gymnasium.spaces.Discrete) 8 | MultiBinary = (gym.spaces.MultiBinary, gymnasium.spaces.MultiBinary) 9 | MultiDiscrete = (gym.spaces.MultiDiscrete, gymnasium.spaces.MultiDiscrete) 10 | Tuple = (gym.spaces.Tuple, gymnasium.spaces.Tuple) 11 | 12 | def joint_space(space, n): 13 | if isinstance(space, Discrete): 14 | return gymnasium.spaces.MultiDiscrete([space.n] * n) 15 | elif isinstance(space, MultiDiscrete): 16 | return gymnasium.spaces.Box(low=0, 17 | high=np.repeat(space.nvec[None] - 1, n, axis=0), 18 | shape=(n, len(space)), dtype=space.dtype) 19 | elif isinstance(space, Box): 20 | return gymnasium.spaces.Box( 21 | low=np.repeat(space.low[None], n, axis=0), 22 | high=np.repeat(space.high[None], n, axis=0), 23 | shape=(n, *space.shape), dtype=space.dtype) 24 | else: 25 | raise ValueError(f'Unsupported space: {space}') 26 | -------------------------------------------------------------------------------- /pufferlib/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '2.0.6' 2 | -------------------------------------------------------------------------------- /pufferlib/wrappers.py: -------------------------------------------------------------------------------- 1 | from pdb import set_trace as T 2 | 3 | class GymToGymnasium: 4 | def __init__(self, env): 5 | self.env = env 6 | self.observation_space = env.observation_space 7 | self.action_space = env.action_space 8 | self.render = env.render 9 | self.metadata = env.metadata 10 | 11 | def reset(self, seed=None, options=None): 12 | if seed is not None: 13 | ob = self.env.reset(seed=seed) 14 | else: 15 | ob = self.env.reset() 16 | return ob, {} 17 | 18 | def step(self, action): 19 | observation, reward, done, info = self.env.step(action) 20 | return observation, reward, done, False, info 21 | 22 | def close(self): 23 | self.env.close() 24 | 25 | class PettingZooTruncatedWrapper: 26 | def __init__(self, env): 27 | self.env = env 28 | self.observation_space = env.observation_space 29 | self.action_space = env.action_space 30 | self.render = env.render 31 | 32 | @property 33 | def render_mode(self): 34 | return self.env.render_mode 35 | 36 | @property 37 | def possible_agents(self): 38 | return self.env.possible_agents 39 | 40 | @property 41 | def agents(self): 42 | return self.env.agents 43 | 44 | def reset(self, seed=None): 45 | if seed is not None: 46 | ob, info = self.env.reset(seed=seed) 47 | else: 48 | ob, info = self.env.reset() 49 | info = {k: {} for k in ob} 50 | return ob, info 51 | 52 | def step(self, actions): 53 | observations, rewards, terminals, truncations, infos = self.env.step(actions) 54 | return observations, rewards, terminals, truncations, infos 55 | 56 | def close(self): 57 | self.env.close() 58 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel", "Cython", "numpy"] 3 | build-backend = "setuptools.build_meta" 4 | -------------------------------------------------------------------------------- /resources: -------------------------------------------------------------------------------- 1 | pufferlib/resources/ -------------------------------------------------------------------------------- /sb3_demo.py: -------------------------------------------------------------------------------- 1 | # Minimal SB3 demo using PufferLib's environment wrappers 2 | 3 | import argparse 4 | 5 | from stable_baselines3 import PPO 6 | from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv 7 | from stable_baselines3.common.env_util import make_vec_env 8 | 9 | from pufferlib.environments import atari 10 | 11 | ''' 12 | elif args.backend == 'sb3': 13 | from stable_baselines3 import PPO 14 | from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv 15 | from stable_baselines3.common.env_util import make_vec_env 16 | from sb3_contrib import RecurrentPPO 17 | 18 | envs = make_vec_env(lambda: make_env(**args.env), 19 | n_envs=args.train.num_envs, seed=args.train.seed, vec_env_cls=DummyVecEnv) 20 | 21 | model = RecurrentPPO("CnnLstmPolicy", envs, verbose=1, 22 | n_steps=args.train.batch_rows*args.train.bptt_horizon, 23 | batch_size=args.train.batch_size, n_epochs=args.train.update_epochs, 24 | gamma=args.train.gamma 25 | ) 26 | 27 | model.learn(total_timesteps=args.train.total_timesteps) 28 | ''' 29 | 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--env', type=str, default='BreakoutNoFrameskip-v4') 32 | args = parser.parse_args() 33 | 34 | env_creator = atari.env_creator(args.env) 35 | envs = make_vec_env(lambda: env_creator(), 36 | n_envs=4, seed=0, vec_env_cls=DummyVecEnv) 37 | 38 | model = PPO("CnnPolicy", envs, verbose=1) 39 | model.learn(total_timesteps=2000) 40 | 41 | # Demonstrate loading 42 | model.save(f'ppo_{args.env}') 43 | model = PPO.load(f'ppo_{args.env}') 44 | 45 | # Watch the agent play 46 | env = atari.make_env(args.env, render_mode='human') 47 | terminal = True 48 | for _ in range(1000): 49 | if terminal or truncated: 50 | ob, _ = env.reset() 51 | 52 | ob = ob.reshape(1, *ob.shape) 53 | action, _states = model.predict(ob) 54 | ob, reward, terminal, truncated, info = env.step(action[0]) 55 | env.render() 56 | 57 | -------------------------------------------------------------------------------- /scripts/build_ocean.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Usage: ./build_env.sh pong [local|fast|web] 4 | 5 | ENV=$1 6 | MODE=${2:-local} 7 | PLATFORM="$(uname -s)" 8 | SRC_DIR="pufferlib/ocean/$ENV" 9 | WEB_OUTPUT_DIR="build_web/$ENV" 10 | 11 | # Create build output directory 12 | mkdir -p "$WEB_OUTPUT_DIR" 13 | 14 | if [ "$MODE" = "web" ]; then 15 | echo "Building $ENV for web deployment..." 16 | emcc \ 17 | -o "$WEB_OUTPUT_DIR/game.html" \ 18 | "$SRC_DIR/$ENV.c" \ 19 | -O3 \ 20 | -Wall \ 21 | ./raylib_wasm/lib/libraylib.a \ 22 | -I./raylib_wasm/include \ 23 | -I./pufferlib\ 24 | -L. \ 25 | -L./raylib_wasm/lib \ 26 | -sASSERTIONS=2 \ 27 | -gsource-map \ 28 | -s USE_GLFW=3 \ 29 | -s USE_WEBGL2=1 \ 30 | -s ASYNCIFY \ 31 | -sFILESYSTEM \ 32 | -s FORCE_FILESYSTEM=1 \ 33 | --shell-file ./scripts/minshell.html \ 34 | -sINITIAL_MEMORY=512MB \ 35 | -sSTACK_SIZE=512KB \ 36 | -DPLATFORM_WEB \ 37 | -DGRAPHICS_API_OPENGL_ES3 \ 38 | --preload-file pufferlib/resources@resources/ 39 | echo "Web build completed: $WEB_OUTPUT_DIR/game.html" 40 | exit 0 41 | fi 42 | 43 | FLAGS=( 44 | -Wall 45 | -I./raylib-5.0_linux_amd64/include 46 | -I./pufferlib 47 | "$SRC_DIR/$ENV.c" -o "$ENV" 48 | ./raylib-5.0_linux_amd64/lib/libraylib.a 49 | -lm 50 | -lpthread 51 | -DPLATFORM_DESKTOP 52 | ) 53 | 54 | 55 | if [ "$PLATFORM" = "Darwin" ]; then 56 | FLAGS+=( 57 | -framework Cocoa 58 | -framework IOKit 59 | -framework CoreVideo 60 | ) 61 | fi 62 | 63 | echo ${FLAGS[@]} 64 | 65 | if [ "$MODE" = "local" ]; then 66 | echo "Building $ENV for local testing..." 67 | if [ "$PLATFORM" = "Linux" ]; then 68 | # These important debug flags don't work on macos 69 | FLAGS+=( 70 | -fsanitize=address,undefined,bounds,pointer-overflow,leak 71 | ) 72 | fi 73 | clang -g -O0 ${FLAGS[@]} 74 | elif [ "$MODE" = "fast" ]; then 75 | echo "Building optimized $ENV for local testing..." 76 | clang -pg -O2 ${FLAGS[@]} 77 | echo "Built to: $ENV" 78 | else 79 | echo "Invalid mode specified: local|fast|web" 80 | exit 1 81 | fi 82 | -------------------------------------------------------------------------------- /scripts/sweep_atari.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | environments=( 4 | "pong" 5 | "breakout" 6 | "beam_rider" 7 | "enduro" 8 | "qbert" 9 | "space_invaders" 10 | "seaquest" 11 | ) 12 | 13 | for env in "${environments[@]}"; do 14 | echo "Training: $env" 15 | python demo.py --mode sweep-carbs --vec multiprocessing --env "$env" 16 | done 17 | -------------------------------------------------------------------------------- /scripts/train_atari.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | environments=( 4 | "pong" 5 | "breakout" 6 | "beam_rider" 7 | "enduro" 8 | "qbert" 9 | "space_invaders" 10 | "seaquest" 11 | ) 12 | 13 | for env in "${environments[@]}"; do 14 | echo "Training: $env" 15 | python demo.py --mode train --vec multiprocessing --track --env "$env" 16 | done 17 | -------------------------------------------------------------------------------- /scripts/train_ocean.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | environments=( 4 | "puffer_breakout" 5 | "puffer_connect4" 6 | "puffer_pong" 7 | "puffer_snake" 8 | "puffer_tripletriad" 9 | "puffer_rware" 10 | "puffer_go" 11 | "puffer_tactics" 12 | "puffer_moba" 13 | ) 14 | 15 | for env in "${environments[@]}"; do 16 | echo "Training: $env" 17 | python demo.py --mode train --vec multiprocessing --track --env "$env" 18 | done 19 | -------------------------------------------------------------------------------- /scripts/train_procgen.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | environments=( 4 | "bigfish" 5 | "bossfight" 6 | "caveflyer" 7 | "chaser" 8 | "climber" 9 | "coinrun" 10 | "dodgeball" 11 | "fruitbot" 12 | "heist" 13 | "jumper" 14 | "leaper" 15 | "maze" 16 | "miner" 17 | "ninja" 18 | "plunder" 19 | "starpilot" 20 | ) 21 | 22 | for env in "${environments[@]}"; do 23 | echo "Training on environment: $env" 24 | python demo.py --mode train --vec multiprocessing --track --env "$env" 25 | done 26 | -------------------------------------------------------------------------------- /scripts/train_sanity.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | environments=( 4 | "puffer_squared" 5 | "puffer_password" 6 | "puffer_stochastic" 7 | "puffer_memory" 8 | "puffer_multiagent" 9 | "puffer_spaces" 10 | "puffer_bandit" 11 | ) 12 | 13 | for env in "${environments[@]}"; do 14 | echo "Training: $env" 15 | python demo.py --mode train --vec multiprocessing --track --env "$env" 16 | done 17 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/tests/__init__.py -------------------------------------------------------------------------------- /tests/pool/envpool_results.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PufferAI/PufferLib/9e1c5c3c2e1c288b673aaee6dbb42e0d6b154e11/tests/pool/envpool_results.npy -------------------------------------------------------------------------------- /tests/pool/plot_packing.py: -------------------------------------------------------------------------------- 1 | import plotly.graph_objects as go 2 | import numpy as np 3 | 4 | # Parameters 5 | n_bars = 24 6 | mu = 0.002 7 | std = 0.002 8 | 9 | background = '#061a1a' 10 | forground = '#f1f1f1' 11 | 12 | # Sampling from the normal distribution 13 | bar_heights = mu + np.clip(np.random.normal(mu, std, n_bars), 0, np.inf) 14 | 15 | # Creating the bar chart 16 | fig = go.Figure(go.Bar( 17 | x=[i for i in range(n_bars)], 18 | y=bar_heights, 19 | marker_line_width=0, 20 | marker_color=forground, 21 | )) 22 | 23 | # Updating the layout 24 | fig.update_layout({ 25 | 'plot_bgcolor': background, 26 | 'paper_bgcolor': background, 27 | 'showlegend': False, 28 | 'xaxis': {'visible': False}, 29 | 'yaxis': {'visible': False, 'range': [0, max(bar_heights)]}, 30 | 'margin': {'l': 0, 'r': 0, 't': 0, 'b': 0}, 31 | 'height': 400, 32 | 'width': 800, 33 | 'bargap': 0.0, 34 | 'bargroupgap': 0.0, 35 | }) 36 | 37 | 38 | fig.show() 39 | fig.write_image('../docker/env_variance.png', scale=3) 40 | -------------------------------------------------------------------------------- /tests/pool/test_multiprocessing.py: -------------------------------------------------------------------------------- 1 | from pdb import set_trace as T 2 | import numpy as np 3 | import time 4 | 5 | from pufferlib.vectorization import Multiprocessing 6 | from pufferlib.environments import pokemon_red 7 | 8 | def test_envpool(num_envs, envs_per_worker, envs_per_batch, steps=1000, env_pool=True): 9 | pool = Multiprocessing(pokemon_red.env_creator(), num_envs=num_envs, 10 | envs_per_worker=envs_per_worker, envs_per_batch=envs_per_batch, 11 | env_pool=True, 12 | ) 13 | pool.async_reset() 14 | 15 | a = np.array([pool.single_action_space.sample() for _ in range(envs_per_batch)]) 16 | start = time.time() 17 | for s in range(steps): 18 | o, r, d, t, i, mask, env_id = pool.recv() 19 | pool.send(a) 20 | end = time.time() 21 | print('Steps per second: ', envs_per_batch * steps / (end - start)) 22 | pool.close() 23 | 24 | 25 | if __name__ == '__main__': 26 | # 225 sps 27 | #test_envpool(num_envs=1, envs_per_worker=1, envs_per_batch=1, env_pool=False) 28 | 29 | # 600 sps 30 | #test_envpool(num_envs=6, envs_per_worker=1, envs_per_batch=6, env_pool=False) 31 | 32 | # 645 sps 33 | #test_envpool(num_envs=24, envs_per_worker=4, envs_per_batch=24, env_pool=False) 34 | 35 | # 755 sps 36 | # test_envpool(num_envs=24, envs_per_worker=4, envs_per_batch=24) 37 | 38 | # 1050 sps 39 | # test_envpool(num_envs=48, envs_per_worker=4, envs_per_batch=24) 40 | 41 | # 1300 sps 42 | test_envpool(num_envs=48, envs_per_worker=4, envs_per_batch=12) 43 | -------------------------------------------------------------------------------- /tests/test_atari_reset.py: -------------------------------------------------------------------------------- 1 | from pdb import set_trace as T 2 | from pufferlib.environments import atari 3 | 4 | 5 | def test_atari_reset(): 6 | '''Common way to bug the wrappers can be detected 7 | by checking that the environment properly resets 8 | after hitting 0 lives''' 9 | env = atari.env_creator('BreakoutNoFrameskip-v4')(4) 10 | 11 | obs, info = env.reset() 12 | prev_lives = 5 13 | 14 | lives = [] 15 | for i in range(1000): 16 | action = env.action_space.sample() 17 | obs, reward, terminal, truncated, info = env.step(action) 18 | 19 | if info['lives'] != prev_lives: 20 | lives.append(i) 21 | prev_lives = info['lives'] 22 | 23 | if terminal or truncated: 24 | obs = env.reset() 25 | 26 | assert len(lives) > 10 27 | 28 | if __name__ == '__main__': 29 | test_atari_reset() 30 | -------------------------------------------------------------------------------- /tests/test_import_performance.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | def test_import_speed(): 4 | start = time.time() 5 | import pufferlib 6 | end = time.time() 7 | print(end - start, ' seconds to import pufferlib') 8 | assert end - start < 0.25 9 | 10 | if __name__ == '__main__': 11 | test_import_speed() -------------------------------------------------------------------------------- /tests/test_namespace.py: -------------------------------------------------------------------------------- 1 | from pufferlib import namespace, dataclass 2 | 3 | def test_namespace_as_function(): 4 | ns = namespace(x=1, y=2, z=3) 5 | 6 | assert ns.x == 1 7 | assert ns.y == 2 8 | assert ns.z == 3 9 | assert list(ns.keys()) == ['x', 'y', 'z'] 10 | assert list(ns.values()) == [1, 2, 3] 11 | assert list(ns.items()) == [('x', 1), ('y', 2), ('z', 3)] 12 | 13 | @dataclass 14 | class TestClass: 15 | a: int 16 | b = 1 17 | 18 | def test_namespace_as_decorator(): 19 | obj = TestClass(a=4, b=5) 20 | 21 | assert obj.a == 4 22 | assert obj.b == 5 23 | assert list(obj.keys()) == ['a', 'b'] 24 | assert list(obj.values()) == [4, 5] 25 | assert list(obj.items()) == [('a', 4), ('b', 5)] 26 | 27 | if __name__ == '__main__': 28 | test_namespace_as_function() 29 | test_namespace_as_decorator() 30 | -------------------------------------------------------------------------------- /tests/test_nmmo3_compile.py: -------------------------------------------------------------------------------- 1 | from pdb import set_trace as T 2 | import time 3 | import torch 4 | import numpy as np 5 | 6 | 7 | @torch.compile(fullgraph=True, mode='reduce-overhead') 8 | def fast_decode_map(codes, obs, factors, add, div): 9 | codes = codes.view(codes.shape[0], 1, -1) 10 | dec = add + (codes//div) % factors 11 | obs.scatter_(1, dec, 1) 12 | return obs 13 | 14 | #@torch.compile(fullgraph=True, mode='reduce-overhead') 15 | def decode_map(codes): 16 | codes = codes.unsqueeze(1).long() 17 | factors = [4, 4, 16, 5, 3, 5, 5, 6, 7, 4] 18 | n_channels = sum(factors) 19 | obs = torch.zeros(codes.shape[0], n_channels, 11, 15, device='cuda') 20 | 21 | add, div = 0, 1 22 | # TODO: check item/tier order 23 | for mod in factors: 24 | obs.scatter_(1, add+(codes//div)%mod, 1) 25 | add += mod 26 | div *= mod 27 | 28 | return obs 29 | 30 | 31 | def test_perf(n=100, agents=1024): 32 | factors = np.array([4, 4, 16, 5, 3, 5, 5, 6, 7, 4]) 33 | n_channels = sum(factors) 34 | add = np.array([0, *np.cumsum(factors).tolist()[:-1]])[None, :, None] 35 | div = np.array([1, *np.cumprod(factors).tolist()[:-1]])[None, :, None] 36 | 37 | factors = torch.tensor(factors)[None, :, None].cuda() 38 | add = torch.tensor(add).cuda() 39 | div = torch.tensor(div).cuda() 40 | 41 | codes = torch.randint(0, 4*4*16*5*3*5*5*6*7*4, (agents, 11, 15)).cuda() 42 | obs = torch.zeros(agents, n_channels, 11*15, device='cuda') 43 | obs_view = obs.view(agents, n_channels, 11, 15) 44 | 45 | # Warm up 46 | decode_map(codes) 47 | fast_decode_map(codes, obs, factors, add, div) 48 | torch.cuda.synchronize() 49 | 50 | start = time.time() 51 | for _ in range(n): 52 | fast_decode_map(codes, obs, factors, add, div) 53 | #obs2 = decode_map(codes) 54 | #print(torch.all(obs_view == obs2)) 55 | 56 | 57 | torch.cuda.synchronize() 58 | end = time.time() 59 | sps = n / (end - start) 60 | print(f'SPS: {sps:.2f}') 61 | 62 | if __name__ == '__main__': 63 | test_perf() 64 | 65 | -------------------------------------------------------------------------------- /tests/test_pokemon_red.py: -------------------------------------------------------------------------------- 1 | from pufferlib.environments.pokemon_red import env_creator 2 | 3 | env = env_creator()() 4 | ob, info = env.reset() 5 | for i in range(100): 6 | ob, reward, terminal, truncated, info = env.step(env.action_space.sample()) 7 | print(f'Step: {i}, Info: {info}') 8 | 9 | env.close() 10 | -------------------------------------------------------------------------------- /tests/test_record_array.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | import numpy as np 3 | 4 | # Create a custom Gym space using Dict, Tuple, and Box 5 | space = gym.spaces.Dict({ 6 | "position": gym.spaces.Box(low=-1.0, high=1.0, shape=(2,), dtype=np.float32), 7 | "velocity": gym.spaces.Box(low=-1.0, high=1.0, shape=(2,), dtype=np.float32), 8 | "description": gym.spaces.Tuple(( 9 | #gym.spaces.Discrete(10), 10 | gym.spaces.Box(low=0, high=100, shape=(), dtype=np.int32), 11 | gym.spaces.Box(low=0, high=100, shape=(), dtype=np.int32) 12 | )) 13 | }) 14 | 15 | space = gym.spaces.Dict({ 16 | "position": gym.spaces.Box(low=-1.0, high=1.0, shape=(2,), dtype=np.float32), 17 | }) 18 | 19 | 20 | # Define a function to create a dtype from the Gym space 21 | def create_dtype_from_space(space): 22 | if isinstance(space, gym.spaces.Dict): 23 | dtype_fields = [(name, create_dtype_from_space(subspace)) for name, subspace in space.spaces.items()] 24 | return np.dtype(dtype_fields) 25 | elif isinstance(space, gym.spaces.Tuple): 26 | dtype_fields = [('field' + str(i), create_dtype_from_space(subspace)) for i, subspace in enumerate(space.spaces)] 27 | return np.dtype(dtype_fields) 28 | elif isinstance(space, gym.spaces.Box): 29 | return (space.dtype, space.shape) 30 | elif isinstance(space, gym.spaces.Discrete): 31 | return np.int64 # Assuming np.int64 for Discrete spaces 32 | 33 | # Compute the dtype from the space 34 | space_dtype = create_dtype_from_space(space) 35 | 36 | sample = dict(space.sample()) 37 | breakpoint() 38 | np.rec.array(sample, dtype=space_dtype) 39 | 40 | # Function to sample from the space and convert to a structured numpy array 41 | def sample_and_convert(space, dtype): 42 | sample = space.sample() 43 | flat_sample = {} 44 | def flatten(sample, name_prefix=""): 45 | for key, item in sample.items(): 46 | full_key = name_prefix + key if name_prefix == "" else name_prefix + "_" + key 47 | if isinstance(item, dict): 48 | flatten(item, full_key) 49 | else: 50 | flat_sample[full_key] = item 51 | flatten(sample) 52 | return np.array(tuple(flat_sample.values()), dtype=dtype) 53 | 54 | num_samples = 3 55 | samples = [sample_and_convert(space, space_dtype) for _ in range(num_samples)] 56 | print("Samples:", samples) 57 | 58 | record_array = np.rec.array(samples) 59 | print("Record Array:", record_array) 60 | 61 | bytes_array = record_array.tobytes() 62 | print("Bytes Array:", bytes_array) 63 | 64 | record_array = np.rec.array(bytes_array, dtype=space_dtype) 65 | print("Record Array from Bytes:", record_array) 66 | -------------------------------------------------------------------------------- /tests/test_record_emulation.py: -------------------------------------------------------------------------------- 1 | import pufferlib.emulation 2 | 3 | from pufferlib.environments.ocean import env_creator 4 | 5 | env = env_creator('spaces')() 6 | env.reset() 7 | env.step([1,0]) 8 | breakpoint() 9 | -------------------------------------------------------------------------------- /tests/test_registry.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Loop through all folders in the `registry` directory 4 | for folder in pufferlib/registry/*; do 5 | if [ -d "$folder" ]; then 6 | # Extract folder name 7 | folder_name=$(basename $folder) 8 | 9 | if [[ $folder_name == __* ]]; then 10 | continue 11 | fi 12 | 13 | # Install package with extras 14 | pip install -e .[$folder_name] > /dev/null 2>&1 15 | 16 | # Run tests 17 | python tests/test_registry.py $folder_name 18 | fi 19 | done 20 | -------------------------------------------------------------------------------- /tests/test_render.py: -------------------------------------------------------------------------------- 1 | from pdb import set_trace as T 2 | 3 | import argparse 4 | import importlib 5 | import time 6 | 7 | import cv2 8 | 9 | 10 | # Tested human: classic_control, atari, minigrid 11 | # Tested rbg_array: atari, pokemon_red, crafter 12 | # Tested ansii: minihack, nethack, squared 13 | if __name__ == '__main__': 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--env', type=str, default='atari') 16 | parser.add_argument('--render-mode', type=str, default='rgb_array') 17 | args = parser.parse_args() 18 | 19 | env_module = importlib.import_module(f'pufferlib.environments.{args.env}') 20 | 21 | if args.render_mode == 'human': 22 | env = env_module.env_creator()(render_mode='human') 23 | else: 24 | env = env_module.env_creator()() 25 | 26 | terminal = True 27 | while True: 28 | start = time.time() 29 | if terminal or truncated: 30 | ob, _ = env.reset() 31 | 32 | if args.render_mode == 'rgb_array': 33 | frame = env.render() 34 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) 35 | #if ob.shape[0] in (1, 3, 4): 36 | # ob = ob.transpose(1, 2, 0) 37 | cv2.imshow('frame', frame) 38 | 39 | #cv2.imshow('ob', ob) 40 | cv2.waitKey(1) 41 | elif args.render_mode == 'ansi': 42 | chars = env.render() 43 | print("\033c", end="") 44 | print(chars) 45 | 46 | ob = ob.reshape(1, *ob.shape) 47 | action = env.action_space.sample() 48 | ob, reward, terminal, truncated, info = env.step(action) 49 | env.render() 50 | start = time.time() 51 | if time.time() - start < 1/60: 52 | time.sleep(1/60 - (time.time() - start)) 53 | 54 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import gym 3 | 4 | import pufferlib 5 | import pufferlib.utils 6 | 7 | def test_suppress(): 8 | with pufferlib.utils.Suppress(): 9 | gym.make('Breakout-v4') 10 | print('stdout (you should not see this)', file=sys.stdout) 11 | print('stderr (you should not see this)', file=sys.stderr) 12 | 13 | if __name__ == '__main__': 14 | test_suppress() -------------------------------------------------------------------------------- /tests/time_alloc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import timeit 3 | 4 | # Time np.zeros(2, 5) for 100000 iterations 5 | time_zeros = timeit.timeit('np.zeros((2, 5))', setup='import numpy as np', number=100000) 6 | 7 | # Pre-allocate the array 8 | preallocated_array = np.zeros((2, 5)) 9 | 10 | # Time setting the pre-allocated array to zero for 100000 iterations 11 | time_preallocated = timeit.timeit('preallocated_array[:] = 0', setup='import numpy as np; preallocated_array = np.zeros((2, 5))', number=100000) 12 | 13 | print(f"Time for np.zeros(2, 5) over 100000 iterations: {time_zeros} seconds") 14 | print(f"Time for preallocated *= 0 over 100000 iterations: {time_preallocated} seconds") 15 | 16 | --------------------------------------------------------------------------------