├── REQUIRE ├── test ├── runtests.jl └── denv.jl ├── .gitignore ├── Project.toml ├── src ├── DistributedEnv.jl └── denv.jl ├── README.md └── benchmark └── benchmark.jl /REQUIRE: -------------------------------------------------------------------------------- 1 | julia 0.7 2 | Distributed 3 | MLStyle -------------------------------------------------------------------------------- /test/runtests.jl: -------------------------------------------------------------------------------- 1 | using Test 2 | 3 | include("denv.jl") -------------------------------------------------------------------------------- /test/denv.jl: -------------------------------------------------------------------------------- 1 | using DistributedEnv 2 | using Distributed 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.jl.cov 2 | *.jl.*.cov 3 | *.jl.mem 4 | deps/deps.jl 5 | Manifest.toml -------------------------------------------------------------------------------- /Project.toml: -------------------------------------------------------------------------------- 1 | name = "DistributedEnv" 2 | uuid = "4c74d760-a68f-11e8-0c49-2ba71e3afc10" 3 | authors = ["Jun Tian "] 4 | version = "0.1.0" 5 | 6 | [deps] 7 | Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" 8 | ReinforcementLearningBase = "9b2b9cba-ac73-11e8-02b1-9f0869453fc0" 9 | 10 | [extras] 11 | Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" 12 | 13 | [targets] 14 | test = ["Test"] 15 | -------------------------------------------------------------------------------- /src/DistributedEnv.jl: -------------------------------------------------------------------------------- 1 | module DistributedEnv 2 | export interact!, reset!, getstate, actionspace 3 | 4 | import ReinforcementLearningBase: interact!, reset!, getstate, actionspace 5 | 6 | include("denv.jl") 7 | 8 | interact!(denv::RemoteEnv, action) = send(denv, :interact!, action) 9 | reset!(denv::RemoteEnv) = send(denv, :reset!) 10 | getstate(denv::RemoteEnv) = send(denv, :getstate) 11 | actionspace(denv::RemoteEnv) = denv.actionspace 12 | 13 | end # module 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DistributedEnv.jl 2 | 3 | This package aims to provide a thin wrapper to enable different reinforcement learning environments to run in parallel. 4 | 5 | In current implementation, an environment is running infinitely on a worker process as a `Task`, which is bind with a `RemoteChannel`. 6 | For some light-weight environments(like CartPole), this implementation is not that efficient. In the future, a global scheduler will be added to orchestrate remote environments, actors and learners. Just like the [IMPALA](https://deepmind.com/blog/impala-scalable-distributed-deeprl-dmlab-30/) architecture proposed by DeepMind. 7 | 8 | ## Install 9 | 10 | ``` 11 | (v1.0) pkg> add https://github.com/JuliaReinforcementLearning/DistributedEnv.jl.git 12 | ``` 13 | 14 | ## How to use? 15 | 16 | ```julia 17 | julia> using Distributed 18 | 19 | julia> addprocs() 20 | 4-element Array{Int64,1}: 21 | 2 22 | 3 23 | 4 24 | 5 25 | 26 | julia> @everywhere using DistributedEnv 27 | 28 | julia> @everywhere using ReinforcementLearningEnvironmentClassicControl 29 | 30 | julia> envs = [RemoteEnv(CartPole; pid=x) for x in workers()] 31 | 4-element Array{RemoteEnv{CartPole},1}: 32 | RemoteEnv{CartPole}(RemoteChannel{Channel{DistributedEnv.Message}}(2, 1, 38), ReinforcementLearningBase.DiscreteSpace(2, 1)) 33 | RemoteEnv{CartPole}(RemoteChannel{Channel{DistributedEnv.Message}}(3, 1, 43), ReinforcementLearningBase.DiscreteSpace(2, 1)) 34 | RemoteEnv{CartPole}(RemoteChannel{Channel{DistributedEnv.Message}}(4, 1, 48), ReinforcementLearningBase.DiscreteSpace(2, 1)) 35 | RemoteEnv{CartPole}(RemoteChannel{Channel{DistributedEnv.Message}}(5, 1, 53), ReinforcementLearningBase.DiscreteSpace(2, 1)) 36 | ``` -------------------------------------------------------------------------------- /benchmark/benchmark.jl: -------------------------------------------------------------------------------- 1 | using BenchmarkTools 2 | using ReinforcementLearningEnvironmentClassicControl 3 | using DistributedEnv 4 | # install above packages first 5 | using Distributed 6 | 7 | function localenv(env, N) 8 | reset!(env) 9 | n_actions = 1 10 | while n_actions < N 11 | state, reward, isdone = interact!(env, sample(actionspace(env))) 12 | n_actions += 1 13 | isdone || reset!(env) 14 | end 15 | end 16 | 17 | 18 | function remoteenv(envs, N) 19 | map(envs) do env 20 | fetch(reset!(env)) 21 | end 22 | n_actions = length(envs) 23 | while n_actions < N 24 | map(envs) do env 25 | _, _, isdone = fetch(interact!(env, sample(actionspace(env)))) 26 | isdone || reset!(env) 27 | end 28 | n_actions += length(envs) 29 | end 30 | end 31 | 32 | N = 10000 33 | 34 | @benchmark localenv(env, N) setup=(env = CartPole()) teardown=(env = nothing) 35 | 36 | # BenchmarkTools.Trial: 37 | # memory estimate: 3.20 MiB 38 | # allocs estimate: 30000 39 | # -------------- 40 | # minimum time: 2.059 ms (0.00% GC) 41 | # median time: 2.096 ms (0.00% GC) 42 | # mean time: 2.281 ms (5.52% GC) 43 | # maximum time: 50.664 ms (94.55% GC) 44 | # -------------- 45 | # samples: 2178 46 | # evals/sample: 1 47 | 48 | envs = [RemoteEnv(CartPole; pid=x) for x in workers()] 49 | @benchmark remoteenv(envs, N) 50 | 51 | # BenchmarkTools.Trial: 52 | # memory estimate: 36.58 MiB 53 | # allocs estimate: 976091 54 | # -------------- 55 | # minimum time: 143.638 ms (5.72% GC) 56 | # median time: 177.707 ms (13.51% GC) 57 | # mean time: 185.569 ms (16.34% GC) 58 | # maximum time: 289.228 ms (33.72% GC) 59 | # -------------- 60 | # samples: 27 61 | # evals/sample: 1 -------------------------------------------------------------------------------- /src/denv.jl: -------------------------------------------------------------------------------- 1 | using Distributed 2 | using ReinforcementLearningBase 3 | export sample 4 | export RemoteEnv, whereis, send 5 | 6 | struct Message 7 | resbox::Distributed.AbstractRemoteRef 8 | method::Symbol 9 | args::Tuple 10 | kw::Iterators.Pairs 11 | end 12 | 13 | struct RemoteEnv{T <: AbstractEnv} 14 | mailbox::RemoteChannel{Channel{Message}} 15 | actionspace::AbstractSpace 16 | end 17 | 18 | 19 | """ 20 | Create an environment on a worker. 21 | 22 | RemoteEnv(envtype::Type{T}, args...; pid::Int=myid(), kw...) where T <: AbstractEnv 23 | 24 | The `args` and `kw` are passed to `envtype` to create an environment at a worker 25 | specified by `pid`. 26 | """ 27 | function RemoteEnv(envtype::Type{T}, args...; pid::Int=myid(), kw...) where T <: AbstractEnv 28 | envtype <: AbstractEnv || throw("Unsupported Environment type $envtype") 29 | mailbox = RemoteChannel(pid) do 30 | Channel(;ctype=Message, csize=Inf) do c 31 | try 32 | env = envtype(args...; kw...) 33 | while true 34 | msg = take!(c) 35 | method = @eval Main.$(msg.method) 36 | put!(msg.resbox, method(env, msg.args...; msg.kw...)) 37 | end 38 | catch e 39 | @error e 40 | end 41 | end 42 | end 43 | actionspace = send(mailbox, :actionspace) |> fetch 44 | RemoteEnv{T}(mailbox, actionspace) 45 | end 46 | 47 | "Return the worker id of an `RemoteEnv`" 48 | whereis(env::RemoteEnv) = env.mailbox.where 49 | 50 | function send(mailbox::RemoteChannel{Channel{Message}} , method::Symbol, args...; kw...) 51 | resbox = Future(mailbox.where) 52 | put!(mailbox, Message(resbox, method, args, kw)) 53 | resbox 54 | end 55 | 56 | send(env::RemoteEnv, method::Symbol, args...; kw...) = send(env.mailbox, method, args...; kw...) --------------------------------------------------------------------------------