├── .gitignore
├── LICENSE
├── README.md
├── Town01_robust_rl_paths.png
├── carle_gym
├── __init__.py
├── envs
│ ├── __init__.py
│ └── carle_env.py
└── setup.py
├── config
├── config_A2C_Town01_mlp.json
├── config_A2C_Town02_mlp.json
├── config_A2C_Town03_mlp.json
├── config_DQN_Town01_mlp.json
├── config_DQN_Town02_mlp.json
├── config_DQN_Town03_mlp.json
├── config_PPO_Town01_mlp.json
├── config_PPO_Town02_mlp.json
├── config_PPO_Town03_mlp.json
├── config_QRDQN_Town01_mlp.json
├── config_QRDQN_Town01_robust_rl_greedy.json
├── config_QRDQN_Town01_robust_rl_ssd.json
├── config_QRDQN_Town01_robust_rl_thres_ssd.json
├── config_QRDQN_Town02_mlp.json
├── config_QRDQN_Town02_robust_rl_greedy.json
├── config_QRDQN_Town02_robust_rl_ssd.json
├── config_QRDQN_Town02_robust_rl_thres_ssd.json
└── config_QRDQN_Town03_mlp.json
├── observation.png
├── parameters.md
├── run_stable_baselines3.py
├── scripts
├── carla_docker.sh
└── extract_maps.py
└── thirdparty
├── sb3_contrib
├── __init__.py
├── ars
│ ├── __init__.py
│ ├── ars.py
│ └── policies.py
├── common
│ ├── __init__.py
│ ├── envs
│ │ ├── __init__.py
│ │ └── invalid_actions_env.py
│ ├── maskable
│ │ ├── __init__.py
│ │ ├── buffers.py
│ │ ├── callbacks.py
│ │ ├── distributions.py
│ │ ├── evaluation.py
│ │ ├── policies.py
│ │ └── utils.py
│ ├── utils.py
│ ├── vec_env
│ │ ├── __init__.py
│ │ └── async_eval.py
│ └── wrappers
│ │ ├── __init__.py
│ │ ├── action_masker.py
│ │ └── time_feature.py
├── local_modifications.txt
├── ppo_mask
│ ├── __init__.py
│ ├── policies.py
│ └── ppo_mask.py
├── py.typed
├── qrdqn
│ ├── __init__.py
│ ├── policies.py
│ └── qrdqn.py
├── tqc
│ ├── __init__.py
│ ├── policies.py
│ └── tqc.py
├── trpo
│ ├── __init__.py
│ ├── policies.py
│ └── trpo.py
└── version.txt
└── stable_baselines3
├── __init__.py
├── a2c
├── __init__.py
├── a2c.py
└── policies.py
├── common
├── __init__.py
├── atari_wrappers.py
├── base_class.py
├── buffers.py
├── callbacks.py
├── distributions.py
├── env_checker.py
├── env_util.py
├── envs
│ ├── __init__.py
│ ├── bit_flipping_env.py
│ ├── identity_env.py
│ └── multi_input_envs.py
├── evaluation.py
├── logger.py
├── monitor.py
├── noise.py
├── off_policy_algorithm.py
├── on_policy_algorithm.py
├── policies.py
├── preprocessing.py
├── results_plotter.py
├── running_mean_std.py
├── save_util.py
├── sb2_compat
│ ├── __init__.py
│ └── rmsprop_tf_like.py
├── torch_layers.py
├── type_aliases.py
├── utils.py
└── vec_env
│ ├── __init__.py
│ ├── base_vec_env.py
│ ├── dummy_vec_env.py
│ ├── stacked_observations.py
│ ├── subproc_vec_env.py
│ ├── util.py
│ ├── vec_check_nan.py
│ ├── vec_extract_dict_obs.py
│ ├── vec_frame_stack.py
│ ├── vec_monitor.py
│ ├── vec_normalize.py
│ ├── vec_transpose.py
│ └── vec_video_recorder.py
├── ddpg
├── __init__.py
├── ddpg.py
└── policies.py
├── dqn
├── __init__.py
├── dqn.py
└── policies.py
├── her
├── __init__.py
├── goal_selection_strategy.py
└── her_replay_buffer.py
├── local_modifications.txt
├── ppo
├── __init__.py
├── policies.py
└── ppo.py
├── py.typed
├── sac
├── __init__.py
├── policies.py
└── sac.py
├── td3
├── __init__.py
├── policies.py
└── td3.py
└── version.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | carla_data
2 | Baselines
3 | carle_gym/carle.egg-info
4 | __pycache__
5 | scripts
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Robust Field Autonomy Lab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Distributional RL for Route Planning
2 |
3 | This repository provides the codes of our UR 2023 paper [here](https://arxiv.org/abs/2304.09996). We developed the Stochastic Road Network Environment based on the autonomous vehicle simulator [CARLA](https://github.com/carla-simulator/carla), and proposed a Distributional RL based route planner that can plan the shortest routes that minimize stochasticity in travel time.
4 |
5 |
6 |
7 |
8 |
9 | If you find this repository useful, please cite our paper
10 | ```
11 | @INPROCEEDINGS{10202222,
12 | author={Lin, Xi and Szenher, Paul and Martin, John D. and Englot, Brendan},
13 | booktitle={2023 20th International Conference on Ubiquitous Robots (UR)},
14 | title={Robust Route Planning with Distributional Reinforcement Learning in a Stochastic Road Network Environment},
15 | year={2023},
16 | volume={},
17 | number={},
18 | pages={287-294},
19 | doi={10.1109/UR57808.2023.10202222}}
20 | ```
21 |
22 | ## The Stochastic Road Network Environment
23 |
24 | The Stochastic Road Network Environment is built upon map structure and simulated sensor data originating from CARLA version 0.9.6. Five such maps are available by default, named in the sequencing of Town01 to Town05. The graph topology structure of Town01 and an example observation provided to the learning system are shown as follows. For detailed information about the Stochastic Road Network Environment, please refer to our paper.
25 |
26 |
27 |
28 |
29 |
30 | ## Map Data Generation
31 |
32 | The map data needed to run experiments on Town01 to Town05 could be downloaded from [here](https://stevens0-my.sharepoint.com/:f:/g/personal/xlin26_stevens_edu/EioIeHjcj_xNnJJc7ziMAUMBmz6fLFFxblYV2JWNHvAcyQ?e=R1UAjR), or you could go through the following process to generate data with the provided scripts.
33 |
34 | 1. Install [NVIDIA Container Runtime](https://nvidia.github.io/nvidia-container-runtime/)
35 | ```
36 | curl -s -L https://nvidia.github.io/nvidia-container-runtime/gpgkey | \
37 | sudo apt-key add -
38 | distribution=$(. /etc/os-release;echo $ID$VERSION_ID)
39 | curl -s -L https://nvidia.github.io/nvidia-container-runtime/$distribution/nvidia-container-runtime.list | \
40 | sudo tee /etc/apt/sources.list.d/nvidia-container-runtime.list
41 | sudo apt-get update
42 | sudo apt install nvidia-container-runtime
43 |
44 | sudo systemctl daemon-reload
45 | sudo systemctl restart docker
46 | ```
47 |
48 | 2. Clone this git repo and enter the directory.
49 | ```
50 | git clone git@github.com:RobustFieldAutonomyLab/Stochastic_Road_Network.git
51 | cd Stochastic_Road_Network
52 | ```
53 |
54 | 3. Install relevant system dependencies for CARLA Python Library:
55 | ```
56 | sudo apt install libpng16-16 libjpeg8 libtiff5
57 | ```
58 |
59 | 4. Run the Docker script (initializes headless CARLA server under Docker)
60 | ```
61 | sudo scripts/carla_docker.sh -oe
62 | ```
63 |
64 | 5. Run the data generation script:
65 | ```
66 | python scripts/extract_maps.py
67 | ```
68 |
69 | ## Train Route Planning RL agents
70 |
71 | Our proposed planner uses QR-DQN, and we select A2C, PPO and DQN as the traditional RL baselines. We provide configuration files in config directory for training RL agents on different maps.
72 | ```
73 | $ python run_stable_baselines3.py -C [config file (required)] -P [number of processes (optional)] -D [cuda device (optional)]
74 | ```
75 |
76 | ## Experiment Parameterization
77 | Example configuration files are provided in the **config** directory, and see [parameters.md](parameters.md) for detailed explanations of common parameters.
78 |
79 | ## Third Party Libraries
80 | This project uses implementations of A2C, PPO, DQN and QR-DQN agents from [stable-baselines3](https://github.com/DLR-RM/stable-baselines3) and [stable-baselines3-contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib), and makes some modifications to apply to the proposed environment. There are some agent specific parameters in the provided configuration files, please refer to [on_policy_algorithm.py](https://github.com/RobustFieldAutonomyLab/Stochastic_Road_Network/blob/main/thirdparty/stable_baselines3/common/on_policy_algorithm.py) ((A2C and PPO)) and [off_policy_algorithm.py](https://github.com/RobustFieldAutonomyLab/Stochastic_Road_Network/blob/main/thirdparty/stable_baselines3/common/off_policy_algorithm.py) (DQN and QR-DQN) for further information.
81 |
--------------------------------------------------------------------------------
/Town01_robust_rl_paths.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RobustFieldAutonomyLab/Stochastic_Road_Network/1cdcd41c7311560bf7a5df0d4d8bca829fe2b958/Town01_robust_rl_paths.png
--------------------------------------------------------------------------------
/carle_gym/__init__.py:
--------------------------------------------------------------------------------
1 | from gym.envs.registration import register
2 |
3 | register(
4 | id='carle-v0',
5 | entry_point='carle_gym.envs:CarleEnv',
6 | )
7 |
--------------------------------------------------------------------------------
/carle_gym/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from carle_gym.envs.carle_env import CarleEnv
2 |
--------------------------------------------------------------------------------
/carle_gym/envs/carle_env.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Sequence, Tuple, Union
3 | import numpy as np
4 |
5 | import gym
6 |
7 | class CarleEnv(gym.Env):
8 |
9 | def __init__(
10 | self,
11 | is_eval_env: bool,
12 | seed: int,
13 | dataset_dir: Union[str, Path],
14 | goal_states: Sequence[int],
15 | reset_state: int,
16 | discount: float,
17 | crosswalk_states: Sequence[int],
18 | agent: str,
19 | network: str,
20 | r_base: float,
21 | r_loopback: float
22 | ):
23 | # PRNG for random rewards.
24 | self._rand = np.random.RandomState(seed)
25 |
26 | # Count timestep and record stochastic reward
27 | self.count = 0
28 |
29 | # Record path if the environment is for evaluation
30 | self.record = True if is_eval_env else False
31 | self.curr_path = []
32 | self.all_paths = []
33 |
34 | # Record quantiles of all state action pair (for QR-DQN agent)
35 | self.agent = agent
36 | self.quantiles = []
37 |
38 | # Set state information for goals and resets
39 | self.goal_states = goal_states
40 | self.reset_state = reset_state
41 | self.state = self.reset_state
42 | self.prev_state = None # used for checking self-transition
43 | self.crosswalk_states = crosswalk_states
44 |
45 | # Initialize environment parameters and data
46 | dataset_path = Path(dataset_dir)
47 | self.waypoint_locations = np.loadtxt(
48 | fname=dataset_path / "waypoint_locations.csv", delimiter=","
49 | )
50 | self.transition_matrix = np.loadtxt(
51 | fname=dataset_path / "transition_matrix.csv", dtype=int, delimiter=","
52 | )
53 | self.observations = np.loadtxt(
54 | fname=dataset_path / "observations.csv", delimiter=","
55 | )
56 |
57 | # Define action space and observation space
58 | self.network = network
59 | self.action_space = gym.spaces.Discrete(np.shape(self.transition_matrix)[1])
60 | if self.network == "MlpPolicy":
61 | self.observation_space = gym.spaces.Box(np.zeros((255*255,)),np.ones((255*255,)),dtype=np.float32)
62 | elif self.network == "CnnPolicy":
63 | self.observation_shape = (1,255,255)
64 | self.observation_space = gym.spaces.Box(np.zeros(self.observation_shape),np.ones(self.observation_shape),dtype=np.float32)
65 | else:
66 | raise RuntimeError("The network strucutre is not available")
67 |
68 | # Initialize transition counter
69 | self.transition_counts = np.zeros_like(self.transition_matrix)
70 |
71 | # Set discount factor
72 | self.discount = discount
73 |
74 | # Set rewards values
75 | self.r_base = r_base
76 | self.r_loopback = r_loopback
77 | self.rewards = -self.r_base * np.ones(len(self.transition_matrix))
78 | self.rewards[np.array(goal_states)] = 0
79 |
80 | def step(self, action: int) -> Tuple[np.ndarray, float, bool, dict]:
81 | """Apply the given action and transition to the next state."""
82 | self.prev_state = self.state
83 | self.transition_counts[self.state, action] += 1
84 | self.state = self.transition_matrix[self.state, action]
85 | self.count += 1
86 |
87 | if self.record:
88 | self.curr_path.append(self.state)
89 |
90 | return self.get_obs(), self.get_reward(), self.get_done(), self.get_state()
91 |
92 | def reset(self) -> np.ndarray:
93 | """Reset the environment and save path if the environment is for evaluation"""
94 | self.transition_counts = np.zeros_like(self.transition_matrix)
95 | self.state = self.reset_state
96 | self.prev_state = None
97 | self.count = 0
98 |
99 | if self.record:
100 | if np.shape(self.curr_path)[0] != 0:
101 | self.all_paths.append(self.curr_path)
102 | self.curr_path = []
103 |
104 | return self.get_obs()
105 |
106 | def get_obs_at_state(self, state:int) -> np.ndarray:
107 | """Returns the observation image (ground) for a given state."""
108 | scans = np.reshape(self.observations[state],(255,255,2))
109 | ground = scans[:,:,1]
110 | if self.network == "MlpPolicy":
111 | return np.array(ground.flatten())
112 | elif self.network == "CnnPolicy":
113 | return np.array([ground])
114 | else:
115 | raise RuntimeError("The network strucutre is not available")
116 |
117 | def get_obs(self) -> np.ndarray:
118 | """Returns the observation image (ground) for the current state."""
119 | #scans = np.reshape(self.observations[self.state],(255,255,2))
120 | #ground = scans[:,:,1]
121 | #return np.array(ground.flatten())
122 | return self.get_obs_at_state(self.state)
123 |
124 | def save_quantiles(self, quantiles:np.ndarray) -> np.ndarray:
125 | """Save quantiles of all state action pair (for QR-DQN agent)"""
126 | assert self.agent == "QRDQN", "save_quantiles is only avaible to the QR-DQN agent"
127 | self.quantiles.append(quantiles)
128 |
129 | def get_quantiles(self) -> np.ndarray:
130 | """Get quantiles of all state action pair (for QR-DQN agent)"""
131 | assert self.agent == "QRDQN", "get_quantiles is only avaible to the QR-DQN agent"
132 | return np.array(self.quantiles)
133 |
134 | def get_reward(self, ssd_thres:int=15) -> float:
135 | """Returns the reward for reaching the current state."""
136 | # Penalize the self-transition action
137 | if self.prev_state == self.state:
138 | # return self.rewards[self.state] - ssd_thres - 3
139 | return self.rewards[self.state] - self.r_loopback
140 |
141 | # Add noise at the simulated cross walks.
142 | if self.state in self.crosswalk_states:
143 | # Use vonmises distribution as stand-in for wrapped gaussian
144 | # - Interval is bounded from -2*r_base to 0 with below parameters
145 | # - kappa parameter is inversely proportional to variance
146 | # - see:
147 | # https://numpy.org/devdocs/reference/random/generated/numpy.random.vonmises.html
148 | return self.r_base*(self._rand.vonmises(mu=0, kappa=1) / np.pi) - self.r_base
149 |
150 | # Deterministic traffic penalty otherwise.
151 | return self.rewards[self.state]
152 |
153 | def get_done(self) -> bool:
154 | """Returns a done flag if the goal is reached."""
155 | return self.state in self.goal_states
156 |
157 | def get_state(self) -> dict:
158 | """Return current state id"""
159 | info = {"state_id":self.state}
160 | return info
161 |
162 | def get_count(self) -> list:
163 | """Return count since last reset"""
164 | return self.count
165 |
166 | def get_all_paths(self) -> list:
167 | """Return all paths in evaluation"""
168 | self.reset()
169 | return self.all_paths
170 |
171 | def get_num_of_states(self) -> int:
172 | return np.shape(self.transition_matrix)[0]
173 |
--------------------------------------------------------------------------------
/carle_gym/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup(name='carle',
4 | version='0.0.1',
5 | install_requires=['gym']
6 | )
7 |
--------------------------------------------------------------------------------
/config/config_A2C_Town01_mlp.json:
--------------------------------------------------------------------------------
1 | {
2 | "base":{
3 | "seed": [0],
4 | "eval_freq": 10000,
5 | "num_timesteps": 1000000
6 | },
7 | "agent":{
8 | "name": "A2C",
9 | "discount": 0.99,
10 | "alpha": [1e-04],
11 | "buffer_size": 64
12 | },
13 | "environment":{
14 | "map_name": "Town01",
15 | "data_dir": "carla_data",
16 | "start_state": 209,
17 | "goal_states": [112, 113],
18 | "crosswalk_states": [261],
19 | "r_base": 1,
20 | "r_loopback": 0
21 | },
22 | "policy": "MlpPolicy",
23 | "save_dir": "Baselines"
24 | }
25 |
--------------------------------------------------------------------------------
/config/config_A2C_Town02_mlp.json:
--------------------------------------------------------------------------------
1 | {
2 | "base":{
3 | "seed": [0],
4 | "eval_freq": 10000,
5 | "num_timesteps": 1000000
6 | },
7 | "agent":{
8 | "name": "A2C",
9 | "discount": 0.99,
10 | "alpha": [8e-05],
11 | "buffer_size": 64
12 | },
13 | "environment":{
14 | "map_name": "Town02",
15 | "data_dir": "carla_data",
16 | "start_state": 16,
17 | "goal_states": [89,90],
18 | "crosswalk_states": [96],
19 | "r_base": 1,
20 | "r_loopback": 0
21 | },
22 | "policy": "MlpPolicy",
23 | "save_dir": "Baselines"
24 | }
25 |
--------------------------------------------------------------------------------
/config/config_A2C_Town03_mlp.json:
--------------------------------------------------------------------------------
1 | {
2 | "base":{
3 | "seed": [0],
4 | "eval_freq": 10000,
5 | "num_timesteps": 2000000
6 | },
7 | "agent":{
8 | "name": "A2C",
9 | "discount": 0.99,
10 | "alpha": [1e-04],
11 | "buffer_size": 64
12 | },
13 | "environment":{
14 | "map_name": "Town03",
15 | "data_dir": "carla_data",
16 | "start_state": 546,
17 | "goal_states": [641, 642],
18 | "crosswalk_states": [585],
19 | "r_base": 1,
20 | "r_loopback": 0
21 | },
22 | "policy": "MlpPolicy",
23 | "save_dir": "Baselines"
24 | }
25 |
--------------------------------------------------------------------------------
/config/config_DQN_Town01_mlp.json:
--------------------------------------------------------------------------------
1 | {
2 | "base":{
3 | "seed": [0],
4 | "eval_freq": 10000,
5 | "num_timesteps": 1000000
6 | },
7 | "agent":{
8 | "name": "DQN",
9 | "discount": 0.99,
10 | "alpha": [1e-04],
11 | "buffer_size": 2048,
12 | "batch_size": 64
13 | },
14 | "environment":{
15 | "map_name": "Town01",
16 | "data_dir": "carla_data",
17 | "start_state": 209,
18 | "goal_states": [112, 113],
19 | "crosswalk_states": [261],
20 | "r_base": 1,
21 | "r_loopback": 0
22 | },
23 | "policy": "MlpPolicy",
24 | "save_dir": "Baselines"
25 | }
26 |
--------------------------------------------------------------------------------
/config/config_DQN_Town02_mlp.json:
--------------------------------------------------------------------------------
1 | {
2 | "base":{
3 | "seed": [0],
4 | "eval_freq": 10000,
5 | "num_timesteps": 1000000
6 | },
7 | "agent":{
8 | "name": "DQN",
9 | "discount": 0.99,
10 | "alpha": [5e-06],
11 | "buffer_size": 2048,
12 | "batch_size": 64
13 | },
14 | "environment":{
15 | "map_name": "Town02",
16 | "data_dir": "carla_data",
17 | "start_state": 16,
18 | "goal_states": [89,90],
19 | "crosswalk_states": [96],
20 | "r_base": 1,
21 | "r_loopback": 0
22 | },
23 | "policy": "MlpPolicy",
24 | "save_dir": "Baselines"
25 | }
26 |
--------------------------------------------------------------------------------
/config/config_DQN_Town03_mlp.json:
--------------------------------------------------------------------------------
1 | {
2 | "base":{
3 | "seed": [0],
4 | "eval_freq": 10000,
5 | "num_timesteps": 2000000
6 | },
7 | "agent":{
8 | "name": "DQN",
9 | "discount": 0.99,
10 | "alpha": [3e-04],
11 | "buffer_size": 2048,
12 | "batch_size": 64
13 | },
14 | "environment":{
15 | "map_name": "Town03",
16 | "data_dir": "carla_data",
17 | "start_state": 546,
18 | "goal_states": [641, 642],
19 | "crosswalk_states": [585],
20 | "r_base": 1,
21 | "r_loopback": 0
22 | },
23 | "policy": "MlpPolicy",
24 | "save_dir": "Baselines"
25 | }
26 |
--------------------------------------------------------------------------------
/config/config_PPO_Town01_mlp.json:
--------------------------------------------------------------------------------
1 | {
2 | "base":{
3 | "seed": [0],
4 | "eval_freq": 10000,
5 | "num_timesteps": 1000000
6 | },
7 | "agent":{
8 | "name": "PPO",
9 | "discount": 0.99,
10 | "alpha": [2e-04],
11 | "buffer_size": 2048,
12 | "batch_size": 64,
13 | "n_epochs":1
14 | },
15 | "environment":{
16 | "map_name": "Town01",
17 | "data_dir": "carla_data",
18 | "start_state": 209,
19 | "goal_states": [112, 113],
20 | "crosswalk_states": [261],
21 | "r_base": 1,
22 | "r_loopback": 0
23 | },
24 | "policy": "MlpPolicy",
25 | "save_dir": "Baselines"
26 | }
27 |
--------------------------------------------------------------------------------
/config/config_PPO_Town02_mlp.json:
--------------------------------------------------------------------------------
1 | {
2 | "base":{
3 | "seed": [0],
4 | "eval_freq": 10000,
5 | "num_timesteps": 1000000
6 | },
7 | "agent":{
8 | "name": "PPO",
9 | "discount": 0.99,
10 | "alpha": [3e-04],
11 | "buffer_size": 2048,
12 | "batch_size": 64,
13 | "n_epochs":1
14 | },
15 | "environment":{
16 | "map_name": "Town02",
17 | "data_dir": "carla_data",
18 | "start_state": 16,
19 | "goal_states": [89, 90],
20 | "crosswalk_states": [96],
21 | "r_base": 1,
22 | "r_loopback": 0
23 | },
24 | "policy": "MlpPolicy",
25 | "save_dir": "Baselines"
26 | }
27 |
--------------------------------------------------------------------------------
/config/config_PPO_Town03_mlp.json:
--------------------------------------------------------------------------------
1 | {
2 | "base":{
3 | "seed": [0],
4 | "eval_freq": 10000,
5 | "num_timesteps": 2000000
6 | },
7 | "agent":{
8 | "name": "PPO",
9 | "discount": 0.99,
10 | "alpha": [2e-04],
11 | "buffer_size": 2048,
12 | "batch_size": 64,
13 | "n_epochs":1
14 | },
15 | "environment":{
16 | "map_name": "Town03",
17 | "data_dir": "carla_data",
18 | "start_state": 546,
19 | "goal_states": [641, 642],
20 | "crosswalk_states": [585],
21 | "r_base": 1,
22 | "r_loopback": 0
23 | },
24 | "policy": "MlpPolicy",
25 | "save_dir": "Baselines"
26 | }
27 |
--------------------------------------------------------------------------------
/config/config_QRDQN_Town01_mlp.json:
--------------------------------------------------------------------------------
1 | {
2 | "base":{
3 | "seed": [0],
4 | "eval_freq": 10000,
5 | "num_timesteps": 2000000
6 | },
7 | "agent":{
8 | "name": "QRDQN",
9 | "discount": 0.99,
10 | "alpha": [6e-04],
11 | "buffer_size": 2048,
12 | "batch_size": 64,
13 | "n_quantiles": 4,
14 | "epsilon": 0.1,
15 | "eps_fraction": 0.02,
16 | "eval_policy": "Greedy"
17 | },
18 | "environment":{
19 | "map_name": "Town01",
20 | "data_dir": "carla_data",
21 | "start_state": 209,
22 | "goal_states": [112, 113],
23 | "crosswalk_states": [261],
24 | "r_base": 1,
25 | "r_loopback": 0
26 | },
27 | "policy": "MlpPolicy",
28 | "save_dir": "Baselines"
29 | }
30 |
31 |
--------------------------------------------------------------------------------
/config/config_QRDQN_Town01_robust_rl_greedy.json:
--------------------------------------------------------------------------------
1 | {
2 | "base":{
3 | "seed": [0],
4 | "eval_freq": 10000,
5 | "num_timesteps": 1000000
6 | },
7 | "agent":{
8 | "name": "QRDQN",
9 | "discount": 0.99,
10 | "alpha": [5e-04],
11 | "buffer_size": 2048,
12 | "batch_size": 64,
13 | "n_quantiles": 4,
14 | "epsilon": 0.1,
15 | "eps_fraction": 0.02,
16 | "eval_policy": "Greedy"
17 | },
18 | "environment":{
19 | "map_name": "Town01",
20 | "data_dir": "carla_data",
21 | "start_state": 209,
22 | "goal_states": [112, 113],
23 | "crosswalk_states": [261],
24 | "r_base": 3,
25 | "r_loopback": 18
26 | },
27 | "policy": "CnnPolicy",
28 | "save_dir": "Baselines"
29 | }
30 |
31 |
--------------------------------------------------------------------------------
/config/config_QRDQN_Town01_robust_rl_ssd.json:
--------------------------------------------------------------------------------
1 | {
2 | "base":{
3 | "seed": [0],
4 | "eval_freq": 10000,
5 | "num_timesteps": 1000000
6 | },
7 | "agent":{
8 | "name": "QRDQN",
9 | "discount": 0.99,
10 | "alpha": [5e-04],
11 | "buffer_size": 2048,
12 | "batch_size": 64,
13 | "n_quantiles": 4,
14 | "epsilon": 0.1,
15 | "eps_fraction": 0.02,
16 | "eval_policy": "SSD"
17 | },
18 | "environment":{
19 | "map_name": "Town01",
20 | "data_dir": "carla_data",
21 | "start_state": 209,
22 | "goal_states": [112, 113],
23 | "crosswalk_states": [261],
24 | "r_base": 3,
25 | "r_loopback": 18
26 | },
27 | "policy": "CnnPolicy",
28 | "save_dir": "Baselines"
29 | }
30 |
31 |
--------------------------------------------------------------------------------
/config/config_QRDQN_Town01_robust_rl_thres_ssd.json:
--------------------------------------------------------------------------------
1 | {
2 | "base":{
3 | "seed": [0],
4 | "eval_freq": 10000,
5 | "num_timesteps": 1000000
6 | },
7 | "agent":{
8 | "name": "QRDQN",
9 | "discount": 0.99,
10 | "alpha": [5e-04],
11 | "buffer_size": 2048,
12 | "batch_size": 64,
13 | "n_quantiles": 4,
14 | "epsilon": 0.1,
15 | "eps_fraction": 0.02,
16 | "eval_policy": "Thresholded_SSD"
17 | },
18 | "environment":{
19 | "map_name": "Town01",
20 | "data_dir": "carla_data",
21 | "start_state": 209,
22 | "goal_states": [112, 113],
23 | "crosswalk_states": [261],
24 | "r_base": 3,
25 | "r_loopback": 18
26 | },
27 | "policy": "CnnPolicy",
28 | "save_dir": "Baselines"
29 | }
30 |
31 |
--------------------------------------------------------------------------------
/config/config_QRDQN_Town02_mlp.json:
--------------------------------------------------------------------------------
1 | {
2 | "base":{
3 | "seed": [0],
4 | "eval_freq": 10000,
5 | "num_timesteps": 2000000
6 | },
7 | "agent":{
8 | "name": "QRDQN",
9 | "discount": 0.99,
10 | "alpha": [6e-05],
11 | "buffer_size": 2048,
12 | "batch_size": 64,
13 | "n_quantiles": 4,
14 | "epsilon": 0.1,
15 | "eps_fraction": 0.02,
16 | "eval_policy": "Greedy"
17 | },
18 | "environment":{
19 | "map_name": "Town02",
20 | "data_dir": "carla_data",
21 | "start_state": 16,
22 | "goal_states": [89, 90],
23 | "crosswalk_states": [96],
24 | "r_base": 1,
25 | "r_loopback": 0
26 | },
27 | "policy": "MlpPolicy",
28 | "save_dir": "Baselines"
29 | }
30 |
31 |
--------------------------------------------------------------------------------
/config/config_QRDQN_Town02_robust_rl_greedy.json:
--------------------------------------------------------------------------------
1 | {
2 | "base":{
3 | "seed": [0],
4 | "eval_freq": 10000,
5 | "num_timesteps": 1000000
6 | },
7 | "agent":{
8 | "name": "QRDQN",
9 | "discount": 0.99,
10 | "alpha": [5e-04],
11 | "buffer_size": 2048,
12 | "batch_size": 64,
13 | "n_quantiles": 4,
14 | "epsilon": 0.1,
15 | "eps_fraction": 0.02,
16 | "eval_policy": "Greedy"
17 | },
18 | "environment":{
19 | "map_name": "Town02",
20 | "data_dir": "/home/rfal/code/tro/data/carla_data",
21 | "start_state": 16,
22 | "goal_states": [89, 90],
23 | "crosswalk_states": [96],
24 | "r_base": 3,
25 | "r_loopback": 18
26 | },
27 | "policy": "CnnPolicy",
28 | "save_dir": "/home/rfal/Stochastic_Road_Network/Baselines"
29 | }
30 |
31 |
--------------------------------------------------------------------------------
/config/config_QRDQN_Town02_robust_rl_ssd.json:
--------------------------------------------------------------------------------
1 | {
2 | "base":{
3 | "seed": [0],
4 | "eval_freq": 10000,
5 | "num_timesteps": 1000000
6 | },
7 | "agent":{
8 | "name": "QRDQN",
9 | "discount": 0.99,
10 | "alpha": [5e-04],
11 | "buffer_size": 2048,
12 | "batch_size": 64,
13 | "n_quantiles": 4,
14 | "epsilon": 0.1,
15 | "eps_fraction": 0.02,
16 | "eval_policy": "SSD"
17 | },
18 | "environment":{
19 | "map_name": "Town02",
20 | "data_dir": "/home/rfal/code/tro/data/carla_data",
21 | "start_state": 16,
22 | "goal_states": [89, 90],
23 | "crosswalk_states": [96],
24 | "r_base": 3,
25 | "r_loopback": 18
26 | },
27 | "policy": "CnnPolicy",
28 | "save_dir": "/home/rfal/Stochastic_Road_Network/Baselines"
29 | }
30 |
31 |
--------------------------------------------------------------------------------
/config/config_QRDQN_Town02_robust_rl_thres_ssd.json:
--------------------------------------------------------------------------------
1 | {
2 | "base":{
3 | "seed": [0],
4 | "eval_freq": 10000,
5 | "num_timesteps": 1000000
6 | },
7 | "agent":{
8 | "name": "QRDQN",
9 | "discount": 0.99,
10 | "alpha": [5e-04],
11 | "buffer_size": 2048,
12 | "batch_size": 64,
13 | "n_quantiles": 4,
14 | "epsilon": 0.1,
15 | "eps_fraction": 0.02,
16 | "eval_policy": "Thresholded_SSD",
17 | "ssd_thres": 15
18 | },
19 | "environment":{
20 | "map_name": "Town02",
21 | "data_dir": "/home/rfal/code/tro/data/carla_data",
22 | "start_state": 16,
23 | "goal_states": [89, 90],
24 | "crosswalk_states": [96],
25 | "r_base": 3,
26 | "r_loopback": 18
27 | },
28 | "policy": "CnnPolicy",
29 | "save_dir": "/home/rfal/Stochastic_Road_Network/Baselines"
30 | }
31 |
32 |
--------------------------------------------------------------------------------
/config/config_QRDQN_Town03_mlp.json:
--------------------------------------------------------------------------------
1 | {
2 | "base":{
3 | "seed": [0],
4 | "eval_freq": 10000,
5 | "num_timesteps": 2000000
6 | },
7 | "agent":{
8 | "name": "QRDQN",
9 | "discount": 0.99,
10 | "alpha": [5e-05],
11 | "buffer_size": 2048,
12 | "batch_size": 64,
13 | "n_quantiles": 4,
14 | "epsilon": 0.1,
15 | "eps_fraction": 0.02,
16 | "eval_policy": "Greedy"
17 | },
18 | "environment":{
19 | "map_name": "Town03",
20 | "data_dir": "carla_data",
21 | "start_state": 546,
22 | "goal_states": [641, 642],
23 | "crosswalk_states": [585],
24 | "r_base": 1,
25 | "r_loopback": 0
26 | },
27 | "policy": "MlpPolicy",
28 | "save_dir": "Baselines"
29 | }
30 |
31 |
--------------------------------------------------------------------------------
/observation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RobustFieldAutonomyLab/Stochastic_Road_Network/1cdcd41c7311560bf7a5df0d4d8bca829fe2b958/observation.png
--------------------------------------------------------------------------------
/parameters.md:
--------------------------------------------------------------------------------
1 | # Experiment Config File parameters
2 |
3 | - **`base`**: Trial parameters.
4 | - **`seed`**: RNG Seeds.
5 | - **`eval_freq`**: Number of training timesteps between two evaluations (per trial).
6 | - **`num_timesteps`**: Number of total training timesteps (per trial).
7 | - **`agent`**: Agent parameters.
8 | - **`name`**: Method (A2C, PPO, DQN, or QRDQN).
9 | - **`discount`**: Discount factor (gamma).
10 | - **`alpha`**: Learning step size (alpha).
11 | - **`environment`**: Environment Parameters.
12 | - **`map_name`**: Name of the target Map.
13 | - **`data_dir`**: Directory where environment data is stored.
14 | - **`start_state`**: Start state.
15 | - **`goal_states`**: List of goal states.
16 | - **`crosswalk_states`**: List of stochastic crosswalk states.
17 | - **`policy`**: Agent network structure (currently only "CnnPolicy" is available).
18 | - **`save_dir`**: Directory where experiment data is stored.
19 |
--------------------------------------------------------------------------------
/scripts/carla_docker.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | ### carla_docker.sh -- run Carla environment in Docker container
4 | ###
5 | ### Usage:
6 | ### ./carla_docker.sh [options]
7 | ###
8 | ### Options:
9 | ### -h Show this message
10 | ### -o Offscreen mode: disable CARLA server window
11 | ###
12 |
13 | usage() {
14 | # Use the file header as a usage guide
15 | # Reference: https://samizdat.dev/help-message-for-shell-scripts/
16 | sed -rn 's/^### ?/ /;T;p' "$0"
17 | }
18 |
19 | version() {
20 | # Convert decimal version numbers into integers suitable for comparison
21 | echo "$@" | awk -F. '{ printf("%03d%03d%03d\n", $1,$2,$3); }'
22 | }
23 |
24 |
25 | # Get directory of scripts
26 | script_dir=$(dirname "$(readlink -f "$0")")
27 |
28 | # Set name of CARLA parameters
29 | carla_image="carlasim/carla:0.9.6"
30 | carla_egg_dir="/home/carla/PythonAPI/carla/dist"
31 |
32 | # Set nvidia container runtime flag (dependent on docker version)
33 | if [ "$(version "$(docker version -f '{{.Server.Version}}')")" -gt "$(version 19.03.0)" ]; then
34 | nvidia_flag="--gpus all"
35 | else
36 | nvidia_flag="--runtime=nvidia"
37 | fi
38 |
39 | # Set SDL_VIDEODRIVER env var to make CARLA server visible
40 | sdl_driver=x11
41 |
42 | # Parse arguments with posix-compatible getopt
43 | cmdargs=$(GETOPT_COMPATIBLE=1 getopt hoe "$@")
44 | eval set -- "$cmdargs"
45 |
46 | while [ "$#" -ne 0 ] ; do
47 | case "$1" in
48 | -h|--help)
49 | usage
50 | exit 1 ;;
51 | -o|--offscreen)
52 | sdl_driver=offscreen
53 | shift 1
54 | ;;
55 | -e|--egg)
56 | copy_egg=yes
57 | shift 1
58 | ;;
59 | --)
60 | break ;;
61 | *) logerror "Argument parse error, could not parse $1"
62 | exit 1 ;;
63 | esac
64 | done
65 |
66 | # If the CARLA egg file does not exist, copy it from the docker image
67 | if [ -n "$copy_egg" ]; then
68 | echo "CARLA egg file requested, copying from \"$carla_image\" image."
69 | egg_file=$(docker run --rm "$carla_image" bash -c "ls ${carla_egg_dir}/*py3*.egg")
70 | echo "Egg file $egg_file found, copying..."
71 | docker cp "$(docker create --rm $carla_image)":"${egg_file}" "$script_dir"
72 | fi
73 |
74 | echo "Running $carla_image"
75 | docker run \
76 | -p 2000-2002:2000-2002 \
77 | -v /tmp/.X11-unix:/tmp/.X11-unix \
78 | -e DISPLAY="$DISPLAY" \
79 | -e SDL_VIDEODRIVER="$sdl_driver"\
80 | -it \
81 | --rm \
82 | $nvidia_flag \
83 | $carla_image \
84 | ./CarlaUE4.sh -opengl
85 |
--------------------------------------------------------------------------------
/thirdparty/sb3_contrib/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from sb3_contrib.ars import ARS
4 | from sb3_contrib.ppo_mask import MaskablePPO
5 | from sb3_contrib.qrdqn import QRDQN
6 | from sb3_contrib.tqc import TQC
7 | from sb3_contrib.trpo import TRPO
8 |
9 | # Read version from file
10 | version_file = os.path.join(os.path.dirname(__file__), "version.txt")
11 | with open(version_file, "r") as file_handler:
12 | __version__ = file_handler.read().strip()
13 |
--------------------------------------------------------------------------------
/thirdparty/sb3_contrib/ars/__init__.py:
--------------------------------------------------------------------------------
1 | from sb3_contrib.ars.ars import ARS
2 | from sb3_contrib.ars.policies import LinearPolicy, MlpPolicy
3 |
--------------------------------------------------------------------------------
/thirdparty/sb3_contrib/ars/policies.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Optional, Type
2 |
3 | import gym
4 | import torch as th
5 | from stable_baselines3.common.policies import BasePolicy, register_policy
6 | from stable_baselines3.common.preprocessing import get_action_dim
7 | from stable_baselines3.common.torch_layers import create_mlp
8 | from torch import nn
9 |
10 |
11 | class ARSPolicy(BasePolicy):
12 | """
13 | Policy network for ARS.
14 |
15 | :param observation_space: The observation space of the environment
16 | :param action_space: The action space of the environment
17 | :param net_arch: Network architecture, defaults to a 2 layers MLP with 64 hidden nodes.
18 | :param activation_fn: Activation function
19 | :param squash_output: For continuous actions, whether the output is squashed
20 | or not using a ``tanh()`` function. If not squashed with tanh the output will instead be clipped.
21 | """
22 |
23 | def __init__(
24 | self,
25 | observation_space: gym.spaces.Space,
26 | action_space: gym.spaces.Space,
27 | net_arch: Optional[List[int]] = None,
28 | activation_fn: Type[nn.Module] = nn.ReLU,
29 | squash_output: bool = True,
30 | ):
31 |
32 | super().__init__(
33 | observation_space,
34 | action_space,
35 | squash_output=isinstance(action_space, gym.spaces.Box) and squash_output,
36 | )
37 |
38 | if net_arch is None:
39 | net_arch = [64, 64]
40 |
41 | self.net_arch = net_arch
42 | self.features_extractor = self.make_features_extractor()
43 | self.features_dim = self.features_extractor.features_dim
44 | self.activation_fn = activation_fn
45 |
46 | if isinstance(action_space, gym.spaces.Box):
47 | action_dim = get_action_dim(action_space)
48 | actor_net = create_mlp(self.features_dim, action_dim, net_arch, activation_fn, squash_output=True)
49 | elif isinstance(action_space, gym.spaces.Discrete):
50 | actor_net = create_mlp(self.features_dim, action_space.n, net_arch, activation_fn)
51 | else:
52 | raise NotImplementedError(f"Error: ARS policy not implemented for action space of type {type(action_space)}.")
53 |
54 | self.action_net = nn.Sequential(*actor_net)
55 |
56 | def _get_constructor_parameters(self) -> Dict[str, Any]:
57 | # data = super()._get_constructor_parameters() this adds normalize_images, which we don't support...
58 | data = dict(
59 | observation_space=self.observation_space,
60 | action_space=self.action_space,
61 | net_arch=self.net_arch,
62 | activation_fn=self.activation_fn,
63 | )
64 | return data
65 |
66 | def forward(self, obs: th.Tensor) -> th.Tensor:
67 |
68 | features = self.extract_features(obs)
69 | if isinstance(self.action_space, gym.spaces.Box):
70 | return self.action_net(features)
71 | elif isinstance(self.action_space, gym.spaces.Discrete):
72 | logits = self.action_net(features)
73 | return th.argmax(logits, dim=1)
74 | else:
75 | raise NotImplementedError()
76 |
77 | def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor:
78 | # Non deterministic action does not really make sense for ARS, we ignore this parameter for now..
79 | return self(observation)
80 |
81 |
82 | class ARSLinearPolicy(ARSPolicy):
83 | """
84 | Linear policy network for ARS.
85 |
86 | :param observation_space: The observation space of the environment
87 | :param action_space: The action space of the environment
88 | :param with_bias: With or without bias on the output
89 | :param squash_output: For continuous actions, whether the output is squashed
90 | or not using a ``tanh()`` function. If not squashed with tanh the output will instead be clipped.
91 | """
92 |
93 | def __init__(
94 | self,
95 | observation_space: gym.spaces.Space,
96 | action_space: gym.spaces.Space,
97 | with_bias: bool = False,
98 | squash_output: bool = False,
99 | ):
100 |
101 | super().__init__(observation_space, action_space, squash_output=squash_output)
102 |
103 | if isinstance(action_space, gym.spaces.Box):
104 | action_dim = get_action_dim(action_space)
105 | self.action_net = nn.Linear(self.features_dim, action_dim, bias=with_bias)
106 | if squash_output:
107 | self.action_net = nn.Sequential(self.action_net, nn.Tanh())
108 | elif isinstance(action_space, gym.spaces.Discrete):
109 | self.action_net = nn.Linear(self.features_dim, action_space.n, bias=with_bias)
110 | else:
111 | raise NotImplementedError(f"Error: ARS policy not implemented for action space of type {type(action_space)}.")
112 |
113 |
114 | MlpPolicy = ARSPolicy
115 | LinearPolicy = ARSLinearPolicy
116 |
117 |
118 | register_policy("LinearPolicy", LinearPolicy)
119 | register_policy("MlpPolicy", MlpPolicy)
120 |
--------------------------------------------------------------------------------
/thirdparty/sb3_contrib/common/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RobustFieldAutonomyLab/Stochastic_Road_Network/1cdcd41c7311560bf7a5df0d4d8bca829fe2b958/thirdparty/sb3_contrib/common/__init__.py
--------------------------------------------------------------------------------
/thirdparty/sb3_contrib/common/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from sb3_contrib.common.envs.invalid_actions_env import (
2 | InvalidActionEnvDiscrete,
3 | InvalidActionEnvMultiBinary,
4 | InvalidActionEnvMultiDiscrete,
5 | )
6 |
--------------------------------------------------------------------------------
/thirdparty/sb3_contrib/common/envs/invalid_actions_env.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 |
3 | import numpy as np
4 | from gym import spaces
5 | from stable_baselines3.common.envs import IdentityEnv
6 |
7 |
8 | class InvalidActionEnvDiscrete(IdentityEnv):
9 | """
10 | Identity env with a discrete action space. Supports action masking.
11 | """
12 |
13 | def __init__(
14 | self,
15 | dim: Optional[int] = None,
16 | ep_length: int = 100,
17 | n_invalid_actions: int = 0,
18 | ):
19 | if dim is None:
20 | dim = 1
21 | assert n_invalid_actions < dim, f"Too many invalid actions: {n_invalid_actions} < {dim}"
22 |
23 | space = spaces.Discrete(dim)
24 | self.n_invalid_actions = n_invalid_actions
25 | self.possible_actions = np.arange(space.n)
26 | self.invalid_actions: List[int] = []
27 | super().__init__(space=space, ep_length=ep_length)
28 |
29 | def _choose_next_state(self) -> None:
30 | self.state = self.action_space.sample()
31 | # Randomly choose invalid actions that are not the current state
32 | potential_invalid_actions = [i for i in self.possible_actions if i != self.state]
33 | self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False)
34 |
35 | def action_masks(self) -> List[bool]:
36 | return [action not in self.invalid_actions for action in self.possible_actions]
37 |
38 |
39 | class InvalidActionEnvMultiDiscrete(IdentityEnv):
40 | """
41 | Identity env with a multidiscrete action space. Supports action masking.
42 | """
43 |
44 | def __init__(
45 | self,
46 | dims: Optional[List[int]] = None,
47 | ep_length: int = 100,
48 | n_invalid_actions: int = 0,
49 | ):
50 | if dims is None:
51 | dims = [1, 1]
52 |
53 | if n_invalid_actions > sum(dims) - len(dims):
54 | raise ValueError(f"Cannot find a valid action for each dim. Set n_invalid_actions <= {sum(dims) - len(dims)}")
55 |
56 | space = spaces.MultiDiscrete(dims)
57 | self.n_invalid_actions = n_invalid_actions
58 | self.possible_actions = np.arange(sum(dims))
59 | self.invalid_actions: List[int] = []
60 | super().__init__(space=space, ep_length=ep_length)
61 |
62 | def _choose_next_state(self) -> None:
63 | self.state = self.action_space.sample()
64 |
65 | converted_state: List[int] = []
66 | running_total = 0
67 | for i in range(len(self.action_space.nvec)):
68 | converted_state.append(running_total + self.state[i])
69 | running_total += self.action_space.nvec[i]
70 |
71 | # Randomly choose invalid actions that are not the current state
72 | potential_invalid_actions = [i for i in self.possible_actions if i not in converted_state]
73 | self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False)
74 |
75 | def action_masks(self) -> List[bool]:
76 | return [action not in self.invalid_actions for action in self.possible_actions]
77 |
78 |
79 | class InvalidActionEnvMultiBinary(IdentityEnv):
80 | """
81 | Identity env with a multibinary action space. Supports action masking.
82 | """
83 |
84 | def __init__(
85 | self,
86 | dims: Optional[int] = None,
87 | ep_length: int = 100,
88 | n_invalid_actions: int = 0,
89 | ):
90 | if dims is None:
91 | dims = 1
92 |
93 | if n_invalid_actions > dims:
94 | raise ValueError(f"Cannot find a valid action for each dim. Set n_invalid_actions <= {dims}")
95 |
96 | space = spaces.MultiBinary(dims)
97 | self.n_invalid_actions = n_invalid_actions
98 | self.possible_actions = np.arange(2 * dims)
99 | self.invalid_actions: List[int] = []
100 | super().__init__(space=space, ep_length=ep_length)
101 |
102 | def _choose_next_state(self) -> None:
103 | self.state = self.action_space.sample()
104 |
105 | converted_state: List[int] = []
106 | running_total = 0
107 | for i in range(self.action_space.n):
108 | converted_state.append(running_total + self.state[i])
109 | running_total += 2
110 |
111 | # Randomly choose invalid actions that are not the current state
112 | potential_invalid_actions = [i for i in self.possible_actions if i not in converted_state]
113 | self.invalid_actions = np.random.choice(potential_invalid_actions, self.n_invalid_actions, replace=False)
114 |
115 | def action_masks(self) -> List[bool]:
116 | return [action not in self.invalid_actions for action in self.possible_actions]
117 |
--------------------------------------------------------------------------------
/thirdparty/sb3_contrib/common/maskable/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RobustFieldAutonomyLab/Stochastic_Road_Network/1cdcd41c7311560bf7a5df0d4d8bca829fe2b958/thirdparty/sb3_contrib/common/maskable/__init__.py
--------------------------------------------------------------------------------
/thirdparty/sb3_contrib/common/maskable/buffers.py:
--------------------------------------------------------------------------------
1 | from typing import Generator, NamedTuple, Optional, Union
2 |
3 | import numpy as np
4 | import torch as th
5 | from gym import spaces
6 | from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
7 | from stable_baselines3.common.type_aliases import TensorDict
8 | from stable_baselines3.common.vec_env import VecNormalize
9 |
10 |
11 | class MaskableRolloutBufferSamples(NamedTuple):
12 | observations: th.Tensor
13 | actions: th.Tensor
14 | old_values: th.Tensor
15 | old_log_prob: th.Tensor
16 | advantages: th.Tensor
17 | returns: th.Tensor
18 | action_masks: th.Tensor
19 |
20 |
21 | class MaskableDictRolloutBufferSamples(MaskableRolloutBufferSamples):
22 | observations: TensorDict
23 | actions: th.Tensor
24 | old_values: th.Tensor
25 | old_log_prob: th.Tensor
26 | advantages: th.Tensor
27 | returns: th.Tensor
28 | action_masks: th.Tensor
29 |
30 |
31 | class MaskableRolloutBuffer(RolloutBuffer):
32 | """
33 | Rollout buffer that also stores the invalid action masks associated with each observation.
34 |
35 | :param buffer_size: Max number of element in the buffer
36 | :param observation_space: Observation space
37 | :param action_space: Action space
38 | :param device:
39 | :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
40 | Equivalent to classic advantage when set to 1.
41 | :param gamma: Discount factor
42 | :param n_envs: Number of parallel environments
43 | """
44 |
45 | def __init__(self, *args, **kwargs):
46 | self.action_masks = None
47 | super().__init__(*args, **kwargs)
48 |
49 | def reset(self) -> None:
50 | if isinstance(self.action_space, spaces.Discrete):
51 | mask_dims = self.action_space.n
52 | elif isinstance(self.action_space, spaces.MultiDiscrete):
53 | mask_dims = sum(self.action_space.nvec)
54 | elif isinstance(self.action_space, spaces.MultiBinary):
55 | mask_dims = 2 * self.action_space.n # One mask per binary outcome
56 | else:
57 | raise ValueError(f"Unsupported action space {type(self.action_space)}")
58 |
59 | self.mask_dims = mask_dims
60 | self.action_masks = np.ones((self.buffer_size, self.n_envs, self.mask_dims), dtype=np.float32)
61 |
62 | super().reset()
63 |
64 | def add(self, *args, action_masks: Optional[np.ndarray] = None, **kwargs) -> None:
65 | """
66 | :param action_masks: Masks applied to constrain the choice of possible actions.
67 | """
68 | if action_masks is not None:
69 | self.action_masks[self.pos] = action_masks.reshape((self.n_envs, self.mask_dims))
70 |
71 | super().add(*args, **kwargs)
72 |
73 | def get(self, batch_size: Optional[int] = None) -> Generator[MaskableRolloutBufferSamples, None, None]:
74 | assert self.full, ""
75 | indices = np.random.permutation(self.buffer_size * self.n_envs)
76 | # Prepare the data
77 | if not self.generator_ready:
78 | for tensor in [
79 | "observations",
80 | "actions",
81 | "values",
82 | "log_probs",
83 | "advantages",
84 | "returns",
85 | "action_masks",
86 | ]:
87 | self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
88 | self.generator_ready = True
89 |
90 | # Return everything, don't create minibatches
91 | if batch_size is None:
92 | batch_size = self.buffer_size * self.n_envs
93 |
94 | start_idx = 0
95 | while start_idx < self.buffer_size * self.n_envs:
96 | yield self._get_samples(indices[start_idx : start_idx + batch_size])
97 | start_idx += batch_size
98 |
99 | def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableRolloutBufferSamples:
100 | data = (
101 | self.observations[batch_inds],
102 | self.actions[batch_inds],
103 | self.values[batch_inds].flatten(),
104 | self.log_probs[batch_inds].flatten(),
105 | self.advantages[batch_inds].flatten(),
106 | self.returns[batch_inds].flatten(),
107 | self.action_masks[batch_inds].reshape(-1, self.mask_dims),
108 | )
109 | return MaskableRolloutBufferSamples(*map(self.to_torch, data))
110 |
111 |
112 | class MaskableDictRolloutBuffer(DictRolloutBuffer):
113 | """
114 | Dict Rollout buffer used in on-policy algorithms like A2C/PPO.
115 | Extends the RolloutBuffer to use dictionary observations
116 |
117 | It corresponds to ``buffer_size`` transitions collected
118 | using the current policy.
119 | This experience will be discarded after the policy update.
120 | In order to use PPO objective, we also store the current value of each state
121 | and the log probability of each taken action.
122 |
123 | The term rollout here refers to the model-free notion and should not
124 | be used with the concept of rollout used in model-based RL or planning.
125 | Hence, it is only involved in policy and value function training but not action selection.
126 |
127 | :param buffer_size: Max number of element in the buffer
128 | :param observation_space: Observation space
129 | :param action_space: Action space
130 | :param device:
131 | :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
132 | Equivalent to classic advantage when set to 1.
133 | :param gamma: Discount factor
134 | :param n_envs: Number of parallel environments
135 | """
136 |
137 | def __init__(
138 | self,
139 | buffer_size: int,
140 | observation_space: spaces.Space,
141 | action_space: spaces.Space,
142 | device: Union[th.device, str] = "cpu",
143 | gae_lambda: float = 1,
144 | gamma: float = 0.99,
145 | n_envs: int = 1,
146 | ):
147 | self.action_masks = None
148 | super(MaskableDictRolloutBuffer, self).__init__(
149 | buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs
150 | )
151 |
152 | def reset(self) -> None:
153 | if isinstance(self.action_space, spaces.Discrete):
154 | mask_dims = self.action_space.n
155 | elif isinstance(self.action_space, spaces.MultiDiscrete):
156 | mask_dims = sum(self.action_space.nvec)
157 | elif isinstance(self.action_space, spaces.MultiBinary):
158 | mask_dims = 2 * self.action_space.n # One mask per binary outcome
159 | else:
160 | raise ValueError(f"Unsupported action space {type(self.action_space)}")
161 |
162 | self.mask_dims = mask_dims
163 | self.action_masks = np.ones((self.buffer_size, self.n_envs, self.mask_dims), dtype=np.float32)
164 |
165 | super(MaskableDictRolloutBuffer, self).reset()
166 |
167 | def add(self, *args, action_masks: Optional[np.ndarray] = None, **kwargs) -> None:
168 | """
169 | :param action_masks: Masks applied to constrain the choice of possible actions.
170 | """
171 | if action_masks is not None:
172 | self.action_masks[self.pos] = action_masks.reshape((self.n_envs, self.mask_dims))
173 |
174 | super(MaskableDictRolloutBuffer, self).add(*args, **kwargs)
175 |
176 | def get(self, batch_size: Optional[int] = None) -> Generator[MaskableDictRolloutBufferSamples, None, None]:
177 | assert self.full, ""
178 | indices = np.random.permutation(self.buffer_size * self.n_envs)
179 | # Prepare the data
180 | if not self.generator_ready:
181 |
182 | for key, obs in self.observations.items():
183 | self.observations[key] = self.swap_and_flatten(obs)
184 |
185 | _tensor_names = ["actions", "values", "log_probs", "advantages", "returns", "action_masks"]
186 |
187 | for tensor in _tensor_names:
188 | self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
189 | self.generator_ready = True
190 |
191 | # Return everything, don't create minibatches
192 | if batch_size is None:
193 | batch_size = self.buffer_size * self.n_envs
194 |
195 | start_idx = 0
196 | while start_idx < self.buffer_size * self.n_envs:
197 | yield self._get_samples(indices[start_idx : start_idx + batch_size])
198 | start_idx += batch_size
199 |
200 | def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> MaskableDictRolloutBufferSamples:
201 |
202 | return MaskableDictRolloutBufferSamples(
203 | observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
204 | actions=self.to_torch(self.actions[batch_inds]),
205 | old_values=self.to_torch(self.values[batch_inds].flatten()),
206 | old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()),
207 | advantages=self.to_torch(self.advantages[batch_inds].flatten()),
208 | returns=self.to_torch(self.returns[batch_inds].flatten()),
209 | action_masks=self.to_torch(self.action_masks[batch_inds].reshape(-1, self.mask_dims)),
210 | )
211 |
--------------------------------------------------------------------------------
/thirdparty/sb3_contrib/common/maskable/callbacks.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | from stable_baselines3.common.callbacks import EvalCallback
5 | from stable_baselines3.common.vec_env import sync_envs_normalization
6 |
7 | from sb3_contrib.common.maskable.evaluation import evaluate_policy
8 |
9 |
10 | class MaskableEvalCallback(EvalCallback):
11 | """
12 | Callback for evaluating an agent. Supports invalid action masking.
13 |
14 | :param eval_env: The environment used for initialization
15 | :param callback_on_new_best: Callback to trigger
16 | when there is a new best model according to the ``mean_reward``
17 | :param n_eval_episodes: The number of episodes to test the agent
18 | :param eval_freq: Evaluate the agent every eval_freq call of the callback.
19 | :param log_path: Path to a folder where the evaluations (``evaluations.npz``)
20 | will be saved. It will be updated at each evaluation.
21 | :param best_model_save_path: Path to a folder where the best model
22 | according to performance on the eval env will be saved.
23 | :param deterministic: Whether the evaluation should
24 | use a stochastic or deterministic actions.
25 | :param render: Whether to render or not the environment during evaluation
26 | :param verbose:
27 | :param warn: Passed to ``evaluate_policy`` (warns if ``eval_env`` has not been
28 | wrapped with a Monitor wrapper)
29 | :param use_masking: Whether or not to use invalid action masks during evaluation
30 | """
31 |
32 | def __init__(self, *args, use_masking: bool = True, **kwargs):
33 | super().__init__(*args, **kwargs)
34 | self.use_masking = use_masking
35 |
36 | def _on_step(self) -> bool:
37 | if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
38 | # Sync training and eval env if there is VecNormalize
39 | sync_envs_normalization(self.training_env, self.eval_env)
40 |
41 | # Reset success rate buffer
42 | self._is_success_buffer = []
43 |
44 | # Note that evaluate_policy() has been patched to support masking
45 | episode_rewards, episode_lengths = evaluate_policy(
46 | self.model,
47 | self.eval_env,
48 | n_eval_episodes=self.n_eval_episodes,
49 | render=self.render,
50 | deterministic=self.deterministic,
51 | return_episode_rewards=True,
52 | warn=self.warn,
53 | callback=self._log_success_callback,
54 | use_masking=self.use_masking,
55 | )
56 |
57 | if self.log_path is not None:
58 | self.evaluations_timesteps.append(self.num_timesteps)
59 | self.evaluations_results.append(episode_rewards)
60 | self.evaluations_length.append(episode_lengths)
61 |
62 | kwargs = {}
63 | # Save success log if present
64 | if len(self._is_success_buffer) > 0:
65 | self.evaluations_successes.append(self._is_success_buffer)
66 | kwargs = dict(successes=self.evaluations_successes)
67 |
68 | np.savez(
69 | self.log_path,
70 | timesteps=self.evaluations_timesteps,
71 | results=self.evaluations_results,
72 | ep_lengths=self.evaluations_length,
73 | **kwargs,
74 | )
75 |
76 | mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards)
77 | mean_ep_length, std_ep_length = np.mean(episode_lengths), np.std(episode_lengths)
78 | self.last_mean_reward = mean_reward
79 |
80 | if self.verbose > 0:
81 | print(f"Eval num_timesteps={self.num_timesteps}, " f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}")
82 | print(f"Episode length: {mean_ep_length:.2f} +/- {std_ep_length:.2f}")
83 | # Add to current Logger
84 | self.logger.record("eval/mean_reward", float(mean_reward))
85 | self.logger.record("eval/mean_ep_length", mean_ep_length)
86 |
87 | if len(self._is_success_buffer) > 0:
88 | success_rate = np.mean(self._is_success_buffer)
89 | if self.verbose > 0:
90 | print(f"Success rate: {100 * success_rate:.2f}%")
91 | self.logger.record("eval/success_rate", success_rate)
92 |
93 | # Dump log so the evaluation results are printed with the correct timestep
94 | self.logger.record("time/total timesteps", self.num_timesteps, exclude="tensorboard")
95 | self.logger.dump(self.num_timesteps)
96 |
97 | if mean_reward > self.best_mean_reward:
98 | if self.verbose > 0:
99 | print("New best mean reward!")
100 | if self.best_model_save_path is not None:
101 | self.model.save(os.path.join(self.best_model_save_path, "best_model"))
102 | self.best_mean_reward = mean_reward
103 | # Trigger callback if needed
104 | if self.callback is not None:
105 | return self._on_event()
106 |
107 | return True
108 |
--------------------------------------------------------------------------------
/thirdparty/sb3_contrib/common/maskable/evaluation.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union
3 |
4 | import gym
5 | import numpy as np
6 | from stable_baselines3.common.monitor import Monitor
7 | from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecMonitor, is_vecenv_wrapped
8 |
9 | from sb3_contrib.common.maskable.utils import get_action_masks, is_masking_supported
10 | from sb3_contrib.ppo_mask import MaskablePPO
11 |
12 |
13 | def evaluate_policy( # noqa: C901
14 | model: MaskablePPO,
15 | env: Union[gym.Env, VecEnv],
16 | n_eval_episodes: int = 10,
17 | deterministic: bool = True,
18 | render: bool = False,
19 | callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], None]] = None,
20 | reward_threshold: Optional[float] = None,
21 | return_episode_rewards: bool = False,
22 | warn: bool = True,
23 | use_masking: bool = True,
24 | ) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]:
25 | """
26 | Runs policy for ``n_eval_episodes`` episodes and returns average reward.
27 | If a vector env is passed in, this divides the episodes to evaluate onto the
28 | different elements of the vector env. This static division of work is done to
29 | remove bias. See https://github.com/DLR-RM/stable-baselines3/issues/402 for more
30 | details and discussion.
31 |
32 | .. note::
33 | If environment has not been wrapped with ``Monitor`` wrapper, reward and
34 | episode lengths are counted as it appears with ``env.step`` calls. If
35 | the environment contains wrappers that modify rewards or episode lengths
36 | (e.g. reward scaling, early episode reset), these will affect the evaluation
37 | results as well. You can avoid this by wrapping environment with ``Monitor``
38 | wrapper before anything else.
39 |
40 | :param model: The RL agent you want to evaluate.
41 | :param env: The gym environment. In the case of a ``VecEnv``
42 | this must contain only one environment.
43 | :param n_eval_episodes: Number of episode to evaluate the agent
44 | :param deterministic: Whether to use deterministic or stochastic actions
45 | :param render: Whether to render the environment or not
46 | :param callback: callback function to do additional checks,
47 | called after each step. Gets locals() and globals() passed as parameters.
48 | :param reward_threshold: Minimum expected reward per episode,
49 | this will raise an error if the performance is not met
50 | :param return_episode_rewards: If True, a list of rewards and episde lengths
51 | per episode will be returned instead of the mean.
52 | :param warn: If True (default), warns user about lack of a Monitor wrapper in the
53 | evaluation environment.
54 | :param use_masking: Whether or not to use invalid action masks during evaluation
55 | :return: Mean reward per episode, std of reward per episode.
56 | Returns ([float], [int]) when ``return_episode_rewards`` is True, first
57 | list containing per-episode rewards and second containing per-episode lengths
58 | (in number of steps).
59 | """
60 |
61 | if use_masking and not is_masking_supported(env):
62 | raise ValueError("Environment does not support action masking. Consider using ActionMasker wrapper")
63 |
64 | is_monitor_wrapped = False
65 |
66 | if not isinstance(env, VecEnv):
67 | env = DummyVecEnv([lambda: env])
68 |
69 | is_monitor_wrapped = is_vecenv_wrapped(env, VecMonitor) or env.env_is_wrapped(Monitor)[0]
70 |
71 | if not is_monitor_wrapped and warn:
72 | warnings.warn(
73 | "Evaluation environment is not wrapped with a ``Monitor`` wrapper. "
74 | "This may result in reporting modified episode lengths and rewards, if other wrappers happen to modify these. "
75 | "Consider wrapping environment first with ``Monitor`` wrapper.",
76 | UserWarning,
77 | )
78 |
79 | n_envs = env.num_envs
80 | episode_rewards = []
81 | episode_lengths = []
82 |
83 | episode_counts = np.zeros(n_envs, dtype="int")
84 | # Divides episodes among different sub environments in the vector as evenly as possible
85 | episode_count_targets = np.array([(n_eval_episodes + i) // n_envs for i in range(n_envs)], dtype="int")
86 |
87 | current_rewards = np.zeros(n_envs)
88 | current_lengths = np.zeros(n_envs, dtype="int")
89 | observations = env.reset()
90 | states = None
91 |
92 | while (episode_counts < episode_count_targets).any():
93 | if use_masking:
94 | action_masks = get_action_masks(env)
95 | actions, state = model.predict(
96 | observations,
97 | state=states,
98 | deterministic=deterministic,
99 | action_masks=action_masks,
100 | )
101 | else:
102 | actions, states = model.predict(observations, state=states, deterministic=deterministic)
103 | observations, rewards, dones, infos = env.step(actions)
104 | current_rewards += rewards
105 | current_lengths += 1
106 | for i in range(n_envs):
107 | if episode_counts[i] < episode_count_targets[i]:
108 |
109 | # unpack values so that the callback can access the local variables
110 | reward = rewards[i]
111 | done = dones[i]
112 | info = infos[i]
113 |
114 | if callback is not None:
115 | callback(locals(), globals())
116 |
117 | if dones[i]:
118 | if is_monitor_wrapped:
119 | # Atari wrapper can send a "done" signal when
120 | # the agent loses a life, but it does not correspond
121 | # to the true end of episode
122 | if "episode" in info.keys():
123 | # Do not trust "done" with episode endings.
124 | # Monitor wrapper includes "episode" key in info if environment
125 | # has been wrapped with it. Use those rewards instead.
126 | episode_rewards.append(info["episode"]["r"])
127 | episode_lengths.append(info["episode"]["l"])
128 | # Only increment at the real end of an episode
129 | episode_counts[i] += 1
130 | else:
131 | episode_rewards.append(current_rewards[i])
132 | episode_lengths.append(current_lengths[i])
133 | episode_counts[i] += 1
134 | current_rewards[i] = 0
135 | current_lengths[i] = 0
136 | if states is not None:
137 | states[i] *= 0
138 |
139 | if render:
140 | env.render()
141 |
142 | mean_reward = np.mean(episode_rewards)
143 | std_reward = np.std(episode_rewards)
144 | if reward_threshold is not None:
145 | assert mean_reward > reward_threshold, "Mean reward below threshold: " f"{mean_reward:.2f} < {reward_threshold:.2f}"
146 | if return_episode_rewards:
147 | return episode_rewards, episode_lengths
148 | return mean_reward, std_reward
149 |
--------------------------------------------------------------------------------
/thirdparty/sb3_contrib/common/maskable/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from stable_baselines3.common.type_aliases import GymEnv
3 | from stable_baselines3.common.vec_env import VecEnv
4 |
5 | EXPECTED_METHOD_NAME = "action_masks"
6 |
7 |
8 | def get_action_masks(env: GymEnv) -> np.ndarray:
9 | """
10 | Checks whether gym env exposes a method returning invalid action masks
11 |
12 | :param env: the Gym environment to get masks from
13 | :return: A numpy array of the masks
14 | """
15 |
16 | if isinstance(env, VecEnv):
17 | return np.stack(env.env_method(EXPECTED_METHOD_NAME))
18 | else:
19 | return getattr(env, EXPECTED_METHOD_NAME)()
20 |
21 |
22 | def is_masking_supported(env: GymEnv) -> bool:
23 | """
24 | Checks whether gym env exposes a method returning invalid action masks
25 |
26 | :param env: the Gym environment to check
27 | :return: True if the method is found, False otherwise
28 | """
29 |
30 | if isinstance(env, VecEnv):
31 | try:
32 | # TODO: add VecEnv.has_attr()
33 | env.get_attr(EXPECTED_METHOD_NAME)
34 | return True
35 | except AttributeError:
36 | return False
37 | else:
38 | return hasattr(env, EXPECTED_METHOD_NAME)
39 |
--------------------------------------------------------------------------------
/thirdparty/sb3_contrib/common/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Optional, Sequence
2 |
3 | import torch as th
4 | from torch import nn
5 |
6 |
7 | def quantile_huber_loss(
8 | current_quantiles: th.Tensor,
9 | target_quantiles: th.Tensor,
10 | cum_prob: Optional[th.Tensor] = None,
11 | sum_over_quantiles: bool = True,
12 | ) -> th.Tensor:
13 | """
14 | The quantile-regression loss, as described in the QR-DQN and TQC papers.
15 | Partially taken from https://github.com/bayesgroup/tqc_pytorch.
16 |
17 | :param current_quantiles: current estimate of quantiles, must be either
18 | (batch_size, n_quantiles) or (batch_size, n_critics, n_quantiles)
19 | :param target_quantiles: target of quantiles, must be either (batch_size, n_target_quantiles),
20 | (batch_size, 1, n_target_quantiles), or (batch_size, n_critics, n_target_quantiles)
21 | :param cum_prob: cumulative probabilities to calculate quantiles (also called midpoints in QR-DQN paper),
22 | must be either (batch_size, n_quantiles), (batch_size, 1, n_quantiles), or (batch_size, n_critics, n_quantiles).
23 | (if None, calculating unit quantiles)
24 | :param sum_over_quantiles: if summing over the quantile dimension or not
25 | :return: the loss
26 | """
27 | if current_quantiles.ndim != target_quantiles.ndim:
28 | raise ValueError(
29 | f"Error: The dimension of curremt_quantile ({current_quantiles.ndim}) needs to match "
30 | f"the dimension of target_quantiles ({target_quantiles.ndim})."
31 | )
32 | if current_quantiles.shape[0] != target_quantiles.shape[0]:
33 | raise ValueError(
34 | f"Error: The batch size of curremt_quantile ({current_quantiles.shape[0]}) needs to match "
35 | f"the batch size of target_quantiles ({target_quantiles.shape[0]})."
36 | )
37 | if current_quantiles.ndim not in (2, 3):
38 | raise ValueError(f"Error: The dimension of current_quantiles ({current_quantiles.ndim}) needs to be either 2 or 3.")
39 |
40 | if cum_prob is None:
41 | n_quantiles = current_quantiles.shape[-1]
42 | # Cumulative probabilities to calculate quantiles.
43 | cum_prob = (th.arange(n_quantiles, device=current_quantiles.device, dtype=th.float) + 0.5) / n_quantiles
44 | if current_quantiles.ndim == 2:
45 | # For QR-DQN, current_quantiles have a shape (batch_size, n_quantiles), and make cum_prob
46 | # broadcastable to (batch_size, n_quantiles, n_target_quantiles)
47 | cum_prob = cum_prob.view(1, -1, 1)
48 | elif current_quantiles.ndim == 3:
49 | # For TQC, current_quantiles have a shape (batch_size, n_critics, n_quantiles), and make cum_prob
50 | # broadcastable to (batch_size, n_critics, n_quantiles, n_target_quantiles)
51 | cum_prob = cum_prob.view(1, 1, -1, 1)
52 |
53 | # QR-DQN
54 | # target_quantiles: (batch_size, n_target_quantiles) -> (batch_size, 1, n_target_quantiles)
55 | # current_quantiles: (batch_size, n_quantiles) -> (batch_size, n_quantiles, 1)
56 | # pairwise_delta: (batch_size, n_target_quantiles, n_quantiles)
57 | # TQC
58 | # target_quantiles: (batch_size, 1, n_target_quantiles) -> (batch_size, 1, 1, n_target_quantiles)
59 | # current_quantiles: (batch_size, n_critics, n_quantiles) -> (batch_size, n_critics, n_quantiles, 1)
60 | # pairwise_delta: (batch_size, n_critics, n_quantiles, n_target_quantiles)
61 | # Note: in both cases, the loss has the same shape as pairwise_delta
62 | pairwise_delta = target_quantiles.unsqueeze(-2) - current_quantiles.unsqueeze(-1)
63 | abs_pairwise_delta = th.abs(pairwise_delta)
64 | huber_loss = th.where(abs_pairwise_delta > 1, abs_pairwise_delta - 0.5, pairwise_delta**2 * 0.5)
65 | loss = th.abs(cum_prob - (pairwise_delta.detach() < 0).float()) * huber_loss
66 | if sum_over_quantiles:
67 | loss = loss.sum(dim=-2).mean()
68 | else:
69 | loss = loss.mean()
70 | return loss
71 |
72 |
73 | def conjugate_gradient_solver(
74 | matrix_vector_dot_fn: Callable[[th.Tensor], th.Tensor],
75 | b,
76 | max_iter=10,
77 | residual_tol=1e-10,
78 | ) -> th.Tensor:
79 | """
80 | Finds an approximate solution to a set of linear equations Ax = b
81 |
82 | Sources:
83 | - https://github.com/ajlangley/trpo-pytorch/blob/master/conjugate_gradient.py
84 | - https://github.com/joschu/modular_rl/blob/master/modular_rl/trpo.py#L122
85 |
86 | Reference:
87 | - https://epubs.siam.org/doi/abs/10.1137/1.9781611971446.ch6
88 |
89 | :param matrix_vector_dot_fn:
90 | a function that right multiplies a matrix A by a vector v
91 | :param b:
92 | the right hand term in the set of linear equations Ax = b
93 | :param max_iter:
94 | the maximum number of iterations (default is 10)
95 | :param residual_tol:
96 | residual tolerance for early stopping of the solving (default is 1e-10)
97 | :return x:
98 | the approximate solution to the system of equations defined by `matrix_vector_dot_fn`
99 | and b
100 | """
101 |
102 | # The vector is not initialized at 0 because of the instability issues when the gradient becomes small.
103 | # A small random gaussian noise is used for the initialization.
104 | x = 1e-4 * th.randn_like(b)
105 | residual = b - matrix_vector_dot_fn(x)
106 | # Equivalent to th.linalg.norm(residual) ** 2 (L2 norm squared)
107 | residual_squared_norm = th.matmul(residual, residual)
108 |
109 | if residual_squared_norm < residual_tol:
110 | # If the gradient becomes extremely small
111 | # The denominator in alpha will become zero
112 | # Leading to a division by zero
113 | return x
114 |
115 | p = residual.clone()
116 |
117 | for i in range(max_iter):
118 | # A @ p (matrix vector multiplication)
119 | A_dot_p = matrix_vector_dot_fn(p)
120 |
121 | alpha = residual_squared_norm / p.dot(A_dot_p)
122 | x += alpha * p
123 |
124 | if i == max_iter - 1:
125 | return x
126 |
127 | residual -= alpha * A_dot_p
128 | new_residual_squared_norm = th.matmul(residual, residual)
129 |
130 | if new_residual_squared_norm < residual_tol:
131 | return x
132 |
133 | beta = new_residual_squared_norm / residual_squared_norm
134 | residual_squared_norm = new_residual_squared_norm
135 | p = residual + beta * p
136 |
137 |
138 | def flat_grad(
139 | output,
140 | parameters: Sequence[nn.parameter.Parameter],
141 | create_graph: bool = False,
142 | retain_graph: bool = False,
143 | ) -> th.Tensor:
144 | """
145 | Returns the gradients of the passed sequence of parameters into a flat gradient.
146 | Order of parameters is preserved.
147 |
148 | :param output: functional output to compute the gradient for
149 | :param parameters: sequence of ``Parameter``
150 | :param retain_graph: – If ``False``, the graph used to compute the grad will be freed.
151 | Defaults to the value of ``create_graph``.
152 | :param create_graph: – If ``True``, graph of the derivative will be constructed,
153 | allowing to compute higher order derivative products. Default: ``False``.
154 | :return: Tensor containing the flattened gradients
155 | """
156 | grads = th.autograd.grad(
157 | output,
158 | parameters,
159 | create_graph=create_graph,
160 | retain_graph=retain_graph,
161 | allow_unused=True,
162 | )
163 | return th.cat([th.ravel(grad) for grad in grads if grad is not None])
164 |
--------------------------------------------------------------------------------
/thirdparty/sb3_contrib/common/vec_env/__init__.py:
--------------------------------------------------------------------------------
1 | from sb3_contrib.common.vec_env.async_eval import AsyncEval
2 |
--------------------------------------------------------------------------------
/thirdparty/sb3_contrib/common/vec_env/async_eval.py:
--------------------------------------------------------------------------------
1 | import multiprocessing
2 | import multiprocessing as mp
3 | from collections import defaultdict
4 | from typing import Callable, List, Optional, Tuple, Union
5 |
6 | import numpy as np
7 | import torch as th
8 | from stable_baselines3.common.evaluation import evaluate_policy
9 | from stable_baselines3.common.policies import BasePolicy
10 | from stable_baselines3.common.running_mean_std import RunningMeanStd
11 | from stable_baselines3.common.vec_env import VecEnv, unwrap_vec_normalize
12 | from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper
13 |
14 |
15 | def _worker(
16 | remote: mp.connection.Connection,
17 | parent_remote: mp.connection.Connection,
18 | worker_env_wrapper: CloudpickleWrapper,
19 | train_policy_wrapper: CloudpickleWrapper,
20 | n_eval_episodes: int = 1,
21 | ) -> None:
22 | """
23 | Function that will be run in each process.
24 | It is in charge of creating environments, evaluating candidates
25 | and communicating with the main process.
26 |
27 | :param remote: Pipe to communicate with the parent process.
28 | :param parent_remote:
29 | :param worker_env_wrapper: Callable used to create the environment inside the process.
30 | :param train_policy_wrapper: Callable used to create the policy inside the process.
31 | :param n_eval_episodes: Number of evaluation episodes per candidate.
32 | """
33 | parent_remote.close()
34 | env = worker_env_wrapper.var()
35 | train_policy = train_policy_wrapper.var
36 | vec_normalize = unwrap_vec_normalize(env)
37 | if vec_normalize is not None:
38 | obs_rms = vec_normalize.obs_rms
39 | else:
40 | obs_rms = None
41 | while True:
42 | try:
43 | cmd, data = remote.recv()
44 | if cmd == "eval":
45 | results = []
46 | # Evaluate each candidate and save results
47 | for weights_idx, candidate_weights in data:
48 | train_policy.load_from_vector(candidate_weights.cpu())
49 | episode_rewards, episode_lengths = evaluate_policy(
50 | train_policy,
51 | env,
52 | n_eval_episodes=n_eval_episodes,
53 | return_episode_rewards=True,
54 | warn=False,
55 | )
56 | results.append((weights_idx, (episode_rewards, episode_lengths)))
57 | remote.send(results)
58 | elif cmd == "seed":
59 | remote.send(env.seed(data))
60 | elif cmd == "get_obs_rms":
61 | remote.send(obs_rms)
62 | elif cmd == "sync_obs_rms":
63 | vec_normalize.obs_rms = data
64 | obs_rms = data
65 | elif cmd == "close":
66 | env.close()
67 | remote.close()
68 | break
69 | else:
70 | raise NotImplementedError(f"`{cmd}` is not implemented in the worker")
71 | except EOFError:
72 | break
73 |
74 |
75 | class AsyncEval(object):
76 | """
77 | Helper class to do asynchronous evaluation of different policies with multiple processes.
78 | It is useful when implementing population based methods like Evolution Strategies (ES),
79 | Cross Entropy Method (CEM) or Augmented Random Search (ARS).
80 |
81 | .. warning::
82 |
83 | Only 'forkserver' and 'spawn' start methods are thread-safe,
84 | which is important to avoid race conditions.
85 | However, compared to
86 | 'fork' they incur a small start-up cost and have restrictions on
87 | global variables. With those methods, users must wrap the code in an
88 | ``if __name__ == "__main__":`` block.
89 | For more information, see the multiprocessing documentation.
90 |
91 | :param envs_fn: Vectorized environments to run in subprocesses (callable)
92 | :param train_policy: The policy object that will load the different candidate
93 | weights.
94 | :param start_method: method used to start the subprocesses.
95 | Must be one of the methods returned by ``multiprocessing.get_all_start_methods()``.
96 | Defaults to 'forkserver' on available platforms, and 'spawn' otherwise.
97 | :param n_eval_episodes: The number of episodes to test each agent
98 | """
99 |
100 | def __init__(
101 | self,
102 | envs_fn: List[Callable[[], VecEnv]],
103 | train_policy: BasePolicy,
104 | start_method: Optional[str] = None,
105 | n_eval_episodes: int = 1,
106 | ):
107 | self.waiting = False
108 | self.closed = False
109 | n_envs = len(envs_fn)
110 |
111 | if start_method is None:
112 | # Fork is not a thread safe method (see issue #217)
113 | # but is more user friendly (does not require to wrap the code in
114 | # a `if __name__ == "__main__":`)
115 | forkserver_available = "forkserver" in mp.get_all_start_methods()
116 | start_method = "forkserver" if forkserver_available else "spawn"
117 | ctx = mp.get_context(start_method)
118 |
119 | self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(n_envs)])
120 | self.processes = []
121 | for work_remote, remote, worker_env in zip(self.work_remotes, self.remotes, envs_fn):
122 | args = (
123 | work_remote,
124 | remote,
125 | CloudpickleWrapper(worker_env),
126 | CloudpickleWrapper(train_policy),
127 | n_eval_episodes,
128 | )
129 | # daemon=True: if the main process crashes, we should not cause things to hang
130 | process = ctx.Process(target=_worker, args=args, daemon=True) # pytype:disable=attribute-error
131 | process.start()
132 | self.processes.append(process)
133 | work_remote.close()
134 |
135 | def send_jobs(self, candidate_weights: th.Tensor, pop_size: int) -> None:
136 | """
137 | Send jobs to the workers to evaluate new candidates.
138 |
139 | :param candidate_weights: The weights to be evaluated.
140 | :pop_size: The number of candidate (size of the population)
141 | """
142 | jobs_per_worker = defaultdict(list)
143 | for weights_idx in range(pop_size):
144 | jobs_per_worker[weights_idx % len(self.remotes)].append((weights_idx, candidate_weights[weights_idx]))
145 |
146 | for remote_idx, remote in enumerate(self.remotes):
147 | remote.send(("eval", jobs_per_worker[remote_idx]))
148 | self.waiting = True
149 |
150 | def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
151 | """
152 | Seed the environments.
153 |
154 | :param seed: The seed for the pseudo-random generators.
155 | :return:
156 | """
157 | for idx, remote in enumerate(self.remotes):
158 | remote.send(("seed", seed + idx))
159 | return [remote.recv() for remote in self.remotes]
160 |
161 | def get_results(self) -> List[Tuple[int, Tuple[np.ndarray, np.ndarray]]]:
162 | """
163 | Retreive episode rewards and lengths from each worker
164 | for all candidates (there might be multiple candidates per worker)
165 |
166 | :return: A list of tuples containing each candidate index and its
167 | result (episodic reward and episode length)
168 | """
169 | results = [remote.recv() for remote in self.remotes]
170 | flat_results = [result for worker_results in results for result in worker_results]
171 | self.waiting = False
172 | return flat_results
173 |
174 | def get_obs_rms(self) -> List[RunningMeanStd]:
175 | """
176 | Retrieve the observation filters (observation running mean std)
177 | of each process, they will be combined in the main process.
178 | Synchronisation is done afterward using ``sync_obs_rms()``.
179 | :return: A list of ``RunningMeanStd`` objects (one per process)
180 | """
181 | for remote in self.remotes:
182 | remote.send(("get_obs_rms", None))
183 | return [remote.recv() for remote in self.remotes]
184 |
185 | def sync_obs_rms(self, obs_rms: RunningMeanStd) -> None:
186 | """
187 | Synchronise (and update) the observation filters
188 | (observation running mean std)
189 | :param obs_rms: The updated ``RunningMeanStd`` to be used
190 | by workers for normalizing observations.
191 | """
192 | for remote in self.remotes:
193 | remote.send(("sync_obs_rms", obs_rms))
194 |
195 | def close(self) -> None:
196 | """
197 | Close the processes.
198 | """
199 | if self.closed:
200 | return
201 | if self.waiting:
202 | for remote in self.remotes:
203 | remote.recv()
204 | for remote in self.remotes:
205 | remote.send(("close", None))
206 | for process in self.processes:
207 | process.join()
208 | self.closed = True
209 |
--------------------------------------------------------------------------------
/thirdparty/sb3_contrib/common/wrappers/__init__.py:
--------------------------------------------------------------------------------
1 | from sb3_contrib.common.wrappers.action_masker import ActionMasker
2 | from sb3_contrib.common.wrappers.time_feature import TimeFeatureWrapper
3 |
--------------------------------------------------------------------------------
/thirdparty/sb3_contrib/common/wrappers/action_masker.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Union
2 |
3 | import gym
4 | import numpy as np
5 |
6 |
7 | class ActionMasker(gym.Wrapper):
8 | """
9 | Env wrapper providing the method required to support masking.
10 |
11 | Exposes a method called action_masks(), which returns masks for the wrapped env.
12 | This wrapper is not needed if the env exposes the expected method itself.
13 |
14 | :param env: the Gym environment to wrap
15 | :param action_mask_fn: A function that takes a Gym environment and returns an action mask,
16 | or the name of such a method provided by the environment.
17 | """
18 |
19 | def __init__(self, env: gym.Env, action_mask_fn: Union[str, Callable[[gym.Env], np.ndarray]]):
20 | super().__init__(env)
21 |
22 | if isinstance(action_mask_fn, str):
23 | found_method = getattr(self.env, action_mask_fn)
24 | if not callable(found_method):
25 | raise ValueError(f"Environment attribute {action_mask_fn} is not a method")
26 |
27 | self._action_mask_fn = found_method
28 | else:
29 | self._action_mask_fn = action_mask_fn
30 |
31 | def action_masks(self) -> np.ndarray:
32 | return self._action_mask_fn(self.env)
33 |
--------------------------------------------------------------------------------
/thirdparty/sb3_contrib/common/wrappers/time_feature.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Union
2 |
3 | import gym
4 | import numpy as np
5 | from stable_baselines3.common.type_aliases import GymObs, GymStepReturn
6 |
7 |
8 | class TimeFeatureWrapper(gym.Wrapper):
9 | """
10 | Add remaining, normalized time to observation space for fixed length episodes.
11 | See https://arxiv.org/abs/1712.00378 and https://github.com/aravindr93/mjrl/issues/13.
12 |
13 | .. note::
14 |
15 | Only ``gym.spaces.Box`` and ``gym.spaces.Dict`` (``gym.GoalEnv``) 1D observation spaces
16 | are supported for now.
17 |
18 | :param env: Gym env to wrap.
19 | :param max_steps: Max number of steps of an episode
20 | if it is not wrapped in a ``TimeLimit`` object.
21 | :param test_mode: In test mode, the time feature is constant,
22 | equal to zero. This allow to check that the agent did not overfit this feature,
23 | learning a deterministic pre-defined sequence of actions.
24 | """
25 |
26 | def __init__(self, env: gym.Env, max_steps: int = 1000, test_mode: bool = False):
27 | assert isinstance(
28 | env.observation_space, (gym.spaces.Box, gym.spaces.Dict)
29 | ), "`TimeFeatureWrapper` only supports `gym.spaces.Box` and `gym.spaces.Dict` (`gym.GoalEnv`) observation spaces."
30 |
31 | # Add a time feature to the observation
32 | if isinstance(env.observation_space, gym.spaces.Dict):
33 | assert "observation" in env.observation_space.spaces, "No `observation` key in the observation space"
34 | obs_space = env.observation_space.spaces["observation"]
35 | assert isinstance(
36 | obs_space, gym.spaces.Box
37 | ), "`TimeFeatureWrapper` only supports `gym.spaces.Box` observation space."
38 | obs_space = env.observation_space.spaces["observation"]
39 | else:
40 | obs_space = env.observation_space
41 |
42 | assert len(obs_space.shape) == 1, "Only 1D observation spaces are supported"
43 |
44 | low, high = obs_space.low, obs_space.high
45 | low, high = np.concatenate((low, [0.0])), np.concatenate((high, [1.0]))
46 | self.dtype = obs_space.dtype
47 |
48 | if isinstance(env.observation_space, gym.spaces.Dict):
49 | env.observation_space.spaces["observation"] = gym.spaces.Box(low=low, high=high, dtype=self.dtype)
50 | else:
51 | env.observation_space = gym.spaces.Box(low=low, high=high, dtype=self.dtype)
52 |
53 | super(TimeFeatureWrapper, self).__init__(env)
54 |
55 | # Try to infer the max number of steps per episode
56 | try:
57 | self._max_steps = env.spec.max_episode_steps
58 | except AttributeError:
59 | self._max_steps = None
60 |
61 | # Fallback to provided value
62 | if self._max_steps is None:
63 | self._max_steps = max_steps
64 |
65 | self._current_step = 0
66 | self._test_mode = test_mode
67 |
68 | def reset(self) -> GymObs:
69 | self._current_step = 0
70 | return self._get_obs(self.env.reset())
71 |
72 | def step(self, action: Union[int, np.ndarray]) -> GymStepReturn:
73 | self._current_step += 1
74 | obs, reward, done, info = self.env.step(action)
75 | return self._get_obs(obs), reward, done, info
76 |
77 | def _get_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]:
78 | """
79 | Concatenate the time feature to the current observation.
80 |
81 | :param obs:
82 | :return:
83 | """
84 | # Remaining time is more general
85 | time_feature = 1 - (self._current_step / self._max_steps)
86 | if self._test_mode:
87 | time_feature = 1.0
88 | time_feature = np.array(time_feature, dtype=self.dtype)
89 |
90 | if isinstance(obs, dict):
91 | obs["observation"] = np.append(obs["observation"], time_feature)
92 | return obs
93 | return np.append(obs, time_feature)
94 |
--------------------------------------------------------------------------------
/thirdparty/sb3_contrib/local_modifications.txt:
--------------------------------------------------------------------------------
1 | modified: sb3_contrib/qrdqn/policies.py
2 | sb3_contrib/qrdqn/qrdqn.py
3 |
--------------------------------------------------------------------------------
/thirdparty/sb3_contrib/ppo_mask/__init__.py:
--------------------------------------------------------------------------------
1 | from sb3_contrib.ppo_mask.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
2 | from sb3_contrib.ppo_mask.ppo_mask import MaskablePPO
3 |
--------------------------------------------------------------------------------
/thirdparty/sb3_contrib/ppo_mask/policies.py:
--------------------------------------------------------------------------------
1 | from stable_baselines3.common.policies import register_policy
2 |
3 | from sb3_contrib.common.maskable.policies import (
4 | MaskableActorCriticCnnPolicy,
5 | MaskableActorCriticPolicy,
6 | MaskableMultiInputActorCriticPolicy,
7 | )
8 |
9 | MlpPolicy = MaskableActorCriticPolicy
10 | CnnPolicy = MaskableActorCriticCnnPolicy
11 | MultiInputPolicy = MaskableMultiInputActorCriticPolicy
12 |
13 | register_policy("MlpPolicy", MaskableActorCriticPolicy)
14 | register_policy("CnnPolicy", MaskableActorCriticCnnPolicy)
15 | register_policy("MultiInputPolicy", MaskableMultiInputActorCriticPolicy)
16 |
--------------------------------------------------------------------------------
/thirdparty/sb3_contrib/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RobustFieldAutonomyLab/Stochastic_Road_Network/1cdcd41c7311560bf7a5df0d4d8bca829fe2b958/thirdparty/sb3_contrib/py.typed
--------------------------------------------------------------------------------
/thirdparty/sb3_contrib/qrdqn/__init__.py:
--------------------------------------------------------------------------------
1 | from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
2 | from sb3_contrib.qrdqn.qrdqn import QRDQN
3 |
--------------------------------------------------------------------------------
/thirdparty/sb3_contrib/tqc/__init__.py:
--------------------------------------------------------------------------------
1 | from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
2 | from sb3_contrib.tqc.tqc import TQC
3 |
--------------------------------------------------------------------------------
/thirdparty/sb3_contrib/trpo/__init__.py:
--------------------------------------------------------------------------------
1 | from sb3_contrib.trpo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
2 | from sb3_contrib.trpo.trpo import TRPO
3 |
--------------------------------------------------------------------------------
/thirdparty/sb3_contrib/trpo/policies.py:
--------------------------------------------------------------------------------
1 | # This file is here just to define MlpPolicy/CnnPolicy
2 | # that work for TRPO
3 | from stable_baselines3.common.policies import (
4 | ActorCriticCnnPolicy,
5 | ActorCriticPolicy,
6 | MultiInputActorCriticPolicy,
7 | register_policy,
8 | )
9 |
10 | MlpPolicy = ActorCriticPolicy
11 | CnnPolicy = ActorCriticCnnPolicy
12 | MultiInputPolicy = MultiInputActorCriticPolicy
13 |
14 | register_policy("MlpPolicy", ActorCriticPolicy)
15 | register_policy("CnnPolicy", ActorCriticCnnPolicy)
16 | register_policy("MultiInputPolicy", MultiInputPolicy)
17 |
--------------------------------------------------------------------------------
/thirdparty/sb3_contrib/version.txt:
--------------------------------------------------------------------------------
1 | 1.5.0
2 |
--------------------------------------------------------------------------------
/thirdparty/stable_baselines3/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from stable_baselines3.a2c import A2C
4 | from stable_baselines3.common.utils import get_system_info
5 | from stable_baselines3.ddpg import DDPG
6 | from stable_baselines3.dqn import DQN
7 | from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
8 | from stable_baselines3.ppo import PPO
9 | from stable_baselines3.sac import SAC
10 | from stable_baselines3.td3 import TD3
11 |
12 | # Read version from file
13 | version_file = os.path.join(os.path.dirname(__file__), "version.txt")
14 | with open(version_file, "r") as file_handler:
15 | __version__ = file_handler.read().strip()
16 |
17 |
18 | def HER(*args, **kwargs):
19 | raise ImportError(
20 | "Since Stable Baselines 2.1.0, `HER` is now a replay buffer class `HerReplayBuffer`.\n "
21 | "Please check the documentation for more information: https://stable-baselines3.readthedocs.io/"
22 | )
23 |
--------------------------------------------------------------------------------
/thirdparty/stable_baselines3/a2c/__init__.py:
--------------------------------------------------------------------------------
1 | from stable_baselines3.a2c.a2c import A2C
2 | from stable_baselines3.a2c.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
3 |
--------------------------------------------------------------------------------
/thirdparty/stable_baselines3/a2c/policies.py:
--------------------------------------------------------------------------------
1 | # This file is here just to define MlpPolicy/CnnPolicy
2 | # that work for A2C
3 | from stable_baselines3.common.policies import (
4 | ActorCriticCnnPolicy,
5 | ActorCriticPolicy,
6 | MultiInputActorCriticPolicy,
7 | register_policy,
8 | )
9 |
10 | MlpPolicy = ActorCriticPolicy
11 | CnnPolicy = ActorCriticCnnPolicy
12 | MultiInputPolicy = MultiInputActorCriticPolicy
13 |
14 | register_policy("MlpPolicy", ActorCriticPolicy)
15 | register_policy("CnnPolicy", ActorCriticCnnPolicy)
16 | register_policy("MultiInputPolicy", MultiInputPolicy)
17 |
--------------------------------------------------------------------------------
/thirdparty/stable_baselines3/common/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RobustFieldAutonomyLab/Stochastic_Road_Network/1cdcd41c7311560bf7a5df0d4d8bca829fe2b958/thirdparty/stable_baselines3/common/__init__.py
--------------------------------------------------------------------------------
/thirdparty/stable_baselines3/common/atari_wrappers.py:
--------------------------------------------------------------------------------
1 | import gym
2 | import numpy as np
3 | from gym import spaces
4 |
5 | try:
6 | import cv2 # pytype:disable=import-error
7 |
8 | cv2.ocl.setUseOpenCL(False)
9 | except ImportError:
10 | cv2 = None
11 |
12 | from stable_baselines3.common.type_aliases import GymObs, GymStepReturn
13 |
14 |
15 | class NoopResetEnv(gym.Wrapper):
16 | """
17 | Sample initial states by taking random number of no-ops on reset.
18 | No-op is assumed to be action 0.
19 |
20 | :param env: the environment to wrap
21 | :param noop_max: the maximum value of no-ops to run
22 | """
23 |
24 | def __init__(self, env: gym.Env, noop_max: int = 30):
25 | gym.Wrapper.__init__(self, env)
26 | self.noop_max = noop_max
27 | self.override_num_noops = None
28 | self.noop_action = 0
29 | assert env.unwrapped.get_action_meanings()[0] == "NOOP"
30 |
31 | def reset(self, **kwargs) -> np.ndarray:
32 | self.env.reset(**kwargs)
33 | if self.override_num_noops is not None:
34 | noops = self.override_num_noops
35 | else:
36 | noops = self.unwrapped.np_random.randint(1, self.noop_max + 1)
37 | assert noops > 0
38 | obs = np.zeros(0)
39 | for _ in range(noops):
40 | obs, _, done, _ = self.env.step(self.noop_action)
41 | if done:
42 | obs = self.env.reset(**kwargs)
43 | return obs
44 |
45 |
46 | class FireResetEnv(gym.Wrapper):
47 | """
48 | Take action on reset for environments that are fixed until firing.
49 |
50 | :param env: the environment to wrap
51 | """
52 |
53 | def __init__(self, env: gym.Env):
54 | gym.Wrapper.__init__(self, env)
55 | assert env.unwrapped.get_action_meanings()[1] == "FIRE"
56 | assert len(env.unwrapped.get_action_meanings()) >= 3
57 |
58 | def reset(self, **kwargs) -> np.ndarray:
59 | self.env.reset(**kwargs)
60 | obs, _, done, _ = self.env.step(1)
61 | if done:
62 | self.env.reset(**kwargs)
63 | obs, _, done, _ = self.env.step(2)
64 | if done:
65 | self.env.reset(**kwargs)
66 | return obs
67 |
68 |
69 | class EpisodicLifeEnv(gym.Wrapper):
70 | """
71 | Make end-of-life == end-of-episode, but only reset on true game over.
72 | Done by DeepMind for the DQN and co. since it helps value estimation.
73 |
74 | :param env: the environment to wrap
75 | """
76 |
77 | def __init__(self, env: gym.Env):
78 | gym.Wrapper.__init__(self, env)
79 | self.lives = 0
80 | self.was_real_done = True
81 |
82 | def step(self, action: int) -> GymStepReturn:
83 | obs, reward, done, info = self.env.step(action)
84 | self.was_real_done = done
85 | # check current lives, make loss of life terminal,
86 | # then update lives to handle bonus lives
87 | lives = self.env.unwrapped.ale.lives()
88 | if 0 < lives < self.lives:
89 | # for Qbert sometimes we stay in lives == 0 condtion for a few frames
90 | # so its important to keep lives > 0, so that we only reset once
91 | # the environment advertises done.
92 | done = True
93 | self.lives = lives
94 | return obs, reward, done, info
95 |
96 | def reset(self, **kwargs) -> np.ndarray:
97 | """
98 | Calls the Gym environment reset, only when lives are exhausted.
99 | This way all states are still reachable even though lives are episodic,
100 | and the learner need not know about any of this behind-the-scenes.
101 |
102 | :param kwargs: Extra keywords passed to env.reset() call
103 | :return: the first observation of the environment
104 | """
105 | if self.was_real_done:
106 | obs = self.env.reset(**kwargs)
107 | else:
108 | # no-op step to advance from terminal/lost life state
109 | obs, _, _, _ = self.env.step(0)
110 | self.lives = self.env.unwrapped.ale.lives()
111 | return obs
112 |
113 |
114 | class MaxAndSkipEnv(gym.Wrapper):
115 | """
116 | Return only every ``skip``-th frame (frameskipping)
117 |
118 | :param env: the environment
119 | :param skip: number of ``skip``-th frame
120 | """
121 |
122 | def __init__(self, env: gym.Env, skip: int = 4):
123 | gym.Wrapper.__init__(self, env)
124 | # most recent raw observations (for max pooling across time steps)
125 | self._obs_buffer = np.zeros((2,) + env.observation_space.shape, dtype=env.observation_space.dtype)
126 | self._skip = skip
127 |
128 | def step(self, action: int) -> GymStepReturn:
129 | """
130 | Step the environment with the given action
131 | Repeat action, sum reward, and max over last observations.
132 |
133 | :param action: the action
134 | :return: observation, reward, done, information
135 | """
136 | total_reward = 0.0
137 | done = None
138 | for i in range(self._skip):
139 | obs, reward, done, info = self.env.step(action)
140 | if i == self._skip - 2:
141 | self._obs_buffer[0] = obs
142 | if i == self._skip - 1:
143 | self._obs_buffer[1] = obs
144 | total_reward += reward
145 | if done:
146 | break
147 | # Note that the observation on the done=True frame
148 | # doesn't matter
149 | max_frame = self._obs_buffer.max(axis=0)
150 |
151 | return max_frame, total_reward, done, info
152 |
153 | def reset(self, **kwargs) -> GymObs:
154 | return self.env.reset(**kwargs)
155 |
156 |
157 | class ClipRewardEnv(gym.RewardWrapper):
158 | """
159 | Clips the reward to {+1, 0, -1} by its sign.
160 |
161 | :param env: the environment
162 | """
163 |
164 | def __init__(self, env: gym.Env):
165 | gym.RewardWrapper.__init__(self, env)
166 |
167 | def reward(self, reward: float) -> float:
168 | """
169 | Bin reward to {+1, 0, -1} by its sign.
170 |
171 | :param reward:
172 | :return:
173 | """
174 | return np.sign(reward)
175 |
176 |
177 | class WarpFrame(gym.ObservationWrapper):
178 | """
179 | Convert to grayscale and warp frames to 84x84 (default)
180 | as done in the Nature paper and later work.
181 |
182 | :param env: the environment
183 | :param width:
184 | :param height:
185 | """
186 |
187 | def __init__(self, env: gym.Env, width: int = 84, height: int = 84):
188 | gym.ObservationWrapper.__init__(self, env)
189 | self.width = width
190 | self.height = height
191 | self.observation_space = spaces.Box(
192 | low=0, high=255, shape=(self.height, self.width, 1), dtype=env.observation_space.dtype
193 | )
194 |
195 | def observation(self, frame: np.ndarray) -> np.ndarray:
196 | """
197 | returns the current observation from a frame
198 |
199 | :param frame: environment frame
200 | :return: the observation
201 | """
202 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
203 | frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
204 | return frame[:, :, None]
205 |
206 |
207 | class AtariWrapper(gym.Wrapper):
208 | """
209 | Atari 2600 preprocessings
210 |
211 | Specifically:
212 |
213 | * NoopReset: obtain initial state by taking random number of no-ops on reset.
214 | * Frame skipping: 4 by default
215 | * Max-pooling: most recent two observations
216 | * Termination signal when a life is lost.
217 | * Resize to a square image: 84x84 by default
218 | * Grayscale observation
219 | * Clip reward to {-1, 0, 1}
220 |
221 | :param env: gym environment
222 | :param noop_max: max number of no-ops
223 | :param frame_skip: the frequency at which the agent experiences the game.
224 | :param screen_size: resize Atari frame
225 | :param terminal_on_life_loss: if True, then step() returns done=True whenever a life is lost.
226 | :param clip_reward: If True (default), the reward is clip to {-1, 0, 1} depending on its sign.
227 | """
228 |
229 | def __init__(
230 | self,
231 | env: gym.Env,
232 | noop_max: int = 30,
233 | frame_skip: int = 4,
234 | screen_size: int = 84,
235 | terminal_on_life_loss: bool = True,
236 | clip_reward: bool = True,
237 | ):
238 | env = NoopResetEnv(env, noop_max=noop_max)
239 | env = MaxAndSkipEnv(env, skip=frame_skip)
240 | if terminal_on_life_loss:
241 | env = EpisodicLifeEnv(env)
242 | if "FIRE" in env.unwrapped.get_action_meanings():
243 | env = FireResetEnv(env)
244 | env = WarpFrame(env, width=screen_size, height=screen_size)
245 | if clip_reward:
246 | env = ClipRewardEnv(env)
247 |
248 | super(AtariWrapper, self).__init__(env)
249 |
--------------------------------------------------------------------------------
/thirdparty/stable_baselines3/common/env_util.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Any, Callable, Dict, Optional, Type, Union
3 |
4 | import gym
5 |
6 | from stable_baselines3.common.atari_wrappers import AtariWrapper
7 | from stable_baselines3.common.monitor import Monitor
8 | from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv
9 |
10 |
11 | def unwrap_wrapper(env: gym.Env, wrapper_class: Type[gym.Wrapper]) -> Optional[gym.Wrapper]:
12 | """
13 | Retrieve a ``VecEnvWrapper`` object by recursively searching.
14 |
15 | :param env: Environment to unwrap
16 | :param wrapper_class: Wrapper to look for
17 | :return: Environment unwrapped till ``wrapper_class`` if it has been wrapped with it
18 | """
19 | env_tmp = env
20 | while isinstance(env_tmp, gym.Wrapper):
21 | if isinstance(env_tmp, wrapper_class):
22 | return env_tmp
23 | env_tmp = env_tmp.env
24 | return None
25 |
26 |
27 | def is_wrapped(env: Type[gym.Env], wrapper_class: Type[gym.Wrapper]) -> bool:
28 | """
29 | Check if a given environment has been wrapped with a given wrapper.
30 |
31 | :param env: Environment to check
32 | :param wrapper_class: Wrapper class to look for
33 | :return: True if environment has been wrapped with ``wrapper_class``.
34 | """
35 | return unwrap_wrapper(env, wrapper_class) is not None
36 |
37 |
38 | def make_vec_env(
39 | env_id: Union[str, Type[gym.Env]],
40 | n_envs: int = 1,
41 | seed: Optional[int] = None,
42 | start_index: int = 0,
43 | monitor_dir: Optional[str] = None,
44 | wrapper_class: Optional[Callable[[gym.Env], gym.Env]] = None,
45 | env_kwargs: Optional[Dict[str, Any]] = None,
46 | vec_env_cls: Optional[Type[Union[DummyVecEnv, SubprocVecEnv]]] = None,
47 | vec_env_kwargs: Optional[Dict[str, Any]] = None,
48 | monitor_kwargs: Optional[Dict[str, Any]] = None,
49 | wrapper_kwargs: Optional[Dict[str, Any]] = None,
50 | ) -> VecEnv:
51 | """
52 | Create a wrapped, monitored ``VecEnv``.
53 | By default it uses a ``DummyVecEnv`` which is usually faster
54 | than a ``SubprocVecEnv``.
55 |
56 | :param env_id: the environment ID or the environment class
57 | :param n_envs: the number of environments you wish to have in parallel
58 | :param seed: the initial seed for the random number generator
59 | :param start_index: start rank index
60 | :param monitor_dir: Path to a folder where the monitor files will be saved.
61 | If None, no file will be written, however, the env will still be wrapped
62 | in a Monitor wrapper to provide additional information about training.
63 | :param wrapper_class: Additional wrapper to use on the environment.
64 | This can also be a function with single argument that wraps the environment in many things.
65 | :param env_kwargs: Optional keyword argument to pass to the env constructor
66 | :param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None.
67 | :param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor.
68 | :param monitor_kwargs: Keyword arguments to pass to the ``Monitor`` class constructor.
69 | :param wrapper_kwargs: Keyword arguments to pass to the ``Wrapper`` class constructor.
70 | :return: The wrapped environment
71 | """
72 | env_kwargs = {} if env_kwargs is None else env_kwargs
73 | vec_env_kwargs = {} if vec_env_kwargs is None else vec_env_kwargs
74 | monitor_kwargs = {} if monitor_kwargs is None else monitor_kwargs
75 | wrapper_kwargs = {} if wrapper_kwargs is None else wrapper_kwargs
76 |
77 | def make_env(rank):
78 | def _init():
79 | if isinstance(env_id, str):
80 | env = gym.make(env_id, **env_kwargs)
81 | else:
82 | env = env_id(**env_kwargs)
83 | if seed is not None:
84 | env.seed(seed + rank)
85 | env.action_space.seed(seed + rank)
86 | # Wrap the env in a Monitor wrapper
87 | # to have additional training information
88 | monitor_path = os.path.join(monitor_dir, str(rank)) if monitor_dir is not None else None
89 | # Create the monitor folder if needed
90 | if monitor_path is not None:
91 | os.makedirs(monitor_dir, exist_ok=True)
92 | env = Monitor(env, filename=monitor_path, **monitor_kwargs)
93 | # Optionally, wrap the environment with the provided wrapper
94 | if wrapper_class is not None:
95 | env = wrapper_class(env, **wrapper_kwargs)
96 | return env
97 |
98 | return _init
99 |
100 | # No custom VecEnv is passed
101 | if vec_env_cls is None:
102 | # Default: use a DummyVecEnv
103 | vec_env_cls = DummyVecEnv
104 |
105 | return vec_env_cls([make_env(i + start_index) for i in range(n_envs)], **vec_env_kwargs)
106 |
107 |
108 | def make_atari_env(
109 | env_id: Union[str, Type[gym.Env]],
110 | n_envs: int = 1,
111 | seed: Optional[int] = None,
112 | start_index: int = 0,
113 | monitor_dir: Optional[str] = None,
114 | wrapper_kwargs: Optional[Dict[str, Any]] = None,
115 | env_kwargs: Optional[Dict[str, Any]] = None,
116 | vec_env_cls: Optional[Union[DummyVecEnv, SubprocVecEnv]] = None,
117 | vec_env_kwargs: Optional[Dict[str, Any]] = None,
118 | monitor_kwargs: Optional[Dict[str, Any]] = None,
119 | ) -> VecEnv:
120 | """
121 | Create a wrapped, monitored VecEnv for Atari.
122 | It is a wrapper around ``make_vec_env`` that includes common preprocessing for Atari games.
123 |
124 | :param env_id: the environment ID or the environment class
125 | :param n_envs: the number of environments you wish to have in parallel
126 | :param seed: the initial seed for the random number generator
127 | :param start_index: start rank index
128 | :param monitor_dir: Path to a folder where the monitor files will be saved.
129 | If None, no file will be written, however, the env will still be wrapped
130 | in a Monitor wrapper to provide additional information about training.
131 | :param wrapper_kwargs: Optional keyword argument to pass to the ``AtariWrapper``
132 | :param env_kwargs: Optional keyword argument to pass to the env constructor
133 | :param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None.
134 | :param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor.
135 | :param monitor_kwargs: Keyword arguments to pass to the ``Monitor`` class constructor.
136 | :return: The wrapped environment
137 | """
138 | if wrapper_kwargs is None:
139 | wrapper_kwargs = {}
140 |
141 | def atari_wrapper(env: gym.Env) -> gym.Env:
142 | env = AtariWrapper(env, **wrapper_kwargs)
143 | return env
144 |
145 | return make_vec_env(
146 | env_id,
147 | n_envs=n_envs,
148 | seed=seed,
149 | start_index=start_index,
150 | monitor_dir=monitor_dir,
151 | wrapper_class=atari_wrapper,
152 | env_kwargs=env_kwargs,
153 | vec_env_cls=vec_env_cls,
154 | vec_env_kwargs=vec_env_kwargs,
155 | monitor_kwargs=monitor_kwargs,
156 | )
157 |
--------------------------------------------------------------------------------
/thirdparty/stable_baselines3/common/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from stable_baselines3.common.envs.bit_flipping_env import BitFlippingEnv
2 | from stable_baselines3.common.envs.identity_env import (
3 | FakeImageEnv,
4 | IdentityEnv,
5 | IdentityEnvBox,
6 | IdentityEnvMultiBinary,
7 | IdentityEnvMultiDiscrete,
8 | )
9 | from stable_baselines3.common.envs.multi_input_envs import SimpleMultiObsEnv
10 |
--------------------------------------------------------------------------------
/thirdparty/stable_baselines3/common/envs/bit_flipping_env.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from typing import Any, Dict, Optional, Union
3 |
4 | import numpy as np
5 | from gym import GoalEnv, spaces
6 | from gym.envs.registration import EnvSpec
7 |
8 | from stable_baselines3.common.type_aliases import GymStepReturn
9 |
10 |
11 | class BitFlippingEnv(GoalEnv):
12 | """
13 | Simple bit flipping env, useful to test HER.
14 | The goal is to flip all the bits to get a vector of ones.
15 | In the continuous variant, if the ith action component has a value > 0,
16 | then the ith bit will be flipped.
17 |
18 | :param n_bits: Number of bits to flip
19 | :param continuous: Whether to use the continuous actions version or not,
20 | by default, it uses the discrete one
21 | :param max_steps: Max number of steps, by default, equal to n_bits
22 | :param discrete_obs_space: Whether to use the discrete observation
23 | version or not, by default, it uses the ``MultiBinary`` one
24 | :param image_obs_space: Use image as input instead of the ``MultiBinary`` one.
25 | :param channel_first: Whether to use channel-first or last image.
26 | """
27 |
28 | spec = EnvSpec("BitFlippingEnv-v0")
29 |
30 | def __init__(
31 | self,
32 | n_bits: int = 10,
33 | continuous: bool = False,
34 | max_steps: Optional[int] = None,
35 | discrete_obs_space: bool = False,
36 | image_obs_space: bool = False,
37 | channel_first: bool = True,
38 | ):
39 | super(BitFlippingEnv, self).__init__()
40 | # Shape of the observation when using image space
41 | self.image_shape = (1, 36, 36) if channel_first else (36, 36, 1)
42 | # The achieved goal is determined by the current state
43 | # here, it is a special where they are equal
44 | if discrete_obs_space:
45 | # In the discrete case, the agent act on the binary
46 | # representation of the observation
47 | self.observation_space = spaces.Dict(
48 | {
49 | "observation": spaces.Discrete(2 ** n_bits),
50 | "achieved_goal": spaces.Discrete(2 ** n_bits),
51 | "desired_goal": spaces.Discrete(2 ** n_bits),
52 | }
53 | )
54 | elif image_obs_space:
55 | # When using image as input,
56 | # one image contains the bits 0 -> 0, 1 -> 255
57 | # and the rest is filled with zeros
58 | self.observation_space = spaces.Dict(
59 | {
60 | "observation": spaces.Box(
61 | low=0,
62 | high=255,
63 | shape=self.image_shape,
64 | dtype=np.uint8,
65 | ),
66 | "achieved_goal": spaces.Box(
67 | low=0,
68 | high=255,
69 | shape=self.image_shape,
70 | dtype=np.uint8,
71 | ),
72 | "desired_goal": spaces.Box(
73 | low=0,
74 | high=255,
75 | shape=self.image_shape,
76 | dtype=np.uint8,
77 | ),
78 | }
79 | )
80 | else:
81 | self.observation_space = spaces.Dict(
82 | {
83 | "observation": spaces.MultiBinary(n_bits),
84 | "achieved_goal": spaces.MultiBinary(n_bits),
85 | "desired_goal": spaces.MultiBinary(n_bits),
86 | }
87 | )
88 |
89 | self.obs_space = spaces.MultiBinary(n_bits)
90 |
91 | if continuous:
92 | self.action_space = spaces.Box(-1, 1, shape=(n_bits,), dtype=np.float32)
93 | else:
94 | self.action_space = spaces.Discrete(n_bits)
95 | self.continuous = continuous
96 | self.discrete_obs_space = discrete_obs_space
97 | self.image_obs_space = image_obs_space
98 | self.state = None
99 | self.desired_goal = np.ones((n_bits,))
100 | if max_steps is None:
101 | max_steps = n_bits
102 | self.max_steps = max_steps
103 | self.current_step = 0
104 |
105 | def seed(self, seed: int) -> None:
106 | self.obs_space.seed(seed)
107 |
108 | def convert_if_needed(self, state: np.ndarray) -> Union[int, np.ndarray]:
109 | """
110 | Convert to discrete space if needed.
111 |
112 | :param state:
113 | :return:
114 | """
115 | if self.discrete_obs_space:
116 | # The internal state is the binary representation of the
117 | # observed one
118 | return int(sum([state[i] * 2 ** i for i in range(len(state))]))
119 |
120 | if self.image_obs_space:
121 | size = np.prod(self.image_shape)
122 | image = np.concatenate((state * 255, np.zeros(size - len(state), dtype=np.uint8)))
123 | return image.reshape(self.image_shape).astype(np.uint8)
124 | return state
125 |
126 | def convert_to_bit_vector(self, state: Union[int, np.ndarray], batch_size: int) -> np.ndarray:
127 | """
128 | Convert to bit vector if needed.
129 |
130 | :param state:
131 | :param batch_size:
132 | :return:
133 | """
134 | # Convert back to bit vector
135 | if isinstance(state, int):
136 | state = np.array(state).reshape(batch_size, -1)
137 | # Convert to binary representation
138 | state = (((state[:, :] & (1 << np.arange(len(self.state))))) > 0).astype(int)
139 | elif self.image_obs_space:
140 | state = state.reshape(batch_size, -1)[:, : len(self.state)] / 255
141 | else:
142 | state = np.array(state).reshape(batch_size, -1)
143 |
144 | return state
145 |
146 | def _get_obs(self) -> Dict[str, Union[int, np.ndarray]]:
147 | """
148 | Helper to create the observation.
149 |
150 | :return: The current observation.
151 | """
152 | return OrderedDict(
153 | [
154 | ("observation", self.convert_if_needed(self.state.copy())),
155 | ("achieved_goal", self.convert_if_needed(self.state.copy())),
156 | ("desired_goal", self.convert_if_needed(self.desired_goal.copy())),
157 | ]
158 | )
159 |
160 | def reset(self) -> Dict[str, Union[int, np.ndarray]]:
161 | self.current_step = 0
162 | self.state = self.obs_space.sample()
163 | return self._get_obs()
164 |
165 | def step(self, action: Union[np.ndarray, int]) -> GymStepReturn:
166 | if self.continuous:
167 | self.state[action > 0] = 1 - self.state[action > 0]
168 | else:
169 | self.state[action] = 1 - self.state[action]
170 | obs = self._get_obs()
171 | reward = float(self.compute_reward(obs["achieved_goal"], obs["desired_goal"], None))
172 | done = reward == 0
173 | self.current_step += 1
174 | # Episode terminate when we reached the goal or the max number of steps
175 | info = {"is_success": done}
176 | done = done or self.current_step >= self.max_steps
177 | return obs, reward, done, info
178 |
179 | def compute_reward(
180 | self, achieved_goal: Union[int, np.ndarray], desired_goal: Union[int, np.ndarray], _info: Optional[Dict[str, Any]]
181 | ) -> np.float32:
182 | # As we are using a vectorized version, we need to keep track of the `batch_size`
183 | if isinstance(achieved_goal, int):
184 | batch_size = 1
185 | elif self.image_obs_space:
186 | batch_size = achieved_goal.shape[0] if len(achieved_goal.shape) > 3 else 1
187 | else:
188 | batch_size = achieved_goal.shape[0] if len(achieved_goal.shape) > 1 else 1
189 |
190 | desired_goal = self.convert_to_bit_vector(desired_goal, batch_size)
191 | achieved_goal = self.convert_to_bit_vector(achieved_goal, batch_size)
192 |
193 | # Deceptive reward: it is positive only when the goal is achieved
194 | # Here we are using a vectorized version
195 | distance = np.linalg.norm(achieved_goal - desired_goal, axis=-1)
196 | return -(distance > 0).astype(np.float32)
197 |
198 | def render(self, mode: str = "human") -> Optional[np.ndarray]:
199 | if mode == "rgb_array":
200 | return self.state.copy()
201 | print(self.state)
202 |
203 | def close(self) -> None:
204 | pass
205 |
--------------------------------------------------------------------------------
/thirdparty/stable_baselines3/common/envs/identity_env.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union
2 |
3 | import numpy as np
4 | from gym import Env, Space
5 | from gym.spaces import Box, Discrete, MultiBinary, MultiDiscrete
6 |
7 | from stable_baselines3.common.type_aliases import GymObs, GymStepReturn
8 |
9 |
10 | class IdentityEnv(Env):
11 | def __init__(self, dim: Optional[int] = None, space: Optional[Space] = None, ep_length: int = 100):
12 | """
13 | Identity environment for testing purposes
14 |
15 | :param dim: the size of the action and observation dimension you want
16 | to learn. Provide at most one of ``dim`` and ``space``. If both are
17 | None, then initialization proceeds with ``dim=1`` and ``space=None``.
18 | :param space: the action and observation space. Provide at most one of
19 | ``dim`` and ``space``.
20 | :param ep_length: the length of each episode in timesteps
21 | """
22 | if space is None:
23 | if dim is None:
24 | dim = 1
25 | space = Discrete(dim)
26 | else:
27 | assert dim is None, "arguments for both 'dim' and 'space' provided: at most one allowed"
28 |
29 | self.action_space = self.observation_space = space
30 | self.ep_length = ep_length
31 | self.current_step = 0
32 | self.num_resets = -1 # Becomes 0 after __init__ exits.
33 | self.reset()
34 |
35 | def reset(self) -> GymObs:
36 | self.current_step = 0
37 | self.num_resets += 1
38 | self._choose_next_state()
39 | return self.state
40 |
41 | def step(self, action: Union[int, np.ndarray]) -> GymStepReturn:
42 | reward = self._get_reward(action)
43 | self._choose_next_state()
44 | self.current_step += 1
45 | done = self.current_step >= self.ep_length
46 | return self.state, reward, done, {}
47 |
48 | def _choose_next_state(self) -> None:
49 | self.state = self.action_space.sample()
50 |
51 | def _get_reward(self, action: Union[int, np.ndarray]) -> float:
52 | return 1.0 if np.all(self.state == action) else 0.0
53 |
54 | def render(self, mode: str = "human") -> None:
55 | pass
56 |
57 |
58 | class IdentityEnvBox(IdentityEnv):
59 | def __init__(self, low: float = -1.0, high: float = 1.0, eps: float = 0.05, ep_length: int = 100):
60 | """
61 | Identity environment for testing purposes
62 |
63 | :param low: the lower bound of the box dim
64 | :param high: the upper bound of the box dim
65 | :param eps: the epsilon bound for correct value
66 | :param ep_length: the length of each episode in timesteps
67 | """
68 | space = Box(low=low, high=high, shape=(1,), dtype=np.float32)
69 | super().__init__(ep_length=ep_length, space=space)
70 | self.eps = eps
71 |
72 | def step(self, action: np.ndarray) -> GymStepReturn:
73 | reward = self._get_reward(action)
74 | self._choose_next_state()
75 | self.current_step += 1
76 | done = self.current_step >= self.ep_length
77 | return self.state, reward, done, {}
78 |
79 | def _get_reward(self, action: np.ndarray) -> float:
80 | return 1.0 if (self.state - self.eps) <= action <= (self.state + self.eps) else 0.0
81 |
82 |
83 | class IdentityEnvMultiDiscrete(IdentityEnv):
84 | def __init__(self, dim: int = 1, ep_length: int = 100):
85 | """
86 | Identity environment for testing purposes
87 |
88 | :param dim: the size of the dimensions you want to learn
89 | :param ep_length: the length of each episode in timesteps
90 | """
91 | space = MultiDiscrete([dim, dim])
92 | super().__init__(ep_length=ep_length, space=space)
93 |
94 |
95 | class IdentityEnvMultiBinary(IdentityEnv):
96 | def __init__(self, dim: int = 1, ep_length: int = 100):
97 | """
98 | Identity environment for testing purposes
99 |
100 | :param dim: the size of the dimensions you want to learn
101 | :param ep_length: the length of each episode in timesteps
102 | """
103 | space = MultiBinary(dim)
104 | super().__init__(ep_length=ep_length, space=space)
105 |
106 |
107 | class FakeImageEnv(Env):
108 | """
109 | Fake image environment for testing purposes, it mimics Atari games.
110 |
111 | :param action_dim: Number of discrete actions
112 | :param screen_height: Height of the image
113 | :param screen_width: Width of the image
114 | :param n_channels: Number of color channels
115 | :param discrete: Create discrete action space instead of continuous
116 | :param channel_first: Put channels on first axis instead of last
117 | """
118 |
119 | def __init__(
120 | self,
121 | action_dim: int = 6,
122 | screen_height: int = 84,
123 | screen_width: int = 84,
124 | n_channels: int = 1,
125 | discrete: bool = True,
126 | channel_first: bool = False,
127 | ):
128 | self.observation_shape = (screen_height, screen_width, n_channels)
129 | if channel_first:
130 | self.observation_shape = (n_channels, screen_height, screen_width)
131 | self.observation_space = Box(low=0, high=255, shape=self.observation_shape, dtype=np.uint8)
132 | if discrete:
133 | self.action_space = Discrete(action_dim)
134 | else:
135 | self.action_space = Box(low=-1, high=1, shape=(5,), dtype=np.float32)
136 | self.ep_length = 10
137 | self.current_step = 0
138 |
139 | def reset(self) -> np.ndarray:
140 | self.current_step = 0
141 | return self.observation_space.sample()
142 |
143 | def step(self, action: Union[np.ndarray, int]) -> GymStepReturn:
144 | reward = 0.0
145 | self.current_step += 1
146 | done = self.current_step >= self.ep_length
147 | return self.observation_space.sample(), reward, done, {}
148 |
149 | def render(self, mode: str = "human") -> None:
150 | pass
151 |
--------------------------------------------------------------------------------
/thirdparty/stable_baselines3/common/envs/multi_input_envs.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Union
2 |
3 | import gym
4 | import numpy as np
5 |
6 | from stable_baselines3.common.type_aliases import GymStepReturn
7 |
8 |
9 | class SimpleMultiObsEnv(gym.Env):
10 | """
11 | Base class for GridWorld-based MultiObs Environments 4x4 grid world.
12 |
13 | .. code-block:: text
14 |
15 | ____________
16 | | 0 1 2 3|
17 | | 4|¯5¯¯6¯| 7|
18 | | 8|_9_10_|11|
19 | |12 13 14 15|
20 | ¯¯¯¯¯¯¯¯¯¯¯¯¯¯
21 |
22 | start is 0
23 | states 5, 6, 9, and 10 are blocked
24 | goal is 15
25 | actions are = [left, down, right, up]
26 |
27 | simple linear state env of 15 states but encoded with a vector and an image observation:
28 | each column is represented by a random vector and each row is
29 | represented by a random image, both sampled once at creation time.
30 |
31 | :param num_col: Number of columns in the grid
32 | :param num_row: Number of rows in the grid
33 | :param random_start: If true, agent starts in random position
34 | :param channel_last: If true, the image will be channel last, else it will be channel first
35 | """
36 |
37 | def __init__(
38 | self,
39 | num_col: int = 4,
40 | num_row: int = 4,
41 | random_start: bool = True,
42 | discrete_actions: bool = True,
43 | channel_last: bool = True,
44 | ):
45 | super(SimpleMultiObsEnv, self).__init__()
46 |
47 | self.vector_size = 5
48 | if channel_last:
49 | self.img_size = [64, 64, 1]
50 | else:
51 | self.img_size = [1, 64, 64]
52 |
53 | self.random_start = random_start
54 | self.discrete_actions = discrete_actions
55 | if discrete_actions:
56 | self.action_space = gym.spaces.Discrete(4)
57 | else:
58 | self.action_space = gym.spaces.Box(0, 1, (4,))
59 |
60 | self.observation_space = gym.spaces.Dict(
61 | spaces={
62 | "vec": gym.spaces.Box(0, 1, (self.vector_size,), dtype=np.float64),
63 | "img": gym.spaces.Box(0, 255, self.img_size, dtype=np.uint8),
64 | }
65 | )
66 | self.count = 0
67 | # Timeout
68 | self.max_count = 100
69 | self.log = ""
70 | self.state = 0
71 | self.action2str = ["left", "down", "right", "up"]
72 | self.init_possible_transitions()
73 |
74 | self.num_col = num_col
75 | self.state_mapping = []
76 | self.init_state_mapping(num_col, num_row)
77 |
78 | self.max_state = len(self.state_mapping) - 1
79 |
80 | def init_state_mapping(self, num_col: int, num_row: int) -> None:
81 | """
82 | Initializes the state_mapping array which holds the observation values for each state
83 |
84 | :param num_col: Number of columns.
85 | :param num_row: Number of rows.
86 | """
87 | # Each column is represented by a random vector
88 | col_vecs = np.random.random((num_col, self.vector_size))
89 | # Each row is represented by a random image
90 | row_imgs = np.random.randint(0, 255, (num_row, 64, 64), dtype=np.uint8)
91 |
92 | for i in range(num_col):
93 | for j in range(num_row):
94 | self.state_mapping.append({"vec": col_vecs[i], "img": row_imgs[j].reshape(self.img_size)})
95 |
96 | def get_state_mapping(self) -> Dict[str, np.ndarray]:
97 | """
98 | Uses the state to get the observation mapping.
99 |
100 | :return: observation dict {'vec': ..., 'img': ...}
101 | """
102 | return self.state_mapping[self.state]
103 |
104 | def init_possible_transitions(self) -> None:
105 | """
106 | Initializes the transitions of the environment
107 | The environment exploits the cardinal directions of the grid by noting that
108 | they correspond to simple addition and subtraction from the cell id within the grid
109 |
110 | - up => means moving up a row => means subtracting the length of a column
111 | - down => means moving down a row => means adding the length of a column
112 | - left => means moving left by one => means subtracting 1
113 | - right => means moving right by one => means adding 1
114 |
115 | Thus one only needs to specify in which states each action is possible
116 | in order to define the transitions of the environment
117 | """
118 | self.left_possible = [1, 2, 3, 13, 14, 15]
119 | self.down_possible = [0, 4, 8, 3, 7, 11]
120 | self.right_possible = [0, 1, 2, 12, 13, 14]
121 | self.up_possible = [4, 8, 12, 7, 11, 15]
122 |
123 | def step(self, action: Union[int, float, np.ndarray]) -> GymStepReturn:
124 | """
125 | Run one timestep of the environment's dynamics. When end of
126 | episode is reached, you are responsible for calling `reset()`
127 | to reset this environment's state.
128 | Accepts an action and returns a tuple (observation, reward, done, info).
129 |
130 | :param action:
131 | :return: tuple (observation, reward, done, info).
132 | """
133 | if not self.discrete_actions:
134 | action = np.argmax(action)
135 | else:
136 | action = int(action)
137 |
138 | self.count += 1
139 |
140 | prev_state = self.state
141 |
142 | reward = -0.1
143 | # define state transition
144 | if self.state in self.left_possible and action == 0: # left
145 | self.state -= 1
146 | elif self.state in self.down_possible and action == 1: # down
147 | self.state += self.num_col
148 | elif self.state in self.right_possible and action == 2: # right
149 | self.state += 1
150 | elif self.state in self.up_possible and action == 3: # up
151 | self.state -= self.num_col
152 |
153 | got_to_end = self.state == self.max_state
154 | reward = 1 if got_to_end else reward
155 | done = self.count > self.max_count or got_to_end
156 |
157 | self.log = f"Went {self.action2str[action]} in state {prev_state}, got to state {self.state}"
158 |
159 | return self.get_state_mapping(), reward, done, {"got_to_end": got_to_end}
160 |
161 | def render(self, mode: str = "human") -> None:
162 | """
163 | Prints the log of the environment.
164 |
165 | :param mode:
166 | """
167 | print(self.log)
168 |
169 | def reset(self) -> Dict[str, np.ndarray]:
170 | """
171 | Resets the environment state and step count and returns reset observation.
172 |
173 | :return: observation dict {'vec': ..., 'img': ...}
174 | """
175 | self.count = 0
176 | if not self.random_start:
177 | self.state = 0
178 | else:
179 | self.state = np.random.randint(0, self.max_state)
180 | return self.state_mapping[self.state]
181 |
--------------------------------------------------------------------------------
/thirdparty/stable_baselines3/common/evaluation.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union
3 |
4 | import gym
5 | import numpy as np
6 |
7 | from stable_baselines3.common import base_class
8 | from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecMonitor, is_vecenv_wrapped
9 |
10 | ##### local modification #####
11 | def ssd_policy(quantiles:np.ndarray, use_threshold:bool=False, mean_threshold:float=1e-03):
12 | means = np.mean(quantiles,axis=0)
13 | sort_idx = np.argsort(-1*means)
14 | best_1 = sort_idx[0]
15 | best_2 = sort_idx[1]
16 | if means[best_1] - means[best_2] > mean_threshold:
17 | return best_1
18 | else:
19 | if use_threshold:
20 | signed_second_moment = -1 * np.var(quantiles,axis=0)
21 | else:
22 | signed_second_moment = -1 * np.mean(quantiles**2,axis=0)
23 | action = best_1
24 | if signed_second_moment[best_2] > signed_second_moment[best_1]:
25 | action = best_2
26 | return action
27 |
28 |
29 | def evaluate_policy(
30 | model: "base_class.BaseAlgorithm",
31 | env: Union[gym.Env, VecEnv],
32 | n_eval_episodes: int = 10,
33 | deterministic: bool = True,
34 | render: bool = False,
35 | callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], None]] = None,
36 | reward_threshold: Optional[float] = None,
37 | return_episode_rewards: bool = False,
38 | warn: bool = True,
39 | ##### local modification #####
40 | eval_policy: str = "Greedy",
41 | ssd_thres: float = 1e-03
42 | ) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]:
43 | """
44 | Runs policy for ``n_eval_episodes`` episodes and returns average reward.
45 | If a vector env is passed in, this divides the episodes to evaluate onto the
46 | different elements of the vector env. This static division of work is done to
47 | remove bias. See https://github.com/DLR-RM/stable-baselines3/issues/402 for more
48 | details and discussion.
49 |
50 | .. note::
51 | If environment has not been wrapped with ``Monitor`` wrapper, reward and
52 | episode lengths are counted as it appears with ``env.step`` calls. If
53 | the environment contains wrappers that modify rewards or episode lengths
54 | (e.g. reward scaling, early episode reset), these will affect the evaluation
55 | results as well. You can avoid this by wrapping environment with ``Monitor``
56 | wrapper before anything else.
57 |
58 | :param model: The RL agent you want to evaluate.
59 | :param env: The gym environment or ``VecEnv`` environment.
60 | :param n_eval_episodes: Number of episode to evaluate the agent
61 | :param deterministic: Whether to use deterministic or stochastic actions
62 | :param render: Whether to render the environment or not
63 | :param callback: callback function to do additional checks,
64 | called after each step. Gets locals() and globals() passed as parameters.
65 | :param reward_threshold: Minimum expected reward per episode,
66 | this will raise an error if the performance is not met
67 | :param return_episode_rewards: If True, a list of rewards and episode lengths
68 | per episode will be returned instead of the mean.
69 | :param warn: If True (default), warns user about lack of a Monitor wrapper in the
70 | evaluation environment.
71 | :return: Mean reward per episode, std of reward per episode.
72 | Returns ([float], [int]) when ``return_episode_rewards`` is True, first
73 | list containing per-episode rewards and second containing per-episode lengths
74 | (in number of steps).
75 | """
76 | is_monitor_wrapped = False
77 | # Avoid circular import
78 | from stable_baselines3.common.monitor import Monitor
79 |
80 | if not isinstance(env, VecEnv):
81 | env = DummyVecEnv([lambda: env])
82 |
83 | is_monitor_wrapped = is_vecenv_wrapped(env, VecMonitor) or env.env_is_wrapped(Monitor)[0]
84 |
85 | if not is_monitor_wrapped and warn:
86 | warnings.warn(
87 | "Evaluation environment is not wrapped with a ``Monitor`` wrapper. "
88 | "This may result in reporting modified episode lengths and rewards, if other wrappers happen to modify these. "
89 | "Consider wrapping environment first with ``Monitor`` wrapper.",
90 | UserWarning,
91 | )
92 |
93 | ##### local modification #####
94 | # store quantiles prediction for all state action pair if the agent is QR-DQN
95 | if env.save_q_vals:
96 | print("saving quantiles (QR-DQN)")
97 | all_quantiles = []
98 | for i in range(env.num_states):
99 | obs = env.get_obs_at_state(i)
100 | q_vals = model.predict_quantiles(obs)
101 | all_quantiles.append(q_vals.cpu().data.numpy()[0])
102 |
103 | env.save_quantiles(np.array(all_quantiles))
104 |
105 | n_envs = env.num_envs
106 | episode_rewards = []
107 | episode_lengths = []
108 |
109 | episode_counts = np.zeros(n_envs, dtype="int")
110 | # Divides episodes among different sub environments in the vector as evenly as possible
111 | episode_count_targets = np.array([(n_eval_episodes + i) // n_envs for i in range(n_envs)], dtype="int")
112 |
113 | current_rewards = np.zeros(n_envs)
114 | current_lengths = np.zeros(n_envs, dtype="int")
115 | observations = env.reset()
116 | states = None
117 | while (episode_counts < episode_count_targets).any():
118 | ##### local modification #####
119 | if eval_policy == "Greedy":
120 | actions, states = model.predict(observations, state=states, deterministic=deterministic)
121 | # TODO: consider multi environments case
122 | elif eval_policy == "SSD":
123 | q_vals = model.predict_quantiles(observations)
124 | actions = np.array([ssd_policy(q_vals.cpu().data.numpy()[0])])
125 | states = None
126 | elif eval_policy == "Thresholded_SSD":
127 | q_vals = model.predict_quantiles(observations)
128 | actions = np.array([ssd_policy(q_vals.cpu().data.numpy()[0],use_threshold=True,mean_threshold=ssd_thres)])
129 | states = None
130 | else:
131 | raise RuntimeError("The evaluation policy is not available.")
132 |
133 | observations, rewards, dones, infos = env.step(actions)
134 | ##### local modification #####
135 | #current_rewards += rewards
136 | current_rewards += env.discount ** current_lengths[0] * rewards
137 | current_lengths += 1
138 | for i in range(n_envs):
139 | if episode_counts[i] < episode_count_targets[i]:
140 |
141 | # unpack values so that the callback can access the local variables
142 | reward = rewards[i]
143 | done = dones[i]
144 | info = infos[i]
145 |
146 | if callback is not None:
147 | callback(locals(), globals())
148 |
149 | ##### local modification #####
150 | # if dones[i]:
151 | if dones[i] or current_lengths[i] >= 1000:
152 | print("Eval_steps: ",current_lengths[i]," Eval_return: ",current_rewards[i])
153 | # if is_monitor_wrapped:
154 | # # Atari wrapper can send a "done" signal when
155 | # # the agent loses a life, but it does not correspond
156 | # # to the true end of episode
157 | # if "episode" in info.keys():
158 | # # Do not trust "done" with episode endings.
159 | # # Monitor wrapper includes "episode" key in info if environment
160 | # # has been wrapped with it. Use those rewards instead.
161 | # episode_rewards.append(info["episode"]["r"])
162 | # episode_lengths.append(info["episode"]["l"])
163 | # # Only increment at the real end of an episode
164 | # episode_counts[i] += 1
165 | # else:
166 | # episode_rewards.append(current_rewards[i])
167 | # episode_lengths.append(current_lengths[i])
168 | # episode_counts[i] += 1
169 |
170 | episode_rewards.append(current_rewards[i])
171 | episode_lengths.append(current_lengths[i])
172 | episode_counts[i] += 1
173 |
174 | current_rewards[i] = 0
175 | current_lengths[i] = 0
176 | if states is not None:
177 | states[i] *= 0
178 |
179 | if render:
180 | env.render()
181 |
182 | mean_reward = np.mean(episode_rewards)
183 | std_reward = np.std(episode_rewards)
184 | if reward_threshold is not None:
185 | assert mean_reward > reward_threshold, "Mean reward below threshold: " f"{mean_reward:.2f} < {reward_threshold:.2f}"
186 | if return_episode_rewards:
187 | return episode_rewards, episode_lengths
188 | return mean_reward, std_reward
189 |
--------------------------------------------------------------------------------
/thirdparty/stable_baselines3/common/monitor.py:
--------------------------------------------------------------------------------
1 | __all__ = ["Monitor", "ResultsWriter", "get_monitor_files", "load_results"]
2 |
3 | import csv
4 | import json
5 | import os
6 | import time
7 | from glob import glob
8 | from typing import Dict, List, Optional, Tuple, Union
9 |
10 | import gym
11 | import numpy as np
12 | import pandas
13 |
14 | from stable_baselines3.common.type_aliases import GymObs, GymStepReturn
15 |
16 |
17 | class Monitor(gym.Wrapper):
18 | """
19 | A monitor wrapper for Gym environments, it is used to know the episode reward, length, time and other data.
20 |
21 | :param env: The environment
22 | :param filename: the location to save a log file, can be None for no log
23 | :param allow_early_resets: allows the reset of the environment before it is done
24 | :param reset_keywords: extra keywords for the reset call,
25 | if extra parameters are needed at reset
26 | :param info_keywords: extra information to log, from the information return of env.step()
27 | """
28 |
29 | EXT = "monitor.csv"
30 |
31 | def __init__(
32 | self,
33 | env: gym.Env,
34 | filename: Optional[str] = None,
35 | allow_early_resets: bool = True,
36 | reset_keywords: Tuple[str, ...] = (),
37 | info_keywords: Tuple[str, ...] = (),
38 | ):
39 | super(Monitor, self).__init__(env=env)
40 | self.t_start = time.time()
41 | if filename is not None:
42 | self.results_writer = ResultsWriter(
43 | filename,
44 | header={"t_start": self.t_start, "env_id": env.spec and env.spec.id},
45 | extra_keys=reset_keywords + info_keywords,
46 | )
47 | else:
48 | self.results_writer = None
49 | self.reset_keywords = reset_keywords
50 | self.info_keywords = info_keywords
51 | self.allow_early_resets = allow_early_resets
52 | self.rewards = None
53 | self.needs_reset = True
54 | self.episode_returns = []
55 | self.episode_lengths = []
56 | self.episode_times = []
57 | self.total_steps = 0
58 | self.current_reset_info = {} # extra info about the current episode, that was passed in during reset()
59 |
60 | def reset(self, **kwargs) -> GymObs:
61 | """
62 | Calls the Gym environment reset. Can only be called if the environment is over, or if allow_early_resets is True
63 |
64 | :param kwargs: Extra keywords saved for the next episode. only if defined by reset_keywords
65 | :return: the first observation of the environment
66 | """
67 | if not self.allow_early_resets and not self.needs_reset:
68 | raise RuntimeError(
69 | "Tried to reset an environment before done. If you want to allow early resets, "
70 | "wrap your env with Monitor(env, path, allow_early_resets=True)"
71 | )
72 | self.rewards = []
73 | self.needs_reset = False
74 | for key in self.reset_keywords:
75 | value = kwargs.get(key)
76 | if value is None:
77 | raise ValueError(f"Expected you to pass keyword argument {key} into reset")
78 | self.current_reset_info[key] = value
79 | return self.env.reset(**kwargs)
80 |
81 | def step(self, action: Union[np.ndarray, int]) -> GymStepReturn:
82 | """
83 | Step the environment with the given action
84 |
85 | :param action: the action
86 | :return: observation, reward, done, information
87 | """
88 | if self.needs_reset:
89 | raise RuntimeError("Tried to step environment that needs reset")
90 | observation, reward, done, info = self.env.step(action)
91 | self.rewards.append(reward)
92 | if done:
93 | self.needs_reset = True
94 | ep_rew = sum(self.rewards)
95 | ep_len = len(self.rewards)
96 | ep_info = {"r": round(ep_rew, 6), "l": ep_len, "t": round(time.time() - self.t_start, 6)}
97 | for key in self.info_keywords:
98 | ep_info[key] = info[key]
99 | self.episode_returns.append(ep_rew)
100 | self.episode_lengths.append(ep_len)
101 | self.episode_times.append(time.time() - self.t_start)
102 | ep_info.update(self.current_reset_info)
103 | if self.results_writer:
104 | self.results_writer.write_row(ep_info)
105 | info["episode"] = ep_info
106 | self.total_steps += 1
107 | return observation, reward, done, info
108 |
109 | def close(self) -> None:
110 | """
111 | Closes the environment
112 | """
113 | super(Monitor, self).close()
114 | if self.results_writer is not None:
115 | self.results_writer.close()
116 |
117 | def get_total_steps(self) -> int:
118 | """
119 | Returns the total number of timesteps
120 |
121 | :return:
122 | """
123 | return self.total_steps
124 |
125 | def get_episode_rewards(self) -> List[float]:
126 | """
127 | Returns the rewards of all the episodes
128 |
129 | :return:
130 | """
131 | return self.episode_returns
132 |
133 | def get_episode_lengths(self) -> List[int]:
134 | """
135 | Returns the number of timesteps of all the episodes
136 |
137 | :return:
138 | """
139 | return self.episode_lengths
140 |
141 | def get_episode_times(self) -> List[float]:
142 | """
143 | Returns the runtime in seconds of all the episodes
144 |
145 | :return:
146 | """
147 | return self.episode_times
148 |
149 |
150 | class LoadMonitorResultsError(Exception):
151 | """
152 | Raised when loading the monitor log fails.
153 | """
154 |
155 | pass
156 |
157 |
158 | class ResultsWriter:
159 | """
160 | A result writer that saves the data from the `Monitor` class
161 |
162 | :param filename: the location to save a log file, can be None for no log
163 | :param header: the header dictionary object of the saved csv
164 | :param reset_keywords: the extra information to log, typically is composed of
165 | ``reset_keywords`` and ``info_keywords``
166 | """
167 |
168 | def __init__(
169 | self,
170 | filename: str = "",
171 | header: Optional[Dict[str, Union[float, str]]] = None,
172 | extra_keys: Tuple[str, ...] = (),
173 | ):
174 | if header is None:
175 | header = {}
176 | if not filename.endswith(Monitor.EXT):
177 | if os.path.isdir(filename):
178 | filename = os.path.join(filename, Monitor.EXT)
179 | else:
180 | filename = filename + "." + Monitor.EXT
181 | self.file_handler = open(filename, "wt")
182 | self.file_handler.write("#%s\n" % json.dumps(header))
183 | self.logger = csv.DictWriter(self.file_handler, fieldnames=("r", "l", "t") + extra_keys)
184 | self.logger.writeheader()
185 | self.file_handler.flush()
186 |
187 | def write_row(self, epinfo: Dict[str, Union[float, int]]) -> None:
188 | """
189 | Close the file handler
190 |
191 | :param epinfo: the information on episodic return, length, and time
192 | """
193 | if self.logger:
194 | self.logger.writerow(epinfo)
195 | self.file_handler.flush()
196 |
197 | def close(self) -> None:
198 | """
199 | Close the file handler
200 | """
201 | self.file_handler.close()
202 |
203 |
204 | def get_monitor_files(path: str) -> List[str]:
205 | """
206 | get all the monitor files in the given path
207 |
208 | :param path: the logging folder
209 | :return: the log files
210 | """
211 | return glob(os.path.join(path, "*" + Monitor.EXT))
212 |
213 |
214 | def load_results(path: str) -> pandas.DataFrame:
215 | """
216 | Load all Monitor logs from a given directory path matching ``*monitor.csv``
217 |
218 | :param path: the directory path containing the log file(s)
219 | :return: the logged data
220 | """
221 | monitor_files = get_monitor_files(path)
222 | if len(monitor_files) == 0:
223 | raise LoadMonitorResultsError(f"No monitor files of the form *{Monitor.EXT} found in {path}")
224 | data_frames, headers = [], []
225 | for file_name in monitor_files:
226 | with open(file_name, "rt") as file_handler:
227 | first_line = file_handler.readline()
228 | assert first_line[0] == "#"
229 | header = json.loads(first_line[1:])
230 | data_frame = pandas.read_csv(file_handler, index_col=None)
231 | headers.append(header)
232 | data_frame["t"] += header["t_start"]
233 | data_frames.append(data_frame)
234 | data_frame = pandas.concat(data_frames)
235 | data_frame.sort_values("t", inplace=True)
236 | data_frame.reset_index(inplace=True)
237 | data_frame["t"] -= min(header["t_start"] for header in headers)
238 | return data_frame
239 |
--------------------------------------------------------------------------------
/thirdparty/stable_baselines3/common/noise.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from abc import ABC, abstractmethod
3 | from typing import Iterable, List, Optional
4 |
5 | import numpy as np
6 |
7 |
8 | class ActionNoise(ABC):
9 | """
10 | The action noise base class
11 | """
12 |
13 | def __init__(self):
14 | super(ActionNoise, self).__init__()
15 |
16 | def reset(self) -> None:
17 | """
18 | call end of episode reset for the noise
19 | """
20 | pass
21 |
22 | @abstractmethod
23 | def __call__(self) -> np.ndarray:
24 | raise NotImplementedError()
25 |
26 |
27 | class NormalActionNoise(ActionNoise):
28 | """
29 | A Gaussian action noise
30 |
31 | :param mean: the mean value of the noise
32 | :param sigma: the scale of the noise (std here)
33 | """
34 |
35 | def __init__(self, mean: np.ndarray, sigma: np.ndarray):
36 | self._mu = mean
37 | self._sigma = sigma
38 | super(NormalActionNoise, self).__init__()
39 |
40 | def __call__(self) -> np.ndarray:
41 | return np.random.normal(self._mu, self._sigma)
42 |
43 | def __repr__(self) -> str:
44 | return f"NormalActionNoise(mu={self._mu}, sigma={self._sigma})"
45 |
46 |
47 | class OrnsteinUhlenbeckActionNoise(ActionNoise):
48 | """
49 | An Ornstein Uhlenbeck action noise, this is designed to approximate Brownian motion with friction.
50 |
51 | Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab
52 |
53 | :param mean: the mean of the noise
54 | :param sigma: the scale of the noise
55 | :param theta: the rate of mean reversion
56 | :param dt: the timestep for the noise
57 | :param initial_noise: the initial value for the noise output, (if None: 0)
58 | """
59 |
60 | def __init__(
61 | self,
62 | mean: np.ndarray,
63 | sigma: np.ndarray,
64 | theta: float = 0.15,
65 | dt: float = 1e-2,
66 | initial_noise: Optional[np.ndarray] = None,
67 | ):
68 | self._theta = theta
69 | self._mu = mean
70 | self._sigma = sigma
71 | self._dt = dt
72 | self.initial_noise = initial_noise
73 | self.noise_prev = np.zeros_like(self._mu)
74 | self.reset()
75 | super(OrnsteinUhlenbeckActionNoise, self).__init__()
76 |
77 | def __call__(self) -> np.ndarray:
78 | noise = (
79 | self.noise_prev
80 | + self._theta * (self._mu - self.noise_prev) * self._dt
81 | + self._sigma * np.sqrt(self._dt) * np.random.normal(size=self._mu.shape)
82 | )
83 | self.noise_prev = noise
84 | return noise
85 |
86 | def reset(self) -> None:
87 | """
88 | reset the Ornstein Uhlenbeck noise, to the initial position
89 | """
90 | self.noise_prev = self.initial_noise if self.initial_noise is not None else np.zeros_like(self._mu)
91 |
92 | def __repr__(self) -> str:
93 | return f"OrnsteinUhlenbeckActionNoise(mu={self._mu}, sigma={self._sigma})"
94 |
95 |
96 | class VectorizedActionNoise(ActionNoise):
97 | """
98 | A Vectorized action noise for parallel environments.
99 |
100 | :param base_noise: ActionNoise The noise generator to use
101 | :param n_envs: The number of parallel environments
102 | """
103 |
104 | def __init__(self, base_noise: ActionNoise, n_envs: int):
105 | try:
106 | self.n_envs = int(n_envs)
107 | assert self.n_envs > 0
108 | except (TypeError, AssertionError):
109 | raise ValueError(f"Expected n_envs={n_envs} to be positive integer greater than 0")
110 |
111 | self.base_noise = base_noise
112 | self.noises = [copy.deepcopy(self.base_noise) for _ in range(n_envs)]
113 |
114 | def reset(self, indices: Optional[Iterable[int]] = None) -> None:
115 | """
116 | Reset all the noise processes, or those listed in indices
117 |
118 | :param indices: Optional[Iterable[int]] The indices to reset. Default: None.
119 | If the parameter is None, then all processes are reset to their initial position.
120 | """
121 | if indices is None:
122 | indices = range(len(self.noises))
123 |
124 | for index in indices:
125 | self.noises[index].reset()
126 |
127 | def __repr__(self) -> str:
128 | return f"VecNoise(BaseNoise={repr(self.base_noise)}), n_envs={len(self.noises)})"
129 |
130 | def __call__(self) -> np.ndarray:
131 | """
132 | Generate and stack the action noise from each noise object
133 | """
134 | noise = np.stack([noise() for noise in self.noises])
135 | return noise
136 |
137 | @property
138 | def base_noise(self) -> ActionNoise:
139 | return self._base_noise
140 |
141 | @base_noise.setter
142 | def base_noise(self, base_noise: ActionNoise) -> None:
143 | if base_noise is None:
144 | raise ValueError("Expected base_noise to be an instance of ActionNoise, not None", ActionNoise)
145 | if not isinstance(base_noise, ActionNoise):
146 | raise TypeError("Expected base_noise to be an instance of type ActionNoise", ActionNoise)
147 | self._base_noise = base_noise
148 |
149 | @property
150 | def noises(self) -> List[ActionNoise]:
151 | return self._noises
152 |
153 | @noises.setter
154 | def noises(self, noises: List[ActionNoise]) -> None:
155 | noises = list(noises) # raises TypeError if not iterable
156 | assert len(noises) == self.n_envs, f"Expected a list of {self.n_envs} ActionNoises, found {len(noises)}."
157 |
158 | different_types = [i for i, noise in enumerate(noises) if not isinstance(noise, type(self.base_noise))]
159 |
160 | if len(different_types):
161 | raise ValueError(
162 | f"Noise instances at indices {different_types} don't match the type of base_noise", type(self.base_noise)
163 | )
164 |
165 | self._noises = noises
166 | for noise in noises:
167 | noise.reset()
168 |
--------------------------------------------------------------------------------
/thirdparty/stable_baselines3/common/preprocessing.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from typing import Dict, Tuple, Union
3 |
4 | import numpy as np
5 | import torch as th
6 | from gym import spaces
7 | from torch.nn import functional as F
8 |
9 |
10 | def is_image_space_channels_first(observation_space: spaces.Box) -> bool:
11 | """
12 | Check if an image observation space (see ``is_image_space``)
13 | is channels-first (CxHxW, True) or channels-last (HxWxC, False).
14 |
15 | Use a heuristic that channel dimension is the smallest of the three.
16 | If second dimension is smallest, raise an exception (no support).
17 |
18 | :param observation_space:
19 | :return: True if observation space is channels-first image, False if channels-last.
20 | """
21 | smallest_dimension = np.argmin(observation_space.shape).item()
22 | if smallest_dimension == 1:
23 | warnings.warn("Treating image space as channels-last, while second dimension was smallest of the three.")
24 | return smallest_dimension == 0
25 |
26 |
27 | def is_image_space(
28 | observation_space: spaces.Space,
29 | check_channels: bool = False,
30 | ) -> bool:
31 | """
32 | Check if a observation space has the shape, limits and dtype
33 | of a valid image.
34 | The check is conservative, so that it returns False if there is a doubt.
35 |
36 | Valid images: RGB, RGBD, GrayScale with values in [0, 255]
37 |
38 | :param observation_space:
39 | :param check_channels: Whether to do or not the check for the number of channels.
40 | e.g., with frame-stacking, the observation space may have more channels than expected.
41 | :return:
42 | """
43 | if isinstance(observation_space, spaces.Box) and len(observation_space.shape) == 3:
44 | # Check the type
45 | if observation_space.dtype != np.uint8:
46 | return False
47 |
48 | # Check the value range
49 | if np.any(observation_space.low != 0) or np.any(observation_space.high != 255):
50 | return False
51 |
52 | # Skip channels check
53 | if not check_channels:
54 | return True
55 | # Check the number of channels
56 | if is_image_space_channels_first(observation_space):
57 | n_channels = observation_space.shape[0]
58 | else:
59 | n_channels = observation_space.shape[-1]
60 | # RGB, RGBD, GrayScale
61 | return n_channels in [1, 3, 4]
62 | return False
63 |
64 |
65 | def maybe_transpose(observation: np.ndarray, observation_space: spaces.Space) -> np.ndarray:
66 | """
67 | Handle the different cases for images as PyTorch use channel first format.
68 |
69 | :param observation:
70 | :param observation_space:
71 | :return: channel first observation if observation is an image
72 | """
73 | # Avoid circular import
74 | from stable_baselines3.common.vec_env import VecTransposeImage
75 |
76 | if is_image_space(observation_space):
77 | if not (observation.shape == observation_space.shape or observation.shape[1:] == observation_space.shape):
78 | # Try to re-order the channels
79 | transpose_obs = VecTransposeImage.transpose_image(observation)
80 | if transpose_obs.shape == observation_space.shape or transpose_obs.shape[1:] == observation_space.shape:
81 | observation = transpose_obs
82 | return observation
83 |
84 |
85 | def preprocess_obs(
86 | obs: th.Tensor,
87 | observation_space: spaces.Space,
88 | normalize_images: bool = True,
89 | ) -> Union[th.Tensor, Dict[str, th.Tensor]]:
90 | """
91 | Preprocess observation to be to a neural network.
92 | For images, it normalizes the values by dividing them by 255 (to have values in [0, 1])
93 | For discrete observations, it create a one hot vector.
94 |
95 | :param obs: Observation
96 | :param observation_space:
97 | :param normalize_images: Whether to normalize images or not
98 | (True by default)
99 | :return:
100 | """
101 | if isinstance(observation_space, spaces.Box):
102 | if is_image_space(observation_space) and normalize_images:
103 | return obs.float() / 255.0
104 | return obs.float()
105 |
106 | elif isinstance(observation_space, spaces.Discrete):
107 | # One hot encoding and convert to float to avoid errors
108 | return F.one_hot(obs.long(), num_classes=observation_space.n).float()
109 |
110 | elif isinstance(observation_space, spaces.MultiDiscrete):
111 | # Tensor concatenation of one hot encodings of each Categorical sub-space
112 | return th.cat(
113 | [
114 | F.one_hot(obs_.long(), num_classes=int(observation_space.nvec[idx])).float()
115 | for idx, obs_ in enumerate(th.split(obs.long(), 1, dim=1))
116 | ],
117 | dim=-1,
118 | ).view(obs.shape[0], sum(observation_space.nvec))
119 |
120 | elif isinstance(observation_space, spaces.MultiBinary):
121 | return obs.float()
122 |
123 | elif isinstance(observation_space, spaces.Dict):
124 | # Do not modify by reference the original observation
125 | preprocessed_obs = {}
126 | for key, _obs in obs.items():
127 | preprocessed_obs[key] = preprocess_obs(_obs, observation_space[key], normalize_images=normalize_images)
128 | return preprocessed_obs
129 |
130 | else:
131 | raise NotImplementedError(f"Preprocessing not implemented for {observation_space}")
132 |
133 |
134 | def get_obs_shape(
135 | observation_space: spaces.Space,
136 | ) -> Union[Tuple[int, ...], Dict[str, Tuple[int, ...]]]:
137 | """
138 | Get the shape of the observation (useful for the buffers).
139 |
140 | :param observation_space:
141 | :return:
142 | """
143 | if isinstance(observation_space, spaces.Box):
144 | return observation_space.shape
145 | elif isinstance(observation_space, spaces.Discrete):
146 | # Observation is an int
147 | return (1,)
148 | elif isinstance(observation_space, spaces.MultiDiscrete):
149 | # Number of discrete features
150 | return (int(len(observation_space.nvec)),)
151 | elif isinstance(observation_space, spaces.MultiBinary):
152 | # Number of binary features
153 | return (int(observation_space.n),)
154 | elif isinstance(observation_space, spaces.Dict):
155 | return {key: get_obs_shape(subspace) for (key, subspace) in observation_space.spaces.items()}
156 |
157 | else:
158 | raise NotImplementedError(f"{observation_space} observation space is not supported")
159 |
160 |
161 | def get_flattened_obs_dim(observation_space: spaces.Space) -> int:
162 | """
163 | Get the dimension of the observation space when flattened.
164 | It does not apply to image observation space.
165 |
166 | Used by the ``FlattenExtractor`` to compute the input shape.
167 |
168 | :param observation_space:
169 | :return:
170 | """
171 | # See issue https://github.com/openai/gym/issues/1915
172 | # it may be a problem for Dict/Tuple spaces too...
173 | if isinstance(observation_space, spaces.MultiDiscrete):
174 | return sum(observation_space.nvec)
175 | else:
176 | # Use Gym internal method
177 | return spaces.utils.flatdim(observation_space)
178 |
179 |
180 | def get_action_dim(action_space: spaces.Space) -> int:
181 | """
182 | Get the dimension of the action space.
183 |
184 | :param action_space:
185 | :return:
186 | """
187 | if isinstance(action_space, spaces.Box):
188 | return int(np.prod(action_space.shape))
189 | elif isinstance(action_space, spaces.Discrete):
190 | # Action is an int
191 | return 1
192 | elif isinstance(action_space, spaces.MultiDiscrete):
193 | # Number of discrete actions
194 | return int(len(action_space.nvec))
195 | elif isinstance(action_space, spaces.MultiBinary):
196 | # Number of binary actions
197 | return int(action_space.n)
198 | else:
199 | raise NotImplementedError(f"{action_space} action space is not supported")
200 |
201 |
202 | def check_for_nested_spaces(obs_space: spaces.Space):
203 | """
204 | Make sure the observation space does not have nested spaces (Dicts/Tuples inside Dicts/Tuples).
205 | If so, raise an Exception informing that there is no support for this.
206 |
207 | :param obs_space: an observation space
208 | :return:
209 | """
210 | if isinstance(obs_space, (spaces.Dict, spaces.Tuple)):
211 | sub_spaces = obs_space.spaces.values() if isinstance(obs_space, spaces.Dict) else obs_space.spaces
212 | for sub_space in sub_spaces:
213 | if isinstance(sub_space, (spaces.Dict, spaces.Tuple)):
214 | raise NotImplementedError(
215 | "Nested observation spaces are not supported (Tuple/Dict space inside Tuple/Dict space)."
216 | )
217 |
--------------------------------------------------------------------------------
/thirdparty/stable_baselines3/common/results_plotter.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, List, Optional, Tuple
2 |
3 | import numpy as np
4 | import pandas as pd
5 |
6 | # import matplotlib
7 | # matplotlib.use('TkAgg') # Can change to 'Agg' for non-interactive mode
8 | from matplotlib import pyplot as plt
9 |
10 | from stable_baselines3.common.monitor import load_results
11 |
12 | X_TIMESTEPS = "timesteps"
13 | X_EPISODES = "episodes"
14 | X_WALLTIME = "walltime_hrs"
15 | POSSIBLE_X_AXES = [X_TIMESTEPS, X_EPISODES, X_WALLTIME]
16 | EPISODES_WINDOW = 100
17 |
18 |
19 | def rolling_window(array: np.ndarray, window: int) -> np.ndarray:
20 | """
21 | Apply a rolling window to a np.ndarray
22 |
23 | :param array: the input Array
24 | :param window: length of the rolling window
25 | :return: rolling window on the input array
26 | """
27 | shape = array.shape[:-1] + (array.shape[-1] - window + 1, window)
28 | strides = array.strides + (array.strides[-1],)
29 | return np.lib.stride_tricks.as_strided(array, shape=shape, strides=strides)
30 |
31 |
32 | def window_func(var_1: np.ndarray, var_2: np.ndarray, window: int, func: Callable) -> Tuple[np.ndarray, np.ndarray]:
33 | """
34 | Apply a function to the rolling window of 2 arrays
35 |
36 | :param var_1: variable 1
37 | :param var_2: variable 2
38 | :param window: length of the rolling window
39 | :param func: function to apply on the rolling window on variable 2 (such as np.mean)
40 | :return: the rolling output with applied function
41 | """
42 | var_2_window = rolling_window(var_2, window)
43 | function_on_var2 = func(var_2_window, axis=-1)
44 | return var_1[window - 1 :], function_on_var2
45 |
46 |
47 | def ts2xy(data_frame: pd.DataFrame, x_axis: str) -> Tuple[np.ndarray, np.ndarray]:
48 | """
49 | Decompose a data frame variable to x ans ys
50 |
51 | :param data_frame: the input data
52 | :param x_axis: the axis for the x and y output
53 | (can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs')
54 | :return: the x and y output
55 | """
56 | if x_axis == X_TIMESTEPS:
57 | x_var = np.cumsum(data_frame.l.values)
58 | y_var = data_frame.r.values
59 | elif x_axis == X_EPISODES:
60 | x_var = np.arange(len(data_frame))
61 | y_var = data_frame.r.values
62 | elif x_axis == X_WALLTIME:
63 | # Convert to hours
64 | x_var = data_frame.t.values / 3600.0
65 | y_var = data_frame.r.values
66 | else:
67 | raise NotImplementedError
68 | return x_var, y_var
69 |
70 |
71 | def plot_curves(
72 | xy_list: List[Tuple[np.ndarray, np.ndarray]], x_axis: str, title: str, figsize: Tuple[int, int] = (8, 2)
73 | ) -> None:
74 | """
75 | plot the curves
76 |
77 | :param xy_list: the x and y coordinates to plot
78 | :param x_axis: the axis for the x and y output
79 | (can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs')
80 | :param title: the title of the plot
81 | :param figsize: Size of the figure (width, height)
82 | """
83 |
84 | plt.figure(title, figsize=figsize)
85 | max_x = max(xy[0][-1] for xy in xy_list)
86 | min_x = 0
87 | for (_, (x, y)) in enumerate(xy_list):
88 | plt.scatter(x, y, s=2)
89 | # Do not plot the smoothed curve at all if the timeseries is shorter than window size.
90 | if x.shape[0] >= EPISODES_WINDOW:
91 | # Compute and plot rolling mean with window of size EPISODE_WINDOW
92 | x, y_mean = window_func(x, y, EPISODES_WINDOW, np.mean)
93 | plt.plot(x, y_mean)
94 | plt.xlim(min_x, max_x)
95 | plt.title(title)
96 | plt.xlabel(x_axis)
97 | plt.ylabel("Episode Rewards")
98 | plt.tight_layout()
99 |
100 |
101 | def plot_results(
102 | dirs: List[str], num_timesteps: Optional[int], x_axis: str, task_name: str, figsize: Tuple[int, int] = (8, 2)
103 | ) -> None:
104 | """
105 | Plot the results using csv files from ``Monitor`` wrapper.
106 |
107 | :param dirs: the save location of the results to plot
108 | :param num_timesteps: only plot the points below this value
109 | :param x_axis: the axis for the x and y output
110 | (can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs')
111 | :param task_name: the title of the task to plot
112 | :param figsize: Size of the figure (width, height)
113 | """
114 |
115 | data_frames = []
116 | for folder in dirs:
117 | data_frame = load_results(folder)
118 | if num_timesteps is not None:
119 | data_frame = data_frame[data_frame.l.cumsum() <= num_timesteps]
120 | data_frames.append(data_frame)
121 | xy_list = [ts2xy(data_frame, x_axis) for data_frame in data_frames]
122 | plot_curves(xy_list, x_axis, task_name, figsize)
123 |
--------------------------------------------------------------------------------
/thirdparty/stable_baselines3/common/running_mean_std.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import numpy as np
4 |
5 |
6 | class RunningMeanStd(object):
7 | def __init__(self, epsilon: float = 1e-4, shape: Tuple[int, ...] = ()):
8 | """
9 | Calulates the running mean and std of a data stream
10 | https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
11 |
12 | :param epsilon: helps with arithmetic issues
13 | :param shape: the shape of the data stream's output
14 | """
15 | self.mean = np.zeros(shape, np.float64)
16 | self.var = np.ones(shape, np.float64)
17 | self.count = epsilon
18 |
19 | def update(self, arr: np.ndarray) -> None:
20 | batch_mean = np.mean(arr, axis=0)
21 | batch_var = np.var(arr, axis=0)
22 | batch_count = arr.shape[0]
23 | self.update_from_moments(batch_mean, batch_var, batch_count)
24 |
25 | def update_from_moments(self, batch_mean: np.ndarray, batch_var: np.ndarray, batch_count: int) -> None:
26 | delta = batch_mean - self.mean
27 | tot_count = self.count + batch_count
28 |
29 | new_mean = self.mean + delta * batch_count / tot_count
30 | m_a = self.var * self.count
31 | m_b = batch_var * batch_count
32 | m_2 = m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count)
33 | new_var = m_2 / (self.count + batch_count)
34 |
35 | new_count = batch_count + self.count
36 |
37 | self.mean = new_mean
38 | self.var = new_var
39 | self.count = new_count
40 |
--------------------------------------------------------------------------------
/thirdparty/stable_baselines3/common/sb2_compat/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RobustFieldAutonomyLab/Stochastic_Road_Network/1cdcd41c7311560bf7a5df0d4d8bca829fe2b958/thirdparty/stable_baselines3/common/sb2_compat/__init__.py
--------------------------------------------------------------------------------
/thirdparty/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, Dict, Iterable, Optional
2 |
3 | import torch
4 | from torch.optim import Optimizer
5 |
6 |
7 | class RMSpropTFLike(Optimizer):
8 | r"""Implements RMSprop algorithm with closer match to Tensorflow version.
9 |
10 | For reproducibility with original stable-baselines. Use this
11 | version with e.g. A2C for stabler learning than with the PyTorch
12 | RMSProp. Based on the PyTorch v1.5.0 implementation of RMSprop.
13 |
14 | See a more throughout conversion in pytorch-image-models repository:
15 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/rmsprop_tf.py
16 |
17 | Changes to the original RMSprop:
18 | - Move epsilon inside square root
19 | - Initialize squared gradient to ones rather than zeros
20 |
21 | Proposed by G. Hinton in his
22 | `course `_.
23 |
24 | The centered version first appears in `Generating Sequences
25 | With Recurrent Neural Networks `_.
26 |
27 | The implementation here takes the square root of the gradient average before
28 | adding epsilon (note that TensorFlow interchanges these two operations). The effective
29 | learning rate is thus :math:`\alpha/(\sqrt{v} + \epsilon)` where :math:`\alpha`
30 | is the scheduled learning rate and :math:`v` is the weighted moving average
31 | of the squared gradient.
32 |
33 | :params: iterable of parameters to optimize or dicts defining
34 | parameter groups
35 | :param lr: learning rate (default: 1e-2)
36 | :param momentum: momentum factor (default: 0)
37 | :param alpha: smoothing constant (default: 0.99)
38 | :param eps: term added to the denominator to improve
39 | numerical stability (default: 1e-8)
40 | :param centered: if ``True``, compute the centered RMSProp,
41 | the gradient is normalized by an estimation of its variance
42 | :param weight_decay: weight decay (L2 penalty) (default: 0)
43 |
44 | """
45 |
46 | def __init__(
47 | self,
48 | params: Iterable[torch.nn.Parameter],
49 | lr: float = 1e-2,
50 | alpha: float = 0.99,
51 | eps: float = 1e-8,
52 | weight_decay: float = 0,
53 | momentum: float = 0,
54 | centered: bool = False,
55 | ):
56 | if not 0.0 <= lr:
57 | raise ValueError("Invalid learning rate: {}".format(lr))
58 | if not 0.0 <= eps:
59 | raise ValueError("Invalid epsilon value: {}".format(eps))
60 | if not 0.0 <= momentum:
61 | raise ValueError("Invalid momentum value: {}".format(momentum))
62 | if not 0.0 <= weight_decay:
63 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
64 | if not 0.0 <= alpha:
65 | raise ValueError("Invalid alpha value: {}".format(alpha))
66 |
67 | defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay)
68 | super(RMSpropTFLike, self).__init__(params, defaults)
69 |
70 | def __setstate__(self, state: Dict[str, Any]) -> None:
71 | super(RMSpropTFLike, self).__setstate__(state)
72 | for group in self.param_groups:
73 | group.setdefault("momentum", 0)
74 | group.setdefault("centered", False)
75 |
76 | @torch.no_grad()
77 | def step(self, closure: Optional[Callable[[], None]] = None) -> Optional[torch.Tensor]:
78 | """Performs a single optimization step.
79 |
80 | :param closure: A closure that reevaluates the model
81 | and returns the loss.
82 | :return: loss
83 | """
84 | loss = None
85 | if closure is not None:
86 | with torch.enable_grad():
87 | loss = closure()
88 |
89 | for group in self.param_groups:
90 | for p in group["params"]:
91 | if p.grad is None:
92 | continue
93 | grad = p.grad
94 | if grad.is_sparse:
95 | raise RuntimeError("RMSpropTF does not support sparse gradients")
96 | state = self.state[p]
97 |
98 | # State initialization
99 | if len(state) == 0:
100 | state["step"] = 0
101 | # PyTorch initialized to zeros here
102 | state["square_avg"] = torch.ones_like(p, memory_format=torch.preserve_format)
103 | if group["momentum"] > 0:
104 | state["momentum_buffer"] = torch.zeros_like(p, memory_format=torch.preserve_format)
105 | if group["centered"]:
106 | state["grad_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)
107 |
108 | square_avg = state["square_avg"]
109 | alpha = group["alpha"]
110 |
111 | state["step"] += 1
112 |
113 | if group["weight_decay"] != 0:
114 | grad = grad.add(p, alpha=group["weight_decay"])
115 |
116 | square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha)
117 |
118 | if group["centered"]:
119 | grad_avg = state["grad_avg"]
120 | grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha)
121 | # PyTorch added epsilon after square root
122 | # avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).sqrt_().add_(group['eps'])
123 | avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).add_(group["eps"]).sqrt_()
124 | else:
125 | # PyTorch added epsilon after square root
126 | # avg = square_avg.sqrt().add_(group['eps'])
127 | avg = square_avg.add(group["eps"]).sqrt_()
128 |
129 | if group["momentum"] > 0:
130 | buf = state["momentum_buffer"]
131 | buf.mul_(group["momentum"]).addcdiv_(grad, avg)
132 | p.add_(buf, alpha=-group["lr"])
133 | else:
134 | p.addcdiv_(grad, avg, value=-group["lr"])
135 |
136 | return loss
137 |
--------------------------------------------------------------------------------
/thirdparty/stable_baselines3/common/type_aliases.py:
--------------------------------------------------------------------------------
1 | """Common aliases for type hints"""
2 |
3 | from enum import Enum
4 | from typing import Any, Callable, Dict, List, NamedTuple, Tuple, Union
5 |
6 | import gym
7 | import numpy as np
8 | import torch as th
9 |
10 | from stable_baselines3.common import callbacks, vec_env
11 |
12 | GymEnv = Union[gym.Env, vec_env.VecEnv]
13 | GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int]
14 | GymStepReturn = Tuple[GymObs, float, bool, Dict]
15 | TensorDict = Dict[Union[str, int], th.Tensor]
16 | OptimizerStateDict = Dict[str, Any]
17 | MaybeCallback = Union[None, Callable, List[callbacks.BaseCallback], callbacks.BaseCallback]
18 |
19 | # A schedule takes the remaining progress as input
20 | # and ouputs a scalar (e.g. learning rate, clip range, ...)
21 | Schedule = Callable[[float], float]
22 |
23 |
24 | class RolloutBufferSamples(NamedTuple):
25 | observations: th.Tensor
26 | actions: th.Tensor
27 | old_values: th.Tensor
28 | old_log_prob: th.Tensor
29 | advantages: th.Tensor
30 | returns: th.Tensor
31 |
32 |
33 | class DictRolloutBufferSamples(RolloutBufferSamples):
34 | observations: TensorDict
35 | actions: th.Tensor
36 | old_values: th.Tensor
37 | old_log_prob: th.Tensor
38 | advantages: th.Tensor
39 | returns: th.Tensor
40 |
41 |
42 | class ReplayBufferSamples(NamedTuple):
43 | observations: th.Tensor
44 | actions: th.Tensor
45 | next_observations: th.Tensor
46 | dones: th.Tensor
47 | rewards: th.Tensor
48 |
49 |
50 | class DictReplayBufferSamples(ReplayBufferSamples):
51 | observations: TensorDict
52 | actions: th.Tensor
53 | next_observations: th.Tensor
54 | dones: th.Tensor
55 | rewards: th.Tensor
56 |
57 |
58 | class RolloutReturn(NamedTuple):
59 | episode_reward: float
60 | episode_timesteps: int
61 | n_episodes: int
62 | continue_training: bool
63 |
64 |
65 | class TrainFrequencyUnit(Enum):
66 | STEP = "step"
67 | EPISODE = "episode"
68 |
69 |
70 | class TrainFreq(NamedTuple):
71 | frequency: int
72 | unit: TrainFrequencyUnit # either "step" or "episode"
73 |
--------------------------------------------------------------------------------
/thirdparty/stable_baselines3/common/vec_env/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa F401
2 | import typing
3 | from copy import deepcopy
4 | from typing import Optional, Type, Union
5 |
6 | from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper
7 | from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
8 | from stable_baselines3.common.vec_env.stacked_observations import StackedDictObservations, StackedObservations
9 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
10 | from stable_baselines3.common.vec_env.vec_check_nan import VecCheckNan
11 | from stable_baselines3.common.vec_env.vec_extract_dict_obs import VecExtractDictObs
12 | from stable_baselines3.common.vec_env.vec_frame_stack import VecFrameStack
13 | from stable_baselines3.common.vec_env.vec_monitor import VecMonitor
14 | from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
15 | from stable_baselines3.common.vec_env.vec_transpose import VecTransposeImage
16 | from stable_baselines3.common.vec_env.vec_video_recorder import VecVideoRecorder
17 |
18 | # Avoid circular import
19 | if typing.TYPE_CHECKING:
20 | from stable_baselines3.common.type_aliases import GymEnv
21 |
22 |
23 | def unwrap_vec_wrapper(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> Optional[VecEnvWrapper]:
24 | """
25 | Retrieve a ``VecEnvWrapper`` object by recursively searching.
26 |
27 | :param env:
28 | :param vec_wrapper_class:
29 | :return:
30 | """
31 | env_tmp = env
32 | while isinstance(env_tmp, VecEnvWrapper):
33 | if isinstance(env_tmp, vec_wrapper_class):
34 | return env_tmp
35 | env_tmp = env_tmp.venv
36 | return None
37 |
38 |
39 | def unwrap_vec_normalize(env: Union["GymEnv", VecEnv]) -> Optional[VecNormalize]:
40 | """
41 | :param env:
42 | :return:
43 | """
44 | return unwrap_vec_wrapper(env, VecNormalize) # pytype:disable=bad-return-type
45 |
46 |
47 | def is_vecenv_wrapped(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> bool:
48 | """
49 | Check if an environment is already wrapped by a given ``VecEnvWrapper``.
50 |
51 | :param env:
52 | :param vec_wrapper_class:
53 | :return:
54 | """
55 | return unwrap_vec_wrapper(env, vec_wrapper_class) is not None
56 |
57 |
58 | # Define here to avoid circular import
59 | def sync_envs_normalization(env: "GymEnv", eval_env: "GymEnv") -> None:
60 | """
61 | Sync eval env and train env when using VecNormalize
62 |
63 | :param env:
64 | :param eval_env:
65 | """
66 | env_tmp, eval_env_tmp = env, eval_env
67 | while isinstance(env_tmp, VecEnvWrapper):
68 | if isinstance(env_tmp, VecNormalize):
69 | eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms)
70 | eval_env_tmp.ret_rms = deepcopy(env_tmp.ret_rms)
71 | env_tmp = env_tmp.venv
72 | eval_env_tmp = eval_env_tmp.venv
73 |
--------------------------------------------------------------------------------
/thirdparty/stable_baselines3/common/vec_env/dummy_vec_env.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from copy import deepcopy
3 | from typing import Any, Callable, List, Optional, Sequence, Type, Union
4 |
5 | import gym
6 | import numpy as np
7 |
8 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn
9 | from stable_baselines3.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info
10 |
11 |
12 | class DummyVecEnv(VecEnv):
13 | """
14 | Creates a simple vectorized wrapper for multiple environments, calling each environment in sequence on the current
15 | Python process. This is useful for computationally simple environment such as ``cartpole-v1``,
16 | as the overhead of multiprocess or multithread outweighs the environment computation time.
17 | This can also be used for RL methods that
18 | require a vectorized environment, but that you want a single environments to train with.
19 |
20 | :param env_fns: a list of functions
21 | that return environments to vectorize
22 | """
23 |
24 | def __init__(self, env_fns: List[Callable[[], gym.Env]]):
25 | self.envs = [fn() for fn in env_fns]
26 | env = self.envs[0]
27 | VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
28 | obs_space = env.observation_space
29 | self.keys, shapes, dtypes = obs_space_info(obs_space)
30 |
31 | self.buf_obs = OrderedDict([(k, np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k])) for k in self.keys])
32 | self.buf_dones = np.zeros((self.num_envs,), dtype=bool)
33 | self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
34 | self.buf_infos = [{} for _ in range(self.num_envs)]
35 | self.actions = None
36 | self.metadata = env.metadata
37 |
38 | ##### local modification #####
39 | self.discount = env.discount
40 | self.num_states = env.get_num_of_states()
41 | self.save_q_vals = True if env.agent == "QRDQN" else False
42 |
43 | ##### local modification #####
44 | def get_obs_at_state(self, state:int) -> np.ndarray:
45 | return self.envs[0].get_obs_at_state(state)
46 |
47 | ##### local modification #####
48 | def save_quantiles(self, quantiles:np.ndarray) -> None:
49 | self.envs[0].save_quantiles(quantiles)
50 |
51 | def step_async(self, actions: np.ndarray) -> None:
52 | self.actions = actions
53 |
54 | def step_wait(self) -> VecEnvStepReturn:
55 | for env_idx in range(self.num_envs):
56 | obs, self.buf_rews[env_idx], self.buf_dones[env_idx], self.buf_infos[env_idx] = self.envs[env_idx].step(
57 | self.actions[env_idx]
58 | )
59 | if self.buf_dones[env_idx]:
60 | # save final observation where user can get it, then reset
61 | self.buf_infos[env_idx]["terminal_observation"] = obs
62 | obs = self.envs[env_idx].reset()
63 | self._save_obs(env_idx, obs)
64 | return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos))
65 |
66 | def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
67 | seeds = list()
68 | for idx, env in enumerate(self.envs):
69 | seeds.append(env.seed(seed + idx))
70 | return seeds
71 |
72 | def reset(self) -> VecEnvObs:
73 | for env_idx in range(self.num_envs):
74 | obs = self.envs[env_idx].reset()
75 | self._save_obs(env_idx, obs)
76 | return self._obs_from_buf()
77 |
78 | def close(self) -> None:
79 | for env in self.envs:
80 | env.close()
81 |
82 | def get_images(self) -> Sequence[np.ndarray]:
83 | return [env.render(mode="rgb_array") for env in self.envs]
84 |
85 | def render(self, mode: str = "human") -> Optional[np.ndarray]:
86 | """
87 | Gym environment rendering. If there are multiple environments then
88 | they are tiled together in one image via ``BaseVecEnv.render()``.
89 | Otherwise (if ``self.num_envs == 1``), we pass the render call directly to the
90 | underlying environment.
91 |
92 | Therefore, some arguments such as ``mode`` will have values that are valid
93 | only when ``num_envs == 1``.
94 |
95 | :param mode: The rendering type.
96 | """
97 | if self.num_envs == 1:
98 | return self.envs[0].render(mode=mode)
99 | else:
100 | return super().render(mode=mode)
101 |
102 | def _save_obs(self, env_idx: int, obs: VecEnvObs) -> None:
103 | for key in self.keys:
104 | if key is None:
105 | self.buf_obs[key][env_idx] = obs
106 | else:
107 | self.buf_obs[key][env_idx] = obs[key]
108 |
109 | def _obs_from_buf(self) -> VecEnvObs:
110 | return dict_to_obs(self.observation_space, copy_obs_dict(self.buf_obs))
111 |
112 | def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
113 | """Return attribute from vectorized environment (see base class)."""
114 | target_envs = self._get_target_envs(indices)
115 | return [getattr(env_i, attr_name) for env_i in target_envs]
116 |
117 | def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
118 | """Set attribute inside vectorized environments (see base class)."""
119 | target_envs = self._get_target_envs(indices)
120 | for env_i in target_envs:
121 | setattr(env_i, attr_name, value)
122 |
123 | def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
124 | """Call instance methods of vectorized environments."""
125 | target_envs = self._get_target_envs(indices)
126 | return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs]
127 |
128 | def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
129 | """Check if worker environments are wrapped with a given wrapper"""
130 | target_envs = self._get_target_envs(indices)
131 | # Import here to avoid a circular import
132 | from stable_baselines3.common import env_util
133 |
134 | return [env_util.is_wrapped(env_i, wrapper_class) for env_i in target_envs]
135 |
136 | def _get_target_envs(self, indices: VecEnvIndices) -> List[gym.Env]:
137 | indices = self._get_indices(indices)
138 | return [self.envs[i] for i in indices]
139 |
--------------------------------------------------------------------------------
/thirdparty/stable_baselines3/common/vec_env/util.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers for dealing with vectorized environments.
3 | """
4 | from collections import OrderedDict
5 | from typing import Any, Dict, List, Tuple
6 |
7 | import gym
8 | import numpy as np
9 |
10 | from stable_baselines3.common.preprocessing import check_for_nested_spaces
11 | from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs
12 |
13 |
14 | def copy_obs_dict(obs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
15 | """
16 | Deep-copy a dict of numpy arrays.
17 |
18 | :param obs: a dict of numpy arrays.
19 | :return: a dict of copied numpy arrays.
20 | """
21 | assert isinstance(obs, OrderedDict), f"unexpected type for observations '{type(obs)}'"
22 | return OrderedDict([(k, np.copy(v)) for k, v in obs.items()])
23 |
24 |
25 | def dict_to_obs(obs_space: gym.spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> VecEnvObs:
26 | """
27 | Convert an internal representation raw_obs into the appropriate type
28 | specified by space.
29 |
30 | :param obs_space: an observation space.
31 | :param obs_dict: a dict of numpy arrays.
32 | :return: returns an observation of the same type as space.
33 | If space is Dict, function is identity; if space is Tuple, converts dict to Tuple;
34 | otherwise, space is unstructured and returns the value raw_obs[None].
35 | """
36 | if isinstance(obs_space, gym.spaces.Dict):
37 | return obs_dict
38 | elif isinstance(obs_space, gym.spaces.Tuple):
39 | assert len(obs_dict) == len(obs_space.spaces), "size of observation does not match size of observation space"
40 | return tuple((obs_dict[i] for i in range(len(obs_space.spaces))))
41 | else:
42 | assert set(obs_dict.keys()) == {None}, "multiple observation keys for unstructured observation space"
43 | return obs_dict[None]
44 |
45 |
46 | def obs_space_info(obs_space: gym.spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[int, ...]], Dict[Any, np.dtype]]:
47 | """
48 | Get dict-structured information about a gym.Space.
49 |
50 | Dict spaces are represented directly by their dict of subspaces.
51 | Tuple spaces are converted into a dict with keys indexing into the tuple.
52 | Unstructured spaces are represented by {None: obs_space}.
53 |
54 | :param obs_space: an observation space
55 | :return: A tuple (keys, shapes, dtypes):
56 | keys: a list of dict keys.
57 | shapes: a dict mapping keys to shapes.
58 | dtypes: a dict mapping keys to dtypes.
59 | """
60 | check_for_nested_spaces(obs_space)
61 | if isinstance(obs_space, gym.spaces.Dict):
62 | assert isinstance(obs_space.spaces, OrderedDict), "Dict space must have ordered subspaces"
63 | subspaces = obs_space.spaces
64 | elif isinstance(obs_space, gym.spaces.Tuple):
65 | subspaces = {i: space for i, space in enumerate(obs_space.spaces)}
66 | else:
67 | assert not hasattr(obs_space, "spaces"), f"Unsupported structured space '{type(obs_space)}'"
68 | subspaces = {None: obs_space}
69 | keys = []
70 | shapes = {}
71 | dtypes = {}
72 | for key, box in subspaces.items():
73 | keys.append(key)
74 | shapes[key] = box.shape
75 | dtypes[key] = box.dtype
76 | return keys, shapes, dtypes
77 |
--------------------------------------------------------------------------------
/thirdparty/stable_baselines3/common/vec_env/vec_check_nan.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import numpy as np
4 |
5 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper
6 |
7 |
8 | class VecCheckNan(VecEnvWrapper):
9 | """
10 | NaN and inf checking wrapper for vectorized environment, will raise a warning by default,
11 | allowing you to know from what the NaN of inf originated from.
12 |
13 | :param venv: the vectorized environment to wrap
14 | :param raise_exception: Whether or not to raise a ValueError, instead of a UserWarning
15 | :param warn_once: Whether or not to only warn once.
16 | :param check_inf: Whether or not to check for +inf or -inf as well
17 | """
18 |
19 | def __init__(self, venv: VecEnv, raise_exception: bool = False, warn_once: bool = True, check_inf: bool = True):
20 | VecEnvWrapper.__init__(self, venv)
21 | self.raise_exception = raise_exception
22 | self.warn_once = warn_once
23 | self.check_inf = check_inf
24 | self._actions = None
25 | self._observations = None
26 | self._user_warned = False
27 |
28 | def step_async(self, actions: np.ndarray) -> None:
29 | self._check_val(async_step=True, actions=actions)
30 |
31 | self._actions = actions
32 | self.venv.step_async(actions)
33 |
34 | def step_wait(self) -> VecEnvStepReturn:
35 | observations, rewards, news, infos = self.venv.step_wait()
36 |
37 | self._check_val(async_step=False, observations=observations, rewards=rewards, news=news)
38 |
39 | self._observations = observations
40 | return observations, rewards, news, infos
41 |
42 | def reset(self) -> VecEnvObs:
43 | observations = self.venv.reset()
44 | self._actions = None
45 |
46 | self._check_val(async_step=False, observations=observations)
47 |
48 | self._observations = observations
49 | return observations
50 |
51 | def _check_val(self, *, async_step: bool, **kwargs) -> None:
52 | # if warn and warn once and have warned once: then stop checking
53 | if not self.raise_exception and self.warn_once and self._user_warned:
54 | return
55 |
56 | found = []
57 | for name, val in kwargs.items():
58 | has_nan = np.any(np.isnan(val))
59 | has_inf = self.check_inf and np.any(np.isinf(val))
60 | if has_inf:
61 | found.append((name, "inf"))
62 | if has_nan:
63 | found.append((name, "nan"))
64 |
65 | if found:
66 | self._user_warned = True
67 | msg = ""
68 | for i, (name, type_val) in enumerate(found):
69 | msg += f"found {type_val} in {name}"
70 | if i != len(found) - 1:
71 | msg += ", "
72 |
73 | msg += ".\r\nOriginated from the "
74 |
75 | if not async_step:
76 | if self._actions is None:
77 | msg += "environment observation (at reset)"
78 | else:
79 | msg += f"environment, Last given value was: \r\n\taction={self._actions}"
80 | else:
81 | msg += f"RL model, Last given value was: \r\n\tobservations={self._observations}"
82 |
83 | if self.raise_exception:
84 | raise ValueError(msg)
85 | else:
86 | warnings.warn(msg, UserWarning)
87 |
--------------------------------------------------------------------------------
/thirdparty/stable_baselines3/common/vec_env/vec_extract_dict_obs.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper
4 |
5 |
6 | class VecExtractDictObs(VecEnvWrapper):
7 | """
8 | A vectorized wrapper for extracting dictionary observations.
9 |
10 | :param venv: The vectorized environment
11 | :param key: The key of the dictionary observation
12 | """
13 |
14 | def __init__(self, venv: VecEnv, key: str):
15 | self.key = key
16 | super().__init__(venv=venv, observation_space=venv.observation_space.spaces[self.key])
17 |
18 | def reset(self) -> np.ndarray:
19 | obs = self.venv.reset()
20 | return obs[self.key]
21 |
22 | def step_wait(self) -> VecEnvStepReturn:
23 | obs, reward, done, info = self.venv.step_wait()
24 | return obs[self.key], reward, done, info
25 |
--------------------------------------------------------------------------------
/thirdparty/stable_baselines3/common/vec_env/vec_frame_stack.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Optional, Tuple, Union
2 |
3 | import numpy as np
4 | from gym import spaces
5 |
6 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper
7 | from stable_baselines3.common.vec_env.stacked_observations import StackedDictObservations, StackedObservations
8 |
9 |
10 | class VecFrameStack(VecEnvWrapper):
11 | """
12 | Frame stacking wrapper for vectorized environment. Designed for image observations.
13 |
14 | Uses the StackedObservations class, or StackedDictObservations depending on the observations space
15 |
16 | :param venv: the vectorized environment to wrap
17 | :param n_stack: Number of frames to stack
18 | :param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension.
19 | If None, automatically detect channel to stack over in case of image observation or default to "last" (default).
20 | Alternatively channels_order can be a dictionary which can be used with environments with Dict observation spaces
21 | """
22 |
23 | def __init__(self, venv: VecEnv, n_stack: int, channels_order: Optional[Union[str, Dict[str, str]]] = None):
24 | self.venv = venv
25 | self.n_stack = n_stack
26 |
27 | wrapped_obs_space = venv.observation_space
28 |
29 | if isinstance(wrapped_obs_space, spaces.Box):
30 | assert not isinstance(
31 | channels_order, dict
32 | ), f"Expected None or string for channels_order but received {channels_order}"
33 | self.stackedobs = StackedObservations(venv.num_envs, n_stack, wrapped_obs_space, channels_order)
34 |
35 | elif isinstance(wrapped_obs_space, spaces.Dict):
36 | self.stackedobs = StackedDictObservations(venv.num_envs, n_stack, wrapped_obs_space, channels_order)
37 |
38 | else:
39 | raise Exception("VecFrameStack only works with gym.spaces.Box and gym.spaces.Dict observation spaces")
40 |
41 | observation_space = self.stackedobs.stack_observation_space(wrapped_obs_space)
42 | VecEnvWrapper.__init__(self, venv, observation_space=observation_space)
43 |
44 | def step_wait(
45 | self,
46 | ) -> Tuple[Union[np.ndarray, Dict[str, np.ndarray]], np.ndarray, np.ndarray, List[Dict[str, Any]],]:
47 |
48 | observations, rewards, dones, infos = self.venv.step_wait()
49 |
50 | observations, infos = self.stackedobs.update(observations, dones, infos)
51 |
52 | return observations, rewards, dones, infos
53 |
54 | def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]:
55 | """
56 | Reset all environments
57 | """
58 | observation = self.venv.reset() # pytype:disable=annotation-type-mismatch
59 |
60 | observation = self.stackedobs.reset(observation)
61 | return observation
62 |
63 | def close(self) -> None:
64 | self.venv.close()
65 |
--------------------------------------------------------------------------------
/thirdparty/stable_baselines3/common/vec_env/vec_monitor.py:
--------------------------------------------------------------------------------
1 | import time
2 | import warnings
3 | from typing import Optional, Tuple
4 |
5 | import numpy as np
6 |
7 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper
8 |
9 |
10 | class VecMonitor(VecEnvWrapper):
11 | """
12 | A vectorized monitor wrapper for *vectorized* Gym environments,
13 | it is used to record the episode reward, length, time and other data.
14 |
15 | Some environments like `openai/procgen `_
16 | or `gym3 `_ directly initialize the
17 | vectorized environments, without giving us a chance to use the ``Monitor``
18 | wrapper. So this class simply does the job of the ``Monitor`` wrapper on
19 | a vectorized level.
20 |
21 | :param venv: The vectorized environment
22 | :param filename: the location to save a log file, can be None for no log
23 | :param info_keywords: extra information to log, from the information return of env.step()
24 | """
25 |
26 | def __init__(
27 | self,
28 | venv: VecEnv,
29 | filename: Optional[str] = None,
30 | info_keywords: Tuple[str, ...] = (),
31 | ):
32 | # Avoid circular import
33 | from stable_baselines3.common.monitor import Monitor, ResultsWriter
34 |
35 | # This check is not valid for special `VecEnv`
36 | # like the ones created by Procgen, that does follow completely
37 | # the `VecEnv` interface
38 | try:
39 | is_wrapped_with_monitor = venv.env_is_wrapped(Monitor)[0]
40 | except AttributeError:
41 | is_wrapped_with_monitor = False
42 |
43 | if is_wrapped_with_monitor:
44 | warnings.warn(
45 | "The environment is already wrapped with a `Monitor` wrapper"
46 | "but you are wrapping it with a `VecMonitor` wrapper, the `Monitor` statistics will be"
47 | "overwritten by the `VecMonitor` ones.",
48 | UserWarning,
49 | )
50 |
51 | VecEnvWrapper.__init__(self, venv)
52 | self.episode_returns = None
53 | self.episode_lengths = None
54 | self.episode_count = 0
55 | self.t_start = time.time()
56 |
57 | env_id = None
58 | if hasattr(venv, "spec") and venv.spec is not None:
59 | env_id = venv.spec.id
60 |
61 | if filename:
62 | self.results_writer = ResultsWriter(
63 | filename, header={"t_start": self.t_start, "env_id": env_id}, extra_keys=info_keywords
64 | )
65 | else:
66 | self.results_writer = None
67 | self.info_keywords = info_keywords
68 |
69 | def reset(self) -> VecEnvObs:
70 | obs = self.venv.reset()
71 | self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
72 | self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
73 | return obs
74 |
75 | def step_wait(self) -> VecEnvStepReturn:
76 | obs, rewards, dones, infos = self.venv.step_wait()
77 | self.episode_returns += rewards
78 | self.episode_lengths += 1
79 | new_infos = list(infos[:])
80 | for i in range(len(dones)):
81 | if dones[i]:
82 | info = infos[i].copy()
83 | episode_return = self.episode_returns[i]
84 | episode_length = self.episode_lengths[i]
85 | episode_info = {"r": episode_return, "l": episode_length, "t": round(time.time() - self.t_start, 6)}
86 | info["episode"] = episode_info
87 | self.episode_count += 1
88 | self.episode_returns[i] = 0
89 | self.episode_lengths[i] = 0
90 | if self.results_writer:
91 | self.results_writer.write_row(episode_info)
92 | new_infos[i] = info
93 | return obs, rewards, dones, new_infos
94 |
95 | def close(self) -> None:
96 | if self.results_writer:
97 | self.results_writer.close()
98 | return self.venv.close()
99 |
--------------------------------------------------------------------------------
/thirdparty/stable_baselines3/common/vec_env/vec_transpose.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | from typing import Dict, Union
3 |
4 | import numpy as np
5 | from gym import spaces
6 |
7 | from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
8 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper
9 |
10 |
11 | class VecTransposeImage(VecEnvWrapper):
12 | """
13 | Re-order channels, from HxWxC to CxHxW.
14 | It is required for PyTorch convolution layers.
15 |
16 | :param venv:
17 | """
18 |
19 | def __init__(self, venv: VecEnv):
20 | assert is_image_space(venv.observation_space) or isinstance(
21 | venv.observation_space, spaces.dict.Dict
22 | ), "The observation space must be an image or dictionary observation space"
23 |
24 | if isinstance(venv.observation_space, spaces.dict.Dict):
25 | self.image_space_keys = []
26 | observation_space = deepcopy(venv.observation_space)
27 | for key, space in observation_space.spaces.items():
28 | if is_image_space(space):
29 | # Keep track of which keys should be transposed later
30 | self.image_space_keys.append(key)
31 | observation_space.spaces[key] = self.transpose_space(space, key)
32 | else:
33 | observation_space = self.transpose_space(venv.observation_space)
34 | super(VecTransposeImage, self).__init__(venv, observation_space=observation_space)
35 |
36 | @staticmethod
37 | def transpose_space(observation_space: spaces.Box, key: str = "") -> spaces.Box:
38 | """
39 | Transpose an observation space (re-order channels).
40 |
41 | :param observation_space:
42 | :param key: In case of dictionary space, the key of the observation space.
43 | :return:
44 | """
45 | # Sanity checks
46 | assert is_image_space(observation_space), "The observation space must be an image"
47 | assert not is_image_space_channels_first(
48 | observation_space
49 | ), f"The observation space {key} must follow the channel last convention"
50 | height, width, channels = observation_space.shape
51 | new_shape = (channels, height, width)
52 | return spaces.Box(low=0, high=255, shape=new_shape, dtype=observation_space.dtype)
53 |
54 | @staticmethod
55 | def transpose_image(image: np.ndarray) -> np.ndarray:
56 | """
57 | Transpose an image or batch of images (re-order channels).
58 |
59 | :param image:
60 | :return:
61 | """
62 | if len(image.shape) == 3:
63 | return np.transpose(image, (2, 0, 1))
64 | return np.transpose(image, (0, 3, 1, 2))
65 |
66 | def transpose_observations(self, observations: Union[np.ndarray, Dict]) -> Union[np.ndarray, Dict]:
67 | """
68 | Transpose (if needed) and return new observations.
69 |
70 | :param observations:
71 | :return: Transposed observations
72 | """
73 | if isinstance(observations, dict):
74 | # Avoid modifying the original object in place
75 | observations = deepcopy(observations)
76 | for k in self.image_space_keys:
77 | observations[k] = self.transpose_image(observations[k])
78 | else:
79 | observations = self.transpose_image(observations)
80 | return observations
81 |
82 | def step_wait(self) -> VecEnvStepReturn:
83 | observations, rewards, dones, infos = self.venv.step_wait()
84 |
85 | # Transpose the terminal observations
86 | for idx, done in enumerate(dones):
87 | if not done:
88 | continue
89 | if "terminal_observation" in infos[idx]:
90 | infos[idx]["terminal_observation"] = self.transpose_observations(infos[idx]["terminal_observation"])
91 |
92 | return self.transpose_observations(observations), rewards, dones, infos
93 |
94 | def reset(self) -> Union[np.ndarray, Dict]:
95 | """
96 | Reset all environments
97 | """
98 | return self.transpose_observations(self.venv.reset())
99 |
100 | def close(self) -> None:
101 | self.venv.close()
102 |
--------------------------------------------------------------------------------
/thirdparty/stable_baselines3/common/vec_env/vec_video_recorder.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Callable
3 |
4 | from gym.wrappers.monitoring import video_recorder
5 |
6 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper
7 | from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
8 | from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
9 |
10 |
11 | class VecVideoRecorder(VecEnvWrapper):
12 | """
13 | Wraps a VecEnv or VecEnvWrapper object to record rendered image as mp4 video.
14 | It requires ffmpeg or avconv to be installed on the machine.
15 |
16 | :param venv:
17 | :param video_folder: Where to save videos
18 | :param record_video_trigger: Function that defines when to start recording.
19 | The function takes the current number of step,
20 | and returns whether we should start recording or not.
21 | :param video_length: Length of recorded videos
22 | :param name_prefix: Prefix to the video name
23 | """
24 |
25 | def __init__(
26 | self,
27 | venv: VecEnv,
28 | video_folder: str,
29 | record_video_trigger: Callable[[int], bool],
30 | video_length: int = 200,
31 | name_prefix: str = "rl-video",
32 | ):
33 |
34 | VecEnvWrapper.__init__(self, venv)
35 |
36 | self.env = venv
37 | # Temp variable to retrieve metadata
38 | temp_env = venv
39 |
40 | # Unwrap to retrieve metadata dict
41 | # that will be used by gym recorder
42 | while isinstance(temp_env, VecEnvWrapper):
43 | temp_env = temp_env.venv
44 |
45 | if isinstance(temp_env, DummyVecEnv) or isinstance(temp_env, SubprocVecEnv):
46 | metadata = temp_env.get_attr("metadata")[0]
47 | else:
48 | metadata = temp_env.metadata
49 |
50 | self.env.metadata = metadata
51 |
52 | self.record_video_trigger = record_video_trigger
53 | self.video_recorder = None
54 |
55 | self.video_folder = os.path.abspath(video_folder)
56 | # Create output folder if needed
57 | os.makedirs(self.video_folder, exist_ok=True)
58 |
59 | self.name_prefix = name_prefix
60 | self.step_id = 0
61 | self.video_length = video_length
62 |
63 | self.recording = False
64 | self.recorded_frames = 0
65 |
66 | def reset(self) -> VecEnvObs:
67 | obs = self.venv.reset()
68 | self.start_video_recorder()
69 | return obs
70 |
71 | def start_video_recorder(self) -> None:
72 | self.close_video_recorder()
73 |
74 | video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}"
75 | base_path = os.path.join(self.video_folder, video_name)
76 | self.video_recorder = video_recorder.VideoRecorder(
77 | env=self.env, base_path=base_path, metadata={"step_id": self.step_id}
78 | )
79 |
80 | self.video_recorder.capture_frame()
81 | self.recorded_frames = 1
82 | self.recording = True
83 |
84 | def _video_enabled(self) -> bool:
85 | return self.record_video_trigger(self.step_id)
86 |
87 | def step_wait(self) -> VecEnvStepReturn:
88 | obs, rews, dones, infos = self.venv.step_wait()
89 |
90 | self.step_id += 1
91 | if self.recording:
92 | self.video_recorder.capture_frame()
93 | self.recorded_frames += 1
94 | if self.recorded_frames > self.video_length:
95 | print(f"Saving video to {self.video_recorder.path}")
96 | self.close_video_recorder()
97 | elif self._video_enabled():
98 | self.start_video_recorder()
99 |
100 | return obs, rews, dones, infos
101 |
102 | def close_video_recorder(self) -> None:
103 | if self.recording:
104 | self.video_recorder.close()
105 | self.recording = False
106 | self.recorded_frames = 1
107 |
108 | def close(self) -> None:
109 | VecEnvWrapper.close(self)
110 | self.close_video_recorder()
111 |
112 | def __del__(self):
113 | self.close()
114 |
--------------------------------------------------------------------------------
/thirdparty/stable_baselines3/ddpg/__init__.py:
--------------------------------------------------------------------------------
1 | from stable_baselines3.ddpg.ddpg import DDPG
2 | from stable_baselines3.ddpg.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
3 |
--------------------------------------------------------------------------------
/thirdparty/stable_baselines3/ddpg/ddpg.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, Optional, Tuple, Type, Union
2 |
3 | import torch as th
4 |
5 | from stable_baselines3.common.buffers import ReplayBuffer
6 | from stable_baselines3.common.noise import ActionNoise
7 | from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
8 | from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
9 | from stable_baselines3.td3.policies import TD3Policy
10 | from stable_baselines3.td3.td3 import TD3
11 |
12 |
13 | class DDPG(TD3):
14 | """
15 | Deep Deterministic Policy Gradient (DDPG).
16 |
17 | Deterministic Policy Gradient: http://proceedings.mlr.press/v32/silver14.pdf
18 | DDPG Paper: https://arxiv.org/abs/1509.02971
19 | Introduction to DDPG: https://spinningup.openai.com/en/latest/algorithms/ddpg.html
20 |
21 | Note: we treat DDPG as a special case of its successor TD3.
22 |
23 | :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
24 | :param env: The environment to learn from (if registered in Gym, can be str)
25 | :param learning_rate: learning rate for adam optimizer,
26 | the same learning rate will be used for all networks (Q-Values, Actor and Value function)
27 | it can be a function of the current progress remaining (from 1 to 0)
28 | :param buffer_size: size of the replay buffer
29 | :param learning_starts: how many steps of the model to collect transitions for before learning starts
30 | :param batch_size: Minibatch size for each gradient update
31 | :param tau: the soft update coefficient ("Polyak update", between 0 and 1)
32 | :param gamma: the discount factor
33 | :param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit
34 | like ``(5, "step")`` or ``(2, "episode")``.
35 | :param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``)
36 | Set to ``-1`` means to do as many gradient steps as steps done in the environment
37 | during the rollout.
38 | :param action_noise: the action noise type (None by default), this can help
39 | for hard exploration problem. Cf common.noise for the different action noise type.
40 | :param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``).
41 | If ``None``, it will be automatically selected.
42 | :param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation.
43 | :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
44 | at a cost of more complexity.
45 | See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
46 | :param create_eval_env: Whether to create a second environment that will be
47 | used for evaluating the agent periodically. (Only available when passing string for the environment)
48 | :param policy_kwargs: additional arguments to be passed to the policy on creation
49 | :param verbose: the verbosity level: 0 no output, 1 info, 2 debug
50 | :param seed: Seed for the pseudo random generators
51 | :param device: Device (cpu, cuda, ...) on which the code should be run.
52 | Setting it to auto, the code will be run on the GPU if possible.
53 | :param _init_setup_model: Whether or not to build the network at the creation of the instance
54 | """
55 |
56 | def __init__(
57 | self,
58 | policy: Union[str, Type[TD3Policy]],
59 | env: Union[GymEnv, str],
60 | learning_rate: Union[float, Schedule] = 1e-3,
61 | buffer_size: int = 1_000_000, # 1e6
62 | learning_starts: int = 100,
63 | batch_size: int = 100,
64 | tau: float = 0.005,
65 | gamma: float = 0.99,
66 | train_freq: Union[int, Tuple[int, str]] = (1, "episode"),
67 | gradient_steps: int = -1,
68 | action_noise: Optional[ActionNoise] = None,
69 | replay_buffer_class: Optional[ReplayBuffer] = None,
70 | replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
71 | optimize_memory_usage: bool = False,
72 | tensorboard_log: Optional[str] = None,
73 | create_eval_env: bool = False,
74 | policy_kwargs: Optional[Dict[str, Any]] = None,
75 | verbose: int = 0,
76 | seed: Optional[int] = None,
77 | device: Union[th.device, str] = "auto",
78 | _init_setup_model: bool = True,
79 | ):
80 |
81 | super(DDPG, self).__init__(
82 | policy=policy,
83 | env=env,
84 | learning_rate=learning_rate,
85 | buffer_size=buffer_size,
86 | learning_starts=learning_starts,
87 | batch_size=batch_size,
88 | tau=tau,
89 | gamma=gamma,
90 | train_freq=train_freq,
91 | gradient_steps=gradient_steps,
92 | action_noise=action_noise,
93 | replay_buffer_class=replay_buffer_class,
94 | replay_buffer_kwargs=replay_buffer_kwargs,
95 | policy_kwargs=policy_kwargs,
96 | tensorboard_log=tensorboard_log,
97 | verbose=verbose,
98 | device=device,
99 | create_eval_env=create_eval_env,
100 | seed=seed,
101 | optimize_memory_usage=optimize_memory_usage,
102 | # Remove all tricks from TD3 to obtain DDPG:
103 | # we still need to specify target_policy_noise > 0 to avoid errors
104 | policy_delay=1,
105 | target_noise_clip=0.0,
106 | target_policy_noise=0.1,
107 | _init_setup_model=False,
108 | )
109 |
110 | # Use only one critic
111 | if "n_critics" not in self.policy_kwargs:
112 | self.policy_kwargs["n_critics"] = 1
113 |
114 | if _init_setup_model:
115 | self._setup_model()
116 |
117 | def learn(
118 | self,
119 | total_timesteps: int,
120 | callback: MaybeCallback = None,
121 | log_interval: int = 4,
122 | eval_env: Optional[GymEnv] = None,
123 | eval_freq: int = -1,
124 | n_eval_episodes: int = 5,
125 | tb_log_name: str = "DDPG",
126 | eval_log_path: Optional[str] = None,
127 | reset_num_timesteps: bool = True,
128 | ) -> OffPolicyAlgorithm:
129 |
130 | return super(DDPG, self).learn(
131 | total_timesteps=total_timesteps,
132 | callback=callback,
133 | log_interval=log_interval,
134 | eval_env=eval_env,
135 | eval_freq=eval_freq,
136 | n_eval_episodes=n_eval_episodes,
137 | tb_log_name=tb_log_name,
138 | eval_log_path=eval_log_path,
139 | reset_num_timesteps=reset_num_timesteps,
140 | )
141 |
--------------------------------------------------------------------------------
/thirdparty/stable_baselines3/ddpg/policies.py:
--------------------------------------------------------------------------------
1 | # DDPG can be view as a special case of TD3
2 | from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy # noqa:F401
3 |
--------------------------------------------------------------------------------
/thirdparty/stable_baselines3/dqn/__init__.py:
--------------------------------------------------------------------------------
1 | from stable_baselines3.dqn.dqn import DQN
2 | from stable_baselines3.dqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
3 |
--------------------------------------------------------------------------------
/thirdparty/stable_baselines3/her/__init__.py:
--------------------------------------------------------------------------------
1 | from stable_baselines3.her.goal_selection_strategy import GoalSelectionStrategy
2 | from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
3 |
--------------------------------------------------------------------------------
/thirdparty/stable_baselines3/her/goal_selection_strategy.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 |
4 | class GoalSelectionStrategy(Enum):
5 | """
6 | The strategies for selecting new goals when
7 | creating artificial transitions.
8 | """
9 |
10 | # Select a goal that was achieved
11 | # after the current step, in the same episode
12 | FUTURE = 0
13 | # Select the goal that was achieved
14 | # at the end of the episode
15 | FINAL = 1
16 | # Select a goal that was achieved in the episode
17 | EPISODE = 2
18 |
19 |
20 | # For convenience
21 | # that way, we can use string to select a strategy
22 | KEY_TO_GOAL_STRATEGY = {
23 | "future": GoalSelectionStrategy.FUTURE,
24 | "final": GoalSelectionStrategy.FINAL,
25 | "episode": GoalSelectionStrategy.EPISODE,
26 | }
27 |
--------------------------------------------------------------------------------
/thirdparty/stable_baselines3/local_modifications.txt:
--------------------------------------------------------------------------------
1 | modified: stable_baselines3/common/evaluation.py
2 | stable_baselines3/common/vec_env/dummy_vec_env.py
3 |
--------------------------------------------------------------------------------
/thirdparty/stable_baselines3/ppo/__init__.py:
--------------------------------------------------------------------------------
1 | from stable_baselines3.ppo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
2 | from stable_baselines3.ppo.ppo import PPO
3 |
--------------------------------------------------------------------------------
/thirdparty/stable_baselines3/ppo/policies.py:
--------------------------------------------------------------------------------
1 | # This file is here just to define MlpPolicy/CnnPolicy
2 | # that work for PPO
3 | from stable_baselines3.common.policies import (
4 | ActorCriticCnnPolicy,
5 | ActorCriticPolicy,
6 | MultiInputActorCriticPolicy,
7 | register_policy,
8 | )
9 |
10 | MlpPolicy = ActorCriticPolicy
11 | CnnPolicy = ActorCriticCnnPolicy
12 | MultiInputPolicy = MultiInputActorCriticPolicy
13 |
14 | register_policy("MlpPolicy", ActorCriticPolicy)
15 | register_policy("CnnPolicy", ActorCriticCnnPolicy)
16 | register_policy("MultiInputPolicy", MultiInputPolicy)
17 |
--------------------------------------------------------------------------------
/thirdparty/stable_baselines3/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RobustFieldAutonomyLab/Stochastic_Road_Network/1cdcd41c7311560bf7a5df0d4d8bca829fe2b958/thirdparty/stable_baselines3/py.typed
--------------------------------------------------------------------------------
/thirdparty/stable_baselines3/sac/__init__.py:
--------------------------------------------------------------------------------
1 | from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
2 | from stable_baselines3.sac.sac import SAC
3 |
--------------------------------------------------------------------------------
/thirdparty/stable_baselines3/td3/__init__.py:
--------------------------------------------------------------------------------
1 | from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
2 | from stable_baselines3.td3.td3 import TD3
3 |
--------------------------------------------------------------------------------
/thirdparty/stable_baselines3/version.txt:
--------------------------------------------------------------------------------
1 | 1.3.1a1
2 |
--------------------------------------------------------------------------------