├── AUTHORS ├── requirements_test.txt ├── docs ├── benchmarks │ └── index.md ├── assets │ ├── fonts │ │ ├── BellMT.ttf │ │ ├── sherpa.ttf │ │ ├── sherpa.woff2 │ │ ├── sage-bold.ttf │ │ ├── BellMTItalic.ttf │ │ ├── bell mt fett.ttf │ │ ├── sage-bold.woff2 │ │ ├── Franklin-Gothic-Demi.TTF │ │ ├── Franklin-Gothic-Heavy.TTF │ │ ├── Franklin-Gothic-Medium.ttf │ │ ├── Franklin-Gothic-Demi-Cond.TTF │ │ ├── Franklin-Gothic-Demi-Italic.TTF │ │ ├── Franklin-Gothic-Heavy-Italic.TTF │ │ ├── Franklin-Gothic-Medium-Cond.TTF │ │ └── Franklin-Gothic-Medium-Italic.ttf │ ├── images │ │ └── navix_logo.png │ └── stylesheets │ │ └── extra.css ├── requirements.txt ├── install │ └── index.md ├── examples │ └── customisation.ipynb ├── scripts │ └── gen_doc_stubs.py └── index.md ├── assets ├── sprites │ ├── floor.png │ ├── goal.png │ ├── lava.png │ ├── wall.png │ ├── ball_red.png │ ├── box_blue.png │ ├── box_grey.png │ ├── box_red.png │ ├── key_blue.png │ ├── key_grey.png │ ├── key_red.png │ ├── ball_blue.png │ ├── ball_green.png │ ├── ball_grey.png │ ├── ball_purple.png │ ├── ball_yellow.png │ ├── box_green.png │ ├── box_purple.png │ ├── box_yellow.png │ ├── key_green.png │ ├── key_purple.png │ ├── key_yellow.png │ ├── player_east.png │ ├── player_west.png │ ├── door_open_red.png │ ├── player_north.png │ ├── player_south.png │ ├── door_closed_blue.png │ ├── door_closed_grey.png │ ├── door_closed_red.png │ ├── door_locked_blue.png │ ├── door_locked_grey.png │ ├── door_locked_red.png │ ├── door_open_blue.png │ ├── door_open_green.png │ ├── door_open_grey.png │ ├── door_open_purple.png │ ├── door_open_yellow.png │ ├── door_closed_green.png │ ├── door_closed_purple.png │ ├── door_closed_yellow.png │ ├── door_locked_green.png │ ├── door_locked_purple.png │ └── door_locked_yellow.png ├── COPYRIGHT └── LICENSE ├── NOTICE ├── requirements.txt ├── CHANGELOG.md ├── tests ├── performance │ ├── minigrid_report.txt │ ├── observations_room_report.txt │ ├── observations_keydoor_report.txt │ ├── grid.py │ ├── profiling.py │ ├── minigrid.py │ └── observations.py ├── issue_92.py ├── test_entities.py ├── test_spaces.py ├── issue_95.py ├── test_issues.py ├── test_terminations.py ├── test_tasks.py ├── test_environments.py ├── test_grid.py └── test_observations.py ├── navix ├── config.py ├── tasks.py ├── rendering │ ├── __init__.py │ ├── cache.py │ └── registry.py ├── _version.py ├── agents │ ├── __init__.py │ ├── models.py │ └── agent.py ├── environments │ ├── __init__.py │ ├── wrappers.py │ ├── registry.py │ ├── dist_shift.py │ ├── four_rooms.py │ ├── lava_gap.py │ ├── go_to_door.py │ ├── dynamic_obstacles.py │ ├── empty.py │ ├── crossings.py │ ├── key_corridor.py │ └── door_key.py ├── __init__.py ├── events.py ├── components.py ├── terminations.py ├── spaces.py ├── transitions.py ├── rewards.py └── experiment.py ├── CITATION.cff ├── COPYRIGHT ├── scripts └── release.sh ├── baselines └── ppo.py ├── examples ├── ppo.py ├── hparam_search.py └── purejaxrl │ └── wrappers.py ├── .github └── workflows │ ├── CD.yml │ └── CI.yml ├── benchmarks └── navix_.py ├── .gitignore ├── pyproject.toml ├── mkdocs.yml └── README.md /AUTHORS: -------------------------------------------------------------------------------- 1 | Eduardo Pignatelli -------------------------------------------------------------------------------- /requirements_test.txt: -------------------------------------------------------------------------------- 1 | minigrid -------------------------------------------------------------------------------- /docs/benchmarks/index.md: -------------------------------------------------------------------------------- 1 | **Coming soon** -------------------------------------------------------------------------------- /assets/sprites/floor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/floor.png -------------------------------------------------------------------------------- /assets/sprites/goal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/goal.png -------------------------------------------------------------------------------- /assets/sprites/lava.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/lava.png -------------------------------------------------------------------------------- /assets/sprites/wall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/wall.png -------------------------------------------------------------------------------- /assets/sprites/ball_red.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/ball_red.png -------------------------------------------------------------------------------- /assets/sprites/box_blue.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/box_blue.png -------------------------------------------------------------------------------- /assets/sprites/box_grey.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/box_grey.png -------------------------------------------------------------------------------- /assets/sprites/box_red.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/box_red.png -------------------------------------------------------------------------------- /assets/sprites/key_blue.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/key_blue.png -------------------------------------------------------------------------------- /assets/sprites/key_grey.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/key_grey.png -------------------------------------------------------------------------------- /assets/sprites/key_red.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/key_red.png -------------------------------------------------------------------------------- /assets/sprites/ball_blue.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/ball_blue.png -------------------------------------------------------------------------------- /assets/sprites/ball_green.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/ball_green.png -------------------------------------------------------------------------------- /assets/sprites/ball_grey.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/ball_grey.png -------------------------------------------------------------------------------- /assets/sprites/ball_purple.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/ball_purple.png -------------------------------------------------------------------------------- /assets/sprites/ball_yellow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/ball_yellow.png -------------------------------------------------------------------------------- /assets/sprites/box_green.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/box_green.png -------------------------------------------------------------------------------- /assets/sprites/box_purple.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/box_purple.png -------------------------------------------------------------------------------- /assets/sprites/box_yellow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/box_yellow.png -------------------------------------------------------------------------------- /assets/sprites/key_green.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/key_green.png -------------------------------------------------------------------------------- /assets/sprites/key_purple.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/key_purple.png -------------------------------------------------------------------------------- /assets/sprites/key_yellow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/key_yellow.png -------------------------------------------------------------------------------- /assets/sprites/player_east.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/player_east.png -------------------------------------------------------------------------------- /assets/sprites/player_west.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/player_west.png -------------------------------------------------------------------------------- /docs/assets/fonts/BellMT.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/docs/assets/fonts/BellMT.ttf -------------------------------------------------------------------------------- /docs/assets/fonts/sherpa.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/docs/assets/fonts/sherpa.ttf -------------------------------------------------------------------------------- /docs/assets/fonts/sherpa.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/docs/assets/fonts/sherpa.woff2 -------------------------------------------------------------------------------- /assets/sprites/door_open_red.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/door_open_red.png -------------------------------------------------------------------------------- /assets/sprites/player_north.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/player_north.png -------------------------------------------------------------------------------- /assets/sprites/player_south.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/player_south.png -------------------------------------------------------------------------------- /docs/assets/fonts/sage-bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/docs/assets/fonts/sage-bold.ttf -------------------------------------------------------------------------------- /assets/sprites/door_closed_blue.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/door_closed_blue.png -------------------------------------------------------------------------------- /assets/sprites/door_closed_grey.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/door_closed_grey.png -------------------------------------------------------------------------------- /assets/sprites/door_closed_red.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/door_closed_red.png -------------------------------------------------------------------------------- /assets/sprites/door_locked_blue.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/door_locked_blue.png -------------------------------------------------------------------------------- /assets/sprites/door_locked_grey.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/door_locked_grey.png -------------------------------------------------------------------------------- /assets/sprites/door_locked_red.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/door_locked_red.png -------------------------------------------------------------------------------- /assets/sprites/door_open_blue.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/door_open_blue.png -------------------------------------------------------------------------------- /assets/sprites/door_open_green.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/door_open_green.png -------------------------------------------------------------------------------- /assets/sprites/door_open_grey.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/door_open_grey.png -------------------------------------------------------------------------------- /assets/sprites/door_open_purple.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/door_open_purple.png -------------------------------------------------------------------------------- /assets/sprites/door_open_yellow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/door_open_yellow.png -------------------------------------------------------------------------------- /docs/assets/fonts/BellMTItalic.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/docs/assets/fonts/BellMTItalic.ttf -------------------------------------------------------------------------------- /docs/assets/fonts/bell mt fett.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/docs/assets/fonts/bell mt fett.ttf -------------------------------------------------------------------------------- /docs/assets/fonts/sage-bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/docs/assets/fonts/sage-bold.woff2 -------------------------------------------------------------------------------- /docs/assets/images/navix_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/docs/assets/images/navix_logo.png -------------------------------------------------------------------------------- /assets/sprites/door_closed_green.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/door_closed_green.png -------------------------------------------------------------------------------- /assets/sprites/door_closed_purple.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/door_closed_purple.png -------------------------------------------------------------------------------- /assets/sprites/door_closed_yellow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/door_closed_yellow.png -------------------------------------------------------------------------------- /assets/sprites/door_locked_green.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/door_locked_green.png -------------------------------------------------------------------------------- /assets/sprites/door_locked_purple.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/door_locked_purple.png -------------------------------------------------------------------------------- /assets/sprites/door_locked_yellow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/assets/sprites/door_locked_yellow.png -------------------------------------------------------------------------------- /docs/assets/fonts/Franklin-Gothic-Demi.TTF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/docs/assets/fonts/Franklin-Gothic-Demi.TTF -------------------------------------------------------------------------------- /docs/assets/fonts/Franklin-Gothic-Heavy.TTF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/docs/assets/fonts/Franklin-Gothic-Heavy.TTF -------------------------------------------------------------------------------- /docs/assets/fonts/Franklin-Gothic-Medium.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/docs/assets/fonts/Franklin-Gothic-Medium.ttf -------------------------------------------------------------------------------- /docs/assets/fonts/Franklin-Gothic-Demi-Cond.TTF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/docs/assets/fonts/Franklin-Gothic-Demi-Cond.TTF -------------------------------------------------------------------------------- /docs/assets/fonts/Franklin-Gothic-Demi-Italic.TTF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/docs/assets/fonts/Franklin-Gothic-Demi-Italic.TTF -------------------------------------------------------------------------------- /docs/assets/fonts/Franklin-Gothic-Heavy-Italic.TTF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/docs/assets/fonts/Franklin-Gothic-Heavy-Italic.TTF -------------------------------------------------------------------------------- /docs/assets/fonts/Franklin-Gothic-Medium-Cond.TTF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/docs/assets/fonts/Franklin-Gothic-Medium-Cond.TTF -------------------------------------------------------------------------------- /docs/assets/fonts/Franklin-Gothic-Medium-Italic.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epignatelli/navix/HEAD/docs/assets/fonts/Franklin-Gothic-Medium-Italic.ttf -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | mkdocs 2 | mkdocs-material 3 | mkdocs-jupyter 4 | mkdocstrings 5 | mkdocstrings-python 6 | mkdocs-gen-files 7 | mkdocs-literate-nav 8 | mkdocs-section-index -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Apache epignatelli/navix 2 | Copyright 2023 The Navix Authors. 3 | 4 | This product includes software developed at 5 | The Apache Software Foundation (http://www.apache.org/). -------------------------------------------------------------------------------- /assets/COPYRIGHT: -------------------------------------------------------------------------------- 1 | Copyright 2024 https://github.com/Farama-Foundation/Minigrid 2 | The following images are under Apache 2.0 License as per https://github.com/Farama-Foundation/Minigrid/LICENSE. 3 | A copy of the license is provided in the file assets/LICENSE. 4 | 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # coding 2 | black 3 | flake8 4 | pylint 5 | typing-extensions 6 | # deploying 7 | setuptools_scm 8 | # testing 9 | pytest 10 | # compute libraries 11 | pillow 12 | jax 13 | flax 14 | rlax 15 | # experiments 16 | tyro 17 | # logging 18 | wandb -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | All notable changes to this project will be documented in this file. 3 | 4 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 5 | 6 | ## Unreleased 7 | -------------------------------------------------------------------------------- /tests/performance/minigrid_report.txt: -------------------------------------------------------------------------------- 1 | Profiling navix with `scan`, N_SEEDS = 10000, N_TIMESTEPS = 1000 2 | Compiling ... 3 | Compiled in 0.78s 4 | Running ... 5 | 0.1666584610939026 ± 0.021019447594881058 6 | Profiling minigrid, N_SEEDS = 1, N_TIMESTEPS = 1000 7 | -------------------------------------------------------------------------------- /navix/config.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | 4 | class Config: 5 | """Config class to store global variables.""" 6 | 7 | def __init__(self): 8 | self.ARRAY_CHECKS_ENABLED = False 9 | 10 | def update(self, key: str, value: Any) -> None: 11 | setattr(self, key, value) 12 | 13 | def reset(self) -> None: 14 | self.__init__() 15 | 16 | 17 | config = Config() 18 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | title: >- 3 | Navix: Accelerated gridworld navigation with JAX 4 | type: software 5 | authors: 6 | - family-names: Pignatelli 7 | given-names: Eduardo 8 | email: edu.pignatelli@gmail.com 9 | affiliation: University College London (UCL) 10 | orcid: 'https://orcid.org/0000-0003-0730-2303' 11 | identifiers: 12 | - type: url 13 | value: 'https://github.com/epignatelli/navix' 14 | repository-code: 'https://github.com/epignatelli/navix' 15 | keywords: 16 | - Reinforcement Learning 17 | - JAX 18 | - Minigrid 19 | license: Apache-2.0 20 | -------------------------------------------------------------------------------- /navix/tasks.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | from jax import Array 3 | from flax import struct 4 | 5 | from . import rewards, terminations 6 | from .states import State 7 | 8 | 9 | class Task(struct.PyTreeNode): 10 | """Task formuation as described in https://arxiv.org/abs/1609.01995""" 11 | 12 | reward_fn: Callable[[State, Array, State], Array] 13 | termination_fn: Callable[[State, Array, State], Array] 14 | 15 | 16 | NAVIGATION = Task(rewards.on_goal_reached, terminations.on_goal_reached) 17 | 18 | GO_TO_DOOR = Task(terminations.on_door_done, terminations.on_door_done) 19 | -------------------------------------------------------------------------------- /tests/issue_92.py: -------------------------------------------------------------------------------- 1 | """Unittest for issue #92: https://github.com/epignatelli/navix/issues/92""" 2 | 3 | import jax 4 | import navix as nx 5 | from navix.states import Event, EventType 6 | 7 | 8 | def test_issue_92(): 9 | # try instantiate Event 10 | event = Event( 11 | position=jax.numpy.asarray([-1, -1]), 12 | colour=jax.numpy.asarray([255, 0, 0]), 13 | happened=jax.numpy.asarray(False), 14 | event_type=EventType.NONE, 15 | ) 16 | # try instantiate environment 17 | env = nx.make("Navix-KeyCorridorS6R3-v0") 18 | timestep = env.reset(jax.random.PRNGKey(0)) 19 | timestep = env.step(timestep, jax.numpy.asarray(1)) 20 | -------------------------------------------------------------------------------- /tests/performance/observations_room_report.txt: -------------------------------------------------------------------------------- 1 | Profiling observation , N_SEEDS = 10, N_TIMESTEPS = 1000 2 | Compiling ... 3 | Compiled in 0.99s 4 | Running ... 5 | 0.04586170241236687 ± 0.00395115977153182 6 | Profiling observation , N_SEEDS = 10, N_TIMESTEPS = 1000 7 | Compiling ... 8 | Compiled in 0.99s 9 | Running ... 10 | 0.058218248188495636 ± 0.0005350122228264809 11 | Profiling observation , N_SEEDS = 10, N_TIMESTEPS = 1000 12 | Compiling ... 13 | Compiled in 1.08s 14 | Running ... 15 | 0.1171514168381691 ± 0.00043898896547034383 16 | -------------------------------------------------------------------------------- /tests/performance/observations_keydoor_report.txt: -------------------------------------------------------------------------------- 1 | Profiling observation , N_SEEDS = 10, N_TIMESTEPS = 1000 2 | Compiling ... 3 | Compiled in 1.35s 4 | Running ... 5 | 0.07763808965682983 ± 0.004634591285139322 6 | Profiling observation , N_SEEDS = 10, N_TIMESTEPS = 1000 7 | Compiling ... 8 | Compiled in 1.33s 9 | Running ... 10 | 0.09059003740549088 ± 0.0004746166814584285 11 | Profiling observation , N_SEEDS = 10, N_TIMESTEPS = 1000 12 | Compiling ... 13 | Compiled in 1.47s 14 | Running ... 15 | 0.14711560308933258 ± 0.0005664637428708375 16 | -------------------------------------------------------------------------------- /COPYRIGHT: -------------------------------------------------------------------------------- 1 | Copyright 2023 The Navix Authors. 2 | 3 | Licensed to the Apache Software Foundation (ASF) under one 4 | or more contributor license agreements. See the NOTICE file 5 | distributed with this work for additional information 6 | regarding copyright ownership. The ASF licenses this file 7 | to you under the Apache License, Version 2.0 (the 8 | "License"); you may not use this file except in compliance 9 | with the License. You may obtain a copy of the License at 10 | 11 | http://www.apache.org/licenses/LICENSE-2.0 12 | 13 | Unless required by applicable law or agreed to in writing, 14 | software distributed under the License is distributed on an 15 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | KIND, either express or implied. See the License for the 17 | specific language governing permissions and limitations 18 | under the License. -------------------------------------------------------------------------------- /docs/install/index.md: -------------------------------------------------------------------------------- 1 | ## Install JAX 2 | NAVIX depends on JAX. 3 | Follow the official [JAX installation guide](https://github.com/google/jax#installation.) for your OS and preferred accelerator. 4 | 5 | For a quick start, you can install JAX for GPU with the following command: 6 | ```bash 7 | pip install -U "jax[cuda12]" 8 | ``` 9 | which will install JAX with CUDA 12 support. 10 | 11 | 12 | ## Install NAVIX 13 | ```bash 14 | pip install navix 15 | ``` 16 | 17 | Or, for the latest version from source: 18 | ```bash 19 | pip install git+https://github.com/epignatelli/navix 20 | ``` 21 | 22 | 23 | ## Installing in a conda environment 24 | We recommend install NAVIX in a conda environment. 25 | To create a new conda environment and install NAVIX, run the following commands: 26 | ```bash 27 | conda create -n navix python=3.10 28 | conda activate navix 29 | cd 30 | pip install navix 31 | ``` 32 | -------------------------------------------------------------------------------- /docs/examples/customisation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Customising the Environment" 8 | ] 9 | } 10 | ], 11 | "metadata": { 12 | "kernelspec": { 13 | "display_name": "Python 3", 14 | "language": "python", 15 | "name": "python3" 16 | }, 17 | "language_info": { 18 | "codemirror_mode": { 19 | "name": "ipython", 20 | "version": 3 21 | }, 22 | "file_extension": ".py", 23 | "mimetype": "text/x-python", 24 | "name": "python", 25 | "nbconvert_exporter": "python", 26 | "pygments_lexer": "ipython3", 27 | "version": "3.10.12" 28 | } 29 | }, 30 | "nbformat": 4, 31 | "nbformat_minor": 4 32 | } 33 | -------------------------------------------------------------------------------- /navix/rendering/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Navix Authors. 2 | 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | 20 | 21 | from . import cache, registry 22 | -------------------------------------------------------------------------------- /scripts/release.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # get current script directory 4 | SCRIPT="$(readlink -f "$0")" 5 | SCRIPT_DIR=$(dirname "$SCRIPT") 6 | echo "Script dir is: $SCRIPT_DIR" 7 | 8 | # get version file 9 | VERSION_FILE="$SCRIPT_DIR/../navix/_version.py" 10 | VERSION_CONTENT="$(cat "$VERSION_FILE")" 11 | echo "Version file found at: $VERSION_FILE and contains:" 12 | echo "$VERSION_CONTENT" 13 | 14 | # extract version 15 | VERSION=$(cat navix/_version.py | grep "__version__ = " | cut -d'=' -f2 | sed 's,\",,g' | sed "s,',,g" | sed 's, ,,g') 16 | echo "Current version is:" 17 | echo "$VERSION" 18 | 19 | # cd to repo dir 20 | REPO_DIR="$(cd "$(dirname -- "$1")" >/dev/null; pwd -P)/$(basename -- "$1")" 21 | echo "Repo dir is: $REPO_DIR" 22 | cd $REPO_DIR 23 | 24 | # create tag 25 | git tag -a $VERSION -m "Release $VERSION" 26 | git push --tags 27 | 28 | # create release 29 | gh release create $VERSION 30 | 31 | # trigger CD 32 | gh workflow run CD -r main 33 | -------------------------------------------------------------------------------- /navix/_version.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Navix Authors. 2 | 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | 20 | 21 | __version__ = "0.7.4" 22 | __version_info__ = tuple(int(i) for i in __version__.split(".") if i.isdigit()) 23 | -------------------------------------------------------------------------------- /navix/agents/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Navix Authors. 2 | 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | 20 | 21 | from .ppo import PPO, PPOHparams as PPOHparams 22 | from .models import MLPEncoder, ConvEncoder, ActorCritic 23 | -------------------------------------------------------------------------------- /tests/test_entities.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | from navix.entities import Goal, Player 5 | from navix.rendering.registry import TILE_SIZE 6 | 7 | 8 | def test_indexing(): 9 | # batched entity with batch size 1 10 | entity = Player( 11 | position=jnp.ones((1, 2), dtype=jnp.int32), 12 | direction=jnp.ones((1,), jnp.int32), 13 | pocket=jnp.ones((1,), jnp.int32), 14 | ) 15 | assert jnp.array_equal(entity[0].position, jnp.asarray((1, 1))) 16 | assert jnp.array_equal(entity[0].direction, jnp.asarray(1)) 17 | 18 | 19 | def test_get_sprites(): 20 | # batched entity with batch size 1 21 | entity = Goal.create(position=jnp.ones((1, 2)), probability=jnp.ones((1,))) 22 | assert entity.sprite.shape == (1, TILE_SIZE, TILE_SIZE, 3) 23 | 24 | # batched entity with batch size > 1 25 | entity = Goal.create(position=jnp.ones((5, 2)), probability=jnp.ones((5,))) 26 | assert entity.sprite.shape == (5, TILE_SIZE, TILE_SIZE, 3) 27 | 28 | 29 | if __name__ == "__main__": 30 | test_indexing() 31 | # test_get_sprites() 32 | jax.jit(test_get_sprites)() 33 | -------------------------------------------------------------------------------- /navix/environments/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Navix Authors. 2 | 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | 20 | 21 | from __future__ import annotations 22 | 23 | 24 | from .environment import Environment, Timestep 25 | from .empty import Room 26 | from .door_key import DoorKey 27 | from .four_rooms import FourRooms 28 | from .key_corridor import KeyCorridor 29 | from .lava_gap import LavaGap 30 | from .crossings import Crossings 31 | from .dynamic_obstacles import DynamicObstacles 32 | from .dist_shift import DistShift 33 | from .go_to_door import GoToDoor 34 | -------------------------------------------------------------------------------- /navix/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Navix Authors. 2 | 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | 20 | 21 | from . import ( 22 | actions, 23 | components, 24 | entities, 25 | grid, 26 | observations, 27 | rewards, 28 | environments, 29 | terminations, 30 | config, 31 | spaces, 32 | rendering, 33 | transitions, 34 | events, 35 | agents, 36 | ) 37 | 38 | from .environments.registry import make, register_env, registry 39 | from .experiment import Experiment 40 | from .environments.environment import Environment, Timestep, StepType -------------------------------------------------------------------------------- /docs/scripts/gen_doc_stubs.py: -------------------------------------------------------------------------------- 1 | """Generate the code reference pages and navigation.""" 2 | 3 | from pathlib import Path 4 | 5 | import mkdocs_gen_files 6 | 7 | nav = mkdocs_gen_files.nav.Nav() 8 | 9 | root = Path(__file__).parent.parent.parent 10 | src = root / "navix" 11 | out = "api" 12 | 13 | exclude_files = [ 14 | "_version.py", 15 | "config.py" 16 | ] 17 | 18 | for path in sorted(src.rglob("*.py")): 19 | if path.name in exclude_files: 20 | continue 21 | 22 | print("Generating stub for", path) 23 | module_path = path.relative_to(src).with_suffix("") 24 | doc_path = path.relative_to(src).with_suffix(".md") 25 | full_doc_path = Path(out, doc_path) 26 | 27 | parts = tuple(module_path.parts) 28 | parts = ("navix",) + parts 29 | 30 | if parts[-1] == "__init__": 31 | parts = parts[:-1] 32 | doc_path = doc_path.with_name("index.md") 33 | full_doc_path = full_doc_path.with_name("index.md") 34 | elif parts[-1] == "__main__": 35 | continue 36 | 37 | if parts: 38 | nav[parts] = doc_path.as_posix() 39 | 40 | with mkdocs_gen_files.open(full_doc_path, "w") as fd: 41 | ident = ".".join(parts) 42 | fd.write(f"::: {ident}") 43 | 44 | mkdocs_gen_files.set_edit_path(full_doc_path, path.relative_to(root)) 45 | 46 | with mkdocs_gen_files.open(f"{out}/index.md", "w") as nav_file: 47 | nav_file.writelines(nav.build_literate_nav()) 48 | -------------------------------------------------------------------------------- /tests/performance/grid.py: -------------------------------------------------------------------------------- 1 | import time 2 | from timeit import repeat 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | import navix as nx 7 | 8 | 9 | N_TIMEIT_LOOPS = 10 10 | N_REPEAT = 100 11 | N_TIMESTEPS = 1000 12 | N_SEEDS = 100 13 | 14 | 15 | def test_observation(): 16 | def test(seed): 17 | env = nx.environments.Room.create( 18 | height=10, width=5, max_steps=100, observation_fn=nx.observations.none 19 | ) 20 | key = jax.random.PRNGKey(seed) 21 | timestep = env._reset(key) 22 | 23 | actions = jax.random.randint(key, (100,), 0, 6) 24 | timestep = jax.lax.scan(lambda c, x: (env.step(c, x), ()), timestep, actions)[0] 25 | return timestep 26 | 27 | # profile navix scanned 28 | print("Profiling, N_SEEDS = {}, N_TIMESTEPS = {}".format(N_SEEDS, N_TIMESTEPS)) 29 | 30 | seeds = jnp.arange(N_SEEDS) 31 | 32 | print(f"\tCompiling {test}...") 33 | start = time.time() 34 | test_jit = jax.jit(jax.vmap(test)).lower(seeds).compile() 35 | print("\tCompiled in {:.2f}s".format(time.time() - start)) 36 | 37 | print("\tRunning ...") 38 | res = repeat( 39 | lambda: test_jit(seeds).observation.block_until_ready(), 40 | number=N_TIMEIT_LOOPS, 41 | repeat=N_REPEAT, 42 | ) 43 | res = jnp.asarray(res) 44 | print(f"\t {jnp.mean(res)} ± {jnp.std(res)}") 45 | 46 | 47 | if __name__ == "__main__": 48 | test_observation() 49 | -------------------------------------------------------------------------------- /tests/test_spaces.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | from navix.spaces import Continuous, Discrete 6 | 7 | 8 | MAX_INT = 100_000_000 9 | MIN_INT = -100_000_000 10 | 11 | 12 | def test_discrete(): 13 | key = jax.random.PRNGKey(42) 14 | elements = (5, 0, MAX_INT, MIN_INT) 15 | shapes = ((), (0,), (0, 0), (1, 2), (5, 5)) 16 | dtypes = (jnp.int8, jnp.int16, jnp.int32) 17 | for element in elements: 18 | for shape in shapes: 19 | for dtype in dtypes: 20 | space = Discrete.create(element, shape, dtype) 21 | sample = space.sample(key) 22 | print(sample) 23 | assert jnp.all(jnp.logical_not(jnp.isnan(sample))) 24 | 25 | 26 | def test_continuous(): 27 | key = jax.random.PRNGKey(42) 28 | shapes = ((), (0,), (0, 0), (1, 2), (5, 5)) 29 | min_max = [ 30 | (0.0, 1.0), 31 | (0.0, 1), 32 | (0, 1), 33 | (1.0, -1.0), 34 | (MIN_INT, MAX_INT), 35 | ] 36 | for shape in shapes: 37 | for minimum, maximum in min_max: 38 | space = Continuous.create( 39 | shape=shape, minimum=jnp.asarray(minimum), maximum=jnp.asarray(maximum) 40 | ) 41 | sample = space.sample(key) 42 | print(sample) 43 | assert jnp.all(jnp.logical_not(jnp.isnan(sample))) 44 | 45 | 46 | if __name__ == "__main__": 47 | test_discrete() 48 | test_continuous() 49 | -------------------------------------------------------------------------------- /tests/performance/profiling.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import navix as nx 4 | import time 5 | 6 | 7 | N_TIMESTEPS = 10 8 | N_SEEDS = 100 9 | 10 | 11 | def f(seed): 12 | key = jax.random.PRNGKey(seed) 13 | env = nx.environments.Room(16, 16, 8, observation_fn=nx.observations.rgb) 14 | timestep = env._reset(key) 15 | 16 | for _ in range(N_TIMESTEPS): 17 | action = jax.random.randint(timestep.state.key, (), 0, 6) 18 | timestep = env.step(timestep, jnp.asarray(action)) 19 | return timestep 20 | 21 | 22 | def f_scan(seed): 23 | key = jax.random.PRNGKey(seed) 24 | env = nx.environments.Room(16, 16, 8, observation_fn=nx.observations.rgb) 25 | timestep = env._reset(key) 26 | 27 | def body_fun(carry, x): 28 | timestep = carry 29 | action = jax.random.randint(timestep.state.key, (), 0, 6) 30 | timestep = env.step(timestep, action) 31 | return timestep, () 32 | 33 | timestep = jax.lax.scan( 34 | body_fun, 35 | timestep, 36 | [None] * N_TIMESTEPS, 37 | length=N_TIMESTEPS, 38 | )[0] 39 | return timestep 40 | 41 | 42 | seeds = jnp.arange(N_SEEDS) 43 | function = jax.vmap(f) 44 | 45 | # print(f"\tCompiling {function}...") 46 | # start = time.time() 47 | # f_jit = jax.jit(function).lower(seeds).compile() 48 | # print("\tCompiled in {:.2f}s".format(time.time() - start)) 49 | 50 | 51 | with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True): 52 | timestep = function(seeds).observation.block_until_ready() 53 | -------------------------------------------------------------------------------- /tests/issue_95.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from navix.entities import Door 3 | 4 | 5 | def test_open_unlocked(): 6 | # open=1, requires=-1 (unlocked) 7 | door = Door( 8 | position=jnp.zeros((1, 2)), 9 | requires=jnp.asarray([-1]), 10 | open=jnp.asarray([1]), 11 | colour=jnp.asarray([0], dtype=jnp.uint8), 12 | ) 13 | state = door.symbolic_state 14 | assert jnp.all(state == 0), "Expected state to be 0 for open unlocked door." 15 | 16 | 17 | def test_closed_unlocked(): 18 | # open=0, requires=-1 (unlocked) 19 | door = Door( 20 | position=jnp.zeros((1, 2)), 21 | requires=jnp.asarray([-1]), 22 | open=jnp.asarray([0]), 23 | colour=jnp.asarray([0], dtype=jnp.uint8), 24 | ) 25 | state = door.symbolic_state 26 | assert jnp.all(state == 1), "Expected state to be 1 for closed unlocked door." 27 | 28 | 29 | def test_closed_locked(): 30 | # open=0, requires=5 (locked) 31 | door = Door( 32 | position=jnp.zeros((1, 2)), 33 | requires=jnp.asarray([5]), 34 | open=jnp.asarray([0]), 35 | colour=jnp.asarray([0], dtype=jnp.uint8), 36 | ) 37 | state = door.symbolic_state 38 | assert jnp.all(state == 2), "Expected state to be 2 for closed locked door." 39 | 40 | 41 | def test_open_locked(): 42 | # open=1, requires=5 (locked, but open) 43 | door = Door( 44 | position=jnp.zeros((1, 2)), 45 | requires=jnp.asarray([5]), 46 | open=jnp.asarray([1]), 47 | colour=jnp.asarray([0], dtype=jnp.uint8), 48 | ) 49 | state = door.symbolic_state 50 | assert jnp.all(state == 0), "Expected state to be 0 for open locked door." 51 | -------------------------------------------------------------------------------- /tests/test_issues.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Navix Authors. 2 | 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | 20 | from __future__ import annotations 21 | 22 | import jax 23 | import jax.numpy as jnp 24 | 25 | import navix as nx 26 | from navix import observations 27 | 28 | 29 | def test_82(): 30 | env = nx.make( 31 | "Navix-DoorKey-5x5-v0", 32 | max_steps=100, 33 | observation_fn=observations.rgb, 34 | ) 35 | key = jax.random.PRNGKey(5) 36 | timestep = env.reset(key) 37 | # Seed 5 is: 38 | # # # # # 39 | # P # . # 40 | # . # . # 41 | # K D G # 42 | # # # # # 43 | 44 | # start agent direction = EAST 45 | prev_pos = timestep.state.entities["player"].position 46 | # action 2 is forward 47 | timestep = env.step(timestep, 2) # should not walk into wall 48 | pos = timestep.state.entities["player"].position 49 | assert jnp.array_equal(prev_pos, pos) 50 | 51 | 52 | if __name__ == "__main__": 53 | test_82() 54 | -------------------------------------------------------------------------------- /baselines/ppo.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Tuple 3 | import wandb 4 | 5 | import numpy as np 6 | import jax.numpy as jnp 7 | import flax.linen as nn 8 | from flax.linen.initializers import constant, orthogonal 9 | import tyro 10 | import navix as nx 11 | from navix.environments.environment import Environment 12 | from navix.agents import PPO, PPOHparams, ActorCritic 13 | 14 | 15 | def FlattenObsWrapper(env: Environment): 16 | flatten_obs_fn = lambda x: jnp.ravel(env.observation_fn(x)) 17 | flatten_obs_shape = (int(np.prod(env.observation_space.shape)),) 18 | return env.replace( 19 | observation_fn=flatten_obs_fn, 20 | observation_space=env.observation_space.replace(shape=flatten_obs_shape), 21 | ) 22 | 23 | 24 | @dataclass 25 | class Args: 26 | project_name = "navix-baselines" 27 | seeds_range: Tuple[int, int, int] = (0, 10, 1) 28 | ppo: PPOHparams = PPOHparams() 29 | 30 | 31 | if __name__ == "__main__": 32 | args = tyro.cli(Args) 33 | 34 | # create environments 35 | for env_id in nx.registry(): 36 | # init logging 37 | config = {**vars(args), **{"observations": "symbolic"}, **{"algo": "ppo"}} 38 | wandb.init(project=args.project_name, config=config) 39 | 40 | # init environment 41 | env = nx.make(env_id) 42 | env = FlattenObsWrapper(env) 43 | 44 | # create agent 45 | agent = PPO( 46 | hparams=args.ppo, 47 | network=ActorCritic(action_dim=len(env.action_set)), 48 | env=env, 49 | ) 50 | 51 | # run experiment 52 | experiment = nx.Experiment( 53 | name=args.project_name, 54 | agent=agent, 55 | env=env, 56 | env_id=env_id, 57 | seeds=tuple(range(*args.seeds_range)), 58 | ) 59 | experiment.run() 60 | -------------------------------------------------------------------------------- /examples/ppo.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | import tyro 3 | import numpy as np 4 | import jax.numpy as jnp 5 | import navix as nx 6 | from navix import observations 7 | from navix.agents import PPO, PPOHparams, ActorCritic 8 | from navix.environments.environment import Environment 9 | 10 | # set persistent compilation cache directory 11 | # jax.config.update("jax_compilation_cache_dir", "/tmp/jax-cache/") 12 | 13 | 14 | @dataclass 15 | class Args: 16 | project_name = "navix-examples" 17 | seeds_offset: int = 0 18 | n_seeds: int = 1 19 | # env 20 | env_id: str = "Navix-Empty-Random-5x5-v0" 21 | discount: float = 0.99 22 | # ppo 23 | ppo_config: PPOHparams = field(default_factory=PPOHparams) 24 | 25 | 26 | if __name__ == "__main__": 27 | args = tyro.cli(Args) 28 | 29 | def FlattenObsWrapper(env: Environment): 30 | flatten_obs_fn = lambda x: jnp.ravel(env.observation_fn(x)) 31 | flatten_obs_shape = (int(np.prod(env.observation_space.shape)),) 32 | return env.replace( 33 | observation_fn=flatten_obs_fn, 34 | observation_space=env.observation_space.replace(shape=flatten_obs_shape), 35 | ) 36 | 37 | env = nx.make( 38 | args.env_id, 39 | observation_fn=observations.symbolic_first_person, 40 | gamma=args.discount, 41 | ) 42 | env = FlattenObsWrapper(env) 43 | 44 | agent = PPO( 45 | hparams=args.ppo_config, 46 | network=ActorCritic( 47 | action_dim=len(env.action_set), 48 | ), 49 | env=env, 50 | ) 51 | 52 | experiment = nx.Experiment( 53 | name=args.project_name, 54 | agent=agent, 55 | env=env, 56 | env_id=args.env_id, 57 | seeds=tuple(range(args.seeds_offset, args.seeds_offset + args.n_seeds)), 58 | ) 59 | train_state, logs = experiment.run() 60 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 |

A fast, fully jittable, batched MiniGrid reimplemented in JAX for HIGH THROUGHPUT

2 |

Welcome to NAVIX!

3 | 4 | 5 | **NAVIX** is a reimplementation of the [MiniGrid](https://minigrid.farama.org/) environment suite in JAX, and leverages JAX’s intermediate language representation to migrate the computation to different accelerators, such as GPUs and TPUs. 6 | 7 | NAVIX is designed to be a drop-in replacement for the original MiniGrid environment, with the added benefit of being significantly faster. 8 | Experiments that took **1 week**, now take **15 minutes**. 9 | 10 | A [`navix.Environment`](api/environments/environment.md) is a [`flax.struct.PyTreeNode`](https://flax.readthedocs.io/en/latest/api_reference/flax.struct.html#flax.struct.PyTreeNode) and supports [`jax.vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html), [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html), [`jax.grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html), and all the other JAX's transformations. 11 | See some examples [here](examples/getting_started.ipynb). 12 | 13 |
14 | Most of the MiniGrid environments are supported, and the API is designed to be as close as possible to the original MiniGrid API. 15 | However, some features might be missing, or the API might be slightly different. 16 | If you find so, please [open an issue](https://github.com/epignatelli/navix/issues/new) or a [pull request](https://github.com/epignatelli/navix/pulls), contributions are welcome! 17 | 18 | 19 | Thanks to JAX's backend, NAVIX offers: 20 | 21 | - Multiple accelerators: NAVIX can run on CPU, GPU, or TPU. 22 | - Performance boost: 200 000x speed up in batch mode or 20x unbatched mode. 23 | - Parallellisation: NAVIX can run up to 2048 PPO agents (32768 environments!) in parallel on a single Nvidia A100 80Gb. 24 | - Full automatic differentiation: NAVIX can compute gradients of the environment with respect to the agent's actions. 25 | 26 | 27 | [Get started with NAVIX](examples/getting_started.ipynb){ .md-button .md-button--primary} -------------------------------------------------------------------------------- /navix/environments/wrappers.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Tuple 2 | import jax 3 | from jax import Array 4 | from flax import struct 5 | from gymnax.environments.environment import ( 6 | Environment as GymnaxEnv, 7 | EnvParams, 8 | EnvState, 9 | ) 10 | from gymnax.environments.spaces import Discrete as GymnaxDiscrete, Box as GymnaxBox 11 | 12 | from .environment import Environment, Timestep 13 | 14 | 15 | @struct.dataclass 16 | class GymnaxState(EnvState): 17 | timestep: Timestep 18 | time: Array 19 | 20 | 21 | class ToGymnax(GymnaxEnv): 22 | def __init__(self, env: Environment): 23 | self.env = env 24 | 25 | @property 26 | def default_params(self) -> EnvParams: 27 | return EnvParams(max_steps_in_episode=self.env.max_steps) 28 | 29 | @classmethod 30 | def wrap(cls, env: Environment) -> Tuple[GymnaxEnv, EnvParams]: 31 | return cls(env=env), EnvParams(max_steps_in_episode=env.max_steps) 32 | 33 | def action_space(self, params: Any): 34 | return GymnaxDiscrete(len(self.env.action_set)) 35 | 36 | def observation_space(self, params: Any): 37 | o_space = self.env.observation_space 38 | return GymnaxBox( 39 | low=o_space.minimum, 40 | high=o_space.maximum, 41 | shape=o_space.shape, 42 | dtype=o_space.dtype, 43 | ) 44 | 45 | def reset( 46 | self, key: jax.Array, params: EnvParams | None = None 47 | ) -> Tuple[Array, EnvState]: 48 | timestep = self.env.reset(key) 49 | return ( 50 | timestep.observation, 51 | GymnaxState(time=timestep.t, timestep=timestep), 52 | ) 53 | 54 | def step( 55 | self, key: Array, state: GymnaxState, action: jax.Array, params: EnvParams 56 | ) -> Tuple[Array, EnvState, Array, Array, Dict[str, Any]]: 57 | timestep = self.env.step(state.timestep, action) 58 | return ( 59 | timestep.observation, 60 | GymnaxState(time=timestep.t, timestep=timestep), 61 | timestep.reward, 62 | timestep.is_done(), 63 | timestep.info, 64 | ) 65 | -------------------------------------------------------------------------------- /tests/test_terminations.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | import navix as nx 5 | from navix.entities import Directions 6 | from navix.states import State 7 | from navix.components import EMPTY_POCKET_ID 8 | 9 | 10 | def test_on_navigation_completion(): 11 | grid = jnp.zeros((5, 5), dtype=jnp.int32) 12 | 13 | players = nx.entities.Player( 14 | position=jnp.asarray((1, 1)), direction=Directions.EAST, pocket=EMPTY_POCKET_ID 15 | ) 16 | goals = nx.entities.Goal.create(position=jnp.asarray((1, 2)), probability=jnp.asarray(1)) 17 | entities = { 18 | nx.entities.Entities.PLAYER: players[None], 19 | nx.entities.Entities.GOAL: goals[None], 20 | } 21 | 22 | state = State( 23 | key=jax.random.PRNGKey(0), 24 | grid=grid, 25 | cache=nx.rendering.cache.RenderingCache.init(grid), 26 | entities=entities, 27 | ) 28 | # should not terminate 29 | termination = nx.terminations.on_goal_reached(state, jnp.asarray(0), state) 30 | assert not termination, f"Should not terminate, got {termination} instead" 31 | 32 | # move forward 33 | new_state = nx.actions.forward(state) 34 | termination = nx.terminations.on_goal_reached(state, jnp.asarray(0), new_state) 35 | assert termination, f"Should terminate, got {termination} instead" 36 | 37 | 38 | def test_check_truncation(): 39 | terminated = jnp.asarray(False) 40 | truncated = jnp.asarray(False) 41 | assert nx.terminations.check_truncation(terminated, truncated) == jnp.asarray( 42 | 0, dtype=jnp.int32 43 | ) 44 | 45 | terminated = jnp.asarray(True) 46 | truncated = jnp.asarray(False) 47 | assert nx.terminations.check_truncation(terminated, truncated) == jnp.asarray( 48 | 2, dtype=jnp.int32 49 | ) 50 | 51 | terminated = jnp.asarray(False) 52 | truncated = jnp.asarray(True) 53 | assert nx.terminations.check_truncation(terminated, truncated) == jnp.asarray( 54 | 1, dtype=jnp.int32 55 | ) 56 | 57 | terminated = jnp.asarray(True) 58 | truncated = jnp.asarray(True) 59 | assert nx.terminations.check_truncation(terminated, truncated) == jnp.asarray( 60 | 2, dtype=jnp.int32 61 | ) 62 | 63 | 64 | if __name__ == "__main__": 65 | test_on_navigation_completion() 66 | test_check_truncation() 67 | -------------------------------------------------------------------------------- /.github/workflows/CD.yml: -------------------------------------------------------------------------------- 1 | name: CD 2 | 3 | on: 4 | workflow_dispatch: 5 | push: 6 | branches: 7 | - "main" 8 | 9 | jobs: 10 | release: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v3 14 | with: 15 | fetch-depth: 0 16 | 17 | - name: Get version from file 18 | run: | 19 | VERSION_FILE="$/navix/_version.py" 20 | NAVIX_VERSION="$(cat navix/_version.py | grep '__version__ = ' | cut -d'=' -f2 | sed 's,\",,g' | sed "s,\',,g" | sed 's, ,,g')" 21 | echo "Current version is:" 22 | echo "$NAVIX_VERSION" 23 | 24 | echo "NAVIX_VERSION=$NAVIX_VERSION" >> $GITHUB_ENV 25 | 26 | - name: Create changelog and push tag 27 | id: changelog 28 | uses: TriPSs/conventional-changelog-action@v3 29 | with: 30 | github-token: ${{ secrets.GITHUB_TOKEN }} 31 | output-file: false 32 | fallback-version: ${{ env.NAVIX_VERSION }} 33 | skip-commit: true 34 | 35 | - name: Create Release 36 | uses: ncipollo/release-action@v1 37 | with: 38 | tag: ${{ env.NAVIX_VERSION }} 39 | name: "NAVIX release v${{ env.NAVIX_VERSION }}" 40 | body: ${{ steps.changelog.outputs.clean_changelog }} 41 | 42 | - uses: actions/setup-python@v4 43 | with: 44 | python-version: "3.10" 45 | 46 | - name: Install pypa/build 47 | run: | 48 | python -m pip install build 49 | 50 | - name: Build wheel and sdist 51 | run: | 52 | python -m build --sdist --wheel --outdir dist/ . 53 | 54 | - name: Publish distribution 📦 to PyPI 55 | uses: pypa/gh-action-pypi-publish@release/v1 56 | with: 57 | verbose: true 58 | password: ${{ secrets.PYPI_API_KEY }} 59 | 60 | deploy-docs: 61 | runs-on: ubuntu-latest 62 | steps: 63 | - uses: actions/checkout@v3 64 | with: 65 | fetch-depth: 0 66 | - uses: actions/setup-python@v4 67 | with: 68 | python-version: "3.10" 69 | - run: pip install --upgrade pip && pip install -r docs/requirements.txt 70 | - run: git config user.name 'github-actions[bot]' && git config user.email 'github-actions[bot]@users.noreply.github.com' 71 | - name: Publish docs 72 | run: mkdocs gh-deploy 73 | -------------------------------------------------------------------------------- /tests/performance/minigrid.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import navix as nx 4 | 5 | import gymnasium 6 | import random 7 | import time 8 | 9 | from timeit import repeat 10 | 11 | N_TIMEIT_LOOPS = 3 12 | N_REPEAT = 5 13 | N_TIMESTEPS = 1000 14 | N_SEEDS = 10_000 15 | 16 | 17 | def profile_navix_scan(seed): 18 | env = nx.environments.Room.create( 19 | height=10, width=5, max_steps=100, observation_fn=nx.observations.categorical 20 | ) 21 | key = jax.random.PRNGKey(4) 22 | timestep = env._reset(key) 23 | actions = jax.random.randint(key, (N_TIMESTEPS,), 0, 6) 24 | 25 | timestep = jax.lax.scan( 26 | lambda carry, x: (env.step(carry, x), ()), timestep, actions 27 | )[0] 28 | 29 | return timestep 30 | 31 | 32 | def profile_minigrid(seed): 33 | env = gymnasium.make("MiniGrid-Empty-16x16-v0", render_mode=None) 34 | observation, info = env.reset(seed=42) 35 | for _ in range(N_TIMESTEPS): 36 | action = random.randint(0, 4) 37 | observation, reward, terminated, truncated, info = env.step(action) 38 | 39 | if terminated or truncated: 40 | observation, info = env.reset() 41 | env.close() 42 | return observation 43 | 44 | 45 | if __name__ == "__main__": 46 | # profile navix scanned 47 | print( 48 | "Profiling navix with `scan`, N_SEEDS = {}, N_TIMESTEPS = {}".format( 49 | N_SEEDS, N_TIMESTEPS 50 | ) 51 | ) 52 | seeds = jnp.arange(N_SEEDS) 53 | 54 | print(f"\tCompiling {profile_navix_scan}...") 55 | start = time.time() 56 | f_scan = jax.jit(jax.vmap(profile_navix_scan)).lower(seeds).compile() 57 | print("\tCompiled in {:.2f}s".format(time.time() - start)) 58 | 59 | print("\tRunning ...") 60 | res_navix = repeat( 61 | lambda: f_scan(seeds).observation.block_until_ready(), 62 | number=N_TIMEIT_LOOPS, 63 | repeat=N_REPEAT, 64 | ) 65 | res_navix = jnp.asarray(res_navix) 66 | print(f"\t {jnp.mean(res_navix)} ± {jnp.std(res_navix)}") 67 | 68 | # profile minigrid 69 | print("Profiling minigrid, N_SEEDS = 1, N_TIMESTEPS = {}".format(N_TIMESTEPS)) 70 | res_minigrid = repeat( 71 | lambda: profile_minigrid(0), number=N_TIMEIT_LOOPS, repeat=N_REPEAT 72 | ) 73 | res_minigrid = jnp.asarray(res_minigrid) 74 | print(f"\t {jnp.mean(res_minigrid)} ± {jnp.std(res_minigrid)}") 75 | -------------------------------------------------------------------------------- /tests/test_tasks.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | import navix as nx 5 | from navix.states import State 6 | from navix.entities import Entities, Player, Goal, Key, Door 7 | from navix.components import EMPTY_POCKET_ID 8 | from navix.rendering.registry import PALETTE 9 | 10 | 11 | def test_navigation(): 12 | """Unittest for https://github.com/epignatelli/navix/pull/47""" 13 | height = 10 14 | width = 10 15 | grid = jnp.zeros((height - 2, width - 2), dtype=jnp.int32) 16 | grid = jnp.pad(grid, 1, mode="constant", constant_values=-1) 17 | 18 | players = Player( 19 | position=jnp.asarray((1, 1)), direction=jnp.asarray(0), pocket=EMPTY_POCKET_ID 20 | ) 21 | goals = Goal.create( 22 | position=jnp.asarray([(1, 1), (1, 1)]), probability=jnp.asarray([0.0, 0.0]) 23 | ) 24 | keys = Key.create(position=jnp.asarray((2, 2)), id=jnp.asarray(0), colour=PALETTE.YELLOW) 25 | doors = Door.create( 26 | position=jnp.asarray([(1, 5), (1, 6)]), 27 | requires=jnp.asarray((0, 0)), 28 | open=jnp.asarray((False, True)), 29 | colour=PALETTE.YELLOW, 30 | ) 31 | 32 | entities = { 33 | Entities.PLAYER: players[None], 34 | Entities.GOAL: goals, 35 | Entities.KEY: keys[None], 36 | Entities.DOOR: doors, 37 | } 38 | 39 | state = State( 40 | key=jax.random.PRNGKey(0), 41 | grid=grid, 42 | cache=nx.rendering.cache.RenderingCache.init(grid), 43 | entities=entities, 44 | ) 45 | action = jnp.asarray(0) 46 | reward = nx.rewards.on_goal_reached(state, action, state) 47 | assert jnp.array_equal(reward, jnp.asarray(0.0)) 48 | 49 | 50 | def test_tasks_composition(): 51 | reward_fn = nx.rewards.compose( 52 | nx.rewards.on_goal_reached, 53 | nx.rewards.action_cost, 54 | nx.rewards.time_cost, 55 | nx.rewards.wall_hit_cost, 56 | ) 57 | 58 | env = nx.environments.Room.create(height=3, width=3, max_steps=8, reward_fn=reward_fn) 59 | key = jax.random.PRNGKey(0) 60 | 61 | def _test(): 62 | timestep = env._reset(key) 63 | for _ in range(10): 64 | timestep = env.step(timestep, jax.random.randint(key, (), 0, 7)) 65 | return timestep 66 | 67 | print(jax.jit(_test)()) 68 | 69 | 70 | if __name__ == "__main__": 71 | # test_tasks_composition() 72 | test_navigation() 73 | -------------------------------------------------------------------------------- /tests/performance/observations.py: -------------------------------------------------------------------------------- 1 | import time 2 | from timeit import repeat 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | import navix as nx 7 | 8 | 9 | N_TIMEIT_LOOPS = 10 10 | N_REPEAT = 30 11 | N_TIMESTEPS = 10 12 | N_SEEDS = 10000 13 | 14 | 15 | def test_observation(observation_fn): 16 | def test(seed): 17 | env = nx.environments.Room( 18 | height=5, width=10, max_steps=100, gamma=1.0, observation_fn=observation_fn 19 | ) 20 | key = jax.random.PRNGKey(seed) 21 | timestep = env._reset(key) 22 | 23 | # option 1 24 | # actions = jax.random.randint(key, (100,), 0, 6) 25 | # timestep = jax.lax.scan(lambda c, x: (env.step(c, x), ()), timestep, actions)[0] 26 | 27 | # option 2 28 | # for i in range(N_TIMESTEPS): 29 | # action = jax.random.randint(key, (), 0, 6) 30 | # timestep = env.step(timestep, jnp.asarray(action)) 31 | 32 | # option 3 33 | actions = jax.random.randint(key, (100,), 0, 6) 34 | jax.lax.while_loop( 35 | lambda x: x[1] < N_TIMESTEPS, 36 | lambda x: (env.step(x[0], actions[x[1]]), x[1] + 1), 37 | (timestep, jnp.asarray(0)), 38 | ) 39 | 40 | return timestep 41 | 42 | # profile navix scanned 43 | print( 44 | "Profiling observation {}, N_SEEDS = {}, N_TIMESTEPS = {}".format( 45 | observation_fn, N_SEEDS, N_TIMESTEPS 46 | ) 47 | ) 48 | 49 | seeds = jnp.arange(N_SEEDS) 50 | 51 | print(f"\tCompiling {observation_fn}...") 52 | start = time.time() 53 | test_jit = jax.vmap(test) 54 | test_jit = jax.jit(test_jit) 55 | test_jit = test_jit.lower(seeds) 56 | test_jit = test_jit.compile() 57 | print("\tCompiled in {:.2f}s".format(time.time() - start)) 58 | 59 | print(f"\tRunning {observation_fn}...") 60 | res = repeat( 61 | lambda: test_jit(seeds).observation.block_until_ready(), 62 | number=N_TIMEIT_LOOPS, 63 | repeat=N_REPEAT, 64 | ) 65 | res = jnp.asarray(res) 66 | print(f"\t {jnp.mean(res)} ± {jnp.std(res)}") 67 | 68 | 69 | if __name__ == "__main__": 70 | test_observation(nx.observations.none) 71 | test_observation(nx.observations.categorical) 72 | test_observation(nx.observations.categorical_first_person) 73 | test_observation(nx.observations.rgb) 74 | test_observation(nx.observations.rgb_first_person) 75 | -------------------------------------------------------------------------------- /navix/agents/models.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Callable, Sequence, Tuple 3 | from jax import Array 4 | import jax 5 | import jax.numpy as jnp 6 | import distrax 7 | import flax.linen as nn 8 | from flax.linen.initializers import constant, orthogonal 9 | 10 | 11 | class MLPEncoder(nn.Module): 12 | hidden_size: int = 64 13 | 14 | @nn.compact 15 | def __call__(self, x): 16 | return nn.Sequential( 17 | [ 18 | nn.Dense(self.hidden_size), 19 | nn.tanh, 20 | nn.Dense(self.hidden_size), 21 | nn.tanh, 22 | ] 23 | )(x) 24 | 25 | 26 | class ConvEncoder(nn.Module): 27 | hidden_size: int = 64 28 | 29 | @nn.compact 30 | def __call__(self, x): 31 | return nn.Sequential( 32 | [ 33 | nn.Conv(16, kernel_size=(2, 2)), 34 | nn.relu, 35 | nn.Conv(32, kernel_size=(2, 2)), 36 | nn.relu, 37 | nn.Conv(64, kernel_size=(2, 2)), 38 | nn.relu, 39 | jnp.ravel, 40 | nn.Dense(self.hidden_size), 41 | nn.relu, 42 | ] 43 | )(x) 44 | 45 | 46 | class ActorCritic(nn.Module): 47 | action_dim: int 48 | actor_encoder: nn.Module = MLPEncoder() 49 | critic_encoder: nn.Module = MLPEncoder() 50 | 51 | def setup(self): 52 | self.actor = nn.Sequential( 53 | [ 54 | self.actor_encoder, 55 | nn.Dense( 56 | self.action_dim, 57 | kernel_init=orthogonal(0.01), 58 | bias_init=constant(0.0), 59 | ), 60 | # lambda x: distrax.Categorical(logits=x), 61 | ] 62 | ) 63 | 64 | self.critic = nn.Sequential( 65 | [ 66 | self.critic_encoder, 67 | nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0)), 68 | # lambda x: jnp.squeeze(x, axis=-1), 69 | ] 70 | ) 71 | 72 | def __call__(self, x: Array) -> Tuple[distrax.Distribution, Array]: 73 | return distrax.Categorical(self.actor(x)), jnp.squeeze(self.critic(x), -1) 74 | 75 | def policy(self, x: Array) -> distrax.Distribution: 76 | return distrax.Categorical(logits=self.actor(x)) 77 | 78 | def value(self, x: Array) -> Array: 79 | return jnp.squeeze(self.critic(x), -1) 80 | -------------------------------------------------------------------------------- /benchmarks/navix_.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict, dataclass, field 2 | import time 3 | import jax 4 | import tyro 5 | import numpy as np 6 | import jax.numpy as jnp 7 | import wandb 8 | import navix as nx 9 | from navix import observations 10 | from navix.agents import PPO, PPOHparams, ActorCritic 11 | from navix.environments.environment import Environment 12 | 13 | # set persistent compilation cache directory 14 | # jax.config.update("jax_compilation_cache_dir", "/tmp/jax-cache/") 15 | 16 | 17 | @dataclass 18 | class Args: 19 | project_name = "navix-benchmarks" 20 | seeds_offset: int = 0 21 | n_seeds: int = 1 22 | # env 23 | env_id: str = "Navix-DoorKey-Random-8x8-v0" 24 | discount: float = 0.99 25 | # ppo 26 | ppo_config: PPOHparams = field(default_factory=PPOHparams) 27 | 28 | 29 | if __name__ == "__main__": 30 | args = tyro.cli(Args) 31 | 32 | def FlattenObsWrapper(env: Environment): 33 | flatten_obs_fn = lambda x: jnp.ravel(env.observation_fn(x)) 34 | flatten_obs_shape = (int(np.prod(env.observation_space.shape)),) 35 | return env.replace( 36 | observation_fn=flatten_obs_fn, 37 | observation_space=env.observation_space.replace(shape=flatten_obs_shape), 38 | ) 39 | 40 | env = nx.make( 41 | args.env_id, 42 | observation_fn=observations.symbolic_first_person, 43 | gamma=args.discount, 44 | ) 45 | env = FlattenObsWrapper(env) 46 | 47 | ppo_config = args.ppo_config.replace(budget=10_000_000) 48 | agent = PPO( 49 | hparams=ppo_config, 50 | network=ActorCritic( 51 | action_dim=len(env.action_set), 52 | ), 53 | env=env, 54 | ) 55 | 56 | experiment = nx.Experiment( 57 | name=args.project_name, 58 | agent=agent, 59 | env=env, 60 | env_id=args.env_id, 61 | seeds=tuple(range(args.seeds_offset, args.seeds_offset + args.n_seeds)), 62 | group="navix", 63 | ) 64 | train_state, logs = experiment.run(do_log=False) 65 | 66 | print("Logging final results to wandb...") 67 | start_time = time.time() 68 | # average over seeds 69 | logs_avg = jax.tree.map(lambda x: x.mean(axis=0), logs) 70 | config = {**vars(experiment), **asdict(agent.hparams)} 71 | wandb.init(project=experiment.name, config=config, group=experiment.group) 72 | agent.log_on_train_end(logs_avg) 73 | wandb.finish() 74 | logging_time = time.time() - start_time 75 | print(f"Logging time cost: {logging_time}") 76 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # developing 2 | playground.ipynb 3 | 4 | # vscode 5 | .vscode/ 6 | wandb/ 7 | .DS_Store 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | pip-wheel-metadata/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | db.sqlite3-journal 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 103 | __pypackages__/ 104 | 105 | # Celery stuff 106 | celerybeat-schedule 107 | celerybeat.pid 108 | 109 | # SageMath parsed files 110 | *.sage.py 111 | 112 | # Environments 113 | .env 114 | .venv 115 | env/ 116 | venv/ 117 | ENV/ 118 | env.bak/ 119 | venv.bak/ 120 | 121 | # Spyder project settings 122 | .spyderproject 123 | .spyproject 124 | 125 | # Rope project settings 126 | .ropeproject 127 | 128 | # mkdocs documentation 129 | /site 130 | 131 | # mypy 132 | .mypy_cache/ 133 | .dmypy.json 134 | dmypy.json 135 | 136 | # Pyre type checker 137 | .pyre/ 138 | -------------------------------------------------------------------------------- /assets/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 5 | 1. Definitions. 6 | "License" shall mean the terms and conditions for use, reproduction, 7 | and distribution as defined by Sections 1 through 9 of this document. 8 | "Licensor" shall mean the copyright owner or entity authorized by 9 | the copyright owner that is granting the License. 10 | "Legal Entity" shall mean the union of the acting entity and all 11 | other entities that control, are controlled by, or are under common 12 | control with that entity. For the purposes of this definition, 13 | "control" means (i) the power, direct or indirect, to cause the 14 | direction or management of such entity, whether by contract or 15 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 16 | outstanding shares, or (iii) beneficial ownership of such entity. 17 | "You" (or "Your") shall mean an individual or Legal Entity 18 | @@ -154,49 +194,23 @@ 19 | whether in tort (including negligence), contract, or otherwise, 20 | unless required by applicable law (such as deliberate and grossly 21 | negligent acts) or agreed to in writing, shall any Contributor be 22 | liable to You for damages, including any direct, indirect, special, 23 | incidental, or consequential damages of any character arising as a 24 | result of this License or out of the use or inability to use the 25 | Work (including but not limited to damages for loss of goodwill, 26 | work stoppage, computer failure or malfunction, or any and all 27 | other commercial damages or losses), even if such Contributor 28 | has been advised of the possibility of such damages. 29 | 9. Accepting Warranty or Additional Liability. While redistributing 30 | the Work or Derivative Works thereof, You may choose to offer, 31 | and charge a fee for, acceptance of support, warranty, indemnity, 32 | or other liability obligations and/or rights consistent with this 33 | License. However, in accepting such obligations, You may act only 34 | on Your own behalf and on Your sole responsibility, not on behalf 35 | of any other Contributor, and only if You agree to indemnify, 36 | defend, and hold each Contributor harmless for any liability 37 | incurred by, or claims asserted against, such Contributor by reason 38 | of your accepting any such warranty or additional liability. 39 | 40 | END OF TERMS AND CONDITIONS 41 | -------------------------------------------------------------------------------- /examples/hparam_search.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Dict 3 | 4 | import distrax 5 | import tyro 6 | import numpy as np 7 | import jax.numpy as jnp 8 | import navix as nx 9 | from navix import observations 10 | from navix.agents import PPO, PPOHparams, ActorCritic 11 | from navix.environments.environment import Environment 12 | 13 | 14 | @dataclass 15 | class Args: 16 | project_name = "navix-debug" 17 | population_size: int = 10 18 | seed: int = 0 19 | # env 20 | env_id: str = "Navix-DoorKey-Random-6x6-v0" 21 | discount: float = 0.99 22 | # ppo 23 | ppo_config: PPOHparams = field(default_factory=PPOHparams) 24 | 25 | 26 | class CategoricalUniform(distrax.Categorical): 27 | def __init__(self, domain: tuple, dtype=jnp.int32): 28 | self.domain = jnp.asarray(domain) 29 | super().__init__(logits=jnp.zeros(len(domain)), dtype=dtype) 30 | 31 | def sample(self, *, seed, sample_shape=()): 32 | samples = super().sample(seed=seed, sample_shape=sample_shape) 33 | return self.domain[samples] 34 | 35 | def sample_n(self, rng, n): 36 | samples = super().sample(seed=rng, sample_shape=(n,)) 37 | return self.domain[samples] 38 | 39 | 40 | if __name__ == "__main__": 41 | args = tyro.cli(Args) 42 | 43 | def FlattenObsWrapper(env: Environment): 44 | flatten_obs_fn = lambda x: jnp.ravel(env.observation_fn(x)) 45 | flatten_obs_shape = (int(np.prod(env.observation_space.shape)),) 46 | return env.replace( 47 | observation_fn=flatten_obs_fn, 48 | observation_space=env.observation_space.replace(shape=flatten_obs_shape), 49 | ) 50 | 51 | env = nx.make( 52 | args.env_id, 53 | observation_fn=observations.symbolic, 54 | gamma=args.discount, 55 | ) 56 | env = FlattenObsWrapper(env) 57 | 58 | # static hparams 59 | ppo_config = args.ppo_config.replace(anneal_lr=False) 60 | 61 | hparams_distr: Dict[str, distrax.Distribution] = { 62 | "gae_lambda": CategoricalUniform((0.7, 0.95, 0.99)), 63 | "clip_eps": CategoricalUniform((0.1, 0.2)), 64 | "ent_coef": CategoricalUniform((0.001, 0.01, 0.1)), 65 | "vf_coef": CategoricalUniform((0.1, 0.5, 0.9)), 66 | "lr": CategoricalUniform((1e-3, 2.5e-4, 1e-4, 1e-5)), 67 | } 68 | 69 | base_hparams = args.ppo_config 70 | experiment = nx.Experiment( 71 | name=args.project_name, 72 | agent=PPO(base_hparams, ActorCritic(len(env.action_set)), env), 73 | env=env, 74 | env_id=args.env_id, 75 | seeds=(args.seed,), 76 | ) 77 | 78 | experiment.run_hparam_search(hparams_distr, args.population_size) 79 | -------------------------------------------------------------------------------- /navix/agents/agent.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import time 3 | from typing import Dict, Tuple 4 | 5 | import numpy as np 6 | import wandb 7 | import jax 8 | import jax.numpy as jnp 9 | from flax import struct 10 | from flax.training.train_state import TrainState 11 | 12 | 13 | class HParams(struct.PyTreeNode): 14 | debug: bool = struct.field(pytree_node=False, default=False) 15 | """Whether to run in debug mode.""" 16 | log_frequency: int = struct.field(pytree_node=False, default=1) 17 | """How often to log results.""" 18 | log_render: bool = struct.field(pytree_node=False, default=False) 19 | 20 | 21 | class Agent(struct.PyTreeNode): 22 | hparams: HParams 23 | 24 | def train(self, rng: jax.Array) -> Tuple[TrainState, Dict[str, jax.Array]]: 25 | raise NotImplementedError 26 | 27 | def log(self, logs, inspectable=None): 28 | if len(logs) == 0 or logs["iter/updates"] % self.hparams.log_frequency != 0: 29 | return 30 | 31 | start_time = time.time() 32 | msg = f"Update Step: {logs['iter/updates']}, Frames: {logs['iter/frames']}" 33 | step = jnp.asarray(logs["iter/updates"], dtype=jnp.int32) 34 | 35 | # log renders 36 | if self.hparams.log_render: 37 | render_human = logs.pop("render/human") # (T, 3, H, W) 38 | logs[f"render/human"] = wandb.Video(np.array(render_human), fps=4) 39 | 40 | if "done_mask" in logs: 41 | mask = jnp.asarray(logs.pop("done_mask"), dtype=jnp.bool) # (T, N) 42 | # log episode length 43 | if "lengths" in logs: 44 | lengths: jax.Array = logs.pop("lengths") # (T, N) 45 | episode_lengths = lengths[mask] # (K,) 46 | logs["perf/episode_length"] = jnp.mean(episode_lengths) 47 | msg += f", Length: {logs['perf/episode_length']}" 48 | 49 | # log returns 50 | if "returns" in logs: 51 | returns = logs.pop("returns") # (T, N) 52 | final_returns = returns[mask] # (K,) 53 | logs["perf/returns"] = jnp.mean(final_returns) 54 | logs["perf/success_rate"] = jnp.mean(final_returns == 1.0) 55 | msg += f", Returns: {logs['perf/returns']}, Success Rate: {logs['perf/success_rate']}" 56 | 57 | msg += f", Logging time cost: {time.time() - start_time}" 58 | wandb.log(logs, step=step) 59 | 60 | def log_on_train_end(self, logs): 61 | print(jax.tree.map(lambda x: x.shape, logs)) 62 | len_logs = len(logs["iter/updates"]) 63 | for step in range(len_logs): 64 | step_logs = {k: jax.tree.map(lambda x: x[step], v) for k, v in logs.items()} 65 | self.log(step_logs) 66 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # This toml file is in comliance with PEP 517, PEP 518 and PEP 621 2 | # However, setuptools support for these is only in Beta 3 | # For now, `setup.py` is still the first option for pip 4 | [build-system] 5 | build-backend = "setuptools.build_meta" 6 | requires = ["setuptools >= 50", "setuptools-scm[toml]>=6.2", "wheel"] 7 | 8 | [project] 9 | name = "Navix" 10 | dynamic = ["version", "dependencies"] 11 | description = "Accelerated gridworld navigation with JAX for deep reinforcement learning" 12 | requires-python = ">=3.9" 13 | readme = "README.md" 14 | license = {file = "LICENSE", name = "Apache-2.0"} 15 | authors = [ 16 | {name = "Eduardo Pignatelli", email = "edu.pignatelli@gmail.com"}, 17 | ] 18 | maintainers = [ 19 | {name = "Eduardo Pignatelli", email = "edu.pignatelli@gmail.com"}, 20 | ] 21 | classifiers = [ 22 | "Development Status :: 3 - Alpha", 23 | "Intended Audience :: Developers", 24 | "Intended Audience :: Science/Research", 25 | "Intended Audience :: Financial and Insurance Industry", 26 | "Intended Audience :: Healthcare Industry", 27 | "Environment :: GPU", 28 | "Environment :: GPU :: NVIDIA CUDA", 29 | "Environment :: GPU :: NVIDIA CUDA :: 11.0", 30 | "Environment :: GPU :: NVIDIA CUDA :: 11.1", 31 | "Environment :: GPU :: NVIDIA CUDA :: 11.2", 32 | "Environment :: GPU :: NVIDIA CUDA :: 11.3", 33 | "Environment :: GPU :: NVIDIA CUDA :: 11.4", 34 | "Environment :: GPU :: NVIDIA CUDA :: 11.5", 35 | "Environment :: GPU :: NVIDIA CUDA :: 11.6", 36 | "Environment :: GPU :: NVIDIA CUDA :: 11.7", 37 | "Environment :: GPU :: NVIDIA CUDA :: 11.8", 38 | "Framework :: Pytest", 39 | "License :: OSI Approved :: Apache Software License", 40 | "Operating System :: OS Independent", 41 | "Programming Language :: Python", 42 | "Programming Language :: Python :: 3", 43 | "Programming Language :: Python :: 3.9", 44 | "Programming Language :: Python :: 3.10", 45 | "Programming Language :: Python :: 3.11", 46 | "Topic :: Scientific/Engineering", 47 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 48 | "Topic :: Software Development", 49 | "Topic :: Software Development :: Libraries", 50 | "Topic :: Software Development :: Libraries :: Python Modules", 51 | ] 52 | 53 | 54 | [project.urls] 55 | homepage = "https://github.com/epignatelli/navix" 56 | repository = "https://github.com/epignatelli/navix" 57 | bug_tracker = "https://github.com/epignatelli/navix/issues" 58 | 59 | 60 | [tool.setuptools.dynamic] 61 | version = {attr = "navix._version.__version__"} 62 | dependencies = {file = "./requirements.txt"} 63 | 64 | 65 | [tool.setuptools.packages.find] 66 | include = ["navix*", "assets*"] 67 | exclude = ["tests", "examples", "scripts", "docs"] 68 | 69 | 70 | [tool.distutils.bdist_wheel] 71 | universal = true 72 | 73 | 74 | [tool.black] 75 | line-length = 88 76 | -------------------------------------------------------------------------------- /tests/test_environments.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import navix as nx 4 | 5 | 6 | def test_room(): 7 | def f(): 8 | env = nx.environments.Room.create( 9 | height=3, 10 | width=3, 11 | max_steps=8, 12 | observation_fn=nx.observations.symbolic_first_person, 13 | ) 14 | key = jax.random.PRNGKey(4) 15 | reset = jax.jit(env._reset) 16 | step = jax.jit(env.step) 17 | timestep = reset(key) 18 | # these are optimal actios for navigation + action_cost 19 | actions = ( 20 | 0, # noop sanity check 21 | 2, # rotate_ccw 22 | 3, # forward 23 | 3, # forward 24 | 2, # rotate_ccw 25 | 3, # forward 26 | ) 27 | print(timestep) 28 | print() 29 | for action in actions: 30 | timestep = step(timestep, jnp.asarray(action)) 31 | print() 32 | print(nx.actions.DEFAULT_ACTION_SET[action]) 33 | print(timestep) 34 | return timestep 35 | 36 | f() 37 | timestep = jax.jit(f)() 38 | print(timestep) 39 | 40 | 41 | def test_keydoor(): 42 | def f(): 43 | env = nx.environments.DoorKey.create( 44 | height=5, 45 | width=10, 46 | max_steps=8, 47 | observation_fn=nx.observations.symbolic_first_person, 48 | ) 49 | key = jax.random.PRNGKey(1) 50 | reset = jax.jit(env._reset) 51 | step = jax.jit(env.step) 52 | timestep = reset(key) 53 | # these are optimal actions for navigation + action_cost 54 | actions = ( 55 | 0, # rotate_ccw 56 | 2, # forward 57 | 2, # forward 58 | 2, # forward 59 | 0, # rotate_ccw 60 | 3, # pick-up 61 | 0, # rotate_ccw 62 | 0, # rotate_ccw 63 | 2, # forward 64 | 2, # forward 65 | 1, # rotate_cw 66 | 2, # forward 67 | 0, # rotate_ccw 68 | 5, # open 69 | 2, # forward 70 | 2, # forward 71 | ) 72 | print(timestep) 73 | for action in actions: 74 | timestep = step(timestep, jnp.asarray(action)) 75 | print() 76 | print(nx.actions.DEFAULT_ACTION_SET[action]) 77 | print(timestep) 78 | return timestep 79 | 80 | f() 81 | jax.jit(f)() 82 | 83 | 84 | def test_keydoor2(): 85 | env = nx.environments.DoorKey.create(5, 7, 100, observation_fn=nx.observations.rgb) 86 | 87 | key = jax.random.PRNGKey(1) 88 | timestep = env._reset(key) 89 | return 90 | 91 | 92 | if __name__ == "__main__": 93 | # test_room() 94 | # jax.jit(test_room)() 95 | test_keydoor() 96 | # test_keydoor2() 97 | -------------------------------------------------------------------------------- /.github/workflows/CI.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | pull_request: 5 | types: [opened, synchronize, reopened] 6 | 7 | jobs: 8 | Test: 9 | runs-on: ${{ matrix.os }}-latest 10 | strategy: 11 | fail-fast: false 12 | max-parallel: 5 13 | matrix: 14 | os: ["ubuntu"] 15 | python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] 16 | continue-on-error: false 17 | steps: 18 | - uses: actions/checkout@v3 19 | - uses: actions/setup-python@v4 20 | with: 21 | python-version: ${{ matrix.python-version }} 22 | - name: Setup navix 23 | run: | 24 | pip install . -v 25 | pip install -r requirements_test.txt 26 | - name: Check code quality 27 | run: | 28 | pip install pylint 29 | MESSAGE=$(pylint -ry $(git ls-files '*.py') ||:) 30 | echo "$MESSAGE" 31 | - name: Run unit tests with pytest 32 | run: | 33 | wandb offline 34 | pytest 35 | - name: Run examples 36 | run: | 37 | for example in examples/*.py; do 38 | python $example 39 | done 40 | 41 | Compliance: 42 | runs-on: ubuntu-latest 43 | steps: 44 | - uses: actions/checkout@v3 45 | with: 46 | fetch-depth: 0 47 | - name: PEP8 Compliance 48 | run: | 49 | pip install pylint 50 | PR_BRANCH=${{ github.event.pull_request.target.ref }} 51 | MAIN_BRANCH=origin/${{ github.event.pull_request.base.ref }} 52 | CURRENT_DIFF=$(git diff --name-only --diff-filter=d $MAIN_BRANCH $PR_BRANCH | grep -E '\.py$' | tr '\n' ' ') 53 | if [[ $CURRENT_DIFF == "" ]]; 54 | then MESSAGE="Diff is empty and there is nothing to pylint." 55 | else 56 | MESSAGE=$(pylint -ry --disable=E0401 $CURRENT_DIFF ||:) 57 | fi 58 | echo 'MESSAGE<> $GITHUB_ENV 59 | echo "
$MESSAGE
" >> $GITHUB_ENV 60 | echo 'EOF' >> $GITHUB_ENV 61 | echo "Printing PR message: $MESSAGE" 62 | - uses: mshick/add-pr-comment@v2 63 | with: 64 | issue: ${{ github.event.pull_request.number }} 65 | message: ${{ env.MESSAGE }} 66 | repo-token: ${{ secrets.GITHUB_TOKEN }} 67 | 68 | Check-next-version: 69 | runs-on: ubuntu-latest 70 | steps: 71 | - uses: actions/checkout@v3 72 | - uses: actions/setup-python@v4 73 | - name: Get version from file 74 | run: | 75 | VERSION_FILE="navix/_version.py" 76 | NAVIX_VERSION="$(cat navix/_version.py | grep '__version__ = ' | cut -d'=' -f2 | sed 's,\",,g' | sed "s,\',,g" | sed 's, ,,g')" 77 | echo "Current version is:" 78 | echo "$NAVIX_VERSION" 79 | echo "NAVIX_VERSION=$NAVIX_VERSION" >> $GITHUB_ENV 80 | - name: Check that git tag does not exist 81 | run: | 82 | NAVIX_VERSION=${{ env.NAVIX_VERSION }} 83 | git fetch --tags 84 | if [ $(git tag -l "$NAVIX_VERSION") ]; then 85 | echo "Tag $NAVIX_VERSION already exists. Please update the version in navix/_version.py file." 86 | exit 1 87 | fi 88 | echo "Tag $NAVIX_VERSION will be the deployed version." 89 | -------------------------------------------------------------------------------- /navix/environments/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Navix Authors. 2 | 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | from typing import Callable 20 | import difflib 21 | 22 | 23 | _ENVS_REGISTRY = {} 24 | 25 | 26 | def registry(): 27 | return _ENVS_REGISTRY 28 | 29 | 30 | def register_env(name: str, ctor: Callable): 31 | _ENVS_REGISTRY[name] = ctor 32 | 33 | 34 | def make(name: str, max_steps: int = 100, **kwargs): 35 | if name not in registry(): 36 | closest = difflib.get_close_matches(name, registry().keys()) 37 | msg = f"Environment {name} not yet implemented." 38 | if closest: 39 | msg += ( 40 | f"Did you mean one of these? {closest}\n" 41 | + "If not, please open a feature request!" 42 | + "\nhttps://github.com/epignatelli/navix/issues/new?labels=enhancement" 43 | ) 44 | raise NotImplementedError(msg) 45 | ctor = _ENVS_REGISTRY[name] 46 | return ctor(max_steps=max_steps, **kwargs) 47 | 48 | 49 | NotImplementedEnvs = [ 50 | "MiniGrid-BlockedUnlockPickup-v0", 51 | "MiniGrid-LavaCrossingS9N1-v0", 52 | "MiniGrid-LavaCrossingS9N2-v0", 53 | "MiniGrid-LavaCrossingS9N3-v0", 54 | "MiniGrid-LavaCrossingS11N5-v0", 55 | "MiniGrid-Fetch-5x5-N2-v0", 56 | "MiniGrid-Fetch-6x6-N2-v0", 57 | "MiniGrid-Fetch-8x8-N3-v0", 58 | "MiniGrid-GoToObject-6x6-N2-v0", 59 | "MiniGrid-GoToObject-8x8-N2-v0", 60 | "MiniGrid-LockedRoom-v0", 61 | "MiniGrid-MemoryS17Random-v0", 62 | "MiniGrid-MemoryS13Random-v0", 63 | "MiniGrid-MemoryS13-v0", 64 | "MiniGrid-MemoryS11-v0", 65 | "MiniGrid-MemoryS9-v0", 66 | "MiniGrid-MemoryS7-v0", 67 | "MiniGrid-MultiRoom-N2-S4-v0", 68 | "MiniGrid-MultiRoom-N4-S5-v0", 69 | "MiniGrid-MultiRoom-N6-v0", 70 | "MiniGrid-ObstructedMaze-1Dl-v0", 71 | "MiniGrid-ObstructedMaze-1Dlh-v0", 72 | "MiniGrid-ObstructedMaze-1Dlhb-v0", 73 | "MiniGrid-ObstructedMaze-2Dl-v0", 74 | "MiniGrid-ObstructedMaze-2Dlh-v0", 75 | "MiniGrid-ObstructedMaze-2Dlhb-v0", 76 | "MiniGrid-ObstructedMaze-1Q-v0", 77 | "MiniGrid-ObstructedMaze-2Q-v0", 78 | "MiniGrid-ObstructedMaze-Full-v0", 79 | "MiniGrid-ObstructedMaze-2Dlhb-v1", 80 | "MiniGrid-ObstructedMaze-1Q-v1", 81 | "MiniGrid-ObstructedMaze-2Q-v1", 82 | "MiniGrid-ObstructedMaze-Full-v1", 83 | "MiniGrid-Playground-v0", 84 | "MiniGrid-PutNear-6x6-N2-v0", 85 | "MiniGrid-PutNear-8x8-N3-v0", 86 | "MiniGrid-RedBlueDoors-6x6-v0", 87 | "MiniGrid-RedBlueDoors-8x8-v0", 88 | "MiniGrid-Unlock-v0", 89 | "MiniGrid-UnlockPickup-v0", 90 | ] 91 | -------------------------------------------------------------------------------- /navix/events.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Navix Authors. 2 | 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | from __future__ import annotations 20 | 21 | from jax import Array 22 | import jax.numpy as jnp 23 | 24 | from .states import State 25 | from .grid import positions_equal, translate 26 | from .entities import Entities, Player 27 | 28 | 29 | def on_goal_reached(state: State) -> Array: 30 | """Checks whether the goal has been reached using the `goal_reached` event. 31 | 32 | Args: 33 | state (State): The current state of the game. 34 | 35 | Returns: 36 | Array: A boolean array indicating whether the goal has been reached.""" 37 | return state.events.goal_reached.happened 38 | 39 | 40 | def on_lava_fall(state: State) -> Array: 41 | """Checks whether the lava has fallen using the `lava_fall` event. 42 | 43 | Args: 44 | state (State): The current state of the game. 45 | 46 | Returns: 47 | Array: A boolean array indicating whether the lava has fallen.""" 48 | return state.events.lava_fall.happened 49 | 50 | 51 | def on_ball_hit(state: State) -> Array: 52 | """Checks whether the ball has hit something using the `ball_hit` event. 53 | 54 | Args: 55 | state (State): The current state of the game. 56 | 57 | Returns: 58 | Array: A boolean array indicating whether the ball has hit something.""" 59 | return state.events.ball_hit.happened 60 | 61 | 62 | def on_door_done(state: State) -> Array: 63 | """Checks whether the action `done` has been called in front of a `Door` object with the correct colour. 64 | 65 | Args: 66 | state (State): The current state of the game. 67 | 68 | Returns: 69 | Array: A boolean array indicating whether the action `done` has been called in front of a `Door` object with the correct colour. 70 | """ 71 | assert ( 72 | state.mission is not None 73 | ), "Termination on door done requires the state to specify a mission." 74 | player = state.entities[Entities.PLAYER][0] 75 | assert isinstance(player, Player) 76 | 77 | fwd_pos = translate(player.position, player.direction) 78 | if Entities.DOOR not in state.entities: 79 | return jnp.asarray(False) 80 | doors = state.get_doors() 81 | idx = jnp.where(positions_equal(doors.position, fwd_pos), size=1)[0][0] 82 | doors = doors[idx] 83 | pos_match = jnp.array_equal(fwd_pos, state.mission.position) 84 | colour_match = jnp.array_equal(doors.colour, state.mission.colour) 85 | return jnp.logical_and(pos_match, colour_match) 86 | 87 | 88 | def on_wall_hit(state: State) -> Array: 89 | """Checks whether the wall has been hit using the `wall_hit` event. 90 | 91 | Args: 92 | state (State): The current state of the game. 93 | 94 | Returns: 95 | Array: A boolean array indicating whether the wall has been hit.""" 96 | return state.events.wall_hit.happened 97 | -------------------------------------------------------------------------------- /tests/test_grid.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import navix as nx 4 | 5 | 6 | def test_grid_from_ascii(): 7 | ascii_map = """######## 8 | #1.....# 9 | #......# 10 | #......# 11 | #......# 12 | #......# 13 | #......# 14 | #.....2# 15 | ######## 16 | ######## 17 | ######## 18 | ######## 19 | """ 20 | print(ascii_map) 21 | 22 | grid = nx.grid.from_ascii_map(ascii_map) 23 | print(grid) 24 | 25 | ascii_map = ascii_map.replace("1", "P") 26 | ascii_map = ascii_map.replace("2", "G") 27 | grid = nx.grid.from_ascii_map(ascii_map, mapping={"P": 1, "G": 2}) 28 | print(grid) 29 | 30 | 31 | def test_idx_from_coordinates(): 32 | grid = jnp.zeros((5, 7), dtype=jnp.int32) 33 | grid = jnp.pad(grid, 1, mode="constant", constant_values=-1) 34 | 35 | positions = jnp.array([[1, 1], [2, 2], [3, 3], [4, 4]]) 36 | indices = nx.grid.idx_from_coordinates(grid, positions) 37 | positions_after = nx.grid.coordinates_from_idx(grid, indices) 38 | assert jnp.all(jnp.array_equal(positions, positions_after)), ( 39 | positions, 40 | positions_after, 41 | ) 42 | 43 | 44 | def test_random_positions(): 45 | grid = jnp.zeros((5, 7), dtype=jnp.int32) 46 | grid = jnp.pad(grid, 1, mode="constant", constant_values=-1) 47 | 48 | key = jax.random.PRNGKey(0) 49 | positions = nx.grid.random_positions(key, grid, n=1) 50 | assert positions.shape == (2,), positions.shape 51 | 52 | positions = nx.grid.random_positions(key, grid, n=4) 53 | assert positions.shape == (4, 2), positions.shape 54 | 55 | exclude = jnp.asarray((1, 1)) 56 | positions = nx.grid.random_positions(key, grid, n=50, exclude=exclude) 57 | for position in positions: 58 | assert not jnp.array_equal(position, exclude), position 59 | assert jnp.array_equal(grid[tuple(position)], 0), positions 60 | 61 | 62 | def test_position_equal(): 63 | # one to one 64 | a = jnp.array([1, 1]) 65 | b = jnp.array([1, 1]) 66 | assert nx.grid.positions_equal(a, b) 67 | assert nx.grid.positions_equal(b, a) 68 | assert not nx.grid.positions_equal(a, b + 1) 69 | assert not nx.grid.positions_equal(a + 1, b) 70 | assert not nx.grid.positions_equal(b, a + 1) 71 | assert not nx.grid.positions_equal(b + 1, a) 72 | 73 | # one to many 74 | a = jnp.array([1, 1]) 75 | b = jnp.array([[1, 1], [1, 2]]) 76 | assert jnp.array_equal(nx.grid.positions_equal(a, b), jnp.array([True, False])) 77 | assert jnp.array_equal(nx.grid.positions_equal(b, a), jnp.array([True, False])) 78 | assert jnp.array_equal(nx.grid.positions_equal(a, b + 1), jnp.array([False, False])) 79 | assert jnp.array_equal(nx.grid.positions_equal(a + 1, b), jnp.array([False, False])) 80 | assert jnp.array_equal(nx.grid.positions_equal(b, a + 1), jnp.array([False, False])) 81 | assert jnp.array_equal(nx.grid.positions_equal(b + 1, a), jnp.array([False, False])) 82 | 83 | # many to many 84 | a = jnp.array([[1, 1], [1, 2]]) 85 | b = jnp.array([[1, 1], [1, 2]]) 86 | assert jnp.array_equal(nx.grid.positions_equal(a, b), jnp.array([True, True])) 87 | assert jnp.array_equal(nx.grid.positions_equal(b, a), jnp.array([True, True])) 88 | assert jnp.array_equal(nx.grid.positions_equal(a, b + 1), jnp.array([False, False])) 89 | assert jnp.array_equal(nx.grid.positions_equal(a + 1, b), jnp.array([False, False])) 90 | assert jnp.array_equal(nx.grid.positions_equal(b, a + 1), jnp.array([False, False])) 91 | assert jnp.array_equal(nx.grid.positions_equal(b + 1, a), jnp.array([False, False])) 92 | 93 | 94 | if __name__ == "__main__": 95 | # test_grid_from_ascii() 96 | # test_idx_from_coordinates() 97 | # test_random_positions() 98 | test_position_equal() 99 | # jax.jit(test_position_equal)() 100 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: NAVIX 2 | site_author: Eduardo Pignatelli 3 | site_description: A reimplementation of MiniGrid in JAX 4 | site_url: https://epignatelli/navix 5 | 6 | # GitHub 7 | repo_name: epignatelli/navix 8 | repo_url: https://github.com/epignatelli/navix 9 | use_directory_urls: false 10 | # mkdocstrings 11 | watch: 12 | - navix 13 | 14 | nav: 15 | - Home: 16 | - Welcome: index.md 17 | - Environments: home/environments.md 18 | - Install: install/index.md 19 | - Quickstart: 20 | - "Getting started": examples/getting_started.ipynb 21 | - "PPO": examples/ppo.ipynb 22 | # - "Customizing envs": examples/customisation.ipynb 23 | - Benchmarks: 24 | - "Timesteps": benchmarks/timesteps.ipynb 25 | - "Environments": benchmarks/envs.ipynb 26 | - API: api/ 27 | - Changelog: https://github.com/epignatelli/navix/releases 28 | 29 | # Customization 30 | extra: 31 | social: 32 | - icon: fontawesome/brands/github 33 | link: https://github.com/epignatelli/navix 34 | - icon: fontawesome/brands/python 35 | link: https://pypi.org/project/navix/ 36 | - icon: fontawesome/brands/twitter 37 | link: https://twitter.com/edupignatelli 38 | - icon: fontawesome/brands/google-scholar 39 | link: https://github.com/epignatelli/navix 40 | 41 | extra_css: 42 | - assets/stylesheets/extra.css 43 | - https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.7/katex.min.css 44 | 45 | extra_javascript: 46 | - https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.7/katex.min.js 47 | - https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.7/contrib/auto-render.min.js 48 | - https://unpkg.com/mermaid/dist/mermaid.min.js 49 | 50 | theme: 51 | name: "material" 52 | logo: assets/images/navix_logo.png 53 | font: 54 | text: Roboto 55 | code: Roboto Mono 56 | 57 | features: 58 | - announce.dismiss 59 | - content.action.edit 60 | - content.action.view 61 | - content.code.annotate 62 | - content.code.copy 63 | - content.tooltips 64 | - navigation.instant 65 | - navigation.footer 66 | - navigation.tabs 67 | - navigation.top 68 | - navigation.path 69 | - search.highlight 70 | - search.share 71 | - search.suggest 72 | - toc.follow 73 | - toc.integrate 74 | 75 | palette: 76 | - scheme: default 77 | primary: pink 78 | accent: red 79 | toggle: 80 | icon: material/weather-night 81 | name: Switch to dark mode 82 | 83 | - scheme: slate 84 | primary: pink 85 | accent: red 86 | toggle: 87 | icon: material/weather-sunny 88 | name: Switch to light mode 89 | 90 | plugins: 91 | - mkdocs-jupyter: 92 | remove_tag_config: 93 | remove_input_tags: 94 | - hide_input 95 | - mkdocstrings: 96 | default_handler: python 97 | handlers: 98 | python: 99 | options: 100 | docstring_style: google 101 | show_bases: true 102 | show_source: false 103 | heading_level: 3 104 | show_root_full_path: true 105 | show_symbol_type_heading: true 106 | show_symbol_type_toc: true 107 | show_signature: true 108 | show_signature_annotations: false 109 | signature_crossrefs: false 110 | - search 111 | - gen-files: 112 | scripts: 113 | - docs/scripts/gen_doc_stubs.py # or any other name or path 114 | - literate-nav: 115 | nav_file: SUMMARY.md 116 | 117 | markdown_extensions: 118 | - toc: 119 | toc_depth: 5 120 | - pymdownx.highlight 121 | - pymdownx.snippets: 122 | check_paths: true 123 | - admonition 124 | - attr_list 125 | - footnotes 126 | - pymdownx.details # For collapsible admonitions 127 | - pymdownx.superfences 128 | 129 | copyright: Copyright © 2023 - 2024 NAVIX Authors 130 | -------------------------------------------------------------------------------- /docs/assets/stylesheets/extra.css: -------------------------------------------------------------------------------- 1 | @font-face { 2 | font-family: 'Sage'; 3 | src: url('../fonts/sage-bold.woff2') format('woff'); 4 | } 5 | 6 | @font-face { 7 | font-family: 'Sherpa'; 8 | src: url('../fonts/sherpa.woff2') format('woff'); 9 | } 10 | 11 | /* h1 { 12 | font-family: 'Sage', sans-serif !important; 13 | font-size: 44px !important; 14 | color: #000 !important; 15 | } 16 | 17 | :root { 18 | --md-text-font: "Sherpa"; 19 | } */ 20 | 21 | .md-header__button.md-logo img, 22 | .md-header__button.md-logo svg { 23 | height: 3rem !important; 24 | margin-bottom: 0em; 25 | padding-bottom: 0; 26 | } 27 | 28 | .no-bottom-margin { 29 | margin-bottom: 0 !important; 30 | } 31 | 32 | .maiusc { 33 | text-transform: uppercase; 34 | } 35 | 36 | .center { 37 | display: block; 38 | margin-left: auto; 39 | margin-right: auto; 40 | } 41 | 42 | .doc-attribute { 43 | border-top: 1px solid #ccc; 44 | } 45 | 46 | .doc-method { 47 | border-top: 1px solid #ccc; 48 | } 49 | 50 | .doc-function { 51 | border-top: 1px solid #ccc; 52 | } 53 | 54 | .doc-class { 55 | border-top: 5px solid var(--md-code-bg-color); 56 | } 57 | 58 | /* Remove the `In` and `Out` block in rendered Jupyter notebooks */ 59 | .md-container .jp-Cell-outputWrapper .jp-OutputPrompt.jp-OutputArea-prompt, 60 | .md-container .jp-Cell-inputWrapper .jp-InputPrompt.jp-InputArea-prompt { 61 | display: none !important; 62 | } 63 | 64 | /* Fix /page#foo going to the top of the viewport and being hidden by the navbar */ 65 | html { 66 | scroll-padding-top: 50px; 67 | } 68 | 69 | /* Emphasise sections of nav on left hand side */ 70 | nav.md-nav { 71 | padding-left: 5px; 72 | } 73 | 74 | nav.md-nav--secondary { 75 | border-left: revert !important; 76 | } 77 | 78 | .md-nav__title { 79 | font-size: 0.9rem; 80 | } 81 | 82 | .md-nav__item--section>.md-nav__link { 83 | font-size: 0.9rem; 84 | } 85 | 86 | /* More space at the bottom of the page */ 87 | 88 | .md-main__inner { 89 | margin-bottom: 1.5rem; 90 | } 91 | 92 | 93 | /* Change font sizes */ 94 | html { 95 | /* Decrease font size for overall webpage 96 | Down from 137.5% which is the Material default */ 97 | font-size: 110%; 98 | } 99 | 100 | .md-typeset .admonition { 101 | /* Increase font size in admonitions */ 102 | font-size: 100% !important; 103 | } 104 | 105 | .md-typeset details { 106 | /* Increase font size in details */ 107 | font-size: 100% !important; 108 | } 109 | 110 | .md-typeset h1 { 111 | font-size: 1.6rem; 112 | } 113 | 114 | .md-typeset h2 { 115 | font-size: 1.5rem; 116 | } 117 | 118 | .md-typeset h3 { 119 | font-size: 1.3rem; 120 | } 121 | 122 | .md-typeset h4 { 123 | font-size: 1.1rem; 124 | } 125 | 126 | .md-typeset h5 { 127 | font-size: 0.9rem; 128 | } 129 | 130 | .md-typeset h6 { 131 | font-size: 0.8rem; 132 | } 133 | 134 | 135 | /* Highlight functions, classes etc. type signatures. Really helps to make clear where 136 | one item ends and another begins. */ 137 | 138 | [data-md-color-scheme="default"] { 139 | --doc-heading-color: #DDD; 140 | --doc-heading-border-color: #CCC; 141 | --doc-heading-color-alt: #F0F0F0; 142 | } 143 | 144 | [data-md-color-scheme="slate"] { 145 | --doc-heading-color: rgb(25, 25, 33); 146 | --doc-heading-border-color: rgb(25, 25, 33); 147 | --doc-heading-color-alt: rgb(33, 33, 44); 148 | --md-code-bg-color: rgb(38, 38, 50); 149 | } 150 | 151 | h4.doc-heading { 152 | /* NOT var(--md-code-bg-color) as that's not visually distinct from other code blocks.*/ 153 | background-color: var(--doc-heading-color); 154 | border: solid var(--doc-heading-border-color); 155 | border-width: 1.5pt; 156 | border-radius: 2pt; 157 | padding: 0pt 5pt 2pt 5pt; 158 | } 159 | 160 | h5.doc-heading, 161 | h6.heading { 162 | background-color: var(--doc-heading-color-alt); 163 | border-radius: 2pt; 164 | padding: 0pt 5pt 2pt 5pt; 165 | } -------------------------------------------------------------------------------- /navix/environments/dist_shift.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Navix Authors. 2 | 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | 20 | 21 | from __future__ import annotations 22 | from typing import Union 23 | 24 | import jax.numpy as jnp 25 | from jax import Array 26 | from flax import struct 27 | 28 | from navix import observations, rewards, terminations 29 | 30 | from ..components import EMPTY_POCKET_ID 31 | from ..entities import Entities, Goal, Lava, Player 32 | from ..states import State 33 | from ..grid import room 34 | from ..rendering.cache import RenderingCache 35 | from .environment import Environment, Timestep 36 | from .registry import register_env 37 | 38 | 39 | class DistShift(Environment): 40 | split_lava: bool = struct.field(pytree_node=False, default=False) 41 | 42 | def _reset(self, key: Array, cache: Union[RenderingCache, None] = None) -> Timestep: 43 | # map 44 | grid = room(height=self.height, width=self.width) 45 | 46 | # goal and player 47 | player_pos = jnp.asarray([1, 1]) 48 | direction = jnp.asarray(0) 49 | player = Player.create( 50 | position=player_pos, 51 | direction=direction, 52 | pocket=EMPTY_POCKET_ID, 53 | ) 54 | # goal 55 | goal_pos = jnp.asarray([1, self.width - 2]) 56 | goal = Goal.create(position=goal_pos, probability=jnp.asarray(1.0)) 57 | 58 | # lava 59 | last_row = 5 if self.split_lava else 2 60 | lava_pos = jnp.asarray( 61 | [[1, 3], [1, 4], [1, 5], [last_row, 3], [last_row, 4], [last_row, 5]] 62 | ) 63 | lava = Lava.create(lava_pos) 64 | 65 | entities = { 66 | Entities.PLAYER: player[None], 67 | Entities.GOAL: goal[None], 68 | Entities.LAVA: lava, 69 | } 70 | 71 | # systems 72 | state = State( 73 | key=key, 74 | grid=grid, 75 | cache=cache or RenderingCache.init(grid), 76 | entities=entities, 77 | ) 78 | 79 | return Timestep( 80 | t=jnp.asarray(0, dtype=jnp.int32), 81 | observation=self.observation_fn(state), 82 | action=jnp.asarray(0, dtype=jnp.int32), 83 | reward=jnp.asarray(0.0, dtype=jnp.float32), 84 | step_type=jnp.asarray(0, dtype=jnp.int32), 85 | state=state, 86 | ) 87 | 88 | 89 | register_env( 90 | "Navix-DistShift1-v0", 91 | lambda *args, **kwargs: DistShift.create( 92 | height=7, 93 | width=9, 94 | split_lava=False, 95 | observation_fn=kwargs.pop("observation_fn", observations.symbolic), 96 | reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached), 97 | termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached), 98 | *args, 99 | **kwargs, 100 | ), 101 | ) 102 | register_env( 103 | "Navix-DistShift2-v0", 104 | lambda *args, **kwargs: DistShift.create( 105 | height=7, 106 | width=9, 107 | split_lava=True, 108 | observation_fn=kwargs.pop("observation_fn", observations.symbolic), 109 | reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached), 110 | termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached), 111 | *args, 112 | **kwargs, 113 | ), 114 | ) 115 | -------------------------------------------------------------------------------- /navix/rendering/cache.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Navix Authors. 2 | 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | 20 | 21 | from __future__ import annotations 22 | 23 | from typing import Dict, Tuple 24 | 25 | import jax 26 | from jax import Array 27 | import jax.numpy as jnp 28 | from flax import struct 29 | 30 | from .registry import TILE_SIZE, SPRITES_REGISTRY 31 | 32 | 33 | class RenderingCache(struct.PyTreeNode): 34 | patches: Array 35 | """A flat set of patches representing the RGB values of each tile in the base map""" 36 | 37 | @classmethod 38 | def init(cls, grid: Array) -> RenderingCache: 39 | background = render_background(grid) 40 | patches = flatten_patches(background) 41 | 42 | # add discard pile 43 | patches = jnp.concatenate( 44 | [ 45 | patches, 46 | jnp.zeros((1, TILE_SIZE, TILE_SIZE, 3), dtype=jnp.uint8), 47 | ], 48 | axis=0, 49 | ) 50 | return cls(patches=patches) 51 | 52 | 53 | def render_background( 54 | grid: Array, sprites_registry: Dict[str, Array] = SPRITES_REGISTRY 55 | ) -> Array: 56 | image_width = grid.shape[0] * TILE_SIZE 57 | image_height = grid.shape[1] * TILE_SIZE 58 | n_channels = 3 59 | 60 | background = jnp.zeros((image_height, image_width, n_channels), dtype=jnp.uint8) 61 | grid_resized = jax.image.resize( 62 | grid, (grid.shape[0] * TILE_SIZE, grid.shape[1] * TILE_SIZE), method="nearest" 63 | ) 64 | 65 | mask = jnp.asarray(grid_resized, dtype=bool) # 0 = floor, 1 = wall 66 | # index by [entity_type, direction, open/closed, y, x, channel] 67 | wall_tile = tile_grid(grid, sprites_registry["wall"]) 68 | floor_tile = tile_grid(grid, sprites_registry["floor"]) 69 | background = jnp.where(mask[..., None], wall_tile, floor_tile) 70 | return background 71 | 72 | 73 | def tile_grid(grid: Array, tile: Array) -> Array: 74 | """Tiles a grid (H, W) with equal tiles `tiles` (w, h, 3) to get a final array 75 | of shape (H * h, W * w, 3) and dtype `jnp.uint8`""" 76 | tiled = jnp.tile(tile, (*grid.shape, 1)) 77 | return jnp.asarray(tiled, dtype=jnp.uint8) 78 | 79 | 80 | def flatten_patches( 81 | image: Array, patch_size: Tuple[int, int] = (TILE_SIZE, TILE_SIZE) 82 | ) -> Array: 83 | height = image.shape[0] // patch_size[0] 84 | width = image.shape[1] // patch_size[1] 85 | n_channels = image.shape[2] 86 | 87 | grid = image.reshape(height, patch_size[0], width, patch_size[1], n_channels) 88 | 89 | # Swap the first and second axes of the grid to revert the stacking order 90 | grid = jnp.swapaxes(grid, 1, 2) 91 | 92 | # Reshape the grid of tiles into the original list of tiles 93 | patches = grid.reshape(height * width, patch_size[0], patch_size[1], n_channels) 94 | 95 | return patches 96 | 97 | 98 | def unflatten_patches(patches: Array, image_size: Tuple[int, int]) -> Array: 99 | image_height = image_size[0] 100 | image_width = image_size[1] 101 | patch_height = patches.shape[1] 102 | patch_width = patches.shape[2] 103 | n_channels = patches.shape[3] 104 | 105 | # Reshape the list of tiles into a 2D grid 106 | grid = patches.reshape( 107 | image_height // patch_height, 108 | image_width // patch_width, 109 | patch_height, 110 | patch_width, 111 | n_channels, 112 | ) 113 | 114 | # Swap the first and second axes of the grid to change the order of stacking 115 | grid = jnp.swapaxes(grid, 1, 2) 116 | 117 | # Reshape and stack the grid tiles horizontally and vertically to form the final image 118 | image = grid.reshape(image_height, image_width, n_channels) 119 | 120 | return image 121 | -------------------------------------------------------------------------------- /navix/environments/four_rooms.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Navix Authors. 2 | 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | 20 | 21 | from __future__ import annotations 22 | from typing import Union 23 | 24 | import jax 25 | import jax.numpy as jnp 26 | from jax import Array 27 | 28 | from navix import observations, rewards, terminations 29 | 30 | from ..components import EMPTY_POCKET_ID 31 | from ..entities import Entities, Goal, Player, Wall 32 | from ..states import State 33 | from ..grid import ( 34 | random_positions, 35 | random_directions, 36 | room, 37 | horizontal_wall, 38 | vertical_wall, 39 | ) 40 | from ..rendering.cache import RenderingCache 41 | from .environment import Environment, Timestep 42 | from .registry import register_env 43 | 44 | 45 | class FourRooms(Environment): 46 | def _reset(self, key: Array, cache: Union[RenderingCache, None] = None) -> Timestep: 47 | assert self.height > 4, f"Insufficient height for room {self.height} < 4" 48 | assert self.width > 4, f"Insufficient width for room {self.width} < 4" 49 | key, k1, k2 = jax.random.split(key, 3) 50 | 51 | # map 52 | grid = room(height=self.height, width=self.width) 53 | 54 | # vertical partition 55 | opening_1 = jax.random.randint(k1, shape=(), minval=1, maxval=self.height // 2) 56 | opening_2 = jax.random.randint( 57 | k1, shape=(), minval=self.height // 2 + 2, maxval=self.height 58 | ) 59 | openings = jnp.stack([opening_1, opening_2]) 60 | wall_pos_vert = vertical_wall(grid, 9, openings) 61 | 62 | # horizontal partition 63 | opening_1 = jax.random.randint(k2, shape=(), minval=1, maxval=self.width // 2) 64 | opening_2 = jax.random.randint( 65 | k1, shape=(), minval=self.width // 2 + 2, maxval=self.width 66 | ) 67 | openings = jnp.stack([opening_1, opening_2]) 68 | wall_pos_hor = horizontal_wall(grid, 9, openings) 69 | 70 | walls_pos = jnp.concatenate([wall_pos_vert, wall_pos_hor]) 71 | walls = Wall.create(position=walls_pos) 72 | 73 | # player 74 | player_pos, goal_pos = random_positions(k1, grid, n=2, exclude=walls_pos) 75 | direction = random_directions(k2, n=1) 76 | player = Player.create( 77 | position=player_pos, 78 | direction=direction, 79 | pocket=EMPTY_POCKET_ID, 80 | ) 81 | # goal 82 | goal = Goal.create(position=goal_pos, probability=jnp.asarray(1.0)) 83 | 84 | entities = { 85 | Entities.PLAYER: player[None], 86 | Entities.GOAL: goal[None], 87 | Entities.WALL: walls, 88 | } 89 | 90 | # systems 91 | state = State( 92 | key=key, 93 | grid=grid, 94 | cache=cache or RenderingCache.init(grid), 95 | entities=entities, 96 | ) 97 | 98 | return Timestep( 99 | t=jnp.asarray(0, dtype=jnp.int32), 100 | observation=self.observation_fn(state), 101 | action=jnp.asarray(0, dtype=jnp.int32), 102 | reward=jnp.asarray(0.0, dtype=jnp.float32), 103 | step_type=jnp.asarray(0, dtype=jnp.int32), 104 | state=state, 105 | ) 106 | 107 | 108 | register_env( 109 | "Navix-FourRooms-v0", 110 | lambda *args, **kwargs: FourRooms.create( 111 | height=19, 112 | width=19, 113 | observation_fn=kwargs.pop("observation_fn", observations.symbolic), 114 | reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached), 115 | termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached), 116 | *args, 117 | **kwargs, 118 | ), 119 | ) 120 | -------------------------------------------------------------------------------- /navix/components.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Navix Authors. 2 | 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | 20 | 21 | from __future__ import annotations 22 | from typing import Tuple 23 | 24 | 25 | from jax import Array 26 | from flax import struct 27 | import jax.numpy as jnp 28 | import dataclasses 29 | 30 | 31 | DISCARD_PILE_COORDS = jnp.asarray((0, -1), dtype=jnp.int32) 32 | DISCARD_PILE_IDX = jnp.asarray(-1, dtype=jnp.int32) 33 | EMPTY_POCKET_ID = jnp.asarray(-1, dtype=jnp.int32) 34 | UNSET_DIRECTION = jnp.asarray(-1, dtype=jnp.int32) 35 | UNSET_CONSUMED = jnp.asarray(-1, dtype=jnp.int32) 36 | 37 | 38 | def field(shape: Tuple[int, ...], **kwargs): 39 | return dataclasses.field(metadata={"shape": shape}, **kwargs) 40 | 41 | 42 | class Component(struct.PyTreeNode): 43 | """Base class for all components in the game. 44 | Components are used to store the data of the entities in the game.""" 45 | 46 | def check_ndim(self, batched: bool = False) -> None: 47 | return 48 | 49 | 50 | class Positionable(Component): 51 | """Flags an entity as positionable in the grid, and provides the `position` attribute""" 52 | 53 | position: Array = field(shape=(2,)) 54 | """The (row, column) position of the entity in the grid as a JAX array, defaults to the discard pile (-1, -1)""" 55 | 56 | 57 | class Directional(Component): 58 | """Flags an entity as directional, and provides the `direction` attribute""" 59 | 60 | direction: Array = field(shape=()) 61 | """The direction the entity: 0 = east, 1 = south, 2 = west, 3 = north""" 62 | 63 | 64 | class HasColour(Component): 65 | """Flags an entity as having a colour, and provides the `colour` attribute""" 66 | 67 | colour: Array = field(shape=()) 68 | """The colour of the object for rendering. """ 69 | 70 | 71 | class Stochastic(Component): 72 | """Flags an entity as stochastic, and provides the `probability` attribute 73 | 74 | TODO: 75 | * consider replace probability (Array) with a distrax.Distribution 76 | 77 | """ 78 | 79 | probability: Array = field(shape=()) 80 | """The probability of receiving the reward, if reached.""" 81 | 82 | 83 | class Openable(Component): 84 | """Flags an entity as openable, and provides the `requires` and `open` attributes""" 85 | 86 | requires: Array = field(shape=()) 87 | """The id of the item required to consume this item. If set, it must be > 0. 88 | If -1, the door is unlocked and does not require any key to open.""" 89 | open: Array = field(shape=()) 90 | """Open is jnp.asarray(0) if the entity is closed and 1 if open.""" 91 | 92 | 93 | class Pickable(Component): 94 | """Flags an entity as pickable, and provides the `id` attribute, which is used to identify the item in the inventory""" 95 | 96 | id: Array = field(shape=()) 97 | """The id of the item. If set, it must be >= 1.""" 98 | 99 | 100 | class Holder(Component): 101 | """Flags an entity as a holder, and provides the `pocket` attribute. The pocket is used to store the id of the item in the pocket.""" 102 | 103 | pocket: Array = field(shape=()) 104 | """The id of the item in the pocket (0 if empty)""" 105 | 106 | 107 | class HasTag(Component): 108 | """Flags an entity as having a tag, and provides the `tag` attribute. The tag is used to identify the type of the entity in the observations.""" 109 | 110 | @property 111 | def tag(self) -> Array: 112 | """The tag of the component, used to identify the type of the component in `observations.categorical`""" 113 | raise NotImplementedError() 114 | 115 | 116 | class HasSprite(Component): 117 | """Flags an entity as having a sprite, and provides the `sprite` attribute. The sprite is used to render the entity in the game.""" 118 | 119 | @property 120 | def sprite(self) -> Array: 121 | raise NotImplementedError() 122 | -------------------------------------------------------------------------------- /navix/terminations.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Navix Authors. 2 | 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | from __future__ import annotations 20 | 21 | from typing import Callable 22 | from jax import Array 23 | import jax.numpy as jnp 24 | 25 | from . import events 26 | from .states import State 27 | 28 | 29 | def compose( 30 | *term_functions: Callable[[State, Array, State], Array], 31 | operator: Callable = jnp.any, 32 | ) -> Callable: 33 | """Compose termination functions into a single termination function. 34 | 35 | Args: 36 | *term_functions (Callable): List of termination functions. 37 | operator (Callable): Operator to combine the termination functions. 38 | 39 | Returns: 40 | Callable: A single termination function.""" 41 | return lambda prev_state, action, state: operator( 42 | jnp.asarray([term_f(prev_state, action, state) for term_f in term_functions]) 43 | ) 44 | 45 | 46 | def check_truncation(terminated: Array, truncated: Array) -> Array: 47 | """Check if the episode is truncated or terminated, and returns a value 48 | that conforms to the `StepType` enum. 49 | 50 | Args: 51 | terminated (Array): A boolean array indicating whether the episode is terminated. 52 | truncated (Array): A boolean array indicating whether the episode is truncated. 53 | 54 | Returns: 55 | Array: An integer array that represents the step type.""" 56 | result = jnp.asarray(truncated + 2 * terminated, dtype=jnp.int32) 57 | return jnp.clip(result, 0, 2) 58 | 59 | 60 | def on_goal_reached(prev_state: State, action: Array, state: State) -> Array: 61 | """Check if the goal has been reached using the `goal_reached` event. 62 | 63 | Args: 64 | prev_state (State): The previous state of the game. 65 | action (Array): The action taken by the player. 66 | state (State): The current state of the game. 67 | 68 | Returns: 69 | Array: A boolean array indicating whether the goal has been reached.""" 70 | return jnp.asarray(events.on_goal_reached(state), dtype=jnp.bool_) 71 | 72 | 73 | def on_lava_fall(prev_state: State, action: Array, state: State) -> Array: 74 | """Check if the lava has fallen using the `lava_fall` event. 75 | 76 | Args: 77 | prev_state (State): The previous state of the game. 78 | action (Array): The action taken by the player. 79 | state (State): The current state of the game. 80 | 81 | Returns: 82 | Array: A boolean array indicating whether the lava has fallen.""" 83 | return jnp.asarray(events.on_lava_fall(state), dtype=jnp.bool_) 84 | 85 | 86 | def on_ball_hit(prev_state: State, action: Array, state: State) -> Array: 87 | """Check if the ball has hit something using the `ball_hit` event. 88 | 89 | Args: 90 | prev_state (State): The previous state of the game. 91 | action (Array): The action taken by the player. 92 | state (State): The current state of the game. 93 | 94 | Returns: 95 | Array: A boolean array indicating whether the ball has hit something.""" 96 | return jnp.asarray(events.on_ball_hit(state), dtype=jnp.bool_) 97 | 98 | 99 | def on_door_done(prev_state: State, action: Array, state: State) -> Array: 100 | """Check if the action `done` has been called in front of a `Door` object with the \ 101 | correct colour. 102 | 103 | Args: 104 | prev_state (State): The previous state of the game. 105 | action (Array): The action taken by the player. 106 | state (State): The current state of the game. 107 | 108 | Returns: 109 | Array: A boolean array indicating whether the action `done` has been called in \ 110 | front of a `Door` object with the correct colour. 111 | """ 112 | return jnp.asarray(events.on_door_done(state), dtype=jnp.bool_) 113 | 114 | 115 | DEFAULT_TERMINATION = compose(on_goal_reached, on_lava_fall, on_ball_hit) 116 | -------------------------------------------------------------------------------- /navix/spaces.py: -------------------------------------------------------------------------------- 1 | # Copyright [2023] The Helx Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from __future__ import annotations 17 | from typing import Tuple 18 | 19 | import jax 20 | import jax.numpy as jnp 21 | from jax import Array 22 | from flax import struct 23 | 24 | Shape = Tuple[int, ...] 25 | 26 | 27 | class Space(struct.PyTreeNode): 28 | """Base class for all spaces in the game. Spaces define the shape and type of the \ 29 | observations, actions and rewards in the game. 30 | The `sample` method is used to generate random samples from the space. 31 | 32 | !!! note 33 | To initialize a space, use the `create` method of the specific space class. 34 | 35 | TODO: 36 | * maximum and minimum should be static objects, not arrays. 37 | But how do we handle the case when they are not scalars? Maybe numpy arrays?""" 38 | 39 | shape: Shape = struct.field(pytree_node=False) 40 | dtype: jnp.dtype = struct.field(pytree_node=False) 41 | minimum: Array 42 | maximum: Array 43 | 44 | def sample(self, key: Array) -> Array: 45 | """Generate a random sample from the space. 46 | 47 | Args: 48 | key (Array): A random key to generate the sample. 49 | 50 | Returns: 51 | Array: A random sample from the space.""" 52 | raise NotImplementedError() 53 | 54 | 55 | class Discrete(Space): 56 | @classmethod 57 | def create( 58 | cls, n_elements: int | jax.Array, shape: Shape = (), dtype=jnp.int32 59 | ) -> Discrete: 60 | """Create a discrete space with a given number of elements. 61 | 62 | Args: 63 | n_elements (int | jax.Array): The number of elements in the space. 64 | shape (Shape): The shape of the space. 65 | dtype (jnp.dtype): The data type of the space. 66 | 67 | Returns: 68 | Discrete: A discrete space with the given number of elements.""" 69 | return Discrete( 70 | shape=shape, 71 | dtype=dtype, 72 | minimum=jnp.asarray(0), 73 | maximum=jnp.asarray(n_elements) - 1, 74 | ) 75 | 76 | def sample(self, key: Array) -> Array: 77 | """Generate a random sample from the space. 78 | 79 | Args: 80 | key (Array): A random key to generate the sample. 81 | 82 | Returns: 83 | Array: A random sample from the space.""" 84 | item = jax.random.randint(key, self.shape, self.minimum, self.maximum) 85 | # randint cannot draw jnp.uint, so we cast it later 86 | return jnp.asarray(item, dtype=self.dtype) 87 | 88 | @property 89 | def n(self) -> Array: 90 | """The number of elements in the space. 91 | 92 | Returns: 93 | Array: The number of elements in the space.""" 94 | return self.maximum + 1 95 | 96 | 97 | class Continuous(Space): 98 | @classmethod 99 | def create( 100 | cls, shape: Shape, minimum: Array, maximum: Array, dtype=jnp.float32 101 | ) -> Continuous: 102 | """Create a continuous space with a given shape, minimum and maximum values. 103 | 104 | Args: 105 | shape (Shape): The shape of the space. 106 | minimum (Array): The minimum value of the space. 107 | maximum (Array): The maximum value of the space. 108 | dtype (jnp.dtype): The data type of the space. 109 | 110 | Returns: 111 | Continuous: A continuous space with the given shape, minimum and maximum values. 112 | """ 113 | return Continuous(shape=shape, dtype=dtype, minimum=minimum, maximum=maximum) 114 | 115 | def sample(self, key: Array) -> Array: 116 | """Generate a random sample from the space. 117 | 118 | Args: 119 | key (Array): A random key to generate the sample. 120 | 121 | Returns: 122 | Array: A random sample from the space.""" 123 | assert jnp.issubdtype(self.dtype, jnp.floating) 124 | # see: https://github.com/google/jax/issues/14003 125 | lower = jnp.nan_to_num(self.minimum) 126 | upper = jnp.nan_to_num(self.maximum) 127 | return jax.random.uniform( 128 | key, self.shape, minval=lower, maxval=upper, dtype=self.dtype 129 | ) 130 | -------------------------------------------------------------------------------- /navix/rendering/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Navix Authors. 2 | 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | 20 | 21 | from __future__ import annotations 22 | 23 | import os 24 | from PIL import Image 25 | 26 | import jax 27 | from jax import Array 28 | import jax.numpy as jnp 29 | 30 | 31 | SPRITES_DIR = os.path.normpath( 32 | os.path.join(__file__, "..", "..", "..", "assets", "sprites") 33 | ) 34 | MIN_TILE_SIZE = 8 35 | TILE_SIZE = MIN_TILE_SIZE 36 | 37 | 38 | def load_sprite(name: str) -> Array: 39 | """Loads an image from disk in RGB space. 40 | Args: 41 | path(str): the filepath of the image on disk 42 | 43 | Returns: 44 | (Array): a jax.Array of shape (H, W, C)""" 45 | path = os.path.join(SPRITES_DIR, f"{name}.png") 46 | image = Image.open(path) 47 | array = jnp.asarray(image) 48 | resized = jax.image.resize(array, (TILE_SIZE, TILE_SIZE, 3), method="cubic") 49 | return jnp.asarray(resized, dtype=jnp.uint8) 50 | 51 | 52 | class PALETTE: 53 | RED: Array = jnp.asarray(0, dtype=jnp.uint8) 54 | GREEN: Array = jnp.asarray(1, dtype=jnp.uint8) 55 | BLUE: Array = jnp.asarray(2, dtype=jnp.uint8) 56 | PURPLE: Array = jnp.asarray(3, dtype=jnp.uint8) 57 | YELLOW: Array = jnp.asarray(4, dtype=jnp.uint8) 58 | GREY: Array = jnp.asarray(5, dtype=jnp.uint8) 59 | UNSET: Array = jnp.asarray(255, dtype=jnp.uint8) 60 | 61 | @classmethod 62 | def as_string(cls): 63 | return ["red", "green", "blue", "purple", "yellow", "grey"] 64 | 65 | @classmethod 66 | def as_array(cls): 67 | return [cls.RED, cls.GREEN, cls.BLUE, cls.PURPLE, cls.YELLOW, cls.GREY] 68 | 69 | 70 | class SpritesRegistry: 71 | def __init__(self): 72 | self.registry = {} 73 | self.build_registry() 74 | 75 | def build_registry(self): 76 | """Populates the sprites registry for all entities.""" 77 | self.set_wall_sprite() 78 | self.set_floor_sprite() 79 | self.set_goal_sprite() 80 | self.set_key_sprite() 81 | self.set_player_sprite() 82 | self.set_door_sprite() 83 | self.set_lava_sprite() 84 | self.set_ball_sprite() 85 | self.set_box_sprite() 86 | 87 | def set_wall_sprite(self): 88 | self.registry["wall"] = load_sprite("wall") 89 | 90 | def set_floor_sprite(self): 91 | self.registry["floor"] = load_sprite("floor") 92 | 93 | def set_goal_sprite(self): 94 | self.registry["goal"] = load_sprite("goal") 95 | 96 | def set_key_sprite(self): 97 | keys_coloured = [ 98 | load_sprite("key" + f"_{colour}") for colour in PALETTE.as_string() 99 | ] 100 | self.registry["key"] = jnp.stack(keys_coloured, axis=0) 101 | 102 | def set_player_sprite(self): 103 | self.registry["player"] = jnp.stack( 104 | [ 105 | load_sprite("player_east"), 106 | load_sprite("player_south"), 107 | load_sprite("player_west"), 108 | load_sprite("player_north"), 109 | ] 110 | ) 111 | 112 | def set_door_sprite(self): 113 | door = jnp.zeros( 114 | (len(PALETTE.as_string()), 3, TILE_SIZE, TILE_SIZE, 3), dtype=jnp.uint8 115 | ) 116 | for c_idx, colour in enumerate(PALETTE.as_string()): 117 | for s_idx, state in enumerate(["closed", "open", "locked"]): 118 | sprite = load_sprite("door" + f"_{state}" + f"_{colour}") 119 | door = door.at[c_idx, s_idx].set(sprite) 120 | self.registry["door"] = door 121 | 122 | def set_lava_sprite(self): 123 | self.registry["lava"] = load_sprite("lava") 124 | 125 | def set_ball_sprite(self): 126 | ball_coloured = [ 127 | load_sprite("ball" + f"_{colour}") for colour in PALETTE.as_string() 128 | ] 129 | self.registry["ball"] = jnp.stack(ball_coloured, axis=0) 130 | 131 | def set_box_sprite(self): 132 | box_coloured = [ 133 | load_sprite("box" + f"_{colour}") for colour in PALETTE.as_string() 134 | ] 135 | self.registry["box"] = jnp.stack(box_coloured, axis=0) 136 | 137 | 138 | # initialise sprites registry 139 | SPRITES_REGISTRY = SpritesRegistry().registry 140 | -------------------------------------------------------------------------------- /navix/transitions.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Navix Authors. 2 | 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | 20 | 21 | from __future__ import annotations 22 | 23 | from typing import Callable, Tuple 24 | from jax import Array 25 | import jax 26 | import jax.numpy as jnp 27 | from .entities import Entities, Ball 28 | from .states import EventsManager, State 29 | from .grid import positions_equal, translate 30 | 31 | 32 | def deterministic_transition( 33 | state: State, action: Array, actions_set: Tuple[Callable[[State], State], ...] 34 | ) -> State: 35 | """Deterministic transition function. It selects the action from the set of actions 36 | and applies it to the state. 37 | 38 | Args: 39 | state (State): The current state of the game. 40 | action (Array): The action to be taken. 41 | actions_set (Tuple[Callable[[State], State]): A set of actions that can be taken. 42 | 43 | Returns: 44 | State: The new state of the game.""" 45 | return jax.lax.switch(action, actions_set, state) 46 | 47 | 48 | def stochastic_transition( 49 | state: State, action: Array, actions_set: Tuple[Callable[[State], State], ...] 50 | ) -> State: 51 | """Stochastic transition function. It selects the action from the set of actions 52 | and applies it to the state, and updates entities that have stochastic transitions, 53 | such as balls. 54 | 55 | Args: 56 | state (State): The current state of the game. 57 | action (Array): The action to be taken. 58 | actions_set (Tuple[Callable[[State], State]): A set of actions that can be taken. 59 | 60 | Returns: 61 | State: The new state of the game.""" 62 | # actions 63 | state = jax.lax.switch(action, actions_set, state) 64 | 65 | state = update_balls(state) 66 | return state 67 | 68 | 69 | def update_balls(state: State) -> State: 70 | """Update the position of the balls in the game. 71 | Balls move in a random direction if they can, otherwise they stay in place. 72 | 73 | Args: 74 | state (State): The current state of the game. 75 | 76 | Returns: 77 | State: The new state of the game.""" 78 | def update_one(ball: Ball, key: Array) -> Tuple[Array, EventsManager]: 79 | direction = jax.random.randint(key, (), minval=0, maxval=4) 80 | new_position = translate(ball.position, direction) 81 | new_ball = ball.replace(position=new_position) 82 | can_move, events = _can_spawn_there(state, new_ball) 83 | return jnp.where(can_move, new_ball.position, ball.position), events 84 | 85 | if Entities.BALL in state.entities: 86 | balls: Ball = state.entities[Entities.BALL] # type: ignore 87 | keys = jax.random.split(state.key, len(balls.position) + 1) 88 | new_position, new_events = jax.jit(jax.vmap(update_one))(balls, keys[1:]) 89 | # update balls 90 | balls = balls.replace(position=new_position) 91 | state = state.set_balls(balls) 92 | # update events 93 | # take only the first happened event (even if happened already) 94 | idx = jnp.where(new_events.ball_hit.happened, size=1)[0][0] # scalar 95 | ball_hits = jax.tree.map(lambda x: x[idx], new_events.ball_hit) 96 | events = state.events.replace(ball_hit=ball_hits) 97 | state = state.replace(key=keys[0], events=events) 98 | return state 99 | 100 | 101 | def _can_spawn_there(state: State, ball: Ball) -> Tuple[Array, EventsManager]: 102 | # according to the grid 103 | walkable = jnp.equal(state.grid[tuple(ball.position)], 0) 104 | 105 | # according to entities 106 | events = state.events 107 | entities = state.entities 108 | for k in state.entities: 109 | obstructs = positions_equal(entities[k].position, ball.position)[0] 110 | if k == Entities.PLAYER: 111 | events = jax.lax.cond( 112 | obstructs, 113 | lambda x: x.record_ball_hit(ball), 114 | lambda x: x, 115 | events, 116 | ) 117 | walkable = jnp.logical_and(walkable, jnp.any(jnp.logical_not(obstructs))) 118 | return jnp.asarray(walkable, dtype=jnp.bool_), events 119 | 120 | 121 | DEFAULT_TRANSITION = stochastic_transition 122 | -------------------------------------------------------------------------------- /tests/test_observations.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | import navix as nx 7 | from navix.states import State 8 | from navix.entities import Entities, Player, Goal, Key, Door 9 | from navix.components import EMPTY_POCKET_ID 10 | from navix.rendering.cache import RenderingCache, TILE_SIZE 11 | from navix.rendering.registry import SPRITES_REGISTRY, PALETTE 12 | 13 | 14 | def test_rgb(): 15 | height = 10 16 | width = 10 17 | grid = jnp.zeros((height - 2, width - 2), dtype=jnp.int32) 18 | grid = jnp.pad(grid, 1, mode="constant", constant_values=-1) 19 | 20 | players = Player( 21 | position=jnp.asarray((1, 1)), direction=jnp.asarray(0), pocket=EMPTY_POCKET_ID 22 | ) 23 | goals = Goal.create(position=jnp.asarray((4, 4)), probability=jnp.asarray(1.0)) 24 | keys = Key(position=jnp.asarray((2, 2)), id=jnp.asarray(0), colour=PALETTE.YELLOW) 25 | doors = Door( 26 | position=jnp.asarray([(1, 5), (1, 6)]), 27 | requires=jnp.asarray((0, 0)), 28 | open=jnp.asarray((False, True)), 29 | colour=PALETTE.YELLOW[None], 30 | ) 31 | 32 | entities = { 33 | Entities.PLAYER: players[None], 34 | Entities.GOAL: goals[None], 35 | Entities.KEY: keys[None], 36 | Entities.DOOR: doors, 37 | } 38 | 39 | state = State( 40 | key=jax.random.PRNGKey(0), 41 | grid=grid, 42 | cache=RenderingCache.init(grid), 43 | entities=entities, 44 | ) 45 | sprites_registry = SPRITES_REGISTRY 46 | 47 | doors = state.get_doors() 48 | doors = doors.replace(open=jnp.asarray((False, True))) 49 | state.entities[Entities.DOOR] = doors 50 | 51 | obs = nx.observations.rgb(state) 52 | expected_obs_shape = ( 53 | height * TILE_SIZE, 54 | width * TILE_SIZE, 55 | 3, 56 | ) 57 | assert ( 58 | obs.shape == expected_obs_shape 59 | ), f"Expected observation {expected_obs_shape}, got {obs.shape} instead" 60 | 61 | def get_tile(position): 62 | x = position[0] * TILE_SIZE 63 | y = position[1] * TILE_SIZE 64 | return obs[x : x + TILE_SIZE, y : y + TILE_SIZE, :] 65 | 66 | player = state.get_player() 67 | player_tile = get_tile(player.position) 68 | assert jnp.array_equal( 69 | player_tile, sprites_registry[Entities.PLAYER][player.direction] 70 | ), player_tile 71 | 72 | goals = state.get_goals() 73 | goal_tile = get_tile(goals.position[0]) 74 | assert jnp.array_equal(goal_tile, sprites_registry[Entities.GOAL]), goal_tile 75 | 76 | keys = state.get_keys() 77 | key_tile = get_tile(keys.position[0]) 78 | colour = keys.colour[0] 79 | assert jnp.array_equal(key_tile, sprites_registry[Entities.KEY][colour]), key_tile 80 | 81 | doors = state.get_doors() 82 | door = doors[0] 83 | door_tile = get_tile(door.position) 84 | colour = door.colour 85 | idx = jnp.asarray(door.open + 2 * door.locked, dtype=jnp.int32) 86 | assert jnp.array_equal( 87 | door_tile, sprites_registry[Entities.DOOR][colour, idx] 88 | ), door_tile 89 | 90 | door = doors[1] 91 | door_tile = get_tile(door.position) 92 | colour = door.colour 93 | idx = jnp.asarray(door.open + 2 * door.locked, dtype=jnp.int32) 94 | assert jnp.array_equal( 95 | door_tile, sprites_registry[Entities.DOOR][colour, idx] 96 | ), door_tile 97 | 98 | return 99 | 100 | 101 | def test_categorical_first_person(): 102 | height = 10 103 | width = 10 104 | grid = jnp.zeros((height - 2, width - 2), dtype=jnp.int32) 105 | grid = jnp.pad(grid, 1, mode="constant", constant_values=-1) 106 | 107 | players = Player( 108 | position=jnp.asarray((1, 1)), direction=jnp.asarray(0), pocket=EMPTY_POCKET_ID 109 | ) 110 | goals = Goal.create(position=jnp.asarray((4, 4)), probability=jnp.asarray(1.0)) 111 | keys = Key(position=jnp.asarray((2, 2)), id=jnp.asarray(0), colour=PALETTE.YELLOW) 112 | doors = Door( 113 | position=jnp.asarray([(1, 5), (1, 6)]), 114 | requires=jnp.asarray((0, 0)), 115 | open=jnp.asarray((False, True)), 116 | colour=PALETTE.YELLOW, 117 | ) 118 | entities = { 119 | Entities.PLAYER: players[None], 120 | Entities.GOAL: goals[None], 121 | Entities.KEY: keys[None], 122 | Entities.DOOR: doors, 123 | } 124 | 125 | state = State( 126 | key=jax.random.PRNGKey(0), 127 | grid=grid, 128 | cache=RenderingCache.init(grid), 129 | entities=entities, 130 | ) 131 | 132 | obs = nx.observations.categorical_first_person(state) 133 | print(obs) 134 | 135 | 136 | def test_rgb_first_person(): 137 | import gymnasium as gym 138 | import minigrid 139 | 140 | navix_env_id = "Navix-Empty-8x8-v0" 141 | gym_env_id = navix_env_id.replace("Navix", "MiniGrid") 142 | 143 | env = nx.make(navix_env_id, observation_fn=nx.observations.rgb_first_person) 144 | timestep = env.reset(jax.random.PRNGKey(0)) 145 | 146 | env = gym.make(gym_env_id) 147 | env = minigrid.wrappers.RGBImgPartialObsWrapper(env) 148 | obs, _ = env.reset() 149 | obs = obs["image"] 150 | 151 | 152 | if __name__ == "__main__": 153 | test_rgb() 154 | # test_categorical_first_person() 155 | # jax.jit(test_categorical_first_person)() 156 | -------------------------------------------------------------------------------- /examples/purejaxrl/wrappers.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import chex 4 | import numpy as np 5 | from flax import struct 6 | from functools import partial 7 | from typing import Optional, Tuple, Union, Any 8 | from gymnax.environments import environment, spaces 9 | import navix as nx 10 | 11 | 12 | class GymnaxWrapper(object): 13 | """Base class for Gymnax wrappers.""" 14 | 15 | def __init__(self, env): 16 | self._env = env 17 | 18 | # provide proxy access to regular attributes of wrapped object 19 | def __getattr__(self, name): 20 | return getattr(self._env, name) 21 | 22 | 23 | class FlattenObservationWrapper(GymnaxWrapper): 24 | """Flatten the observations of the environment.""" 25 | 26 | def __init__(self, env: environment.Environment): 27 | super().__init__(env) 28 | 29 | def observation_space(self, params) -> spaces.Box: 30 | assert isinstance( 31 | self._env.observation_space(params), spaces.Box 32 | ), "Only Box spaces are supported for now." 33 | return spaces.Box( 34 | low=self._env.observation_space(params).low, 35 | high=self._env.observation_space(params).high, 36 | shape=(np.prod(self._env.observation_space(params).shape),), 37 | dtype=self._env.observation_space(params).dtype, 38 | ) 39 | 40 | @partial(jax.jit, static_argnums=(0,)) 41 | def reset( 42 | self, key: chex.PRNGKey, params: Optional[environment.EnvParams] = None 43 | ) -> Tuple[chex.Array, environment.EnvState]: 44 | obs, state = self._env.reset(key, params) 45 | obs = jnp.reshape(obs, (-1,)) 46 | return obs, state 47 | 48 | @partial(jax.jit, static_argnums=(0,)) 49 | def step( 50 | self, 51 | key: chex.PRNGKey, 52 | state: environment.EnvState, 53 | action: Union[int, float], 54 | params: Optional[environment.EnvParams] = None, 55 | ) -> Tuple[chex.Array, environment.EnvState, float, bool, dict]: 56 | obs, state, reward, done, info = self._env.step(key, state, action, params) 57 | obs = jnp.reshape(obs, (-1,)) 58 | return obs, state, reward, done, info 59 | 60 | 61 | @struct.dataclass 62 | class LogEnvState: 63 | env_state: environment.EnvState 64 | episode_returns: float 65 | episode_lengths: int 66 | returned_episode_returns: float 67 | returned_episode_lengths: int 68 | timestep: int 69 | 70 | 71 | class LogWrapper(GymnaxWrapper): 72 | """Log the episode returns and lengths.""" 73 | 74 | def __init__(self, env: environment.Environment): 75 | super().__init__(env) 76 | 77 | @partial(jax.jit, static_argnums=(0,)) 78 | def reset( 79 | self, key: chex.PRNGKey, params: Optional[environment.EnvParams] = None 80 | ) -> Tuple[chex.Array, environment.EnvState]: 81 | obs, env_state = self._env.reset(key, params) 82 | state = LogEnvState(env_state, 0, 0, 0, 0, 0) 83 | return obs, state 84 | 85 | @partial(jax.jit, static_argnums=(0,)) 86 | def step( 87 | self, 88 | key: chex.PRNGKey, 89 | state: environment.EnvState, 90 | action: Union[int, float], 91 | params: Optional[environment.EnvParams] = None, 92 | ) -> Tuple[chex.Array, environment.EnvState, float, bool, dict]: 93 | obs, env_state, reward, done, info = self._env.step( 94 | key, state.env_state, action, params 95 | ) 96 | new_episode_return = state.episode_returns + reward 97 | new_episode_length = state.episode_lengths + 1 98 | state = LogEnvState( 99 | env_state=env_state, 100 | episode_returns=new_episode_return * (1 - done), 101 | episode_lengths=new_episode_length * (1 - done), 102 | returned_episode_returns=state.returned_episode_returns * (1 - done) 103 | + new_episode_return * done, 104 | returned_episode_lengths=state.returned_episode_lengths * (1 - done) 105 | + new_episode_length * done, 106 | timestep=state.timestep + 1, 107 | ) 108 | info["returned_episode_returns"] = state.returned_episode_returns 109 | info["returned_episode_lengths"] = state.returned_episode_lengths 110 | info["timestep"] = state.timestep 111 | info["returned_episode"] = done 112 | return obs, state, reward, done, info 113 | 114 | class NavixGymnaxWrapper: 115 | def __init__(self, env_name): 116 | self._env = nx.make(env_name) 117 | 118 | def reset(self, key, params=None): 119 | timestep = self._env.reset(key) 120 | return timestep.observation, timestep 121 | 122 | def step(self, key, state, action, params=None): 123 | timestep = self._env.step(state, action) 124 | return timestep.observation, timestep, timestep.reward, timestep.is_done(), {} 125 | 126 | def observation_space(self, params): 127 | return spaces.Box( 128 | low=self._env.observation_space.minimum, 129 | high=self._env.observation_space.maximum, 130 | shape=(np.prod(self._env.observation_space.shape),), 131 | dtype=self._env.observation_space.dtype, 132 | ) 133 | 134 | def action_space(self, params): 135 | return spaces.Discrete( 136 | num_categories=self._env.action_space.maximum.item() + 1, 137 | ) 138 | 139 | -------------------------------------------------------------------------------- /navix/environments/lava_gap.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Navix Authors. 2 | 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | 20 | 21 | from typing import Union 22 | 23 | import jax 24 | import jax.numpy as jnp 25 | from jax import Array 26 | 27 | from navix import observations, rewards, terminations 28 | 29 | from ..components import EMPTY_POCKET_ID 30 | from ..rendering.cache import RenderingCache 31 | from . import Environment 32 | from ..entities import Player, Goal, Lava 33 | from ..states import State 34 | from . import Timestep 35 | from ..grid import room 36 | from .registry import register_env 37 | 38 | 39 | class LavaGap(Environment): 40 | def _reset(self, key: Array, cache: Union[RenderingCache, None] = None) -> Timestep: 41 | # check minimum height and width 42 | assert ( 43 | self.height > 3 44 | ), f"Room height must be greater than 3, got {self.height} instead" 45 | assert ( 46 | self.width > 4 47 | ), f"Room width must be greater than 5, got {self.width} instead" 48 | 49 | key, k1, k2 = jax.random.split(key, num=3) 50 | 51 | grid = room(height=self.height, width=self.width) 52 | 53 | # player 54 | player_pos = jnp.asarray([1, 1]) 55 | player_dir = jnp.asarray(0) 56 | player = Player.create( 57 | position=player_pos, direction=player_dir, pocket=EMPTY_POCKET_ID 58 | ) 59 | # goal 60 | goal_pos = jnp.asarray([self.height - 2, self.width - 2]) 61 | goals = Goal.create(position=goal_pos, probability=jnp.asarray(1.0)) 62 | 63 | # lava positions 64 | gap_row = jax.random.randint(k1, (), 1, self.height - 1) # col 65 | 66 | col = jax.random.randint(k2, (), minval=2, maxval=self.width - 2) 67 | lava_row = jnp.arange(1, self.height - 1) 68 | lava_cols = jnp.asarray([col] * (self.height - 2)) 69 | lava_pos = jnp.stack((lava_row, lava_cols), axis=1) 70 | # remove lava where the door is 71 | lava_pos = jnp.delete(lava_pos, gap_row - 1, axis=0, assume_unique_indices=True) 72 | lavas = Lava.create(position=lava_pos) 73 | 74 | entities = { 75 | "player": player[None], 76 | "goal": goals[None], 77 | "lava": lavas, 78 | } 79 | 80 | state = State( 81 | key=key, 82 | grid=grid, 83 | cache=cache or RenderingCache.init(grid), 84 | entities=entities, 85 | ) 86 | return Timestep( 87 | t=jnp.asarray(0, dtype=jnp.int32), 88 | observation=self.observation_fn(state), 89 | action=jnp.asarray(-1, dtype=jnp.int32), 90 | reward=jnp.asarray(0.0, dtype=jnp.float32), 91 | step_type=jnp.asarray(0, dtype=jnp.int32), 92 | state=state, 93 | ) 94 | 95 | 96 | register_env( 97 | "Navix-LavaGapS5-v0", 98 | lambda *args, **kwargs: LavaGap.create( 99 | height=5, 100 | width=5, 101 | observation_fn=kwargs.pop("observation_fn", observations.symbolic), 102 | reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached), 103 | termination_fn=kwargs.pop( 104 | "termination_fn", 105 | terminations.compose( 106 | terminations.on_goal_reached, 107 | terminations.on_lava_fall, 108 | ), 109 | ), 110 | *args, 111 | **kwargs, 112 | ), 113 | ) 114 | register_env( 115 | "Navix-LavaGapS6-v0", 116 | lambda *args, **kwargs: LavaGap.create( 117 | height=6, 118 | width=6, 119 | observation_fn=kwargs.pop("observation_fn", observations.symbolic), 120 | reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached), 121 | termination_fn=kwargs.pop( 122 | "termination_fn", 123 | terminations.compose( 124 | terminations.on_goal_reached, 125 | terminations.on_lava_fall, 126 | ), 127 | ), 128 | *args, 129 | **kwargs, 130 | ), 131 | ) 132 | register_env( 133 | "Navix-LavaGapS7-v0", 134 | lambda *args, **kwargs: LavaGap.create( 135 | height=7, 136 | width=7, 137 | observation_fn=kwargs.pop("observation_fn", observations.symbolic), 138 | reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached), 139 | termination_fn=kwargs.pop( 140 | "termination_fn", 141 | terminations.compose( 142 | terminations.on_goal_reached, 143 | terminations.on_lava_fall, 144 | ), 145 | ), 146 | *args, 147 | **kwargs, 148 | ), 149 | ) 150 | -------------------------------------------------------------------------------- /navix/rewards.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Navix Authors. 2 | 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | from __future__ import annotations 20 | from typing import Callable 21 | 22 | 23 | import jax.numpy as jnp 24 | from jax import Array 25 | 26 | from . import events 27 | from .states import State 28 | 29 | 30 | def compose( 31 | *reward_functions: Callable[[State, Array, State], Array], 32 | operator: Callable = jnp.sum, 33 | ) -> Callable: 34 | """Compose multiple reward functions into a single reward function. 35 | The functions are called in order and the results are reduced using the `operator` \ 36 | function. 37 | 38 | Args: 39 | *reward_functions (Callable[[State, Array, State], Array]): A list of reward functions. 40 | operator (Callable): The operator to reduce the results of the reward functions. 41 | It must be a function that takes a list of arrays, or an array and returns an \ 42 | array of size `f32[]`. 43 | 44 | Returns: 45 | Callable: A composed reward function that applies the `operator` to the results of the \ 46 | reward functions.""" 47 | return lambda prev_state, action, state: operator( 48 | jnp.asarray( 49 | [f(prev_state, action, state) for f in reward_functions], dtype=jnp.float32 50 | ) 51 | ) 52 | 53 | 54 | def free(state: State) -> Array: 55 | """A reward function that always returns 0, to simulate reward-free learning. 56 | 57 | Args: 58 | state (State): The current state of the game. 59 | 60 | Returns: 61 | Array: A scalar array `f32[]` with value 0.""" 62 | return jnp.asarray(0.0, dtype=jnp.float32) 63 | 64 | 65 | def on_goal_reached(prev_state: State, action: Array, state: State) -> Array: 66 | """A reward function that returns 1 when the goal is reached, and 0 otherwise. 67 | 68 | Args: 69 | state (State): The current state of the game. 70 | 71 | Returns: 72 | Array: A scalar array `f32[]` with value 1 if the goal is reached, and 0 otherwise. 73 | """ 74 | return jnp.asarray(events.on_goal_reached(state), dtype=jnp.float32) 75 | 76 | 77 | def action_cost( 78 | prev_state: State, action: Array, new_state: State, cost: float = 0.01 79 | ) -> Array: 80 | """A reward function that returns a negative value when an action is taken. 81 | All actions have a cost of `cost`, except for noops. 82 | 83 | Args: 84 | prev_state (State): The previous state of the game. 85 | action (Array): The action taken. 86 | new_state (State): The new state of the game. 87 | cost (float): The cost of taking an action. 88 | 89 | Returns: 90 | Array: A scalar array `f32[]` with value -`cost` if the action is not a noop, \ 91 | and 0 otherwise.""" 92 | # noops are free 93 | return -jnp.asarray(action != 6, dtype=jnp.float32) * cost 94 | 95 | 96 | def time_cost( 97 | prev_state: State, action: Array, new_state: State, cost: float = 0.01 98 | ) -> Array: 99 | """A reward function that returns a negative value as time passes, paying a cost \ 100 | of `cost` at each time step. 101 | 102 | Args: 103 | prev_state (State): The previous state of the game. 104 | action (Array): The action taken. 105 | new_state (State): The new state of the game. 106 | cost (float): The cost of time passing. 107 | 108 | Returns: 109 | Array: A scalar array `f32[]` with value -`cost`. 110 | """ 111 | # time always has a cost 112 | return -jnp.asarray(cost, dtype=jnp.float32) 113 | 114 | 115 | def wall_hit_cost( 116 | prev_state: State, action: Array, state: State, cost: float = 0.01 117 | ) -> Array: 118 | """A reward function that returns a negative value when the agent hits a wall, \ 119 | paying a cost of `cost` for each wall hit. 120 | 121 | Args: 122 | state (State): The current state of the game. 123 | cost (float): The cost of hitting a wall. 124 | 125 | Returns: 126 | Array: A scalar array `f32[]` with value -`cost` if the agent hits a wall, \ 127 | and 0 otherwise.""" 128 | return jnp.asarray(events.on_wall_hit(state), dtype=jnp.float32) * cost 129 | 130 | 131 | def on_door_done(prev_state: State, action: Array, state: State) -> Array: 132 | """A reward function that returns a positive value when the agent uses the action \ 133 | `done` in front of a door. 134 | 135 | Args: 136 | state (State): The current state of the game. 137 | 138 | Returns: 139 | Array: A scalar array `f32[]` with value 1 if the agent uses the action `done` in \ 140 | front of a door, and 0 otherwise.""" 141 | 142 | return jnp.asarray(events.on_door_done(state), dtype=jnp.float32) 143 | 144 | 145 | DEFAULT_TASK = compose(on_goal_reached, action_cost) 146 | """The default task for the game, composed of the `on_goal_reached` and `action_cost` reward functions.""" 147 | -------------------------------------------------------------------------------- /navix/environments/go_to_door.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Navix Authors. 2 | 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | 20 | 21 | from __future__ import annotations 22 | from typing import Union 23 | 24 | import jax 25 | import jax.numpy as jnp 26 | from jax import Array 27 | from flax import struct 28 | 29 | from navix import observations 30 | 31 | from .. import rewards, terminations 32 | from ..components import EMPTY_POCKET_ID 33 | from ..entities import Entities, Door, Player 34 | from ..states import EventType, State, Event 35 | from ..grid import random_colour, random_directions 36 | from ..rendering.cache import RenderingCache 37 | from .environment import Environment, Timestep 38 | from .registry import register_env 39 | 40 | 41 | class GoToDoor(Environment): 42 | split_lava: bool = struct.field(pytree_node=False, default=False) 43 | 44 | def _reset(self, key: Array, cache: Union[RenderingCache, None] = None) -> Timestep: 45 | # map 46 | grid = jnp.zeros((self.height, self.width), dtype=jnp.int32) 47 | 48 | k1, k2, k3, k4, k5, k6 = jax.random.split(key, num=6) 49 | room_height = jax.random.randint(k1, (), minval=5, maxval=self.height) 50 | room_width = jax.random.randint(k1, (), minval=5, maxval=self.width) 51 | 52 | # set wall on grid 53 | grid = grid.at[jnp.asarray([0, room_height - 1])].set(-1) 54 | grid = grid.at[:, jnp.asarray([0, room_width - 1])].set(-1) 55 | 56 | # player 57 | player_row = jax.random.randint(k2, (), minval=1, maxval=room_height - 1) 58 | player_col = jax.random.randint(k3, (), minval=1, maxval=room_width - 1) 59 | player_pos = jnp.asarray([player_row, player_col]) 60 | direction = random_directions(k4) 61 | player = Player( 62 | position=player_pos, 63 | direction=direction, 64 | pocket=EMPTY_POCKET_ID, 65 | ) 66 | 67 | # doors 68 | k6, k7 = jax.random.split(k5, num=2) 69 | rows = jax.random.randint(k6, (2,), minval=2, maxval=room_height - 2) 70 | cols = jax.random.randint(k7, (2,), minval=2, maxval=room_width - 2) 71 | positions = jnp.asarray( 72 | [ 73 | [rows[0], room_width - 1], 74 | [room_height - 1, cols[0]], 75 | [rows[1], 0], 76 | [0, cols[1]], 77 | ] 78 | ) 79 | colours = random_colour(key, n=4) 80 | open = jnp.asarray([0] * 4) 81 | requires = jnp.asarray([-1] * 4) 82 | doors = Door.create( 83 | position=positions, requires=requires, colour=colours, open=open 84 | ) 85 | 86 | entities = { 87 | Entities.PLAYER: player[None], 88 | Entities.DOOR: doors, 89 | } 90 | 91 | idx = jax.random.randint(k6, (), minval=0, maxval=4) 92 | target_door = doors[idx] 93 | mission = Event( 94 | position=target_door.position, 95 | colour=target_door.colour, 96 | happened=jnp.asarray(False), 97 | event_type=EventType.REACH, 98 | ) 99 | 100 | # systems 101 | state = State( 102 | key=key, 103 | grid=grid, 104 | cache=RenderingCache.init(grid), 105 | entities=entities, 106 | mission=mission, 107 | ) 108 | 109 | return Timestep( 110 | t=jnp.asarray(0, dtype=jnp.int32), 111 | observation=self.observation_fn(state), 112 | action=jnp.asarray(0, dtype=jnp.int32), 113 | reward=jnp.asarray(0.0, dtype=jnp.float32), 114 | step_type=jnp.asarray(0, dtype=jnp.int32), 115 | state=state, 116 | ) 117 | 118 | 119 | register_env( 120 | "Navix-GoToDoor-5x5-v0", 121 | lambda *args, **kwargs: GoToDoor.create( 122 | height=5, 123 | width=5, 124 | observation_fn=kwargs.pop("observation_fn", observations.symbolic), 125 | reward_fn=kwargs.pop("reward_fn", rewards.on_door_done), 126 | termination_fn=kwargs.pop("termination_fn", terminations.on_door_done), 127 | *args, 128 | **kwargs, 129 | ), 130 | ) 131 | register_env( 132 | "Navix-GoToDoor-6x6-v0", 133 | lambda *args, **kwargs: GoToDoor.create( 134 | height=6, 135 | width=6, 136 | observation_fn=kwargs.pop("observation_fn", observations.symbolic), 137 | reward_fn=kwargs.pop("reward_fn", rewards.on_door_done), 138 | termination_fn=kwargs.pop("termination_fn", terminations.on_door_done), 139 | *args, 140 | **kwargs, 141 | ), 142 | ) 143 | register_env( 144 | "Navix-GoToDoor-8x8-v0", 145 | lambda *args, **kwargs: GoToDoor.create( 146 | height=8, 147 | width=8, 148 | observation_fn=kwargs.pop("observation_fn", observations.symbolic), 149 | reward_fn=kwargs.pop("reward_fn", rewards.on_door_done), 150 | termination_fn=kwargs.pop("termination_fn", terminations.on_door_done), 151 | *args, 152 | **kwargs, 153 | ), 154 | ) 155 | -------------------------------------------------------------------------------- /navix/environments/dynamic_obstacles.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Navix Authors. 2 | 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | 20 | 21 | from __future__ import annotations 22 | from typing import Union 23 | 24 | import jax 25 | import jax.numpy as jnp 26 | from jax import Array 27 | from flax import struct 28 | 29 | from navix import observations, rewards, terminations 30 | 31 | from ..components import EMPTY_POCKET_ID 32 | from ..entities import Entities, Goal, Player, Ball 33 | from ..states import State 34 | from ..grid import random_positions, random_directions, room 35 | from ..rendering.cache import RenderingCache 36 | from ..rendering.registry import PALETTE 37 | from .environment import Environment, Timestep 38 | from .registry import register_env 39 | 40 | 41 | class DynamicObstacles(Environment): 42 | random_start: bool = struct.field(pytree_node=False, default=False) 43 | n_obstacles: int = struct.field(pytree_node=False, default=2) 44 | 45 | def _reset(self, key: Array, cache: Union[RenderingCache, None] = None) -> Timestep: 46 | key, k1, k2, k3 = jax.random.split(key, 4) 47 | 48 | # map 49 | grid = room(height=self.height, width=self.width) 50 | 51 | # goal and player 52 | if self.random_start: 53 | player_pos = random_positions(k1, grid) 54 | direction = random_directions(k2, n=1) 55 | else: 56 | player_pos = jnp.asarray([1, 1]) 57 | direction = jnp.asarray(0) 58 | # player 59 | player = Player.create( 60 | position=player_pos, 61 | direction=direction, 62 | pocket=EMPTY_POCKET_ID, 63 | ) 64 | # goal 65 | goal_pos = jnp.asarray([self.height - 2, self.width - 2]) 66 | goal = Goal.create(position=goal_pos, probability=jnp.asarray(1.0)) 67 | 68 | # balls 69 | exclude = jnp.stack([player_pos, goal_pos]) 70 | ball_pos = random_positions(k3, grid, n=self.n_obstacles, exclude=exclude) 71 | balls = Ball.create( 72 | position=ball_pos, 73 | colour=jnp.tile(PALETTE.BLUE, (self.n_obstacles,)), 74 | probability=jnp.ones(self.n_obstacles), 75 | ) 76 | 77 | entities = { 78 | Entities.PLAYER: player[None], 79 | Entities.GOAL: goal[None], 80 | Entities.BALL: balls, 81 | } 82 | 83 | # systems 84 | state = State( 85 | key=key, 86 | grid=grid, 87 | cache=cache or RenderingCache.init(grid), 88 | entities=entities, 89 | ) 90 | 91 | return Timestep( 92 | t=jnp.asarray(0, dtype=jnp.int32), 93 | observation=self.observation_fn(state), 94 | action=jnp.asarray(0, dtype=jnp.int32), 95 | reward=jnp.asarray(0.0, dtype=jnp.float32), 96 | step_type=jnp.asarray(0, dtype=jnp.int32), 97 | state=state, 98 | ) 99 | 100 | 101 | register_env( 102 | "Navix-Dynamic-Obstacles-5x5-v0", 103 | lambda *args, **kwargs: DynamicObstacles.create( 104 | height=5, 105 | width=5, 106 | n_obstacles=2, 107 | random_start=False, 108 | observation_fn=kwargs.pop("observation_fn", observations.symbolic), 109 | reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached), 110 | *args, 111 | **kwargs, 112 | ), 113 | ) 114 | register_env( 115 | "Navix-Dynamic-Obstacles-5x5-Random-v0", 116 | lambda *args, **kwargs: DynamicObstacles.create( 117 | height=5, 118 | width=5, 119 | n_obstacles=2, 120 | random_start=True, 121 | observation_fn=kwargs.pop("observation_fn", observations.symbolic), 122 | reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached), 123 | *args, 124 | **kwargs, 125 | ), 126 | ) 127 | register_env( 128 | "Navix-Dynamic-Obstacles-6x6-v0", 129 | lambda *args, **kwargs: DynamicObstacles.create( 130 | height=6, 131 | width=6, 132 | n_obstacles=3, 133 | random_start=False, 134 | observation_fn=kwargs.pop("observation_fn", observations.symbolic), 135 | reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached), 136 | *args, 137 | **kwargs, 138 | ), 139 | ) 140 | register_env( 141 | "Navix-Dynamic-Obstacles-6x6-Random-v0", 142 | lambda *args, **kwargs: DynamicObstacles.create( 143 | height=6, 144 | width=6, 145 | n_obstacles=3, 146 | random_start=True, 147 | observation_fn=kwargs.pop("observation_fn", observations.symbolic), 148 | reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached), 149 | *args, 150 | **kwargs, 151 | ), 152 | ) 153 | register_env( 154 | "Navix-Dynamic-Obstacles-8x8-v0", 155 | lambda *args, **kwargs: DynamicObstacles.create( 156 | height=8, 157 | width=8, 158 | n_obstacles=4, 159 | random_start=False, 160 | observation_fn=kwargs.pop("observation_fn", observations.symbolic), 161 | reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached), 162 | *args, 163 | **kwargs, 164 | ), 165 | ) 166 | register_env( 167 | "Navix-Dynamic-Obstacles-16x16-v0", 168 | lambda *args, **kwargs: DynamicObstacles.create( 169 | height=16, 170 | width=16, 171 | n_obstacles=8, 172 | random_start=False, 173 | observation_fn=kwargs.pop("observation_fn", observations.symbolic), 174 | reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached), 175 | *args, 176 | **kwargs, 177 | ), 178 | ) 179 | -------------------------------------------------------------------------------- /navix/environments/empty.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Navix Authors. 2 | 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | 20 | 21 | from __future__ import annotations 22 | from typing import Union 23 | 24 | import jax 25 | import jax.numpy as jnp 26 | from jax import Array 27 | from flax import struct 28 | 29 | from .. import rewards, observations, terminations 30 | from ..components import EMPTY_POCKET_ID 31 | from ..entities import Entities, Goal, Player 32 | from ..states import State 33 | from ..grid import random_positions, random_directions, room 34 | from ..rendering.cache import RenderingCache 35 | from .environment import Environment, Timestep 36 | from .registry import register_env 37 | 38 | 39 | class Room(Environment): 40 | random_start: bool = struct.field(pytree_node=False, default=False) 41 | 42 | def _reset(self, key: Array, cache: Union[RenderingCache, None] = None) -> Timestep: 43 | key, k1, k2 = jax.random.split(key, 3) 44 | 45 | # map 46 | grid = room(height=self.height, width=self.width) 47 | 48 | # goal and player 49 | if self.random_start: 50 | player_pos, goal_pos = random_positions(k1, grid, n=2) 51 | direction = random_directions(k2, n=1) 52 | else: 53 | goal_pos = jnp.asarray([self.height - 2, self.width - 2]) 54 | player_pos = jnp.asarray([1, 1]) 55 | direction = jnp.asarray(0) 56 | player = Player.create( 57 | position=player_pos, 58 | direction=direction, 59 | pocket=EMPTY_POCKET_ID, 60 | ) 61 | # goal 62 | goal = Goal.create(position=goal_pos, probability=jnp.asarray(1.0)) 63 | 64 | entities = { 65 | Entities.PLAYER: player[None], 66 | Entities.GOAL: goal[None], 67 | } 68 | 69 | # systems 70 | state = State( 71 | key=key, 72 | grid=grid, 73 | cache=cache or RenderingCache.init(grid), 74 | entities=entities, 75 | ) 76 | 77 | return Timestep( 78 | t=jnp.asarray(0, dtype=jnp.int32), 79 | observation=self.observation_fn(state), 80 | action=jnp.asarray(0, dtype=jnp.int32), 81 | reward=jnp.asarray(0.0, dtype=jnp.float32), 82 | step_type=jnp.asarray(0, dtype=jnp.int32), 83 | state=state, 84 | ) 85 | 86 | 87 | register_env( 88 | "Navix-Empty-5x5-v0", 89 | lambda *args, **kwargs: Room.create( 90 | height=5, 91 | width=5, 92 | random_start=False, 93 | *args, 94 | observation_fn=kwargs.pop("observation_fn", observations.symbolic), 95 | reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached), 96 | termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached), 97 | **kwargs, 98 | ), 99 | ) 100 | register_env( 101 | "Navix-Empty-6x6-v0", 102 | lambda *args, **kwargs: Room.create( 103 | height=6, 104 | width=6, 105 | random_start=False, 106 | *args, 107 | observation_fn=kwargs.pop("observation_fn", observations.symbolic), 108 | reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached), 109 | termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached), 110 | **kwargs, 111 | ), 112 | ) 113 | register_env( 114 | "Navix-Empty-8x8-v0", 115 | lambda *args, **kwargs: Room.create( 116 | height=8, 117 | width=8, 118 | random_start=False, 119 | observation_fn=kwargs.pop("observation_fn", observations.symbolic), 120 | reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached), 121 | termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached), 122 | *args, 123 | **kwargs, 124 | ), 125 | ) 126 | register_env( 127 | "Navix-Empty-16x16-v0", 128 | lambda *args, **kwargs: Room.create( 129 | height=16, 130 | width=16, 131 | random_start=False, 132 | observation_fn=kwargs.pop("observation_fn", observations.symbolic), 133 | reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached), 134 | termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached), 135 | *args, 136 | **kwargs, 137 | ), 138 | ) 139 | register_env( 140 | "Navix-Empty-Random-5x5-v0", 141 | lambda *args, **kwargs: Room.create( 142 | height=5, 143 | width=5, 144 | random_start=True, 145 | observation_fn=kwargs.pop("observation_fn", observations.symbolic), 146 | reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached), 147 | termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached), 148 | *args, 149 | **kwargs, 150 | ), 151 | ) 152 | register_env( 153 | "Navix-Empty-Random-6x6-v0", 154 | lambda *args, **kwargs: Room.create( 155 | height=6, 156 | width=6, 157 | random_start=True, 158 | observation_fn=kwargs.pop("observation_fn", observations.symbolic), 159 | reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached), 160 | termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached), 161 | *args, 162 | **kwargs, 163 | ), 164 | ) 165 | register_env( 166 | "Navix-Empty-Random-8x8-v0", 167 | lambda *args, **kwargs: Room.create( 168 | height=8, 169 | width=8, 170 | random_start=True, 171 | observation_fn=kwargs.pop("observation_fn", observations.symbolic), 172 | reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached), 173 | termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached), 174 | *args, 175 | **kwargs, 176 | ), 177 | ) 178 | register_env( 179 | "Navix-Empty-Random-16x16-v0", 180 | lambda *args, **kwargs: Room.create( 181 | height=16, 182 | width=16, 183 | random_start=True, 184 | observation_fn=kwargs.pop("observation_fn", observations.symbolic), 185 | reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached), 186 | termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached), 187 | *args, 188 | **kwargs, 189 | ), 190 | ) 191 | -------------------------------------------------------------------------------- /navix/environments/crossings.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Navix Authors. 2 | 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | 20 | 21 | from typing import Union 22 | import jax 23 | import jax.numpy as jnp 24 | from jax import Array 25 | from flax import struct 26 | 27 | from navix import observations, rewards, terminations 28 | 29 | from ..components import EMPTY_POCKET_ID 30 | from ..rendering.cache import RenderingCache 31 | from . import Environment 32 | from ..entities import Player, Goal, Lava 33 | from ..states import State 34 | from . import Timestep 35 | from .registry import register_env 36 | 37 | 38 | class Crossings(Environment): 39 | n_crossings: int = struct.field(pytree_node=False, default=1) 40 | is_lava: bool = struct.field(pytree_node=False, default=False) 41 | 42 | def _reset(self, key: Array, cache: Union[RenderingCache, None] = None) -> Timestep: 43 | assert ( 44 | self.height == self.width 45 | ), f"Crossings are only defined for square grids, got height {self.height} and \ 46 | width {self.width}" 47 | # check minimum height and width 48 | key, k1, k2 = jax.random.split(key, num=3) 49 | 50 | grid = jnp.zeros((self.height - 2, self.width - 2), dtype=jnp.int32) 51 | 52 | # player 53 | player_pos = jnp.asarray([1, 1]) 54 | player_dir = jnp.asarray(0) 55 | player = Player.create( 56 | position=player_pos, direction=player_dir, pocket=EMPTY_POCKET_ID 57 | ) 58 | # goal 59 | goal_pos = jnp.asarray([self.height - 2, self.width - 2]) 60 | goals = Goal.create(position=goal_pos, probability=jnp.asarray(1.0)) 61 | 62 | entities = { 63 | "player": player[None], 64 | "goal": goals[None], 65 | } 66 | 67 | # crossings 68 | obstacles_hor = jnp.mgrid[ 69 | 1 : self.height - 2 : 2, 1 : self.width - 1 70 | ].transpose(1, 2, 0) 71 | obstacles_ver = jnp.mgrid[ 72 | 1 : self.height - 1, 1 : self.width - 2 : 2 73 | ].transpose(2, 1, 0) 74 | all_obstacles_pos = jnp.concatenate([obstacles_hor, obstacles_ver]) 75 | num_obstacles = min(self.n_crossings, len(all_obstacles_pos)) 76 | obstacles_pos = jax.random.choice( 77 | k1, all_obstacles_pos, (num_obstacles,), replace=False 78 | ) 79 | 80 | if self.is_lava: 81 | entities["lava"] = Lava.create(position=obstacles_pos) 82 | else: 83 | grid = grid.at[tuple(obstacles_pos.T)].set(-1) 84 | 85 | # path to goal 86 | def update(direction, start, grid, step_size): 87 | return jax.lax.cond( 88 | direction == jnp.asarray(0, dtype=jnp.int32), 89 | lambda: ( 90 | start + jnp.asarray([0, step_size]), 91 | jax.lax.dynamic_update_slice( 92 | grid, jnp.zeros((1, step_size), dtype=jnp.int32), tuple(start.T) 93 | ), 94 | ), 95 | lambda: ( 96 | start + jnp.asarray([step_size, 0]), 97 | jax.lax.dynamic_update_slice( 98 | grid, jnp.zeros((step_size, 1), dtype=jnp.int32), tuple(start.T) 99 | ), 100 | ), 101 | ) 102 | 103 | start = jnp.asarray([0, 0], dtype=jnp.int32) 104 | step_size = 3 105 | for i in range(10): 106 | k2, k3 = jax.random.split(k2) 107 | direction = jax.random.randint(k2, (), minval=0, maxval=2) 108 | start, grid = update(direction, start, grid, step_size) 109 | 110 | grid = jnp.pad(grid, 1, mode="constant", constant_values=-1) 111 | 112 | state = State( 113 | key=key, 114 | grid=grid, 115 | cache=RenderingCache.init(grid), 116 | entities=entities, 117 | ) 118 | return Timestep( 119 | t=jnp.asarray(0, dtype=jnp.int32), 120 | observation=self.observation_fn(state), 121 | action=jnp.asarray(-1, dtype=jnp.int32), 122 | reward=jnp.asarray(0.0, dtype=jnp.float32), 123 | step_type=jnp.asarray(0, dtype=jnp.int32), 124 | state=state, 125 | ) 126 | 127 | 128 | register_env( 129 | "Navix-SimpleCrossingS9N1-v0", 130 | lambda *args, **kwargs: Crossings.create( 131 | height=9, 132 | width=9, 133 | n_crossings=1, 134 | observation_fn=kwargs.pop("observation_fn", observations.symbolic), 135 | reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached), 136 | termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached), 137 | *args, 138 | **kwargs, 139 | ), 140 | ) 141 | register_env( 142 | "Navix-SimpleCrossingS9N2-v0", 143 | lambda *args, **kwargs: Crossings.create( 144 | height=9, 145 | width=9, 146 | n_crossings=2, 147 | observation_fn=kwargs.pop("observation_fn", observations.symbolic), 148 | reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached), 149 | termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached), 150 | *args, 151 | **kwargs, 152 | ), 153 | ) 154 | register_env( 155 | "Navix-SimpleCrossingS9N3-v0", 156 | lambda *args, **kwargs: Crossings.create( 157 | height=9, 158 | width=9, 159 | n_crossings=3, 160 | observation_fn=kwargs.pop("observation_fn", observations.symbolic), 161 | reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached), 162 | termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached), 163 | *args, 164 | **kwargs, 165 | ), 166 | ) 167 | register_env( 168 | "Navix-SimpleCrossingS11N5-v0", 169 | lambda *args, **kwargs: Crossings.create( 170 | height=11, 171 | width=11, 172 | n_crossings=5, 173 | observation_fn=kwargs.pop("observation_fn", observations.symbolic), 174 | reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached), 175 | termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached), 176 | *args, 177 | **kwargs, 178 | ), 179 | ) 180 | -------------------------------------------------------------------------------- /navix/experiment.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict, replace, fields 2 | import time 3 | from typing import Dict, Tuple 4 | 5 | import distrax 6 | import jax 7 | import jax.numpy as jnp 8 | import wandb 9 | import wandb.util 10 | from navix.agents.agent import Agent 11 | from navix.environments.environment import Environment 12 | 13 | 14 | class Experiment: 15 | """A class to run an experiment with a given agent and environment. 16 | 17 | Args: 18 | name (str): The name of the experiment. 19 | agent (Agent): The agent to use in the experiment. 20 | env (Environment): The environment to use in the experiment. 21 | env_id (str): The ID of the environment. 22 | seeds (Tuple[int, ...]): The seeds to use in the experiment. 23 | group (str): The group to use in the experiment. 24 | 25 | Attributes: 26 | name (str): The name of the experiment. 27 | agent (Agent): The agent to use in the experiment. 28 | env (Environment): The environment to use in the experiment. 29 | env_id (str): The ID of the environment. 30 | seeds (Tuple[int, ...]): The seeds to use in the experiment. 31 | group (str): The group to use in the experiment. 32 | 33 | """ 34 | 35 | def __init__( 36 | self, 37 | name: str, 38 | agent: Agent, 39 | env: Environment, 40 | env_id: str = "", 41 | seeds: Tuple[int, ...] = (0,), 42 | group: str = "", 43 | ): 44 | self.name = name 45 | self.agent = agent 46 | self.env = env 47 | self.env_id = env_id 48 | self.seeds = seeds 49 | self.group = group 50 | 51 | def run(self, do_log: bool = True): 52 | """Default function to run the experiment. This function compiles the training function, trains the agent, and logs the results. 53 | 54 | Args: 55 | do_log (bool): Whether to log the results to wandb. 56 | !!! Warning 57 | Logging to `wandb` is usually much slower than training the agent itself. 58 | The time is linear in the number of seeds. 59 | 60 | Returns: 61 | Tuple: A tuple containing the final training state and the logs. 62 | """ 63 | print("Running experiment with the following configuration:") 64 | print(vars(self)) 65 | rng = jnp.asarray([jax.random.PRNGKey(seed) for seed in self.seeds]) 66 | 67 | print("Compiling training function...") 68 | start_time = time.time() 69 | train_fn = jax.jit(jax.vmap(self.agent.train)).lower(rng).compile() 70 | compilation_time = time.time() - start_time 71 | print(f"Compilation time cost: {compilation_time}") 72 | 73 | print("Training agent...") 74 | start_time = time.time() 75 | train_state, logs = train_fn(rng) 76 | training_time = time.time() - start_time 77 | print(f"Training time cost: {training_time}") 78 | 79 | if not self.agent.hparams.debug and do_log: 80 | print("Logging final results to wandb...") 81 | start_time = time.time() 82 | for seed in self.seeds: 83 | config = {**vars(self), **asdict(self.agent.hparams)} 84 | config.update(seed=seed) 85 | wandb.init(project=self.name, config=config, group=self.group) 86 | print("Logging results for seed:", seed) 87 | log = jax.tree.map(lambda x: x[seed], logs) 88 | self.agent.log_on_train_end(log) 89 | wandb.finish() 90 | logging_time = time.time() - start_time 91 | print(f"Logging time cost: {logging_time}") 92 | 93 | print("Training complete") 94 | total_time = 0 95 | print(f"Compilation time cost: {compilation_time}") 96 | total_time += compilation_time 97 | print(f"Training time cost: {training_time}") 98 | total_time += training_time 99 | if not self.agent.hparams.debug and do_log: 100 | print(f"Logging time cost: {logging_time}") 101 | total_time += logging_time 102 | print(f"Total time cost: {total_time}") 103 | return train_state, logs 104 | 105 | def run_hparam_search( 106 | self, hparams_distr: Dict[str, distrax.Distribution], pop_size: int 107 | ): 108 | """Function to run a hyperparameter search for the experiment. This function \ 109 | samples hyperparameters from the given distributions, trains the agent, and \ 110 | logs the results. 111 | 112 | Args: 113 | hparams_distr (Dict[str, distrax.Distribution]): A dictionary of \ 114 | hyperparameter distributions. The keys are the hyperparameter names, which \ 115 | must exist in `self.agent.hparams`, and the values are the corresponding \ 116 | distributions. 117 | pop_size (int): The number of hyperparameter sets to sample. 118 | 119 | Returns: 120 | Tuple: A tuple containing the final training states and the logs, batched \ 121 | over the hyperparameter sets. 122 | """ 123 | hparams_fields = fields(self.agent.hparams) 124 | for k in hparams_distr: 125 | member = list(filter(lambda x: x.name == k, hparams_fields)) 126 | if ( 127 | len(member) > 0 128 | and "pytree_node" in member[0].metadata 129 | and member[0].metadata["pytree_node"] == False 130 | ): 131 | raise ValueError( 132 | f"Hyperparameter {k} is not a traceable pytree node. " 133 | + f"Set pytree_node=True for {k} to include it into the hparam search." 134 | ) 135 | 136 | search_set = [] 137 | for seed in range(pop_size): 138 | hparams = self.agent.hparams 139 | key = jax.random.PRNGKey(seed) 140 | for k, distr in hparams_distr.items(): 141 | hparams = replace(hparams, **{k: distr.sample(seed=key)}) 142 | print("Hparams:", hparams) 143 | search_set.append(hparams) 144 | # transpose search set 145 | len_search_set = len(search_set) 146 | search_set = jax.tree.map(lambda *x: jnp.stack(x), *search_set) 147 | 148 | rngs = jnp.asarray([jax.random.PRNGKey(seed) for seed in self.seeds]) 149 | 150 | def search(hparam_set_sample): 151 | agent = self.agent.replace(hparams=hparam_set_sample) 152 | return jax.vmap(agent.train)(rngs) 153 | 154 | print("Running hyperparameter search with the following configuration:") 155 | print(search_set) 156 | 157 | print("Compiling search function...") 158 | start_time = time.time() 159 | search_fn = jax.jit(jax.vmap(search)).lower(search_set).compile() 160 | compilation_time = time.time() - start_time 161 | print(f"Compilation time cost: {compilation_time}") 162 | 163 | print("Searching for optimal hyperparameters...") 164 | start_time = time.time() 165 | train_states, logs = search_fn(search_set) 166 | search_time = time.time() - start_time 167 | print(f"Search time cost: {search_time}") 168 | 169 | print("Logging final results to wandb...") 170 | start_time = time.time() 171 | # average over seeds 172 | for i in range(len_search_set): 173 | print("Logging results for hparam set:", search_set) 174 | hparams = jax.tree.map(lambda x: x[i], search_set) 175 | config = {**vars(self), **asdict(hparams)} 176 | wandb.init(project=self.name, config=config, group=self.group) 177 | log = jax.tree.map(lambda x: jnp.mean(x[i], axis=0), logs) 178 | self.agent.log_on_train_end(log) 179 | wandb.finish() 180 | logging_time = time.time() - start_time 181 | 182 | print("Hyperparameter search complete") 183 | total_time = 0 184 | print(f"Compilation time cost: {compilation_time}") 185 | total_time += compilation_time 186 | print(f"Search time cost: {search_time}") 187 | total_time += search_time 188 | print(f"Logging time cost: {logging_time}") 189 | total_time += logging_time 190 | print(f"Total time cost: {total_time}") 191 | return train_states, logs 192 | -------------------------------------------------------------------------------- /navix/environments/key_corridor.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Navix Authors. 2 | 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | 20 | 21 | from __future__ import annotations 22 | 23 | from typing import Union 24 | import jax 25 | import jax.numpy as jnp 26 | from jax import Array 27 | 28 | from navix import observations, rewards, terminations 29 | 30 | from ..components import EMPTY_POCKET_ID 31 | from ..rendering.cache import RenderingCache 32 | from ..environments import Environment 33 | from ..entities import Ball, Player, Key, Door 34 | from ..states import State 35 | from ..environments import Timestep 36 | from ..grid import random_directions, random_colour, RoomsGrid 37 | from .registry import register_env 38 | 39 | 40 | class KeyCorridor(Environment): 41 | def _reset(self, key: Array, cache: Union[RenderingCache, None] = None) -> Timestep: 42 | n_rows_config = {3: 1, 5: 2} 43 | n_rows = n_rows_config.get(self.height, 3) 44 | room_size = (self.width - 3) // 3 45 | k1, k2, k3, k4, k5, k6 = jax.random.split(key, num=6) 46 | 47 | # grid of rooms 48 | grid = RoomsGrid.create(n_rows, 3, (room_size, room_size)) 49 | 50 | # key 51 | key_room_row = jax.random.randint(k1, (), minval=0, maxval=n_rows) 52 | key_pos = grid.position_in_room( 53 | key_room_row, jnp.asarray(0, dtype=jnp.int32), key=k1 54 | ) 55 | key_colour = random_colour(k4) 56 | key_id = jnp.asarray(1) 57 | key_obj = Key.create(key_pos, key_colour, key_id) 58 | 59 | # agent 60 | pk_1, pk_2, pk_3 = jax.random.split(k2, num=3) 61 | agent_room_row = jax.random.randint(pk_1, (), minval=0, maxval=n_rows) 62 | agent_pos = grid.position_in_room(agent_room_row, jnp.asarray(1), key=pk_2) 63 | player = Player.create( 64 | agent_pos, random_directions(pk_3), pocket=EMPTY_POCKET_ID 65 | ) 66 | 67 | # ball 68 | ball_room_row = jax.random.randint(k3, (), minval=0, maxval=n_rows) 69 | ball_pos = grid.position_in_room(ball_room_row, jnp.asarray(2), key=k4) 70 | ball = Ball.create(ball_pos, random_colour(k6), probability=jnp.asarray(0.0)) 71 | 72 | # Doors 73 | doors = [] 74 | for row in range(n_rows): 75 | k5, k6, k7, k8, k9 = jax.random.split(k5, num=5) 76 | # left corridor, right wall 77 | door_pos = grid.position_on_border(row, 2, 0, key=k5) 78 | requires, colour, open = jax.lax.cond( 79 | jnp.array_equal(row, ball_room_row), 80 | lambda: (key_id, key_colour, jnp.asarray(2)), 81 | lambda: (jnp.asarray(-1), random_colour(k5), jnp.asarray(0)), 82 | ) 83 | doors.append( 84 | Door.create( 85 | position=door_pos, requires=requires, colour=colour, open=open 86 | ) 87 | ) 88 | # right corridor, left wall 89 | door_pos = grid.position_on_border(row, 0, 1, key=k7) 90 | doors.append( 91 | Door.create( 92 | position=door_pos, 93 | requires=EMPTY_POCKET_ID, 94 | colour=random_colour(k7), 95 | open=jnp.asarray(0), 96 | ) 97 | ) 98 | for row in range(n_rows - 1): 99 | k9, k10, k11, k12 = jax.random.split(k9, num=4) 100 | # first col 101 | door_pos = grid.position_on_border(row, 0, 3, key=k9) 102 | doors.append( 103 | Door.create( 104 | position=door_pos, 105 | requires=EMPTY_POCKET_ID, 106 | colour=random_colour(k10), 107 | open=jnp.asarray(0), 108 | ) 109 | ) 110 | doors.append( 111 | Door.create( 112 | position=door_pos, 113 | requires=EMPTY_POCKET_ID, 114 | colour=random_colour(k12), 115 | open=jnp.asarray(0), 116 | ) 117 | ) 118 | doors = jax.tree.map(lambda *x: jnp.stack(x), *doors) 119 | 120 | entities = { 121 | "player": player[None], 122 | "key": key_obj[None], 123 | "door": doors, 124 | "goal": ball[None], 125 | } 126 | 127 | grid = grid.get_grid() 128 | grid = grid.at[ 129 | 1 + room_size : self.height - 1 : room_size + 1, 130 | 1 + room_size + 1 : 1 + room_size + 1 + room_size, 131 | ].set(0) 132 | state = State( 133 | key=key, 134 | grid=grid, 135 | cache=cache or RenderingCache.init(grid), 136 | entities=entities, 137 | ) 138 | return Timestep( 139 | t=jnp.asarray(0, dtype=jnp.int32), 140 | observation=self.observation_fn(state), 141 | action=jnp.asarray(-1, dtype=jnp.int32), 142 | reward=jnp.asarray(0.0, dtype=jnp.float32), 143 | step_type=jnp.asarray(0, dtype=jnp.int32), 144 | state=state, 145 | ) 146 | 147 | 148 | register_env( 149 | "Navix-KeyCorridorS3R1-v0", 150 | lambda *args, **kwargs: KeyCorridor.create( 151 | height=3, 152 | width=7, 153 | observation_fn=kwargs.pop("observation_fn", observations.symbolic), 154 | reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached), 155 | termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached), 156 | *args, 157 | **kwargs, 158 | ), 159 | ) 160 | register_env( 161 | "Navix-KeyCorridorS3R2-v0", 162 | lambda *args, **kwargs: KeyCorridor.create( 163 | height=5, 164 | width=7, 165 | observation_fn=kwargs.pop("observation_fn", observations.symbolic), 166 | reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached), 167 | termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached), 168 | *args, 169 | **kwargs, 170 | ), 171 | ) 172 | register_env( 173 | "Navix-KeyCorridorS3R3-v0", 174 | lambda *args, **kwargs: KeyCorridor.create( 175 | height=7, 176 | width=7, 177 | observation_fn=kwargs.pop("observation_fn", observations.symbolic), 178 | reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached), 179 | termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached), 180 | *args, 181 | **kwargs, 182 | ), 183 | ) 184 | register_env( 185 | "Navix-KeyCorridorS4R3-v0", 186 | lambda *args, **kwargs: KeyCorridor.create( 187 | height=10, 188 | width=10, 189 | observation_fn=kwargs.pop("observation_fn", observations.symbolic), 190 | reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached), 191 | termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached), 192 | *args, 193 | **kwargs, 194 | ), 195 | ) 196 | register_env( 197 | "Navix-KeyCorridorS5R3-v0", 198 | lambda *args, **kwargs: KeyCorridor.create( 199 | height=13, 200 | width=13, 201 | observation_fn=kwargs.pop("observation_fn", observations.symbolic), 202 | reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached), 203 | termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached), 204 | *args, 205 | **kwargs, 206 | ), 207 | ) 208 | register_env( 209 | "Navix-KeyCorridorS6R3-v0", 210 | lambda *args, **kwargs: KeyCorridor.create( 211 | height=16, 212 | width=16, 213 | observation_fn=kwargs.pop("observation_fn", observations.symbolic), 214 | reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached), 215 | termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached), 216 | *args, 217 | **kwargs, 218 | ), 219 | ) 220 | -------------------------------------------------------------------------------- /navix/environments/door_key.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Navix Authors. 2 | 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | 20 | 21 | from typing import Union 22 | import jax 23 | import jax.numpy as jnp 24 | from jax import Array 25 | from flax import struct 26 | 27 | from .. import rewards, observations, terminations 28 | from ..components import EMPTY_POCKET_ID 29 | from ..rendering.cache import RenderingCache 30 | from ..rendering.registry import PALETTE 31 | from . import Environment 32 | from ..entities import Player, Key, Door, Goal, Wall 33 | from ..states import State 34 | from . import Timestep 35 | from ..grid import mask_by_coordinates, room, random_positions, random_directions 36 | from .registry import register_env 37 | 38 | 39 | class DoorKey(Environment): 40 | random_start: bool = struct.field(pytree_node=False, default=False) 41 | 42 | def _reset(self, key: Array, cache: Union[RenderingCache, None] = None) -> Timestep: 43 | # check minimum height and width 44 | assert ( 45 | self.height > 3 46 | ), f"Room height must be greater than 3, got {self.height} instead" 47 | assert ( 48 | self.width > 4 49 | ), f"Room width must be greater than 5, got {self.width} instead" 50 | 51 | key, k1, k2, k3, k4 = jax.random.split(key, 5) 52 | 53 | grid = room(height=self.height, width=self.width) 54 | 55 | # door positions 56 | # col can be between 1 and height - 2 57 | door_col = jax.random.randint(k4, (), 2, self.width - 2) # col 58 | # row can be between 1 and height - 2 59 | door_row = jax.random.randint(k3, (), 1, self.height - 1) # row 60 | door_pos = jnp.asarray((door_row, door_col)) 61 | doors = Door.create( 62 | position=door_pos, 63 | requires=jnp.asarray(3), 64 | open=jnp.asarray(False), 65 | colour=PALETTE.YELLOW, 66 | ) 67 | 68 | # wall positions 69 | wall_rows = jnp.arange(1, self.height - 1) 70 | wall_cols = jnp.asarray([door_col] * (self.height - 2)) 71 | wall_pos = jnp.stack((wall_rows, wall_cols), axis=1) 72 | # remove wall where the door is 73 | wall_pos = jnp.delete( 74 | wall_pos, door_row - 1, axis=0, assume_unique_indices=True 75 | ) 76 | walls = Wall.create(position=wall_pos) 77 | 78 | # get rooms 79 | first_room_mask = mask_by_coordinates( 80 | grid, (jnp.asarray(self.height), door_col), jnp.less 81 | ) 82 | first_room = jnp.where(first_room_mask, grid, -1) # put walls where not mask 83 | second_room_mask = mask_by_coordinates( 84 | grid, (jnp.asarray(0), door_col), jnp.greater 85 | ) 86 | second_room = jnp.where(second_room_mask, grid, -1) # put walls where not mask 87 | 88 | # set player and goal pos 89 | if self.random_start: 90 | player_pos = random_positions(k1, first_room) 91 | player_dir = random_directions(k2) 92 | goal_pos = random_positions(k2, second_room) 93 | else: 94 | player_pos = jnp.asarray([1, 1]) 95 | player_dir = jnp.asarray(0) 96 | goal_pos = jnp.asarray([self.height - 2, self.width - 2]) 97 | 98 | # spawn goal and player 99 | player = Player.create( 100 | position=player_pos, direction=player_dir, pocket=EMPTY_POCKET_ID 101 | ) 102 | goals = Goal.create(position=goal_pos, probability=jnp.asarray(1.0)) 103 | 104 | # spawn key 105 | key_pos = random_positions(k2, first_room, exclude=player_pos) 106 | keys = Key.create(position=key_pos, id=jnp.asarray(3), colour=PALETTE.YELLOW) 107 | 108 | # remove the wall beneath the door 109 | grid = grid.at[tuple(door_pos)].set(0) 110 | 111 | entities = { 112 | "player": player[None], 113 | "key": keys[None], 114 | "door": doors[None], 115 | "goal": goals[None], 116 | "wall": walls, 117 | } 118 | 119 | state = State( 120 | key=key, 121 | grid=grid, 122 | cache=cache or RenderingCache.init(grid), 123 | entities=entities, 124 | ) 125 | return Timestep( 126 | t=jnp.asarray(0, dtype=jnp.int32), 127 | observation=self.observation_fn(state), 128 | action=jnp.asarray(-1, dtype=jnp.int32), 129 | reward=jnp.asarray(0.0, dtype=jnp.float32), 130 | step_type=jnp.asarray(0, dtype=jnp.int32), 131 | state=state, 132 | ) 133 | 134 | 135 | register_env( 136 | "Navix-DoorKey-5x5-v0", 137 | lambda *args, **kwargs: DoorKey.create( 138 | observation_fn=kwargs.pop("observation_fn", observations.symbolic), 139 | reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached), 140 | termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached), 141 | height=5, 142 | width=5, 143 | random_start=False, 144 | *args, 145 | **kwargs, 146 | ), 147 | ) 148 | register_env( 149 | "Navix-DoorKey-6x6-v0", 150 | lambda *args, **kwargs: DoorKey.create( 151 | observation_fn=kwargs.pop("observation_fn", observations.symbolic), 152 | reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached), 153 | termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached), 154 | height=6, 155 | width=6, 156 | random_start=False, 157 | *args, 158 | **kwargs, 159 | ), 160 | ) 161 | register_env( 162 | "Navix-DoorKey-8x8-v0", 163 | lambda *args, **kwargs: DoorKey.create( 164 | observation_fn=kwargs.pop("observation_fn", observations.symbolic), 165 | reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached), 166 | termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached), 167 | height=8, 168 | width=8, 169 | random_start=False, 170 | *args, 171 | **kwargs, 172 | ), 173 | ) 174 | register_env( 175 | "Navix-DoorKey-16x16-v0", 176 | lambda *args, **kwargs: DoorKey.create( 177 | observation_fn=kwargs.pop("observation_fn", observations.symbolic), 178 | reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached), 179 | termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached), 180 | height=16, 181 | width=16, 182 | random_start=False, 183 | *args, 184 | **kwargs, 185 | ), 186 | ) 187 | register_env( 188 | "Navix-DoorKey-Random-5x5-v0", 189 | lambda *args, **kwargs: DoorKey.create( 190 | observation_fn=kwargs.pop("observation_fn", observations.symbolic), 191 | reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached), 192 | termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached), 193 | height=5, 194 | width=5, 195 | random_start=True, 196 | *args, 197 | **kwargs, 198 | ), 199 | ) 200 | register_env( 201 | "Navix-DoorKey-Random-6x6-v0", 202 | lambda *args, **kwargs: DoorKey.create( 203 | observation_fn=kwargs.pop("observation_fn", observations.symbolic), 204 | reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached), 205 | termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached), 206 | height=6, 207 | width=6, 208 | random_start=True, 209 | *args, 210 | **kwargs, 211 | ), 212 | ) 213 | register_env( 214 | "Navix-DoorKey-Random-8x8-v0", 215 | lambda *args, **kwargs: DoorKey.create( 216 | observation_fn=kwargs.pop("observation_fn", observations.symbolic), 217 | reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached), 218 | termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached), 219 | height=8, 220 | width=8, 221 | random_start=True, 222 | *args, 223 | **kwargs, 224 | ), 225 | ) 226 | register_env( 227 | "Navix-DoorKey-Random-16x16-v0", 228 | lambda *args, **kwargs: DoorKey.create( 229 | observation_fn=kwargs.pop("observation_fn", observations.symbolic), 230 | reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached), 231 | termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached), 232 | height=16, 233 | width=16, 234 | random_start=True, 235 | *args, 236 | **kwargs, 237 | ), 238 | ) 239 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | logo 3 | 4 | # NAVIX: minigrid in JAX 5 | [![CI](https://github.com/epignatelli/navix/actions/workflows/CI.yml/badge.svg)](https://github.com/epignatelli/navix/actions/workflows/CI.yml) 6 | [![CD](https://github.com/epignatelli/navix/actions/workflows/CD.yml/badge.svg)](https://github.com/epignatelli/navix/actions/workflows/CD.yml) 7 | ![PyPI version](https://img.shields.io/pypi/v/navix?label=PyPI&color=%230099ab) 8 | [![arXiv](https://img.shields.io/badge/arXiv-2407.19396-b31b1b.svg?style=flat)](https://arxiv.org/abs/2407.19396) 9 | 10 | **[Quickstart](#what-is-navix)** | **[Install](#installation)** | **[Performance](#performance)** | **[Examples](#examples)** | **[Docs](https://epignatelli.com/navix)** | **[The JAX ecosystem](#jax-ecosystem-for-rl)** | **[Contribute](#join-us)** | **[Cite](#cite)** 11 | 12 |
13 | 14 | ## What is NAVIX? 15 | NAVIX is a JAX-powered reimplementation of [MiniGrid](https://github.com/Farama-Foundation/Minigrid). Experiments that took **1 week**, now take **15 minutes**. 16 | 17 | 200 000x speedups compared to MiniGrid and 670 Million steps/s are not just a speed improvements. They produce a whole new paradigm that grants access to experiments that were previously impossible, e.g., those taking years to run. 18 | 19 | It changes the game. 20 | Check out the NAVIX [performance](#performance) more in detail and the [documentation](https://epignatelli.com/navix) for more information. 21 | 22 | Key features: 23 | - Performance Boost: NAVIX offers **over 1000x** speed increase compared to the original Minigrid implementation, enabling faster experimentation and scaling. You can see a preliminary performance comparison [here](docs/performance.py), and a full benchmarking at [here](benchmarks/). 24 | - XLA Compilation: Leverage the power of XLA to optimize NAVIX computations for many accelerators. NAVIX can run on CPU, GPU, and TPU. 25 | - Autograd Support: Differentiate through environment transitions, opening up new possibilities such as learned world models. 26 | - Batched hyperparameter tuning: run thousands of experiments in parallel, enabling hyperparameter tuning at scale. Clear your doubts instantly if your algorithm doesn't work because of the hyperparameters choice. 27 | - It allows finally focus on the method research, and not the engineering. 28 | 29 | The library is in active development, and we are working on adding more environments and features. 30 | If you want join the development and contribute, please [open a discussion](https://github.com/epignatelli/navix/discussions/new?category=general) and let's have a chat! 31 | 32 | 33 | ## Installation 34 | #### Install JAX 35 | Follow the official installation guide for your OS and preferred accelerator: https://github.com/google/jax#installation. 36 | 37 | #### Install NAVIX 38 | ```bash 39 | pip install navix 40 | ``` 41 | 42 | Or, for the latest version from source: 43 | ```bash 44 | pip install git+https://github.com/epignatelli/navix 45 | ``` 46 | 47 | ## Performance 48 | NAVIX improves MiniGrid both in execution speed *and* throughput, allowing to run more than 2048 PPO agents in parallel almost 10 times faster than *a single* PPO agent in the original MiniGrid. 49 | 50 | ![speedup_env](https://github.com/user-attachments/assets/b221048c-1b98-43d8-b09b-2a240412dd81) 51 | 52 | NAVIX performs 2048 × 1M/49s = 668 734 693.88 steps per second (∼ 670 Million steps/s) in batch mode, 53 | while the original Minigrid implementation performs 1M/318.01 = 3 144.65 steps per second. This 54 | is a speedup of over 200 000×. 55 | ![throughput_ppo](https://github.com/user-attachments/assets/eea6e312-55b4-41c3-adb0-4207c5e78fd1) 56 | 57 | 58 | ## Examples 59 | You can view a full set of examples [here](examples/) (more coming), but here are the most common use cases. 60 | 61 | ### Compiling a collection step 62 | ```python 63 | import jax 64 | import navix as nx 65 | import jax.numpy as jnp 66 | 67 | 68 | def run(seed): 69 | env = nx.make('MiniGrid-Empty-8x8-v0') # Create the environment 70 | key = jax.random.PRNGKey(seed) 71 | timestep = env.reset(key) 72 | actions = jax.random.randint(key, (N_TIMESTEPS,), 0, env.action_space.n) 73 | 74 | def body_fun(timestep, action): 75 | timestep = env.step(action) # Update the environment state 76 | return timestep, () 77 | 78 | return jax.lax.scan(body_fun, timestep, actions)[0] 79 | 80 | # Compile the entire training run for maximum performance 81 | final_timestep = jax.jit(jax.vmap(run))(jnp.arange(1000)) 82 | ``` 83 | 84 | ### Compiling a full training run 85 | ```python 86 | import jax 87 | import navix as nx 88 | import jax.numpy as jnp 89 | from jax import random 90 | 91 | def run_episode(seed, env, policy): 92 | """Simulates a single episode with a given policy""" 93 | key = random.PRNGKey(seed) 94 | timestep = env.reset(key) 95 | done = False 96 | total_reward = 0 97 | 98 | while not done: 99 | action = policy(timestep.observation) 100 | timestep, reward, done, _ = env.step(action) 101 | total_reward += reward 102 | 103 | return total_reward 104 | 105 | def train_policy(policy, num_episodes): 106 | """Trains a policy over multiple parallel episodes""" 107 | envs = jax.vmap(nx.make, in_axes=0)(['MiniGrid-MultiRoom-N2-S4-v0'] * num_episodes) 108 | seeds = random.split(random.PRNGKey(0), num_episodes) 109 | 110 | # Compile the entire training loop with XLA 111 | compiled_episode = jax.jit(run_episode) 112 | compiled_train = jax.jit(jax.vmap(compiled_episode, in_axes=(0, 0, None))) 113 | 114 | for _ in range(num_episodes): 115 | rewards = compiled_train(seeds, envs, policy) 116 | # ... Update the policy based on rewards ... 117 | 118 | # Hypothetical policy function 119 | def policy(observation): 120 | # ... your policy logic ... 121 | return action 122 | 123 | # Start the training 124 | train_policy(policy, num_episodes=100) 125 | ``` 126 | 127 | ### Backpropagation through the environment 128 | ```python 129 | import jax 130 | import navix as nx 131 | import jax.numpy as jnp 132 | from jax import grad 133 | from flax import struct 134 | 135 | 136 | class Model(struct.PyTreeNode): 137 | @nn.compact 138 | def __call__(self, x): 139 | # ... your NN here 140 | 141 | model = Model() 142 | env = nx.environments.Room(16, 16, 8) 143 | 144 | def loss(params, timestep): 145 | action = jnp.asarray(0) 146 | pred_obs = model.apply(timestep.observation) 147 | timestep = env.step(timestep, action) 148 | return jnp.square(timestep.observation - pred_obs).mean() 149 | 150 | key = jax.random.PRNGKey(0) 151 | timestep = env.reset(key) 152 | params = model.init(key, timestep.observation) 153 | 154 | gradients = grad(loss)(params, timestep) 155 | ``` 156 | 157 | ## JAX ecosystem for RL 158 | NAVIX is not alone and part of an ecosystem of JAX-powered modules for RL. Check out the following projects: 159 | - Environments: 160 | - [Gymnax](https://github.com/RobertTLange/gymnax): a broad range of RL environments 161 | - [Brax](https://github.com/google/brax): a physics engine for robotics experiments 162 | - [EnvPool](https://github.com/sail-sg/envpool): a set of various batched environments 163 | - [Craftax](https://github.com/MichaelTMatthews/Craftax): a JAX reimplementation of the game of [Crafter](https://github.com/danijar/crafter) 164 | - [Jumanji](https://github.com/instadeepai/jumanji): another set of diverse environments 165 | - [PGX](https://github.com/sotetsuk/pgx): board games commonly used for RL, such as backgammon, chess, shogi, and go 166 | - [JAX-MARL](https://github.com/FLAIROx/JaxMARL): multi-agent RL environments in JAX 167 | - [Xland-Minigrid](https://github.com/corl-team/xland-minigrid/): a set of JAX-reimplemented grid-world environments 168 | - [Minimax](https://github.com/facebookresearch/minimax): a JAX library for RL autocurricula with 120x faster baselines 169 | - Agents: 170 | - [PureJaxRl](https://github.com/luchris429/purejaxrl): proposing fullly-jitten training routines 171 | - [Rejax](https://github.com/keraJLi/rejax): a suite of diverse agents, among which, DDPG, DQN, PPO, SAC, TD3 172 | - [Stoix](https://github.com/EdanToledo/Stoix): useful implementations of popular single-agent RL algorithms in JAX 173 | - [JAX-CORL](https://github.com/nissymori/JAX-CORL): lean single-file implementations of offline RL algorithms with solid performance reports 174 | - [Dopamine](https://github.com/google/dopamine): a research framework for fast prototyping of reinforcement learning algorithms 175 | 176 | 177 | ## Join Us! 178 | 179 | NAVIX is actively developed. If you'd like to contribute to this open-source project, we welcome your involvement! Start a discussion or open a pull request. 180 | 181 | Please, consider starring the project if you like NAVIX! 182 | 183 | ## Cite us, please! 184 | If you use NAVIX please cite it as: 185 | 186 | ```bibtex 187 | @article{pignatelli2024navix, 188 | title={NAVIX: Scaling MiniGrid Environments with JAX}, 189 | author={Pignatelli, Eduardo and Liesen, Jarek and Lange, Robert Tjarko and Lu, Chris and Castro, Pablo Samuel and Toni, Laura}, 190 | journal={arXiv preprint arXiv:2407.19396}, 191 | year={2024} 192 | } 193 | ``` 194 | --------------------------------------------------------------------------------