├── src ├── envs │ ├── envs.jl │ ├── empty.jl │ ├── gotodoor.jl │ └── fourrooms.jl ├── Gridworld.jl ├── abstract_grid_world.jl ├── render_in_terminal.jl ├── objects.jl ├── render_with_Makie.jl └── grid_world_base.jl ├── test └── runtests.jl ├── .gitignore ├── .github └── workflows │ ├── TagBot.yml │ └── CompatHelper.yml ├── Project.toml ├── .travis.yml ├── README.md ├── .appveyor.yml └── LICENSE /src/envs/envs.jl: -------------------------------------------------------------------------------- 1 | include("empty.jl") 2 | include("fourrooms.jl") 3 | include("gotodoor.jl") -------------------------------------------------------------------------------- /test/runtests.jl: -------------------------------------------------------------------------------- 1 | using Gridworld 2 | using Test 3 | 4 | @testset "Gridworld.jl" begin 5 | # Write your tests here. 6 | end 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.jl.*.cov 2 | *.jl.cov 3 | *.jl.mem 4 | Manifest.toml 5 | 6 | .vscode/* 7 | !.vscode/tasks.json 8 | !.vscode/launch.json 9 | !.vscode/extensions.json 10 | *.code-workspace 11 | 12 | # Local History for Visual Studio Code 13 | .history/ 14 | -------------------------------------------------------------------------------- /.github/workflows/TagBot.yml: -------------------------------------------------------------------------------- 1 | name: TagBot 2 | on: 3 | schedule: 4 | - cron: 0 0 * * * 5 | workflow_dispatch: 6 | jobs: 7 | TagBot: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: JuliaRegistries/TagBot@v1 11 | with: 12 | token: ${{ secrets.GITHUB_TOKEN }} 13 | ssh: ${{ secrets.DOCUMENTER_KEY }} 14 | -------------------------------------------------------------------------------- /src/Gridworld.jl: -------------------------------------------------------------------------------- 1 | module Gridworld 2 | 3 | using Requires 4 | 5 | const GW = Gridworld 6 | export GW 7 | 8 | include("objects.jl") 9 | include("grid_world_base.jl") 10 | include("abstract_grid_world.jl") 11 | include("envs/envs.jl") 12 | include("render_in_terminal.jl") 13 | 14 | # function __init__() 15 | # @require Makie="ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" include("render_with_Makie.jl") 16 | # end 17 | include("render_with_Makie.jl") 18 | 19 | end 20 | -------------------------------------------------------------------------------- /.github/workflows/CompatHelper.yml: -------------------------------------------------------------------------------- 1 | name: CompatHelper 2 | on: 3 | schedule: 4 | - cron: 0 0 * * * 5 | workflow_dispatch: 6 | jobs: 7 | CompatHelper: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - name: Pkg.add("CompatHelper") 11 | run: julia -e 'using Pkg; Pkg.add("CompatHelper")' 12 | - name: CompatHelper.main() 13 | env: 14 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 15 | COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }} 16 | run: julia -e 'using CompatHelper; CompatHelper.main()' 17 | -------------------------------------------------------------------------------- /Project.toml: -------------------------------------------------------------------------------- 1 | name = "Gridworld" 2 | uuid = "e15a9946-cd7f-4d03-83e2-6c30bacb0043" 3 | authors = ["Sriram"] 4 | version = "0.1.0" 5 | 6 | [deps] 7 | Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" 8 | Crayons = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" 9 | MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" 10 | Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" 11 | Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" 12 | Requires = "ae029012-a4dd-5104-9daa-d747884805df" 13 | 14 | [compat] 15 | julia = "1" 16 | 17 | [extras] 18 | Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" 19 | 20 | [targets] 21 | test = ["Test"] 22 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | # Documentation: http://docs.travis-ci.com/user/languages/julia 2 | language: julia 3 | notifications: 4 | email: false 5 | julia: 6 | - 1.0 7 | - 1.5 8 | - nightly 9 | os: 10 | - linux 11 | - osx 12 | - windows 13 | arch: 14 | - x64 15 | cache: 16 | directories: 17 | - ~/.julia/artifacts 18 | jobs: 19 | fast_finish: true 20 | allow_failures: 21 | - julia: nightly 22 | after_success: 23 | - | 24 | julia -e ' 25 | using Pkg 26 | Pkg.add("Coverage") 27 | using Coverage 28 | Codecov.submit(process_folder())' 29 | - | 30 | julia -e ' 31 | using Pkg 32 | Pkg.add("Coverage") 33 | using Coverage 34 | Coveralls.submit(process_folder())' 35 | -------------------------------------------------------------------------------- /src/envs/empty.jl: -------------------------------------------------------------------------------- 1 | export EmptyGridWorld 2 | 3 | mutable struct EmptyGridWorld <: AbstractGridWorld 4 | world::GridWorldBase{Tuple{Empty,Wall,Goal}} 5 | agent_pos::CartesianIndex{2} 6 | agent::Agent 7 | end 8 | 9 | function EmptyGridWorld(;n=8, agent_start_pos=CartesianIndex(2,2), agent_start_dir=RIGHT) 10 | objects = (EMPTY, WALL, GOAL) 11 | w = GridWorldBase(objects, n, n) 12 | w[EMPTY, 2:n-1, 2:n-1] .= true 13 | w[WALL, [1,n], 1:n] .= true 14 | w[WALL, 1:n, [1,n]] .= true 15 | w[GOAL, n-1, n-1] = true 16 | w[EMPTY, n-1, n-1] = false 17 | EmptyGridWorld(w, agent_start_pos, Agent(dir=agent_start_dir)) 18 | end 19 | 20 | function (w::EmptyGridWorld)(::MoveForward) 21 | dir = get_dir(w.agent) 22 | dest = dir(w.agent_pos) 23 | if !w.world[WALL, dest] 24 | w.agent_pos = dest 25 | end 26 | w 27 | end 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Gridworld 2 | 3 | This project aims to provide some simple grid world environments similar to [gym-minigrid](https://github.com/maximecb/gym-minigrid) for reinforcement learning research in Julia. 4 | 5 | ## Design 6 | 7 | A `GridWorldBase` is used to represent the whole grid world. Inside of it, a 3-D `BitArray` of size `(n_objects, height, width)` is used to encode objects in each tile. 8 | 9 | ## Usage 10 | 11 | ```julia 12 | using Gridworld 13 | 14 | w = EmptyGridWorld() 15 | 16 | w(MOVE_FORWARD) 17 | w(TURN_LEFT) 18 | w(RURN_RIGHT) 19 | 20 | play(w) # you can also play interactively with the help of Makie 21 | ``` 22 | 23 | ## TODO 24 | 25 | ### Environment list 26 | 27 | - [x] EmptyGridWorld 28 | - [x] FourRooms 29 | - [x] GoToDoor 30 | - 31 | 32 | ### Needs improvement 33 | 34 | - [ ] Add test cases 35 | - [ ] Benchmark (ensure our implementations do not have significant performance issues) 36 | - [ ] A wrapper for ReinforcementLearningBase.jl 37 | - [ ] Gif/Video writer 38 | -------------------------------------------------------------------------------- /.appveyor.yml: -------------------------------------------------------------------------------- 1 | # Documentation: https://github.com/JuliaCI/Appveyor.jl 2 | environment: 3 | matrix: 4 | - julia_version: 1.0 5 | - julia_version: 1.5 6 | - julia_version: nightly 7 | platform: 8 | - x64 9 | cache: 10 | - '%USERPROFILE%\.julia\artifacts' 11 | matrix: 12 | allow_failures: 13 | - julia_version: nightly 14 | branches: 15 | only: 16 | - master 17 | - /release-.*/ 18 | notifications: 19 | - provider: Email 20 | on_build_success: false 21 | on_build_failure: false 22 | on_build_status_changed: false 23 | install: 24 | - ps: iex ((new-object net.webclient).DownloadString("https://raw.githubusercontent.com/JuliaCI/Appveyor.jl/version-1/bin/install.ps1")) 25 | build_script: 26 | - echo "%JL_BUILD_SCRIPT%" 27 | - C:\julia\bin\julia -e "%JL_BUILD_SCRIPT%" 28 | test_script: 29 | - echo "%JL_TEST_SCRIPT%" 30 | - C:\julia\bin\julia -e "%JL_TEST_SCRIPT%" 31 | on_success: 32 | - echo "%JL_CODECOV_SCRIPT%" 33 | - C:\julia\bin\julia -e "%JL_CODECOV_SCRIPT%" 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Sriram 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/envs/gotodoor.jl: -------------------------------------------------------------------------------- 1 | export GoToDoor 2 | 3 | using Random 4 | 5 | mutable struct GoToDoor{W<:GridWorldBase} <: AbstractGridWorld 6 | world::W 7 | agent_pos::CartesianIndex{2} 8 | agent::Agent 9 | end 10 | 11 | function GoToDoor(;n=8, agent_start_pos=CartesianIndex(2,2), rng=Random.GLOBAL_RNG) 12 | objects = (EMPTY, WALL, (Door(c) for c in COLORS)...) 13 | world = GridWorldBase(objects, n, n) 14 | world[EMPTY, :, :] .= true 15 | world[WALL, [1,n], 1:n] .= true 16 | world[EMPTY, [1,n], 1:n] .= false 17 | world[WALL, 1:n, [1,n]] .= true 18 | world[EMPTY, 1:n, [1,n]] .= false 19 | 20 | door_pos = [(rand(rng, 2:n-1),1), (rand(rng, 2:n-1),n), (1,rand(rng, 2:n-1)), (n,rand(rng, 2:n-1))] 21 | door_colors = COLORS[randperm(rng, length(COLORS))][1:length(door_pos)] 22 | for (c, p) in zip(door_colors, door_pos) 23 | world[Door(c), p...] = true 24 | world[WALL, p...] = false 25 | end 26 | GoToDoor(world, agent_start_pos, Agent(dir=RIGHT)) 27 | end 28 | 29 | function (w::GoToDoor)(::MoveForward) 30 | dir = get_dir(w.agent) 31 | dest = dir(w.agent_pos) 32 | if !w.world[WALL,dest] 33 | w.agent_pos = dest 34 | end 35 | w 36 | end -------------------------------------------------------------------------------- /src/abstract_grid_world.jl: -------------------------------------------------------------------------------- 1 | abstract type AbstractGridWorld end 2 | 3 | function get_agent_view end 4 | function get_agent end 5 | 6 | Base.convert(::Type{GridWorldBase}, w::AbstractGridWorld) = w.world 7 | get_object(w::AbstractGridWorld) = get_object(convert(GridWorldBase, w)) 8 | get_object(w::AbstractGridWorld, x::Type{<:AbstractObject}) = filter(o -> o isa x, get_object(w)) 9 | get_object(w::AbstractGridWorld, x::Type{Agent}) = w.agent 10 | get_pos(w::AbstractGridWorld, ::Type{Agent}) = w.agent_pos 11 | 12 | get_agent(w::AbstractGridWorld) = get_object(w, Agent) 13 | get_agent_pos(w::AbstractGridWorld) = get_pos(w, Agent) 14 | get_agent_dir(w::AbstractGridWorld) = w |> get_agent |> get_dir 15 | 16 | function get_agent_view(w::AbstractGridWorld, agent_view_size=(7,7)) 17 | wb = convert(GridWorldBase, w) 18 | v = BitArray{3}(undef, size(wb, 1), agent_view_size...) 19 | fill!(v, false) 20 | get_agent_view!(v, w) 21 | end 22 | 23 | function (w::AbstractGridWorld)(dir::Union{TurnRight, TurnLeft}) 24 | a = get_agent(w) 25 | set_dir!(a, dir(get_dir(a))) 26 | w 27 | end 28 | 29 | get_agent_view_inds(w::AbstractGridWorld, s=(7,7)) = get_agent_view_inds(get_agent_pos(w).I, s, get_agent_dir(w)) 30 | 31 | get_agent_view!(v::BitArray{3}, w::AbstractGridWorld) = get_agent_view!(v, convert(GridWorldBase, w), get_agent_pos(w), get_agent_dir(w)) 32 | 33 | -------------------------------------------------------------------------------- /src/envs/fourrooms.jl: -------------------------------------------------------------------------------- 1 | export FourRooms 2 | 3 | mutable struct FourRooms <: AbstractGridWorld 4 | world::GridWorldBase{Tuple{Empty,Wall,Goal}} 5 | agent_pos::CartesianIndex{2} 6 | agent::Agent 7 | end 8 | 9 | function FourRooms(;n=19, agent_start_pos=CartesianIndex(2,2)) 10 | objects = (EMPTY, WALL, GOAL) 11 | world = GridWorldBase(objects, n,n) 12 | world[EMPTY, 2:n-1, 2:n-1] .= true 13 | world[WALL, [1,n], 1:n] .= true 14 | world[WALL, 1:n, [1,n]] .= true 15 | world[WALL, ceil(Int,n/2), vcat(2:ceil(Int,n/4)-1,ceil(Int,n/4)+1:ceil(Int,n/2)-1,ceil(Int,n/2):ceil(Int,3*n/4)-1,ceil(Int,3*n/4)+1:n)] .= true 16 | world[EMPTY, ceil(Int,n/2), vcat(2:ceil(Int,n/4)-1,ceil(Int,n/4)+1:ceil(Int,n/2)-1,ceil(Int,n/2):ceil(Int,3*n/4)-1,ceil(Int,3*n/4)+1:n)] .= false 17 | world[WALL, vcat(2:ceil(Int,n/4)-1,ceil(Int,n/4)+1:ceil(Int,n/2)-1,ceil(Int,n/2):ceil(Int,3*n/4)-1,ceil(Int,3*n/4)+1:n), ceil(Int,n/2)] .= true 18 | world[EMPTY, vcat(2:ceil(Int,n/4)-1,ceil(Int,n/4)+1:ceil(Int,n/2)-1,ceil(Int,n/2):ceil(Int,3*n/4)-1,ceil(Int,3*n/4)+1:n), ceil(Int,n/2)] .= false 19 | world[GOAL, n-1, n-1] = true 20 | world[EMPTY, n-1, n-1] = false 21 | FourRooms(world,agent_start_pos,Agent(dir=RIGHT)) 22 | end 23 | 24 | function (w::FourRooms)(::MoveForward) 25 | dest = w.agent_dir(w.agent_pos) 26 | if !w.world[WALL, dest] 27 | w.agent_pos = dest 28 | end 29 | w 30 | end 31 | -------------------------------------------------------------------------------- /src/render_in_terminal.jl: -------------------------------------------------------------------------------- 1 | using Crayons 2 | 3 | function Base.show(io::IO, gw::AbstractGridWorld) 4 | p, d = get_agent_pos(gw), get_agent_dir(gw) 5 | w = convert(GridWorldBase, gw) 6 | 7 | println(io, "World:") 8 | for i in 1:size(w, 2) 9 | for j in 1:size(w, 3) 10 | if CartesianIndex(i, j) ∈ get_agent_view_inds(gw) 11 | bg = :dark_gray 12 | else 13 | bg = :black 14 | end 15 | if i == p[1] && j == p[2] 16 | agent = get_agent(gw) 17 | print(io, Crayon(background=bg, foreground=get_color(agent),reset=true), convert(Char, agent)) 18 | else 19 | o = get_object(gw)[findfirst(w.world[:, i, j])] 20 | print(io, Crayon(background=bg, foreground=get_color(o),reset=true), convert(Char, o)) 21 | end 22 | end 23 | println(io, Crayon(reset=true)) 24 | end 25 | println(io) 26 | 27 | println(io, "Agent's view:") 28 | v = get_agent_view(gw) 29 | for i in 1:size(v, 2) 30 | for j in 1:size(v, 3) 31 | if i == 1 && j == size(v, 3) ÷ 2 + 1 32 | print(io, Agent(dir=DOWN)) 33 | else 34 | x = findfirst(v[:, i, j]) 35 | if isnothing(x) 36 | print(io, '_') 37 | else 38 | print(io, get_object(gw)[x]) 39 | end 40 | end 41 | end 42 | println(io) 43 | end 44 | println(io) 45 | end -------------------------------------------------------------------------------- /src/objects.jl: -------------------------------------------------------------------------------- 1 | export MOVE_FORWARD, TURN_LEFT, TURN_RIGHT 2 | 3 | using Crayons 4 | using Colors 5 | 6 | const COLORS = (:red, :green, :blue, :magenta, :yellow, :white) 7 | 8 | ##### 9 | # Actions 10 | ##### 11 | 12 | struct MoveForward end 13 | const MOVE_FORWARD = MoveForward() 14 | 15 | struct Up end 16 | const UP = Up() 17 | (x::Up)(p::CartesianIndex{2}) = p + CartesianIndex(-1, 0) 18 | 19 | struct Down end 20 | const DOWN = Down() 21 | (x::Down)(p::CartesianIndex{2}) = p + CartesianIndex(1, 0) 22 | 23 | struct Left end 24 | const LEFT = Left() 25 | (x::Left)(p::CartesianIndex{2}) = p + CartesianIndex(0, -1) 26 | 27 | struct Right end 28 | const RIGHT = Right() 29 | (x::Right)(p::CartesianIndex{2}) = p + CartesianIndex(0, 1) 30 | 31 | const LRUD = Union{Left, Right, Up, Down} 32 | 33 | struct TurnRight end 34 | const TURN_RIGHT = TurnRight() 35 | struct TurnLeft end 36 | const TURN_LEFT = TurnLeft() 37 | 38 | (x::TurnRight)(::Left) = UP 39 | (x::TurnRight)(::Up) = RIGHT 40 | (x::TurnRight)(::Right) = DOWN 41 | (x::TurnRight)(::Down) = LEFT 42 | (x::TurnLeft)(::Left) = DOWN 43 | (x::TurnLeft)(::Up) = LEFT 44 | (x::TurnLeft)(::Right) = UP 45 | (x::TurnLeft)(::Down) = RIGHT 46 | 47 | ##### 48 | # Objects 49 | ##### 50 | 51 | abstract type AbstractObject end 52 | 53 | Base.show(io::IO, x::AbstractObject) = print(io, Crayon(foreground=get_color(x), reset=true), convert(Char, x)) 54 | 55 | struct Empty <: AbstractObject end 56 | const EMPTY = Empty() 57 | Base.convert(::Type{Char}, ::Empty) = '⋅' 58 | get_color(::Empty) = :white 59 | 60 | struct Wall <: AbstractObject end 61 | const WALL = Wall() 62 | Base.convert(::Type{Char}, ::Wall) = '█' 63 | get_color(::Wall) = :white 64 | 65 | struct Goal <: AbstractObject end 66 | const GOAL = Goal() 67 | Base.convert(::Type{Char}, ::Goal) = '♥' 68 | get_color(::Goal) = :red 69 | 70 | struct Door{C} <: AbstractObject end 71 | Door(c) = Door{c}() 72 | Base.convert(::Type{Char}, ::Door) = '🚪' 73 | get_color(::Door{C}) where C = C 74 | 75 | Base.@kwdef mutable struct Agent <: AbstractObject 76 | color::Symbol=:red 77 | dir::LRUD 78 | end 79 | function Base.convert(::Type{Char}, a::Agent) 80 | if a.dir === UP 81 | '↑' 82 | elseif a.dir === DOWN 83 | '↓' 84 | elseif a.dir === LEFT 85 | '←' 86 | elseif a.dir === RIGHT 87 | '→' 88 | end 89 | end 90 | get_color(a::Agent) = a.color 91 | get_dir(a::Agent) = a.dir 92 | set_dir!(a::Agent, d) = a.dir = d 93 | 94 | -------------------------------------------------------------------------------- /src/render_with_Makie.jl: -------------------------------------------------------------------------------- 1 | export play 2 | 3 | using Colors 4 | 5 | # coordinate transform for Makie.jl 6 | transform(x::Int) = p -> CartesianIndex(p[2], x-p[1]+1) 7 | 8 | using Makie 9 | 10 | function init_screen(w::Observable{<:AbstractGridWorld}; resolution=(1000,1000)) 11 | scene = Scene(resolution = resolution, raw = true, camera = campixel!) 12 | 13 | area = scene.px_area 14 | grid_size = size(w[].world)[2:3] 15 | grid_inds = CartesianIndices(grid_size) 16 | tile_size = @lift((widths($area)[1] / size($w.world, 2), widths($area)[2] / size($w.world, 3))) 17 | T = transform(size(w[].world, 2)) 18 | boxes(pos, s) = [FRect2D((T(p).I .- (1,1)) .* s, s) for p in pos] 19 | centers(pos, s) = [(T(p).I .- (0.5,0.5)) .* s for p in pos] 20 | 21 | # 1. paint background 22 | poly!(scene, area) 23 | 24 | # 2. paint each kind of object 25 | for o in get_object(w[]) 26 | if o === WALL 27 | poly!(scene, @lift(boxes(findall($w.world[WALL, :, :]), $tile_size)), color=:darkgray,) 28 | elseif o === GOAL 29 | scatter!(scene, @lift(centers(findall($w.world[GOAL,:,:]), $tile_size)), color=get_color(GOAL), marker=convert(Char, GOAL), markersize=@lift(minimum($tile_size))) 30 | elseif o isa Door 31 | scatter!(scene, @lift(centers(findall($w.world[o,:,:]), $tile_size)), color=get_color(o), marker=convert(Char, o), markersize=@lift(minimum($tile_size))) 32 | end 33 | end 34 | 35 | # 3. paint stroke 36 | poly!(scene, @lift(boxes(vec(grid_inds), $tile_size)), color=:transparent, strokecolor = :lightgray, strokewidth = 4) 37 | 38 | # 3. paint agent's view 39 | view_boxes = @lift boxes([p for p in get_agent_view_inds($w) if p ∈ grid_inds], $tile_size) 40 | poly!(scene, view_boxes, color="rgba(255,255,255,0.2)") 41 | 42 | # 4. paint agent 43 | agent = @lift(get_agent($w)) 44 | agent_position = @lift((T(get_agent_pos($w)).I .- (0.5, 0.5)).* $tile_size) 45 | scatter!(scene, agent_position, color=@lift(get_color($agent)), marker=@lift(convert(Char, $agent)), markersize=@lift(minimum($tile_size))) 46 | 47 | display(scene) 48 | scene 49 | end 50 | 51 | function play(environment::AbstractGridWorld) 52 | print(""" 53 | Key bindings: 54 | ←: TurnLeft 55 | →: TurnRight 56 | ↑: MoveForward 57 | q: Quit 58 | """) 59 | w = environment 60 | w_node = Node(w) 61 | scene = init_screen(w_node) 62 | is_quit = Ref(false) 63 | 64 | on(scene.events.keyboardbuttons) do b 65 | if ispressed(b, Keyboard.left) 66 | w(TURN_LEFT) 67 | w_node[] = w 68 | elseif ispressed(b, Keyboard.right) 69 | w(TURN_RIGHT) 70 | w_node[] = w 71 | elseif ispressed(b, Keyboard.up) 72 | w(MOVE_FORWARD) 73 | w_node[] = w 74 | elseif ispressed(b, Keyboard.q) 75 | is_quit[] = true 76 | end 77 | end 78 | 79 | while !is_quit[] 80 | sleep(0.5) 81 | end 82 | end 83 | -------------------------------------------------------------------------------- /src/grid_world_base.jl: -------------------------------------------------------------------------------- 1 | using MacroTools:@forward 2 | using Random 3 | 4 | """ 5 | GridWorldBase{O} <: AbstractArray{Bool, 3} 6 | 7 | A basic representation of grid world. 8 | The first dimension uses multi-hot encoding to encode objects in a tile. 9 | The second and third dimension means the height and width of the grid. 10 | """ 11 | struct GridWorldBase{O} <: AbstractArray{Bool, 3} 12 | world::BitArray{3} 13 | objects::O 14 | end 15 | 16 | get_object(w::GridWorldBase) = w.objects 17 | 18 | function GridWorldBase(objects::Tuple{Vararg{AbstractObject}}, x::Int, y::Int) 19 | world = BitArray{3}(undef, length(objects), x, y) 20 | fill!(world, false) 21 | GridWorldBase(world, objects) 22 | end 23 | 24 | @forward GridWorldBase.world Base.size, Base.getindex, Base.setindex! 25 | 26 | @generated function Base.to_index(::GridWorldBase{O}, x::X) where {X<:AbstractObject, O} 27 | i = findfirst(X .=== O.parameters) 28 | isnothing(i) && error("unknow object $x") 29 | :($i) 30 | end 31 | 32 | Base.setindex!(w::GridWorldBase, v::Bool, o::AbstractObject, x::Int, y::Int) = setindex!(w.world, v, Base.to_index(w, o), x, y) 33 | Base.setindex!(w::GridWorldBase, v::Bool, o::AbstractObject, i::CartesianIndex{2}) = setindex!(w, v, o, i[1], i[2]) 34 | 35 | Base.getindex(w::GridWorldBase, o::AbstractObject, x::Int, y::Int) = getindex(w.world, Base.to_index(w, o), x, y) 36 | Base.getindex(w::GridWorldBase, o::AbstractObject, i::CartesianIndex{2}) = getindex(w, o, i[1], i[2]) 37 | Base.getindex(w::GridWorldBase, o::AbstractObject, x::Colon, y::Colon) = getindex(w.world, Base.to_index(w, o), x, y) 38 | 39 | ##### 40 | # utils 41 | ##### 42 | 43 | switch!(world::GridWorldBase, x, src::CartesianIndex{2}, dest::CartesianIndex{2}) = world[x, src], world[x, dest] = world[x, dest], world[x, src] 44 | 45 | function switch!(world::GridWorldBase, src::CartesianIndex{2}, dest::CartesianIndex{2}) 46 | for x in axes(world, 1) 47 | switch!(world, x, src, dest) 48 | end 49 | end 50 | 51 | function Random.rand(f::Function, w::GridWorldBase; max_try=typemax(Int), rng=Random.GLOBAL_RNG) 52 | inds = CartesianIndices((size(w, 2), size(w, 3))) 53 | for _ in 1:max_try 54 | pos = rand(rng, inds) 55 | f(view(w, :, pos)) && return pos 56 | end 57 | @warn "a rare case happened when sampling from GridWorldBase" 58 | return nothing 59 | end 60 | 61 | ##### 62 | # get_agent_view 63 | ##### 64 | 65 | get_agent_view_inds((i, j), (m, n), ::Left) = CartesianIndices((i-(n-1)÷2:i+(n-(n-1)÷2)-1, j-m+1:j)) 66 | get_agent_view_inds((i, j), (m, n), ::Right) = CartesianIndices((i-(n-1)÷2:i+(n-(n-1)÷2)-1, j:j+m-1)) 67 | get_agent_view_inds((i, j), (m, n), ::Up) = CartesianIndices((i-m+1:i, j-(n-1)÷2:j+(n-(n-1)÷2)-1)) 68 | get_agent_view_inds((i, j), (m, n), ::Down) = CartesianIndices((i:i+m-1, j-(n-1)÷2:j+(n-(n-1)÷2)-1)) 69 | 70 | ind_map((i,j), (m, n), ::Left) = (m-j+1, i) 71 | ind_map((i,j), (m, n), ::Right) = (j, n-i+1) 72 | ind_map((i,j), (m, n), ::Up) = (m-i+1, n-j+1) 73 | ind_map((i,j), (m, n), ::Down) = (i,j) 74 | 75 | function get_agent_view!(v::AbstractArray{Bool,3}, a::AbstractArray{Bool,3}, p::CartesianIndex, dir::LRUD) 76 | view_size = (size(v, 2), size(v, 3)) 77 | grid_size = (size(a,2),size(a,3)) 78 | inds = get_agent_view_inds(p.I, view_size, dir) 79 | valid_inds = CartesianIndices(grid_size) 80 | for ind in CartesianIndices(inds) 81 | if inds[ind] ∈ valid_inds 82 | v[:, ind_map(ind.I, view_size, dir)...] .= a[:, inds[ind]] 83 | end 84 | end 85 | v 86 | end --------------------------------------------------------------------------------