├── .gitattributes
├── .gitignore
├── README.md
├── docs
├── figs
│ └── srl_architecture.png
├── system_components
│ ├── 00_system_overview.md
│ ├── 01_worker_base.md
│ ├── 02_actor_worker.md
│ ├── 03_policy_worker.md
│ ├── 04_trainer_worker.md
│ ├── 05_buffer_worker.md
│ ├── 06_eval_manager.md
│ ├── 07_inference_stream.md
│ ├── 08_sample_stream.md
│ └── 09_basic_utils.md
└── user_guide
│ ├── 00_overview.md
│ ├── 01_experiment_config.md
│ ├── 02_logging.md
│ ├── 03_environments.md
│ ├── 04_policy_algorithm.md
│ └── 05_named_array.md
├── setup.py
├── src
└── rlsrl
│ ├── api
│ ├── __init__.py
│ ├── config.py
│ ├── curriculum.py
│ ├── env_utils.py
│ ├── environment.py
│ ├── policy.py
│ └── trainer.py
│ ├── apps
│ └── main.py
│ ├── base
│ ├── __init__.py
│ ├── buffer.py
│ ├── conditions.py
│ ├── gpu_utils.py
│ ├── lock.py
│ ├── name_resolve.py
│ ├── namedarray.py
│ ├── names.py
│ ├── network.py
│ ├── numpy_utils.py
│ ├── segment_tree.py
│ ├── shared_memory.py
│ ├── timeutil.py
│ └── user.py
│ ├── legacy
│ ├── __init__.py
│ ├── algorithm
│ │ ├── __init__.py
│ │ ├── modules
│ │ │ ├── __init__.py
│ │ │ ├── attention.py
│ │ │ ├── autoreset_rnn.py
│ │ │ ├── cnn.py
│ │ │ ├── gae.py
│ │ │ ├── popart.py
│ │ │ ├── recurrent_backbone.py
│ │ │ └── utils.py
│ │ └── ppo
│ │ │ ├── __init__.py
│ │ │ ├── actor_critic_policies
│ │ │ ├── __init__.py
│ │ │ ├── actor_critic_policy.py
│ │ │ └── utils.py
│ │ │ ├── mappo.py
│ │ │ └── phasic_policy_gradient.py
│ ├── environment
│ │ ├── __init__.py
│ │ └── atari
│ │ │ ├── __init__.py
│ │ │ ├── atari_env.py
│ │ │ └── atari_wrappers.py
│ └── experiments
│ │ ├── __init__.py
│ │ ├── atari_benchmark.py
│ │ └── atari_remote.py
│ ├── system
│ ├── __init__.py
│ ├── api
│ │ ├── inference_stream.py
│ │ ├── parameter_db.py
│ │ ├── sample_stream.py
│ │ ├── worker_base.py
│ │ └── worker_control.py
│ └── impl
│ │ ├── __init__.py
│ │ ├── actor_worker.py
│ │ ├── dummy_worker.py
│ │ ├── eval_manager.py
│ │ ├── inline_inference.py
│ │ ├── local_inference.py
│ │ ├── local_sample.py
│ │ ├── master_worker.py
│ │ ├── policy_worker.py
│ │ ├── remote_inference.py
│ │ ├── remote_sample.py
│ │ └── trainer_worker.py
│ └── testing
│ ├── __init__.py
│ ├── aerochess_env.py
│ ├── null_trainer.py
│ └── random_policy.py
└── tests
└── system
├── actor_worker_test.py
├── eval_manager_test.py
├── inference_stream_test.py
├── policy_worker_test.py
├── sample_stream_test.py
└── trainer_worker_test.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | bazel-*
3 | *__pycache__*
4 | _experiments
5 | .vscode
6 | wandb
7 | *.pyc
8 | *.log
9 | *.sh
10 | *.egg-info
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # **Note**
2 | This repository has been moved to a new address [https://github.com/openpsi-project/srl](https://github.com/openpsi-project/srl) !!! This repository is no longer maintained. Please check our new updates in the new repository!
3 |
4 | # SRL (**R**ea**L**ly **S**calable **RL**): Scaling Distributed Reinforcement Learning to Over Ten Thousand Cores
5 |
6 | **SRL** is an _efficient_, _scalable_ and _extensible_ distributed **Reinforcement Learning** system. **SRL** supports running several state-of-the-art RL algorithms on some common environments with one simple configuration file, and also exposes general APIs for users to develop their self-defined environments, policies and algorithms. **SRL** even allows users to implement new system components to support their algorithm designs, if current system architecture is not sufficient.
7 |
8 | Currently, our scheduler with `slurm` is not released. We are planning to implement a `ray` version launcher for users to easily deploy SRL on a large scale!
9 |
10 | ## Algorithms and Environments
11 |
12 | In this repository, one algorithm (**[Proximal Policy Optimization](https://arxiv.org/abs/1707.06347)**) and five environments (**[Gym Atari](https://www.gymlibrary.dev/environments/atari/), [Google football](https://github.com/google-research/football), [Gym MuJoCo](https://www.gymlibrary.dev/environments/mujoco/), [Hide and Seek](https://openai.com/blog/emergent-tool-use/), [SMAC](https://github.com/oxwhirl/smac)**) are implemented as examples. In the future, more environment and algorithm supports will be added to build an RL library with SRL.
13 |
14 | ## Installation
15 |
16 | Before installation, make sure you have `python>=3.8` and `torch>=1.10.0, gym` installed. [Wandb](https://wandb.ai/) is also supported, please install `wandb` package if you intend to use it for logging. You should also install environments you intend to run. For more information, check links about supported envrionment in previous section. (Note that **Google football** environment requires a older version of `gym==0.21.0`)
17 |
18 | Contents in this repository could be installed as a python package. To install, you should clone this repository and install the package by:
19 |
20 | `git clone https://github.com/openpsi-projects/srl.git`
21 |
22 | `cd srl && pip install -e .`
23 |
24 | ## Running an Experiment
25 |
26 | After installing **SRL** and atari environment, to run a simple experiment we provide as an example:
27 |
28 | `srl-local run -e atari-mini -f test`
29 |
30 | This command line will start a run of simple PPO training on environment atari, defined by:
31 |
32 | - Experiment config: [src/rlsrl/legacy/experiments/atari_benchmark.py](src/rlsrl/legacy/experiments/atari_benchmark.py)
33 |
34 | - Atari environment implementation: [src/rlsrl/legacy/environment/atari/atari_env.py](src/rlsrl/legacy/environment/atari/atari_env.py)
35 |
36 | - Algorithm and policy implementation: [src/rlsrl/legacy/algorithm/ppo/](src/rlsrl/legacy/algorithm/ppo/)
37 |
38 | ## Documentation
39 |
40 | For more user guides:
41 | - [Users Guide](docs/user_guide/00_overview.md)
42 |
43 | For more information about **SRL**:
44 | - [System Components](docs/system_components/00_system_overview.md)
45 |
46 | ## Full paper
47 |
48 | Full paper: **SRL: Scaling Distributed Reinforcement Learning to Over Ten thousand cores** available in arxiv! Link: **[https://arxiv.org/abs/2306.16688](https://arxiv.org/abs/2306.16688)**
49 |
--------------------------------------------------------------------------------
/docs/figs/srl_architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openpsi-projects/srl/97b3ae4e5fab00f6da90f81679dc035a9079eb09/docs/figs/srl_architecture.png
--------------------------------------------------------------------------------
/docs/system_components/00_system_overview.md:
--------------------------------------------------------------------------------
1 | # System Components: Overview
2 |
3 | In this section, basic system components will be introduced to give users a first impression of how the system works.
4 |
5 | ## Workers
6 |
7 | Workers are stateful computational components that follows a pre-defined workflow. In distributed version of the system, upon system starts, workers will be launched as RPC servers that receive requests from a centralized controller. The centralized controller calls remote methods in workers to control their life cycle. To launch a local experiment, `apps/main.py` launches all workers and manages them as `multiprocessing.Process`.
8 |
9 | In our original system, there are 3 types of workers: [Actor Workers](02_actor_worker.md), [Policy Workers](03_policy_worker.md) and [Trainer Workers](04_trainer_worker.md), corresponding to simulation, inference and training tasks in RL. They are connected to each others by data streams (Inference Stream and Sample Stream). The relationship between these components is shown as follows:
10 |
11 |
12 |
13 |
14 |
15 | Users are allowed to implement their own workers to support self-designed RL algorithms. General APIs for workers are provided by a base class `worker_base.Worker`. See [Worker Base](01_worker_base.md) for detail.
16 |
17 | ## Data Streams
18 |
19 |
20 | Data streams are components that deals with communication between workers. In our original RL system design, there are two types of data streams: Inference Stream, which exchanges observations and actions between actor workers and policy workers; Sample Stream, which sends training samples from actor workers to trainer workers. In distributed version, data streams are implemented in ZMQ (https://zeromq.org/), while in local version, data streams are implemented in shared memory.
21 |
22 | ### Inference Stream
23 |
24 | Inference streams gather/batch inference requests from actor workers, send them to policy workers, and receive responses from policy workers to reply actor workers. Inference stream is duplex. _Inference Clients_ and _Inference Servers_ of an inference stream could be created in endpoint workers as handles to receive and send data.
25 |
26 | - _Inference Client_ is responsible for gathering/batching requests from actor workers and receiving responses from inference clients.
27 | - _Inference Server_ is responsible for receiving requests and distributing inference results.
28 | - Inline Inference: When using CPU as inference device, actor and policy workers reside in the same process. Inline inference stream avoids unnecessary communication in this situation.
29 |
30 | - Distributed Inference stream is implemented in ZMQ (https://zeromq.org/).
31 | - Requests: PUSH-PULL pattern sockets.
32 | - Responses: PUB-SUB pattern sockets.
33 |
34 | See [Inference Stream](07_inference_stream.md) for detailed APIs and usage.
35 |
36 | ### Sample Stream
37 |
38 | Sample streams receive and batch training samples from actor workers and send them to trainer workers. Sample Stream is simplex. _Sample Producers_ and _Sample Consumers_ of a sample stream could be created in endpoint workers as handles to receive and send data.
39 |
40 | - _Sample Producer_ is responsible for batching samples and sending out samples.
41 | - _Sample Consumer_ is responsible for consuming samples (into buffers).
42 |
43 | - Sample are batches two times before consuming.
44 | - Upon sending, actors will batch the sample along the _time_ dimension.
45 | - Upon consuming, buffer will batch the sample along the _batch_ dimension.
46 |
47 | - Distributed Sample Stream is implemented in ZMQ.
48 | - PUSH-PULL pattern sockets.
49 |
50 | See [Sample Stream](08_sample_stream.md) for detailed APIs and usage.
51 |
52 | ### Generality
53 |
54 | Both sample streams and inference stream could be used in a more general way. To support communication between workers implemented by users, payload of data streams could be data other than inference requests, replys and training samples. From our observation, sample streams and inference streams are able to satisfy requirements for communications between any types of workers.
55 |
56 | ## Parameter Database
57 |
58 | Parameter Database stores model parameters, and workers can push/get parameters via ParameterDBClient.
59 |
60 | - ParameterDB is currently implemented in file systems (NFS for distributed, local file system for local).
61 | - Metadata query is supported via MongoDB (distributed).
62 |
63 | See [Basic Utils](08_basic_utils.md) for further details.
64 |
65 | ## Name Resolving and Name Record Repositories
66 |
67 | Workers exchange system level metadata via Name Resolving, including addresses, ports, peers, etc. Name record repositories are databases that stores system level metadata and provide supports for various data operations. Name Resolving is now mostly used in distributed version of the system.
68 |
69 | NameResolving is currently used in the following ways:
70 | - Workers save their listen port to name resolve for controller.
71 | - InferenceServer reveal their inbound address for clients.
72 | - SampleConsumer reveal their inbound address for producers.
73 | - Trainers reveal their identity so that they can find DDP peers.
74 |
75 | Name record repositories implementations:
76 | - MemoryNameRecordRepository (Local only. No inter-node communication.)
77 | - NFSNameRecordRepository (Distributed, requires NFS support in cluster)
78 | - RedisNameRecordRepository (Distributed, requires [redis](https://redis.io/docs/) support in cluster. Recommended.)
79 |
80 | See [Basic Utils](09_basic_utils.md) for further details.
81 |
82 | # Related References
83 |
84 | - [Worker Base](01_worker_base.md)
85 | - [Actor Worker](02_actor_worker.md)
86 | - [Policy Worker](03_policy_worker.md)
87 | - [Trainer Worker](04_trainer_worker.md)
88 | - [Inference Stream](07_inference_stream.md)
89 | - [Sample Stream](08_sample_stream.md)
90 |
91 | # Related Files or Directories
92 | - [system/api/inference_stream.py](../../src/rlsrl/system/api/inference_stream.py)
93 | - [system/api/parameter_db.py](../../src/rlsrl/system/api/parameter_db.py)
94 | - [system/api/sample_stream.py](../../src/rlsrl/system/api/sample_stream.py)
95 | - [system/api/worker_base.py](../../src/rlsrl/system/api/worker_base.py)
96 |
97 | # What's Next
98 |
99 | - [Worker Base](01_worker_base.md)
100 |
--------------------------------------------------------------------------------
/docs/system_components/01_worker_base.md:
--------------------------------------------------------------------------------
1 | # System Components: Worker Base
2 | ### Base class for workers: `worker_base.Worker`
3 |
4 | Base class of all workers. Common handles implemented in class `worker_base.Worker` are:
5 |
6 | - configure
7 | - start
8 | - pause
9 | - exit
10 | - run
11 |
12 | Upon launching the system, a centralized controller will launch the workers (in this local version, launching method is python `multiprocessing`) and call these handles to control the life cycle of workers.
13 |
14 | ### Worker Status
15 |
16 | There are two status for a live worker, running and exiting. Workers could be neither running nor exiting when they are not configured or paused. There are three handles that only changes worker status by default. They are:
17 |
18 | - `start`: change status running to true.
19 | - `pause`: change status running to false.
20 | - `exit`: change status exiting to true
21 |
22 | ### Life Cycle
23 |
24 | In local version, workers' life cycles are simplified. Three methods of workers are called after launching workers and before exiting or any pausing: `configure()`, `start()`, and `run()`. In distributed version, other handles are used by controller to provide functions such as reconfiguring, pausing and monitoring.
25 |
26 | When `configure()` is called, `_configure()` method implemented in subclasses will run to parse worker specific configs and initializations. After that, worker will run common configs such as `wandb` logging configuration.
27 |
28 | When `run()` is called, worker will run in a dead loop of `_poll()` if `start()` is called to set the worker running. After every `_poll()` the worker will log the default stats and worker stats specified in `_stats()` to standard output or/and `wandb`.
29 |
30 | ### Implement subclasses of `worker_base.Worker`
31 |
32 | If you want to implement some type of Worker on your own, you should **at least** override method `_configure()` and `_poll()`.
33 |
34 | In `_configure()`, you may initialize your worker by completing following steps:
35 | 1. Configure required parameters and variables (by reading them from configuration file).
36 | 2. Initialize I/O, storage or computation components, assign devices and computing resources.
37 | 3. Initialize and start threads if your worker is implemented in a threaded fashion. Details will be discussed in class `MappingThread`.
38 | 4. Configure worker specific monitoring.
39 |
40 | In `_poll()`, you should implement the main computing step of your worker. The return value (sample counts and batch counts) could be arbitrarily defined to whatever you want to log.
41 |
42 | In addition, you could override handles that changes worker status: `pause()`, `start()` and `exit()` to change behavior of workers upon status change.
43 |
44 | If you desire to record extra information when running worker, specify them in `_stats()` to log then down in standard output or `wandb`.
45 |
46 | ### Scheduling
47 | The local version of the system are scheduled with default python `multiprocessing` rules. Workers are launched as `multiprocessing.Process` and connected by data streams implemented in `multiprocessing.Queue`. In the distributed version, workers are scheduled via [slurm](https://slurm.schedmd.com/documentation.html) scheduler.
48 |
49 |
50 | ## class `MappingThread`
51 |
52 | The workers could be implemented in a threaded fashion using python `threading` module. For example, in our original workers:
53 | - Policy Worker has a main thread(cpu), an inference thread, and a responding thread(cpu).
54 | - Trainer Worker has a main thread(cpu, where buffer resides), and a training thread(gpu).
55 |
56 | `Mapping Thread` is a wrapper of python `threading.Thread` that makes it easier to implement a threaded worker. A mapping thread gets input from its upstream queue, process data and outputs into a downstream queue.
57 |
58 | ### Initialization
59 |
60 | `Mapping Thread` is initialized with:
61 | 1. `map_fn`: A mapping function that takes input and process it to get output.
62 | 2. `upstream_queue`: Input queue.
63 | 3. `downstream_queue`: Optional. Output queue, should be none if there is no output for `map_fn`.
64 | 4. `cuda_device`: Optional. CUDA device that should be used in this mapping thread.
65 |
66 | ### Handles
67 |
68 | There are 4 handles for an instance of `MappingThread`, similar to python `threading.Thread`:
69 | 1. `is_alive()`: check whether the thread is alive.
70 | 2. `start()`: Start the thread. After starting, the thread will take input from `upstream_queue`, call `map_fn` and outputs into `downstream_queue` repeatedly.
71 | 3. `join()`: Join the thread.
72 | 4. `stop()`: Stop the thread.
73 |
74 | # Related References
75 |
76 | - [System Overview](00_system_overview.md)
77 |
78 | # Related File and Directory
79 | - [system/api/worker_base.py](../../src/rlsrl/system/api/worker_base.py)
80 |
81 | # What's Next
82 |
83 | - [Actor Worker](02_actor_worker.md)
84 |
85 |
--------------------------------------------------------------------------------
/docs/system_components/02_actor_worker.md:
--------------------------------------------------------------------------------
1 | # System Components: Actor Worker
2 |
3 | _"Reality is merely a simulation of God who himself is the simulation of humanity emulated through a brief history of time."_
4 |
5 | An **Actor Worker** simulates multiple environments to produce samples (training data) for RL training. In a single actor workers, multiple copies of environments receives actions from [Policy Workers](03_policy_worker.md) via [Inference Clients](07_inference_stream.md), simulates to obtain observations from the environments and send them to [Policy Workers](03_policy_worker.md) via [Inference Clients](07_inference_stream.md). After collecting enough simulation data (action, observations, etc.), actor workers batch them into a sample and send them to [Trainer Workers](04_trainer_workers) via [Sample Producers](08_sample_stream.md) as training data.
6 |
7 | ## Agent (class `Agent`)
8 |
9 | An **Agent** is a minimal acting unit in an simulated environment. Each agent has its own policy, which means it should communicate with [Inference Stream](07_inference_stream.md) and [Sample Stream](08_sample_stream.md) corresponding to the policy. An agent has two major tasks:
10 | 1. Get inference results from its Inference stream, corresponding to `Agent.get_action()`. After retriving one step inference result (an action), put the result into last observation in the memory, which is the request corresponding to result, to form a complete [Sample Batch](../user_guide/09_basic_apis.md). Pass the action to Environment Target and expect for next step result. Finally, update policy state from inference stream.
11 | 2. Process a new step result from its Environment Target. One new step result contains new observation and reward of last step. Upon receiving a new observation, an agent first check whether it should post sample batches in its memory to Sample Stream. The conditions of sending sample batches are related to agent configuration (See [AgentSpec configuration](../user_guide/config_your_experiment.md#AgentSpec)).
12 |
13 | ### Initialization
14 |
15 | Refer to the initialization docstring of [AgentSpec configuration](../user_guide/config_your_experiment.md#AgentSpec) on how to config an Agent.
16 |
17 | ## Environment Target (class `_EnvTarget`)
18 |
19 | An **Environment Target** hold a single **Environment** instance, and manages all Agent instances in an environment. An Environment Target Exposes 5 methods for the actor worker to call:
20 | 1. `all_done()`: Check if all agents in this environment are done.
21 | 2. `unmask_all_info()`: Unmask the most recent episode info of all agents.
22 | 3. `reset()`: Reset the environment.
23 | 4. `step()`: Get actions from agents, and perform one environment step.
24 | 5. `ready_to_step()`: Check if all agents are ready to perform a step in the environment.
25 |
26 | ### Initialization
27 |
28 | The initialization of an Environment Target requires an Environment instance, a maximum number of steps that the environment is allowed to run per episode (to avoid a dead loop in an environment), and a list of Agents in this Environment Target.
29 |
30 | ## Environment Ring (class `_EnvRing`)
31 |
32 | In a trivial implementation of actor workers, simulation and policy inference are executed sychronously. After completing simulation for one step, actor worker sends inference requests to its corresponding Inference Streams, waits for replys and then continue to simulate the next step. In this procedure, resources occupied by actor worker will idle when waiting for replys. To utilize resource as much as possible, the actor worker could run simulation in an asynchronous fashion, which means running simulation on other environments while waiting for inference replys. **Environment Ring** is a data structure that supports asynchronous simulation.
33 |
34 | One Environment Ring contains multiple copies of identical Environment Targets_(i.e. environment simulators). Each _Environment Ring_ has a pointer that points to one target. After one simulation step, the ring rotates and the pointer points to the next target.
35 |
36 | ### Initialization
37 |
38 | Specify list of Environment Targets in the ring to initialize Environment Ring.
39 |
40 | ## class `Actor Worker`
41 |
42 | This class is inherited from `worker_base.Worker` ([Worker Base](01_worker_base.md)).
43 |
44 | In each `_poll()`, the actor worker execute following procedure for `_MAX_POLL_STEPS` (default = 16) times:
45 | 1. Check if all Agents in this Environment Target is done.
46 | 1. If yes, unmask all info and reset this Environment Target. If using inline inference, load new parameters for local policy worker.
47 | 2. If not, check if any agents in this Environment Target is waiting for inference reply.
48 | 1. If yes, break to wait for agents to be ready.
49 | 2. If no, perform one target step.
50 | 2. Flush inference clients every determined number of steps. (See [Inference Stream](07_inference_stream.md))
51 | 3. Rotate Environment Ring.
52 |
53 | ### Initialization
54 |
55 | Refer to the initialization docstring of [ActorWorker configuration](../user_guide/config_your_experiment.md#ActorWorker) on how to config an _Actor Worker_.
56 |
57 | Initialization process is (`_configure()`):
58 | 1. Make Inference Clients and Sample Producers. The actor worker keeps a reference to these streams.f
59 | 2. Make agents as specified by AgentSpec. Each agent is matched with an Inference Client and a Sample Producer.
60 | 3. Create many Environment Targets with the specified Environment and agents created in 2.
61 | 4. Create the Environment Ring.
62 |
63 | # Related References
64 |
65 | - [System Overview](00_system_overview.md)
66 | - [Worker Base](01_worker_base.md)
67 | - [Policy Worker](03_policy_worker.md)
68 | - [Trainer Worker](04_trainer_worker.md)
69 | - [Inference Stream](07_inference_stream.md)
70 | - [Sample Stream](08_sample_stream.md)
71 |
72 | # Related File and Directory
73 | - [system/basic/impl/actor_worker.py](../../src/rlsrl/system/impl/actor_worker.py)
74 |
75 | # What's Next
76 |
77 | - [Policy Worker](03_policy_worker.md)
78 |
--------------------------------------------------------------------------------
/docs/system_components/03_policy_worker.md:
--------------------------------------------------------------------------------
1 | # System Components: Policy Worker
2 |
3 | _"The best policy is to declare victory and leave."_
4 |
5 | A **Policy Worker** does policy model inference to produce actions as inputs for environment simulations in [Actor Workers](02_actor_worker.md). Policy Workers receives inference requests from _Inference Servers_, finish policy inference designated by the requests, and send the results to [Actor Workers](02_actor_worker.md) via _Inference Servers_. Policy Workers are also required to update their policy model in a required frequency via _Parameter Database_. (_Inference Server_ and _Parameter Database_ are introduced in [System Overview](01_system_overview.md))
6 |
7 | ## class `PolicyWorker`
8 | ### Initialization (Configuration)
9 | Refer to the initialization docstring of [PolicyWorker configuration](../user_guide/config_your_experiment.md#PolicyWorker) on how to config a Policy Worker.
10 |
11 | There are 3 steps in configuration (in function `_configure()`):
12 | 1. Make policy model for inference.
13 | 2. Initialize Inference Server and parameter database that stores policy model.
14 | 3. Initialize and start rollout thread and respond thread.
15 |
16 | ### Threaded Implementation
17 | Policy Worker is implemented in a threaded fashion to optimize efficiency. There are three threads for a policy worker, which will be introduced in detail in the following sections.
18 |
19 | #### Main Thread
20 | The main thread is responsible for two tasks (in function`_poll()`):
21 | 1. Pulling parameters from database or files. Main thread actively get policy checkpoint from database or files on a pre-configured frequency, then put the parameters into parameter queue for inference thread to update the policy.
22 | 2. Receiving and batching rollout requests. The main thread receives inference requests from Inference Stream, batch all unprocessed inference requests into inference queue (for inference thread to process) until the batch reaches a pre-configured batch size. In the case that number of unprocessed inference requests exceeds batch size, the main thread puts them into a request buffer and wait for inference of prior batch.
23 |
24 | Notice that rollout batch size is dynamically adjusted according to the throughput of inference requests. The inference queue is a queue of size 1, and when the queue is full, the main thread accumulates requests. Once the inference queue is cleared, the main-thread batches all the pending requests and put to the queue.
25 |
26 | The batch count of policy worker is number batches that the worker has processed, and sample count is number of inference requests the worker has processed.
27 |
28 | #### Rollout Thread
29 | The rollout thread updates policy model from parameter queue, and runs `policy.rollout` on request batch in inference queue on every step. After rollout, it puts the results into respond queue.
30 |
31 | #### Respond Thread
32 | The respond queue sends the inference results (from respond queue) to the Inference Stream.
33 |
34 |
35 | # Related References
36 |
37 | - [System Overview](01_system_overview.md)
38 | - [Actor Worker](02_actor_worker.md)
39 |
40 | # Related Files and Directories
41 |
42 | - [system/impl/policy_worker.py](../../src/rlsrl/system/impl/policy_worker.py)
43 |
44 | # What's Next
45 |
46 | - [Trainer Worker](04_trainer_worker.md)
47 |
--------------------------------------------------------------------------------
/docs/system_components/04_trainer_worker.md:
--------------------------------------------------------------------------------
1 | # System Components: Trainer Worker
2 |
3 | _"Tell me and I forget, teach me and I may remember, involve me and I learn."_
4 |
5 | **Trainer Workers** consume training samples and train policy models. Trainer Workers receive training samples from [Actor Workers](02_actor_workers.md) via Sample Consumers. Trainer Workers updates policy models stored in Parameter Database once completing a training step. Although this is a local version of the system, we support multi-GPU training. Multiple trainer workers synchronize their gradients through
6 | [pytorch DistributedDataParallel framework](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html).
7 |
8 |
9 | ## class `TrainerWorker`
10 |
11 | ### Initialization(Configuration)
12 | Refer to the initialization docstring of [TrainerWorker configuration](../user_guide/config_your_experiment.md#TrainerWorker) on how to config a Trainer Worker.
13 |
14 | In `_configure()`, steps of initialization are as follows:
15 | 1. Initialize buffer, sample consumer and parameter database.
16 | 2. Reveal DDP identity (register as a DDP peer in name resolve)
17 | 3. Set up DDP: Retrieve DDP peer information from name resolve and initialize DDP connection.
18 | 4. Initialize policy and algorithm. Try loading checkpoint for specified policy. If there are no checkpoint, initialize policy model and push the first checkpoint to parameter database.
19 | 5. Initialize GPU thread.
20 |
21 | ### Threaded Implementation
22 | To optimize the efficiency of GPU usage, there are two threads in the Trainer Worker: main thread and GPU thread. GPU thread runs training tasks. It starts when `start()` is called and stops when `exit()` is called. GPU thread is implemented as a separate class, and details will be introduced in the class documentation.
23 |
24 | #### Main Thread
25 | Main threads receives sample batches from sample streams, and controls checkpoint pushing to parameter database.
26 |
27 | The frequency of parameter checking and pushing is controlled by amount of data consumed from sample streams in one `_poll()`. In default setting, in one `_poll`, the main thread consumes 1024 sample batches from sample stream at maximum.
28 |
29 | Tagged version (permanent version) of policy checkpoint is stored in paramter database with a pre-configured frequency. Other versions are stored for backup, but constantly cleared.
30 |
31 | ## class `GPUThread`
32 |
33 | Class `GPUThread` is a threaded implementation for training with GPU.
34 |
35 | ### Initialization
36 |
37 | Arguments required for initializing `GPUThread` includes:
38 | 1. `buffer`: Buffer instance for storage;
39 | 2. `trainer`: Trainer instance for algorithm;
40 | 3. `is_master`: Boolean value indicating whether this Trainer Worker is DDP master;
41 | 4. `log_frequency_seconds`: Time limit for logging frequency control;
42 | 5. `log_frequency_steps`: Step limit for logging frequency control;
43 | 6. `push_frequency_seconds`: Time limit for checkpoint pushing frequency control;
44 | 7. `push_frequency_steps`: Step limit for checkpoint pushing frequency control;
45 | 8. `dist_kwargs`: DDP related arguments.
46 |
47 | See [Basic Utils](09_basic_utils.md) for more details about frequency control.
48 |
49 | ### Main Loop
50 |
51 | GPU thread runs in a dead loop until interruption. One step in the loop includes:
52 | 1. Check whether buffer is empty. If yes, sleep for 5 ms to void locking the buffer. (TODO: implement thread-safe buffer) If no, get an replay entry.
53 | 2. Run one training step on replay entry.
54 | 3. Check logging condition and store log into queue.
55 | 4. Check pushing checkpoint condition and store checkpoint into queue.
56 |
57 | ### Logging
58 |
59 | When `TrainerWorker._stats()` is called, if the trainer is DDP master, `GPUThread.stats()` will be called to return all logged stats in logging queue.
60 |
61 | ### Terminating
62 |
63 | When Trainer Worker is exiting, it will tell `GPUThread` when to stop running. To prevent DDP peers from waiting for other exited peers, DDP peers should stop training with only a few step differences. Trainer Worker will calculate when to stop GPU thread safely and tell GPU thread to enter an interrupt loop.
64 |
65 | # Related References
66 |
67 | - [System Overview](01_system_overview.md)
68 | - [Actor Worker](02_actor_worker.md)
69 | - [Basic Utils](09_basic_utils.md)
70 |
71 | # Related Files and Directories
72 |
73 | - [system/impl/trainer_worker.py](../../src/rlsrl/system/impl/trainer_worker.py)
74 |
75 | # What's Next
76 |
77 | - [Buffer Worker](05_buffer_worker.md)
--------------------------------------------------------------------------------
/docs/system_components/05_buffer_worker.md:
--------------------------------------------------------------------------------
1 | # System Components: Buffer Worker
2 |
3 | Not updated yet.
4 |
5 | # What's Next
6 | - [EvalManager](06_eval_manager.md)
--------------------------------------------------------------------------------
/docs/system_components/06_eval_manager.md:
--------------------------------------------------------------------------------
1 | # System Components: Evaluation Manager
2 |
3 | _"Without proper self-evaluation, failure is inevitable."_
4 |
5 | ## Initialization (Configuration)
6 | Refer to the initialization docstring of [EvalManager configuration](../user_guide/config_your_experiment.md#EvalManager) on how to config an Evaluation Manager.
7 |
8 | ## Worker Logic
9 |
10 | Evaluation manager accepts samples of two different kinds:
11 |
12 | 1. Samples of the same version to the current `eval_tag`.
13 | 2. Samples of tagged policy versions.
14 |
15 | NOTE: eval_manager will discard samples where the policy version is not unique, or it does not match the above.
16 |
17 | On receiving 1, the sample is considered as an evaluation result.
18 | With the specified frequency, data will be logged to W&B.
19 |
20 | On receiving 2, eval_manager will extract the `episode_info` from the last step and update the metadata on that version
21 | accordingly.
22 |
23 | # Related Files and Directory
24 | - [system/impl/eval_manager.py](../../src/rlsrl/system/impl/eval_manager.py)
25 |
26 | # What's next
27 | - [Inference Stream](07_inference_stream.md)
--------------------------------------------------------------------------------
/docs/system_components/07_inference_stream.md:
--------------------------------------------------------------------------------
1 | # System Components: Inference Stream
2 |
3 | **Inference Stream** defines the data flow between policy workers and actor workers.
4 |
5 | In our design, actor workers are in charge of executing `env.step()` (typically simulation), while
6 | policy workers running `policy.rollout_step()` (typically neural network inference). The inference
7 | stream is the abstraction of the data flow between them: the actor workers send environment
8 | observations as requests, and the policy workers return actions as responses, both plus other
9 | additional information.
10 |
11 | ## class `InferenceClient`
12 |
13 | Interface used by the actor workers to obtain actions given current observation.
14 |
15 | See [local/system/inference_stream.py](../../system/api/inference_stream.py) for detailed API description.
16 |
17 | ## class `InferenceServer`
18 |
19 | Interface used by the policy workers to serve inference requests.
20 |
21 | See [local/system/inference_stream.py](../../system/api/inference_stream.py) for detailed API description.
22 |
23 | ## Implementations
24 |
25 | ### Shared Memory Local Inference Stream (class `PinnedSharedMemoryInferenceClient`, `PinnedSharedMemoryInferenceServer`)
26 |
27 | Inference Stream implementation with pinned python shared memory.
28 |
29 | ### IP Remote Inference Stream (class `IpInferenceClient`, `IpInferenceServer`)
30 |
31 | Inference Stream implementation with sockets.
32 |
33 | ### Name Resolving Inference Stream (class `NameResolvingInferenceClient`, `NameResolvingInferenceServer`)
34 |
35 | Inference Stream implementation with name resolveing service to match inference clients and servers.
36 |
37 | ### Inline Inference Stream (class `InlineInferenceClient`)
38 |
39 | Inline Inference Stream is a special type of Inference Stream. It is used to do inference on CPU devices.
40 |
41 | GPU inference is usually faster than CPU inference, however not always more efficient. In some occasion, when GPU resource is not available or transmitting data between different processes is not optimal (due to bandwidth or efficiency problem), CPU inference is the better choice.
42 |
43 | This is where Inline Inference Stream comes to play. To understand Inline Inference Stream better, we can treat it as an Inference Stream that connect actor workers with "policy workers" that inference with CPU devices. However, in implementation, inference is done when calling `flush()` method in `InlineInferenceClient`.
44 |
45 | The implementation for inference is similar to [Policy Worker](03_policy_worker.md). The implementation for data stream is similar to normal (local) inference stream.
46 |
47 |
48 | # Related References
49 | - [System Overview](00_system_overview.md)
50 | - [Actor Worker](02_actor_worker.md)
51 | - [Policy Worker](03_policy_worker.md)
52 |
53 | # Related Files and Directories
54 | - [system/api/inference_stream.py](../../src/rlsrl/system/api/inference_stream.py)
55 | - [system/impl/local_inference.py](../../src/rlsrl/system/impl/local_inference.py)
56 | - [system/impl/inline_inference.py](../../src/rlsrl/system/impl/inline_inference.py)
57 | - [system/impl/remote_inference.py](../../src/rlsrl/system/impl/remote_inference.py)
58 |
59 | # What's next
60 |
61 | - [Sample Stream](08_sample_stream.md)
62 |
--------------------------------------------------------------------------------
/docs/system_components/08_sample_stream.md:
--------------------------------------------------------------------------------
1 | # System Components: Sample Stream
2 |
3 | **Sample Stream** defines the data flow between the actor workers and the trainers. It is a simple producer-consumer model.
4 |
5 | A side note that our design chooses to let actor workers see all the data, and posts trajectory samples to the trainer, instead of letting the policy workers doing so.
6 |
7 | ## class `SampleProducer`
8 |
9 | Interface used by the actor workers to post samples to the trainers.
10 |
11 | See [system/sample_stream.py](../../src/rlsrl/system/api/sample_stream.py) for detailed API description.
12 |
13 | ## class `SampleConsumer`
14 |
15 | Interface used by the trainers to acquire samples.
16 |
17 | See [system/sample_stream.py](../../src/rlsrl/system/sample_stream.py) for detailed API description.
18 |
19 | ## class `ZippedSampleProducer(SampleProducer)`
20 |
21 | Sometimes one copy of training sample are required to be sent to multiple consumers. `ZippedSampleProducer` is a set of multiple sample producers. When `ZippedSampleProducer.post(sample)` is called, `sample` is sent by all sample producers.
22 |
23 | ## Implementations
24 |
25 | ### Shared Memory Local Sample Stream (class `SharedMemorySampleProducer`, `SharedMemorySampleConsumer`)
26 |
27 | A sample stream implementation with python shared memory.
28 |
29 | ### Shared Memory Sample Stream (class `IpSampleProducer`, `IpSampleConsumer`)
30 |
31 | A sample stream implementation with sockets.
32 |
33 | ### Shared Memory Sample Stream (class `NameResolvingSampleProducer`, `NameResolvingSampleConsumer`)
34 |
35 | Sample Stream implementation with name resolveing service to match producers and consumers.
36 |
37 | ### Null Sample Producer (class `NullSampleProducer`)
38 |
39 | A dummy sample producer that discard all samples.
40 |
41 |
42 | # Related References
43 | - [System Overview](00_system_overview.md)
44 | - [Actor Worker](02_actor_worker.md)
45 | - [Trainer Worker](04_trainer_worker.md)
46 |
47 | # Related Files and Directories
48 | - [system/api/sample_stream.py](../../src/rlsrl/system/api/sample_stream.py)
49 | - [system/impl/local_sample.py](../../src/rlsrl/system/impl/local_sample.py)
50 | - [system/impl/remote_sample.py](../../src/rlsrl/system/impl/remote_sample.py)
51 |
52 | # What's next
53 |
54 | - [Basic Utils](09_basic_utils.md)
--------------------------------------------------------------------------------
/docs/system_components/09_basic_utils.md:
--------------------------------------------------------------------------------
1 | # System Components: Basic Utils
2 |
3 | In this section, basic utilities including parameter database in the system will be introduced in detail.
4 |
5 | ## Parameter Server (class `ParameterDBClient`)
6 |
7 | `ParameterDBClient` provides communication between an user and the parameter database (aka parameter server). For details about parameters (Policy models) naming rules and handles in parameter database client, see [system/parameter_db.py](../../src/rlsrl/system/parameter_db.py).
8 |
9 | ### Implementation: Pytorch Filesystem Parameter Database (class `PytorchFilesystemParameterDB(ParameterDBClient)`)
10 |
11 | An implementation of parameter DB that stores `pytorch` models in the filesystem. All files are stored in `"$HOME/marl_checkpoints"` by default. If you wish to change the directory, change `ROOT` parameter under this class.
12 |
13 | ## Buffer (class `Buffer`)
14 |
15 | A buffer that stores training sample, which is used in [Trainer Worker](04_trainer_worker.md). It has 3 simple methods:
16 |
17 | - `put(x)`: put `x` into buffer storage.
18 | - `get()`: get next element.
19 | - `empty()`: check if the buffer is empty.
20 |
21 | Related file [base/buffer.py](../../src/rlsrl/base/buffer.py)
22 |
23 | ### Implementation
24 |
25 | In all following implementations, buffers store samples in a unit of batch, which includes`batch_size` samples. One batch of samples is stored in a `ReplayEntry` data structure, which has fields :`reuses_left`, `receive_time`, `sample` and `reuses`.
26 |
27 | #### Simple Queue Buffer (class `SimpleQueueBuffer(Buffer)`)
28 |
29 | A simple buffer that is implemented with a python `queue.SimpleQueue()`, following FIFO pattern.
30 |
31 | #### Simple Replay Buffer (class `SimpleReplayBuffer(Buffer)`)
32 |
33 | A buffer that allows to get one sample batch for a pre-configured number of times. It uniformly samples a sample batch when calling `get()`. When sample batch reaches maximum replay time, it is discarded.
34 |
35 | #### Priority Buffer (class `PriorityBuffer(Buffer)`)
36 |
37 | A replay buffer that get batches with the maximal `reuses_left`.
38 |
39 | ## GPU utils
40 |
41 | GPU utils, related file: [base/gpu_utils.py](../../src/rlsrl/base/utils.py).
42 |
43 | ## Named Array (class `NamedArray`)
44 |
45 | **Named Array** is a data structure that is used everywhere in the system. A class modified from the `namedarraytuple` class in rlpyt repo, referring to https://github.com/astooke/rlpyt/blob/master/rlpyt/utils/collections.py#L16.
46 |
47 | NamedArray supports dict-like unpacking and string indexing, and exposes integer slicing reads and writes applied to all contained objects, which must share indexing (`__getitem__`) behavior (e.g. numpy arrays or torch tensors).
48 |
49 | Note that namedarray supports nested structure, i.e., the elements of a NamedArray could also be NamedArray.
50 |
51 | See related file [base/namedarray.py](../../src/rlsrl/base/namedarray.py) for implementation details, and read [User Guide: NamedArray](../user_guide/06_named_array.md) for usage guide and examples.
52 |
53 | ## Names
54 |
55 | Methods to get names to store in name resolving, related file: [base/names.py](../../src/rlsrl/base/names.py).
56 |
57 | ## Name Record Repository (class `NameRecordRepository`)
58 |
59 | Name record repository, also referred as **name resolve**, implements a simple name resolving service, which can be considered as a global key-value dict. See related file [base/name_resolve.py](../../src/rlsrl/base/name_resolve.py) for detailed info about APIs.
60 |
61 | ## Network
62 |
63 | Utility functions about networking, related file [base/network.py](../../src/rlsrl/base/network.py)
64 |
65 | ## Numpy Utils
66 |
67 | Utility functions about numpy, related file [base/numpy.py](../../src/rlsrl/base/numpy.py)
68 |
69 | ## Time Utils (class `FrequencyControl`)
70 |
71 | Frequency Control is an utility to control the execution of code with a time or/and step frequency, used when workers needs a timing method to control frequency of some operations. See file [base/timeutil.py](../../src/rlsrl/base/util.py) for detailed usage.
72 |
73 | ## User Utils
74 |
75 | Utility functions about OS and users, related file [base/user.py](../../src/rlsrl/base/user.py)
76 |
77 |
78 | # Related References
79 | - [System Overview](00_system_overview.md)
80 |
81 | # Related Files and Directories
82 | - [system/api/parameter_db.py](../../src/rlsrl/system/api/parameter_db.py)
83 | - [base/](../../src/rlsrl/base/)
84 |
--------------------------------------------------------------------------------
/docs/user_guide/00_overview.md:
--------------------------------------------------------------------------------
1 | # User Guide: Overview
2 |
3 | This user guide is a detailed manual for both new-comers that wish to train an RL agent with well-implemented algorithms in supported envrionments, and advanced users that aim to design and implement their own algorithms and even system components. It includes documentation of APIs exposed by the system and instructions of common system usage.
4 |
5 | ## Command Line Options
6 |
7 | To run an experiment in **SRL** after installation:
8 |
9 | `srl-local run -e -f `
10 |
11 | Options include:
12 |
13 | - `-e (--experiment_name)`: Experiment name, explained in [User Guide: Config Your Experiment](01_experiment_config.md).
14 | - `-f (--trial_name)`: Experiment name, explained in [User Guide: Config Your Experiment](01_experiment_config.md)
15 | - `--wandb_mode`: Wandb Logging mode, includes online, offline and disabled (default). Explained in [User Guide: Logging](02_logging.md).
16 | - `--import_files`: Extra files to import, include files that defines experiments, environments, policies and algorithms.
17 |
18 | ## For Beginners
19 |
20 | If you are a first-time user, its highly recommanded to try running our system with existing configuration files to observe logging and behaviors of system components. If you are interested, we also provide [documentation about our system components](../system_components/00_system_overview.md).
21 |
22 | After getting familiar with our system, you might want to try more different parameters and training options to get better results in your experiments. Read [User Guide: Config Your Experiment](01_experiment_config.md) to learn how to run an experiment and write your own configuration files. We support logging in terminal output as well as using visualize tools [wandb](https://wandb.ai/site). Read [User Guide: Logging](02_logging.md) to understand how to configure logging and learn system and training data that could be logged in our system.
23 |
24 | Experiment Configuration examples: [legacy/experiments/](../../src/rlsrl/legacy/experiments/).
25 |
26 | ## For Advanced Users
27 |
28 | For advanced users, if you want to run experiments on environments that are not inherently supported in our system, we provide environment APIs, which is a wrapper of [gym API](url=https://github.com/openai/gym#api). Read [User Guide: Environments](03_environments.md) to learn to use our environment APIs. Also, you might want to re-implement novel RL algorithms in our systems, or design and implement your own new algorithms. To learn how to implement new policies and algorithms, please read [User Guide: Policy and algorithm development](04_policy_algorithm.md).
29 |
30 | In our system, we use `NamedArray` as a basic data structure, which is an extension to `numpy` arrays. `NamedArray` is almost used everywhere in our system, and learning how it works will greatly help you get your hands on coding. Read [User Guide: NamedArray](05_named_array.md) for detailed information about `NamedArray`.
31 |
32 | Environment implementation examples: [legacy/environments/](../../src/rlsrl/legacy/environments/), Policy and Algorithm implementation examples: [legacy/algorithm/](../../src/rlsrl/legacy/algorithm/).
33 |
34 | ## For System-level Users
35 |
36 | Sometimes you might find our system's architecture (actor, policy, trainer workers with a simple parameter server) cannot meet the needs of your new algorithm. This usually happens when your algorithms require additional data processing or transformation, and have computation workloads other than simulation, policy inference and training. For example, Muzero (Reference: [Mastering Atari, Go, Chess and Shogi by Planning with a Learned Model](https://www.nature.com/articles/s41586-020-03051-4)) requires reanalyzation before sending samples to trainers. This is where you should implement your own workers and plug them into our system. We expose APIs for workers and communication APIs in sample/inference streams to ensure that you could implement any worker with any communication patterns you desire. For more detailed information and references, please read [documentation about our system components](../system_components/00_system_overview.md) thoroughly.
37 |
38 | # Related References
39 | - [User Guide: Experiment Config](01_experiment_config.md)
40 | - [User Guide: Logging](02_logging.md)
41 | - [User Guide: Environments](03_environments.md)
42 | - [User Guide: Policy and algorithm development](04_policy_algorithm.md)
43 | - [User Guide: NamedArray](05_named_array.md)
44 |
45 | # What's Next
46 | - [User Guide: Experiment Config](01_experiment_config.md)
47 |
--------------------------------------------------------------------------------
/docs/user_guide/02_logging.md:
--------------------------------------------------------------------------------
1 | # User Guide: Logging
2 |
3 | There are two ways of logging we support in our system: standard output and [wandb](https://wandb.ai/site). In the system, the only statistic logging is done in **workers**, which is implemented in the main logic of class `WorkerBase`.
4 |
5 | ## Standard Output Logging
6 |
7 | To enable standard output logging for some worker, set `log_terminal` in `WorkerInformation` spec of this worker to true.
8 |
9 | ## wandb Logging
10 |
11 | To enable standard output logging for some worker, set `log_wandb` in `WorkerInformation` spec of this worker to true, and complete following fields:
12 |
13 | - `wandb_entity`
14 | - `wandb_project`
15 | - `wandb_job_type`
16 | - `wandb_group`
17 | - `wandb_name`
18 |
19 | By default, these values will be setup by the system with the following defaults:
20 | - `wandb_entity = None`
21 | - `wandb_project = experiment_name` (-e)
22 | - `wandb_group = trial_name` (-f)
23 | - `wandb_job_type = worker_type` (actor/policy/trainer/eval_manager)
24 | - `wandb_name = policy_name or "unnamed" `
25 |
26 | The worker configuration will also be passed as argument ```config``` to ```wandb.init()```. Nested dataclasses are not parsed by W&B. For example, currently trainer configuration in `TrainerWorker` cannot be used as filter. This is a
27 | known issue and will be resolved in the future. A workaround is to add the values that you want to filter on to `wandb_name`. See below for configuration instructions.
28 |
29 | You may specify your customized configuration in you experiment configuration. An example will be:
30 |
31 | ```python
32 |
33 | from rlsrl.api.config import *
34 |
35 | actor_worker_count = 10
36 | actor_worker_configs = [experiments.ActorWorker(..., # worker configuration
37 | worker_info=experiments.WorkerInformation(
38 | wandb_entity="your_entity",
39 | wandb_project="your_project",
40 | wandb_group="your_group",
41 | wandb_job_type="actor_worker",
42 | wandb_name=f"my_perfect_wandb_name_actor_{worker_index}")
43 | ) for worker_index in range(actor_worker_count)]
44 | ```
45 |
46 | ## Logging Arbitrary Data
47 |
48 | When running your own experiments, you may want to log data that is not explicitly shown in our orignial implementation. In this situation, you will have to understand how logging in workers works and re-implement logging function `_stats` in workers (reference: [System Components](01_worker_base.md)).
49 |
50 | Everytime a worker finishes a round of `_poll()`, it will try to call `_stats()` function to retrive any data that is required to be logged and log them by `logger` (terminal logging) or `self.__wandb_run.log(...)` (wandb logging) to complete a logging step. `_stats()` function returns a list of `LogEntry` that stores data. A `LogEntry` contains 2 fields, a stats dict `stats` and `step` (used for wandb logging by steps). On every logging step, log entries in list returned by `_stats()` will be logged one by one. For a log entry `e`, if terminal logging is used, `e.stats` will be printed. If wandb logging is used and `e.step >= 0`, `self.__wandb_run.log(e.stats, step=e.step)` will be called. If `e.step < 0` (default), `self.__wandb_run.log(e.stats)` will be called.
51 |
52 | If you wish to log your own data, you need to modify source code for worker implementations. Store data as attributes of worker class and return them as log entries in `_stats()`.
53 |
54 | # Related References
55 | - [System Components: Worker Base](../01_worker_base.md)
56 |
57 | # Related Files and Directories
58 | - [system/worker_base.py](../../src/rlsrl/system/worker_base.py)
59 |
60 | # What's next
61 | - [User Guide: Environments](03_environments.md)
62 |
--------------------------------------------------------------------------------
/docs/user_guide/03_environments.md:
--------------------------------------------------------------------------------
1 | # Environment
2 |
3 | Our **Environment** API provides an interface for user to implement any RL environments, including multi-agent environments. Our environment API as the following features:
4 |
5 | 1. Agents in an environment can have different observation shapes, action spaces, etc.;
6 | 2. Environment can be asynchronous. In some multi-agent environments, e.g. Hanabi, only part of the agents make action at each step. We solve these cases by supporting `None` return values for agents, and does not generate action for such agents in an environment step.
7 |
8 | and the following limitations:
9 |
10 | 1. Environments have to be homogeneous within an execution, the number of agents in an environment cannot change during an execution;
11 | 2. Observation space of each agent cannot change (There is no dynamic inference stream matching.)
12 |
13 | ## Environment Dataclasses
14 |
15 | In [api/env_utils.py](../../src/rlsrl/api/env_utils.py), we provide utility data classes to represent actions and action spaces in the form of data structures in our system. We support both discrete (class `DiscreteAction` and `DiscreteActionSpace`) and continuous (class `ContinuousAction` and `ContinuousActionSpace`).
16 |
17 | Moreover, environments pass the result of each reset and step for an single agent to the system as a dataclass `StepResult`, which includes 4 fields:
18 |
19 | - `obs`: one step observation.
20 | - `reward`: one step reward.
21 | - `done`: whether this agent is done in this episode.
22 | - `info`: other informations.
23 |
24 | ## Implementing Environments
25 |
26 | The procedure of implementing new environments is similar to writing a new config file. A new **Environment** is defined by a subclass of `Environment` from [api/environment.py](../../src/rlsrl/api/environment.py). To run experiments with a new environment, first register it with method `rlsrl.api.environment.register(name, env_class)`, then specify it in experiment config file with `Environment` spec. For example, a new environment `SomeEnvironment(Environment)` is implemented in file `some_env.py`:
27 |
28 | ```python
29 | # some_env.py
30 |
31 | from rlsrl.api.environment import *
32 |
33 | class SomeEnvironment(Environment):
34 | def __init__(self, **kwargs):
35 | ...
36 |
37 | def step(self, actions):
38 | ...
39 | return StepResult(obs=..., reward=..., ...)
40 |
41 | def reset(self):
42 | ...
43 | return StepResult(obs=..., reward=..., ...)
44 | ```
45 |
46 | You may implement other methods in class `Environment` as well, see [api/environment.py](../../src/rlsrl/api/environment.py) for details. In the end of `some_env.py`, environment needs to be registered:
47 |
48 | ```python
49 | # some_env.py
50 |
51 | register("some_env", SomeEnvironment)
52 | ```
53 |
54 | After that, you can write your experiment config file, and write `ActorWorker` spec with your environment spec:
55 |
56 | ```python
57 | # some_env_config.py
58 |
59 | class SomeEnvExperiment(Experiment):
60 | def initial_setup(self):
61 | ...
62 | return ExperimentConfig(actor_workers=
63 | [ActorWorker(env=Environment(
64 | type_="some_env",
65 | args=args,
66 | )
67 | ...
68 | )
69 | ...],
70 | ...
71 | )
72 |
73 | register("some_env_expr", SomeEnvExperiment)
74 |
75 | ```
76 |
77 | When running experiment with command line, `some_env.py` should be included in the `--import_files` option as well:
78 |
79 | `srl-local run -e some_env_expr -f hello --import_files some_env.py;some_env_config.py`
80 |
81 | Of course, you could refer to environments that has already been implemented in the system in directory [legacy/environment/](../../src/rlsrl/legacy/environment):
82 |
83 | - **[Gym Atari](../../src/rlsrl/legacy/environment/atari/atari_env.py)**
84 | - **[Google football](../../src/rlsrl/legacy/environment/google_football/gfootball_env.py)**
85 | - **[Gym MuJoCo](../../src/rlsrl/legacy/gym_mujoco/gym_mujoco_env.py)**
86 | - **[Hide and Seek](../../src/rlsrl/legacy/environment/hide_and_seek/hns_env.py)**
87 | - **[SMAC](../../src/rlsrl/legacy/environment/smac/smac_env.py)**
88 |
89 |
90 | # Related References
91 | - [System Components: Actor Worker](../02_actor_worker.md)
92 |
93 | # Related Files and Directories
94 | - [api/env_utils.py](../../src/rlsrl/api/env_utils.py)
95 | - [api/environment.py](../../src/rlsrl/api/environment.py)
96 | - [legacy/environment/](../../src/rlsrl/legacy/environment/)
97 |
98 | # What's next
99 | - [User Guide: Policy and algorithm development](04_policy_algorithm.md)
100 |
--------------------------------------------------------------------------------
/docs/user_guide/05_named_array.md:
--------------------------------------------------------------------------------
1 | # User Guide: NamedArray
2 |
3 | **NamedArray** is a key data structure in the system and used almost everywhere.
4 | This intro will get your started.
5 |
6 | *We did not come up with the idea. Read the comments in code files for more details.*
7 |
8 | *Debugging named array related problems may become troublesome. Make sure you read this doc carefully.*
9 |
10 | ## Why use NamedArray
11 | Named array extends numpy array in the
12 | following ways.
13 | 1. Each NamedArray aggregates multiple numpy arrays, possibly of different shapes.
14 | 2. Each numpy array is given a name, providing a user-friendly way of indexing to the corresponding data.
15 | 3. Named arrays can be nested.
16 |
17 | All of these come in handy in reinforcement learning, where data of different shapes and nesting-relations
18 | are passed around between system components.
19 |
20 | ## Creating and Indexing a NamedArray
21 |
22 | Let's use gym api for an example.
23 |
24 | ```python
25 | import numpy as np
26 | from rlsrl.base.namedarray import NamedArray
27 |
28 | # Suppose episode info contains `episode_length` and `episode_return`.
29 | class EpisodeInfo(NamedArray):
30 | def __init__(
31 | episode_length: np.ndarray,
32 | episode_return: np.ndarray
33 | )
34 | super(EpisodeInfo, self).__init__(episode_length=episode_length,
35 | episode_return=episode_return)
36 |
37 |
38 | @namedarray
39 | class MiniSampleBatch(NamedArray):
40 | def __init__(
41 | obs: np.ndarray,
42 | reward: np.ndarray,
43 | done: np.ndarray,
44 | info: EpisodeInfo
45 | )
46 | super(MiniSampleBatch, self).__init__(obs=obs,
47 | reward=reward,
48 | done=done,
49 | info=info)
50 |
51 | ```
52 |
53 | As shown, the syntax is just like for python dataclasses. However, the datatype is restricted to numpy.ndarray or other
54 | NamedArray class.
55 |
56 | `EpisodeInfo` has two fields: `episode_length` and `episode_return`. Let's see how you can assign and get their
57 | values.
58 |
59 | ```python
60 | ei = EpisodeInfo(episode_length=np.full(shape=(10, 1), fill_value=10),
61 | episode_return=np.ones(shape=(10, 1))
62 | )
63 |
64 | print(ei.episode_return)
65 | print(ei.episode_length.shape)
66 | ei5 = ei[:5]
67 | print(ei5)
68 | print(ei5.shape)
69 | ```
70 |
71 | Not surprisingly, `ei[:5]` returns an `EpisodeInfo` instance. The `shape` attribute of a NamedArray prints the shape
72 | of each field iteratively. Now let's try how things work out for nested data.
73 |
74 | ```python
75 | msb = MiniSampleBatch(obs=np.random.random(size=(10, 3, 200, 200)),
76 | reward=np.random.random(size=(10, 1)),
77 | done=np.array([False] * 9 + [True]),
78 | info=ei)
79 | print(msb.shape)
80 | print(len(msb))
81 | print(msb["obs"])
82 | print(msb[:5].shape)
83 | ```
84 | As shown, `msb["obs"]` is equivalent to `msb.obs` and indexing will apply to sub-NamedArray.
85 |
86 |
87 | ## Aggregation and Mapping
88 |
89 | Image that we are running a gym environment and result for each step is wrapped in a `MiniSampleBatch`.
90 |
91 | For demonstration purpose, let's give each field a default value.
92 |
93 | ```python
94 | from rlsrl.base.namedarray import recursive_aggregate, NamedArray
95 |
96 | class MiniSampleBatch(NamedArray):
97 | def __init__(
98 | obs: np.ndarray = np.ones(shape=(3, 200, 200)),
99 | reward: np.ndarray = np.zeros(shape=(1, )),
100 | done: np.ndarray = np.zeros(shape=(1, )),
101 | info: EpisodeInfo = None
102 | )
103 | super(MiniSampleBatch, self).__init__(obs=obs,
104 | reward=reward,
105 | done=done,
106 | info=info)
107 |
108 | msb_list = [MiniSampleBatch() for _ in range(10)]
109 |
110 | agg_msb = recursive_aggregate(xs=msb_list,
111 | aggregate_fn=np.stack)
112 | print(agg_msb.__class__)
113 | print(agg_msb.shape)
114 | ```
115 |
116 | By using `recursive_aggregate`, np.stack is applied to each field, except for those with value None. The aggregation
117 | result is returned as a new instance of `MiniSampleBatch`.
118 |
119 | In other occasions, we may want to apply some function to each field of a single NamedArray instance. For example,
120 | when training, all numpy arrays must be converted to pytorch Tensors, or must be moved to GPU for gradient computation.
121 |
122 | ```python
123 | import torch
124 | from rlsrl.base.namedarray import recursive_apply
125 |
126 | torch_msb = recursive_apply(x=agg_msb,
127 | fn=lambda x: torch.from_numpy(x).to("cuda:0"))
128 | ```
129 |
130 | ## FAQ
131 |
132 | 1. As a algorithm developer, where should I use NamedArray?
133 | - In your environment, reset/step would return a list of `StepResult`. The `observation` and `episode_info` in each
134 | StepResult should be NamedArray. Reward and done are numpy arrays.
135 | - In your policy, the `analyze` and `rollout` methods, the sample and rollout requests are passed in as NamedArray.
136 | You will have to indexing through them to get your data.
137 | - In your trainer, when you implement the `step` method, the sample_batch is passed in as a NamedArray. For specific
138 | cases like data chunking, you will have to use `recursive_apply`.
139 |
140 | 2. What happens if I use `recursive_aggregate` on NamedArrays of different shapes, possibly with some Nones?
141 | - Firstly, the nesting structure, including names of each field, must match. Otherwise, it causes error.
142 | - If all shapes of a specific data field (not nesting NamedArray) matches except some `None`s. All `None`s
143 | will be filled with numpy.zeros(shape=).
144 | - If any two arrays in a specific data field differs in shape, the aggregation causes error.
145 | - One common cause is that the environment returns values (e.g. observation) of inconsistent shapes.
146 |
147 | # Related References
148 | - [System Components: Basic Utils](../09_basic_utils.md)
149 |
150 | # Related Files and Directories
151 | - [base/named_array.py](../../src/rlsrl/base/named_array.py)
152 |
153 | # What's Next
154 | - [System Components: Overview](../00_system_overview.md)
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | with open("README.md", "r") as f:
4 | long_description = f.read()
5 |
6 | setup(
7 | name="srl-local",
8 | version="0.0.1",
9 | author="openpsi-projects",
10 | author_email="openpsi.projects@gmail.com",
11 | description="open-source SRL",
12 | long_description=long_description,
13 | long_description_content_type="text/markdown",
14 | python_requires=">=3.8",
15 | packages=find_packages(where="src"),
16 | package_dir={"":"src"},
17 | install_requires=[
18 | "gym",
19 | "torch",
20 | "wandb"
21 | ],
22 | entry_points={
23 | 'console_scripts': [
24 | 'srl-local = rlsrl.apps.main:main',
25 | ]
26 | }
27 | )
--------------------------------------------------------------------------------
/src/rlsrl/api/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openpsi-projects/srl/97b3ae4e5fab00f6da90f81679dc035a9079eb09/src/rlsrl/api/__init__.py
--------------------------------------------------------------------------------
/src/rlsrl/api/curriculum.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List, Dict, Union
2 | import logging
3 |
4 | from rlsrl.base.names import curriculum_stage
5 | from rlsrl.base.conditions import make as make_condition
6 | import rlsrl.api.config as config_api
7 | import rlsrl.base.name_resolve as name_resolve
8 |
9 |
10 | class Curriculum:
11 | """Abstract class of a curriculum_learning. Curriculum controls the stages of training.
12 | Typically, evaluation manager will pass the evaluation result to curriculum_learning control and the curriculum_learning
13 | control will decide whether to change the global stage. As a result, changes will be made to the environments
14 | upon resetting.
15 | """
16 |
17 | def submit(self, data: Dict) -> bool:
18 | """Submit the episode info to the curriculum_learning.
19 | Args:
20 | data: episode info, typically results of a batch of evaluations.
21 | Returns:
22 | done (bool): whether the curriculum_learning is finished.
23 | """
24 | raise NotImplementedError()
25 |
26 | def reset(self):
27 | """Reset the Curriculum.
28 | """
29 | raise NotImplementedError()
30 |
31 | def get_stage(self) -> str:
32 | """Get the current course of the curriculum.
33 | Returns:
34 | course_name(str): name of the current course.
35 | """
36 | raise NotImplementedError()
37 |
38 |
39 | class LinearCurriculum(Curriculum):
40 |
41 | def __init__(self, experiment_name, trial_name, curriculum_name,
42 | stages: Union[str, List[str]],
43 | condition_cfg: List[config_api.Condition]):
44 | self.__experiment_name = experiment_name
45 | self.__trial_name = trial_name
46 | self.__curriculum_name = curriculum_name
47 | self.logger = logging.getLogger(f"Curriculum {self.__curriculum_name}")
48 | self.__conditions = [make_condition(cond) for cond in condition_cfg]
49 | if isinstance(stages, str):
50 | self.__stages = [stages]
51 | else:
52 | self.__stages = stages
53 | self.__stage_index = 0
54 |
55 | def reset(self):
56 | self.__stage_index = 0
57 | self.set_stage(self.__stages[self.__stage_index])
58 |
59 | def set_stage(self, stage):
60 | self.logger.info(f"now on stage {stage}")
61 | name_resolve.add(curriculum_stage(self.__experiment_name,
62 | self.__trial_name,
63 | self.__curriculum_name),
64 | value=stage,
65 | replace=True)
66 |
67 | def submit(self, data):
68 | for cond in self.__conditions:
69 | if not cond.is_met_with(data):
70 | self.logger.info(f"Condition {cond} is not met.")
71 | return False
72 | else:
73 | self.logger.info("All conditions met.")
74 | if self.__stage_index + 1 == len(self.__stages):
75 | self.logger.info(f"All stages cleared: {self.__stages}")
76 | return True
77 | else:
78 | self.__stage_index += 1
79 | self.set_stage(self.__stages[self.__stage_index])
80 | return False
81 |
82 | def get_stage(self) -> Optional[str]:
83 | try:
84 | return name_resolve.get(
85 | curriculum_stage(self.__experiment_name, self.__trial_name,
86 | self.__curriculum_name))
87 | except name_resolve.NameEntryNotFoundError:
88 | return None
89 |
90 |
91 | def make(cfg: config_api.Curriculum,
92 | worker_info: config_api.WorkerInformation):
93 | if cfg.type_ == config_api.Curriculum.Type.Linear:
94 | return LinearCurriculum(
95 | experiment_name=worker_info.experiment_name,
96 | trial_name=worker_info.trial_name,
97 | curriculum_name=cfg.name,
98 | stages=cfg.stages,
99 | condition_cfg=cfg.conditions,
100 | )
101 | else:
102 | raise NotImplementedError()
103 |
--------------------------------------------------------------------------------
/src/rlsrl/api/env_utils.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | import gym
3 | import numpy as np
4 |
5 | from rlsrl.base.namedarray import NamedArray
6 | import rlsrl.api.environment as environment
7 |
8 |
9 | class DiscreteAction(NamedArray, environment.Action):
10 |
11 | def __init__(self, x: np.ndarray):
12 | super(DiscreteAction, self).__init__(x=x)
13 |
14 | def __eq__(self, other):
15 | assert isinstance(other, DiscreteAction), \
16 | "Cannot compare DiscreteAction to object of class{}".format(other.__class__.__name__)
17 | return self.key == other.key
18 |
19 | def __hash__(self):
20 | return hash(self.x.item())
21 |
22 | @property
23 | def key(self):
24 | return self.x.item()
25 |
26 |
27 | class DiscreteActionSpace(environment.ActionSpace):
28 |
29 | def __init__(self,
30 | space: Union[gym.spaces.Discrete, gym.spaces.MultiDiscrete],
31 | shared=False,
32 | n_agents=-1):
33 | """Discrete Action Space Wrapper.
34 | Args:
35 | space: Action space of one agent. can be (gym.spaces.Discrete) or (gym.spaces.MultiDiscrete)
36 | shared: concatenate action space for multiple agents.
37 | n_agents: number of agents. Effective only if shared=True.
38 | """
39 | self.__shared = shared
40 | self.__n_agents = n_agents
41 | if shared and n_agents == -1:
42 | raise ValueError("n_agents must be given to a shared action space.")
43 | self.__space = space
44 | assert isinstance(space, (gym.spaces.Discrete, gym.spaces.MultiDiscrete)), type(space)
45 |
46 | @property
47 | def n(self):
48 | return self.__space.n if isinstance(self.__space, gym.spaces.Discrete) else self.__space.nvec
49 |
50 | def sample(self, available_action: np.ndarray = None) -> DiscreteAction:
51 | if available_action is None:
52 | if isinstance(self.__space, gym.spaces.Discrete):
53 | if self.__shared:
54 | x = np.array([[self.__space.sample()] for _ in range(self.__n_agents)], dtype=np.int32)
55 | else:
56 | x = np.array([self.__space.sample()], dtype=np.int32)
57 | else:
58 | if self.__shared:
59 | x = np.array([self.__space.sample() for _ in range(self.__n_agents)], dtype=np.int32)
60 | else:
61 | x = np.array(self.__space.sample(), dtype=np.int32)
62 | return DiscreteAction(x)
63 | else:
64 | if self.__shared:
65 | assert available_action.shape == (self.__n_agents, self.__space.n)
66 | x = []
67 | for agent_idx in range(self.__n_agents):
68 | a_x = self.__space.sample()
69 | while not available_action[agent_idx, a_x]:
70 | a_x = self.__space.sample()
71 | x.append([a_x])
72 | x = np.array(x, dtype=np.int32)
73 | else:
74 | assert available_action.shape == (self.__space.n,)
75 | x = self.__space.sample()
76 | while not available_action[x]:
77 | x = self.__space.sample()
78 | x = np.array([x], dtype=np.int32)
79 | return DiscreteAction(x)
80 |
81 |
82 | class ContinuousAction(NamedArray, environment.Action):
83 |
84 | def __init__(self, x=np.ndarray):
85 | super(ContinuousAction, self).__init__(x=x)
86 |
87 | def __eq__(self, other):
88 | assert isinstance(other, ContinuousAction), \
89 | "Cannot compare ContinuousAction to object of class{}".format(other.__class__.__name__)
90 | return self.key == other.key
91 |
92 | @property
93 | def key(self):
94 | return self.x
95 |
96 |
97 | class ContinuousActionSpace(environment.ActionSpace):
98 |
99 | def __init__(self, space: gym.spaces.Box, shared=False, n_agents=-1):
100 | """Continuous Action Space wrapper.
101 | Args:
102 | space: Action space of a single agent. Must be of type gym.spaces.Box.
103 | shared: concatenate action space for multiple agents.
104 | n_agents: number of agents. Effective only if shared=True.
105 | """
106 | self.__shared = shared
107 | self.__n_agents = n_agents
108 | if shared and n_agents == -1:
109 | raise ValueError("n_agents must be given to a shared action space.")
110 | self.__space = space
111 | assert isinstance(space, gym.spaces.Box) and len(space.shape) == 1, type(space)
112 |
113 | @property
114 | def n(self):
115 | return self.__space.shape[0]
116 |
117 | def sample(self) -> ContinuousAction:
118 | if self.__shared:
119 | x = np.stack([self.__space.sample() for _ in range(self.__n_agents)])
120 | else:
121 | x = self.__space.sample()
122 | return ContinuousAction(x)
123 |
--------------------------------------------------------------------------------
/src/rlsrl/api/environment.py:
--------------------------------------------------------------------------------
1 | """Abstraction of the RL environment and related concepts.
2 |
3 | This is basically a clone of the gym interface. The reasons of replicating are:
4 | - Allow easy changing of APIs when necessary.
5 | - Avoid hard dependency on gym.
6 | """
7 | from typing import List, Union, Dict, Type
8 | import dataclasses
9 | import importlib
10 | import numpy as np
11 |
12 | import rlsrl.api.config as config
13 |
14 |
15 | class Action:
16 | pass
17 |
18 |
19 | class ActionSpace:
20 |
21 | def sample(self, *args, **kwargs) -> Action:
22 | raise NotImplementedError()
23 |
24 |
25 | class DataAugmenter:
26 | """DataAugmenter pre-process the generated sample before it is sent to trainers.. Defined per environment.
27 | """
28 |
29 | def process(self, sample):
30 | """Relabel sample. Operation should be in-place.
31 | Args:
32 | sample (algorithm.trainer.SampleBatch) Sample to be augmented.
33 | Return:
34 | augmented_sample (algorithm.trainer.SampleBatch).
35 | """
36 | raise NotImplementedError()
37 |
38 |
39 | class NullAugmenter(DataAugmenter):
40 |
41 | def process(self, sample):
42 | return sample
43 |
44 |
45 | @dataclasses.dataclass
46 | class StepResult:
47 | """Step result for a single agent. In multi-agent scenario, env.step() essentially returns
48 | List[StepResult].
49 | """
50 | obs: Dict
51 | reward: np.ndarray
52 | done: np.ndarray
53 | info: Dict
54 | truncated: np.ndarray = np.zeros(shape=(1, ), dtype=np.uint8)
55 |
56 |
57 | class Environment:
58 |
59 | @property
60 | def agent_count(self) -> int:
61 | raise NotImplementedError()
62 |
63 | @property
64 | def observation_spaces(self) -> List[dict]:
65 | """Return a list of observation spaces for all agents.
66 |
67 | Each element in self.observation_spaces is a Dict, which contains
68 | shapes of observation entries specified by the key.
69 | Example:
70 | -------------------------------------------------------------
71 | self.observation_spaces = [{
72 | 'observation_self': (10, ),
73 | 'box_obs': (9, 15),
74 | }, {
75 | 'observation_self': (20, ),
76 | 'box_obs': (9, 15),
77 | }]
78 | -------------------------------------------------------------
79 | Observation spaces of different agents can be different.
80 | In this case, policies *MUST* be *DIFFERENT*
81 | among agents with different observation dimension.
82 | """
83 | raise NotImplementedError()
84 |
85 | @property
86 | def action_spaces(self) -> List[ActionSpace]:
87 | """Return a list of action spaces for all agents.
88 |
89 | Each element in self.action_spaces is an instance of
90 | env_base.ActionSpace, which is basically a wrapped Dict.
91 | The Dict contains shapes of action entries specified by the key.
92 | **We force each action entry to be either gym.spaces.Discrete
93 | or gym.spaces.Box.**
94 | Example:
95 | -------------------------------------------------------------
96 | self.action_spaces = [
97 | SomeEnvActionSpace(dict(move_x=Discrete(10), move_y=Discrete(10), cursur=Box(2))),
98 | SomeEnvActionSpace(dict(cursur=Box(2)))
99 | ]
100 | -------------------------------------------------------------
101 | Action spaces of different agents can be different.
102 | In this case, policies *MUST* be *DIFFERENT*
103 | among agents with different action output.
104 | """
105 | raise NotImplementedError()
106 |
107 | def reset(self) -> List[StepResult]:
108 | """Reset the environment, and returns a list of step results for all agents.
109 |
110 | Returns:
111 | List[StepResult]: StepResult with valid Observations only.
112 | """
113 | raise NotImplementedError()
114 |
115 | def step(self, actions: List[Action]) -> List[StepResult]:
116 | """ Consume actions and advance one env step.
117 |
118 | Args:
119 | actions (List[Action]): Actions of all agents.
120 |
121 | Returns:
122 | step result (StepResult): An object with 4 members:
123 | - obs (namedarray): It contains observations, available actions, masks, etc.
124 | - reward (numpy.ndarray): A numpy array with shape [1].
125 | - done (numpy.ndarray): A numpy array with shape [1],
126 | indicating whether an episode is done or an agent is dead.
127 | - info (namedarray): Customized namedarray recording required summary infos.
128 | """
129 | raise NotImplementedError()
130 |
131 | def render(self) -> None:
132 | pass
133 |
134 | def seed(self, seed):
135 | """Set a random seed for the environment.
136 |
137 | Args:
138 | seed (Any): The seed to be set. It could be int,
139 | str or any other types depending on the implementation of
140 | the specific environment. Defaults to None.
141 |
142 | Returns:
143 | Any: The new seed.
144 | """
145 | raise NotImplementedError()
146 |
147 | def set_curriculum_stage(self, stage_name: str):
148 | """Set the environment to be in a certain stage.
149 | Args:
150 | stage_name: name of the stage to be set.
151 | """
152 | raise NotImplementedError()
153 |
154 |
155 | ALL_ENVIRONMENT_CLASSES = {}
156 | ALL_ENVIRONMENT_MODULES = {}
157 | ALL_AUGMENTER_CLASSES = {}
158 |
159 |
160 | def register(name, env_class: Union[Type, str], module=None):
161 | """Register a environment. If env_class is string, the module is registered implicitly. The corresponding
162 | module is only imported when the environment is created.
163 | Args:
164 | name: A reference name of the environment. Use this name in your experiment configuration.
165 | env_class: Class of the environment. If passed as string, its source module is required.
166 | module: String, the module path the find the env_class, ignored when env_class is not a string.
167 |
168 | Raises:
169 | KeyError: if name is already registered.
170 |
171 | Examples:
172 | # codespace/implementation/this_is_my_env.py
173 | class ThisIsMyEnv(api.environment.Environment):
174 | ...
175 | register("this-is-my-env", ThisIsMyEnv)
176 | # OR
177 | register("this-is-my-env", "ThisIsMyEnv", "codespace.implementation.this_is_my_env")
178 |
179 | """
180 | if name in ALL_ENVIRONMENT_CLASSES:
181 | raise KeyError(
182 | f"Environment {name} already registered as {ALL_ENVIRONMENT_CLASSES[name]}. "
183 | f"But got another register with env_class={env_class} and module={module}"
184 | )
185 | if isinstance(env_class, str):
186 | assert module is not None, "For safe registration, specify module in api.environment.register."
187 | ALL_ENVIRONMENT_MODULES[name] = module
188 | ALL_ENVIRONMENT_CLASSES[name] = env_class
189 |
190 |
191 | def register_relabler(name, relabeler_class):
192 | ALL_AUGMENTER_CLASSES[name] = relabeler_class
193 |
194 |
195 | def make(cfg: Union[str, config.Environment]) -> Environment:
196 | env_type_ = cfg if isinstance(cfg, str) else cfg.type_
197 | if isinstance(cfg, str):
198 | cfg = config.Environment(type_=cfg)
199 | if isinstance(ALL_ENVIRONMENT_CLASSES[env_type_], str):
200 | if env_type_ not in ALL_ENVIRONMENT_MODULES:
201 | raise RuntimeError(
202 | "Module is not registered correctly for safe registration.")
203 | m_ = importlib.import_module(ALL_ENVIRONMENT_MODULES[env_type_])
204 | ALL_ENVIRONMENT_CLASSES[env_type_] = getattr(
205 | m_, ALL_ENVIRONMENT_CLASSES[env_type_])
206 | cls = ALL_ENVIRONMENT_CLASSES[env_type_]
207 | return cls(**cfg.args)
208 |
209 |
210 | register_relabler("NULL", NullAugmenter)
211 |
212 |
213 | def make_augmenter(cfg: Union[str, config.DataAugmenter]) -> DataAugmenter:
214 | augmenter_type = cfg if isinstance(cfg, str) else cfg.type_
215 | cls = ALL_AUGMENTER_CLASSES[augmenter_type]
216 | return cls(**cfg.args)
217 |
--------------------------------------------------------------------------------
/src/rlsrl/api/trainer.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 | from typing import Union, Dict, Optional, List
3 | import dataclasses
4 | import numpy as np
5 | import torch
6 | import torch.distributed as dist
7 |
8 | from rlsrl.base.namedarray import NamedArray, recursive_apply
9 | from rlsrl.api.environment import Action
10 | import rlsrl.api.policy as policy_api
11 | import rlsrl.api.config as config_api
12 |
13 |
14 | class SampleBatch(NamedArray):
15 | # `SampleBatch` is the general data structure and will be used for ALL the algorithms we implement.
16 | # There could be some entries that may not be used by a specific algorithm,
17 | # e.g. log_probs and value are not used by DQN, which will be left as None.
18 |
19 | # `obs`, `on_reset`, `action`, `reward`, and `info` are environment-related data entries.
20 | # `obs` and `on_reset` can be obtained once environment step is preformed.
21 |
22 | def __init__(
23 | self,
24 | obs: NamedArray,
25 | on_reset: np.ndarray = None,
26 | done: np.ndarray = None,
27 | truncated: np.ndarray = None,
28 |
29 | # `action` and `reward` can be obtained when the inference is done.
30 | action: Action = None,
31 | reward: np.ndarray = None,
32 |
33 | # Currently we assume info contains all the information we want to gather in an environment.
34 | # It is NOT agent-specific and should include summary information of ALL agents.
35 | info: NamedArray = None,
36 |
37 | # `info_mask` is recorded for correctly recording summary info when there are
38 | # multiple agents and some agents may die before an episode is done.
39 | info_mask: np.ndarray = None,
40 |
41 | # In some cases we may need Policy State. e.g. Partial Trajectory, Mu-Zero w/o reanalyze.
42 | policy_state: policy_api.PolicyState = None,
43 |
44 | # `analyzed_result` records algorithm-related analyzed results
45 | analyzed_result: policy_api.AnalyzedResult = None,
46 |
47 | # Policy-ralted infos.
48 | policy_name: np.ndarray = None,
49 | policy_version_steps: np.ndarray = None,
50 |
51 | # Metadata.
52 | sampling_weight: np.ndarray = None,
53 | **kwargs):
54 | super(SampleBatch, self).__init__(
55 | obs=obs,
56 | on_reset=on_reset,
57 | done=done,
58 | truncated=truncated,
59 | action=action,
60 | reward=reward,
61 | info=info,
62 | info_mask=info_mask,
63 | policy_state=policy_state,
64 | analyzed_result=analyzed_result,
65 | policy_name=policy_name,
66 | policy_version_steps=policy_version_steps,
67 | )
68 | self.register_metadata(sampling_weight=sampling_weight, )
69 |
70 |
71 | class TrajPostprocessor(ABC):
72 | """Post-process trajectories in actor workers before sending to trainers.
73 |
74 | Basically computing returns, e.g., GAE or n-step return.
75 | """
76 |
77 | def process(self, memory: List[SampleBatch]):
78 | raise NotImplementedError()
79 |
80 |
81 | class NullTrajPostprocessor(TrajPostprocessor):
82 |
83 | def process(self, memory: List[SampleBatch]):
84 | return memory
85 |
86 |
87 | @dataclasses.dataclass
88 | class TrainerStepResult:
89 | stats: Dict # Stats to be logged.
90 | step: int # current step count of trainer.
91 | agree_pushing: Optional[bool] = True # whether agree to push parameters
92 | priorities: Optional[
93 | np.ndarray] = None # New priorities of the PER buffer.
94 |
95 |
96 | class Trainer:
97 |
98 | @property
99 | def policy(self) -> policy_api.Policy:
100 | """Running policy of the trainer.
101 | """
102 | raise NotImplementedError()
103 |
104 | def step(self, samples: SampleBatch) -> TrainerStepResult:
105 | """Advances one training step given samples collected by actor workers.
106 |
107 | Example code:
108 | ...
109 | some_data = self.policy.analyze(sample)
110 | loss = loss_fn(some_data, sample)
111 | self.optimizer.zero_grad()
112 | loss.backward()
113 | ...
114 | self.optimizer.step()
115 | ...
116 |
117 | Args:
118 | samples (SampleBatch): A batch of data required for training.
119 |
120 | Returns:
121 | TrainerStepResult: Entry to be logged by trainer worker.
122 | """
123 | raise NotImplementedError()
124 |
125 | def distributed(self, **kwargs):
126 | """Make the trainer distributed.
127 | """
128 | raise NotImplementedError()
129 |
130 | def get_checkpoint(self, *args, **kwargs):
131 | """Get checkpoint of the model, which typically includes:
132 | 1. Policy state (e.g. neural network parameter).
133 | 2. Optimizer state.
134 | Return:
135 | checkpoint to be saved.
136 | """
137 | raise NotImplementedError()
138 |
139 | def load_checkpoint(self, checkpoint, **kwargs):
140 | """Load a saved checkpoint.
141 | Args:
142 | checkpoint: checkpoint to be loaded.
143 | """
144 | raise NotImplementedError()
145 |
146 |
147 | class PytorchTrainer(Trainer, ABC):
148 |
149 | @property
150 | def policy(self) -> policy_api.Policy:
151 | return self._policy
152 |
153 | def __init__(self, policy: policy_api.Policy):
154 | """Initialization method of Pytorch Trainer.
155 | Args:
156 | policy: Policy to be trained.
157 |
158 | Note:
159 | After initialization, access policy from property trainer.policy
160 | """
161 | if policy.device != "cpu":
162 | torch.cuda.set_device(policy.device)
163 | torch.cuda.empty_cache()
164 | self._policy = policy
165 |
166 | def distributed(self, rank, world_size, init_method, **kwargs):
167 | is_gpu_process = all([
168 | torch.distributed.is_nccl_available(),
169 | torch.cuda.is_available(),
170 | self.policy.device != "cpu",
171 | ])
172 | dist.init_process_group(backend="nccl" if is_gpu_process else "gloo",
173 | init_method=init_method,
174 | rank=rank,
175 | world_size=world_size)
176 | self.policy.distributed()
177 |
178 | def __del__(self):
179 | if dist.is_initialized():
180 | dist.destroy_process_group()
181 |
182 |
183 | class PyTorchGPUPrefetcher:
184 | """Prefetch sample into GPU in trainer.
185 |
186 | Reference: https://github.com/NVIDIA/apex/blob/f5cd5ae937f168c763985f627bbf850648ea5f3f/examples/imagenet/main_amp.py#L256.
187 | """
188 |
189 | def __init__(self):
190 | self.stream = torch.cuda.Stream()
191 | self.nex_numpy_sample = None
192 | self.nex_torch_sample = None
193 | self.initialized_prefetching = False
194 |
195 | def _preload(self, sample):
196 | self.nex_numpy_sample = sample
197 | with torch.cuda.stream(self.stream):
198 | # NOTE: Use `.to(device)` instead of `.cuda` will not accerlate data loading.
199 | self.nex_torch_sample = recursive_apply(
200 | self.nex_numpy_sample,
201 | lambda x: torch.from_numpy(x).cuda(non_blocking=True))
202 | self.nex_torch_sample = recursive_apply(self.nex_torch_sample,
203 | lambda x: x.float())
204 |
205 | def push(self, sample):
206 | if not self.initialized_prefetching:
207 | self._preload(sample)
208 | self.initialized_prefetching = True
209 | return None
210 | torch.cuda.current_stream().wait_stream(self.stream)
211 | numpy_sample = self.nex_numpy_sample
212 | torch_sample = self.nex_torch_sample
213 | self._preload(sample)
214 | return numpy_sample, torch_sample
215 |
216 |
217 | ALL_TRAINER_CLASSES = {}
218 |
219 |
220 | def register(name, trainer_class):
221 | ALL_TRAINER_CLASSES[name] = trainer_class
222 |
223 |
224 | def make(cfg: Union[str, config_api.Trainer],
225 | policy_cfg: Union[str, config_api.Policy]) -> Trainer:
226 | if isinstance(cfg, str):
227 | cfg = config_api.Trainer(type_=cfg)
228 | if isinstance(policy_cfg, str):
229 | policy_cfg = config_api.Policy(type_=policy_cfg)
230 | cls = ALL_TRAINER_CLASSES[cfg.type_]
231 | policy = policy_api.make(policy_cfg)
232 | policy.train_mode() # To be explicit.
233 | return cls(policy=policy, **cfg.args)
234 |
235 |
236 | ALL_TRAJ_POSTPROCESSOR_CLASSES = {}
237 |
238 |
239 | def register_traj_postprocessor(name, cls_):
240 | ALL_TRAJ_POSTPROCESSOR_CLASSES[name] = cls_
241 |
242 |
243 | register_traj_postprocessor('null', NullTrajPostprocessor)
244 |
245 |
246 | def make_traj_postprocessor(cfg: Union[str, config_api.TrajPostprocessor]):
247 | if isinstance(cfg, str):
248 | cfg = config_api.TrajPostprocessor(cfg)
249 | augmenter_type = cfg if isinstance(cfg, str) else cfg.type_
250 | cls = ALL_TRAJ_POSTPROCESSOR_CLASSES[augmenter_type]
251 | return cls(**cfg.args)
--------------------------------------------------------------------------------
/src/rlsrl/apps/main.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, List
2 | import argparse
3 | import itertools
4 | import logging
5 | import multiprocessing
6 | import multiprocessing.connection as mp_connection
7 | import os
8 |
9 | # multiprocessing.set_start_method("spawn", force=True)
10 |
11 | LOG_FORMAT = "%(asctime)s.%(msecs)03d %(name)s %(levelname)s: %(message)s"
12 | DATE_FORMAT = "%Y%m%d-%H:%M:%S"
13 | logging.basicConfig(format=LOG_FORMAT,
14 | datefmt=DATE_FORMAT,
15 | level=os.environ.get("LOGLEVEL", "INFO"))
16 |
17 | import rlsrl.api.config as config_api
18 | import rlsrl.system as system
19 | import rlsrl.system.api.worker_control as worker_control
20 |
21 | # Import runnable legacy code.
22 | import rlsrl.legacy.algorithm
23 | import rlsrl.legacy.environment
24 | import rlsrl.legacy.experiments
25 |
26 | logger = logging.getLogger("SRL")
27 |
28 |
29 | def submit_workers(worker_type, ctrls, experiment_name, trial_name,
30 | env_vars):
31 | count = len(ctrls)
32 | logger.info(f"Submitted {count} {worker_type} worker(s).")
33 | ps = [
34 | multiprocessing.Process(
35 | target=system.run_worker,
36 | args=(
37 | worker_type,
38 | experiment_name,
39 | trial_name,
40 | f"{worker_type}_{i}",
41 | ctrl,
42 | env_vars,
43 | ),
44 | ) for i, ctrl in enumerate(ctrls)
45 | ]
46 | for p in ps:
47 | p.start()
48 | return ps
49 |
50 |
51 | def group_request(cmd: str, cmd_args: List[Tuple],
52 | ctrl_handles: List[mp_connection.Connection]):
53 | for tx, arg in zip(ctrl_handles, cmd_args):
54 | tx.send((cmd, arg))
55 | return [tx.recv() for tx in ctrl_handles]
56 |
57 |
58 | def run_local(args):
59 | exps = config_api.make_experiment(args.experiment_name)
60 | if len(exps) > 1:
61 | # TODO: add support for consecutive multiple experiments.
62 | raise NotImplementedError()
63 |
64 | exp: config_api.Experiment = exps[0]
65 | setup = exp.initial_setup()
66 | setup.set_worker_information(args.experiment_name, args.trial_name)
67 | # Filter and only remain the workers that are assigned to this node.
68 | for worker_type, spec in system.RL_WORKERS.items():
69 | field_name = spec.config_field_name
70 | worker_setups = getattr(setup, field_name)
71 | filtered = filter(lambda x: x.worker_info.node_rank == args.node_rank,
72 | worker_setups)
73 | setattr(setup, field_name, list(filtered))
74 |
75 | logger.info(f"Node {args.node_rank}: Running {exp.__class__.__name__} "
76 | f"experiment_name: {args.experiment_name}"
77 | f" trial_name {args.trial_name}")
78 |
79 | name_resolve_address = None
80 | name_resolve_port = None
81 |
82 | workers = dict()
83 | worker_tx = dict()
84 |
85 | env_vars = {
86 | "PYTHONPATH": os.path.dirname(os.path.dirname(__file__)),
87 | "NCCL_IB_DISABLE": "1",
88 | "WANDB_MODE": args.wandb_mode,
89 | "LOGLEVEL": os.environ.get("LOGLEVEL", "INFO"),
90 | }
91 |
92 | master_setup = setup.master_worker
93 | if len(master_setup) > 1:
94 | raise RuntimeError("Only one or zero master worker is supported.")
95 | if len(master_setup) == 1:
96 | master_setup = master_setup[0]
97 | name_resolve_address = master_setup.address
98 | name_resolve_port = master_setup.port
99 | master_setup.worker_info.set_name_resolve_address(
100 | name_resolve_address, name_resolve_port)
101 | assert master_setup.worker_info.node_rank == 0, "Master worker should be allocated at master node (rank = 0)."
102 | if args.node_rank == 0:
103 | tx, rx = multiprocessing.Pipe()
104 | procs = submit_workers("master",
105 | [worker_control.WorkerCtrl(rx, None, None)],
106 | args.experiment_name, args.trial_name,
107 | env_vars)
108 | workers['master'] = procs
109 | worker_tx['master'] = [tx]
110 |
111 | group_request("configure", [(master_setup, )], worker_tx['master'])
112 | group_request("start", [()], worker_tx['master'])
113 |
114 | setup, ctrls, tx_handles = worker_control.make_worker_control(
115 | args.experiment_name, args.trial_name, setup)
116 | worker_tx.update(tx_handles)
117 |
118 | # Submit workers.
119 | for worker_type in system.RL_WORKERS:
120 | if worker_type not in ctrls:
121 | continue
122 | procs = submit_workers(worker_type, ctrls[worker_type],
123 | args.experiment_name, args.trial_name, env_vars)
124 | workers[worker_type] = procs
125 | logger.info(f"Node {args.node_rank}: Submitted all workers.")
126 |
127 | # Configure workers.
128 | for worker_type, spec in system.RL_WORKERS.items():
129 | if worker_type not in workers:
130 | continue
131 | worker_setups = getattr(setup, spec.config_field_name)
132 | for worker_setup in worker_setups:
133 | assert worker_setup.worker_info.node_rank == args.node_rank
134 | worker_setup.worker_info.set_name_resolve_address(
135 | name_resolve_address, name_resolve_port)
136 |
137 | group_request("configure", [(w, ) for w in worker_setups],
138 | worker_tx[worker_type])
139 | logger.info(f"Node {args.node_rank}: Configured all workers.")
140 |
141 | # Start workers.
142 | for worker_type in system.RL_WORKERS:
143 | if worker_type not in workers:
144 | continue
145 | group_request("start", [() for _ in worker_tx[worker_type]],
146 | worker_tx[worker_type])
147 | logger.info(
148 | f"Node {args.node_rank}: Experiment successfully started. Check wandb for progress."
149 | )
150 |
151 | for w in itertools.chain.from_iterable(workers.values()):
152 | w.join(timeout=args.timeout)
153 |
154 | for ctrl in itertools.chain.from_iterable(ctrls.values()):
155 | if ctrl.inf_ctrls is not None:
156 | for c in ctrl.inf_ctrls:
157 | c.request_ctrl.close()
158 | c.response_ctrl.close()
159 | if ctrl.spl_ctrls is not None:
160 | for c in ctrl.spl_ctrls:
161 | c.close()
162 |
163 |
164 | def main():
165 | parser = argparse.ArgumentParser(prog="srl-local")
166 | subparsers = parser.add_subparsers(dest="cmd", help="sub-command help")
167 | subparsers.required = True
168 |
169 | subparser = subparsers.add_parser("run", help="starts a basic experiment")
170 | subparser.add_argument("--node_rank",
171 | "-n",
172 | type=int,
173 | required=False,
174 | default=0,
175 | help="Rank of the node. 0 = master node.")
176 | subparser.add_argument("--experiment_name",
177 | "-e",
178 | type=str,
179 | required=True,
180 | help="name of the experiment")
181 | subparser.add_argument("--trial_name",
182 | "-f",
183 | type=str,
184 | required=True,
185 | help="name of the trial")
186 | subparser.add_argument("--wandb_mode",
187 | type=str,
188 | default="disabled",
189 | choices=["online", "offline", "disabled"])
190 | subparser.add_argument("--timeout",
191 | "-t",
192 | type=int,
193 | default=3600,
194 | help="Timeout for the experiment. (seconds)")
195 |
196 | subparser.set_defaults(func=run_local)
197 |
198 | args = parser.parse_args()
199 | args.func(args)
200 |
201 |
202 | if __name__ == '__main__':
203 | main()
204 |
--------------------------------------------------------------------------------
/src/rlsrl/base/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openpsi-projects/srl/97b3ae4e5fab00f6da90f81679dc035a9079eb09/src/rlsrl/base/__init__.py
--------------------------------------------------------------------------------
/src/rlsrl/base/conditions.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import queue
3 | import torch
4 |
5 | import rlsrl.api.config as config
6 |
7 |
8 | class Condition:
9 | """Defines a condition to be checked.
10 | """
11 |
12 | def is_met_with(self, data) -> bool:
13 | """Check whether passed data satisfies this condition.
14 | Args:
15 | data[key-value]: data to be checked.
16 | Returns:
17 | check_passed (bool)
18 | """
19 | raise NotImplementedError()
20 |
21 | def reset(self):
22 | """Reset condition to initial state.
23 | """
24 | raise NotImplementedError()
25 |
26 |
27 | class SimpleBoundCondition:
28 |
29 | def __init__(self, field, lower_limit=None, upper_limit=None):
30 | self.field = field
31 | self.lower_limit = lower_limit or -np.inf
32 | self.upper_limit = upper_limit or np.inf
33 |
34 | def is_met_with(self, data):
35 | if self.field not in data:
36 | raise ValueError(f"target field {self.field} not found when checking condition {self}")
37 | if isinstance(data[self.field], (np.ndarray, torch.Tensor)):
38 | return self.lower_limit < data[self.field].mean() < self.upper_limit
39 | else:
40 | return self.lower_limit < data[self.field] < self.upper_limit
41 |
42 | def reset(self):
43 | return
44 |
45 | def __str__(self):
46 | return f" {self.lower_limit} < {self.field} < {self.upper_limit}"
47 |
48 |
49 | class ConvergedCondition:
50 |
51 | def __init__(self, value_field, step_field, warmup_step=0, duration=100, confidence=0.9, threshold=1e-2):
52 | """Check if the target value is converged.
53 | Args:
54 | value_field: target value field. E.g., episode_return.
55 | step_field: step field. E.g., version.
56 | warmup_step: always return False when step < warmup_step.
57 | duration: all values within the last duration steps are cached to check convergence.
58 | confidence: what percentage of cached values to use for convergence check. E.g., if confidence is
59 | 0.9, then only 90% values are used, the smallest 5% and largest 5% values are ignored.
60 | threshold: the acceptable difference between the largest and the smallest value within the
61 | confidence interval.
62 | """
63 |
64 | self.value_field = value_field
65 | self.step_field = step_field
66 | self.warmup_step = warmup_step
67 | self.duration = duration
68 | self.confidence = confidence
69 | self.threshold = threshold
70 | self.__head_step = None
71 | self.__step_queue = None
72 | self.__value_queue = None
73 |
74 | self.reset()
75 |
76 | def reset(self):
77 | self.__head_step = None
78 | self.__step_queue = queue.Queue()
79 | self.__value_queue = queue.Queue()
80 |
81 | def is_met_with(self, data):
82 | if self.value_field not in data or self.step_field not in data:
83 | raise ValueError(f"target field {self.value_field} or {self.step_field} not found when checking "
84 | f"condition {self}")
85 | if isinstance(data[self.step_field], (np.ndarray, torch.Tensor)):
86 | step = data[self.step_field].mean()
87 | else:
88 | step = data[self.step_field]
89 | if isinstance(data[self.value_field], (np.ndarray, torch.Tensor)):
90 | value = data[self.value_field].mean()
91 | else:
92 | value = data[self.value_field]
93 |
94 | if step < self.warmup_step:
95 | return False
96 | self.__step_queue.put(step)
97 | self.__value_queue.put(value)
98 |
99 | if self.__head_step is None:
100 | self.__head_step = self.__step_queue.get()
101 | return False
102 | if step - self.__head_step < self.duration:
103 | return False
104 |
105 | # Check convergence.
106 | values = np.sort(self.__value_queue.queue)
107 | idx = int(len(values) * (1 - self.confidence) / 2)
108 | # Not converged, if the last value falls outside the confidence range.
109 | converged = (max(values[-1 - idx], value) - min(values[idx], value) <= self.threshold)
110 | # Update queues.
111 | while step - self.__head_step >= self.duration:
112 | self.__head_step = self.__step_queue.get()
113 | self.__value_queue.get()
114 | return converged
115 |
116 | def __str__(self):
117 | return f"{self.value_field} converged"
118 |
119 |
120 | def make(cfg: config.Condition):
121 | if cfg.type_ == config.Condition.Type.SimpleBound:
122 | return SimpleBoundCondition(**cfg.args)
123 | elif cfg.type_ == config.Condition.Type.Converged:
124 | return ConvergedCondition(**cfg.args)
125 | else:
126 | raise NotImplementedError()
127 |
--------------------------------------------------------------------------------
/src/rlsrl/base/gpu_utils.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | import itertools
3 | import logging
4 | import os
5 | import platform
6 |
7 | logger = logging.getLogger("System-GPU")
8 |
9 |
10 | def gpu_count():
11 | """Returns the number of gpus on a node. Ad-hoc to frl cluster.
12 | """
13 | if platform.system() == "Darwin":
14 | return 0
15 | elif platform.system() == "Windows":
16 | try:
17 | import torch
18 | return torch.cuda.device_count()
19 | except ImportError:
20 | return 0
21 | else:
22 | dev_directories = list(os.listdir("/dev/"))
23 | for cnt in itertools.count():
24 | if "nvidia" + str(cnt) in dev_directories:
25 | continue
26 | else:
27 | break
28 | return cnt
29 |
30 |
31 | def resolve_cuda_environment():
32 | """Pytorch DDP does not work if more than one processes (with different environment variable CUDA_VISIBLE_DEVICES)
33 | are inited on the same node(w/ multiple GPUS). This function works around the issue by setting another variable.
34 | Currently all devices should use `base.gpu_utils.get_gpu_device()` to get the proper gpu device.
35 | """
36 | if "MARL_CUDA_DEVICES" in os.environ.keys():
37 | return
38 |
39 | cuda_devices = [str(i) for i in range(gpu_count())]
40 | if "CUDA_VISIBLE_DEVICES" not in os.environ:
41 | if len(cuda_devices) > 0:
42 | os.environ["MARL_CUDA_DEVICES"] = "0"
43 | else:
44 | os.environ["MARL_CUDA_DEVICES"] = "cpu"
45 | else:
46 | if os.environ["CUDA_VISIBLE_DEVICES"] != "":
47 | for s in os.environ["CUDA_VISIBLE_DEVICES"].split(","):
48 | assert s.isdigit() and s in cuda_devices, f"Cuda device {s} cannot be resolved."
49 | os.environ["MARL_CUDA_DEVICES"] = os.environ["CUDA_VISIBLE_DEVICES"] # Store assigned device.
50 | else:
51 | os.environ["MARL_CUDA_DEVICES"] = "cpu" # Use CPU if no cuda device available.
52 | os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(cuda_devices) # Make all devices visible.
53 |
54 |
55 | def get_gpu_device() -> List[str]:
56 | """
57 | Returns:
58 | List of assigned devices.
59 | """
60 | if "MARL_CUDA_DEVICES" not in os.environ:
61 | resolve_cuda_environment()
62 |
63 | if os.environ["MARL_CUDA_DEVICES"] == "cpu":
64 | return ["cpu"]
65 | else:
66 | return [f"cuda:{device}" for device in os.environ["MARL_CUDA_DEVICES"].split(",")]
67 |
68 |
69 | def set_cuda_device(device):
70 | """Set the default cuda-device. Useful on multi-gpu nodes. Should be called in every gpu-thread.
71 | """
72 | logger.info(f"Setting device to {device}.")
73 | if device != "cpu":
74 | import torch
75 | torch.cuda.set_device(device)
76 |
77 |
78 | resolve_cuda_environment()
79 |
--------------------------------------------------------------------------------
/src/rlsrl/base/lock.py:
--------------------------------------------------------------------------------
1 | from typing import Literal, Optional
2 | import contextlib
3 | import logging
4 | import multiprocessing as mp
5 |
6 | logger = logging.getLogger("Lock")
7 |
8 |
9 | class RWLock:
10 | """ A lock object that allows many simultaneous write and many
11 | simultaneous read, but cannot be acquired for both at the same time. """
12 |
13 | def __init__(self,
14 | max_concurrent_read: int = 1,
15 | max_concurrent_write: int = 1,
16 | priority: Optional[Literal['read', 'write']] = None):
17 | self._lock = mp.Condition(mp.Lock())
18 |
19 | self._max_read = max_concurrent_read
20 | self._max_write = max_concurrent_write
21 |
22 | self._n_readers = mp.Value("i", 0)
23 | self._n_writers = mp.Value("i", 0)
24 |
25 | self._n_waiting_readers = mp.Value("i", 0)
26 | self._n_waiting_writers = mp.Value("i", 0)
27 |
28 | self._priority = priority
29 | if priority not in [None, 'read', 'write']:
30 | raise ValueError("Priority must be one of 'read', 'write'.")
31 |
32 | @property
33 | def n_waitings(self):
34 | return self._n_waiting_writers.value + self._n_waiting_readers.value
35 |
36 | def acquire_read(self):
37 | self._lock.acquire()
38 | while (self._n_writers.value > 0
39 | or (self._priority == 'write'
40 | and self._n_waiting_writers.value > 0)
41 | or self._n_readers.value >= self._max_read):
42 | self._n_waiting_readers.value += 1
43 | logger.debug(
44 | f"Waiting for read lock. Numer of waiting readers: {self._n_waiting_readers.value}."
45 | )
46 | self._lock.wait()
47 | self._n_waiting_readers.value -= 1
48 | try:
49 | logger.debug(
50 | f"Acquire read. Concurrent readers: {self._n_readers.value}. "
51 | f"Numer of waiting readers: {self._n_waiting_readers.value}.")
52 | self._n_readers.value += 1
53 | assert self._n_readers.value <= self._max_read
54 | finally:
55 | self._lock.release()
56 |
57 | def release_read(self):
58 | self._lock.acquire()
59 | try:
60 | self._n_readers.value -= 1
61 | logger.debug(
62 | f"Relese read. Remaining readers: {self._n_readers.value}. "
63 | f"Numer of waiting readers: {self._n_waiting_readers.value}."
64 | f"Numer of waiting writers: {self._n_waiting_writers.value}.")
65 | if self.n_waitings > 0:
66 | self._lock.notify(self.n_waitings)
67 | finally:
68 | self._lock.release()
69 |
70 | def acquire_write(self):
71 | self._lock.acquire()
72 | while (self._n_readers.value > 0 or
73 | (self._priority == 'read' and self._n_waiting_readers.value > 0)
74 | or self._n_writers.value >= self._max_write):
75 | self._n_waiting_writers.value += 1
76 | logger.debug(
77 | f"Waiting for write lock. "
78 | f"Current writers {self._n_writers.value}. "
79 | f"Numer of waiting writers: {self._n_waiting_writers.value}.")
80 | self._lock.wait()
81 | self._n_waiting_writers.value -= 1
82 | try:
83 | self._n_writers.value += 1
84 | logger.debug(
85 | f"Acquire write. Concurrent writers: {self._n_writers.value}. "
86 | f"Numer of waiting writers: {self._n_waiting_writers.value}.")
87 | assert self._n_writers.value <= self._max_write
88 | finally:
89 | self._lock.release()
90 |
91 | def release_write(self):
92 | self._lock.acquire()
93 | try:
94 | self._n_writers.value -= 1
95 | logger.debug(
96 | f"Relese write. Remaining writers: {self._n_writers.value}. "
97 | f"Numer of waiting readers: {self._n_waiting_readers.value}."
98 | f"Numer of waiting writers: {self._n_waiting_writers.value}.")
99 | if self.n_waitings > 0:
100 | self._lock.notify(self.n_waitings)
101 | finally:
102 | self._lock.release()
103 |
104 | @contextlib.contextmanager
105 | def read_locked(self):
106 | try:
107 | self.acquire_read()
108 | yield
109 | finally:
110 | self.release_read()
111 |
112 | @contextlib.contextmanager
113 | def write_locked(self):
114 | try:
115 | self.acquire_write()
116 | yield
117 | finally:
118 | self.release_write()
119 |
--------------------------------------------------------------------------------
/src/rlsrl/base/names.py:
--------------------------------------------------------------------------------
1 | # This file standardizes the name-resolve names used by different components of the system.
2 | import getpass
3 |
4 | USER_NAMESPACE = getpass.getuser()
5 |
6 |
7 | def registry_root(user):
8 | return f"trial_registry/{user}"
9 |
10 |
11 | def trial_registry(experiment_name, trial_name):
12 | return f"trial_registry/{USER_NAMESPACE}/{experiment_name}/{trial_name}"
13 |
14 |
15 | def trial_root(experiment_name, trial_name):
16 | return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}"
17 |
18 |
19 | def worker_status(experiment_name, trial_name, worker_name):
20 | return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/status/{worker_name}"
21 |
22 |
23 | def worker_root(experiment_name, trial_name):
24 | return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/worker/"
25 |
26 |
27 | def worker(experiment_name, trial_name, worker_name):
28 | return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/worker/{worker_name}"
29 |
30 |
31 | def worker2(experiment_name, trial_name, worker_type, worker_index):
32 | return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/worker/{worker_type}/{worker_index}"
33 |
34 |
35 | def inference_stream(experiment_name, trial_name, stream_name):
36 | return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/inference_stream/{stream_name}"
37 |
38 |
39 | def inference_stream_constant(experiment_name, trial_name, stream_name,
40 | constant_name):
41 | return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/inference_stream_consts/{stream_name}/{constant_name}"
42 |
43 |
44 | def sample_stream(experiment_name, trial_name, stream_name):
45 | return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/sample_stream/{stream_name}"
46 |
47 |
48 | def trainer_ddp_peer(experiment_name, trial_name, policy_name):
49 | return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/trainer_ddp_peer/{policy_name}"
50 |
51 |
52 | def trainer_ddp_master(experiment_name, trial_name, policy_name):
53 | return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/trainer_ddp_master/{policy_name}"
54 |
55 |
56 | def curriculum_stage(experiment_name, trial_name, curriculum_name):
57 | return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/curriculum/{curriculum_name}"
58 |
59 |
60 | def worker_key(experiment_name, trial_name, key):
61 | return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/worker_key/{key}"
62 |
63 |
64 | def parameter_subscription(experiment_name, trial_name):
65 | return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/parameter_sub"
66 |
67 |
68 | def parameter_server(experiment_name, trial_name, parameter_id_str):
69 | return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/parameter_server/{parameter_id_str}"
70 |
71 |
72 | def shared_memory(experiment_name, trial_name, dock_name):
73 | return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/shared_memory/{dock_name}"
74 |
75 |
76 | def pinned_shm_qsize(experiment_name, trial_name, stream_name):
77 | return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/pinned_shm_qsize/{stream_name}"
78 |
--------------------------------------------------------------------------------
/src/rlsrl/base/network.py:
--------------------------------------------------------------------------------
1 | from contextlib import closing
2 | import socket
3 |
4 |
5 | def find_free_port():
6 | """From, stackoverflow Issue 1365265
7 | """
8 | with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
9 | s.bind(('', 0))
10 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
11 | return s.getsockname()[1]
12 |
13 |
14 | def gethostname():
15 | return socket.gethostname()
16 |
17 |
18 | def gethostip():
19 | return socket.gethostbyname(socket.gethostname())
20 |
--------------------------------------------------------------------------------
/src/rlsrl/base/numpy_utils.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 | import numpy as np
3 |
4 |
5 | def split_to_shapes(x: np.ndarray, shapes: Dict, axis: int = -1):
6 | """Split an array and reshape to desired shapes.
7 |
8 | Args:
9 | x (np.ndarray): The array to be splitted
10 | shapes (Dict): Dict of shapes (tuples) specifying how to split.
11 | axis (int): Split dimension.
12 |
13 | Returns:
14 | List: Splitted observations.
15 | """
16 | axis = len(x.shape) + axis if axis < 0 else axis
17 | split_lengths = [np.prod(shape) for shape in shapes.values()]
18 | assert x.shape[axis] == sum(split_lengths)
19 | accum_split_lengths = [
20 | sum(split_lengths[:i]) for i in range(1, len(split_lengths))
21 | ]
22 | splitted_x = np.split(x, accum_split_lengths, axis)
23 | return {
24 | k: x.reshape(*x.shape[:axis], *shape, *x.shape[axis + 1:])
25 | for x, (k, shape) in zip(splitted_x, shapes.items())
26 | }
27 |
28 |
29 | def moving_average(x: np.ndarray, window_size: int):
30 | """Return the moving average of a 1D numpy array.
31 | """
32 | if len(x.shape) != 1:
33 | raise ValueError("Moving average works only on 1D arrays!")
34 | if window_size > x.shape[0]:
35 | raise ValueError(
36 | "Can't average over a window size larger than array length!")
37 | return np.convolve(
38 | np.ones(window_size, dtype=np.float32) / window_size, x, 'valid')
39 |
40 |
41 | def moving_maximum(x: np.ndarray, window_size: int):
42 | if len(x.shape) != 1:
43 | raise ValueError("Moving maximum works only on 1D arrays!")
44 | if window_size > x.shape[0]:
45 | raise ValueError(
46 | "Can't average over a window size larger than array length!")
47 | shape = (x.shape[0] - window_size + 1, window_size)
48 | strides = x.strides + (x.strides[-1], )
49 | rolling = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides)
50 | return np.max(rolling, axis=1)
51 |
52 |
53 | def dtype_to_num_bytes(dtype: np.dtype) -> int:
54 | if dtype in [np.int8, np.uint8]:
55 | return 1
56 | if dtype in [np.uint16, np.int16, np.float16]:
57 | return 2
58 | if dtype in [np.uint32, np.int32, np.float32]:
59 | return 4
60 | if dtype in [np.uint64, np.int64, np.float64]:
61 | return 8
62 | if dtype in [np.float128]:
63 | return 16
64 | if str(dtype).startswith(" str:
70 | if dtype == np.uint8 or dtype == bool:
71 | return "uint8"
72 | if dtype == np.float32:
73 | return "float32"
74 | if dtype == np.float64:
75 | return "float64"
76 | if dtype == np.int32:
77 | return "int32"
78 | if dtype == np.int64:
79 | return "int64"
80 | if str(dtype).startswith(" np.dtype:
86 | return np.dtype(dtype_str)
87 |
--------------------------------------------------------------------------------
/src/rlsrl/base/segment_tree.py:
--------------------------------------------------------------------------------
1 | import operator
2 |
3 |
4 | class SegmentTree(object):
5 |
6 | def __init__(self, capacity, operation, neutral_element):
7 | """Build a Segment Tree data structure.
8 |
9 | https://en.wikipedia.org/wiki/Segment_tree
10 |
11 | Can be used as regular array, but with two
12 | important differences:
13 |
14 | a) setting item's value is slightly slower.
15 | It is O(lg capacity) instead of O(1).
16 | b) user has access to an efficient ( O(log segment size) )
17 | `reduce` operation which reduces `operation` over
18 | a contiguous subsequence of items in the array.
19 |
20 | Paramters
21 | ---------
22 | capacity: int
23 | Total size of the array - must be a power of two.
24 | operation: lambda obj, obj -> obj
25 | and operation for combining elements (eg. sum, max)
26 | must form a mathematical group together with the set of
27 | possible values for array elements (i.e. be associative)
28 | neutral_element: obj
29 | neutral element for the operation above. eg. float('-inf')
30 | for max and 0 for sum.
31 | """
32 | assert capacity > 0 and capacity & (capacity - 1) == 0, "capacity must be positive and a power of 2."
33 | self._capacity = capacity
34 | self._value = [neutral_element for _ in range(2 * capacity)]
35 | self._operation = operation
36 |
37 | def _reduce_helper(self, start, end, node, node_start, node_end):
38 | if start == node_start and end == node_end:
39 | return self._value[node]
40 | mid = (node_start + node_end) // 2
41 | if end <= mid:
42 | return self._reduce_helper(start, end, 2 * node, node_start, mid)
43 | else:
44 | if mid + 1 <= start:
45 | return self._reduce_helper(start, end, 2 * node + 1, mid + 1, node_end)
46 | else:
47 | return self._operation(self._reduce_helper(start, mid, 2 * node, node_start, mid),
48 | self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end))
49 |
50 | def reduce(self, start=0, end=None):
51 | """Returns result of applying `self.operation`
52 | to a contiguous subsequence of the array.
53 |
54 | self.operation(arr[start], operation(arr[start+1], operation(... arr[end])))
55 |
56 | Parameters
57 | ----------
58 | start: int
59 | beginning of the subsequence
60 | end: int
61 | end of the subsequences
62 |
63 | Returns
64 | -------
65 | reduced: obj
66 | result of reducing self.operation over the specified range of array elements.
67 | """
68 | if end is None:
69 | end = self._capacity
70 | if end < 0:
71 | end += self._capacity
72 | end -= 1
73 | return self._reduce_helper(start, end, 1, 0, self._capacity - 1)
74 |
75 | def __setitem__(self, idx, val):
76 | # index of the leaf
77 | idx += self._capacity
78 | self._value[idx] = val
79 | idx //= 2
80 | while idx >= 1:
81 | self._value[idx] = self._operation(self._value[2 * idx], self._value[2 * idx + 1])
82 | idx //= 2
83 |
84 | def __getitem__(self, idx):
85 | assert 0 <= idx < self._capacity
86 | return self._value[self._capacity + idx]
87 |
88 |
89 | class SumSegmentTree(SegmentTree):
90 |
91 | def __init__(self, capacity):
92 | super(SumSegmentTree, self).__init__(capacity=capacity, operation=operator.add, neutral_element=0.0)
93 |
94 | def sum(self, start=0, end=None):
95 | """Returns arr[start] + ... + arr[end]"""
96 | return super(SumSegmentTree, self).reduce(start, end)
97 |
98 | def find_prefixsum_idx(self, prefixsum):
99 | """Find the highest index `i` in the array such that
100 | sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum
101 |
102 | if array values are probabilities, this function
103 | allows to sample indexes according to the discrete
104 | probability efficiently.
105 |
106 | Parameters
107 | ----------
108 | perfixsum: float
109 | upperbound on the sum of array prefix
110 |
111 | Returns
112 | -------
113 | idx: int
114 | highest index satisfying the prefixsum constraint
115 | """
116 | assert 0 <= prefixsum <= self.sum() + 1e-5
117 | idx = 1
118 | while idx < self._capacity: # while non-leaf
119 | if self._value[2 * idx] > prefixsum:
120 | idx = 2 * idx
121 | else:
122 | prefixsum -= self._value[2 * idx]
123 | idx = 2 * idx + 1
124 | return idx - self._capacity
125 |
126 |
127 | class MinSegmentTree(SegmentTree):
128 |
129 | def __init__(self, capacity):
130 | super(MinSegmentTree, self).__init__(capacity=capacity, operation=min, neutral_element=float('inf'))
131 |
132 | def min(self, start=0, end=None):
133 | """Returns min(arr[start], ..., arr[end])"""
134 |
135 | return super(MinSegmentTree, self).reduce(start, end)
--------------------------------------------------------------------------------
/src/rlsrl/base/user.py:
--------------------------------------------------------------------------------
1 | import getpass
2 | import os
3 | import tempfile
4 |
5 |
6 | def get_user_tmp():
7 | tmp = tempfile.gettempdir()
8 | user = getpass.getuser()
9 | user_tmp = os.path.join(tmp, user)
10 | os.makedirs(user_tmp, exist_ok=True)
11 | return user_tmp
12 |
13 | def get_user_home():
14 | home_dir = os.environ["HOME"]
15 | return home_dir
16 |
17 | def get_random_tmp():
18 | return tempfile.mkdtemp()
19 |
--------------------------------------------------------------------------------
/src/rlsrl/legacy/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openpsi-projects/srl/97b3ae4e5fab00f6da90f81679dc035a9079eb09/src/rlsrl/legacy/__init__.py
--------------------------------------------------------------------------------
/src/rlsrl/legacy/algorithm/__init__.py:
--------------------------------------------------------------------------------
1 | from rlsrl.legacy.algorithm.ppo import *
2 |
--------------------------------------------------------------------------------
/src/rlsrl/legacy/algorithm/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .attention import *
2 | from .autoreset_rnn import *
3 | from .gae import *
4 | from .popart import *
5 | from .recurrent_backbone import *
6 | from .utils import *
7 | from .cnn import *
8 |
--------------------------------------------------------------------------------
/src/rlsrl/legacy/algorithm/modules/attention.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class CatSelfEmbedding(nn.Module):
8 |
9 | def __init__(self, self_dim, others_shape_dict, d_embedding, use_orthogonal=True):
10 | super(CatSelfEmbedding, self).__init__()
11 | self.self_dim = self_dim
12 | self.others_shape_dict = others_shape_dict
13 | self.d_embedding = d_embedding
14 |
15 | def get_layer(input_dim, output_dim):
16 | linear = nn.Linear(input_dim, output_dim)
17 | nn.init.xavier_uniform_(linear.weight.data)
18 | return nn.Sequential(linear, nn.ReLU(inplace=True))
19 |
20 | self.others_keys = sorted(self.others_shape_dict.keys())
21 | self.self_embedding = get_layer(self_dim, d_embedding)
22 | for k in self.others_keys:
23 | if 'mask' not in k:
24 | setattr(self, k + '_fc', get_layer(others_shape_dict[k][-1] + self_dim, d_embedding))
25 |
26 | def forward(self, self_vec, **inputs):
27 | other_embeddings = []
28 | self_embedding = self.self_embedding(self_vec)
29 | self_vec_ = self_vec.unsqueeze(-2)
30 | for k, x in inputs.items():
31 | assert k in self.others_keys
32 | expand_shape = [-1 for _ in range(len(x.shape))]
33 | expand_shape[-2] = x.shape[-2]
34 | x_ = torch.cat([self_vec_.expand(*expand_shape), x], -1)
35 | other_embeddings.append(getattr(self, k + '_fc')(x_))
36 |
37 | other_embeddings = torch.cat(other_embeddings, dim=-2)
38 | return self_embedding, other_embeddings
39 |
40 |
41 | def ScaledDotProductAttention(q, k, v, d_k, mask=None, dropout=None):
42 | scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
43 | if mask is not None:
44 | mask = mask.unsqueeze(-2).unsqueeze(-2)
45 | scores = scores - (1 - mask) * 1e10
46 | # in case of overflow
47 | scores = scores - scores.max(dim=-1, keepdim=True)[0]
48 | scores = F.softmax(scores, dim=-1)
49 | if mask is not None:
50 | # for stablity
51 | scores = scores * mask
52 |
53 | if dropout is not None:
54 | scores = dropout(scores)
55 |
56 | output = torch.matmul(scores, v)
57 | return output
58 |
59 |
60 | class MultiHeadSelfAttention(nn.Module):
61 |
62 | def __init__(self, input_dim, heads, d_head, dropout=0.0, use_orthogonal=True):
63 | super(MultiHeadSelfAttention, self).__init__()
64 |
65 | self.d_model = d_head * heads
66 | self.d_head = d_head
67 | self.h = heads
68 |
69 | self.pre_norm = nn.LayerNorm(input_dim)
70 | self.q_linear = nn.Linear(input_dim, self.d_model)
71 | nn.init.normal_(self.q_linear.weight.data, std=math.sqrt(0.125 / input_dim))
72 | self.k_linear = nn.Linear(input_dim, self.d_model)
73 | nn.init.normal_(self.k_linear.weight.data, std=math.sqrt(0.125 / input_dim))
74 |
75 | self.v_linear = nn.Linear(input_dim, self.d_model)
76 | nn.init.normal_(self.v_linear.weight.data, std=math.sqrt(0.125 / input_dim))
77 |
78 | # self.attn_dropout = nn.Dropout(dropout)
79 | self.attn_dropout = None
80 |
81 | def forward(self, x, mask, use_ckpt=False):
82 | x = self.pre_norm(x)
83 | # perform linear operation and split into h heads
84 | k = self.k_linear(x).view(*x.shape[:-1], self.h, self.d_head).transpose(-2, -3)
85 | q = self.q_linear(x).view(*x.shape[:-1], self.h, self.d_head).transpose(-2, -3)
86 | v = self.v_linear(x).view(*x.shape[:-1], self.h, self.d_head).transpose(-2, -3)
87 |
88 | # calculate attention
89 | scores = ScaledDotProductAttention(q, k, v, self.d_head, mask, self.attn_dropout)
90 |
91 | # concatenate heads and put through final linear layer
92 | return scores.transpose(-2, -3).contiguous().view(*x.shape[:-1], self.d_model)
93 |
94 |
95 | class ResidualMultiHeadSelfAttention(nn.Module):
96 |
97 | def __init__(self, input_dim, heads, d_head, dropout=0.0, use_orthogonal=True):
98 | super(ResidualMultiHeadSelfAttention, self).__init__()
99 | self.d_model = heads * d_head
100 | self.attn = MultiHeadSelfAttention(input_dim, heads, d_head, dropout, use_orthogonal)
101 |
102 | post_linear = nn.Linear(self.d_model, self.d_model)
103 | nn.init.normal_(post_linear.weight.data, std=math.sqrt(0.125 / self.d_model))
104 | self.dense = post_linear
105 | self.residual_norm = nn.LayerNorm(self.d_model)
106 | # self.dropout_after_attn = nn.Dropout(dropout)
107 | self.dropout_after_attn = None
108 |
109 | def forward(self, x, mask, use_ckpt=False):
110 | scores = self.dense(self.attn(x, mask, use_ckpt))
111 | if self.dropout_after_attn is not None:
112 | scores = self.dropout_after_attn(scores)
113 | return self.residual_norm(x + scores)
114 |
115 |
116 | def masked_avg_pooling(scores, mask=None):
117 | if mask is None:
118 | return scores.mean(-2)
119 | else:
120 | assert mask.shape[-1] == scores.shape[-2]
121 | masked_scores = scores * mask.unsqueeze(-1)
122 | return masked_scores.sum(-2) / (mask.sum(-1, keepdim=True) + 1e-5)
123 |
--------------------------------------------------------------------------------
/src/rlsrl/legacy/algorithm/modules/autoreset_rnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class AutoResetRNN(nn.Module):
6 |
7 | def __init__(self,
8 | input_dim,
9 | output_dim,
10 | num_layers=1,
11 | batch_first=False,
12 | rnn_type='lstm'):
13 | super().__init__()
14 | self.__type = rnn_type
15 | if self.__type == 'gru':
16 | self.__net = nn.GRU(input_dim,
17 | output_dim,
18 | num_layers=num_layers,
19 | batch_first=batch_first)
20 | elif self.__type == 'lstm':
21 | self.__net = nn.LSTM(input_dim,
22 | output_dim,
23 | num_layers=num_layers,
24 | batch_first=batch_first)
25 | else:
26 | raise NotImplementedError(
27 | f'RNN type {self.__type} has not been implemented.')
28 |
29 | def __forward(self, x, h):
30 | if self.__type == 'lstm':
31 | h = torch.split(h, h.shape[-1] // 2, dim=-1)
32 | h = (h[0].contiguous(), h[1].contiguous())
33 | x_, h_ = self.__net(x, h)
34 | if self.__type == 'lstm':
35 | h_ = torch.cat(h_, -1)
36 | return x_, h_
37 |
38 | def forward(self, x, h, on_reset=None):
39 | if on_reset is None:
40 | return self.__forward(x, h)
41 |
42 | masks = 1 - on_reset
43 | hxs = h
44 |
45 | has_zeros = (masks[1:] == 0.0).any(dim=1).nonzero(
46 | as_tuple=True)[0].cpu().numpy()
47 | has_zeros = [0] + (has_zeros + 1).tolist() + [x.shape[0]]
48 |
49 | outputs = []
50 | for i in range(len(has_zeros) - 1):
51 | # We can now process steps that don't have any zeros in masks together!
52 | # This is much faster
53 | start_idx = has_zeros[i]
54 | end_idx = has_zeros[i + 1]
55 | rnn_scores, hxs = self.__forward(
56 | x[start_idx:end_idx],
57 | hxs * masks[start_idx].view(1, -1, *((1, ) * (hxs.dim() - 2))))
58 | outputs.append(rnn_scores)
59 |
60 | # assert len(outputs) == T
61 | # x is a (T, N, -1) tensor
62 | x_ = torch.cat(outputs, dim=0)
63 | h_ = hxs
64 | return x_, h_
65 |
--------------------------------------------------------------------------------
/src/rlsrl/legacy/algorithm/modules/cnn.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List, Tuple
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 |
6 | from .utils import mlp
7 |
8 |
9 | def cnn_output_dim(dimension, padding, dilation, kernel_size, stride):
10 | """Calculates the output height and width based on the input
11 | height and width to the convolution layer.
12 | ref: https://pytorch.org/docs/master/nn.html#torch.nn.Conv2d
13 | """
14 | out_dimension = []
15 | for i in range(len(dimension)):
16 | out_dimension.append(
17 | int(
18 | np.floor(((dimension[i] + 2 * padding[i] - dilation[i] *
19 | (kernel_size[i] - 1) - 1) / stride[i]) + 1)))
20 | if not all([d > 0 for d in out_dimension]):
21 | raise ValueError(f"CNN Dimension error, got {out_dimension} after convolution")
22 | return tuple(out_dimension)
23 |
24 |
25 | def maxpool_output_dim(dimension, dilation, kernel_size, stride):
26 | """Calculates the output height and width based on the input
27 | height and width to the convolution layer.
28 | ref: https://pytorch.org/docs/master/nn.html#torch.nn.Conv2d
29 | """
30 | out_dimension = []
31 | for i in range(len(dimension)):
32 | out_dimension.append(
33 | int(np.floor(((dimension[i] - dilation[i] * (kernel_size[i] - 1) - 1) / stride[i]) + 1)))
34 | if not all([d > 0 for d in out_dimension]):
35 | raise ValueError(f"CNN Dimension error, got {out_dimension} after convolution")
36 | return tuple(out_dimension)
37 |
38 |
39 | class Convolution(nn.Module):
40 | """A model that uses Conv2d and max-pooling to embed a image to vector.
41 | """
42 |
43 | def __init__(self, input_shape: Tuple, cnn_layers: Optional[List[Tuple[int, int, int, int, str]]],
44 | use_maxpool: bool, activation, hidden_size: int, use_orthogonal: bool):
45 | """Initialization method of ImageToVector (auto-cnn).
46 | Args:
47 | input_shape: Shape of input image, channel first.
48 | cnn_layers: List of user-specified cnn layer configuration (out_channels, kernel_size, stride, padding).
49 | use_maxpool: Whether to use a maxpool2d of size 2 after each layer.
50 | activation: nn.Relu / nn.Tanh.
51 | hidden_size: output dimension.
52 | use_orthogonal: whether to use orthogonal initialization method.
53 | """
54 | super(Convolution, self).__init__()
55 | self.__input_shape = input_shape
56 | self.__input_dims = len(input_shape)
57 | self.__use_maxpool = use_maxpool
58 | self.__cnn_layers = cnn_layers
59 | self.activation_name = activation
60 |
61 | if self.__input_dims == 2:
62 | self.__cnn = nn.Conv1d
63 | self.__max_pool = nn.MaxPool1d
64 | elif self.__input_dims == 3:
65 | self.__cnn = nn.Conv2d
66 | self.__max_pool = nn.MaxPool2d
67 | elif self.__input_dims == 4:
68 | self.__cnn = nn.Conv3d
69 | self.__max_pool = nn.MaxPool3d
70 |
71 | self.__activation = activation
72 |
73 | self.__weights_init_gain = nn.init.calculate_gain(activation.__name__.lower())
74 | self.__weights_init_method = nn.init.orthogonal_ if use_orthogonal else nn.init.xavier_uniform_
75 | self.__bias_init_constant = 0
76 | self.__bias_init_method = nn.init.constant_
77 |
78 | self.__hidden_size = hidden_size
79 | self.__model = self._build_cnn_model()
80 |
81 | def __init_layer(self, layer: nn.Module):
82 | self.__weights_init_method(layer.weight.data, gain=self.__weights_init_gain)
83 | self.__bias_init_method(layer.bias.data, val=self.__bias_init_constant)
84 | return layer
85 |
86 | def __linear_layers(self, input_dim):
87 | linear_sizes = [input_dim]
88 | while linear_sizes[-1] > self.__hidden_size * 8:
89 | linear_sizes.append(linear_sizes[-1] // 2)
90 | linear_sizes.append(self.__hidden_size)
91 | return mlp(linear_sizes)
92 |
93 | def _build_cnn_model(self):
94 | cnn_layers = []
95 | num_channels, *dimension = self.__input_shape
96 | if self.__cnn_layers is None:
97 | self.__cnn_layers = [(num_channels, 5, 1, 0, "zeros"), (num_channels * 2, 3, 1, 0, "zeros"),
98 | (num_channels, 3, 1, 0, "zeros")]
99 | for i, (out_channels, kernel_size, stride, padding, padding_mode) in enumerate(self.__cnn_layers):
100 | if self.__use_maxpool and i != len(self.__cnn_layers) - 1:
101 | # Add a maxpool layer if not yet the last layer.
102 | cnn_layers.append(self.__max_pool(2))
103 | dimension = maxpool_output_dim(dimension=dimension,
104 | dilation=np.array([1] * self.__input_dims, dtype=np.float32),
105 | kernel_size=np.array([2] * self.__input_dims,
106 | dtype=np.float32),
107 | stride=np.array([2] * self.__input_dims, dtype=np.float32))
108 |
109 | cnn_layers.append(
110 | self.__init_layer(
111 | self.__cnn(in_channels=num_channels,
112 | out_channels=out_channels,
113 | kernel_size=kernel_size,
114 | stride=stride,
115 | padding=padding,
116 | padding_mode=padding_mode)))
117 | dimension = cnn_output_dim(dimension=dimension,
118 | padding=np.array([padding] * self.__input_dims, dtype=np.float32),
119 | dilation=np.array([1] * self.__input_dims, dtype=np.float32),
120 | kernel_size=np.array([kernel_size] * self.__input_dims,
121 | dtype=np.float32),
122 | stride=np.array([stride] * self.__input_dims, dtype=np.float32))
123 |
124 | cnn_layers.append(self.__activation())
125 | num_channels = out_channels
126 | cnn_layers.extend([nn.Flatten(), self.__linear_layers(input_dim=num_channels * np.prod(dimension))])
127 |
128 | return nn.Sequential(*cnn_layers)
129 |
130 | def forward(self, x):
131 | T, B = x.size()[:2]
132 | x = torch.flatten(x, start_dim=0, end_dim=1)
133 | cnn_x = self.__model(x)
134 |
135 | return cnn_x.reshape(T, B, -1)
136 |
--------------------------------------------------------------------------------
/src/rlsrl/legacy/algorithm/modules/gae.py:
--------------------------------------------------------------------------------
1 | from typing import Union, Optional
2 | import numpy as np
3 | import torch
4 |
5 |
6 | @torch.no_grad()
7 | def gae_trace(
8 | reward: torch.FloatTensor,
9 | value: torch.FloatTensor,
10 | truncated: torch.FloatTensor,
11 | done: torch.FloatTensor,
12 | on_reset: torch.FloatTensor,
13 | gamma: Union[float, torch.FloatTensor],
14 | lmbda: Union[float, torch.FloatTensor],
15 | vtrace: Optional[bool] = False,
16 | imp_ratio: Optional[torch.FloatTensor] = None,
17 | rho: Optional[float] = 1.0,
18 | c: Optional[float] = 1.0,
19 | high_precision: Optional[bool] = True,
20 | ) -> torch.FloatTensor:
21 | """Compute the Generalized Advantage Estimation.
22 |
23 | Args:
24 | reward (torch.FloatTensor): rewards of shape [T, bs, Nc]
25 | value (torch.FloatTensor): values of shape [T+1, bs, Nc]
26 | truncated (torch.FloatTensor): truncated indicator of shape [T+1, bs, 1]
27 | done (torch.FloatTensor): done (aka terminated) indicator of shape [T+1, bs, 1]
28 | on_reset (torch.FloatTensor): whether on the reset step, shape [T+1, bs, 1]
29 | gamma (Union[float, torch.FloatTensor]): discount factor.
30 | If input is a tensor, it should have shape [T, bs, 1]
31 | lmbda (Union[float, torch.FloatTensor]): GAE lambda.
32 | If input is a tensor, it should have shape [T, bs, 1]
33 | vtrace (Optional[bool]): whether to use V-trace correction. Defaults to False.
34 | imp_ratio (Optional[torch.FloatTensor]): importance sampling ratio of shape [T, bs, 1]
35 | rho (Optional[float]):
36 | Clipping hyperparameter rho as described in the paper. Defaults to 1.0.
37 | c (Optional[float]):
38 | Clipping hyperparameter c as described in the paper. Defaults to 1.0.
39 | high_precision (Optional[bool]): whether to use float64. Defaults to True.
40 | Returns:
41 | torch.FloatTensor: GAE of shape [T, bs, Nc]
42 | """
43 |
44 | if high_precision:
45 | reward, value, truncated, done, on_reset = map(
46 | lambda x: x.to(torch.float64),
47 | [reward, value, truncated, done, on_reset])
48 | if vtrace:
49 | imp_ratio = imp_ratio.to(torch.float64)
50 | if not isinstance(gamma, float):
51 | assert isinstance(gamma, torch.FloatTensor), type(gamma)
52 | assert gamma.shape == on_reset[:-1].shape, gamma.shape
53 | if high_precision:
54 | gamma = gamma.to(torch.float64)
55 | if not isinstance(lmbda, float):
56 | assert isinstance(lmbda, torch.FloatTensor), type(lmbda)
57 | assert lmbda.shape == on_reset[:-1].shape, lmbda.shape
58 | if high_precision:
59 | lmbda = lmbda.to(torch.float64)
60 |
61 | episode_length = int(reward.shape[0])
62 | delta = reward + gamma * value[1:] * (1 - on_reset[1:]) - value[:-1]
63 | if vtrace:
64 | delta *= imp_ratio.clip(max=rho)
65 |
66 | ################## ASSERTIONS START ##################
67 | ###### disable assertions with `python -O xxx` #######
68 | assert (truncated * done == 0).all()
69 | assert ((truncated + done)[:-1] == on_reset[1:]).all()
70 | # the reward should not be amended at the final step
71 | assert (reward * on_reset[1:] == 0).all()
72 | # when an episode is done (not truncated), reward, value, and bootstrapped value should be zero
73 | # hence delta is also zero
74 | assert (delta * on_reset[1:] * (1 - truncated[:-1]) == 0).all()
75 | # when an episode is truncated, reward and bootstrapped value are zero
76 | assert (delta * truncated[:-1] == -value[:-1] * truncated[:-1]).all()
77 | ################### ASSERTIONS END ###################
78 |
79 | gae = torch.zeros_like(reward[0])
80 | adv = torch.zeros_like(reward)
81 |
82 | # 1. If the next step is a new episode, GAE doesn't propagate back
83 | # 2. If the next step is a truncated final step, the backpropagated GAE is -V(t),
84 | # which is not correct. We ignore it such that the current GAE is r(t-1)+ɣV(t)-V(t-1)
85 | # 3. If the next step is a done final step, the backpropagated GAE is zero.
86 | m = gamma * lmbda * (1 - on_reset[1:]) * (1 - truncated[1:])
87 | if vtrace:
88 | m *= imp_ratio.clip(max=c)
89 |
90 | step = episode_length - 1
91 | while step >= 0:
92 | gae = delta[step] + m[step] * gae
93 | adv[step] = gae
94 | step -= 1
95 |
96 | return adv.float()
97 |
--------------------------------------------------------------------------------
/src/rlsrl/legacy/algorithm/modules/popart.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from .utils import RunningMeanStd
7 |
8 |
9 | class PopArtValueHead(nn.Module):
10 |
11 | def __init__(
12 | self,
13 | input_dim,
14 | critic_dim,
15 | beta=0.99999,
16 | epsilon=1e-5,
17 | burn_in_updates=torch.inf,
18 | high_precision=True,
19 | ):
20 | super().__init__()
21 | self.__rms = RunningMeanStd((critic_dim, ),
22 | beta=beta,
23 | epsilon=epsilon,
24 | high_precision=high_precision)
25 |
26 | self.__weight = nn.Parameter(torch.zeros(critic_dim, input_dim))
27 | self.__bias = nn.Parameter(torch.zeros(critic_dim))
28 | # The same initialization as `nn.Linear`.
29 | torch.nn.init.kaiming_uniform_(self.__weight, a=math.sqrt(5))
30 | torch.nn.init.uniform_(self.__bias, -1 / math.sqrt(input_dim),
31 | 1 / math.sqrt(input_dim))
32 |
33 | self.__burn_in_updates = burn_in_updates
34 | self.__update_cnt = 0
35 |
36 | @property
37 | def weight(self):
38 | return self.__weight
39 |
40 | @property
41 | def bias(self):
42 | return self.__bias
43 |
44 | def forward(self, feature):
45 | return F.linear(feature, self.__weight, self.__bias)
46 |
47 | @torch.no_grad()
48 | def update(self, x, mask):
49 | old_mean, old_std = self.__rms.mean_std()
50 | self.__rms.update(x, mask)
51 | new_mean, new_std = self.__rms.mean_std()
52 | self.__update_cnt += 1
53 |
54 | if self.__update_cnt > self.__burn_in_updates:
55 | self.__weight.data[:] = self.__weight * (old_std /
56 | new_std).unsqueeze(-1)
57 | self.__bias.data[:] = (old_std * self.__bias + old_mean -
58 | new_mean) / new_std
59 |
60 | @torch.no_grad()
61 | def normalize(self, x):
62 | return self.__rms.normalize(x)
63 |
64 | @torch.no_grad()
65 | def denormalize(self, x):
66 | return self.__rms.denormalize(x)
67 |
--------------------------------------------------------------------------------
/src/rlsrl/legacy/algorithm/modules/recurrent_backbone.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch.nn as nn
3 |
4 | from .autoreset_rnn import AutoResetRNN
5 | from .utils import mlp
6 |
7 |
8 | class RecurrentBackbone(nn.Module):
9 |
10 | @property
11 | def feature_dim(self):
12 | return self.__feature_dim
13 |
14 | def __init__(
15 | self,
16 | obs_dim: int,
17 | dense_layers: int,
18 | hidden_dim: int,
19 | rnn_type: str,
20 | num_rnn_layers: int,
21 | dense_layer_gain: float = math.sqrt(2),
22 | activation='relu',
23 | layernorm=True,
24 | ):
25 | super(RecurrentBackbone, self).__init__()
26 |
27 | if activation == 'relu':
28 | act_fn = nn.ReLU
29 | elif activation == 'tanh':
30 | act_fn = nn.Tanh
31 | elif activation == 'elu':
32 | act_fn = nn.ELU
33 | elif activation == 'gelu':
34 | act_fn = nn.GELU
35 | else:
36 | raise NotImplementedError(
37 | f"Activation function {activation} not implemented.")
38 |
39 | self.__feature_dim = hidden_dim
40 | self.__rnn_type = rnn_type
41 | self.fc = mlp([obs_dim, *([hidden_dim] * dense_layers)],
42 | act_fn,
43 | layernorm=layernorm)
44 | for k, p in self.fc.named_parameters():
45 | if 'weight' in k and len(p.data.shape) >= 2:
46 | # filter out layer norm weights
47 | nn.init.orthogonal_(p.data, gain=dense_layer_gain)
48 | if 'bias' in k:
49 | nn.init.zeros_(p.data)
50 |
51 | self.num_rnn_layers = num_rnn_layers
52 | if self.num_rnn_layers:
53 | self.rnn = AutoResetRNN(hidden_dim,
54 | hidden_dim,
55 | num_layers=num_rnn_layers,
56 | rnn_type=self.__rnn_type)
57 | self.rnn_norm = nn.LayerNorm([hidden_dim])
58 | for k, p in self.rnn.named_parameters():
59 | if 'weight' in k and len(p.data.shape) >= 2:
60 | # filter out layer norm weights
61 | nn.init.orthogonal_(p.data)
62 | if 'bias' in k:
63 | nn.init.zeros_(p.data)
64 |
65 | def forward(self, obs, hx, on_reset=None):
66 | features = self.fc(obs)
67 | if self.num_rnn_layers > 0:
68 | features, hx = self.rnn(features, hx, on_reset)
69 | features = self.rnn_norm(features)
70 | return features, hx
71 |
--------------------------------------------------------------------------------
/src/rlsrl/legacy/algorithm/ppo/__init__.py:
--------------------------------------------------------------------------------
1 | import rlsrl.legacy.algorithm.ppo.actor_critic_policies
2 | import rlsrl.legacy.algorithm.ppo.mappo
3 | import rlsrl.legacy.algorithm.ppo.phasic_policy_gradient
4 |
--------------------------------------------------------------------------------
/src/rlsrl/legacy/algorithm/ppo/actor_critic_policies/__init__.py:
--------------------------------------------------------------------------------
1 | import rlsrl.legacy.algorithm.ppo.actor_critic_policies.actor_critic_policy
2 |
--------------------------------------------------------------------------------
/src/rlsrl/legacy/algorithm/ppo/actor_critic_policies/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Union, Tuple
2 | import dataclasses
3 | import torch.nn as nn
4 |
5 | from rlsrl.legacy.algorithm import modules
6 |
7 |
8 | @dataclasses.dataclass
9 | class ActionIndex:
10 | start: int
11 | end: int
12 |
13 |
14 | def get_action_indices(act_dims):
15 | """Convert action dimensions to indices. Results can be used to parse model forward results.
16 | Args:
17 | act_dims: dimensions of all value heads.
18 | Returns:
19 | action_indices (List[ActionIndex]): index of each action.
20 | Example:
21 | >>> indices = get_action_indices([3, 5, 7])
22 | >>> print(indices)
23 | [ActionIndex(start=0, end=3), ActionIndex(start=3, end=8), ActionIndex(start=8, end=15)]
24 | """
25 | curr_action_index = 0
26 | action_indices = []
27 | for dim in act_dims:
28 | action_indices.append(ActionIndex(curr_action_index, curr_action_index + dim))
29 | curr_action_index += dim
30 | return action_indices
31 |
32 |
33 | def make_models_for_obs(obs_dim: Dict[str, Union[int, Tuple]], hidden_dim: int, activation,
34 | cnn_layers: Dict[str, Tuple], use_maxpool: Dict[str, bool]):
35 | """Make models based on a dict of observation dimension.
36 | Args:
37 | obs_dim: Key-Value pair of observation_name and dimension of observation {"obs_name": obs_dim},
38 | hidden_dim: Embedding dimension of all observations.
39 | activation: nn.ReLU or nn.Tanh
40 | cnn_layers: Key-Value pair of observation_name and user-specified cnn-layers.
41 | use_maxpool: whether to use maxpool for each obs. effective only with convolutional nets.
42 | """
43 | obs_embd_dict = nn.ModuleDict()
44 | for k, v in obs_dim.items():
45 | if isinstance(v, int):
46 | obs_embd_dict.update({
47 | k: nn.Sequential(nn.LayerNorm([v]),
48 | modules.mlp([v, hidden_dim], activation=activation, layernorm=True))
49 | })
50 | elif len(v) in [2, 3, 4]:
51 | obs_embd_dict.update({
52 | k: nn.Sequential(
53 | nn.LayerNorm(v),
54 | modules.Convolution(v,
55 | cnn_layers=cnn_layers.get(k, None),
56 | use_maxpool=use_maxpool.get(k, False),
57 | activation=activation,
58 | hidden_size=hidden_dim,
59 | use_orthogonal=True))
60 | })
61 | else:
62 | raise NotImplementedError()
63 | return obs_embd_dict
64 |
--------------------------------------------------------------------------------
/src/rlsrl/legacy/environment/__init__.py:
--------------------------------------------------------------------------------
1 | """Legacy environments are registered safely.
2 | """
3 | from rlsrl.api.environment import register
4 |
5 | register("atari", "AtariEnvironment", "rlsrl.legacy.environment.atari.atari_env")
6 |
--------------------------------------------------------------------------------
/src/rlsrl/legacy/environment/atari/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openpsi-projects/srl/97b3ae4e5fab00f6da90f81679dc035a9079eb09/src/rlsrl/legacy/environment/atari/__init__.py
--------------------------------------------------------------------------------
/src/rlsrl/legacy/environment/atari/atari_env.py:
--------------------------------------------------------------------------------
1 | """Simple wrapper around the Atari environments provided by gym.
2 | """
3 | from typing import List, Union, Optional
4 | import collections
5 | import gym
6 | import logging
7 | import numpy as np
8 | import os
9 | import time
10 |
11 | from .atari_wrappers import make_atari
12 | import rlsrl.api.environment as env_base
13 | import rlsrl.api.env_utils as env_utils
14 |
15 | logger = logging.getLogger("env-atari")
16 |
17 | _HAS_DISPLAY = len(os.environ.get("DISPLAY", "").strip()) > 0
18 |
19 |
20 | class AtariEnvironment(env_base.Environment):
21 |
22 | def __init__(
23 | self,
24 | game_name,
25 | render: bool = False,
26 | pause: bool = False,
27 | noop_max: int = 0,
28 | episode_life: bool = False,
29 | clip_reward: bool = False,
30 | frame_skip: int = 1,
31 | stacked_observations: Union[int, None] = None,
32 | max_episode_steps: int = 108000,
33 | gray_scale: bool = False,
34 | obs_shape: Union[List[int], None] = None,
35 | scale: bool = False,
36 | seed: Optional[int] = None,
37 | obs_include_last_action: bool = False,
38 | stacked_last_actions: int = 1,
39 | obs_include_last_reward: bool = False,
40 | full_action_space: bool = False,
41 | epsilon: Optional[float] = None,
42 | ):
43 | """Atari environment
44 |
45 | DQN training configuration:
46 | - noop_max: 30
47 | - episode_life: True
48 | - clip_reward: True
49 | - frame_skip: 4
50 | - stacked_observations: 4
51 | - max_episode_steps: 108000
52 | - gray_scale: True
53 | - obs_shape: (84, 84)
54 |
55 | R2D2 training configuration:
56 | - noop_max: 30
57 | - episode_life: False
58 | - clip_reward: False
59 | - frame_skip: 4
60 | - stacked_observations: 4
61 | - max_episode_steps: 108000
62 | - gray_scale: True
63 | - obs_shape: (84, 84)
64 | - obs_include_last_action: True
65 | - obs_include_last_reward: True
66 |
67 | Parameters
68 | ----------
69 | noop_max: int
70 | upon reset, do no-op action for a number of steps in [1, noop_max]
71 | episode_life: bool
72 | terminal upon loss of life
73 | clip_reward: bool
74 | reward -> sign(reward)
75 | frame_skip: int
76 | repeat the action for `frame_skip` steps and return max of last two frames
77 | max_episode_steps: int
78 | episode length
79 | gray_scale: bool
80 | use gray image observation
81 | obs_shape: list
82 | resize observation to `obs_shape`
83 | scale: bool
84 | scale frames to [0, 1]
85 | obs_include_last_action:
86 | include one-hot action in observation dict
87 | stacked_last_actions:
88 | stack k latest one-hot actions
89 | obs_include_last_reward:
90 | include last reward in observation dict
91 | """
92 | self.game_name = game_name
93 | self.__render = render
94 | self.__pause = pause
95 | self.__env = make_atari(
96 | env_id=game_name,
97 | seed=seed,
98 | noop_max=noop_max,
99 | frame_skip=frame_skip,
100 | max_episode_steps=max_episode_steps,
101 | episode_life=episode_life,
102 | obs_shape=obs_shape,
103 | gray_scale=gray_scale,
104 | clip_reward=clip_reward,
105 | stacked_observations=stacked_observations,
106 | scale=scale,
107 | full_action_space=full_action_space,
108 | )
109 | self.__frame_skip = frame_skip
110 |
111 | self.__obs_include_last_action = obs_include_last_action
112 | self.__stacked_last_actions = stacked_last_actions
113 | self.__last_actions_queue = collections.deque(maxlen=stacked_last_actions)
114 | self.__obs_include_last_reward = obs_include_last_reward
115 |
116 | self.__step_count = np.zeros(1, dtype=np.int32)
117 | self.__episode_return = np.zeros(1, dtype=np.float32)
118 |
119 | self.__epsilon = epsilon
120 |
121 | @property
122 | def agent_count(self) -> int:
123 | return 1 # We are a simple Atari environment here.
124 |
125 | @property
126 | def observation_spaces(self):
127 | base_space = {"obs": self.__env.observation_space.shape}
128 | if self.__obs_include_last_action:
129 | base_space['action'] = (self.__stacked_last_actions * self.__env.action_space.n,)
130 | if self.__obs_include_last_reward:
131 | base_space['reward'] = (1,)
132 | return [base_space]
133 |
134 | @property
135 | def action_spaces(self):
136 | return [env_utils.DiscreteActionSpace(self.__env.action_space)]
137 |
138 | def _get_obs(self, frame, action, reward):
139 | self.__last_action[:] = 0
140 | if action is not None:
141 | self.__last_action[action] = 1
142 | self.__last_actions_queue.append(self.__last_action)
143 |
144 | obs = dict(obs=frame)
145 | if self.__epsilon is not None:
146 | obs['epsilon'] = np.array([self.__epsilon], dtype=np.float32)
147 | if self.__obs_include_last_action:
148 | obs['action'] = np.concatenate(list(self.__last_actions_queue), -1)
149 | if self.__obs_include_last_reward:
150 | obs['reward'] = np.array([reward], dtype=np.float32)
151 | return obs
152 |
153 | def reset(self) -> List[env_base.StepResult]:
154 | self.__step_count[:] = 0
155 | self.__episode_return[:] = 0
156 | self.__last_action = np.zeros((self.__env.action_space.n,), dtype=np.uint8)
157 | for _ in range(self.__stacked_last_actions):
158 | self.__last_actions_queue.append(self.__last_action)
159 |
160 | frame = self.__env.reset()
161 |
162 | return [
163 | env_base.StepResult(obs=self._get_obs(frame, None, 0),
164 | reward=np.array([0.0], dtype=np.float32),
165 | done=np.array([False], dtype=np.uint8),
166 | info=dict(episode_length=self.__step_count.copy(),
167 | episode_return=self.__episode_return.copy()))
168 | ]
169 |
170 | def step(self, actions: List[env_utils.DiscreteAction]) -> List[env_base.StepResult]:
171 |
172 | assert len(actions) == 1, len(actions)
173 | action = int(actions[0].x)
174 |
175 | frame, reward, done, info = self.__env.step(action)
176 |
177 | self.__step_count += self.__frame_skip
178 | self.__episode_return += reward
179 |
180 | if self.__render:
181 | logger.info("Step %d: reward=%.2f, done=%d", self.__step_count, reward, done)
182 | if _HAS_DISPLAY:
183 | self.render()
184 | if self.__pause:
185 | input()
186 | else:
187 | time.sleep(0.05)
188 |
189 | return [
190 | env_base.StepResult(
191 | obs=self._get_obs(frame, action, reward),
192 | reward=np.array([reward], dtype=np.float32),
193 | done=np.array([done], dtype=np.uint8),
194 | info=dict(episode_length=self.__step_count.copy(),
195 | episode_return=self.__episode_return.copy()),
196 | )
197 | ]
198 |
199 | def render(self) -> None:
200 | self.__env.render()
201 |
202 | def seed(self, seed=None):
203 | self.__env.seed(seed)
204 | return seed
205 |
206 |
207 | if __name__ == '__main__':
208 | import psutil
209 | import multiprocessing
210 | import time
211 |
212 | # code or function for which memory
213 | # has to be monitored
214 | def app():
215 | config = dict(
216 | full_action_space=False,
217 | noop_max=30,
218 | frame_skip=4,
219 | stacked_observations=4,
220 | gray_scale=True,
221 | obs_shape=(84, 84),
222 | scale=False,
223 | obs_include_last_action=True,
224 | obs_include_last_reward=True,
225 | epsilon=0.1,
226 | )
227 | env = AtariEnvironment('PongNoFrameskip-v4', **config)
228 | step_cnt = 0
229 | import time
230 | srs = env.reset()
231 | tik = time.perf_counter()
232 | for _ in range(100):
233 | done = False
234 | srs = env.reset()
235 | while not done:
236 | srs = env.step([sp.sample() for sp, sr in zip(env.action_spaces, srs)])
237 | done = all(sr.done[0] for sr in srs if sr is not None)
238 | step_cnt += 1
239 | if step_cnt % 100 == 0:
240 | print(f"FPS: {step_cnt * 4 / (time.perf_counter() - tik)}")
241 |
242 | p = multiprocessing.Process(target=app)
243 | p.start()
244 |
245 | # Get the process ID of the current Python process
246 | main_process = psutil.Process()
247 |
248 | for _ in range(100):
249 | time.sleep(5)
250 | # Get a list of child processes for the main process
251 | child_processes = main_process.children(recursive=True)
252 |
253 | # Find the child process with the highest memory usage
254 | mem = sum([process.memory_info().rss / 1024**2 for process in child_processes])
255 | print(f"Memory: {mem}")
--------------------------------------------------------------------------------
/src/rlsrl/legacy/experiments/__init__.py:
--------------------------------------------------------------------------------
1 | import rlsrl.legacy.experiments.atari_benchmark
2 | import rlsrl.legacy.experiments.atari_remote
--------------------------------------------------------------------------------
/src/rlsrl/legacy/experiments/atari_benchmark.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import itertools
3 | import functools
4 |
5 | from rlsrl.api.config import *
6 |
7 |
8 | @dataclasses.dataclass
9 | class AtariFPSBenchmarkExperiment(Experiment):
10 | aws: int = 32
11 | pws: int = 4
12 | tws: int = 1
13 |
14 | inference_splits: int = 4
15 | ring_size: int = 40
16 |
17 | shared_memory: bool = True
18 |
19 | seed: int = 1
20 |
21 | def make_envs(self, aw_rank):
22 | return [
23 | Environment(type_="atari",
24 | args=dict(game_name="PongNoFrameskip-v4",
25 | seed=self.seed + 12345 * x +
26 | aw_rank * self.ring_size,
27 | render=False,
28 | pause=False,
29 | noop_max=30,
30 | frame_skip=4,
31 | stacked_observations=4,
32 | max_episode_steps=108000,
33 | gray_scale=True,
34 | obs_shape=(84, 84)))
35 | for x in range(self.ring_size)
36 | ]
37 |
38 | def initial_setup(self):
39 | sample_stream_qsize = 1024
40 | buffer_zero_copy = self.shared_memory
41 |
42 | self.policy_name = "default"
43 | self.policy = Policy(
44 | type_="actor-critic",
45 | args=dict(
46 | obs_dim={"obs": (4, 84, 84)},
47 | action_dim=6,
48 | num_dense_layers=0,
49 | hidden_dim=512,
50 | popart=True,
51 | layernorm=False,
52 | shared_backbone=True,
53 | rnn_type='lstm',
54 | num_rnn_layers=0,
55 | seed=self.seed,
56 | cnn_layers=dict(obs=[(16, 8, 4, 0,
57 | 'zeros'), (32, 4, 2, 0, 'zeros')]),
58 | chunk_len=10,
59 | ))
60 | self.trainer = Trainer(type_="mappo",
61 | args=dict(
62 | discount_rate=0.99,
63 | gae_lambda=0.97,
64 | eps_clip=0.2,
65 | clip_value=True,
66 | dual_clip=False,
67 | vtrace=False,
68 | value_loss='huber',
69 | value_loss_weight=1.0,
70 | value_loss_config=dict(delta=10.0, ),
71 | entropy_bonus_weight=0.01,
72 | optimizer='adam',
73 | optimizer_config=dict(lr=5e-4),
74 | popart=True,
75 | max_grad_norm=40.0,
76 | bootstrap_steps=1,
77 | recompute_adv_among_epochs=False,
78 | recompute_adv_on_reuse=False,
79 | burn_in_steps=0,
80 | ))
81 | self.agent_specs = [
82 | AgentSpec(
83 | index_regex=".*",
84 | inference_stream_idx=0,
85 | sample_stream_idx=0,
86 | send_full_trajectory=False,
87 | send_after_done=False,
88 | sample_steps=50,
89 | bootstrap_steps=1,
90 | )
91 | ]
92 |
93 | if self.shared_memory:
94 | sample_stream = SampleStream(
95 | type_=SampleStream.Type.SHARED_MEMORY,
96 | stream_name=self.policy_name,
97 | plugin=SharedMemorySampleStreamPlugin(
98 | qsize=sample_stream_qsize),
99 | )
100 | inference_stream = InferenceStream(
101 | type_=InferenceStream.Type.SHARED_MEMORY,
102 | stream_name=self.policy_name)
103 | else:
104 | sample_stream = SampleStream(type_=SampleStream.Type.NAME,
105 | stream_name=self.policy_name)
106 | inference_stream = InferenceStream(type_=InferenceStream.Type.NAME,
107 | stream_name=self.policy_name)
108 |
109 | actors = [
110 | ActorWorker(env=self.make_envs(i),
111 | inference_streams=[
112 | inference_stream,
113 | ],
114 | sample_streams=[sample_stream],
115 | agent_specs=self.agent_specs,
116 | max_num_steps=20000,
117 | inference_splits=self.inference_splits,
118 | ring_size=self.ring_size) for i in range(self.aws)
119 | ]
120 | policies = [
121 | PolicyWorker(
122 | policy_name=self.policy_name,
123 | inference_stream=inference_stream,
124 | policy=self.policy,
125 | worker_info=WorkerInformation(device="cuda:0"),
126 | ) for i in range(self.pws)
127 | ]
128 |
129 | return ExperimentConfig(
130 | actors=actors,
131 | policies=policies,
132 | trainers=[
133 | TrainerWorker(
134 | buffer_name='priority_queue',
135 | buffer_zero_copy=buffer_zero_copy,
136 | buffer_args=dict(
137 | max_size=sample_stream_qsize,
138 | reuses=1,
139 | batch_size=32,
140 | ),
141 | policy_name=self.policy_name,
142 | trainer=self.trainer,
143 | policy=self.policy,
144 | log_frequency_seconds=5,
145 | sample_stream=sample_stream,
146 | worker_info=WorkerInformation(
147 | wandb_job_type='tw',
148 | wandb_group='mini',
149 | wandb_project='srl-atari',
150 | wandb_name=f'seed{self.seed}',
151 | log_terminal=True,
152 | device="cuda:0",
153 | ),
154 | ) for _ in range(self.tws)
155 | ],
156 | master_worker=[MasterWorker(
157 | address="localhost",
158 | port=51234,
159 | )])
160 |
161 |
162 | register_experiment("atari-mini", AtariFPSBenchmarkExperiment)
163 | register_experiment("atari-mini-remote", functools.partial(AtariFPSBenchmarkExperiment, shared_memory=False))
164 |
--------------------------------------------------------------------------------
/src/rlsrl/legacy/experiments/atari_remote.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import itertools
3 | import functools
4 |
5 | from rlsrl.api.config import *
6 |
7 |
8 | @dataclasses.dataclass
9 | class AtariMultiNodeExperiment(Experiment):
10 | aws: int = 32
11 | pws: int = 4
12 | tws: int = 1
13 |
14 | inference_splits: int = 4
15 | ring_size: int = 40
16 |
17 | seed: int = 1
18 |
19 | def make_envs(self, aw_rank):
20 | return [
21 | Environment(type_="atari",
22 | args=dict(game_name="PongNoFrameskip-v4",
23 | seed=self.seed + 12345 * x +
24 | aw_rank * self.ring_size,
25 | render=False,
26 | pause=False,
27 | noop_max=30,
28 | frame_skip=4,
29 | stacked_observations=4,
30 | max_episode_steps=108000,
31 | gray_scale=True,
32 | obs_shape=(84, 84)))
33 | for x in range(self.ring_size)
34 | ]
35 |
36 | def initial_setup(self):
37 | sample_stream_qsize = 1024
38 |
39 | self.policy_name = "default"
40 | self.policy = Policy(
41 | type_="actor-critic",
42 | args=dict(
43 | obs_dim={"obs": (4, 84, 84)},
44 | action_dim=6,
45 | num_dense_layers=0,
46 | hidden_dim=512,
47 | popart=True,
48 | layernorm=False,
49 | shared_backbone=True,
50 | rnn_type='lstm',
51 | num_rnn_layers=0,
52 | seed=self.seed,
53 | cnn_layers=dict(obs=[(16, 8, 4, 0,
54 | 'zeros'), (32, 4, 2, 0, 'zeros')]),
55 | chunk_len=10,
56 | ))
57 | self.trainer = Trainer(type_="mappo",
58 | args=dict(
59 | discount_rate=0.99,
60 | gae_lambda=0.97,
61 | eps_clip=0.2,
62 | clip_value=True,
63 | dual_clip=False,
64 | vtrace=False,
65 | value_loss='huber',
66 | value_loss_weight=1.0,
67 | value_loss_config=dict(delta=10.0, ),
68 | entropy_bonus_weight=0.01,
69 | optimizer='adam',
70 | optimizer_config=dict(lr=5e-4),
71 | popart=True,
72 | max_grad_norm=40.0,
73 | bootstrap_steps=1,
74 | recompute_adv_among_epochs=False,
75 | recompute_adv_on_reuse=False,
76 | burn_in_steps=0,
77 | ))
78 | self.agent_specs = [
79 | AgentSpec(
80 | index_regex=".*",
81 | inference_stream_idx=0,
82 | sample_stream_idx=0,
83 | send_full_trajectory=False,
84 | send_after_done=False,
85 | sample_steps=50,
86 | bootstrap_steps=1,
87 | )
88 | ]
89 |
90 | spl_stream = SampleStream(
91 | type_=SampleStream.Type.NAME,
92 | stream_name=self.policy_name,
93 | )
94 | inf_stream = InferenceStream(type_=InferenceStream.Type.NAME,
95 | stream_name=self.policy_name)
96 |
97 | actors = [
98 | ActorWorker(env=self.make_envs(i),
99 | inference_streams=[inf_stream],
100 | sample_streams=[spl_stream],
101 | agent_specs=self.agent_specs,
102 | max_num_steps=108000,
103 | inference_splits=self.inference_splits,
104 | ring_size=self.ring_size,
105 | worker_info=WorkerInformation(node_rank=1))
106 | for i in range(self.aws)
107 | ]
108 |
109 | policies = [
110 | PolicyWorker(
111 | policy_name=self.policy_name,
112 | inference_stream=inf_stream,
113 | policy=self.policy,
114 | pull_frequency_seconds=None,
115 | parameter_service_client=ParameterServiceClient(),
116 | worker_info=WorkerInformation(device="cuda:0", node_rank=0),
117 | ) for i in range(self.pws)
118 | ]
119 |
120 | return ExperimentConfig(
121 | actors=actors,
122 | policies=policies,
123 | trainers=[
124 | TrainerWorker(
125 | buffer_name='priority_queue',
126 | buffer_zero_copy=False,
127 | buffer_args=dict(
128 | max_size=sample_stream_qsize,
129 | reuses=1,
130 | batch_size=32,
131 | ),
132 | policy_name=self.policy_name,
133 | trainer=self.trainer,
134 | policy=self.policy,
135 | log_frequency_seconds=5,
136 | sample_stream=spl_stream,
137 | worker_info=WorkerInformation(
138 | wandb_job_type='tw',
139 | wandb_group='mini',
140 | wandb_project='srl-atari',
141 | wandb_name=f'seed{self.seed}',
142 | log_terminal=True,
143 | device="cuda:0",
144 | node_rank=0,
145 | ),
146 | ) for _ in range(self.tws)
147 | ],
148 | master_worker=[MasterWorker(
149 | address="10.210.12.42",
150 | port=51234,
151 | )])
152 |
153 |
154 | register_experiment("atari-mini-multinode",
155 | functools.partial(AtariMultiNodeExperiment))
156 |
--------------------------------------------------------------------------------
/src/rlsrl/system/__init__.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import dataclasses
3 | import importlib
4 | import logging
5 | import os
6 | import traceback
7 |
8 |
9 | @dataclasses.dataclass
10 | class WorkerSpec:
11 | """Description of a worker implementation.
12 | """
13 | short_name: str # short name is used in file names.
14 | config_field_name: str # Used in experiment/scheduling configuration(api.config).
15 | class_name: str # The class name of the implementation.
16 | module: str # The module path to find the worker class.
17 |
18 | def load_worker(self):
19 | module = importlib.import_module(self.module)
20 | return getattr(module, self.class_name)
21 |
22 |
23 | actor_worker = WorkerSpec(short_name='aw',
24 | class_name="ActorWorker",
25 | config_field_name="actors",
26 | module="rlsrl.system.impl.actor_worker")
27 | # buffer_worker = WorkerSpec(short_name='bw',
28 | # class_name="BufferWorker",
29 | # config_field_name="buffers",
30 | # module="rlsrl.system.impl.buffer_worker")
31 | eval_manager = WorkerSpec(short_name='em',
32 | class_name="EvalManager",
33 | config_field_name="eval_managers",
34 | module="rlsrl.system.impl.eval_manager")
35 | policy_worker = WorkerSpec(short_name='pw',
36 | class_name="PolicyWorker",
37 | config_field_name="policies",
38 | module="rlsrl.system.impl.policy_worker")
39 | trainer_worker = WorkerSpec(short_name='tw',
40 | class_name="TrainerWorker",
41 | config_field_name="trainers",
42 | module="rlsrl.system.impl.trainer_worker")
43 | # population_manager = WorkerSpec(short_name="pm",
44 | # class_name="PopulationManager",
45 | # config_field_name="population_manager",
46 | # module="rlsrl.system.impl.population_manager")
47 |
48 | RL_WORKERS = collections.OrderedDict()
49 | RL_WORKERS["trainer"] = trainer_worker
50 | # RL_WORKERS["buffer"] = buffer_worker
51 | RL_WORKERS["policy"] = policy_worker
52 | RL_WORKERS["eval_manager"] = eval_manager
53 | # RL_WORKERS["population_manager"] = population_manager
54 | RL_WORKERS["actor"] = actor_worker
55 |
56 |
57 | def run_worker(worker_type,
58 | experiment_name,
59 | trial_name,
60 | worker_name,
61 | ctrl,
62 | env_vars=None):
63 | """Run one worker
64 | Args:
65 | worker_type: string, one of the worker types listed above,
66 | experiment_name: string, the experiment this worker belongs to,
67 | trial_name: string, the specific trial this worker belongs to,
68 | worker_name: name given to the worker, typically "/"
69 | """
70 | import torch
71 | torch.cuda.init()
72 | if env_vars is not None:
73 | for k, v in env_vars.items():
74 | os.environ[k] = v
75 |
76 | if worker_type == "master":
77 | from rlsrl.system.impl.master_worker import MasterWorker as worker_class
78 | else:
79 | worker_class = RL_WORKERS[worker_type].load_worker()
80 | worker = worker_class(ctrl=ctrl)
81 | try:
82 | worker.run()
83 | except Exception as e:
84 | logging.error("Worker %s failed with exception: %s", worker_name, e)
85 | logging.error(traceback.format_exc())
86 | raise e
87 |
--------------------------------------------------------------------------------
/src/rlsrl/system/api/inference_stream.py:
--------------------------------------------------------------------------------
1 | """This module defines the data flow between policy workers and actor workers.
2 |
3 | In our design, actor workers are in charge of executing env.step() (typically simulation), while
4 | policy workers running policy.rollout_step() (typically neural network inference). The inference
5 | stream is the abstraction of the data flow between them: the actor workers send environment
6 | observations as requests, and the policy workers return actions as responses, both plus other
7 | additional information.
8 | """
9 | from typing import List, Optional, Any, Union
10 |
11 | import rlsrl.api.config as config_api
12 | import rlsrl.api.policy as policy_api
13 |
14 |
15 | class InferenceClient:
16 | """Interface used by the actor workers to obtain actions given current observation."""
17 |
18 | def post_request(self, request: policy_api.RolloutRequest,
19 | index: int) -> int:
20 | """Set the client_id and request_id of the request and cache the request.
21 |
22 | Args:
23 | request: RolloutRequest of length 1.
24 | index: Index of the request in the shared memory buffer.
25 | Used only when using shared memory.
26 | """
27 | raise NotImplementedError()
28 |
29 | def poll_responses(self):
30 | """Poll all responses from inference server.
31 | This method is considered thread unsafe and
32 | only called by the main process.
33 | """
34 | raise NotImplementedError()
35 |
36 | def is_ready(self, inference_ids: List[int],
37 | buffer_indices: List[int]) -> bool:
38 | """Check whether a specific request is ready to be consumed.
39 |
40 | Args:
41 | inference_ids: A list of requests to check.
42 | buffer_indices: The buffer indices of requests to check.
43 |
44 | Outputs:
45 | is_ready: Whether the inference_ids are all ready.
46 | """
47 | raise NotImplementedError()
48 |
49 | def register_agent(self):
50 | return 0
51 |
52 | def consume_result(self, inference_ids: List[int],
53 | buffer_indices: List[int]):
54 | """Consume a result with specific request_id, returns un-pickled message.
55 | Raises KeyError if inference id is not ready. Make sure you call is_ready before consuming.
56 |
57 | Args:
58 | inference_ids: a list of requests to consume.
59 | buffer_indices: the buffer indices of requests to consume.
60 |
61 | Outputs:
62 | results: list of rollout_request.
63 | """
64 | raise NotImplementedError()
65 |
66 | def flush(self):
67 | """Send all cached inference requests to inference server.
68 | Implementations are considered thread-unsafe.
69 | """
70 | raise NotImplementedError()
71 |
72 | def get_constant(self, name: str) -> Any:
73 | """Retrieve the constant value saved by inference server.
74 |
75 | Args:
76 | name: name of the constant to get.
77 |
78 | Returns:
79 | value: the value set by inference server.
80 | """
81 | raise NotImplementedError()
82 |
83 |
84 | class InferenceServer:
85 | """Interface used by the policy workers to serve inference requests."""
86 |
87 | def poll_requests(self) -> List[policy_api.RolloutRequest]:
88 | """Consumes all incoming requests.
89 |
90 | Returns:
91 | RequestPool: A list of requests, already batched by client.
92 | """
93 | raise NotImplementedError()
94 |
95 | def respond(self, response: policy_api.RolloutResult):
96 | """Send rollout results to inference clients.
97 |
98 | Args:
99 | response: rollout result to send.
100 | """
101 | raise NotImplementedError()
102 |
103 | def set_constant(self, name: str, value: Any):
104 | """Retrieve the constant value saved by inference server.
105 |
106 | Args:
107 | name: name of the constant to get.
108 | value: the value to be set, can be any object that can be pickled..
109 | """
110 | raise NotImplementedError()
111 |
112 |
113 | ALL_INFERENCE_CLIENT_CLS = {}
114 | ALL_INFERENCE_SERVER_CLS = {}
115 |
116 |
117 | def register_server(type_: config_api.InferenceStream.Type, cls):
118 | ALL_INFERENCE_SERVER_CLS[type_] = cls
119 |
120 |
121 | def register_client(type_: config_api.InferenceStream.Type, cls):
122 | ALL_INFERENCE_CLIENT_CLS[type_] = cls
123 |
124 |
125 | def make_server(spec: Union[str, config_api.InferenceStream, InferenceServer],
126 | worker_info: Optional[config_api.WorkerInformation] = None,
127 | *args,
128 | **kwargs):
129 | """Initializes an inference stream server.
130 |
131 | Args:
132 | spec: Inference stream specification.
133 | worker_info: The server worker information.
134 | """
135 | if isinstance(spec, InferenceServer):
136 | return spec
137 | if isinstance(spec, str):
138 | spec = config_api.InferenceStream(
139 | type_=config_api.InferenceStream.Type.NAME,
140 | stream_name=spec)
141 | if spec.worker_info is None:
142 | spec.worker_info = worker_info
143 | return ALL_INFERENCE_SERVER_CLS[spec.type_](spec, *args, **kwargs)
144 |
145 |
146 | def make_client(spec: Union[str, config_api.InferenceStream],
147 | worker_info: Optional[config_api.WorkerInformation] = None,
148 | *args,
149 | **kwargs):
150 | """Initializes an inference stream client.
151 |
152 | Args:
153 | spec: Inference stream specification.
154 | worker_info: The client worker information.
155 | """
156 | if isinstance(spec, InferenceClient):
157 | return spec
158 | if isinstance(spec, str):
159 | spec = config_api.InferenceStream(
160 | type_=config_api.InferenceStream.Type.NAME,
161 | stream_name=spec)
162 | if spec.worker_info is None:
163 | spec.worker_info = worker_info
164 | return ALL_INFERENCE_CLIENT_CLS[spec.type_](spec, *args, **kwargs)
165 |
--------------------------------------------------------------------------------
/src/rlsrl/system/api/sample_stream.py:
--------------------------------------------------------------------------------
1 | """This module defines the data flow between the actor workers and the trainers. It is a simple
2 | producer-consumer model.
3 |
4 | A side note that our design chooses to let actor workers see all the data, and posts trajectory
5 | samples to the trainer, instead of letting the policy workers doing so.
6 | """
7 | from typing import Optional, List, Union, Any
8 |
9 | import rlsrl.api.config as config_api
10 | import rlsrl.base.buffer as buffer
11 |
12 |
13 | class NothingToConsume(Exception):
14 | pass
15 |
16 |
17 | class SampleProducer:
18 | """Used by the actor workers to post samples to the trainers.
19 | """
20 |
21 | def post(self, sample):
22 | """Post a sample. Implementation should be thread safe.
23 | Args:
24 | sample: data to be sent.
25 | """
26 | raise NotImplementedError()
27 |
28 | def flush(self):
29 | """Flush all posted samples.
30 | Thread-safety:
31 | The implementation of `flush` is considered thread-unsafe. Therefore, on each producer end, only one
32 | thread should call flush. At the same time, it is safe to call `post` on other threads.
33 | """
34 | raise NotImplementedError()
35 |
36 | def close(self):
37 | """ Explicitly close sample stream. """
38 | pass
39 |
40 |
41 | class SampleConsumer:
42 | """Used by the trainers to acquire samples.
43 | """
44 |
45 | def consume_to(self, buffer: buffer.Buffer, max_iter) -> int:
46 | """Consumes all available samples to a target buffer.
47 |
48 | Returns:
49 | The count of samples added to the buffer.
50 | """
51 | raise NotImplementedError()
52 |
53 | def consume(self) -> Any:
54 | """Consume one from stream. Blocking consume is not supported as it may cause workers to stuck.
55 | Returns:
56 | Whatever is sent by the producer.
57 |
58 | Raises:
59 | NoSampleException: if nothing can be consumed from sample stream.
60 | """
61 | raise NotImplementedError()
62 |
63 | def close(self):
64 | """ Explicitly close sample stream. """
65 | pass
66 |
67 |
68 | class NullSampleProducer(SampleProducer):
69 | """NullSampleProducer discards all samples.
70 | """
71 |
72 | def flush(self):
73 | pass
74 |
75 | def post(self, sample):
76 | pass
77 |
78 |
79 | class ZippedSampleProducer(SampleProducer):
80 |
81 | def __init__(self, sample_producers: List[SampleProducer]):
82 | self.__producers = sample_producers
83 |
84 | def post(self, sample):
85 | # TODO: With the current implementation, we are pickling samples for multiple times.
86 | for p in self.__producers:
87 | p.post(sample)
88 |
89 | def flush(self):
90 | for p in self.__producers:
91 | p.flush()
92 |
93 |
94 | ALL_SAMPLE_PRODUCER_CLS = {}
95 | ALL_SAMPLE_CONSUMER_CLS = {}
96 |
97 |
98 | def register_producer(type_: config_api.SampleStream.Type, cls):
99 | ALL_SAMPLE_PRODUCER_CLS[type_] = cls
100 |
101 |
102 | def register_consumer(type_: config_api.SampleStream.Type, cls):
103 | ALL_SAMPLE_CONSUMER_CLS[type_] = cls
104 |
105 |
106 | def make_producer(spec: Union[str, config_api.SampleStream, SampleProducer],
107 | worker_info: Optional[config_api.WorkerInformation] = None,
108 | *args,
109 | **kwargs):
110 | """Initializes a sample producer (client).
111 |
112 | Args:
113 | spec: Configuration of the sample stream.
114 | worker_info: Worker information.
115 | """
116 | if isinstance(spec, SampleProducer):
117 | return spec
118 | if isinstance(spec, str):
119 | spec = config_api.SampleStream(type_=config_api.SampleStream.Type.NAME,
120 | stream_name=spec)
121 | if spec.worker_info is None:
122 | spec.worker_info = worker_info
123 | return ALL_SAMPLE_PRODUCER_CLS[spec.type_](spec, *args, **kwargs)
124 |
125 |
126 | def make_consumer(spec: Union[str, config_api.SampleStream, SampleConsumer],
127 | worker_info: Optional[config_api.WorkerInformation] = None,
128 | *args,
129 | **kwargs):
130 | """Initializes a sample consumer (server).
131 |
132 | Args:
133 | spec: Configuration of the sample stream.
134 | worker_info: Worker information.
135 | """
136 | if isinstance(spec, SampleConsumer):
137 | return spec
138 | if isinstance(spec, str):
139 | spec = config_api.SampleStream(type_=config_api.SampleStream.Type.NAME,
140 | stream_name=spec)
141 | if spec.worker_info is None:
142 | spec.worker_info = worker_info
143 | return ALL_SAMPLE_CONSUMER_CLS[spec.type_](spec, *args, **kwargs)
144 |
145 |
146 | def zip_producers(sample_producers: List[SampleProducer]):
147 | return ZippedSampleProducer(sample_producers)
148 |
149 |
150 | register_producer(config_api.SampleStream.Type.NULL,
151 | lambda spec: NullSampleProducer)
152 |
--------------------------------------------------------------------------------
/src/rlsrl/system/api/worker_control.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Tuple, Union
2 | import collections
3 | import dataclasses
4 | import multiprocessing as mp
5 | import multiprocessing.connection as mp_connection
6 |
7 | from rlsrl.base.shared_memory import (OutOfOrderSharedMemoryControl,
8 | SharedMemoryInferenceStreamCtrl,
9 | PinnedRequestSharedMemoryControl,
10 | PinnedResponseSharedMemoryControl)
11 | import rlsrl.api.config as config_api
12 |
13 |
14 | @dataclasses.dataclass
15 | class WorkerCtrl:
16 | rx: mp_connection.Connection
17 | inf_ctrls: Tuple[SharedMemoryInferenceStreamCtrl] = None
18 | spl_ctrls: Tuple[OutOfOrderSharedMemoryControl] = None
19 |
20 |
21 | def _reveal_shm_stream_identity(setup: config_api.ExperimentConfig):
22 | inf_streams = set()
23 | spl_streams = set()
24 | spl_specs = dict()
25 |
26 | # Collect all shared memory stream names.
27 | for aw in setup.actors:
28 | for i, x in enumerate(aw.inference_streams):
29 | if (not isinstance(x, str) and x.type_
30 | == config_api.InferenceStream.Type.SHARED_MEMORY):
31 | inf_streams.add(aw.inference_streams[i].stream_name)
32 |
33 | for i, x in enumerate(aw.sample_streams):
34 | if (not isinstance(x, str)
35 | and x.type_ == config_api.SampleStream.Type.SHARED_MEMORY):
36 | stream_name = aw.sample_streams[i].stream_name
37 | spl_streams.add(stream_name)
38 | if stream_name not in spl_specs:
39 | spl_specs[stream_name] = (aw.sample_streams[i].qsize,
40 | aw.sample_streams[i].reuses,
41 | aw.sample_streams[i].batch_size)
42 | else:
43 | assert spl_specs[stream_name] == (
44 | aw.sample_streams[i].qsize,
45 | aw.sample_streams[i].reuses,
46 | aw.sample_streams[i].batch_size
47 | ), ("Inconsistent shared memory stream specification. "
48 | "Specs like reuses and qsize should be the same with the same stream name."
49 | )
50 |
51 | return setup, sorted(inf_streams), sorted(spl_streams), spl_specs
52 |
53 |
54 | def make_worker_control(experiment_name: str, trial_name: str,
55 | setup: config_api.ExperimentConfig):
56 | # Make worker control.
57 | (setup, inf_streams, spl_streams,
58 | spl_specs) = _reveal_shm_stream_identity(setup)
59 |
60 | inf_ctrls = [
61 | SharedMemoryInferenceStreamCtrl(
62 | request_ctrl=PinnedRequestSharedMemoryControl(
63 | experiment_name,
64 | trial_name,
65 | f"{x}_infreq_ctrl",
66 | ),
67 | response_ctrl=PinnedResponseSharedMemoryControl(
68 | experiment_name,
69 | trial_name,
70 | f"{x}_infresp_ctrl",
71 | )) for x in inf_streams
72 | ]
73 | spl_ctrls = [
74 | OutOfOrderSharedMemoryControl(
75 | experiment_name,
76 | trial_name,
77 | f"{x}_spl_ctrl",
78 | qsize=spl_specs[x][0],
79 | reuses=spl_specs[x][1],
80 | ) for x in spl_streams
81 | ]
82 |
83 | tx_handles = collections.defaultdict(list)
84 | ctrls = collections.defaultdict(list)
85 |
86 | # Assign shared memory stream index to each worker.
87 | # Stream ctrls will be indexed correspondingly.
88 | for aw in setup.actors:
89 | for i, x in enumerate(aw.inference_streams):
90 | if (not isinstance(x, str) and x.type_
91 | == config_api.InferenceStream.Type.SHARED_MEMORY):
92 | x.stream_index = inf_streams.index(x.stream_name)
93 | for i, x in enumerate(aw.sample_streams):
94 | if (not isinstance(x, str)
95 | and x.type_ == config_api.SampleStream.Type.SHARED_MEMORY):
96 | x.stream_index = spl_streams.index(x.stream_name)
97 | tx, rx = mp.Pipe()
98 | ctrls['actor'].append(WorkerCtrl(rx, inf_ctrls, spl_ctrls))
99 | tx_handles['actor'].append(tx)
100 |
101 | for pw in setup.policies:
102 | if (not isinstance(pw.inference_stream, str)
103 | and pw.inference_stream.type_
104 | == config_api.InferenceStream.Type.SHARED_MEMORY):
105 | pw.inference_stream.stream_index = inf_streams.index(
106 | pw.inference_stream.stream_name)
107 | tx, rx = mp.Pipe()
108 | ctrls['policy'].append(WorkerCtrl(rx, inf_ctrls, None))
109 | tx_handles['policy'].append(tx)
110 |
111 | for tw in setup.trainers:
112 | if (not isinstance(tw.sample_stream, str) and tw.sample_stream.type_
113 | == config_api.SampleStream.Type.SHARED_MEMORY):
114 | tw.sample_stream.stream_index = spl_streams.index(
115 | tw.sample_stream.stream_name)
116 | tx, rx = mp.Pipe()
117 | ctrls['trainer'].append(WorkerCtrl(rx, None, spl_ctrls))
118 | tx_handles['trainer'].append(tx)
119 |
120 | for em in setup.eval_managers:
121 | if (not isinstance(em.eval_sample_stream, str)
122 | and em.eval_sample_stream.type_
123 | == config_api.SampleStream.Type.SHARED_MEMORY):
124 | em.eval_sample_stream.stream_index = spl_streams.index(
125 | em.eval_sample_stream.stream_name)
126 | tx, rx = mp.Pipe()
127 | ctrls['eval_manager'].append(WorkerCtrl(rx, None, spl_ctrls))
128 | tx_handles['eval_manager'].append(tx)
129 |
130 | return setup, dict(ctrls), dict(tx_handles)
--------------------------------------------------------------------------------
/src/rlsrl/system/impl/__init__.py:
--------------------------------------------------------------------------------
1 | import rlsrl.system.impl.actor_worker
2 | import rlsrl.system.impl.dummy_worker
3 | import rlsrl.system.impl.eval_manager
4 | import rlsrl.system.impl.inline_inference
5 | import rlsrl.system.impl.local_inference
6 | import rlsrl.system.impl.local_sample
7 | import rlsrl.system.impl.master_worker
8 | import rlsrl.system.impl.policy_worker
9 | import rlsrl.system.impl.remote_inference
10 | import rlsrl.system.impl.remote_sample
11 | import rlsrl.system.impl.trainer_worker
12 |
--------------------------------------------------------------------------------
/src/rlsrl/system/impl/dummy_worker.py:
--------------------------------------------------------------------------------
1 | import time
2 | import logging
3 |
4 | import rlsrl.api.config
5 | import rlsrl.base.name_resolve
6 | import rlsrl.system.api.worker_base as worker_base
7 |
8 | logger = logging.getLogger("DummyWorker")
9 |
10 |
11 | class DummyWorker(worker_base.Worker):
12 | # dummy worker for testing
13 | def __init__(self, ctrl=None):
14 | super().__init__(ctrl=ctrl)
15 | self.__count = 0
16 |
17 | def _configure(self, config):
18 | self.my_key = self.my_value = config.worker_info.worker_index
19 | logger.info(
20 | f"Before dummy worker add /dummy/{self.my_key} {self.my_value}_0")
21 | rlsrl.base.name_resolve.add(f"/dummy/{self.my_key}",
22 | f"{self.my_value}_0")
23 | logger.info(
24 | f"After dummy worker add /dummy/{self.my_key} {self.my_value}_0")
25 | return config.worker_info
26 |
27 | def _poll(self):
28 | get_subtree_results = rlsrl.base.name_resolve.get_subtree("/dummy")
29 | get_result = rlsrl.base.name_resolve.get(f"/dummy/{self.my_key}")
30 |
31 | logger.info(
32 | f"Before dummy worker add /dummy/{self.my_key} {self.my_value}_{self.__count}"
33 | )
34 | rlsrl.base.name_resolve.add(f"/dummy/{self.my_key}",
35 | f"{self.my_value}_{self.__count}",
36 | replace=True)
37 | logger.info(
38 | f"After dummy worker add /dummy/{self.my_key} {self.my_value}_{self.__count}"
39 | )
40 |
41 | logger.info(
42 | f"get_subtree_results: {get_subtree_results}, get_result: {get_result}"
43 | )
44 | self.__count += 1
45 | time.sleep(1)
46 | return worker_base.PollResult(valid=True)
47 |
--------------------------------------------------------------------------------
/src/rlsrl/system/impl/inline_inference.py:
--------------------------------------------------------------------------------
1 | from typing import List, Any
2 | import logging
3 | import numpy as np
4 | import time
5 |
6 | import rlsrl.api.config as config_api
7 | import rlsrl.api.policy as policy_api
8 | import rlsrl.base.numpy_utils as numpy_utils
9 | import rlsrl.base.namedarray as namedarray
10 | import rlsrl.base.timeutil as timeutil
11 | import rlsrl.system.api.parameter_db as db
12 | import rlsrl.system.api.inference_stream as inference_stream
13 |
14 | _INLINE_PASSIVE_PULL_FREQUENCY_SECONDS = 2
15 | _INLINE_PULL_PARAMETER_ON_START = True
16 |
17 | logger = logging.getLogger("InlineInferenceStream")
18 |
19 |
20 | class InlineInferenceClient(inference_stream.InferenceClient):
21 |
22 | def poll_responses(self):
23 | pass
24 |
25 | def __init__(self,
26 | policy,
27 | policy_name,
28 | param_db,
29 | worker_info,
30 | pull_interval,
31 | policy_identifier,
32 | parameter_service_client=None,
33 | foreign_policy=None,
34 | accept_update_call=True,
35 | population=None,
36 | policy_sample_probs=None):
37 | self.policy_name = policy_name
38 | self.__policy_identifier = policy_identifier
39 | import os
40 | os.environ["MARL_CUDA_DEVICES"] = "cpu"
41 | self.policy = policy_api.make(policy)
42 | self.policy.eval_mode()
43 | self.__logger = logging.getLogger("Inline Inference")
44 | self._request_count = 0
45 | self.__request_buffer = []
46 | self._response_cache = {}
47 | self.__pull_freq_control = timeutil.FrequencyControl(
48 | frequency_seconds=pull_interval,
49 | initial_value=_INLINE_PULL_PARAMETER_ON_START)
50 | self.__passive_pull_freq_control = timeutil.FrequencyControl(
51 | frequency_seconds=_INLINE_PASSIVE_PULL_FREQUENCY_SECONDS,
52 | initial_value=_INLINE_PULL_PARAMETER_ON_START,
53 | )
54 | self.__load_absolute_path = None
55 | self.__accept_update_call = accept_update_call
56 | self.__parameter_service_client = None
57 |
58 | # Parameter DB / Policy name related.
59 | if foreign_policy is not None:
60 | p = foreign_policy
61 | i = worker_info
62 | pseudo_worker_info = config_api.WorkerInformation(
63 | experiment_name=p.foreign_experiment_name or i.experiment_name,
64 | trial_name=p.foreign_trial_name or i.trial_name)
65 | self.__param_db = db.make_db(p.param_db,
66 | worker_info=pseudo_worker_info)
67 | self.__load_absolute_path = p.absolute_path
68 | self.__load_policy_name = p.foreign_policy_name or policy_name
69 | self.__policy_identifier = p.foreign_policy_identifier or policy_identifier
70 | else:
71 | self.__param_db = db.make_db(param_db, worker_info=worker_info)
72 | self.__load_policy_name = policy_name
73 | self.__policy_identifier = policy_identifier
74 |
75 | if parameter_service_client is not None and self.__load_absolute_path is None:
76 | self.__parameter_service_client = db.make_client(
77 | parameter_service_client, worker_info)
78 | self.__parameter_service_client.subscribe(
79 | experiment_name=self.__param_db.experiment_name,
80 | trial_name=self.__param_db.trial_name,
81 | policy_name=self.__load_policy_name,
82 | tag=self.__policy_identifier,
83 | callback_fn=self.policy.load_checkpoint,
84 | use_current_thread=True)
85 |
86 | self.configure_population(population, policy_sample_probs)
87 |
88 | self.__log_frequency_control = timeutil.FrequencyControl(
89 | frequency_seconds=10)
90 |
91 | def configure_population(self, population, policy_sample_probs):
92 | if population is not None:
93 | assert policy_sample_probs is None or len(
94 | policy_sample_probs
95 | ) == len(population), (
96 | f"Size of policy_sample_probs {len(policy_sample_probs)} and population {len(population)} must be the same."
97 | )
98 | self.__population = population
99 | if policy_sample_probs is None:
100 | policy_sample_probs = np.ones(
101 | len(population)) / len(population)
102 | self.__policy_sample_probs = policy_sample_probs
103 | elif self.policy_name is None:
104 | policy_names = self.__param_db.list_names()
105 | if len(policy_names) == 0:
106 | raise ValueError(
107 | "You set policy_name and population to be None, but no existing policies were found."
108 | )
109 | logger.info(f"Auto-detected population {policy_names}")
110 | self.__population = policy_names
111 | self.__policy_sample_probs = np.ones(
112 | len(policy_names)) / len(policy_names)
113 | else:
114 | self.__population = None
115 | self.__policy_sample_probs = None
116 |
117 | def post_request(self, request: policy_api.RolloutRequest, _=None) -> int:
118 | request.request_id = np.array([self._request_count], dtype=np.int64)
119 | req_id = self._request_count
120 | self.__request_buffer.append(request)
121 | self._request_count += 1
122 | self.flush()
123 | return req_id
124 |
125 | def is_ready(self, inference_ids: List[int], _=None) -> bool:
126 | for req_id in inference_ids:
127 | if req_id not in list(self._response_cache.keys()):
128 | return False
129 | return True
130 |
131 | def consume_result(self, inference_ids: List[int], _=None):
132 | return [self._response_cache.pop(req_id) for req_id in inference_ids]
133 |
134 | def load_parameter(self):
135 | """Method exposed to Actor worker so we can reload parameter when env is done.
136 | """
137 | if self.__passive_pull_freq_control.check(
138 | ) and self.__accept_update_call:
139 | # This reduces the unnecessary workload of mongodb.
140 | self.__load_parameter()
141 |
142 | def __get_checkpoint_from_db(self, block=False):
143 | if self.__load_absolute_path is not None:
144 | return self.__param_db.get_file(self.__load_absolute_path)
145 | else:
146 | return self.__param_db.get(name=self.__load_policy_name,
147 | identifier=self.__policy_identifier,
148 | block=block)
149 |
150 | def __load_parameter(self):
151 | if self.__population is None:
152 | policy_name = self.policy_name
153 | else:
154 | policy_name = np.random.choice(self.__population,
155 | p=self.__policy_sample_probs)
156 | checkpoint = self.__get_checkpoint_from_db(
157 | block=self.policy.version < 0)
158 | self.policy.load_checkpoint(checkpoint)
159 | self.policy_name = policy_name
160 | self.__logger.debug(
161 | f"Loaded {self.policy_name}'s parameter of version {self.policy.version}"
162 | )
163 |
164 | def flush(self):
165 | if self.__pull_freq_control.check():
166 | self.__load_parameter()
167 |
168 | if self.__parameter_service_client is not None:
169 | self.__parameter_service_client.poll()
170 |
171 | if self.__log_frequency_control.check():
172 | self.__logger.debug(f"Policy Version: {self.policy.version}")
173 |
174 | if len(self.__request_buffer) > 0:
175 | agg_req = namedarray.recursive_aggregate(self.__request_buffer,
176 | np.stack)
177 | rollout_results = self.policy.rollout(agg_req)
178 | rollout_results.request_id = agg_req.request_id
179 | rollout_results.policy_version_steps = np.full(
180 | shape=agg_req.client_id.shape, fill_value=self.policy.version)
181 | rollout_results.policy_name = np.full(
182 | shape=agg_req.client_id.shape, fill_value=self.policy_name)
183 | self.__request_buffer = []
184 | for i in range(rollout_results.length(dim=0)):
185 | self._response_cache[rollout_results.request_id[
186 | i, 0]] = rollout_results[i]
187 |
188 | def get_constant(self, name: str) -> Any:
189 | if name == "default_policy_state":
190 | return self.policy.default_policy_state
191 | else:
192 | raise NotImplementedError(name)
193 |
194 |
195 | inference_stream.register_client(
196 | config_api.InferenceStream.Type.INLINE,
197 | lambda spec: InlineInferenceClient(
198 | policy=spec.policy,
199 | policy_name=spec.policy_name,
200 | param_db=spec.param_db,
201 | worker_info=spec.worker_info,
202 | pull_interval=spec.pull_interval_seconds,
203 | policy_identifier=spec.policy_identifier,
204 | foreign_policy=spec.foreign_policy,
205 | accept_update_call=spec.accept_update_call,
206 | population=spec.population,
207 | parameter_service_client=spec.parameter_service_client,
208 | policy_sample_probs=spec.policy_sample_probs,
209 | ),
210 | )
211 |
--------------------------------------------------------------------------------
/src/rlsrl/system/impl/local_sample.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import numpy as np
3 | import time
4 | import threading
5 |
6 | import rlsrl.api.config as config_api
7 | import rlsrl.base.namedarray as namedarray
8 | import rlsrl.base.shared_memory as shared_memory
9 | import rlsrl.system.api.sample_stream as sample_stream
10 |
11 | logger = logging.getLogger("LocalSampleStream")
12 |
13 |
14 | class SharedMemorySampleProducer(sample_stream.SampleProducer):
15 |
16 | def __init__(
17 | self,
18 | experiment_name,
19 | trial_name,
20 | stream_name,
21 | qsize,
22 | ctrl: shared_memory.OutOfOrderSharedMemoryControl,
23 | ):
24 | self.__shared_memory_writer = shared_memory.SharedMemoryDock(
25 | experiment_name,
26 | trial_name,
27 | stream_name + "_sample",
28 | qsize=qsize,
29 | ctrl=ctrl,
30 | second_dim_index=True,
31 | )
32 | self.__post_lock = threading.Lock()
33 | self.__sample_buffer = []
34 |
35 | def post(self, sample):
36 | with self.__post_lock:
37 | self.__sample_buffer.append(sample)
38 |
39 | def flush(self):
40 | with self.__post_lock:
41 | tmp = self.__sample_buffer
42 | self.__sample_buffer = []
43 | for x in tmp:
44 | self.__shared_memory_writer.write(x)
45 |
46 | def close(self):
47 | self.__shared_memory_writer.close()
48 |
49 |
50 | sample_stream.register_producer(
51 | config_api.SampleStream.Type.SHARED_MEMORY,
52 | lambda spec, ctrl: SharedMemorySampleProducer(
53 | experiment_name=spec.worker_info.experiment_name,
54 | trial_name=spec.worker_info.trial_name,
55 | stream_name=spec.stream_name,
56 | qsize=spec.qsize,
57 | ctrl=ctrl,
58 | ),
59 | )
60 |
61 |
62 | class SharedMemorySampleConsumer(sample_stream.SampleConsumer):
63 |
64 | def __init__(
65 | self,
66 | experiment_name,
67 | trial_name,
68 | stream_name,
69 | qsize,
70 | batch_size,
71 | ctrl: shared_memory.OutOfOrderSharedMemoryControl,
72 | ):
73 | self.__shared_memory_reader = shared_memory.SharedMemoryDock(
74 | experiment_name,
75 | trial_name,
76 | stream_name + "_sample",
77 | qsize=qsize,
78 | ctrl=ctrl,
79 | second_dim_index=True)
80 | self.__batch_size = batch_size
81 |
82 | def consume_to(self, buffer, max_iter=16):
83 | count = 0
84 | for _ in range(max_iter):
85 | try:
86 | sample = self.__shared_memory_reader.read(
87 | batch_size=self.__batch_size)
88 | except shared_memory.NothingToRead:
89 | break
90 | if_batch = buffer.put(sample)
91 | count += 1
92 | return count
93 |
94 | def consume(self):
95 | try:
96 | return self.__shared_memory_reader.read(
97 | batch_size=self.__batch_size)
98 | except shared_memory.NothingToRead:
99 | raise sample_stream.NothingToConsume()
100 |
101 | def close(self):
102 | self.__shared_memory_reader.close()
103 |
104 |
105 | sample_stream.register_consumer(
106 | config_api.SampleStream.Type.SHARED_MEMORY,
107 | lambda spec, ctrl: SharedMemorySampleConsumer(
108 | experiment_name=spec.worker_info.experiment_name,
109 | trial_name=spec.worker_info.trial_name,
110 | stream_name=spec.stream_name,
111 | qsize=spec.qsize,
112 | batch_size=spec.batch_size,
113 | ctrl=ctrl,
114 | ),
115 | )
116 |
117 |
118 | class InlineSampleProducer(sample_stream.SampleProducer):
119 | """Testing Only! Will not push parameters.
120 | """
121 |
122 | def __init__(self, trainer, policy):
123 | from rlsrl.api.trainer import make
124 | from rlsrl.system.api.parameter_db import make_db
125 | from rlsrl.api.config import ParameterDB
126 |
127 | self.trainer = make(trainer, policy)
128 | self.buffer = []
129 | self.logger = logging.getLogger("Inline Training")
130 | self.param_db = make_db(
131 | ParameterDB(type_=ParameterDB.Type.LOCAL_TESTING))
132 | self.param_db.push(name="",
133 | checkpoint=self.trainer.get_checkpoint(),
134 | version=0)
135 |
136 | def post(self, sample):
137 | self.buffer.append(sample)
138 | self.logger.debug("Receive sample.")
139 |
140 | def flush(self):
141 | if len(self.buffer) >= 5:
142 | batch_sample = namedarray.recursive_aggregate(
143 | self.buffer, aggregate_fn=lambda x: np.stack(x, axis=1))
144 | batch_sample.policy_name = None
145 | self.trainer.step(batch_sample)
146 | self.param_db.push(name="",
147 | checkpoint=self.trainer.get_checkpoint(),
148 | version=0)
149 | self.logger.info("Trainer step is successful!")
150 | self.logger.debug(
151 | f"Trainer steps. now on version {self.trainer.policy.version}."
152 | )
153 | self.buffer = []
154 |
155 |
156 | sample_stream.register_producer(
157 | config_api.SampleStream.Type.INLINE_TESTING,
158 | lambda spec: InlineSampleProducer(
159 | trainer=spec.trainer,
160 | policy=spec.policy,
161 | ),
162 | )
163 |
--------------------------------------------------------------------------------
/src/rlsrl/system/impl/master_worker.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import threading
3 | import time
4 |
5 | import rlsrl.api.config
6 | import rlsrl.base.name_resolve
7 | import rlsrl.system.api.parameter_db as db
8 | import rlsrl.system.api.worker_base as worker_base
9 |
10 | logger = logging.getLogger("MasterWorker")
11 |
12 |
13 | class MasterWorker(worker_base.Worker):
14 | # A master worker is a worker similar to a centralized controller on master rank node.
15 | # It is responsible for name resolving and parameter service.
16 | # The implementation is simple and only for demonstration of distributed execution of SRL.
17 |
18 | def _configure(self, cfg: rlsrl.api.config.MasterWorker):
19 | self.__name_resolve_server = rlsrl.base.name_resolve.NameResolveServer(
20 | cfg.port)
21 |
22 | self.__name_resolve_server_thread = threading.Thread(
23 | target=self.__name_resolve_server.run)
24 | self.__name_resolve_server_thread.start()
25 | logger.info("Master worker started RPC name resolve server thread.")
26 |
27 | # self.__parameter_server = db.make_server(cfg.parameter_server,
28 | # cfg.worker_info)
29 | # self.__parameter_server.update_subscription()
30 | # self.__parameter_server.run()
31 |
32 | return cfg.worker_info
33 |
34 | def _poll(self):
35 | # self.__parameter_server.update_subscription()
36 | time.sleep(5)
37 | return worker_base.PollResult(sample_count=0, batch_count=0)
--------------------------------------------------------------------------------
/src/rlsrl/testing/__init__.py:
--------------------------------------------------------------------------------
1 | """The helper module for testing only."""
2 | import random
3 | import os
4 | import sys
5 | import time
6 | import mock
7 | import torch
8 | import threading
9 |
10 | import rlsrl.base.name_resolve as name_resolve
11 | import rlsrl.testing.aerochess_env
12 | import rlsrl.testing.null_trainer
13 | import rlsrl.testing.random_policy
14 |
15 | _IS_GITHUB_WORKFLOW = len(os.environ.get("CI", "").strip()) > 0
16 | os.environ["MARL_CUDA_DEVICES"] = "cpu"
17 | _DEFAULT_WAIT_NETWORK_SECONDS = 0.5 if _IS_GITHUB_WORKFLOW else 0.05
18 | os.environ["MARL_TESTING"] = "1"
19 |
20 | _next_port = 20000 + random.randint(
21 | 0, 10000) # Random port for now, should be ok most of the time.
22 |
23 |
24 | def get_testing_port():
25 | """Returns a local port for testing."""
26 | global _next_port
27 | _next_port += 1
28 | return _next_port
29 |
30 |
31 | def wait_network(length=_DEFAULT_WAIT_NETWORK_SECONDS):
32 | time.sleep(length)
33 |
34 |
35 | def get_test_param(version=0):
36 | return {
37 | "steps": version,
38 | "state_dict": {
39 | "linear_weights": torch.randn(10, 10)
40 | },
41 | }
42 |
43 |
44 | TESTING_RPC_NAME_RESOLVE_SERVER_PORT = get_testing_port()
45 | name_resolve_rpc_server = name_resolve.NameResolveServer(
46 | port=TESTING_RPC_NAME_RESOLVE_SERVER_PORT)
47 | thread = threading.Thread(target=name_resolve_rpc_server.run, daemon=True)
48 | thread.start()
49 |
--------------------------------------------------------------------------------
/src/rlsrl/testing/aerochess_env.py:
--------------------------------------------------------------------------------
1 | """A simplified environment of Aeroplane Chess useful for testing.
2 | """
3 | from typing import List
4 | import gym
5 | import numpy as np
6 | import scipy.stats
7 |
8 | import rlsrl.api.environment as env_base
9 | import rlsrl.api.env_utils as env_utils
10 |
11 |
12 | class AerochessEnvironment(env_base.Environment):
13 | """Our simple environment has a very simple and predictable behaviour:
14 | - There are `n` players (configurable) all starting at location 0.
15 | - In each step, each players goes forward by `a_i in [1, 6]` steps while `a` is the action provided.
16 | - Each player's observation is a vector showing all players' current location.
17 | - Each player's reward is its rank (`n` for first, and 1 for last) after the step.
18 | - Reaching location `length` will end the game for the player.
19 | - No killing is considered.
20 | """
21 |
22 | def __init__(self, length=10, n=1, max_steps=None):
23 | self.__length = length
24 | self.__n = n
25 | self.__space = env_utils.DiscreteActionSpace(gym.spaces.Discrete(6))
26 | self.__locations = None
27 | self.__steps = None
28 | self.max_steps = max_steps
29 |
30 | @property
31 | def agent_count(self):
32 | return self.__n
33 |
34 | @property
35 | def observation_spaces(self):
36 | return {}
37 |
38 | @property
39 | def action_spaces(self):
40 | return [self.__space for _ in range(self.__n)]
41 |
42 | def reset(self):
43 | self.__locations = np.zeros(self.__n, dtype=np.int64)
44 | self.__steps = 0
45 | return [
46 | env_base.StepResult(obs={"obs": self.__locations},
47 | reward=np.array([0], dtype=np.float32),
48 | done=np.array([False], dtype=np.uint8),
49 | info={}) for _ in range(self.__n)
50 | ]
51 |
52 | def step(self, actions: List[env_utils.DiscreteAction]):
53 | delta = np.array([a.x.item() for a in actions], dtype=np.int64)
54 | assert np.all(delta > 0), delta
55 | self.__locations = np.minimum(self.__locations + delta, self.__length)
56 | rewards = self.__n - scipy.stats.rankdata(self.__locations, method='min') + 1
57 |
58 | self.__steps += 1
59 | if not self.max_steps:
60 | return [
61 | env_base.StepResult(obs={"obs": self.__locations},
62 | reward=np.array([r], dtype=np.float32),
63 | done=np.array([loc >= self.__length], dtype=np.uint8),
64 | info={}) for r, loc in zip(rewards, self.__locations)
65 | ]
66 | else:
67 | if self.__steps < self.max_steps:
68 | return [
69 | env_base.StepResult(obs={"obs": self.__locations},
70 | reward=np.array([r], dtype=np.float32),
71 | done=np.array([False], dtype=np.uint8),
72 | info={}) for r, loc in zip(rewards, self.__locations)
73 | ]
74 | else:
75 | # step = 3, the third step, finish step
76 | return [
77 | env_base.StepResult(obs={"obs": self.__locations},
78 | reward=np.array([r], dtype=np.float32),
79 | done=np.array([True], dtype=np.uint8),
80 | info={}) for r, loc in zip(rewards, self.__locations)
81 | ]
82 |
83 |
84 | env_base.register("aerochess", AerochessEnvironment)
85 |
--------------------------------------------------------------------------------
/src/rlsrl/testing/null_trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from rlsrl.api.trainer import Trainer, TrainerStepResult
4 | from rlsrl.api.trainer import register
5 | from rlsrl.api.policy import make as make_policy
6 | import rlsrl.api.policy as policy_api
7 |
8 |
9 | class NullPolicy(torch.nn.Module):
10 |
11 | def __init__(self):
12 | super().__init__()
13 | self.linear = torch.nn.Linear(3, 3)
14 |
15 | def set_state_dict(self, param):
16 | self.load_state_dict(param)
17 |
18 |
19 | class NullTrainer(Trainer):
20 |
21 | @property
22 | def policy(self) -> policy_api.Policy:
23 | return self._policy
24 |
25 | def __init__(self, policy, **kwargs):
26 | self._policy = policy
27 | self.steps = 0
28 |
29 | def step(self, sample):
30 | self.steps += 1
31 | return TrainerStepResult(stats={}, step=0)
32 |
33 | def distributed(self, **kwargs):
34 | pass
35 |
36 | def get_checkpoint(self, *args, **kwargs):
37 | return {}
38 |
39 | def load_checkpoint(self, *args, **kwargs):
40 | pass
41 |
42 |
43 | register('null_trainer', NullTrainer)
44 |
--------------------------------------------------------------------------------
/src/rlsrl/testing/random_policy.py:
--------------------------------------------------------------------------------
1 | import torch.nn
2 | import numpy as np
3 |
4 | from rlsrl.api.policy import Policy, RolloutResult, register
5 |
6 |
7 | class RandomPolicy(Policy):
8 | """A un-trainable random policy for testing
9 | """
10 |
11 | def get_checkpoint(self):
12 | return {"state_dict": {}}
13 |
14 | def load_checkpoint(self, checkpoint):
15 | pass
16 |
17 | def train_mode(self):
18 | pass
19 |
20 | def eval_mode(self):
21 | pass
22 |
23 | @property
24 | def net(self):
25 | pass
26 |
27 | @property
28 | def version(self) -> int:
29 | return 0
30 |
31 | def inc_version(self):
32 | pass
33 |
34 | @property
35 | def neural_networks(self):
36 | return []
37 |
38 | def distributed(self):
39 | pass
40 |
41 | @property
42 | def default_policy_state(self):
43 | return None
44 |
45 | def __init__(self, action_space: int):
46 | super(RandomPolicy, self).__init__()
47 | self.action_space = action_space
48 | self.state_dim = 10
49 | self.__net = torch.nn.Module()
50 |
51 | def analyze(self, sample, **kwargs):
52 | pass
53 |
54 | def rollout(self, requests, **kwargs):
55 | num_requests = requests.length(dim=0)
56 | actions_scores = np.random.randint(low=10, high=100, size=(num_requests, self.action_space))
57 | actions_probs = actions_scores / actions_scores.sum()
58 | actions = actions_probs.argmax(axis=1)
59 | policy_states = np.random.random((num_requests, 1, self.state_dim))
60 |
61 | return RolloutResult(action=actions, log_probs=actions_probs, policy_state=policy_states)
62 |
63 | def parameters(self):
64 | return self.state_dim
65 |
66 |
67 | register("random_policy", RandomPolicy)
68 |
--------------------------------------------------------------------------------
/tests/system/eval_manager_test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import socket
3 | import tempfile
4 | import mock
5 | import numpy as np
6 | import unittest
7 |
8 | from rlsrl.testing import *
9 |
10 | from rlsrl.system.impl.remote_inference import IpInferenceClient
11 | from rlsrl.system.impl.eval_manager import EvalManager
12 | import rlsrl.api.config as config_api
13 | import rlsrl.api.trainer as trainer_api
14 | import rlsrl.base.name_resolve as name_resolve
15 | import rlsrl.base.namedarray as namedarray
16 | import rlsrl.system.api.sample_stream as sample_stream
17 | import rlsrl.system.api.parameter_db as parameter_db
18 |
19 |
20 | class TestEpisodeInfo(namedarray.NamedArray):
21 |
22 | def __init__(
23 | self,
24 | hp: np.ndarray = np.array([0], dtype=np.float32),
25 | mana: np.ndarray = np.array([0], dtype=np.float32),
26 | ):
27 | super(TestEpisodeInfo, self).__init__(hp=hp, mana=mana)
28 |
29 |
30 | def make_config(policy_name="test_policy",
31 | eval_stream_name="eval_test_policy",
32 | worker_index=0,
33 | worker_count=1):
34 | return config_api.EvaluationManager(
35 | eval_sample_stream=config_api.SampleStream(
36 | config_api.SampleStream.Type.NAME, stream_name=eval_stream_name),
37 | parameter_db=config_api.ParameterDB(
38 | config_api.ParameterDB.Type.FILESYSTEM),
39 | policy_name=policy_name,
40 | eval_tag="evaluation",
41 | eval_games_per_version=5,
42 | worker_info=config_api.WorkerInformation("test_exp", "test_run",
43 | "trainer", worker_index,
44 | worker_count),
45 | )
46 |
47 |
48 | def random_sample_batch(version=0, hp=0, mana=0, policy_name="test_policy"):
49 | return trainer_api.SampleBatch(
50 | obs=np.random.random(size=(10, 10)),
51 | reward=np.random.random(size=(10, 1)),
52 | policy_version_steps=np.full(shape=(10, 1), fill_value=version),
53 | info=TestEpisodeInfo(hp=np.full(shape=(10, 1), fill_value=hp),
54 | mana=np.full(shape=(10, 1), fill_value=mana)),
55 | info_mask=np.concatenate([np.zeros(
56 | (9, 1)), np.ones((1, 1))], axis=0),
57 | policy_name=np.full(shape=(10, 1), fill_value=policy_name))
58 |
59 |
60 | def make_test_producer(policy_name="test_policy", rank=0):
61 | producer = sample_stream.make_producer(
62 | config_api.SampleStream(config_api.SampleStream.Type.NAME,
63 | stream_name=policy_name),
64 | worker_info=config_api.WorkerInformation("test_exp", "test_run",
65 | "policy", rank, 100),
66 | )
67 | return producer
68 |
69 |
70 | class TestEvalManager(unittest.TestCase):
71 |
72 | def setUp(self) -> None:
73 | IpInferenceClient._shake_hand = mock.Mock()
74 | self.__tmp = tempfile.TemporaryDirectory()
75 | parameter_db.PytorchFilesystemParameterDB.ROOT = os.path.join(
76 | self.__tmp.name, "checkpoints")
77 |
78 | os.environ["WANDB_MODE"] = "disabled"
79 | socket.gethostbyname = mock.MagicMock(return_value="127.0.0.1")
80 | name_resolve.reconfigure("memory", log_events=True)
81 | name_resolve.reconfigure("memory", log_events=True)
82 |
83 | def tearDown(self) -> None:
84 | db = parameter_db.make_db(config_api.ParameterDB(
85 | type_=config_api.ParameterDB.Type.FILESYSTEM),
86 | worker_info=config_api.WorkerInformation(
87 | experiment_name="test_exp",
88 | trial_name="test_run",
89 | ))
90 | try:
91 | db.clear("test_policy")
92 | except FileNotFoundError:
93 | pass
94 |
95 | def test_loginfo(self):
96 | test_parameter_db = parameter_db.make_db(
97 | config_api.ParameterDB(
98 | type_=config_api.ParameterDB.Type.FILESYSTEM),
99 | worker_info=config_api.WorkerInformation(
100 | experiment_name="test_exp",
101 | trial_name="test_run",
102 | ))
103 | try:
104 | test_parameter_db.clear("test_policy")
105 | except FileNotFoundError:
106 | pass
107 | eval_manager = EvalManager()
108 | eval_manager.configure(make_config("test_policy", "eval", "metadata"))
109 | producer = make_test_producer(policy_name="eval")
110 | wait_network()
111 | r = eval_manager._poll()
112 | self.assertEqual(r.sample_count, 0)
113 | self.assertEqual(r.batch_count, 0)
114 |
115 | for _ in range(5):
116 | producer.post(random_sample_batch(version=0))
117 | producer.flush()
118 | wait_network()
119 | # Eval manager does not accept sample until the first version is pushed.
120 | for _ in range(5):
121 | r = eval_manager._poll()
122 | self.assertEqual(r.sample_count, 1)
123 | self.assertEqual(r.batch_count, 0)
124 | r = eval_manager._poll()
125 | self.assertEqual(r.sample_count, 0)
126 | self.assertEqual(r.batch_count, 0)
127 |
128 | test_parameter_db.push("test_policy",
129 | get_test_param(version=1),
130 | version="1")
131 | for _ in range(5):
132 | producer.post(random_sample_batch(version=0))
133 | producer.flush()
134 | wait_network()
135 | for _ in range(5):
136 | r = eval_manager._poll()
137 | self.assertEqual(r.sample_count, 1)
138 | self.assertEqual(r.batch_count, 0)
139 | r = eval_manager._poll()
140 | self.assertEqual(r.sample_count, 0)
141 | self.assertEqual(r.batch_count, 0)
142 |
143 | test_parameter_db.push("test_policy", get_test_param(20), version="20")
144 | for _ in range(5):
145 | producer.post(random_sample_batch(version=1))
146 | producer.flush()
147 | wait_network()
148 | for _ in range(4):
149 | r = eval_manager._poll()
150 | self.assertEqual(r.sample_count, 1)
151 | self.assertEqual(r.batch_count, 0)
152 | r = eval_manager._poll()
153 | self.assertEqual(r.sample_count, 1)
154 | self.assertEqual(r.batch_count, 1)
155 | r = eval_manager._poll()
156 | self.assertEqual(r.sample_count, 0)
157 | self.assertEqual(r.batch_count, 0)
158 |
159 | # Evaluation manager loads to version 20. 10 episodes will be logged.
160 | for _ in range(10):
161 | producer.post(random_sample_batch(version=20))
162 | producer.flush()
163 | wait_network()
164 | for __ in range(2):
165 | for _ in range(4):
166 | r = eval_manager._poll()
167 | self.assertEqual(r.sample_count, 1)
168 | self.assertEqual(r.batch_count, 0)
169 | r = eval_manager._poll()
170 | self.assertEqual(r.sample_count, 1)
171 | self.assertEqual(r.batch_count, 1)
172 |
173 | test_parameter_db.push("test_policy", get_test_param(50), version="50")
174 | r = eval_manager._poll()
175 | self.assertEqual(r.sample_count, 0)
176 | self.assertEqual(r.batch_count, 0)
177 |
178 |
179 | if __name__ == '__main__':
180 | unittest.main()
181 |
--------------------------------------------------------------------------------